From 3d8bf9de25e5957ebeb9e90683bc746370bca4fc Mon Sep 17 00:00:00 2001 From: "Ziyang.Zhang" Date: Thu, 5 Jun 2025 17:08:42 +0800 Subject: [PATCH] =?UTF-8?q?[=E4=BB=A3=E7=A0=81=E9=87=8D=E6=9E=84=E4=B8=AD]?= =?UTF-8?q?=E5=88=9B=E5=BB=BA=E5=81=87=E7=9A=84SPKFunctor=E4=BB=A5?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E6=B6=88=E6=81=AF=E9=98=9F=E5=88=97=E6=B5=81?= =?UTF-8?q?=E7=A8=8B=E6=98=AF=E5=90=A6=E6=AD=A3=E7=A1=AE=EF=BC=8C=E6=97=A0?= =?UTF-8?q?=E9=97=AE=E9=A2=98=EF=BC=8C=E5=BE=85=E8=BF=9B=E4=B8=80=E6=AD=A5?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E8=AF=B4=E8=AF=9D=E4=BA=BA=E8=AF=86=E5=88=AB?= =?UTF-8?q?=EF=BC=8C=E6=AD=A4=E5=A4=96=EF=BC=8C=E8=80=83=E8=99=91=E5=B0=86?= =?UTF-8?q?=E4=B8=80=E4=BA=9B=E5=85=B1=E6=9C=89=E5=86=85=E5=AE=B9=E5=86=99?= =?UTF-8?q?=E5=85=A5BaseFunctor=E4=B8=AD=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/functor/asr_functor.py | 8 +- src/functor/base.py | 10 ++- src/functor/spk_functor.py | 146 +++++++++++++++++++++++++++++++++++++ tests/functor/vad_test.py | 28 +++++-- 4 files changed, 182 insertions(+), 10 deletions(-) create mode 100644 src/functor/spk_functor.py diff --git a/src/functor/asr_functor.py b/src/functor/asr_functor.py index e5a93be..9f766a4 100644 --- a/src/functor/asr_functor.py +++ b/src/functor/asr_functor.py @@ -99,7 +99,7 @@ class ASRFunctor(BaseFunctor): ) self._do_callback(result) - def _run(self): + def _run(self) -> None: """ 线程运行逻辑 """ @@ -122,9 +122,11 @@ class ASRFunctor(BaseFunctor): logger.error("ASRFunctor运行时发生错误: %s", e) raise e - def run(self): + def run(self) -> threading.Thread: """ 启动线程 + Returns: + threading.Thread: 返回已运行线程实例 """ self._pre_check() self._thread = threading.Thread(target=self._run, daemon=True) @@ -145,7 +147,7 @@ class ASRFunctor(BaseFunctor): raise ValueError("回调函数未设置") return True - def stop(self): + def stop(self) -> bool: """ 停止线程 """ diff --git a/src/functor/base.py b/src/functor/base.py index 3754154..b0f8eb4 100644 --- a/src/functor/base.py +++ b/src/functor/base.py @@ -14,6 +14,7 @@ Functor基础模块 from abc import ABC, abstractmethod from typing import Callable, List from queue import Queue +import threading class BaseFunctor(ABC): """ @@ -37,7 +38,14 @@ class BaseFunctor(ABC): model (dict): 模型相关的配置和实例 """ self._callback: List[Callable] = [] - self._model: dict = {} + self._model: dict = {} + # flag + self._is_running: bool = False + self._stop_event: bool = False + # 状态锁 + self._status_lock: threading.Lock = threading.Lock() + # 线程资源 + self._thread: threading.Thread = None def add_callback(self, callback: Callable): """ diff --git a/src/functor/spk_functor.py b/src/functor/spk_functor.py new file mode 100644 index 0000000..8d0fa7a --- /dev/null +++ b/src/functor/spk_functor.py @@ -0,0 +1,146 @@ +""" +SpkFunctor +负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback +""" +from src.functor.base import BaseFunctor +from src.models import AudioBinary_Config, VAD_Functor_result +from typing import Callable, List +from queue import Queue, Empty +import threading + +# 日志 +from src.utils.logger import get_module_logger + +logger = get_module_logger(__name__) + +class SpkFunctor(BaseFunctor): + """ + SpkFunctor + 负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback + 需要配置好 _model, _callback, _input_queue, _audio_config + 否则无法run()启动线程 + + 运行中, 使用reset_cache()重置缓存, 准备下次任务 + + 使用stop()停止线程, 但需要等待input_queue为空 + """ + + def __init__(self) -> None: + super().__init__() + # 资源与配置 + self._model: dict = {} # 模型 + self._callback: List[Callable] = [] # 回调函数 + self._input_queue: Queue = None # 输入队列 + self._audio_config: AudioBinary_Config = None # 音频配置 + + + def reset_cache(self) -> None: + """ + 重置缓存, 用于任务完成后清理缓存数据, 准备下次任务 + """ + pass + + def set_input_queue(self, queue: Queue) -> None: + """ + 设置监听的输入消息队列 + """ + self._input_queue = queue + + def set_model(self, model: dict) -> None: + """ + 设置推理模型 + """ + self._model = model + + def set_audio_config(self, audio_config: AudioBinary_Config) -> None: + """ + 设置音频配置 + """ + self._audio_config = audio_config + logger.debug("SpkFunctor设置音频配置: %s", self._audio_config) + + def add_callback(self, callback: Callable) -> None: + """ + 向自身的_callback: List[Callable]回调函数列表中添加回调函数 + """ + if not isinstance(self._callback, list): + self._callback = [] + self._callback.append(callback) + + def _do_callback(self, result: List[str]) -> None: + """ + 回调函数 + """ + for callback in self._callback: + callback(result) + + def _process(self, data: VAD_Functor_result) -> None: + """ + 处理数据 + """ + binary_data = data.audiobinary_data.binary_data + # result = self._model["spk"].generate( + # input=binary_data, + # chunk_size=self._audio_config.chunk_size, + # ) + result = [{'result': "spk1", 'score': {"spk1": 0.9, "spk2": 0.3}}] + self._do_callback(result) + + def _run(self) -> None: + """ + 线程运行逻辑 + """ + with self._status_lock: + self._is_running = True + self._stop_event = False + # 运行逻辑 + while self._is_running: + try: + data = self._input_queue.get(True, timeout=1) + self._process(data) + self._input_queue.task_done() + # 当队列为空时, 间隔1s检测是否进入停止事件。 + except Empty: + if self._stop_event: + break + continue + # 其他异常 + except Exception as e: + logger.error("SpkFunctor运行时发生错误: %s", e) + raise e + + def run(self) -> threading.Thread: + """ + 启动线程 + Returns: + threading.Thread: 返回已运行线程实例 + """ + self._pre_check() + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + return self._thread + + def _pre_check(self) -> bool: + """ + 预检查 + """ + if self._model is None: + raise ValueError("模型未设置") + if self._input_queue is None: + raise ValueError("输入队列未设置") + if self._callback is None: + raise ValueError("回调函数未设置") + return True + + def stop(self) -> bool: + """ + 停止线程 + """ + with self._status_lock: + self._stop_event = True + self._thread.join() + with self._status_lock: + self._is_running = False + return not self._thread.is_alive() + + \ No newline at end of file diff --git a/tests/functor/vad_test.py b/tests/functor/vad_test.py index 7a7df37..ccd2b02 100644 --- a/tests/functor/vad_test.py +++ b/tests/functor/vad_test.py @@ -4,6 +4,7 @@ VAD测试 """ from src.functor.vad_functor import VADFunctor from src.functor.asr_functor import ASRFunctor +from src.functor.spk_functor import SpkFunctor from queue import Queue, Empty from src.model_loader import ModelLoader from src.models import AudioBinary_Config, AudioBinary_data_list @@ -44,6 +45,7 @@ def test_vad_functor(): # 创建输入队列 input_queue = Queue() vad2asr_queue = Queue() + vad2spk_queue = Queue() # 创建音频数据列表 audio_binary_data_list = AudioBinary_data_list() @@ -58,6 +60,7 @@ def test_vad_functor(): # 设置回调函数 vad_functor.add_callback(lambda x: print(f"vad callback: {x}")) vad_functor.add_callback(lambda x: vad2asr_queue.put(x)) + vad_functor.add_callback(lambda x: vad2spk_queue.put(x)) # 设置模型 vad_functor.set_model({ 'vad': model_loader.models['vad'] @@ -80,6 +83,23 @@ def test_vad_functor(): # 启动ASR函数器 asr_functor.run() + # 创建SPK函数器 + spk_functor = SpkFunctor() + # 设置输入队列 + spk_functor.set_input_queue(vad2spk_queue) + # 设置音频配置 + spk_functor.set_audio_config(audio_config) + # 设置回调函数 + spk_functor.add_callback(lambda x: print(f"spk callback: {x}")) + # 设置模型 + spk_functor.set_model({ + # 'spk': model_loader.models['spk'] + 'spk': 'fake_spk' + }) + # 启动SPK函数器 + spk_functor.run() + + f_binary = f_data audio_clip_len = 200 print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}") @@ -87,15 +107,11 @@ def test_vad_functor(): binary_data = f_binary[i:i+audio_clip_len] input_queue.put(binary_data) # 等待VAD函数器结束 - print("[vad_test] 等待input_queue为空") - input_queue.join() - print("[vad_test] input_queue为空") + + vad_functor.stop() print("[vad_test] VAD函数器结束") - print("[vad_test] 等待vad2asr_queue为空") - vad2asr_queue.join() - print("[vad_test] vad2asr_queue为空") asr_functor.stop() print("[vad_test] ASR函数器结束")