[Feature] Add /tests/modelsuse 测试实时VAD检测。

This commit is contained in:
Keeeer 2025-04-15 13:53:06 +08:00
parent 86e5425787
commit 8b69ff195f
8 changed files with 506 additions and 196 deletions

172
src/audiochunk.py Normal file
View File

@ -0,0 +1,172 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
音频数据块管理类 - 用于存储和处理16KHz音频数据
"""
import numpy as np
import logging
from typing import List, Optional, Union
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('AudioChunk')
class AudioChunk:
"""音频数据块管理类用于存储和处理16KHz音频数据"""
def __init__(self,
max_duration_ms: int = 1000*60*60*10,
sample_rate: int = 16000,
sample_width: int = 2,
channels: int = 1):
"""
初始化音频数据块管理器
参数:
max_duration_ms: 音频池最大留存时间(ms)默认10小时
sample_rate: 采样率默认16KHz
sample_width: 采样位宽默认16bit
channels: 通道数默认1
"""
# 音频参数
self.sample_rate = sample_rate # 采样率16KHz
self.sample_width = sample_width # 采样位宽16bit
self.channels = channels # 通道数:单声道
# 数据存储
self._max_duration_ms = max_duration_ms
self._max_chunk_size = self._time2size(max_duration_ms) # 最大数据大小
self._chunk = [] # 当前音频数据块列表
self._chunk_size = 0 # 当前数据总大小
self._offset = 0 # 当前偏移量
logger.info(f"初始化AudioChunk: 最大时长={max_duration_ms}ms, 最大数据大小={self._max_chunk_size}字节")
def add_chunk(self, chunk: Union[bytes, np.ndarray]) -> bool:
"""
添加音频数据块
参数:
chunk: 音频数据块可以是bytes或numpy数组
返回:
bool: 是否添加成功
"""
try:
# 检查数据格式
if isinstance(chunk, np.ndarray):
# 确保是16bit整数格式
if chunk.dtype != np.int16:
chunk = chunk.astype(np.int16)
# 转换为bytes
chunk = chunk.tobytes()
# 检查数据大小
if len(chunk) % (self.sample_width * self.channels) != 0:
logger.warning(f"音频数据大小不是{self.sample_width * self.channels}的倍数: {len(chunk)}")
return False
# 检查是否超过最大限制
if self._chunk_size + len(chunk) > self._max_chunk_size:
logger.warning("音频数据超过最大限制,将自动清除旧数据")
self.clear_chunk()
# 添加数据
self._chunk.append(chunk)
self._chunk_size += len(chunk)
return True
except Exception as e:
logger.error(f"添加音频数据块时出错: {e}")
return False
def get_chunk(self, start_ms: int = 0, end_ms: Optional[int] = None) -> Optional[bytes]:
"""
获取指定时间范围的音频数据
参数:
start_ms: 开始时间(ms)
end_ms: 结束时间(ms)None表示到末尾
返回:
Optional[bytes]: 音频数据如果获取失败则返回None
"""
try:
if not self._chunk:
return None
# 计算字节偏移
start_byte = self._time2size(start_ms)
end_byte = self._time2size(end_ms) if end_ms is not None else self._chunk_size
# 检查范围是否有效
if start_byte >= self._chunk_size or start_byte >= end_byte:
return None
# 获取数据
data = b''.join(self._chunk)
return data[start_byte:end_byte]
except Exception as e:
logger.error(f"获取音频数据块时出错: {e}")
return None
def get_duration(self) -> int:
"""
获取当前音频总时长(ms)
返回:
int: 音频时长(ms)
"""
return self._size2time(self._chunk_size)
def clear_chunk(self) -> None:
"""清除所有音频数据"""
self._chunk = []
self._chunk_size = 0
self._offset = 0
logger.info("已清除所有音频数据")
def _time2size(self, time_ms: int) -> int:
"""
将时间(ms)转换为数据大小(字节)
参数:
time_ms: 时间(ms)
返回:
int: 数据大小(字节)
"""
return int(time_ms * self.sample_rate * self.sample_width * self.channels / 1000)
def _size2time(self, size: int) -> int:
"""
将数据大小(字节)转换为时间(ms)
参数:
size: 数据大小(字节)
返回:
int: 时间(ms)
"""
return int(size * 1000 / (self.sample_rate * self.sample_width * self.channels))
# instance(start_ms, end_ms, use_offset=True)
def __call__(self, start_ms: int = 0, end_ms: Optional[int] = None, use_offset: bool = True) -> Optional[bytes]:
"""
获取指定时间范围的音频数据
"""
if use_offset:
start_ms += self._offset
end_ms += self._offset
return self.get_chunk(start_ms, end_ms)
def __len__(self) -> int:
"""
获取当前音频数据块大小
"""
return self._chunk_size

View File

@ -1,196 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
WebSocket客户端示例 - 用于测试语音识别服务
"""
import asyncio
import json
import websockets
import argparse
import numpy as np
import wave
import os
def parse_args():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description="FunASR WebSocket客户端")
parser.add_argument(
"--host",
type=str,
default="localhost",
help="服务器主机地址"
)
parser.add_argument(
"--port",
type=int,
default=10095,
help="服务器端口"
)
parser.add_argument(
"--audio_file",
type=str,
required=True,
help="要识别的音频文件路径"
)
parser.add_argument(
"--mode",
type=str,
default="2pass",
choices=["2pass", "online", "offline"],
help="识别模式: 2pass(默认), online, offline"
)
parser.add_argument(
"--chunk_size",
type=str,
default="5,10",
help="分块大小, 格式为'encoder_size,decoder_size'"
)
parser.add_argument(
"--use_ssl",
action="store_true",
help="是否使用SSL连接"
)
return parser.parse_args()
async def send_audio(websocket, audio_file, mode, chunk_size):
"""
发送音频文件到服务器进行识别
参数:
websocket: WebSocket连接
audio_file: 音频文件路径
mode: 识别模式
chunk_size: 分块大小
"""
# 打开并读取WAV文件
with wave.open(audio_file, "rb") as wav_file:
params = wav_file.getparams()
frames = wav_file.readframes(wav_file.getnframes())
# 音频文件信息
print(f"音频文件: {os.path.basename(audio_file)}")
print(f"采样率: {params.framerate}Hz, 通道数: {params.nchannels}")
print(f"采样位深: {params.sampwidth * 8}位, 总帧数: {params.nframes}")
# 设置配置参数
config = {
"mode": mode,
"chunk_size": chunk_size,
"wav_name": os.path.basename(audio_file),
"is_speaking": True
}
# 发送配置
await websocket.send(json.dumps(config))
# 模拟实时发送音频数据
chunk_size_bytes = 3200 # 每次发送100ms的16kHz音频
total_chunks = len(frames) // chunk_size_bytes
print(f"开始发送音频数据,共 {total_chunks} 个数据块...")
try:
for i in range(0, len(frames), chunk_size_bytes):
chunk = frames[i:i+chunk_size_bytes]
await websocket.send(chunk)
# 模拟实时每100ms发送一次
await asyncio.sleep(0.1)
# 显示进度
if (i // chunk_size_bytes) % 10 == 0:
print(f"已发送 {i // chunk_size_bytes}/{total_chunks} 数据块")
# 发送结束信号
await websocket.send(json.dumps({"is_speaking": False}))
print("音频数据发送完成")
except Exception as e:
print(f"发送音频时出错: {e}")
async def receive_results(websocket):
"""
接收并显示识别结果
参数:
websocket: WebSocket连接
"""
online_text = ""
offline_text = ""
try:
async for message in websocket:
# 解析服务器返回的JSON消息
result = json.loads(message)
mode = result.get("mode", "")
text = result.get("text", "")
is_final = result.get("is_final", False)
# 根据模式更新文本
if "online" in mode:
online_text = text
print(f"\r[在线识别] {online_text}", end="", flush=True)
elif "offline" in mode:
offline_text = text
print(f"\n[离线识别] {offline_text}")
# 如果是最终结果,打印完整信息
if is_final and offline_text:
print("\n最终识别结果:")
print(f"[离线识别] {offline_text}")
return
except Exception as e:
print(f"接收结果时出错: {e}")
async def main():
"""主函数"""
args = parse_args()
# WebSocket URI
protocol = "wss" if args.use_ssl else "ws"
uri = f"{protocol}://{args.host}:{args.port}"
print(f"连接到服务器: {uri}")
try:
# 创建WebSocket连接
async with websockets.connect(
uri,
subprotocols=["binary"]
) as websocket:
print("连接成功")
# 创建两个任务: 发送音频和接收结果
send_task = asyncio.create_task(
send_audio(websocket, args.audio_file, args.mode, args.chunk_size)
)
receive_task = asyncio.create_task(
receive_results(websocket)
)
# 等待任务完成
await asyncio.gather(send_task, receive_task)
except Exception as e:
print(f"连接服务器失败: {e}")
if __name__ == "__main__":
# 运行主函数
asyncio.run(main())

164
src/logic_trager.py Normal file
View File

@ -0,0 +1,164 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
逻辑触发器类 - 用于处理音频数据并触发相应的处理逻辑
"""
import logging
from typing import Any, Dict, Type
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('LogicTrager')
class AutoAfterMeta(type):
"""
自动调用__after__函数的元类
实现单例模式
"""
_instances: Dict[Type, Any] = {} # 存储单例实例
def __new__(cls, name, bases, attrs):
# 遍历所有属性
for attr_name, attr_value in attrs.items():
# 如果是函数且不是以_开头
if callable(attr_value) and not attr_name.startswith('__'):
# 获取原函数
original_func = attr_value
# 创建包装函数
def make_wrapper(func):
def wrapper(self, *args, **kwargs):
# 执行原函数
result = func(self, *args, **kwargs)
# 构建_after_函数名
after_func_name = f"__after__{func.__name__}"
# 检查是否存在对应的_after_函数
if hasattr(self, after_func_name):
after_func = getattr(self, after_func_name)
if callable(after_func):
try:
# 调用_after_函数
after_func()
except Exception as e:
logger.error(f"调用{after_func_name}时出错: {e}")
return result
return wrapper
# 替换原函数
attrs[attr_name] = make_wrapper(original_func)
# 创建类
new_class = super().__new__(cls, name, bases, attrs)
return new_class
def __call__(cls, *args, **kwargs):
"""
重写__call__方法实现单例模式
当类被调用时即创建实例时执行
"""
if cls not in cls._instances:
# 如果实例不存在,创建新实例
cls._instances[cls] = super().__call__(*args, **kwargs)
logger.info(f"创建{cls.__name__}的新实例")
else:
logger.debug(f"返回{cls.__name__}的现有实例")
return cls._instances[cls]
"""
整体识别的处理逻辑
1.压入二进制音频信息
2.不断检测VAD
3.当检测到完整VAD时,将VAD的音频信息压入音频块,并清除对应二进制信息
4.对音频块进行语音转文字offline,时间戳预测,说话人识别
5.将识别结果整合压入结果队列
6.结果队列被压入时调用回调函数
1->2 __after__push_binary_data 外部压入二进制信息
2,3->4 __after__push_audio_chunk 内部压入音频块
4->5 push_result_queue 压入结果队列
5->6 __after__push_result_queue 调用回调函数
"""
class LogicTrager(metaclass=AutoAfterMeta):
"""逻辑触发器类"""
def __init__(self,
audio_chunk_max_size: int = 1024 * 1024 * 10,
sample_rate: int = 16000,
channels: int = 1,
on_result_callback: Callable = None,
):
"""初始化"""
# 存储音频块
self._audio_chunk = []
# 存储二进制数据
self._audio_chunk_binary = b''
self._audio_chunk_max_size = audio_chunk_max_size
# 音频参数
self._sample_rate = sample_rate
self._channels = channels
# 结果队列
self._result_queue = []
# 回调函数
self._on_result_callback = on_result_callback
logger.info("初始化LogicTrager")
def push_binary_data(self, chunk: bytes) -> None:
"""
添加音频块
参数:
chunk: 音频数据块
"""
if self._audio_chunk is None:
logger.error("AudioChunk未初始化")
return
self._audio_chunk_binary += chunk
logger.debug(f"添加音频块,大小: {len(chunk)}字节")
def __after__push_binary_data(self) -> None:
"""
添加音频块后处理
VAD检测将检测到的VAD压入音频块
"""
# VAD检测
pass
# 压入音频块 push_audio_chunk
def push_audio_chunk(self, chunk: bytes) -> None:
"""
压入音频块
"""
self._audio_chunk.append(chunk)
def __after__push_audio_chunk(self) -> None:
"""
压入音频块后处理
"""
pass
def push_result_queue(self, result: Dict[str, Any]) -> None:
"""
压入结果队列
"""
self._result_queue.append(result)
def __after__push_result_queue(self) -> None:
"""
压入结果队列后处理
"""
pass
def __call__(self):
"""调用函数"""
pass

View File

@ -4,6 +4,8 @@
模型加载模块 - 负责加载各种语音识别相关模型
"""
from typing import List, Optional
def load_models(args):
"""
加载所有需要的模型

127
src/pydantic_models.py Normal file
View File

@ -0,0 +1,127 @@
from pydantic import BaseModel, Field, validator
from typing import List, Optional, Callable
class VADSegment(BaseModel):
"""VAD片段"""
start: int = Field(description="开始时间(ms)")
end: int = Field(description="结束时间(ms)")
class VADResult(BaseModel):
"""VAD结果"""
key: str = Field(description="音频标识")
value: List[VADSegment] = Field(description="VAD片段列表")
class VADResponse(BaseModel):
"""VAD响应"""
results: List[VADResult] = Field(description="VAD结果列表", default_factory=list)
time_chunk: List[VADSegment] = Field(description="时间块", default_factory=list)
time_chunk_index: int = Field(description="当前处理时间块索引", default=0)
time_chunk_index_callback: Optional[Callable[[int], None]] = Field(
description="时间块索引回调函数",
default=None
)
@validator('time_chunk')
def validate_time_chunk(cls, v):
"""验证时间块的有效性"""
if not v:
return v
# 检查时间顺序
for i in range(len(v) - 1):
if v[i].end >= v[i + 1].start:
raise ValueError(f"时间块{i}的结束时间({v[i].end})大于等于下一个时间块的开始时间({v[i + 1].start})")
return v
# 回调未处理的时间块
def process_time_chunk(self, callback: Callable[[int], None] = None) -> None:
"""处理时间块"""
# print("Enter process_time_chunk", self.time_chunk_index, len(self.time_chunk))
while self.time_chunk_index < len(self.time_chunk) - 1:
if self.time_chunk[self.time_chunk_index].end != -1:
if callback is not None:
callback(self.time_chunk_index)
elif self.time_chunk_index_callback is not None:
self.time_chunk_index_callback(self.time_chunk_index)
else:
print("[Warning] No callback available")
self.time_chunk_index += 1
def __add__(self, other: 'VADResponse') -> 'VADResponse':
"""合并两个VADResponse"""
if not self.results:
self.results = other.results
self.time_chunk = other.time_chunk
return self
# 检查是否可以合并最后一个结果
last_result = self.results[-1]
first_other = other.results[0]
if last_result.value[-1].end == first_other.value[0].start:
# 合并相邻的时间段
last_result.value[-1].end = first_other.value[0].end
first_other.value.pop(0)
# 更新time_chunk
self.time_chunk[-1].end = other.time_chunk[0].end
other.time_chunk.pop(0)
# 添加剩余的结果
if first_other.value:
self.results.extend(other.results)
self.time_chunk.extend(other.time_chunk)
else:
# 直接添加所有结果
self.results.extend(other.results)
self.time_chunk.extend(other.time_chunk)
return self
@classmethod
def from_raw(cls, raw_data: List[dict]) -> "VADResponse":
"""
从原始数据创建VADResponse
参数:
raw_data: 原始数据格式如 [{'key': 'xxx', 'value': [[-1, 59540], [59820, -1]]}]
返回:
VADResponse: 解析后的VAD响应
"""
results = []
time_chunk = []
for item in raw_data:
segments = [
VADSegment(start=seg[0], end=seg[1])
for seg in item['value']
]
results.append(VADResult(
key=item['key'],
value=segments
))
time_chunk.extend(segments)
return cls(results=results, time_chunk=time_chunk)
def to_raw(self) -> List[dict]:
"""
转换为原始数据格式
返回:
List[dict]: 原始数据格式
"""
return [
{
'key': result.key,
'value': [[seg.start, seg.end] for seg in result.value]
}
for result in self.results
]
def __str__(self):
result_str = "VADResponse:\n"
for result in self.results:
for value_item in result.value:
result_str += f"[{value_item.start}:{value_item.end}]\n"
return result_str

4
test_main.py Normal file
View File

@ -0,0 +1,4 @@
from tests.modelsuse import vad_model_use_online
vad_result = vad_model_use_online("tests/vad_example.wav")
print(vad_result)

37
tests/modelsuse.py Normal file
View File

@ -0,0 +1,37 @@
from funasr import AutoModel
from typing import List, Dict, Any
from src.pydantic_models import VADResponse
import time
def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
chunk_size = 100 # ms
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
vad_result = VADResponse()
vad_result.time_chunk_index_callback = lambda index: print(f"回调: {index}")
items = []
import soundfile
speech, sample_rate = soundfile.read(file_path)
chunk_stride = int(chunk_size * sample_rate / 1000)
cache = {}
total_chunk_num = int(len((speech)-1)/chunk_stride+1)
for i in range(total_chunk_num):
time.sleep(0.1)
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
is_final = i == total_chunk_num - 1
res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
if len(res[0]["value"]):
vad_result += VADResponse.from_raw(res)
for item in res[0]["value"]:
items.append(item)
vad_result.process_time_chunk()
# for item in items:
# print(item)
return vad_result
if __name__ == "__main__":
vad_result = vad_model_use_online("tests/vad_example.wav")
# print(vad_result)

BIN
tests/vad_example.wav Normal file

Binary file not shown.