diff --git a/src/pipeline/__init__.py b/src/pipeline/__init__.py new file mode 100644 index 0000000..c2f90bb --- /dev/null +++ b/src/pipeline/__init__.py @@ -0,0 +1,3 @@ +from src.pipeline.base import PipelineBase, Pipeline + +__all__ = ["PipelineBase", "Pipeline"] diff --git a/src/pipeline/base.py b/src/pipeline/base.py new file mode 100644 index 0000000..72b508f --- /dev/null +++ b/src/pipeline/base.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod + + +class PipelineBase(ABC): + """ + 管道基类 + """ + @abstractmethod + def run(self, *args, **kwargs): + """ + 运行管道 + """ + +class Pipeline(PipelineBase): + """ + 管道类 + """ + def __init__(self, *args, **kwargs): + """ + """ + pass \ No newline at end of file diff --git a/src/runner.py b/src/runner.py index e69de29..7eb76af 100644 --- a/src/runner.py +++ b/src/runner.py @@ -0,0 +1,167 @@ +""" +运行器模块 +提供运行器基类和运行器类,用于管理音频数据和模型的交互。 +主要包含: +- RunnerBase: 运行器基类,定义了基本接口 +- Runner: 运行器类,工厂模式实现 +- RunnerFactory: 运行器工厂类,用于创建运行器 +""" +from abc import ABC, abstractmethod +from typing import Dict, Any, List, Queue +from src.audio_chunk import AudioChunk, AudioBinary +from src.pipeline import Pipeline +from src.model_loader import ModelLoader + +audio_chunk = AudioChunk() +models_loaded = ModelLoader() +pipelines_loaded = PipelineLoader() + +class RunnerBase(ABC): + """ + 运行器基类 + """ + # 计算资源 + _audio_binary: AudioBinary = None + _models: Dict[str, Any] = {} + _pipeline: Pipeline = None + + # IO交互 + _receivers: List[callable] = [] + + # 异步交互消息队列 + _input_queue: Queue = None + _output_queue: Queue = None + + @abstractmethod + def adder(self, *args, **kwargs): + """ + 添加数据 + """ + + @abstractmethod + def add_recevier(self, *args, **kwargs): + """ + 接收数据 + """ + +class STT_Runner(RunnerBase): + """ + 运行器类 + 工厂模式 + """ + + + def __init__( + self, + *, + audio_binary_list: List[AudioBinary], + models: Dict[str, Any], + pipeline_list: List[Pipeline], + input_queue: Queue, + ): + """ + 初始化 + """ + # 接收资源 + self._audio_binary_list = audio_binary_list + self._models = models + self._pipeline_list = pipeline_list + + # 配置资源 + for pipeline in self._pipeline_list: + # 配置 + if pipeline.get_config('audio_binary_name') is not None: + pipeline.set_audio_binary(self._audio_binary_list[pipeline.get_config('audio_binary_name')]) + if pipeline.get_config('model_name_list') is not None: + pipeline.set_models(self._models) + + def adder(self, *args, **kwargs): + """ + 添加数据 + """ + if self._pipeline_thread is None: + raise RuntimeError("Pipeline thread not started") + self._input_queue.put(args["data"]) + + def add_recevier(self, recevier: callable): + """ + 添加数据接收者 + """ + if self._pipeline_thread is None: + raise RuntimeError("Pipeline thread not started") + self._receivers.append(recevier) + + def run(self): + """ + 运行pipeline子线程 + """ + # 创建pipeline子线程 + self._pipeline_thread = threading.Thread( + target=self._pipeline.run, + args=( + input_queue=self._input_queue, + output_queue=self._output_queue + ) + ) + self._pipeline_thread.start() + + def stop(self): + """ + 停止pipeline子线程 + """ + # 结束pipeline子线程 + self._pipeline_thread.join() + + def __del__(self): + """ + 析构 + """ + self.stop() + +class STT_RunnerFactory: + """ + STT Runner工厂类 + """ + def _create_runner( + audio_binary_name: str, + model_name_list: List[str], + pipeline_name_list: List[str], + ): + """ + 全参数创建Runner + 参数: + audio_binary_name: 音频二进制名称 + model_name_list: 模型名称列表 + pipeline_name_list: 管道名称列表 + 返回: + Runner实例 + """ + audio_binary = audio_chunk.get_audio_binary(audio_binary_name) + models: Dict[str, Any] = { + model_name: models_loaded.models[model_name] + for model_name in model_name_list + } + pipelines: List[Pipeline] = [ + pipelines_loaded.pipelines[pipeline_name] + for pipeline_name in pipeline_name_list + ] + return Runner(audio_binary, models, pipelines) + + @classmethod + def create_runner_from_config( + cls, + config: Dict[str, Any], + ): + audio_binary_name = config["audio_binary_name"] + model_name_list = config["model_name_list"] + pipeline_name_list = config["pipeline_name_list"] + return cls._create_runner(audio_binary_name, model_name_list, pipeline_name_list) + + @classmethod + def create_runner_normal( + cls + ) + audio_binary_name = None + model_name_list = models_loaded.models.keys() + pipeline_name_list = None + return cls._create_runner(audio_binary_name, model_name_list, pipeline_name_list)