[代码重构中] 将ModelLoader移至core目录,更新相关测试文件的导入路径。创建ASRRunner,初步搭建框架。
This commit is contained in:
parent
5a820b49e4
commit
7b9a79942d
281
src/runner.py
281
src/runner.py
@ -1,281 +0,0 @@
|
|||||||
"""
|
|
||||||
运行器模块
|
|
||||||
提供运行器基类和运行器类,用于管理音频数据和模型的交互。
|
|
||||||
主要包含:
|
|
||||||
- RunnerBase: 运行器基类,定义了基本接口
|
|
||||||
- Runner: 运行器类,工厂模式实现
|
|
||||||
- RunnerFactory: 运行器工厂类,用于创建运行器
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, Any, List
|
|
||||||
from threading import Thread, Lock
|
|
||||||
from queue import Queue
|
|
||||||
import traceback
|
|
||||||
import time
|
|
||||||
|
|
||||||
from src.audio_chunk import AudioChunk, AudioBinary
|
|
||||||
from src.pipeline import Pipeline, PipelineFactory
|
|
||||||
from src.model_loader import ModelLoader
|
|
||||||
from src.utils.logger import get_module_logger
|
|
||||||
|
|
||||||
logger = get_module_logger(__name__, level="INFO")
|
|
||||||
|
|
||||||
audio_chunk = AudioChunk()
|
|
||||||
models_loaded = ModelLoader()
|
|
||||||
|
|
||||||
|
|
||||||
class RunnerBase(ABC):
|
|
||||||
"""
|
|
||||||
运行器基类
|
|
||||||
定义了运行器的基本接口
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def adder(self, data: Any) -> None:
|
|
||||||
"""
|
|
||||||
添加数据
|
|
||||||
参数:
|
|
||||||
data: 要添加的数据
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_recevier(self, receiver: callable) -> None:
|
|
||||||
"""
|
|
||||||
添加数据接收者
|
|
||||||
参数:
|
|
||||||
receiver: 接收数据的回调函数
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class STTRunner(RunnerBase):
|
|
||||||
"""
|
|
||||||
运行器类
|
|
||||||
负责管理资源和协调Pipeline的运行
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
audio_binary_list: List[AudioBinary],
|
|
||||||
models: Dict[str, Any],
|
|
||||||
pipeline_list: List[Pipeline],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
初始化运行器
|
|
||||||
参数:
|
|
||||||
audio_binary_list: 音频二进制列表
|
|
||||||
models: 模型字典
|
|
||||||
pipeline_list: 管道列表
|
|
||||||
queue_size: 队列大小
|
|
||||||
stop_timeout: 停止超时时间(秒)
|
|
||||||
"""
|
|
||||||
# 接收资源
|
|
||||||
self._audio_binary_list = audio_binary_list
|
|
||||||
self._models = models
|
|
||||||
self._pipeline_list = pipeline_list
|
|
||||||
|
|
||||||
# 线程控制
|
|
||||||
self._lock = Lock()
|
|
||||||
|
|
||||||
# 消息队列
|
|
||||||
self._input_queue = Queue(maxsize=1000)
|
|
||||||
|
|
||||||
# 停止控制
|
|
||||||
self._stop_timeout = 10.0
|
|
||||||
self._is_stopping = False
|
|
||||||
|
|
||||||
# 配置资源
|
|
||||||
for pipeline in self._pipeline_list:
|
|
||||||
# 设置输入队列
|
|
||||||
pipeline.set_input_queue(self._input_queue)
|
|
||||||
|
|
||||||
# 配置资源
|
|
||||||
pipeline.set_audio_binary(
|
|
||||||
self._audio_binary_list[pipeline.get_config("audio_binary_name")]
|
|
||||||
)
|
|
||||||
pipeline.set_models(self._models)
|
|
||||||
|
|
||||||
def adder(self, data: Any) -> None:
|
|
||||||
"""
|
|
||||||
添加数据到输入队列
|
|
||||||
参数:
|
|
||||||
data: 要添加的数据
|
|
||||||
"""
|
|
||||||
if not self._pipeline_list:
|
|
||||||
raise RuntimeError("没有可用的管道")
|
|
||||||
if self._is_stopping:
|
|
||||||
raise RuntimeError("运行器正在停止,无法添加数据")
|
|
||||||
self._input_queue.put(data)
|
|
||||||
|
|
||||||
def add_recevier(self, receiver: callable) -> None:
|
|
||||||
"""
|
|
||||||
添加数据接收者
|
|
||||||
参数:
|
|
||||||
receiver: 接收数据的回调函数
|
|
||||||
"""
|
|
||||||
with self._lock:
|
|
||||||
for pipeline in self._pipeline_list:
|
|
||||||
pipeline.add_callback(receiver)
|
|
||||||
|
|
||||||
def run(self) -> None:
|
|
||||||
"""
|
|
||||||
启动所有管道
|
|
||||||
"""
|
|
||||||
logger.info("[%s] 启动所有管道", self.__class__.__name__)
|
|
||||||
if not self._pipeline_list:
|
|
||||||
raise RuntimeError("没有可用的管道")
|
|
||||||
|
|
||||||
# 启动所有管道
|
|
||||||
for pipeline in self._pipeline_list:
|
|
||||||
thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}")
|
|
||||||
thread.daemon = True
|
|
||||||
thread.start()
|
|
||||||
logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline))
|
|
||||||
|
|
||||||
def stop(self, force: bool = False) -> bool:
|
|
||||||
"""
|
|
||||||
停止所有管道
|
|
||||||
参数:
|
|
||||||
force: 是否强制停止
|
|
||||||
返回:
|
|
||||||
bool: 是否成功停止
|
|
||||||
"""
|
|
||||||
if self._is_stopping:
|
|
||||||
logger.warning("运行器已经在停止中")
|
|
||||||
return False
|
|
||||||
|
|
||||||
self._is_stopping = True
|
|
||||||
logger.info("正在停止运行器...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 发送结束信号
|
|
||||||
self._input_queue.put(None)
|
|
||||||
|
|
||||||
# 停止所有管道
|
|
||||||
success = True
|
|
||||||
for pipeline in self._pipeline_list:
|
|
||||||
if force:
|
|
||||||
pipeline.force_stop()
|
|
||||||
else:
|
|
||||||
if not pipeline.stop(timeout=self._stop_timeout):
|
|
||||||
logger.warning("管道 %s 停止超时", id(pipeline))
|
|
||||||
success = False
|
|
||||||
|
|
||||||
# 等待队列处理完成
|
|
||||||
try:
|
|
||||||
start_time = time.time()
|
|
||||||
while not self._input_queue.empty():
|
|
||||||
if time.time() - start_time > self._stop_timeout:
|
|
||||||
logger.warning(
|
|
||||||
"等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理",
|
|
||||||
self._stop_timeout,
|
|
||||||
self._input_queue.qsize(),
|
|
||||||
)
|
|
||||||
success = False
|
|
||||||
break
|
|
||||||
time.sleep(0.1) # 避免过度消耗CPU
|
|
||||||
except Exception as e:
|
|
||||||
error_type = type(e).__name__
|
|
||||||
error_msg = str(e)
|
|
||||||
error_traceback = traceback.format_exc()
|
|
||||||
logger.error(
|
|
||||||
"等待队列处理完成时发生错误:\n"
|
|
||||||
"错误类型: %s\n"
|
|
||||||
"错误信息: %s\n"
|
|
||||||
"错误堆栈:\n%s",
|
|
||||||
error_type,
|
|
||||||
error_msg,
|
|
||||||
error_traceback,
|
|
||||||
)
|
|
||||||
success = False
|
|
||||||
|
|
||||||
if success:
|
|
||||||
logger.info("所有管道已成功停止")
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"部分管道停止失败,队列状态: 大小=%d, 是否为空=%s",
|
|
||||||
self._input_queue.qsize(),
|
|
||||||
self._input_queue.empty(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return success
|
|
||||||
|
|
||||||
finally:
|
|
||||||
self._is_stopping = False
|
|
||||||
|
|
||||||
def __del__(self) -> None:
|
|
||||||
"""
|
|
||||||
析构函数
|
|
||||||
"""
|
|
||||||
self.stop(force=True)
|
|
||||||
|
|
||||||
|
|
||||||
class STTRunnerFactory:
|
|
||||||
"""
|
|
||||||
STT Runner工厂类
|
|
||||||
用于创建运行器实例
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _create_runner(
|
|
||||||
audio_binary_name: str,
|
|
||||||
model_name_list: List[str],
|
|
||||||
pipeline_name_list: List[str],
|
|
||||||
) -> STTRunner:
|
|
||||||
"""
|
|
||||||
创建运行器
|
|
||||||
参数:
|
|
||||||
audio_binary_name: 音频二进制名称
|
|
||||||
model_name_list: 模型名称列表
|
|
||||||
pipeline_name_list: 管道名称列表
|
|
||||||
返回:
|
|
||||||
Runner实例
|
|
||||||
"""
|
|
||||||
audio_binary = audio_chunk.get_audio_binary(audio_binary_name)
|
|
||||||
models: Dict[str, Any] = {
|
|
||||||
model_name: models_loaded.models[model_name]
|
|
||||||
for model_name in model_name_list
|
|
||||||
}
|
|
||||||
pipelines: List[Pipeline] = [
|
|
||||||
PipelineFactory.create_pipeline(pipeline_name)
|
|
||||||
for pipeline_name in pipeline_name_list
|
|
||||||
]
|
|
||||||
return STTRunner(
|
|
||||||
audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_runner_from_config(
|
|
||||||
cls,
|
|
||||||
config: Dict[str, Any],
|
|
||||||
) -> STTRunner:
|
|
||||||
"""
|
|
||||||
从配置创建运行器
|
|
||||||
参数:
|
|
||||||
config: 配置字典
|
|
||||||
返回:
|
|
||||||
Runner实例
|
|
||||||
"""
|
|
||||||
audio_binary_name = config["audio_binary_name"]
|
|
||||||
model_name_list = config["model_name_list"]
|
|
||||||
pipeline_name_list = config["pipeline_name_list"]
|
|
||||||
return cls._create_runner(
|
|
||||||
audio_binary_name, model_name_list, pipeline_name_list
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_runner_normal(cls) -> STTRunner:
|
|
||||||
"""
|
|
||||||
创建默认运行器
|
|
||||||
返回:
|
|
||||||
Runner实例
|
|
||||||
"""
|
|
||||||
audio_binary_name = None
|
|
||||||
model_name_list = list(models_loaded.models.keys())
|
|
||||||
pipeline_name_list = None
|
|
||||||
return cls._create_runner(
|
|
||||||
audio_binary_name, model_name_list, pipeline_name_list
|
|
||||||
)
|
|
227
src/runner/ASRRunner.py
Normal file
227
src/runner/ASRRunner.py
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
"""
|
||||||
|
-*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
ASRRunner
|
||||||
|
继承RunnerBase
|
||||||
|
专属pipeline为ASRPipeline
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.pipeline.ASRpipeline import ASRPipeline
|
||||||
|
from src.pipeline import PipelineFactory
|
||||||
|
from src.models import AudioBinary_data_list, AudioBinary_Config
|
||||||
|
from src.core.model_loader import ModelLoader
|
||||||
|
from queue import Queue
|
||||||
|
import soundfile
|
||||||
|
import time
|
||||||
|
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
OVAERWATCH = False
|
||||||
|
|
||||||
|
model_loader = ModelLoader()
|
||||||
|
|
||||||
|
|
||||||
|
def test_asr_pipeline():
|
||||||
|
# 加载模型
|
||||||
|
args = {
|
||||||
|
"asr_model": "paraformer-zh",
|
||||||
|
"asr_model_revision": "v2.0.4",
|
||||||
|
"vad_model": "fsmn-vad",
|
||||||
|
"vad_model_revision": "v2.0.4",
|
||||||
|
"spk_model": "cam++",
|
||||||
|
"spk_model_revision": "v2.0.2",
|
||||||
|
"audio_update": False,
|
||||||
|
}
|
||||||
|
models = model_loader.load_models(args)
|
||||||
|
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
|
||||||
|
audio_config = AudioBinary_Config(
|
||||||
|
chunk_size=200,
|
||||||
|
chunk_stride=1600,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
sample_width=16,
|
||||||
|
channels=1,
|
||||||
|
)
|
||||||
|
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
|
||||||
|
audio_config.chunk_stride = chunk_stride
|
||||||
|
|
||||||
|
# 创建参数Dict
|
||||||
|
config = {
|
||||||
|
"audio_config": audio_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 创建音频数据列表
|
||||||
|
audio_binary_data_list = AudioBinary_data_list()
|
||||||
|
|
||||||
|
input_queue = Queue()
|
||||||
|
|
||||||
|
# 创建Pipeline
|
||||||
|
# asr_pipeline = ASRPipeline()
|
||||||
|
# asr_pipeline.set_models(models)
|
||||||
|
# asr_pipeline.set_config(config)
|
||||||
|
# asr_pipeline.set_audio_binary(audio_binary_data_list)
|
||||||
|
# asr_pipeline.set_input_queue(input_queue)
|
||||||
|
# asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
|
||||||
|
# asr_pipeline.bake()
|
||||||
|
asr_pipeline = PipelineFactory.create_pipeline(
|
||||||
|
pipeline_name = "ASRpipeline",
|
||||||
|
models=models,
|
||||||
|
config=config,
|
||||||
|
audio_binary=audio_binary_data_list,
|
||||||
|
input_queue=input_queue,
|
||||||
|
callback=lambda x: print(f"pipeline callback: {x}")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 运行Pipeline
|
||||||
|
asr_instance = asr_pipeline.run()
|
||||||
|
|
||||||
|
|
||||||
|
audio_clip_len = 200
|
||||||
|
print(
|
||||||
|
f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}"
|
||||||
|
)
|
||||||
|
for i in range(0, len(audio_data), audio_clip_len):
|
||||||
|
input_queue.put(audio_data[i : i + audio_clip_len])
|
||||||
|
|
||||||
|
# time.sleep(10)
|
||||||
|
# input_queue.put(None)
|
||||||
|
|
||||||
|
# 等待Pipeline结束
|
||||||
|
# asr_instance.join()
|
||||||
|
|
||||||
|
time.sleep(5)
|
||||||
|
asr_pipeline.stop()
|
||||||
|
# asr_pipeline.stop()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class STTRunner(RunnerBase):
|
||||||
|
"""
|
||||||
|
运行器类
|
||||||
|
负责管理资源和协调Pipeline的运行
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
# ws资源
|
||||||
|
self._ws_pool: Dict[str,List[WebSocketClient]] = {}
|
||||||
|
# 接收资源
|
||||||
|
self._audio_binary_list = audio_binary_list
|
||||||
|
self._models = models
|
||||||
|
self._pipeline_list = pipeline_list
|
||||||
|
|
||||||
|
# 线程控制
|
||||||
|
self._lock = Lock()
|
||||||
|
# 停止控制
|
||||||
|
self._stop_timeout = 10.0
|
||||||
|
self._is_stopping = False
|
||||||
|
|
||||||
|
# 配置资源
|
||||||
|
for pipeline in self._pipeline_list:
|
||||||
|
# 设置输入队列
|
||||||
|
pipeline.set_input_queue(self._input_queue)
|
||||||
|
|
||||||
|
# 配置资源
|
||||||
|
pipeline.set_audio_binary(
|
||||||
|
self._audio_binary_list[pipeline.get_config("audio_binary_name")]
|
||||||
|
)
|
||||||
|
pipeline.set_models(self._models)
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
"""
|
||||||
|
启动所有管道
|
||||||
|
"""
|
||||||
|
logger.info("[%s] 启动所有管道", self.__class__.__name__)
|
||||||
|
if not self._pipeline_list:
|
||||||
|
raise RuntimeError("没有可用的管道")
|
||||||
|
|
||||||
|
# 启动所有管道
|
||||||
|
for pipeline in self._pipeline_list:
|
||||||
|
thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}")
|
||||||
|
thread.daemon = True
|
||||||
|
thread.start()
|
||||||
|
logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline))
|
||||||
|
|
||||||
|
def stop(self, force: bool = False) -> bool:
|
||||||
|
"""
|
||||||
|
停止所有管道
|
||||||
|
参数:
|
||||||
|
force: 是否强制停止
|
||||||
|
返回:
|
||||||
|
bool: 是否成功停止
|
||||||
|
"""
|
||||||
|
if self._is_stopping:
|
||||||
|
logger.warning("运行器已经在停止中")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._is_stopping = True
|
||||||
|
logger.info("正在停止运行器...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 发送结束信号
|
||||||
|
self._input_queue.put(None)
|
||||||
|
|
||||||
|
# 停止所有管道
|
||||||
|
success = True
|
||||||
|
for pipeline in self._pipeline_list:
|
||||||
|
if force:
|
||||||
|
pipeline.force_stop()
|
||||||
|
else:
|
||||||
|
if not pipeline.stop(timeout=self._stop_timeout):
|
||||||
|
logger.warning("管道 %s 停止超时", id(pipeline))
|
||||||
|
success = False
|
||||||
|
|
||||||
|
# 等待队列处理完成
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
while not self._input_queue.empty():
|
||||||
|
if time.time() - start_time > self._stop_timeout:
|
||||||
|
logger.warning(
|
||||||
|
"等待队列处理完成超时(%s秒), 队列中还有 %d 个任务未处理",
|
||||||
|
self._stop_timeout,
|
||||||
|
self._input_queue.qsize(),
|
||||||
|
)
|
||||||
|
success = False
|
||||||
|
break
|
||||||
|
time.sleep(0.1) # 避免过度消耗CPU
|
||||||
|
except Exception as e:
|
||||||
|
error_type = type(e).__name__
|
||||||
|
error_msg = str(e)
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
logger.error(
|
||||||
|
"等待队列处理完成时发生错误:\n"
|
||||||
|
"错误类型: %s\n"
|
||||||
|
"错误信息: %s\n"
|
||||||
|
"错误堆栈:\n%s",
|
||||||
|
error_type,
|
||||||
|
error_msg,
|
||||||
|
error_traceback,
|
||||||
|
)
|
||||||
|
success = False
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info("所有管道已成功停止")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"部分管道停止失败, 队列状态: 大小=%d, 是否为空=%s",
|
||||||
|
self._input_queue.qsize(),
|
||||||
|
self._input_queue.empty(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self._is_stopping = False
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
"""
|
||||||
|
析构函数
|
||||||
|
"""
|
||||||
|
self.stop(force=True)
|
105
src/runner/runner.py
Normal file
105
src/runner/runner.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
"""
|
||||||
|
-*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
Runner类
|
||||||
|
所有的Runner都对应一个fastapi的endpoint,
|
||||||
|
Runner需要处理:
|
||||||
|
1.新的websocket 进来后放到 unknow_websocket_pool中
|
||||||
|
2.收到特定消息后, 将消息转发给特定的pipeline处理
|
||||||
|
3.管理pipeline与websocket对应关系, 管理pipeline的ID
|
||||||
|
4.管理pipeline的启动和停止
|
||||||
|
5.管理所有pipeline用到的资源, 管理pipeline的存活时间。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
from threading import Thread, Lock
|
||||||
|
from queue import Queue
|
||||||
|
import traceback
|
||||||
|
import time
|
||||||
|
|
||||||
|
from src.audio_chunk import AudioChunk, AudioBinary
|
||||||
|
from src.pipeline import Pipeline, PipelineFactory
|
||||||
|
from src.core.model_loader import ModelLoader
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
audio_chunk = AudioChunk()
|
||||||
|
models_loaded = ModelLoader()
|
||||||
|
|
||||||
|
|
||||||
|
class RunnerBase(ABC):
|
||||||
|
"""
|
||||||
|
运行器基类
|
||||||
|
定义了运行器的基本接口
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class STTRunnerFactory:
|
||||||
|
"""
|
||||||
|
STT Runner工厂类
|
||||||
|
用于创建运行器实例
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_runner(
|
||||||
|
audio_binary_name: str,
|
||||||
|
model_name_list: List[str],
|
||||||
|
pipeline_name_list: List[str],
|
||||||
|
) -> STTRunner:
|
||||||
|
"""
|
||||||
|
创建运行器
|
||||||
|
参数:
|
||||||
|
audio_binary_name: 音频二进制名称
|
||||||
|
model_name_list: 模型名称列表
|
||||||
|
pipeline_name_list: 管道名称列表
|
||||||
|
返回:
|
||||||
|
Runner实例
|
||||||
|
"""
|
||||||
|
audio_binary = audio_chunk.get_audio_binary(audio_binary_name)
|
||||||
|
models: Dict[str, Any] = {
|
||||||
|
model_name: models_loaded.models[model_name]
|
||||||
|
for model_name in model_name_list
|
||||||
|
}
|
||||||
|
pipelines: List[Pipeline] = [
|
||||||
|
PipelineFactory.create_pipeline(pipeline_name)
|
||||||
|
for pipeline_name in pipeline_name_list
|
||||||
|
]
|
||||||
|
return STTRunner(
|
||||||
|
audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_runner_from_config(
|
||||||
|
cls,
|
||||||
|
config: Dict[str, Any],
|
||||||
|
) -> STTRunner:
|
||||||
|
"""
|
||||||
|
从配置创建运行器
|
||||||
|
参数:
|
||||||
|
config: 配置字典
|
||||||
|
返回:
|
||||||
|
Runner实例
|
||||||
|
"""
|
||||||
|
audio_binary_name = config["audio_binary_name"]
|
||||||
|
model_name_list = config["model_name_list"]
|
||||||
|
pipeline_name_list = config["pipeline_name_list"]
|
||||||
|
return cls._create_runner(
|
||||||
|
audio_binary_name, model_name_list, pipeline_name_list
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_runner_normal(cls) -> STTRunner:
|
||||||
|
"""
|
||||||
|
创建默认运行器
|
||||||
|
返回:
|
||||||
|
Runner实例
|
||||||
|
"""
|
||||||
|
audio_binary_name = None
|
||||||
|
model_name_list = list(models_loaded.models.keys())
|
||||||
|
pipeline_name_list = None
|
||||||
|
return cls._create_runner(
|
||||||
|
audio_binary_name, model_name_list, pipeline_name_list
|
||||||
|
)
|
@ -7,7 +7,7 @@ from src.functor.vad_functor import VADFunctor
|
|||||||
from src.functor.asr_functor import ASRFunctor
|
from src.functor.asr_functor import ASRFunctor
|
||||||
from src.functor.spk_functor import SPKFunctor
|
from src.functor.spk_functor import SPKFunctor
|
||||||
from queue import Queue, Empty
|
from queue import Queue, Empty
|
||||||
from src.model_loader import ModelLoader
|
from src.core.model_loader import ModelLoader
|
||||||
from src.models import AudioBinary_Config, AudioBinary_data_list
|
from src.models import AudioBinary_Config, AudioBinary_data_list
|
||||||
from src.utils.data_format import wav_to_bytes
|
from src.utils.data_format import wav_to_bytes
|
||||||
import time
|
import time
|
||||||
|
@ -61,7 +61,7 @@ def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
|
|||||||
|
|
||||||
# from src.functor.model_loader import load_models
|
# from src.functor.model_loader import load_models
|
||||||
# models = load_models(args)
|
# models = load_models(args)
|
||||||
from src.model_loader import ModelLoader
|
from src.core.model_loader import ModelLoader
|
||||||
|
|
||||||
models = ModelLoader(args)
|
models = ModelLoader(args)
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ VAD+ASR+SPK(FAKE)
|
|||||||
from src.pipeline.ASRpipeline import ASRPipeline
|
from src.pipeline.ASRpipeline import ASRPipeline
|
||||||
from src.pipeline import PipelineFactory
|
from src.pipeline import PipelineFactory
|
||||||
from src.models import AudioBinary_data_list, AudioBinary_Config
|
from src.models import AudioBinary_data_list, AudioBinary_Config
|
||||||
from src.model_loader import ModelLoader
|
from src.core.model_loader import ModelLoader
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
import soundfile
|
import soundfile
|
||||||
import time
|
import time
|
||||||
|
0
tests/runner/stt_runner.py
Normal file
0
tests/runner/stt_runner.py
Normal file
Loading…
x
Reference in New Issue
Block a user