Compare commits
6 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
1392168126 | ||
![]() |
eff22cb33e | ||
![]() |
66c9477e4b | ||
9d522fa137 | |||
f7138dcb39 | |||
8b69ff195f |
196
src/client.py
196
src/client.py
@ -1,196 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
WebSocket客户端示例 - 用于测试语音识别服务
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import websockets
|
|
||||||
import argparse
|
|
||||||
import numpy as np
|
|
||||||
import wave
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
"""解析命令行参数"""
|
|
||||||
parser = argparse.ArgumentParser(description="FunASR WebSocket客户端")
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--host",
|
|
||||||
type=str,
|
|
||||||
default="localhost",
|
|
||||||
help="服务器主机地址"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--port",
|
|
||||||
type=int,
|
|
||||||
default=10095,
|
|
||||||
help="服务器端口"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--audio_file",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="要识别的音频文件路径"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--mode",
|
|
||||||
type=str,
|
|
||||||
default="2pass",
|
|
||||||
choices=["2pass", "online", "offline"],
|
|
||||||
help="识别模式: 2pass(默认), online, offline"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--chunk_size",
|
|
||||||
type=str,
|
|
||||||
default="5,10",
|
|
||||||
help="分块大小, 格式为'encoder_size,decoder_size'"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_ssl",
|
|
||||||
action="store_true",
|
|
||||||
help="是否使用SSL连接"
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
async def send_audio(websocket, audio_file, mode, chunk_size):
|
|
||||||
"""
|
|
||||||
发送音频文件到服务器进行识别
|
|
||||||
|
|
||||||
参数:
|
|
||||||
websocket: WebSocket连接
|
|
||||||
audio_file: 音频文件路径
|
|
||||||
mode: 识别模式
|
|
||||||
chunk_size: 分块大小
|
|
||||||
"""
|
|
||||||
# 打开并读取WAV文件
|
|
||||||
with wave.open(audio_file, "rb") as wav_file:
|
|
||||||
params = wav_file.getparams()
|
|
||||||
frames = wav_file.readframes(wav_file.getnframes())
|
|
||||||
|
|
||||||
# 音频文件信息
|
|
||||||
print(f"音频文件: {os.path.basename(audio_file)}")
|
|
||||||
print(f"采样率: {params.framerate}Hz, 通道数: {params.nchannels}")
|
|
||||||
print(f"采样位深: {params.sampwidth * 8}位, 总帧数: {params.nframes}")
|
|
||||||
|
|
||||||
# 设置配置参数
|
|
||||||
config = {
|
|
||||||
"mode": mode,
|
|
||||||
"chunk_size": chunk_size,
|
|
||||||
"wav_name": os.path.basename(audio_file),
|
|
||||||
"is_speaking": True
|
|
||||||
}
|
|
||||||
|
|
||||||
# 发送配置
|
|
||||||
await websocket.send(json.dumps(config))
|
|
||||||
|
|
||||||
# 模拟实时发送音频数据
|
|
||||||
chunk_size_bytes = 3200 # 每次发送100ms的16kHz音频
|
|
||||||
total_chunks = len(frames) // chunk_size_bytes
|
|
||||||
|
|
||||||
print(f"开始发送音频数据,共 {total_chunks} 个数据块...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
for i in range(0, len(frames), chunk_size_bytes):
|
|
||||||
chunk = frames[i:i+chunk_size_bytes]
|
|
||||||
await websocket.send(chunk)
|
|
||||||
|
|
||||||
# 模拟实时,每100ms发送一次
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
# 显示进度
|
|
||||||
if (i // chunk_size_bytes) % 10 == 0:
|
|
||||||
print(f"已发送 {i // chunk_size_bytes}/{total_chunks} 数据块")
|
|
||||||
|
|
||||||
# 发送结束信号
|
|
||||||
await websocket.send(json.dumps({"is_speaking": False}))
|
|
||||||
print("音频数据发送完成")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"发送音频时出错: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def receive_results(websocket):
|
|
||||||
"""
|
|
||||||
接收并显示识别结果
|
|
||||||
|
|
||||||
参数:
|
|
||||||
websocket: WebSocket连接
|
|
||||||
"""
|
|
||||||
online_text = ""
|
|
||||||
offline_text = ""
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for message in websocket:
|
|
||||||
# 解析服务器返回的JSON消息
|
|
||||||
result = json.loads(message)
|
|
||||||
|
|
||||||
mode = result.get("mode", "")
|
|
||||||
text = result.get("text", "")
|
|
||||||
is_final = result.get("is_final", False)
|
|
||||||
|
|
||||||
# 根据模式更新文本
|
|
||||||
if "online" in mode:
|
|
||||||
online_text = text
|
|
||||||
print(f"\r[在线识别] {online_text}", end="", flush=True)
|
|
||||||
elif "offline" in mode:
|
|
||||||
offline_text = text
|
|
||||||
print(f"\n[离线识别] {offline_text}")
|
|
||||||
|
|
||||||
# 如果是最终结果,打印完整信息
|
|
||||||
if is_final and offline_text:
|
|
||||||
print("\n最终识别结果:")
|
|
||||||
print(f"[离线识别] {offline_text}")
|
|
||||||
return
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"接收结果时出错: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""主函数"""
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
# WebSocket URI
|
|
||||||
protocol = "wss" if args.use_ssl else "ws"
|
|
||||||
uri = f"{protocol}://{args.host}:{args.port}"
|
|
||||||
|
|
||||||
print(f"连接到服务器: {uri}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 创建WebSocket连接
|
|
||||||
async with websockets.connect(
|
|
||||||
uri,
|
|
||||||
subprotocols=["binary"]
|
|
||||||
) as websocket:
|
|
||||||
|
|
||||||
print("连接成功")
|
|
||||||
|
|
||||||
# 创建两个任务: 发送音频和接收结果
|
|
||||||
send_task = asyncio.create_task(
|
|
||||||
send_audio(websocket, args.audio_file, args.mode, args.chunk_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
receive_task = asyncio.create_task(
|
|
||||||
receive_results(websocket)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 等待任务完成
|
|
||||||
await asyncio.gather(send_task, receive_task)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"连接服务器失败: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 运行主函数
|
|
||||||
asyncio.run(main())
|
|
4
src/functor/__init__.py
Normal file
4
src/functor/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .vad_functor import VAD
|
||||||
|
from .model_loader import load_models
|
||||||
|
|
||||||
|
__all__ = ["VAD", "load_models"]
|
178
src/functor/audiochunk.py
Normal file
178
src/functor/audiochunk.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
音频数据块管理类 - 用于存储和处理16KHz音频数据
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
from src.models import AudioBinary_Config
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = get_module_logger(__name__, level="INFO")
|
||||||
|
|
||||||
|
class AudioChunk:
|
||||||
|
"""音频数据块管理类,用于存储和处理16KHz音频数据"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
max_duration_ms: int = 1000*60*60*10,
|
||||||
|
audio_config : AudioBinary_Config = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化音频数据块管理器
|
||||||
|
|
||||||
|
参数:
|
||||||
|
max_duration_ms: 音频池最大留存时间(ms),默认10小时
|
||||||
|
audio_config: 音频配置信息
|
||||||
|
"""
|
||||||
|
# 音频参数
|
||||||
|
self.sample_rate = audio_config.sample_rate if audio_config is not None else 16000 # 采样率:16KHz
|
||||||
|
self.sample_width = audio_config.sample_width if audio_config is not None else 2 # 采样位宽:16bit
|
||||||
|
self.channels = audio_config.channels if audio_config is not None else 1 # 通道数:单声道
|
||||||
|
|
||||||
|
# 数据存储
|
||||||
|
self._max_duration_ms = max_duration_ms
|
||||||
|
self._max_chunk_size = self._time2size(max_duration_ms) # 最大数据大小
|
||||||
|
self._chunk = [] # 当前音频数据块列表
|
||||||
|
self._chunk_size = 0 # 当前数据总大小
|
||||||
|
self._offset = 0 # 当前偏移量
|
||||||
|
|
||||||
|
logger.info(f"初始化AudioChunk: 最大时长={max_duration_ms}ms, 最大数据大小={self._max_chunk_size}字节")
|
||||||
|
|
||||||
|
def add_chunk(self, chunk: Union[bytes, np.ndarray]) -> bool:
|
||||||
|
"""
|
||||||
|
添加音频数据块
|
||||||
|
|
||||||
|
参数:
|
||||||
|
chunk: 音频数据块,可以是bytes或numpy数组
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否添加成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 检查数据格式
|
||||||
|
if isinstance(chunk, np.ndarray):
|
||||||
|
# 确保是16bit整数格式
|
||||||
|
if chunk.dtype != np.int16:
|
||||||
|
chunk = chunk.astype(np.int16)
|
||||||
|
# 转换为bytes
|
||||||
|
chunk = chunk.tobytes()
|
||||||
|
|
||||||
|
# 检查数据大小
|
||||||
|
if len(chunk) % (self.sample_width * self.channels) != 0:
|
||||||
|
logger.warning(f"音频数据大小不是{self.sample_width * self.channels}的倍数: {len(chunk)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查是否超过最大限制
|
||||||
|
if self._chunk_size + len(chunk) > self._max_chunk_size:
|
||||||
|
logger.warning("音频数据超过最大限制,将自动清除旧数据")
|
||||||
|
self.clear_chunk()
|
||||||
|
|
||||||
|
# 添加数据
|
||||||
|
self._chunk.append(chunk)
|
||||||
|
self._chunk_size += len(chunk)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"添加音频数据块时出错: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_chunk_binary(self, start: int = 0, end: Optional[int] = None) -> Optional[bytes]:
|
||||||
|
"""
|
||||||
|
获取指定索引的音频数据块
|
||||||
|
"""
|
||||||
|
print("[AudioChunk] get_chunk_binary", start, end)
|
||||||
|
if start >= len(self._chunk):
|
||||||
|
return None
|
||||||
|
if end is None or end > len(self._chunk):
|
||||||
|
end = len(self._chunk)
|
||||||
|
data = b''.join(self._chunk)
|
||||||
|
return data[start:end]
|
||||||
|
|
||||||
|
def get_chunk(self, start_ms: int = 0, end_ms: Optional[int] = None) -> Optional[bytes]:
|
||||||
|
"""
|
||||||
|
获取指定时间范围的音频数据
|
||||||
|
|
||||||
|
参数:
|
||||||
|
start_ms: 开始时间(ms)
|
||||||
|
end_ms: 结束时间(ms),None表示到末尾
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Optional[bytes]: 音频数据,如果获取失败则返回None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not self._chunk:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 计算字节偏移
|
||||||
|
start_byte = self._time2size(start_ms)
|
||||||
|
end_byte = self._time2size(end_ms) if end_ms is not None else self._chunk_size
|
||||||
|
|
||||||
|
# 检查范围是否有效
|
||||||
|
if start_byte >= self._chunk_size or start_byte >= end_byte:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 获取数据
|
||||||
|
data = b''.join(self._chunk)
|
||||||
|
return data[start_byte:end_byte]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取音频数据块时出错: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_duration(self) -> int:
|
||||||
|
"""
|
||||||
|
获取当前音频总时长(ms)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 音频时长(ms)
|
||||||
|
"""
|
||||||
|
return self._size2time(self._chunk_size)
|
||||||
|
|
||||||
|
def clear_chunk(self) -> None:
|
||||||
|
"""清除所有音频数据"""
|
||||||
|
self._chunk = []
|
||||||
|
self._chunk_size = 0
|
||||||
|
self._offset = 0
|
||||||
|
logger.info("已清除所有音频数据")
|
||||||
|
|
||||||
|
def _time2size(self, time_ms: int) -> int:
|
||||||
|
"""
|
||||||
|
将时间(ms)转换为数据大小(字节)
|
||||||
|
|
||||||
|
参数:
|
||||||
|
time_ms: 时间(ms)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 数据大小(字节)
|
||||||
|
"""
|
||||||
|
return int(time_ms * self.sample_rate * self.sample_width * self.channels / 1000)
|
||||||
|
|
||||||
|
def _size2time(self, size: int) -> int:
|
||||||
|
"""
|
||||||
|
将数据大小(字节)转换为时间(ms)
|
||||||
|
|
||||||
|
参数:
|
||||||
|
size: 数据大小(字节)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 时间(ms)
|
||||||
|
"""
|
||||||
|
return int(size * 1000 / (self.sample_rate * self.sample_width * self.channels))
|
||||||
|
|
||||||
|
# instance(start_ms, end_ms, use_offset=True)
|
||||||
|
def __call__(self, start_ms: int = 0, end_ms: Optional[int] = None, use_offset: bool = True) -> Optional[bytes]:
|
||||||
|
"""
|
||||||
|
获取指定时间范围的音频数据
|
||||||
|
"""
|
||||||
|
if use_offset:
|
||||||
|
start_ms += self._offset
|
||||||
|
end_ms += self._offset
|
||||||
|
return self.get_chunk(start_ms, end_ms)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
获取当前音频数据块大小
|
||||||
|
"""
|
||||||
|
return self._chunk_size
|
@ -4,6 +4,8 @@
|
|||||||
模型加载模块 - 负责加载各种语音识别相关模型
|
模型加载模块 - 负责加载各种语音识别相关模型
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
def load_models(args):
|
def load_models(args):
|
||||||
"""
|
"""
|
||||||
加载所有需要的模型
|
加载所有需要的模型
|
||||||
@ -33,6 +35,7 @@ def load_models(args):
|
|||||||
device=args.device,
|
device=args.device,
|
||||||
disable_pbar=True,
|
disable_pbar=True,
|
||||||
disable_log=True,
|
disable_log=True,
|
||||||
|
disable_update=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. 加载在线ASR模型
|
# 2. 加载在线ASR模型
|
||||||
@ -45,6 +48,7 @@ def load_models(args):
|
|||||||
device=args.device,
|
device=args.device,
|
||||||
disable_pbar=True,
|
disable_pbar=True,
|
||||||
disable_log=True,
|
disable_log=True,
|
||||||
|
disable_update=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 加载VAD模型
|
# 3. 加载VAD模型
|
||||||
@ -57,6 +61,7 @@ def load_models(args):
|
|||||||
device=args.device,
|
device=args.device,
|
||||||
disable_pbar=True,
|
disable_pbar=True,
|
||||||
disable_log=True,
|
disable_log=True,
|
||||||
|
disable_update=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. 加载标点符号模型(如果指定)
|
# 4. 加载标点符号模型(如果指定)
|
||||||
@ -70,6 +75,7 @@ def load_models(args):
|
|||||||
device=args.device,
|
device=args.device,
|
||||||
disable_pbar=True,
|
disable_pbar=True,
|
||||||
disable_log=True,
|
disable_log=True,
|
||||||
|
disable_update=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
models["punc"] = None
|
models["punc"] = None
|
60
src/functor/vad_functor.py
Normal file
60
src/functor/vad_functor.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
from funasr import AutoModel
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from src.models import VADResponse
|
||||||
|
from src.models import AudioBinary_Config
|
||||||
|
from src.functor.audiochunk import AudioChunk
|
||||||
|
from src.models import AudioBinary_Chunk
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
class VAD:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
VAD_model = None,
|
||||||
|
audio_config : AudioBinary_Config = None,
|
||||||
|
callback: Callable = None,
|
||||||
|
):
|
||||||
|
# vad model
|
||||||
|
self.VAD_model = VAD_model
|
||||||
|
if self.VAD_model is None:
|
||||||
|
self.VAD_model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
|
||||||
|
# audio config
|
||||||
|
self.audio_config = audio_config
|
||||||
|
# vad result
|
||||||
|
self.vad_result = VADResponse(time_chunk_index_callback=callback)
|
||||||
|
# audio binary poll
|
||||||
|
self.audio_chunk = AudioChunk(
|
||||||
|
audio_config=self.audio_config
|
||||||
|
)
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
def push_binary_data(self,
|
||||||
|
binary_data: bytes,
|
||||||
|
):
|
||||||
|
# 压入二进制数据
|
||||||
|
self.audio_chunk.add_chunk(binary_data)
|
||||||
|
# 处理音频块
|
||||||
|
res = self.VAD_model.generate(input=binary_data,
|
||||||
|
cache=self.cache,
|
||||||
|
chunk_size=self.audio_config.chunk_size,
|
||||||
|
is_final=False)
|
||||||
|
# print("VAD generate", res)
|
||||||
|
if len(res[0]["value"]):
|
||||||
|
self.vad_result += VADResponse.from_raw(res)
|
||||||
|
|
||||||
|
def set_callback(self,
|
||||||
|
callback: Callable,
|
||||||
|
):
|
||||||
|
self.vad_result.time_chunk_index_callback = callback
|
||||||
|
|
||||||
|
def process_vad_result(self, callback: Callable = None):
|
||||||
|
# 处理VAD结果
|
||||||
|
callback = callback if callback is not None else self.vad_result.time_chunk_index_callback
|
||||||
|
self.vad_result.process_time_chunk(
|
||||||
|
lambda x : callback(
|
||||||
|
AudioBinary_Chunk(
|
||||||
|
start_time=x["start_time"],
|
||||||
|
end_time=x["end_time"],
|
||||||
|
chunk=self.audio_chunk.get_chunk(x["start_time"], x["end_time"])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
165
src/logic_trager.py
Normal file
165
src/logic_trager.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
逻辑触发器类 - 用于处理音频数据并触发相应的处理逻辑
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
from typing import Any, Dict, Type, Callable
|
||||||
|
# 配置日志
|
||||||
|
logger = get_module_logger(__name__, level="INFO")
|
||||||
|
|
||||||
|
class AutoAfterMeta(type):
|
||||||
|
"""
|
||||||
|
自动调用__after__函数的元类
|
||||||
|
实现单例模式
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instances: Dict[Type, Any] = {} # 存储单例实例
|
||||||
|
|
||||||
|
def __new__(cls, name, bases, attrs):
|
||||||
|
# 遍历所有属性
|
||||||
|
for attr_name, attr_value in attrs.items():
|
||||||
|
# 如果是函数且不是以_开头
|
||||||
|
if callable(attr_value) and not attr_name.startswith('__'):
|
||||||
|
# 获取原函数
|
||||||
|
original_func = attr_value
|
||||||
|
|
||||||
|
# 创建包装函数
|
||||||
|
def make_wrapper(func):
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
# 执行原函数
|
||||||
|
result = func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
# 构建_after_函数名
|
||||||
|
after_func_name = f"__after__{func.__name__}"
|
||||||
|
|
||||||
|
# 检查是否存在对应的_after_函数
|
||||||
|
if hasattr(self, after_func_name):
|
||||||
|
after_func = getattr(self, after_func_name)
|
||||||
|
if callable(after_func):
|
||||||
|
try:
|
||||||
|
# 调用_after_函数
|
||||||
|
after_func()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"调用{after_func_name}时出错: {e}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
# 替换原函数
|
||||||
|
attrs[attr_name] = make_wrapper(original_func)
|
||||||
|
|
||||||
|
# 创建类
|
||||||
|
new_class = super().__new__(cls, name, bases, attrs)
|
||||||
|
return new_class
|
||||||
|
|
||||||
|
def __call__(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
重写__call__方法实现单例模式
|
||||||
|
当类被调用时(即创建实例时)执行
|
||||||
|
"""
|
||||||
|
if cls not in cls._instances:
|
||||||
|
# 如果实例不存在,创建新实例
|
||||||
|
cls._instances[cls] = super().__call__(*args, **kwargs)
|
||||||
|
logger.info(f"创建{cls.__name__}的新实例")
|
||||||
|
else:
|
||||||
|
logger.debug(f"返回{cls.__name__}的现有实例")
|
||||||
|
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
"""
|
||||||
|
整体识别的处理逻辑:
|
||||||
|
1.压入二进制音频信息
|
||||||
|
2.不断检测VAD
|
||||||
|
3.当检测到完整VAD时,将VAD的音频信息压入音频块,并清除对应二进制信息
|
||||||
|
4.对音频块进行语音转文字offline,时间戳预测,说话人识别
|
||||||
|
5.将识别结果整合压入结果队列
|
||||||
|
6.结果队列被压入时调用回调函数
|
||||||
|
|
||||||
|
1->2 __after__push_binary_data 外部压入二进制信息
|
||||||
|
2,3->4 __after__push_audio_chunk 内部压入音频块
|
||||||
|
4->5 push_result_queue 压入结果队列
|
||||||
|
5->6 __after__push_result_queue 调用回调函数
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.functor import VAD
|
||||||
|
from src.models import AudioBinary_Config
|
||||||
|
from src.models import AudioBinary_Chunk
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
class LogicTrager(metaclass=AutoAfterMeta):
|
||||||
|
"""逻辑触发器类"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
audio_chunk_max_size: int = 1024 * 1024 * 10,
|
||||||
|
audio_config: AudioBinary_Config = None,
|
||||||
|
result_callback: Callable = None,
|
||||||
|
models: Dict[str, Any] = None,
|
||||||
|
):
|
||||||
|
"""初始化"""
|
||||||
|
# 存储音频块
|
||||||
|
self._audio_chunk : List[AudioBinary_Chunk] = []
|
||||||
|
# 存储二进制数据
|
||||||
|
self._audio_chunk_binary = b''
|
||||||
|
self._audio_chunk_max_size = audio_chunk_max_size
|
||||||
|
# 音频参数
|
||||||
|
self._audio_config = audio_config if audio_config is not None else AudioBinary_Config()
|
||||||
|
# 结果队列
|
||||||
|
self._result_queue = []
|
||||||
|
# 聚合结果回调函数
|
||||||
|
self._aggregate_result_callback = result_callback
|
||||||
|
# 组件
|
||||||
|
self._vad = VAD(VAD_model = models.get("vad"), audio_config = self._audio_config)
|
||||||
|
self._vad.set_callback(self.push_audio_chunk)
|
||||||
|
|
||||||
|
|
||||||
|
logger.info("初始化LogicTrager")
|
||||||
|
|
||||||
|
def push_binary_data(self, chunk: bytes) -> None:
|
||||||
|
"""
|
||||||
|
压入音频块至VAD模块
|
||||||
|
|
||||||
|
参数:
|
||||||
|
chunk: 音频数据块
|
||||||
|
"""
|
||||||
|
# print("LogicTrager push_binary_data", len(chunk))
|
||||||
|
self._vad.push_binary_data(chunk)
|
||||||
|
self.__after__push_binary_data()
|
||||||
|
|
||||||
|
def __after__push_binary_data(self) -> None:
|
||||||
|
"""
|
||||||
|
添加音频块后处理
|
||||||
|
"""
|
||||||
|
# print("LogicTrager __after__push_binary_data")
|
||||||
|
self._vad.process_vad_result()
|
||||||
|
|
||||||
|
def push_audio_chunk(self, chunk: AudioBinary_Chunk) -> None:
|
||||||
|
"""
|
||||||
|
音频处理
|
||||||
|
"""
|
||||||
|
logger.info("LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(chunk.start_time, chunk.end_time, len(chunk.chunk)))
|
||||||
|
self._audio_chunk.append(chunk)
|
||||||
|
|
||||||
|
def __after__push_audio_chunk(self) -> None:
|
||||||
|
"""
|
||||||
|
压入音频块后处理
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def push_result_queue(self, result: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
压入结果队列
|
||||||
|
"""
|
||||||
|
self._result_queue.append(result)
|
||||||
|
|
||||||
|
def __after__push_result_queue(self) -> None:
|
||||||
|
"""
|
||||||
|
压入结果队列后处理
|
||||||
|
"""
|
||||||
|
logger.info("FINISH Result=")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
"""调用函数"""
|
||||||
|
pass
|
3
src/models/__init__.py
Normal file
3
src/models/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .audiobinary import AudioBinary_Config, AudioBinary_Chunk
|
||||||
|
from .vad import VADResponse
|
||||||
|
__all__ = ["AudioBinary_Config", "AudioBinary_Chunk", "VADResponse"]
|
16
src/models/audiobinary.py
Normal file
16
src/models/audiobinary.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
class AudioBinary_Config(BaseModel):
|
||||||
|
"""二进制音频块配置信息"""
|
||||||
|
audio_data: bytes = Field(description="音频数据", default=None)
|
||||||
|
chunk_size: int = Field(description="块大小", default=100)
|
||||||
|
chunk_stride: int = Field(description="块步长", default=1600)
|
||||||
|
sample_rate: int = Field(description="采样率", default=16000)
|
||||||
|
sample_width: int = Field(description="采样位宽", default=2)
|
||||||
|
channels: int = Field(description="通道数", default=1)
|
||||||
|
|
||||||
|
class AudioBinary_Chunk(BaseModel):
|
||||||
|
"""音频块"""
|
||||||
|
start_time: int = Field(description="开始时间", default=0)
|
||||||
|
end_time: int = Field(description="结束时间", default=0)
|
||||||
|
chunk: bytes = Field(description="音频块", default=None)
|
143
src/models/vad.py
Normal file
143
src/models/vad.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
from typing import List, Optional, Callable
|
||||||
|
|
||||||
|
class VADSegment(BaseModel):
|
||||||
|
"""VAD片段"""
|
||||||
|
start: int = Field(description="开始时间(ms)")
|
||||||
|
end: int = Field(description="结束时间(ms)")
|
||||||
|
|
||||||
|
class VADResult(BaseModel):
|
||||||
|
"""VAD结果"""
|
||||||
|
key: str = Field(description="音频标识")
|
||||||
|
value: List[VADSegment] = Field(description="VAD片段列表")
|
||||||
|
|
||||||
|
class VADResponse(BaseModel):
|
||||||
|
"""VAD响应"""
|
||||||
|
results: List[VADResult] = Field(description="VAD结果列表", default_factory=list)
|
||||||
|
time_chunk: List[VADSegment] = Field(description="时间块", default_factory=list)
|
||||||
|
time_chunk_index: int = Field(description="当前处理时间块索引", default=0)
|
||||||
|
time_chunk_index_callback: Optional[Callable[[int], None]] = Field(
|
||||||
|
description="时间块索引回调函数",
|
||||||
|
default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
@validator('time_chunk')
|
||||||
|
def validate_time_chunk(cls, v):
|
||||||
|
"""验证时间块的有效性"""
|
||||||
|
if not v:
|
||||||
|
return v
|
||||||
|
|
||||||
|
# 检查时间顺序
|
||||||
|
for i in range(len(v) - 1):
|
||||||
|
if v[i].end >= v[i + 1].start:
|
||||||
|
raise ValueError(f"时间块{i}的结束时间({v[i].end})大于等于下一个时间块的开始时间({v[i + 1].start})")
|
||||||
|
return v
|
||||||
|
|
||||||
|
# 回调未处理的时间块
|
||||||
|
def process_time_chunk(self, callback: Callable[[int], None] = None) -> None:
|
||||||
|
"""处理时间块"""
|
||||||
|
# print("Enter process_time_chunk", self.time_chunk_index, len(self.time_chunk))
|
||||||
|
while self.time_chunk_index < len(self.time_chunk) - 1:
|
||||||
|
index = self.time_chunk_index
|
||||||
|
if self.time_chunk[index].end != -1:
|
||||||
|
x = {
|
||||||
|
"start_time": self.time_chunk[index].start,
|
||||||
|
"end_time": self.time_chunk[index].end
|
||||||
|
}
|
||||||
|
if callback is not None:
|
||||||
|
callback(x)
|
||||||
|
elif self.time_chunk_index_callback is not None:
|
||||||
|
self.time_chunk_index_callback(x)
|
||||||
|
else:
|
||||||
|
print("[Warning] No callback available")
|
||||||
|
self.time_chunk_index += 1
|
||||||
|
|
||||||
|
def __add__(self, other: 'VADResponse') -> 'VADResponse':
|
||||||
|
"""合并两个VADResponse"""
|
||||||
|
if not self.results:
|
||||||
|
self.results = other.results
|
||||||
|
self.time_chunk = other.time_chunk
|
||||||
|
return self
|
||||||
|
|
||||||
|
# 检查是否可以合并最后一个结果
|
||||||
|
last_result = self.results[-1]
|
||||||
|
first_other = other.results[0]
|
||||||
|
|
||||||
|
if last_result.value[-1].end == first_other.value[0].start:
|
||||||
|
# 合并相邻的时间段
|
||||||
|
last_result.value[-1].end = first_other.value[0].end
|
||||||
|
first_other.value.pop(0)
|
||||||
|
|
||||||
|
# 更新time_chunk
|
||||||
|
self.time_chunk[-1].end = other.time_chunk[0].end
|
||||||
|
other.time_chunk.pop(0)
|
||||||
|
|
||||||
|
# 添加剩余的结果
|
||||||
|
if first_other.value:
|
||||||
|
self.results.extend(other.results)
|
||||||
|
self.time_chunk.extend(other.time_chunk)
|
||||||
|
else:
|
||||||
|
# 直接添加所有结果
|
||||||
|
self.results.extend(other.results)
|
||||||
|
self.time_chunk.extend(other.time_chunk)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_raw(cls, raw_data: List[dict]) -> "VADResponse":
|
||||||
|
"""
|
||||||
|
从原始数据创建VADResponse
|
||||||
|
|
||||||
|
参数:
|
||||||
|
raw_data: 原始数据,格式如 [{'key': 'xxx', 'value': [[-1, 59540], [59820, -1]]}]
|
||||||
|
|
||||||
|
返回:
|
||||||
|
VADResponse: 解析后的VAD响应
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
time_chunk = []
|
||||||
|
for item in raw_data:
|
||||||
|
segments = [
|
||||||
|
VADSegment(start=seg[0], end=seg[1])
|
||||||
|
for seg in item['value']
|
||||||
|
]
|
||||||
|
results.append(VADResult(
|
||||||
|
key=item['key'],
|
||||||
|
value=segments
|
||||||
|
))
|
||||||
|
time_chunk.extend(segments)
|
||||||
|
return cls(results=results, time_chunk=time_chunk)
|
||||||
|
|
||||||
|
def to_raw(self) -> List[dict]:
|
||||||
|
"""
|
||||||
|
转换为原始数据格式
|
||||||
|
|
||||||
|
返回:
|
||||||
|
List[dict]: 原始数据格式
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'key': result.key,
|
||||||
|
'value': [[seg.start, seg.end] for seg in result.value]
|
||||||
|
}
|
||||||
|
for result in self.results
|
||||||
|
]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
result_str = "VADResponse:\n"
|
||||||
|
for result in self.results:
|
||||||
|
for value_item in result.value:
|
||||||
|
result_str += f"[{value_item.start}:{value_item.end}]\n"
|
||||||
|
return result_str
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.time_chunk)
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
return next(self.time_chunk)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.time_chunk)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.time_chunk[index]
|
91
src/utils/logger.py
Normal file
91
src/utils/logger.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
def setup_logger(
|
||||||
|
name: str = None,
|
||||||
|
level: str = "INFO",
|
||||||
|
log_file: Optional[str] = None,
|
||||||
|
log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
date_format: str = "%Y-%m-%d %H:%M:%S",
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
设置并返回一个配置好的logger实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: logger的名称,默认为None(使用root logger)
|
||||||
|
level: 日志级别,默认为"INFO"
|
||||||
|
log_file: 日志文件路径,默认为None(仅控制台输出)
|
||||||
|
log_format: 日志格式
|
||||||
|
date_format: 日期格式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logging.Logger: 配置好的logger实例
|
||||||
|
"""
|
||||||
|
# 获取logger实例
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
|
||||||
|
# 设置日志级别
|
||||||
|
level = getattr(logging, level.upper())
|
||||||
|
logger.setLevel(level)
|
||||||
|
|
||||||
|
print(f"添加处理器 {name} {log_file} {log_format} {date_format}")
|
||||||
|
# 创建格式器
|
||||||
|
formatter = logging.Formatter(log_format, date_format)
|
||||||
|
|
||||||
|
# 添加控制台处理器
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
# 如果指定了日志文件,添加文件处理器
|
||||||
|
if log_file:
|
||||||
|
# 确保日志目录存在
|
||||||
|
log_path = Path(log_file)
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
file_handler = logging.FileHandler(log_file, encoding='utf-8')
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
# 注意:移除了 propagate = False,允许日志传递
|
||||||
|
return logger
|
||||||
|
|
||||||
|
def setup_root_logger(
|
||||||
|
level: str = "INFO",
|
||||||
|
log_file: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
配置根日志器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: 日志级别
|
||||||
|
log_file: 日志文件路径
|
||||||
|
"""
|
||||||
|
setup_logger(None, level, log_file)
|
||||||
|
|
||||||
|
def get_module_logger(
|
||||||
|
module_name: str,
|
||||||
|
level: Optional[str] = None, # 改为可选参数
|
||||||
|
log_file: Optional[str] = None # 一般不需要单独指定
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
获取模块级别的logger
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_name: 模块名称,通常传入__name__
|
||||||
|
level: 可选的日志级别,如果不指定则继承父级配置
|
||||||
|
log_file: 可选的日志文件路径,一般不需要指定
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(module_name)
|
||||||
|
|
||||||
|
# 只有显式指定了level才设置
|
||||||
|
if level:
|
||||||
|
logger.setLevel(getattr(logging, level.upper()))
|
||||||
|
|
||||||
|
# 只有显式指定了log_file才添加文件处理器
|
||||||
|
if log_file:
|
||||||
|
setup_logger(module_name, level or "INFO", log_file)
|
||||||
|
|
||||||
|
return logger
|
22
test_main.py
Normal file
22
test_main.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from src.utils.logger import get_module_logger, setup_root_logger
|
||||||
|
from tests.modelsuse import vad_model_use_online_logic, asr_model_use_offline
|
||||||
|
import json
|
||||||
|
setup_root_logger(level="INFO",log_file="logs/test_main.log")
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
logger.info("开始测试")
|
||||||
|
vad_result = vad_model_use_online_logic("tests/vad_example.wav")
|
||||||
|
logger.info("测试结束")
|
||||||
|
if vad_result is None:
|
||||||
|
logger.warning("VAD结果为空")
|
||||||
|
else:
|
||||||
|
logger.info(f"VAD结果: {vad_result}")
|
||||||
|
|
||||||
|
asr_result = asr_model_use_offline("tests/vad_example.wav")
|
||||||
|
# asr_result str->dict
|
||||||
|
setup_root_logger(level="INFO",log_file="logs/test_main.log")
|
||||||
|
result = asr_result[0]['sentence_info']
|
||||||
|
for item in result:
|
||||||
|
#[{'start': 70, 'end': 2340, 'sentence': '试 错 的 过 程 很 简 单', 'timestamp': [[380, 620], [640, 740], [740, 940], [940, 1020], [1020, 1260], [1500, 1740], [1740, 1840], [1840, 2135]], 'spk': 0}
|
||||||
|
logger.info(f"spk[{item['spk']}] [{item['start']}ms:{item['end']}ms] {item['sentence'].replace(' ', '')}")
|
84
tests/modelsuse.py
Normal file
84
tests/modelsuse.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
from funasr import AutoModel
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from src.models import VADResponse
|
||||||
|
import time
|
||||||
|
|
||||||
|
def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
|
||||||
|
chunk_size = 100 # ms
|
||||||
|
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
|
||||||
|
|
||||||
|
vad_result = VADResponse()
|
||||||
|
vad_result.time_chunk_index_callback = lambda index: print(f"回调: {index}")
|
||||||
|
items = []
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
speech, sample_rate = soundfile.read(file_path)
|
||||||
|
chunk_stride = int(chunk_size * sample_rate / 1000)
|
||||||
|
|
||||||
|
cache = {}
|
||||||
|
total_chunk_num = int(len((speech)-1)/chunk_stride+1)
|
||||||
|
for i in range(total_chunk_num):
|
||||||
|
time.sleep(0.1)
|
||||||
|
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
|
||||||
|
is_final = i == total_chunk_num - 1
|
||||||
|
res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
|
||||||
|
if len(res[0]["value"]):
|
||||||
|
vad_result += VADResponse.from_raw(res)
|
||||||
|
for item in res[0]["value"]:
|
||||||
|
items.append(item)
|
||||||
|
vad_result.process_time_chunk()
|
||||||
|
|
||||||
|
# for item in items:
|
||||||
|
# print(item)
|
||||||
|
return vad_result
|
||||||
|
|
||||||
|
def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
|
||||||
|
from src.logic_trager import LogicTrager
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
from src.config import parse_args
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
from src.functor.model_loader import load_models
|
||||||
|
models = load_models(args)
|
||||||
|
|
||||||
|
chunk_size = 200 # ms
|
||||||
|
from src.models import AudioBinary_Config
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
speech, sample_rate = soundfile.read(file_path)
|
||||||
|
chunk_stride = int(chunk_size * sample_rate / 1000)
|
||||||
|
audio_config = AudioBinary_Config(sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size)
|
||||||
|
|
||||||
|
logic_trager = LogicTrager(models=models, audio_config=audio_config)
|
||||||
|
for i in range(len(speech)//chunk_stride+1):
|
||||||
|
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
|
||||||
|
logic_trager.push_binary_data(speech_chunk)
|
||||||
|
|
||||||
|
# for item in items:
|
||||||
|
# print(item)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
|
||||||
|
from funasr import AutoModel
|
||||||
|
model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
|
||||||
|
vad_model="fsmn-vad", vad_model_revision="v2.0.4",
|
||||||
|
# punc_model="ct-punc-c", punc_model_revision="v2.0.4",
|
||||||
|
spk_model="cam++", spk_model_revision="v2.0.2",
|
||||||
|
spk_mode="vad_segment",
|
||||||
|
auto_update=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
from src.models import AudioBinary_Config
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
speech, sample_rate = soundfile.read(file_path)
|
||||||
|
result = model.generate(speech)
|
||||||
|
return result
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# vad_result = vad_model_use_online("tests/vad_example.wav")
|
||||||
|
vad_result = vad_model_use_online_logic("tests/vad_example.wav")
|
||||||
|
# print(vad_result)
|
BIN
tests/vad_example.wav
Normal file
BIN
tests/vad_example.wav
Normal file
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user