208 lines
5.1 KiB
Python
208 lines
5.1 KiB
Python
"""
|
|
Functor基础模块
|
|
|
|
该模块定义了Functor的基类,所有功能性的类(如VAD、PUNC、ASR、SPK等)都应继承自这个基类。
|
|
基类提供了数据处理的基本框架,包括:
|
|
- 回调函数管理
|
|
- 模型配置管理
|
|
- 线程运行控制
|
|
|
|
主要类:
|
|
BaseFunctor: Functor抽象类
|
|
FunctorFactory: Functor工厂类
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Callable, List, Dict
|
|
from queue import Queue
|
|
import threading
|
|
|
|
from src.utils.logger import get_module_logger
|
|
|
|
logger = get_module_logger(__name__)
|
|
|
|
class BaseFunctor(ABC):
|
|
"""
|
|
Functor抽象类
|
|
|
|
该抽象类规定了所有的Functor类必须实现run()方法启动自身线程
|
|
|
|
属性:
|
|
_callback (Callable): 处理完成后的回调函数
|
|
_model (dict): 存储模型相关的配置和实例
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""
|
|
初始化函数器
|
|
|
|
参数:
|
|
callback (Callable): 处理完成后的回调函数
|
|
model (dict): 模型相关的配置和实例
|
|
"""
|
|
self._callback: List[Callable] = []
|
|
self._model: dict = {}
|
|
# flag
|
|
self._is_running: bool = False
|
|
self._stop_event: bool = False
|
|
# 状态锁
|
|
self._status_lock: threading.Lock = threading.Lock()
|
|
# 线程资源
|
|
self._thread: threading.Thread = None
|
|
|
|
def add_callback(self, callback: Callable):
|
|
"""
|
|
添加回调函数
|
|
|
|
参数:
|
|
callback (Callable): 新的回调函数
|
|
"""
|
|
self._callback.append(callback)
|
|
|
|
def set_model(self, model: dict):
|
|
"""
|
|
设置模型配置
|
|
|
|
参数:
|
|
model (dict): 新的模型配置
|
|
"""
|
|
self._model = model
|
|
|
|
def set_input_queue(self, queue: Queue):
|
|
"""
|
|
设置输入队列
|
|
|
|
参数:
|
|
queue (Queue): 新的输入队列
|
|
"""
|
|
self._input_queue = queue
|
|
|
|
@abstractmethod
|
|
def _run(self):
|
|
"""
|
|
线程运行逻辑
|
|
|
|
返回:
|
|
当达到条件时触发callback
|
|
"""
|
|
|
|
@abstractmethod
|
|
def run(self):
|
|
"""
|
|
启动_run方法线程
|
|
|
|
返回:
|
|
线程实例
|
|
"""
|
|
|
|
@abstractmethod
|
|
def _pre_check(self):
|
|
"""
|
|
预检查
|
|
|
|
返回:
|
|
预检查结果
|
|
"""
|
|
|
|
@abstractmethod
|
|
def stop(self):
|
|
"""
|
|
停止线程
|
|
|
|
返回:
|
|
停止结果
|
|
"""
|
|
|
|
|
|
class FunctorFactory:
|
|
"""
|
|
Functor工厂类
|
|
|
|
该工厂类负责创建和配置Functor实例
|
|
|
|
主要方法:
|
|
make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
|
|
创建并配置Functor实例
|
|
"""
|
|
def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor:
|
|
"""
|
|
创建VAD Functor实例
|
|
"""
|
|
from src.functor.vad_functor import VADFunctor
|
|
|
|
audio_config = config["audio_config"]
|
|
model = {"vad": models["vad"]}
|
|
|
|
vad_functor = VADFunctor()
|
|
vad_functor.set_audio_config(audio_config)
|
|
vad_functor.set_model(model)
|
|
|
|
return vad_functor
|
|
|
|
def _make_asrfunctor(config: dict, models: dict) -> BaseFunctor:
|
|
"""
|
|
创建ASR Functor实例
|
|
"""
|
|
from src.functor.asr_functor import ASRFunctor
|
|
|
|
audio_config = config["audio_config"]
|
|
model = {"asr": models["asr"]}
|
|
|
|
asr_functor = ASRFunctor()
|
|
asr_functor.set_audio_config(audio_config)
|
|
asr_functor.set_model(model)
|
|
|
|
return asr_functor
|
|
|
|
def _make_spkfunctor(config: dict, models: dict) -> BaseFunctor:
|
|
"""
|
|
创建SPK Functor实例
|
|
"""
|
|
from src.functor.spk_functor import SPKFunctor
|
|
|
|
logger.debug(f"创建spk functor[开始]")
|
|
audio_config = config["audio_config"]
|
|
# model = {"spk": models["spk"]}
|
|
|
|
spk_functor = SPKFunctor(sv_pipeline=models["spk"])
|
|
spk_functor.set_audio_config(audio_config)
|
|
# spk_functor.set_model(model)
|
|
|
|
logger.debug(f"创建spk functor[完成]")
|
|
return spk_functor
|
|
|
|
def _make_resultbinderfunctor(config: dict, models: dict) -> BaseFunctor:
|
|
"""
|
|
创建ResultBinder Functor实例
|
|
"""
|
|
from src.functor.resultbinder_functor import ResultBinderFunctor
|
|
|
|
resultbinder_functor = ResultBinderFunctor()
|
|
|
|
return resultbinder_functor
|
|
|
|
factory_dict: Dict[str, Callable] = {
|
|
"vad": _make_vadfunctor,
|
|
"asr": _make_asrfunctor,
|
|
"spk": _make_spkfunctor,
|
|
"resultbinder": _make_resultbinderfunctor,
|
|
}
|
|
|
|
@classmethod
|
|
def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor:
|
|
"""
|
|
创建并配置Functor实例
|
|
|
|
参数:
|
|
funtor_name (str): Functor名称
|
|
config (dict): 配置信息
|
|
models (dict): 模型信息
|
|
|
|
返回:
|
|
BaseFunctor: 创建的Functor实例
|
|
"""
|
|
if functor_name in cls.factory_dict:
|
|
return cls.factory_dict[functor_name](config=config, models=models)
|
|
else:
|
|
raise ValueError(f"不支持的Functor类型: {functor_name}")
|