From 4e9e94d8dc2fee1de9c95799e2282e8f058dfec9 Mon Sep 17 00:00:00 2001 From: "Ziyang.Zhang" Date: Thu, 5 Jun 2025 13:43:23 +0800 Subject: [PATCH] =?UTF-8?q?[=E4=BB=A3=E7=A0=81=E9=87=8D=E6=9E=84=E4=B8=AD]?= =?UTF-8?q?=E5=AE=8C=E5=96=84VADFunctor=EF=BC=8C=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E6=8C=81=E4=B9=85=E5=8C=96=E4=BF=9D=E5=AD=98VAD=E7=89=87?= =?UTF-8?q?=E6=AE=B5=E7=9A=84=E9=9F=B3=E9=A2=91=E6=95=B0=E6=8D=AE=E6=88=90?= =?UTF-8?q?=E5=8A=9F=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 34 ++++++ src/functor/vad_functor.py | 145 ++++++++++++++++++++++---- src/models/__init__.py | 6 +- src/models/audio.py | 136 +++++++++++++++++++++--- src/models/vad.py | 201 +++++++++++++----------------------- src/pipeline/ASRpipeline.py | 57 +++++----- tests/functor/vad_test.py | 64 ++++++++---- 7 files changed, 430 insertions(+), 213 deletions(-) create mode 100644 main.py diff --git a/main.py b/main.py new file mode 100644 index 0000000..03accbd --- /dev/null +++ b/main.py @@ -0,0 +1,34 @@ +from funasr import AutoModel + +chunk_size = 200 # ms +model = AutoModel( + model="fsmn-vad", + model_revision="v2.0.4", + disable_update=True +) + +import soundfile + +wav_file = "tests/vad_example.wav" +speech, sample_rate = soundfile.read(wav_file) +chunk_stride = int(chunk_size * sample_rate / 1000) + +cache = {} +total_chunk_num = int(len((speech)-1)/chunk_stride+1) +for i in range(total_chunk_num): + speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] + is_final = i == total_chunk_num - 1 + res = model.generate( + input=speech_chunk, + cache=cache, + is_final=is_final, + chunk_size=chunk_size, + disable_pbar=True + ) + if len(res[0]["value"]): + print(res) + +print(f"len(speech): {len(speech)}") +print(f"len(speech_chunk): {len(speech_chunk)}") +print(f"total_chunk_num: {total_chunk_num}") +print(f"generateconfig: chunk_size: {chunk_size}, chunk_stride: {chunk_stride}") \ No newline at end of file diff --git a/src/functor/vad_functor.py b/src/functor/vad_functor.py index f2dc868..c946f63 100644 --- a/src/functor/vad_functor.py +++ b/src/functor/vad_functor.py @@ -1,13 +1,11 @@ from funasr import AutoModel -from typing import List, Dict, Any -from src.models import VADResponse -from src.models import AudioBinary_Config -from src.models import AudioBinary_data_list -from src.models import AudioBinary_Slice +from typing import List, Dict, Any, Callable +from src.models import VAD_Functor_result, _AudioBinary_data, AudioBinary_Config, AudioBinary_data_list from typing import Callable from src.functor.base import BaseFunctor import threading from queue import Empty, Queue +import numpy # 日志 from src.utils.logger import get_module_logger @@ -20,15 +18,37 @@ class VADFunctor(BaseFunctor): self ): super().__init__() - self._model: dict = {} - self._callback: List[Callable] = [] - self._status_lock: threading.Lock = threading.Lock() - self._input_queue: Queue = None - self._audio_config: AudioBinary_Config = None + # 资源与配置 + self._model: dict = {} # 模型 + self._callback: List[Callable] = [] # 回调函数 + self._input_queue: Queue = None # 输入队列 + self._audio_config: AudioBinary_Config = None # 音频配置 + self._audio_binary_data_list: AudioBinary_data_list = None # 音频数据列表 + + # flag + # 此处用到两个锁,但都是为了截断_run线程,考虑后续优化 self._is_running: bool = False self._stop_event: bool = False - self._audio_cache: bytes = b'' - self._cache: dict = {} + + # 状态锁 + self._status_lock: threading.Lock = threading.Lock() + + # 缓存 + self._audio_cache: numpy.ndarray = None + self._audio_cache_preindex: int = 0 + self._model_cache: dict = {} + self._cache_result_list = [] + self._audiobinary_cache = None + + def reset_cache(self): + """ + 重置缓存, 用于任务完成后清理缓存数据, 准备下次任务 + """ + self._audio_cache = None + self._audio_cache_preindex = 0 + self._model_cache = {} + self._cache_result_list = [] + self._audiobinary_cache = None def set_input_queue(self, queue: Queue): self._input_queue = queue @@ -38,32 +58,99 @@ class VADFunctor(BaseFunctor): def set_audio_config(self, audio_config: AudioBinary_Config): self._audio_config = audio_config + logger.info(f"VADFunctor设置音频配置: {self._audio_config}") + + def set_audio_binary_data_list(self, audio_binary_data_list: AudioBinary_data_list): + self._audio_binary_data_list = audio_binary_data_list def add_callback(self, callback: Callable): if not isinstance(self._callback, list): self._callback = [] self._callback.append(callback) - def _process(self, data: bytes): + def _do_callback(self, result: List[List[int]], audio_cache: AudioBinary_data_list): + """ + 回调函数 + VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice + + 输入: + result: List[[start,end]] 处理所得VAD端点 + 其中若start==-1, 则表示前无端点, 若end==-1, 则表示后无端点 + 当处理得到一个完成片段时, 存入AudioBinary中, 并向队列中添加AudioBinary_Slice + 输出: + None + """ + # 持久化缓存结果队列 + for pair in result: + [start, end] = pair + # 若无前端点, 则向缓存队列中合并 + if start == -1: + self._cache_result_list[-1][1] = end + else: + self._cache_result_list.append(pair) + while len(self._cache_result_list) > 1: + # 创建VAD片段 + # 计算开始帧 + start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0]) + start_frame -= self._audio_cache_preindex + # 计算结束帧 + end_frame = self._audio_config.ms2frame(self._cache_result_list[0][1]) + end_frame -= self._audio_cache_preindex + # 计算开始时间 + vad_result = VAD_Functor_result.create_from_push_data( + audiobinary_data_list=self._audio_binary_data_list, + data=self._audiobinary_cache[start_frame:end_frame], + start_time=self._cache_result_list[0][0], + end_time=self._cache_result_list[0][1] + ) + self._audio_cache_preindex += end_frame + self._audiobinary_cache = self._audiobinary_cache[end_frame:] + for callback in self._callback: + callback(vad_result) + self._cache_result_list.pop(0) + + def _predeal_data(self, data: Any): + """ + 预处理数据 + """ + if self._audio_cache is None: + self._audio_cache = data + else: + # 拼接音频数据 + if isinstance(self._audio_cache, numpy.ndarray): + self._audio_cache = numpy.concatenate((self._audio_cache, data)) + elif isinstance(self._audio_cache, list): + self._audio_cache.append(data) + if self._audiobinary_cache is None: + self._audiobinary_cache = data + else: + # 拼接音频数据 + if isinstance(self._audiobinary_cache, numpy.ndarray): + self._audiobinary_cache = numpy.concatenate((self._audiobinary_cache, data)) + elif isinstance(self._audiobinary_cache, list): + self._audiobinary_cache.append(data) + def _process(self, data: Any): """ 处理数据 """ - self._audio_cache += data - if len(self._audio_cache) >= self._audio_config.chunk_size*100: + self._predeal_data(data) + if len(self._audio_cache) >= self._audio_config.chunk_stride: result = self._model['vad'].generate( input=self._audio_cache, - cache=self._cache, + cache=self._model_cache, chunk_size=self._audio_config.chunk_size, is_final=False, ) - logger.info(f"VADFunctor处理数据: {len(self._audio_cache)}, {result}") - self._audio_cache = b'' + if (len(result[0]['value']) > 0): + self._do_callback(result[0]['value'], self._audio_cache) + logger.debug(f"VADFunctor结果: {result[0]['value']}") + self._audio_cache = None def _run(self): """ 线程运行逻辑 - 监听输入队列,当有数据时,处理数据 + 监听输入队列, 当有数据时, 处理数据 当输入队列为空时, 间隔1s检测是否进入停止事件。 """ # 刷新运行状态 @@ -76,7 +163,7 @@ class VADFunctor(BaseFunctor): data = self._input_queue.get(True, timeout=1) self._process(data) self._input_queue.task_done() - # 当队列为空时,间隔1s检测是否进入停止事件。 + # 当队列为空时, 间隔1s检测是否进入停止事件。 except Empty: if self._stop_event: break @@ -90,12 +177,26 @@ class VADFunctor(BaseFunctor): """ 启动 _run 线程, 并返回线程对象 """ + self._pre_check() self._thread = threading.Thread(target=self._run, daemon=True) self._thread.start() return self._thread - def _pre_check(self): - pass + def _pre_check(self) -> bool: + """ + 检测硬性资源是否都已设置 + """ + if self._model is None: + raise ValueError("模型未设置") + if self._audio_config is None: + raise ValueError("音频配置未设置") + if self._audio_binary_data_list is None: + raise ValueError("音频数据列表未设置") + if self._input_queue is None: + raise ValueError("输入队列未设置") + if self._callback is None: + raise ValueError("回调函数未设置") + return True def stop(self): with self._status_lock: diff --git a/src/models/__init__.py b/src/models/__init__.py index fdbc7ec..28004e7 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,3 +1,3 @@ -from .audio import AudioBinary_Config, AudioBinary_data_list, AudioBinary_Slice -from .vad import VADResponse -__all__ = ["AudioBinary_Config", "AudioBinary_data_list", "AudioBinary_Slice", "VADResponse"] \ No newline at end of file +from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data +from .vad import VAD_Functor_result +__all__ = ["AudioBinary_Config", "AudioBinary_data_list", "_AudioBinary_data", "VAD_Functor_result"] \ No newline at end of file diff --git a/src/models/audio.py b/src/models/audio.py index 83b187b..558e3f3 100644 --- a/src/models/audio.py +++ b/src/models/audio.py @@ -1,9 +1,19 @@ -from pydantic import BaseModel, Field -from typing import List +from pydantic import BaseModel, Field, validator +from typing import List, Any +import numpy + +from src.utils import get_module_logger + +logger = get_module_logger(__name__) + +binary_data_types = (bytes, numpy.ndarray) class AudioBinary_Config(BaseModel): """二进制音频块配置信息""" - audio_data: bytes = Field(description="音频数据", default=None) + class Config: + arbitrary_types_allowed = True + + audio_data: binary_data_types = Field(description="音频数据", default=None) chunk_size: int = Field(description="块大小", default=100) chunk_stride: int = Field(description="块步长", default=1600) sample_rate: int = Field(description="采样率", default=16000) @@ -15,23 +25,117 @@ class AudioBinary_Config(BaseModel): def AudioBinary_Config_from_dict(cls, data: dict): return cls(**data) + def ms2frame(self, ms: int) -> int: + """ + 将毫秒转换为帧 + """ + return int(ms * self.sample_rate / 1000) + + def frame2ms(self, frame: int) -> int: + """ + 将帧转换为毫秒 + """ + return int(frame * 1000 / self.sample_rate) class _AudioBinary_data(BaseModel): - """音频二进制数据""" - binary_data: bytes = Field(description="音频二进制数据", default=None) + """音频数据""" + binary_data: binary_data_types = Field(description="音频二进制数据", default=None) + + class Config: + arbitrary_types_allowed = True + + @validator('binary_data') + def validate_binary_data(cls, v): + """ + 验证音频数据 + Args: + v: 音频数据 + Returns: + binary_data_types: 音频数据 + """ + if not isinstance(v, (bytes, numpy.ndarray)): + logger.warning("[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查", cls.__class__.__name__, type(v)) + return v + + def __len__(self): + """ + 获取音频数据长度 + Returns: + int: 音频数据长度 + """ + return len(self.binary_data) + + def __init__(self, binary_data: binary_data_types): + """ + 初始化音频数据 + Args: + binary_data: 音频数据 + """ + logger.debug("[%s]初始化音频数据, 数据类型为%s", self.__class__.__name__, type(binary_data)) + super().__init__(binary_data=binary_data) + + def __getitem__(self): + """ + 当获取数据时, 直接返回binary_data + Returns: + binary_data_types: 音频数据 + """ + return self.binary_data class AudioBinary_data_list(BaseModel): - """音频二进制数据列表""" - binary_data_list: List[_AudioBinary_data] = Field(description="音频二进制数据列表", default=[]) + """音频数据列表""" + binary_data_list: List[_AudioBinary_data] = Field(description="音频数据列表", default=[]) - def __call__(self): - return self.binary_data_list + class Config: + arbitrary_types_allowed = True -class AudioBinary_Slice(BaseModel): - """音频块切片""" - target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None) - start_index: int = Field(description="开始位置", default=0) - end_index: int = Field(description="结束位置", default=0) + def push_data(self, data: binary_data_types) -> int: + """ + 添加音频数据 + Args: + data: 音频数据 + Returns: + int: 数据在binary_data_list中的索引 + """ + self.binary_data_list.append(_AudioBinary_data(binary_data=data)) + return len(self.binary_data_list) - 1 - def __call__(self): - return self.target_Binary(self.start_index, self.end_index) \ No newline at end of file + def __getitem__(self, index: int): + """ + 获取音频数据 + Args: + index: 音频数据在binary_data_list中的索引 + Returns: + _AudioBinary_data: 音频数据 + """ + return self.binary_data_list[index] + + def __len__(self): + """ + 获取音频数据列表长度 + Returns: + int: 音频数据列表长度 + """ + return len(self.binary_data_list) + +# class AudioBinary_Slice(BaseModel): +# """音频块切片""" +# target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None) +# start_index: int = Field(description="开始位置", default=0) +# end_index: int = Field(description="结束位置", default=-1) + +# @validator('start_index') +# def validate_start_index(cls, v): +# if v < 0: +# raise ValueError("start_index必须大于0") +# return v + +# @validator('end_index') +# def validate_end_index(cls, v): +# if v < cls.start_index: +# logger.debug("[%s]end_index小于start_index, 将end_index设置为start_index", cls.__class__.__name__) +# v = cls.start_index +# return v + +# def __call__(self): +# return self.target_Binary(self.start_index, self.end_index) \ No newline at end of file diff --git a/src/models/vad.py b/src/models/vad.py index 46f087a..7eafca7 100644 --- a/src/models/vad.py +++ b/src/models/vad.py @@ -1,143 +1,88 @@ from pydantic import BaseModel, Field, validator -from typing import List, Optional, Callable +from typing import List, Optional, Callable, Any +from .audio import AudioBinary_data_list, _AudioBinary_data -class VADSegment(BaseModel): - """VAD片段""" - start: int = Field(description="开始时间(ms)") - end: int = Field(description="结束时间(ms)") +class VAD_Functor_result(BaseModel): + """ + VADFunctor结果 + """ + audiobinary_data_list: AudioBinary_data_list = Field(description="音频数据列表") + audiobinary_index: int = Field(description="音频数据索引") + audiobinary_data: _AudioBinary_data = Field(description="音频数据, 指向AudioBinary_data") + start_time: int = Field(description="开始时间", is_required=True) + end_time: int = Field(description="结束时间", is_required=True) -class VADResult(BaseModel): - """VAD结果""" - key: str = Field(description="音频标识") - value: List[VADSegment] = Field(description="VAD片段列表") + @validator('audiobinary_data_list') + def validate_audiobinary_data_list(cls, v): + if not isinstance(v, AudioBinary_data_list): + raise ValueError("audiobinary_data_list必须是AudioBinary_data_list类型") + return v -class VADResponse(BaseModel): - """VAD响应""" - results: List[VADResult] = Field(description="VAD结果列表", default_factory=list) - time_chunk: List[VADSegment] = Field(description="时间块", default_factory=list) - time_chunk_index: int = Field(description="当前处理时间块索引", default=0) - time_chunk_index_callback: Optional[Callable[[int], None]] = Field( - description="时间块索引回调函数", - default=None - ) - - @validator('time_chunk') - def validate_time_chunk(cls, v): - """验证时间块的有效性""" - if not v: - return v - - # 检查时间顺序 - for i in range(len(v) - 1): - if v[i].end >= v[i + 1].start: - raise ValueError(f"时间块{i}的结束时间({v[i].end})大于等于下一个时间块的开始时间({v[i + 1].start})") + @validator('audiobinary_index') + def validate_audiobinary_index(cls, v): + if not isinstance(v, int): + raise ValueError("audiobinary_index必须是int类型") + if v < 0: + raise ValueError("audiobinary_index必须大于0") + return v + + @validator('audiobinary_data') + def validate_audiobinary_data(cls, v): + if not isinstance(v, _AudioBinary_data): + raise ValueError("audiobinary_data必须是AudioBinary_data类型") + return v + + @validator('start_time') + def validate_start_time(cls, v): + if not isinstance(v, int): + raise ValueError("start_time必须是int类型") + if v < 0: + raise ValueError("start_time必须大于0") return v - # 回调未处理的时间块 - def process_time_chunk(self, callback: Callable[[int], None] = None) -> None: - """处理时间块""" - # print("Enter process_time_chunk", self.time_chunk_index, len(self.time_chunk)) - while self.time_chunk_index < len(self.time_chunk) - 1: - index = self.time_chunk_index - if self.time_chunk[index].end != -1: - x = { - "start_time": self.time_chunk[index].start, - "end_time": self.time_chunk[index].end - } - if callback is not None: - callback(x) - elif self.time_chunk_index_callback is not None: - self.time_chunk_index_callback(x) - else: - print("[Warning] No callback available") - self.time_chunk_index += 1 + @validator('end_time') + def validate_end_time(cls, v, values): + if not isinstance(v, int): + raise ValueError("end_time必须是int类型") + if 'start_time' in values and v <= values['start_time']: + raise ValueError("end_time必须大于start_time") + return v - def __add__(self, other: 'VADResponse') -> 'VADResponse': - """合并两个VADResponse""" - if not self.results: - self.results = other.results - self.time_chunk = other.time_chunk - return self - - # 检查是否可以合并最后一个结果 - last_result = self.results[-1] - first_other = other.results[0] - - if last_result.value[-1].end == first_other.value[0].start: - # 合并相邻的时间段 - last_result.value[-1].end = first_other.value[0].end - first_other.value.pop(0) - - # 更新time_chunk - self.time_chunk[-1].end = other.time_chunk[0].end - other.time_chunk.pop(0) - - # 添加剩余的结果 - if first_other.value: - self.results.extend(other.results) - self.time_chunk.extend(other.time_chunk) - else: - # 直接添加所有结果 - self.results.extend(other.results) - self.time_chunk.extend(other.time_chunk) - - return self - @classmethod - def from_raw(cls, raw_data: List[dict]) -> "VADResponse": + def create_from_push_data( + cls, + audiobinary_data_list: AudioBinary_data_list, + data: Any, + start_time: int, + end_time: int + ): """ - 从原始数据创建VADResponse - - 参数: - raw_data: 原始数据,格式如 [{'key': 'xxx', 'value': [[-1, 59540], [59820, -1]]}] - - 返回: - VADResponse: 解析后的VAD响应 + 创建VAD片段 """ - results = [] - time_chunk = [] - for item in raw_data: - segments = [ - VADSegment(start=seg[0], end=seg[1]) - for seg in item['value'] - ] - results.append(VADResult( - key=item['key'], - value=segments - )) - time_chunk.extend(segments) - return cls(results=results, time_chunk=time_chunk) - - def to_raw(self) -> List[dict]: + index = audiobinary_data_list.push_data(data) + + return cls( + audiobinary_data_list=audiobinary_data_list, + audiobinary_index=index, + audiobinary_data=audiobinary_data_list[index], + start_time=start_time, + end_time=end_time) + + def __len__(self): """ - 转换为原始数据格式 - - 返回: - List[dict]: 原始数据格式 + 获取音频数据长度 """ - return [ - { - 'key': result.key, - 'value': [[seg.start, seg.end] for seg in result.value] - } - for result in self.results - ] + return len(self.audiobinary_data.binary_data) def __str__(self): - result_str = "VADResponse:\n" - for result in self.results: - for value_item in result.value: - result_str += f"[{value_item.start}:{value_item.end}]\n" - return result_str - - def __iter__(self): - return iter(self.time_chunk) + """ + 字符串展示内容 + """ + tostr = f'audiobinary_data_index: {self.audiobinary_index}\n' + tostr += f'start_time: {self.start_time}\n' + tostr += f'end_time: {self.end_time}\n' + tostr += f'data_length: {len(self.audiobinary_data.binary_data)}\n' + tostr += f'data_type: {type(self.audiobinary_data.binary_data)}\n' + return tostr - def __next__(self): - return next(self.time_chunk) - - def __len__(self): - return len(self.time_chunk) - - def __getitem__(self, index): - return self.time_chunk[index] \ No newline at end of file + \ No newline at end of file diff --git a/src/pipeline/ASRpipeline.py b/src/pipeline/ASRpipeline.py index 369f65f..85f1818 100644 --- a/src/pipeline/ASRpipeline.py +++ b/src/pipeline/ASRpipeline.py @@ -16,7 +16,7 @@ class ASRPipeline(PipelineBase): """ super().__init__(*args, **kwargs) self._config: Dict[str, Any] = {} - self._funtor_dict: Dict[str, Any] = {} + self._functor_dict: Dict[str, Any] = {} self._subqueue_dict: Dict[str, Any] = {} self._is_baked: bool = False @@ -46,28 +46,28 @@ class ASRPipeline(PipelineBase): """ 烘焙管道 """ - self._init_funtor() + self._init_functor() self._is_baked = True - def _init_funtor(self) -> None: + def _init_functor(self) -> None: """ 初始化函数 """ try: - from src.funtor import FuntorFactory - # 加载VAD、asr、spk funtor - self._funtor_dict["vad"] = FuntorFactory.make_funtor( - funtor_name = "vad", + from src.functor import functorFactory + # 加载VAD、asr、spk functor + self._functor_dict["vad"] = functorFactory.make_functor( + functor_name = "vad", config = self._config, models = self._models ) - self._funtor_dict["asr"] = FuntorFactory.make_funtor( - funtor_name = "asr", + self._functor_dict["asr"] = functorFactory.make_functor( + functor_name = "asr", config = self._config, models = self._models ) - self._funtor_dict["spk"] = FuntorFactory.make_funtor( - funtor_name = "spk", + self._functor_dict["spk"] = functorFactory.make_functor( + functor_name = "spk", config = self._config, models = self._models ) @@ -79,13 +79,13 @@ class ASRPipeline(PipelineBase): self._subqueue_dict["spkend"] = Queue() # 设置子队列的输入队列 - self._funtor_dict["vad"].set_input_queue(self._input_queue) - self._funtor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"]) - self._funtor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"]) + self._functor_dict["vad"].set_input_queue(self._input_queue) + self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"]) + self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"]) # 设置回调函数——放置到对应队列中 - self._funtor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put) - self._funtor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put) + self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put) + self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put) # 构造带回调函数的put def put_with_checkcallback(queue: Queue, callback: Callable) -> None: @@ -97,11 +97,11 @@ class ASRPipeline(PipelineBase): callback() return put_with_check - self._funtor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result)) - self._funtor_dict["spk"].add_callback(put_with_checkcallback(self._subqueue_dict["spkend"], self._check_result)) + self._functor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result)) + self._functor_dict["spk"].add_callback(put_with_checkcallback(self._subqueue_dict["spkend"], self._check_result)) except ImportError: - raise ImportError("FuntorFactory引入失败,ASRPipeline无法完成初始化") + raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化") def get_config(self) -> Dict[str, Any]: """ @@ -129,10 +129,10 @@ class ASRPipeline(PipelineBase): if not self._is_baked: raise RuntimeError("管道未烘焙,无法运行") - # 运行所有funtor - for funtor_name, funtor in self._funtor_dict.items(): - logger.info(f"运行{funtor_name}funtor") - funtor.run() + # 运行所有functor + for functor_name, functor in self._functor_dict.items(): + logger.info(f"运行{functor_name}functor") + self._functor_dict[functor_name].run() # 运行管道 if not self._input_queue: @@ -152,6 +152,7 @@ class ASRPipeline(PipelineBase): # 检查是否是结束信号 if data is None: logger.info("收到结束信号,管道准备停止") + self._stop() self._input_queue.task_done() # 标记结束信号已处理 break @@ -187,3 +188,13 @@ class ASRPipeline(PipelineBase): # 通知回调函数 self._notify_callbacks(result) + def stop(self) -> None: + """ + 停止管道 + """ + self._is_running = False + self._stop_event = True + for functor_name, functor in self._functor_dict.items(): + logger.info(f"停止{functor_name}functor") + functor.stop() + logger.info("子Functor停止") diff --git a/tests/functor/vad_test.py b/tests/functor/vad_test.py index 2770dc3..f39e325 100644 --- a/tests/functor/vad_test.py +++ b/tests/functor/vad_test.py @@ -5,10 +5,15 @@ VAD测试 from src.functor.vad_functor import VADFunctor from queue import Queue, Empty from src.model_loader import ModelLoader -from src.models import AudioBinary_Config +from src.models import AudioBinary_Config, AudioBinary_data_list from src.utils.data_format import wav_to_bytes import time from src.utils.logger import get_module_logger +from pydub import AudioSegment +import soundfile + +# 观察参数 +OVERWATCH = False logger = get_module_logger(__name__) @@ -22,37 +27,54 @@ def test_vad_functor(): "auto_update": False, } model_loader.load_models(args) - # 创建VAD函数器 - vad_functor = VADFunctor() + # 加载数据 + f_data, sample_rate = soundfile.read("tests/vad_example.wav") + audio_config = AudioBinary_Config( + chunk_size=200, + chunk_stride=1600, + sample_rate=sample_rate, + sample_width=16, + channels=1 + ) + chunk_stride = int(audio_config.chunk_size*sample_rate/1000) + audio_config.chunk_stride = chunk_stride # 创建输入队列 input_queue = Queue() + # 创建音频数据列表 + audio_binary_data_list = AudioBinary_data_list() + + # 创建VAD函数器 + vad_functor = VADFunctor() # 设置输入队列 vad_functor.set_input_queue(input_queue) # 设置音频配置 - vad_functor.set_audio_config(AudioBinary_Config( - chunk_size=960, - chunk_stride=480, - sample_rate=16000, - sample_width=2, - channels=1 - )) + vad_functor.set_audio_config(audio_config) + # 设置音频数据列表 + vad_functor.set_audio_binary_data_list(audio_binary_data_list) # 设置回调函数 - vad_functor.add_callback(lambda x: print(x)) + vad_functor.add_callback(lambda x: print(f"callback: {x}")) # 设置模型 vad_functor.set_model({ 'vad': model_loader.models['vad'] }) + # 启动VAD函数器 vad_functor.run() - - # 加载数据 - f_binary = wav_to_bytes("tests/vad_example.wav") - chunk_size = 960000 - # chunk_size = len(f_binary) - print(f"f_binary: {len(f_binary)}, chunk_size: {chunk_size}, clip_num: {len(f_binary) // chunk_size}") - for i in range(0, len(f_binary), chunk_size): - binary_data = f_binary[i:i+chunk_size] + f_binary = f_data + audio_clip_len = 200 + print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}") + for i in range(0, len(f_binary), audio_clip_len): + binary_data = f_binary[i:i+audio_clip_len] input_queue.put(binary_data) # 等待VAD函数器结束 - time.sleep(10) - vad_functor.stop() \ No newline at end of file + print("[vad_test] 等待input_queue为空") + input_queue.join() + print("[vad_test] input_queue为空") + vad_functor.stop() + print("[vad_test] VAD函数器结束") + + # 保存音频数据 + if OVERWATCH: + for index in range(len(audio_binary_data_list)): + save_path = f"tests/vad_test_output_{index}.wav" + soundfile.write(save_path, audio_binary_data_list[index].binary_data, sample_rate)