From 7b9a79942daf2619b7e211919c2f206aa9a7a7ff Mon Sep 17 00:00:00 2001 From: "Ziyang.Zhang" Date: Tue, 24 Jun 2025 09:22:48 +0800 Subject: [PATCH] =?UTF-8?q?[=E4=BB=A3=E7=A0=81=E9=87=8D=E6=9E=84=E4=B8=AD]?= =?UTF-8?q?=20=E5=B0=86ModelLoader=E7=A7=BB=E8=87=B3core=E7=9B=AE=E5=BD=95?= =?UTF-8?q?=EF=BC=8C=E6=9B=B4=E6=96=B0=E7=9B=B8=E5=85=B3=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E7=9A=84=E5=AF=BC=E5=85=A5=E8=B7=AF=E5=BE=84?= =?UTF-8?q?=E3=80=82=E5=88=9B=E5=BB=BAASRRunner,=E5=88=9D=E6=AD=A5?= =?UTF-8?q?=E6=90=AD=E5=BB=BA=E6=A1=86=E6=9E=B6=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/{ => core}/model_loader.py | 0 src/runner.py | 281 --------------------------------- src/runner/ASRRunner.py | 227 ++++++++++++++++++++++++++ src/runner/runner.py | 105 ++++++++++++ tests/functor/vad_test.py | 2 +- tests/modelsuse.py | 2 +- tests/pipeline/asr_test.py | 2 +- tests/runner/stt_runner.py | 0 8 files changed, 335 insertions(+), 284 deletions(-) rename src/{ => core}/model_loader.py (100%) delete mode 100644 src/runner.py create mode 100644 src/runner/ASRRunner.py create mode 100644 src/runner/runner.py create mode 100644 tests/runner/stt_runner.py diff --git a/src/model_loader.py b/src/core/model_loader.py similarity index 100% rename from src/model_loader.py rename to src/core/model_loader.py diff --git a/src/runner.py b/src/runner.py deleted file mode 100644 index 9d908cc..0000000 --- a/src/runner.py +++ /dev/null @@ -1,281 +0,0 @@ -""" -运行器模块 -提供运行器基类和运行器类,用于管理音频数据和模型的交互。 -主要包含: -- RunnerBase: 运行器基类,定义了基本接口 -- Runner: 运行器类,工厂模式实现 -- RunnerFactory: 运行器工厂类,用于创建运行器 -""" - -from abc import ABC, abstractmethod -from typing import Dict, Any, List -from threading import Thread, Lock -from queue import Queue -import traceback -import time - -from src.audio_chunk import AudioChunk, AudioBinary -from src.pipeline import Pipeline, PipelineFactory -from src.model_loader import ModelLoader -from src.utils.logger import get_module_logger - -logger = get_module_logger(__name__, level="INFO") - -audio_chunk = AudioChunk() -models_loaded = ModelLoader() - - -class RunnerBase(ABC): - """ - 运行器基类 - 定义了运行器的基本接口 - """ - - @abstractmethod - def adder(self, data: Any) -> None: - """ - 添加数据 - 参数: - data: 要添加的数据 - """ - pass - - @abstractmethod - def add_recevier(self, receiver: callable) -> None: - """ - 添加数据接收者 - 参数: - receiver: 接收数据的回调函数 - """ - pass - - -class STTRunner(RunnerBase): - """ - 运行器类 - 负责管理资源和协调Pipeline的运行 - """ - - def __init__( - self, - *, - audio_binary_list: List[AudioBinary], - models: Dict[str, Any], - pipeline_list: List[Pipeline], - ): - """ - 初始化运行器 - 参数: - audio_binary_list: 音频二进制列表 - models: 模型字典 - pipeline_list: 管道列表 - queue_size: 队列大小 - stop_timeout: 停止超时时间(秒) - """ - # 接收资源 - self._audio_binary_list = audio_binary_list - self._models = models - self._pipeline_list = pipeline_list - - # 线程控制 - self._lock = Lock() - - # 消息队列 - self._input_queue = Queue(maxsize=1000) - - # 停止控制 - self._stop_timeout = 10.0 - self._is_stopping = False - - # 配置资源 - for pipeline in self._pipeline_list: - # 设置输入队列 - pipeline.set_input_queue(self._input_queue) - - # 配置资源 - pipeline.set_audio_binary( - self._audio_binary_list[pipeline.get_config("audio_binary_name")] - ) - pipeline.set_models(self._models) - - def adder(self, data: Any) -> None: - """ - 添加数据到输入队列 - 参数: - data: 要添加的数据 - """ - if not self._pipeline_list: - raise RuntimeError("没有可用的管道") - if self._is_stopping: - raise RuntimeError("运行器正在停止,无法添加数据") - self._input_queue.put(data) - - def add_recevier(self, receiver: callable) -> None: - """ - 添加数据接收者 - 参数: - receiver: 接收数据的回调函数 - """ - with self._lock: - for pipeline in self._pipeline_list: - pipeline.add_callback(receiver) - - def run(self) -> None: - """ - 启动所有管道 - """ - logger.info("[%s] 启动所有管道", self.__class__.__name__) - if not self._pipeline_list: - raise RuntimeError("没有可用的管道") - - # 启动所有管道 - for pipeline in self._pipeline_list: - thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}") - thread.daemon = True - thread.start() - logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline)) - - def stop(self, force: bool = False) -> bool: - """ - 停止所有管道 - 参数: - force: 是否强制停止 - 返回: - bool: 是否成功停止 - """ - if self._is_stopping: - logger.warning("运行器已经在停止中") - return False - - self._is_stopping = True - logger.info("正在停止运行器...") - - try: - # 发送结束信号 - self._input_queue.put(None) - - # 停止所有管道 - success = True - for pipeline in self._pipeline_list: - if force: - pipeline.force_stop() - else: - if not pipeline.stop(timeout=self._stop_timeout): - logger.warning("管道 %s 停止超时", id(pipeline)) - success = False - - # 等待队列处理完成 - try: - start_time = time.time() - while not self._input_queue.empty(): - if time.time() - start_time > self._stop_timeout: - logger.warning( - "等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理", - self._stop_timeout, - self._input_queue.qsize(), - ) - success = False - break - time.sleep(0.1) # 避免过度消耗CPU - except Exception as e: - error_type = type(e).__name__ - error_msg = str(e) - error_traceback = traceback.format_exc() - logger.error( - "等待队列处理完成时发生错误:\n" - "错误类型: %s\n" - "错误信息: %s\n" - "错误堆栈:\n%s", - error_type, - error_msg, - error_traceback, - ) - success = False - - if success: - logger.info("所有管道已成功停止") - else: - logger.warning( - "部分管道停止失败,队列状态: 大小=%d, 是否为空=%s", - self._input_queue.qsize(), - self._input_queue.empty(), - ) - - return success - - finally: - self._is_stopping = False - - def __del__(self) -> None: - """ - 析构函数 - """ - self.stop(force=True) - - -class STTRunnerFactory: - """ - STT Runner工厂类 - 用于创建运行器实例 - """ - - @staticmethod - def _create_runner( - audio_binary_name: str, - model_name_list: List[str], - pipeline_name_list: List[str], - ) -> STTRunner: - """ - 创建运行器 - 参数: - audio_binary_name: 音频二进制名称 - model_name_list: 模型名称列表 - pipeline_name_list: 管道名称列表 - 返回: - Runner实例 - """ - audio_binary = audio_chunk.get_audio_binary(audio_binary_name) - models: Dict[str, Any] = { - model_name: models_loaded.models[model_name] - for model_name in model_name_list - } - pipelines: List[Pipeline] = [ - PipelineFactory.create_pipeline(pipeline_name) - for pipeline_name in pipeline_name_list - ] - return STTRunner( - audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines - ) - - @classmethod - def create_runner_from_config( - cls, - config: Dict[str, Any], - ) -> STTRunner: - """ - 从配置创建运行器 - 参数: - config: 配置字典 - 返回: - Runner实例 - """ - audio_binary_name = config["audio_binary_name"] - model_name_list = config["model_name_list"] - pipeline_name_list = config["pipeline_name_list"] - return cls._create_runner( - audio_binary_name, model_name_list, pipeline_name_list - ) - - @classmethod - def create_runner_normal(cls) -> STTRunner: - """ - 创建默认运行器 - 返回: - Runner实例 - """ - audio_binary_name = None - model_name_list = list(models_loaded.models.keys()) - pipeline_name_list = None - return cls._create_runner( - audio_binary_name, model_name_list, pipeline_name_list - ) diff --git a/src/runner/ASRRunner.py b/src/runner/ASRRunner.py new file mode 100644 index 0000000..4f8c786 --- /dev/null +++ b/src/runner/ASRRunner.py @@ -0,0 +1,227 @@ +""" +-*- encoding: utf-8 -*- + +ASRRunner +继承RunnerBase +专属pipeline为ASRPipeline +""" + +from src.pipeline.ASRpipeline import ASRPipeline +from src.pipeline import PipelineFactory +from src.models import AudioBinary_data_list, AudioBinary_Config +from src.core.model_loader import ModelLoader +from queue import Queue +import soundfile +import time + +from src.utils.logger import get_module_logger + +logger = get_module_logger(__name__) + + +OVAERWATCH = False + +model_loader = ModelLoader() + + +def test_asr_pipeline(): + # 加载模型 + args = { + "asr_model": "paraformer-zh", + "asr_model_revision": "v2.0.4", + "vad_model": "fsmn-vad", + "vad_model_revision": "v2.0.4", + "spk_model": "cam++", + "spk_model_revision": "v2.0.2", + "audio_update": False, + } + models = model_loader.load_models(args) + audio_data, sample_rate = soundfile.read("tests/vad_example.wav") + audio_config = AudioBinary_Config( + chunk_size=200, + chunk_stride=1600, + sample_rate=sample_rate, + sample_width=16, + channels=1, + ) + chunk_stride = int(audio_config.chunk_size * sample_rate / 1000) + audio_config.chunk_stride = chunk_stride + + # 创建参数Dict + config = { + "audio_config": audio_config, + } + + # 创建音频数据列表 + audio_binary_data_list = AudioBinary_data_list() + + input_queue = Queue() + + # 创建Pipeline + # asr_pipeline = ASRPipeline() + # asr_pipeline.set_models(models) + # asr_pipeline.set_config(config) + # asr_pipeline.set_audio_binary(audio_binary_data_list) + # asr_pipeline.set_input_queue(input_queue) + # asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}")) + # asr_pipeline.bake() + asr_pipeline = PipelineFactory.create_pipeline( + pipeline_name = "ASRpipeline", + models=models, + config=config, + audio_binary=audio_binary_data_list, + input_queue=input_queue, + callback=lambda x: print(f"pipeline callback: {x}") + ) + + # 运行Pipeline + asr_instance = asr_pipeline.run() + + + audio_clip_len = 200 + print( + f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}" + ) + for i in range(0, len(audio_data), audio_clip_len): + input_queue.put(audio_data[i : i + audio_clip_len]) + + # time.sleep(10) + # input_queue.put(None) + + # 等待Pipeline结束 + # asr_instance.join() + + time.sleep(5) + asr_pipeline.stop() + # asr_pipeline.stop() + + + +class STTRunner(RunnerBase): + """ + 运行器类 + 负责管理资源和协调Pipeline的运行 + """ + + def __init__( + self, + *args, + **kwargs, + ): + """ + """ + # ws资源 + self._ws_pool: Dict[str,List[WebSocketClient]] = {} + # 接收资源 + self._audio_binary_list = audio_binary_list + self._models = models + self._pipeline_list = pipeline_list + + # 线程控制 + self._lock = Lock() + # 停止控制 + self._stop_timeout = 10.0 + self._is_stopping = False + + # 配置资源 + for pipeline in self._pipeline_list: + # 设置输入队列 + pipeline.set_input_queue(self._input_queue) + + # 配置资源 + pipeline.set_audio_binary( + self._audio_binary_list[pipeline.get_config("audio_binary_name")] + ) + pipeline.set_models(self._models) + + def run(self) -> None: + """ + 启动所有管道 + """ + logger.info("[%s] 启动所有管道", self.__class__.__name__) + if not self._pipeline_list: + raise RuntimeError("没有可用的管道") + + # 启动所有管道 + for pipeline in self._pipeline_list: + thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}") + thread.daemon = True + thread.start() + logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline)) + + def stop(self, force: bool = False) -> bool: + """ + 停止所有管道 + 参数: + force: 是否强制停止 + 返回: + bool: 是否成功停止 + """ + if self._is_stopping: + logger.warning("运行器已经在停止中") + return False + + self._is_stopping = True + logger.info("正在停止运行器...") + + try: + # 发送结束信号 + self._input_queue.put(None) + + # 停止所有管道 + success = True + for pipeline in self._pipeline_list: + if force: + pipeline.force_stop() + else: + if not pipeline.stop(timeout=self._stop_timeout): + logger.warning("管道 %s 停止超时", id(pipeline)) + success = False + + # 等待队列处理完成 + try: + start_time = time.time() + while not self._input_queue.empty(): + if time.time() - start_time > self._stop_timeout: + logger.warning( + "等待队列处理完成超时(%s秒), 队列中还有 %d 个任务未处理", + self._stop_timeout, + self._input_queue.qsize(), + ) + success = False + break + time.sleep(0.1) # 避免过度消耗CPU + except Exception as e: + error_type = type(e).__name__ + error_msg = str(e) + error_traceback = traceback.format_exc() + logger.error( + "等待队列处理完成时发生错误:\n" + "错误类型: %s\n" + "错误信息: %s\n" + "错误堆栈:\n%s", + error_type, + error_msg, + error_traceback, + ) + success = False + + if success: + logger.info("所有管道已成功停止") + else: + logger.warning( + "部分管道停止失败, 队列状态: 大小=%d, 是否为空=%s", + self._input_queue.qsize(), + self._input_queue.empty(), + ) + + return success + + finally: + self._is_stopping = False + + def __del__(self) -> None: + """ + 析构函数 + """ + self.stop(force=True) diff --git a/src/runner/runner.py b/src/runner/runner.py new file mode 100644 index 0000000..4a5d21b --- /dev/null +++ b/src/runner/runner.py @@ -0,0 +1,105 @@ +""" +-*- encoding: utf-8 -*- + +Runner类 +所有的Runner都对应一个fastapi的endpoint, +Runner需要处理: +1.新的websocket 进来后放到 unknow_websocket_pool中 +2.收到特定消息后, 将消息转发给特定的pipeline处理 +3.管理pipeline与websocket对应关系, 管理pipeline的ID +4.管理pipeline的启动和停止 +5.管理所有pipeline用到的资源, 管理pipeline的存活时间。 +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any, List +from threading import Thread, Lock +from queue import Queue +import traceback +import time + +from src.audio_chunk import AudioChunk, AudioBinary +from src.pipeline import Pipeline, PipelineFactory +from src.core.model_loader import ModelLoader +from src.utils.logger import get_module_logger + +logger = get_module_logger(__name__) + +audio_chunk = AudioChunk() +models_loaded = ModelLoader() + + +class RunnerBase(ABC): + """ + 运行器基类 + 定义了运行器的基本接口 + """ + def __init__(self, *args, **kwargs): + pass + +class STTRunnerFactory: + """ + STT Runner工厂类 + 用于创建运行器实例 + """ + + @staticmethod + def _create_runner( + audio_binary_name: str, + model_name_list: List[str], + pipeline_name_list: List[str], + ) -> STTRunner: + """ + 创建运行器 + 参数: + audio_binary_name: 音频二进制名称 + model_name_list: 模型名称列表 + pipeline_name_list: 管道名称列表 + 返回: + Runner实例 + """ + audio_binary = audio_chunk.get_audio_binary(audio_binary_name) + models: Dict[str, Any] = { + model_name: models_loaded.models[model_name] + for model_name in model_name_list + } + pipelines: List[Pipeline] = [ + PipelineFactory.create_pipeline(pipeline_name) + for pipeline_name in pipeline_name_list + ] + return STTRunner( + audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines + ) + + @classmethod + def create_runner_from_config( + cls, + config: Dict[str, Any], + ) -> STTRunner: + """ + 从配置创建运行器 + 参数: + config: 配置字典 + 返回: + Runner实例 + """ + audio_binary_name = config["audio_binary_name"] + model_name_list = config["model_name_list"] + pipeline_name_list = config["pipeline_name_list"] + return cls._create_runner( + audio_binary_name, model_name_list, pipeline_name_list + ) + + @classmethod + def create_runner_normal(cls) -> STTRunner: + """ + 创建默认运行器 + 返回: + Runner实例 + """ + audio_binary_name = None + model_name_list = list(models_loaded.models.keys()) + pipeline_name_list = None + return cls._create_runner( + audio_binary_name, model_name_list, pipeline_name_list + ) diff --git a/tests/functor/vad_test.py b/tests/functor/vad_test.py index 97016dd..a8cfd13 100644 --- a/tests/functor/vad_test.py +++ b/tests/functor/vad_test.py @@ -7,7 +7,7 @@ from src.functor.vad_functor import VADFunctor from src.functor.asr_functor import ASRFunctor from src.functor.spk_functor import SPKFunctor from queue import Queue, Empty -from src.model_loader import ModelLoader +from src.core.model_loader import ModelLoader from src.models import AudioBinary_Config, AudioBinary_data_list from src.utils.data_format import wav_to_bytes import time diff --git a/tests/modelsuse.py b/tests/modelsuse.py index 8123d3f..f66a746 100644 --- a/tests/modelsuse.py +++ b/tests/modelsuse.py @@ -61,7 +61,7 @@ def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]: # from src.functor.model_loader import load_models # models = load_models(args) - from src.model_loader import ModelLoader + from src.core.model_loader import ModelLoader models = ModelLoader(args) diff --git a/tests/pipeline/asr_test.py b/tests/pipeline/asr_test.py index 631383c..fe0d064 100644 --- a/tests/pipeline/asr_test.py +++ b/tests/pipeline/asr_test.py @@ -6,7 +6,7 @@ VAD+ASR+SPK(FAKE) from src.pipeline.ASRpipeline import ASRPipeline from src.pipeline import PipelineFactory from src.models import AudioBinary_data_list, AudioBinary_Config -from src.model_loader import ModelLoader +from src.core.model_loader import ModelLoader from queue import Queue import soundfile import time diff --git a/tests/runner/stt_runner.py b/tests/runner/stt_runner.py new file mode 100644 index 0000000..e69de29