STT_Server/src/functor/audiochunk.py

178 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
音频数据块管理类 - 用于存储和处理16KHz音频数据
"""
import numpy as np
from src.utils.logger import get_module_logger
from typing import List, Optional, Union
from src.models import AudioBinary_Config
# 配置日志
logger = get_module_logger(__name__, level="INFO")
class AudioChunk:
"""音频数据块管理类用于存储和处理16KHz音频数据"""
def __init__(self,
max_duration_ms: int = 1000*60*60*10,
audio_config : AudioBinary_Config = None,
):
"""
初始化音频数据块管理器
参数:
max_duration_ms: 音频池最大留存时间(ms)默认10小时
audio_config: 音频配置信息
"""
# 音频参数
self.sample_rate = audio_config.sample_rate if audio_config is not None else 16000 # 采样率16KHz
self.sample_width = audio_config.sample_width if audio_config is not None else 2 # 采样位宽16bit
self.channels = audio_config.channels if audio_config is not None else 1 # 通道数:单声道
# 数据存储
self._max_duration_ms = max_duration_ms
self._max_chunk_size = self._time2size(max_duration_ms) # 最大数据大小
self._chunk = [] # 当前音频数据块列表
self._chunk_size = 0 # 当前数据总大小
self._offset = 0 # 当前偏移量
logger.info(f"初始化AudioChunk: 最大时长={max_duration_ms}ms, 最大数据大小={self._max_chunk_size}字节")
def add_chunk(self, chunk: Union[bytes, np.ndarray]) -> bool:
"""
添加音频数据块
参数:
chunk: 音频数据块可以是bytes或numpy数组
返回:
bool: 是否添加成功
"""
try:
# 检查数据格式
if isinstance(chunk, np.ndarray):
# 确保是16bit整数格式
if chunk.dtype != np.int16:
chunk = chunk.astype(np.int16)
# 转换为bytes
chunk = chunk.tobytes()
# 检查数据大小
if len(chunk) % (self.sample_width * self.channels) != 0:
logger.warning(f"音频数据大小不是{self.sample_width * self.channels}的倍数: {len(chunk)}")
return False
# 检查是否超过最大限制
if self._chunk_size + len(chunk) > self._max_chunk_size:
logger.warning("音频数据超过最大限制,将自动清除旧数据")
self.clear_chunk()
# 添加数据
self._chunk.append(chunk)
self._chunk_size += len(chunk)
return True
except Exception as e:
logger.error(f"添加音频数据块时出错: {e}")
return False
def get_chunk_binary(self, start: int = 0, end: Optional[int] = None) -> Optional[bytes]:
"""
获取指定索引的音频数据块
"""
print("[AudioChunk] get_chunk_binary", start, end)
if start >= len(self._chunk):
return None
if end is None or end > len(self._chunk):
end = len(self._chunk)
data = b''.join(self._chunk)
return data[start:end]
def get_chunk(self, start_ms: int = 0, end_ms: Optional[int] = None) -> Optional[bytes]:
"""
获取指定时间范围的音频数据
参数:
start_ms: 开始时间(ms)
end_ms: 结束时间(ms)None表示到末尾
返回:
Optional[bytes]: 音频数据如果获取失败则返回None
"""
try:
if not self._chunk:
return None
# 计算字节偏移
start_byte = self._time2size(start_ms)
end_byte = self._time2size(end_ms) if end_ms is not None else self._chunk_size
# 检查范围是否有效
if start_byte >= self._chunk_size or start_byte >= end_byte:
return None
# 获取数据
data = b''.join(self._chunk)
return data[start_byte:end_byte]
except Exception as e:
logger.error(f"获取音频数据块时出错: {e}")
return None
def get_duration(self) -> int:
"""
获取当前音频总时长(ms)
返回:
int: 音频时长(ms)
"""
return self._size2time(self._chunk_size)
def clear_chunk(self) -> None:
"""清除所有音频数据"""
self._chunk = []
self._chunk_size = 0
self._offset = 0
logger.info("已清除所有音频数据")
def _time2size(self, time_ms: int) -> int:
"""
将时间(ms)转换为数据大小(字节)
参数:
time_ms: 时间(ms)
返回:
int: 数据大小(字节)
"""
return int(time_ms * self.sample_rate * self.sample_width * self.channels / 1000)
def _size2time(self, size: int) -> int:
"""
将数据大小(字节)转换为时间(ms)
参数:
size: 数据大小(字节)
返回:
int: 时间(ms)
"""
return int(size * 1000 / (self.sample_rate * self.sample_width * self.channels))
# instance(start_ms, end_ms, use_offset=True)
def __call__(self, start_ms: int = 0, end_ms: Optional[int] = None, use_offset: bool = True) -> Optional[bytes]:
"""
获取指定时间范围的音频数据
"""
if use_offset:
start_ms += self._offset
end_ms += self._offset
return self.get_chunk(start_ms, end_ms)
def __len__(self) -> int:
"""
获取当前音频数据块大小
"""
return self._chunk_size