Compare commits

...

7 Commits
master ... dev

Author SHA1 Message Date
Ziyang.Zhang
52a0fdfd89 [项目重构/说话人识别] 完成项目重构,基于Funasr、FastAPI、Websocket的具有VAD、ASR、SPK(说话人识别)功能的asr_server服务器;重构了基于Functor(功能模块)、Pipeline(处理流程)、Runner(单一Pipeline多用户并发管理)的三层实现层+FastAPI&Websocket用户接入层的服务器;使用uv进行项目依赖管理、支持docker一键部署。 2025-07-10 10:17:21 +08:00
Ziyang.Zhang
1392168126 Merge branch 'feature_logger' into dev
[Feature] 添加了logger用于管理日志,同时测试了ASR、PUNC、SPK模型效果;
[BUG] 发现BUG:使用funasr的一些模块会导致logger被更改,这一点需要进一步讨论解决方案
2025-04-16 14:30:40 +08:00
Ziyang.Zhang
eff22cb33e [Feature] 测试了后续的ASR、punc、spk效果; BUG:在调用funasr后,logger信息会被改变,导致格式变化,重复输出。 2025-04-16 14:30:11 +08:00
Ziyang.Zhang
66c9477e4b [Feature] 添加src/utils/logger文件控制程序日志输出,包括一个root配置器和logger生成器。 2025-04-16 10:46:09 +08:00
9d522fa137 Merge branch 'feature_vad' into dev
[项目结构变动] 分离了模型加载、功能实现、整体工作流等内容
[功能开发] 使用pydantic规范数据格式;开发VAD声音端点检测functor;
[测试] 完成了本地流式(online)的VAD检测,完成了 logic_traher(仅包含VAD与VAD检测结果)的工作流程测试
[未来内容] 1.完成ASR、时间戳、说话人识别;2.接入websocket服务。
2025-04-15 17:18:48 +08:00
f7138dcb39 [Feature] 调整VAD工作流程,规范VAD产出数据规范为 models/audiobinary中的AudioBinary_Chunk;完整测试LogicTrager VAD online流程。 2025-04-15 17:15:13 +08:00
8b69ff195f [Feature] Add /tests/modelsuse 测试实时VAD检测。 2025-04-15 13:53:06 +08:00
70 changed files with 8824 additions and 845 deletions

19
.cursorrules Normal file

@ -0,0 +1,19 @@
You are an AI assistant specialized in Python development. Your approach emphasizes:
1. Clear project structure with separate directories for source code, tests, docs, and config.
2. Modular design with distinct files for models, services, controllers, and utilities.
3. Configuration management using environment variables.
4. Robust error handling and logging, including context capture.
5. Comprehensive testing with pytest.
6. Detailed documentation using docstrings and README files.
7. Dependency management via https://github.com/astral-sh/rye and virtual environments.
8. Code style consistency using Ruff.
9. CI/CD implementation with GitHub Actions or GitLab CI.
10. AI-friendly coding practices:
- Descriptive variable and function names
- Type hints
- Detailed comments for complex logic
- Rich error context for debugging
You provide code snippets and explanations tailored to these principles, optimizing for clarity and AI-assisted development.

1
.gitignore vendored

@ -29,6 +29,7 @@ env/
.coverage .coverage
htmlcov/ htmlcov/
.pytest_cache/ .pytest_cache/
savePath/
# 编辑器相关 # 编辑器相关
.idea/ .idea/

@ -1,27 +0,0 @@
FROM python:3.9-slim
# 安装系统依赖
RUN apt-get update && apt-get install -y \
build-essential \
libsndfile1 \
ffmpeg \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
# 安装Python依赖
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 复制应用代码
COPY . .
# 设置环境变量
ENV PYTHONPATH=/app
ENV PYTHONUNBUFFERED=1
# 暴露WebSocket端口
EXPOSE 10095
# 启动服务
CMD ["python", "src/server.py"]

191
README.md

@ -1,113 +1,142 @@
# FunASR WebSocket服务 # FunASR FastAPI WebSocket Service
## 简介 一个基于 FunASR 和 FastAPI 构建的高性能、实时的语音识别 WebSocket 服务。该项目核心特色是支持"一发多收"的广播模式,适用于会议实时字幕、在线教育、直播转写等需要将单一音源的识别结果分发给多个客户端的场景。
本项目基于FunASR实现了一个WebSocket语音识别服务支持实时语音流的在线和离线识别。利用ModelScope开源语音模型该服务可以进行高精度的中文语音识别并支持语音活动检测(VAD)和自动添加标点符号。
## ✨ 功能特性
- **实时语音处理**: 集成 FunASR 的语音活动检测VAD、语音识别ASR和声纹识别SPK模型。
- **WebSocket 流式 API**: 提供低延迟、双向的实时通信接口。
- **"一发多收"架构**:
- **发送者 (Sender)**: 单一客户端作为音频来源,向服务器持续发送音频流。
- **接收者 (Receiver)**: 多个客户端可以订阅同一个会话,实时接收广播的识别结果。
- **异步核心**: 基于 FastAPI 和 `asyncio` 构建,可处理大量并发连接。
- **模块化设计**: 清晰地分离了服务层 (`server.py`)、会话管理层 (`ASRRunner`) 和核心处理流水线 (`ASRPipeline`)。
## 📂 项目结构
## 项目结构
``` ```
. .
├── src/ # 源代码目录 ├── main.py # 应用程序主入口,使用 uvicorn 启动服务
│ ├── __init__.py # 包初始化文件 ├── docker
│ ├── server.py # WebSocket服务器实现 │ ├── Dockerfile # Docker 镜像构建文件
│ ├── config.py # 配置处理模块 │ ├── docker-compose.yml # Docker 容器编排文件
│ ├── models.py # 模型加载模块 ├── docs
│ ├── service.py # ASR服务实现 │ ├── WEBSOCKET_API.md # WebSocket API 详细使用文档和示例
│ └── client.py # 测试客户端 │ ├── SystemArchitecture.md # 系统架构文档
├── tests/ # 测试目录 ├── src
│ ├── __init__.py # 测试包初始化文件 │ ├── server.py # FastAPI 应用核心,管理生命周期和全局资源
│ └── test_config.py # 配置模块测试 │ ├── runner
├── requirements.txt # Python依赖 │ │ └── ASRRunner.py # 核心会话管理器,负责创建和协调识别会话 (SAR)
├── Dockerfile # Docker配置 │ ├── pipeline
├── docker-compose.yml # Docker Compose配置 │ │ └── ASRpipeline.py # 同步的、基于线程的语音处理流水线
├── .gitignore # Git忽略文件 │ ├── functor # VAD, ASR, SPK 等原子操作的实现
└── README.md # 项目说明 │ ├── websockets
│ │ ├── adapter.py # WebSocket 适配器,处理数据格式转换
│ │ ├── endpoint
│ │ │ └── asr_endpoint.py # WebSocket 的业务逻辑端点
│ │ └── router.py # WebSocket 路由
│ ├── core
│ │ └── model_loader.py # 模型加载器
│ ├── utils
│ │ └── logger.py # 日志记录器
│ │ └── data_format.py # 数据格式转换
│ │ └── mock_websocket.py # 模拟 WebSocket 客户端
├── tests
│ ├── runner
│ │ └── asr_runner_test.py # ASRRunner 的单元测试 (异步)
│ ├── websocket
│ │ └── websocket_asr.py # WebSocket 服务的端到端测试
``` ```
## 功能特性 ## 🚀 快速开始
- **多模式识别**:支持离线(offline)、在线(online)和两阶段(2pass)识别模式 ### 1. 环境与依赖
- **语音活动检测**:自动检测语音开始和结束
- **标点符号**:支持自动添加标点符号
- **WebSocket接口**基于二进制WebSocket提供实时语音识别
- **Docker支持**:提供容器化部署支持
## 安装与使用
### 环境要求
- Python 3.8+ - Python 3.8+
- CUDA支持 (若需GPU加速) - 项目依赖项记录在 `requirements.txt` 文件中。
- 内存 >= 8GB
### 安装依赖 ### 2. 安装
建议在虚拟环境中安装依赖。在项目根目录下,运行:
```bash ```bash
pip install -r requirements.txt pip install -r requirements.txt
``` ```
### 运行服务器 ### 3. 运行服务
执行主入口文件来启动 FastAPI 服务:
```bash
python main.py
```
服务启动后,将监听 `http://0.0.0.0:8000`
## 💡 如何使用
服务通过 WebSocket 提供,客户端通过 `session_id` 来创建或加入一个识别会话,并通过 `mode` 参数声明自己的角色(`sender``receiver`)。
**详细的 API 说明、URL 格式以及 Python 和 JavaScript 的客户端连接示例,请参阅:**
➡️ **[WEBSOCKET_API.md](./docs/WEBSOCKET_API.md)**
## 🔬 测试
项目提供了两种测试方式来验证其功能。
### 1. 端到端 WebSocket 测试
此测试会模拟一个 `sender` 和一个 `receiver`,完整地测试一次识别会话。
**前提**: 确保 FastAPI 服务正在运行。
```bash
python main.py
```
在项目根目录下执行:
```bash
python tests/websocket/websocket_asr.py
```
### 2. ASRRunner 单元测试
此测试针对核心的 `ASRRunner` 组件进行,验证其异步逻辑。
执行测试:
```bash
python test_main.py
```
## 🐳 使用 Docker 部署
### 1. 构建 Docker 镜像
在项目根目录下,运行:
```bash ```bash
python src/server.py docker build -t asr-server:x.x.x -f docker/Dockerfile .
``` ```
常用启动参数: ### 2. 运行 Docker 容器
- `--host`: 服务器监听地址,默认为 0.0.0.0
- `--port`: 服务器端口,默认为 10095
- `--device`: 设备类型(cuda或cpu),默认为 cuda
- `--ngpu`: GPU数量0表示使用CPU默认为 1
### 测试客户端
```bash ```bash
python src/client.py --audio_file path/to/audio.wav docker run -d -p 11096:11096 --name asr-server -v ~/.cache/modelscope:/root/.cache/modelscope asr-server:x.x.x
``` ```
常用客户端参数: ### 环境变量说明
- `--audio_file`: 要识别的音频文件路径
- `--mode`: 识别模式,可选 2pass/online/offline默认为 2pass
- `--host`: 服务器地址,默认为 localhost
- `--port`: 服务器端口,默认为 10095
## Docker部署 - `SPEAKERS_URL`: 说话人数据库 API 的 URL。
### 构建镜像 示例/默认 SPEAKERS_URL="http://172.23.30.120:11200/api/v1/speakers/"
```bash 此url为后端api提供的查询所有数据库中说话人信息的接口
docker build -t funasr-websocket .
```
### 使用Docker Compose启动 - `LOG_LEVEL`: 日志等级。
```bash 示例/默认 LOG_LEVEL="INFO"
docker-compose up -d
```
## API说明 `LOG_LEVEL` 为总项目日志等级,
### WebSocket消息格式 - `LOG_LEVEL_ASR_SERVER`: 日志等级。
1. **客户端配置消息**: 示例/默认 LOG_LEVEL_ASR_SERVER="INFO"
```json
{
"mode": "2pass", // 可选: "2pass", "online", "offline"
"chunk_size": "5,10", // 块大小,格式为"encoder_size,decoder_size"
"wav_name": "audio1", // 音频标识名称
"is_speaking": true // 是否正在说话
}
```
2. **客户端音频数据**: `LOG_LEVEL_ASR_SERVER` 为 ASR 服务日志等级,优先级高于 `LOG_LEVEL`
二进制音频数据流16kHz采样率16位PCM格式
3. **服务器识别结果**:
```json
{
"mode": "2pass-online", // 识别模式
"text": "识别的文本内容", // 识别结果
"wav_name": "audio1", // 音频标识
"is_final": false // 是否是最终结果
}
```
## 许可证
[MIT](LICENSE)

13
data/denoise.py Normal file

@ -0,0 +1,13 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
ans = pipeline(
Tasks.acoustic_noise_suppression,
model='iic/speech_frcrn_ans_cirm_16k')
wav_file = 'speaker_wav/HaiaoDuan.wav'
output_path = 'denoise_output/HaiaoDuan_denoise_output.wav'
result = ans(
wav_file,
output_path=output_path)

Binary file not shown.

Binary file not shown.

35
data/record.py Normal file

@ -0,0 +1,35 @@
"""
本地录音保存为wav格式存储在data/speaker_wav目录下
"""
import pyaudio
import wave
def record_audio(filename, duration=5, format=pyaudio.paInt16, channels=1, rate=16000):
"""
本地录音保存为wav格式存储在data/speaker_wav目录下
"""
p = pyaudio.PyAudio()
stream = p.open(format=format, channels=channels, rate=rate, input=True, frames_per_buffer=1024)
print("按下回车键开始录音...")
input()
frames = []
for i in range(0, int(rate / 1024 * duration)):
data = stream.read(1024)
frames.append(data)
print("录音结束")
stream.stop_stream()
stream.close()
p.terminate()
wav_file = wave.open(filename, 'wb')
wav_file.setnchannels(channels)
wav_file.setsampwidth(p.get_sample_size(format))
wav_file.setframerate(rate)
wav_file.writeframes(b''.join(frames))
wav_file.close()
if __name__ == "__main__":
record_audio(
"data/speaker_wav/test.wav",
duration=5
)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

400
data/speakers.json Normal file

@ -0,0 +1,400 @@
[
{
"speaker_id": "137facd6-a1c9-47b9-b87d-16e6f62e07bd",
"speaker_name": "ZiyangZhang",
"wav_path": "/home/lyg/Code/funasr/data/speaker_wav/ZiyangZhang.wav",
"speaker_embs": [
-0.4249887466430664,
-0.12976674735546112,
1.6118208169937134,
1.3348901271820068,
0.1423041820526123,
0.16940945386886597,
-0.042910803109407425,
0.9634712934494019,
0.9677271246910095,
1.1112406253814697,
-2.0086846351623535,
1.729629635810852,
-0.3664000928401947,
2.4323978424072266,
-1.587996244430542,
-1.0803641080856323,
0.08011860400438309,
1.6515964269638062,
-1.1337167024612427,
-0.5088973045349121,
-1.0002555847167969,
0.11426643282175064,
-0.8616334199905396,
-0.006051262840628624,
0.44800689816474915,
0.6659525632858276,
-0.9864538908004761,
2.1259539127349854,
-0.49345871806144714,
-0.14384664595127106,
0.0742349922657013,
0.25577273964881897,
1.0516602993011475,
1.7297064065933228,
-0.44126248359680176,
1.3971654176712036,
0.04305446520447731,
-2.261837959289551,
-0.355578750371933,
-0.8388981819152832,
0.8178591728210449,
0.016942109912633896,
0.8212596774101257,
1.108891248703003,
-0.5182072520256042,
-0.07741295546293259,
0.9407528042793274,
0.026407398283481598,
-0.6210324168205261,
-2.0659642219543457,
0.13895569741725922,
-1.3570973873138428,
2.236407995223999,
-0.29706746339797974,
1.9819035530090332,
1.3580390214920044,
-0.5505754351615906,
0.7189999222755432,
-0.3190038502216339,
1.1075336933135986,
-1.4158482551574707,
0.20138776302337646,
0.8354343175888062,
0.1671304553747177,
-0.56927490234375,
1.057538390159607,
-0.2868591248989105,
0.005044424440711737,
0.49878695607185364,
-0.7493277192115784,
2.4639663696289062,
0.5516767501831055,
-0.2763596177101135,
-0.8769170641899109,
-1.296872615814209,
-0.5233777165412903,
-0.10551001876592636,
-0.5955559611320496,
-0.6046199202537537,
0.22645621001720428,
1.12480890750885,
-0.3678736388683319,
-1.1580262184143066,
-0.3625229299068451,
0.8251489996910095,
0.3464623987674713,
2.261840581893921,
-0.11341957747936249,
-0.6645990610122681,
0.8480257987976074,
-0.47770705819129944,
0.8085628747940063,
-0.26823946833610535,
-0.25040531158447266,
1.0610276460647583,
-0.14239133894443512,
-1.309299349784851,
-1.0987954139709473,
-0.1301683634519577,
-0.05199439451098442,
-0.07838833332061768,
-0.21310138702392578,
0.29347339272499084,
1.0793802738189697,
-1.813226342201233,
-1.1362330913543701,
-0.13013578951358795,
0.6647212505340576,
-0.34312230348587036,
0.5921282172203064,
0.26284533739089966,
0.9369505047798157,
0.1739131063222885,
0.7924790978431702,
0.3412249982357025,
0.16646981239318848,
-0.32468467950820923,
-0.5835385918617249,
0.05923287197947502,
1.191710352897644,
-0.3653518557548523,
-0.8665252923965454,
0.7419591546058655,
-1.7234965562820435,
0.3421083092689514,
-0.24517370760440826,
-0.8724228143692017,
-0.11004912108182907,
-0.10676378011703491,
-1.0688399076461792,
0.4397974908351898,
-0.9902229309082031,
-0.2676651179790497,
1.4346729516983032,
0.34571582078933716,
0.9091840386390686,
0.41458258032798767,
-0.7863419055938721,
0.6952191591262817,
0.8847752809524536,
0.15871241688728333,
-0.10740098357200623,
-0.5305340886116028,
1.0536329746246338,
-1.337695837020874,
0.23358777165412903,
-0.19285082817077637,
-0.5339606404304504,
-0.6768214106559753,
1.6815600395202637,
-0.36710524559020996,
-0.22888287901878357,
-0.2714850902557373,
-0.0895417258143425,
0.3480932116508484,
-0.19148986041545868,
0.44108960032463074,
0.03198949620127678,
-0.3665091097354889,
-0.6040502786636353,
0.37234461307525635,
-0.07462035119533539,
-0.18109525740146637,
-0.19882601499557495,
0.33298638463020325,
0.039957765489816666,
0.6185765266418457,
1.5921381711959839,
0.04164457693696022,
-0.7556226849555969,
-1.0537445545196533,
0.36932048201560974,
-0.2881639897823334,
-1.3762420415878296,
-0.6029151678085327,
-1.3592504262924194,
0.6726564168930054,
0.06349147856235504,
-0.4627697765827179,
1.1113581657409668,
-1.1767970323562622,
0.3900119662284851,
-0.3050364851951599,
-0.2807784676551819,
-0.7237444519996643,
-0.039161279797554016,
0.5845404267311096,
-0.4385261833667755,
-0.3988557755947113,
-1.235430359840393,
-0.648483395576477,
1.084520936012268
]
},
{
"speaker_id": "81e1b806-c76a-468f-98b7-4a63f2996480",
"speaker_name": "HaiaoDuan",
"wav_path": "/home/lyg/Code/funasr/data/speaker_wav/HaiaoDuan.wav",
"speaker_embs": [
-1.3490606546401978,
-0.9654964208602905,
0.6671794652938843,
2.3401081562042236,
-1.374346137046814,
0.24404077231884003,
0.08137784898281097,
0.10915698111057281,
0.8208633065223694,
-1.0312862396240234,
1.721955418586731,
-0.16976028680801392,
-1.0259445905685425,
-0.9134035706520081,
-1.3709611892700195,
-0.6821202635765076,
1.0825326442718506,
1.4931895732879639,
-0.06801076978445053,
-0.5044959187507629,
-1.3154232501983643,
-1.1049765348434448,
0.6122218370437622,
1.1061663627624512,
-0.2288999855518341,
-0.03568289428949356,
-0.9260172247886658,
1.1030527353286743,
-0.7439772486686707,
1.4323620796203613,
0.2221372127532959,
-0.8355774283409119,
0.6758987307548523,
0.8520456552505493,
-0.0186605341732502,
-0.981821596622467,
0.11743613332509995,
-0.3539535701274872,
-0.33924832940101624,
-0.510174036026001,
0.6893219351768494,
-0.10966216027736664,
-1.5873743295669556,
1.7041956186294556,
-0.9844599366188049,
-1.368901252746582,
0.44316115975379944,
-2.406590700149536,
0.9880101680755615,
0.8344699740409851,
0.22896111011505127,
-1.4464795589447021,
2.222980260848999,
-0.22508130967617035,
0.8659772276878357,
0.7801474928855896,
1.824644923210144,
-0.2455991804599762,
-0.06682202965021133,
0.07106778025627136,
-1.8072712421417236,
0.7733234763145447,
0.20490191876888275,
-1.119908094406128,
-1.2623472213745117,
0.34426289796829224,
0.7909225821495056,
0.47128093242645264,
-0.9976771473884583,
-0.6703121662139893,
0.7459381818771362,
1.0664807558059692,
0.659284770488739,
-0.49438077211380005,
0.1974140703678131,
-0.07557231187820435,
-1.324866533279419,
-1.2217090129852295,
-1.0160834789276123,
0.7517350912094116,
0.06301767379045486,
0.8621189594268799,
-1.033493161201477,
-0.18051855266094208,
-0.2633781135082245,
0.5859690308570862,
1.5803791284561157,
-0.7071301341056824,
-0.016185184940695763,
-0.5259001851081848,
-0.6252623796463013,
1.4383807182312012,
0.6068354845046997,
0.39534664154052734,
0.22612401843070984,
-1.541978120803833,
-2.575181484222412,
-0.9924071431159973,
1.9649298191070557,
-1.1940282583236694,
-0.6481325030326843,
-1.5226261615753174,
1.6535273790359497,
0.7740333676338196,
-1.8780876398086548,
0.627184271812439,
1.0915889739990234,
1.694388508796692,
-0.47886598110198975,
-0.04895557090640068,
0.3620351552963257,
0.640113115310669,
-0.4149058163166046,
-0.18083086609840393,
-0.30447620153427124,
0.022528085857629776,
-0.6550383567810059,
-0.3812088668346405,
-0.478842169046402,
0.6615785360336304,
0.49959492683410645,
-0.249789759516716,
1.7448066473007202,
-0.9037050008773804,
-0.7441433668136597,
0.5949154496192932,
-1.1230697631835938,
-0.2552490830421448,
0.4216223657131195,
-0.5870983004570007,
0.7283152937889099,
-0.13834434747695923,
-1.3267407417297363,
1.1050132513046265,
1.731435775756836,
0.3724023103713989,
0.830539882183075,
-1.032881736755371,
0.8204181790351868,
0.05735205113887787,
0.5442802906036377,
-0.7974395751953125,
0.18374553322792053,
-0.17642715573310852,
-0.051413919776678085,
-0.2413552850484848,
-0.43316808342933655,
-0.2594863772392273,
1.5363879203796387,
0.5056991577148438,
-1.3894445896148682,
-1.2057586908340454,
-0.48546579480171204,
-0.2659154236316681,
0.9767322540283203,
-1.97313392162323,
-0.3016327917575836,
-0.6123557686805725,
0.288481205701828,
0.2976057827472687,
0.08243764936923981,
0.6122551560401917,
-0.6019028425216675,
-0.10548368841409683,
-0.016991911455988884,
1.75961172580719,
0.6418831944465637,
0.3137458264827728,
0.25365981459617615,
-0.45389246940612793,
0.238858163356781,
0.2631453275680542,
1.1121031045913696,
-0.9991472363471985,
-0.8904637694358826,
-1.1346020698547363,
-1.1918814182281494,
-1.1205440759658813,
-1.486283779144287,
1.0530670881271362,
-0.583172082901001,
0.26391518115997314,
1.2654175758361816,
-0.8430055975914001,
0.21697403490543365,
-0.30710718035697937,
2.191946506500244,
-0.19980488717556,
-0.5966204404830933,
0.04923265427350998,
-0.8815436959266663,
0.9289136528968811
]
}
]

14
data/speakers.json.backup Normal file

@ -0,0 +1,14 @@
[
{
"speaker_id": "b7e2c8e2-1f3a-4c2a-9e7a-2c1d4e8f9a3b",
"speaker_name": "ZiyangZhang",
"wav_path": "/home/lyg/Code/funasr/data/speaker_wav/ZiyangZhang.wav",
"speaker_embs": ""
},
{
"speaker_id": "b7e2c8e2-1f3a-4c2a-9e7a-2c1d4e8f9a3b",
"speaker_name": "HaiaoDuan",
"wav_path": "/home/lyg/Code/funasr/data/speaker_wav/HaiaoDuan.wav",
"speaker_embs": ""
}
]

@ -1,23 +0,0 @@
version: '3.8'
services:
funasr:
build: .
container_name: funasr-websocket
volumes:
- .:/app
# 如果需要使用本地模型缓存
- ~/.cache/modelscope:/root/.cache/modelscope
ports:
- "10095:10095"
environment:
- PYTHONUNBUFFERED=1
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
restart: unless-stopped
command: python src/server.py --device cuda --ngpu 1

36
docker/Dockerfile Normal file

@ -0,0 +1,36 @@
FROM python:3.10-slim
# 更换系统源
# 科大大的 Debian 12 (bookworm) 镜像源
RUN echo "deb http://mirrors.ustc.edu.cn/debian bookworm main contrib non-free" > /etc/apt/sources.list && \
echo "deb http://mirrors.ustc.edu.cn/debian bookworm-updates main contrib non-free" >> /etc/apt/sources.list && \
echo "deb http://mirrors.ustc.edu.cn/debian-security bookworm-security main" >> /etc/apt/sources.list && \
echo "deb http://mirrors.ustc.edu.cn/debian bookworm-backports main contrib non-free" >> /etc/apt/sources.list
# 安装系统依赖
RUN apt-get update && apt-get install -y \
build-essential \
libsndfile1 \
ffmpeg \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
# 安装Python依赖
COPY ../requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
# 复制应用代码
COPY ../src/ /app/src/
COPY ../logs/ /app/logs/
COPY ../main.py /app/
# 设置环境变量
ENV SPEAKERS_URL="http://172.23.30.120:11200/api/v1/speakers/" \
LOG_LEVEL="INFO" \
LOG_LEVEL_ASR_SERVER="INFO"
# 暴露WebSocket端口
EXPOSE 11096
# 启动服务
CMD ["python", "main.py"]

@ -0,0 +1,94 @@
# 系统架构
本项目是一个基于 FunASR 和 FastAPI 构建的高性能、实时的语音识别ASRWebSocket 服务。其核心架构设计旨在处理实时的流式音频数据,并通过 "一发多收" 的广播模式,将识别结果分发给多个客户端。
## 核心组件
系统主要由以下几个核心组件构成,它们各司其职,通过异步和多线程协作,实现了高效的实时语音处理:
1. **WebSocket 服务 (FastAPI)**
- **文件**: `src/websockets/`, `src/server.py`, `main.py`
- **职责**: 作为系统的网络入口,负责处理 WebSocket 连接。它使用 FastAPI 构建,提供异步的、非阻塞的 I/O 处理能力,能够高效地管理大量并发客户端连接。`asr_endpoint.py` 是核心端点,负责根据客户端声明的 `mode` (sender/receiver) 将连接路由到 `ASRRunner`
2. **会话管理器 (ASRRunner)**
- **文件**: `src/runner/ASRRunner.py`
- **职责**: 这是整个系统的"大脑"和协调中心。它管理所有活跃的语音识别会话(`SenderAndReceiver`,简称 SAR
- **会话生命周期**: 当一个 `sender` 连接时,`ASRRunner` 会创建一个新的 SAR 实例;当 `receiver` 连接时,会将其加入到指定的 SAR 中。
- **异步桥梁**: `ASRRunner` 运行在主 `asyncio` 事件循环中,负责从 `sender` 的 WebSocket 连接异步接收音频数据,然后通过线程安全的队列 (`queue.Queue`) 将数据传递给同步的 `ASRPipeline`。同时,它也负责接收来自 Pipeline 的最终结果,并将其异步广播给所有 `receiver`
3. **语音处理流水线 (ASRPipeline)**
- **文件**: `src/pipeline/ASRpipeline.py`
- **职责**: 这是实际执行语音处理任务的核心引擎。每个 SAR 会话都拥有一个独立的 `ASRPipeline` 实例,该实例在自己的后台线程中运行。
- **模块化设计**: Pipeline 内部由多个 `Functor` (如 VAD, ASR, SPK) 组成,通过一系列内部队列连接,形成一个处理链。
- **处理流程**:
1. **VAD (Voice Activity Detection)**: 检测音频流中的有效语音片段。
2. **ASR (Automatic Speech Recognition)**: 将语音片段转换为文字。
3. **SPK (Speaker Recognition)**: 识别说话人(声纹识别)。
4. **ResultBinder**: 将 ASR 的文本结果和 SPK 的说话人结果合并,生成最终的识别消息。
4. **原子操作单元 (Functor)**
- **文件**: `src/functor/`
- **职责**: `Functor` 是 Pipeline 中执行具体原子任务的单元。每个 Functor 都是一个独立的类,负责调用底层 FunASR 模型来执行 VAD、ASR 或 SPK 等任务。这种设计使得处理流程更加清晰和模块化。
## 流程图
下面是系统处理一次完整语音识别请求的流程图,展示了从客户端连接到收到识别结果的全过程。
```mermaid
sequenceDiagram
participant Sender Client
participant Receiver Client
participant FastAPI WebSocket Endpoint
participant ASRRunner
participant ASRPipeline (Thread)
participant Functors (VAD, ASR, SPK)
par
Sender Client->>+FastAPI WebSocket Endpoint: 发起连接 (mode=sender, session_id=S1)
FastAPI WebSocket Endpoint->>+ASRRunner: new_SAR(ws, name="S1")
ASRRunner->>ASRRunner: 创建 SenderAndReceiver (SAR) 实例
ASRRunner->>ASRPipeline (Thread): 创建并运行 Pipeline 实例
ASRPipeline (Thread)->>Functors (VAD, ASR, SPK): 初始化 Functor 线程
ASRRunner->>-FastAPI WebSocket Endpoint: 返回成功
FastAPI WebSocket Endpoint->>-Sender Client: 连接建立
and
Receiver Client->>+FastAPI WebSocket Endpoint: 发起连接 (mode=receiver, session_id=S1)
FastAPI WebSocket Endpoint->>+ASRRunner: join_SAR(ws, name="S1")
ASRRunner->>ASRRunner: 将 Receiver 加入 S1 的接收者列表
ASRRunner->>-FastAPI WebSocket Endpoint: 返回成功
FastAPI WebSocket Endpoint->>-Receiver Client: 连接建立
end
loop 音频流传输与处理
Sender Client->>ASRRunner: 发送音频数据块
ASRRunner->>ASRPipeline (Thread): (via Queue) 传递音频数据
ASRPipeline (Thread)->>Functors (VAD, ASR, SPK): (via sub-queues) 分发数据
Note over Functors (VAD, ASR, SPK): 1. VAD检测语音<br/>2. ASR识别文本<br/>3. SPK识别说话人<br/>4. ResultBinder合并结果
Functors (VAD, ASR, SPK)->>ASRPipeline (Thread): (via callback) 返回最终识别结果
ASRPipeline (Thread)->>ASRRunner: (via thread-safe callback) 发送结果
end
ASRRunner->>ASRRunner: 收到结果,准备广播
ASRRunner-->>Sender Client: 广播识别结果
ASRRunner-->>Receiver Client: 广播识别结果
par
Sender Client->>FastAPI WebSocket Endpoint: 关闭连接
FastAPI WebSocket Endpoint->>ASRRunner: (触发异常)
ASRRunner->>ASRPipeline (Thread): 发送停止信号
ASRPipeline (Thread)->>Functors (VAD, ASR, SPK): 停止 Functor 线程
Note right of ASRRunner: 清理会话资源
and
Receiver Client->>FastAPI WebSocket Endpoint: 关闭连接
FastAPI WebSocket Endpoint->>ASRRunner: (触发异常)
ASRRunner->>ASRRunner: 从接收者列表移除
end
```
## 架构优势
- **高并发和低延迟**: 采用 `asyncio` 和 WebSocket网络层能够处理大量并发连接。音频处理在独立的线程中进行避免了 CPU 密集型任务阻塞事件循环,保证了低延迟。
- **解耦与模块化**: `WebSocket Endpoint``ASRRunner``ASRPipeline` 职责清晰,相互解耦。`Functor` 的设计使得添加或修改处理步骤变得容易。
- **鲁棒性**: 每个识别会话SAR都是隔离的一个会话的失败不会影响其他会话。优雅的关闭逻辑确保了资源的正确释放。
- **可扩展性**: "一发多收" 的广播模式可以轻松扩展到大量 `receiver`,适用于多种实时应用场景。

262
docs/WEBSOCKET_API.md Normal file

@ -0,0 +1,262 @@
# FunASR-FastAPI WebSocket API 文档
本文档详细介绍了如何连接和使用 FunASR-FastAPI 实时语音识别服务的 WebSocket 接口。
## 1. 连接端点 (Endpoint)
服务的 WebSocket 端点 URL 格式如下:
```
ws://<your_server_host>:8000/ws/asr/{session_id}?mode=<client_mode>
```
### 参数说明
- **`{session_id}`** (路径参数, `str`, **必需**):
用于唯一标识一个识别会话(例如,一场会议或一次直播)。所有属于同一次会话的客户端都应使用相同的 `session_id`
- **`mode`** (查询参数, `str`, **必需**):
定义客户端的角色。
- `sender`: 音频发送者。一个会话中应该只有一个 `sender`。此客户端负责将实时音频流发送到服务器。
- `receiver`: 结果接收者。一个会话中可以有多个 `receiver`。此客户端只接收由服务器广播的识别结果,不发送音频。
## 2. 数据格式
### 2.1 发送数据 (Sender -> Server)
- **音频格式**: `sender` 必须发送原始的 **PCM 音频数据**
- **采样率**: 16000 Hz
- **位深**: 32-bit (floating point)
- **声道数**: 单声道 (Mono)
- **传输格式**: 必须以**二进制 (bytes)** 格式发送。
- **结束信号**: 当音频流结束时,`sender` 应发送一个**文本消息** `"close"` 来通知服务器关闭会话。
### 2.2 接收数据 (Server -> Receiver)
服务器会将识别结果以 **JSON 文本** 格式广播给会话中的所有 `receiver`(以及 `sender` 自己。JSON 对象的结构示例如下:
```json
{
"asr": "你好,世界。",
"spk": {
"speaker_id": "uuid-of-the-speaker",
"speaker_name": "SpeakerName",
"score": 0.98
}
}
```
## 3. Python 客户端示例
需要安装 `websockets` 库: `pip install websockets`
### 3.1 Python Sender 示例 (发送本地音频文件)
这个脚本会读取一个 WAV 文件,并将其内容以流式方式发送到服务器。
```python
import asyncio
import websockets
import soundfile as sf
import uuid
# --- 配置 ---
SERVER_URI = "ws://localhost:8000/ws/asr/{session_id}?mode=sender"
SESSION_ID = str(uuid.uuid4()) # 为这次会话生成一个唯一的ID
AUDIO_FILE = "tests/XT_ZZY_denoise.wav" # 替换为你的音频文件路径
CHUNK_SIZE = 3200 # 对应 100ms 的 float32 数据 (16000 * 4 * 0.1)
async def send_audio():
"""连接到服务器,并流式发送音频文件"""
uri = SERVER_URI.format(session_id=SESSION_ID)
print(f"作为 Sender 连接到: {uri}")
async with websockets.connect(uri) as websocket:
try:
# 读取音频文件
with sf.SoundFile(AUDIO_FILE, 'r') as f:
assert f.samplerate == 16000, "音频文件采样率必须为 16kHz"
assert f.channels == 1, "音频文件必须为单声道"
print("开始发送音频...")
while True:
# 读取为 float32 类型
data = f.read(CHUNK_SIZE, dtype='float32')
if not data.any():
break
# 将 numpy 数组转换为原始字节流
await websocket.send(data.tobytes())
await asyncio.sleep(0.1) # 模拟实时音频输入
print("音频发送完毕,发送结束信号。")
await websocket.send("close")
# 等待服务器的最终确认或关闭连接
response = await websocket.recv()
print(f"收到服务器最终响应: {response}")
except websockets.exceptions.ConnectionClosed as e:
print(f"连接已关闭: {e}")
except Exception as e:
print(f"发生错误: {e}")
if __name__ == "__main__":
asyncio.run(send_audio())
```
### 3.2 Python Receiver 示例 (接收识别结果)
这个脚本会连接到指定的会话,并持续打印服务器广播的识别结果。
```python
import asyncio
import websockets
# --- 配置 ---
# !!! 必须和 Sender 使用相同的 SESSION_ID !!!
SERVER_URI = "ws://localhost:8000/ws/asr/{session_id}?mode=receiver"
SESSION_ID = "在此处粘贴你的Sender会话ID"
async def receive_results():
"""连接到服务器并接收识别结果"""
if "粘贴你的Sender会话ID" in SESSION_ID:
print("错误:请先设置有效的 SESSION_ID")
return
uri = SERVER_URI.format(session_id=SESSION_ID)
print(f"作为 Receiver 连接到: {uri}")
async with websockets.connect(uri) as websocket:
try:
print("等待接收识别结果...")
while True:
message = await websocket.recv()
print(f"收到结果: {message}")
except websockets.exceptions.ConnectionClosed as e:
print(f"连接已关闭: {e.code} {e.reason}")
if __name__ == "__main__":
asyncio.run(receive_results())
```
## 4. JavaScript 客户端示例 (浏览器)
这个示例展示了如何在网页上通过麦克风获取音频,并将其作为 `sender` 发送。
```html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>WebSocket ASR Client</title>
</head>
<body>
<h1>FunASR WebSocket Client (Sender)</h1>
<p><strong>Session ID:</strong> <span id="sessionId"></span></p>
<button id="startButton">开始识别</button>
<button id="stopButton" disabled>停止识别</button>
<h2>识别结果:</h2>
<div id="results"></div>
<script>
const startButton = document.getElementById('startButton');
const stopButton = document.getElementById('stopButton');
const resultsDiv = document.getElementById('results');
const sessionIdSpan = document.getElementById('sessionId');
let websocket;
let audioContext;
let scriptProcessor;
let mediaStream;
const CHUNK_DURATION_MS = 100; // 每100ms发送一次数据
const SAMPLE_RATE = 16000;
// 生成一个简单的UUID
function generateUUID() {
return ([1e7]+-1e3+-4e3+-8e3+-1e11).replace(/[018]/g, c =>
(c ^ crypto.getRandomValues(new Uint8Array(1))[0] & 15 >> c / 4).toString(16)
);
}
async function startRecording() {
const sessionId = generateUUID();
sessionIdSpan.textContent = sessionId;
const wsUrl = `ws://${window.location.host}/ws/asr/${sessionId}?mode=sender`;
websocket = new WebSocket(wsUrl);
websocket.onopen = () => {
console.log("WebSocket 连接已打开");
startButton.disabled = true;
stopButton.disabled = false;
resultsDiv.innerHTML = '';
};
websocket.onmessage = (event) => {
console.log("收到消息:", event.data);
const result = JSON.parse(event.data);
const asrText = result.asr || '';
const spkName = result.spk ? result.spk.speaker_name : 'Unknown';
resultsDiv.innerHTML += `<p><strong>${spkName}:</strong> ${asrText}</p>`;
};
websocket.onclose = () => {
console.log("WebSocket 连接已关闭");
stopRecording();
};
websocket.onerror = (error) => {
console.error("WebSocket 错误:", error);
alert("WebSocket 连接失败!");
stopRecording();
};
try {
mediaStream = await navigator.mediaDevices.getUserMedia({ audio: true, video: false });
audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: SAMPLE_RATE });
const source = audioContext.createMediaStreamSource(mediaStream);
const bufferSize = CHUNK_DURATION_MS * SAMPLE_RATE / 1000 * 4; // 计算缓冲区大小
scriptProcessor = audioContext.createScriptProcessor(bufferSize, 1, 1);
scriptProcessor.onaudioprocess = (e) => {
if (websocket && websocket.readyState === WebSocket.OPEN) {
const inputData = e.inputBuffer.getChannelData(0);
// 服务器期望 float32 数据inputData 本身就是 Float32Array直接发送其 buffer
websocket.send(inputData.buffer);
}
};
source.connect(scriptProcessor);
scriptProcessor.connect(audioContext.destination);
} catch (err) {
console.error("无法获取麦克风:", err);
alert("无法获取麦克风权限!");
if (websocket) websocket.close();
}
}
function stopRecording() {
if (websocket && websocket.readyState === WebSocket.OPEN) {
websocket.send("close");
}
if (mediaStream) {
mediaStream.getTracks().forEach(track => track.stop());
}
if (scriptProcessor) {
scriptProcessor.disconnect();
}
if (audioContext) {
audioContext.close();
}
startButton.disabled = false;
stopButton.disabled = true;
}
startButton.addEventListener('click', startRecording);
stopButton.addEventListener('click', stopRecording);
</script>
</body>
</html>
```

23
main.py Normal file

@ -0,0 +1,23 @@
import uvicorn
import os
from src.server import app
from src.utils.logger import get_module_logger, setup_root_logger
from datetime import datetime
time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# 日志等级 LOG_LEVEL_ASR_SERVER > LOG_LEVEL > INFO(default)
logger_level = os.getenv("LOG_LEVEL", "INFO")
logger_level = os.getenv("LOG_LEVEL_ASR_SERVER", logger_level)
setup_root_logger(level=logger_level, log_file=f"logs/main_{time}.log")
logger = get_module_logger(__name__)
if __name__ == "__main__":
logger.info("启动 FunASR FastAPI 服务器...")
# 在生产环境中推荐使用更强大的ASGI服务器如Gunicorn并配合Uvicorn workers。
# 例如: gunicorn -w 4 -k uvicorn.workers.UvicornWorker main:app
uvicorn.run(
app,
host="0.0.0.0",
port=11096,
log_level="info"
)

20
pyproject.toml Normal file

@ -0,0 +1,20 @@
[project]
name = "asr-server"
version = "0.1.0"
requires-python = ">=3.9.21"
dependencies = [
"addict==2.4.0",
"datasets==2.21.0",
"fastapi==0.115.14",
"funasr==1.2.6",
"numpy==2.0.1",
"pillow==11.1.0",
"pydantic==2.11.3",
"pydub>=0.25.1",
"simplejson==3.20.1",
"sortedcontainers==2.4.0",
"torch==2.3.1",
"torchaudio==2.3.1",
"uvicorn==0.35.0",
"websockets==12.0",
]

@ -1,11 +1,335 @@
pytest==7.3.1 # This file was autogenerated by uv via the following command:
pytest-cov==4.1.0 # uv pip compile pyproject.toml -o requirements.txt
flake8==6.0.0 addict==2.4.0
black==23.3.0 # via asr-server (pyproject.toml)
isort==5.12.0 aiohappyeyeballs==2.6.1
flask==2.3.2 # via aiohttp
requests==2.31.0 aiohttp==3.12.13
websockets==11.0.3 # via
numpy==1.24.3 # datasets
funasr==0.10.0 # fsspec
modelscope==1.9.5 aiosignal==1.4.0
# via aiohttp
aliyun-python-sdk-core==2.16.0
# via
# aliyun-python-sdk-kms
# oss2
aliyun-python-sdk-kms==2.16.5
# via oss2
annotated-types==0.7.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via
# hydra-core
# omegaconf
anyio==4.9.0
# via starlette
async-timeout==5.0.1
# via aiohttp
attrs==25.3.0
# via aiohttp
audioread==3.0.1
# via librosa
certifi==2025.6.15
# via requests
cffi==1.17.1
# via
# cryptography
# soundfile
charset-normalizer==3.4.2
# via requests
click==8.2.1
# via uvicorn
crcmod==1.7
# via oss2
cryptography==45.0.5
# via aliyun-python-sdk-core
datasets==2.21.0
# via asr-server (pyproject.toml)
decorator==5.2.1
# via librosa
dill==0.3.8
# via
# datasets
# multiprocess
editdistance==0.8.1
# via funasr
exceptiongroup==1.3.0
# via anyio
fastapi==0.115.14
# via asr-server (pyproject.toml)
filelock==3.18.0
# via
# datasets
# huggingface-hub
# torch
# triton
frozenlist==1.7.0
# via
# aiohttp
# aiosignal
fsspec==2024.6.1
# via
# datasets
# huggingface-hub
# torch
funasr==1.2.6
# via asr-server (pyproject.toml)
h11==0.16.0
# via uvicorn
hf-xet==1.1.5
# via huggingface-hub
huggingface-hub==0.33.2
# via datasets
hydra-core==1.3.2
# via funasr
idna==3.10
# via
# anyio
# requests
# yarl
jaconv==0.4.0
# via funasr
jamo==0.4.1
# via funasr
jieba==0.42.1
# via funasr
jinja2==3.1.6
# via torch
jmespath==0.10.0
# via aliyun-python-sdk-core
joblib==1.5.1
# via
# librosa
# pynndescent
# scikit-learn
kaldiio==2.18.1
# via funasr
lazy-loader==0.4
# via librosa
librosa==0.11.0
# via funasr
llvmlite==0.44.0
# via
# numba
# pynndescent
markupsafe==3.0.2
# via jinja2
modelscope==1.27.1
# via funasr
mpmath==1.3.0
# via sympy
msgpack==1.1.1
# via librosa
multidict==6.6.3
# via
# aiohttp
# yarl
multiprocess==0.70.16
# via datasets
networkx==3.4.2
# via torch
numba==0.61.2
# via
# librosa
# pynndescent
# umap-learn
numpy==2.0.1
# via
# asr-server (pyproject.toml)
# datasets
# kaldiio
# librosa
# numba
# pandas
# pytorch-wpe
# scikit-learn
# scipy
# soundfile
# soxr
# tensorboardx
# torch-complex
# umap-learn
nvidia-cublas-cu12==12.1.3.1
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==8.9.2.26
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via
# nvidia-cusolver-cu12
# torch
nvidia-nccl-cu12==2.20.5
# via torch
nvidia-nvjitlink-cu12==12.9.86
# via
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
omegaconf==2.3.0
# via hydra-core
oss2==2.19.1
# via funasr
packaging==25.0
# via
# datasets
# huggingface-hub
# hydra-core
# lazy-loader
# pooch
# tensorboardx
# torch-complex
pandas==2.3.0
# via datasets
pillow==11.1.0
# via asr-server (pyproject.toml)
platformdirs==4.3.8
# via pooch
pooch==1.8.2
# via librosa
propcache==0.3.2
# via
# aiohttp
# yarl
protobuf==6.31.1
# via tensorboardx
pyarrow==20.0.0
# via datasets
pycparser==2.22
# via cffi
pycryptodome==3.23.0
# via oss2
pydantic==2.11.3
# via
# asr-server (pyproject.toml)
# fastapi
pydantic-core==2.33.1
# via pydantic
pydub==0.25.1
# via asr-server (pyproject.toml)
pynndescent==0.5.13
# via umap-learn
python-dateutil==2.9.0.post0
# via pandas
pytorch-wpe==0.0.1
# via funasr
pytz==2025.2
# via pandas
pyyaml==6.0.2
# via
# datasets
# funasr
# huggingface-hub
# omegaconf
requests==2.32.4
# via
# datasets
# funasr
# huggingface-hub
# modelscope
# oss2
# pooch
scikit-learn==1.7.0
# via
# librosa
# pynndescent
# umap-learn
scipy==1.15.3
# via
# funasr
# librosa
# pynndescent
# scikit-learn
# umap-learn
sentencepiece==0.2.0
# via funasr
setuptools==80.9.0
# via modelscope
simplejson==3.20.1
# via asr-server (pyproject.toml)
six==1.17.0
# via
# oss2
# python-dateutil
sniffio==1.3.1
# via anyio
sortedcontainers==2.4.0
# via asr-server (pyproject.toml)
soundfile==0.13.1
# via
# funasr
# librosa
soxr==0.5.0.post1
# via librosa
starlette==0.46.2
# via fastapi
sympy==1.14.0
# via torch
tensorboardx==2.6.4
# via funasr
threadpoolctl==3.6.0
# via scikit-learn
torch==2.3.1
# via
# asr-server (pyproject.toml)
# torchaudio
torch-complex==0.4.4
# via funasr
torchaudio==2.3.1
# via asr-server (pyproject.toml)
tqdm==4.67.1
# via
# datasets
# funasr
# huggingface-hub
# modelscope
# umap-learn
triton==2.3.1
# via torch
typing-extensions==4.14.1
# via
# aiosignal
# anyio
# exceptiongroup
# fastapi
# huggingface-hub
# librosa
# multidict
# pydantic
# pydantic-core
# torch
# typing-inspection
# uvicorn
typing-inspection==0.4.1
# via pydantic
tzdata==2025.2
# via pandas
umap-learn==0.5.9.post2
# via funasr
urllib3==2.5.0
# via
# modelscope
# requests
uvicorn==0.35.0
# via asr-server (pyproject.toml)
websockets==12.0
# via asr-server (pyproject.toml)
xxhash==3.5.0
# via datasets
yarl==1.20.1
# via aiohttp

194
src/audio_chunk.py Normal file

@ -0,0 +1,194 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
音频数据块管理类 - 用于存储和处理16KHz音频数据
"""
from typing import List, Optional, Dict
from src.models import AudioBinary_Config, AudioBinary_data_list
# 配置日志
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__, level="INFO")
class AudioBinary:
"""
音频数据存储单元
用于存储二进制数据
面向Slice, 向Slice提供数据与接口
self._audio_config: AudioBinary_Config -- 音频参数配置
self._binary_data_list: AudioBinary_data_list -- 音频数据列表
self._slice_listener: List[callable] -- 切片监听器
AudioBinary_Config: Dict -- 音频参数配置
AudioBinary_data_list: List[bytes] -- 音频数据列表
"""
def __init__(self, *args):
"""
初始化音频数据块
参数:
*args: 可变参数
"""
# 音频参数配置
self._audio_config = AudioBinary_Config()
# 音频片段
self._binary_data_list: AudioBinary_data_list = AudioBinary_data_list()
# 切片监听器
self._slice_listener: List = []
if isinstance(args, Dict):
self._audio_config = AudioBinary_Config.AudioBinary_Config_from_dict(args)
elif isinstance(args, AudioBinary_Config):
self._audio_config = args
else:
raise ValueError("参数类型错误")
def add_slice_listener(self, slice_listener: callable) -> None:
"""
添加切片监听器
参数:
slice_listener: callable -- 切片监听器
"""
self._slice_listener.append(slice_listener)
def __add__(self, other: bytes):
"""
__add__ "+" 运算符的重载
使用方法:
audio_binary = audio_binary + bytes
添加音频数据块 add_binary_data 等效
但可以链式调用, 方便使用
参数:
other: bytes --音频数据块
"""
self._binary_data_list.append(other)
return self
def __iadd__(self, other: bytes):
"""
__iadd__ "+=" 运算符的重载
使用方法:
audio_binary += bytes
添加音频数据块 add_binary_data 等效
但可以链式调用, 方便使用
参数:
other: bytes --音频数据块
"""
self._binary_data_list.append(other)
return self
def add_binary_data(self, binary_data: bytes):
"""
添加音频数据块
参数:
binary_data: bytes --音频数据块
"""
self._binary_data_list.append(binary_data)
def rewrite_binary_data(self, target_index: int, binary_data: bytes):
"""
重写音频数据块
参数:
target_index: int -- 目标索引
binary_data: bytes --音频数据块
"""
self._binary_data_list.rewrite(target_index, binary_data)
def get_binary_data(
self,
start: int = 0,
end: Optional[int] = None,
) -> Optional[bytes]:
"""
获取指定索引的音频数据块
参数:
start: 开始索引
end: 结束索引
返回:
List[bytes]: 音频数据块
"""
if start >= len(self._binary_data_list):
return None
if end is None:
end = start + 1
end = min(end, len(self._binary_data_list))
return self._binary_data_list[start:end]
class AudioChunk:
"""
音频数据块管理类
管理两部分内容, AudioBinary和Slice
AudioBinary用于内部存储字节数据
Slice是AudioBinary的切片,用于外部接口
此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑
"""
_instance: Optional["AudioChunk"] = None
def __new__(cls, *args, **kwargs):
"""
单例模式
"""
if cls._instance is None:
cls._instance = super(AudioChunk, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self) -> None:
"""
初始化AudioChunk实例
"""
self._audio_binary_list: Dict[str, AudioBinary] = {}
self._slice_listener: List[callable] = []
def get_audio_binary(
self,
binary_name: Optional[str] = None,
audio_config: Optional[AudioBinary_Config] = None,
) -> AudioBinary:
"""
获取音频数据块
参数:
binary_name: str -- 音频数据块名称
返回:
AudioBinary: 音频数据块
"""
if binary_name is None:
binary_name = "default"
if binary_name not in self._audio_binary_list:
self._audio_binary_list[binary_name] = AudioBinary(audio_config)
return self._audio_binary_list[binary_name]
@staticmethod
def _time2size(time_ms: int, audio_config: AudioBinary_Config) -> int:
"""
将时间(ms)转换为数据大小(字节)
参数:
time_ms: int -- 时间(ms)
audio_config: AudioBinary_Config -- 音频参数配置
返回:
int: 数据大小(字节)
"""
# 时间(ms)到字节(bytes)计算方法为: 时间(ms) * 采样率(Hz) * 通道数(1 or 2) * 采样位宽(16 or 24) / 1000
time_s = time_ms / 1000
bytes_per_sample = audio_config.sample_width * audio_config.channel
return int(time_s * audio_config.sample_rate * bytes_per_sample)
@staticmethod
def _size2time(size: int, audio_config: AudioBinary_Config) -> int:
"""
将数据大小(字节)转换为时间(ms)
参数:
size: int -- 数据大小(字节)
audio_config: AudioBinary_Config -- 音频参数配置
返回:
int: 时间(ms)
"""
# 字节(bytes)到时间(ms)计算方法为: 字节(bytes) * 1000 / (采样率(Hz) * 通道数(1 or 2) * 采样位宽(16 or 24))
bytes_per_sample = audio_config.sample_width * audio_config.channel
time_ms = size * 1000 // (audio_config.sample_rate * bytes_per_sample)
return time_ms

@ -1,196 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
WebSocket客户端示例 - 用于测试语音识别服务
"""
import asyncio
import json
import websockets
import argparse
import numpy as np
import wave
import os
def parse_args():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description="FunASR WebSocket客户端")
parser.add_argument(
"--host",
type=str,
default="localhost",
help="服务器主机地址"
)
parser.add_argument(
"--port",
type=int,
default=10095,
help="服务器端口"
)
parser.add_argument(
"--audio_file",
type=str,
required=True,
help="要识别的音频文件路径"
)
parser.add_argument(
"--mode",
type=str,
default="2pass",
choices=["2pass", "online", "offline"],
help="识别模式: 2pass(默认), online, offline"
)
parser.add_argument(
"--chunk_size",
type=str,
default="5,10",
help="分块大小, 格式为'encoder_size,decoder_size'"
)
parser.add_argument(
"--use_ssl",
action="store_true",
help="是否使用SSL连接"
)
return parser.parse_args()
async def send_audio(websocket, audio_file, mode, chunk_size):
"""
发送音频文件到服务器进行识别
参数:
websocket: WebSocket连接
audio_file: 音频文件路径
mode: 识别模式
chunk_size: 分块大小
"""
# 打开并读取WAV文件
with wave.open(audio_file, "rb") as wav_file:
params = wav_file.getparams()
frames = wav_file.readframes(wav_file.getnframes())
# 音频文件信息
print(f"音频文件: {os.path.basename(audio_file)}")
print(f"采样率: {params.framerate}Hz, 通道数: {params.nchannels}")
print(f"采样位深: {params.sampwidth * 8}位, 总帧数: {params.nframes}")
# 设置配置参数
config = {
"mode": mode,
"chunk_size": chunk_size,
"wav_name": os.path.basename(audio_file),
"is_speaking": True
}
# 发送配置
await websocket.send(json.dumps(config))
# 模拟实时发送音频数据
chunk_size_bytes = 3200 # 每次发送100ms的16kHz音频
total_chunks = len(frames) // chunk_size_bytes
print(f"开始发送音频数据,共 {total_chunks} 个数据块...")
try:
for i in range(0, len(frames), chunk_size_bytes):
chunk = frames[i:i+chunk_size_bytes]
await websocket.send(chunk)
# 模拟实时每100ms发送一次
await asyncio.sleep(0.1)
# 显示进度
if (i // chunk_size_bytes) % 10 == 0:
print(f"已发送 {i // chunk_size_bytes}/{total_chunks} 数据块")
# 发送结束信号
await websocket.send(json.dumps({"is_speaking": False}))
print("音频数据发送完成")
except Exception as e:
print(f"发送音频时出错: {e}")
async def receive_results(websocket):
"""
接收并显示识别结果
参数:
websocket: WebSocket连接
"""
online_text = ""
offline_text = ""
try:
async for message in websocket:
# 解析服务器返回的JSON消息
result = json.loads(message)
mode = result.get("mode", "")
text = result.get("text", "")
is_final = result.get("is_final", False)
# 根据模式更新文本
if "online" in mode:
online_text = text
print(f"\r[在线识别] {online_text}", end="", flush=True)
elif "offline" in mode:
offline_text = text
print(f"\n[离线识别] {offline_text}")
# 如果是最终结果,打印完整信息
if is_final and offline_text:
print("\n最终识别结果:")
print(f"[离线识别] {offline_text}")
return
except Exception as e:
print(f"接收结果时出错: {e}")
async def main():
"""主函数"""
args = parse_args()
# WebSocket URI
protocol = "wss" if args.use_ssl else "ws"
uri = f"{protocol}://{args.host}:{args.port}"
print(f"连接到服务器: {uri}")
try:
# 创建WebSocket连接
async with websockets.connect(
uri,
subprotocols=["binary"]
) as websocket:
print("连接成功")
# 创建两个任务: 发送音频和接收结果
send_task = asyncio.create_task(
send_audio(websocket, args.audio_file, args.mode, args.chunk_size)
)
receive_task = asyncio.create_task(
receive_results(websocket)
)
# 等待任务完成
await asyncio.gather(send_task, receive_task)
except Exception as e:
print(f"连接服务器失败: {e}")
if __name__ == "__main__":
# 运行主函数
asyncio.run(main())

@ -1,11 +1,25 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
默认配置DefaultConfig
- audio_config: 音频配置
配置模块 - 处理命令行参数和配置项 配置模块 - 处理命令行参数和配置项
""" """
import argparse import argparse
from src.models import AudioBinary_Config
class DefaultConfig:
"""
默认配置
"""
audio_config = AudioBinary_Config(
chunk_size=200,
chunk_stride=1600,
sample_rate=16000,
sample_width=16,
channels=1,
)
def parse_args(): def parse_args():
""" """
@ -21,41 +35,23 @@ def parse_args():
"--host", "--host",
type=str, type=str,
default="0.0.0.0", default="0.0.0.0",
help="服务器主机地址例如localhost, 0.0.0.0" help="服务器主机地址例如localhost, 0.0.0.0",
)
parser.add_argument(
"--port",
type=int,
default=10095,
help="WebSocket服务器端口"
) )
parser.add_argument("--port", type=int, default=10095, help="WebSocket服务器端口")
# SSL配置 # SSL配置
parser.add_argument( parser.add_argument("--certfile", type=str, default="", help="SSL证书文件路径")
"--certfile", parser.add_argument("--keyfile", type=str, default="", help="SSL密钥文件路径")
type=str,
default="",
help="SSL证书文件路径"
)
parser.add_argument(
"--keyfile",
type=str,
default="",
help="SSL密钥文件路径"
)
# ASR模型配置 # ASR模型配置
parser.add_argument( parser.add_argument(
"--asr_model", "--asr_model",
type=str, type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="离线ASR模型从ModelScope获取" help="离线ASR模型从ModelScope获取",
) )
parser.add_argument( parser.add_argument(
"--asr_model_revision", "--asr_model_revision", type=str, default="v2.0.4", help="离线ASR模型版本"
type=str,
default="v2.0.4",
help="离线ASR模型版本"
) )
# 在线ASR模型配置 # 在线ASR模型配置
@ -63,13 +59,13 @@ def parse_args():
"--asr_model_online", "--asr_model_online",
type=str, type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
help="在线ASR模型从ModelScope获取" help="在线ASR模型从ModelScope获取",
) )
parser.add_argument( parser.add_argument(
"--asr_model_online_revision", "--asr_model_online_revision",
type=str, type=str,
default="v2.0.4", default="v2.0.4",
help="在线ASR模型版本" help="在线ASR模型版本",
) )
# VAD模型配置 # VAD模型配置
@ -77,13 +73,10 @@ def parse_args():
"--vad_model", "--vad_model",
type=str, type=str,
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
help="VAD语音活动检测模型从ModelScope获取" help="VAD语音活动检测模型从ModelScope获取",
) )
parser.add_argument( parser.add_argument(
"--vad_model_revision", "--vad_model_revision", type=str, default="v2.0.4", help="VAD模型版本"
type=str,
default="v2.0.4",
help="VAD模型版本"
) )
# 标点符号模型配置 # 标点符号模型配置
@ -91,34 +84,18 @@ def parse_args():
"--punc_model", "--punc_model",
type=str, type=str,
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
help="标点符号模型从ModelScope获取" help="标点符号模型从ModelScope获取",
) )
parser.add_argument( parser.add_argument(
"--punc_model_revision", "--punc_model_revision", type=str, default="v2.0.4", help="标点符号模型版本"
type=str,
default="v2.0.4",
help="标点符号模型版本"
) )
# 硬件配置 # 硬件配置
parser.add_argument("--ngpu", type=int, default=1, help="GPU数量0表示仅使用CPU")
parser.add_argument( parser.add_argument(
"--ngpu", "--device", type=str, default="cuda", help="设备类型cuda或cpu"
type=int,
default=1,
help="GPU数量0表示仅使用CPU"
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="设备类型cuda或cpu"
)
parser.add_argument(
"--ncpu",
type=int,
default=4,
help="CPU核心数"
) )
parser.add_argument("--ncpu", type=int, default=4, help="CPU核心数")
return parser.parse_args() return parser.parse_args()

3
src/core/__init__.py Normal file

@ -0,0 +1,3 @@
from .model_loader import ModelLoader
__all__ = ["ModelLoader"]

143
src/core/model_loader.py Normal file

@ -0,0 +1,143 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
模型加载模块 - 负责加载各种语音识别相关模型
"""
try:
# 导入FunASR库
from funasr import AutoModel
except ImportError as exc:
raise ImportError("未找到funasr库, 请先安装: pip install funasr") from exc
from modelscope.pipelines import pipeline
# 日志模块
from src.utils import get_module_logger
logger = get_module_logger(__name__)
# 单例模式
class ModelLoader:
"""
ModelLoader类是单例模式, 程序生命周期全局唯一, 负责加载模型到字典中
一般的, 可以直接call ModelLoader()来获取加载的模型
也可以通过ModelLoader实例(args)或ModelloaderInstance.load_models(args)来初始化, 并加载模型
"""
_instance = None
def __new__(cls, *args, **kwargs):
"""
单例模式
"""
if cls._instance is None:
cls._instance = super(ModelLoader, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self, args=None):
"""
初始化ModelLoader实例
"""
self.models = {}
logger.debug("初始化ModelLoader")
if args is not None:
self.__call__(args)
def __call__(self, args=None):
"""
调用ModelLoader实例时, 如果模型字典为空, 则加载模型
"""
# 如果模型字典为空, 则加载模型
if self.models == {} or self.models is None:
if args.asr_model is not None:
self.models = self.load_models(args)
# 直接调用等于调用self.models
return self.models
def _load_model(self, input_model_args: dict, model_type: str):
"""
加载单个模型
参数:
model_args: 模型加载字典
model_type: 模型类型, 用于确定使用哪个模型参数
返回:
AutoModel: 加载的模型实例
"""
# 默认配置
default_config = {
"model": None,
"model_revision": None,
"ngpu": 0,
"ncpu": 1,
"device": "cpu",
"disable_pbar": True,
"disable_log": True,
"disable_update": True,
}
# 从args中获取配置, 如果存在则覆盖默认值
model_args = default_config.copy()
for key, value in default_config.items():
if key in ["model", "model_revision"]:
# 特殊处理model和model_revision, 因为它们需要model_type前缀
if key == "model":
value = input_model_args.get(f"{model_type}_model", None)
else:
value = input_model_args.get(f"{model_type}_model_revision", None)
else:
value = input_model_args.get(key, None)
if value is not None:
logger.debug("替换%s模型参数: %s = %s", model_type, key, value)
model_args[key] = value
# 验证必要参数
if not model_args["model"]:
raise ValueError(f"未指定{model_type}模型路径")
try:
# 使用 % 格式化替代 f-string,避免不必要的字符串格式化开销
logger.debug("正在加载%s模型: %s", model_type, model_args["model"])
model = AutoModel(**model_args)
return model
except Exception as e:
logger.error("加载%s模型失败: %s", model_type, str(e))
raise
# def _load_pipeline(self, input_model_args: dict, model_type: str):
# """
# 加载pipeline
# """
# default_pipeline = pipeline(
# task='speaker-verification',
# model='iic/speech_campplus_sv_zh-cn_16k-common',
# model_revision='v1.0.0'
# )
def load_models(self, args):
"""
加载所有需要的模型
参数:
args: 命令行参数, 包含模型配置
返回:
dict: 包含所有加载的模型的字典
"""
logger.info("ModelLoader加载模型")
# 初始化模型字典
self.models = {}
# 加载离线ASR模型
# 检查对应键是否存在
model_list = ["asr", "asr_online", "vad", "punc"]
for model_name in model_list:
name_model = f"{model_name}_model"
name_model_revision = f"{model_name}_model_revision"
if name_model in args:
logger.debug("加载%s模型", model_name)
self.models[model_name] = self._load_model(args, model_name)
pipeline_list = ["spk"]
for pipeline_name in pipeline_list:
if pipeline_name == "spk":
self.models[pipeline_name] = pipeline(
task='speaker-verification',
model='iic/speech_campplus_sv_zh-cn_16k-common',
model_revision='v1.0.0'
)
return self.models

4
src/functor/__init__.py Normal file

@ -0,0 +1,4 @@
from .vad_functor import VADFunctor
from .base import FunctorFactory
__all__ = ["VADFunctor", "FunctorFactory"]

164
src/functor/asr_functor.py Normal file

@ -0,0 +1,164 @@
"""
ASRFunctor
负责对音频片段进行ASR处理, 以ASR_Result进行callback
"""
from src.functor.base import BaseFunctor
from src.models import AudioBinary_data_list, AudioBinary_Config, VAD_Functor_result
from typing import Callable, List
from queue import Queue, Empty
import threading
# 日志
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class ASRFunctor(BaseFunctor):
"""
ASRFunctor
负责对音频片段进行ASR处理, 以ASR_Result进行callback
需要配置好 _model, _callback, _input_queue, _audio_config
否则无法run()启动线程
运行中, 使用reset_cache()重置缓存, 准备下次任务
使用stop()停止线程, 但需要等待input_queue为空
"""
def __init__(self) -> None:
super().__init__()
# 资源与配置
self._model: dict = {} # 模型
self._callback: List[Callable] = [] # 回调函数
self._input_queue: Queue = None # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置
# flag
self._is_running: bool = False
self._stop_event: bool = False
# 线程资源
self._thread: threading.Thread = None
# 状态锁
self._status_lock: threading.Lock = threading.Lock()
# 缓存
self._hotwords: List[str] = []
def reset_cache(self) -> None:
"""
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
"""
pass
def set_input_queue(self, queue: Queue) -> None:
"""
设置监听的输入消息队列
"""
self._input_queue = queue
def set_model(self, model: dict) -> None:
"""
设置推理模型
"""
self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
"""
设置音频配置
"""
self._audio_config = audio_config
logger.debug("ASRFunctor设置音频配置: %s", self._audio_config)
def add_callback(self, callback: Callable) -> None:
"""
向自身的_callback: List[Callable]回调函数列表中添加回调函数
"""
if not isinstance(self._callback, list):
self._callback = []
self._callback.append(callback)
def _do_callback(self, result: List[str]) -> None:
"""
回调函数
"""
text = result[0]["text"].replace(" ", "")
for callback in self._callback:
callback(text)
def _process(self, data: VAD_Functor_result) -> None:
"""
处理数据
"""
binary_data = data.audiobinary_data.binary_data
result = self._model["asr"].generate(
input=binary_data,
chunk_size=self._audio_config.chunk_size,
hotwords=self._hotwords,
)
self._do_callback(result)
def _run(self) -> None:
"""
线程运行逻辑
"""
with self._status_lock:
self._is_running = True
self._stop_event = False
# 运行逻辑
while self._is_running:
try:
data = self._input_queue.get(True, timeout=1)
if data is None:
break
logger.debug("[ASRFunctor]获取到的数据length: %s", len(data))
self._process(data)
self._input_queue.task_done()
# 当队列为空时, 间隔1s检测是否进入停止事件。
except Empty:
if self._stop_event:
break
continue
# 其他异常
except Exception as e:
logger.error("ASRFunctor运行时发生错误: %s", e)
raise e
def run(self) -> threading.Thread:
"""
启动线程
Returns:
threading.Thread: 返回已运行线程实例
"""
self._pre_check()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
return self._thread
def _pre_check(self) -> bool:
"""
预检查
"""
if self._model is None:
raise ValueError("模型未设置")
if self._audio_config is None:
raise ValueError("音频配置未设置")
if self._input_queue is None:
raise ValueError("输入队列未设置")
if self._callback is None:
raise ValueError("回调函数未设置")
return True
def stop(self) -> bool:
"""
停止线程
"""
with self._status_lock:
self._stop_event = True
self._thread.join()
with self._status_lock:
self._is_running = False
return not self._thread.is_alive()

208
src/functor/base.py Normal file

@ -0,0 +1,208 @@
"""
Functor基础模块
该模块定义了Functor的基类,所有功能性的类(如VADPUNCASRSPK等)都应继承自这个基类
基类提供了数据处理的基本框架,包括:
- 回调函数管理
- 模型配置管理
- 线程运行控制
主要类:
BaseFunctor: Functor抽象类
FunctorFactory: Functor工厂类
"""
from abc import ABC, abstractmethod
from typing import Callable, List, Dict
from queue import Queue
import threading
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class BaseFunctor(ABC):
"""
Functor抽象类
该抽象类规定了所有的Functor类必须实现run()方法启动自身线程
属性:
_callback (Callable): 处理完成后的回调函数
_model (dict): 存储模型相关的配置和实例
"""
def __init__(self):
"""
初始化函数器
参数:
callback (Callable): 处理完成后的回调函数
model (dict): 模型相关的配置和实例
"""
self._callback: List[Callable] = []
self._model: dict = {}
# flag
self._is_running: bool = False
self._stop_event: bool = False
# 状态锁
self._status_lock: threading.Lock = threading.Lock()
# 线程资源
self._thread: threading.Thread = None
def add_callback(self, callback: Callable):
"""
添加回调函数
参数:
callback (Callable): 新的回调函数
"""
self._callback.append(callback)
def set_model(self, model: dict):
"""
设置模型配置
参数:
model (dict): 新的模型配置
"""
self._model = model
def set_input_queue(self, queue: Queue):
"""
设置输入队列
参数:
queue (Queue): 新的输入队列
"""
self._input_queue = queue
@abstractmethod
def _run(self):
"""
线程运行逻辑
返回:
当达到条件时触发callback
"""
@abstractmethod
def run(self):
"""
启动_run方法线程
返回:
线程实例
"""
@abstractmethod
def _pre_check(self):
"""
预检查
返回:
预检查结果
"""
@abstractmethod
def stop(self):
"""
停止线程
返回:
停止结果
"""
class FunctorFactory:
"""
Functor工厂类
该工厂类负责创建和配置Functor实例
主要方法:
make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
创建并配置Functor实例
"""
def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建VAD Functor实例
"""
from src.functor.vad_functor import VADFunctor
audio_config = config["audio_config"]
model = {"vad": models["vad"]}
vad_functor = VADFunctor()
vad_functor.set_audio_config(audio_config)
vad_functor.set_model(model)
return vad_functor
def _make_asrfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建ASR Functor实例
"""
from src.functor.asr_functor import ASRFunctor
audio_config = config["audio_config"]
model = {"asr": models["asr"]}
asr_functor = ASRFunctor()
asr_functor.set_audio_config(audio_config)
asr_functor.set_model(model)
return asr_functor
def _make_spkfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建SPK Functor实例
"""
from src.functor.spk_functor import SPKFunctor
logger.debug(f"创建spk functor[开始]")
audio_config = config["audio_config"]
# model = {"spk": models["spk"]}
spk_functor = SPKFunctor(sv_pipeline=models["spk"])
spk_functor.set_audio_config(audio_config)
# spk_functor.set_model(model)
spk_functor.bake()
logger.debug(f"创建spk functor[完成]")
return spk_functor
def _make_resultbinderfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建ResultBinder Functor实例
"""
from src.functor.resultbinder_functor import ResultBinderFunctor
resultbinder_functor = ResultBinderFunctor()
return resultbinder_functor
factory_dict: Dict[str, Callable] = {
"vad": _make_vadfunctor,
"asr": _make_asrfunctor,
"spk": _make_spkfunctor,
"resultbinder": _make_resultbinderfunctor,
}
@classmethod
def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor:
"""
创建并配置Functor实例
参数:
funtor_name (str): Functor名称
config (dict): 配置信息
models (dict): 模型信息
返回:
BaseFunctor: 创建的Functor实例
"""
if functor_name in cls.factory_dict:
return cls.factory_dict[functor_name](config=config, models=models)
else:
raise ValueError(f"不支持的Functor类型: {functor_name}")

122
src/functor/readme.md Normal file

@ -0,0 +1,122 @@
# 对于Functor的解释
## Functor 文件夹作用
Functor文件夹用于存放所有功能性的类包括VAD、PUNC、ASR、SPK等。
## Functor 类的定义
所有类应继承于**基类**`BaseFunctor`
为了方便使用,我们对于**基类**的定义如下:
1. 函数内部使用的变量以单下划线开头,基类中包含:
* _model: Dict 存放模型相关的配置和实例
* _input_queue: Queue 监听的输入消息队列
* _thread: Threading.Thread 运行的线程实例
* _callback: List[Callable] 回调函数列表
* _is_running: bool 线程运行状态标志
* _stop_event: bool 停止事件标志
* _status_lock: threading.Lock 状态锁,用于线程同步
2. 对于使用的模型,请从统一的 **模型管理类`ModelLoader`** 中获取,由模型管理类统一进行加载、缓存和释放,`_model`存放类型为`dict`
3. 基类定义的核心方法:
* `add_callback(callback: Callable)`: 添加结果处理的回调函数
* `set_model(model: dict)`: 设置模型配置和实例
* `set_input_queue(queue: Queue)`: 设置输入数据队列
* `run()`: 启动处理线程(抽象方法)
* `stop()`: 停止处理线程(抽象方法)
* `_run()`: 线程运行的具体逻辑(抽象方法)
* `_pre_check()`: 运行前的预检查(抽象方法)
## 派生类实现要求
1. 必须实现的抽象方法:
* `_pre_check()`:
- 检查必要的配置是否完整(如模型、队列等)
- 检查运行环境是否满足要求
- 返回检查结果
* `_run()`:
- 实现具体的数据处理逻辑
- 从 _input_queue 获取输入数据
- 使用 _model 进行处理
- 通过 _callback 返回处理结果
* `run()`:
- 调用 _pre_check() 进行预检查
- 创建并启动处理线程
- 设置相关状态标志
* `stop()`:
- 安全停止处理线程
- 清理资源
- 重置状态标志
2. 建议实现的方法:
* `__str__`: 返回当前实例的状态信息
* 错误处理方法:处理运行过程中的异常情况
## 使用示例
```python
class MyFunctor(BaseFunctor):
def _pre_check(self):
if not self._model or not self._input_queue:
return False
return True
def _run(self):
while not self._stop_event:
try:
data = self._input_queue.get(timeout=1.0)
result = self._model['my_model'].process(data)
for callback in self._callback:
callback(result)
except Queue.Empty:
continue
except Exception as e:
logger.error(f"处理错误: {e}")
def run(self):
if not self._pre_check():
raise RuntimeError("预检查失败")
with self._status_lock:
if self._is_running:
return
self._is_running = True
self._stop_event = False
self._thread = threading.Thread(target=self._run)
self._thread.start()
def stop(self):
with self._status_lock:
if not self._is_running:
return
self._stop_event = True
if self._thread:
self._thread.join()
self._is_running = False
```
## 注意事项
1. 线程安全:
* 使用 _status_lock 保护状态变更
* 注意共享资源的访问控制
2. 错误处理:
* 在 _run() 中妥善处理异常
* 提供详细的错误日志
3. 资源管理:
* 确保在 stop() 中正确清理资源
* 避免资源泄露
4. 回调函数:
* 回调函数应该是非阻塞的
* 处理回调函数抛出的异常

@ -0,0 +1,165 @@
"""
ResultBinderFunctor
负责聚合结果将所有input_queue中的结果进行聚合并进行callback
"""
from src.functor.base import BaseFunctor
from src.models import AudioBinary_Config, VAD_Functor_result
from typing import Callable, List, Dict, Any
from queue import Queue, Empty
import threading
import time
# 日志
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class ResultBinderFunctor(BaseFunctor):
"""
ResultBinderFunctor
负责聚合结果将所有input_queue中的结果进行聚合并进行callback
"""
def __init__(self) -> None:
super().__init__()
# 资源与配置
self._callback: List[Callable] = [] # 回调函数
self._input_queue: Dict[str, Queue] = {} # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置
def reset_cache(self) -> None:
"""
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
"""
pass
def add_input_queue(self,
name: str,
queue: Queue,
) -> None:
"""
设置监听的输入消息队列
"""
self._input_queue[name] = queue
def set_model(self, model: dict) -> None:
"""
设置推理模型
resultbinder_functor 不应设置模型
"""
logger.warning("ResultBinderFunctor不应设置模型")
self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
"""
设置音频配置
resultbinder_functor 不应设置音频配置
"""
logger.warning("ResultBinderFunctor不应设置音频配置")
self._audio_config = audio_config
def add_callback(self, callback: Callable) -> None:
"""
向自身的_callback: List[Callable]回调函数列表中添加回调函数
"""
if not isinstance(self._callback, list):
self._callback = []
self._callback.append(callback)
def _do_callback(self, result: List[str]) -> None:
"""
回调函数
"""
for callback in self._callback:
callback(result)
def _process(self, data: Dict[str, Any]) -> None:
"""
处理数据
{
"is_final": false,
"mode": "2pass-offline",
"text": "等一下我回一下ok你看这里就有了",
"wav_name": "h5",
"speaker_id":
}
"""
logger.debug("ResultBinderFunctor处理数据: %s", data)
# 将data中的result进行聚合
# 此步暂时无意义,预留
results = {
"is_final": False,
"mode": "2pass-offline",
"text": data["asr"],
"wav_name": "h5",
"speaker_id": data["spk"]["speaker_id"]
}
# for name, result in data.items():
# results[name] = result
self._do_callback(results)
def _run(self) -> None:
"""
线程运行逻辑
"""
with self._status_lock:
self._is_running = True
self._stop_event = False
# 运行逻辑
while self._is_running:
try:
# 若有队列为空则等待0.1s
for name, queue in self._input_queue.items():
if queue.empty():
time.sleep(0.1)
raise Empty
data = {}
for name, queue in self._input_queue.items():
data[name] = queue.get(True, timeout=1)
queue.task_done()
self._process(data)
# 当队列为空时, 检测是否进入停止事件。
except Empty:
if self._stop_event:
break
continue
# 其他异常
except Exception as e:
logger.error("SpkFunctor运行时发生错误: %s", e)
raise e
def run(self) -> threading.Thread:
"""
启动线程
Returns:
threading.Thread: 返回已运行线程实例
"""
self._pre_check()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
return self._thread
def _pre_check(self) -> bool:
"""
预检查
"""
if self._model is None:
raise ValueError("模型未设置")
if self._input_queue is None:
raise ValueError("输入队列未设置")
if self._callback is None:
raise ValueError("回调函数未设置")
return True
def stop(self) -> bool:
"""
停止线程
"""
with self._status_lock:
self._stop_event = True
self._thread.join()
with self._status_lock:
self._is_running = False
return not self._thread.is_alive()

367
src/functor/spk_functor.py Normal file

@ -0,0 +1,367 @@
"""
SpkFunctor
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
"""
from src.functor.base import BaseFunctor
from src.models import AudioBinary_Config, VAD_Functor_result, SpeakerCreate
from typing import Callable, List, Dict
from queue import Queue, Empty
import json
import torch
import threading
import numpy
import os
import requests
# 日志
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
from modelscope.pipelines import pipeline
# sv_pipeline = pipeline(
# task='speaker-verification',
# model='iic/speech_campplus_sv_zh-cn_16k-common',
# model_revision='v1.0.0'
# )
# speaker1_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker1_a_cn_16k.wav'
# speaker1_b_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker1_b_cn_16k.wav'
# speaker2_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker2_a_cn_16k.wav'
# # 相同说话人语音
# result = sv_pipeline([speaker1_a_wav, speaker1_b_wav])
# print(result)
# # 不同说话人语音
# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav])
# print(result)
# # 可以自定义得分阈值来进行识别,阈值越高,判定为同一人的条件越严格
# result = sv_pipeline([speaker1_a_wav, speaker1_a_wav], thr=0.6)
# print(result)
# # 可以传入output_emb参数输出结果中就会包含提取到的说话人embedding
# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], output_emb=True)
# print(result['embs'], result['outputs'])
# # 可以传入save_dir参数提取到的说话人embedding会存储在save_dir目录中
# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], save_dir='savePath/')
class SPKFunctor(BaseFunctor):
"""
SPKFunctor
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
需要配置好 _model, _callback, _input_queue, _audio_config
否则无法run()启动线程
运行中, 使用reset_cache()重置缓存, 准备下次任务
使用stop()停止线程, 但需要等待input_queue为空
"""
class speaker_verify:
"""
说话人验证
"""
def __init__(self, *args, **kwargs):
self._spk_data: List[SpeakerCreate] = []
def add_speaker(self, speaker: SpeakerCreate) -> None:
self._spk_data.append(speaker)
# logger.debug("添加说话人: %s", speaker)
def verify(self, emb: numpy.ndarray) -> Dict:
# 将输入的numpy embedding转换为tensor
input_emb_tensor = torch.from_numpy(emb).unsqueeze(0)
results = []
for speaker in self._spk_data:
if not speaker.speaker_embs:
continue
# 将列表中存储的embedding转换为tensor
speaker_emb_tensor = torch.tensor(speaker.speaker_embs).unsqueeze(0)
# 计算余弦相似度
score = torch.nn.functional.cosine_similarity(input_emb_tensor, speaker_emb_tensor).item()
result = {
"score": score,
"speaker_id": speaker.speaker_id,
"speaker_name": speaker.speaker_name,
"speaker_description": speaker.speaker_description
}
results.append(result)
if not results:
return {
"speaker_id": "unknown",
"speaker_name": "Unknown",
"speaker_description": "No registered speakers to verify against.",
"score": 0.0,
"results": []
}
results.sort(key=lambda x: x['score'], reverse=True)
best_match = results[0]
return {
"speaker_id": best_match['speaker_id'],
"speaker_name": best_match['speaker_name'],
"speaker_description": best_match['speaker_description'],
"score": best_match['score'],
"results": results
}
def __init__(self, sv_pipeline: pipeline) -> None:
super().__init__()
# 资源与配置
self._spk_verify = self.speaker_verify()
self._sv_pipeline = sv_pipeline
self._model: dict = {} # 模型
self._callback: List[Callable] = [] # 回调函数
self._input_queue: Queue = None # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置
# logger.debug("加载本地说话人数据")
# self.load_spk_data_local()
# import inspect
# func_name = inspect.currentframe().f_code.co_name
# logger.info(f"{func_name} 加载数据库说话人数据")
self._speakers_url = os.getenv("SPEAKERS_URL", None)
def bake(self, *args, **kwargs) -> None:
"""
烘焙
"""
if self._speakers_url is None:
logger.error("未提供说话人数据库API的URL")
return
self.load_spk_database_api(
url=kwargs.get("speakers_url", self._speakers_url)
)
def load_spk_database_api(
self,
url: str = None,
) -> None:
"""
加载api远程数据库说话人数据
"""
# 网络请求后端数据库拿到所有的说话人数据
try:
url = url if url != None else self._speakers_url
if not url:
logger.error("未提供说话人数据库API的URL")
return
logger.debug("加载API说话人数据库: url: %s", url)
response = requests.get(url)
response.raise_for_status()
data = response.json()
items = data.get("items", [])
logger.debug("加载API说话人个数: %s", len(items))
for spk in items:
logger.debug("加载API说话人数据: %s %s", spk.get('speaker_name'), spk.get('speaker_id'))
# 兼容API返回的字段与本地字段
spk_dict = {
"speaker_name": spk.get("speaker_name", ""),
"speaker_id": spk.get("speaker_id", ""),
"speaker_description": spk.get("speaker_description", ""),
"avatar": spk.get("avatar", ""),
"wav_path": "", # API未提供本地wav路径
"speaker_embs": spk.get("audio_feature_vector", []),
}
# 如果API没有embs但有音频样本可以尝试提取
if (not spk_dict["speaker_embs"] or spk_dict["speaker_embs"] is None) and spk.get("audio_sample"):
try:
import base64
import io
import numpy as np
from pydub import AudioSegment
audio_bytes = base64.b64decode(spk["audio_sample"])
# 使用 pydub 处理音频,它能兼容 wav, mp3 等多种格式 (需要 ffmpeg)
audio_segment = AudioSegment.from_file(io.BytesIO(audio_bytes))
# 将音频转换为模型期望的格式,通常是 16k 采样率、单声道
audio_segment = audio_segment.set_frame_rate(16000)
audio_segment = audio_segment.set_channels(1)
# 转换为 float32 的 numpy 数组并归一化
wav_data = np.array(audio_segment.get_array_of_samples()).astype(np.float32)
wav_data /= 32768.0 # 16-bit signed audio is in range [-32768, 32767]
spk_dict["speaker_embs"] = self._sv_pipeline([wav_data], output_emb=True)['embs'][0]
logger.debug("API音频样本转换为embs: %s", spk_dict["speaker_name"])
except Exception:
logger.error("API音频样本 '%s' 转换为embs失败", spk_dict.get("speaker_name", "Unknown"), exc_info=True)
spk_dict["speaker_embs"] = []
# 转为numpy后加入
try:
import numpy
spk_dict["speaker_embs"] = numpy.array(spk_dict["speaker_embs"])
self._spk_verify.add_speaker(SpeakerCreate(**spk_dict))
except Exception as e:
logger.error("添加API说话人到本地失败: %s", e)
except Exception as e:
logger.error("请求说话人数据库API失败: %s", e)
def load_spk_data_local(
self,
spk_data_path: str = 'data/speakers.json',
) -> None:
"""
加载本地说话人数据
"""
with open(spk_data_path, 'r') as f:
spk_data = json.load(f)
for i, spk in enumerate(spk_data):
logger.debug("加载本地说话人数据: %s %s", spk['speaker_name'], spk['speaker_id'])
if spk['speaker_embs'] == "" and spk['wav_path'] != "":
logger.debug("尝试转换本地wav为embs: %s", spk['wav_path'])
try:
# 读取数据为numpy数组
import soundfile as sf
import numpy as np
wav_data, sr = sf.read(spk['wav_path'], dtype='int16')
# 确保是单通道
if wav_data.ndim > 1:
wav_data = wav_data[:, 0]
# 转为numpy数组后送入pipeline
spk['speaker_embs'] = self._sv_pipeline([wav_data], output_emb=True)['embs'][0]
logger.debug("转换本地wav为embs: length=%s type=%s", len(spk['speaker_embs']), type(spk['speaker_embs']))
except Exception as e:
logger.error("转换本地wav为embs失败: %s", e)
else:
pass
# logger.debug("加载本地说话人数据: %s %s", spk['speaker_name'], spk['speaker_id'])
# 将spk的speaker_embs转换为numpy
spk['speaker_embs'] = numpy.array(spk['speaker_embs'])
self._spk_verify.add_speaker(SpeakerCreate(**spk))
spk['speaker_embs'] = spk['speaker_embs'].tolist()
spk_data[i] = spk
# 保存更新后的数据
try:
with open(spk_data_path, 'w') as f:
json.dump(spk_data, f, indent=4)
except Exception as e:
logger.error("保存更新后的数据失败: %s", e)
def reset_cache(self) -> None:
"""
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
"""
pass
def set_speakers_url(self, url: str) -> None:
"""
设置说话人数据库API的URL
"""
self._speakers_url = url
def set_input_queue(self, queue: Queue) -> None:
"""
设置监听的输入消息队列
"""
self._input_queue = queue
def set_model(self, model: dict) -> None:
"""
设置推理模型
"""
self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
"""
设置音频配置
"""
self._audio_config = audio_config
logger.debug("SpkFunctor设置音频配置: %s", self._audio_config)
def add_callback(self, callback: Callable) -> None:
"""
向自身的_callback: List[Callable]回调函数列表中添加回调函数
"""
if not isinstance(self._callback, list):
self._callback = []
self._callback.append(callback)
def _do_callback(self, result: List[str]) -> None:
"""
回调函数
"""
for callback in self._callback:
callback(result)
def _process(self, data: VAD_Functor_result) -> None:
"""
处理数据
"""
binary_data = data.audiobinary_data.binary_data
# result = self._model["spk"].generate(
# input=binary_data,
# chunk_size=self._audio_config.chunk_size,
# )
sv_result = self._sv_pipeline([binary_data], output_emb=True)
embs = sv_result['embs'][0]
result = self._spk_verify.verify(embs)
self._do_callback(result)
def _run(self) -> None:
"""
线程运行逻辑
"""
with self._status_lock:
self._is_running = True
self._stop_event = False
# 运行逻辑
while self._is_running:
try:
data = self._input_queue.get(True, timeout=1)
if data is None:
break
logger.debug("[SPKFunctor]获取到的数据length: %s", len(data))
self._process(data)
self._input_queue.task_done()
# 当队列为空时, 间隔1s检测是否进入停止事件。
except Empty:
if self._stop_event:
break
continue
# 其他异常
except Exception as e:
logger.error("SpkFunctor运行时发生错误: %s", e)
raise e
def run(self) -> threading.Thread:
"""
启动线程
Returns:
threading.Thread: 返回已运行线程实例
"""
self._pre_check()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
return self._thread
def _pre_check(self) -> bool:
"""
预检查
"""
if self._model is None:
raise ValueError("模型未设置")
if self._input_queue is None:
raise ValueError("输入队列未设置")
if self._callback is None:
raise ValueError("回调函数未设置")
return True
def stop(self) -> bool:
"""
停止线程
"""
with self._status_lock:
self._stop_event = True
self._thread.join()
with self._status_lock:
self._is_running = False
return not self._thread.is_alive()

276
src/functor/vad_functor.py Normal file

@ -0,0 +1,276 @@
"""
VADFunctor
负责对音频片段进行VAD处理, 以VAD_Result进行callback
"""
import threading
from queue import Empty, Queue
from typing import List, Any, Callable
import numpy
import time
from datetime import datetime
from src.models import (
VAD_Functor_result,
AudioBinary_Config,
AudioBinary_data_list,
)
from src.functor.base import BaseFunctor
# 日志
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class VADFunctor(BaseFunctor):
"""
VADFunctor
负责对音频片段进行VAD处理, 以VAD_Result进行callback
需要配置好 _model, _callback, _input_queue, _audio_config, _audio_binary_data_list
否则无法run()启动线程
运行中, 使用reset_cache()重置缓存, 准备下次任务
使用stop()停止线程, 但需要等待input_queue为空
"""
def __init__(self) -> None:
super().__init__()
# 资源与配置
self._model: dict = {} # 模型
self._callback: List[Callable] = [] # 回调函数
self._input_queue: Queue = None # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置
self._audio_binary_data_list: AudioBinary_data_list = None # 音频数据列表
# flag
# 此处用到两个锁但都是为了截断_run线程考虑后续优化
self._is_running: bool = False
self._stop_event: bool = False
# 线程资源
self._thread: threading.Thread = None
# 状态锁
self._status_lock: threading.Lock = threading.Lock()
# 缓存
self._audio_cache: numpy.ndarray = None
self._audio_cache_preindex: int = 0
self._model_cache: dict = {}
self._cache_result_list = []
self._audiobinary_cache = None
def reset_cache(self) -> None:
"""
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
"""
self._audio_cache = None
self._audio_cache_preindex = 0
self._model_cache = {}
self._cache_result_list = []
self._audiobinary_cache = None
def set_input_queue(self, queue: Queue) -> None:
"""
设置监听的输入消息队列
"""
self._input_queue = queue
def set_model(self, model: dict) -> None:
"""
设置推理模型
"""
self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
"""
设置音频配置
"""
self._audio_config = audio_config
logger.debug("VADFunctor设置音频配置: %s", self._audio_config)
def set_audio_binary_data_list(
self, audio_binary_data_list: AudioBinary_data_list
) -> None:
"""
设置音频数据列表, 为Class AudioBinary_data_list类型
AudioBinary_data_list包含binary_data_list, 为list[_AudioBinary_data]类型
_AudioBinary_data包含binary_data, 为bytes/numpy.ndarray类型
"""
self._audio_binary_data_list = audio_binary_data_list
def add_callback(self, callback: Callable) -> None:
"""
向自身的_callback: List[Callable]回调函数列表中添加回调函数
"""
if not isinstance(self._callback, list):
self._callback = []
self._callback.append(callback)
def _do_callback(self, result: List[List[int]]) -> None:
"""
回调函数
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice
输入:
result: List[[start,end]] 处理所得VAD端点
其中若start==-1, 则表示前无端点, 若end==-1, 则表示后无端点
当处理得到一个完成片段时, 存入AudioBinary中, 并向队列中添加AudioBinary_Slice
输出:
None
"""
# 持久化缓存结果队列
for pair in result:
[start, end] = pair
# 若无前端点, 则向缓存队列中合并
if start == -1:
self._cache_result_list[-1][1] = end
else:
self._cache_result_list.append(pair)
logger.debug(f"VADFunctor结果: {self._cache_result_list}")
while len(self._cache_result_list) > 0 and self._cache_result_list[0][1] != -1:
logger.debug(f"VADFunctor结果: {self._cache_result_list}")
logger.debug(f"VADFunctor list[0][1]: {self._cache_result_list[0][1]}")
# 创建VAD片段
# 计算开始帧
start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0])
start_frame -= self._audio_cache_preindex
# 计算结束帧
end_frame = self._audio_config.ms2frame(self._cache_result_list[0][1])
end_frame -= self._audio_cache_preindex
# 计算开始时间
timestamp = time.time()
format_time = datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
logger.debug(f"{format_time}创建VAD片段: {start_frame} - {end_frame}")
vad_result = VAD_Functor_result.create_from_push_data(
audiobinary_data_list=self._audio_binary_data_list,
data=self._audiobinary_cache[start_frame:end_frame],
start_time=self._cache_result_list[0][0],
end_time=self._cache_result_list[0][1],
)
logger.debug(f"{format_time}创建VAD片段成功: {vad_result}")
self._audio_cache_preindex += end_frame
self._audiobinary_cache = self._audiobinary_cache[end_frame:]
for callback in self._callback:
callback(vad_result)
self._cache_result_list.pop(0)
def _predeal_data(self, data: Any) -> None:
"""
预处理数据, 将数据缓存到_audio_cache和_audiobinary_cache中
"""
if self._audio_cache is None:
self._audio_cache = data
else:
# 拼接音频数据
if isinstance(self._audio_cache, numpy.ndarray):
self._audio_cache = numpy.concatenate((self._audio_cache, data))
elif isinstance(self._audio_cache, list):
self._audio_cache.append(data)
if self._audiobinary_cache is None:
self._audiobinary_cache = data
else:
# 拼接音频数据
if isinstance(self._audiobinary_cache, numpy.ndarray):
self._audiobinary_cache = numpy.concatenate(
(self._audiobinary_cache, data)
)
elif isinstance(self._audiobinary_cache, list):
self._audiobinary_cache.append(data)
def _process(self, data: Any):
"""
处理数据
使用model进行生成, 并使用_do_callback进行回调
"""
if data is None:
result = self._model["vad"].generate(
input=self._audio_cache,
cache=self._model_cache,
chunk_size=self._audio_config.chunk_size,
is_final=True,
)
self._do_callback(result[0]["value"])
return
self._predeal_data(data)
if len(self._audio_cache) >= self._audio_config.chunk_stride:
result = self._model["vad"].generate(
input=self._audio_cache,
cache=self._model_cache,
chunk_size=self._audio_config.chunk_size,
max_end_silence_time = 300,
is_final=False,
)
if len(result[0]["value"]) > 0:
self._do_callback(result[0]["value"])
logger.debug(f"VADFunctor结果: {result[0]['value']}")
self._audio_cache = None
def _run(self):
"""
线程运行逻辑
监听输入队列, 当有数据时, 处理数据
当输入队列为空时, 间隔1s检测是否进入停止事件
"""
# 刷新运行状态
with self._status_lock:
self._is_running = True
self._stop_event = False
# 运行逻辑
while self._is_running:
try:
data = self._input_queue.get(True, timeout=1)
# logger.debug("[VADFunctor]获取到的数据length: %s", len(data))
self._process(data)
self._input_queue.task_done()
if data is None:
break
# 当队列为空时, 间隔1s检测是否进入停止事件。
except Empty:
if self._stop_event:
break
continue
# 其他异常
except Exception as e:
logger.error("VADFunctor运行时发生错误: %s", e)
raise e
def run(self):
"""
启动 _run 线程, 并返回线程对象
"""
self._pre_check()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
return self._thread
def _pre_check(self) -> bool:
"""
检测硬性资源是否都已设置
"""
if self._model is None:
raise ValueError("模型未设置")
if self._audio_config is None:
raise ValueError("音频配置未设置")
if self._audio_binary_data_list is None:
raise ValueError("音频数据列表未设置")
if self._input_queue is None:
raise ValueError("输入队列未设置")
if self._callback is None:
raise ValueError("回调函数未设置")
return True
def stop(self):
"""
停止线程
通过设置_stop_event为True, 来在input_queue.get()循环为空时退出
"""
with self._status_lock:
self._stop_event = True
self._thread.join()
with self._status_lock:
self._is_running = False
return not self._thread.is_alive()

176
src/logic_trager.py Normal file

@ -0,0 +1,176 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
逻辑触发器类 - 用于处理音频数据并触发相应的处理逻辑
"""
from src.utils.logger import get_module_logger
from typing import Any, Dict, Type, Callable
# 配置日志
logger = get_module_logger(__name__, level="INFO")
class AutoAfterMeta(type):
"""
自动调用__after__函数的元类
实现单例模式
"""
_instances: Dict[Type, Any] = {} # 存储单例实例
def __new__(cls, name, bases, attrs):
# 遍历所有属性
for attr_name, attr_value in attrs.items():
# 如果是函数且不是以_开头
if callable(attr_value) and not attr_name.startswith("__"):
# 获取原函数
original_func = attr_value
# 创建包装函数
def make_wrapper(func):
def wrapper(self, *args, **kwargs):
# 执行原函数
result = func(self, *args, **kwargs)
# 构建_after_函数名
after_func_name = f"__after__{func.__name__}"
# 检查是否存在对应的_after_函数
if hasattr(self, after_func_name):
after_func = getattr(self, after_func_name)
if callable(after_func):
try:
# 调用_after_函数
after_func()
except Exception as e:
logger.error(f"调用{after_func_name}时出错: {e}")
return result
return wrapper
# 替换原函数
attrs[attr_name] = make_wrapper(original_func)
# 创建类
new_class = super().__new__(cls, name, bases, attrs)
return new_class
def __call__(cls, *args, **kwargs):
"""
重写__call__方法实现单例模式
当类被调用时即创建实例时执行
"""
if cls not in cls._instances:
# 如果实例不存在,创建新实例
cls._instances[cls] = super().__call__(*args, **kwargs)
logger.info(f"创建{cls.__name__}的新实例")
else:
logger.debug(f"返回{cls.__name__}的现有实例")
return cls._instances[cls]
"""
整体识别的处理逻辑
1.压入二进制音频信息
2.不断检测VAD
3.当检测到完整VAD时,将VAD的音频信息压入音频块,并清除对应二进制信息
4.对音频块进行语音转文字offline,时间戳预测,说话人识别
5.将识别结果整合压入结果队列
6.结果队列被压入时调用回调函数
1->2 __after__push_binary_data 外部压入二进制信息
2,3->4 __after__push_audio_chunk 内部压入音频块
4->5 push_result_queue 压入结果队列
5->6 __after__push_result_queue 调用回调函数
"""
from src.functor import VAD
from src.models import AudioBinary_Config
from src.models import AudioBinary_Chunk
from typing import List
class LogicTrager(metaclass=AutoAfterMeta):
"""逻辑触发器类"""
def __init__(
self,
audio_chunk_max_size: int = 1024 * 1024 * 10,
audio_config: AudioBinary_Config = None,
result_callback: Callable = None,
models: Dict[str, Any] = None,
):
"""初始化"""
# 存储音频块
self._audio_chunk: List[AudioBinary_Chunk] = []
# 存储二进制数据
self._audio_chunk_binary = b""
self._audio_chunk_max_size = audio_chunk_max_size
# 音频参数
self._audio_config = (
audio_config if audio_config is not None else AudioBinary_Config()
)
# 结果队列
self._result_queue = []
# 聚合结果回调函数
self._aggregate_result_callback = result_callback
# 组件
self._vad = VAD(VAD_model=models.get("vad"), audio_config=self._audio_config)
self._vad.set_callback(self.push_audio_chunk)
logger.info("初始化LogicTrager")
def push_binary_data(self, chunk: bytes) -> None:
"""
压入音频块至VAD模块
参数:
chunk: 音频数据块
"""
# print("LogicTrager push_binary_data", len(chunk))
self._vad.push_binary_data(chunk)
self.__after__push_binary_data()
def __after__push_binary_data(self) -> None:
"""
添加音频块后处理
"""
# print("LogicTrager __after__push_binary_data")
self._vad.process_vad_result()
def push_audio_chunk(self, chunk: AudioBinary_Chunk) -> None:
"""
音频处理
"""
logger.info(
"LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(
chunk.start_time, chunk.end_time, len(chunk.chunk)
)
)
self._audio_chunk.append(chunk)
def __after__push_audio_chunk(self) -> None:
"""
压入音频块后处理
"""
pass
def push_result_queue(self, result: Dict[str, Any]) -> None:
"""
压入结果队列
"""
self._result_queue.append(result)
def __after__push_result_queue(self) -> None:
"""
压入结果队列后处理
"""
logger.info("FINISH Result=")
pass
def __call__(self):
"""调用函数"""
pass

@ -1,79 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
模型加载模块 - 负责加载各种语音识别相关模型
"""
def load_models(args):
"""
加载所有需要的模型
参数:
args: 命令行参数包含模型配置
返回:
dict: 包含所有加载的模型的字典
"""
try:
# 导入FunASR库
from funasr import AutoModel
except ImportError:
raise ImportError("未找到funasr库请先安装: pip install funasr")
# 初始化模型字典
models = {}
# 1. 加载离线ASR模型
print(f"正在加载ASR离线模型: {args.asr_model}")
models["asr"] = AutoModel(
model=args.asr_model,
model_revision=args.asr_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
# 2. 加载在线ASR模型
print(f"正在加载ASR在线模型: {args.asr_model_online}")
models["asr_streaming"] = AutoModel(
model=args.asr_model_online,
model_revision=args.asr_model_online_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
# 3. 加载VAD模型
print(f"正在加载VAD模型: {args.vad_model}")
models["vad"] = AutoModel(
model=args.vad_model,
model_revision=args.vad_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
# 4. 加载标点符号模型(如果指定)
if args.punc_model:
print(f"正在加载标点符号模型: {args.punc_model}")
models["punc"] = AutoModel(
model=args.punc_model,
model_revision=args.punc_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
else:
models["punc"] = None
print("未指定标点符号模型,将不使用标点符号")
print("所有模型加载完成")
return models

11
src/models/__init__.py Normal file

@ -0,0 +1,11 @@
from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
from .vad import VAD_Functor_result
from .spk import SpeakerCreate, SpeakerResponse
__all__ = [
"AudioBinary_Config",
"AudioBinary_data_list",
"_AudioBinary_data",
"VAD_Functor_result",
]

158
src/models/audio.py Normal file

@ -0,0 +1,158 @@
from pydantic import BaseModel, Field, validator
from typing import List, Any
import numpy
from src.utils import get_module_logger
logger = get_module_logger(__name__)
binary_data_types = (bytes, numpy.ndarray)
class AudioBinary_Config(BaseModel):
"""二进制音频块配置信息"""
class Config:
arbitrary_types_allowed = True
audio_data: binary_data_types = Field(description="音频数据", default=None)
chunk_size: int = Field(description="块大小", default=100)
chunk_stride: int = Field(description="块步长", default=1600)
sample_rate: int = Field(description="采样率", default=16000)
sample_width: int = Field(description="采样位宽", default=2)
channels: int = Field(description="通道数", default=1)
# 从Dict中加载
@classmethod
def AudioBinary_Config_from_dict(cls, data: dict):
return cls(**data)
def ms2frame(self, ms: int) -> int:
"""
将毫秒转换为帧
"""
return int(ms * self.sample_rate / 1000)
def frame2ms(self, frame: int) -> int:
"""
将帧转换为毫秒
"""
return int(frame * 1000 / self.sample_rate)
class _AudioBinary_data(BaseModel):
"""音频数据"""
binary_data: binary_data_types = Field(description="音频二进制数据", default=None)
class Config:
arbitrary_types_allowed = True
@validator("binary_data")
def validate_binary_data(cls, v):
"""
验证音频数据
Args:
v: 音频数据
Returns:
binary_data_types: 音频数据
"""
if not isinstance(v, (bytes, numpy.ndarray)):
logger.warning(
"[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查",
cls.__class__.__name__,
type(v),
)
return v
def __len__(self):
"""
获取音频数据长度
Returns:
int: 音频数据长度
"""
return len(self.binary_data)
def __init__(self, binary_data: binary_data_types):
"""
初始化音频数据
Args:
binary_data: 音频数据
"""
logger.debug(
"[%s]初始化音频数据, 数据类型为%s",
self.__class__.__name__,
type(binary_data),
)
super().__init__(binary_data=binary_data)
def __getitem__(self):
"""
当获取数据时, 直接返回binary_data
Returns:
binary_data_types: 音频数据
"""
return self.binary_data
class AudioBinary_data_list(BaseModel):
"""音频数据列表"""
binary_data_list: List[_AudioBinary_data] = Field(
description="音频数据列表", default=[]
)
class Config:
arbitrary_types_allowed = True
def push_data(self, data: binary_data_types) -> int:
"""
添加音频数据
Args:
data: 音频数据
Returns:
int: 数据在binary_data_list中的索引
"""
self.binary_data_list.append(_AudioBinary_data(binary_data=data))
return len(self.binary_data_list) - 1
def __getitem__(self, index: int):
"""
获取音频数据
Args:
index: 音频数据在binary_data_list中的索引
Returns:
_AudioBinary_data: 音频数据
"""
return self.binary_data_list[index]
def __len__(self):
"""
获取音频数据列表长度
Returns:
int: 音频数据列表长度
"""
return len(self.binary_data_list)
# class AudioBinary_Slice(BaseModel):
# """音频块切片"""
# target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None)
# start_index: int = Field(description="开始位置", default=0)
# end_index: int = Field(description="结束位置", default=-1)
# @validator('start_index')
# def validate_start_index(cls, v):
# if v < 0:
# raise ValueError("start_index必须大于0")
# return v
# @validator('end_index')
# def validate_end_index(cls, v):
# if v < cls.start_index:
# logger.debug("[%s]end_index小于start_index, 将end_index设置为start_index", cls.__class__.__name__)
# v = cls.start_index
# return v
# def __call__(self):
# return self.target_Binary(self.start_index, self.end_index)

70
src/models/spk.py Normal file

@ -0,0 +1,70 @@
"""
src/models/spk.py
------------------------
此模块定义与说话人speakers表对应的 Pydantic 模型用于 API 数据验证和序列化
模型说明
- SpeakerBase定义说话人的基础字段包括姓名与描述
- SpeakerCreate用于创建说话人时的数据验证直接继承 SpeakerBase
- SpeakerUpdate用于更新说话人信息时所有字段均为可选
- SpeakerResponse返回给客户端时使用包含数据库生成的字段 speaker_idcreated_at
"""
from datetime import datetime
from typing import Optional, List
from uuid import UUID
from pydantic import BaseModel, Field
from src.utils import get_module_logger
logger = get_module_logger(__name__)
# 基础模型,定义说话人的核心属性
class SpeakerBase(BaseModel):
speaker_id: UUID = Field(
...,
description="说话人唯一标识符"
)
speaker_name: str = Field(
...,
max_length=255,
description="说话人姓名必填最多255字符"
)
speaker_description: Optional[str] = Field(
None,
description="说话人描述信息,可选"
)
speaker_embs: Optional[List[float]] = Field(
None,
description="说话人embedding可选"
)
# 用于创建说话人时的数据模型,直接继承 SpeakerBase
class SpeakerCreate(SpeakerBase):
pass
# 用于更新说话人时的数据模型,所有字段都是可选的
class SpeakerUpdate(BaseModel):
speaker_name: Optional[str] = Field(
None,
max_length=255,
description="说话人姓名"
)
speaker_description: Optional[str] = Field(
None,
description="说话人描述信息"
)
# 用于 API 返回的说话人响应模型,扩展基础模型以包含数据库生成的字段
class SpeakerResponse(SpeakerBase):
speaker_id: UUID = Field(
...,
description="说话人的唯一标识符UUID"
)
created_at: datetime = Field(
...,
description="记录创建时间"
)
updated_at: datetime = Field(
...,
description="最近更新时间"
)

91
src/models/vad.py Normal file

@ -0,0 +1,91 @@
from pydantic import BaseModel, Field, validator
from typing import List, Optional, Callable, Any
from .audio import AudioBinary_data_list, _AudioBinary_data
class VAD_Functor_result(BaseModel):
"""
VADFunctor结果
"""
audiobinary_data_list: AudioBinary_data_list = Field(description="音频数据列表")
audiobinary_index: int = Field(description="音频数据索引")
audiobinary_data: _AudioBinary_data = Field(
description="音频数据, 指向AudioBinary_data"
)
start_time: int = Field(description="开始时间", is_required=True)
end_time: int = Field(description="结束时间", is_required=True)
@validator("audiobinary_data_list")
def validate_audiobinary_data_list(cls, v):
if not isinstance(v, AudioBinary_data_list):
raise ValueError("audiobinary_data_list必须是AudioBinary_data_list类型")
return v
@validator("audiobinary_index")
def validate_audiobinary_index(cls, v):
if not isinstance(v, int):
raise ValueError("audiobinary_index必须是int类型")
if v < 0:
raise ValueError("audiobinary_index必须大于0")
return v
@validator("audiobinary_data")
def validate_audiobinary_data(cls, v):
if not isinstance(v, _AudioBinary_data):
raise ValueError("audiobinary_data必须是AudioBinary_data类型")
return v
@validator("start_time")
def validate_start_time(cls, v):
if not isinstance(v, int):
raise ValueError("start_time必须是int类型")
if v < 0:
raise ValueError("start_time必须大于0")
return v
@validator("end_time")
def validate_end_time(cls, v, values):
if not isinstance(v, int):
raise ValueError("end_time必须是int类型")
if "start_time" in values and v <= values["start_time"]:
raise ValueError("end_time必须大于start_time")
return v
@classmethod
def create_from_push_data(
cls,
audiobinary_data_list: AudioBinary_data_list,
data: Any,
start_time: int,
end_time: int,
):
"""
创建VAD片段
"""
index = audiobinary_data_list.push_data(data)
return cls(
audiobinary_data_list=audiobinary_data_list,
audiobinary_index=index,
audiobinary_data=audiobinary_data_list[index],
start_time=start_time,
end_time=end_time,
)
def __len__(self):
"""
获取音频数据长度
"""
return len(self.audiobinary_data.binary_data)
def __str__(self):
"""
字符串展示内容
"""
tostr = f"audiobinary_data_index: {self.audiobinary_index}\n"
tostr += f"start_time: {self.start_time}\n"
tostr += f"end_time: {self.end_time}\n"
tostr += f"data_length: {len(self.audiobinary_data.binary_data)}\n"
tostr += f"data_type: {type(self.audiobinary_data.binary_data)}\n"
return tostr

280
src/pipeline/ASRpipeline.py Normal file

@ -0,0 +1,280 @@
from src.pipeline.base import PipelineBase
from typing import Dict, Any, Callable
from queue import Queue, Empty
from src.utils import get_module_logger
from src.models import AudioBinary_data_list
import threading
logger = get_module_logger(__name__)
class ASRPipeline(PipelineBase):
"""
管道类
实现具体的处理逻辑
"""
def __init__(self, *args, **kwargs):
"""
初始化管道
"""
super().__init__(*args, **kwargs)
self._config: Dict[str, Any] = {}
self._functor_dict: Dict[str, Any] = {}
self._subqueue_dict: Dict[str, Any] = {}
self._is_baked: bool = False
self._input_queue: Queue = None
self._audio_binary_data_list: AudioBinary_data_list = None
self._status_lock = threading.Lock()
self._is_running: bool = False
self._stop_event: bool = False
self._callback: Callable = None
def set_config(self, config: Dict[str, Any]) -> None:
"""
设置配置
参数:
config: Dict[str, Any] 配置
"""
self._config = config
def get_config(self) -> Dict[str, Any]:
"""
获取配置
返回:
Dict[str, Any] 配置
"""
return self._config
def set_audio_binary(self, audio_binary: AudioBinary_data_list) -> None:
"""
设置音频二进制存储单元
参数:
audio_binary: 音频二进制
"""
self._audio_binary = audio_binary
def set_models(self, models: Dict[str, Any]) -> None:
"""
设置模型
"""
self._models = models
def set_input_queue(self, input_queue: Queue) -> None:
"""
设置输入队列
"""
self._input_queue = input_queue
def set_callback(self, callback: Callable) -> None:
"""
设置回调函数
"""
self._callback = callback
def bake(self) -> None:
"""
烘焙管道
"""
self._pre_check_resource()
self._init_functor()
self._is_baked = True
def _pre_check_resource(self) -> None:
"""
预检查资源
"""
if self._input_queue is None:
raise RuntimeError("[ASRpipeline]输入队列未设置")
if self._functor_dict is None:
raise RuntimeError("[ASRpipeline]functor字典未设置")
if self._subqueue_dict is None:
raise RuntimeError("[ASRpipeline]子队列字典未设置")
if self._audio_binary is None:
raise RuntimeError("[ASRpipeline]音频存储单元未设置")
def _init_functor(self) -> None:
"""
初始化函数
自身的functor流程图如下
self.input_queue(self.run检测输入到subqueue["original"])->vad
->vad2asr ->asrend
->vad2spk ->spkend
->asrend+spkend->resultbinder
->self.callback
"""
try:
from src.functor import FunctorFactory
# 加载VAD、asr、spk functor
logger.debug(f"使用FunctorFactory创建vad functor")
self._functor_dict["vad"] = FunctorFactory.make_functor(
functor_name="vad", config=self._config, models=self._models
)
logger.debug(f"使用FunctorFactory创建asr functor")
self._functor_dict["asr"] = FunctorFactory.make_functor(
functor_name="asr", config=self._config, models=self._models
)
logger.debug(f"使用FunctorFactory创建spk functor")
self._functor_dict["spk"] = FunctorFactory.make_functor(
functor_name="spk", config=self._config, models=self._models
)
logger.debug(f"使用FunctorFactory创建resultbinder functor")
self._functor_dict["resultbinder"] = FunctorFactory.make_functor(
functor_name="resultbinder", config=self._config, models=self._models
)
# 创建音频数据存储单元
self._audio_binary_data_list = AudioBinary_data_list()
self._functor_dict["vad"].set_audio_binary_data_list(
self._audio_binary_data_list
)
# 初始化子队列
self._subqueue_dict["original"] = Queue()
self._subqueue_dict["vad2asr"] = Queue()
self._subqueue_dict["vad2spk"] = Queue()
self._subqueue_dict["asrend"] = Queue()
self._subqueue_dict["spkend"] = Queue()
# 输出队列
self._subqueue_dict["OUTPUT"] = Queue()
# 设置子队列的输入队列
self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"])
self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
# 设置resultbinder的输入队列
# 汇总 asr语音识别结果 和 说话人识别结果
self._functor_dict["resultbinder"].add_input_queue(
"asr", self._subqueue_dict["asrend"]
)
self._functor_dict["resultbinder"].add_input_queue(
"spk", self._subqueue_dict["spkend"]
)
# 设置回调函数——放置到对应队列中
self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put)
self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put)
# 设置asr与spk的回调函数
self._functor_dict["asr"].add_callback(self._subqueue_dict["asrend"].put)
self._functor_dict["spk"].add_callback(self._subqueue_dict["spkend"].put)
# 设置resultbinder的回调函数 为 自身被设置的回调函数,用于和外界交互
self._functor_dict["resultbinder"].add_callback(self._callback)
except ImportError as e:
raise ImportError(f"functorFactory引入失败,ASRPipeline无法完成初始化: {str(e)}")
except Exception as e:
raise Exception(f"ASRPipeline初始化失败: {str(e)}")
def _check_result(self, result: Any) -> None:
"""
检查结果
"""
# 若asr和spk队列中都有数据则合并数据
if (
self._subqueue_dict["asrend"].qsize()
& self._subqueue_dict["spkend"].qsize()
):
asr_data = self._subqueue_dict["asrend"].get()
spk_data = self._subqueue_dict["spkend"].get()
# 合并数据
result = {"asr_data": asr_data, "spk_data": spk_data}
# 通知回调函数
self._notify_callbacks(result)
def run(self) -> threading.Thread:
"""
运行管道
Returns:
threading.Thread: 返回已运行线程实例
"""
# 检查运行资源是否准备完毕
self._pre_check()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
logger.info("[ASRpipeline]管道开始运行")
return self._thread
def _pre_check(self) -> None:
"""
预检查
"""
if self._is_baked is False:
raise RuntimeError("[ASRpipeline]管道未烘焙,无法运行")
for functor_name, functor in self._functor_dict.items():
if functor is None:
raise RuntimeError(f"[ASRpipeline]functor{functor_name}异常")
for subqueue_name, subqueue in self._subqueue_dict.items():
if subqueue is None:
raise RuntimeError(f"[ASRpipeline]子队列{subqueue_name}异常")
def _run(self) -> None:
"""
真实的运行逻辑
"""
# 运行所有functor
for functor_name, functor in self._functor_dict.items():
logger.info(f"[ASRpipeline]运行{functor_name}functor")
self._functor_dict[functor_name].run()
# 设置管道运行状态
with self._status_lock:
self._is_running = True
self._stop_event = False
while self._is_running and not self._stop_event:
try:
data = self._input_queue.get(timeout=self._queue_timeout)
# logger.debug("[ASRpipeline]获取到的数据length: %s", len(data))
# 检查是否是结束信号
if data is None:
logger.info("收到结束信号,管道准备停止")
self._input_queue.task_done() # 标记结束信号已处理
break
# 处理数据
self._process(data)
# 标记任务完成
self._input_queue.task_done()
except Empty:
# 队列获取超时,继续等待
continue
except Exception as e:
logger.error(f"[ASRpipeline]管道处理数据出错: {str(e)}")
break
logger.info("[ASRpipeline]管道停止运行")
def _process(self, data: Any) -> Any:
"""
处理数据
参数:
data: 输入数据
返回:
处理结果
"""
# 子类实现具体的处理逻辑
self._subqueue_dict["original"].put(data)
def stop(self) -> None:
"""
停止管道
"""
with self._status_lock:
self._is_running = False
self._stop_event = True
for functor_name, functor in self._functor_dict.items():
# logger.info(f"停止{functor_name}functor")
if functor.stop():
logger.info(f"[ASRpipeline]子Functor[{functor_name}]停止")
else:
logger.error(f"[ASRpipeline]子Functor[{functor_name}]停止失败")
self._thread.join()
logger.info("[ASRpipeline]管道停止")
return True

3
src/pipeline/__init__.py Normal file

@ -0,0 +1,3 @@
from src.pipeline.base import PipelineBase, PipelineFactory
__all__ = ["PipelineBase", "PipelineFactory"]

151
src/pipeline/base.py Normal file

@ -0,0 +1,151 @@
from abc import ABC, abstractmethod
from queue import Queue, Empty
from typing import List, Callable, Any, Optional
import logging
import threading
import time
# 配置日志
logger = logging.getLogger(__name__)
class PipelineBase(ABC):
"""
管道基类
定义了管道的基本接口和通用功能
"""
def __init__(self, input_queue: Optional[Queue] = None):
"""
初始化管道
参数:
input_queue: 输入队列用于接收数据
"""
self._input_queue = input_queue
self._callbacks: List[Callable] = []
self._is_running = False
self._stop_event = False
self._thread: Optional[threading.Thread] = None
self._stop_timeout = 5 # 默认停止超时时间(秒)
self._queue_timeout = 1 # 队列获取超时时间(秒)
def set_input_queue(self, queue: Queue) -> None:
"""
设置输入队列
参数:
queue: 输入队列
"""
self._input_queue = queue
def add_callback(self, callback: Callable) -> None:
"""
添加回调函数
参数:
callback: 回调函数接收处理结果
"""
self._callbacks.append(callback)
def _notify_callbacks(self, result: Any) -> None:
"""
通知所有回调函数
参数:
result: 处理结果
"""
for callback in self._callbacks:
try:
callback(result)
except Exception as e:
logger.error(f"回调函数执行出错: {str(e)}")
@abstractmethod
def _process(self, data: Any) -> Any:
"""
处理数据
参数:
data: 输入数据
返回:
处理结果
"""
pass
@abstractmethod
def _run(self) -> None:
"""
运行管道
从输入队列获取数据并处理
"""
pass
def stop(self, timeout: Optional[float] = None) -> bool:
"""
停止管道
参数:
timeout: 停止超时时间None表示使用默认超时时间
返回:
bool: 是否成功停止
"""
if not self._is_running:
return True
logger.info("正在停止管道...")
self._stop_event = True
self._is_running = False
# 等待线程结束
if self._thread and self._thread.is_alive():
timeout = timeout if timeout is not None else self._stop_timeout
self._thread.join(timeout=timeout)
# 检查是否成功停止
if self._thread.is_alive():
logger.warning(f"管道停止超时({timeout}秒),强制终止")
return False
else:
logger.info("管道已成功停止")
return True
return True
def force_stop(self) -> None:
"""
强制停止管道
注意这可能会导致资源未正确释放
"""
logger.warning("强制停止管道")
self._stop_event = True
self._is_running = False
# 注意Python的线程无法被强制终止这里只是设置标志
# 实际终止需要依赖操作系统的进程管理
class PipelineFactory:
"""
管道工厂类
用于创建管道实例
"""
from src.pipeline.ASRpipeline import ASRPipeline
def _create_pipeline_ASRpipeline(*args, **kwargs) -> ASRPipeline:
"""
创建ASR管道实例
"""
from src.pipeline.ASRpipeline import ASRPipeline
pipeline = ASRPipeline()
pipeline.set_config(kwargs["config"])
pipeline.set_models(kwargs["models"])
pipeline.set_audio_binary(kwargs["audio_binary"])
pipeline.set_input_queue(kwargs["input_queue"])
pipeline.set_callback(kwargs["callback"])
# pipeline.bake()
return pipeline
@classmethod
def create_pipeline(cls, pipeline_name: str, *args, **kwargs) -> Any:
"""
创建管道实例
"""
if pipeline_name == "ASRpipeline":
return cls._create_pipeline_ASRpipeline(*args, **kwargs)
else:
raise ValueError(f"不支持的管道类型: {pipeline_name}")

0
src/pipeline/test.py Normal file

231
src/runner/ASRRunner.py Normal file

@ -0,0 +1,231 @@
"""
-*- encoding: utf-8 -*-
ASRRunner
继承RunnerBase
专属pipeline为ASRPipeline
"""
from src.pipeline.ASRpipeline import ASRPipeline
from src.pipeline import PipelineFactory
from src.models import AudioBinary_data_list, AudioBinary_Config
from src.core.model_loader import ModelLoader
from src.config import DefaultConfig
import asyncio
from queue import Queue
import soundfile
import time
from typing import List, Optional
import uuid
from threading import Thread
from src.utils.mock_websocket import MockWebSocketClient as WebSocketClient
from .runner import RunnerBase
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
OVAERWATCH = False
model_loader = ModelLoader()
class ASRRunner(RunnerBase):
"""
运行器类
负责管理资源和协调Pipeline的运行
"""
class SenderAndReceiver:
"""
对于单个pipeline的管理
包含 发送者 接收者
_sender: 发送者 唯一
_receiver: 接收者 可以有多个
_pipeline: 对应管道 唯一
"""
def __init__(self, *args, **kwargs):
# 可选传入参数,
self._name: str = kwargs.get("name", "")
self._sender: Optional[WebSocketClient] = kwargs.get("sender", None)
self._receiver: List[WebSocketClient] = kwargs.get("receiver", [])
# 资源
self._audio_config: AudioBinary_Config = kwargs.get("audio_config", DefaultConfig.audio_config)
self._models: dict = kwargs.get("models", None)
self._audio_binary: AudioBinary_data_list = AudioBinary_data_list()
# id唯一标识
self._id: str = str(uuid.uuid4())
# 输入队列
self._input_queue: Queue = Queue()
self._pipeline: Optional[ASRPipeline] = None
self._task: Optional[asyncio.Task] = None
def set_name(self, name: str):
self._name = name
def set_id(self, id: str):
self._id = id
def set_sender(self, sender: WebSocketClient):
self._sender = sender
def set_pipeline(self, pipeline: ASRPipeline):
self._pipeline = pipeline
config = {
"audio_config": self._audio_config,
}
self._pipeline.set_config(config)
self._pipeline.set_models(self._models)
self._pipeline.set_audio_binary(self._audio_binary)
self._pipeline.set_input_queue(self._input_queue)
# --- 异步-同步桥梁 ---
# 创建一个线程安全的回调函数用于从Pipeline的线程中调用Runner的异步方法
loop = asyncio.get_running_loop()
def thread_safe_callback(message):
asyncio.run_coroutine_threadsafe(self.deal_message(message), loop)
self._pipeline.set_callback(thread_safe_callback)
self._pipeline.bake()
def append_receiver(self, receiver: WebSocketClient):
self._receiver.append(receiver)
def delete_receiver(self, receiver: WebSocketClient):
self._receiver.remove(receiver)
async def deal_message(self, message: str):
await self.broadcast(message)
async def broadcast(self, message: str):
"""
广播发送给所有接收者
"""
logger.info("[ASRRunner][SAR-%s]广播发送给所有接收者: 消息长度:%s", self._name, len(message))
logger.info(f"SAR-{self._name} 的接收者列表: {self._receiver}")
tasks = [receiver.send(message) for receiver in self._receiver]
await asyncio.gather(*tasks)
async def _run(self):
"""
运行SAR
"""
self._pipeline.run()
loop = asyncio.get_running_loop()
while True:
try:
data = await self._sender.recv()
if data is None:
# `None` is used as a signal to end the stream
await loop.run_in_executor(None, self._input_queue.put, None)
break
# logger.debug("[ASRRunner][SAR-%s]接收到的数据length: %s", self._name, len(data))
await loop.run_in_executor(None, self._input_queue.put, data)
except Exception as e:
logger.error(f"[ASRRunner][SAR-{self._name}] _run loop error: {e}")
break
await self.stop()
def run(self):
"""
运行SAR
"""
self._task = asyncio.create_task(self._run())
async def stop(self):
"""
停止SAR
"""
logger.info(f"Stopping SAR: {self._name}")
self._pipeline.stop()
# Close all receiver websockets
receiver_tasks = [ws.close() for ws in self._receiver]
await asyncio.gather(*receiver_tasks, return_exceptions=True)
# Close the sender websocket
if self._sender:
await self._sender.close()
# Cancel the main task if it's still running
if self._task and not self._task.done():
self._task.cancel()
logger.info(f"SAR stopped: {self._name}")
def __init__(self,*args,**kwargs):
"""
"""
# 接收资源
self._default_audio_config = kwargs.get("audio_config", DefaultConfig.audio_config)
# self._audio_binary_list = args.get("audio_binary_list", None)
self._default_models = kwargs.get("models", None)
self._SAR_list: List[self.SenderAndReceiver] = []
def set_default_config(self, *args, **kwargs):
"""
设置配置
"""
self._default_audio_config = kwargs.get("audio_config", self._default_audio_config)
self._default_models = kwargs.get("models", self._default_models)
def new_SAR(
self,
ws: "WebSocketClient",
name: str = "",
audio_config: "AudioBinary_Config" = None,
models: dict = None
) -> uuid.UUID:
"""
创建新的SAR SenderAndReceiver
"""
if audio_config is None:
audio_config = self._default_audio_config
if models is None:
models = self._default_models
try:
new_SAR = self.SenderAndReceiver(
name=name,
audio_config=audio_config,
models=models
)
new_pipeline = ASRPipeline()
new_SAR.set_pipeline(new_pipeline)
# new_SAR.set_pipeline()
logger.info("创建新的SAR: name %s, id %s", new_SAR._name, new_SAR._id)
new_SAR.set_sender(ws)
new_SAR.append_receiver(ws)
new_SAR.run()
self._SAR_list.append(new_SAR)
return new_SAR._id
except Exception as e:
logger.error("创建管道失败: %s", e)
return None
def join_SAR(
self,
ws: "WebSocketClient",
name: Optional[str] = None,
id: Optional[str] = None,
) -> bool:
"""
加入SAR的Receiver
"""
# 使用next获取迭代器下一个元素生成pipeline_list迭代器按id停止
if id:
exist_pipeline = next((pipeline for pipeline in self._SAR_list if pipeline._id == id), None)
if name:
exist_pipeline = next((pipeline for pipeline in self._SAR_list if pipeline._name == name), None)
if exist_pipeline:
exist_pipeline.append_receiver(ws)
return True
return False
async def shutdown(self):
"""
优雅地关闭所有SAR会话
"""
logger.info("Shutting down all SAR instances...")
tasks = [sar.stop() for sar in self._SAR_list]
await asyncio.gather(*tasks, return_exceptions=True)
logger.info("All SAR instances have been shut down.")

3
src/runner/__init__.py Normal file

@ -0,0 +1,3 @@
from .ASRRunner import ASRRunner
__all__ = ["ASRRunner"]

105
src/runner/runner.py Normal file

@ -0,0 +1,105 @@
"""
-*- encoding: utf-8 -*-
Runner类
所有的Runner都对应一个fastapi的endpoint,
Runner需要处理:
1.新的websocket 进来后放到 unknow_websocket_pool中
2.收到特定消息后, 将消息转发给特定的pipeline处理
3.管理pipeline与websocket对应关系, 管理pipeline的ID
4.管理pipeline的启动和停止
5.管理所有pipeline用到的资源, 管理pipeline的存活时间
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, List
from threading import Thread, Lock
from queue import Queue
import traceback
import time
from src.audio_chunk import AudioChunk
from src.pipeline import PipelineFactory
from src.core.model_loader import ModelLoader
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
audio_chunk = AudioChunk()
models_loaded = ModelLoader()
class RunnerBase(ABC):
"""
运行器基类
定义了运行器的基本接口
"""
def __init__(self, *args, **kwargs):
pass
class STTRunnerFactory:
"""
STT Runner工厂类
用于创建运行器实例
"""
@staticmethod
def _create_runner(
audio_binary_name: str,
model_name_list: List[str],
pipeline_name_list: List[str],
) -> RunnerBase:
"""
创建运行器
参数:
audio_binary_name: 音频二进制名称
model_name_list: 模型名称列表
pipeline_name_list: 管道名称列表
返回:
Runner实例
"""
audio_binary = audio_chunk.get_audio_binary(audio_binary_name)
models: Dict[str, Any] = {
model_name: models_loaded.models[model_name]
for model_name in model_name_list
}
pipelines: List[Pipeline] = [
PipelineFactory.create_pipeline(pipeline_name)
for pipeline_name in pipeline_name_list
]
return RunnerBase(
audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines
)
@classmethod
def create_runner_from_config(
cls,
config: Dict[str, Any],
) -> RunnerBase:
"""
从配置创建运行器
参数:
config: 配置字典
返回:
Runner实例
"""
audio_binary_name = config["audio_binary_name"]
model_name_list = config["model_name_list"]
pipeline_name_list = config["pipeline_name_list"]
return cls._create_runner(
audio_binary_name, model_name_list, pipeline_name_list
)
@classmethod
def create_runner_normal(cls) -> RunnerBase:
"""
创建默认运行器
返回:
Runner实例
"""
audio_binary_name = None
model_name_list = list(models_loaded.models.keys())
pipeline_name_list = None
return cls._create_runner(
audio_binary_name, model_name_list, pipeline_name_list
)

@ -10,233 +10,80 @@ import json
import websockets import websockets
import ssl import ssl
import argparse import argparse
from config import parse_args from src.runner import ASRRunner
from models import load_models from src.config import DefaultConfig
from service import ASRService from src.config import AudioBinary_Config
from src.websockets.router import websocket_router
from src.core import ModelLoader
import uvicorn
from fastapi import FastAPI
from contextlib import asynccontextmanager
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
# 全局变量存储当前连接的WebSocket客户端 # 全局变量存储当前连接的WebSocket客户端
websocket_users = set() websocket_users = set()
# 使用 lifespan 上下文管理器来管理应用的生命周期
async def ws_reset(websocket): @asynccontextmanager
"""重置WebSocket连接状态并关闭连接""" async def lifespan(app: FastAPI):
print(f"重置WebSocket连接当前连接数: {len(websocket_users)}")
# 重置状态字典
websocket.status_dict_asr_online["cache"] = {}
websocket.status_dict_asr_online["is_final"] = True
websocket.status_dict_vad["cache"] = {}
websocket.status_dict_vad["is_final"] = True
websocket.status_dict_punc["cache"] = {}
# 关闭连接
await websocket.close()
async def clear_websocket():
"""清理所有WebSocket连接"""
for websocket in websocket_users:
await ws_reset(websocket)
websocket_users.clear()
async def ws_serve(websocket, path):
""" """
WebSocket服务主函数处理客户端连接和消息 在应用启动时加载模型和初始化ASRRunner
在应用关闭时优雅地关闭ASRRunner
参数:
websocket: WebSocket连接对象
path: 连接路径
""" """
frames = [] # 存储所有音频帧 logger.info("应用启动开始加载模型和初始化Runner...")
frames_asr = [] # 存储用于离线ASR的音频帧
frames_asr_online = [] # 存储用于在线ASR的音频帧
global websocket_users # 1. 加载模型
# await clear_websocket() # 清理现有连接(目前注释掉,允许多客户端) # 这里的参数可以从配置文件或环境变量中获取
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"spk_model": "cam++",
"spk_model_revision": "v2.0.2",
}
model_loader = ModelLoader()
models = model_loader.load_models(args)
# 添加到用户集合 # 2. 初始化 ASRRunner
websocket_users.add(websocket) _audio_config = AudioBinary_Config(
chunk_size=200, # ms
sample_rate=16000,
sample_width=2, # 16-bit
channels=1,
)
_audio_config.chunk_stride = int(_audio_config.chunk_size * _audio_config.sample_rate / 1000)
# 初始化连接状态 asr_runner = ASRRunner()
websocket.status_dict_asr = {} asr_runner.set_default_config(
websocket.status_dict_asr_online = {"cache": {}, "is_final": False} audio_config=_audio_config,
websocket.status_dict_vad = {"cache": {}, "is_final": False} models=models,
websocket.status_dict_punc = {"cache": {}} )
websocket.chunk_interval = 10
websocket.vad_pre_idx = 0
websocket.is_speaking = True # 默认用户正在说话
# 语音检测状态 # 3. 将 asr_runner 实例存储在 app.state 中
speech_start = False app.state.asr_runner = asr_runner
speech_end_i = -1 logger.info("模型加载和Runner初始化完成。")
# 初始化配置 yield
websocket.wav_name = "microphone"
websocket.mode = "2pass" # 默认使用两阶段识别模式
print("新用户已连接", flush=True) # --- 应用关闭时执行的代码 ---
logger.info("应用关闭,开始清理资源...")
await app.state.asr_runner.shutdown()
logger.info("资源清理完成。")
try: # 初始化FastAPI应用并指定lifespan
# 持续接收客户端消息 app = FastAPI(lifespan=lifespan)
async for message in websocket:
# 处理JSON配置消息
if isinstance(message, str):
try:
messagejson = json.loads(message)
# 更新各种配置参数 # 挂载WebSocket路由
if "is_speaking" in messagejson: app.include_router(websocket_router, prefix="/ws")
websocket.is_speaking = messagejson["is_speaking"]
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
if "chunk_interval" in messagejson:
websocket.chunk_interval = messagejson["chunk_interval"]
if "wav_name" in messagejson:
websocket.wav_name = messagejson.get("wav_name")
if "chunk_size" in messagejson:
chunk_size = messagejson["chunk_size"]
if isinstance(chunk_size, str):
chunk_size = chunk_size.split(",")
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
if "encoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
if "decoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"]
if "hotword" in messagejson:
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
if "mode" in messagejson:
websocket.mode = messagejson["mode"]
except json.JSONDecodeError:
print(f"无效的JSON消息: {message}")
# 根据chunk_interval更新VAD的chunk_size
websocket.status_dict_vad["chunk_size"] = int(
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1] * 60 / websocket.chunk_interval
)
# 处理音频数据
if len(frames_asr_online) > 0 or len(frames_asr) >= 0 or not isinstance(message, str):
if not isinstance(message, str): # 二进制音频数据
# 添加到帧缓冲区
frames.append(message)
duration_ms = len(message) // 32 # 计算音频时长
websocket.vad_pre_idx += duration_ms
# 处理在线ASR
frames_asr_online.append(message)
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
# 达到chunk_interval或最终帧时处理在线ASR
if (len(frames_asr_online) % websocket.chunk_interval == 0 or
websocket.status_dict_asr_online["is_final"]):
if websocket.mode == "2pass" or websocket.mode == "online":
audio_in = b"".join(frames_asr_online)
try:
await asr_service.async_asr_online(websocket, audio_in)
except Exception as e:
print(f"在线ASR处理错误: {e}")
frames_asr_online = []
# 如果检测到语音开始收集帧用于离线ASR
if speech_start:
frames_asr.append(message)
# VAD处理 - 语音活动检测
try:
speech_start_i, speech_end_i = await asr_service.async_vad(websocket, message)
except Exception as e:
print(f"VAD处理错误: {e}")
# 检测到语音开始
if speech_start_i != -1:
speech_start = True
# 计算开始偏移并收集前面的帧
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
frames_pre = frames[-beg_bias:] if beg_bias < len(frames) else frames
frames_asr = []
frames_asr.extend(frames_pre)
# 处理离线ASR (语音结束或用户停止说话)
if speech_end_i != -1 or not websocket.is_speaking:
if websocket.mode == "2pass" or websocket.mode == "offline":
audio_in = b"".join(frames_asr)
try:
await asr_service.async_asr(websocket, audio_in)
except Exception as e:
print(f"离线ASR处理错误: {e}")
# 重置状态
frames_asr = []
speech_start = False
frames_asr_online = []
websocket.status_dict_asr_online["cache"] = {}
# 如果用户停止说话,完全重置
if not websocket.is_speaking:
websocket.vad_pre_idx = 0
frames = []
websocket.status_dict_vad["cache"] = {}
else:
# 保留最近的帧用于下一轮处理
frames = frames[-20:]
except websockets.ConnectionClosed:
print(f"连接已关闭...", flush=True)
await ws_reset(websocket)
websocket_users.remove(websocket)
except websockets.InvalidState:
print("无效的WebSocket状态...")
except Exception as e:
print(f"发生异常: {e}")
def start_server(args, asr_service_instance):
"""
启动WebSocket服务器
参数:
args: 命令行参数
asr_service_instance: ASR服务实例
"""
global asr_service
asr_service = asr_service_instance
# 配置SSL (如果提供了证书)
if args.certfile and len(args.certfile) > 0:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile)
start_server = websockets.serve(
ws_serve, args.host, args.port,
subprotocols=["binary"],
ping_interval=None,
ssl=ssl_context
)
else:
start_server = websockets.serve(
ws_serve, args.host, args.port,
subprotocols=["binary"],
ping_interval=None
)
print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}")
# 启动事件循环
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()
@app.get("/")
async def read_root():
return {"message": "FunASR-FastAPI WebSocket Server is running."}
# 如果需要直接运行此文件进行测试
if __name__ == "__main__": if __name__ == "__main__":
# 解析命令行参数 # 注意在生产环境中推荐使用Gunicorn + Uvicorn workers
args = parse_args() uvicorn.run(app, host="0.0.0.0", port=8000)
# 加载模型
print("正在加载模型...")
models = load_models(args)
print("模型加载完成!当前仅支持单个客户端同时连接!")
# 创建ASR服务
asr_service = ASRService(models)
# 启动服务器
start_server(args, asr_service)

@ -1,127 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
ASR服务模块 - 提供语音识别相关的核心功能
"""
import json
class ASRService:
"""ASR服务类封装各种语音识别相关功能"""
def __init__(self, models):
"""
初始化ASR服务
参数:
models: 包含各种预加载模型的字典
"""
self.model_asr = models["asr"]
self.model_asr_streaming = models["asr_streaming"]
self.model_vad = models["vad"]
self.model_punc = models["punc"]
async def async_vad(self, websocket, audio_in):
"""
语音活动检测
参数:
websocket: WebSocket连接
audio_in: 二进制音频数据
返回:
tuple: (speech_start, speech_end) 语音开始和结束位置
"""
# 使用VAD模型分析音频段
segments_result = self.model_vad.generate(
input=audio_in,
**websocket.status_dict_vad
)[0]["value"]
speech_start = -1
speech_end = -1
# 解析VAD结果
if len(segments_result) == 0 or len(segments_result) > 1:
return speech_start, speech_end
if segments_result[0][0] != -1:
speech_start = segments_result[0][0]
if segments_result[0][1] != -1:
speech_end = segments_result[0][1]
return speech_start, speech_end
async def async_asr(self, websocket, audio_in):
"""
离线ASR处理
参数:
websocket: WebSocket连接
audio_in: 二进制音频数据
"""
if len(audio_in) > 0:
# 使用离线ASR模型处理音频
rec_result = self.model_asr.generate(
input=audio_in,
**websocket.status_dict_asr
)[0]
# 如果有标点符号模型且识别出文本,则添加标点
if self.model_punc is not None and len(rec_result["text"]) > 0:
rec_result = self.model_punc.generate(
input=rec_result["text"],
**websocket.status_dict_punc
)[0]
# 如果识别出文本,发送到客户端
if len(rec_result["text"]) > 0:
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({
"mode": mode,
"text": rec_result["text"],
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
})
await websocket.send(message)
else:
# 如果没有音频数据,发送空文本
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({
"mode": mode,
"text": "",
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
})
await websocket.send(message)
async def async_asr_online(self, websocket, audio_in):
"""
在线ASR处理
参数:
websocket: WebSocket连接
audio_in: 二进制音频数据
"""
if len(audio_in) > 0:
# 使用在线ASR模型处理音频
rec_result = self.model_asr_streaming.generate(
input=audio_in,
**websocket.status_dict_asr_online
)[0]
# 在2pass模式下如果是最终帧则跳过(留给离线ASR处理)
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False):
return
# 如果识别出文本,发送到客户端
if len(rec_result["text"]):
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({
"mode": mode,
"text": rec_result["text"],
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
})
await websocket.send(message)

6
src/utils/__init__.py Normal file

@ -0,0 +1,6 @@
from .logger import get_module_logger, setup_root_logger
__all__ = [
"get_module_logger",
"setup_root_logger",
]

122
src/utils/data_format.py Normal file

@ -0,0 +1,122 @@
"""
处理各类音频数据与bytes的转换
"""
import wave
from pydub import AudioSegment
import io
def wav_to_bytes(wav_path: str) -> bytes:
"""
将WAV文件读取为bytes数据
参数:
wav_path (str): WAV文件的路径
返回:
bytes: WAV文件的原始字节数据
异常:
FileNotFoundError: 如果WAV文件不存在
wave.Error: 如果文件不是有效的WAV文件
"""
try:
with wave.open(wav_path, "rb") as wf:
# 读取所有帧
frames = wf.readframes(wf.getnframes())
return frames
except FileNotFoundError:
# 可以选择记录日志或重新抛出,这里为了清晰直接重新抛出
raise FileNotFoundError(f"错误: 未找到WAV文件 '{wav_path}'")
except wave.Error as e:
raise wave.Error(f"错误: 打开或读取WAV文件 '{wav_path}' 失败 - {e}")
def bytes_to_wav(
bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int
):
"""
将bytes数据写入为WAV文件
参数:
bytes_data (bytes): 音频的字节数据
wav_path (str): 保存WAV文件的路径
nchannels (int): 声道数 (例如 1 for mono, 2 for stereo)
sampwidth (int): 采样宽度 (字节数, 例如 2 for 16-bit audio)
framerate (int): 采样率 (例如 44100, 16000)
异常:
wave.Error: 如果写入WAV文件失败
"""
try:
with wave.open(wav_path, "wb") as wf:
wf.setnchannels(nchannels)
wf.setsampwidth(sampwidth)
wf.setframerate(framerate)
wf.writeframes(bytes_data)
except wave.Error as e:
raise wave.Error(f"错误: 写入WAV文件 '{wav_path}' 失败 - {e}")
except Exception as e:
# 捕获其他可能的写入错误
raise Exception(f"写入WAV文件 '{wav_path}' 时发生未知错误 - {e}")
def mp3_to_bytes(mp3_path: str) -> bytes:
"""
将MP3文件转换为bytes数据 (原始PCM数据)
参数:
mp3_path (str): MP3文件的路径
返回:
bytes: MP3文件解码后的原始PCM字节数据
异常:
FileNotFoundError: 如果MP3文件不存在
pydub.exceptions.CouldntDecodeError: 如果MP3文件无法解码
"""
try:
audio = AudioSegment.from_mp3(mp3_path)
# 获取原始PCM数据
return audio.raw_data
except FileNotFoundError:
raise FileNotFoundError(f"错误: 未找到MP3文件 '{mp3_path}'")
except Exception as e: # pydub 可能抛出多种解码相关的错误
raise Exception(f"错误: 处理MP3文件 '{mp3_path}' 失败 - {e}")
def bytes_to_mp3(
bytes_data: bytes,
mp3_path: str,
frame_rate: int,
channels: int,
sample_width: int,
bitrate: str = "192k",
):
"""
将原始PCM bytes数据转换为MP3文件
参数:
bytes_data (bytes): 原始PCM字节数据
mp3_path (str): 保存MP3文件的路径
frame_rate (int): 原始PCM数据的采样率 (例如 44100)
channels (int): 原始PCM数据的声道数 (例如 1 for mono, 2 for stereo)
sample_width (int): 原始PCM数据的采样宽度 (字节数, 例如 2 for 16-bit)
bitrate (str): MP3编码的比特率 (例如 "128k", "192k", "320k")
异常:
Exception: 如果转换或写入MP3文件失败
"""
try:
# 从原始数据创建AudioSegment对象
audio = AudioSegment(
data=bytes_data,
sample_width=sample_width,
frame_rate=frame_rate,
channels=channels,
)
# 导出为MP3
audio.export(mp3_path, format="mp3", bitrate=bitrate)
except Exception as e:
raise Exception(f"错误: 转换或写入MP3文件 '{mp3_path}' 失败 - {e}")

91
src/utils/logger.py Normal file

@ -0,0 +1,91 @@
import logging
import sys
from pathlib import Path
from typing import Optional
def setup_logger(
name: str = None,
level: str = "INFO",
log_file: Optional[str] = None,
log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
date_format: str = "%Y-%m-%d %H:%M:%S",
) -> logging.Logger:
"""
设置并返回一个配置好的logger实例
Args:
name: logger的名称默认为None使用root logger
level: 日志级别默认为"INFO"
log_file: 日志文件路径默认为None仅控制台输出
log_format: 日志格式
date_format: 日期格式
Returns:
logging.Logger: 配置好的logger实例
"""
# 获取logger实例
logger = logging.getLogger(name)
# 设置日志级别
level = getattr(logging, level.upper())
logger.setLevel(level)
print(f"添加处理器 {name} {log_file} {log_format} {date_format}")
# 创建格式器
formatter = logging.Formatter(log_format, date_format)
# 添加控制台处理器
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# 如果指定了日志文件,添加文件处理器
if log_file:
# 确保日志目录存在
log_path = Path(log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# 注意:移除了 propagate = False允许日志传递
return logger
def setup_root_logger(level: str = "INFO", log_file: Optional[str] = None) -> None:
"""
配置根日志器
Args:
level: 日志级别
log_file: 日志文件路径
"""
setup_logger(None, level, log_file)
def get_module_logger(
module_name: str,
level: Optional[str] = None, # 改为可选参数
log_file: Optional[str] = None, # 一般不需要单独指定
) -> logging.Logger:
"""
获取模块级别的logger
Args:
module_name: 模块名称通常传入__name__
level: 可选的日志级别如果不指定则继承父级配置
log_file: 可选的日志文件路径一般不需要指定
"""
logger = logging.getLogger(module_name)
# 只有显式指定了level才设置
if level:
logger.setLevel(getattr(logging, level.upper()))
# 只有显式指定了log_file才添加文件处理器
if log_file:
setup_logger(module_name, level or "INFO", log_file)
return logger

@ -0,0 +1,55 @@
import queue
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class MockWebSocketClient:
"""A mock WebSocket client to simulate a connection for testing."""
def __init__(self):
self.sent_messages = []
self._is_closed = False
self.receive_queue = queue.Queue()
def send(self, message: dict):
"""Simulates sending a message (which is a dict)."""
if self._is_closed:
print("Warning: sending message on a closed websocket")
return
self.sent_messages.append(message)
print(f"Mock WS received: {message}")
def recv(self):
"""Simulates receiving data from the WebSocket."""
if self._is_closed:
return None
try:
# Block until data is available, with a timeout to prevent hanging.
data = self.receive_queue.get(timeout=10)
if data is None:
self._is_closed = True
return data
except queue.Empty:
print("Mock WS recv timeout")
self._is_closed = True
return None
def close(self):
"""Simulates closing the WebSocket connection."""
if not self._is_closed:
# Put None to unblock any waiting recv call
self.receive_queue.put(None)
self._is_closed = True
print("Mock WS closed")
def put_for_recv(self, data):
"""Puts data into the receive queue for the `recv` method to consume."""
if data is None:
return
# logger.debug("Mock WS put_for_recv length: %s", len(data))
self.receive_queue.put(data)
@property
def is_closed(self):
return self._is_closed

66
src/websockets/adapter.py Normal file

@ -0,0 +1,66 @@
import numpy as np
from fastapi import WebSocket
from typing import Union
import uuid
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class FastAPIWebSocketAdapter:
"""
一个适配器类用于将FastAPI的WebSocket对象包装成ASRRunner所期望的接口
同时处理数据类型转换
"""
def __init__(self, websocket: WebSocket, sample_rate: int = 16000, sample_width: int = 2):
self._ws = websocket
self._sample_rate = sample_rate
self._sample_width = sample_width
self._total_received = 0
async def recv(self) -> Union[np.ndarray, None]:
"""
接收来自FastAPI WebSocket的数据
如果收到的是字节流将其转换为Numpy数组
如果收到的是文本"close"返回None以表示结束
"""
message = await self._ws.receive()
if 'bytes' in message:
bytes_data = message['bytes']
audio_array = np.frombuffer(bytes_data, dtype=np.int16)
# 将int16转为float32
audio_array = audio_array.astype(np.float32)
# 归一化
audio_array = audio_array / 32768.0
# 使用回车符 \r 覆盖打印进度
self._total_received += len(bytes_data)
print(f"🎧 [Adapter] 正在接收音频... 总计: {self._total_received / 1024:.2f} KB", end='\r')
return audio_array
elif 'text' in message and message['text'].lower() == 'close':
print("\n🏁 [Adapter] 收到 'close' 信号。") # 在收到结束信号时换行
return None # 返回 None 来作为结束信号
return np.array([]) # 返回空数组以忽略其他类型的消息
async def send(self, message: dict):
"""
将字典消息作为JSON发送给客户端
在发送前将所有UUID对象转换为字符串以确保可序列化
"""
def convert_uuids(obj):
if isinstance(obj, dict):
return {k: convert_uuids(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_uuids(elem) for elem in obj]
elif isinstance(obj, uuid.UUID):
return str(obj)
return obj
serializable_message = convert_uuids(message)
logger.info(f"[Adapter] 发送消息: {serializable_message}")
await self._ws.send_json(serializable_message)
async def close(self):
"""
关闭WebSocket连接
"""
await self._ws.close()

@ -0,0 +1,3 @@
from .asr_endpoint import router as asr_router
__all__ = ["asr_router"]

@ -0,0 +1,89 @@
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Request
from src.websockets.adapter import FastAPIWebSocketAdapter
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
router = APIRouter()
@router.websocket("/asr/{session_id}")
async def asr_websocket_endpoint(
websocket: WebSocket,
session_id: str,
mode: str = Query(default="sender", enum=["sender", "receiver"])
):
"""
ASR WebSocket 端点
- **session_id**: 标识一个识别会话的唯一ID.
- **mode**: 客户端模式.
- `sender`: 作为音频发送方加入将创建一个新的识别会话.
- `receiver`: 作为结果接收方加入订阅一个已存在的会话.
"""
await websocket.accept()
# 从websocket.app.state获取全局的ASRRunner实例
asr_runner = websocket.app.state.asr_runner
# 创建WebSocket适配器
# 注意这里的audio_config应该与ASRRunner中的默认配置一致
audio_config = asr_runner._default_audio_config
adapter = FastAPIWebSocketAdapter(
websocket,
sample_rate=audio_config.sample_rate,
sample_width=audio_config.sample_width
)
if mode == "sender":
logger.info(f"客户端 {websocket.client} 作为 'sender' 加入会话: {session_id}")
# 创建一个新的SAR会话
sar_id = asr_runner.new_SAR(ws=adapter, name=session_id)
if sar_id is None:
logger.error(f"为会话 {session_id} 创建SAR失败")
await websocket.close(code=1011, reason="Failed to create ASR session")
return
sar = next((s for s in asr_runner._SAR_list if s._id == sar_id), None)
try:
# 端点函数等待后台任务完成。
# 真正的接收逻辑在SAR的_run方法中该方法由new_SAR作为后台任务启动。
# 当客户端断开连接时adapter.recv()会抛出异常,
# _run任务会捕获它然后停止并清理最后任务结束。
if sar and sar._task:
await sar._task
else:
# 如果任务没有被创建,记录一个错误并关闭连接
logger.error(f"SAR任务未能在会话 {session_id} 中启动")
await websocket.close(code=1011, reason="Failed to start ASR task")
except Exception as e:
# 捕获任何意外的错误
logger.error(f"会话 {session_id}'sender' 端点发生未知错误: {e}")
finally:
logger.info(f"'sender' {websocket.client} 在会话 {session_id} 的连接处理已结束")
elif mode == "receiver":
logger.info(f"客户端 {websocket.client} 作为 'receiver' 加入会话: {session_id}")
# 加入一个已存在的SAR会话
joined = asr_runner.join_SAR(ws=adapter, name=session_id)
if not joined:
logger.warning(f"无法找到会话 {session_id}'receiver' {websocket.client} 加入失败")
await websocket.close(code=1011, reason=f"Session '{session_id}' not found")
return
try:
# Receiver只需要保持连接等待从SAR广播过来的消息
# 这个循环也用于检测断开
while True:
await websocket.receive_text()
except WebSocketDisconnect:
logger.info(f"'receiver' {websocket.client} 在会话 {session_id} 中断开连接")
# Receiver断开时需要将其从SAR的接收者列表中移除
sar = next((s for s in asr_runner._SAR_list if s._name == session_id), None)
if sar:
sar.delete_receiver(adapter)
logger.info(f"已从会话 {session_id} 中移除 'receiver' {websocket.client}")
else:
# 理论上由于FastAPI的enum校验这里的代码不会被执行
logger.error(f"无效的模式: {mode}")
await websocket.close(code=1003, reason="Invalid mode specified")

9
src/websockets/router.py Normal file

@ -0,0 +1,9 @@
from fastapi import APIRouter
from .endpoint import asr_endpoint
websocket_router = APIRouter()
# 包含ASR端点路由
websocket_router.include_router(asr_endpoint.router)
__all__ = ["websocket_router"]

26
test_main.py Normal file

@ -0,0 +1,26 @@
"""
测试主函数
请在tests目录下创建测试文件, 并在此文件中调用
"""
from tests.pipeline.asr_test import test_asr_pipeline
from src.utils.logger import get_module_logger, setup_root_logger
from tests.runner.asr_runner_test import test_asr_runner
import asyncio
setup_root_logger(level="INFO", log_file="logs/test_main.log")
logger = get_module_logger(__name__)
# 清空logs/test_main.log文件
with open("logs/test_main.log", "w") as f:
f.truncate()
# from tests.functor.vad_test import test_vad_functor
# logger.info("开始测试VAD函数器")
# test_vad_functor()
# logger.info("开始测试ASR管道")
# test_asr_pipeline()
logger.info("开始测试ASRRunner")
asyncio.run(test_asr_runner())

BIN
tests/XT_ZZY.wav Normal file

Binary file not shown.

BIN
tests/XT_ZZY_denoise.wav Normal file

Binary file not shown.

124
tests/functor/vad_test.py Normal file

@ -0,0 +1,124 @@
"""
Functor测试
VAD测试
"""
from src.functor.vad_functor import VADFunctor
from src.functor.asr_functor import ASRFunctor
from src.functor.spk_functor import SPKFunctor
from queue import Queue, Empty
from src.core.model_loader import ModelLoader
from src.models import AudioBinary_Config, AudioBinary_data_list
from src.utils.data_format import wav_to_bytes
import time
from src.utils.logger import get_module_logger
from pydub import AudioSegment
import soundfile
# 观察参数
OVERWATCH = False
logger = get_module_logger(__name__)
model_loader = ModelLoader()
def test_vad_functor():
# 加载模型
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"auto_update": False,
}
model_loader.load_models(args)
# 加载数据
f_data, sample_rate = soundfile.read("tests/vad_example.wav")
audio_config = AudioBinary_Config(
chunk_size=200,
chunk_stride=1600,
sample_rate=sample_rate,
sample_width=16,
channels=1,
)
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
audio_config.chunk_stride = chunk_stride
# 创建输入队列
input_queue = Queue()
vad2asr_queue = Queue()
vad2spk_queue = Queue()
# 创建音频数据列表
audio_binary_data_list = AudioBinary_data_list()
# 创建VAD函数器
vad_functor = VADFunctor()
# 设置输入队列
vad_functor.set_input_queue(input_queue)
# 设置音频配置
vad_functor.set_audio_config(audio_config)
# 设置音频数据列表
vad_functor.set_audio_binary_data_list(audio_binary_data_list)
# 设置回调函数
vad_functor.add_callback(lambda x: print(f"vad callback: {x}"))
vad_functor.add_callback(lambda x: vad2asr_queue.put(x))
vad_functor.add_callback(lambda x: vad2spk_queue.put(x))
# 设置模型
vad_functor.set_model({"vad": model_loader.models["vad"]})
# 启动VAD函数器
vad_functor.run()
# 创建ASR函数器
asr_functor = ASRFunctor()
# 设置输入队列
asr_functor.set_input_queue(vad2asr_queue)
# 设置音频配置
asr_functor.set_audio_config(audio_config)
# 设置回调函数
asr_functor.add_callback(lambda x: print(f"asr callback: {x}"))
# 设置模型
asr_functor.set_model({"asr": model_loader.models["asr"]})
# 启动ASR函数器
asr_functor.run()
# 创建SPK函数器
spk_functor = SPKFunctor()
# 设置输入队列
spk_functor.set_input_queue(vad2spk_queue)
# 设置音频配置
spk_functor.set_audio_config(audio_config)
# 设置回调函数
spk_functor.add_callback(lambda x: print(f"spk callback: {x}"))
# 设置模型
spk_functor.set_model(
{
# 'spk': model_loader.models['spk']
"spk": "fake_spk"
}
)
# 启动SPK函数器
spk_functor.run()
f_binary = f_data
audio_clip_len = 200
print(
f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}"
)
for i in range(0, len(f_binary), audio_clip_len):
binary_data = f_binary[i : i + audio_clip_len]
input_queue.put(binary_data)
# 等待VAD函数器结束
vad_functor.stop()
print("[vad_test] VAD函数器结束")
asr_functor.stop()
print("[vad_test] ASR函数器结束")
# 保存音频数据
if OVERWATCH:
for index in range(len(audio_binary_data_list)):
save_path = f"tests/vad_test_output_{index}.wav"
soundfile.write(
save_path, audio_binary_data_list[index].binary_data, sample_rate
)

121
tests/modelsuse.py Normal file

@ -0,0 +1,121 @@
"""
模型使用测试
此处主要用于各类调用模型的处理数据与输出格式
请在主目录下test_main.py中调用
将需要测试的模型定义在函数中进行测试, 函数名称需要与测试内容匹配
"""
from funasr import AutoModel
from typing import List, Dict, Any
from src.models import VADResponse
import time
def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
"""
在线VAD模型使用
"""
chunk_size = 100 # ms
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
vad_result = VADResponse()
vad_result.time_chunk_index_callback = lambda index: print(f"回调: {index}")
items = []
import soundfile
speech, sample_rate = soundfile.read(file_path)
chunk_stride = int(chunk_size * sample_rate / 1000)
cache = {}
total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
for i in range(total_chunk_num):
time.sleep(0.1)
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
is_final = i == total_chunk_num - 1
res = model.generate(
input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size
)
if len(res[0]["value"]):
vad_result += VADResponse.from_raw(res)
for item in res[0]["value"]:
items.append(item)
vad_result.process_time_chunk()
# for item in items:
# print(item)
return vad_result
def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
"""
在线VAD模型使用
测试LogicTrager
在Rebuild版本后LogicTrager中已弃用
"""
from src.logic_trager import LogicTrager
import soundfile
from src.config import parse_args
args = parse_args()
# from src.functor.model_loader import load_models
# models = load_models(args)
from src.core.model_loader import ModelLoader
models = ModelLoader(args)
chunk_size = 200 # ms
from src.models import AudioBinary_Config
import soundfile
speech, sample_rate = soundfile.read(file_path)
chunk_stride = int(chunk_size * sample_rate / 1000)
audio_config = AudioBinary_Config(
sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size
)
logic_trager = LogicTrager(models=models, audio_config=audio_config)
for i in range(len(speech) // chunk_stride + 1):
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
logic_trager.push_binary_data(speech_chunk)
# for item in items:
# print(item)
return None
def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
"""
ASR模型使用
离线ASR模型使用
"""
from funasr import AutoModel
model = AutoModel(
model="paraformer-zh",
model_revision="v2.0.4",
vad_model="fsmn-vad",
vad_model_revision="v2.0.4",
# punc_model="ct-punc-c", punc_model_revision="v2.0.4",
spk_model="cam++",
spk_model_revision="v2.0.2",
spk_mode="vad_segment",
auto_update=False,
)
import soundfile
from src.models import AudioBinary_Config
import soundfile
speech, sample_rate = soundfile.read(file_path)
result = model.generate(speech)
return result
# if __name__ == "__main__":
# 请在主目录下调用test_main.py文件进行测试
# vad_result = vad_model_use_online("tests/vad_example.wav")
# vad_result = vad_model_use_online_logic("tests/vad_example.wav")
# print(vad_result)

@ -0,0 +1,94 @@
"""
Pipeline测试
VAD+ASR+SPK(FAKE)
"""
from src.pipeline.ASRpipeline import ASRPipeline
from src.pipeline import PipelineFactory
from src.models import AudioBinary_data_list, AudioBinary_Config
from src.core.model_loader import ModelLoader
from queue import Queue
import soundfile
import time
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
OVAERWATCH = False
model_loader = ModelLoader()
def test_asr_pipeline():
# 加载模型
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"spk_model": "cam++",
"spk_model_revision": "v2.0.2",
"audio_update": False,
}
models = model_loader.load_models(args)
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
audio_config = AudioBinary_Config(
chunk_size=200,
chunk_stride=1600,
sample_rate=sample_rate,
sample_width=16,
channels=1,
)
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
audio_config.chunk_stride = chunk_stride
# 创建参数Dict
config = {
"audio_config": audio_config,
}
# 创建音频数据列表
audio_binary_data_list = AudioBinary_data_list()
input_queue = Queue()
# 创建Pipeline
# asr_pipeline = ASRPipeline()
# asr_pipeline.set_models(models)
# asr_pipeline.set_config(config)
# asr_pipeline.set_audio_binary(audio_binary_data_list)
# asr_pipeline.set_input_queue(input_queue)
# asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
# asr_pipeline.bake()
asr_pipeline = PipelineFactory.create_pipeline(
pipeline_name = "ASRpipeline",
models=models,
config=config,
audio_binary=audio_binary_data_list,
input_queue=input_queue,
callback=lambda x: print(f"pipeline callback: {x}")
)
asr_pipeline.bake()
# 运行Pipeline
asr_instance = asr_pipeline.run()
audio_clip_len = 200
print(
f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}"
)
for i in range(0, len(audio_data), audio_clip_len):
input_queue.put(audio_data[i : i + audio_clip_len])
# time.sleep(10)
# input_queue.put(None)
# 等待Pipeline结束
# asr_instance.join()
time.sleep(5)
asr_pipeline.stop()
# asr_pipeline.stop()

@ -0,0 +1,132 @@
"""
ASRRunner test
"""
import asyncio
import soundfile
import numpy as np
from src.runner.ASRRunner import ASRRunner
from src.core.model_loader import ModelLoader
from src.models import AudioBinary_Config
from asyncio import Queue as AsyncQueue
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class AsyncMockWebSocketClient:
"""一个用于测试目的的异步WebSocket客户端模拟器。"""
def __init__(self):
self._recv_q = AsyncQueue()
self._send_q = AsyncQueue()
def put_for_recv(self, item):
"""允许测试将数据送入模拟的WebSocket中。"""
self._recv_q.put_nowait(item)
async def get_from_send(self):
"""允许测试从模拟的WebSocket中获取结果。"""
return await self._send_q.get()
async def recv(self):
"""ASRRunner将调用此方法来获取数据。"""
return await self._recv_q.get()
async def send(self, item):
"""ASRRunner将通过回调调用此方法来发送结果。"""
logger.info(f"Mock WS 收到结果: {item}")
await self._send_q.put(item)
async def close(self):
"""一个模拟的关闭方法。"""
pass
async def test_asr_runner():
"""
针对ASRRunner的端到端测试已适配异步操作
1. 加载模型.
2. 配置并初始化ASRRunner.
3. 创建一个异步的模拟WebSocket客户端.
4. 在Runner中启动一个新的SenderAndReceiver (SAR)实例.
5. 通过模拟的WebSocket流式传输音频数据.
6. 等待处理任务完成并断言其无错误运行.
"""
# 1. 加载模型
model_loader = ModelLoader()
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"spk_model": "cam++",
"spk_model_revision": "v2.0.2",
}
models = model_loader.load_models(args)
audio_file_path = "tests/XT_ZZY_denoise.wav"
audio_data, sample_rate = soundfile.read(audio_file_path, dtype='float32')
logger.info(
f"加载数据: {audio_file_path} , audio_data_length: {len(audio_data)}, audio_data_type: {type(audio_data)}, sample_rate: {sample_rate}"
)
# 进一步详细打印audio_data数据类型
# 详细打印audio_data的类型和结构信息便于调试
logger.info(f"audio_data 类型: {type(audio_data)}")
logger.info(f"audio_data dtype: {getattr(audio_data, 'dtype', '未知')}")
logger.info(f"audio_data shape: {getattr(audio_data, 'shape', '未知')}")
logger.info(f"audio_data ndim: {getattr(audio_data, 'ndim', '未知')}")
logger.info(f"audio_data 示例前10个值: {audio_data[:10] if hasattr(audio_data, '__getitem__') else '不可切片'}")
# 2. 配置音频
audio_config = AudioBinary_Config(
chunk_size=200, # ms
sample_rate=sample_rate,
sample_width=2, # 16-bit
channels=1,
)
audio_config.chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
# 3. 设置ASRRunner
asr_runner = ASRRunner()
asr_runner.set_default_config(
audio_config=audio_config,
models=models,
)
# 4. 创建模拟WebSocket并启动SAR
mock_ws = AsyncMockWebSocketClient()
sar_id = asr_runner.new_SAR(
ws=mock_ws,
name="test_sar",
)
assert sar_id is not None, "创建新的SAR实例失败"
# 获取SAR实例以等待其任务
sar = next((s for s in asr_runner._SAR_list if s._id == sar_id), None)
assert sar is not None, "无法从Runner中获取SAR实例。"
assert sar._task is not None, "SAR任务未被创建。"
# 5. 在后台任务中模拟流式音频
async def feed_audio():
logger.info("Feeder任务已启动开始流式传输音频数据...")
# 每次发送100ms的音频
audio_clip_len = int(sample_rate * 0.1)
for i in range(0, len(audio_data), audio_clip_len):
chunk = audio_data[i : i + audio_clip_len].astype(np.float32)
if chunk.size == 0:
break
mock_ws.put_for_recv(chunk)
await asyncio.sleep(0.1) # 模拟实时流
# 发送None来表示音频流结束
mock_ws.put_for_recv(None)
logger.info("Feeder任务已完成所有音频数据已发送。")
feeder_task = asyncio.create_task(feed_audio())
# 6. 等待SAR处理完成
# SAR任务在从模拟WebSocket接收到None后会结束
await sar._task
await feeder_task # 确保feeder也已完成
logger.info("ASRRunner测试成功完成。")

25
tests/spkverify_use.py Normal file

@ -0,0 +1,25 @@
from modelscope.pipelines import pipeline
sv_pipeline = pipeline(
task='speaker-verification',
model='iic/speech_campplus_sv_zh-cn_16k-common',
model_revision='v1.0.0'
)
speaker1_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker1_a_cn_16k.wav'
speaker1_b_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker1_b_cn_16k.wav'
speaker2_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker2_a_cn_16k.wav'
# 相同说话人语音
result = sv_pipeline([speaker1_a_wav, speaker1_b_wav])
print(result)
# 不同说话人语音
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav])
print(result)
# 可以自定义得分阈值来进行识别,阈值越高,判定为同一人的条件越严格
result = sv_pipeline([speaker1_a_wav, speaker1_a_wav], thr=0.6)
print(result)
# 可以传入output_emb参数输出结果中就会包含提取到的说话人embedding
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], output_emb=True)
print(result['embs'], result['outputs'])
# result1 = sv_pipeline([speaker1_a_wav], output_emb=True)
# print(result1['embs'], result1['outputs'])
# 可以传入save_dir参数提取到的说话人embedding会存储在save_dir目录中
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], save_dir='savePath/')

@ -10,13 +10,13 @@ import os
from unittest.mock import patch from unittest.mock import patch
# 将src目录添加到路径 # 将src目录添加到路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.config import parse_args from src.config import parse_args
def test_default_args(): def test_default_args():
"""测试默认参数值""" """测试默认参数值"""
with patch('sys.argv', ['script.py']): with patch("sys.argv", ["script.py"]):
args = parse_args() args = parse_args()
# 检查服务器参数 # 检查服务器参数
@ -46,17 +46,24 @@ def test_default_args():
def test_custom_args(): def test_custom_args():
"""测试自定义参数值""" """测试自定义参数值"""
test_args = [ test_args = [
'script.py', "script.py",
'--host', 'localhost', "--host",
'--port', '8080', "localhost",
'--certfile', 'cert.pem', "--port",
'--keyfile', 'key.pem', "8080",
'--asr_model', 'custom_model', "--certfile",
'--ngpu', '0', "cert.pem",
'--device', 'cpu' "--keyfile",
"key.pem",
"--asr_model",
"custom_model",
"--ngpu",
"0",
"--device",
"cpu",
] ]
with patch('sys.argv', test_args): with patch("sys.argv", test_args):
args = parse_args() args = parse_args()
# 检查自定义参数 # 检查自定义参数

BIN
tests/vad_example.wav Normal file

Binary file not shown.

@ -0,0 +1,110 @@
import asyncio
import websockets
import soundfile as sf
import uuid
# --- 配置 ---
HOST = "localhost"
PORT = 11096
SESSION_ID = str(uuid.uuid4())
SENDER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=sender"
RECEIVER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=receiver"
AUDIO_FILE_PATH = "tests/XT_ZZY_denoise.wav" # 确保此测试文件存在且为 16kHz, 16-bit, 单声道
CHUNK_DURATION_MS = 100 # 每次发送100ms的音频数据
CHUNK_SIZE = int(16000 * 2 * CHUNK_DURATION_MS / 1000) # 3200 bytes
async def run_receiver():
"""作为接收者连接,并打印收到的所有消息。"""
print(f"▶️ [Receiver] 尝试连接到: {RECEIVER_URI}")
try:
async with websockets.connect(RECEIVER_URI) as websocket:
print("✅ [Receiver] 连接成功,等待消息...")
try:
while True:
message = await websocket.recv()
print(f"🎧 [Receiver] 收到结果: {message}")
except websockets.exceptions.ConnectionClosed as e:
print(f"✅ [Receiver] 连接已由服务器正常关闭: {e.reason}")
except Exception as e:
print(f"❌ [Receiver] 连接失败: {e}")
async def run_sender():
"""
作为发送者连接同时负责发送音频和接收自己会话的广播结果
"""
await asyncio.sleep(1) # 等待receiver有机会先连接
print(f"▶️ [Sender] 尝试连接到: {SENDER_URI}")
try:
async with websockets.connect(SENDER_URI) as websocket:
print("✅ [Sender] 连接成功。")
# --- 并行任务:接收消息 ---
async def receive_task():
print("▶️ [Sender-Receiver] 开始监听广播消息...")
try:
while True:
message = await websocket.recv()
print(f"🎧 [Sender-Receiver] 收到结果: {message}")
except websockets.exceptions.ConnectionClosed:
print("✅ [Sender-Receiver] 连接已关闭,停止监听。")
receiver_sub_task = asyncio.create_task(receive_task())
# --- 主任务:发送音频 ---
try:
print("▶️ [Sender] 准备发送音频...")
audio_data, sample_rate = sf.read(AUDIO_FILE_PATH, dtype='int16')
if sample_rate != 16000:
print(f"❌ [Sender] 错误:音频文件采样率必须是 16kHz。")
receiver_sub_task.cancel()
return
total_samples = len(audio_data)
chunk_samples = CHUNK_SIZE // 2
samples_sent = 0
print(f"音频加载成功,总长度: {total_samples} samples。开始分块发送...")
for i in range(0, total_samples, chunk_samples):
chunk = audio_data[i:i + chunk_samples]
if len(chunk) == 0:
break
await websocket.send(chunk.tobytes())
samples_sent += len(chunk)
print(f"🎧 [Sender] 正在发送: {samples_sent}/{total_samples} samples", end="\r")
await asyncio.sleep(CHUNK_DURATION_MS / 1000)
print()
print("🏁 [Sender] 音频流发送完毕,发送 'close' 信号。")
await websocket.send("close")
except FileNotFoundError:
print(f"❌ [Sender] 错误:找不到音频文件 {AUDIO_FILE_PATH}")
except Exception as e:
print(f"❌ [Sender] 发送过程中发生错误: {e}")
# 等待接收任务自然结束(当连接关闭时)
await receiver_sub_task
except Exception as e:
print(f"❌ [Sender] 连接失败: {e}")
async def main():
"""同时运行 sender 和 receiver 任务。"""
print("--- 开始 WebSocket ASR 服务端到端测试 ---")
print(f"会话 ID: {SESSION_ID}")
# 创建 receiver 和 sender 任务
sender_task = asyncio.create_task(run_sender())
await asyncio.sleep(7)
receiver_task = asyncio.create_task(run_receiver())
# 等待两个任务完成
await asyncio.gather(receiver_task, sender_task)
print("--- 测试结束 ---")
if __name__ == "__main__":
# 在运行此脚本前,请确保 FastAPI 服务器正在运行。
# python main.py
asyncio.run(main())

3332
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff