commit 86e5425787561ba74795296f9cf07ed48f1a94b5 Author: Keeeer Date: Mon Apr 14 11:04:36 2025 +0800 [Init] 初始化项目,基于funasr的实时语音识别 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c3a38d8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,53 @@ +# Python相关 +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# 虚拟环境 +venv/ +ENV/ +env/ + +# 测试 +.coverage +htmlcov/ +.pytest_cache/ + +# 编辑器相关 +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# 系统文件 +.DS_Store +Thumbs.db + +# 日志文件 +*.log +logs/ + +# 环境变量 +.env +.env.local + +# cursor规则 +.cursor \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..baf04ab --- /dev/null +++ b/Dockerfile @@ -0,0 +1,27 @@ +FROM python:3.9-slim + +# 安装系统依赖 +RUN apt-get update && apt-get install -y \ + build-essential \ + libsndfile1 \ + ffmpeg \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# 安装Python依赖 +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# 复制应用代码 +COPY . . + +# 设置环境变量 +ENV PYTHONPATH=/app +ENV PYTHONUNBUFFERED=1 + +# 暴露WebSocket端口 +EXPOSE 10095 + +# 启动服务 +CMD ["python", "src/server.py"] \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..bb3eb83 --- /dev/null +++ b/README.md @@ -0,0 +1,113 @@ +# FunASR WebSocket服务 + +## 简介 +本项目基于FunASR实现了一个WebSocket语音识别服务,支持实时语音流的在线和离线识别。利用ModelScope开源语音模型,该服务可以进行高精度的中文语音识别,并支持语音活动检测(VAD)和自动添加标点符号。 + +## 项目结构 +``` +. +├── src/ # 源代码目录 +│ ├── __init__.py # 包初始化文件 +│ ├── server.py # WebSocket服务器实现 +│ ├── config.py # 配置处理模块 +│ ├── models.py # 模型加载模块 +│ ├── service.py # ASR服务实现 +│ └── client.py # 测试客户端 +├── tests/ # 测试目录 +│ ├── __init__.py # 测试包初始化文件 +│ └── test_config.py # 配置模块测试 +├── requirements.txt # Python依赖 +├── Dockerfile # Docker配置 +├── docker-compose.yml # Docker Compose配置 +├── .gitignore # Git忽略文件 +└── README.md # 项目说明 +``` + +## 功能特性 + +- **多模式识别**:支持离线(offline)、在线(online)和两阶段(2pass)识别模式 +- **语音活动检测**:自动检测语音开始和结束 +- **标点符号**:支持自动添加标点符号 +- **WebSocket接口**:基于二进制WebSocket提供实时语音识别 +- **Docker支持**:提供容器化部署支持 + +## 安装与使用 + +### 环境要求 +- Python 3.8+ +- CUDA支持 (若需GPU加速) +- 内存 >= 8GB + +### 安装依赖 + +```bash +pip install -r requirements.txt +``` + +### 运行服务器 + +```bash +python src/server.py +``` + +常用启动参数: +- `--host`: 服务器监听地址,默认为 0.0.0.0 +- `--port`: 服务器端口,默认为 10095 +- `--device`: 设备类型(cuda或cpu),默认为 cuda +- `--ngpu`: GPU数量,0表示使用CPU,默认为 1 + +### 测试客户端 + +```bash +python src/client.py --audio_file path/to/audio.wav +``` + +常用客户端参数: +- `--audio_file`: 要识别的音频文件路径 +- `--mode`: 识别模式,可选 2pass/online/offline,默认为 2pass +- `--host`: 服务器地址,默认为 localhost +- `--port`: 服务器端口,默认为 10095 + +## Docker部署 + +### 构建镜像 + +```bash +docker build -t funasr-websocket . +``` + +### 使用Docker Compose启动 + +```bash +docker-compose up -d +``` + +## API说明 + +### WebSocket消息格式 + +1. **客户端配置消息**: +```json +{ + "mode": "2pass", // 可选: "2pass", "online", "offline" + "chunk_size": "5,10", // 块大小,格式为"encoder_size,decoder_size" + "wav_name": "audio1", // 音频标识名称 + "is_speaking": true // 是否正在说话 +} +``` + +2. **客户端音频数据**: +二进制音频数据流,16kHz采样率,16位PCM格式 + +3. **服务器识别结果**: +```json +{ + "mode": "2pass-online", // 识别模式 + "text": "识别的文本内容", // 识别结果 + "wav_name": "audio1", // 音频标识 + "is_final": false // 是否是最终结果 +} +``` + +## 许可证 +[MIT](LICENSE) \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..d08c69c --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,23 @@ +version: '3.8' + +services: + funasr: + build: . + container_name: funasr-websocket + volumes: + - .:/app + # 如果需要使用本地模型缓存 + - ~/.cache/modelscope:/root/.cache/modelscope + ports: + - "10095:10095" + environment: + - PYTHONUNBUFFERED=1 + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + restart: unless-stopped + command: python src/server.py --device cuda --ngpu 1 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..66e9708 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +pytest==7.3.1 +pytest-cov==4.1.0 +flake8==6.0.0 +black==23.3.0 +isort==5.12.0 +flask==2.3.2 +requests==2.31.0 +websockets==11.0.3 +numpy==1.24.3 +funasr==0.10.0 +modelscope==1.9.5 \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e13b69f --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,14 @@ +""" +FunASR WebSocket服务 +==================== + +提供基于WebSocket的实时语音识别服务,支持在线和离线两种识别模式。 + +主要特性: +- 支持实时语音流识别 +- 支持VAD语音活动检测 +- 支持标点符号自动添加 +- 支持多种识别模式(2pass/online/offline) +""" + +__version__ = "0.1.0" \ No newline at end of file diff --git a/src/client.py b/src/client.py new file mode 100644 index 0000000..86328f1 --- /dev/null +++ b/src/client.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +WebSocket客户端示例 - 用于测试语音识别服务 +""" + +import asyncio +import json +import websockets +import argparse +import numpy as np +import wave +import os + + +def parse_args(): + """解析命令行参数""" + parser = argparse.ArgumentParser(description="FunASR WebSocket客户端") + + parser.add_argument( + "--host", + type=str, + default="localhost", + help="服务器主机地址" + ) + + parser.add_argument( + "--port", + type=int, + default=10095, + help="服务器端口" + ) + + parser.add_argument( + "--audio_file", + type=str, + required=True, + help="要识别的音频文件路径" + ) + + parser.add_argument( + "--mode", + type=str, + default="2pass", + choices=["2pass", "online", "offline"], + help="识别模式: 2pass(默认), online, offline" + ) + + parser.add_argument( + "--chunk_size", + type=str, + default="5,10", + help="分块大小, 格式为'encoder_size,decoder_size'" + ) + + parser.add_argument( + "--use_ssl", + action="store_true", + help="是否使用SSL连接" + ) + + return parser.parse_args() + + +async def send_audio(websocket, audio_file, mode, chunk_size): + """ + 发送音频文件到服务器进行识别 + + 参数: + websocket: WebSocket连接 + audio_file: 音频文件路径 + mode: 识别模式 + chunk_size: 分块大小 + """ + # 打开并读取WAV文件 + with wave.open(audio_file, "rb") as wav_file: + params = wav_file.getparams() + frames = wav_file.readframes(wav_file.getnframes()) + + # 音频文件信息 + print(f"音频文件: {os.path.basename(audio_file)}") + print(f"采样率: {params.framerate}Hz, 通道数: {params.nchannels}") + print(f"采样位深: {params.sampwidth * 8}位, 总帧数: {params.nframes}") + + # 设置配置参数 + config = { + "mode": mode, + "chunk_size": chunk_size, + "wav_name": os.path.basename(audio_file), + "is_speaking": True + } + + # 发送配置 + await websocket.send(json.dumps(config)) + + # 模拟实时发送音频数据 + chunk_size_bytes = 3200 # 每次发送100ms的16kHz音频 + total_chunks = len(frames) // chunk_size_bytes + + print(f"开始发送音频数据,共 {total_chunks} 个数据块...") + + try: + for i in range(0, len(frames), chunk_size_bytes): + chunk = frames[i:i+chunk_size_bytes] + await websocket.send(chunk) + + # 模拟实时,每100ms发送一次 + await asyncio.sleep(0.1) + + # 显示进度 + if (i // chunk_size_bytes) % 10 == 0: + print(f"已发送 {i // chunk_size_bytes}/{total_chunks} 数据块") + + # 发送结束信号 + await websocket.send(json.dumps({"is_speaking": False})) + print("音频数据发送完成") + + except Exception as e: + print(f"发送音频时出错: {e}") + + +async def receive_results(websocket): + """ + 接收并显示识别结果 + + 参数: + websocket: WebSocket连接 + """ + online_text = "" + offline_text = "" + + try: + async for message in websocket: + # 解析服务器返回的JSON消息 + result = json.loads(message) + + mode = result.get("mode", "") + text = result.get("text", "") + is_final = result.get("is_final", False) + + # 根据模式更新文本 + if "online" in mode: + online_text = text + print(f"\r[在线识别] {online_text}", end="", flush=True) + elif "offline" in mode: + offline_text = text + print(f"\n[离线识别] {offline_text}") + + # 如果是最终结果,打印完整信息 + if is_final and offline_text: + print("\n最终识别结果:") + print(f"[离线识别] {offline_text}") + return + + except Exception as e: + print(f"接收结果时出错: {e}") + + +async def main(): + """主函数""" + args = parse_args() + + # WebSocket URI + protocol = "wss" if args.use_ssl else "ws" + uri = f"{protocol}://{args.host}:{args.port}" + + print(f"连接到服务器: {uri}") + + try: + # 创建WebSocket连接 + async with websockets.connect( + uri, + subprotocols=["binary"] + ) as websocket: + + print("连接成功") + + # 创建两个任务: 发送音频和接收结果 + send_task = asyncio.create_task( + send_audio(websocket, args.audio_file, args.mode, args.chunk_size) + ) + + receive_task = asyncio.create_task( + receive_results(websocket) + ) + + # 等待任务完成 + await asyncio.gather(send_task, receive_task) + + except Exception as e: + print(f"连接服务器失败: {e}") + + +if __name__ == "__main__": + # 运行主函数 + asyncio.run(main()) \ No newline at end of file diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..feea99f --- /dev/null +++ b/src/config.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +配置模块 - 处理命令行参数和配置项 +""" + +import argparse + + +def parse_args(): + """ + 解析命令行参数 + + 返回: + argparse.Namespace: 解析后的参数对象 + """ + parser = argparse.ArgumentParser(description="FunASR WebSocket服务器") + + # 服务器配置 + parser.add_argument( + "--host", + type=str, + default="0.0.0.0", + help="服务器主机地址,例如:localhost, 0.0.0.0" + ) + parser.add_argument( + "--port", + type=int, + default=10095, + help="WebSocket服务器端口" + ) + + # SSL配置 + parser.add_argument( + "--certfile", + type=str, + default="", + help="SSL证书文件路径" + ) + parser.add_argument( + "--keyfile", + type=str, + default="", + help="SSL密钥文件路径" + ) + + # ASR模型配置 + parser.add_argument( + "--asr_model", + type=str, + default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", + help="离线ASR模型(从ModelScope获取)" + ) + parser.add_argument( + "--asr_model_revision", + type=str, + default="v2.0.4", + help="离线ASR模型版本" + ) + + # 在线ASR模型配置 + parser.add_argument( + "--asr_model_online", + type=str, + default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", + help="在线ASR模型(从ModelScope获取)" + ) + parser.add_argument( + "--asr_model_online_revision", + type=str, + default="v2.0.4", + help="在线ASR模型版本" + ) + + # VAD模型配置 + parser.add_argument( + "--vad_model", + type=str, + default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", + help="VAD语音活动检测模型(从ModelScope获取)" + ) + parser.add_argument( + "--vad_model_revision", + type=str, + default="v2.0.4", + help="VAD模型版本" + ) + + # 标点符号模型配置 + parser.add_argument( + "--punc_model", + type=str, + default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", + help="标点符号模型(从ModelScope获取)" + ) + parser.add_argument( + "--punc_model_revision", + type=str, + default="v2.0.4", + help="标点符号模型版本" + ) + + # 硬件配置 + parser.add_argument( + "--ngpu", + type=int, + default=1, + help="GPU数量,0表示仅使用CPU" + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="设备类型:cuda或cpu" + ) + parser.add_argument( + "--ncpu", + type=int, + default=4, + help="CPU核心数" + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + print("配置参数:") + for arg in vars(args): + print(f" {arg}: {getattr(args, arg)}") \ No newline at end of file diff --git a/src/models.py b/src/models.py new file mode 100644 index 0000000..7c9c7b6 --- /dev/null +++ b/src/models.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +模型加载模块 - 负责加载各种语音识别相关模型 +""" + +def load_models(args): + """ + 加载所有需要的模型 + + 参数: + args: 命令行参数,包含模型配置 + + 返回: + dict: 包含所有加载的模型的字典 + """ + try: + # 导入FunASR库 + from funasr import AutoModel + except ImportError: + raise ImportError("未找到funasr库,请先安装: pip install funasr") + + # 初始化模型字典 + models = {} + + # 1. 加载离线ASR模型 + print(f"正在加载ASR离线模型: {args.asr_model}") + models["asr"] = AutoModel( + model=args.asr_model, + model_revision=args.asr_model_revision, + ngpu=args.ngpu, + ncpu=args.ncpu, + device=args.device, + disable_pbar=True, + disable_log=True, + ) + + # 2. 加载在线ASR模型 + print(f"正在加载ASR在线模型: {args.asr_model_online}") + models["asr_streaming"] = AutoModel( + model=args.asr_model_online, + model_revision=args.asr_model_online_revision, + ngpu=args.ngpu, + ncpu=args.ncpu, + device=args.device, + disable_pbar=True, + disable_log=True, + ) + + # 3. 加载VAD模型 + print(f"正在加载VAD模型: {args.vad_model}") + models["vad"] = AutoModel( + model=args.vad_model, + model_revision=args.vad_model_revision, + ngpu=args.ngpu, + ncpu=args.ncpu, + device=args.device, + disable_pbar=True, + disable_log=True, + ) + + # 4. 加载标点符号模型(如果指定) + if args.punc_model: + print(f"正在加载标点符号模型: {args.punc_model}") + models["punc"] = AutoModel( + model=args.punc_model, + model_revision=args.punc_model_revision, + ngpu=args.ngpu, + ncpu=args.ncpu, + device=args.device, + disable_pbar=True, + disable_log=True, + ) + else: + models["punc"] = None + print("未指定标点符号模型,将不使用标点符号") + + print("所有模型加载完成") + return models \ No newline at end of file diff --git a/src/server.py b/src/server.py new file mode 100644 index 0000000..9b27763 --- /dev/null +++ b/src/server.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +语音识别WebSocket服务器入口文件 +提供基于WebSocket的实时语音识别服务,支持离线和在线两种模式 +""" + +import asyncio +import json +import websockets +import ssl +import argparse +from config import parse_args +from models import load_models +from service import ASRService + +# 全局变量,存储当前连接的WebSocket客户端 +websocket_users = set() + + +async def ws_reset(websocket): + """重置WebSocket连接状态并关闭连接""" + print(f"重置WebSocket连接,当前连接数: {len(websocket_users)}") + + # 重置状态字典 + websocket.status_dict_asr_online["cache"] = {} + websocket.status_dict_asr_online["is_final"] = True + websocket.status_dict_vad["cache"] = {} + websocket.status_dict_vad["is_final"] = True + websocket.status_dict_punc["cache"] = {} + + # 关闭连接 + await websocket.close() + + +async def clear_websocket(): + """清理所有WebSocket连接""" + for websocket in websocket_users: + await ws_reset(websocket) + websocket_users.clear() + + +async def ws_serve(websocket, path): + """ + WebSocket服务主函数,处理客户端连接和消息 + + 参数: + websocket: WebSocket连接对象 + path: 连接路径 + """ + frames = [] # 存储所有音频帧 + frames_asr = [] # 存储用于离线ASR的音频帧 + frames_asr_online = [] # 存储用于在线ASR的音频帧 + + global websocket_users + # await clear_websocket() # 清理现有连接(目前注释掉,允许多客户端) + + # 添加到用户集合 + websocket_users.add(websocket) + + # 初始化连接状态 + websocket.status_dict_asr = {} + websocket.status_dict_asr_online = {"cache": {}, "is_final": False} + websocket.status_dict_vad = {"cache": {}, "is_final": False} + websocket.status_dict_punc = {"cache": {}} + websocket.chunk_interval = 10 + websocket.vad_pre_idx = 0 + websocket.is_speaking = True # 默认用户正在说话 + + # 语音检测状态 + speech_start = False + speech_end_i = -1 + + # 初始化配置 + websocket.wav_name = "microphone" + websocket.mode = "2pass" # 默认使用两阶段识别模式 + + print("新用户已连接", flush=True) + + try: + # 持续接收客户端消息 + async for message in websocket: + # 处理JSON配置消息 + if isinstance(message, str): + try: + messagejson = json.loads(message) + + # 更新各种配置参数 + if "is_speaking" in messagejson: + websocket.is_speaking = messagejson["is_speaking"] + websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking + if "chunk_interval" in messagejson: + websocket.chunk_interval = messagejson["chunk_interval"] + if "wav_name" in messagejson: + websocket.wav_name = messagejson.get("wav_name") + if "chunk_size" in messagejson: + chunk_size = messagejson["chunk_size"] + if isinstance(chunk_size, str): + chunk_size = chunk_size.split(",") + websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size] + if "encoder_chunk_look_back" in messagejson: + websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"] + if "decoder_chunk_look_back" in messagejson: + websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"] + if "hotword" in messagejson: + websocket.status_dict_asr["hotword"] = messagejson["hotwords"] + if "mode" in messagejson: + websocket.mode = messagejson["mode"] + except json.JSONDecodeError: + print(f"无效的JSON消息: {message}") + + # 根据chunk_interval更新VAD的chunk_size + websocket.status_dict_vad["chunk_size"] = int( + websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1] * 60 / websocket.chunk_interval + ) + + # 处理音频数据 + if len(frames_asr_online) > 0 or len(frames_asr) >= 0 or not isinstance(message, str): + if not isinstance(message, str): # 二进制音频数据 + # 添加到帧缓冲区 + frames.append(message) + duration_ms = len(message) // 32 # 计算音频时长 + websocket.vad_pre_idx += duration_ms + + # 处理在线ASR + frames_asr_online.append(message) + websocket.status_dict_asr_online["is_final"] = speech_end_i != -1 + + # 达到chunk_interval或最终帧时处理在线ASR + if (len(frames_asr_online) % websocket.chunk_interval == 0 or + websocket.status_dict_asr_online["is_final"]): + if websocket.mode == "2pass" or websocket.mode == "online": + audio_in = b"".join(frames_asr_online) + try: + await asr_service.async_asr_online(websocket, audio_in) + except Exception as e: + print(f"在线ASR处理错误: {e}") + frames_asr_online = [] + + # 如果检测到语音开始,收集帧用于离线ASR + if speech_start: + frames_asr.append(message) + + # VAD处理 - 语音活动检测 + try: + speech_start_i, speech_end_i = await asr_service.async_vad(websocket, message) + except Exception as e: + print(f"VAD处理错误: {e}") + + # 检测到语音开始 + if speech_start_i != -1: + speech_start = True + # 计算开始偏移并收集前面的帧 + beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms + frames_pre = frames[-beg_bias:] if beg_bias < len(frames) else frames + frames_asr = [] + frames_asr.extend(frames_pre) + + # 处理离线ASR (语音结束或用户停止说话) + if speech_end_i != -1 or not websocket.is_speaking: + if websocket.mode == "2pass" or websocket.mode == "offline": + audio_in = b"".join(frames_asr) + try: + await asr_service.async_asr(websocket, audio_in) + except Exception as e: + print(f"离线ASR处理错误: {e}") + + # 重置状态 + frames_asr = [] + speech_start = False + frames_asr_online = [] + websocket.status_dict_asr_online["cache"] = {} + + # 如果用户停止说话,完全重置 + if not websocket.is_speaking: + websocket.vad_pre_idx = 0 + frames = [] + websocket.status_dict_vad["cache"] = {} + else: + # 保留最近的帧用于下一轮处理 + frames = frames[-20:] + + except websockets.ConnectionClosed: + print(f"连接已关闭...", flush=True) + await ws_reset(websocket) + websocket_users.remove(websocket) + except websockets.InvalidState: + print("无效的WebSocket状态...") + except Exception as e: + print(f"发生异常: {e}") + + +def start_server(args, asr_service_instance): + """ + 启动WebSocket服务器 + + 参数: + args: 命令行参数 + asr_service_instance: ASR服务实例 + """ + global asr_service + asr_service = asr_service_instance + + # 配置SSL (如果提供了证书) + if args.certfile and len(args.certfile) > 0: + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile) + + start_server = websockets.serve( + ws_serve, args.host, args.port, + subprotocols=["binary"], + ping_interval=None, + ssl=ssl_context + ) + else: + start_server = websockets.serve( + ws_serve, args.host, args.port, + subprotocols=["binary"], + ping_interval=None + ) + + print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}") + + # 启动事件循环 + asyncio.get_event_loop().run_until_complete(start_server) + asyncio.get_event_loop().run_forever() + + +if __name__ == "__main__": + # 解析命令行参数 + args = parse_args() + + # 加载模型 + print("正在加载模型...") + models = load_models(args) + print("模型加载完成!当前仅支持单个客户端同时连接!") + + # 创建ASR服务 + asr_service = ASRService(models) + + # 启动服务器 + start_server(args, asr_service) \ No newline at end of file diff --git a/src/service.py b/src/service.py new file mode 100644 index 0000000..131cfa2 --- /dev/null +++ b/src/service.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +ASR服务模块 - 提供语音识别相关的核心功能 +""" + +import json + + +class ASRService: + """ASR服务类,封装各种语音识别相关功能""" + + def __init__(self, models): + """ + 初始化ASR服务 + + 参数: + models: 包含各种预加载模型的字典 + """ + self.model_asr = models["asr"] + self.model_asr_streaming = models["asr_streaming"] + self.model_vad = models["vad"] + self.model_punc = models["punc"] + + async def async_vad(self, websocket, audio_in): + """ + 语音活动检测 + + 参数: + websocket: WebSocket连接 + audio_in: 二进制音频数据 + + 返回: + tuple: (speech_start, speech_end) 语音开始和结束位置 + """ + # 使用VAD模型分析音频段 + segments_result = self.model_vad.generate( + input=audio_in, + **websocket.status_dict_vad + )[0]["value"] + + speech_start = -1 + speech_end = -1 + + # 解析VAD结果 + if len(segments_result) == 0 or len(segments_result) > 1: + return speech_start, speech_end + + if segments_result[0][0] != -1: + speech_start = segments_result[0][0] + if segments_result[0][1] != -1: + speech_end = segments_result[0][1] + + return speech_start, speech_end + + async def async_asr(self, websocket, audio_in): + """ + 离线ASR处理 + + 参数: + websocket: WebSocket连接 + audio_in: 二进制音频数据 + """ + if len(audio_in) > 0: + # 使用离线ASR模型处理音频 + rec_result = self.model_asr.generate( + input=audio_in, + **websocket.status_dict_asr + )[0] + + # 如果有标点符号模型且识别出文本,则添加标点 + if self.model_punc is not None and len(rec_result["text"]) > 0: + rec_result = self.model_punc.generate( + input=rec_result["text"], + **websocket.status_dict_punc + )[0] + + # 如果识别出文本,发送到客户端 + if len(rec_result["text"]) > 0: + mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode + message = json.dumps({ + "mode": mode, + "text": rec_result["text"], + "wav_name": websocket.wav_name, + "is_final": websocket.is_speaking, + }) + await websocket.send(message) + else: + # 如果没有音频数据,发送空文本 + mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode + message = json.dumps({ + "mode": mode, + "text": "", + "wav_name": websocket.wav_name, + "is_final": websocket.is_speaking, + }) + await websocket.send(message) + + async def async_asr_online(self, websocket, audio_in): + """ + 在线ASR处理 + + 参数: + websocket: WebSocket连接 + audio_in: 二进制音频数据 + """ + if len(audio_in) > 0: + # 使用在线ASR模型处理音频 + rec_result = self.model_asr_streaming.generate( + input=audio_in, + **websocket.status_dict_asr_online + )[0] + + # 在2pass模式下,如果是最终帧则跳过(留给离线ASR处理) + if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False): + return + + # 如果识别出文本,发送到客户端 + if len(rec_result["text"]): + mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode + message = json.dumps({ + "mode": mode, + "text": rec_result["text"], + "wav_name": websocket.wav_name, + "is_final": websocket.is_speaking, + }) + await websocket.send(message) \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..4ffb53d --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""FunASR WebSocket服务测试模块""" \ No newline at end of file diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..6710998 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +配置模块测试 +""" + +import pytest +import sys +import os +from unittest.mock import patch + +# 将src目录添加到路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +from src.config import parse_args + + +def test_default_args(): + """测试默认参数值""" + with patch('sys.argv', ['script.py']): + args = parse_args() + + # 检查服务器参数 + assert args.host == "0.0.0.0" + assert args.port == 10095 + + # 检查SSL参数 + assert args.certfile == "" + assert args.keyfile == "" + + # 检查模型参数 + assert "paraformer" in args.asr_model + assert args.asr_model_revision == "v2.0.4" + assert "paraformer" in args.asr_model_online + assert args.asr_model_online_revision == "v2.0.4" + assert "vad" in args.vad_model + assert args.vad_model_revision == "v2.0.4" + assert "punc" in args.punc_model + assert args.punc_model_revision == "v2.0.4" + + # 检查硬件配置 + assert args.ngpu == 1 + assert args.device == "cuda" + assert args.ncpu == 4 + + +def test_custom_args(): + """测试自定义参数值""" + test_args = [ + 'script.py', + '--host', 'localhost', + '--port', '8080', + '--certfile', 'cert.pem', + '--keyfile', 'key.pem', + '--asr_model', 'custom_model', + '--ngpu', '0', + '--device', 'cpu' + ] + + with patch('sys.argv', test_args): + args = parse_args() + + # 检查自定义参数 + assert args.host == "localhost" + assert args.port == 8080 + assert args.certfile == "cert.pem" + assert args.keyfile == "key.pem" + assert args.asr_model == "custom_model" + assert args.ngpu == 0 + assert args.device == "cpu" \ No newline at end of file