[代码重构完成]完成了拥有VAD、ASR、SPKVerify(说话人加载自数据库)的基于fastAPI的ws语音识别与说话人识别服务器。

This commit is contained in:
Ziyang.Zhang 2025-07-03 18:00:14 +08:00
parent 3083738db4
commit db811763d4
16 changed files with 309 additions and 137 deletions

145
README.md
View File

@ -1,110 +1,97 @@
# FunASR WebSocket服务 # FunASR FastAPI WebSocket Service
## 简介 一个基于 FunASR 和 FastAPI 构建的高性能、实时的语音识别 WebSocket 服务。该项目核心特色是支持"一发多收"的广播模式,适用于会议实时字幕、在线教育、直播转写等需要将单一音源的识别结果分发给多个客户端的场景。
本项目基于FunASR实现了一个WebSocket语音识别服务支持实时语音流的在线和离线识别。利用ModelScope开源语音模型该服务可以进行高精度的中文语音识别并支持语音活动检测(VAD)和自动添加标点符号。
## ✨ 功能特性
- **实时语音处理**: 集成 FunASR 的语音活动检测VAD、语音识别ASR和声纹识别SPK模型。
- **WebSocket 流式 API**: 提供低延迟、双向的实时通信接口。
- **"一发多收"架构**:
- **发送者 (Sender)**: 单一客户端作为音频来源,向服务器持续发送音频流。
- **接收者 (Receiver)**: 多个客户端可以订阅同一个会话,实时接收广播的识别结果。
- **异步核心**: 基于 FastAPI 和 `asyncio` 构建,可处理大量并发连接。
- **模块化设计**: 清晰地分离了服务层 (`server.py`)、会话管理层 (`ASRRunner`) 和核心处理流水线 (`ASRPipeline`)。
## 📂 项目结构
## 项目结构
``` ```
. .
├── src/ # 源代码目录 ├── main.py # 应用程序主入口,使用 uvicorn 启动服务
│ ├── __init__.py # 包初始化文件 ├── WEBSOCKET_API.md # WebSocket API 详细使用文档和示例
│ ├── server.py # WebSocket服务器实现 ├── src
│ ├── config.py # 配置处理模块 │ ├── server.py # FastAPI 应用核心,管理生命周期和全局资源
│ ├── models.py # 模型加载模块 │ ├── runner
│ ├── service.py # ASR服务实现 │ │ └── ASRRunner.py # 核心会话管理器,负责创建和协调识别会话 (SAR)
│ └── client.py # 测试客户端 │ ├── pipeline
├── tests/ # 测试目录 │ │ └── ASRpipeline.py # 同步的、基于线程的语音处理流水线
│ ├── __init__.py # 测试包初始化文件 │ ├── functor # VAD, ASR, SPK 等原子操作的实现
│ └── test_config.py # 配置模块测试 │ ├── websockets
├── requirements.txt # Python依赖 │ │ ├── adapter.py # WebSocket 适配器,处理数据格式转换
├── Dockerfile # Docker配置 │ │ ├── endpoint
├── docker-compose.yml # Docker Compose配置 │ │ │ └── asr_endpoint.py # WebSocket 的业务逻辑端点
├── .gitignore # Git忽略文件 │ │ └── router.py # WebSocket 路由
└── README.md # 项目说明 │ └── ...
└── tests
├── runner
│ └── asr_runner_test.py # ASRRunner 的单元测试 (异步)
└── websocket
└── websocket_asr.py # WebSocket 服务的端到端测试
``` ```
## 功能特性 ## 🚀 快速开始
- **多模式识别**:支持离线(offline)、在线(online)和两阶段(2pass)识别模式 ### 1. 环境与依赖
- **语音活动检测**:自动检测语音开始和结束
- **标点符号**:支持自动添加标点符号
- **WebSocket接口**基于二进制WebSocket提供实时语音识别
- **Docker支持**:提供容器化部署支持
## 安装与使用
### 环境要求
- Python 3.8+ - Python 3.8+
- CUDA支持 (若需GPU加速) - 项目依赖项记录在 `requirements.txt` 文件中。
- 内存 >= 8GB
### 安装依赖 ### 2. 安装
建议在虚拟环境中安装依赖。在项目根目录下,运行:
```bash ```bash
pip install -r requirements.txt pip install -r requirements.txt
``` ```
### 运行服务 ### 3. 运行服务
执行主入口文件来启动 FastAPI 服务:
```bash ```bash
python src/server.py python main.py
``` ```
服务启动后,将监听 `http://0.0.0.0:8000`
常用启动参数: ## 💡 如何使用
- `--host`: 服务器监听地址,默认为 0.0.0.0
- `--port`: 服务器端口,默认为 10095
- `--device`: 设备类型(cuda或cpu),默认为 cuda
- `--ngpu`: GPU数量0表示使用CPU默认为 1
### 测试客户端 服务通过 WebSocket 提供,客户端通过 `session_id` 来创建或加入一个识别会话,并通过 `mode` 参数声明自己的角色(`sender``receiver`)。
**详细的 API 说明、URL 格式以及 Python 和 JavaScript 的客户端连接示例,请参阅:**
➡️ **[WEBSOCKET_API.md](./docs/WEBSOCKET_API.md)**
## 🔬 测试
项目提供了两种测试方式来验证其功能。
### 1. 端到端 WebSocket 测试
此测试会模拟一个 `sender` 和一个 `receiver`,完整地测试一次识别会话。
**前提**: 确保 FastAPI 服务正在运行。
```bash ```bash
python src/client.py --audio_file path/to/audio.wav python main.py
``` ```
常用客户端参数: 在项目根目录下执行:
- `--audio_file`: 要识别的音频文件路径
- `--mode`: 识别模式,可选 2pass/online/offline默认为 2pass
- `--host`: 服务器地址,默认为 localhost
- `--port`: 服务器端口,默认为 10095
## Docker部署
### 构建镜像
```bash ```bash
docker build -t funasr-websocket . python tests/websocket/websocket_asr.py
``` ```
### 使用Docker Compose启动 ### 2. ASRRunner 单元测试
此测试针对核心的 `ASRRunner` 组件进行,验证其异步逻辑。
执行测试:
```bash ```bash
docker-compose up -d python test_main.py
``` ```
## API说明
### WebSocket消息格式
1. **客户端配置消息**:
```json
{
"mode": "2pass", // 可选: "2pass", "online", "offline"
"chunk_size": "5,10", // 块大小,格式为"encoder_size,decoder_size"
"wav_name": "audio1", // 音频标识名称
"is_speaking": true // 是否正在说话
}
```
2. **客户端音频数据**:
二进制音频数据流16kHz采样率16位PCM格式
3. **服务器识别结果**:
```json
{
"mode": "2pass-online", // 识别模式
"text": "识别的文本内容", // 识别结果
"wav_name": "audio1", // 音频标识
"is_final": false // 是否是最终结果
}
```

View File

@ -1,6 +1,6 @@
[ [
{ {
"speaker_id": "b7e2c8e2-1f3a-4c2a-9e7a-2c1d4e8f9a3b", "speaker_id": "137facd6-a1c9-47b9-b87d-16e6f62e07bd",
"speaker_name": "ZiyangZhang", "speaker_name": "ZiyangZhang",
"wav_path": "/home/lyg/Code/funasr/data/speaker_wav/ZiyangZhang.wav", "wav_path": "/home/lyg/Code/funasr/data/speaker_wav/ZiyangZhang.wav",
"speaker_embs": [ "speaker_embs": [
@ -199,7 +199,7 @@
] ]
}, },
{ {
"speaker_id": "b7e2c8e2-1f3a-4c2a-9e7a-2c1d4e8f9a3b", "speaker_id": "81e1b806-c76a-468f-98b7-4a63f2996480",
"speaker_name": "HaiaoDuan", "speaker_name": "HaiaoDuan",
"wav_path": "/home/lyg/Code/funasr/data/speaker_wav/HaiaoDuan.wav", "wav_path": "/home/lyg/Code/funasr/data/speaker_wav/HaiaoDuan.wav",
"speaker_embs": [ "speaker_embs": [

View File

@ -26,7 +26,7 @@ ws://<your_server_host>:8000/ws/asr/{session_id}?mode=<client_mode>
- **音频格式**: `sender` 必须发送原始的 **PCM 音频数据** - **音频格式**: `sender` 必须发送原始的 **PCM 音频数据**
- **采样率**: 16000 Hz - **采样率**: 16000 Hz
- **位深**: 16-bit (signed integer) - **位深**: 32-bit (floating point)
- **声道数**: 单声道 (Mono) - **声道数**: 单声道 (Mono)
- **传输格式**: 必须以**二进制 (bytes)** 格式发送。 - **传输格式**: 必须以**二进制 (bytes)** 格式发送。
- **结束信号**: 当音频流结束时,`sender` 应发送一个**文本消息** `"close"` 来通知服务器关闭会话。 - **结束信号**: 当音频流结束时,`sender` 应发送一个**文本消息** `"close"` 来通知服务器关闭会话。
@ -64,7 +64,7 @@ import uuid
SERVER_URI = "ws://localhost:8000/ws/asr/{session_id}?mode=sender" SERVER_URI = "ws://localhost:8000/ws/asr/{session_id}?mode=sender"
SESSION_ID = str(uuid.uuid4()) # 为这次会话生成一个唯一的ID SESSION_ID = str(uuid.uuid4()) # 为这次会话生成一个唯一的ID
AUDIO_FILE = "tests/XT_ZZY_denoise.wav" # 替换为你的音频文件路径 AUDIO_FILE = "tests/XT_ZZY_denoise.wav" # 替换为你的音频文件路径
CHUNK_SIZE = 3200 # 每次发送 100ms 的音频数据 (16000 * 2 * 0.1) CHUNK_SIZE = 3200 # 对应 100ms 的 float32 数据 (16000 * 4 * 0.1)
async def send_audio(): async def send_audio():
"""连接到服务器,并流式发送音频文件""" """连接到服务器,并流式发送音频文件"""
@ -80,7 +80,8 @@ async def send_audio():
print("开始发送音频...") print("开始发送音频...")
while True: while True:
data = f.read(CHUNK_SIZE, dtype='int16') # 读取为 float32 类型
data = f.read(CHUNK_SIZE, dtype='float32')
if not data.any(): if not data.any():
break break
# 将 numpy 数组转换为原始字节流 # 将 numpy 数组转换为原始字节流
@ -215,18 +216,14 @@ if __name__ == "__main__":
audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: SAMPLE_RATE }); audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: SAMPLE_RATE });
const source = audioContext.createMediaStreamSource(mediaStream); const source = audioContext.createMediaStreamSource(mediaStream);
const bufferSize = CHUNK_DURATION_MS * SAMPLE_RATE / 1000 * 2; // 计算缓冲区大小 const bufferSize = CHUNK_DURATION_MS * SAMPLE_RATE / 1000 * 4; // 计算缓冲区大小
scriptProcessor = audioContext.createScriptProcessor(bufferSize, 1, 1); scriptProcessor = audioContext.createScriptProcessor(bufferSize, 1, 1);
scriptProcessor.onaudioprocess = (e) => { scriptProcessor.onaudioprocess = (e) => {
if (websocket && websocket.readyState === WebSocket.OPEN) { if (websocket && websocket.readyState === WebSocket.OPEN) {
const inputData = e.inputBuffer.getChannelData(0); const inputData = e.inputBuffer.getChannelData(0);
// 服务器期望16-bit PCM需要转换 // 服务器期望 float32 数据inputData 本身就是 Float32Array直接发送其 buffer
const pcmData = new Int16Array(inputData.length); websocket.send(inputData.buffer);
for (let i = 0; i < inputData.length; i++) {
pcmData[i] = Math.max(-1, Math.min(1, inputData[i])) * 32767;
}
websocket.send(pcmData.buffer);
} }
}; };

View File

@ -4,7 +4,7 @@ from src.utils.logger import get_module_logger, setup_root_logger
from datetime import datetime from datetime import datetime
time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
setup_root_logger(level="INFO", log_file=f"logs/main_{time}.log") setup_root_logger(level="DEBUG", log_file=f"logs/main_{time}.log")
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
if __name__ == "__main__": if __name__ == "__main__":
@ -14,6 +14,6 @@ if __name__ == "__main__":
uvicorn.run( uvicorn.run(
app, app,
host="0.0.0.0", host="0.0.0.0",
port=8000, port=11096,
log_level="info" log_level="info"
) )

32
paerser.py Normal file
View File

@ -0,0 +1,32 @@
def pre_remove_details(input_string: str) -> str:
start_tag = '</details>'
start = input_string.find(start_tag, 0)
if start == -1:
return input_string, "unfind %s" % start_tag
return input_string[start + len(start_tag):], "success remove %s" % start_tag
def pre_remove_markdown(input_string: str) -> str:
start_tag = '```markdown'
end_tag = '```'
start = input_string.find(start_tag, 0)
if start == -1:
return input_string, "unfind %s" % start_tag
end = input_string.find(end_tag, start + 11)
if end == -1:
return input_string, "unfind %s" % end_tag
return input_string[start + 11:end].strip(), "success remove %s" % start_tag
def main(input_string: str) -> dict:
result = input_string
statuses = []
result, detail_status = pre_remove_details(result)
statuses.append(detail_status)
result, markdown_status = pre_remove_markdown(result)
statuses.append(markdown_status)
return {
"result": result,
"status": statuses
}

View File

@ -1,11 +1,11 @@
pytest==7.3.1 fastapi==0.115.14
pytest-cov==4.1.0 funasr==1.2.6
flake8==6.0.0 modelscope==1.27.1
black==23.3.0 numpy==2.0.1
isort==5.12.0 pyaudio==0.2.14
flask==2.3.2 pydantic==2.11.7
requests==2.31.0 pydub==0.25.1
websockets==11.0.3 pytest==8.3.5
numpy==1.24.3 soundfile==0.13.1
funasr==0.10.0 torch==2.3.1
modelscope==1.9.5 uvicorn==0.35.0

11
requirements.txt.backup Normal file
View File

@ -0,0 +1,11 @@
pytest==7.3.1
pytest-cov==4.1.0
flake8==6.0.0
black==23.3.0
isort==5.12.0
flask==2.3.2
requests==2.31.0
websockets==11.0.3
numpy==1.24.3
funasr==0.10.0
modelscope==1.9.5

View File

@ -8,7 +8,7 @@ try:
from funasr import AutoModel from funasr import AutoModel
except ImportError as exc: except ImportError as exc:
raise ImportError("未找到funasr库, 请先安装: pip install funasr") from exc raise ImportError("未找到funasr库, 请先安装: pip install funasr") from exc
from modelscope.pipelines import pipeline
# 日志模块 # 日志模块
from src.utils import get_module_logger from src.utils import get_module_logger
@ -101,6 +101,16 @@ class ModelLoader:
logger.error("加载%s模型失败: %s", model_type, str(e)) logger.error("加载%s模型失败: %s", model_type, str(e))
raise raise
# def _load_pipeline(self, input_model_args: dict, model_type: str):
# """
# 加载pipeline
# """
# default_pipeline = pipeline(
# task='speaker-verification',
# model='iic/speech_campplus_sv_zh-cn_16k-common',
# model_revision='v1.0.0'
# )
def load_models(self, args): def load_models(self, args):
""" """
加载所有需要的模型 加载所有需要的模型
@ -115,12 +125,19 @@ class ModelLoader:
self.models = {} self.models = {}
# 加载离线ASR模型 # 加载离线ASR模型
# 检查对应键是否存在 # 检查对应键是否存在
model_list = ["asr", "asr_online", "vad", "punc", "spk"] model_list = ["asr", "asr_online", "vad", "punc"]
for model_name in model_list: for model_name in model_list:
name_model = f"{model_name}_model" name_model = f"{model_name}_model"
name_model_revision = f"{model_name}_model_revision" name_model_revision = f"{model_name}_model_revision"
if name_model in args: if name_model in args:
logger.debug("加载%s模型", model_name) logger.debug("加载%s模型", model_name)
self.models[model_name] = self._load_model(args, model_name) self.models[model_name] = self._load_model(args, model_name)
logger.info("所有模型加载完成") pipeline_list = ["spk"]
for pipeline_name in pipeline_list:
if pipeline_name == "spk":
self.models[pipeline_name] = pipeline(
task='speaker-verification',
model='iic/speech_campplus_sv_zh-cn_16k-common',
model_revision='v1.0.0'
)
return self.models return self.models

View File

@ -17,6 +17,9 @@ from typing import Callable, List, Dict
from queue import Queue from queue import Queue
import threading import threading
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class BaseFunctor(ABC): class BaseFunctor(ABC):
""" """
@ -157,13 +160,15 @@ class FunctorFactory:
""" """
from src.functor.spk_functor import SPKFunctor from src.functor.spk_functor import SPKFunctor
logger.debug(f"创建spk functor[开始]")
audio_config = config["audio_config"] audio_config = config["audio_config"]
model = {"spk": models["spk"]} # model = {"spk": models["spk"]}
spk_functor = SPKFunctor() spk_functor = SPKFunctor(sv_pipeline=models["spk"])
spk_functor.set_audio_config(audio_config) spk_functor.set_audio_config(audio_config)
spk_functor.set_model(model) # spk_functor.set_model(model)
logger.debug(f"创建spk functor[完成]")
return spk_functor return spk_functor
def _make_resultbinderfunctor(config: dict, models: dict) -> BaseFunctor: def _make_resultbinderfunctor(config: dict, models: dict) -> BaseFunctor:

View File

@ -5,7 +5,7 @@ ResultBinderFunctor
from src.functor.base import BaseFunctor from src.functor.base import BaseFunctor
from src.models import AudioBinary_Config, VAD_Functor_result from src.models import AudioBinary_Config, VAD_Functor_result
from typing import Callable, List from typing import Callable, List, Dict, Any
from queue import Queue, Empty from queue import Queue, Empty
import threading import threading
import time import time
@ -74,16 +74,30 @@ class ResultBinderFunctor(BaseFunctor):
for callback in self._callback: for callback in self._callback:
callback(result) callback(result)
def _process(self, data: VAD_Functor_result) -> None: def _process(self, data: Dict[str, Any]) -> None:
""" """
处理数据 处理数据
{
"is_final": false,
"mode": "2pass-offline",
"text": "等一下我回一下ok你看这里就有了",
"wav_name": "h5",
"speaker_id":
}
""" """
logger.debug("ResultBinderFunctor处理数据: %s", data) logger.debug("ResultBinderFunctor处理数据: %s", data)
# 将data中的result进行聚合 # 将data中的result进行聚合
# 此步暂时无意义,预留 # 此步暂时无意义,预留
results = {} results = {
for name, result in data.items(): "is_final": False,
results[name] = result "mode": "2pass-offline",
"text": data["asr"],
"wav_name": "h5",
"speaker_id": data["spk"]["speaker_id"]
}
# for name, result in data.items():
# results[name] = result
self._do_callback(results) self._do_callback(results)
def _run(self) -> None: def _run(self) -> None:

File diff suppressed because one or more lines are too long

View File

@ -7,6 +7,8 @@ import threading
from queue import Empty, Queue from queue import Empty, Queue
from typing import List, Any, Callable from typing import List, Any, Callable
import numpy import numpy
import time
from datetime import datetime
from src.models import ( from src.models import (
VAD_Functor_result, VAD_Functor_result,
AudioBinary_Config, AudioBinary_Config,
@ -126,7 +128,10 @@ class VADFunctor(BaseFunctor):
self._cache_result_list[-1][1] = end self._cache_result_list[-1][1] = end
else: else:
self._cache_result_list.append(pair) self._cache_result_list.append(pair)
while len(self._cache_result_list) > 1: logger.debug(f"VADFunctor结果: {self._cache_result_list}")
while len(self._cache_result_list) > 0 and self._cache_result_list[0][1] != -1:
logger.debug(f"VADFunctor结果: {self._cache_result_list}")
logger.debug(f"VADFunctor list[0][1]: {self._cache_result_list[0][1]}")
# 创建VAD片段 # 创建VAD片段
# 计算开始帧 # 计算开始帧
start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0]) start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0])
@ -135,12 +140,16 @@ class VADFunctor(BaseFunctor):
end_frame = self._audio_config.ms2frame(self._cache_result_list[0][1]) end_frame = self._audio_config.ms2frame(self._cache_result_list[0][1])
end_frame -= self._audio_cache_preindex end_frame -= self._audio_cache_preindex
# 计算开始时间 # 计算开始时间
timestamp = time.time()
format_time = datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
logger.debug(f"{format_time}创建VAD片段: {start_frame} - {end_frame}")
vad_result = VAD_Functor_result.create_from_push_data( vad_result = VAD_Functor_result.create_from_push_data(
audiobinary_data_list=self._audio_binary_data_list, audiobinary_data_list=self._audio_binary_data_list,
data=self._audiobinary_cache[start_frame:end_frame], data=self._audiobinary_cache[start_frame:end_frame],
start_time=self._cache_result_list[0][0], start_time=self._cache_result_list[0][0],
end_time=self._cache_result_list[0][1], end_time=self._cache_result_list[0][1],
) )
logger.debug(f"{format_time}创建VAD片段成功: {vad_result}")
self._audio_cache_preindex += end_frame self._audio_cache_preindex += end_frame
self._audiobinary_cache = self._audiobinary_cache[end_frame:] self._audiobinary_cache = self._audiobinary_cache[end_frame:]
for callback in self._callback: for callback in self._callback:

View File

@ -108,16 +108,19 @@ class ASRPipeline(PipelineBase):
from src.functor import FunctorFactory from src.functor import FunctorFactory
# 加载VAD、asr、spk functor # 加载VAD、asr、spk functor
logger.debug(f"使用FunctorFactory创建vad functor")
self._functor_dict["vad"] = FunctorFactory.make_functor( self._functor_dict["vad"] = FunctorFactory.make_functor(
functor_name="vad", config=self._config, models=self._models functor_name="vad", config=self._config, models=self._models
) )
logger.debug(f"使用FunctorFactory创建asr functor")
self._functor_dict["asr"] = FunctorFactory.make_functor( self._functor_dict["asr"] = FunctorFactory.make_functor(
functor_name="asr", config=self._config, models=self._models functor_name="asr", config=self._config, models=self._models
) )
logger.debug(f"使用FunctorFactory创建spk functor")
self._functor_dict["spk"] = FunctorFactory.make_functor( self._functor_dict["spk"] = FunctorFactory.make_functor(
functor_name="spk", config=self._config, models=self._models functor_name="spk", config=self._config, models=self._models
) )
logger.debug(f"使用FunctorFactory创建resultbinder functor")
self._functor_dict["resultbinder"] = FunctorFactory.make_functor( self._functor_dict["resultbinder"] = FunctorFactory.make_functor(
functor_name="resultbinder", config=self._config, models=self._models functor_name="resultbinder", config=self._config, models=self._models
) )
@ -160,8 +163,10 @@ class ASRPipeline(PipelineBase):
# 设置resultbinder的回调函数 为 自身被设置的回调函数,用于和外界交互 # 设置resultbinder的回调函数 为 自身被设置的回调函数,用于和外界交互
self._functor_dict["resultbinder"].add_callback(self._callback) self._functor_dict["resultbinder"].add_callback(self._callback)
except ImportError: except ImportError as e:
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化") raise ImportError(f"functorFactory引入失败,ASRPipeline无法完成初始化: {str(e)}")
except Exception as e:
raise Exception(f"ASRPipeline初始化失败: {str(e)}")
def _check_result(self, result: Any) -> None: def _check_result(self, result: Any) -> None:
""" """

View File

@ -25,8 +25,11 @@ class FastAPIWebSocketAdapter:
message = await self._ws.receive() message = await self._ws.receive()
if 'bytes' in message: if 'bytes' in message:
bytes_data = message['bytes'] bytes_data = message['bytes']
# 将字节流转换为float64的Numpy数组 audio_array = np.frombuffer(bytes_data, dtype=np.int16)
audio_array = np.frombuffer(bytes_data, dtype=np.float32) # 将int16转为float32
audio_array = audio_array.astype(np.float32)
# 归一化
audio_array = audio_array / 32768.0
# 使用回车符 \r 覆盖打印进度 # 使用回车符 \r 覆盖打印进度
self._total_received += len(bytes_data) self._total_received += len(bytes_data)

View File

@ -65,7 +65,7 @@ async def test_asr_runner():
} }
models = model_loader.load_models(args) models = model_loader.load_models(args)
audio_file_path = "tests/XT_ZZY_denoise.wav" audio_file_path = "tests/XT_ZZY_denoise.wav"
audio_data, sample_rate = soundfile.read(audio_file_path) audio_data, sample_rate = soundfile.read(audio_file_path, dtype='float32')
logger.info( logger.info(
f"加载数据: {audio_file_path} , audio_data_length: {len(audio_data)}, audio_data_type: {type(audio_data)}, sample_rate: {sample_rate}" f"加载数据: {audio_file_path} , audio_data_length: {len(audio_data)}, audio_data_type: {type(audio_data)}, sample_rate: {sample_rate}"
) )
@ -112,7 +112,7 @@ async def test_asr_runner():
# 每次发送100ms的音频 # 每次发送100ms的音频
audio_clip_len = int(sample_rate * 0.1) audio_clip_len = int(sample_rate * 0.1)
for i in range(0, len(audio_data), audio_clip_len): for i in range(0, len(audio_data), audio_clip_len):
chunk = audio_data[i : i + audio_clip_len] chunk = audio_data[i : i + audio_clip_len].astype(np.float32)
if chunk.size == 0: if chunk.size == 0:
break break
mock_ws.put_for_recv(chunk) mock_ws.put_for_recv(chunk)

View File

@ -5,7 +5,7 @@ import uuid
# --- 配置 --- # --- 配置 ---
HOST = "localhost" HOST = "localhost"
PORT = 8000 PORT = 11096
SESSION_ID = str(uuid.uuid4()) SESSION_ID = str(uuid.uuid4())
SENDER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=sender" SENDER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=sender"
RECEIVER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=receiver" RECEIVER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=receiver"
@ -54,7 +54,7 @@ async def run_sender():
# --- 主任务:发送音频 --- # --- 主任务:发送音频 ---
try: try:
print("▶️ [Sender] 准备发送音频...") print("▶️ [Sender] 准备发送音频...")
audio_data, sample_rate = sf.read(AUDIO_FILE_PATH, dtype='float32') audio_data, sample_rate = sf.read(AUDIO_FILE_PATH, dtype='int16')
if sample_rate != 16000: if sample_rate != 16000:
print(f"❌ [Sender] 错误:音频文件采样率必须是 16kHz。") print(f"❌ [Sender] 错误:音频文件采样率必须是 16kHz。")
receiver_sub_task.cancel() receiver_sub_task.cancel()