[代码重构中]编写class STT_Runner中,将设计为线程启动。作为异步IO与资源管理模块。
This commit is contained in:
parent
703a40e955
commit
49cb428c23
3
src/pipeline/__init__.py
Normal file
3
src/pipeline/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from src.pipeline.base import PipelineBase, Pipeline
|
||||||
|
|
||||||
|
__all__ = ["PipelineBase", "Pipeline"]
|
21
src/pipeline/base.py
Normal file
21
src/pipeline/base.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineBase(ABC):
|
||||||
|
"""
|
||||||
|
管道基类
|
||||||
|
"""
|
||||||
|
@abstractmethod
|
||||||
|
def run(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
运行管道
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Pipeline(PipelineBase):
|
||||||
|
"""
|
||||||
|
管道类
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
pass
|
167
src/runner.py
167
src/runner.py
@ -0,0 +1,167 @@
|
|||||||
|
"""
|
||||||
|
运行器模块
|
||||||
|
提供运行器基类和运行器类,用于管理音频数据和模型的交互。
|
||||||
|
主要包含:
|
||||||
|
- RunnerBase: 运行器基类,定义了基本接口
|
||||||
|
- Runner: 运行器类,工厂模式实现
|
||||||
|
- RunnerFactory: 运行器工厂类,用于创建运行器
|
||||||
|
"""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Any, List, Queue
|
||||||
|
from src.audio_chunk import AudioChunk, AudioBinary
|
||||||
|
from src.pipeline import Pipeline
|
||||||
|
from src.model_loader import ModelLoader
|
||||||
|
|
||||||
|
audio_chunk = AudioChunk()
|
||||||
|
models_loaded = ModelLoader()
|
||||||
|
pipelines_loaded = PipelineLoader()
|
||||||
|
|
||||||
|
class RunnerBase(ABC):
|
||||||
|
"""
|
||||||
|
运行器基类
|
||||||
|
"""
|
||||||
|
# 计算资源
|
||||||
|
_audio_binary: AudioBinary = None
|
||||||
|
_models: Dict[str, Any] = {}
|
||||||
|
_pipeline: Pipeline = None
|
||||||
|
|
||||||
|
# IO交互
|
||||||
|
_receivers: List[callable] = []
|
||||||
|
|
||||||
|
# 异步交互消息队列
|
||||||
|
_input_queue: Queue = None
|
||||||
|
_output_queue: Queue = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def adder(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
添加数据
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_recevier(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
接收数据
|
||||||
|
"""
|
||||||
|
|
||||||
|
class STT_Runner(RunnerBase):
|
||||||
|
"""
|
||||||
|
运行器类
|
||||||
|
工厂模式
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
audio_binary_list: List[AudioBinary],
|
||||||
|
models: Dict[str, Any],
|
||||||
|
pipeline_list: List[Pipeline],
|
||||||
|
input_queue: Queue,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化
|
||||||
|
"""
|
||||||
|
# 接收资源
|
||||||
|
self._audio_binary_list = audio_binary_list
|
||||||
|
self._models = models
|
||||||
|
self._pipeline_list = pipeline_list
|
||||||
|
|
||||||
|
# 配置资源
|
||||||
|
for pipeline in self._pipeline_list:
|
||||||
|
# 配置
|
||||||
|
if pipeline.get_config('audio_binary_name') is not None:
|
||||||
|
pipeline.set_audio_binary(self._audio_binary_list[pipeline.get_config('audio_binary_name')])
|
||||||
|
if pipeline.get_config('model_name_list') is not None:
|
||||||
|
pipeline.set_models(self._models)
|
||||||
|
|
||||||
|
def adder(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
添加数据
|
||||||
|
"""
|
||||||
|
if self._pipeline_thread is None:
|
||||||
|
raise RuntimeError("Pipeline thread not started")
|
||||||
|
self._input_queue.put(args["data"])
|
||||||
|
|
||||||
|
def add_recevier(self, recevier: callable):
|
||||||
|
"""
|
||||||
|
添加数据接收者
|
||||||
|
"""
|
||||||
|
if self._pipeline_thread is None:
|
||||||
|
raise RuntimeError("Pipeline thread not started")
|
||||||
|
self._receivers.append(recevier)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""
|
||||||
|
运行pipeline子线程
|
||||||
|
"""
|
||||||
|
# 创建pipeline子线程
|
||||||
|
self._pipeline_thread = threading.Thread(
|
||||||
|
target=self._pipeline.run,
|
||||||
|
args=(
|
||||||
|
input_queue=self._input_queue,
|
||||||
|
output_queue=self._output_queue
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._pipeline_thread.start()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""
|
||||||
|
停止pipeline子线程
|
||||||
|
"""
|
||||||
|
# 结束pipeline子线程
|
||||||
|
self._pipeline_thread.join()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""
|
||||||
|
析构
|
||||||
|
"""
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
class STT_RunnerFactory:
|
||||||
|
"""
|
||||||
|
STT Runner工厂类
|
||||||
|
"""
|
||||||
|
def _create_runner(
|
||||||
|
audio_binary_name: str,
|
||||||
|
model_name_list: List[str],
|
||||||
|
pipeline_name_list: List[str],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
全参数创建Runner
|
||||||
|
参数:
|
||||||
|
audio_binary_name: 音频二进制名称
|
||||||
|
model_name_list: 模型名称列表
|
||||||
|
pipeline_name_list: 管道名称列表
|
||||||
|
返回:
|
||||||
|
Runner实例
|
||||||
|
"""
|
||||||
|
audio_binary = audio_chunk.get_audio_binary(audio_binary_name)
|
||||||
|
models: Dict[str, Any] = {
|
||||||
|
model_name: models_loaded.models[model_name]
|
||||||
|
for model_name in model_name_list
|
||||||
|
}
|
||||||
|
pipelines: List[Pipeline] = [
|
||||||
|
pipelines_loaded.pipelines[pipeline_name]
|
||||||
|
for pipeline_name in pipeline_name_list
|
||||||
|
]
|
||||||
|
return Runner(audio_binary, models, pipelines)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_runner_from_config(
|
||||||
|
cls,
|
||||||
|
config: Dict[str, Any],
|
||||||
|
):
|
||||||
|
audio_binary_name = config["audio_binary_name"]
|
||||||
|
model_name_list = config["model_name_list"]
|
||||||
|
pipeline_name_list = config["pipeline_name_list"]
|
||||||
|
return cls._create_runner(audio_binary_name, model_name_list, pipeline_name_list)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_runner_normal(
|
||||||
|
cls
|
||||||
|
)
|
||||||
|
audio_binary_name = None
|
||||||
|
model_name_list = models_loaded.models.keys()
|
||||||
|
pipeline_name_list = None
|
||||||
|
return cls._create_runner(audio_binary_name, model_name_list, pipeline_name_list)
|
Loading…
x
Reference in New Issue
Block a user