[Init] 初始化项目,基于funasr的实时语音识别
This commit is contained in:
commit
86e5425787
53
.gitignore
vendored
Normal file
53
.gitignore
vendored
Normal file
@ -0,0 +1,53 @@
|
||||
# Python相关
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
env/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# 虚拟环境
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
|
||||
# 测试
|
||||
.coverage
|
||||
htmlcov/
|
||||
.pytest_cache/
|
||||
|
||||
# 编辑器相关
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# 系统文件
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# 日志文件
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# 环境变量
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# cursor规则
|
||||
.cursor
|
27
Dockerfile
Normal file
27
Dockerfile
Normal file
@ -0,0 +1,27 @@
|
||||
FROM python:3.9-slim
|
||||
|
||||
# 安装系统依赖
|
||||
RUN apt-get update && apt-get install -y \
|
||||
build-essential \
|
||||
libsndfile1 \
|
||||
ffmpeg \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 安装Python依赖
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 复制应用代码
|
||||
COPY . .
|
||||
|
||||
# 设置环境变量
|
||||
ENV PYTHONPATH=/app
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# 暴露WebSocket端口
|
||||
EXPOSE 10095
|
||||
|
||||
# 启动服务
|
||||
CMD ["python", "src/server.py"]
|
113
README.md
Normal file
113
README.md
Normal file
@ -0,0 +1,113 @@
|
||||
# FunASR WebSocket服务
|
||||
|
||||
## 简介
|
||||
本项目基于FunASR实现了一个WebSocket语音识别服务,支持实时语音流的在线和离线识别。利用ModelScope开源语音模型,该服务可以进行高精度的中文语音识别,并支持语音活动检测(VAD)和自动添加标点符号。
|
||||
|
||||
## 项目结构
|
||||
```
|
||||
.
|
||||
├── src/ # 源代码目录
|
||||
│ ├── __init__.py # 包初始化文件
|
||||
│ ├── server.py # WebSocket服务器实现
|
||||
│ ├── config.py # 配置处理模块
|
||||
│ ├── models.py # 模型加载模块
|
||||
│ ├── service.py # ASR服务实现
|
||||
│ └── client.py # 测试客户端
|
||||
├── tests/ # 测试目录
|
||||
│ ├── __init__.py # 测试包初始化文件
|
||||
│ └── test_config.py # 配置模块测试
|
||||
├── requirements.txt # Python依赖
|
||||
├── Dockerfile # Docker配置
|
||||
├── docker-compose.yml # Docker Compose配置
|
||||
├── .gitignore # Git忽略文件
|
||||
└── README.md # 项目说明
|
||||
```
|
||||
|
||||
## 功能特性
|
||||
|
||||
- **多模式识别**:支持离线(offline)、在线(online)和两阶段(2pass)识别模式
|
||||
- **语音活动检测**:自动检测语音开始和结束
|
||||
- **标点符号**:支持自动添加标点符号
|
||||
- **WebSocket接口**:基于二进制WebSocket提供实时语音识别
|
||||
- **Docker支持**:提供容器化部署支持
|
||||
|
||||
## 安装与使用
|
||||
|
||||
### 环境要求
|
||||
- Python 3.8+
|
||||
- CUDA支持 (若需GPU加速)
|
||||
- 内存 >= 8GB
|
||||
|
||||
### 安装依赖
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 运行服务器
|
||||
|
||||
```bash
|
||||
python src/server.py
|
||||
```
|
||||
|
||||
常用启动参数:
|
||||
- `--host`: 服务器监听地址,默认为 0.0.0.0
|
||||
- `--port`: 服务器端口,默认为 10095
|
||||
- `--device`: 设备类型(cuda或cpu),默认为 cuda
|
||||
- `--ngpu`: GPU数量,0表示使用CPU,默认为 1
|
||||
|
||||
### 测试客户端
|
||||
|
||||
```bash
|
||||
python src/client.py --audio_file path/to/audio.wav
|
||||
```
|
||||
|
||||
常用客户端参数:
|
||||
- `--audio_file`: 要识别的音频文件路径
|
||||
- `--mode`: 识别模式,可选 2pass/online/offline,默认为 2pass
|
||||
- `--host`: 服务器地址,默认为 localhost
|
||||
- `--port`: 服务器端口,默认为 10095
|
||||
|
||||
## Docker部署
|
||||
|
||||
### 构建镜像
|
||||
|
||||
```bash
|
||||
docker build -t funasr-websocket .
|
||||
```
|
||||
|
||||
### 使用Docker Compose启动
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
## API说明
|
||||
|
||||
### WebSocket消息格式
|
||||
|
||||
1. **客户端配置消息**:
|
||||
```json
|
||||
{
|
||||
"mode": "2pass", // 可选: "2pass", "online", "offline"
|
||||
"chunk_size": "5,10", // 块大小,格式为"encoder_size,decoder_size"
|
||||
"wav_name": "audio1", // 音频标识名称
|
||||
"is_speaking": true // 是否正在说话
|
||||
}
|
||||
```
|
||||
|
||||
2. **客户端音频数据**:
|
||||
二进制音频数据流,16kHz采样率,16位PCM格式
|
||||
|
||||
3. **服务器识别结果**:
|
||||
```json
|
||||
{
|
||||
"mode": "2pass-online", // 识别模式
|
||||
"text": "识别的文本内容", // 识别结果
|
||||
"wav_name": "audio1", // 音频标识
|
||||
"is_final": false // 是否是最终结果
|
||||
}
|
||||
```
|
||||
|
||||
## 许可证
|
||||
[MIT](LICENSE)
|
23
docker-compose.yml
Normal file
23
docker-compose.yml
Normal file
@ -0,0 +1,23 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
funasr:
|
||||
build: .
|
||||
container_name: funasr-websocket
|
||||
volumes:
|
||||
- .:/app
|
||||
# 如果需要使用本地模型缓存
|
||||
- ~/.cache/modelscope:/root/.cache/modelscope
|
||||
ports:
|
||||
- "10095:10095"
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
restart: unless-stopped
|
||||
command: python src/server.py --device cuda --ngpu 1
|
11
requirements.txt
Normal file
11
requirements.txt
Normal file
@ -0,0 +1,11 @@
|
||||
pytest==7.3.1
|
||||
pytest-cov==4.1.0
|
||||
flake8==6.0.0
|
||||
black==23.3.0
|
||||
isort==5.12.0
|
||||
flask==2.3.2
|
||||
requests==2.31.0
|
||||
websockets==11.0.3
|
||||
numpy==1.24.3
|
||||
funasr==0.10.0
|
||||
modelscope==1.9.5
|
14
src/__init__.py
Normal file
14
src/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""
|
||||
FunASR WebSocket服务
|
||||
====================
|
||||
|
||||
提供基于WebSocket的实时语音识别服务,支持在线和离线两种识别模式。
|
||||
|
||||
主要特性:
|
||||
- 支持实时语音流识别
|
||||
- 支持VAD语音活动检测
|
||||
- 支持标点符号自动添加
|
||||
- 支持多种识别模式(2pass/online/offline)
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
196
src/client.py
Normal file
196
src/client.py
Normal file
@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
WebSocket客户端示例 - 用于测试语音识别服务
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import websockets
|
||||
import argparse
|
||||
import numpy as np
|
||||
import wave
|
||||
import os
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""解析命令行参数"""
|
||||
parser = argparse.ArgumentParser(description="FunASR WebSocket客户端")
|
||||
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="服务器主机地址"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=10095,
|
||||
help="服务器端口"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--audio_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="要识别的音频文件路径"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="2pass",
|
||||
choices=["2pass", "online", "offline"],
|
||||
help="识别模式: 2pass(默认), online, offline"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chunk_size",
|
||||
type=str,
|
||||
default="5,10",
|
||||
help="分块大小, 格式为'encoder_size,decoder_size'"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_ssl",
|
||||
action="store_true",
|
||||
help="是否使用SSL连接"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def send_audio(websocket, audio_file, mode, chunk_size):
|
||||
"""
|
||||
发送音频文件到服务器进行识别
|
||||
|
||||
参数:
|
||||
websocket: WebSocket连接
|
||||
audio_file: 音频文件路径
|
||||
mode: 识别模式
|
||||
chunk_size: 分块大小
|
||||
"""
|
||||
# 打开并读取WAV文件
|
||||
with wave.open(audio_file, "rb") as wav_file:
|
||||
params = wav_file.getparams()
|
||||
frames = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
# 音频文件信息
|
||||
print(f"音频文件: {os.path.basename(audio_file)}")
|
||||
print(f"采样率: {params.framerate}Hz, 通道数: {params.nchannels}")
|
||||
print(f"采样位深: {params.sampwidth * 8}位, 总帧数: {params.nframes}")
|
||||
|
||||
# 设置配置参数
|
||||
config = {
|
||||
"mode": mode,
|
||||
"chunk_size": chunk_size,
|
||||
"wav_name": os.path.basename(audio_file),
|
||||
"is_speaking": True
|
||||
}
|
||||
|
||||
# 发送配置
|
||||
await websocket.send(json.dumps(config))
|
||||
|
||||
# 模拟实时发送音频数据
|
||||
chunk_size_bytes = 3200 # 每次发送100ms的16kHz音频
|
||||
total_chunks = len(frames) // chunk_size_bytes
|
||||
|
||||
print(f"开始发送音频数据,共 {total_chunks} 个数据块...")
|
||||
|
||||
try:
|
||||
for i in range(0, len(frames), chunk_size_bytes):
|
||||
chunk = frames[i:i+chunk_size_bytes]
|
||||
await websocket.send(chunk)
|
||||
|
||||
# 模拟实时,每100ms发送一次
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 显示进度
|
||||
if (i // chunk_size_bytes) % 10 == 0:
|
||||
print(f"已发送 {i // chunk_size_bytes}/{total_chunks} 数据块")
|
||||
|
||||
# 发送结束信号
|
||||
await websocket.send(json.dumps({"is_speaking": False}))
|
||||
print("音频数据发送完成")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发送音频时出错: {e}")
|
||||
|
||||
|
||||
async def receive_results(websocket):
|
||||
"""
|
||||
接收并显示识别结果
|
||||
|
||||
参数:
|
||||
websocket: WebSocket连接
|
||||
"""
|
||||
online_text = ""
|
||||
offline_text = ""
|
||||
|
||||
try:
|
||||
async for message in websocket:
|
||||
# 解析服务器返回的JSON消息
|
||||
result = json.loads(message)
|
||||
|
||||
mode = result.get("mode", "")
|
||||
text = result.get("text", "")
|
||||
is_final = result.get("is_final", False)
|
||||
|
||||
# 根据模式更新文本
|
||||
if "online" in mode:
|
||||
online_text = text
|
||||
print(f"\r[在线识别] {online_text}", end="", flush=True)
|
||||
elif "offline" in mode:
|
||||
offline_text = text
|
||||
print(f"\n[离线识别] {offline_text}")
|
||||
|
||||
# 如果是最终结果,打印完整信息
|
||||
if is_final and offline_text:
|
||||
print("\n最终识别结果:")
|
||||
print(f"[离线识别] {offline_text}")
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
print(f"接收结果时出错: {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
args = parse_args()
|
||||
|
||||
# WebSocket URI
|
||||
protocol = "wss" if args.use_ssl else "ws"
|
||||
uri = f"{protocol}://{args.host}:{args.port}"
|
||||
|
||||
print(f"连接到服务器: {uri}")
|
||||
|
||||
try:
|
||||
# 创建WebSocket连接
|
||||
async with websockets.connect(
|
||||
uri,
|
||||
subprotocols=["binary"]
|
||||
) as websocket:
|
||||
|
||||
print("连接成功")
|
||||
|
||||
# 创建两个任务: 发送音频和接收结果
|
||||
send_task = asyncio.create_task(
|
||||
send_audio(websocket, args.audio_file, args.mode, args.chunk_size)
|
||||
)
|
||||
|
||||
receive_task = asyncio.create_task(
|
||||
receive_results(websocket)
|
||||
)
|
||||
|
||||
# 等待任务完成
|
||||
await asyncio.gather(send_task, receive_task)
|
||||
|
||||
except Exception as e:
|
||||
print(f"连接服务器失败: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行主函数
|
||||
asyncio.run(main())
|
130
src/config.py
Normal file
130
src/config.py
Normal file
@ -0,0 +1,130 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
配置模块 - 处理命令行参数和配置项
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
解析命令行参数
|
||||
|
||||
返回:
|
||||
argparse.Namespace: 解析后的参数对象
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="FunASR WebSocket服务器")
|
||||
|
||||
# 服务器配置
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="服务器主机地址,例如:localhost, 0.0.0.0"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=10095,
|
||||
help="WebSocket服务器端口"
|
||||
)
|
||||
|
||||
# SSL配置
|
||||
parser.add_argument(
|
||||
"--certfile",
|
||||
type=str,
|
||||
default="",
|
||||
help="SSL证书文件路径"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keyfile",
|
||||
type=str,
|
||||
default="",
|
||||
help="SSL密钥文件路径"
|
||||
)
|
||||
|
||||
# ASR模型配置
|
||||
parser.add_argument(
|
||||
"--asr_model",
|
||||
type=str,
|
||||
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
help="离线ASR模型(从ModelScope获取)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--asr_model_revision",
|
||||
type=str,
|
||||
default="v2.0.4",
|
||||
help="离线ASR模型版本"
|
||||
)
|
||||
|
||||
# 在线ASR模型配置
|
||||
parser.add_argument(
|
||||
"--asr_model_online",
|
||||
type=str,
|
||||
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
|
||||
help="在线ASR模型(从ModelScope获取)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--asr_model_online_revision",
|
||||
type=str,
|
||||
default="v2.0.4",
|
||||
help="在线ASR模型版本"
|
||||
)
|
||||
|
||||
# VAD模型配置
|
||||
parser.add_argument(
|
||||
"--vad_model",
|
||||
type=str,
|
||||
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
help="VAD语音活动检测模型(从ModelScope获取)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vad_model_revision",
|
||||
type=str,
|
||||
default="v2.0.4",
|
||||
help="VAD模型版本"
|
||||
)
|
||||
|
||||
# 标点符号模型配置
|
||||
parser.add_argument(
|
||||
"--punc_model",
|
||||
type=str,
|
||||
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
|
||||
help="标点符号模型(从ModelScope获取)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--punc_model_revision",
|
||||
type=str,
|
||||
default="v2.0.4",
|
||||
help="标点符号模型版本"
|
||||
)
|
||||
|
||||
# 硬件配置
|
||||
parser.add_argument(
|
||||
"--ngpu",
|
||||
type=int,
|
||||
default=1,
|
||||
help="GPU数量,0表示仅使用CPU"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="设备类型:cuda或cpu"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ncpu",
|
||||
type=int,
|
||||
default=4,
|
||||
help="CPU核心数"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
print("配置参数:")
|
||||
for arg in vars(args):
|
||||
print(f" {arg}: {getattr(args, arg)}")
|
79
src/models.py
Normal file
79
src/models.py
Normal file
@ -0,0 +1,79 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
模型加载模块 - 负责加载各种语音识别相关模型
|
||||
"""
|
||||
|
||||
def load_models(args):
|
||||
"""
|
||||
加载所有需要的模型
|
||||
|
||||
参数:
|
||||
args: 命令行参数,包含模型配置
|
||||
|
||||
返回:
|
||||
dict: 包含所有加载的模型的字典
|
||||
"""
|
||||
try:
|
||||
# 导入FunASR库
|
||||
from funasr import AutoModel
|
||||
except ImportError:
|
||||
raise ImportError("未找到funasr库,请先安装: pip install funasr")
|
||||
|
||||
# 初始化模型字典
|
||||
models = {}
|
||||
|
||||
# 1. 加载离线ASR模型
|
||||
print(f"正在加载ASR离线模型: {args.asr_model}")
|
||||
models["asr"] = AutoModel(
|
||||
model=args.asr_model,
|
||||
model_revision=args.asr_model_revision,
|
||||
ngpu=args.ngpu,
|
||||
ncpu=args.ncpu,
|
||||
device=args.device,
|
||||
disable_pbar=True,
|
||||
disable_log=True,
|
||||
)
|
||||
|
||||
# 2. 加载在线ASR模型
|
||||
print(f"正在加载ASR在线模型: {args.asr_model_online}")
|
||||
models["asr_streaming"] = AutoModel(
|
||||
model=args.asr_model_online,
|
||||
model_revision=args.asr_model_online_revision,
|
||||
ngpu=args.ngpu,
|
||||
ncpu=args.ncpu,
|
||||
device=args.device,
|
||||
disable_pbar=True,
|
||||
disable_log=True,
|
||||
)
|
||||
|
||||
# 3. 加载VAD模型
|
||||
print(f"正在加载VAD模型: {args.vad_model}")
|
||||
models["vad"] = AutoModel(
|
||||
model=args.vad_model,
|
||||
model_revision=args.vad_model_revision,
|
||||
ngpu=args.ngpu,
|
||||
ncpu=args.ncpu,
|
||||
device=args.device,
|
||||
disable_pbar=True,
|
||||
disable_log=True,
|
||||
)
|
||||
|
||||
# 4. 加载标点符号模型(如果指定)
|
||||
if args.punc_model:
|
||||
print(f"正在加载标点符号模型: {args.punc_model}")
|
||||
models["punc"] = AutoModel(
|
||||
model=args.punc_model,
|
||||
model_revision=args.punc_model_revision,
|
||||
ngpu=args.ngpu,
|
||||
ncpu=args.ncpu,
|
||||
device=args.device,
|
||||
disable_pbar=True,
|
||||
disable_log=True,
|
||||
)
|
||||
else:
|
||||
models["punc"] = None
|
||||
print("未指定标点符号模型,将不使用标点符号")
|
||||
|
||||
print("所有模型加载完成")
|
||||
return models
|
242
src/server.py
Normal file
242
src/server.py
Normal file
@ -0,0 +1,242 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
语音识别WebSocket服务器入口文件
|
||||
提供基于WebSocket的实时语音识别服务,支持离线和在线两种模式
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import websockets
|
||||
import ssl
|
||||
import argparse
|
||||
from config import parse_args
|
||||
from models import load_models
|
||||
from service import ASRService
|
||||
|
||||
# 全局变量,存储当前连接的WebSocket客户端
|
||||
websocket_users = set()
|
||||
|
||||
|
||||
async def ws_reset(websocket):
|
||||
"""重置WebSocket连接状态并关闭连接"""
|
||||
print(f"重置WebSocket连接,当前连接数: {len(websocket_users)}")
|
||||
|
||||
# 重置状态字典
|
||||
websocket.status_dict_asr_online["cache"] = {}
|
||||
websocket.status_dict_asr_online["is_final"] = True
|
||||
websocket.status_dict_vad["cache"] = {}
|
||||
websocket.status_dict_vad["is_final"] = True
|
||||
websocket.status_dict_punc["cache"] = {}
|
||||
|
||||
# 关闭连接
|
||||
await websocket.close()
|
||||
|
||||
|
||||
async def clear_websocket():
|
||||
"""清理所有WebSocket连接"""
|
||||
for websocket in websocket_users:
|
||||
await ws_reset(websocket)
|
||||
websocket_users.clear()
|
||||
|
||||
|
||||
async def ws_serve(websocket, path):
|
||||
"""
|
||||
WebSocket服务主函数,处理客户端连接和消息
|
||||
|
||||
参数:
|
||||
websocket: WebSocket连接对象
|
||||
path: 连接路径
|
||||
"""
|
||||
frames = [] # 存储所有音频帧
|
||||
frames_asr = [] # 存储用于离线ASR的音频帧
|
||||
frames_asr_online = [] # 存储用于在线ASR的音频帧
|
||||
|
||||
global websocket_users
|
||||
# await clear_websocket() # 清理现有连接(目前注释掉,允许多客户端)
|
||||
|
||||
# 添加到用户集合
|
||||
websocket_users.add(websocket)
|
||||
|
||||
# 初始化连接状态
|
||||
websocket.status_dict_asr = {}
|
||||
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
|
||||
websocket.status_dict_vad = {"cache": {}, "is_final": False}
|
||||
websocket.status_dict_punc = {"cache": {}}
|
||||
websocket.chunk_interval = 10
|
||||
websocket.vad_pre_idx = 0
|
||||
websocket.is_speaking = True # 默认用户正在说话
|
||||
|
||||
# 语音检测状态
|
||||
speech_start = False
|
||||
speech_end_i = -1
|
||||
|
||||
# 初始化配置
|
||||
websocket.wav_name = "microphone"
|
||||
websocket.mode = "2pass" # 默认使用两阶段识别模式
|
||||
|
||||
print("新用户已连接", flush=True)
|
||||
|
||||
try:
|
||||
# 持续接收客户端消息
|
||||
async for message in websocket:
|
||||
# 处理JSON配置消息
|
||||
if isinstance(message, str):
|
||||
try:
|
||||
messagejson = json.loads(message)
|
||||
|
||||
# 更新各种配置参数
|
||||
if "is_speaking" in messagejson:
|
||||
websocket.is_speaking = messagejson["is_speaking"]
|
||||
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
|
||||
if "chunk_interval" in messagejson:
|
||||
websocket.chunk_interval = messagejson["chunk_interval"]
|
||||
if "wav_name" in messagejson:
|
||||
websocket.wav_name = messagejson.get("wav_name")
|
||||
if "chunk_size" in messagejson:
|
||||
chunk_size = messagejson["chunk_size"]
|
||||
if isinstance(chunk_size, str):
|
||||
chunk_size = chunk_size.split(",")
|
||||
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
|
||||
if "encoder_chunk_look_back" in messagejson:
|
||||
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
|
||||
if "decoder_chunk_look_back" in messagejson:
|
||||
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"]
|
||||
if "hotword" in messagejson:
|
||||
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
|
||||
if "mode" in messagejson:
|
||||
websocket.mode = messagejson["mode"]
|
||||
except json.JSONDecodeError:
|
||||
print(f"无效的JSON消息: {message}")
|
||||
|
||||
# 根据chunk_interval更新VAD的chunk_size
|
||||
websocket.status_dict_vad["chunk_size"] = int(
|
||||
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1] * 60 / websocket.chunk_interval
|
||||
)
|
||||
|
||||
# 处理音频数据
|
||||
if len(frames_asr_online) > 0 or len(frames_asr) >= 0 or not isinstance(message, str):
|
||||
if not isinstance(message, str): # 二进制音频数据
|
||||
# 添加到帧缓冲区
|
||||
frames.append(message)
|
||||
duration_ms = len(message) // 32 # 计算音频时长
|
||||
websocket.vad_pre_idx += duration_ms
|
||||
|
||||
# 处理在线ASR
|
||||
frames_asr_online.append(message)
|
||||
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
|
||||
|
||||
# 达到chunk_interval或最终帧时处理在线ASR
|
||||
if (len(frames_asr_online) % websocket.chunk_interval == 0 or
|
||||
websocket.status_dict_asr_online["is_final"]):
|
||||
if websocket.mode == "2pass" or websocket.mode == "online":
|
||||
audio_in = b"".join(frames_asr_online)
|
||||
try:
|
||||
await asr_service.async_asr_online(websocket, audio_in)
|
||||
except Exception as e:
|
||||
print(f"在线ASR处理错误: {e}")
|
||||
frames_asr_online = []
|
||||
|
||||
# 如果检测到语音开始,收集帧用于离线ASR
|
||||
if speech_start:
|
||||
frames_asr.append(message)
|
||||
|
||||
# VAD处理 - 语音活动检测
|
||||
try:
|
||||
speech_start_i, speech_end_i = await asr_service.async_vad(websocket, message)
|
||||
except Exception as e:
|
||||
print(f"VAD处理错误: {e}")
|
||||
|
||||
# 检测到语音开始
|
||||
if speech_start_i != -1:
|
||||
speech_start = True
|
||||
# 计算开始偏移并收集前面的帧
|
||||
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
|
||||
frames_pre = frames[-beg_bias:] if beg_bias < len(frames) else frames
|
||||
frames_asr = []
|
||||
frames_asr.extend(frames_pre)
|
||||
|
||||
# 处理离线ASR (语音结束或用户停止说话)
|
||||
if speech_end_i != -1 or not websocket.is_speaking:
|
||||
if websocket.mode == "2pass" or websocket.mode == "offline":
|
||||
audio_in = b"".join(frames_asr)
|
||||
try:
|
||||
await asr_service.async_asr(websocket, audio_in)
|
||||
except Exception as e:
|
||||
print(f"离线ASR处理错误: {e}")
|
||||
|
||||
# 重置状态
|
||||
frames_asr = []
|
||||
speech_start = False
|
||||
frames_asr_online = []
|
||||
websocket.status_dict_asr_online["cache"] = {}
|
||||
|
||||
# 如果用户停止说话,完全重置
|
||||
if not websocket.is_speaking:
|
||||
websocket.vad_pre_idx = 0
|
||||
frames = []
|
||||
websocket.status_dict_vad["cache"] = {}
|
||||
else:
|
||||
# 保留最近的帧用于下一轮处理
|
||||
frames = frames[-20:]
|
||||
|
||||
except websockets.ConnectionClosed:
|
||||
print(f"连接已关闭...", flush=True)
|
||||
await ws_reset(websocket)
|
||||
websocket_users.remove(websocket)
|
||||
except websockets.InvalidState:
|
||||
print("无效的WebSocket状态...")
|
||||
except Exception as e:
|
||||
print(f"发生异常: {e}")
|
||||
|
||||
|
||||
def start_server(args, asr_service_instance):
|
||||
"""
|
||||
启动WebSocket服务器
|
||||
|
||||
参数:
|
||||
args: 命令行参数
|
||||
asr_service_instance: ASR服务实例
|
||||
"""
|
||||
global asr_service
|
||||
asr_service = asr_service_instance
|
||||
|
||||
# 配置SSL (如果提供了证书)
|
||||
if args.certfile and len(args.certfile) > 0:
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile)
|
||||
|
||||
start_server = websockets.serve(
|
||||
ws_serve, args.host, args.port,
|
||||
subprotocols=["binary"],
|
||||
ping_interval=None,
|
||||
ssl=ssl_context
|
||||
)
|
||||
else:
|
||||
start_server = websockets.serve(
|
||||
ws_serve, args.host, args.port,
|
||||
subprotocols=["binary"],
|
||||
ping_interval=None
|
||||
)
|
||||
|
||||
print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}")
|
||||
|
||||
# 启动事件循环
|
||||
asyncio.get_event_loop().run_until_complete(start_server)
|
||||
asyncio.get_event_loop().run_forever()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 解析命令行参数
|
||||
args = parse_args()
|
||||
|
||||
# 加载模型
|
||||
print("正在加载模型...")
|
||||
models = load_models(args)
|
||||
print("模型加载完成!当前仅支持单个客户端同时连接!")
|
||||
|
||||
# 创建ASR服务
|
||||
asr_service = ASRService(models)
|
||||
|
||||
# 启动服务器
|
||||
start_server(args, asr_service)
|
127
src/service.py
Normal file
127
src/service.py
Normal file
@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
ASR服务模块 - 提供语音识别相关的核心功能
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class ASRService:
|
||||
"""ASR服务类,封装各种语音识别相关功能"""
|
||||
|
||||
def __init__(self, models):
|
||||
"""
|
||||
初始化ASR服务
|
||||
|
||||
参数:
|
||||
models: 包含各种预加载模型的字典
|
||||
"""
|
||||
self.model_asr = models["asr"]
|
||||
self.model_asr_streaming = models["asr_streaming"]
|
||||
self.model_vad = models["vad"]
|
||||
self.model_punc = models["punc"]
|
||||
|
||||
async def async_vad(self, websocket, audio_in):
|
||||
"""
|
||||
语音活动检测
|
||||
|
||||
参数:
|
||||
websocket: WebSocket连接
|
||||
audio_in: 二进制音频数据
|
||||
|
||||
返回:
|
||||
tuple: (speech_start, speech_end) 语音开始和结束位置
|
||||
"""
|
||||
# 使用VAD模型分析音频段
|
||||
segments_result = self.model_vad.generate(
|
||||
input=audio_in,
|
||||
**websocket.status_dict_vad
|
||||
)[0]["value"]
|
||||
|
||||
speech_start = -1
|
||||
speech_end = -1
|
||||
|
||||
# 解析VAD结果
|
||||
if len(segments_result) == 0 or len(segments_result) > 1:
|
||||
return speech_start, speech_end
|
||||
|
||||
if segments_result[0][0] != -1:
|
||||
speech_start = segments_result[0][0]
|
||||
if segments_result[0][1] != -1:
|
||||
speech_end = segments_result[0][1]
|
||||
|
||||
return speech_start, speech_end
|
||||
|
||||
async def async_asr(self, websocket, audio_in):
|
||||
"""
|
||||
离线ASR处理
|
||||
|
||||
参数:
|
||||
websocket: WebSocket连接
|
||||
audio_in: 二进制音频数据
|
||||
"""
|
||||
if len(audio_in) > 0:
|
||||
# 使用离线ASR模型处理音频
|
||||
rec_result = self.model_asr.generate(
|
||||
input=audio_in,
|
||||
**websocket.status_dict_asr
|
||||
)[0]
|
||||
|
||||
# 如果有标点符号模型且识别出文本,则添加标点
|
||||
if self.model_punc is not None and len(rec_result["text"]) > 0:
|
||||
rec_result = self.model_punc.generate(
|
||||
input=rec_result["text"],
|
||||
**websocket.status_dict_punc
|
||||
)[0]
|
||||
|
||||
# 如果识别出文本,发送到客户端
|
||||
if len(rec_result["text"]) > 0:
|
||||
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
||||
message = json.dumps({
|
||||
"mode": mode,
|
||||
"text": rec_result["text"],
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
})
|
||||
await websocket.send(message)
|
||||
else:
|
||||
# 如果没有音频数据,发送空文本
|
||||
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
||||
message = json.dumps({
|
||||
"mode": mode,
|
||||
"text": "",
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
})
|
||||
await websocket.send(message)
|
||||
|
||||
async def async_asr_online(self, websocket, audio_in):
|
||||
"""
|
||||
在线ASR处理
|
||||
|
||||
参数:
|
||||
websocket: WebSocket连接
|
||||
audio_in: 二进制音频数据
|
||||
"""
|
||||
if len(audio_in) > 0:
|
||||
# 使用在线ASR模型处理音频
|
||||
rec_result = self.model_asr_streaming.generate(
|
||||
input=audio_in,
|
||||
**websocket.status_dict_asr_online
|
||||
)[0]
|
||||
|
||||
# 在2pass模式下,如果是最终帧则跳过(留给离线ASR处理)
|
||||
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False):
|
||||
return
|
||||
|
||||
# 如果识别出文本,发送到客户端
|
||||
if len(rec_result["text"]):
|
||||
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
|
||||
message = json.dumps({
|
||||
"mode": mode,
|
||||
"text": rec_result["text"],
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
})
|
||||
await websocket.send(message)
|
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""FunASR WebSocket服务测试模块"""
|
69
tests/test_config.py
Normal file
69
tests/test_config.py
Normal file
@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
配置模块测试
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
# 将src目录添加到路径
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
from src.config import parse_args
|
||||
|
||||
|
||||
def test_default_args():
|
||||
"""测试默认参数值"""
|
||||
with patch('sys.argv', ['script.py']):
|
||||
args = parse_args()
|
||||
|
||||
# 检查服务器参数
|
||||
assert args.host == "0.0.0.0"
|
||||
assert args.port == 10095
|
||||
|
||||
# 检查SSL参数
|
||||
assert args.certfile == ""
|
||||
assert args.keyfile == ""
|
||||
|
||||
# 检查模型参数
|
||||
assert "paraformer" in args.asr_model
|
||||
assert args.asr_model_revision == "v2.0.4"
|
||||
assert "paraformer" in args.asr_model_online
|
||||
assert args.asr_model_online_revision == "v2.0.4"
|
||||
assert "vad" in args.vad_model
|
||||
assert args.vad_model_revision == "v2.0.4"
|
||||
assert "punc" in args.punc_model
|
||||
assert args.punc_model_revision == "v2.0.4"
|
||||
|
||||
# 检查硬件配置
|
||||
assert args.ngpu == 1
|
||||
assert args.device == "cuda"
|
||||
assert args.ncpu == 4
|
||||
|
||||
|
||||
def test_custom_args():
|
||||
"""测试自定义参数值"""
|
||||
test_args = [
|
||||
'script.py',
|
||||
'--host', 'localhost',
|
||||
'--port', '8080',
|
||||
'--certfile', 'cert.pem',
|
||||
'--keyfile', 'key.pem',
|
||||
'--asr_model', 'custom_model',
|
||||
'--ngpu', '0',
|
||||
'--device', 'cpu'
|
||||
]
|
||||
|
||||
with patch('sys.argv', test_args):
|
||||
args = parse_args()
|
||||
|
||||
# 检查自定义参数
|
||||
assert args.host == "localhost"
|
||||
assert args.port == 8080
|
||||
assert args.certfile == "cert.pem"
|
||||
assert args.keyfile == "key.pem"
|
||||
assert args.asr_model == "custom_model"
|
||||
assert args.ngpu == 0
|
||||
assert args.device == "cpu"
|
Loading…
x
Reference in New Issue
Block a user