77 lines
1.9 KiB
Python
77 lines
1.9 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
配置模块测试
|
|
"""
|
|
|
|
import pytest
|
|
import sys
|
|
import os
|
|
from unittest.mock import patch
|
|
|
|
# 将src目录添加到路径
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|
from src.config import parse_args
|
|
|
|
|
|
def test_default_args():
|
|
"""测试默认参数值"""
|
|
with patch("sys.argv", ["script.py"]):
|
|
args = parse_args()
|
|
|
|
# 检查服务器参数
|
|
assert args.host == "0.0.0.0"
|
|
assert args.port == 10095
|
|
|
|
# 检查SSL参数
|
|
assert args.certfile == ""
|
|
assert args.keyfile == ""
|
|
|
|
# 检查模型参数
|
|
assert "paraformer" in args.asr_model
|
|
assert args.asr_model_revision == "v2.0.4"
|
|
assert "paraformer" in args.asr_model_online
|
|
assert args.asr_model_online_revision == "v2.0.4"
|
|
assert "vad" in args.vad_model
|
|
assert args.vad_model_revision == "v2.0.4"
|
|
assert "punc" in args.punc_model
|
|
assert args.punc_model_revision == "v2.0.4"
|
|
|
|
# 检查硬件配置
|
|
assert args.ngpu == 1
|
|
assert args.device == "cuda"
|
|
assert args.ncpu == 4
|
|
|
|
|
|
def test_custom_args():
|
|
"""测试自定义参数值"""
|
|
test_args = [
|
|
"script.py",
|
|
"--host",
|
|
"localhost",
|
|
"--port",
|
|
"8080",
|
|
"--certfile",
|
|
"cert.pem",
|
|
"--keyfile",
|
|
"key.pem",
|
|
"--asr_model",
|
|
"custom_model",
|
|
"--ngpu",
|
|
"0",
|
|
"--device",
|
|
"cpu",
|
|
]
|
|
|
|
with patch("sys.argv", test_args):
|
|
args = parse_args()
|
|
|
|
# 检查自定义参数
|
|
assert args.host == "localhost"
|
|
assert args.port == 8080
|
|
assert args.certfile == "cert.pem"
|
|
assert args.keyfile == "key.pem"
|
|
assert args.asr_model == "custom_model"
|
|
assert args.ngpu == 0
|
|
assert args.device == "cpu"
|