146 lines
4.3 KiB
Python
146 lines
4.3 KiB
Python
"""
|
|
SpkFunctor
|
|
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
|
|
"""
|
|
from src.functor.base import BaseFunctor
|
|
from src.models import AudioBinary_Config, VAD_Functor_result
|
|
from typing import Callable, List
|
|
from queue import Queue, Empty
|
|
import threading
|
|
|
|
# 日志
|
|
from src.utils.logger import get_module_logger
|
|
|
|
logger = get_module_logger(__name__)
|
|
|
|
class SPKFunctor(BaseFunctor):
|
|
"""
|
|
SPKFunctor
|
|
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
|
|
需要配置好 _model, _callback, _input_queue, _audio_config
|
|
否则无法run()启动线程
|
|
|
|
运行中, 使用reset_cache()重置缓存, 准备下次任务
|
|
|
|
使用stop()停止线程, 但需要等待input_queue为空
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
# 资源与配置
|
|
self._model: dict = {} # 模型
|
|
self._callback: List[Callable] = [] # 回调函数
|
|
self._input_queue: Queue = None # 输入队列
|
|
self._audio_config: AudioBinary_Config = None # 音频配置
|
|
|
|
|
|
def reset_cache(self) -> None:
|
|
"""
|
|
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
|
"""
|
|
pass
|
|
|
|
def set_input_queue(self, queue: Queue) -> None:
|
|
"""
|
|
设置监听的输入消息队列
|
|
"""
|
|
self._input_queue = queue
|
|
|
|
def set_model(self, model: dict) -> None:
|
|
"""
|
|
设置推理模型
|
|
"""
|
|
self._model = model
|
|
|
|
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
|
"""
|
|
设置音频配置
|
|
"""
|
|
self._audio_config = audio_config
|
|
logger.debug("SpkFunctor设置音频配置: %s", self._audio_config)
|
|
|
|
def add_callback(self, callback: Callable) -> None:
|
|
"""
|
|
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
|
"""
|
|
if not isinstance(self._callback, list):
|
|
self._callback = []
|
|
self._callback.append(callback)
|
|
|
|
def _do_callback(self, result: List[str]) -> None:
|
|
"""
|
|
回调函数
|
|
"""
|
|
for callback in self._callback:
|
|
callback(result)
|
|
|
|
def _process(self, data: VAD_Functor_result) -> None:
|
|
"""
|
|
处理数据
|
|
"""
|
|
binary_data = data.audiobinary_data.binary_data
|
|
# result = self._model["spk"].generate(
|
|
# input=binary_data,
|
|
# chunk_size=self._audio_config.chunk_size,
|
|
# )
|
|
result = [{'result': "spk1", 'score': {"spk1": 0.9, "spk2": 0.3}}]
|
|
self._do_callback(result)
|
|
|
|
def _run(self) -> None:
|
|
"""
|
|
线程运行逻辑
|
|
"""
|
|
with self._status_lock:
|
|
self._is_running = True
|
|
self._stop_event = False
|
|
# 运行逻辑
|
|
while self._is_running:
|
|
try:
|
|
data = self._input_queue.get(True, timeout=1)
|
|
self._process(data)
|
|
self._input_queue.task_done()
|
|
# 当队列为空时, 间隔1s检测是否进入停止事件。
|
|
except Empty:
|
|
if self._stop_event:
|
|
break
|
|
continue
|
|
# 其他异常
|
|
except Exception as e:
|
|
logger.error("SpkFunctor运行时发生错误: %s", e)
|
|
raise e
|
|
|
|
def run(self) -> threading.Thread:
|
|
"""
|
|
启动线程
|
|
Returns:
|
|
threading.Thread: 返回已运行线程实例
|
|
"""
|
|
self._pre_check()
|
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
|
self._thread.start()
|
|
return self._thread
|
|
|
|
def _pre_check(self) -> bool:
|
|
"""
|
|
预检查
|
|
"""
|
|
if self._model is None:
|
|
raise ValueError("模型未设置")
|
|
if self._input_queue is None:
|
|
raise ValueError("输入队列未设置")
|
|
if self._callback is None:
|
|
raise ValueError("回调函数未设置")
|
|
return True
|
|
|
|
def stop(self) -> bool:
|
|
"""
|
|
停止线程
|
|
"""
|
|
with self._status_lock:
|
|
self._stop_event = True
|
|
self._thread.join()
|
|
with self._status_lock:
|
|
self._is_running = False
|
|
return not self._thread.is_alive()
|
|
|
|
|