STT_Server/src/functor/resultbinder_functor.py

166 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
ResultBinderFunctor
负责聚合结果将所有input_queue中的结果进行聚合并进行callback
"""
from src.functor.base import BaseFunctor
from src.models import AudioBinary_Config, VAD_Functor_result
from typing import Callable, List, Dict, Any
from queue import Queue, Empty
import threading
import time
# 日志
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class ResultBinderFunctor(BaseFunctor):
"""
ResultBinderFunctor
负责聚合结果将所有input_queue中的结果进行聚合并进行callback
"""
def __init__(self) -> None:
super().__init__()
# 资源与配置
self._callback: List[Callable] = [] # 回调函数
self._input_queue: Dict[str, Queue] = {} # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置
def reset_cache(self) -> None:
"""
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
"""
pass
def add_input_queue(self,
name: str,
queue: Queue,
) -> None:
"""
设置监听的输入消息队列
"""
self._input_queue[name] = queue
def set_model(self, model: dict) -> None:
"""
设置推理模型
resultbinder_functor 不应设置模型
"""
logger.warning("ResultBinderFunctor不应设置模型")
self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
"""
设置音频配置
resultbinder_functor 不应设置音频配置
"""
logger.warning("ResultBinderFunctor不应设置音频配置")
self._audio_config = audio_config
def add_callback(self, callback: Callable) -> None:
"""
向自身的_callback: List[Callable]回调函数列表中添加回调函数
"""
if not isinstance(self._callback, list):
self._callback = []
self._callback.append(callback)
def _do_callback(self, result: List[str]) -> None:
"""
回调函数
"""
for callback in self._callback:
callback(result)
def _process(self, data: Dict[str, Any]) -> None:
"""
处理数据
{
"is_final": false,
"mode": "2pass-offline",
"text": "等一下我回一下ok你看这里就有了",
"wav_name": "h5",
"speaker_id":
}
"""
logger.debug("ResultBinderFunctor处理数据: %s", data)
# 将data中的result进行聚合
# 此步暂时无意义,预留
results = {
"is_final": False,
"mode": "2pass-offline",
"text": data["asr"],
"wav_name": "h5",
"speaker_id": data["spk"]["speaker_id"]
}
# for name, result in data.items():
# results[name] = result
self._do_callback(results)
def _run(self) -> None:
"""
线程运行逻辑
"""
with self._status_lock:
self._is_running = True
self._stop_event = False
# 运行逻辑
while self._is_running:
try:
# 若有队列为空则等待0.1s
for name, queue in self._input_queue.items():
if queue.empty():
time.sleep(0.1)
raise Empty
data = {}
for name, queue in self._input_queue.items():
data[name] = queue.get(True, timeout=1)
queue.task_done()
self._process(data)
# 当队列为空时, 检测是否进入停止事件。
except Empty:
if self._stop_event:
break
continue
# 其他异常
except Exception as e:
logger.error("SpkFunctor运行时发生错误: %s", e)
raise e
def run(self) -> threading.Thread:
"""
启动线程
Returns:
threading.Thread: 返回已运行线程实例
"""
self._pre_check()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
return self._thread
def _pre_check(self) -> bool:
"""
预检查
"""
if self._model is None:
raise ValueError("模型未设置")
if self._input_queue is None:
raise ValueError("输入队列未设置")
if self._callback is None:
raise ValueError("回调函数未设置")
return True
def stop(self) -> bool:
"""
停止线程
"""
with self._status_lock:
self._stop_event = True
self._thread.join()
with self._status_lock:
self._is_running = False
return not self._thread.is_alive()