143 lines
4.8 KiB
Python
143 lines
4.8 KiB
Python
from pydantic import BaseModel, Field, validator
|
|
from typing import List, Optional, Callable
|
|
|
|
class VADSegment(BaseModel):
|
|
"""VAD片段"""
|
|
start: int = Field(description="开始时间(ms)")
|
|
end: int = Field(description="结束时间(ms)")
|
|
|
|
class VADResult(BaseModel):
|
|
"""VAD结果"""
|
|
key: str = Field(description="音频标识")
|
|
value: List[VADSegment] = Field(description="VAD片段列表")
|
|
|
|
class VADResponse(BaseModel):
|
|
"""VAD响应"""
|
|
results: List[VADResult] = Field(description="VAD结果列表", default_factory=list)
|
|
time_chunk: List[VADSegment] = Field(description="时间块", default_factory=list)
|
|
time_chunk_index: int = Field(description="当前处理时间块索引", default=0)
|
|
time_chunk_index_callback: Optional[Callable[[int], None]] = Field(
|
|
description="时间块索引回调函数",
|
|
default=None
|
|
)
|
|
|
|
@validator('time_chunk')
|
|
def validate_time_chunk(cls, v):
|
|
"""验证时间块的有效性"""
|
|
if not v:
|
|
return v
|
|
|
|
# 检查时间顺序
|
|
for i in range(len(v) - 1):
|
|
if v[i].end >= v[i + 1].start:
|
|
raise ValueError(f"时间块{i}的结束时间({v[i].end})大于等于下一个时间块的开始时间({v[i + 1].start})")
|
|
return v
|
|
|
|
# 回调未处理的时间块
|
|
def process_time_chunk(self, callback: Callable[[int], None] = None) -> None:
|
|
"""处理时间块"""
|
|
# print("Enter process_time_chunk", self.time_chunk_index, len(self.time_chunk))
|
|
while self.time_chunk_index < len(self.time_chunk) - 1:
|
|
index = self.time_chunk_index
|
|
if self.time_chunk[index].end != -1:
|
|
x = {
|
|
"start_time": self.time_chunk[index].start,
|
|
"end_time": self.time_chunk[index].end
|
|
}
|
|
if callback is not None:
|
|
callback(x)
|
|
elif self.time_chunk_index_callback is not None:
|
|
self.time_chunk_index_callback(x)
|
|
else:
|
|
print("[Warning] No callback available")
|
|
self.time_chunk_index += 1
|
|
|
|
def __add__(self, other: 'VADResponse') -> 'VADResponse':
|
|
"""合并两个VADResponse"""
|
|
if not self.results:
|
|
self.results = other.results
|
|
self.time_chunk = other.time_chunk
|
|
return self
|
|
|
|
# 检查是否可以合并最后一个结果
|
|
last_result = self.results[-1]
|
|
first_other = other.results[0]
|
|
|
|
if last_result.value[-1].end == first_other.value[0].start:
|
|
# 合并相邻的时间段
|
|
last_result.value[-1].end = first_other.value[0].end
|
|
first_other.value.pop(0)
|
|
|
|
# 更新time_chunk
|
|
self.time_chunk[-1].end = other.time_chunk[0].end
|
|
other.time_chunk.pop(0)
|
|
|
|
# 添加剩余的结果
|
|
if first_other.value:
|
|
self.results.extend(other.results)
|
|
self.time_chunk.extend(other.time_chunk)
|
|
else:
|
|
# 直接添加所有结果
|
|
self.results.extend(other.results)
|
|
self.time_chunk.extend(other.time_chunk)
|
|
|
|
return self
|
|
|
|
@classmethod
|
|
def from_raw(cls, raw_data: List[dict]) -> "VADResponse":
|
|
"""
|
|
从原始数据创建VADResponse
|
|
|
|
参数:
|
|
raw_data: 原始数据,格式如 [{'key': 'xxx', 'value': [[-1, 59540], [59820, -1]]}]
|
|
|
|
返回:
|
|
VADResponse: 解析后的VAD响应
|
|
"""
|
|
results = []
|
|
time_chunk = []
|
|
for item in raw_data:
|
|
segments = [
|
|
VADSegment(start=seg[0], end=seg[1])
|
|
for seg in item['value']
|
|
]
|
|
results.append(VADResult(
|
|
key=item['key'],
|
|
value=segments
|
|
))
|
|
time_chunk.extend(segments)
|
|
return cls(results=results, time_chunk=time_chunk)
|
|
|
|
def to_raw(self) -> List[dict]:
|
|
"""
|
|
转换为原始数据格式
|
|
|
|
返回:
|
|
List[dict]: 原始数据格式
|
|
"""
|
|
return [
|
|
{
|
|
'key': result.key,
|
|
'value': [[seg.start, seg.end] for seg in result.value]
|
|
}
|
|
for result in self.results
|
|
]
|
|
|
|
def __str__(self):
|
|
result_str = "VADResponse:\n"
|
|
for result in self.results:
|
|
for value_item in result.value:
|
|
result_str += f"[{value_item.start}:{value_item.end}]\n"
|
|
return result_str
|
|
|
|
def __iter__(self):
|
|
return iter(self.time_chunk)
|
|
|
|
def __next__(self):
|
|
return next(self.time_chunk)
|
|
|
|
def __len__(self):
|
|
return len(self.time_chunk)
|
|
|
|
def __getitem__(self, index):
|
|
return self.time_chunk[index] |