[代码重构中]编写class STT_Runner中,将设计为线程启动。作为异步IO与资源管理模块。

This commit is contained in:
Ziyang.Zhang 2025-05-28 18:00:54 +08:00
parent 703a40e955
commit 49cb428c23
3 changed files with 191 additions and 0 deletions

3
src/pipeline/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from src.pipeline.base import PipelineBase, Pipeline
__all__ = ["PipelineBase", "Pipeline"]

21
src/pipeline/base.py Normal file
View File

@ -0,0 +1,21 @@
from abc import ABC, abstractmethod
class PipelineBase(ABC):
"""
管道基类
"""
@abstractmethod
def run(self, *args, **kwargs):
"""
运行管道
"""
class Pipeline(PipelineBase):
"""
管道类
"""
def __init__(self, *args, **kwargs):
"""
"""
pass

View File

@ -0,0 +1,167 @@
"""
运行器模块
提供运行器基类和运行器类用于管理音频数据和模型的交互
主要包含:
- RunnerBase: 运行器基类,定义了基本接口
- Runner: 运行器类,工厂模式实现
- RunnerFactory: 运行器工厂类,用于创建运行器
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Queue
from src.audio_chunk import AudioChunk, AudioBinary
from src.pipeline import Pipeline
from src.model_loader import ModelLoader
audio_chunk = AudioChunk()
models_loaded = ModelLoader()
pipelines_loaded = PipelineLoader()
class RunnerBase(ABC):
"""
运行器基类
"""
# 计算资源
_audio_binary: AudioBinary = None
_models: Dict[str, Any] = {}
_pipeline: Pipeline = None
# IO交互
_receivers: List[callable] = []
# 异步交互消息队列
_input_queue: Queue = None
_output_queue: Queue = None
@abstractmethod
def adder(self, *args, **kwargs):
"""
添加数据
"""
@abstractmethod
def add_recevier(self, *args, **kwargs):
"""
接收数据
"""
class STT_Runner(RunnerBase):
"""
运行器类
工厂模式
"""
def __init__(
self,
*,
audio_binary_list: List[AudioBinary],
models: Dict[str, Any],
pipeline_list: List[Pipeline],
input_queue: Queue,
):
"""
初始化
"""
# 接收资源
self._audio_binary_list = audio_binary_list
self._models = models
self._pipeline_list = pipeline_list
# 配置资源
for pipeline in self._pipeline_list:
# 配置
if pipeline.get_config('audio_binary_name') is not None:
pipeline.set_audio_binary(self._audio_binary_list[pipeline.get_config('audio_binary_name')])
if pipeline.get_config('model_name_list') is not None:
pipeline.set_models(self._models)
def adder(self, *args, **kwargs):
"""
添加数据
"""
if self._pipeline_thread is None:
raise RuntimeError("Pipeline thread not started")
self._input_queue.put(args["data"])
def add_recevier(self, recevier: callable):
"""
添加数据接收者
"""
if self._pipeline_thread is None:
raise RuntimeError("Pipeline thread not started")
self._receivers.append(recevier)
def run(self):
"""
运行pipeline子线程
"""
# 创建pipeline子线程
self._pipeline_thread = threading.Thread(
target=self._pipeline.run,
args=(
input_queue=self._input_queue,
output_queue=self._output_queue
)
)
self._pipeline_thread.start()
def stop(self):
"""
停止pipeline子线程
"""
# 结束pipeline子线程
self._pipeline_thread.join()
def __del__(self):
"""
析构
"""
self.stop()
class STT_RunnerFactory:
"""
STT Runner工厂类
"""
def _create_runner(
audio_binary_name: str,
model_name_list: List[str],
pipeline_name_list: List[str],
):
"""
全参数创建Runner
参数:
audio_binary_name: 音频二进制名称
model_name_list: 模型名称列表
pipeline_name_list: 管道名称列表
返回:
Runner实例
"""
audio_binary = audio_chunk.get_audio_binary(audio_binary_name)
models: Dict[str, Any] = {
model_name: models_loaded.models[model_name]
for model_name in model_name_list
}
pipelines: List[Pipeline] = [
pipelines_loaded.pipelines[pipeline_name]
for pipeline_name in pipeline_name_list
]
return Runner(audio_binary, models, pipelines)
@classmethod
def create_runner_from_config(
cls,
config: Dict[str, Any],
):
audio_binary_name = config["audio_binary_name"]
model_name_list = config["model_name_list"]
pipeline_name_list = config["pipeline_name_list"]
return cls._create_runner(audio_binary_name, model_name_list, pipeline_name_list)
@classmethod
def create_runner_normal(
cls
)
audio_binary_name = None
model_name_list = models_loaded.models.keys()
pipeline_name_list = None
return cls._create_runner(audio_binary_name, model_name_list, pipeline_name_list)