[代码重构中]重构model_loader与audio_chunk,全局单例模式管理模型加载与audiobinary数据存储单元类。删除readme中不需要的MIT许可证。

This commit is contained in:
Ziyang.Zhang 2025-05-28 10:35:35 +08:00
parent 040fc57e02
commit 703a40e955
11 changed files with 363 additions and 304 deletions

View File

@ -108,6 +108,3 @@ docker-compose up -d
"is_final": false // 是否是最终结果
}
```
## 许可证
[MIT](LICENSE)

194
src/audio_chunk.py Normal file
View File

@ -0,0 +1,194 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
音频数据块管理类 - 用于存储和处理16KHz音频数据
"""
from typing import List, Optional, Dict
from src.models import AudioBinary_Config, AudioBinary_data_list
# 配置日志
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__, level="INFO")
class AudioBinary:
"""
音频数据存储单元
用于存储二进制数据
面向Slice, 向Slice提供数据与接口
self._audio_config: AudioBinary_Config -- 音频参数配置
self._binary_data_list: AudioBinary_data_list -- 音频数据列表
self._slice_listener: List[callable] -- 切片监听器
AudioBinary_Config: Dict -- 音频参数配置
AudioBinary_data_list: List[bytes] -- 音频数据列表
"""
def __init__(self, *args):
"""
初始化音频数据块
参数:
*args: 可变参数
"""
# 音频参数配置
self._audio_config = AudioBinary_Config()
# 音频片段
self._binary_data_list: AudioBinary_data_list = AudioBinary_data_list()
# 切片监听器
self._slice_listener: List = []
if isinstance(args, Dict):
self._audio_config = AudioBinary_Config.AudioBinary_Config_from_dict(args)
elif isinstance(args, AudioBinary_Config):
self._audio_config = args
else:
raise ValueError("参数类型错误")
def add_slice_listener(self, slice_listener: callable):
"""
添加切片监听器
参数:
slice_listener: callable -- 切片监听器
"""
self._slice_listener.append(slice_listener)
def __add__(self, other: bytes):
"""
__add__ "+" 运算符的重载
使用方法:
audio_binary = audio_binary + bytes
添加音频数据块 add_binary_data 等效
但可以链式调用, 方便使用
参数:
other: bytes --音频数据块
"""
self._binary_data_list.append(other)
return self
def __iadd__(self, other: bytes):
"""
__iadd__ "+=" 运算符的重载
使用方法:
audio_binary += bytes
添加音频数据块 add_binary_data 等效
但可以链式调用, 方便使用
参数:
other: bytes --音频数据块
"""
self._binary_data_list.append(other)
return self
def add_binary_data(self, binary_data: bytes):
"""
添加音频数据块
参数:
binary_data: bytes --音频数据块
"""
self._binary_data_list.append(binary_data)
def rewrite_binary_data(self, target_index: int, binary_data: bytes):
"""
重写音频数据块
参数:
target_index: int -- 目标索引
binary_data: bytes --音频数据块
"""
self._binary_data_list.rewrite(target_index, binary_data)
def get_binary_data(
self,
start: int = 0,
end: Optional[int] = None,
) -> Optional[bytes]:
"""
获取指定索引的音频数据块
参数:
start: 开始索引
end: 结束索引
返回:
List[bytes]: 音频数据块
"""
if start >= len(self._binary_data_list):
return None
if end is None:
end = start + 1
end = min(end, len(self._binary_data_list))
return self._binary_data_list[start:end]
class AudioChunk:
"""
音频数据块管理类
管理两部分内容, AudioBinary和Slice
AudioBinary用于内部存储字节数据
Slice是AudioBinary的切片,用于外部接口
此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑
"""
_instance = None
def __new__(cls, *args, **kwargs):
"""
单例模式
"""
if cls._instance is None:
cls._instance = super(AudioChunk, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self):
"""
初始化AudioChunk实例
"""
self._audio_binary_list: Dict[str, AudioBinary] = {}
self._slice_listener: List[callable] = []
def get_audio_binary(
self,
binary_name: Optional[str] = None,
audio_config: Optional[AudioBinary_Config] = None,
) -> AudioBinary:
"""
获取音频数据块
参数:
binary_name: str -- 音频数据块名称
返回:
AudioBinary: 音频数据块
"""
if binary_name is None:
binary_name = "default"
if binary_name not in self._audio_binary_list:
self._audio_binary_list[binary_name] = AudioBinary(audio_config)
return self._audio_binary_list[binary_name]
@staticmethod
def _time2size(time_ms: int, audio_config: AudioBinary_Config) -> int:
"""
将时间(ms)转换为数据大小(字节)
参数:
time_ms: int -- 时间(ms)
audio_config: AudioBinary_Config -- 音频参数配置
返回:
int: 数据大小(字节)
"""
# 时间(ms)到字节(bytes)计算方法为: 时间(ms) * 采样率(Hz) * 通道数(1 or 2) * 采样位宽(16 or 24) / 1000
time_s = time_ms / 1000
bytes_per_sample = audio_config.sample_width * audio_config.channel
return int(time_s * audio_config.sample_rate * bytes_per_sample)
@staticmethod
def _size2time(size: int, audio_config: AudioBinary_Config) -> int:
"""
将数据大小(字节)转换为时间(ms)
参数:
size: int -- 数据大小(字节)
audio_config: AudioBinary_Config -- 音频参数配置
返回:
int: 时间(ms)
"""
# 字节(bytes)到时间(ms)计算方法为: 字节(bytes) * 1000 / (采样率(Hz) * 通道数(1 or 2) * 采样位宽(16 or 24))
bytes_per_sample = audio_config.sample_width * audio_config.channel
time_ms = size * 1000 // (audio_config.sample_rate * bytes_per_sample)
return time_ms

View File

@ -1,178 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
音频数据块管理类 - 用于存储和处理16KHz音频数据
"""
import numpy as np
from src.utils.logger import get_module_logger
from typing import List, Optional, Union
from src.models import AudioBinary_Config
# 配置日志
logger = get_module_logger(__name__, level="INFO")
class AudioChunk:
"""音频数据块管理类用于存储和处理16KHz音频数据"""
def __init__(self,
max_duration_ms: int = 1000*60*60*10,
audio_config : AudioBinary_Config = None,
):
"""
初始化音频数据块管理器
参数:
max_duration_ms: 音频池最大留存时间(ms)默认10小时
audio_config: 音频配置信息
"""
# 音频参数
self.sample_rate = audio_config.sample_rate if audio_config is not None else 16000 # 采样率16KHz
self.sample_width = audio_config.sample_width if audio_config is not None else 2 # 采样位宽16bit
self.channels = audio_config.channels if audio_config is not None else 1 # 通道数:单声道
# 数据存储
self._max_duration_ms = max_duration_ms
self._max_chunk_size = self._time2size(max_duration_ms) # 最大数据大小
self._chunk = [] # 当前音频数据块列表
self._chunk_size = 0 # 当前数据总大小
self._offset = 0 # 当前偏移量
logger.info(f"初始化AudioChunk: 最大时长={max_duration_ms}ms, 最大数据大小={self._max_chunk_size}字节")
def add_chunk(self, chunk: Union[bytes, np.ndarray]) -> bool:
"""
添加音频数据块
参数:
chunk: 音频数据块可以是bytes或numpy数组
返回:
bool: 是否添加成功
"""
try:
# 检查数据格式
if isinstance(chunk, np.ndarray):
# 确保是16bit整数格式
if chunk.dtype != np.int16:
chunk = chunk.astype(np.int16)
# 转换为bytes
chunk = chunk.tobytes()
# 检查数据大小
if len(chunk) % (self.sample_width * self.channels) != 0:
logger.warning(f"音频数据大小不是{self.sample_width * self.channels}的倍数: {len(chunk)}")
return False
# 检查是否超过最大限制
if self._chunk_size + len(chunk) > self._max_chunk_size:
logger.warning("音频数据超过最大限制,将自动清除旧数据")
self.clear_chunk()
# 添加数据
self._chunk.append(chunk)
self._chunk_size += len(chunk)
return True
except Exception as e:
logger.error(f"添加音频数据块时出错: {e}")
return False
def get_chunk_binary(self, start: int = 0, end: Optional[int] = None) -> Optional[bytes]:
"""
获取指定索引的音频数据块
"""
print("[AudioChunk] get_chunk_binary", start, end)
if start >= len(self._chunk):
return None
if end is None or end > len(self._chunk):
end = len(self._chunk)
data = b''.join(self._chunk)
return data[start:end]
def get_chunk(self, start_ms: int = 0, end_ms: Optional[int] = None) -> Optional[bytes]:
"""
获取指定时间范围的音频数据
参数:
start_ms: 开始时间(ms)
end_ms: 结束时间(ms)None表示到末尾
返回:
Optional[bytes]: 音频数据如果获取失败则返回None
"""
try:
if not self._chunk:
return None
# 计算字节偏移
start_byte = self._time2size(start_ms)
end_byte = self._time2size(end_ms) if end_ms is not None else self._chunk_size
# 检查范围是否有效
if start_byte >= self._chunk_size or start_byte >= end_byte:
return None
# 获取数据
data = b''.join(self._chunk)
return data[start_byte:end_byte]
except Exception as e:
logger.error(f"获取音频数据块时出错: {e}")
return None
def get_duration(self) -> int:
"""
获取当前音频总时长(ms)
返回:
int: 音频时长(ms)
"""
return self._size2time(self._chunk_size)
def clear_chunk(self) -> None:
"""清除所有音频数据"""
self._chunk = []
self._chunk_size = 0
self._offset = 0
logger.info("已清除所有音频数据")
def _time2size(self, time_ms: int) -> int:
"""
将时间(ms)转换为数据大小(字节)
参数:
time_ms: 时间(ms)
返回:
int: 数据大小(字节)
"""
return int(time_ms * self.sample_rate * self.sample_width * self.channels / 1000)
def _size2time(self, size: int) -> int:
"""
将数据大小(字节)转换为时间(ms)
参数:
size: 数据大小(字节)
返回:
int: 时间(ms)
"""
return int(size * 1000 / (self.sample_rate * self.sample_width * self.channels))
# instance(start_ms, end_ms, use_offset=True)
def __call__(self, start_ms: int = 0, end_ms: Optional[int] = None, use_offset: bool = True) -> Optional[bytes]:
"""
获取指定时间范围的音频数据
"""
if use_offset:
start_ms += self._offset
end_ms += self._offset
return self.get_chunk(start_ms, end_ms)
def __len__(self) -> int:
"""
获取当前音频数据块大小
"""
return self._chunk_size

View File

@ -1,103 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
模型加载模块 - 负责加载各种语音识别相关模型
"""
from typing import List, Optional
class ModelLoader:
def __init__(self):
pass
def __call__(self, args):
return self.load_models(args)
def load_models(self, args):
"""
加载所有需要的模型
参数:
args: 命令行参数包含模型配置
返回:
dict: 包含所有加载的模型的字典
"""
def load_models(args):
"""
加载所有需要的模型
参数:
args: 命令行参数包含模型配置
返回:
dict: 包含所有加载的模型的字典
"""
try:
# 导入FunASR库
from funasr import AutoModel
except ImportError:
raise ImportError("未找到funasr库请先安装: pip install funasr")
# 初始化模型字典
models = {}
# 1. 加载离线ASR模型
print(f"正在加载ASR离线模型: {args.asr_model}")
models["asr"] = AutoModel(
model=args.asr_model,
model_revision=args.asr_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
disable_update=True,
)
# 2. 加载在线ASR模型
print(f"正在加载ASR在线模型: {args.asr_model_online}")
models["asr_streaming"] = AutoModel(
model=args.asr_model_online,
model_revision=args.asr_model_online_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
disable_update=True,
)
# 3. 加载VAD模型
print(f"正在加载VAD模型: {args.vad_model}")
models["vad"] = AutoModel(
model=args.vad_model,
model_revision=args.vad_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
disable_update=True,
)
# 4. 加载标点符号模型(如果指定)
if args.punc_model:
print(f"正在加载标点符号模型: {args.punc_model}")
models["punc"] = AutoModel(
model=args.punc_model,
model_revision=args.punc_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
disable_update=True,
)
else:
models["punc"] = None
print("未指定标点符号模型,将不使用标点符号")
print("所有模型加载完成")
return models

124
src/model_loader.py Normal file
View File

@ -0,0 +1,124 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
模型加载模块 - 负责加载各种语音识别相关模型
"""
try:
# 导入FunASR库
from funasr import AutoModel
except ImportError as exc:
raise ImportError("未找到funasr库, 请先安装: pip install funasr") from exc
# 日志模块
from src.utils import get_module_logger
logger = get_module_logger(__name__)
# 单例模式
class ModelLoader:
"""
ModelLoader类是单例模式, 程序生命周期全局唯一, 负责加载模型到字典中
一般的, 可以直接call ModelLoader()来获取加载的模型
也可以通过ModelLoader实例(args)或ModelloaderInstance.load_models(args)来初始化, 并加载模型
"""
_instance = None
def __new__(cls, *args, **kwargs):
"""
单例模式
"""
if cls._instance is None:
cls._instance = super(ModelLoader, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self, args=None):
"""
初始化ModelLoader实例
"""
self.models = {}
logger.info("初始化ModelLoader")
if args is not None:
self.__call__(args)
def __call__(self, args=None):
"""
调用ModelLoader实例时, 如果模型字典为空, 则加载模型
"""
# 如果模型字典为空, 则加载模型
if self.models == {} or self.models is None:
if args.asr_model is not None:
self.models = self.load_models(args)
# 直接调用等于调用self.models
return self.models
def _load_model(self, args, model_type):
"""
加载单个模型
参数:
args: 命令行参数, 包含模型配置
model_type: 模型类型, 用于确定使用哪个模型参数
返回:
AutoModel: 加载的模型实例
"""
# 默认配置
default_config = {
"model": None,
"model_revision": None,
"ngpu": 0,
"ncpu": 1,
"device": "cpu",
"disable_pbar": True,
"disable_log": True,
"disable_update": True,
}
# 从args中获取配置, 如果存在则覆盖默认值
model_args = default_config.copy()
for key, value in default_config.items():
if key in ["model", "model_revision"]:
# 特殊处理model和model_revision, 因为它们需要model_type前缀
if key == "model":
value = getattr(args, f"{model_type}_model", None)
else:
value = getattr(args, f"{model_type}_model_revision", None)
else:
value = getattr(args, key, None)
if value is not None:
model_args[key] = value
# 验证必要参数
if not model_args["model"]:
raise ValueError(f"未指定{model_type}模型路径")
try:
# 使用 % 格式化替代 f-string,避免不必要的字符串格式化开销
logger.info("正在加载%s模型: %s", model_type, model_args["model"])
model = AutoModel(**model_args)
return model
except Exception as e:
logger.error("加载%s模型失败: %s", model_type, str(e))
raise
def load_models(self, args):
"""
加载所有需要的模型
参数:
args: 命令行参数, 包含模型配置
返回:
dict: 包含所有加载的模型的字典
"""
logger.info("ModelLoader加载模型")
# 初始化模型字典
self.models = {}
# 加载离线ASR模型
self.models["asr"] = self._load_model(args, "asr")
# 2. 加载在线ASR模型
self.models["asr_streaming"] = self._load_model(args, "asr_online")
# 3. 加载VAD模型
self.models["vad"] = self._load_model(args, "vad")
# 4. 加载标点符号模型(如果指定)
self.models["punc"] = self._load_model(args, "punc")
logger.info("所有模型加载完成")
return self.models

View File

@ -1,3 +1,3 @@
from .audiobinary import AudioBinary_Config, AudioBinary_Chunk
from .audio import AudioBinary_Config, AudioBinary_data_list, AudioBinary_Slice
from .vad import VADResponse
__all__ = ["AudioBinary_Config", "AudioBinary_Chunk", "VADResponse"]
__all__ = ["AudioBinary_Config", "AudioBinary_data_list", "AudioBinary_Slice", "VADResponse"]

36
src/models/audio.py Normal file
View File

@ -0,0 +1,36 @@
from pydantic import BaseModel, Field
from src.audiochunk import AudioBinary
class AudioBinary_Config(BaseModel):
"""二进制音频块配置信息"""
audio_data: bytes = Field(description="音频数据", default=None)
chunk_size: int = Field(description="块大小", default=100)
chunk_stride: int = Field(description="块步长", default=1600)
sample_rate: int = Field(description="采样率", default=16000)
sample_width: int = Field(description="采样位宽", default=2)
channels: int = Field(description="通道数", default=1)
# 从Dict中加载
@classmethod
def AudioBinary_Config_from_dict(cls, data: dict):
return cls(**data)
class AudioBinary_Slice(BaseModel):
"""音频块切片"""
target_Binary: AudioBinary = Field(description="目标音频块", default=None)
start_index: int = Field(description="开始位置", default=0)
end_index: int = Field(description="结束位置", default=0)
def __call__(self):
return self.target_Binary(self.start_index, self.end_index)
class _AudioBinary_data(BaseModel):
"""音频二进制数据"""
binary_data: bytes = Field(description="音频二进制数据", default=None)
class AudioBinary_data_list(BaseModel):
"""音频二进制数据列表"""
binary_data_list: List[_AudioBinary_data] = Field(description="音频二进制数据列表", default=[])
def __call__(self):
return self.binary_data_list

View File

@ -1,16 +0,0 @@
from pydantic import BaseModel, Field
class AudioBinary_Config(BaseModel):
"""二进制音频块配置信息"""
audio_data: bytes = Field(description="音频数据", default=None)
chunk_size: int = Field(description="块大小", default=100)
chunk_stride: int = Field(description="块步长", default=1600)
sample_rate: int = Field(description="采样率", default=16000)
sample_width: int = Field(description="采样位宽", default=2)
channels: int = Field(description="通道数", default=1)
class AudioBinary_Chunk(BaseModel):
"""音频块"""
start_time: int = Field(description="开始时间", default=0)
end_time: int = Field(description="结束时间", default=0)
chunk: bytes = Field(description="音频块", default=None)

0
src/runner.py Normal file
View File

3
src/utils/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from .logger import get_module_logger, setup_root_logger
__all__ = ["get_module_logger", "setup_root_logger"]

View File

@ -39,8 +39,10 @@ def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
from src.config import parse_args
args = parse_args()
from src.functor.model_loader import load_models
models = load_models(args)
# from src.functor.model_loader import load_models
# models = load_models(args)
from src.model_loader import ModelLoader
models = ModelLoader(args)
chunk_size = 200 # ms
from src.models import AudioBinary_Config