From 3083738db473ab2538182c2894b2472d91438914 Mon Sep 17 00:00:00 2001 From: "Ziyang.Zhang" Date: Wed, 2 Jul 2025 15:49:21 +0800 Subject: [PATCH] =?UTF-8?q?[fastAPI]=E5=AE=8C=E6=88=90fastAPI-websocket?= =?UTF-8?q?=E7=AB=AF=E5=8F=A3=E6=90=AD=E5=BB=BA=E4=B8=8E=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=EF=BC=8C=E6=8E=A5=E6=94=B6float32=E7=9A=84tobytes=E5=AD=97?= =?UTF-8?q?=E8=8A=82=E6=B5=81=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- WEBSOCKET_API.md | 265 +++++++++++++++++++++ main.py | 19 +- src/runner/ASRRunner.py | 78 +++++-- src/runner/__init__.py | 3 + src/server.py | 299 +++++------------------- src/service.py | 131 ----------- src/utils/__init__.py | 5 +- src/websockets/adapter.py | 63 +++++ src/websockets/endpoint/asr_endpoint.py | 89 +++++++ src/websockets/router.py | 10 +- test_main.py | 3 +- tests/runner/asr_runner_test.py | 120 +++++++--- tests/websocket/websocket_asr.py | 110 +++++++++ 13 files changed, 758 insertions(+), 437 deletions(-) create mode 100644 WEBSOCKET_API.md create mode 100644 src/runner/__init__.py delete mode 100644 src/service.py create mode 100644 src/websockets/adapter.py create mode 100644 src/websockets/endpoint/asr_endpoint.py create mode 100644 tests/websocket/websocket_asr.py diff --git a/WEBSOCKET_API.md b/WEBSOCKET_API.md new file mode 100644 index 0000000..cf7e97e --- /dev/null +++ b/WEBSOCKET_API.md @@ -0,0 +1,265 @@ +# FunASR-FastAPI WebSocket API 文档 + +本文档详细介绍了如何连接和使用 FunASR-FastAPI 实时语音识别服务的 WebSocket 接口。 + +## 1. 连接端点 (Endpoint) + +服务的 WebSocket 端点 URL 格式如下: + +``` +ws://:8000/ws/asr/{session_id}?mode= +``` + +### 参数说明 + +- **`{session_id}`** (路径参数, `str`, **必需**): + 用于唯一标识一个识别会话(例如,一场会议或一次直播)。所有属于同一次会话的客户端都应使用相同的 `session_id`。 + +- **`mode`** (查询参数, `str`, **必需**): + 定义客户端的角色。 + - `sender`: 音频发送者。一个会话中应该只有一个 `sender`。此客户端负责将实时音频流发送到服务器。 + - `receiver`: 结果接收者。一个会话中可以有多个 `receiver`。此客户端只接收由服务器广播的识别结果,不发送音频。 + +## 2. 数据格式 + +### 2.1 发送数据 (Sender -> Server) + +- **音频格式**: `sender` 必须发送原始的 **PCM 音频数据**。 + - **采样率**: 16000 Hz + - **位深**: 16-bit (signed integer) + - **声道数**: 单声道 (Mono) +- **传输格式**: 必须以**二进制 (bytes)** 格式发送。 +- **结束信号**: 当音频流结束时,`sender` 应发送一个**文本消息** `"close"` 来通知服务器关闭会话。 + +### 2.2 接收数据 (Server -> Receiver) + +服务器会将识别结果以 **JSON 文本** 格式广播给会话中的所有 `receiver`(以及 `sender` 自己)。JSON 对象的结构示例如下: + +```json +{ + "asr": "你好,世界。", + "spk": { + "speaker_id": "uuid-of-the-speaker", + "speaker_name": "SpeakerName", + "score": 0.98 + } +} +``` + +## 3. Python 客户端示例 + +需要安装 `websockets` 库: `pip install websockets` + +### 3.1 Python Sender 示例 (发送本地音频文件) + +这个脚本会读取一个 WAV 文件,并将其内容以流式方式发送到服务器。 + +```python +import asyncio +import websockets +import soundfile as sf +import uuid + +# --- 配置 --- +SERVER_URI = "ws://localhost:8000/ws/asr/{session_id}?mode=sender" +SESSION_ID = str(uuid.uuid4()) # 为这次会话生成一个唯一的ID +AUDIO_FILE = "tests/XT_ZZY_denoise.wav" # 替换为你的音频文件路径 +CHUNK_SIZE = 3200 # 每次发送 100ms 的音频数据 (16000 * 2 * 0.1) + +async def send_audio(): + """连接到服务器,并流式发送音频文件""" + uri = SERVER_URI.format(session_id=SESSION_ID) + print(f"作为 Sender 连接到: {uri}") + + async with websockets.connect(uri) as websocket: + try: + # 读取音频文件 + with sf.SoundFile(AUDIO_FILE, 'r') as f: + assert f.samplerate == 16000, "音频文件采样率必须为 16kHz" + assert f.channels == 1, "音频文件必须为单声道" + + print("开始发送音频...") + while True: + data = f.read(CHUNK_SIZE, dtype='int16') + if not data.any(): + break + # 将 numpy 数组转换为原始字节流 + await websocket.send(data.tobytes()) + await asyncio.sleep(0.1) # 模拟实时音频输入 + + print("音频发送完毕,发送结束信号。") + await websocket.send("close") + + # 等待服务器的最终确认或关闭连接 + response = await websocket.recv() + print(f"收到服务器最终响应: {response}") + + except websockets.exceptions.ConnectionClosed as e: + print(f"连接已关闭: {e}") + except Exception as e: + print(f"发生错误: {e}") + +if __name__ == "__main__": + asyncio.run(send_audio()) +``` + +### 3.2 Python Receiver 示例 (接收识别结果) + +这个脚本会连接到指定的会话,并持续打印服务器广播的识别结果。 + +```python +import asyncio +import websockets + +# --- 配置 --- +# !!! 必须和 Sender 使用相同的 SESSION_ID !!! +SERVER_URI = "ws://localhost:8000/ws/asr/{session_id}?mode=receiver" +SESSION_ID = "在此处粘贴你的Sender会话ID" + +async def receive_results(): + """连接到服务器并接收识别结果""" + if "粘贴你的Sender会话ID" in SESSION_ID: + print("错误:请先设置有效的 SESSION_ID!") + return + + uri = SERVER_URI.format(session_id=SESSION_ID) + print(f"作为 Receiver 连接到: {uri}") + + async with websockets.connect(uri) as websocket: + try: + print("等待接收识别结果...") + while True: + message = await websocket.recv() + print(f"收到结果: {message}") + except websockets.exceptions.ConnectionClosed as e: + print(f"连接已关闭: {e.code} {e.reason}") + +if __name__ == "__main__": + asyncio.run(receive_results()) +``` + +## 4. JavaScript 客户端示例 (浏览器) + +这个示例展示了如何在网页上通过麦克风获取音频,并将其作为 `sender` 发送。 + +```html + + + + + WebSocket ASR Client + + +

FunASR WebSocket Client (Sender)

+

Session ID:

+ + +

识别结果:

+
+ + + + +``` \ No newline at end of file diff --git a/main.py b/main.py index 8135688..ffa7a78 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,19 @@ -from src.server import app import uvicorn -from datetime import datetime +from src.server import app from src.utils.logger import get_module_logger, setup_root_logger +from datetime import datetime -time = format(datetime.now(), "%Y-%m-%d %H:%M:%S") -setup_root_logger(level="DEBUG", log_file=f"logs/fastapiserver_{time}.log") +time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") +setup_root_logger(level="INFO", log_file=f"logs/main_{time}.log") logger = get_module_logger(__name__) - if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file + logger.info("启动 FunASR FastAPI 服务器...") + # 在生产环境中,推荐使用更强大的ASGI服务器,如Gunicorn,并配合Uvicorn workers。 + # 例如: gunicorn -w 4 -k uvicorn.workers.UvicornWorker main:app + uvicorn.run( + app, + host="0.0.0.0", + port=8000, + log_level="info" + ) \ No newline at end of file diff --git a/src/runner/ASRRunner.py b/src/runner/ASRRunner.py index e7898b9..dd50257 100644 --- a/src/runner/ASRRunner.py +++ b/src/runner/ASRRunner.py @@ -11,6 +11,7 @@ from src.pipeline import PipelineFactory from src.models import AudioBinary_data_list, AudioBinary_Config from src.core.model_loader import ModelLoader from src.config import DefaultConfig +import asyncio from queue import Queue import soundfile import time @@ -57,6 +58,7 @@ class ASRRunner(RunnerBase): # 输入队列 self._input_queue: Queue = Queue() self._pipeline: Optional[ASRPipeline] = None + self._task: Optional[asyncio.Task] = None def set_name(self, name: str): self._name = name @@ -76,7 +78,14 @@ class ASRRunner(RunnerBase): self._pipeline.set_models(self._models) self._pipeline.set_audio_binary(self._audio_binary) self._pipeline.set_input_queue(self._input_queue) - self._pipeline.set_callback(self.deal_message) + + # --- 异步-同步桥梁 --- + # 创建一个线程安全的回调函数,用于从Pipeline的线程中调用Runner的异步方法 + loop = asyncio.get_running_loop() + def thread_safe_callback(message): + asyncio.run_coroutine_threadsafe(self.deal_message(message), loop) + + self._pipeline.set_callback(thread_safe_callback) self._pipeline.bake() def append_receiver(self, receiver: WebSocketClient): @@ -85,47 +94,64 @@ class ASRRunner(RunnerBase): def delete_receiver(self, receiver: WebSocketClient): self._receiver.remove(receiver) - def deal_message(self, message: str): - self.broadcast(message) + async def deal_message(self, message: str): + await self.broadcast(message) - def broadcast(self, message: str): + async def broadcast(self, message: str): """ 广播发送给所有接收者 """ - logger.info("[ASRRunner][SAR-%s]广播发送给所有接收者: %s", self._name, message) - for receiver in self._receiver: - receiver.send(message) + logger.info("[ASRRunner][SAR-%s]广播发送给所有接收者: 消息长度:%s", self._name, len(message)) + logger.info(f"SAR-{self._name} 的接收者列表: {self._receiver}") + tasks = [receiver.send(message) for receiver in self._receiver] + await asyncio.gather(*tasks) - def _run(self): + async def _run(self): """ 运行SAR """ self._pipeline.run() + loop = asyncio.get_running_loop() while True: - data = self._sender.recv() - if data is None: + try: + data = await self._sender.recv() + if data is None: + # `None` is used as a signal to end the stream + await loop.run_in_executor(None, self._input_queue.put, None) + break + # logger.debug("[ASRRunner][SAR-%s]接收到的数据length: %s", self._name, len(data)) + await loop.run_in_executor(None, self._input_queue.put, data) + except Exception as e: + logger.error(f"[ASRRunner][SAR-{self._name}] _run loop error: {e}") break - # logger.debug("[ASRRunner][SAR-%s]接收到的数据length: %s", self._name, len(data)) - self._input_queue.put(data) - self.stop() + await self.stop() def run(self): """ 运行SAR """ - self._thread = Thread(target=self._run, name=f"[ASRRunner]SAR-{self._name}") - self._thread.daemon = True - self._thread.start() + self._task = asyncio.create_task(self._run()) - def stop(self): + async def stop(self): """ 停止SAR """ + logger.info(f"Stopping SAR: {self._name}") self._pipeline.stop() - for ws in self._receiver: - ws.close() - self._sender.close() - + + # Close all receiver websockets + receiver_tasks = [ws.close() for ws in self._receiver] + await asyncio.gather(*receiver_tasks, return_exceptions=True) + + # Close the sender websocket + if self._sender: + await self._sender.close() + + # Cancel the main task if it's still running + if self._task and not self._task.done(): + self._task.cancel() + logger.info(f"SAR stopped: {self._name}") + def __init__(self,*args,**kwargs): """ """ @@ -195,9 +221,11 @@ class ASRRunner(RunnerBase): return True return False - def __del__(self) -> None: + async def shutdown(self): """ - 析构函数 + 优雅地关闭所有SAR会话 """ - for sar in self._SAR_list: - sar.stop() + logger.info("Shutting down all SAR instances...") + tasks = [sar.stop() for sar in self._SAR_list] + await asyncio.gather(*tasks, return_exceptions=True) + logger.info("All SAR instances have been shut down.") diff --git a/src/runner/__init__.py b/src/runner/__init__.py new file mode 100644 index 0000000..3372b8e --- /dev/null +++ b/src/runner/__init__.py @@ -0,0 +1,3 @@ +from .ASRRunner import ASRRunner + +__all__ = ["ASRRunner"] diff --git a/src/server.py b/src/server.py index 8d48693..062b9b3 100644 --- a/src/server.py +++ b/src/server.py @@ -10,255 +10,80 @@ import json import websockets import ssl import argparse -from config import parse_args -from models import load_models -from service import ASRService +from src.runner import ASRRunner +from src.config import DefaultConfig +from src.config import AudioBinary_Config +from src.websockets.router import websocket_router +from src.core import ModelLoader +import uvicorn +from fastapi import FastAPI +from contextlib import asynccontextmanager +from src.utils.logger import get_module_logger + +logger = get_module_logger(__name__) # 全局变量,存储当前连接的WebSocket客户端 websocket_users = set() - -async def ws_reset(websocket): - """重置WebSocket连接状态并关闭连接""" - print(f"重置WebSocket连接,当前连接数: {len(websocket_users)}") - - # 重置状态字典 - websocket.status_dict_asr_online["cache"] = {} - websocket.status_dict_asr_online["is_final"] = True - websocket.status_dict_vad["cache"] = {} - websocket.status_dict_vad["is_final"] = True - websocket.status_dict_punc["cache"] = {} - - # 关闭连接 - await websocket.close() - - -async def clear_websocket(): - """清理所有WebSocket连接""" - for websocket in websocket_users: - await ws_reset(websocket) - websocket_users.clear() - - -async def ws_serve(websocket, path): +# 使用 lifespan 上下文管理器来管理应用的生命周期 +@asynccontextmanager +async def lifespan(app: FastAPI): """ - WebSocket服务主函数,处理客户端连接和消息 - - 参数: - websocket: WebSocket连接对象 - path: 连接路径 + 在应用启动时加载模型和初始化ASRRunner, + 在应用关闭时优雅地关闭ASRRunner。 """ - frames = [] # 存储所有音频帧 - frames_asr = [] # 存储用于离线ASR的音频帧 - frames_asr_online = [] # 存储用于在线ASR的音频帧 + logger.info("应用启动,开始加载模型和初始化Runner...") + + # 1. 加载模型 + # 这里的参数可以从配置文件或环境变量中获取 + args = { + "asr_model": "paraformer-zh", + "asr_model_revision": "v2.0.4", + "vad_model": "fsmn-vad", + "vad_model_revision": "v2.0.4", + "spk_model": "cam++", + "spk_model_revision": "v2.0.2", + } + model_loader = ModelLoader() + models = model_loader.load_models(args) + + # 2. 初始化 ASRRunner + _audio_config = AudioBinary_Config( + chunk_size=200, # ms + sample_rate=16000, + sample_width=2, # 16-bit + channels=1, + ) + _audio_config.chunk_stride = int(_audio_config.chunk_size * _audio_config.sample_rate / 1000) - global websocket_users - # await clear_websocket() # 清理现有连接(目前注释掉,允许多客户端) + asr_runner = ASRRunner() + asr_runner.set_default_config( + audio_config=_audio_config, + models=models, + ) - # 添加到用户集合 - websocket_users.add(websocket) + # 3. 将 asr_runner 实例存储在 app.state 中 + app.state.asr_runner = asr_runner + logger.info("模型加载和Runner初始化完成。") - # 初始化连接状态 - websocket.status_dict_asr = {} - websocket.status_dict_asr_online = {"cache": {}, "is_final": False} - websocket.status_dict_vad = {"cache": {}, "is_final": False} - websocket.status_dict_punc = {"cache": {}} - websocket.chunk_interval = 10 - websocket.vad_pre_idx = 0 - websocket.is_speaking = True # 默认用户正在说话 + yield - # 语音检测状态 - speech_start = False - speech_end_i = -1 + # --- 应用关闭时执行的代码 --- + logger.info("应用关闭,开始清理资源...") + await app.state.asr_runner.shutdown() + logger.info("资源清理完成。") - # 初始化配置 - websocket.wav_name = "microphone" - websocket.mode = "2pass" # 默认使用两阶段识别模式 +# 初始化FastAPI应用,并指定lifespan +app = FastAPI(lifespan=lifespan) - print("新用户已连接", flush=True) - - try: - # 持续接收客户端消息 - async for message in websocket: - # 处理JSON配置消息 - if isinstance(message, str): - try: - messagejson = json.loads(message) - - # 更新各种配置参数 - if "is_speaking" in messagejson: - websocket.is_speaking = messagejson["is_speaking"] - websocket.status_dict_asr_online["is_final"] = ( - not websocket.is_speaking - ) - if "chunk_interval" in messagejson: - websocket.chunk_interval = messagejson["chunk_interval"] - if "wav_name" in messagejson: - websocket.wav_name = messagejson.get("wav_name") - if "chunk_size" in messagejson: - chunk_size = messagejson["chunk_size"] - if isinstance(chunk_size, str): - chunk_size = chunk_size.split(",") - websocket.status_dict_asr_online["chunk_size"] = [ - int(x) for x in chunk_size - ] - if "encoder_chunk_look_back" in messagejson: - websocket.status_dict_asr_online["encoder_chunk_look_back"] = ( - messagejson["encoder_chunk_look_back"] - ) - if "decoder_chunk_look_back" in messagejson: - websocket.status_dict_asr_online["decoder_chunk_look_back"] = ( - messagejson["decoder_chunk_look_back"] - ) - if "hotword" in messagejson: - websocket.status_dict_asr["hotword"] = messagejson["hotwords"] - if "mode" in messagejson: - websocket.mode = messagejson["mode"] - except json.JSONDecodeError: - print(f"无效的JSON消息: {message}") - - # 根据chunk_interval更新VAD的chunk_size - websocket.status_dict_vad["chunk_size"] = int( - websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1] - * 60 - / websocket.chunk_interval - ) - - # 处理音频数据 - if ( - len(frames_asr_online) > 0 - or len(frames_asr) >= 0 - or not isinstance(message, str) - ): - if not isinstance(message, str): # 二进制音频数据 - # 添加到帧缓冲区 - frames.append(message) - duration_ms = len(message) // 32 # 计算音频时长 - websocket.vad_pre_idx += duration_ms - - # 处理在线ASR - frames_asr_online.append(message) - websocket.status_dict_asr_online["is_final"] = speech_end_i != -1 - - # 达到chunk_interval或最终帧时处理在线ASR - if ( - len(frames_asr_online) % websocket.chunk_interval == 0 - or websocket.status_dict_asr_online["is_final"] - ): - if websocket.mode == "2pass" or websocket.mode == "online": - audio_in = b"".join(frames_asr_online) - try: - await asr_service.async_asr_online(websocket, audio_in) - except Exception as e: - print(f"在线ASR处理错误: {e}") - frames_asr_online = [] - - # 如果检测到语音开始,收集帧用于离线ASR - if speech_start: - frames_asr.append(message) - - # VAD处理 - 语音活动检测 - try: - speech_start_i, speech_end_i = await asr_service.async_vad( - websocket, message - ) - except Exception as e: - print(f"VAD处理错误: {e}") - - # 检测到语音开始 - if speech_start_i != -1: - speech_start = True - # 计算开始偏移并收集前面的帧 - beg_bias = ( - websocket.vad_pre_idx - speech_start_i - ) // duration_ms - frames_pre = ( - frames[-beg_bias:] if beg_bias < len(frames) else frames - ) - frames_asr = [] - frames_asr.extend(frames_pre) - - # 处理离线ASR (语音结束或用户停止说话) - if speech_end_i != -1 or not websocket.is_speaking: - if websocket.mode == "2pass" or websocket.mode == "offline": - audio_in = b"".join(frames_asr) - try: - await asr_service.async_asr(websocket, audio_in) - except Exception as e: - print(f"离线ASR处理错误: {e}") - - # 重置状态 - frames_asr = [] - speech_start = False - frames_asr_online = [] - websocket.status_dict_asr_online["cache"] = {} - - # 如果用户停止说话,完全重置 - if not websocket.is_speaking: - websocket.vad_pre_idx = 0 - frames = [] - websocket.status_dict_vad["cache"] = {} - else: - # 保留最近的帧用于下一轮处理 - frames = frames[-20:] - - except websockets.ConnectionClosed: - print(f"连接已关闭...", flush=True) - await ws_reset(websocket) - websocket_users.remove(websocket) - except websockets.InvalidState: - print("无效的WebSocket状态...") - except Exception as e: - print(f"发生异常: {e}") - - -def start_server(args, asr_service_instance): - """ - 启动WebSocket服务器 - - 参数: - args: 命令行参数 - asr_service_instance: ASR服务实例 - """ - global asr_service - asr_service = asr_service_instance - - # 配置SSL (如果提供了证书) - if args.certfile and len(args.certfile) > 0: - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile) - - start_server = websockets.serve( - ws_serve, - args.host, - args.port, - subprotocols=["binary"], - ping_interval=None, - ssl=ssl_context, - ) - else: - start_server = websockets.serve( - ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None - ) - - print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}") - - # 启动事件循环 - asyncio.get_event_loop().run_until_complete(start_server) - asyncio.get_event_loop().run_forever() +# 挂载WebSocket路由 +app.include_router(websocket_router, prefix="/ws") +@app.get("/") +async def read_root(): + return {"message": "FunASR-FastAPI WebSocket Server is running."} +# 如果需要直接运行此文件进行测试 if __name__ == "__main__": - # 解析命令行参数 - args = parse_args() - - # 加载模型 - print("正在加载模型...") - models = load_models(args) - print("模型加载完成!当前仅支持单个客户端同时连接!") - - # 创建ASR服务 - asr_service = ASRService(models) - - # 启动服务器 - start_server(args, asr_service) + # 注意:在生产环境中,推荐使用Gunicorn + Uvicorn workers + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/src/service.py b/src/service.py deleted file mode 100644 index e8a6dfb..0000000 --- a/src/service.py +++ /dev/null @@ -1,131 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -ASR服务模块 - 提供语音识别相关的核心功能 -""" - -import json - - -class ASRService: - """ASR服务类,封装各种语音识别相关功能""" - - def __init__(self, models): - """ - 初始化ASR服务 - - 参数: - models: 包含各种预加载模型的字典 - """ - self.model_asr = models["asr"] - self.model_asr_streaming = models["asr_streaming"] - self.model_vad = models["vad"] - self.model_punc = models["punc"] - - async def async_vad(self, websocket, audio_in): - """ - 语音活动检测 - - 参数: - websocket: WebSocket连接 - audio_in: 二进制音频数据 - - 返回: - tuple: (speech_start, speech_end) 语音开始和结束位置 - """ - # 使用VAD模型分析音频段 - segments_result = self.model_vad.generate( - input=audio_in, **websocket.status_dict_vad - )[0]["value"] - - speech_start = -1 - speech_end = -1 - - # 解析VAD结果 - if len(segments_result) == 0 or len(segments_result) > 1: - return speech_start, speech_end - - if segments_result[0][0] != -1: - speech_start = segments_result[0][0] - if segments_result[0][1] != -1: - speech_end = segments_result[0][1] - - return speech_start, speech_end - - async def async_asr(self, websocket, audio_in): - """ - 离线ASR处理 - - 参数: - websocket: WebSocket连接 - audio_in: 二进制音频数据 - """ - if len(audio_in) > 0: - # 使用离线ASR模型处理音频 - rec_result = self.model_asr.generate( - input=audio_in, **websocket.status_dict_asr - )[0] - - # 如果有标点符号模型且识别出文本,则添加标点 - if self.model_punc is not None and len(rec_result["text"]) > 0: - rec_result = self.model_punc.generate( - input=rec_result["text"], **websocket.status_dict_punc - )[0] - - # 如果识别出文本,发送到客户端 - if len(rec_result["text"]) > 0: - mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode - message = json.dumps( - { - "mode": mode, - "text": rec_result["text"], - "wav_name": websocket.wav_name, - "is_final": websocket.is_speaking, - } - ) - await websocket.send(message) - else: - # 如果没有音频数据,发送空文本 - mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode - message = json.dumps( - { - "mode": mode, - "text": "", - "wav_name": websocket.wav_name, - "is_final": websocket.is_speaking, - } - ) - await websocket.send(message) - - async def async_asr_online(self, websocket, audio_in): - """ - 在线ASR处理 - - 参数: - websocket: WebSocket连接 - audio_in: 二进制音频数据 - """ - if len(audio_in) > 0: - # 使用在线ASR模型处理音频 - rec_result = self.model_asr_streaming.generate( - input=audio_in, **websocket.status_dict_asr_online - )[0] - - # 在2pass模式下,如果是最终帧则跳过(留给离线ASR处理) - if websocket.mode == "2pass" and websocket.status_dict_asr_online.get( - "is_final", False - ): - return - - # 如果识别出文本,发送到客户端 - if len(rec_result["text"]): - mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode - message = json.dumps( - { - "mode": mode, - "text": rec_result["text"], - "wav_name": websocket.wav_name, - "is_final": websocket.is_speaking, - } - ) - await websocket.send(message) diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 19cb0bd..ec586b8 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,3 +1,6 @@ from .logger import get_module_logger, setup_root_logger -__all__ = ["get_module_logger", "setup_root_logger"] +__all__ = [ + "get_module_logger", + "setup_root_logger", + ] diff --git a/src/websockets/adapter.py b/src/websockets/adapter.py new file mode 100644 index 0000000..053d4c0 --- /dev/null +++ b/src/websockets/adapter.py @@ -0,0 +1,63 @@ +import numpy as np +from fastapi import WebSocket +from typing import Union +import uuid +from src.utils.logger import get_module_logger + +logger = get_module_logger(__name__) + +class FastAPIWebSocketAdapter: + """ + 一个适配器类,用于将FastAPI的WebSocket对象包装成ASRRunner所期望的接口。 + 同时处理数据类型转换。 + """ + def __init__(self, websocket: WebSocket, sample_rate: int = 16000, sample_width: int = 2): + self._ws = websocket + self._sample_rate = sample_rate + self._sample_width = sample_width + self._total_received = 0 + async def recv(self) -> Union[np.ndarray, None]: + """ + 接收来自FastAPI WebSocket的数据。 + 如果收到的是字节流,将其转换为Numpy数组。 + 如果收到的是文本"close",返回None以表示结束。 + """ + message = await self._ws.receive() + if 'bytes' in message: + bytes_data = message['bytes'] + # 将字节流转换为float64的Numpy数组 + audio_array = np.frombuffer(bytes_data, dtype=np.float32) + + # 使用回车符 \r 覆盖打印进度 + self._total_received += len(bytes_data) + print(f"🎧 [Adapter] 正在接收音频... 总计: {self._total_received / 1024:.2f} KB", end='\r') + + return audio_array + elif 'text' in message and message['text'].lower() == 'close': + print("\n🏁 [Adapter] 收到 'close' 信号。") # 在收到结束信号时换行 + return None # 返回 None 来作为结束信号 + return np.array([]) # 返回空数组以忽略其他类型的消息 + + async def send(self, message: dict): + """ + 将字典消息作为JSON发送给客户端。 + 在发送前,将所有UUID对象转换为字符串以确保可序列化。 + """ + def convert_uuids(obj): + if isinstance(obj, dict): + return {k: convert_uuids(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [convert_uuids(elem) for elem in obj] + elif isinstance(obj, uuid.UUID): + return str(obj) + return obj + + serializable_message = convert_uuids(message) + logger.info(f"[Adapter] 发送消息: {serializable_message}") + await self._ws.send_json(serializable_message) + + async def close(self): + """ + 关闭WebSocket连接。 + """ + await self._ws.close() \ No newline at end of file diff --git a/src/websockets/endpoint/asr_endpoint.py b/src/websockets/endpoint/asr_endpoint.py new file mode 100644 index 0000000..042bb3b --- /dev/null +++ b/src/websockets/endpoint/asr_endpoint.py @@ -0,0 +1,89 @@ +from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Request +from src.websockets.adapter import FastAPIWebSocketAdapter +from src.utils.logger import get_module_logger + +logger = get_module_logger(__name__) +router = APIRouter() + +@router.websocket("/asr/{session_id}") +async def asr_websocket_endpoint( + websocket: WebSocket, + session_id: str, + mode: str = Query(default="sender", enum=["sender", "receiver"]) +): + """ + ASR WebSocket 端点 + + - **session_id**: 标识一个识别会话的唯一ID. + - **mode**: 客户端模式. + - `sender`: 作为音频发送方加入,将创建一个新的识别会话. + - `receiver`: 作为结果接收方加入,订阅一个已存在的会话. + """ + await websocket.accept() + + # 从websocket.app.state获取全局的ASRRunner实例 + asr_runner = websocket.app.state.asr_runner + + # 创建WebSocket适配器 + # 注意:这里的audio_config应该与ASRRunner中的默认配置一致 + audio_config = asr_runner._default_audio_config + adapter = FastAPIWebSocketAdapter( + websocket, + sample_rate=audio_config.sample_rate, + sample_width=audio_config.sample_width + ) + + if mode == "sender": + logger.info(f"客户端 {websocket.client} 作为 'sender' 加入会话: {session_id}") + # 创建一个新的SAR会话 + sar_id = asr_runner.new_SAR(ws=adapter, name=session_id) + if sar_id is None: + logger.error(f"为会话 {session_id} 创建SAR失败") + await websocket.close(code=1011, reason="Failed to create ASR session") + return + + sar = next((s for s in asr_runner._SAR_list if s._id == sar_id), None) + try: + # 端点函数等待后台任务完成。 + # 真正的接收逻辑在SAR的_run方法中,该方法由new_SAR作为后台任务启动。 + # 当客户端断开连接时,adapter.recv()会抛出异常, + # _run任务会捕获它,然后停止并清理,最后任务结束。 + if sar and sar._task: + await sar._task + else: + # 如果任务没有被创建,记录一个错误并关闭连接 + logger.error(f"SAR任务未能在会话 {session_id} 中启动") + await websocket.close(code=1011, reason="Failed to start ASR task") + + except Exception as e: + # 捕获任何意外的错误 + logger.error(f"会话 {session_id} 的 'sender' 端点发生未知错误: {e}") + finally: + logger.info(f"'sender' {websocket.client} 在会话 {session_id} 的连接处理已结束") + + elif mode == "receiver": + logger.info(f"客户端 {websocket.client} 作为 'receiver' 加入会话: {session_id}") + # 加入一个已存在的SAR会话 + joined = asr_runner.join_SAR(ws=adapter, name=session_id) + if not joined: + logger.warning(f"无法找到会话 {session_id},'receiver' {websocket.client} 加入失败") + await websocket.close(code=1011, reason=f"Session '{session_id}' not found") + return + + try: + # Receiver只需要保持连接,等待从SAR广播过来的消息 + # 这个循环也用于检测断开 + while True: + await websocket.receive_text() + except WebSocketDisconnect: + logger.info(f"'receiver' {websocket.client} 在会话 {session_id} 中断开连接") + # Receiver断开时,需要将其从SAR的接收者列表中移除 + sar = next((s for s in asr_runner._SAR_list if s._name == session_id), None) + if sar: + sar.delete_receiver(adapter) + logger.info(f"已从会话 {session_id} 中移除 'receiver' {websocket.client}") + + else: + # 理论上,由于FastAPI的enum校验,这里的代码不会被执行 + logger.error(f"无效的模式: {mode}") + await websocket.close(code=1003, reason="Invalid mode specified") \ No newline at end of file diff --git a/src/websockets/router.py b/src/websockets/router.py index 2694a6a..489b616 100644 --- a/src/websockets/router.py +++ b/src/websockets/router.py @@ -1,3 +1,9 @@ -from endpoint import asr_router +from fastapi import APIRouter +from .endpoint import asr_endpoint -__all__ = ["asr_router"] \ No newline at end of file +websocket_router = APIRouter() + +# 包含ASR端点路由 +websocket_router.include_router(asr_endpoint.router) + +__all__ = ["websocket_router"] \ No newline at end of file diff --git a/test_main.py b/test_main.py index 42bd7a9..c7a697a 100644 --- a/test_main.py +++ b/test_main.py @@ -6,6 +6,7 @@ from tests.pipeline.asr_test import test_asr_pipeline from src.utils.logger import get_module_logger, setup_root_logger from tests.runner.asr_runner_test import test_asr_runner +import asyncio setup_root_logger(level="INFO", log_file="logs/test_main.log") logger = get_module_logger(__name__) @@ -22,4 +23,4 @@ with open("logs/test_main.log", "w") as f: # test_asr_pipeline() logger.info("开始测试ASRRunner") -test_asr_runner() +asyncio.run(test_asr_runner()) diff --git a/tests/runner/asr_runner_test.py b/tests/runner/asr_runner_test.py index a4dbec8..79e047c 100644 --- a/tests/runner/asr_runner_test.py +++ b/tests/runner/asr_runner_test.py @@ -1,30 +1,59 @@ """ ASRRunner test """ -import queue -import time +import asyncio import soundfile import numpy as np from src.runner.ASRRunner import ASRRunner from src.core.model_loader import ModelLoader from src.models import AudioBinary_Config -from src.utils.mock_websocket import MockWebSocketClient +from asyncio import Queue as AsyncQueue from src.utils.logger import get_module_logger logger = get_module_logger(__name__) -def test_asr_runner(): + +class AsyncMockWebSocketClient: + """一个用于测试目的的异步WebSocket客户端模拟器。""" + + def __init__(self): + self._recv_q = AsyncQueue() + self._send_q = AsyncQueue() + + def put_for_recv(self, item): + """允许测试将数据送入模拟的WebSocket中。""" + self._recv_q.put_nowait(item) + + async def get_from_send(self): + """允许测试从模拟的WebSocket中获取结果。""" + return await self._send_q.get() + + async def recv(self): + """ASRRunner将调用此方法来获取数据。""" + return await self._recv_q.get() + + async def send(self, item): + """ASRRunner将通过回调调用此方法来发送结果。""" + logger.info(f"Mock WS 收到结果: {item}") + await self._send_q.put(item) + + async def close(self): + """一个模拟的关闭方法。""" + pass + + +async def test_asr_runner(): """ - End-to-end test for ASRRunner. - 1. Loads models. - 2. Configures and initializes ASRRunner. - 3. Creates a mock WebSocket client. - 4. Starts a new SenderAndReceiver (SAR) instance in the runner. - 5. Streams audio data via the mock WebSocket. - 6. Asserts that the received transcription matches the expected text. + 针对ASRRunner的端到端测试,已适配异步操作。 + 1. 加载模型. + 2. 配置并初始化ASRRunner. + 3. 创建一个异步的模拟WebSocket客户端. + 4. 在Runner中启动一个新的SenderAndReceiver (SAR)实例. + 5. 通过模拟的WebSocket流式传输音频数据. + 6. 等待处理任务完成并断言其无错误运行. """ - # 1. Load models + # 1. 加载模型 model_loader = ModelLoader() args = { "asr_model": "paraformer-zh", @@ -33,48 +62,71 @@ def test_asr_runner(): "vad_model_revision": "v2.0.4", "spk_model": "cam++", "spk_model_revision": "v2.0.2", - "audio_update": False, } models = model_loader.load_models(args) audio_file_path = "tests/XT_ZZY_denoise.wav" audio_data, sample_rate = soundfile.read(audio_file_path) - logger.info(f"加载数据: {audio_file_path} , audio_data_length: {len(audio_data)}, audio_data_type: {type(audio_data)}, sample_rate: {sample_rate}") - # 2. Configure audio + logger.info( + f"加载数据: {audio_file_path} , audio_data_length: {len(audio_data)}, audio_data_type: {type(audio_data)}, sample_rate: {sample_rate}" + ) + # 进一步详细打印audio_data数据类型 + # 详细打印audio_data的类型和结构信息,便于调试 + logger.info(f"audio_data 类型: {type(audio_data)}") + logger.info(f"audio_data dtype: {getattr(audio_data, 'dtype', '未知')}") + logger.info(f"audio_data shape: {getattr(audio_data, 'shape', '未知')}") + logger.info(f"audio_data ndim: {getattr(audio_data, 'ndim', '未知')}") + logger.info(f"audio_data 示例前10个值: {audio_data[:10] if hasattr(audio_data, '__getitem__') else '不可切片'}") + + # 2. 配置音频 audio_config = AudioBinary_Config( chunk_size=200, # ms - chunk_stride=1000, # 10ms stride for 16kHz sample_rate=sample_rate, sample_width=2, # 16-bit - channels=2, + channels=1, ) audio_config.chunk_stride = int(audio_config.chunk_size * sample_rate / 1000) - # 3. Setup ASRRunner + # 3. 设置ASRRunner asr_runner = ASRRunner() asr_runner.set_default_config( audio_config=audio_config, models=models, ) - # 4. Create Mock WebSocket and start SAR - mock_ws = MockWebSocketClient() + # 4. 创建模拟WebSocket并启动SAR + mock_ws = AsyncMockWebSocketClient() sar_id = asr_runner.new_SAR( ws=mock_ws, name="test_sar", ) - assert sar_id is not None, "Failed to create a new SAR instance" + assert sar_id is not None, "创建新的SAR实例失败" - # 5. Simulate streaming audio - print(f"Sending audio data of length {len(audio_data)} samples.") - audio_clip_len = 200 - for i in range(0, len(audio_data), audio_clip_len): - chunk = audio_data[i : i + audio_clip_len] - if not isinstance(chunk, np.ndarray) or chunk.size == 0: - break - # Simulate receiving binary data over WebSocket - mock_ws.put_for_recv(chunk) + # 获取SAR实例以等待其任务 + sar = next((s for s in asr_runner._SAR_list if s._id == sar_id), None) + assert sar is not None, "无法从Runner中获取SAR实例。" + assert sar._task is not None, "SAR任务未被创建。" - # 6. Wait for results and assert - time.sleep(30) - # Signal end of audio stream by sending None - mock_ws.put_for_recv(None) + # 5. 在后台任务中模拟流式音频 + async def feed_audio(): + logger.info("Feeder任务已启动:开始流式传输音频数据...") + # 每次发送100ms的音频 + audio_clip_len = int(sample_rate * 0.1) + for i in range(0, len(audio_data), audio_clip_len): + chunk = audio_data[i : i + audio_clip_len] + if chunk.size == 0: + break + mock_ws.put_for_recv(chunk) + await asyncio.sleep(0.1) # 模拟实时流 + + # 发送None来表示音频流结束 + mock_ws.put_for_recv(None) + logger.info("Feeder任务已完成:所有音频数据已发送。") + + feeder_task = asyncio.create_task(feed_audio()) + + # 6. 等待SAR处理完成 + # SAR任务在从模拟WebSocket接收到None后会结束 + await sar._task + await feeder_task # 确保feeder也已完成 + + logger.info("ASRRunner测试成功完成。") diff --git a/tests/websocket/websocket_asr.py b/tests/websocket/websocket_asr.py new file mode 100644 index 0000000..a54336e --- /dev/null +++ b/tests/websocket/websocket_asr.py @@ -0,0 +1,110 @@ +import asyncio +import websockets +import soundfile as sf +import uuid + +# --- 配置 --- +HOST = "localhost" +PORT = 8000 +SESSION_ID = str(uuid.uuid4()) +SENDER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=sender" +RECEIVER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=receiver" + +AUDIO_FILE_PATH = "tests/XT_ZZY_denoise.wav" # 确保此测试文件存在且为 16kHz, 16-bit, 单声道 +CHUNK_DURATION_MS = 100 # 每次发送100ms的音频数据 +CHUNK_SIZE = int(16000 * 2 * CHUNK_DURATION_MS / 1000) # 3200 bytes + +async def run_receiver(): + """作为接收者连接,并打印收到的所有消息。""" + print(f"▶️ [Receiver] 尝试连接到: {RECEIVER_URI}") + try: + async with websockets.connect(RECEIVER_URI) as websocket: + print("✅ [Receiver] 连接成功,等待消息...") + try: + while True: + message = await websocket.recv() + print(f"🎧 [Receiver] 收到结果: {message}") + except websockets.exceptions.ConnectionClosed as e: + print(f"✅ [Receiver] 连接已由服务器正常关闭: {e.reason}") + except Exception as e: + print(f"❌ [Receiver] 连接失败: {e}") + +async def run_sender(): + """ + 作为发送者连接,同时负责发送音频和接收自己会话的广播结果。 + """ + await asyncio.sleep(1) # 等待receiver有机会先连接 + print(f"▶️ [Sender] 尝试连接到: {SENDER_URI}") + try: + async with websockets.connect(SENDER_URI) as websocket: + print("✅ [Sender] 连接成功。") + + # --- 并行任务:接收消息 --- + async def receive_task(): + print("▶️ [Sender-Receiver] 开始监听广播消息...") + try: + while True: + message = await websocket.recv() + print(f"🎧 [Sender-Receiver] 收到结果: {message}") + except websockets.exceptions.ConnectionClosed: + print("✅ [Sender-Receiver] 连接已关闭,停止监听。") + + receiver_sub_task = asyncio.create_task(receive_task()) + + # --- 主任务:发送音频 --- + try: + print("▶️ [Sender] 准备发送音频...") + audio_data, sample_rate = sf.read(AUDIO_FILE_PATH, dtype='float32') + if sample_rate != 16000: + print(f"❌ [Sender] 错误:音频文件采样率必须是 16kHz。") + receiver_sub_task.cancel() + return + + total_samples = len(audio_data) + chunk_samples = CHUNK_SIZE // 2 + samples_sent = 0 + print(f"音频加载成功,总长度: {total_samples} samples。开始分块发送...") + + for i in range(0, total_samples, chunk_samples): + chunk = audio_data[i:i + chunk_samples] + if len(chunk) == 0: + break + await websocket.send(chunk.tobytes()) + samples_sent += len(chunk) + print(f"🎧 [Sender] 正在发送: {samples_sent}/{total_samples} samples", end="\r") + await asyncio.sleep(CHUNK_DURATION_MS / 1000) + + print() + print("🏁 [Sender] 音频流发送完毕,发送 'close' 信号。") + await websocket.send("close") + + except FileNotFoundError: + print(f"❌ [Sender] 错误:找不到音频文件 {AUDIO_FILE_PATH}") + except Exception as e: + print(f"❌ [Sender] 发送过程中发生错误: {e}") + + # 等待接收任务自然结束(当连接关闭时) + await receiver_sub_task + + except Exception as e: + print(f"❌ [Sender] 连接失败: {e}") + +async def main(): + """同时运行 sender 和 receiver 任务。""" + print("--- 开始 WebSocket ASR 服务端到端测试 ---") + print(f"会话 ID: {SESSION_ID}") + + # 创建 receiver 和 sender 任务 + sender_task = asyncio.create_task(run_sender()) + await asyncio.sleep(7) + receiver_task = asyncio.create_task(run_receiver()) + + # 等待两个任务完成 + await asyncio.gather(receiver_task, sender_task) + + print("--- 测试结束 ---") + +if __name__ == "__main__": + # 在运行此脚本前,请确保 FastAPI 服务器正在运行。 + # python main.py + asyncio.run(main())