From 5dac718dee80f2be4e6e63cfc2028443865459f9 Mon Sep 17 00:00:00 2001 From: "Ziyang.Zhang" Date: Wed, 25 Jun 2025 16:57:41 +0800 Subject: [PATCH] =?UTF-8?q?[Runner]=E5=AE=8C=E6=88=90ASRRunner=E7=9A=84?= =?UTF-8?q?=E7=BC=96=E5=86=99=E5=92=8C=E6=B5=8B=E8=AF=95=EF=BC=8C=E4=BD=BF?= =?UTF-8?q?=E7=94=A8MockWebSocket=E5=AE=8C=E6=88=90=E8=99=9A=E6=8B=9F?= =?UTF-8?q?=E7=BD=91=E7=BB=9C=E8=BF=9E=E6=8E=A5=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/audio_chunk.py | 2 +- src/config.py | 14 ++ src/pipeline/base.py | 2 +- src/runner/ASRRunner.py | 340 +++++++++++++++----------------- src/runner/runner.py | 12 +- src/utils/mock_websocket.py | 55 ++++++ test_main.py | 8 +- tests/pipeline/asr_test.py | 1 + tests/runner/asr_runner_test.py | 75 +++++++ tests/runner/stt_runner.py | 0 10 files changed, 317 insertions(+), 192 deletions(-) create mode 100644 src/utils/mock_websocket.py create mode 100644 tests/runner/asr_runner_test.py delete mode 100644 tests/runner/stt_runner.py diff --git a/src/audio_chunk.py b/src/audio_chunk.py index 72b6241..d680950 100644 --- a/src/audio_chunk.py +++ b/src/audio_chunk.py @@ -128,7 +128,7 @@ class AudioChunk: 此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑。 """ - _instance: Optional[AudioChunk] = None + _instance: Optional["AudioChunk"] = None def __new__(cls, *args, **kwargs): """ diff --git a/src/config.py b/src/config.py index 7b9a34f..c123ad4 100644 --- a/src/config.py +++ b/src/config.py @@ -1,11 +1,25 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- """ +默认配置DefaultConfig + - audio_config: 音频配置 配置模块 - 处理命令行参数和配置项 """ import argparse +from src.models import AudioBinary_Config +class DefaultConfig: + """ + 默认配置 + """ + audio_config = AudioBinary_Config( + chunk_size=200, + chunk_stride=1600, + sample_rate=16000, + sample_width=16, + channels=1, + ) def parse_args(): """ diff --git a/src/pipeline/base.py b/src/pipeline/base.py index f594d59..89b4c4a 100644 --- a/src/pipeline/base.py +++ b/src/pipeline/base.py @@ -136,7 +136,7 @@ class PipelineFactory: pipeline.set_audio_binary(kwargs["audio_binary"]) pipeline.set_input_queue(kwargs["input_queue"]) pipeline.set_callback(kwargs["callback"]) - pipeline.bake() + # pipeline.bake() return pipeline @classmethod diff --git a/src/runner/ASRRunner.py b/src/runner/ASRRunner.py index 4f8c786..c9d503a 100644 --- a/src/runner/ASRRunner.py +++ b/src/runner/ASRRunner.py @@ -10,9 +10,15 @@ from src.pipeline.ASRpipeline import ASRPipeline from src.pipeline import PipelineFactory from src.models import AudioBinary_data_list, AudioBinary_Config from src.core.model_loader import ModelLoader +from src.config import DefaultConfig from queue import Queue import soundfile import time +from typing import List, Optional +import uuid +from threading import Thread +from src.utils.mock_websocket import MockWebSocketClient as WebSocketClient +from .runner import RunnerBase from src.utils.logger import get_module_logger @@ -23,205 +29,175 @@ OVAERWATCH = False model_loader = ModelLoader() - -def test_asr_pipeline(): - # 加载模型 - args = { - "asr_model": "paraformer-zh", - "asr_model_revision": "v2.0.4", - "vad_model": "fsmn-vad", - "vad_model_revision": "v2.0.4", - "spk_model": "cam++", - "spk_model_revision": "v2.0.2", - "audio_update": False, - } - models = model_loader.load_models(args) - audio_data, sample_rate = soundfile.read("tests/vad_example.wav") - audio_config = AudioBinary_Config( - chunk_size=200, - chunk_stride=1600, - sample_rate=sample_rate, - sample_width=16, - channels=1, - ) - chunk_stride = int(audio_config.chunk_size * sample_rate / 1000) - audio_config.chunk_stride = chunk_stride - - # 创建参数Dict - config = { - "audio_config": audio_config, - } - - # 创建音频数据列表 - audio_binary_data_list = AudioBinary_data_list() - - input_queue = Queue() - - # 创建Pipeline - # asr_pipeline = ASRPipeline() - # asr_pipeline.set_models(models) - # asr_pipeline.set_config(config) - # asr_pipeline.set_audio_binary(audio_binary_data_list) - # asr_pipeline.set_input_queue(input_queue) - # asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}")) - # asr_pipeline.bake() - asr_pipeline = PipelineFactory.create_pipeline( - pipeline_name = "ASRpipeline", - models=models, - config=config, - audio_binary=audio_binary_data_list, - input_queue=input_queue, - callback=lambda x: print(f"pipeline callback: {x}") - ) - - # 运行Pipeline - asr_instance = asr_pipeline.run() - - - audio_clip_len = 200 - print( - f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}" - ) - for i in range(0, len(audio_data), audio_clip_len): - input_queue.put(audio_data[i : i + audio_clip_len]) - - # time.sleep(10) - # input_queue.put(None) - - # 等待Pipeline结束 - # asr_instance.join() - - time.sleep(5) - asr_pipeline.stop() - # asr_pipeline.stop() - - - -class STTRunner(RunnerBase): +class ASRRunner(RunnerBase): """ 运行器类 负责管理资源和协调Pipeline的运行 """ + class SenderAndReceiver: + """ + 对于单个pipeline的管理 + 包含 发送者 和 接收者 + _sender: 发送者 唯一 + _receiver: 接收者 可以有多个 + _pipeline: 对应管道 唯一 + """ + def __init__(self, *args, **kwargs): + # 可选传入参数, + self._name: str = kwargs.get("name", "") + self._sender: Optional[WebSocketClient] = kwargs.get("sender", None) + self._receiver: List[WebSocketClient] = kwargs.get("receiver", []) - def __init__( - self, - *args, - **kwargs, - ): + # 资源 + self._audio_config: AudioBinary_Config = kwargs.get("audio_config", DefaultConfig.audio_config) + self._models: dict = kwargs.get("models", None) + self._audio_binary: AudioBinary_data_list = AudioBinary_data_list() + # id唯一标识 + self._id: str = str(uuid.uuid4()) + # 输入队列 + self._input_queue: Queue = Queue() + self._pipeline: Optional[ASRPipeline] = None + + def set_name(self, name: str): + self._name = name + + def set_id(self, id: str): + self._id = id + + def set_sender(self, sender: WebSocketClient): + self._sender = sender + + def set_pipeline(self, pipeline: ASRPipeline): + self._pipeline = pipeline + config = { + "audio_config": self._audio_config, + } + self._pipeline.set_config(config) + self._pipeline.set_models(self._models) + self._pipeline.set_audio_binary(self._audio_binary) + self._pipeline.set_input_queue(self._input_queue) + self._pipeline.set_callback(self.deal_message) + self._pipeline.bake() + + def append_receiver(self, receiver: WebSocketClient): + self._receiver.append(receiver) + + def delete_receiver(self, receiver: WebSocketClient): + self._receiver.remove(receiver) + + def deal_message(self, message: str): + self.broadcast(message) + + def broadcast(self, message: str): + """ + 广播发送给所有接收者 + """ + logger.info("[ASRRunner][SAR-%s]广播发送给所有接收者: %s", self._name, message) + for receiver in self._receiver: + receiver.send(message) + + def _run(self): + """ + 运行SAR + """ + self._pipeline.run() + while True: + data = self._sender.recv() + if data is None: + break + logger.debug("[ASRRunner][SAR-%s]接收到的数据length: %s", self._name, len(data)) + self._input_queue.put(data) + self.stop() + + def run(self): + """ + 运行SAR + """ + self._thread = Thread(target=self._run, name=f"[ASRRunner]SAR-{self._name}") + self._thread.daemon = True + self._thread.start() + + def stop(self): + """ + 停止SAR + """ + self._pipeline.stop() + for ws in self._receiver: + ws.close() + self._sender.close() + + def __init__(self,*args,**kwargs): """ - """ - # ws资源 - self._ws_pool: Dict[str,List[WebSocketClient]] = {} + """ # 接收资源 - self._audio_binary_list = audio_binary_list - self._models = models - self._pipeline_list = pipeline_list + self._default_audio_config = kwargs.get("audio_config", DefaultConfig.audio_config) + # self._audio_binary_list = args.get("audio_binary_list", None) + self._default_models = kwargs.get("models", None) + self._SAR_list: List[self.SenderAndReceiver] = [] - # 线程控制 - self._lock = Lock() - # 停止控制 - self._stop_timeout = 10.0 - self._is_stopping = False - - # 配置资源 - for pipeline in self._pipeline_list: - # 设置输入队列 - pipeline.set_input_queue(self._input_queue) - - # 配置资源 - pipeline.set_audio_binary( - self._audio_binary_list[pipeline.get_config("audio_binary_name")] - ) - pipeline.set_models(self._models) - - def run(self) -> None: + def set_default_config(self, *args, **kwargs): """ - 启动所有管道 + 设置配置 """ - logger.info("[%s] 启动所有管道", self.__class__.__name__) - if not self._pipeline_list: - raise RuntimeError("没有可用的管道") + self._default_audio_config = kwargs.get("audio_config", self._default_audio_config) + self._default_models = kwargs.get("models", self._default_models) - # 启动所有管道 - for pipeline in self._pipeline_list: - thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}") - thread.daemon = True - thread.start() - logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline)) - - def stop(self, force: bool = False) -> bool: + def new_SAR( + self, + ws: "WebSocketClient", + name: str = "", + audio_config: "AudioBinary_Config" = None, + models: dict = None + ) -> uuid.UUID: """ - 停止所有管道 - 参数: - force: 是否强制停止 - 返回: - bool: 是否成功停止 + 创建新的SAR SenderAndReceiver """ - if self._is_stopping: - logger.warning("运行器已经在停止中") - return False - - self._is_stopping = True - logger.info("正在停止运行器...") + if audio_config is None: + audio_config = self._default_audio_config + if models is None: + models = self._default_models try: - # 发送结束信号 - self._input_queue.put(None) - - # 停止所有管道 - success = True - for pipeline in self._pipeline_list: - if force: - pipeline.force_stop() - else: - if not pipeline.stop(timeout=self._stop_timeout): - logger.warning("管道 %s 停止超时", id(pipeline)) - success = False - - # 等待队列处理完成 - try: - start_time = time.time() - while not self._input_queue.empty(): - if time.time() - start_time > self._stop_timeout: - logger.warning( - "等待队列处理完成超时(%s秒), 队列中还有 %d 个任务未处理", - self._stop_timeout, - self._input_queue.qsize(), - ) - success = False - break - time.sleep(0.1) # 避免过度消耗CPU - except Exception as e: - error_type = type(e).__name__ - error_msg = str(e) - error_traceback = traceback.format_exc() - logger.error( - "等待队列处理完成时发生错误:\n" - "错误类型: %s\n" - "错误信息: %s\n" - "错误堆栈:\n%s", - error_type, - error_msg, - error_traceback, - ) - success = False - - if success: - logger.info("所有管道已成功停止") - else: - logger.warning( - "部分管道停止失败, 队列状态: 大小=%d, 是否为空=%s", - self._input_queue.qsize(), - self._input_queue.empty(), - ) - - return success - - finally: - self._is_stopping = False + new_SAR = self.SenderAndReceiver( + name=name, + audio_config=audio_config, + models=models + ) + new_pipeline = ASRPipeline() + new_SAR.set_pipeline(new_pipeline) + # new_SAR.set_pipeline() + logger.info("创建新的SAR: name %s, id %s", new_SAR._name, new_SAR._id) + new_SAR.set_sender(ws) + new_SAR.append_receiver(ws) + new_SAR.run() + self._SAR_list.append(new_SAR) + return new_SAR._id + except Exception as e: + logger.error("创建管道失败: %s", e) + return None + def join_SAR( + self, + ws: "WebSocketClient", + name: Optional[str] = None, + id: Optional[str] = None, + ) -> bool: + """ + 加入SAR的Receiver + """ + # 使用next获取迭代器下一个元素,生成pipeline_list迭代器,按id停止 + if id: + exist_pipeline = next((pipeline for pipeline in self._SAR_list if pipeline._id == id), None) + if name: + exist_pipeline = next((pipeline for pipeline in self._SAR_list if pipeline._name == name), None) + if exist_pipeline: + exist_pipeline.append_receiver(ws) + return True + return False + def __del__(self) -> None: """ 析构函数 """ - self.stop(force=True) + for sar in self._SAR_list: + sar.stop() diff --git a/src/runner/runner.py b/src/runner/runner.py index 4a5d21b..6966c6f 100644 --- a/src/runner/runner.py +++ b/src/runner/runner.py @@ -18,8 +18,8 @@ from queue import Queue import traceback import time -from src.audio_chunk import AudioChunk, AudioBinary -from src.pipeline import Pipeline, PipelineFactory +from src.audio_chunk import AudioChunk +from src.pipeline import PipelineFactory from src.core.model_loader import ModelLoader from src.utils.logger import get_module_logger @@ -48,7 +48,7 @@ class STTRunnerFactory: audio_binary_name: str, model_name_list: List[str], pipeline_name_list: List[str], - ) -> STTRunner: + ) -> RunnerBase: """ 创建运行器 参数: @@ -67,7 +67,7 @@ class STTRunnerFactory: PipelineFactory.create_pipeline(pipeline_name) for pipeline_name in pipeline_name_list ] - return STTRunner( + return RunnerBase( audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines ) @@ -75,7 +75,7 @@ class STTRunnerFactory: def create_runner_from_config( cls, config: Dict[str, Any], - ) -> STTRunner: + ) -> RunnerBase: """ 从配置创建运行器 参数: @@ -91,7 +91,7 @@ class STTRunnerFactory: ) @classmethod - def create_runner_normal(cls) -> STTRunner: + def create_runner_normal(cls) -> RunnerBase: """ 创建默认运行器 返回: diff --git a/src/utils/mock_websocket.py b/src/utils/mock_websocket.py new file mode 100644 index 0000000..6ae84a1 --- /dev/null +++ b/src/utils/mock_websocket.py @@ -0,0 +1,55 @@ +import queue + +from src.utils.logger import get_module_logger + +logger = get_module_logger(__name__) + +class MockWebSocketClient: + """A mock WebSocket client to simulate a connection for testing.""" + + def __init__(self): + self.sent_messages = [] + self._is_closed = False + self.receive_queue = queue.Queue() + + def send(self, message: dict): + """Simulates sending a message (which is a dict).""" + if self._is_closed: + print("Warning: sending message on a closed websocket") + return + self.sent_messages.append(message) + print(f"Mock WS received: {message}") + + def recv(self): + """Simulates receiving data from the WebSocket.""" + if self._is_closed: + return None + try: + # Block until data is available, with a timeout to prevent hanging. + data = self.receive_queue.get(timeout=10) + if data is None: + self._is_closed = True + return data + except queue.Empty: + print("Mock WS recv timeout") + self._is_closed = True + return None + + def close(self): + """Simulates closing the WebSocket connection.""" + if not self._is_closed: + # Put None to unblock any waiting recv call + self.receive_queue.put(None) + self._is_closed = True + print("Mock WS closed") + + def put_for_recv(self, data): + """Puts data into the receive queue for the `recv` method to consume.""" + if data is None: + return + logger.debug("Mock WS put_for_recv length: %s", len(data)) + self.receive_queue.put(data) + + @property + def is_closed(self): + return self._is_closed \ No newline at end of file diff --git a/test_main.py b/test_main.py index caaa40e..850efe4 100644 --- a/test_main.py +++ b/test_main.py @@ -5,6 +5,7 @@ from tests.pipeline.asr_test import test_asr_pipeline from src.utils.logger import get_module_logger, setup_root_logger +from tests.runner.stt_runner_test import test_asr_runner setup_root_logger(level="INFO", log_file="logs/test_main.log") logger = get_module_logger(__name__) @@ -13,5 +14,8 @@ logger = get_module_logger(__name__) # logger.info("开始测试VAD函数器") # test_vad_functor() -logger.info("开始测试ASR管道") -test_asr_pipeline() +# logger.info("开始测试ASR管道") +# test_asr_pipeline() + +logger.info("开始测试ASRRunner") +test_asr_runner() diff --git a/tests/pipeline/asr_test.py b/tests/pipeline/asr_test.py index fe0d064..1efd6b8 100644 --- a/tests/pipeline/asr_test.py +++ b/tests/pipeline/asr_test.py @@ -71,6 +71,7 @@ def test_asr_pipeline(): callback=lambda x: print(f"pipeline callback: {x}") ) + asr_pipeline.bake() # 运行Pipeline asr_instance = asr_pipeline.run() diff --git a/tests/runner/asr_runner_test.py b/tests/runner/asr_runner_test.py new file mode 100644 index 0000000..51008cd --- /dev/null +++ b/tests/runner/asr_runner_test.py @@ -0,0 +1,75 @@ +""" +ASRRunner test +""" +import queue +import time +import soundfile +import numpy as np +from src.runner.ASRRunner import ASRRunner +from src.core.model_loader import ModelLoader +from src.models import AudioBinary_Config +from src.utils.mock_websocket import MockWebSocketClient + +def test_asr_runner(): + """ + End-to-end test for ASRRunner. + 1. Loads models. + 2. Configures and initializes ASRRunner. + 3. Creates a mock WebSocket client. + 4. Starts a new SenderAndReceiver (SAR) instance in the runner. + 5. Streams audio data via the mock WebSocket. + 6. Asserts that the received transcription matches the expected text. + """ + # 1. Load models + model_loader = ModelLoader() + args = { + "asr_model": "paraformer-zh", + "asr_model_revision": "v2.0.4", + "vad_model": "fsmn-vad", + "vad_model_revision": "v2.0.4", + "spk_model": "cam++", + "spk_model_revision": "v2.0.2", + "audio_update": False, + } + models = model_loader.load_models(args) + audio_data, sample_rate = soundfile.read("tests/vad_example.wav") + + # 2. Configure audio + audio_config = AudioBinary_Config( + chunk_size=200, # ms + chunk_stride=1600, # 10ms stride for 16kHz + sample_rate=sample_rate, + sample_width=2, # 16-bit + channels=1, + ) + audio_config.chunk_stride = int(audio_config.chunk_stride * sample_rate / 1000) + + # 3. Setup ASRRunner + asr_runner = ASRRunner() + asr_runner.set_default_config( + audio_config=audio_config, + models=models, + ) + + # 4. Create Mock WebSocket and start SAR + mock_ws = MockWebSocketClient() + sar_id = asr_runner.new_SAR( + ws=mock_ws, + name="test_sar", + ) + assert sar_id is not None, "Failed to create a new SAR instance" + + # 5. Simulate streaming audio + print(f"Sending audio data of length {len(audio_data)} samples.") + audio_clip_len = 200 + for i in range(0, len(audio_data), audio_clip_len): + chunk = audio_data[i : i + audio_clip_len] + if not isinstance(chunk, np.ndarray) or chunk.size == 0: + break + # Simulate receiving binary data over WebSocket + mock_ws.put_for_recv(chunk) + + # 6. Wait for results and assert + time.sleep(10) + # Signal end of audio stream by sending None + mock_ws.put_for_recv(None) diff --git a/tests/runner/stt_runner.py b/tests/runner/stt_runner.py deleted file mode 100644 index e69de29..0000000