[代码重构中]初步构建ASRFunctor,与VADFunctor在vad_test.py中进行联调无问题,数据衔接正常。

This commit is contained in:
Ziyang.Zhang 2025-06-05 15:57:11 +08:00
parent 4e9e94d8dc
commit ff9bd70039
3 changed files with 267 additions and 36 deletions

159
src/functor/asr_functor.py Normal file
View File

@ -0,0 +1,159 @@
"""
ASRFunctor
负责对音频片段进行ASR处理, 以ASR_Result进行callback
"""
from src.functor.base import BaseFunctor
from src.models import AudioBinary_data_list, AudioBinary_Config,VAD_Functor_result
from typing import Callable, List
from queue import Queue, Empty
import threading
# 日志
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class ASRFunctor(BaseFunctor):
"""
ASRFunctor
负责对音频片段进行ASR处理, 以ASR_Result进行callback
需要配置好 _model, _callback, _input_queue, _audio_config
否则无法run()启动线程
运行中, 使用reset_cache()重置缓存, 准备下次任务
使用stop()停止线程, 但需要等待input_queue为空
"""
def __init__(self) -> None:
super().__init__()
# 资源与配置
self._model: dict = {} # 模型
self._callback: List[Callable] = [] # 回调函数
self._input_queue: Queue = None # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置
# flag
self._is_running: bool = False
self._stop_event: bool = False
# 线程资源
self._thread: threading.Thread = None
# 状态锁
self._status_lock: threading.Lock = threading.Lock()
# 缓存
self._hotwords: List[str] = []
def reset_cache(self) -> None:
"""
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
"""
pass
def set_input_queue(self, queue: Queue) -> None:
"""
设置监听的输入消息队列
"""
self._input_queue = queue
def set_model(self, model: dict) -> None:
"""
设置推理模型
"""
self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
"""
设置音频配置
"""
self._audio_config = audio_config
logger.debug("ASRFunctor设置音频配置: %s", self._audio_config)
def add_callback(self, callback: Callable) -> None:
"""
向自身的_callback: List[Callable]回调函数列表中添加回调函数
"""
if not isinstance(self._callback, list):
self._callback = []
self._callback.append(callback)
def _do_callback(self, result: List[str]) -> None:
"""
回调函数
"""
text = result[0]['text'].replace(" ", "")
for callback in self._callback:
callback(text)
def _process(self, data: VAD_Functor_result) -> None:
"""
处理数据
"""
binary_data = data.audiobinary_data.binary_data
result = self._model["asr"].generate(
input=binary_data,
chunk_size=self._audio_config.chunk_size,
hotwords=self._hotwords,
)
self._do_callback(result)
def _run(self):
"""
线程运行逻辑
"""
with self._status_lock:
self._is_running = True
self._stop_event = False
# 运行逻辑
while self._is_running:
try:
data = self._input_queue.get(True, timeout=1)
self._process(data)
self._input_queue.task_done()
# 当队列为空时, 间隔1s检测是否进入停止事件。
except Empty:
if self._stop_event:
break
continue
# 其他异常
except Exception as e:
logger.error("ASRFunctor运行时发生错误: %s", e)
raise e
def run(self):
"""
启动线程
"""
self._pre_check()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
return self._thread
def _pre_check(self) -> bool:
"""
预检查
"""
if self._model is None:
raise ValueError("模型未设置")
if self._audio_config is None:
raise ValueError("音频配置未设置")
if self._input_queue is None:
raise ValueError("输入队列未设置")
if self._callback is None:
raise ValueError("回调函数未设置")
return True
def stop(self):
"""
停止线程
"""
with self._status_lock:
self._stop_event = True
self._thread.join()
with self._status_lock:
self._is_running = False
return not self._thread.is_alive()

View File

@ -1,11 +1,17 @@
from funasr import AutoModel """
from typing import List, Dict, Any, Callable VADFunctor
from src.models import VAD_Functor_result, _AudioBinary_data, AudioBinary_Config, AudioBinary_data_list 负责对音频片段进行VAD处理, 以VAD_Result进行callback
from typing import Callable """
from src.functor.base import BaseFunctor
import threading import threading
from queue import Empty, Queue from queue import Empty, Queue
from typing import List, Any, Callable
import numpy import numpy
from src.models import (
VAD_Functor_result,
AudioBinary_Config,
AudioBinary_data_list,
)
from src.functor.base import BaseFunctor
# 日志 # 日志
from src.utils.logger import get_module_logger from src.utils.logger import get_module_logger
@ -14,9 +20,18 @@ logger = get_module_logger(__name__)
class VADFunctor(BaseFunctor): class VADFunctor(BaseFunctor):
def __init__( """
self VADFunctor
): 负责对音频片段进行VAD处理, 以VAD_Result进行callback
需要配置好 _model, _callback, _input_queue, _audio_config, _audio_binary_data_list
否则无法run()启动线程
运行中, 使用reset_cache()重置缓存, 准备下次任务
使用stop()停止线程, 但需要等待input_queue为空
"""
def __init__(self) -> None:
super().__init__() super().__init__()
# 资源与配置 # 资源与配置
self._model: dict = {} # 模型 self._model: dict = {} # 模型
@ -30,6 +45,9 @@ class VADFunctor(BaseFunctor):
self._is_running: bool = False self._is_running: bool = False
self._stop_event: bool = False self._stop_event: bool = False
# 线程资源
self._thread: threading.Thread = None
# 状态锁 # 状态锁
self._status_lock: threading.Lock = threading.Lock() self._status_lock: threading.Lock = threading.Lock()
@ -40,7 +58,7 @@ class VADFunctor(BaseFunctor):
self._cache_result_list = [] self._cache_result_list = []
self._audiobinary_cache = None self._audiobinary_cache = None
def reset_cache(self): def reset_cache(self) -> None:
""" """
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务 重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
""" """
@ -50,25 +68,46 @@ class VADFunctor(BaseFunctor):
self._cache_result_list = [] self._cache_result_list = []
self._audiobinary_cache = None self._audiobinary_cache = None
def set_input_queue(self, queue: Queue): def set_input_queue(self, queue: Queue) -> None:
"""
设置监听的输入消息队列
"""
self._input_queue = queue self._input_queue = queue
def set_model(self, model: dict): def set_model(self, model: dict) -> None:
"""
设置推理模型
"""
self._model = model self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config): def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
"""
设置音频配置
"""
self._audio_config = audio_config self._audio_config = audio_config
logger.info(f"VADFunctor设置音频配置: {self._audio_config}") logger.debug("VADFunctor设置音频配置: %s", self._audio_config)
def set_audio_binary_data_list(self, audio_binary_data_list: AudioBinary_data_list): def set_audio_binary_data_list(
self, audio_binary_data_list: AudioBinary_data_list
) -> None:
"""
设置音频数据列表, 为Class AudioBinary_data_list类型
AudioBinary_data_list包含binary_data_list, 为list[_AudioBinary_data]类型
_AudioBinary_data包含binary_data, 为bytes/numpy.ndarray类型
"""
self._audio_binary_data_list = audio_binary_data_list self._audio_binary_data_list = audio_binary_data_list
def add_callback(self, callback: Callable): def add_callback(self, callback: Callable) -> None:
"""
向自身的_callback: List[Callable]回调函数列表中添加回调函数
"""
if not isinstance(self._callback, list): if not isinstance(self._callback, list):
self._callback = [] self._callback = []
self._callback.append(callback) self._callback.append(callback)
def _do_callback(self, result: List[List[int]], audio_cache: AudioBinary_data_list): def _do_callback(
self, result: List[List[int]]
) -> None:
""" """
回调函数 回调函数
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice
@ -101,7 +140,7 @@ class VADFunctor(BaseFunctor):
audiobinary_data_list=self._audio_binary_data_list, audiobinary_data_list=self._audio_binary_data_list,
data=self._audiobinary_cache[start_frame:end_frame], data=self._audiobinary_cache[start_frame:end_frame],
start_time=self._cache_result_list[0][0], start_time=self._cache_result_list[0][0],
end_time=self._cache_result_list[0][1] end_time=self._cache_result_list[0][1],
) )
self._audio_cache_preindex += end_frame self._audio_cache_preindex += end_frame
self._audiobinary_cache = self._audiobinary_cache[end_frame:] self._audiobinary_cache = self._audiobinary_cache[end_frame:]
@ -109,9 +148,9 @@ class VADFunctor(BaseFunctor):
callback(vad_result) callback(vad_result)
self._cache_result_list.pop(0) self._cache_result_list.pop(0)
def _predeal_data(self, data: Any): def _predeal_data(self, data: Any) -> None:
""" """
预处理数据 预处理数据, 将数据缓存到_audio_cache和_audiobinary_cache中
""" """
if self._audio_cache is None: if self._audio_cache is None:
self._audio_cache = data self._audio_cache = data
@ -126,27 +165,30 @@ class VADFunctor(BaseFunctor):
else: else:
# 拼接音频数据 # 拼接音频数据
if isinstance(self._audiobinary_cache, numpy.ndarray): if isinstance(self._audiobinary_cache, numpy.ndarray):
self._audiobinary_cache = numpy.concatenate((self._audiobinary_cache, data)) self._audiobinary_cache = numpy.concatenate(
(self._audiobinary_cache, data)
)
elif isinstance(self._audiobinary_cache, list): elif isinstance(self._audiobinary_cache, list):
self._audiobinary_cache.append(data) self._audiobinary_cache.append(data)
def _process(self, data: Any): def _process(self, data: Any):
""" """
处理数据 处理数据
使用model进行生成, 并使用_do_callback进行回调
""" """
self._predeal_data(data) self._predeal_data(data)
if len(self._audio_cache) >= self._audio_config.chunk_stride: if len(self._audio_cache) >= self._audio_config.chunk_stride:
result = self._model['vad'].generate( result = self._model["vad"].generate(
input=self._audio_cache, input=self._audio_cache,
cache=self._model_cache, cache=self._model_cache,
chunk_size=self._audio_config.chunk_size, chunk_size=self._audio_config.chunk_size,
is_final=False, is_final=False,
) )
if (len(result[0]['value']) > 0): if len(result[0]["value"]) > 0:
self._do_callback(result[0]['value'], self._audio_cache) self._do_callback(result[0]["value"])
logger.debug(f"VADFunctor结果: {result[0]['value']}") # logger.debug(f"VADFunctor结果: {result[0]['value']}")
self._audio_cache = None self._audio_cache = None
def _run(self): def _run(self):
""" """
线程运行逻辑 线程运行逻辑
@ -170,7 +212,7 @@ class VADFunctor(BaseFunctor):
continue continue
# 其他异常 # 其他异常
except Exception as e: except Exception as e:
logger.error(f"VADFunctor运行时发生错误: {e}") logger.error("VADFunctor运行时发生错误: %s", e)
raise e raise e
def run(self): def run(self):
@ -199,12 +241,16 @@ class VADFunctor(BaseFunctor):
return True return True
def stop(self): def stop(self):
"""
停止线程
通过设置_stop_event为True, 来在input_queue.get()循环为空时退出
"""
with self._status_lock: with self._status_lock:
self._stop_event = True self._stop_event = True
self._thread.join() self._thread.join()
with self._status_lock: with self._status_lock:
self._is_running = False self._is_running = False
return True return not self._thread.is_alive()
# class VAD: # class VAD:

View File

@ -3,6 +3,7 @@ Functor测试
VAD测试 VAD测试
""" """
from src.functor.vad_functor import VADFunctor from src.functor.vad_functor import VADFunctor
from src.functor.asr_functor import ASRFunctor
from queue import Queue, Empty from queue import Queue, Empty
from src.model_loader import ModelLoader from src.model_loader import ModelLoader
from src.models import AudioBinary_Config, AudioBinary_data_list from src.models import AudioBinary_Config, AudioBinary_data_list
@ -22,6 +23,8 @@ model_loader = ModelLoader()
def test_vad_functor(): def test_vad_functor():
# 加载模型 # 加载模型
args = { args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad", "vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4", "vad_model_revision": "v2.0.4",
"auto_update": False, "auto_update": False,
@ -40,6 +43,7 @@ def test_vad_functor():
audio_config.chunk_stride = chunk_stride audio_config.chunk_stride = chunk_stride
# 创建输入队列 # 创建输入队列
input_queue = Queue() input_queue = Queue()
vad2asr_queue = Queue()
# 创建音频数据列表 # 创建音频数据列表
audio_binary_data_list = AudioBinary_data_list() audio_binary_data_list = AudioBinary_data_list()
@ -52,14 +56,30 @@ def test_vad_functor():
# 设置音频数据列表 # 设置音频数据列表
vad_functor.set_audio_binary_data_list(audio_binary_data_list) vad_functor.set_audio_binary_data_list(audio_binary_data_list)
# 设置回调函数 # 设置回调函数
vad_functor.add_callback(lambda x: print(f"callback: {x}")) vad_functor.add_callback(lambda x: print(f"vad callback: {x}"))
vad_functor.add_callback(lambda x: vad2asr_queue.put(x))
# 设置模型 # 设置模型
vad_functor.set_model({ vad_functor.set_model({
'vad': model_loader.models['vad'] 'vad': model_loader.models['vad']
}) })
# 启动VAD函数器 # 启动VAD函数器
vad_functor.run() vad_functor.run()
# 创建ASR函数器
asr_functor = ASRFunctor()
# 设置输入队列
asr_functor.set_input_queue(vad2asr_queue)
# 设置音频配置
asr_functor.set_audio_config(audio_config)
# 设置回调函数
asr_functor.add_callback(lambda x: print(f"asr callback: {x}"))
# 设置模型
asr_functor.set_model({
'asr': model_loader.models['asr']
})
# 启动ASR函数器
asr_functor.run()
f_binary = f_data f_binary = f_data
audio_clip_len = 200 audio_clip_len = 200
print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}") print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}")
@ -73,6 +93,12 @@ def test_vad_functor():
vad_functor.stop() vad_functor.stop()
print("[vad_test] VAD函数器结束") print("[vad_test] VAD函数器结束")
print("[vad_test] 等待vad2asr_queue为空")
vad2asr_queue.join()
print("[vad_test] vad2asr_queue为空")
asr_functor.stop()
print("[vad_test] ASR函数器结束")
# 保存音频数据 # 保存音频数据
if OVERWATCH: if OVERWATCH:
for index in range(len(audio_binary_data_list)): for index in range(len(audio_binary_data_list)):