STT_Server/src/model_loader.py

125 lines
4.1 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
模型加载模块 - 负责加载各种语音识别相关模型
"""
try:
# 导入FunASR库
from funasr import AutoModel
except ImportError as exc:
raise ImportError("未找到funasr库, 请先安装: pip install funasr") from exc
# 日志模块
from src.utils import get_module_logger
logger = get_module_logger(__name__)
# 单例模式
class ModelLoader:
"""
ModelLoader类是单例模式, 程序生命周期全局唯一, 负责加载模型到字典中。
一般的, 可以直接call ModelLoader()来获取加载的模型。
也可以通过ModelLoader实例(args)或ModelloaderInstance.load_models(args)来初始化, 并加载模型。
"""
_instance = None
def __new__(cls, *args, **kwargs):
"""
单例模式
"""
if cls._instance is None:
cls._instance = super(ModelLoader, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self, args=None):
"""
初始化ModelLoader实例
"""
self.models = {}
logger.info("初始化ModelLoader")
if args is not None:
self.__call__(args)
def __call__(self, args=None):
"""
调用ModelLoader实例时, 如果模型字典为空, 则加载模型
"""
# 如果模型字典为空, 则加载模型
if self.models == {} or self.models is None:
if args.asr_model is not None:
self.models = self.load_models(args)
# 直接调用等于调用self.models
return self.models
def _load_model(self, args, model_type):
"""
加载单个模型
参数:
args: 命令行参数, 包含模型配置
model_type: 模型类型, 用于确定使用哪个模型参数
返回:
AutoModel: 加载的模型实例
"""
# 默认配置
default_config = {
"model": None,
"model_revision": None,
"ngpu": 0,
"ncpu": 1,
"device": "cpu",
"disable_pbar": True,
"disable_log": True,
"disable_update": True,
}
# 从args中获取配置, 如果存在则覆盖默认值
model_args = default_config.copy()
for key, value in default_config.items():
if key in ["model", "model_revision"]:
# 特殊处理model和model_revision, 因为它们需要model_type前缀
if key == "model":
value = getattr(args, f"{model_type}_model", None)
else:
value = getattr(args, f"{model_type}_model_revision", None)
else:
value = getattr(args, key, None)
if value is not None:
model_args[key] = value
# 验证必要参数
if not model_args["model"]:
raise ValueError(f"未指定{model_type}模型路径")
try:
# 使用 % 格式化替代 f-string,避免不必要的字符串格式化开销
logger.info("正在加载%s模型: %s", model_type, model_args["model"])
model = AutoModel(**model_args)
return model
except Exception as e:
logger.error("加载%s模型失败: %s", model_type, str(e))
raise
def load_models(self, args):
"""
加载所有需要的模型
参数:
args: 命令行参数, 包含模型配置
返回:
dict: 包含所有加载的模型的字典
"""
logger.info("ModelLoader加载模型")
# 初始化模型字典
self.models = {}
# 加载离线ASR模型
self.models["asr"] = self._load_model(args, "asr")
# 2. 加载在线ASR模型
self.models["asr_streaming"] = self._load_model(args, "asr_online")
# 3. 加载VAD模型
self.models["vad"] = self._load_model(args, "vad")
# 4. 加载标点符号模型(如果指定)
self.models["punc"] = self._load_model(args, "punc")
logger.info("所有模型加载完成")
return self.models