[代码重构中]完善VADFunctor,测试持久化保存VAD片段的音频数据成功。
This commit is contained in:
parent
b569b7e63d
commit
4e9e94d8dc
34
main.py
Normal file
34
main.py
Normal file
@ -0,0 +1,34 @@
|
||||
from funasr import AutoModel
|
||||
|
||||
chunk_size = 200 # ms
|
||||
model = AutoModel(
|
||||
model="fsmn-vad",
|
||||
model_revision="v2.0.4",
|
||||
disable_update=True
|
||||
)
|
||||
|
||||
import soundfile
|
||||
|
||||
wav_file = "tests/vad_example.wav"
|
||||
speech, sample_rate = soundfile.read(wav_file)
|
||||
chunk_stride = int(chunk_size * sample_rate / 1000)
|
||||
|
||||
cache = {}
|
||||
total_chunk_num = int(len((speech)-1)/chunk_stride+1)
|
||||
for i in range(total_chunk_num):
|
||||
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
|
||||
is_final = i == total_chunk_num - 1
|
||||
res = model.generate(
|
||||
input=speech_chunk,
|
||||
cache=cache,
|
||||
is_final=is_final,
|
||||
chunk_size=chunk_size,
|
||||
disable_pbar=True
|
||||
)
|
||||
if len(res[0]["value"]):
|
||||
print(res)
|
||||
|
||||
print(f"len(speech): {len(speech)}")
|
||||
print(f"len(speech_chunk): {len(speech_chunk)}")
|
||||
print(f"total_chunk_num: {total_chunk_num}")
|
||||
print(f"generateconfig: chunk_size: {chunk_size}, chunk_stride: {chunk_stride}")
|
@ -1,13 +1,11 @@
|
||||
from funasr import AutoModel
|
||||
from typing import List, Dict, Any
|
||||
from src.models import VADResponse
|
||||
from src.models import AudioBinary_Config
|
||||
from src.models import AudioBinary_data_list
|
||||
from src.models import AudioBinary_Slice
|
||||
from typing import List, Dict, Any, Callable
|
||||
from src.models import VAD_Functor_result, _AudioBinary_data, AudioBinary_Config, AudioBinary_data_list
|
||||
from typing import Callable
|
||||
from src.functor.base import BaseFunctor
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
import numpy
|
||||
|
||||
# 日志
|
||||
from src.utils.logger import get_module_logger
|
||||
@ -20,15 +18,37 @@ class VADFunctor(BaseFunctor):
|
||||
self
|
||||
):
|
||||
super().__init__()
|
||||
self._model: dict = {}
|
||||
self._callback: List[Callable] = []
|
||||
self._status_lock: threading.Lock = threading.Lock()
|
||||
self._input_queue: Queue = None
|
||||
self._audio_config: AudioBinary_Config = None
|
||||
# 资源与配置
|
||||
self._model: dict = {} # 模型
|
||||
self._callback: List[Callable] = [] # 回调函数
|
||||
self._input_queue: Queue = None # 输入队列
|
||||
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||
self._audio_binary_data_list: AudioBinary_data_list = None # 音频数据列表
|
||||
|
||||
# flag
|
||||
# 此处用到两个锁,但都是为了截断_run线程,考虑后续优化
|
||||
self._is_running: bool = False
|
||||
self._stop_event: bool = False
|
||||
self._audio_cache: bytes = b''
|
||||
self._cache: dict = {}
|
||||
|
||||
# 状态锁
|
||||
self._status_lock: threading.Lock = threading.Lock()
|
||||
|
||||
# 缓存
|
||||
self._audio_cache: numpy.ndarray = None
|
||||
self._audio_cache_preindex: int = 0
|
||||
self._model_cache: dict = {}
|
||||
self._cache_result_list = []
|
||||
self._audiobinary_cache = None
|
||||
|
||||
def reset_cache(self):
|
||||
"""
|
||||
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||
"""
|
||||
self._audio_cache = None
|
||||
self._audio_cache_preindex = 0
|
||||
self._model_cache = {}
|
||||
self._cache_result_list = []
|
||||
self._audiobinary_cache = None
|
||||
|
||||
def set_input_queue(self, queue: Queue):
|
||||
self._input_queue = queue
|
||||
@ -38,32 +58,99 @@ class VADFunctor(BaseFunctor):
|
||||
|
||||
def set_audio_config(self, audio_config: AudioBinary_Config):
|
||||
self._audio_config = audio_config
|
||||
logger.info(f"VADFunctor设置音频配置: {self._audio_config}")
|
||||
|
||||
def set_audio_binary_data_list(self, audio_binary_data_list: AudioBinary_data_list):
|
||||
self._audio_binary_data_list = audio_binary_data_list
|
||||
|
||||
def add_callback(self, callback: Callable):
|
||||
if not isinstance(self._callback, list):
|
||||
self._callback = []
|
||||
self._callback.append(callback)
|
||||
|
||||
def _process(self, data: bytes):
|
||||
def _do_callback(self, result: List[List[int]], audio_cache: AudioBinary_data_list):
|
||||
"""
|
||||
回调函数
|
||||
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice
|
||||
|
||||
输入:
|
||||
result: List[[start,end]] 处理所得VAD端点
|
||||
其中若start==-1, 则表示前无端点, 若end==-1, 则表示后无端点
|
||||
当处理得到一个完成片段时, 存入AudioBinary中, 并向队列中添加AudioBinary_Slice
|
||||
输出:
|
||||
None
|
||||
"""
|
||||
# 持久化缓存结果队列
|
||||
for pair in result:
|
||||
[start, end] = pair
|
||||
# 若无前端点, 则向缓存队列中合并
|
||||
if start == -1:
|
||||
self._cache_result_list[-1][1] = end
|
||||
else:
|
||||
self._cache_result_list.append(pair)
|
||||
while len(self._cache_result_list) > 1:
|
||||
# 创建VAD片段
|
||||
# 计算开始帧
|
||||
start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0])
|
||||
start_frame -= self._audio_cache_preindex
|
||||
# 计算结束帧
|
||||
end_frame = self._audio_config.ms2frame(self._cache_result_list[0][1])
|
||||
end_frame -= self._audio_cache_preindex
|
||||
# 计算开始时间
|
||||
vad_result = VAD_Functor_result.create_from_push_data(
|
||||
audiobinary_data_list=self._audio_binary_data_list,
|
||||
data=self._audiobinary_cache[start_frame:end_frame],
|
||||
start_time=self._cache_result_list[0][0],
|
||||
end_time=self._cache_result_list[0][1]
|
||||
)
|
||||
self._audio_cache_preindex += end_frame
|
||||
self._audiobinary_cache = self._audiobinary_cache[end_frame:]
|
||||
for callback in self._callback:
|
||||
callback(vad_result)
|
||||
self._cache_result_list.pop(0)
|
||||
|
||||
def _predeal_data(self, data: Any):
|
||||
"""
|
||||
预处理数据
|
||||
"""
|
||||
if self._audio_cache is None:
|
||||
self._audio_cache = data
|
||||
else:
|
||||
# 拼接音频数据
|
||||
if isinstance(self._audio_cache, numpy.ndarray):
|
||||
self._audio_cache = numpy.concatenate((self._audio_cache, data))
|
||||
elif isinstance(self._audio_cache, list):
|
||||
self._audio_cache.append(data)
|
||||
if self._audiobinary_cache is None:
|
||||
self._audiobinary_cache = data
|
||||
else:
|
||||
# 拼接音频数据
|
||||
if isinstance(self._audiobinary_cache, numpy.ndarray):
|
||||
self._audiobinary_cache = numpy.concatenate((self._audiobinary_cache, data))
|
||||
elif isinstance(self._audiobinary_cache, list):
|
||||
self._audiobinary_cache.append(data)
|
||||
def _process(self, data: Any):
|
||||
"""
|
||||
处理数据
|
||||
"""
|
||||
self._audio_cache += data
|
||||
if len(self._audio_cache) >= self._audio_config.chunk_size*100:
|
||||
self._predeal_data(data)
|
||||
if len(self._audio_cache) >= self._audio_config.chunk_stride:
|
||||
result = self._model['vad'].generate(
|
||||
input=self._audio_cache,
|
||||
cache=self._cache,
|
||||
cache=self._model_cache,
|
||||
chunk_size=self._audio_config.chunk_size,
|
||||
is_final=False,
|
||||
)
|
||||
logger.info(f"VADFunctor处理数据: {len(self._audio_cache)}, {result}")
|
||||
self._audio_cache = b''
|
||||
if (len(result[0]['value']) > 0):
|
||||
self._do_callback(result[0]['value'], self._audio_cache)
|
||||
logger.debug(f"VADFunctor结果: {result[0]['value']}")
|
||||
self._audio_cache = None
|
||||
|
||||
|
||||
def _run(self):
|
||||
"""
|
||||
线程运行逻辑
|
||||
监听输入队列,当有数据时,处理数据
|
||||
监听输入队列, 当有数据时, 处理数据
|
||||
当输入队列为空时, 间隔1s检测是否进入停止事件。
|
||||
"""
|
||||
# 刷新运行状态
|
||||
@ -76,7 +163,7 @@ class VADFunctor(BaseFunctor):
|
||||
data = self._input_queue.get(True, timeout=1)
|
||||
self._process(data)
|
||||
self._input_queue.task_done()
|
||||
# 当队列为空时,间隔1s检测是否进入停止事件。
|
||||
# 当队列为空时, 间隔1s检测是否进入停止事件。
|
||||
except Empty:
|
||||
if self._stop_event:
|
||||
break
|
||||
@ -90,12 +177,26 @@ class VADFunctor(BaseFunctor):
|
||||
"""
|
||||
启动 _run 线程, 并返回线程对象
|
||||
"""
|
||||
self._pre_check()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
return self._thread
|
||||
|
||||
def _pre_check(self):
|
||||
pass
|
||||
def _pre_check(self) -> bool:
|
||||
"""
|
||||
检测硬性资源是否都已设置
|
||||
"""
|
||||
if self._model is None:
|
||||
raise ValueError("模型未设置")
|
||||
if self._audio_config is None:
|
||||
raise ValueError("音频配置未设置")
|
||||
if self._audio_binary_data_list is None:
|
||||
raise ValueError("音频数据列表未设置")
|
||||
if self._input_queue is None:
|
||||
raise ValueError("输入队列未设置")
|
||||
if self._callback is None:
|
||||
raise ValueError("回调函数未设置")
|
||||
return True
|
||||
|
||||
def stop(self):
|
||||
with self._status_lock:
|
||||
|
@ -1,3 +1,3 @@
|
||||
from .audio import AudioBinary_Config, AudioBinary_data_list, AudioBinary_Slice
|
||||
from .vad import VADResponse
|
||||
__all__ = ["AudioBinary_Config", "AudioBinary_data_list", "AudioBinary_Slice", "VADResponse"]
|
||||
from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
|
||||
from .vad import VAD_Functor_result
|
||||
__all__ = ["AudioBinary_Config", "AudioBinary_data_list", "_AudioBinary_data", "VAD_Functor_result"]
|
@ -1,9 +1,19 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import List, Any
|
||||
import numpy
|
||||
|
||||
from src.utils import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
binary_data_types = (bytes, numpy.ndarray)
|
||||
|
||||
class AudioBinary_Config(BaseModel):
|
||||
"""二进制音频块配置信息"""
|
||||
audio_data: bytes = Field(description="音频数据", default=None)
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
audio_data: binary_data_types = Field(description="音频数据", default=None)
|
||||
chunk_size: int = Field(description="块大小", default=100)
|
||||
chunk_stride: int = Field(description="块步长", default=1600)
|
||||
sample_rate: int = Field(description="采样率", default=16000)
|
||||
@ -15,23 +25,117 @@ class AudioBinary_Config(BaseModel):
|
||||
def AudioBinary_Config_from_dict(cls, data: dict):
|
||||
return cls(**data)
|
||||
|
||||
def ms2frame(self, ms: int) -> int:
|
||||
"""
|
||||
将毫秒转换为帧
|
||||
"""
|
||||
return int(ms * self.sample_rate / 1000)
|
||||
|
||||
def frame2ms(self, frame: int) -> int:
|
||||
"""
|
||||
将帧转换为毫秒
|
||||
"""
|
||||
return int(frame * 1000 / self.sample_rate)
|
||||
|
||||
class _AudioBinary_data(BaseModel):
|
||||
"""音频二进制数据"""
|
||||
binary_data: bytes = Field(description="音频二进制数据", default=None)
|
||||
"""音频数据"""
|
||||
binary_data: binary_data_types = Field(description="音频二进制数据", default=None)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator('binary_data')
|
||||
def validate_binary_data(cls, v):
|
||||
"""
|
||||
验证音频数据
|
||||
Args:
|
||||
v: 音频数据
|
||||
Returns:
|
||||
binary_data_types: 音频数据
|
||||
"""
|
||||
if not isinstance(v, (bytes, numpy.ndarray)):
|
||||
logger.warning("[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查", cls.__class__.__name__, type(v))
|
||||
return v
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
获取音频数据长度
|
||||
Returns:
|
||||
int: 音频数据长度
|
||||
"""
|
||||
return len(self.binary_data)
|
||||
|
||||
def __init__(self, binary_data: binary_data_types):
|
||||
"""
|
||||
初始化音频数据
|
||||
Args:
|
||||
binary_data: 音频数据
|
||||
"""
|
||||
logger.debug("[%s]初始化音频数据, 数据类型为%s", self.__class__.__name__, type(binary_data))
|
||||
super().__init__(binary_data=binary_data)
|
||||
|
||||
def __getitem__(self):
|
||||
"""
|
||||
当获取数据时, 直接返回binary_data
|
||||
Returns:
|
||||
binary_data_types: 音频数据
|
||||
"""
|
||||
return self.binary_data
|
||||
|
||||
class AudioBinary_data_list(BaseModel):
|
||||
"""音频二进制数据列表"""
|
||||
binary_data_list: List[_AudioBinary_data] = Field(description="音频二进制数据列表", default=[])
|
||||
"""音频数据列表"""
|
||||
binary_data_list: List[_AudioBinary_data] = Field(description="音频数据列表", default=[])
|
||||
|
||||
def __call__(self):
|
||||
return self.binary_data_list
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
class AudioBinary_Slice(BaseModel):
|
||||
"""音频块切片"""
|
||||
target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None)
|
||||
start_index: int = Field(description="开始位置", default=0)
|
||||
end_index: int = Field(description="结束位置", default=0)
|
||||
def push_data(self, data: binary_data_types) -> int:
|
||||
"""
|
||||
添加音频数据
|
||||
Args:
|
||||
data: 音频数据
|
||||
Returns:
|
||||
int: 数据在binary_data_list中的索引
|
||||
"""
|
||||
self.binary_data_list.append(_AudioBinary_data(binary_data=data))
|
||||
return len(self.binary_data_list) - 1
|
||||
|
||||
def __call__(self):
|
||||
return self.target_Binary(self.start_index, self.end_index)
|
||||
def __getitem__(self, index: int):
|
||||
"""
|
||||
获取音频数据
|
||||
Args:
|
||||
index: 音频数据在binary_data_list中的索引
|
||||
Returns:
|
||||
_AudioBinary_data: 音频数据
|
||||
"""
|
||||
return self.binary_data_list[index]
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
获取音频数据列表长度
|
||||
Returns:
|
||||
int: 音频数据列表长度
|
||||
"""
|
||||
return len(self.binary_data_list)
|
||||
|
||||
# class AudioBinary_Slice(BaseModel):
|
||||
# """音频块切片"""
|
||||
# target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None)
|
||||
# start_index: int = Field(description="开始位置", default=0)
|
||||
# end_index: int = Field(description="结束位置", default=-1)
|
||||
|
||||
# @validator('start_index')
|
||||
# def validate_start_index(cls, v):
|
||||
# if v < 0:
|
||||
# raise ValueError("start_index必须大于0")
|
||||
# return v
|
||||
|
||||
# @validator('end_index')
|
||||
# def validate_end_index(cls, v):
|
||||
# if v < cls.start_index:
|
||||
# logger.debug("[%s]end_index小于start_index, 将end_index设置为start_index", cls.__class__.__name__)
|
||||
# v = cls.start_index
|
||||
# return v
|
||||
|
||||
# def __call__(self):
|
||||
# return self.target_Binary(self.start_index, self.end_index)
|
@ -1,143 +1,88 @@
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import List, Optional, Callable
|
||||
from typing import List, Optional, Callable, Any
|
||||
from .audio import AudioBinary_data_list, _AudioBinary_data
|
||||
|
||||
class VADSegment(BaseModel):
|
||||
"""VAD片段"""
|
||||
start: int = Field(description="开始时间(ms)")
|
||||
end: int = Field(description="结束时间(ms)")
|
||||
class VAD_Functor_result(BaseModel):
|
||||
"""
|
||||
VADFunctor结果
|
||||
"""
|
||||
audiobinary_data_list: AudioBinary_data_list = Field(description="音频数据列表")
|
||||
audiobinary_index: int = Field(description="音频数据索引")
|
||||
audiobinary_data: _AudioBinary_data = Field(description="音频数据, 指向AudioBinary_data")
|
||||
start_time: int = Field(description="开始时间", is_required=True)
|
||||
end_time: int = Field(description="结束时间", is_required=True)
|
||||
|
||||
class VADResult(BaseModel):
|
||||
"""VAD结果"""
|
||||
key: str = Field(description="音频标识")
|
||||
value: List[VADSegment] = Field(description="VAD片段列表")
|
||||
@validator('audiobinary_data_list')
|
||||
def validate_audiobinary_data_list(cls, v):
|
||||
if not isinstance(v, AudioBinary_data_list):
|
||||
raise ValueError("audiobinary_data_list必须是AudioBinary_data_list类型")
|
||||
return v
|
||||
|
||||
class VADResponse(BaseModel):
|
||||
"""VAD响应"""
|
||||
results: List[VADResult] = Field(description="VAD结果列表", default_factory=list)
|
||||
time_chunk: List[VADSegment] = Field(description="时间块", default_factory=list)
|
||||
time_chunk_index: int = Field(description="当前处理时间块索引", default=0)
|
||||
time_chunk_index_callback: Optional[Callable[[int], None]] = Field(
|
||||
description="时间块索引回调函数",
|
||||
default=None
|
||||
)
|
||||
|
||||
@validator('time_chunk')
|
||||
def validate_time_chunk(cls, v):
|
||||
"""验证时间块的有效性"""
|
||||
if not v:
|
||||
return v
|
||||
|
||||
# 检查时间顺序
|
||||
for i in range(len(v) - 1):
|
||||
if v[i].end >= v[i + 1].start:
|
||||
raise ValueError(f"时间块{i}的结束时间({v[i].end})大于等于下一个时间块的开始时间({v[i + 1].start})")
|
||||
@validator('audiobinary_index')
|
||||
def validate_audiobinary_index(cls, v):
|
||||
if not isinstance(v, int):
|
||||
raise ValueError("audiobinary_index必须是int类型")
|
||||
if v < 0:
|
||||
raise ValueError("audiobinary_index必须大于0")
|
||||
return v
|
||||
|
||||
@validator('audiobinary_data')
|
||||
def validate_audiobinary_data(cls, v):
|
||||
if not isinstance(v, _AudioBinary_data):
|
||||
raise ValueError("audiobinary_data必须是AudioBinary_data类型")
|
||||
return v
|
||||
|
||||
@validator('start_time')
|
||||
def validate_start_time(cls, v):
|
||||
if not isinstance(v, int):
|
||||
raise ValueError("start_time必须是int类型")
|
||||
if v < 0:
|
||||
raise ValueError("start_time必须大于0")
|
||||
return v
|
||||
|
||||
# 回调未处理的时间块
|
||||
def process_time_chunk(self, callback: Callable[[int], None] = None) -> None:
|
||||
"""处理时间块"""
|
||||
# print("Enter process_time_chunk", self.time_chunk_index, len(self.time_chunk))
|
||||
while self.time_chunk_index < len(self.time_chunk) - 1:
|
||||
index = self.time_chunk_index
|
||||
if self.time_chunk[index].end != -1:
|
||||
x = {
|
||||
"start_time": self.time_chunk[index].start,
|
||||
"end_time": self.time_chunk[index].end
|
||||
}
|
||||
if callback is not None:
|
||||
callback(x)
|
||||
elif self.time_chunk_index_callback is not None:
|
||||
self.time_chunk_index_callback(x)
|
||||
else:
|
||||
print("[Warning] No callback available")
|
||||
self.time_chunk_index += 1
|
||||
@validator('end_time')
|
||||
def validate_end_time(cls, v, values):
|
||||
if not isinstance(v, int):
|
||||
raise ValueError("end_time必须是int类型")
|
||||
if 'start_time' in values and v <= values['start_time']:
|
||||
raise ValueError("end_time必须大于start_time")
|
||||
return v
|
||||
|
||||
def __add__(self, other: 'VADResponse') -> 'VADResponse':
|
||||
"""合并两个VADResponse"""
|
||||
if not self.results:
|
||||
self.results = other.results
|
||||
self.time_chunk = other.time_chunk
|
||||
return self
|
||||
|
||||
# 检查是否可以合并最后一个结果
|
||||
last_result = self.results[-1]
|
||||
first_other = other.results[0]
|
||||
|
||||
if last_result.value[-1].end == first_other.value[0].start:
|
||||
# 合并相邻的时间段
|
||||
last_result.value[-1].end = first_other.value[0].end
|
||||
first_other.value.pop(0)
|
||||
|
||||
# 更新time_chunk
|
||||
self.time_chunk[-1].end = other.time_chunk[0].end
|
||||
other.time_chunk.pop(0)
|
||||
|
||||
# 添加剩余的结果
|
||||
if first_other.value:
|
||||
self.results.extend(other.results)
|
||||
self.time_chunk.extend(other.time_chunk)
|
||||
else:
|
||||
# 直接添加所有结果
|
||||
self.results.extend(other.results)
|
||||
self.time_chunk.extend(other.time_chunk)
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_raw(cls, raw_data: List[dict]) -> "VADResponse":
|
||||
def create_from_push_data(
|
||||
cls,
|
||||
audiobinary_data_list: AudioBinary_data_list,
|
||||
data: Any,
|
||||
start_time: int,
|
||||
end_time: int
|
||||
):
|
||||
"""
|
||||
从原始数据创建VADResponse
|
||||
|
||||
参数:
|
||||
raw_data: 原始数据,格式如 [{'key': 'xxx', 'value': [[-1, 59540], [59820, -1]]}]
|
||||
|
||||
返回:
|
||||
VADResponse: 解析后的VAD响应
|
||||
创建VAD片段
|
||||
"""
|
||||
results = []
|
||||
time_chunk = []
|
||||
for item in raw_data:
|
||||
segments = [
|
||||
VADSegment(start=seg[0], end=seg[1])
|
||||
for seg in item['value']
|
||||
]
|
||||
results.append(VADResult(
|
||||
key=item['key'],
|
||||
value=segments
|
||||
))
|
||||
time_chunk.extend(segments)
|
||||
return cls(results=results, time_chunk=time_chunk)
|
||||
|
||||
def to_raw(self) -> List[dict]:
|
||||
index = audiobinary_data_list.push_data(data)
|
||||
|
||||
return cls(
|
||||
audiobinary_data_list=audiobinary_data_list,
|
||||
audiobinary_index=index,
|
||||
audiobinary_data=audiobinary_data_list[index],
|
||||
start_time=start_time,
|
||||
end_time=end_time)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
转换为原始数据格式
|
||||
|
||||
返回:
|
||||
List[dict]: 原始数据格式
|
||||
获取音频数据长度
|
||||
"""
|
||||
return [
|
||||
{
|
||||
'key': result.key,
|
||||
'value': [[seg.start, seg.end] for seg in result.value]
|
||||
}
|
||||
for result in self.results
|
||||
]
|
||||
return len(self.audiobinary_data.binary_data)
|
||||
|
||||
def __str__(self):
|
||||
result_str = "VADResponse:\n"
|
||||
for result in self.results:
|
||||
for value_item in result.value:
|
||||
result_str += f"[{value_item.start}:{value_item.end}]\n"
|
||||
return result_str
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.time_chunk)
|
||||
"""
|
||||
字符串展示内容
|
||||
"""
|
||||
tostr = f'audiobinary_data_index: {self.audiobinary_index}\n'
|
||||
tostr += f'start_time: {self.start_time}\n'
|
||||
tostr += f'end_time: {self.end_time}\n'
|
||||
tostr += f'data_length: {len(self.audiobinary_data.binary_data)}\n'
|
||||
tostr += f'data_type: {type(self.audiobinary_data.binary_data)}\n'
|
||||
return tostr
|
||||
|
||||
def __next__(self):
|
||||
return next(self.time_chunk)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.time_chunk)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.time_chunk[index]
|
||||
|
@ -16,7 +16,7 @@ class ASRPipeline(PipelineBase):
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._config: Dict[str, Any] = {}
|
||||
self._funtor_dict: Dict[str, Any] = {}
|
||||
self._functor_dict: Dict[str, Any] = {}
|
||||
self._subqueue_dict: Dict[str, Any] = {}
|
||||
self._is_baked: bool = False
|
||||
|
||||
@ -46,28 +46,28 @@ class ASRPipeline(PipelineBase):
|
||||
"""
|
||||
烘焙管道
|
||||
"""
|
||||
self._init_funtor()
|
||||
self._init_functor()
|
||||
self._is_baked = True
|
||||
|
||||
def _init_funtor(self) -> None:
|
||||
def _init_functor(self) -> None:
|
||||
"""
|
||||
初始化函数
|
||||
"""
|
||||
try:
|
||||
from src.funtor import FuntorFactory
|
||||
# 加载VAD、asr、spk funtor
|
||||
self._funtor_dict["vad"] = FuntorFactory.make_funtor(
|
||||
funtor_name = "vad",
|
||||
from src.functor import functorFactory
|
||||
# 加载VAD、asr、spk functor
|
||||
self._functor_dict["vad"] = functorFactory.make_functor(
|
||||
functor_name = "vad",
|
||||
config = self._config,
|
||||
models = self._models
|
||||
)
|
||||
self._funtor_dict["asr"] = FuntorFactory.make_funtor(
|
||||
funtor_name = "asr",
|
||||
self._functor_dict["asr"] = functorFactory.make_functor(
|
||||
functor_name = "asr",
|
||||
config = self._config,
|
||||
models = self._models
|
||||
)
|
||||
self._funtor_dict["spk"] = FuntorFactory.make_funtor(
|
||||
funtor_name = "spk",
|
||||
self._functor_dict["spk"] = functorFactory.make_functor(
|
||||
functor_name = "spk",
|
||||
config = self._config,
|
||||
models = self._models
|
||||
)
|
||||
@ -79,13 +79,13 @@ class ASRPipeline(PipelineBase):
|
||||
self._subqueue_dict["spkend"] = Queue()
|
||||
|
||||
# 设置子队列的输入队列
|
||||
self._funtor_dict["vad"].set_input_queue(self._input_queue)
|
||||
self._funtor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
|
||||
self._funtor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
|
||||
self._functor_dict["vad"].set_input_queue(self._input_queue)
|
||||
self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
|
||||
self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
|
||||
|
||||
# 设置回调函数——放置到对应队列中
|
||||
self._funtor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put)
|
||||
self._funtor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put)
|
||||
self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put)
|
||||
self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put)
|
||||
|
||||
# 构造带回调函数的put
|
||||
def put_with_checkcallback(queue: Queue, callback: Callable) -> None:
|
||||
@ -97,11 +97,11 @@ class ASRPipeline(PipelineBase):
|
||||
callback()
|
||||
return put_with_check
|
||||
|
||||
self._funtor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result))
|
||||
self._funtor_dict["spk"].add_callback(put_with_checkcallback(self._subqueue_dict["spkend"], self._check_result))
|
||||
self._functor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result))
|
||||
self._functor_dict["spk"].add_callback(put_with_checkcallback(self._subqueue_dict["spkend"], self._check_result))
|
||||
|
||||
except ImportError:
|
||||
raise ImportError("FuntorFactory引入失败,ASRPipeline无法完成初始化")
|
||||
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -129,10 +129,10 @@ class ASRPipeline(PipelineBase):
|
||||
if not self._is_baked:
|
||||
raise RuntimeError("管道未烘焙,无法运行")
|
||||
|
||||
# 运行所有funtor
|
||||
for funtor_name, funtor in self._funtor_dict.items():
|
||||
logger.info(f"运行{funtor_name}funtor")
|
||||
funtor.run()
|
||||
# 运行所有functor
|
||||
for functor_name, functor in self._functor_dict.items():
|
||||
logger.info(f"运行{functor_name}functor")
|
||||
self._functor_dict[functor_name].run()
|
||||
|
||||
# 运行管道
|
||||
if not self._input_queue:
|
||||
@ -152,6 +152,7 @@ class ASRPipeline(PipelineBase):
|
||||
# 检查是否是结束信号
|
||||
if data is None:
|
||||
logger.info("收到结束信号,管道准备停止")
|
||||
self._stop()
|
||||
self._input_queue.task_done() # 标记结束信号已处理
|
||||
break
|
||||
|
||||
@ -187,3 +188,13 @@ class ASRPipeline(PipelineBase):
|
||||
# 通知回调函数
|
||||
self._notify_callbacks(result)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
停止管道
|
||||
"""
|
||||
self._is_running = False
|
||||
self._stop_event = True
|
||||
for functor_name, functor in self._functor_dict.items():
|
||||
logger.info(f"停止{functor_name}functor")
|
||||
functor.stop()
|
||||
logger.info("子Functor停止")
|
||||
|
@ -5,10 +5,15 @@ VAD测试
|
||||
from src.functor.vad_functor import VADFunctor
|
||||
from queue import Queue, Empty
|
||||
from src.model_loader import ModelLoader
|
||||
from src.models import AudioBinary_Config
|
||||
from src.models import AudioBinary_Config, AudioBinary_data_list
|
||||
from src.utils.data_format import wav_to_bytes
|
||||
import time
|
||||
from src.utils.logger import get_module_logger
|
||||
from pydub import AudioSegment
|
||||
import soundfile
|
||||
|
||||
# 观察参数
|
||||
OVERWATCH = False
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
@ -22,37 +27,54 @@ def test_vad_functor():
|
||||
"auto_update": False,
|
||||
}
|
||||
model_loader.load_models(args)
|
||||
# 创建VAD函数器
|
||||
vad_functor = VADFunctor()
|
||||
# 加载数据
|
||||
f_data, sample_rate = soundfile.read("tests/vad_example.wav")
|
||||
audio_config = AudioBinary_Config(
|
||||
chunk_size=200,
|
||||
chunk_stride=1600,
|
||||
sample_rate=sample_rate,
|
||||
sample_width=16,
|
||||
channels=1
|
||||
)
|
||||
chunk_stride = int(audio_config.chunk_size*sample_rate/1000)
|
||||
audio_config.chunk_stride = chunk_stride
|
||||
# 创建输入队列
|
||||
input_queue = Queue()
|
||||
# 创建音频数据列表
|
||||
audio_binary_data_list = AudioBinary_data_list()
|
||||
|
||||
# 创建VAD函数器
|
||||
vad_functor = VADFunctor()
|
||||
# 设置输入队列
|
||||
vad_functor.set_input_queue(input_queue)
|
||||
# 设置音频配置
|
||||
vad_functor.set_audio_config(AudioBinary_Config(
|
||||
chunk_size=960,
|
||||
chunk_stride=480,
|
||||
sample_rate=16000,
|
||||
sample_width=2,
|
||||
channels=1
|
||||
))
|
||||
vad_functor.set_audio_config(audio_config)
|
||||
# 设置音频数据列表
|
||||
vad_functor.set_audio_binary_data_list(audio_binary_data_list)
|
||||
# 设置回调函数
|
||||
vad_functor.add_callback(lambda x: print(x))
|
||||
vad_functor.add_callback(lambda x: print(f"callback: {x}"))
|
||||
# 设置模型
|
||||
vad_functor.set_model({
|
||||
'vad': model_loader.models['vad']
|
||||
})
|
||||
|
||||
# 启动VAD函数器
|
||||
vad_functor.run()
|
||||
|
||||
# 加载数据
|
||||
f_binary = wav_to_bytes("tests/vad_example.wav")
|
||||
chunk_size = 960000
|
||||
# chunk_size = len(f_binary)
|
||||
print(f"f_binary: {len(f_binary)}, chunk_size: {chunk_size}, clip_num: {len(f_binary) // chunk_size}")
|
||||
for i in range(0, len(f_binary), chunk_size):
|
||||
binary_data = f_binary[i:i+chunk_size]
|
||||
f_binary = f_data
|
||||
audio_clip_len = 200
|
||||
print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}")
|
||||
for i in range(0, len(f_binary), audio_clip_len):
|
||||
binary_data = f_binary[i:i+audio_clip_len]
|
||||
input_queue.put(binary_data)
|
||||
# 等待VAD函数器结束
|
||||
time.sleep(10)
|
||||
vad_functor.stop()
|
||||
print("[vad_test] 等待input_queue为空")
|
||||
input_queue.join()
|
||||
print("[vad_test] input_queue为空")
|
||||
vad_functor.stop()
|
||||
print("[vad_test] VAD函数器结束")
|
||||
|
||||
# 保存音频数据
|
||||
if OVERWATCH:
|
||||
for index in range(len(audio_binary_data_list)):
|
||||
save_path = f"tests/vad_test_output_{index}.wav"
|
||||
soundfile.write(save_path, audio_binary_data_list[index].binary_data, sample_rate)
|
||||
|
Loading…
x
Reference in New Issue
Block a user