Compare commits
16 Commits
master
...
feature_re
Author | SHA1 | Date | |
---|---|---|---|
![]() |
5a820b49e4 | ||
![]() |
5b94c40016 | ||
![]() |
3d8bf9de25 | ||
![]() |
ff9bd70039 | ||
![]() |
4e9e94d8dc | ||
![]() |
b569b7e63d | ||
![]() |
f245c6e9df | ||
![]() |
49cb428c23 | ||
![]() |
703a40e955 | ||
![]() |
040fc57e02 | ||
![]() |
1392168126 | ||
![]() |
eff22cb33e | ||
![]() |
66c9477e4b | ||
9d522fa137 | |||
f7138dcb39 | |||
8b69ff195f |
19
.cursorrules
Normal file
19
.cursorrules
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
|
||||||
|
You are an AI assistant specialized in Python development. Your approach emphasizes:
|
||||||
|
|
||||||
|
1. Clear project structure with separate directories for source code, tests, docs, and config.
|
||||||
|
2. Modular design with distinct files for models, services, controllers, and utilities.
|
||||||
|
3. Configuration management using environment variables.
|
||||||
|
4. Robust error handling and logging, including context capture.
|
||||||
|
5. Comprehensive testing with pytest.
|
||||||
|
6. Detailed documentation using docstrings and README files.
|
||||||
|
7. Dependency management via https://github.com/astral-sh/rye and virtual environments.
|
||||||
|
8. Code style consistency using Ruff.
|
||||||
|
9. CI/CD implementation with GitHub Actions or GitLab CI.
|
||||||
|
10. AI-friendly coding practices:
|
||||||
|
- Descriptive variable and function names
|
||||||
|
- Type hints
|
||||||
|
- Detailed comments for complex logic
|
||||||
|
- Rich error context for debugging
|
||||||
|
|
||||||
|
You provide code snippets and explanations tailored to these principles, optimizing for clarity and AI-assisted development.
|
@ -108,6 +108,3 @@ docker-compose up -d
|
|||||||
"is_final": false // 是否是最终结果
|
"is_final": false // 是否是最终结果
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## 许可证
|
|
||||||
[MIT](LICENSE)
|
|
30
main.py
Normal file
30
main.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from funasr import AutoModel
|
||||||
|
|
||||||
|
chunk_size = 200 # ms
|
||||||
|
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
|
||||||
|
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
wav_file = "tests/vad_example.wav"
|
||||||
|
speech, sample_rate = soundfile.read(wav_file)
|
||||||
|
chunk_stride = int(chunk_size * sample_rate / 1000)
|
||||||
|
|
||||||
|
cache = {}
|
||||||
|
total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
|
||||||
|
for i in range(total_chunk_num):
|
||||||
|
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
|
||||||
|
is_final = i == total_chunk_num - 1
|
||||||
|
res = model.generate(
|
||||||
|
input=speech_chunk,
|
||||||
|
cache=cache,
|
||||||
|
is_final=is_final,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
disable_pbar=True,
|
||||||
|
)
|
||||||
|
if len(res[0]["value"]):
|
||||||
|
print(res)
|
||||||
|
|
||||||
|
print(f"len(speech): {len(speech)}")
|
||||||
|
print(f"len(speech_chunk): {len(speech_chunk)}")
|
||||||
|
print(f"total_chunk_num: {total_chunk_num}")
|
||||||
|
print(f"generateconfig: chunk_size: {chunk_size}, chunk_stride: {chunk_stride}")
|
@ -11,4 +11,4 @@ FunASR WebSocket服务
|
|||||||
- 支持多种识别模式(2pass/online/offline)
|
- 支持多种识别模式(2pass/online/offline)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
|
194
src/audio_chunk.py
Normal file
194
src/audio_chunk.py
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
音频数据块管理类 - 用于存储和处理16KHz音频数据
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional, Dict
|
||||||
|
from src.models import AudioBinary_Config, AudioBinary_data_list
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__, level="INFO")
|
||||||
|
|
||||||
|
|
||||||
|
class AudioBinary:
|
||||||
|
"""
|
||||||
|
音频数据存储单元
|
||||||
|
用于存储二进制数据
|
||||||
|
面向Slice, 向Slice提供数据与接口
|
||||||
|
|
||||||
|
self._audio_config: AudioBinary_Config -- 音频参数配置
|
||||||
|
self._binary_data_list: AudioBinary_data_list -- 音频数据列表
|
||||||
|
self._slice_listener: List[callable] -- 切片监听器
|
||||||
|
|
||||||
|
AudioBinary_Config: Dict -- 音频参数配置
|
||||||
|
AudioBinary_data_list: List[bytes] -- 音频数据列表
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args):
|
||||||
|
"""
|
||||||
|
初始化音频数据块
|
||||||
|
参数:
|
||||||
|
*args: 可变参数
|
||||||
|
"""
|
||||||
|
# 音频参数配置
|
||||||
|
self._audio_config = AudioBinary_Config()
|
||||||
|
# 音频片段
|
||||||
|
self._binary_data_list: AudioBinary_data_list = AudioBinary_data_list()
|
||||||
|
# 切片监听器
|
||||||
|
self._slice_listener: List = []
|
||||||
|
if isinstance(args, Dict):
|
||||||
|
self._audio_config = AudioBinary_Config.AudioBinary_Config_from_dict(args)
|
||||||
|
elif isinstance(args, AudioBinary_Config):
|
||||||
|
self._audio_config = args
|
||||||
|
else:
|
||||||
|
raise ValueError("参数类型错误")
|
||||||
|
|
||||||
|
def add_slice_listener(self, slice_listener: callable) -> None:
|
||||||
|
"""
|
||||||
|
添加切片监听器
|
||||||
|
参数:
|
||||||
|
slice_listener: callable -- 切片监听器
|
||||||
|
"""
|
||||||
|
self._slice_listener.append(slice_listener)
|
||||||
|
|
||||||
|
def __add__(self, other: bytes):
|
||||||
|
"""
|
||||||
|
__add__ 是 "+" 运算符的重载,
|
||||||
|
使用方法:
|
||||||
|
audio_binary = audio_binary + bytes
|
||||||
|
添加音频数据块 与 add_binary_data 等效,
|
||||||
|
但可以链式调用, 方便使用
|
||||||
|
参数:
|
||||||
|
other: bytes --音频数据块
|
||||||
|
"""
|
||||||
|
self._binary_data_list.append(other)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __iadd__(self, other: bytes):
|
||||||
|
"""
|
||||||
|
__iadd__ 是 "+=" 运算符的重载,
|
||||||
|
使用方法:
|
||||||
|
audio_binary += bytes
|
||||||
|
添加音频数据块 与 add_binary_data 等效,
|
||||||
|
但可以链式调用, 方便使用
|
||||||
|
参数:
|
||||||
|
other: bytes --音频数据块
|
||||||
|
"""
|
||||||
|
self._binary_data_list.append(other)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_binary_data(self, binary_data: bytes):
|
||||||
|
"""
|
||||||
|
添加音频数据块
|
||||||
|
参数:
|
||||||
|
binary_data: bytes --音频数据块
|
||||||
|
"""
|
||||||
|
self._binary_data_list.append(binary_data)
|
||||||
|
|
||||||
|
def rewrite_binary_data(self, target_index: int, binary_data: bytes):
|
||||||
|
"""
|
||||||
|
重写音频数据块
|
||||||
|
参数:
|
||||||
|
target_index: int -- 目标索引
|
||||||
|
binary_data: bytes --音频数据块
|
||||||
|
"""
|
||||||
|
self._binary_data_list.rewrite(target_index, binary_data)
|
||||||
|
|
||||||
|
def get_binary_data(
|
||||||
|
self,
|
||||||
|
start: int = 0,
|
||||||
|
end: Optional[int] = None,
|
||||||
|
) -> Optional[bytes]:
|
||||||
|
"""
|
||||||
|
获取指定索引的音频数据块
|
||||||
|
参数:
|
||||||
|
start: 开始索引
|
||||||
|
end: 结束索引
|
||||||
|
返回:
|
||||||
|
List[bytes]: 音频数据块
|
||||||
|
"""
|
||||||
|
if start >= len(self._binary_data_list):
|
||||||
|
return None
|
||||||
|
if end is None:
|
||||||
|
end = start + 1
|
||||||
|
end = min(end, len(self._binary_data_list))
|
||||||
|
return self._binary_data_list[start:end]
|
||||||
|
|
||||||
|
|
||||||
|
class AudioChunk:
|
||||||
|
"""
|
||||||
|
音频数据块管理类
|
||||||
|
管理两部分内容, AudioBinary和Slice。
|
||||||
|
AudioBinary用于内部存储字节数据。
|
||||||
|
Slice是AudioBinary的切片,用于外部接口。
|
||||||
|
|
||||||
|
此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑。
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance: Optional[AudioChunk] = None
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
单例模式
|
||||||
|
"""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(AudioChunk, cls).__new__(cls, *args, **kwargs)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""
|
||||||
|
初始化AudioChunk实例
|
||||||
|
"""
|
||||||
|
self._audio_binary_list: Dict[str, AudioBinary] = {}
|
||||||
|
self._slice_listener: List[callable] = []
|
||||||
|
|
||||||
|
def get_audio_binary(
|
||||||
|
self,
|
||||||
|
binary_name: Optional[str] = None,
|
||||||
|
audio_config: Optional[AudioBinary_Config] = None,
|
||||||
|
) -> AudioBinary:
|
||||||
|
"""
|
||||||
|
获取音频数据块
|
||||||
|
参数:
|
||||||
|
binary_name: str -- 音频数据块名称
|
||||||
|
返回:
|
||||||
|
AudioBinary: 音频数据块
|
||||||
|
"""
|
||||||
|
if binary_name is None:
|
||||||
|
binary_name = "default"
|
||||||
|
if binary_name not in self._audio_binary_list:
|
||||||
|
self._audio_binary_list[binary_name] = AudioBinary(audio_config)
|
||||||
|
return self._audio_binary_list[binary_name]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _time2size(time_ms: int, audio_config: AudioBinary_Config) -> int:
|
||||||
|
"""
|
||||||
|
将时间(ms)转换为数据大小(字节)
|
||||||
|
参数:
|
||||||
|
time_ms: int -- 时间(ms)
|
||||||
|
audio_config: AudioBinary_Config -- 音频参数配置
|
||||||
|
返回:
|
||||||
|
int: 数据大小(字节)
|
||||||
|
"""
|
||||||
|
# 时间(ms)到字节(bytes)计算方法为: 时间(ms) * 采样率(Hz) * 通道数(1 or 2) * 采样位宽(16 or 24) / 1000
|
||||||
|
time_s = time_ms / 1000
|
||||||
|
bytes_per_sample = audio_config.sample_width * audio_config.channel
|
||||||
|
return int(time_s * audio_config.sample_rate * bytes_per_sample)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _size2time(size: int, audio_config: AudioBinary_Config) -> int:
|
||||||
|
"""
|
||||||
|
将数据大小(字节)转换为时间(ms)
|
||||||
|
参数:
|
||||||
|
size: int -- 数据大小(字节)
|
||||||
|
audio_config: AudioBinary_Config -- 音频参数配置
|
||||||
|
返回:
|
||||||
|
int: 时间(ms)
|
||||||
|
"""
|
||||||
|
# 字节(bytes)到时间(ms)计算方法为: 字节(bytes) * 1000 / (采样率(Hz) * 通道数(1 or 2) * 采样位宽(16 or 24))
|
||||||
|
bytes_per_sample = audio_config.sample_width * audio_config.channel
|
||||||
|
time_ms = size * 1000 // (audio_config.sample_rate * bytes_per_sample)
|
||||||
|
return time_ms
|
196
src/client.py
196
src/client.py
@ -1,196 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
WebSocket客户端示例 - 用于测试语音识别服务
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import websockets
|
|
||||||
import argparse
|
|
||||||
import numpy as np
|
|
||||||
import wave
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
"""解析命令行参数"""
|
|
||||||
parser = argparse.ArgumentParser(description="FunASR WebSocket客户端")
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--host",
|
|
||||||
type=str,
|
|
||||||
default="localhost",
|
|
||||||
help="服务器主机地址"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--port",
|
|
||||||
type=int,
|
|
||||||
default=10095,
|
|
||||||
help="服务器端口"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--audio_file",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="要识别的音频文件路径"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--mode",
|
|
||||||
type=str,
|
|
||||||
default="2pass",
|
|
||||||
choices=["2pass", "online", "offline"],
|
|
||||||
help="识别模式: 2pass(默认), online, offline"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--chunk_size",
|
|
||||||
type=str,
|
|
||||||
default="5,10",
|
|
||||||
help="分块大小, 格式为'encoder_size,decoder_size'"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_ssl",
|
|
||||||
action="store_true",
|
|
||||||
help="是否使用SSL连接"
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
async def send_audio(websocket, audio_file, mode, chunk_size):
|
|
||||||
"""
|
|
||||||
发送音频文件到服务器进行识别
|
|
||||||
|
|
||||||
参数:
|
|
||||||
websocket: WebSocket连接
|
|
||||||
audio_file: 音频文件路径
|
|
||||||
mode: 识别模式
|
|
||||||
chunk_size: 分块大小
|
|
||||||
"""
|
|
||||||
# 打开并读取WAV文件
|
|
||||||
with wave.open(audio_file, "rb") as wav_file:
|
|
||||||
params = wav_file.getparams()
|
|
||||||
frames = wav_file.readframes(wav_file.getnframes())
|
|
||||||
|
|
||||||
# 音频文件信息
|
|
||||||
print(f"音频文件: {os.path.basename(audio_file)}")
|
|
||||||
print(f"采样率: {params.framerate}Hz, 通道数: {params.nchannels}")
|
|
||||||
print(f"采样位深: {params.sampwidth * 8}位, 总帧数: {params.nframes}")
|
|
||||||
|
|
||||||
# 设置配置参数
|
|
||||||
config = {
|
|
||||||
"mode": mode,
|
|
||||||
"chunk_size": chunk_size,
|
|
||||||
"wav_name": os.path.basename(audio_file),
|
|
||||||
"is_speaking": True
|
|
||||||
}
|
|
||||||
|
|
||||||
# 发送配置
|
|
||||||
await websocket.send(json.dumps(config))
|
|
||||||
|
|
||||||
# 模拟实时发送音频数据
|
|
||||||
chunk_size_bytes = 3200 # 每次发送100ms的16kHz音频
|
|
||||||
total_chunks = len(frames) // chunk_size_bytes
|
|
||||||
|
|
||||||
print(f"开始发送音频数据,共 {total_chunks} 个数据块...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
for i in range(0, len(frames), chunk_size_bytes):
|
|
||||||
chunk = frames[i:i+chunk_size_bytes]
|
|
||||||
await websocket.send(chunk)
|
|
||||||
|
|
||||||
# 模拟实时,每100ms发送一次
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
# 显示进度
|
|
||||||
if (i // chunk_size_bytes) % 10 == 0:
|
|
||||||
print(f"已发送 {i // chunk_size_bytes}/{total_chunks} 数据块")
|
|
||||||
|
|
||||||
# 发送结束信号
|
|
||||||
await websocket.send(json.dumps({"is_speaking": False}))
|
|
||||||
print("音频数据发送完成")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"发送音频时出错: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def receive_results(websocket):
|
|
||||||
"""
|
|
||||||
接收并显示识别结果
|
|
||||||
|
|
||||||
参数:
|
|
||||||
websocket: WebSocket连接
|
|
||||||
"""
|
|
||||||
online_text = ""
|
|
||||||
offline_text = ""
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for message in websocket:
|
|
||||||
# 解析服务器返回的JSON消息
|
|
||||||
result = json.loads(message)
|
|
||||||
|
|
||||||
mode = result.get("mode", "")
|
|
||||||
text = result.get("text", "")
|
|
||||||
is_final = result.get("is_final", False)
|
|
||||||
|
|
||||||
# 根据模式更新文本
|
|
||||||
if "online" in mode:
|
|
||||||
online_text = text
|
|
||||||
print(f"\r[在线识别] {online_text}", end="", flush=True)
|
|
||||||
elif "offline" in mode:
|
|
||||||
offline_text = text
|
|
||||||
print(f"\n[离线识别] {offline_text}")
|
|
||||||
|
|
||||||
# 如果是最终结果,打印完整信息
|
|
||||||
if is_final and offline_text:
|
|
||||||
print("\n最终识别结果:")
|
|
||||||
print(f"[离线识别] {offline_text}")
|
|
||||||
return
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"接收结果时出错: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""主函数"""
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
# WebSocket URI
|
|
||||||
protocol = "wss" if args.use_ssl else "ws"
|
|
||||||
uri = f"{protocol}://{args.host}:{args.port}"
|
|
||||||
|
|
||||||
print(f"连接到服务器: {uri}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 创建WebSocket连接
|
|
||||||
async with websockets.connect(
|
|
||||||
uri,
|
|
||||||
subprotocols=["binary"]
|
|
||||||
) as websocket:
|
|
||||||
|
|
||||||
print("连接成功")
|
|
||||||
|
|
||||||
# 创建两个任务: 发送音频和接收结果
|
|
||||||
send_task = asyncio.create_task(
|
|
||||||
send_audio(websocket, args.audio_file, args.mode, args.chunk_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
receive_task = asyncio.create_task(
|
|
||||||
receive_results(websocket)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 等待任务完成
|
|
||||||
await asyncio.gather(send_task, receive_task)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"连接服务器失败: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 运行主函数
|
|
||||||
asyncio.run(main())
|
|
@ -10,116 +10,79 @@ import argparse
|
|||||||
def parse_args():
|
def parse_args():
|
||||||
"""
|
"""
|
||||||
解析命令行参数
|
解析命令行参数
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
argparse.Namespace: 解析后的参数对象
|
argparse.Namespace: 解析后的参数对象
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser(description="FunASR WebSocket服务器")
|
parser = argparse.ArgumentParser(description="FunASR WebSocket服务器")
|
||||||
|
|
||||||
# 服务器配置
|
# 服务器配置
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--host",
|
"--host",
|
||||||
type=str,
|
type=str,
|
||||||
default="0.0.0.0",
|
default="0.0.0.0",
|
||||||
help="服务器主机地址,例如:localhost, 0.0.0.0"
|
help="服务器主机地址,例如:localhost, 0.0.0.0",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--port", type=int, default=10095, help="WebSocket服务器端口")
|
||||||
"--port",
|
|
||||||
type=int,
|
|
||||||
default=10095,
|
|
||||||
help="WebSocket服务器端口"
|
|
||||||
)
|
|
||||||
|
|
||||||
# SSL配置
|
# SSL配置
|
||||||
parser.add_argument(
|
parser.add_argument("--certfile", type=str, default="", help="SSL证书文件路径")
|
||||||
"--certfile",
|
parser.add_argument("--keyfile", type=str, default="", help="SSL密钥文件路径")
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="SSL证书文件路径"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--keyfile",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="SSL密钥文件路径"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ASR模型配置
|
# ASR模型配置
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--asr_model",
|
"--asr_model",
|
||||||
type=str,
|
type=str,
|
||||||
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||||
help="离线ASR模型(从ModelScope获取)"
|
help="离线ASR模型(从ModelScope获取)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--asr_model_revision",
|
"--asr_model_revision", type=str, default="v2.0.4", help="离线ASR模型版本"
|
||||||
type=str,
|
|
||||||
default="v2.0.4",
|
|
||||||
help="离线ASR模型版本"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 在线ASR模型配置
|
# 在线ASR模型配置
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--asr_model_online",
|
"--asr_model_online",
|
||||||
type=str,
|
type=str,
|
||||||
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
|
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
|
||||||
help="在线ASR模型(从ModelScope获取)"
|
help="在线ASR模型(从ModelScope获取)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--asr_model_online_revision",
|
"--asr_model_online_revision",
|
||||||
type=str,
|
type=str,
|
||||||
default="v2.0.4",
|
default="v2.0.4",
|
||||||
help="在线ASR模型版本"
|
help="在线ASR模型版本",
|
||||||
)
|
)
|
||||||
|
|
||||||
# VAD模型配置
|
# VAD模型配置
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vad_model",
|
"--vad_model",
|
||||||
type=str,
|
type=str,
|
||||||
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||||
help="VAD语音活动检测模型(从ModelScope获取)"
|
help="VAD语音活动检测模型(从ModelScope获取)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vad_model_revision",
|
"--vad_model_revision", type=str, default="v2.0.4", help="VAD模型版本"
|
||||||
type=str,
|
|
||||||
default="v2.0.4",
|
|
||||||
help="VAD模型版本"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 标点符号模型配置
|
# 标点符号模型配置
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--punc_model",
|
"--punc_model",
|
||||||
type=str,
|
type=str,
|
||||||
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
|
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
|
||||||
help="标点符号模型(从ModelScope获取)"
|
help="标点符号模型(从ModelScope获取)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--punc_model_revision",
|
"--punc_model_revision", type=str, default="v2.0.4", help="标点符号模型版本"
|
||||||
type=str,
|
|
||||||
default="v2.0.4",
|
|
||||||
help="标点符号模型版本"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 硬件配置
|
# 硬件配置
|
||||||
|
parser.add_argument("--ngpu", type=int, default=1, help="GPU数量,0表示仅使用CPU")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ngpu",
|
"--device", type=str, default="cuda", help="设备类型:cuda或cpu"
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="GPU数量,0表示仅使用CPU"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--ncpu", type=int, default=4, help="CPU核心数")
|
||||||
"--device",
|
|
||||||
type=str,
|
|
||||||
default="cuda",
|
|
||||||
help="设备类型:cuda或cpu"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ncpu",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="CPU核心数"
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -127,4 +90,4 @@ if __name__ == "__main__":
|
|||||||
args = parse_args()
|
args = parse_args()
|
||||||
print("配置参数:")
|
print("配置参数:")
|
||||||
for arg in vars(args):
|
for arg in vars(args):
|
||||||
print(f" {arg}: {getattr(args, arg)}")
|
print(f" {arg}: {getattr(args, arg)}")
|
||||||
|
4
src/functor/__init__.py
Normal file
4
src/functor/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .vad_functor import VADFunctor
|
||||||
|
from .base import FunctorFactory
|
||||||
|
|
||||||
|
__all__ = ["VADFunctor", "FunctorFactory"]
|
161
src/functor/asr_functor.py
Normal file
161
src/functor/asr_functor.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
"""
|
||||||
|
ASRFunctor
|
||||||
|
负责对音频片段进行ASR处理, 以ASR_Result进行callback
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.functor.base import BaseFunctor
|
||||||
|
from src.models import AudioBinary_data_list, AudioBinary_Config, VAD_Functor_result
|
||||||
|
from typing import Callable, List
|
||||||
|
from queue import Queue, Empty
|
||||||
|
import threading
|
||||||
|
|
||||||
|
# 日志
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ASRFunctor(BaseFunctor):
|
||||||
|
"""
|
||||||
|
ASRFunctor
|
||||||
|
负责对音频片段进行ASR处理, 以ASR_Result进行callback
|
||||||
|
需要配置好 _model, _callback, _input_queue, _audio_config
|
||||||
|
否则无法run()启动线程
|
||||||
|
|
||||||
|
运行中, 使用reset_cache()重置缓存, 准备下次任务
|
||||||
|
|
||||||
|
使用stop()停止线程, 但需要等待input_queue为空
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
# 资源与配置
|
||||||
|
self._model: dict = {} # 模型
|
||||||
|
self._callback: List[Callable] = [] # 回调函数
|
||||||
|
self._input_queue: Queue = None # 输入队列
|
||||||
|
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||||
|
|
||||||
|
# flag
|
||||||
|
self._is_running: bool = False
|
||||||
|
self._stop_event: bool = False
|
||||||
|
|
||||||
|
# 线程资源
|
||||||
|
self._thread: threading.Thread = None
|
||||||
|
|
||||||
|
# 状态锁
|
||||||
|
self._status_lock: threading.Lock = threading.Lock()
|
||||||
|
|
||||||
|
# 缓存
|
||||||
|
self._hotwords: List[str] = []
|
||||||
|
|
||||||
|
def reset_cache(self) -> None:
|
||||||
|
"""
|
||||||
|
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_input_queue(self, queue: Queue) -> None:
|
||||||
|
"""
|
||||||
|
设置监听的输入消息队列
|
||||||
|
"""
|
||||||
|
self._input_queue = queue
|
||||||
|
|
||||||
|
def set_model(self, model: dict) -> None:
|
||||||
|
"""
|
||||||
|
设置推理模型
|
||||||
|
"""
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
||||||
|
"""
|
||||||
|
设置音频配置
|
||||||
|
"""
|
||||||
|
self._audio_config = audio_config
|
||||||
|
logger.debug("ASRFunctor设置音频配置: %s", self._audio_config)
|
||||||
|
|
||||||
|
def add_callback(self, callback: Callable) -> None:
|
||||||
|
"""
|
||||||
|
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||||||
|
"""
|
||||||
|
if not isinstance(self._callback, list):
|
||||||
|
self._callback = []
|
||||||
|
self._callback.append(callback)
|
||||||
|
|
||||||
|
def _do_callback(self, result: List[str]) -> None:
|
||||||
|
"""
|
||||||
|
回调函数
|
||||||
|
"""
|
||||||
|
text = result[0]["text"].replace(" ", "")
|
||||||
|
for callback in self._callback:
|
||||||
|
callback(text)
|
||||||
|
|
||||||
|
def _process(self, data: VAD_Functor_result) -> None:
|
||||||
|
"""
|
||||||
|
处理数据
|
||||||
|
"""
|
||||||
|
binary_data = data.audiobinary_data.binary_data
|
||||||
|
result = self._model["asr"].generate(
|
||||||
|
input=binary_data,
|
||||||
|
chunk_size=self._audio_config.chunk_size,
|
||||||
|
hotwords=self._hotwords,
|
||||||
|
)
|
||||||
|
self._do_callback(result)
|
||||||
|
|
||||||
|
def _run(self) -> None:
|
||||||
|
"""
|
||||||
|
线程运行逻辑
|
||||||
|
"""
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = True
|
||||||
|
self._stop_event = False
|
||||||
|
# 运行逻辑
|
||||||
|
while self._is_running:
|
||||||
|
try:
|
||||||
|
data = self._input_queue.get(True, timeout=1)
|
||||||
|
self._process(data)
|
||||||
|
self._input_queue.task_done()
|
||||||
|
# 当队列为空时, 间隔1s检测是否进入停止事件。
|
||||||
|
except Empty:
|
||||||
|
if self._stop_event:
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
# 其他异常
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("ASRFunctor运行时发生错误: %s", e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def run(self) -> threading.Thread:
|
||||||
|
"""
|
||||||
|
启动线程
|
||||||
|
Returns:
|
||||||
|
threading.Thread: 返回已运行线程实例
|
||||||
|
"""
|
||||||
|
self._pre_check()
|
||||||
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
return self._thread
|
||||||
|
|
||||||
|
def _pre_check(self) -> bool:
|
||||||
|
"""
|
||||||
|
预检查
|
||||||
|
"""
|
||||||
|
if self._model is None:
|
||||||
|
raise ValueError("模型未设置")
|
||||||
|
if self._audio_config is None:
|
||||||
|
raise ValueError("音频配置未设置")
|
||||||
|
if self._input_queue is None:
|
||||||
|
raise ValueError("输入队列未设置")
|
||||||
|
if self._callback is None:
|
||||||
|
raise ValueError("回调函数未设置")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def stop(self) -> bool:
|
||||||
|
"""
|
||||||
|
停止线程
|
||||||
|
"""
|
||||||
|
with self._status_lock:
|
||||||
|
self._stop_event = True
|
||||||
|
self._thread.join()
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = False
|
||||||
|
return not self._thread.is_alive()
|
191
src/functor/base.py
Normal file
191
src/functor/base.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
"""
|
||||||
|
Functor基础模块
|
||||||
|
|
||||||
|
该模块定义了Functor的基类,所有功能性的类(如VAD、PUNC、ASR、SPK等)都应继承自这个基类。
|
||||||
|
基类提供了数据处理的基本框架,包括:
|
||||||
|
- 回调函数管理
|
||||||
|
- 模型配置管理
|
||||||
|
- 线程运行控制
|
||||||
|
|
||||||
|
主要类:
|
||||||
|
BaseFunctor: Functor抽象类
|
||||||
|
FunctorFactory: Functor工厂类
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Callable, List
|
||||||
|
from queue import Queue
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFunctor(ABC):
|
||||||
|
"""
|
||||||
|
Functor抽象类
|
||||||
|
|
||||||
|
该抽象类规定了所有的Functor类必须实现run()方法启动自身线程
|
||||||
|
|
||||||
|
属性:
|
||||||
|
_callback (Callable): 处理完成后的回调函数
|
||||||
|
_model (dict): 存储模型相关的配置和实例
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
初始化函数器
|
||||||
|
|
||||||
|
参数:
|
||||||
|
callback (Callable): 处理完成后的回调函数
|
||||||
|
model (dict): 模型相关的配置和实例
|
||||||
|
"""
|
||||||
|
self._callback: List[Callable] = []
|
||||||
|
self._model: dict = {}
|
||||||
|
# flag
|
||||||
|
self._is_running: bool = False
|
||||||
|
self._stop_event: bool = False
|
||||||
|
# 状态锁
|
||||||
|
self._status_lock: threading.Lock = threading.Lock()
|
||||||
|
# 线程资源
|
||||||
|
self._thread: threading.Thread = None
|
||||||
|
|
||||||
|
def add_callback(self, callback: Callable):
|
||||||
|
"""
|
||||||
|
添加回调函数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
callback (Callable): 新的回调函数
|
||||||
|
"""
|
||||||
|
self._callback.append(callback)
|
||||||
|
|
||||||
|
def set_model(self, model: dict):
|
||||||
|
"""
|
||||||
|
设置模型配置
|
||||||
|
|
||||||
|
参数:
|
||||||
|
model (dict): 新的模型配置
|
||||||
|
"""
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
def set_input_queue(self, queue: Queue):
|
||||||
|
"""
|
||||||
|
设置输入队列
|
||||||
|
|
||||||
|
参数:
|
||||||
|
queue (Queue): 新的输入队列
|
||||||
|
"""
|
||||||
|
self._input_queue = queue
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _run(self):
|
||||||
|
"""
|
||||||
|
线程运行逻辑
|
||||||
|
|
||||||
|
返回:
|
||||||
|
当达到条件时触发callback
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run(self):
|
||||||
|
"""
|
||||||
|
启动_run方法线程
|
||||||
|
|
||||||
|
返回:
|
||||||
|
线程实例
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _pre_check(self):
|
||||||
|
"""
|
||||||
|
预检查
|
||||||
|
|
||||||
|
返回:
|
||||||
|
预检查结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def stop(self):
|
||||||
|
"""
|
||||||
|
停止线程
|
||||||
|
|
||||||
|
返回:
|
||||||
|
停止结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class FunctorFactory:
|
||||||
|
"""
|
||||||
|
Functor工厂类
|
||||||
|
|
||||||
|
该工厂类负责创建和配置Functor实例
|
||||||
|
|
||||||
|
主要方法:
|
||||||
|
make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
|
||||||
|
创建并配置Functor实例
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor:
|
||||||
|
"""
|
||||||
|
创建并配置Functor实例
|
||||||
|
|
||||||
|
参数:
|
||||||
|
funtor_name (str): Functor名称
|
||||||
|
config (dict): 配置信息
|
||||||
|
models (dict): 模型信息
|
||||||
|
|
||||||
|
返回:
|
||||||
|
BaseFunctor: 创建的Functor实例
|
||||||
|
"""
|
||||||
|
|
||||||
|
if functor_name == "vad":
|
||||||
|
return cls._make_vadfunctor(config=config, models=models)
|
||||||
|
elif functor_name == "asr":
|
||||||
|
return cls._make_asrfunctor(config=config, models=models)
|
||||||
|
elif functor_name == "spk":
|
||||||
|
return cls._make_spkfunctor(config=config, models=models)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的Functor类型: {functor_name}")
|
||||||
|
|
||||||
|
def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||||
|
"""
|
||||||
|
创建VAD Functor实例
|
||||||
|
"""
|
||||||
|
from src.functor.vad_functor import VADFunctor
|
||||||
|
|
||||||
|
audio_config = config["audio_config"]
|
||||||
|
model = {"vad": models["vad"]}
|
||||||
|
|
||||||
|
vad_functor = VADFunctor()
|
||||||
|
vad_functor.set_audio_config(audio_config)
|
||||||
|
vad_functor.set_model(model)
|
||||||
|
|
||||||
|
return vad_functor
|
||||||
|
|
||||||
|
def _make_asrfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||||
|
"""
|
||||||
|
创建ASR Functor实例
|
||||||
|
"""
|
||||||
|
from src.functor.asr_functor import ASRFunctor
|
||||||
|
|
||||||
|
audio_config = config["audio_config"]
|
||||||
|
model = {"asr": models["asr"]}
|
||||||
|
|
||||||
|
asr_functor = ASRFunctor()
|
||||||
|
asr_functor.set_audio_config(audio_config)
|
||||||
|
asr_functor.set_model(model)
|
||||||
|
|
||||||
|
return asr_functor
|
||||||
|
|
||||||
|
def _make_spkfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||||
|
"""
|
||||||
|
创建SPK Functor实例
|
||||||
|
"""
|
||||||
|
from src.functor.spk_functor import SPKFunctor
|
||||||
|
|
||||||
|
audio_config = config["audio_config"]
|
||||||
|
model = {"spk": models["spk"]}
|
||||||
|
|
||||||
|
spk_functor = SPKFunctor()
|
||||||
|
spk_functor.set_audio_config(audio_config)
|
||||||
|
spk_functor.set_model(model)
|
||||||
|
|
||||||
|
return spk_functor
|
122
src/functor/readme.md
Normal file
122
src/functor/readme.md
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
# 对于Functor的解释
|
||||||
|
|
||||||
|
## Functor 文件夹作用
|
||||||
|
|
||||||
|
Functor文件夹用于存放所有功能性的类,包括VAD、PUNC、ASR、SPK等。
|
||||||
|
|
||||||
|
## Functor 类的定义
|
||||||
|
|
||||||
|
所有类应继承于**基类**`BaseFunctor`。
|
||||||
|
|
||||||
|
为了方便使用,我们对于**基类**的定义如下:
|
||||||
|
|
||||||
|
1. 函数内部使用的变量以单下划线开头,基类中包含:
|
||||||
|
|
||||||
|
* _model: Dict 存放模型相关的配置和实例
|
||||||
|
* _input_queue: Queue 监听的输入消息队列
|
||||||
|
* _thread: Threading.Thread 运行的线程实例
|
||||||
|
* _callback: List[Callable] 回调函数列表
|
||||||
|
* _is_running: bool 线程运行状态标志
|
||||||
|
* _stop_event: bool 停止事件标志
|
||||||
|
* _status_lock: threading.Lock 状态锁,用于线程同步
|
||||||
|
|
||||||
|
2. 对于使用的模型,请从统一的 **模型管理类`ModelLoader`** 中获取,由模型管理类统一进行加载、缓存和释放,`_model`存放类型为`dict`。
|
||||||
|
|
||||||
|
3. 基类定义的核心方法:
|
||||||
|
|
||||||
|
* `add_callback(callback: Callable)`: 添加结果处理的回调函数
|
||||||
|
* `set_model(model: dict)`: 设置模型配置和实例
|
||||||
|
* `set_input_queue(queue: Queue)`: 设置输入数据队列
|
||||||
|
* `run()`: 启动处理线程(抽象方法)
|
||||||
|
* `stop()`: 停止处理线程(抽象方法)
|
||||||
|
* `_run()`: 线程运行的具体逻辑(抽象方法)
|
||||||
|
* `_pre_check()`: 运行前的预检查(抽象方法)
|
||||||
|
|
||||||
|
## 派生类实现要求
|
||||||
|
|
||||||
|
1. 必须实现的抽象方法:
|
||||||
|
* `_pre_check()`:
|
||||||
|
- 检查必要的配置是否完整(如模型、队列等)
|
||||||
|
- 检查运行环境是否满足要求
|
||||||
|
- 返回检查结果
|
||||||
|
|
||||||
|
* `_run()`:
|
||||||
|
- 实现具体的数据处理逻辑
|
||||||
|
- 从 _input_queue 获取输入数据
|
||||||
|
- 使用 _model 进行处理
|
||||||
|
- 通过 _callback 返回处理结果
|
||||||
|
|
||||||
|
* `run()`:
|
||||||
|
- 调用 _pre_check() 进行预检查
|
||||||
|
- 创建并启动处理线程
|
||||||
|
- 设置相关状态标志
|
||||||
|
|
||||||
|
* `stop()`:
|
||||||
|
- 安全停止处理线程
|
||||||
|
- 清理资源
|
||||||
|
- 重置状态标志
|
||||||
|
|
||||||
|
2. 建议实现的方法:
|
||||||
|
* `__str__`: 返回当前实例的状态信息
|
||||||
|
* 错误处理方法:处理运行过程中的异常情况
|
||||||
|
|
||||||
|
## 使用示例
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyFunctor(BaseFunctor):
|
||||||
|
def _pre_check(self):
|
||||||
|
if not self._model or not self._input_queue:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
while not self._stop_event:
|
||||||
|
try:
|
||||||
|
data = self._input_queue.get(timeout=1.0)
|
||||||
|
result = self._model['my_model'].process(data)
|
||||||
|
for callback in self._callback:
|
||||||
|
callback(result)
|
||||||
|
except Queue.Empty:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理错误: {e}")
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
if not self._pre_check():
|
||||||
|
raise RuntimeError("预检查失败")
|
||||||
|
|
||||||
|
with self._status_lock:
|
||||||
|
if self._is_running:
|
||||||
|
return
|
||||||
|
self._is_running = True
|
||||||
|
self._stop_event = False
|
||||||
|
self._thread = threading.Thread(target=self._run)
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
with self._status_lock:
|
||||||
|
if not self._is_running:
|
||||||
|
return
|
||||||
|
self._stop_event = True
|
||||||
|
if self._thread:
|
||||||
|
self._thread.join()
|
||||||
|
self._is_running = False
|
||||||
|
```
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
1. 线程安全:
|
||||||
|
* 使用 _status_lock 保护状态变更
|
||||||
|
* 注意共享资源的访问控制
|
||||||
|
|
||||||
|
2. 错误处理:
|
||||||
|
* 在 _run() 中妥善处理异常
|
||||||
|
* 提供详细的错误日志
|
||||||
|
|
||||||
|
3. 资源管理:
|
||||||
|
* 确保在 stop() 中正确清理资源
|
||||||
|
* 避免资源泄露
|
||||||
|
|
||||||
|
4. 回调函数:
|
||||||
|
* 回调函数应该是非阻塞的
|
||||||
|
* 处理回调函数抛出的异常
|
145
src/functor/spk_functor.py
Normal file
145
src/functor/spk_functor.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
"""
|
||||||
|
SpkFunctor
|
||||||
|
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.functor.base import BaseFunctor
|
||||||
|
from src.models import AudioBinary_Config, VAD_Functor_result
|
||||||
|
from typing import Callable, List
|
||||||
|
from queue import Queue, Empty
|
||||||
|
import threading
|
||||||
|
|
||||||
|
# 日志
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SPKFunctor(BaseFunctor):
|
||||||
|
"""
|
||||||
|
SPKFunctor
|
||||||
|
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
|
||||||
|
需要配置好 _model, _callback, _input_queue, _audio_config
|
||||||
|
否则无法run()启动线程
|
||||||
|
|
||||||
|
运行中, 使用reset_cache()重置缓存, 准备下次任务
|
||||||
|
|
||||||
|
使用stop()停止线程, 但需要等待input_queue为空
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
# 资源与配置
|
||||||
|
self._model: dict = {} # 模型
|
||||||
|
self._callback: List[Callable] = [] # 回调函数
|
||||||
|
self._input_queue: Queue = None # 输入队列
|
||||||
|
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||||
|
|
||||||
|
def reset_cache(self) -> None:
|
||||||
|
"""
|
||||||
|
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_input_queue(self, queue: Queue) -> None:
|
||||||
|
"""
|
||||||
|
设置监听的输入消息队列
|
||||||
|
"""
|
||||||
|
self._input_queue = queue
|
||||||
|
|
||||||
|
def set_model(self, model: dict) -> None:
|
||||||
|
"""
|
||||||
|
设置推理模型
|
||||||
|
"""
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
||||||
|
"""
|
||||||
|
设置音频配置
|
||||||
|
"""
|
||||||
|
self._audio_config = audio_config
|
||||||
|
logger.debug("SpkFunctor设置音频配置: %s", self._audio_config)
|
||||||
|
|
||||||
|
def add_callback(self, callback: Callable) -> None:
|
||||||
|
"""
|
||||||
|
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||||||
|
"""
|
||||||
|
if not isinstance(self._callback, list):
|
||||||
|
self._callback = []
|
||||||
|
self._callback.append(callback)
|
||||||
|
|
||||||
|
def _do_callback(self, result: List[str]) -> None:
|
||||||
|
"""
|
||||||
|
回调函数
|
||||||
|
"""
|
||||||
|
for callback in self._callback:
|
||||||
|
callback(result)
|
||||||
|
|
||||||
|
def _process(self, data: VAD_Functor_result) -> None:
|
||||||
|
"""
|
||||||
|
处理数据
|
||||||
|
"""
|
||||||
|
binary_data = data.audiobinary_data.binary_data
|
||||||
|
# result = self._model["spk"].generate(
|
||||||
|
# input=binary_data,
|
||||||
|
# chunk_size=self._audio_config.chunk_size,
|
||||||
|
# )
|
||||||
|
result = [{"result": "spk1", "score": {"spk1": 0.9, "spk2": 0.3}}]
|
||||||
|
self._do_callback(result)
|
||||||
|
|
||||||
|
def _run(self) -> None:
|
||||||
|
"""
|
||||||
|
线程运行逻辑
|
||||||
|
"""
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = True
|
||||||
|
self._stop_event = False
|
||||||
|
# 运行逻辑
|
||||||
|
while self._is_running:
|
||||||
|
try:
|
||||||
|
data = self._input_queue.get(True, timeout=1)
|
||||||
|
self._process(data)
|
||||||
|
self._input_queue.task_done()
|
||||||
|
# 当队列为空时, 间隔1s检测是否进入停止事件。
|
||||||
|
except Empty:
|
||||||
|
if self._stop_event:
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
# 其他异常
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("SpkFunctor运行时发生错误: %s", e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def run(self) -> threading.Thread:
|
||||||
|
"""
|
||||||
|
启动线程
|
||||||
|
Returns:
|
||||||
|
threading.Thread: 返回已运行线程实例
|
||||||
|
"""
|
||||||
|
self._pre_check()
|
||||||
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
return self._thread
|
||||||
|
|
||||||
|
def _pre_check(self) -> bool:
|
||||||
|
"""
|
||||||
|
预检查
|
||||||
|
"""
|
||||||
|
if self._model is None:
|
||||||
|
raise ValueError("模型未设置")
|
||||||
|
if self._input_queue is None:
|
||||||
|
raise ValueError("输入队列未设置")
|
||||||
|
if self._callback is None:
|
||||||
|
raise ValueError("回调函数未设置")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def stop(self) -> bool:
|
||||||
|
"""
|
||||||
|
停止线程
|
||||||
|
"""
|
||||||
|
with self._status_lock:
|
||||||
|
self._stop_event = True
|
||||||
|
self._thread.join()
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = False
|
||||||
|
return not self._thread.is_alive()
|
315
src/functor/vad_functor.py
Normal file
315
src/functor/vad_functor.py
Normal file
@ -0,0 +1,315 @@
|
|||||||
|
"""
|
||||||
|
VADFunctor
|
||||||
|
负责对音频片段进行VAD处理, 以VAD_Result进行callback
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from queue import Empty, Queue
|
||||||
|
from typing import List, Any, Callable
|
||||||
|
import numpy
|
||||||
|
from src.models import (
|
||||||
|
VAD_Functor_result,
|
||||||
|
AudioBinary_Config,
|
||||||
|
AudioBinary_data_list,
|
||||||
|
)
|
||||||
|
from src.functor.base import BaseFunctor
|
||||||
|
|
||||||
|
# 日志
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VADFunctor(BaseFunctor):
|
||||||
|
"""
|
||||||
|
VADFunctor
|
||||||
|
负责对音频片段进行VAD处理, 以VAD_Result进行callback
|
||||||
|
需要配置好 _model, _callback, _input_queue, _audio_config, _audio_binary_data_list
|
||||||
|
否则无法run()启动线程
|
||||||
|
|
||||||
|
运行中, 使用reset_cache()重置缓存, 准备下次任务
|
||||||
|
|
||||||
|
使用stop()停止线程, 但需要等待input_queue为空
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
# 资源与配置
|
||||||
|
self._model: dict = {} # 模型
|
||||||
|
self._callback: List[Callable] = [] # 回调函数
|
||||||
|
self._input_queue: Queue = None # 输入队列
|
||||||
|
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||||
|
self._audio_binary_data_list: AudioBinary_data_list = None # 音频数据列表
|
||||||
|
|
||||||
|
# flag
|
||||||
|
# 此处用到两个锁,但都是为了截断_run线程,考虑后续优化
|
||||||
|
self._is_running: bool = False
|
||||||
|
self._stop_event: bool = False
|
||||||
|
|
||||||
|
# 线程资源
|
||||||
|
self._thread: threading.Thread = None
|
||||||
|
|
||||||
|
# 状态锁
|
||||||
|
self._status_lock: threading.Lock = threading.Lock()
|
||||||
|
|
||||||
|
# 缓存
|
||||||
|
self._audio_cache: numpy.ndarray = None
|
||||||
|
self._audio_cache_preindex: int = 0
|
||||||
|
self._model_cache: dict = {}
|
||||||
|
self._cache_result_list = []
|
||||||
|
self._audiobinary_cache = None
|
||||||
|
|
||||||
|
def reset_cache(self) -> None:
|
||||||
|
"""
|
||||||
|
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||||
|
"""
|
||||||
|
self._audio_cache = None
|
||||||
|
self._audio_cache_preindex = 0
|
||||||
|
self._model_cache = {}
|
||||||
|
self._cache_result_list = []
|
||||||
|
self._audiobinary_cache = None
|
||||||
|
|
||||||
|
def set_input_queue(self, queue: Queue) -> None:
|
||||||
|
"""
|
||||||
|
设置监听的输入消息队列
|
||||||
|
"""
|
||||||
|
self._input_queue = queue
|
||||||
|
|
||||||
|
def set_model(self, model: dict) -> None:
|
||||||
|
"""
|
||||||
|
设置推理模型
|
||||||
|
"""
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
||||||
|
"""
|
||||||
|
设置音频配置
|
||||||
|
"""
|
||||||
|
self._audio_config = audio_config
|
||||||
|
logger.debug("VADFunctor设置音频配置: %s", self._audio_config)
|
||||||
|
|
||||||
|
def set_audio_binary_data_list(
|
||||||
|
self, audio_binary_data_list: AudioBinary_data_list
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
设置音频数据列表, 为Class AudioBinary_data_list类型
|
||||||
|
AudioBinary_data_list包含binary_data_list, 为list[_AudioBinary_data]类型
|
||||||
|
_AudioBinary_data包含binary_data, 为bytes/numpy.ndarray类型
|
||||||
|
"""
|
||||||
|
self._audio_binary_data_list = audio_binary_data_list
|
||||||
|
|
||||||
|
def add_callback(self, callback: Callable) -> None:
|
||||||
|
"""
|
||||||
|
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||||||
|
"""
|
||||||
|
if not isinstance(self._callback, list):
|
||||||
|
self._callback = []
|
||||||
|
self._callback.append(callback)
|
||||||
|
|
||||||
|
def _do_callback(self, result: List[List[int]]) -> None:
|
||||||
|
"""
|
||||||
|
回调函数
|
||||||
|
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice
|
||||||
|
|
||||||
|
输入:
|
||||||
|
result: List[[start,end]] 处理所得VAD端点
|
||||||
|
其中若start==-1, 则表示前无端点, 若end==-1, 则表示后无端点
|
||||||
|
当处理得到一个完成片段时, 存入AudioBinary中, 并向队列中添加AudioBinary_Slice
|
||||||
|
输出:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# 持久化缓存结果队列
|
||||||
|
for pair in result:
|
||||||
|
[start, end] = pair
|
||||||
|
# 若无前端点, 则向缓存队列中合并
|
||||||
|
if start == -1:
|
||||||
|
self._cache_result_list[-1][1] = end
|
||||||
|
else:
|
||||||
|
self._cache_result_list.append(pair)
|
||||||
|
while len(self._cache_result_list) > 1:
|
||||||
|
# 创建VAD片段
|
||||||
|
# 计算开始帧
|
||||||
|
start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0])
|
||||||
|
start_frame -= self._audio_cache_preindex
|
||||||
|
# 计算结束帧
|
||||||
|
end_frame = self._audio_config.ms2frame(self._cache_result_list[0][1])
|
||||||
|
end_frame -= self._audio_cache_preindex
|
||||||
|
# 计算开始时间
|
||||||
|
vad_result = VAD_Functor_result.create_from_push_data(
|
||||||
|
audiobinary_data_list=self._audio_binary_data_list,
|
||||||
|
data=self._audiobinary_cache[start_frame:end_frame],
|
||||||
|
start_time=self._cache_result_list[0][0],
|
||||||
|
end_time=self._cache_result_list[0][1],
|
||||||
|
)
|
||||||
|
self._audio_cache_preindex += end_frame
|
||||||
|
self._audiobinary_cache = self._audiobinary_cache[end_frame:]
|
||||||
|
for callback in self._callback:
|
||||||
|
callback(vad_result)
|
||||||
|
self._cache_result_list.pop(0)
|
||||||
|
|
||||||
|
def _predeal_data(self, data: Any) -> None:
|
||||||
|
"""
|
||||||
|
预处理数据, 将数据缓存到_audio_cache和_audiobinary_cache中
|
||||||
|
"""
|
||||||
|
if self._audio_cache is None:
|
||||||
|
self._audio_cache = data
|
||||||
|
else:
|
||||||
|
# 拼接音频数据
|
||||||
|
if isinstance(self._audio_cache, numpy.ndarray):
|
||||||
|
self._audio_cache = numpy.concatenate((self._audio_cache, data))
|
||||||
|
elif isinstance(self._audio_cache, list):
|
||||||
|
self._audio_cache.append(data)
|
||||||
|
if self._audiobinary_cache is None:
|
||||||
|
self._audiobinary_cache = data
|
||||||
|
else:
|
||||||
|
# 拼接音频数据
|
||||||
|
if isinstance(self._audiobinary_cache, numpy.ndarray):
|
||||||
|
self._audiobinary_cache = numpy.concatenate(
|
||||||
|
(self._audiobinary_cache, data)
|
||||||
|
)
|
||||||
|
elif isinstance(self._audiobinary_cache, list):
|
||||||
|
self._audiobinary_cache.append(data)
|
||||||
|
|
||||||
|
def _process(self, data: Any):
|
||||||
|
"""
|
||||||
|
处理数据
|
||||||
|
使用model进行生成, 并使用_do_callback进行回调
|
||||||
|
"""
|
||||||
|
self._predeal_data(data)
|
||||||
|
if len(self._audio_cache) >= self._audio_config.chunk_stride:
|
||||||
|
result = self._model["vad"].generate(
|
||||||
|
input=self._audio_cache,
|
||||||
|
cache=self._model_cache,
|
||||||
|
chunk_size=self._audio_config.chunk_size,
|
||||||
|
is_final=False,
|
||||||
|
)
|
||||||
|
if len(result[0]["value"]) > 0:
|
||||||
|
self._do_callback(result[0]["value"])
|
||||||
|
# logger.debug(f"VADFunctor结果: {result[0]['value']}")
|
||||||
|
self._audio_cache = None
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
"""
|
||||||
|
线程运行逻辑
|
||||||
|
监听输入队列, 当有数据时, 处理数据
|
||||||
|
当输入队列为空时, 间隔1s检测是否进入停止事件。
|
||||||
|
"""
|
||||||
|
# 刷新运行状态
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = True
|
||||||
|
self._stop_event = False
|
||||||
|
# 运行逻辑
|
||||||
|
while self._is_running:
|
||||||
|
try:
|
||||||
|
data = self._input_queue.get(True, timeout=1)
|
||||||
|
self._process(data)
|
||||||
|
self._input_queue.task_done()
|
||||||
|
# 当队列为空时, 间隔1s检测是否进入停止事件。
|
||||||
|
except Empty:
|
||||||
|
if self._stop_event:
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
# 其他异常
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("VADFunctor运行时发生错误: %s", e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""
|
||||||
|
启动 _run 线程, 并返回线程对象
|
||||||
|
"""
|
||||||
|
self._pre_check()
|
||||||
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
return self._thread
|
||||||
|
|
||||||
|
def _pre_check(self) -> bool:
|
||||||
|
"""
|
||||||
|
检测硬性资源是否都已设置
|
||||||
|
"""
|
||||||
|
if self._model is None:
|
||||||
|
raise ValueError("模型未设置")
|
||||||
|
if self._audio_config is None:
|
||||||
|
raise ValueError("音频配置未设置")
|
||||||
|
if self._audio_binary_data_list is None:
|
||||||
|
raise ValueError("音频数据列表未设置")
|
||||||
|
if self._input_queue is None:
|
||||||
|
raise ValueError("输入队列未设置")
|
||||||
|
if self._callback is None:
|
||||||
|
raise ValueError("回调函数未设置")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""
|
||||||
|
停止线程
|
||||||
|
通过设置_stop_event为True, 来在input_queue.get()循环为空时退出
|
||||||
|
"""
|
||||||
|
with self._status_lock:
|
||||||
|
self._stop_event = True
|
||||||
|
self._thread.join()
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = False
|
||||||
|
return not self._thread.is_alive()
|
||||||
|
|
||||||
|
|
||||||
|
# class VAD:
|
||||||
|
|
||||||
|
# def __init__(
|
||||||
|
# self,
|
||||||
|
# VAD_model=None,
|
||||||
|
# audio_config: AudioBinary_Config = None,
|
||||||
|
# callback: Callable = None,
|
||||||
|
# ):
|
||||||
|
# # vad model
|
||||||
|
# self.VAD_model = VAD_model
|
||||||
|
# if self.VAD_model is None:
|
||||||
|
# self.VAD_model = AutoModel(
|
||||||
|
# model="fsmn-vad", model_revision="v2.0.4", disable_update=True
|
||||||
|
# )
|
||||||
|
# # audio config
|
||||||
|
# self.audio_config = audio_config
|
||||||
|
# # vad result
|
||||||
|
# self.vad_result = VADResponse(time_chunk_index_callback=callback)
|
||||||
|
# # audio binary poll
|
||||||
|
# self.audio_chunk = AudioChunk(audio_config=self.audio_config)
|
||||||
|
# self.cache = {}
|
||||||
|
|
||||||
|
# def push_binary_data(
|
||||||
|
# self,
|
||||||
|
# binary_data: bytes,
|
||||||
|
# ):
|
||||||
|
# # 压入二进制数据
|
||||||
|
# self.audio_chunk.add_chunk(binary_data)
|
||||||
|
# # 处理音频块
|
||||||
|
# res = self.VAD_model.generate(
|
||||||
|
# input=binary_data,
|
||||||
|
# cache=self.cache,
|
||||||
|
# chunk_size=self.audio_config.chunk_size,
|
||||||
|
# is_final=False,
|
||||||
|
# )
|
||||||
|
# # print("VAD generate", res)
|
||||||
|
# if len(res[0]["value"]):
|
||||||
|
# self.vad_result += VADResponse.from_raw(res)
|
||||||
|
|
||||||
|
# def set_callback(
|
||||||
|
# self,
|
||||||
|
# callback: Callable,
|
||||||
|
# ):
|
||||||
|
# self.vad_result.time_chunk_index_callback = callback
|
||||||
|
|
||||||
|
# def process_vad_result(self, callback: Callable = None):
|
||||||
|
# # 处理VAD结果
|
||||||
|
# callback = (
|
||||||
|
# callback
|
||||||
|
# if callback is not None
|
||||||
|
# else self.vad_result.time_chunk_index_callback
|
||||||
|
# )
|
||||||
|
# self.vad_result.process_time_chunk(
|
||||||
|
# lambda x: callback(
|
||||||
|
# AudioBinary_Chunk(
|
||||||
|
# start_time=x["start_time"],
|
||||||
|
# end_time=x["end_time"],
|
||||||
|
# chunk=self.audio_chunk.get_chunk(x["start_time"], x["end_time"]),
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# )
|
176
src/logic_trager.py
Normal file
176
src/logic_trager.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
逻辑触发器类 - 用于处理音频数据并触发相应的处理逻辑
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
from typing import Any, Dict, Type, Callable
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = get_module_logger(__name__, level="INFO")
|
||||||
|
|
||||||
|
|
||||||
|
class AutoAfterMeta(type):
|
||||||
|
"""
|
||||||
|
自动调用__after__函数的元类
|
||||||
|
实现单例模式
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instances: Dict[Type, Any] = {} # 存储单例实例
|
||||||
|
|
||||||
|
def __new__(cls, name, bases, attrs):
|
||||||
|
# 遍历所有属性
|
||||||
|
for attr_name, attr_value in attrs.items():
|
||||||
|
# 如果是函数且不是以_开头
|
||||||
|
if callable(attr_value) and not attr_name.startswith("__"):
|
||||||
|
# 获取原函数
|
||||||
|
original_func = attr_value
|
||||||
|
|
||||||
|
# 创建包装函数
|
||||||
|
def make_wrapper(func):
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
# 执行原函数
|
||||||
|
result = func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
# 构建_after_函数名
|
||||||
|
after_func_name = f"__after__{func.__name__}"
|
||||||
|
|
||||||
|
# 检查是否存在对应的_after_函数
|
||||||
|
if hasattr(self, after_func_name):
|
||||||
|
after_func = getattr(self, after_func_name)
|
||||||
|
if callable(after_func):
|
||||||
|
try:
|
||||||
|
# 调用_after_函数
|
||||||
|
after_func()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"调用{after_func_name}时出错: {e}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
# 替换原函数
|
||||||
|
attrs[attr_name] = make_wrapper(original_func)
|
||||||
|
|
||||||
|
# 创建类
|
||||||
|
new_class = super().__new__(cls, name, bases, attrs)
|
||||||
|
return new_class
|
||||||
|
|
||||||
|
def __call__(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
重写__call__方法实现单例模式
|
||||||
|
当类被调用时(即创建实例时)执行
|
||||||
|
"""
|
||||||
|
if cls not in cls._instances:
|
||||||
|
# 如果实例不存在,创建新实例
|
||||||
|
cls._instances[cls] = super().__call__(*args, **kwargs)
|
||||||
|
logger.info(f"创建{cls.__name__}的新实例")
|
||||||
|
else:
|
||||||
|
logger.debug(f"返回{cls.__name__}的现有实例")
|
||||||
|
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
整体识别的处理逻辑:
|
||||||
|
1.压入二进制音频信息
|
||||||
|
2.不断检测VAD
|
||||||
|
3.当检测到完整VAD时,将VAD的音频信息压入音频块,并清除对应二进制信息
|
||||||
|
4.对音频块进行语音转文字offline,时间戳预测,说话人识别
|
||||||
|
5.将识别结果整合压入结果队列
|
||||||
|
6.结果队列被压入时调用回调函数
|
||||||
|
|
||||||
|
1->2 __after__push_binary_data 外部压入二进制信息
|
||||||
|
2,3->4 __after__push_audio_chunk 内部压入音频块
|
||||||
|
4->5 push_result_queue 压入结果队列
|
||||||
|
5->6 __after__push_result_queue 调用回调函数
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.functor import VAD
|
||||||
|
from src.models import AudioBinary_Config
|
||||||
|
from src.models import AudioBinary_Chunk
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class LogicTrager(metaclass=AutoAfterMeta):
|
||||||
|
"""逻辑触发器类"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
audio_chunk_max_size: int = 1024 * 1024 * 10,
|
||||||
|
audio_config: AudioBinary_Config = None,
|
||||||
|
result_callback: Callable = None,
|
||||||
|
models: Dict[str, Any] = None,
|
||||||
|
):
|
||||||
|
"""初始化"""
|
||||||
|
# 存储音频块
|
||||||
|
self._audio_chunk: List[AudioBinary_Chunk] = []
|
||||||
|
# 存储二进制数据
|
||||||
|
self._audio_chunk_binary = b""
|
||||||
|
self._audio_chunk_max_size = audio_chunk_max_size
|
||||||
|
# 音频参数
|
||||||
|
self._audio_config = (
|
||||||
|
audio_config if audio_config is not None else AudioBinary_Config()
|
||||||
|
)
|
||||||
|
# 结果队列
|
||||||
|
self._result_queue = []
|
||||||
|
# 聚合结果回调函数
|
||||||
|
self._aggregate_result_callback = result_callback
|
||||||
|
# 组件
|
||||||
|
self._vad = VAD(VAD_model=models.get("vad"), audio_config=self._audio_config)
|
||||||
|
self._vad.set_callback(self.push_audio_chunk)
|
||||||
|
|
||||||
|
logger.info("初始化LogicTrager")
|
||||||
|
|
||||||
|
def push_binary_data(self, chunk: bytes) -> None:
|
||||||
|
"""
|
||||||
|
压入音频块至VAD模块
|
||||||
|
|
||||||
|
参数:
|
||||||
|
chunk: 音频数据块
|
||||||
|
"""
|
||||||
|
# print("LogicTrager push_binary_data", len(chunk))
|
||||||
|
self._vad.push_binary_data(chunk)
|
||||||
|
self.__after__push_binary_data()
|
||||||
|
|
||||||
|
def __after__push_binary_data(self) -> None:
|
||||||
|
"""
|
||||||
|
添加音频块后处理
|
||||||
|
"""
|
||||||
|
# print("LogicTrager __after__push_binary_data")
|
||||||
|
self._vad.process_vad_result()
|
||||||
|
|
||||||
|
def push_audio_chunk(self, chunk: AudioBinary_Chunk) -> None:
|
||||||
|
"""
|
||||||
|
音频处理
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(
|
||||||
|
chunk.start_time, chunk.end_time, len(chunk.chunk)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._audio_chunk.append(chunk)
|
||||||
|
|
||||||
|
def __after__push_audio_chunk(self) -> None:
|
||||||
|
"""
|
||||||
|
压入音频块后处理
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def push_result_queue(self, result: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
压入结果队列
|
||||||
|
"""
|
||||||
|
self._result_queue.append(result)
|
||||||
|
|
||||||
|
def __after__push_result_queue(self) -> None:
|
||||||
|
"""
|
||||||
|
压入结果队列后处理
|
||||||
|
"""
|
||||||
|
logger.info("FINISH Result=")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
"""调用函数"""
|
||||||
|
pass
|
126
src/model_loader.py
Normal file
126
src/model_loader.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
模型加载模块 - 负责加载各种语音识别相关模型
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 导入FunASR库
|
||||||
|
from funasr import AutoModel
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError("未找到funasr库, 请先安装: pip install funasr") from exc
|
||||||
|
|
||||||
|
# 日志模块
|
||||||
|
from src.utils import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# 单例模式
|
||||||
|
class ModelLoader:
|
||||||
|
"""
|
||||||
|
ModelLoader类是单例模式, 程序生命周期全局唯一, 负责加载模型到字典中。
|
||||||
|
一般的, 可以直接call ModelLoader()来获取加载的模型。
|
||||||
|
也可以通过ModelLoader实例(args)或ModelloaderInstance.load_models(args)来初始化, 并加载模型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
单例模式
|
||||||
|
"""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(ModelLoader, cls).__new__(cls, *args, **kwargs)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self, args=None):
|
||||||
|
"""
|
||||||
|
初始化ModelLoader实例
|
||||||
|
"""
|
||||||
|
self.models = {}
|
||||||
|
logger.debug("初始化ModelLoader")
|
||||||
|
if args is not None:
|
||||||
|
self.__call__(args)
|
||||||
|
|
||||||
|
def __call__(self, args=None):
|
||||||
|
"""
|
||||||
|
调用ModelLoader实例时, 如果模型字典为空, 则加载模型
|
||||||
|
"""
|
||||||
|
# 如果模型字典为空, 则加载模型
|
||||||
|
if self.models == {} or self.models is None:
|
||||||
|
if args.asr_model is not None:
|
||||||
|
self.models = self.load_models(args)
|
||||||
|
# 直接调用等于调用self.models
|
||||||
|
return self.models
|
||||||
|
|
||||||
|
def _load_model(self, input_model_args: dict, model_type: str):
|
||||||
|
"""
|
||||||
|
加载单个模型
|
||||||
|
|
||||||
|
参数:
|
||||||
|
model_args: 模型加载字典
|
||||||
|
model_type: 模型类型, 用于确定使用哪个模型参数
|
||||||
|
|
||||||
|
返回:
|
||||||
|
AutoModel: 加载的模型实例
|
||||||
|
"""
|
||||||
|
# 默认配置
|
||||||
|
default_config = {
|
||||||
|
"model": None,
|
||||||
|
"model_revision": None,
|
||||||
|
"ngpu": 0,
|
||||||
|
"ncpu": 1,
|
||||||
|
"device": "cpu",
|
||||||
|
"disable_pbar": True,
|
||||||
|
"disable_log": True,
|
||||||
|
"disable_update": True,
|
||||||
|
}
|
||||||
|
# 从args中获取配置, 如果存在则覆盖默认值
|
||||||
|
model_args = default_config.copy()
|
||||||
|
for key, value in default_config.items():
|
||||||
|
if key in ["model", "model_revision"]:
|
||||||
|
# 特殊处理model和model_revision, 因为它们需要model_type前缀
|
||||||
|
if key == "model":
|
||||||
|
value = input_model_args.get(f"{model_type}_model", None)
|
||||||
|
else:
|
||||||
|
value = input_model_args.get(f"{model_type}_model_revision", None)
|
||||||
|
else:
|
||||||
|
value = input_model_args.get(key, None)
|
||||||
|
if value is not None:
|
||||||
|
logger.debug("替换%s模型参数: %s = %s", model_type, key, value)
|
||||||
|
model_args[key] = value
|
||||||
|
# 验证必要参数
|
||||||
|
if not model_args["model"]:
|
||||||
|
raise ValueError(f"未指定{model_type}模型路径")
|
||||||
|
try:
|
||||||
|
# 使用 % 格式化替代 f-string,避免不必要的字符串格式化开销
|
||||||
|
logger.debug("正在加载%s模型: %s", model_type, model_args["model"])
|
||||||
|
model = AutoModel(**model_args)
|
||||||
|
return model
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("加载%s模型失败: %s", model_type, str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
def load_models(self, args):
|
||||||
|
"""
|
||||||
|
加载所有需要的模型
|
||||||
|
参数:
|
||||||
|
args: 命令行参数, 包含模型配置
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict: 包含所有加载的模型的字典
|
||||||
|
"""
|
||||||
|
logger.info("ModelLoader加载模型")
|
||||||
|
# 初始化模型字典
|
||||||
|
self.models = {}
|
||||||
|
# 加载离线ASR模型
|
||||||
|
# 检查对应键是否存在
|
||||||
|
model_list = ["asr", "asr_online", "vad", "punc", "spk"]
|
||||||
|
for model_name in model_list:
|
||||||
|
name_model = f"{model_name}_model"
|
||||||
|
name_model_revision = f"{model_name}_model_revision"
|
||||||
|
if name_model in args:
|
||||||
|
logger.debug("加载%s模型", model_name)
|
||||||
|
self.models[model_name] = self._load_model(args, model_name)
|
||||||
|
logger.info("所有模型加载完成")
|
||||||
|
return self.models
|
@ -1,79 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
模型加载模块 - 负责加载各种语音识别相关模型
|
|
||||||
"""
|
|
||||||
|
|
||||||
def load_models(args):
|
|
||||||
"""
|
|
||||||
加载所有需要的模型
|
|
||||||
|
|
||||||
参数:
|
|
||||||
args: 命令行参数,包含模型配置
|
|
||||||
|
|
||||||
返回:
|
|
||||||
dict: 包含所有加载的模型的字典
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 导入FunASR库
|
|
||||||
from funasr import AutoModel
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("未找到funasr库,请先安装: pip install funasr")
|
|
||||||
|
|
||||||
# 初始化模型字典
|
|
||||||
models = {}
|
|
||||||
|
|
||||||
# 1. 加载离线ASR模型
|
|
||||||
print(f"正在加载ASR离线模型: {args.asr_model}")
|
|
||||||
models["asr"] = AutoModel(
|
|
||||||
model=args.asr_model,
|
|
||||||
model_revision=args.asr_model_revision,
|
|
||||||
ngpu=args.ngpu,
|
|
||||||
ncpu=args.ncpu,
|
|
||||||
device=args.device,
|
|
||||||
disable_pbar=True,
|
|
||||||
disable_log=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. 加载在线ASR模型
|
|
||||||
print(f"正在加载ASR在线模型: {args.asr_model_online}")
|
|
||||||
models["asr_streaming"] = AutoModel(
|
|
||||||
model=args.asr_model_online,
|
|
||||||
model_revision=args.asr_model_online_revision,
|
|
||||||
ngpu=args.ngpu,
|
|
||||||
ncpu=args.ncpu,
|
|
||||||
device=args.device,
|
|
||||||
disable_pbar=True,
|
|
||||||
disable_log=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. 加载VAD模型
|
|
||||||
print(f"正在加载VAD模型: {args.vad_model}")
|
|
||||||
models["vad"] = AutoModel(
|
|
||||||
model=args.vad_model,
|
|
||||||
model_revision=args.vad_model_revision,
|
|
||||||
ngpu=args.ngpu,
|
|
||||||
ncpu=args.ncpu,
|
|
||||||
device=args.device,
|
|
||||||
disable_pbar=True,
|
|
||||||
disable_log=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. 加载标点符号模型(如果指定)
|
|
||||||
if args.punc_model:
|
|
||||||
print(f"正在加载标点符号模型: {args.punc_model}")
|
|
||||||
models["punc"] = AutoModel(
|
|
||||||
model=args.punc_model,
|
|
||||||
model_revision=args.punc_model_revision,
|
|
||||||
ngpu=args.ngpu,
|
|
||||||
ncpu=args.ncpu,
|
|
||||||
device=args.device,
|
|
||||||
disable_pbar=True,
|
|
||||||
disable_log=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
models["punc"] = None
|
|
||||||
print("未指定标点符号模型,将不使用标点符号")
|
|
||||||
|
|
||||||
print("所有模型加载完成")
|
|
||||||
return models
|
|
9
src/models/__init__.py
Normal file
9
src/models/__init__.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
|
||||||
|
from .vad import VAD_Functor_result
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AudioBinary_Config",
|
||||||
|
"AudioBinary_data_list",
|
||||||
|
"_AudioBinary_data",
|
||||||
|
"VAD_Functor_result",
|
||||||
|
]
|
158
src/models/audio.py
Normal file
158
src/models/audio.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
from typing import List, Any
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
from src.utils import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
binary_data_types = (bytes, numpy.ndarray)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioBinary_Config(BaseModel):
|
||||||
|
"""二进制音频块配置信息"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
audio_data: binary_data_types = Field(description="音频数据", default=None)
|
||||||
|
chunk_size: int = Field(description="块大小", default=100)
|
||||||
|
chunk_stride: int = Field(description="块步长", default=1600)
|
||||||
|
sample_rate: int = Field(description="采样率", default=16000)
|
||||||
|
sample_width: int = Field(description="采样位宽", default=2)
|
||||||
|
channels: int = Field(description="通道数", default=1)
|
||||||
|
|
||||||
|
# 从Dict中加载
|
||||||
|
@classmethod
|
||||||
|
def AudioBinary_Config_from_dict(cls, data: dict):
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
def ms2frame(self, ms: int) -> int:
|
||||||
|
"""
|
||||||
|
将毫秒转换为帧
|
||||||
|
"""
|
||||||
|
return int(ms * self.sample_rate / 1000)
|
||||||
|
|
||||||
|
def frame2ms(self, frame: int) -> int:
|
||||||
|
"""
|
||||||
|
将帧转换为毫秒
|
||||||
|
"""
|
||||||
|
return int(frame * 1000 / self.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
|
class _AudioBinary_data(BaseModel):
|
||||||
|
"""音频数据"""
|
||||||
|
|
||||||
|
binary_data: binary_data_types = Field(description="音频二进制数据", default=None)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@validator("binary_data")
|
||||||
|
def validate_binary_data(cls, v):
|
||||||
|
"""
|
||||||
|
验证音频数据
|
||||||
|
Args:
|
||||||
|
v: 音频数据
|
||||||
|
Returns:
|
||||||
|
binary_data_types: 音频数据
|
||||||
|
"""
|
||||||
|
if not isinstance(v, (bytes, numpy.ndarray)):
|
||||||
|
logger.warning(
|
||||||
|
"[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查",
|
||||||
|
cls.__class__.__name__,
|
||||||
|
type(v),
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
获取音频数据长度
|
||||||
|
Returns:
|
||||||
|
int: 音频数据长度
|
||||||
|
"""
|
||||||
|
return len(self.binary_data)
|
||||||
|
|
||||||
|
def __init__(self, binary_data: binary_data_types):
|
||||||
|
"""
|
||||||
|
初始化音频数据
|
||||||
|
Args:
|
||||||
|
binary_data: 音频数据
|
||||||
|
"""
|
||||||
|
logger.debug(
|
||||||
|
"[%s]初始化音频数据, 数据类型为%s",
|
||||||
|
self.__class__.__name__,
|
||||||
|
type(binary_data),
|
||||||
|
)
|
||||||
|
super().__init__(binary_data=binary_data)
|
||||||
|
|
||||||
|
def __getitem__(self):
|
||||||
|
"""
|
||||||
|
当获取数据时, 直接返回binary_data
|
||||||
|
Returns:
|
||||||
|
binary_data_types: 音频数据
|
||||||
|
"""
|
||||||
|
return self.binary_data
|
||||||
|
|
||||||
|
|
||||||
|
class AudioBinary_data_list(BaseModel):
|
||||||
|
"""音频数据列表"""
|
||||||
|
|
||||||
|
binary_data_list: List[_AudioBinary_data] = Field(
|
||||||
|
description="音频数据列表", default=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def push_data(self, data: binary_data_types) -> int:
|
||||||
|
"""
|
||||||
|
添加音频数据
|
||||||
|
Args:
|
||||||
|
data: 音频数据
|
||||||
|
Returns:
|
||||||
|
int: 数据在binary_data_list中的索引
|
||||||
|
"""
|
||||||
|
self.binary_data_list.append(_AudioBinary_data(binary_data=data))
|
||||||
|
return len(self.binary_data_list) - 1
|
||||||
|
|
||||||
|
def __getitem__(self, index: int):
|
||||||
|
"""
|
||||||
|
获取音频数据
|
||||||
|
Args:
|
||||||
|
index: 音频数据在binary_data_list中的索引
|
||||||
|
Returns:
|
||||||
|
_AudioBinary_data: 音频数据
|
||||||
|
"""
|
||||||
|
return self.binary_data_list[index]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
获取音频数据列表长度
|
||||||
|
Returns:
|
||||||
|
int: 音频数据列表长度
|
||||||
|
"""
|
||||||
|
return len(self.binary_data_list)
|
||||||
|
|
||||||
|
|
||||||
|
# class AudioBinary_Slice(BaseModel):
|
||||||
|
# """音频块切片"""
|
||||||
|
# target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None)
|
||||||
|
# start_index: int = Field(description="开始位置", default=0)
|
||||||
|
# end_index: int = Field(description="结束位置", default=-1)
|
||||||
|
|
||||||
|
# @validator('start_index')
|
||||||
|
# def validate_start_index(cls, v):
|
||||||
|
# if v < 0:
|
||||||
|
# raise ValueError("start_index必须大于0")
|
||||||
|
# return v
|
||||||
|
|
||||||
|
# @validator('end_index')
|
||||||
|
# def validate_end_index(cls, v):
|
||||||
|
# if v < cls.start_index:
|
||||||
|
# logger.debug("[%s]end_index小于start_index, 将end_index设置为start_index", cls.__class__.__name__)
|
||||||
|
# v = cls.start_index
|
||||||
|
# return v
|
||||||
|
|
||||||
|
# def __call__(self):
|
||||||
|
# return self.target_Binary(self.start_index, self.end_index)
|
91
src/models/vad.py
Normal file
91
src/models/vad.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
from typing import List, Optional, Callable, Any
|
||||||
|
from .audio import AudioBinary_data_list, _AudioBinary_data
|
||||||
|
|
||||||
|
|
||||||
|
class VAD_Functor_result(BaseModel):
|
||||||
|
"""
|
||||||
|
VADFunctor结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
audiobinary_data_list: AudioBinary_data_list = Field(description="音频数据列表")
|
||||||
|
audiobinary_index: int = Field(description="音频数据索引")
|
||||||
|
audiobinary_data: _AudioBinary_data = Field(
|
||||||
|
description="音频数据, 指向AudioBinary_data"
|
||||||
|
)
|
||||||
|
start_time: int = Field(description="开始时间", is_required=True)
|
||||||
|
end_time: int = Field(description="结束时间", is_required=True)
|
||||||
|
|
||||||
|
@validator("audiobinary_data_list")
|
||||||
|
def validate_audiobinary_data_list(cls, v):
|
||||||
|
if not isinstance(v, AudioBinary_data_list):
|
||||||
|
raise ValueError("audiobinary_data_list必须是AudioBinary_data_list类型")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("audiobinary_index")
|
||||||
|
def validate_audiobinary_index(cls, v):
|
||||||
|
if not isinstance(v, int):
|
||||||
|
raise ValueError("audiobinary_index必须是int类型")
|
||||||
|
if v < 0:
|
||||||
|
raise ValueError("audiobinary_index必须大于0")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("audiobinary_data")
|
||||||
|
def validate_audiobinary_data(cls, v):
|
||||||
|
if not isinstance(v, _AudioBinary_data):
|
||||||
|
raise ValueError("audiobinary_data必须是AudioBinary_data类型")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("start_time")
|
||||||
|
def validate_start_time(cls, v):
|
||||||
|
if not isinstance(v, int):
|
||||||
|
raise ValueError("start_time必须是int类型")
|
||||||
|
if v < 0:
|
||||||
|
raise ValueError("start_time必须大于0")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("end_time")
|
||||||
|
def validate_end_time(cls, v, values):
|
||||||
|
if not isinstance(v, int):
|
||||||
|
raise ValueError("end_time必须是int类型")
|
||||||
|
if "start_time" in values and v <= values["start_time"]:
|
||||||
|
raise ValueError("end_time必须大于start_time")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_push_data(
|
||||||
|
cls,
|
||||||
|
audiobinary_data_list: AudioBinary_data_list,
|
||||||
|
data: Any,
|
||||||
|
start_time: int,
|
||||||
|
end_time: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
创建VAD片段
|
||||||
|
"""
|
||||||
|
index = audiobinary_data_list.push_data(data)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
audiobinary_data_list=audiobinary_data_list,
|
||||||
|
audiobinary_index=index,
|
||||||
|
audiobinary_data=audiobinary_data_list[index],
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
获取音频数据长度
|
||||||
|
"""
|
||||||
|
return len(self.audiobinary_data.binary_data)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""
|
||||||
|
字符串展示内容
|
||||||
|
"""
|
||||||
|
tostr = f"audiobinary_data_index: {self.audiobinary_index}\n"
|
||||||
|
tostr += f"start_time: {self.start_time}\n"
|
||||||
|
tostr += f"end_time: {self.end_time}\n"
|
||||||
|
tostr += f"data_length: {len(self.audiobinary_data.binary_data)}\n"
|
||||||
|
tostr += f"data_type: {type(self.audiobinary_data.binary_data)}\n"
|
||||||
|
return tostr
|
265
src/pipeline/ASRpipeline.py
Normal file
265
src/pipeline/ASRpipeline.py
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
from src.pipeline.base import PipelineBase
|
||||||
|
from typing import Dict, Any, Callable
|
||||||
|
from queue import Queue, Empty
|
||||||
|
from src.utils import get_module_logger
|
||||||
|
from src.models import AudioBinary_data_list
|
||||||
|
import threading
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ASRPipeline(PipelineBase):
|
||||||
|
"""
|
||||||
|
管道类
|
||||||
|
实现具体的处理逻辑
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
初始化管道
|
||||||
|
"""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._config: Dict[str, Any] = {}
|
||||||
|
self._functor_dict: Dict[str, Any] = {}
|
||||||
|
self._subqueue_dict: Dict[str, Any] = {}
|
||||||
|
self._is_baked: bool = False
|
||||||
|
self._input_queue: Queue = None
|
||||||
|
self._audio_binary_data_list: AudioBinary_data_list = None
|
||||||
|
|
||||||
|
self._status_lock = threading.Lock()
|
||||||
|
self._is_running: bool = False
|
||||||
|
self._stop_event: bool = False
|
||||||
|
|
||||||
|
def set_config(self, config: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
设置配置
|
||||||
|
参数:
|
||||||
|
config: Dict[str, Any] 配置
|
||||||
|
"""
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
def get_config(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取配置
|
||||||
|
返回:
|
||||||
|
Dict[str, Any] 配置
|
||||||
|
"""
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
def set_audio_binary(self, audio_binary: AudioBinary_data_list) -> None:
|
||||||
|
"""
|
||||||
|
设置音频二进制存储单元
|
||||||
|
参数:
|
||||||
|
audio_binary: 音频二进制
|
||||||
|
"""
|
||||||
|
self._audio_binary = audio_binary
|
||||||
|
|
||||||
|
def set_models(self, models: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
设置模型
|
||||||
|
"""
|
||||||
|
self._models = models
|
||||||
|
|
||||||
|
def set_input_queue(self, input_queue: Queue) -> None:
|
||||||
|
"""
|
||||||
|
设置输入队列
|
||||||
|
"""
|
||||||
|
self._input_queue = input_queue
|
||||||
|
|
||||||
|
def bake(self) -> None:
|
||||||
|
"""
|
||||||
|
烘焙管道
|
||||||
|
"""
|
||||||
|
self._pre_check_resource()
|
||||||
|
self._init_functor()
|
||||||
|
self._is_baked = True
|
||||||
|
|
||||||
|
def _pre_check_resource(self) -> None:
|
||||||
|
"""
|
||||||
|
预检查资源
|
||||||
|
"""
|
||||||
|
if self._input_queue is None:
|
||||||
|
raise RuntimeError("[ASRpipeline]输入队列未设置")
|
||||||
|
if self._functor_dict is None:
|
||||||
|
raise RuntimeError("[ASRpipeline]functor字典未设置")
|
||||||
|
if self._subqueue_dict is None:
|
||||||
|
raise RuntimeError("[ASRpipeline]子队列字典未设置")
|
||||||
|
if self._audio_binary is None:
|
||||||
|
raise RuntimeError("[ASRpipeline]音频存储单元未设置")
|
||||||
|
|
||||||
|
def _init_functor(self) -> None:
|
||||||
|
"""
|
||||||
|
初始化函数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from src.functor import FunctorFactory
|
||||||
|
|
||||||
|
# 加载VAD、asr、spk functor
|
||||||
|
self._functor_dict["vad"] = FunctorFactory.make_functor(
|
||||||
|
functor_name="vad", config=self._config, models=self._models
|
||||||
|
)
|
||||||
|
self._functor_dict["asr"] = FunctorFactory.make_functor(
|
||||||
|
functor_name="asr", config=self._config, models=self._models
|
||||||
|
)
|
||||||
|
self._functor_dict["spk"] = FunctorFactory.make_functor(
|
||||||
|
functor_name="spk", config=self._config, models=self._models
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建音频数据存储单元
|
||||||
|
self._audio_binary_data_list = AudioBinary_data_list()
|
||||||
|
|
||||||
|
self._functor_dict["vad"].set_audio_binary_data_list(
|
||||||
|
self._audio_binary_data_list
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化子队列
|
||||||
|
self._subqueue_dict["original"] = Queue()
|
||||||
|
self._subqueue_dict["vad2asr"] = Queue()
|
||||||
|
self._subqueue_dict["vad2spk"] = Queue()
|
||||||
|
self._subqueue_dict["asrend"] = Queue()
|
||||||
|
self._subqueue_dict["spkend"] = Queue()
|
||||||
|
|
||||||
|
# 设置子队列的输入队列
|
||||||
|
self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"])
|
||||||
|
self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
|
||||||
|
self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
|
||||||
|
|
||||||
|
# 设置回调函数——放置到对应队列中
|
||||||
|
self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put)
|
||||||
|
self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put)
|
||||||
|
|
||||||
|
# 构造带回调函数的put
|
||||||
|
def put_with_checkcallback(queue: Queue, callback: Callable) -> None:
|
||||||
|
"""
|
||||||
|
带回调函数的put
|
||||||
|
"""
|
||||||
|
|
||||||
|
def put_with_check(data: Any) -> None:
|
||||||
|
queue.put(data)
|
||||||
|
callback(data)
|
||||||
|
|
||||||
|
return put_with_check
|
||||||
|
|
||||||
|
self._functor_dict["asr"].add_callback(
|
||||||
|
put_with_checkcallback(
|
||||||
|
self._subqueue_dict["asrend"], self._check_result
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._functor_dict["spk"].add_callback(
|
||||||
|
put_with_checkcallback(
|
||||||
|
self._subqueue_dict["spkend"], self._check_result
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
|
||||||
|
|
||||||
|
def _check_result(self, result: Any) -> None:
|
||||||
|
"""
|
||||||
|
检查结果
|
||||||
|
"""
|
||||||
|
# 若asr和spk队列中都有数据,则合并数据
|
||||||
|
if (
|
||||||
|
self._subqueue_dict["asrend"].qsize()
|
||||||
|
& self._subqueue_dict["spkend"].qsize()
|
||||||
|
):
|
||||||
|
asr_data = self._subqueue_dict["asrend"].get()
|
||||||
|
spk_data = self._subqueue_dict["spkend"].get()
|
||||||
|
# 合并数据
|
||||||
|
result = {"asr_data": asr_data, "spk_data": spk_data}
|
||||||
|
# 通知回调函数
|
||||||
|
self._notify_callbacks(result)
|
||||||
|
|
||||||
|
def run(self) -> threading.Thread:
|
||||||
|
"""
|
||||||
|
运行管道
|
||||||
|
Returns:
|
||||||
|
threading.Thread: 返回已运行线程实例
|
||||||
|
"""
|
||||||
|
# 检查运行资源是否准备完毕
|
||||||
|
self._pre_check()
|
||||||
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
logger.info("[ASRpipeline]管道开始运行")
|
||||||
|
return self._thread
|
||||||
|
|
||||||
|
def _pre_check(self) -> None:
|
||||||
|
"""
|
||||||
|
预检查
|
||||||
|
"""
|
||||||
|
if self._is_baked is False:
|
||||||
|
raise RuntimeError("[ASRpipeline]管道未烘焙,无法运行")
|
||||||
|
|
||||||
|
for functor_name, functor in self._functor_dict.items():
|
||||||
|
if functor is None:
|
||||||
|
raise RuntimeError(f"[ASRpipeline]functor{functor_name}异常")
|
||||||
|
|
||||||
|
for subqueue_name, subqueue in self._subqueue_dict.items():
|
||||||
|
if subqueue is None:
|
||||||
|
raise RuntimeError(f"[ASRpipeline]子队列{subqueue_name}异常")
|
||||||
|
|
||||||
|
def _run(self) -> None:
|
||||||
|
"""
|
||||||
|
真实的运行逻辑
|
||||||
|
"""
|
||||||
|
# 运行所有functor
|
||||||
|
for functor_name, functor in self._functor_dict.items():
|
||||||
|
logger.info(f"[ASRpipeline]运行{functor_name}functor")
|
||||||
|
self._functor_dict[functor_name].run()
|
||||||
|
|
||||||
|
# 设置管道运行状态
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = True
|
||||||
|
self._stop_event = False
|
||||||
|
|
||||||
|
while self._is_running and not self._stop_event:
|
||||||
|
try:
|
||||||
|
data = self._input_queue.get(timeout=self._queue_timeout)
|
||||||
|
# 检查是否是结束信号
|
||||||
|
if data is None:
|
||||||
|
logger.info("收到结束信号,管道准备停止")
|
||||||
|
self._input_queue.task_done() # 标记结束信号已处理
|
||||||
|
break
|
||||||
|
|
||||||
|
# 处理数据
|
||||||
|
self._process(data)
|
||||||
|
|
||||||
|
# 标记任务完成
|
||||||
|
self._input_queue.task_done()
|
||||||
|
|
||||||
|
except Empty:
|
||||||
|
# 队列获取超时,继续等待
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[ASRpipeline]管道处理数据出错: {str(e)}")
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info("[ASRpipeline]管道停止运行")
|
||||||
|
|
||||||
|
def _process(self, data: Any) -> Any:
|
||||||
|
"""
|
||||||
|
处理数据
|
||||||
|
参数:
|
||||||
|
data: 输入数据
|
||||||
|
返回:
|
||||||
|
处理结果
|
||||||
|
"""
|
||||||
|
# 子类实现具体的处理逻辑
|
||||||
|
self._subqueue_dict["original"].put(data)
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""
|
||||||
|
停止管道
|
||||||
|
"""
|
||||||
|
with self._status_lock:
|
||||||
|
self._is_running = False
|
||||||
|
self._stop_event = True
|
||||||
|
for functor_name, functor in self._functor_dict.items():
|
||||||
|
# logger.info(f"停止{functor_name}functor")
|
||||||
|
if functor.stop():
|
||||||
|
logger.info(f"[ASRpipeline]子Functor[{functor_name}]停止")
|
||||||
|
else:
|
||||||
|
logger.error(f"[ASRpipeline]子Functor[{functor_name}]停止失败")
|
||||||
|
self._thread.join()
|
||||||
|
logger.info("[ASRpipeline]管道停止")
|
||||||
|
return True
|
3
src/pipeline/__init__.py
Normal file
3
src/pipeline/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from src.pipeline.base import PipelineBase, PipelineFactory
|
||||||
|
|
||||||
|
__all__ = ["PipelineBase", "PipelineFactory"]
|
151
src/pipeline/base.py
Normal file
151
src/pipeline/base.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from queue import Queue, Empty
|
||||||
|
from typing import List, Callable, Any, Optional
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineBase(ABC):
|
||||||
|
"""
|
||||||
|
管道基类
|
||||||
|
定义了管道的基本接口和通用功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_queue: Optional[Queue] = None):
|
||||||
|
"""
|
||||||
|
初始化管道
|
||||||
|
参数:
|
||||||
|
input_queue: 输入队列,用于接收数据
|
||||||
|
"""
|
||||||
|
self._input_queue = input_queue
|
||||||
|
self._callbacks: List[Callable] = []
|
||||||
|
self._is_running = False
|
||||||
|
self._stop_event = False
|
||||||
|
self._thread: Optional[threading.Thread] = None
|
||||||
|
self._stop_timeout = 5 # 默认停止超时时间(秒)
|
||||||
|
self._queue_timeout = 1 # 队列获取超时时间(秒)
|
||||||
|
|
||||||
|
def set_input_queue(self, queue: Queue) -> None:
|
||||||
|
"""
|
||||||
|
设置输入队列
|
||||||
|
参数:
|
||||||
|
queue: 输入队列
|
||||||
|
"""
|
||||||
|
self._input_queue = queue
|
||||||
|
|
||||||
|
def add_callback(self, callback: Callable) -> None:
|
||||||
|
"""
|
||||||
|
添加回调函数
|
||||||
|
参数:
|
||||||
|
callback: 回调函数,接收处理结果
|
||||||
|
"""
|
||||||
|
self._callbacks.append(callback)
|
||||||
|
|
||||||
|
def _notify_callbacks(self, result: Any) -> None:
|
||||||
|
"""
|
||||||
|
通知所有回调函数
|
||||||
|
参数:
|
||||||
|
result: 处理结果
|
||||||
|
"""
|
||||||
|
for callback in self._callbacks:
|
||||||
|
try:
|
||||||
|
callback(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"回调函数执行出错: {str(e)}")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _process(self, data: Any) -> Any:
|
||||||
|
"""
|
||||||
|
处理数据
|
||||||
|
参数:
|
||||||
|
data: 输入数据
|
||||||
|
返回:
|
||||||
|
处理结果
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _run(self) -> None:
|
||||||
|
"""
|
||||||
|
运行管道
|
||||||
|
从输入队列获取数据并处理
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def stop(self, timeout: Optional[float] = None) -> bool:
|
||||||
|
"""
|
||||||
|
停止管道
|
||||||
|
参数:
|
||||||
|
timeout: 停止超时时间(秒),None表示使用默认超时时间
|
||||||
|
返回:
|
||||||
|
bool: 是否成功停止
|
||||||
|
"""
|
||||||
|
if not self._is_running:
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.info("正在停止管道...")
|
||||||
|
self._stop_event = True
|
||||||
|
self._is_running = False
|
||||||
|
|
||||||
|
# 等待线程结束
|
||||||
|
if self._thread and self._thread.is_alive():
|
||||||
|
timeout = timeout if timeout is not None else self._stop_timeout
|
||||||
|
self._thread.join(timeout=timeout)
|
||||||
|
|
||||||
|
# 检查是否成功停止
|
||||||
|
if self._thread.is_alive():
|
||||||
|
logger.warning(f"管道停止超时({timeout}秒),强制终止")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.info("管道已成功停止")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def force_stop(self) -> None:
|
||||||
|
"""
|
||||||
|
强制停止管道
|
||||||
|
注意:这可能会导致资源未正确释放
|
||||||
|
"""
|
||||||
|
logger.warning("强制停止管道")
|
||||||
|
self._stop_event = True
|
||||||
|
self._is_running = False
|
||||||
|
# 注意:Python的线程无法被强制终止,这里只是设置标志
|
||||||
|
# 实际终止需要依赖操作系统的进程管理
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineFactory:
|
||||||
|
"""
|
||||||
|
管道工厂类
|
||||||
|
用于创建管道实例
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.pipeline.ASRpipeline import ASRPipeline
|
||||||
|
def _create_pipeline_ASRpipeline(*args, **kwargs) -> ASRPipeline:
|
||||||
|
"""
|
||||||
|
创建ASR管道实例
|
||||||
|
"""
|
||||||
|
from src.pipeline.ASRpipeline import ASRPipeline
|
||||||
|
pipeline = ASRPipeline()
|
||||||
|
pipeline.set_config(kwargs["config"])
|
||||||
|
pipeline.set_models(kwargs["models"])
|
||||||
|
pipeline.set_audio_binary(kwargs["audio_binary"])
|
||||||
|
pipeline.set_input_queue(kwargs["input_queue"])
|
||||||
|
pipeline.add_callback(kwargs["callback"])
|
||||||
|
pipeline.bake()
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_pipeline(cls, pipeline_name: str, *args, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
创建管道实例
|
||||||
|
"""
|
||||||
|
if pipeline_name == "ASRpipeline":
|
||||||
|
return cls._create_pipeline_ASRpipeline(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的管道类型: {pipeline_name}")
|
||||||
|
|
0
src/pipeline/test.py
Normal file
0
src/pipeline/test.py
Normal file
281
src/runner.py
Normal file
281
src/runner.py
Normal file
@ -0,0 +1,281 @@
|
|||||||
|
"""
|
||||||
|
运行器模块
|
||||||
|
提供运行器基类和运行器类,用于管理音频数据和模型的交互。
|
||||||
|
主要包含:
|
||||||
|
- RunnerBase: 运行器基类,定义了基本接口
|
||||||
|
- Runner: 运行器类,工厂模式实现
|
||||||
|
- RunnerFactory: 运行器工厂类,用于创建运行器
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
from threading import Thread, Lock
|
||||||
|
from queue import Queue
|
||||||
|
import traceback
|
||||||
|
import time
|
||||||
|
|
||||||
|
from src.audio_chunk import AudioChunk, AudioBinary
|
||||||
|
from src.pipeline import Pipeline, PipelineFactory
|
||||||
|
from src.model_loader import ModelLoader
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__, level="INFO")
|
||||||
|
|
||||||
|
audio_chunk = AudioChunk()
|
||||||
|
models_loaded = ModelLoader()
|
||||||
|
|
||||||
|
|
||||||
|
class RunnerBase(ABC):
|
||||||
|
"""
|
||||||
|
运行器基类
|
||||||
|
定义了运行器的基本接口
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def adder(self, data: Any) -> None:
|
||||||
|
"""
|
||||||
|
添加数据
|
||||||
|
参数:
|
||||||
|
data: 要添加的数据
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_recevier(self, receiver: callable) -> None:
|
||||||
|
"""
|
||||||
|
添加数据接收者
|
||||||
|
参数:
|
||||||
|
receiver: 接收数据的回调函数
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class STTRunner(RunnerBase):
|
||||||
|
"""
|
||||||
|
运行器类
|
||||||
|
负责管理资源和协调Pipeline的运行
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
audio_binary_list: List[AudioBinary],
|
||||||
|
models: Dict[str, Any],
|
||||||
|
pipeline_list: List[Pipeline],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化运行器
|
||||||
|
参数:
|
||||||
|
audio_binary_list: 音频二进制列表
|
||||||
|
models: 模型字典
|
||||||
|
pipeline_list: 管道列表
|
||||||
|
queue_size: 队列大小
|
||||||
|
stop_timeout: 停止超时时间(秒)
|
||||||
|
"""
|
||||||
|
# 接收资源
|
||||||
|
self._audio_binary_list = audio_binary_list
|
||||||
|
self._models = models
|
||||||
|
self._pipeline_list = pipeline_list
|
||||||
|
|
||||||
|
# 线程控制
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
# 消息队列
|
||||||
|
self._input_queue = Queue(maxsize=1000)
|
||||||
|
|
||||||
|
# 停止控制
|
||||||
|
self._stop_timeout = 10.0
|
||||||
|
self._is_stopping = False
|
||||||
|
|
||||||
|
# 配置资源
|
||||||
|
for pipeline in self._pipeline_list:
|
||||||
|
# 设置输入队列
|
||||||
|
pipeline.set_input_queue(self._input_queue)
|
||||||
|
|
||||||
|
# 配置资源
|
||||||
|
pipeline.set_audio_binary(
|
||||||
|
self._audio_binary_list[pipeline.get_config("audio_binary_name")]
|
||||||
|
)
|
||||||
|
pipeline.set_models(self._models)
|
||||||
|
|
||||||
|
def adder(self, data: Any) -> None:
|
||||||
|
"""
|
||||||
|
添加数据到输入队列
|
||||||
|
参数:
|
||||||
|
data: 要添加的数据
|
||||||
|
"""
|
||||||
|
if not self._pipeline_list:
|
||||||
|
raise RuntimeError("没有可用的管道")
|
||||||
|
if self._is_stopping:
|
||||||
|
raise RuntimeError("运行器正在停止,无法添加数据")
|
||||||
|
self._input_queue.put(data)
|
||||||
|
|
||||||
|
def add_recevier(self, receiver: callable) -> None:
|
||||||
|
"""
|
||||||
|
添加数据接收者
|
||||||
|
参数:
|
||||||
|
receiver: 接收数据的回调函数
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
for pipeline in self._pipeline_list:
|
||||||
|
pipeline.add_callback(receiver)
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
"""
|
||||||
|
启动所有管道
|
||||||
|
"""
|
||||||
|
logger.info("[%s] 启动所有管道", self.__class__.__name__)
|
||||||
|
if not self._pipeline_list:
|
||||||
|
raise RuntimeError("没有可用的管道")
|
||||||
|
|
||||||
|
# 启动所有管道
|
||||||
|
for pipeline in self._pipeline_list:
|
||||||
|
thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}")
|
||||||
|
thread.daemon = True
|
||||||
|
thread.start()
|
||||||
|
logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline))
|
||||||
|
|
||||||
|
def stop(self, force: bool = False) -> bool:
|
||||||
|
"""
|
||||||
|
停止所有管道
|
||||||
|
参数:
|
||||||
|
force: 是否强制停止
|
||||||
|
返回:
|
||||||
|
bool: 是否成功停止
|
||||||
|
"""
|
||||||
|
if self._is_stopping:
|
||||||
|
logger.warning("运行器已经在停止中")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._is_stopping = True
|
||||||
|
logger.info("正在停止运行器...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 发送结束信号
|
||||||
|
self._input_queue.put(None)
|
||||||
|
|
||||||
|
# 停止所有管道
|
||||||
|
success = True
|
||||||
|
for pipeline in self._pipeline_list:
|
||||||
|
if force:
|
||||||
|
pipeline.force_stop()
|
||||||
|
else:
|
||||||
|
if not pipeline.stop(timeout=self._stop_timeout):
|
||||||
|
logger.warning("管道 %s 停止超时", id(pipeline))
|
||||||
|
success = False
|
||||||
|
|
||||||
|
# 等待队列处理完成
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
while not self._input_queue.empty():
|
||||||
|
if time.time() - start_time > self._stop_timeout:
|
||||||
|
logger.warning(
|
||||||
|
"等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理",
|
||||||
|
self._stop_timeout,
|
||||||
|
self._input_queue.qsize(),
|
||||||
|
)
|
||||||
|
success = False
|
||||||
|
break
|
||||||
|
time.sleep(0.1) # 避免过度消耗CPU
|
||||||
|
except Exception as e:
|
||||||
|
error_type = type(e).__name__
|
||||||
|
error_msg = str(e)
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
logger.error(
|
||||||
|
"等待队列处理完成时发生错误:\n"
|
||||||
|
"错误类型: %s\n"
|
||||||
|
"错误信息: %s\n"
|
||||||
|
"错误堆栈:\n%s",
|
||||||
|
error_type,
|
||||||
|
error_msg,
|
||||||
|
error_traceback,
|
||||||
|
)
|
||||||
|
success = False
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info("所有管道已成功停止")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"部分管道停止失败,队列状态: 大小=%d, 是否为空=%s",
|
||||||
|
self._input_queue.qsize(),
|
||||||
|
self._input_queue.empty(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self._is_stopping = False
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
"""
|
||||||
|
析构函数
|
||||||
|
"""
|
||||||
|
self.stop(force=True)
|
||||||
|
|
||||||
|
|
||||||
|
class STTRunnerFactory:
|
||||||
|
"""
|
||||||
|
STT Runner工厂类
|
||||||
|
用于创建运行器实例
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_runner(
|
||||||
|
audio_binary_name: str,
|
||||||
|
model_name_list: List[str],
|
||||||
|
pipeline_name_list: List[str],
|
||||||
|
) -> STTRunner:
|
||||||
|
"""
|
||||||
|
创建运行器
|
||||||
|
参数:
|
||||||
|
audio_binary_name: 音频二进制名称
|
||||||
|
model_name_list: 模型名称列表
|
||||||
|
pipeline_name_list: 管道名称列表
|
||||||
|
返回:
|
||||||
|
Runner实例
|
||||||
|
"""
|
||||||
|
audio_binary = audio_chunk.get_audio_binary(audio_binary_name)
|
||||||
|
models: Dict[str, Any] = {
|
||||||
|
model_name: models_loaded.models[model_name]
|
||||||
|
for model_name in model_name_list
|
||||||
|
}
|
||||||
|
pipelines: List[Pipeline] = [
|
||||||
|
PipelineFactory.create_pipeline(pipeline_name)
|
||||||
|
for pipeline_name in pipeline_name_list
|
||||||
|
]
|
||||||
|
return STTRunner(
|
||||||
|
audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_runner_from_config(
|
||||||
|
cls,
|
||||||
|
config: Dict[str, Any],
|
||||||
|
) -> STTRunner:
|
||||||
|
"""
|
||||||
|
从配置创建运行器
|
||||||
|
参数:
|
||||||
|
config: 配置字典
|
||||||
|
返回:
|
||||||
|
Runner实例
|
||||||
|
"""
|
||||||
|
audio_binary_name = config["audio_binary_name"]
|
||||||
|
model_name_list = config["model_name_list"]
|
||||||
|
pipeline_name_list = config["pipeline_name_list"]
|
||||||
|
return cls._create_runner(
|
||||||
|
audio_binary_name, model_name_list, pipeline_name_list
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_runner_normal(cls) -> STTRunner:
|
||||||
|
"""
|
||||||
|
创建默认运行器
|
||||||
|
返回:
|
||||||
|
Runner实例
|
||||||
|
"""
|
||||||
|
audio_binary_name = None
|
||||||
|
model_name_list = list(models_loaded.models.keys())
|
||||||
|
pipeline_name_list = None
|
||||||
|
return cls._create_runner(
|
||||||
|
audio_binary_name, model_name_list, pipeline_name_list
|
||||||
|
)
|
108
src/server.py
108
src/server.py
@ -43,7 +43,7 @@ async def clear_websocket():
|
|||||||
async def ws_serve(websocket, path):
|
async def ws_serve(websocket, path):
|
||||||
"""
|
"""
|
||||||
WebSocket服务主函数,处理客户端连接和消息
|
WebSocket服务主函数,处理客户端连接和消息
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
websocket: WebSocket连接对象
|
websocket: WebSocket连接对象
|
||||||
path: 连接路径
|
path: 连接路径
|
||||||
@ -51,13 +51,13 @@ async def ws_serve(websocket, path):
|
|||||||
frames = [] # 存储所有音频帧
|
frames = [] # 存储所有音频帧
|
||||||
frames_asr = [] # 存储用于离线ASR的音频帧
|
frames_asr = [] # 存储用于离线ASR的音频帧
|
||||||
frames_asr_online = [] # 存储用于在线ASR的音频帧
|
frames_asr_online = [] # 存储用于在线ASR的音频帧
|
||||||
|
|
||||||
global websocket_users
|
global websocket_users
|
||||||
# await clear_websocket() # 清理现有连接(目前注释掉,允许多客户端)
|
# await clear_websocket() # 清理现有连接(目前注释掉,允许多客户端)
|
||||||
|
|
||||||
# 添加到用户集合
|
# 添加到用户集合
|
||||||
websocket_users.add(websocket)
|
websocket_users.add(websocket)
|
||||||
|
|
||||||
# 初始化连接状态
|
# 初始化连接状态
|
||||||
websocket.status_dict_asr = {}
|
websocket.status_dict_asr = {}
|
||||||
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
|
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
|
||||||
@ -66,15 +66,15 @@ async def ws_serve(websocket, path):
|
|||||||
websocket.chunk_interval = 10
|
websocket.chunk_interval = 10
|
||||||
websocket.vad_pre_idx = 0
|
websocket.vad_pre_idx = 0
|
||||||
websocket.is_speaking = True # 默认用户正在说话
|
websocket.is_speaking = True # 默认用户正在说话
|
||||||
|
|
||||||
# 语音检测状态
|
# 语音检测状态
|
||||||
speech_start = False
|
speech_start = False
|
||||||
speech_end_i = -1
|
speech_end_i = -1
|
||||||
|
|
||||||
# 初始化配置
|
# 初始化配置
|
||||||
websocket.wav_name = "microphone"
|
websocket.wav_name = "microphone"
|
||||||
websocket.mode = "2pass" # 默认使用两阶段识别模式
|
websocket.mode = "2pass" # 默认使用两阶段识别模式
|
||||||
|
|
||||||
print("新用户已连接", flush=True)
|
print("新用户已连接", flush=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -84,11 +84,13 @@ async def ws_serve(websocket, path):
|
|||||||
if isinstance(message, str):
|
if isinstance(message, str):
|
||||||
try:
|
try:
|
||||||
messagejson = json.loads(message)
|
messagejson = json.loads(message)
|
||||||
|
|
||||||
# 更新各种配置参数
|
# 更新各种配置参数
|
||||||
if "is_speaking" in messagejson:
|
if "is_speaking" in messagejson:
|
||||||
websocket.is_speaking = messagejson["is_speaking"]
|
websocket.is_speaking = messagejson["is_speaking"]
|
||||||
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
|
websocket.status_dict_asr_online["is_final"] = (
|
||||||
|
not websocket.is_speaking
|
||||||
|
)
|
||||||
if "chunk_interval" in messagejson:
|
if "chunk_interval" in messagejson:
|
||||||
websocket.chunk_interval = messagejson["chunk_interval"]
|
websocket.chunk_interval = messagejson["chunk_interval"]
|
||||||
if "wav_name" in messagejson:
|
if "wav_name" in messagejson:
|
||||||
@ -97,11 +99,17 @@ async def ws_serve(websocket, path):
|
|||||||
chunk_size = messagejson["chunk_size"]
|
chunk_size = messagejson["chunk_size"]
|
||||||
if isinstance(chunk_size, str):
|
if isinstance(chunk_size, str):
|
||||||
chunk_size = chunk_size.split(",")
|
chunk_size = chunk_size.split(",")
|
||||||
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
|
websocket.status_dict_asr_online["chunk_size"] = [
|
||||||
|
int(x) for x in chunk_size
|
||||||
|
]
|
||||||
if "encoder_chunk_look_back" in messagejson:
|
if "encoder_chunk_look_back" in messagejson:
|
||||||
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
|
websocket.status_dict_asr_online["encoder_chunk_look_back"] = (
|
||||||
|
messagejson["encoder_chunk_look_back"]
|
||||||
|
)
|
||||||
if "decoder_chunk_look_back" in messagejson:
|
if "decoder_chunk_look_back" in messagejson:
|
||||||
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"]
|
websocket.status_dict_asr_online["decoder_chunk_look_back"] = (
|
||||||
|
messagejson["decoder_chunk_look_back"]
|
||||||
|
)
|
||||||
if "hotword" in messagejson:
|
if "hotword" in messagejson:
|
||||||
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
|
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
|
||||||
if "mode" in messagejson:
|
if "mode" in messagejson:
|
||||||
@ -111,11 +119,17 @@ async def ws_serve(websocket, path):
|
|||||||
|
|
||||||
# 根据chunk_interval更新VAD的chunk_size
|
# 根据chunk_interval更新VAD的chunk_size
|
||||||
websocket.status_dict_vad["chunk_size"] = int(
|
websocket.status_dict_vad["chunk_size"] = int(
|
||||||
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1] * 60 / websocket.chunk_interval
|
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1]
|
||||||
|
* 60
|
||||||
|
/ websocket.chunk_interval
|
||||||
)
|
)
|
||||||
|
|
||||||
# 处理音频数据
|
# 处理音频数据
|
||||||
if len(frames_asr_online) > 0 or len(frames_asr) >= 0 or not isinstance(message, str):
|
if (
|
||||||
|
len(frames_asr_online) > 0
|
||||||
|
or len(frames_asr) >= 0
|
||||||
|
or not isinstance(message, str)
|
||||||
|
):
|
||||||
if not isinstance(message, str): # 二进制音频数据
|
if not isinstance(message, str): # 二进制音频数据
|
||||||
# 添加到帧缓冲区
|
# 添加到帧缓冲区
|
||||||
frames.append(message)
|
frames.append(message)
|
||||||
@ -125,10 +139,12 @@ async def ws_serve(websocket, path):
|
|||||||
# 处理在线ASR
|
# 处理在线ASR
|
||||||
frames_asr_online.append(message)
|
frames_asr_online.append(message)
|
||||||
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
|
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
|
||||||
|
|
||||||
# 达到chunk_interval或最终帧时处理在线ASR
|
# 达到chunk_interval或最终帧时处理在线ASR
|
||||||
if (len(frames_asr_online) % websocket.chunk_interval == 0 or
|
if (
|
||||||
websocket.status_dict_asr_online["is_final"]):
|
len(frames_asr_online) % websocket.chunk_interval == 0
|
||||||
|
or websocket.status_dict_asr_online["is_final"]
|
||||||
|
):
|
||||||
if websocket.mode == "2pass" or websocket.mode == "online":
|
if websocket.mode == "2pass" or websocket.mode == "online":
|
||||||
audio_in = b"".join(frames_asr_online)
|
audio_in = b"".join(frames_asr_online)
|
||||||
try:
|
try:
|
||||||
@ -136,26 +152,32 @@ async def ws_serve(websocket, path):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"在线ASR处理错误: {e}")
|
print(f"在线ASR处理错误: {e}")
|
||||||
frames_asr_online = []
|
frames_asr_online = []
|
||||||
|
|
||||||
# 如果检测到语音开始,收集帧用于离线ASR
|
# 如果检测到语音开始,收集帧用于离线ASR
|
||||||
if speech_start:
|
if speech_start:
|
||||||
frames_asr.append(message)
|
frames_asr.append(message)
|
||||||
|
|
||||||
# VAD处理 - 语音活动检测
|
# VAD处理 - 语音活动检测
|
||||||
try:
|
try:
|
||||||
speech_start_i, speech_end_i = await asr_service.async_vad(websocket, message)
|
speech_start_i, speech_end_i = await asr_service.async_vad(
|
||||||
|
websocket, message
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"VAD处理错误: {e}")
|
print(f"VAD处理错误: {e}")
|
||||||
|
|
||||||
# 检测到语音开始
|
# 检测到语音开始
|
||||||
if speech_start_i != -1:
|
if speech_start_i != -1:
|
||||||
speech_start = True
|
speech_start = True
|
||||||
# 计算开始偏移并收集前面的帧
|
# 计算开始偏移并收集前面的帧
|
||||||
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
|
beg_bias = (
|
||||||
frames_pre = frames[-beg_bias:] if beg_bias < len(frames) else frames
|
websocket.vad_pre_idx - speech_start_i
|
||||||
|
) // duration_ms
|
||||||
|
frames_pre = (
|
||||||
|
frames[-beg_bias:] if beg_bias < len(frames) else frames
|
||||||
|
)
|
||||||
frames_asr = []
|
frames_asr = []
|
||||||
frames_asr.extend(frames_pre)
|
frames_asr.extend(frames_pre)
|
||||||
|
|
||||||
# 处理离线ASR (语音结束或用户停止说话)
|
# 处理离线ASR (语音结束或用户停止说话)
|
||||||
if speech_end_i != -1 or not websocket.is_speaking:
|
if speech_end_i != -1 or not websocket.is_speaking:
|
||||||
if websocket.mode == "2pass" or websocket.mode == "offline":
|
if websocket.mode == "2pass" or websocket.mode == "offline":
|
||||||
@ -164,13 +186,13 @@ async def ws_serve(websocket, path):
|
|||||||
await asr_service.async_asr(websocket, audio_in)
|
await asr_service.async_asr(websocket, audio_in)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"离线ASR处理错误: {e}")
|
print(f"离线ASR处理错误: {e}")
|
||||||
|
|
||||||
# 重置状态
|
# 重置状态
|
||||||
frames_asr = []
|
frames_asr = []
|
||||||
speech_start = False
|
speech_start = False
|
||||||
frames_asr_online = []
|
frames_asr_online = []
|
||||||
websocket.status_dict_asr_online["cache"] = {}
|
websocket.status_dict_asr_online["cache"] = {}
|
||||||
|
|
||||||
# 如果用户停止说话,完全重置
|
# 如果用户停止说话,完全重置
|
||||||
if not websocket.is_speaking:
|
if not websocket.is_speaking:
|
||||||
websocket.vad_pre_idx = 0
|
websocket.vad_pre_idx = 0
|
||||||
@ -193,34 +215,34 @@ async def ws_serve(websocket, path):
|
|||||||
def start_server(args, asr_service_instance):
|
def start_server(args, asr_service_instance):
|
||||||
"""
|
"""
|
||||||
启动WebSocket服务器
|
启动WebSocket服务器
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
args: 命令行参数
|
args: 命令行参数
|
||||||
asr_service_instance: ASR服务实例
|
asr_service_instance: ASR服务实例
|
||||||
"""
|
"""
|
||||||
global asr_service
|
global asr_service
|
||||||
asr_service = asr_service_instance
|
asr_service = asr_service_instance
|
||||||
|
|
||||||
# 配置SSL (如果提供了证书)
|
# 配置SSL (如果提供了证书)
|
||||||
if args.certfile and len(args.certfile) > 0:
|
if args.certfile and len(args.certfile) > 0:
|
||||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile)
|
ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile)
|
||||||
|
|
||||||
start_server = websockets.serve(
|
start_server = websockets.serve(
|
||||||
ws_serve, args.host, args.port,
|
ws_serve,
|
||||||
subprotocols=["binary"],
|
args.host,
|
||||||
ping_interval=None,
|
args.port,
|
||||||
ssl=ssl_context
|
subprotocols=["binary"],
|
||||||
|
ping_interval=None,
|
||||||
|
ssl=ssl_context,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
start_server = websockets.serve(
|
start_server = websockets.serve(
|
||||||
ws_serve, args.host, args.port,
|
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None
|
||||||
subprotocols=["binary"],
|
|
||||||
ping_interval=None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}")
|
print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}")
|
||||||
|
|
||||||
# 启动事件循环
|
# 启动事件循环
|
||||||
asyncio.get_event_loop().run_until_complete(start_server)
|
asyncio.get_event_loop().run_until_complete(start_server)
|
||||||
asyncio.get_event_loop().run_forever()
|
asyncio.get_event_loop().run_forever()
|
||||||
@ -229,14 +251,14 @@ def start_server(args, asr_service_instance):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 解析命令行参数
|
# 解析命令行参数
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# 加载模型
|
# 加载模型
|
||||||
print("正在加载模型...")
|
print("正在加载模型...")
|
||||||
models = load_models(args)
|
models = load_models(args)
|
||||||
print("模型加载完成!当前仅支持单个客户端同时连接!")
|
print("模型加载完成!当前仅支持单个客户端同时连接!")
|
||||||
|
|
||||||
# 创建ASR服务
|
# 创建ASR服务
|
||||||
asr_service = ASRService(models)
|
asr_service = ASRService(models)
|
||||||
|
|
||||||
# 启动服务器
|
# 启动服务器
|
||||||
start_server(args, asr_service)
|
start_server(args, asr_service)
|
||||||
|
@ -9,11 +9,11 @@ import json
|
|||||||
|
|
||||||
class ASRService:
|
class ASRService:
|
||||||
"""ASR服务类,封装各种语音识别相关功能"""
|
"""ASR服务类,封装各种语音识别相关功能"""
|
||||||
|
|
||||||
def __init__(self, models):
|
def __init__(self, models):
|
||||||
"""
|
"""
|
||||||
初始化ASR服务
|
初始化ASR服务
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
models: 包含各种预加载模型的字典
|
models: 包含各种预加载模型的字典
|
||||||
"""
|
"""
|
||||||
@ -21,42 +21,41 @@ class ASRService:
|
|||||||
self.model_asr_streaming = models["asr_streaming"]
|
self.model_asr_streaming = models["asr_streaming"]
|
||||||
self.model_vad = models["vad"]
|
self.model_vad = models["vad"]
|
||||||
self.model_punc = models["punc"]
|
self.model_punc = models["punc"]
|
||||||
|
|
||||||
async def async_vad(self, websocket, audio_in):
|
async def async_vad(self, websocket, audio_in):
|
||||||
"""
|
"""
|
||||||
语音活动检测
|
语音活动检测
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
websocket: WebSocket连接
|
websocket: WebSocket连接
|
||||||
audio_in: 二进制音频数据
|
audio_in: 二进制音频数据
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
tuple: (speech_start, speech_end) 语音开始和结束位置
|
tuple: (speech_start, speech_end) 语音开始和结束位置
|
||||||
"""
|
"""
|
||||||
# 使用VAD模型分析音频段
|
# 使用VAD模型分析音频段
|
||||||
segments_result = self.model_vad.generate(
|
segments_result = self.model_vad.generate(
|
||||||
input=audio_in,
|
input=audio_in, **websocket.status_dict_vad
|
||||||
**websocket.status_dict_vad
|
|
||||||
)[0]["value"]
|
)[0]["value"]
|
||||||
|
|
||||||
speech_start = -1
|
speech_start = -1
|
||||||
speech_end = -1
|
speech_end = -1
|
||||||
|
|
||||||
# 解析VAD结果
|
# 解析VAD结果
|
||||||
if len(segments_result) == 0 or len(segments_result) > 1:
|
if len(segments_result) == 0 or len(segments_result) > 1:
|
||||||
return speech_start, speech_end
|
return speech_start, speech_end
|
||||||
|
|
||||||
if segments_result[0][0] != -1:
|
if segments_result[0][0] != -1:
|
||||||
speech_start = segments_result[0][0]
|
speech_start = segments_result[0][0]
|
||||||
if segments_result[0][1] != -1:
|
if segments_result[0][1] != -1:
|
||||||
speech_end = segments_result[0][1]
|
speech_end = segments_result[0][1]
|
||||||
|
|
||||||
return speech_start, speech_end
|
return speech_start, speech_end
|
||||||
|
|
||||||
async def async_asr(self, websocket, audio_in):
|
async def async_asr(self, websocket, audio_in):
|
||||||
"""
|
"""
|
||||||
离线ASR处理
|
离线ASR处理
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
websocket: WebSocket连接
|
websocket: WebSocket连接
|
||||||
audio_in: 二进制音频数据
|
audio_in: 二进制音频数据
|
||||||
@ -64,42 +63,44 @@ class ASRService:
|
|||||||
if len(audio_in) > 0:
|
if len(audio_in) > 0:
|
||||||
# 使用离线ASR模型处理音频
|
# 使用离线ASR模型处理音频
|
||||||
rec_result = self.model_asr.generate(
|
rec_result = self.model_asr.generate(
|
||||||
input=audio_in,
|
input=audio_in, **websocket.status_dict_asr
|
||||||
**websocket.status_dict_asr
|
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# 如果有标点符号模型且识别出文本,则添加标点
|
# 如果有标点符号模型且识别出文本,则添加标点
|
||||||
if self.model_punc is not None and len(rec_result["text"]) > 0:
|
if self.model_punc is not None and len(rec_result["text"]) > 0:
|
||||||
rec_result = self.model_punc.generate(
|
rec_result = self.model_punc.generate(
|
||||||
input=rec_result["text"],
|
input=rec_result["text"], **websocket.status_dict_punc
|
||||||
**websocket.status_dict_punc
|
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# 如果识别出文本,发送到客户端
|
# 如果识别出文本,发送到客户端
|
||||||
if len(rec_result["text"]) > 0:
|
if len(rec_result["text"]) > 0:
|
||||||
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
||||||
message = json.dumps({
|
message = json.dumps(
|
||||||
"mode": mode,
|
{
|
||||||
"text": rec_result["text"],
|
"mode": mode,
|
||||||
"wav_name": websocket.wav_name,
|
"text": rec_result["text"],
|
||||||
"is_final": websocket.is_speaking,
|
"wav_name": websocket.wav_name,
|
||||||
})
|
"is_final": websocket.is_speaking,
|
||||||
|
}
|
||||||
|
)
|
||||||
await websocket.send(message)
|
await websocket.send(message)
|
||||||
else:
|
else:
|
||||||
# 如果没有音频数据,发送空文本
|
# 如果没有音频数据,发送空文本
|
||||||
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
||||||
message = json.dumps({
|
message = json.dumps(
|
||||||
"mode": mode,
|
{
|
||||||
"text": "",
|
"mode": mode,
|
||||||
"wav_name": websocket.wav_name,
|
"text": "",
|
||||||
"is_final": websocket.is_speaking,
|
"wav_name": websocket.wav_name,
|
||||||
})
|
"is_final": websocket.is_speaking,
|
||||||
|
}
|
||||||
|
)
|
||||||
await websocket.send(message)
|
await websocket.send(message)
|
||||||
|
|
||||||
async def async_asr_online(self, websocket, audio_in):
|
async def async_asr_online(self, websocket, audio_in):
|
||||||
"""
|
"""
|
||||||
在线ASR处理
|
在线ASR处理
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
websocket: WebSocket连接
|
websocket: WebSocket连接
|
||||||
audio_in: 二进制音频数据
|
audio_in: 二进制音频数据
|
||||||
@ -107,21 +108,24 @@ class ASRService:
|
|||||||
if len(audio_in) > 0:
|
if len(audio_in) > 0:
|
||||||
# 使用在线ASR模型处理音频
|
# 使用在线ASR模型处理音频
|
||||||
rec_result = self.model_asr_streaming.generate(
|
rec_result = self.model_asr_streaming.generate(
|
||||||
input=audio_in,
|
input=audio_in, **websocket.status_dict_asr_online
|
||||||
**websocket.status_dict_asr_online
|
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# 在2pass模式下,如果是最终帧则跳过(留给离线ASR处理)
|
# 在2pass模式下,如果是最终帧则跳过(留给离线ASR处理)
|
||||||
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False):
|
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get(
|
||||||
|
"is_final", False
|
||||||
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
# 如果识别出文本,发送到客户端
|
# 如果识别出文本,发送到客户端
|
||||||
if len(rec_result["text"]):
|
if len(rec_result["text"]):
|
||||||
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
|
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
|
||||||
message = json.dumps({
|
message = json.dumps(
|
||||||
"mode": mode,
|
{
|
||||||
"text": rec_result["text"],
|
"mode": mode,
|
||||||
"wav_name": websocket.wav_name,
|
"text": rec_result["text"],
|
||||||
"is_final": websocket.is_speaking,
|
"wav_name": websocket.wav_name,
|
||||||
})
|
"is_final": websocket.is_speaking,
|
||||||
await websocket.send(message)
|
}
|
||||||
|
)
|
||||||
|
await websocket.send(message)
|
||||||
|
3
src/utils/__init__.py
Normal file
3
src/utils/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .logger import get_module_logger, setup_root_logger
|
||||||
|
|
||||||
|
__all__ = ["get_module_logger", "setup_root_logger"]
|
122
src/utils/data_format.py
Normal file
122
src/utils/data_format.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
"""
|
||||||
|
处理各类音频数据与bytes的转换
|
||||||
|
"""
|
||||||
|
|
||||||
|
import wave
|
||||||
|
from pydub import AudioSegment
|
||||||
|
import io
|
||||||
|
|
||||||
|
|
||||||
|
def wav_to_bytes(wav_path: str) -> bytes:
|
||||||
|
"""
|
||||||
|
将WAV文件读取为bytes数据。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
wav_path (str): WAV文件的路径。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bytes: WAV文件的原始字节数据。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
FileNotFoundError: 如果WAV文件不存在。
|
||||||
|
wave.Error: 如果文件不是有效的WAV文件。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with wave.open(wav_path, "rb") as wf:
|
||||||
|
# 读取所有帧
|
||||||
|
frames = wf.readframes(wf.getnframes())
|
||||||
|
return frames
|
||||||
|
except FileNotFoundError:
|
||||||
|
# 可以选择记录日志或重新抛出,这里为了清晰直接重新抛出
|
||||||
|
raise FileNotFoundError(f"错误: 未找到WAV文件 '{wav_path}'")
|
||||||
|
except wave.Error as e:
|
||||||
|
raise wave.Error(f"错误: 打开或读取WAV文件 '{wav_path}' 失败 - {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_wav(
|
||||||
|
bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
将bytes数据写入为WAV文件。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
bytes_data (bytes): 音频的字节数据。
|
||||||
|
wav_path (str): 保存WAV文件的路径。
|
||||||
|
nchannels (int): 声道数 (例如 1 for mono, 2 for stereo)。
|
||||||
|
sampwidth (int): 采样宽度 (字节数, 例如 2 for 16-bit audio)。
|
||||||
|
framerate (int): 采样率 (例如 44100, 16000)。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
wave.Error: 如果写入WAV文件失败。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with wave.open(wav_path, "wb") as wf:
|
||||||
|
wf.setnchannels(nchannels)
|
||||||
|
wf.setsampwidth(sampwidth)
|
||||||
|
wf.setframerate(framerate)
|
||||||
|
wf.writeframes(bytes_data)
|
||||||
|
except wave.Error as e:
|
||||||
|
raise wave.Error(f"错误: 写入WAV文件 '{wav_path}' 失败 - {e}")
|
||||||
|
except Exception as e:
|
||||||
|
# 捕获其他可能的写入错误
|
||||||
|
raise Exception(f"写入WAV文件 '{wav_path}' 时发生未知错误 - {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def mp3_to_bytes(mp3_path: str) -> bytes:
|
||||||
|
"""
|
||||||
|
将MP3文件转换为bytes数据 (原始PCM数据)。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
mp3_path (str): MP3文件的路径。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bytes: MP3文件解码后的原始PCM字节数据。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
FileNotFoundError: 如果MP3文件不存在。
|
||||||
|
pydub.exceptions.CouldntDecodeError: 如果MP3文件无法解码。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
audio = AudioSegment.from_mp3(mp3_path)
|
||||||
|
# 获取原始PCM数据
|
||||||
|
return audio.raw_data
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise FileNotFoundError(f"错误: 未找到MP3文件 '{mp3_path}'")
|
||||||
|
except Exception as e: # pydub 可能抛出多种解码相关的错误
|
||||||
|
raise Exception(f"错误: 处理MP3文件 '{mp3_path}' 失败 - {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_mp3(
|
||||||
|
bytes_data: bytes,
|
||||||
|
mp3_path: str,
|
||||||
|
frame_rate: int,
|
||||||
|
channels: int,
|
||||||
|
sample_width: int,
|
||||||
|
bitrate: str = "192k",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
将原始PCM bytes数据转换为MP3文件。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
bytes_data (bytes): 原始PCM字节数据。
|
||||||
|
mp3_path (str): 保存MP3文件的路径。
|
||||||
|
frame_rate (int): 原始PCM数据的采样率 (例如 44100)。
|
||||||
|
channels (int): 原始PCM数据的声道数 (例如 1 for mono, 2 for stereo)。
|
||||||
|
sample_width (int): 原始PCM数据的采样宽度 (字节数, 例如 2 for 16-bit)。
|
||||||
|
bitrate (str): MP3编码的比特率 (例如 "128k", "192k", "320k")。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
Exception: 如果转换或写入MP3文件失败。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 从原始数据创建AudioSegment对象
|
||||||
|
audio = AudioSegment(
|
||||||
|
data=bytes_data,
|
||||||
|
sample_width=sample_width,
|
||||||
|
frame_rate=frame_rate,
|
||||||
|
channels=channels,
|
||||||
|
)
|
||||||
|
# 导出为MP3
|
||||||
|
audio.export(mp3_path, format="mp3", bitrate=bitrate)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"错误: 转换或写入MP3文件 '{mp3_path}' 失败 - {e}")
|
91
src/utils/logger.py
Normal file
91
src/utils/logger.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(
|
||||||
|
name: str = None,
|
||||||
|
level: str = "INFO",
|
||||||
|
log_file: Optional[str] = None,
|
||||||
|
log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
date_format: str = "%Y-%m-%d %H:%M:%S",
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
设置并返回一个配置好的logger实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: logger的名称,默认为None(使用root logger)
|
||||||
|
level: 日志级别,默认为"INFO"
|
||||||
|
log_file: 日志文件路径,默认为None(仅控制台输出)
|
||||||
|
log_format: 日志格式
|
||||||
|
date_format: 日期格式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logging.Logger: 配置好的logger实例
|
||||||
|
"""
|
||||||
|
# 获取logger实例
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
|
||||||
|
# 设置日志级别
|
||||||
|
level = getattr(logging, level.upper())
|
||||||
|
logger.setLevel(level)
|
||||||
|
|
||||||
|
print(f"添加处理器 {name} {log_file} {log_format} {date_format}")
|
||||||
|
# 创建格式器
|
||||||
|
formatter = logging.Formatter(log_format, date_format)
|
||||||
|
|
||||||
|
# 添加控制台处理器
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
# 如果指定了日志文件,添加文件处理器
|
||||||
|
if log_file:
|
||||||
|
# 确保日志目录存在
|
||||||
|
log_path = Path(log_file)
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
# 注意:移除了 propagate = False,允许日志传递
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def setup_root_logger(level: str = "INFO", log_file: Optional[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
配置根日志器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: 日志级别
|
||||||
|
log_file: 日志文件路径
|
||||||
|
"""
|
||||||
|
setup_logger(None, level, log_file)
|
||||||
|
|
||||||
|
|
||||||
|
def get_module_logger(
|
||||||
|
module_name: str,
|
||||||
|
level: Optional[str] = None, # 改为可选参数
|
||||||
|
log_file: Optional[str] = None, # 一般不需要单独指定
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
获取模块级别的logger
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_name: 模块名称,通常传入__name__
|
||||||
|
level: 可选的日志级别,如果不指定则继承父级配置
|
||||||
|
log_file: 可选的日志文件路径,一般不需要指定
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(module_name)
|
||||||
|
|
||||||
|
# 只有显式指定了level才设置
|
||||||
|
if level:
|
||||||
|
logger.setLevel(getattr(logging, level.upper()))
|
||||||
|
|
||||||
|
# 只有显式指定了log_file才添加文件处理器
|
||||||
|
if log_file:
|
||||||
|
setup_logger(module_name, level or "INFO", log_file)
|
||||||
|
|
||||||
|
return logger
|
17
test_main.py
Normal file
17
test_main.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
"""
|
||||||
|
测试主函数
|
||||||
|
请在tests目录下创建测试文件, 并在此文件中调用
|
||||||
|
"""
|
||||||
|
|
||||||
|
from tests.pipeline.asr_test import test_asr_pipeline
|
||||||
|
from src.utils.logger import get_module_logger, setup_root_logger
|
||||||
|
|
||||||
|
setup_root_logger(level="INFO", log_file="logs/test_main.log")
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
# from tests.functor.vad_test import test_vad_functor
|
||||||
|
# logger.info("开始测试VAD函数器")
|
||||||
|
# test_vad_functor()
|
||||||
|
|
||||||
|
logger.info("开始测试ASR管道")
|
||||||
|
test_asr_pipeline()
|
@ -1 +1 @@
|
|||||||
"""FunASR WebSocket服务测试模块"""
|
"""FunASR WebSocket服务测试模块"""
|
||||||
|
124
tests/functor/vad_test.py
Normal file
124
tests/functor/vad_test.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
"""
|
||||||
|
Functor测试
|
||||||
|
VAD测试
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.functor.vad_functor import VADFunctor
|
||||||
|
from src.functor.asr_functor import ASRFunctor
|
||||||
|
from src.functor.spk_functor import SPKFunctor
|
||||||
|
from queue import Queue, Empty
|
||||||
|
from src.model_loader import ModelLoader
|
||||||
|
from src.models import AudioBinary_Config, AudioBinary_data_list
|
||||||
|
from src.utils.data_format import wav_to_bytes
|
||||||
|
import time
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
from pydub import AudioSegment
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
# 观察参数
|
||||||
|
OVERWATCH = False
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
model_loader = ModelLoader()
|
||||||
|
|
||||||
|
|
||||||
|
def test_vad_functor():
|
||||||
|
# 加载模型
|
||||||
|
args = {
|
||||||
|
"asr_model": "paraformer-zh",
|
||||||
|
"asr_model_revision": "v2.0.4",
|
||||||
|
"vad_model": "fsmn-vad",
|
||||||
|
"vad_model_revision": "v2.0.4",
|
||||||
|
"auto_update": False,
|
||||||
|
}
|
||||||
|
model_loader.load_models(args)
|
||||||
|
# 加载数据
|
||||||
|
f_data, sample_rate = soundfile.read("tests/vad_example.wav")
|
||||||
|
audio_config = AudioBinary_Config(
|
||||||
|
chunk_size=200,
|
||||||
|
chunk_stride=1600,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
sample_width=16,
|
||||||
|
channels=1,
|
||||||
|
)
|
||||||
|
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
|
||||||
|
audio_config.chunk_stride = chunk_stride
|
||||||
|
# 创建输入队列
|
||||||
|
input_queue = Queue()
|
||||||
|
vad2asr_queue = Queue()
|
||||||
|
vad2spk_queue = Queue()
|
||||||
|
# 创建音频数据列表
|
||||||
|
audio_binary_data_list = AudioBinary_data_list()
|
||||||
|
|
||||||
|
# 创建VAD函数器
|
||||||
|
vad_functor = VADFunctor()
|
||||||
|
# 设置输入队列
|
||||||
|
vad_functor.set_input_queue(input_queue)
|
||||||
|
# 设置音频配置
|
||||||
|
vad_functor.set_audio_config(audio_config)
|
||||||
|
# 设置音频数据列表
|
||||||
|
vad_functor.set_audio_binary_data_list(audio_binary_data_list)
|
||||||
|
# 设置回调函数
|
||||||
|
vad_functor.add_callback(lambda x: print(f"vad callback: {x}"))
|
||||||
|
vad_functor.add_callback(lambda x: vad2asr_queue.put(x))
|
||||||
|
vad_functor.add_callback(lambda x: vad2spk_queue.put(x))
|
||||||
|
# 设置模型
|
||||||
|
vad_functor.set_model({"vad": model_loader.models["vad"]})
|
||||||
|
# 启动VAD函数器
|
||||||
|
vad_functor.run()
|
||||||
|
|
||||||
|
# 创建ASR函数器
|
||||||
|
asr_functor = ASRFunctor()
|
||||||
|
# 设置输入队列
|
||||||
|
asr_functor.set_input_queue(vad2asr_queue)
|
||||||
|
# 设置音频配置
|
||||||
|
asr_functor.set_audio_config(audio_config)
|
||||||
|
# 设置回调函数
|
||||||
|
asr_functor.add_callback(lambda x: print(f"asr callback: {x}"))
|
||||||
|
# 设置模型
|
||||||
|
asr_functor.set_model({"asr": model_loader.models["asr"]})
|
||||||
|
# 启动ASR函数器
|
||||||
|
asr_functor.run()
|
||||||
|
|
||||||
|
# 创建SPK函数器
|
||||||
|
spk_functor = SPKFunctor()
|
||||||
|
# 设置输入队列
|
||||||
|
spk_functor.set_input_queue(vad2spk_queue)
|
||||||
|
# 设置音频配置
|
||||||
|
spk_functor.set_audio_config(audio_config)
|
||||||
|
# 设置回调函数
|
||||||
|
spk_functor.add_callback(lambda x: print(f"spk callback: {x}"))
|
||||||
|
# 设置模型
|
||||||
|
spk_functor.set_model(
|
||||||
|
{
|
||||||
|
# 'spk': model_loader.models['spk']
|
||||||
|
"spk": "fake_spk"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# 启动SPK函数器
|
||||||
|
spk_functor.run()
|
||||||
|
|
||||||
|
f_binary = f_data
|
||||||
|
audio_clip_len = 200
|
||||||
|
print(
|
||||||
|
f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}"
|
||||||
|
)
|
||||||
|
for i in range(0, len(f_binary), audio_clip_len):
|
||||||
|
binary_data = f_binary[i : i + audio_clip_len]
|
||||||
|
input_queue.put(binary_data)
|
||||||
|
# 等待VAD函数器结束
|
||||||
|
|
||||||
|
vad_functor.stop()
|
||||||
|
print("[vad_test] VAD函数器结束")
|
||||||
|
|
||||||
|
asr_functor.stop()
|
||||||
|
print("[vad_test] ASR函数器结束")
|
||||||
|
|
||||||
|
# 保存音频数据
|
||||||
|
if OVERWATCH:
|
||||||
|
for index in range(len(audio_binary_data_list)):
|
||||||
|
save_path = f"tests/vad_test_output_{index}.wav"
|
||||||
|
soundfile.write(
|
||||||
|
save_path, audio_binary_data_list[index].binary_data, sample_rate
|
||||||
|
)
|
121
tests/modelsuse.py
Normal file
121
tests/modelsuse.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
"""
|
||||||
|
模型使用测试
|
||||||
|
此处主要用于各类调用模型的处理数据与输出格式
|
||||||
|
请在主目录下test_main.py中调用
|
||||||
|
将需要测试的模型定义在函数中进行测试, 函数名称需要与测试内容匹配。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from funasr import AutoModel
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from src.models import VADResponse
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
在线VAD模型使用
|
||||||
|
"""
|
||||||
|
chunk_size = 100 # ms
|
||||||
|
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
|
||||||
|
|
||||||
|
vad_result = VADResponse()
|
||||||
|
vad_result.time_chunk_index_callback = lambda index: print(f"回调: {index}")
|
||||||
|
items = []
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
speech, sample_rate = soundfile.read(file_path)
|
||||||
|
chunk_stride = int(chunk_size * sample_rate / 1000)
|
||||||
|
|
||||||
|
cache = {}
|
||||||
|
total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
|
||||||
|
for i in range(total_chunk_num):
|
||||||
|
time.sleep(0.1)
|
||||||
|
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
|
||||||
|
is_final = i == total_chunk_num - 1
|
||||||
|
res = model.generate(
|
||||||
|
input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size
|
||||||
|
)
|
||||||
|
if len(res[0]["value"]):
|
||||||
|
vad_result += VADResponse.from_raw(res)
|
||||||
|
for item in res[0]["value"]:
|
||||||
|
items.append(item)
|
||||||
|
vad_result.process_time_chunk()
|
||||||
|
|
||||||
|
# for item in items:
|
||||||
|
# print(item)
|
||||||
|
return vad_result
|
||||||
|
|
||||||
|
|
||||||
|
def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
在线VAD模型使用
|
||||||
|
测试LogicTrager
|
||||||
|
在Rebuild版本后LogicTrager中已弃用
|
||||||
|
"""
|
||||||
|
from src.logic_trager import LogicTrager
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
from src.config import parse_args
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# from src.functor.model_loader import load_models
|
||||||
|
# models = load_models(args)
|
||||||
|
from src.model_loader import ModelLoader
|
||||||
|
|
||||||
|
models = ModelLoader(args)
|
||||||
|
|
||||||
|
chunk_size = 200 # ms
|
||||||
|
from src.models import AudioBinary_Config
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
speech, sample_rate = soundfile.read(file_path)
|
||||||
|
chunk_stride = int(chunk_size * sample_rate / 1000)
|
||||||
|
audio_config = AudioBinary_Config(
|
||||||
|
sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
|
logic_trager = LogicTrager(models=models, audio_config=audio_config)
|
||||||
|
for i in range(len(speech) // chunk_stride + 1):
|
||||||
|
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
|
||||||
|
logic_trager.push_binary_data(speech_chunk)
|
||||||
|
|
||||||
|
# for item in items:
|
||||||
|
# print(item)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
ASR模型使用
|
||||||
|
离线ASR模型使用
|
||||||
|
"""
|
||||||
|
from funasr import AutoModel
|
||||||
|
|
||||||
|
model = AutoModel(
|
||||||
|
model="paraformer-zh",
|
||||||
|
model_revision="v2.0.4",
|
||||||
|
vad_model="fsmn-vad",
|
||||||
|
vad_model_revision="v2.0.4",
|
||||||
|
# punc_model="ct-punc-c", punc_model_revision="v2.0.4",
|
||||||
|
spk_model="cam++",
|
||||||
|
spk_model_revision="v2.0.2",
|
||||||
|
spk_mode="vad_segment",
|
||||||
|
auto_update=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
from src.models import AudioBinary_Config
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
speech, sample_rate = soundfile.read(file_path)
|
||||||
|
result = model.generate(speech)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# 请在主目录下调用test_main.py文件进行测试
|
||||||
|
# vad_result = vad_model_use_online("tests/vad_example.wav")
|
||||||
|
# vad_result = vad_model_use_online_logic("tests/vad_example.wav")
|
||||||
|
# print(vad_result)
|
93
tests/pipeline/asr_test.py
Normal file
93
tests/pipeline/asr_test.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
"""
|
||||||
|
Pipeline测试
|
||||||
|
VAD+ASR+SPK(FAKE)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.pipeline.ASRpipeline import ASRPipeline
|
||||||
|
from src.pipeline import PipelineFactory
|
||||||
|
from src.models import AudioBinary_data_list, AudioBinary_Config
|
||||||
|
from src.model_loader import ModelLoader
|
||||||
|
from queue import Queue
|
||||||
|
import soundfile
|
||||||
|
import time
|
||||||
|
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
OVAERWATCH = False
|
||||||
|
|
||||||
|
model_loader = ModelLoader()
|
||||||
|
|
||||||
|
|
||||||
|
def test_asr_pipeline():
|
||||||
|
# 加载模型
|
||||||
|
args = {
|
||||||
|
"asr_model": "paraformer-zh",
|
||||||
|
"asr_model_revision": "v2.0.4",
|
||||||
|
"vad_model": "fsmn-vad",
|
||||||
|
"vad_model_revision": "v2.0.4",
|
||||||
|
"spk_model": "cam++",
|
||||||
|
"spk_model_revision": "v2.0.2",
|
||||||
|
"audio_update": False,
|
||||||
|
}
|
||||||
|
models = model_loader.load_models(args)
|
||||||
|
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
|
||||||
|
audio_config = AudioBinary_Config(
|
||||||
|
chunk_size=200,
|
||||||
|
chunk_stride=1600,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
sample_width=16,
|
||||||
|
channels=1,
|
||||||
|
)
|
||||||
|
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
|
||||||
|
audio_config.chunk_stride = chunk_stride
|
||||||
|
|
||||||
|
# 创建参数Dict
|
||||||
|
config = {
|
||||||
|
"audio_config": audio_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 创建音频数据列表
|
||||||
|
audio_binary_data_list = AudioBinary_data_list()
|
||||||
|
|
||||||
|
input_queue = Queue()
|
||||||
|
|
||||||
|
# 创建Pipeline
|
||||||
|
# asr_pipeline = ASRPipeline()
|
||||||
|
# asr_pipeline.set_models(models)
|
||||||
|
# asr_pipeline.set_config(config)
|
||||||
|
# asr_pipeline.set_audio_binary(audio_binary_data_list)
|
||||||
|
# asr_pipeline.set_input_queue(input_queue)
|
||||||
|
# asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
|
||||||
|
# asr_pipeline.bake()
|
||||||
|
asr_pipeline = PipelineFactory.create_pipeline(
|
||||||
|
pipeline_name = "ASRpipeline",
|
||||||
|
models=models,
|
||||||
|
config=config,
|
||||||
|
audio_binary=audio_binary_data_list,
|
||||||
|
input_queue=input_queue,
|
||||||
|
callback=lambda x: print(f"pipeline callback: {x}")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 运行Pipeline
|
||||||
|
asr_instance = asr_pipeline.run()
|
||||||
|
|
||||||
|
|
||||||
|
audio_clip_len = 200
|
||||||
|
print(
|
||||||
|
f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}"
|
||||||
|
)
|
||||||
|
for i in range(0, len(audio_data), audio_clip_len):
|
||||||
|
input_queue.put(audio_data[i : i + audio_clip_len])
|
||||||
|
|
||||||
|
# time.sleep(10)
|
||||||
|
# input_queue.put(None)
|
||||||
|
|
||||||
|
# 等待Pipeline结束
|
||||||
|
# asr_instance.join()
|
||||||
|
|
||||||
|
time.sleep(5)
|
||||||
|
asr_pipeline.stop()
|
||||||
|
# asr_pipeline.stop()
|
@ -10,23 +10,23 @@ import os
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
# 将src目录添加到路径
|
# 将src目录添加到路径
|
||||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
from src.config import parse_args
|
from src.config import parse_args
|
||||||
|
|
||||||
|
|
||||||
def test_default_args():
|
def test_default_args():
|
||||||
"""测试默认参数值"""
|
"""测试默认参数值"""
|
||||||
with patch('sys.argv', ['script.py']):
|
with patch("sys.argv", ["script.py"]):
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# 检查服务器参数
|
# 检查服务器参数
|
||||||
assert args.host == "0.0.0.0"
|
assert args.host == "0.0.0.0"
|
||||||
assert args.port == 10095
|
assert args.port == 10095
|
||||||
|
|
||||||
# 检查SSL参数
|
# 检查SSL参数
|
||||||
assert args.certfile == ""
|
assert args.certfile == ""
|
||||||
assert args.keyfile == ""
|
assert args.keyfile == ""
|
||||||
|
|
||||||
# 检查模型参数
|
# 检查模型参数
|
||||||
assert "paraformer" in args.asr_model
|
assert "paraformer" in args.asr_model
|
||||||
assert args.asr_model_revision == "v2.0.4"
|
assert args.asr_model_revision == "v2.0.4"
|
||||||
@ -36,7 +36,7 @@ def test_default_args():
|
|||||||
assert args.vad_model_revision == "v2.0.4"
|
assert args.vad_model_revision == "v2.0.4"
|
||||||
assert "punc" in args.punc_model
|
assert "punc" in args.punc_model
|
||||||
assert args.punc_model_revision == "v2.0.4"
|
assert args.punc_model_revision == "v2.0.4"
|
||||||
|
|
||||||
# 检查硬件配置
|
# 检查硬件配置
|
||||||
assert args.ngpu == 1
|
assert args.ngpu == 1
|
||||||
assert args.device == "cuda"
|
assert args.device == "cuda"
|
||||||
@ -46,19 +46,26 @@ def test_default_args():
|
|||||||
def test_custom_args():
|
def test_custom_args():
|
||||||
"""测试自定义参数值"""
|
"""测试自定义参数值"""
|
||||||
test_args = [
|
test_args = [
|
||||||
'script.py',
|
"script.py",
|
||||||
'--host', 'localhost',
|
"--host",
|
||||||
'--port', '8080',
|
"localhost",
|
||||||
'--certfile', 'cert.pem',
|
"--port",
|
||||||
'--keyfile', 'key.pem',
|
"8080",
|
||||||
'--asr_model', 'custom_model',
|
"--certfile",
|
||||||
'--ngpu', '0',
|
"cert.pem",
|
||||||
'--device', 'cpu'
|
"--keyfile",
|
||||||
|
"key.pem",
|
||||||
|
"--asr_model",
|
||||||
|
"custom_model",
|
||||||
|
"--ngpu",
|
||||||
|
"0",
|
||||||
|
"--device",
|
||||||
|
"cpu",
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch('sys.argv', test_args):
|
with patch("sys.argv", test_args):
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# 检查自定义参数
|
# 检查自定义参数
|
||||||
assert args.host == "localhost"
|
assert args.host == "localhost"
|
||||||
assert args.port == 8080
|
assert args.port == 8080
|
||||||
@ -66,4 +73,4 @@ def test_custom_args():
|
|||||||
assert args.keyfile == "key.pem"
|
assert args.keyfile == "key.pem"
|
||||||
assert args.asr_model == "custom_model"
|
assert args.asr_model == "custom_model"
|
||||||
assert args.ngpu == 0
|
assert args.ngpu == 0
|
||||||
assert args.device == "cpu"
|
assert args.device == "cpu"
|
||||||
|
BIN
tests/vad_example.wav
Normal file
BIN
tests/vad_example.wav
Normal file
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user