diff --git a/src/audiochunk.py b/src/audiochunk.py new file mode 100644 index 0000000..971a269 --- /dev/null +++ b/src/audiochunk.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +音频数据块管理类 - 用于存储和处理16KHz音频数据 +""" + +import numpy as np +import logging +from typing import List, Optional, Union + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger('AudioChunk') + +class AudioChunk: + """音频数据块管理类,用于存储和处理16KHz音频数据""" + + def __init__(self, + max_duration_ms: int = 1000*60*60*10, + sample_rate: int = 16000, + sample_width: int = 2, + channels: int = 1): + """ + 初始化音频数据块管理器 + + 参数: + max_duration_ms: 音频池最大留存时间(ms),默认10小时 + sample_rate: 采样率,默认16KHz + sample_width: 采样位宽,默认16bit + channels: 通道数,默认1 + """ + # 音频参数 + self.sample_rate = sample_rate # 采样率:16KHz + self.sample_width = sample_width # 采样位宽:16bit + self.channels = channels # 通道数:单声道 + + # 数据存储 + self._max_duration_ms = max_duration_ms + self._max_chunk_size = self._time2size(max_duration_ms) # 最大数据大小 + self._chunk = [] # 当前音频数据块列表 + self._chunk_size = 0 # 当前数据总大小 + self._offset = 0 # 当前偏移量 + + logger.info(f"初始化AudioChunk: 最大时长={max_duration_ms}ms, 最大数据大小={self._max_chunk_size}字节") + + def add_chunk(self, chunk: Union[bytes, np.ndarray]) -> bool: + """ + 添加音频数据块 + + 参数: + chunk: 音频数据块,可以是bytes或numpy数组 + + 返回: + bool: 是否添加成功 + """ + try: + # 检查数据格式 + if isinstance(chunk, np.ndarray): + # 确保是16bit整数格式 + if chunk.dtype != np.int16: + chunk = chunk.astype(np.int16) + # 转换为bytes + chunk = chunk.tobytes() + + # 检查数据大小 + if len(chunk) % (self.sample_width * self.channels) != 0: + logger.warning(f"音频数据大小不是{self.sample_width * self.channels}的倍数: {len(chunk)}") + return False + + # 检查是否超过最大限制 + if self._chunk_size + len(chunk) > self._max_chunk_size: + logger.warning("音频数据超过最大限制,将自动清除旧数据") + self.clear_chunk() + + # 添加数据 + self._chunk.append(chunk) + self._chunk_size += len(chunk) + return True + + except Exception as e: + logger.error(f"添加音频数据块时出错: {e}") + return False + + def get_chunk(self, start_ms: int = 0, end_ms: Optional[int] = None) -> Optional[bytes]: + """ + 获取指定时间范围的音频数据 + + 参数: + start_ms: 开始时间(ms) + end_ms: 结束时间(ms),None表示到末尾 + + 返回: + Optional[bytes]: 音频数据,如果获取失败则返回None + """ + try: + if not self._chunk: + return None + + # 计算字节偏移 + start_byte = self._time2size(start_ms) + end_byte = self._time2size(end_ms) if end_ms is not None else self._chunk_size + + # 检查范围是否有效 + if start_byte >= self._chunk_size or start_byte >= end_byte: + return None + + # 获取数据 + data = b''.join(self._chunk) + return data[start_byte:end_byte] + + except Exception as e: + logger.error(f"获取音频数据块时出错: {e}") + return None + + def get_duration(self) -> int: + """ + 获取当前音频总时长(ms) + + 返回: + int: 音频时长(ms) + """ + return self._size2time(self._chunk_size) + + def clear_chunk(self) -> None: + """清除所有音频数据""" + self._chunk = [] + self._chunk_size = 0 + self._offset = 0 + logger.info("已清除所有音频数据") + + def _time2size(self, time_ms: int) -> int: + """ + 将时间(ms)转换为数据大小(字节) + + 参数: + time_ms: 时间(ms) + + 返回: + int: 数据大小(字节) + """ + return int(time_ms * self.sample_rate * self.sample_width * self.channels / 1000) + + def _size2time(self, size: int) -> int: + """ + 将数据大小(字节)转换为时间(ms) + + 参数: + size: 数据大小(字节) + + 返回: + int: 时间(ms) + """ + return int(size * 1000 / (self.sample_rate * self.sample_width * self.channels)) + + # instance(start_ms, end_ms, use_offset=True) + def __call__(self, start_ms: int = 0, end_ms: Optional[int] = None, use_offset: bool = True) -> Optional[bytes]: + """ + 获取指定时间范围的音频数据 + """ + if use_offset: + start_ms += self._offset + end_ms += self._offset + return self.get_chunk(start_ms, end_ms) + + def __len__(self) -> int: + """ + 获取当前音频数据块大小 + """ + return self._chunk_size \ No newline at end of file diff --git a/src/client.py b/src/client.py deleted file mode 100644 index 86328f1..0000000 --- a/src/client.py +++ /dev/null @@ -1,196 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -WebSocket客户端示例 - 用于测试语音识别服务 -""" - -import asyncio -import json -import websockets -import argparse -import numpy as np -import wave -import os - - -def parse_args(): - """解析命令行参数""" - parser = argparse.ArgumentParser(description="FunASR WebSocket客户端") - - parser.add_argument( - "--host", - type=str, - default="localhost", - help="服务器主机地址" - ) - - parser.add_argument( - "--port", - type=int, - default=10095, - help="服务器端口" - ) - - parser.add_argument( - "--audio_file", - type=str, - required=True, - help="要识别的音频文件路径" - ) - - parser.add_argument( - "--mode", - type=str, - default="2pass", - choices=["2pass", "online", "offline"], - help="识别模式: 2pass(默认), online, offline" - ) - - parser.add_argument( - "--chunk_size", - type=str, - default="5,10", - help="分块大小, 格式为'encoder_size,decoder_size'" - ) - - parser.add_argument( - "--use_ssl", - action="store_true", - help="是否使用SSL连接" - ) - - return parser.parse_args() - - -async def send_audio(websocket, audio_file, mode, chunk_size): - """ - 发送音频文件到服务器进行识别 - - 参数: - websocket: WebSocket连接 - audio_file: 音频文件路径 - mode: 识别模式 - chunk_size: 分块大小 - """ - # 打开并读取WAV文件 - with wave.open(audio_file, "rb") as wav_file: - params = wav_file.getparams() - frames = wav_file.readframes(wav_file.getnframes()) - - # 音频文件信息 - print(f"音频文件: {os.path.basename(audio_file)}") - print(f"采样率: {params.framerate}Hz, 通道数: {params.nchannels}") - print(f"采样位深: {params.sampwidth * 8}位, 总帧数: {params.nframes}") - - # 设置配置参数 - config = { - "mode": mode, - "chunk_size": chunk_size, - "wav_name": os.path.basename(audio_file), - "is_speaking": True - } - - # 发送配置 - await websocket.send(json.dumps(config)) - - # 模拟实时发送音频数据 - chunk_size_bytes = 3200 # 每次发送100ms的16kHz音频 - total_chunks = len(frames) // chunk_size_bytes - - print(f"开始发送音频数据,共 {total_chunks} 个数据块...") - - try: - for i in range(0, len(frames), chunk_size_bytes): - chunk = frames[i:i+chunk_size_bytes] - await websocket.send(chunk) - - # 模拟实时,每100ms发送一次 - await asyncio.sleep(0.1) - - # 显示进度 - if (i // chunk_size_bytes) % 10 == 0: - print(f"已发送 {i // chunk_size_bytes}/{total_chunks} 数据块") - - # 发送结束信号 - await websocket.send(json.dumps({"is_speaking": False})) - print("音频数据发送完成") - - except Exception as e: - print(f"发送音频时出错: {e}") - - -async def receive_results(websocket): - """ - 接收并显示识别结果 - - 参数: - websocket: WebSocket连接 - """ - online_text = "" - offline_text = "" - - try: - async for message in websocket: - # 解析服务器返回的JSON消息 - result = json.loads(message) - - mode = result.get("mode", "") - text = result.get("text", "") - is_final = result.get("is_final", False) - - # 根据模式更新文本 - if "online" in mode: - online_text = text - print(f"\r[在线识别] {online_text}", end="", flush=True) - elif "offline" in mode: - offline_text = text - print(f"\n[离线识别] {offline_text}") - - # 如果是最终结果,打印完整信息 - if is_final and offline_text: - print("\n最终识别结果:") - print(f"[离线识别] {offline_text}") - return - - except Exception as e: - print(f"接收结果时出错: {e}") - - -async def main(): - """主函数""" - args = parse_args() - - # WebSocket URI - protocol = "wss" if args.use_ssl else "ws" - uri = f"{protocol}://{args.host}:{args.port}" - - print(f"连接到服务器: {uri}") - - try: - # 创建WebSocket连接 - async with websockets.connect( - uri, - subprotocols=["binary"] - ) as websocket: - - print("连接成功") - - # 创建两个任务: 发送音频和接收结果 - send_task = asyncio.create_task( - send_audio(websocket, args.audio_file, args.mode, args.chunk_size) - ) - - receive_task = asyncio.create_task( - receive_results(websocket) - ) - - # 等待任务完成 - await asyncio.gather(send_task, receive_task) - - except Exception as e: - print(f"连接服务器失败: {e}") - - -if __name__ == "__main__": - # 运行主函数 - asyncio.run(main()) \ No newline at end of file diff --git a/src/logic_trager.py b/src/logic_trager.py new file mode 100644 index 0000000..da0034b --- /dev/null +++ b/src/logic_trager.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +逻辑触发器类 - 用于处理音频数据并触发相应的处理逻辑 +""" + +import logging +from typing import Any, Dict, Type + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger('LogicTrager') + +class AutoAfterMeta(type): + """ + 自动调用__after__函数的元类 + 实现单例模式 + """ + + _instances: Dict[Type, Any] = {} # 存储单例实例 + + def __new__(cls, name, bases, attrs): + # 遍历所有属性 + for attr_name, attr_value in attrs.items(): + # 如果是函数且不是以_开头 + if callable(attr_value) and not attr_name.startswith('__'): + # 获取原函数 + original_func = attr_value + + # 创建包装函数 + def make_wrapper(func): + def wrapper(self, *args, **kwargs): + # 执行原函数 + result = func(self, *args, **kwargs) + + # 构建_after_函数名 + after_func_name = f"__after__{func.__name__}" + + # 检查是否存在对应的_after_函数 + if hasattr(self, after_func_name): + after_func = getattr(self, after_func_name) + if callable(after_func): + try: + # 调用_after_函数 + after_func() + except Exception as e: + logger.error(f"调用{after_func_name}时出错: {e}") + + return result + return wrapper + + # 替换原函数 + attrs[attr_name] = make_wrapper(original_func) + + # 创建类 + new_class = super().__new__(cls, name, bases, attrs) + return new_class + + def __call__(cls, *args, **kwargs): + """ + 重写__call__方法实现单例模式 + 当类被调用时(即创建实例时)执行 + """ + if cls not in cls._instances: + # 如果实例不存在,创建新实例 + cls._instances[cls] = super().__call__(*args, **kwargs) + logger.info(f"创建{cls.__name__}的新实例") + else: + logger.debug(f"返回{cls.__name__}的现有实例") + + return cls._instances[cls] + +""" +整体识别的处理逻辑: +1.压入二进制音频信息 +2.不断检测VAD +3.当检测到完整VAD时,将VAD的音频信息压入音频块,并清除对应二进制信息 +4.对音频块进行语音转文字offline,时间戳预测,说话人识别 +5.将识别结果整合压入结果队列 +6.结果队列被压入时调用回调函数 + +1->2 __after__push_binary_data 外部压入二进制信息 +2,3->4 __after__push_audio_chunk 内部压入音频块 +4->5 push_result_queue 压入结果队列 +5->6 __after__push_result_queue 调用回调函数 +""" + +class LogicTrager(metaclass=AutoAfterMeta): + """逻辑触发器类""" + + def __init__(self, + audio_chunk_max_size: int = 1024 * 1024 * 10, + sample_rate: int = 16000, + channels: int = 1, + on_result_callback: Callable = None, + ): + """初始化""" + # 存储音频块 + self._audio_chunk = [] + # 存储二进制数据 + self._audio_chunk_binary = b'' + self._audio_chunk_max_size = audio_chunk_max_size + # 音频参数 + self._sample_rate = sample_rate + self._channels = channels + # 结果队列 + self._result_queue = [] + # 回调函数 + self._on_result_callback = on_result_callback + logger.info("初始化LogicTrager") + + def push_binary_data(self, chunk: bytes) -> None: + """ + 添加音频块 + + 参数: + chunk: 音频数据块 + """ + if self._audio_chunk is None: + logger.error("AudioChunk未初始化") + return + + self._audio_chunk_binary += chunk + logger.debug(f"添加音频块,大小: {len(chunk)}字节") + + def __after__push_binary_data(self) -> None: + """ + 添加音频块后处理 + VAD检测,将检测到的VAD压入音频块 + """ + # VAD检测 + pass + # 压入音频块 push_audio_chunk + + def push_audio_chunk(self, chunk: bytes) -> None: + """ + 压入音频块 + """ + self._audio_chunk.append(chunk) + + def __after__push_audio_chunk(self) -> None: + """ + 压入音频块后处理 + """ + pass + + def push_result_queue(self, result: Dict[str, Any]) -> None: + """ + 压入结果队列 + """ + self._result_queue.append(result) + + def __after__push_result_queue(self) -> None: + """ + 压入结果队列后处理 + """ + pass + + def __call__(self): + """调用函数""" + pass \ No newline at end of file diff --git a/src/models.py b/src/models.py index 7c9c7b6..511116a 100644 --- a/src/models.py +++ b/src/models.py @@ -4,6 +4,8 @@ 模型加载模块 - 负责加载各种语音识别相关模型 """ +from typing import List, Optional + def load_models(args): """ 加载所有需要的模型 diff --git a/src/pydantic_models.py b/src/pydantic_models.py new file mode 100644 index 0000000..10c252c --- /dev/null +++ b/src/pydantic_models.py @@ -0,0 +1,127 @@ +from pydantic import BaseModel, Field, validator +from typing import List, Optional, Callable + +class VADSegment(BaseModel): + """VAD片段""" + start: int = Field(description="开始时间(ms)") + end: int = Field(description="结束时间(ms)") + +class VADResult(BaseModel): + """VAD结果""" + key: str = Field(description="音频标识") + value: List[VADSegment] = Field(description="VAD片段列表") + +class VADResponse(BaseModel): + """VAD响应""" + results: List[VADResult] = Field(description="VAD结果列表", default_factory=list) + time_chunk: List[VADSegment] = Field(description="时间块", default_factory=list) + time_chunk_index: int = Field(description="当前处理时间块索引", default=0) + time_chunk_index_callback: Optional[Callable[[int], None]] = Field( + description="时间块索引回调函数", + default=None + ) + + @validator('time_chunk') + def validate_time_chunk(cls, v): + """验证时间块的有效性""" + if not v: + return v + + # 检查时间顺序 + for i in range(len(v) - 1): + if v[i].end >= v[i + 1].start: + raise ValueError(f"时间块{i}的结束时间({v[i].end})大于等于下一个时间块的开始时间({v[i + 1].start})") + return v + + # 回调未处理的时间块 + def process_time_chunk(self, callback: Callable[[int], None] = None) -> None: + """处理时间块""" + # print("Enter process_time_chunk", self.time_chunk_index, len(self.time_chunk)) + while self.time_chunk_index < len(self.time_chunk) - 1: + if self.time_chunk[self.time_chunk_index].end != -1: + if callback is not None: + callback(self.time_chunk_index) + elif self.time_chunk_index_callback is not None: + self.time_chunk_index_callback(self.time_chunk_index) + else: + print("[Warning] No callback available") + self.time_chunk_index += 1 + + def __add__(self, other: 'VADResponse') -> 'VADResponse': + """合并两个VADResponse""" + if not self.results: + self.results = other.results + self.time_chunk = other.time_chunk + return self + + # 检查是否可以合并最后一个结果 + last_result = self.results[-1] + first_other = other.results[0] + + if last_result.value[-1].end == first_other.value[0].start: + # 合并相邻的时间段 + last_result.value[-1].end = first_other.value[0].end + first_other.value.pop(0) + + # 更新time_chunk + self.time_chunk[-1].end = other.time_chunk[0].end + other.time_chunk.pop(0) + + # 添加剩余的结果 + if first_other.value: + self.results.extend(other.results) + self.time_chunk.extend(other.time_chunk) + else: + # 直接添加所有结果 + self.results.extend(other.results) + self.time_chunk.extend(other.time_chunk) + + return self + + @classmethod + def from_raw(cls, raw_data: List[dict]) -> "VADResponse": + """ + 从原始数据创建VADResponse + + 参数: + raw_data: 原始数据,格式如 [{'key': 'xxx', 'value': [[-1, 59540], [59820, -1]]}] + + 返回: + VADResponse: 解析后的VAD响应 + """ + results = [] + time_chunk = [] + for item in raw_data: + segments = [ + VADSegment(start=seg[0], end=seg[1]) + for seg in item['value'] + ] + results.append(VADResult( + key=item['key'], + value=segments + )) + time_chunk.extend(segments) + return cls(results=results, time_chunk=time_chunk) + + def to_raw(self) -> List[dict]: + """ + 转换为原始数据格式 + + 返回: + List[dict]: 原始数据格式 + """ + return [ + { + 'key': result.key, + 'value': [[seg.start, seg.end] for seg in result.value] + } + for result in self.results + ] + + def __str__(self): + result_str = "VADResponse:\n" + for result in self.results: + for value_item in result.value: + result_str += f"[{value_item.start}:{value_item.end}]\n" + return result_str + diff --git a/test_main.py b/test_main.py new file mode 100644 index 0000000..919be7b --- /dev/null +++ b/test_main.py @@ -0,0 +1,4 @@ +from tests.modelsuse import vad_model_use_online + +vad_result = vad_model_use_online("tests/vad_example.wav") +print(vad_result) \ No newline at end of file diff --git a/tests/modelsuse.py b/tests/modelsuse.py new file mode 100644 index 0000000..35faad9 --- /dev/null +++ b/tests/modelsuse.py @@ -0,0 +1,37 @@ +from funasr import AutoModel +from typing import List, Dict, Any +from src.pydantic_models import VADResponse +import time + +def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]: + chunk_size = 100 # ms + model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True) + + vad_result = VADResponse() + vad_result.time_chunk_index_callback = lambda index: print(f"回调: {index}") + items = [] + import soundfile + + speech, sample_rate = soundfile.read(file_path) + chunk_stride = int(chunk_size * sample_rate / 1000) + + cache = {} + total_chunk_num = int(len((speech)-1)/chunk_stride+1) + for i in range(total_chunk_num): + time.sleep(0.1) + speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] + is_final = i == total_chunk_num - 1 + res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size) + if len(res[0]["value"]): + vad_result += VADResponse.from_raw(res) + for item in res[0]["value"]: + items.append(item) + vad_result.process_time_chunk() + + # for item in items: + # print(item) + return vad_result + +if __name__ == "__main__": + vad_result = vad_model_use_online("tests/vad_example.wav") + # print(vad_result) \ No newline at end of file diff --git a/tests/vad_example.wav b/tests/vad_example.wav new file mode 100644 index 0000000..2ebc8c7 Binary files /dev/null and b/tests/vad_example.wav differ