[代码重构中]编写融合VAD,ASR,SPK(FAKE)的ASRPipeline并完成测试,正常运行。
This commit is contained in:
parent
3d8bf9de25
commit
5b94c40016
@ -110,6 +110,8 @@ class BaseFunctor(ABC):
|
||||
停止结果
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class FunctorFactory:
|
||||
"""
|
||||
Functor工厂类
|
||||
@ -121,8 +123,8 @@ class FunctorFactory:
|
||||
创建并配置Functor实例
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
|
||||
@classmethod
|
||||
def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor:
|
||||
"""
|
||||
创建并配置Functor实例
|
||||
|
||||
@ -135,3 +137,59 @@ class FunctorFactory:
|
||||
BaseFunctor: 创建的Functor实例
|
||||
"""
|
||||
|
||||
if functor_name == "vad":
|
||||
return cls._make_vadfunctor(config = config,models = models)
|
||||
elif functor_name == "asr":
|
||||
return cls._make_asrfunctor(config = config,models = models)
|
||||
elif functor_name == "spk":
|
||||
return cls._make_spkfunctor(config = config,models = models)
|
||||
else:
|
||||
raise ValueError(f"不支持的Functor类型: {functor_name}")
|
||||
|
||||
def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||
"""
|
||||
创建VAD Functor实例
|
||||
"""
|
||||
from src.functor.vad_functor import VADFunctor
|
||||
audio_config = config["audio_config"]
|
||||
model = {
|
||||
"vad": models["vad"]
|
||||
}
|
||||
|
||||
vad_functor = VADFunctor()
|
||||
vad_functor.set_audio_config(audio_config)
|
||||
vad_functor.set_model(model)
|
||||
|
||||
return vad_functor
|
||||
|
||||
def _make_asrfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||
"""
|
||||
创建ASR Functor实例
|
||||
"""
|
||||
from src.functor.asr_functor import ASRFunctor
|
||||
audio_config = config["audio_config"]
|
||||
model = {
|
||||
"asr": models["asr"]
|
||||
}
|
||||
|
||||
asr_functor = ASRFunctor()
|
||||
asr_functor.set_audio_config(audio_config)
|
||||
asr_functor.set_model(model)
|
||||
|
||||
return asr_functor
|
||||
|
||||
def _make_spkfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||
"""
|
||||
创建SPK Functor实例
|
||||
"""
|
||||
from src.functor.spk_functor import SPKFunctor
|
||||
audio_config = config["audio_config"]
|
||||
model = {
|
||||
"spk": models["spk"]
|
||||
}
|
||||
|
||||
spk_functor = SPKFunctor()
|
||||
spk_functor.set_audio_config(audio_config)
|
||||
spk_functor.set_model(model)
|
||||
|
||||
return spk_functor
|
@ -6,38 +6,117 @@ Functor文件夹用于存放所有功能性的类,包括VAD、PUNC、ASR、SPK
|
||||
|
||||
## Functor 类的定义
|
||||
|
||||
所有类应继承于 **基类** `BaseFunctor` ,应遵从 *压入数据* 与 *数据处理* 解绑。
|
||||
所有类应继承于**基类**`BaseFunctor`。
|
||||
|
||||
为了方便使用,我们对于 **基类** 的定义如下:
|
||||
为了方便使用,我们对于**基类**的定义如下:
|
||||
|
||||
1. 函数内部使用的变量以单下划线开头,预定有 `_data`, `_callback`, `_model`等
|
||||
1. 函数内部使用的变量以单下划线开头,基类中包含:
|
||||
|
||||
* _model: Dict 存放模型相关的配置和实例
|
||||
* _input_queue: Queue 监听的输入消息队列
|
||||
* _thread: Threading.Thread 运行的线程实例
|
||||
* _callback: List[Callable] 回调函数列表
|
||||
* _is_running: bool 线程运行状态标志
|
||||
* _stop_event: bool 停止事件标志
|
||||
* _status_lock: threading.Lock 状态锁,用于线程同步
|
||||
|
||||
2. 对于使用的模型,请从统一的 **模型管理类`ModelLoader`** 中获取,由模型管理类统一进行加载、缓存和释放,`_model`存放类型为`dict`。
|
||||
|
||||
3. 定义了
|
||||
3. 基类定义的核心方法:
|
||||
|
||||
`__call__`:可传入`data`,默认调用`push_data`,随后默认调用`process`
|
||||
* `add_callback(callback: Callable)`: 添加结果处理的回调函数
|
||||
* `set_model(model: dict)`: 设置模型配置和实例
|
||||
* `set_input_queue(queue: Queue)`: 设置输入数据队列
|
||||
* `run()`: 启动处理线程(抽象方法)
|
||||
* `stop()`: 停止处理线程(抽象方法)
|
||||
* `_run()`: 线程运行的具体逻辑(抽象方法)
|
||||
* `_pre_check()`: 运行前的预检查(抽象方法)
|
||||
|
||||
`__add__`
|
||||
## 派生类实现要求
|
||||
|
||||
1. 必须实现的抽象方法:
|
||||
* `_pre_check()`:
|
||||
- 检查必要的配置是否完整(如模型、队列等)
|
||||
- 检查运行环境是否满足要求
|
||||
- 返回检查结果
|
||||
|
||||
* `_run()`:
|
||||
- 实现具体的数据处理逻辑
|
||||
- 从 _input_queue 获取输入数据
|
||||
- 使用 _model 进行处理
|
||||
- 通过 _callback 返回处理结果
|
||||
|
||||
* `run()`:
|
||||
- 调用 _pre_check() 进行预检查
|
||||
- 创建并启动处理线程
|
||||
- 设置相关状态标志
|
||||
|
||||
* `stop()`:
|
||||
- 安全停止处理线程
|
||||
- 清理资源
|
||||
- 重置状态标志
|
||||
|
||||
2. 建议实现的方法:
|
||||
* `__str__`: 返回当前实例的状态信息
|
||||
* 错误处理方法:处理运行过程中的异常情况
|
||||
|
||||
## 使用示例
|
||||
|
||||
```python
|
||||
class BaseFunctor:
|
||||
def __init__(self):
|
||||
self._data: dict = {}
|
||||
self._callback: function = null
|
||||
pass
|
||||
class MyFunctor(BaseFunctor):
|
||||
def _pre_check(self):
|
||||
if not self._model or not self._input_queue:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __call__(self, data):
|
||||
result =
|
||||
self._callback(process(data))
|
||||
return self.process(data)
|
||||
def _run(self):
|
||||
while not self._stop_event:
|
||||
try:
|
||||
data = self._input_queue.get(timeout=1.0)
|
||||
result = self._model['my_model'].process(data)
|
||||
for callback in self._callback:
|
||||
callback(result)
|
||||
except Queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"处理错误: {e}")
|
||||
|
||||
def set_callback(self, callback: Callable):
|
||||
self._callback = callback
|
||||
def run(self):
|
||||
if not self._pre_check():
|
||||
raise RuntimeError("预检查失败")
|
||||
|
||||
def push_data():
|
||||
pass
|
||||
with self._status_lock:
|
||||
if self._is_running:
|
||||
return
|
||||
self._is_running = True
|
||||
self._stop_event = False
|
||||
self._thread = threading.Thread(target=self._run)
|
||||
self._thread.start()
|
||||
|
||||
def process(self, data):
|
||||
pass
|
||||
def stop(self):
|
||||
with self._status_lock:
|
||||
if not self._is_running:
|
||||
return
|
||||
self._stop_event = True
|
||||
if self._thread:
|
||||
self._thread.join()
|
||||
self._is_running = False
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 线程安全:
|
||||
* 使用 _status_lock 保护状态变更
|
||||
* 注意共享资源的访问控制
|
||||
|
||||
2. 错误处理:
|
||||
* 在 _run() 中妥善处理异常
|
||||
* 提供详细的错误日志
|
||||
|
||||
3. 资源管理:
|
||||
* 确保在 stop() 中正确清理资源
|
||||
* 避免资源泄露
|
||||
|
||||
4. 回调函数:
|
||||
* 回调函数应该是非阻塞的
|
||||
* 处理回调函数抛出的异常
|
@ -13,9 +13,9 @@ from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
class SpkFunctor(BaseFunctor):
|
||||
class SPKFunctor(BaseFunctor):
|
||||
"""
|
||||
SpkFunctor
|
||||
SPKFunctor
|
||||
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
|
||||
需要配置好 _model, _callback, _input_queue, _audio_config
|
||||
否则无法run()启动线程
|
||||
|
@ -38,7 +38,7 @@ class ModelLoader:
|
||||
初始化ModelLoader实例
|
||||
"""
|
||||
self.models = {}
|
||||
logger.info("初始化ModelLoader")
|
||||
logger.debug("初始化ModelLoader")
|
||||
if args is not None:
|
||||
self.__call__(args)
|
||||
|
||||
@ -87,14 +87,14 @@ class ModelLoader:
|
||||
else:
|
||||
value = input_model_args.get(key, None)
|
||||
if value is not None:
|
||||
logger.info("替换%s模型参数: %s = %s", model_type, key, value)
|
||||
logger.debug("替换%s模型参数: %s = %s", model_type, key, value)
|
||||
model_args[key] = value
|
||||
# 验证必要参数
|
||||
if not model_args["model"]:
|
||||
raise ValueError(f"未指定{model_type}模型路径")
|
||||
try:
|
||||
# 使用 % 格式化替代 f-string,避免不必要的字符串格式化开销
|
||||
logger.info("正在加载%s模型: %s", model_type, model_args["model"])
|
||||
logger.debug("正在加载%s模型: %s", model_type, model_args["model"])
|
||||
model = AutoModel(**model_args)
|
||||
return model
|
||||
except Exception as e:
|
||||
@ -115,12 +115,12 @@ class ModelLoader:
|
||||
self.models = {}
|
||||
# 加载离线ASR模型
|
||||
# 检查对应键是否存在
|
||||
model_list = ['asr', 'asr_online', 'vad', 'punc']
|
||||
model_list = ['asr', 'asr_online', 'vad', 'punc', 'spk']
|
||||
for model_name in model_list:
|
||||
name_model = f"{model_name}_model"
|
||||
name_model_revision = f"{model_name}_model_revision"
|
||||
if name_model in args:
|
||||
logger.info("加载%s模型", model_name)
|
||||
logger.debug("加载%s模型", model_name)
|
||||
self.models[model_name] = self._load_model(args, model_name)
|
||||
logger.info("所有模型加载完成")
|
||||
return self.models
|
||||
|
@ -1,7 +1,9 @@
|
||||
from src.pipeline.base import PipelineBase
|
||||
from typing import Dict, Any
|
||||
from queue import Queue
|
||||
from typing import Dict, Any, Callable
|
||||
from queue import Queue, Empty
|
||||
from src.utils import get_module_logger
|
||||
from src.models import AudioBinary_data_list
|
||||
import threading
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
@ -19,6 +21,12 @@ class ASRPipeline(PipelineBase):
|
||||
self._functor_dict: Dict[str, Any] = {}
|
||||
self._subqueue_dict: Dict[str, Any] = {}
|
||||
self._is_baked: bool = False
|
||||
self._input_queue: Queue = None
|
||||
self._audio_binary_data_list: AudioBinary_data_list = None
|
||||
|
||||
self._status_lock = threading.Lock()
|
||||
self._is_running: bool = False
|
||||
self._stop_event: bool = False
|
||||
|
||||
def set_config(self, config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
@ -28,7 +36,15 @@ class ASRPipeline(PipelineBase):
|
||||
"""
|
||||
self._config = config
|
||||
|
||||
def set_audio_binary(self, audio_binary: AudioBinary) -> None:
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取配置
|
||||
返回:
|
||||
Dict[str, Any] 配置
|
||||
"""
|
||||
return self._config
|
||||
|
||||
def set_audio_binary(self, audio_binary: AudioBinary_data_list) -> None:
|
||||
"""
|
||||
设置音频二进制存储单元
|
||||
参数:
|
||||
@ -42,44 +58,70 @@ class ASRPipeline(PipelineBase):
|
||||
"""
|
||||
self._models = models
|
||||
|
||||
def set_input_queue(self, input_queue: Queue) -> None:
|
||||
"""
|
||||
设置输入队列
|
||||
"""
|
||||
self._input_queue = input_queue
|
||||
|
||||
def bake(self) -> None:
|
||||
"""
|
||||
烘焙管道
|
||||
"""
|
||||
self._pre_check_resource()
|
||||
self._init_functor()
|
||||
self._is_baked = True
|
||||
|
||||
def _pre_check_resource(self) -> None:
|
||||
"""
|
||||
预检查资源
|
||||
"""
|
||||
if self._input_queue is None:
|
||||
raise RuntimeError("[ASRpipeline]输入队列未设置")
|
||||
if self._functor_dict is None:
|
||||
raise RuntimeError("[ASRpipeline]functor字典未设置")
|
||||
if self._subqueue_dict is None:
|
||||
raise RuntimeError("[ASRpipeline]子队列字典未设置")
|
||||
if self._audio_binary is None:
|
||||
raise RuntimeError("[ASRpipeline]音频存储单元未设置")
|
||||
|
||||
def _init_functor(self) -> None:
|
||||
"""
|
||||
初始化函数
|
||||
"""
|
||||
try:
|
||||
from src.functor import functorFactory
|
||||
from src.functor import FunctorFactory
|
||||
# 加载VAD、asr、spk functor
|
||||
self._functor_dict["vad"] = functorFactory.make_functor(
|
||||
self._functor_dict["vad"] = FunctorFactory.make_functor(
|
||||
functor_name = "vad",
|
||||
config = self._config,
|
||||
models = self._models
|
||||
)
|
||||
self._functor_dict["asr"] = functorFactory.make_functor(
|
||||
self._functor_dict["asr"] = FunctorFactory.make_functor(
|
||||
functor_name = "asr",
|
||||
config = self._config,
|
||||
models = self._models
|
||||
)
|
||||
self._functor_dict["spk"] = functorFactory.make_functor(
|
||||
self._functor_dict["spk"] = FunctorFactory.make_functor(
|
||||
functor_name = "spk",
|
||||
config = self._config,
|
||||
models = self._models
|
||||
)
|
||||
|
||||
# 创建音频数据存储单元
|
||||
self._audio_binary_data_list = AudioBinary_data_list()
|
||||
|
||||
self._functor_dict["vad"].set_audio_binary_data_list(self._audio_binary_data_list)
|
||||
|
||||
# 初始化子队列
|
||||
self._subqueue_dict["original"] = Queue()
|
||||
self._subqueue_dict["vad2asr"] = Queue()
|
||||
self._subqueue_dict["vad2spk"] = Queue()
|
||||
self._subqueue_dict["asrend"] = Queue()
|
||||
self._subqueue_dict["spkend"] = Queue()
|
||||
|
||||
# 设置子队列的输入队列
|
||||
self._functor_dict["vad"].set_input_queue(self._input_queue)
|
||||
self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"])
|
||||
self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
|
||||
self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
|
||||
|
||||
@ -94,7 +136,7 @@ class ASRPipeline(PipelineBase):
|
||||
"""
|
||||
def put_with_check(data: Any) -> None:
|
||||
queue.put(data)
|
||||
callback()
|
||||
callback(data)
|
||||
return put_with_check
|
||||
|
||||
self._functor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result))
|
||||
@ -103,75 +145,6 @@ class ASRPipeline(PipelineBase):
|
||||
except ImportError:
|
||||
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取配置
|
||||
返回:
|
||||
Dict[str, Any] 配置
|
||||
"""
|
||||
return self._config
|
||||
|
||||
def process(self, data: Any) -> Any:
|
||||
"""
|
||||
处理数据
|
||||
参数:
|
||||
data: 输入数据
|
||||
返回:
|
||||
处理结果
|
||||
"""
|
||||
# 子类实现具体的处理逻辑
|
||||
self._input_queue.put(data)
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
运行管道
|
||||
"""
|
||||
if not self._is_baked:
|
||||
raise RuntimeError("管道未烘焙,无法运行")
|
||||
|
||||
# 运行所有functor
|
||||
for functor_name, functor in self._functor_dict.items():
|
||||
logger.info(f"运行{functor_name}functor")
|
||||
self._functor_dict[functor_name].run()
|
||||
|
||||
# 运行管道
|
||||
if not self._input_queue:
|
||||
raise RuntimeError("输入队列未设置")
|
||||
|
||||
# 设置管道运行状态
|
||||
self._is_running = True
|
||||
self._stop_event = False
|
||||
self._thread = threading.current_thread()
|
||||
logger.info("ASR管道开始运行")
|
||||
|
||||
while self._is_running and not self._stop_event:
|
||||
try:
|
||||
# 从队列获取数据
|
||||
try:
|
||||
data = self._input_queue.get(timeout=self._queue_timeout)
|
||||
# 检查是否是结束信号
|
||||
if data is None:
|
||||
logger.info("收到结束信号,管道准备停止")
|
||||
self._stop()
|
||||
self._input_queue.task_done() # 标记结束信号已处理
|
||||
break
|
||||
|
||||
# 处理数据
|
||||
self.process(data)
|
||||
|
||||
# 标记任务完成
|
||||
self._input_queue.task_done()
|
||||
|
||||
except Empty:
|
||||
# 队列获取超时,继续等待
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"管道处理数据出错: {str(e)}")
|
||||
continue
|
||||
|
||||
logger.info("管道停止运行")
|
||||
|
||||
def _check_result(self, result: Any) -> None:
|
||||
"""
|
||||
检查结果
|
||||
@ -188,13 +161,96 @@ class ASRPipeline(PipelineBase):
|
||||
# 通知回调函数
|
||||
self._notify_callbacks(result)
|
||||
|
||||
def run(self) -> threading.Thread:
|
||||
"""
|
||||
运行管道
|
||||
Returns:
|
||||
threading.Thread: 返回已运行线程实例
|
||||
"""
|
||||
# 检查运行资源是否准备完毕
|
||||
self._pre_check()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("[ASRpipeline]管道开始运行")
|
||||
return self._thread
|
||||
|
||||
def _pre_check(self) -> None:
|
||||
"""
|
||||
预检查
|
||||
"""
|
||||
if self._is_baked is False:
|
||||
raise RuntimeError("[ASRpipeline]管道未烘焙,无法运行")
|
||||
|
||||
for functor_name, functor in self._functor_dict.items():
|
||||
if functor is None:
|
||||
raise RuntimeError(f"[ASRpipeline]functor{functor_name}异常")
|
||||
|
||||
for subqueue_name, subqueue in self._subqueue_dict.items():
|
||||
if subqueue is None:
|
||||
raise RuntimeError(f"[ASRpipeline]子队列{subqueue_name}异常")
|
||||
|
||||
def _run(self) -> None:
|
||||
"""
|
||||
真实的运行逻辑
|
||||
"""
|
||||
# 运行所有functor
|
||||
for functor_name, functor in self._functor_dict.items():
|
||||
logger.info(f"[ASRpipeline]运行{functor_name}functor")
|
||||
self._functor_dict[functor_name].run()
|
||||
|
||||
# 设置管道运行状态
|
||||
with self._status_lock:
|
||||
self._is_running = True
|
||||
self._stop_event = False
|
||||
|
||||
while self._is_running and not self._stop_event:
|
||||
try:
|
||||
data = self._input_queue.get(timeout=self._queue_timeout)
|
||||
# 检查是否是结束信号
|
||||
if data is None:
|
||||
logger.info("收到结束信号,管道准备停止")
|
||||
self._input_queue.task_done() # 标记结束信号已处理
|
||||
break
|
||||
|
||||
# 处理数据
|
||||
self._process(data)
|
||||
|
||||
# 标记任务完成
|
||||
self._input_queue.task_done()
|
||||
|
||||
except Empty:
|
||||
# 队列获取超时,继续等待
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"[ASRpipeline]管道处理数据出错: {str(e)}")
|
||||
break
|
||||
|
||||
logger.info("[ASRpipeline]管道停止运行")
|
||||
|
||||
def _process(self, data: Any) -> Any:
|
||||
"""
|
||||
处理数据
|
||||
参数:
|
||||
data: 输入数据
|
||||
返回:
|
||||
处理结果
|
||||
"""
|
||||
# 子类实现具体的处理逻辑
|
||||
self._subqueue_dict["original"].put(data)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
停止管道
|
||||
"""
|
||||
self._is_running = False
|
||||
self._stop_event = True
|
||||
for functor_name, functor in self._functor_dict.items():
|
||||
logger.info(f"停止{functor_name}functor")
|
||||
functor.stop()
|
||||
logger.info("子Functor停止")
|
||||
with self._status_lock:
|
||||
self._is_running = False
|
||||
self._stop_event = True
|
||||
for functor_name, functor in self._functor_dict.items():
|
||||
# logger.info(f"停止{functor_name}functor")
|
||||
if functor.stop():
|
||||
logger.info(f"[ASRpipeline]子Functor[{functor_name}]停止")
|
||||
else:
|
||||
logger.error(f"[ASRpipeline]子Functor[{functor_name}]停止失败")
|
||||
self._thread.join()
|
||||
logger.info("[ASRpipeline]管道停止")
|
||||
return True
|
||||
|
@ -1,3 +1,3 @@
|
||||
from src.pipeline.base import PipelineBase, Pipeline, PipelineFactory
|
||||
from src.pipeline.base import PipelineBase, PipelineFactory
|
||||
|
||||
__all__ = ["PipelineBase", "Pipeline", "PipelineFactory"]
|
||||
__all__ = ["PipelineBase", "PipelineFactory"]
|
||||
|
@ -56,7 +56,7 @@ class PipelineBase(ABC):
|
||||
logger.error(f"回调函数执行出错: {str(e)}")
|
||||
|
||||
@abstractmethod
|
||||
def process(self, data: Any) -> Any:
|
||||
def _process(self, data: Any) -> Any:
|
||||
"""
|
||||
处理数据
|
||||
参数:
|
||||
@ -67,7 +67,7 @@ class PipelineBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run(self) -> None:
|
||||
def _run(self) -> None:
|
||||
"""
|
||||
运行管道
|
||||
从输入队列获取数据并处理
|
||||
@ -121,7 +121,7 @@ class PipelineFactory:
|
||||
用于创建管道实例
|
||||
"""
|
||||
@staticmethod
|
||||
def create_pipeline(pipeline_name: str) -> Pipeline:
|
||||
def create_pipeline(pipeline_name: str) -> Any:
|
||||
"""
|
||||
创建管道实例
|
||||
"""
|
||||
|
@ -1,8 +1,12 @@
|
||||
from tests.functor.vad_test import test_vad_functor
|
||||
from tests.pipeline.asr_test import test_asr_pipeline
|
||||
from src.utils.logger import get_module_logger, setup_root_logger
|
||||
|
||||
setup_root_logger(level="INFO", log_file="logs/test_main.log")
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
logger.info("开始测试VAD函数器")
|
||||
test_vad_functor()
|
||||
# logger.info("开始测试VAD函数器")
|
||||
# test_vad_functor()
|
||||
|
||||
logger.info("开始测试ASR管道")
|
||||
test_asr_pipeline()
|
||||
|
@ -4,7 +4,7 @@ VAD测试
|
||||
"""
|
||||
from src.functor.vad_functor import VADFunctor
|
||||
from src.functor.asr_functor import ASRFunctor
|
||||
from src.functor.spk_functor import SpkFunctor
|
||||
from src.functor.spk_functor import SPKFunctor
|
||||
from queue import Queue, Empty
|
||||
from src.model_loader import ModelLoader
|
||||
from src.models import AudioBinary_Config, AudioBinary_data_list
|
||||
@ -84,7 +84,7 @@ def test_vad_functor():
|
||||
asr_functor.run()
|
||||
|
||||
# 创建SPK函数器
|
||||
spk_functor = SpkFunctor()
|
||||
spk_functor = SPKFunctor()
|
||||
# 设置输入队列
|
||||
spk_functor.set_input_queue(vad2spk_queue)
|
||||
# 设置音频配置
|
||||
|
80
tests/pipeline/asr_test.py
Normal file
80
tests/pipeline/asr_test.py
Normal file
@ -0,0 +1,80 @@
|
||||
"""
|
||||
Pipeline测试
|
||||
VAD+ASR+SPK(FAKE)
|
||||
"""
|
||||
from src.pipeline.ASRpipeline import ASRPipeline
|
||||
from src.models import AudioBinary_data_list, AudioBinary_Config
|
||||
from src.model_loader import ModelLoader
|
||||
from queue import Queue
|
||||
import soundfile
|
||||
import time
|
||||
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
OVAERWATCH = False
|
||||
|
||||
model_loader = ModelLoader()
|
||||
|
||||
def test_asr_pipeline():
|
||||
# 加载模型
|
||||
args = {
|
||||
"asr_model": "paraformer-zh",
|
||||
"asr_model_revision": "v2.0.4",
|
||||
"vad_model": "fsmn-vad",
|
||||
"vad_model_revision": "v2.0.4",
|
||||
"spk_model": "cam++",
|
||||
"spk_model_revision": "v2.0.2",
|
||||
"audio_update": False,
|
||||
}
|
||||
models = model_loader.load_models(args)
|
||||
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
|
||||
audio_config = AudioBinary_Config(
|
||||
chunk_size=200,
|
||||
chunk_stride=1600,
|
||||
sample_rate=sample_rate,
|
||||
sample_width=16,
|
||||
channels=1
|
||||
)
|
||||
chunk_stride = int(audio_config.chunk_size*sample_rate/1000)
|
||||
audio_config.chunk_stride = chunk_stride
|
||||
|
||||
# 创建参数Dict
|
||||
config = {
|
||||
"audio_config": audio_config,
|
||||
}
|
||||
|
||||
# 创建音频数据列表
|
||||
audio_binary_data_list = AudioBinary_data_list()
|
||||
|
||||
input_queue = Queue()
|
||||
|
||||
# 创建Pipeline
|
||||
asr_pipeline = ASRPipeline()
|
||||
asr_pipeline.set_models(models)
|
||||
asr_pipeline.set_config(config)
|
||||
asr_pipeline.set_audio_binary(audio_binary_data_list)
|
||||
asr_pipeline.set_input_queue(input_queue)
|
||||
asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
|
||||
asr_pipeline.bake()
|
||||
|
||||
# 运行Pipeline
|
||||
asr_instance = asr_pipeline.run()
|
||||
|
||||
audio_clip_len = 200
|
||||
print(f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}")
|
||||
for i in range(0, len(audio_data), audio_clip_len):
|
||||
input_queue.put(audio_data[i:i+audio_clip_len])
|
||||
|
||||
# time.sleep(10)
|
||||
# input_queue.put(None)
|
||||
|
||||
# 等待Pipeline结束
|
||||
# asr_instance.join()
|
||||
|
||||
time.sleep(5)
|
||||
asr_pipeline.stop()
|
||||
# asr_pipeline.stop()
|
||||
|
Loading…
x
Reference in New Issue
Block a user