STT_Server/tests/test_config.py

69 lines
1.9 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
配置模块测试
"""
import pytest
import sys
import os
from unittest.mock import patch
# 将src目录添加到路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src.config import parse_args
def test_default_args():
"""测试默认参数值"""
with patch('sys.argv', ['script.py']):
args = parse_args()
# 检查服务器参数
assert args.host == "0.0.0.0"
assert args.port == 10095
# 检查SSL参数
assert args.certfile == ""
assert args.keyfile == ""
# 检查模型参数
assert "paraformer" in args.asr_model
assert args.asr_model_revision == "v2.0.4"
assert "paraformer" in args.asr_model_online
assert args.asr_model_online_revision == "v2.0.4"
assert "vad" in args.vad_model
assert args.vad_model_revision == "v2.0.4"
assert "punc" in args.punc_model
assert args.punc_model_revision == "v2.0.4"
# 检查硬件配置
assert args.ngpu == 1
assert args.device == "cuda"
assert args.ncpu == 4
def test_custom_args():
"""测试自定义参数值"""
test_args = [
'script.py',
'--host', 'localhost',
'--port', '8080',
'--certfile', 'cert.pem',
'--keyfile', 'key.pem',
'--asr_model', 'custom_model',
'--ngpu', '0',
'--device', 'cpu'
]
with patch('sys.argv', test_args):
args = parse_args()
# 检查自定义参数
assert args.host == "localhost"
assert args.port == 8080
assert args.certfile == "cert.pem"
assert args.keyfile == "key.pem"
assert args.asr_model == "custom_model"
assert args.ngpu == 0
assert args.device == "cpu"