diff --git a/src/functor/base.py b/src/functor/base.py index a5b345c..221401b 100644 --- a/src/functor/base.py +++ b/src/functor/base.py @@ -2,9 +2,9 @@ from typing import Callable class BaseFunctor: """ - 基础函数器类,提供数据处理的基本框架 + 基础函数器类, 提供数据处理的基本框架 - 该类实现了数据处理的基本接口,包括数据推送、处理和回调机制。 + 该类实现了数据处理的基本接口, 包括数据推送、处理和回调机制。 所有具体的功能实现类都应该继承这个基类。 属性: @@ -22,7 +22,7 @@ class BaseFunctor: 初始化函数器 参数: - data (dict or bytes): 初始数据,可以是字典或字节数据 + data (dict or bytes): 初始数据, 可以是字典或字节数据 callback (Callable): 处理完成后的回调函数 model (dict): 模型相关的配置和实例 """ @@ -34,33 +34,33 @@ class BaseFunctor: def __call__(self, data = None): """ - 使类实例可调用,处理数据并触发回调 + 使类实例可调用, 处理数据并触发回调 参数: - data: 要处理的数据,如果为None则处理已存储的数据 + data: 要处理的数据, 如果为None则处理已存储的数据 返回: 处理结果 """ - # 如果传入数据,则压入数据 + # 如果传入数据, 则压入数据 if data is not None: self.push_data(data) # 处理数据 result = self.process() - # 如果回调函数存在,则触发回调 + # 如果回调函数存在, 则触发回调 if self._callback is not None and callable(self._callback): self._callback(result) return result def __add__(self, other): """ - 重载加法运算符,用于合并数据 + 重载加法运算符, 用于合并数据 参数: other: 要合并的数据 返回: - self: 返回当前实例,支持链式调用 + self: 返回当前实例, 支持链式调用 """ self.push_data(other) return self diff --git a/src/pipeline/ASRpipeline.py b/src/pipeline/ASRpipeline.py new file mode 100644 index 0000000..7d215a3 --- /dev/null +++ b/src/pipeline/ASRpipeline.py @@ -0,0 +1,190 @@ +from src.pipeline.base import PipelineBase +from typing import Dict, Any +from queue import Queue +from src.utils import get_module_logger + +logger = get_module_logger(__name__) + +class ASRPipeline(PipelineBase): + """ + 管道类 + 实现具体的处理逻辑 + """ + def __init__(self, *args, **kwargs): + """ + 初始化管道 + """ + super().__init__(*args, **kwargs) + self._config: Dict[str, Any] = {} + self._funtor_dict: Dict[str, Any] = {} + self._subqueue_dict: Dict[str, Any] = {} + + self._is_baked: bool = False + + def set_config(self, config: Dict[str, Any]) -> None: + """ + 设置配置 + 参数: + config: Dict[str, Any] 配置 + """ + self._config = config + + def set_audio_binary(self, audio_binary: AudioBinary) -> None: + """ + 设置音频二进制存储单元 + 参数: + audio_binary: 音频二进制 + """ + self._audio_binary = audio_binary + + def set_models(self, models: Dict[str, Any]) -> None: + """ + 设置模型 + """ + self._models = models + + def bake(self) -> None: + """ + 烘焙管道 + """ + self._init_funtor() + self._is_baked = True + + def _init_funtor(self) -> None: + """ + 初始化函数 + """ + try: + from src.funtor import FuntorFactory + # 加载VAD、asr、spk funtor + self._funtor_dict["vad"] = FuntorFactory.get_funtor( + funtor_name = "vad", + config = self._config, + models = self._models + ) + self._funtor_dict["asr"] = FuntorFactory.get_funtor( + funtor_name = "asr", + config = self._config, + models = self._models + ) + self._funtor_dict["spk"] = FuntorFactory.get_funtor( + funtor_name = "spk", + config = self._config, + models = self._models + ) + + # 初始化子队列 + self._subqueue_dict["vad2asr"] = Queue() + self._subqueue_dict["vad2spk"] = Queue() + self._subqueue_dict["asrend"] = Queue() + self._subqueue_dict["spkend"] = Queue() + + # 设置子队列的输入队列 + self._funtor_dict["vad"].set_input_queue(self._input_queue) + self._funtor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"]) + self._funtor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"]) + + # 设置回调函数——放置到对应队列中 + self._funtor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put) + self._funtor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put) + + # 构造带回调函数的put + def put_with_checkcallback(queue: Queue, callback: Callable) -> None: + """ + 带回调函数的put + """ + def put_with_check(data: Any) -> None: + queue.put(data) + callback() + return put_with_check + + self._funtor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result)) + self._funtor_dict["spk"].add_callback(put_with_checkcallback(self._subqueue_dict["spkend"], self._check_result)) + + except ImportError: + raise ImportError("FuntorFactory引入失败,ASRPipeline无法完成初始化") + + def get_config(self) -> Dict[str, Any]: + """ + 获取配置 + 返回: + Dict[str, Any] 配置 + """ + return self._config + + def process(self, data: Any) -> Any: + """ + 处理数据 + 参数: + data: 输入数据 + 返回: + 处理结果 + """ + # 子类实现具体的处理逻辑 + self._input_queue.put(data) + + def run(self) -> None: + """ + 运行管道 + """ + if not self._is_baked: + raise RuntimeError("管道未烘焙,无法运行") + + # 运行所有funtor + for funtor_name, funtor in self._funtor_dict.items(): + logger.info(f"运行{funtor_name}funtor") + funtor.run() + + # 运行管道 + if not self._input_queue: + raise RuntimeError("输入队列未设置") + + # 设置管道运行状态 + self._is_running = True + self._stop_event = False + self._thread = threading.current_thread() + logger.info("ASR管道开始运行") + + while self._is_running and not self._stop_event: + try: + # 从队列获取数据 + try: + data = self._input_queue.get(timeout=self._queue_timeout) + # 检查是否是结束信号 + if data is None: + logger.info("收到结束信号,管道准备停止") + self._input_queue.task_done() # 标记结束信号已处理 + break + + # 处理数据 + self.process(data) + + # 标记任务完成 + self._input_queue.task_done() + + except Empty: + # 队列获取超时,继续等待 + continue + + except Exception as e: + logger.error(f"管道处理数据出错: {str(e)}") + continue + + logger.info("管道停止运行") + + def _check_result(self, result: Any) -> None: + """ + 检查结果 + """ + # 若asr和spk队列中都有数据,则合并数据 + if self._subqueue_dict["asrend"].qsize() & self._subqueue_dict["spkend"].qsize(): + asr_data = self._subqueue_dict["asrend"].get() + spk_data = self._subqueue_dict["spkend"].get() + # 合并数据 + result = { + "asr_data": asr_data, + "spk_data": spk_data + } + # 通知回调函数 + self._notify_callbacks(result) + diff --git a/src/pipeline/__init__.py b/src/pipeline/__init__.py index c2f90bb..dd78c84 100644 --- a/src/pipeline/__init__.py +++ b/src/pipeline/__init__.py @@ -1,3 +1,3 @@ -from src.pipeline.base import PipelineBase, Pipeline +from src.pipeline.base import PipelineBase, Pipeline, PipelineFactory -__all__ = ["PipelineBase", "Pipeline"] +__all__ = ["PipelineBase", "Pipeline", "PipelineFactory"] diff --git a/src/pipeline/base.py b/src/pipeline/base.py index 72b508f..e348f1b 100644 --- a/src/pipeline/base.py +++ b/src/pipeline/base.py @@ -1,21 +1,128 @@ from abc import ABC, abstractmethod +from queue import Queue, Empty +from typing import List, Callable, Any, Optional +import logging +import threading +import time +# 配置日志 +logger = logging.getLogger(__name__) class PipelineBase(ABC): """ 管道基类 + 定义了管道的基本接口和通用功能 """ + def __init__(self, input_queue: Optional[Queue] = None): + """ + 初始化管道 + 参数: + input_queue: 输入队列,用于接收数据 + """ + self._input_queue = input_queue + self._callbacks: List[Callable] = [] + self._is_running = False + self._stop_event = False + self._thread: Optional[threading.Thread] = None + self._stop_timeout = 5 # 默认停止超时时间(秒) + self._queue_timeout = 1 # 队列获取超时时间(秒) + + def set_input_queue(self, queue: Queue) -> None: + """ + 设置输入队列 + 参数: + queue: 输入队列 + """ + self._input_queue = queue + + def add_callback(self, callback: Callable) -> None: + """ + 添加回调函数 + 参数: + callback: 回调函数,接收处理结果 + """ + self._callbacks.append(callback) + + def _notify_callbacks(self, result: Any) -> None: + """ + 通知所有回调函数 + 参数: + result: 处理结果 + """ + for callback in self._callbacks: + try: + callback(result) + except Exception as e: + logger.error(f"回调函数执行出错: {str(e)}") + @abstractmethod - def run(self, *args, **kwargs): + def process(self, data: Any) -> Any: + """ + 处理数据 + 参数: + data: 输入数据 + 返回: + 处理结果 + """ + pass + + @abstractmethod + def run(self) -> None: """ 运行管道 + 从输入队列获取数据并处理 """ + pass -class Pipeline(PipelineBase): - """ - 管道类 - """ - def __init__(self, *args, **kwargs): + def stop(self, timeout: Optional[float] = None) -> bool: """ + 停止管道 + 参数: + timeout: 停止超时时间(秒),None表示使用默认超时时间 + 返回: + bool: 是否成功停止 """ - pass \ No newline at end of file + if not self._is_running: + return True + + logger.info("正在停止管道...") + self._stop_event = True + self._is_running = False + + # 等待线程结束 + if self._thread and self._thread.is_alive(): + timeout = timeout if timeout is not None else self._stop_timeout + self._thread.join(timeout=timeout) + + # 检查是否成功停止 + if self._thread.is_alive(): + logger.warning(f"管道停止超时({timeout}秒),强制终止") + return False + else: + logger.info("管道已成功停止") + return True + + return True + + def force_stop(self) -> None: + """ + 强制停止管道 + 注意:这可能会导致资源未正确释放 + """ + logger.warning("强制停止管道") + self._stop_event = True + self._is_running = False + # 注意:Python的线程无法被强制终止,这里只是设置标志 + # 实际终止需要依赖操作系统的进程管理 + +class PipelineFactory: + """ + 管道工厂类 + 用于创建管道实例 + """ + @staticmethod + def create_pipeline(pipeline_name: str) -> Pipeline: + """ + 创建管道实例 + """ + pass diff --git a/src/pipeline/test.py b/src/pipeline/test.py new file mode 100644 index 0000000..e69de29 diff --git a/src/runner.py b/src/runner.py index 7eb76af..24dabde 100644 --- a/src/runner.py +++ b/src/runner.py @@ -6,129 +6,227 @@ - Runner: 运行器类,工厂模式实现 - RunnerFactory: 运行器工厂类,用于创建运行器 """ + from abc import ABC, abstractmethod -from typing import Dict, Any, List, Queue +from typing import Dict, Any, List +from threading import Thread, Lock +from queue import Queue +import traceback +import time + from src.audio_chunk import AudioChunk, AudioBinary -from src.pipeline import Pipeline +from src.pipeline import Pipeline, PipelineFactory from src.model_loader import ModelLoader +from src.utils.logger import get_module_logger + +logger = get_module_logger(__name__, level="INFO") audio_chunk = AudioChunk() models_loaded = ModelLoader() -pipelines_loaded = PipelineLoader() + class RunnerBase(ABC): """ 运行器基类 + 定义了运行器的基本接口 """ - # 计算资源 - _audio_binary: AudioBinary = None - _models: Dict[str, Any] = {} - _pipeline: Pipeline = None - - # IO交互 - _receivers: List[callable] = [] - - # 异步交互消息队列 - _input_queue: Queue = None - _output_queue: Queue = None @abstractmethod - def adder(self, *args, **kwargs): + def adder(self, data: Any) -> None: """ 添加数据 + 参数: + data: 要添加的数据 """ + pass @abstractmethod - def add_recevier(self, *args, **kwargs): + def add_recevier(self, receiver: callable) -> None: """ - 接收数据 + 添加数据接收者 + 参数: + receiver: 接收数据的回调函数 """ + pass -class STT_Runner(RunnerBase): + +class STTRunner(RunnerBase): """ 运行器类 - 工厂模式 + 负责管理资源和协调Pipeline的运行 """ - def __init__( self, *, audio_binary_list: List[AudioBinary], models: Dict[str, Any], pipeline_list: List[Pipeline], - input_queue: Queue, ): """ - 初始化 + 初始化运行器 + 参数: + audio_binary_list: 音频二进制列表 + models: 模型字典 + pipeline_list: 管道列表 + queue_size: 队列大小 + stop_timeout: 停止超时时间(秒) """ # 接收资源 self._audio_binary_list = audio_binary_list self._models = models self._pipeline_list = pipeline_list + # 线程控制 + self._lock = Lock() + + # 消息队列 + self._input_queue = Queue(maxsize=1000) + + # 停止控制 + self._stop_timeout = 10.0 + self._is_stopping = False + # 配置资源 for pipeline in self._pipeline_list: - # 配置 - if pipeline.get_config('audio_binary_name') is not None: - pipeline.set_audio_binary(self._audio_binary_list[pipeline.get_config('audio_binary_name')]) - if pipeline.get_config('model_name_list') is not None: - pipeline.set_models(self._models) + # 设置输入队列 + pipeline.set_input_queue(self._input_queue) - def adder(self, *args, **kwargs): + # 配置资源 + pipeline.set_audio_binary( + self._audio_binary_list[pipeline.get_config("audio_binary_name")] + ) + pipeline.set_models(self._models) + + def adder(self, data: Any) -> None: """ - 添加数据 + 添加数据到输入队列 + 参数: + data: 要添加的数据 """ - if self._pipeline_thread is None: - raise RuntimeError("Pipeline thread not started") - self._input_queue.put(args["data"]) - - def add_recevier(self, recevier: callable): + if not self._pipeline_list: + raise RuntimeError("没有可用的管道") + if self._is_stopping: + raise RuntimeError("运行器正在停止,无法添加数据") + self._input_queue.put(data) + + def add_recevier(self, receiver: callable) -> None: """ 添加数据接收者 + 参数: + receiver: 接收数据的回调函数 """ - if self._pipeline_thread is None: - raise RuntimeError("Pipeline thread not started") - self._receivers.append(recevier) + with self._lock: + for pipeline in self._pipeline_list: + pipeline.add_callback(receiver) - def run(self): + def run(self) -> None: """ - 运行pipeline子线程 + 启动所有管道 """ - # 创建pipeline子线程 - self._pipeline_thread = threading.Thread( - target=self._pipeline.run, - args=( - input_queue=self._input_queue, - output_queue=self._output_queue - ) - ) - self._pipeline_thread.start() + logger.info("[%s] 启动所有管道", self.__class__.__name__) + if not self._pipeline_list: + raise RuntimeError("没有可用的管道") - def stop(self): - """ - 停止pipeline子线程 - """ - # 结束pipeline子线程 - self._pipeline_thread.join() - - def __del__(self): - """ - 析构 - """ - self.stop() + # 启动所有管道 + for pipeline in self._pipeline_list: + thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}") + thread.daemon = True + thread.start() + logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline)) -class STT_RunnerFactory: + def stop(self, force: bool = False) -> bool: + """ + 停止所有管道 + 参数: + force: 是否强制停止 + 返回: + bool: 是否成功停止 + """ + if self._is_stopping: + logger.warning("运行器已经在停止中") + return False + + self._is_stopping = True + logger.info("正在停止运行器...") + + try: + # 发送结束信号 + self._input_queue.put(None) + + # 停止所有管道 + success = True + for pipeline in self._pipeline_list: + if force: + pipeline.force_stop() + else: + if not pipeline.stop(timeout=self._stop_timeout): + logger.warning("管道 %s 停止超时", id(pipeline)) + success = False + + # 等待队列处理完成 + try: + start_time = time.time() + while not self._input_queue.empty(): + if time.time() - start_time > self._stop_timeout: + logger.warning( + "等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理", + self._stop_timeout, + self._input_queue.qsize() + ) + success = False + break + time.sleep(0.1) # 避免过度消耗CPU + except Exception as e: + error_type = type(e).__name__ + error_msg = str(e) + error_traceback = traceback.format_exc() + logger.error( + "等待队列处理完成时发生错误:\n" + "错误类型: %s\n" + "错误信息: %s\n" + "错误堆栈:\n%s", + error_type, + error_msg, + error_traceback + ) + success = False + + if success: + logger.info("所有管道已成功停止") + else: + logger.warning( + "部分管道停止失败,队列状态: 大小=%d, 是否为空=%s", + self._input_queue.qsize(), + self._input_queue.empty() + ) + + return success + + finally: + self._is_stopping = False + + def __del__(self) -> None: + """ + 析构函数 + """ + self.stop(force=True) + + +class STTRunnerFactory: """ STT Runner工厂类 + 用于创建运行器实例 """ + + @staticmethod def _create_runner( audio_binary_name: str, model_name_list: List[str], pipeline_name_list: List[str], - ): + ) -> STTRunner: """ - 全参数创建Runner + 创建运行器 参数: audio_binary_name: 音频二进制名称 model_name_list: 模型名称列表 @@ -142,26 +240,42 @@ class STT_RunnerFactory: for model_name in model_name_list } pipelines: List[Pipeline] = [ - pipelines_loaded.pipelines[pipeline_name] + PipelineFactory.create_pipeline(pipeline_name) for pipeline_name in pipeline_name_list ] - return Runner(audio_binary, models, pipelines) + return STTRunner( + audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines + ) @classmethod def create_runner_from_config( cls, config: Dict[str, Any], - ): + ) -> STTRunner: + """ + 从配置创建运行器 + 参数: + config: 配置字典 + 返回: + Runner实例 + """ audio_binary_name = config["audio_binary_name"] model_name_list = config["model_name_list"] pipeline_name_list = config["pipeline_name_list"] - return cls._create_runner(audio_binary_name, model_name_list, pipeline_name_list) + return cls._create_runner( + audio_binary_name, model_name_list, pipeline_name_list + ) @classmethod - def create_runner_normal( - cls - ) + def create_runner_normal(cls) -> STTRunner: + """ + 创建默认运行器 + 返回: + Runner实例 + """ audio_binary_name = None - model_name_list = models_loaded.models.keys() + model_name_list = list(models_loaded.models.keys()) pipeline_name_list = None - return cls._create_runner(audio_binary_name, model_name_list, pipeline_name_list) + return cls._create_runner( + audio_binary_name, model_name_list, pipeline_name_list + )