166 lines
4.9 KiB
Python
166 lines
4.9 KiB
Python
"""
|
||
ResultBinderFunctor
|
||
负责聚合结果,将所有input_queue中的结果进行聚合,并进行callback
|
||
"""
|
||
|
||
from src.functor.base import BaseFunctor
|
||
from src.models import AudioBinary_Config, VAD_Functor_result
|
||
from typing import Callable, List, Dict, Any
|
||
from queue import Queue, Empty
|
||
import threading
|
||
import time
|
||
# 日志
|
||
from src.utils.logger import get_module_logger
|
||
|
||
logger = get_module_logger(__name__)
|
||
|
||
|
||
class ResultBinderFunctor(BaseFunctor):
|
||
"""
|
||
ResultBinderFunctor
|
||
负责聚合结果,将所有input_queue中的结果进行聚合,并进行callback
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
# 资源与配置
|
||
self._callback: List[Callable] = [] # 回调函数
|
||
self._input_queue: Dict[str, Queue] = {} # 输入队列
|
||
self._audio_config: AudioBinary_Config = None # 音频配置
|
||
|
||
def reset_cache(self) -> None:
|
||
"""
|
||
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||
"""
|
||
pass
|
||
|
||
def add_input_queue(self,
|
||
name: str,
|
||
queue: Queue,
|
||
) -> None:
|
||
"""
|
||
设置监听的输入消息队列
|
||
"""
|
||
self._input_queue[name] = queue
|
||
|
||
def set_model(self, model: dict) -> None:
|
||
"""
|
||
设置推理模型
|
||
resultbinder_functor 不应设置模型
|
||
"""
|
||
logger.warning("ResultBinderFunctor不应设置模型")
|
||
self._model = model
|
||
|
||
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
||
"""
|
||
设置音频配置
|
||
resultbinder_functor 不应设置音频配置
|
||
"""
|
||
logger.warning("ResultBinderFunctor不应设置音频配置")
|
||
self._audio_config = audio_config
|
||
|
||
def add_callback(self, callback: Callable) -> None:
|
||
"""
|
||
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||
"""
|
||
if not isinstance(self._callback, list):
|
||
self._callback = []
|
||
self._callback.append(callback)
|
||
|
||
def _do_callback(self, result: List[str]) -> None:
|
||
"""
|
||
回调函数
|
||
"""
|
||
for callback in self._callback:
|
||
callback(result)
|
||
|
||
def _process(self, data: Dict[str, Any]) -> None:
|
||
"""
|
||
处理数据
|
||
{
|
||
"is_final": false,
|
||
"mode": "2pass-offline",
|
||
"text": ",等一下我回一下,ok你看这里就有了",
|
||
"wav_name": "h5",
|
||
"speaker_id":
|
||
}
|
||
"""
|
||
logger.debug("ResultBinderFunctor处理数据: %s", data)
|
||
# 将data中的result进行聚合
|
||
# 此步暂时无意义,预留
|
||
results = {
|
||
"is_final": False,
|
||
"mode": "2pass-offline",
|
||
"text": data["asr"],
|
||
"wav_name": "h5",
|
||
"speaker_id": data["spk"]["speaker_id"]
|
||
}
|
||
|
||
# for name, result in data.items():
|
||
# results[name] = result
|
||
self._do_callback(results)
|
||
|
||
def _run(self) -> None:
|
||
"""
|
||
线程运行逻辑
|
||
"""
|
||
with self._status_lock:
|
||
self._is_running = True
|
||
self._stop_event = False
|
||
# 运行逻辑
|
||
while self._is_running:
|
||
try:
|
||
# 若有队列为空,则等待0.1s
|
||
for name, queue in self._input_queue.items():
|
||
if queue.empty():
|
||
time.sleep(0.1)
|
||
raise Empty
|
||
data = {}
|
||
for name, queue in self._input_queue.items():
|
||
data[name] = queue.get(True, timeout=1)
|
||
queue.task_done()
|
||
self._process(data)
|
||
# 当队列为空时, 检测是否进入停止事件。
|
||
except Empty:
|
||
if self._stop_event:
|
||
break
|
||
continue
|
||
# 其他异常
|
||
except Exception as e:
|
||
logger.error("SpkFunctor运行时发生错误: %s", e)
|
||
raise e
|
||
|
||
def run(self) -> threading.Thread:
|
||
"""
|
||
启动线程
|
||
Returns:
|
||
threading.Thread: 返回已运行线程实例
|
||
"""
|
||
self._pre_check()
|
||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||
self._thread.start()
|
||
return self._thread
|
||
|
||
def _pre_check(self) -> bool:
|
||
"""
|
||
预检查
|
||
"""
|
||
if self._model is None:
|
||
raise ValueError("模型未设置")
|
||
if self._input_queue is None:
|
||
raise ValueError("输入队列未设置")
|
||
if self._callback is None:
|
||
raise ValueError("回调函数未设置")
|
||
return True
|
||
|
||
def stop(self) -> bool:
|
||
"""
|
||
停止线程
|
||
"""
|
||
with self._status_lock:
|
||
self._stop_event = True
|
||
self._thread.join()
|
||
with self._status_lock:
|
||
self._is_running = False
|
||
return not self._thread.is_alive()
|