STT_Server/tests/runner/asr_runner_test.py

133 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
ASRRunner test
"""
import asyncio
import soundfile
import numpy as np
from src.runner.ASRRunner import ASRRunner
from src.core.model_loader import ModelLoader
from src.models import AudioBinary_Config
from asyncio import Queue as AsyncQueue
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class AsyncMockWebSocketClient:
"""一个用于测试目的的异步WebSocket客户端模拟器。"""
def __init__(self):
self._recv_q = AsyncQueue()
self._send_q = AsyncQueue()
def put_for_recv(self, item):
"""允许测试将数据送入模拟的WebSocket中。"""
self._recv_q.put_nowait(item)
async def get_from_send(self):
"""允许测试从模拟的WebSocket中获取结果。"""
return await self._send_q.get()
async def recv(self):
"""ASRRunner将调用此方法来获取数据。"""
return await self._recv_q.get()
async def send(self, item):
"""ASRRunner将通过回调调用此方法来发送结果。"""
logger.info(f"Mock WS 收到结果: {item}")
await self._send_q.put(item)
async def close(self):
"""一个模拟的关闭方法。"""
pass
async def test_asr_runner():
"""
针对ASRRunner的端到端测试已适配异步操作。
1. 加载模型.
2. 配置并初始化ASRRunner.
3. 创建一个异步的模拟WebSocket客户端.
4. 在Runner中启动一个新的SenderAndReceiver (SAR)实例.
5. 通过模拟的WebSocket流式传输音频数据.
6. 等待处理任务完成并断言其无错误运行.
"""
# 1. 加载模型
model_loader = ModelLoader()
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"spk_model": "cam++",
"spk_model_revision": "v2.0.2",
}
models = model_loader.load_models(args)
audio_file_path = "tests/XT_ZZY_denoise.wav"
audio_data, sample_rate = soundfile.read(audio_file_path)
logger.info(
f"加载数据: {audio_file_path} , audio_data_length: {len(audio_data)}, audio_data_type: {type(audio_data)}, sample_rate: {sample_rate}"
)
# 进一步详细打印audio_data数据类型
# 详细打印audio_data的类型和结构信息便于调试
logger.info(f"audio_data 类型: {type(audio_data)}")
logger.info(f"audio_data dtype: {getattr(audio_data, 'dtype', '未知')}")
logger.info(f"audio_data shape: {getattr(audio_data, 'shape', '未知')}")
logger.info(f"audio_data ndim: {getattr(audio_data, 'ndim', '未知')}")
logger.info(f"audio_data 示例前10个值: {audio_data[:10] if hasattr(audio_data, '__getitem__') else '不可切片'}")
# 2. 配置音频
audio_config = AudioBinary_Config(
chunk_size=200, # ms
sample_rate=sample_rate,
sample_width=2, # 16-bit
channels=1,
)
audio_config.chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
# 3. 设置ASRRunner
asr_runner = ASRRunner()
asr_runner.set_default_config(
audio_config=audio_config,
models=models,
)
# 4. 创建模拟WebSocket并启动SAR
mock_ws = AsyncMockWebSocketClient()
sar_id = asr_runner.new_SAR(
ws=mock_ws,
name="test_sar",
)
assert sar_id is not None, "创建新的SAR实例失败"
# 获取SAR实例以等待其任务
sar = next((s for s in asr_runner._SAR_list if s._id == sar_id), None)
assert sar is not None, "无法从Runner中获取SAR实例。"
assert sar._task is not None, "SAR任务未被创建。"
# 5. 在后台任务中模拟流式音频
async def feed_audio():
logger.info("Feeder任务已启动开始流式传输音频数据...")
# 每次发送100ms的音频
audio_clip_len = int(sample_rate * 0.1)
for i in range(0, len(audio_data), audio_clip_len):
chunk = audio_data[i : i + audio_clip_len]
if chunk.size == 0:
break
mock_ws.put_for_recv(chunk)
await asyncio.sleep(0.1) # 模拟实时流
# 发送None来表示音频流结束
mock_ws.put_for_recv(None)
logger.info("Feeder任务已完成所有音频数据已发送。")
feeder_task = asyncio.create_task(feed_audio())
# 6. 等待SAR处理完成
# SAR任务在从模拟WebSocket接收到None后会结束
await sar._task
await feeder_task # 确保feeder也已完成
logger.info("ASRRunner测试成功完成。")