[代码重构中] 将ModelLoader移至core目录,更新相关测试文件的导入路径。创建ASRRunner,初步搭建框架。

This commit is contained in:
Ziyang.Zhang 2025-06-24 09:22:48 +08:00
parent 5a820b49e4
commit 7b9a79942d
8 changed files with 335 additions and 284 deletions

View File

@ -1,281 +0,0 @@
"""
运行器模块
提供运行器基类和运行器类用于管理音频数据和模型的交互
主要包含:
- RunnerBase: 运行器基类,定义了基本接口
- Runner: 运行器类,工厂模式实现
- RunnerFactory: 运行器工厂类,用于创建运行器
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, List
from threading import Thread, Lock
from queue import Queue
import traceback
import time
from src.audio_chunk import AudioChunk, AudioBinary
from src.pipeline import Pipeline, PipelineFactory
from src.model_loader import ModelLoader
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__, level="INFO")
audio_chunk = AudioChunk()
models_loaded = ModelLoader()
class RunnerBase(ABC):
"""
运行器基类
定义了运行器的基本接口
"""
@abstractmethod
def adder(self, data: Any) -> None:
"""
添加数据
参数:
data: 要添加的数据
"""
pass
@abstractmethod
def add_recevier(self, receiver: callable) -> None:
"""
添加数据接收者
参数:
receiver: 接收数据的回调函数
"""
pass
class STTRunner(RunnerBase):
"""
运行器类
负责管理资源和协调Pipeline的运行
"""
def __init__(
self,
*,
audio_binary_list: List[AudioBinary],
models: Dict[str, Any],
pipeline_list: List[Pipeline],
):
"""
初始化运行器
参数:
audio_binary_list: 音频二进制列表
models: 模型字典
pipeline_list: 管道列表
queue_size: 队列大小
stop_timeout: 停止超时时间
"""
# 接收资源
self._audio_binary_list = audio_binary_list
self._models = models
self._pipeline_list = pipeline_list
# 线程控制
self._lock = Lock()
# 消息队列
self._input_queue = Queue(maxsize=1000)
# 停止控制
self._stop_timeout = 10.0
self._is_stopping = False
# 配置资源
for pipeline in self._pipeline_list:
# 设置输入队列
pipeline.set_input_queue(self._input_queue)
# 配置资源
pipeline.set_audio_binary(
self._audio_binary_list[pipeline.get_config("audio_binary_name")]
)
pipeline.set_models(self._models)
def adder(self, data: Any) -> None:
"""
添加数据到输入队列
参数:
data: 要添加的数据
"""
if not self._pipeline_list:
raise RuntimeError("没有可用的管道")
if self._is_stopping:
raise RuntimeError("运行器正在停止,无法添加数据")
self._input_queue.put(data)
def add_recevier(self, receiver: callable) -> None:
"""
添加数据接收者
参数:
receiver: 接收数据的回调函数
"""
with self._lock:
for pipeline in self._pipeline_list:
pipeline.add_callback(receiver)
def run(self) -> None:
"""
启动所有管道
"""
logger.info("[%s] 启动所有管道", self.__class__.__name__)
if not self._pipeline_list:
raise RuntimeError("没有可用的管道")
# 启动所有管道
for pipeline in self._pipeline_list:
thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}")
thread.daemon = True
thread.start()
logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline))
def stop(self, force: bool = False) -> bool:
"""
停止所有管道
参数:
force: 是否强制停止
返回:
bool: 是否成功停止
"""
if self._is_stopping:
logger.warning("运行器已经在停止中")
return False
self._is_stopping = True
logger.info("正在停止运行器...")
try:
# 发送结束信号
self._input_queue.put(None)
# 停止所有管道
success = True
for pipeline in self._pipeline_list:
if force:
pipeline.force_stop()
else:
if not pipeline.stop(timeout=self._stop_timeout):
logger.warning("管道 %s 停止超时", id(pipeline))
success = False
# 等待队列处理完成
try:
start_time = time.time()
while not self._input_queue.empty():
if time.time() - start_time > self._stop_timeout:
logger.warning(
"等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理",
self._stop_timeout,
self._input_queue.qsize(),
)
success = False
break
time.sleep(0.1) # 避免过度消耗CPU
except Exception as e:
error_type = type(e).__name__
error_msg = str(e)
error_traceback = traceback.format_exc()
logger.error(
"等待队列处理完成时发生错误:\n"
"错误类型: %s\n"
"错误信息: %s\n"
"错误堆栈:\n%s",
error_type,
error_msg,
error_traceback,
)
success = False
if success:
logger.info("所有管道已成功停止")
else:
logger.warning(
"部分管道停止失败,队列状态: 大小=%d, 是否为空=%s",
self._input_queue.qsize(),
self._input_queue.empty(),
)
return success
finally:
self._is_stopping = False
def __del__(self) -> None:
"""
析构函数
"""
self.stop(force=True)
class STTRunnerFactory:
"""
STT Runner工厂类
用于创建运行器实例
"""
@staticmethod
def _create_runner(
audio_binary_name: str,
model_name_list: List[str],
pipeline_name_list: List[str],
) -> STTRunner:
"""
创建运行器
参数:
audio_binary_name: 音频二进制名称
model_name_list: 模型名称列表
pipeline_name_list: 管道名称列表
返回:
Runner实例
"""
audio_binary = audio_chunk.get_audio_binary(audio_binary_name)
models: Dict[str, Any] = {
model_name: models_loaded.models[model_name]
for model_name in model_name_list
}
pipelines: List[Pipeline] = [
PipelineFactory.create_pipeline(pipeline_name)
for pipeline_name in pipeline_name_list
]
return STTRunner(
audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines
)
@classmethod
def create_runner_from_config(
cls,
config: Dict[str, Any],
) -> STTRunner:
"""
从配置创建运行器
参数:
config: 配置字典
返回:
Runner实例
"""
audio_binary_name = config["audio_binary_name"]
model_name_list = config["model_name_list"]
pipeline_name_list = config["pipeline_name_list"]
return cls._create_runner(
audio_binary_name, model_name_list, pipeline_name_list
)
@classmethod
def create_runner_normal(cls) -> STTRunner:
"""
创建默认运行器
返回:
Runner实例
"""
audio_binary_name = None
model_name_list = list(models_loaded.models.keys())
pipeline_name_list = None
return cls._create_runner(
audio_binary_name, model_name_list, pipeline_name_list
)

227
src/runner/ASRRunner.py Normal file
View File

@ -0,0 +1,227 @@
"""
-*- encoding: utf-8 -*-
ASRRunner
继承RunnerBase
专属pipeline为ASRPipeline
"""
from src.pipeline.ASRpipeline import ASRPipeline
from src.pipeline import PipelineFactory
from src.models import AudioBinary_data_list, AudioBinary_Config
from src.core.model_loader import ModelLoader
from queue import Queue
import soundfile
import time
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
OVAERWATCH = False
model_loader = ModelLoader()
def test_asr_pipeline():
# 加载模型
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"spk_model": "cam++",
"spk_model_revision": "v2.0.2",
"audio_update": False,
}
models = model_loader.load_models(args)
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
audio_config = AudioBinary_Config(
chunk_size=200,
chunk_stride=1600,
sample_rate=sample_rate,
sample_width=16,
channels=1,
)
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
audio_config.chunk_stride = chunk_stride
# 创建参数Dict
config = {
"audio_config": audio_config,
}
# 创建音频数据列表
audio_binary_data_list = AudioBinary_data_list()
input_queue = Queue()
# 创建Pipeline
# asr_pipeline = ASRPipeline()
# asr_pipeline.set_models(models)
# asr_pipeline.set_config(config)
# asr_pipeline.set_audio_binary(audio_binary_data_list)
# asr_pipeline.set_input_queue(input_queue)
# asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
# asr_pipeline.bake()
asr_pipeline = PipelineFactory.create_pipeline(
pipeline_name = "ASRpipeline",
models=models,
config=config,
audio_binary=audio_binary_data_list,
input_queue=input_queue,
callback=lambda x: print(f"pipeline callback: {x}")
)
# 运行Pipeline
asr_instance = asr_pipeline.run()
audio_clip_len = 200
print(
f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}"
)
for i in range(0, len(audio_data), audio_clip_len):
input_queue.put(audio_data[i : i + audio_clip_len])
# time.sleep(10)
# input_queue.put(None)
# 等待Pipeline结束
# asr_instance.join()
time.sleep(5)
asr_pipeline.stop()
# asr_pipeline.stop()
class STTRunner(RunnerBase):
"""
运行器类
负责管理资源和协调Pipeline的运行
"""
def __init__(
self,
*args,
**kwargs,
):
"""
"""
# ws资源
self._ws_pool: Dict[str,List[WebSocketClient]] = {}
# 接收资源
self._audio_binary_list = audio_binary_list
self._models = models
self._pipeline_list = pipeline_list
# 线程控制
self._lock = Lock()
# 停止控制
self._stop_timeout = 10.0
self._is_stopping = False
# 配置资源
for pipeline in self._pipeline_list:
# 设置输入队列
pipeline.set_input_queue(self._input_queue)
# 配置资源
pipeline.set_audio_binary(
self._audio_binary_list[pipeline.get_config("audio_binary_name")]
)
pipeline.set_models(self._models)
def run(self) -> None:
"""
启动所有管道
"""
logger.info("[%s] 启动所有管道", self.__class__.__name__)
if not self._pipeline_list:
raise RuntimeError("没有可用的管道")
# 启动所有管道
for pipeline in self._pipeline_list:
thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}")
thread.daemon = True
thread.start()
logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline))
def stop(self, force: bool = False) -> bool:
"""
停止所有管道
参数:
force: 是否强制停止
返回:
bool: 是否成功停止
"""
if self._is_stopping:
logger.warning("运行器已经在停止中")
return False
self._is_stopping = True
logger.info("正在停止运行器...")
try:
# 发送结束信号
self._input_queue.put(None)
# 停止所有管道
success = True
for pipeline in self._pipeline_list:
if force:
pipeline.force_stop()
else:
if not pipeline.stop(timeout=self._stop_timeout):
logger.warning("管道 %s 停止超时", id(pipeline))
success = False
# 等待队列处理完成
try:
start_time = time.time()
while not self._input_queue.empty():
if time.time() - start_time > self._stop_timeout:
logger.warning(
"等待队列处理完成超时(%s秒), 队列中还有 %d 个任务未处理",
self._stop_timeout,
self._input_queue.qsize(),
)
success = False
break
time.sleep(0.1) # 避免过度消耗CPU
except Exception as e:
error_type = type(e).__name__
error_msg = str(e)
error_traceback = traceback.format_exc()
logger.error(
"等待队列处理完成时发生错误:\n"
"错误类型: %s\n"
"错误信息: %s\n"
"错误堆栈:\n%s",
error_type,
error_msg,
error_traceback,
)
success = False
if success:
logger.info("所有管道已成功停止")
else:
logger.warning(
"部分管道停止失败, 队列状态: 大小=%d, 是否为空=%s",
self._input_queue.qsize(),
self._input_queue.empty(),
)
return success
finally:
self._is_stopping = False
def __del__(self) -> None:
"""
析构函数
"""
self.stop(force=True)

105
src/runner/runner.py Normal file
View File

@ -0,0 +1,105 @@
"""
-*- encoding: utf-8 -*-
Runner类
所有的Runner都对应一个fastapi的endpoint,
Runner需要处理:
1.新的websocket 进来后放到 unknow_websocket_pool中
2.收到特定消息后, 将消息转发给特定的pipeline处理
3.管理pipeline与websocket对应关系, 管理pipeline的ID
4.管理pipeline的启动和停止
5.管理所有pipeline用到的资源, 管理pipeline的存活时间
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, List
from threading import Thread, Lock
from queue import Queue
import traceback
import time
from src.audio_chunk import AudioChunk, AudioBinary
from src.pipeline import Pipeline, PipelineFactory
from src.core.model_loader import ModelLoader
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
audio_chunk = AudioChunk()
models_loaded = ModelLoader()
class RunnerBase(ABC):
"""
运行器基类
定义了运行器的基本接口
"""
def __init__(self, *args, **kwargs):
pass
class STTRunnerFactory:
"""
STT Runner工厂类
用于创建运行器实例
"""
@staticmethod
def _create_runner(
audio_binary_name: str,
model_name_list: List[str],
pipeline_name_list: List[str],
) -> STTRunner:
"""
创建运行器
参数:
audio_binary_name: 音频二进制名称
model_name_list: 模型名称列表
pipeline_name_list: 管道名称列表
返回:
Runner实例
"""
audio_binary = audio_chunk.get_audio_binary(audio_binary_name)
models: Dict[str, Any] = {
model_name: models_loaded.models[model_name]
for model_name in model_name_list
}
pipelines: List[Pipeline] = [
PipelineFactory.create_pipeline(pipeline_name)
for pipeline_name in pipeline_name_list
]
return STTRunner(
audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines
)
@classmethod
def create_runner_from_config(
cls,
config: Dict[str, Any],
) -> STTRunner:
"""
从配置创建运行器
参数:
config: 配置字典
返回:
Runner实例
"""
audio_binary_name = config["audio_binary_name"]
model_name_list = config["model_name_list"]
pipeline_name_list = config["pipeline_name_list"]
return cls._create_runner(
audio_binary_name, model_name_list, pipeline_name_list
)
@classmethod
def create_runner_normal(cls) -> STTRunner:
"""
创建默认运行器
返回:
Runner实例
"""
audio_binary_name = None
model_name_list = list(models_loaded.models.keys())
pipeline_name_list = None
return cls._create_runner(
audio_binary_name, model_name_list, pipeline_name_list
)

View File

@ -7,7 +7,7 @@ from src.functor.vad_functor import VADFunctor
from src.functor.asr_functor import ASRFunctor
from src.functor.spk_functor import SPKFunctor
from queue import Queue, Empty
from src.model_loader import ModelLoader
from src.core.model_loader import ModelLoader
from src.models import AudioBinary_Config, AudioBinary_data_list
from src.utils.data_format import wav_to_bytes
import time

View File

@ -61,7 +61,7 @@ def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
# from src.functor.model_loader import load_models
# models = load_models(args)
from src.model_loader import ModelLoader
from src.core.model_loader import ModelLoader
models = ModelLoader(args)

View File

@ -6,7 +6,7 @@ VAD+ASR+SPK(FAKE)
from src.pipeline.ASRpipeline import ASRPipeline
from src.pipeline import PipelineFactory
from src.models import AudioBinary_data_list, AudioBinary_Config
from src.model_loader import ModelLoader
from src.core.model_loader import ModelLoader
from queue import Queue
import soundfile
import time

View File