[代码重构完成]完成了拥有VAD、ASR、SPKVerify(说话人加载自数据库)的基于fastAPI的ws语音识别与说话人识别服务器。
This commit is contained in:
parent
3083738db4
commit
db811763d4
145
README.md
145
README.md
@ -1,110 +1,97 @@
|
|||||||
# FunASR WebSocket服务
|
# FunASR FastAPI WebSocket Service
|
||||||
|
|
||||||
## 简介
|
一个基于 FunASR 和 FastAPI 构建的高性能、实时的语音识别 WebSocket 服务。该项目核心特色是支持"一发多收"的广播模式,适用于会议实时字幕、在线教育、直播转写等需要将单一音源的识别结果分发给多个客户端的场景。
|
||||||
本项目基于FunASR实现了一个WebSocket语音识别服务,支持实时语音流的在线和离线识别。利用ModelScope开源语音模型,该服务可以进行高精度的中文语音识别,并支持语音活动检测(VAD)和自动添加标点符号。
|
|
||||||
|
## ✨ 功能特性
|
||||||
|
|
||||||
|
- **实时语音处理**: 集成 FunASR 的语音活动检测(VAD)、语音识别(ASR)和声纹识别(SPK)模型。
|
||||||
|
- **WebSocket 流式 API**: 提供低延迟、双向的实时通信接口。
|
||||||
|
- **"一发多收"架构**:
|
||||||
|
- **发送者 (Sender)**: 单一客户端作为音频来源,向服务器持续发送音频流。
|
||||||
|
- **接收者 (Receiver)**: 多个客户端可以订阅同一个会话,实时接收广播的识别结果。
|
||||||
|
- **异步核心**: 基于 FastAPI 和 `asyncio` 构建,可处理大量并发连接。
|
||||||
|
- **模块化设计**: 清晰地分离了服务层 (`server.py`)、会话管理层 (`ASRRunner`) 和核心处理流水线 (`ASRPipeline`)。
|
||||||
|
|
||||||
|
## 📂 项目结构
|
||||||
|
|
||||||
## 项目结构
|
|
||||||
```
|
```
|
||||||
.
|
.
|
||||||
├── src/ # 源代码目录
|
├── main.py # 应用程序主入口,使用 uvicorn 启动服务
|
||||||
│ ├── __init__.py # 包初始化文件
|
├── WEBSOCKET_API.md # WebSocket API 详细使用文档和示例
|
||||||
│ ├── server.py # WebSocket服务器实现
|
├── src
|
||||||
│ ├── config.py # 配置处理模块
|
│ ├── server.py # FastAPI 应用核心,管理生命周期和全局资源
|
||||||
│ ├── models.py # 模型加载模块
|
│ ├── runner
|
||||||
│ ├── service.py # ASR服务实现
|
│ │ └── ASRRunner.py # 核心会话管理器,负责创建和协调识别会话 (SAR)
|
||||||
│ └── client.py # 测试客户端
|
│ ├── pipeline
|
||||||
├── tests/ # 测试目录
|
│ │ └── ASRpipeline.py # 同步的、基于线程的语音处理流水线
|
||||||
│ ├── __init__.py # 测试包初始化文件
|
│ ├── functor # VAD, ASR, SPK 等原子操作的实现
|
||||||
│ └── test_config.py # 配置模块测试
|
│ ├── websockets
|
||||||
├── requirements.txt # Python依赖
|
│ │ ├── adapter.py # WebSocket 适配器,处理数据格式转换
|
||||||
├── Dockerfile # Docker配置
|
│ │ ├── endpoint
|
||||||
├── docker-compose.yml # Docker Compose配置
|
│ │ │ └── asr_endpoint.py # WebSocket 的业务逻辑端点
|
||||||
├── .gitignore # Git忽略文件
|
│ │ └── router.py # WebSocket 路由
|
||||||
└── README.md # 项目说明
|
│ └── ...
|
||||||
|
└── tests
|
||||||
|
├── runner
|
||||||
|
│ └── asr_runner_test.py # ASRRunner 的单元测试 (异步)
|
||||||
|
└── websocket
|
||||||
|
└── websocket_asr.py # WebSocket 服务的端到端测试
|
||||||
```
|
```
|
||||||
|
|
||||||
## 功能特性
|
## 🚀 快速开始
|
||||||
|
|
||||||
- **多模式识别**:支持离线(offline)、在线(online)和两阶段(2pass)识别模式
|
### 1. 环境与依赖
|
||||||
- **语音活动检测**:自动检测语音开始和结束
|
|
||||||
- **标点符号**:支持自动添加标点符号
|
|
||||||
- **WebSocket接口**:基于二进制WebSocket提供实时语音识别
|
|
||||||
- **Docker支持**:提供容器化部署支持
|
|
||||||
|
|
||||||
## 安装与使用
|
|
||||||
|
|
||||||
### 环境要求
|
|
||||||
- Python 3.8+
|
- Python 3.8+
|
||||||
- CUDA支持 (若需GPU加速)
|
- 项目依赖项记录在 `requirements.txt` 文件中。
|
||||||
- 内存 >= 8GB
|
|
||||||
|
|
||||||
### 安装依赖
|
### 2. 安装
|
||||||
|
|
||||||
|
建议在虚拟环境中安装依赖。在项目根目录下,运行:
|
||||||
```bash
|
```bash
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
### 运行服务器
|
### 3. 运行服务
|
||||||
|
|
||||||
|
执行主入口文件来启动 FastAPI 服务:
|
||||||
```bash
|
```bash
|
||||||
python src/server.py
|
python main.py
|
||||||
```
|
```
|
||||||
|
服务启动后,将监听 `http://0.0.0.0:8000`。
|
||||||
|
|
||||||
常用启动参数:
|
## 💡 如何使用
|
||||||
- `--host`: 服务器监听地址,默认为 0.0.0.0
|
|
||||||
- `--port`: 服务器端口,默认为 10095
|
|
||||||
- `--device`: 设备类型(cuda或cpu),默认为 cuda
|
|
||||||
- `--ngpu`: GPU数量,0表示使用CPU,默认为 1
|
|
||||||
|
|
||||||
### 测试客户端
|
服务通过 WebSocket 提供,客户端通过 `session_id` 来创建或加入一个识别会话,并通过 `mode` 参数声明自己的角色(`sender` 或 `receiver`)。
|
||||||
|
|
||||||
|
**详细的 API 说明、URL 格式以及 Python 和 JavaScript 的客户端连接示例,请参阅:**
|
||||||
|
|
||||||
|
➡️ **[WEBSOCKET_API.md](./docs/WEBSOCKET_API.md)**
|
||||||
|
|
||||||
|
## 🔬 测试
|
||||||
|
|
||||||
|
项目提供了两种测试方式来验证其功能。
|
||||||
|
|
||||||
|
### 1. 端到端 WebSocket 测试
|
||||||
|
|
||||||
|
此测试会模拟一个 `sender` 和一个 `receiver`,完整地测试一次识别会话。
|
||||||
|
|
||||||
|
**前提**: 确保 FastAPI 服务正在运行。
|
||||||
```bash
|
```bash
|
||||||
python src/client.py --audio_file path/to/audio.wav
|
python main.py
|
||||||
```
|
```
|
||||||
|
|
||||||
常用客户端参数:
|
在项目根目录下执行:
|
||||||
- `--audio_file`: 要识别的音频文件路径
|
|
||||||
- `--mode`: 识别模式,可选 2pass/online/offline,默认为 2pass
|
|
||||||
- `--host`: 服务器地址,默认为 localhost
|
|
||||||
- `--port`: 服务器端口,默认为 10095
|
|
||||||
|
|
||||||
## Docker部署
|
|
||||||
|
|
||||||
### 构建镜像
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker build -t funasr-websocket .
|
python tests/websocket/websocket_asr.py
|
||||||
```
|
```
|
||||||
|
|
||||||
### 使用Docker Compose启动
|
### 2. ASRRunner 单元测试
|
||||||
|
|
||||||
|
此测试针对核心的 `ASRRunner` 组件进行,验证其异步逻辑。
|
||||||
|
|
||||||
|
执行测试:
|
||||||
```bash
|
```bash
|
||||||
docker-compose up -d
|
python test_main.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## API说明
|
|
||||||
|
|
||||||
### WebSocket消息格式
|
|
||||||
|
|
||||||
1. **客户端配置消息**:
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"mode": "2pass", // 可选: "2pass", "online", "offline"
|
|
||||||
"chunk_size": "5,10", // 块大小,格式为"encoder_size,decoder_size"
|
|
||||||
"wav_name": "audio1", // 音频标识名称
|
|
||||||
"is_speaking": true // 是否正在说话
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **客户端音频数据**:
|
|
||||||
二进制音频数据流,16kHz采样率,16位PCM格式
|
|
||||||
|
|
||||||
3. **服务器识别结果**:
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"mode": "2pass-online", // 识别模式
|
|
||||||
"text": "识别的文本内容", // 识别结果
|
|
||||||
"wav_name": "audio1", // 音频标识
|
|
||||||
"is_final": false // 是否是最终结果
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[
|
[
|
||||||
{
|
{
|
||||||
"speaker_id": "b7e2c8e2-1f3a-4c2a-9e7a-2c1d4e8f9a3b",
|
"speaker_id": "137facd6-a1c9-47b9-b87d-16e6f62e07bd",
|
||||||
"speaker_name": "ZiyangZhang",
|
"speaker_name": "ZiyangZhang",
|
||||||
"wav_path": "/home/lyg/Code/funasr/data/speaker_wav/ZiyangZhang.wav",
|
"wav_path": "/home/lyg/Code/funasr/data/speaker_wav/ZiyangZhang.wav",
|
||||||
"speaker_embs": [
|
"speaker_embs": [
|
||||||
@ -199,7 +199,7 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"speaker_id": "b7e2c8e2-1f3a-4c2a-9e7a-2c1d4e8f9a3b",
|
"speaker_id": "81e1b806-c76a-468f-98b7-4a63f2996480",
|
||||||
"speaker_name": "HaiaoDuan",
|
"speaker_name": "HaiaoDuan",
|
||||||
"wav_path": "/home/lyg/Code/funasr/data/speaker_wav/HaiaoDuan.wav",
|
"wav_path": "/home/lyg/Code/funasr/data/speaker_wav/HaiaoDuan.wav",
|
||||||
"speaker_embs": [
|
"speaker_embs": [
|
||||||
|
@ -26,7 +26,7 @@ ws://<your_server_host>:8000/ws/asr/{session_id}?mode=<client_mode>
|
|||||||
|
|
||||||
- **音频格式**: `sender` 必须发送原始的 **PCM 音频数据**。
|
- **音频格式**: `sender` 必须发送原始的 **PCM 音频数据**。
|
||||||
- **采样率**: 16000 Hz
|
- **采样率**: 16000 Hz
|
||||||
- **位深**: 16-bit (signed integer)
|
- **位深**: 32-bit (floating point)
|
||||||
- **声道数**: 单声道 (Mono)
|
- **声道数**: 单声道 (Mono)
|
||||||
- **传输格式**: 必须以**二进制 (bytes)** 格式发送。
|
- **传输格式**: 必须以**二进制 (bytes)** 格式发送。
|
||||||
- **结束信号**: 当音频流结束时,`sender` 应发送一个**文本消息** `"close"` 来通知服务器关闭会话。
|
- **结束信号**: 当音频流结束时,`sender` 应发送一个**文本消息** `"close"` 来通知服务器关闭会话。
|
||||||
@ -64,7 +64,7 @@ import uuid
|
|||||||
SERVER_URI = "ws://localhost:8000/ws/asr/{session_id}?mode=sender"
|
SERVER_URI = "ws://localhost:8000/ws/asr/{session_id}?mode=sender"
|
||||||
SESSION_ID = str(uuid.uuid4()) # 为这次会话生成一个唯一的ID
|
SESSION_ID = str(uuid.uuid4()) # 为这次会话生成一个唯一的ID
|
||||||
AUDIO_FILE = "tests/XT_ZZY_denoise.wav" # 替换为你的音频文件路径
|
AUDIO_FILE = "tests/XT_ZZY_denoise.wav" # 替换为你的音频文件路径
|
||||||
CHUNK_SIZE = 3200 # 每次发送 100ms 的音频数据 (16000 * 2 * 0.1)
|
CHUNK_SIZE = 3200 # 对应 100ms 的 float32 数据 (16000 * 4 * 0.1)
|
||||||
|
|
||||||
async def send_audio():
|
async def send_audio():
|
||||||
"""连接到服务器,并流式发送音频文件"""
|
"""连接到服务器,并流式发送音频文件"""
|
||||||
@ -80,7 +80,8 @@ async def send_audio():
|
|||||||
|
|
||||||
print("开始发送音频...")
|
print("开始发送音频...")
|
||||||
while True:
|
while True:
|
||||||
data = f.read(CHUNK_SIZE, dtype='int16')
|
# 读取为 float32 类型
|
||||||
|
data = f.read(CHUNK_SIZE, dtype='float32')
|
||||||
if not data.any():
|
if not data.any():
|
||||||
break
|
break
|
||||||
# 将 numpy 数组转换为原始字节流
|
# 将 numpy 数组转换为原始字节流
|
||||||
@ -215,18 +216,14 @@ if __name__ == "__main__":
|
|||||||
audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: SAMPLE_RATE });
|
audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: SAMPLE_RATE });
|
||||||
|
|
||||||
const source = audioContext.createMediaStreamSource(mediaStream);
|
const source = audioContext.createMediaStreamSource(mediaStream);
|
||||||
const bufferSize = CHUNK_DURATION_MS * SAMPLE_RATE / 1000 * 2; // 计算缓冲区大小
|
const bufferSize = CHUNK_DURATION_MS * SAMPLE_RATE / 1000 * 4; // 计算缓冲区大小
|
||||||
scriptProcessor = audioContext.createScriptProcessor(bufferSize, 1, 1);
|
scriptProcessor = audioContext.createScriptProcessor(bufferSize, 1, 1);
|
||||||
|
|
||||||
scriptProcessor.onaudioprocess = (e) => {
|
scriptProcessor.onaudioprocess = (e) => {
|
||||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||||
const inputData = e.inputBuffer.getChannelData(0);
|
const inputData = e.inputBuffer.getChannelData(0);
|
||||||
// 服务器期望16-bit PCM,需要转换
|
// 服务器期望 float32 数据,inputData 本身就是 Float32Array,直接发送其 buffer
|
||||||
const pcmData = new Int16Array(inputData.length);
|
websocket.send(inputData.buffer);
|
||||||
for (let i = 0; i < inputData.length; i++) {
|
|
||||||
pcmData[i] = Math.max(-1, Math.min(1, inputData[i])) * 32767;
|
|
||||||
}
|
|
||||||
websocket.send(pcmData.buffer);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
4
main.py
4
main.py
@ -4,7 +4,7 @@ from src.utils.logger import get_module_logger, setup_root_logger
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
setup_root_logger(level="INFO", log_file=f"logs/main_{time}.log")
|
setup_root_logger(level="DEBUG", log_file=f"logs/main_{time}.log")
|
||||||
logger = get_module_logger(__name__)
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -14,6 +14,6 @@ if __name__ == "__main__":
|
|||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app,
|
app,
|
||||||
host="0.0.0.0",
|
host="0.0.0.0",
|
||||||
port=8000,
|
port=11096,
|
||||||
log_level="info"
|
log_level="info"
|
||||||
)
|
)
|
32
paerser.py
Normal file
32
paerser.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
def pre_remove_details(input_string: str) -> str:
|
||||||
|
start_tag = '</details>'
|
||||||
|
start = input_string.find(start_tag, 0)
|
||||||
|
if start == -1:
|
||||||
|
return input_string, "unfind %s" % start_tag
|
||||||
|
return input_string[start + len(start_tag):], "success remove %s" % start_tag
|
||||||
|
|
||||||
|
def pre_remove_markdown(input_string: str) -> str:
|
||||||
|
start_tag = '```markdown'
|
||||||
|
end_tag = '```'
|
||||||
|
start = input_string.find(start_tag, 0)
|
||||||
|
if start == -1:
|
||||||
|
return input_string, "unfind %s" % start_tag
|
||||||
|
end = input_string.find(end_tag, start + 11)
|
||||||
|
if end == -1:
|
||||||
|
return input_string, "unfind %s" % end_tag
|
||||||
|
return input_string[start + 11:end].strip(), "success remove %s" % start_tag
|
||||||
|
|
||||||
|
def main(input_string: str) -> dict:
|
||||||
|
result = input_string
|
||||||
|
statuses = []
|
||||||
|
|
||||||
|
result, detail_status = pre_remove_details(result)
|
||||||
|
statuses.append(detail_status)
|
||||||
|
|
||||||
|
result, markdown_status = pre_remove_markdown(result)
|
||||||
|
statuses.append(markdown_status)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"result": result,
|
||||||
|
"status": statuses
|
||||||
|
}
|
@ -1,11 +1,11 @@
|
|||||||
pytest==7.3.1
|
fastapi==0.115.14
|
||||||
pytest-cov==4.1.0
|
funasr==1.2.6
|
||||||
flake8==6.0.0
|
modelscope==1.27.1
|
||||||
black==23.3.0
|
numpy==2.0.1
|
||||||
isort==5.12.0
|
pyaudio==0.2.14
|
||||||
flask==2.3.2
|
pydantic==2.11.7
|
||||||
requests==2.31.0
|
pydub==0.25.1
|
||||||
websockets==11.0.3
|
pytest==8.3.5
|
||||||
numpy==1.24.3
|
soundfile==0.13.1
|
||||||
funasr==0.10.0
|
torch==2.3.1
|
||||||
modelscope==1.9.5
|
uvicorn==0.35.0
|
||||||
|
11
requirements.txt.backup
Normal file
11
requirements.txt.backup
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
pytest==7.3.1
|
||||||
|
pytest-cov==4.1.0
|
||||||
|
flake8==6.0.0
|
||||||
|
black==23.3.0
|
||||||
|
isort==5.12.0
|
||||||
|
flask==2.3.2
|
||||||
|
requests==2.31.0
|
||||||
|
websockets==11.0.3
|
||||||
|
numpy==1.24.3
|
||||||
|
funasr==0.10.0
|
||||||
|
modelscope==1.9.5
|
@ -8,7 +8,7 @@ try:
|
|||||||
from funasr import AutoModel
|
from funasr import AutoModel
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
raise ImportError("未找到funasr库, 请先安装: pip install funasr") from exc
|
raise ImportError("未找到funasr库, 请先安装: pip install funasr") from exc
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
# 日志模块
|
# 日志模块
|
||||||
from src.utils import get_module_logger
|
from src.utils import get_module_logger
|
||||||
|
|
||||||
@ -101,6 +101,16 @@ class ModelLoader:
|
|||||||
logger.error("加载%s模型失败: %s", model_type, str(e))
|
logger.error("加载%s模型失败: %s", model_type, str(e))
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# def _load_pipeline(self, input_model_args: dict, model_type: str):
|
||||||
|
# """
|
||||||
|
# 加载pipeline
|
||||||
|
# """
|
||||||
|
# default_pipeline = pipeline(
|
||||||
|
# task='speaker-verification',
|
||||||
|
# model='iic/speech_campplus_sv_zh-cn_16k-common',
|
||||||
|
# model_revision='v1.0.0'
|
||||||
|
# )
|
||||||
|
|
||||||
def load_models(self, args):
|
def load_models(self, args):
|
||||||
"""
|
"""
|
||||||
加载所有需要的模型
|
加载所有需要的模型
|
||||||
@ -115,12 +125,19 @@ class ModelLoader:
|
|||||||
self.models = {}
|
self.models = {}
|
||||||
# 加载离线ASR模型
|
# 加载离线ASR模型
|
||||||
# 检查对应键是否存在
|
# 检查对应键是否存在
|
||||||
model_list = ["asr", "asr_online", "vad", "punc", "spk"]
|
model_list = ["asr", "asr_online", "vad", "punc"]
|
||||||
for model_name in model_list:
|
for model_name in model_list:
|
||||||
name_model = f"{model_name}_model"
|
name_model = f"{model_name}_model"
|
||||||
name_model_revision = f"{model_name}_model_revision"
|
name_model_revision = f"{model_name}_model_revision"
|
||||||
if name_model in args:
|
if name_model in args:
|
||||||
logger.debug("加载%s模型", model_name)
|
logger.debug("加载%s模型", model_name)
|
||||||
self.models[model_name] = self._load_model(args, model_name)
|
self.models[model_name] = self._load_model(args, model_name)
|
||||||
logger.info("所有模型加载完成")
|
pipeline_list = ["spk"]
|
||||||
|
for pipeline_name in pipeline_list:
|
||||||
|
if pipeline_name == "spk":
|
||||||
|
self.models[pipeline_name] = pipeline(
|
||||||
|
task='speaker-verification',
|
||||||
|
model='iic/speech_campplus_sv_zh-cn_16k-common',
|
||||||
|
model_revision='v1.0.0'
|
||||||
|
)
|
||||||
return self.models
|
return self.models
|
||||||
|
@ -17,6 +17,9 @@ from typing import Callable, List, Dict
|
|||||||
from queue import Queue
|
from queue import Queue
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
class BaseFunctor(ABC):
|
class BaseFunctor(ABC):
|
||||||
"""
|
"""
|
||||||
@ -157,13 +160,15 @@ class FunctorFactory:
|
|||||||
"""
|
"""
|
||||||
from src.functor.spk_functor import SPKFunctor
|
from src.functor.spk_functor import SPKFunctor
|
||||||
|
|
||||||
|
logger.debug(f"创建spk functor[开始]")
|
||||||
audio_config = config["audio_config"]
|
audio_config = config["audio_config"]
|
||||||
model = {"spk": models["spk"]}
|
# model = {"spk": models["spk"]}
|
||||||
|
|
||||||
spk_functor = SPKFunctor()
|
spk_functor = SPKFunctor(sv_pipeline=models["spk"])
|
||||||
spk_functor.set_audio_config(audio_config)
|
spk_functor.set_audio_config(audio_config)
|
||||||
spk_functor.set_model(model)
|
# spk_functor.set_model(model)
|
||||||
|
|
||||||
|
logger.debug(f"创建spk functor[完成]")
|
||||||
return spk_functor
|
return spk_functor
|
||||||
|
|
||||||
def _make_resultbinderfunctor(config: dict, models: dict) -> BaseFunctor:
|
def _make_resultbinderfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||||
|
@ -5,7 +5,7 @@ ResultBinderFunctor
|
|||||||
|
|
||||||
from src.functor.base import BaseFunctor
|
from src.functor.base import BaseFunctor
|
||||||
from src.models import AudioBinary_Config, VAD_Functor_result
|
from src.models import AudioBinary_Config, VAD_Functor_result
|
||||||
from typing import Callable, List
|
from typing import Callable, List, Dict, Any
|
||||||
from queue import Queue, Empty
|
from queue import Queue, Empty
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@ -74,16 +74,30 @@ class ResultBinderFunctor(BaseFunctor):
|
|||||||
for callback in self._callback:
|
for callback in self._callback:
|
||||||
callback(result)
|
callback(result)
|
||||||
|
|
||||||
def _process(self, data: VAD_Functor_result) -> None:
|
def _process(self, data: Dict[str, Any]) -> None:
|
||||||
"""
|
"""
|
||||||
处理数据
|
处理数据
|
||||||
|
{
|
||||||
|
"is_final": false,
|
||||||
|
"mode": "2pass-offline",
|
||||||
|
"text": ",等一下我回一下,ok你看这里就有了",
|
||||||
|
"wav_name": "h5",
|
||||||
|
"speaker_id":
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
logger.debug("ResultBinderFunctor处理数据: %s", data)
|
logger.debug("ResultBinderFunctor处理数据: %s", data)
|
||||||
# 将data中的result进行聚合
|
# 将data中的result进行聚合
|
||||||
# 此步暂时无意义,预留
|
# 此步暂时无意义,预留
|
||||||
results = {}
|
results = {
|
||||||
for name, result in data.items():
|
"is_final": False,
|
||||||
results[name] = result
|
"mode": "2pass-offline",
|
||||||
|
"text": data["asr"],
|
||||||
|
"wav_name": "h5",
|
||||||
|
"speaker_id": data["spk"]["speaker_id"]
|
||||||
|
}
|
||||||
|
|
||||||
|
# for name, result in data.items():
|
||||||
|
# results[name] = result
|
||||||
self._do_callback(results)
|
self._do_callback(results)
|
||||||
|
|
||||||
def _run(self) -> None:
|
def _run(self) -> None:
|
||||||
|
File diff suppressed because one or more lines are too long
@ -7,6 +7,8 @@ import threading
|
|||||||
from queue import Empty, Queue
|
from queue import Empty, Queue
|
||||||
from typing import List, Any, Callable
|
from typing import List, Any, Callable
|
||||||
import numpy
|
import numpy
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
from src.models import (
|
from src.models import (
|
||||||
VAD_Functor_result,
|
VAD_Functor_result,
|
||||||
AudioBinary_Config,
|
AudioBinary_Config,
|
||||||
@ -126,7 +128,10 @@ class VADFunctor(BaseFunctor):
|
|||||||
self._cache_result_list[-1][1] = end
|
self._cache_result_list[-1][1] = end
|
||||||
else:
|
else:
|
||||||
self._cache_result_list.append(pair)
|
self._cache_result_list.append(pair)
|
||||||
while len(self._cache_result_list) > 1:
|
logger.debug(f"VADFunctor结果: {self._cache_result_list}")
|
||||||
|
while len(self._cache_result_list) > 0 and self._cache_result_list[0][1] != -1:
|
||||||
|
logger.debug(f"VADFunctor结果: {self._cache_result_list}")
|
||||||
|
logger.debug(f"VADFunctor list[0][1]: {self._cache_result_list[0][1]}")
|
||||||
# 创建VAD片段
|
# 创建VAD片段
|
||||||
# 计算开始帧
|
# 计算开始帧
|
||||||
start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0])
|
start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0])
|
||||||
@ -135,12 +140,16 @@ class VADFunctor(BaseFunctor):
|
|||||||
end_frame = self._audio_config.ms2frame(self._cache_result_list[0][1])
|
end_frame = self._audio_config.ms2frame(self._cache_result_list[0][1])
|
||||||
end_frame -= self._audio_cache_preindex
|
end_frame -= self._audio_cache_preindex
|
||||||
# 计算开始时间
|
# 计算开始时间
|
||||||
|
timestamp = time.time()
|
||||||
|
format_time = datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
logger.debug(f"{format_time}创建VAD片段: {start_frame} - {end_frame}")
|
||||||
vad_result = VAD_Functor_result.create_from_push_data(
|
vad_result = VAD_Functor_result.create_from_push_data(
|
||||||
audiobinary_data_list=self._audio_binary_data_list,
|
audiobinary_data_list=self._audio_binary_data_list,
|
||||||
data=self._audiobinary_cache[start_frame:end_frame],
|
data=self._audiobinary_cache[start_frame:end_frame],
|
||||||
start_time=self._cache_result_list[0][0],
|
start_time=self._cache_result_list[0][0],
|
||||||
end_time=self._cache_result_list[0][1],
|
end_time=self._cache_result_list[0][1],
|
||||||
)
|
)
|
||||||
|
logger.debug(f"{format_time}创建VAD片段成功: {vad_result}")
|
||||||
self._audio_cache_preindex += end_frame
|
self._audio_cache_preindex += end_frame
|
||||||
self._audiobinary_cache = self._audiobinary_cache[end_frame:]
|
self._audiobinary_cache = self._audiobinary_cache[end_frame:]
|
||||||
for callback in self._callback:
|
for callback in self._callback:
|
||||||
|
@ -108,16 +108,19 @@ class ASRPipeline(PipelineBase):
|
|||||||
from src.functor import FunctorFactory
|
from src.functor import FunctorFactory
|
||||||
|
|
||||||
# 加载VAD、asr、spk functor
|
# 加载VAD、asr、spk functor
|
||||||
|
logger.debug(f"使用FunctorFactory创建vad functor")
|
||||||
self._functor_dict["vad"] = FunctorFactory.make_functor(
|
self._functor_dict["vad"] = FunctorFactory.make_functor(
|
||||||
functor_name="vad", config=self._config, models=self._models
|
functor_name="vad", config=self._config, models=self._models
|
||||||
)
|
)
|
||||||
|
logger.debug(f"使用FunctorFactory创建asr functor")
|
||||||
self._functor_dict["asr"] = FunctorFactory.make_functor(
|
self._functor_dict["asr"] = FunctorFactory.make_functor(
|
||||||
functor_name="asr", config=self._config, models=self._models
|
functor_name="asr", config=self._config, models=self._models
|
||||||
)
|
)
|
||||||
|
logger.debug(f"使用FunctorFactory创建spk functor")
|
||||||
self._functor_dict["spk"] = FunctorFactory.make_functor(
|
self._functor_dict["spk"] = FunctorFactory.make_functor(
|
||||||
functor_name="spk", config=self._config, models=self._models
|
functor_name="spk", config=self._config, models=self._models
|
||||||
)
|
)
|
||||||
|
logger.debug(f"使用FunctorFactory创建resultbinder functor")
|
||||||
self._functor_dict["resultbinder"] = FunctorFactory.make_functor(
|
self._functor_dict["resultbinder"] = FunctorFactory.make_functor(
|
||||||
functor_name="resultbinder", config=self._config, models=self._models
|
functor_name="resultbinder", config=self._config, models=self._models
|
||||||
)
|
)
|
||||||
@ -160,8 +163,10 @@ class ASRPipeline(PipelineBase):
|
|||||||
# 设置resultbinder的回调函数 为 自身被设置的回调函数,用于和外界交互
|
# 设置resultbinder的回调函数 为 自身被设置的回调函数,用于和外界交互
|
||||||
self._functor_dict["resultbinder"].add_callback(self._callback)
|
self._functor_dict["resultbinder"].add_callback(self._callback)
|
||||||
|
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
|
raise ImportError(f"functorFactory引入失败,ASRPipeline无法完成初始化: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"ASRPipeline初始化失败: {str(e)}")
|
||||||
|
|
||||||
def _check_result(self, result: Any) -> None:
|
def _check_result(self, result: Any) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -25,8 +25,11 @@ class FastAPIWebSocketAdapter:
|
|||||||
message = await self._ws.receive()
|
message = await self._ws.receive()
|
||||||
if 'bytes' in message:
|
if 'bytes' in message:
|
||||||
bytes_data = message['bytes']
|
bytes_data = message['bytes']
|
||||||
# 将字节流转换为float64的Numpy数组
|
audio_array = np.frombuffer(bytes_data, dtype=np.int16)
|
||||||
audio_array = np.frombuffer(bytes_data, dtype=np.float32)
|
# 将int16转为float32
|
||||||
|
audio_array = audio_array.astype(np.float32)
|
||||||
|
# 归一化
|
||||||
|
audio_array = audio_array / 32768.0
|
||||||
|
|
||||||
# 使用回车符 \r 覆盖打印进度
|
# 使用回车符 \r 覆盖打印进度
|
||||||
self._total_received += len(bytes_data)
|
self._total_received += len(bytes_data)
|
||||||
|
@ -65,7 +65,7 @@ async def test_asr_runner():
|
|||||||
}
|
}
|
||||||
models = model_loader.load_models(args)
|
models = model_loader.load_models(args)
|
||||||
audio_file_path = "tests/XT_ZZY_denoise.wav"
|
audio_file_path = "tests/XT_ZZY_denoise.wav"
|
||||||
audio_data, sample_rate = soundfile.read(audio_file_path)
|
audio_data, sample_rate = soundfile.read(audio_file_path, dtype='float32')
|
||||||
logger.info(
|
logger.info(
|
||||||
f"加载数据: {audio_file_path} , audio_data_length: {len(audio_data)}, audio_data_type: {type(audio_data)}, sample_rate: {sample_rate}"
|
f"加载数据: {audio_file_path} , audio_data_length: {len(audio_data)}, audio_data_type: {type(audio_data)}, sample_rate: {sample_rate}"
|
||||||
)
|
)
|
||||||
@ -112,7 +112,7 @@ async def test_asr_runner():
|
|||||||
# 每次发送100ms的音频
|
# 每次发送100ms的音频
|
||||||
audio_clip_len = int(sample_rate * 0.1)
|
audio_clip_len = int(sample_rate * 0.1)
|
||||||
for i in range(0, len(audio_data), audio_clip_len):
|
for i in range(0, len(audio_data), audio_clip_len):
|
||||||
chunk = audio_data[i : i + audio_clip_len]
|
chunk = audio_data[i : i + audio_clip_len].astype(np.float32)
|
||||||
if chunk.size == 0:
|
if chunk.size == 0:
|
||||||
break
|
break
|
||||||
mock_ws.put_for_recv(chunk)
|
mock_ws.put_for_recv(chunk)
|
||||||
|
@ -5,7 +5,7 @@ import uuid
|
|||||||
|
|
||||||
# --- 配置 ---
|
# --- 配置 ---
|
||||||
HOST = "localhost"
|
HOST = "localhost"
|
||||||
PORT = 8000
|
PORT = 11096
|
||||||
SESSION_ID = str(uuid.uuid4())
|
SESSION_ID = str(uuid.uuid4())
|
||||||
SENDER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=sender"
|
SENDER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=sender"
|
||||||
RECEIVER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=receiver"
|
RECEIVER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=receiver"
|
||||||
@ -54,7 +54,7 @@ async def run_sender():
|
|||||||
# --- 主任务:发送音频 ---
|
# --- 主任务:发送音频 ---
|
||||||
try:
|
try:
|
||||||
print("▶️ [Sender] 准备发送音频...")
|
print("▶️ [Sender] 准备发送音频...")
|
||||||
audio_data, sample_rate = sf.read(AUDIO_FILE_PATH, dtype='float32')
|
audio_data, sample_rate = sf.read(AUDIO_FILE_PATH, dtype='int16')
|
||||||
if sample_rate != 16000:
|
if sample_rate != 16000:
|
||||||
print(f"❌ [Sender] 错误:音频文件采样率必须是 16kHz。")
|
print(f"❌ [Sender] 错误:音频文件采样率必须是 16kHz。")
|
||||||
receiver_sub_task.cancel()
|
receiver_sub_task.cancel()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user