[代码重构中]编写融合VAD,ASR,SPK(FAKE)的ASRPipeline并完成测试,正常运行。

This commit is contained in:
Ziyang.Zhang 2025-06-06 17:26:08 +08:00
parent 3d8bf9de25
commit 5b94c40016
10 changed files with 405 additions and 128 deletions

View File

@ -110,6 +110,8 @@ class BaseFunctor(ABC):
停止结果 停止结果
""" """
class FunctorFactory: class FunctorFactory:
""" """
Functor工厂类 Functor工厂类
@ -121,8 +123,8 @@ class FunctorFactory:
创建并配置Functor实例 创建并配置Functor实例
""" """
@staticmethod @classmethod
def make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor: def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor:
""" """
创建并配置Functor实例 创建并配置Functor实例
@ -134,4 +136,60 @@ class FunctorFactory:
返回: 返回:
BaseFunctor: 创建的Functor实例 BaseFunctor: 创建的Functor实例
""" """
if functor_name == "vad":
return cls._make_vadfunctor(config = config,models = models)
elif functor_name == "asr":
return cls._make_asrfunctor(config = config,models = models)
elif functor_name == "spk":
return cls._make_spkfunctor(config = config,models = models)
else:
raise ValueError(f"不支持的Functor类型: {functor_name}")
def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建VAD Functor实例
"""
from src.functor.vad_functor import VADFunctor
audio_config = config["audio_config"]
model = {
"vad": models["vad"]
}
vad_functor = VADFunctor()
vad_functor.set_audio_config(audio_config)
vad_functor.set_model(model)
return vad_functor
def _make_asrfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建ASR Functor实例
"""
from src.functor.asr_functor import ASRFunctor
audio_config = config["audio_config"]
model = {
"asr": models["asr"]
}
asr_functor = ASRFunctor()
asr_functor.set_audio_config(audio_config)
asr_functor.set_model(model)
return asr_functor
def _make_spkfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建SPK Functor实例
"""
from src.functor.spk_functor import SPKFunctor
audio_config = config["audio_config"]
model = {
"spk": models["spk"]
}
spk_functor = SPKFunctor()
spk_functor.set_audio_config(audio_config)
spk_functor.set_model(model)
return spk_functor

View File

@ -6,38 +6,117 @@ Functor文件夹用于存放所有功能性的类包括VAD、PUNC、ASR、SPK
## Functor 类的定义 ## Functor 类的定义
所有类应继承于 **基类** `BaseFunctor` ,应遵从 *压入数据**数据处理* 解绑 所有类应继承于**基类**`BaseFunctor`
为了方便使用,我们对于 **基类** 的定义如下: 为了方便使用,我们对于**基类**的定义如下:
1. 函数内部使用的变量以单下划线开头,预定有 `_data`, `_callback`, `_model` 1. 函数内部使用的变量以单下划线开头,基类中包含:
* _model: Dict 存放模型相关的配置和实例
* _input_queue: Queue 监听的输入消息队列
* _thread: Threading.Thread 运行的线程实例
* _callback: List[Callable] 回调函数列表
* _is_running: bool 线程运行状态标志
* _stop_event: bool 停止事件标志
* _status_lock: threading.Lock 状态锁,用于线程同步
2. 对于使用的模型,请从统一的 **模型管理类`ModelLoader`** 中获取,由模型管理类统一进行加载、缓存和释放,`_model`存放类型为`dict` 2. 对于使用的模型,请从统一的 **模型管理类`ModelLoader`** 中获取,由模型管理类统一进行加载、缓存和释放,`_model`存放类型为`dict`
3. 定义了 3. 基类定义的核心方法:
`__call__`:可传入`data`,默认调用`push_data`,随后默认调用`process` * `add_callback(callback: Callable)`: 添加结果处理的回调函数
* `set_model(model: dict)`: 设置模型配置和实例
`__add__` * `set_input_queue(queue: Queue)`: 设置输入数据队列
* `run()`: 启动处理线程(抽象方法)
* `stop()`: 停止处理线程(抽象方法)
* `_run()`: 线程运行的具体逻辑(抽象方法)
* `_pre_check()`: 运行前的预检查(抽象方法)
## 派生类实现要求
1. 必须实现的抽象方法:
* `_pre_check()`:
- 检查必要的配置是否完整(如模型、队列等)
- 检查运行环境是否满足要求
- 返回检查结果
* `_run()`:
- 实现具体的数据处理逻辑
- 从 _input_queue 获取输入数据
- 使用 _model 进行处理
- 通过 _callback 返回处理结果
* `run()`:
- 调用 _pre_check() 进行预检查
- 创建并启动处理线程
- 设置相关状态标志
* `stop()`:
- 安全停止处理线程
- 清理资源
- 重置状态标志
2. 建议实现的方法:
* `__str__`: 返回当前实例的状态信息
* 错误处理方法:处理运行过程中的异常情况
## 使用示例
```python ```python
class BaseFunctor: class MyFunctor(BaseFunctor):
def __init__(self): def _pre_check(self):
self._data: dict = {} if not self._model or not self._input_queue:
self._callback: function = null return False
pass return True
def __call__(self, data): def _run(self):
result = while not self._stop_event:
self._callback(process(data)) try:
return self.process(data) data = self._input_queue.get(timeout=1.0)
result = self._model['my_model'].process(data)
for callback in self._callback:
callback(result)
except Queue.Empty:
continue
except Exception as e:
logger.error(f"处理错误: {e}")
def set_callback(self, callback: Callable): def run(self):
self._callback = callback if not self._pre_check():
raise RuntimeError("预检查失败")
with self._status_lock:
if self._is_running:
return
self._is_running = True
self._stop_event = False
self._thread = threading.Thread(target=self._run)
self._thread.start()
def push_data(): def stop(self):
pass with self._status_lock:
if not self._is_running:
return
self._stop_event = True
if self._thread:
self._thread.join()
self._is_running = False
```
def process(self, data): ## 注意事项
pass
``` 1. 线程安全:
* 使用 _status_lock 保护状态变更
* 注意共享资源的访问控制
2. 错误处理:
* 在 _run() 中妥善处理异常
* 提供详细的错误日志
3. 资源管理:
* 确保在 stop() 中正确清理资源
* 避免资源泄露
4. 回调函数:
* 回调函数应该是非阻塞的
* 处理回调函数抛出的异常

View File

@ -13,9 +13,9 @@ from src.utils.logger import get_module_logger
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
class SpkFunctor(BaseFunctor): class SPKFunctor(BaseFunctor):
""" """
SpkFunctor SPKFunctor
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback 负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
需要配置好 _model, _callback, _input_queue, _audio_config 需要配置好 _model, _callback, _input_queue, _audio_config
否则无法run()启动线程 否则无法run()启动线程

View File

@ -38,7 +38,7 @@ class ModelLoader:
初始化ModelLoader实例 初始化ModelLoader实例
""" """
self.models = {} self.models = {}
logger.info("初始化ModelLoader") logger.debug("初始化ModelLoader")
if args is not None: if args is not None:
self.__call__(args) self.__call__(args)
@ -87,14 +87,14 @@ class ModelLoader:
else: else:
value = input_model_args.get(key, None) value = input_model_args.get(key, None)
if value is not None: if value is not None:
logger.info("替换%s模型参数: %s = %s", model_type, key, value) logger.debug("替换%s模型参数: %s = %s", model_type, key, value)
model_args[key] = value model_args[key] = value
# 验证必要参数 # 验证必要参数
if not model_args["model"]: if not model_args["model"]:
raise ValueError(f"未指定{model_type}模型路径") raise ValueError(f"未指定{model_type}模型路径")
try: try:
# 使用 % 格式化替代 f-string,避免不必要的字符串格式化开销 # 使用 % 格式化替代 f-string,避免不必要的字符串格式化开销
logger.info("正在加载%s模型: %s", model_type, model_args["model"]) logger.debug("正在加载%s模型: %s", model_type, model_args["model"])
model = AutoModel(**model_args) model = AutoModel(**model_args)
return model return model
except Exception as e: except Exception as e:
@ -115,12 +115,12 @@ class ModelLoader:
self.models = {} self.models = {}
# 加载离线ASR模型 # 加载离线ASR模型
# 检查对应键是否存在 # 检查对应键是否存在
model_list = ['asr', 'asr_online', 'vad', 'punc'] model_list = ['asr', 'asr_online', 'vad', 'punc', 'spk']
for model_name in model_list: for model_name in model_list:
name_model = f"{model_name}_model" name_model = f"{model_name}_model"
name_model_revision = f"{model_name}_model_revision" name_model_revision = f"{model_name}_model_revision"
if name_model in args: if name_model in args:
logger.info("加载%s模型", model_name) logger.debug("加载%s模型", model_name)
self.models[model_name] = self._load_model(args, model_name) self.models[model_name] = self._load_model(args, model_name)
logger.info("所有模型加载完成") logger.info("所有模型加载完成")
return self.models return self.models

View File

@ -1,7 +1,9 @@
from src.pipeline.base import PipelineBase from src.pipeline.base import PipelineBase
from typing import Dict, Any from typing import Dict, Any, Callable
from queue import Queue from queue import Queue, Empty
from src.utils import get_module_logger from src.utils import get_module_logger
from src.models import AudioBinary_data_list
import threading
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
@ -19,6 +21,12 @@ class ASRPipeline(PipelineBase):
self._functor_dict: Dict[str, Any] = {} self._functor_dict: Dict[str, Any] = {}
self._subqueue_dict: Dict[str, Any] = {} self._subqueue_dict: Dict[str, Any] = {}
self._is_baked: bool = False self._is_baked: bool = False
self._input_queue: Queue = None
self._audio_binary_data_list: AudioBinary_data_list = None
self._status_lock = threading.Lock()
self._is_running: bool = False
self._stop_event: bool = False
def set_config(self, config: Dict[str, Any]) -> None: def set_config(self, config: Dict[str, Any]) -> None:
""" """
@ -28,7 +36,15 @@ class ASRPipeline(PipelineBase):
""" """
self._config = config self._config = config
def set_audio_binary(self, audio_binary: AudioBinary) -> None: def get_config(self) -> Dict[str, Any]:
"""
获取配置
返回:
Dict[str, Any] 配置
"""
return self._config
def set_audio_binary(self, audio_binary: AudioBinary_data_list) -> None:
""" """
设置音频二进制存储单元 设置音频二进制存储单元
参数: 参数:
@ -42,44 +58,70 @@ class ASRPipeline(PipelineBase):
""" """
self._models = models self._models = models
def set_input_queue(self, input_queue: Queue) -> None:
"""
设置输入队列
"""
self._input_queue = input_queue
def bake(self) -> None: def bake(self) -> None:
""" """
烘焙管道 烘焙管道
""" """
self._pre_check_resource()
self._init_functor() self._init_functor()
self._is_baked = True self._is_baked = True
def _pre_check_resource(self) -> None:
"""
预检查资源
"""
if self._input_queue is None:
raise RuntimeError("[ASRpipeline]输入队列未设置")
if self._functor_dict is None:
raise RuntimeError("[ASRpipeline]functor字典未设置")
if self._subqueue_dict is None:
raise RuntimeError("[ASRpipeline]子队列字典未设置")
if self._audio_binary is None:
raise RuntimeError("[ASRpipeline]音频存储单元未设置")
def _init_functor(self) -> None: def _init_functor(self) -> None:
""" """
初始化函数 初始化函数
""" """
try: try:
from src.functor import functorFactory from src.functor import FunctorFactory
# 加载VAD、asr、spk functor # 加载VAD、asr、spk functor
self._functor_dict["vad"] = functorFactory.make_functor( self._functor_dict["vad"] = FunctorFactory.make_functor(
functor_name = "vad", functor_name = "vad",
config = self._config, config = self._config,
models = self._models models = self._models
) )
self._functor_dict["asr"] = functorFactory.make_functor( self._functor_dict["asr"] = FunctorFactory.make_functor(
functor_name = "asr", functor_name = "asr",
config = self._config, config = self._config,
models = self._models models = self._models
) )
self._functor_dict["spk"] = functorFactory.make_functor( self._functor_dict["spk"] = FunctorFactory.make_functor(
functor_name = "spk", functor_name = "spk",
config = self._config, config = self._config,
models = self._models models = self._models
) )
# 创建音频数据存储单元
self._audio_binary_data_list = AudioBinary_data_list()
self._functor_dict["vad"].set_audio_binary_data_list(self._audio_binary_data_list)
# 初始化子队列 # 初始化子队列
self._subqueue_dict["original"] = Queue()
self._subqueue_dict["vad2asr"] = Queue() self._subqueue_dict["vad2asr"] = Queue()
self._subqueue_dict["vad2spk"] = Queue() self._subqueue_dict["vad2spk"] = Queue()
self._subqueue_dict["asrend"] = Queue() self._subqueue_dict["asrend"] = Queue()
self._subqueue_dict["spkend"] = Queue() self._subqueue_dict["spkend"] = Queue()
# 设置子队列的输入队列 # 设置子队列的输入队列
self._functor_dict["vad"].set_input_queue(self._input_queue) self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"])
self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"]) self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"]) self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
@ -94,7 +136,7 @@ class ASRPipeline(PipelineBase):
""" """
def put_with_check(data: Any) -> None: def put_with_check(data: Any) -> None:
queue.put(data) queue.put(data)
callback() callback(data)
return put_with_check return put_with_check
self._functor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result)) self._functor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result))
@ -102,76 +144,7 @@ class ASRPipeline(PipelineBase):
except ImportError: except ImportError:
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化") raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
def get_config(self) -> Dict[str, Any]:
"""
获取配置
返回:
Dict[str, Any] 配置
"""
return self._config
def process(self, data: Any) -> Any:
"""
处理数据
参数:
data: 输入数据
返回:
处理结果
"""
# 子类实现具体的处理逻辑
self._input_queue.put(data)
def run(self) -> None:
"""
运行管道
"""
if not self._is_baked:
raise RuntimeError("管道未烘焙,无法运行")
# 运行所有functor
for functor_name, functor in self._functor_dict.items():
logger.info(f"运行{functor_name}functor")
self._functor_dict[functor_name].run()
# 运行管道
if not self._input_queue:
raise RuntimeError("输入队列未设置")
# 设置管道运行状态
self._is_running = True
self._stop_event = False
self._thread = threading.current_thread()
logger.info("ASR管道开始运行")
while self._is_running and not self._stop_event:
try:
# 从队列获取数据
try:
data = self._input_queue.get(timeout=self._queue_timeout)
# 检查是否是结束信号
if data is None:
logger.info("收到结束信号,管道准备停止")
self._stop()
self._input_queue.task_done() # 标记结束信号已处理
break
# 处理数据
self.process(data)
# 标记任务完成
self._input_queue.task_done()
except Empty:
# 队列获取超时,继续等待
continue
except Exception as e:
logger.error(f"管道处理数据出错: {str(e)}")
continue
logger.info("管道停止运行")
def _check_result(self, result: Any) -> None: def _check_result(self, result: Any) -> None:
""" """
检查结果 检查结果
@ -188,13 +161,96 @@ class ASRPipeline(PipelineBase):
# 通知回调函数 # 通知回调函数
self._notify_callbacks(result) self._notify_callbacks(result)
def run(self) -> threading.Thread:
"""
运行管道
Returns:
threading.Thread: 返回已运行线程实例
"""
# 检查运行资源是否准备完毕
self._pre_check()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
logger.info("[ASRpipeline]管道开始运行")
return self._thread
def _pre_check(self) -> None:
"""
预检查
"""
if self._is_baked is False:
raise RuntimeError("[ASRpipeline]管道未烘焙,无法运行")
for functor_name, functor in self._functor_dict.items():
if functor is None:
raise RuntimeError(f"[ASRpipeline]functor{functor_name}异常")
for subqueue_name, subqueue in self._subqueue_dict.items():
if subqueue is None:
raise RuntimeError(f"[ASRpipeline]子队列{subqueue_name}异常")
def _run(self) -> None:
"""
真实的运行逻辑
"""
# 运行所有functor
for functor_name, functor in self._functor_dict.items():
logger.info(f"[ASRpipeline]运行{functor_name}functor")
self._functor_dict[functor_name].run()
# 设置管道运行状态
with self._status_lock:
self._is_running = True
self._stop_event = False
while self._is_running and not self._stop_event:
try:
data = self._input_queue.get(timeout=self._queue_timeout)
# 检查是否是结束信号
if data is None:
logger.info("收到结束信号,管道准备停止")
self._input_queue.task_done() # 标记结束信号已处理
break
# 处理数据
self._process(data)
# 标记任务完成
self._input_queue.task_done()
except Empty:
# 队列获取超时,继续等待
continue
except Exception as e:
logger.error(f"[ASRpipeline]管道处理数据出错: {str(e)}")
break
logger.info("[ASRpipeline]管道停止运行")
def _process(self, data: Any) -> Any:
"""
处理数据
参数:
data: 输入数据
返回:
处理结果
"""
# 子类实现具体的处理逻辑
self._subqueue_dict["original"].put(data)
def stop(self) -> None: def stop(self) -> None:
""" """
停止管道 停止管道
""" """
self._is_running = False with self._status_lock:
self._stop_event = True self._is_running = False
for functor_name, functor in self._functor_dict.items(): self._stop_event = True
logger.info(f"停止{functor_name}functor") for functor_name, functor in self._functor_dict.items():
functor.stop() # logger.info(f"停止{functor_name}functor")
logger.info("子Functor停止") if functor.stop():
logger.info(f"[ASRpipeline]子Functor[{functor_name}]停止")
else:
logger.error(f"[ASRpipeline]子Functor[{functor_name}]停止失败")
self._thread.join()
logger.info("[ASRpipeline]管道停止")
return True

View File

@ -1,3 +1,3 @@
from src.pipeline.base import PipelineBase, Pipeline, PipelineFactory from src.pipeline.base import PipelineBase, PipelineFactory
__all__ = ["PipelineBase", "Pipeline", "PipelineFactory"] __all__ = ["PipelineBase", "PipelineFactory"]

View File

@ -56,7 +56,7 @@ class PipelineBase(ABC):
logger.error(f"回调函数执行出错: {str(e)}") logger.error(f"回调函数执行出错: {str(e)}")
@abstractmethod @abstractmethod
def process(self, data: Any) -> Any: def _process(self, data: Any) -> Any:
""" """
处理数据 处理数据
参数: 参数:
@ -67,7 +67,7 @@ class PipelineBase(ABC):
pass pass
@abstractmethod @abstractmethod
def run(self) -> None: def _run(self) -> None:
""" """
运行管道 运行管道
从输入队列获取数据并处理 从输入队列获取数据并处理
@ -121,7 +121,7 @@ class PipelineFactory:
用于创建管道实例 用于创建管道实例
""" """
@staticmethod @staticmethod
def create_pipeline(pipeline_name: str) -> Pipeline: def create_pipeline(pipeline_name: str) -> Any:
""" """
创建管道实例 创建管道实例
""" """

View File

@ -1,8 +1,12 @@
from tests.functor.vad_test import test_vad_functor from tests.functor.vad_test import test_vad_functor
from tests.pipeline.asr_test import test_asr_pipeline
from src.utils.logger import get_module_logger, setup_root_logger from src.utils.logger import get_module_logger, setup_root_logger
setup_root_logger(level="INFO", log_file="logs/test_main.log") setup_root_logger(level="INFO", log_file="logs/test_main.log")
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
logger.info("开始测试VAD函数器") # logger.info("开始测试VAD函数器")
test_vad_functor() # test_vad_functor()
logger.info("开始测试ASR管道")
test_asr_pipeline()

View File

@ -4,7 +4,7 @@ VAD测试
""" """
from src.functor.vad_functor import VADFunctor from src.functor.vad_functor import VADFunctor
from src.functor.asr_functor import ASRFunctor from src.functor.asr_functor import ASRFunctor
from src.functor.spk_functor import SpkFunctor from src.functor.spk_functor import SPKFunctor
from queue import Queue, Empty from queue import Queue, Empty
from src.model_loader import ModelLoader from src.model_loader import ModelLoader
from src.models import AudioBinary_Config, AudioBinary_data_list from src.models import AudioBinary_Config, AudioBinary_data_list
@ -84,7 +84,7 @@ def test_vad_functor():
asr_functor.run() asr_functor.run()
# 创建SPK函数器 # 创建SPK函数器
spk_functor = SpkFunctor() spk_functor = SPKFunctor()
# 设置输入队列 # 设置输入队列
spk_functor.set_input_queue(vad2spk_queue) spk_functor.set_input_queue(vad2spk_queue)
# 设置音频配置 # 设置音频配置

View File

@ -0,0 +1,80 @@
"""
Pipeline测试
VAD+ASR+SPK(FAKE)
"""
from src.pipeline.ASRpipeline import ASRPipeline
from src.models import AudioBinary_data_list, AudioBinary_Config
from src.model_loader import ModelLoader
from queue import Queue
import soundfile
import time
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
OVAERWATCH = False
model_loader = ModelLoader()
def test_asr_pipeline():
# 加载模型
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"spk_model": "cam++",
"spk_model_revision": "v2.0.2",
"audio_update": False,
}
models = model_loader.load_models(args)
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
audio_config = AudioBinary_Config(
chunk_size=200,
chunk_stride=1600,
sample_rate=sample_rate,
sample_width=16,
channels=1
)
chunk_stride = int(audio_config.chunk_size*sample_rate/1000)
audio_config.chunk_stride = chunk_stride
# 创建参数Dict
config = {
"audio_config": audio_config,
}
# 创建音频数据列表
audio_binary_data_list = AudioBinary_data_list()
input_queue = Queue()
# 创建Pipeline
asr_pipeline = ASRPipeline()
asr_pipeline.set_models(models)
asr_pipeline.set_config(config)
asr_pipeline.set_audio_binary(audio_binary_data_list)
asr_pipeline.set_input_queue(input_queue)
asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
asr_pipeline.bake()
# 运行Pipeline
asr_instance = asr_pipeline.run()
audio_clip_len = 200
print(f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}")
for i in range(0, len(audio_data), audio_clip_len):
input_queue.put(audio_data[i:i+audio_clip_len])
# time.sleep(10)
# input_queue.put(None)
# 等待Pipeline结束
# asr_instance.join()
time.sleep(5)
asr_pipeline.stop()
# asr_pipeline.stop()