[代码重构中]测试VADFuntor中,发现字节流推理问题,待进一步研究
This commit is contained in:
parent
f245c6e9df
commit
b569b7e63d
@ -1,4 +1,4 @@
|
||||
from .vad_functor import VAD
|
||||
from .model_loader import load_models
|
||||
from .vad_functor import VADFunctor
|
||||
from .base import FunctorFactory
|
||||
|
||||
__all__ = ["VAD", "load_models"]
|
||||
__all__ = ["VADFunctor", "FunctorFactory"]
|
@ -1,78 +1,52 @@
|
||||
from typing import Callable
|
||||
"""
|
||||
Functor基础模块
|
||||
|
||||
class BaseFunctor:
|
||||
该模块定义了Functor的基类,所有功能性的类(如VAD、PUNC、ASR、SPK等)都应继承自这个基类。
|
||||
基类提供了数据处理的基本框架,包括:
|
||||
- 回调函数管理
|
||||
- 模型配置管理
|
||||
- 线程运行控制
|
||||
|
||||
主要类:
|
||||
BaseFunctor: Functor抽象类
|
||||
FunctorFactory: Functor工厂类
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List
|
||||
from queue import Queue
|
||||
|
||||
class BaseFunctor(ABC):
|
||||
"""
|
||||
基础函数器类, 提供数据处理的基本框架
|
||||
Functor抽象类
|
||||
|
||||
该类实现了数据处理的基本接口, 包括数据推送、处理和回调机制。
|
||||
所有具体的功能实现类都应该继承这个基类。
|
||||
该抽象类规定了所有的Functor类必须实现run()方法启动自身线程
|
||||
|
||||
属性:
|
||||
_data (dict): 存储处理数据的字典
|
||||
_callback (Callable): 处理完成后的回调函数
|
||||
_model (dict): 存储模型相关的配置和实例
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data: dict or bytes = {},
|
||||
callback: Callable = None,
|
||||
model: dict = {},
|
||||
def __init__(
|
||||
self
|
||||
):
|
||||
"""
|
||||
初始化函数器
|
||||
|
||||
参数:
|
||||
data (dict or bytes): 初始数据, 可以是字典或字节数据
|
||||
callback (Callable): 处理完成后的回调函数
|
||||
model (dict): 模型相关的配置和实例
|
||||
"""
|
||||
self._data: dict = {}
|
||||
self.push_data(data)
|
||||
self._callback: Callable = callback
|
||||
self._model: dict = model
|
||||
pass
|
||||
self._callback: List[Callable] = []
|
||||
self._model: dict = {}
|
||||
|
||||
def __call__(self, data = None):
|
||||
def add_callback(self, callback: Callable):
|
||||
"""
|
||||
使类实例可调用, 处理数据并触发回调
|
||||
|
||||
参数:
|
||||
data: 要处理的数据, 如果为None则处理已存储的数据
|
||||
|
||||
返回:
|
||||
处理结果
|
||||
"""
|
||||
# 如果传入数据, 则压入数据
|
||||
if data is not None:
|
||||
self.push_data(data)
|
||||
# 处理数据
|
||||
result = self.process()
|
||||
# 如果回调函数存在, 则触发回调
|
||||
if self._callback is not None and callable(self._callback):
|
||||
self._callback(result)
|
||||
return result
|
||||
|
||||
def __add__(self, other):
|
||||
"""
|
||||
重载加法运算符, 用于合并数据
|
||||
|
||||
参数:
|
||||
other: 要合并的数据
|
||||
|
||||
返回:
|
||||
self: 返回当前实例, 支持链式调用
|
||||
"""
|
||||
self.push_data(other)
|
||||
return self
|
||||
|
||||
def set_callback(self, callback: Callable):
|
||||
"""
|
||||
设置回调函数
|
||||
添加回调函数
|
||||
|
||||
参数:
|
||||
callback (Callable): 新的回调函数
|
||||
"""
|
||||
self._callback = callback
|
||||
self._callback.append(callback)
|
||||
|
||||
def set_model(self, model: dict):
|
||||
"""
|
||||
@ -83,20 +57,73 @@ class BaseFunctor:
|
||||
"""
|
||||
self._model = model
|
||||
|
||||
def push_data(self, data):
|
||||
def set_input_queue(self, queue: Queue):
|
||||
"""
|
||||
推送数据到处理器
|
||||
设置输入队列
|
||||
|
||||
参数:
|
||||
data: 要处理的数据
|
||||
queue (Queue): 新的输入队列
|
||||
"""
|
||||
pass
|
||||
self._input_queue = queue
|
||||
|
||||
def process(self):
|
||||
@abstractmethod
|
||||
def _run(self):
|
||||
"""
|
||||
处理数据的核心方法
|
||||
线程运行逻辑
|
||||
|
||||
返回:
|
||||
处理结果
|
||||
当达到条件时触发callback
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
"""
|
||||
启动_run方法线程
|
||||
|
||||
返回:
|
||||
线程实例
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _pre_check(self):
|
||||
"""
|
||||
预检查
|
||||
|
||||
返回:
|
||||
预检查结果
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def stop(self):
|
||||
"""
|
||||
停止线程
|
||||
|
||||
返回:
|
||||
停止结果
|
||||
"""
|
||||
|
||||
class FunctorFactory:
|
||||
"""
|
||||
Functor工厂类
|
||||
|
||||
该工厂类负责创建和配置Functor实例
|
||||
|
||||
主要方法:
|
||||
make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
|
||||
创建并配置Functor实例
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
|
||||
"""
|
||||
创建并配置Functor实例
|
||||
|
||||
参数:
|
||||
funtor_name (str): Functor名称
|
||||
config (dict): 配置信息
|
||||
models (dict): 模型信息
|
||||
|
||||
返回:
|
||||
BaseFunctor: 创建的Functor实例
|
||||
"""
|
||||
|
@ -2,59 +2,168 @@ from funasr import AutoModel
|
||||
from typing import List, Dict, Any
|
||||
from src.models import VADResponse
|
||||
from src.models import AudioBinary_Config
|
||||
from src.functor.audiochunk import AudioChunk
|
||||
from src.models import AudioBinary_Chunk
|
||||
from src.models import AudioBinary_data_list
|
||||
from src.models import AudioBinary_Slice
|
||||
from typing import Callable
|
||||
from src.functor.base import BaseFunctor
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
|
||||
class VAD:
|
||||
# 日志
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
def __init__(self,
|
||||
VAD_model = None,
|
||||
audio_config : AudioBinary_Config = None,
|
||||
callback: Callable = None,
|
||||
):
|
||||
# vad model
|
||||
self.VAD_model = VAD_model
|
||||
if self.VAD_model is None:
|
||||
self.VAD_model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
|
||||
# audio config
|
||||
self.audio_config = audio_config
|
||||
# vad result
|
||||
self.vad_result = VADResponse(time_chunk_index_callback=callback)
|
||||
# audio binary poll
|
||||
self.audio_chunk = AudioChunk(
|
||||
audio_config=self.audio_config
|
||||
)
|
||||
self.cache = {}
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
def push_binary_data(self,
|
||||
binary_data: bytes,
|
||||
):
|
||||
# 压入二进制数据
|
||||
self.audio_chunk.add_chunk(binary_data)
|
||||
# 处理音频块
|
||||
res = self.VAD_model.generate(input=binary_data,
|
||||
cache=self.cache,
|
||||
chunk_size=self.audio_config.chunk_size,
|
||||
is_final=False)
|
||||
# print("VAD generate", res)
|
||||
if len(res[0]["value"]):
|
||||
self.vad_result += VADResponse.from_raw(res)
|
||||
|
||||
def set_callback(self,
|
||||
callback: Callable,
|
||||
):
|
||||
self.vad_result.time_chunk_index_callback = callback
|
||||
class VADFunctor(BaseFunctor):
|
||||
def __init__(
|
||||
self
|
||||
):
|
||||
super().__init__()
|
||||
self._model: dict = {}
|
||||
self._callback: List[Callable] = []
|
||||
self._status_lock: threading.Lock = threading.Lock()
|
||||
self._input_queue: Queue = None
|
||||
self._audio_config: AudioBinary_Config = None
|
||||
self._is_running: bool = False
|
||||
self._stop_event: bool = False
|
||||
self._audio_cache: bytes = b''
|
||||
self._cache: dict = {}
|
||||
|
||||
def process_vad_result(self, callback: Callable = None):
|
||||
# 处理VAD结果
|
||||
callback = callback if callback is not None else self.vad_result.time_chunk_index_callback
|
||||
self.vad_result.process_time_chunk(
|
||||
lambda x : callback(
|
||||
AudioBinary_Chunk(
|
||||
start_time=x["start_time"],
|
||||
end_time=x["end_time"],
|
||||
chunk=self.audio_chunk.get_chunk(x["start_time"], x["end_time"])
|
||||
)
|
||||
def set_input_queue(self, queue: Queue):
|
||||
self._input_queue = queue
|
||||
|
||||
def set_model(self, model: dict):
|
||||
self._model = model
|
||||
|
||||
def set_audio_config(self, audio_config: AudioBinary_Config):
|
||||
self._audio_config = audio_config
|
||||
|
||||
def add_callback(self, callback: Callable):
|
||||
if not isinstance(self._callback, list):
|
||||
self._callback = []
|
||||
self._callback.append(callback)
|
||||
|
||||
def _process(self, data: bytes):
|
||||
"""
|
||||
处理数据
|
||||
"""
|
||||
self._audio_cache += data
|
||||
if len(self._audio_cache) >= self._audio_config.chunk_size*100:
|
||||
result = self._model['vad'].generate(
|
||||
input=self._audio_cache,
|
||||
cache=self._cache,
|
||||
chunk_size=self._audio_config.chunk_size,
|
||||
is_final=False,
|
||||
)
|
||||
)
|
||||
logger.info(f"VADFunctor处理数据: {len(self._audio_cache)}, {result}")
|
||||
self._audio_cache = b''
|
||||
|
||||
|
||||
def _run(self):
|
||||
"""
|
||||
线程运行逻辑
|
||||
监听输入队列,当有数据时,处理数据
|
||||
当输入队列为空时, 间隔1s检测是否进入停止事件。
|
||||
"""
|
||||
# 刷新运行状态
|
||||
with self._status_lock:
|
||||
self._is_running = True
|
||||
self._stop_event = False
|
||||
# 运行逻辑
|
||||
while self._is_running:
|
||||
try:
|
||||
data = self._input_queue.get(True, timeout=1)
|
||||
self._process(data)
|
||||
self._input_queue.task_done()
|
||||
# 当队列为空时,间隔1s检测是否进入停止事件。
|
||||
except Empty:
|
||||
if self._stop_event:
|
||||
break
|
||||
continue
|
||||
# 其他异常
|
||||
except Exception as e:
|
||||
logger.error(f"VADFunctor运行时发生错误: {e}")
|
||||
raise e
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
启动 _run 线程, 并返回线程对象
|
||||
"""
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
return self._thread
|
||||
|
||||
def _pre_check(self):
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
with self._status_lock:
|
||||
self._stop_event = True
|
||||
self._thread.join()
|
||||
with self._status_lock:
|
||||
self._is_running = False
|
||||
return True
|
||||
|
||||
|
||||
# class VAD:
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# VAD_model=None,
|
||||
# audio_config: AudioBinary_Config = None,
|
||||
# callback: Callable = None,
|
||||
# ):
|
||||
# # vad model
|
||||
# self.VAD_model = VAD_model
|
||||
# if self.VAD_model is None:
|
||||
# self.VAD_model = AutoModel(
|
||||
# model="fsmn-vad", model_revision="v2.0.4", disable_update=True
|
||||
# )
|
||||
# # audio config
|
||||
# self.audio_config = audio_config
|
||||
# # vad result
|
||||
# self.vad_result = VADResponse(time_chunk_index_callback=callback)
|
||||
# # audio binary poll
|
||||
# self.audio_chunk = AudioChunk(audio_config=self.audio_config)
|
||||
# self.cache = {}
|
||||
|
||||
# def push_binary_data(
|
||||
# self,
|
||||
# binary_data: bytes,
|
||||
# ):
|
||||
# # 压入二进制数据
|
||||
# self.audio_chunk.add_chunk(binary_data)
|
||||
# # 处理音频块
|
||||
# res = self.VAD_model.generate(
|
||||
# input=binary_data,
|
||||
# cache=self.cache,
|
||||
# chunk_size=self.audio_config.chunk_size,
|
||||
# is_final=False,
|
||||
# )
|
||||
# # print("VAD generate", res)
|
||||
# if len(res[0]["value"]):
|
||||
# self.vad_result += VADResponse.from_raw(res)
|
||||
|
||||
# def set_callback(
|
||||
# self,
|
||||
# callback: Callable,
|
||||
# ):
|
||||
# self.vad_result.time_chunk_index_callback = callback
|
||||
|
||||
# def process_vad_result(self, callback: Callable = None):
|
||||
# # 处理VAD结果
|
||||
# callback = (
|
||||
# callback
|
||||
# if callback is not None
|
||||
# else self.vad_result.time_chunk_index_callback
|
||||
# )
|
||||
# self.vad_result.process_time_chunk(
|
||||
# lambda x: callback(
|
||||
# AudioBinary_Chunk(
|
||||
# start_time=x["start_time"],
|
||||
# end_time=x["end_time"],
|
||||
# chunk=self.audio_chunk.get_chunk(x["start_time"], x["end_time"]),
|
||||
# )
|
||||
# )
|
||||
# )
|
||||
|
@ -53,12 +53,12 @@ class ModelLoader:
|
||||
# 直接调用等于调用self.models
|
||||
return self.models
|
||||
|
||||
def _load_model(self, args, model_type):
|
||||
def _load_model(self, input_model_args: dict, model_type: str):
|
||||
"""
|
||||
加载单个模型
|
||||
|
||||
参数:
|
||||
args: 命令行参数, 包含模型配置
|
||||
model_args: 模型加载字典
|
||||
model_type: 模型类型, 用于确定使用哪个模型参数
|
||||
|
||||
返回:
|
||||
@ -81,12 +81,13 @@ class ModelLoader:
|
||||
if key in ["model", "model_revision"]:
|
||||
# 特殊处理model和model_revision, 因为它们需要model_type前缀
|
||||
if key == "model":
|
||||
value = getattr(args, f"{model_type}_model", None)
|
||||
value = input_model_args.get(f"{model_type}_model", None)
|
||||
else:
|
||||
value = getattr(args, f"{model_type}_model_revision", None)
|
||||
value = input_model_args.get(f"{model_type}_model_revision", None)
|
||||
else:
|
||||
value = getattr(args, key, None)
|
||||
value = input_model_args.get(key, None)
|
||||
if value is not None:
|
||||
logger.info("替换%s模型参数: %s = %s", model_type, key, value)
|
||||
model_args[key] = value
|
||||
# 验证必要参数
|
||||
if not model_args["model"]:
|
||||
@ -113,12 +114,13 @@ class ModelLoader:
|
||||
# 初始化模型字典
|
||||
self.models = {}
|
||||
# 加载离线ASR模型
|
||||
self.models["asr"] = self._load_model(args, "asr")
|
||||
# 2. 加载在线ASR模型
|
||||
self.models["asr_streaming"] = self._load_model(args, "asr_online")
|
||||
# 3. 加载VAD模型
|
||||
self.models["vad"] = self._load_model(args, "vad")
|
||||
# 4. 加载标点符号模型(如果指定)
|
||||
self.models["punc"] = self._load_model(args, "punc")
|
||||
# 检查对应键是否存在
|
||||
model_list = ['asr', 'asr_online', 'vad', 'punc']
|
||||
for model_name in model_list:
|
||||
name_model = f"{model_name}_model"
|
||||
name_model_revision = f"{model_name}_model_revision"
|
||||
if name_model in args:
|
||||
logger.info("加载%s模型", model_name)
|
||||
self.models[model_name] = self._load_model(args, model_name)
|
||||
logger.info("所有模型加载完成")
|
||||
return self.models
|
||||
|
@ -1,5 +1,5 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from src.audiochunk import AudioBinary
|
||||
from typing import List
|
||||
|
||||
class AudioBinary_Config(BaseModel):
|
||||
"""二进制音频块配置信息"""
|
||||
@ -15,14 +15,6 @@ class AudioBinary_Config(BaseModel):
|
||||
def AudioBinary_Config_from_dict(cls, data: dict):
|
||||
return cls(**data)
|
||||
|
||||
class AudioBinary_Slice(BaseModel):
|
||||
"""音频块切片"""
|
||||
target_Binary: AudioBinary = Field(description="目标音频块", default=None)
|
||||
start_index: int = Field(description="开始位置", default=0)
|
||||
end_index: int = Field(description="结束位置", default=0)
|
||||
|
||||
def __call__(self):
|
||||
return self.target_Binary(self.start_index, self.end_index)
|
||||
|
||||
class _AudioBinary_data(BaseModel):
|
||||
"""音频二进制数据"""
|
||||
@ -34,3 +26,12 @@ class AudioBinary_data_list(BaseModel):
|
||||
|
||||
def __call__(self):
|
||||
return self.binary_data_list
|
||||
|
||||
class AudioBinary_Slice(BaseModel):
|
||||
"""音频块切片"""
|
||||
target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None)
|
||||
start_index: int = Field(description="开始位置", default=0)
|
||||
end_index: int = Field(description="结束位置", default=0)
|
||||
|
||||
def __call__(self):
|
||||
return self.target_Binary(self.start_index, self.end_index)
|
@ -18,7 +18,6 @@ class ASRPipeline(PipelineBase):
|
||||
self._config: Dict[str, Any] = {}
|
||||
self._funtor_dict: Dict[str, Any] = {}
|
||||
self._subqueue_dict: Dict[str, Any] = {}
|
||||
|
||||
self._is_baked: bool = False
|
||||
|
||||
def set_config(self, config: Dict[str, Any]) -> None:
|
||||
@ -57,17 +56,17 @@ class ASRPipeline(PipelineBase):
|
||||
try:
|
||||
from src.funtor import FuntorFactory
|
||||
# 加载VAD、asr、spk funtor
|
||||
self._funtor_dict["vad"] = FuntorFactory.get_funtor(
|
||||
self._funtor_dict["vad"] = FuntorFactory.make_funtor(
|
||||
funtor_name = "vad",
|
||||
config = self._config,
|
||||
models = self._models
|
||||
)
|
||||
self._funtor_dict["asr"] = FuntorFactory.get_funtor(
|
||||
self._funtor_dict["asr"] = FuntorFactory.make_funtor(
|
||||
funtor_name = "asr",
|
||||
config = self._config,
|
||||
models = self._models
|
||||
)
|
||||
self._funtor_dict["spk"] = FuntorFactory.get_funtor(
|
||||
self._funtor_dict["spk"] = FuntorFactory.make_funtor(
|
||||
funtor_name = "spk",
|
||||
config = self._config,
|
||||
models = self._models
|
||||
|
110
src/utils/data_format.py
Normal file
110
src/utils/data_format.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""
|
||||
处理各类音频数据与bytes的转换
|
||||
"""
|
||||
import wave
|
||||
from pydub import AudioSegment
|
||||
import io
|
||||
|
||||
def wav_to_bytes(wav_path: str) -> bytes:
|
||||
"""
|
||||
将WAV文件读取为bytes数据。
|
||||
|
||||
参数:
|
||||
wav_path (str): WAV文件的路径。
|
||||
|
||||
返回:
|
||||
bytes: WAV文件的原始字节数据。
|
||||
|
||||
异常:
|
||||
FileNotFoundError: 如果WAV文件不存在。
|
||||
wave.Error: 如果文件不是有效的WAV文件。
|
||||
"""
|
||||
try:
|
||||
with wave.open(wav_path, 'rb') as wf:
|
||||
# 读取所有帧
|
||||
frames = wf.readframes(wf.getnframes())
|
||||
return frames
|
||||
except FileNotFoundError:
|
||||
# 可以选择记录日志或重新抛出,这里为了清晰直接重新抛出
|
||||
raise FileNotFoundError(f"错误:未找到WAV文件 '{wav_path}'")
|
||||
except wave.Error as e:
|
||||
raise wave.Error(f"错误:打开或读取WAV文件 '{wav_path}' 失败 - {e}")
|
||||
|
||||
|
||||
def bytes_to_wav(bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int):
|
||||
"""
|
||||
将bytes数据写入为WAV文件。
|
||||
|
||||
参数:
|
||||
bytes_data (bytes): 音频的字节数据。
|
||||
wav_path (str): 保存WAV文件的路径。
|
||||
nchannels (int): 声道数 (例如 1 for mono, 2 for stereo)。
|
||||
sampwidth (int): 采样宽度 (字节数, 例如 2 for 16-bit audio)。
|
||||
framerate (int): 采样率 (例如 44100, 16000)。
|
||||
|
||||
异常:
|
||||
wave.Error: 如果写入WAV文件失败。
|
||||
"""
|
||||
try:
|
||||
with wave.open(wav_path, 'wb') as wf:
|
||||
wf.setnchannels(nchannels)
|
||||
wf.setsampwidth(sampwidth)
|
||||
wf.setframerate(framerate)
|
||||
wf.writeframes(bytes_data)
|
||||
except wave.Error as e:
|
||||
raise wave.Error(f"错误:写入WAV文件 '{wav_path}' 失败 - {e}")
|
||||
except Exception as e:
|
||||
# 捕获其他可能的写入错误
|
||||
raise Exception(f"写入WAV文件 '{wav_path}' 时发生未知错误 - {e}")
|
||||
|
||||
def mp3_to_bytes(mp3_path: str) -> bytes:
|
||||
"""
|
||||
将MP3文件转换为bytes数据 (原始PCM数据)。
|
||||
|
||||
参数:
|
||||
mp3_path (str): MP3文件的路径。
|
||||
|
||||
返回:
|
||||
bytes: MP3文件解码后的原始PCM字节数据。
|
||||
|
||||
异常:
|
||||
FileNotFoundError: 如果MP3文件不存在。
|
||||
pydub.exceptions.CouldntDecodeError: 如果MP3文件无法解码。
|
||||
"""
|
||||
try:
|
||||
audio = AudioSegment.from_mp3(mp3_path)
|
||||
# 获取原始PCM数据
|
||||
return audio.raw_data
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"错误:未找到MP3文件 '{mp3_path}'")
|
||||
except Exception as e: # pydub 可能抛出多种解码相关的错误
|
||||
raise Exception(f"错误:处理MP3文件 '{mp3_path}' 失败 - {e}")
|
||||
|
||||
|
||||
def bytes_to_mp3(bytes_data: bytes, mp3_path: str, frame_rate: int, channels: int, sample_width: int, bitrate: str = "192k"):
|
||||
"""
|
||||
将原始PCM bytes数据转换为MP3文件。
|
||||
|
||||
参数:
|
||||
bytes_data (bytes): 原始PCM字节数据。
|
||||
mp3_path (str): 保存MP3文件的路径。
|
||||
frame_rate (int): 原始PCM数据的采样率 (例如 44100)。
|
||||
channels (int): 原始PCM数据的声道数 (例如 1 for mono, 2 for stereo)。
|
||||
sample_width (int): 原始PCM数据的采样宽度 (字节数, 例如 2 for 16-bit)。
|
||||
bitrate (str): MP3编码的比特率 (例如 "128k", "192k", "320k")。
|
||||
|
||||
异常:
|
||||
Exception: 如果转换或写入MP3文件失败。
|
||||
"""
|
||||
try:
|
||||
# 从原始数据创建AudioSegment对象
|
||||
audio = AudioSegment(
|
||||
data=bytes_data,
|
||||
sample_width=sample_width,
|
||||
frame_rate=frame_rate,
|
||||
channels=channels
|
||||
)
|
||||
# 导出为MP3
|
||||
audio.export(mp3_path, format="mp3", bitrate=bitrate)
|
||||
except Exception as e:
|
||||
raise Exception(f"错误:转换或写入MP3文件 '{mp3_path}' 失败 - {e}")
|
22
test_main.py
22
test_main.py
@ -1,22 +1,8 @@
|
||||
from tests.functor.vad_test import test_vad_functor
|
||||
from src.utils.logger import get_module_logger, setup_root_logger
|
||||
from tests.modelsuse import vad_model_use_online_logic, asr_model_use_offline
|
||||
import json
|
||||
setup_root_logger(level="INFO",log_file="logs/test_main.log")
|
||||
|
||||
setup_root_logger(level="INFO", log_file="logs/test_main.log")
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
logger.info("开始测试")
|
||||
vad_result = vad_model_use_online_logic("tests/vad_example.wav")
|
||||
logger.info("测试结束")
|
||||
if vad_result is None:
|
||||
logger.warning("VAD结果为空")
|
||||
else:
|
||||
logger.info(f"VAD结果: {vad_result}")
|
||||
|
||||
asr_result = asr_model_use_offline("tests/vad_example.wav")
|
||||
# asr_result str->dict
|
||||
setup_root_logger(level="INFO",log_file="logs/test_main.log")
|
||||
result = asr_result[0]['sentence_info']
|
||||
for item in result:
|
||||
#[{'start': 70, 'end': 2340, 'sentence': '试 错 的 过 程 很 简 单', 'timestamp': [[380, 620], [640, 740], [740, 940], [940, 1020], [1020, 1260], [1500, 1740], [1740, 1840], [1840, 2135]], 'spk': 0}
|
||||
logger.info(f"spk[{item['spk']}] [{item['start']}ms:{item['end']}ms] {item['sentence'].replace(' ', '')}")
|
||||
logger.info("开始测试VAD函数器")
|
||||
test_vad_functor()
|
58
tests/functor/vad_test.py
Normal file
58
tests/functor/vad_test.py
Normal file
@ -0,0 +1,58 @@
|
||||
"""
|
||||
Functor测试
|
||||
VAD测试
|
||||
"""
|
||||
from src.functor.vad_functor import VADFunctor
|
||||
from queue import Queue, Empty
|
||||
from src.model_loader import ModelLoader
|
||||
from src.models import AudioBinary_Config
|
||||
from src.utils.data_format import wav_to_bytes
|
||||
import time
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
model_loader = ModelLoader()
|
||||
|
||||
def test_vad_functor():
|
||||
# 加载模型
|
||||
args = {
|
||||
"vad_model": "fsmn-vad",
|
||||
"vad_model_revision": "v2.0.4",
|
||||
"auto_update": False,
|
||||
}
|
||||
model_loader.load_models(args)
|
||||
# 创建VAD函数器
|
||||
vad_functor = VADFunctor()
|
||||
# 创建输入队列
|
||||
input_queue = Queue()
|
||||
# 设置输入队列
|
||||
vad_functor.set_input_queue(input_queue)
|
||||
# 设置音频配置
|
||||
vad_functor.set_audio_config(AudioBinary_Config(
|
||||
chunk_size=960,
|
||||
chunk_stride=480,
|
||||
sample_rate=16000,
|
||||
sample_width=2,
|
||||
channels=1
|
||||
))
|
||||
# 设置回调函数
|
||||
vad_functor.add_callback(lambda x: print(x))
|
||||
# 设置模型
|
||||
vad_functor.set_model({
|
||||
'vad': model_loader.models['vad']
|
||||
})
|
||||
# 启动VAD函数器
|
||||
vad_functor.run()
|
||||
|
||||
# 加载数据
|
||||
f_binary = wav_to_bytes("tests/vad_example.wav")
|
||||
chunk_size = 960000
|
||||
# chunk_size = len(f_binary)
|
||||
print(f"f_binary: {len(f_binary)}, chunk_size: {chunk_size}, clip_num: {len(f_binary) // chunk_size}")
|
||||
for i in range(0, len(f_binary), chunk_size):
|
||||
binary_data = f_binary[i:i+chunk_size]
|
||||
input_queue.put(binary_data)
|
||||
# 等待VAD函数器结束
|
||||
time.sleep(10)
|
||||
vad_functor.stop()
|
Loading…
x
Reference in New Issue
Block a user