125 lines
4.1 KiB
Python
125 lines
4.1 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
模型加载模块 - 负责加载各种语音识别相关模型
|
|
"""
|
|
try:
|
|
# 导入FunASR库
|
|
from funasr import AutoModel
|
|
except ImportError as exc:
|
|
raise ImportError("未找到funasr库, 请先安装: pip install funasr") from exc
|
|
|
|
# 日志模块
|
|
from src.utils import get_module_logger
|
|
|
|
logger = get_module_logger(__name__)
|
|
|
|
|
|
# 单例模式
|
|
class ModelLoader:
|
|
"""
|
|
ModelLoader类是单例模式, 程序生命周期全局唯一, 负责加载模型到字典中。
|
|
一般的, 可以直接call ModelLoader()来获取加载的模型。
|
|
也可以通过ModelLoader实例(args)或ModelloaderInstance.load_models(args)来初始化, 并加载模型。
|
|
"""
|
|
|
|
_instance = None
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
"""
|
|
单例模式
|
|
"""
|
|
if cls._instance is None:
|
|
cls._instance = super(ModelLoader, cls).__new__(cls, *args, **kwargs)
|
|
return cls._instance
|
|
|
|
def __init__(self, args=None):
|
|
"""
|
|
初始化ModelLoader实例
|
|
"""
|
|
self.models = {}
|
|
logger.info("初始化ModelLoader")
|
|
if args is not None:
|
|
self.__call__(args)
|
|
|
|
def __call__(self, args=None):
|
|
"""
|
|
调用ModelLoader实例时, 如果模型字典为空, 则加载模型
|
|
"""
|
|
# 如果模型字典为空, 则加载模型
|
|
if self.models == {} or self.models is None:
|
|
if args.asr_model is not None:
|
|
self.models = self.load_models(args)
|
|
# 直接调用等于调用self.models
|
|
return self.models
|
|
|
|
def _load_model(self, args, model_type):
|
|
"""
|
|
加载单个模型
|
|
|
|
参数:
|
|
args: 命令行参数, 包含模型配置
|
|
model_type: 模型类型, 用于确定使用哪个模型参数
|
|
|
|
返回:
|
|
AutoModel: 加载的模型实例
|
|
"""
|
|
# 默认配置
|
|
default_config = {
|
|
"model": None,
|
|
"model_revision": None,
|
|
"ngpu": 0,
|
|
"ncpu": 1,
|
|
"device": "cpu",
|
|
"disable_pbar": True,
|
|
"disable_log": True,
|
|
"disable_update": True,
|
|
}
|
|
# 从args中获取配置, 如果存在则覆盖默认值
|
|
model_args = default_config.copy()
|
|
for key, value in default_config.items():
|
|
if key in ["model", "model_revision"]:
|
|
# 特殊处理model和model_revision, 因为它们需要model_type前缀
|
|
if key == "model":
|
|
value = getattr(args, f"{model_type}_model", None)
|
|
else:
|
|
value = getattr(args, f"{model_type}_model_revision", None)
|
|
else:
|
|
value = getattr(args, key, None)
|
|
if value is not None:
|
|
model_args[key] = value
|
|
# 验证必要参数
|
|
if not model_args["model"]:
|
|
raise ValueError(f"未指定{model_type}模型路径")
|
|
try:
|
|
# 使用 % 格式化替代 f-string,避免不必要的字符串格式化开销
|
|
logger.info("正在加载%s模型: %s", model_type, model_args["model"])
|
|
model = AutoModel(**model_args)
|
|
return model
|
|
except Exception as e:
|
|
logger.error("加载%s模型失败: %s", model_type, str(e))
|
|
raise
|
|
|
|
def load_models(self, args):
|
|
"""
|
|
加载所有需要的模型
|
|
参数:
|
|
args: 命令行参数, 包含模型配置
|
|
|
|
返回:
|
|
dict: 包含所有加载的模型的字典
|
|
"""
|
|
logger.info("ModelLoader加载模型")
|
|
# 初始化模型字典
|
|
self.models = {}
|
|
# 加载离线ASR模型
|
|
self.models["asr"] = self._load_model(args, "asr")
|
|
# 2. 加载在线ASR模型
|
|
self.models["asr_streaming"] = self._load_model(args, "asr_online")
|
|
# 3. 加载VAD模型
|
|
self.models["vad"] = self._load_model(args, "vad")
|
|
# 4. 加载标点符号模型(如果指定)
|
|
self.models["punc"] = self._load_model(args, "punc")
|
|
logger.info("所有模型加载完成")
|
|
return self.models
|