diff --git a/src/functor/asr_functor.py b/src/functor/asr_functor.py new file mode 100644 index 0000000..e5a93be --- /dev/null +++ b/src/functor/asr_functor.py @@ -0,0 +1,159 @@ +""" +ASRFunctor +负责对音频片段进行ASR处理, 以ASR_Result进行callback +""" +from src.functor.base import BaseFunctor +from src.models import AudioBinary_data_list, AudioBinary_Config,VAD_Functor_result +from typing import Callable, List +from queue import Queue, Empty +import threading + +# 日志 +from src.utils.logger import get_module_logger + +logger = get_module_logger(__name__) + +class ASRFunctor(BaseFunctor): + """ + ASRFunctor + 负责对音频片段进行ASR处理, 以ASR_Result进行callback + 需要配置好 _model, _callback, _input_queue, _audio_config + 否则无法run()启动线程 + + 运行中, 使用reset_cache()重置缓存, 准备下次任务 + + 使用stop()停止线程, 但需要等待input_queue为空 + """ + + def __init__(self) -> None: + super().__init__() + # 资源与配置 + self._model: dict = {} # 模型 + self._callback: List[Callable] = [] # 回调函数 + self._input_queue: Queue = None # 输入队列 + self._audio_config: AudioBinary_Config = None # 音频配置 + + # flag + self._is_running: bool = False + self._stop_event: bool = False + + # 线程资源 + self._thread: threading.Thread = None + + # 状态锁 + self._status_lock: threading.Lock = threading.Lock() + + # 缓存 + self._hotwords: List[str] = [] + + def reset_cache(self) -> None: + """ + 重置缓存, 用于任务完成后清理缓存数据, 准备下次任务 + """ + pass + + def set_input_queue(self, queue: Queue) -> None: + """ + 设置监听的输入消息队列 + """ + self._input_queue = queue + + def set_model(self, model: dict) -> None: + """ + 设置推理模型 + """ + self._model = model + + def set_audio_config(self, audio_config: AudioBinary_Config) -> None: + """ + 设置音频配置 + """ + self._audio_config = audio_config + logger.debug("ASRFunctor设置音频配置: %s", self._audio_config) + + def add_callback(self, callback: Callable) -> None: + """ + 向自身的_callback: List[Callable]回调函数列表中添加回调函数 + """ + if not isinstance(self._callback, list): + self._callback = [] + self._callback.append(callback) + + def _do_callback(self, result: List[str]) -> None: + """ + 回调函数 + """ + text = result[0]['text'].replace(" ", "") + for callback in self._callback: + callback(text) + + def _process(self, data: VAD_Functor_result) -> None: + """ + 处理数据 + """ + binary_data = data.audiobinary_data.binary_data + result = self._model["asr"].generate( + input=binary_data, + chunk_size=self._audio_config.chunk_size, + hotwords=self._hotwords, + ) + self._do_callback(result) + + def _run(self): + """ + 线程运行逻辑 + """ + with self._status_lock: + self._is_running = True + self._stop_event = False + # 运行逻辑 + while self._is_running: + try: + data = self._input_queue.get(True, timeout=1) + self._process(data) + self._input_queue.task_done() + # 当队列为空时, 间隔1s检测是否进入停止事件。 + except Empty: + if self._stop_event: + break + continue + # 其他异常 + except Exception as e: + logger.error("ASRFunctor运行时发生错误: %s", e) + raise e + + def run(self): + """ + 启动线程 + """ + self._pre_check() + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + return self._thread + + def _pre_check(self) -> bool: + """ + 预检查 + """ + if self._model is None: + raise ValueError("模型未设置") + if self._audio_config is None: + raise ValueError("音频配置未设置") + if self._input_queue is None: + raise ValueError("输入队列未设置") + if self._callback is None: + raise ValueError("回调函数未设置") + return True + + def stop(self): + """ + 停止线程 + """ + with self._status_lock: + self._stop_event = True + self._thread.join() + with self._status_lock: + self._is_running = False + return not self._thread.is_alive() + + \ No newline at end of file diff --git a/src/functor/vad_functor.py b/src/functor/vad_functor.py index c946f63..4f4cfca 100644 --- a/src/functor/vad_functor.py +++ b/src/functor/vad_functor.py @@ -1,11 +1,17 @@ -from funasr import AutoModel -from typing import List, Dict, Any, Callable -from src.models import VAD_Functor_result, _AudioBinary_data, AudioBinary_Config, AudioBinary_data_list -from typing import Callable -from src.functor.base import BaseFunctor +""" +VADFunctor +负责对音频片段进行VAD处理, 以VAD_Result进行callback +""" import threading from queue import Empty, Queue +from typing import List, Any, Callable import numpy +from src.models import ( + VAD_Functor_result, + AudioBinary_Config, + AudioBinary_data_list, +) +from src.functor.base import BaseFunctor # 日志 from src.utils.logger import get_module_logger @@ -14,22 +20,34 @@ logger = get_module_logger(__name__) class VADFunctor(BaseFunctor): - def __init__( - self - ): + """ + VADFunctor + 负责对音频片段进行VAD处理, 以VAD_Result进行callback + 需要配置好 _model, _callback, _input_queue, _audio_config, _audio_binary_data_list + 否则无法run()启动线程 + + 运行中, 使用reset_cache()重置缓存, 准备下次任务 + + 使用stop()停止线程, 但需要等待input_queue为空 + """ + + def __init__(self) -> None: super().__init__() # 资源与配置 - self._model: dict = {} # 模型 - self._callback: List[Callable] = [] # 回调函数 - self._input_queue: Queue = None # 输入队列 - self._audio_config: AudioBinary_Config = None # 音频配置 - self._audio_binary_data_list: AudioBinary_data_list = None # 音频数据列表 + self._model: dict = {} # 模型 + self._callback: List[Callable] = [] # 回调函数 + self._input_queue: Queue = None # 输入队列 + self._audio_config: AudioBinary_Config = None # 音频配置 + self._audio_binary_data_list: AudioBinary_data_list = None # 音频数据列表 # flag # 此处用到两个锁,但都是为了截断_run线程,考虑后续优化 self._is_running: bool = False self._stop_event: bool = False + # 线程资源 + self._thread: threading.Thread = None + # 状态锁 self._status_lock: threading.Lock = threading.Lock() @@ -39,8 +57,8 @@ class VADFunctor(BaseFunctor): self._model_cache: dict = {} self._cache_result_list = [] self._audiobinary_cache = None - - def reset_cache(self): + + def reset_cache(self) -> None: """ 重置缓存, 用于任务完成后清理缓存数据, 准备下次任务 """ @@ -50,25 +68,46 @@ class VADFunctor(BaseFunctor): self._cache_result_list = [] self._audiobinary_cache = None - def set_input_queue(self, queue: Queue): + def set_input_queue(self, queue: Queue) -> None: + """ + 设置监听的输入消息队列 + """ self._input_queue = queue - def set_model(self, model: dict): + def set_model(self, model: dict) -> None: + """ + 设置推理模型 + """ self._model = model - def set_audio_config(self, audio_config: AudioBinary_Config): + def set_audio_config(self, audio_config: AudioBinary_Config) -> None: + """ + 设置音频配置 + """ self._audio_config = audio_config - logger.info(f"VADFunctor设置音频配置: {self._audio_config}") + logger.debug("VADFunctor设置音频配置: %s", self._audio_config) - def set_audio_binary_data_list(self, audio_binary_data_list: AudioBinary_data_list): + def set_audio_binary_data_list( + self, audio_binary_data_list: AudioBinary_data_list + ) -> None: + """ + 设置音频数据列表, 为Class AudioBinary_data_list类型 + AudioBinary_data_list包含binary_data_list, 为list[_AudioBinary_data]类型 + _AudioBinary_data包含binary_data, 为bytes/numpy.ndarray类型 + """ self._audio_binary_data_list = audio_binary_data_list - def add_callback(self, callback: Callable): + def add_callback(self, callback: Callable) -> None: + """ + 向自身的_callback: List[Callable]回调函数列表中添加回调函数 + """ if not isinstance(self._callback, list): self._callback = [] self._callback.append(callback) - def _do_callback(self, result: List[List[int]], audio_cache: AudioBinary_data_list): + def _do_callback( + self, result: List[List[int]] + ) -> None: """ 回调函数 VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice @@ -91,7 +130,7 @@ class VADFunctor(BaseFunctor): while len(self._cache_result_list) > 1: # 创建VAD片段 # 计算开始帧 - start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0]) + start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0]) start_frame -= self._audio_cache_preindex # 计算结束帧 end_frame = self._audio_config.ms2frame(self._cache_result_list[0][1]) @@ -101,7 +140,7 @@ class VADFunctor(BaseFunctor): audiobinary_data_list=self._audio_binary_data_list, data=self._audiobinary_cache[start_frame:end_frame], start_time=self._cache_result_list[0][0], - end_time=self._cache_result_list[0][1] + end_time=self._cache_result_list[0][1], ) self._audio_cache_preindex += end_frame self._audiobinary_cache = self._audiobinary_cache[end_frame:] @@ -109,9 +148,9 @@ class VADFunctor(BaseFunctor): callback(vad_result) self._cache_result_list.pop(0) - def _predeal_data(self, data: Any): + def _predeal_data(self, data: Any) -> None: """ - 预处理数据 + 预处理数据, 将数据缓存到_audio_cache和_audiobinary_cache中 """ if self._audio_cache is None: self._audio_cache = data @@ -126,26 +165,29 @@ class VADFunctor(BaseFunctor): else: # 拼接音频数据 if isinstance(self._audiobinary_cache, numpy.ndarray): - self._audiobinary_cache = numpy.concatenate((self._audiobinary_cache, data)) + self._audiobinary_cache = numpy.concatenate( + (self._audiobinary_cache, data) + ) elif isinstance(self._audiobinary_cache, list): self._audiobinary_cache.append(data) + def _process(self, data: Any): """ 处理数据 + 使用model进行生成, 并使用_do_callback进行回调 """ self._predeal_data(data) if len(self._audio_cache) >= self._audio_config.chunk_stride: - result = self._model['vad'].generate( + result = self._model["vad"].generate( input=self._audio_cache, cache=self._model_cache, chunk_size=self._audio_config.chunk_size, is_final=False, ) - if (len(result[0]['value']) > 0): - self._do_callback(result[0]['value'], self._audio_cache) - logger.debug(f"VADFunctor结果: {result[0]['value']}") + if len(result[0]["value"]) > 0: + self._do_callback(result[0]["value"]) + # logger.debug(f"VADFunctor结果: {result[0]['value']}") self._audio_cache = None - def _run(self): """ @@ -170,7 +212,7 @@ class VADFunctor(BaseFunctor): continue # 其他异常 except Exception as e: - logger.error(f"VADFunctor运行时发生错误: {e}") + logger.error("VADFunctor运行时发生错误: %s", e) raise e def run(self): @@ -199,12 +241,16 @@ class VADFunctor(BaseFunctor): return True def stop(self): + """ + 停止线程 + 通过设置_stop_event为True, 来在input_queue.get()循环为空时退出 + """ with self._status_lock: self._stop_event = True self._thread.join() with self._status_lock: self._is_running = False - return True + return not self._thread.is_alive() # class VAD: diff --git a/tests/functor/vad_test.py b/tests/functor/vad_test.py index f39e325..7a7df37 100644 --- a/tests/functor/vad_test.py +++ b/tests/functor/vad_test.py @@ -3,6 +3,7 @@ Functor测试 VAD测试 """ from src.functor.vad_functor import VADFunctor +from src.functor.asr_functor import ASRFunctor from queue import Queue, Empty from src.model_loader import ModelLoader from src.models import AudioBinary_Config, AudioBinary_data_list @@ -22,6 +23,8 @@ model_loader = ModelLoader() def test_vad_functor(): # 加载模型 args = { + "asr_model": "paraformer-zh", + "asr_model_revision": "v2.0.4", "vad_model": "fsmn-vad", "vad_model_revision": "v2.0.4", "auto_update": False, @@ -40,6 +43,7 @@ def test_vad_functor(): audio_config.chunk_stride = chunk_stride # 创建输入队列 input_queue = Queue() + vad2asr_queue = Queue() # 创建音频数据列表 audio_binary_data_list = AudioBinary_data_list() @@ -52,14 +56,30 @@ def test_vad_functor(): # 设置音频数据列表 vad_functor.set_audio_binary_data_list(audio_binary_data_list) # 设置回调函数 - vad_functor.add_callback(lambda x: print(f"callback: {x}")) + vad_functor.add_callback(lambda x: print(f"vad callback: {x}")) + vad_functor.add_callback(lambda x: vad2asr_queue.put(x)) # 设置模型 vad_functor.set_model({ 'vad': model_loader.models['vad'] }) - # 启动VAD函数器 vad_functor.run() + + # 创建ASR函数器 + asr_functor = ASRFunctor() + # 设置输入队列 + asr_functor.set_input_queue(vad2asr_queue) + # 设置音频配置 + asr_functor.set_audio_config(audio_config) + # 设置回调函数 + asr_functor.add_callback(lambda x: print(f"asr callback: {x}")) + # 设置模型 + asr_functor.set_model({ + 'asr': model_loader.models['asr'] + }) + # 启动ASR函数器 + asr_functor.run() + f_binary = f_data audio_clip_len = 200 print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}") @@ -73,6 +93,12 @@ def test_vad_functor(): vad_functor.stop() print("[vad_test] VAD函数器结束") + print("[vad_test] 等待vad2asr_queue为空") + vad2asr_queue.join() + print("[vad_test] vad2asr_queue为空") + asr_functor.stop() + print("[vad_test] ASR函数器结束") + # 保存音频数据 if OVERWATCH: for index in range(len(audio_binary_data_list)):