STT_Server/src/functor/model_loader.py

103 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
模型加载模块 - 负责加载各种语音识别相关模型
"""
from typing import List, Optional
class ModelLoader:
def __init__(self):
pass
def __call__(self, args):
return self.load_models(args)
def load_models(self, args):
"""
加载所有需要的模型
参数:
args: 命令行参数,包含模型配置
返回:
dict: 包含所有加载的模型的字典
"""
def load_models(args):
"""
加载所有需要的模型
参数:
args: 命令行参数,包含模型配置
返回:
dict: 包含所有加载的模型的字典
"""
try:
# 导入FunASR库
from funasr import AutoModel
except ImportError:
raise ImportError("未找到funasr库请先安装: pip install funasr")
# 初始化模型字典
models = {}
# 1. 加载离线ASR模型
print(f"正在加载ASR离线模型: {args.asr_model}")
models["asr"] = AutoModel(
model=args.asr_model,
model_revision=args.asr_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
disable_update=True,
)
# 2. 加载在线ASR模型
print(f"正在加载ASR在线模型: {args.asr_model_online}")
models["asr_streaming"] = AutoModel(
model=args.asr_model_online,
model_revision=args.asr_model_online_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
disable_update=True,
)
# 3. 加载VAD模型
print(f"正在加载VAD模型: {args.vad_model}")
models["vad"] = AutoModel(
model=args.vad_model,
model_revision=args.vad_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
disable_update=True,
)
# 4. 加载标点符号模型(如果指定)
if args.punc_model:
print(f"正在加载标点符号模型: {args.punc_model}")
models["punc"] = AutoModel(
model=args.punc_model,
model_revision=args.punc_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
disable_update=True,
)
else:
models["punc"] = None
print("未指定标点符号模型,将不使用标点符号")
print("所有模型加载完成")
return models