Compare commits

...

16 Commits

Author SHA1 Message Date
Ziyang.Zhang
5a820b49e4 [代码结构]black . 对所有文件格式调整,无功能变化。 2025-06-12 15:49:43 +08:00
Ziyang.Zhang
5b94c40016 [代码重构中]编写融合VAD,ASR,SPK(FAKE)的ASRPipeline并完成测试,正常运行。 2025-06-06 17:26:08 +08:00
Ziyang.Zhang
3d8bf9de25 [代码重构中]创建假的SPKFunctor以测试消息队列流程是否正确,无问题,待进一步实现说话人识别,此外,考虑将一些共有内容写入BaseFunctor中。 2025-06-05 17:08:42 +08:00
Ziyang.Zhang
ff9bd70039 [代码重构中]初步构建ASRFunctor,与VADFunctor在vad_test.py中进行联调无问题,数据衔接正常。 2025-06-05 15:57:11 +08:00
Ziyang.Zhang
4e9e94d8dc [代码重构中]完善VADFunctor,测试持久化保存VAD片段的音频数据成功。 2025-06-05 13:43:23 +08:00
Ziyang.Zhang
b569b7e63d [代码重构中]测试VADFuntor中,发现字节流推理问题,待进一步研究 2025-06-03 17:41:59 +08:00
Ziyang.Zhang
f245c6e9df [代码重构中]编写ASRpipeline,管理funtor的线程启动,管理funtor间消息队列queue 2025-06-03 09:19:15 +08:00
Ziyang.Zhang
49cb428c23 [代码重构中]编写class STT_Runner中,将设计为线程启动。作为异步IO与资源管理模块。 2025-05-28 18:00:54 +08:00
Ziyang.Zhang
703a40e955 [代码重构中]重构model_loader与audio_chunk,全局单例模式管理模型加载与audiobinary数据存储单元类。删除readme中不需要的MIT许可证。 2025-05-28 10:35:35 +08:00
Ziyang.Zhang
040fc57e02 [代码重构中]重构Functor下的函数定义,修改为一个BaseFunctor+ModelLoader+DataCache进行基底构建。 2025-05-21 11:49:28 +08:00
Ziyang.Zhang
1392168126 Merge branch 'feature_logger' into dev
[Feature] 添加了logger用于管理日志,同时测试了ASR、PUNC、SPK模型效果;
[BUG] 发现BUG:使用funasr的一些模块会导致logger被更改,这一点需要进一步讨论解决方案
2025-04-16 14:30:40 +08:00
Ziyang.Zhang
eff22cb33e [Feature] 测试了后续的ASR、punc、spk效果; BUG:在调用funasr后,logger信息会被改变,导致格式变化,重复输出。 2025-04-16 14:30:11 +08:00
Ziyang.Zhang
66c9477e4b [Feature] 添加src/utils/logger文件控制程序日志输出,包括一个root配置器和logger生成器。 2025-04-16 10:46:09 +08:00
9d522fa137 Merge branch 'feature_vad' into dev
[项目结构变动] 分离了模型加载、功能实现、整体工作流等内容
[功能开发] 使用pydantic规范数据格式;开发VAD声音端点检测functor;
[测试] 完成了本地流式(online)的VAD检测,完成了 logic_traher(仅包含VAD与VAD检测结果)的工作流程测试
[未来内容] 1.完成ASR、时间戳、说话人识别;2.接入websocket服务。
2025-04-15 17:18:48 +08:00
f7138dcb39 [Feature] 调整VAD工作流程,规范VAD产出数据规范为 models/audiobinary中的AudioBinary_Chunk;完整测试LogicTrager VAD online流程。 2025-04-15 17:15:13 +08:00
8b69ff195f [Feature] Add /tests/modelsuse 测试实时VAD检测。 2025-04-15 13:53:06 +08:00
36 changed files with 3184 additions and 454 deletions

19
.cursorrules Normal file
View File

@ -0,0 +1,19 @@
You are an AI assistant specialized in Python development. Your approach emphasizes:
1. Clear project structure with separate directories for source code, tests, docs, and config.
2. Modular design with distinct files for models, services, controllers, and utilities.
3. Configuration management using environment variables.
4. Robust error handling and logging, including context capture.
5. Comprehensive testing with pytest.
6. Detailed documentation using docstrings and README files.
7. Dependency management via https://github.com/astral-sh/rye and virtual environments.
8. Code style consistency using Ruff.
9. CI/CD implementation with GitHub Actions or GitLab CI.
10. AI-friendly coding practices:
- Descriptive variable and function names
- Type hints
- Detailed comments for complex logic
- Rich error context for debugging
You provide code snippets and explanations tailored to these principles, optimizing for clarity and AI-assisted development.

View File

@ -108,6 +108,3 @@ docker-compose up -d
"is_final": false // 是否是最终结果
}
```
## 许可证
[MIT](LICENSE)

30
main.py Normal file
View File

@ -0,0 +1,30 @@
from funasr import AutoModel
chunk_size = 200 # ms
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
import soundfile
wav_file = "tests/vad_example.wav"
speech, sample_rate = soundfile.read(wav_file)
chunk_stride = int(chunk_size * sample_rate / 1000)
cache = {}
total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
for i in range(total_chunk_num):
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
is_final = i == total_chunk_num - 1
res = model.generate(
input=speech_chunk,
cache=cache,
is_final=is_final,
chunk_size=chunk_size,
disable_pbar=True,
)
if len(res[0]["value"]):
print(res)
print(f"len(speech): {len(speech)}")
print(f"len(speech_chunk): {len(speech_chunk)}")
print(f"total_chunk_num: {total_chunk_num}")
print(f"generateconfig: chunk_size: {chunk_size}, chunk_stride: {chunk_stride}")

View File

@ -11,4 +11,4 @@ FunASR WebSocket服务
- 支持多种识别模式(2pass/online/offline)
"""
__version__ = "0.1.0"
__version__ = "0.1.0"

194
src/audio_chunk.py Normal file
View File

@ -0,0 +1,194 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
音频数据块管理类 - 用于存储和处理16KHz音频数据
"""
from typing import List, Optional, Dict
from src.models import AudioBinary_Config, AudioBinary_data_list
# 配置日志
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__, level="INFO")
class AudioBinary:
"""
音频数据存储单元
用于存储二进制数据
面向Slice, 向Slice提供数据与接口
self._audio_config: AudioBinary_Config -- 音频参数配置
self._binary_data_list: AudioBinary_data_list -- 音频数据列表
self._slice_listener: List[callable] -- 切片监听器
AudioBinary_Config: Dict -- 音频参数配置
AudioBinary_data_list: List[bytes] -- 音频数据列表
"""
def __init__(self, *args):
"""
初始化音频数据块
参数:
*args: 可变参数
"""
# 音频参数配置
self._audio_config = AudioBinary_Config()
# 音频片段
self._binary_data_list: AudioBinary_data_list = AudioBinary_data_list()
# 切片监听器
self._slice_listener: List = []
if isinstance(args, Dict):
self._audio_config = AudioBinary_Config.AudioBinary_Config_from_dict(args)
elif isinstance(args, AudioBinary_Config):
self._audio_config = args
else:
raise ValueError("参数类型错误")
def add_slice_listener(self, slice_listener: callable) -> None:
"""
添加切片监听器
参数:
slice_listener: callable -- 切片监听器
"""
self._slice_listener.append(slice_listener)
def __add__(self, other: bytes):
"""
__add__ "+" 运算符的重载
使用方法:
audio_binary = audio_binary + bytes
添加音频数据块 add_binary_data 等效
但可以链式调用, 方便使用
参数:
other: bytes --音频数据块
"""
self._binary_data_list.append(other)
return self
def __iadd__(self, other: bytes):
"""
__iadd__ "+=" 运算符的重载
使用方法:
audio_binary += bytes
添加音频数据块 add_binary_data 等效
但可以链式调用, 方便使用
参数:
other: bytes --音频数据块
"""
self._binary_data_list.append(other)
return self
def add_binary_data(self, binary_data: bytes):
"""
添加音频数据块
参数:
binary_data: bytes --音频数据块
"""
self._binary_data_list.append(binary_data)
def rewrite_binary_data(self, target_index: int, binary_data: bytes):
"""
重写音频数据块
参数:
target_index: int -- 目标索引
binary_data: bytes --音频数据块
"""
self._binary_data_list.rewrite(target_index, binary_data)
def get_binary_data(
self,
start: int = 0,
end: Optional[int] = None,
) -> Optional[bytes]:
"""
获取指定索引的音频数据块
参数:
start: 开始索引
end: 结束索引
返回:
List[bytes]: 音频数据块
"""
if start >= len(self._binary_data_list):
return None
if end is None:
end = start + 1
end = min(end, len(self._binary_data_list))
return self._binary_data_list[start:end]
class AudioChunk:
"""
音频数据块管理类
管理两部分内容, AudioBinary和Slice
AudioBinary用于内部存储字节数据
Slice是AudioBinary的切片,用于外部接口
此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑
"""
_instance: Optional[AudioChunk] = None
def __new__(cls, *args, **kwargs):
"""
单例模式
"""
if cls._instance is None:
cls._instance = super(AudioChunk, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self) -> None:
"""
初始化AudioChunk实例
"""
self._audio_binary_list: Dict[str, AudioBinary] = {}
self._slice_listener: List[callable] = []
def get_audio_binary(
self,
binary_name: Optional[str] = None,
audio_config: Optional[AudioBinary_Config] = None,
) -> AudioBinary:
"""
获取音频数据块
参数:
binary_name: str -- 音频数据块名称
返回:
AudioBinary: 音频数据块
"""
if binary_name is None:
binary_name = "default"
if binary_name not in self._audio_binary_list:
self._audio_binary_list[binary_name] = AudioBinary(audio_config)
return self._audio_binary_list[binary_name]
@staticmethod
def _time2size(time_ms: int, audio_config: AudioBinary_Config) -> int:
"""
将时间(ms)转换为数据大小(字节)
参数:
time_ms: int -- 时间(ms)
audio_config: AudioBinary_Config -- 音频参数配置
返回:
int: 数据大小(字节)
"""
# 时间(ms)到字节(bytes)计算方法为: 时间(ms) * 采样率(Hz) * 通道数(1 or 2) * 采样位宽(16 or 24) / 1000
time_s = time_ms / 1000
bytes_per_sample = audio_config.sample_width * audio_config.channel
return int(time_s * audio_config.sample_rate * bytes_per_sample)
@staticmethod
def _size2time(size: int, audio_config: AudioBinary_Config) -> int:
"""
将数据大小(字节)转换为时间(ms)
参数:
size: int -- 数据大小(字节)
audio_config: AudioBinary_Config -- 音频参数配置
返回:
int: 时间(ms)
"""
# 字节(bytes)到时间(ms)计算方法为: 字节(bytes) * 1000 / (采样率(Hz) * 通道数(1 or 2) * 采样位宽(16 or 24))
bytes_per_sample = audio_config.sample_width * audio_config.channel
time_ms = size * 1000 // (audio_config.sample_rate * bytes_per_sample)
return time_ms

View File

@ -1,196 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
WebSocket客户端示例 - 用于测试语音识别服务
"""
import asyncio
import json
import websockets
import argparse
import numpy as np
import wave
import os
def parse_args():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description="FunASR WebSocket客户端")
parser.add_argument(
"--host",
type=str,
default="localhost",
help="服务器主机地址"
)
parser.add_argument(
"--port",
type=int,
default=10095,
help="服务器端口"
)
parser.add_argument(
"--audio_file",
type=str,
required=True,
help="要识别的音频文件路径"
)
parser.add_argument(
"--mode",
type=str,
default="2pass",
choices=["2pass", "online", "offline"],
help="识别模式: 2pass(默认), online, offline"
)
parser.add_argument(
"--chunk_size",
type=str,
default="5,10",
help="分块大小, 格式为'encoder_size,decoder_size'"
)
parser.add_argument(
"--use_ssl",
action="store_true",
help="是否使用SSL连接"
)
return parser.parse_args()
async def send_audio(websocket, audio_file, mode, chunk_size):
"""
发送音频文件到服务器进行识别
参数:
websocket: WebSocket连接
audio_file: 音频文件路径
mode: 识别模式
chunk_size: 分块大小
"""
# 打开并读取WAV文件
with wave.open(audio_file, "rb") as wav_file:
params = wav_file.getparams()
frames = wav_file.readframes(wav_file.getnframes())
# 音频文件信息
print(f"音频文件: {os.path.basename(audio_file)}")
print(f"采样率: {params.framerate}Hz, 通道数: {params.nchannels}")
print(f"采样位深: {params.sampwidth * 8}位, 总帧数: {params.nframes}")
# 设置配置参数
config = {
"mode": mode,
"chunk_size": chunk_size,
"wav_name": os.path.basename(audio_file),
"is_speaking": True
}
# 发送配置
await websocket.send(json.dumps(config))
# 模拟实时发送音频数据
chunk_size_bytes = 3200 # 每次发送100ms的16kHz音频
total_chunks = len(frames) // chunk_size_bytes
print(f"开始发送音频数据,共 {total_chunks} 个数据块...")
try:
for i in range(0, len(frames), chunk_size_bytes):
chunk = frames[i:i+chunk_size_bytes]
await websocket.send(chunk)
# 模拟实时每100ms发送一次
await asyncio.sleep(0.1)
# 显示进度
if (i // chunk_size_bytes) % 10 == 0:
print(f"已发送 {i // chunk_size_bytes}/{total_chunks} 数据块")
# 发送结束信号
await websocket.send(json.dumps({"is_speaking": False}))
print("音频数据发送完成")
except Exception as e:
print(f"发送音频时出错: {e}")
async def receive_results(websocket):
"""
接收并显示识别结果
参数:
websocket: WebSocket连接
"""
online_text = ""
offline_text = ""
try:
async for message in websocket:
# 解析服务器返回的JSON消息
result = json.loads(message)
mode = result.get("mode", "")
text = result.get("text", "")
is_final = result.get("is_final", False)
# 根据模式更新文本
if "online" in mode:
online_text = text
print(f"\r[在线识别] {online_text}", end="", flush=True)
elif "offline" in mode:
offline_text = text
print(f"\n[离线识别] {offline_text}")
# 如果是最终结果,打印完整信息
if is_final and offline_text:
print("\n最终识别结果:")
print(f"[离线识别] {offline_text}")
return
except Exception as e:
print(f"接收结果时出错: {e}")
async def main():
"""主函数"""
args = parse_args()
# WebSocket URI
protocol = "wss" if args.use_ssl else "ws"
uri = f"{protocol}://{args.host}:{args.port}"
print(f"连接到服务器: {uri}")
try:
# 创建WebSocket连接
async with websockets.connect(
uri,
subprotocols=["binary"]
) as websocket:
print("连接成功")
# 创建两个任务: 发送音频和接收结果
send_task = asyncio.create_task(
send_audio(websocket, args.audio_file, args.mode, args.chunk_size)
)
receive_task = asyncio.create_task(
receive_results(websocket)
)
# 等待任务完成
await asyncio.gather(send_task, receive_task)
except Exception as e:
print(f"连接服务器失败: {e}")
if __name__ == "__main__":
# 运行主函数
asyncio.run(main())

View File

@ -10,116 +10,79 @@ import argparse
def parse_args():
"""
解析命令行参数
返回:
argparse.Namespace: 解析后的参数对象
"""
parser = argparse.ArgumentParser(description="FunASR WebSocket服务器")
# 服务器配置
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="服务器主机地址例如localhost, 0.0.0.0"
"--host",
type=str,
default="0.0.0.0",
help="服务器主机地址例如localhost, 0.0.0.0",
)
parser.add_argument(
"--port",
type=int,
default=10095,
help="WebSocket服务器端口"
)
parser.add_argument("--port", type=int, default=10095, help="WebSocket服务器端口")
# SSL配置
parser.add_argument(
"--certfile",
type=str,
default="",
help="SSL证书文件路径"
)
parser.add_argument(
"--keyfile",
type=str,
default="",
help="SSL密钥文件路径"
)
parser.add_argument("--certfile", type=str, default="", help="SSL证书文件路径")
parser.add_argument("--keyfile", type=str, default="", help="SSL密钥文件路径")
# ASR模型配置
parser.add_argument(
"--asr_model",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="离线ASR模型从ModelScope获取"
help="离线ASR模型从ModelScope获取",
)
parser.add_argument(
"--asr_model_revision",
type=str,
default="v2.0.4",
help="离线ASR模型版本"
"--asr_model_revision", type=str, default="v2.0.4", help="离线ASR模型版本"
)
# 在线ASR模型配置
parser.add_argument(
"--asr_model_online",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
help="在线ASR模型从ModelScope获取"
help="在线ASR模型从ModelScope获取",
)
parser.add_argument(
"--asr_model_online_revision",
type=str,
default="v2.0.4",
help="在线ASR模型版本"
"--asr_model_online_revision",
type=str,
default="v2.0.4",
help="在线ASR模型版本",
)
# VAD模型配置
parser.add_argument(
"--vad_model",
type=str,
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
help="VAD语音活动检测模型从ModelScope获取"
help="VAD语音活动检测模型从ModelScope获取",
)
parser.add_argument(
"--vad_model_revision",
type=str,
default="v2.0.4",
help="VAD模型版本"
"--vad_model_revision", type=str, default="v2.0.4", help="VAD模型版本"
)
# 标点符号模型配置
parser.add_argument(
"--punc_model",
type=str,
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
help="标点符号模型从ModelScope获取"
help="标点符号模型从ModelScope获取",
)
parser.add_argument(
"--punc_model_revision",
type=str,
default="v2.0.4",
help="标点符号模型版本"
"--punc_model_revision", type=str, default="v2.0.4", help="标点符号模型版本"
)
# 硬件配置
parser.add_argument("--ngpu", type=int, default=1, help="GPU数量0表示仅使用CPU")
parser.add_argument(
"--ngpu",
type=int,
default=1,
help="GPU数量0表示仅使用CPU"
"--device", type=str, default="cuda", help="设备类型cuda或cpu"
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="设备类型cuda或cpu"
)
parser.add_argument(
"--ncpu",
type=int,
default=4,
help="CPU核心数"
)
parser.add_argument("--ncpu", type=int, default=4, help="CPU核心数")
return parser.parse_args()
@ -127,4 +90,4 @@ if __name__ == "__main__":
args = parse_args()
print("配置参数:")
for arg in vars(args):
print(f" {arg}: {getattr(args, arg)}")
print(f" {arg}: {getattr(args, arg)}")

4
src/functor/__init__.py Normal file
View File

@ -0,0 +1,4 @@
from .vad_functor import VADFunctor
from .base import FunctorFactory
__all__ = ["VADFunctor", "FunctorFactory"]

161
src/functor/asr_functor.py Normal file
View File

@ -0,0 +1,161 @@
"""
ASRFunctor
负责对音频片段进行ASR处理, 以ASR_Result进行callback
"""
from src.functor.base import BaseFunctor
from src.models import AudioBinary_data_list, AudioBinary_Config, VAD_Functor_result
from typing import Callable, List
from queue import Queue, Empty
import threading
# 日志
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class ASRFunctor(BaseFunctor):
"""
ASRFunctor
负责对音频片段进行ASR处理, 以ASR_Result进行callback
需要配置好 _model, _callback, _input_queue, _audio_config
否则无法run()启动线程
运行中, 使用reset_cache()重置缓存, 准备下次任务
使用stop()停止线程, 但需要等待input_queue为空
"""
def __init__(self) -> None:
super().__init__()
# 资源与配置
self._model: dict = {} # 模型
self._callback: List[Callable] = [] # 回调函数
self._input_queue: Queue = None # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置
# flag
self._is_running: bool = False
self._stop_event: bool = False
# 线程资源
self._thread: threading.Thread = None
# 状态锁
self._status_lock: threading.Lock = threading.Lock()
# 缓存
self._hotwords: List[str] = []
def reset_cache(self) -> None:
"""
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
"""
pass
def set_input_queue(self, queue: Queue) -> None:
"""
设置监听的输入消息队列
"""
self._input_queue = queue
def set_model(self, model: dict) -> None:
"""
设置推理模型
"""
self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
"""
设置音频配置
"""
self._audio_config = audio_config
logger.debug("ASRFunctor设置音频配置: %s", self._audio_config)
def add_callback(self, callback: Callable) -> None:
"""
向自身的_callback: List[Callable]回调函数列表中添加回调函数
"""
if not isinstance(self._callback, list):
self._callback = []
self._callback.append(callback)
def _do_callback(self, result: List[str]) -> None:
"""
回调函数
"""
text = result[0]["text"].replace(" ", "")
for callback in self._callback:
callback(text)
def _process(self, data: VAD_Functor_result) -> None:
"""
处理数据
"""
binary_data = data.audiobinary_data.binary_data
result = self._model["asr"].generate(
input=binary_data,
chunk_size=self._audio_config.chunk_size,
hotwords=self._hotwords,
)
self._do_callback(result)
def _run(self) -> None:
"""
线程运行逻辑
"""
with self._status_lock:
self._is_running = True
self._stop_event = False
# 运行逻辑
while self._is_running:
try:
data = self._input_queue.get(True, timeout=1)
self._process(data)
self._input_queue.task_done()
# 当队列为空时, 间隔1s检测是否进入停止事件。
except Empty:
if self._stop_event:
break
continue
# 其他异常
except Exception as e:
logger.error("ASRFunctor运行时发生错误: %s", e)
raise e
def run(self) -> threading.Thread:
"""
启动线程
Returns:
threading.Thread: 返回已运行线程实例
"""
self._pre_check()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
return self._thread
def _pre_check(self) -> bool:
"""
预检查
"""
if self._model is None:
raise ValueError("模型未设置")
if self._audio_config is None:
raise ValueError("音频配置未设置")
if self._input_queue is None:
raise ValueError("输入队列未设置")
if self._callback is None:
raise ValueError("回调函数未设置")
return True
def stop(self) -> bool:
"""
停止线程
"""
with self._status_lock:
self._stop_event = True
self._thread.join()
with self._status_lock:
self._is_running = False
return not self._thread.is_alive()

191
src/functor/base.py Normal file
View File

@ -0,0 +1,191 @@
"""
Functor基础模块
该模块定义了Functor的基类,所有功能性的类(如VADPUNCASRSPK等)都应继承自这个基类
基类提供了数据处理的基本框架,包括:
- 回调函数管理
- 模型配置管理
- 线程运行控制
主要类:
BaseFunctor: Functor抽象类
FunctorFactory: Functor工厂类
"""
from abc import ABC, abstractmethod
from typing import Callable, List
from queue import Queue
import threading
class BaseFunctor(ABC):
"""
Functor抽象类
该抽象类规定了所有的Functor类必须实现run()方法启动自身线程
属性:
_callback (Callable): 处理完成后的回调函数
_model (dict): 存储模型相关的配置和实例
"""
def __init__(self):
"""
初始化函数器
参数:
callback (Callable): 处理完成后的回调函数
model (dict): 模型相关的配置和实例
"""
self._callback: List[Callable] = []
self._model: dict = {}
# flag
self._is_running: bool = False
self._stop_event: bool = False
# 状态锁
self._status_lock: threading.Lock = threading.Lock()
# 线程资源
self._thread: threading.Thread = None
def add_callback(self, callback: Callable):
"""
添加回调函数
参数:
callback (Callable): 新的回调函数
"""
self._callback.append(callback)
def set_model(self, model: dict):
"""
设置模型配置
参数:
model (dict): 新的模型配置
"""
self._model = model
def set_input_queue(self, queue: Queue):
"""
设置输入队列
参数:
queue (Queue): 新的输入队列
"""
self._input_queue = queue
@abstractmethod
def _run(self):
"""
线程运行逻辑
返回:
当达到条件时触发callback
"""
@abstractmethod
def run(self):
"""
启动_run方法线程
返回:
线程实例
"""
@abstractmethod
def _pre_check(self):
"""
预检查
返回:
预检查结果
"""
@abstractmethod
def stop(self):
"""
停止线程
返回:
停止结果
"""
class FunctorFactory:
"""
Functor工厂类
该工厂类负责创建和配置Functor实例
主要方法:
make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
创建并配置Functor实例
"""
@classmethod
def make_functor(cls, functor_name: str, config: dict, models: dict) -> BaseFunctor:
"""
创建并配置Functor实例
参数:
funtor_name (str): Functor名称
config (dict): 配置信息
models (dict): 模型信息
返回:
BaseFunctor: 创建的Functor实例
"""
if functor_name == "vad":
return cls._make_vadfunctor(config=config, models=models)
elif functor_name == "asr":
return cls._make_asrfunctor(config=config, models=models)
elif functor_name == "spk":
return cls._make_spkfunctor(config=config, models=models)
else:
raise ValueError(f"不支持的Functor类型: {functor_name}")
def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建VAD Functor实例
"""
from src.functor.vad_functor import VADFunctor
audio_config = config["audio_config"]
model = {"vad": models["vad"]}
vad_functor = VADFunctor()
vad_functor.set_audio_config(audio_config)
vad_functor.set_model(model)
return vad_functor
def _make_asrfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建ASR Functor实例
"""
from src.functor.asr_functor import ASRFunctor
audio_config = config["audio_config"]
model = {"asr": models["asr"]}
asr_functor = ASRFunctor()
asr_functor.set_audio_config(audio_config)
asr_functor.set_model(model)
return asr_functor
def _make_spkfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建SPK Functor实例
"""
from src.functor.spk_functor import SPKFunctor
audio_config = config["audio_config"]
model = {"spk": models["spk"]}
spk_functor = SPKFunctor()
spk_functor.set_audio_config(audio_config)
spk_functor.set_model(model)
return spk_functor

122
src/functor/readme.md Normal file
View File

@ -0,0 +1,122 @@
# 对于Functor的解释
## Functor 文件夹作用
Functor文件夹用于存放所有功能性的类包括VAD、PUNC、ASR、SPK等。
## Functor 类的定义
所有类应继承于**基类**`BaseFunctor`
为了方便使用,我们对于**基类**的定义如下:
1. 函数内部使用的变量以单下划线开头,基类中包含:
* _model: Dict 存放模型相关的配置和实例
* _input_queue: Queue 监听的输入消息队列
* _thread: Threading.Thread 运行的线程实例
* _callback: List[Callable] 回调函数列表
* _is_running: bool 线程运行状态标志
* _stop_event: bool 停止事件标志
* _status_lock: threading.Lock 状态锁,用于线程同步
2. 对于使用的模型,请从统一的 **模型管理类`ModelLoader`** 中获取,由模型管理类统一进行加载、缓存和释放,`_model`存放类型为`dict`
3. 基类定义的核心方法:
* `add_callback(callback: Callable)`: 添加结果处理的回调函数
* `set_model(model: dict)`: 设置模型配置和实例
* `set_input_queue(queue: Queue)`: 设置输入数据队列
* `run()`: 启动处理线程(抽象方法)
* `stop()`: 停止处理线程(抽象方法)
* `_run()`: 线程运行的具体逻辑(抽象方法)
* `_pre_check()`: 运行前的预检查(抽象方法)
## 派生类实现要求
1. 必须实现的抽象方法:
* `_pre_check()`:
- 检查必要的配置是否完整(如模型、队列等)
- 检查运行环境是否满足要求
- 返回检查结果
* `_run()`:
- 实现具体的数据处理逻辑
- 从 _input_queue 获取输入数据
- 使用 _model 进行处理
- 通过 _callback 返回处理结果
* `run()`:
- 调用 _pre_check() 进行预检查
- 创建并启动处理线程
- 设置相关状态标志
* `stop()`:
- 安全停止处理线程
- 清理资源
- 重置状态标志
2. 建议实现的方法:
* `__str__`: 返回当前实例的状态信息
* 错误处理方法:处理运行过程中的异常情况
## 使用示例
```python
class MyFunctor(BaseFunctor):
def _pre_check(self):
if not self._model or not self._input_queue:
return False
return True
def _run(self):
while not self._stop_event:
try:
data = self._input_queue.get(timeout=1.0)
result = self._model['my_model'].process(data)
for callback in self._callback:
callback(result)
except Queue.Empty:
continue
except Exception as e:
logger.error(f"处理错误: {e}")
def run(self):
if not self._pre_check():
raise RuntimeError("预检查失败")
with self._status_lock:
if self._is_running:
return
self._is_running = True
self._stop_event = False
self._thread = threading.Thread(target=self._run)
self._thread.start()
def stop(self):
with self._status_lock:
if not self._is_running:
return
self._stop_event = True
if self._thread:
self._thread.join()
self._is_running = False
```
## 注意事项
1. 线程安全:
* 使用 _status_lock 保护状态变更
* 注意共享资源的访问控制
2. 错误处理:
* 在 _run() 中妥善处理异常
* 提供详细的错误日志
3. 资源管理:
* 确保在 stop() 中正确清理资源
* 避免资源泄露
4. 回调函数:
* 回调函数应该是非阻塞的
* 处理回调函数抛出的异常

145
src/functor/spk_functor.py Normal file
View File

@ -0,0 +1,145 @@
"""
SpkFunctor
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
"""
from src.functor.base import BaseFunctor
from src.models import AudioBinary_Config, VAD_Functor_result
from typing import Callable, List
from queue import Queue, Empty
import threading
# 日志
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class SPKFunctor(BaseFunctor):
"""
SPKFunctor
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
需要配置好 _model, _callback, _input_queue, _audio_config
否则无法run()启动线程
运行中, 使用reset_cache()重置缓存, 准备下次任务
使用stop()停止线程, 但需要等待input_queue为空
"""
def __init__(self) -> None:
super().__init__()
# 资源与配置
self._model: dict = {} # 模型
self._callback: List[Callable] = [] # 回调函数
self._input_queue: Queue = None # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置
def reset_cache(self) -> None:
"""
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
"""
pass
def set_input_queue(self, queue: Queue) -> None:
"""
设置监听的输入消息队列
"""
self._input_queue = queue
def set_model(self, model: dict) -> None:
"""
设置推理模型
"""
self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
"""
设置音频配置
"""
self._audio_config = audio_config
logger.debug("SpkFunctor设置音频配置: %s", self._audio_config)
def add_callback(self, callback: Callable) -> None:
"""
向自身的_callback: List[Callable]回调函数列表中添加回调函数
"""
if not isinstance(self._callback, list):
self._callback = []
self._callback.append(callback)
def _do_callback(self, result: List[str]) -> None:
"""
回调函数
"""
for callback in self._callback:
callback(result)
def _process(self, data: VAD_Functor_result) -> None:
"""
处理数据
"""
binary_data = data.audiobinary_data.binary_data
# result = self._model["spk"].generate(
# input=binary_data,
# chunk_size=self._audio_config.chunk_size,
# )
result = [{"result": "spk1", "score": {"spk1": 0.9, "spk2": 0.3}}]
self._do_callback(result)
def _run(self) -> None:
"""
线程运行逻辑
"""
with self._status_lock:
self._is_running = True
self._stop_event = False
# 运行逻辑
while self._is_running:
try:
data = self._input_queue.get(True, timeout=1)
self._process(data)
self._input_queue.task_done()
# 当队列为空时, 间隔1s检测是否进入停止事件。
except Empty:
if self._stop_event:
break
continue
# 其他异常
except Exception as e:
logger.error("SpkFunctor运行时发生错误: %s", e)
raise e
def run(self) -> threading.Thread:
"""
启动线程
Returns:
threading.Thread: 返回已运行线程实例
"""
self._pre_check()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
return self._thread
def _pre_check(self) -> bool:
"""
预检查
"""
if self._model is None:
raise ValueError("模型未设置")
if self._input_queue is None:
raise ValueError("输入队列未设置")
if self._callback is None:
raise ValueError("回调函数未设置")
return True
def stop(self) -> bool:
"""
停止线程
"""
with self._status_lock:
self._stop_event = True
self._thread.join()
with self._status_lock:
self._is_running = False
return not self._thread.is_alive()

315
src/functor/vad_functor.py Normal file
View File

@ -0,0 +1,315 @@
"""
VADFunctor
负责对音频片段进行VAD处理, 以VAD_Result进行callback
"""
import threading
from queue import Empty, Queue
from typing import List, Any, Callable
import numpy
from src.models import (
VAD_Functor_result,
AudioBinary_Config,
AudioBinary_data_list,
)
from src.functor.base import BaseFunctor
# 日志
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class VADFunctor(BaseFunctor):
"""
VADFunctor
负责对音频片段进行VAD处理, 以VAD_Result进行callback
需要配置好 _model, _callback, _input_queue, _audio_config, _audio_binary_data_list
否则无法run()启动线程
运行中, 使用reset_cache()重置缓存, 准备下次任务
使用stop()停止线程, 但需要等待input_queue为空
"""
def __init__(self) -> None:
super().__init__()
# 资源与配置
self._model: dict = {} # 模型
self._callback: List[Callable] = [] # 回调函数
self._input_queue: Queue = None # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置
self._audio_binary_data_list: AudioBinary_data_list = None # 音频数据列表
# flag
# 此处用到两个锁但都是为了截断_run线程考虑后续优化
self._is_running: bool = False
self._stop_event: bool = False
# 线程资源
self._thread: threading.Thread = None
# 状态锁
self._status_lock: threading.Lock = threading.Lock()
# 缓存
self._audio_cache: numpy.ndarray = None
self._audio_cache_preindex: int = 0
self._model_cache: dict = {}
self._cache_result_list = []
self._audiobinary_cache = None
def reset_cache(self) -> None:
"""
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
"""
self._audio_cache = None
self._audio_cache_preindex = 0
self._model_cache = {}
self._cache_result_list = []
self._audiobinary_cache = None
def set_input_queue(self, queue: Queue) -> None:
"""
设置监听的输入消息队列
"""
self._input_queue = queue
def set_model(self, model: dict) -> None:
"""
设置推理模型
"""
self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
"""
设置音频配置
"""
self._audio_config = audio_config
logger.debug("VADFunctor设置音频配置: %s", self._audio_config)
def set_audio_binary_data_list(
self, audio_binary_data_list: AudioBinary_data_list
) -> None:
"""
设置音频数据列表, 为Class AudioBinary_data_list类型
AudioBinary_data_list包含binary_data_list, 为list[_AudioBinary_data]类型
_AudioBinary_data包含binary_data, 为bytes/numpy.ndarray类型
"""
self._audio_binary_data_list = audio_binary_data_list
def add_callback(self, callback: Callable) -> None:
"""
向自身的_callback: List[Callable]回调函数列表中添加回调函数
"""
if not isinstance(self._callback, list):
self._callback = []
self._callback.append(callback)
def _do_callback(self, result: List[List[int]]) -> None:
"""
回调函数
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice
输入:
result: List[[start,end]] 处理所得VAD端点
其中若start==-1, 则表示前无端点, 若end==-1, 则表示后无端点
当处理得到一个完成片段时, 存入AudioBinary中, 并向队列中添加AudioBinary_Slice
输出:
None
"""
# 持久化缓存结果队列
for pair in result:
[start, end] = pair
# 若无前端点, 则向缓存队列中合并
if start == -1:
self._cache_result_list[-1][1] = end
else:
self._cache_result_list.append(pair)
while len(self._cache_result_list) > 1:
# 创建VAD片段
# 计算开始帧
start_frame = self._audio_config.ms2frame(self._cache_result_list[0][0])
start_frame -= self._audio_cache_preindex
# 计算结束帧
end_frame = self._audio_config.ms2frame(self._cache_result_list[0][1])
end_frame -= self._audio_cache_preindex
# 计算开始时间
vad_result = VAD_Functor_result.create_from_push_data(
audiobinary_data_list=self._audio_binary_data_list,
data=self._audiobinary_cache[start_frame:end_frame],
start_time=self._cache_result_list[0][0],
end_time=self._cache_result_list[0][1],
)
self._audio_cache_preindex += end_frame
self._audiobinary_cache = self._audiobinary_cache[end_frame:]
for callback in self._callback:
callback(vad_result)
self._cache_result_list.pop(0)
def _predeal_data(self, data: Any) -> None:
"""
预处理数据, 将数据缓存到_audio_cache和_audiobinary_cache中
"""
if self._audio_cache is None:
self._audio_cache = data
else:
# 拼接音频数据
if isinstance(self._audio_cache, numpy.ndarray):
self._audio_cache = numpy.concatenate((self._audio_cache, data))
elif isinstance(self._audio_cache, list):
self._audio_cache.append(data)
if self._audiobinary_cache is None:
self._audiobinary_cache = data
else:
# 拼接音频数据
if isinstance(self._audiobinary_cache, numpy.ndarray):
self._audiobinary_cache = numpy.concatenate(
(self._audiobinary_cache, data)
)
elif isinstance(self._audiobinary_cache, list):
self._audiobinary_cache.append(data)
def _process(self, data: Any):
"""
处理数据
使用model进行生成, 并使用_do_callback进行回调
"""
self._predeal_data(data)
if len(self._audio_cache) >= self._audio_config.chunk_stride:
result = self._model["vad"].generate(
input=self._audio_cache,
cache=self._model_cache,
chunk_size=self._audio_config.chunk_size,
is_final=False,
)
if len(result[0]["value"]) > 0:
self._do_callback(result[0]["value"])
# logger.debug(f"VADFunctor结果: {result[0]['value']}")
self._audio_cache = None
def _run(self):
"""
线程运行逻辑
监听输入队列, 当有数据时, 处理数据
当输入队列为空时, 间隔1s检测是否进入停止事件
"""
# 刷新运行状态
with self._status_lock:
self._is_running = True
self._stop_event = False
# 运行逻辑
while self._is_running:
try:
data = self._input_queue.get(True, timeout=1)
self._process(data)
self._input_queue.task_done()
# 当队列为空时, 间隔1s检测是否进入停止事件。
except Empty:
if self._stop_event:
break
continue
# 其他异常
except Exception as e:
logger.error("VADFunctor运行时发生错误: %s", e)
raise e
def run(self):
"""
启动 _run 线程, 并返回线程对象
"""
self._pre_check()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
return self._thread
def _pre_check(self) -> bool:
"""
检测硬性资源是否都已设置
"""
if self._model is None:
raise ValueError("模型未设置")
if self._audio_config is None:
raise ValueError("音频配置未设置")
if self._audio_binary_data_list is None:
raise ValueError("音频数据列表未设置")
if self._input_queue is None:
raise ValueError("输入队列未设置")
if self._callback is None:
raise ValueError("回调函数未设置")
return True
def stop(self):
"""
停止线程
通过设置_stop_event为True, 来在input_queue.get()循环为空时退出
"""
with self._status_lock:
self._stop_event = True
self._thread.join()
with self._status_lock:
self._is_running = False
return not self._thread.is_alive()
# class VAD:
# def __init__(
# self,
# VAD_model=None,
# audio_config: AudioBinary_Config = None,
# callback: Callable = None,
# ):
# # vad model
# self.VAD_model = VAD_model
# if self.VAD_model is None:
# self.VAD_model = AutoModel(
# model="fsmn-vad", model_revision="v2.0.4", disable_update=True
# )
# # audio config
# self.audio_config = audio_config
# # vad result
# self.vad_result = VADResponse(time_chunk_index_callback=callback)
# # audio binary poll
# self.audio_chunk = AudioChunk(audio_config=self.audio_config)
# self.cache = {}
# def push_binary_data(
# self,
# binary_data: bytes,
# ):
# # 压入二进制数据
# self.audio_chunk.add_chunk(binary_data)
# # 处理音频块
# res = self.VAD_model.generate(
# input=binary_data,
# cache=self.cache,
# chunk_size=self.audio_config.chunk_size,
# is_final=False,
# )
# # print("VAD generate", res)
# if len(res[0]["value"]):
# self.vad_result += VADResponse.from_raw(res)
# def set_callback(
# self,
# callback: Callable,
# ):
# self.vad_result.time_chunk_index_callback = callback
# def process_vad_result(self, callback: Callable = None):
# # 处理VAD结果
# callback = (
# callback
# if callback is not None
# else self.vad_result.time_chunk_index_callback
# )
# self.vad_result.process_time_chunk(
# lambda x: callback(
# AudioBinary_Chunk(
# start_time=x["start_time"],
# end_time=x["end_time"],
# chunk=self.audio_chunk.get_chunk(x["start_time"], x["end_time"]),
# )
# )
# )

176
src/logic_trager.py Normal file
View File

@ -0,0 +1,176 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
逻辑触发器类 - 用于处理音频数据并触发相应的处理逻辑
"""
from src.utils.logger import get_module_logger
from typing import Any, Dict, Type, Callable
# 配置日志
logger = get_module_logger(__name__, level="INFO")
class AutoAfterMeta(type):
"""
自动调用__after__函数的元类
实现单例模式
"""
_instances: Dict[Type, Any] = {} # 存储单例实例
def __new__(cls, name, bases, attrs):
# 遍历所有属性
for attr_name, attr_value in attrs.items():
# 如果是函数且不是以_开头
if callable(attr_value) and not attr_name.startswith("__"):
# 获取原函数
original_func = attr_value
# 创建包装函数
def make_wrapper(func):
def wrapper(self, *args, **kwargs):
# 执行原函数
result = func(self, *args, **kwargs)
# 构建_after_函数名
after_func_name = f"__after__{func.__name__}"
# 检查是否存在对应的_after_函数
if hasattr(self, after_func_name):
after_func = getattr(self, after_func_name)
if callable(after_func):
try:
# 调用_after_函数
after_func()
except Exception as e:
logger.error(f"调用{after_func_name}时出错: {e}")
return result
return wrapper
# 替换原函数
attrs[attr_name] = make_wrapper(original_func)
# 创建类
new_class = super().__new__(cls, name, bases, attrs)
return new_class
def __call__(cls, *args, **kwargs):
"""
重写__call__方法实现单例模式
当类被调用时即创建实例时执行
"""
if cls not in cls._instances:
# 如果实例不存在,创建新实例
cls._instances[cls] = super().__call__(*args, **kwargs)
logger.info(f"创建{cls.__name__}的新实例")
else:
logger.debug(f"返回{cls.__name__}的现有实例")
return cls._instances[cls]
"""
整体识别的处理逻辑
1.压入二进制音频信息
2.不断检测VAD
3.当检测到完整VAD时,将VAD的音频信息压入音频块,并清除对应二进制信息
4.对音频块进行语音转文字offline,时间戳预测,说话人识别
5.将识别结果整合压入结果队列
6.结果队列被压入时调用回调函数
1->2 __after__push_binary_data 外部压入二进制信息
2,3->4 __after__push_audio_chunk 内部压入音频块
4->5 push_result_queue 压入结果队列
5->6 __after__push_result_queue 调用回调函数
"""
from src.functor import VAD
from src.models import AudioBinary_Config
from src.models import AudioBinary_Chunk
from typing import List
class LogicTrager(metaclass=AutoAfterMeta):
"""逻辑触发器类"""
def __init__(
self,
audio_chunk_max_size: int = 1024 * 1024 * 10,
audio_config: AudioBinary_Config = None,
result_callback: Callable = None,
models: Dict[str, Any] = None,
):
"""初始化"""
# 存储音频块
self._audio_chunk: List[AudioBinary_Chunk] = []
# 存储二进制数据
self._audio_chunk_binary = b""
self._audio_chunk_max_size = audio_chunk_max_size
# 音频参数
self._audio_config = (
audio_config if audio_config is not None else AudioBinary_Config()
)
# 结果队列
self._result_queue = []
# 聚合结果回调函数
self._aggregate_result_callback = result_callback
# 组件
self._vad = VAD(VAD_model=models.get("vad"), audio_config=self._audio_config)
self._vad.set_callback(self.push_audio_chunk)
logger.info("初始化LogicTrager")
def push_binary_data(self, chunk: bytes) -> None:
"""
压入音频块至VAD模块
参数:
chunk: 音频数据块
"""
# print("LogicTrager push_binary_data", len(chunk))
self._vad.push_binary_data(chunk)
self.__after__push_binary_data()
def __after__push_binary_data(self) -> None:
"""
添加音频块后处理
"""
# print("LogicTrager __after__push_binary_data")
self._vad.process_vad_result()
def push_audio_chunk(self, chunk: AudioBinary_Chunk) -> None:
"""
音频处理
"""
logger.info(
"LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(
chunk.start_time, chunk.end_time, len(chunk.chunk)
)
)
self._audio_chunk.append(chunk)
def __after__push_audio_chunk(self) -> None:
"""
压入音频块后处理
"""
pass
def push_result_queue(self, result: Dict[str, Any]) -> None:
"""
压入结果队列
"""
self._result_queue.append(result)
def __after__push_result_queue(self) -> None:
"""
压入结果队列后处理
"""
logger.info("FINISH Result=")
pass
def __call__(self):
"""调用函数"""
pass

126
src/model_loader.py Normal file
View File

@ -0,0 +1,126 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
模型加载模块 - 负责加载各种语音识别相关模型
"""
try:
# 导入FunASR库
from funasr import AutoModel
except ImportError as exc:
raise ImportError("未找到funasr库, 请先安装: pip install funasr") from exc
# 日志模块
from src.utils import get_module_logger
logger = get_module_logger(__name__)
# 单例模式
class ModelLoader:
"""
ModelLoader类是单例模式, 程序生命周期全局唯一, 负责加载模型到字典中
一般的, 可以直接call ModelLoader()来获取加载的模型
也可以通过ModelLoader实例(args)或ModelloaderInstance.load_models(args)来初始化, 并加载模型
"""
_instance = None
def __new__(cls, *args, **kwargs):
"""
单例模式
"""
if cls._instance is None:
cls._instance = super(ModelLoader, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self, args=None):
"""
初始化ModelLoader实例
"""
self.models = {}
logger.debug("初始化ModelLoader")
if args is not None:
self.__call__(args)
def __call__(self, args=None):
"""
调用ModelLoader实例时, 如果模型字典为空, 则加载模型
"""
# 如果模型字典为空, 则加载模型
if self.models == {} or self.models is None:
if args.asr_model is not None:
self.models = self.load_models(args)
# 直接调用等于调用self.models
return self.models
def _load_model(self, input_model_args: dict, model_type: str):
"""
加载单个模型
参数:
model_args: 模型加载字典
model_type: 模型类型, 用于确定使用哪个模型参数
返回:
AutoModel: 加载的模型实例
"""
# 默认配置
default_config = {
"model": None,
"model_revision": None,
"ngpu": 0,
"ncpu": 1,
"device": "cpu",
"disable_pbar": True,
"disable_log": True,
"disable_update": True,
}
# 从args中获取配置, 如果存在则覆盖默认值
model_args = default_config.copy()
for key, value in default_config.items():
if key in ["model", "model_revision"]:
# 特殊处理model和model_revision, 因为它们需要model_type前缀
if key == "model":
value = input_model_args.get(f"{model_type}_model", None)
else:
value = input_model_args.get(f"{model_type}_model_revision", None)
else:
value = input_model_args.get(key, None)
if value is not None:
logger.debug("替换%s模型参数: %s = %s", model_type, key, value)
model_args[key] = value
# 验证必要参数
if not model_args["model"]:
raise ValueError(f"未指定{model_type}模型路径")
try:
# 使用 % 格式化替代 f-string,避免不必要的字符串格式化开销
logger.debug("正在加载%s模型: %s", model_type, model_args["model"])
model = AutoModel(**model_args)
return model
except Exception as e:
logger.error("加载%s模型失败: %s", model_type, str(e))
raise
def load_models(self, args):
"""
加载所有需要的模型
参数:
args: 命令行参数, 包含模型配置
返回:
dict: 包含所有加载的模型的字典
"""
logger.info("ModelLoader加载模型")
# 初始化模型字典
self.models = {}
# 加载离线ASR模型
# 检查对应键是否存在
model_list = ["asr", "asr_online", "vad", "punc", "spk"]
for model_name in model_list:
name_model = f"{model_name}_model"
name_model_revision = f"{model_name}_model_revision"
if name_model in args:
logger.debug("加载%s模型", model_name)
self.models[model_name] = self._load_model(args, model_name)
logger.info("所有模型加载完成")
return self.models

View File

@ -1,79 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
模型加载模块 - 负责加载各种语音识别相关模型
"""
def load_models(args):
"""
加载所有需要的模型
参数:
args: 命令行参数包含模型配置
返回:
dict: 包含所有加载的模型的字典
"""
try:
# 导入FunASR库
from funasr import AutoModel
except ImportError:
raise ImportError("未找到funasr库请先安装: pip install funasr")
# 初始化模型字典
models = {}
# 1. 加载离线ASR模型
print(f"正在加载ASR离线模型: {args.asr_model}")
models["asr"] = AutoModel(
model=args.asr_model,
model_revision=args.asr_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
# 2. 加载在线ASR模型
print(f"正在加载ASR在线模型: {args.asr_model_online}")
models["asr_streaming"] = AutoModel(
model=args.asr_model_online,
model_revision=args.asr_model_online_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
# 3. 加载VAD模型
print(f"正在加载VAD模型: {args.vad_model}")
models["vad"] = AutoModel(
model=args.vad_model,
model_revision=args.vad_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
# 4. 加载标点符号模型(如果指定)
if args.punc_model:
print(f"正在加载标点符号模型: {args.punc_model}")
models["punc"] = AutoModel(
model=args.punc_model,
model_revision=args.punc_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
else:
models["punc"] = None
print("未指定标点符号模型,将不使用标点符号")
print("所有模型加载完成")
return models

9
src/models/__init__.py Normal file
View File

@ -0,0 +1,9 @@
from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
from .vad import VAD_Functor_result
__all__ = [
"AudioBinary_Config",
"AudioBinary_data_list",
"_AudioBinary_data",
"VAD_Functor_result",
]

158
src/models/audio.py Normal file
View File

@ -0,0 +1,158 @@
from pydantic import BaseModel, Field, validator
from typing import List, Any
import numpy
from src.utils import get_module_logger
logger = get_module_logger(__name__)
binary_data_types = (bytes, numpy.ndarray)
class AudioBinary_Config(BaseModel):
"""二进制音频块配置信息"""
class Config:
arbitrary_types_allowed = True
audio_data: binary_data_types = Field(description="音频数据", default=None)
chunk_size: int = Field(description="块大小", default=100)
chunk_stride: int = Field(description="块步长", default=1600)
sample_rate: int = Field(description="采样率", default=16000)
sample_width: int = Field(description="采样位宽", default=2)
channels: int = Field(description="通道数", default=1)
# 从Dict中加载
@classmethod
def AudioBinary_Config_from_dict(cls, data: dict):
return cls(**data)
def ms2frame(self, ms: int) -> int:
"""
将毫秒转换为帧
"""
return int(ms * self.sample_rate / 1000)
def frame2ms(self, frame: int) -> int:
"""
将帧转换为毫秒
"""
return int(frame * 1000 / self.sample_rate)
class _AudioBinary_data(BaseModel):
"""音频数据"""
binary_data: binary_data_types = Field(description="音频二进制数据", default=None)
class Config:
arbitrary_types_allowed = True
@validator("binary_data")
def validate_binary_data(cls, v):
"""
验证音频数据
Args:
v: 音频数据
Returns:
binary_data_types: 音频数据
"""
if not isinstance(v, (bytes, numpy.ndarray)):
logger.warning(
"[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查",
cls.__class__.__name__,
type(v),
)
return v
def __len__(self):
"""
获取音频数据长度
Returns:
int: 音频数据长度
"""
return len(self.binary_data)
def __init__(self, binary_data: binary_data_types):
"""
初始化音频数据
Args:
binary_data: 音频数据
"""
logger.debug(
"[%s]初始化音频数据, 数据类型为%s",
self.__class__.__name__,
type(binary_data),
)
super().__init__(binary_data=binary_data)
def __getitem__(self):
"""
当获取数据时, 直接返回binary_data
Returns:
binary_data_types: 音频数据
"""
return self.binary_data
class AudioBinary_data_list(BaseModel):
"""音频数据列表"""
binary_data_list: List[_AudioBinary_data] = Field(
description="音频数据列表", default=[]
)
class Config:
arbitrary_types_allowed = True
def push_data(self, data: binary_data_types) -> int:
"""
添加音频数据
Args:
data: 音频数据
Returns:
int: 数据在binary_data_list中的索引
"""
self.binary_data_list.append(_AudioBinary_data(binary_data=data))
return len(self.binary_data_list) - 1
def __getitem__(self, index: int):
"""
获取音频数据
Args:
index: 音频数据在binary_data_list中的索引
Returns:
_AudioBinary_data: 音频数据
"""
return self.binary_data_list[index]
def __len__(self):
"""
获取音频数据列表长度
Returns:
int: 音频数据列表长度
"""
return len(self.binary_data_list)
# class AudioBinary_Slice(BaseModel):
# """音频块切片"""
# target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None)
# start_index: int = Field(description="开始位置", default=0)
# end_index: int = Field(description="结束位置", default=-1)
# @validator('start_index')
# def validate_start_index(cls, v):
# if v < 0:
# raise ValueError("start_index必须大于0")
# return v
# @validator('end_index')
# def validate_end_index(cls, v):
# if v < cls.start_index:
# logger.debug("[%s]end_index小于start_index, 将end_index设置为start_index", cls.__class__.__name__)
# v = cls.start_index
# return v
# def __call__(self):
# return self.target_Binary(self.start_index, self.end_index)

91
src/models/vad.py Normal file
View File

@ -0,0 +1,91 @@
from pydantic import BaseModel, Field, validator
from typing import List, Optional, Callable, Any
from .audio import AudioBinary_data_list, _AudioBinary_data
class VAD_Functor_result(BaseModel):
"""
VADFunctor结果
"""
audiobinary_data_list: AudioBinary_data_list = Field(description="音频数据列表")
audiobinary_index: int = Field(description="音频数据索引")
audiobinary_data: _AudioBinary_data = Field(
description="音频数据, 指向AudioBinary_data"
)
start_time: int = Field(description="开始时间", is_required=True)
end_time: int = Field(description="结束时间", is_required=True)
@validator("audiobinary_data_list")
def validate_audiobinary_data_list(cls, v):
if not isinstance(v, AudioBinary_data_list):
raise ValueError("audiobinary_data_list必须是AudioBinary_data_list类型")
return v
@validator("audiobinary_index")
def validate_audiobinary_index(cls, v):
if not isinstance(v, int):
raise ValueError("audiobinary_index必须是int类型")
if v < 0:
raise ValueError("audiobinary_index必须大于0")
return v
@validator("audiobinary_data")
def validate_audiobinary_data(cls, v):
if not isinstance(v, _AudioBinary_data):
raise ValueError("audiobinary_data必须是AudioBinary_data类型")
return v
@validator("start_time")
def validate_start_time(cls, v):
if not isinstance(v, int):
raise ValueError("start_time必须是int类型")
if v < 0:
raise ValueError("start_time必须大于0")
return v
@validator("end_time")
def validate_end_time(cls, v, values):
if not isinstance(v, int):
raise ValueError("end_time必须是int类型")
if "start_time" in values and v <= values["start_time"]:
raise ValueError("end_time必须大于start_time")
return v
@classmethod
def create_from_push_data(
cls,
audiobinary_data_list: AudioBinary_data_list,
data: Any,
start_time: int,
end_time: int,
):
"""
创建VAD片段
"""
index = audiobinary_data_list.push_data(data)
return cls(
audiobinary_data_list=audiobinary_data_list,
audiobinary_index=index,
audiobinary_data=audiobinary_data_list[index],
start_time=start_time,
end_time=end_time,
)
def __len__(self):
"""
获取音频数据长度
"""
return len(self.audiobinary_data.binary_data)
def __str__(self):
"""
字符串展示内容
"""
tostr = f"audiobinary_data_index: {self.audiobinary_index}\n"
tostr += f"start_time: {self.start_time}\n"
tostr += f"end_time: {self.end_time}\n"
tostr += f"data_length: {len(self.audiobinary_data.binary_data)}\n"
tostr += f"data_type: {type(self.audiobinary_data.binary_data)}\n"
return tostr

265
src/pipeline/ASRpipeline.py Normal file
View File

@ -0,0 +1,265 @@
from src.pipeline.base import PipelineBase
from typing import Dict, Any, Callable
from queue import Queue, Empty
from src.utils import get_module_logger
from src.models import AudioBinary_data_list
import threading
logger = get_module_logger(__name__)
class ASRPipeline(PipelineBase):
"""
管道类
实现具体的处理逻辑
"""
def __init__(self, *args, **kwargs):
"""
初始化管道
"""
super().__init__(*args, **kwargs)
self._config: Dict[str, Any] = {}
self._functor_dict: Dict[str, Any] = {}
self._subqueue_dict: Dict[str, Any] = {}
self._is_baked: bool = False
self._input_queue: Queue = None
self._audio_binary_data_list: AudioBinary_data_list = None
self._status_lock = threading.Lock()
self._is_running: bool = False
self._stop_event: bool = False
def set_config(self, config: Dict[str, Any]) -> None:
"""
设置配置
参数:
config: Dict[str, Any] 配置
"""
self._config = config
def get_config(self) -> Dict[str, Any]:
"""
获取配置
返回:
Dict[str, Any] 配置
"""
return self._config
def set_audio_binary(self, audio_binary: AudioBinary_data_list) -> None:
"""
设置音频二进制存储单元
参数:
audio_binary: 音频二进制
"""
self._audio_binary = audio_binary
def set_models(self, models: Dict[str, Any]) -> None:
"""
设置模型
"""
self._models = models
def set_input_queue(self, input_queue: Queue) -> None:
"""
设置输入队列
"""
self._input_queue = input_queue
def bake(self) -> None:
"""
烘焙管道
"""
self._pre_check_resource()
self._init_functor()
self._is_baked = True
def _pre_check_resource(self) -> None:
"""
预检查资源
"""
if self._input_queue is None:
raise RuntimeError("[ASRpipeline]输入队列未设置")
if self._functor_dict is None:
raise RuntimeError("[ASRpipeline]functor字典未设置")
if self._subqueue_dict is None:
raise RuntimeError("[ASRpipeline]子队列字典未设置")
if self._audio_binary is None:
raise RuntimeError("[ASRpipeline]音频存储单元未设置")
def _init_functor(self) -> None:
"""
初始化函数
"""
try:
from src.functor import FunctorFactory
# 加载VAD、asr、spk functor
self._functor_dict["vad"] = FunctorFactory.make_functor(
functor_name="vad", config=self._config, models=self._models
)
self._functor_dict["asr"] = FunctorFactory.make_functor(
functor_name="asr", config=self._config, models=self._models
)
self._functor_dict["spk"] = FunctorFactory.make_functor(
functor_name="spk", config=self._config, models=self._models
)
# 创建音频数据存储单元
self._audio_binary_data_list = AudioBinary_data_list()
self._functor_dict["vad"].set_audio_binary_data_list(
self._audio_binary_data_list
)
# 初始化子队列
self._subqueue_dict["original"] = Queue()
self._subqueue_dict["vad2asr"] = Queue()
self._subqueue_dict["vad2spk"] = Queue()
self._subqueue_dict["asrend"] = Queue()
self._subqueue_dict["spkend"] = Queue()
# 设置子队列的输入队列
self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"])
self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
# 设置回调函数——放置到对应队列中
self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put)
self._functor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put)
# 构造带回调函数的put
def put_with_checkcallback(queue: Queue, callback: Callable) -> None:
"""
带回调函数的put
"""
def put_with_check(data: Any) -> None:
queue.put(data)
callback(data)
return put_with_check
self._functor_dict["asr"].add_callback(
put_with_checkcallback(
self._subqueue_dict["asrend"], self._check_result
)
)
self._functor_dict["spk"].add_callback(
put_with_checkcallback(
self._subqueue_dict["spkend"], self._check_result
)
)
except ImportError:
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
def _check_result(self, result: Any) -> None:
"""
检查结果
"""
# 若asr和spk队列中都有数据则合并数据
if (
self._subqueue_dict["asrend"].qsize()
& self._subqueue_dict["spkend"].qsize()
):
asr_data = self._subqueue_dict["asrend"].get()
spk_data = self._subqueue_dict["spkend"].get()
# 合并数据
result = {"asr_data": asr_data, "spk_data": spk_data}
# 通知回调函数
self._notify_callbacks(result)
def run(self) -> threading.Thread:
"""
运行管道
Returns:
threading.Thread: 返回已运行线程实例
"""
# 检查运行资源是否准备完毕
self._pre_check()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
logger.info("[ASRpipeline]管道开始运行")
return self._thread
def _pre_check(self) -> None:
"""
预检查
"""
if self._is_baked is False:
raise RuntimeError("[ASRpipeline]管道未烘焙,无法运行")
for functor_name, functor in self._functor_dict.items():
if functor is None:
raise RuntimeError(f"[ASRpipeline]functor{functor_name}异常")
for subqueue_name, subqueue in self._subqueue_dict.items():
if subqueue is None:
raise RuntimeError(f"[ASRpipeline]子队列{subqueue_name}异常")
def _run(self) -> None:
"""
真实的运行逻辑
"""
# 运行所有functor
for functor_name, functor in self._functor_dict.items():
logger.info(f"[ASRpipeline]运行{functor_name}functor")
self._functor_dict[functor_name].run()
# 设置管道运行状态
with self._status_lock:
self._is_running = True
self._stop_event = False
while self._is_running and not self._stop_event:
try:
data = self._input_queue.get(timeout=self._queue_timeout)
# 检查是否是结束信号
if data is None:
logger.info("收到结束信号,管道准备停止")
self._input_queue.task_done() # 标记结束信号已处理
break
# 处理数据
self._process(data)
# 标记任务完成
self._input_queue.task_done()
except Empty:
# 队列获取超时,继续等待
continue
except Exception as e:
logger.error(f"[ASRpipeline]管道处理数据出错: {str(e)}")
break
logger.info("[ASRpipeline]管道停止运行")
def _process(self, data: Any) -> Any:
"""
处理数据
参数:
data: 输入数据
返回:
处理结果
"""
# 子类实现具体的处理逻辑
self._subqueue_dict["original"].put(data)
def stop(self) -> None:
"""
停止管道
"""
with self._status_lock:
self._is_running = False
self._stop_event = True
for functor_name, functor in self._functor_dict.items():
# logger.info(f"停止{functor_name}functor")
if functor.stop():
logger.info(f"[ASRpipeline]子Functor[{functor_name}]停止")
else:
logger.error(f"[ASRpipeline]子Functor[{functor_name}]停止失败")
self._thread.join()
logger.info("[ASRpipeline]管道停止")
return True

3
src/pipeline/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from src.pipeline.base import PipelineBase, PipelineFactory
__all__ = ["PipelineBase", "PipelineFactory"]

151
src/pipeline/base.py Normal file
View File

@ -0,0 +1,151 @@
from abc import ABC, abstractmethod
from queue import Queue, Empty
from typing import List, Callable, Any, Optional
import logging
import threading
import time
# 配置日志
logger = logging.getLogger(__name__)
class PipelineBase(ABC):
"""
管道基类
定义了管道的基本接口和通用功能
"""
def __init__(self, input_queue: Optional[Queue] = None):
"""
初始化管道
参数:
input_queue: 输入队列用于接收数据
"""
self._input_queue = input_queue
self._callbacks: List[Callable] = []
self._is_running = False
self._stop_event = False
self._thread: Optional[threading.Thread] = None
self._stop_timeout = 5 # 默认停止超时时间(秒)
self._queue_timeout = 1 # 队列获取超时时间(秒)
def set_input_queue(self, queue: Queue) -> None:
"""
设置输入队列
参数:
queue: 输入队列
"""
self._input_queue = queue
def add_callback(self, callback: Callable) -> None:
"""
添加回调函数
参数:
callback: 回调函数接收处理结果
"""
self._callbacks.append(callback)
def _notify_callbacks(self, result: Any) -> None:
"""
通知所有回调函数
参数:
result: 处理结果
"""
for callback in self._callbacks:
try:
callback(result)
except Exception as e:
logger.error(f"回调函数执行出错: {str(e)}")
@abstractmethod
def _process(self, data: Any) -> Any:
"""
处理数据
参数:
data: 输入数据
返回:
处理结果
"""
pass
@abstractmethod
def _run(self) -> None:
"""
运行管道
从输入队列获取数据并处理
"""
pass
def stop(self, timeout: Optional[float] = None) -> bool:
"""
停止管道
参数:
timeout: 停止超时时间None表示使用默认超时时间
返回:
bool: 是否成功停止
"""
if not self._is_running:
return True
logger.info("正在停止管道...")
self._stop_event = True
self._is_running = False
# 等待线程结束
if self._thread and self._thread.is_alive():
timeout = timeout if timeout is not None else self._stop_timeout
self._thread.join(timeout=timeout)
# 检查是否成功停止
if self._thread.is_alive():
logger.warning(f"管道停止超时({timeout}秒),强制终止")
return False
else:
logger.info("管道已成功停止")
return True
return True
def force_stop(self) -> None:
"""
强制停止管道
注意这可能会导致资源未正确释放
"""
logger.warning("强制停止管道")
self._stop_event = True
self._is_running = False
# 注意Python的线程无法被强制终止这里只是设置标志
# 实际终止需要依赖操作系统的进程管理
class PipelineFactory:
"""
管道工厂类
用于创建管道实例
"""
from src.pipeline.ASRpipeline import ASRPipeline
def _create_pipeline_ASRpipeline(*args, **kwargs) -> ASRPipeline:
"""
创建ASR管道实例
"""
from src.pipeline.ASRpipeline import ASRPipeline
pipeline = ASRPipeline()
pipeline.set_config(kwargs["config"])
pipeline.set_models(kwargs["models"])
pipeline.set_audio_binary(kwargs["audio_binary"])
pipeline.set_input_queue(kwargs["input_queue"])
pipeline.add_callback(kwargs["callback"])
pipeline.bake()
return pipeline
@classmethod
def create_pipeline(cls, pipeline_name: str, *args, **kwargs) -> Any:
"""
创建管道实例
"""
if pipeline_name == "ASRpipeline":
return cls._create_pipeline_ASRpipeline(*args, **kwargs)
else:
raise ValueError(f"不支持的管道类型: {pipeline_name}")

0
src/pipeline/test.py Normal file
View File

281
src/runner.py Normal file
View File

@ -0,0 +1,281 @@
"""
运行器模块
提供运行器基类和运行器类用于管理音频数据和模型的交互
主要包含:
- RunnerBase: 运行器基类,定义了基本接口
- Runner: 运行器类,工厂模式实现
- RunnerFactory: 运行器工厂类,用于创建运行器
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, List
from threading import Thread, Lock
from queue import Queue
import traceback
import time
from src.audio_chunk import AudioChunk, AudioBinary
from src.pipeline import Pipeline, PipelineFactory
from src.model_loader import ModelLoader
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__, level="INFO")
audio_chunk = AudioChunk()
models_loaded = ModelLoader()
class RunnerBase(ABC):
"""
运行器基类
定义了运行器的基本接口
"""
@abstractmethod
def adder(self, data: Any) -> None:
"""
添加数据
参数:
data: 要添加的数据
"""
pass
@abstractmethod
def add_recevier(self, receiver: callable) -> None:
"""
添加数据接收者
参数:
receiver: 接收数据的回调函数
"""
pass
class STTRunner(RunnerBase):
"""
运行器类
负责管理资源和协调Pipeline的运行
"""
def __init__(
self,
*,
audio_binary_list: List[AudioBinary],
models: Dict[str, Any],
pipeline_list: List[Pipeline],
):
"""
初始化运行器
参数:
audio_binary_list: 音频二进制列表
models: 模型字典
pipeline_list: 管道列表
queue_size: 队列大小
stop_timeout: 停止超时时间
"""
# 接收资源
self._audio_binary_list = audio_binary_list
self._models = models
self._pipeline_list = pipeline_list
# 线程控制
self._lock = Lock()
# 消息队列
self._input_queue = Queue(maxsize=1000)
# 停止控制
self._stop_timeout = 10.0
self._is_stopping = False
# 配置资源
for pipeline in self._pipeline_list:
# 设置输入队列
pipeline.set_input_queue(self._input_queue)
# 配置资源
pipeline.set_audio_binary(
self._audio_binary_list[pipeline.get_config("audio_binary_name")]
)
pipeline.set_models(self._models)
def adder(self, data: Any) -> None:
"""
添加数据到输入队列
参数:
data: 要添加的数据
"""
if not self._pipeline_list:
raise RuntimeError("没有可用的管道")
if self._is_stopping:
raise RuntimeError("运行器正在停止,无法添加数据")
self._input_queue.put(data)
def add_recevier(self, receiver: callable) -> None:
"""
添加数据接收者
参数:
receiver: 接收数据的回调函数
"""
with self._lock:
for pipeline in self._pipeline_list:
pipeline.add_callback(receiver)
def run(self) -> None:
"""
启动所有管道
"""
logger.info("[%s] 启动所有管道", self.__class__.__name__)
if not self._pipeline_list:
raise RuntimeError("没有可用的管道")
# 启动所有管道
for pipeline in self._pipeline_list:
thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}")
thread.daemon = True
thread.start()
logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline))
def stop(self, force: bool = False) -> bool:
"""
停止所有管道
参数:
force: 是否强制停止
返回:
bool: 是否成功停止
"""
if self._is_stopping:
logger.warning("运行器已经在停止中")
return False
self._is_stopping = True
logger.info("正在停止运行器...")
try:
# 发送结束信号
self._input_queue.put(None)
# 停止所有管道
success = True
for pipeline in self._pipeline_list:
if force:
pipeline.force_stop()
else:
if not pipeline.stop(timeout=self._stop_timeout):
logger.warning("管道 %s 停止超时", id(pipeline))
success = False
# 等待队列处理完成
try:
start_time = time.time()
while not self._input_queue.empty():
if time.time() - start_time > self._stop_timeout:
logger.warning(
"等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理",
self._stop_timeout,
self._input_queue.qsize(),
)
success = False
break
time.sleep(0.1) # 避免过度消耗CPU
except Exception as e:
error_type = type(e).__name__
error_msg = str(e)
error_traceback = traceback.format_exc()
logger.error(
"等待队列处理完成时发生错误:\n"
"错误类型: %s\n"
"错误信息: %s\n"
"错误堆栈:\n%s",
error_type,
error_msg,
error_traceback,
)
success = False
if success:
logger.info("所有管道已成功停止")
else:
logger.warning(
"部分管道停止失败,队列状态: 大小=%d, 是否为空=%s",
self._input_queue.qsize(),
self._input_queue.empty(),
)
return success
finally:
self._is_stopping = False
def __del__(self) -> None:
"""
析构函数
"""
self.stop(force=True)
class STTRunnerFactory:
"""
STT Runner工厂类
用于创建运行器实例
"""
@staticmethod
def _create_runner(
audio_binary_name: str,
model_name_list: List[str],
pipeline_name_list: List[str],
) -> STTRunner:
"""
创建运行器
参数:
audio_binary_name: 音频二进制名称
model_name_list: 模型名称列表
pipeline_name_list: 管道名称列表
返回:
Runner实例
"""
audio_binary = audio_chunk.get_audio_binary(audio_binary_name)
models: Dict[str, Any] = {
model_name: models_loaded.models[model_name]
for model_name in model_name_list
}
pipelines: List[Pipeline] = [
PipelineFactory.create_pipeline(pipeline_name)
for pipeline_name in pipeline_name_list
]
return STTRunner(
audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines
)
@classmethod
def create_runner_from_config(
cls,
config: Dict[str, Any],
) -> STTRunner:
"""
从配置创建运行器
参数:
config: 配置字典
返回:
Runner实例
"""
audio_binary_name = config["audio_binary_name"]
model_name_list = config["model_name_list"]
pipeline_name_list = config["pipeline_name_list"]
return cls._create_runner(
audio_binary_name, model_name_list, pipeline_name_list
)
@classmethod
def create_runner_normal(cls) -> STTRunner:
"""
创建默认运行器
返回:
Runner实例
"""
audio_binary_name = None
model_name_list = list(models_loaded.models.keys())
pipeline_name_list = None
return cls._create_runner(
audio_binary_name, model_name_list, pipeline_name_list
)

View File

@ -43,7 +43,7 @@ async def clear_websocket():
async def ws_serve(websocket, path):
"""
WebSocket服务主函数处理客户端连接和消息
参数:
websocket: WebSocket连接对象
path: 连接路径
@ -51,13 +51,13 @@ async def ws_serve(websocket, path):
frames = [] # 存储所有音频帧
frames_asr = [] # 存储用于离线ASR的音频帧
frames_asr_online = [] # 存储用于在线ASR的音频帧
global websocket_users
# await clear_websocket() # 清理现有连接(目前注释掉,允许多客户端)
# 添加到用户集合
websocket_users.add(websocket)
# 初始化连接状态
websocket.status_dict_asr = {}
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
@ -66,15 +66,15 @@ async def ws_serve(websocket, path):
websocket.chunk_interval = 10
websocket.vad_pre_idx = 0
websocket.is_speaking = True # 默认用户正在说话
# 语音检测状态
speech_start = False
speech_end_i = -1
# 初始化配置
websocket.wav_name = "microphone"
websocket.mode = "2pass" # 默认使用两阶段识别模式
print("新用户已连接", flush=True)
try:
@ -84,11 +84,13 @@ async def ws_serve(websocket, path):
if isinstance(message, str):
try:
messagejson = json.loads(message)
# 更新各种配置参数
if "is_speaking" in messagejson:
websocket.is_speaking = messagejson["is_speaking"]
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
websocket.status_dict_asr_online["is_final"] = (
not websocket.is_speaking
)
if "chunk_interval" in messagejson:
websocket.chunk_interval = messagejson["chunk_interval"]
if "wav_name" in messagejson:
@ -97,11 +99,17 @@ async def ws_serve(websocket, path):
chunk_size = messagejson["chunk_size"]
if isinstance(chunk_size, str):
chunk_size = chunk_size.split(",")
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
websocket.status_dict_asr_online["chunk_size"] = [
int(x) for x in chunk_size
]
if "encoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
websocket.status_dict_asr_online["encoder_chunk_look_back"] = (
messagejson["encoder_chunk_look_back"]
)
if "decoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"]
websocket.status_dict_asr_online["decoder_chunk_look_back"] = (
messagejson["decoder_chunk_look_back"]
)
if "hotword" in messagejson:
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
if "mode" in messagejson:
@ -111,11 +119,17 @@ async def ws_serve(websocket, path):
# 根据chunk_interval更新VAD的chunk_size
websocket.status_dict_vad["chunk_size"] = int(
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1] * 60 / websocket.chunk_interval
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1]
* 60
/ websocket.chunk_interval
)
# 处理音频数据
if len(frames_asr_online) > 0 or len(frames_asr) >= 0 or not isinstance(message, str):
if (
len(frames_asr_online) > 0
or len(frames_asr) >= 0
or not isinstance(message, str)
):
if not isinstance(message, str): # 二进制音频数据
# 添加到帧缓冲区
frames.append(message)
@ -125,10 +139,12 @@ async def ws_serve(websocket, path):
# 处理在线ASR
frames_asr_online.append(message)
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
# 达到chunk_interval或最终帧时处理在线ASR
if (len(frames_asr_online) % websocket.chunk_interval == 0 or
websocket.status_dict_asr_online["is_final"]):
if (
len(frames_asr_online) % websocket.chunk_interval == 0
or websocket.status_dict_asr_online["is_final"]
):
if websocket.mode == "2pass" or websocket.mode == "online":
audio_in = b"".join(frames_asr_online)
try:
@ -136,26 +152,32 @@ async def ws_serve(websocket, path):
except Exception as e:
print(f"在线ASR处理错误: {e}")
frames_asr_online = []
# 如果检测到语音开始收集帧用于离线ASR
if speech_start:
frames_asr.append(message)
# VAD处理 - 语音活动检测
try:
speech_start_i, speech_end_i = await asr_service.async_vad(websocket, message)
speech_start_i, speech_end_i = await asr_service.async_vad(
websocket, message
)
except Exception as e:
print(f"VAD处理错误: {e}")
# 检测到语音开始
if speech_start_i != -1:
speech_start = True
# 计算开始偏移并收集前面的帧
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
frames_pre = frames[-beg_bias:] if beg_bias < len(frames) else frames
beg_bias = (
websocket.vad_pre_idx - speech_start_i
) // duration_ms
frames_pre = (
frames[-beg_bias:] if beg_bias < len(frames) else frames
)
frames_asr = []
frames_asr.extend(frames_pre)
# 处理离线ASR (语音结束或用户停止说话)
if speech_end_i != -1 or not websocket.is_speaking:
if websocket.mode == "2pass" or websocket.mode == "offline":
@ -164,13 +186,13 @@ async def ws_serve(websocket, path):
await asr_service.async_asr(websocket, audio_in)
except Exception as e:
print(f"离线ASR处理错误: {e}")
# 重置状态
frames_asr = []
speech_start = False
frames_asr_online = []
websocket.status_dict_asr_online["cache"] = {}
# 如果用户停止说话,完全重置
if not websocket.is_speaking:
websocket.vad_pre_idx = 0
@ -193,34 +215,34 @@ async def ws_serve(websocket, path):
def start_server(args, asr_service_instance):
"""
启动WebSocket服务器
参数:
args: 命令行参数
asr_service_instance: ASR服务实例
"""
global asr_service
asr_service = asr_service_instance
# 配置SSL (如果提供了证书)
if args.certfile and len(args.certfile) > 0:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile)
start_server = websockets.serve(
ws_serve, args.host, args.port,
subprotocols=["binary"],
ping_interval=None,
ssl=ssl_context
ws_serve,
args.host,
args.port,
subprotocols=["binary"],
ping_interval=None,
ssl=ssl_context,
)
else:
start_server = websockets.serve(
ws_serve, args.host, args.port,
subprotocols=["binary"],
ping_interval=None
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None
)
print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}")
# 启动事件循环
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()
@ -229,14 +251,14 @@ def start_server(args, asr_service_instance):
if __name__ == "__main__":
# 解析命令行参数
args = parse_args()
# 加载模型
print("正在加载模型...")
models = load_models(args)
print("模型加载完成!当前仅支持单个客户端同时连接!")
# 创建ASR服务
asr_service = ASRService(models)
# 启动服务器
start_server(args, asr_service)
start_server(args, asr_service)

View File

@ -9,11 +9,11 @@ import json
class ASRService:
"""ASR服务类封装各种语音识别相关功能"""
def __init__(self, models):
"""
初始化ASR服务
参数:
models: 包含各种预加载模型的字典
"""
@ -21,42 +21,41 @@ class ASRService:
self.model_asr_streaming = models["asr_streaming"]
self.model_vad = models["vad"]
self.model_punc = models["punc"]
async def async_vad(self, websocket, audio_in):
"""
语音活动检测
参数:
websocket: WebSocket连接
audio_in: 二进制音频数据
返回:
tuple: (speech_start, speech_end) 语音开始和结束位置
"""
# 使用VAD模型分析音频段
segments_result = self.model_vad.generate(
input=audio_in,
**websocket.status_dict_vad
input=audio_in, **websocket.status_dict_vad
)[0]["value"]
speech_start = -1
speech_end = -1
# 解析VAD结果
if len(segments_result) == 0 or len(segments_result) > 1:
return speech_start, speech_end
if segments_result[0][0] != -1:
speech_start = segments_result[0][0]
if segments_result[0][1] != -1:
speech_end = segments_result[0][1]
return speech_start, speech_end
async def async_asr(self, websocket, audio_in):
"""
离线ASR处理
参数:
websocket: WebSocket连接
audio_in: 二进制音频数据
@ -64,42 +63,44 @@ class ASRService:
if len(audio_in) > 0:
# 使用离线ASR模型处理音频
rec_result = self.model_asr.generate(
input=audio_in,
**websocket.status_dict_asr
input=audio_in, **websocket.status_dict_asr
)[0]
# 如果有标点符号模型且识别出文本,则添加标点
if self.model_punc is not None and len(rec_result["text"]) > 0:
rec_result = self.model_punc.generate(
input=rec_result["text"],
**websocket.status_dict_punc
input=rec_result["text"], **websocket.status_dict_punc
)[0]
# 如果识别出文本,发送到客户端
if len(rec_result["text"]) > 0:
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({
"mode": mode,
"text": rec_result["text"],
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
})
message = json.dumps(
{
"mode": mode,
"text": rec_result["text"],
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
await websocket.send(message)
else:
# 如果没有音频数据,发送空文本
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({
"mode": mode,
"text": "",
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
})
message = json.dumps(
{
"mode": mode,
"text": "",
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
await websocket.send(message)
async def async_asr_online(self, websocket, audio_in):
"""
在线ASR处理
参数:
websocket: WebSocket连接
audio_in: 二进制音频数据
@ -107,21 +108,24 @@ class ASRService:
if len(audio_in) > 0:
# 使用在线ASR模型处理音频
rec_result = self.model_asr_streaming.generate(
input=audio_in,
**websocket.status_dict_asr_online
input=audio_in, **websocket.status_dict_asr_online
)[0]
# 在2pass模式下如果是最终帧则跳过(留给离线ASR处理)
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False):
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get(
"is_final", False
):
return
# 如果识别出文本,发送到客户端
if len(rec_result["text"]):
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({
"mode": mode,
"text": rec_result["text"],
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
})
await websocket.send(message)
message = json.dumps(
{
"mode": mode,
"text": rec_result["text"],
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
await websocket.send(message)

3
src/utils/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from .logger import get_module_logger, setup_root_logger
__all__ = ["get_module_logger", "setup_root_logger"]

122
src/utils/data_format.py Normal file
View File

@ -0,0 +1,122 @@
"""
处理各类音频数据与bytes的转换
"""
import wave
from pydub import AudioSegment
import io
def wav_to_bytes(wav_path: str) -> bytes:
"""
将WAV文件读取为bytes数据
参数:
wav_path (str): WAV文件的路径
返回:
bytes: WAV文件的原始字节数据
异常:
FileNotFoundError: 如果WAV文件不存在
wave.Error: 如果文件不是有效的WAV文件
"""
try:
with wave.open(wav_path, "rb") as wf:
# 读取所有帧
frames = wf.readframes(wf.getnframes())
return frames
except FileNotFoundError:
# 可以选择记录日志或重新抛出,这里为了清晰直接重新抛出
raise FileNotFoundError(f"错误: 未找到WAV文件 '{wav_path}'")
except wave.Error as e:
raise wave.Error(f"错误: 打开或读取WAV文件 '{wav_path}' 失败 - {e}")
def bytes_to_wav(
bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int
):
"""
将bytes数据写入为WAV文件
参数:
bytes_data (bytes): 音频的字节数据
wav_path (str): 保存WAV文件的路径
nchannels (int): 声道数 (例如 1 for mono, 2 for stereo)
sampwidth (int): 采样宽度 (字节数, 例如 2 for 16-bit audio)
framerate (int): 采样率 (例如 44100, 16000)
异常:
wave.Error: 如果写入WAV文件失败
"""
try:
with wave.open(wav_path, "wb") as wf:
wf.setnchannels(nchannels)
wf.setsampwidth(sampwidth)
wf.setframerate(framerate)
wf.writeframes(bytes_data)
except wave.Error as e:
raise wave.Error(f"错误: 写入WAV文件 '{wav_path}' 失败 - {e}")
except Exception as e:
# 捕获其他可能的写入错误
raise Exception(f"写入WAV文件 '{wav_path}' 时发生未知错误 - {e}")
def mp3_to_bytes(mp3_path: str) -> bytes:
"""
将MP3文件转换为bytes数据 (原始PCM数据)
参数:
mp3_path (str): MP3文件的路径
返回:
bytes: MP3文件解码后的原始PCM字节数据
异常:
FileNotFoundError: 如果MP3文件不存在
pydub.exceptions.CouldntDecodeError: 如果MP3文件无法解码
"""
try:
audio = AudioSegment.from_mp3(mp3_path)
# 获取原始PCM数据
return audio.raw_data
except FileNotFoundError:
raise FileNotFoundError(f"错误: 未找到MP3文件 '{mp3_path}'")
except Exception as e: # pydub 可能抛出多种解码相关的错误
raise Exception(f"错误: 处理MP3文件 '{mp3_path}' 失败 - {e}")
def bytes_to_mp3(
bytes_data: bytes,
mp3_path: str,
frame_rate: int,
channels: int,
sample_width: int,
bitrate: str = "192k",
):
"""
将原始PCM bytes数据转换为MP3文件
参数:
bytes_data (bytes): 原始PCM字节数据
mp3_path (str): 保存MP3文件的路径
frame_rate (int): 原始PCM数据的采样率 (例如 44100)
channels (int): 原始PCM数据的声道数 (例如 1 for mono, 2 for stereo)
sample_width (int): 原始PCM数据的采样宽度 (字节数, 例如 2 for 16-bit)
bitrate (str): MP3编码的比特率 (例如 "128k", "192k", "320k")
异常:
Exception: 如果转换或写入MP3文件失败
"""
try:
# 从原始数据创建AudioSegment对象
audio = AudioSegment(
data=bytes_data,
sample_width=sample_width,
frame_rate=frame_rate,
channels=channels,
)
# 导出为MP3
audio.export(mp3_path, format="mp3", bitrate=bitrate)
except Exception as e:
raise Exception(f"错误: 转换或写入MP3文件 '{mp3_path}' 失败 - {e}")

91
src/utils/logger.py Normal file
View File

@ -0,0 +1,91 @@
import logging
import sys
from pathlib import Path
from typing import Optional
def setup_logger(
name: str = None,
level: str = "INFO",
log_file: Optional[str] = None,
log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
date_format: str = "%Y-%m-%d %H:%M:%S",
) -> logging.Logger:
"""
设置并返回一个配置好的logger实例
Args:
name: logger的名称默认为None使用root logger
level: 日志级别默认为"INFO"
log_file: 日志文件路径默认为None仅控制台输出
log_format: 日志格式
date_format: 日期格式
Returns:
logging.Logger: 配置好的logger实例
"""
# 获取logger实例
logger = logging.getLogger(name)
# 设置日志级别
level = getattr(logging, level.upper())
logger.setLevel(level)
print(f"添加处理器 {name} {log_file} {log_format} {date_format}")
# 创建格式器
formatter = logging.Formatter(log_format, date_format)
# 添加控制台处理器
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# 如果指定了日志文件,添加文件处理器
if log_file:
# 确保日志目录存在
log_path = Path(log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# 注意:移除了 propagate = False允许日志传递
return logger
def setup_root_logger(level: str = "INFO", log_file: Optional[str] = None) -> None:
"""
配置根日志器
Args:
level: 日志级别
log_file: 日志文件路径
"""
setup_logger(None, level, log_file)
def get_module_logger(
module_name: str,
level: Optional[str] = None, # 改为可选参数
log_file: Optional[str] = None, # 一般不需要单独指定
) -> logging.Logger:
"""
获取模块级别的logger
Args:
module_name: 模块名称通常传入__name__
level: 可选的日志级别如果不指定则继承父级配置
log_file: 可选的日志文件路径一般不需要指定
"""
logger = logging.getLogger(module_name)
# 只有显式指定了level才设置
if level:
logger.setLevel(getattr(logging, level.upper()))
# 只有显式指定了log_file才添加文件处理器
if log_file:
setup_logger(module_name, level or "INFO", log_file)
return logger

17
test_main.py Normal file
View File

@ -0,0 +1,17 @@
"""
测试主函数
请在tests目录下创建测试文件, 并在此文件中调用
"""
from tests.pipeline.asr_test import test_asr_pipeline
from src.utils.logger import get_module_logger, setup_root_logger
setup_root_logger(level="INFO", log_file="logs/test_main.log")
logger = get_module_logger(__name__)
# from tests.functor.vad_test import test_vad_functor
# logger.info("开始测试VAD函数器")
# test_vad_functor()
logger.info("开始测试ASR管道")
test_asr_pipeline()

View File

@ -1 +1 @@
"""FunASR WebSocket服务测试模块"""
"""FunASR WebSocket服务测试模块"""

124
tests/functor/vad_test.py Normal file
View File

@ -0,0 +1,124 @@
"""
Functor测试
VAD测试
"""
from src.functor.vad_functor import VADFunctor
from src.functor.asr_functor import ASRFunctor
from src.functor.spk_functor import SPKFunctor
from queue import Queue, Empty
from src.model_loader import ModelLoader
from src.models import AudioBinary_Config, AudioBinary_data_list
from src.utils.data_format import wav_to_bytes
import time
from src.utils.logger import get_module_logger
from pydub import AudioSegment
import soundfile
# 观察参数
OVERWATCH = False
logger = get_module_logger(__name__)
model_loader = ModelLoader()
def test_vad_functor():
# 加载模型
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"auto_update": False,
}
model_loader.load_models(args)
# 加载数据
f_data, sample_rate = soundfile.read("tests/vad_example.wav")
audio_config = AudioBinary_Config(
chunk_size=200,
chunk_stride=1600,
sample_rate=sample_rate,
sample_width=16,
channels=1,
)
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
audio_config.chunk_stride = chunk_stride
# 创建输入队列
input_queue = Queue()
vad2asr_queue = Queue()
vad2spk_queue = Queue()
# 创建音频数据列表
audio_binary_data_list = AudioBinary_data_list()
# 创建VAD函数器
vad_functor = VADFunctor()
# 设置输入队列
vad_functor.set_input_queue(input_queue)
# 设置音频配置
vad_functor.set_audio_config(audio_config)
# 设置音频数据列表
vad_functor.set_audio_binary_data_list(audio_binary_data_list)
# 设置回调函数
vad_functor.add_callback(lambda x: print(f"vad callback: {x}"))
vad_functor.add_callback(lambda x: vad2asr_queue.put(x))
vad_functor.add_callback(lambda x: vad2spk_queue.put(x))
# 设置模型
vad_functor.set_model({"vad": model_loader.models["vad"]})
# 启动VAD函数器
vad_functor.run()
# 创建ASR函数器
asr_functor = ASRFunctor()
# 设置输入队列
asr_functor.set_input_queue(vad2asr_queue)
# 设置音频配置
asr_functor.set_audio_config(audio_config)
# 设置回调函数
asr_functor.add_callback(lambda x: print(f"asr callback: {x}"))
# 设置模型
asr_functor.set_model({"asr": model_loader.models["asr"]})
# 启动ASR函数器
asr_functor.run()
# 创建SPK函数器
spk_functor = SPKFunctor()
# 设置输入队列
spk_functor.set_input_queue(vad2spk_queue)
# 设置音频配置
spk_functor.set_audio_config(audio_config)
# 设置回调函数
spk_functor.add_callback(lambda x: print(f"spk callback: {x}"))
# 设置模型
spk_functor.set_model(
{
# 'spk': model_loader.models['spk']
"spk": "fake_spk"
}
)
# 启动SPK函数器
spk_functor.run()
f_binary = f_data
audio_clip_len = 200
print(
f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}"
)
for i in range(0, len(f_binary), audio_clip_len):
binary_data = f_binary[i : i + audio_clip_len]
input_queue.put(binary_data)
# 等待VAD函数器结束
vad_functor.stop()
print("[vad_test] VAD函数器结束")
asr_functor.stop()
print("[vad_test] ASR函数器结束")
# 保存音频数据
if OVERWATCH:
for index in range(len(audio_binary_data_list)):
save_path = f"tests/vad_test_output_{index}.wav"
soundfile.write(
save_path, audio_binary_data_list[index].binary_data, sample_rate
)

121
tests/modelsuse.py Normal file
View File

@ -0,0 +1,121 @@
"""
模型使用测试
此处主要用于各类调用模型的处理数据与输出格式
请在主目录下test_main.py中调用
将需要测试的模型定义在函数中进行测试, 函数名称需要与测试内容匹配
"""
from funasr import AutoModel
from typing import List, Dict, Any
from src.models import VADResponse
import time
def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
"""
在线VAD模型使用
"""
chunk_size = 100 # ms
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
vad_result = VADResponse()
vad_result.time_chunk_index_callback = lambda index: print(f"回调: {index}")
items = []
import soundfile
speech, sample_rate = soundfile.read(file_path)
chunk_stride = int(chunk_size * sample_rate / 1000)
cache = {}
total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
for i in range(total_chunk_num):
time.sleep(0.1)
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
is_final = i == total_chunk_num - 1
res = model.generate(
input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size
)
if len(res[0]["value"]):
vad_result += VADResponse.from_raw(res)
for item in res[0]["value"]:
items.append(item)
vad_result.process_time_chunk()
# for item in items:
# print(item)
return vad_result
def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
"""
在线VAD模型使用
测试LogicTrager
在Rebuild版本后LogicTrager中已弃用
"""
from src.logic_trager import LogicTrager
import soundfile
from src.config import parse_args
args = parse_args()
# from src.functor.model_loader import load_models
# models = load_models(args)
from src.model_loader import ModelLoader
models = ModelLoader(args)
chunk_size = 200 # ms
from src.models import AudioBinary_Config
import soundfile
speech, sample_rate = soundfile.read(file_path)
chunk_stride = int(chunk_size * sample_rate / 1000)
audio_config = AudioBinary_Config(
sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size
)
logic_trager = LogicTrager(models=models, audio_config=audio_config)
for i in range(len(speech) // chunk_stride + 1):
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
logic_trager.push_binary_data(speech_chunk)
# for item in items:
# print(item)
return None
def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
"""
ASR模型使用
离线ASR模型使用
"""
from funasr import AutoModel
model = AutoModel(
model="paraformer-zh",
model_revision="v2.0.4",
vad_model="fsmn-vad",
vad_model_revision="v2.0.4",
# punc_model="ct-punc-c", punc_model_revision="v2.0.4",
spk_model="cam++",
spk_model_revision="v2.0.2",
spk_mode="vad_segment",
auto_update=False,
)
import soundfile
from src.models import AudioBinary_Config
import soundfile
speech, sample_rate = soundfile.read(file_path)
result = model.generate(speech)
return result
# if __name__ == "__main__":
# 请在主目录下调用test_main.py文件进行测试
# vad_result = vad_model_use_online("tests/vad_example.wav")
# vad_result = vad_model_use_online_logic("tests/vad_example.wav")
# print(vad_result)

View File

@ -0,0 +1,93 @@
"""
Pipeline测试
VAD+ASR+SPK(FAKE)
"""
from src.pipeline.ASRpipeline import ASRPipeline
from src.pipeline import PipelineFactory
from src.models import AudioBinary_data_list, AudioBinary_Config
from src.model_loader import ModelLoader
from queue import Queue
import soundfile
import time
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
OVAERWATCH = False
model_loader = ModelLoader()
def test_asr_pipeline():
# 加载模型
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"spk_model": "cam++",
"spk_model_revision": "v2.0.2",
"audio_update": False,
}
models = model_loader.load_models(args)
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
audio_config = AudioBinary_Config(
chunk_size=200,
chunk_stride=1600,
sample_rate=sample_rate,
sample_width=16,
channels=1,
)
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
audio_config.chunk_stride = chunk_stride
# 创建参数Dict
config = {
"audio_config": audio_config,
}
# 创建音频数据列表
audio_binary_data_list = AudioBinary_data_list()
input_queue = Queue()
# 创建Pipeline
# asr_pipeline = ASRPipeline()
# asr_pipeline.set_models(models)
# asr_pipeline.set_config(config)
# asr_pipeline.set_audio_binary(audio_binary_data_list)
# asr_pipeline.set_input_queue(input_queue)
# asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
# asr_pipeline.bake()
asr_pipeline = PipelineFactory.create_pipeline(
pipeline_name = "ASRpipeline",
models=models,
config=config,
audio_binary=audio_binary_data_list,
input_queue=input_queue,
callback=lambda x: print(f"pipeline callback: {x}")
)
# 运行Pipeline
asr_instance = asr_pipeline.run()
audio_clip_len = 200
print(
f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}"
)
for i in range(0, len(audio_data), audio_clip_len):
input_queue.put(audio_data[i : i + audio_clip_len])
# time.sleep(10)
# input_queue.put(None)
# 等待Pipeline结束
# asr_instance.join()
time.sleep(5)
asr_pipeline.stop()
# asr_pipeline.stop()

View File

@ -10,23 +10,23 @@ import os
from unittest.mock import patch
# 将src目录添加到路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.config import parse_args
def test_default_args():
"""测试默认参数值"""
with patch('sys.argv', ['script.py']):
with patch("sys.argv", ["script.py"]):
args = parse_args()
# 检查服务器参数
assert args.host == "0.0.0.0"
assert args.port == 10095
# 检查SSL参数
assert args.certfile == ""
assert args.keyfile == ""
# 检查模型参数
assert "paraformer" in args.asr_model
assert args.asr_model_revision == "v2.0.4"
@ -36,7 +36,7 @@ def test_default_args():
assert args.vad_model_revision == "v2.0.4"
assert "punc" in args.punc_model
assert args.punc_model_revision == "v2.0.4"
# 检查硬件配置
assert args.ngpu == 1
assert args.device == "cuda"
@ -46,19 +46,26 @@ def test_default_args():
def test_custom_args():
"""测试自定义参数值"""
test_args = [
'script.py',
'--host', 'localhost',
'--port', '8080',
'--certfile', 'cert.pem',
'--keyfile', 'key.pem',
'--asr_model', 'custom_model',
'--ngpu', '0',
'--device', 'cpu'
"script.py",
"--host",
"localhost",
"--port",
"8080",
"--certfile",
"cert.pem",
"--keyfile",
"key.pem",
"--asr_model",
"custom_model",
"--ngpu",
"0",
"--device",
"cpu",
]
with patch('sys.argv', test_args):
with patch("sys.argv", test_args):
args = parse_args()
# 检查自定义参数
assert args.host == "localhost"
assert args.port == 8080
@ -66,4 +73,4 @@ def test_custom_args():
assert args.keyfile == "key.pem"
assert args.asr_model == "custom_model"
assert args.ngpu == 0
assert args.device == "cpu"
assert args.device == "cpu"

BIN
tests/vad_example.wav Normal file

Binary file not shown.