Compare commits
7 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
52a0fdfd89 | ||
![]() |
1392168126 | ||
![]() |
eff22cb33e | ||
![]() |
66c9477e4b | ||
9d522fa137 | |||
f7138dcb39 | |||
8b69ff195f |
.cursorrules.gitignoreDockerfileREADME.md
data
docker-compose.ymldocker
docs
main.pypyproject.tomlrequirements.txtsrc
audio_chunk.pyclient.pyconfig.py
test_main.pycore
functor
logic_trager.pymodels.pymodels
pipeline
runner
server.pyservice.pyutils
websockets
tests
XT_ZZY.wavXT_ZZY_denoise.wav
uv.lockfunctor
modelsuse.pypipeline
runner
spkverify_use.pytest_config.pyvad_example.wavwebsocket
19
.cursorrules
Normal file
19
.cursorrules
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
|
||||||
|
You are an AI assistant specialized in Python development. Your approach emphasizes:
|
||||||
|
|
||||||
|
1. Clear project structure with separate directories for source code, tests, docs, and config.
|
||||||
|
2. Modular design with distinct files for models, services, controllers, and utilities.
|
||||||
|
3. Configuration management using environment variables.
|
||||||
|
4. Robust error handling and logging, including context capture.
|
||||||
|
5. Comprehensive testing with pytest.
|
||||||
|
6. Detailed documentation using docstrings and README files.
|
||||||
|
7. Dependency management via https://github.com/astral-sh/rye and virtual environments.
|
||||||
|
8. Code style consistency using Ruff.
|
||||||
|
9. CI/CD implementation with GitHub Actions or GitLab CI.
|
||||||
|
10. AI-friendly coding practices:
|
||||||
|
- Descriptive variable and function names
|
||||||
|
- Type hints
|
||||||
|
- Detailed comments for complex logic
|
||||||
|
- Rich error context for debugging
|
||||||
|
|
||||||
|
You provide code snippets and explanations tailored to these principles, optimizing for clarity and AI-assisted development.
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -29,6 +29,7 @@ env/
|
|||||||
.coverage
|
.coverage
|
||||||
htmlcov/
|
htmlcov/
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
|
savePath/
|
||||||
|
|
||||||
# 编辑器相关
|
# 编辑器相关
|
||||||
.idea/
|
.idea/
|
||||||
|
27
Dockerfile
27
Dockerfile
@ -1,27 +0,0 @@
|
|||||||
FROM python:3.9-slim
|
|
||||||
|
|
||||||
# 安装系统依赖
|
|
||||||
RUN apt-get update && apt-get install -y \
|
|
||||||
build-essential \
|
|
||||||
libsndfile1 \
|
|
||||||
ffmpeg \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
# 安装Python依赖
|
|
||||||
COPY requirements.txt .
|
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
|
||||||
|
|
||||||
# 复制应用代码
|
|
||||||
COPY . .
|
|
||||||
|
|
||||||
# 设置环境变量
|
|
||||||
ENV PYTHONPATH=/app
|
|
||||||
ENV PYTHONUNBUFFERED=1
|
|
||||||
|
|
||||||
# 暴露WebSocket端口
|
|
||||||
EXPOSE 10095
|
|
||||||
|
|
||||||
# 启动服务
|
|
||||||
CMD ["python", "src/server.py"]
|
|
191
README.md
191
README.md
@ -1,113 +1,142 @@
|
|||||||
# FunASR WebSocket服务
|
# FunASR FastAPI WebSocket Service
|
||||||
|
|
||||||
## 简介
|
一个基于 FunASR 和 FastAPI 构建的高性能、实时的语音识别 WebSocket 服务。该项目核心特色是支持"一发多收"的广播模式,适用于会议实时字幕、在线教育、直播转写等需要将单一音源的识别结果分发给多个客户端的场景。
|
||||||
本项目基于FunASR实现了一个WebSocket语音识别服务,支持实时语音流的在线和离线识别。利用ModelScope开源语音模型,该服务可以进行高精度的中文语音识别,并支持语音活动检测(VAD)和自动添加标点符号。
|
|
||||||
|
## ✨ 功能特性
|
||||||
|
|
||||||
|
- **实时语音处理**: 集成 FunASR 的语音活动检测(VAD)、语音识别(ASR)和声纹识别(SPK)模型。
|
||||||
|
- **WebSocket 流式 API**: 提供低延迟、双向的实时通信接口。
|
||||||
|
- **"一发多收"架构**:
|
||||||
|
- **发送者 (Sender)**: 单一客户端作为音频来源,向服务器持续发送音频流。
|
||||||
|
- **接收者 (Receiver)**: 多个客户端可以订阅同一个会话,实时接收广播的识别结果。
|
||||||
|
- **异步核心**: 基于 FastAPI 和 `asyncio` 构建,可处理大量并发连接。
|
||||||
|
- **模块化设计**: 清晰地分离了服务层 (`server.py`)、会话管理层 (`ASRRunner`) 和核心处理流水线 (`ASRPipeline`)。
|
||||||
|
|
||||||
|
## 📂 项目结构
|
||||||
|
|
||||||
## 项目结构
|
|
||||||
```
|
```
|
||||||
.
|
.
|
||||||
├── src/ # 源代码目录
|
├── main.py # 应用程序主入口,使用 uvicorn 启动服务
|
||||||
│ ├── __init__.py # 包初始化文件
|
├── docker
|
||||||
│ ├── server.py # WebSocket服务器实现
|
│ ├── Dockerfile # Docker 镜像构建文件
|
||||||
│ ├── config.py # 配置处理模块
|
│ ├── docker-compose.yml # Docker 容器编排文件
|
||||||
│ ├── models.py # 模型加载模块
|
├── docs
|
||||||
│ ├── service.py # ASR服务实现
|
│ ├── WEBSOCKET_API.md # WebSocket API 详细使用文档和示例
|
||||||
│ └── client.py # 测试客户端
|
│ ├── SystemArchitecture.md # 系统架构文档
|
||||||
├── tests/ # 测试目录
|
├── src
|
||||||
│ ├── __init__.py # 测试包初始化文件
|
│ ├── server.py # FastAPI 应用核心,管理生命周期和全局资源
|
||||||
│ └── test_config.py # 配置模块测试
|
│ ├── runner
|
||||||
├── requirements.txt # Python依赖
|
│ │ └── ASRRunner.py # 核心会话管理器,负责创建和协调识别会话 (SAR)
|
||||||
├── Dockerfile # Docker配置
|
│ ├── pipeline
|
||||||
├── docker-compose.yml # Docker Compose配置
|
│ │ └── ASRpipeline.py # 同步的、基于线程的语音处理流水线
|
||||||
├── .gitignore # Git忽略文件
|
│ ├── functor # VAD, ASR, SPK 等原子操作的实现
|
||||||
└── README.md # 项目说明
|
│ ├── websockets
|
||||||
|
│ │ ├── adapter.py # WebSocket 适配器,处理数据格式转换
|
||||||
|
│ │ ├── endpoint
|
||||||
|
│ │ │ └── asr_endpoint.py # WebSocket 的业务逻辑端点
|
||||||
|
│ │ └── router.py # WebSocket 路由
|
||||||
|
│ ├── core
|
||||||
|
│ │ └── model_loader.py # 模型加载器
|
||||||
|
│ ├── utils
|
||||||
|
│ │ └── logger.py # 日志记录器
|
||||||
|
│ │ └── data_format.py # 数据格式转换
|
||||||
|
│ │ └── mock_websocket.py # 模拟 WebSocket 客户端
|
||||||
|
├── tests
|
||||||
|
│ ├── runner
|
||||||
|
│ │ └── asr_runner_test.py # ASRRunner 的单元测试 (异步)
|
||||||
|
│ ├── websocket
|
||||||
|
│ │ └── websocket_asr.py # WebSocket 服务的端到端测试
|
||||||
```
|
```
|
||||||
|
|
||||||
## 功能特性
|
## 🚀 快速开始
|
||||||
|
|
||||||
- **多模式识别**:支持离线(offline)、在线(online)和两阶段(2pass)识别模式
|
### 1. 环境与依赖
|
||||||
- **语音活动检测**:自动检测语音开始和结束
|
|
||||||
- **标点符号**:支持自动添加标点符号
|
|
||||||
- **WebSocket接口**:基于二进制WebSocket提供实时语音识别
|
|
||||||
- **Docker支持**:提供容器化部署支持
|
|
||||||
|
|
||||||
## 安装与使用
|
|
||||||
|
|
||||||
### 环境要求
|
|
||||||
- Python 3.8+
|
- Python 3.8+
|
||||||
- CUDA支持 (若需GPU加速)
|
- 项目依赖项记录在 `requirements.txt` 文件中。
|
||||||
- 内存 >= 8GB
|
|
||||||
|
|
||||||
### 安装依赖
|
### 2. 安装
|
||||||
|
|
||||||
|
建议在虚拟环境中安装依赖。在项目根目录下,运行:
|
||||||
```bash
|
```bash
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
### 运行服务器
|
### 3. 运行服务
|
||||||
|
|
||||||
|
执行主入口文件来启动 FastAPI 服务:
|
||||||
|
```bash
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
服务启动后,将监听 `http://0.0.0.0:8000`。
|
||||||
|
|
||||||
|
## 💡 如何使用
|
||||||
|
|
||||||
|
服务通过 WebSocket 提供,客户端通过 `session_id` 来创建或加入一个识别会话,并通过 `mode` 参数声明自己的角色(`sender` 或 `receiver`)。
|
||||||
|
|
||||||
|
**详细的 API 说明、URL 格式以及 Python 和 JavaScript 的客户端连接示例,请参阅:**
|
||||||
|
|
||||||
|
➡️ **[WEBSOCKET_API.md](./docs/WEBSOCKET_API.md)**
|
||||||
|
|
||||||
|
## 🔬 测试
|
||||||
|
|
||||||
|
项目提供了两种测试方式来验证其功能。
|
||||||
|
|
||||||
|
### 1. 端到端 WebSocket 测试
|
||||||
|
|
||||||
|
此测试会模拟一个 `sender` 和一个 `receiver`,完整地测试一次识别会话。
|
||||||
|
|
||||||
|
**前提**: 确保 FastAPI 服务正在运行。
|
||||||
|
```bash
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
在项目根目录下执行:
|
||||||
|
```bash
|
||||||
|
python tests/websocket/websocket_asr.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. ASRRunner 单元测试
|
||||||
|
|
||||||
|
此测试针对核心的 `ASRRunner` 组件进行,验证其异步逻辑。
|
||||||
|
|
||||||
|
执行测试:
|
||||||
|
```bash
|
||||||
|
python test_main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🐳 使用 Docker 部署
|
||||||
|
|
||||||
|
### 1. 构建 Docker 镜像
|
||||||
|
|
||||||
|
在项目根目录下,运行:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/server.py
|
docker build -t asr-server:x.x.x -f docker/Dockerfile .
|
||||||
```
|
```
|
||||||
|
|
||||||
常用启动参数:
|
### 2. 运行 Docker 容器
|
||||||
- `--host`: 服务器监听地址,默认为 0.0.0.0
|
|
||||||
- `--port`: 服务器端口,默认为 10095
|
|
||||||
- `--device`: 设备类型(cuda或cpu),默认为 cuda
|
|
||||||
- `--ngpu`: GPU数量,0表示使用CPU,默认为 1
|
|
||||||
|
|
||||||
### 测试客户端
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/client.py --audio_file path/to/audio.wav
|
docker run -d -p 11096:11096 --name asr-server -v ~/.cache/modelscope:/root/.cache/modelscope asr-server:x.x.x
|
||||||
```
|
```
|
||||||
|
|
||||||
常用客户端参数:
|
### 环境变量说明
|
||||||
- `--audio_file`: 要识别的音频文件路径
|
|
||||||
- `--mode`: 识别模式,可选 2pass/online/offline,默认为 2pass
|
|
||||||
- `--host`: 服务器地址,默认为 localhost
|
|
||||||
- `--port`: 服务器端口,默认为 10095
|
|
||||||
|
|
||||||
## Docker部署
|
- `SPEAKERS_URL`: 说话人数据库 API 的 URL。
|
||||||
|
|
||||||
### 构建镜像
|
示例/默认 SPEAKERS_URL="http://172.23.30.120:11200/api/v1/speakers/"
|
||||||
|
|
||||||
```bash
|
此url为后端api提供的查询所有数据库中说话人信息的接口
|
||||||
docker build -t funasr-websocket .
|
|
||||||
```
|
|
||||||
|
|
||||||
### 使用Docker Compose启动
|
- `LOG_LEVEL`: 日志等级。
|
||||||
|
|
||||||
```bash
|
示例/默认 LOG_LEVEL="INFO"
|
||||||
docker-compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
## API说明
|
`LOG_LEVEL` 为总项目日志等级,
|
||||||
|
|
||||||
### WebSocket消息格式
|
- `LOG_LEVEL_ASR_SERVER`: 日志等级。
|
||||||
|
|
||||||
1. **客户端配置消息**:
|
示例/默认 LOG_LEVEL_ASR_SERVER="INFO"
|
||||||
```json
|
|
||||||
{
|
|
||||||
"mode": "2pass", // 可选: "2pass", "online", "offline"
|
|
||||||
"chunk_size": "5,10", // 块大小,格式为"encoder_size,decoder_size"
|
|
||||||
"wav_name": "audio1", // 音频标识名称
|
|
||||||
"is_speaking": true // 是否正在说话
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **客户端音频数据**:
|
`LOG_LEVEL_ASR_SERVER` 为 ASR 服务日志等级,优先级高于 `LOG_LEVEL`
|
||||||
二进制音频数据流,16kHz采样率,16位PCM格式
|
|
||||||
|
|
||||||
3. **服务器识别结果**:
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"mode": "2pass-online", // 识别模式
|
|
||||||
"text": "识别的文本内容", // 识别结果
|
|
||||||
"wav_name": "audio1", // 音频标识
|
|
||||||
"is_final": false // 是否是最终结果
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 许可证
|
|
||||||
[MIT](LICENSE)
|
|
||||||
|
13
data/denoise.py
Normal file
13
data/denoise.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
|
||||||
|
ans = pipeline(
|
||||||
|
Tasks.acoustic_noise_suppression,
|
||||||
|
model='iic/speech_frcrn_ans_cirm_16k')
|
||||||
|
|
||||||
|
wav_file = 'speaker_wav/HaiaoDuan.wav'
|
||||||
|
output_path = 'denoise_output/HaiaoDuan_denoise_output.wav'
|
||||||
|
result = ans(
|
||||||
|
wav_file,
|
||||||
|
output_path=output_path)
|
BIN
data/denoise_output/HaiaoDuan_denoise_output.wav
Normal file
BIN
data/denoise_output/HaiaoDuan_denoise_output.wav
Normal file
Binary file not shown.
BIN
data/denoise_output/ZiyangZhang_denoise_output.wav
Normal file
BIN
data/denoise_output/ZiyangZhang_denoise_output.wav
Normal file
Binary file not shown.
35
data/record.py
Normal file
35
data/record.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
本地录音,保存为wav格式,存储在data/speaker_wav目录下
|
||||||
|
"""
|
||||||
|
import pyaudio
|
||||||
|
import wave
|
||||||
|
|
||||||
|
def record_audio(filename, duration=5, format=pyaudio.paInt16, channels=1, rate=16000):
|
||||||
|
"""
|
||||||
|
本地录音,保存为wav格式,存储在data/speaker_wav目录下
|
||||||
|
"""
|
||||||
|
p = pyaudio.PyAudio()
|
||||||
|
stream = p.open(format=format, channels=channels, rate=rate, input=True, frames_per_buffer=1024)
|
||||||
|
|
||||||
|
print("按下回车键开始录音...")
|
||||||
|
input()
|
||||||
|
frames = []
|
||||||
|
for i in range(0, int(rate / 1024 * duration)):
|
||||||
|
data = stream.read(1024)
|
||||||
|
frames.append(data)
|
||||||
|
print("录音结束")
|
||||||
|
stream.stop_stream()
|
||||||
|
stream.close()
|
||||||
|
p.terminate()
|
||||||
|
wav_file = wave.open(filename, 'wb')
|
||||||
|
wav_file.setnchannels(channels)
|
||||||
|
wav_file.setsampwidth(p.get_sample_size(format))
|
||||||
|
wav_file.setframerate(rate)
|
||||||
|
wav_file.writeframes(b''.join(frames))
|
||||||
|
wav_file.close()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
record_audio(
|
||||||
|
"data/speaker_wav/test.wav",
|
||||||
|
duration=5
|
||||||
|
)
|
BIN
data/speaker_wav/HaiaoDuan.wav
Normal file
BIN
data/speaker_wav/HaiaoDuan.wav
Normal file
Binary file not shown.
BIN
data/speaker_wav/HaiaoDuan_origin.wav
Normal file
BIN
data/speaker_wav/HaiaoDuan_origin.wav
Normal file
Binary file not shown.
BIN
data/speaker_wav/ZiyangZhang.wav
Normal file
BIN
data/speaker_wav/ZiyangZhang.wav
Normal file
Binary file not shown.
BIN
data/speaker_wav/ZiyangZhang_origin.wav
Normal file
BIN
data/speaker_wav/ZiyangZhang_origin.wav
Normal file
Binary file not shown.
400
data/speakers.json
Normal file
400
data/speakers.json
Normal file
@ -0,0 +1,400 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"speaker_id": "137facd6-a1c9-47b9-b87d-16e6f62e07bd",
|
||||||
|
"speaker_name": "ZiyangZhang",
|
||||||
|
"wav_path": "/home/lyg/Code/funasr/data/speaker_wav/ZiyangZhang.wav",
|
||||||
|
"speaker_embs": [
|
||||||
|
-0.4249887466430664,
|
||||||
|
-0.12976674735546112,
|
||||||
|
1.6118208169937134,
|
||||||
|
1.3348901271820068,
|
||||||
|
0.1423041820526123,
|
||||||
|
0.16940945386886597,
|
||||||
|
-0.042910803109407425,
|
||||||
|
0.9634712934494019,
|
||||||
|
0.9677271246910095,
|
||||||
|
1.1112406253814697,
|
||||||
|
-2.0086846351623535,
|
||||||
|
1.729629635810852,
|
||||||
|
-0.3664000928401947,
|
||||||
|
2.4323978424072266,
|
||||||
|
-1.587996244430542,
|
||||||
|
-1.0803641080856323,
|
||||||
|
0.08011860400438309,
|
||||||
|
1.6515964269638062,
|
||||||
|
-1.1337167024612427,
|
||||||
|
-0.5088973045349121,
|
||||||
|
-1.0002555847167969,
|
||||||
|
0.11426643282175064,
|
||||||
|
-0.8616334199905396,
|
||||||
|
-0.006051262840628624,
|
||||||
|
0.44800689816474915,
|
||||||
|
0.6659525632858276,
|
||||||
|
-0.9864538908004761,
|
||||||
|
2.1259539127349854,
|
||||||
|
-0.49345871806144714,
|
||||||
|
-0.14384664595127106,
|
||||||
|
0.0742349922657013,
|
||||||
|
0.25577273964881897,
|
||||||
|
1.0516602993011475,
|
||||||
|
1.7297064065933228,
|
||||||
|
-0.44126248359680176,
|
||||||
|
1.3971654176712036,
|
||||||
|
0.04305446520447731,
|
||||||
|
-2.261837959289551,
|
||||||
|
-0.355578750371933,
|
||||||
|
-0.8388981819152832,
|
||||||
|
0.8178591728210449,
|
||||||
|
0.016942109912633896,
|
||||||
|
0.8212596774101257,
|
||||||
|
1.108891248703003,
|
||||||
|
-0.5182072520256042,
|
||||||
|
-0.07741295546293259,
|
||||||
|
0.9407528042793274,
|
||||||
|
0.026407398283481598,
|
||||||
|
-0.6210324168205261,
|
||||||
|
-2.0659642219543457,
|
||||||
|
0.13895569741725922,
|
||||||
|
-1.3570973873138428,
|
||||||
|
2.236407995223999,
|
||||||
|
-0.29706746339797974,
|
||||||
|
1.9819035530090332,
|
||||||
|
1.3580390214920044,
|
||||||
|
-0.5505754351615906,
|
||||||
|
0.7189999222755432,
|
||||||
|
-0.3190038502216339,
|
||||||
|
1.1075336933135986,
|
||||||
|
-1.4158482551574707,
|
||||||
|
0.20138776302337646,
|
||||||
|
0.8354343175888062,
|
||||||
|
0.1671304553747177,
|
||||||
|
-0.56927490234375,
|
||||||
|
1.057538390159607,
|
||||||
|
-0.2868591248989105,
|
||||||
|
0.005044424440711737,
|
||||||
|
0.49878695607185364,
|
||||||
|
-0.7493277192115784,
|
||||||
|
2.4639663696289062,
|
||||||
|
0.5516767501831055,
|
||||||
|
-0.2763596177101135,
|
||||||
|
-0.8769170641899109,
|
||||||
|
-1.296872615814209,
|
||||||
|
-0.5233777165412903,
|
||||||
|
-0.10551001876592636,
|
||||||
|
-0.5955559611320496,
|
||||||
|
-0.6046199202537537,
|
||||||
|
0.22645621001720428,
|
||||||
|
1.12480890750885,
|
||||||
|
-0.3678736388683319,
|
||||||
|
-1.1580262184143066,
|
||||||
|
-0.3625229299068451,
|
||||||
|
0.8251489996910095,
|
||||||
|
0.3464623987674713,
|
||||||
|
2.261840581893921,
|
||||||
|
-0.11341957747936249,
|
||||||
|
-0.6645990610122681,
|
||||||
|
0.8480257987976074,
|
||||||
|
-0.47770705819129944,
|
||||||
|
0.8085628747940063,
|
||||||
|
-0.26823946833610535,
|
||||||
|
-0.25040531158447266,
|
||||||
|
1.0610276460647583,
|
||||||
|
-0.14239133894443512,
|
||||||
|
-1.309299349784851,
|
||||||
|
-1.0987954139709473,
|
||||||
|
-0.1301683634519577,
|
||||||
|
-0.05199439451098442,
|
||||||
|
-0.07838833332061768,
|
||||||
|
-0.21310138702392578,
|
||||||
|
0.29347339272499084,
|
||||||
|
1.0793802738189697,
|
||||||
|
-1.813226342201233,
|
||||||
|
-1.1362330913543701,
|
||||||
|
-0.13013578951358795,
|
||||||
|
0.6647212505340576,
|
||||||
|
-0.34312230348587036,
|
||||||
|
0.5921282172203064,
|
||||||
|
0.26284533739089966,
|
||||||
|
0.9369505047798157,
|
||||||
|
0.1739131063222885,
|
||||||
|
0.7924790978431702,
|
||||||
|
0.3412249982357025,
|
||||||
|
0.16646981239318848,
|
||||||
|
-0.32468467950820923,
|
||||||
|
-0.5835385918617249,
|
||||||
|
0.05923287197947502,
|
||||||
|
1.191710352897644,
|
||||||
|
-0.3653518557548523,
|
||||||
|
-0.8665252923965454,
|
||||||
|
0.7419591546058655,
|
||||||
|
-1.7234965562820435,
|
||||||
|
0.3421083092689514,
|
||||||
|
-0.24517370760440826,
|
||||||
|
-0.8724228143692017,
|
||||||
|
-0.11004912108182907,
|
||||||
|
-0.10676378011703491,
|
||||||
|
-1.0688399076461792,
|
||||||
|
0.4397974908351898,
|
||||||
|
-0.9902229309082031,
|
||||||
|
-0.2676651179790497,
|
||||||
|
1.4346729516983032,
|
||||||
|
0.34571582078933716,
|
||||||
|
0.9091840386390686,
|
||||||
|
0.41458258032798767,
|
||||||
|
-0.7863419055938721,
|
||||||
|
0.6952191591262817,
|
||||||
|
0.8847752809524536,
|
||||||
|
0.15871241688728333,
|
||||||
|
-0.10740098357200623,
|
||||||
|
-0.5305340886116028,
|
||||||
|
1.0536329746246338,
|
||||||
|
-1.337695837020874,
|
||||||
|
0.23358777165412903,
|
||||||
|
-0.19285082817077637,
|
||||||
|
-0.5339606404304504,
|
||||||
|
-0.6768214106559753,
|
||||||
|
1.6815600395202637,
|
||||||
|
-0.36710524559020996,
|
||||||
|
-0.22888287901878357,
|
||||||
|
-0.2714850902557373,
|
||||||
|
-0.0895417258143425,
|
||||||
|
0.3480932116508484,
|
||||||
|
-0.19148986041545868,
|
||||||
|
0.44108960032463074,
|
||||||
|
0.03198949620127678,
|
||||||
|
-0.3665091097354889,
|
||||||
|
-0.6040502786636353,
|
||||||
|
0.37234461307525635,
|
||||||
|
-0.07462035119533539,
|
||||||
|
-0.18109525740146637,
|
||||||
|
-0.19882601499557495,
|
||||||
|
0.33298638463020325,
|
||||||
|
0.039957765489816666,
|
||||||
|
0.6185765266418457,
|
||||||
|
1.5921381711959839,
|
||||||
|
0.04164457693696022,
|
||||||
|
-0.7556226849555969,
|
||||||
|
-1.0537445545196533,
|
||||||
|
0.36932048201560974,
|
||||||
|
-0.2881639897823334,
|
||||||
|
-1.3762420415878296,
|
||||||
|
-0.6029151678085327,
|
||||||
|
-1.3592504262924194,
|
||||||
|
0.6726564168930054,
|
||||||
|
0.06349147856235504,
|
||||||
|
-0.4627697765827179,
|
||||||
|
1.1113581657409668,
|
||||||
|
-1.1767970323562622,
|
||||||
|
0.3900119662284851,
|
||||||
|
-0.3050364851951599,
|
||||||
|
-0.2807784676551819,
|
||||||
|
-0.7237444519996643,
|
||||||
|
-0.039161279797554016,
|
||||||
|
0.5845404267311096,
|
||||||
|
-0.4385261833667755,
|
||||||
|
-0.3988557755947113,
|
||||||
|
-1.235430359840393,
|
||||||
|
-0.648483395576477,
|
||||||
|
1.084520936012268
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speaker_id": "81e1b806-c76a-468f-98b7-4a63f2996480",
|
||||||
|
"speaker_name": "HaiaoDuan",
|
||||||
|
"wav_path": "/home/lyg/Code/funasr/data/speaker_wav/HaiaoDuan.wav",
|
||||||
|
"speaker_embs": [
|
||||||
|
-1.3490606546401978,
|
||||||
|
-0.9654964208602905,
|
||||||
|
0.6671794652938843,
|
||||||
|
2.3401081562042236,
|
||||||
|
-1.374346137046814,
|
||||||
|
0.24404077231884003,
|
||||||
|
0.08137784898281097,
|
||||||
|
0.10915698111057281,
|
||||||
|
0.8208633065223694,
|
||||||
|
-1.0312862396240234,
|
||||||
|
1.721955418586731,
|
||||||
|
-0.16976028680801392,
|
||||||
|
-1.0259445905685425,
|
||||||
|
-0.9134035706520081,
|
||||||
|
-1.3709611892700195,
|
||||||
|
-0.6821202635765076,
|
||||||
|
1.0825326442718506,
|
||||||
|
1.4931895732879639,
|
||||||
|
-0.06801076978445053,
|
||||||
|
-0.5044959187507629,
|
||||||
|
-1.3154232501983643,
|
||||||
|
-1.1049765348434448,
|
||||||
|
0.6122218370437622,
|
||||||
|
1.1061663627624512,
|
||||||
|
-0.2288999855518341,
|
||||||
|
-0.03568289428949356,
|
||||||
|
-0.9260172247886658,
|
||||||
|
1.1030527353286743,
|
||||||
|
-0.7439772486686707,
|
||||||
|
1.4323620796203613,
|
||||||
|
0.2221372127532959,
|
||||||
|
-0.8355774283409119,
|
||||||
|
0.6758987307548523,
|
||||||
|
0.8520456552505493,
|
||||||
|
-0.0186605341732502,
|
||||||
|
-0.981821596622467,
|
||||||
|
0.11743613332509995,
|
||||||
|
-0.3539535701274872,
|
||||||
|
-0.33924832940101624,
|
||||||
|
-0.510174036026001,
|
||||||
|
0.6893219351768494,
|
||||||
|
-0.10966216027736664,
|
||||||
|
-1.5873743295669556,
|
||||||
|
1.7041956186294556,
|
||||||
|
-0.9844599366188049,
|
||||||
|
-1.368901252746582,
|
||||||
|
0.44316115975379944,
|
||||||
|
-2.406590700149536,
|
||||||
|
0.9880101680755615,
|
||||||
|
0.8344699740409851,
|
||||||
|
0.22896111011505127,
|
||||||
|
-1.4464795589447021,
|
||||||
|
2.222980260848999,
|
||||||
|
-0.22508130967617035,
|
||||||
|
0.8659772276878357,
|
||||||
|
0.7801474928855896,
|
||||||
|
1.824644923210144,
|
||||||
|
-0.2455991804599762,
|
||||||
|
-0.06682202965021133,
|
||||||
|
0.07106778025627136,
|
||||||
|
-1.8072712421417236,
|
||||||
|
0.7733234763145447,
|
||||||
|
0.20490191876888275,
|
||||||
|
-1.119908094406128,
|
||||||
|
-1.2623472213745117,
|
||||||
|
0.34426289796829224,
|
||||||
|
0.7909225821495056,
|
||||||
|
0.47128093242645264,
|
||||||
|
-0.9976771473884583,
|
||||||
|
-0.6703121662139893,
|
||||||
|
0.7459381818771362,
|
||||||
|
1.0664807558059692,
|
||||||
|
0.659284770488739,
|
||||||
|
-0.49438077211380005,
|
||||||
|
0.1974140703678131,
|
||||||
|
-0.07557231187820435,
|
||||||
|
-1.324866533279419,
|
||||||
|
-1.2217090129852295,
|
||||||
|
-1.0160834789276123,
|
||||||
|
0.7517350912094116,
|
||||||
|
0.06301767379045486,
|
||||||
|
0.8621189594268799,
|
||||||
|
-1.033493161201477,
|
||||||
|
-0.18051855266094208,
|
||||||
|
-0.2633781135082245,
|
||||||
|
0.5859690308570862,
|
||||||
|
1.5803791284561157,
|
||||||
|
-0.7071301341056824,
|
||||||
|
-0.016185184940695763,
|
||||||
|
-0.5259001851081848,
|
||||||
|
-0.6252623796463013,
|
||||||
|
1.4383807182312012,
|
||||||
|
0.6068354845046997,
|
||||||
|
0.39534664154052734,
|
||||||
|
0.22612401843070984,
|
||||||
|
-1.541978120803833,
|
||||||
|
-2.575181484222412,
|
||||||
|
-0.9924071431159973,
|
||||||
|
1.9649298191070557,
|
||||||
|
-1.1940282583236694,
|
||||||
|
-0.6481325030326843,
|
||||||
|
-1.5226261615753174,
|
||||||
|
1.6535273790359497,
|
||||||
|
0.7740333676338196,
|
||||||
|
-1.8780876398086548,
|
||||||
|
0.627184271812439,
|
||||||
|
1.0915889739990234,
|
||||||
|
1.694388508796692,
|
||||||
|
-0.47886598110198975,
|
||||||
|
-0.04895557090640068,
|
||||||
|
0.3620351552963257,
|
||||||
|
0.640113115310669,
|
||||||
|
-0.4149058163166046,
|
||||||
|
-0.18083086609840393,
|
||||||
|
-0.30447620153427124,
|
||||||
|
0.022528085857629776,
|
||||||
|
-0.6550383567810059,
|
||||||
|
-0.3812088668346405,
|
||||||
|
-0.478842169046402,
|
||||||
|
0.6615785360336304,
|
||||||
|
0.49959492683410645,
|
||||||
|
-0.249789759516716,
|
||||||
|
1.7448066473007202,
|
||||||
|
-0.9037050008773804,
|
||||||
|
-0.7441433668136597,
|
||||||
|
0.5949154496192932,
|
||||||
|
-1.1230697631835938,
|
||||||
|
-0.2552490830421448,
|
||||||
|
0.4216223657131195,
|
||||||
|
-0.5870983004570007,
|
||||||
|
0.7283152937889099,
|
||||||
|
-0.13834434747695923,
|
||||||
|
-1.3267407417297363,
|
||||||
|
1.1050132513046265,
|
||||||
|
1.731435775756836,
|
||||||
|
0.3724023103713989,
|
||||||
|
0.830539882183075,
|
||||||
|
-1.032881736755371,
|
||||||
|
0.8204181790351868,
|
||||||
|
0.05735205113887787,
|
||||||
|
0.5442802906036377,
|
||||||
|
-0.7974395751953125,
|
||||||
|
0.18374553322792053,
|
||||||
|
-0.17642715573310852,
|
||||||
|
-0.051413919776678085,
|
||||||
|
-0.2413552850484848,
|
||||||
|
-0.43316808342933655,
|
||||||
|
-0.2594863772392273,
|
||||||
|
1.5363879203796387,
|
||||||
|
0.5056991577148438,
|
||||||
|
-1.3894445896148682,
|
||||||
|
-1.2057586908340454,
|
||||||
|
-0.48546579480171204,
|
||||||
|
-0.2659154236316681,
|
||||||
|
0.9767322540283203,
|
||||||
|
-1.97313392162323,
|
||||||
|
-0.3016327917575836,
|
||||||
|
-0.6123557686805725,
|
||||||
|
0.288481205701828,
|
||||||
|
0.2976057827472687,
|
||||||
|
0.08243764936923981,
|
||||||
|
0.6122551560401917,
|
||||||
|
-0.6019028425216675,
|
||||||
|
-0.10548368841409683,
|
||||||
|
-0.016991911455988884,
|
||||||
|
1.75961172580719,
|
||||||
|
0.6418831944465637,
|
||||||
|
0.3137458264827728,
|
||||||
|
0.25365981459617615,
|
||||||
|
-0.45389246940612793,
|
||||||
|
0.238858163356781,
|
||||||
|
0.2631453275680542,
|
||||||
|
1.1121031045913696,
|
||||||
|
-0.9991472363471985,
|
||||||
|
-0.8904637694358826,
|
||||||
|
-1.1346020698547363,
|
||||||
|
-1.1918814182281494,
|
||||||
|
-1.1205440759658813,
|
||||||
|
-1.486283779144287,
|
||||||
|
1.0530670881271362,
|
||||||
|
-0.583172082901001,
|
||||||
|
0.26391518115997314,
|
||||||
|
1.2654175758361816,
|
||||||
|
-0.8430055975914001,
|
||||||
|
0.21697403490543365,
|
||||||
|
-0.30710718035697937,
|
||||||
|
2.191946506500244,
|
||||||
|
-0.19980488717556,
|
||||||
|
-0.5966204404830933,
|
||||||
|
0.04923265427350998,
|
||||||
|
-0.8815436959266663,
|
||||||
|
0.9289136528968811
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
14
data/speakers.json.backup
Normal file
14
data/speakers.json.backup
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"speaker_id": "b7e2c8e2-1f3a-4c2a-9e7a-2c1d4e8f9a3b",
|
||||||
|
"speaker_name": "ZiyangZhang",
|
||||||
|
"wav_path": "/home/lyg/Code/funasr/data/speaker_wav/ZiyangZhang.wav",
|
||||||
|
"speaker_embs": ""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speaker_id": "b7e2c8e2-1f3a-4c2a-9e7a-2c1d4e8f9a3b",
|
||||||
|
"speaker_name": "HaiaoDuan",
|
||||||
|
"wav_path": "/home/lyg/Code/funasr/data/speaker_wav/HaiaoDuan.wav",
|
||||||
|
"speaker_embs": ""
|
||||||
|
}
|
||||||
|
]
|
@ -1,23 +0,0 @@
|
|||||||
version: '3.8'
|
|
||||||
|
|
||||||
services:
|
|
||||||
funasr:
|
|
||||||
build: .
|
|
||||||
container_name: funasr-websocket
|
|
||||||
volumes:
|
|
||||||
- .:/app
|
|
||||||
# 如果需要使用本地模型缓存
|
|
||||||
- ~/.cache/modelscope:/root/.cache/modelscope
|
|
||||||
ports:
|
|
||||||
- "10095:10095"
|
|
||||||
environment:
|
|
||||||
- PYTHONUNBUFFERED=1
|
|
||||||
deploy:
|
|
||||||
resources:
|
|
||||||
reservations:
|
|
||||||
devices:
|
|
||||||
- driver: nvidia
|
|
||||||
count: 1
|
|
||||||
capabilities: [gpu]
|
|
||||||
restart: unless-stopped
|
|
||||||
command: python src/server.py --device cuda --ngpu 1
|
|
36
docker/Dockerfile
Normal file
36
docker/Dockerfile
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
FROM python:3.10-slim
|
||||||
|
|
||||||
|
# 更换系统源
|
||||||
|
# 科大大的 Debian 12 (bookworm) 镜像源
|
||||||
|
RUN echo "deb http://mirrors.ustc.edu.cn/debian bookworm main contrib non-free" > /etc/apt/sources.list && \
|
||||||
|
echo "deb http://mirrors.ustc.edu.cn/debian bookworm-updates main contrib non-free" >> /etc/apt/sources.list && \
|
||||||
|
echo "deb http://mirrors.ustc.edu.cn/debian-security bookworm-security main" >> /etc/apt/sources.list && \
|
||||||
|
echo "deb http://mirrors.ustc.edu.cn/debian bookworm-backports main contrib non-free" >> /etc/apt/sources.list
|
||||||
|
# 安装系统依赖
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
build-essential \
|
||||||
|
libsndfile1 \
|
||||||
|
ffmpeg \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# 安装Python依赖
|
||||||
|
COPY ../requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
|
||||||
|
# 复制应用代码
|
||||||
|
COPY ../src/ /app/src/
|
||||||
|
COPY ../logs/ /app/logs/
|
||||||
|
COPY ../main.py /app/
|
||||||
|
|
||||||
|
# 设置环境变量
|
||||||
|
ENV SPEAKERS_URL="http://172.23.30.120:11200/api/v1/speakers/" \
|
||||||
|
LOG_LEVEL="INFO" \
|
||||||
|
LOG_LEVEL_ASR_SERVER="INFO"
|
||||||
|
|
||||||
|
# 暴露WebSocket端口
|
||||||
|
EXPOSE 11096
|
||||||
|
|
||||||
|
# 启动服务
|
||||||
|
CMD ["python", "main.py"]
|
94
docs/SystemArchitecture.md
Normal file
94
docs/SystemArchitecture.md
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
# 系统架构
|
||||||
|
|
||||||
|
本项目是一个基于 FunASR 和 FastAPI 构建的高性能、实时的语音识别(ASR)WebSocket 服务。其核心架构设计旨在处理实时的流式音频数据,并通过 "一发多收" 的广播模式,将识别结果分发给多个客户端。
|
||||||
|
|
||||||
|
## 核心组件
|
||||||
|
|
||||||
|
系统主要由以下几个核心组件构成,它们各司其职,通过异步和多线程协作,实现了高效的实时语音处理:
|
||||||
|
|
||||||
|
1. **WebSocket 服务 (FastAPI)**
|
||||||
|
- **文件**: `src/websockets/`, `src/server.py`, `main.py`
|
||||||
|
- **职责**: 作为系统的网络入口,负责处理 WebSocket 连接。它使用 FastAPI 构建,提供异步的、非阻塞的 I/O 处理能力,能够高效地管理大量并发客户端连接。`asr_endpoint.py` 是核心端点,负责根据客户端声明的 `mode` (sender/receiver) 将连接路由到 `ASRRunner`。
|
||||||
|
|
||||||
|
2. **会话管理器 (ASRRunner)**
|
||||||
|
- **文件**: `src/runner/ASRRunner.py`
|
||||||
|
- **职责**: 这是整个系统的"大脑"和协调中心。它管理所有活跃的语音识别会话(`SenderAndReceiver`,简称 SAR)。
|
||||||
|
- **会话生命周期**: 当一个 `sender` 连接时,`ASRRunner` 会创建一个新的 SAR 实例;当 `receiver` 连接时,会将其加入到指定的 SAR 中。
|
||||||
|
- **异步桥梁**: `ASRRunner` 运行在主 `asyncio` 事件循环中,负责从 `sender` 的 WebSocket 连接异步接收音频数据,然后通过线程安全的队列 (`queue.Queue`) 将数据传递给同步的 `ASRPipeline`。同时,它也负责接收来自 Pipeline 的最终结果,并将其异步广播给所有 `receiver`。
|
||||||
|
|
||||||
|
3. **语音处理流水线 (ASRPipeline)**
|
||||||
|
- **文件**: `src/pipeline/ASRpipeline.py`
|
||||||
|
- **职责**: 这是实际执行语音处理任务的核心引擎。每个 SAR 会话都拥有一个独立的 `ASRPipeline` 实例,该实例在自己的后台线程中运行。
|
||||||
|
- **模块化设计**: Pipeline 内部由多个 `Functor` (如 VAD, ASR, SPK) 组成,通过一系列内部队列连接,形成一个处理链。
|
||||||
|
- **处理流程**:
|
||||||
|
1. **VAD (Voice Activity Detection)**: 检测音频流中的有效语音片段。
|
||||||
|
2. **ASR (Automatic Speech Recognition)**: 将语音片段转换为文字。
|
||||||
|
3. **SPK (Speaker Recognition)**: 识别说话人(声纹识别)。
|
||||||
|
4. **ResultBinder**: 将 ASR 的文本结果和 SPK 的说话人结果合并,生成最终的识别消息。
|
||||||
|
|
||||||
|
4. **原子操作单元 (Functor)**
|
||||||
|
- **文件**: `src/functor/`
|
||||||
|
- **职责**: `Functor` 是 Pipeline 中执行具体原子任务的单元。每个 Functor 都是一个独立的类,负责调用底层 FunASR 模型来执行 VAD、ASR 或 SPK 等任务。这种设计使得处理流程更加清晰和模块化。
|
||||||
|
|
||||||
|
## 流程图
|
||||||
|
|
||||||
|
下面是系统处理一次完整语音识别请求的流程图,展示了从客户端连接到收到识别结果的全过程。
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
sequenceDiagram
|
||||||
|
participant Sender Client
|
||||||
|
participant Receiver Client
|
||||||
|
participant FastAPI WebSocket Endpoint
|
||||||
|
participant ASRRunner
|
||||||
|
participant ASRPipeline (Thread)
|
||||||
|
participant Functors (VAD, ASR, SPK)
|
||||||
|
|
||||||
|
par
|
||||||
|
Sender Client->>+FastAPI WebSocket Endpoint: 发起连接 (mode=sender, session_id=S1)
|
||||||
|
FastAPI WebSocket Endpoint->>+ASRRunner: new_SAR(ws, name="S1")
|
||||||
|
ASRRunner->>ASRRunner: 创建 SenderAndReceiver (SAR) 实例
|
||||||
|
ASRRunner->>ASRPipeline (Thread): 创建并运行 Pipeline 实例
|
||||||
|
ASRPipeline (Thread)->>Functors (VAD, ASR, SPK): 初始化 Functor 线程
|
||||||
|
ASRRunner->>-FastAPI WebSocket Endpoint: 返回成功
|
||||||
|
FastAPI WebSocket Endpoint->>-Sender Client: 连接建立
|
||||||
|
and
|
||||||
|
Receiver Client->>+FastAPI WebSocket Endpoint: 发起连接 (mode=receiver, session_id=S1)
|
||||||
|
FastAPI WebSocket Endpoint->>+ASRRunner: join_SAR(ws, name="S1")
|
||||||
|
ASRRunner->>ASRRunner: 将 Receiver 加入 S1 的接收者列表
|
||||||
|
ASRRunner->>-FastAPI WebSocket Endpoint: 返回成功
|
||||||
|
FastAPI WebSocket Endpoint->>-Receiver Client: 连接建立
|
||||||
|
end
|
||||||
|
|
||||||
|
loop 音频流传输与处理
|
||||||
|
Sender Client->>ASRRunner: 发送音频数据块
|
||||||
|
ASRRunner->>ASRPipeline (Thread): (via Queue) 传递音频数据
|
||||||
|
ASRPipeline (Thread)->>Functors (VAD, ASR, SPK): (via sub-queues) 分发数据
|
||||||
|
Note over Functors (VAD, ASR, SPK): 1. VAD检测语音<br/>2. ASR识别文本<br/>3. SPK识别说话人<br/>4. ResultBinder合并结果
|
||||||
|
Functors (VAD, ASR, SPK)->>ASRPipeline (Thread): (via callback) 返回最终识别结果
|
||||||
|
ASRPipeline (Thread)->>ASRRunner: (via thread-safe callback) 发送结果
|
||||||
|
end
|
||||||
|
|
||||||
|
ASRRunner->>ASRRunner: 收到结果,准备广播
|
||||||
|
ASRRunner-->>Sender Client: 广播识别结果
|
||||||
|
ASRRunner-->>Receiver Client: 广播识别结果
|
||||||
|
|
||||||
|
par
|
||||||
|
Sender Client->>FastAPI WebSocket Endpoint: 关闭连接
|
||||||
|
FastAPI WebSocket Endpoint->>ASRRunner: (触发异常)
|
||||||
|
ASRRunner->>ASRPipeline (Thread): 发送停止信号
|
||||||
|
ASRPipeline (Thread)->>Functors (VAD, ASR, SPK): 停止 Functor 线程
|
||||||
|
Note right of ASRRunner: 清理会话资源
|
||||||
|
and
|
||||||
|
Receiver Client->>FastAPI WebSocket Endpoint: 关闭连接
|
||||||
|
FastAPI WebSocket Endpoint->>ASRRunner: (触发异常)
|
||||||
|
ASRRunner->>ASRRunner: 从接收者列表移除
|
||||||
|
end
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## 架构优势
|
||||||
|
|
||||||
|
- **高并发和低延迟**: 采用 `asyncio` 和 WebSocket,网络层能够处理大量并发连接。音频处理在独立的线程中进行,避免了 CPU 密集型任务阻塞事件循环,保证了低延迟。
|
||||||
|
- **解耦与模块化**: `WebSocket Endpoint`、`ASRRunner` 和 `ASRPipeline` 职责清晰,相互解耦。`Functor` 的设计使得添加或修改处理步骤变得容易。
|
||||||
|
- **鲁棒性**: 每个识别会话(SAR)都是隔离的,一个会话的失败不会影响其他会话。优雅的关闭逻辑确保了资源的正确释放。
|
||||||
|
- **可扩展性**: "一发多收" 的广播模式可以轻松扩展到大量 `receiver`,适用于多种实时应用场景。
|
262
docs/WEBSOCKET_API.md
Normal file
262
docs/WEBSOCKET_API.md
Normal file
@ -0,0 +1,262 @@
|
|||||||
|
# FunASR-FastAPI WebSocket API 文档
|
||||||
|
|
||||||
|
本文档详细介绍了如何连接和使用 FunASR-FastAPI 实时语音识别服务的 WebSocket 接口。
|
||||||
|
|
||||||
|
## 1. 连接端点 (Endpoint)
|
||||||
|
|
||||||
|
服务的 WebSocket 端点 URL 格式如下:
|
||||||
|
|
||||||
|
```
|
||||||
|
ws://<your_server_host>:8000/ws/asr/{session_id}?mode=<client_mode>
|
||||||
|
```
|
||||||
|
|
||||||
|
### 参数说明
|
||||||
|
|
||||||
|
- **`{session_id}`** (路径参数, `str`, **必需**):
|
||||||
|
用于唯一标识一个识别会话(例如,一场会议或一次直播)。所有属于同一次会话的客户端都应使用相同的 `session_id`。
|
||||||
|
|
||||||
|
- **`mode`** (查询参数, `str`, **必需**):
|
||||||
|
定义客户端的角色。
|
||||||
|
- `sender`: 音频发送者。一个会话中应该只有一个 `sender`。此客户端负责将实时音频流发送到服务器。
|
||||||
|
- `receiver`: 结果接收者。一个会话中可以有多个 `receiver`。此客户端只接收由服务器广播的识别结果,不发送音频。
|
||||||
|
|
||||||
|
## 2. 数据格式
|
||||||
|
|
||||||
|
### 2.1 发送数据 (Sender -> Server)
|
||||||
|
|
||||||
|
- **音频格式**: `sender` 必须发送原始的 **PCM 音频数据**。
|
||||||
|
- **采样率**: 16000 Hz
|
||||||
|
- **位深**: 32-bit (floating point)
|
||||||
|
- **声道数**: 单声道 (Mono)
|
||||||
|
- **传输格式**: 必须以**二进制 (bytes)** 格式发送。
|
||||||
|
- **结束信号**: 当音频流结束时,`sender` 应发送一个**文本消息** `"close"` 来通知服务器关闭会话。
|
||||||
|
|
||||||
|
### 2.2 接收数据 (Server -> Receiver)
|
||||||
|
|
||||||
|
服务器会将识别结果以 **JSON 文本** 格式广播给会话中的所有 `receiver`(以及 `sender` 自己)。JSON 对象的结构示例如下:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"asr": "你好,世界。",
|
||||||
|
"spk": {
|
||||||
|
"speaker_id": "uuid-of-the-speaker",
|
||||||
|
"speaker_name": "SpeakerName",
|
||||||
|
"score": 0.98
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3. Python 客户端示例
|
||||||
|
|
||||||
|
需要安装 `websockets` 库: `pip install websockets`
|
||||||
|
|
||||||
|
### 3.1 Python Sender 示例 (发送本地音频文件)
|
||||||
|
|
||||||
|
这个脚本会读取一个 WAV 文件,并将其内容以流式方式发送到服务器。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import asyncio
|
||||||
|
import websockets
|
||||||
|
import soundfile as sf
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
# --- 配置 ---
|
||||||
|
SERVER_URI = "ws://localhost:8000/ws/asr/{session_id}?mode=sender"
|
||||||
|
SESSION_ID = str(uuid.uuid4()) # 为这次会话生成一个唯一的ID
|
||||||
|
AUDIO_FILE = "tests/XT_ZZY_denoise.wav" # 替换为你的音频文件路径
|
||||||
|
CHUNK_SIZE = 3200 # 对应 100ms 的 float32 数据 (16000 * 4 * 0.1)
|
||||||
|
|
||||||
|
async def send_audio():
|
||||||
|
"""连接到服务器,并流式发送音频文件"""
|
||||||
|
uri = SERVER_URI.format(session_id=SESSION_ID)
|
||||||
|
print(f"作为 Sender 连接到: {uri}")
|
||||||
|
|
||||||
|
async with websockets.connect(uri) as websocket:
|
||||||
|
try:
|
||||||
|
# 读取音频文件
|
||||||
|
with sf.SoundFile(AUDIO_FILE, 'r') as f:
|
||||||
|
assert f.samplerate == 16000, "音频文件采样率必须为 16kHz"
|
||||||
|
assert f.channels == 1, "音频文件必须为单声道"
|
||||||
|
|
||||||
|
print("开始发送音频...")
|
||||||
|
while True:
|
||||||
|
# 读取为 float32 类型
|
||||||
|
data = f.read(CHUNK_SIZE, dtype='float32')
|
||||||
|
if not data.any():
|
||||||
|
break
|
||||||
|
# 将 numpy 数组转换为原始字节流
|
||||||
|
await websocket.send(data.tobytes())
|
||||||
|
await asyncio.sleep(0.1) # 模拟实时音频输入
|
||||||
|
|
||||||
|
print("音频发送完毕,发送结束信号。")
|
||||||
|
await websocket.send("close")
|
||||||
|
|
||||||
|
# 等待服务器的最终确认或关闭连接
|
||||||
|
response = await websocket.recv()
|
||||||
|
print(f"收到服务器最终响应: {response}")
|
||||||
|
|
||||||
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
|
print(f"连接已关闭: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"发生错误: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(send_audio())
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.2 Python Receiver 示例 (接收识别结果)
|
||||||
|
|
||||||
|
这个脚本会连接到指定的会话,并持续打印服务器广播的识别结果。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import asyncio
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
# --- 配置 ---
|
||||||
|
# !!! 必须和 Sender 使用相同的 SESSION_ID !!!
|
||||||
|
SERVER_URI = "ws://localhost:8000/ws/asr/{session_id}?mode=receiver"
|
||||||
|
SESSION_ID = "在此处粘贴你的Sender会话ID"
|
||||||
|
|
||||||
|
async def receive_results():
|
||||||
|
"""连接到服务器并接收识别结果"""
|
||||||
|
if "粘贴你的Sender会话ID" in SESSION_ID:
|
||||||
|
print("错误:请先设置有效的 SESSION_ID!")
|
||||||
|
return
|
||||||
|
|
||||||
|
uri = SERVER_URI.format(session_id=SESSION_ID)
|
||||||
|
print(f"作为 Receiver 连接到: {uri}")
|
||||||
|
|
||||||
|
async with websockets.connect(uri) as websocket:
|
||||||
|
try:
|
||||||
|
print("等待接收识别结果...")
|
||||||
|
while True:
|
||||||
|
message = await websocket.recv()
|
||||||
|
print(f"收到结果: {message}")
|
||||||
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
|
print(f"连接已关闭: {e.code} {e.reason}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(receive_results())
|
||||||
|
```
|
||||||
|
|
||||||
|
## 4. JavaScript 客户端示例 (浏览器)
|
||||||
|
|
||||||
|
这个示例展示了如何在网页上通过麦克风获取音频,并将其作为 `sender` 发送。
|
||||||
|
|
||||||
|
```html
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<title>WebSocket ASR Client</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>FunASR WebSocket Client (Sender)</h1>
|
||||||
|
<p><strong>Session ID:</strong> <span id="sessionId"></span></p>
|
||||||
|
<button id="startButton">开始识别</button>
|
||||||
|
<button id="stopButton" disabled>停止识别</button>
|
||||||
|
<h2>识别结果:</h2>
|
||||||
|
<div id="results"></div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
const startButton = document.getElementById('startButton');
|
||||||
|
const stopButton = document.getElementById('stopButton');
|
||||||
|
const resultsDiv = document.getElementById('results');
|
||||||
|
const sessionIdSpan = document.getElementById('sessionId');
|
||||||
|
|
||||||
|
let websocket;
|
||||||
|
let audioContext;
|
||||||
|
let scriptProcessor;
|
||||||
|
let mediaStream;
|
||||||
|
|
||||||
|
const CHUNK_DURATION_MS = 100; // 每100ms发送一次数据
|
||||||
|
const SAMPLE_RATE = 16000;
|
||||||
|
|
||||||
|
// 生成一个简单的UUID
|
||||||
|
function generateUUID() {
|
||||||
|
return ([1e7]+-1e3+-4e3+-8e3+-1e11).replace(/[018]/g, c =>
|
||||||
|
(c ^ crypto.getRandomValues(new Uint8Array(1))[0] & 15 >> c / 4).toString(16)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function startRecording() {
|
||||||
|
const sessionId = generateUUID();
|
||||||
|
sessionIdSpan.textContent = sessionId;
|
||||||
|
const wsUrl = `ws://${window.location.host}/ws/asr/${sessionId}?mode=sender`;
|
||||||
|
|
||||||
|
websocket = new WebSocket(wsUrl);
|
||||||
|
websocket.onopen = () => {
|
||||||
|
console.log("WebSocket 连接已打开");
|
||||||
|
startButton.disabled = true;
|
||||||
|
stopButton.disabled = false;
|
||||||
|
resultsDiv.innerHTML = '';
|
||||||
|
};
|
||||||
|
|
||||||
|
websocket.onmessage = (event) => {
|
||||||
|
console.log("收到消息:", event.data);
|
||||||
|
const result = JSON.parse(event.data);
|
||||||
|
const asrText = result.asr || '';
|
||||||
|
const spkName = result.spk ? result.spk.speaker_name : 'Unknown';
|
||||||
|
resultsDiv.innerHTML += `<p><strong>${spkName}:</strong> ${asrText}</p>`;
|
||||||
|
};
|
||||||
|
|
||||||
|
websocket.onclose = () => {
|
||||||
|
console.log("WebSocket 连接已关闭");
|
||||||
|
stopRecording();
|
||||||
|
};
|
||||||
|
|
||||||
|
websocket.onerror = (error) => {
|
||||||
|
console.error("WebSocket 错误:", error);
|
||||||
|
alert("WebSocket 连接失败!");
|
||||||
|
stopRecording();
|
||||||
|
};
|
||||||
|
|
||||||
|
try {
|
||||||
|
mediaStream = await navigator.mediaDevices.getUserMedia({ audio: true, video: false });
|
||||||
|
audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: SAMPLE_RATE });
|
||||||
|
|
||||||
|
const source = audioContext.createMediaStreamSource(mediaStream);
|
||||||
|
const bufferSize = CHUNK_DURATION_MS * SAMPLE_RATE / 1000 * 4; // 计算缓冲区大小
|
||||||
|
scriptProcessor = audioContext.createScriptProcessor(bufferSize, 1, 1);
|
||||||
|
|
||||||
|
scriptProcessor.onaudioprocess = (e) => {
|
||||||
|
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||||
|
const inputData = e.inputBuffer.getChannelData(0);
|
||||||
|
// 服务器期望 float32 数据,inputData 本身就是 Float32Array,直接发送其 buffer
|
||||||
|
websocket.send(inputData.buffer);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
source.connect(scriptProcessor);
|
||||||
|
scriptProcessor.connect(audioContext.destination);
|
||||||
|
|
||||||
|
} catch (err) {
|
||||||
|
console.error("无法获取麦克风:", err);
|
||||||
|
alert("无法获取麦克风权限!");
|
||||||
|
if (websocket) websocket.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function stopRecording() {
|
||||||
|
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||||
|
websocket.send("close");
|
||||||
|
}
|
||||||
|
if (mediaStream) {
|
||||||
|
mediaStream.getTracks().forEach(track => track.stop());
|
||||||
|
}
|
||||||
|
if (scriptProcessor) {
|
||||||
|
scriptProcessor.disconnect();
|
||||||
|
}
|
||||||
|
if (audioContext) {
|
||||||
|
audioContext.close();
|
||||||
|
}
|
||||||
|
startButton.disabled = false;
|
||||||
|
stopButton.disabled = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
startButton.addEventListener('click', startRecording);
|
||||||
|
stopButton.addEventListener('click', stopRecording);
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
```
|
23
main.py
Normal file
23
main.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import uvicorn
|
||||||
|
import os
|
||||||
|
from src.server import app
|
||||||
|
from src.utils.logger import get_module_logger, setup_root_logger
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
# 日志等级 LOG_LEVEL_ASR_SERVER > LOG_LEVEL > INFO(default)
|
||||||
|
logger_level = os.getenv("LOG_LEVEL", "INFO")
|
||||||
|
logger_level = os.getenv("LOG_LEVEL_ASR_SERVER", logger_level)
|
||||||
|
setup_root_logger(level=logger_level, log_file=f"logs/main_{time}.log")
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger.info("启动 FunASR FastAPI 服务器...")
|
||||||
|
# 在生产环境中,推荐使用更强大的ASGI服务器,如Gunicorn,并配合Uvicorn workers。
|
||||||
|
# 例如: gunicorn -w 4 -k uvicorn.workers.UvicornWorker main:app
|
||||||
|
uvicorn.run(
|
||||||
|
app,
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=11096,
|
||||||
|
log_level="info"
|
||||||
|
)
|
20
pyproject.toml
Normal file
20
pyproject.toml
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
[project]
|
||||||
|
name = "asr-server"
|
||||||
|
version = "0.1.0"
|
||||||
|
requires-python = ">=3.9.21"
|
||||||
|
dependencies = [
|
||||||
|
"addict==2.4.0",
|
||||||
|
"datasets==2.21.0",
|
||||||
|
"fastapi==0.115.14",
|
||||||
|
"funasr==1.2.6",
|
||||||
|
"numpy==2.0.1",
|
||||||
|
"pillow==11.1.0",
|
||||||
|
"pydantic==2.11.3",
|
||||||
|
"pydub>=0.25.1",
|
||||||
|
"simplejson==3.20.1",
|
||||||
|
"sortedcontainers==2.4.0",
|
||||||
|
"torch==2.3.1",
|
||||||
|
"torchaudio==2.3.1",
|
||||||
|
"uvicorn==0.35.0",
|
||||||
|
"websockets==12.0",
|
||||||
|
]
|
346
requirements.txt
346
requirements.txt
@ -1,11 +1,335 @@
|
|||||||
pytest==7.3.1
|
# This file was autogenerated by uv via the following command:
|
||||||
pytest-cov==4.1.0
|
# uv pip compile pyproject.toml -o requirements.txt
|
||||||
flake8==6.0.0
|
addict==2.4.0
|
||||||
black==23.3.0
|
# via asr-server (pyproject.toml)
|
||||||
isort==5.12.0
|
aiohappyeyeballs==2.6.1
|
||||||
flask==2.3.2
|
# via aiohttp
|
||||||
requests==2.31.0
|
aiohttp==3.12.13
|
||||||
websockets==11.0.3
|
# via
|
||||||
numpy==1.24.3
|
# datasets
|
||||||
funasr==0.10.0
|
# fsspec
|
||||||
modelscope==1.9.5
|
aiosignal==1.4.0
|
||||||
|
# via aiohttp
|
||||||
|
aliyun-python-sdk-core==2.16.0
|
||||||
|
# via
|
||||||
|
# aliyun-python-sdk-kms
|
||||||
|
# oss2
|
||||||
|
aliyun-python-sdk-kms==2.16.5
|
||||||
|
# via oss2
|
||||||
|
annotated-types==0.7.0
|
||||||
|
# via pydantic
|
||||||
|
antlr4-python3-runtime==4.9.3
|
||||||
|
# via
|
||||||
|
# hydra-core
|
||||||
|
# omegaconf
|
||||||
|
anyio==4.9.0
|
||||||
|
# via starlette
|
||||||
|
async-timeout==5.0.1
|
||||||
|
# via aiohttp
|
||||||
|
attrs==25.3.0
|
||||||
|
# via aiohttp
|
||||||
|
audioread==3.0.1
|
||||||
|
# via librosa
|
||||||
|
certifi==2025.6.15
|
||||||
|
# via requests
|
||||||
|
cffi==1.17.1
|
||||||
|
# via
|
||||||
|
# cryptography
|
||||||
|
# soundfile
|
||||||
|
charset-normalizer==3.4.2
|
||||||
|
# via requests
|
||||||
|
click==8.2.1
|
||||||
|
# via uvicorn
|
||||||
|
crcmod==1.7
|
||||||
|
# via oss2
|
||||||
|
cryptography==45.0.5
|
||||||
|
# via aliyun-python-sdk-core
|
||||||
|
datasets==2.21.0
|
||||||
|
# via asr-server (pyproject.toml)
|
||||||
|
decorator==5.2.1
|
||||||
|
# via librosa
|
||||||
|
dill==0.3.8
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# multiprocess
|
||||||
|
editdistance==0.8.1
|
||||||
|
# via funasr
|
||||||
|
exceptiongroup==1.3.0
|
||||||
|
# via anyio
|
||||||
|
fastapi==0.115.14
|
||||||
|
# via asr-server (pyproject.toml)
|
||||||
|
filelock==3.18.0
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# huggingface-hub
|
||||||
|
# torch
|
||||||
|
# triton
|
||||||
|
frozenlist==1.7.0
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# aiosignal
|
||||||
|
fsspec==2024.6.1
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# huggingface-hub
|
||||||
|
# torch
|
||||||
|
funasr==1.2.6
|
||||||
|
# via asr-server (pyproject.toml)
|
||||||
|
h11==0.16.0
|
||||||
|
# via uvicorn
|
||||||
|
hf-xet==1.1.5
|
||||||
|
# via huggingface-hub
|
||||||
|
huggingface-hub==0.33.2
|
||||||
|
# via datasets
|
||||||
|
hydra-core==1.3.2
|
||||||
|
# via funasr
|
||||||
|
idna==3.10
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# requests
|
||||||
|
# yarl
|
||||||
|
jaconv==0.4.0
|
||||||
|
# via funasr
|
||||||
|
jamo==0.4.1
|
||||||
|
# via funasr
|
||||||
|
jieba==0.42.1
|
||||||
|
# via funasr
|
||||||
|
jinja2==3.1.6
|
||||||
|
# via torch
|
||||||
|
jmespath==0.10.0
|
||||||
|
# via aliyun-python-sdk-core
|
||||||
|
joblib==1.5.1
|
||||||
|
# via
|
||||||
|
# librosa
|
||||||
|
# pynndescent
|
||||||
|
# scikit-learn
|
||||||
|
kaldiio==2.18.1
|
||||||
|
# via funasr
|
||||||
|
lazy-loader==0.4
|
||||||
|
# via librosa
|
||||||
|
librosa==0.11.0
|
||||||
|
# via funasr
|
||||||
|
llvmlite==0.44.0
|
||||||
|
# via
|
||||||
|
# numba
|
||||||
|
# pynndescent
|
||||||
|
markupsafe==3.0.2
|
||||||
|
# via jinja2
|
||||||
|
modelscope==1.27.1
|
||||||
|
# via funasr
|
||||||
|
mpmath==1.3.0
|
||||||
|
# via sympy
|
||||||
|
msgpack==1.1.1
|
||||||
|
# via librosa
|
||||||
|
multidict==6.6.3
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# yarl
|
||||||
|
multiprocess==0.70.16
|
||||||
|
# via datasets
|
||||||
|
networkx==3.4.2
|
||||||
|
# via torch
|
||||||
|
numba==0.61.2
|
||||||
|
# via
|
||||||
|
# librosa
|
||||||
|
# pynndescent
|
||||||
|
# umap-learn
|
||||||
|
numpy==2.0.1
|
||||||
|
# via
|
||||||
|
# asr-server (pyproject.toml)
|
||||||
|
# datasets
|
||||||
|
# kaldiio
|
||||||
|
# librosa
|
||||||
|
# numba
|
||||||
|
# pandas
|
||||||
|
# pytorch-wpe
|
||||||
|
# scikit-learn
|
||||||
|
# scipy
|
||||||
|
# soundfile
|
||||||
|
# soxr
|
||||||
|
# tensorboardx
|
||||||
|
# torch-complex
|
||||||
|
# umap-learn
|
||||||
|
nvidia-cublas-cu12==12.1.3.1
|
||||||
|
# via
|
||||||
|
# nvidia-cudnn-cu12
|
||||||
|
# nvidia-cusolver-cu12
|
||||||
|
# torch
|
||||||
|
nvidia-cuda-cupti-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
nvidia-cuda-runtime-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
nvidia-cudnn-cu12==8.9.2.26
|
||||||
|
# via torch
|
||||||
|
nvidia-cufft-cu12==11.0.2.54
|
||||||
|
# via torch
|
||||||
|
nvidia-curand-cu12==10.3.2.106
|
||||||
|
# via torch
|
||||||
|
nvidia-cusolver-cu12==11.4.5.107
|
||||||
|
# via torch
|
||||||
|
nvidia-cusparse-cu12==12.1.0.106
|
||||||
|
# via
|
||||||
|
# nvidia-cusolver-cu12
|
||||||
|
# torch
|
||||||
|
nvidia-nccl-cu12==2.20.5
|
||||||
|
# via torch
|
||||||
|
nvidia-nvjitlink-cu12==12.9.86
|
||||||
|
# via
|
||||||
|
# nvidia-cusolver-cu12
|
||||||
|
# nvidia-cusparse-cu12
|
||||||
|
nvidia-nvtx-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
omegaconf==2.3.0
|
||||||
|
# via hydra-core
|
||||||
|
oss2==2.19.1
|
||||||
|
# via funasr
|
||||||
|
packaging==25.0
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# huggingface-hub
|
||||||
|
# hydra-core
|
||||||
|
# lazy-loader
|
||||||
|
# pooch
|
||||||
|
# tensorboardx
|
||||||
|
# torch-complex
|
||||||
|
pandas==2.3.0
|
||||||
|
# via datasets
|
||||||
|
pillow==11.1.0
|
||||||
|
# via asr-server (pyproject.toml)
|
||||||
|
platformdirs==4.3.8
|
||||||
|
# via pooch
|
||||||
|
pooch==1.8.2
|
||||||
|
# via librosa
|
||||||
|
propcache==0.3.2
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# yarl
|
||||||
|
protobuf==6.31.1
|
||||||
|
# via tensorboardx
|
||||||
|
pyarrow==20.0.0
|
||||||
|
# via datasets
|
||||||
|
pycparser==2.22
|
||||||
|
# via cffi
|
||||||
|
pycryptodome==3.23.0
|
||||||
|
# via oss2
|
||||||
|
pydantic==2.11.3
|
||||||
|
# via
|
||||||
|
# asr-server (pyproject.toml)
|
||||||
|
# fastapi
|
||||||
|
pydantic-core==2.33.1
|
||||||
|
# via pydantic
|
||||||
|
pydub==0.25.1
|
||||||
|
# via asr-server (pyproject.toml)
|
||||||
|
pynndescent==0.5.13
|
||||||
|
# via umap-learn
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
# via pandas
|
||||||
|
pytorch-wpe==0.0.1
|
||||||
|
# via funasr
|
||||||
|
pytz==2025.2
|
||||||
|
# via pandas
|
||||||
|
pyyaml==6.0.2
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# funasr
|
||||||
|
# huggingface-hub
|
||||||
|
# omegaconf
|
||||||
|
requests==2.32.4
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# funasr
|
||||||
|
# huggingface-hub
|
||||||
|
# modelscope
|
||||||
|
# oss2
|
||||||
|
# pooch
|
||||||
|
scikit-learn==1.7.0
|
||||||
|
# via
|
||||||
|
# librosa
|
||||||
|
# pynndescent
|
||||||
|
# umap-learn
|
||||||
|
scipy==1.15.3
|
||||||
|
# via
|
||||||
|
# funasr
|
||||||
|
# librosa
|
||||||
|
# pynndescent
|
||||||
|
# scikit-learn
|
||||||
|
# umap-learn
|
||||||
|
sentencepiece==0.2.0
|
||||||
|
# via funasr
|
||||||
|
setuptools==80.9.0
|
||||||
|
# via modelscope
|
||||||
|
simplejson==3.20.1
|
||||||
|
# via asr-server (pyproject.toml)
|
||||||
|
six==1.17.0
|
||||||
|
# via
|
||||||
|
# oss2
|
||||||
|
# python-dateutil
|
||||||
|
sniffio==1.3.1
|
||||||
|
# via anyio
|
||||||
|
sortedcontainers==2.4.0
|
||||||
|
# via asr-server (pyproject.toml)
|
||||||
|
soundfile==0.13.1
|
||||||
|
# via
|
||||||
|
# funasr
|
||||||
|
# librosa
|
||||||
|
soxr==0.5.0.post1
|
||||||
|
# via librosa
|
||||||
|
starlette==0.46.2
|
||||||
|
# via fastapi
|
||||||
|
sympy==1.14.0
|
||||||
|
# via torch
|
||||||
|
tensorboardx==2.6.4
|
||||||
|
# via funasr
|
||||||
|
threadpoolctl==3.6.0
|
||||||
|
# via scikit-learn
|
||||||
|
torch==2.3.1
|
||||||
|
# via
|
||||||
|
# asr-server (pyproject.toml)
|
||||||
|
# torchaudio
|
||||||
|
torch-complex==0.4.4
|
||||||
|
# via funasr
|
||||||
|
torchaudio==2.3.1
|
||||||
|
# via asr-server (pyproject.toml)
|
||||||
|
tqdm==4.67.1
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# funasr
|
||||||
|
# huggingface-hub
|
||||||
|
# modelscope
|
||||||
|
# umap-learn
|
||||||
|
triton==2.3.1
|
||||||
|
# via torch
|
||||||
|
typing-extensions==4.14.1
|
||||||
|
# via
|
||||||
|
# aiosignal
|
||||||
|
# anyio
|
||||||
|
# exceptiongroup
|
||||||
|
# fastapi
|
||||||
|
# huggingface-hub
|
||||||
|
# librosa
|
||||||
|
# multidict
|
||||||
|
# pydantic
|
||||||
|
# pydantic-core
|
||||||
|
# torch
|
||||||
|
# typing-inspection
|
||||||
|
# uvicorn
|
||||||
|
typing-inspection==0.4.1
|
||||||
|
# via pydantic
|
||||||
|
tzdata==2025.2
|
||||||
|
# via pandas
|
||||||
|
umap-learn==0.5.9.post2
|
||||||
|
# via funasr
|
||||||
|
urllib3==2.5.0
|
||||||
|
# via
|
||||||
|
# modelscope
|
||||||
|
# requests
|
||||||
|
uvicorn==0.35.0
|
||||||
|
# via asr-server (pyproject.toml)
|
||||||
|
websockets==12.0
|
||||||
|
# via asr-server (pyproject.toml)
|
||||||
|
xxhash==3.5.0
|
||||||
|
# via datasets
|
||||||
|
yarl==1.20.1
|
||||||
|
# via aiohttp
|
||||||
|
194
src/audio_chunk.py
Normal file
194
src/audio_chunk.py
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
音频数据块管理类 - 用于存储和处理16KHz音频数据
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional, Dict
|
||||||
|
from src.models import AudioBinary_Config, AudioBinary_data_list
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__, level="INFO")
|
||||||
|
|
||||||
|
|
||||||
|
class AudioBinary:
|
||||||
|
"""
|
||||||
|
音频数据存储单元
|
||||||
|
用于存储二进制数据
|
||||||
|
面向Slice, 向Slice提供数据与接口
|
||||||
|
|
||||||
|
self._audio_config: AudioBinary_Config -- 音频参数配置
|
||||||
|
self._binary_data_list: AudioBinary_data_list -- 音频数据列表
|
||||||
|
self._slice_listener: List[callable] -- 切片监听器
|
||||||
|
|
||||||
|
AudioBinary_Config: Dict -- 音频参数配置
|
||||||
|
AudioBinary_data_list: List[bytes] -- 音频数据列表
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args):
|
||||||
|
"""
|
||||||
|
初始化音频数据块
|
||||||
|
参数:
|
||||||
|
*args: 可变参数
|
||||||
|
"""
|
||||||
|
# 音频参数配置
|
||||||
|
self._audio_config = AudioBinary_Config()
|
||||||
|
# 音频片段
|
||||||
|
self._binary_data_list: AudioBinary_data_list = AudioBinary_data_list()
|
||||||
|
# 切片监听器
|
||||||
|
self._slice_listener: List = []
|
||||||
|
if isinstance(args, Dict):
|
||||||
|
self._audio_config = AudioBinary_Config.AudioBinary_Config_from_dict(args)
|
||||||
|
elif isinstance(args, AudioBinary_Config):
|
||||||
|
self._audio_config = args
|
||||||
|
else:
|
||||||
|
raise ValueError("参数类型错误")
|
||||||
|
|
||||||
|
def add_slice_listener(self, slice_listener: callable) -> None:
|
||||||
|
"""
|
||||||
|
添加切片监听器
|
||||||
|
参数:
|
||||||
|
slice_listener: callable -- 切片监听器
|
||||||
|
"""
|
||||||
|
self._slice_listener.append(slice_listener)
|
||||||
|
|
||||||
|
def __add__(self, other: bytes):
|
||||||
|
"""
|
||||||
|
__add__ 是 "+" 运算符的重载,
|
||||||
|
使用方法:
|
||||||
|
audio_binary = audio_binary + bytes
|
||||||
|
添加音频数据块 与 add_binary_data 等效,
|
||||||
|
但可以链式调用, 方便使用
|
||||||
|
参数:
|
||||||
|
other: bytes --音频数据块
|
||||||
|
"""
|
||||||
|
self._binary_data_list.append(other)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __iadd__(self, other: bytes):
|
||||||
|
"""
|
||||||
|
__iadd__ 是 "+=" 运算符的重载,
|
||||||
|
使用方法:
|
||||||
|
audio_binary += bytes
|
||||||
|
添加音频数据块 与 add_binary_data 等效,
|
||||||
|
但可以链式调用, 方便使用
|
||||||
|
参数:
|
||||||
|
other: bytes --音频数据块
|
||||||
|
"""
|
||||||
|
self._binary_data_list.append(other)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_binary_data(self, binary_data: bytes):
|
||||||
|
"""
|
||||||
|
添加音频数据块
|
||||||
|
参数:
|
||||||
|
binary_data: bytes --音频数据块
|
||||||
|
"""
|
||||||
|
self._binary_data_list.append(binary_data)
|
||||||
|
|
||||||
|
def rewrite_binary_data(self, target_index: int, binary_data: bytes):
|
||||||
|
"""
|
||||||
|
重写音频数据块
|
||||||
|
参数:
|
||||||
|
target_index: int -- 目标索引
|
||||||
|
binary_data: bytes --音频数据块
|
||||||
|
"""
|
||||||
|
self._binary_data_list.rewrite(target_index, binary_data)
|
||||||
|
|
||||||
|
def get_binary_data(
|
||||||
|
self,
|
||||||
|
start: int = 0,
|
||||||
|
end: Optional[int] = None,
|
||||||
|
) -> Optional[bytes]:
|
||||||
|
"""
|
||||||
|
获取指定索引的音频数据块
|
||||||
|
参数:
|
||||||
|
start: 开始索引
|
||||||
|
end: 结束索引
|
||||||
|
返回:
|
||||||
|
List[bytes]: 音频数据块
|
||||||
|
"""
|
||||||
|
if start >= len(self._binary_data_list):
|
||||||
|
return None
|
||||||
|
if end is None:
|
||||||
|
end = start + 1
|
||||||
|
end = min(end, len(self._binary_data_list))
|
||||||
|
return self._binary_data_list[start:end]
|
||||||
|
|
||||||
|
|
||||||
|
class AudioChunk:
|
||||||
|
"""
|
||||||
|
音频数据块管理类
|
||||||
|
管理两部分内容, AudioBinary和Slice。
|
||||||
|
AudioBinary用于内部存储字节数据。
|
||||||
|
Slice是AudioBinary的切片,用于外部接口。
|
||||||
|
|
||||||
|
此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑。
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance: Optional["AudioChunk"] = None
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
单例模式
|
||||||
|
"""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(AudioChunk, cls).__new__(cls, *args, **kwargs)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""
|
||||||
|
初始化AudioChunk实例
|
||||||
|
"""
|
||||||
|
self._audio_binary_list: Dict[str, AudioBinary] = {}
|
||||||
|
self._slice_listener: List[callable] = []
|
||||||
|
|
||||||
|
def get_audio_binary(
|
||||||
|
self,
|
||||||
|
binary_name: Optional[str] = None,
|
||||||
|
audio_config: Optional[AudioBinary_Config] = None,
|
||||||
|
) -> AudioBinary:
|
||||||
|
"""
|
||||||
|
获取音频数据块
|
||||||
|
参数:
|
||||||
|
binary_name: str -- 音频数据块名称
|
||||||
|
返回:
|
||||||
|
AudioBinary: 音频数据块
|
||||||
|
"""
|
||||||
|
if binary_name is None:
|
||||||
|
binary_name = "default"
|
||||||
|
if binary_name not in self._audio_binary_list:
|
||||||
|
self._audio_binary_list[binary_name] = AudioBinary(audio_config)
|
||||||
|
return self._audio_binary_list[binary_name]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _time2size(time_ms: int, audio_config: AudioBinary_Config) -> int:
|
||||||
|
"""
|
||||||
|
将时间(ms)转换为数据大小(字节)
|
||||||
|
参数:
|
||||||
|
time_ms: int -- 时间(ms)
|
||||||
|
audio_config: AudioBinary_Config -- 音频参数配置
|
||||||
|
返回:
|
||||||
|
int: 数据大小(字节)
|
||||||
|
"""
|
||||||
|
# 时间(ms)到字节(bytes)计算方法为: 时间(ms) * 采样率(Hz) * 通道数(1 or 2) * 采样位宽(16 or 24) / 1000
|
||||||
|
time_s = time_ms / 1000
|
||||||
|
bytes_per_sample = audio_config.sample_width * audio_config.channel
|
||||||
|
return int(time_s * audio_config.sample_rate * bytes_per_sample)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _size2time(size: int, audio_config: AudioBinary_Config) -> int:
|
||||||
|
"""
|
||||||
|
将数据大小(字节)转换为时间(ms)
|
||||||
|
参数:
|
||||||
|
size: int -- 数据大小(字节)
|
||||||
|
audio_config: AudioBinary_Config -- 音频参数配置
|
||||||
|
返回:
|
||||||
|
int: 时间(ms)
|
||||||
|
"""
|
||||||
|
# 字节(bytes)到时间(ms)计算方法为: 字节(bytes) * 1000 / (采样率(Hz) * 通道数(1 or 2) * 采样位宽(16 or 24))
|
||||||
|
bytes_per_sample = audio_config.sample_width * audio_config.channel
|
||||||
|
time_ms = size * 1000 // (audio_config.sample_rate * bytes_per_sample)
|
||||||
|
return time_ms
|
196
src/client.py
196
src/client.py
@ -1,196 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
WebSocket客户端示例 - 用于测试语音识别服务
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import websockets
|
|
||||||
import argparse
|
|
||||||
import numpy as np
|
|
||||||
import wave
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
"""解析命令行参数"""
|
|
||||||
parser = argparse.ArgumentParser(description="FunASR WebSocket客户端")
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--host",
|
|
||||||
type=str,
|
|
||||||
default="localhost",
|
|
||||||
help="服务器主机地址"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--port",
|
|
||||||
type=int,
|
|
||||||
default=10095,
|
|
||||||
help="服务器端口"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--audio_file",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="要识别的音频文件路径"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--mode",
|
|
||||||
type=str,
|
|
||||||
default="2pass",
|
|
||||||
choices=["2pass", "online", "offline"],
|
|
||||||
help="识别模式: 2pass(默认), online, offline"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--chunk_size",
|
|
||||||
type=str,
|
|
||||||
default="5,10",
|
|
||||||
help="分块大小, 格式为'encoder_size,decoder_size'"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_ssl",
|
|
||||||
action="store_true",
|
|
||||||
help="是否使用SSL连接"
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
async def send_audio(websocket, audio_file, mode, chunk_size):
|
|
||||||
"""
|
|
||||||
发送音频文件到服务器进行识别
|
|
||||||
|
|
||||||
参数:
|
|
||||||
websocket: WebSocket连接
|
|
||||||
audio_file: 音频文件路径
|
|
||||||
mode: 识别模式
|
|
||||||
chunk_size: 分块大小
|
|
||||||
"""
|
|
||||||
# 打开并读取WAV文件
|
|
||||||
with wave.open(audio_file, "rb") as wav_file:
|
|
||||||
params = wav_file.getparams()
|
|
||||||
frames = wav_file.readframes(wav_file.getnframes())
|
|
||||||
|
|
||||||
# 音频文件信息
|
|
||||||
print(f"音频文件: {os.path.basename(audio_file)}")
|
|
||||||
print(f"采样率: {params.framerate}Hz, 通道数: {params.nchannels}")
|
|
||||||
print(f"采样位深: {params.sampwidth * 8}位, 总帧数: {params.nframes}")
|
|
||||||
|
|
||||||
# 设置配置参数
|
|
||||||
config = {
|
|
||||||
"mode": mode,
|
|
||||||
"chunk_size": chunk_size,
|
|
||||||
"wav_name": os.path.basename(audio_file),
|
|
||||||
"is_speaking": True
|
|
||||||
}
|
|
||||||
|
|
||||||
# 发送配置
|
|
||||||
await websocket.send(json.dumps(config))
|
|
||||||
|
|
||||||
# 模拟实时发送音频数据
|
|
||||||
chunk_size_bytes = 3200 # 每次发送100ms的16kHz音频
|
|
||||||
total_chunks = len(frames) // chunk_size_bytes
|
|
||||||
|
|
||||||
print(f"开始发送音频数据,共 {total_chunks} 个数据块...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
for i in range(0, len(frames), chunk_size_bytes):
|
|
||||||
chunk = frames[i:i+chunk_size_bytes]
|
|
||||||
await websocket.send(chunk)
|
|
||||||
|
|
||||||
# 模拟实时,每100ms发送一次
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
# 显示进度
|
|
||||||
if (i // chunk_size_bytes) % 10 == 0:
|
|
||||||
print(f"已发送 {i // chunk_size_bytes}/{total_chunks} 数据块")
|
|
||||||
|
|
||||||
# 发送结束信号
|
|
||||||
await websocket.send(json.dumps({"is_speaking": False}))
|
|
||||||
print("音频数据发送完成")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"发送音频时出错: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def receive_results(websocket):
|
|
||||||
"""
|
|
||||||
接收并显示识别结果
|
|
||||||
|
|
||||||
参数:
|
|
||||||
websocket: WebSocket连接
|
|
||||||
"""
|
|
||||||
online_text = ""
|
|
||||||
offline_text = ""
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for message in websocket:
|
|
||||||
# 解析服务器返回的JSON消息
|
|
||||||
result = json.loads(message)
|
|
||||||
|
|
||||||
mode = result.get("mode", "")
|
|
||||||
text = result.get("text", "")
|
|
||||||
is_final = result.get("is_final", False)
|
|
||||||
|
|
||||||
# 根据模式更新文本
|
|
||||||
if "online" in mode:
|
|
||||||
online_text = text
|
|
||||||
print(f"\r[在线识别] {online_text}", end="", flush=True)
|
|
||||||
elif "offline" in mode:
|
|
||||||
offline_text = text
|
|
||||||
print(f"\n[离线识别] {offline_text}")
|
|
||||||
|
|
||||||
# 如果是最终结果,打印完整信息
|
|
||||||
if is_final and offline_text:
|
|
||||||
print("\n最终识别结果:")
|
|
||||||
print(f"[离线识别] {offline_text}")
|
|
||||||
return
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"接收结果时出错: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""主函数"""
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
# WebSocket URI
|
|
||||||
protocol = "wss" if args.use_ssl else "ws"
|
|
||||||
uri = f"{protocol}://{args.host}:{args.port}"
|
|
||||||
|
|
||||||
print(f"连接到服务器: {uri}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 创建WebSocket连接
|
|
||||||
async with websockets.connect(
|
|
||||||
uri,
|
|
||||||
subprotocols=["binary"]
|
|
||||||
) as websocket:
|
|
||||||
|
|
||||||
print("连接成功")
|
|
||||||
|
|
||||||
# 创建两个任务: 发送音频和接收结果
|
|
||||||
send_task = asyncio.create_task(
|
|
||||||
send_audio(websocket, args.audio_file, args.mode, args.chunk_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
receive_task = asyncio.create_task(
|
|
||||||
receive_results(websocket)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 等待任务完成
|
|
||||||
await asyncio.gather(send_task, receive_task)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"连接服务器失败: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 运行主函数
|
|
||||||
asyncio.run(main())
|
|
@ -1,11 +1,25 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""
|
"""
|
||||||
|
默认配置DefaultConfig
|
||||||
|
- audio_config: 音频配置
|
||||||
配置模块 - 处理命令行参数和配置项
|
配置模块 - 处理命令行参数和配置项
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
from src.models import AudioBinary_Config
|
||||||
|
|
||||||
|
class DefaultConfig:
|
||||||
|
"""
|
||||||
|
默认配置
|
||||||
|
"""
|
||||||
|
audio_config = AudioBinary_Config(
|
||||||
|
chunk_size=200,
|
||||||
|
chunk_stride=1600,
|
||||||
|
sample_rate=16000,
|
||||||
|
sample_width=16,
|
||||||
|
channels=1,
|
||||||
|
)
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
"""
|
"""
|
||||||
@ -21,41 +35,23 @@ def parse_args():
|
|||||||
"--host",
|
"--host",
|
||||||
type=str,
|
type=str,
|
||||||
default="0.0.0.0",
|
default="0.0.0.0",
|
||||||
help="服务器主机地址,例如:localhost, 0.0.0.0"
|
help="服务器主机地址,例如:localhost, 0.0.0.0",
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--port",
|
|
||||||
type=int,
|
|
||||||
default=10095,
|
|
||||||
help="WebSocket服务器端口"
|
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--port", type=int, default=10095, help="WebSocket服务器端口")
|
||||||
|
|
||||||
# SSL配置
|
# SSL配置
|
||||||
parser.add_argument(
|
parser.add_argument("--certfile", type=str, default="", help="SSL证书文件路径")
|
||||||
"--certfile",
|
parser.add_argument("--keyfile", type=str, default="", help="SSL密钥文件路径")
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="SSL证书文件路径"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--keyfile",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="SSL密钥文件路径"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ASR模型配置
|
# ASR模型配置
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--asr_model",
|
"--asr_model",
|
||||||
type=str,
|
type=str,
|
||||||
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||||
help="离线ASR模型(从ModelScope获取)"
|
help="离线ASR模型(从ModelScope获取)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--asr_model_revision",
|
"--asr_model_revision", type=str, default="v2.0.4", help="离线ASR模型版本"
|
||||||
type=str,
|
|
||||||
default="v2.0.4",
|
|
||||||
help="离线ASR模型版本"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 在线ASR模型配置
|
# 在线ASR模型配置
|
||||||
@ -63,13 +59,13 @@ def parse_args():
|
|||||||
"--asr_model_online",
|
"--asr_model_online",
|
||||||
type=str,
|
type=str,
|
||||||
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
|
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
|
||||||
help="在线ASR模型(从ModelScope获取)"
|
help="在线ASR模型(从ModelScope获取)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--asr_model_online_revision",
|
"--asr_model_online_revision",
|
||||||
type=str,
|
type=str,
|
||||||
default="v2.0.4",
|
default="v2.0.4",
|
||||||
help="在线ASR模型版本"
|
help="在线ASR模型版本",
|
||||||
)
|
)
|
||||||
|
|
||||||
# VAD模型配置
|
# VAD模型配置
|
||||||
@ -77,13 +73,10 @@ def parse_args():
|
|||||||
"--vad_model",
|
"--vad_model",
|
||||||
type=str,
|
type=str,
|
||||||
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||||
help="VAD语音活动检测模型(从ModelScope获取)"
|
help="VAD语音活动检测模型(从ModelScope获取)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vad_model_revision",
|
"--vad_model_revision", type=str, default="v2.0.4", help="VAD模型版本"
|
||||||
type=str,
|
|
||||||
default="v2.0.4",
|
|
||||||
help="VAD模型版本"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 标点符号模型配置
|
# 标点符号模型配置
|
||||||
@ -91,34 +84,18 @@ def parse_args():
|
|||||||
"--punc_model",
|
"--punc_model",
|
||||||
type=str,
|
type=str,
|
||||||
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
|
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
|
||||||
help="标点符号模型(从ModelScope获取)"
|
help="标点符号模型(从ModelScope获取)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--punc_model_revision",
|
"--punc_model_revision", type=str, default="v2.0.4", help="标点符号模型版本"
|
||||||
type=str,
|
|
||||||
default="v2.0.4",
|
|
||||||
help="标点符号模型版本"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 硬件配置
|
# 硬件配置
|
||||||
|
parser.add_argument("--ngpu", type=int, default=1, help="GPU数量,0表示仅使用CPU")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ngpu",
|
"--device", type=str, default="cuda", help="设备类型:cuda或cpu"
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="GPU数量,0表示仅使用CPU"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--device",
|
|
||||||
type=str,
|
|
||||||
default="cuda",
|
|
||||||
help="设备类型:cuda或cpu"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ncpu",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="CPU核心数"
|
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--ncpu", type=int, default=4, help="CPU核心数")
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
3
src/core/__init__.py
Normal file
3
src/core/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .model_loader import ModelLoader
|
||||||
|
|
||||||
|
__all__ = ["ModelLoader"]
|
143
src/core/model_loader.py
Normal file
143
src/core/model_loader.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
模型加载模块 - 负责加载各种语音识别相关模型
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 导入FunASR库
|
||||||
|
from funasr import AutoModel
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError("未找到funasr库, 请先安装: pip install funasr") from exc
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
# 日志模块
|
||||||
|
from src.utils import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# 单例模式
|
||||||
|
class ModelLoader:
|
||||||
|
"""
|
||||||
|
ModelLoader类是单例模式, 程序生命周期全局唯一, 负责加载模型到字典中。
|
||||||
|
一般的, 可以直接call ModelLoader()来获取加载的模型。
|
||||||
|
也可以通过ModelLoader实例(args)或ModelloaderInstance.load_models(args)来初始化, 并加载模型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
单例模式
|
||||||
|
"""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(ModelLoader, cls).__new__(cls, *args, **kwargs)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self, args=None):
|
||||||
|
"""
|
||||||
|
初始化ModelLoader实例
|
||||||
|
"""
|
||||||
|
self.models = {}
|
||||||
|
logger.debug("初始化ModelLoader")
|
||||||
|
if args is not None:
|
||||||
|
self.__call__(args)
|
||||||
|
|
||||||
|
def __call__(self, args=None):
|
||||||
|
"""
|
||||||
|
调用ModelLoader实例时, 如果模型字典为空, 则加载模型
|
||||||
|
"""
|
||||||
|
# 如果模型字典为空, 则加载模型
|
||||||
|
if self.models == {} or self.models is None:
|
||||||
|
if args.asr_model is not None:
|
||||||
|
self.models = self.load_models(args)
|
||||||
|
# 直接调用等于调用self.models
|
||||||
|
return self.models
|
||||||
|
|
||||||
|
def _load_model(self, input_model_args: dict, model_type: str):
|
||||||
|
"""
|
||||||
|
加载单个模型
|
||||||
|
|
||||||
|
参数:
|
||||||
|
model_args: 模型加载字典
|
||||||
|
model_type: 模型类型, 用于确定使用哪个模型参数
|
||||||
|
|
||||||
|
返回:
|
||||||
|
AutoModel: 加载的模型实例
|
||||||
|
"""
|
||||||
|
# 默认配置
|
||||||
|
default_config = {
|
||||||
|
"model": None,
|
||||||
|
"model_revision": None,
|
||||||
|
"ngpu": 0,
|
||||||
|
"ncpu": 1,
|
||||||
|
"device": "cpu",
|
||||||
|
"disable_pbar": True,
|
||||||
|
"disable_log": True,
|
||||||
|
"disable_update": True,
|
||||||
|
}
|
||||||
|
# 从args中获取配置, 如果存在则覆盖默认值
|
||||||
|
model_args = default_config.copy()
|
||||||
|
for key, value in default_config.items():
|
||||||
|
if key in ["model", "model_revision"]:
|
||||||
|
# 特殊处理model和model_revision, 因为它们需要model_type前缀
|
||||||
|
if key == "model":
|
||||||
|
value = input_model_args.get(f"{model_type}_model", None)
|
||||||
|
else:
|
||||||
|
value = input_model_args.get(f"{model_type}_model_revision", None)
|
||||||
|
else:
|
||||||
|
value = input_model_args.get(key, None)
|
||||||
|
if value is not None:
|
||||||
|
logger.debug("替换%s模型参数: %s = %s", model_type, key, value)
|
||||||
|
model_args[key] = value
|
||||||
|
# 验证必要参数
|
||||||
|
if not model_args["model"]:
|
||||||
|
raise ValueError(f"未指定{model_type}模型路径")
|
||||||
|
try:
|
||||||
|
# 使用 % 格式化替代 f-string,避免不必要的字符串格式化开销
|
||||||
|
logger.debug("正在加载%s模型: %s", model_type, model_args["model"])
|
||||||
|
model = AutoModel(**model_args)
|
||||||
|
return model
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("加载%s模型失败: %s", model_type, str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
# def _load_pipeline(self, input_model_args: dict, model_type: str):
|
||||||
|
# """
|
||||||
|
# 加载pipeline
|
||||||
|
# """
|
||||||
|
# default_pipeline = pipeline(
|
||||||
|
# task='speaker-verification',
|
||||||
|
# model='iic/speech_campplus_sv_zh-cn_16k-common',
|
||||||
|
# model_revision='v1.0.0'
|
||||||
|
# )
|
||||||
|
|
||||||
|
def load_models(self, args):
|
||||||
|
"""
|
||||||
|
加载所有需要的模型
|
||||||
|
参数:
|
||||||
|
args: 命令行参数, 包含模型配置
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict: 包含所有加载的模型的字典
|
||||||
|
"""
|
||||||
|
logger.info("ModelLoader加载模型")
|
||||||
|
# 初始化模型字典
|
||||||
|
self.models = {}
|
||||||
|
# 加载离线ASR模型
|
||||||
|
# 检查对应键是否存在
|
||||||
|
model_list = ["asr", "asr_online", "vad", "punc"]
|
||||||
|
for model_name in model_list:
|
||||||
|
name_model = f"{model_name}_model"
|
||||||
|
name_model_revision = f"{model_name}_model_revision"
|
||||||
|
if name_model in args:
|
||||||
|
logger.debug("加载%s模型", model_name)
|
||||||
|
self.models[model_name] = self._load_model(args, model_name)
|
||||||
|
pipeline_list = ["spk"]
|
||||||
|
for pipeline_name in pipeline_list:
|
||||||
|
if pipeline_name == "spk":
|
||||||
|
self.models[pipeline_name] = pipeline(
|
||||||
|
task='speaker-verification',
|
||||||
|
model='iic/speech_campplus_sv_zh-cn_16k-common',
|
||||||
|
model_revision='v1.0.0'
|
||||||
|
)
|
||||||
|
return self.models
|
4
src/functor/__init__.py
Normal file
4
src/functor/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .vad_functor import VADFunctor
|
||||||
|
from .base import FunctorFactory
|
||||||
|
|
||||||
|
__all__ = ["VADFunctor", "FunctorFactory"]
|
164
src/functor/asr_functor.py
Normal file
164
src/functor/asr_functor.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
"""
|
||||||
|
ASRFunctor
|
||||||
|
负责对音频片段进行ASR处理, 以ASR_Result进行callback
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.functor.base import BaseFunctor
|
||||||
|
from src.models import AudioBinary_data_list, AudioBinary_Config, VAD_Functor_result
|
||||||
|
from typing import Callable, List
|
||||||
|
from queue import Queue, Empty
|
||||||
|
import threading
|
||||||
|
|
||||||
|
# 日志
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ASRFunctor(BaseFunctor):
|
||||||
|
"""
|
||||||
|
ASRFunctor
|
||||||
|
负责对音频片段进行ASR处理, 以ASR_Result进行callback
|
||||||
|
需要配置好 _model, _callback, _input_queue, _audio_config
|
||||||
|
否则无法run()启动线程
|
||||||
|
|
||||||
|
运行中, 使用reset_cache()重置缓存, 准备下次任务
|
||||||
|
|
||||||
|
使用stop()停止线程, 但需要等待input_queue为空
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
# 资源与配置
|
||||||
|
self._model: dict = {} # 模型
|
||||||
|
self._callback: List[Callable] = [] # 回调函数
|
||||||
|
self._input_queue: Queue = None # 输入队列
|
||||||
|
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||||
|
|
||||||
|
# flag
|
||||||
|
self._is_running: bool = False
|
||||||
|
self._stop_event: bool = False
|
||||||
|
|
||||||
|
# 线程资源
|
||||||
|
self._thread: threading.Thread = None
|
||||||
|
|
||||||
|
# 状态锁
|
||||||
|
self._status_lock: threading.Lock = threading.Lock()
|
||||||
|
|
||||||
|
# 缓存
|
||||||
|
self._hotwords: List[str] = []
|
||||||
|
|
||||||
|
def reset_cache(self) -> None:
|
||||||
|
"""
|
||||||
|
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_input_queue(self, queue: Queue) -> None:
|
||||||
|
"""
|
||||||
|
设置监听的输入消息队列
|
||||||
|
"""
|
||||||
|
self._input_queue = queue
|
||||||
|
|
||||||
|
def set_model(self, model: dict) -> None:
|
||||||
|
"""
|
||||||
|
设置推理模型
|
||||||
|
"""
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
||||||
|
"""
|
||||||
|
设置音频配置
|
||||||
|
"""
|
||||||
|
self._audio_config = audio_config
|
||||||
|
logger.debug("ASRFunctor设置音频配置: %s", self._audio_config)
|
||||||
|
|
||||||
|
def add_callback(self, callback: Callable) -> None:
|
||||||
|
"""
|
||||||
|
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||||||
|
"""
|
||||||
|
if not isinstance(self._callback, list):
|
||||||
|
self._callback = []
|
||||||
|
self._callback.append(callback)
|
||||||
|
|
||||||
|
def _do_callback(self, result: List[str]) -> None:
|
||||||
|
"""
|
||||||
|
回调函数
|
||||||
|
"""
|
||||||
|
text = result[0]["text"].replace(" ", "")
|
||||||
|
for callback in self._callback:
|
||||||
|
callback(text)
|
||||||
|
|
||||||
|
def _process(self, data: VAD_Functor_result) -> None:
|
||||||
|
"""
|
||||||
|
处理数据
|
||||||
|
"""
|
||||||
|
binary_data = data.audiobinary_data.binary_data
|
||||||
|
result = self._model["asr"].generate(
|
||||||
|
input=binary_data,
|
||||||
|
chunk_size=self._audio_config.chunk_size,
|
||||||
|
hotwords=self._hotwords,
|
||||||
|
)
|
||||||
|
self._do_callback(result)
|
||||||
|
|
||||||
|
def _run(self) -> None:
|
||||||
|
"""
|
||||||
|
线程运行逻辑
|
||||||
|
"""
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = True
|
||||||
|
self._stop_event = False
|
||||||
|
# 运行逻辑
|
||||||
|
while self._is_running:
|
||||||
|
try:
|
||||||
|
data = self._input_queue.get(True, timeout=1)
|
||||||
|
if data is None:
|
||||||
|
break
|
||||||
|
logger.debug("[ASRFunctor]获取到的数据length: %s", len(data))
|
||||||
|
self._process(data)
|
||||||
|
self._input_queue.task_done()
|
||||||
|
# 当队列为空时, 间隔1s检测是否进入停止事件。
|
||||||
|
except Empty:
|
||||||
|
if self._stop_event:
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
# 其他异常
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("ASRFunctor运行时发生错误: %s", e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def run(self) -> threading.Thread:
|
||||||
|
"""
|
||||||
|
启动线程
|
||||||
|
Returns:
|
||||||
|
threading.Thread: 返回已运行线程实例
|
||||||
|
"""
|
||||||
|
self._pre_check()
|
||||||
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
return self._thread
|
||||||
|
|
||||||
|
def _pre_check(self) -> bool:
|
||||||
|
"""
|
||||||
|
预检查
|
||||||
|
"""
|
||||||
|
if self._model is None:
|
||||||
|
raise ValueError("模型未设置")
|
||||||
|
if self._audio_config is None:
|
||||||
|
raise ValueError("音频配置未设置")
|
||||||
|
if self._input_queue is None:
|
||||||
|
raise ValueError("输入队列未设置")
|
||||||
|
if self._callback is None:
|
||||||
|
raise ValueError("回调函数未设置")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def stop(self) -> bool:
|
||||||
|
"""
|
||||||
|
停止线程
|
||||||
|
"""
|
||||||
|
with self._status_lock:
|
||||||
|
self._stop_event = True
|
||||||
|
self._thread.join()
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = False
|
||||||
|
return not self._thread.is_alive()
|
208
src/functor/base.py
Normal file
208
src/functor/base.py
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
"""
|
||||||
|
Functor基础模块
|
||||||
|
|
||||||
|
该模块定义了Functor的基类,所有功能性的类(如VAD、PUNC、ASR、SPK等)都应继承自这个基类。
|
||||||
|
基类提供了数据处理的基本框架,包括:
|
||||||
|
- 回调函数管理
|
||||||
|
- 模型配置管理
|
||||||
|
- 线程运行控制
|
||||||
|
|
||||||
|
主要类:
|
||||||
|
BaseFunctor: Functor抽象类
|
||||||
|
FunctorFactory: Functor工厂类
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Callable, List, Dict
|
||||||
|
from queue import Queue
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
class BaseFunctor(ABC):
|
||||||
|
"""
|
||||||
|
Functor抽象类
|
||||||
|
|
||||||
|
该抽象类规定了所有的Functor类必须实现run()方法启动自身线程
|
||||||
|
|
||||||
|
属性:
|
||||||
|
_callback (Callable): 处理完成后的回调函数
|
||||||
|
_model (dict): 存储模型相关的配置和实例
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
初始化函数器
|
||||||
|
|
||||||
|
参数:
|
||||||
|
callback (Callable): 处理完成后的回调函数
|
||||||
|
model (dict): 模型相关的配置和实例
|
||||||
|
"""
|
||||||
|
self._callback: List[Callable] = []
|
||||||
|
self._model: dict = {}
|
||||||
|
# flag
|
||||||
|
self._is_running: bool = False
|
||||||
|
self._stop_event: bool = False
|
||||||
|
# 状态锁
|
||||||
|
self._status_lock: threading.Lock = threading.Lock()
|
||||||
|
# 线程资源
|
||||||
|
self._thread: threading.Thread = None
|
||||||
|
|
||||||
|
def add_callback(self, callback: Callable):
|
||||||
|
"""
|
||||||
|
添加回调函数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
callback (Callable): 新的回调函数
|
||||||
|
"""
|
||||||
|
self._callback.append(callback)
|
||||||
|
|
||||||
|
def set_model(self, model: dict):
|
||||||
|
"""
|
||||||
|
设置模型配置
|
||||||
|
|
||||||
|
参数:
|
||||||
|
model (dict): 新的模型配置
|
||||||
|
"""
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
def set_input_queue(self, queue: Queue):
|
||||||
|
"""
|
||||||
|
设置输入队列
|
||||||
|
|
||||||
|
参数:
|
||||||
|
queue (Queue): 新的输入队列
|
||||||
|
"""
|
||||||
|
self._input_queue = queue
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _run(self):
|
||||||
|
"""
|
||||||
|
线程运行逻辑
|
||||||
|
|
||||||
|
返回:
|
||||||
|
当达到条件时触发callback
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run(self):
|
||||||
|
"""
|
||||||
|
启动_run方法线程
|
||||||
|
|
||||||
|
返回:
|
||||||
|
线程实例
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _pre_check(self):
|
||||||
|
"""
|
||||||
|
预检查
|
||||||
|
|
||||||
|
返回:
|
||||||
|
预检查结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def stop(self):
|
||||||
|
"""
|
||||||
|
停止线程
|
||||||
|
|
||||||
|
返回:
|
||||||
|
停止结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class FunctorFactory:
|
||||||
|
"""
|
||||||
|
Functor工厂类
|
||||||
|
|
||||||
|
该工厂类负责创建和配置Functor实例
|
||||||
|
|
||||||
|
主要方法:
|
||||||
|
make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
|
||||||
|
创建并配置Functor实例
|
||||||
|
"""
|
||||||
|
def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||||
|
"""
|
||||||
|
创建VAD Functor实例
|
||||||
|
"""
|
||||||
|
from src.functor.vad_functor import VADFunctor
|
||||||
|
|
||||||
|
audio_config = config["audio_config"]
|
||||||
|
model = {"vad": models["vad"]}
|
||||||
|
|
||||||
|
vad_functor = VADFunctor()
|
||||||
|
vad_functor.set_audio_config(audio_config)
|
||||||
|
vad_functor.set_model(model)
|
||||||
|
|
||||||
|
return vad_functor
|
||||||
|
|
||||||
|
def _make_asrfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||||
|
"""
|
||||||
|
创建ASR Functor实例
|
||||||
|
"""
|
||||||
|
from src.functor.asr_functor import ASRFunctor
|
||||||
|
|
||||||
|
audio_config = config["audio_config"]
|
||||||
|
model = {"asr": models["asr"]}
|
||||||
|
|
||||||
|
asr_functor = ASRFunctor()
|
||||||
|
asr_functor.set_audio_config(audio_config)
|
||||||
|
asr_functor.set_model(model)
|
||||||
|
|
||||||
|
return asr_functor
|
||||||
|
|
||||||
|
def _make_spkfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||||
|
"""
|
||||||
|
创建SPK Functor实例
|
||||||
|
"""
|
||||||
|
from src.functor.spk_functor import SPKFunctor
|
||||||
|
|
||||||
|
logger.debug(f"创建spk functor[开始]")
|
||||||
|
audio_config = config["audio_config"]
|
||||||
|
# model = {"spk": models["spk"]}
|
||||||
|
|
||||||
|
spk_functor = SPKFunctor(sv_pipeline=models["spk"])
|
||||||
|
spk_functor.set_audio_config(audio_config)
|
||||||
|
# spk_functor.set_model(model)
|
||||||
|
spk_functor.bake()
|
||||||
|
|
||||||
|
logger.debug(f"创建spk functor[完成]")
|
||||||
|
return spk_functor
|
||||||
|
|
||||||
|
def _make_resultbinderfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||||
|
"""
|
||||||
|
创建ResultBinder Functor实例
|
||||||
|
"""
|
||||||
|
from src.functor.resultbinder_functor import ResultBinderFunctor
|
||||||
|
|
||||||
|
resultbinder_functor = ResultBinderFunctor()
|
||||||
|
|
||||||
|
return resultbinder_functor
|
||||||
|
|
||||||
|
factory_dict: Dict[str, Callable] = {
|
||||||
|
"vad": _make_vadfunctor,
|
||||||
|
"asr": _make_asrfunctor,
|
||||||
|
"spk": _make_spkfunctor,
|
||||||
|
"resultbinder": _make_resultbinderfunctor,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor:
|
||||||
|
"""
|
||||||
|
创建并配置Functor实例
|
||||||
|
|
||||||
|
参数:
|
||||||
|
funtor_name (str): Functor名称
|
||||||
|
config (dict): 配置信息
|
||||||
|
models (dict): 模型信息
|
||||||
|
|
||||||
|
返回:
|
||||||
|
BaseFunctor: 创建的Functor实例
|
||||||
|
"""
|
||||||
|
if functor_name in cls.factory_dict:
|
||||||
|
return cls.factory_dict[functor_name](config=config, models=models)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的Functor类型: {functor_name}")
|
122
src/functor/readme.md
Normal file
122
src/functor/readme.md
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
# 对于Functor的解释
|
||||||
|
|
||||||
|
## Functor 文件夹作用
|
||||||
|
|
||||||
|
Functor文件夹用于存放所有功能性的类,包括VAD、PUNC、ASR、SPK等。
|
||||||
|
|
||||||
|
## Functor 类的定义
|
||||||
|
|
||||||
|
所有类应继承于**基类**`BaseFunctor`。
|
||||||
|
|
||||||
|
为了方便使用,我们对于**基类**的定义如下:
|
||||||
|
|
||||||
|
1. 函数内部使用的变量以单下划线开头,基类中包含:
|
||||||
|
|
||||||
|
* _model: Dict 存放模型相关的配置和实例
|
||||||
|
* _input_queue: Queue 监听的输入消息队列
|
||||||
|
* _thread: Threading.Thread 运行的线程实例
|
||||||
|
* _callback: List[Callable] 回调函数列表
|
||||||
|
* _is_running: bool 线程运行状态标志
|
||||||
|
* _stop_event: bool 停止事件标志
|
||||||
|
* _status_lock: threading.Lock 状态锁,用于线程同步
|
||||||
|
|
||||||
|
2. 对于使用的模型,请从统一的 **模型管理类`ModelLoader`** 中获取,由模型管理类统一进行加载、缓存和释放,`_model`存放类型为`dict`。
|
||||||
|
|
||||||
|
3. 基类定义的核心方法:
|
||||||
|
|
||||||
|
* `add_callback(callback: Callable)`: 添加结果处理的回调函数
|
||||||
|
* `set_model(model: dict)`: 设置模型配置和实例
|
||||||
|
* `set_input_queue(queue: Queue)`: 设置输入数据队列
|
||||||
|
* `run()`: 启动处理线程(抽象方法)
|
||||||
|
* `stop()`: 停止处理线程(抽象方法)
|
||||||
|
* `_run()`: 线程运行的具体逻辑(抽象方法)
|
||||||
|
* `_pre_check()`: 运行前的预检查(抽象方法)
|
||||||
|
|
||||||
|
## 派生类实现要求
|
||||||
|
|
||||||
|
1. 必须实现的抽象方法:
|
||||||
|
* `_pre_check()`:
|
||||||
|
- 检查必要的配置是否完整(如模型、队列等)
|
||||||
|
- 检查运行环境是否满足要求
|
||||||
|
- 返回检查结果
|
||||||
|
|
||||||
|
* `_run()`:
|
||||||
|
- 实现具体的数据处理逻辑
|
||||||
|
- 从 _input_queue 获取输入数据
|
||||||
|
- 使用 _model 进行处理
|
||||||
|
- 通过 _callback 返回处理结果
|
||||||
|
|
||||||
|
* `run()`:
|
||||||
|
- 调用 _pre_check() 进行预检查
|
||||||
|
- 创建并启动处理线程
|
||||||
|
- 设置相关状态标志
|
||||||
|
|
||||||
|
* `stop()`:
|
||||||
|
- 安全停止处理线程
|
||||||
|
- 清理资源
|
||||||
|
- 重置状态标志
|
||||||
|
|
||||||
|
2. 建议实现的方法:
|
||||||
|
* `__str__`: 返回当前实例的状态信息
|
||||||
|
* 错误处理方法:处理运行过程中的异常情况
|
||||||
|
|
||||||
|
## 使用示例
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyFunctor(BaseFunctor):
|
||||||
|
def _pre_check(self):
|
||||||
|
if not self._model or not self._input_queue:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
while not self._stop_event:
|
||||||
|
try:
|
||||||
|
data = self._input_queue.get(timeout=1.0)
|
||||||
|
result = self._model['my_model'].process(data)
|
||||||
|
for callback in self._callback:
|
||||||
|
callback(result)
|
||||||
|
except Queue.Empty:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理错误: {e}")
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
if not self._pre_check():
|
||||||
|
raise RuntimeError("预检查失败")
|
||||||
|
|
||||||
|
with self._status_lock:
|
||||||
|
if self._is_running:
|
||||||
|
return
|
||||||
|
self._is_running = True
|
||||||
|
self._stop_event = False
|
||||||
|
self._thread = threading.Thread(target=self._run)
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
with self._status_lock:
|
||||||
|
if not self._is_running:
|
||||||
|
return
|
||||||
|
self._stop_event = True
|
||||||
|
if self._thread:
|
||||||
|
self._thread.join()
|
||||||
|
self._is_running = False
|
||||||
|
```
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
1. 线程安全:
|
||||||
|
* 使用 _status_lock 保护状态变更
|
||||||
|
* 注意共享资源的访问控制
|
||||||
|
|
||||||
|
2. 错误处理:
|
||||||
|
* 在 _run() 中妥善处理异常
|
||||||
|
* 提供详细的错误日志
|
||||||
|
|
||||||
|
3. 资源管理:
|
||||||
|
* 确保在 stop() 中正确清理资源
|
||||||
|
* 避免资源泄露
|
||||||
|
|
||||||
|
4. 回调函数:
|
||||||
|
* 回调函数应该是非阻塞的
|
||||||
|
* 处理回调函数抛出的异常
|
165
src/functor/resultbinder_functor.py
Normal file
165
src/functor/resultbinder_functor.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
"""
|
||||||
|
ResultBinderFunctor
|
||||||
|
负责聚合结果,将所有input_queue中的结果进行聚合,并进行callback
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.functor.base import BaseFunctor
|
||||||
|
from src.models import AudioBinary_Config, VAD_Functor_result
|
||||||
|
from typing import Callable, List, Dict, Any
|
||||||
|
from queue import Queue, Empty
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
# 日志
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ResultBinderFunctor(BaseFunctor):
|
||||||
|
"""
|
||||||
|
ResultBinderFunctor
|
||||||
|
负责聚合结果,将所有input_queue中的结果进行聚合,并进行callback
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
# 资源与配置
|
||||||
|
self._callback: List[Callable] = [] # 回调函数
|
||||||
|
self._input_queue: Dict[str, Queue] = {} # 输入队列
|
||||||
|
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||||
|
|
||||||
|
def reset_cache(self) -> None:
|
||||||
|
"""
|
||||||
|
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def add_input_queue(self,
|
||||||
|
name: str,
|
||||||
|
queue: Queue,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
设置监听的输入消息队列
|
||||||
|
"""
|
||||||
|
self._input_queue[name] = queue
|
||||||
|
|
||||||
|
def set_model(self, model: dict) -> None:
|
||||||
|
"""
|
||||||
|
设置推理模型
|
||||||
|
resultbinder_functor 不应设置模型
|
||||||
|
"""
|
||||||
|
logger.warning("ResultBinderFunctor不应设置模型")
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
||||||
|
"""
|
||||||
|
设置音频配置
|
||||||
|
resultbinder_functor 不应设置音频配置
|
||||||
|
"""
|
||||||
|
logger.warning("ResultBinderFunctor不应设置音频配置")
|
||||||
|
self._audio_config = audio_config
|
||||||
|
|
||||||
|
def add_callback(self, callback: Callable) -> None:
|
||||||
|
"""
|
||||||
|
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||||||
|
"""
|
||||||
|
if not isinstance(self._callback, list):
|
||||||
|
self._callback = []
|
||||||
|
self._callback.append(callback)
|
||||||
|
|
||||||
|
def _do_callback(self, result: List[str]) -> None:
|
||||||
|
"""
|
||||||
|
回调函数
|
||||||
|
"""
|
||||||
|
for callback in self._callback:
|
||||||
|
callback(result)
|
||||||
|
|
||||||
|
def _process(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
处理数据
|
||||||
|
{
|
||||||
|
"is_final": false,
|
||||||
|
"mode": "2pass-offline",
|
||||||
|
"text": ",等一下我回一下,ok你看这里就有了",
|
||||||
|
"wav_name": "h5",
|
||||||
|
"speaker_id":
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
logger.debug("ResultBinderFunctor处理数据: %s", data)
|
||||||
|
# 将data中的result进行聚合
|
||||||
|
# 此步暂时无意义,预留
|
||||||
|
results = {
|
||||||
|
"is_final": False,
|
||||||
|
"mode": "2pass-offline",
|
||||||
|
"text": data["asr"],
|
||||||
|
"wav_name": "h5",
|
||||||
|
"speaker_id": data["spk"]["speaker_id"]
|
||||||
|
}
|
||||||
|
|
||||||
|
# for name, result in data.items():
|
||||||
|
# results[name] = result
|
||||||
|
self._do_callback(results)
|
||||||
|
|
||||||
|
def _run(self) -> None:
|
||||||
|
"""
|
||||||
|
线程运行逻辑
|
||||||
|
"""
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = True
|
||||||
|
self._stop_event = False
|
||||||
|
# 运行逻辑
|
||||||
|
while self._is_running:
|
||||||
|
try:
|
||||||
|
# 若有队列为空,则等待0.1s
|
||||||
|
for name, queue in self._input_queue.items():
|
||||||
|
if queue.empty():
|
||||||
|
time.sleep(0.1)
|
||||||
|
raise Empty
|
||||||
|
data = {}
|
||||||
|
for name, queue in self._input_queue.items():
|
||||||
|
data[name] = queue.get(True, timeout=1)
|
||||||
|
queue.task_done()
|
||||||
|
self._process(data)
|
||||||
|
# 当队列为空时, 检测是否进入停止事件。
|
||||||
|
except Empty:
|
||||||
|
if self._stop_event:
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
# 其他异常
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("SpkFunctor运行时发生错误: %s", e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def run(self) -> threading.Thread:
|
||||||
|
"""
|
||||||
|
启动线程
|
||||||
|
Returns:
|
||||||
|
threading.Thread: 返回已运行线程实例
|
||||||
|
"""
|
||||||
|
self._pre_check()
|
||||||
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
return self._thread
|
||||||
|
|
||||||
|
def _pre_check(self) -> bool:
|
||||||
|
"""
|
||||||
|
预检查
|
||||||
|
"""
|
||||||
|
if self._model is None:
|
||||||
|
raise ValueError("模型未设置")
|
||||||
|
if self._input_queue is None:
|
||||||
|
raise ValueError("输入队列未设置")
|
||||||
|
if self._callback is None:
|
||||||
|
raise ValueError("回调函数未设置")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def stop(self) -> bool:
|
||||||
|
"""
|
||||||
|
停止线程
|
||||||
|
"""
|
||||||
|
with self._status_lock:
|
||||||
|
self._stop_event = True
|
||||||
|
self._thread.join()
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = False
|
||||||
|
return not self._thread.is_alive()
|
367
src/functor/spk_functor.py
Normal file
367
src/functor/spk_functor.py
Normal file
@ -0,0 +1,367 @@
|
|||||||
|
"""
|
||||||
|
SpkFunctor
|
||||||
|
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.functor.base import BaseFunctor
|
||||||
|
from src.models import AudioBinary_Config, VAD_Functor_result, SpeakerCreate
|
||||||
|
from typing import Callable, List, Dict
|
||||||
|
from queue import Queue, Empty
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import threading
|
||||||
|
import numpy
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
# 日志
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
# sv_pipeline = pipeline(
|
||||||
|
# task='speaker-verification',
|
||||||
|
# model='iic/speech_campplus_sv_zh-cn_16k-common',
|
||||||
|
# model_revision='v1.0.0'
|
||||||
|
# )
|
||||||
|
# speaker1_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker1_a_cn_16k.wav'
|
||||||
|
# speaker1_b_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker1_b_cn_16k.wav'
|
||||||
|
# speaker2_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker2_a_cn_16k.wav'
|
||||||
|
# # 相同说话人语音
|
||||||
|
# result = sv_pipeline([speaker1_a_wav, speaker1_b_wav])
|
||||||
|
# print(result)
|
||||||
|
# # 不同说话人语音
|
||||||
|
# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav])
|
||||||
|
# print(result)
|
||||||
|
# # 可以自定义得分阈值来进行识别,阈值越高,判定为同一人的条件越严格
|
||||||
|
# result = sv_pipeline([speaker1_a_wav, speaker1_a_wav], thr=0.6)
|
||||||
|
# print(result)
|
||||||
|
# # 可以传入output_emb参数,输出结果中就会包含提取到的说话人embedding
|
||||||
|
# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], output_emb=True)
|
||||||
|
# print(result['embs'], result['outputs'])
|
||||||
|
# # 可以传入save_dir参数,提取到的说话人embedding会存储在save_dir目录中
|
||||||
|
# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], save_dir='savePath/')
|
||||||
|
|
||||||
|
class SPKFunctor(BaseFunctor):
|
||||||
|
"""
|
||||||
|
SPKFunctor
|
||||||
|
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
|
||||||
|
需要配置好 _model, _callback, _input_queue, _audio_config
|
||||||
|
否则无法run()启动线程
|
||||||
|
|
||||||
|
运行中, 使用reset_cache()重置缓存, 准备下次任务
|
||||||
|
|
||||||
|
使用stop()停止线程, 但需要等待input_queue为空
|
||||||
|
"""
|
||||||
|
class speaker_verify:
|
||||||
|
"""
|
||||||
|
说话人验证
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._spk_data: List[SpeakerCreate] = []
|
||||||
|
|
||||||
|
def add_speaker(self, speaker: SpeakerCreate) -> None:
|
||||||
|
self._spk_data.append(speaker)
|
||||||
|
# logger.debug("添加说话人: %s", speaker)
|
||||||
|
|
||||||
|
def verify(self, emb: numpy.ndarray) -> Dict:
|
||||||
|
# 将输入的numpy embedding转换为tensor
|
||||||
|
input_emb_tensor = torch.from_numpy(emb).unsqueeze(0)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for speaker in self._spk_data:
|
||||||
|
if not speaker.speaker_embs:
|
||||||
|
continue
|
||||||
|
# 将列表中存储的embedding转换为tensor
|
||||||
|
speaker_emb_tensor = torch.tensor(speaker.speaker_embs).unsqueeze(0)
|
||||||
|
|
||||||
|
# 计算余弦相似度
|
||||||
|
score = torch.nn.functional.cosine_similarity(input_emb_tensor, speaker_emb_tensor).item()
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"score": score,
|
||||||
|
"speaker_id": speaker.speaker_id,
|
||||||
|
"speaker_name": speaker.speaker_name,
|
||||||
|
"speaker_description": speaker.speaker_description
|
||||||
|
}
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return {
|
||||||
|
"speaker_id": "unknown",
|
||||||
|
"speaker_name": "Unknown",
|
||||||
|
"speaker_description": "No registered speakers to verify against.",
|
||||||
|
"score": 0.0,
|
||||||
|
"results": []
|
||||||
|
}
|
||||||
|
|
||||||
|
results.sort(key=lambda x: x['score'], reverse=True)
|
||||||
|
best_match = results[0]
|
||||||
|
return {
|
||||||
|
"speaker_id": best_match['speaker_id'],
|
||||||
|
"speaker_name": best_match['speaker_name'],
|
||||||
|
"speaker_description": best_match['speaker_description'],
|
||||||
|
"score": best_match['score'],
|
||||||
|
"results": results
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, sv_pipeline: pipeline) -> None:
|
||||||
|
super().__init__()
|
||||||
|
# 资源与配置
|
||||||
|
self._spk_verify = self.speaker_verify()
|
||||||
|
self._sv_pipeline = sv_pipeline
|
||||||
|
self._model: dict = {} # 模型
|
||||||
|
self._callback: List[Callable] = [] # 回调函数
|
||||||
|
self._input_queue: Queue = None # 输入队列
|
||||||
|
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||||
|
|
||||||
|
# logger.debug("加载本地说话人数据")
|
||||||
|
# self.load_spk_data_local()
|
||||||
|
# import inspect
|
||||||
|
# func_name = inspect.currentframe().f_code.co_name
|
||||||
|
# logger.info(f"{func_name} 加载数据库说话人数据")
|
||||||
|
self._speakers_url = os.getenv("SPEAKERS_URL", None)
|
||||||
|
|
||||||
|
def bake(self, *args, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
烘焙
|
||||||
|
"""
|
||||||
|
if self._speakers_url is None:
|
||||||
|
logger.error("未提供说话人数据库API的URL")
|
||||||
|
return
|
||||||
|
self.load_spk_database_api(
|
||||||
|
url=kwargs.get("speakers_url", self._speakers_url)
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_spk_database_api(
|
||||||
|
self,
|
||||||
|
url: str = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
加载api远程数据库说话人数据
|
||||||
|
"""
|
||||||
|
# 网络请求后端数据库拿到所有的说话人数据
|
||||||
|
|
||||||
|
try:
|
||||||
|
url = url if url != None else self._speakers_url
|
||||||
|
if not url:
|
||||||
|
logger.error("未提供说话人数据库API的URL")
|
||||||
|
return
|
||||||
|
logger.debug("加载API说话人数据库: url: %s", url)
|
||||||
|
response = requests.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
items = data.get("items", [])
|
||||||
|
logger.debug("加载API说话人个数: %s", len(items))
|
||||||
|
for spk in items:
|
||||||
|
logger.debug("加载API说话人数据: %s %s", spk.get('speaker_name'), spk.get('speaker_id'))
|
||||||
|
# 兼容API返回的字段与本地字段
|
||||||
|
spk_dict = {
|
||||||
|
"speaker_name": spk.get("speaker_name", ""),
|
||||||
|
"speaker_id": spk.get("speaker_id", ""),
|
||||||
|
"speaker_description": spk.get("speaker_description", ""),
|
||||||
|
"avatar": spk.get("avatar", ""),
|
||||||
|
"wav_path": "", # API未提供本地wav路径
|
||||||
|
"speaker_embs": spk.get("audio_feature_vector", []),
|
||||||
|
}
|
||||||
|
# 如果API没有embs但有音频样本,可以尝试提取
|
||||||
|
if (not spk_dict["speaker_embs"] or spk_dict["speaker_embs"] is None) and spk.get("audio_sample"):
|
||||||
|
try:
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import numpy as np
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
audio_bytes = base64.b64decode(spk["audio_sample"])
|
||||||
|
|
||||||
|
# 使用 pydub 处理音频,它能兼容 wav, mp3 等多种格式 (需要 ffmpeg)
|
||||||
|
audio_segment = AudioSegment.from_file(io.BytesIO(audio_bytes))
|
||||||
|
|
||||||
|
# 将音频转换为模型期望的格式,通常是 16k 采样率、单声道
|
||||||
|
audio_segment = audio_segment.set_frame_rate(16000)
|
||||||
|
audio_segment = audio_segment.set_channels(1)
|
||||||
|
|
||||||
|
# 转换为 float32 的 numpy 数组并归一化
|
||||||
|
wav_data = np.array(audio_segment.get_array_of_samples()).astype(np.float32)
|
||||||
|
wav_data /= 32768.0 # 16-bit signed audio is in range [-32768, 32767]
|
||||||
|
|
||||||
|
spk_dict["speaker_embs"] = self._sv_pipeline([wav_data], output_emb=True)['embs'][0]
|
||||||
|
logger.debug("API音频样本转换为embs: %s", spk_dict["speaker_name"])
|
||||||
|
except Exception:
|
||||||
|
logger.error("API音频样本 '%s' 转换为embs失败", spk_dict.get("speaker_name", "Unknown"), exc_info=True)
|
||||||
|
spk_dict["speaker_embs"] = []
|
||||||
|
|
||||||
|
# 转为numpy后加入
|
||||||
|
try:
|
||||||
|
import numpy
|
||||||
|
spk_dict["speaker_embs"] = numpy.array(spk_dict["speaker_embs"])
|
||||||
|
self._spk_verify.add_speaker(SpeakerCreate(**spk_dict))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("添加API说话人到本地失败: %s", e)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("请求说话人数据库API失败: %s", e)
|
||||||
|
|
||||||
|
def load_spk_data_local(
|
||||||
|
self,
|
||||||
|
spk_data_path: str = 'data/speakers.json',
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
加载本地说话人数据
|
||||||
|
"""
|
||||||
|
with open(spk_data_path, 'r') as f:
|
||||||
|
spk_data = json.load(f)
|
||||||
|
for i, spk in enumerate(spk_data):
|
||||||
|
logger.debug("加载本地说话人数据: %s %s", spk['speaker_name'], spk['speaker_id'])
|
||||||
|
if spk['speaker_embs'] == "" and spk['wav_path'] != "":
|
||||||
|
logger.debug("尝试转换本地wav为embs: %s", spk['wav_path'])
|
||||||
|
try:
|
||||||
|
# 读取数据为numpy数组
|
||||||
|
import soundfile as sf
|
||||||
|
import numpy as np
|
||||||
|
wav_data, sr = sf.read(spk['wav_path'], dtype='int16')
|
||||||
|
# 确保是单通道
|
||||||
|
if wav_data.ndim > 1:
|
||||||
|
wav_data = wav_data[:, 0]
|
||||||
|
# 转为numpy数组后送入pipeline
|
||||||
|
spk['speaker_embs'] = self._sv_pipeline([wav_data], output_emb=True)['embs'][0]
|
||||||
|
logger.debug("转换本地wav为embs: length=%s type=%s", len(spk['speaker_embs']), type(spk['speaker_embs']))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("转换本地wav为embs失败: %s", e)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
# logger.debug("加载本地说话人数据: %s %s", spk['speaker_name'], spk['speaker_id'])
|
||||||
|
# 将spk的speaker_embs转换为numpy
|
||||||
|
spk['speaker_embs'] = numpy.array(spk['speaker_embs'])
|
||||||
|
self._spk_verify.add_speaker(SpeakerCreate(**spk))
|
||||||
|
spk['speaker_embs'] = spk['speaker_embs'].tolist()
|
||||||
|
spk_data[i] = spk
|
||||||
|
# 保存更新后的数据
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(spk_data_path, 'w') as f:
|
||||||
|
json.dump(spk_data, f, indent=4)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("保存更新后的数据失败: %s", e)
|
||||||
|
|
||||||
|
def reset_cache(self) -> None:
|
||||||
|
"""
|
||||||
|
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_speakers_url(self, url: str) -> None:
|
||||||
|
"""
|
||||||
|
设置说话人数据库API的URL
|
||||||
|
"""
|
||||||
|
self._speakers_url = url
|
||||||
|
|
||||||
|
def set_input_queue(self, queue: Queue) -> None:
|
||||||
|
"""
|
||||||
|
设置监听的输入消息队列
|
||||||
|
"""
|
||||||
|
self._input_queue = queue
|
||||||
|
|
||||||
|
def set_model(self, model: dict) -> None:
|
||||||
|
"""
|
||||||
|
设置推理模型
|
||||||
|
"""
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
||||||
|
"""
|
||||||
|
设置音频配置
|
||||||
|
"""
|
||||||
|
self._audio_config = audio_config
|
||||||
|
logger.debug("SpkFunctor设置音频配置: %s", self._audio_config)
|
||||||
|
|
||||||
|
def add_callback(self, callback: Callable) -> None:
|
||||||
|
"""
|
||||||
|
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||||||
|
"""
|
||||||
|
if not isinstance(self._callback, list):
|
||||||
|
self._callback = []
|
||||||
|
self._callback.append(callback)
|
||||||
|
|
||||||
|
def _do_callback(self, result: List[str]) -> None:
|
||||||
|
"""
|
||||||
|
回调函数
|
||||||
|
"""
|
||||||
|
for callback in self._callback:
|
||||||
|
callback(result)
|
||||||
|
|
||||||
|
def _process(self, data: VAD_Functor_result) -> None:
|
||||||
|
"""
|
||||||
|
处理数据
|
||||||
|
"""
|
||||||
|
binary_data = data.audiobinary_data.binary_data
|
||||||
|
# result = self._model["spk"].generate(
|
||||||
|
# input=binary_data,
|
||||||
|
# chunk_size=self._audio_config.chunk_size,
|
||||||
|
# )
|
||||||
|
sv_result = self._sv_pipeline([binary_data], output_emb=True)
|
||||||
|
embs = sv_result['embs'][0]
|
||||||
|
|
||||||
|
result = self._spk_verify.verify(embs)
|
||||||
|
|
||||||
|
self._do_callback(result)
|
||||||
|
|
||||||
|
def _run(self) -> None:
|
||||||
|
"""
|
||||||
|
线程运行逻辑
|
||||||
|
"""
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = True
|
||||||
|
self._stop_event = False
|
||||||
|
# 运行逻辑
|
||||||
|
while self._is_running:
|
||||||
|
try:
|
||||||
|
data = self._input_queue.get(True, timeout=1)
|
||||||
|
if data is None:
|
||||||
|
break
|
||||||
|
logger.debug("[SPKFunctor]获取到的数据length: %s", len(data))
|
||||||
|
self._process(data)
|
||||||
|
self._input_queue.task_done()
|
||||||
|
# 当队列为空时, 间隔1s检测是否进入停止事件。
|
||||||
|
except Empty:
|
||||||
|
if self._stop_event:
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
# 其他异常
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("SpkFunctor运行时发生错误: %s", e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def run(self) -> threading.Thread:
|
||||||
|
"""
|
||||||
|
启动线程
|
||||||
|
Returns:
|
||||||
|
threading.Thread: 返回已运行线程实例
|
||||||
|
"""
|
||||||
|
self._pre_check()
|
||||||
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
return self._thread
|
||||||
|
|
||||||
|
def _pre_check(self) -> bool:
|
||||||
|
"""
|
||||||
|
预检查
|
||||||
|
"""
|
||||||
|
if self._model is None:
|
||||||
|
raise ValueError("模型未设置")
|
||||||
|
if self._input_queue is None:
|
||||||
|
raise ValueError("输入队列未设置")
|
||||||
|
if self._callback is None:
|
||||||
|
raise ValueError("回调函数未设置")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def stop(self) -> bool:
|
||||||
|
"""
|
||||||
|
停止线程
|
||||||
|
"""
|
||||||
|
with self._status_lock:
|
||||||
|
self._stop_event = True
|
||||||
|
self._thread.join()
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = False
|
||||||
|
return not self._thread.is_alive()
|
276
src/functor/vad_functor.py
Normal file
276
src/functor/vad_functor.py
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
"""
|
||||||
|
VADFunctor
|
||||||
|
负责对音频片段进行VAD处理, 以VAD_Result进行callback
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from queue import Empty, Queue
|
||||||
|
from typing import List, Any, Callable
|
||||||
|
import numpy
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from src.models import (
|
||||||
|
VAD_Functor_result,
|
||||||
|
AudioBinary_Config,
|
||||||
|
AudioBinary_data_list,
|
||||||
|
)
|
||||||
|
from src.functor.base import BaseFunctor
|
||||||
|
|
||||||
|
# 日志
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VADFunctor(BaseFunctor):
|
||||||
|
"""
|
||||||
|
VADFunctor
|
||||||
|
负责对音频片段进行VAD处理, 以VAD_Result进行callback
|
||||||
|
需要配置好 _model, _callback, _input_queue, _audio_config, _audio_binary_data_list
|
||||||
|
否则无法run()启动线程
|
||||||
|
|
||||||
|
运行中, 使用reset_cache()重置缓存, 准备下次任务
|
||||||
|
|
||||||
|
使用stop()停止线程, 但需要等待input_queue为空
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
# 资源与配置
|
||||||
|
self._model: dict = {} # 模型
|
||||||
|
self._callback: List[Callable] = [] # 回调函数
|
||||||
|
self._input_queue: Queue = None # 输入队列
|
||||||
|
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||||
|
self._audio_binary_data_list: AudioBinary_data_list = None # 音频数据列表
|
||||||
|
|
||||||
|
# flag
|
||||||
|
# 此处用到两个锁,但都是为了截断_run线程,考虑后续优化
|
||||||
|
self._is_running: bool = False
|
||||||
|
self._stop_event: bool = False
|
||||||
|
|
||||||
|
# 线程资源
|
||||||
|
self._thread: threading.Thread = None
|
||||||
|
|
||||||
|
# 状态锁
|
||||||
|
self._status_lock: threading.Lock = threading.Lock()
|
||||||
|
|
||||||
|
# 缓存
|
||||||
|
self._audio_cache: numpy.ndarray = None
|
||||||
|
self._audio_cache_preindex: int = 0
|
||||||
|
self._model_cache: dict = {}
|
||||||
|
self._cache_result_list = []
|
||||||
|
self._audiobinary_cache = None
|
||||||
|
|
||||||
|
def reset_cache(self) -> None:
|
||||||
|
"""
|
||||||
|
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||||
|
"""
|
||||||
|
self._audio_cache = None
|
||||||
|
self._audio_cache_preindex = 0
|
||||||
|
self._model_cache = {}
|
||||||
|
self._cache_result_list = []
|
||||||
|
self._audiobinary_cache = None
|
||||||
|
|
||||||
|
def set_input_queue(self, queue: Queue) -> None:
|
||||||
|
"""
|
||||||
|
设置监听的输入消息队列
|
||||||
|
"""
|
||||||
|
self._input_queue = queue
|
||||||
|
|
||||||
|
def set_model(self, model: dict) -> None:
|
||||||
|
"""
|
||||||
|
设置推理模型
|
||||||
|
"""
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
||||||
|
"""
|
||||||
|
设置音频配置
|
||||||
|
"""
|
||||||
|
self._audio_config = audio_config
|
||||||
|
logger.debug("VADFunctor设置音频配置: %s", self._audio_config)
|
||||||
|
|
||||||
|
def set_audio_binary_data_list(
|
||||||
|
self, audio_binary_data_list: AudioBinary_data_list
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
设置音频数据列表, 为Class AudioBinary_data_list类型
|
||||||
|
AudioBinary_data_list包含binary_data_list, 为list[_AudioBinary_data]类型
|
||||||
|
_AudioBinary_data包含binary_data, 为bytes/numpy.ndarray类型
|
||||||
|
"""
|
||||||
|
self._audio_binary_data_list = audio_binary_data_list
|
||||||
|
|
||||||
|
def add_callback(self, callback: Callable) -> None:
|
||||||
|
"""
|
||||||
|
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||||||
|
"""
|
||||||
|
if not isinstance(self._callback, list):
|
||||||
|
self._callback = []
|
||||||
|
self._callback.append(callback)
|
||||||
|
|
||||||
|
def _do_callback(self, result: List[List[int]]) -> None:
|
||||||
|
"""
|
||||||
|
回调函数
|
||||||
|
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice
|
||||||
|
|
||||||
|
输入:
|
||||||
|
result: List[[start,end]] 处理所得VAD端点
|
||||||
|
其中若start==-1, 则表示前无端点, 若end==-1, 则表示后无端点
|
||||||
|
当处理得到一个完成片段时, 存入AudioBinary中, 并向队列中添加AudioBinary_Slice
|
||||||
|
输出:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# 持久化缓存结果队列
|
||||||
|
for pair in result:
|
||||||
|
[start, end] = pair
|
||||||
|
# 若无前端点, 则向缓存队列中合并
|
||||||
|
if start == -1:
|
||||||
|
self._cache_result_list[-1][1] = end
|
||||||
|
else:
|
||||||
|
self._cache_result_list.append(pair)
|
||||||
|
logger.debug(f"VADFunctor结果: {self._cache_result_list}")
|
||||||
|
while len(self._cache_result_list) > 0 and self._cache_result_list[0][1] != -1:
|
||||||
|
logger.debug(f"VADFunctor结果: {self._cache_result_list}")
|
||||||
|
logger.debug(f"VADFunctor list[0][1]: {self._cache_result_list[0][1]}")
|
||||||
|
# 创建VAD片段
|
||||||
|
# 计算开始帧
|
||||||
|
start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0])
|
||||||
|
start_frame -= self._audio_cache_preindex
|
||||||
|
# 计算结束帧
|
||||||
|
end_frame = self._audio_config.ms2frame(self._cache_result_list[0][1])
|
||||||
|
end_frame -= self._audio_cache_preindex
|
||||||
|
# 计算开始时间
|
||||||
|
timestamp = time.time()
|
||||||
|
format_time = datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
logger.debug(f"{format_time}创建VAD片段: {start_frame} - {end_frame}")
|
||||||
|
vad_result = VAD_Functor_result.create_from_push_data(
|
||||||
|
audiobinary_data_list=self._audio_binary_data_list,
|
||||||
|
data=self._audiobinary_cache[start_frame:end_frame],
|
||||||
|
start_time=self._cache_result_list[0][0],
|
||||||
|
end_time=self._cache_result_list[0][1],
|
||||||
|
)
|
||||||
|
logger.debug(f"{format_time}创建VAD片段成功: {vad_result}")
|
||||||
|
self._audio_cache_preindex += end_frame
|
||||||
|
self._audiobinary_cache = self._audiobinary_cache[end_frame:]
|
||||||
|
for callback in self._callback:
|
||||||
|
callback(vad_result)
|
||||||
|
self._cache_result_list.pop(0)
|
||||||
|
|
||||||
|
def _predeal_data(self, data: Any) -> None:
|
||||||
|
"""
|
||||||
|
预处理数据, 将数据缓存到_audio_cache和_audiobinary_cache中
|
||||||
|
"""
|
||||||
|
if self._audio_cache is None:
|
||||||
|
self._audio_cache = data
|
||||||
|
else:
|
||||||
|
# 拼接音频数据
|
||||||
|
if isinstance(self._audio_cache, numpy.ndarray):
|
||||||
|
self._audio_cache = numpy.concatenate((self._audio_cache, data))
|
||||||
|
elif isinstance(self._audio_cache, list):
|
||||||
|
self._audio_cache.append(data)
|
||||||
|
|
||||||
|
if self._audiobinary_cache is None:
|
||||||
|
self._audiobinary_cache = data
|
||||||
|
else:
|
||||||
|
# 拼接音频数据
|
||||||
|
if isinstance(self._audiobinary_cache, numpy.ndarray):
|
||||||
|
self._audiobinary_cache = numpy.concatenate(
|
||||||
|
(self._audiobinary_cache, data)
|
||||||
|
)
|
||||||
|
elif isinstance(self._audiobinary_cache, list):
|
||||||
|
self._audiobinary_cache.append(data)
|
||||||
|
|
||||||
|
def _process(self, data: Any):
|
||||||
|
"""
|
||||||
|
处理数据
|
||||||
|
使用model进行生成, 并使用_do_callback进行回调
|
||||||
|
"""
|
||||||
|
if data is None:
|
||||||
|
result = self._model["vad"].generate(
|
||||||
|
input=self._audio_cache,
|
||||||
|
cache=self._model_cache,
|
||||||
|
chunk_size=self._audio_config.chunk_size,
|
||||||
|
is_final=True,
|
||||||
|
)
|
||||||
|
self._do_callback(result[0]["value"])
|
||||||
|
return
|
||||||
|
|
||||||
|
self._predeal_data(data)
|
||||||
|
if len(self._audio_cache) >= self._audio_config.chunk_stride:
|
||||||
|
result = self._model["vad"].generate(
|
||||||
|
input=self._audio_cache,
|
||||||
|
cache=self._model_cache,
|
||||||
|
chunk_size=self._audio_config.chunk_size,
|
||||||
|
max_end_silence_time = 300,
|
||||||
|
is_final=False,
|
||||||
|
)
|
||||||
|
if len(result[0]["value"]) > 0:
|
||||||
|
self._do_callback(result[0]["value"])
|
||||||
|
logger.debug(f"VADFunctor结果: {result[0]['value']}")
|
||||||
|
self._audio_cache = None
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
"""
|
||||||
|
线程运行逻辑
|
||||||
|
监听输入队列, 当有数据时, 处理数据
|
||||||
|
当输入队列为空时, 间隔1s检测是否进入停止事件。
|
||||||
|
"""
|
||||||
|
# 刷新运行状态
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = True
|
||||||
|
self._stop_event = False
|
||||||
|
# 运行逻辑
|
||||||
|
while self._is_running:
|
||||||
|
try:
|
||||||
|
data = self._input_queue.get(True, timeout=1)
|
||||||
|
# logger.debug("[VADFunctor]获取到的数据length: %s", len(data))
|
||||||
|
self._process(data)
|
||||||
|
self._input_queue.task_done()
|
||||||
|
if data is None:
|
||||||
|
break
|
||||||
|
# 当队列为空时, 间隔1s检测是否进入停止事件。
|
||||||
|
except Empty:
|
||||||
|
if self._stop_event:
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
# 其他异常
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("VADFunctor运行时发生错误: %s", e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""
|
||||||
|
启动 _run 线程, 并返回线程对象
|
||||||
|
"""
|
||||||
|
self._pre_check()
|
||||||
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
return self._thread
|
||||||
|
|
||||||
|
def _pre_check(self) -> bool:
|
||||||
|
"""
|
||||||
|
检测硬性资源是否都已设置
|
||||||
|
"""
|
||||||
|
if self._model is None:
|
||||||
|
raise ValueError("模型未设置")
|
||||||
|
if self._audio_config is None:
|
||||||
|
raise ValueError("音频配置未设置")
|
||||||
|
if self._audio_binary_data_list is None:
|
||||||
|
raise ValueError("音频数据列表未设置")
|
||||||
|
if self._input_queue is None:
|
||||||
|
raise ValueError("输入队列未设置")
|
||||||
|
if self._callback is None:
|
||||||
|
raise ValueError("回调函数未设置")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""
|
||||||
|
停止线程
|
||||||
|
通过设置_stop_event为True, 来在input_queue.get()循环为空时退出
|
||||||
|
"""
|
||||||
|
with self._status_lock:
|
||||||
|
self._stop_event = True
|
||||||
|
self._thread.join()
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = False
|
||||||
|
return not self._thread.is_alive()
|
176
src/logic_trager.py
Normal file
176
src/logic_trager.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
逻辑触发器类 - 用于处理音频数据并触发相应的处理逻辑
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
from typing import Any, Dict, Type, Callable
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = get_module_logger(__name__, level="INFO")
|
||||||
|
|
||||||
|
|
||||||
|
class AutoAfterMeta(type):
|
||||||
|
"""
|
||||||
|
自动调用__after__函数的元类
|
||||||
|
实现单例模式
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instances: Dict[Type, Any] = {} # 存储单例实例
|
||||||
|
|
||||||
|
def __new__(cls, name, bases, attrs):
|
||||||
|
# 遍历所有属性
|
||||||
|
for attr_name, attr_value in attrs.items():
|
||||||
|
# 如果是函数且不是以_开头
|
||||||
|
if callable(attr_value) and not attr_name.startswith("__"):
|
||||||
|
# 获取原函数
|
||||||
|
original_func = attr_value
|
||||||
|
|
||||||
|
# 创建包装函数
|
||||||
|
def make_wrapper(func):
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
# 执行原函数
|
||||||
|
result = func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
# 构建_after_函数名
|
||||||
|
after_func_name = f"__after__{func.__name__}"
|
||||||
|
|
||||||
|
# 检查是否存在对应的_after_函数
|
||||||
|
if hasattr(self, after_func_name):
|
||||||
|
after_func = getattr(self, after_func_name)
|
||||||
|
if callable(after_func):
|
||||||
|
try:
|
||||||
|
# 调用_after_函数
|
||||||
|
after_func()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"调用{after_func_name}时出错: {e}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
# 替换原函数
|
||||||
|
attrs[attr_name] = make_wrapper(original_func)
|
||||||
|
|
||||||
|
# 创建类
|
||||||
|
new_class = super().__new__(cls, name, bases, attrs)
|
||||||
|
return new_class
|
||||||
|
|
||||||
|
def __call__(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
重写__call__方法实现单例模式
|
||||||
|
当类被调用时(即创建实例时)执行
|
||||||
|
"""
|
||||||
|
if cls not in cls._instances:
|
||||||
|
# 如果实例不存在,创建新实例
|
||||||
|
cls._instances[cls] = super().__call__(*args, **kwargs)
|
||||||
|
logger.info(f"创建{cls.__name__}的新实例")
|
||||||
|
else:
|
||||||
|
logger.debug(f"返回{cls.__name__}的现有实例")
|
||||||
|
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
整体识别的处理逻辑:
|
||||||
|
1.压入二进制音频信息
|
||||||
|
2.不断检测VAD
|
||||||
|
3.当检测到完整VAD时,将VAD的音频信息压入音频块,并清除对应二进制信息
|
||||||
|
4.对音频块进行语音转文字offline,时间戳预测,说话人识别
|
||||||
|
5.将识别结果整合压入结果队列
|
||||||
|
6.结果队列被压入时调用回调函数
|
||||||
|
|
||||||
|
1->2 __after__push_binary_data 外部压入二进制信息
|
||||||
|
2,3->4 __after__push_audio_chunk 内部压入音频块
|
||||||
|
4->5 push_result_queue 压入结果队列
|
||||||
|
5->6 __after__push_result_queue 调用回调函数
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.functor import VAD
|
||||||
|
from src.models import AudioBinary_Config
|
||||||
|
from src.models import AudioBinary_Chunk
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class LogicTrager(metaclass=AutoAfterMeta):
|
||||||
|
"""逻辑触发器类"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
audio_chunk_max_size: int = 1024 * 1024 * 10,
|
||||||
|
audio_config: AudioBinary_Config = None,
|
||||||
|
result_callback: Callable = None,
|
||||||
|
models: Dict[str, Any] = None,
|
||||||
|
):
|
||||||
|
"""初始化"""
|
||||||
|
# 存储音频块
|
||||||
|
self._audio_chunk: List[AudioBinary_Chunk] = []
|
||||||
|
# 存储二进制数据
|
||||||
|
self._audio_chunk_binary = b""
|
||||||
|
self._audio_chunk_max_size = audio_chunk_max_size
|
||||||
|
# 音频参数
|
||||||
|
self._audio_config = (
|
||||||
|
audio_config if audio_config is not None else AudioBinary_Config()
|
||||||
|
)
|
||||||
|
# 结果队列
|
||||||
|
self._result_queue = []
|
||||||
|
# 聚合结果回调函数
|
||||||
|
self._aggregate_result_callback = result_callback
|
||||||
|
# 组件
|
||||||
|
self._vad = VAD(VAD_model=models.get("vad"), audio_config=self._audio_config)
|
||||||
|
self._vad.set_callback(self.push_audio_chunk)
|
||||||
|
|
||||||
|
logger.info("初始化LogicTrager")
|
||||||
|
|
||||||
|
def push_binary_data(self, chunk: bytes) -> None:
|
||||||
|
"""
|
||||||
|
压入音频块至VAD模块
|
||||||
|
|
||||||
|
参数:
|
||||||
|
chunk: 音频数据块
|
||||||
|
"""
|
||||||
|
# print("LogicTrager push_binary_data", len(chunk))
|
||||||
|
self._vad.push_binary_data(chunk)
|
||||||
|
self.__after__push_binary_data()
|
||||||
|
|
||||||
|
def __after__push_binary_data(self) -> None:
|
||||||
|
"""
|
||||||
|
添加音频块后处理
|
||||||
|
"""
|
||||||
|
# print("LogicTrager __after__push_binary_data")
|
||||||
|
self._vad.process_vad_result()
|
||||||
|
|
||||||
|
def push_audio_chunk(self, chunk: AudioBinary_Chunk) -> None:
|
||||||
|
"""
|
||||||
|
音频处理
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(
|
||||||
|
chunk.start_time, chunk.end_time, len(chunk.chunk)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._audio_chunk.append(chunk)
|
||||||
|
|
||||||
|
def __after__push_audio_chunk(self) -> None:
|
||||||
|
"""
|
||||||
|
压入音频块后处理
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def push_result_queue(self, result: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
压入结果队列
|
||||||
|
"""
|
||||||
|
self._result_queue.append(result)
|
||||||
|
|
||||||
|
def __after__push_result_queue(self) -> None:
|
||||||
|
"""
|
||||||
|
压入结果队列后处理
|
||||||
|
"""
|
||||||
|
logger.info("FINISH Result=")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
"""调用函数"""
|
||||||
|
pass
|
@ -1,79 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
模型加载模块 - 负责加载各种语音识别相关模型
|
|
||||||
"""
|
|
||||||
|
|
||||||
def load_models(args):
|
|
||||||
"""
|
|
||||||
加载所有需要的模型
|
|
||||||
|
|
||||||
参数:
|
|
||||||
args: 命令行参数,包含模型配置
|
|
||||||
|
|
||||||
返回:
|
|
||||||
dict: 包含所有加载的模型的字典
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 导入FunASR库
|
|
||||||
from funasr import AutoModel
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("未找到funasr库,请先安装: pip install funasr")
|
|
||||||
|
|
||||||
# 初始化模型字典
|
|
||||||
models = {}
|
|
||||||
|
|
||||||
# 1. 加载离线ASR模型
|
|
||||||
print(f"正在加载ASR离线模型: {args.asr_model}")
|
|
||||||
models["asr"] = AutoModel(
|
|
||||||
model=args.asr_model,
|
|
||||||
model_revision=args.asr_model_revision,
|
|
||||||
ngpu=args.ngpu,
|
|
||||||
ncpu=args.ncpu,
|
|
||||||
device=args.device,
|
|
||||||
disable_pbar=True,
|
|
||||||
disable_log=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. 加载在线ASR模型
|
|
||||||
print(f"正在加载ASR在线模型: {args.asr_model_online}")
|
|
||||||
models["asr_streaming"] = AutoModel(
|
|
||||||
model=args.asr_model_online,
|
|
||||||
model_revision=args.asr_model_online_revision,
|
|
||||||
ngpu=args.ngpu,
|
|
||||||
ncpu=args.ncpu,
|
|
||||||
device=args.device,
|
|
||||||
disable_pbar=True,
|
|
||||||
disable_log=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. 加载VAD模型
|
|
||||||
print(f"正在加载VAD模型: {args.vad_model}")
|
|
||||||
models["vad"] = AutoModel(
|
|
||||||
model=args.vad_model,
|
|
||||||
model_revision=args.vad_model_revision,
|
|
||||||
ngpu=args.ngpu,
|
|
||||||
ncpu=args.ncpu,
|
|
||||||
device=args.device,
|
|
||||||
disable_pbar=True,
|
|
||||||
disable_log=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. 加载标点符号模型(如果指定)
|
|
||||||
if args.punc_model:
|
|
||||||
print(f"正在加载标点符号模型: {args.punc_model}")
|
|
||||||
models["punc"] = AutoModel(
|
|
||||||
model=args.punc_model,
|
|
||||||
model_revision=args.punc_model_revision,
|
|
||||||
ngpu=args.ngpu,
|
|
||||||
ncpu=args.ncpu,
|
|
||||||
device=args.device,
|
|
||||||
disable_pbar=True,
|
|
||||||
disable_log=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
models["punc"] = None
|
|
||||||
print("未指定标点符号模型,将不使用标点符号")
|
|
||||||
|
|
||||||
print("所有模型加载完成")
|
|
||||||
return models
|
|
11
src/models/__init__.py
Normal file
11
src/models/__init__.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
|
||||||
|
from .vad import VAD_Functor_result
|
||||||
|
|
||||||
|
from .spk import SpeakerCreate, SpeakerResponse
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AudioBinary_Config",
|
||||||
|
"AudioBinary_data_list",
|
||||||
|
"_AudioBinary_data",
|
||||||
|
"VAD_Functor_result",
|
||||||
|
]
|
158
src/models/audio.py
Normal file
158
src/models/audio.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
from typing import List, Any
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
from src.utils import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
binary_data_types = (bytes, numpy.ndarray)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioBinary_Config(BaseModel):
|
||||||
|
"""二进制音频块配置信息"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
audio_data: binary_data_types = Field(description="音频数据", default=None)
|
||||||
|
chunk_size: int = Field(description="块大小", default=100)
|
||||||
|
chunk_stride: int = Field(description="块步长", default=1600)
|
||||||
|
sample_rate: int = Field(description="采样率", default=16000)
|
||||||
|
sample_width: int = Field(description="采样位宽", default=2)
|
||||||
|
channels: int = Field(description="通道数", default=1)
|
||||||
|
|
||||||
|
# 从Dict中加载
|
||||||
|
@classmethod
|
||||||
|
def AudioBinary_Config_from_dict(cls, data: dict):
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
def ms2frame(self, ms: int) -> int:
|
||||||
|
"""
|
||||||
|
将毫秒转换为帧
|
||||||
|
"""
|
||||||
|
return int(ms * self.sample_rate / 1000)
|
||||||
|
|
||||||
|
def frame2ms(self, frame: int) -> int:
|
||||||
|
"""
|
||||||
|
将帧转换为毫秒
|
||||||
|
"""
|
||||||
|
return int(frame * 1000 / self.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
|
class _AudioBinary_data(BaseModel):
|
||||||
|
"""音频数据"""
|
||||||
|
|
||||||
|
binary_data: binary_data_types = Field(description="音频二进制数据", default=None)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@validator("binary_data")
|
||||||
|
def validate_binary_data(cls, v):
|
||||||
|
"""
|
||||||
|
验证音频数据
|
||||||
|
Args:
|
||||||
|
v: 音频数据
|
||||||
|
Returns:
|
||||||
|
binary_data_types: 音频数据
|
||||||
|
"""
|
||||||
|
if not isinstance(v, (bytes, numpy.ndarray)):
|
||||||
|
logger.warning(
|
||||||
|
"[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查",
|
||||||
|
cls.__class__.__name__,
|
||||||
|
type(v),
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
获取音频数据长度
|
||||||
|
Returns:
|
||||||
|
int: 音频数据长度
|
||||||
|
"""
|
||||||
|
return len(self.binary_data)
|
||||||
|
|
||||||
|
def __init__(self, binary_data: binary_data_types):
|
||||||
|
"""
|
||||||
|
初始化音频数据
|
||||||
|
Args:
|
||||||
|
binary_data: 音频数据
|
||||||
|
"""
|
||||||
|
logger.debug(
|
||||||
|
"[%s]初始化音频数据, 数据类型为%s",
|
||||||
|
self.__class__.__name__,
|
||||||
|
type(binary_data),
|
||||||
|
)
|
||||||
|
super().__init__(binary_data=binary_data)
|
||||||
|
|
||||||
|
def __getitem__(self):
|
||||||
|
"""
|
||||||
|
当获取数据时, 直接返回binary_data
|
||||||
|
Returns:
|
||||||
|
binary_data_types: 音频数据
|
||||||
|
"""
|
||||||
|
return self.binary_data
|
||||||
|
|
||||||
|
|
||||||
|
class AudioBinary_data_list(BaseModel):
|
||||||
|
"""音频数据列表"""
|
||||||
|
|
||||||
|
binary_data_list: List[_AudioBinary_data] = Field(
|
||||||
|
description="音频数据列表", default=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def push_data(self, data: binary_data_types) -> int:
|
||||||
|
"""
|
||||||
|
添加音频数据
|
||||||
|
Args:
|
||||||
|
data: 音频数据
|
||||||
|
Returns:
|
||||||
|
int: 数据在binary_data_list中的索引
|
||||||
|
"""
|
||||||
|
self.binary_data_list.append(_AudioBinary_data(binary_data=data))
|
||||||
|
return len(self.binary_data_list) - 1
|
||||||
|
|
||||||
|
def __getitem__(self, index: int):
|
||||||
|
"""
|
||||||
|
获取音频数据
|
||||||
|
Args:
|
||||||
|
index: 音频数据在binary_data_list中的索引
|
||||||
|
Returns:
|
||||||
|
_AudioBinary_data: 音频数据
|
||||||
|
"""
|
||||||
|
return self.binary_data_list[index]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
获取音频数据列表长度
|
||||||
|
Returns:
|
||||||
|
int: 音频数据列表长度
|
||||||
|
"""
|
||||||
|
return len(self.binary_data_list)
|
||||||
|
|
||||||
|
|
||||||
|
# class AudioBinary_Slice(BaseModel):
|
||||||
|
# """音频块切片"""
|
||||||
|
# target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None)
|
||||||
|
# start_index: int = Field(description="开始位置", default=0)
|
||||||
|
# end_index: int = Field(description="结束位置", default=-1)
|
||||||
|
|
||||||
|
# @validator('start_index')
|
||||||
|
# def validate_start_index(cls, v):
|
||||||
|
# if v < 0:
|
||||||
|
# raise ValueError("start_index必须大于0")
|
||||||
|
# return v
|
||||||
|
|
||||||
|
# @validator('end_index')
|
||||||
|
# def validate_end_index(cls, v):
|
||||||
|
# if v < cls.start_index:
|
||||||
|
# logger.debug("[%s]end_index小于start_index, 将end_index设置为start_index", cls.__class__.__name__)
|
||||||
|
# v = cls.start_index
|
||||||
|
# return v
|
||||||
|
|
||||||
|
# def __call__(self):
|
||||||
|
# return self.target_Binary(self.start_index, self.end_index)
|
70
src/models/spk.py
Normal file
70
src/models/spk.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
"""
|
||||||
|
src/models/spk.py
|
||||||
|
------------------------
|
||||||
|
此模块定义与说话人(speakers)表对应的 Pydantic 模型,用于 API 数据验证和序列化。
|
||||||
|
|
||||||
|
模型说明:
|
||||||
|
- SpeakerBase:定义说话人的基础字段,包括姓名与描述。
|
||||||
|
- SpeakerCreate:用于创建说话人时的数据验证,直接继承 SpeakerBase。
|
||||||
|
- SpeakerUpdate:用于更新说话人信息时,所有字段均为可选。
|
||||||
|
- SpeakerResponse:返回给客户端时使用,包含数据库生成的字段(如 speaker_id、created_at 等)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List
|
||||||
|
from uuid import UUID
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from src.utils import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
# 基础模型,定义说话人的核心属性
|
||||||
|
class SpeakerBase(BaseModel):
|
||||||
|
speaker_id: UUID = Field(
|
||||||
|
...,
|
||||||
|
description="说话人唯一标识符"
|
||||||
|
)
|
||||||
|
speaker_name: str = Field(
|
||||||
|
...,
|
||||||
|
max_length=255,
|
||||||
|
description="说话人姓名,必填,最多255字符"
|
||||||
|
)
|
||||||
|
speaker_description: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="说话人描述信息,可选"
|
||||||
|
)
|
||||||
|
speaker_embs: Optional[List[float]] = Field(
|
||||||
|
None,
|
||||||
|
description="说话人embedding,可选"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 用于创建说话人时的数据模型,直接继承 SpeakerBase
|
||||||
|
class SpeakerCreate(SpeakerBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 用于更新说话人时的数据模型,所有字段都是可选的
|
||||||
|
class SpeakerUpdate(BaseModel):
|
||||||
|
speaker_name: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
max_length=255,
|
||||||
|
description="说话人姓名"
|
||||||
|
)
|
||||||
|
speaker_description: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="说话人描述信息"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 用于 API 返回的说话人响应模型,扩展基础模型以包含数据库生成的字段
|
||||||
|
class SpeakerResponse(SpeakerBase):
|
||||||
|
speaker_id: UUID = Field(
|
||||||
|
...,
|
||||||
|
description="说话人的唯一标识符(UUID)"
|
||||||
|
)
|
||||||
|
created_at: datetime = Field(
|
||||||
|
...,
|
||||||
|
description="记录创建时间"
|
||||||
|
)
|
||||||
|
updated_at: datetime = Field(
|
||||||
|
...,
|
||||||
|
description="最近更新时间"
|
||||||
|
)
|
91
src/models/vad.py
Normal file
91
src/models/vad.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
from typing import List, Optional, Callable, Any
|
||||||
|
from .audio import AudioBinary_data_list, _AudioBinary_data
|
||||||
|
|
||||||
|
|
||||||
|
class VAD_Functor_result(BaseModel):
|
||||||
|
"""
|
||||||
|
VADFunctor结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
audiobinary_data_list: AudioBinary_data_list = Field(description="音频数据列表")
|
||||||
|
audiobinary_index: int = Field(description="音频数据索引")
|
||||||
|
audiobinary_data: _AudioBinary_data = Field(
|
||||||
|
description="音频数据, 指向AudioBinary_data"
|
||||||
|
)
|
||||||
|
start_time: int = Field(description="开始时间", is_required=True)
|
||||||
|
end_time: int = Field(description="结束时间", is_required=True)
|
||||||
|
|
||||||
|
@validator("audiobinary_data_list")
|
||||||
|
def validate_audiobinary_data_list(cls, v):
|
||||||
|
if not isinstance(v, AudioBinary_data_list):
|
||||||
|
raise ValueError("audiobinary_data_list必须是AudioBinary_data_list类型")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("audiobinary_index")
|
||||||
|
def validate_audiobinary_index(cls, v):
|
||||||
|
if not isinstance(v, int):
|
||||||
|
raise ValueError("audiobinary_index必须是int类型")
|
||||||
|
if v < 0:
|
||||||
|
raise ValueError("audiobinary_index必须大于0")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("audiobinary_data")
|
||||||
|
def validate_audiobinary_data(cls, v):
|
||||||
|
if not isinstance(v, _AudioBinary_data):
|
||||||
|
raise ValueError("audiobinary_data必须是AudioBinary_data类型")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("start_time")
|
||||||
|
def validate_start_time(cls, v):
|
||||||
|
if not isinstance(v, int):
|
||||||
|
raise ValueError("start_time必须是int类型")
|
||||||
|
if v < 0:
|
||||||
|
raise ValueError("start_time必须大于0")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("end_time")
|
||||||
|
def validate_end_time(cls, v, values):
|
||||||
|
if not isinstance(v, int):
|
||||||
|
raise ValueError("end_time必须是int类型")
|
||||||
|
if "start_time" in values and v <= values["start_time"]:
|
||||||
|
raise ValueError("end_time必须大于start_time")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_push_data(
|
||||||
|
cls,
|
||||||
|
audiobinary_data_list: AudioBinary_data_list,
|
||||||
|
data: Any,
|
||||||
|
start_time: int,
|
||||||
|
end_time: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
创建VAD片段
|
||||||
|
"""
|
||||||
|
index = audiobinary_data_list.push_data(data)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
audiobinary_data_list=audiobinary_data_list,
|
||||||
|
audiobinary_index=index,
|
||||||
|
audiobinary_data=audiobinary_data_list[index],
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
获取音频数据长度
|
||||||
|
"""
|
||||||
|
return len(self.audiobinary_data.binary_data)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""
|
||||||
|
字符串展示内容
|
||||||
|
"""
|
||||||
|
tostr = f"audiobinary_data_index: {self.audiobinary_index}\n"
|
||||||
|
tostr += f"start_time: {self.start_time}\n"
|
||||||
|
tostr += f"end_time: {self.end_time}\n"
|
||||||
|
tostr += f"data_length: {len(self.audiobinary_data.binary_data)}\n"
|
||||||
|
tostr += f"data_type: {type(self.audiobinary_data.binary_data)}\n"
|
||||||
|
return tostr
|
280
src/pipeline/ASRpipeline.py
Normal file
280
src/pipeline/ASRpipeline.py
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
from src.pipeline.base import PipelineBase
|
||||||
|
from typing import Dict, Any, Callable
|
||||||
|
from queue import Queue, Empty
|
||||||
|
from src.utils import get_module_logger
|
||||||
|
from src.models import AudioBinary_data_list
|
||||||
|
import threading
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ASRPipeline(PipelineBase):
|
||||||
|
"""
|
||||||
|
管道类
|
||||||
|
实现具体的处理逻辑
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
初始化管道
|
||||||
|
"""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._config: Dict[str, Any] = {}
|
||||||
|
self._functor_dict: Dict[str, Any] = {}
|
||||||
|
self._subqueue_dict: Dict[str, Any] = {}
|
||||||
|
self._is_baked: bool = False
|
||||||
|
self._input_queue: Queue = None
|
||||||
|
self._audio_binary_data_list: AudioBinary_data_list = None
|
||||||
|
|
||||||
|
self._status_lock = threading.Lock()
|
||||||
|
self._is_running: bool = False
|
||||||
|
self._stop_event: bool = False
|
||||||
|
self._callback: Callable = None
|
||||||
|
|
||||||
|
def set_config(self, config: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
设置配置
|
||||||
|
参数:
|
||||||
|
config: Dict[str, Any] 配置
|
||||||
|
"""
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
def get_config(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取配置
|
||||||
|
返回:
|
||||||
|
Dict[str, Any] 配置
|
||||||
|
"""
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
def set_audio_binary(self, audio_binary: AudioBinary_data_list) -> None:
|
||||||
|
"""
|
||||||
|
设置音频二进制存储单元
|
||||||
|
参数:
|
||||||
|
audio_binary: 音频二进制
|
||||||
|
"""
|
||||||
|
self._audio_binary = audio_binary
|
||||||
|
|
||||||
|
def set_models(self, models: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
设置模型
|
||||||
|
"""
|
||||||
|
self._models = models
|
||||||
|
|
||||||
|
def set_input_queue(self, input_queue: Queue) -> None:
|
||||||
|
"""
|
||||||
|
设置输入队列
|
||||||
|
"""
|
||||||
|
self._input_queue = input_queue
|
||||||
|
|
||||||
|
def set_callback(self, callback: Callable) -> None:
|
||||||
|
"""
|
||||||
|
设置回调函数
|
||||||
|
"""
|
||||||
|
self._callback = callback
|
||||||
|
|
||||||
|
def bake(self) -> None:
|
||||||
|
"""
|
||||||
|
烘焙管道
|
||||||
|
"""
|
||||||
|
self._pre_check_resource()
|
||||||
|
self._init_functor()
|
||||||
|
self._is_baked = True
|
||||||
|
|
||||||
|
def _pre_check_resource(self) -> None:
|
||||||
|
"""
|
||||||
|
预检查资源
|
||||||
|
"""
|
||||||
|
if self._input_queue is None:
|
||||||
|
raise RuntimeError("[ASRpipeline]输入队列未设置")
|
||||||
|
if self._functor_dict is None:
|
||||||
|
raise RuntimeError("[ASRpipeline]functor字典未设置")
|
||||||
|
if self._subqueue_dict is None:
|
||||||
|
raise RuntimeError("[ASRpipeline]子队列字典未设置")
|
||||||
|
if self._audio_binary is None:
|
||||||
|
raise RuntimeError("[ASRpipeline]音频存储单元未设置")
|
||||||
|
|
||||||
|
def _init_functor(self) -> None:
|
||||||
|
"""
|
||||||
|
初始化函数
|
||||||
|
自身的functor流程图如下
|
||||||
|
self.input_queue(self.run检测输入到subqueue["original"])->vad
|
||||||
|
->vad2asr ->asrend
|
||||||
|
->vad2spk ->spkend
|
||||||
|
->asrend+spkend->resultbinder
|
||||||
|
->self.callback
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from src.functor import FunctorFactory
|
||||||
|
|
||||||
|
# 加载VAD、asr、spk functor
|
||||||
|
logger.debug(f"使用FunctorFactory创建vad functor")
|
||||||
|
self._functor_dict["vad"] = FunctorFactory.make_functor(
|
||||||
|
functor_name="vad", config=self._config, models=self._models
|
||||||
|
)
|
||||||
|
logger.debug(f"使用FunctorFactory创建asr functor")
|
||||||
|
self._functor_dict["asr"] = FunctorFactory.make_functor(
|
||||||
|
functor_name="asr", config=self._config, models=self._models
|
||||||
|
)
|
||||||
|
logger.debug(f"使用FunctorFactory创建spk functor")
|
||||||
|
self._functor_dict["spk"] = FunctorFactory.make_functor(
|
||||||
|
functor_name="spk", config=self._config, models=self._models
|
||||||
|
)
|
||||||
|
logger.debug(f"使用FunctorFactory创建resultbinder functor")
|
||||||
|
self._functor_dict["resultbinder"] = FunctorFactory.make_functor(
|
||||||
|
functor_name="resultbinder", config=self._config, models=self._models
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建音频数据存储单元
|
||||||
|
self._audio_binary_data_list = AudioBinary_data_list()
|
||||||
|
|
||||||
|
self._functor_dict["vad"].set_audio_binary_data_list(
|
||||||
|
self._audio_binary_data_list
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化子队列
|
||||||
|
self._subqueue_dict["original"] = Queue()
|
||||||
|
self._subqueue_dict["vad2asr"] = Queue()
|
||||||
|
self._subqueue_dict["vad2spk"] = Queue()
|
||||||
|
self._subqueue_dict["asrend"] = Queue()
|
||||||
|
self._subqueue_dict["spkend"] = Queue()
|
||||||
|
# 输出队列
|
||||||
|
self._subqueue_dict["OUTPUT"] = Queue()
|
||||||
|
|
||||||
|
# 设置子队列的输入队列
|
||||||
|
self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"])
|
||||||
|
self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
|
||||||
|
self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
|
||||||
|
# 设置resultbinder的输入队列
|
||||||
|
# 汇总 asr语音识别结果 和 说话人识别结果
|
||||||
|
self._functor_dict["resultbinder"].add_input_queue(
|
||||||
|
"asr", self._subqueue_dict["asrend"]
|
||||||
|
)
|
||||||
|
self._functor_dict["resultbinder"].add_input_queue(
|
||||||
|
"spk", self._subqueue_dict["spkend"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 设置回调函数——放置到对应队列中
|
||||||
|
self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put)
|
||||||
|
self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put)
|
||||||
|
# 设置asr与spk的回调函数
|
||||||
|
self._functor_dict["asr"].add_callback(self._subqueue_dict["asrend"].put)
|
||||||
|
self._functor_dict["spk"].add_callback(self._subqueue_dict["spkend"].put)
|
||||||
|
# 设置resultbinder的回调函数 为 自身被设置的回调函数,用于和外界交互
|
||||||
|
self._functor_dict["resultbinder"].add_callback(self._callback)
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(f"functorFactory引入失败,ASRPipeline无法完成初始化: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"ASRPipeline初始化失败: {str(e)}")
|
||||||
|
|
||||||
|
def _check_result(self, result: Any) -> None:
|
||||||
|
"""
|
||||||
|
检查结果
|
||||||
|
"""
|
||||||
|
# 若asr和spk队列中都有数据,则合并数据
|
||||||
|
if (
|
||||||
|
self._subqueue_dict["asrend"].qsize()
|
||||||
|
& self._subqueue_dict["spkend"].qsize()
|
||||||
|
):
|
||||||
|
asr_data = self._subqueue_dict["asrend"].get()
|
||||||
|
spk_data = self._subqueue_dict["spkend"].get()
|
||||||
|
# 合并数据
|
||||||
|
result = {"asr_data": asr_data, "spk_data": spk_data}
|
||||||
|
# 通知回调函数
|
||||||
|
self._notify_callbacks(result)
|
||||||
|
|
||||||
|
def run(self) -> threading.Thread:
|
||||||
|
"""
|
||||||
|
运行管道
|
||||||
|
Returns:
|
||||||
|
threading.Thread: 返回已运行线程实例
|
||||||
|
"""
|
||||||
|
# 检查运行资源是否准备完毕
|
||||||
|
self._pre_check()
|
||||||
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
logger.info("[ASRpipeline]管道开始运行")
|
||||||
|
return self._thread
|
||||||
|
|
||||||
|
def _pre_check(self) -> None:
|
||||||
|
"""
|
||||||
|
预检查
|
||||||
|
"""
|
||||||
|
if self._is_baked is False:
|
||||||
|
raise RuntimeError("[ASRpipeline]管道未烘焙,无法运行")
|
||||||
|
|
||||||
|
for functor_name, functor in self._functor_dict.items():
|
||||||
|
if functor is None:
|
||||||
|
raise RuntimeError(f"[ASRpipeline]functor{functor_name}异常")
|
||||||
|
|
||||||
|
for subqueue_name, subqueue in self._subqueue_dict.items():
|
||||||
|
if subqueue is None:
|
||||||
|
raise RuntimeError(f"[ASRpipeline]子队列{subqueue_name}异常")
|
||||||
|
|
||||||
|
def _run(self) -> None:
|
||||||
|
"""
|
||||||
|
真实的运行逻辑
|
||||||
|
"""
|
||||||
|
# 运行所有functor
|
||||||
|
for functor_name, functor in self._functor_dict.items():
|
||||||
|
logger.info(f"[ASRpipeline]运行{functor_name}functor")
|
||||||
|
self._functor_dict[functor_name].run()
|
||||||
|
|
||||||
|
# 设置管道运行状态
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = True
|
||||||
|
self._stop_event = False
|
||||||
|
|
||||||
|
while self._is_running and not self._stop_event:
|
||||||
|
try:
|
||||||
|
data = self._input_queue.get(timeout=self._queue_timeout)
|
||||||
|
# logger.debug("[ASRpipeline]获取到的数据length: %s", len(data))
|
||||||
|
# 检查是否是结束信号
|
||||||
|
if data is None:
|
||||||
|
logger.info("收到结束信号,管道准备停止")
|
||||||
|
self._input_queue.task_done() # 标记结束信号已处理
|
||||||
|
break
|
||||||
|
|
||||||
|
# 处理数据
|
||||||
|
self._process(data)
|
||||||
|
|
||||||
|
# 标记任务完成
|
||||||
|
self._input_queue.task_done()
|
||||||
|
|
||||||
|
except Empty:
|
||||||
|
# 队列获取超时,继续等待
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[ASRpipeline]管道处理数据出错: {str(e)}")
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info("[ASRpipeline]管道停止运行")
|
||||||
|
|
||||||
|
def _process(self, data: Any) -> Any:
|
||||||
|
"""
|
||||||
|
处理数据
|
||||||
|
参数:
|
||||||
|
data: 输入数据
|
||||||
|
返回:
|
||||||
|
处理结果
|
||||||
|
"""
|
||||||
|
# 子类实现具体的处理逻辑
|
||||||
|
self._subqueue_dict["original"].put(data)
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""
|
||||||
|
停止管道
|
||||||
|
"""
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = False
|
||||||
|
self._stop_event = True
|
||||||
|
for functor_name, functor in self._functor_dict.items():
|
||||||
|
# logger.info(f"停止{functor_name}functor")
|
||||||
|
if functor.stop():
|
||||||
|
logger.info(f"[ASRpipeline]子Functor[{functor_name}]停止")
|
||||||
|
else:
|
||||||
|
logger.error(f"[ASRpipeline]子Functor[{functor_name}]停止失败")
|
||||||
|
self._thread.join()
|
||||||
|
logger.info("[ASRpipeline]管道停止")
|
||||||
|
return True
|
3
src/pipeline/__init__.py
Normal file
3
src/pipeline/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from src.pipeline.base import PipelineBase, PipelineFactory
|
||||||
|
|
||||||
|
__all__ = ["PipelineBase", "PipelineFactory"]
|
151
src/pipeline/base.py
Normal file
151
src/pipeline/base.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from queue import Queue, Empty
|
||||||
|
from typing import List, Callable, Any, Optional
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineBase(ABC):
|
||||||
|
"""
|
||||||
|
管道基类
|
||||||
|
定义了管道的基本接口和通用功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_queue: Optional[Queue] = None):
|
||||||
|
"""
|
||||||
|
初始化管道
|
||||||
|
参数:
|
||||||
|
input_queue: 输入队列,用于接收数据
|
||||||
|
"""
|
||||||
|
self._input_queue = input_queue
|
||||||
|
self._callbacks: List[Callable] = []
|
||||||
|
self._is_running = False
|
||||||
|
self._stop_event = False
|
||||||
|
self._thread: Optional[threading.Thread] = None
|
||||||
|
self._stop_timeout = 5 # 默认停止超时时间(秒)
|
||||||
|
self._queue_timeout = 1 # 队列获取超时时间(秒)
|
||||||
|
|
||||||
|
def set_input_queue(self, queue: Queue) -> None:
|
||||||
|
"""
|
||||||
|
设置输入队列
|
||||||
|
参数:
|
||||||
|
queue: 输入队列
|
||||||
|
"""
|
||||||
|
self._input_queue = queue
|
||||||
|
|
||||||
|
def add_callback(self, callback: Callable) -> None:
|
||||||
|
"""
|
||||||
|
添加回调函数
|
||||||
|
参数:
|
||||||
|
callback: 回调函数,接收处理结果
|
||||||
|
"""
|
||||||
|
self._callbacks.append(callback)
|
||||||
|
|
||||||
|
def _notify_callbacks(self, result: Any) -> None:
|
||||||
|
"""
|
||||||
|
通知所有回调函数
|
||||||
|
参数:
|
||||||
|
result: 处理结果
|
||||||
|
"""
|
||||||
|
for callback in self._callbacks:
|
||||||
|
try:
|
||||||
|
callback(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"回调函数执行出错: {str(e)}")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _process(self, data: Any) -> Any:
|
||||||
|
"""
|
||||||
|
处理数据
|
||||||
|
参数:
|
||||||
|
data: 输入数据
|
||||||
|
返回:
|
||||||
|
处理结果
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _run(self) -> None:
|
||||||
|
"""
|
||||||
|
运行管道
|
||||||
|
从输入队列获取数据并处理
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def stop(self, timeout: Optional[float] = None) -> bool:
|
||||||
|
"""
|
||||||
|
停止管道
|
||||||
|
参数:
|
||||||
|
timeout: 停止超时时间(秒),None表示使用默认超时时间
|
||||||
|
返回:
|
||||||
|
bool: 是否成功停止
|
||||||
|
"""
|
||||||
|
if not self._is_running:
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.info("正在停止管道...")
|
||||||
|
self._stop_event = True
|
||||||
|
self._is_running = False
|
||||||
|
|
||||||
|
# 等待线程结束
|
||||||
|
if self._thread and self._thread.is_alive():
|
||||||
|
timeout = timeout if timeout is not None else self._stop_timeout
|
||||||
|
self._thread.join(timeout=timeout)
|
||||||
|
|
||||||
|
# 检查是否成功停止
|
||||||
|
if self._thread.is_alive():
|
||||||
|
logger.warning(f"管道停止超时({timeout}秒),强制终止")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.info("管道已成功停止")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def force_stop(self) -> None:
|
||||||
|
"""
|
||||||
|
强制停止管道
|
||||||
|
注意:这可能会导致资源未正确释放
|
||||||
|
"""
|
||||||
|
logger.warning("强制停止管道")
|
||||||
|
self._stop_event = True
|
||||||
|
self._is_running = False
|
||||||
|
# 注意:Python的线程无法被强制终止,这里只是设置标志
|
||||||
|
# 实际终止需要依赖操作系统的进程管理
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineFactory:
|
||||||
|
"""
|
||||||
|
管道工厂类
|
||||||
|
用于创建管道实例
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.pipeline.ASRpipeline import ASRPipeline
|
||||||
|
def _create_pipeline_ASRpipeline(*args, **kwargs) -> ASRPipeline:
|
||||||
|
"""
|
||||||
|
创建ASR管道实例
|
||||||
|
"""
|
||||||
|
from src.pipeline.ASRpipeline import ASRPipeline
|
||||||
|
pipeline = ASRPipeline()
|
||||||
|
pipeline.set_config(kwargs["config"])
|
||||||
|
pipeline.set_models(kwargs["models"])
|
||||||
|
pipeline.set_audio_binary(kwargs["audio_binary"])
|
||||||
|
pipeline.set_input_queue(kwargs["input_queue"])
|
||||||
|
pipeline.set_callback(kwargs["callback"])
|
||||||
|
# pipeline.bake()
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_pipeline(cls, pipeline_name: str, *args, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
创建管道实例
|
||||||
|
"""
|
||||||
|
if pipeline_name == "ASRpipeline":
|
||||||
|
return cls._create_pipeline_ASRpipeline(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的管道类型: {pipeline_name}")
|
||||||
|
|
0
src/pipeline/test.py
Normal file
0
src/pipeline/test.py
Normal file
231
src/runner/ASRRunner.py
Normal file
231
src/runner/ASRRunner.py
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
"""
|
||||||
|
-*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
ASRRunner
|
||||||
|
继承RunnerBase
|
||||||
|
专属pipeline为ASRPipeline
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.pipeline.ASRpipeline import ASRPipeline
|
||||||
|
from src.pipeline import PipelineFactory
|
||||||
|
from src.models import AudioBinary_data_list, AudioBinary_Config
|
||||||
|
from src.core.model_loader import ModelLoader
|
||||||
|
from src.config import DefaultConfig
|
||||||
|
import asyncio
|
||||||
|
from queue import Queue
|
||||||
|
import soundfile
|
||||||
|
import time
|
||||||
|
from typing import List, Optional
|
||||||
|
import uuid
|
||||||
|
from threading import Thread
|
||||||
|
from src.utils.mock_websocket import MockWebSocketClient as WebSocketClient
|
||||||
|
from .runner import RunnerBase
|
||||||
|
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
OVAERWATCH = False
|
||||||
|
|
||||||
|
model_loader = ModelLoader()
|
||||||
|
|
||||||
|
class ASRRunner(RunnerBase):
|
||||||
|
"""
|
||||||
|
运行器类
|
||||||
|
负责管理资源和协调Pipeline的运行
|
||||||
|
"""
|
||||||
|
class SenderAndReceiver:
|
||||||
|
"""
|
||||||
|
对于单个pipeline的管理
|
||||||
|
包含 发送者 和 接收者
|
||||||
|
_sender: 发送者 唯一
|
||||||
|
_receiver: 接收者 可以有多个
|
||||||
|
_pipeline: 对应管道 唯一
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
# 可选传入参数,
|
||||||
|
self._name: str = kwargs.get("name", "")
|
||||||
|
self._sender: Optional[WebSocketClient] = kwargs.get("sender", None)
|
||||||
|
self._receiver: List[WebSocketClient] = kwargs.get("receiver", [])
|
||||||
|
|
||||||
|
# 资源
|
||||||
|
self._audio_config: AudioBinary_Config = kwargs.get("audio_config", DefaultConfig.audio_config)
|
||||||
|
self._models: dict = kwargs.get("models", None)
|
||||||
|
self._audio_binary: AudioBinary_data_list = AudioBinary_data_list()
|
||||||
|
# id唯一标识
|
||||||
|
self._id: str = str(uuid.uuid4())
|
||||||
|
# 输入队列
|
||||||
|
self._input_queue: Queue = Queue()
|
||||||
|
self._pipeline: Optional[ASRPipeline] = None
|
||||||
|
self._task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
def set_name(self, name: str):
|
||||||
|
self._name = name
|
||||||
|
|
||||||
|
def set_id(self, id: str):
|
||||||
|
self._id = id
|
||||||
|
|
||||||
|
def set_sender(self, sender: WebSocketClient):
|
||||||
|
self._sender = sender
|
||||||
|
|
||||||
|
def set_pipeline(self, pipeline: ASRPipeline):
|
||||||
|
self._pipeline = pipeline
|
||||||
|
config = {
|
||||||
|
"audio_config": self._audio_config,
|
||||||
|
}
|
||||||
|
self._pipeline.set_config(config)
|
||||||
|
self._pipeline.set_models(self._models)
|
||||||
|
self._pipeline.set_audio_binary(self._audio_binary)
|
||||||
|
self._pipeline.set_input_queue(self._input_queue)
|
||||||
|
|
||||||
|
# --- 异步-同步桥梁 ---
|
||||||
|
# 创建一个线程安全的回调函数,用于从Pipeline的线程中调用Runner的异步方法
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
def thread_safe_callback(message):
|
||||||
|
asyncio.run_coroutine_threadsafe(self.deal_message(message), loop)
|
||||||
|
|
||||||
|
self._pipeline.set_callback(thread_safe_callback)
|
||||||
|
self._pipeline.bake()
|
||||||
|
|
||||||
|
def append_receiver(self, receiver: WebSocketClient):
|
||||||
|
self._receiver.append(receiver)
|
||||||
|
|
||||||
|
def delete_receiver(self, receiver: WebSocketClient):
|
||||||
|
self._receiver.remove(receiver)
|
||||||
|
|
||||||
|
async def deal_message(self, message: str):
|
||||||
|
await self.broadcast(message)
|
||||||
|
|
||||||
|
async def broadcast(self, message: str):
|
||||||
|
"""
|
||||||
|
广播发送给所有接收者
|
||||||
|
"""
|
||||||
|
logger.info("[ASRRunner][SAR-%s]广播发送给所有接收者: 消息长度:%s", self._name, len(message))
|
||||||
|
logger.info(f"SAR-{self._name} 的接收者列表: {self._receiver}")
|
||||||
|
tasks = [receiver.send(message) for receiver in self._receiver]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
async def _run(self):
|
||||||
|
"""
|
||||||
|
运行SAR
|
||||||
|
"""
|
||||||
|
self._pipeline.run()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = await self._sender.recv()
|
||||||
|
if data is None:
|
||||||
|
# `None` is used as a signal to end the stream
|
||||||
|
await loop.run_in_executor(None, self._input_queue.put, None)
|
||||||
|
break
|
||||||
|
# logger.debug("[ASRRunner][SAR-%s]接收到的数据length: %s", self._name, len(data))
|
||||||
|
await loop.run_in_executor(None, self._input_queue.put, data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[ASRRunner][SAR-{self._name}] _run loop error: {e}")
|
||||||
|
break
|
||||||
|
await self.stop()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""
|
||||||
|
运行SAR
|
||||||
|
"""
|
||||||
|
self._task = asyncio.create_task(self._run())
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""
|
||||||
|
停止SAR
|
||||||
|
"""
|
||||||
|
logger.info(f"Stopping SAR: {self._name}")
|
||||||
|
self._pipeline.stop()
|
||||||
|
|
||||||
|
# Close all receiver websockets
|
||||||
|
receiver_tasks = [ws.close() for ws in self._receiver]
|
||||||
|
await asyncio.gather(*receiver_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# Close the sender websocket
|
||||||
|
if self._sender:
|
||||||
|
await self._sender.close()
|
||||||
|
|
||||||
|
# Cancel the main task if it's still running
|
||||||
|
if self._task and not self._task.done():
|
||||||
|
self._task.cancel()
|
||||||
|
logger.info(f"SAR stopped: {self._name}")
|
||||||
|
|
||||||
|
def __init__(self,*args,**kwargs):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
# 接收资源
|
||||||
|
self._default_audio_config = kwargs.get("audio_config", DefaultConfig.audio_config)
|
||||||
|
# self._audio_binary_list = args.get("audio_binary_list", None)
|
||||||
|
self._default_models = kwargs.get("models", None)
|
||||||
|
self._SAR_list: List[self.SenderAndReceiver] = []
|
||||||
|
|
||||||
|
def set_default_config(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
设置配置
|
||||||
|
"""
|
||||||
|
self._default_audio_config = kwargs.get("audio_config", self._default_audio_config)
|
||||||
|
self._default_models = kwargs.get("models", self._default_models)
|
||||||
|
|
||||||
|
def new_SAR(
|
||||||
|
self,
|
||||||
|
ws: "WebSocketClient",
|
||||||
|
name: str = "",
|
||||||
|
audio_config: "AudioBinary_Config" = None,
|
||||||
|
models: dict = None
|
||||||
|
) -> uuid.UUID:
|
||||||
|
"""
|
||||||
|
创建新的SAR SenderAndReceiver
|
||||||
|
"""
|
||||||
|
if audio_config is None:
|
||||||
|
audio_config = self._default_audio_config
|
||||||
|
if models is None:
|
||||||
|
models = self._default_models
|
||||||
|
|
||||||
|
try:
|
||||||
|
new_SAR = self.SenderAndReceiver(
|
||||||
|
name=name,
|
||||||
|
audio_config=audio_config,
|
||||||
|
models=models
|
||||||
|
)
|
||||||
|
new_pipeline = ASRPipeline()
|
||||||
|
new_SAR.set_pipeline(new_pipeline)
|
||||||
|
# new_SAR.set_pipeline()
|
||||||
|
logger.info("创建新的SAR: name %s, id %s", new_SAR._name, new_SAR._id)
|
||||||
|
new_SAR.set_sender(ws)
|
||||||
|
new_SAR.append_receiver(ws)
|
||||||
|
new_SAR.run()
|
||||||
|
self._SAR_list.append(new_SAR)
|
||||||
|
return new_SAR._id
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("创建管道失败: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def join_SAR(
|
||||||
|
self,
|
||||||
|
ws: "WebSocketClient",
|
||||||
|
name: Optional[str] = None,
|
||||||
|
id: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
加入SAR的Receiver
|
||||||
|
"""
|
||||||
|
# 使用next获取迭代器下一个元素,生成pipeline_list迭代器,按id停止
|
||||||
|
if id:
|
||||||
|
exist_pipeline = next((pipeline for pipeline in self._SAR_list if pipeline._id == id), None)
|
||||||
|
if name:
|
||||||
|
exist_pipeline = next((pipeline for pipeline in self._SAR_list if pipeline._name == name), None)
|
||||||
|
if exist_pipeline:
|
||||||
|
exist_pipeline.append_receiver(ws)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""
|
||||||
|
优雅地关闭所有SAR会话
|
||||||
|
"""
|
||||||
|
logger.info("Shutting down all SAR instances...")
|
||||||
|
tasks = [sar.stop() for sar in self._SAR_list]
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
logger.info("All SAR instances have been shut down.")
|
3
src/runner/__init__.py
Normal file
3
src/runner/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .ASRRunner import ASRRunner
|
||||||
|
|
||||||
|
__all__ = ["ASRRunner"]
|
105
src/runner/runner.py
Normal file
105
src/runner/runner.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
"""
|
||||||
|
-*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
Runner类
|
||||||
|
所有的Runner都对应一个fastapi的endpoint,
|
||||||
|
Runner需要处理:
|
||||||
|
1.新的websocket 进来后放到 unknow_websocket_pool中
|
||||||
|
2.收到特定消息后, 将消息转发给特定的pipeline处理
|
||||||
|
3.管理pipeline与websocket对应关系, 管理pipeline的ID
|
||||||
|
4.管理pipeline的启动和停止
|
||||||
|
5.管理所有pipeline用到的资源, 管理pipeline的存活时间。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
from threading import Thread, Lock
|
||||||
|
from queue import Queue
|
||||||
|
import traceback
|
||||||
|
import time
|
||||||
|
|
||||||
|
from src.audio_chunk import AudioChunk
|
||||||
|
from src.pipeline import PipelineFactory
|
||||||
|
from src.core.model_loader import ModelLoader
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
audio_chunk = AudioChunk()
|
||||||
|
models_loaded = ModelLoader()
|
||||||
|
|
||||||
|
|
||||||
|
class RunnerBase(ABC):
|
||||||
|
"""
|
||||||
|
运行器基类
|
||||||
|
定义了运行器的基本接口
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class STTRunnerFactory:
|
||||||
|
"""
|
||||||
|
STT Runner工厂类
|
||||||
|
用于创建运行器实例
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_runner(
|
||||||
|
audio_binary_name: str,
|
||||||
|
model_name_list: List[str],
|
||||||
|
pipeline_name_list: List[str],
|
||||||
|
) -> RunnerBase:
|
||||||
|
"""
|
||||||
|
创建运行器
|
||||||
|
参数:
|
||||||
|
audio_binary_name: 音频二进制名称
|
||||||
|
model_name_list: 模型名称列表
|
||||||
|
pipeline_name_list: 管道名称列表
|
||||||
|
返回:
|
||||||
|
Runner实例
|
||||||
|
"""
|
||||||
|
audio_binary = audio_chunk.get_audio_binary(audio_binary_name)
|
||||||
|
models: Dict[str, Any] = {
|
||||||
|
model_name: models_loaded.models[model_name]
|
||||||
|
for model_name in model_name_list
|
||||||
|
}
|
||||||
|
pipelines: List[Pipeline] = [
|
||||||
|
PipelineFactory.create_pipeline(pipeline_name)
|
||||||
|
for pipeline_name in pipeline_name_list
|
||||||
|
]
|
||||||
|
return RunnerBase(
|
||||||
|
audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_runner_from_config(
|
||||||
|
cls,
|
||||||
|
config: Dict[str, Any],
|
||||||
|
) -> RunnerBase:
|
||||||
|
"""
|
||||||
|
从配置创建运行器
|
||||||
|
参数:
|
||||||
|
config: 配置字典
|
||||||
|
返回:
|
||||||
|
Runner实例
|
||||||
|
"""
|
||||||
|
audio_binary_name = config["audio_binary_name"]
|
||||||
|
model_name_list = config["model_name_list"]
|
||||||
|
pipeline_name_list = config["pipeline_name_list"]
|
||||||
|
return cls._create_runner(
|
||||||
|
audio_binary_name, model_name_list, pipeline_name_list
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_runner_normal(cls) -> RunnerBase:
|
||||||
|
"""
|
||||||
|
创建默认运行器
|
||||||
|
返回:
|
||||||
|
Runner实例
|
||||||
|
"""
|
||||||
|
audio_binary_name = None
|
||||||
|
model_name_list = list(models_loaded.models.keys())
|
||||||
|
pipeline_name_list = None
|
||||||
|
return cls._create_runner(
|
||||||
|
audio_binary_name, model_name_list, pipeline_name_list
|
||||||
|
)
|
273
src/server.py
273
src/server.py
@ -10,233 +10,80 @@ import json
|
|||||||
import websockets
|
import websockets
|
||||||
import ssl
|
import ssl
|
||||||
import argparse
|
import argparse
|
||||||
from config import parse_args
|
from src.runner import ASRRunner
|
||||||
from models import load_models
|
from src.config import DefaultConfig
|
||||||
from service import ASRService
|
from src.config import AudioBinary_Config
|
||||||
|
from src.websockets.router import websocket_router
|
||||||
|
from src.core import ModelLoader
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
# 全局变量,存储当前连接的WebSocket客户端
|
# 全局变量,存储当前连接的WebSocket客户端
|
||||||
websocket_users = set()
|
websocket_users = set()
|
||||||
|
|
||||||
|
# 使用 lifespan 上下文管理器来管理应用的生命周期
|
||||||
async def ws_reset(websocket):
|
@asynccontextmanager
|
||||||
"""重置WebSocket连接状态并关闭连接"""
|
async def lifespan(app: FastAPI):
|
||||||
print(f"重置WebSocket连接,当前连接数: {len(websocket_users)}")
|
|
||||||
|
|
||||||
# 重置状态字典
|
|
||||||
websocket.status_dict_asr_online["cache"] = {}
|
|
||||||
websocket.status_dict_asr_online["is_final"] = True
|
|
||||||
websocket.status_dict_vad["cache"] = {}
|
|
||||||
websocket.status_dict_vad["is_final"] = True
|
|
||||||
websocket.status_dict_punc["cache"] = {}
|
|
||||||
|
|
||||||
# 关闭连接
|
|
||||||
await websocket.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def clear_websocket():
|
|
||||||
"""清理所有WebSocket连接"""
|
|
||||||
for websocket in websocket_users:
|
|
||||||
await ws_reset(websocket)
|
|
||||||
websocket_users.clear()
|
|
||||||
|
|
||||||
|
|
||||||
async def ws_serve(websocket, path):
|
|
||||||
"""
|
"""
|
||||||
WebSocket服务主函数,处理客户端连接和消息
|
在应用启动时加载模型和初始化ASRRunner,
|
||||||
|
在应用关闭时优雅地关闭ASRRunner。
|
||||||
参数:
|
|
||||||
websocket: WebSocket连接对象
|
|
||||||
path: 连接路径
|
|
||||||
"""
|
"""
|
||||||
frames = [] # 存储所有音频帧
|
logger.info("应用启动,开始加载模型和初始化Runner...")
|
||||||
frames_asr = [] # 存储用于离线ASR的音频帧
|
|
||||||
frames_asr_online = [] # 存储用于在线ASR的音频帧
|
|
||||||
|
|
||||||
global websocket_users
|
# 1. 加载模型
|
||||||
# await clear_websocket() # 清理现有连接(目前注释掉,允许多客户端)
|
# 这里的参数可以从配置文件或环境变量中获取
|
||||||
|
args = {
|
||||||
|
"asr_model": "paraformer-zh",
|
||||||
|
"asr_model_revision": "v2.0.4",
|
||||||
|
"vad_model": "fsmn-vad",
|
||||||
|
"vad_model_revision": "v2.0.4",
|
||||||
|
"spk_model": "cam++",
|
||||||
|
"spk_model_revision": "v2.0.2",
|
||||||
|
}
|
||||||
|
model_loader = ModelLoader()
|
||||||
|
models = model_loader.load_models(args)
|
||||||
|
|
||||||
# 添加到用户集合
|
# 2. 初始化 ASRRunner
|
||||||
websocket_users.add(websocket)
|
_audio_config = AudioBinary_Config(
|
||||||
|
chunk_size=200, # ms
|
||||||
|
sample_rate=16000,
|
||||||
|
sample_width=2, # 16-bit
|
||||||
|
channels=1,
|
||||||
|
)
|
||||||
|
_audio_config.chunk_stride = int(_audio_config.chunk_size * _audio_config.sample_rate / 1000)
|
||||||
|
|
||||||
# 初始化连接状态
|
asr_runner = ASRRunner()
|
||||||
websocket.status_dict_asr = {}
|
asr_runner.set_default_config(
|
||||||
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
|
audio_config=_audio_config,
|
||||||
websocket.status_dict_vad = {"cache": {}, "is_final": False}
|
models=models,
|
||||||
websocket.status_dict_punc = {"cache": {}}
|
)
|
||||||
websocket.chunk_interval = 10
|
|
||||||
websocket.vad_pre_idx = 0
|
|
||||||
websocket.is_speaking = True # 默认用户正在说话
|
|
||||||
|
|
||||||
# 语音检测状态
|
# 3. 将 asr_runner 实例存储在 app.state 中
|
||||||
speech_start = False
|
app.state.asr_runner = asr_runner
|
||||||
speech_end_i = -1
|
logger.info("模型加载和Runner初始化完成。")
|
||||||
|
|
||||||
# 初始化配置
|
yield
|
||||||
websocket.wav_name = "microphone"
|
|
||||||
websocket.mode = "2pass" # 默认使用两阶段识别模式
|
|
||||||
|
|
||||||
print("新用户已连接", flush=True)
|
# --- 应用关闭时执行的代码 ---
|
||||||
|
logger.info("应用关闭,开始清理资源...")
|
||||||
|
await app.state.asr_runner.shutdown()
|
||||||
|
logger.info("资源清理完成。")
|
||||||
|
|
||||||
try:
|
# 初始化FastAPI应用,并指定lifespan
|
||||||
# 持续接收客户端消息
|
app = FastAPI(lifespan=lifespan)
|
||||||
async for message in websocket:
|
|
||||||
# 处理JSON配置消息
|
|
||||||
if isinstance(message, str):
|
|
||||||
try:
|
|
||||||
messagejson = json.loads(message)
|
|
||||||
|
|
||||||
# 更新各种配置参数
|
# 挂载WebSocket路由
|
||||||
if "is_speaking" in messagejson:
|
app.include_router(websocket_router, prefix="/ws")
|
||||||
websocket.is_speaking = messagejson["is_speaking"]
|
|
||||||
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
|
|
||||||
if "chunk_interval" in messagejson:
|
|
||||||
websocket.chunk_interval = messagejson["chunk_interval"]
|
|
||||||
if "wav_name" in messagejson:
|
|
||||||
websocket.wav_name = messagejson.get("wav_name")
|
|
||||||
if "chunk_size" in messagejson:
|
|
||||||
chunk_size = messagejson["chunk_size"]
|
|
||||||
if isinstance(chunk_size, str):
|
|
||||||
chunk_size = chunk_size.split(",")
|
|
||||||
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
|
|
||||||
if "encoder_chunk_look_back" in messagejson:
|
|
||||||
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
|
|
||||||
if "decoder_chunk_look_back" in messagejson:
|
|
||||||
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"]
|
|
||||||
if "hotword" in messagejson:
|
|
||||||
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
|
|
||||||
if "mode" in messagejson:
|
|
||||||
websocket.mode = messagejson["mode"]
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
print(f"无效的JSON消息: {message}")
|
|
||||||
|
|
||||||
# 根据chunk_interval更新VAD的chunk_size
|
|
||||||
websocket.status_dict_vad["chunk_size"] = int(
|
|
||||||
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1] * 60 / websocket.chunk_interval
|
|
||||||
)
|
|
||||||
|
|
||||||
# 处理音频数据
|
|
||||||
if len(frames_asr_online) > 0 or len(frames_asr) >= 0 or not isinstance(message, str):
|
|
||||||
if not isinstance(message, str): # 二进制音频数据
|
|
||||||
# 添加到帧缓冲区
|
|
||||||
frames.append(message)
|
|
||||||
duration_ms = len(message) // 32 # 计算音频时长
|
|
||||||
websocket.vad_pre_idx += duration_ms
|
|
||||||
|
|
||||||
# 处理在线ASR
|
|
||||||
frames_asr_online.append(message)
|
|
||||||
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
|
|
||||||
|
|
||||||
# 达到chunk_interval或最终帧时处理在线ASR
|
|
||||||
if (len(frames_asr_online) % websocket.chunk_interval == 0 or
|
|
||||||
websocket.status_dict_asr_online["is_final"]):
|
|
||||||
if websocket.mode == "2pass" or websocket.mode == "online":
|
|
||||||
audio_in = b"".join(frames_asr_online)
|
|
||||||
try:
|
|
||||||
await asr_service.async_asr_online(websocket, audio_in)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"在线ASR处理错误: {e}")
|
|
||||||
frames_asr_online = []
|
|
||||||
|
|
||||||
# 如果检测到语音开始,收集帧用于离线ASR
|
|
||||||
if speech_start:
|
|
||||||
frames_asr.append(message)
|
|
||||||
|
|
||||||
# VAD处理 - 语音活动检测
|
|
||||||
try:
|
|
||||||
speech_start_i, speech_end_i = await asr_service.async_vad(websocket, message)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"VAD处理错误: {e}")
|
|
||||||
|
|
||||||
# 检测到语音开始
|
|
||||||
if speech_start_i != -1:
|
|
||||||
speech_start = True
|
|
||||||
# 计算开始偏移并收集前面的帧
|
|
||||||
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
|
|
||||||
frames_pre = frames[-beg_bias:] if beg_bias < len(frames) else frames
|
|
||||||
frames_asr = []
|
|
||||||
frames_asr.extend(frames_pre)
|
|
||||||
|
|
||||||
# 处理离线ASR (语音结束或用户停止说话)
|
|
||||||
if speech_end_i != -1 or not websocket.is_speaking:
|
|
||||||
if websocket.mode == "2pass" or websocket.mode == "offline":
|
|
||||||
audio_in = b"".join(frames_asr)
|
|
||||||
try:
|
|
||||||
await asr_service.async_asr(websocket, audio_in)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"离线ASR处理错误: {e}")
|
|
||||||
|
|
||||||
# 重置状态
|
|
||||||
frames_asr = []
|
|
||||||
speech_start = False
|
|
||||||
frames_asr_online = []
|
|
||||||
websocket.status_dict_asr_online["cache"] = {}
|
|
||||||
|
|
||||||
# 如果用户停止说话,完全重置
|
|
||||||
if not websocket.is_speaking:
|
|
||||||
websocket.vad_pre_idx = 0
|
|
||||||
frames = []
|
|
||||||
websocket.status_dict_vad["cache"] = {}
|
|
||||||
else:
|
|
||||||
# 保留最近的帧用于下一轮处理
|
|
||||||
frames = frames[-20:]
|
|
||||||
|
|
||||||
except websockets.ConnectionClosed:
|
|
||||||
print(f"连接已关闭...", flush=True)
|
|
||||||
await ws_reset(websocket)
|
|
||||||
websocket_users.remove(websocket)
|
|
||||||
except websockets.InvalidState:
|
|
||||||
print("无效的WebSocket状态...")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"发生异常: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def start_server(args, asr_service_instance):
|
|
||||||
"""
|
|
||||||
启动WebSocket服务器
|
|
||||||
|
|
||||||
参数:
|
|
||||||
args: 命令行参数
|
|
||||||
asr_service_instance: ASR服务实例
|
|
||||||
"""
|
|
||||||
global asr_service
|
|
||||||
asr_service = asr_service_instance
|
|
||||||
|
|
||||||
# 配置SSL (如果提供了证书)
|
|
||||||
if args.certfile and len(args.certfile) > 0:
|
|
||||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
||||||
ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile)
|
|
||||||
|
|
||||||
start_server = websockets.serve(
|
|
||||||
ws_serve, args.host, args.port,
|
|
||||||
subprotocols=["binary"],
|
|
||||||
ping_interval=None,
|
|
||||||
ssl=ssl_context
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
start_server = websockets.serve(
|
|
||||||
ws_serve, args.host, args.port,
|
|
||||||
subprotocols=["binary"],
|
|
||||||
ping_interval=None
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}")
|
|
||||||
|
|
||||||
# 启动事件循环
|
|
||||||
asyncio.get_event_loop().run_until_complete(start_server)
|
|
||||||
asyncio.get_event_loop().run_forever()
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def read_root():
|
||||||
|
return {"message": "FunASR-FastAPI WebSocket Server is running."}
|
||||||
|
|
||||||
|
# 如果需要直接运行此文件进行测试
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 解析命令行参数
|
# 注意:在生产环境中,推荐使用Gunicorn + Uvicorn workers
|
||||||
args = parse_args()
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
|
|
||||||
# 加载模型
|
|
||||||
print("正在加载模型...")
|
|
||||||
models = load_models(args)
|
|
||||||
print("模型加载完成!当前仅支持单个客户端同时连接!")
|
|
||||||
|
|
||||||
# 创建ASR服务
|
|
||||||
asr_service = ASRService(models)
|
|
||||||
|
|
||||||
# 启动服务器
|
|
||||||
start_server(args, asr_service)
|
|
||||||
|
127
src/service.py
127
src/service.py
@ -1,127 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
ASR服务模块 - 提供语音识别相关的核心功能
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
class ASRService:
|
|
||||||
"""ASR服务类,封装各种语音识别相关功能"""
|
|
||||||
|
|
||||||
def __init__(self, models):
|
|
||||||
"""
|
|
||||||
初始化ASR服务
|
|
||||||
|
|
||||||
参数:
|
|
||||||
models: 包含各种预加载模型的字典
|
|
||||||
"""
|
|
||||||
self.model_asr = models["asr"]
|
|
||||||
self.model_asr_streaming = models["asr_streaming"]
|
|
||||||
self.model_vad = models["vad"]
|
|
||||||
self.model_punc = models["punc"]
|
|
||||||
|
|
||||||
async def async_vad(self, websocket, audio_in):
|
|
||||||
"""
|
|
||||||
语音活动检测
|
|
||||||
|
|
||||||
参数:
|
|
||||||
websocket: WebSocket连接
|
|
||||||
audio_in: 二进制音频数据
|
|
||||||
|
|
||||||
返回:
|
|
||||||
tuple: (speech_start, speech_end) 语音开始和结束位置
|
|
||||||
"""
|
|
||||||
# 使用VAD模型分析音频段
|
|
||||||
segments_result = self.model_vad.generate(
|
|
||||||
input=audio_in,
|
|
||||||
**websocket.status_dict_vad
|
|
||||||
)[0]["value"]
|
|
||||||
|
|
||||||
speech_start = -1
|
|
||||||
speech_end = -1
|
|
||||||
|
|
||||||
# 解析VAD结果
|
|
||||||
if len(segments_result) == 0 or len(segments_result) > 1:
|
|
||||||
return speech_start, speech_end
|
|
||||||
|
|
||||||
if segments_result[0][0] != -1:
|
|
||||||
speech_start = segments_result[0][0]
|
|
||||||
if segments_result[0][1] != -1:
|
|
||||||
speech_end = segments_result[0][1]
|
|
||||||
|
|
||||||
return speech_start, speech_end
|
|
||||||
|
|
||||||
async def async_asr(self, websocket, audio_in):
|
|
||||||
"""
|
|
||||||
离线ASR处理
|
|
||||||
|
|
||||||
参数:
|
|
||||||
websocket: WebSocket连接
|
|
||||||
audio_in: 二进制音频数据
|
|
||||||
"""
|
|
||||||
if len(audio_in) > 0:
|
|
||||||
# 使用离线ASR模型处理音频
|
|
||||||
rec_result = self.model_asr.generate(
|
|
||||||
input=audio_in,
|
|
||||||
**websocket.status_dict_asr
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
# 如果有标点符号模型且识别出文本,则添加标点
|
|
||||||
if self.model_punc is not None and len(rec_result["text"]) > 0:
|
|
||||||
rec_result = self.model_punc.generate(
|
|
||||||
input=rec_result["text"],
|
|
||||||
**websocket.status_dict_punc
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
# 如果识别出文本,发送到客户端
|
|
||||||
if len(rec_result["text"]) > 0:
|
|
||||||
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
|
||||||
message = json.dumps({
|
|
||||||
"mode": mode,
|
|
||||||
"text": rec_result["text"],
|
|
||||||
"wav_name": websocket.wav_name,
|
|
||||||
"is_final": websocket.is_speaking,
|
|
||||||
})
|
|
||||||
await websocket.send(message)
|
|
||||||
else:
|
|
||||||
# 如果没有音频数据,发送空文本
|
|
||||||
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
|
||||||
message = json.dumps({
|
|
||||||
"mode": mode,
|
|
||||||
"text": "",
|
|
||||||
"wav_name": websocket.wav_name,
|
|
||||||
"is_final": websocket.is_speaking,
|
|
||||||
})
|
|
||||||
await websocket.send(message)
|
|
||||||
|
|
||||||
async def async_asr_online(self, websocket, audio_in):
|
|
||||||
"""
|
|
||||||
在线ASR处理
|
|
||||||
|
|
||||||
参数:
|
|
||||||
websocket: WebSocket连接
|
|
||||||
audio_in: 二进制音频数据
|
|
||||||
"""
|
|
||||||
if len(audio_in) > 0:
|
|
||||||
# 使用在线ASR模型处理音频
|
|
||||||
rec_result = self.model_asr_streaming.generate(
|
|
||||||
input=audio_in,
|
|
||||||
**websocket.status_dict_asr_online
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
# 在2pass模式下,如果是最终帧则跳过(留给离线ASR处理)
|
|
||||||
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False):
|
|
||||||
return
|
|
||||||
|
|
||||||
# 如果识别出文本,发送到客户端
|
|
||||||
if len(rec_result["text"]):
|
|
||||||
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
|
|
||||||
message = json.dumps({
|
|
||||||
"mode": mode,
|
|
||||||
"text": rec_result["text"],
|
|
||||||
"wav_name": websocket.wav_name,
|
|
||||||
"is_final": websocket.is_speaking,
|
|
||||||
})
|
|
||||||
await websocket.send(message)
|
|
6
src/utils/__init__.py
Normal file
6
src/utils/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from .logger import get_module_logger, setup_root_logger
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_module_logger",
|
||||||
|
"setup_root_logger",
|
||||||
|
]
|
122
src/utils/data_format.py
Normal file
122
src/utils/data_format.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
"""
|
||||||
|
处理各类音频数据与bytes的转换
|
||||||
|
"""
|
||||||
|
|
||||||
|
import wave
|
||||||
|
from pydub import AudioSegment
|
||||||
|
import io
|
||||||
|
|
||||||
|
|
||||||
|
def wav_to_bytes(wav_path: str) -> bytes:
|
||||||
|
"""
|
||||||
|
将WAV文件读取为bytes数据。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
wav_path (str): WAV文件的路径。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bytes: WAV文件的原始字节数据。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
FileNotFoundError: 如果WAV文件不存在。
|
||||||
|
wave.Error: 如果文件不是有效的WAV文件。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with wave.open(wav_path, "rb") as wf:
|
||||||
|
# 读取所有帧
|
||||||
|
frames = wf.readframes(wf.getnframes())
|
||||||
|
return frames
|
||||||
|
except FileNotFoundError:
|
||||||
|
# 可以选择记录日志或重新抛出,这里为了清晰直接重新抛出
|
||||||
|
raise FileNotFoundError(f"错误: 未找到WAV文件 '{wav_path}'")
|
||||||
|
except wave.Error as e:
|
||||||
|
raise wave.Error(f"错误: 打开或读取WAV文件 '{wav_path}' 失败 - {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_wav(
|
||||||
|
bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
将bytes数据写入为WAV文件。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
bytes_data (bytes): 音频的字节数据。
|
||||||
|
wav_path (str): 保存WAV文件的路径。
|
||||||
|
nchannels (int): 声道数 (例如 1 for mono, 2 for stereo)。
|
||||||
|
sampwidth (int): 采样宽度 (字节数, 例如 2 for 16-bit audio)。
|
||||||
|
framerate (int): 采样率 (例如 44100, 16000)。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
wave.Error: 如果写入WAV文件失败。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with wave.open(wav_path, "wb") as wf:
|
||||||
|
wf.setnchannels(nchannels)
|
||||||
|
wf.setsampwidth(sampwidth)
|
||||||
|
wf.setframerate(framerate)
|
||||||
|
wf.writeframes(bytes_data)
|
||||||
|
except wave.Error as e:
|
||||||
|
raise wave.Error(f"错误: 写入WAV文件 '{wav_path}' 失败 - {e}")
|
||||||
|
except Exception as e:
|
||||||
|
# 捕获其他可能的写入错误
|
||||||
|
raise Exception(f"写入WAV文件 '{wav_path}' 时发生未知错误 - {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def mp3_to_bytes(mp3_path: str) -> bytes:
|
||||||
|
"""
|
||||||
|
将MP3文件转换为bytes数据 (原始PCM数据)。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
mp3_path (str): MP3文件的路径。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bytes: MP3文件解码后的原始PCM字节数据。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
FileNotFoundError: 如果MP3文件不存在。
|
||||||
|
pydub.exceptions.CouldntDecodeError: 如果MP3文件无法解码。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
audio = AudioSegment.from_mp3(mp3_path)
|
||||||
|
# 获取原始PCM数据
|
||||||
|
return audio.raw_data
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise FileNotFoundError(f"错误: 未找到MP3文件 '{mp3_path}'")
|
||||||
|
except Exception as e: # pydub 可能抛出多种解码相关的错误
|
||||||
|
raise Exception(f"错误: 处理MP3文件 '{mp3_path}' 失败 - {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_mp3(
|
||||||
|
bytes_data: bytes,
|
||||||
|
mp3_path: str,
|
||||||
|
frame_rate: int,
|
||||||
|
channels: int,
|
||||||
|
sample_width: int,
|
||||||
|
bitrate: str = "192k",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
将原始PCM bytes数据转换为MP3文件。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
bytes_data (bytes): 原始PCM字节数据。
|
||||||
|
mp3_path (str): 保存MP3文件的路径。
|
||||||
|
frame_rate (int): 原始PCM数据的采样率 (例如 44100)。
|
||||||
|
channels (int): 原始PCM数据的声道数 (例如 1 for mono, 2 for stereo)。
|
||||||
|
sample_width (int): 原始PCM数据的采样宽度 (字节数, 例如 2 for 16-bit)。
|
||||||
|
bitrate (str): MP3编码的比特率 (例如 "128k", "192k", "320k")。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
Exception: 如果转换或写入MP3文件失败。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 从原始数据创建AudioSegment对象
|
||||||
|
audio = AudioSegment(
|
||||||
|
data=bytes_data,
|
||||||
|
sample_width=sample_width,
|
||||||
|
frame_rate=frame_rate,
|
||||||
|
channels=channels,
|
||||||
|
)
|
||||||
|
# 导出为MP3
|
||||||
|
audio.export(mp3_path, format="mp3", bitrate=bitrate)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"错误: 转换或写入MP3文件 '{mp3_path}' 失败 - {e}")
|
91
src/utils/logger.py
Normal file
91
src/utils/logger.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(
|
||||||
|
name: str = None,
|
||||||
|
level: str = "INFO",
|
||||||
|
log_file: Optional[str] = None,
|
||||||
|
log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
date_format: str = "%Y-%m-%d %H:%M:%S",
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
设置并返回一个配置好的logger实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: logger的名称,默认为None(使用root logger)
|
||||||
|
level: 日志级别,默认为"INFO"
|
||||||
|
log_file: 日志文件路径,默认为None(仅控制台输出)
|
||||||
|
log_format: 日志格式
|
||||||
|
date_format: 日期格式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logging.Logger: 配置好的logger实例
|
||||||
|
"""
|
||||||
|
# 获取logger实例
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
|
||||||
|
# 设置日志级别
|
||||||
|
level = getattr(logging, level.upper())
|
||||||
|
logger.setLevel(level)
|
||||||
|
|
||||||
|
print(f"添加处理器 {name} {log_file} {log_format} {date_format}")
|
||||||
|
# 创建格式器
|
||||||
|
formatter = logging.Formatter(log_format, date_format)
|
||||||
|
|
||||||
|
# 添加控制台处理器
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
# 如果指定了日志文件,添加文件处理器
|
||||||
|
if log_file:
|
||||||
|
# 确保日志目录存在
|
||||||
|
log_path = Path(log_file)
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
# 注意:移除了 propagate = False,允许日志传递
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def setup_root_logger(level: str = "INFO", log_file: Optional[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
配置根日志器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: 日志级别
|
||||||
|
log_file: 日志文件路径
|
||||||
|
"""
|
||||||
|
setup_logger(None, level, log_file)
|
||||||
|
|
||||||
|
|
||||||
|
def get_module_logger(
|
||||||
|
module_name: str,
|
||||||
|
level: Optional[str] = None, # 改为可选参数
|
||||||
|
log_file: Optional[str] = None, # 一般不需要单独指定
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
获取模块级别的logger
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_name: 模块名称,通常传入__name__
|
||||||
|
level: 可选的日志级别,如果不指定则继承父级配置
|
||||||
|
log_file: 可选的日志文件路径,一般不需要指定
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(module_name)
|
||||||
|
|
||||||
|
# 只有显式指定了level才设置
|
||||||
|
if level:
|
||||||
|
logger.setLevel(getattr(logging, level.upper()))
|
||||||
|
|
||||||
|
# 只有显式指定了log_file才添加文件处理器
|
||||||
|
if log_file:
|
||||||
|
setup_logger(module_name, level or "INFO", log_file)
|
||||||
|
|
||||||
|
return logger
|
55
src/utils/mock_websocket.py
Normal file
55
src/utils/mock_websocket.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import queue
|
||||||
|
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
class MockWebSocketClient:
|
||||||
|
"""A mock WebSocket client to simulate a connection for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.sent_messages = []
|
||||||
|
self._is_closed = False
|
||||||
|
self.receive_queue = queue.Queue()
|
||||||
|
|
||||||
|
def send(self, message: dict):
|
||||||
|
"""Simulates sending a message (which is a dict)."""
|
||||||
|
if self._is_closed:
|
||||||
|
print("Warning: sending message on a closed websocket")
|
||||||
|
return
|
||||||
|
self.sent_messages.append(message)
|
||||||
|
print(f"Mock WS received: {message}")
|
||||||
|
|
||||||
|
def recv(self):
|
||||||
|
"""Simulates receiving data from the WebSocket."""
|
||||||
|
if self._is_closed:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
# Block until data is available, with a timeout to prevent hanging.
|
||||||
|
data = self.receive_queue.get(timeout=10)
|
||||||
|
if data is None:
|
||||||
|
self._is_closed = True
|
||||||
|
return data
|
||||||
|
except queue.Empty:
|
||||||
|
print("Mock WS recv timeout")
|
||||||
|
self._is_closed = True
|
||||||
|
return None
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Simulates closing the WebSocket connection."""
|
||||||
|
if not self._is_closed:
|
||||||
|
# Put None to unblock any waiting recv call
|
||||||
|
self.receive_queue.put(None)
|
||||||
|
self._is_closed = True
|
||||||
|
print("Mock WS closed")
|
||||||
|
|
||||||
|
def put_for_recv(self, data):
|
||||||
|
"""Puts data into the receive queue for the `recv` method to consume."""
|
||||||
|
if data is None:
|
||||||
|
return
|
||||||
|
# logger.debug("Mock WS put_for_recv length: %s", len(data))
|
||||||
|
self.receive_queue.put(data)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_closed(self):
|
||||||
|
return self._is_closed
|
66
src/websockets/adapter.py
Normal file
66
src/websockets/adapter.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import numpy as np
|
||||||
|
from fastapi import WebSocket
|
||||||
|
from typing import Union
|
||||||
|
import uuid
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
class FastAPIWebSocketAdapter:
|
||||||
|
"""
|
||||||
|
一个适配器类,用于将FastAPI的WebSocket对象包装成ASRRunner所期望的接口。
|
||||||
|
同时处理数据类型转换。
|
||||||
|
"""
|
||||||
|
def __init__(self, websocket: WebSocket, sample_rate: int = 16000, sample_width: int = 2):
|
||||||
|
self._ws = websocket
|
||||||
|
self._sample_rate = sample_rate
|
||||||
|
self._sample_width = sample_width
|
||||||
|
self._total_received = 0
|
||||||
|
async def recv(self) -> Union[np.ndarray, None]:
|
||||||
|
"""
|
||||||
|
接收来自FastAPI WebSocket的数据。
|
||||||
|
如果收到的是字节流,将其转换为Numpy数组。
|
||||||
|
如果收到的是文本"close",返回None以表示结束。
|
||||||
|
"""
|
||||||
|
message = await self._ws.receive()
|
||||||
|
if 'bytes' in message:
|
||||||
|
bytes_data = message['bytes']
|
||||||
|
audio_array = np.frombuffer(bytes_data, dtype=np.int16)
|
||||||
|
# 将int16转为float32
|
||||||
|
audio_array = audio_array.astype(np.float32)
|
||||||
|
# 归一化
|
||||||
|
audio_array = audio_array / 32768.0
|
||||||
|
|
||||||
|
# 使用回车符 \r 覆盖打印进度
|
||||||
|
self._total_received += len(bytes_data)
|
||||||
|
print(f"🎧 [Adapter] 正在接收音频... 总计: {self._total_received / 1024:.2f} KB", end='\r')
|
||||||
|
|
||||||
|
return audio_array
|
||||||
|
elif 'text' in message and message['text'].lower() == 'close':
|
||||||
|
print("\n🏁 [Adapter] 收到 'close' 信号。") # 在收到结束信号时换行
|
||||||
|
return None # 返回 None 来作为结束信号
|
||||||
|
return np.array([]) # 返回空数组以忽略其他类型的消息
|
||||||
|
|
||||||
|
async def send(self, message: dict):
|
||||||
|
"""
|
||||||
|
将字典消息作为JSON发送给客户端。
|
||||||
|
在发送前,将所有UUID对象转换为字符串以确保可序列化。
|
||||||
|
"""
|
||||||
|
def convert_uuids(obj):
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: convert_uuids(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [convert_uuids(elem) for elem in obj]
|
||||||
|
elif isinstance(obj, uuid.UUID):
|
||||||
|
return str(obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
serializable_message = convert_uuids(message)
|
||||||
|
logger.info(f"[Adapter] 发送消息: {serializable_message}")
|
||||||
|
await self._ws.send_json(serializable_message)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""
|
||||||
|
关闭WebSocket连接。
|
||||||
|
"""
|
||||||
|
await self._ws.close()
|
3
src/websockets/endpoint/__init__.py
Normal file
3
src/websockets/endpoint/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .asr_endpoint import router as asr_router
|
||||||
|
|
||||||
|
__all__ = ["asr_router"]
|
89
src/websockets/endpoint/asr_endpoint.py
Normal file
89
src/websockets/endpoint/asr_endpoint.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Request
|
||||||
|
from src.websockets.adapter import FastAPIWebSocketAdapter
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.websocket("/asr/{session_id}")
|
||||||
|
async def asr_websocket_endpoint(
|
||||||
|
websocket: WebSocket,
|
||||||
|
session_id: str,
|
||||||
|
mode: str = Query(default="sender", enum=["sender", "receiver"])
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
ASR WebSocket 端点
|
||||||
|
|
||||||
|
- **session_id**: 标识一个识别会话的唯一ID.
|
||||||
|
- **mode**: 客户端模式.
|
||||||
|
- `sender`: 作为音频发送方加入,将创建一个新的识别会话.
|
||||||
|
- `receiver`: 作为结果接收方加入,订阅一个已存在的会话.
|
||||||
|
"""
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
# 从websocket.app.state获取全局的ASRRunner实例
|
||||||
|
asr_runner = websocket.app.state.asr_runner
|
||||||
|
|
||||||
|
# 创建WebSocket适配器
|
||||||
|
# 注意:这里的audio_config应该与ASRRunner中的默认配置一致
|
||||||
|
audio_config = asr_runner._default_audio_config
|
||||||
|
adapter = FastAPIWebSocketAdapter(
|
||||||
|
websocket,
|
||||||
|
sample_rate=audio_config.sample_rate,
|
||||||
|
sample_width=audio_config.sample_width
|
||||||
|
)
|
||||||
|
|
||||||
|
if mode == "sender":
|
||||||
|
logger.info(f"客户端 {websocket.client} 作为 'sender' 加入会话: {session_id}")
|
||||||
|
# 创建一个新的SAR会话
|
||||||
|
sar_id = asr_runner.new_SAR(ws=adapter, name=session_id)
|
||||||
|
if sar_id is None:
|
||||||
|
logger.error(f"为会话 {session_id} 创建SAR失败")
|
||||||
|
await websocket.close(code=1011, reason="Failed to create ASR session")
|
||||||
|
return
|
||||||
|
|
||||||
|
sar = next((s for s in asr_runner._SAR_list if s._id == sar_id), None)
|
||||||
|
try:
|
||||||
|
# 端点函数等待后台任务完成。
|
||||||
|
# 真正的接收逻辑在SAR的_run方法中,该方法由new_SAR作为后台任务启动。
|
||||||
|
# 当客户端断开连接时,adapter.recv()会抛出异常,
|
||||||
|
# _run任务会捕获它,然后停止并清理,最后任务结束。
|
||||||
|
if sar and sar._task:
|
||||||
|
await sar._task
|
||||||
|
else:
|
||||||
|
# 如果任务没有被创建,记录一个错误并关闭连接
|
||||||
|
logger.error(f"SAR任务未能在会话 {session_id} 中启动")
|
||||||
|
await websocket.close(code=1011, reason="Failed to start ASR task")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 捕获任何意外的错误
|
||||||
|
logger.error(f"会话 {session_id} 的 'sender' 端点发生未知错误: {e}")
|
||||||
|
finally:
|
||||||
|
logger.info(f"'sender' {websocket.client} 在会话 {session_id} 的连接处理已结束")
|
||||||
|
|
||||||
|
elif mode == "receiver":
|
||||||
|
logger.info(f"客户端 {websocket.client} 作为 'receiver' 加入会话: {session_id}")
|
||||||
|
# 加入一个已存在的SAR会话
|
||||||
|
joined = asr_runner.join_SAR(ws=adapter, name=session_id)
|
||||||
|
if not joined:
|
||||||
|
logger.warning(f"无法找到会话 {session_id},'receiver' {websocket.client} 加入失败")
|
||||||
|
await websocket.close(code=1011, reason=f"Session '{session_id}' not found")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Receiver只需要保持连接,等待从SAR广播过来的消息
|
||||||
|
# 这个循环也用于检测断开
|
||||||
|
while True:
|
||||||
|
await websocket.receive_text()
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.info(f"'receiver' {websocket.client} 在会话 {session_id} 中断开连接")
|
||||||
|
# Receiver断开时,需要将其从SAR的接收者列表中移除
|
||||||
|
sar = next((s for s in asr_runner._SAR_list if s._name == session_id), None)
|
||||||
|
if sar:
|
||||||
|
sar.delete_receiver(adapter)
|
||||||
|
logger.info(f"已从会话 {session_id} 中移除 'receiver' {websocket.client}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 理论上,由于FastAPI的enum校验,这里的代码不会被执行
|
||||||
|
logger.error(f"无效的模式: {mode}")
|
||||||
|
await websocket.close(code=1003, reason="Invalid mode specified")
|
9
src/websockets/router.py
Normal file
9
src/websockets/router.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
from .endpoint import asr_endpoint
|
||||||
|
|
||||||
|
websocket_router = APIRouter()
|
||||||
|
|
||||||
|
# 包含ASR端点路由
|
||||||
|
websocket_router.include_router(asr_endpoint.router)
|
||||||
|
|
||||||
|
__all__ = ["websocket_router"]
|
26
test_main.py
Normal file
26
test_main.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
测试主函数
|
||||||
|
请在tests目录下创建测试文件, 并在此文件中调用
|
||||||
|
"""
|
||||||
|
|
||||||
|
from tests.pipeline.asr_test import test_asr_pipeline
|
||||||
|
from src.utils.logger import get_module_logger, setup_root_logger
|
||||||
|
from tests.runner.asr_runner_test import test_asr_runner
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
setup_root_logger(level="INFO", log_file="logs/test_main.log")
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
# 清空logs/test_main.log文件
|
||||||
|
with open("logs/test_main.log", "w") as f:
|
||||||
|
f.truncate()
|
||||||
|
|
||||||
|
# from tests.functor.vad_test import test_vad_functor
|
||||||
|
# logger.info("开始测试VAD函数器")
|
||||||
|
# test_vad_functor()
|
||||||
|
|
||||||
|
# logger.info("开始测试ASR管道")
|
||||||
|
# test_asr_pipeline()
|
||||||
|
|
||||||
|
logger.info("开始测试ASRRunner")
|
||||||
|
asyncio.run(test_asr_runner())
|
BIN
tests/XT_ZZY.wav
Normal file
BIN
tests/XT_ZZY.wav
Normal file
Binary file not shown.
BIN
tests/XT_ZZY_denoise.wav
Normal file
BIN
tests/XT_ZZY_denoise.wav
Normal file
Binary file not shown.
124
tests/functor/vad_test.py
Normal file
124
tests/functor/vad_test.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
"""
|
||||||
|
Functor测试
|
||||||
|
VAD测试
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.functor.vad_functor import VADFunctor
|
||||||
|
from src.functor.asr_functor import ASRFunctor
|
||||||
|
from src.functor.spk_functor import SPKFunctor
|
||||||
|
from queue import Queue, Empty
|
||||||
|
from src.core.model_loader import ModelLoader
|
||||||
|
from src.models import AudioBinary_Config, AudioBinary_data_list
|
||||||
|
from src.utils.data_format import wav_to_bytes
|
||||||
|
import time
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
from pydub import AudioSegment
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
# 观察参数
|
||||||
|
OVERWATCH = False
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
model_loader = ModelLoader()
|
||||||
|
|
||||||
|
|
||||||
|
def test_vad_functor():
|
||||||
|
# 加载模型
|
||||||
|
args = {
|
||||||
|
"asr_model": "paraformer-zh",
|
||||||
|
"asr_model_revision": "v2.0.4",
|
||||||
|
"vad_model": "fsmn-vad",
|
||||||
|
"vad_model_revision": "v2.0.4",
|
||||||
|
"auto_update": False,
|
||||||
|
}
|
||||||
|
model_loader.load_models(args)
|
||||||
|
# 加载数据
|
||||||
|
f_data, sample_rate = soundfile.read("tests/vad_example.wav")
|
||||||
|
audio_config = AudioBinary_Config(
|
||||||
|
chunk_size=200,
|
||||||
|
chunk_stride=1600,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
sample_width=16,
|
||||||
|
channels=1,
|
||||||
|
)
|
||||||
|
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
|
||||||
|
audio_config.chunk_stride = chunk_stride
|
||||||
|
# 创建输入队列
|
||||||
|
input_queue = Queue()
|
||||||
|
vad2asr_queue = Queue()
|
||||||
|
vad2spk_queue = Queue()
|
||||||
|
# 创建音频数据列表
|
||||||
|
audio_binary_data_list = AudioBinary_data_list()
|
||||||
|
|
||||||
|
# 创建VAD函数器
|
||||||
|
vad_functor = VADFunctor()
|
||||||
|
# 设置输入队列
|
||||||
|
vad_functor.set_input_queue(input_queue)
|
||||||
|
# 设置音频配置
|
||||||
|
vad_functor.set_audio_config(audio_config)
|
||||||
|
# 设置音频数据列表
|
||||||
|
vad_functor.set_audio_binary_data_list(audio_binary_data_list)
|
||||||
|
# 设置回调函数
|
||||||
|
vad_functor.add_callback(lambda x: print(f"vad callback: {x}"))
|
||||||
|
vad_functor.add_callback(lambda x: vad2asr_queue.put(x))
|
||||||
|
vad_functor.add_callback(lambda x: vad2spk_queue.put(x))
|
||||||
|
# 设置模型
|
||||||
|
vad_functor.set_model({"vad": model_loader.models["vad"]})
|
||||||
|
# 启动VAD函数器
|
||||||
|
vad_functor.run()
|
||||||
|
|
||||||
|
# 创建ASR函数器
|
||||||
|
asr_functor = ASRFunctor()
|
||||||
|
# 设置输入队列
|
||||||
|
asr_functor.set_input_queue(vad2asr_queue)
|
||||||
|
# 设置音频配置
|
||||||
|
asr_functor.set_audio_config(audio_config)
|
||||||
|
# 设置回调函数
|
||||||
|
asr_functor.add_callback(lambda x: print(f"asr callback: {x}"))
|
||||||
|
# 设置模型
|
||||||
|
asr_functor.set_model({"asr": model_loader.models["asr"]})
|
||||||
|
# 启动ASR函数器
|
||||||
|
asr_functor.run()
|
||||||
|
|
||||||
|
# 创建SPK函数器
|
||||||
|
spk_functor = SPKFunctor()
|
||||||
|
# 设置输入队列
|
||||||
|
spk_functor.set_input_queue(vad2spk_queue)
|
||||||
|
# 设置音频配置
|
||||||
|
spk_functor.set_audio_config(audio_config)
|
||||||
|
# 设置回调函数
|
||||||
|
spk_functor.add_callback(lambda x: print(f"spk callback: {x}"))
|
||||||
|
# 设置模型
|
||||||
|
spk_functor.set_model(
|
||||||
|
{
|
||||||
|
# 'spk': model_loader.models['spk']
|
||||||
|
"spk": "fake_spk"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# 启动SPK函数器
|
||||||
|
spk_functor.run()
|
||||||
|
|
||||||
|
f_binary = f_data
|
||||||
|
audio_clip_len = 200
|
||||||
|
print(
|
||||||
|
f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}"
|
||||||
|
)
|
||||||
|
for i in range(0, len(f_binary), audio_clip_len):
|
||||||
|
binary_data = f_binary[i : i + audio_clip_len]
|
||||||
|
input_queue.put(binary_data)
|
||||||
|
# 等待VAD函数器结束
|
||||||
|
|
||||||
|
vad_functor.stop()
|
||||||
|
print("[vad_test] VAD函数器结束")
|
||||||
|
|
||||||
|
asr_functor.stop()
|
||||||
|
print("[vad_test] ASR函数器结束")
|
||||||
|
|
||||||
|
# 保存音频数据
|
||||||
|
if OVERWATCH:
|
||||||
|
for index in range(len(audio_binary_data_list)):
|
||||||
|
save_path = f"tests/vad_test_output_{index}.wav"
|
||||||
|
soundfile.write(
|
||||||
|
save_path, audio_binary_data_list[index].binary_data, sample_rate
|
||||||
|
)
|
121
tests/modelsuse.py
Normal file
121
tests/modelsuse.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
"""
|
||||||
|
模型使用测试
|
||||||
|
此处主要用于各类调用模型的处理数据与输出格式
|
||||||
|
请在主目录下test_main.py中调用
|
||||||
|
将需要测试的模型定义在函数中进行测试, 函数名称需要与测试内容匹配。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from funasr import AutoModel
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from src.models import VADResponse
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
在线VAD模型使用
|
||||||
|
"""
|
||||||
|
chunk_size = 100 # ms
|
||||||
|
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
|
||||||
|
|
||||||
|
vad_result = VADResponse()
|
||||||
|
vad_result.time_chunk_index_callback = lambda index: print(f"回调: {index}")
|
||||||
|
items = []
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
speech, sample_rate = soundfile.read(file_path)
|
||||||
|
chunk_stride = int(chunk_size * sample_rate / 1000)
|
||||||
|
|
||||||
|
cache = {}
|
||||||
|
total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
|
||||||
|
for i in range(total_chunk_num):
|
||||||
|
time.sleep(0.1)
|
||||||
|
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
|
||||||
|
is_final = i == total_chunk_num - 1
|
||||||
|
res = model.generate(
|
||||||
|
input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size
|
||||||
|
)
|
||||||
|
if len(res[0]["value"]):
|
||||||
|
vad_result += VADResponse.from_raw(res)
|
||||||
|
for item in res[0]["value"]:
|
||||||
|
items.append(item)
|
||||||
|
vad_result.process_time_chunk()
|
||||||
|
|
||||||
|
# for item in items:
|
||||||
|
# print(item)
|
||||||
|
return vad_result
|
||||||
|
|
||||||
|
|
||||||
|
def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
在线VAD模型使用
|
||||||
|
测试LogicTrager
|
||||||
|
在Rebuild版本后LogicTrager中已弃用
|
||||||
|
"""
|
||||||
|
from src.logic_trager import LogicTrager
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
from src.config import parse_args
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# from src.functor.model_loader import load_models
|
||||||
|
# models = load_models(args)
|
||||||
|
from src.core.model_loader import ModelLoader
|
||||||
|
|
||||||
|
models = ModelLoader(args)
|
||||||
|
|
||||||
|
chunk_size = 200 # ms
|
||||||
|
from src.models import AudioBinary_Config
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
speech, sample_rate = soundfile.read(file_path)
|
||||||
|
chunk_stride = int(chunk_size * sample_rate / 1000)
|
||||||
|
audio_config = AudioBinary_Config(
|
||||||
|
sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
|
logic_trager = LogicTrager(models=models, audio_config=audio_config)
|
||||||
|
for i in range(len(speech) // chunk_stride + 1):
|
||||||
|
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
|
||||||
|
logic_trager.push_binary_data(speech_chunk)
|
||||||
|
|
||||||
|
# for item in items:
|
||||||
|
# print(item)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
ASR模型使用
|
||||||
|
离线ASR模型使用
|
||||||
|
"""
|
||||||
|
from funasr import AutoModel
|
||||||
|
|
||||||
|
model = AutoModel(
|
||||||
|
model="paraformer-zh",
|
||||||
|
model_revision="v2.0.4",
|
||||||
|
vad_model="fsmn-vad",
|
||||||
|
vad_model_revision="v2.0.4",
|
||||||
|
# punc_model="ct-punc-c", punc_model_revision="v2.0.4",
|
||||||
|
spk_model="cam++",
|
||||||
|
spk_model_revision="v2.0.2",
|
||||||
|
spk_mode="vad_segment",
|
||||||
|
auto_update=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
from src.models import AudioBinary_Config
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
speech, sample_rate = soundfile.read(file_path)
|
||||||
|
result = model.generate(speech)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# 请在主目录下调用test_main.py文件进行测试
|
||||||
|
# vad_result = vad_model_use_online("tests/vad_example.wav")
|
||||||
|
# vad_result = vad_model_use_online_logic("tests/vad_example.wav")
|
||||||
|
# print(vad_result)
|
94
tests/pipeline/asr_test.py
Normal file
94
tests/pipeline/asr_test.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
"""
|
||||||
|
Pipeline测试
|
||||||
|
VAD+ASR+SPK(FAKE)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.pipeline.ASRpipeline import ASRPipeline
|
||||||
|
from src.pipeline import PipelineFactory
|
||||||
|
from src.models import AudioBinary_data_list, AudioBinary_Config
|
||||||
|
from src.core.model_loader import ModelLoader
|
||||||
|
from queue import Queue
|
||||||
|
import soundfile
|
||||||
|
import time
|
||||||
|
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
OVAERWATCH = False
|
||||||
|
|
||||||
|
model_loader = ModelLoader()
|
||||||
|
|
||||||
|
|
||||||
|
def test_asr_pipeline():
|
||||||
|
# 加载模型
|
||||||
|
args = {
|
||||||
|
"asr_model": "paraformer-zh",
|
||||||
|
"asr_model_revision": "v2.0.4",
|
||||||
|
"vad_model": "fsmn-vad",
|
||||||
|
"vad_model_revision": "v2.0.4",
|
||||||
|
"spk_model": "cam++",
|
||||||
|
"spk_model_revision": "v2.0.2",
|
||||||
|
"audio_update": False,
|
||||||
|
}
|
||||||
|
models = model_loader.load_models(args)
|
||||||
|
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
|
||||||
|
audio_config = AudioBinary_Config(
|
||||||
|
chunk_size=200,
|
||||||
|
chunk_stride=1600,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
sample_width=16,
|
||||||
|
channels=1,
|
||||||
|
)
|
||||||
|
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
|
||||||
|
audio_config.chunk_stride = chunk_stride
|
||||||
|
|
||||||
|
# 创建参数Dict
|
||||||
|
config = {
|
||||||
|
"audio_config": audio_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 创建音频数据列表
|
||||||
|
audio_binary_data_list = AudioBinary_data_list()
|
||||||
|
|
||||||
|
input_queue = Queue()
|
||||||
|
|
||||||
|
# 创建Pipeline
|
||||||
|
# asr_pipeline = ASRPipeline()
|
||||||
|
# asr_pipeline.set_models(models)
|
||||||
|
# asr_pipeline.set_config(config)
|
||||||
|
# asr_pipeline.set_audio_binary(audio_binary_data_list)
|
||||||
|
# asr_pipeline.set_input_queue(input_queue)
|
||||||
|
# asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
|
||||||
|
# asr_pipeline.bake()
|
||||||
|
asr_pipeline = PipelineFactory.create_pipeline(
|
||||||
|
pipeline_name = "ASRpipeline",
|
||||||
|
models=models,
|
||||||
|
config=config,
|
||||||
|
audio_binary=audio_binary_data_list,
|
||||||
|
input_queue=input_queue,
|
||||||
|
callback=lambda x: print(f"pipeline callback: {x}")
|
||||||
|
)
|
||||||
|
|
||||||
|
asr_pipeline.bake()
|
||||||
|
# 运行Pipeline
|
||||||
|
asr_instance = asr_pipeline.run()
|
||||||
|
|
||||||
|
|
||||||
|
audio_clip_len = 200
|
||||||
|
print(
|
||||||
|
f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}"
|
||||||
|
)
|
||||||
|
for i in range(0, len(audio_data), audio_clip_len):
|
||||||
|
input_queue.put(audio_data[i : i + audio_clip_len])
|
||||||
|
|
||||||
|
# time.sleep(10)
|
||||||
|
# input_queue.put(None)
|
||||||
|
|
||||||
|
# 等待Pipeline结束
|
||||||
|
# asr_instance.join()
|
||||||
|
|
||||||
|
time.sleep(5)
|
||||||
|
asr_pipeline.stop()
|
||||||
|
# asr_pipeline.stop()
|
132
tests/runner/asr_runner_test.py
Normal file
132
tests/runner/asr_runner_test.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
"""
|
||||||
|
ASRRunner test
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import soundfile
|
||||||
|
import numpy as np
|
||||||
|
from src.runner.ASRRunner import ASRRunner
|
||||||
|
from src.core.model_loader import ModelLoader
|
||||||
|
from src.models import AudioBinary_Config
|
||||||
|
from asyncio import Queue as AsyncQueue
|
||||||
|
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncMockWebSocketClient:
|
||||||
|
"""一个用于测试目的的异步WebSocket客户端模拟器。"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._recv_q = AsyncQueue()
|
||||||
|
self._send_q = AsyncQueue()
|
||||||
|
|
||||||
|
def put_for_recv(self, item):
|
||||||
|
"""允许测试将数据送入模拟的WebSocket中。"""
|
||||||
|
self._recv_q.put_nowait(item)
|
||||||
|
|
||||||
|
async def get_from_send(self):
|
||||||
|
"""允许测试从模拟的WebSocket中获取结果。"""
|
||||||
|
return await self._send_q.get()
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
"""ASRRunner将调用此方法来获取数据。"""
|
||||||
|
return await self._recv_q.get()
|
||||||
|
|
||||||
|
async def send(self, item):
|
||||||
|
"""ASRRunner将通过回调调用此方法来发送结果。"""
|
||||||
|
logger.info(f"Mock WS 收到结果: {item}")
|
||||||
|
await self._send_q.put(item)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""一个模拟的关闭方法。"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def test_asr_runner():
|
||||||
|
"""
|
||||||
|
针对ASRRunner的端到端测试,已适配异步操作。
|
||||||
|
1. 加载模型.
|
||||||
|
2. 配置并初始化ASRRunner.
|
||||||
|
3. 创建一个异步的模拟WebSocket客户端.
|
||||||
|
4. 在Runner中启动一个新的SenderAndReceiver (SAR)实例.
|
||||||
|
5. 通过模拟的WebSocket流式传输音频数据.
|
||||||
|
6. 等待处理任务完成并断言其无错误运行.
|
||||||
|
"""
|
||||||
|
# 1. 加载模型
|
||||||
|
model_loader = ModelLoader()
|
||||||
|
args = {
|
||||||
|
"asr_model": "paraformer-zh",
|
||||||
|
"asr_model_revision": "v2.0.4",
|
||||||
|
"vad_model": "fsmn-vad",
|
||||||
|
"vad_model_revision": "v2.0.4",
|
||||||
|
"spk_model": "cam++",
|
||||||
|
"spk_model_revision": "v2.0.2",
|
||||||
|
}
|
||||||
|
models = model_loader.load_models(args)
|
||||||
|
audio_file_path = "tests/XT_ZZY_denoise.wav"
|
||||||
|
audio_data, sample_rate = soundfile.read(audio_file_path, dtype='float32')
|
||||||
|
logger.info(
|
||||||
|
f"加载数据: {audio_file_path} , audio_data_length: {len(audio_data)}, audio_data_type: {type(audio_data)}, sample_rate: {sample_rate}"
|
||||||
|
)
|
||||||
|
# 进一步详细打印audio_data数据类型
|
||||||
|
# 详细打印audio_data的类型和结构信息,便于调试
|
||||||
|
logger.info(f"audio_data 类型: {type(audio_data)}")
|
||||||
|
logger.info(f"audio_data dtype: {getattr(audio_data, 'dtype', '未知')}")
|
||||||
|
logger.info(f"audio_data shape: {getattr(audio_data, 'shape', '未知')}")
|
||||||
|
logger.info(f"audio_data ndim: {getattr(audio_data, 'ndim', '未知')}")
|
||||||
|
logger.info(f"audio_data 示例前10个值: {audio_data[:10] if hasattr(audio_data, '__getitem__') else '不可切片'}")
|
||||||
|
|
||||||
|
# 2. 配置音频
|
||||||
|
audio_config = AudioBinary_Config(
|
||||||
|
chunk_size=200, # ms
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
sample_width=2, # 16-bit
|
||||||
|
channels=1,
|
||||||
|
)
|
||||||
|
audio_config.chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
|
||||||
|
|
||||||
|
# 3. 设置ASRRunner
|
||||||
|
asr_runner = ASRRunner()
|
||||||
|
asr_runner.set_default_config(
|
||||||
|
audio_config=audio_config,
|
||||||
|
models=models,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. 创建模拟WebSocket并启动SAR
|
||||||
|
mock_ws = AsyncMockWebSocketClient()
|
||||||
|
sar_id = asr_runner.new_SAR(
|
||||||
|
ws=mock_ws,
|
||||||
|
name="test_sar",
|
||||||
|
)
|
||||||
|
assert sar_id is not None, "创建新的SAR实例失败"
|
||||||
|
|
||||||
|
# 获取SAR实例以等待其任务
|
||||||
|
sar = next((s for s in asr_runner._SAR_list if s._id == sar_id), None)
|
||||||
|
assert sar is not None, "无法从Runner中获取SAR实例。"
|
||||||
|
assert sar._task is not None, "SAR任务未被创建。"
|
||||||
|
|
||||||
|
# 5. 在后台任务中模拟流式音频
|
||||||
|
async def feed_audio():
|
||||||
|
logger.info("Feeder任务已启动:开始流式传输音频数据...")
|
||||||
|
# 每次发送100ms的音频
|
||||||
|
audio_clip_len = int(sample_rate * 0.1)
|
||||||
|
for i in range(0, len(audio_data), audio_clip_len):
|
||||||
|
chunk = audio_data[i : i + audio_clip_len].astype(np.float32)
|
||||||
|
if chunk.size == 0:
|
||||||
|
break
|
||||||
|
mock_ws.put_for_recv(chunk)
|
||||||
|
await asyncio.sleep(0.1) # 模拟实时流
|
||||||
|
|
||||||
|
# 发送None来表示音频流结束
|
||||||
|
mock_ws.put_for_recv(None)
|
||||||
|
logger.info("Feeder任务已完成:所有音频数据已发送。")
|
||||||
|
|
||||||
|
feeder_task = asyncio.create_task(feed_audio())
|
||||||
|
|
||||||
|
# 6. 等待SAR处理完成
|
||||||
|
# SAR任务在从模拟WebSocket接收到None后会结束
|
||||||
|
await sar._task
|
||||||
|
await feeder_task # 确保feeder也已完成
|
||||||
|
|
||||||
|
logger.info("ASRRunner测试成功完成。")
|
25
tests/spkverify_use.py
Normal file
25
tests/spkverify_use.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
sv_pipeline = pipeline(
|
||||||
|
task='speaker-verification',
|
||||||
|
model='iic/speech_campplus_sv_zh-cn_16k-common',
|
||||||
|
model_revision='v1.0.0'
|
||||||
|
)
|
||||||
|
speaker1_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker1_a_cn_16k.wav'
|
||||||
|
speaker1_b_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker1_b_cn_16k.wav'
|
||||||
|
speaker2_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker2_a_cn_16k.wav'
|
||||||
|
# 相同说话人语音
|
||||||
|
result = sv_pipeline([speaker1_a_wav, speaker1_b_wav])
|
||||||
|
print(result)
|
||||||
|
# 不同说话人语音
|
||||||
|
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav])
|
||||||
|
print(result)
|
||||||
|
# 可以自定义得分阈值来进行识别,阈值越高,判定为同一人的条件越严格
|
||||||
|
result = sv_pipeline([speaker1_a_wav, speaker1_a_wav], thr=0.6)
|
||||||
|
print(result)
|
||||||
|
# 可以传入output_emb参数,输出结果中就会包含提取到的说话人embedding
|
||||||
|
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], output_emb=True)
|
||||||
|
print(result['embs'], result['outputs'])
|
||||||
|
# result1 = sv_pipeline([speaker1_a_wav], output_emb=True)
|
||||||
|
# print(result1['embs'], result1['outputs'])
|
||||||
|
# 可以传入save_dir参数,提取到的说话人embedding会存储在save_dir目录中
|
||||||
|
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], save_dir='savePath/')
|
@ -10,13 +10,13 @@ import os
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
# 将src目录添加到路径
|
# 将src目录添加到路径
|
||||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
from src.config import parse_args
|
from src.config import parse_args
|
||||||
|
|
||||||
|
|
||||||
def test_default_args():
|
def test_default_args():
|
||||||
"""测试默认参数值"""
|
"""测试默认参数值"""
|
||||||
with patch('sys.argv', ['script.py']):
|
with patch("sys.argv", ["script.py"]):
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# 检查服务器参数
|
# 检查服务器参数
|
||||||
@ -46,17 +46,24 @@ def test_default_args():
|
|||||||
def test_custom_args():
|
def test_custom_args():
|
||||||
"""测试自定义参数值"""
|
"""测试自定义参数值"""
|
||||||
test_args = [
|
test_args = [
|
||||||
'script.py',
|
"script.py",
|
||||||
'--host', 'localhost',
|
"--host",
|
||||||
'--port', '8080',
|
"localhost",
|
||||||
'--certfile', 'cert.pem',
|
"--port",
|
||||||
'--keyfile', 'key.pem',
|
"8080",
|
||||||
'--asr_model', 'custom_model',
|
"--certfile",
|
||||||
'--ngpu', '0',
|
"cert.pem",
|
||||||
'--device', 'cpu'
|
"--keyfile",
|
||||||
|
"key.pem",
|
||||||
|
"--asr_model",
|
||||||
|
"custom_model",
|
||||||
|
"--ngpu",
|
||||||
|
"0",
|
||||||
|
"--device",
|
||||||
|
"cpu",
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch('sys.argv', test_args):
|
with patch("sys.argv", test_args):
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# 检查自定义参数
|
# 检查自定义参数
|
||||||
|
BIN
tests/vad_example.wav
Normal file
BIN
tests/vad_example.wav
Normal file
Binary file not shown.
110
tests/websocket/websocket_asr.py
Normal file
110
tests/websocket/websocket_asr.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import asyncio
|
||||||
|
import websockets
|
||||||
|
import soundfile as sf
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
# --- 配置 ---
|
||||||
|
HOST = "localhost"
|
||||||
|
PORT = 11096
|
||||||
|
SESSION_ID = str(uuid.uuid4())
|
||||||
|
SENDER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=sender"
|
||||||
|
RECEIVER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=receiver"
|
||||||
|
|
||||||
|
AUDIO_FILE_PATH = "tests/XT_ZZY_denoise.wav" # 确保此测试文件存在且为 16kHz, 16-bit, 单声道
|
||||||
|
CHUNK_DURATION_MS = 100 # 每次发送100ms的音频数据
|
||||||
|
CHUNK_SIZE = int(16000 * 2 * CHUNK_DURATION_MS / 1000) # 3200 bytes
|
||||||
|
|
||||||
|
async def run_receiver():
|
||||||
|
"""作为接收者连接,并打印收到的所有消息。"""
|
||||||
|
print(f"▶️ [Receiver] 尝试连接到: {RECEIVER_URI}")
|
||||||
|
try:
|
||||||
|
async with websockets.connect(RECEIVER_URI) as websocket:
|
||||||
|
print("✅ [Receiver] 连接成功,等待消息...")
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = await websocket.recv()
|
||||||
|
print(f"🎧 [Receiver] 收到结果: {message}")
|
||||||
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
|
print(f"✅ [Receiver] 连接已由服务器正常关闭: {e.reason}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ [Receiver] 连接失败: {e}")
|
||||||
|
|
||||||
|
async def run_sender():
|
||||||
|
"""
|
||||||
|
作为发送者连接,同时负责发送音频和接收自己会话的广播结果。
|
||||||
|
"""
|
||||||
|
await asyncio.sleep(1) # 等待receiver有机会先连接
|
||||||
|
print(f"▶️ [Sender] 尝试连接到: {SENDER_URI}")
|
||||||
|
try:
|
||||||
|
async with websockets.connect(SENDER_URI) as websocket:
|
||||||
|
print("✅ [Sender] 连接成功。")
|
||||||
|
|
||||||
|
# --- 并行任务:接收消息 ---
|
||||||
|
async def receive_task():
|
||||||
|
print("▶️ [Sender-Receiver] 开始监听广播消息...")
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = await websocket.recv()
|
||||||
|
print(f"🎧 [Sender-Receiver] 收到结果: {message}")
|
||||||
|
except websockets.exceptions.ConnectionClosed:
|
||||||
|
print("✅ [Sender-Receiver] 连接已关闭,停止监听。")
|
||||||
|
|
||||||
|
receiver_sub_task = asyncio.create_task(receive_task())
|
||||||
|
|
||||||
|
# --- 主任务:发送音频 ---
|
||||||
|
try:
|
||||||
|
print("▶️ [Sender] 准备发送音频...")
|
||||||
|
audio_data, sample_rate = sf.read(AUDIO_FILE_PATH, dtype='int16')
|
||||||
|
if sample_rate != 16000:
|
||||||
|
print(f"❌ [Sender] 错误:音频文件采样率必须是 16kHz。")
|
||||||
|
receiver_sub_task.cancel()
|
||||||
|
return
|
||||||
|
|
||||||
|
total_samples = len(audio_data)
|
||||||
|
chunk_samples = CHUNK_SIZE // 2
|
||||||
|
samples_sent = 0
|
||||||
|
print(f"音频加载成功,总长度: {total_samples} samples。开始分块发送...")
|
||||||
|
|
||||||
|
for i in range(0, total_samples, chunk_samples):
|
||||||
|
chunk = audio_data[i:i + chunk_samples]
|
||||||
|
if len(chunk) == 0:
|
||||||
|
break
|
||||||
|
await websocket.send(chunk.tobytes())
|
||||||
|
samples_sent += len(chunk)
|
||||||
|
print(f"🎧 [Sender] 正在发送: {samples_sent}/{total_samples} samples", end="\r")
|
||||||
|
await asyncio.sleep(CHUNK_DURATION_MS / 1000)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("🏁 [Sender] 音频流发送完毕,发送 'close' 信号。")
|
||||||
|
await websocket.send("close")
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"❌ [Sender] 错误:找不到音频文件 {AUDIO_FILE_PATH}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ [Sender] 发送过程中发生错误: {e}")
|
||||||
|
|
||||||
|
# 等待接收任务自然结束(当连接关闭时)
|
||||||
|
await receiver_sub_task
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ [Sender] 连接失败: {e}")
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""同时运行 sender 和 receiver 任务。"""
|
||||||
|
print("--- 开始 WebSocket ASR 服务端到端测试 ---")
|
||||||
|
print(f"会话 ID: {SESSION_ID}")
|
||||||
|
|
||||||
|
# 创建 receiver 和 sender 任务
|
||||||
|
sender_task = asyncio.create_task(run_sender())
|
||||||
|
await asyncio.sleep(7)
|
||||||
|
receiver_task = asyncio.create_task(run_receiver())
|
||||||
|
|
||||||
|
# 等待两个任务完成
|
||||||
|
await asyncio.gather(receiver_task, sender_task)
|
||||||
|
|
||||||
|
print("--- 测试结束 ---")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 在运行此脚本前,请确保 FastAPI 服务器正在运行。
|
||||||
|
# python main.py
|
||||||
|
asyncio.run(main())
|
Loading…
x
Reference in New Issue
Block a user