[代码重构中]编写ASRpipeline,管理funtor的线程启动,管理funtor间消息队列queue

This commit is contained in:
Ziyang.Zhang 2025-06-03 09:19:15 +08:00
parent 49cb428c23
commit f245c6e9df
6 changed files with 501 additions and 90 deletions

View File

@ -2,9 +2,9 @@ from typing import Callable
class BaseFunctor: class BaseFunctor:
""" """
基础函数器类提供数据处理的基本框架 基础函数器类, 提供数据处理的基本框架
该类实现了数据处理的基本接口包括数据推送处理和回调机制 该类实现了数据处理的基本接口, 包括数据推送处理和回调机制
所有具体的功能实现类都应该继承这个基类 所有具体的功能实现类都应该继承这个基类
属性: 属性:
@ -22,7 +22,7 @@ class BaseFunctor:
初始化函数器 初始化函数器
参数: 参数:
data (dict or bytes): 初始数据可以是字典或字节数据 data (dict or bytes): 初始数据, 可以是字典或字节数据
callback (Callable): 处理完成后的回调函数 callback (Callable): 处理完成后的回调函数
model (dict): 模型相关的配置和实例 model (dict): 模型相关的配置和实例
""" """
@ -34,33 +34,33 @@ class BaseFunctor:
def __call__(self, data = None): def __call__(self, data = None):
""" """
使类实例可调用处理数据并触发回调 使类实例可调用, 处理数据并触发回调
参数: 参数:
data: 要处理的数据如果为None则处理已存储的数据 data: 要处理的数据, 如果为None则处理已存储的数据
返回: 返回:
处理结果 处理结果
""" """
# 如果传入数据则压入数据 # 如果传入数据, 则压入数据
if data is not None: if data is not None:
self.push_data(data) self.push_data(data)
# 处理数据 # 处理数据
result = self.process() result = self.process()
# 如果回调函数存在则触发回调 # 如果回调函数存在, 则触发回调
if self._callback is not None and callable(self._callback): if self._callback is not None and callable(self._callback):
self._callback(result) self._callback(result)
return result return result
def __add__(self, other): def __add__(self, other):
""" """
重载加法运算符用于合并数据 重载加法运算符, 用于合并数据
参数: 参数:
other: 要合并的数据 other: 要合并的数据
返回: 返回:
self: 返回当前实例支持链式调用 self: 返回当前实例, 支持链式调用
""" """
self.push_data(other) self.push_data(other)
return self return self

190
src/pipeline/ASRpipeline.py Normal file
View File

@ -0,0 +1,190 @@
from src.pipeline.base import PipelineBase
from typing import Dict, Any
from queue import Queue
from src.utils import get_module_logger
logger = get_module_logger(__name__)
class ASRPipeline(PipelineBase):
"""
管道类
实现具体的处理逻辑
"""
def __init__(self, *args, **kwargs):
"""
初始化管道
"""
super().__init__(*args, **kwargs)
self._config: Dict[str, Any] = {}
self._funtor_dict: Dict[str, Any] = {}
self._subqueue_dict: Dict[str, Any] = {}
self._is_baked: bool = False
def set_config(self, config: Dict[str, Any]) -> None:
"""
设置配置
参数:
config: Dict[str, Any] 配置
"""
self._config = config
def set_audio_binary(self, audio_binary: AudioBinary) -> None:
"""
设置音频二进制存储单元
参数:
audio_binary: 音频二进制
"""
self._audio_binary = audio_binary
def set_models(self, models: Dict[str, Any]) -> None:
"""
设置模型
"""
self._models = models
def bake(self) -> None:
"""
烘焙管道
"""
self._init_funtor()
self._is_baked = True
def _init_funtor(self) -> None:
"""
初始化函数
"""
try:
from src.funtor import FuntorFactory
# 加载VAD、asr、spk funtor
self._funtor_dict["vad"] = FuntorFactory.get_funtor(
funtor_name = "vad",
config = self._config,
models = self._models
)
self._funtor_dict["asr"] = FuntorFactory.get_funtor(
funtor_name = "asr",
config = self._config,
models = self._models
)
self._funtor_dict["spk"] = FuntorFactory.get_funtor(
funtor_name = "spk",
config = self._config,
models = self._models
)
# 初始化子队列
self._subqueue_dict["vad2asr"] = Queue()
self._subqueue_dict["vad2spk"] = Queue()
self._subqueue_dict["asrend"] = Queue()
self._subqueue_dict["spkend"] = Queue()
# 设置子队列的输入队列
self._funtor_dict["vad"].set_input_queue(self._input_queue)
self._funtor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
self._funtor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
# 设置回调函数——放置到对应队列中
self._funtor_dict["vad"].add_callback(self._subqueue_dict["vad2asr"].put)
self._funtor_dict["vad"].add_callback(self._subqueue_dict["vad2spk"].put)
# 构造带回调函数的put
def put_with_checkcallback(queue: Queue, callback: Callable) -> None:
"""
带回调函数的put
"""
def put_with_check(data: Any) -> None:
queue.put(data)
callback()
return put_with_check
self._funtor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result))
self._funtor_dict["spk"].add_callback(put_with_checkcallback(self._subqueue_dict["spkend"], self._check_result))
except ImportError:
raise ImportError("FuntorFactory引入失败,ASRPipeline无法完成初始化")
def get_config(self) -> Dict[str, Any]:
"""
获取配置
返回:
Dict[str, Any] 配置
"""
return self._config
def process(self, data: Any) -> Any:
"""
处理数据
参数:
data: 输入数据
返回:
处理结果
"""
# 子类实现具体的处理逻辑
self._input_queue.put(data)
def run(self) -> None:
"""
运行管道
"""
if not self._is_baked:
raise RuntimeError("管道未烘焙,无法运行")
# 运行所有funtor
for funtor_name, funtor in self._funtor_dict.items():
logger.info(f"运行{funtor_name}funtor")
funtor.run()
# 运行管道
if not self._input_queue:
raise RuntimeError("输入队列未设置")
# 设置管道运行状态
self._is_running = True
self._stop_event = False
self._thread = threading.current_thread()
logger.info("ASR管道开始运行")
while self._is_running and not self._stop_event:
try:
# 从队列获取数据
try:
data = self._input_queue.get(timeout=self._queue_timeout)
# 检查是否是结束信号
if data is None:
logger.info("收到结束信号,管道准备停止")
self._input_queue.task_done() # 标记结束信号已处理
break
# 处理数据
self.process(data)
# 标记任务完成
self._input_queue.task_done()
except Empty:
# 队列获取超时,继续等待
continue
except Exception as e:
logger.error(f"管道处理数据出错: {str(e)}")
continue
logger.info("管道停止运行")
def _check_result(self, result: Any) -> None:
"""
检查结果
"""
# 若asr和spk队列中都有数据则合并数据
if self._subqueue_dict["asrend"].qsize() & self._subqueue_dict["spkend"].qsize():
asr_data = self._subqueue_dict["asrend"].get()
spk_data = self._subqueue_dict["spkend"].get()
# 合并数据
result = {
"asr_data": asr_data,
"spk_data": spk_data
}
# 通知回调函数
self._notify_callbacks(result)

View File

@ -1,3 +1,3 @@
from src.pipeline.base import PipelineBase, Pipeline from src.pipeline.base import PipelineBase, Pipeline, PipelineFactory
__all__ = ["PipelineBase", "Pipeline"] __all__ = ["PipelineBase", "Pipeline", "PipelineFactory"]

View File

@ -1,21 +1,128 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from queue import Queue, Empty
from typing import List, Callable, Any, Optional
import logging
import threading
import time
# 配置日志
logger = logging.getLogger(__name__)
class PipelineBase(ABC): class PipelineBase(ABC):
""" """
管道基类 管道基类
定义了管道的基本接口和通用功能
""" """
def __init__(self, input_queue: Optional[Queue] = None):
"""
初始化管道
参数:
input_queue: 输入队列用于接收数据
"""
self._input_queue = input_queue
self._callbacks: List[Callable] = []
self._is_running = False
self._stop_event = False
self._thread: Optional[threading.Thread] = None
self._stop_timeout = 5 # 默认停止超时时间(秒)
self._queue_timeout = 1 # 队列获取超时时间(秒)
def set_input_queue(self, queue: Queue) -> None:
"""
设置输入队列
参数:
queue: 输入队列
"""
self._input_queue = queue
def add_callback(self, callback: Callable) -> None:
"""
添加回调函数
参数:
callback: 回调函数接收处理结果
"""
self._callbacks.append(callback)
def _notify_callbacks(self, result: Any) -> None:
"""
通知所有回调函数
参数:
result: 处理结果
"""
for callback in self._callbacks:
try:
callback(result)
except Exception as e:
logger.error(f"回调函数执行出错: {str(e)}")
@abstractmethod @abstractmethod
def run(self, *args, **kwargs): def process(self, data: Any) -> Any:
"""
处理数据
参数:
data: 输入数据
返回:
处理结果
"""
pass
@abstractmethod
def run(self) -> None:
""" """
运行管道 运行管道
从输入队列获取数据并处理
""" """
pass
class Pipeline(PipelineBase): def stop(self, timeout: Optional[float] = None) -> bool:
"""
管道类
"""
def __init__(self, *args, **kwargs):
""" """
停止管道
参数:
timeout: 停止超时时间None表示使用默认超时时间
返回:
bool: 是否成功停止
""" """
pass if not self._is_running:
return True
logger.info("正在停止管道...")
self._stop_event = True
self._is_running = False
# 等待线程结束
if self._thread and self._thread.is_alive():
timeout = timeout if timeout is not None else self._stop_timeout
self._thread.join(timeout=timeout)
# 检查是否成功停止
if self._thread.is_alive():
logger.warning(f"管道停止超时({timeout}秒),强制终止")
return False
else:
logger.info("管道已成功停止")
return True
return True
def force_stop(self) -> None:
"""
强制停止管道
注意这可能会导致资源未正确释放
"""
logger.warning("强制停止管道")
self._stop_event = True
self._is_running = False
# 注意Python的线程无法被强制终止这里只是设置标志
# 实际终止需要依赖操作系统的进程管理
class PipelineFactory:
"""
管道工厂类
用于创建管道实例
"""
@staticmethod
def create_pipeline(pipeline_name: str) -> Pipeline:
"""
创建管道实例
"""
pass

0
src/pipeline/test.py Normal file
View File

View File

@ -6,129 +6,227 @@
- Runner: 运行器类,工厂模式实现 - Runner: 运行器类,工厂模式实现
- RunnerFactory: 运行器工厂类,用于创建运行器 - RunnerFactory: 运行器工厂类,用于创建运行器
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Any, List, Queue from typing import Dict, Any, List
from threading import Thread, Lock
from queue import Queue
import traceback
import time
from src.audio_chunk import AudioChunk, AudioBinary from src.audio_chunk import AudioChunk, AudioBinary
from src.pipeline import Pipeline from src.pipeline import Pipeline, PipelineFactory
from src.model_loader import ModelLoader from src.model_loader import ModelLoader
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__, level="INFO")
audio_chunk = AudioChunk() audio_chunk = AudioChunk()
models_loaded = ModelLoader() models_loaded = ModelLoader()
pipelines_loaded = PipelineLoader()
class RunnerBase(ABC): class RunnerBase(ABC):
""" """
运行器基类 运行器基类
定义了运行器的基本接口
""" """
# 计算资源
_audio_binary: AudioBinary = None
_models: Dict[str, Any] = {}
_pipeline: Pipeline = None
# IO交互
_receivers: List[callable] = []
# 异步交互消息队列
_input_queue: Queue = None
_output_queue: Queue = None
@abstractmethod @abstractmethod
def adder(self, *args, **kwargs): def adder(self, data: Any) -> None:
""" """
添加数据 添加数据
参数:
data: 要添加的数据
""" """
pass
@abstractmethod @abstractmethod
def add_recevier(self, *args, **kwargs): def add_recevier(self, receiver: callable) -> None:
""" """
接收数据 添加数据接收者
参数:
receiver: 接收数据的回调函数
""" """
pass
class STT_Runner(RunnerBase):
class STTRunner(RunnerBase):
""" """
运行器类 运行器类
工厂模式 负责管理资源和协调Pipeline的运行
""" """
def __init__( def __init__(
self, self,
*, *,
audio_binary_list: List[AudioBinary], audio_binary_list: List[AudioBinary],
models: Dict[str, Any], models: Dict[str, Any],
pipeline_list: List[Pipeline], pipeline_list: List[Pipeline],
input_queue: Queue,
): ):
""" """
初始化 初始化运行器
参数:
audio_binary_list: 音频二进制列表
models: 模型字典
pipeline_list: 管道列表
queue_size: 队列大小
stop_timeout: 停止超时时间
""" """
# 接收资源 # 接收资源
self._audio_binary_list = audio_binary_list self._audio_binary_list = audio_binary_list
self._models = models self._models = models
self._pipeline_list = pipeline_list self._pipeline_list = pipeline_list
# 线程控制
self._lock = Lock()
# 消息队列
self._input_queue = Queue(maxsize=1000)
# 停止控制
self._stop_timeout = 10.0
self._is_stopping = False
# 配置资源 # 配置资源
for pipeline in self._pipeline_list: for pipeline in self._pipeline_list:
# 配置 # 设置输入队列
if pipeline.get_config('audio_binary_name') is not None: pipeline.set_input_queue(self._input_queue)
pipeline.set_audio_binary(self._audio_binary_list[pipeline.get_config('audio_binary_name')])
if pipeline.get_config('model_name_list') is not None:
pipeline.set_models(self._models)
def adder(self, *args, **kwargs): # 配置资源
pipeline.set_audio_binary(
self._audio_binary_list[pipeline.get_config("audio_binary_name")]
)
pipeline.set_models(self._models)
def adder(self, data: Any) -> None:
""" """
添加数据 添加数据到输入队列
参数:
data: 要添加的数据
""" """
if self._pipeline_thread is None: if not self._pipeline_list:
raise RuntimeError("Pipeline thread not started") raise RuntimeError("没有可用的管道")
self._input_queue.put(args["data"]) if self._is_stopping:
raise RuntimeError("运行器正在停止,无法添加数据")
def add_recevier(self, recevier: callable): self._input_queue.put(data)
def add_recevier(self, receiver: callable) -> None:
""" """
添加数据接收者 添加数据接收者
参数:
receiver: 接收数据的回调函数
""" """
if self._pipeline_thread is None: with self._lock:
raise RuntimeError("Pipeline thread not started") for pipeline in self._pipeline_list:
self._receivers.append(recevier) pipeline.add_callback(receiver)
def run(self): def run(self) -> None:
""" """
运行pipeline子线程 启动所有管道
""" """
# 创建pipeline子线程 logger.info("[%s] 启动所有管道", self.__class__.__name__)
self._pipeline_thread = threading.Thread( if not self._pipeline_list:
target=self._pipeline.run, raise RuntimeError("没有可用的管道")
args=(
input_queue=self._input_queue,
output_queue=self._output_queue
)
)
self._pipeline_thread.start()
def stop(self): # 启动所有管道
""" for pipeline in self._pipeline_list:
停止pipeline子线程 thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}")
""" thread.daemon = True
# 结束pipeline子线程 thread.start()
self._pipeline_thread.join() logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline))
def __del__(self):
"""
析构
"""
self.stop()
class STT_RunnerFactory: def stop(self, force: bool = False) -> bool:
"""
停止所有管道
参数:
force: 是否强制停止
返回:
bool: 是否成功停止
"""
if self._is_stopping:
logger.warning("运行器已经在停止中")
return False
self._is_stopping = True
logger.info("正在停止运行器...")
try:
# 发送结束信号
self._input_queue.put(None)
# 停止所有管道
success = True
for pipeline in self._pipeline_list:
if force:
pipeline.force_stop()
else:
if not pipeline.stop(timeout=self._stop_timeout):
logger.warning("管道 %s 停止超时", id(pipeline))
success = False
# 等待队列处理完成
try:
start_time = time.time()
while not self._input_queue.empty():
if time.time() - start_time > self._stop_timeout:
logger.warning(
"等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理",
self._stop_timeout,
self._input_queue.qsize()
)
success = False
break
time.sleep(0.1) # 避免过度消耗CPU
except Exception as e:
error_type = type(e).__name__
error_msg = str(e)
error_traceback = traceback.format_exc()
logger.error(
"等待队列处理完成时发生错误:\n"
"错误类型: %s\n"
"错误信息: %s\n"
"错误堆栈:\n%s",
error_type,
error_msg,
error_traceback
)
success = False
if success:
logger.info("所有管道已成功停止")
else:
logger.warning(
"部分管道停止失败,队列状态: 大小=%d, 是否为空=%s",
self._input_queue.qsize(),
self._input_queue.empty()
)
return success
finally:
self._is_stopping = False
def __del__(self) -> None:
"""
析构函数
"""
self.stop(force=True)
class STTRunnerFactory:
""" """
STT Runner工厂类 STT Runner工厂类
用于创建运行器实例
""" """
@staticmethod
def _create_runner( def _create_runner(
audio_binary_name: str, audio_binary_name: str,
model_name_list: List[str], model_name_list: List[str],
pipeline_name_list: List[str], pipeline_name_list: List[str],
): ) -> STTRunner:
""" """
全参数创建Runner 创建运行器
参数: 参数:
audio_binary_name: 音频二进制名称 audio_binary_name: 音频二进制名称
model_name_list: 模型名称列表 model_name_list: 模型名称列表
@ -142,26 +240,42 @@ class STT_RunnerFactory:
for model_name in model_name_list for model_name in model_name_list
} }
pipelines: List[Pipeline] = [ pipelines: List[Pipeline] = [
pipelines_loaded.pipelines[pipeline_name] PipelineFactory.create_pipeline(pipeline_name)
for pipeline_name in pipeline_name_list for pipeline_name in pipeline_name_list
] ]
return Runner(audio_binary, models, pipelines) return STTRunner(
audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines
)
@classmethod @classmethod
def create_runner_from_config( def create_runner_from_config(
cls, cls,
config: Dict[str, Any], config: Dict[str, Any],
): ) -> STTRunner:
"""
从配置创建运行器
参数:
config: 配置字典
返回:
Runner实例
"""
audio_binary_name = config["audio_binary_name"] audio_binary_name = config["audio_binary_name"]
model_name_list = config["model_name_list"] model_name_list = config["model_name_list"]
pipeline_name_list = config["pipeline_name_list"] pipeline_name_list = config["pipeline_name_list"]
return cls._create_runner(audio_binary_name, model_name_list, pipeline_name_list) return cls._create_runner(
audio_binary_name, model_name_list, pipeline_name_list
)
@classmethod @classmethod
def create_runner_normal( def create_runner_normal(cls) -> STTRunner:
cls """
) 创建默认运行器
返回:
Runner实例
"""
audio_binary_name = None audio_binary_name = None
model_name_list = models_loaded.models.keys() model_name_list = list(models_loaded.models.keys())
pipeline_name_list = None pipeline_name_list = None
return cls._create_runner(audio_binary_name, model_name_list, pipeline_name_list) return cls._create_runner(
audio_binary_name, model_name_list, pipeline_name_list
)