Compare commits
16 Commits
master
...
feature_re
Author | SHA1 | Date | |
---|---|---|---|
![]() |
5a820b49e4 | ||
![]() |
5b94c40016 | ||
![]() |
3d8bf9de25 | ||
![]() |
ff9bd70039 | ||
![]() |
4e9e94d8dc | ||
![]() |
b569b7e63d | ||
![]() |
f245c6e9df | ||
![]() |
49cb428c23 | ||
![]() |
703a40e955 | ||
![]() |
040fc57e02 | ||
![]() |
1392168126 | ||
![]() |
eff22cb33e | ||
![]() |
66c9477e4b | ||
9d522fa137 | |||
f7138dcb39 | |||
8b69ff195f |
19
.cursorrules
Normal file
19
.cursorrules
Normal file
@ -0,0 +1,19 @@
|
||||
|
||||
You are an AI assistant specialized in Python development. Your approach emphasizes:
|
||||
|
||||
1. Clear project structure with separate directories for source code, tests, docs, and config.
|
||||
2. Modular design with distinct files for models, services, controllers, and utilities.
|
||||
3. Configuration management using environment variables.
|
||||
4. Robust error handling and logging, including context capture.
|
||||
5. Comprehensive testing with pytest.
|
||||
6. Detailed documentation using docstrings and README files.
|
||||
7. Dependency management via https://github.com/astral-sh/rye and virtual environments.
|
||||
8. Code style consistency using Ruff.
|
||||
9. CI/CD implementation with GitHub Actions or GitLab CI.
|
||||
10. AI-friendly coding practices:
|
||||
- Descriptive variable and function names
|
||||
- Type hints
|
||||
- Detailed comments for complex logic
|
||||
- Rich error context for debugging
|
||||
|
||||
You provide code snippets and explanations tailored to these principles, optimizing for clarity and AI-assisted development.
|
@ -108,6 +108,3 @@ docker-compose up -d
|
||||
"is_final": false // 是否是最终结果
|
||||
}
|
||||
```
|
||||
|
||||
## 许可证
|
||||
[MIT](LICENSE)
|
30
main.py
Normal file
30
main.py
Normal file
@ -0,0 +1,30 @@
|
||||
from funasr import AutoModel
|
||||
|
||||
chunk_size = 200 # ms
|
||||
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
|
||||
|
||||
import soundfile
|
||||
|
||||
wav_file = "tests/vad_example.wav"
|
||||
speech, sample_rate = soundfile.read(wav_file)
|
||||
chunk_stride = int(chunk_size * sample_rate / 1000)
|
||||
|
||||
cache = {}
|
||||
total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
|
||||
for i in range(total_chunk_num):
|
||||
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
|
||||
is_final = i == total_chunk_num - 1
|
||||
res = model.generate(
|
||||
input=speech_chunk,
|
||||
cache=cache,
|
||||
is_final=is_final,
|
||||
chunk_size=chunk_size,
|
||||
disable_pbar=True,
|
||||
)
|
||||
if len(res[0]["value"]):
|
||||
print(res)
|
||||
|
||||
print(f"len(speech): {len(speech)}")
|
||||
print(f"len(speech_chunk): {len(speech_chunk)}")
|
||||
print(f"total_chunk_num: {total_chunk_num}")
|
||||
print(f"generateconfig: chunk_size: {chunk_size}, chunk_stride: {chunk_stride}")
|
@ -11,4 +11,4 @@ FunASR WebSocket服务
|
||||
- 支持多种识别模式(2pass/online/offline)
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__version__ = "0.1.0"
|
||||
|
194
src/audio_chunk.py
Normal file
194
src/audio_chunk.py
Normal file
@ -0,0 +1,194 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
音频数据块管理类 - 用于存储和处理16KHz音频数据
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict
|
||||
from src.models import AudioBinary_Config, AudioBinary_data_list
|
||||
|
||||
# 配置日志
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__, level="INFO")
|
||||
|
||||
|
||||
class AudioBinary:
|
||||
"""
|
||||
音频数据存储单元
|
||||
用于存储二进制数据
|
||||
面向Slice, 向Slice提供数据与接口
|
||||
|
||||
self._audio_config: AudioBinary_Config -- 音频参数配置
|
||||
self._binary_data_list: AudioBinary_data_list -- 音频数据列表
|
||||
self._slice_listener: List[callable] -- 切片监听器
|
||||
|
||||
AudioBinary_Config: Dict -- 音频参数配置
|
||||
AudioBinary_data_list: List[bytes] -- 音频数据列表
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
"""
|
||||
初始化音频数据块
|
||||
参数:
|
||||
*args: 可变参数
|
||||
"""
|
||||
# 音频参数配置
|
||||
self._audio_config = AudioBinary_Config()
|
||||
# 音频片段
|
||||
self._binary_data_list: AudioBinary_data_list = AudioBinary_data_list()
|
||||
# 切片监听器
|
||||
self._slice_listener: List = []
|
||||
if isinstance(args, Dict):
|
||||
self._audio_config = AudioBinary_Config.AudioBinary_Config_from_dict(args)
|
||||
elif isinstance(args, AudioBinary_Config):
|
||||
self._audio_config = args
|
||||
else:
|
||||
raise ValueError("参数类型错误")
|
||||
|
||||
def add_slice_listener(self, slice_listener: callable) -> None:
|
||||
"""
|
||||
添加切片监听器
|
||||
参数:
|
||||
slice_listener: callable -- 切片监听器
|
||||
"""
|
||||
self._slice_listener.append(slice_listener)
|
||||
|
||||
def __add__(self, other: bytes):
|
||||
"""
|
||||
__add__ 是 "+" 运算符的重载,
|
||||
使用方法:
|
||||
audio_binary = audio_binary + bytes
|
||||
添加音频数据块 与 add_binary_data 等效,
|
||||
但可以链式调用, 方便使用
|
||||
参数:
|
||||
other: bytes --音频数据块
|
||||
"""
|
||||
self._binary_data_list.append(other)
|
||||
return self
|
||||
|
||||
def __iadd__(self, other: bytes):
|
||||
"""
|
||||
__iadd__ 是 "+=" 运算符的重载,
|
||||
使用方法:
|
||||
audio_binary += bytes
|
||||
添加音频数据块 与 add_binary_data 等效,
|
||||
但可以链式调用, 方便使用
|
||||
参数:
|
||||
other: bytes --音频数据块
|
||||
"""
|
||||
self._binary_data_list.append(other)
|
||||
return self
|
||||
|
||||
def add_binary_data(self, binary_data: bytes):
|
||||
"""
|
||||
添加音频数据块
|
||||
参数:
|
||||
binary_data: bytes --音频数据块
|
||||
"""
|
||||
self._binary_data_list.append(binary_data)
|
||||
|
||||
def rewrite_binary_data(self, target_index: int, binary_data: bytes):
|
||||
"""
|
||||
重写音频数据块
|
||||
参数:
|
||||
target_index: int -- 目标索引
|
||||
binary_data: bytes --音频数据块
|
||||
"""
|
||||
self._binary_data_list.rewrite(target_index, binary_data)
|
||||
|
||||
def get_binary_data(
|
||||
self,
|
||||
start: int = 0,
|
||||
end: Optional[int] = None,
|
||||
) -> Optional[bytes]:
|
||||
"""
|
||||
获取指定索引的音频数据块
|
||||
参数:
|
||||
start: 开始索引
|
||||
end: 结束索引
|
||||
返回:
|
||||
List[bytes]: 音频数据块
|
||||
"""
|
||||
if start >= len(self._binary_data_list):
|
||||
return None
|
||||
if end is None:
|
||||
end = start + 1
|
||||
end = min(end, len(self._binary_data_list))
|
||||
return self._binary_data_list[start:end]
|
||||
|
||||
|
||||
class AudioChunk:
|
||||
"""
|
||||
音频数据块管理类
|
||||
管理两部分内容, AudioBinary和Slice。
|
||||
AudioBinary用于内部存储字节数据。
|
||||
Slice是AudioBinary的切片,用于外部接口。
|
||||
|
||||
此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑。
|
||||
"""
|
||||
|
||||
_instance: Optional[AudioChunk] = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""
|
||||
单例模式
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super(AudioChunk, cls).__new__(cls, *args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
初始化AudioChunk实例
|
||||
"""
|
||||
self._audio_binary_list: Dict[str, AudioBinary] = {}
|
||||
self._slice_listener: List[callable] = []
|
||||
|
||||
def get_audio_binary(
|
||||
self,
|
||||
binary_name: Optional[str] = None,
|
||||
audio_config: Optional[AudioBinary_Config] = None,
|
||||
) -> AudioBinary:
|
||||
"""
|
||||
获取音频数据块
|
||||
参数:
|
||||
binary_name: str -- 音频数据块名称
|
||||
返回:
|
||||
AudioBinary: 音频数据块
|
||||
"""
|
||||
if binary_name is None:
|
||||
binary_name = "default"
|
||||
if binary_name not in self._audio_binary_list:
|
||||
self._audio_binary_list[binary_name] = AudioBinary(audio_config)
|
||||
return self._audio_binary_list[binary_name]
|
||||
|
||||
@staticmethod
|
||||
def _time2size(time_ms: int, audio_config: AudioBinary_Config) -> int:
|
||||
"""
|
||||
将时间(ms)转换为数据大小(字节)
|
||||
参数:
|
||||
time_ms: int -- 时间(ms)
|
||||
audio_config: AudioBinary_Config -- 音频参数配置
|
||||
返回:
|
||||
int: 数据大小(字节)
|
||||
"""
|
||||
# 时间(ms)到字节(bytes)计算方法为: 时间(ms) * 采样率(Hz) * 通道数(1 or 2) * 采样位宽(16 or 24) / 1000
|
||||
time_s = time_ms / 1000
|
||||
bytes_per_sample = audio_config.sample_width * audio_config.channel
|
||||
return int(time_s * audio_config.sample_rate * bytes_per_sample)
|
||||
|
||||
@staticmethod
|
||||
def _size2time(size: int, audio_config: AudioBinary_Config) -> int:
|
||||
"""
|
||||
将数据大小(字节)转换为时间(ms)
|
||||
参数:
|
||||
size: int -- 数据大小(字节)
|
||||
audio_config: AudioBinary_Config -- 音频参数配置
|
||||
返回:
|
||||
int: 时间(ms)
|
||||
"""
|
||||
# 字节(bytes)到时间(ms)计算方法为: 字节(bytes) * 1000 / (采样率(Hz) * 通道数(1 or 2) * 采样位宽(16 or 24))
|
||||
bytes_per_sample = audio_config.sample_width * audio_config.channel
|
||||
time_ms = size * 1000 // (audio_config.sample_rate * bytes_per_sample)
|
||||
return time_ms
|
196
src/client.py
196
src/client.py
@ -1,196 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
WebSocket客户端示例 - 用于测试语音识别服务
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import websockets
|
||||
import argparse
|
||||
import numpy as np
|
||||
import wave
|
||||
import os
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""解析命令行参数"""
|
||||
parser = argparse.ArgumentParser(description="FunASR WebSocket客户端")
|
||||
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="服务器主机地址"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=10095,
|
||||
help="服务器端口"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--audio_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="要识别的音频文件路径"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="2pass",
|
||||
choices=["2pass", "online", "offline"],
|
||||
help="识别模式: 2pass(默认), online, offline"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chunk_size",
|
||||
type=str,
|
||||
default="5,10",
|
||||
help="分块大小, 格式为'encoder_size,decoder_size'"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_ssl",
|
||||
action="store_true",
|
||||
help="是否使用SSL连接"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def send_audio(websocket, audio_file, mode, chunk_size):
|
||||
"""
|
||||
发送音频文件到服务器进行识别
|
||||
|
||||
参数:
|
||||
websocket: WebSocket连接
|
||||
audio_file: 音频文件路径
|
||||
mode: 识别模式
|
||||
chunk_size: 分块大小
|
||||
"""
|
||||
# 打开并读取WAV文件
|
||||
with wave.open(audio_file, "rb") as wav_file:
|
||||
params = wav_file.getparams()
|
||||
frames = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
# 音频文件信息
|
||||
print(f"音频文件: {os.path.basename(audio_file)}")
|
||||
print(f"采样率: {params.framerate}Hz, 通道数: {params.nchannels}")
|
||||
print(f"采样位深: {params.sampwidth * 8}位, 总帧数: {params.nframes}")
|
||||
|
||||
# 设置配置参数
|
||||
config = {
|
||||
"mode": mode,
|
||||
"chunk_size": chunk_size,
|
||||
"wav_name": os.path.basename(audio_file),
|
||||
"is_speaking": True
|
||||
}
|
||||
|
||||
# 发送配置
|
||||
await websocket.send(json.dumps(config))
|
||||
|
||||
# 模拟实时发送音频数据
|
||||
chunk_size_bytes = 3200 # 每次发送100ms的16kHz音频
|
||||
total_chunks = len(frames) // chunk_size_bytes
|
||||
|
||||
print(f"开始发送音频数据,共 {total_chunks} 个数据块...")
|
||||
|
||||
try:
|
||||
for i in range(0, len(frames), chunk_size_bytes):
|
||||
chunk = frames[i:i+chunk_size_bytes]
|
||||
await websocket.send(chunk)
|
||||
|
||||
# 模拟实时,每100ms发送一次
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 显示进度
|
||||
if (i // chunk_size_bytes) % 10 == 0:
|
||||
print(f"已发送 {i // chunk_size_bytes}/{total_chunks} 数据块")
|
||||
|
||||
# 发送结束信号
|
||||
await websocket.send(json.dumps({"is_speaking": False}))
|
||||
print("音频数据发送完成")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发送音频时出错: {e}")
|
||||
|
||||
|
||||
async def receive_results(websocket):
|
||||
"""
|
||||
接收并显示识别结果
|
||||
|
||||
参数:
|
||||
websocket: WebSocket连接
|
||||
"""
|
||||
online_text = ""
|
||||
offline_text = ""
|
||||
|
||||
try:
|
||||
async for message in websocket:
|
||||
# 解析服务器返回的JSON消息
|
||||
result = json.loads(message)
|
||||
|
||||
mode = result.get("mode", "")
|
||||
text = result.get("text", "")
|
||||
is_final = result.get("is_final", False)
|
||||
|
||||
# 根据模式更新文本
|
||||
if "online" in mode:
|
||||
online_text = text
|
||||
print(f"\r[在线识别] {online_text}", end="", flush=True)
|
||||
elif "offline" in mode:
|
||||
offline_text = text
|
||||
print(f"\n[离线识别] {offline_text}")
|
||||
|
||||
# 如果是最终结果,打印完整信息
|
||||
if is_final and offline_text:
|
||||
print("\n最终识别结果:")
|
||||
print(f"[离线识别] {offline_text}")
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
print(f"接收结果时出错: {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
args = parse_args()
|
||||
|
||||
# WebSocket URI
|
||||
protocol = "wss" if args.use_ssl else "ws"
|
||||
uri = f"{protocol}://{args.host}:{args.port}"
|
||||
|
||||
print(f"连接到服务器: {uri}")
|
||||
|
||||
try:
|
||||
# 创建WebSocket连接
|
||||
async with websockets.connect(
|
||||
uri,
|
||||
subprotocols=["binary"]
|
||||
) as websocket:
|
||||
|
||||
print("连接成功")
|
||||
|
||||
# 创建两个任务: 发送音频和接收结果
|
||||
send_task = asyncio.create_task(
|
||||
send_audio(websocket, args.audio_file, args.mode, args.chunk_size)
|
||||
)
|
||||
|
||||
receive_task = asyncio.create_task(
|
||||
receive_results(websocket)
|
||||
)
|
||||
|
||||
# 等待任务完成
|
||||
await asyncio.gather(send_task, receive_task)
|
||||
|
||||
except Exception as e:
|
||||
print(f"连接服务器失败: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行主函数
|
||||
asyncio.run(main())
|
@ -10,116 +10,79 @@ import argparse
|
||||
def parse_args():
|
||||
"""
|
||||
解析命令行参数
|
||||
|
||||
|
||||
返回:
|
||||
argparse.Namespace: 解析后的参数对象
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="FunASR WebSocket服务器")
|
||||
|
||||
|
||||
# 服务器配置
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="服务器主机地址,例如:localhost, 0.0.0.0"
|
||||
"--host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="服务器主机地址,例如:localhost, 0.0.0.0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=10095,
|
||||
help="WebSocket服务器端口"
|
||||
)
|
||||
|
||||
parser.add_argument("--port", type=int, default=10095, help="WebSocket服务器端口")
|
||||
|
||||
# SSL配置
|
||||
parser.add_argument(
|
||||
"--certfile",
|
||||
type=str,
|
||||
default="",
|
||||
help="SSL证书文件路径"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keyfile",
|
||||
type=str,
|
||||
default="",
|
||||
help="SSL密钥文件路径"
|
||||
)
|
||||
|
||||
parser.add_argument("--certfile", type=str, default="", help="SSL证书文件路径")
|
||||
parser.add_argument("--keyfile", type=str, default="", help="SSL密钥文件路径")
|
||||
|
||||
# ASR模型配置
|
||||
parser.add_argument(
|
||||
"--asr_model",
|
||||
type=str,
|
||||
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
help="离线ASR模型(从ModelScope获取)"
|
||||
help="离线ASR模型(从ModelScope获取)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--asr_model_revision",
|
||||
type=str,
|
||||
default="v2.0.4",
|
||||
help="离线ASR模型版本"
|
||||
"--asr_model_revision", type=str, default="v2.0.4", help="离线ASR模型版本"
|
||||
)
|
||||
|
||||
|
||||
# 在线ASR模型配置
|
||||
parser.add_argument(
|
||||
"--asr_model_online",
|
||||
type=str,
|
||||
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
|
||||
help="在线ASR模型(从ModelScope获取)"
|
||||
help="在线ASR模型(从ModelScope获取)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--asr_model_online_revision",
|
||||
type=str,
|
||||
default="v2.0.4",
|
||||
help="在线ASR模型版本"
|
||||
"--asr_model_online_revision",
|
||||
type=str,
|
||||
default="v2.0.4",
|
||||
help="在线ASR模型版本",
|
||||
)
|
||||
|
||||
|
||||
# VAD模型配置
|
||||
parser.add_argument(
|
||||
"--vad_model",
|
||||
type=str,
|
||||
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
help="VAD语音活动检测模型(从ModelScope获取)"
|
||||
help="VAD语音活动检测模型(从ModelScope获取)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vad_model_revision",
|
||||
type=str,
|
||||
default="v2.0.4",
|
||||
help="VAD模型版本"
|
||||
"--vad_model_revision", type=str, default="v2.0.4", help="VAD模型版本"
|
||||
)
|
||||
|
||||
|
||||
# 标点符号模型配置
|
||||
parser.add_argument(
|
||||
"--punc_model",
|
||||
type=str,
|
||||
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
|
||||
help="标点符号模型(从ModelScope获取)"
|
||||
help="标点符号模型(从ModelScope获取)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--punc_model_revision",
|
||||
type=str,
|
||||
default="v2.0.4",
|
||||
help="标点符号模型版本"
|
||||
"--punc_model_revision", type=str, default="v2.0.4", help="标点符号模型版本"
|
||||
)
|
||||
|
||||
|
||||
# 硬件配置
|
||||
parser.add_argument("--ngpu", type=int, default=1, help="GPU数量,0表示仅使用CPU")
|
||||
parser.add_argument(
|
||||
"--ngpu",
|
||||
type=int,
|
||||
default=1,
|
||||
help="GPU数量,0表示仅使用CPU"
|
||||
"--device", type=str, default="cuda", help="设备类型:cuda或cpu"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="设备类型:cuda或cpu"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ncpu",
|
||||
type=int,
|
||||
default=4,
|
||||
help="CPU核心数"
|
||||
)
|
||||
|
||||
parser.add_argument("--ncpu", type=int, default=4, help="CPU核心数")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -127,4 +90,4 @@ if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
print("配置参数:")
|
||||
for arg in vars(args):
|
||||
print(f" {arg}: {getattr(args, arg)}")
|
||||
print(f" {arg}: {getattr(args, arg)}")
|
||||
|
4
src/functor/__init__.py
Normal file
4
src/functor/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .vad_functor import VADFunctor
|
||||
from .base import FunctorFactory
|
||||
|
||||
__all__ = ["VADFunctor", "FunctorFactory"]
|
161
src/functor/asr_functor.py
Normal file
161
src/functor/asr_functor.py
Normal file
@ -0,0 +1,161 @@
|
||||
"""
|
||||
ASRFunctor
|
||||
负责对音频片段进行ASR处理, 以ASR_Result进行callback
|
||||
"""
|
||||
|
||||
from src.functor.base import BaseFunctor
|
||||
from src.models import AudioBinary_data_list, AudioBinary_Config, VAD_Functor_result
|
||||
from typing import Callable, List
|
||||
from queue import Queue, Empty
|
||||
import threading
|
||||
|
||||
# 日志
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
class ASRFunctor(BaseFunctor):
|
||||
"""
|
||||
ASRFunctor
|
||||
负责对音频片段进行ASR处理, 以ASR_Result进行callback
|
||||
需要配置好 _model, _callback, _input_queue, _audio_config
|
||||
否则无法run()启动线程
|
||||
|
||||
运行中, 使用reset_cache()重置缓存, 准备下次任务
|
||||
|
||||
使用stop()停止线程, 但需要等待input_queue为空
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# 资源与配置
|
||||
self._model: dict = {} # 模型
|
||||
self._callback: List[Callable] = [] # 回调函数
|
||||
self._input_queue: Queue = None # 输入队列
|
||||
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||
|
||||
# flag
|
||||
self._is_running: bool = False
|
||||
self._stop_event: bool = False
|
||||
|
||||
# 线程资源
|
||||
self._thread: threading.Thread = None
|
||||
|
||||
# 状态锁
|
||||
self._status_lock: threading.Lock = threading.Lock()
|
||||
|
||||
# 缓存
|
||||
self._hotwords: List[str] = []
|
||||
|
||||
def reset_cache(self) -> None:
|
||||
"""
|
||||
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||
"""
|
||||
pass
|
||||
|
||||
def set_input_queue(self, queue: Queue) -> None:
|
||||
"""
|
||||
设置监听的输入消息队列
|
||||
"""
|
||||
self._input_queue = queue
|
||||
|
||||
def set_model(self, model: dict) -> None:
|
||||
"""
|
||||
设置推理模型
|
||||
"""
|
||||
self._model = model
|
||||
|
||||
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
||||
"""
|
||||
设置音频配置
|
||||
"""
|
||||
self._audio_config = audio_config
|
||||
logger.debug("ASRFunctor设置音频配置: %s", self._audio_config)
|
||||
|
||||
def add_callback(self, callback: Callable) -> None:
|
||||
"""
|
||||
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||||
"""
|
||||
if not isinstance(self._callback, list):
|
||||
self._callback = []
|
||||
self._callback.append(callback)
|
||||
|
||||
def _do_callback(self, result: List[str]) -> None:
|
||||
"""
|
||||
回调函数
|
||||
"""
|
||||
text = result[0]["text"].replace(" ", "")
|
||||
for callback in self._callback:
|
||||
callback(text)
|
||||
|
||||
def _process(self, data: VAD_Functor_result) -> None:
|
||||
"""
|
||||
处理数据
|
||||
"""
|
||||
binary_data = data.audiobinary_data.binary_data
|
||||
result = self._model["asr"].generate(
|
||||
input=binary_data,
|
||||
chunk_size=self._audio_config.chunk_size,
|
||||
hotwords=self._hotwords,
|
||||
)
|
||||
self._do_callback(result)
|
||||
|
||||
def _run(self) -> None:
|
||||
"""
|
||||
线程运行逻辑
|
||||
"""
|
||||
with self._status_lock:
|
||||
self._is_running = True
|
||||
self._stop_event = False
|
||||
# 运行逻辑
|
||||
while self._is_running:
|
||||
try:
|
||||
data = self._input_queue.get(True, timeout=1)
|
||||
self._process(data)
|
||||
self._input_queue.task_done()
|
||||
# 当队列为空时, 间隔1s检测是否进入停止事件。
|
||||
except Empty:
|
||||
if self._stop_event:
|
||||
break
|
||||
continue
|
||||
# 其他异常
|
||||
except Exception as e:
|
||||
logger.error("ASRFunctor运行时发生错误: %s", e)
|
||||
raise e
|
||||
|
||||
def run(self) -> threading.Thread:
|
||||
"""
|
||||
启动线程
|
||||
Returns:
|
||||
threading.Thread: 返回已运行线程实例
|
||||
"""
|
||||
self._pre_check()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
return self._thread
|
||||
|
||||
def _pre_check(self) -> bool:
|
||||
"""
|
||||
预检查
|
||||
"""
|
||||
if self._model is None:
|
||||
raise ValueError("模型未设置")
|
||||
if self._audio_config is None:
|
||||
raise ValueError("音频配置未设置")
|
||||
if self._input_queue is None:
|
||||
raise ValueError("输入队列未设置")
|
||||
if self._callback is None:
|
||||
raise ValueError("回调函数未设置")
|
||||
return True
|
||||
|
||||
def stop(self) -> bool:
|
||||
"""
|
||||
停止线程
|
||||
"""
|
||||
with self._status_lock:
|
||||
self._stop_event = True
|
||||
self._thread.join()
|
||||
with self._status_lock:
|
||||
self._is_running = False
|
||||
return not self._thread.is_alive()
|
191
src/functor/base.py
Normal file
191
src/functor/base.py
Normal file
@ -0,0 +1,191 @@
|
||||
"""
|
||||
Functor基础模块
|
||||
|
||||
该模块定义了Functor的基类,所有功能性的类(如VAD、PUNC、ASR、SPK等)都应继承自这个基类。
|
||||
基类提供了数据处理的基本框架,包括:
|
||||
- 回调函数管理
|
||||
- 模型配置管理
|
||||
- 线程运行控制
|
||||
|
||||
主要类:
|
||||
BaseFunctor: Functor抽象类
|
||||
FunctorFactory: Functor工厂类
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List
|
||||
from queue import Queue
|
||||
import threading
|
||||
|
||||
|
||||
class BaseFunctor(ABC):
|
||||
"""
|
||||
Functor抽象类
|
||||
|
||||
该抽象类规定了所有的Functor类必须实现run()方法启动自身线程
|
||||
|
||||
属性:
|
||||
_callback (Callable): 处理完成后的回调函数
|
||||
_model (dict): 存储模型相关的配置和实例
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化函数器
|
||||
|
||||
参数:
|
||||
callback (Callable): 处理完成后的回调函数
|
||||
model (dict): 模型相关的配置和实例
|
||||
"""
|
||||
self._callback: List[Callable] = []
|
||||
self._model: dict = {}
|
||||
# flag
|
||||
self._is_running: bool = False
|
||||
self._stop_event: bool = False
|
||||
# 状态锁
|
||||
self._status_lock: threading.Lock = threading.Lock()
|
||||
# 线程资源
|
||||
self._thread: threading.Thread = None
|
||||
|
||||
def add_callback(self, callback: Callable):
|
||||
"""
|
||||
添加回调函数
|
||||
|
||||
参数:
|
||||
callback (Callable): 新的回调函数
|
||||
"""
|
||||
self._callback.append(callback)
|
||||
|
||||
def set_model(self, model: dict):
|
||||
"""
|
||||
设置模型配置
|
||||
|
||||
参数:
|
||||
model (dict): 新的模型配置
|
||||
"""
|
||||
self._model = model
|
||||
|
||||
def set_input_queue(self, queue: Queue):
|
||||
"""
|
||||
设置输入队列
|
||||
|
||||
参数:
|
||||
queue (Queue): 新的输入队列
|
||||
"""
|
||||
self._input_queue = queue
|
||||
|
||||
@abstractmethod
|
||||
def _run(self):
|
||||
"""
|
||||
线程运行逻辑
|
||||
|
||||
返回:
|
||||
当达到条件时触发callback
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
"""
|
||||
启动_run方法线程
|
||||
|
||||
返回:
|
||||
线程实例
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _pre_check(self):
|
||||
"""
|
||||
预检查
|
||||
|
||||
返回:
|
||||
预检查结果
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def stop(self):
|
||||
"""
|
||||
停止线程
|
||||
|
||||
返回:
|
||||
停止结果
|
||||
"""
|
||||
|
||||
|
||||
class FunctorFactory:
|
||||
"""
|
||||
Functor工厂类
|
||||
|
||||
该工厂类负责创建和配置Functor实例
|
||||
|
||||
主要方法:
|
||||
make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
|
||||
创建并配置Functor实例
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor:
|
||||
"""
|
||||
创建并配置Functor实例
|
||||
|
||||
参数:
|
||||
funtor_name (str): Functor名称
|
||||
config (dict): 配置信息
|
||||
models (dict): 模型信息
|
||||
|
||||
返回:
|
||||
BaseFunctor: 创建的Functor实例
|
||||
"""
|
||||
|
||||
if functor_name == "vad":
|
||||
return cls._make_vadfunctor(config=config, models=models)
|
||||
elif functor_name == "asr":
|
||||
return cls._make_asrfunctor(config=config, models=models)
|
||||
elif functor_name == "spk":
|
||||
return cls._make_spkfunctor(config=config, models=models)
|
||||
else:
|
||||
raise ValueError(f"不支持的Functor类型: {functor_name}")
|
||||
|
||||
def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||
"""
|
||||
创建VAD Functor实例
|
||||
"""
|
||||
from src.functor.vad_functor import VADFunctor
|
||||
|
||||
audio_config = config["audio_config"]
|
||||
model = {"vad": models["vad"]}
|
||||
|
||||
vad_functor = VADFunctor()
|
||||
vad_functor.set_audio_config(audio_config)
|
||||
vad_functor.set_model(model)
|
||||
|
||||
return vad_functor
|
||||
|
||||
def _make_asrfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||
"""
|
||||
创建ASR Functor实例
|
||||
"""
|
||||
from src.functor.asr_functor import ASRFunctor
|
||||
|
||||
audio_config = config["audio_config"]
|
||||
model = {"asr": models["asr"]}
|
||||
|
||||
asr_functor = ASRFunctor()
|
||||
asr_functor.set_audio_config(audio_config)
|
||||
asr_functor.set_model(model)
|
||||
|
||||
return asr_functor
|
||||
|
||||
def _make_spkfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||
"""
|
||||
创建SPK Functor实例
|
||||
"""
|
||||
from src.functor.spk_functor import SPKFunctor
|
||||
|
||||
audio_config = config["audio_config"]
|
||||
model = {"spk": models["spk"]}
|
||||
|
||||
spk_functor = SPKFunctor()
|
||||
spk_functor.set_audio_config(audio_config)
|
||||
spk_functor.set_model(model)
|
||||
|
||||
return spk_functor
|
122
src/functor/readme.md
Normal file
122
src/functor/readme.md
Normal file
@ -0,0 +1,122 @@
|
||||
# 对于Functor的解释
|
||||
|
||||
## Functor 文件夹作用
|
||||
|
||||
Functor文件夹用于存放所有功能性的类,包括VAD、PUNC、ASR、SPK等。
|
||||
|
||||
## Functor 类的定义
|
||||
|
||||
所有类应继承于**基类**`BaseFunctor`。
|
||||
|
||||
为了方便使用,我们对于**基类**的定义如下:
|
||||
|
||||
1. 函数内部使用的变量以单下划线开头,基类中包含:
|
||||
|
||||
* _model: Dict 存放模型相关的配置和实例
|
||||
* _input_queue: Queue 监听的输入消息队列
|
||||
* _thread: Threading.Thread 运行的线程实例
|
||||
* _callback: List[Callable] 回调函数列表
|
||||
* _is_running: bool 线程运行状态标志
|
||||
* _stop_event: bool 停止事件标志
|
||||
* _status_lock: threading.Lock 状态锁,用于线程同步
|
||||
|
||||
2. 对于使用的模型,请从统一的 **模型管理类`ModelLoader`** 中获取,由模型管理类统一进行加载、缓存和释放,`_model`存放类型为`dict`。
|
||||
|
||||
3. 基类定义的核心方法:
|
||||
|
||||
* `add_callback(callback: Callable)`: 添加结果处理的回调函数
|
||||
* `set_model(model: dict)`: 设置模型配置和实例
|
||||
* `set_input_queue(queue: Queue)`: 设置输入数据队列
|
||||
* `run()`: 启动处理线程(抽象方法)
|
||||
* `stop()`: 停止处理线程(抽象方法)
|
||||
* `_run()`: 线程运行的具体逻辑(抽象方法)
|
||||
* `_pre_check()`: 运行前的预检查(抽象方法)
|
||||
|
||||
## 派生类实现要求
|
||||
|
||||
1. 必须实现的抽象方法:
|
||||
* `_pre_check()`:
|
||||
- 检查必要的配置是否完整(如模型、队列等)
|
||||
- 检查运行环境是否满足要求
|
||||
- 返回检查结果
|
||||
|
||||
* `_run()`:
|
||||
- 实现具体的数据处理逻辑
|
||||
- 从 _input_queue 获取输入数据
|
||||
- 使用 _model 进行处理
|
||||
- 通过 _callback 返回处理结果
|
||||
|
||||
* `run()`:
|
||||
- 调用 _pre_check() 进行预检查
|
||||
- 创建并启动处理线程
|
||||
- 设置相关状态标志
|
||||
|
||||
* `stop()`:
|
||||
- 安全停止处理线程
|
||||
- 清理资源
|
||||
- 重置状态标志
|
||||
|
||||
2. 建议实现的方法:
|
||||
* `__str__`: 返回当前实例的状态信息
|
||||
* 错误处理方法:处理运行过程中的异常情况
|
||||
|
||||
## 使用示例
|
||||
|
||||
```python
|
||||
class MyFunctor(BaseFunctor):
|
||||
def _pre_check(self):
|
||||
if not self._model or not self._input_queue:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _run(self):
|
||||
while not self._stop_event:
|
||||
try:
|
||||
data = self._input_queue.get(timeout=1.0)
|
||||
result = self._model['my_model'].process(data)
|
||||
for callback in self._callback:
|
||||
callback(result)
|
||||
except Queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"处理错误: {e}")
|
||||
|
||||
def run(self):
|
||||
if not self._pre_check():
|
||||
raise RuntimeError("预检查失败")
|
||||
|
||||
with self._status_lock:
|
||||
if self._is_running:
|
||||
return
|
||||
self._is_running = True
|
||||
self._stop_event = False
|
||||
self._thread = threading.Thread(target=self._run)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self):
|
||||
with self._status_lock:
|
||||
if not self._is_running:
|
||||
return
|
||||
self._stop_event = True
|
||||
if self._thread:
|
||||
self._thread.join()
|
||||
self._is_running = False
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 线程安全:
|
||||
* 使用 _status_lock 保护状态变更
|
||||
* 注意共享资源的访问控制
|
||||
|
||||
2. 错误处理:
|
||||
* 在 _run() 中妥善处理异常
|
||||
* 提供详细的错误日志
|
||||
|
||||
3. 资源管理:
|
||||
* 确保在 stop() 中正确清理资源
|
||||
* 避免资源泄露
|
||||
|
||||
4. 回调函数:
|
||||
* 回调函数应该是非阻塞的
|
||||
* 处理回调函数抛出的异常
|
145
src/functor/spk_functor.py
Normal file
145
src/functor/spk_functor.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""
|
||||
SpkFunctor
|
||||
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
|
||||
"""
|
||||
|
||||
from src.functor.base import BaseFunctor
|
||||
from src.models import AudioBinary_Config, VAD_Functor_result
|
||||
from typing import Callable, List
|
||||
from queue import Queue, Empty
|
||||
import threading
|
||||
|
||||
# 日志
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
class SPKFunctor(BaseFunctor):
|
||||
"""
|
||||
SPKFunctor
|
||||
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
|
||||
需要配置好 _model, _callback, _input_queue, _audio_config
|
||||
否则无法run()启动线程
|
||||
|
||||
运行中, 使用reset_cache()重置缓存, 准备下次任务
|
||||
|
||||
使用stop()停止线程, 但需要等待input_queue为空
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# 资源与配置
|
||||
self._model: dict = {} # 模型
|
||||
self._callback: List[Callable] = [] # 回调函数
|
||||
self._input_queue: Queue = None # 输入队列
|
||||
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||
|
||||
def reset_cache(self) -> None:
|
||||
"""
|
||||
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||
"""
|
||||
pass
|
||||
|
||||
def set_input_queue(self, queue: Queue) -> None:
|
||||
"""
|
||||
设置监听的输入消息队列
|
||||
"""
|
||||
self._input_queue = queue
|
||||
|
||||
def set_model(self, model: dict) -> None:
|
||||
"""
|
||||
设置推理模型
|
||||
"""
|
||||
self._model = model
|
||||
|
||||
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
||||
"""
|
||||
设置音频配置
|
||||
"""
|
||||
self._audio_config = audio_config
|
||||
logger.debug("SpkFunctor设置音频配置: %s", self._audio_config)
|
||||
|
||||
def add_callback(self, callback: Callable) -> None:
|
||||
"""
|
||||
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||||
"""
|
||||
if not isinstance(self._callback, list):
|
||||
self._callback = []
|
||||
self._callback.append(callback)
|
||||
|
||||
def _do_callback(self, result: List[str]) -> None:
|
||||
"""
|
||||
回调函数
|
||||
"""
|
||||
for callback in self._callback:
|
||||
callback(result)
|
||||
|
||||
def _process(self, data: VAD_Functor_result) -> None:
|
||||
"""
|
||||
处理数据
|
||||
"""
|
||||
binary_data = data.audiobinary_data.binary_data
|
||||
# result = self._model["spk"].generate(
|
||||
# input=binary_data,
|
||||
# chunk_size=self._audio_config.chunk_size,
|
||||
# )
|
||||
result = [{"result": "spk1", "score": {"spk1": 0.9, "spk2": 0.3}}]
|
||||
self._do_callback(result)
|
||||
|
||||
def _run(self) -> None:
|
||||
"""
|
||||
线程运行逻辑
|
||||
"""
|
||||
with self._status_lock:
|
||||
self._is_running = True
|
||||
self._stop_event = False
|
||||
# 运行逻辑
|
||||
while self._is_running:
|
||||
try:
|
||||
data = self._input_queue.get(True, timeout=1)
|
||||
self._process(data)
|
||||
self._input_queue.task_done()
|
||||
# 当队列为空时, 间隔1s检测是否进入停止事件。
|
||||
except Empty:
|
||||
if self._stop_event:
|
||||
break
|
||||
continue
|
||||
# 其他异常
|
||||
except Exception as e:
|
||||
logger.error("SpkFunctor运行时发生错误: %s", e)
|
||||
raise e
|
||||
|
||||
def run(self) -> threading.Thread:
|
||||
"""
|
||||
启动线程
|
||||
Returns:
|
||||
threading.Thread: 返回已运行线程实例
|
||||
"""
|
||||
self._pre_check()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
return self._thread
|
||||
|
||||
def _pre_check(self) -> bool:
|
||||
"""
|
||||
预检查
|
||||
"""
|
||||
if self._model is None:
|
||||
raise ValueError("模型未设置")
|
||||
if self._input_queue is None:
|
||||
raise ValueError("输入队列未设置")
|
||||
if self._callback is None:
|
||||
raise ValueError("回调函数未设置")
|
||||
return True
|
||||
|
||||
def stop(self) -> bool:
|
||||
"""
|
||||
停止线程
|
||||
"""
|
||||
with self._status_lock:
|
||||
self._stop_event = True
|
||||
self._thread.join()
|
||||
with self._status_lock:
|
||||
self._is_running = False
|
||||
return not self._thread.is_alive()
|
315
src/functor/vad_functor.py
Normal file
315
src/functor/vad_functor.py
Normal file
@ -0,0 +1,315 @@
|
||||
"""
|
||||
VADFunctor
|
||||
负责对音频片段进行VAD处理, 以VAD_Result进行callback
|
||||
"""
|
||||
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
from typing import List, Any, Callable
|
||||
import numpy
|
||||
from src.models import (
|
||||
VAD_Functor_result,
|
||||
AudioBinary_Config,
|
||||
AudioBinary_data_list,
|
||||
)
|
||||
from src.functor.base import BaseFunctor
|
||||
|
||||
# 日志
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
class VADFunctor(BaseFunctor):
|
||||
"""
|
||||
VADFunctor
|
||||
负责对音频片段进行VAD处理, 以VAD_Result进行callback
|
||||
需要配置好 _model, _callback, _input_queue, _audio_config, _audio_binary_data_list
|
||||
否则无法run()启动线程
|
||||
|
||||
运行中, 使用reset_cache()重置缓存, 准备下次任务
|
||||
|
||||
使用stop()停止线程, 但需要等待input_queue为空
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# 资源与配置
|
||||
self._model: dict = {} # 模型
|
||||
self._callback: List[Callable] = [] # 回调函数
|
||||
self._input_queue: Queue = None # 输入队列
|
||||
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||
self._audio_binary_data_list: AudioBinary_data_list = None # 音频数据列表
|
||||
|
||||
# flag
|
||||
# 此处用到两个锁,但都是为了截断_run线程,考虑后续优化
|
||||
self._is_running: bool = False
|
||||
self._stop_event: bool = False
|
||||
|
||||
# 线程资源
|
||||
self._thread: threading.Thread = None
|
||||
|
||||
# 状态锁
|
||||
self._status_lock: threading.Lock = threading.Lock()
|
||||
|
||||
# 缓存
|
||||
self._audio_cache: numpy.ndarray = None
|
||||
self._audio_cache_preindex: int = 0
|
||||
self._model_cache: dict = {}
|
||||
self._cache_result_list = []
|
||||
self._audiobinary_cache = None
|
||||
|
||||
def reset_cache(self) -> None:
|
||||
"""
|
||||
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||
"""
|
||||
self._audio_cache = None
|
||||
self._audio_cache_preindex = 0
|
||||
self._model_cache = {}
|
||||
self._cache_result_list = []
|
||||
self._audiobinary_cache = None
|
||||
|
||||
def set_input_queue(self, queue: Queue) -> None:
|
||||
"""
|
||||
设置监听的输入消息队列
|
||||
"""
|
||||
self._input_queue = queue
|
||||
|
||||
def set_model(self, model: dict) -> None:
|
||||
"""
|
||||
设置推理模型
|
||||
"""
|
||||
self._model = model
|
||||
|
||||
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
||||
"""
|
||||
设置音频配置
|
||||
"""
|
||||
self._audio_config = audio_config
|
||||
logger.debug("VADFunctor设置音频配置: %s", self._audio_config)
|
||||
|
||||
def set_audio_binary_data_list(
|
||||
self, audio_binary_data_list: AudioBinary_data_list
|
||||
) -> None:
|
||||
"""
|
||||
设置音频数据列表, 为Class AudioBinary_data_list类型
|
||||
AudioBinary_data_list包含binary_data_list, 为list[_AudioBinary_data]类型
|
||||
_AudioBinary_data包含binary_data, 为bytes/numpy.ndarray类型
|
||||
"""
|
||||
self._audio_binary_data_list = audio_binary_data_list
|
||||
|
||||
def add_callback(self, callback: Callable) -> None:
|
||||
"""
|
||||
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||||
"""
|
||||
if not isinstance(self._callback, list):
|
||||
self._callback = []
|
||||
self._callback.append(callback)
|
||||
|
||||
def _do_callback(self, result: List[List[int]]) -> None:
|
||||
"""
|
||||
回调函数
|
||||
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice
|
||||
|
||||
输入:
|
||||
result: List[[start,end]] 处理所得VAD端点
|
||||
其中若start==-1, 则表示前无端点, 若end==-1, 则表示后无端点
|
||||
当处理得到一个完成片段时, 存入AudioBinary中, 并向队列中添加AudioBinary_Slice
|
||||
输出:
|
||||
None
|
||||
"""
|
||||
# 持久化缓存结果队列
|
||||
for pair in result:
|
||||
[start, end] = pair
|
||||
# 若无前端点, 则向缓存队列中合并
|
||||
if start == -1:
|
||||
self._cache_result_list[-1][1] = end
|
||||
else:
|
||||
self._cache_result_list.append(pair)
|
||||
while len(self._cache_result_list) > 1:
|
||||
# 创建VAD片段
|
||||
# 计算开始帧
|
||||
start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0])
|
||||
start_frame -= self._audio_cache_preindex
|
||||
# 计算结束帧
|
||||
end_frame = self._audio_config.ms2frame(self._cache_result_list[0][1])
|
||||
end_frame -= self._audio_cache_preindex
|
||||
# 计算开始时间
|
||||
vad_result = VAD_Functor_result.create_from_push_data(
|
||||
audiobinary_data_list=self._audio_binary_data_list,
|
||||
data=self._audiobinary_cache[start_frame:end_frame],
|
||||
start_time=self._cache_result_list[0][0],
|
||||
end_time=self._cache_result_list[0][1],
|
||||
)
|
||||
self._audio_cache_preindex += end_frame
|
||||
self._audiobinary_cache = self._audiobinary_cache[end_frame:]
|
||||
for callback in self._callback:
|
||||
callback(vad_result)
|
||||
self._cache_result_list.pop(0)
|
||||
|
||||
def _predeal_data(self, data: Any) -> None:
|
||||
"""
|
||||
预处理数据, 将数据缓存到_audio_cache和_audiobinary_cache中
|
||||
"""
|
||||
if self._audio_cache is None:
|
||||
self._audio_cache = data
|
||||
else:
|
||||
# 拼接音频数据
|
||||
if isinstance(self._audio_cache, numpy.ndarray):
|
||||
self._audio_cache = numpy.concatenate((self._audio_cache, data))
|
||||
elif isinstance(self._audio_cache, list):
|
||||
self._audio_cache.append(data)
|
||||
if self._audiobinary_cache is None:
|
||||
self._audiobinary_cache = data
|
||||
else:
|
||||
# 拼接音频数据
|
||||
if isinstance(self._audiobinary_cache, numpy.ndarray):
|
||||
self._audiobinary_cache = numpy.concatenate(
|
||||
(self._audiobinary_cache, data)
|
||||
)
|
||||
elif isinstance(self._audiobinary_cache, list):
|
||||
self._audiobinary_cache.append(data)
|
||||
|
||||
def _process(self, data: Any):
|
||||
"""
|
||||
处理数据
|
||||
使用model进行生成, 并使用_do_callback进行回调
|
||||
"""
|
||||
self._predeal_data(data)
|
||||
if len(self._audio_cache) >= self._audio_config.chunk_stride:
|
||||
result = self._model["vad"].generate(
|
||||
input=self._audio_cache,
|
||||
cache=self._model_cache,
|
||||
chunk_size=self._audio_config.chunk_size,
|
||||
is_final=False,
|
||||
)
|
||||
if len(result[0]["value"]) > 0:
|
||||
self._do_callback(result[0]["value"])
|
||||
# logger.debug(f"VADFunctor结果: {result[0]['value']}")
|
||||
self._audio_cache = None
|
||||
|
||||
def _run(self):
|
||||
"""
|
||||
线程运行逻辑
|
||||
监听输入队列, 当有数据时, 处理数据
|
||||
当输入队列为空时, 间隔1s检测是否进入停止事件。
|
||||
"""
|
||||
# 刷新运行状态
|
||||
with self._status_lock:
|
||||
self._is_running = True
|
||||
self._stop_event = False
|
||||
# 运行逻辑
|
||||
while self._is_running:
|
||||
try:
|
||||
data = self._input_queue.get(True, timeout=1)
|
||||
self._process(data)
|
||||
self._input_queue.task_done()
|
||||
# 当队列为空时, 间隔1s检测是否进入停止事件。
|
||||
except Empty:
|
||||
if self._stop_event:
|
||||
break
|
||||
continue
|
||||
# 其他异常
|
||||
except Exception as e:
|
||||
logger.error("VADFunctor运行时发生错误: %s", e)
|
||||
raise e
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
启动 _run 线程, 并返回线程对象
|
||||
"""
|
||||
self._pre_check()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
return self._thread
|
||||
|
||||
def _pre_check(self) -> bool:
|
||||
"""
|
||||
检测硬性资源是否都已设置
|
||||
"""
|
||||
if self._model is None:
|
||||
raise ValueError("模型未设置")
|
||||
if self._audio_config is None:
|
||||
raise ValueError("音频配置未设置")
|
||||
if self._audio_binary_data_list is None:
|
||||
raise ValueError("音频数据列表未设置")
|
||||
if self._input_queue is None:
|
||||
raise ValueError("输入队列未设置")
|
||||
if self._callback is None:
|
||||
raise ValueError("回调函数未设置")
|
||||
return True
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
停止线程
|
||||
通过设置_stop_event为True, 来在input_queue.get()循环为空时退出
|
||||
"""
|
||||
with self._status_lock:
|
||||
self._stop_event = True
|
||||
self._thread.join()
|
||||
with self._status_lock:
|
||||
self._is_running = False
|
||||
return not self._thread.is_alive()
|
||||
|
||||
|
||||
# class VAD:
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# VAD_model=None,
|
||||
# audio_config: AudioBinary_Config = None,
|
||||
# callback: Callable = None,
|
||||
# ):
|
||||
# # vad model
|
||||
# self.VAD_model = VAD_model
|
||||
# if self.VAD_model is None:
|
||||
# self.VAD_model = AutoModel(
|
||||
# model="fsmn-vad", model_revision="v2.0.4", disable_update=True
|
||||
# )
|
||||
# # audio config
|
||||
# self.audio_config = audio_config
|
||||
# # vad result
|
||||
# self.vad_result = VADResponse(time_chunk_index_callback=callback)
|
||||
# # audio binary poll
|
||||
# self.audio_chunk = AudioChunk(audio_config=self.audio_config)
|
||||
# self.cache = {}
|
||||
|
||||
# def push_binary_data(
|
||||
# self,
|
||||
# binary_data: bytes,
|
||||
# ):
|
||||
# # 压入二进制数据
|
||||
# self.audio_chunk.add_chunk(binary_data)
|
||||
# # 处理音频块
|
||||
# res = self.VAD_model.generate(
|
||||
# input=binary_data,
|
||||
# cache=self.cache,
|
||||
# chunk_size=self.audio_config.chunk_size,
|
||||
# is_final=False,
|
||||
# )
|
||||
# # print("VAD generate", res)
|
||||
# if len(res[0]["value"]):
|
||||
# self.vad_result += VADResponse.from_raw(res)
|
||||
|
||||
# def set_callback(
|
||||
# self,
|
||||
# callback: Callable,
|
||||
# ):
|
||||
# self.vad_result.time_chunk_index_callback = callback
|
||||
|
||||
# def process_vad_result(self, callback: Callable = None):
|
||||
# # 处理VAD结果
|
||||
# callback = (
|
||||
# callback
|
||||
# if callback is not None
|
||||
# else self.vad_result.time_chunk_index_callback
|
||||
# )
|
||||
# self.vad_result.process_time_chunk(
|
||||
# lambda x: callback(
|
||||
# AudioBinary_Chunk(
|
||||
# start_time=x["start_time"],
|
||||
# end_time=x["end_time"],
|
||||
# chunk=self.audio_chunk.get_chunk(x["start_time"], x["end_time"]),
|
||||
# )
|
||||
# )
|
||||
# )
|
176
src/logic_trager.py
Normal file
176
src/logic_trager.py
Normal file
@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
逻辑触发器类 - 用于处理音频数据并触发相应的处理逻辑
|
||||
"""
|
||||
|
||||
from src.utils.logger import get_module_logger
|
||||
from typing import Any, Dict, Type, Callable
|
||||
|
||||
# 配置日志
|
||||
logger = get_module_logger(__name__, level="INFO")
|
||||
|
||||
|
||||
class AutoAfterMeta(type):
|
||||
"""
|
||||
自动调用__after__函数的元类
|
||||
实现单例模式
|
||||
"""
|
||||
|
||||
_instances: Dict[Type, Any] = {} # 存储单例实例
|
||||
|
||||
def __new__(cls, name, bases, attrs):
|
||||
# 遍历所有属性
|
||||
for attr_name, attr_value in attrs.items():
|
||||
# 如果是函数且不是以_开头
|
||||
if callable(attr_value) and not attr_name.startswith("__"):
|
||||
# 获取原函数
|
||||
original_func = attr_value
|
||||
|
||||
# 创建包装函数
|
||||
def make_wrapper(func):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
# 执行原函数
|
||||
result = func(self, *args, **kwargs)
|
||||
|
||||
# 构建_after_函数名
|
||||
after_func_name = f"__after__{func.__name__}"
|
||||
|
||||
# 检查是否存在对应的_after_函数
|
||||
if hasattr(self, after_func_name):
|
||||
after_func = getattr(self, after_func_name)
|
||||
if callable(after_func):
|
||||
try:
|
||||
# 调用_after_函数
|
||||
after_func()
|
||||
except Exception as e:
|
||||
logger.error(f"调用{after_func_name}时出错: {e}")
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
# 替换原函数
|
||||
attrs[attr_name] = make_wrapper(original_func)
|
||||
|
||||
# 创建类
|
||||
new_class = super().__new__(cls, name, bases, attrs)
|
||||
return new_class
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
"""
|
||||
重写__call__方法实现单例模式
|
||||
当类被调用时(即创建实例时)执行
|
||||
"""
|
||||
if cls not in cls._instances:
|
||||
# 如果实例不存在,创建新实例
|
||||
cls._instances[cls] = super().__call__(*args, **kwargs)
|
||||
logger.info(f"创建{cls.__name__}的新实例")
|
||||
else:
|
||||
logger.debug(f"返回{cls.__name__}的现有实例")
|
||||
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
"""
|
||||
整体识别的处理逻辑:
|
||||
1.压入二进制音频信息
|
||||
2.不断检测VAD
|
||||
3.当检测到完整VAD时,将VAD的音频信息压入音频块,并清除对应二进制信息
|
||||
4.对音频块进行语音转文字offline,时间戳预测,说话人识别
|
||||
5.将识别结果整合压入结果队列
|
||||
6.结果队列被压入时调用回调函数
|
||||
|
||||
1->2 __after__push_binary_data 外部压入二进制信息
|
||||
2,3->4 __after__push_audio_chunk 内部压入音频块
|
||||
4->5 push_result_queue 压入结果队列
|
||||
5->6 __after__push_result_queue 调用回调函数
|
||||
"""
|
||||
|
||||
from src.functor import VAD
|
||||
from src.models import AudioBinary_Config
|
||||
from src.models import AudioBinary_Chunk
|
||||
from typing import List
|
||||
|
||||
|
||||
class LogicTrager(metaclass=AutoAfterMeta):
|
||||
"""逻辑触发器类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_chunk_max_size: int = 1024 * 1024 * 10,
|
||||
audio_config: AudioBinary_Config = None,
|
||||
result_callback: Callable = None,
|
||||
models: Dict[str, Any] = None,
|
||||
):
|
||||
"""初始化"""
|
||||
# 存储音频块
|
||||
self._audio_chunk: List[AudioBinary_Chunk] = []
|
||||
# 存储二进制数据
|
||||
self._audio_chunk_binary = b""
|
||||
self._audio_chunk_max_size = audio_chunk_max_size
|
||||
# 音频参数
|
||||
self._audio_config = (
|
||||
audio_config if audio_config is not None else AudioBinary_Config()
|
||||
)
|
||||
# 结果队列
|
||||
self._result_queue = []
|
||||
# 聚合结果回调函数
|
||||
self._aggregate_result_callback = result_callback
|
||||
# 组件
|
||||
self._vad = VAD(VAD_model=models.get("vad"), audio_config=self._audio_config)
|
||||
self._vad.set_callback(self.push_audio_chunk)
|
||||
|
||||
logger.info("初始化LogicTrager")
|
||||
|
||||
def push_binary_data(self, chunk: bytes) -> None:
|
||||
"""
|
||||
压入音频块至VAD模块
|
||||
|
||||
参数:
|
||||
chunk: 音频数据块
|
||||
"""
|
||||
# print("LogicTrager push_binary_data", len(chunk))
|
||||
self._vad.push_binary_data(chunk)
|
||||
self.__after__push_binary_data()
|
||||
|
||||
def __after__push_binary_data(self) -> None:
|
||||
"""
|
||||
添加音频块后处理
|
||||
"""
|
||||
# print("LogicTrager __after__push_binary_data")
|
||||
self._vad.process_vad_result()
|
||||
|
||||
def push_audio_chunk(self, chunk: AudioBinary_Chunk) -> None:
|
||||
"""
|
||||
音频处理
|
||||
"""
|
||||
logger.info(
|
||||
"LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(
|
||||
chunk.start_time, chunk.end_time, len(chunk.chunk)
|
||||
)
|
||||
)
|
||||
self._audio_chunk.append(chunk)
|
||||
|
||||
def __after__push_audio_chunk(self) -> None:
|
||||
"""
|
||||
压入音频块后处理
|
||||
"""
|
||||
pass
|
||||
|
||||
def push_result_queue(self, result: Dict[str, Any]) -> None:
|
||||
"""
|
||||
压入结果队列
|
||||
"""
|
||||
self._result_queue.append(result)
|
||||
|
||||
def __after__push_result_queue(self) -> None:
|
||||
"""
|
||||
压入结果队列后处理
|
||||
"""
|
||||
logger.info("FINISH Result=")
|
||||
pass
|
||||
|
||||
def __call__(self):
|
||||
"""调用函数"""
|
||||
pass
|
126
src/model_loader.py
Normal file
126
src/model_loader.py
Normal file
@ -0,0 +1,126 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
模型加载模块 - 负责加载各种语音识别相关模型
|
||||
"""
|
||||
try:
|
||||
# 导入FunASR库
|
||||
from funasr import AutoModel
|
||||
except ImportError as exc:
|
||||
raise ImportError("未找到funasr库, 请先安装: pip install funasr") from exc
|
||||
|
||||
# 日志模块
|
||||
from src.utils import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
# 单例模式
|
||||
class ModelLoader:
|
||||
"""
|
||||
ModelLoader类是单例模式, 程序生命周期全局唯一, 负责加载模型到字典中。
|
||||
一般的, 可以直接call ModelLoader()来获取加载的模型。
|
||||
也可以通过ModelLoader实例(args)或ModelloaderInstance.load_models(args)来初始化, 并加载模型。
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""
|
||||
单例模式
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ModelLoader, cls).__new__(cls, *args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, args=None):
|
||||
"""
|
||||
初始化ModelLoader实例
|
||||
"""
|
||||
self.models = {}
|
||||
logger.debug("初始化ModelLoader")
|
||||
if args is not None:
|
||||
self.__call__(args)
|
||||
|
||||
def __call__(self, args=None):
|
||||
"""
|
||||
调用ModelLoader实例时, 如果模型字典为空, 则加载模型
|
||||
"""
|
||||
# 如果模型字典为空, 则加载模型
|
||||
if self.models == {} or self.models is None:
|
||||
if args.asr_model is not None:
|
||||
self.models = self.load_models(args)
|
||||
# 直接调用等于调用self.models
|
||||
return self.models
|
||||
|
||||
def _load_model(self, input_model_args: dict, model_type: str):
|
||||
"""
|
||||
加载单个模型
|
||||
|
||||
参数:
|
||||
model_args: 模型加载字典
|
||||
model_type: 模型类型, 用于确定使用哪个模型参数
|
||||
|
||||
返回:
|
||||
AutoModel: 加载的模型实例
|
||||
"""
|
||||
# 默认配置
|
||||
default_config = {
|
||||
"model": None,
|
||||
"model_revision": None,
|
||||
"ngpu": 0,
|
||||
"ncpu": 1,
|
||||
"device": "cpu",
|
||||
"disable_pbar": True,
|
||||
"disable_log": True,
|
||||
"disable_update": True,
|
||||
}
|
||||
# 从args中获取配置, 如果存在则覆盖默认值
|
||||
model_args = default_config.copy()
|
||||
for key, value in default_config.items():
|
||||
if key in ["model", "model_revision"]:
|
||||
# 特殊处理model和model_revision, 因为它们需要model_type前缀
|
||||
if key == "model":
|
||||
value = input_model_args.get(f"{model_type}_model", None)
|
||||
else:
|
||||
value = input_model_args.get(f"{model_type}_model_revision", None)
|
||||
else:
|
||||
value = input_model_args.get(key, None)
|
||||
if value is not None:
|
||||
logger.debug("替换%s模型参数: %s = %s", model_type, key, value)
|
||||
model_args[key] = value
|
||||
# 验证必要参数
|
||||
if not model_args["model"]:
|
||||
raise ValueError(f"未指定{model_type}模型路径")
|
||||
try:
|
||||
# 使用 % 格式化替代 f-string,避免不必要的字符串格式化开销
|
||||
logger.debug("正在加载%s模型: %s", model_type, model_args["model"])
|
||||
model = AutoModel(**model_args)
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error("加载%s模型失败: %s", model_type, str(e))
|
||||
raise
|
||||
|
||||
def load_models(self, args):
|
||||
"""
|
||||
加载所有需要的模型
|
||||
参数:
|
||||
args: 命令行参数, 包含模型配置
|
||||
|
||||
返回:
|
||||
dict: 包含所有加载的模型的字典
|
||||
"""
|
||||
logger.info("ModelLoader加载模型")
|
||||
# 初始化模型字典
|
||||
self.models = {}
|
||||
# 加载离线ASR模型
|
||||
# 检查对应键是否存在
|
||||
model_list = ["asr", "asr_online", "vad", "punc", "spk"]
|
||||
for model_name in model_list:
|
||||
name_model = f"{model_name}_model"
|
||||
name_model_revision = f"{model_name}_model_revision"
|
||||
if name_model in args:
|
||||
logger.debug("加载%s模型", model_name)
|
||||
self.models[model_name] = self._load_model(args, model_name)
|
||||
logger.info("所有模型加载完成")
|
||||
return self.models
|
@ -1,79 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
模型加载模块 - 负责加载各种语音识别相关模型
|
||||
"""
|
||||
|
||||
def load_models(args):
|
||||
"""
|
||||
加载所有需要的模型
|
||||
|
||||
参数:
|
||||
args: 命令行参数,包含模型配置
|
||||
|
||||
返回:
|
||||
dict: 包含所有加载的模型的字典
|
||||
"""
|
||||
try:
|
||||
# 导入FunASR库
|
||||
from funasr import AutoModel
|
||||
except ImportError:
|
||||
raise ImportError("未找到funasr库,请先安装: pip install funasr")
|
||||
|
||||
# 初始化模型字典
|
||||
models = {}
|
||||
|
||||
# 1. 加载离线ASR模型
|
||||
print(f"正在加载ASR离线模型: {args.asr_model}")
|
||||
models["asr"] = AutoModel(
|
||||
model=args.asr_model,
|
||||
model_revision=args.asr_model_revision,
|
||||
ngpu=args.ngpu,
|
||||
ncpu=args.ncpu,
|
||||
device=args.device,
|
||||
disable_pbar=True,
|
||||
disable_log=True,
|
||||
)
|
||||
|
||||
# 2. 加载在线ASR模型
|
||||
print(f"正在加载ASR在线模型: {args.asr_model_online}")
|
||||
models["asr_streaming"] = AutoModel(
|
||||
model=args.asr_model_online,
|
||||
model_revision=args.asr_model_online_revision,
|
||||
ngpu=args.ngpu,
|
||||
ncpu=args.ncpu,
|
||||
device=args.device,
|
||||
disable_pbar=True,
|
||||
disable_log=True,
|
||||
)
|
||||
|
||||
# 3. 加载VAD模型
|
||||
print(f"正在加载VAD模型: {args.vad_model}")
|
||||
models["vad"] = AutoModel(
|
||||
model=args.vad_model,
|
||||
model_revision=args.vad_model_revision,
|
||||
ngpu=args.ngpu,
|
||||
ncpu=args.ncpu,
|
||||
device=args.device,
|
||||
disable_pbar=True,
|
||||
disable_log=True,
|
||||
)
|
||||
|
||||
# 4. 加载标点符号模型(如果指定)
|
||||
if args.punc_model:
|
||||
print(f"正在加载标点符号模型: {args.punc_model}")
|
||||
models["punc"] = AutoModel(
|
||||
model=args.punc_model,
|
||||
model_revision=args.punc_model_revision,
|
||||
ngpu=args.ngpu,
|
||||
ncpu=args.ncpu,
|
||||
device=args.device,
|
||||
disable_pbar=True,
|
||||
disable_log=True,
|
||||
)
|
||||
else:
|
||||
models["punc"] = None
|
||||
print("未指定标点符号模型,将不使用标点符号")
|
||||
|
||||
print("所有模型加载完成")
|
||||
return models
|
9
src/models/__init__.py
Normal file
9
src/models/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
|
||||
from .vad import VAD_Functor_result
|
||||
|
||||
__all__ = [
|
||||
"AudioBinary_Config",
|
||||
"AudioBinary_data_list",
|
||||
"_AudioBinary_data",
|
||||
"VAD_Functor_result",
|
||||
]
|
158
src/models/audio.py
Normal file
158
src/models/audio.py
Normal file
@ -0,0 +1,158 @@
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import List, Any
|
||||
import numpy
|
||||
|
||||
from src.utils import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
binary_data_types = (bytes, numpy.ndarray)
|
||||
|
||||
|
||||
class AudioBinary_Config(BaseModel):
|
||||
"""二进制音频块配置信息"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
audio_data: binary_data_types = Field(description="音频数据", default=None)
|
||||
chunk_size: int = Field(description="块大小", default=100)
|
||||
chunk_stride: int = Field(description="块步长", default=1600)
|
||||
sample_rate: int = Field(description="采样率", default=16000)
|
||||
sample_width: int = Field(description="采样位宽", default=2)
|
||||
channels: int = Field(description="通道数", default=1)
|
||||
|
||||
# 从Dict中加载
|
||||
@classmethod
|
||||
def AudioBinary_Config_from_dict(cls, data: dict):
|
||||
return cls(**data)
|
||||
|
||||
def ms2frame(self, ms: int) -> int:
|
||||
"""
|
||||
将毫秒转换为帧
|
||||
"""
|
||||
return int(ms * self.sample_rate / 1000)
|
||||
|
||||
def frame2ms(self, frame: int) -> int:
|
||||
"""
|
||||
将帧转换为毫秒
|
||||
"""
|
||||
return int(frame * 1000 / self.sample_rate)
|
||||
|
||||
|
||||
class _AudioBinary_data(BaseModel):
|
||||
"""音频数据"""
|
||||
|
||||
binary_data: binary_data_types = Field(description="音频二进制数据", default=None)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("binary_data")
|
||||
def validate_binary_data(cls, v):
|
||||
"""
|
||||
验证音频数据
|
||||
Args:
|
||||
v: 音频数据
|
||||
Returns:
|
||||
binary_data_types: 音频数据
|
||||
"""
|
||||
if not isinstance(v, (bytes, numpy.ndarray)):
|
||||
logger.warning(
|
||||
"[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查",
|
||||
cls.__class__.__name__,
|
||||
type(v),
|
||||
)
|
||||
return v
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
获取音频数据长度
|
||||
Returns:
|
||||
int: 音频数据长度
|
||||
"""
|
||||
return len(self.binary_data)
|
||||
|
||||
def __init__(self, binary_data: binary_data_types):
|
||||
"""
|
||||
初始化音频数据
|
||||
Args:
|
||||
binary_data: 音频数据
|
||||
"""
|
||||
logger.debug(
|
||||
"[%s]初始化音频数据, 数据类型为%s",
|
||||
self.__class__.__name__,
|
||||
type(binary_data),
|
||||
)
|
||||
super().__init__(binary_data=binary_data)
|
||||
|
||||
def __getitem__(self):
|
||||
"""
|
||||
当获取数据时, 直接返回binary_data
|
||||
Returns:
|
||||
binary_data_types: 音频数据
|
||||
"""
|
||||
return self.binary_data
|
||||
|
||||
|
||||
class AudioBinary_data_list(BaseModel):
|
||||
"""音频数据列表"""
|
||||
|
||||
binary_data_list: List[_AudioBinary_data] = Field(
|
||||
description="音频数据列表", default=[]
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def push_data(self, data: binary_data_types) -> int:
|
||||
"""
|
||||
添加音频数据
|
||||
Args:
|
||||
data: 音频数据
|
||||
Returns:
|
||||
int: 数据在binary_data_list中的索引
|
||||
"""
|
||||
self.binary_data_list.append(_AudioBinary_data(binary_data=data))
|
||||
return len(self.binary_data_list) - 1
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
"""
|
||||
获取音频数据
|
||||
Args:
|
||||
index: 音频数据在binary_data_list中的索引
|
||||
Returns:
|
||||
_AudioBinary_data: 音频数据
|
||||
"""
|
||||
return self.binary_data_list[index]
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
获取音频数据列表长度
|
||||
Returns:
|
||||
int: 音频数据列表长度
|
||||
"""
|
||||
return len(self.binary_data_list)
|
||||
|
||||
|
||||
# class AudioBinary_Slice(BaseModel):
|
||||
# """音频块切片"""
|
||||
# target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None)
|
||||
# start_index: int = Field(description="开始位置", default=0)
|
||||
# end_index: int = Field(description="结束位置", default=-1)
|
||||
|
||||
# @validator('start_index')
|
||||
# def validate_start_index(cls, v):
|
||||
# if v < 0:
|
||||
# raise ValueError("start_index必须大于0")
|
||||
# return v
|
||||
|
||||
# @validator('end_index')
|
||||
# def validate_end_index(cls, v):
|
||||
# if v < cls.start_index:
|
||||
# logger.debug("[%s]end_index小于start_index, 将end_index设置为start_index", cls.__class__.__name__)
|
||||
# v = cls.start_index
|
||||
# return v
|
||||
|
||||
# def __call__(self):
|
||||
# return self.target_Binary(self.start_index, self.end_index)
|
91
src/models/vad.py
Normal file
91
src/models/vad.py
Normal file
@ -0,0 +1,91 @@
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import List, Optional, Callable, Any
|
||||
from .audio import AudioBinary_data_list, _AudioBinary_data
|
||||
|
||||
|
||||
class VAD_Functor_result(BaseModel):
|
||||
"""
|
||||
VADFunctor结果
|
||||
"""
|
||||
|
||||
audiobinary_data_list: AudioBinary_data_list = Field(description="音频数据列表")
|
||||
audiobinary_index: int = Field(description="音频数据索引")
|
||||
audiobinary_data: _AudioBinary_data = Field(
|
||||
description="音频数据, 指向AudioBinary_data"
|
||||
)
|
||||
start_time: int = Field(description="开始时间", is_required=True)
|
||||
end_time: int = Field(description="结束时间", is_required=True)
|
||||
|
||||
@validator("audiobinary_data_list")
|
||||
def validate_audiobinary_data_list(cls, v):
|
||||
if not isinstance(v, AudioBinary_data_list):
|
||||
raise ValueError("audiobinary_data_list必须是AudioBinary_data_list类型")
|
||||
return v
|
||||
|
||||
@validator("audiobinary_index")
|
||||
def validate_audiobinary_index(cls, v):
|
||||
if not isinstance(v, int):
|
||||
raise ValueError("audiobinary_index必须是int类型")
|
||||
if v < 0:
|
||||
raise ValueError("audiobinary_index必须大于0")
|
||||
return v
|
||||
|
||||
@validator("audiobinary_data")
|
||||
def validate_audiobinary_data(cls, v):
|
||||
if not isinstance(v, _AudioBinary_data):
|
||||
raise ValueError("audiobinary_data必须是AudioBinary_data类型")
|
||||
return v
|
||||
|
||||
@validator("start_time")
|
||||
def validate_start_time(cls, v):
|
||||
if not isinstance(v, int):
|
||||
raise ValueError("start_time必须是int类型")
|
||||
if v < 0:
|
||||
raise ValueError("start_time必须大于0")
|
||||
return v
|
||||
|
||||
@validator("end_time")
|
||||
def validate_end_time(cls, v, values):
|
||||
if not isinstance(v, int):
|
||||
raise ValueError("end_time必须是int类型")
|
||||
if "start_time" in values and v <= values["start_time"]:
|
||||
raise ValueError("end_time必须大于start_time")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def create_from_push_data(
|
||||
cls,
|
||||
audiobinary_data_list: AudioBinary_data_list,
|
||||
data: Any,
|
||||
start_time: int,
|
||||
end_time: int,
|
||||
):
|
||||
"""
|
||||
创建VAD片段
|
||||
"""
|
||||
index = audiobinary_data_list.push_data(data)
|
||||
|
||||
return cls(
|
||||
audiobinary_data_list=audiobinary_data_list,
|
||||
audiobinary_index=index,
|
||||
audiobinary_data=audiobinary_data_list[index],
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
获取音频数据长度
|
||||
"""
|
||||
return len(self.audiobinary_data.binary_data)
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
字符串展示内容
|
||||
"""
|
||||
tostr = f"audiobinary_data_index: {self.audiobinary_index}\n"
|
||||
tostr += f"start_time: {self.start_time}\n"
|
||||
tostr += f"end_time: {self.end_time}\n"
|
||||
tostr += f"data_length: {len(self.audiobinary_data.binary_data)}\n"
|
||||
tostr += f"data_type: {type(self.audiobinary_data.binary_data)}\n"
|
||||
return tostr
|
265
src/pipeline/ASRpipeline.py
Normal file
265
src/pipeline/ASRpipeline.py
Normal file
@ -0,0 +1,265 @@
|
||||
from src.pipeline.base import PipelineBase
|
||||
from typing import Dict, Any, Callable
|
||||
from queue import Queue, Empty
|
||||
from src.utils import get_module_logger
|
||||
from src.models import AudioBinary_data_list
|
||||
import threading
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
class ASRPipeline(PipelineBase):
|
||||
"""
|
||||
管道类
|
||||
实现具体的处理逻辑
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
初始化管道
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._config: Dict[str, Any] = {}
|
||||
self._functor_dict: Dict[str, Any] = {}
|
||||
self._subqueue_dict: Dict[str, Any] = {}
|
||||
self._is_baked: bool = False
|
||||
self._input_queue: Queue = None
|
||||
self._audio_binary_data_list: AudioBinary_data_list = None
|
||||
|
||||
self._status_lock = threading.Lock()
|
||||
self._is_running: bool = False
|
||||
self._stop_event: bool = False
|
||||
|
||||
def set_config(self, config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
设置配置
|
||||
参数:
|
||||
config: Dict[str, Any] 配置
|
||||
"""
|
||||
self._config = config
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取配置
|
||||
返回:
|
||||
Dict[str, Any] 配置
|
||||
"""
|
||||
return self._config
|
||||
|
||||
def set_audio_binary(self, audio_binary: AudioBinary_data_list) -> None:
|
||||
"""
|
||||
设置音频二进制存储单元
|
||||
参数:
|
||||
audio_binary: 音频二进制
|
||||
"""
|
||||
self._audio_binary = audio_binary
|
||||
|
||||
def set_models(self, models: Dict[str, Any]) -> None:
|
||||
"""
|
||||
设置模型
|
||||
"""
|
||||
self._models = models
|
||||
|
||||
def set_input_queue(self, input_queue: Queue) -> None:
|
||||
"""
|
||||
设置输入队列
|
||||
"""
|
||||
self._input_queue = input_queue
|
||||
|
||||
def bake(self) -> None:
|
||||
"""
|
||||
烘焙管道
|
||||
"""
|
||||
self._pre_check_resource()
|
||||
self._init_functor()
|
||||
self._is_baked = True
|
||||
|
||||
def _pre_check_resource(self) -> None:
|
||||
"""
|
||||
预检查资源
|
||||
"""
|
||||
if self._input_queue is None:
|
||||
raise RuntimeError("[ASRpipeline]输入队列未设置")
|
||||
if self._functor_dict is None:
|
||||
raise RuntimeError("[ASRpipeline]functor字典未设置")
|
||||
if self._subqueue_dict is None:
|
||||
raise RuntimeError("[ASRpipeline]子队列字典未设置")
|
||||
if self._audio_binary is None:
|
||||
raise RuntimeError("[ASRpipeline]音频存储单元未设置")
|
||||
|
||||
def _init_functor(self) -> None:
|
||||
"""
|
||||
初始化函数
|
||||
"""
|
||||
try:
|
||||
from src.functor import FunctorFactory
|
||||
|
||||
# 加载VAD、asr、spk functor
|
||||
self._functor_dict["vad"] = FunctorFactory.make_functor(
|
||||
functor_name="vad", config=self._config, models=self._models
|
||||
)
|
||||
self._functor_dict["asr"] = FunctorFactory.make_functor(
|
||||
functor_name="asr", config=self._config, models=self._models
|
||||
)
|
||||
self._functor_dict["spk"] = FunctorFactory.make_functor(
|
||||
functor_name="spk", config=self._config, models=self._models
|
||||
)
|
||||
|
||||
# 创建音频数据存储单元
|
||||
self._audio_binary_data_list = AudioBinary_data_list()
|
||||
|
||||
self._functor_dict["vad"].set_audio_binary_data_list(
|
||||
self._audio_binary_data_list
|
||||
)
|
||||
|
||||
# 初始化子队列
|
||||
self._subqueue_dict["original"] = Queue()
|
||||
self._subqueue_dict["vad2asr"] = Queue()
|
||||
self._subqueue_dict["vad2spk"] = Queue()
|
||||
self._subqueue_dict["asrend"] = Queue()
|
||||
self._subqueue_dict["spkend"] = Queue()
|
||||
|
||||
# 设置子队列的输入队列
|
||||
self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"])
|
||||
self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
|
||||
self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
|
||||
|
||||
# 设置回调函数——放置到对应队列中
|
||||
self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put)
|
||||
self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put)
|
||||
|
||||
# 构造带回调函数的put
|
||||
def put_with_checkcallback(queue: Queue, callback: Callable) -> None:
|
||||
"""
|
||||
带回调函数的put
|
||||
"""
|
||||
|
||||
def put_with_check(data: Any) -> None:
|
||||
queue.put(data)
|
||||
callback(data)
|
||||
|
||||
return put_with_check
|
||||
|
||||
self._functor_dict["asr"].add_callback(
|
||||
put_with_checkcallback(
|
||||
self._subqueue_dict["asrend"], self._check_result
|
||||
)
|
||||
)
|
||||
self._functor_dict["spk"].add_callback(
|
||||
put_with_checkcallback(
|
||||
self._subqueue_dict["spkend"], self._check_result
|
||||
)
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
|
||||
|
||||
def _check_result(self, result: Any) -> None:
|
||||
"""
|
||||
检查结果
|
||||
"""
|
||||
# 若asr和spk队列中都有数据,则合并数据
|
||||
if (
|
||||
self._subqueue_dict["asrend"].qsize()
|
||||
& self._subqueue_dict["spkend"].qsize()
|
||||
):
|
||||
asr_data = self._subqueue_dict["asrend"].get()
|
||||
spk_data = self._subqueue_dict["spkend"].get()
|
||||
# 合并数据
|
||||
result = {"asr_data": asr_data, "spk_data": spk_data}
|
||||
# 通知回调函数
|
||||
self._notify_callbacks(result)
|
||||
|
||||
def run(self) -> threading.Thread:
|
||||
"""
|
||||
运行管道
|
||||
Returns:
|
||||
threading.Thread: 返回已运行线程实例
|
||||
"""
|
||||
# 检查运行资源是否准备完毕
|
||||
self._pre_check()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("[ASRpipeline]管道开始运行")
|
||||
return self._thread
|
||||
|
||||
def _pre_check(self) -> None:
|
||||
"""
|
||||
预检查
|
||||
"""
|
||||
if self._is_baked is False:
|
||||
raise RuntimeError("[ASRpipeline]管道未烘焙,无法运行")
|
||||
|
||||
for functor_name, functor in self._functor_dict.items():
|
||||
if functor is None:
|
||||
raise RuntimeError(f"[ASRpipeline]functor{functor_name}异常")
|
||||
|
||||
for subqueue_name, subqueue in self._subqueue_dict.items():
|
||||
if subqueue is None:
|
||||
raise RuntimeError(f"[ASRpipeline]子队列{subqueue_name}异常")
|
||||
|
||||
def _run(self) -> None:
|
||||
"""
|
||||
真实的运行逻辑
|
||||
"""
|
||||
# 运行所有functor
|
||||
for functor_name, functor in self._functor_dict.items():
|
||||
logger.info(f"[ASRpipeline]运行{functor_name}functor")
|
||||
self._functor_dict[functor_name].run()
|
||||
|
||||
# 设置管道运行状态
|
||||
with self._status_lock:
|
||||
self._is_running = True
|
||||
self._stop_event = False
|
||||
|
||||
while self._is_running and not self._stop_event:
|
||||
try:
|
||||
data = self._input_queue.get(timeout=self._queue_timeout)
|
||||
# 检查是否是结束信号
|
||||
if data is None:
|
||||
logger.info("收到结束信号,管道准备停止")
|
||||
self._input_queue.task_done() # 标记结束信号已处理
|
||||
break
|
||||
|
||||
# 处理数据
|
||||
self._process(data)
|
||||
|
||||
# 标记任务完成
|
||||
self._input_queue.task_done()
|
||||
|
||||
except Empty:
|
||||
# 队列获取超时,继续等待
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"[ASRpipeline]管道处理数据出错: {str(e)}")
|
||||
break
|
||||
|
||||
logger.info("[ASRpipeline]管道停止运行")
|
||||
|
||||
def _process(self, data: Any) -> Any:
|
||||
"""
|
||||
处理数据
|
||||
参数:
|
||||
data: 输入数据
|
||||
返回:
|
||||
处理结果
|
||||
"""
|
||||
# 子类实现具体的处理逻辑
|
||||
self._subqueue_dict["original"].put(data)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
停止管道
|
||||
"""
|
||||
with self._status_lock:
|
||||
self._is_running = False
|
||||
self._stop_event = True
|
||||
for functor_name, functor in self._functor_dict.items():
|
||||
# logger.info(f"停止{functor_name}functor")
|
||||
if functor.stop():
|
||||
logger.info(f"[ASRpipeline]子Functor[{functor_name}]停止")
|
||||
else:
|
||||
logger.error(f"[ASRpipeline]子Functor[{functor_name}]停止失败")
|
||||
self._thread.join()
|
||||
logger.info("[ASRpipeline]管道停止")
|
||||
return True
|
3
src/pipeline/__init__.py
Normal file
3
src/pipeline/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from src.pipeline.base import PipelineBase, PipelineFactory
|
||||
|
||||
__all__ = ["PipelineBase", "PipelineFactory"]
|
151
src/pipeline/base.py
Normal file
151
src/pipeline/base.py
Normal file
@ -0,0 +1,151 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from queue import Queue, Empty
|
||||
from typing import List, Callable, Any, Optional
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineBase(ABC):
|
||||
"""
|
||||
管道基类
|
||||
定义了管道的基本接口和通用功能
|
||||
"""
|
||||
|
||||
def __init__(self, input_queue: Optional[Queue] = None):
|
||||
"""
|
||||
初始化管道
|
||||
参数:
|
||||
input_queue: 输入队列,用于接收数据
|
||||
"""
|
||||
self._input_queue = input_queue
|
||||
self._callbacks: List[Callable] = []
|
||||
self._is_running = False
|
||||
self._stop_event = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._stop_timeout = 5 # 默认停止超时时间(秒)
|
||||
self._queue_timeout = 1 # 队列获取超时时间(秒)
|
||||
|
||||
def set_input_queue(self, queue: Queue) -> None:
|
||||
"""
|
||||
设置输入队列
|
||||
参数:
|
||||
queue: 输入队列
|
||||
"""
|
||||
self._input_queue = queue
|
||||
|
||||
def add_callback(self, callback: Callable) -> None:
|
||||
"""
|
||||
添加回调函数
|
||||
参数:
|
||||
callback: 回调函数,接收处理结果
|
||||
"""
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def _notify_callbacks(self, result: Any) -> None:
|
||||
"""
|
||||
通知所有回调函数
|
||||
参数:
|
||||
result: 处理结果
|
||||
"""
|
||||
for callback in self._callbacks:
|
||||
try:
|
||||
callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"回调函数执行出错: {str(e)}")
|
||||
|
||||
@abstractmethod
|
||||
def _process(self, data: Any) -> Any:
|
||||
"""
|
||||
处理数据
|
||||
参数:
|
||||
data: 输入数据
|
||||
返回:
|
||||
处理结果
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> None:
|
||||
"""
|
||||
运行管道
|
||||
从输入队列获取数据并处理
|
||||
"""
|
||||
pass
|
||||
|
||||
def stop(self, timeout: Optional[float] = None) -> bool:
|
||||
"""
|
||||
停止管道
|
||||
参数:
|
||||
timeout: 停止超时时间(秒),None表示使用默认超时时间
|
||||
返回:
|
||||
bool: 是否成功停止
|
||||
"""
|
||||
if not self._is_running:
|
||||
return True
|
||||
|
||||
logger.info("正在停止管道...")
|
||||
self._stop_event = True
|
||||
self._is_running = False
|
||||
|
||||
# 等待线程结束
|
||||
if self._thread and self._thread.is_alive():
|
||||
timeout = timeout if timeout is not None else self._stop_timeout
|
||||
self._thread.join(timeout=timeout)
|
||||
|
||||
# 检查是否成功停止
|
||||
if self._thread.is_alive():
|
||||
logger.warning(f"管道停止超时({timeout}秒),强制终止")
|
||||
return False
|
||||
else:
|
||||
logger.info("管道已成功停止")
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
def force_stop(self) -> None:
|
||||
"""
|
||||
强制停止管道
|
||||
注意:这可能会导致资源未正确释放
|
||||
"""
|
||||
logger.warning("强制停止管道")
|
||||
self._stop_event = True
|
||||
self._is_running = False
|
||||
# 注意:Python的线程无法被强制终止,这里只是设置标志
|
||||
# 实际终止需要依赖操作系统的进程管理
|
||||
|
||||
|
||||
class PipelineFactory:
|
||||
"""
|
||||
管道工厂类
|
||||
用于创建管道实例
|
||||
"""
|
||||
|
||||
from src.pipeline.ASRpipeline import ASRPipeline
|
||||
def _create_pipeline_ASRpipeline(*args, **kwargs) -> ASRPipeline:
|
||||
"""
|
||||
创建ASR管道实例
|
||||
"""
|
||||
from src.pipeline.ASRpipeline import ASRPipeline
|
||||
pipeline = ASRPipeline()
|
||||
pipeline.set_config(kwargs["config"])
|
||||
pipeline.set_models(kwargs["models"])
|
||||
pipeline.set_audio_binary(kwargs["audio_binary"])
|
||||
pipeline.set_input_queue(kwargs["input_queue"])
|
||||
pipeline.add_callback(kwargs["callback"])
|
||||
pipeline.bake()
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def create_pipeline(cls, pipeline_name: str, *args, **kwargs) -> Any:
|
||||
"""
|
||||
创建管道实例
|
||||
"""
|
||||
if pipeline_name == "ASRpipeline":
|
||||
return cls._create_pipeline_ASRpipeline(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"不支持的管道类型: {pipeline_name}")
|
||||
|
0
src/pipeline/test.py
Normal file
0
src/pipeline/test.py
Normal file
281
src/runner.py
Normal file
281
src/runner.py
Normal file
@ -0,0 +1,281 @@
|
||||
"""
|
||||
运行器模块
|
||||
提供运行器基类和运行器类,用于管理音频数据和模型的交互。
|
||||
主要包含:
|
||||
- RunnerBase: 运行器基类,定义了基本接口
|
||||
- Runner: 运行器类,工厂模式实现
|
||||
- RunnerFactory: 运行器工厂类,用于创建运行器
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List
|
||||
from threading import Thread, Lock
|
||||
from queue import Queue
|
||||
import traceback
|
||||
import time
|
||||
|
||||
from src.audio_chunk import AudioChunk, AudioBinary
|
||||
from src.pipeline import Pipeline, PipelineFactory
|
||||
from src.model_loader import ModelLoader
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__, level="INFO")
|
||||
|
||||
audio_chunk = AudioChunk()
|
||||
models_loaded = ModelLoader()
|
||||
|
||||
|
||||
class RunnerBase(ABC):
|
||||
"""
|
||||
运行器基类
|
||||
定义了运行器的基本接口
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def adder(self, data: Any) -> None:
|
||||
"""
|
||||
添加数据
|
||||
参数:
|
||||
data: 要添加的数据
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_recevier(self, receiver: callable) -> None:
|
||||
"""
|
||||
添加数据接收者
|
||||
参数:
|
||||
receiver: 接收数据的回调函数
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class STTRunner(RunnerBase):
|
||||
"""
|
||||
运行器类
|
||||
负责管理资源和协调Pipeline的运行
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
audio_binary_list: List[AudioBinary],
|
||||
models: Dict[str, Any],
|
||||
pipeline_list: List[Pipeline],
|
||||
):
|
||||
"""
|
||||
初始化运行器
|
||||
参数:
|
||||
audio_binary_list: 音频二进制列表
|
||||
models: 模型字典
|
||||
pipeline_list: 管道列表
|
||||
queue_size: 队列大小
|
||||
stop_timeout: 停止超时时间(秒)
|
||||
"""
|
||||
# 接收资源
|
||||
self._audio_binary_list = audio_binary_list
|
||||
self._models = models
|
||||
self._pipeline_list = pipeline_list
|
||||
|
||||
# 线程控制
|
||||
self._lock = Lock()
|
||||
|
||||
# 消息队列
|
||||
self._input_queue = Queue(maxsize=1000)
|
||||
|
||||
# 停止控制
|
||||
self._stop_timeout = 10.0
|
||||
self._is_stopping = False
|
||||
|
||||
# 配置资源
|
||||
for pipeline in self._pipeline_list:
|
||||
# 设置输入队列
|
||||
pipeline.set_input_queue(self._input_queue)
|
||||
|
||||
# 配置资源
|
||||
pipeline.set_audio_binary(
|
||||
self._audio_binary_list[pipeline.get_config("audio_binary_name")]
|
||||
)
|
||||
pipeline.set_models(self._models)
|
||||
|
||||
def adder(self, data: Any) -> None:
|
||||
"""
|
||||
添加数据到输入队列
|
||||
参数:
|
||||
data: 要添加的数据
|
||||
"""
|
||||
if not self._pipeline_list:
|
||||
raise RuntimeError("没有可用的管道")
|
||||
if self._is_stopping:
|
||||
raise RuntimeError("运行器正在停止,无法添加数据")
|
||||
self._input_queue.put(data)
|
||||
|
||||
def add_recevier(self, receiver: callable) -> None:
|
||||
"""
|
||||
添加数据接收者
|
||||
参数:
|
||||
receiver: 接收数据的回调函数
|
||||
"""
|
||||
with self._lock:
|
||||
for pipeline in self._pipeline_list:
|
||||
pipeline.add_callback(receiver)
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
启动所有管道
|
||||
"""
|
||||
logger.info("[%s] 启动所有管道", self.__class__.__name__)
|
||||
if not self._pipeline_list:
|
||||
raise RuntimeError("没有可用的管道")
|
||||
|
||||
# 启动所有管道
|
||||
for pipeline in self._pipeline_list:
|
||||
thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}")
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline))
|
||||
|
||||
def stop(self, force: bool = False) -> bool:
|
||||
"""
|
||||
停止所有管道
|
||||
参数:
|
||||
force: 是否强制停止
|
||||
返回:
|
||||
bool: 是否成功停止
|
||||
"""
|
||||
if self._is_stopping:
|
||||
logger.warning("运行器已经在停止中")
|
||||
return False
|
||||
|
||||
self._is_stopping = True
|
||||
logger.info("正在停止运行器...")
|
||||
|
||||
try:
|
||||
# 发送结束信号
|
||||
self._input_queue.put(None)
|
||||
|
||||
# 停止所有管道
|
||||
success = True
|
||||
for pipeline in self._pipeline_list:
|
||||
if force:
|
||||
pipeline.force_stop()
|
||||
else:
|
||||
if not pipeline.stop(timeout=self._stop_timeout):
|
||||
logger.warning("管道 %s 停止超时", id(pipeline))
|
||||
success = False
|
||||
|
||||
# 等待队列处理完成
|
||||
try:
|
||||
start_time = time.time()
|
||||
while not self._input_queue.empty():
|
||||
if time.time() - start_time > self._stop_timeout:
|
||||
logger.warning(
|
||||
"等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理",
|
||||
self._stop_timeout,
|
||||
self._input_queue.qsize(),
|
||||
)
|
||||
success = False
|
||||
break
|
||||
time.sleep(0.1) # 避免过度消耗CPU
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
error_msg = str(e)
|
||||
error_traceback = traceback.format_exc()
|
||||
logger.error(
|
||||
"等待队列处理完成时发生错误:\n"
|
||||
"错误类型: %s\n"
|
||||
"错误信息: %s\n"
|
||||
"错误堆栈:\n%s",
|
||||
error_type,
|
||||
error_msg,
|
||||
error_traceback,
|
||||
)
|
||||
success = False
|
||||
|
||||
if success:
|
||||
logger.info("所有管道已成功停止")
|
||||
else:
|
||||
logger.warning(
|
||||
"部分管道停止失败,队列状态: 大小=%d, 是否为空=%s",
|
||||
self._input_queue.qsize(),
|
||||
self._input_queue.empty(),
|
||||
)
|
||||
|
||||
return success
|
||||
|
||||
finally:
|
||||
self._is_stopping = False
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
析构函数
|
||||
"""
|
||||
self.stop(force=True)
|
||||
|
||||
|
||||
class STTRunnerFactory:
|
||||
"""
|
||||
STT Runner工厂类
|
||||
用于创建运行器实例
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _create_runner(
|
||||
audio_binary_name: str,
|
||||
model_name_list: List[str],
|
||||
pipeline_name_list: List[str],
|
||||
) -> STTRunner:
|
||||
"""
|
||||
创建运行器
|
||||
参数:
|
||||
audio_binary_name: 音频二进制名称
|
||||
model_name_list: 模型名称列表
|
||||
pipeline_name_list: 管道名称列表
|
||||
返回:
|
||||
Runner实例
|
||||
"""
|
||||
audio_binary = audio_chunk.get_audio_binary(audio_binary_name)
|
||||
models: Dict[str, Any] = {
|
||||
model_name: models_loaded.models[model_name]
|
||||
for model_name in model_name_list
|
||||
}
|
||||
pipelines: List[Pipeline] = [
|
||||
PipelineFactory.create_pipeline(pipeline_name)
|
||||
for pipeline_name in pipeline_name_list
|
||||
]
|
||||
return STTRunner(
|
||||
audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_runner_from_config(
|
||||
cls,
|
||||
config: Dict[str, Any],
|
||||
) -> STTRunner:
|
||||
"""
|
||||
从配置创建运行器
|
||||
参数:
|
||||
config: 配置字典
|
||||
返回:
|
||||
Runner实例
|
||||
"""
|
||||
audio_binary_name = config["audio_binary_name"]
|
||||
model_name_list = config["model_name_list"]
|
||||
pipeline_name_list = config["pipeline_name_list"]
|
||||
return cls._create_runner(
|
||||
audio_binary_name, model_name_list, pipeline_name_list
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_runner_normal(cls) -> STTRunner:
|
||||
"""
|
||||
创建默认运行器
|
||||
返回:
|
||||
Runner实例
|
||||
"""
|
||||
audio_binary_name = None
|
||||
model_name_list = list(models_loaded.models.keys())
|
||||
pipeline_name_list = None
|
||||
return cls._create_runner(
|
||||
audio_binary_name, model_name_list, pipeline_name_list
|
||||
)
|
108
src/server.py
108
src/server.py
@ -43,7 +43,7 @@ async def clear_websocket():
|
||||
async def ws_serve(websocket, path):
|
||||
"""
|
||||
WebSocket服务主函数,处理客户端连接和消息
|
||||
|
||||
|
||||
参数:
|
||||
websocket: WebSocket连接对象
|
||||
path: 连接路径
|
||||
@ -51,13 +51,13 @@ async def ws_serve(websocket, path):
|
||||
frames = [] # 存储所有音频帧
|
||||
frames_asr = [] # 存储用于离线ASR的音频帧
|
||||
frames_asr_online = [] # 存储用于在线ASR的音频帧
|
||||
|
||||
|
||||
global websocket_users
|
||||
# await clear_websocket() # 清理现有连接(目前注释掉,允许多客户端)
|
||||
|
||||
|
||||
# 添加到用户集合
|
||||
websocket_users.add(websocket)
|
||||
|
||||
|
||||
# 初始化连接状态
|
||||
websocket.status_dict_asr = {}
|
||||
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
|
||||
@ -66,15 +66,15 @@ async def ws_serve(websocket, path):
|
||||
websocket.chunk_interval = 10
|
||||
websocket.vad_pre_idx = 0
|
||||
websocket.is_speaking = True # 默认用户正在说话
|
||||
|
||||
|
||||
# 语音检测状态
|
||||
speech_start = False
|
||||
speech_end_i = -1
|
||||
|
||||
|
||||
# 初始化配置
|
||||
websocket.wav_name = "microphone"
|
||||
websocket.mode = "2pass" # 默认使用两阶段识别模式
|
||||
|
||||
|
||||
print("新用户已连接", flush=True)
|
||||
|
||||
try:
|
||||
@ -84,11 +84,13 @@ async def ws_serve(websocket, path):
|
||||
if isinstance(message, str):
|
||||
try:
|
||||
messagejson = json.loads(message)
|
||||
|
||||
|
||||
# 更新各种配置参数
|
||||
if "is_speaking" in messagejson:
|
||||
websocket.is_speaking = messagejson["is_speaking"]
|
||||
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
|
||||
websocket.status_dict_asr_online["is_final"] = (
|
||||
not websocket.is_speaking
|
||||
)
|
||||
if "chunk_interval" in messagejson:
|
||||
websocket.chunk_interval = messagejson["chunk_interval"]
|
||||
if "wav_name" in messagejson:
|
||||
@ -97,11 +99,17 @@ async def ws_serve(websocket, path):
|
||||
chunk_size = messagejson["chunk_size"]
|
||||
if isinstance(chunk_size, str):
|
||||
chunk_size = chunk_size.split(",")
|
||||
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
|
||||
websocket.status_dict_asr_online["chunk_size"] = [
|
||||
int(x) for x in chunk_size
|
||||
]
|
||||
if "encoder_chunk_look_back" in messagejson:
|
||||
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
|
||||
websocket.status_dict_asr_online["encoder_chunk_look_back"] = (
|
||||
messagejson["encoder_chunk_look_back"]
|
||||
)
|
||||
if "decoder_chunk_look_back" in messagejson:
|
||||
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"]
|
||||
websocket.status_dict_asr_online["decoder_chunk_look_back"] = (
|
||||
messagejson["decoder_chunk_look_back"]
|
||||
)
|
||||
if "hotword" in messagejson:
|
||||
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
|
||||
if "mode" in messagejson:
|
||||
@ -111,11 +119,17 @@ async def ws_serve(websocket, path):
|
||||
|
||||
# 根据chunk_interval更新VAD的chunk_size
|
||||
websocket.status_dict_vad["chunk_size"] = int(
|
||||
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1] * 60 / websocket.chunk_interval
|
||||
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1]
|
||||
* 60
|
||||
/ websocket.chunk_interval
|
||||
)
|
||||
|
||||
|
||||
# 处理音频数据
|
||||
if len(frames_asr_online) > 0 or len(frames_asr) >= 0 or not isinstance(message, str):
|
||||
if (
|
||||
len(frames_asr_online) > 0
|
||||
or len(frames_asr) >= 0
|
||||
or not isinstance(message, str)
|
||||
):
|
||||
if not isinstance(message, str): # 二进制音频数据
|
||||
# 添加到帧缓冲区
|
||||
frames.append(message)
|
||||
@ -125,10 +139,12 @@ async def ws_serve(websocket, path):
|
||||
# 处理在线ASR
|
||||
frames_asr_online.append(message)
|
||||
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
|
||||
|
||||
|
||||
# 达到chunk_interval或最终帧时处理在线ASR
|
||||
if (len(frames_asr_online) % websocket.chunk_interval == 0 or
|
||||
websocket.status_dict_asr_online["is_final"]):
|
||||
if (
|
||||
len(frames_asr_online) % websocket.chunk_interval == 0
|
||||
or websocket.status_dict_asr_online["is_final"]
|
||||
):
|
||||
if websocket.mode == "2pass" or websocket.mode == "online":
|
||||
audio_in = b"".join(frames_asr_online)
|
||||
try:
|
||||
@ -136,26 +152,32 @@ async def ws_serve(websocket, path):
|
||||
except Exception as e:
|
||||
print(f"在线ASR处理错误: {e}")
|
||||
frames_asr_online = []
|
||||
|
||||
|
||||
# 如果检测到语音开始,收集帧用于离线ASR
|
||||
if speech_start:
|
||||
frames_asr.append(message)
|
||||
|
||||
|
||||
# VAD处理 - 语音活动检测
|
||||
try:
|
||||
speech_start_i, speech_end_i = await asr_service.async_vad(websocket, message)
|
||||
speech_start_i, speech_end_i = await asr_service.async_vad(
|
||||
websocket, message
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"VAD处理错误: {e}")
|
||||
|
||||
|
||||
# 检测到语音开始
|
||||
if speech_start_i != -1:
|
||||
speech_start = True
|
||||
# 计算开始偏移并收集前面的帧
|
||||
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
|
||||
frames_pre = frames[-beg_bias:] if beg_bias < len(frames) else frames
|
||||
beg_bias = (
|
||||
websocket.vad_pre_idx - speech_start_i
|
||||
) // duration_ms
|
||||
frames_pre = (
|
||||
frames[-beg_bias:] if beg_bias < len(frames) else frames
|
||||
)
|
||||
frames_asr = []
|
||||
frames_asr.extend(frames_pre)
|
||||
|
||||
|
||||
# 处理离线ASR (语音结束或用户停止说话)
|
||||
if speech_end_i != -1 or not websocket.is_speaking:
|
||||
if websocket.mode == "2pass" or websocket.mode == "offline":
|
||||
@ -164,13 +186,13 @@ async def ws_serve(websocket, path):
|
||||
await asr_service.async_asr(websocket, audio_in)
|
||||
except Exception as e:
|
||||
print(f"离线ASR处理错误: {e}")
|
||||
|
||||
|
||||
# 重置状态
|
||||
frames_asr = []
|
||||
speech_start = False
|
||||
frames_asr_online = []
|
||||
websocket.status_dict_asr_online["cache"] = {}
|
||||
|
||||
|
||||
# 如果用户停止说话,完全重置
|
||||
if not websocket.is_speaking:
|
||||
websocket.vad_pre_idx = 0
|
||||
@ -193,34 +215,34 @@ async def ws_serve(websocket, path):
|
||||
def start_server(args, asr_service_instance):
|
||||
"""
|
||||
启动WebSocket服务器
|
||||
|
||||
|
||||
参数:
|
||||
args: 命令行参数
|
||||
asr_service_instance: ASR服务实例
|
||||
"""
|
||||
global asr_service
|
||||
asr_service = asr_service_instance
|
||||
|
||||
|
||||
# 配置SSL (如果提供了证书)
|
||||
if args.certfile and len(args.certfile) > 0:
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile)
|
||||
|
||||
|
||||
start_server = websockets.serve(
|
||||
ws_serve, args.host, args.port,
|
||||
subprotocols=["binary"],
|
||||
ping_interval=None,
|
||||
ssl=ssl_context
|
||||
ws_serve,
|
||||
args.host,
|
||||
args.port,
|
||||
subprotocols=["binary"],
|
||||
ping_interval=None,
|
||||
ssl=ssl_context,
|
||||
)
|
||||
else:
|
||||
start_server = websockets.serve(
|
||||
ws_serve, args.host, args.port,
|
||||
subprotocols=["binary"],
|
||||
ping_interval=None
|
||||
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None
|
||||
)
|
||||
|
||||
|
||||
print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}")
|
||||
|
||||
|
||||
# 启动事件循环
|
||||
asyncio.get_event_loop().run_until_complete(start_server)
|
||||
asyncio.get_event_loop().run_forever()
|
||||
@ -229,14 +251,14 @@ def start_server(args, asr_service_instance):
|
||||
if __name__ == "__main__":
|
||||
# 解析命令行参数
|
||||
args = parse_args()
|
||||
|
||||
|
||||
# 加载模型
|
||||
print("正在加载模型...")
|
||||
models = load_models(args)
|
||||
print("模型加载完成!当前仅支持单个客户端同时连接!")
|
||||
|
||||
|
||||
# 创建ASR服务
|
||||
asr_service = ASRService(models)
|
||||
|
||||
|
||||
# 启动服务器
|
||||
start_server(args, asr_service)
|
||||
start_server(args, asr_service)
|
||||
|
@ -9,11 +9,11 @@ import json
|
||||
|
||||
class ASRService:
|
||||
"""ASR服务类,封装各种语音识别相关功能"""
|
||||
|
||||
|
||||
def __init__(self, models):
|
||||
"""
|
||||
初始化ASR服务
|
||||
|
||||
|
||||
参数:
|
||||
models: 包含各种预加载模型的字典
|
||||
"""
|
||||
@ -21,42 +21,41 @@ class ASRService:
|
||||
self.model_asr_streaming = models["asr_streaming"]
|
||||
self.model_vad = models["vad"]
|
||||
self.model_punc = models["punc"]
|
||||
|
||||
|
||||
async def async_vad(self, websocket, audio_in):
|
||||
"""
|
||||
语音活动检测
|
||||
|
||||
|
||||
参数:
|
||||
websocket: WebSocket连接
|
||||
audio_in: 二进制音频数据
|
||||
|
||||
|
||||
返回:
|
||||
tuple: (speech_start, speech_end) 语音开始和结束位置
|
||||
"""
|
||||
# 使用VAD模型分析音频段
|
||||
segments_result = self.model_vad.generate(
|
||||
input=audio_in,
|
||||
**websocket.status_dict_vad
|
||||
input=audio_in, **websocket.status_dict_vad
|
||||
)[0]["value"]
|
||||
|
||||
|
||||
speech_start = -1
|
||||
speech_end = -1
|
||||
|
||||
|
||||
# 解析VAD结果
|
||||
if len(segments_result) == 0 or len(segments_result) > 1:
|
||||
return speech_start, speech_end
|
||||
|
||||
|
||||
if segments_result[0][0] != -1:
|
||||
speech_start = segments_result[0][0]
|
||||
if segments_result[0][1] != -1:
|
||||
speech_end = segments_result[0][1]
|
||||
|
||||
|
||||
return speech_start, speech_end
|
||||
|
||||
|
||||
async def async_asr(self, websocket, audio_in):
|
||||
"""
|
||||
离线ASR处理
|
||||
|
||||
|
||||
参数:
|
||||
websocket: WebSocket连接
|
||||
audio_in: 二进制音频数据
|
||||
@ -64,42 +63,44 @@ class ASRService:
|
||||
if len(audio_in) > 0:
|
||||
# 使用离线ASR模型处理音频
|
||||
rec_result = self.model_asr.generate(
|
||||
input=audio_in,
|
||||
**websocket.status_dict_asr
|
||||
input=audio_in, **websocket.status_dict_asr
|
||||
)[0]
|
||||
|
||||
|
||||
# 如果有标点符号模型且识别出文本,则添加标点
|
||||
if self.model_punc is not None and len(rec_result["text"]) > 0:
|
||||
rec_result = self.model_punc.generate(
|
||||
input=rec_result["text"],
|
||||
**websocket.status_dict_punc
|
||||
input=rec_result["text"], **websocket.status_dict_punc
|
||||
)[0]
|
||||
|
||||
|
||||
# 如果识别出文本,发送到客户端
|
||||
if len(rec_result["text"]) > 0:
|
||||
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
||||
message = json.dumps({
|
||||
"mode": mode,
|
||||
"text": rec_result["text"],
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
})
|
||||
message = json.dumps(
|
||||
{
|
||||
"mode": mode,
|
||||
"text": rec_result["text"],
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
}
|
||||
)
|
||||
await websocket.send(message)
|
||||
else:
|
||||
# 如果没有音频数据,发送空文本
|
||||
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
||||
message = json.dumps({
|
||||
"mode": mode,
|
||||
"text": "",
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
})
|
||||
message = json.dumps(
|
||||
{
|
||||
"mode": mode,
|
||||
"text": "",
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
}
|
||||
)
|
||||
await websocket.send(message)
|
||||
|
||||
|
||||
async def async_asr_online(self, websocket, audio_in):
|
||||
"""
|
||||
在线ASR处理
|
||||
|
||||
|
||||
参数:
|
||||
websocket: WebSocket连接
|
||||
audio_in: 二进制音频数据
|
||||
@ -107,21 +108,24 @@ class ASRService:
|
||||
if len(audio_in) > 0:
|
||||
# 使用在线ASR模型处理音频
|
||||
rec_result = self.model_asr_streaming.generate(
|
||||
input=audio_in,
|
||||
**websocket.status_dict_asr_online
|
||||
input=audio_in, **websocket.status_dict_asr_online
|
||||
)[0]
|
||||
|
||||
|
||||
# 在2pass模式下,如果是最终帧则跳过(留给离线ASR处理)
|
||||
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False):
|
||||
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get(
|
||||
"is_final", False
|
||||
):
|
||||
return
|
||||
|
||||
|
||||
# 如果识别出文本,发送到客户端
|
||||
if len(rec_result["text"]):
|
||||
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
|
||||
message = json.dumps({
|
||||
"mode": mode,
|
||||
"text": rec_result["text"],
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
})
|
||||
await websocket.send(message)
|
||||
message = json.dumps(
|
||||
{
|
||||
"mode": mode,
|
||||
"text": rec_result["text"],
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
}
|
||||
)
|
||||
await websocket.send(message)
|
||||
|
3
src/utils/__init__.py
Normal file
3
src/utils/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .logger import get_module_logger, setup_root_logger
|
||||
|
||||
__all__ = ["get_module_logger", "setup_root_logger"]
|
122
src/utils/data_format.py
Normal file
122
src/utils/data_format.py
Normal file
@ -0,0 +1,122 @@
|
||||
"""
|
||||
处理各类音频数据与bytes的转换
|
||||
"""
|
||||
|
||||
import wave
|
||||
from pydub import AudioSegment
|
||||
import io
|
||||
|
||||
|
||||
def wav_to_bytes(wav_path: str) -> bytes:
|
||||
"""
|
||||
将WAV文件读取为bytes数据。
|
||||
|
||||
参数:
|
||||
wav_path (str): WAV文件的路径。
|
||||
|
||||
返回:
|
||||
bytes: WAV文件的原始字节数据。
|
||||
|
||||
异常:
|
||||
FileNotFoundError: 如果WAV文件不存在。
|
||||
wave.Error: 如果文件不是有效的WAV文件。
|
||||
"""
|
||||
try:
|
||||
with wave.open(wav_path, "rb") as wf:
|
||||
# 读取所有帧
|
||||
frames = wf.readframes(wf.getnframes())
|
||||
return frames
|
||||
except FileNotFoundError:
|
||||
# 可以选择记录日志或重新抛出,这里为了清晰直接重新抛出
|
||||
raise FileNotFoundError(f"错误: 未找到WAV文件 '{wav_path}'")
|
||||
except wave.Error as e:
|
||||
raise wave.Error(f"错误: 打开或读取WAV文件 '{wav_path}' 失败 - {e}")
|
||||
|
||||
|
||||
def bytes_to_wav(
|
||||
bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int
|
||||
):
|
||||
"""
|
||||
将bytes数据写入为WAV文件。
|
||||
|
||||
参数:
|
||||
bytes_data (bytes): 音频的字节数据。
|
||||
wav_path (str): 保存WAV文件的路径。
|
||||
nchannels (int): 声道数 (例如 1 for mono, 2 for stereo)。
|
||||
sampwidth (int): 采样宽度 (字节数, 例如 2 for 16-bit audio)。
|
||||
framerate (int): 采样率 (例如 44100, 16000)。
|
||||
|
||||
异常:
|
||||
wave.Error: 如果写入WAV文件失败。
|
||||
"""
|
||||
try:
|
||||
with wave.open(wav_path, "wb") as wf:
|
||||
wf.setnchannels(nchannels)
|
||||
wf.setsampwidth(sampwidth)
|
||||
wf.setframerate(framerate)
|
||||
wf.writeframes(bytes_data)
|
||||
except wave.Error as e:
|
||||
raise wave.Error(f"错误: 写入WAV文件 '{wav_path}' 失败 - {e}")
|
||||
except Exception as e:
|
||||
# 捕获其他可能的写入错误
|
||||
raise Exception(f"写入WAV文件 '{wav_path}' 时发生未知错误 - {e}")
|
||||
|
||||
|
||||
def mp3_to_bytes(mp3_path: str) -> bytes:
|
||||
"""
|
||||
将MP3文件转换为bytes数据 (原始PCM数据)。
|
||||
|
||||
参数:
|
||||
mp3_path (str): MP3文件的路径。
|
||||
|
||||
返回:
|
||||
bytes: MP3文件解码后的原始PCM字节数据。
|
||||
|
||||
异常:
|
||||
FileNotFoundError: 如果MP3文件不存在。
|
||||
pydub.exceptions.CouldntDecodeError: 如果MP3文件无法解码。
|
||||
"""
|
||||
try:
|
||||
audio = AudioSegment.from_mp3(mp3_path)
|
||||
# 获取原始PCM数据
|
||||
return audio.raw_data
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"错误: 未找到MP3文件 '{mp3_path}'")
|
||||
except Exception as e: # pydub 可能抛出多种解码相关的错误
|
||||
raise Exception(f"错误: 处理MP3文件 '{mp3_path}' 失败 - {e}")
|
||||
|
||||
|
||||
def bytes_to_mp3(
|
||||
bytes_data: bytes,
|
||||
mp3_path: str,
|
||||
frame_rate: int,
|
||||
channels: int,
|
||||
sample_width: int,
|
||||
bitrate: str = "192k",
|
||||
):
|
||||
"""
|
||||
将原始PCM bytes数据转换为MP3文件。
|
||||
|
||||
参数:
|
||||
bytes_data (bytes): 原始PCM字节数据。
|
||||
mp3_path (str): 保存MP3文件的路径。
|
||||
frame_rate (int): 原始PCM数据的采样率 (例如 44100)。
|
||||
channels (int): 原始PCM数据的声道数 (例如 1 for mono, 2 for stereo)。
|
||||
sample_width (int): 原始PCM数据的采样宽度 (字节数, 例如 2 for 16-bit)。
|
||||
bitrate (str): MP3编码的比特率 (例如 "128k", "192k", "320k")。
|
||||
|
||||
异常:
|
||||
Exception: 如果转换或写入MP3文件失败。
|
||||
"""
|
||||
try:
|
||||
# 从原始数据创建AudioSegment对象
|
||||
audio = AudioSegment(
|
||||
data=bytes_data,
|
||||
sample_width=sample_width,
|
||||
frame_rate=frame_rate,
|
||||
channels=channels,
|
||||
)
|
||||
# 导出为MP3
|
||||
audio.export(mp3_path, format="mp3", bitrate=bitrate)
|
||||
except Exception as e:
|
||||
raise Exception(f"错误: 转换或写入MP3文件 '{mp3_path}' 失败 - {e}")
|
91
src/utils/logger.py
Normal file
91
src/utils/logger.py
Normal file
@ -0,0 +1,91 @@
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def setup_logger(
|
||||
name: str = None,
|
||||
level: str = "INFO",
|
||||
log_file: Optional[str] = None,
|
||||
log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
date_format: str = "%Y-%m-%d %H:%M:%S",
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
设置并返回一个配置好的logger实例
|
||||
|
||||
Args:
|
||||
name: logger的名称,默认为None(使用root logger)
|
||||
level: 日志级别,默认为"INFO"
|
||||
log_file: 日志文件路径,默认为None(仅控制台输出)
|
||||
log_format: 日志格式
|
||||
date_format: 日期格式
|
||||
|
||||
Returns:
|
||||
logging.Logger: 配置好的logger实例
|
||||
"""
|
||||
# 获取logger实例
|
||||
logger = logging.getLogger(name)
|
||||
|
||||
# 设置日志级别
|
||||
level = getattr(logging, level.upper())
|
||||
logger.setLevel(level)
|
||||
|
||||
print(f"添加处理器 {name} {log_file} {log_format} {date_format}")
|
||||
# 创建格式器
|
||||
formatter = logging.Formatter(log_format, date_format)
|
||||
|
||||
# 添加控制台处理器
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# 如果指定了日志文件,添加文件处理器
|
||||
if log_file:
|
||||
# 确保日志目录存在
|
||||
log_path = Path(log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# 注意:移除了 propagate = False,允许日志传递
|
||||
return logger
|
||||
|
||||
|
||||
def setup_root_logger(level: str = "INFO", log_file: Optional[str] = None) -> None:
|
||||
"""
|
||||
配置根日志器
|
||||
|
||||
Args:
|
||||
level: 日志级别
|
||||
log_file: 日志文件路径
|
||||
"""
|
||||
setup_logger(None, level, log_file)
|
||||
|
||||
|
||||
def get_module_logger(
|
||||
module_name: str,
|
||||
level: Optional[str] = None, # 改为可选参数
|
||||
log_file: Optional[str] = None, # 一般不需要单独指定
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
获取模块级别的logger
|
||||
|
||||
Args:
|
||||
module_name: 模块名称,通常传入__name__
|
||||
level: 可选的日志级别,如果不指定则继承父级配置
|
||||
log_file: 可选的日志文件路径,一般不需要指定
|
||||
"""
|
||||
logger = logging.getLogger(module_name)
|
||||
|
||||
# 只有显式指定了level才设置
|
||||
if level:
|
||||
logger.setLevel(getattr(logging, level.upper()))
|
||||
|
||||
# 只有显式指定了log_file才添加文件处理器
|
||||
if log_file:
|
||||
setup_logger(module_name, level or "INFO", log_file)
|
||||
|
||||
return logger
|
17
test_main.py
Normal file
17
test_main.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""
|
||||
测试主函数
|
||||
请在tests目录下创建测试文件, 并在此文件中调用
|
||||
"""
|
||||
|
||||
from tests.pipeline.asr_test import test_asr_pipeline
|
||||
from src.utils.logger import get_module_logger, setup_root_logger
|
||||
|
||||
setup_root_logger(level="INFO", log_file="logs/test_main.log")
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
# from tests.functor.vad_test import test_vad_functor
|
||||
# logger.info("开始测试VAD函数器")
|
||||
# test_vad_functor()
|
||||
|
||||
logger.info("开始测试ASR管道")
|
||||
test_asr_pipeline()
|
@ -1 +1 @@
|
||||
"""FunASR WebSocket服务测试模块"""
|
||||
"""FunASR WebSocket服务测试模块"""
|
||||
|
124
tests/functor/vad_test.py
Normal file
124
tests/functor/vad_test.py
Normal file
@ -0,0 +1,124 @@
|
||||
"""
|
||||
Functor测试
|
||||
VAD测试
|
||||
"""
|
||||
|
||||
from src.functor.vad_functor import VADFunctor
|
||||
from src.functor.asr_functor import ASRFunctor
|
||||
from src.functor.spk_functor import SPKFunctor
|
||||
from queue import Queue, Empty
|
||||
from src.model_loader import ModelLoader
|
||||
from src.models import AudioBinary_Config, AudioBinary_data_list
|
||||
from src.utils.data_format import wav_to_bytes
|
||||
import time
|
||||
from src.utils.logger import get_module_logger
|
||||
from pydub import AudioSegment
|
||||
import soundfile
|
||||
|
||||
# 观察参数
|
||||
OVERWATCH = False
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
model_loader = ModelLoader()
|
||||
|
||||
|
||||
def test_vad_functor():
|
||||
# 加载模型
|
||||
args = {
|
||||
"asr_model": "paraformer-zh",
|
||||
"asr_model_revision": "v2.0.4",
|
||||
"vad_model": "fsmn-vad",
|
||||
"vad_model_revision": "v2.0.4",
|
||||
"auto_update": False,
|
||||
}
|
||||
model_loader.load_models(args)
|
||||
# 加载数据
|
||||
f_data, sample_rate = soundfile.read("tests/vad_example.wav")
|
||||
audio_config = AudioBinary_Config(
|
||||
chunk_size=200,
|
||||
chunk_stride=1600,
|
||||
sample_rate=sample_rate,
|
||||
sample_width=16,
|
||||
channels=1,
|
||||
)
|
||||
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
|
||||
audio_config.chunk_stride = chunk_stride
|
||||
# 创建输入队列
|
||||
input_queue = Queue()
|
||||
vad2asr_queue = Queue()
|
||||
vad2spk_queue = Queue()
|
||||
# 创建音频数据列表
|
||||
audio_binary_data_list = AudioBinary_data_list()
|
||||
|
||||
# 创建VAD函数器
|
||||
vad_functor = VADFunctor()
|
||||
# 设置输入队列
|
||||
vad_functor.set_input_queue(input_queue)
|
||||
# 设置音频配置
|
||||
vad_functor.set_audio_config(audio_config)
|
||||
# 设置音频数据列表
|
||||
vad_functor.set_audio_binary_data_list(audio_binary_data_list)
|
||||
# 设置回调函数
|
||||
vad_functor.add_callback(lambda x: print(f"vad callback: {x}"))
|
||||
vad_functor.add_callback(lambda x: vad2asr_queue.put(x))
|
||||
vad_functor.add_callback(lambda x: vad2spk_queue.put(x))
|
||||
# 设置模型
|
||||
vad_functor.set_model({"vad": model_loader.models["vad"]})
|
||||
# 启动VAD函数器
|
||||
vad_functor.run()
|
||||
|
||||
# 创建ASR函数器
|
||||
asr_functor = ASRFunctor()
|
||||
# 设置输入队列
|
||||
asr_functor.set_input_queue(vad2asr_queue)
|
||||
# 设置音频配置
|
||||
asr_functor.set_audio_config(audio_config)
|
||||
# 设置回调函数
|
||||
asr_functor.add_callback(lambda x: print(f"asr callback: {x}"))
|
||||
# 设置模型
|
||||
asr_functor.set_model({"asr": model_loader.models["asr"]})
|
||||
# 启动ASR函数器
|
||||
asr_functor.run()
|
||||
|
||||
# 创建SPK函数器
|
||||
spk_functor = SPKFunctor()
|
||||
# 设置输入队列
|
||||
spk_functor.set_input_queue(vad2spk_queue)
|
||||
# 设置音频配置
|
||||
spk_functor.set_audio_config(audio_config)
|
||||
# 设置回调函数
|
||||
spk_functor.add_callback(lambda x: print(f"spk callback: {x}"))
|
||||
# 设置模型
|
||||
spk_functor.set_model(
|
||||
{
|
||||
# 'spk': model_loader.models['spk']
|
||||
"spk": "fake_spk"
|
||||
}
|
||||
)
|
||||
# 启动SPK函数器
|
||||
spk_functor.run()
|
||||
|
||||
f_binary = f_data
|
||||
audio_clip_len = 200
|
||||
print(
|
||||
f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}"
|
||||
)
|
||||
for i in range(0, len(f_binary), audio_clip_len):
|
||||
binary_data = f_binary[i : i + audio_clip_len]
|
||||
input_queue.put(binary_data)
|
||||
# 等待VAD函数器结束
|
||||
|
||||
vad_functor.stop()
|
||||
print("[vad_test] VAD函数器结束")
|
||||
|
||||
asr_functor.stop()
|
||||
print("[vad_test] ASR函数器结束")
|
||||
|
||||
# 保存音频数据
|
||||
if OVERWATCH:
|
||||
for index in range(len(audio_binary_data_list)):
|
||||
save_path = f"tests/vad_test_output_{index}.wav"
|
||||
soundfile.write(
|
||||
save_path, audio_binary_data_list[index].binary_data, sample_rate
|
||||
)
|
121
tests/modelsuse.py
Normal file
121
tests/modelsuse.py
Normal file
@ -0,0 +1,121 @@
|
||||
"""
|
||||
模型使用测试
|
||||
此处主要用于各类调用模型的处理数据与输出格式
|
||||
请在主目录下test_main.py中调用
|
||||
将需要测试的模型定义在函数中进行测试, 函数名称需要与测试内容匹配。
|
||||
"""
|
||||
|
||||
from funasr import AutoModel
|
||||
from typing import List, Dict, Any
|
||||
from src.models import VADResponse
|
||||
import time
|
||||
|
||||
|
||||
def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
在线VAD模型使用
|
||||
"""
|
||||
chunk_size = 100 # ms
|
||||
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
|
||||
|
||||
vad_result = VADResponse()
|
||||
vad_result.time_chunk_index_callback = lambda index: print(f"回调: {index}")
|
||||
items = []
|
||||
import soundfile
|
||||
|
||||
speech, sample_rate = soundfile.read(file_path)
|
||||
chunk_stride = int(chunk_size * sample_rate / 1000)
|
||||
|
||||
cache = {}
|
||||
total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
|
||||
for i in range(total_chunk_num):
|
||||
time.sleep(0.1)
|
||||
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
|
||||
is_final = i == total_chunk_num - 1
|
||||
res = model.generate(
|
||||
input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size
|
||||
)
|
||||
if len(res[0]["value"]):
|
||||
vad_result += VADResponse.from_raw(res)
|
||||
for item in res[0]["value"]:
|
||||
items.append(item)
|
||||
vad_result.process_time_chunk()
|
||||
|
||||
# for item in items:
|
||||
# print(item)
|
||||
return vad_result
|
||||
|
||||
|
||||
def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
在线VAD模型使用
|
||||
测试LogicTrager
|
||||
在Rebuild版本后LogicTrager中已弃用
|
||||
"""
|
||||
from src.logic_trager import LogicTrager
|
||||
import soundfile
|
||||
|
||||
from src.config import parse_args
|
||||
|
||||
args = parse_args()
|
||||
|
||||
# from src.functor.model_loader import load_models
|
||||
# models = load_models(args)
|
||||
from src.model_loader import ModelLoader
|
||||
|
||||
models = ModelLoader(args)
|
||||
|
||||
chunk_size = 200 # ms
|
||||
from src.models import AudioBinary_Config
|
||||
import soundfile
|
||||
|
||||
speech, sample_rate = soundfile.read(file_path)
|
||||
chunk_stride = int(chunk_size * sample_rate / 1000)
|
||||
audio_config = AudioBinary_Config(
|
||||
sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size
|
||||
)
|
||||
|
||||
logic_trager = LogicTrager(models=models, audio_config=audio_config)
|
||||
for i in range(len(speech) // chunk_stride + 1):
|
||||
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
|
||||
logic_trager.push_binary_data(speech_chunk)
|
||||
|
||||
# for item in items:
|
||||
# print(item)
|
||||
return None
|
||||
|
||||
|
||||
def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
ASR模型使用
|
||||
离线ASR模型使用
|
||||
"""
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(
|
||||
model="paraformer-zh",
|
||||
model_revision="v2.0.4",
|
||||
vad_model="fsmn-vad",
|
||||
vad_model_revision="v2.0.4",
|
||||
# punc_model="ct-punc-c", punc_model_revision="v2.0.4",
|
||||
spk_model="cam++",
|
||||
spk_model_revision="v2.0.2",
|
||||
spk_mode="vad_segment",
|
||||
auto_update=False,
|
||||
)
|
||||
|
||||
import soundfile
|
||||
|
||||
from src.models import AudioBinary_Config
|
||||
import soundfile
|
||||
|
||||
speech, sample_rate = soundfile.read(file_path)
|
||||
result = model.generate(speech)
|
||||
return result
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# 请在主目录下调用test_main.py文件进行测试
|
||||
# vad_result = vad_model_use_online("tests/vad_example.wav")
|
||||
# vad_result = vad_model_use_online_logic("tests/vad_example.wav")
|
||||
# print(vad_result)
|
93
tests/pipeline/asr_test.py
Normal file
93
tests/pipeline/asr_test.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""
|
||||
Pipeline测试
|
||||
VAD+ASR+SPK(FAKE)
|
||||
"""
|
||||
|
||||
from src.pipeline.ASRpipeline import ASRPipeline
|
||||
from src.pipeline import PipelineFactory
|
||||
from src.models import AudioBinary_data_list, AudioBinary_Config
|
||||
from src.model_loader import ModelLoader
|
||||
from queue import Queue
|
||||
import soundfile
|
||||
import time
|
||||
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
OVAERWATCH = False
|
||||
|
||||
model_loader = ModelLoader()
|
||||
|
||||
|
||||
def test_asr_pipeline():
|
||||
# 加载模型
|
||||
args = {
|
||||
"asr_model": "paraformer-zh",
|
||||
"asr_model_revision": "v2.0.4",
|
||||
"vad_model": "fsmn-vad",
|
||||
"vad_model_revision": "v2.0.4",
|
||||
"spk_model": "cam++",
|
||||
"spk_model_revision": "v2.0.2",
|
||||
"audio_update": False,
|
||||
}
|
||||
models = model_loader.load_models(args)
|
||||
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
|
||||
audio_config = AudioBinary_Config(
|
||||
chunk_size=200,
|
||||
chunk_stride=1600,
|
||||
sample_rate=sample_rate,
|
||||
sample_width=16,
|
||||
channels=1,
|
||||
)
|
||||
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
|
||||
audio_config.chunk_stride = chunk_stride
|
||||
|
||||
# 创建参数Dict
|
||||
config = {
|
||||
"audio_config": audio_config,
|
||||
}
|
||||
|
||||
# 创建音频数据列表
|
||||
audio_binary_data_list = AudioBinary_data_list()
|
||||
|
||||
input_queue = Queue()
|
||||
|
||||
# 创建Pipeline
|
||||
# asr_pipeline = ASRPipeline()
|
||||
# asr_pipeline.set_models(models)
|
||||
# asr_pipeline.set_config(config)
|
||||
# asr_pipeline.set_audio_binary(audio_binary_data_list)
|
||||
# asr_pipeline.set_input_queue(input_queue)
|
||||
# asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
|
||||
# asr_pipeline.bake()
|
||||
asr_pipeline = PipelineFactory.create_pipeline(
|
||||
pipeline_name = "ASRpipeline",
|
||||
models=models,
|
||||
config=config,
|
||||
audio_binary=audio_binary_data_list,
|
||||
input_queue=input_queue,
|
||||
callback=lambda x: print(f"pipeline callback: {x}")
|
||||
)
|
||||
|
||||
# 运行Pipeline
|
||||
asr_instance = asr_pipeline.run()
|
||||
|
||||
|
||||
audio_clip_len = 200
|
||||
print(
|
||||
f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}"
|
||||
)
|
||||
for i in range(0, len(audio_data), audio_clip_len):
|
||||
input_queue.put(audio_data[i : i + audio_clip_len])
|
||||
|
||||
# time.sleep(10)
|
||||
# input_queue.put(None)
|
||||
|
||||
# 等待Pipeline结束
|
||||
# asr_instance.join()
|
||||
|
||||
time.sleep(5)
|
||||
asr_pipeline.stop()
|
||||
# asr_pipeline.stop()
|
@ -10,23 +10,23 @@ import os
|
||||
from unittest.mock import patch
|
||||
|
||||
# 将src目录添加到路径
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from src.config import parse_args
|
||||
|
||||
|
||||
def test_default_args():
|
||||
"""测试默认参数值"""
|
||||
with patch('sys.argv', ['script.py']):
|
||||
with patch("sys.argv", ["script.py"]):
|
||||
args = parse_args()
|
||||
|
||||
|
||||
# 检查服务器参数
|
||||
assert args.host == "0.0.0.0"
|
||||
assert args.port == 10095
|
||||
|
||||
|
||||
# 检查SSL参数
|
||||
assert args.certfile == ""
|
||||
assert args.keyfile == ""
|
||||
|
||||
|
||||
# 检查模型参数
|
||||
assert "paraformer" in args.asr_model
|
||||
assert args.asr_model_revision == "v2.0.4"
|
||||
@ -36,7 +36,7 @@ def test_default_args():
|
||||
assert args.vad_model_revision == "v2.0.4"
|
||||
assert "punc" in args.punc_model
|
||||
assert args.punc_model_revision == "v2.0.4"
|
||||
|
||||
|
||||
# 检查硬件配置
|
||||
assert args.ngpu == 1
|
||||
assert args.device == "cuda"
|
||||
@ -46,19 +46,26 @@ def test_default_args():
|
||||
def test_custom_args():
|
||||
"""测试自定义参数值"""
|
||||
test_args = [
|
||||
'script.py',
|
||||
'--host', 'localhost',
|
||||
'--port', '8080',
|
||||
'--certfile', 'cert.pem',
|
||||
'--keyfile', 'key.pem',
|
||||
'--asr_model', 'custom_model',
|
||||
'--ngpu', '0',
|
||||
'--device', 'cpu'
|
||||
"script.py",
|
||||
"--host",
|
||||
"localhost",
|
||||
"--port",
|
||||
"8080",
|
||||
"--certfile",
|
||||
"cert.pem",
|
||||
"--keyfile",
|
||||
"key.pem",
|
||||
"--asr_model",
|
||||
"custom_model",
|
||||
"--ngpu",
|
||||
"0",
|
||||
"--device",
|
||||
"cpu",
|
||||
]
|
||||
|
||||
with patch('sys.argv', test_args):
|
||||
|
||||
with patch("sys.argv", test_args):
|
||||
args = parse_args()
|
||||
|
||||
|
||||
# 检查自定义参数
|
||||
assert args.host == "localhost"
|
||||
assert args.port == 8080
|
||||
@ -66,4 +73,4 @@ def test_custom_args():
|
||||
assert args.keyfile == "key.pem"
|
||||
assert args.asr_model == "custom_model"
|
||||
assert args.ngpu == 0
|
||||
assert args.device == "cpu"
|
||||
assert args.device == "cpu"
|
||||
|
BIN
tests/vad_example.wav
Normal file
BIN
tests/vad_example.wav
Normal file
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user