[代码重构中]完善VADFunctor,测试持久化保存VAD片段的音频数据成功。

This commit is contained in:
Ziyang.Zhang 2025-06-05 13:43:23 +08:00
parent b569b7e63d
commit 4e9e94d8dc
7 changed files with 430 additions and 213 deletions

34
main.py Normal file
View File

@ -0,0 +1,34 @@
from funasr import AutoModel
chunk_size = 200 # ms
model = AutoModel(
model="fsmn-vad",
model_revision="v2.0.4",
disable_update=True
)
import soundfile
wav_file = "tests/vad_example.wav"
speech, sample_rate = soundfile.read(wav_file)
chunk_stride = int(chunk_size * sample_rate / 1000)
cache = {}
total_chunk_num = int(len((speech)-1)/chunk_stride+1)
for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
is_final = i == total_chunk_num - 1
res = model.generate(
input=speech_chunk,
cache=cache,
is_final=is_final,
chunk_size=chunk_size,
disable_pbar=True
)
if len(res[0]["value"]):
print(res)
print(f"len(speech): {len(speech)}")
print(f"len(speech_chunk): {len(speech_chunk)}")
print(f"total_chunk_num: {total_chunk_num}")
print(f"generateconfig: chunk_size: {chunk_size}, chunk_stride: {chunk_stride}")

View File

@ -1,13 +1,11 @@
from funasr import AutoModel from funasr import AutoModel
from typing import List, Dict, Any from typing import List, Dict, Any, Callable
from src.models import VADResponse from src.models import VAD_Functor_result, _AudioBinary_data, AudioBinary_Config, AudioBinary_data_list
from src.models import AudioBinary_Config
from src.models import AudioBinary_data_list
from src.models import AudioBinary_Slice
from typing import Callable from typing import Callable
from src.functor.base import BaseFunctor from src.functor.base import BaseFunctor
import threading import threading
from queue import Empty, Queue from queue import Empty, Queue
import numpy
# 日志 # 日志
from src.utils.logger import get_module_logger from src.utils.logger import get_module_logger
@ -20,15 +18,37 @@ class VADFunctor(BaseFunctor):
self self
): ):
super().__init__() super().__init__()
self._model: dict = {} # 资源与配置
self._callback: List[Callable] = [] self._model: dict = {} # 模型
self._status_lock: threading.Lock = threading.Lock() self._callback: List[Callable] = [] # 回调函数
self._input_queue: Queue = None self._input_queue: Queue = None # 输入队列
self._audio_config: AudioBinary_Config = None self._audio_config: AudioBinary_Config = None # 音频配置
self._audio_binary_data_list: AudioBinary_data_list = None # 音频数据列表
# flag
# 此处用到两个锁但都是为了截断_run线程考虑后续优化
self._is_running: bool = False self._is_running: bool = False
self._stop_event: bool = False self._stop_event: bool = False
self._audio_cache: bytes = b''
self._cache: dict = {} # 状态锁
self._status_lock: threading.Lock = threading.Lock()
# 缓存
self._audio_cache: numpy.ndarray = None
self._audio_cache_preindex: int = 0
self._model_cache: dict = {}
self._cache_result_list = []
self._audiobinary_cache = None
def reset_cache(self):
"""
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
"""
self._audio_cache = None
self._audio_cache_preindex = 0
self._model_cache = {}
self._cache_result_list = []
self._audiobinary_cache = None
def set_input_queue(self, queue: Queue): def set_input_queue(self, queue: Queue):
self._input_queue = queue self._input_queue = queue
@ -38,32 +58,99 @@ class VADFunctor(BaseFunctor):
def set_audio_config(self, audio_config: AudioBinary_Config): def set_audio_config(self, audio_config: AudioBinary_Config):
self._audio_config = audio_config self._audio_config = audio_config
logger.info(f"VADFunctor设置音频配置: {self._audio_config}")
def set_audio_binary_data_list(self, audio_binary_data_list: AudioBinary_data_list):
self._audio_binary_data_list = audio_binary_data_list
def add_callback(self, callback: Callable): def add_callback(self, callback: Callable):
if not isinstance(self._callback, list): if not isinstance(self._callback, list):
self._callback = [] self._callback = []
self._callback.append(callback) self._callback.append(callback)
def _process(self, data: bytes): def _do_callback(self, result: List[List[int]], audio_cache: AudioBinary_data_list):
"""
回调函数
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice
输入:
result: List[[start,end]] 处理所得VAD端点
其中若start==-1, 则表示前无端点, 若end==-1, 则表示后无端点
当处理得到一个完成片段时, 存入AudioBinary中, 并向队列中添加AudioBinary_Slice
输出:
None
"""
# 持久化缓存结果队列
for pair in result:
[start, end] = pair
# 若无前端点, 则向缓存队列中合并
if start == -1:
self._cache_result_list[-1][1] = end
else:
self._cache_result_list.append(pair)
while len(self._cache_result_list) > 1:
# 创建VAD片段
# 计算开始帧
start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0])
start_frame -= self._audio_cache_preindex
# 计算结束帧
end_frame = self._audio_config.ms2frame(self._cache_result_list[0][1])
end_frame -= self._audio_cache_preindex
# 计算开始时间
vad_result = VAD_Functor_result.create_from_push_data(
audiobinary_data_list=self._audio_binary_data_list,
data=self._audiobinary_cache[start_frame:end_frame],
start_time=self._cache_result_list[0][0],
end_time=self._cache_result_list[0][1]
)
self._audio_cache_preindex += end_frame
self._audiobinary_cache = self._audiobinary_cache[end_frame:]
for callback in self._callback:
callback(vad_result)
self._cache_result_list.pop(0)
def _predeal_data(self, data: Any):
"""
预处理数据
"""
if self._audio_cache is None:
self._audio_cache = data
else:
# 拼接音频数据
if isinstance(self._audio_cache, numpy.ndarray):
self._audio_cache = numpy.concatenate((self._audio_cache, data))
elif isinstance(self._audio_cache, list):
self._audio_cache.append(data)
if self._audiobinary_cache is None:
self._audiobinary_cache = data
else:
# 拼接音频数据
if isinstance(self._audiobinary_cache, numpy.ndarray):
self._audiobinary_cache = numpy.concatenate((self._audiobinary_cache, data))
elif isinstance(self._audiobinary_cache, list):
self._audiobinary_cache.append(data)
def _process(self, data: Any):
""" """
处理数据 处理数据
""" """
self._audio_cache += data self._predeal_data(data)
if len(self._audio_cache) >= self._audio_config.chunk_size*100: if len(self._audio_cache) >= self._audio_config.chunk_stride:
result = self._model['vad'].generate( result = self._model['vad'].generate(
input=self._audio_cache, input=self._audio_cache,
cache=self._cache, cache=self._model_cache,
chunk_size=self._audio_config.chunk_size, chunk_size=self._audio_config.chunk_size,
is_final=False, is_final=False,
) )
logger.info(f"VADFunctor处理数据: {len(self._audio_cache)}, {result}") if (len(result[0]['value']) > 0):
self._audio_cache = b'' self._do_callback(result[0]['value'], self._audio_cache)
logger.debug(f"VADFunctor结果: {result[0]['value']}")
self._audio_cache = None
def _run(self): def _run(self):
""" """
线程运行逻辑 线程运行逻辑
监听输入队列当有数据时处理数据 监听输入队列, 当有数据时, 处理数据
当输入队列为空时, 间隔1s检测是否进入停止事件 当输入队列为空时, 间隔1s检测是否进入停止事件
""" """
# 刷新运行状态 # 刷新运行状态
@ -76,7 +163,7 @@ class VADFunctor(BaseFunctor):
data = self._input_queue.get(True, timeout=1) data = self._input_queue.get(True, timeout=1)
self._process(data) self._process(data)
self._input_queue.task_done() self._input_queue.task_done()
# 当队列为空时间隔1s检测是否进入停止事件。 # 当队列为空时, 间隔1s检测是否进入停止事件。
except Empty: except Empty:
if self._stop_event: if self._stop_event:
break break
@ -90,12 +177,26 @@ class VADFunctor(BaseFunctor):
""" """
启动 _run 线程, 并返回线程对象 启动 _run 线程, 并返回线程对象
""" """
self._pre_check()
self._thread = threading.Thread(target=self._run, daemon=True) self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start() self._thread.start()
return self._thread return self._thread
def _pre_check(self): def _pre_check(self) -> bool:
pass """
检测硬性资源是否都已设置
"""
if self._model is None:
raise ValueError("模型未设置")
if self._audio_config is None:
raise ValueError("音频配置未设置")
if self._audio_binary_data_list is None:
raise ValueError("音频数据列表未设置")
if self._input_queue is None:
raise ValueError("输入队列未设置")
if self._callback is None:
raise ValueError("回调函数未设置")
return True
def stop(self): def stop(self):
with self._status_lock: with self._status_lock:

View File

@ -1,3 +1,3 @@
from .audio import AudioBinary_Config, AudioBinary_data_list, AudioBinary_Slice from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
from .vad import VADResponse from .vad import VAD_Functor_result
__all__ = ["AudioBinary_Config", "AudioBinary_data_list", "AudioBinary_Slice", "VADResponse"] __all__ = ["AudioBinary_Config", "AudioBinary_data_list", "_AudioBinary_data", "VAD_Functor_result"]

View File

@ -1,9 +1,19 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, validator
from typing import List from typing import List, Any
import numpy
from src.utils import get_module_logger
logger = get_module_logger(__name__)
binary_data_types = (bytes, numpy.ndarray)
class AudioBinary_Config(BaseModel): class AudioBinary_Config(BaseModel):
"""二进制音频块配置信息""" """二进制音频块配置信息"""
audio_data: bytes = Field(description="音频数据", default=None) class Config:
arbitrary_types_allowed = True
audio_data: binary_data_types = Field(description="音频数据", default=None)
chunk_size: int = Field(description="块大小", default=100) chunk_size: int = Field(description="块大小", default=100)
chunk_stride: int = Field(description="块步长", default=1600) chunk_stride: int = Field(description="块步长", default=1600)
sample_rate: int = Field(description="采样率", default=16000) sample_rate: int = Field(description="采样率", default=16000)
@ -15,23 +25,117 @@ class AudioBinary_Config(BaseModel):
def AudioBinary_Config_from_dict(cls, data: dict): def AudioBinary_Config_from_dict(cls, data: dict):
return cls(**data) return cls(**data)
def ms2frame(self, ms: int) -> int:
"""
将毫秒转换为帧
"""
return int(ms * self.sample_rate / 1000)
def frame2ms(self, frame: int) -> int:
"""
将帧转换为毫秒
"""
return int(frame * 1000 / self.sample_rate)
class _AudioBinary_data(BaseModel): class _AudioBinary_data(BaseModel):
"""音频二进制数据""" """音频数据"""
binary_data: bytes = Field(description="音频二进制数据", default=None) binary_data: binary_data_types = Field(description="音频二进制数据", default=None)
class Config:
arbitrary_types_allowed = True
@validator('binary_data')
def validate_binary_data(cls, v):
"""
验证音频数据
Args:
v: 音频数据
Returns:
binary_data_types: 音频数据
"""
if not isinstance(v, (bytes, numpy.ndarray)):
logger.warning("[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查", cls.__class__.__name__, type(v))
return v
def __len__(self):
"""
获取音频数据长度
Returns:
int: 音频数据长度
"""
return len(self.binary_data)
def __init__(self, binary_data: binary_data_types):
"""
初始化音频数据
Args:
binary_data: 音频数据
"""
logger.debug("[%s]初始化音频数据, 数据类型为%s", self.__class__.__name__, type(binary_data))
super().__init__(binary_data=binary_data)
def __getitem__(self):
"""
当获取数据时, 直接返回binary_data
Returns:
binary_data_types: 音频数据
"""
return self.binary_data
class AudioBinary_data_list(BaseModel): class AudioBinary_data_list(BaseModel):
"""音频二进制数据列表""" """音频数据列表"""
binary_data_list: List[_AudioBinary_data] = Field(description="音频二进制数据列表", default=[]) binary_data_list: List[_AudioBinary_data] = Field(description="音频数据列表", default=[])
def __call__(self): class Config:
return self.binary_data_list arbitrary_types_allowed = True
class AudioBinary_Slice(BaseModel): def push_data(self, data: binary_data_types) -> int:
"""音频块切片""" """
target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None) 添加音频数据
start_index: int = Field(description="开始位置", default=0) Args:
end_index: int = Field(description="结束位置", default=0) data: 音频数据
Returns:
int: 数据在binary_data_list中的索引
"""
self.binary_data_list.append(_AudioBinary_data(binary_data=data))
return len(self.binary_data_list) - 1
def __call__(self): def __getitem__(self, index: int):
return self.target_Binary(self.start_index, self.end_index) """
获取音频数据
Args:
index: 音频数据在binary_data_list中的索引
Returns:
_AudioBinary_data: 音频数据
"""
return self.binary_data_list[index]
def __len__(self):
"""
获取音频数据列表长度
Returns:
int: 音频数据列表长度
"""
return len(self.binary_data_list)
# class AudioBinary_Slice(BaseModel):
# """音频块切片"""
# target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None)
# start_index: int = Field(description="开始位置", default=0)
# end_index: int = Field(description="结束位置", default=-1)
# @validator('start_index')
# def validate_start_index(cls, v):
# if v < 0:
# raise ValueError("start_index必须大于0")
# return v
# @validator('end_index')
# def validate_end_index(cls, v):
# if v < cls.start_index:
# logger.debug("[%s]end_index小于start_index, 将end_index设置为start_index", cls.__class__.__name__)
# v = cls.start_index
# return v
# def __call__(self):
# return self.target_Binary(self.start_index, self.end_index)

View File

@ -1,143 +1,88 @@
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from typing import List, Optional, Callable from typing import List, Optional, Callable, Any
from .audio import AudioBinary_data_list, _AudioBinary_data
class VADSegment(BaseModel): class VAD_Functor_result(BaseModel):
"""VAD片段""" """
start: int = Field(description="开始时间(ms)") VADFunctor结果
end: int = Field(description="结束时间(ms)") """
audiobinary_data_list: AudioBinary_data_list = Field(description="音频数据列表")
audiobinary_index: int = Field(description="音频数据索引")
audiobinary_data: _AudioBinary_data = Field(description="音频数据, 指向AudioBinary_data")
start_time: int = Field(description="开始时间", is_required=True)
end_time: int = Field(description="结束时间", is_required=True)
class VADResult(BaseModel): @validator('audiobinary_data_list')
"""VAD结果""" def validate_audiobinary_data_list(cls, v):
key: str = Field(description="音频标识") if not isinstance(v, AudioBinary_data_list):
value: List[VADSegment] = Field(description="VAD片段列表") raise ValueError("audiobinary_data_list必须是AudioBinary_data_list类型")
return v
class VADResponse(BaseModel): @validator('audiobinary_index')
"""VAD响应""" def validate_audiobinary_index(cls, v):
results: List[VADResult] = Field(description="VAD结果列表", default_factory=list) if not isinstance(v, int):
time_chunk: List[VADSegment] = Field(description="时间块", default_factory=list) raise ValueError("audiobinary_index必须是int类型")
time_chunk_index: int = Field(description="当前处理时间块索引", default=0) if v < 0:
time_chunk_index_callback: Optional[Callable[[int], None]] = Field( raise ValueError("audiobinary_index必须大于0")
description="时间块索引回调函数", return v
default=None
) @validator('audiobinary_data')
def validate_audiobinary_data(cls, v):
@validator('time_chunk') if not isinstance(v, _AudioBinary_data):
def validate_time_chunk(cls, v): raise ValueError("audiobinary_data必须是AudioBinary_data类型")
"""验证时间块的有效性""" return v
if not v:
return v @validator('start_time')
def validate_start_time(cls, v):
# 检查时间顺序 if not isinstance(v, int):
for i in range(len(v) - 1): raise ValueError("start_time必须是int类型")
if v[i].end >= v[i + 1].start: if v < 0:
raise ValueError(f"时间块{i}的结束时间({v[i].end})大于等于下一个时间块的开始时间({v[i + 1].start})") raise ValueError("start_time必须大于0")
return v return v
# 回调未处理的时间块 @validator('end_time')
def process_time_chunk(self, callback: Callable[[int], None] = None) -> None: def validate_end_time(cls, v, values):
"""处理时间块""" if not isinstance(v, int):
# print("Enter process_time_chunk", self.time_chunk_index, len(self.time_chunk)) raise ValueError("end_time必须是int类型")
while self.time_chunk_index < len(self.time_chunk) - 1: if 'start_time' in values and v <= values['start_time']:
index = self.time_chunk_index raise ValueError("end_time必须大于start_time")
if self.time_chunk[index].end != -1: return v
x = {
"start_time": self.time_chunk[index].start,
"end_time": self.time_chunk[index].end
}
if callback is not None:
callback(x)
elif self.time_chunk_index_callback is not None:
self.time_chunk_index_callback(x)
else:
print("[Warning] No callback available")
self.time_chunk_index += 1
def __add__(self, other: 'VADResponse') -> 'VADResponse':
"""合并两个VADResponse"""
if not self.results:
self.results = other.results
self.time_chunk = other.time_chunk
return self
# 检查是否可以合并最后一个结果
last_result = self.results[-1]
first_other = other.results[0]
if last_result.value[-1].end == first_other.value[0].start:
# 合并相邻的时间段
last_result.value[-1].end = first_other.value[0].end
first_other.value.pop(0)
# 更新time_chunk
self.time_chunk[-1].end = other.time_chunk[0].end
other.time_chunk.pop(0)
# 添加剩余的结果
if first_other.value:
self.results.extend(other.results)
self.time_chunk.extend(other.time_chunk)
else:
# 直接添加所有结果
self.results.extend(other.results)
self.time_chunk.extend(other.time_chunk)
return self
@classmethod @classmethod
def from_raw(cls, raw_data: List[dict]) -> "VADResponse": def create_from_push_data(
cls,
audiobinary_data_list: AudioBinary_data_list,
data: Any,
start_time: int,
end_time: int
):
""" """
从原始数据创建VADResponse 创建VAD片段
参数:
raw_data: 原始数据格式如 [{'key': 'xxx', 'value': [[-1, 59540], [59820, -1]]}]
返回:
VADResponse: 解析后的VAD响应
""" """
results = [] index = audiobinary_data_list.push_data(data)
time_chunk = []
for item in raw_data: return cls(
segments = [ audiobinary_data_list=audiobinary_data_list,
VADSegment(start=seg[0], end=seg[1]) audiobinary_index=index,
for seg in item['value'] audiobinary_data=audiobinary_data_list[index],
] start_time=start_time,
results.append(VADResult( end_time=end_time)
key=item['key'],
value=segments def __len__(self):
))
time_chunk.extend(segments)
return cls(results=results, time_chunk=time_chunk)
def to_raw(self) -> List[dict]:
""" """
转换为原始数据格式 获取音频数据长度
返回:
List[dict]: 原始数据格式
""" """
return [ return len(self.audiobinary_data.binary_data)
{
'key': result.key,
'value': [[seg.start, seg.end] for seg in result.value]
}
for result in self.results
]
def __str__(self): def __str__(self):
result_str = "VADResponse:\n" """
for result in self.results: 字符串展示内容
for value_item in result.value: """
result_str += f"[{value_item.start}:{value_item.end}]\n" tostr = f'audiobinary_data_index: {self.audiobinary_index}\n'
return result_str tostr += f'start_time: {self.start_time}\n'
tostr += f'end_time: {self.end_time}\n'
def __iter__(self): tostr += f'data_length: {len(self.audiobinary_data.binary_data)}\n'
return iter(self.time_chunk) tostr += f'data_type: {type(self.audiobinary_data.binary_data)}\n'
return tostr
def __next__(self):
return next(self.time_chunk)
def __len__(self):
return len(self.time_chunk)
def __getitem__(self, index):
return self.time_chunk[index]

View File

@ -16,7 +16,7 @@ class ASRPipeline(PipelineBase):
""" """
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._config: Dict[str, Any] = {} self._config: Dict[str, Any] = {}
self._funtor_dict: Dict[str, Any] = {} self._functor_dict: Dict[str, Any] = {}
self._subqueue_dict: Dict[str, Any] = {} self._subqueue_dict: Dict[str, Any] = {}
self._is_baked: bool = False self._is_baked: bool = False
@ -46,28 +46,28 @@ class ASRPipeline(PipelineBase):
""" """
烘焙管道 烘焙管道
""" """
self._init_funtor() self._init_functor()
self._is_baked = True self._is_baked = True
def _init_funtor(self) -> None: def _init_functor(self) -> None:
""" """
初始化函数 初始化函数
""" """
try: try:
from src.funtor import FuntorFactory from src.functor import functorFactory
# 加载VAD、asr、spk funtor # 加载VAD、asr、spk functor
self._funtor_dict["vad"] = FuntorFactory.make_funtor( self._functor_dict["vad"] = functorFactory.make_functor(
funtor_name = "vad", functor_name = "vad",
config = self._config, config = self._config,
models = self._models models = self._models
) )
self._funtor_dict["asr"] = FuntorFactory.make_funtor( self._functor_dict["asr"] = functorFactory.make_functor(
funtor_name = "asr", functor_name = "asr",
config = self._config, config = self._config,
models = self._models models = self._models
) )
self._funtor_dict["spk"] = FuntorFactory.make_funtor( self._functor_dict["spk"] = functorFactory.make_functor(
funtor_name = "spk", functor_name = "spk",
config = self._config, config = self._config,
models = self._models models = self._models
) )
@ -79,13 +79,13 @@ class ASRPipeline(PipelineBase):
self._subqueue_dict["spkend"] = Queue() self._subqueue_dict["spkend"] = Queue()
# 设置子队列的输入队列 # 设置子队列的输入队列
self._funtor_dict["vad"].set_input_queue(self._input_queue) self._functor_dict["vad"].set_input_queue(self._input_queue)
self._funtor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"]) self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
self._funtor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"]) self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
# 设置回调函数——放置到对应队列中 # 设置回调函数——放置到对应队列中
self._funtor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put) self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put)
self._funtor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put) self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put)
# 构造带回调函数的put # 构造带回调函数的put
def put_with_checkcallback(queue: Queue, callback: Callable) -> None: def put_with_checkcallback(queue: Queue, callback: Callable) -> None:
@ -97,11 +97,11 @@ class ASRPipeline(PipelineBase):
callback() callback()
return put_with_check return put_with_check
self._funtor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result)) self._functor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result))
self._funtor_dict["spk"].add_callback(put_with_checkcallback(self._subqueue_dict["spkend"], self._check_result)) self._functor_dict["spk"].add_callback(put_with_checkcallback(self._subqueue_dict["spkend"], self._check_result))
except ImportError: except ImportError:
raise ImportError("FuntorFactory引入失败,ASRPipeline无法完成初始化") raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
def get_config(self) -> Dict[str, Any]: def get_config(self) -> Dict[str, Any]:
""" """
@ -129,10 +129,10 @@ class ASRPipeline(PipelineBase):
if not self._is_baked: if not self._is_baked:
raise RuntimeError("管道未烘焙,无法运行") raise RuntimeError("管道未烘焙,无法运行")
# 运行所有funtor # 运行所有functor
for funtor_name, funtor in self._funtor_dict.items(): for functor_name, functor in self._functor_dict.items():
logger.info(f"运行{funtor_name}funtor") logger.info(f"运行{functor_name}functor")
funtor.run() self._functor_dict[functor_name].run()
# 运行管道 # 运行管道
if not self._input_queue: if not self._input_queue:
@ -152,6 +152,7 @@ class ASRPipeline(PipelineBase):
# 检查是否是结束信号 # 检查是否是结束信号
if data is None: if data is None:
logger.info("收到结束信号,管道准备停止") logger.info("收到结束信号,管道准备停止")
self._stop()
self._input_queue.task_done() # 标记结束信号已处理 self._input_queue.task_done() # 标记结束信号已处理
break break
@ -187,3 +188,13 @@ class ASRPipeline(PipelineBase):
# 通知回调函数 # 通知回调函数
self._notify_callbacks(result) self._notify_callbacks(result)
def stop(self) -> None:
"""
停止管道
"""
self._is_running = False
self._stop_event = True
for functor_name, functor in self._functor_dict.items():
logger.info(f"停止{functor_name}functor")
functor.stop()
logger.info("子Functor停止")

View File

@ -5,10 +5,15 @@ VAD测试
from src.functor.vad_functor import VADFunctor from src.functor.vad_functor import VADFunctor
from queue import Queue, Empty from queue import Queue, Empty
from src.model_loader import ModelLoader from src.model_loader import ModelLoader
from src.models import AudioBinary_Config from src.models import AudioBinary_Config, AudioBinary_data_list
from src.utils.data_format import wav_to_bytes from src.utils.data_format import wav_to_bytes
import time import time
from src.utils.logger import get_module_logger from src.utils.logger import get_module_logger
from pydub import AudioSegment
import soundfile
# 观察参数
OVERWATCH = False
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
@ -22,37 +27,54 @@ def test_vad_functor():
"auto_update": False, "auto_update": False,
} }
model_loader.load_models(args) model_loader.load_models(args)
# 创建VAD函数器 # 加载数据
vad_functor = VADFunctor() f_data, sample_rate = soundfile.read("tests/vad_example.wav")
audio_config = AudioBinary_Config(
chunk_size=200,
chunk_stride=1600,
sample_rate=sample_rate,
sample_width=16,
channels=1
)
chunk_stride = int(audio_config.chunk_size*sample_rate/1000)
audio_config.chunk_stride = chunk_stride
# 创建输入队列 # 创建输入队列
input_queue = Queue() input_queue = Queue()
# 创建音频数据列表
audio_binary_data_list = AudioBinary_data_list()
# 创建VAD函数器
vad_functor = VADFunctor()
# 设置输入队列 # 设置输入队列
vad_functor.set_input_queue(input_queue) vad_functor.set_input_queue(input_queue)
# 设置音频配置 # 设置音频配置
vad_functor.set_audio_config(AudioBinary_Config( vad_functor.set_audio_config(audio_config)
chunk_size=960, # 设置音频数据列表
chunk_stride=480, vad_functor.set_audio_binary_data_list(audio_binary_data_list)
sample_rate=16000,
sample_width=2,
channels=1
))
# 设置回调函数 # 设置回调函数
vad_functor.add_callback(lambda x: print(x)) vad_functor.add_callback(lambda x: print(f"callback: {x}"))
# 设置模型 # 设置模型
vad_functor.set_model({ vad_functor.set_model({
'vad': model_loader.models['vad'] 'vad': model_loader.models['vad']
}) })
# 启动VAD函数器 # 启动VAD函数器
vad_functor.run() vad_functor.run()
f_binary = f_data
# 加载数据 audio_clip_len = 200
f_binary = wav_to_bytes("tests/vad_example.wav") print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}")
chunk_size = 960000 for i in range(0, len(f_binary), audio_clip_len):
# chunk_size = len(f_binary) binary_data = f_binary[i:i+audio_clip_len]
print(f"f_binary: {len(f_binary)}, chunk_size: {chunk_size}, clip_num: {len(f_binary) // chunk_size}")
for i in range(0, len(f_binary), chunk_size):
binary_data = f_binary[i:i+chunk_size]
input_queue.put(binary_data) input_queue.put(binary_data)
# 等待VAD函数器结束 # 等待VAD函数器结束
time.sleep(10) print("[vad_test] 等待input_queue为空")
vad_functor.stop() input_queue.join()
print("[vad_test] input_queue为空")
vad_functor.stop()
print("[vad_test] VAD函数器结束")
# 保存音频数据
if OVERWATCH:
for index in range(len(audio_binary_data_list)):
save_path = f"tests/vad_test_output_{index}.wav"
soundfile.write(save_path, audio_binary_data_list[index].binary_data, sample_rate)