[代码重构中]添加结果聚合节点,修改ASRpipeline,已通过基本测试。

This commit is contained in:
Ziyang.Zhang 2025-06-25 11:30:35 +08:00
parent 7b9a79942d
commit 6ac206b6b1
4 changed files with 220 additions and 49 deletions

View File

@ -13,7 +13,7 @@ Functor基础模块
"""
from abc import ABC, abstractmethod
from typing import Callable, List
from typing import Callable, List, Dict
from queue import Queue
import threading
@ -121,30 +121,6 @@ class FunctorFactory:
make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
创建并配置Functor实例
"""
@classmethod
def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor:
"""
创建并配置Functor实例
参数:
funtor_name (str): Functor名称
config (dict): 配置信息
models (dict): 模型信息
返回:
BaseFunctor: 创建的Functor实例
"""
if functor_name == "vad":
return cls._make_vadfunctor(config=config, models=models)
elif functor_name == "asr":
return cls._make_asrfunctor(config=config, models=models)
elif functor_name == "spk":
return cls._make_spkfunctor(config=config, models=models)
else:
raise ValueError(f"不支持的Functor类型: {functor_name}")
def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建VAD Functor实例
@ -189,3 +165,38 @@ class FunctorFactory:
spk_functor.set_model(model)
return spk_functor
def _make_resultbinderfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建ResultBinder Functor实例
"""
from src.functor.resultbinder_functor import ResultBinderFunctor
resultbinder_functor = ResultBinderFunctor()
return resultbinder_functor
factory_dict: Dict[str, Callable] = {
"vad": _make_vadfunctor,
"asr": _make_asrfunctor,
"spk": _make_spkfunctor,
"resultbinder": _make_resultbinderfunctor,
}
@classmethod
def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor:
"""
创建并配置Functor实例
参数:
funtor_name (str): Functor名称
config (dict): 配置信息
models (dict): 模型信息
返回:
BaseFunctor: 创建的Functor实例
"""
if functor_name in cls.factory_dict:
return cls.factory_dict[functor_name](config=config, models=models)
else:
raise ValueError(f"不支持的Functor类型: {functor_name}")

View File

@ -0,0 +1,151 @@
"""
ResultBinderFunctor
负责聚合结果将所有input_queue中的结果进行聚合并进行callback
"""
from src.functor.base import BaseFunctor
from src.models import AudioBinary_Config, VAD_Functor_result
from typing import Callable, List
from queue import Queue, Empty
import threading
import time
# 日志
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class ResultBinderFunctor(BaseFunctor):
"""
ResultBinderFunctor
负责聚合结果将所有input_queue中的结果进行聚合并进行callback
"""
def __init__(self) -> None:
super().__init__()
# 资源与配置
self._callback: List[Callable] = [] # 回调函数
self._input_queue: Dict[str, Queue] = {} # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置
def reset_cache(self) -> None:
"""
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
"""
pass
def add_input_queue(self,
name: str,
queue: Queue,
) -> None:
"""
设置监听的输入消息队列
"""
self._input_queue[name] = queue
def set_model(self, model: dict) -> None:
"""
设置推理模型
resultbinder_functor 不应设置模型
"""
logger.warning("ResultBinderFunctor不应设置模型")
self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
"""
设置音频配置
resultbinder_functor 不应设置音频配置
"""
logger.warning("ResultBinderFunctor不应设置音频配置")
self._audio_config = audio_config
def add_callback(self, callback: Callable) -> None:
"""
向自身的_callback: List[Callable]回调函数列表中添加回调函数
"""
if not isinstance(self._callback, list):
self._callback = []
self._callback.append(callback)
def _do_callback(self, result: List[str]) -> None:
"""
回调函数
"""
for callback in self._callback:
callback(result)
def _process(self, data: VAD_Functor_result) -> None:
"""
处理数据
"""
logger.debug("ResultBinderFunctor处理数据: %s", data)
# 将data中的result进行聚合
# 此步暂时无意义,预留
results = {}
for name, result in data.items():
results[name] = result
self._do_callback(results)
def _run(self) -> None:
"""
线程运行逻辑
"""
with self._status_lock:
self._is_running = True
self._stop_event = False
# 运行逻辑
while self._is_running:
try:
# 若有队列为空则等待0.1s
for name, queue in self._input_queue.items():
if queue.empty():
time.sleep(0.1)
raise Empty
data = {}
for name, queue in self._input_queue.items():
data[name] = queue.get(True, timeout=1)
queue.task_done()
self._process(data)
# 当队列为空时, 检测是否进入停止事件。
except Empty:
if self._stop_event:
break
continue
# 其他异常
except Exception as e:
logger.error("SpkFunctor运行时发生错误: %s", e)
raise e
def run(self) -> threading.Thread:
"""
启动线程
Returns:
threading.Thread: 返回已运行线程实例
"""
self._pre_check()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
return self._thread
def _pre_check(self) -> bool:
"""
预检查
"""
if self._model is None:
raise ValueError("模型未设置")
if self._input_queue is None:
raise ValueError("输入队列未设置")
if self._callback is None:
raise ValueError("回调函数未设置")
return True
def stop(self) -> bool:
"""
停止线程
"""
with self._status_lock:
self._stop_event = True
self._thread.join()
with self._status_lock:
self._is_running = False
return not self._thread.is_alive()

View File

@ -29,6 +29,7 @@ class ASRPipeline(PipelineBase):
self._status_lock = threading.Lock()
self._is_running: bool = False
self._stop_event: bool = False
self._callback: Callable = None
def set_config(self, config: Dict[str, Any]) -> None:
"""
@ -66,6 +67,12 @@ class ASRPipeline(PipelineBase):
"""
self._input_queue = input_queue
def set_callback(self, callback: Callable) -> None:
"""
设置回调函数
"""
self._callback = callback
def bake(self) -> None:
"""
烘焙管道
@ -90,6 +97,12 @@ class ASRPipeline(PipelineBase):
def _init_functor(self) -> None:
"""
初始化函数
自身的functor流程图如下
self.input_queue(self.run检测输入到subqueue["original"])->vad
->vad2asr ->asrend
->vad2spk ->spkend
->asrend+spkend->resultbinder
->self.callback
"""
try:
from src.functor import FunctorFactory
@ -105,6 +118,10 @@ class ASRPipeline(PipelineBase):
functor_name="spk", config=self._config, models=self._models
)
self._functor_dict["resultbinder"] = FunctorFactory.make_functor(
functor_name="resultbinder", config=self._config, models=self._models
)
# 创建音频数据存储单元
self._audio_binary_data_list = AudioBinary_data_list()
@ -118,38 +135,30 @@ class ASRPipeline(PipelineBase):
self._subqueue_dict["vad2spk"] = Queue()
self._subqueue_dict["asrend"] = Queue()
self._subqueue_dict["spkend"] = Queue()
# 输出队列
self._subqueue_dict["OUTPUT"] = Queue()
# 设置子队列的输入队列
self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"])
self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
# 设置resultbinder的输入队列
# 汇总 asr语音识别结果 和 说话人识别结果
self._functor_dict["resultbinder"].add_input_queue(
"asr", self._subqueue_dict["asrend"]
)
self._functor_dict["resultbinder"].add_input_queue(
"spk", self._subqueue_dict["spkend"]
)
# 设置回调函数——放置到对应队列中
self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put)
self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put)
# 构造带回调函数的put
def put_with_checkcallback(queue: Queue, callback: Callable) -> None:
"""
带回调函数的put
"""
def put_with_check(data: Any) -> None:
queue.put(data)
callback(data)
return put_with_check
self._functor_dict["asr"].add_callback(
put_with_checkcallback(
self._subqueue_dict["asrend"], self._check_result
)
)
self._functor_dict["spk"].add_callback(
put_with_checkcallback(
self._subqueue_dict["spkend"], self._check_result
)
)
# 设置asr与spk的回调函数
self._functor_dict["asr"].add_callback(self._subqueue_dict["asrend"].put)
self._functor_dict["spk"].add_callback(self._subqueue_dict["spkend"].put)
# 设置resultbinder的回调函数 为 自身被设置的回调函数,用于和外界交互
self._functor_dict["resultbinder"].add_callback(self._callback)
except ImportError:
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")

View File

@ -135,7 +135,7 @@ class PipelineFactory:
pipeline.set_models(kwargs["models"])
pipeline.set_audio_binary(kwargs["audio_binary"])
pipeline.set_input_queue(kwargs["input_queue"])
pipeline.add_callback(kwargs["callback"])
pipeline.set_callback(kwargs["callback"])
pipeline.bake()
return pipeline