81 lines
2.1 KiB
Python
81 lines
2.1 KiB
Python
"""
|
|
Pipeline测试
|
|
VAD+ASR+SPK(FAKE)
|
|
"""
|
|
from src.pipeline.ASRpipeline import ASRPipeline
|
|
from src.models import AudioBinary_data_list, AudioBinary_Config
|
|
from src.model_loader import ModelLoader
|
|
from queue import Queue
|
|
import soundfile
|
|
import time
|
|
|
|
from src.utils.logger import get_module_logger
|
|
|
|
logger = get_module_logger(__name__)
|
|
|
|
|
|
OVAERWATCH = False
|
|
|
|
model_loader = ModelLoader()
|
|
|
|
def test_asr_pipeline():
|
|
# 加载模型
|
|
args = {
|
|
"asr_model": "paraformer-zh",
|
|
"asr_model_revision": "v2.0.4",
|
|
"vad_model": "fsmn-vad",
|
|
"vad_model_revision": "v2.0.4",
|
|
"spk_model": "cam++",
|
|
"spk_model_revision": "v2.0.2",
|
|
"audio_update": False,
|
|
}
|
|
models = model_loader.load_models(args)
|
|
audio_data, sample_rate = soundfile.read("tests/vad_example.wav")
|
|
audio_config = AudioBinary_Config(
|
|
chunk_size=200,
|
|
chunk_stride=1600,
|
|
sample_rate=sample_rate,
|
|
sample_width=16,
|
|
channels=1
|
|
)
|
|
chunk_stride = int(audio_config.chunk_size*sample_rate/1000)
|
|
audio_config.chunk_stride = chunk_stride
|
|
|
|
# 创建参数Dict
|
|
config = {
|
|
"audio_config": audio_config,
|
|
}
|
|
|
|
# 创建音频数据列表
|
|
audio_binary_data_list = AudioBinary_data_list()
|
|
|
|
input_queue = Queue()
|
|
|
|
# 创建Pipeline
|
|
asr_pipeline = ASRPipeline()
|
|
asr_pipeline.set_models(models)
|
|
asr_pipeline.set_config(config)
|
|
asr_pipeline.set_audio_binary(audio_binary_data_list)
|
|
asr_pipeline.set_input_queue(input_queue)
|
|
asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
|
|
asr_pipeline.bake()
|
|
|
|
# 运行Pipeline
|
|
asr_instance = asr_pipeline.run()
|
|
|
|
audio_clip_len = 200
|
|
print(f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}")
|
|
for i in range(0, len(audio_data), audio_clip_len):
|
|
input_queue.put(audio_data[i:i+audio_clip_len])
|
|
|
|
# time.sleep(10)
|
|
# input_queue.put(None)
|
|
|
|
# 等待Pipeline结束
|
|
# asr_instance.join()
|
|
|
|
time.sleep(5)
|
|
asr_pipeline.stop()
|
|
# asr_pipeline.stop()
|
|
|