diff --git a/src/functor/base.py b/src/functor/base.py index ba0b13f..6436cc5 100644 --- a/src/functor/base.py +++ b/src/functor/base.py @@ -13,7 +13,7 @@ Functor基础模块 """ from abc import ABC, abstractmethod -from typing import Callable, List +from typing import Callable, List, Dict from queue import Queue import threading @@ -121,30 +121,6 @@ class FunctorFactory: make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor: 创建并配置Functor实例 """ - - @classmethod - def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor: - """ - 创建并配置Functor实例 - - 参数: - funtor_name (str): Functor名称 - config (dict): 配置信息 - models (dict): 模型信息 - - 返回: - BaseFunctor: 创建的Functor实例 - """ - - if functor_name == "vad": - return cls._make_vadfunctor(config=config, models=models) - elif functor_name == "asr": - return cls._make_asrfunctor(config=config, models=models) - elif functor_name == "spk": - return cls._make_spkfunctor(config=config, models=models) - else: - raise ValueError(f"不支持的Functor类型: {functor_name}") - def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor: """ 创建VAD Functor实例 @@ -189,3 +165,38 @@ class FunctorFactory: spk_functor.set_model(model) return spk_functor + + def _make_resultbinderfunctor(config: dict, models: dict) -> BaseFunctor: + """ + 创建ResultBinder Functor实例 + """ + from src.functor.resultbinder_functor import ResultBinderFunctor + + resultbinder_functor = ResultBinderFunctor() + + return resultbinder_functor + + factory_dict: Dict[str, Callable] = { + "vad": _make_vadfunctor, + "asr": _make_asrfunctor, + "spk": _make_spkfunctor, + "resultbinder": _make_resultbinderfunctor, + } + + @classmethod + def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor: + """ + 创建并配置Functor实例 + + 参数: + funtor_name (str): Functor名称 + config (dict): 配置信息 + models (dict): 模型信息 + + 返回: + BaseFunctor: 创建的Functor实例 + """ + if functor_name in cls.factory_dict: + return cls.factory_dict[functor_name](config=config, models=models) + else: + raise ValueError(f"不支持的Functor类型: {functor_name}") diff --git a/src/functor/resultbinder_functor.py b/src/functor/resultbinder_functor.py new file mode 100644 index 0000000..b7b33a8 --- /dev/null +++ b/src/functor/resultbinder_functor.py @@ -0,0 +1,151 @@ +""" +ResultBinderFunctor +负责聚合结果,将所有input_queue中的结果进行聚合,并进行callback +""" + +from src.functor.base import BaseFunctor +from src.models import AudioBinary_Config, VAD_Functor_result +from typing import Callable, List +from queue import Queue, Empty +import threading +import time +# 日志 +from src.utils.logger import get_module_logger + +logger = get_module_logger(__name__) + + +class ResultBinderFunctor(BaseFunctor): + """ + ResultBinderFunctor + 负责聚合结果,将所有input_queue中的结果进行聚合,并进行callback + """ + + def __init__(self) -> None: + super().__init__() + # 资源与配置 + self._callback: List[Callable] = [] # 回调函数 + self._input_queue: Dict[str, Queue] = {} # 输入队列 + self._audio_config: AudioBinary_Config = None # 音频配置 + + def reset_cache(self) -> None: + """ + 重置缓存, 用于任务完成后清理缓存数据, 准备下次任务 + """ + pass + + def add_input_queue(self, + name: str, + queue: Queue, + ) -> None: + """ + 设置监听的输入消息队列 + """ + self._input_queue[name] = queue + + def set_model(self, model: dict) -> None: + """ + 设置推理模型 + resultbinder_functor 不应设置模型 + """ + logger.warning("ResultBinderFunctor不应设置模型") + self._model = model + + def set_audio_config(self, audio_config: AudioBinary_Config) -> None: + """ + 设置音频配置 + resultbinder_functor 不应设置音频配置 + """ + logger.warning("ResultBinderFunctor不应设置音频配置") + self._audio_config = audio_config + + def add_callback(self, callback: Callable) -> None: + """ + 向自身的_callback: List[Callable]回调函数列表中添加回调函数 + """ + if not isinstance(self._callback, list): + self._callback = [] + self._callback.append(callback) + + def _do_callback(self, result: List[str]) -> None: + """ + 回调函数 + """ + for callback in self._callback: + callback(result) + + def _process(self, data: VAD_Functor_result) -> None: + """ + 处理数据 + """ + logger.debug("ResultBinderFunctor处理数据: %s", data) + # 将data中的result进行聚合 + # 此步暂时无意义,预留 + results = {} + for name, result in data.items(): + results[name] = result + self._do_callback(results) + + def _run(self) -> None: + """ + 线程运行逻辑 + """ + with self._status_lock: + self._is_running = True + self._stop_event = False + # 运行逻辑 + while self._is_running: + try: + # 若有队列为空,则等待0.1s + for name, queue in self._input_queue.items(): + if queue.empty(): + time.sleep(0.1) + raise Empty + data = {} + for name, queue in self._input_queue.items(): + data[name] = queue.get(True, timeout=1) + queue.task_done() + self._process(data) + # 当队列为空时, 检测是否进入停止事件。 + except Empty: + if self._stop_event: + break + continue + # 其他异常 + except Exception as e: + logger.error("SpkFunctor运行时发生错误: %s", e) + raise e + + def run(self) -> threading.Thread: + """ + 启动线程 + Returns: + threading.Thread: 返回已运行线程实例 + """ + self._pre_check() + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + return self._thread + + def _pre_check(self) -> bool: + """ + 预检查 + """ + if self._model is None: + raise ValueError("模型未设置") + if self._input_queue is None: + raise ValueError("输入队列未设置") + if self._callback is None: + raise ValueError("回调函数未设置") + return True + + def stop(self) -> bool: + """ + 停止线程 + """ + with self._status_lock: + self._stop_event = True + self._thread.join() + with self._status_lock: + self._is_running = False + return not self._thread.is_alive() diff --git a/src/pipeline/ASRpipeline.py b/src/pipeline/ASRpipeline.py index ab44de0..8d11eb6 100644 --- a/src/pipeline/ASRpipeline.py +++ b/src/pipeline/ASRpipeline.py @@ -29,6 +29,7 @@ class ASRPipeline(PipelineBase): self._status_lock = threading.Lock() self._is_running: bool = False self._stop_event: bool = False + self._callback: Callable = None def set_config(self, config: Dict[str, Any]) -> None: """ @@ -66,6 +67,12 @@ class ASRPipeline(PipelineBase): """ self._input_queue = input_queue + def set_callback(self, callback: Callable) -> None: + """ + 设置回调函数 + """ + self._callback = callback + def bake(self) -> None: """ 烘焙管道 @@ -90,6 +97,12 @@ class ASRPipeline(PipelineBase): def _init_functor(self) -> None: """ 初始化函数 + 自身的functor流程图如下 + self.input_queue(self.run检测输入到subqueue["original"])->vad + ->vad2asr ->asrend + ->vad2spk ->spkend + ->asrend+spkend->resultbinder + ->self.callback """ try: from src.functor import FunctorFactory @@ -105,6 +118,10 @@ class ASRPipeline(PipelineBase): functor_name="spk", config=self._config, models=self._models ) + self._functor_dict["resultbinder"] = FunctorFactory.make_functor( + functor_name="resultbinder", config=self._config, models=self._models + ) + # 创建音频数据存储单元 self._audio_binary_data_list = AudioBinary_data_list() @@ -118,38 +135,30 @@ class ASRPipeline(PipelineBase): self._subqueue_dict["vad2spk"] = Queue() self._subqueue_dict["asrend"] = Queue() self._subqueue_dict["spkend"] = Queue() + # 输出队列 + self._subqueue_dict["OUTPUT"] = Queue() # 设置子队列的输入队列 self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"]) self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"]) self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"]) + # 设置resultbinder的输入队列 + # 汇总 asr语音识别结果 和 说话人识别结果 + self._functor_dict["resultbinder"].add_input_queue( + "asr", self._subqueue_dict["asrend"] + ) + self._functor_dict["resultbinder"].add_input_queue( + "spk", self._subqueue_dict["spkend"] + ) # 设置回调函数——放置到对应队列中 self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put) self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put) - - # 构造带回调函数的put - def put_with_checkcallback(queue: Queue, callback: Callable) -> None: - """ - 带回调函数的put - """ - - def put_with_check(data: Any) -> None: - queue.put(data) - callback(data) - - return put_with_check - - self._functor_dict["asr"].add_callback( - put_with_checkcallback( - self._subqueue_dict["asrend"], self._check_result - ) - ) - self._functor_dict["spk"].add_callback( - put_with_checkcallback( - self._subqueue_dict["spkend"], self._check_result - ) - ) + # 设置asr与spk的回调函数 + self._functor_dict["asr"].add_callback(self._subqueue_dict["asrend"].put) + self._functor_dict["spk"].add_callback(self._subqueue_dict["spkend"].put) + # 设置resultbinder的回调函数 为 自身被设置的回调函数,用于和外界交互 + self._functor_dict["resultbinder"].add_callback(self._callback) except ImportError: raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化") diff --git a/src/pipeline/base.py b/src/pipeline/base.py index 7d5ef13..f594d59 100644 --- a/src/pipeline/base.py +++ b/src/pipeline/base.py @@ -135,7 +135,7 @@ class PipelineFactory: pipeline.set_models(kwargs["models"]) pipeline.set_audio_binary(kwargs["audio_binary"]) pipeline.set_input_queue(kwargs["input_queue"]) - pipeline.add_callback(kwargs["callback"]) + pipeline.set_callback(kwargs["callback"]) pipeline.bake() return pipeline