94 lines
2.4 KiB
Python
94 lines
2.4 KiB
Python
"""
|
|
Pipeline测试
|
|
VAD+ASR+SPK(FAKE)
|
|
"""
|
|
|
|
from src.pipeline.ASRpipeline import ASRPipeline
|
|
from src.pipeline import PipelineFactory
|
|
from src.models import AudioBinary_data_list, AudioBinary_Config
|
|
from src.model_loader import ModelLoader
|
|
from queue import Queue
|
|
import soundfile
|
|
import time
|
|
|
|
from src.utils.logger import get_module_logger
|
|
|
|
logger = get_module_logger(__name__)
|
|
|
|
|
|
OVAERWATCH = False
|
|
|
|
model_loader = ModelLoader()
|
|
|
|
|
|
def test_asr_pipeline():
|
|
# 加载模型
|
|
args = {
|
|
"asr_model": "paraformer-zh",
|
|
"asr_model_revision": "v2.0.4",
|
|
"vad_model": "fsmn-vad",
|
|
"vad_model_revision": "v2.0.4",
|
|
"spk_model": "cam++",
|
|
"spk_model_revision": "v2.0.2",
|
|
"audio_update": False,
|
|
}
|
|
models = model_loader.load_models(args)
|
|
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
|
|
audio_config = AudioBinary_Config(
|
|
chunk_size=200,
|
|
chunk_stride=1600,
|
|
sample_rate=sample_rate,
|
|
sample_width=16,
|
|
channels=1,
|
|
)
|
|
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
|
|
audio_config.chunk_stride = chunk_stride
|
|
|
|
# 创建参数Dict
|
|
config = {
|
|
"audio_config": audio_config,
|
|
}
|
|
|
|
# 创建音频数据列表
|
|
audio_binary_data_list = AudioBinary_data_list()
|
|
|
|
input_queue = Queue()
|
|
|
|
# 创建Pipeline
|
|
# asr_pipeline = ASRPipeline()
|
|
# asr_pipeline.set_models(models)
|
|
# asr_pipeline.set_config(config)
|
|
# asr_pipeline.set_audio_binary(audio_binary_data_list)
|
|
# asr_pipeline.set_input_queue(input_queue)
|
|
# asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
|
|
# asr_pipeline.bake()
|
|
asr_pipeline = PipelineFactory.create_pipeline(
|
|
pipeline_name = "ASRpipeline",
|
|
models=models,
|
|
config=config,
|
|
audio_binary=audio_binary_data_list,
|
|
input_queue=input_queue,
|
|
callback=lambda x: print(f"pipeline callback: {x}")
|
|
)
|
|
|
|
# 运行Pipeline
|
|
asr_instance = asr_pipeline.run()
|
|
|
|
|
|
audio_clip_len = 200
|
|
print(
|
|
f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}"
|
|
)
|
|
for i in range(0, len(audio_data), audio_clip_len):
|
|
input_queue.put(audio_data[i : i + audio_clip_len])
|
|
|
|
# time.sleep(10)
|
|
# input_queue.put(None)
|
|
|
|
# 等待Pipeline结束
|
|
# asr_instance.join()
|
|
|
|
time.sleep(5)
|
|
asr_pipeline.stop()
|
|
# asr_pipeline.stop()
|