[代码重构中]编写ASRpipeline,管理funtor的线程启动,管理funtor间消息队列queue
This commit is contained in:
parent
49cb428c23
commit
f245c6e9df
@ -2,9 +2,9 @@ from typing import Callable
|
||||
|
||||
class BaseFunctor:
|
||||
"""
|
||||
基础函数器类,提供数据处理的基本框架
|
||||
基础函数器类, 提供数据处理的基本框架
|
||||
|
||||
该类实现了数据处理的基本接口,包括数据推送、处理和回调机制。
|
||||
该类实现了数据处理的基本接口, 包括数据推送、处理和回调机制。
|
||||
所有具体的功能实现类都应该继承这个基类。
|
||||
|
||||
属性:
|
||||
@ -22,7 +22,7 @@ class BaseFunctor:
|
||||
初始化函数器
|
||||
|
||||
参数:
|
||||
data (dict or bytes): 初始数据,可以是字典或字节数据
|
||||
data (dict or bytes): 初始数据, 可以是字典或字节数据
|
||||
callback (Callable): 处理完成后的回调函数
|
||||
model (dict): 模型相关的配置和实例
|
||||
"""
|
||||
@ -34,33 +34,33 @@ class BaseFunctor:
|
||||
|
||||
def __call__(self, data = None):
|
||||
"""
|
||||
使类实例可调用,处理数据并触发回调
|
||||
使类实例可调用, 处理数据并触发回调
|
||||
|
||||
参数:
|
||||
data: 要处理的数据,如果为None则处理已存储的数据
|
||||
data: 要处理的数据, 如果为None则处理已存储的数据
|
||||
|
||||
返回:
|
||||
处理结果
|
||||
"""
|
||||
# 如果传入数据,则压入数据
|
||||
# 如果传入数据, 则压入数据
|
||||
if data is not None:
|
||||
self.push_data(data)
|
||||
# 处理数据
|
||||
result = self.process()
|
||||
# 如果回调函数存在,则触发回调
|
||||
# 如果回调函数存在, 则触发回调
|
||||
if self._callback is not None and callable(self._callback):
|
||||
self._callback(result)
|
||||
return result
|
||||
|
||||
def __add__(self, other):
|
||||
"""
|
||||
重载加法运算符,用于合并数据
|
||||
重载加法运算符, 用于合并数据
|
||||
|
||||
参数:
|
||||
other: 要合并的数据
|
||||
|
||||
返回:
|
||||
self: 返回当前实例,支持链式调用
|
||||
self: 返回当前实例, 支持链式调用
|
||||
"""
|
||||
self.push_data(other)
|
||||
return self
|
||||
|
190
src/pipeline/ASRpipeline.py
Normal file
190
src/pipeline/ASRpipeline.py
Normal file
@ -0,0 +1,190 @@
|
||||
from src.pipeline.base import PipelineBase
|
||||
from typing import Dict, Any
|
||||
from queue import Queue
|
||||
from src.utils import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
class ASRPipeline(PipelineBase):
|
||||
"""
|
||||
管道类
|
||||
实现具体的处理逻辑
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
初始化管道
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._config: Dict[str, Any] = {}
|
||||
self._funtor_dict: Dict[str, Any] = {}
|
||||
self._subqueue_dict: Dict[str, Any] = {}
|
||||
|
||||
self._is_baked: bool = False
|
||||
|
||||
def set_config(self, config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
设置配置
|
||||
参数:
|
||||
config: Dict[str, Any] 配置
|
||||
"""
|
||||
self._config = config
|
||||
|
||||
def set_audio_binary(self, audio_binary: AudioBinary) -> None:
|
||||
"""
|
||||
设置音频二进制存储单元
|
||||
参数:
|
||||
audio_binary: 音频二进制
|
||||
"""
|
||||
self._audio_binary = audio_binary
|
||||
|
||||
def set_models(self, models: Dict[str, Any]) -> None:
|
||||
"""
|
||||
设置模型
|
||||
"""
|
||||
self._models = models
|
||||
|
||||
def bake(self) -> None:
|
||||
"""
|
||||
烘焙管道
|
||||
"""
|
||||
self._init_funtor()
|
||||
self._is_baked = True
|
||||
|
||||
def _init_funtor(self) -> None:
|
||||
"""
|
||||
初始化函数
|
||||
"""
|
||||
try:
|
||||
from src.funtor import FuntorFactory
|
||||
# 加载VAD、asr、spk funtor
|
||||
self._funtor_dict["vad"] = FuntorFactory.get_funtor(
|
||||
funtor_name = "vad",
|
||||
config = self._config,
|
||||
models = self._models
|
||||
)
|
||||
self._funtor_dict["asr"] = FuntorFactory.get_funtor(
|
||||
funtor_name = "asr",
|
||||
config = self._config,
|
||||
models = self._models
|
||||
)
|
||||
self._funtor_dict["spk"] = FuntorFactory.get_funtor(
|
||||
funtor_name = "spk",
|
||||
config = self._config,
|
||||
models = self._models
|
||||
)
|
||||
|
||||
# 初始化子队列
|
||||
self._subqueue_dict["vad2asr"] = Queue()
|
||||
self._subqueue_dict["vad2spk"] = Queue()
|
||||
self._subqueue_dict["asrend"] = Queue()
|
||||
self._subqueue_dict["spkend"] = Queue()
|
||||
|
||||
# 设置子队列的输入队列
|
||||
self._funtor_dict["vad"].set_input_queue(self._input_queue)
|
||||
self._funtor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
|
||||
self._funtor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
|
||||
|
||||
# 设置回调函数——放置到对应队列中
|
||||
self._funtor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put)
|
||||
self._funtor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put)
|
||||
|
||||
# 构造带回调函数的put
|
||||
def put_with_checkcallback(queue: Queue, callback: Callable) -> None:
|
||||
"""
|
||||
带回调函数的put
|
||||
"""
|
||||
def put_with_check(data: Any) -> None:
|
||||
queue.put(data)
|
||||
callback()
|
||||
return put_with_check
|
||||
|
||||
self._funtor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result))
|
||||
self._funtor_dict["spk"].add_callback(put_with_checkcallback(self._subqueue_dict["spkend"], self._check_result))
|
||||
|
||||
except ImportError:
|
||||
raise ImportError("FuntorFactory引入失败,ASRPipeline无法完成初始化")
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取配置
|
||||
返回:
|
||||
Dict[str, Any] 配置
|
||||
"""
|
||||
return self._config
|
||||
|
||||
def process(self, data: Any) -> Any:
|
||||
"""
|
||||
处理数据
|
||||
参数:
|
||||
data: 输入数据
|
||||
返回:
|
||||
处理结果
|
||||
"""
|
||||
# 子类实现具体的处理逻辑
|
||||
self._input_queue.put(data)
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
运行管道
|
||||
"""
|
||||
if not self._is_baked:
|
||||
raise RuntimeError("管道未烘焙,无法运行")
|
||||
|
||||
# 运行所有funtor
|
||||
for funtor_name, funtor in self._funtor_dict.items():
|
||||
logger.info(f"运行{funtor_name}funtor")
|
||||
funtor.run()
|
||||
|
||||
# 运行管道
|
||||
if not self._input_queue:
|
||||
raise RuntimeError("输入队列未设置")
|
||||
|
||||
# 设置管道运行状态
|
||||
self._is_running = True
|
||||
self._stop_event = False
|
||||
self._thread = threading.current_thread()
|
||||
logger.info("ASR管道开始运行")
|
||||
|
||||
while self._is_running and not self._stop_event:
|
||||
try:
|
||||
# 从队列获取数据
|
||||
try:
|
||||
data = self._input_queue.get(timeout=self._queue_timeout)
|
||||
# 检查是否是结束信号
|
||||
if data is None:
|
||||
logger.info("收到结束信号,管道准备停止")
|
||||
self._input_queue.task_done() # 标记结束信号已处理
|
||||
break
|
||||
|
||||
# 处理数据
|
||||
self.process(data)
|
||||
|
||||
# 标记任务完成
|
||||
self._input_queue.task_done()
|
||||
|
||||
except Empty:
|
||||
# 队列获取超时,继续等待
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"管道处理数据出错: {str(e)}")
|
||||
continue
|
||||
|
||||
logger.info("管道停止运行")
|
||||
|
||||
def _check_result(self, result: Any) -> None:
|
||||
"""
|
||||
检查结果
|
||||
"""
|
||||
# 若asr和spk队列中都有数据,则合并数据
|
||||
if self._subqueue_dict["asrend"].qsize() & self._subqueue_dict["spkend"].qsize():
|
||||
asr_data = self._subqueue_dict["asrend"].get()
|
||||
spk_data = self._subqueue_dict["spkend"].get()
|
||||
# 合并数据
|
||||
result = {
|
||||
"asr_data": asr_data,
|
||||
"spk_data": spk_data
|
||||
}
|
||||
# 通知回调函数
|
||||
self._notify_callbacks(result)
|
||||
|
@ -1,3 +1,3 @@
|
||||
from src.pipeline.base import PipelineBase, Pipeline
|
||||
from src.pipeline.base import PipelineBase, Pipeline, PipelineFactory
|
||||
|
||||
__all__ = ["PipelineBase", "Pipeline"]
|
||||
__all__ = ["PipelineBase", "Pipeline", "PipelineFactory"]
|
||||
|
@ -1,21 +1,128 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from queue import Queue, Empty
|
||||
from typing import List, Callable, Any, Optional
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PipelineBase(ABC):
|
||||
"""
|
||||
管道基类
|
||||
定义了管道的基本接口和通用功能
|
||||
"""
|
||||
@abstractmethod
|
||||
def run(self, *args, **kwargs):
|
||||
def __init__(self, input_queue: Optional[Queue] = None):
|
||||
"""
|
||||
运行管道
|
||||
初始化管道
|
||||
参数:
|
||||
input_queue: 输入队列,用于接收数据
|
||||
"""
|
||||
self._input_queue = input_queue
|
||||
self._callbacks: List[Callable] = []
|
||||
self._is_running = False
|
||||
self._stop_event = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._stop_timeout = 5 # 默认停止超时时间(秒)
|
||||
self._queue_timeout = 1 # 队列获取超时时间(秒)
|
||||
|
||||
class Pipeline(PipelineBase):
|
||||
def set_input_queue(self, queue: Queue) -> None:
|
||||
"""
|
||||
管道类
|
||||
设置输入队列
|
||||
参数:
|
||||
queue: 输入队列
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._input_queue = queue
|
||||
|
||||
def add_callback(self, callback: Callable) -> None:
|
||||
"""
|
||||
添加回调函数
|
||||
参数:
|
||||
callback: 回调函数,接收处理结果
|
||||
"""
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def _notify_callbacks(self, result: Any) -> None:
|
||||
"""
|
||||
通知所有回调函数
|
||||
参数:
|
||||
result: 处理结果
|
||||
"""
|
||||
for callback in self._callbacks:
|
||||
try:
|
||||
callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"回调函数执行出错: {str(e)}")
|
||||
|
||||
@abstractmethod
|
||||
def process(self, data: Any) -> Any:
|
||||
"""
|
||||
处理数据
|
||||
参数:
|
||||
data: 输入数据
|
||||
返回:
|
||||
处理结果
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run(self) -> None:
|
||||
"""
|
||||
运行管道
|
||||
从输入队列获取数据并处理
|
||||
"""
|
||||
pass
|
||||
|
||||
def stop(self, timeout: Optional[float] = None) -> bool:
|
||||
"""
|
||||
停止管道
|
||||
参数:
|
||||
timeout: 停止超时时间(秒),None表示使用默认超时时间
|
||||
返回:
|
||||
bool: 是否成功停止
|
||||
"""
|
||||
if not self._is_running:
|
||||
return True
|
||||
|
||||
logger.info("正在停止管道...")
|
||||
self._stop_event = True
|
||||
self._is_running = False
|
||||
|
||||
# 等待线程结束
|
||||
if self._thread and self._thread.is_alive():
|
||||
timeout = timeout if timeout is not None else self._stop_timeout
|
||||
self._thread.join(timeout=timeout)
|
||||
|
||||
# 检查是否成功停止
|
||||
if self._thread.is_alive():
|
||||
logger.warning(f"管道停止超时({timeout}秒),强制终止")
|
||||
return False
|
||||
else:
|
||||
logger.info("管道已成功停止")
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
def force_stop(self) -> None:
|
||||
"""
|
||||
强制停止管道
|
||||
注意:这可能会导致资源未正确释放
|
||||
"""
|
||||
logger.warning("强制停止管道")
|
||||
self._stop_event = True
|
||||
self._is_running = False
|
||||
# 注意:Python的线程无法被强制终止,这里只是设置标志
|
||||
# 实际终止需要依赖操作系统的进程管理
|
||||
|
||||
class PipelineFactory:
|
||||
"""
|
||||
管道工厂类
|
||||
用于创建管道实例
|
||||
"""
|
||||
@staticmethod
|
||||
def create_pipeline(pipeline_name: str) -> Pipeline:
|
||||
"""
|
||||
创建管道实例
|
||||
"""
|
||||
pass
|
0
src/pipeline/test.py
Normal file
0
src/pipeline/test.py
Normal file
248
src/runner.py
248
src/runner.py
@ -6,129 +6,227 @@
|
||||
- Runner: 运行器类,工厂模式实现
|
||||
- RunnerFactory: 运行器工厂类,用于创建运行器
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List, Queue
|
||||
from typing import Dict, Any, List
|
||||
from threading import Thread, Lock
|
||||
from queue import Queue
|
||||
import traceback
|
||||
import time
|
||||
|
||||
from src.audio_chunk import AudioChunk, AudioBinary
|
||||
from src.pipeline import Pipeline
|
||||
from src.pipeline import Pipeline, PipelineFactory
|
||||
from src.model_loader import ModelLoader
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__, level="INFO")
|
||||
|
||||
audio_chunk = AudioChunk()
|
||||
models_loaded = ModelLoader()
|
||||
pipelines_loaded = PipelineLoader()
|
||||
|
||||
|
||||
class RunnerBase(ABC):
|
||||
"""
|
||||
运行器基类
|
||||
定义了运行器的基本接口
|
||||
"""
|
||||
# 计算资源
|
||||
_audio_binary: AudioBinary = None
|
||||
_models: Dict[str, Any] = {}
|
||||
_pipeline: Pipeline = None
|
||||
|
||||
# IO交互
|
||||
_receivers: List[callable] = []
|
||||
|
||||
# 异步交互消息队列
|
||||
_input_queue: Queue = None
|
||||
_output_queue: Queue = None
|
||||
|
||||
@abstractmethod
|
||||
def adder(self, *args, **kwargs):
|
||||
def adder(self, data: Any) -> None:
|
||||
"""
|
||||
添加数据
|
||||
参数:
|
||||
data: 要添加的数据
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_recevier(self, *args, **kwargs):
|
||||
def add_recevier(self, receiver: callable) -> None:
|
||||
"""
|
||||
接收数据
|
||||
添加数据接收者
|
||||
参数:
|
||||
receiver: 接收数据的回调函数
|
||||
"""
|
||||
pass
|
||||
|
||||
class STT_Runner(RunnerBase):
|
||||
|
||||
class STTRunner(RunnerBase):
|
||||
"""
|
||||
运行器类
|
||||
工厂模式
|
||||
负责管理资源和协调Pipeline的运行
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
audio_binary_list: List[AudioBinary],
|
||||
models: Dict[str, Any],
|
||||
pipeline_list: List[Pipeline],
|
||||
input_queue: Queue,
|
||||
):
|
||||
"""
|
||||
初始化
|
||||
初始化运行器
|
||||
参数:
|
||||
audio_binary_list: 音频二进制列表
|
||||
models: 模型字典
|
||||
pipeline_list: 管道列表
|
||||
queue_size: 队列大小
|
||||
stop_timeout: 停止超时时间(秒)
|
||||
"""
|
||||
# 接收资源
|
||||
self._audio_binary_list = audio_binary_list
|
||||
self._models = models
|
||||
self._pipeline_list = pipeline_list
|
||||
|
||||
# 线程控制
|
||||
self._lock = Lock()
|
||||
|
||||
# 消息队列
|
||||
self._input_queue = Queue(maxsize=1000)
|
||||
|
||||
# 停止控制
|
||||
self._stop_timeout = 10.0
|
||||
self._is_stopping = False
|
||||
|
||||
# 配置资源
|
||||
for pipeline in self._pipeline_list:
|
||||
# 配置
|
||||
if pipeline.get_config('audio_binary_name') is not None:
|
||||
pipeline.set_audio_binary(self._audio_binary_list[pipeline.get_config('audio_binary_name')])
|
||||
if pipeline.get_config('model_name_list') is not None:
|
||||
# 设置输入队列
|
||||
pipeline.set_input_queue(self._input_queue)
|
||||
|
||||
# 配置资源
|
||||
pipeline.set_audio_binary(
|
||||
self._audio_binary_list[pipeline.get_config("audio_binary_name")]
|
||||
)
|
||||
pipeline.set_models(self._models)
|
||||
|
||||
def adder(self, *args, **kwargs):
|
||||
def adder(self, data: Any) -> None:
|
||||
"""
|
||||
添加数据
|
||||
添加数据到输入队列
|
||||
参数:
|
||||
data: 要添加的数据
|
||||
"""
|
||||
if self._pipeline_thread is None:
|
||||
raise RuntimeError("Pipeline thread not started")
|
||||
self._input_queue.put(args["data"])
|
||||
if not self._pipeline_list:
|
||||
raise RuntimeError("没有可用的管道")
|
||||
if self._is_stopping:
|
||||
raise RuntimeError("运行器正在停止,无法添加数据")
|
||||
self._input_queue.put(data)
|
||||
|
||||
def add_recevier(self, recevier: callable):
|
||||
def add_recevier(self, receiver: callable) -> None:
|
||||
"""
|
||||
添加数据接收者
|
||||
参数:
|
||||
receiver: 接收数据的回调函数
|
||||
"""
|
||||
if self._pipeline_thread is None:
|
||||
raise RuntimeError("Pipeline thread not started")
|
||||
self._receivers.append(recevier)
|
||||
with self._lock:
|
||||
for pipeline in self._pipeline_list:
|
||||
pipeline.add_callback(receiver)
|
||||
|
||||
def run(self):
|
||||
def run(self) -> None:
|
||||
"""
|
||||
运行pipeline子线程
|
||||
启动所有管道
|
||||
"""
|
||||
# 创建pipeline子线程
|
||||
self._pipeline_thread = threading.Thread(
|
||||
target=self._pipeline.run,
|
||||
args=(
|
||||
input_queue=self._input_queue,
|
||||
output_queue=self._output_queue
|
||||
logger.info("[%s] 启动所有管道", self.__class__.__name__)
|
||||
if not self._pipeline_list:
|
||||
raise RuntimeError("没有可用的管道")
|
||||
|
||||
# 启动所有管道
|
||||
for pipeline in self._pipeline_list:
|
||||
thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}")
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline))
|
||||
|
||||
def stop(self, force: bool = False) -> bool:
|
||||
"""
|
||||
停止所有管道
|
||||
参数:
|
||||
force: 是否强制停止
|
||||
返回:
|
||||
bool: 是否成功停止
|
||||
"""
|
||||
if self._is_stopping:
|
||||
logger.warning("运行器已经在停止中")
|
||||
return False
|
||||
|
||||
self._is_stopping = True
|
||||
logger.info("正在停止运行器...")
|
||||
|
||||
try:
|
||||
# 发送结束信号
|
||||
self._input_queue.put(None)
|
||||
|
||||
# 停止所有管道
|
||||
success = True
|
||||
for pipeline in self._pipeline_list:
|
||||
if force:
|
||||
pipeline.force_stop()
|
||||
else:
|
||||
if not pipeline.stop(timeout=self._stop_timeout):
|
||||
logger.warning("管道 %s 停止超时", id(pipeline))
|
||||
success = False
|
||||
|
||||
# 等待队列处理完成
|
||||
try:
|
||||
start_time = time.time()
|
||||
while not self._input_queue.empty():
|
||||
if time.time() - start_time > self._stop_timeout:
|
||||
logger.warning(
|
||||
"等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理",
|
||||
self._stop_timeout,
|
||||
self._input_queue.qsize()
|
||||
)
|
||||
success = False
|
||||
break
|
||||
time.sleep(0.1) # 避免过度消耗CPU
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
error_msg = str(e)
|
||||
error_traceback = traceback.format_exc()
|
||||
logger.error(
|
||||
"等待队列处理完成时发生错误:\n"
|
||||
"错误类型: %s\n"
|
||||
"错误信息: %s\n"
|
||||
"错误堆栈:\n%s",
|
||||
error_type,
|
||||
error_msg,
|
||||
error_traceback
|
||||
)
|
||||
self._pipeline_thread.start()
|
||||
success = False
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
停止pipeline子线程
|
||||
"""
|
||||
# 结束pipeline子线程
|
||||
self._pipeline_thread.join()
|
||||
if success:
|
||||
logger.info("所有管道已成功停止")
|
||||
else:
|
||||
logger.warning(
|
||||
"部分管道停止失败,队列状态: 大小=%d, 是否为空=%s",
|
||||
self._input_queue.qsize(),
|
||||
self._input_queue.empty()
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
析构
|
||||
"""
|
||||
self.stop()
|
||||
return success
|
||||
|
||||
class STT_RunnerFactory:
|
||||
finally:
|
||||
self._is_stopping = False
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
析构函数
|
||||
"""
|
||||
self.stop(force=True)
|
||||
|
||||
|
||||
class STTRunnerFactory:
|
||||
"""
|
||||
STT Runner工厂类
|
||||
用于创建运行器实例
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _create_runner(
|
||||
audio_binary_name: str,
|
||||
model_name_list: List[str],
|
||||
pipeline_name_list: List[str],
|
||||
):
|
||||
) -> STTRunner:
|
||||
"""
|
||||
全参数创建Runner
|
||||
创建运行器
|
||||
参数:
|
||||
audio_binary_name: 音频二进制名称
|
||||
model_name_list: 模型名称列表
|
||||
@ -142,26 +240,42 @@ class STT_RunnerFactory:
|
||||
for model_name in model_name_list
|
||||
}
|
||||
pipelines: List[Pipeline] = [
|
||||
pipelines_loaded.pipelines[pipeline_name]
|
||||
PipelineFactory.create_pipeline(pipeline_name)
|
||||
for pipeline_name in pipeline_name_list
|
||||
]
|
||||
return Runner(audio_binary, models, pipelines)
|
||||
return STTRunner(
|
||||
audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_runner_from_config(
|
||||
cls,
|
||||
config: Dict[str, Any],
|
||||
):
|
||||
) -> STTRunner:
|
||||
"""
|
||||
从配置创建运行器
|
||||
参数:
|
||||
config: 配置字典
|
||||
返回:
|
||||
Runner实例
|
||||
"""
|
||||
audio_binary_name = config["audio_binary_name"]
|
||||
model_name_list = config["model_name_list"]
|
||||
pipeline_name_list = config["pipeline_name_list"]
|
||||
return cls._create_runner(audio_binary_name, model_name_list, pipeline_name_list)
|
||||
return cls._create_runner(
|
||||
audio_binary_name, model_name_list, pipeline_name_list
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_runner_normal(
|
||||
cls
|
||||
)
|
||||
def create_runner_normal(cls) -> STTRunner:
|
||||
"""
|
||||
创建默认运行器
|
||||
返回:
|
||||
Runner实例
|
||||
"""
|
||||
audio_binary_name = None
|
||||
model_name_list = models_loaded.models.keys()
|
||||
model_name_list = list(models_loaded.models.keys())
|
||||
pipeline_name_list = None
|
||||
return cls._create_runner(audio_binary_name, model_name_list, pipeline_name_list)
|
||||
return cls._create_runner(
|
||||
audio_binary_name, model_name_list, pipeline_name_list
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user