[代码重构中]初步构建ASRFunctor,与VADFunctor在vad_test.py中进行联调无问题,数据衔接正常。

This commit is contained in:
Ziyang.Zhang 2025-06-05 15:57:11 +08:00
parent 4e9e94d8dc
commit ff9bd70039
3 changed files with 267 additions and 36 deletions

159
src/functor/asr_functor.py Normal file
View File

@ -0,0 +1,159 @@
"""
ASRFunctor
负责对音频片段进行ASR处理, 以ASR_Result进行callback
"""
from src.functor.base import BaseFunctor
from src.models import AudioBinary_data_list, AudioBinary_Config,VAD_Functor_result
from typing import Callable, List
from queue import Queue, Empty
import threading
# 日志
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class ASRFunctor(BaseFunctor):
"""
ASRFunctor
负责对音频片段进行ASR处理, 以ASR_Result进行callback
需要配置好 _model, _callback, _input_queue, _audio_config
否则无法run()启动线程
运行中, 使用reset_cache()重置缓存, 准备下次任务
使用stop()停止线程, 但需要等待input_queue为空
"""
def __init__(self) -> None:
super().__init__()
# 资源与配置
self._model: dict = {} # 模型
self._callback: List[Callable] = [] # 回调函数
self._input_queue: Queue = None # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置
# flag
self._is_running: bool = False
self._stop_event: bool = False
# 线程资源
self._thread: threading.Thread = None
# 状态锁
self._status_lock: threading.Lock = threading.Lock()
# 缓存
self._hotwords: List[str] = []
def reset_cache(self) -> None:
"""
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
"""
pass
def set_input_queue(self, queue: Queue) -> None:
"""
设置监听的输入消息队列
"""
self._input_queue = queue
def set_model(self, model: dict) -> None:
"""
设置推理模型
"""
self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
"""
设置音频配置
"""
self._audio_config = audio_config
logger.debug("ASRFunctor设置音频配置: %s", self._audio_config)
def add_callback(self, callback: Callable) -> None:
"""
向自身的_callback: List[Callable]回调函数列表中添加回调函数
"""
if not isinstance(self._callback, list):
self._callback = []
self._callback.append(callback)
def _do_callback(self, result: List[str]) -> None:
"""
回调函数
"""
text = result[0]['text'].replace(" ", "")
for callback in self._callback:
callback(text)
def _process(self, data: VAD_Functor_result) -> None:
"""
处理数据
"""
binary_data = data.audiobinary_data.binary_data
result = self._model["asr"].generate(
input=binary_data,
chunk_size=self._audio_config.chunk_size,
hotwords=self._hotwords,
)
self._do_callback(result)
def _run(self):
"""
线程运行逻辑
"""
with self._status_lock:
self._is_running = True
self._stop_event = False
# 运行逻辑
while self._is_running:
try:
data = self._input_queue.get(True, timeout=1)
self._process(data)
self._input_queue.task_done()
# 当队列为空时, 间隔1s检测是否进入停止事件。
except Empty:
if self._stop_event:
break
continue
# 其他异常
except Exception as e:
logger.error("ASRFunctor运行时发生错误: %s", e)
raise e
def run(self):
"""
启动线程
"""
self._pre_check()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
return self._thread
def _pre_check(self) -> bool:
"""
预检查
"""
if self._model is None:
raise ValueError("模型未设置")
if self._audio_config is None:
raise ValueError("音频配置未设置")
if self._input_queue is None:
raise ValueError("输入队列未设置")
if self._callback is None:
raise ValueError("回调函数未设置")
return True
def stop(self):
"""
停止线程
"""
with self._status_lock:
self._stop_event = True
self._thread.join()
with self._status_lock:
self._is_running = False
return not self._thread.is_alive()

View File

@ -1,11 +1,17 @@
from funasr import AutoModel
from typing import List, Dict, Any, Callable
from src.models import VAD_Functor_result, _AudioBinary_data, AudioBinary_Config, AudioBinary_data_list
from typing import Callable
from src.functor.base import BaseFunctor
"""
VADFunctor
负责对音频片段进行VAD处理, 以VAD_Result进行callback
"""
import threading
from queue import Empty, Queue
from typing import List, Any, Callable
import numpy
from src.models import (
VAD_Functor_result,
AudioBinary_Config,
AudioBinary_data_list,
)
from src.functor.base import BaseFunctor
# 日志
from src.utils.logger import get_module_logger
@ -14,22 +20,34 @@ logger = get_module_logger(__name__)
class VADFunctor(BaseFunctor):
def __init__(
self
):
"""
VADFunctor
负责对音频片段进行VAD处理, 以VAD_Result进行callback
需要配置好 _model, _callback, _input_queue, _audio_config, _audio_binary_data_list
否则无法run()启动线程
运行中, 使用reset_cache()重置缓存, 准备下次任务
使用stop()停止线程, 但需要等待input_queue为空
"""
def __init__(self) -> None:
super().__init__()
# 资源与配置
self._model: dict = {} # 模型
self._callback: List[Callable] = [] # 回调函数
self._input_queue: Queue = None # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置
self._audio_binary_data_list: AudioBinary_data_list = None # 音频数据列表
self._model: dict = {} # 模型
self._callback: List[Callable] = [] # 回调函数
self._input_queue: Queue = None # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置
self._audio_binary_data_list: AudioBinary_data_list = None # 音频数据列表
# flag
# 此处用到两个锁但都是为了截断_run线程考虑后续优化
self._is_running: bool = False
self._stop_event: bool = False
# 线程资源
self._thread: threading.Thread = None
# 状态锁
self._status_lock: threading.Lock = threading.Lock()
@ -39,8 +57,8 @@ class VADFunctor(BaseFunctor):
self._model_cache: dict = {}
self._cache_result_list = []
self._audiobinary_cache = None
def reset_cache(self):
def reset_cache(self) -> None:
"""
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
"""
@ -50,25 +68,46 @@ class VADFunctor(BaseFunctor):
self._cache_result_list = []
self._audiobinary_cache = None
def set_input_queue(self, queue: Queue):
def set_input_queue(self, queue: Queue) -> None:
"""
设置监听的输入消息队列
"""
self._input_queue = queue
def set_model(self, model: dict):
def set_model(self, model: dict) -> None:
"""
设置推理模型
"""
self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config):
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
"""
设置音频配置
"""
self._audio_config = audio_config
logger.info(f"VADFunctor设置音频配置: {self._audio_config}")
logger.debug("VADFunctor设置音频配置: %s", self._audio_config)
def set_audio_binary_data_list(self, audio_binary_data_list: AudioBinary_data_list):
def set_audio_binary_data_list(
self, audio_binary_data_list: AudioBinary_data_list
) -> None:
"""
设置音频数据列表, 为Class AudioBinary_data_list类型
AudioBinary_data_list包含binary_data_list, 为list[_AudioBinary_data]类型
_AudioBinary_data包含binary_data, 为bytes/numpy.ndarray类型
"""
self._audio_binary_data_list = audio_binary_data_list
def add_callback(self, callback: Callable):
def add_callback(self, callback: Callable) -> None:
"""
向自身的_callback: List[Callable]回调函数列表中添加回调函数
"""
if not isinstance(self._callback, list):
self._callback = []
self._callback.append(callback)
def _do_callback(self, result: List[List[int]], audio_cache: AudioBinary_data_list):
def _do_callback(
self, result: List[List[int]]
) -> None:
"""
回调函数
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice
@ -91,7 +130,7 @@ class VADFunctor(BaseFunctor):
while len(self._cache_result_list) > 1:
# 创建VAD片段
# 计算开始帧
start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0])
start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0])
start_frame -= self._audio_cache_preindex
# 计算结束帧
end_frame = self._audio_config.ms2frame(self._cache_result_list[0][1])
@ -101,7 +140,7 @@ class VADFunctor(BaseFunctor):
audiobinary_data_list=self._audio_binary_data_list,
data=self._audiobinary_cache[start_frame:end_frame],
start_time=self._cache_result_list[0][0],
end_time=self._cache_result_list[0][1]
end_time=self._cache_result_list[0][1],
)
self._audio_cache_preindex += end_frame
self._audiobinary_cache = self._audiobinary_cache[end_frame:]
@ -109,9 +148,9 @@ class VADFunctor(BaseFunctor):
callback(vad_result)
self._cache_result_list.pop(0)
def _predeal_data(self, data: Any):
def _predeal_data(self, data: Any) -> None:
"""
预处理数据
预处理数据, 将数据缓存到_audio_cache和_audiobinary_cache中
"""
if self._audio_cache is None:
self._audio_cache = data
@ -126,26 +165,29 @@ class VADFunctor(BaseFunctor):
else:
# 拼接音频数据
if isinstance(self._audiobinary_cache, numpy.ndarray):
self._audiobinary_cache = numpy.concatenate((self._audiobinary_cache, data))
self._audiobinary_cache = numpy.concatenate(
(self._audiobinary_cache, data)
)
elif isinstance(self._audiobinary_cache, list):
self._audiobinary_cache.append(data)
def _process(self, data: Any):
"""
处理数据
使用model进行生成, 并使用_do_callback进行回调
"""
self._predeal_data(data)
if len(self._audio_cache) >= self._audio_config.chunk_stride:
result = self._model['vad'].generate(
result = self._model["vad"].generate(
input=self._audio_cache,
cache=self._model_cache,
chunk_size=self._audio_config.chunk_size,
is_final=False,
)
if (len(result[0]['value']) > 0):
self._do_callback(result[0]['value'], self._audio_cache)
logger.debug(f"VADFunctor结果: {result[0]['value']}")
if len(result[0]["value"]) > 0:
self._do_callback(result[0]["value"])
# logger.debug(f"VADFunctor结果: {result[0]['value']}")
self._audio_cache = None
def _run(self):
"""
@ -170,7 +212,7 @@ class VADFunctor(BaseFunctor):
continue
# 其他异常
except Exception as e:
logger.error(f"VADFunctor运行时发生错误: {e}")
logger.error("VADFunctor运行时发生错误: %s", e)
raise e
def run(self):
@ -199,12 +241,16 @@ class VADFunctor(BaseFunctor):
return True
def stop(self):
"""
停止线程
通过设置_stop_event为True, 来在input_queue.get()循环为空时退出
"""
with self._status_lock:
self._stop_event = True
self._thread.join()
with self._status_lock:
self._is_running = False
return True
return not self._thread.is_alive()
# class VAD:

View File

@ -3,6 +3,7 @@ Functor测试
VAD测试
"""
from src.functor.vad_functor import VADFunctor
from src.functor.asr_functor import ASRFunctor
from queue import Queue, Empty
from src.model_loader import ModelLoader
from src.models import AudioBinary_Config, AudioBinary_data_list
@ -22,6 +23,8 @@ model_loader = ModelLoader()
def test_vad_functor():
# 加载模型
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"auto_update": False,
@ -40,6 +43,7 @@ def test_vad_functor():
audio_config.chunk_stride = chunk_stride
# 创建输入队列
input_queue = Queue()
vad2asr_queue = Queue()
# 创建音频数据列表
audio_binary_data_list = AudioBinary_data_list()
@ -52,14 +56,30 @@ def test_vad_functor():
# 设置音频数据列表
vad_functor.set_audio_binary_data_list(audio_binary_data_list)
# 设置回调函数
vad_functor.add_callback(lambda x: print(f"callback: {x}"))
vad_functor.add_callback(lambda x: print(f"vad callback: {x}"))
vad_functor.add_callback(lambda x: vad2asr_queue.put(x))
# 设置模型
vad_functor.set_model({
'vad': model_loader.models['vad']
})
# 启动VAD函数器
vad_functor.run()
# 创建ASR函数器
asr_functor = ASRFunctor()
# 设置输入队列
asr_functor.set_input_queue(vad2asr_queue)
# 设置音频配置
asr_functor.set_audio_config(audio_config)
# 设置回调函数
asr_functor.add_callback(lambda x: print(f"asr callback: {x}"))
# 设置模型
asr_functor.set_model({
'asr': model_loader.models['asr']
})
# 启动ASR函数器
asr_functor.run()
f_binary = f_data
audio_clip_len = 200
print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}")
@ -73,6 +93,12 @@ def test_vad_functor():
vad_functor.stop()
print("[vad_test] VAD函数器结束")
print("[vad_test] 等待vad2asr_queue为空")
vad2asr_queue.join()
print("[vad_test] vad2asr_queue为空")
asr_functor.stop()
print("[vad_test] ASR函数器结束")
# 保存音频数据
if OVERWATCH:
for index in range(len(audio_binary_data_list)):