[代码重构中]初步构建ASRFunctor,与VADFunctor在vad_test.py中进行联调无问题,数据衔接正常。
This commit is contained in:
parent
4e9e94d8dc
commit
ff9bd70039
159
src/functor/asr_functor.py
Normal file
159
src/functor/asr_functor.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
"""
|
||||||
|
ASRFunctor
|
||||||
|
负责对音频片段进行ASR处理, 以ASR_Result进行callback
|
||||||
|
"""
|
||||||
|
from src.functor.base import BaseFunctor
|
||||||
|
from src.models import AudioBinary_data_list, AudioBinary_Config,VAD_Functor_result
|
||||||
|
from typing import Callable, List
|
||||||
|
from queue import Queue, Empty
|
||||||
|
import threading
|
||||||
|
|
||||||
|
# 日志
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
class ASRFunctor(BaseFunctor):
|
||||||
|
"""
|
||||||
|
ASRFunctor
|
||||||
|
负责对音频片段进行ASR处理, 以ASR_Result进行callback
|
||||||
|
需要配置好 _model, _callback, _input_queue, _audio_config
|
||||||
|
否则无法run()启动线程
|
||||||
|
|
||||||
|
运行中, 使用reset_cache()重置缓存, 准备下次任务
|
||||||
|
|
||||||
|
使用stop()停止线程, 但需要等待input_queue为空
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
# 资源与配置
|
||||||
|
self._model: dict = {} # 模型
|
||||||
|
self._callback: List[Callable] = [] # 回调函数
|
||||||
|
self._input_queue: Queue = None # 输入队列
|
||||||
|
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||||
|
|
||||||
|
# flag
|
||||||
|
self._is_running: bool = False
|
||||||
|
self._stop_event: bool = False
|
||||||
|
|
||||||
|
# 线程资源
|
||||||
|
self._thread: threading.Thread = None
|
||||||
|
|
||||||
|
# 状态锁
|
||||||
|
self._status_lock: threading.Lock = threading.Lock()
|
||||||
|
|
||||||
|
# 缓存
|
||||||
|
self._hotwords: List[str] = []
|
||||||
|
|
||||||
|
def reset_cache(self) -> None:
|
||||||
|
"""
|
||||||
|
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_input_queue(self, queue: Queue) -> None:
|
||||||
|
"""
|
||||||
|
设置监听的输入消息队列
|
||||||
|
"""
|
||||||
|
self._input_queue = queue
|
||||||
|
|
||||||
|
def set_model(self, model: dict) -> None:
|
||||||
|
"""
|
||||||
|
设置推理模型
|
||||||
|
"""
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
||||||
|
"""
|
||||||
|
设置音频配置
|
||||||
|
"""
|
||||||
|
self._audio_config = audio_config
|
||||||
|
logger.debug("ASRFunctor设置音频配置: %s", self._audio_config)
|
||||||
|
|
||||||
|
def add_callback(self, callback: Callable) -> None:
|
||||||
|
"""
|
||||||
|
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||||||
|
"""
|
||||||
|
if not isinstance(self._callback, list):
|
||||||
|
self._callback = []
|
||||||
|
self._callback.append(callback)
|
||||||
|
|
||||||
|
def _do_callback(self, result: List[str]) -> None:
|
||||||
|
"""
|
||||||
|
回调函数
|
||||||
|
"""
|
||||||
|
text = result[0]['text'].replace(" ", "")
|
||||||
|
for callback in self._callback:
|
||||||
|
callback(text)
|
||||||
|
|
||||||
|
def _process(self, data: VAD_Functor_result) -> None:
|
||||||
|
"""
|
||||||
|
处理数据
|
||||||
|
"""
|
||||||
|
binary_data = data.audiobinary_data.binary_data
|
||||||
|
result = self._model["asr"].generate(
|
||||||
|
input=binary_data,
|
||||||
|
chunk_size=self._audio_config.chunk_size,
|
||||||
|
hotwords=self._hotwords,
|
||||||
|
)
|
||||||
|
self._do_callback(result)
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
"""
|
||||||
|
线程运行逻辑
|
||||||
|
"""
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = True
|
||||||
|
self._stop_event = False
|
||||||
|
# 运行逻辑
|
||||||
|
while self._is_running:
|
||||||
|
try:
|
||||||
|
data = self._input_queue.get(True, timeout=1)
|
||||||
|
self._process(data)
|
||||||
|
self._input_queue.task_done()
|
||||||
|
# 当队列为空时, 间隔1s检测是否进入停止事件。
|
||||||
|
except Empty:
|
||||||
|
if self._stop_event:
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
# 其他异常
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("ASRFunctor运行时发生错误: %s", e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""
|
||||||
|
启动线程
|
||||||
|
"""
|
||||||
|
self._pre_check()
|
||||||
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
return self._thread
|
||||||
|
|
||||||
|
def _pre_check(self) -> bool:
|
||||||
|
"""
|
||||||
|
预检查
|
||||||
|
"""
|
||||||
|
if self._model is None:
|
||||||
|
raise ValueError("模型未设置")
|
||||||
|
if self._audio_config is None:
|
||||||
|
raise ValueError("音频配置未设置")
|
||||||
|
if self._input_queue is None:
|
||||||
|
raise ValueError("输入队列未设置")
|
||||||
|
if self._callback is None:
|
||||||
|
raise ValueError("回调函数未设置")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""
|
||||||
|
停止线程
|
||||||
|
"""
|
||||||
|
with self._status_lock:
|
||||||
|
self._stop_event = True
|
||||||
|
self._thread.join()
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = False
|
||||||
|
return not self._thread.is_alive()
|
||||||
|
|
||||||
|
|
@ -1,11 +1,17 @@
|
|||||||
from funasr import AutoModel
|
"""
|
||||||
from typing import List, Dict, Any, Callable
|
VADFunctor
|
||||||
from src.models import VAD_Functor_result, _AudioBinary_data, AudioBinary_Config, AudioBinary_data_list
|
负责对音频片段进行VAD处理, 以VAD_Result进行callback
|
||||||
from typing import Callable
|
"""
|
||||||
from src.functor.base import BaseFunctor
|
|
||||||
import threading
|
import threading
|
||||||
from queue import Empty, Queue
|
from queue import Empty, Queue
|
||||||
|
from typing import List, Any, Callable
|
||||||
import numpy
|
import numpy
|
||||||
|
from src.models import (
|
||||||
|
VAD_Functor_result,
|
||||||
|
AudioBinary_Config,
|
||||||
|
AudioBinary_data_list,
|
||||||
|
)
|
||||||
|
from src.functor.base import BaseFunctor
|
||||||
|
|
||||||
# 日志
|
# 日志
|
||||||
from src.utils.logger import get_module_logger
|
from src.utils.logger import get_module_logger
|
||||||
@ -14,9 +20,18 @@ logger = get_module_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class VADFunctor(BaseFunctor):
|
class VADFunctor(BaseFunctor):
|
||||||
def __init__(
|
"""
|
||||||
self
|
VADFunctor
|
||||||
):
|
负责对音频片段进行VAD处理, 以VAD_Result进行callback
|
||||||
|
需要配置好 _model, _callback, _input_queue, _audio_config, _audio_binary_data_list
|
||||||
|
否则无法run()启动线程
|
||||||
|
|
||||||
|
运行中, 使用reset_cache()重置缓存, 准备下次任务
|
||||||
|
|
||||||
|
使用stop()停止线程, 但需要等待input_queue为空
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 资源与配置
|
# 资源与配置
|
||||||
self._model: dict = {} # 模型
|
self._model: dict = {} # 模型
|
||||||
@ -30,6 +45,9 @@ class VADFunctor(BaseFunctor):
|
|||||||
self._is_running: bool = False
|
self._is_running: bool = False
|
||||||
self._stop_event: bool = False
|
self._stop_event: bool = False
|
||||||
|
|
||||||
|
# 线程资源
|
||||||
|
self._thread: threading.Thread = None
|
||||||
|
|
||||||
# 状态锁
|
# 状态锁
|
||||||
self._status_lock: threading.Lock = threading.Lock()
|
self._status_lock: threading.Lock = threading.Lock()
|
||||||
|
|
||||||
@ -40,7 +58,7 @@ class VADFunctor(BaseFunctor):
|
|||||||
self._cache_result_list = []
|
self._cache_result_list = []
|
||||||
self._audiobinary_cache = None
|
self._audiobinary_cache = None
|
||||||
|
|
||||||
def reset_cache(self):
|
def reset_cache(self) -> None:
|
||||||
"""
|
"""
|
||||||
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||||
"""
|
"""
|
||||||
@ -50,25 +68,46 @@ class VADFunctor(BaseFunctor):
|
|||||||
self._cache_result_list = []
|
self._cache_result_list = []
|
||||||
self._audiobinary_cache = None
|
self._audiobinary_cache = None
|
||||||
|
|
||||||
def set_input_queue(self, queue: Queue):
|
def set_input_queue(self, queue: Queue) -> None:
|
||||||
|
"""
|
||||||
|
设置监听的输入消息队列
|
||||||
|
"""
|
||||||
self._input_queue = queue
|
self._input_queue = queue
|
||||||
|
|
||||||
def set_model(self, model: dict):
|
def set_model(self, model: dict) -> None:
|
||||||
|
"""
|
||||||
|
设置推理模型
|
||||||
|
"""
|
||||||
self._model = model
|
self._model = model
|
||||||
|
|
||||||
def set_audio_config(self, audio_config: AudioBinary_Config):
|
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
||||||
|
"""
|
||||||
|
设置音频配置
|
||||||
|
"""
|
||||||
self._audio_config = audio_config
|
self._audio_config = audio_config
|
||||||
logger.info(f"VADFunctor设置音频配置: {self._audio_config}")
|
logger.debug("VADFunctor设置音频配置: %s", self._audio_config)
|
||||||
|
|
||||||
def set_audio_binary_data_list(self, audio_binary_data_list: AudioBinary_data_list):
|
def set_audio_binary_data_list(
|
||||||
|
self, audio_binary_data_list: AudioBinary_data_list
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
设置音频数据列表, 为Class AudioBinary_data_list类型
|
||||||
|
AudioBinary_data_list包含binary_data_list, 为list[_AudioBinary_data]类型
|
||||||
|
_AudioBinary_data包含binary_data, 为bytes/numpy.ndarray类型
|
||||||
|
"""
|
||||||
self._audio_binary_data_list = audio_binary_data_list
|
self._audio_binary_data_list = audio_binary_data_list
|
||||||
|
|
||||||
def add_callback(self, callback: Callable):
|
def add_callback(self, callback: Callable) -> None:
|
||||||
|
"""
|
||||||
|
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||||||
|
"""
|
||||||
if not isinstance(self._callback, list):
|
if not isinstance(self._callback, list):
|
||||||
self._callback = []
|
self._callback = []
|
||||||
self._callback.append(callback)
|
self._callback.append(callback)
|
||||||
|
|
||||||
def _do_callback(self, result: List[List[int]], audio_cache: AudioBinary_data_list):
|
def _do_callback(
|
||||||
|
self, result: List[List[int]]
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
回调函数
|
回调函数
|
||||||
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice
|
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice
|
||||||
@ -101,7 +140,7 @@ class VADFunctor(BaseFunctor):
|
|||||||
audiobinary_data_list=self._audio_binary_data_list,
|
audiobinary_data_list=self._audio_binary_data_list,
|
||||||
data=self._audiobinary_cache[start_frame:end_frame],
|
data=self._audiobinary_cache[start_frame:end_frame],
|
||||||
start_time=self._cache_result_list[0][0],
|
start_time=self._cache_result_list[0][0],
|
||||||
end_time=self._cache_result_list[0][1]
|
end_time=self._cache_result_list[0][1],
|
||||||
)
|
)
|
||||||
self._audio_cache_preindex += end_frame
|
self._audio_cache_preindex += end_frame
|
||||||
self._audiobinary_cache = self._audiobinary_cache[end_frame:]
|
self._audiobinary_cache = self._audiobinary_cache[end_frame:]
|
||||||
@ -109,9 +148,9 @@ class VADFunctor(BaseFunctor):
|
|||||||
callback(vad_result)
|
callback(vad_result)
|
||||||
self._cache_result_list.pop(0)
|
self._cache_result_list.pop(0)
|
||||||
|
|
||||||
def _predeal_data(self, data: Any):
|
def _predeal_data(self, data: Any) -> None:
|
||||||
"""
|
"""
|
||||||
预处理数据
|
预处理数据, 将数据缓存到_audio_cache和_audiobinary_cache中
|
||||||
"""
|
"""
|
||||||
if self._audio_cache is None:
|
if self._audio_cache is None:
|
||||||
self._audio_cache = data
|
self._audio_cache = data
|
||||||
@ -126,27 +165,30 @@ class VADFunctor(BaseFunctor):
|
|||||||
else:
|
else:
|
||||||
# 拼接音频数据
|
# 拼接音频数据
|
||||||
if isinstance(self._audiobinary_cache, numpy.ndarray):
|
if isinstance(self._audiobinary_cache, numpy.ndarray):
|
||||||
self._audiobinary_cache = numpy.concatenate((self._audiobinary_cache, data))
|
self._audiobinary_cache = numpy.concatenate(
|
||||||
|
(self._audiobinary_cache, data)
|
||||||
|
)
|
||||||
elif isinstance(self._audiobinary_cache, list):
|
elif isinstance(self._audiobinary_cache, list):
|
||||||
self._audiobinary_cache.append(data)
|
self._audiobinary_cache.append(data)
|
||||||
|
|
||||||
def _process(self, data: Any):
|
def _process(self, data: Any):
|
||||||
"""
|
"""
|
||||||
处理数据
|
处理数据
|
||||||
|
使用model进行生成, 并使用_do_callback进行回调
|
||||||
"""
|
"""
|
||||||
self._predeal_data(data)
|
self._predeal_data(data)
|
||||||
if len(self._audio_cache) >= self._audio_config.chunk_stride:
|
if len(self._audio_cache) >= self._audio_config.chunk_stride:
|
||||||
result = self._model['vad'].generate(
|
result = self._model["vad"].generate(
|
||||||
input=self._audio_cache,
|
input=self._audio_cache,
|
||||||
cache=self._model_cache,
|
cache=self._model_cache,
|
||||||
chunk_size=self._audio_config.chunk_size,
|
chunk_size=self._audio_config.chunk_size,
|
||||||
is_final=False,
|
is_final=False,
|
||||||
)
|
)
|
||||||
if (len(result[0]['value']) > 0):
|
if len(result[0]["value"]) > 0:
|
||||||
self._do_callback(result[0]['value'], self._audio_cache)
|
self._do_callback(result[0]["value"])
|
||||||
logger.debug(f"VADFunctor结果: {result[0]['value']}")
|
# logger.debug(f"VADFunctor结果: {result[0]['value']}")
|
||||||
self._audio_cache = None
|
self._audio_cache = None
|
||||||
|
|
||||||
|
|
||||||
def _run(self):
|
def _run(self):
|
||||||
"""
|
"""
|
||||||
线程运行逻辑
|
线程运行逻辑
|
||||||
@ -170,7 +212,7 @@ class VADFunctor(BaseFunctor):
|
|||||||
continue
|
continue
|
||||||
# 其他异常
|
# 其他异常
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"VADFunctor运行时发生错误: {e}")
|
logger.error("VADFunctor运行时发生错误: %s", e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
@ -199,12 +241,16 @@ class VADFunctor(BaseFunctor):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
|
"""
|
||||||
|
停止线程
|
||||||
|
通过设置_stop_event为True, 来在input_queue.get()循环为空时退出
|
||||||
|
"""
|
||||||
with self._status_lock:
|
with self._status_lock:
|
||||||
self._stop_event = True
|
self._stop_event = True
|
||||||
self._thread.join()
|
self._thread.join()
|
||||||
with self._status_lock:
|
with self._status_lock:
|
||||||
self._is_running = False
|
self._is_running = False
|
||||||
return True
|
return not self._thread.is_alive()
|
||||||
|
|
||||||
|
|
||||||
# class VAD:
|
# class VAD:
|
||||||
|
@ -3,6 +3,7 @@ Functor测试
|
|||||||
VAD测试
|
VAD测试
|
||||||
"""
|
"""
|
||||||
from src.functor.vad_functor import VADFunctor
|
from src.functor.vad_functor import VADFunctor
|
||||||
|
from src.functor.asr_functor import ASRFunctor
|
||||||
from queue import Queue, Empty
|
from queue import Queue, Empty
|
||||||
from src.model_loader import ModelLoader
|
from src.model_loader import ModelLoader
|
||||||
from src.models import AudioBinary_Config, AudioBinary_data_list
|
from src.models import AudioBinary_Config, AudioBinary_data_list
|
||||||
@ -22,6 +23,8 @@ model_loader = ModelLoader()
|
|||||||
def test_vad_functor():
|
def test_vad_functor():
|
||||||
# 加载模型
|
# 加载模型
|
||||||
args = {
|
args = {
|
||||||
|
"asr_model": "paraformer-zh",
|
||||||
|
"asr_model_revision": "v2.0.4",
|
||||||
"vad_model": "fsmn-vad",
|
"vad_model": "fsmn-vad",
|
||||||
"vad_model_revision": "v2.0.4",
|
"vad_model_revision": "v2.0.4",
|
||||||
"auto_update": False,
|
"auto_update": False,
|
||||||
@ -40,6 +43,7 @@ def test_vad_functor():
|
|||||||
audio_config.chunk_stride = chunk_stride
|
audio_config.chunk_stride = chunk_stride
|
||||||
# 创建输入队列
|
# 创建输入队列
|
||||||
input_queue = Queue()
|
input_queue = Queue()
|
||||||
|
vad2asr_queue = Queue()
|
||||||
# 创建音频数据列表
|
# 创建音频数据列表
|
||||||
audio_binary_data_list = AudioBinary_data_list()
|
audio_binary_data_list = AudioBinary_data_list()
|
||||||
|
|
||||||
@ -52,14 +56,30 @@ def test_vad_functor():
|
|||||||
# 设置音频数据列表
|
# 设置音频数据列表
|
||||||
vad_functor.set_audio_binary_data_list(audio_binary_data_list)
|
vad_functor.set_audio_binary_data_list(audio_binary_data_list)
|
||||||
# 设置回调函数
|
# 设置回调函数
|
||||||
vad_functor.add_callback(lambda x: print(f"callback: {x}"))
|
vad_functor.add_callback(lambda x: print(f"vad callback: {x}"))
|
||||||
|
vad_functor.add_callback(lambda x: vad2asr_queue.put(x))
|
||||||
# 设置模型
|
# 设置模型
|
||||||
vad_functor.set_model({
|
vad_functor.set_model({
|
||||||
'vad': model_loader.models['vad']
|
'vad': model_loader.models['vad']
|
||||||
})
|
})
|
||||||
|
|
||||||
# 启动VAD函数器
|
# 启动VAD函数器
|
||||||
vad_functor.run()
|
vad_functor.run()
|
||||||
|
|
||||||
|
# 创建ASR函数器
|
||||||
|
asr_functor = ASRFunctor()
|
||||||
|
# 设置输入队列
|
||||||
|
asr_functor.set_input_queue(vad2asr_queue)
|
||||||
|
# 设置音频配置
|
||||||
|
asr_functor.set_audio_config(audio_config)
|
||||||
|
# 设置回调函数
|
||||||
|
asr_functor.add_callback(lambda x: print(f"asr callback: {x}"))
|
||||||
|
# 设置模型
|
||||||
|
asr_functor.set_model({
|
||||||
|
'asr': model_loader.models['asr']
|
||||||
|
})
|
||||||
|
# 启动ASR函数器
|
||||||
|
asr_functor.run()
|
||||||
|
|
||||||
f_binary = f_data
|
f_binary = f_data
|
||||||
audio_clip_len = 200
|
audio_clip_len = 200
|
||||||
print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}")
|
print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}")
|
||||||
@ -73,6 +93,12 @@ def test_vad_functor():
|
|||||||
vad_functor.stop()
|
vad_functor.stop()
|
||||||
print("[vad_test] VAD函数器结束")
|
print("[vad_test] VAD函数器结束")
|
||||||
|
|
||||||
|
print("[vad_test] 等待vad2asr_queue为空")
|
||||||
|
vad2asr_queue.join()
|
||||||
|
print("[vad_test] vad2asr_queue为空")
|
||||||
|
asr_functor.stop()
|
||||||
|
print("[vad_test] ASR函数器结束")
|
||||||
|
|
||||||
# 保存音频数据
|
# 保存音频数据
|
||||||
if OVERWATCH:
|
if OVERWATCH:
|
||||||
for index in range(len(audio_binary_data_list)):
|
for index in range(len(audio_binary_data_list)):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user