[代码重构中] 将ModelLoader移至core目录,更新相关测试文件的导入路径。创建ASRRunner,初步搭建框架。
This commit is contained in:
parent
5a820b49e4
commit
7b9a79942d
281
src/runner.py
281
src/runner.py
@ -1,281 +0,0 @@
|
||||
"""
|
||||
运行器模块
|
||||
提供运行器基类和运行器类,用于管理音频数据和模型的交互。
|
||||
主要包含:
|
||||
- RunnerBase: 运行器基类,定义了基本接口
|
||||
- Runner: 运行器类,工厂模式实现
|
||||
- RunnerFactory: 运行器工厂类,用于创建运行器
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List
|
||||
from threading import Thread, Lock
|
||||
from queue import Queue
|
||||
import traceback
|
||||
import time
|
||||
|
||||
from src.audio_chunk import AudioChunk, AudioBinary
|
||||
from src.pipeline import Pipeline, PipelineFactory
|
||||
from src.model_loader import ModelLoader
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__, level="INFO")
|
||||
|
||||
audio_chunk = AudioChunk()
|
||||
models_loaded = ModelLoader()
|
||||
|
||||
|
||||
class RunnerBase(ABC):
|
||||
"""
|
||||
运行器基类
|
||||
定义了运行器的基本接口
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def adder(self, data: Any) -> None:
|
||||
"""
|
||||
添加数据
|
||||
参数:
|
||||
data: 要添加的数据
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_recevier(self, receiver: callable) -> None:
|
||||
"""
|
||||
添加数据接收者
|
||||
参数:
|
||||
receiver: 接收数据的回调函数
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class STTRunner(RunnerBase):
|
||||
"""
|
||||
运行器类
|
||||
负责管理资源和协调Pipeline的运行
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
audio_binary_list: List[AudioBinary],
|
||||
models: Dict[str, Any],
|
||||
pipeline_list: List[Pipeline],
|
||||
):
|
||||
"""
|
||||
初始化运行器
|
||||
参数:
|
||||
audio_binary_list: 音频二进制列表
|
||||
models: 模型字典
|
||||
pipeline_list: 管道列表
|
||||
queue_size: 队列大小
|
||||
stop_timeout: 停止超时时间(秒)
|
||||
"""
|
||||
# 接收资源
|
||||
self._audio_binary_list = audio_binary_list
|
||||
self._models = models
|
||||
self._pipeline_list = pipeline_list
|
||||
|
||||
# 线程控制
|
||||
self._lock = Lock()
|
||||
|
||||
# 消息队列
|
||||
self._input_queue = Queue(maxsize=1000)
|
||||
|
||||
# 停止控制
|
||||
self._stop_timeout = 10.0
|
||||
self._is_stopping = False
|
||||
|
||||
# 配置资源
|
||||
for pipeline in self._pipeline_list:
|
||||
# 设置输入队列
|
||||
pipeline.set_input_queue(self._input_queue)
|
||||
|
||||
# 配置资源
|
||||
pipeline.set_audio_binary(
|
||||
self._audio_binary_list[pipeline.get_config("audio_binary_name")]
|
||||
)
|
||||
pipeline.set_models(self._models)
|
||||
|
||||
def adder(self, data: Any) -> None:
|
||||
"""
|
||||
添加数据到输入队列
|
||||
参数:
|
||||
data: 要添加的数据
|
||||
"""
|
||||
if not self._pipeline_list:
|
||||
raise RuntimeError("没有可用的管道")
|
||||
if self._is_stopping:
|
||||
raise RuntimeError("运行器正在停止,无法添加数据")
|
||||
self._input_queue.put(data)
|
||||
|
||||
def add_recevier(self, receiver: callable) -> None:
|
||||
"""
|
||||
添加数据接收者
|
||||
参数:
|
||||
receiver: 接收数据的回调函数
|
||||
"""
|
||||
with self._lock:
|
||||
for pipeline in self._pipeline_list:
|
||||
pipeline.add_callback(receiver)
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
启动所有管道
|
||||
"""
|
||||
logger.info("[%s] 启动所有管道", self.__class__.__name__)
|
||||
if not self._pipeline_list:
|
||||
raise RuntimeError("没有可用的管道")
|
||||
|
||||
# 启动所有管道
|
||||
for pipeline in self._pipeline_list:
|
||||
thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}")
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline))
|
||||
|
||||
def stop(self, force: bool = False) -> bool:
|
||||
"""
|
||||
停止所有管道
|
||||
参数:
|
||||
force: 是否强制停止
|
||||
返回:
|
||||
bool: 是否成功停止
|
||||
"""
|
||||
if self._is_stopping:
|
||||
logger.warning("运行器已经在停止中")
|
||||
return False
|
||||
|
||||
self._is_stopping = True
|
||||
logger.info("正在停止运行器...")
|
||||
|
||||
try:
|
||||
# 发送结束信号
|
||||
self._input_queue.put(None)
|
||||
|
||||
# 停止所有管道
|
||||
success = True
|
||||
for pipeline in self._pipeline_list:
|
||||
if force:
|
||||
pipeline.force_stop()
|
||||
else:
|
||||
if not pipeline.stop(timeout=self._stop_timeout):
|
||||
logger.warning("管道 %s 停止超时", id(pipeline))
|
||||
success = False
|
||||
|
||||
# 等待队列处理完成
|
||||
try:
|
||||
start_time = time.time()
|
||||
while not self._input_queue.empty():
|
||||
if time.time() - start_time > self._stop_timeout:
|
||||
logger.warning(
|
||||
"等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理",
|
||||
self._stop_timeout,
|
||||
self._input_queue.qsize(),
|
||||
)
|
||||
success = False
|
||||
break
|
||||
time.sleep(0.1) # 避免过度消耗CPU
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
error_msg = str(e)
|
||||
error_traceback = traceback.format_exc()
|
||||
logger.error(
|
||||
"等待队列处理完成时发生错误:\n"
|
||||
"错误类型: %s\n"
|
||||
"错误信息: %s\n"
|
||||
"错误堆栈:\n%s",
|
||||
error_type,
|
||||
error_msg,
|
||||
error_traceback,
|
||||
)
|
||||
success = False
|
||||
|
||||
if success:
|
||||
logger.info("所有管道已成功停止")
|
||||
else:
|
||||
logger.warning(
|
||||
"部分管道停止失败,队列状态: 大小=%d, 是否为空=%s",
|
||||
self._input_queue.qsize(),
|
||||
self._input_queue.empty(),
|
||||
)
|
||||
|
||||
return success
|
||||
|
||||
finally:
|
||||
self._is_stopping = False
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
析构函数
|
||||
"""
|
||||
self.stop(force=True)
|
||||
|
||||
|
||||
class STTRunnerFactory:
|
||||
"""
|
||||
STT Runner工厂类
|
||||
用于创建运行器实例
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _create_runner(
|
||||
audio_binary_name: str,
|
||||
model_name_list: List[str],
|
||||
pipeline_name_list: List[str],
|
||||
) -> STTRunner:
|
||||
"""
|
||||
创建运行器
|
||||
参数:
|
||||
audio_binary_name: 音频二进制名称
|
||||
model_name_list: 模型名称列表
|
||||
pipeline_name_list: 管道名称列表
|
||||
返回:
|
||||
Runner实例
|
||||
"""
|
||||
audio_binary = audio_chunk.get_audio_binary(audio_binary_name)
|
||||
models: Dict[str, Any] = {
|
||||
model_name: models_loaded.models[model_name]
|
||||
for model_name in model_name_list
|
||||
}
|
||||
pipelines: List[Pipeline] = [
|
||||
PipelineFactory.create_pipeline(pipeline_name)
|
||||
for pipeline_name in pipeline_name_list
|
||||
]
|
||||
return STTRunner(
|
||||
audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_runner_from_config(
|
||||
cls,
|
||||
config: Dict[str, Any],
|
||||
) -> STTRunner:
|
||||
"""
|
||||
从配置创建运行器
|
||||
参数:
|
||||
config: 配置字典
|
||||
返回:
|
||||
Runner实例
|
||||
"""
|
||||
audio_binary_name = config["audio_binary_name"]
|
||||
model_name_list = config["model_name_list"]
|
||||
pipeline_name_list = config["pipeline_name_list"]
|
||||
return cls._create_runner(
|
||||
audio_binary_name, model_name_list, pipeline_name_list
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_runner_normal(cls) -> STTRunner:
|
||||
"""
|
||||
创建默认运行器
|
||||
返回:
|
||||
Runner实例
|
||||
"""
|
||||
audio_binary_name = None
|
||||
model_name_list = list(models_loaded.models.keys())
|
||||
pipeline_name_list = None
|
||||
return cls._create_runner(
|
||||
audio_binary_name, model_name_list, pipeline_name_list
|
||||
)
|
227
src/runner/ASRRunner.py
Normal file
227
src/runner/ASRRunner.py
Normal file
@ -0,0 +1,227 @@
|
||||
"""
|
||||
-*- encoding: utf-8 -*-
|
||||
|
||||
ASRRunner
|
||||
继承RunnerBase
|
||||
专属pipeline为ASRPipeline
|
||||
"""
|
||||
|
||||
from src.pipeline.ASRpipeline import ASRPipeline
|
||||
from src.pipeline import PipelineFactory
|
||||
from src.models import AudioBinary_data_list, AudioBinary_Config
|
||||
from src.core.model_loader import ModelLoader
|
||||
from queue import Queue
|
||||
import soundfile
|
||||
import time
|
||||
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
OVAERWATCH = False
|
||||
|
||||
model_loader = ModelLoader()
|
||||
|
||||
|
||||
def test_asr_pipeline():
|
||||
# 加载模型
|
||||
args = {
|
||||
"asr_model": "paraformer-zh",
|
||||
"asr_model_revision": "v2.0.4",
|
||||
"vad_model": "fsmn-vad",
|
||||
"vad_model_revision": "v2.0.4",
|
||||
"spk_model": "cam++",
|
||||
"spk_model_revision": "v2.0.2",
|
||||
"audio_update": False,
|
||||
}
|
||||
models = model_loader.load_models(args)
|
||||
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
|
||||
audio_config = AudioBinary_Config(
|
||||
chunk_size=200,
|
||||
chunk_stride=1600,
|
||||
sample_rate=sample_rate,
|
||||
sample_width=16,
|
||||
channels=1,
|
||||
)
|
||||
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
|
||||
audio_config.chunk_stride = chunk_stride
|
||||
|
||||
# 创建参数Dict
|
||||
config = {
|
||||
"audio_config": audio_config,
|
||||
}
|
||||
|
||||
# 创建音频数据列表
|
||||
audio_binary_data_list = AudioBinary_data_list()
|
||||
|
||||
input_queue = Queue()
|
||||
|
||||
# 创建Pipeline
|
||||
# asr_pipeline = ASRPipeline()
|
||||
# asr_pipeline.set_models(models)
|
||||
# asr_pipeline.set_config(config)
|
||||
# asr_pipeline.set_audio_binary(audio_binary_data_list)
|
||||
# asr_pipeline.set_input_queue(input_queue)
|
||||
# asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
|
||||
# asr_pipeline.bake()
|
||||
asr_pipeline = PipelineFactory.create_pipeline(
|
||||
pipeline_name = "ASRpipeline",
|
||||
models=models,
|
||||
config=config,
|
||||
audio_binary=audio_binary_data_list,
|
||||
input_queue=input_queue,
|
||||
callback=lambda x: print(f"pipeline callback: {x}")
|
||||
)
|
||||
|
||||
# 运行Pipeline
|
||||
asr_instance = asr_pipeline.run()
|
||||
|
||||
|
||||
audio_clip_len = 200
|
||||
print(
|
||||
f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}"
|
||||
)
|
||||
for i in range(0, len(audio_data), audio_clip_len):
|
||||
input_queue.put(audio_data[i : i + audio_clip_len])
|
||||
|
||||
# time.sleep(10)
|
||||
# input_queue.put(None)
|
||||
|
||||
# 等待Pipeline结束
|
||||
# asr_instance.join()
|
||||
|
||||
time.sleep(5)
|
||||
asr_pipeline.stop()
|
||||
# asr_pipeline.stop()
|
||||
|
||||
|
||||
|
||||
class STTRunner(RunnerBase):
|
||||
"""
|
||||
运行器类
|
||||
负责管理资源和协调Pipeline的运行
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
"""
|
||||
# ws资源
|
||||
self._ws_pool: Dict[str,List[WebSocketClient]] = {}
|
||||
# 接收资源
|
||||
self._audio_binary_list = audio_binary_list
|
||||
self._models = models
|
||||
self._pipeline_list = pipeline_list
|
||||
|
||||
# 线程控制
|
||||
self._lock = Lock()
|
||||
# 停止控制
|
||||
self._stop_timeout = 10.0
|
||||
self._is_stopping = False
|
||||
|
||||
# 配置资源
|
||||
for pipeline in self._pipeline_list:
|
||||
# 设置输入队列
|
||||
pipeline.set_input_queue(self._input_queue)
|
||||
|
||||
# 配置资源
|
||||
pipeline.set_audio_binary(
|
||||
self._audio_binary_list[pipeline.get_config("audio_binary_name")]
|
||||
)
|
||||
pipeline.set_models(self._models)
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
启动所有管道
|
||||
"""
|
||||
logger.info("[%s] 启动所有管道", self.__class__.__name__)
|
||||
if not self._pipeline_list:
|
||||
raise RuntimeError("没有可用的管道")
|
||||
|
||||
# 启动所有管道
|
||||
for pipeline in self._pipeline_list:
|
||||
thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}")
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline))
|
||||
|
||||
def stop(self, force: bool = False) -> bool:
|
||||
"""
|
||||
停止所有管道
|
||||
参数:
|
||||
force: 是否强制停止
|
||||
返回:
|
||||
bool: 是否成功停止
|
||||
"""
|
||||
if self._is_stopping:
|
||||
logger.warning("运行器已经在停止中")
|
||||
return False
|
||||
|
||||
self._is_stopping = True
|
||||
logger.info("正在停止运行器...")
|
||||
|
||||
try:
|
||||
# 发送结束信号
|
||||
self._input_queue.put(None)
|
||||
|
||||
# 停止所有管道
|
||||
success = True
|
||||
for pipeline in self._pipeline_list:
|
||||
if force:
|
||||
pipeline.force_stop()
|
||||
else:
|
||||
if not pipeline.stop(timeout=self._stop_timeout):
|
||||
logger.warning("管道 %s 停止超时", id(pipeline))
|
||||
success = False
|
||||
|
||||
# 等待队列处理完成
|
||||
try:
|
||||
start_time = time.time()
|
||||
while not self._input_queue.empty():
|
||||
if time.time() - start_time > self._stop_timeout:
|
||||
logger.warning(
|
||||
"等待队列处理完成超时(%s秒), 队列中还有 %d 个任务未处理",
|
||||
self._stop_timeout,
|
||||
self._input_queue.qsize(),
|
||||
)
|
||||
success = False
|
||||
break
|
||||
time.sleep(0.1) # 避免过度消耗CPU
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
error_msg = str(e)
|
||||
error_traceback = traceback.format_exc()
|
||||
logger.error(
|
||||
"等待队列处理完成时发生错误:\n"
|
||||
"错误类型: %s\n"
|
||||
"错误信息: %s\n"
|
||||
"错误堆栈:\n%s",
|
||||
error_type,
|
||||
error_msg,
|
||||
error_traceback,
|
||||
)
|
||||
success = False
|
||||
|
||||
if success:
|
||||
logger.info("所有管道已成功停止")
|
||||
else:
|
||||
logger.warning(
|
||||
"部分管道停止失败, 队列状态: 大小=%d, 是否为空=%s",
|
||||
self._input_queue.qsize(),
|
||||
self._input_queue.empty(),
|
||||
)
|
||||
|
||||
return success
|
||||
|
||||
finally:
|
||||
self._is_stopping = False
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
析构函数
|
||||
"""
|
||||
self.stop(force=True)
|
105
src/runner/runner.py
Normal file
105
src/runner/runner.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""
|
||||
-*- encoding: utf-8 -*-
|
||||
|
||||
Runner类
|
||||
所有的Runner都对应一个fastapi的endpoint,
|
||||
Runner需要处理:
|
||||
1.新的websocket 进来后放到 unknow_websocket_pool中
|
||||
2.收到特定消息后, 将消息转发给特定的pipeline处理
|
||||
3.管理pipeline与websocket对应关系, 管理pipeline的ID
|
||||
4.管理pipeline的启动和停止
|
||||
5.管理所有pipeline用到的资源, 管理pipeline的存活时间。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List
|
||||
from threading import Thread, Lock
|
||||
from queue import Queue
|
||||
import traceback
|
||||
import time
|
||||
|
||||
from src.audio_chunk import AudioChunk, AudioBinary
|
||||
from src.pipeline import Pipeline, PipelineFactory
|
||||
from src.core.model_loader import ModelLoader
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
audio_chunk = AudioChunk()
|
||||
models_loaded = ModelLoader()
|
||||
|
||||
|
||||
class RunnerBase(ABC):
|
||||
"""
|
||||
运行器基类
|
||||
定义了运行器的基本接口
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
class STTRunnerFactory:
|
||||
"""
|
||||
STT Runner工厂类
|
||||
用于创建运行器实例
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _create_runner(
|
||||
audio_binary_name: str,
|
||||
model_name_list: List[str],
|
||||
pipeline_name_list: List[str],
|
||||
) -> STTRunner:
|
||||
"""
|
||||
创建运行器
|
||||
参数:
|
||||
audio_binary_name: 音频二进制名称
|
||||
model_name_list: 模型名称列表
|
||||
pipeline_name_list: 管道名称列表
|
||||
返回:
|
||||
Runner实例
|
||||
"""
|
||||
audio_binary = audio_chunk.get_audio_binary(audio_binary_name)
|
||||
models: Dict[str, Any] = {
|
||||
model_name: models_loaded.models[model_name]
|
||||
for model_name in model_name_list
|
||||
}
|
||||
pipelines: List[Pipeline] = [
|
||||
PipelineFactory.create_pipeline(pipeline_name)
|
||||
for pipeline_name in pipeline_name_list
|
||||
]
|
||||
return STTRunner(
|
||||
audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_runner_from_config(
|
||||
cls,
|
||||
config: Dict[str, Any],
|
||||
) -> STTRunner:
|
||||
"""
|
||||
从配置创建运行器
|
||||
参数:
|
||||
config: 配置字典
|
||||
返回:
|
||||
Runner实例
|
||||
"""
|
||||
audio_binary_name = config["audio_binary_name"]
|
||||
model_name_list = config["model_name_list"]
|
||||
pipeline_name_list = config["pipeline_name_list"]
|
||||
return cls._create_runner(
|
||||
audio_binary_name, model_name_list, pipeline_name_list
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_runner_normal(cls) -> STTRunner:
|
||||
"""
|
||||
创建默认运行器
|
||||
返回:
|
||||
Runner实例
|
||||
"""
|
||||
audio_binary_name = None
|
||||
model_name_list = list(models_loaded.models.keys())
|
||||
pipeline_name_list = None
|
||||
return cls._create_runner(
|
||||
audio_binary_name, model_name_list, pipeline_name_list
|
||||
)
|
@ -7,7 +7,7 @@ from src.functor.vad_functor import VADFunctor
|
||||
from src.functor.asr_functor import ASRFunctor
|
||||
from src.functor.spk_functor import SPKFunctor
|
||||
from queue import Queue, Empty
|
||||
from src.model_loader import ModelLoader
|
||||
from src.core.model_loader import ModelLoader
|
||||
from src.models import AudioBinary_Config, AudioBinary_data_list
|
||||
from src.utils.data_format import wav_to_bytes
|
||||
import time
|
||||
|
@ -61,7 +61,7 @@ def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
|
||||
|
||||
# from src.functor.model_loader import load_models
|
||||
# models = load_models(args)
|
||||
from src.model_loader import ModelLoader
|
||||
from src.core.model_loader import ModelLoader
|
||||
|
||||
models = ModelLoader(args)
|
||||
|
||||
|
@ -6,7 +6,7 @@ VAD+ASR+SPK(FAKE)
|
||||
from src.pipeline.ASRpipeline import ASRPipeline
|
||||
from src.pipeline import PipelineFactory
|
||||
from src.models import AudioBinary_data_list, AudioBinary_Config
|
||||
from src.model_loader import ModelLoader
|
||||
from src.core.model_loader import ModelLoader
|
||||
from queue import Queue
|
||||
import soundfile
|
||||
import time
|
||||
|
0
tests/runner/stt_runner.py
Normal file
0
tests/runner/stt_runner.py
Normal file
Loading…
x
Reference in New Issue
Block a user