#!/usr/bin/env python # -*- coding: utf-8 -*- """ 配置模块测试 """ import pytest import sys import os from unittest.mock import patch # 将src目录添加到路径 sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from src.config import parse_args def test_default_args(): """测试默认参数值""" with patch("sys.argv", ["script.py"]): args = parse_args() # 检查服务器参数 assert args.host == "0.0.0.0" assert args.port == 10095 # 检查SSL参数 assert args.certfile == "" assert args.keyfile == "" # 检查模型参数 assert "paraformer" in args.asr_model assert args.asr_model_revision == "v2.0.4" assert "paraformer" in args.asr_model_online assert args.asr_model_online_revision == "v2.0.4" assert "vad" in args.vad_model assert args.vad_model_revision == "v2.0.4" assert "punc" in args.punc_model assert args.punc_model_revision == "v2.0.4" # 检查硬件配置 assert args.ngpu == 1 assert args.device == "cuda" assert args.ncpu == 4 def test_custom_args(): """测试自定义参数值""" test_args = [ "script.py", "--host", "localhost", "--port", "8080", "--certfile", "cert.pem", "--keyfile", "key.pem", "--asr_model", "custom_model", "--ngpu", "0", "--device", "cpu", ] with patch("sys.argv", test_args): args = parse_args() # 检查自定义参数 assert args.host == "localhost" assert args.port == 8080 assert args.certfile == "cert.pem" assert args.keyfile == "key.pem" assert args.asr_model == "custom_model" assert args.ngpu == 0 assert args.device == "cpu"