STT_Server/src/functor/vad_functor.py

271 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from funasr import AutoModel
from typing import List, Dict, Any, Callable
from src.models import VAD_Functor_result, _AudioBinary_data, AudioBinary_Config, AudioBinary_data_list
from typing import Callable
from src.functor.base import BaseFunctor
import threading
from queue import Empty, Queue
import numpy
# 日志
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class VADFunctor(BaseFunctor):
def __init__(
self
):
super().__init__()
# 资源与配置
self._model: dict = {} # 模型
self._callback: List[Callable] = [] # 回调函数
self._input_queue: Queue = None # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置
self._audio_binary_data_list: AudioBinary_data_list = None # 音频数据列表
# flag
# 此处用到两个锁但都是为了截断_run线程考虑后续优化
self._is_running: bool = False
self._stop_event: bool = False
# 状态锁
self._status_lock: threading.Lock = threading.Lock()
# 缓存
self._audio_cache: numpy.ndarray = None
self._audio_cache_preindex: int = 0
self._model_cache: dict = {}
self._cache_result_list = []
self._audiobinary_cache = None
def reset_cache(self):
"""
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
"""
self._audio_cache = None
self._audio_cache_preindex = 0
self._model_cache = {}
self._cache_result_list = []
self._audiobinary_cache = None
def set_input_queue(self, queue: Queue):
self._input_queue = queue
def set_model(self, model: dict):
self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config):
self._audio_config = audio_config
logger.info(f"VADFunctor设置音频配置: {self._audio_config}")
def set_audio_binary_data_list(self, audio_binary_data_list: AudioBinary_data_list):
self._audio_binary_data_list = audio_binary_data_list
def add_callback(self, callback: Callable):
if not isinstance(self._callback, list):
self._callback = []
self._callback.append(callback)
def _do_callback(self, result: List[List[int]], audio_cache: AudioBinary_data_list):
"""
回调函数
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice
输入:
result: List[[start,end]] 处理所得VAD端点
其中若start==-1, 则表示前无端点, 若end==-1, 则表示后无端点
当处理得到一个完成片段时, 存入AudioBinary中, 并向队列中添加AudioBinary_Slice
输出:
None
"""
# 持久化缓存结果队列
for pair in result:
[start, end] = pair
# 若无前端点, 则向缓存队列中合并
if start == -1:
self._cache_result_list[-1][1] = end
else:
self._cache_result_list.append(pair)
while len(self._cache_result_list) > 1:
# 创建VAD片段
# 计算开始帧
start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0])
start_frame -= self._audio_cache_preindex
# 计算结束帧
end_frame = self._audio_config.ms2frame(self._cache_result_list[0][1])
end_frame -= self._audio_cache_preindex
# 计算开始时间
vad_result = VAD_Functor_result.create_from_push_data(
audiobinary_data_list=self._audio_binary_data_list,
data=self._audiobinary_cache[start_frame:end_frame],
start_time=self._cache_result_list[0][0],
end_time=self._cache_result_list[0][1]
)
self._audio_cache_preindex += end_frame
self._audiobinary_cache = self._audiobinary_cache[end_frame:]
for callback in self._callback:
callback(vad_result)
self._cache_result_list.pop(0)
def _predeal_data(self, data: Any):
"""
预处理数据
"""
if self._audio_cache is None:
self._audio_cache = data
else:
# 拼接音频数据
if isinstance(self._audio_cache, numpy.ndarray):
self._audio_cache = numpy.concatenate((self._audio_cache, data))
elif isinstance(self._audio_cache, list):
self._audio_cache.append(data)
if self._audiobinary_cache is None:
self._audiobinary_cache = data
else:
# 拼接音频数据
if isinstance(self._audiobinary_cache, numpy.ndarray):
self._audiobinary_cache = numpy.concatenate((self._audiobinary_cache, data))
elif isinstance(self._audiobinary_cache, list):
self._audiobinary_cache.append(data)
def _process(self, data: Any):
"""
处理数据
"""
self._predeal_data(data)
if len(self._audio_cache) >= self._audio_config.chunk_stride:
result = self._model['vad'].generate(
input=self._audio_cache,
cache=self._model_cache,
chunk_size=self._audio_config.chunk_size,
is_final=False,
)
if (len(result[0]['value']) > 0):
self._do_callback(result[0]['value'], self._audio_cache)
logger.debug(f"VADFunctor结果: {result[0]['value']}")
self._audio_cache = None
def _run(self):
"""
线程运行逻辑
监听输入队列, 当有数据时, 处理数据
当输入队列为空时, 间隔1s检测是否进入停止事件。
"""
# 刷新运行状态
with self._status_lock:
self._is_running = True
self._stop_event = False
# 运行逻辑
while self._is_running:
try:
data = self._input_queue.get(True, timeout=1)
self._process(data)
self._input_queue.task_done()
# 当队列为空时, 间隔1s检测是否进入停止事件。
except Empty:
if self._stop_event:
break
continue
# 其他异常
except Exception as e:
logger.error(f"VADFunctor运行时发生错误: {e}")
raise e
def run(self):
"""
启动 _run 线程, 并返回线程对象
"""
self._pre_check()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
return self._thread
def _pre_check(self) -> bool:
"""
检测硬性资源是否都已设置
"""
if self._model is None:
raise ValueError("模型未设置")
if self._audio_config is None:
raise ValueError("音频配置未设置")
if self._audio_binary_data_list is None:
raise ValueError("音频数据列表未设置")
if self._input_queue is None:
raise ValueError("输入队列未设置")
if self._callback is None:
raise ValueError("回调函数未设置")
return True
def stop(self):
with self._status_lock:
self._stop_event = True
self._thread.join()
with self._status_lock:
self._is_running = False
return True
# class VAD:
# def __init__(
# self,
# VAD_model=None,
# audio_config: AudioBinary_Config = None,
# callback: Callable = None,
# ):
# # vad model
# self.VAD_model = VAD_model
# if self.VAD_model is None:
# self.VAD_model = AutoModel(
# model="fsmn-vad", model_revision="v2.0.4", disable_update=True
# )
# # audio config
# self.audio_config = audio_config
# # vad result
# self.vad_result = VADResponse(time_chunk_index_callback=callback)
# # audio binary poll
# self.audio_chunk = AudioChunk(audio_config=self.audio_config)
# self.cache = {}
# def push_binary_data(
# self,
# binary_data: bytes,
# ):
# # 压入二进制数据
# self.audio_chunk.add_chunk(binary_data)
# # 处理音频块
# res = self.VAD_model.generate(
# input=binary_data,
# cache=self.cache,
# chunk_size=self.audio_config.chunk_size,
# is_final=False,
# )
# # print("VAD generate", res)
# if len(res[0]["value"]):
# self.vad_result += VADResponse.from_raw(res)
# def set_callback(
# self,
# callback: Callable,
# ):
# self.vad_result.time_chunk_index_callback = callback
# def process_vad_result(self, callback: Callable = None):
# # 处理VAD结果
# callback = (
# callback
# if callback is not None
# else self.vad_result.time_chunk_index_callback
# )
# self.vad_result.process_time_chunk(
# lambda x: callback(
# AudioBinary_Chunk(
# start_time=x["start_time"],
# end_time=x["end_time"],
# chunk=self.audio_chunk.get_chunk(x["start_time"], x["end_time"]),
# )
# )
# )