[代码重构完成]完成了拥有VAD、ASR、SPKVerify(说话人加载自数据库)的基于fastAPI的ws语音识别与说话人识别服务器。
This commit is contained in:
parent
3083738db4
commit
db811763d4
145
README.md
145
README.md
@ -1,110 +1,97 @@
|
||||
# FunASR WebSocket服务
|
||||
# FunASR FastAPI WebSocket Service
|
||||
|
||||
## 简介
|
||||
本项目基于FunASR实现了一个WebSocket语音识别服务,支持实时语音流的在线和离线识别。利用ModelScope开源语音模型,该服务可以进行高精度的中文语音识别,并支持语音活动检测(VAD)和自动添加标点符号。
|
||||
一个基于 FunASR 和 FastAPI 构建的高性能、实时的语音识别 WebSocket 服务。该项目核心特色是支持"一发多收"的广播模式,适用于会议实时字幕、在线教育、直播转写等需要将单一音源的识别结果分发给多个客户端的场景。
|
||||
|
||||
## ✨ 功能特性
|
||||
|
||||
- **实时语音处理**: 集成 FunASR 的语音活动检测(VAD)、语音识别(ASR)和声纹识别(SPK)模型。
|
||||
- **WebSocket 流式 API**: 提供低延迟、双向的实时通信接口。
|
||||
- **"一发多收"架构**:
|
||||
- **发送者 (Sender)**: 单一客户端作为音频来源,向服务器持续发送音频流。
|
||||
- **接收者 (Receiver)**: 多个客户端可以订阅同一个会话,实时接收广播的识别结果。
|
||||
- **异步核心**: 基于 FastAPI 和 `asyncio` 构建,可处理大量并发连接。
|
||||
- **模块化设计**: 清晰地分离了服务层 (`server.py`)、会话管理层 (`ASRRunner`) 和核心处理流水线 (`ASRPipeline`)。
|
||||
|
||||
## 📂 项目结构
|
||||
|
||||
## 项目结构
|
||||
```
|
||||
.
|
||||
├── src/ # 源代码目录
|
||||
│ ├── __init__.py # 包初始化文件
|
||||
│ ├── server.py # WebSocket服务器实现
|
||||
│ ├── config.py # 配置处理模块
|
||||
│ ├── models.py # 模型加载模块
|
||||
│ ├── service.py # ASR服务实现
|
||||
│ └── client.py # 测试客户端
|
||||
├── tests/ # 测试目录
|
||||
│ ├── __init__.py # 测试包初始化文件
|
||||
│ └── test_config.py # 配置模块测试
|
||||
├── requirements.txt # Python依赖
|
||||
├── Dockerfile # Docker配置
|
||||
├── docker-compose.yml # Docker Compose配置
|
||||
├── .gitignore # Git忽略文件
|
||||
└── README.md # 项目说明
|
||||
├── main.py # 应用程序主入口,使用 uvicorn 启动服务
|
||||
├── WEBSOCKET_API.md # WebSocket API 详细使用文档和示例
|
||||
├── src
|
||||
│ ├── server.py # FastAPI 应用核心,管理生命周期和全局资源
|
||||
│ ├── runner
|
||||
│ │ └── ASRRunner.py # 核心会话管理器,负责创建和协调识别会话 (SAR)
|
||||
│ ├── pipeline
|
||||
│ │ └── ASRpipeline.py # 同步的、基于线程的语音处理流水线
|
||||
│ ├── functor # VAD, ASR, SPK 等原子操作的实现
|
||||
│ ├── websockets
|
||||
│ │ ├── adapter.py # WebSocket 适配器,处理数据格式转换
|
||||
│ │ ├── endpoint
|
||||
│ │ │ └── asr_endpoint.py # WebSocket 的业务逻辑端点
|
||||
│ │ └── router.py # WebSocket 路由
|
||||
│ └── ...
|
||||
└── tests
|
||||
├── runner
|
||||
│ └── asr_runner_test.py # ASRRunner 的单元测试 (异步)
|
||||
└── websocket
|
||||
└── websocket_asr.py # WebSocket 服务的端到端测试
|
||||
```
|
||||
|
||||
## 功能特性
|
||||
## 🚀 快速开始
|
||||
|
||||
- **多模式识别**:支持离线(offline)、在线(online)和两阶段(2pass)识别模式
|
||||
- **语音活动检测**:自动检测语音开始和结束
|
||||
- **标点符号**:支持自动添加标点符号
|
||||
- **WebSocket接口**:基于二进制WebSocket提供实时语音识别
|
||||
- **Docker支持**:提供容器化部署支持
|
||||
### 1. 环境与依赖
|
||||
|
||||
## 安装与使用
|
||||
|
||||
### 环境要求
|
||||
- Python 3.8+
|
||||
- CUDA支持 (若需GPU加速)
|
||||
- 内存 >= 8GB
|
||||
- 项目依赖项记录在 `requirements.txt` 文件中。
|
||||
|
||||
### 安装依赖
|
||||
### 2. 安装
|
||||
|
||||
建议在虚拟环境中安装依赖。在项目根目录下,运行:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 运行服务器
|
||||
### 3. 运行服务
|
||||
|
||||
执行主入口文件来启动 FastAPI 服务:
|
||||
```bash
|
||||
python src/server.py
|
||||
python main.py
|
||||
```
|
||||
服务启动后,将监听 `http://0.0.0.0:8000`。
|
||||
|
||||
常用启动参数:
|
||||
- `--host`: 服务器监听地址,默认为 0.0.0.0
|
||||
- `--port`: 服务器端口,默认为 10095
|
||||
- `--device`: 设备类型(cuda或cpu),默认为 cuda
|
||||
- `--ngpu`: GPU数量,0表示使用CPU,默认为 1
|
||||
## 💡 如何使用
|
||||
|
||||
### 测试客户端
|
||||
服务通过 WebSocket 提供,客户端通过 `session_id` 来创建或加入一个识别会话,并通过 `mode` 参数声明自己的角色(`sender` 或 `receiver`)。
|
||||
|
||||
**详细的 API 说明、URL 格式以及 Python 和 JavaScript 的客户端连接示例,请参阅:**
|
||||
|
||||
➡️ **[WEBSOCKET_API.md](./docs/WEBSOCKET_API.md)**
|
||||
|
||||
## 🔬 测试
|
||||
|
||||
项目提供了两种测试方式来验证其功能。
|
||||
|
||||
### 1. 端到端 WebSocket 测试
|
||||
|
||||
此测试会模拟一个 `sender` 和一个 `receiver`,完整地测试一次识别会话。
|
||||
|
||||
**前提**: 确保 FastAPI 服务正在运行。
|
||||
```bash
|
||||
python src/client.py --audio_file path/to/audio.wav
|
||||
python main.py
|
||||
```
|
||||
|
||||
常用客户端参数:
|
||||
- `--audio_file`: 要识别的音频文件路径
|
||||
- `--mode`: 识别模式,可选 2pass/online/offline,默认为 2pass
|
||||
- `--host`: 服务器地址,默认为 localhost
|
||||
- `--port`: 服务器端口,默认为 10095
|
||||
|
||||
## Docker部署
|
||||
|
||||
### 构建镜像
|
||||
|
||||
在项目根目录下执行:
|
||||
```bash
|
||||
docker build -t funasr-websocket .
|
||||
python tests/websocket/websocket_asr.py
|
||||
```
|
||||
|
||||
### 使用Docker Compose启动
|
||||
### 2. ASRRunner 单元测试
|
||||
|
||||
此测试针对核心的 `ASRRunner` 组件进行,验证其异步逻辑。
|
||||
|
||||
执行测试:
|
||||
```bash
|
||||
docker-compose up -d
|
||||
python test_main.py
|
||||
```
|
||||
|
||||
## API说明
|
||||
|
||||
### WebSocket消息格式
|
||||
|
||||
1. **客户端配置消息**:
|
||||
```json
|
||||
{
|
||||
"mode": "2pass", // 可选: "2pass", "online", "offline"
|
||||
"chunk_size": "5,10", // 块大小,格式为"encoder_size,decoder_size"
|
||||
"wav_name": "audio1", // 音频标识名称
|
||||
"is_speaking": true // 是否正在说话
|
||||
}
|
||||
```
|
||||
|
||||
2. **客户端音频数据**:
|
||||
二进制音频数据流,16kHz采样率,16位PCM格式
|
||||
|
||||
3. **服务器识别结果**:
|
||||
```json
|
||||
{
|
||||
"mode": "2pass-online", // 识别模式
|
||||
"text": "识别的文本内容", // 识别结果
|
||||
"wav_name": "audio1", // 音频标识
|
||||
"is_final": false // 是否是最终结果
|
||||
}
|
||||
```
|
||||
|
@ -1,6 +1,6 @@
|
||||
[
|
||||
{
|
||||
"speaker_id": "b7e2c8e2-1f3a-4c2a-9e7a-2c1d4e8f9a3b",
|
||||
"speaker_id": "137facd6-a1c9-47b9-b87d-16e6f62e07bd",
|
||||
"speaker_name": "ZiyangZhang",
|
||||
"wav_path": "/home/lyg/Code/funasr/data/speaker_wav/ZiyangZhang.wav",
|
||||
"speaker_embs": [
|
||||
@ -199,7 +199,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"speaker_id": "b7e2c8e2-1f3a-4c2a-9e7a-2c1d4e8f9a3b",
|
||||
"speaker_id": "81e1b806-c76a-468f-98b7-4a63f2996480",
|
||||
"speaker_name": "HaiaoDuan",
|
||||
"wav_path": "/home/lyg/Code/funasr/data/speaker_wav/HaiaoDuan.wav",
|
||||
"speaker_embs": [
|
||||
|
@ -26,7 +26,7 @@ ws://<your_server_host>:8000/ws/asr/{session_id}?mode=<client_mode>
|
||||
|
||||
- **音频格式**: `sender` 必须发送原始的 **PCM 音频数据**。
|
||||
- **采样率**: 16000 Hz
|
||||
- **位深**: 16-bit (signed integer)
|
||||
- **位深**: 32-bit (floating point)
|
||||
- **声道数**: 单声道 (Mono)
|
||||
- **传输格式**: 必须以**二进制 (bytes)** 格式发送。
|
||||
- **结束信号**: 当音频流结束时,`sender` 应发送一个**文本消息** `"close"` 来通知服务器关闭会话。
|
||||
@ -64,7 +64,7 @@ import uuid
|
||||
SERVER_URI = "ws://localhost:8000/ws/asr/{session_id}?mode=sender"
|
||||
SESSION_ID = str(uuid.uuid4()) # 为这次会话生成一个唯一的ID
|
||||
AUDIO_FILE = "tests/XT_ZZY_denoise.wav" # 替换为你的音频文件路径
|
||||
CHUNK_SIZE = 3200 # 每次发送 100ms 的音频数据 (16000 * 2 * 0.1)
|
||||
CHUNK_SIZE = 3200 # 对应 100ms 的 float32 数据 (16000 * 4 * 0.1)
|
||||
|
||||
async def send_audio():
|
||||
"""连接到服务器,并流式发送音频文件"""
|
||||
@ -80,7 +80,8 @@ async def send_audio():
|
||||
|
||||
print("开始发送音频...")
|
||||
while True:
|
||||
data = f.read(CHUNK_SIZE, dtype='int16')
|
||||
# 读取为 float32 类型
|
||||
data = f.read(CHUNK_SIZE, dtype='float32')
|
||||
if not data.any():
|
||||
break
|
||||
# 将 numpy 数组转换为原始字节流
|
||||
@ -215,18 +216,14 @@ if __name__ == "__main__":
|
||||
audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: SAMPLE_RATE });
|
||||
|
||||
const source = audioContext.createMediaStreamSource(mediaStream);
|
||||
const bufferSize = CHUNK_DURATION_MS * SAMPLE_RATE / 1000 * 2; // 计算缓冲区大小
|
||||
const bufferSize = CHUNK_DURATION_MS * SAMPLE_RATE / 1000 * 4; // 计算缓冲区大小
|
||||
scriptProcessor = audioContext.createScriptProcessor(bufferSize, 1, 1);
|
||||
|
||||
scriptProcessor.onaudioprocess = (e) => {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
const inputData = e.inputBuffer.getChannelData(0);
|
||||
// 服务器期望16-bit PCM,需要转换
|
||||
const pcmData = new Int16Array(inputData.length);
|
||||
for (let i = 0; i < inputData.length; i++) {
|
||||
pcmData[i] = Math.max(-1, Math.min(1, inputData[i])) * 32767;
|
||||
}
|
||||
websocket.send(pcmData.buffer);
|
||||
// 服务器期望 float32 数据,inputData 本身就是 Float32Array,直接发送其 buffer
|
||||
websocket.send(inputData.buffer);
|
||||
}
|
||||
};
|
||||
|
4
main.py
4
main.py
@ -4,7 +4,7 @@ from src.utils.logger import get_module_logger, setup_root_logger
|
||||
from datetime import datetime
|
||||
|
||||
time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
setup_root_logger(level="INFO", log_file=f"logs/main_{time}.log")
|
||||
setup_root_logger(level="DEBUG", log_file=f"logs/main_{time}.log")
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -14,6 +14,6 @@ if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
app,
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
port=11096,
|
||||
log_level="info"
|
||||
)
|
32
paerser.py
Normal file
32
paerser.py
Normal file
@ -0,0 +1,32 @@
|
||||
def pre_remove_details(input_string: str) -> str:
|
||||
start_tag = '</details>'
|
||||
start = input_string.find(start_tag, 0)
|
||||
if start == -1:
|
||||
return input_string, "unfind %s" % start_tag
|
||||
return input_string[start + len(start_tag):], "success remove %s" % start_tag
|
||||
|
||||
def pre_remove_markdown(input_string: str) -> str:
|
||||
start_tag = '```markdown'
|
||||
end_tag = '```'
|
||||
start = input_string.find(start_tag, 0)
|
||||
if start == -1:
|
||||
return input_string, "unfind %s" % start_tag
|
||||
end = input_string.find(end_tag, start + 11)
|
||||
if end == -1:
|
||||
return input_string, "unfind %s" % end_tag
|
||||
return input_string[start + 11:end].strip(), "success remove %s" % start_tag
|
||||
|
||||
def main(input_string: str) -> dict:
|
||||
result = input_string
|
||||
statuses = []
|
||||
|
||||
result, detail_status = pre_remove_details(result)
|
||||
statuses.append(detail_status)
|
||||
|
||||
result, markdown_status = pre_remove_markdown(result)
|
||||
statuses.append(markdown_status)
|
||||
|
||||
return {
|
||||
"result": result,
|
||||
"status": statuses
|
||||
}
|
@ -1,11 +1,11 @@
|
||||
pytest==7.3.1
|
||||
pytest-cov==4.1.0
|
||||
flake8==6.0.0
|
||||
black==23.3.0
|
||||
isort==5.12.0
|
||||
flask==2.3.2
|
||||
requests==2.31.0
|
||||
websockets==11.0.3
|
||||
numpy==1.24.3
|
||||
funasr==0.10.0
|
||||
modelscope==1.9.5
|
||||
fastapi==0.115.14
|
||||
funasr==1.2.6
|
||||
modelscope==1.27.1
|
||||
numpy==2.0.1
|
||||
pyaudio==0.2.14
|
||||
pydantic==2.11.7
|
||||
pydub==0.25.1
|
||||
pytest==8.3.5
|
||||
soundfile==0.13.1
|
||||
torch==2.3.1
|
||||
uvicorn==0.35.0
|
||||
|
11
requirements.txt.backup
Normal file
11
requirements.txt.backup
Normal file
@ -0,0 +1,11 @@
|
||||
pytest==7.3.1
|
||||
pytest-cov==4.1.0
|
||||
flake8==6.0.0
|
||||
black==23.3.0
|
||||
isort==5.12.0
|
||||
flask==2.3.2
|
||||
requests==2.31.0
|
||||
websockets==11.0.3
|
||||
numpy==1.24.3
|
||||
funasr==0.10.0
|
||||
modelscope==1.9.5
|
@ -8,7 +8,7 @@ try:
|
||||
from funasr import AutoModel
|
||||
except ImportError as exc:
|
||||
raise ImportError("未找到funasr库, 请先安装: pip install funasr") from exc
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
# 日志模块
|
||||
from src.utils import get_module_logger
|
||||
|
||||
@ -101,6 +101,16 @@ class ModelLoader:
|
||||
logger.error("加载%s模型失败: %s", model_type, str(e))
|
||||
raise
|
||||
|
||||
# def _load_pipeline(self, input_model_args: dict, model_type: str):
|
||||
# """
|
||||
# 加载pipeline
|
||||
# """
|
||||
# default_pipeline = pipeline(
|
||||
# task='speaker-verification',
|
||||
# model='iic/speech_campplus_sv_zh-cn_16k-common',
|
||||
# model_revision='v1.0.0'
|
||||
# )
|
||||
|
||||
def load_models(self, args):
|
||||
"""
|
||||
加载所有需要的模型
|
||||
@ -115,12 +125,19 @@ class ModelLoader:
|
||||
self.models = {}
|
||||
# 加载离线ASR模型
|
||||
# 检查对应键是否存在
|
||||
model_list = ["asr", "asr_online", "vad", "punc", "spk"]
|
||||
model_list = ["asr", "asr_online", "vad", "punc"]
|
||||
for model_name in model_list:
|
||||
name_model = f"{model_name}_model"
|
||||
name_model_revision = f"{model_name}_model_revision"
|
||||
if name_model in args:
|
||||
logger.debug("加载%s模型", model_name)
|
||||
self.models[model_name] = self._load_model(args, model_name)
|
||||
logger.info("所有模型加载完成")
|
||||
pipeline_list = ["spk"]
|
||||
for pipeline_name in pipeline_list:
|
||||
if pipeline_name == "spk":
|
||||
self.models[pipeline_name] = pipeline(
|
||||
task='speaker-verification',
|
||||
model='iic/speech_campplus_sv_zh-cn_16k-common',
|
||||
model_revision='v1.0.0'
|
||||
)
|
||||
return self.models
|
||||
|
@ -17,6 +17,9 @@ from typing import Callable, List, Dict
|
||||
from queue import Queue
|
||||
import threading
|
||||
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
class BaseFunctor(ABC):
|
||||
"""
|
||||
@ -157,13 +160,15 @@ class FunctorFactory:
|
||||
"""
|
||||
from src.functor.spk_functor import SPKFunctor
|
||||
|
||||
logger.debug(f"创建spk functor[开始]")
|
||||
audio_config = config["audio_config"]
|
||||
model = {"spk": models["spk"]}
|
||||
# model = {"spk": models["spk"]}
|
||||
|
||||
spk_functor = SPKFunctor()
|
||||
spk_functor = SPKFunctor(sv_pipeline=models["spk"])
|
||||
spk_functor.set_audio_config(audio_config)
|
||||
spk_functor.set_model(model)
|
||||
# spk_functor.set_model(model)
|
||||
|
||||
logger.debug(f"创建spk functor[完成]")
|
||||
return spk_functor
|
||||
|
||||
def _make_resultbinderfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||
|
@ -5,7 +5,7 @@ ResultBinderFunctor
|
||||
|
||||
from src.functor.base import BaseFunctor
|
||||
from src.models import AudioBinary_Config, VAD_Functor_result
|
||||
from typing import Callable, List
|
||||
from typing import Callable, List, Dict, Any
|
||||
from queue import Queue, Empty
|
||||
import threading
|
||||
import time
|
||||
@ -74,16 +74,30 @@ class ResultBinderFunctor(BaseFunctor):
|
||||
for callback in self._callback:
|
||||
callback(result)
|
||||
|
||||
def _process(self, data: VAD_Functor_result) -> None:
|
||||
def _process(self, data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
处理数据
|
||||
{
|
||||
"is_final": false,
|
||||
"mode": "2pass-offline",
|
||||
"text": ",等一下我回一下,ok你看这里就有了",
|
||||
"wav_name": "h5",
|
||||
"speaker_id":
|
||||
}
|
||||
"""
|
||||
logger.debug("ResultBinderFunctor处理数据: %s", data)
|
||||
# 将data中的result进行聚合
|
||||
# 此步暂时无意义,预留
|
||||
results = {}
|
||||
for name, result in data.items():
|
||||
results[name] = result
|
||||
results = {
|
||||
"is_final": False,
|
||||
"mode": "2pass-offline",
|
||||
"text": data["asr"],
|
||||
"wav_name": "h5",
|
||||
"speaker_id": data["spk"]["speaker_id"]
|
||||
}
|
||||
|
||||
# for name, result in data.items():
|
||||
# results[name] = result
|
||||
self._do_callback(results)
|
||||
|
||||
def _run(self) -> None:
|
||||
|
File diff suppressed because one or more lines are too long
@ -7,6 +7,8 @@ import threading
|
||||
from queue import Empty, Queue
|
||||
from typing import List, Any, Callable
|
||||
import numpy
|
||||
import time
|
||||
from datetime import datetime
|
||||
from src.models import (
|
||||
VAD_Functor_result,
|
||||
AudioBinary_Config,
|
||||
@ -126,7 +128,10 @@ class VADFunctor(BaseFunctor):
|
||||
self._cache_result_list[-1][1] = end
|
||||
else:
|
||||
self._cache_result_list.append(pair)
|
||||
while len(self._cache_result_list) > 1:
|
||||
logger.debug(f"VADFunctor结果: {self._cache_result_list}")
|
||||
while len(self._cache_result_list) > 0 and self._cache_result_list[0][1] != -1:
|
||||
logger.debug(f"VADFunctor结果: {self._cache_result_list}")
|
||||
logger.debug(f"VADFunctor list[0][1]: {self._cache_result_list[0][1]}")
|
||||
# 创建VAD片段
|
||||
# 计算开始帧
|
||||
start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0])
|
||||
@ -135,12 +140,16 @@ class VADFunctor(BaseFunctor):
|
||||
end_frame = self._audio_config.ms2frame(self._cache_result_list[0][1])
|
||||
end_frame -= self._audio_cache_preindex
|
||||
# 计算开始时间
|
||||
timestamp = time.time()
|
||||
format_time = datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
|
||||
logger.debug(f"{format_time}创建VAD片段: {start_frame} - {end_frame}")
|
||||
vad_result = VAD_Functor_result.create_from_push_data(
|
||||
audiobinary_data_list=self._audio_binary_data_list,
|
||||
data=self._audiobinary_cache[start_frame:end_frame],
|
||||
start_time=self._cache_result_list[0][0],
|
||||
end_time=self._cache_result_list[0][1],
|
||||
)
|
||||
logger.debug(f"{format_time}创建VAD片段成功: {vad_result}")
|
||||
self._audio_cache_preindex += end_frame
|
||||
self._audiobinary_cache = self._audiobinary_cache[end_frame:]
|
||||
for callback in self._callback:
|
||||
|
@ -108,16 +108,19 @@ class ASRPipeline(PipelineBase):
|
||||
from src.functor import FunctorFactory
|
||||
|
||||
# 加载VAD、asr、spk functor
|
||||
logger.debug(f"使用FunctorFactory创建vad functor")
|
||||
self._functor_dict["vad"] = FunctorFactory.make_functor(
|
||||
functor_name="vad", config=self._config, models=self._models
|
||||
)
|
||||
logger.debug(f"使用FunctorFactory创建asr functor")
|
||||
self._functor_dict["asr"] = FunctorFactory.make_functor(
|
||||
functor_name="asr", config=self._config, models=self._models
|
||||
)
|
||||
logger.debug(f"使用FunctorFactory创建spk functor")
|
||||
self._functor_dict["spk"] = FunctorFactory.make_functor(
|
||||
functor_name="spk", config=self._config, models=self._models
|
||||
)
|
||||
|
||||
logger.debug(f"使用FunctorFactory创建resultbinder functor")
|
||||
self._functor_dict["resultbinder"] = FunctorFactory.make_functor(
|
||||
functor_name="resultbinder", config=self._config, models=self._models
|
||||
)
|
||||
@ -160,8 +163,10 @@ class ASRPipeline(PipelineBase):
|
||||
# 设置resultbinder的回调函数 为 自身被设置的回调函数,用于和外界交互
|
||||
self._functor_dict["resultbinder"].add_callback(self._callback)
|
||||
|
||||
except ImportError:
|
||||
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
|
||||
except ImportError as e:
|
||||
raise ImportError(f"functorFactory引入失败,ASRPipeline无法完成初始化: {str(e)}")
|
||||
except Exception as e:
|
||||
raise Exception(f"ASRPipeline初始化失败: {str(e)}")
|
||||
|
||||
def _check_result(self, result: Any) -> None:
|
||||
"""
|
||||
|
@ -25,8 +25,11 @@ class FastAPIWebSocketAdapter:
|
||||
message = await self._ws.receive()
|
||||
if 'bytes' in message:
|
||||
bytes_data = message['bytes']
|
||||
# 将字节流转换为float64的Numpy数组
|
||||
audio_array = np.frombuffer(bytes_data, dtype=np.float32)
|
||||
audio_array = np.frombuffer(bytes_data, dtype=np.int16)
|
||||
# 将int16转为float32
|
||||
audio_array = audio_array.astype(np.float32)
|
||||
# 归一化
|
||||
audio_array = audio_array / 32768.0
|
||||
|
||||
# 使用回车符 \r 覆盖打印进度
|
||||
self._total_received += len(bytes_data)
|
||||
|
@ -65,7 +65,7 @@ async def test_asr_runner():
|
||||
}
|
||||
models = model_loader.load_models(args)
|
||||
audio_file_path = "tests/XT_ZZY_denoise.wav"
|
||||
audio_data, sample_rate = soundfile.read(audio_file_path)
|
||||
audio_data, sample_rate = soundfile.read(audio_file_path, dtype='float32')
|
||||
logger.info(
|
||||
f"加载数据: {audio_file_path} , audio_data_length: {len(audio_data)}, audio_data_type: {type(audio_data)}, sample_rate: {sample_rate}"
|
||||
)
|
||||
@ -112,7 +112,7 @@ async def test_asr_runner():
|
||||
# 每次发送100ms的音频
|
||||
audio_clip_len = int(sample_rate * 0.1)
|
||||
for i in range(0, len(audio_data), audio_clip_len):
|
||||
chunk = audio_data[i : i + audio_clip_len]
|
||||
chunk = audio_data[i : i + audio_clip_len].astype(np.float32)
|
||||
if chunk.size == 0:
|
||||
break
|
||||
mock_ws.put_for_recv(chunk)
|
||||
|
@ -5,7 +5,7 @@ import uuid
|
||||
|
||||
# --- 配置 ---
|
||||
HOST = "localhost"
|
||||
PORT = 8000
|
||||
PORT = 11096
|
||||
SESSION_ID = str(uuid.uuid4())
|
||||
SENDER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=sender"
|
||||
RECEIVER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=receiver"
|
||||
@ -54,7 +54,7 @@ async def run_sender():
|
||||
# --- 主任务:发送音频 ---
|
||||
try:
|
||||
print("▶️ [Sender] 准备发送音频...")
|
||||
audio_data, sample_rate = sf.read(AUDIO_FILE_PATH, dtype='float32')
|
||||
audio_data, sample_rate = sf.read(AUDIO_FILE_PATH, dtype='int16')
|
||||
if sample_rate != 16000:
|
||||
print(f"❌ [Sender] 错误:音频文件采样率必须是 16kHz。")
|
||||
receiver_sub_task.cancel()
|
||||
|
Loading…
x
Reference in New Issue
Block a user