[Runner]完成ASRRunner的编写和测试,使用MockWebSocket完成虚拟网络连接。
This commit is contained in:
parent
d5b9953905
commit
5dac718dee
@ -128,7 +128,7 @@ class AudioChunk:
|
|||||||
此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑。
|
此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_instance: Optional[AudioChunk] = None
|
_instance: Optional["AudioChunk"] = None
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -1,11 +1,25 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""
|
"""
|
||||||
|
默认配置DefaultConfig
|
||||||
|
- audio_config: 音频配置
|
||||||
配置模块 - 处理命令行参数和配置项
|
配置模块 - 处理命令行参数和配置项
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
from src.models import AudioBinary_Config
|
||||||
|
|
||||||
|
class DefaultConfig:
|
||||||
|
"""
|
||||||
|
默认配置
|
||||||
|
"""
|
||||||
|
audio_config = AudioBinary_Config(
|
||||||
|
chunk_size=200,
|
||||||
|
chunk_stride=1600,
|
||||||
|
sample_rate=16000,
|
||||||
|
sample_width=16,
|
||||||
|
channels=1,
|
||||||
|
)
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
"""
|
"""
|
||||||
|
@ -136,7 +136,7 @@ class PipelineFactory:
|
|||||||
pipeline.set_audio_binary(kwargs["audio_binary"])
|
pipeline.set_audio_binary(kwargs["audio_binary"])
|
||||||
pipeline.set_input_queue(kwargs["input_queue"])
|
pipeline.set_input_queue(kwargs["input_queue"])
|
||||||
pipeline.set_callback(kwargs["callback"])
|
pipeline.set_callback(kwargs["callback"])
|
||||||
pipeline.bake()
|
# pipeline.bake()
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -10,9 +10,15 @@ from src.pipeline.ASRpipeline import ASRPipeline
|
|||||||
from src.pipeline import PipelineFactory
|
from src.pipeline import PipelineFactory
|
||||||
from src.models import AudioBinary_data_list, AudioBinary_Config
|
from src.models import AudioBinary_data_list, AudioBinary_Config
|
||||||
from src.core.model_loader import ModelLoader
|
from src.core.model_loader import ModelLoader
|
||||||
|
from src.config import DefaultConfig
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
import soundfile
|
import soundfile
|
||||||
import time
|
import time
|
||||||
|
from typing import List, Optional
|
||||||
|
import uuid
|
||||||
|
from threading import Thread
|
||||||
|
from src.utils.mock_websocket import MockWebSocketClient as WebSocketClient
|
||||||
|
from .runner import RunnerBase
|
||||||
|
|
||||||
from src.utils.logger import get_module_logger
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
@ -23,205 +29,175 @@ OVAERWATCH = False
|
|||||||
|
|
||||||
model_loader = ModelLoader()
|
model_loader = ModelLoader()
|
||||||
|
|
||||||
|
class ASRRunner(RunnerBase):
|
||||||
def test_asr_pipeline():
|
|
||||||
# 加载模型
|
|
||||||
args = {
|
|
||||||
"asr_model": "paraformer-zh",
|
|
||||||
"asr_model_revision": "v2.0.4",
|
|
||||||
"vad_model": "fsmn-vad",
|
|
||||||
"vad_model_revision": "v2.0.4",
|
|
||||||
"spk_model": "cam++",
|
|
||||||
"spk_model_revision": "v2.0.2",
|
|
||||||
"audio_update": False,
|
|
||||||
}
|
|
||||||
models = model_loader.load_models(args)
|
|
||||||
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
|
|
||||||
audio_config = AudioBinary_Config(
|
|
||||||
chunk_size=200,
|
|
||||||
chunk_stride=1600,
|
|
||||||
sample_rate=sample_rate,
|
|
||||||
sample_width=16,
|
|
||||||
channels=1,
|
|
||||||
)
|
|
||||||
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
|
|
||||||
audio_config.chunk_stride = chunk_stride
|
|
||||||
|
|
||||||
# 创建参数Dict
|
|
||||||
config = {
|
|
||||||
"audio_config": audio_config,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 创建音频数据列表
|
|
||||||
audio_binary_data_list = AudioBinary_data_list()
|
|
||||||
|
|
||||||
input_queue = Queue()
|
|
||||||
|
|
||||||
# 创建Pipeline
|
|
||||||
# asr_pipeline = ASRPipeline()
|
|
||||||
# asr_pipeline.set_models(models)
|
|
||||||
# asr_pipeline.set_config(config)
|
|
||||||
# asr_pipeline.set_audio_binary(audio_binary_data_list)
|
|
||||||
# asr_pipeline.set_input_queue(input_queue)
|
|
||||||
# asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
|
|
||||||
# asr_pipeline.bake()
|
|
||||||
asr_pipeline = PipelineFactory.create_pipeline(
|
|
||||||
pipeline_name = "ASRpipeline",
|
|
||||||
models=models,
|
|
||||||
config=config,
|
|
||||||
audio_binary=audio_binary_data_list,
|
|
||||||
input_queue=input_queue,
|
|
||||||
callback=lambda x: print(f"pipeline callback: {x}")
|
|
||||||
)
|
|
||||||
|
|
||||||
# 运行Pipeline
|
|
||||||
asr_instance = asr_pipeline.run()
|
|
||||||
|
|
||||||
|
|
||||||
audio_clip_len = 200
|
|
||||||
print(
|
|
||||||
f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}"
|
|
||||||
)
|
|
||||||
for i in range(0, len(audio_data), audio_clip_len):
|
|
||||||
input_queue.put(audio_data[i : i + audio_clip_len])
|
|
||||||
|
|
||||||
# time.sleep(10)
|
|
||||||
# input_queue.put(None)
|
|
||||||
|
|
||||||
# 等待Pipeline结束
|
|
||||||
# asr_instance.join()
|
|
||||||
|
|
||||||
time.sleep(5)
|
|
||||||
asr_pipeline.stop()
|
|
||||||
# asr_pipeline.stop()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class STTRunner(RunnerBase):
|
|
||||||
"""
|
"""
|
||||||
运行器类
|
运行器类
|
||||||
负责管理资源和协调Pipeline的运行
|
负责管理资源和协调Pipeline的运行
|
||||||
"""
|
"""
|
||||||
|
class SenderAndReceiver:
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
|
对于单个pipeline的管理
|
||||||
|
包含 发送者 和 接收者
|
||||||
|
_sender: 发送者 唯一
|
||||||
|
_receiver: 接收者 可以有多个
|
||||||
|
_pipeline: 对应管道 唯一
|
||||||
"""
|
"""
|
||||||
# ws资源
|
def __init__(self, *args, **kwargs):
|
||||||
self._ws_pool: Dict[str,List[WebSocketClient]] = {}
|
# 可选传入参数,
|
||||||
# 接收资源
|
self._name: str = kwargs.get("name", "")
|
||||||
self._audio_binary_list = audio_binary_list
|
self._sender: Optional[WebSocketClient] = kwargs.get("sender", None)
|
||||||
self._models = models
|
self._receiver: List[WebSocketClient] = kwargs.get("receiver", [])
|
||||||
self._pipeline_list = pipeline_list
|
|
||||||
|
|
||||||
# 线程控制
|
# 资源
|
||||||
self._lock = Lock()
|
self._audio_config: AudioBinary_Config = kwargs.get("audio_config", DefaultConfig.audio_config)
|
||||||
# 停止控制
|
self._models: dict = kwargs.get("models", None)
|
||||||
self._stop_timeout = 10.0
|
self._audio_binary: AudioBinary_data_list = AudioBinary_data_list()
|
||||||
self._is_stopping = False
|
# id唯一标识
|
||||||
|
self._id: str = str(uuid.uuid4())
|
||||||
|
# 输入队列
|
||||||
|
self._input_queue: Queue = Queue()
|
||||||
|
self._pipeline: Optional[ASRPipeline] = None
|
||||||
|
|
||||||
# 配置资源
|
def set_name(self, name: str):
|
||||||
for pipeline in self._pipeline_list:
|
self._name = name
|
||||||
# 设置输入队列
|
|
||||||
pipeline.set_input_queue(self._input_queue)
|
|
||||||
|
|
||||||
# 配置资源
|
def set_id(self, id: str):
|
||||||
pipeline.set_audio_binary(
|
self._id = id
|
||||||
self._audio_binary_list[pipeline.get_config("audio_binary_name")]
|
|
||||||
)
|
|
||||||
pipeline.set_models(self._models)
|
|
||||||
|
|
||||||
def run(self) -> None:
|
def set_sender(self, sender: WebSocketClient):
|
||||||
|
self._sender = sender
|
||||||
|
|
||||||
|
def set_pipeline(self, pipeline: ASRPipeline):
|
||||||
|
self._pipeline = pipeline
|
||||||
|
config = {
|
||||||
|
"audio_config": self._audio_config,
|
||||||
|
}
|
||||||
|
self._pipeline.set_config(config)
|
||||||
|
self._pipeline.set_models(self._models)
|
||||||
|
self._pipeline.set_audio_binary(self._audio_binary)
|
||||||
|
self._pipeline.set_input_queue(self._input_queue)
|
||||||
|
self._pipeline.set_callback(self.deal_message)
|
||||||
|
self._pipeline.bake()
|
||||||
|
|
||||||
|
def append_receiver(self, receiver: WebSocketClient):
|
||||||
|
self._receiver.append(receiver)
|
||||||
|
|
||||||
|
def delete_receiver(self, receiver: WebSocketClient):
|
||||||
|
self._receiver.remove(receiver)
|
||||||
|
|
||||||
|
def deal_message(self, message: str):
|
||||||
|
self.broadcast(message)
|
||||||
|
|
||||||
|
def broadcast(self, message: str):
|
||||||
"""
|
"""
|
||||||
启动所有管道
|
广播发送给所有接收者
|
||||||
"""
|
"""
|
||||||
logger.info("[%s] 启动所有管道", self.__class__.__name__)
|
logger.info("[ASRRunner][SAR-%s]广播发送给所有接收者: %s", self._name, message)
|
||||||
if not self._pipeline_list:
|
for receiver in self._receiver:
|
||||||
raise RuntimeError("没有可用的管道")
|
receiver.send(message)
|
||||||
|
|
||||||
# 启动所有管道
|
def _run(self):
|
||||||
for pipeline in self._pipeline_list:
|
|
||||||
thread = Thread(target=pipeline.run, name=f"Pipeline-{id(pipeline)}")
|
|
||||||
thread.daemon = True
|
|
||||||
thread.start()
|
|
||||||
logger.info("[%s] 管道 %s 已启动", self.__class__.__name__, id(pipeline))
|
|
||||||
|
|
||||||
def stop(self, force: bool = False) -> bool:
|
|
||||||
"""
|
"""
|
||||||
停止所有管道
|
运行SAR
|
||||||
参数:
|
|
||||||
force: 是否强制停止
|
|
||||||
返回:
|
|
||||||
bool: 是否成功停止
|
|
||||||
"""
|
"""
|
||||||
if self._is_stopping:
|
self._pipeline.run()
|
||||||
logger.warning("运行器已经在停止中")
|
while True:
|
||||||
return False
|
data = self._sender.recv()
|
||||||
|
if data is None:
|
||||||
self._is_stopping = True
|
|
||||||
logger.info("正在停止运行器...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 发送结束信号
|
|
||||||
self._input_queue.put(None)
|
|
||||||
|
|
||||||
# 停止所有管道
|
|
||||||
success = True
|
|
||||||
for pipeline in self._pipeline_list:
|
|
||||||
if force:
|
|
||||||
pipeline.force_stop()
|
|
||||||
else:
|
|
||||||
if not pipeline.stop(timeout=self._stop_timeout):
|
|
||||||
logger.warning("管道 %s 停止超时", id(pipeline))
|
|
||||||
success = False
|
|
||||||
|
|
||||||
# 等待队列处理完成
|
|
||||||
try:
|
|
||||||
start_time = time.time()
|
|
||||||
while not self._input_queue.empty():
|
|
||||||
if time.time() - start_time > self._stop_timeout:
|
|
||||||
logger.warning(
|
|
||||||
"等待队列处理完成超时(%s秒), 队列中还有 %d 个任务未处理",
|
|
||||||
self._stop_timeout,
|
|
||||||
self._input_queue.qsize(),
|
|
||||||
)
|
|
||||||
success = False
|
|
||||||
break
|
break
|
||||||
time.sleep(0.1) # 避免过度消耗CPU
|
logger.debug("[ASRRunner][SAR-%s]接收到的数据length: %s", self._name, len(data))
|
||||||
|
self._input_queue.put(data)
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""
|
||||||
|
运行SAR
|
||||||
|
"""
|
||||||
|
self._thread = Thread(target=self._run, name=f"[ASRRunner]SAR-{self._name}")
|
||||||
|
self._thread.daemon = True
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""
|
||||||
|
停止SAR
|
||||||
|
"""
|
||||||
|
self._pipeline.stop()
|
||||||
|
for ws in self._receiver:
|
||||||
|
ws.close()
|
||||||
|
self._sender.close()
|
||||||
|
|
||||||
|
def __init__(self,*args,**kwargs):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
# 接收资源
|
||||||
|
self._default_audio_config = kwargs.get("audio_config", DefaultConfig.audio_config)
|
||||||
|
# self._audio_binary_list = args.get("audio_binary_list", None)
|
||||||
|
self._default_models = kwargs.get("models", None)
|
||||||
|
self._SAR_list: List[self.SenderAndReceiver] = []
|
||||||
|
|
||||||
|
def set_default_config(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
设置配置
|
||||||
|
"""
|
||||||
|
self._default_audio_config = kwargs.get("audio_config", self._default_audio_config)
|
||||||
|
self._default_models = kwargs.get("models", self._default_models)
|
||||||
|
|
||||||
|
def new_SAR(
|
||||||
|
self,
|
||||||
|
ws: "WebSocketClient",
|
||||||
|
name: str = "",
|
||||||
|
audio_config: "AudioBinary_Config" = None,
|
||||||
|
models: dict = None
|
||||||
|
) -> uuid.UUID:
|
||||||
|
"""
|
||||||
|
创建新的SAR SenderAndReceiver
|
||||||
|
"""
|
||||||
|
if audio_config is None:
|
||||||
|
audio_config = self._default_audio_config
|
||||||
|
if models is None:
|
||||||
|
models = self._default_models
|
||||||
|
|
||||||
|
try:
|
||||||
|
new_SAR = self.SenderAndReceiver(
|
||||||
|
name=name,
|
||||||
|
audio_config=audio_config,
|
||||||
|
models=models
|
||||||
|
)
|
||||||
|
new_pipeline = ASRPipeline()
|
||||||
|
new_SAR.set_pipeline(new_pipeline)
|
||||||
|
# new_SAR.set_pipeline()
|
||||||
|
logger.info("创建新的SAR: name %s, id %s", new_SAR._name, new_SAR._id)
|
||||||
|
new_SAR.set_sender(ws)
|
||||||
|
new_SAR.append_receiver(ws)
|
||||||
|
new_SAR.run()
|
||||||
|
self._SAR_list.append(new_SAR)
|
||||||
|
return new_SAR._id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_type = type(e).__name__
|
logger.error("创建管道失败: %s", e)
|
||||||
error_msg = str(e)
|
return None
|
||||||
error_traceback = traceback.format_exc()
|
|
||||||
logger.error(
|
|
||||||
"等待队列处理完成时发生错误:\n"
|
|
||||||
"错误类型: %s\n"
|
|
||||||
"错误信息: %s\n"
|
|
||||||
"错误堆栈:\n%s",
|
|
||||||
error_type,
|
|
||||||
error_msg,
|
|
||||||
error_traceback,
|
|
||||||
)
|
|
||||||
success = False
|
|
||||||
|
|
||||||
if success:
|
def join_SAR(
|
||||||
logger.info("所有管道已成功停止")
|
self,
|
||||||
else:
|
ws: "WebSocketClient",
|
||||||
logger.warning(
|
name: Optional[str] = None,
|
||||||
"部分管道停止失败, 队列状态: 大小=%d, 是否为空=%s",
|
id: Optional[str] = None,
|
||||||
self._input_queue.qsize(),
|
) -> bool:
|
||||||
self._input_queue.empty(),
|
"""
|
||||||
)
|
加入SAR的Receiver
|
||||||
|
"""
|
||||||
return success
|
# 使用next获取迭代器下一个元素,生成pipeline_list迭代器,按id停止
|
||||||
|
if id:
|
||||||
finally:
|
exist_pipeline = next((pipeline for pipeline in self._SAR_list if pipeline._id == id), None)
|
||||||
self._is_stopping = False
|
if name:
|
||||||
|
exist_pipeline = next((pipeline for pipeline in self._SAR_list if pipeline._name == name), None)
|
||||||
|
if exist_pipeline:
|
||||||
|
exist_pipeline.append_receiver(ws)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
"""
|
"""
|
||||||
析构函数
|
析构函数
|
||||||
"""
|
"""
|
||||||
self.stop(force=True)
|
for sar in self._SAR_list:
|
||||||
|
sar.stop()
|
||||||
|
@ -18,8 +18,8 @@ from queue import Queue
|
|||||||
import traceback
|
import traceback
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from src.audio_chunk import AudioChunk, AudioBinary
|
from src.audio_chunk import AudioChunk
|
||||||
from src.pipeline import Pipeline, PipelineFactory
|
from src.pipeline import PipelineFactory
|
||||||
from src.core.model_loader import ModelLoader
|
from src.core.model_loader import ModelLoader
|
||||||
from src.utils.logger import get_module_logger
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ class STTRunnerFactory:
|
|||||||
audio_binary_name: str,
|
audio_binary_name: str,
|
||||||
model_name_list: List[str],
|
model_name_list: List[str],
|
||||||
pipeline_name_list: List[str],
|
pipeline_name_list: List[str],
|
||||||
) -> STTRunner:
|
) -> RunnerBase:
|
||||||
"""
|
"""
|
||||||
创建运行器
|
创建运行器
|
||||||
参数:
|
参数:
|
||||||
@ -67,7 +67,7 @@ class STTRunnerFactory:
|
|||||||
PipelineFactory.create_pipeline(pipeline_name)
|
PipelineFactory.create_pipeline(pipeline_name)
|
||||||
for pipeline_name in pipeline_name_list
|
for pipeline_name in pipeline_name_list
|
||||||
]
|
]
|
||||||
return STTRunner(
|
return RunnerBase(
|
||||||
audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines
|
audio_binary_list=[audio_binary], models=models, pipeline_list=pipelines
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -75,7 +75,7 @@ class STTRunnerFactory:
|
|||||||
def create_runner_from_config(
|
def create_runner_from_config(
|
||||||
cls,
|
cls,
|
||||||
config: Dict[str, Any],
|
config: Dict[str, Any],
|
||||||
) -> STTRunner:
|
) -> RunnerBase:
|
||||||
"""
|
"""
|
||||||
从配置创建运行器
|
从配置创建运行器
|
||||||
参数:
|
参数:
|
||||||
@ -91,7 +91,7 @@ class STTRunnerFactory:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_runner_normal(cls) -> STTRunner:
|
def create_runner_normal(cls) -> RunnerBase:
|
||||||
"""
|
"""
|
||||||
创建默认运行器
|
创建默认运行器
|
||||||
返回:
|
返回:
|
||||||
|
55
src/utils/mock_websocket.py
Normal file
55
src/utils/mock_websocket.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import queue
|
||||||
|
|
||||||
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
class MockWebSocketClient:
|
||||||
|
"""A mock WebSocket client to simulate a connection for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.sent_messages = []
|
||||||
|
self._is_closed = False
|
||||||
|
self.receive_queue = queue.Queue()
|
||||||
|
|
||||||
|
def send(self, message: dict):
|
||||||
|
"""Simulates sending a message (which is a dict)."""
|
||||||
|
if self._is_closed:
|
||||||
|
print("Warning: sending message on a closed websocket")
|
||||||
|
return
|
||||||
|
self.sent_messages.append(message)
|
||||||
|
print(f"Mock WS received: {message}")
|
||||||
|
|
||||||
|
def recv(self):
|
||||||
|
"""Simulates receiving data from the WebSocket."""
|
||||||
|
if self._is_closed:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
# Block until data is available, with a timeout to prevent hanging.
|
||||||
|
data = self.receive_queue.get(timeout=10)
|
||||||
|
if data is None:
|
||||||
|
self._is_closed = True
|
||||||
|
return data
|
||||||
|
except queue.Empty:
|
||||||
|
print("Mock WS recv timeout")
|
||||||
|
self._is_closed = True
|
||||||
|
return None
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Simulates closing the WebSocket connection."""
|
||||||
|
if not self._is_closed:
|
||||||
|
# Put None to unblock any waiting recv call
|
||||||
|
self.receive_queue.put(None)
|
||||||
|
self._is_closed = True
|
||||||
|
print("Mock WS closed")
|
||||||
|
|
||||||
|
def put_for_recv(self, data):
|
||||||
|
"""Puts data into the receive queue for the `recv` method to consume."""
|
||||||
|
if data is None:
|
||||||
|
return
|
||||||
|
logger.debug("Mock WS put_for_recv length: %s", len(data))
|
||||||
|
self.receive_queue.put(data)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_closed(self):
|
||||||
|
return self._is_closed
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
from tests.pipeline.asr_test import test_asr_pipeline
|
from tests.pipeline.asr_test import test_asr_pipeline
|
||||||
from src.utils.logger import get_module_logger, setup_root_logger
|
from src.utils.logger import get_module_logger, setup_root_logger
|
||||||
|
from tests.runner.stt_runner_test import test_asr_runner
|
||||||
|
|
||||||
setup_root_logger(level="INFO", log_file="logs/test_main.log")
|
setup_root_logger(level="INFO", log_file="logs/test_main.log")
|
||||||
logger = get_module_logger(__name__)
|
logger = get_module_logger(__name__)
|
||||||
@ -13,5 +14,8 @@ logger = get_module_logger(__name__)
|
|||||||
# logger.info("开始测试VAD函数器")
|
# logger.info("开始测试VAD函数器")
|
||||||
# test_vad_functor()
|
# test_vad_functor()
|
||||||
|
|
||||||
logger.info("开始测试ASR管道")
|
# logger.info("开始测试ASR管道")
|
||||||
test_asr_pipeline()
|
# test_asr_pipeline()
|
||||||
|
|
||||||
|
logger.info("开始测试ASRRunner")
|
||||||
|
test_asr_runner()
|
||||||
|
@ -71,6 +71,7 @@ def test_asr_pipeline():
|
|||||||
callback=lambda x: print(f"pipeline callback: {x}")
|
callback=lambda x: print(f"pipeline callback: {x}")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
asr_pipeline.bake()
|
||||||
# 运行Pipeline
|
# 运行Pipeline
|
||||||
asr_instance = asr_pipeline.run()
|
asr_instance = asr_pipeline.run()
|
||||||
|
|
||||||
|
75
tests/runner/asr_runner_test.py
Normal file
75
tests/runner/asr_runner_test.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
"""
|
||||||
|
ASRRunner test
|
||||||
|
"""
|
||||||
|
import queue
|
||||||
|
import time
|
||||||
|
import soundfile
|
||||||
|
import numpy as np
|
||||||
|
from src.runner.ASRRunner import ASRRunner
|
||||||
|
from src.core.model_loader import ModelLoader
|
||||||
|
from src.models import AudioBinary_Config
|
||||||
|
from src.utils.mock_websocket import MockWebSocketClient
|
||||||
|
|
||||||
|
def test_asr_runner():
|
||||||
|
"""
|
||||||
|
End-to-end test for ASRRunner.
|
||||||
|
1. Loads models.
|
||||||
|
2. Configures and initializes ASRRunner.
|
||||||
|
3. Creates a mock WebSocket client.
|
||||||
|
4. Starts a new SenderAndReceiver (SAR) instance in the runner.
|
||||||
|
5. Streams audio data via the mock WebSocket.
|
||||||
|
6. Asserts that the received transcription matches the expected text.
|
||||||
|
"""
|
||||||
|
# 1. Load models
|
||||||
|
model_loader = ModelLoader()
|
||||||
|
args = {
|
||||||
|
"asr_model": "paraformer-zh",
|
||||||
|
"asr_model_revision": "v2.0.4",
|
||||||
|
"vad_model": "fsmn-vad",
|
||||||
|
"vad_model_revision": "v2.0.4",
|
||||||
|
"spk_model": "cam++",
|
||||||
|
"spk_model_revision": "v2.0.2",
|
||||||
|
"audio_update": False,
|
||||||
|
}
|
||||||
|
models = model_loader.load_models(args)
|
||||||
|
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
|
||||||
|
|
||||||
|
# 2. Configure audio
|
||||||
|
audio_config = AudioBinary_Config(
|
||||||
|
chunk_size=200, # ms
|
||||||
|
chunk_stride=1600, # 10ms stride for 16kHz
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
sample_width=2, # 16-bit
|
||||||
|
channels=1,
|
||||||
|
)
|
||||||
|
audio_config.chunk_stride = int(audio_config.chunk_stride * sample_rate / 1000)
|
||||||
|
|
||||||
|
# 3. Setup ASRRunner
|
||||||
|
asr_runner = ASRRunner()
|
||||||
|
asr_runner.set_default_config(
|
||||||
|
audio_config=audio_config,
|
||||||
|
models=models,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Create Mock WebSocket and start SAR
|
||||||
|
mock_ws = MockWebSocketClient()
|
||||||
|
sar_id = asr_runner.new_SAR(
|
||||||
|
ws=mock_ws,
|
||||||
|
name="test_sar",
|
||||||
|
)
|
||||||
|
assert sar_id is not None, "Failed to create a new SAR instance"
|
||||||
|
|
||||||
|
# 5. Simulate streaming audio
|
||||||
|
print(f"Sending audio data of length {len(audio_data)} samples.")
|
||||||
|
audio_clip_len = 200
|
||||||
|
for i in range(0, len(audio_data), audio_clip_len):
|
||||||
|
chunk = audio_data[i : i + audio_clip_len]
|
||||||
|
if not isinstance(chunk, np.ndarray) or chunk.size == 0:
|
||||||
|
break
|
||||||
|
# Simulate receiving binary data over WebSocket
|
||||||
|
mock_ws.put_for_recv(chunk)
|
||||||
|
|
||||||
|
# 6. Wait for results and assert
|
||||||
|
time.sleep(10)
|
||||||
|
# Signal end of audio stream by sending None
|
||||||
|
mock_ws.put_for_recv(None)
|
Loading…
x
Reference in New Issue
Block a user