143 lines
4.8 KiB
Python

from pydantic import BaseModel, Field, validator
from typing import List, Optional, Callable
class VADSegment(BaseModel):
"""VAD片段"""
start: int = Field(description="开始时间(ms)")
end: int = Field(description="结束时间(ms)")
class VADResult(BaseModel):
"""VAD结果"""
key: str = Field(description="音频标识")
value: List[VADSegment] = Field(description="VAD片段列表")
class VADResponse(BaseModel):
"""VAD响应"""
results: List[VADResult] = Field(description="VAD结果列表", default_factory=list)
time_chunk: List[VADSegment] = Field(description="时间块", default_factory=list)
time_chunk_index: int = Field(description="当前处理时间块索引", default=0)
time_chunk_index_callback: Optional[Callable[[int], None]] = Field(
description="时间块索引回调函数",
default=None
)
@validator('time_chunk')
def validate_time_chunk(cls, v):
"""验证时间块的有效性"""
if not v:
return v
# 检查时间顺序
for i in range(len(v) - 1):
if v[i].end >= v[i + 1].start:
raise ValueError(f"时间块{i}的结束时间({v[i].end})大于等于下一个时间块的开始时间({v[i + 1].start})")
return v
# 回调未处理的时间块
def process_time_chunk(self, callback: Callable[[int], None] = None) -> None:
"""处理时间块"""
# print("Enter process_time_chunk", self.time_chunk_index, len(self.time_chunk))
while self.time_chunk_index < len(self.time_chunk) - 1:
index = self.time_chunk_index
if self.time_chunk[index].end != -1:
x = {
"start_time": self.time_chunk[index].start,
"end_time": self.time_chunk[index].end
}
if callback is not None:
callback(x)
elif self.time_chunk_index_callback is not None:
self.time_chunk_index_callback(x)
else:
print("[Warning] No callback available")
self.time_chunk_index += 1
def __add__(self, other: 'VADResponse') -> 'VADResponse':
"""合并两个VADResponse"""
if not self.results:
self.results = other.results
self.time_chunk = other.time_chunk
return self
# 检查是否可以合并最后一个结果
last_result = self.results[-1]
first_other = other.results[0]
if last_result.value[-1].end == first_other.value[0].start:
# 合并相邻的时间段
last_result.value[-1].end = first_other.value[0].end
first_other.value.pop(0)
# 更新time_chunk
self.time_chunk[-1].end = other.time_chunk[0].end
other.time_chunk.pop(0)
# 添加剩余的结果
if first_other.value:
self.results.extend(other.results)
self.time_chunk.extend(other.time_chunk)
else:
# 直接添加所有结果
self.results.extend(other.results)
self.time_chunk.extend(other.time_chunk)
return self
@classmethod
def from_raw(cls, raw_data: List[dict]) -> "VADResponse":
"""
从原始数据创建VADResponse
参数:
raw_data: 原始数据,格式如 [{'key': 'xxx', 'value': [[-1, 59540], [59820, -1]]}]
返回:
VADResponse: 解析后的VAD响应
"""
results = []
time_chunk = []
for item in raw_data:
segments = [
VADSegment(start=seg[0], end=seg[1])
for seg in item['value']
]
results.append(VADResult(
key=item['key'],
value=segments
))
time_chunk.extend(segments)
return cls(results=results, time_chunk=time_chunk)
def to_raw(self) -> List[dict]:
"""
转换为原始数据格式
返回:
List[dict]: 原始数据格式
"""
return [
{
'key': result.key,
'value': [[seg.start, seg.end] for seg in result.value]
}
for result in self.results
]
def __str__(self):
result_str = "VADResponse:\n"
for result in self.results:
for value_item in result.value:
result_str += f"[{value_item.start}:{value_item.end}]\n"
return result_str
def __iter__(self):
return iter(self.time_chunk)
def __next__(self):
return next(self.time_chunk)
def __len__(self):
return len(self.time_chunk)
def __getitem__(self, index):
return self.time_chunk[index]