[Runner]完成ASRRunner的编写和测试,使用MockWebSocket完成虚拟网络连接。

This commit is contained in:
Ziyang.Zhang 2025-06-25 16:57:41 +08:00
parent d5b9953905
commit 5dac718dee
10 changed files with 317 additions and 192 deletions

View File

@ -128,7 +128,7 @@ class AudioChunk:
此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑
"""
_instance: Optional[AudioChunk] = None
_instance: Optional["AudioChunk"] = None
def __new__(cls, *args, **kwargs):
"""

View File

@ -1,11 +1,25 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
默认配置DefaultConfig
- audio_config: 音频配置
配置模块 - 处理命令行参数和配置项
"""
import argparse
from src.models import AudioBinary_Config
class DefaultConfig:
"""
默认配置
"""
audio_config = AudioBinary_Config(
chunk_size=200,
chunk_stride=1600,
sample_rate=16000,
sample_width=16,
channels=1,
)
def parse_args():
"""

View File

@ -136,7 +136,7 @@ class PipelineFactory:
pipeline.set_audio_binary(kwargs["audio_binary"])
pipeline.set_input_queue(kwargs["input_queue"])
pipeline.set_callback(kwargs["callback"])
pipeline.bake()
# pipeline.bake()
return pipeline
@classmethod

View File

@ -10,9 +10,15 @@ from src.pipeline.ASRpipeline import ASRPipeline
from src.pipeline import PipelineFactory
from src.models import AudioBinary_data_list, AudioBinary_Config
from src.core.model_loader import ModelLoader
from src.config import DefaultConfig
from queue import Queue
import soundfile
import time
from typing import List, Optional
import uuid
from threading import Thread
from src.utils.mock_websocket import MockWebSocketClient as WebSocketClient
from .runner import RunnerBase
from src.utils.logger import get_module_logger
@ -23,205 +29,175 @@ OVAERWATCH = False
model_loader = ModelLoader()
def test_asr_pipeline():
# 加载模型
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"spk_model": "cam++",
"spk_model_revision": "v2.0.2",
"audio_update": False,
}
models = model_loader.load_models(args)
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
audio_config = AudioBinary_Config(
chunk_size=200,
chunk_stride=1600,
sample_rate=sample_rate,
sample_width=16,
channels=1,
)
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
audio_config.chunk_stride = chunk_stride
# 创建参数Dict
config = {
"audio_config": audio_config,
}
# 创建音频数据列表
audio_binary_data_list = AudioBinary_data_list()
input_queue = Queue()
# 创建Pipeline
# asr_pipeline = ASRPipeline()
# asr_pipeline.set_models(models)
# asr_pipeline.set_config(config)
# asr_pipeline.set_audio_binary(audio_binary_data_list)
# asr_pipeline.set_input_queue(input_queue)
# asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
# asr_pipeline.bake()
asr_pipeline = PipelineFactory.create_pipeline(
pipeline_name = "ASRpipeline",
models=models,
config=config,
audio_binary=audio_binary_data_list,
input_queue=input_queue,
callback=lambda x: print(f"pipeline callback: {x}")
)
# 运行Pipeline
asr_instance = asr_pipeline.run()
audio_clip_len = 200
print(
f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}"
)
for i in range(0, len(audio_data), audio_clip_len):
input_queue.put(audio_data[i : i + audio_clip_len])
# time.sleep(10)
# input_queue.put(None)
# 等待Pipeline结束
# asr_instance.join()
time.sleep(5)
asr_pipeline.stop()
# asr_pipeline.stop()
class STTRunner(RunnerBase):
class ASRRunner(RunnerBase):
"""
运行器类
负责管理资源和协调Pipeline的运行
"""
def __init__(
self,
*args,
**kwargs,
):
class SenderAndReceiver:
"""
对于单个pipeline的管理
包含 发送者 接收者
_sender: 发送者 唯一
_receiver: 接收者 可以有多个
_pipeline: 对应管道 唯一
"""
# ws资源
self._ws_pool: Dict[str,List[WebSocketClient]] = {}
# 接收资源
self._audio_binary_list = audio_binary_list
self._models = models
self._pipeline_list = pipeline_list
def __init__(self, *args, **kwargs):
# 可选传入参数,
self._name: str = kwargs.get("name", "")
self._sender: Optional[WebSocketClient] = kwargs.get("sender", None)
self._receiver: List[WebSocketClient] = kwargs.get("receiver", [])
# 线程控制
self._lock = Lock()
# 停止控制
self._stop_timeout = 10.0
self._is_stopping = False
# 资源
self._audio_config: AudioBinary_Config = kwargs.get("audio_config", DefaultConfig.audio_config)
self._models: dict = kwargs.get("models", None)
self._audio_binary: AudioBinary_data_list = AudioBinary_data_list()
# id唯一标识
self._id: str = str(uuid.uuid4())
# 输入队列
self._input_queue: Queue = Queue()
self._pipeline: Optional[ASRPipeline] = None
# 配置资源
for pipeline in self._pipeline_list:
# 设置输入队列
pipeline.set_input_queue(self._input_queue)
def set_name(self, name: str):
self._name = name
# 配置资源
pipeline.set_audio_binary(
self._audio_binary_list[pipeline.get_config("audio_binary_name")]
)
pipeline.set_models(self._models)
def set_id(self, id: str):
self._id = id
def run(self) -> None:
def set_sender(self, sender: WebSocketClient):
self._sender = sender
def set_pipeline(self, pipeline: ASRPipeline):
self._pipeline = pipeline
config = {
"audio_config": self._audio_config,
}
self._pipeline.set_config(config)
self._pipeline.set_models(self._models)
self._pipeline.set_audio_binary(self._audio_binary)
self._pipeline.set_input_queue(self._input_queue)
self._pipeline.set_callback(self.deal_message)
self._pipeline.bake()
def append_receiver(self, receiver: WebSocketClient):
self._receiver.append(receiver)
def delete_receiver(self, receiver: WebSocketClient):
self._receiver.remove(receiver)
def deal_message(self, message: str):
self.broadcast(message)
def broadcast(self, message: str):
"""
启动所有管道
广播发送给所有接收者
"""
logger.info("[%s] 启动所有管道", self.__class__.__name__)
if not self._pipeline_list:
raise RuntimeError("没有可用的管道")
logger.info("[ASRRunner][SAR-%s]广播发送给所有接收者: %s", self._name, message)
for receiver in self._receiver:
receiver.send(message)
# 启动所有管道
for pipeline in self._pipeline_list:
thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}")
thread.daemon = True
thread.start()
logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline))
def stop(self, force: bool = False) -> bool:
def _run(self):
"""
停止所有管道
参数:
force: 是否强制停止
返回:
bool: 是否成功停止
运行SAR
"""
if self._is_stopping:
logger.warning("运行器已经在停止中")
return False
self._is_stopping = True
logger.info("正在停止运行器...")
try:
# 发送结束信号
self._input_queue.put(None)
# 停止所有管道
success = True
for pipeline in self._pipeline_list:
if force:
pipeline.force_stop()
else:
if not pipeline.stop(timeout=self._stop_timeout):
logger.warning("管道 %s 停止超时", id(pipeline))
success = False
# 等待队列处理完成
try:
start_time = time.time()
while not self._input_queue.empty():
if time.time() - start_time > self._stop_timeout:
logger.warning(
"等待队列处理完成超时(%s秒), 队列中还有 %d 个任务未处理",
self._stop_timeout,
self._input_queue.qsize(),
)
success = False
self._pipeline.run()
while True:
data = self._sender.recv()
if data is None:
break
time.sleep(0.1) # 避免过度消耗CPU
logger.debug("[ASRRunner][SAR-%s]接收到的数据length: %s", self._name, len(data))
self._input_queue.put(data)
self.stop()
def run(self):
"""
运行SAR
"""
self._thread = Thread(target=self._run, name=f"[ASRRunner]SAR-{self._name}")
self._thread.daemon = True
self._thread.start()
def stop(self):
"""
停止SAR
"""
self._pipeline.stop()
for ws in self._receiver:
ws.close()
self._sender.close()
def __init__(self,*args,**kwargs):
"""
"""
# 接收资源
self._default_audio_config = kwargs.get("audio_config", DefaultConfig.audio_config)
# self._audio_binary_list = args.get("audio_binary_list", None)
self._default_models = kwargs.get("models", None)
self._SAR_list: List[self.SenderAndReceiver] = []
def set_default_config(self, *args, **kwargs):
"""
设置配置
"""
self._default_audio_config = kwargs.get("audio_config", self._default_audio_config)
self._default_models = kwargs.get("models", self._default_models)
def new_SAR(
self,
ws: "WebSocketClient",
name: str = "",
audio_config: "AudioBinary_Config" = None,
models: dict = None
) -> uuid.UUID:
"""
创建新的SAR SenderAndReceiver
"""
if audio_config is None:
audio_config = self._default_audio_config
if models is None:
models = self._default_models
try:
new_SAR = self.SenderAndReceiver(
name=name,
audio_config=audio_config,
models=models
)
new_pipeline = ASRPipeline()
new_SAR.set_pipeline(new_pipeline)
# new_SAR.set_pipeline()
logger.info("创建新的SAR: name %s, id %s", new_SAR._name, new_SAR._id)
new_SAR.set_sender(ws)
new_SAR.append_receiver(ws)
new_SAR.run()
self._SAR_list.append(new_SAR)
return new_SAR._id
except Exception as e:
error_type = type(e).__name__
error_msg = str(e)
error_traceback = traceback.format_exc()
logger.error(
"等待队列处理完成时发生错误:\n"
"错误类型: %s\n"
"错误信息: %s\n"
"错误堆栈:\n%s",
error_type,
error_msg,
error_traceback,
)
success = False
logger.error("创建管道失败: %s", e)
return None
if success:
logger.info("所有管道已成功停止")
else:
logger.warning(
"部分管道停止失败, 队列状态: 大小=%d, 是否为空=%s",
self._input_queue.qsize(),
self._input_queue.empty(),
)
return success
finally:
self._is_stopping = False
def join_SAR(
self,
ws: "WebSocketClient",
name: Optional[str] = None,
id: Optional[str] = None,
) -> bool:
"""
加入SAR的Receiver
"""
# 使用next获取迭代器下一个元素生成pipeline_list迭代器按id停止
if id:
exist_pipeline = next((pipeline for pipeline in self._SAR_list if pipeline._id == id), None)
if name:
exist_pipeline = next((pipeline for pipeline in self._SAR_list if pipeline._name == name), None)
if exist_pipeline:
exist_pipeline.append_receiver(ws)
return True
return False
def __del__(self) -> None:
"""
析构函数
"""
self.stop(force=True)
for sar in self._SAR_list:
sar.stop()

View File

@ -18,8 +18,8 @@ from queue import Queue
import traceback
import time
from src.audio_chunk import AudioChunk, AudioBinary
from src.pipeline import Pipeline, PipelineFactory
from src.audio_chunk import AudioChunk
from src.pipeline import PipelineFactory
from src.core.model_loader import ModelLoader
from src.utils.logger import get_module_logger
@ -48,7 +48,7 @@ class STTRunnerFactory:
audio_binary_name: str,
model_name_list: List[str],
pipeline_name_list: List[str],
) -> STTRunner:
) -> RunnerBase:
"""
创建运行器
参数:
@ -67,7 +67,7 @@ class STTRunnerFactory:
PipelineFactory.create_pipeline(pipeline_name)
for pipeline_name in pipeline_name_list
]
return STTRunner(
return RunnerBase(
audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines
)
@ -75,7 +75,7 @@ class STTRunnerFactory:
def create_runner_from_config(
cls,
config: Dict[str, Any],
) -> STTRunner:
) -> RunnerBase:
"""
从配置创建运行器
参数:
@ -91,7 +91,7 @@ class STTRunnerFactory:
)
@classmethod
def create_runner_normal(cls) -> STTRunner:
def create_runner_normal(cls) -> RunnerBase:
"""
创建默认运行器
返回:

View File

@ -0,0 +1,55 @@
import queue
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class MockWebSocketClient:
"""A mock WebSocket client to simulate a connection for testing."""
def __init__(self):
self.sent_messages = []
self._is_closed = False
self.receive_queue = queue.Queue()
def send(self, message: dict):
"""Simulates sending a message (which is a dict)."""
if self._is_closed:
print("Warning: sending message on a closed websocket")
return
self.sent_messages.append(message)
print(f"Mock WS received: {message}")
def recv(self):
"""Simulates receiving data from the WebSocket."""
if self._is_closed:
return None
try:
# Block until data is available, with a timeout to prevent hanging.
data = self.receive_queue.get(timeout=10)
if data is None:
self._is_closed = True
return data
except queue.Empty:
print("Mock WS recv timeout")
self._is_closed = True
return None
def close(self):
"""Simulates closing the WebSocket connection."""
if not self._is_closed:
# Put None to unblock any waiting recv call
self.receive_queue.put(None)
self._is_closed = True
print("Mock WS closed")
def put_for_recv(self, data):
"""Puts data into the receive queue for the `recv` method to consume."""
if data is None:
return
logger.debug("Mock WS put_for_recv length: %s", len(data))
self.receive_queue.put(data)
@property
def is_closed(self):
return self._is_closed

View File

@ -5,6 +5,7 @@
from tests.pipeline.asr_test import test_asr_pipeline
from src.utils.logger import get_module_logger, setup_root_logger
from tests.runner.stt_runner_test import test_asr_runner
setup_root_logger(level="INFO", log_file="logs/test_main.log")
logger = get_module_logger(__name__)
@ -13,5 +14,8 @@ logger = get_module_logger(__name__)
# logger.info("开始测试VAD函数器")
# test_vad_functor()
logger.info("开始测试ASR管道")
test_asr_pipeline()
# logger.info("开始测试ASR管道")
# test_asr_pipeline()
logger.info("开始测试ASRRunner")
test_asr_runner()

View File

@ -71,6 +71,7 @@ def test_asr_pipeline():
callback=lambda x: print(f"pipeline callback: {x}")
)
asr_pipeline.bake()
# 运行Pipeline
asr_instance = asr_pipeline.run()

View File

@ -0,0 +1,75 @@
"""
ASRRunner test
"""
import queue
import time
import soundfile
import numpy as np
from src.runner.ASRRunner import ASRRunner
from src.core.model_loader import ModelLoader
from src.models import AudioBinary_Config
from src.utils.mock_websocket import MockWebSocketClient
def test_asr_runner():
"""
End-to-end test for ASRRunner.
1. Loads models.
2. Configures and initializes ASRRunner.
3. Creates a mock WebSocket client.
4. Starts a new SenderAndReceiver (SAR) instance in the runner.
5. Streams audio data via the mock WebSocket.
6. Asserts that the received transcription matches the expected text.
"""
# 1. Load models
model_loader = ModelLoader()
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"spk_model": "cam++",
"spk_model_revision": "v2.0.2",
"audio_update": False,
}
models = model_loader.load_models(args)
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
# 2. Configure audio
audio_config = AudioBinary_Config(
chunk_size=200, # ms
chunk_stride=1600, # 10ms stride for 16kHz
sample_rate=sample_rate,
sample_width=2, # 16-bit
channels=1,
)
audio_config.chunk_stride = int(audio_config.chunk_stride * sample_rate / 1000)
# 3. Setup ASRRunner
asr_runner = ASRRunner()
asr_runner.set_default_config(
audio_config=audio_config,
models=models,
)
# 4. Create Mock WebSocket and start SAR
mock_ws = MockWebSocketClient()
sar_id = asr_runner.new_SAR(
ws=mock_ws,
name="test_sar",
)
assert sar_id is not None, "Failed to create a new SAR instance"
# 5. Simulate streaming audio
print(f"Sending audio data of length {len(audio_data)} samples.")
audio_clip_len = 200
for i in range(0, len(audio_data), audio_clip_len):
chunk = audio_data[i : i + audio_clip_len]
if not isinstance(chunk, np.ndarray) or chunk.size == 0:
break
# Simulate receiving binary data over WebSocket
mock_ws.put_for_recv(chunk)
# 6. Wait for results and assert
time.sleep(10)
# Signal end of audio stream by sending None
mock_ws.put_for_recv(None)