From b569b7e63db8d14e93e2f1a94b768fa67da86969 Mon Sep 17 00:00:00 2001 From: "Ziyang.Zhang" Date: Tue, 3 Jun 2025 17:41:59 +0800 Subject: [PATCH] =?UTF-8?q?[=E4=BB=A3=E7=A0=81=E9=87=8D=E6=9E=84=E4=B8=AD]?= =?UTF-8?q?=E6=B5=8B=E8=AF=95VADFuntor=E4=B8=AD=EF=BC=8C=E5=8F=91=E7=8E=B0?= =?UTF-8?q?=E5=AD=97=E8=8A=82=E6=B5=81=E6=8E=A8=E7=90=86=E9=97=AE=E9=A2=98?= =?UTF-8?q?=EF=BC=8C=E5=BE=85=E8=BF=9B=E4=B8=80=E6=AD=A5=E7=A0=94=E7=A9=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/functor/__init__.py | 6 +- src/functor/base.py | 163 ++++++++++++++++------------ src/functor/vad_functor.py | 211 +++++++++++++++++++++++++++--------- src/model_loader.py | 26 +++-- src/models/audio.py | 19 ++-- src/pipeline/ASRpipeline.py | 7 +- src/utils/data_format.py | 110 +++++++++++++++++++ test_main.py | 22 +--- tests/functor/vad_test.py | 58 ++++++++++ 9 files changed, 457 insertions(+), 165 deletions(-) create mode 100644 src/utils/data_format.py create mode 100644 tests/functor/vad_test.py diff --git a/src/functor/__init__.py b/src/functor/__init__.py index 5ecc5b1..a18b658 100644 --- a/src/functor/__init__.py +++ b/src/functor/__init__.py @@ -1,4 +1,4 @@ -from .vad_functor import VAD -from .model_loader import load_models +from .vad_functor import VADFunctor +from .base import FunctorFactory -__all__ = ["VAD", "load_models"] \ No newline at end of file +__all__ = ["VADFunctor", "FunctorFactory"] \ No newline at end of file diff --git a/src/functor/base.py b/src/functor/base.py index 221401b..3754154 100644 --- a/src/functor/base.py +++ b/src/functor/base.py @@ -1,102 +1,129 @@ -from typing import Callable +""" +Functor基础模块 -class BaseFunctor: +该模块定义了Functor的基类,所有功能性的类(如VAD、PUNC、ASR、SPK等)都应继承自这个基类。 +基类提供了数据处理的基本框架,包括: +- 回调函数管理 +- 模型配置管理 +- 线程运行控制 + +主要类: + BaseFunctor: Functor抽象类 + FunctorFactory: Functor工厂类 +""" +from abc import ABC, abstractmethod +from typing import Callable, List +from queue import Queue + +class BaseFunctor(ABC): """ - 基础函数器类, 提供数据处理的基本框架 - - 该类实现了数据处理的基本接口, 包括数据推送、处理和回调机制。 - 所有具体的功能实现类都应该继承这个基类。 - + Functor抽象类 + + 该抽象类规定了所有的Functor类必须实现run()方法启动自身线程 + 属性: - _data (dict): 存储处理数据的字典 _callback (Callable): 处理完成后的回调函数 _model (dict): 存储模型相关的配置和实例 """ - - def __init__(self, - data: dict or bytes = {}, - callback: Callable = None, - model: dict = {}, + + def __init__( + self ): """ 初始化函数器 - + 参数: - data (dict or bytes): 初始数据, 可以是字典或字节数据 callback (Callable): 处理完成后的回调函数 model (dict): 模型相关的配置和实例 """ - self._data: dict = {} - self.push_data(data) - self._callback: Callable = callback - self._model: dict = model - pass + self._callback: List[Callable] = [] + self._model: dict = {} - def __call__(self, data = None): + def add_callback(self, callback: Callable): """ - 使类实例可调用, 处理数据并触发回调 - - 参数: - data: 要处理的数据, 如果为None则处理已存储的数据 - - 返回: - 处理结果 - """ - # 如果传入数据, 则压入数据 - if data is not None: - self.push_data(data) - # 处理数据 - result = self.process() - # 如果回调函数存在, 则触发回调 - if self._callback is not None and callable(self._callback): - self._callback(result) - return result + 添加回调函数 - def __add__(self, other): - """ - 重载加法运算符, 用于合并数据 - - 参数: - other: 要合并的数据 - - 返回: - self: 返回当前实例, 支持链式调用 - """ - self.push_data(other) - return self - - def set_callback(self, callback: Callable): - """ - 设置回调函数 - 参数: callback (Callable): 新的回调函数 """ - self._callback = callback + self._callback.append(callback) def set_model(self, model: dict): """ 设置模型配置 - + 参数: model (dict): 新的模型配置 """ self._model = model - def push_data(self, data): + def set_input_queue(self, queue: Queue): """ - 推送数据到处理器 - + 设置输入队列 + 参数: - data: 要处理的数据 + queue (Queue): 新的输入队列 """ - pass - - def process(self): + self._input_queue = queue + + @abstractmethod + def _run(self): """ - 处理数据的核心方法 - + 线程运行逻辑 + 返回: - 处理结果 + 当达到条件时触发callback """ - pass + + @abstractmethod + def run(self): + """ + 启动_run方法线程 + + 返回: + 线程实例 + """ + + @abstractmethod + def _pre_check(self): + """ + 预检查 + + 返回: + 预检查结果 + """ + + @abstractmethod + def stop(self): + """ + 停止线程 + + 返回: + 停止结果 + """ + +class FunctorFactory: + """ + Functor工厂类 + + 该工厂类负责创建和配置Functor实例 + + 主要方法: + make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor: + 创建并配置Functor实例 + """ + + @staticmethod + def make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor: + """ + 创建并配置Functor实例 + + 参数: + funtor_name (str): Functor名称 + config (dict): 配置信息 + models (dict): 模型信息 + + 返回: + BaseFunctor: 创建的Functor实例 + """ + \ No newline at end of file diff --git a/src/functor/vad_functor.py b/src/functor/vad_functor.py index cc87040..f2dc868 100644 --- a/src/functor/vad_functor.py +++ b/src/functor/vad_functor.py @@ -2,59 +2,168 @@ from funasr import AutoModel from typing import List, Dict, Any from src.models import VADResponse from src.models import AudioBinary_Config -from src.functor.audiochunk import AudioChunk -from src.models import AudioBinary_Chunk +from src.models import AudioBinary_data_list +from src.models import AudioBinary_Slice from typing import Callable +from src.functor.base import BaseFunctor +import threading +from queue import Empty, Queue -class VAD: +# 日志 +from src.utils.logger import get_module_logger - def __init__(self, - VAD_model = None, - audio_config : AudioBinary_Config = None, - callback: Callable = None, - ): - # vad model - self.VAD_model = VAD_model - if self.VAD_model is None: - self.VAD_model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True) - # audio config - self.audio_config = audio_config - # vad result - self.vad_result = VADResponse(time_chunk_index_callback=callback) - # audio binary poll - self.audio_chunk = AudioChunk( - audio_config=self.audio_config - ) - self.cache = {} - - def push_binary_data(self, - binary_data: bytes, - ): - # 压入二进制数据 - self.audio_chunk.add_chunk(binary_data) - # 处理音频块 - res = self.VAD_model.generate(input=binary_data, - cache=self.cache, - chunk_size=self.audio_config.chunk_size, - is_final=False) - # print("VAD generate", res) - if len(res[0]["value"]): - self.vad_result += VADResponse.from_raw(res) - - def set_callback(self, - callback: Callable, - ): - self.vad_result.time_chunk_index_callback = callback +logger = get_module_logger(__name__) - def process_vad_result(self, callback: Callable = None): - # 处理VAD结果 - callback = callback if callback is not None else self.vad_result.time_chunk_index_callback - self.vad_result.process_time_chunk( - lambda x : callback( - AudioBinary_Chunk( - start_time=x["start_time"], - end_time=x["end_time"], - chunk=self.audio_chunk.get_chunk(x["start_time"], x["end_time"]) - ) + +class VADFunctor(BaseFunctor): + def __init__( + self + ): + super().__init__() + self._model: dict = {} + self._callback: List[Callable] = [] + self._status_lock: threading.Lock = threading.Lock() + self._input_queue: Queue = None + self._audio_config: AudioBinary_Config = None + self._is_running: bool = False + self._stop_event: bool = False + self._audio_cache: bytes = b'' + self._cache: dict = {} + + def set_input_queue(self, queue: Queue): + self._input_queue = queue + + def set_model(self, model: dict): + self._model = model + + def set_audio_config(self, audio_config: AudioBinary_Config): + self._audio_config = audio_config + + def add_callback(self, callback: Callable): + if not isinstance(self._callback, list): + self._callback = [] + self._callback.append(callback) + + def _process(self, data: bytes): + """ + 处理数据 + """ + self._audio_cache += data + if len(self._audio_cache) >= self._audio_config.chunk_size*100: + result = self._model['vad'].generate( + input=self._audio_cache, + cache=self._cache, + chunk_size=self._audio_config.chunk_size, + is_final=False, ) - ) \ No newline at end of file + logger.info(f"VADFunctor处理数据: {len(self._audio_cache)}, {result}") + self._audio_cache = b'' + + + def _run(self): + """ + 线程运行逻辑 + 监听输入队列,当有数据时,处理数据 + 当输入队列为空时, 间隔1s检测是否进入停止事件。 + """ + # 刷新运行状态 + with self._status_lock: + self._is_running = True + self._stop_event = False + # 运行逻辑 + while self._is_running: + try: + data = self._input_queue.get(True, timeout=1) + self._process(data) + self._input_queue.task_done() + # 当队列为空时,间隔1s检测是否进入停止事件。 + except Empty: + if self._stop_event: + break + continue + # 其他异常 + except Exception as e: + logger.error(f"VADFunctor运行时发生错误: {e}") + raise e + + def run(self): + """ + 启动 _run 线程, 并返回线程对象 + """ + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + return self._thread + + def _pre_check(self): + pass + + def stop(self): + with self._status_lock: + self._stop_event = True + self._thread.join() + with self._status_lock: + self._is_running = False + return True + + +# class VAD: + +# def __init__( +# self, +# VAD_model=None, +# audio_config: AudioBinary_Config = None, +# callback: Callable = None, +# ): +# # vad model +# self.VAD_model = VAD_model +# if self.VAD_model is None: +# self.VAD_model = AutoModel( +# model="fsmn-vad", model_revision="v2.0.4", disable_update=True +# ) +# # audio config +# self.audio_config = audio_config +# # vad result +# self.vad_result = VADResponse(time_chunk_index_callback=callback) +# # audio binary poll +# self.audio_chunk = AudioChunk(audio_config=self.audio_config) +# self.cache = {} + +# def push_binary_data( +# self, +# binary_data: bytes, +# ): +# # 压入二进制数据 +# self.audio_chunk.add_chunk(binary_data) +# # 处理音频块 +# res = self.VAD_model.generate( +# input=binary_data, +# cache=self.cache, +# chunk_size=self.audio_config.chunk_size, +# is_final=False, +# ) +# # print("VAD generate", res) +# if len(res[0]["value"]): +# self.vad_result += VADResponse.from_raw(res) + +# def set_callback( +# self, +# callback: Callable, +# ): +# self.vad_result.time_chunk_index_callback = callback + +# def process_vad_result(self, callback: Callable = None): +# # 处理VAD结果 +# callback = ( +# callback +# if callback is not None +# else self.vad_result.time_chunk_index_callback +# ) +# self.vad_result.process_time_chunk( +# lambda x: callback( +# AudioBinary_Chunk( +# start_time=x["start_time"], +# end_time=x["end_time"], +# chunk=self.audio_chunk.get_chunk(x["start_time"], x["end_time"]), +# ) +# ) +# ) diff --git a/src/model_loader.py b/src/model_loader.py index aafb93c..f4a116c 100644 --- a/src/model_loader.py +++ b/src/model_loader.py @@ -53,12 +53,12 @@ class ModelLoader: # 直接调用等于调用self.models return self.models - def _load_model(self, args, model_type): + def _load_model(self, input_model_args: dict, model_type: str): """ 加载单个模型 参数: - args: 命令行参数, 包含模型配置 + model_args: 模型加载字典 model_type: 模型类型, 用于确定使用哪个模型参数 返回: @@ -81,12 +81,13 @@ class ModelLoader: if key in ["model", "model_revision"]: # 特殊处理model和model_revision, 因为它们需要model_type前缀 if key == "model": - value = getattr(args, f"{model_type}_model", None) + value = input_model_args.get(f"{model_type}_model", None) else: - value = getattr(args, f"{model_type}_model_revision", None) + value = input_model_args.get(f"{model_type}_model_revision", None) else: - value = getattr(args, key, None) + value = input_model_args.get(key, None) if value is not None: + logger.info("替换%s模型参数: %s = %s", model_type, key, value) model_args[key] = value # 验证必要参数 if not model_args["model"]: @@ -113,12 +114,13 @@ class ModelLoader: # 初始化模型字典 self.models = {} # 加载离线ASR模型 - self.models["asr"] = self._load_model(args, "asr") - # 2. 加载在线ASR模型 - self.models["asr_streaming"] = self._load_model(args, "asr_online") - # 3. 加载VAD模型 - self.models["vad"] = self._load_model(args, "vad") - # 4. 加载标点符号模型(如果指定) - self.models["punc"] = self._load_model(args, "punc") + # 检查对应键是否存在 + model_list = ['asr', 'asr_online', 'vad', 'punc'] + for model_name in model_list: + name_model = f"{model_name}_model" + name_model_revision = f"{model_name}_model_revision" + if name_model in args: + logger.info("加载%s模型", model_name) + self.models[model_name] = self._load_model(args, model_name) logger.info("所有模型加载完成") return self.models diff --git a/src/models/audio.py b/src/models/audio.py index ac63a06..83b187b 100644 --- a/src/models/audio.py +++ b/src/models/audio.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from src.audiochunk import AudioBinary +from typing import List class AudioBinary_Config(BaseModel): """二进制音频块配置信息""" @@ -15,14 +15,6 @@ class AudioBinary_Config(BaseModel): def AudioBinary_Config_from_dict(cls, data: dict): return cls(**data) -class AudioBinary_Slice(BaseModel): - """音频块切片""" - target_Binary: AudioBinary = Field(description="目标音频块", default=None) - start_index: int = Field(description="开始位置", default=0) - end_index: int = Field(description="结束位置", default=0) - - def __call__(self): - return self.target_Binary(self.start_index, self.end_index) class _AudioBinary_data(BaseModel): """音频二进制数据""" @@ -34,3 +26,12 @@ class AudioBinary_data_list(BaseModel): def __call__(self): return self.binary_data_list + +class AudioBinary_Slice(BaseModel): + """音频块切片""" + target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None) + start_index: int = Field(description="开始位置", default=0) + end_index: int = Field(description="结束位置", default=0) + + def __call__(self): + return self.target_Binary(self.start_index, self.end_index) \ No newline at end of file diff --git a/src/pipeline/ASRpipeline.py b/src/pipeline/ASRpipeline.py index 7d215a3..369f65f 100644 --- a/src/pipeline/ASRpipeline.py +++ b/src/pipeline/ASRpipeline.py @@ -18,7 +18,6 @@ class ASRPipeline(PipelineBase): self._config: Dict[str, Any] = {} self._funtor_dict: Dict[str, Any] = {} self._subqueue_dict: Dict[str, Any] = {} - self._is_baked: bool = False def set_config(self, config: Dict[str, Any]) -> None: @@ -57,17 +56,17 @@ class ASRPipeline(PipelineBase): try: from src.funtor import FuntorFactory # 加载VAD、asr、spk funtor - self._funtor_dict["vad"] = FuntorFactory.get_funtor( + self._funtor_dict["vad"] = FuntorFactory.make_funtor( funtor_name = "vad", config = self._config, models = self._models ) - self._funtor_dict["asr"] = FuntorFactory.get_funtor( + self._funtor_dict["asr"] = FuntorFactory.make_funtor( funtor_name = "asr", config = self._config, models = self._models ) - self._funtor_dict["spk"] = FuntorFactory.get_funtor( + self._funtor_dict["spk"] = FuntorFactory.make_funtor( funtor_name = "spk", config = self._config, models = self._models diff --git a/src/utils/data_format.py b/src/utils/data_format.py new file mode 100644 index 0000000..bb845f9 --- /dev/null +++ b/src/utils/data_format.py @@ -0,0 +1,110 @@ +""" +处理各类音频数据与bytes的转换 +""" +import wave +from pydub import AudioSegment +import io + +def wav_to_bytes(wav_path: str) -> bytes: + """ + 将WAV文件读取为bytes数据。 + + 参数: + wav_path (str): WAV文件的路径。 + + 返回: + bytes: WAV文件的原始字节数据。 + + 异常: + FileNotFoundError: 如果WAV文件不存在。 + wave.Error: 如果文件不是有效的WAV文件。 + """ + try: + with wave.open(wav_path, 'rb') as wf: + # 读取所有帧 + frames = wf.readframes(wf.getnframes()) + return frames + except FileNotFoundError: + # 可以选择记录日志或重新抛出,这里为了清晰直接重新抛出 + raise FileNotFoundError(f"错误:未找到WAV文件 '{wav_path}'") + except wave.Error as e: + raise wave.Error(f"错误:打开或读取WAV文件 '{wav_path}' 失败 - {e}") + + +def bytes_to_wav(bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int): + """ + 将bytes数据写入为WAV文件。 + + 参数: + bytes_data (bytes): 音频的字节数据。 + wav_path (str): 保存WAV文件的路径。 + nchannels (int): 声道数 (例如 1 for mono, 2 for stereo)。 + sampwidth (int): 采样宽度 (字节数, 例如 2 for 16-bit audio)。 + framerate (int): 采样率 (例如 44100, 16000)。 + + 异常: + wave.Error: 如果写入WAV文件失败。 + """ + try: + with wave.open(wav_path, 'wb') as wf: + wf.setnchannels(nchannels) + wf.setsampwidth(sampwidth) + wf.setframerate(framerate) + wf.writeframes(bytes_data) + except wave.Error as e: + raise wave.Error(f"错误:写入WAV文件 '{wav_path}' 失败 - {e}") + except Exception as e: + # 捕获其他可能的写入错误 + raise Exception(f"写入WAV文件 '{wav_path}' 时发生未知错误 - {e}") + +def mp3_to_bytes(mp3_path: str) -> bytes: + """ + 将MP3文件转换为bytes数据 (原始PCM数据)。 + + 参数: + mp3_path (str): MP3文件的路径。 + + 返回: + bytes: MP3文件解码后的原始PCM字节数据。 + + 异常: + FileNotFoundError: 如果MP3文件不存在。 + pydub.exceptions.CouldntDecodeError: 如果MP3文件无法解码。 + """ + try: + audio = AudioSegment.from_mp3(mp3_path) + # 获取原始PCM数据 + return audio.raw_data + except FileNotFoundError: + raise FileNotFoundError(f"错误:未找到MP3文件 '{mp3_path}'") + except Exception as e: # pydub 可能抛出多种解码相关的错误 + raise Exception(f"错误:处理MP3文件 '{mp3_path}' 失败 - {e}") + + +def bytes_to_mp3(bytes_data: bytes, mp3_path: str, frame_rate: int, channels: int, sample_width: int, bitrate: str = "192k"): + """ + 将原始PCM bytes数据转换为MP3文件。 + + 参数: + bytes_data (bytes): 原始PCM字节数据。 + mp3_path (str): 保存MP3文件的路径。 + frame_rate (int): 原始PCM数据的采样率 (例如 44100)。 + channels (int): 原始PCM数据的声道数 (例如 1 for mono, 2 for stereo)。 + sample_width (int): 原始PCM数据的采样宽度 (字节数, 例如 2 for 16-bit)。 + bitrate (str): MP3编码的比特率 (例如 "128k", "192k", "320k")。 + + 异常: + Exception: 如果转换或写入MP3文件失败。 + """ + try: + # 从原始数据创建AudioSegment对象 + audio = AudioSegment( + data=bytes_data, + sample_width=sample_width, + frame_rate=frame_rate, + channels=channels + ) + # 导出为MP3 + audio.export(mp3_path, format="mp3", bitrate=bitrate) + except Exception as e: + raise Exception(f"错误:转换或写入MP3文件 '{mp3_path}' 失败 - {e}") diff --git a/test_main.py b/test_main.py index 596385f..d5c834b 100644 --- a/test_main.py +++ b/test_main.py @@ -1,22 +1,8 @@ +from tests.functor.vad_test import test_vad_functor from src.utils.logger import get_module_logger, setup_root_logger -from tests.modelsuse import vad_model_use_online_logic, asr_model_use_offline -import json -setup_root_logger(level="INFO",log_file="logs/test_main.log") +setup_root_logger(level="INFO", log_file="logs/test_main.log") logger = get_module_logger(__name__) -logger.info("开始测试") -vad_result = vad_model_use_online_logic("tests/vad_example.wav") -logger.info("测试结束") -if vad_result is None: - logger.warning("VAD结果为空") -else: - logger.info(f"VAD结果: {vad_result}") - -asr_result = asr_model_use_offline("tests/vad_example.wav") -# asr_result str->dict -setup_root_logger(level="INFO",log_file="logs/test_main.log") -result = asr_result[0]['sentence_info'] -for item in result: - #[{'start': 70, 'end': 2340, 'sentence': '试 错 的 过 程 很 简 单', 'timestamp': [[380, 620], [640, 740], [740, 940], [940, 1020], [1020, 1260], [1500, 1740], [1740, 1840], [1840, 2135]], 'spk': 0} - logger.info(f"spk[{item['spk']}] [{item['start']}ms:{item['end']}ms] {item['sentence'].replace(' ', '')}") \ No newline at end of file +logger.info("开始测试VAD函数器") +test_vad_functor() \ No newline at end of file diff --git a/tests/functor/vad_test.py b/tests/functor/vad_test.py new file mode 100644 index 0000000..2770dc3 --- /dev/null +++ b/tests/functor/vad_test.py @@ -0,0 +1,58 @@ +""" +Functor测试 +VAD测试 +""" +from src.functor.vad_functor import VADFunctor +from queue import Queue, Empty +from src.model_loader import ModelLoader +from src.models import AudioBinary_Config +from src.utils.data_format import wav_to_bytes +import time +from src.utils.logger import get_module_logger + +logger = get_module_logger(__name__) + +model_loader = ModelLoader() + +def test_vad_functor(): + # 加载模型 + args = { + "vad_model": "fsmn-vad", + "vad_model_revision": "v2.0.4", + "auto_update": False, + } + model_loader.load_models(args) + # 创建VAD函数器 + vad_functor = VADFunctor() + # 创建输入队列 + input_queue = Queue() + # 设置输入队列 + vad_functor.set_input_queue(input_queue) + # 设置音频配置 + vad_functor.set_audio_config(AudioBinary_Config( + chunk_size=960, + chunk_stride=480, + sample_rate=16000, + sample_width=2, + channels=1 + )) + # 设置回调函数 + vad_functor.add_callback(lambda x: print(x)) + # 设置模型 + vad_functor.set_model({ + 'vad': model_loader.models['vad'] + }) + # 启动VAD函数器 + vad_functor.run() + + # 加载数据 + f_binary = wav_to_bytes("tests/vad_example.wav") + chunk_size = 960000 + # chunk_size = len(f_binary) + print(f"f_binary: {len(f_binary)}, chunk_size: {chunk_size}, clip_num: {len(f_binary) // chunk_size}") + for i in range(0, len(f_binary), chunk_size): + binary_data = f_binary[i:i+chunk_size] + input_queue.put(binary_data) + # 等待VAD函数器结束 + time.sleep(10) + vad_functor.stop() \ No newline at end of file