[代码重构中]添加结果聚合节点,修改ASRpipeline,已通过基本测试。

This commit is contained in:
Ziyang.Zhang 2025-06-25 11:30:35 +08:00
parent 7b9a79942d
commit 6ac206b6b1
4 changed files with 220 additions and 49 deletions

View File

@ -13,7 +13,7 @@ Functor基础模块
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, List from typing import Callable, List, Dict
from queue import Queue from queue import Queue
import threading import threading
@ -121,30 +121,6 @@ class FunctorFactory:
make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor: make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
创建并配置Functor实例 创建并配置Functor实例
""" """
@classmethod
def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor:
"""
创建并配置Functor实例
参数:
funtor_name (str): Functor名称
config (dict): 配置信息
models (dict): 模型信息
返回:
BaseFunctor: 创建的Functor实例
"""
if functor_name == "vad":
return cls._make_vadfunctor(config=config, models=models)
elif functor_name == "asr":
return cls._make_asrfunctor(config=config, models=models)
elif functor_name == "spk":
return cls._make_spkfunctor(config=config, models=models)
else:
raise ValueError(f"不支持的Functor类型: {functor_name}")
def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor: def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor:
""" """
创建VAD Functor实例 创建VAD Functor实例
@ -189,3 +165,38 @@ class FunctorFactory:
spk_functor.set_model(model) spk_functor.set_model(model)
return spk_functor return spk_functor
def _make_resultbinderfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建ResultBinder Functor实例
"""
from src.functor.resultbinder_functor import ResultBinderFunctor
resultbinder_functor = ResultBinderFunctor()
return resultbinder_functor
factory_dict: Dict[str, Callable] = {
"vad": _make_vadfunctor,
"asr": _make_asrfunctor,
"spk": _make_spkfunctor,
"resultbinder": _make_resultbinderfunctor,
}
@classmethod
def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor:
"""
创建并配置Functor实例
参数:
funtor_name (str): Functor名称
config (dict): 配置信息
models (dict): 模型信息
返回:
BaseFunctor: 创建的Functor实例
"""
if functor_name in cls.factory_dict:
return cls.factory_dict[functor_name](config=config, models=models)
else:
raise ValueError(f"不支持的Functor类型: {functor_name}")

View File

@ -0,0 +1,151 @@
"""
ResultBinderFunctor
负责聚合结果将所有input_queue中的结果进行聚合并进行callback
"""
from src.functor.base import BaseFunctor
from src.models import AudioBinary_Config, VAD_Functor_result
from typing import Callable, List
from queue import Queue, Empty
import threading
import time
# 日志
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class ResultBinderFunctor(BaseFunctor):
"""
ResultBinderFunctor
负责聚合结果将所有input_queue中的结果进行聚合并进行callback
"""
def __init__(self) -> None:
super().__init__()
# 资源与配置
self._callback: List[Callable] = [] # 回调函数
self._input_queue: Dict[str, Queue] = {} # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置
def reset_cache(self) -> None:
"""
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
"""
pass
def add_input_queue(self,
name: str,
queue: Queue,
) -> None:
"""
设置监听的输入消息队列
"""
self._input_queue[name] = queue
def set_model(self, model: dict) -> None:
"""
设置推理模型
resultbinder_functor 不应设置模型
"""
logger.warning("ResultBinderFunctor不应设置模型")
self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
"""
设置音频配置
resultbinder_functor 不应设置音频配置
"""
logger.warning("ResultBinderFunctor不应设置音频配置")
self._audio_config = audio_config
def add_callback(self, callback: Callable) -> None:
"""
向自身的_callback: List[Callable]回调函数列表中添加回调函数
"""
if not isinstance(self._callback, list):
self._callback = []
self._callback.append(callback)
def _do_callback(self, result: List[str]) -> None:
"""
回调函数
"""
for callback in self._callback:
callback(result)
def _process(self, data: VAD_Functor_result) -> None:
"""
处理数据
"""
logger.debug("ResultBinderFunctor处理数据: %s", data)
# 将data中的result进行聚合
# 此步暂时无意义,预留
results = {}
for name, result in data.items():
results[name] = result
self._do_callback(results)
def _run(self) -> None:
"""
线程运行逻辑
"""
with self._status_lock:
self._is_running = True
self._stop_event = False
# 运行逻辑
while self._is_running:
try:
# 若有队列为空则等待0.1s
for name, queue in self._input_queue.items():
if queue.empty():
time.sleep(0.1)
raise Empty
data = {}
for name, queue in self._input_queue.items():
data[name] = queue.get(True, timeout=1)
queue.task_done()
self._process(data)
# 当队列为空时, 检测是否进入停止事件。
except Empty:
if self._stop_event:
break
continue
# 其他异常
except Exception as e:
logger.error("SpkFunctor运行时发生错误: %s", e)
raise e
def run(self) -> threading.Thread:
"""
启动线程
Returns:
threading.Thread: 返回已运行线程实例
"""
self._pre_check()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
return self._thread
def _pre_check(self) -> bool:
"""
预检查
"""
if self._model is None:
raise ValueError("模型未设置")
if self._input_queue is None:
raise ValueError("输入队列未设置")
if self._callback is None:
raise ValueError("回调函数未设置")
return True
def stop(self) -> bool:
"""
停止线程
"""
with self._status_lock:
self._stop_event = True
self._thread.join()
with self._status_lock:
self._is_running = False
return not self._thread.is_alive()

View File

@ -29,6 +29,7 @@ class ASRPipeline(PipelineBase):
self._status_lock = threading.Lock() self._status_lock = threading.Lock()
self._is_running: bool = False self._is_running: bool = False
self._stop_event: bool = False self._stop_event: bool = False
self._callback: Callable = None
def set_config(self, config: Dict[str, Any]) -> None: def set_config(self, config: Dict[str, Any]) -> None:
""" """
@ -66,6 +67,12 @@ class ASRPipeline(PipelineBase):
""" """
self._input_queue = input_queue self._input_queue = input_queue
def set_callback(self, callback: Callable) -> None:
"""
设置回调函数
"""
self._callback = callback
def bake(self) -> None: def bake(self) -> None:
""" """
烘焙管道 烘焙管道
@ -90,6 +97,12 @@ class ASRPipeline(PipelineBase):
def _init_functor(self) -> None: def _init_functor(self) -> None:
""" """
初始化函数 初始化函数
自身的functor流程图如下
self.input_queue(self.run检测输入到subqueue["original"])->vad
->vad2asr ->asrend
->vad2spk ->spkend
->asrend+spkend->resultbinder
->self.callback
""" """
try: try:
from src.functor import FunctorFactory from src.functor import FunctorFactory
@ -105,6 +118,10 @@ class ASRPipeline(PipelineBase):
functor_name="spk", config=self._config, models=self._models functor_name="spk", config=self._config, models=self._models
) )
self._functor_dict["resultbinder"] = FunctorFactory.make_functor(
functor_name="resultbinder", config=self._config, models=self._models
)
# 创建音频数据存储单元 # 创建音频数据存储单元
self._audio_binary_data_list = AudioBinary_data_list() self._audio_binary_data_list = AudioBinary_data_list()
@ -118,38 +135,30 @@ class ASRPipeline(PipelineBase):
self._subqueue_dict["vad2spk"] = Queue() self._subqueue_dict["vad2spk"] = Queue()
self._subqueue_dict["asrend"] = Queue() self._subqueue_dict["asrend"] = Queue()
self._subqueue_dict["spkend"] = Queue() self._subqueue_dict["spkend"] = Queue()
# 输出队列
self._subqueue_dict["OUTPUT"] = Queue()
# 设置子队列的输入队列 # 设置子队列的输入队列
self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"]) self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"])
self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"]) self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"]) self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
# 设置resultbinder的输入队列
# 汇总 asr语音识别结果 和 说话人识别结果
self._functor_dict["resultbinder"].add_input_queue(
"asr", self._subqueue_dict["asrend"]
)
self._functor_dict["resultbinder"].add_input_queue(
"spk", self._subqueue_dict["spkend"]
)
# 设置回调函数——放置到对应队列中 # 设置回调函数——放置到对应队列中
self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put) self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put)
self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put) self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put)
# 设置asr与spk的回调函数
# 构造带回调函数的put self._functor_dict["asr"].add_callback(self._subqueue_dict["asrend"].put)
def put_with_checkcallback(queue: Queue, callback: Callable) -> None: self._functor_dict["spk"].add_callback(self._subqueue_dict["spkend"].put)
""" # 设置resultbinder的回调函数 为 自身被设置的回调函数,用于和外界交互
带回调函数的put self._functor_dict["resultbinder"].add_callback(self._callback)
"""
def put_with_check(data: Any) -> None:
queue.put(data)
callback(data)
return put_with_check
self._functor_dict["asr"].add_callback(
put_with_checkcallback(
self._subqueue_dict["asrend"], self._check_result
)
)
self._functor_dict["spk"].add_callback(
put_with_checkcallback(
self._subqueue_dict["spkend"], self._check_result
)
)
except ImportError: except ImportError:
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化") raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")

View File

@ -135,7 +135,7 @@ class PipelineFactory:
pipeline.set_models(kwargs["models"]) pipeline.set_models(kwargs["models"])
pipeline.set_audio_binary(kwargs["audio_binary"]) pipeline.set_audio_binary(kwargs["audio_binary"])
pipeline.set_input_queue(kwargs["input_queue"]) pipeline.set_input_queue(kwargs["input_queue"])
pipeline.add_callback(kwargs["callback"]) pipeline.set_callback(kwargs["callback"])
pipeline.bake() pipeline.bake()
return pipeline return pipeline