[代码重构中]编写ASRpipeline,管理funtor的线程启动,管理funtor间消息队列queue
This commit is contained in:
parent
49cb428c23
commit
f245c6e9df
@ -2,9 +2,9 @@ from typing import Callable
|
|||||||
|
|
||||||
class BaseFunctor:
|
class BaseFunctor:
|
||||||
"""
|
"""
|
||||||
基础函数器类,提供数据处理的基本框架
|
基础函数器类, 提供数据处理的基本框架
|
||||||
|
|
||||||
该类实现了数据处理的基本接口,包括数据推送、处理和回调机制。
|
该类实现了数据处理的基本接口, 包括数据推送、处理和回调机制。
|
||||||
所有具体的功能实现类都应该继承这个基类。
|
所有具体的功能实现类都应该继承这个基类。
|
||||||
|
|
||||||
属性:
|
属性:
|
||||||
@ -22,7 +22,7 @@ class BaseFunctor:
|
|||||||
初始化函数器
|
初始化函数器
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
data (dict or bytes): 初始数据,可以是字典或字节数据
|
data (dict or bytes): 初始数据, 可以是字典或字节数据
|
||||||
callback (Callable): 处理完成后的回调函数
|
callback (Callable): 处理完成后的回调函数
|
||||||
model (dict): 模型相关的配置和实例
|
model (dict): 模型相关的配置和实例
|
||||||
"""
|
"""
|
||||||
@ -34,33 +34,33 @@ class BaseFunctor:
|
|||||||
|
|
||||||
def __call__(self, data = None):
|
def __call__(self, data = None):
|
||||||
"""
|
"""
|
||||||
使类实例可调用,处理数据并触发回调
|
使类实例可调用, 处理数据并触发回调
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
data: 要处理的数据,如果为None则处理已存储的数据
|
data: 要处理的数据, 如果为None则处理已存储的数据
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
处理结果
|
处理结果
|
||||||
"""
|
"""
|
||||||
# 如果传入数据,则压入数据
|
# 如果传入数据, 则压入数据
|
||||||
if data is not None:
|
if data is not None:
|
||||||
self.push_data(data)
|
self.push_data(data)
|
||||||
# 处理数据
|
# 处理数据
|
||||||
result = self.process()
|
result = self.process()
|
||||||
# 如果回调函数存在,则触发回调
|
# 如果回调函数存在, 则触发回调
|
||||||
if self._callback is not None and callable(self._callback):
|
if self._callback is not None and callable(self._callback):
|
||||||
self._callback(result)
|
self._callback(result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def __add__(self, other):
|
def __add__(self, other):
|
||||||
"""
|
"""
|
||||||
重载加法运算符,用于合并数据
|
重载加法运算符, 用于合并数据
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
other: 要合并的数据
|
other: 要合并的数据
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
self: 返回当前实例,支持链式调用
|
self: 返回当前实例, 支持链式调用
|
||||||
"""
|
"""
|
||||||
self.push_data(other)
|
self.push_data(other)
|
||||||
return self
|
return self
|
||||||
|
190
src/pipeline/ASRpipeline.py
Normal file
190
src/pipeline/ASRpipeline.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
from src.pipeline.base import PipelineBase
|
||||||
|
from typing import Dict, Any
|
||||||
|
from queue import Queue
|
||||||
|
from src.utils import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
class ASRPipeline(PipelineBase):
|
||||||
|
"""
|
||||||
|
管道类
|
||||||
|
实现具体的处理逻辑
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
初始化管道
|
||||||
|
"""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._config: Dict[str, Any] = {}
|
||||||
|
self._funtor_dict: Dict[str, Any] = {}
|
||||||
|
self._subqueue_dict: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
self._is_baked: bool = False
|
||||||
|
|
||||||
|
def set_config(self, config: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
设置配置
|
||||||
|
参数:
|
||||||
|
config: Dict[str, Any] 配置
|
||||||
|
"""
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
def set_audio_binary(self, audio_binary: AudioBinary) -> None:
|
||||||
|
"""
|
||||||
|
设置音频二进制存储单元
|
||||||
|
参数:
|
||||||
|
audio_binary: 音频二进制
|
||||||
|
"""
|
||||||
|
self._audio_binary = audio_binary
|
||||||
|
|
||||||
|
def set_models(self, models: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
设置模型
|
||||||
|
"""
|
||||||
|
self._models = models
|
||||||
|
|
||||||
|
def bake(self) -> None:
|
||||||
|
"""
|
||||||
|
烘焙管道
|
||||||
|
"""
|
||||||
|
self._init_funtor()
|
||||||
|
self._is_baked = True
|
||||||
|
|
||||||
|
def _init_funtor(self) -> None:
|
||||||
|
"""
|
||||||
|
初始化函数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from src.funtor import FuntorFactory
|
||||||
|
# 加载VAD、asr、spk funtor
|
||||||
|
self._funtor_dict["vad"] = FuntorFactory.get_funtor(
|
||||||
|
funtor_name = "vad",
|
||||||
|
config = self._config,
|
||||||
|
models = self._models
|
||||||
|
)
|
||||||
|
self._funtor_dict["asr"] = FuntorFactory.get_funtor(
|
||||||
|
funtor_name = "asr",
|
||||||
|
config = self._config,
|
||||||
|
models = self._models
|
||||||
|
)
|
||||||
|
self._funtor_dict["spk"] = FuntorFactory.get_funtor(
|
||||||
|
funtor_name = "spk",
|
||||||
|
config = self._config,
|
||||||
|
models = self._models
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化子队列
|
||||||
|
self._subqueue_dict["vad2asr"] = Queue()
|
||||||
|
self._subqueue_dict["vad2spk"] = Queue()
|
||||||
|
self._subqueue_dict["asrend"] = Queue()
|
||||||
|
self._subqueue_dict["spkend"] = Queue()
|
||||||
|
|
||||||
|
# 设置子队列的输入队列
|
||||||
|
self._funtor_dict["vad"].set_input_queue(self._input_queue)
|
||||||
|
self._funtor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
|
||||||
|
self._funtor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
|
||||||
|
|
||||||
|
# 设置回调函数——放置到对应队列中
|
||||||
|
self._funtor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put)
|
||||||
|
self._funtor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put)
|
||||||
|
|
||||||
|
# 构造带回调函数的put
|
||||||
|
def put_with_checkcallback(queue: Queue, callback: Callable) -> None:
|
||||||
|
"""
|
||||||
|
带回调函数的put
|
||||||
|
"""
|
||||||
|
def put_with_check(data: Any) -> None:
|
||||||
|
queue.put(data)
|
||||||
|
callback()
|
||||||
|
return put_with_check
|
||||||
|
|
||||||
|
self._funtor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result))
|
||||||
|
self._funtor_dict["spk"].add_callback(put_with_checkcallback(self._subqueue_dict["spkend"], self._check_result))
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("FuntorFactory引入失败,ASRPipeline无法完成初始化")
|
||||||
|
|
||||||
|
def get_config(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取配置
|
||||||
|
返回:
|
||||||
|
Dict[str, Any] 配置
|
||||||
|
"""
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
def process(self, data: Any) -> Any:
|
||||||
|
"""
|
||||||
|
处理数据
|
||||||
|
参数:
|
||||||
|
data: 输入数据
|
||||||
|
返回:
|
||||||
|
处理结果
|
||||||
|
"""
|
||||||
|
# 子类实现具体的处理逻辑
|
||||||
|
self._input_queue.put(data)
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
"""
|
||||||
|
运行管道
|
||||||
|
"""
|
||||||
|
if not self._is_baked:
|
||||||
|
raise RuntimeError("管道未烘焙,无法运行")
|
||||||
|
|
||||||
|
# 运行所有funtor
|
||||||
|
for funtor_name, funtor in self._funtor_dict.items():
|
||||||
|
logger.info(f"运行{funtor_name}funtor")
|
||||||
|
funtor.run()
|
||||||
|
|
||||||
|
# 运行管道
|
||||||
|
if not self._input_queue:
|
||||||
|
raise RuntimeError("输入队列未设置")
|
||||||
|
|
||||||
|
# 设置管道运行状态
|
||||||
|
self._is_running = True
|
||||||
|
self._stop_event = False
|
||||||
|
self._thread = threading.current_thread()
|
||||||
|
logger.info("ASR管道开始运行")
|
||||||
|
|
||||||
|
while self._is_running and not self._stop_event:
|
||||||
|
try:
|
||||||
|
# 从队列获取数据
|
||||||
|
try:
|
||||||
|
data = self._input_queue.get(timeout=self._queue_timeout)
|
||||||
|
# 检查是否是结束信号
|
||||||
|
if data is None:
|
||||||
|
logger.info("收到结束信号,管道准备停止")
|
||||||
|
self._input_queue.task_done() # 标记结束信号已处理
|
||||||
|
break
|
||||||
|
|
||||||
|
# 处理数据
|
||||||
|
self.process(data)
|
||||||
|
|
||||||
|
# 标记任务完成
|
||||||
|
self._input_queue.task_done()
|
||||||
|
|
||||||
|
except Empty:
|
||||||
|
# 队列获取超时,继续等待
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"管道处理数据出错: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info("管道停止运行")
|
||||||
|
|
||||||
|
def _check_result(self, result: Any) -> None:
|
||||||
|
"""
|
||||||
|
检查结果
|
||||||
|
"""
|
||||||
|
# 若asr和spk队列中都有数据,则合并数据
|
||||||
|
if self._subqueue_dict["asrend"].qsize() & self._subqueue_dict["spkend"].qsize():
|
||||||
|
asr_data = self._subqueue_dict["asrend"].get()
|
||||||
|
spk_data = self._subqueue_dict["spkend"].get()
|
||||||
|
# 合并数据
|
||||||
|
result = {
|
||||||
|
"asr_data": asr_data,
|
||||||
|
"spk_data": spk_data
|
||||||
|
}
|
||||||
|
# 通知回调函数
|
||||||
|
self._notify_callbacks(result)
|
||||||
|
|
@ -1,3 +1,3 @@
|
|||||||
from src.pipeline.base import PipelineBase, Pipeline
|
from src.pipeline.base import PipelineBase, Pipeline, PipelineFactory
|
||||||
|
|
||||||
__all__ = ["PipelineBase", "Pipeline"]
|
__all__ = ["PipelineBase", "Pipeline", "PipelineFactory"]
|
||||||
|
@ -1,21 +1,128 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from queue import Queue, Empty
|
||||||
|
from typing import List, Callable, Any, Optional
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class PipelineBase(ABC):
|
class PipelineBase(ABC):
|
||||||
"""
|
"""
|
||||||
管道基类
|
管道基类
|
||||||
|
定义了管道的基本接口和通用功能
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, input_queue: Optional[Queue] = None):
|
||||||
|
"""
|
||||||
|
初始化管道
|
||||||
|
参数:
|
||||||
|
input_queue: 输入队列,用于接收数据
|
||||||
|
"""
|
||||||
|
self._input_queue = input_queue
|
||||||
|
self._callbacks: List[Callable] = []
|
||||||
|
self._is_running = False
|
||||||
|
self._stop_event = False
|
||||||
|
self._thread: Optional[threading.Thread] = None
|
||||||
|
self._stop_timeout = 5 # 默认停止超时时间(秒)
|
||||||
|
self._queue_timeout = 1 # 队列获取超时时间(秒)
|
||||||
|
|
||||||
|
def set_input_queue(self, queue: Queue) -> None:
|
||||||
|
"""
|
||||||
|
设置输入队列
|
||||||
|
参数:
|
||||||
|
queue: 输入队列
|
||||||
|
"""
|
||||||
|
self._input_queue = queue
|
||||||
|
|
||||||
|
def add_callback(self, callback: Callable) -> None:
|
||||||
|
"""
|
||||||
|
添加回调函数
|
||||||
|
参数:
|
||||||
|
callback: 回调函数,接收处理结果
|
||||||
|
"""
|
||||||
|
self._callbacks.append(callback)
|
||||||
|
|
||||||
|
def _notify_callbacks(self, result: Any) -> None:
|
||||||
|
"""
|
||||||
|
通知所有回调函数
|
||||||
|
参数:
|
||||||
|
result: 处理结果
|
||||||
|
"""
|
||||||
|
for callback in self._callbacks:
|
||||||
|
try:
|
||||||
|
callback(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"回调函数执行出错: {str(e)}")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run(self, *args, **kwargs):
|
def process(self, data: Any) -> Any:
|
||||||
|
"""
|
||||||
|
处理数据
|
||||||
|
参数:
|
||||||
|
data: 输入数据
|
||||||
|
返回:
|
||||||
|
处理结果
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run(self) -> None:
|
||||||
"""
|
"""
|
||||||
运行管道
|
运行管道
|
||||||
|
从输入队列获取数据并处理
|
||||||
"""
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
class Pipeline(PipelineBase):
|
def stop(self, timeout: Optional[float] = None) -> bool:
|
||||||
"""
|
|
||||||
管道类
|
|
||||||
"""
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
"""
|
"""
|
||||||
|
停止管道
|
||||||
|
参数:
|
||||||
|
timeout: 停止超时时间(秒),None表示使用默认超时时间
|
||||||
|
返回:
|
||||||
|
bool: 是否成功停止
|
||||||
"""
|
"""
|
||||||
pass
|
if not self._is_running:
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.info("正在停止管道...")
|
||||||
|
self._stop_event = True
|
||||||
|
self._is_running = False
|
||||||
|
|
||||||
|
# 等待线程结束
|
||||||
|
if self._thread and self._thread.is_alive():
|
||||||
|
timeout = timeout if timeout is not None else self._stop_timeout
|
||||||
|
self._thread.join(timeout=timeout)
|
||||||
|
|
||||||
|
# 检查是否成功停止
|
||||||
|
if self._thread.is_alive():
|
||||||
|
logger.warning(f"管道停止超时({timeout}秒),强制终止")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.info("管道已成功停止")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def force_stop(self) -> None:
|
||||||
|
"""
|
||||||
|
强制停止管道
|
||||||
|
注意:这可能会导致资源未正确释放
|
||||||
|
"""
|
||||||
|
logger.warning("强制停止管道")
|
||||||
|
self._stop_event = True
|
||||||
|
self._is_running = False
|
||||||
|
# 注意:Python的线程无法被强制终止,这里只是设置标志
|
||||||
|
# 实际终止需要依赖操作系统的进程管理
|
||||||
|
|
||||||
|
class PipelineFactory:
|
||||||
|
"""
|
||||||
|
管道工厂类
|
||||||
|
用于创建管道实例
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def create_pipeline(pipeline_name: str) -> Pipeline:
|
||||||
|
"""
|
||||||
|
创建管道实例
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
0
src/pipeline/test.py
Normal file
0
src/pipeline/test.py
Normal file
258
src/runner.py
258
src/runner.py
@ -6,129 +6,227 @@
|
|||||||
- Runner: 运行器类,工厂模式实现
|
- Runner: 运行器类,工厂模式实现
|
||||||
- RunnerFactory: 运行器工厂类,用于创建运行器
|
- RunnerFactory: 运行器工厂类,用于创建运行器
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Any, List, Queue
|
from typing import Dict, Any, List
|
||||||
|
from threading import Thread, Lock
|
||||||
|
from queue import Queue
|
||||||
|
import traceback
|
||||||
|
import time
|
||||||
|
|
||||||
from src.audio_chunk import AudioChunk, AudioBinary
|
from src.audio_chunk import AudioChunk, AudioBinary
|
||||||
from src.pipeline import Pipeline
|
from src.pipeline import Pipeline, PipelineFactory
|
||||||
from src.model_loader import ModelLoader
|
from src.model_loader import ModelLoader
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__, level="INFO")
|
||||||
|
|
||||||
audio_chunk = AudioChunk()
|
audio_chunk = AudioChunk()
|
||||||
models_loaded = ModelLoader()
|
models_loaded = ModelLoader()
|
||||||
pipelines_loaded = PipelineLoader()
|
|
||||||
|
|
||||||
class RunnerBase(ABC):
|
class RunnerBase(ABC):
|
||||||
"""
|
"""
|
||||||
运行器基类
|
运行器基类
|
||||||
|
定义了运行器的基本接口
|
||||||
"""
|
"""
|
||||||
# 计算资源
|
|
||||||
_audio_binary: AudioBinary = None
|
|
||||||
_models: Dict[str, Any] = {}
|
|
||||||
_pipeline: Pipeline = None
|
|
||||||
|
|
||||||
# IO交互
|
|
||||||
_receivers: List[callable] = []
|
|
||||||
|
|
||||||
# 异步交互消息队列
|
|
||||||
_input_queue: Queue = None
|
|
||||||
_output_queue: Queue = None
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def adder(self, *args, **kwargs):
|
def adder(self, data: Any) -> None:
|
||||||
"""
|
"""
|
||||||
添加数据
|
添加数据
|
||||||
|
参数:
|
||||||
|
data: 要添加的数据
|
||||||
"""
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_recevier(self, *args, **kwargs):
|
def add_recevier(self, receiver: callable) -> None:
|
||||||
"""
|
"""
|
||||||
接收数据
|
添加数据接收者
|
||||||
|
参数:
|
||||||
|
receiver: 接收数据的回调函数
|
||||||
"""
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
class STT_Runner(RunnerBase):
|
|
||||||
|
class STTRunner(RunnerBase):
|
||||||
"""
|
"""
|
||||||
运行器类
|
运行器类
|
||||||
工厂模式
|
负责管理资源和协调Pipeline的运行
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
audio_binary_list: List[AudioBinary],
|
audio_binary_list: List[AudioBinary],
|
||||||
models: Dict[str, Any],
|
models: Dict[str, Any],
|
||||||
pipeline_list: List[Pipeline],
|
pipeline_list: List[Pipeline],
|
||||||
input_queue: Queue,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化
|
初始化运行器
|
||||||
|
参数:
|
||||||
|
audio_binary_list: 音频二进制列表
|
||||||
|
models: 模型字典
|
||||||
|
pipeline_list: 管道列表
|
||||||
|
queue_size: 队列大小
|
||||||
|
stop_timeout: 停止超时时间(秒)
|
||||||
"""
|
"""
|
||||||
# 接收资源
|
# 接收资源
|
||||||
self._audio_binary_list = audio_binary_list
|
self._audio_binary_list = audio_binary_list
|
||||||
self._models = models
|
self._models = models
|
||||||
self._pipeline_list = pipeline_list
|
self._pipeline_list = pipeline_list
|
||||||
|
|
||||||
|
# 线程控制
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
# 消息队列
|
||||||
|
self._input_queue = Queue(maxsize=1000)
|
||||||
|
|
||||||
|
# 停止控制
|
||||||
|
self._stop_timeout = 10.0
|
||||||
|
self._is_stopping = False
|
||||||
|
|
||||||
# 配置资源
|
# 配置资源
|
||||||
for pipeline in self._pipeline_list:
|
for pipeline in self._pipeline_list:
|
||||||
# 配置
|
# 设置输入队列
|
||||||
if pipeline.get_config('audio_binary_name') is not None:
|
pipeline.set_input_queue(self._input_queue)
|
||||||
pipeline.set_audio_binary(self._audio_binary_list[pipeline.get_config('audio_binary_name')])
|
|
||||||
if pipeline.get_config('model_name_list') is not None:
|
|
||||||
pipeline.set_models(self._models)
|
|
||||||
|
|
||||||
def adder(self, *args, **kwargs):
|
# 配置资源
|
||||||
|
pipeline.set_audio_binary(
|
||||||
|
self._audio_binary_list[pipeline.get_config("audio_binary_name")]
|
||||||
|
)
|
||||||
|
pipeline.set_models(self._models)
|
||||||
|
|
||||||
|
def adder(self, data: Any) -> None:
|
||||||
"""
|
"""
|
||||||
添加数据
|
添加数据到输入队列
|
||||||
|
参数:
|
||||||
|
data: 要添加的数据
|
||||||
"""
|
"""
|
||||||
if self._pipeline_thread is None:
|
if not self._pipeline_list:
|
||||||
raise RuntimeError("Pipeline thread not started")
|
raise RuntimeError("没有可用的管道")
|
||||||
self._input_queue.put(args["data"])
|
if self._is_stopping:
|
||||||
|
raise RuntimeError("运行器正在停止,无法添加数据")
|
||||||
def add_recevier(self, recevier: callable):
|
self._input_queue.put(data)
|
||||||
|
|
||||||
|
def add_recevier(self, receiver: callable) -> None:
|
||||||
"""
|
"""
|
||||||
添加数据接收者
|
添加数据接收者
|
||||||
|
参数:
|
||||||
|
receiver: 接收数据的回调函数
|
||||||
"""
|
"""
|
||||||
if self._pipeline_thread is None:
|
with self._lock:
|
||||||
raise RuntimeError("Pipeline thread not started")
|
for pipeline in self._pipeline_list:
|
||||||
self._receivers.append(recevier)
|
pipeline.add_callback(receiver)
|
||||||
|
|
||||||
def run(self):
|
def run(self) -> None:
|
||||||
"""
|
"""
|
||||||
运行pipeline子线程
|
启动所有管道
|
||||||
"""
|
"""
|
||||||
# 创建pipeline子线程
|
logger.info("[%s] 启动所有管道", self.__class__.__name__)
|
||||||
self._pipeline_thread = threading.Thread(
|
if not self._pipeline_list:
|
||||||
target=self._pipeline.run,
|
raise RuntimeError("没有可用的管道")
|
||||||
args=(
|
|
||||||
input_queue=self._input_queue,
|
|
||||||
output_queue=self._output_queue
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self._pipeline_thread.start()
|
|
||||||
|
|
||||||
def stop(self):
|
# 启动所有管道
|
||||||
"""
|
for pipeline in self._pipeline_list:
|
||||||
停止pipeline子线程
|
thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}")
|
||||||
"""
|
thread.daemon = True
|
||||||
# 结束pipeline子线程
|
thread.start()
|
||||||
self._pipeline_thread.join()
|
logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline))
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
"""
|
|
||||||
析构
|
|
||||||
"""
|
|
||||||
self.stop()
|
|
||||||
|
|
||||||
class STT_RunnerFactory:
|
def stop(self, force: bool = False) -> bool:
|
||||||
|
"""
|
||||||
|
停止所有管道
|
||||||
|
参数:
|
||||||
|
force: 是否强制停止
|
||||||
|
返回:
|
||||||
|
bool: 是否成功停止
|
||||||
|
"""
|
||||||
|
if self._is_stopping:
|
||||||
|
logger.warning("运行器已经在停止中")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._is_stopping = True
|
||||||
|
logger.info("正在停止运行器...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 发送结束信号
|
||||||
|
self._input_queue.put(None)
|
||||||
|
|
||||||
|
# 停止所有管道
|
||||||
|
success = True
|
||||||
|
for pipeline in self._pipeline_list:
|
||||||
|
if force:
|
||||||
|
pipeline.force_stop()
|
||||||
|
else:
|
||||||
|
if not pipeline.stop(timeout=self._stop_timeout):
|
||||||
|
logger.warning("管道 %s 停止超时", id(pipeline))
|
||||||
|
success = False
|
||||||
|
|
||||||
|
# 等待队列处理完成
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
while not self._input_queue.empty():
|
||||||
|
if time.time() - start_time > self._stop_timeout:
|
||||||
|
logger.warning(
|
||||||
|
"等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理",
|
||||||
|
self._stop_timeout,
|
||||||
|
self._input_queue.qsize()
|
||||||
|
)
|
||||||
|
success = False
|
||||||
|
break
|
||||||
|
time.sleep(0.1) # 避免过度消耗CPU
|
||||||
|
except Exception as e:
|
||||||
|
error_type = type(e).__name__
|
||||||
|
error_msg = str(e)
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
logger.error(
|
||||||
|
"等待队列处理完成时发生错误:\n"
|
||||||
|
"错误类型: %s\n"
|
||||||
|
"错误信息: %s\n"
|
||||||
|
"错误堆栈:\n%s",
|
||||||
|
error_type,
|
||||||
|
error_msg,
|
||||||
|
error_traceback
|
||||||
|
)
|
||||||
|
success = False
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info("所有管道已成功停止")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"部分管道停止失败,队列状态: 大小=%d, 是否为空=%s",
|
||||||
|
self._input_queue.qsize(),
|
||||||
|
self._input_queue.empty()
|
||||||
|
)
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self._is_stopping = False
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
"""
|
||||||
|
析构函数
|
||||||
|
"""
|
||||||
|
self.stop(force=True)
|
||||||
|
|
||||||
|
|
||||||
|
class STTRunnerFactory:
|
||||||
"""
|
"""
|
||||||
STT Runner工厂类
|
STT Runner工厂类
|
||||||
|
用于创建运行器实例
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _create_runner(
|
def _create_runner(
|
||||||
audio_binary_name: str,
|
audio_binary_name: str,
|
||||||
model_name_list: List[str],
|
model_name_list: List[str],
|
||||||
pipeline_name_list: List[str],
|
pipeline_name_list: List[str],
|
||||||
):
|
) -> STTRunner:
|
||||||
"""
|
"""
|
||||||
全参数创建Runner
|
创建运行器
|
||||||
参数:
|
参数:
|
||||||
audio_binary_name: 音频二进制名称
|
audio_binary_name: 音频二进制名称
|
||||||
model_name_list: 模型名称列表
|
model_name_list: 模型名称列表
|
||||||
@ -142,26 +240,42 @@ class STT_RunnerFactory:
|
|||||||
for model_name in model_name_list
|
for model_name in model_name_list
|
||||||
}
|
}
|
||||||
pipelines: List[Pipeline] = [
|
pipelines: List[Pipeline] = [
|
||||||
pipelines_loaded.pipelines[pipeline_name]
|
PipelineFactory.create_pipeline(pipeline_name)
|
||||||
for pipeline_name in pipeline_name_list
|
for pipeline_name in pipeline_name_list
|
||||||
]
|
]
|
||||||
return Runner(audio_binary, models, pipelines)
|
return STTRunner(
|
||||||
|
audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_runner_from_config(
|
def create_runner_from_config(
|
||||||
cls,
|
cls,
|
||||||
config: Dict[str, Any],
|
config: Dict[str, Any],
|
||||||
):
|
) -> STTRunner:
|
||||||
|
"""
|
||||||
|
从配置创建运行器
|
||||||
|
参数:
|
||||||
|
config: 配置字典
|
||||||
|
返回:
|
||||||
|
Runner实例
|
||||||
|
"""
|
||||||
audio_binary_name = config["audio_binary_name"]
|
audio_binary_name = config["audio_binary_name"]
|
||||||
model_name_list = config["model_name_list"]
|
model_name_list = config["model_name_list"]
|
||||||
pipeline_name_list = config["pipeline_name_list"]
|
pipeline_name_list = config["pipeline_name_list"]
|
||||||
return cls._create_runner(audio_binary_name, model_name_list, pipeline_name_list)
|
return cls._create_runner(
|
||||||
|
audio_binary_name, model_name_list, pipeline_name_list
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_runner_normal(
|
def create_runner_normal(cls) -> STTRunner:
|
||||||
cls
|
"""
|
||||||
)
|
创建默认运行器
|
||||||
|
返回:
|
||||||
|
Runner实例
|
||||||
|
"""
|
||||||
audio_binary_name = None
|
audio_binary_name = None
|
||||||
model_name_list = models_loaded.models.keys()
|
model_name_list = list(models_loaded.models.keys())
|
||||||
pipeline_name_list = None
|
pipeline_name_list = None
|
||||||
return cls._create_runner(audio_binary_name, model_name_list, pipeline_name_list)
|
return cls._create_runner(
|
||||||
|
audio_binary_name, model_name_list, pipeline_name_list
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user