From 703a40e9554ba471626f1f32a1cad57944d8a847 Mon Sep 17 00:00:00 2001 From: "Ziyang.Zhang" Date: Wed, 28 May 2025 10:35:35 +0800 Subject: [PATCH] =?UTF-8?q?[=E4=BB=A3=E7=A0=81=E9=87=8D=E6=9E=84=E4=B8=AD]?= =?UTF-8?q?=E9=87=8D=E6=9E=84model=5Floader=E4=B8=8Eaudio=5Fchunk=EF=BC=8C?= =?UTF-8?q?=E5=85=A8=E5=B1=80=E5=8D=95=E4=BE=8B=E6=A8=A1=E5=BC=8F=E7=AE=A1?= =?UTF-8?q?=E7=90=86=E6=A8=A1=E5=9E=8B=E5=8A=A0=E8=BD=BD=E4=B8=8Eaudiobina?= =?UTF-8?q?ry=E6=95=B0=E6=8D=AE=E5=AD=98=E5=82=A8=E5=8D=95=E5=85=83?= =?UTF-8?q?=E7=B1=BB=E3=80=82=E5=88=A0=E9=99=A4readme=E4=B8=AD=E4=B8=8D?= =?UTF-8?q?=E9=9C=80=E8=A6=81=E7=9A=84MIT=E8=AE=B8=E5=8F=AF=E8=AF=81?= =?UTF-8?q?=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 3 - src/audio_chunk.py | 194 ++++++++++++++++++++++++++++++++++++ src/functor/audiochunk.py | 178 --------------------------------- src/functor/model_loader.py | 103 ------------------- src/model_loader.py | 124 +++++++++++++++++++++++ src/models/__init__.py | 4 +- src/models/audio.py | 36 +++++++ src/models/audiobinary.py | 16 --- src/runner.py | 0 src/utils/__init__.py | 3 + tests/modelsuse.py | 6 +- 11 files changed, 363 insertions(+), 304 deletions(-) create mode 100644 src/audio_chunk.py delete mode 100644 src/functor/audiochunk.py delete mode 100644 src/functor/model_loader.py create mode 100644 src/model_loader.py create mode 100644 src/models/audio.py delete mode 100644 src/models/audiobinary.py create mode 100644 src/runner.py create mode 100644 src/utils/__init__.py diff --git a/README.md b/README.md index bb3eb83..79de1d2 100644 --- a/README.md +++ b/README.md @@ -108,6 +108,3 @@ docker-compose up -d "is_final": false // 是否是最终结果 } ``` - -## 许可证 -[MIT](LICENSE) \ No newline at end of file diff --git a/src/audio_chunk.py b/src/audio_chunk.py new file mode 100644 index 0000000..c96bef8 --- /dev/null +++ b/src/audio_chunk.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +音频数据块管理类 - 用于存储和处理16KHz音频数据 +""" + +from typing import List, Optional, Dict +from src.models import AudioBinary_Config, AudioBinary_data_list + +# 配置日志 +from src.utils.logger import get_module_logger + +logger = get_module_logger(__name__, level="INFO") + + +class AudioBinary: + """ + 音频数据存储单元 + 用于存储二进制数据 + 面向Slice, 向Slice提供数据与接口 + + self._audio_config: AudioBinary_Config -- 音频参数配置 + self._binary_data_list: AudioBinary_data_list -- 音频数据列表 + self._slice_listener: List[callable] -- 切片监听器 + + AudioBinary_Config: Dict -- 音频参数配置 + AudioBinary_data_list: List[bytes] -- 音频数据列表 + """ + + def __init__(self, *args): + """ + 初始化音频数据块 + 参数: + *args: 可变参数 + """ + # 音频参数配置 + self._audio_config = AudioBinary_Config() + # 音频片段 + self._binary_data_list: AudioBinary_data_list = AudioBinary_data_list() + # 切片监听器 + self._slice_listener: List = [] + if isinstance(args, Dict): + self._audio_config = AudioBinary_Config.AudioBinary_Config_from_dict(args) + elif isinstance(args, AudioBinary_Config): + self._audio_config = args + else: + raise ValueError("参数类型错误") + + def add_slice_listener(self, slice_listener: callable): + """ + 添加切片监听器 + 参数: + slice_listener: callable -- 切片监听器 + """ + self._slice_listener.append(slice_listener) + + def __add__(self, other: bytes): + """ + __add__ 是 "+" 运算符的重载, + 使用方法: + audio_binary = audio_binary + bytes + 添加音频数据块 与 add_binary_data 等效, + 但可以链式调用, 方便使用 + 参数: + other: bytes --音频数据块 + """ + self._binary_data_list.append(other) + return self + + def __iadd__(self, other: bytes): + """ + __iadd__ 是 "+=" 运算符的重载, + 使用方法: + audio_binary += bytes + 添加音频数据块 与 add_binary_data 等效, + 但可以链式调用, 方便使用 + 参数: + other: bytes --音频数据块 + """ + self._binary_data_list.append(other) + return self + + def add_binary_data(self, binary_data: bytes): + """ + 添加音频数据块 + 参数: + binary_data: bytes --音频数据块 + """ + self._binary_data_list.append(binary_data) + + def rewrite_binary_data(self, target_index: int, binary_data: bytes): + """ + 重写音频数据块 + 参数: + target_index: int -- 目标索引 + binary_data: bytes --音频数据块 + """ + self._binary_data_list.rewrite(target_index, binary_data) + + def get_binary_data( + self, + start: int = 0, + end: Optional[int] = None, + ) -> Optional[bytes]: + """ + 获取指定索引的音频数据块 + 参数: + start: 开始索引 + end: 结束索引 + 返回: + List[bytes]: 音频数据块 + """ + if start >= len(self._binary_data_list): + return None + if end is None: + end = start + 1 + end = min(end, len(self._binary_data_list)) + return self._binary_data_list[start:end] + + +class AudioChunk: + """ + 音频数据块管理类 + 管理两部分内容, AudioBinary和Slice。 + AudioBinary用于内部存储字节数据。 + Slice是AudioBinary的切片,用于外部接口。 + + 此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑。 + """ + + _instance = None + + def __new__(cls, *args, **kwargs): + """ + 单例模式 + """ + if cls._instance is None: + cls._instance = super(AudioChunk, cls).__new__(cls, *args, **kwargs) + return cls._instance + + def __init__(self): + """ + 初始化AudioChunk实例 + """ + self._audio_binary_list: Dict[str, AudioBinary] = {} + self._slice_listener: List[callable] = [] + + def get_audio_binary( + self, + binary_name: Optional[str] = None, + audio_config: Optional[AudioBinary_Config] = None, + ) -> AudioBinary: + """ + 获取音频数据块 + 参数: + binary_name: str -- 音频数据块名称 + 返回: + AudioBinary: 音频数据块 + """ + if binary_name is None: + binary_name = "default" + if binary_name not in self._audio_binary_list: + self._audio_binary_list[binary_name] = AudioBinary(audio_config) + return self._audio_binary_list[binary_name] + + @staticmethod + def _time2size(time_ms: int, audio_config: AudioBinary_Config) -> int: + """ + 将时间(ms)转换为数据大小(字节) + 参数: + time_ms: int -- 时间(ms) + audio_config: AudioBinary_Config -- 音频参数配置 + 返回: + int: 数据大小(字节) + """ + # 时间(ms)到字节(bytes)计算方法为: 时间(ms) * 采样率(Hz) * 通道数(1 or 2) * 采样位宽(16 or 24) / 1000 + time_s = time_ms / 1000 + bytes_per_sample = audio_config.sample_width * audio_config.channel + return int(time_s * audio_config.sample_rate * bytes_per_sample) + + @staticmethod + def _size2time(size: int, audio_config: AudioBinary_Config) -> int: + """ + 将数据大小(字节)转换为时间(ms) + 参数: + size: int -- 数据大小(字节) + audio_config: AudioBinary_Config -- 音频参数配置 + 返回: + int: 时间(ms) + """ + # 字节(bytes)到时间(ms)计算方法为: 字节(bytes) * 1000 / (采样率(Hz) * 通道数(1 or 2) * 采样位宽(16 or 24)) + bytes_per_sample = audio_config.sample_width * audio_config.channel + time_ms = size * 1000 // (audio_config.sample_rate * bytes_per_sample) + return time_ms diff --git a/src/functor/audiochunk.py b/src/functor/audiochunk.py deleted file mode 100644 index ad1fbae..0000000 --- a/src/functor/audiochunk.py +++ /dev/null @@ -1,178 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -音频数据块管理类 - 用于存储和处理16KHz音频数据 -""" - -import numpy as np -from src.utils.logger import get_module_logger -from typing import List, Optional, Union -from src.models import AudioBinary_Config - -# 配置日志 -logger = get_module_logger(__name__, level="INFO") - -class AudioChunk: - """音频数据块管理类,用于存储和处理16KHz音频数据""" - - def __init__(self, - max_duration_ms: int = 1000*60*60*10, - audio_config : AudioBinary_Config = None, - ): - """ - 初始化音频数据块管理器 - - 参数: - max_duration_ms: 音频池最大留存时间(ms),默认10小时 - audio_config: 音频配置信息 - """ - # 音频参数 - self.sample_rate = audio_config.sample_rate if audio_config is not None else 16000 # 采样率:16KHz - self.sample_width = audio_config.sample_width if audio_config is not None else 2 # 采样位宽:16bit - self.channels = audio_config.channels if audio_config is not None else 1 # 通道数:单声道 - - # 数据存储 - self._max_duration_ms = max_duration_ms - self._max_chunk_size = self._time2size(max_duration_ms) # 最大数据大小 - self._chunk = [] # 当前音频数据块列表 - self._chunk_size = 0 # 当前数据总大小 - self._offset = 0 # 当前偏移量 - - logger.info(f"初始化AudioChunk: 最大时长={max_duration_ms}ms, 最大数据大小={self._max_chunk_size}字节") - - def add_chunk(self, chunk: Union[bytes, np.ndarray]) -> bool: - """ - 添加音频数据块 - - 参数: - chunk: 音频数据块,可以是bytes或numpy数组 - - 返回: - bool: 是否添加成功 - """ - try: - # 检查数据格式 - if isinstance(chunk, np.ndarray): - # 确保是16bit整数格式 - if chunk.dtype != np.int16: - chunk = chunk.astype(np.int16) - # 转换为bytes - chunk = chunk.tobytes() - - # 检查数据大小 - if len(chunk) % (self.sample_width * self.channels) != 0: - logger.warning(f"音频数据大小不是{self.sample_width * self.channels}的倍数: {len(chunk)}") - return False - - # 检查是否超过最大限制 - if self._chunk_size + len(chunk) > self._max_chunk_size: - logger.warning("音频数据超过最大限制,将自动清除旧数据") - self.clear_chunk() - - # 添加数据 - self._chunk.append(chunk) - self._chunk_size += len(chunk) - return True - - except Exception as e: - logger.error(f"添加音频数据块时出错: {e}") - return False - - def get_chunk_binary(self, start: int = 0, end: Optional[int] = None) -> Optional[bytes]: - """ - 获取指定索引的音频数据块 - """ - print("[AudioChunk] get_chunk_binary", start, end) - if start >= len(self._chunk): - return None - if end is None or end > len(self._chunk): - end = len(self._chunk) - data = b''.join(self._chunk) - return data[start:end] - - def get_chunk(self, start_ms: int = 0, end_ms: Optional[int] = None) -> Optional[bytes]: - """ - 获取指定时间范围的音频数据 - - 参数: - start_ms: 开始时间(ms) - end_ms: 结束时间(ms),None表示到末尾 - - 返回: - Optional[bytes]: 音频数据,如果获取失败则返回None - """ - try: - if not self._chunk: - return None - - # 计算字节偏移 - start_byte = self._time2size(start_ms) - end_byte = self._time2size(end_ms) if end_ms is not None else self._chunk_size - - # 检查范围是否有效 - if start_byte >= self._chunk_size or start_byte >= end_byte: - return None - - # 获取数据 - data = b''.join(self._chunk) - return data[start_byte:end_byte] - - except Exception as e: - logger.error(f"获取音频数据块时出错: {e}") - return None - - def get_duration(self) -> int: - """ - 获取当前音频总时长(ms) - - 返回: - int: 音频时长(ms) - """ - return self._size2time(self._chunk_size) - - def clear_chunk(self) -> None: - """清除所有音频数据""" - self._chunk = [] - self._chunk_size = 0 - self._offset = 0 - logger.info("已清除所有音频数据") - - def _time2size(self, time_ms: int) -> int: - """ - 将时间(ms)转换为数据大小(字节) - - 参数: - time_ms: 时间(ms) - - 返回: - int: 数据大小(字节) - """ - return int(time_ms * self.sample_rate * self.sample_width * self.channels / 1000) - - def _size2time(self, size: int) -> int: - """ - 将数据大小(字节)转换为时间(ms) - - 参数: - size: 数据大小(字节) - - 返回: - int: 时间(ms) - """ - return int(size * 1000 / (self.sample_rate * self.sample_width * self.channels)) - - # instance(start_ms, end_ms, use_offset=True) - def __call__(self, start_ms: int = 0, end_ms: Optional[int] = None, use_offset: bool = True) -> Optional[bytes]: - """ - 获取指定时间范围的音频数据 - """ - if use_offset: - start_ms += self._offset - end_ms += self._offset - return self.get_chunk(start_ms, end_ms) - - def __len__(self) -> int: - """ - 获取当前音频数据块大小 - """ - return self._chunk_size \ No newline at end of file diff --git a/src/functor/model_loader.py b/src/functor/model_loader.py deleted file mode 100644 index f1a18f7..0000000 --- a/src/functor/model_loader.py +++ /dev/null @@ -1,103 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -模型加载模块 - 负责加载各种语音识别相关模型 -""" - -from typing import List, Optional - -class ModelLoader: - def __init__(self): - pass - - def __call__(self, args): - return self.load_models(args) - - def load_models(self, args): - """ - 加载所有需要的模型 - - 参数: - args: 命令行参数,包含模型配置 - - 返回: - dict: 包含所有加载的模型的字典 - """ - -def load_models(args): - """ - 加载所有需要的模型 - - 参数: - args: 命令行参数,包含模型配置 - - 返回: - dict: 包含所有加载的模型的字典 - """ - try: - # 导入FunASR库 - from funasr import AutoModel - except ImportError: - raise ImportError("未找到funasr库,请先安装: pip install funasr") - - # 初始化模型字典 - models = {} - - # 1. 加载离线ASR模型 - print(f"正在加载ASR离线模型: {args.asr_model}") - models["asr"] = AutoModel( - model=args.asr_model, - model_revision=args.asr_model_revision, - ngpu=args.ngpu, - ncpu=args.ncpu, - device=args.device, - disable_pbar=True, - disable_log=True, - disable_update=True, - ) - - # 2. 加载在线ASR模型 - print(f"正在加载ASR在线模型: {args.asr_model_online}") - models["asr_streaming"] = AutoModel( - model=args.asr_model_online, - model_revision=args.asr_model_online_revision, - ngpu=args.ngpu, - ncpu=args.ncpu, - device=args.device, - disable_pbar=True, - disable_log=True, - disable_update=True, - ) - - # 3. 加载VAD模型 - print(f"正在加载VAD模型: {args.vad_model}") - models["vad"] = AutoModel( - model=args.vad_model, - model_revision=args.vad_model_revision, - ngpu=args.ngpu, - ncpu=args.ncpu, - device=args.device, - disable_pbar=True, - disable_log=True, - disable_update=True, - ) - - # 4. 加载标点符号模型(如果指定) - if args.punc_model: - print(f"正在加载标点符号模型: {args.punc_model}") - models["punc"] = AutoModel( - model=args.punc_model, - model_revision=args.punc_model_revision, - ngpu=args.ngpu, - ncpu=args.ncpu, - device=args.device, - disable_pbar=True, - disable_log=True, - disable_update=True, - ) - else: - models["punc"] = None - print("未指定标点符号模型,将不使用标点符号") - - print("所有模型加载完成") - return models \ No newline at end of file diff --git a/src/model_loader.py b/src/model_loader.py new file mode 100644 index 0000000..aafb93c --- /dev/null +++ b/src/model_loader.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +模型加载模块 - 负责加载各种语音识别相关模型 +""" +try: + # 导入FunASR库 + from funasr import AutoModel +except ImportError as exc: + raise ImportError("未找到funasr库, 请先安装: pip install funasr") from exc + +# 日志模块 +from src.utils import get_module_logger + +logger = get_module_logger(__name__) + + +# 单例模式 +class ModelLoader: + """ + ModelLoader类是单例模式, 程序生命周期全局唯一, 负责加载模型到字典中。 + 一般的, 可以直接call ModelLoader()来获取加载的模型。 + 也可以通过ModelLoader实例(args)或ModelloaderInstance.load_models(args)来初始化, 并加载模型。 + """ + + _instance = None + + def __new__(cls, *args, **kwargs): + """ + 单例模式 + """ + if cls._instance is None: + cls._instance = super(ModelLoader, cls).__new__(cls, *args, **kwargs) + return cls._instance + + def __init__(self, args=None): + """ + 初始化ModelLoader实例 + """ + self.models = {} + logger.info("初始化ModelLoader") + if args is not None: + self.__call__(args) + + def __call__(self, args=None): + """ + 调用ModelLoader实例时, 如果模型字典为空, 则加载模型 + """ + # 如果模型字典为空, 则加载模型 + if self.models == {} or self.models is None: + if args.asr_model is not None: + self.models = self.load_models(args) + # 直接调用等于调用self.models + return self.models + + def _load_model(self, args, model_type): + """ + 加载单个模型 + + 参数: + args: 命令行参数, 包含模型配置 + model_type: 模型类型, 用于确定使用哪个模型参数 + + 返回: + AutoModel: 加载的模型实例 + """ + # 默认配置 + default_config = { + "model": None, + "model_revision": None, + "ngpu": 0, + "ncpu": 1, + "device": "cpu", + "disable_pbar": True, + "disable_log": True, + "disable_update": True, + } + # 从args中获取配置, 如果存在则覆盖默认值 + model_args = default_config.copy() + for key, value in default_config.items(): + if key in ["model", "model_revision"]: + # 特殊处理model和model_revision, 因为它们需要model_type前缀 + if key == "model": + value = getattr(args, f"{model_type}_model", None) + else: + value = getattr(args, f"{model_type}_model_revision", None) + else: + value = getattr(args, key, None) + if value is not None: + model_args[key] = value + # 验证必要参数 + if not model_args["model"]: + raise ValueError(f"未指定{model_type}模型路径") + try: + # 使用 % 格式化替代 f-string,避免不必要的字符串格式化开销 + logger.info("正在加载%s模型: %s", model_type, model_args["model"]) + model = AutoModel(**model_args) + return model + except Exception as e: + logger.error("加载%s模型失败: %s", model_type, str(e)) + raise + + def load_models(self, args): + """ + 加载所有需要的模型 + 参数: + args: 命令行参数, 包含模型配置 + + 返回: + dict: 包含所有加载的模型的字典 + """ + logger.info("ModelLoader加载模型") + # 初始化模型字典 + self.models = {} + # 加载离线ASR模型 + self.models["asr"] = self._load_model(args, "asr") + # 2. 加载在线ASR模型 + self.models["asr_streaming"] = self._load_model(args, "asr_online") + # 3. 加载VAD模型 + self.models["vad"] = self._load_model(args, "vad") + # 4. 加载标点符号模型(如果指定) + self.models["punc"] = self._load_model(args, "punc") + logger.info("所有模型加载完成") + return self.models diff --git a/src/models/__init__.py b/src/models/__init__.py index d64a40a..fdbc7ec 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,3 +1,3 @@ -from .audiobinary import AudioBinary_Config, AudioBinary_Chunk +from .audio import AudioBinary_Config, AudioBinary_data_list, AudioBinary_Slice from .vad import VADResponse -__all__ = ["AudioBinary_Config", "AudioBinary_Chunk", "VADResponse"] \ No newline at end of file +__all__ = ["AudioBinary_Config", "AudioBinary_data_list", "AudioBinary_Slice", "VADResponse"] \ No newline at end of file diff --git a/src/models/audio.py b/src/models/audio.py new file mode 100644 index 0000000..ac63a06 --- /dev/null +++ b/src/models/audio.py @@ -0,0 +1,36 @@ +from pydantic import BaseModel, Field +from src.audiochunk import AudioBinary + +class AudioBinary_Config(BaseModel): + """二进制音频块配置信息""" + audio_data: bytes = Field(description="音频数据", default=None) + chunk_size: int = Field(description="块大小", default=100) + chunk_stride: int = Field(description="块步长", default=1600) + sample_rate: int = Field(description="采样率", default=16000) + sample_width: int = Field(description="采样位宽", default=2) + channels: int = Field(description="通道数", default=1) + + # 从Dict中加载 + @classmethod + def AudioBinary_Config_from_dict(cls, data: dict): + return cls(**data) + +class AudioBinary_Slice(BaseModel): + """音频块切片""" + target_Binary: AudioBinary = Field(description="目标音频块", default=None) + start_index: int = Field(description="开始位置", default=0) + end_index: int = Field(description="结束位置", default=0) + + def __call__(self): + return self.target_Binary(self.start_index, self.end_index) + +class _AudioBinary_data(BaseModel): + """音频二进制数据""" + binary_data: bytes = Field(description="音频二进制数据", default=None) + +class AudioBinary_data_list(BaseModel): + """音频二进制数据列表""" + binary_data_list: List[_AudioBinary_data] = Field(description="音频二进制数据列表", default=[]) + + def __call__(self): + return self.binary_data_list diff --git a/src/models/audiobinary.py b/src/models/audiobinary.py deleted file mode 100644 index 3ad0d26..0000000 --- a/src/models/audiobinary.py +++ /dev/null @@ -1,16 +0,0 @@ -from pydantic import BaseModel, Field - -class AudioBinary_Config(BaseModel): - """二进制音频块配置信息""" - audio_data: bytes = Field(description="音频数据", default=None) - chunk_size: int = Field(description="块大小", default=100) - chunk_stride: int = Field(description="块步长", default=1600) - sample_rate: int = Field(description="采样率", default=16000) - sample_width: int = Field(description="采样位宽", default=2) - channels: int = Field(description="通道数", default=1) - -class AudioBinary_Chunk(BaseModel): - """音频块""" - start_time: int = Field(description="开始时间", default=0) - end_time: int = Field(description="结束时间", default=0) - chunk: bytes = Field(description="音频块", default=None) diff --git a/src/runner.py b/src/runner.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..0e567b2 --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,3 @@ +from .logger import get_module_logger, setup_root_logger + +__all__ = ["get_module_logger", "setup_root_logger"] \ No newline at end of file diff --git a/tests/modelsuse.py b/tests/modelsuse.py index 0554e0c..30d035d 100644 --- a/tests/modelsuse.py +++ b/tests/modelsuse.py @@ -39,8 +39,10 @@ def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]: from src.config import parse_args args = parse_args() - from src.functor.model_loader import load_models - models = load_models(args) + # from src.functor.model_loader import load_models + # models = load_models(args) + from src.model_loader import ModelLoader + models = ModelLoader(args) chunk_size = 200 # ms from src.models import AudioBinary_Config