133 lines
4.5 KiB
Python
133 lines
4.5 KiB
Python
"""
|
||
ASRRunner test
|
||
"""
|
||
import asyncio
|
||
import soundfile
|
||
import numpy as np
|
||
from src.runner.ASRRunner import ASRRunner
|
||
from src.core.model_loader import ModelLoader
|
||
from src.models import AudioBinary_Config
|
||
from asyncio import Queue as AsyncQueue
|
||
|
||
from src.utils.logger import get_module_logger
|
||
|
||
logger = get_module_logger(__name__)
|
||
|
||
|
||
class AsyncMockWebSocketClient:
|
||
"""一个用于测试目的的异步WebSocket客户端模拟器。"""
|
||
|
||
def __init__(self):
|
||
self._recv_q = AsyncQueue()
|
||
self._send_q = AsyncQueue()
|
||
|
||
def put_for_recv(self, item):
|
||
"""允许测试将数据送入模拟的WebSocket中。"""
|
||
self._recv_q.put_nowait(item)
|
||
|
||
async def get_from_send(self):
|
||
"""允许测试从模拟的WebSocket中获取结果。"""
|
||
return await self._send_q.get()
|
||
|
||
async def recv(self):
|
||
"""ASRRunner将调用此方法来获取数据。"""
|
||
return await self._recv_q.get()
|
||
|
||
async def send(self, item):
|
||
"""ASRRunner将通过回调调用此方法来发送结果。"""
|
||
logger.info(f"Mock WS 收到结果: {item}")
|
||
await self._send_q.put(item)
|
||
|
||
async def close(self):
|
||
"""一个模拟的关闭方法。"""
|
||
pass
|
||
|
||
|
||
async def test_asr_runner():
|
||
"""
|
||
针对ASRRunner的端到端测试,已适配异步操作。
|
||
1. 加载模型.
|
||
2. 配置并初始化ASRRunner.
|
||
3. 创建一个异步的模拟WebSocket客户端.
|
||
4. 在Runner中启动一个新的SenderAndReceiver (SAR)实例.
|
||
5. 通过模拟的WebSocket流式传输音频数据.
|
||
6. 等待处理任务完成并断言其无错误运行.
|
||
"""
|
||
# 1. 加载模型
|
||
model_loader = ModelLoader()
|
||
args = {
|
||
"asr_model": "paraformer-zh",
|
||
"asr_model_revision": "v2.0.4",
|
||
"vad_model": "fsmn-vad",
|
||
"vad_model_revision": "v2.0.4",
|
||
"spk_model": "cam++",
|
||
"spk_model_revision": "v2.0.2",
|
||
}
|
||
models = model_loader.load_models(args)
|
||
audio_file_path = "tests/XT_ZZY_denoise.wav"
|
||
audio_data, sample_rate = soundfile.read(audio_file_path)
|
||
logger.info(
|
||
f"加载数据: {audio_file_path} , audio_data_length: {len(audio_data)}, audio_data_type: {type(audio_data)}, sample_rate: {sample_rate}"
|
||
)
|
||
# 进一步详细打印audio_data数据类型
|
||
# 详细打印audio_data的类型和结构信息,便于调试
|
||
logger.info(f"audio_data 类型: {type(audio_data)}")
|
||
logger.info(f"audio_data dtype: {getattr(audio_data, 'dtype', '未知')}")
|
||
logger.info(f"audio_data shape: {getattr(audio_data, 'shape', '未知')}")
|
||
logger.info(f"audio_data ndim: {getattr(audio_data, 'ndim', '未知')}")
|
||
logger.info(f"audio_data 示例前10个值: {audio_data[:10] if hasattr(audio_data, '__getitem__') else '不可切片'}")
|
||
|
||
# 2. 配置音频
|
||
audio_config = AudioBinary_Config(
|
||
chunk_size=200, # ms
|
||
sample_rate=sample_rate,
|
||
sample_width=2, # 16-bit
|
||
channels=1,
|
||
)
|
||
audio_config.chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
|
||
|
||
# 3. 设置ASRRunner
|
||
asr_runner = ASRRunner()
|
||
asr_runner.set_default_config(
|
||
audio_config=audio_config,
|
||
models=models,
|
||
)
|
||
|
||
# 4. 创建模拟WebSocket并启动SAR
|
||
mock_ws = AsyncMockWebSocketClient()
|
||
sar_id = asr_runner.new_SAR(
|
||
ws=mock_ws,
|
||
name="test_sar",
|
||
)
|
||
assert sar_id is not None, "创建新的SAR实例失败"
|
||
|
||
# 获取SAR实例以等待其任务
|
||
sar = next((s for s in asr_runner._SAR_list if s._id == sar_id), None)
|
||
assert sar is not None, "无法从Runner中获取SAR实例。"
|
||
assert sar._task is not None, "SAR任务未被创建。"
|
||
|
||
# 5. 在后台任务中模拟流式音频
|
||
async def feed_audio():
|
||
logger.info("Feeder任务已启动:开始流式传输音频数据...")
|
||
# 每次发送100ms的音频
|
||
audio_clip_len = int(sample_rate * 0.1)
|
||
for i in range(0, len(audio_data), audio_clip_len):
|
||
chunk = audio_data[i : i + audio_clip_len]
|
||
if chunk.size == 0:
|
||
break
|
||
mock_ws.put_for_recv(chunk)
|
||
await asyncio.sleep(0.1) # 模拟实时流
|
||
|
||
# 发送None来表示音频流结束
|
||
mock_ws.put_for_recv(None)
|
||
logger.info("Feeder任务已完成:所有音频数据已发送。")
|
||
|
||
feeder_task = asyncio.create_task(feed_audio())
|
||
|
||
# 6. 等待SAR处理完成
|
||
# SAR任务在从模拟WebSocket接收到None后会结束
|
||
await sar._task
|
||
await feeder_task # 确保feeder也已完成
|
||
|
||
logger.info("ASRRunner测试成功完成。")
|