[代码重构中]添加结果聚合节点,修改ASRpipeline,已通过基本测试。
This commit is contained in:
parent
7b9a79942d
commit
6ac206b6b1
@ -13,7 +13,7 @@ Functor基础模块
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List
|
||||
from typing import Callable, List, Dict
|
||||
from queue import Queue
|
||||
import threading
|
||||
|
||||
@ -121,30 +121,6 @@ class FunctorFactory:
|
||||
make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
|
||||
创建并配置Functor实例
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor:
|
||||
"""
|
||||
创建并配置Functor实例
|
||||
|
||||
参数:
|
||||
funtor_name (str): Functor名称
|
||||
config (dict): 配置信息
|
||||
models (dict): 模型信息
|
||||
|
||||
返回:
|
||||
BaseFunctor: 创建的Functor实例
|
||||
"""
|
||||
|
||||
if functor_name == "vad":
|
||||
return cls._make_vadfunctor(config=config, models=models)
|
||||
elif functor_name == "asr":
|
||||
return cls._make_asrfunctor(config=config, models=models)
|
||||
elif functor_name == "spk":
|
||||
return cls._make_spkfunctor(config=config, models=models)
|
||||
else:
|
||||
raise ValueError(f"不支持的Functor类型: {functor_name}")
|
||||
|
||||
def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||
"""
|
||||
创建VAD Functor实例
|
||||
@ -189,3 +165,38 @@ class FunctorFactory:
|
||||
spk_functor.set_model(model)
|
||||
|
||||
return spk_functor
|
||||
|
||||
def _make_resultbinderfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||
"""
|
||||
创建ResultBinder Functor实例
|
||||
"""
|
||||
from src.functor.resultbinder_functor import ResultBinderFunctor
|
||||
|
||||
resultbinder_functor = ResultBinderFunctor()
|
||||
|
||||
return resultbinder_functor
|
||||
|
||||
factory_dict: Dict[str, Callable] = {
|
||||
"vad": _make_vadfunctor,
|
||||
"asr": _make_asrfunctor,
|
||||
"spk": _make_spkfunctor,
|
||||
"resultbinder": _make_resultbinderfunctor,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor:
|
||||
"""
|
||||
创建并配置Functor实例
|
||||
|
||||
参数:
|
||||
funtor_name (str): Functor名称
|
||||
config (dict): 配置信息
|
||||
models (dict): 模型信息
|
||||
|
||||
返回:
|
||||
BaseFunctor: 创建的Functor实例
|
||||
"""
|
||||
if functor_name in cls.factory_dict:
|
||||
return cls.factory_dict[functor_name](config=config, models=models)
|
||||
else:
|
||||
raise ValueError(f"不支持的Functor类型: {functor_name}")
|
||||
|
151
src/functor/resultbinder_functor.py
Normal file
151
src/functor/resultbinder_functor.py
Normal file
@ -0,0 +1,151 @@
|
||||
"""
|
||||
ResultBinderFunctor
|
||||
负责聚合结果,将所有input_queue中的结果进行聚合,并进行callback
|
||||
"""
|
||||
|
||||
from src.functor.base import BaseFunctor
|
||||
from src.models import AudioBinary_Config, VAD_Functor_result
|
||||
from typing import Callable, List
|
||||
from queue import Queue, Empty
|
||||
import threading
|
||||
import time
|
||||
# 日志
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
class ResultBinderFunctor(BaseFunctor):
|
||||
"""
|
||||
ResultBinderFunctor
|
||||
负责聚合结果,将所有input_queue中的结果进行聚合,并进行callback
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# 资源与配置
|
||||
self._callback: List[Callable] = [] # 回调函数
|
||||
self._input_queue: Dict[str, Queue] = {} # 输入队列
|
||||
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||
|
||||
def reset_cache(self) -> None:
|
||||
"""
|
||||
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||
"""
|
||||
pass
|
||||
|
||||
def add_input_queue(self,
|
||||
name: str,
|
||||
queue: Queue,
|
||||
) -> None:
|
||||
"""
|
||||
设置监听的输入消息队列
|
||||
"""
|
||||
self._input_queue[name] = queue
|
||||
|
||||
def set_model(self, model: dict) -> None:
|
||||
"""
|
||||
设置推理模型
|
||||
resultbinder_functor 不应设置模型
|
||||
"""
|
||||
logger.warning("ResultBinderFunctor不应设置模型")
|
||||
self._model = model
|
||||
|
||||
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
||||
"""
|
||||
设置音频配置
|
||||
resultbinder_functor 不应设置音频配置
|
||||
"""
|
||||
logger.warning("ResultBinderFunctor不应设置音频配置")
|
||||
self._audio_config = audio_config
|
||||
|
||||
def add_callback(self, callback: Callable) -> None:
|
||||
"""
|
||||
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||||
"""
|
||||
if not isinstance(self._callback, list):
|
||||
self._callback = []
|
||||
self._callback.append(callback)
|
||||
|
||||
def _do_callback(self, result: List[str]) -> None:
|
||||
"""
|
||||
回调函数
|
||||
"""
|
||||
for callback in self._callback:
|
||||
callback(result)
|
||||
|
||||
def _process(self, data: VAD_Functor_result) -> None:
|
||||
"""
|
||||
处理数据
|
||||
"""
|
||||
logger.debug("ResultBinderFunctor处理数据: %s", data)
|
||||
# 将data中的result进行聚合
|
||||
# 此步暂时无意义,预留
|
||||
results = {}
|
||||
for name, result in data.items():
|
||||
results[name] = result
|
||||
self._do_callback(results)
|
||||
|
||||
def _run(self) -> None:
|
||||
"""
|
||||
线程运行逻辑
|
||||
"""
|
||||
with self._status_lock:
|
||||
self._is_running = True
|
||||
self._stop_event = False
|
||||
# 运行逻辑
|
||||
while self._is_running:
|
||||
try:
|
||||
# 若有队列为空,则等待0.1s
|
||||
for name, queue in self._input_queue.items():
|
||||
if queue.empty():
|
||||
time.sleep(0.1)
|
||||
raise Empty
|
||||
data = {}
|
||||
for name, queue in self._input_queue.items():
|
||||
data[name] = queue.get(True, timeout=1)
|
||||
queue.task_done()
|
||||
self._process(data)
|
||||
# 当队列为空时, 检测是否进入停止事件。
|
||||
except Empty:
|
||||
if self._stop_event:
|
||||
break
|
||||
continue
|
||||
# 其他异常
|
||||
except Exception as e:
|
||||
logger.error("SpkFunctor运行时发生错误: %s", e)
|
||||
raise e
|
||||
|
||||
def run(self) -> threading.Thread:
|
||||
"""
|
||||
启动线程
|
||||
Returns:
|
||||
threading.Thread: 返回已运行线程实例
|
||||
"""
|
||||
self._pre_check()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
return self._thread
|
||||
|
||||
def _pre_check(self) -> bool:
|
||||
"""
|
||||
预检查
|
||||
"""
|
||||
if self._model is None:
|
||||
raise ValueError("模型未设置")
|
||||
if self._input_queue is None:
|
||||
raise ValueError("输入队列未设置")
|
||||
if self._callback is None:
|
||||
raise ValueError("回调函数未设置")
|
||||
return True
|
||||
|
||||
def stop(self) -> bool:
|
||||
"""
|
||||
停止线程
|
||||
"""
|
||||
with self._status_lock:
|
||||
self._stop_event = True
|
||||
self._thread.join()
|
||||
with self._status_lock:
|
||||
self._is_running = False
|
||||
return not self._thread.is_alive()
|
@ -29,6 +29,7 @@ class ASRPipeline(PipelineBase):
|
||||
self._status_lock = threading.Lock()
|
||||
self._is_running: bool = False
|
||||
self._stop_event: bool = False
|
||||
self._callback: Callable = None
|
||||
|
||||
def set_config(self, config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
@ -66,6 +67,12 @@ class ASRPipeline(PipelineBase):
|
||||
"""
|
||||
self._input_queue = input_queue
|
||||
|
||||
def set_callback(self, callback: Callable) -> None:
|
||||
"""
|
||||
设置回调函数
|
||||
"""
|
||||
self._callback = callback
|
||||
|
||||
def bake(self) -> None:
|
||||
"""
|
||||
烘焙管道
|
||||
@ -90,6 +97,12 @@ class ASRPipeline(PipelineBase):
|
||||
def _init_functor(self) -> None:
|
||||
"""
|
||||
初始化函数
|
||||
自身的functor流程图如下
|
||||
self.input_queue(self.run检测输入到subqueue["original"])->vad
|
||||
->vad2asr ->asrend
|
||||
->vad2spk ->spkend
|
||||
->asrend+spkend->resultbinder
|
||||
->self.callback
|
||||
"""
|
||||
try:
|
||||
from src.functor import FunctorFactory
|
||||
@ -105,6 +118,10 @@ class ASRPipeline(PipelineBase):
|
||||
functor_name="spk", config=self._config, models=self._models
|
||||
)
|
||||
|
||||
self._functor_dict["resultbinder"] = FunctorFactory.make_functor(
|
||||
functor_name="resultbinder", config=self._config, models=self._models
|
||||
)
|
||||
|
||||
# 创建音频数据存储单元
|
||||
self._audio_binary_data_list = AudioBinary_data_list()
|
||||
|
||||
@ -118,38 +135,30 @@ class ASRPipeline(PipelineBase):
|
||||
self._subqueue_dict["vad2spk"] = Queue()
|
||||
self._subqueue_dict["asrend"] = Queue()
|
||||
self._subqueue_dict["spkend"] = Queue()
|
||||
# 输出队列
|
||||
self._subqueue_dict["OUTPUT"] = Queue()
|
||||
|
||||
# 设置子队列的输入队列
|
||||
self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"])
|
||||
self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
|
||||
self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
|
||||
# 设置resultbinder的输入队列
|
||||
# 汇总 asr语音识别结果 和 说话人识别结果
|
||||
self._functor_dict["resultbinder"].add_input_queue(
|
||||
"asr", self._subqueue_dict["asrend"]
|
||||
)
|
||||
self._functor_dict["resultbinder"].add_input_queue(
|
||||
"spk", self._subqueue_dict["spkend"]
|
||||
)
|
||||
|
||||
# 设置回调函数——放置到对应队列中
|
||||
self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put)
|
||||
self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put)
|
||||
|
||||
# 构造带回调函数的put
|
||||
def put_with_checkcallback(queue: Queue, callback: Callable) -> None:
|
||||
"""
|
||||
带回调函数的put
|
||||
"""
|
||||
|
||||
def put_with_check(data: Any) -> None:
|
||||
queue.put(data)
|
||||
callback(data)
|
||||
|
||||
return put_with_check
|
||||
|
||||
self._functor_dict["asr"].add_callback(
|
||||
put_with_checkcallback(
|
||||
self._subqueue_dict["asrend"], self._check_result
|
||||
)
|
||||
)
|
||||
self._functor_dict["spk"].add_callback(
|
||||
put_with_checkcallback(
|
||||
self._subqueue_dict["spkend"], self._check_result
|
||||
)
|
||||
)
|
||||
# 设置asr与spk的回调函数
|
||||
self._functor_dict["asr"].add_callback(self._subqueue_dict["asrend"].put)
|
||||
self._functor_dict["spk"].add_callback(self._subqueue_dict["spkend"].put)
|
||||
# 设置resultbinder的回调函数 为 自身被设置的回调函数,用于和外界交互
|
||||
self._functor_dict["resultbinder"].add_callback(self._callback)
|
||||
|
||||
except ImportError:
|
||||
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
|
||||
|
@ -135,7 +135,7 @@ class PipelineFactory:
|
||||
pipeline.set_models(kwargs["models"])
|
||||
pipeline.set_audio_binary(kwargs["audio_binary"])
|
||||
pipeline.set_input_queue(kwargs["input_queue"])
|
||||
pipeline.add_callback(kwargs["callback"])
|
||||
pipeline.set_callback(kwargs["callback"])
|
||||
pipeline.bake()
|
||||
return pipeline
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user