[fastAPI]完成fastAPI-websocket端口搭建与测试,接收float32的tobytes字节流。

This commit is contained in:
Ziyang.Zhang 2025-07-02 15:49:21 +08:00
parent 1a296d8309
commit 3083738db4
13 changed files with 758 additions and 437 deletions

265
WEBSOCKET_API.md Normal file
View File

@ -0,0 +1,265 @@
# FunASR-FastAPI WebSocket API 文档
本文档详细介绍了如何连接和使用 FunASR-FastAPI 实时语音识别服务的 WebSocket 接口。
## 1. 连接端点 (Endpoint)
服务的 WebSocket 端点 URL 格式如下:
```
ws://<your_server_host>:8000/ws/asr/{session_id}?mode=<client_mode>
```
### 参数说明
- **`{session_id}`** (路径参数, `str`, **必需**):
用于唯一标识一个识别会话(例如,一场会议或一次直播)。所有属于同一次会话的客户端都应使用相同的 `session_id`
- **`mode`** (查询参数, `str`, **必需**):
定义客户端的角色。
- `sender`: 音频发送者。一个会话中应该只有一个 `sender`。此客户端负责将实时音频流发送到服务器。
- `receiver`: 结果接收者。一个会话中可以有多个 `receiver`。此客户端只接收由服务器广播的识别结果,不发送音频。
## 2. 数据格式
### 2.1 发送数据 (Sender -> Server)
- **音频格式**: `sender` 必须发送原始的 **PCM 音频数据**
- **采样率**: 16000 Hz
- **位深**: 16-bit (signed integer)
- **声道数**: 单声道 (Mono)
- **传输格式**: 必须以**二进制 (bytes)** 格式发送。
- **结束信号**: 当音频流结束时,`sender` 应发送一个**文本消息** `"close"` 来通知服务器关闭会话。
### 2.2 接收数据 (Server -> Receiver)
服务器会将识别结果以 **JSON 文本** 格式广播给会话中的所有 `receiver`(以及 `sender` 自己。JSON 对象的结构示例如下:
```json
{
"asr": "你好,世界。",
"spk": {
"speaker_id": "uuid-of-the-speaker",
"speaker_name": "SpeakerName",
"score": 0.98
}
}
```
## 3. Python 客户端示例
需要安装 `websockets` 库: `pip install websockets`
### 3.1 Python Sender 示例 (发送本地音频文件)
这个脚本会读取一个 WAV 文件,并将其内容以流式方式发送到服务器。
```python
import asyncio
import websockets
import soundfile as sf
import uuid
# --- 配置 ---
SERVER_URI = "ws://localhost:8000/ws/asr/{session_id}?mode=sender"
SESSION_ID = str(uuid.uuid4()) # 为这次会话生成一个唯一的ID
AUDIO_FILE = "tests/XT_ZZY_denoise.wav" # 替换为你的音频文件路径
CHUNK_SIZE = 3200 # 每次发送 100ms 的音频数据 (16000 * 2 * 0.1)
async def send_audio():
"""连接到服务器,并流式发送音频文件"""
uri = SERVER_URI.format(session_id=SESSION_ID)
print(f"作为 Sender 连接到: {uri}")
async with websockets.connect(uri) as websocket:
try:
# 读取音频文件
with sf.SoundFile(AUDIO_FILE, 'r') as f:
assert f.samplerate == 16000, "音频文件采样率必须为 16kHz"
assert f.channels == 1, "音频文件必须为单声道"
print("开始发送音频...")
while True:
data = f.read(CHUNK_SIZE, dtype='int16')
if not data.any():
break
# 将 numpy 数组转换为原始字节流
await websocket.send(data.tobytes())
await asyncio.sleep(0.1) # 模拟实时音频输入
print("音频发送完毕,发送结束信号。")
await websocket.send("close")
# 等待服务器的最终确认或关闭连接
response = await websocket.recv()
print(f"收到服务器最终响应: {response}")
except websockets.exceptions.ConnectionClosed as e:
print(f"连接已关闭: {e}")
except Exception as e:
print(f"发生错误: {e}")
if __name__ == "__main__":
asyncio.run(send_audio())
```
### 3.2 Python Receiver 示例 (接收识别结果)
这个脚本会连接到指定的会话,并持续打印服务器广播的识别结果。
```python
import asyncio
import websockets
# --- 配置 ---
# !!! 必须和 Sender 使用相同的 SESSION_ID !!!
SERVER_URI = "ws://localhost:8000/ws/asr/{session_id}?mode=receiver"
SESSION_ID = "在此处粘贴你的Sender会话ID"
async def receive_results():
"""连接到服务器并接收识别结果"""
if "粘贴你的Sender会话ID" in SESSION_ID:
print("错误:请先设置有效的 SESSION_ID")
return
uri = SERVER_URI.format(session_id=SESSION_ID)
print(f"作为 Receiver 连接到: {uri}")
async with websockets.connect(uri) as websocket:
try:
print("等待接收识别结果...")
while True:
message = await websocket.recv()
print(f"收到结果: {message}")
except websockets.exceptions.ConnectionClosed as e:
print(f"连接已关闭: {e.code} {e.reason}")
if __name__ == "__main__":
asyncio.run(receive_results())
```
## 4. JavaScript 客户端示例 (浏览器)
这个示例展示了如何在网页上通过麦克风获取音频,并将其作为 `sender` 发送。
```html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>WebSocket ASR Client</title>
</head>
<body>
<h1>FunASR WebSocket Client (Sender)</h1>
<p><strong>Session ID:</strong> <span id="sessionId"></span></p>
<button id="startButton">开始识别</button>
<button id="stopButton" disabled>停止识别</button>
<h2>识别结果:</h2>
<div id="results"></div>
<script>
const startButton = document.getElementById('startButton');
const stopButton = document.getElementById('stopButton');
const resultsDiv = document.getElementById('results');
const sessionIdSpan = document.getElementById('sessionId');
let websocket;
let audioContext;
let scriptProcessor;
let mediaStream;
const CHUNK_DURATION_MS = 100; // 每100ms发送一次数据
const SAMPLE_RATE = 16000;
// 生成一个简单的UUID
function generateUUID() {
return ([1e7]+-1e3+-4e3+-8e3+-1e11).replace(/[018]/g, c =>
(c ^ crypto.getRandomValues(new Uint8Array(1))[0] & 15 >> c / 4).toString(16)
);
}
async function startRecording() {
const sessionId = generateUUID();
sessionIdSpan.textContent = sessionId;
const wsUrl = `ws://${window.location.host}/ws/asr/${sessionId}?mode=sender`;
websocket = new WebSocket(wsUrl);
websocket.onopen = () => {
console.log("WebSocket 连接已打开");
startButton.disabled = true;
stopButton.disabled = false;
resultsDiv.innerHTML = '';
};
websocket.onmessage = (event) => {
console.log("收到消息:", event.data);
const result = JSON.parse(event.data);
const asrText = result.asr || '';
const spkName = result.spk ? result.spk.speaker_name : 'Unknown';
resultsDiv.innerHTML += `<p><strong>${spkName}:</strong> ${asrText}</p>`;
};
websocket.onclose = () => {
console.log("WebSocket 连接已关闭");
stopRecording();
};
websocket.onerror = (error) => {
console.error("WebSocket 错误:", error);
alert("WebSocket 连接失败!");
stopRecording();
};
try {
mediaStream = await navigator.mediaDevices.getUserMedia({ audio: true, video: false });
audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: SAMPLE_RATE });
const source = audioContext.createMediaStreamSource(mediaStream);
const bufferSize = CHUNK_DURATION_MS * SAMPLE_RATE / 1000 * 2; // 计算缓冲区大小
scriptProcessor = audioContext.createScriptProcessor(bufferSize, 1, 1);
scriptProcessor.onaudioprocess = (e) => {
if (websocket && websocket.readyState === WebSocket.OPEN) {
const inputData = e.inputBuffer.getChannelData(0);
// 服务器期望16-bit PCM需要转换
const pcmData = new Int16Array(inputData.length);
for (let i = 0; i < inputData.length; i++) {
pcmData[i] = Math.max(-1, Math.min(1, inputData[i])) * 32767;
}
websocket.send(pcmData.buffer);
}
};
source.connect(scriptProcessor);
scriptProcessor.connect(audioContext.destination);
} catch (err) {
console.error("无法获取麦克风:", err);
alert("无法获取麦克风权限!");
if (websocket) websocket.close();
}
}
function stopRecording() {
if (websocket && websocket.readyState === WebSocket.OPEN) {
websocket.send("close");
}
if (mediaStream) {
mediaStream.getTracks().forEach(track => track.stop());
}
if (scriptProcessor) {
scriptProcessor.disconnect();
}
if (audioContext) {
audioContext.close();
}
startButton.disabled = false;
stopButton.disabled = true;
}
startButton.addEventListener('click', startRecording);
stopButton.addEventListener('click', stopRecording);
</script>
</body>
</html>
```

19
main.py
View File

@ -1,12 +1,19 @@
from src.server import app
import uvicorn import uvicorn
from datetime import datetime from src.server import app
from src.utils.logger import get_module_logger, setup_root_logger from src.utils.logger import get_module_logger, setup_root_logger
from datetime import datetime
time = format(datetime.now(), "%Y-%m-%d %H:%M:%S") time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
setup_root_logger(level="DEBUG", log_file=f"logs/fastapiserver_{time}.log") setup_root_logger(level="INFO", log_file=f"logs/main_{time}.log")
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000) logger.info("启动 FunASR FastAPI 服务器...")
# 在生产环境中推荐使用更强大的ASGI服务器如Gunicorn并配合Uvicorn workers。
# 例如: gunicorn -w 4 -k uvicorn.workers.UvicornWorker main:app
uvicorn.run(
app,
host="0.0.0.0",
port=8000,
log_level="info"
)

View File

@ -11,6 +11,7 @@ from src.pipeline import PipelineFactory
from src.models import AudioBinary_data_list, AudioBinary_Config from src.models import AudioBinary_data_list, AudioBinary_Config
from src.core.model_loader import ModelLoader from src.core.model_loader import ModelLoader
from src.config import DefaultConfig from src.config import DefaultConfig
import asyncio
from queue import Queue from queue import Queue
import soundfile import soundfile
import time import time
@ -57,6 +58,7 @@ class ASRRunner(RunnerBase):
# 输入队列 # 输入队列
self._input_queue: Queue = Queue() self._input_queue: Queue = Queue()
self._pipeline: Optional[ASRPipeline] = None self._pipeline: Optional[ASRPipeline] = None
self._task: Optional[asyncio.Task] = None
def set_name(self, name: str): def set_name(self, name: str):
self._name = name self._name = name
@ -76,7 +78,14 @@ class ASRRunner(RunnerBase):
self._pipeline.set_models(self._models) self._pipeline.set_models(self._models)
self._pipeline.set_audio_binary(self._audio_binary) self._pipeline.set_audio_binary(self._audio_binary)
self._pipeline.set_input_queue(self._input_queue) self._pipeline.set_input_queue(self._input_queue)
self._pipeline.set_callback(self.deal_message)
# --- 异步-同步桥梁 ---
# 创建一个线程安全的回调函数用于从Pipeline的线程中调用Runner的异步方法
loop = asyncio.get_running_loop()
def thread_safe_callback(message):
asyncio.run_coroutine_threadsafe(self.deal_message(message), loop)
self._pipeline.set_callback(thread_safe_callback)
self._pipeline.bake() self._pipeline.bake()
def append_receiver(self, receiver: WebSocketClient): def append_receiver(self, receiver: WebSocketClient):
@ -85,46 +94,63 @@ class ASRRunner(RunnerBase):
def delete_receiver(self, receiver: WebSocketClient): def delete_receiver(self, receiver: WebSocketClient):
self._receiver.remove(receiver) self._receiver.remove(receiver)
def deal_message(self, message: str): async def deal_message(self, message: str):
self.broadcast(message) await self.broadcast(message)
def broadcast(self, message: str): async def broadcast(self, message: str):
""" """
广播发送给所有接收者 广播发送给所有接收者
""" """
logger.info("[ASRRunner][SAR-%s]广播发送给所有接收者: %s", self._name, message) logger.info("[ASRRunner][SAR-%s]广播发送给所有接收者: 消息长度:%s", self._name, len(message))
for receiver in self._receiver: logger.info(f"SAR-{self._name} 的接收者列表: {self._receiver}")
receiver.send(message) tasks = [receiver.send(message) for receiver in self._receiver]
await asyncio.gather(*tasks)
def _run(self): async def _run(self):
""" """
运行SAR 运行SAR
""" """
self._pipeline.run() self._pipeline.run()
loop = asyncio.get_running_loop()
while True: while True:
data = self._sender.recv() try:
data = await self._sender.recv()
if data is None: if data is None:
# `None` is used as a signal to end the stream
await loop.run_in_executor(None, self._input_queue.put, None)
break break
# logger.debug("[ASRRunner][SAR-%s]接收到的数据length: %s", self._name, len(data)) # logger.debug("[ASRRunner][SAR-%s]接收到的数据length: %s", self._name, len(data))
self._input_queue.put(data) await loop.run_in_executor(None, self._input_queue.put, data)
self.stop() except Exception as e:
logger.error(f"[ASRRunner][SAR-{self._name}] _run loop error: {e}")
break
await self.stop()
def run(self): def run(self):
""" """
运行SAR 运行SAR
""" """
self._thread = Thread(target=self._run, name=f"[ASRRunner]SAR-{self._name}") self._task = asyncio.create_task(self._run())
self._thread.daemon = True
self._thread.start()
def stop(self): async def stop(self):
""" """
停止SAR 停止SAR
""" """
logger.info(f"Stopping SAR: {self._name}")
self._pipeline.stop() self._pipeline.stop()
for ws in self._receiver:
ws.close() # Close all receiver websockets
self._sender.close() receiver_tasks = [ws.close() for ws in self._receiver]
await asyncio.gather(*receiver_tasks, return_exceptions=True)
# Close the sender websocket
if self._sender:
await self._sender.close()
# Cancel the main task if it's still running
if self._task and not self._task.done():
self._task.cancel()
logger.info(f"SAR stopped: {self._name}")
def __init__(self,*args,**kwargs): def __init__(self,*args,**kwargs):
""" """
@ -195,9 +221,11 @@ class ASRRunner(RunnerBase):
return True return True
return False return False
def __del__(self) -> None: async def shutdown(self):
""" """
析构函数 优雅地关闭所有SAR会话
""" """
for sar in self._SAR_list: logger.info("Shutting down all SAR instances...")
sar.stop() tasks = [sar.stop() for sar in self._SAR_list]
await asyncio.gather(*tasks, return_exceptions=True)
logger.info("All SAR instances have been shut down.")

3
src/runner/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from .ASRRunner import ASRRunner
__all__ = ["ASRRunner"]

View File

@ -10,255 +10,80 @@ import json
import websockets import websockets
import ssl import ssl
import argparse import argparse
from config import parse_args from src.runner import ASRRunner
from models import load_models from src.config import DefaultConfig
from service import ASRService from src.config import AudioBinary_Config
from src.websockets.router import websocket_router
from src.core import ModelLoader
import uvicorn
from fastapi import FastAPI
from contextlib import asynccontextmanager
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
# 全局变量存储当前连接的WebSocket客户端 # 全局变量存储当前连接的WebSocket客户端
websocket_users = set() websocket_users = set()
# 使用 lifespan 上下文管理器来管理应用的生命周期
async def ws_reset(websocket): @asynccontextmanager
"""重置WebSocket连接状态并关闭连接""" async def lifespan(app: FastAPI):
print(f"重置WebSocket连接当前连接数: {len(websocket_users)}")
# 重置状态字典
websocket.status_dict_asr_online["cache"] = {}
websocket.status_dict_asr_online["is_final"] = True
websocket.status_dict_vad["cache"] = {}
websocket.status_dict_vad["is_final"] = True
websocket.status_dict_punc["cache"] = {}
# 关闭连接
await websocket.close()
async def clear_websocket():
"""清理所有WebSocket连接"""
for websocket in websocket_users:
await ws_reset(websocket)
websocket_users.clear()
async def ws_serve(websocket, path):
""" """
WebSocket服务主函数处理客户端连接和消息 在应用启动时加载模型和初始化ASRRunner
在应用关闭时优雅地关闭ASRRunner
参数:
websocket: WebSocket连接对象
path: 连接路径
""" """
frames = [] # 存储所有音频帧 logger.info("应用启动开始加载模型和初始化Runner...")
frames_asr = [] # 存储用于离线ASR的音频帧
frames_asr_online = [] # 存储用于在线ASR的音频帧
global websocket_users # 1. 加载模型
# await clear_websocket() # 清理现有连接(目前注释掉,允许多客户端) # 这里的参数可以从配置文件或环境变量中获取
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"spk_model": "cam++",
"spk_model_revision": "v2.0.2",
}
model_loader = ModelLoader()
models = model_loader.load_models(args)
# 添加到用户集合 # 2. 初始化 ASRRunner
websocket_users.add(websocket) _audio_config = AudioBinary_Config(
chunk_size=200, # ms
# 初始化连接状态 sample_rate=16000,
websocket.status_dict_asr = {} sample_width=2, # 16-bit
websocket.status_dict_asr_online = {"cache": {}, "is_final": False} channels=1,
websocket.status_dict_vad = {"cache": {}, "is_final": False}
websocket.status_dict_punc = {"cache": {}}
websocket.chunk_interval = 10
websocket.vad_pre_idx = 0
websocket.is_speaking = True # 默认用户正在说话
# 语音检测状态
speech_start = False
speech_end_i = -1
# 初始化配置
websocket.wav_name = "microphone"
websocket.mode = "2pass" # 默认使用两阶段识别模式
print("新用户已连接", flush=True)
try:
# 持续接收客户端消息
async for message in websocket:
# 处理JSON配置消息
if isinstance(message, str):
try:
messagejson = json.loads(message)
# 更新各种配置参数
if "is_speaking" in messagejson:
websocket.is_speaking = messagejson["is_speaking"]
websocket.status_dict_asr_online["is_final"] = (
not websocket.is_speaking
) )
if "chunk_interval" in messagejson: _audio_config.chunk_stride = int(_audio_config.chunk_size * _audio_config.sample_rate / 1000)
websocket.chunk_interval = messagejson["chunk_interval"]
if "wav_name" in messagejson:
websocket.wav_name = messagejson.get("wav_name")
if "chunk_size" in messagejson:
chunk_size = messagejson["chunk_size"]
if isinstance(chunk_size, str):
chunk_size = chunk_size.split(",")
websocket.status_dict_asr_online["chunk_size"] = [
int(x) for x in chunk_size
]
if "encoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["encoder_chunk_look_back"] = (
messagejson["encoder_chunk_look_back"]
)
if "decoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["decoder_chunk_look_back"] = (
messagejson["decoder_chunk_look_back"]
)
if "hotword" in messagejson:
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
if "mode" in messagejson:
websocket.mode = messagejson["mode"]
except json.JSONDecodeError:
print(f"无效的JSON消息: {message}")
# 根据chunk_interval更新VAD的chunk_size asr_runner = ASRRunner()
websocket.status_dict_vad["chunk_size"] = int( asr_runner.set_default_config(
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1] audio_config=_audio_config,
* 60 models=models,
/ websocket.chunk_interval
) )
# 处理音频数据 # 3. 将 asr_runner 实例存储在 app.state 中
if ( app.state.asr_runner = asr_runner
len(frames_asr_online) > 0 logger.info("模型加载和Runner初始化完成。")
or len(frames_asr) >= 0
or not isinstance(message, str)
):
if not isinstance(message, str): # 二进制音频数据
# 添加到帧缓冲区
frames.append(message)
duration_ms = len(message) // 32 # 计算音频时长
websocket.vad_pre_idx += duration_ms
# 处理在线ASR yield
frames_asr_online.append(message)
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
# 达到chunk_interval或最终帧时处理在线ASR # --- 应用关闭时执行的代码 ---
if ( logger.info("应用关闭,开始清理资源...")
len(frames_asr_online) % websocket.chunk_interval == 0 await app.state.asr_runner.shutdown()
or websocket.status_dict_asr_online["is_final"] logger.info("资源清理完成。")
):
if websocket.mode == "2pass" or websocket.mode == "online":
audio_in = b"".join(frames_asr_online)
try:
await asr_service.async_asr_online(websocket, audio_in)
except Exception as e:
print(f"在线ASR处理错误: {e}")
frames_asr_online = []
# 如果检测到语音开始收集帧用于离线ASR # 初始化FastAPI应用并指定lifespan
if speech_start: app = FastAPI(lifespan=lifespan)
frames_asr.append(message)
# VAD处理 - 语音活动检测 # 挂载WebSocket路由
try: app.include_router(websocket_router, prefix="/ws")
speech_start_i, speech_end_i = await asr_service.async_vad(
websocket, message
)
except Exception as e:
print(f"VAD处理错误: {e}")
# 检测到语音开始
if speech_start_i != -1:
speech_start = True
# 计算开始偏移并收集前面的帧
beg_bias = (
websocket.vad_pre_idx - speech_start_i
) // duration_ms
frames_pre = (
frames[-beg_bias:] if beg_bias < len(frames) else frames
)
frames_asr = []
frames_asr.extend(frames_pre)
# 处理离线ASR (语音结束或用户停止说话)
if speech_end_i != -1 or not websocket.is_speaking:
if websocket.mode == "2pass" or websocket.mode == "offline":
audio_in = b"".join(frames_asr)
try:
await asr_service.async_asr(websocket, audio_in)
except Exception as e:
print(f"离线ASR处理错误: {e}")
# 重置状态
frames_asr = []
speech_start = False
frames_asr_online = []
websocket.status_dict_asr_online["cache"] = {}
# 如果用户停止说话,完全重置
if not websocket.is_speaking:
websocket.vad_pre_idx = 0
frames = []
websocket.status_dict_vad["cache"] = {}
else:
# 保留最近的帧用于下一轮处理
frames = frames[-20:]
except websockets.ConnectionClosed:
print(f"连接已关闭...", flush=True)
await ws_reset(websocket)
websocket_users.remove(websocket)
except websockets.InvalidState:
print("无效的WebSocket状态...")
except Exception as e:
print(f"发生异常: {e}")
def start_server(args, asr_service_instance):
"""
启动WebSocket服务器
参数:
args: 命令行参数
asr_service_instance: ASR服务实例
"""
global asr_service
asr_service = asr_service_instance
# 配置SSL (如果提供了证书)
if args.certfile and len(args.certfile) > 0:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile)
start_server = websockets.serve(
ws_serve,
args.host,
args.port,
subprotocols=["binary"],
ping_interval=None,
ssl=ssl_context,
)
else:
start_server = websockets.serve(
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None
)
print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}")
# 启动事件循环
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()
@app.get("/")
async def read_root():
return {"message": "FunASR-FastAPI WebSocket Server is running."}
# 如果需要直接运行此文件进行测试
if __name__ == "__main__": if __name__ == "__main__":
# 解析命令行参数 # 注意在生产环境中推荐使用Gunicorn + Uvicorn workers
args = parse_args() uvicorn.run(app, host="0.0.0.0", port=8000)
# 加载模型
print("正在加载模型...")
models = load_models(args)
print("模型加载完成!当前仅支持单个客户端同时连接!")
# 创建ASR服务
asr_service = ASRService(models)
# 启动服务器
start_server(args, asr_service)

View File

@ -1,131 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
ASR服务模块 - 提供语音识别相关的核心功能
"""
import json
class ASRService:
"""ASR服务类封装各种语音识别相关功能"""
def __init__(self, models):
"""
初始化ASR服务
参数:
models: 包含各种预加载模型的字典
"""
self.model_asr = models["asr"]
self.model_asr_streaming = models["asr_streaming"]
self.model_vad = models["vad"]
self.model_punc = models["punc"]
async def async_vad(self, websocket, audio_in):
"""
语音活动检测
参数:
websocket: WebSocket连接
audio_in: 二进制音频数据
返回:
tuple: (speech_start, speech_end) 语音开始和结束位置
"""
# 使用VAD模型分析音频段
segments_result = self.model_vad.generate(
input=audio_in, **websocket.status_dict_vad
)[0]["value"]
speech_start = -1
speech_end = -1
# 解析VAD结果
if len(segments_result) == 0 or len(segments_result) > 1:
return speech_start, speech_end
if segments_result[0][0] != -1:
speech_start = segments_result[0][0]
if segments_result[0][1] != -1:
speech_end = segments_result[0][1]
return speech_start, speech_end
async def async_asr(self, websocket, audio_in):
"""
离线ASR处理
参数:
websocket: WebSocket连接
audio_in: 二进制音频数据
"""
if len(audio_in) > 0:
# 使用离线ASR模型处理音频
rec_result = self.model_asr.generate(
input=audio_in, **websocket.status_dict_asr
)[0]
# 如果有标点符号模型且识别出文本,则添加标点
if self.model_punc is not None and len(rec_result["text"]) > 0:
rec_result = self.model_punc.generate(
input=rec_result["text"], **websocket.status_dict_punc
)[0]
# 如果识别出文本,发送到客户端
if len(rec_result["text"]) > 0:
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps(
{
"mode": mode,
"text": rec_result["text"],
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
await websocket.send(message)
else:
# 如果没有音频数据,发送空文本
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps(
{
"mode": mode,
"text": "",
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
await websocket.send(message)
async def async_asr_online(self, websocket, audio_in):
"""
在线ASR处理
参数:
websocket: WebSocket连接
audio_in: 二进制音频数据
"""
if len(audio_in) > 0:
# 使用在线ASR模型处理音频
rec_result = self.model_asr_streaming.generate(
input=audio_in, **websocket.status_dict_asr_online
)[0]
# 在2pass模式下如果是最终帧则跳过(留给离线ASR处理)
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get(
"is_final", False
):
return
# 如果识别出文本,发送到客户端
if len(rec_result["text"]):
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
message = json.dumps(
{
"mode": mode,
"text": rec_result["text"],
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
await websocket.send(message)

View File

@ -1,3 +1,6 @@
from .logger import get_module_logger, setup_root_logger from .logger import get_module_logger, setup_root_logger
__all__ = ["get_module_logger", "setup_root_logger"] __all__ = [
"get_module_logger",
"setup_root_logger",
]

63
src/websockets/adapter.py Normal file
View File

@ -0,0 +1,63 @@
import numpy as np
from fastapi import WebSocket
from typing import Union
import uuid
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class FastAPIWebSocketAdapter:
"""
一个适配器类用于将FastAPI的WebSocket对象包装成ASRRunner所期望的接口
同时处理数据类型转换
"""
def __init__(self, websocket: WebSocket, sample_rate: int = 16000, sample_width: int = 2):
self._ws = websocket
self._sample_rate = sample_rate
self._sample_width = sample_width
self._total_received = 0
async def recv(self) -> Union[np.ndarray, None]:
"""
接收来自FastAPI WebSocket的数据
如果收到的是字节流将其转换为Numpy数组
如果收到的是文本"close"返回None以表示结束
"""
message = await self._ws.receive()
if 'bytes' in message:
bytes_data = message['bytes']
# 将字节流转换为float64的Numpy数组
audio_array = np.frombuffer(bytes_data, dtype=np.float32)
# 使用回车符 \r 覆盖打印进度
self._total_received += len(bytes_data)
print(f"🎧 [Adapter] 正在接收音频... 总计: {self._total_received / 1024:.2f} KB", end='\r')
return audio_array
elif 'text' in message and message['text'].lower() == 'close':
print("\n🏁 [Adapter] 收到 'close' 信号。") # 在收到结束信号时换行
return None # 返回 None 来作为结束信号
return np.array([]) # 返回空数组以忽略其他类型的消息
async def send(self, message: dict):
"""
将字典消息作为JSON发送给客户端
在发送前将所有UUID对象转换为字符串以确保可序列化
"""
def convert_uuids(obj):
if isinstance(obj, dict):
return {k: convert_uuids(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_uuids(elem) for elem in obj]
elif isinstance(obj, uuid.UUID):
return str(obj)
return obj
serializable_message = convert_uuids(message)
logger.info(f"[Adapter] 发送消息: {serializable_message}")
await self._ws.send_json(serializable_message)
async def close(self):
"""
关闭WebSocket连接
"""
await self._ws.close()

View File

@ -0,0 +1,89 @@
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Request
from src.websockets.adapter import FastAPIWebSocketAdapter
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
router = APIRouter()
@router.websocket("/asr/{session_id}")
async def asr_websocket_endpoint(
websocket: WebSocket,
session_id: str,
mode: str = Query(default="sender", enum=["sender", "receiver"])
):
"""
ASR WebSocket 端点
- **session_id**: 标识一个识别会话的唯一ID.
- **mode**: 客户端模式.
- `sender`: 作为音频发送方加入将创建一个新的识别会话.
- `receiver`: 作为结果接收方加入订阅一个已存在的会话.
"""
await websocket.accept()
# 从websocket.app.state获取全局的ASRRunner实例
asr_runner = websocket.app.state.asr_runner
# 创建WebSocket适配器
# 注意这里的audio_config应该与ASRRunner中的默认配置一致
audio_config = asr_runner._default_audio_config
adapter = FastAPIWebSocketAdapter(
websocket,
sample_rate=audio_config.sample_rate,
sample_width=audio_config.sample_width
)
if mode == "sender":
logger.info(f"客户端 {websocket.client} 作为 'sender' 加入会话: {session_id}")
# 创建一个新的SAR会话
sar_id = asr_runner.new_SAR(ws=adapter, name=session_id)
if sar_id is None:
logger.error(f"为会话 {session_id} 创建SAR失败")
await websocket.close(code=1011, reason="Failed to create ASR session")
return
sar = next((s for s in asr_runner._SAR_list if s._id == sar_id), None)
try:
# 端点函数等待后台任务完成。
# 真正的接收逻辑在SAR的_run方法中该方法由new_SAR作为后台任务启动。
# 当客户端断开连接时adapter.recv()会抛出异常,
# _run任务会捕获它然后停止并清理最后任务结束。
if sar and sar._task:
await sar._task
else:
# 如果任务没有被创建,记录一个错误并关闭连接
logger.error(f"SAR任务未能在会话 {session_id} 中启动")
await websocket.close(code=1011, reason="Failed to start ASR task")
except Exception as e:
# 捕获任何意外的错误
logger.error(f"会话 {session_id}'sender' 端点发生未知错误: {e}")
finally:
logger.info(f"'sender' {websocket.client} 在会话 {session_id} 的连接处理已结束")
elif mode == "receiver":
logger.info(f"客户端 {websocket.client} 作为 'receiver' 加入会话: {session_id}")
# 加入一个已存在的SAR会话
joined = asr_runner.join_SAR(ws=adapter, name=session_id)
if not joined:
logger.warning(f"无法找到会话 {session_id}'receiver' {websocket.client} 加入失败")
await websocket.close(code=1011, reason=f"Session '{session_id}' not found")
return
try:
# Receiver只需要保持连接等待从SAR广播过来的消息
# 这个循环也用于检测断开
while True:
await websocket.receive_text()
except WebSocketDisconnect:
logger.info(f"'receiver' {websocket.client} 在会话 {session_id} 中断开连接")
# Receiver断开时需要将其从SAR的接收者列表中移除
sar = next((s for s in asr_runner._SAR_list if s._name == session_id), None)
if sar:
sar.delete_receiver(adapter)
logger.info(f"已从会话 {session_id} 中移除 'receiver' {websocket.client}")
else:
# 理论上由于FastAPI的enum校验这里的代码不会被执行
logger.error(f"无效的模式: {mode}")
await websocket.close(code=1003, reason="Invalid mode specified")

View File

@ -1,3 +1,9 @@
from endpoint import asr_router from fastapi import APIRouter
from .endpoint import asr_endpoint
__all__ = ["asr_router"] websocket_router = APIRouter()
# 包含ASR端点路由
websocket_router.include_router(asr_endpoint.router)
__all__ = ["websocket_router"]

View File

@ -6,6 +6,7 @@
from tests.pipeline.asr_test import test_asr_pipeline from tests.pipeline.asr_test import test_asr_pipeline
from src.utils.logger import get_module_logger, setup_root_logger from src.utils.logger import get_module_logger, setup_root_logger
from tests.runner.asr_runner_test import test_asr_runner from tests.runner.asr_runner_test import test_asr_runner
import asyncio
setup_root_logger(level="INFO", log_file="logs/test_main.log") setup_root_logger(level="INFO", log_file="logs/test_main.log")
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
@ -22,4 +23,4 @@ with open("logs/test_main.log", "w") as f:
# test_asr_pipeline() # test_asr_pipeline()
logger.info("开始测试ASRRunner") logger.info("开始测试ASRRunner")
test_asr_runner() asyncio.run(test_asr_runner())

View File

@ -1,30 +1,59 @@
""" """
ASRRunner test ASRRunner test
""" """
import queue import asyncio
import time
import soundfile import soundfile
import numpy as np import numpy as np
from src.runner.ASRRunner import ASRRunner from src.runner.ASRRunner import ASRRunner
from src.core.model_loader import ModelLoader from src.core.model_loader import ModelLoader
from src.models import AudioBinary_Config from src.models import AudioBinary_Config
from src.utils.mock_websocket import MockWebSocketClient from asyncio import Queue as AsyncQueue
from src.utils.logger import get_module_logger from src.utils.logger import get_module_logger
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
def test_asr_runner():
class AsyncMockWebSocketClient:
"""一个用于测试目的的异步WebSocket客户端模拟器。"""
def __init__(self):
self._recv_q = AsyncQueue()
self._send_q = AsyncQueue()
def put_for_recv(self, item):
"""允许测试将数据送入模拟的WebSocket中。"""
self._recv_q.put_nowait(item)
async def get_from_send(self):
"""允许测试从模拟的WebSocket中获取结果。"""
return await self._send_q.get()
async def recv(self):
"""ASRRunner将调用此方法来获取数据。"""
return await self._recv_q.get()
async def send(self, item):
"""ASRRunner将通过回调调用此方法来发送结果。"""
logger.info(f"Mock WS 收到结果: {item}")
await self._send_q.put(item)
async def close(self):
"""一个模拟的关闭方法。"""
pass
async def test_asr_runner():
""" """
End-to-end test for ASRRunner. 针对ASRRunner的端到端测试已适配异步操作
1. Loads models. 1. 加载模型.
2. Configures and initializes ASRRunner. 2. 配置并初始化ASRRunner.
3. Creates a mock WebSocket client. 3. 创建一个异步的模拟WebSocket客户端.
4. Starts a new SenderAndReceiver (SAR) instance in the runner. 4. 在Runner中启动一个新的SenderAndReceiver (SAR)实例.
5. Streams audio data via the mock WebSocket. 5. 通过模拟的WebSocket流式传输音频数据.
6. Asserts that the received transcription matches the expected text. 6. 等待处理任务完成并断言其无错误运行.
""" """
# 1. Load models # 1. 加载模型
model_loader = ModelLoader() model_loader = ModelLoader()
args = { args = {
"asr_model": "paraformer-zh", "asr_model": "paraformer-zh",
@ -33,48 +62,71 @@ def test_asr_runner():
"vad_model_revision": "v2.0.4", "vad_model_revision": "v2.0.4",
"spk_model": "cam++", "spk_model": "cam++",
"spk_model_revision": "v2.0.2", "spk_model_revision": "v2.0.2",
"audio_update": False,
} }
models = model_loader.load_models(args) models = model_loader.load_models(args)
audio_file_path = "tests/XT_ZZY_denoise.wav" audio_file_path = "tests/XT_ZZY_denoise.wav"
audio_data, sample_rate = soundfile.read(audio_file_path) audio_data, sample_rate = soundfile.read(audio_file_path)
logger.info(f"加载数据: {audio_file_path} , audio_data_length: {len(audio_data)}, audio_data_type: {type(audio_data)}, sample_rate: {sample_rate}") logger.info(
# 2. Configure audio f"加载数据: {audio_file_path} , audio_data_length: {len(audio_data)}, audio_data_type: {type(audio_data)}, sample_rate: {sample_rate}"
)
# 进一步详细打印audio_data数据类型
# 详细打印audio_data的类型和结构信息便于调试
logger.info(f"audio_data 类型: {type(audio_data)}")
logger.info(f"audio_data dtype: {getattr(audio_data, 'dtype', '未知')}")
logger.info(f"audio_data shape: {getattr(audio_data, 'shape', '未知')}")
logger.info(f"audio_data ndim: {getattr(audio_data, 'ndim', '未知')}")
logger.info(f"audio_data 示例前10个值: {audio_data[:10] if hasattr(audio_data, '__getitem__') else '不可切片'}")
# 2. 配置音频
audio_config = AudioBinary_Config( audio_config = AudioBinary_Config(
chunk_size=200, # ms chunk_size=200, # ms
chunk_stride=1000, # 10ms stride for 16kHz
sample_rate=sample_rate, sample_rate=sample_rate,
sample_width=2, # 16-bit sample_width=2, # 16-bit
channels=2, channels=1,
) )
audio_config.chunk_stride = int(audio_config.chunk_size * sample_rate / 1000) audio_config.chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
# 3. Setup ASRRunner # 3. 设置ASRRunner
asr_runner = ASRRunner() asr_runner = ASRRunner()
asr_runner.set_default_config( asr_runner.set_default_config(
audio_config=audio_config, audio_config=audio_config,
models=models, models=models,
) )
# 4. Create Mock WebSocket and start SAR # 4. 创建模拟WebSocket并启动SAR
mock_ws = MockWebSocketClient() mock_ws = AsyncMockWebSocketClient()
sar_id = asr_runner.new_SAR( sar_id = asr_runner.new_SAR(
ws=mock_ws, ws=mock_ws,
name="test_sar", name="test_sar",
) )
assert sar_id is not None, "Failed to create a new SAR instance" assert sar_id is not None, "创建新的SAR实例失败"
# 5. Simulate streaming audio # 获取SAR实例以等待其任务
print(f"Sending audio data of length {len(audio_data)} samples.") sar = next((s for s in asr_runner._SAR_list if s._id == sar_id), None)
audio_clip_len = 200 assert sar is not None, "无法从Runner中获取SAR实例。"
assert sar._task is not None, "SAR任务未被创建。"
# 5. 在后台任务中模拟流式音频
async def feed_audio():
logger.info("Feeder任务已启动开始流式传输音频数据...")
# 每次发送100ms的音频
audio_clip_len = int(sample_rate * 0.1)
for i in range(0, len(audio_data), audio_clip_len): for i in range(0, len(audio_data), audio_clip_len):
chunk = audio_data[i : i + audio_clip_len] chunk = audio_data[i : i + audio_clip_len]
if not isinstance(chunk, np.ndarray) or chunk.size == 0: if chunk.size == 0:
break break
# Simulate receiving binary data over WebSocket
mock_ws.put_for_recv(chunk) mock_ws.put_for_recv(chunk)
await asyncio.sleep(0.1) # 模拟实时流
# 6. Wait for results and assert # 发送None来表示音频流结束
time.sleep(30)
# Signal end of audio stream by sending None
mock_ws.put_for_recv(None) mock_ws.put_for_recv(None)
logger.info("Feeder任务已完成所有音频数据已发送。")
feeder_task = asyncio.create_task(feed_audio())
# 6. 等待SAR处理完成
# SAR任务在从模拟WebSocket接收到None后会结束
await sar._task
await feeder_task # 确保feeder也已完成
logger.info("ASRRunner测试成功完成。")

View File

@ -0,0 +1,110 @@
import asyncio
import websockets
import soundfile as sf
import uuid
# --- 配置 ---
HOST = "localhost"
PORT = 8000
SESSION_ID = str(uuid.uuid4())
SENDER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=sender"
RECEIVER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=receiver"
AUDIO_FILE_PATH = "tests/XT_ZZY_denoise.wav" # 确保此测试文件存在且为 16kHz, 16-bit, 单声道
CHUNK_DURATION_MS = 100 # 每次发送100ms的音频数据
CHUNK_SIZE = int(16000 * 2 * CHUNK_DURATION_MS / 1000) # 3200 bytes
async def run_receiver():
"""作为接收者连接,并打印收到的所有消息。"""
print(f"▶️ [Receiver] 尝试连接到: {RECEIVER_URI}")
try:
async with websockets.connect(RECEIVER_URI) as websocket:
print("✅ [Receiver] 连接成功,等待消息...")
try:
while True:
message = await websocket.recv()
print(f"🎧 [Receiver] 收到结果: {message}")
except websockets.exceptions.ConnectionClosed as e:
print(f"✅ [Receiver] 连接已由服务器正常关闭: {e.reason}")
except Exception as e:
print(f"❌ [Receiver] 连接失败: {e}")
async def run_sender():
"""
作为发送者连接同时负责发送音频和接收自己会话的广播结果
"""
await asyncio.sleep(1) # 等待receiver有机会先连接
print(f"▶️ [Sender] 尝试连接到: {SENDER_URI}")
try:
async with websockets.connect(SENDER_URI) as websocket:
print("✅ [Sender] 连接成功。")
# --- 并行任务:接收消息 ---
async def receive_task():
print("▶️ [Sender-Receiver] 开始监听广播消息...")
try:
while True:
message = await websocket.recv()
print(f"🎧 [Sender-Receiver] 收到结果: {message}")
except websockets.exceptions.ConnectionClosed:
print("✅ [Sender-Receiver] 连接已关闭,停止监听。")
receiver_sub_task = asyncio.create_task(receive_task())
# --- 主任务:发送音频 ---
try:
print("▶️ [Sender] 准备发送音频...")
audio_data, sample_rate = sf.read(AUDIO_FILE_PATH, dtype='float32')
if sample_rate != 16000:
print(f"❌ [Sender] 错误:音频文件采样率必须是 16kHz。")
receiver_sub_task.cancel()
return
total_samples = len(audio_data)
chunk_samples = CHUNK_SIZE // 2
samples_sent = 0
print(f"音频加载成功,总长度: {total_samples} samples。开始分块发送...")
for i in range(0, total_samples, chunk_samples):
chunk = audio_data[i:i + chunk_samples]
if len(chunk) == 0:
break
await websocket.send(chunk.tobytes())
samples_sent += len(chunk)
print(f"🎧 [Sender] 正在发送: {samples_sent}/{total_samples} samples", end="\r")
await asyncio.sleep(CHUNK_DURATION_MS / 1000)
print()
print("🏁 [Sender] 音频流发送完毕,发送 'close' 信号。")
await websocket.send("close")
except FileNotFoundError:
print(f"❌ [Sender] 错误:找不到音频文件 {AUDIO_FILE_PATH}")
except Exception as e:
print(f"❌ [Sender] 发送过程中发生错误: {e}")
# 等待接收任务自然结束(当连接关闭时)
await receiver_sub_task
except Exception as e:
print(f"❌ [Sender] 连接失败: {e}")
async def main():
"""同时运行 sender 和 receiver 任务。"""
print("--- 开始 WebSocket ASR 服务端到端测试 ---")
print(f"会话 ID: {SESSION_ID}")
# 创建 receiver 和 sender 任务
sender_task = asyncio.create_task(run_sender())
await asyncio.sleep(7)
receiver_task = asyncio.create_task(run_receiver())
# 等待两个任务完成
await asyncio.gather(receiver_task, sender_task)
print("--- 测试结束 ---")
if __name__ == "__main__":
# 在运行此脚本前,请确保 FastAPI 服务器正在运行。
# python main.py
asyncio.run(main())