STT_Server/tests/functor/vad_test.py

123 lines
3.6 KiB
Python

"""
Functor测试
VAD测试
"""
from src.functor.vad_functor import VADFunctor
from src.functor.asr_functor import ASRFunctor
from src.functor.spk_functor import SPKFunctor
from queue import Queue, Empty
from src.model_loader import ModelLoader
from src.models import AudioBinary_Config, AudioBinary_data_list
from src.utils.data_format import wav_to_bytes
import time
from src.utils.logger import get_module_logger
from pydub import AudioSegment
import soundfile
# 观察参数
OVERWATCH = False
logger = get_module_logger(__name__)
model_loader = ModelLoader()
def test_vad_functor():
# 加载模型
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"auto_update": False,
}
model_loader.load_models(args)
# 加载数据
f_data, sample_rate = soundfile.read("tests/vad_example.wav")
audio_config = AudioBinary_Config(
chunk_size=200,
chunk_stride=1600,
sample_rate=sample_rate,
sample_width=16,
channels=1
)
chunk_stride = int(audio_config.chunk_size*sample_rate/1000)
audio_config.chunk_stride = chunk_stride
# 创建输入队列
input_queue = Queue()
vad2asr_queue = Queue()
vad2spk_queue = Queue()
# 创建音频数据列表
audio_binary_data_list = AudioBinary_data_list()
# 创建VAD函数器
vad_functor = VADFunctor()
# 设置输入队列
vad_functor.set_input_queue(input_queue)
# 设置音频配置
vad_functor.set_audio_config(audio_config)
# 设置音频数据列表
vad_functor.set_audio_binary_data_list(audio_binary_data_list)
# 设置回调函数
vad_functor.add_callback(lambda x: print(f"vad callback: {x}"))
vad_functor.add_callback(lambda x: vad2asr_queue.put(x))
vad_functor.add_callback(lambda x: vad2spk_queue.put(x))
# 设置模型
vad_functor.set_model({
'vad': model_loader.models['vad']
})
# 启动VAD函数器
vad_functor.run()
# 创建ASR函数器
asr_functor = ASRFunctor()
# 设置输入队列
asr_functor.set_input_queue(vad2asr_queue)
# 设置音频配置
asr_functor.set_audio_config(audio_config)
# 设置回调函数
asr_functor.add_callback(lambda x: print(f"asr callback: {x}"))
# 设置模型
asr_functor.set_model({
'asr': model_loader.models['asr']
})
# 启动ASR函数器
asr_functor.run()
# 创建SPK函数器
spk_functor = SPKFunctor()
# 设置输入队列
spk_functor.set_input_queue(vad2spk_queue)
# 设置音频配置
spk_functor.set_audio_config(audio_config)
# 设置回调函数
spk_functor.add_callback(lambda x: print(f"spk callback: {x}"))
# 设置模型
spk_functor.set_model({
# 'spk': model_loader.models['spk']
'spk': 'fake_spk'
})
# 启动SPK函数器
spk_functor.run()
f_binary = f_data
audio_clip_len = 200
print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}")
for i in range(0, len(f_binary), audio_clip_len):
binary_data = f_binary[i:i+audio_clip_len]
input_queue.put(binary_data)
# 等待VAD函数器结束
vad_functor.stop()
print("[vad_test] VAD函数器结束")
asr_functor.stop()
print("[vad_test] ASR函数器结束")
# 保存音频数据
if OVERWATCH:
for index in range(len(audio_binary_data_list)):
save_path = f"tests/vad_test_output_{index}.wav"
soundfile.write(save_path, audio_binary_data_list[index].binary_data, sample_rate)