[Runner]完成ASRRunner的编写和测试,使用MockWebSocket完成虚拟网络连接。

This commit is contained in:
Ziyang.Zhang 2025-06-25 16:57:41 +08:00
parent d5b9953905
commit 5dac718dee
10 changed files with 317 additions and 192 deletions

View File

@ -128,7 +128,7 @@ class AudioChunk:
此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑 此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑
""" """
_instance: Optional[AudioChunk] = None _instance: Optional["AudioChunk"] = None
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
""" """

View File

@ -1,11 +1,25 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
默认配置DefaultConfig
- audio_config: 音频配置
配置模块 - 处理命令行参数和配置项 配置模块 - 处理命令行参数和配置项
""" """
import argparse import argparse
from src.models import AudioBinary_Config
class DefaultConfig:
"""
默认配置
"""
audio_config = AudioBinary_Config(
chunk_size=200,
chunk_stride=1600,
sample_rate=16000,
sample_width=16,
channels=1,
)
def parse_args(): def parse_args():
""" """

View File

@ -136,7 +136,7 @@ class PipelineFactory:
pipeline.set_audio_binary(kwargs["audio_binary"]) pipeline.set_audio_binary(kwargs["audio_binary"])
pipeline.set_input_queue(kwargs["input_queue"]) pipeline.set_input_queue(kwargs["input_queue"])
pipeline.set_callback(kwargs["callback"]) pipeline.set_callback(kwargs["callback"])
pipeline.bake() # pipeline.bake()
return pipeline return pipeline
@classmethod @classmethod

View File

@ -10,9 +10,15 @@ from src.pipeline.ASRpipeline import ASRPipeline
from src.pipeline import PipelineFactory from src.pipeline import PipelineFactory
from src.models import AudioBinary_data_list, AudioBinary_Config from src.models import AudioBinary_data_list, AudioBinary_Config
from src.core.model_loader import ModelLoader from src.core.model_loader import ModelLoader
from src.config import DefaultConfig
from queue import Queue from queue import Queue
import soundfile import soundfile
import time import time
from typing import List, Optional
import uuid
from threading import Thread
from src.utils.mock_websocket import MockWebSocketClient as WebSocketClient
from .runner import RunnerBase
from src.utils.logger import get_module_logger from src.utils.logger import get_module_logger
@ -23,205 +29,175 @@ OVAERWATCH = False
model_loader = ModelLoader() model_loader = ModelLoader()
class ASRRunner(RunnerBase):
def test_asr_pipeline():
# 加载模型
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"spk_model": "cam++",
"spk_model_revision": "v2.0.2",
"audio_update": False,
}
models = model_loader.load_models(args)
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
audio_config = AudioBinary_Config(
chunk_size=200,
chunk_stride=1600,
sample_rate=sample_rate,
sample_width=16,
channels=1,
)
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
audio_config.chunk_stride = chunk_stride
# 创建参数Dict
config = {
"audio_config": audio_config,
}
# 创建音频数据列表
audio_binary_data_list = AudioBinary_data_list()
input_queue = Queue()
# 创建Pipeline
# asr_pipeline = ASRPipeline()
# asr_pipeline.set_models(models)
# asr_pipeline.set_config(config)
# asr_pipeline.set_audio_binary(audio_binary_data_list)
# asr_pipeline.set_input_queue(input_queue)
# asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
# asr_pipeline.bake()
asr_pipeline = PipelineFactory.create_pipeline(
pipeline_name = "ASRpipeline",
models=models,
config=config,
audio_binary=audio_binary_data_list,
input_queue=input_queue,
callback=lambda x: print(f"pipeline callback: {x}")
)
# 运行Pipeline
asr_instance = asr_pipeline.run()
audio_clip_len = 200
print(
f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}"
)
for i in range(0, len(audio_data), audio_clip_len):
input_queue.put(audio_data[i : i + audio_clip_len])
# time.sleep(10)
# input_queue.put(None)
# 等待Pipeline结束
# asr_instance.join()
time.sleep(5)
asr_pipeline.stop()
# asr_pipeline.stop()
class STTRunner(RunnerBase):
""" """
运行器类 运行器类
负责管理资源和协调Pipeline的运行 负责管理资源和协调Pipeline的运行
""" """
class SenderAndReceiver:
def __init__(
self,
*args,
**kwargs,
):
""" """
对于单个pipeline的管理
包含 发送者 接收者
_sender: 发送者 唯一
_receiver: 接收者 可以有多个
_pipeline: 对应管道 唯一
""" """
# ws资源 def __init__(self, *args, **kwargs):
self._ws_pool: Dict[str,List[WebSocketClient]] = {} # 可选传入参数,
# 接收资源 self._name: str = kwargs.get("name", "")
self._audio_binary_list = audio_binary_list self._sender: Optional[WebSocketClient] = kwargs.get("sender", None)
self._models = models self._receiver: List[WebSocketClient] = kwargs.get("receiver", [])
self._pipeline_list = pipeline_list
# 线程控制 # 资源
self._lock = Lock() self._audio_config: AudioBinary_Config = kwargs.get("audio_config", DefaultConfig.audio_config)
# 停止控制 self._models: dict = kwargs.get("models", None)
self._stop_timeout = 10.0 self._audio_binary: AudioBinary_data_list = AudioBinary_data_list()
self._is_stopping = False # id唯一标识
self._id: str = str(uuid.uuid4())
# 输入队列
self._input_queue: Queue = Queue()
self._pipeline: Optional[ASRPipeline] = None
# 配置资源 def set_name(self, name: str):
for pipeline in self._pipeline_list: self._name = name
# 设置输入队列
pipeline.set_input_queue(self._input_queue)
# 配置资源 def set_id(self, id: str):
pipeline.set_audio_binary( self._id = id
self._audio_binary_list[pipeline.get_config("audio_binary_name")]
)
pipeline.set_models(self._models)
def run(self) -> None: def set_sender(self, sender: WebSocketClient):
self._sender = sender
def set_pipeline(self, pipeline: ASRPipeline):
self._pipeline = pipeline
config = {
"audio_config": self._audio_config,
}
self._pipeline.set_config(config)
self._pipeline.set_models(self._models)
self._pipeline.set_audio_binary(self._audio_binary)
self._pipeline.set_input_queue(self._input_queue)
self._pipeline.set_callback(self.deal_message)
self._pipeline.bake()
def append_receiver(self, receiver: WebSocketClient):
self._receiver.append(receiver)
def delete_receiver(self, receiver: WebSocketClient):
self._receiver.remove(receiver)
def deal_message(self, message: str):
self.broadcast(message)
def broadcast(self, message: str):
""" """
启动所有管道 广播发送给所有接收者
""" """
logger.info("[%s] 启动所有管道", self.__class__.__name__) logger.info("[ASRRunner][SAR-%s]广播发送给所有接收者: %s", self._name, message)
if not self._pipeline_list: for receiver in self._receiver:
raise RuntimeError("没有可用的管道") receiver.send(message)
# 启动所有管道 def _run(self):
for pipeline in self._pipeline_list:
thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}")
thread.daemon = True
thread.start()
logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline))
def stop(self, force: bool = False) -> bool:
""" """
停止所有管道 运行SAR
参数:
force: 是否强制停止
返回:
bool: 是否成功停止
""" """
if self._is_stopping: self._pipeline.run()
logger.warning("运行器已经在停止中") while True:
return False data = self._sender.recv()
if data is None:
self._is_stopping = True
logger.info("正在停止运行器...")
try:
# 发送结束信号
self._input_queue.put(None)
# 停止所有管道
success = True
for pipeline in self._pipeline_list:
if force:
pipeline.force_stop()
else:
if not pipeline.stop(timeout=self._stop_timeout):
logger.warning("管道 %s 停止超时", id(pipeline))
success = False
# 等待队列处理完成
try:
start_time = time.time()
while not self._input_queue.empty():
if time.time() - start_time > self._stop_timeout:
logger.warning(
"等待队列处理完成超时(%s秒), 队列中还有 %d 个任务未处理",
self._stop_timeout,
self._input_queue.qsize(),
)
success = False
break break
time.sleep(0.1) # 避免过度消耗CPU logger.debug("[ASRRunner][SAR-%s]接收到的数据length: %s", self._name, len(data))
self._input_queue.put(data)
self.stop()
def run(self):
"""
运行SAR
"""
self._thread = Thread(target=self._run, name=f"[ASRRunner]SAR-{self._name}")
self._thread.daemon = True
self._thread.start()
def stop(self):
"""
停止SAR
"""
self._pipeline.stop()
for ws in self._receiver:
ws.close()
self._sender.close()
def __init__(self,*args,**kwargs):
"""
"""
# 接收资源
self._default_audio_config = kwargs.get("audio_config", DefaultConfig.audio_config)
# self._audio_binary_list = args.get("audio_binary_list", None)
self._default_models = kwargs.get("models", None)
self._SAR_list: List[self.SenderAndReceiver] = []
def set_default_config(self, *args, **kwargs):
"""
设置配置
"""
self._default_audio_config = kwargs.get("audio_config", self._default_audio_config)
self._default_models = kwargs.get("models", self._default_models)
def new_SAR(
self,
ws: "WebSocketClient",
name: str = "",
audio_config: "AudioBinary_Config" = None,
models: dict = None
) -> uuid.UUID:
"""
创建新的SAR SenderAndReceiver
"""
if audio_config is None:
audio_config = self._default_audio_config
if models is None:
models = self._default_models
try:
new_SAR = self.SenderAndReceiver(
name=name,
audio_config=audio_config,
models=models
)
new_pipeline = ASRPipeline()
new_SAR.set_pipeline(new_pipeline)
# new_SAR.set_pipeline()
logger.info("创建新的SAR: name %s, id %s", new_SAR._name, new_SAR._id)
new_SAR.set_sender(ws)
new_SAR.append_receiver(ws)
new_SAR.run()
self._SAR_list.append(new_SAR)
return new_SAR._id
except Exception as e: except Exception as e:
error_type = type(e).__name__ logger.error("创建管道失败: %s", e)
error_msg = str(e) return None
error_traceback = traceback.format_exc()
logger.error(
"等待队列处理完成时发生错误:\n"
"错误类型: %s\n"
"错误信息: %s\n"
"错误堆栈:\n%s",
error_type,
error_msg,
error_traceback,
)
success = False
if success: def join_SAR(
logger.info("所有管道已成功停止") self,
else: ws: "WebSocketClient",
logger.warning( name: Optional[str] = None,
"部分管道停止失败, 队列状态: 大小=%d, 是否为空=%s", id: Optional[str] = None,
self._input_queue.qsize(), ) -> bool:
self._input_queue.empty(), """
) 加入SAR的Receiver
"""
return success # 使用next获取迭代器下一个元素生成pipeline_list迭代器按id停止
if id:
finally: exist_pipeline = next((pipeline for pipeline in self._SAR_list if pipeline._id == id), None)
self._is_stopping = False if name:
exist_pipeline = next((pipeline for pipeline in self._SAR_list if pipeline._name == name), None)
if exist_pipeline:
exist_pipeline.append_receiver(ws)
return True
return False
def __del__(self) -> None: def __del__(self) -> None:
""" """
析构函数 析构函数
""" """
self.stop(force=True) for sar in self._SAR_list:
sar.stop()

View File

@ -18,8 +18,8 @@ from queue import Queue
import traceback import traceback
import time import time
from src.audio_chunk import AudioChunk, AudioBinary from src.audio_chunk import AudioChunk
from src.pipeline import Pipeline, PipelineFactory from src.pipeline import PipelineFactory
from src.core.model_loader import ModelLoader from src.core.model_loader import ModelLoader
from src.utils.logger import get_module_logger from src.utils.logger import get_module_logger
@ -48,7 +48,7 @@ class STTRunnerFactory:
audio_binary_name: str, audio_binary_name: str,
model_name_list: List[str], model_name_list: List[str],
pipeline_name_list: List[str], pipeline_name_list: List[str],
) -> STTRunner: ) -> RunnerBase:
""" """
创建运行器 创建运行器
参数: 参数:
@ -67,7 +67,7 @@ class STTRunnerFactory:
PipelineFactory.create_pipeline(pipeline_name) PipelineFactory.create_pipeline(pipeline_name)
for pipeline_name in pipeline_name_list for pipeline_name in pipeline_name_list
] ]
return STTRunner( return RunnerBase(
audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines
) )
@ -75,7 +75,7 @@ class STTRunnerFactory:
def create_runner_from_config( def create_runner_from_config(
cls, cls,
config: Dict[str, Any], config: Dict[str, Any],
) -> STTRunner: ) -> RunnerBase:
""" """
从配置创建运行器 从配置创建运行器
参数: 参数:
@ -91,7 +91,7 @@ class STTRunnerFactory:
) )
@classmethod @classmethod
def create_runner_normal(cls) -> STTRunner: def create_runner_normal(cls) -> RunnerBase:
""" """
创建默认运行器 创建默认运行器
返回: 返回:

View File

@ -0,0 +1,55 @@
import queue
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class MockWebSocketClient:
"""A mock WebSocket client to simulate a connection for testing."""
def __init__(self):
self.sent_messages = []
self._is_closed = False
self.receive_queue = queue.Queue()
def send(self, message: dict):
"""Simulates sending a message (which is a dict)."""
if self._is_closed:
print("Warning: sending message on a closed websocket")
return
self.sent_messages.append(message)
print(f"Mock WS received: {message}")
def recv(self):
"""Simulates receiving data from the WebSocket."""
if self._is_closed:
return None
try:
# Block until data is available, with a timeout to prevent hanging.
data = self.receive_queue.get(timeout=10)
if data is None:
self._is_closed = True
return data
except queue.Empty:
print("Mock WS recv timeout")
self._is_closed = True
return None
def close(self):
"""Simulates closing the WebSocket connection."""
if not self._is_closed:
# Put None to unblock any waiting recv call
self.receive_queue.put(None)
self._is_closed = True
print("Mock WS closed")
def put_for_recv(self, data):
"""Puts data into the receive queue for the `recv` method to consume."""
if data is None:
return
logger.debug("Mock WS put_for_recv length: %s", len(data))
self.receive_queue.put(data)
@property
def is_closed(self):
return self._is_closed

View File

@ -5,6 +5,7 @@
from tests.pipeline.asr_test import test_asr_pipeline from tests.pipeline.asr_test import test_asr_pipeline
from src.utils.logger import get_module_logger, setup_root_logger from src.utils.logger import get_module_logger, setup_root_logger
from tests.runner.stt_runner_test import test_asr_runner
setup_root_logger(level="INFO", log_file="logs/test_main.log") setup_root_logger(level="INFO", log_file="logs/test_main.log")
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
@ -13,5 +14,8 @@ logger = get_module_logger(__name__)
# logger.info("开始测试VAD函数器") # logger.info("开始测试VAD函数器")
# test_vad_functor() # test_vad_functor()
logger.info("开始测试ASR管道") # logger.info("开始测试ASR管道")
test_asr_pipeline() # test_asr_pipeline()
logger.info("开始测试ASRRunner")
test_asr_runner()

View File

@ -71,6 +71,7 @@ def test_asr_pipeline():
callback=lambda x: print(f"pipeline callback: {x}") callback=lambda x: print(f"pipeline callback: {x}")
) )
asr_pipeline.bake()
# 运行Pipeline # 运行Pipeline
asr_instance = asr_pipeline.run() asr_instance = asr_pipeline.run()

View File

@ -0,0 +1,75 @@
"""
ASRRunner test
"""
import queue
import time
import soundfile
import numpy as np
from src.runner.ASRRunner import ASRRunner
from src.core.model_loader import ModelLoader
from src.models import AudioBinary_Config
from src.utils.mock_websocket import MockWebSocketClient
def test_asr_runner():
"""
End-to-end test for ASRRunner.
1. Loads models.
2. Configures and initializes ASRRunner.
3. Creates a mock WebSocket client.
4. Starts a new SenderAndReceiver (SAR) instance in the runner.
5. Streams audio data via the mock WebSocket.
6. Asserts that the received transcription matches the expected text.
"""
# 1. Load models
model_loader = ModelLoader()
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"spk_model": "cam++",
"spk_model_revision": "v2.0.2",
"audio_update": False,
}
models = model_loader.load_models(args)
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
# 2. Configure audio
audio_config = AudioBinary_Config(
chunk_size=200, # ms
chunk_stride=1600, # 10ms stride for 16kHz
sample_rate=sample_rate,
sample_width=2, # 16-bit
channels=1,
)
audio_config.chunk_stride = int(audio_config.chunk_stride * sample_rate / 1000)
# 3. Setup ASRRunner
asr_runner = ASRRunner()
asr_runner.set_default_config(
audio_config=audio_config,
models=models,
)
# 4. Create Mock WebSocket and start SAR
mock_ws = MockWebSocketClient()
sar_id = asr_runner.new_SAR(
ws=mock_ws,
name="test_sar",
)
assert sar_id is not None, "Failed to create a new SAR instance"
# 5. Simulate streaming audio
print(f"Sending audio data of length {len(audio_data)} samples.")
audio_clip_len = 200
for i in range(0, len(audio_data), audio_clip_len):
chunk = audio_data[i : i + audio_clip_len]
if not isinstance(chunk, np.ndarray) or chunk.size == 0:
break
# Simulate receiving binary data over WebSocket
mock_ws.put_for_recv(chunk)
# 6. Wait for results and assert
time.sleep(10)
# Signal end of audio stream by sending None
mock_ws.put_for_recv(None)