From 5b94c4001615b7bca09b2c97e5be3b16ffca4da1 Mon Sep 17 00:00:00 2001 From: "Ziyang.Zhang" Date: Fri, 6 Jun 2025 17:26:08 +0800 Subject: [PATCH] =?UTF-8?q?[=E4=BB=A3=E7=A0=81=E9=87=8D=E6=9E=84=E4=B8=AD]?= =?UTF-8?q?=E7=BC=96=E5=86=99=E8=9E=8D=E5=90=88VAD,ASR,SPK(FAKE)=E7=9A=84A?= =?UTF-8?q?SRPipeline=E5=B9=B6=E5=AE=8C=E6=88=90=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=EF=BC=8C=E6=AD=A3=E5=B8=B8=E8=BF=90=E8=A1=8C=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/functor/base.py | 64 +++++++++- src/functor/readme.md | 127 ++++++++++++++++---- src/functor/spk_functor.py | 4 +- src/model_loader.py | 10 +- src/pipeline/ASRpipeline.py | 226 ++++++++++++++++++++++-------------- src/pipeline/__init__.py | 4 +- src/pipeline/base.py | 6 +- test_main.py | 8 +- tests/functor/vad_test.py | 4 +- tests/pipeline/asr_test.py | 80 +++++++++++++ 10 files changed, 405 insertions(+), 128 deletions(-) create mode 100644 tests/pipeline/asr_test.py diff --git a/src/functor/base.py b/src/functor/base.py index b0f8eb4..a6bd904 100644 --- a/src/functor/base.py +++ b/src/functor/base.py @@ -110,6 +110,8 @@ class BaseFunctor(ABC): 停止结果 """ + + class FunctorFactory: """ Functor工厂类 @@ -121,8 +123,8 @@ class FunctorFactory: 创建并配置Functor实例 """ - @staticmethod - def make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor: + @classmethod + def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor: """ 创建并配置Functor实例 @@ -134,4 +136,60 @@ class FunctorFactory: 返回: BaseFunctor: 创建的Functor实例 """ - \ No newline at end of file + + if functor_name == "vad": + return cls._make_vadfunctor(config = config,models = models) + elif functor_name == "asr": + return cls._make_asrfunctor(config = config,models = models) + elif functor_name == "spk": + return cls._make_spkfunctor(config = config,models = models) + else: + raise ValueError(f"不支持的Functor类型: {functor_name}") + + def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor: + """ + 创建VAD Functor实例 + """ + from src.functor.vad_functor import VADFunctor + audio_config = config["audio_config"] + model = { + "vad": models["vad"] + } + + vad_functor = VADFunctor() + vad_functor.set_audio_config(audio_config) + vad_functor.set_model(model) + + return vad_functor + + def _make_asrfunctor(config: dict, models: dict) -> BaseFunctor: + """ + 创建ASR Functor实例 + """ + from src.functor.asr_functor import ASRFunctor + audio_config = config["audio_config"] + model = { + "asr": models["asr"] + } + + asr_functor = ASRFunctor() + asr_functor.set_audio_config(audio_config) + asr_functor.set_model(model) + + return asr_functor + + def _make_spkfunctor(config: dict, models: dict) -> BaseFunctor: + """ + 创建SPK Functor实例 + """ + from src.functor.spk_functor import SPKFunctor + audio_config = config["audio_config"] + model = { + "spk": models["spk"] + } + + spk_functor = SPKFunctor() + spk_functor.set_audio_config(audio_config) + spk_functor.set_model(model) + + return spk_functor \ No newline at end of file diff --git a/src/functor/readme.md b/src/functor/readme.md index a7dee70..7f918f1 100644 --- a/src/functor/readme.md +++ b/src/functor/readme.md @@ -6,38 +6,117 @@ Functor文件夹用于存放所有功能性的类,包括VAD、PUNC、ASR、SPK ## Functor 类的定义 -所有类应继承于 **基类** `BaseFunctor` ,应遵从 *压入数据* 与 *数据处理* 解绑。 +所有类应继承于**基类**`BaseFunctor`。 -为了方便使用,我们对于 **基类** 的定义如下: +为了方便使用,我们对于**基类**的定义如下: -1. 函数内部使用的变量以单下划线开头,预定有 `_data`, `_callback`, `_model`等 +1. 函数内部使用的变量以单下划线开头,基类中包含: + + * _model: Dict 存放模型相关的配置和实例 + * _input_queue: Queue 监听的输入消息队列 + * _thread: Threading.Thread 运行的线程实例 + * _callback: List[Callable] 回调函数列表 + * _is_running: bool 线程运行状态标志 + * _stop_event: bool 停止事件标志 + * _status_lock: threading.Lock 状态锁,用于线程同步 2. 对于使用的模型,请从统一的 **模型管理类`ModelLoader`** 中获取,由模型管理类统一进行加载、缓存和释放,`_model`存放类型为`dict`。 -3. 定义了 - - `__call__`:可传入`data`,默认调用`push_data`,随后默认调用`process` - - `__add__` +3. 基类定义的核心方法: + + * `add_callback(callback: Callable)`: 添加结果处理的回调函数 + * `set_model(model: dict)`: 设置模型配置和实例 + * `set_input_queue(queue: Queue)`: 设置输入数据队列 + * `run()`: 启动处理线程(抽象方法) + * `stop()`: 停止处理线程(抽象方法) + * `_run()`: 线程运行的具体逻辑(抽象方法) + * `_pre_check()`: 运行前的预检查(抽象方法) + +## 派生类实现要求 + +1. 必须实现的抽象方法: + * `_pre_check()`: + - 检查必要的配置是否完整(如模型、队列等) + - 检查运行环境是否满足要求 + - 返回检查结果 + + * `_run()`: + - 实现具体的数据处理逻辑 + - 从 _input_queue 获取输入数据 + - 使用 _model 进行处理 + - 通过 _callback 返回处理结果 + + * `run()`: + - 调用 _pre_check() 进行预检查 + - 创建并启动处理线程 + - 设置相关状态标志 + + * `stop()`: + - 安全停止处理线程 + - 清理资源 + - 重置状态标志 + +2. 建议实现的方法: + * `__str__`: 返回当前实例的状态信息 + * 错误处理方法:处理运行过程中的异常情况 + +## 使用示例 ```python -class BaseFunctor: - def __init__(self): - self._data: dict = {} - self._callback: function = null - pass +class MyFunctor(BaseFunctor): + def _pre_check(self): + if not self._model or not self._input_queue: + return False + return True - def __call__(self, data): - result = - self._callback(process(data)) - return self.process(data) + def _run(self): + while not self._stop_event: + try: + data = self._input_queue.get(timeout=1.0) + result = self._model['my_model'].process(data) + for callback in self._callback: + callback(result) + except Queue.Empty: + continue + except Exception as e: + logger.error(f"处理错误: {e}") - def set_callback(self, callback: Callable): - self._callback = callback + def run(self): + if not self._pre_check(): + raise RuntimeError("预检查失败") + + with self._status_lock: + if self._is_running: + return + self._is_running = True + self._stop_event = False + self._thread = threading.Thread(target=self._run) + self._thread.start() - def push_data(): - pass + def stop(self): + with self._status_lock: + if not self._is_running: + return + self._stop_event = True + if self._thread: + self._thread.join() + self._is_running = False +``` - def process(self, data): - pass -``` \ No newline at end of file +## 注意事项 + +1. 线程安全: + * 使用 _status_lock 保护状态变更 + * 注意共享资源的访问控制 + +2. 错误处理: + * 在 _run() 中妥善处理异常 + * 提供详细的错误日志 + +3. 资源管理: + * 确保在 stop() 中正确清理资源 + * 避免资源泄露 + +4. 回调函数: + * 回调函数应该是非阻塞的 + * 处理回调函数抛出的异常 \ No newline at end of file diff --git a/src/functor/spk_functor.py b/src/functor/spk_functor.py index 8d0fa7a..c72bb98 100644 --- a/src/functor/spk_functor.py +++ b/src/functor/spk_functor.py @@ -13,9 +13,9 @@ from src.utils.logger import get_module_logger logger = get_module_logger(__name__) -class SpkFunctor(BaseFunctor): +class SPKFunctor(BaseFunctor): """ - SpkFunctor + SPKFunctor 负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback 需要配置好 _model, _callback, _input_queue, _audio_config 否则无法run()启动线程 diff --git a/src/model_loader.py b/src/model_loader.py index f4a116c..50717b2 100644 --- a/src/model_loader.py +++ b/src/model_loader.py @@ -38,7 +38,7 @@ class ModelLoader: 初始化ModelLoader实例 """ self.models = {} - logger.info("初始化ModelLoader") + logger.debug("初始化ModelLoader") if args is not None: self.__call__(args) @@ -87,14 +87,14 @@ class ModelLoader: else: value = input_model_args.get(key, None) if value is not None: - logger.info("替换%s模型参数: %s = %s", model_type, key, value) + logger.debug("替换%s模型参数: %s = %s", model_type, key, value) model_args[key] = value # 验证必要参数 if not model_args["model"]: raise ValueError(f"未指定{model_type}模型路径") try: # 使用 % 格式化替代 f-string,避免不必要的字符串格式化开销 - logger.info("正在加载%s模型: %s", model_type, model_args["model"]) + logger.debug("正在加载%s模型: %s", model_type, model_args["model"]) model = AutoModel(**model_args) return model except Exception as e: @@ -115,12 +115,12 @@ class ModelLoader: self.models = {} # 加载离线ASR模型 # 检查对应键是否存在 - model_list = ['asr', 'asr_online', 'vad', 'punc'] + model_list = ['asr', 'asr_online', 'vad', 'punc', 'spk'] for model_name in model_list: name_model = f"{model_name}_model" name_model_revision = f"{model_name}_model_revision" if name_model in args: - logger.info("加载%s模型", model_name) + logger.debug("加载%s模型", model_name) self.models[model_name] = self._load_model(args, model_name) logger.info("所有模型加载完成") return self.models diff --git a/src/pipeline/ASRpipeline.py b/src/pipeline/ASRpipeline.py index 85f1818..a50e2b7 100644 --- a/src/pipeline/ASRpipeline.py +++ b/src/pipeline/ASRpipeline.py @@ -1,7 +1,9 @@ from src.pipeline.base import PipelineBase -from typing import Dict, Any -from queue import Queue +from typing import Dict, Any, Callable +from queue import Queue, Empty from src.utils import get_module_logger +from src.models import AudioBinary_data_list +import threading logger = get_module_logger(__name__) @@ -19,6 +21,12 @@ class ASRPipeline(PipelineBase): self._functor_dict: Dict[str, Any] = {} self._subqueue_dict: Dict[str, Any] = {} self._is_baked: bool = False + self._input_queue: Queue = None + self._audio_binary_data_list: AudioBinary_data_list = None + + self._status_lock = threading.Lock() + self._is_running: bool = False + self._stop_event: bool = False def set_config(self, config: Dict[str, Any]) -> None: """ @@ -28,7 +36,15 @@ class ASRPipeline(PipelineBase): """ self._config = config - def set_audio_binary(self, audio_binary: AudioBinary) -> None: + def get_config(self) -> Dict[str, Any]: + """ + 获取配置 + 返回: + Dict[str, Any] 配置 + """ + return self._config + + def set_audio_binary(self, audio_binary: AudioBinary_data_list) -> None: """ 设置音频二进制存储单元 参数: @@ -42,44 +58,70 @@ class ASRPipeline(PipelineBase): """ self._models = models + def set_input_queue(self, input_queue: Queue) -> None: + """ + 设置输入队列 + """ + self._input_queue = input_queue + def bake(self) -> None: """ 烘焙管道 """ + self._pre_check_resource() self._init_functor() self._is_baked = True + def _pre_check_resource(self) -> None: + """ + 预检查资源 + """ + if self._input_queue is None: + raise RuntimeError("[ASRpipeline]输入队列未设置") + if self._functor_dict is None: + raise RuntimeError("[ASRpipeline]functor字典未设置") + if self._subqueue_dict is None: + raise RuntimeError("[ASRpipeline]子队列字典未设置") + if self._audio_binary is None: + raise RuntimeError("[ASRpipeline]音频存储单元未设置") + def _init_functor(self) -> None: """ 初始化函数 """ try: - from src.functor import functorFactory + from src.functor import FunctorFactory # 加载VAD、asr、spk functor - self._functor_dict["vad"] = functorFactory.make_functor( + self._functor_dict["vad"] = FunctorFactory.make_functor( functor_name = "vad", config = self._config, models = self._models ) - self._functor_dict["asr"] = functorFactory.make_functor( + self._functor_dict["asr"] = FunctorFactory.make_functor( functor_name = "asr", config = self._config, models = self._models ) - self._functor_dict["spk"] = functorFactory.make_functor( + self._functor_dict["spk"] = FunctorFactory.make_functor( functor_name = "spk", config = self._config, models = self._models ) + # 创建音频数据存储单元 + self._audio_binary_data_list = AudioBinary_data_list() + + self._functor_dict["vad"].set_audio_binary_data_list(self._audio_binary_data_list) + # 初始化子队列 + self._subqueue_dict["original"] = Queue() self._subqueue_dict["vad2asr"] = Queue() self._subqueue_dict["vad2spk"] = Queue() self._subqueue_dict["asrend"] = Queue() self._subqueue_dict["spkend"] = Queue() # 设置子队列的输入队列 - self._functor_dict["vad"].set_input_queue(self._input_queue) + self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"]) self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"]) self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"]) @@ -94,7 +136,7 @@ class ASRPipeline(PipelineBase): """ def put_with_check(data: Any) -> None: queue.put(data) - callback() + callback(data) return put_with_check self._functor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result)) @@ -102,76 +144,7 @@ class ASRPipeline(PipelineBase): except ImportError: raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化") - - def get_config(self) -> Dict[str, Any]: - """ - 获取配置 - 返回: - Dict[str, Any] 配置 - """ - return self._config - - def process(self, data: Any) -> Any: - """ - 处理数据 - 参数: - data: 输入数据 - 返回: - 处理结果 - """ - # 子类实现具体的处理逻辑 - self._input_queue.put(data) - - def run(self) -> None: - """ - 运行管道 - """ - if not self._is_baked: - raise RuntimeError("管道未烘焙,无法运行") - - # 运行所有functor - for functor_name, functor in self._functor_dict.items(): - logger.info(f"运行{functor_name}functor") - self._functor_dict[functor_name].run() - - # 运行管道 - if not self._input_queue: - raise RuntimeError("输入队列未设置") - - # 设置管道运行状态 - self._is_running = True - self._stop_event = False - self._thread = threading.current_thread() - logger.info("ASR管道开始运行") - - while self._is_running and not self._stop_event: - try: - # 从队列获取数据 - try: - data = self._input_queue.get(timeout=self._queue_timeout) - # 检查是否是结束信号 - if data is None: - logger.info("收到结束信号,管道准备停止") - self._stop() - self._input_queue.task_done() # 标记结束信号已处理 - break - - # 处理数据 - self.process(data) - - # 标记任务完成 - self._input_queue.task_done() - - except Empty: - # 队列获取超时,继续等待 - continue - - except Exception as e: - logger.error(f"管道处理数据出错: {str(e)}") - continue - - logger.info("管道停止运行") - + def _check_result(self, result: Any) -> None: """ 检查结果 @@ -188,13 +161,96 @@ class ASRPipeline(PipelineBase): # 通知回调函数 self._notify_callbacks(result) + def run(self) -> threading.Thread: + """ + 运行管道 + Returns: + threading.Thread: 返回已运行线程实例 + """ + # 检查运行资源是否准备完毕 + self._pre_check() + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + logger.info("[ASRpipeline]管道开始运行") + return self._thread + + def _pre_check(self) -> None: + """ + 预检查 + """ + if self._is_baked is False: + raise RuntimeError("[ASRpipeline]管道未烘焙,无法运行") + + for functor_name, functor in self._functor_dict.items(): + if functor is None: + raise RuntimeError(f"[ASRpipeline]functor{functor_name}异常") + + for subqueue_name, subqueue in self._subqueue_dict.items(): + if subqueue is None: + raise RuntimeError(f"[ASRpipeline]子队列{subqueue_name}异常") + + def _run(self) -> None: + """ + 真实的运行逻辑 + """ + # 运行所有functor + for functor_name, functor in self._functor_dict.items(): + logger.info(f"[ASRpipeline]运行{functor_name}functor") + self._functor_dict[functor_name].run() + + # 设置管道运行状态 + with self._status_lock: + self._is_running = True + self._stop_event = False + + while self._is_running and not self._stop_event: + try: + data = self._input_queue.get(timeout=self._queue_timeout) + # 检查是否是结束信号 + if data is None: + logger.info("收到结束信号,管道准备停止") + self._input_queue.task_done() # 标记结束信号已处理 + break + + # 处理数据 + self._process(data) + + # 标记任务完成 + self._input_queue.task_done() + + except Empty: + # 队列获取超时,继续等待 + continue + except Exception as e: + logger.error(f"[ASRpipeline]管道处理数据出错: {str(e)}") + break + + logger.info("[ASRpipeline]管道停止运行") + + def _process(self, data: Any) -> Any: + """ + 处理数据 + 参数: + data: 输入数据 + 返回: + 处理结果 + """ + # 子类实现具体的处理逻辑 + self._subqueue_dict["original"].put(data) + def stop(self) -> None: """ 停止管道 """ - self._is_running = False - self._stop_event = True - for functor_name, functor in self._functor_dict.items(): - logger.info(f"停止{functor_name}functor") - functor.stop() - logger.info("子Functor停止") + with self._status_lock: + self._is_running = False + self._stop_event = True + for functor_name, functor in self._functor_dict.items(): + # logger.info(f"停止{functor_name}functor") + if functor.stop(): + logger.info(f"[ASRpipeline]子Functor[{functor_name}]停止") + else: + logger.error(f"[ASRpipeline]子Functor[{functor_name}]停止失败") + self._thread.join() + logger.info("[ASRpipeline]管道停止") + return True diff --git a/src/pipeline/__init__.py b/src/pipeline/__init__.py index dd78c84..d1a171d 100644 --- a/src/pipeline/__init__.py +++ b/src/pipeline/__init__.py @@ -1,3 +1,3 @@ -from src.pipeline.base import PipelineBase, Pipeline, PipelineFactory +from src.pipeline.base import PipelineBase, PipelineFactory -__all__ = ["PipelineBase", "Pipeline", "PipelineFactory"] +__all__ = ["PipelineBase", "PipelineFactory"] diff --git a/src/pipeline/base.py b/src/pipeline/base.py index e348f1b..6d7968d 100644 --- a/src/pipeline/base.py +++ b/src/pipeline/base.py @@ -56,7 +56,7 @@ class PipelineBase(ABC): logger.error(f"回调函数执行出错: {str(e)}") @abstractmethod - def process(self, data: Any) -> Any: + def _process(self, data: Any) -> Any: """ 处理数据 参数: @@ -67,7 +67,7 @@ class PipelineBase(ABC): pass @abstractmethod - def run(self) -> None: + def _run(self) -> None: """ 运行管道 从输入队列获取数据并处理 @@ -121,7 +121,7 @@ class PipelineFactory: 用于创建管道实例 """ @staticmethod - def create_pipeline(pipeline_name: str) -> Pipeline: + def create_pipeline(pipeline_name: str) -> Any: """ 创建管道实例 """ diff --git a/test_main.py b/test_main.py index d5c834b..4fea328 100644 --- a/test_main.py +++ b/test_main.py @@ -1,8 +1,12 @@ from tests.functor.vad_test import test_vad_functor +from tests.pipeline.asr_test import test_asr_pipeline from src.utils.logger import get_module_logger, setup_root_logger setup_root_logger(level="INFO", log_file="logs/test_main.log") logger = get_module_logger(__name__) -logger.info("开始测试VAD函数器") -test_vad_functor() \ No newline at end of file +# logger.info("开始测试VAD函数器") +# test_vad_functor() + +logger.info("开始测试ASR管道") +test_asr_pipeline() diff --git a/tests/functor/vad_test.py b/tests/functor/vad_test.py index ccd2b02..ab3abc7 100644 --- a/tests/functor/vad_test.py +++ b/tests/functor/vad_test.py @@ -4,7 +4,7 @@ VAD测试 """ from src.functor.vad_functor import VADFunctor from src.functor.asr_functor import ASRFunctor -from src.functor.spk_functor import SpkFunctor +from src.functor.spk_functor import SPKFunctor from queue import Queue, Empty from src.model_loader import ModelLoader from src.models import AudioBinary_Config, AudioBinary_data_list @@ -84,7 +84,7 @@ def test_vad_functor(): asr_functor.run() # 创建SPK函数器 - spk_functor = SpkFunctor() + spk_functor = SPKFunctor() # 设置输入队列 spk_functor.set_input_queue(vad2spk_queue) # 设置音频配置 diff --git a/tests/pipeline/asr_test.py b/tests/pipeline/asr_test.py new file mode 100644 index 0000000..388009d --- /dev/null +++ b/tests/pipeline/asr_test.py @@ -0,0 +1,80 @@ +""" +Pipeline测试 +VAD+ASR+SPK(FAKE) +""" +from src.pipeline.ASRpipeline import ASRPipeline +from src.models import AudioBinary_data_list, AudioBinary_Config +from src.model_loader import ModelLoader +from queue import Queue +import soundfile +import time + +from src.utils.logger import get_module_logger + +logger = get_module_logger(__name__) + + +OVAERWATCH = False + +model_loader = ModelLoader() + +def test_asr_pipeline(): + # 加载模型 + args = { + "asr_model": "paraformer-zh", + "asr_model_revision": "v2.0.4", + "vad_model": "fsmn-vad", + "vad_model_revision": "v2.0.4", + "spk_model": "cam++", + "spk_model_revision": "v2.0.2", + "audio_update": False, + } + models = model_loader.load_models(args) + audio_data, sample_rate = soundfile.read("tests/vad_example.wav") + audio_config = AudioBinary_Config( + chunk_size=200, + chunk_stride=1600, + sample_rate=sample_rate, + sample_width=16, + channels=1 + ) + chunk_stride = int(audio_config.chunk_size*sample_rate/1000) + audio_config.chunk_stride = chunk_stride + + # 创建参数Dict + config = { + "audio_config": audio_config, + } + + # 创建音频数据列表 + audio_binary_data_list = AudioBinary_data_list() + + input_queue = Queue() + + # 创建Pipeline + asr_pipeline = ASRPipeline() + asr_pipeline.set_models(models) + asr_pipeline.set_config(config) + asr_pipeline.set_audio_binary(audio_binary_data_list) + asr_pipeline.set_input_queue(input_queue) + asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}")) + asr_pipeline.bake() + + # 运行Pipeline + asr_instance = asr_pipeline.run() + + audio_clip_len = 200 + print(f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}") + for i in range(0, len(audio_data), audio_clip_len): + input_queue.put(audio_data[i:i+audio_clip_len]) + + # time.sleep(10) + # input_queue.put(None) + + # 等待Pipeline结束 + # asr_instance.join() + + time.sleep(5) + asr_pipeline.stop() + # asr_pipeline.stop() +