[代码重构中]编写class STT_Runner中,将设计为线程启动。作为异步IO与资源管理模块。
This commit is contained in:
parent
703a40e955
commit
49cb428c23
3
src/pipeline/__init__.py
Normal file
3
src/pipeline/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from src.pipeline.base import PipelineBase, Pipeline
|
||||
|
||||
__all__ = ["PipelineBase", "Pipeline"]
|
21
src/pipeline/base.py
Normal file
21
src/pipeline/base.py
Normal file
@ -0,0 +1,21 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class PipelineBase(ABC):
|
||||
"""
|
||||
管道基类
|
||||
"""
|
||||
@abstractmethod
|
||||
def run(self, *args, **kwargs):
|
||||
"""
|
||||
运行管道
|
||||
"""
|
||||
|
||||
class Pipeline(PipelineBase):
|
||||
"""
|
||||
管道类
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
"""
|
||||
pass
|
167
src/runner.py
167
src/runner.py
@ -0,0 +1,167 @@
|
||||
"""
|
||||
运行器模块
|
||||
提供运行器基类和运行器类,用于管理音频数据和模型的交互。
|
||||
主要包含:
|
||||
- RunnerBase: 运行器基类,定义了基本接口
|
||||
- Runner: 运行器类,工厂模式实现
|
||||
- RunnerFactory: 运行器工厂类,用于创建运行器
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List, Queue
|
||||
from src.audio_chunk import AudioChunk, AudioBinary
|
||||
from src.pipeline import Pipeline
|
||||
from src.model_loader import ModelLoader
|
||||
|
||||
audio_chunk = AudioChunk()
|
||||
models_loaded = ModelLoader()
|
||||
pipelines_loaded = PipelineLoader()
|
||||
|
||||
class RunnerBase(ABC):
|
||||
"""
|
||||
运行器基类
|
||||
"""
|
||||
# 计算资源
|
||||
_audio_binary: AudioBinary = None
|
||||
_models: Dict[str, Any] = {}
|
||||
_pipeline: Pipeline = None
|
||||
|
||||
# IO交互
|
||||
_receivers: List[callable] = []
|
||||
|
||||
# 异步交互消息队列
|
||||
_input_queue: Queue = None
|
||||
_output_queue: Queue = None
|
||||
|
||||
@abstractmethod
|
||||
def adder(self, *args, **kwargs):
|
||||
"""
|
||||
添加数据
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_recevier(self, *args, **kwargs):
|
||||
"""
|
||||
接收数据
|
||||
"""
|
||||
|
||||
class STT_Runner(RunnerBase):
|
||||
"""
|
||||
运行器类
|
||||
工厂模式
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
audio_binary_list: List[AudioBinary],
|
||||
models: Dict[str, Any],
|
||||
pipeline_list: List[Pipeline],
|
||||
input_queue: Queue,
|
||||
):
|
||||
"""
|
||||
初始化
|
||||
"""
|
||||
# 接收资源
|
||||
self._audio_binary_list = audio_binary_list
|
||||
self._models = models
|
||||
self._pipeline_list = pipeline_list
|
||||
|
||||
# 配置资源
|
||||
for pipeline in self._pipeline_list:
|
||||
# 配置
|
||||
if pipeline.get_config('audio_binary_name') is not None:
|
||||
pipeline.set_audio_binary(self._audio_binary_list[pipeline.get_config('audio_binary_name')])
|
||||
if pipeline.get_config('model_name_list') is not None:
|
||||
pipeline.set_models(self._models)
|
||||
|
||||
def adder(self, *args, **kwargs):
|
||||
"""
|
||||
添加数据
|
||||
"""
|
||||
if self._pipeline_thread is None:
|
||||
raise RuntimeError("Pipeline thread not started")
|
||||
self._input_queue.put(args["data"])
|
||||
|
||||
def add_recevier(self, recevier: callable):
|
||||
"""
|
||||
添加数据接收者
|
||||
"""
|
||||
if self._pipeline_thread is None:
|
||||
raise RuntimeError("Pipeline thread not started")
|
||||
self._receivers.append(recevier)
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
运行pipeline子线程
|
||||
"""
|
||||
# 创建pipeline子线程
|
||||
self._pipeline_thread = threading.Thread(
|
||||
target=self._pipeline.run,
|
||||
args=(
|
||||
input_queue=self._input_queue,
|
||||
output_queue=self._output_queue
|
||||
)
|
||||
)
|
||||
self._pipeline_thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
停止pipeline子线程
|
||||
"""
|
||||
# 结束pipeline子线程
|
||||
self._pipeline_thread.join()
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
析构
|
||||
"""
|
||||
self.stop()
|
||||
|
||||
class STT_RunnerFactory:
|
||||
"""
|
||||
STT Runner工厂类
|
||||
"""
|
||||
def _create_runner(
|
||||
audio_binary_name: str,
|
||||
model_name_list: List[str],
|
||||
pipeline_name_list: List[str],
|
||||
):
|
||||
"""
|
||||
全参数创建Runner
|
||||
参数:
|
||||
audio_binary_name: 音频二进制名称
|
||||
model_name_list: 模型名称列表
|
||||
pipeline_name_list: 管道名称列表
|
||||
返回:
|
||||
Runner实例
|
||||
"""
|
||||
audio_binary = audio_chunk.get_audio_binary(audio_binary_name)
|
||||
models: Dict[str, Any] = {
|
||||
model_name: models_loaded.models[model_name]
|
||||
for model_name in model_name_list
|
||||
}
|
||||
pipelines: List[Pipeline] = [
|
||||
pipelines_loaded.pipelines[pipeline_name]
|
||||
for pipeline_name in pipeline_name_list
|
||||
]
|
||||
return Runner(audio_binary, models, pipelines)
|
||||
|
||||
@classmethod
|
||||
def create_runner_from_config(
|
||||
cls,
|
||||
config: Dict[str, Any],
|
||||
):
|
||||
audio_binary_name = config["audio_binary_name"]
|
||||
model_name_list = config["model_name_list"]
|
||||
pipeline_name_list = config["pipeline_name_list"]
|
||||
return cls._create_runner(audio_binary_name, model_name_list, pipeline_name_list)
|
||||
|
||||
@classmethod
|
||||
def create_runner_normal(
|
||||
cls
|
||||
)
|
||||
audio_binary_name = None
|
||||
model_name_list = models_loaded.models.keys()
|
||||
pipeline_name_list = None
|
||||
return cls._create_runner(audio_binary_name, model_name_list, pipeline_name_list)
|
Loading…
x
Reference in New Issue
Block a user