127 lines
4.2 KiB
Python
127 lines
4.2 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
ASR服务模块 - 提供语音识别相关的核心功能
|
||
"""
|
||
|
||
import json
|
||
|
||
|
||
class ASRService:
|
||
"""ASR服务类,封装各种语音识别相关功能"""
|
||
|
||
def __init__(self, models):
|
||
"""
|
||
初始化ASR服务
|
||
|
||
参数:
|
||
models: 包含各种预加载模型的字典
|
||
"""
|
||
self.model_asr = models["asr"]
|
||
self.model_asr_streaming = models["asr_streaming"]
|
||
self.model_vad = models["vad"]
|
||
self.model_punc = models["punc"]
|
||
|
||
async def async_vad(self, websocket, audio_in):
|
||
"""
|
||
语音活动检测
|
||
|
||
参数:
|
||
websocket: WebSocket连接
|
||
audio_in: 二进制音频数据
|
||
|
||
返回:
|
||
tuple: (speech_start, speech_end) 语音开始和结束位置
|
||
"""
|
||
# 使用VAD模型分析音频段
|
||
segments_result = self.model_vad.generate(
|
||
input=audio_in,
|
||
**websocket.status_dict_vad
|
||
)[0]["value"]
|
||
|
||
speech_start = -1
|
||
speech_end = -1
|
||
|
||
# 解析VAD结果
|
||
if len(segments_result) == 0 or len(segments_result) > 1:
|
||
return speech_start, speech_end
|
||
|
||
if segments_result[0][0] != -1:
|
||
speech_start = segments_result[0][0]
|
||
if segments_result[0][1] != -1:
|
||
speech_end = segments_result[0][1]
|
||
|
||
return speech_start, speech_end
|
||
|
||
async def async_asr(self, websocket, audio_in):
|
||
"""
|
||
离线ASR处理
|
||
|
||
参数:
|
||
websocket: WebSocket连接
|
||
audio_in: 二进制音频数据
|
||
"""
|
||
if len(audio_in) > 0:
|
||
# 使用离线ASR模型处理音频
|
||
rec_result = self.model_asr.generate(
|
||
input=audio_in,
|
||
**websocket.status_dict_asr
|
||
)[0]
|
||
|
||
# 如果有标点符号模型且识别出文本,则添加标点
|
||
if self.model_punc is not None and len(rec_result["text"]) > 0:
|
||
rec_result = self.model_punc.generate(
|
||
input=rec_result["text"],
|
||
**websocket.status_dict_punc
|
||
)[0]
|
||
|
||
# 如果识别出文本,发送到客户端
|
||
if len(rec_result["text"]) > 0:
|
||
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
||
message = json.dumps({
|
||
"mode": mode,
|
||
"text": rec_result["text"],
|
||
"wav_name": websocket.wav_name,
|
||
"is_final": websocket.is_speaking,
|
||
})
|
||
await websocket.send(message)
|
||
else:
|
||
# 如果没有音频数据,发送空文本
|
||
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
||
message = json.dumps({
|
||
"mode": mode,
|
||
"text": "",
|
||
"wav_name": websocket.wav_name,
|
||
"is_final": websocket.is_speaking,
|
||
})
|
||
await websocket.send(message)
|
||
|
||
async def async_asr_online(self, websocket, audio_in):
|
||
"""
|
||
在线ASR处理
|
||
|
||
参数:
|
||
websocket: WebSocket连接
|
||
audio_in: 二进制音频数据
|
||
"""
|
||
if len(audio_in) > 0:
|
||
# 使用在线ASR模型处理音频
|
||
rec_result = self.model_asr_streaming.generate(
|
||
input=audio_in,
|
||
**websocket.status_dict_asr_online
|
||
)[0]
|
||
|
||
# 在2pass模式下,如果是最终帧则跳过(留给离线ASR处理)
|
||
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False):
|
||
return
|
||
|
||
# 如果识别出文本,发送到客户端
|
||
if len(rec_result["text"]):
|
||
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
|
||
message = json.dumps({
|
||
"mode": mode,
|
||
"text": rec_result["text"],
|
||
"wav_name": websocket.wav_name,
|
||
"is_final": websocket.is_speaking,
|
||
})
|
||
await websocket.send(message) |