69 lines
1.9 KiB
Python
69 lines
1.9 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
配置模块测试
|
|
"""
|
|
|
|
import pytest
|
|
import sys
|
|
import os
|
|
from unittest.mock import patch
|
|
|
|
# 将src目录添加到路径
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
|
from src.config import parse_args
|
|
|
|
|
|
def test_default_args():
|
|
"""测试默认参数值"""
|
|
with patch('sys.argv', ['script.py']):
|
|
args = parse_args()
|
|
|
|
# 检查服务器参数
|
|
assert args.host == "0.0.0.0"
|
|
assert args.port == 10095
|
|
|
|
# 检查SSL参数
|
|
assert args.certfile == ""
|
|
assert args.keyfile == ""
|
|
|
|
# 检查模型参数
|
|
assert "paraformer" in args.asr_model
|
|
assert args.asr_model_revision == "v2.0.4"
|
|
assert "paraformer" in args.asr_model_online
|
|
assert args.asr_model_online_revision == "v2.0.4"
|
|
assert "vad" in args.vad_model
|
|
assert args.vad_model_revision == "v2.0.4"
|
|
assert "punc" in args.punc_model
|
|
assert args.punc_model_revision == "v2.0.4"
|
|
|
|
# 检查硬件配置
|
|
assert args.ngpu == 1
|
|
assert args.device == "cuda"
|
|
assert args.ncpu == 4
|
|
|
|
|
|
def test_custom_args():
|
|
"""测试自定义参数值"""
|
|
test_args = [
|
|
'script.py',
|
|
'--host', 'localhost',
|
|
'--port', '8080',
|
|
'--certfile', 'cert.pem',
|
|
'--keyfile', 'key.pem',
|
|
'--asr_model', 'custom_model',
|
|
'--ngpu', '0',
|
|
'--device', 'cpu'
|
|
]
|
|
|
|
with patch('sys.argv', test_args):
|
|
args = parse_args()
|
|
|
|
# 检查自定义参数
|
|
assert args.host == "localhost"
|
|
assert args.port == 8080
|
|
assert args.certfile == "cert.pem"
|
|
assert args.keyfile == "key.pem"
|
|
assert args.asr_model == "custom_model"
|
|
assert args.ngpu == 0
|
|
assert args.device == "cpu" |