[fastAPI]完成fastAPI-websocket端口搭建与测试,接收float32的tobytes字节流。
This commit is contained in:
parent
1a296d8309
commit
3083738db4
265
WEBSOCKET_API.md
Normal file
265
WEBSOCKET_API.md
Normal file
@ -0,0 +1,265 @@
|
||||
# FunASR-FastAPI WebSocket API 文档
|
||||
|
||||
本文档详细介绍了如何连接和使用 FunASR-FastAPI 实时语音识别服务的 WebSocket 接口。
|
||||
|
||||
## 1. 连接端点 (Endpoint)
|
||||
|
||||
服务的 WebSocket 端点 URL 格式如下:
|
||||
|
||||
```
|
||||
ws://<your_server_host>:8000/ws/asr/{session_id}?mode=<client_mode>
|
||||
```
|
||||
|
||||
### 参数说明
|
||||
|
||||
- **`{session_id}`** (路径参数, `str`, **必需**):
|
||||
用于唯一标识一个识别会话(例如,一场会议或一次直播)。所有属于同一次会话的客户端都应使用相同的 `session_id`。
|
||||
|
||||
- **`mode`** (查询参数, `str`, **必需**):
|
||||
定义客户端的角色。
|
||||
- `sender`: 音频发送者。一个会话中应该只有一个 `sender`。此客户端负责将实时音频流发送到服务器。
|
||||
- `receiver`: 结果接收者。一个会话中可以有多个 `receiver`。此客户端只接收由服务器广播的识别结果,不发送音频。
|
||||
|
||||
## 2. 数据格式
|
||||
|
||||
### 2.1 发送数据 (Sender -> Server)
|
||||
|
||||
- **音频格式**: `sender` 必须发送原始的 **PCM 音频数据**。
|
||||
- **采样率**: 16000 Hz
|
||||
- **位深**: 16-bit (signed integer)
|
||||
- **声道数**: 单声道 (Mono)
|
||||
- **传输格式**: 必须以**二进制 (bytes)** 格式发送。
|
||||
- **结束信号**: 当音频流结束时,`sender` 应发送一个**文本消息** `"close"` 来通知服务器关闭会话。
|
||||
|
||||
### 2.2 接收数据 (Server -> Receiver)
|
||||
|
||||
服务器会将识别结果以 **JSON 文本** 格式广播给会话中的所有 `receiver`(以及 `sender` 自己)。JSON 对象的结构示例如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"asr": "你好,世界。",
|
||||
"spk": {
|
||||
"speaker_id": "uuid-of-the-speaker",
|
||||
"speaker_name": "SpeakerName",
|
||||
"score": 0.98
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 3. Python 客户端示例
|
||||
|
||||
需要安装 `websockets` 库: `pip install websockets`
|
||||
|
||||
### 3.1 Python Sender 示例 (发送本地音频文件)
|
||||
|
||||
这个脚本会读取一个 WAV 文件,并将其内容以流式方式发送到服务器。
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import websockets
|
||||
import soundfile as sf
|
||||
import uuid
|
||||
|
||||
# --- 配置 ---
|
||||
SERVER_URI = "ws://localhost:8000/ws/asr/{session_id}?mode=sender"
|
||||
SESSION_ID = str(uuid.uuid4()) # 为这次会话生成一个唯一的ID
|
||||
AUDIO_FILE = "tests/XT_ZZY_denoise.wav" # 替换为你的音频文件路径
|
||||
CHUNK_SIZE = 3200 # 每次发送 100ms 的音频数据 (16000 * 2 * 0.1)
|
||||
|
||||
async def send_audio():
|
||||
"""连接到服务器,并流式发送音频文件"""
|
||||
uri = SERVER_URI.format(session_id=SESSION_ID)
|
||||
print(f"作为 Sender 连接到: {uri}")
|
||||
|
||||
async with websockets.connect(uri) as websocket:
|
||||
try:
|
||||
# 读取音频文件
|
||||
with sf.SoundFile(AUDIO_FILE, 'r') as f:
|
||||
assert f.samplerate == 16000, "音频文件采样率必须为 16kHz"
|
||||
assert f.channels == 1, "音频文件必须为单声道"
|
||||
|
||||
print("开始发送音频...")
|
||||
while True:
|
||||
data = f.read(CHUNK_SIZE, dtype='int16')
|
||||
if not data.any():
|
||||
break
|
||||
# 将 numpy 数组转换为原始字节流
|
||||
await websocket.send(data.tobytes())
|
||||
await asyncio.sleep(0.1) # 模拟实时音频输入
|
||||
|
||||
print("音频发送完毕,发送结束信号。")
|
||||
await websocket.send("close")
|
||||
|
||||
# 等待服务器的最终确认或关闭连接
|
||||
response = await websocket.recv()
|
||||
print(f"收到服务器最终响应: {response}")
|
||||
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
print(f"连接已关闭: {e}")
|
||||
except Exception as e:
|
||||
print(f"发生错误: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(send_audio())
|
||||
```
|
||||
|
||||
### 3.2 Python Receiver 示例 (接收识别结果)
|
||||
|
||||
这个脚本会连接到指定的会话,并持续打印服务器广播的识别结果。
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import websockets
|
||||
|
||||
# --- 配置 ---
|
||||
# !!! 必须和 Sender 使用相同的 SESSION_ID !!!
|
||||
SERVER_URI = "ws://localhost:8000/ws/asr/{session_id}?mode=receiver"
|
||||
SESSION_ID = "在此处粘贴你的Sender会话ID"
|
||||
|
||||
async def receive_results():
|
||||
"""连接到服务器并接收识别结果"""
|
||||
if "粘贴你的Sender会话ID" in SESSION_ID:
|
||||
print("错误:请先设置有效的 SESSION_ID!")
|
||||
return
|
||||
|
||||
uri = SERVER_URI.format(session_id=SESSION_ID)
|
||||
print(f"作为 Receiver 连接到: {uri}")
|
||||
|
||||
async with websockets.connect(uri) as websocket:
|
||||
try:
|
||||
print("等待接收识别结果...")
|
||||
while True:
|
||||
message = await websocket.recv()
|
||||
print(f"收到结果: {message}")
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
print(f"连接已关闭: {e.code} {e.reason}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(receive_results())
|
||||
```
|
||||
|
||||
## 4. JavaScript 客户端示例 (浏览器)
|
||||
|
||||
这个示例展示了如何在网页上通过麦克风获取音频,并将其作为 `sender` 发送。
|
||||
|
||||
```html
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>WebSocket ASR Client</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>FunASR WebSocket Client (Sender)</h1>
|
||||
<p><strong>Session ID:</strong> <span id="sessionId"></span></p>
|
||||
<button id="startButton">开始识别</button>
|
||||
<button id="stopButton" disabled>停止识别</button>
|
||||
<h2>识别结果:</h2>
|
||||
<div id="results"></div>
|
||||
|
||||
<script>
|
||||
const startButton = document.getElementById('startButton');
|
||||
const stopButton = document.getElementById('stopButton');
|
||||
const resultsDiv = document.getElementById('results');
|
||||
const sessionIdSpan = document.getElementById('sessionId');
|
||||
|
||||
let websocket;
|
||||
let audioContext;
|
||||
let scriptProcessor;
|
||||
let mediaStream;
|
||||
|
||||
const CHUNK_DURATION_MS = 100; // 每100ms发送一次数据
|
||||
const SAMPLE_RATE = 16000;
|
||||
|
||||
// 生成一个简单的UUID
|
||||
function generateUUID() {
|
||||
return ([1e7]+-1e3+-4e3+-8e3+-1e11).replace(/[018]/g, c =>
|
||||
(c ^ crypto.getRandomValues(new Uint8Array(1))[0] & 15 >> c / 4).toString(16)
|
||||
);
|
||||
}
|
||||
|
||||
async function startRecording() {
|
||||
const sessionId = generateUUID();
|
||||
sessionIdSpan.textContent = sessionId;
|
||||
const wsUrl = `ws://${window.location.host}/ws/asr/${sessionId}?mode=sender`;
|
||||
|
||||
websocket = new WebSocket(wsUrl);
|
||||
websocket.onopen = () => {
|
||||
console.log("WebSocket 连接已打开");
|
||||
startButton.disabled = true;
|
||||
stopButton.disabled = false;
|
||||
resultsDiv.innerHTML = '';
|
||||
};
|
||||
|
||||
websocket.onmessage = (event) => {
|
||||
console.log("收到消息:", event.data);
|
||||
const result = JSON.parse(event.data);
|
||||
const asrText = result.asr || '';
|
||||
const spkName = result.spk ? result.spk.speaker_name : 'Unknown';
|
||||
resultsDiv.innerHTML += `<p><strong>${spkName}:</strong> ${asrText}</p>`;
|
||||
};
|
||||
|
||||
websocket.onclose = () => {
|
||||
console.log("WebSocket 连接已关闭");
|
||||
stopRecording();
|
||||
};
|
||||
|
||||
websocket.onerror = (error) => {
|
||||
console.error("WebSocket 错误:", error);
|
||||
alert("WebSocket 连接失败!");
|
||||
stopRecording();
|
||||
};
|
||||
|
||||
try {
|
||||
mediaStream = await navigator.mediaDevices.getUserMedia({ audio: true, video: false });
|
||||
audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: SAMPLE_RATE });
|
||||
|
||||
const source = audioContext.createMediaStreamSource(mediaStream);
|
||||
const bufferSize = CHUNK_DURATION_MS * SAMPLE_RATE / 1000 * 2; // 计算缓冲区大小
|
||||
scriptProcessor = audioContext.createScriptProcessor(bufferSize, 1, 1);
|
||||
|
||||
scriptProcessor.onaudioprocess = (e) => {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
const inputData = e.inputBuffer.getChannelData(0);
|
||||
// 服务器期望16-bit PCM,需要转换
|
||||
const pcmData = new Int16Array(inputData.length);
|
||||
for (let i = 0; i < inputData.length; i++) {
|
||||
pcmData[i] = Math.max(-1, Math.min(1, inputData[i])) * 32767;
|
||||
}
|
||||
websocket.send(pcmData.buffer);
|
||||
}
|
||||
};
|
||||
|
||||
source.connect(scriptProcessor);
|
||||
scriptProcessor.connect(audioContext.destination);
|
||||
|
||||
} catch (err) {
|
||||
console.error("无法获取麦克风:", err);
|
||||
alert("无法获取麦克风权限!");
|
||||
if (websocket) websocket.close();
|
||||
}
|
||||
}
|
||||
|
||||
function stopRecording() {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
websocket.send("close");
|
||||
}
|
||||
if (mediaStream) {
|
||||
mediaStream.getTracks().forEach(track => track.stop());
|
||||
}
|
||||
if (scriptProcessor) {
|
||||
scriptProcessor.disconnect();
|
||||
}
|
||||
if (audioContext) {
|
||||
audioContext.close();
|
||||
}
|
||||
startButton.disabled = false;
|
||||
stopButton.disabled = true;
|
||||
}
|
||||
|
||||
startButton.addEventListener('click', startRecording);
|
||||
stopButton.addEventListener('click', stopRecording);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
```
|
19
main.py
19
main.py
@ -1,12 +1,19 @@
|
||||
from src.server import app
|
||||
import uvicorn
|
||||
from datetime import datetime
|
||||
from src.server import app
|
||||
from src.utils.logger import get_module_logger, setup_root_logger
|
||||
from datetime import datetime
|
||||
|
||||
time = format(datetime.now(), "%Y-%m-%d %H:%M:%S")
|
||||
setup_root_logger(level="DEBUG", log_file=f"logs/fastapiserver_{time}.log")
|
||||
time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
setup_root_logger(level="INFO", log_file=f"logs/main_{time}.log")
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
logger.info("启动 FunASR FastAPI 服务器...")
|
||||
# 在生产环境中,推荐使用更强大的ASGI服务器,如Gunicorn,并配合Uvicorn workers。
|
||||
# 例如: gunicorn -w 4 -k uvicorn.workers.UvicornWorker main:app
|
||||
uvicorn.run(
|
||||
app,
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
log_level="info"
|
||||
)
|
@ -11,6 +11,7 @@ from src.pipeline import PipelineFactory
|
||||
from src.models import AudioBinary_data_list, AudioBinary_Config
|
||||
from src.core.model_loader import ModelLoader
|
||||
from src.config import DefaultConfig
|
||||
import asyncio
|
||||
from queue import Queue
|
||||
import soundfile
|
||||
import time
|
||||
@ -57,6 +58,7 @@ class ASRRunner(RunnerBase):
|
||||
# 输入队列
|
||||
self._input_queue: Queue = Queue()
|
||||
self._pipeline: Optional[ASRPipeline] = None
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
|
||||
def set_name(self, name: str):
|
||||
self._name = name
|
||||
@ -76,7 +78,14 @@ class ASRRunner(RunnerBase):
|
||||
self._pipeline.set_models(self._models)
|
||||
self._pipeline.set_audio_binary(self._audio_binary)
|
||||
self._pipeline.set_input_queue(self._input_queue)
|
||||
self._pipeline.set_callback(self.deal_message)
|
||||
|
||||
# --- 异步-同步桥梁 ---
|
||||
# 创建一个线程安全的回调函数,用于从Pipeline的线程中调用Runner的异步方法
|
||||
loop = asyncio.get_running_loop()
|
||||
def thread_safe_callback(message):
|
||||
asyncio.run_coroutine_threadsafe(self.deal_message(message), loop)
|
||||
|
||||
self._pipeline.set_callback(thread_safe_callback)
|
||||
self._pipeline.bake()
|
||||
|
||||
def append_receiver(self, receiver: WebSocketClient):
|
||||
@ -85,47 +94,64 @@ class ASRRunner(RunnerBase):
|
||||
def delete_receiver(self, receiver: WebSocketClient):
|
||||
self._receiver.remove(receiver)
|
||||
|
||||
def deal_message(self, message: str):
|
||||
self.broadcast(message)
|
||||
async def deal_message(self, message: str):
|
||||
await self.broadcast(message)
|
||||
|
||||
def broadcast(self, message: str):
|
||||
async def broadcast(self, message: str):
|
||||
"""
|
||||
广播发送给所有接收者
|
||||
"""
|
||||
logger.info("[ASRRunner][SAR-%s]广播发送给所有接收者: %s", self._name, message)
|
||||
for receiver in self._receiver:
|
||||
receiver.send(message)
|
||||
logger.info("[ASRRunner][SAR-%s]广播发送给所有接收者: 消息长度:%s", self._name, len(message))
|
||||
logger.info(f"SAR-{self._name} 的接收者列表: {self._receiver}")
|
||||
tasks = [receiver.send(message) for receiver in self._receiver]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def _run(self):
|
||||
async def _run(self):
|
||||
"""
|
||||
运行SAR
|
||||
"""
|
||||
self._pipeline.run()
|
||||
loop = asyncio.get_running_loop()
|
||||
while True:
|
||||
data = self._sender.recv()
|
||||
if data is None:
|
||||
try:
|
||||
data = await self._sender.recv()
|
||||
if data is None:
|
||||
# `None` is used as a signal to end the stream
|
||||
await loop.run_in_executor(None, self._input_queue.put, None)
|
||||
break
|
||||
# logger.debug("[ASRRunner][SAR-%s]接收到的数据length: %s", self._name, len(data))
|
||||
await loop.run_in_executor(None, self._input_queue.put, data)
|
||||
except Exception as e:
|
||||
logger.error(f"[ASRRunner][SAR-{self._name}] _run loop error: {e}")
|
||||
break
|
||||
# logger.debug("[ASRRunner][SAR-%s]接收到的数据length: %s", self._name, len(data))
|
||||
self._input_queue.put(data)
|
||||
self.stop()
|
||||
await self.stop()
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
运行SAR
|
||||
"""
|
||||
self._thread = Thread(target=self._run, name=f"[ASRRunner]SAR-{self._name}")
|
||||
self._thread.daemon = True
|
||||
self._thread.start()
|
||||
self._task = asyncio.create_task(self._run())
|
||||
|
||||
def stop(self):
|
||||
async def stop(self):
|
||||
"""
|
||||
停止SAR
|
||||
"""
|
||||
logger.info(f"Stopping SAR: {self._name}")
|
||||
self._pipeline.stop()
|
||||
for ws in self._receiver:
|
||||
ws.close()
|
||||
self._sender.close()
|
||||
|
||||
|
||||
# Close all receiver websockets
|
||||
receiver_tasks = [ws.close() for ws in self._receiver]
|
||||
await asyncio.gather(*receiver_tasks, return_exceptions=True)
|
||||
|
||||
# Close the sender websocket
|
||||
if self._sender:
|
||||
await self._sender.close()
|
||||
|
||||
# Cancel the main task if it's still running
|
||||
if self._task and not self._task.done():
|
||||
self._task.cancel()
|
||||
logger.info(f"SAR stopped: {self._name}")
|
||||
|
||||
def __init__(self,*args,**kwargs):
|
||||
"""
|
||||
"""
|
||||
@ -195,9 +221,11 @@ class ASRRunner(RunnerBase):
|
||||
return True
|
||||
return False
|
||||
|
||||
def __del__(self) -> None:
|
||||
async def shutdown(self):
|
||||
"""
|
||||
析构函数
|
||||
优雅地关闭所有SAR会话
|
||||
"""
|
||||
for sar in self._SAR_list:
|
||||
sar.stop()
|
||||
logger.info("Shutting down all SAR instances...")
|
||||
tasks = [sar.stop() for sar in self._SAR_list]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
logger.info("All SAR instances have been shut down.")
|
||||
|
3
src/runner/__init__.py
Normal file
3
src/runner/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .ASRRunner import ASRRunner
|
||||
|
||||
__all__ = ["ASRRunner"]
|
299
src/server.py
299
src/server.py
@ -10,255 +10,80 @@ import json
|
||||
import websockets
|
||||
import ssl
|
||||
import argparse
|
||||
from config import parse_args
|
||||
from models import load_models
|
||||
from service import ASRService
|
||||
from src.runner import ASRRunner
|
||||
from src.config import DefaultConfig
|
||||
from src.config import AudioBinary_Config
|
||||
from src.websockets.router import websocket_router
|
||||
from src.core import ModelLoader
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from contextlib import asynccontextmanager
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
# 全局变量,存储当前连接的WebSocket客户端
|
||||
websocket_users = set()
|
||||
|
||||
|
||||
async def ws_reset(websocket):
|
||||
"""重置WebSocket连接状态并关闭连接"""
|
||||
print(f"重置WebSocket连接,当前连接数: {len(websocket_users)}")
|
||||
|
||||
# 重置状态字典
|
||||
websocket.status_dict_asr_online["cache"] = {}
|
||||
websocket.status_dict_asr_online["is_final"] = True
|
||||
websocket.status_dict_vad["cache"] = {}
|
||||
websocket.status_dict_vad["is_final"] = True
|
||||
websocket.status_dict_punc["cache"] = {}
|
||||
|
||||
# 关闭连接
|
||||
await websocket.close()
|
||||
|
||||
|
||||
async def clear_websocket():
|
||||
"""清理所有WebSocket连接"""
|
||||
for websocket in websocket_users:
|
||||
await ws_reset(websocket)
|
||||
websocket_users.clear()
|
||||
|
||||
|
||||
async def ws_serve(websocket, path):
|
||||
# 使用 lifespan 上下文管理器来管理应用的生命周期
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
WebSocket服务主函数,处理客户端连接和消息
|
||||
|
||||
参数:
|
||||
websocket: WebSocket连接对象
|
||||
path: 连接路径
|
||||
在应用启动时加载模型和初始化ASRRunner,
|
||||
在应用关闭时优雅地关闭ASRRunner。
|
||||
"""
|
||||
frames = [] # 存储所有音频帧
|
||||
frames_asr = [] # 存储用于离线ASR的音频帧
|
||||
frames_asr_online = [] # 存储用于在线ASR的音频帧
|
||||
logger.info("应用启动,开始加载模型和初始化Runner...")
|
||||
|
||||
# 1. 加载模型
|
||||
# 这里的参数可以从配置文件或环境变量中获取
|
||||
args = {
|
||||
"asr_model": "paraformer-zh",
|
||||
"asr_model_revision": "v2.0.4",
|
||||
"vad_model": "fsmn-vad",
|
||||
"vad_model_revision": "v2.0.4",
|
||||
"spk_model": "cam++",
|
||||
"spk_model_revision": "v2.0.2",
|
||||
}
|
||||
model_loader = ModelLoader()
|
||||
models = model_loader.load_models(args)
|
||||
|
||||
# 2. 初始化 ASRRunner
|
||||
_audio_config = AudioBinary_Config(
|
||||
chunk_size=200, # ms
|
||||
sample_rate=16000,
|
||||
sample_width=2, # 16-bit
|
||||
channels=1,
|
||||
)
|
||||
_audio_config.chunk_stride = int(_audio_config.chunk_size * _audio_config.sample_rate / 1000)
|
||||
|
||||
global websocket_users
|
||||
# await clear_websocket() # 清理现有连接(目前注释掉,允许多客户端)
|
||||
asr_runner = ASRRunner()
|
||||
asr_runner.set_default_config(
|
||||
audio_config=_audio_config,
|
||||
models=models,
|
||||
)
|
||||
|
||||
# 添加到用户集合
|
||||
websocket_users.add(websocket)
|
||||
# 3. 将 asr_runner 实例存储在 app.state 中
|
||||
app.state.asr_runner = asr_runner
|
||||
logger.info("模型加载和Runner初始化完成。")
|
||||
|
||||
# 初始化连接状态
|
||||
websocket.status_dict_asr = {}
|
||||
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
|
||||
websocket.status_dict_vad = {"cache": {}, "is_final": False}
|
||||
websocket.status_dict_punc = {"cache": {}}
|
||||
websocket.chunk_interval = 10
|
||||
websocket.vad_pre_idx = 0
|
||||
websocket.is_speaking = True # 默认用户正在说话
|
||||
yield
|
||||
|
||||
# 语音检测状态
|
||||
speech_start = False
|
||||
speech_end_i = -1
|
||||
# --- 应用关闭时执行的代码 ---
|
||||
logger.info("应用关闭,开始清理资源...")
|
||||
await app.state.asr_runner.shutdown()
|
||||
logger.info("资源清理完成。")
|
||||
|
||||
# 初始化配置
|
||||
websocket.wav_name = "microphone"
|
||||
websocket.mode = "2pass" # 默认使用两阶段识别模式
|
||||
# 初始化FastAPI应用,并指定lifespan
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
print("新用户已连接", flush=True)
|
||||
|
||||
try:
|
||||
# 持续接收客户端消息
|
||||
async for message in websocket:
|
||||
# 处理JSON配置消息
|
||||
if isinstance(message, str):
|
||||
try:
|
||||
messagejson = json.loads(message)
|
||||
|
||||
# 更新各种配置参数
|
||||
if "is_speaking" in messagejson:
|
||||
websocket.is_speaking = messagejson["is_speaking"]
|
||||
websocket.status_dict_asr_online["is_final"] = (
|
||||
not websocket.is_speaking
|
||||
)
|
||||
if "chunk_interval" in messagejson:
|
||||
websocket.chunk_interval = messagejson["chunk_interval"]
|
||||
if "wav_name" in messagejson:
|
||||
websocket.wav_name = messagejson.get("wav_name")
|
||||
if "chunk_size" in messagejson:
|
||||
chunk_size = messagejson["chunk_size"]
|
||||
if isinstance(chunk_size, str):
|
||||
chunk_size = chunk_size.split(",")
|
||||
websocket.status_dict_asr_online["chunk_size"] = [
|
||||
int(x) for x in chunk_size
|
||||
]
|
||||
if "encoder_chunk_look_back" in messagejson:
|
||||
websocket.status_dict_asr_online["encoder_chunk_look_back"] = (
|
||||
messagejson["encoder_chunk_look_back"]
|
||||
)
|
||||
if "decoder_chunk_look_back" in messagejson:
|
||||
websocket.status_dict_asr_online["decoder_chunk_look_back"] = (
|
||||
messagejson["decoder_chunk_look_back"]
|
||||
)
|
||||
if "hotword" in messagejson:
|
||||
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
|
||||
if "mode" in messagejson:
|
||||
websocket.mode = messagejson["mode"]
|
||||
except json.JSONDecodeError:
|
||||
print(f"无效的JSON消息: {message}")
|
||||
|
||||
# 根据chunk_interval更新VAD的chunk_size
|
||||
websocket.status_dict_vad["chunk_size"] = int(
|
||||
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1]
|
||||
* 60
|
||||
/ websocket.chunk_interval
|
||||
)
|
||||
|
||||
# 处理音频数据
|
||||
if (
|
||||
len(frames_asr_online) > 0
|
||||
or len(frames_asr) >= 0
|
||||
or not isinstance(message, str)
|
||||
):
|
||||
if not isinstance(message, str): # 二进制音频数据
|
||||
# 添加到帧缓冲区
|
||||
frames.append(message)
|
||||
duration_ms = len(message) // 32 # 计算音频时长
|
||||
websocket.vad_pre_idx += duration_ms
|
||||
|
||||
# 处理在线ASR
|
||||
frames_asr_online.append(message)
|
||||
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
|
||||
|
||||
# 达到chunk_interval或最终帧时处理在线ASR
|
||||
if (
|
||||
len(frames_asr_online) % websocket.chunk_interval == 0
|
||||
or websocket.status_dict_asr_online["is_final"]
|
||||
):
|
||||
if websocket.mode == "2pass" or websocket.mode == "online":
|
||||
audio_in = b"".join(frames_asr_online)
|
||||
try:
|
||||
await asr_service.async_asr_online(websocket, audio_in)
|
||||
except Exception as e:
|
||||
print(f"在线ASR处理错误: {e}")
|
||||
frames_asr_online = []
|
||||
|
||||
# 如果检测到语音开始,收集帧用于离线ASR
|
||||
if speech_start:
|
||||
frames_asr.append(message)
|
||||
|
||||
# VAD处理 - 语音活动检测
|
||||
try:
|
||||
speech_start_i, speech_end_i = await asr_service.async_vad(
|
||||
websocket, message
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"VAD处理错误: {e}")
|
||||
|
||||
# 检测到语音开始
|
||||
if speech_start_i != -1:
|
||||
speech_start = True
|
||||
# 计算开始偏移并收集前面的帧
|
||||
beg_bias = (
|
||||
websocket.vad_pre_idx - speech_start_i
|
||||
) // duration_ms
|
||||
frames_pre = (
|
||||
frames[-beg_bias:] if beg_bias < len(frames) else frames
|
||||
)
|
||||
frames_asr = []
|
||||
frames_asr.extend(frames_pre)
|
||||
|
||||
# 处理离线ASR (语音结束或用户停止说话)
|
||||
if speech_end_i != -1 or not websocket.is_speaking:
|
||||
if websocket.mode == "2pass" or websocket.mode == "offline":
|
||||
audio_in = b"".join(frames_asr)
|
||||
try:
|
||||
await asr_service.async_asr(websocket, audio_in)
|
||||
except Exception as e:
|
||||
print(f"离线ASR处理错误: {e}")
|
||||
|
||||
# 重置状态
|
||||
frames_asr = []
|
||||
speech_start = False
|
||||
frames_asr_online = []
|
||||
websocket.status_dict_asr_online["cache"] = {}
|
||||
|
||||
# 如果用户停止说话,完全重置
|
||||
if not websocket.is_speaking:
|
||||
websocket.vad_pre_idx = 0
|
||||
frames = []
|
||||
websocket.status_dict_vad["cache"] = {}
|
||||
else:
|
||||
# 保留最近的帧用于下一轮处理
|
||||
frames = frames[-20:]
|
||||
|
||||
except websockets.ConnectionClosed:
|
||||
print(f"连接已关闭...", flush=True)
|
||||
await ws_reset(websocket)
|
||||
websocket_users.remove(websocket)
|
||||
except websockets.InvalidState:
|
||||
print("无效的WebSocket状态...")
|
||||
except Exception as e:
|
||||
print(f"发生异常: {e}")
|
||||
|
||||
|
||||
def start_server(args, asr_service_instance):
|
||||
"""
|
||||
启动WebSocket服务器
|
||||
|
||||
参数:
|
||||
args: 命令行参数
|
||||
asr_service_instance: ASR服务实例
|
||||
"""
|
||||
global asr_service
|
||||
asr_service = asr_service_instance
|
||||
|
||||
# 配置SSL (如果提供了证书)
|
||||
if args.certfile and len(args.certfile) > 0:
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile)
|
||||
|
||||
start_server = websockets.serve(
|
||||
ws_serve,
|
||||
args.host,
|
||||
args.port,
|
||||
subprotocols=["binary"],
|
||||
ping_interval=None,
|
||||
ssl=ssl_context,
|
||||
)
|
||||
else:
|
||||
start_server = websockets.serve(
|
||||
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None
|
||||
)
|
||||
|
||||
print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}")
|
||||
|
||||
# 启动事件循环
|
||||
asyncio.get_event_loop().run_until_complete(start_server)
|
||||
asyncio.get_event_loop().run_forever()
|
||||
# 挂载WebSocket路由
|
||||
app.include_router(websocket_router, prefix="/ws")
|
||||
|
||||
@app.get("/")
|
||||
async def read_root():
|
||||
return {"message": "FunASR-FastAPI WebSocket Server is running."}
|
||||
|
||||
# 如果需要直接运行此文件进行测试
|
||||
if __name__ == "__main__":
|
||||
# 解析命令行参数
|
||||
args = parse_args()
|
||||
|
||||
# 加载模型
|
||||
print("正在加载模型...")
|
||||
models = load_models(args)
|
||||
print("模型加载完成!当前仅支持单个客户端同时连接!")
|
||||
|
||||
# 创建ASR服务
|
||||
asr_service = ASRService(models)
|
||||
|
||||
# 启动服务器
|
||||
start_server(args, asr_service)
|
||||
# 注意:在生产环境中,推荐使用Gunicorn + Uvicorn workers
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
131
src/service.py
131
src/service.py
@ -1,131 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
ASR服务模块 - 提供语音识别相关的核心功能
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class ASRService:
|
||||
"""ASR服务类,封装各种语音识别相关功能"""
|
||||
|
||||
def __init__(self, models):
|
||||
"""
|
||||
初始化ASR服务
|
||||
|
||||
参数:
|
||||
models: 包含各种预加载模型的字典
|
||||
"""
|
||||
self.model_asr = models["asr"]
|
||||
self.model_asr_streaming = models["asr_streaming"]
|
||||
self.model_vad = models["vad"]
|
||||
self.model_punc = models["punc"]
|
||||
|
||||
async def async_vad(self, websocket, audio_in):
|
||||
"""
|
||||
语音活动检测
|
||||
|
||||
参数:
|
||||
websocket: WebSocket连接
|
||||
audio_in: 二进制音频数据
|
||||
|
||||
返回:
|
||||
tuple: (speech_start, speech_end) 语音开始和结束位置
|
||||
"""
|
||||
# 使用VAD模型分析音频段
|
||||
segments_result = self.model_vad.generate(
|
||||
input=audio_in, **websocket.status_dict_vad
|
||||
)[0]["value"]
|
||||
|
||||
speech_start = -1
|
||||
speech_end = -1
|
||||
|
||||
# 解析VAD结果
|
||||
if len(segments_result) == 0 or len(segments_result) > 1:
|
||||
return speech_start, speech_end
|
||||
|
||||
if segments_result[0][0] != -1:
|
||||
speech_start = segments_result[0][0]
|
||||
if segments_result[0][1] != -1:
|
||||
speech_end = segments_result[0][1]
|
||||
|
||||
return speech_start, speech_end
|
||||
|
||||
async def async_asr(self, websocket, audio_in):
|
||||
"""
|
||||
离线ASR处理
|
||||
|
||||
参数:
|
||||
websocket: WebSocket连接
|
||||
audio_in: 二进制音频数据
|
||||
"""
|
||||
if len(audio_in) > 0:
|
||||
# 使用离线ASR模型处理音频
|
||||
rec_result = self.model_asr.generate(
|
||||
input=audio_in, **websocket.status_dict_asr
|
||||
)[0]
|
||||
|
||||
# 如果有标点符号模型且识别出文本,则添加标点
|
||||
if self.model_punc is not None and len(rec_result["text"]) > 0:
|
||||
rec_result = self.model_punc.generate(
|
||||
input=rec_result["text"], **websocket.status_dict_punc
|
||||
)[0]
|
||||
|
||||
# 如果识别出文本,发送到客户端
|
||||
if len(rec_result["text"]) > 0:
|
||||
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
||||
message = json.dumps(
|
||||
{
|
||||
"mode": mode,
|
||||
"text": rec_result["text"],
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
}
|
||||
)
|
||||
await websocket.send(message)
|
||||
else:
|
||||
# 如果没有音频数据,发送空文本
|
||||
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
||||
message = json.dumps(
|
||||
{
|
||||
"mode": mode,
|
||||
"text": "",
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
}
|
||||
)
|
||||
await websocket.send(message)
|
||||
|
||||
async def async_asr_online(self, websocket, audio_in):
|
||||
"""
|
||||
在线ASR处理
|
||||
|
||||
参数:
|
||||
websocket: WebSocket连接
|
||||
audio_in: 二进制音频数据
|
||||
"""
|
||||
if len(audio_in) > 0:
|
||||
# 使用在线ASR模型处理音频
|
||||
rec_result = self.model_asr_streaming.generate(
|
||||
input=audio_in, **websocket.status_dict_asr_online
|
||||
)[0]
|
||||
|
||||
# 在2pass模式下,如果是最终帧则跳过(留给离线ASR处理)
|
||||
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get(
|
||||
"is_final", False
|
||||
):
|
||||
return
|
||||
|
||||
# 如果识别出文本,发送到客户端
|
||||
if len(rec_result["text"]):
|
||||
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
|
||||
message = json.dumps(
|
||||
{
|
||||
"mode": mode,
|
||||
"text": rec_result["text"],
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
}
|
||||
)
|
||||
await websocket.send(message)
|
@ -1,3 +1,6 @@
|
||||
from .logger import get_module_logger, setup_root_logger
|
||||
|
||||
__all__ = ["get_module_logger", "setup_root_logger"]
|
||||
__all__ = [
|
||||
"get_module_logger",
|
||||
"setup_root_logger",
|
||||
]
|
||||
|
63
src/websockets/adapter.py
Normal file
63
src/websockets/adapter.py
Normal file
@ -0,0 +1,63 @@
|
||||
import numpy as np
|
||||
from fastapi import WebSocket
|
||||
from typing import Union
|
||||
import uuid
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
class FastAPIWebSocketAdapter:
|
||||
"""
|
||||
一个适配器类,用于将FastAPI的WebSocket对象包装成ASRRunner所期望的接口。
|
||||
同时处理数据类型转换。
|
||||
"""
|
||||
def __init__(self, websocket: WebSocket, sample_rate: int = 16000, sample_width: int = 2):
|
||||
self._ws = websocket
|
||||
self._sample_rate = sample_rate
|
||||
self._sample_width = sample_width
|
||||
self._total_received = 0
|
||||
async def recv(self) -> Union[np.ndarray, None]:
|
||||
"""
|
||||
接收来自FastAPI WebSocket的数据。
|
||||
如果收到的是字节流,将其转换为Numpy数组。
|
||||
如果收到的是文本"close",返回None以表示结束。
|
||||
"""
|
||||
message = await self._ws.receive()
|
||||
if 'bytes' in message:
|
||||
bytes_data = message['bytes']
|
||||
# 将字节流转换为float64的Numpy数组
|
||||
audio_array = np.frombuffer(bytes_data, dtype=np.float32)
|
||||
|
||||
# 使用回车符 \r 覆盖打印进度
|
||||
self._total_received += len(bytes_data)
|
||||
print(f"🎧 [Adapter] 正在接收音频... 总计: {self._total_received / 1024:.2f} KB", end='\r')
|
||||
|
||||
return audio_array
|
||||
elif 'text' in message and message['text'].lower() == 'close':
|
||||
print("\n🏁 [Adapter] 收到 'close' 信号。") # 在收到结束信号时换行
|
||||
return None # 返回 None 来作为结束信号
|
||||
return np.array([]) # 返回空数组以忽略其他类型的消息
|
||||
|
||||
async def send(self, message: dict):
|
||||
"""
|
||||
将字典消息作为JSON发送给客户端。
|
||||
在发送前,将所有UUID对象转换为字符串以确保可序列化。
|
||||
"""
|
||||
def convert_uuids(obj):
|
||||
if isinstance(obj, dict):
|
||||
return {k: convert_uuids(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [convert_uuids(elem) for elem in obj]
|
||||
elif isinstance(obj, uuid.UUID):
|
||||
return str(obj)
|
||||
return obj
|
||||
|
||||
serializable_message = convert_uuids(message)
|
||||
logger.info(f"[Adapter] 发送消息: {serializable_message}")
|
||||
await self._ws.send_json(serializable_message)
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
关闭WebSocket连接。
|
||||
"""
|
||||
await self._ws.close()
|
89
src/websockets/endpoint/asr_endpoint.py
Normal file
89
src/websockets/endpoint/asr_endpoint.py
Normal file
@ -0,0 +1,89 @@
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Request
|
||||
from src.websockets.adapter import FastAPIWebSocketAdapter
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
@router.websocket("/asr/{session_id}")
|
||||
async def asr_websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
session_id: str,
|
||||
mode: str = Query(default="sender", enum=["sender", "receiver"])
|
||||
):
|
||||
"""
|
||||
ASR WebSocket 端点
|
||||
|
||||
- **session_id**: 标识一个识别会话的唯一ID.
|
||||
- **mode**: 客户端模式.
|
||||
- `sender`: 作为音频发送方加入,将创建一个新的识别会话.
|
||||
- `receiver`: 作为结果接收方加入,订阅一个已存在的会话.
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
# 从websocket.app.state获取全局的ASRRunner实例
|
||||
asr_runner = websocket.app.state.asr_runner
|
||||
|
||||
# 创建WebSocket适配器
|
||||
# 注意:这里的audio_config应该与ASRRunner中的默认配置一致
|
||||
audio_config = asr_runner._default_audio_config
|
||||
adapter = FastAPIWebSocketAdapter(
|
||||
websocket,
|
||||
sample_rate=audio_config.sample_rate,
|
||||
sample_width=audio_config.sample_width
|
||||
)
|
||||
|
||||
if mode == "sender":
|
||||
logger.info(f"客户端 {websocket.client} 作为 'sender' 加入会话: {session_id}")
|
||||
# 创建一个新的SAR会话
|
||||
sar_id = asr_runner.new_SAR(ws=adapter, name=session_id)
|
||||
if sar_id is None:
|
||||
logger.error(f"为会话 {session_id} 创建SAR失败")
|
||||
await websocket.close(code=1011, reason="Failed to create ASR session")
|
||||
return
|
||||
|
||||
sar = next((s for s in asr_runner._SAR_list if s._id == sar_id), None)
|
||||
try:
|
||||
# 端点函数等待后台任务完成。
|
||||
# 真正的接收逻辑在SAR的_run方法中,该方法由new_SAR作为后台任务启动。
|
||||
# 当客户端断开连接时,adapter.recv()会抛出异常,
|
||||
# _run任务会捕获它,然后停止并清理,最后任务结束。
|
||||
if sar and sar._task:
|
||||
await sar._task
|
||||
else:
|
||||
# 如果任务没有被创建,记录一个错误并关闭连接
|
||||
logger.error(f"SAR任务未能在会话 {session_id} 中启动")
|
||||
await websocket.close(code=1011, reason="Failed to start ASR task")
|
||||
|
||||
except Exception as e:
|
||||
# 捕获任何意外的错误
|
||||
logger.error(f"会话 {session_id} 的 'sender' 端点发生未知错误: {e}")
|
||||
finally:
|
||||
logger.info(f"'sender' {websocket.client} 在会话 {session_id} 的连接处理已结束")
|
||||
|
||||
elif mode == "receiver":
|
||||
logger.info(f"客户端 {websocket.client} 作为 'receiver' 加入会话: {session_id}")
|
||||
# 加入一个已存在的SAR会话
|
||||
joined = asr_runner.join_SAR(ws=adapter, name=session_id)
|
||||
if not joined:
|
||||
logger.warning(f"无法找到会话 {session_id},'receiver' {websocket.client} 加入失败")
|
||||
await websocket.close(code=1011, reason=f"Session '{session_id}' not found")
|
||||
return
|
||||
|
||||
try:
|
||||
# Receiver只需要保持连接,等待从SAR广播过来的消息
|
||||
# 这个循环也用于检测断开
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"'receiver' {websocket.client} 在会话 {session_id} 中断开连接")
|
||||
# Receiver断开时,需要将其从SAR的接收者列表中移除
|
||||
sar = next((s for s in asr_runner._SAR_list if s._name == session_id), None)
|
||||
if sar:
|
||||
sar.delete_receiver(adapter)
|
||||
logger.info(f"已从会话 {session_id} 中移除 'receiver' {websocket.client}")
|
||||
|
||||
else:
|
||||
# 理论上,由于FastAPI的enum校验,这里的代码不会被执行
|
||||
logger.error(f"无效的模式: {mode}")
|
||||
await websocket.close(code=1003, reason="Invalid mode specified")
|
@ -1,3 +1,9 @@
|
||||
from endpoint import asr_router
|
||||
from fastapi import APIRouter
|
||||
from .endpoint import asr_endpoint
|
||||
|
||||
__all__ = ["asr_router"]
|
||||
websocket_router = APIRouter()
|
||||
|
||||
# 包含ASR端点路由
|
||||
websocket_router.include_router(asr_endpoint.router)
|
||||
|
||||
__all__ = ["websocket_router"]
|
@ -6,6 +6,7 @@
|
||||
from tests.pipeline.asr_test import test_asr_pipeline
|
||||
from src.utils.logger import get_module_logger, setup_root_logger
|
||||
from tests.runner.asr_runner_test import test_asr_runner
|
||||
import asyncio
|
||||
|
||||
setup_root_logger(level="INFO", log_file="logs/test_main.log")
|
||||
logger = get_module_logger(__name__)
|
||||
@ -22,4 +23,4 @@ with open("logs/test_main.log", "w") as f:
|
||||
# test_asr_pipeline()
|
||||
|
||||
logger.info("开始测试ASRRunner")
|
||||
test_asr_runner()
|
||||
asyncio.run(test_asr_runner())
|
||||
|
@ -1,30 +1,59 @@
|
||||
"""
|
||||
ASRRunner test
|
||||
"""
|
||||
import queue
|
||||
import time
|
||||
import asyncio
|
||||
import soundfile
|
||||
import numpy as np
|
||||
from src.runner.ASRRunner import ASRRunner
|
||||
from src.core.model_loader import ModelLoader
|
||||
from src.models import AudioBinary_Config
|
||||
from src.utils.mock_websocket import MockWebSocketClient
|
||||
from asyncio import Queue as AsyncQueue
|
||||
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
def test_asr_runner():
|
||||
|
||||
class AsyncMockWebSocketClient:
|
||||
"""一个用于测试目的的异步WebSocket客户端模拟器。"""
|
||||
|
||||
def __init__(self):
|
||||
self._recv_q = AsyncQueue()
|
||||
self._send_q = AsyncQueue()
|
||||
|
||||
def put_for_recv(self, item):
|
||||
"""允许测试将数据送入模拟的WebSocket中。"""
|
||||
self._recv_q.put_nowait(item)
|
||||
|
||||
async def get_from_send(self):
|
||||
"""允许测试从模拟的WebSocket中获取结果。"""
|
||||
return await self._send_q.get()
|
||||
|
||||
async def recv(self):
|
||||
"""ASRRunner将调用此方法来获取数据。"""
|
||||
return await self._recv_q.get()
|
||||
|
||||
async def send(self, item):
|
||||
"""ASRRunner将通过回调调用此方法来发送结果。"""
|
||||
logger.info(f"Mock WS 收到结果: {item}")
|
||||
await self._send_q.put(item)
|
||||
|
||||
async def close(self):
|
||||
"""一个模拟的关闭方法。"""
|
||||
pass
|
||||
|
||||
|
||||
async def test_asr_runner():
|
||||
"""
|
||||
End-to-end test for ASRRunner.
|
||||
1. Loads models.
|
||||
2. Configures and initializes ASRRunner.
|
||||
3. Creates a mock WebSocket client.
|
||||
4. Starts a new SenderAndReceiver (SAR) instance in the runner.
|
||||
5. Streams audio data via the mock WebSocket.
|
||||
6. Asserts that the received transcription matches the expected text.
|
||||
针对ASRRunner的端到端测试,已适配异步操作。
|
||||
1. 加载模型.
|
||||
2. 配置并初始化ASRRunner.
|
||||
3. 创建一个异步的模拟WebSocket客户端.
|
||||
4. 在Runner中启动一个新的SenderAndReceiver (SAR)实例.
|
||||
5. 通过模拟的WebSocket流式传输音频数据.
|
||||
6. 等待处理任务完成并断言其无错误运行.
|
||||
"""
|
||||
# 1. Load models
|
||||
# 1. 加载模型
|
||||
model_loader = ModelLoader()
|
||||
args = {
|
||||
"asr_model": "paraformer-zh",
|
||||
@ -33,48 +62,71 @@ def test_asr_runner():
|
||||
"vad_model_revision": "v2.0.4",
|
||||
"spk_model": "cam++",
|
||||
"spk_model_revision": "v2.0.2",
|
||||
"audio_update": False,
|
||||
}
|
||||
models = model_loader.load_models(args)
|
||||
audio_file_path = "tests/XT_ZZY_denoise.wav"
|
||||
audio_data, sample_rate = soundfile.read(audio_file_path)
|
||||
logger.info(f"加载数据: {audio_file_path} , audio_data_length: {len(audio_data)}, audio_data_type: {type(audio_data)}, sample_rate: {sample_rate}")
|
||||
# 2. Configure audio
|
||||
logger.info(
|
||||
f"加载数据: {audio_file_path} , audio_data_length: {len(audio_data)}, audio_data_type: {type(audio_data)}, sample_rate: {sample_rate}"
|
||||
)
|
||||
# 进一步详细打印audio_data数据类型
|
||||
# 详细打印audio_data的类型和结构信息,便于调试
|
||||
logger.info(f"audio_data 类型: {type(audio_data)}")
|
||||
logger.info(f"audio_data dtype: {getattr(audio_data, 'dtype', '未知')}")
|
||||
logger.info(f"audio_data shape: {getattr(audio_data, 'shape', '未知')}")
|
||||
logger.info(f"audio_data ndim: {getattr(audio_data, 'ndim', '未知')}")
|
||||
logger.info(f"audio_data 示例前10个值: {audio_data[:10] if hasattr(audio_data, '__getitem__') else '不可切片'}")
|
||||
|
||||
# 2. 配置音频
|
||||
audio_config = AudioBinary_Config(
|
||||
chunk_size=200, # ms
|
||||
chunk_stride=1000, # 10ms stride for 16kHz
|
||||
sample_rate=sample_rate,
|
||||
sample_width=2, # 16-bit
|
||||
channels=2,
|
||||
channels=1,
|
||||
)
|
||||
audio_config.chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
|
||||
|
||||
# 3. Setup ASRRunner
|
||||
# 3. 设置ASRRunner
|
||||
asr_runner = ASRRunner()
|
||||
asr_runner.set_default_config(
|
||||
audio_config=audio_config,
|
||||
models=models,
|
||||
)
|
||||
|
||||
# 4. Create Mock WebSocket and start SAR
|
||||
mock_ws = MockWebSocketClient()
|
||||
# 4. 创建模拟WebSocket并启动SAR
|
||||
mock_ws = AsyncMockWebSocketClient()
|
||||
sar_id = asr_runner.new_SAR(
|
||||
ws=mock_ws,
|
||||
name="test_sar",
|
||||
)
|
||||
assert sar_id is not None, "Failed to create a new SAR instance"
|
||||
assert sar_id is not None, "创建新的SAR实例失败"
|
||||
|
||||
# 5. Simulate streaming audio
|
||||
print(f"Sending audio data of length {len(audio_data)} samples.")
|
||||
audio_clip_len = 200
|
||||
for i in range(0, len(audio_data), audio_clip_len):
|
||||
chunk = audio_data[i : i + audio_clip_len]
|
||||
if not isinstance(chunk, np.ndarray) or chunk.size == 0:
|
||||
break
|
||||
# Simulate receiving binary data over WebSocket
|
||||
mock_ws.put_for_recv(chunk)
|
||||
# 获取SAR实例以等待其任务
|
||||
sar = next((s for s in asr_runner._SAR_list if s._id == sar_id), None)
|
||||
assert sar is not None, "无法从Runner中获取SAR实例。"
|
||||
assert sar._task is not None, "SAR任务未被创建。"
|
||||
|
||||
# 6. Wait for results and assert
|
||||
time.sleep(30)
|
||||
# Signal end of audio stream by sending None
|
||||
mock_ws.put_for_recv(None)
|
||||
# 5. 在后台任务中模拟流式音频
|
||||
async def feed_audio():
|
||||
logger.info("Feeder任务已启动:开始流式传输音频数据...")
|
||||
# 每次发送100ms的音频
|
||||
audio_clip_len = int(sample_rate * 0.1)
|
||||
for i in range(0, len(audio_data), audio_clip_len):
|
||||
chunk = audio_data[i : i + audio_clip_len]
|
||||
if chunk.size == 0:
|
||||
break
|
||||
mock_ws.put_for_recv(chunk)
|
||||
await asyncio.sleep(0.1) # 模拟实时流
|
||||
|
||||
# 发送None来表示音频流结束
|
||||
mock_ws.put_for_recv(None)
|
||||
logger.info("Feeder任务已完成:所有音频数据已发送。")
|
||||
|
||||
feeder_task = asyncio.create_task(feed_audio())
|
||||
|
||||
# 6. 等待SAR处理完成
|
||||
# SAR任务在从模拟WebSocket接收到None后会结束
|
||||
await sar._task
|
||||
await feeder_task # 确保feeder也已完成
|
||||
|
||||
logger.info("ASRRunner测试成功完成。")
|
||||
|
110
tests/websocket/websocket_asr.py
Normal file
110
tests/websocket/websocket_asr.py
Normal file
@ -0,0 +1,110 @@
|
||||
import asyncio
|
||||
import websockets
|
||||
import soundfile as sf
|
||||
import uuid
|
||||
|
||||
# --- 配置 ---
|
||||
HOST = "localhost"
|
||||
PORT = 8000
|
||||
SESSION_ID = str(uuid.uuid4())
|
||||
SENDER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=sender"
|
||||
RECEIVER_URI = f"ws://{HOST}:{PORT}/ws/asr/{SESSION_ID}?mode=receiver"
|
||||
|
||||
AUDIO_FILE_PATH = "tests/XT_ZZY_denoise.wav" # 确保此测试文件存在且为 16kHz, 16-bit, 单声道
|
||||
CHUNK_DURATION_MS = 100 # 每次发送100ms的音频数据
|
||||
CHUNK_SIZE = int(16000 * 2 * CHUNK_DURATION_MS / 1000) # 3200 bytes
|
||||
|
||||
async def run_receiver():
|
||||
"""作为接收者连接,并打印收到的所有消息。"""
|
||||
print(f"▶️ [Receiver] 尝试连接到: {RECEIVER_URI}")
|
||||
try:
|
||||
async with websockets.connect(RECEIVER_URI) as websocket:
|
||||
print("✅ [Receiver] 连接成功,等待消息...")
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.recv()
|
||||
print(f"🎧 [Receiver] 收到结果: {message}")
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
print(f"✅ [Receiver] 连接已由服务器正常关闭: {e.reason}")
|
||||
except Exception as e:
|
||||
print(f"❌ [Receiver] 连接失败: {e}")
|
||||
|
||||
async def run_sender():
|
||||
"""
|
||||
作为发送者连接,同时负责发送音频和接收自己会话的广播结果。
|
||||
"""
|
||||
await asyncio.sleep(1) # 等待receiver有机会先连接
|
||||
print(f"▶️ [Sender] 尝试连接到: {SENDER_URI}")
|
||||
try:
|
||||
async with websockets.connect(SENDER_URI) as websocket:
|
||||
print("✅ [Sender] 连接成功。")
|
||||
|
||||
# --- 并行任务:接收消息 ---
|
||||
async def receive_task():
|
||||
print("▶️ [Sender-Receiver] 开始监听广播消息...")
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.recv()
|
||||
print(f"🎧 [Sender-Receiver] 收到结果: {message}")
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
print("✅ [Sender-Receiver] 连接已关闭,停止监听。")
|
||||
|
||||
receiver_sub_task = asyncio.create_task(receive_task())
|
||||
|
||||
# --- 主任务:发送音频 ---
|
||||
try:
|
||||
print("▶️ [Sender] 准备发送音频...")
|
||||
audio_data, sample_rate = sf.read(AUDIO_FILE_PATH, dtype='float32')
|
||||
if sample_rate != 16000:
|
||||
print(f"❌ [Sender] 错误:音频文件采样率必须是 16kHz。")
|
||||
receiver_sub_task.cancel()
|
||||
return
|
||||
|
||||
total_samples = len(audio_data)
|
||||
chunk_samples = CHUNK_SIZE // 2
|
||||
samples_sent = 0
|
||||
print(f"音频加载成功,总长度: {total_samples} samples。开始分块发送...")
|
||||
|
||||
for i in range(0, total_samples, chunk_samples):
|
||||
chunk = audio_data[i:i + chunk_samples]
|
||||
if len(chunk) == 0:
|
||||
break
|
||||
await websocket.send(chunk.tobytes())
|
||||
samples_sent += len(chunk)
|
||||
print(f"🎧 [Sender] 正在发送: {samples_sent}/{total_samples} samples", end="\r")
|
||||
await asyncio.sleep(CHUNK_DURATION_MS / 1000)
|
||||
|
||||
print()
|
||||
print("🏁 [Sender] 音频流发送完毕,发送 'close' 信号。")
|
||||
await websocket.send("close")
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"❌ [Sender] 错误:找不到音频文件 {AUDIO_FILE_PATH}")
|
||||
except Exception as e:
|
||||
print(f"❌ [Sender] 发送过程中发生错误: {e}")
|
||||
|
||||
# 等待接收任务自然结束(当连接关闭时)
|
||||
await receiver_sub_task
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ [Sender] 连接失败: {e}")
|
||||
|
||||
async def main():
|
||||
"""同时运行 sender 和 receiver 任务。"""
|
||||
print("--- 开始 WebSocket ASR 服务端到端测试 ---")
|
||||
print(f"会话 ID: {SESSION_ID}")
|
||||
|
||||
# 创建 receiver 和 sender 任务
|
||||
sender_task = asyncio.create_task(run_sender())
|
||||
await asyncio.sleep(7)
|
||||
receiver_task = asyncio.create_task(run_receiver())
|
||||
|
||||
# 等待两个任务完成
|
||||
await asyncio.gather(receiver_task, sender_task)
|
||||
|
||||
print("--- 测试结束 ---")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 在运行此脚本前,请确保 FastAPI 服务器正在运行。
|
||||
# python main.py
|
||||
asyncio.run(main())
|
Loading…
x
Reference in New Issue
Block a user