208 lines
5.1 KiB
Python

"""
Functor基础模块
该模块定义了Functor的基类,所有功能性的类(如VAD、PUNC、ASR、SPK等)都应继承自这个基类。
基类提供了数据处理的基本框架,包括:
- 回调函数管理
- 模型配置管理
- 线程运行控制
主要类:
BaseFunctor: Functor抽象类
FunctorFactory: Functor工厂类
"""
from abc import ABC, abstractmethod
from typing import Callable, List, Dict
from queue import Queue
import threading
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class BaseFunctor(ABC):
"""
Functor抽象类
该抽象类规定了所有的Functor类必须实现run()方法启动自身线程
属性:
_callback (Callable): 处理完成后的回调函数
_model (dict): 存储模型相关的配置和实例
"""
def __init__(self):
"""
初始化函数器
参数:
callback (Callable): 处理完成后的回调函数
model (dict): 模型相关的配置和实例
"""
self._callback: List[Callable] = []
self._model: dict = {}
# flag
self._is_running: bool = False
self._stop_event: bool = False
# 状态锁
self._status_lock: threading.Lock = threading.Lock()
# 线程资源
self._thread: threading.Thread = None
def add_callback(self, callback: Callable):
"""
添加回调函数
参数:
callback (Callable): 新的回调函数
"""
self._callback.append(callback)
def set_model(self, model: dict):
"""
设置模型配置
参数:
model (dict): 新的模型配置
"""
self._model = model
def set_input_queue(self, queue: Queue):
"""
设置输入队列
参数:
queue (Queue): 新的输入队列
"""
self._input_queue = queue
@abstractmethod
def _run(self):
"""
线程运行逻辑
返回:
当达到条件时触发callback
"""
@abstractmethod
def run(self):
"""
启动_run方法线程
返回:
线程实例
"""
@abstractmethod
def _pre_check(self):
"""
预检查
返回:
预检查结果
"""
@abstractmethod
def stop(self):
"""
停止线程
返回:
停止结果
"""
class FunctorFactory:
"""
Functor工厂类
该工厂类负责创建和配置Functor实例
主要方法:
make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
创建并配置Functor实例
"""
def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建VAD Functor实例
"""
from src.functor.vad_functor import VADFunctor
audio_config = config["audio_config"]
model = {"vad": models["vad"]}
vad_functor = VADFunctor()
vad_functor.set_audio_config(audio_config)
vad_functor.set_model(model)
return vad_functor
def _make_asrfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建ASR Functor实例
"""
from src.functor.asr_functor import ASRFunctor
audio_config = config["audio_config"]
model = {"asr": models["asr"]}
asr_functor = ASRFunctor()
asr_functor.set_audio_config(audio_config)
asr_functor.set_model(model)
return asr_functor
def _make_spkfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建SPK Functor实例
"""
from src.functor.spk_functor import SPKFunctor
logger.debug(f"创建spk functor[开始]")
audio_config = config["audio_config"]
# model = {"spk": models["spk"]}
spk_functor = SPKFunctor(sv_pipeline=models["spk"])
spk_functor.set_audio_config(audio_config)
# spk_functor.set_model(model)
logger.debug(f"创建spk functor[完成]")
return spk_functor
def _make_resultbinderfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建ResultBinder Functor实例
"""
from src.functor.resultbinder_functor import ResultBinderFunctor
resultbinder_functor = ResultBinderFunctor()
return resultbinder_functor
factory_dict: Dict[str, Callable] = {
"vad": _make_vadfunctor,
"asr": _make_asrfunctor,
"spk": _make_spkfunctor,
"resultbinder": _make_resultbinderfunctor,
}
@classmethod
def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor:
"""
创建并配置Functor实例
参数:
funtor_name (str): Functor名称
config (dict): 配置信息
models (dict): 模型信息
返回:
BaseFunctor: 创建的Functor实例
"""
if functor_name in cls.factory_dict:
return cls.factory_dict[functor_name](config=config, models=models)
else:
raise ValueError(f"不支持的Functor类型: {functor_name}")