[代码重构中]初步构建ASRFunctor,与VADFunctor在vad_test.py中进行联调无问题,数据衔接正常。
This commit is contained in:
parent
4e9e94d8dc
commit
ff9bd70039
159
src/functor/asr_functor.py
Normal file
159
src/functor/asr_functor.py
Normal file
@ -0,0 +1,159 @@
|
||||
"""
|
||||
ASRFunctor
|
||||
负责对音频片段进行ASR处理, 以ASR_Result进行callback
|
||||
"""
|
||||
from src.functor.base import BaseFunctor
|
||||
from src.models import AudioBinary_data_list, AudioBinary_Config,VAD_Functor_result
|
||||
from typing import Callable, List
|
||||
from queue import Queue, Empty
|
||||
import threading
|
||||
|
||||
# 日志
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
class ASRFunctor(BaseFunctor):
|
||||
"""
|
||||
ASRFunctor
|
||||
负责对音频片段进行ASR处理, 以ASR_Result进行callback
|
||||
需要配置好 _model, _callback, _input_queue, _audio_config
|
||||
否则无法run()启动线程
|
||||
|
||||
运行中, 使用reset_cache()重置缓存, 准备下次任务
|
||||
|
||||
使用stop()停止线程, 但需要等待input_queue为空
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# 资源与配置
|
||||
self._model: dict = {} # 模型
|
||||
self._callback: List[Callable] = [] # 回调函数
|
||||
self._input_queue: Queue = None # 输入队列
|
||||
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||
|
||||
# flag
|
||||
self._is_running: bool = False
|
||||
self._stop_event: bool = False
|
||||
|
||||
# 线程资源
|
||||
self._thread: threading.Thread = None
|
||||
|
||||
# 状态锁
|
||||
self._status_lock: threading.Lock = threading.Lock()
|
||||
|
||||
# 缓存
|
||||
self._hotwords: List[str] = []
|
||||
|
||||
def reset_cache(self) -> None:
|
||||
"""
|
||||
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||
"""
|
||||
pass
|
||||
|
||||
def set_input_queue(self, queue: Queue) -> None:
|
||||
"""
|
||||
设置监听的输入消息队列
|
||||
"""
|
||||
self._input_queue = queue
|
||||
|
||||
def set_model(self, model: dict) -> None:
|
||||
"""
|
||||
设置推理模型
|
||||
"""
|
||||
self._model = model
|
||||
|
||||
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
||||
"""
|
||||
设置音频配置
|
||||
"""
|
||||
self._audio_config = audio_config
|
||||
logger.debug("ASRFunctor设置音频配置: %s", self._audio_config)
|
||||
|
||||
def add_callback(self, callback: Callable) -> None:
|
||||
"""
|
||||
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||||
"""
|
||||
if not isinstance(self._callback, list):
|
||||
self._callback = []
|
||||
self._callback.append(callback)
|
||||
|
||||
def _do_callback(self, result: List[str]) -> None:
|
||||
"""
|
||||
回调函数
|
||||
"""
|
||||
text = result[0]['text'].replace(" ", "")
|
||||
for callback in self._callback:
|
||||
callback(text)
|
||||
|
||||
def _process(self, data: VAD_Functor_result) -> None:
|
||||
"""
|
||||
处理数据
|
||||
"""
|
||||
binary_data = data.audiobinary_data.binary_data
|
||||
result = self._model["asr"].generate(
|
||||
input=binary_data,
|
||||
chunk_size=self._audio_config.chunk_size,
|
||||
hotwords=self._hotwords,
|
||||
)
|
||||
self._do_callback(result)
|
||||
|
||||
def _run(self):
|
||||
"""
|
||||
线程运行逻辑
|
||||
"""
|
||||
with self._status_lock:
|
||||
self._is_running = True
|
||||
self._stop_event = False
|
||||
# 运行逻辑
|
||||
while self._is_running:
|
||||
try:
|
||||
data = self._input_queue.get(True, timeout=1)
|
||||
self._process(data)
|
||||
self._input_queue.task_done()
|
||||
# 当队列为空时, 间隔1s检测是否进入停止事件。
|
||||
except Empty:
|
||||
if self._stop_event:
|
||||
break
|
||||
continue
|
||||
# 其他异常
|
||||
except Exception as e:
|
||||
logger.error("ASRFunctor运行时发生错误: %s", e)
|
||||
raise e
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
启动线程
|
||||
"""
|
||||
self._pre_check()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
return self._thread
|
||||
|
||||
def _pre_check(self) -> bool:
|
||||
"""
|
||||
预检查
|
||||
"""
|
||||
if self._model is None:
|
||||
raise ValueError("模型未设置")
|
||||
if self._audio_config is None:
|
||||
raise ValueError("音频配置未设置")
|
||||
if self._input_queue is None:
|
||||
raise ValueError("输入队列未设置")
|
||||
if self._callback is None:
|
||||
raise ValueError("回调函数未设置")
|
||||
return True
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
停止线程
|
||||
"""
|
||||
with self._status_lock:
|
||||
self._stop_event = True
|
||||
self._thread.join()
|
||||
with self._status_lock:
|
||||
self._is_running = False
|
||||
return not self._thread.is_alive()
|
||||
|
||||
|
@ -1,11 +1,17 @@
|
||||
from funasr import AutoModel
|
||||
from typing import List, Dict, Any, Callable
|
||||
from src.models import VAD_Functor_result, _AudioBinary_data, AudioBinary_Config, AudioBinary_data_list
|
||||
from typing import Callable
|
||||
from src.functor.base import BaseFunctor
|
||||
"""
|
||||
VADFunctor
|
||||
负责对音频片段进行VAD处理, 以VAD_Result进行callback
|
||||
"""
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
from typing import List, Any, Callable
|
||||
import numpy
|
||||
from src.models import (
|
||||
VAD_Functor_result,
|
||||
AudioBinary_Config,
|
||||
AudioBinary_data_list,
|
||||
)
|
||||
from src.functor.base import BaseFunctor
|
||||
|
||||
# 日志
|
||||
from src.utils.logger import get_module_logger
|
||||
@ -14,22 +20,34 @@ logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
class VADFunctor(BaseFunctor):
|
||||
def __init__(
|
||||
self
|
||||
):
|
||||
"""
|
||||
VADFunctor
|
||||
负责对音频片段进行VAD处理, 以VAD_Result进行callback
|
||||
需要配置好 _model, _callback, _input_queue, _audio_config, _audio_binary_data_list
|
||||
否则无法run()启动线程
|
||||
|
||||
运行中, 使用reset_cache()重置缓存, 准备下次任务
|
||||
|
||||
使用stop()停止线程, 但需要等待input_queue为空
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# 资源与配置
|
||||
self._model: dict = {} # 模型
|
||||
self._callback: List[Callable] = [] # 回调函数
|
||||
self._input_queue: Queue = None # 输入队列
|
||||
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||
self._audio_binary_data_list: AudioBinary_data_list = None # 音频数据列表
|
||||
self._model: dict = {} # 模型
|
||||
self._callback: List[Callable] = [] # 回调函数
|
||||
self._input_queue: Queue = None # 输入队列
|
||||
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||
self._audio_binary_data_list: AudioBinary_data_list = None # 音频数据列表
|
||||
|
||||
# flag
|
||||
# 此处用到两个锁,但都是为了截断_run线程,考虑后续优化
|
||||
self._is_running: bool = False
|
||||
self._stop_event: bool = False
|
||||
|
||||
# 线程资源
|
||||
self._thread: threading.Thread = None
|
||||
|
||||
# 状态锁
|
||||
self._status_lock: threading.Lock = threading.Lock()
|
||||
|
||||
@ -39,8 +57,8 @@ class VADFunctor(BaseFunctor):
|
||||
self._model_cache: dict = {}
|
||||
self._cache_result_list = []
|
||||
self._audiobinary_cache = None
|
||||
|
||||
def reset_cache(self):
|
||||
|
||||
def reset_cache(self) -> None:
|
||||
"""
|
||||
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||
"""
|
||||
@ -50,25 +68,46 @@ class VADFunctor(BaseFunctor):
|
||||
self._cache_result_list = []
|
||||
self._audiobinary_cache = None
|
||||
|
||||
def set_input_queue(self, queue: Queue):
|
||||
def set_input_queue(self, queue: Queue) -> None:
|
||||
"""
|
||||
设置监听的输入消息队列
|
||||
"""
|
||||
self._input_queue = queue
|
||||
|
||||
def set_model(self, model: dict):
|
||||
def set_model(self, model: dict) -> None:
|
||||
"""
|
||||
设置推理模型
|
||||
"""
|
||||
self._model = model
|
||||
|
||||
def set_audio_config(self, audio_config: AudioBinary_Config):
|
||||
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
||||
"""
|
||||
设置音频配置
|
||||
"""
|
||||
self._audio_config = audio_config
|
||||
logger.info(f"VADFunctor设置音频配置: {self._audio_config}")
|
||||
logger.debug("VADFunctor设置音频配置: %s", self._audio_config)
|
||||
|
||||
def set_audio_binary_data_list(self, audio_binary_data_list: AudioBinary_data_list):
|
||||
def set_audio_binary_data_list(
|
||||
self, audio_binary_data_list: AudioBinary_data_list
|
||||
) -> None:
|
||||
"""
|
||||
设置音频数据列表, 为Class AudioBinary_data_list类型
|
||||
AudioBinary_data_list包含binary_data_list, 为list[_AudioBinary_data]类型
|
||||
_AudioBinary_data包含binary_data, 为bytes/numpy.ndarray类型
|
||||
"""
|
||||
self._audio_binary_data_list = audio_binary_data_list
|
||||
|
||||
def add_callback(self, callback: Callable):
|
||||
def add_callback(self, callback: Callable) -> None:
|
||||
"""
|
||||
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||||
"""
|
||||
if not isinstance(self._callback, list):
|
||||
self._callback = []
|
||||
self._callback.append(callback)
|
||||
|
||||
def _do_callback(self, result: List[List[int]], audio_cache: AudioBinary_data_list):
|
||||
def _do_callback(
|
||||
self, result: List[List[int]]
|
||||
) -> None:
|
||||
"""
|
||||
回调函数
|
||||
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice
|
||||
@ -91,7 +130,7 @@ class VADFunctor(BaseFunctor):
|
||||
while len(self._cache_result_list) > 1:
|
||||
# 创建VAD片段
|
||||
# 计算开始帧
|
||||
start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0])
|
||||
start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0])
|
||||
start_frame -= self._audio_cache_preindex
|
||||
# 计算结束帧
|
||||
end_frame = self._audio_config.ms2frame(self._cache_result_list[0][1])
|
||||
@ -101,7 +140,7 @@ class VADFunctor(BaseFunctor):
|
||||
audiobinary_data_list=self._audio_binary_data_list,
|
||||
data=self._audiobinary_cache[start_frame:end_frame],
|
||||
start_time=self._cache_result_list[0][0],
|
||||
end_time=self._cache_result_list[0][1]
|
||||
end_time=self._cache_result_list[0][1],
|
||||
)
|
||||
self._audio_cache_preindex += end_frame
|
||||
self._audiobinary_cache = self._audiobinary_cache[end_frame:]
|
||||
@ -109,9 +148,9 @@ class VADFunctor(BaseFunctor):
|
||||
callback(vad_result)
|
||||
self._cache_result_list.pop(0)
|
||||
|
||||
def _predeal_data(self, data: Any):
|
||||
def _predeal_data(self, data: Any) -> None:
|
||||
"""
|
||||
预处理数据
|
||||
预处理数据, 将数据缓存到_audio_cache和_audiobinary_cache中
|
||||
"""
|
||||
if self._audio_cache is None:
|
||||
self._audio_cache = data
|
||||
@ -126,26 +165,29 @@ class VADFunctor(BaseFunctor):
|
||||
else:
|
||||
# 拼接音频数据
|
||||
if isinstance(self._audiobinary_cache, numpy.ndarray):
|
||||
self._audiobinary_cache = numpy.concatenate((self._audiobinary_cache, data))
|
||||
self._audiobinary_cache = numpy.concatenate(
|
||||
(self._audiobinary_cache, data)
|
||||
)
|
||||
elif isinstance(self._audiobinary_cache, list):
|
||||
self._audiobinary_cache.append(data)
|
||||
|
||||
def _process(self, data: Any):
|
||||
"""
|
||||
处理数据
|
||||
使用model进行生成, 并使用_do_callback进行回调
|
||||
"""
|
||||
self._predeal_data(data)
|
||||
if len(self._audio_cache) >= self._audio_config.chunk_stride:
|
||||
result = self._model['vad'].generate(
|
||||
result = self._model["vad"].generate(
|
||||
input=self._audio_cache,
|
||||
cache=self._model_cache,
|
||||
chunk_size=self._audio_config.chunk_size,
|
||||
is_final=False,
|
||||
)
|
||||
if (len(result[0]['value']) > 0):
|
||||
self._do_callback(result[0]['value'], self._audio_cache)
|
||||
logger.debug(f"VADFunctor结果: {result[0]['value']}")
|
||||
if len(result[0]["value"]) > 0:
|
||||
self._do_callback(result[0]["value"])
|
||||
# logger.debug(f"VADFunctor结果: {result[0]['value']}")
|
||||
self._audio_cache = None
|
||||
|
||||
|
||||
def _run(self):
|
||||
"""
|
||||
@ -170,7 +212,7 @@ class VADFunctor(BaseFunctor):
|
||||
continue
|
||||
# 其他异常
|
||||
except Exception as e:
|
||||
logger.error(f"VADFunctor运行时发生错误: {e}")
|
||||
logger.error("VADFunctor运行时发生错误: %s", e)
|
||||
raise e
|
||||
|
||||
def run(self):
|
||||
@ -199,12 +241,16 @@ class VADFunctor(BaseFunctor):
|
||||
return True
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
停止线程
|
||||
通过设置_stop_event为True, 来在input_queue.get()循环为空时退出
|
||||
"""
|
||||
with self._status_lock:
|
||||
self._stop_event = True
|
||||
self._thread.join()
|
||||
with self._status_lock:
|
||||
self._is_running = False
|
||||
return True
|
||||
return not self._thread.is_alive()
|
||||
|
||||
|
||||
# class VAD:
|
||||
|
@ -3,6 +3,7 @@ Functor测试
|
||||
VAD测试
|
||||
"""
|
||||
from src.functor.vad_functor import VADFunctor
|
||||
from src.functor.asr_functor import ASRFunctor
|
||||
from queue import Queue, Empty
|
||||
from src.model_loader import ModelLoader
|
||||
from src.models import AudioBinary_Config, AudioBinary_data_list
|
||||
@ -22,6 +23,8 @@ model_loader = ModelLoader()
|
||||
def test_vad_functor():
|
||||
# 加载模型
|
||||
args = {
|
||||
"asr_model": "paraformer-zh",
|
||||
"asr_model_revision": "v2.0.4",
|
||||
"vad_model": "fsmn-vad",
|
||||
"vad_model_revision": "v2.0.4",
|
||||
"auto_update": False,
|
||||
@ -40,6 +43,7 @@ def test_vad_functor():
|
||||
audio_config.chunk_stride = chunk_stride
|
||||
# 创建输入队列
|
||||
input_queue = Queue()
|
||||
vad2asr_queue = Queue()
|
||||
# 创建音频数据列表
|
||||
audio_binary_data_list = AudioBinary_data_list()
|
||||
|
||||
@ -52,14 +56,30 @@ def test_vad_functor():
|
||||
# 设置音频数据列表
|
||||
vad_functor.set_audio_binary_data_list(audio_binary_data_list)
|
||||
# 设置回调函数
|
||||
vad_functor.add_callback(lambda x: print(f"callback: {x}"))
|
||||
vad_functor.add_callback(lambda x: print(f"vad callback: {x}"))
|
||||
vad_functor.add_callback(lambda x: vad2asr_queue.put(x))
|
||||
# 设置模型
|
||||
vad_functor.set_model({
|
||||
'vad': model_loader.models['vad']
|
||||
})
|
||||
|
||||
# 启动VAD函数器
|
||||
vad_functor.run()
|
||||
|
||||
# 创建ASR函数器
|
||||
asr_functor = ASRFunctor()
|
||||
# 设置输入队列
|
||||
asr_functor.set_input_queue(vad2asr_queue)
|
||||
# 设置音频配置
|
||||
asr_functor.set_audio_config(audio_config)
|
||||
# 设置回调函数
|
||||
asr_functor.add_callback(lambda x: print(f"asr callback: {x}"))
|
||||
# 设置模型
|
||||
asr_functor.set_model({
|
||||
'asr': model_loader.models['asr']
|
||||
})
|
||||
# 启动ASR函数器
|
||||
asr_functor.run()
|
||||
|
||||
f_binary = f_data
|
||||
audio_clip_len = 200
|
||||
print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}")
|
||||
@ -73,6 +93,12 @@ def test_vad_functor():
|
||||
vad_functor.stop()
|
||||
print("[vad_test] VAD函数器结束")
|
||||
|
||||
print("[vad_test] 等待vad2asr_queue为空")
|
||||
vad2asr_queue.join()
|
||||
print("[vad_test] vad2asr_queue为空")
|
||||
asr_functor.stop()
|
||||
print("[vad_test] ASR函数器结束")
|
||||
|
||||
# 保存音频数据
|
||||
if OVERWATCH:
|
||||
for index in range(len(audio_binary_data_list)):
|
||||
|
Loading…
x
Reference in New Issue
Block a user