#!/usr/bin/env python # -*- coding: utf-8 -*- """ 模型加载模块 - 负责加载各种语音识别相关模型 """ try: # 导入FunASR库 from funasr import AutoModel except ImportError as exc: raise ImportError("未找到funasr库, 请先安装: pip install funasr") from exc # 日志模块 from src.utils import get_module_logger logger = get_module_logger(__name__) # 单例模式 class ModelLoader: """ ModelLoader类是单例模式, 程序生命周期全局唯一, 负责加载模型到字典中。 一般的, 可以直接call ModelLoader()来获取加载的模型。 也可以通过ModelLoader实例(args)或ModelloaderInstance.load_models(args)来初始化, 并加载模型。 """ _instance = None def __new__(cls, *args, **kwargs): """ 单例模式 """ if cls._instance is None: cls._instance = super(ModelLoader, cls).__new__(cls, *args, **kwargs) return cls._instance def __init__(self, args=None): """ 初始化ModelLoader实例 """ self.models = {} logger.info("初始化ModelLoader") if args is not None: self.__call__(args) def __call__(self, args=None): """ 调用ModelLoader实例时, 如果模型字典为空, 则加载模型 """ # 如果模型字典为空, 则加载模型 if self.models == {} or self.models is None: if args.asr_model is not None: self.models = self.load_models(args) # 直接调用等于调用self.models return self.models def _load_model(self, args, model_type): """ 加载单个模型 参数: args: 命令行参数, 包含模型配置 model_type: 模型类型, 用于确定使用哪个模型参数 返回: AutoModel: 加载的模型实例 """ # 默认配置 default_config = { "model": None, "model_revision": None, "ngpu": 0, "ncpu": 1, "device": "cpu", "disable_pbar": True, "disable_log": True, "disable_update": True, } # 从args中获取配置, 如果存在则覆盖默认值 model_args = default_config.copy() for key, value in default_config.items(): if key in ["model", "model_revision"]: # 特殊处理model和model_revision, 因为它们需要model_type前缀 if key == "model": value = getattr(args, f"{model_type}_model", None) else: value = getattr(args, f"{model_type}_model_revision", None) else: value = getattr(args, key, None) if value is not None: model_args[key] = value # 验证必要参数 if not model_args["model"]: raise ValueError(f"未指定{model_type}模型路径") try: # 使用 % 格式化替代 f-string,避免不必要的字符串格式化开销 logger.info("正在加载%s模型: %s", model_type, model_args["model"]) model = AutoModel(**model_args) return model except Exception as e: logger.error("加载%s模型失败: %s", model_type, str(e)) raise def load_models(self, args): """ 加载所有需要的模型 参数: args: 命令行参数, 包含模型配置 返回: dict: 包含所有加载的模型的字典 """ logger.info("ModelLoader加载模型") # 初始化模型字典 self.models = {} # 加载离线ASR模型 self.models["asr"] = self._load_model(args, "asr") # 2. 加载在线ASR模型 self.models["asr_streaming"] = self._load_model(args, "asr_online") # 3. 加载VAD模型 self.models["vad"] = self._load_model(args, "vad") # 4. 加载标点符号模型(如果指定) self.models["punc"] = self._load_model(args, "punc") logger.info("所有模型加载完成") return self.models