[Init] 初始化项目,基于funasr的实时语音识别

This commit is contained in:
Keeeer 2025-04-14 11:04:36 +08:00
commit 86e5425787
13 changed files with 1085 additions and 0 deletions

53
.gitignore vendored Normal file
View File

@ -0,0 +1,53 @@
# Python相关
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# 虚拟环境
venv/
ENV/
env/
# 测试
.coverage
htmlcov/
.pytest_cache/
# 编辑器相关
.idea/
.vscode/
*.swp
*.swo
*~
# 系统文件
.DS_Store
Thumbs.db
# 日志文件
*.log
logs/
# 环境变量
.env
.env.local
# cursor规则
.cursor

27
Dockerfile Normal file
View File

@ -0,0 +1,27 @@
FROM python:3.9-slim
# 安装系统依赖
RUN apt-get update && apt-get install -y \
build-essential \
libsndfile1 \
ffmpeg \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
# 安装Python依赖
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 复制应用代码
COPY . .
# 设置环境变量
ENV PYTHONPATH=/app
ENV PYTHONUNBUFFERED=1
# 暴露WebSocket端口
EXPOSE 10095
# 启动服务
CMD ["python", "src/server.py"]

113
README.md Normal file
View File

@ -0,0 +1,113 @@
# FunASR WebSocket服务
## 简介
本项目基于FunASR实现了一个WebSocket语音识别服务支持实时语音流的在线和离线识别。利用ModelScope开源语音模型该服务可以进行高精度的中文语音识别并支持语音活动检测(VAD)和自动添加标点符号。
## 项目结构
```
.
├── src/ # 源代码目录
│ ├── __init__.py # 包初始化文件
│ ├── server.py # WebSocket服务器实现
│ ├── config.py # 配置处理模块
│ ├── models.py # 模型加载模块
│ ├── service.py # ASR服务实现
│ └── client.py # 测试客户端
├── tests/ # 测试目录
│ ├── __init__.py # 测试包初始化文件
│ └── test_config.py # 配置模块测试
├── requirements.txt # Python依赖
├── Dockerfile # Docker配置
├── docker-compose.yml # Docker Compose配置
├── .gitignore # Git忽略文件
└── README.md # 项目说明
```
## 功能特性
- **多模式识别**:支持离线(offline)、在线(online)和两阶段(2pass)识别模式
- **语音活动检测**:自动检测语音开始和结束
- **标点符号**:支持自动添加标点符号
- **WebSocket接口**基于二进制WebSocket提供实时语音识别
- **Docker支持**:提供容器化部署支持
## 安装与使用
### 环境要求
- Python 3.8+
- CUDA支持 (若需GPU加速)
- 内存 >= 8GB
### 安装依赖
```bash
pip install -r requirements.txt
```
### 运行服务器
```bash
python src/server.py
```
常用启动参数:
- `--host`: 服务器监听地址,默认为 0.0.0.0
- `--port`: 服务器端口,默认为 10095
- `--device`: 设备类型(cuda或cpu),默认为 cuda
- `--ngpu`: GPU数量0表示使用CPU默认为 1
### 测试客户端
```bash
python src/client.py --audio_file path/to/audio.wav
```
常用客户端参数:
- `--audio_file`: 要识别的音频文件路径
- `--mode`: 识别模式,可选 2pass/online/offline默认为 2pass
- `--host`: 服务器地址,默认为 localhost
- `--port`: 服务器端口,默认为 10095
## Docker部署
### 构建镜像
```bash
docker build -t funasr-websocket .
```
### 使用Docker Compose启动
```bash
docker-compose up -d
```
## API说明
### WebSocket消息格式
1. **客户端配置消息**:
```json
{
"mode": "2pass", // 可选: "2pass", "online", "offline"
"chunk_size": "5,10", // 块大小,格式为"encoder_size,decoder_size"
"wav_name": "audio1", // 音频标识名称
"is_speaking": true // 是否正在说话
}
```
2. **客户端音频数据**:
二进制音频数据流16kHz采样率16位PCM格式
3. **服务器识别结果**:
```json
{
"mode": "2pass-online", // 识别模式
"text": "识别的文本内容", // 识别结果
"wav_name": "audio1", // 音频标识
"is_final": false // 是否是最终结果
}
```
## 许可证
[MIT](LICENSE)

23
docker-compose.yml Normal file
View File

@ -0,0 +1,23 @@
version: '3.8'
services:
funasr:
build: .
container_name: funasr-websocket
volumes:
- .:/app
# 如果需要使用本地模型缓存
- ~/.cache/modelscope:/root/.cache/modelscope
ports:
- "10095:10095"
environment:
- PYTHONUNBUFFERED=1
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
restart: unless-stopped
command: python src/server.py --device cuda --ngpu 1

11
requirements.txt Normal file
View File

@ -0,0 +1,11 @@
pytest==7.3.1
pytest-cov==4.1.0
flake8==6.0.0
black==23.3.0
isort==5.12.0
flask==2.3.2
requests==2.31.0
websockets==11.0.3
numpy==1.24.3
funasr==0.10.0
modelscope==1.9.5

14
src/__init__.py Normal file
View File

@ -0,0 +1,14 @@
"""
FunASR WebSocket服务
====================
提供基于WebSocket的实时语音识别服务支持在线和离线两种识别模式
主要特性:
- 支持实时语音流识别
- 支持VAD语音活动检测
- 支持标点符号自动添加
- 支持多种识别模式(2pass/online/offline)
"""
__version__ = "0.1.0"

196
src/client.py Normal file
View File

@ -0,0 +1,196 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
WebSocket客户端示例 - 用于测试语音识别服务
"""
import asyncio
import json
import websockets
import argparse
import numpy as np
import wave
import os
def parse_args():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description="FunASR WebSocket客户端")
parser.add_argument(
"--host",
type=str,
default="localhost",
help="服务器主机地址"
)
parser.add_argument(
"--port",
type=int,
default=10095,
help="服务器端口"
)
parser.add_argument(
"--audio_file",
type=str,
required=True,
help="要识别的音频文件路径"
)
parser.add_argument(
"--mode",
type=str,
default="2pass",
choices=["2pass", "online", "offline"],
help="识别模式: 2pass(默认), online, offline"
)
parser.add_argument(
"--chunk_size",
type=str,
default="5,10",
help="分块大小, 格式为'encoder_size,decoder_size'"
)
parser.add_argument(
"--use_ssl",
action="store_true",
help="是否使用SSL连接"
)
return parser.parse_args()
async def send_audio(websocket, audio_file, mode, chunk_size):
"""
发送音频文件到服务器进行识别
参数:
websocket: WebSocket连接
audio_file: 音频文件路径
mode: 识别模式
chunk_size: 分块大小
"""
# 打开并读取WAV文件
with wave.open(audio_file, "rb") as wav_file:
params = wav_file.getparams()
frames = wav_file.readframes(wav_file.getnframes())
# 音频文件信息
print(f"音频文件: {os.path.basename(audio_file)}")
print(f"采样率: {params.framerate}Hz, 通道数: {params.nchannels}")
print(f"采样位深: {params.sampwidth * 8}位, 总帧数: {params.nframes}")
# 设置配置参数
config = {
"mode": mode,
"chunk_size": chunk_size,
"wav_name": os.path.basename(audio_file),
"is_speaking": True
}
# 发送配置
await websocket.send(json.dumps(config))
# 模拟实时发送音频数据
chunk_size_bytes = 3200 # 每次发送100ms的16kHz音频
total_chunks = len(frames) // chunk_size_bytes
print(f"开始发送音频数据,共 {total_chunks} 个数据块...")
try:
for i in range(0, len(frames), chunk_size_bytes):
chunk = frames[i:i+chunk_size_bytes]
await websocket.send(chunk)
# 模拟实时每100ms发送一次
await asyncio.sleep(0.1)
# 显示进度
if (i // chunk_size_bytes) % 10 == 0:
print(f"已发送 {i // chunk_size_bytes}/{total_chunks} 数据块")
# 发送结束信号
await websocket.send(json.dumps({"is_speaking": False}))
print("音频数据发送完成")
except Exception as e:
print(f"发送音频时出错: {e}")
async def receive_results(websocket):
"""
接收并显示识别结果
参数:
websocket: WebSocket连接
"""
online_text = ""
offline_text = ""
try:
async for message in websocket:
# 解析服务器返回的JSON消息
result = json.loads(message)
mode = result.get("mode", "")
text = result.get("text", "")
is_final = result.get("is_final", False)
# 根据模式更新文本
if "online" in mode:
online_text = text
print(f"\r[在线识别] {online_text}", end="", flush=True)
elif "offline" in mode:
offline_text = text
print(f"\n[离线识别] {offline_text}")
# 如果是最终结果,打印完整信息
if is_final and offline_text:
print("\n最终识别结果:")
print(f"[离线识别] {offline_text}")
return
except Exception as e:
print(f"接收结果时出错: {e}")
async def main():
"""主函数"""
args = parse_args()
# WebSocket URI
protocol = "wss" if args.use_ssl else "ws"
uri = f"{protocol}://{args.host}:{args.port}"
print(f"连接到服务器: {uri}")
try:
# 创建WebSocket连接
async with websockets.connect(
uri,
subprotocols=["binary"]
) as websocket:
print("连接成功")
# 创建两个任务: 发送音频和接收结果
send_task = asyncio.create_task(
send_audio(websocket, args.audio_file, args.mode, args.chunk_size)
)
receive_task = asyncio.create_task(
receive_results(websocket)
)
# 等待任务完成
await asyncio.gather(send_task, receive_task)
except Exception as e:
print(f"连接服务器失败: {e}")
if __name__ == "__main__":
# 运行主函数
asyncio.run(main())

130
src/config.py Normal file
View File

@ -0,0 +1,130 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
配置模块 - 处理命令行参数和配置项
"""
import argparse
def parse_args():
"""
解析命令行参数
返回:
argparse.Namespace: 解析后的参数对象
"""
parser = argparse.ArgumentParser(description="FunASR WebSocket服务器")
# 服务器配置
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="服务器主机地址例如localhost, 0.0.0.0"
)
parser.add_argument(
"--port",
type=int,
default=10095,
help="WebSocket服务器端口"
)
# SSL配置
parser.add_argument(
"--certfile",
type=str,
default="",
help="SSL证书文件路径"
)
parser.add_argument(
"--keyfile",
type=str,
default="",
help="SSL密钥文件路径"
)
# ASR模型配置
parser.add_argument(
"--asr_model",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="离线ASR模型从ModelScope获取"
)
parser.add_argument(
"--asr_model_revision",
type=str,
default="v2.0.4",
help="离线ASR模型版本"
)
# 在线ASR模型配置
parser.add_argument(
"--asr_model_online",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
help="在线ASR模型从ModelScope获取"
)
parser.add_argument(
"--asr_model_online_revision",
type=str,
default="v2.0.4",
help="在线ASR模型版本"
)
# VAD模型配置
parser.add_argument(
"--vad_model",
type=str,
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
help="VAD语音活动检测模型从ModelScope获取"
)
parser.add_argument(
"--vad_model_revision",
type=str,
default="v2.0.4",
help="VAD模型版本"
)
# 标点符号模型配置
parser.add_argument(
"--punc_model",
type=str,
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
help="标点符号模型从ModelScope获取"
)
parser.add_argument(
"--punc_model_revision",
type=str,
default="v2.0.4",
help="标点符号模型版本"
)
# 硬件配置
parser.add_argument(
"--ngpu",
type=int,
default=1,
help="GPU数量0表示仅使用CPU"
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="设备类型cuda或cpu"
)
parser.add_argument(
"--ncpu",
type=int,
default=4,
help="CPU核心数"
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
print("配置参数:")
for arg in vars(args):
print(f" {arg}: {getattr(args, arg)}")

79
src/models.py Normal file
View File

@ -0,0 +1,79 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
模型加载模块 - 负责加载各种语音识别相关模型
"""
def load_models(args):
"""
加载所有需要的模型
参数:
args: 命令行参数包含模型配置
返回:
dict: 包含所有加载的模型的字典
"""
try:
# 导入FunASR库
from funasr import AutoModel
except ImportError:
raise ImportError("未找到funasr库请先安装: pip install funasr")
# 初始化模型字典
models = {}
# 1. 加载离线ASR模型
print(f"正在加载ASR离线模型: {args.asr_model}")
models["asr"] = AutoModel(
model=args.asr_model,
model_revision=args.asr_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
# 2. 加载在线ASR模型
print(f"正在加载ASR在线模型: {args.asr_model_online}")
models["asr_streaming"] = AutoModel(
model=args.asr_model_online,
model_revision=args.asr_model_online_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
# 3. 加载VAD模型
print(f"正在加载VAD模型: {args.vad_model}")
models["vad"] = AutoModel(
model=args.vad_model,
model_revision=args.vad_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
# 4. 加载标点符号模型(如果指定)
if args.punc_model:
print(f"正在加载标点符号模型: {args.punc_model}")
models["punc"] = AutoModel(
model=args.punc_model,
model_revision=args.punc_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
else:
models["punc"] = None
print("未指定标点符号模型,将不使用标点符号")
print("所有模型加载完成")
return models

242
src/server.py Normal file
View File

@ -0,0 +1,242 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
语音识别WebSocket服务器入口文件
提供基于WebSocket的实时语音识别服务支持离线和在线两种模式
"""
import asyncio
import json
import websockets
import ssl
import argparse
from config import parse_args
from models import load_models
from service import ASRService
# 全局变量存储当前连接的WebSocket客户端
websocket_users = set()
async def ws_reset(websocket):
"""重置WebSocket连接状态并关闭连接"""
print(f"重置WebSocket连接当前连接数: {len(websocket_users)}")
# 重置状态字典
websocket.status_dict_asr_online["cache"] = {}
websocket.status_dict_asr_online["is_final"] = True
websocket.status_dict_vad["cache"] = {}
websocket.status_dict_vad["is_final"] = True
websocket.status_dict_punc["cache"] = {}
# 关闭连接
await websocket.close()
async def clear_websocket():
"""清理所有WebSocket连接"""
for websocket in websocket_users:
await ws_reset(websocket)
websocket_users.clear()
async def ws_serve(websocket, path):
"""
WebSocket服务主函数处理客户端连接和消息
参数:
websocket: WebSocket连接对象
path: 连接路径
"""
frames = [] # 存储所有音频帧
frames_asr = [] # 存储用于离线ASR的音频帧
frames_asr_online = [] # 存储用于在线ASR的音频帧
global websocket_users
# await clear_websocket() # 清理现有连接(目前注释掉,允许多客户端)
# 添加到用户集合
websocket_users.add(websocket)
# 初始化连接状态
websocket.status_dict_asr = {}
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
websocket.status_dict_vad = {"cache": {}, "is_final": False}
websocket.status_dict_punc = {"cache": {}}
websocket.chunk_interval = 10
websocket.vad_pre_idx = 0
websocket.is_speaking = True # 默认用户正在说话
# 语音检测状态
speech_start = False
speech_end_i = -1
# 初始化配置
websocket.wav_name = "microphone"
websocket.mode = "2pass" # 默认使用两阶段识别模式
print("新用户已连接", flush=True)
try:
# 持续接收客户端消息
async for message in websocket:
# 处理JSON配置消息
if isinstance(message, str):
try:
messagejson = json.loads(message)
# 更新各种配置参数
if "is_speaking" in messagejson:
websocket.is_speaking = messagejson["is_speaking"]
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
if "chunk_interval" in messagejson:
websocket.chunk_interval = messagejson["chunk_interval"]
if "wav_name" in messagejson:
websocket.wav_name = messagejson.get("wav_name")
if "chunk_size" in messagejson:
chunk_size = messagejson["chunk_size"]
if isinstance(chunk_size, str):
chunk_size = chunk_size.split(",")
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
if "encoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
if "decoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"]
if "hotword" in messagejson:
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
if "mode" in messagejson:
websocket.mode = messagejson["mode"]
except json.JSONDecodeError:
print(f"无效的JSON消息: {message}")
# 根据chunk_interval更新VAD的chunk_size
websocket.status_dict_vad["chunk_size"] = int(
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1] * 60 / websocket.chunk_interval
)
# 处理音频数据
if len(frames_asr_online) > 0 or len(frames_asr) >= 0 or not isinstance(message, str):
if not isinstance(message, str): # 二进制音频数据
# 添加到帧缓冲区
frames.append(message)
duration_ms = len(message) // 32 # 计算音频时长
websocket.vad_pre_idx += duration_ms
# 处理在线ASR
frames_asr_online.append(message)
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
# 达到chunk_interval或最终帧时处理在线ASR
if (len(frames_asr_online) % websocket.chunk_interval == 0 or
websocket.status_dict_asr_online["is_final"]):
if websocket.mode == "2pass" or websocket.mode == "online":
audio_in = b"".join(frames_asr_online)
try:
await asr_service.async_asr_online(websocket, audio_in)
except Exception as e:
print(f"在线ASR处理错误: {e}")
frames_asr_online = []
# 如果检测到语音开始收集帧用于离线ASR
if speech_start:
frames_asr.append(message)
# VAD处理 - 语音活动检测
try:
speech_start_i, speech_end_i = await asr_service.async_vad(websocket, message)
except Exception as e:
print(f"VAD处理错误: {e}")
# 检测到语音开始
if speech_start_i != -1:
speech_start = True
# 计算开始偏移并收集前面的帧
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
frames_pre = frames[-beg_bias:] if beg_bias < len(frames) else frames
frames_asr = []
frames_asr.extend(frames_pre)
# 处理离线ASR (语音结束或用户停止说话)
if speech_end_i != -1 or not websocket.is_speaking:
if websocket.mode == "2pass" or websocket.mode == "offline":
audio_in = b"".join(frames_asr)
try:
await asr_service.async_asr(websocket, audio_in)
except Exception as e:
print(f"离线ASR处理错误: {e}")
# 重置状态
frames_asr = []
speech_start = False
frames_asr_online = []
websocket.status_dict_asr_online["cache"] = {}
# 如果用户停止说话,完全重置
if not websocket.is_speaking:
websocket.vad_pre_idx = 0
frames = []
websocket.status_dict_vad["cache"] = {}
else:
# 保留最近的帧用于下一轮处理
frames = frames[-20:]
except websockets.ConnectionClosed:
print(f"连接已关闭...", flush=True)
await ws_reset(websocket)
websocket_users.remove(websocket)
except websockets.InvalidState:
print("无效的WebSocket状态...")
except Exception as e:
print(f"发生异常: {e}")
def start_server(args, asr_service_instance):
"""
启动WebSocket服务器
参数:
args: 命令行参数
asr_service_instance: ASR服务实例
"""
global asr_service
asr_service = asr_service_instance
# 配置SSL (如果提供了证书)
if args.certfile and len(args.certfile) > 0:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile)
start_server = websockets.serve(
ws_serve, args.host, args.port,
subprotocols=["binary"],
ping_interval=None,
ssl=ssl_context
)
else:
start_server = websockets.serve(
ws_serve, args.host, args.port,
subprotocols=["binary"],
ping_interval=None
)
print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}")
# 启动事件循环
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()
if __name__ == "__main__":
# 解析命令行参数
args = parse_args()
# 加载模型
print("正在加载模型...")
models = load_models(args)
print("模型加载完成!当前仅支持单个客户端同时连接!")
# 创建ASR服务
asr_service = ASRService(models)
# 启动服务器
start_server(args, asr_service)

127
src/service.py Normal file
View File

@ -0,0 +1,127 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
ASR服务模块 - 提供语音识别相关的核心功能
"""
import json
class ASRService:
"""ASR服务类封装各种语音识别相关功能"""
def __init__(self, models):
"""
初始化ASR服务
参数:
models: 包含各种预加载模型的字典
"""
self.model_asr = models["asr"]
self.model_asr_streaming = models["asr_streaming"]
self.model_vad = models["vad"]
self.model_punc = models["punc"]
async def async_vad(self, websocket, audio_in):
"""
语音活动检测
参数:
websocket: WebSocket连接
audio_in: 二进制音频数据
返回:
tuple: (speech_start, speech_end) 语音开始和结束位置
"""
# 使用VAD模型分析音频段
segments_result = self.model_vad.generate(
input=audio_in,
**websocket.status_dict_vad
)[0]["value"]
speech_start = -1
speech_end = -1
# 解析VAD结果
if len(segments_result) == 0 or len(segments_result) > 1:
return speech_start, speech_end
if segments_result[0][0] != -1:
speech_start = segments_result[0][0]
if segments_result[0][1] != -1:
speech_end = segments_result[0][1]
return speech_start, speech_end
async def async_asr(self, websocket, audio_in):
"""
离线ASR处理
参数:
websocket: WebSocket连接
audio_in: 二进制音频数据
"""
if len(audio_in) > 0:
# 使用离线ASR模型处理音频
rec_result = self.model_asr.generate(
input=audio_in,
**websocket.status_dict_asr
)[0]
# 如果有标点符号模型且识别出文本,则添加标点
if self.model_punc is not None and len(rec_result["text"]) > 0:
rec_result = self.model_punc.generate(
input=rec_result["text"],
**websocket.status_dict_punc
)[0]
# 如果识别出文本,发送到客户端
if len(rec_result["text"]) > 0:
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({
"mode": mode,
"text": rec_result["text"],
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
})
await websocket.send(message)
else:
# 如果没有音频数据,发送空文本
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({
"mode": mode,
"text": "",
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
})
await websocket.send(message)
async def async_asr_online(self, websocket, audio_in):
"""
在线ASR处理
参数:
websocket: WebSocket连接
audio_in: 二进制音频数据
"""
if len(audio_in) > 0:
# 使用在线ASR模型处理音频
rec_result = self.model_asr_streaming.generate(
input=audio_in,
**websocket.status_dict_asr_online
)[0]
# 在2pass模式下如果是最终帧则跳过(留给离线ASR处理)
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False):
return
# 如果识别出文本,发送到客户端
if len(rec_result["text"]):
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({
"mode": mode,
"text": rec_result["text"],
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
})
await websocket.send(message)

1
tests/__init__.py Normal file
View File

@ -0,0 +1 @@
"""FunASR WebSocket服务测试模块"""

69
tests/test_config.py Normal file
View File

@ -0,0 +1,69 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
配置模块测试
"""
import pytest
import sys
import os
from unittest.mock import patch
# 将src目录添加到路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src.config import parse_args
def test_default_args():
"""测试默认参数值"""
with patch('sys.argv', ['script.py']):
args = parse_args()
# 检查服务器参数
assert args.host == "0.0.0.0"
assert args.port == 10095
# 检查SSL参数
assert args.certfile == ""
assert args.keyfile == ""
# 检查模型参数
assert "paraformer" in args.asr_model
assert args.asr_model_revision == "v2.0.4"
assert "paraformer" in args.asr_model_online
assert args.asr_model_online_revision == "v2.0.4"
assert "vad" in args.vad_model
assert args.vad_model_revision == "v2.0.4"
assert "punc" in args.punc_model
assert args.punc_model_revision == "v2.0.4"
# 检查硬件配置
assert args.ngpu == 1
assert args.device == "cuda"
assert args.ncpu == 4
def test_custom_args():
"""测试自定义参数值"""
test_args = [
'script.py',
'--host', 'localhost',
'--port', '8080',
'--certfile', 'cert.pem',
'--keyfile', 'key.pem',
'--asr_model', 'custom_model',
'--ngpu', '0',
'--device', 'cpu'
]
with patch('sys.argv', test_args):
args = parse_args()
# 检查自定义参数
assert args.host == "localhost"
assert args.port == 8080
assert args.certfile == "cert.pem"
assert args.keyfile == "key.pem"
assert args.asr_model == "custom_model"
assert args.ngpu == 0
assert args.device == "cpu"