103 lines
2.7 KiB
Python
103 lines
2.7 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
模型加载模块 - 负责加载各种语音识别相关模型
|
||
"""
|
||
|
||
from typing import List, Optional
|
||
|
||
class ModelLoader:
|
||
def __init__(self):
|
||
pass
|
||
|
||
def __call__(self, args):
|
||
return self.load_models(args)
|
||
|
||
def load_models(self, args):
|
||
"""
|
||
加载所有需要的模型
|
||
|
||
参数:
|
||
args: 命令行参数,包含模型配置
|
||
|
||
返回:
|
||
dict: 包含所有加载的模型的字典
|
||
"""
|
||
|
||
def load_models(args):
|
||
"""
|
||
加载所有需要的模型
|
||
|
||
参数:
|
||
args: 命令行参数,包含模型配置
|
||
|
||
返回:
|
||
dict: 包含所有加载的模型的字典
|
||
"""
|
||
try:
|
||
# 导入FunASR库
|
||
from funasr import AutoModel
|
||
except ImportError:
|
||
raise ImportError("未找到funasr库,请先安装: pip install funasr")
|
||
|
||
# 初始化模型字典
|
||
models = {}
|
||
|
||
# 1. 加载离线ASR模型
|
||
print(f"正在加载ASR离线模型: {args.asr_model}")
|
||
models["asr"] = AutoModel(
|
||
model=args.asr_model,
|
||
model_revision=args.asr_model_revision,
|
||
ngpu=args.ngpu,
|
||
ncpu=args.ncpu,
|
||
device=args.device,
|
||
disable_pbar=True,
|
||
disable_log=True,
|
||
disable_update=True,
|
||
)
|
||
|
||
# 2. 加载在线ASR模型
|
||
print(f"正在加载ASR在线模型: {args.asr_model_online}")
|
||
models["asr_streaming"] = AutoModel(
|
||
model=args.asr_model_online,
|
||
model_revision=args.asr_model_online_revision,
|
||
ngpu=args.ngpu,
|
||
ncpu=args.ncpu,
|
||
device=args.device,
|
||
disable_pbar=True,
|
||
disable_log=True,
|
||
disable_update=True,
|
||
)
|
||
|
||
# 3. 加载VAD模型
|
||
print(f"正在加载VAD模型: {args.vad_model}")
|
||
models["vad"] = AutoModel(
|
||
model=args.vad_model,
|
||
model_revision=args.vad_model_revision,
|
||
ngpu=args.ngpu,
|
||
ncpu=args.ncpu,
|
||
device=args.device,
|
||
disable_pbar=True,
|
||
disable_log=True,
|
||
disable_update=True,
|
||
)
|
||
|
||
# 4. 加载标点符号模型(如果指定)
|
||
if args.punc_model:
|
||
print(f"正在加载标点符号模型: {args.punc_model}")
|
||
models["punc"] = AutoModel(
|
||
model=args.punc_model,
|
||
model_revision=args.punc_model_revision,
|
||
ngpu=args.ngpu,
|
||
ncpu=args.ncpu,
|
||
device=args.device,
|
||
disable_pbar=True,
|
||
disable_log=True,
|
||
disable_update=True,
|
||
)
|
||
else:
|
||
models["punc"] = None
|
||
print("未指定标点符号模型,将不使用标点符号")
|
||
|
||
print("所有模型加载完成")
|
||
return models |