[代码重构中]测试VADFuntor中,发现字节流推理问题,待进一步研究

This commit is contained in:
Ziyang.Zhang 2025-06-03 17:41:59 +08:00
parent f245c6e9df
commit b569b7e63d
9 changed files with 457 additions and 165 deletions

View File

@ -1,4 +1,4 @@
from .vad_functor import VAD from .vad_functor import VADFunctor
from .model_loader import load_models from .base import FunctorFactory
__all__ = ["VAD", "load_models"] __all__ = ["VADFunctor", "FunctorFactory"]

View File

@ -1,102 +1,129 @@
from typing import Callable """
Functor基础模块
class BaseFunctor: 该模块定义了Functor的基类,所有功能性的类(如VADPUNCASRSPK等)都应继承自这个基类
基类提供了数据处理的基本框架,包括:
- 回调函数管理
- 模型配置管理
- 线程运行控制
主要类:
BaseFunctor: Functor抽象类
FunctorFactory: Functor工厂类
"""
from abc import ABC, abstractmethod
from typing import Callable, List
from queue import Queue
class BaseFunctor(ABC):
""" """
基础函数器类, 提供数据处理的基本框架 Functor抽象类
该类实现了数据处理的基本接口, 包括数据推送处理和回调机制 该抽象类规定了所有的Functor类必须实现run()方法启动自身线程
所有具体的功能实现类都应该继承这个基类
属性: 属性:
_data (dict): 存储处理数据的字典
_callback (Callable): 处理完成后的回调函数 _callback (Callable): 处理完成后的回调函数
_model (dict): 存储模型相关的配置和实例 _model (dict): 存储模型相关的配置和实例
""" """
def __init__(self, def __init__(
data: dict or bytes = {}, self
callback: Callable = None,
model: dict = {},
): ):
""" """
初始化函数器 初始化函数器
参数: 参数:
data (dict or bytes): 初始数据, 可以是字典或字节数据
callback (Callable): 处理完成后的回调函数 callback (Callable): 处理完成后的回调函数
model (dict): 模型相关的配置和实例 model (dict): 模型相关的配置和实例
""" """
self._data: dict = {} self._callback: List[Callable] = []
self.push_data(data) self._model: dict = {}
self._callback: Callable = callback
self._model: dict = model
pass
def __call__(self, data = None): def add_callback(self, callback: Callable):
""" """
使类实例可调用, 处理数据并触发回调 添加回调函数
参数:
data: 要处理的数据, 如果为None则处理已存储的数据
返回:
处理结果
"""
# 如果传入数据, 则压入数据
if data is not None:
self.push_data(data)
# 处理数据
result = self.process()
# 如果回调函数存在, 则触发回调
if self._callback is not None and callable(self._callback):
self._callback(result)
return result
def __add__(self, other):
"""
重载加法运算符, 用于合并数据
参数:
other: 要合并的数据
返回:
self: 返回当前实例, 支持链式调用
"""
self.push_data(other)
return self
def set_callback(self, callback: Callable):
"""
设置回调函数
参数: 参数:
callback (Callable): 新的回调函数 callback (Callable): 新的回调函数
""" """
self._callback = callback self._callback.append(callback)
def set_model(self, model: dict): def set_model(self, model: dict):
""" """
设置模型配置 设置模型配置
参数: 参数:
model (dict): 新的模型配置 model (dict): 新的模型配置
""" """
self._model = model self._model = model
def push_data(self, data): def set_input_queue(self, queue: Queue):
""" """
推送数据到处理器 设置输入队列
参数: 参数:
data: 要处理的数据 queue (Queue): 新的输入队列
""" """
pass self._input_queue = queue
def process(self): @abstractmethod
def _run(self):
""" """
处理数据的核心方法 线程运行逻辑
返回: 返回:
处理结果 当达到条件时触发callback
""" """
pass
@abstractmethod
def run(self):
"""
启动_run方法线程
返回:
线程实例
"""
@abstractmethod
def _pre_check(self):
"""
预检查
返回:
预检查结果
"""
@abstractmethod
def stop(self):
"""
停止线程
返回:
停止结果
"""
class FunctorFactory:
"""
Functor工厂类
该工厂类负责创建和配置Functor实例
主要方法:
make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
创建并配置Functor实例
"""
@staticmethod
def make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
"""
创建并配置Functor实例
参数:
funtor_name (str): Functor名称
config (dict): 配置信息
models (dict): 模型信息
返回:
BaseFunctor: 创建的Functor实例
"""

View File

@ -2,59 +2,168 @@ from funasr import AutoModel
from typing import List, Dict, Any from typing import List, Dict, Any
from src.models import VADResponse from src.models import VADResponse
from src.models import AudioBinary_Config from src.models import AudioBinary_Config
from src.functor.audiochunk import AudioChunk from src.models import AudioBinary_data_list
from src.models import AudioBinary_Chunk from src.models import AudioBinary_Slice
from typing import Callable from typing import Callable
from src.functor.base import BaseFunctor
import threading
from queue import Empty, Queue
class VAD: # 日志
from src.utils.logger import get_module_logger
def __init__(self, logger = get_module_logger(__name__)
VAD_model = None,
audio_config : AudioBinary_Config = None,
callback: Callable = None,
):
# vad model
self.VAD_model = VAD_model
if self.VAD_model is None:
self.VAD_model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
# audio config
self.audio_config = audio_config
# vad result
self.vad_result = VADResponse(time_chunk_index_callback=callback)
# audio binary poll
self.audio_chunk = AudioChunk(
audio_config=self.audio_config
)
self.cache = {}
def push_binary_data(self,
binary_data: bytes,
):
# 压入二进制数据
self.audio_chunk.add_chunk(binary_data)
# 处理音频块
res = self.VAD_model.generate(input=binary_data,
cache=self.cache,
chunk_size=self.audio_config.chunk_size,
is_final=False)
# print("VAD generate", res)
if len(res[0]["value"]):
self.vad_result += VADResponse.from_raw(res)
def set_callback(self,
callback: Callable,
):
self.vad_result.time_chunk_index_callback = callback
def process_vad_result(self, callback: Callable = None):
# 处理VAD结果 class VADFunctor(BaseFunctor):
callback = callback if callback is not None else self.vad_result.time_chunk_index_callback def __init__(
self.vad_result.process_time_chunk( self
lambda x : callback( ):
AudioBinary_Chunk( super().__init__()
start_time=x["start_time"], self._model: dict = {}
end_time=x["end_time"], self._callback: List[Callable] = []
chunk=self.audio_chunk.get_chunk(x["start_time"], x["end_time"]) self._status_lock: threading.Lock = threading.Lock()
) self._input_queue: Queue = None
self._audio_config: AudioBinary_Config = None
self._is_running: bool = False
self._stop_event: bool = False
self._audio_cache: bytes = b''
self._cache: dict = {}
def set_input_queue(self, queue: Queue):
self._input_queue = queue
def set_model(self, model: dict):
self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config):
self._audio_config = audio_config
def add_callback(self, callback: Callable):
if not isinstance(self._callback, list):
self._callback = []
self._callback.append(callback)
def _process(self, data: bytes):
"""
处理数据
"""
self._audio_cache += data
if len(self._audio_cache) >= self._audio_config.chunk_size*100:
result = self._model['vad'].generate(
input=self._audio_cache,
cache=self._cache,
chunk_size=self._audio_config.chunk_size,
is_final=False,
) )
) logger.info(f"VADFunctor处理数据: {len(self._audio_cache)}, {result}")
self._audio_cache = b''
def _run(self):
"""
线程运行逻辑
监听输入队列当有数据时处理数据
当输入队列为空时, 间隔1s检测是否进入停止事件
"""
# 刷新运行状态
with self._status_lock:
self._is_running = True
self._stop_event = False
# 运行逻辑
while self._is_running:
try:
data = self._input_queue.get(True, timeout=1)
self._process(data)
self._input_queue.task_done()
# 当队列为空时间隔1s检测是否进入停止事件。
except Empty:
if self._stop_event:
break
continue
# 其他异常
except Exception as e:
logger.error(f"VADFunctor运行时发生错误: {e}")
raise e
def run(self):
"""
启动 _run 线程, 并返回线程对象
"""
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
return self._thread
def _pre_check(self):
pass
def stop(self):
with self._status_lock:
self._stop_event = True
self._thread.join()
with self._status_lock:
self._is_running = False
return True
# class VAD:
# def __init__(
# self,
# VAD_model=None,
# audio_config: AudioBinary_Config = None,
# callback: Callable = None,
# ):
# # vad model
# self.VAD_model = VAD_model
# if self.VAD_model is None:
# self.VAD_model = AutoModel(
# model="fsmn-vad", model_revision="v2.0.4", disable_update=True
# )
# # audio config
# self.audio_config = audio_config
# # vad result
# self.vad_result = VADResponse(time_chunk_index_callback=callback)
# # audio binary poll
# self.audio_chunk = AudioChunk(audio_config=self.audio_config)
# self.cache = {}
# def push_binary_data(
# self,
# binary_data: bytes,
# ):
# # 压入二进制数据
# self.audio_chunk.add_chunk(binary_data)
# # 处理音频块
# res = self.VAD_model.generate(
# input=binary_data,
# cache=self.cache,
# chunk_size=self.audio_config.chunk_size,
# is_final=False,
# )
# # print("VAD generate", res)
# if len(res[0]["value"]):
# self.vad_result += VADResponse.from_raw(res)
# def set_callback(
# self,
# callback: Callable,
# ):
# self.vad_result.time_chunk_index_callback = callback
# def process_vad_result(self, callback: Callable = None):
# # 处理VAD结果
# callback = (
# callback
# if callback is not None
# else self.vad_result.time_chunk_index_callback
# )
# self.vad_result.process_time_chunk(
# lambda x: callback(
# AudioBinary_Chunk(
# start_time=x["start_time"],
# end_time=x["end_time"],
# chunk=self.audio_chunk.get_chunk(x["start_time"], x["end_time"]),
# )
# )
# )

View File

@ -53,12 +53,12 @@ class ModelLoader:
# 直接调用等于调用self.models # 直接调用等于调用self.models
return self.models return self.models
def _load_model(self, args, model_type): def _load_model(self, input_model_args: dict, model_type: str):
""" """
加载单个模型 加载单个模型
参数: 参数:
args: 命令行参数, 包含模型配置 model_args: 模型加载字典
model_type: 模型类型, 用于确定使用哪个模型参数 model_type: 模型类型, 用于确定使用哪个模型参数
返回: 返回:
@ -81,12 +81,13 @@ class ModelLoader:
if key in ["model", "model_revision"]: if key in ["model", "model_revision"]:
# 特殊处理model和model_revision, 因为它们需要model_type前缀 # 特殊处理model和model_revision, 因为它们需要model_type前缀
if key == "model": if key == "model":
value = getattr(args, f"{model_type}_model", None) value = input_model_args.get(f"{model_type}_model", None)
else: else:
value = getattr(args, f"{model_type}_model_revision", None) value = input_model_args.get(f"{model_type}_model_revision", None)
else: else:
value = getattr(args, key, None) value = input_model_args.get(key, None)
if value is not None: if value is not None:
logger.info("替换%s模型参数: %s = %s", model_type, key, value)
model_args[key] = value model_args[key] = value
# 验证必要参数 # 验证必要参数
if not model_args["model"]: if not model_args["model"]:
@ -113,12 +114,13 @@ class ModelLoader:
# 初始化模型字典 # 初始化模型字典
self.models = {} self.models = {}
# 加载离线ASR模型 # 加载离线ASR模型
self.models["asr"] = self._load_model(args, "asr") # 检查对应键是否存在
# 2. 加载在线ASR模型 model_list = ['asr', 'asr_online', 'vad', 'punc']
self.models["asr_streaming"] = self._load_model(args, "asr_online") for model_name in model_list:
# 3. 加载VAD模型 name_model = f"{model_name}_model"
self.models["vad"] = self._load_model(args, "vad") name_model_revision = f"{model_name}_model_revision"
# 4. 加载标点符号模型(如果指定) if name_model in args:
self.models["punc"] = self._load_model(args, "punc") logger.info("加载%s模型", model_name)
self.models[model_name] = self._load_model(args, model_name)
logger.info("所有模型加载完成") logger.info("所有模型加载完成")
return self.models return self.models

View File

@ -1,5 +1,5 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from src.audiochunk import AudioBinary from typing import List
class AudioBinary_Config(BaseModel): class AudioBinary_Config(BaseModel):
"""二进制音频块配置信息""" """二进制音频块配置信息"""
@ -15,14 +15,6 @@ class AudioBinary_Config(BaseModel):
def AudioBinary_Config_from_dict(cls, data: dict): def AudioBinary_Config_from_dict(cls, data: dict):
return cls(**data) return cls(**data)
class AudioBinary_Slice(BaseModel):
"""音频块切片"""
target_Binary: AudioBinary = Field(description="目标音频块", default=None)
start_index: int = Field(description="开始位置", default=0)
end_index: int = Field(description="结束位置", default=0)
def __call__(self):
return self.target_Binary(self.start_index, self.end_index)
class _AudioBinary_data(BaseModel): class _AudioBinary_data(BaseModel):
"""音频二进制数据""" """音频二进制数据"""
@ -34,3 +26,12 @@ class AudioBinary_data_list(BaseModel):
def __call__(self): def __call__(self):
return self.binary_data_list return self.binary_data_list
class AudioBinary_Slice(BaseModel):
"""音频块切片"""
target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None)
start_index: int = Field(description="开始位置", default=0)
end_index: int = Field(description="结束位置", default=0)
def __call__(self):
return self.target_Binary(self.start_index, self.end_index)

View File

@ -18,7 +18,6 @@ class ASRPipeline(PipelineBase):
self._config: Dict[str, Any] = {} self._config: Dict[str, Any] = {}
self._funtor_dict: Dict[str, Any] = {} self._funtor_dict: Dict[str, Any] = {}
self._subqueue_dict: Dict[str, Any] = {} self._subqueue_dict: Dict[str, Any] = {}
self._is_baked: bool = False self._is_baked: bool = False
def set_config(self, config: Dict[str, Any]) -> None: def set_config(self, config: Dict[str, Any]) -> None:
@ -57,17 +56,17 @@ class ASRPipeline(PipelineBase):
try: try:
from src.funtor import FuntorFactory from src.funtor import FuntorFactory
# 加载VAD、asr、spk funtor # 加载VAD、asr、spk funtor
self._funtor_dict["vad"] = FuntorFactory.get_funtor( self._funtor_dict["vad"] = FuntorFactory.make_funtor(
funtor_name = "vad", funtor_name = "vad",
config = self._config, config = self._config,
models = self._models models = self._models
) )
self._funtor_dict["asr"] = FuntorFactory.get_funtor( self._funtor_dict["asr"] = FuntorFactory.make_funtor(
funtor_name = "asr", funtor_name = "asr",
config = self._config, config = self._config,
models = self._models models = self._models
) )
self._funtor_dict["spk"] = FuntorFactory.get_funtor( self._funtor_dict["spk"] = FuntorFactory.make_funtor(
funtor_name = "spk", funtor_name = "spk",
config = self._config, config = self._config,
models = self._models models = self._models

110
src/utils/data_format.py Normal file
View File

@ -0,0 +1,110 @@
"""
处理各类音频数据与bytes的转换
"""
import wave
from pydub import AudioSegment
import io
def wav_to_bytes(wav_path: str) -> bytes:
"""
将WAV文件读取为bytes数据
参数:
wav_path (str): WAV文件的路径
返回:
bytes: WAV文件的原始字节数据
异常:
FileNotFoundError: 如果WAV文件不存在
wave.Error: 如果文件不是有效的WAV文件
"""
try:
with wave.open(wav_path, 'rb') as wf:
# 读取所有帧
frames = wf.readframes(wf.getnframes())
return frames
except FileNotFoundError:
# 可以选择记录日志或重新抛出,这里为了清晰直接重新抛出
raise FileNotFoundError(f"错误未找到WAV文件 '{wav_path}'")
except wave.Error as e:
raise wave.Error(f"错误打开或读取WAV文件 '{wav_path}' 失败 - {e}")
def bytes_to_wav(bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int):
"""
将bytes数据写入为WAV文件
参数:
bytes_data (bytes): 音频的字节数据
wav_path (str): 保存WAV文件的路径
nchannels (int): 声道数 (例如 1 for mono, 2 for stereo)
sampwidth (int): 采样宽度 (字节数, 例如 2 for 16-bit audio)
framerate (int): 采样率 (例如 44100, 16000)
异常:
wave.Error: 如果写入WAV文件失败
"""
try:
with wave.open(wav_path, 'wb') as wf:
wf.setnchannels(nchannels)
wf.setsampwidth(sampwidth)
wf.setframerate(framerate)
wf.writeframes(bytes_data)
except wave.Error as e:
raise wave.Error(f"错误写入WAV文件 '{wav_path}' 失败 - {e}")
except Exception as e:
# 捕获其他可能的写入错误
raise Exception(f"写入WAV文件 '{wav_path}' 时发生未知错误 - {e}")
def mp3_to_bytes(mp3_path: str) -> bytes:
"""
将MP3文件转换为bytes数据 (原始PCM数据)
参数:
mp3_path (str): MP3文件的路径
返回:
bytes: MP3文件解码后的原始PCM字节数据
异常:
FileNotFoundError: 如果MP3文件不存在
pydub.exceptions.CouldntDecodeError: 如果MP3文件无法解码
"""
try:
audio = AudioSegment.from_mp3(mp3_path)
# 获取原始PCM数据
return audio.raw_data
except FileNotFoundError:
raise FileNotFoundError(f"错误未找到MP3文件 '{mp3_path}'")
except Exception as e: # pydub 可能抛出多种解码相关的错误
raise Exception(f"错误处理MP3文件 '{mp3_path}' 失败 - {e}")
def bytes_to_mp3(bytes_data: bytes, mp3_path: str, frame_rate: int, channels: int, sample_width: int, bitrate: str = "192k"):
"""
将原始PCM bytes数据转换为MP3文件
参数:
bytes_data (bytes): 原始PCM字节数据
mp3_path (str): 保存MP3文件的路径
frame_rate (int): 原始PCM数据的采样率 (例如 44100)
channels (int): 原始PCM数据的声道数 (例如 1 for mono, 2 for stereo)
sample_width (int): 原始PCM数据的采样宽度 (字节数, 例如 2 for 16-bit)
bitrate (str): MP3编码的比特率 (例如 "128k", "192k", "320k")
异常:
Exception: 如果转换或写入MP3文件失败
"""
try:
# 从原始数据创建AudioSegment对象
audio = AudioSegment(
data=bytes_data,
sample_width=sample_width,
frame_rate=frame_rate,
channels=channels
)
# 导出为MP3
audio.export(mp3_path, format="mp3", bitrate=bitrate)
except Exception as e:
raise Exception(f"错误转换或写入MP3文件 '{mp3_path}' 失败 - {e}")

View File

@ -1,22 +1,8 @@
from tests.functor.vad_test import test_vad_functor
from src.utils.logger import get_module_logger, setup_root_logger from src.utils.logger import get_module_logger, setup_root_logger
from tests.modelsuse import vad_model_use_online_logic, asr_model_use_offline
import json
setup_root_logger(level="INFO",log_file="logs/test_main.log")
setup_root_logger(level="INFO", log_file="logs/test_main.log")
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
logger.info("开始测试") logger.info("开始测试VAD函数器")
vad_result = vad_model_use_online_logic("tests/vad_example.wav") test_vad_functor()
logger.info("测试结束")
if vad_result is None:
logger.warning("VAD结果为空")
else:
logger.info(f"VAD结果: {vad_result}")
asr_result = asr_model_use_offline("tests/vad_example.wav")
# asr_result str->dict
setup_root_logger(level="INFO",log_file="logs/test_main.log")
result = asr_result[0]['sentence_info']
for item in result:
#[{'start': 70, 'end': 2340, 'sentence': '试 错 的 过 程 很 简 单', 'timestamp': [[380, 620], [640, 740], [740, 940], [940, 1020], [1020, 1260], [1500, 1740], [1740, 1840], [1840, 2135]], 'spk': 0}
logger.info(f"spk[{item['spk']}] [{item['start']}ms:{item['end']}ms] {item['sentence'].replace(' ', '')}")

58
tests/functor/vad_test.py Normal file
View File

@ -0,0 +1,58 @@
"""
Functor测试
VAD测试
"""
from src.functor.vad_functor import VADFunctor
from queue import Queue, Empty
from src.model_loader import ModelLoader
from src.models import AudioBinary_Config
from src.utils.data_format import wav_to_bytes
import time
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
model_loader = ModelLoader()
def test_vad_functor():
# 加载模型
args = {
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"auto_update": False,
}
model_loader.load_models(args)
# 创建VAD函数器
vad_functor = VADFunctor()
# 创建输入队列
input_queue = Queue()
# 设置输入队列
vad_functor.set_input_queue(input_queue)
# 设置音频配置
vad_functor.set_audio_config(AudioBinary_Config(
chunk_size=960,
chunk_stride=480,
sample_rate=16000,
sample_width=2,
channels=1
))
# 设置回调函数
vad_functor.add_callback(lambda x: print(x))
# 设置模型
vad_functor.set_model({
'vad': model_loader.models['vad']
})
# 启动VAD函数器
vad_functor.run()
# 加载数据
f_binary = wav_to_bytes("tests/vad_example.wav")
chunk_size = 960000
# chunk_size = len(f_binary)
print(f"f_binary: {len(f_binary)}, chunk_size: {chunk_size}, clip_num: {len(f_binary) // chunk_size}")
for i in range(0, len(f_binary), chunk_size):
binary_data = f_binary[i:i+chunk_size]
input_queue.put(binary_data)
# 等待VAD函数器结束
time.sleep(10)
vad_functor.stop()