[代码重构中]编写融合VAD,ASR,SPK(FAKE)的ASRPipeline并完成测试,正常运行。

This commit is contained in:
Ziyang.Zhang 2025-06-06 17:26:08 +08:00
parent 3d8bf9de25
commit 5b94c40016
10 changed files with 405 additions and 128 deletions

View File

@ -110,6 +110,8 @@ class BaseFunctor(ABC):
停止结果
"""
class FunctorFactory:
"""
Functor工厂类
@ -121,8 +123,8 @@ class FunctorFactory:
创建并配置Functor实例
"""
@staticmethod
def make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
@classmethod
def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor:
"""
创建并配置Functor实例
@ -135,3 +137,59 @@ class FunctorFactory:
BaseFunctor: 创建的Functor实例
"""
if functor_name == "vad":
return cls._make_vadfunctor(config = config,models = models)
elif functor_name == "asr":
return cls._make_asrfunctor(config = config,models = models)
elif functor_name == "spk":
return cls._make_spkfunctor(config = config,models = models)
else:
raise ValueError(f"不支持的Functor类型: {functor_name}")
def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建VAD Functor实例
"""
from src.functor.vad_functor import VADFunctor
audio_config = config["audio_config"]
model = {
"vad": models["vad"]
}
vad_functor = VADFunctor()
vad_functor.set_audio_config(audio_config)
vad_functor.set_model(model)
return vad_functor
def _make_asrfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建ASR Functor实例
"""
from src.functor.asr_functor import ASRFunctor
audio_config = config["audio_config"]
model = {
"asr": models["asr"]
}
asr_functor = ASRFunctor()
asr_functor.set_audio_config(audio_config)
asr_functor.set_model(model)
return asr_functor
def _make_spkfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建SPK Functor实例
"""
from src.functor.spk_functor import SPKFunctor
audio_config = config["audio_config"]
model = {
"spk": models["spk"]
}
spk_functor = SPKFunctor()
spk_functor.set_audio_config(audio_config)
spk_functor.set_model(model)
return spk_functor

View File

@ -6,38 +6,117 @@ Functor文件夹用于存放所有功能性的类包括VAD、PUNC、ASR、SPK
## Functor 类的定义
所有类应继承于 **基类** `BaseFunctor` ,应遵从 *压入数据**数据处理* 解绑
所有类应继承于**基类**`BaseFunctor`
为了方便使用,我们对于 **基类** 的定义如下:
为了方便使用,我们对于**基类**的定义如下:
1. 函数内部使用的变量以单下划线开头,预定有 `_data`, `_callback`, `_model`
1. 函数内部使用的变量以单下划线开头,基类中包含:
* _model: Dict 存放模型相关的配置和实例
* _input_queue: Queue 监听的输入消息队列
* _thread: Threading.Thread 运行的线程实例
* _callback: List[Callable] 回调函数列表
* _is_running: bool 线程运行状态标志
* _stop_event: bool 停止事件标志
* _status_lock: threading.Lock 状态锁,用于线程同步
2. 对于使用的模型,请从统一的 **模型管理类`ModelLoader`** 中获取,由模型管理类统一进行加载、缓存和释放,`_model`存放类型为`dict`
3. 定义了
3. 基类定义的核心方法:
`__call__`:可传入`data`,默认调用`push_data`,随后默认调用`process`
* `add_callback(callback: Callable)`: 添加结果处理的回调函数
* `set_model(model: dict)`: 设置模型配置和实例
* `set_input_queue(queue: Queue)`: 设置输入数据队列
* `run()`: 启动处理线程(抽象方法)
* `stop()`: 停止处理线程(抽象方法)
* `_run()`: 线程运行的具体逻辑(抽象方法)
* `_pre_check()`: 运行前的预检查(抽象方法)
`__add__`
## 派生类实现要求
1. 必须实现的抽象方法:
* `_pre_check()`:
- 检查必要的配置是否完整(如模型、队列等)
- 检查运行环境是否满足要求
- 返回检查结果
* `_run()`:
- 实现具体的数据处理逻辑
- 从 _input_queue 获取输入数据
- 使用 _model 进行处理
- 通过 _callback 返回处理结果
* `run()`:
- 调用 _pre_check() 进行预检查
- 创建并启动处理线程
- 设置相关状态标志
* `stop()`:
- 安全停止处理线程
- 清理资源
- 重置状态标志
2. 建议实现的方法:
* `__str__`: 返回当前实例的状态信息
* 错误处理方法:处理运行过程中的异常情况
## 使用示例
```python
class BaseFunctor:
def __init__(self):
self._data: dict = {}
self._callback: function = null
pass
class MyFunctor(BaseFunctor):
def _pre_check(self):
if not self._model or not self._input_queue:
return False
return True
def __call__(self, data):
result =
self._callback(process(data))
return self.process(data)
def _run(self):
while not self._stop_event:
try:
data = self._input_queue.get(timeout=1.0)
result = self._model['my_model'].process(data)
for callback in self._callback:
callback(result)
except Queue.Empty:
continue
except Exception as e:
logger.error(f"处理错误: {e}")
def set_callback(self, callback: Callable):
self._callback = callback
def run(self):
if not self._pre_check():
raise RuntimeError("预检查失败")
def push_data():
pass
with self._status_lock:
if self._is_running:
return
self._is_running = True
self._stop_event = False
self._thread = threading.Thread(target=self._run)
self._thread.start()
def process(self, data):
pass
def stop(self):
with self._status_lock:
if not self._is_running:
return
self._stop_event = True
if self._thread:
self._thread.join()
self._is_running = False
```
## 注意事项
1. 线程安全:
* 使用 _status_lock 保护状态变更
* 注意共享资源的访问控制
2. 错误处理:
* 在 _run() 中妥善处理异常
* 提供详细的错误日志
3. 资源管理:
* 确保在 stop() 中正确清理资源
* 避免资源泄露
4. 回调函数:
* 回调函数应该是非阻塞的
* 处理回调函数抛出的异常

View File

@ -13,9 +13,9 @@ from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class SpkFunctor(BaseFunctor):
class SPKFunctor(BaseFunctor):
"""
SpkFunctor
SPKFunctor
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
需要配置好 _model, _callback, _input_queue, _audio_config
否则无法run()启动线程

View File

@ -38,7 +38,7 @@ class ModelLoader:
初始化ModelLoader实例
"""
self.models = {}
logger.info("初始化ModelLoader")
logger.debug("初始化ModelLoader")
if args is not None:
self.__call__(args)
@ -87,14 +87,14 @@ class ModelLoader:
else:
value = input_model_args.get(key, None)
if value is not None:
logger.info("替换%s模型参数: %s = %s", model_type, key, value)
logger.debug("替换%s模型参数: %s = %s", model_type, key, value)
model_args[key] = value
# 验证必要参数
if not model_args["model"]:
raise ValueError(f"未指定{model_type}模型路径")
try:
# 使用 % 格式化替代 f-string,避免不必要的字符串格式化开销
logger.info("正在加载%s模型: %s", model_type, model_args["model"])
logger.debug("正在加载%s模型: %s", model_type, model_args["model"])
model = AutoModel(**model_args)
return model
except Exception as e:
@ -115,12 +115,12 @@ class ModelLoader:
self.models = {}
# 加载离线ASR模型
# 检查对应键是否存在
model_list = ['asr', 'asr_online', 'vad', 'punc']
model_list = ['asr', 'asr_online', 'vad', 'punc', 'spk']
for model_name in model_list:
name_model = f"{model_name}_model"
name_model_revision = f"{model_name}_model_revision"
if name_model in args:
logger.info("加载%s模型", model_name)
logger.debug("加载%s模型", model_name)
self.models[model_name] = self._load_model(args, model_name)
logger.info("所有模型加载完成")
return self.models

View File

@ -1,7 +1,9 @@
from src.pipeline.base import PipelineBase
from typing import Dict, Any
from queue import Queue
from typing import Dict, Any, Callable
from queue import Queue, Empty
from src.utils import get_module_logger
from src.models import AudioBinary_data_list
import threading
logger = get_module_logger(__name__)
@ -19,6 +21,12 @@ class ASRPipeline(PipelineBase):
self._functor_dict: Dict[str, Any] = {}
self._subqueue_dict: Dict[str, Any] = {}
self._is_baked: bool = False
self._input_queue: Queue = None
self._audio_binary_data_list: AudioBinary_data_list = None
self._status_lock = threading.Lock()
self._is_running: bool = False
self._stop_event: bool = False
def set_config(self, config: Dict[str, Any]) -> None:
"""
@ -28,7 +36,15 @@ class ASRPipeline(PipelineBase):
"""
self._config = config
def set_audio_binary(self, audio_binary: AudioBinary) -> None:
def get_config(self) -> Dict[str, Any]:
"""
获取配置
返回:
Dict[str, Any] 配置
"""
return self._config
def set_audio_binary(self, audio_binary: AudioBinary_data_list) -> None:
"""
设置音频二进制存储单元
参数:
@ -42,44 +58,70 @@ class ASRPipeline(PipelineBase):
"""
self._models = models
def set_input_queue(self, input_queue: Queue) -> None:
"""
设置输入队列
"""
self._input_queue = input_queue
def bake(self) -> None:
"""
烘焙管道
"""
self._pre_check_resource()
self._init_functor()
self._is_baked = True
def _pre_check_resource(self) -> None:
"""
预检查资源
"""
if self._input_queue is None:
raise RuntimeError("[ASRpipeline]输入队列未设置")
if self._functor_dict is None:
raise RuntimeError("[ASRpipeline]functor字典未设置")
if self._subqueue_dict is None:
raise RuntimeError("[ASRpipeline]子队列字典未设置")
if self._audio_binary is None:
raise RuntimeError("[ASRpipeline]音频存储单元未设置")
def _init_functor(self) -> None:
"""
初始化函数
"""
try:
from src.functor import functorFactory
from src.functor import FunctorFactory
# 加载VAD、asr、spk functor
self._functor_dict["vad"] = functorFactory.make_functor(
self._functor_dict["vad"] = FunctorFactory.make_functor(
functor_name = "vad",
config = self._config,
models = self._models
)
self._functor_dict["asr"] = functorFactory.make_functor(
self._functor_dict["asr"] = FunctorFactory.make_functor(
functor_name = "asr",
config = self._config,
models = self._models
)
self._functor_dict["spk"] = functorFactory.make_functor(
self._functor_dict["spk"] = FunctorFactory.make_functor(
functor_name = "spk",
config = self._config,
models = self._models
)
# 创建音频数据存储单元
self._audio_binary_data_list = AudioBinary_data_list()
self._functor_dict["vad"].set_audio_binary_data_list(self._audio_binary_data_list)
# 初始化子队列
self._subqueue_dict["original"] = Queue()
self._subqueue_dict["vad2asr"] = Queue()
self._subqueue_dict["vad2spk"] = Queue()
self._subqueue_dict["asrend"] = Queue()
self._subqueue_dict["spkend"] = Queue()
# 设置子队列的输入队列
self._functor_dict["vad"].set_input_queue(self._input_queue)
self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"])
self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
@ -94,7 +136,7 @@ class ASRPipeline(PipelineBase):
"""
def put_with_check(data: Any) -> None:
queue.put(data)
callback()
callback(data)
return put_with_check
self._functor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result))
@ -103,75 +145,6 @@ class ASRPipeline(PipelineBase):
except ImportError:
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
def get_config(self) -> Dict[str, Any]:
"""
获取配置
返回:
Dict[str, Any] 配置
"""
return self._config
def process(self, data: Any) -> Any:
"""
处理数据
参数:
data: 输入数据
返回:
处理结果
"""
# 子类实现具体的处理逻辑
self._input_queue.put(data)
def run(self) -> None:
"""
运行管道
"""
if not self._is_baked:
raise RuntimeError("管道未烘焙,无法运行")
# 运行所有functor
for functor_name, functor in self._functor_dict.items():
logger.info(f"运行{functor_name}functor")
self._functor_dict[functor_name].run()
# 运行管道
if not self._input_queue:
raise RuntimeError("输入队列未设置")
# 设置管道运行状态
self._is_running = True
self._stop_event = False
self._thread = threading.current_thread()
logger.info("ASR管道开始运行")
while self._is_running and not self._stop_event:
try:
# 从队列获取数据
try:
data = self._input_queue.get(timeout=self._queue_timeout)
# 检查是否是结束信号
if data is None:
logger.info("收到结束信号,管道准备停止")
self._stop()
self._input_queue.task_done() # 标记结束信号已处理
break
# 处理数据
self.process(data)
# 标记任务完成
self._input_queue.task_done()
except Empty:
# 队列获取超时,继续等待
continue
except Exception as e:
logger.error(f"管道处理数据出错: {str(e)}")
continue
logger.info("管道停止运行")
def _check_result(self, result: Any) -> None:
"""
检查结果
@ -188,13 +161,96 @@ class ASRPipeline(PipelineBase):
# 通知回调函数
self._notify_callbacks(result)
def run(self) -> threading.Thread:
"""
运行管道
Returns:
threading.Thread: 返回已运行线程实例
"""
# 检查运行资源是否准备完毕
self._pre_check()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
logger.info("[ASRpipeline]管道开始运行")
return self._thread
def _pre_check(self) -> None:
"""
预检查
"""
if self._is_baked is False:
raise RuntimeError("[ASRpipeline]管道未烘焙,无法运行")
for functor_name, functor in self._functor_dict.items():
if functor is None:
raise RuntimeError(f"[ASRpipeline]functor{functor_name}异常")
for subqueue_name, subqueue in self._subqueue_dict.items():
if subqueue is None:
raise RuntimeError(f"[ASRpipeline]子队列{subqueue_name}异常")
def _run(self) -> None:
"""
真实的运行逻辑
"""
# 运行所有functor
for functor_name, functor in self._functor_dict.items():
logger.info(f"[ASRpipeline]运行{functor_name}functor")
self._functor_dict[functor_name].run()
# 设置管道运行状态
with self._status_lock:
self._is_running = True
self._stop_event = False
while self._is_running and not self._stop_event:
try:
data = self._input_queue.get(timeout=self._queue_timeout)
# 检查是否是结束信号
if data is None:
logger.info("收到结束信号,管道准备停止")
self._input_queue.task_done() # 标记结束信号已处理
break
# 处理数据
self._process(data)
# 标记任务完成
self._input_queue.task_done()
except Empty:
# 队列获取超时,继续等待
continue
except Exception as e:
logger.error(f"[ASRpipeline]管道处理数据出错: {str(e)}")
break
logger.info("[ASRpipeline]管道停止运行")
def _process(self, data: Any) -> Any:
"""
处理数据
参数:
data: 输入数据
返回:
处理结果
"""
# 子类实现具体的处理逻辑
self._subqueue_dict["original"].put(data)
def stop(self) -> None:
"""
停止管道
"""
with self._status_lock:
self._is_running = False
self._stop_event = True
for functor_name, functor in self._functor_dict.items():
logger.info(f"停止{functor_name}functor")
functor.stop()
logger.info("子Functor停止")
# logger.info(f"停止{functor_name}functor")
if functor.stop():
logger.info(f"[ASRpipeline]子Functor[{functor_name}]停止")
else:
logger.error(f"[ASRpipeline]子Functor[{functor_name}]停止失败")
self._thread.join()
logger.info("[ASRpipeline]管道停止")
return True

View File

@ -1,3 +1,3 @@
from src.pipeline.base import PipelineBase, Pipeline, PipelineFactory
from src.pipeline.base import PipelineBase, PipelineFactory
__all__ = ["PipelineBase", "Pipeline", "PipelineFactory"]
__all__ = ["PipelineBase", "PipelineFactory"]

View File

@ -56,7 +56,7 @@ class PipelineBase(ABC):
logger.error(f"回调函数执行出错: {str(e)}")
@abstractmethod
def process(self, data: Any) -> Any:
def _process(self, data: Any) -> Any:
"""
处理数据
参数:
@ -67,7 +67,7 @@ class PipelineBase(ABC):
pass
@abstractmethod
def run(self) -> None:
def _run(self) -> None:
"""
运行管道
从输入队列获取数据并处理
@ -121,7 +121,7 @@ class PipelineFactory:
用于创建管道实例
"""
@staticmethod
def create_pipeline(pipeline_name: str) -> Pipeline:
def create_pipeline(pipeline_name: str) -> Any:
"""
创建管道实例
"""

View File

@ -1,8 +1,12 @@
from tests.functor.vad_test import test_vad_functor
from tests.pipeline.asr_test import test_asr_pipeline
from src.utils.logger import get_module_logger, setup_root_logger
setup_root_logger(level="INFO", log_file="logs/test_main.log")
logger = get_module_logger(__name__)
logger.info("开始测试VAD函数器")
test_vad_functor()
# logger.info("开始测试VAD函数器")
# test_vad_functor()
logger.info("开始测试ASR管道")
test_asr_pipeline()

View File

@ -4,7 +4,7 @@ VAD测试
"""
from src.functor.vad_functor import VADFunctor
from src.functor.asr_functor import ASRFunctor
from src.functor.spk_functor import SpkFunctor
from src.functor.spk_functor import SPKFunctor
from queue import Queue, Empty
from src.model_loader import ModelLoader
from src.models import AudioBinary_Config, AudioBinary_data_list
@ -84,7 +84,7 @@ def test_vad_functor():
asr_functor.run()
# 创建SPK函数器
spk_functor = SpkFunctor()
spk_functor = SPKFunctor()
# 设置输入队列
spk_functor.set_input_queue(vad2spk_queue)
# 设置音频配置

View File

@ -0,0 +1,80 @@
"""
Pipeline测试
VAD+ASR+SPK(FAKE)
"""
from src.pipeline.ASRpipeline import ASRPipeline
from src.models import AudioBinary_data_list, AudioBinary_Config
from src.model_loader import ModelLoader
from queue import Queue
import soundfile
import time
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
OVAERWATCH = False
model_loader = ModelLoader()
def test_asr_pipeline():
# 加载模型
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"spk_model": "cam++",
"spk_model_revision": "v2.0.2",
"audio_update": False,
}
models = model_loader.load_models(args)
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
audio_config = AudioBinary_Config(
chunk_size=200,
chunk_stride=1600,
sample_rate=sample_rate,
sample_width=16,
channels=1
)
chunk_stride = int(audio_config.chunk_size*sample_rate/1000)
audio_config.chunk_stride = chunk_stride
# 创建参数Dict
config = {
"audio_config": audio_config,
}
# 创建音频数据列表
audio_binary_data_list = AudioBinary_data_list()
input_queue = Queue()
# 创建Pipeline
asr_pipeline = ASRPipeline()
asr_pipeline.set_models(models)
asr_pipeline.set_config(config)
asr_pipeline.set_audio_binary(audio_binary_data_list)
asr_pipeline.set_input_queue(input_queue)
asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
asr_pipeline.bake()
# 运行Pipeline
asr_instance = asr_pipeline.run()
audio_clip_len = 200
print(f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}")
for i in range(0, len(audio_data), audio_clip_len):
input_queue.put(audio_data[i:i+audio_clip_len])
# time.sleep(10)
# input_queue.put(None)
# 等待Pipeline结束
# asr_instance.join()
time.sleep(5)
asr_pipeline.stop()
# asr_pipeline.stop()