[代码结构]black . 对所有文件格式调整,无功能变化。
This commit is contained in:
parent
5b94c40016
commit
5a820b49e4
22
main.py
22
main.py
@ -1,11 +1,7 @@
|
|||||||
from funasr import AutoModel
|
from funasr import AutoModel
|
||||||
|
|
||||||
chunk_size = 200 # ms
|
chunk_size = 200 # ms
|
||||||
model = AutoModel(
|
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
|
||||||
model="fsmn-vad",
|
|
||||||
model_revision="v2.0.4",
|
|
||||||
disable_update=True
|
|
||||||
)
|
|
||||||
|
|
||||||
import soundfile
|
import soundfile
|
||||||
|
|
||||||
@ -14,16 +10,16 @@ speech, sample_rate = soundfile.read(wav_file)
|
|||||||
chunk_stride = int(chunk_size * sample_rate / 1000)
|
chunk_stride = int(chunk_size * sample_rate / 1000)
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
total_chunk_num = int(len((speech)-1)/chunk_stride+1)
|
total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
|
||||||
for i in range(total_chunk_num):
|
for i in range(total_chunk_num):
|
||||||
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
|
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
|
||||||
is_final = i == total_chunk_num - 1
|
is_final = i == total_chunk_num - 1
|
||||||
res = model.generate(
|
res = model.generate(
|
||||||
input=speech_chunk,
|
input=speech_chunk,
|
||||||
cache=cache,
|
cache=cache,
|
||||||
is_final=is_final,
|
is_final=is_final,
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
disable_pbar=True
|
disable_pbar=True,
|
||||||
)
|
)
|
||||||
if len(res[0]["value"]):
|
if len(res[0]["value"]):
|
||||||
print(res)
|
print(res)
|
||||||
@ -31,4 +27,4 @@ for i in range(total_chunk_num):
|
|||||||
print(f"len(speech): {len(speech)}")
|
print(f"len(speech): {len(speech)}")
|
||||||
print(f"len(speech_chunk): {len(speech_chunk)}")
|
print(f"len(speech_chunk): {len(speech_chunk)}")
|
||||||
print(f"total_chunk_num: {total_chunk_num}")
|
print(f"total_chunk_num: {total_chunk_num}")
|
||||||
print(f"generateconfig: chunk_size: {chunk_size}, chunk_stride: {chunk_stride}")
|
print(f"generateconfig: chunk_size: {chunk_size}, chunk_stride: {chunk_stride}")
|
||||||
|
@ -11,4 +11,4 @@ FunASR WebSocket服务
|
|||||||
- 支持多种识别模式(2pass/online/offline)
|
- 支持多种识别模式(2pass/online/offline)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
|
@ -46,7 +46,7 @@ class AudioBinary:
|
|||||||
else:
|
else:
|
||||||
raise ValueError("参数类型错误")
|
raise ValueError("参数类型错误")
|
||||||
|
|
||||||
def add_slice_listener(self, slice_listener: callable):
|
def add_slice_listener(self, slice_listener: callable) -> None:
|
||||||
"""
|
"""
|
||||||
添加切片监听器
|
添加切片监听器
|
||||||
参数:
|
参数:
|
||||||
@ -98,10 +98,10 @@ class AudioBinary:
|
|||||||
self._binary_data_list.rewrite(target_index, binary_data)
|
self._binary_data_list.rewrite(target_index, binary_data)
|
||||||
|
|
||||||
def get_binary_data(
|
def get_binary_data(
|
||||||
self,
|
self,
|
||||||
start: int = 0,
|
start: int = 0,
|
||||||
end: Optional[int] = None,
|
end: Optional[int] = None,
|
||||||
) -> Optional[bytes]:
|
) -> Optional[bytes]:
|
||||||
"""
|
"""
|
||||||
获取指定索引的音频数据块
|
获取指定索引的音频数据块
|
||||||
参数:
|
参数:
|
||||||
@ -128,7 +128,7 @@ class AudioChunk:
|
|||||||
此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑。
|
此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_instance = None
|
_instance: Optional[AudioChunk] = None
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -138,7 +138,7 @@ class AudioChunk:
|
|||||||
cls._instance = super(AudioChunk, cls).__new__(cls, *args, **kwargs)
|
cls._instance = super(AudioChunk, cls).__new__(cls, *args, **kwargs)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
"""
|
"""
|
||||||
初始化AudioChunk实例
|
初始化AudioChunk实例
|
||||||
"""
|
"""
|
||||||
@ -146,10 +146,10 @@ class AudioChunk:
|
|||||||
self._slice_listener: List[callable] = []
|
self._slice_listener: List[callable] = []
|
||||||
|
|
||||||
def get_audio_binary(
|
def get_audio_binary(
|
||||||
self,
|
self,
|
||||||
binary_name: Optional[str] = None,
|
binary_name: Optional[str] = None,
|
||||||
audio_config: Optional[AudioBinary_Config] = None,
|
audio_config: Optional[AudioBinary_Config] = None,
|
||||||
) -> AudioBinary:
|
) -> AudioBinary:
|
||||||
"""
|
"""
|
||||||
获取音频数据块
|
获取音频数据块
|
||||||
参数:
|
参数:
|
||||||
|
@ -10,116 +10,79 @@ import argparse
|
|||||||
def parse_args():
|
def parse_args():
|
||||||
"""
|
"""
|
||||||
解析命令行参数
|
解析命令行参数
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
argparse.Namespace: 解析后的参数对象
|
argparse.Namespace: 解析后的参数对象
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser(description="FunASR WebSocket服务器")
|
parser = argparse.ArgumentParser(description="FunASR WebSocket服务器")
|
||||||
|
|
||||||
# 服务器配置
|
# 服务器配置
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--host",
|
"--host",
|
||||||
type=str,
|
type=str,
|
||||||
default="0.0.0.0",
|
default="0.0.0.0",
|
||||||
help="服务器主机地址,例如:localhost, 0.0.0.0"
|
help="服务器主机地址,例如:localhost, 0.0.0.0",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--port", type=int, default=10095, help="WebSocket服务器端口")
|
||||||
"--port",
|
|
||||||
type=int,
|
|
||||||
default=10095,
|
|
||||||
help="WebSocket服务器端口"
|
|
||||||
)
|
|
||||||
|
|
||||||
# SSL配置
|
# SSL配置
|
||||||
parser.add_argument(
|
parser.add_argument("--certfile", type=str, default="", help="SSL证书文件路径")
|
||||||
"--certfile",
|
parser.add_argument("--keyfile", type=str, default="", help="SSL密钥文件路径")
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="SSL证书文件路径"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--keyfile",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="SSL密钥文件路径"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ASR模型配置
|
# ASR模型配置
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--asr_model",
|
"--asr_model",
|
||||||
type=str,
|
type=str,
|
||||||
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||||
help="离线ASR模型(从ModelScope获取)"
|
help="离线ASR模型(从ModelScope获取)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--asr_model_revision",
|
"--asr_model_revision", type=str, default="v2.0.4", help="离线ASR模型版本"
|
||||||
type=str,
|
|
||||||
default="v2.0.4",
|
|
||||||
help="离线ASR模型版本"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 在线ASR模型配置
|
# 在线ASR模型配置
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--asr_model_online",
|
"--asr_model_online",
|
||||||
type=str,
|
type=str,
|
||||||
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
|
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
|
||||||
help="在线ASR模型(从ModelScope获取)"
|
help="在线ASR模型(从ModelScope获取)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--asr_model_online_revision",
|
"--asr_model_online_revision",
|
||||||
type=str,
|
type=str,
|
||||||
default="v2.0.4",
|
default="v2.0.4",
|
||||||
help="在线ASR模型版本"
|
help="在线ASR模型版本",
|
||||||
)
|
)
|
||||||
|
|
||||||
# VAD模型配置
|
# VAD模型配置
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vad_model",
|
"--vad_model",
|
||||||
type=str,
|
type=str,
|
||||||
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||||
help="VAD语音活动检测模型(从ModelScope获取)"
|
help="VAD语音活动检测模型(从ModelScope获取)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vad_model_revision",
|
"--vad_model_revision", type=str, default="v2.0.4", help="VAD模型版本"
|
||||||
type=str,
|
|
||||||
default="v2.0.4",
|
|
||||||
help="VAD模型版本"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 标点符号模型配置
|
# 标点符号模型配置
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--punc_model",
|
"--punc_model",
|
||||||
type=str,
|
type=str,
|
||||||
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
|
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
|
||||||
help="标点符号模型(从ModelScope获取)"
|
help="标点符号模型(从ModelScope获取)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--punc_model_revision",
|
"--punc_model_revision", type=str, default="v2.0.4", help="标点符号模型版本"
|
||||||
type=str,
|
|
||||||
default="v2.0.4",
|
|
||||||
help="标点符号模型版本"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 硬件配置
|
# 硬件配置
|
||||||
|
parser.add_argument("--ngpu", type=int, default=1, help="GPU数量,0表示仅使用CPU")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ngpu",
|
"--device", type=str, default="cuda", help="设备类型:cuda或cpu"
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="GPU数量,0表示仅使用CPU"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--ncpu", type=int, default=4, help="CPU核心数")
|
||||||
"--device",
|
|
||||||
type=str,
|
|
||||||
default="cuda",
|
|
||||||
help="设备类型:cuda或cpu"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ncpu",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="CPU核心数"
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -127,4 +90,4 @@ if __name__ == "__main__":
|
|||||||
args = parse_args()
|
args = parse_args()
|
||||||
print("配置参数:")
|
print("配置参数:")
|
||||||
for arg in vars(args):
|
for arg in vars(args):
|
||||||
print(f" {arg}: {getattr(args, arg)}")
|
print(f" {arg}: {getattr(args, arg)}")
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .vad_functor import VADFunctor
|
from .vad_functor import VADFunctor
|
||||||
from .base import FunctorFactory
|
from .base import FunctorFactory
|
||||||
|
|
||||||
__all__ = ["VADFunctor", "FunctorFactory"]
|
__all__ = ["VADFunctor", "FunctorFactory"]
|
||||||
|
@ -2,8 +2,9 @@
|
|||||||
ASRFunctor
|
ASRFunctor
|
||||||
负责对音频片段进行ASR处理, 以ASR_Result进行callback
|
负责对音频片段进行ASR处理, 以ASR_Result进行callback
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from src.functor.base import BaseFunctor
|
from src.functor.base import BaseFunctor
|
||||||
from src.models import AudioBinary_data_list, AudioBinary_Config,VAD_Functor_result
|
from src.models import AudioBinary_data_list, AudioBinary_Config, VAD_Functor_result
|
||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
from queue import Queue, Empty
|
from queue import Queue, Empty
|
||||||
import threading
|
import threading
|
||||||
@ -13,6 +14,7 @@ from src.utils.logger import get_module_logger
|
|||||||
|
|
||||||
logger = get_module_logger(__name__)
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ASRFunctor(BaseFunctor):
|
class ASRFunctor(BaseFunctor):
|
||||||
"""
|
"""
|
||||||
ASRFunctor
|
ASRFunctor
|
||||||
@ -51,26 +53,26 @@ class ASRFunctor(BaseFunctor):
|
|||||||
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def set_input_queue(self, queue: Queue) -> None:
|
def set_input_queue(self, queue: Queue) -> None:
|
||||||
"""
|
"""
|
||||||
设置监听的输入消息队列
|
设置监听的输入消息队列
|
||||||
"""
|
"""
|
||||||
self._input_queue = queue
|
self._input_queue = queue
|
||||||
|
|
||||||
def set_model(self, model: dict) -> None:
|
def set_model(self, model: dict) -> None:
|
||||||
"""
|
"""
|
||||||
设置推理模型
|
设置推理模型
|
||||||
"""
|
"""
|
||||||
self._model = model
|
self._model = model
|
||||||
|
|
||||||
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
||||||
"""
|
"""
|
||||||
设置音频配置
|
设置音频配置
|
||||||
"""
|
"""
|
||||||
self._audio_config = audio_config
|
self._audio_config = audio_config
|
||||||
logger.debug("ASRFunctor设置音频配置: %s", self._audio_config)
|
logger.debug("ASRFunctor设置音频配置: %s", self._audio_config)
|
||||||
|
|
||||||
def add_callback(self, callback: Callable) -> None:
|
def add_callback(self, callback: Callable) -> None:
|
||||||
"""
|
"""
|
||||||
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||||||
@ -78,12 +80,12 @@ class ASRFunctor(BaseFunctor):
|
|||||||
if not isinstance(self._callback, list):
|
if not isinstance(self._callback, list):
|
||||||
self._callback = []
|
self._callback = []
|
||||||
self._callback.append(callback)
|
self._callback.append(callback)
|
||||||
|
|
||||||
def _do_callback(self, result: List[str]) -> None:
|
def _do_callback(self, result: List[str]) -> None:
|
||||||
"""
|
"""
|
||||||
回调函数
|
回调函数
|
||||||
"""
|
"""
|
||||||
text = result[0]['text'].replace(" ", "")
|
text = result[0]["text"].replace(" ", "")
|
||||||
for callback in self._callback:
|
for callback in self._callback:
|
||||||
callback(text)
|
callback(text)
|
||||||
|
|
||||||
@ -98,7 +100,7 @@ class ASRFunctor(BaseFunctor):
|
|||||||
hotwords=self._hotwords,
|
hotwords=self._hotwords,
|
||||||
)
|
)
|
||||||
self._do_callback(result)
|
self._do_callback(result)
|
||||||
|
|
||||||
def _run(self) -> None:
|
def _run(self) -> None:
|
||||||
"""
|
"""
|
||||||
线程运行逻辑
|
线程运行逻辑
|
||||||
@ -132,7 +134,7 @@ class ASRFunctor(BaseFunctor):
|
|||||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
return self._thread
|
return self._thread
|
||||||
|
|
||||||
def _pre_check(self) -> bool:
|
def _pre_check(self) -> bool:
|
||||||
"""
|
"""
|
||||||
预检查
|
预检查
|
||||||
@ -146,7 +148,7 @@ class ASRFunctor(BaseFunctor):
|
|||||||
if self._callback is None:
|
if self._callback is None:
|
||||||
raise ValueError("回调函数未设置")
|
raise ValueError("回调函数未设置")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def stop(self) -> bool:
|
def stop(self) -> bool:
|
||||||
"""
|
"""
|
||||||
停止线程
|
停止线程
|
||||||
@ -157,5 +159,3 @@ class ASRFunctor(BaseFunctor):
|
|||||||
with self._status_lock:
|
with self._status_lock:
|
||||||
self._is_running = False
|
self._is_running = False
|
||||||
return not self._thread.is_alive()
|
return not self._thread.is_alive()
|
||||||
|
|
||||||
|
|
@ -4,18 +4,20 @@ Functor基础模块
|
|||||||
该模块定义了Functor的基类,所有功能性的类(如VAD、PUNC、ASR、SPK等)都应继承自这个基类。
|
该模块定义了Functor的基类,所有功能性的类(如VAD、PUNC、ASR、SPK等)都应继承自这个基类。
|
||||||
基类提供了数据处理的基本框架,包括:
|
基类提供了数据处理的基本框架,包括:
|
||||||
- 回调函数管理
|
- 回调函数管理
|
||||||
- 模型配置管理
|
- 模型配置管理
|
||||||
- 线程运行控制
|
- 线程运行控制
|
||||||
|
|
||||||
主要类:
|
主要类:
|
||||||
BaseFunctor: Functor抽象类
|
BaseFunctor: Functor抽象类
|
||||||
FunctorFactory: Functor工厂类
|
FunctorFactory: Functor工厂类
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
|
||||||
class BaseFunctor(ABC):
|
class BaseFunctor(ABC):
|
||||||
"""
|
"""
|
||||||
Functor抽象类
|
Functor抽象类
|
||||||
@ -27,9 +29,7 @@ class BaseFunctor(ABC):
|
|||||||
_model (dict): 存储模型相关的配置和实例
|
_model (dict): 存储模型相关的配置和实例
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self):
|
||||||
self
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
初始化函数器
|
初始化函数器
|
||||||
|
|
||||||
@ -38,8 +38,8 @@ class BaseFunctor(ABC):
|
|||||||
model (dict): 模型相关的配置和实例
|
model (dict): 模型相关的配置和实例
|
||||||
"""
|
"""
|
||||||
self._callback: List[Callable] = []
|
self._callback: List[Callable] = []
|
||||||
self._model: dict = {}
|
self._model: dict = {}
|
||||||
# flag
|
# flag
|
||||||
self._is_running: bool = False
|
self._is_running: bool = False
|
||||||
self._stop_event: bool = False
|
self._stop_event: bool = False
|
||||||
# 状态锁
|
# 状态锁
|
||||||
@ -91,7 +91,7 @@ class BaseFunctor(ABC):
|
|||||||
返回:
|
返回:
|
||||||
线程实例
|
线程实例
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _pre_check(self):
|
def _pre_check(self):
|
||||||
"""
|
"""
|
||||||
@ -111,13 +111,12 @@ class BaseFunctor(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FunctorFactory:
|
class FunctorFactory:
|
||||||
"""
|
"""
|
||||||
Functor工厂类
|
Functor工厂类
|
||||||
|
|
||||||
该工厂类负责创建和配置Functor实例
|
该工厂类负责创建和配置Functor实例
|
||||||
|
|
||||||
主要方法:
|
主要方法:
|
||||||
make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
|
make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
|
||||||
创建并配置Functor实例
|
创建并配置Functor实例
|
||||||
@ -138,58 +137,55 @@ class FunctorFactory:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if functor_name == "vad":
|
if functor_name == "vad":
|
||||||
return cls._make_vadfunctor(config = config,models = models)
|
return cls._make_vadfunctor(config=config, models=models)
|
||||||
elif functor_name == "asr":
|
elif functor_name == "asr":
|
||||||
return cls._make_asrfunctor(config = config,models = models)
|
return cls._make_asrfunctor(config=config, models=models)
|
||||||
elif functor_name == "spk":
|
elif functor_name == "spk":
|
||||||
return cls._make_spkfunctor(config = config,models = models)
|
return cls._make_spkfunctor(config=config, models=models)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"不支持的Functor类型: {functor_name}")
|
raise ValueError(f"不支持的Functor类型: {functor_name}")
|
||||||
|
|
||||||
def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor:
|
def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||||
"""
|
"""
|
||||||
创建VAD Functor实例
|
创建VAD Functor实例
|
||||||
"""
|
"""
|
||||||
from src.functor.vad_functor import VADFunctor
|
from src.functor.vad_functor import VADFunctor
|
||||||
|
|
||||||
audio_config = config["audio_config"]
|
audio_config = config["audio_config"]
|
||||||
model = {
|
model = {"vad": models["vad"]}
|
||||||
"vad": models["vad"]
|
|
||||||
}
|
|
||||||
|
|
||||||
vad_functor = VADFunctor()
|
vad_functor = VADFunctor()
|
||||||
vad_functor.set_audio_config(audio_config)
|
vad_functor.set_audio_config(audio_config)
|
||||||
vad_functor.set_model(model)
|
vad_functor.set_model(model)
|
||||||
|
|
||||||
return vad_functor
|
return vad_functor
|
||||||
|
|
||||||
def _make_asrfunctor(config: dict, models: dict) -> BaseFunctor:
|
def _make_asrfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||||
"""
|
"""
|
||||||
创建ASR Functor实例
|
创建ASR Functor实例
|
||||||
"""
|
"""
|
||||||
from src.functor.asr_functor import ASRFunctor
|
from src.functor.asr_functor import ASRFunctor
|
||||||
|
|
||||||
audio_config = config["audio_config"]
|
audio_config = config["audio_config"]
|
||||||
model = {
|
model = {"asr": models["asr"]}
|
||||||
"asr": models["asr"]
|
|
||||||
}
|
|
||||||
|
|
||||||
asr_functor = ASRFunctor()
|
asr_functor = ASRFunctor()
|
||||||
asr_functor.set_audio_config(audio_config)
|
asr_functor.set_audio_config(audio_config)
|
||||||
asr_functor.set_model(model)
|
asr_functor.set_model(model)
|
||||||
|
|
||||||
return asr_functor
|
return asr_functor
|
||||||
|
|
||||||
def _make_spkfunctor(config: dict, models: dict) -> BaseFunctor:
|
def _make_spkfunctor(config: dict, models: dict) -> BaseFunctor:
|
||||||
"""
|
"""
|
||||||
创建SPK Functor实例
|
创建SPK Functor实例
|
||||||
"""
|
"""
|
||||||
from src.functor.spk_functor import SPKFunctor
|
from src.functor.spk_functor import SPKFunctor
|
||||||
|
|
||||||
audio_config = config["audio_config"]
|
audio_config = config["audio_config"]
|
||||||
model = {
|
model = {"spk": models["spk"]}
|
||||||
"spk": models["spk"]
|
|
||||||
}
|
|
||||||
|
|
||||||
spk_functor = SPKFunctor()
|
spk_functor = SPKFunctor()
|
||||||
spk_functor.set_audio_config(audio_config)
|
spk_functor.set_audio_config(audio_config)
|
||||||
spk_functor.set_model(model)
|
spk_functor.set_model(model)
|
||||||
|
|
||||||
return spk_functor
|
return spk_functor
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
SpkFunctor
|
SpkFunctor
|
||||||
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
|
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from src.functor.base import BaseFunctor
|
from src.functor.base import BaseFunctor
|
||||||
from src.models import AudioBinary_Config, VAD_Functor_result
|
from src.models import AudioBinary_Config, VAD_Functor_result
|
||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
@ -13,6 +14,7 @@ from src.utils.logger import get_module_logger
|
|||||||
|
|
||||||
logger = get_module_logger(__name__)
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SPKFunctor(BaseFunctor):
|
class SPKFunctor(BaseFunctor):
|
||||||
"""
|
"""
|
||||||
SPKFunctor
|
SPKFunctor
|
||||||
@ -33,25 +35,24 @@ class SPKFunctor(BaseFunctor):
|
|||||||
self._input_queue: Queue = None # 输入队列
|
self._input_queue: Queue = None # 输入队列
|
||||||
self._audio_config: AudioBinary_Config = None # 音频配置
|
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||||
|
|
||||||
|
|
||||||
def reset_cache(self) -> None:
|
def reset_cache(self) -> None:
|
||||||
"""
|
"""
|
||||||
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def set_input_queue(self, queue: Queue) -> None:
|
def set_input_queue(self, queue: Queue) -> None:
|
||||||
"""
|
"""
|
||||||
设置监听的输入消息队列
|
设置监听的输入消息队列
|
||||||
"""
|
"""
|
||||||
self._input_queue = queue
|
self._input_queue = queue
|
||||||
|
|
||||||
def set_model(self, model: dict) -> None:
|
def set_model(self, model: dict) -> None:
|
||||||
"""
|
"""
|
||||||
设置推理模型
|
设置推理模型
|
||||||
"""
|
"""
|
||||||
self._model = model
|
self._model = model
|
||||||
|
|
||||||
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
|
||||||
"""
|
"""
|
||||||
设置音频配置
|
设置音频配置
|
||||||
@ -66,7 +67,7 @@ class SPKFunctor(BaseFunctor):
|
|||||||
if not isinstance(self._callback, list):
|
if not isinstance(self._callback, list):
|
||||||
self._callback = []
|
self._callback = []
|
||||||
self._callback.append(callback)
|
self._callback.append(callback)
|
||||||
|
|
||||||
def _do_callback(self, result: List[str]) -> None:
|
def _do_callback(self, result: List[str]) -> None:
|
||||||
"""
|
"""
|
||||||
回调函数
|
回调函数
|
||||||
@ -83,9 +84,9 @@ class SPKFunctor(BaseFunctor):
|
|||||||
# input=binary_data,
|
# input=binary_data,
|
||||||
# chunk_size=self._audio_config.chunk_size,
|
# chunk_size=self._audio_config.chunk_size,
|
||||||
# )
|
# )
|
||||||
result = [{'result': "spk1", 'score': {"spk1": 0.9, "spk2": 0.3}}]
|
result = [{"result": "spk1", "score": {"spk1": 0.9, "spk2": 0.3}}]
|
||||||
self._do_callback(result)
|
self._do_callback(result)
|
||||||
|
|
||||||
def _run(self) -> None:
|
def _run(self) -> None:
|
||||||
"""
|
"""
|
||||||
线程运行逻辑
|
线程运行逻辑
|
||||||
@ -108,7 +109,7 @@ class SPKFunctor(BaseFunctor):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("SpkFunctor运行时发生错误: %s", e)
|
logger.error("SpkFunctor运行时发生错误: %s", e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def run(self) -> threading.Thread:
|
def run(self) -> threading.Thread:
|
||||||
"""
|
"""
|
||||||
启动线程
|
启动线程
|
||||||
@ -119,7 +120,7 @@ class SPKFunctor(BaseFunctor):
|
|||||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
return self._thread
|
return self._thread
|
||||||
|
|
||||||
def _pre_check(self) -> bool:
|
def _pre_check(self) -> bool:
|
||||||
"""
|
"""
|
||||||
预检查
|
预检查
|
||||||
@ -131,7 +132,7 @@ class SPKFunctor(BaseFunctor):
|
|||||||
if self._callback is None:
|
if self._callback is None:
|
||||||
raise ValueError("回调函数未设置")
|
raise ValueError("回调函数未设置")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def stop(self) -> bool:
|
def stop(self) -> bool:
|
||||||
"""
|
"""
|
||||||
停止线程
|
停止线程
|
||||||
@ -142,5 +143,3 @@ class SPKFunctor(BaseFunctor):
|
|||||||
with self._status_lock:
|
with self._status_lock:
|
||||||
self._is_running = False
|
self._is_running = False
|
||||||
return not self._thread.is_alive()
|
return not self._thread.is_alive()
|
||||||
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
|||||||
VADFunctor
|
VADFunctor
|
||||||
负责对音频片段进行VAD处理, 以VAD_Result进行callback
|
负责对音频片段进行VAD处理, 以VAD_Result进行callback
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
from queue import Empty, Queue
|
from queue import Empty, Queue
|
||||||
from typing import List, Any, Callable
|
from typing import List, Any, Callable
|
||||||
@ -105,9 +106,7 @@ class VADFunctor(BaseFunctor):
|
|||||||
self._callback = []
|
self._callback = []
|
||||||
self._callback.append(callback)
|
self._callback.append(callback)
|
||||||
|
|
||||||
def _do_callback(
|
def _do_callback(self, result: List[List[int]]) -> None:
|
||||||
self, result: List[List[int]]
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
回调函数
|
回调函数
|
||||||
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice
|
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice
|
||||||
|
@ -6,34 +6,36 @@
|
|||||||
|
|
||||||
from src.utils.logger import get_module_logger
|
from src.utils.logger import get_module_logger
|
||||||
from typing import Any, Dict, Type, Callable
|
from typing import Any, Dict, Type, Callable
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logger = get_module_logger(__name__, level="INFO")
|
logger = get_module_logger(__name__, level="INFO")
|
||||||
|
|
||||||
|
|
||||||
class AutoAfterMeta(type):
|
class AutoAfterMeta(type):
|
||||||
"""
|
"""
|
||||||
自动调用__after__函数的元类
|
自动调用__after__函数的元类
|
||||||
实现单例模式
|
实现单例模式
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_instances: Dict[Type, Any] = {} # 存储单例实例
|
_instances: Dict[Type, Any] = {} # 存储单例实例
|
||||||
|
|
||||||
def __new__(cls, name, bases, attrs):
|
def __new__(cls, name, bases, attrs):
|
||||||
# 遍历所有属性
|
# 遍历所有属性
|
||||||
for attr_name, attr_value in attrs.items():
|
for attr_name, attr_value in attrs.items():
|
||||||
# 如果是函数且不是以_开头
|
# 如果是函数且不是以_开头
|
||||||
if callable(attr_value) and not attr_name.startswith('__'):
|
if callable(attr_value) and not attr_name.startswith("__"):
|
||||||
# 获取原函数
|
# 获取原函数
|
||||||
original_func = attr_value
|
original_func = attr_value
|
||||||
|
|
||||||
# 创建包装函数
|
# 创建包装函数
|
||||||
def make_wrapper(func):
|
def make_wrapper(func):
|
||||||
def wrapper(self, *args, **kwargs):
|
def wrapper(self, *args, **kwargs):
|
||||||
# 执行原函数
|
# 执行原函数
|
||||||
result = func(self, *args, **kwargs)
|
result = func(self, *args, **kwargs)
|
||||||
|
|
||||||
# 构建_after_函数名
|
# 构建_after_函数名
|
||||||
after_func_name = f"__after__{func.__name__}"
|
after_func_name = f"__after__{func.__name__}"
|
||||||
|
|
||||||
# 检查是否存在对应的_after_函数
|
# 检查是否存在对应的_after_函数
|
||||||
if hasattr(self, after_func_name):
|
if hasattr(self, after_func_name):
|
||||||
after_func = getattr(self, after_func_name)
|
after_func = getattr(self, after_func_name)
|
||||||
@ -43,17 +45,18 @@ class AutoAfterMeta(type):
|
|||||||
after_func()
|
after_func()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"调用{after_func_name}时出错: {e}")
|
logger.error(f"调用{after_func_name}时出错: {e}")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
# 替换原函数
|
# 替换原函数
|
||||||
attrs[attr_name] = make_wrapper(original_func)
|
attrs[attr_name] = make_wrapper(original_func)
|
||||||
|
|
||||||
# 创建类
|
# 创建类
|
||||||
new_class = super().__new__(cls, name, bases, attrs)
|
new_class = super().__new__(cls, name, bases, attrs)
|
||||||
return new_class
|
return new_class
|
||||||
|
|
||||||
def __call__(cls, *args, **kwargs):
|
def __call__(cls, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
重写__call__方法实现单例模式
|
重写__call__方法实现单例模式
|
||||||
@ -65,9 +68,10 @@ class AutoAfterMeta(type):
|
|||||||
logger.info(f"创建{cls.__name__}的新实例")
|
logger.info(f"创建{cls.__name__}的新实例")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"返回{cls.__name__}的现有实例")
|
logger.debug(f"返回{cls.__name__}的现有实例")
|
||||||
|
|
||||||
return cls._instances[cls]
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
整体识别的处理逻辑:
|
整体识别的处理逻辑:
|
||||||
1.压入二进制音频信息
|
1.压入二进制音频信息
|
||||||
@ -88,10 +92,12 @@ from src.models import AudioBinary_Config
|
|||||||
from src.models import AudioBinary_Chunk
|
from src.models import AudioBinary_Chunk
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
class LogicTrager(metaclass=AutoAfterMeta):
|
class LogicTrager(metaclass=AutoAfterMeta):
|
||||||
"""逻辑触发器类"""
|
"""逻辑触发器类"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
audio_chunk_max_size: int = 1024 * 1024 * 10,
|
audio_chunk_max_size: int = 1024 * 1024 * 10,
|
||||||
audio_config: AudioBinary_Config = None,
|
audio_config: AudioBinary_Config = None,
|
||||||
result_callback: Callable = None,
|
result_callback: Callable = None,
|
||||||
@ -99,46 +105,51 @@ class LogicTrager(metaclass=AutoAfterMeta):
|
|||||||
):
|
):
|
||||||
"""初始化"""
|
"""初始化"""
|
||||||
# 存储音频块
|
# 存储音频块
|
||||||
self._audio_chunk : List[AudioBinary_Chunk] = []
|
self._audio_chunk: List[AudioBinary_Chunk] = []
|
||||||
# 存储二进制数据
|
# 存储二进制数据
|
||||||
self._audio_chunk_binary = b''
|
self._audio_chunk_binary = b""
|
||||||
self._audio_chunk_max_size = audio_chunk_max_size
|
self._audio_chunk_max_size = audio_chunk_max_size
|
||||||
# 音频参数
|
# 音频参数
|
||||||
self._audio_config = audio_config if audio_config is not None else AudioBinary_Config()
|
self._audio_config = (
|
||||||
|
audio_config if audio_config is not None else AudioBinary_Config()
|
||||||
|
)
|
||||||
# 结果队列
|
# 结果队列
|
||||||
self._result_queue = []
|
self._result_queue = []
|
||||||
# 聚合结果回调函数
|
# 聚合结果回调函数
|
||||||
self._aggregate_result_callback = result_callback
|
self._aggregate_result_callback = result_callback
|
||||||
# 组件
|
# 组件
|
||||||
self._vad = VAD(VAD_model = models.get("vad"), audio_config = self._audio_config)
|
self._vad = VAD(VAD_model=models.get("vad"), audio_config=self._audio_config)
|
||||||
self._vad.set_callback(self.push_audio_chunk)
|
self._vad.set_callback(self.push_audio_chunk)
|
||||||
|
|
||||||
|
|
||||||
logger.info("初始化LogicTrager")
|
logger.info("初始化LogicTrager")
|
||||||
|
|
||||||
def push_binary_data(self, chunk: bytes) -> None:
|
def push_binary_data(self, chunk: bytes) -> None:
|
||||||
"""
|
"""
|
||||||
压入音频块至VAD模块
|
压入音频块至VAD模块
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
chunk: 音频数据块
|
chunk: 音频数据块
|
||||||
"""
|
"""
|
||||||
# print("LogicTrager push_binary_data", len(chunk))
|
# print("LogicTrager push_binary_data", len(chunk))
|
||||||
self._vad.push_binary_data(chunk)
|
self._vad.push_binary_data(chunk)
|
||||||
self.__after__push_binary_data()
|
self.__after__push_binary_data()
|
||||||
|
|
||||||
def __after__push_binary_data(self) -> None:
|
def __after__push_binary_data(self) -> None:
|
||||||
"""
|
"""
|
||||||
添加音频块后处理
|
添加音频块后处理
|
||||||
"""
|
"""
|
||||||
# print("LogicTrager __after__push_binary_data")
|
# print("LogicTrager __after__push_binary_data")
|
||||||
self._vad.process_vad_result()
|
self._vad.process_vad_result()
|
||||||
|
|
||||||
def push_audio_chunk(self, chunk: AudioBinary_Chunk) -> None:
|
def push_audio_chunk(self, chunk: AudioBinary_Chunk) -> None:
|
||||||
"""
|
"""
|
||||||
音频处理
|
音频处理
|
||||||
"""
|
"""
|
||||||
logger.info("LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(chunk.start_time, chunk.end_time, len(chunk.chunk)))
|
logger.info(
|
||||||
|
"LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(
|
||||||
|
chunk.start_time, chunk.end_time, len(chunk.chunk)
|
||||||
|
)
|
||||||
|
)
|
||||||
self._audio_chunk.append(chunk)
|
self._audio_chunk.append(chunk)
|
||||||
|
|
||||||
def __after__push_audio_chunk(self) -> None:
|
def __after__push_audio_chunk(self) -> None:
|
||||||
@ -162,4 +173,4 @@ class LogicTrager(metaclass=AutoAfterMeta):
|
|||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
"""调用函数"""
|
"""调用函数"""
|
||||||
pass
|
pass
|
||||||
|
@ -115,7 +115,7 @@ class ModelLoader:
|
|||||||
self.models = {}
|
self.models = {}
|
||||||
# 加载离线ASR模型
|
# 加载离线ASR模型
|
||||||
# 检查对应键是否存在
|
# 检查对应键是否存在
|
||||||
model_list = ['asr', 'asr_online', 'vad', 'punc', 'spk']
|
model_list = ["asr", "asr_online", "vad", "punc", "spk"]
|
||||||
for model_name in model_list:
|
for model_name in model_list:
|
||||||
name_model = f"{model_name}_model"
|
name_model = f"{model_name}_model"
|
||||||
name_model_revision = f"{model_name}_model_revision"
|
name_model_revision = f"{model_name}_model_revision"
|
||||||
|
@ -1,3 +1,9 @@
|
|||||||
from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
|
from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
|
||||||
from .vad import VAD_Functor_result
|
from .vad import VAD_Functor_result
|
||||||
__all__ = ["AudioBinary_Config", "AudioBinary_data_list", "_AudioBinary_data", "VAD_Functor_result"]
|
|
||||||
|
__all__ = [
|
||||||
|
"AudioBinary_Config",
|
||||||
|
"AudioBinary_data_list",
|
||||||
|
"_AudioBinary_data",
|
||||||
|
"VAD_Functor_result",
|
||||||
|
]
|
||||||
|
@ -8,8 +8,10 @@ logger = get_module_logger(__name__)
|
|||||||
|
|
||||||
binary_data_types = (bytes, numpy.ndarray)
|
binary_data_types = (bytes, numpy.ndarray)
|
||||||
|
|
||||||
|
|
||||||
class AudioBinary_Config(BaseModel):
|
class AudioBinary_Config(BaseModel):
|
||||||
"""二进制音频块配置信息"""
|
"""二进制音频块配置信息"""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@ -37,14 +39,16 @@ class AudioBinary_Config(BaseModel):
|
|||||||
"""
|
"""
|
||||||
return int(frame * 1000 / self.sample_rate)
|
return int(frame * 1000 / self.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
class _AudioBinary_data(BaseModel):
|
class _AudioBinary_data(BaseModel):
|
||||||
"""音频数据"""
|
"""音频数据"""
|
||||||
|
|
||||||
binary_data: binary_data_types = Field(description="音频二进制数据", default=None)
|
binary_data: binary_data_types = Field(description="音频二进制数据", default=None)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@validator('binary_data')
|
@validator("binary_data")
|
||||||
def validate_binary_data(cls, v):
|
def validate_binary_data(cls, v):
|
||||||
"""
|
"""
|
||||||
验证音频数据
|
验证音频数据
|
||||||
@ -54,7 +58,11 @@ class _AudioBinary_data(BaseModel):
|
|||||||
binary_data_types: 音频数据
|
binary_data_types: 音频数据
|
||||||
"""
|
"""
|
||||||
if not isinstance(v, (bytes, numpy.ndarray)):
|
if not isinstance(v, (bytes, numpy.ndarray)):
|
||||||
logger.warning("[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查", cls.__class__.__name__, type(v))
|
logger.warning(
|
||||||
|
"[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查",
|
||||||
|
cls.__class__.__name__,
|
||||||
|
type(v),
|
||||||
|
)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -64,14 +72,18 @@ class _AudioBinary_data(BaseModel):
|
|||||||
int: 音频数据长度
|
int: 音频数据长度
|
||||||
"""
|
"""
|
||||||
return len(self.binary_data)
|
return len(self.binary_data)
|
||||||
|
|
||||||
def __init__(self, binary_data: binary_data_types):
|
def __init__(self, binary_data: binary_data_types):
|
||||||
"""
|
"""
|
||||||
初始化音频数据
|
初始化音频数据
|
||||||
Args:
|
Args:
|
||||||
binary_data: 音频数据
|
binary_data: 音频数据
|
||||||
"""
|
"""
|
||||||
logger.debug("[%s]初始化音频数据, 数据类型为%s", self.__class__.__name__, type(binary_data))
|
logger.debug(
|
||||||
|
"[%s]初始化音频数据, 数据类型为%s",
|
||||||
|
self.__class__.__name__,
|
||||||
|
type(binary_data),
|
||||||
|
)
|
||||||
super().__init__(binary_data=binary_data)
|
super().__init__(binary_data=binary_data)
|
||||||
|
|
||||||
def __getitem__(self):
|
def __getitem__(self):
|
||||||
@ -82,9 +94,13 @@ class _AudioBinary_data(BaseModel):
|
|||||||
"""
|
"""
|
||||||
return self.binary_data
|
return self.binary_data
|
||||||
|
|
||||||
|
|
||||||
class AudioBinary_data_list(BaseModel):
|
class AudioBinary_data_list(BaseModel):
|
||||||
"""音频数据列表"""
|
"""音频数据列表"""
|
||||||
binary_data_list: List[_AudioBinary_data] = Field(description="音频数据列表", default=[])
|
|
||||||
|
binary_data_list: List[_AudioBinary_data] = Field(
|
||||||
|
description="音频数据列表", default=[]
|
||||||
|
)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
@ -118,6 +134,7 @@ class AudioBinary_data_list(BaseModel):
|
|||||||
"""
|
"""
|
||||||
return len(self.binary_data_list)
|
return len(self.binary_data_list)
|
||||||
|
|
||||||
|
|
||||||
# class AudioBinary_Slice(BaseModel):
|
# class AudioBinary_Slice(BaseModel):
|
||||||
# """音频块切片"""
|
# """音频块切片"""
|
||||||
# target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None)
|
# target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None)
|
||||||
@ -138,4 +155,4 @@ class AudioBinary_data_list(BaseModel):
|
|||||||
# return v
|
# return v
|
||||||
|
|
||||||
# def __call__(self):
|
# def __call__(self):
|
||||||
# return self.target_Binary(self.start_index, self.end_index)
|
# return self.target_Binary(self.start_index, self.end_index)
|
||||||
|
@ -2,23 +2,27 @@ from pydantic import BaseModel, Field, validator
|
|||||||
from typing import List, Optional, Callable, Any
|
from typing import List, Optional, Callable, Any
|
||||||
from .audio import AudioBinary_data_list, _AudioBinary_data
|
from .audio import AudioBinary_data_list, _AudioBinary_data
|
||||||
|
|
||||||
|
|
||||||
class VAD_Functor_result(BaseModel):
|
class VAD_Functor_result(BaseModel):
|
||||||
"""
|
"""
|
||||||
VADFunctor结果
|
VADFunctor结果
|
||||||
"""
|
"""
|
||||||
|
|
||||||
audiobinary_data_list: AudioBinary_data_list = Field(description="音频数据列表")
|
audiobinary_data_list: AudioBinary_data_list = Field(description="音频数据列表")
|
||||||
audiobinary_index: int = Field(description="音频数据索引")
|
audiobinary_index: int = Field(description="音频数据索引")
|
||||||
audiobinary_data: _AudioBinary_data = Field(description="音频数据, 指向AudioBinary_data")
|
audiobinary_data: _AudioBinary_data = Field(
|
||||||
|
description="音频数据, 指向AudioBinary_data"
|
||||||
|
)
|
||||||
start_time: int = Field(description="开始时间", is_required=True)
|
start_time: int = Field(description="开始时间", is_required=True)
|
||||||
end_time: int = Field(description="结束时间", is_required=True)
|
end_time: int = Field(description="结束时间", is_required=True)
|
||||||
|
|
||||||
@validator('audiobinary_data_list')
|
@validator("audiobinary_data_list")
|
||||||
def validate_audiobinary_data_list(cls, v):
|
def validate_audiobinary_data_list(cls, v):
|
||||||
if not isinstance(v, AudioBinary_data_list):
|
if not isinstance(v, AudioBinary_data_list):
|
||||||
raise ValueError("audiobinary_data_list必须是AudioBinary_data_list类型")
|
raise ValueError("audiobinary_data_list必须是AudioBinary_data_list类型")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@validator('audiobinary_index')
|
@validator("audiobinary_index")
|
||||||
def validate_audiobinary_index(cls, v):
|
def validate_audiobinary_index(cls, v):
|
||||||
if not isinstance(v, int):
|
if not isinstance(v, int):
|
||||||
raise ValueError("audiobinary_index必须是int类型")
|
raise ValueError("audiobinary_index必须是int类型")
|
||||||
@ -26,35 +30,35 @@ class VAD_Functor_result(BaseModel):
|
|||||||
raise ValueError("audiobinary_index必须大于0")
|
raise ValueError("audiobinary_index必须大于0")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@validator('audiobinary_data')
|
@validator("audiobinary_data")
|
||||||
def validate_audiobinary_data(cls, v):
|
def validate_audiobinary_data(cls, v):
|
||||||
if not isinstance(v, _AudioBinary_data):
|
if not isinstance(v, _AudioBinary_data):
|
||||||
raise ValueError("audiobinary_data必须是AudioBinary_data类型")
|
raise ValueError("audiobinary_data必须是AudioBinary_data类型")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@validator('start_time')
|
@validator("start_time")
|
||||||
def validate_start_time(cls, v):
|
def validate_start_time(cls, v):
|
||||||
if not isinstance(v, int):
|
if not isinstance(v, int):
|
||||||
raise ValueError("start_time必须是int类型")
|
raise ValueError("start_time必须是int类型")
|
||||||
if v < 0:
|
if v < 0:
|
||||||
raise ValueError("start_time必须大于0")
|
raise ValueError("start_time必须大于0")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@validator('end_time')
|
@validator("end_time")
|
||||||
def validate_end_time(cls, v, values):
|
def validate_end_time(cls, v, values):
|
||||||
if not isinstance(v, int):
|
if not isinstance(v, int):
|
||||||
raise ValueError("end_time必须是int类型")
|
raise ValueError("end_time必须是int类型")
|
||||||
if 'start_time' in values and v <= values['start_time']:
|
if "start_time" in values and v <= values["start_time"]:
|
||||||
raise ValueError("end_time必须大于start_time")
|
raise ValueError("end_time必须大于start_time")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_from_push_data(
|
def create_from_push_data(
|
||||||
cls,
|
cls,
|
||||||
audiobinary_data_list: AudioBinary_data_list,
|
audiobinary_data_list: AudioBinary_data_list,
|
||||||
data: Any,
|
data: Any,
|
||||||
start_time: int,
|
start_time: int,
|
||||||
end_time: int
|
end_time: int,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
创建VAD片段
|
创建VAD片段
|
||||||
@ -66,7 +70,8 @@ class VAD_Functor_result(BaseModel):
|
|||||||
audiobinary_index=index,
|
audiobinary_index=index,
|
||||||
audiobinary_data=audiobinary_data_list[index],
|
audiobinary_data=audiobinary_data_list[index],
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time)
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""
|
"""
|
||||||
@ -78,11 +83,9 @@ class VAD_Functor_result(BaseModel):
|
|||||||
"""
|
"""
|
||||||
字符串展示内容
|
字符串展示内容
|
||||||
"""
|
"""
|
||||||
tostr = f'audiobinary_data_index: {self.audiobinary_index}\n'
|
tostr = f"audiobinary_data_index: {self.audiobinary_index}\n"
|
||||||
tostr += f'start_time: {self.start_time}\n'
|
tostr += f"start_time: {self.start_time}\n"
|
||||||
tostr += f'end_time: {self.end_time}\n'
|
tostr += f"end_time: {self.end_time}\n"
|
||||||
tostr += f'data_length: {len(self.audiobinary_data.binary_data)}\n'
|
tostr += f"data_length: {len(self.audiobinary_data.binary_data)}\n"
|
||||||
tostr += f'data_type: {type(self.audiobinary_data.binary_data)}\n'
|
tostr += f"data_type: {type(self.audiobinary_data.binary_data)}\n"
|
||||||
return tostr
|
return tostr
|
||||||
|
|
||||||
|
|
@ -7,11 +7,13 @@ import threading
|
|||||||
|
|
||||||
logger = get_module_logger(__name__)
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ASRPipeline(PipelineBase):
|
class ASRPipeline(PipelineBase):
|
||||||
"""
|
"""
|
||||||
管道类
|
管道类
|
||||||
实现具体的处理逻辑
|
实现具体的处理逻辑
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
初始化管道
|
初始化管道
|
||||||
@ -91,27 +93,24 @@ class ASRPipeline(PipelineBase):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from src.functor import FunctorFactory
|
from src.functor import FunctorFactory
|
||||||
|
|
||||||
# 加载VAD、asr、spk functor
|
# 加载VAD、asr、spk functor
|
||||||
self._functor_dict["vad"] = FunctorFactory.make_functor(
|
self._functor_dict["vad"] = FunctorFactory.make_functor(
|
||||||
functor_name = "vad",
|
functor_name="vad", config=self._config, models=self._models
|
||||||
config = self._config,
|
|
||||||
models = self._models
|
|
||||||
)
|
)
|
||||||
self._functor_dict["asr"] = FunctorFactory.make_functor(
|
self._functor_dict["asr"] = FunctorFactory.make_functor(
|
||||||
functor_name = "asr",
|
functor_name="asr", config=self._config, models=self._models
|
||||||
config = self._config,
|
|
||||||
models = self._models
|
|
||||||
)
|
)
|
||||||
self._functor_dict["spk"] = FunctorFactory.make_functor(
|
self._functor_dict["spk"] = FunctorFactory.make_functor(
|
||||||
functor_name = "spk",
|
functor_name="spk", config=self._config, models=self._models
|
||||||
config = self._config,
|
|
||||||
models = self._models
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建音频数据存储单元
|
# 创建音频数据存储单元
|
||||||
self._audio_binary_data_list = AudioBinary_data_list()
|
self._audio_binary_data_list = AudioBinary_data_list()
|
||||||
|
|
||||||
self._functor_dict["vad"].set_audio_binary_data_list(self._audio_binary_data_list)
|
self._functor_dict["vad"].set_audio_binary_data_list(
|
||||||
|
self._audio_binary_data_list
|
||||||
|
)
|
||||||
|
|
||||||
# 初始化子队列
|
# 初始化子队列
|
||||||
self._subqueue_dict["original"] = Queue()
|
self._subqueue_dict["original"] = Queue()
|
||||||
@ -121,7 +120,7 @@ class ASRPipeline(PipelineBase):
|
|||||||
self._subqueue_dict["spkend"] = Queue()
|
self._subqueue_dict["spkend"] = Queue()
|
||||||
|
|
||||||
# 设置子队列的输入队列
|
# 设置子队列的输入队列
|
||||||
self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"])
|
self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"])
|
||||||
self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
|
self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
|
||||||
self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
|
self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
|
||||||
|
|
||||||
@ -134,30 +133,40 @@ class ASRPipeline(PipelineBase):
|
|||||||
"""
|
"""
|
||||||
带回调函数的put
|
带回调函数的put
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def put_with_check(data: Any) -> None:
|
def put_with_check(data: Any) -> None:
|
||||||
queue.put(data)
|
queue.put(data)
|
||||||
callback(data)
|
callback(data)
|
||||||
|
|
||||||
return put_with_check
|
return put_with_check
|
||||||
|
|
||||||
self._functor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result))
|
self._functor_dict["asr"].add_callback(
|
||||||
self._functor_dict["spk"].add_callback(put_with_checkcallback(self._subqueue_dict["spkend"], self._check_result))
|
put_with_checkcallback(
|
||||||
|
self._subqueue_dict["asrend"], self._check_result
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._functor_dict["spk"].add_callback(
|
||||||
|
put_with_checkcallback(
|
||||||
|
self._subqueue_dict["spkend"], self._check_result
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
|
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
|
||||||
|
|
||||||
def _check_result(self, result: Any) -> None:
|
def _check_result(self, result: Any) -> None:
|
||||||
"""
|
"""
|
||||||
检查结果
|
检查结果
|
||||||
"""
|
"""
|
||||||
# 若asr和spk队列中都有数据,则合并数据
|
# 若asr和spk队列中都有数据,则合并数据
|
||||||
if self._subqueue_dict["asrend"].qsize() & self._subqueue_dict["spkend"].qsize():
|
if (
|
||||||
|
self._subqueue_dict["asrend"].qsize()
|
||||||
|
& self._subqueue_dict["spkend"].qsize()
|
||||||
|
):
|
||||||
asr_data = self._subqueue_dict["asrend"].get()
|
asr_data = self._subqueue_dict["asrend"].get()
|
||||||
spk_data = self._subqueue_dict["spkend"].get()
|
spk_data = self._subqueue_dict["spkend"].get()
|
||||||
# 合并数据
|
# 合并数据
|
||||||
result = {
|
result = {"asr_data": asr_data, "spk_data": spk_data}
|
||||||
"asr_data": asr_data,
|
|
||||||
"spk_data": spk_data
|
|
||||||
}
|
|
||||||
# 通知回调函数
|
# 通知回调函数
|
||||||
self._notify_callbacks(result)
|
self._notify_callbacks(result)
|
||||||
|
|
||||||
@ -211,13 +220,13 @@ class ASRPipeline(PipelineBase):
|
|||||||
logger.info("收到结束信号,管道准备停止")
|
logger.info("收到结束信号,管道准备停止")
|
||||||
self._input_queue.task_done() # 标记结束信号已处理
|
self._input_queue.task_done() # 标记结束信号已处理
|
||||||
break
|
break
|
||||||
|
|
||||||
# 处理数据
|
# 处理数据
|
||||||
self._process(data)
|
self._process(data)
|
||||||
|
|
||||||
# 标记任务完成
|
# 标记任务完成
|
||||||
self._input_queue.task_done()
|
self._input_queue.task_done()
|
||||||
|
|
||||||
except Empty:
|
except Empty:
|
||||||
# 队列获取超时,继续等待
|
# 队列获取超时,继续等待
|
||||||
continue
|
continue
|
||||||
@ -237,7 +246,7 @@ class ASRPipeline(PipelineBase):
|
|||||||
"""
|
"""
|
||||||
# 子类实现具体的处理逻辑
|
# 子类实现具体的处理逻辑
|
||||||
self._subqueue_dict["original"].put(data)
|
self._subqueue_dict["original"].put(data)
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""
|
"""
|
||||||
停止管道
|
停止管道
|
||||||
|
@ -8,11 +8,13 @@ import time
|
|||||||
# 配置日志
|
# 配置日志
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PipelineBase(ABC):
|
class PipelineBase(ABC):
|
||||||
"""
|
"""
|
||||||
管道基类
|
管道基类
|
||||||
定义了管道的基本接口和通用功能
|
定义了管道的基本接口和通用功能
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_queue: Optional[Queue] = None):
|
def __init__(self, input_queue: Optional[Queue] = None):
|
||||||
"""
|
"""
|
||||||
初始化管道
|
初始化管道
|
||||||
@ -93,7 +95,7 @@ class PipelineBase(ABC):
|
|||||||
if self._thread and self._thread.is_alive():
|
if self._thread and self._thread.is_alive():
|
||||||
timeout = timeout if timeout is not None else self._stop_timeout
|
timeout = timeout if timeout is not None else self._stop_timeout
|
||||||
self._thread.join(timeout=timeout)
|
self._thread.join(timeout=timeout)
|
||||||
|
|
||||||
# 检查是否成功停止
|
# 检查是否成功停止
|
||||||
if self._thread.is_alive():
|
if self._thread.is_alive():
|
||||||
logger.warning(f"管道停止超时({timeout}秒),强制终止")
|
logger.warning(f"管道停止超时({timeout}秒),强制终止")
|
||||||
@ -101,7 +103,7 @@ class PipelineBase(ABC):
|
|||||||
else:
|
else:
|
||||||
logger.info("管道已成功停止")
|
logger.info("管道已成功停止")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def force_stop(self) -> None:
|
def force_stop(self) -> None:
|
||||||
@ -115,14 +117,35 @@ class PipelineBase(ABC):
|
|||||||
# 注意:Python的线程无法被强制终止,这里只是设置标志
|
# 注意:Python的线程无法被强制终止,这里只是设置标志
|
||||||
# 实际终止需要依赖操作系统的进程管理
|
# 实际终止需要依赖操作系统的进程管理
|
||||||
|
|
||||||
|
|
||||||
class PipelineFactory:
|
class PipelineFactory:
|
||||||
"""
|
"""
|
||||||
管道工厂类
|
管道工厂类
|
||||||
用于创建管道实例
|
用于创建管道实例
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
|
||||||
def create_pipeline(pipeline_name: str) -> Any:
|
from src.pipeline.ASRpipeline import ASRPipeline
|
||||||
|
def _create_pipeline_ASRpipeline(*args, **kwargs) -> ASRPipeline:
|
||||||
|
"""
|
||||||
|
创建ASR管道实例
|
||||||
|
"""
|
||||||
|
from src.pipeline.ASRpipeline import ASRPipeline
|
||||||
|
pipeline = ASRPipeline()
|
||||||
|
pipeline.set_config(kwargs["config"])
|
||||||
|
pipeline.set_models(kwargs["models"])
|
||||||
|
pipeline.set_audio_binary(kwargs["audio_binary"])
|
||||||
|
pipeline.set_input_queue(kwargs["input_queue"])
|
||||||
|
pipeline.add_callback(kwargs["callback"])
|
||||||
|
pipeline.bake()
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_pipeline(cls, pipeline_name: str, *args, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
创建管道实例
|
创建管道实例
|
||||||
"""
|
"""
|
||||||
pass
|
if pipeline_name == "ASRpipeline":
|
||||||
|
return cls._create_pipeline_ASRpipeline(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的管道类型: {pipeline_name}")
|
||||||
|
|
||||||
|
@ -172,7 +172,7 @@ class STTRunner(RunnerBase):
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理",
|
"等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理",
|
||||||
self._stop_timeout,
|
self._stop_timeout,
|
||||||
self._input_queue.qsize()
|
self._input_queue.qsize(),
|
||||||
)
|
)
|
||||||
success = False
|
success = False
|
||||||
break
|
break
|
||||||
@ -188,7 +188,7 @@ class STTRunner(RunnerBase):
|
|||||||
"错误堆栈:\n%s",
|
"错误堆栈:\n%s",
|
||||||
error_type,
|
error_type,
|
||||||
error_msg,
|
error_msg,
|
||||||
error_traceback
|
error_traceback,
|
||||||
)
|
)
|
||||||
success = False
|
success = False
|
||||||
|
|
||||||
@ -198,7 +198,7 @@ class STTRunner(RunnerBase):
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"部分管道停止失败,队列状态: 大小=%d, 是否为空=%s",
|
"部分管道停止失败,队列状态: 大小=%d, 是否为空=%s",
|
||||||
self._input_queue.qsize(),
|
self._input_queue.qsize(),
|
||||||
self._input_queue.empty()
|
self._input_queue.empty(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return success
|
return success
|
||||||
|
108
src/server.py
108
src/server.py
@ -43,7 +43,7 @@ async def clear_websocket():
|
|||||||
async def ws_serve(websocket, path):
|
async def ws_serve(websocket, path):
|
||||||
"""
|
"""
|
||||||
WebSocket服务主函数,处理客户端连接和消息
|
WebSocket服务主函数,处理客户端连接和消息
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
websocket: WebSocket连接对象
|
websocket: WebSocket连接对象
|
||||||
path: 连接路径
|
path: 连接路径
|
||||||
@ -51,13 +51,13 @@ async def ws_serve(websocket, path):
|
|||||||
frames = [] # 存储所有音频帧
|
frames = [] # 存储所有音频帧
|
||||||
frames_asr = [] # 存储用于离线ASR的音频帧
|
frames_asr = [] # 存储用于离线ASR的音频帧
|
||||||
frames_asr_online = [] # 存储用于在线ASR的音频帧
|
frames_asr_online = [] # 存储用于在线ASR的音频帧
|
||||||
|
|
||||||
global websocket_users
|
global websocket_users
|
||||||
# await clear_websocket() # 清理现有连接(目前注释掉,允许多客户端)
|
# await clear_websocket() # 清理现有连接(目前注释掉,允许多客户端)
|
||||||
|
|
||||||
# 添加到用户集合
|
# 添加到用户集合
|
||||||
websocket_users.add(websocket)
|
websocket_users.add(websocket)
|
||||||
|
|
||||||
# 初始化连接状态
|
# 初始化连接状态
|
||||||
websocket.status_dict_asr = {}
|
websocket.status_dict_asr = {}
|
||||||
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
|
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
|
||||||
@ -66,15 +66,15 @@ async def ws_serve(websocket, path):
|
|||||||
websocket.chunk_interval = 10
|
websocket.chunk_interval = 10
|
||||||
websocket.vad_pre_idx = 0
|
websocket.vad_pre_idx = 0
|
||||||
websocket.is_speaking = True # 默认用户正在说话
|
websocket.is_speaking = True # 默认用户正在说话
|
||||||
|
|
||||||
# 语音检测状态
|
# 语音检测状态
|
||||||
speech_start = False
|
speech_start = False
|
||||||
speech_end_i = -1
|
speech_end_i = -1
|
||||||
|
|
||||||
# 初始化配置
|
# 初始化配置
|
||||||
websocket.wav_name = "microphone"
|
websocket.wav_name = "microphone"
|
||||||
websocket.mode = "2pass" # 默认使用两阶段识别模式
|
websocket.mode = "2pass" # 默认使用两阶段识别模式
|
||||||
|
|
||||||
print("新用户已连接", flush=True)
|
print("新用户已连接", flush=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -84,11 +84,13 @@ async def ws_serve(websocket, path):
|
|||||||
if isinstance(message, str):
|
if isinstance(message, str):
|
||||||
try:
|
try:
|
||||||
messagejson = json.loads(message)
|
messagejson = json.loads(message)
|
||||||
|
|
||||||
# 更新各种配置参数
|
# 更新各种配置参数
|
||||||
if "is_speaking" in messagejson:
|
if "is_speaking" in messagejson:
|
||||||
websocket.is_speaking = messagejson["is_speaking"]
|
websocket.is_speaking = messagejson["is_speaking"]
|
||||||
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
|
websocket.status_dict_asr_online["is_final"] = (
|
||||||
|
not websocket.is_speaking
|
||||||
|
)
|
||||||
if "chunk_interval" in messagejson:
|
if "chunk_interval" in messagejson:
|
||||||
websocket.chunk_interval = messagejson["chunk_interval"]
|
websocket.chunk_interval = messagejson["chunk_interval"]
|
||||||
if "wav_name" in messagejson:
|
if "wav_name" in messagejson:
|
||||||
@ -97,11 +99,17 @@ async def ws_serve(websocket, path):
|
|||||||
chunk_size = messagejson["chunk_size"]
|
chunk_size = messagejson["chunk_size"]
|
||||||
if isinstance(chunk_size, str):
|
if isinstance(chunk_size, str):
|
||||||
chunk_size = chunk_size.split(",")
|
chunk_size = chunk_size.split(",")
|
||||||
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
|
websocket.status_dict_asr_online["chunk_size"] = [
|
||||||
|
int(x) for x in chunk_size
|
||||||
|
]
|
||||||
if "encoder_chunk_look_back" in messagejson:
|
if "encoder_chunk_look_back" in messagejson:
|
||||||
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
|
websocket.status_dict_asr_online["encoder_chunk_look_back"] = (
|
||||||
|
messagejson["encoder_chunk_look_back"]
|
||||||
|
)
|
||||||
if "decoder_chunk_look_back" in messagejson:
|
if "decoder_chunk_look_back" in messagejson:
|
||||||
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"]
|
websocket.status_dict_asr_online["decoder_chunk_look_back"] = (
|
||||||
|
messagejson["decoder_chunk_look_back"]
|
||||||
|
)
|
||||||
if "hotword" in messagejson:
|
if "hotword" in messagejson:
|
||||||
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
|
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
|
||||||
if "mode" in messagejson:
|
if "mode" in messagejson:
|
||||||
@ -111,11 +119,17 @@ async def ws_serve(websocket, path):
|
|||||||
|
|
||||||
# 根据chunk_interval更新VAD的chunk_size
|
# 根据chunk_interval更新VAD的chunk_size
|
||||||
websocket.status_dict_vad["chunk_size"] = int(
|
websocket.status_dict_vad["chunk_size"] = int(
|
||||||
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1] * 60 / websocket.chunk_interval
|
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1]
|
||||||
|
* 60
|
||||||
|
/ websocket.chunk_interval
|
||||||
)
|
)
|
||||||
|
|
||||||
# 处理音频数据
|
# 处理音频数据
|
||||||
if len(frames_asr_online) > 0 or len(frames_asr) >= 0 or not isinstance(message, str):
|
if (
|
||||||
|
len(frames_asr_online) > 0
|
||||||
|
or len(frames_asr) >= 0
|
||||||
|
or not isinstance(message, str)
|
||||||
|
):
|
||||||
if not isinstance(message, str): # 二进制音频数据
|
if not isinstance(message, str): # 二进制音频数据
|
||||||
# 添加到帧缓冲区
|
# 添加到帧缓冲区
|
||||||
frames.append(message)
|
frames.append(message)
|
||||||
@ -125,10 +139,12 @@ async def ws_serve(websocket, path):
|
|||||||
# 处理在线ASR
|
# 处理在线ASR
|
||||||
frames_asr_online.append(message)
|
frames_asr_online.append(message)
|
||||||
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
|
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
|
||||||
|
|
||||||
# 达到chunk_interval或最终帧时处理在线ASR
|
# 达到chunk_interval或最终帧时处理在线ASR
|
||||||
if (len(frames_asr_online) % websocket.chunk_interval == 0 or
|
if (
|
||||||
websocket.status_dict_asr_online["is_final"]):
|
len(frames_asr_online) % websocket.chunk_interval == 0
|
||||||
|
or websocket.status_dict_asr_online["is_final"]
|
||||||
|
):
|
||||||
if websocket.mode == "2pass" or websocket.mode == "online":
|
if websocket.mode == "2pass" or websocket.mode == "online":
|
||||||
audio_in = b"".join(frames_asr_online)
|
audio_in = b"".join(frames_asr_online)
|
||||||
try:
|
try:
|
||||||
@ -136,26 +152,32 @@ async def ws_serve(websocket, path):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"在线ASR处理错误: {e}")
|
print(f"在线ASR处理错误: {e}")
|
||||||
frames_asr_online = []
|
frames_asr_online = []
|
||||||
|
|
||||||
# 如果检测到语音开始,收集帧用于离线ASR
|
# 如果检测到语音开始,收集帧用于离线ASR
|
||||||
if speech_start:
|
if speech_start:
|
||||||
frames_asr.append(message)
|
frames_asr.append(message)
|
||||||
|
|
||||||
# VAD处理 - 语音活动检测
|
# VAD处理 - 语音活动检测
|
||||||
try:
|
try:
|
||||||
speech_start_i, speech_end_i = await asr_service.async_vad(websocket, message)
|
speech_start_i, speech_end_i = await asr_service.async_vad(
|
||||||
|
websocket, message
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"VAD处理错误: {e}")
|
print(f"VAD处理错误: {e}")
|
||||||
|
|
||||||
# 检测到语音开始
|
# 检测到语音开始
|
||||||
if speech_start_i != -1:
|
if speech_start_i != -1:
|
||||||
speech_start = True
|
speech_start = True
|
||||||
# 计算开始偏移并收集前面的帧
|
# 计算开始偏移并收集前面的帧
|
||||||
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
|
beg_bias = (
|
||||||
frames_pre = frames[-beg_bias:] if beg_bias < len(frames) else frames
|
websocket.vad_pre_idx - speech_start_i
|
||||||
|
) // duration_ms
|
||||||
|
frames_pre = (
|
||||||
|
frames[-beg_bias:] if beg_bias < len(frames) else frames
|
||||||
|
)
|
||||||
frames_asr = []
|
frames_asr = []
|
||||||
frames_asr.extend(frames_pre)
|
frames_asr.extend(frames_pre)
|
||||||
|
|
||||||
# 处理离线ASR (语音结束或用户停止说话)
|
# 处理离线ASR (语音结束或用户停止说话)
|
||||||
if speech_end_i != -1 or not websocket.is_speaking:
|
if speech_end_i != -1 or not websocket.is_speaking:
|
||||||
if websocket.mode == "2pass" or websocket.mode == "offline":
|
if websocket.mode == "2pass" or websocket.mode == "offline":
|
||||||
@ -164,13 +186,13 @@ async def ws_serve(websocket, path):
|
|||||||
await asr_service.async_asr(websocket, audio_in)
|
await asr_service.async_asr(websocket, audio_in)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"离线ASR处理错误: {e}")
|
print(f"离线ASR处理错误: {e}")
|
||||||
|
|
||||||
# 重置状态
|
# 重置状态
|
||||||
frames_asr = []
|
frames_asr = []
|
||||||
speech_start = False
|
speech_start = False
|
||||||
frames_asr_online = []
|
frames_asr_online = []
|
||||||
websocket.status_dict_asr_online["cache"] = {}
|
websocket.status_dict_asr_online["cache"] = {}
|
||||||
|
|
||||||
# 如果用户停止说话,完全重置
|
# 如果用户停止说话,完全重置
|
||||||
if not websocket.is_speaking:
|
if not websocket.is_speaking:
|
||||||
websocket.vad_pre_idx = 0
|
websocket.vad_pre_idx = 0
|
||||||
@ -193,34 +215,34 @@ async def ws_serve(websocket, path):
|
|||||||
def start_server(args, asr_service_instance):
|
def start_server(args, asr_service_instance):
|
||||||
"""
|
"""
|
||||||
启动WebSocket服务器
|
启动WebSocket服务器
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
args: 命令行参数
|
args: 命令行参数
|
||||||
asr_service_instance: ASR服务实例
|
asr_service_instance: ASR服务实例
|
||||||
"""
|
"""
|
||||||
global asr_service
|
global asr_service
|
||||||
asr_service = asr_service_instance
|
asr_service = asr_service_instance
|
||||||
|
|
||||||
# 配置SSL (如果提供了证书)
|
# 配置SSL (如果提供了证书)
|
||||||
if args.certfile and len(args.certfile) > 0:
|
if args.certfile and len(args.certfile) > 0:
|
||||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile)
|
ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile)
|
||||||
|
|
||||||
start_server = websockets.serve(
|
start_server = websockets.serve(
|
||||||
ws_serve, args.host, args.port,
|
ws_serve,
|
||||||
subprotocols=["binary"],
|
args.host,
|
||||||
ping_interval=None,
|
args.port,
|
||||||
ssl=ssl_context
|
subprotocols=["binary"],
|
||||||
|
ping_interval=None,
|
||||||
|
ssl=ssl_context,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
start_server = websockets.serve(
|
start_server = websockets.serve(
|
||||||
ws_serve, args.host, args.port,
|
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None
|
||||||
subprotocols=["binary"],
|
|
||||||
ping_interval=None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}")
|
print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}")
|
||||||
|
|
||||||
# 启动事件循环
|
# 启动事件循环
|
||||||
asyncio.get_event_loop().run_until_complete(start_server)
|
asyncio.get_event_loop().run_until_complete(start_server)
|
||||||
asyncio.get_event_loop().run_forever()
|
asyncio.get_event_loop().run_forever()
|
||||||
@ -229,14 +251,14 @@ def start_server(args, asr_service_instance):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 解析命令行参数
|
# 解析命令行参数
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# 加载模型
|
# 加载模型
|
||||||
print("正在加载模型...")
|
print("正在加载模型...")
|
||||||
models = load_models(args)
|
models = load_models(args)
|
||||||
print("模型加载完成!当前仅支持单个客户端同时连接!")
|
print("模型加载完成!当前仅支持单个客户端同时连接!")
|
||||||
|
|
||||||
# 创建ASR服务
|
# 创建ASR服务
|
||||||
asr_service = ASRService(models)
|
asr_service = ASRService(models)
|
||||||
|
|
||||||
# 启动服务器
|
# 启动服务器
|
||||||
start_server(args, asr_service)
|
start_server(args, asr_service)
|
||||||
|
@ -9,11 +9,11 @@ import json
|
|||||||
|
|
||||||
class ASRService:
|
class ASRService:
|
||||||
"""ASR服务类,封装各种语音识别相关功能"""
|
"""ASR服务类,封装各种语音识别相关功能"""
|
||||||
|
|
||||||
def __init__(self, models):
|
def __init__(self, models):
|
||||||
"""
|
"""
|
||||||
初始化ASR服务
|
初始化ASR服务
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
models: 包含各种预加载模型的字典
|
models: 包含各种预加载模型的字典
|
||||||
"""
|
"""
|
||||||
@ -21,42 +21,41 @@ class ASRService:
|
|||||||
self.model_asr_streaming = models["asr_streaming"]
|
self.model_asr_streaming = models["asr_streaming"]
|
||||||
self.model_vad = models["vad"]
|
self.model_vad = models["vad"]
|
||||||
self.model_punc = models["punc"]
|
self.model_punc = models["punc"]
|
||||||
|
|
||||||
async def async_vad(self, websocket, audio_in):
|
async def async_vad(self, websocket, audio_in):
|
||||||
"""
|
"""
|
||||||
语音活动检测
|
语音活动检测
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
websocket: WebSocket连接
|
websocket: WebSocket连接
|
||||||
audio_in: 二进制音频数据
|
audio_in: 二进制音频数据
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
tuple: (speech_start, speech_end) 语音开始和结束位置
|
tuple: (speech_start, speech_end) 语音开始和结束位置
|
||||||
"""
|
"""
|
||||||
# 使用VAD模型分析音频段
|
# 使用VAD模型分析音频段
|
||||||
segments_result = self.model_vad.generate(
|
segments_result = self.model_vad.generate(
|
||||||
input=audio_in,
|
input=audio_in, **websocket.status_dict_vad
|
||||||
**websocket.status_dict_vad
|
|
||||||
)[0]["value"]
|
)[0]["value"]
|
||||||
|
|
||||||
speech_start = -1
|
speech_start = -1
|
||||||
speech_end = -1
|
speech_end = -1
|
||||||
|
|
||||||
# 解析VAD结果
|
# 解析VAD结果
|
||||||
if len(segments_result) == 0 or len(segments_result) > 1:
|
if len(segments_result) == 0 or len(segments_result) > 1:
|
||||||
return speech_start, speech_end
|
return speech_start, speech_end
|
||||||
|
|
||||||
if segments_result[0][0] != -1:
|
if segments_result[0][0] != -1:
|
||||||
speech_start = segments_result[0][0]
|
speech_start = segments_result[0][0]
|
||||||
if segments_result[0][1] != -1:
|
if segments_result[0][1] != -1:
|
||||||
speech_end = segments_result[0][1]
|
speech_end = segments_result[0][1]
|
||||||
|
|
||||||
return speech_start, speech_end
|
return speech_start, speech_end
|
||||||
|
|
||||||
async def async_asr(self, websocket, audio_in):
|
async def async_asr(self, websocket, audio_in):
|
||||||
"""
|
"""
|
||||||
离线ASR处理
|
离线ASR处理
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
websocket: WebSocket连接
|
websocket: WebSocket连接
|
||||||
audio_in: 二进制音频数据
|
audio_in: 二进制音频数据
|
||||||
@ -64,42 +63,44 @@ class ASRService:
|
|||||||
if len(audio_in) > 0:
|
if len(audio_in) > 0:
|
||||||
# 使用离线ASR模型处理音频
|
# 使用离线ASR模型处理音频
|
||||||
rec_result = self.model_asr.generate(
|
rec_result = self.model_asr.generate(
|
||||||
input=audio_in,
|
input=audio_in, **websocket.status_dict_asr
|
||||||
**websocket.status_dict_asr
|
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# 如果有标点符号模型且识别出文本,则添加标点
|
# 如果有标点符号模型且识别出文本,则添加标点
|
||||||
if self.model_punc is not None and len(rec_result["text"]) > 0:
|
if self.model_punc is not None and len(rec_result["text"]) > 0:
|
||||||
rec_result = self.model_punc.generate(
|
rec_result = self.model_punc.generate(
|
||||||
input=rec_result["text"],
|
input=rec_result["text"], **websocket.status_dict_punc
|
||||||
**websocket.status_dict_punc
|
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# 如果识别出文本,发送到客户端
|
# 如果识别出文本,发送到客户端
|
||||||
if len(rec_result["text"]) > 0:
|
if len(rec_result["text"]) > 0:
|
||||||
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
||||||
message = json.dumps({
|
message = json.dumps(
|
||||||
"mode": mode,
|
{
|
||||||
"text": rec_result["text"],
|
"mode": mode,
|
||||||
"wav_name": websocket.wav_name,
|
"text": rec_result["text"],
|
||||||
"is_final": websocket.is_speaking,
|
"wav_name": websocket.wav_name,
|
||||||
})
|
"is_final": websocket.is_speaking,
|
||||||
|
}
|
||||||
|
)
|
||||||
await websocket.send(message)
|
await websocket.send(message)
|
||||||
else:
|
else:
|
||||||
# 如果没有音频数据,发送空文本
|
# 如果没有音频数据,发送空文本
|
||||||
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
||||||
message = json.dumps({
|
message = json.dumps(
|
||||||
"mode": mode,
|
{
|
||||||
"text": "",
|
"mode": mode,
|
||||||
"wav_name": websocket.wav_name,
|
"text": "",
|
||||||
"is_final": websocket.is_speaking,
|
"wav_name": websocket.wav_name,
|
||||||
})
|
"is_final": websocket.is_speaking,
|
||||||
|
}
|
||||||
|
)
|
||||||
await websocket.send(message)
|
await websocket.send(message)
|
||||||
|
|
||||||
async def async_asr_online(self, websocket, audio_in):
|
async def async_asr_online(self, websocket, audio_in):
|
||||||
"""
|
"""
|
||||||
在线ASR处理
|
在线ASR处理
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
websocket: WebSocket连接
|
websocket: WebSocket连接
|
||||||
audio_in: 二进制音频数据
|
audio_in: 二进制音频数据
|
||||||
@ -107,21 +108,24 @@ class ASRService:
|
|||||||
if len(audio_in) > 0:
|
if len(audio_in) > 0:
|
||||||
# 使用在线ASR模型处理音频
|
# 使用在线ASR模型处理音频
|
||||||
rec_result = self.model_asr_streaming.generate(
|
rec_result = self.model_asr_streaming.generate(
|
||||||
input=audio_in,
|
input=audio_in, **websocket.status_dict_asr_online
|
||||||
**websocket.status_dict_asr_online
|
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# 在2pass模式下,如果是最终帧则跳过(留给离线ASR处理)
|
# 在2pass模式下,如果是最终帧则跳过(留给离线ASR处理)
|
||||||
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False):
|
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get(
|
||||||
|
"is_final", False
|
||||||
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
# 如果识别出文本,发送到客户端
|
# 如果识别出文本,发送到客户端
|
||||||
if len(rec_result["text"]):
|
if len(rec_result["text"]):
|
||||||
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
|
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
|
||||||
message = json.dumps({
|
message = json.dumps(
|
||||||
"mode": mode,
|
{
|
||||||
"text": rec_result["text"],
|
"mode": mode,
|
||||||
"wav_name": websocket.wav_name,
|
"text": rec_result["text"],
|
||||||
"is_final": websocket.is_speaking,
|
"wav_name": websocket.wav_name,
|
||||||
})
|
"is_final": websocket.is_speaking,
|
||||||
await websocket.send(message)
|
}
|
||||||
|
)
|
||||||
|
await websocket.send(message)
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
from .logger import get_module_logger, setup_root_logger
|
from .logger import get_module_logger, setup_root_logger
|
||||||
|
|
||||||
__all__ = ["get_module_logger", "setup_root_logger"]
|
__all__ = ["get_module_logger", "setup_root_logger"]
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
处理各类音频数据与bytes的转换
|
处理各类音频数据与bytes的转换
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import wave
|
import wave
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
import io
|
import io
|
||||||
|
|
||||||
|
|
||||||
def wav_to_bytes(wav_path: str) -> bytes:
|
def wav_to_bytes(wav_path: str) -> bytes:
|
||||||
"""
|
"""
|
||||||
将WAV文件读取为bytes数据。
|
将WAV文件读取为bytes数据。
|
||||||
@ -14,24 +16,26 @@ def wav_to_bytes(wav_path: str) -> bytes:
|
|||||||
|
|
||||||
返回:
|
返回:
|
||||||
bytes: WAV文件的原始字节数据。
|
bytes: WAV文件的原始字节数据。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
FileNotFoundError: 如果WAV文件不存在。
|
FileNotFoundError: 如果WAV文件不存在。
|
||||||
wave.Error: 如果文件不是有效的WAV文件。
|
wave.Error: 如果文件不是有效的WAV文件。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with wave.open(wav_path, 'rb') as wf:
|
with wave.open(wav_path, "rb") as wf:
|
||||||
# 读取所有帧
|
# 读取所有帧
|
||||||
frames = wf.readframes(wf.getnframes())
|
frames = wf.readframes(wf.getnframes())
|
||||||
return frames
|
return frames
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
# 可以选择记录日志或重新抛出,这里为了清晰直接重新抛出
|
# 可以选择记录日志或重新抛出,这里为了清晰直接重新抛出
|
||||||
raise FileNotFoundError(f"错误:未找到WAV文件 '{wav_path}'")
|
raise FileNotFoundError(f"错误: 未找到WAV文件 '{wav_path}'")
|
||||||
except wave.Error as e:
|
except wave.Error as e:
|
||||||
raise wave.Error(f"错误:打开或读取WAV文件 '{wav_path}' 失败 - {e}")
|
raise wave.Error(f"错误: 打开或读取WAV文件 '{wav_path}' 失败 - {e}")
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_wav(bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int):
|
def bytes_to_wav(
|
||||||
|
bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
将bytes数据写入为WAV文件。
|
将bytes数据写入为WAV文件。
|
||||||
|
|
||||||
@ -41,22 +45,23 @@ def bytes_to_wav(bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: in
|
|||||||
nchannels (int): 声道数 (例如 1 for mono, 2 for stereo)。
|
nchannels (int): 声道数 (例如 1 for mono, 2 for stereo)。
|
||||||
sampwidth (int): 采样宽度 (字节数, 例如 2 for 16-bit audio)。
|
sampwidth (int): 采样宽度 (字节数, 例如 2 for 16-bit audio)。
|
||||||
framerate (int): 采样率 (例如 44100, 16000)。
|
framerate (int): 采样率 (例如 44100, 16000)。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
wave.Error: 如果写入WAV文件失败。
|
wave.Error: 如果写入WAV文件失败。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with wave.open(wav_path, 'wb') as wf:
|
with wave.open(wav_path, "wb") as wf:
|
||||||
wf.setnchannels(nchannels)
|
wf.setnchannels(nchannels)
|
||||||
wf.setsampwidth(sampwidth)
|
wf.setsampwidth(sampwidth)
|
||||||
wf.setframerate(framerate)
|
wf.setframerate(framerate)
|
||||||
wf.writeframes(bytes_data)
|
wf.writeframes(bytes_data)
|
||||||
except wave.Error as e:
|
except wave.Error as e:
|
||||||
raise wave.Error(f"错误:写入WAV文件 '{wav_path}' 失败 - {e}")
|
raise wave.Error(f"错误: 写入WAV文件 '{wav_path}' 失败 - {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 捕获其他可能的写入错误
|
# 捕获其他可能的写入错误
|
||||||
raise Exception(f"写入WAV文件 '{wav_path}' 时发生未知错误 - {e}")
|
raise Exception(f"写入WAV文件 '{wav_path}' 时发生未知错误 - {e}")
|
||||||
|
|
||||||
|
|
||||||
def mp3_to_bytes(mp3_path: str) -> bytes:
|
def mp3_to_bytes(mp3_path: str) -> bytes:
|
||||||
"""
|
"""
|
||||||
将MP3文件转换为bytes数据 (原始PCM数据)。
|
将MP3文件转换为bytes数据 (原始PCM数据)。
|
||||||
@ -66,7 +71,7 @@ def mp3_to_bytes(mp3_path: str) -> bytes:
|
|||||||
|
|
||||||
返回:
|
返回:
|
||||||
bytes: MP3文件解码后的原始PCM字节数据。
|
bytes: MP3文件解码后的原始PCM字节数据。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
FileNotFoundError: 如果MP3文件不存在。
|
FileNotFoundError: 如果MP3文件不存在。
|
||||||
pydub.exceptions.CouldntDecodeError: 如果MP3文件无法解码。
|
pydub.exceptions.CouldntDecodeError: 如果MP3文件无法解码。
|
||||||
@ -76,12 +81,19 @@ def mp3_to_bytes(mp3_path: str) -> bytes:
|
|||||||
# 获取原始PCM数据
|
# 获取原始PCM数据
|
||||||
return audio.raw_data
|
return audio.raw_data
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise FileNotFoundError(f"错误:未找到MP3文件 '{mp3_path}'")
|
raise FileNotFoundError(f"错误: 未找到MP3文件 '{mp3_path}'")
|
||||||
except Exception as e: # pydub 可能抛出多种解码相关的错误
|
except Exception as e: # pydub 可能抛出多种解码相关的错误
|
||||||
raise Exception(f"错误:处理MP3文件 '{mp3_path}' 失败 - {e}")
|
raise Exception(f"错误: 处理MP3文件 '{mp3_path}' 失败 - {e}")
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_mp3(bytes_data: bytes, mp3_path: str, frame_rate: int, channels: int, sample_width: int, bitrate: str = "192k"):
|
def bytes_to_mp3(
|
||||||
|
bytes_data: bytes,
|
||||||
|
mp3_path: str,
|
||||||
|
frame_rate: int,
|
||||||
|
channels: int,
|
||||||
|
sample_width: int,
|
||||||
|
bitrate: str = "192k",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
将原始PCM bytes数据转换为MP3文件。
|
将原始PCM bytes数据转换为MP3文件。
|
||||||
|
|
||||||
@ -102,9 +114,9 @@ def bytes_to_mp3(bytes_data: bytes, mp3_path: str, frame_rate: int, channels: in
|
|||||||
data=bytes_data,
|
data=bytes_data,
|
||||||
sample_width=sample_width,
|
sample_width=sample_width,
|
||||||
frame_rate=frame_rate,
|
frame_rate=frame_rate,
|
||||||
channels=channels
|
channels=channels,
|
||||||
)
|
)
|
||||||
# 导出为MP3
|
# 导出为MP3
|
||||||
audio.export(mp3_path, format="mp3", bitrate=bitrate)
|
audio.export(mp3_path, format="mp3", bitrate=bitrate)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"错误:转换或写入MP3文件 '{mp3_path}' 失败 - {e}")
|
raise Exception(f"错误: 转换或写入MP3文件 '{mp3_path}' 失败 - {e}")
|
||||||
|
@ -3,6 +3,7 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
def setup_logger(
|
def setup_logger(
|
||||||
name: str = None,
|
name: str = None,
|
||||||
level: str = "INFO",
|
level: str = "INFO",
|
||||||
@ -12,80 +13,79 @@ def setup_logger(
|
|||||||
) -> logging.Logger:
|
) -> logging.Logger:
|
||||||
"""
|
"""
|
||||||
设置并返回一个配置好的logger实例
|
设置并返回一个配置好的logger实例
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: logger的名称,默认为None(使用root logger)
|
name: logger的名称,默认为None(使用root logger)
|
||||||
level: 日志级别,默认为"INFO"
|
level: 日志级别,默认为"INFO"
|
||||||
log_file: 日志文件路径,默认为None(仅控制台输出)
|
log_file: 日志文件路径,默认为None(仅控制台输出)
|
||||||
log_format: 日志格式
|
log_format: 日志格式
|
||||||
date_format: 日期格式
|
date_format: 日期格式
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
logging.Logger: 配置好的logger实例
|
logging.Logger: 配置好的logger实例
|
||||||
"""
|
"""
|
||||||
# 获取logger实例
|
# 获取logger实例
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
|
|
||||||
# 设置日志级别
|
# 设置日志级别
|
||||||
level = getattr(logging, level.upper())
|
level = getattr(logging, level.upper())
|
||||||
logger.setLevel(level)
|
logger.setLevel(level)
|
||||||
|
|
||||||
print(f"添加处理器 {name} {log_file} {log_format} {date_format}")
|
print(f"添加处理器 {name} {log_file} {log_format} {date_format}")
|
||||||
# 创建格式器
|
# 创建格式器
|
||||||
formatter = logging.Formatter(log_format, date_format)
|
formatter = logging.Formatter(log_format, date_format)
|
||||||
|
|
||||||
# 添加控制台处理器
|
# 添加控制台处理器
|
||||||
console_handler = logging.StreamHandler(sys.stdout)
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
console_handler.setFormatter(formatter)
|
console_handler.setFormatter(formatter)
|
||||||
logger.addHandler(console_handler)
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
# 如果指定了日志文件,添加文件处理器
|
# 如果指定了日志文件,添加文件处理器
|
||||||
if log_file:
|
if log_file:
|
||||||
# 确保日志目录存在
|
# 确保日志目录存在
|
||||||
log_path = Path(log_file)
|
log_path = Path(log_file)
|
||||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
file_handler = logging.FileHandler(log_file, encoding='utf-8')
|
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
logger.addHandler(file_handler)
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
# 注意:移除了 propagate = False,允许日志传递
|
# 注意:移除了 propagate = False,允许日志传递
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
def setup_root_logger(
|
|
||||||
level: str = "INFO",
|
def setup_root_logger(level: str = "INFO", log_file: Optional[str] = None) -> None:
|
||||||
log_file: Optional[str] = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
配置根日志器
|
配置根日志器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
level: 日志级别
|
level: 日志级别
|
||||||
log_file: 日志文件路径
|
log_file: 日志文件路径
|
||||||
"""
|
"""
|
||||||
setup_logger(None, level, log_file)
|
setup_logger(None, level, log_file)
|
||||||
|
|
||||||
|
|
||||||
def get_module_logger(
|
def get_module_logger(
|
||||||
module_name: str,
|
module_name: str,
|
||||||
level: Optional[str] = None, # 改为可选参数
|
level: Optional[str] = None, # 改为可选参数
|
||||||
log_file: Optional[str] = None # 一般不需要单独指定
|
log_file: Optional[str] = None, # 一般不需要单独指定
|
||||||
) -> logging.Logger:
|
) -> logging.Logger:
|
||||||
"""
|
"""
|
||||||
获取模块级别的logger
|
获取模块级别的logger
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module_name: 模块名称,通常传入__name__
|
module_name: 模块名称,通常传入__name__
|
||||||
level: 可选的日志级别,如果不指定则继承父级配置
|
level: 可选的日志级别,如果不指定则继承父级配置
|
||||||
log_file: 可选的日志文件路径,一般不需要指定
|
log_file: 可选的日志文件路径,一般不需要指定
|
||||||
"""
|
"""
|
||||||
logger = logging.getLogger(module_name)
|
logger = logging.getLogger(module_name)
|
||||||
|
|
||||||
# 只有显式指定了level才设置
|
# 只有显式指定了level才设置
|
||||||
if level:
|
if level:
|
||||||
logger.setLevel(getattr(logging, level.upper()))
|
logger.setLevel(getattr(logging, level.upper()))
|
||||||
|
|
||||||
# 只有显式指定了log_file才添加文件处理器
|
# 只有显式指定了log_file才添加文件处理器
|
||||||
if log_file:
|
if log_file:
|
||||||
setup_logger(module_name, level or "INFO", log_file)
|
setup_logger(module_name, level or "INFO", log_file)
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
@ -1,10 +1,15 @@
|
|||||||
from tests.functor.vad_test import test_vad_functor
|
"""
|
||||||
|
测试主函数
|
||||||
|
请在tests目录下创建测试文件, 并在此文件中调用
|
||||||
|
"""
|
||||||
|
|
||||||
from tests.pipeline.asr_test import test_asr_pipeline
|
from tests.pipeline.asr_test import test_asr_pipeline
|
||||||
from src.utils.logger import get_module_logger, setup_root_logger
|
from src.utils.logger import get_module_logger, setup_root_logger
|
||||||
|
|
||||||
setup_root_logger(level="INFO", log_file="logs/test_main.log")
|
setup_root_logger(level="INFO", log_file="logs/test_main.log")
|
||||||
logger = get_module_logger(__name__)
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
# from tests.functor.vad_test import test_vad_functor
|
||||||
# logger.info("开始测试VAD函数器")
|
# logger.info("开始测试VAD函数器")
|
||||||
# test_vad_functor()
|
# test_vad_functor()
|
||||||
|
|
||||||
|
@ -1 +1 @@
|
|||||||
"""FunASR WebSocket服务测试模块"""
|
"""FunASR WebSocket服务测试模块"""
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
Functor测试
|
Functor测试
|
||||||
VAD测试
|
VAD测试
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from src.functor.vad_functor import VADFunctor
|
from src.functor.vad_functor import VADFunctor
|
||||||
from src.functor.asr_functor import ASRFunctor
|
from src.functor.asr_functor import ASRFunctor
|
||||||
from src.functor.spk_functor import SPKFunctor
|
from src.functor.spk_functor import SPKFunctor
|
||||||
@ -21,6 +22,7 @@ logger = get_module_logger(__name__)
|
|||||||
|
|
||||||
model_loader = ModelLoader()
|
model_loader = ModelLoader()
|
||||||
|
|
||||||
|
|
||||||
def test_vad_functor():
|
def test_vad_functor():
|
||||||
# 加载模型
|
# 加载模型
|
||||||
args = {
|
args = {
|
||||||
@ -38,9 +40,9 @@ def test_vad_functor():
|
|||||||
chunk_stride=1600,
|
chunk_stride=1600,
|
||||||
sample_rate=sample_rate,
|
sample_rate=sample_rate,
|
||||||
sample_width=16,
|
sample_width=16,
|
||||||
channels=1
|
channels=1,
|
||||||
)
|
)
|
||||||
chunk_stride = int(audio_config.chunk_size*sample_rate/1000)
|
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
|
||||||
audio_config.chunk_stride = chunk_stride
|
audio_config.chunk_stride = chunk_stride
|
||||||
# 创建输入队列
|
# 创建输入队列
|
||||||
input_queue = Queue()
|
input_queue = Queue()
|
||||||
@ -62,9 +64,7 @@ def test_vad_functor():
|
|||||||
vad_functor.add_callback(lambda x: vad2asr_queue.put(x))
|
vad_functor.add_callback(lambda x: vad2asr_queue.put(x))
|
||||||
vad_functor.add_callback(lambda x: vad2spk_queue.put(x))
|
vad_functor.add_callback(lambda x: vad2spk_queue.put(x))
|
||||||
# 设置模型
|
# 设置模型
|
||||||
vad_functor.set_model({
|
vad_functor.set_model({"vad": model_loader.models["vad"]})
|
||||||
'vad': model_loader.models['vad']
|
|
||||||
})
|
|
||||||
# 启动VAD函数器
|
# 启动VAD函数器
|
||||||
vad_functor.run()
|
vad_functor.run()
|
||||||
|
|
||||||
@ -77,9 +77,7 @@ def test_vad_functor():
|
|||||||
# 设置回调函数
|
# 设置回调函数
|
||||||
asr_functor.add_callback(lambda x: print(f"asr callback: {x}"))
|
asr_functor.add_callback(lambda x: print(f"asr callback: {x}"))
|
||||||
# 设置模型
|
# 设置模型
|
||||||
asr_functor.set_model({
|
asr_functor.set_model({"asr": model_loader.models["asr"]})
|
||||||
'asr': model_loader.models['asr']
|
|
||||||
})
|
|
||||||
# 启动ASR函数器
|
# 启动ASR函数器
|
||||||
asr_functor.run()
|
asr_functor.run()
|
||||||
|
|
||||||
@ -92,23 +90,25 @@ def test_vad_functor():
|
|||||||
# 设置回调函数
|
# 设置回调函数
|
||||||
spk_functor.add_callback(lambda x: print(f"spk callback: {x}"))
|
spk_functor.add_callback(lambda x: print(f"spk callback: {x}"))
|
||||||
# 设置模型
|
# 设置模型
|
||||||
spk_functor.set_model({
|
spk_functor.set_model(
|
||||||
# 'spk': model_loader.models['spk']
|
{
|
||||||
'spk': 'fake_spk'
|
# 'spk': model_loader.models['spk']
|
||||||
})
|
"spk": "fake_spk"
|
||||||
|
}
|
||||||
|
)
|
||||||
# 启动SPK函数器
|
# 启动SPK函数器
|
||||||
spk_functor.run()
|
spk_functor.run()
|
||||||
|
|
||||||
|
|
||||||
f_binary = f_data
|
f_binary = f_data
|
||||||
audio_clip_len = 200
|
audio_clip_len = 200
|
||||||
print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}")
|
print(
|
||||||
|
f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}"
|
||||||
|
)
|
||||||
for i in range(0, len(f_binary), audio_clip_len):
|
for i in range(0, len(f_binary), audio_clip_len):
|
||||||
binary_data = f_binary[i:i+audio_clip_len]
|
binary_data = f_binary[i : i + audio_clip_len]
|
||||||
input_queue.put(binary_data)
|
input_queue.put(binary_data)
|
||||||
# 等待VAD函数器结束
|
# 等待VAD函数器结束
|
||||||
|
|
||||||
|
|
||||||
vad_functor.stop()
|
vad_functor.stop()
|
||||||
print("[vad_test] VAD函数器结束")
|
print("[vad_test] VAD函数器结束")
|
||||||
|
|
||||||
@ -119,4 +119,6 @@ def test_vad_functor():
|
|||||||
if OVERWATCH:
|
if OVERWATCH:
|
||||||
for index in range(len(audio_binary_data_list)):
|
for index in range(len(audio_binary_data_list)):
|
||||||
save_path = f"tests/vad_test_output_{index}.wav"
|
save_path = f"tests/vad_test_output_{index}.wav"
|
||||||
soundfile.write(save_path, audio_binary_data_list[index].binary_data, sample_rate)
|
soundfile.write(
|
||||||
|
save_path, audio_binary_data_list[index].binary_data, sample_rate
|
||||||
|
)
|
||||||
|
@ -1,10 +1,21 @@
|
|||||||
|
"""
|
||||||
|
模型使用测试
|
||||||
|
此处主要用于各类调用模型的处理数据与输出格式
|
||||||
|
请在主目录下test_main.py中调用
|
||||||
|
将需要测试的模型定义在函数中进行测试, 函数名称需要与测试内容匹配。
|
||||||
|
"""
|
||||||
|
|
||||||
from funasr import AutoModel
|
from funasr import AutoModel
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
from src.models import VADResponse
|
from src.models import VADResponse
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
|
def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
|
||||||
chunk_size = 100 # ms
|
"""
|
||||||
|
在线VAD模型使用
|
||||||
|
"""
|
||||||
|
chunk_size = 100 # ms
|
||||||
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
|
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
|
||||||
|
|
||||||
vad_result = VADResponse()
|
vad_result = VADResponse()
|
||||||
@ -16,12 +27,14 @@ def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
|
|||||||
chunk_stride = int(chunk_size * sample_rate / 1000)
|
chunk_stride = int(chunk_size * sample_rate / 1000)
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
total_chunk_num = int(len((speech)-1)/chunk_stride+1)
|
total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
|
||||||
for i in range(total_chunk_num):
|
for i in range(total_chunk_num):
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
|
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
|
||||||
is_final = i == total_chunk_num - 1
|
is_final = i == total_chunk_num - 1
|
||||||
res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
|
res = model.generate(
|
||||||
|
input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size
|
||||||
|
)
|
||||||
if len(res[0]["value"]):
|
if len(res[0]["value"]):
|
||||||
vad_result += VADResponse.from_raw(res)
|
vad_result += VADResponse.from_raw(res)
|
||||||
for item in res[0]["value"]:
|
for item in res[0]["value"]:
|
||||||
@ -32,44 +45,64 @@ def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
|
|||||||
# print(item)
|
# print(item)
|
||||||
return vad_result
|
return vad_result
|
||||||
|
|
||||||
|
|
||||||
def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
|
def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
在线VAD模型使用
|
||||||
|
测试LogicTrager
|
||||||
|
在Rebuild版本后LogicTrager中已弃用
|
||||||
|
"""
|
||||||
from src.logic_trager import LogicTrager
|
from src.logic_trager import LogicTrager
|
||||||
import soundfile
|
import soundfile
|
||||||
|
|
||||||
from src.config import parse_args
|
from src.config import parse_args
|
||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# from src.functor.model_loader import load_models
|
# from src.functor.model_loader import load_models
|
||||||
# models = load_models(args)
|
# models = load_models(args)
|
||||||
from src.model_loader import ModelLoader
|
from src.model_loader import ModelLoader
|
||||||
|
|
||||||
models = ModelLoader(args)
|
models = ModelLoader(args)
|
||||||
|
|
||||||
chunk_size = 200 # ms
|
chunk_size = 200 # ms
|
||||||
from src.models import AudioBinary_Config
|
from src.models import AudioBinary_Config
|
||||||
import soundfile
|
import soundfile
|
||||||
|
|
||||||
speech, sample_rate = soundfile.read(file_path)
|
speech, sample_rate = soundfile.read(file_path)
|
||||||
chunk_stride = int(chunk_size * sample_rate / 1000)
|
chunk_stride = int(chunk_size * sample_rate / 1000)
|
||||||
audio_config = AudioBinary_Config(sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size)
|
audio_config = AudioBinary_Config(
|
||||||
|
sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
logic_trager = LogicTrager(models=models, audio_config=audio_config)
|
logic_trager = LogicTrager(models=models, audio_config=audio_config)
|
||||||
for i in range(len(speech)//chunk_stride+1):
|
for i in range(len(speech) // chunk_stride + 1):
|
||||||
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
|
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
|
||||||
logic_trager.push_binary_data(speech_chunk)
|
logic_trager.push_binary_data(speech_chunk)
|
||||||
|
|
||||||
# for item in items:
|
# for item in items:
|
||||||
# print(item)
|
# print(item)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
|
def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
ASR模型使用
|
||||||
|
离线ASR模型使用
|
||||||
|
"""
|
||||||
from funasr import AutoModel
|
from funasr import AutoModel
|
||||||
model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
|
|
||||||
vad_model="fsmn-vad", vad_model_revision="v2.0.4",
|
model = AutoModel(
|
||||||
# punc_model="ct-punc-c", punc_model_revision="v2.0.4",
|
model="paraformer-zh",
|
||||||
spk_model="cam++", spk_model_revision="v2.0.2",
|
model_revision="v2.0.4",
|
||||||
spk_mode="vad_segment",
|
vad_model="fsmn-vad",
|
||||||
auto_update=False,
|
vad_model_revision="v2.0.4",
|
||||||
)
|
# punc_model="ct-punc-c", punc_model_revision="v2.0.4",
|
||||||
|
spk_model="cam++",
|
||||||
|
spk_model_revision="v2.0.2",
|
||||||
|
spk_mode="vad_segment",
|
||||||
|
auto_update=False,
|
||||||
|
)
|
||||||
|
|
||||||
import soundfile
|
import soundfile
|
||||||
|
|
||||||
@ -80,7 +113,9 @@ def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
|
|||||||
result = model.generate(speech)
|
result = model.generate(speech)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# vad_result = vad_model_use_online("tests/vad_example.wav")
|
# if __name__ == "__main__":
|
||||||
vad_result = vad_model_use_online_logic("tests/vad_example.wav")
|
# 请在主目录下调用test_main.py文件进行测试
|
||||||
# print(vad_result)
|
# vad_result = vad_model_use_online("tests/vad_example.wav")
|
||||||
|
# vad_result = vad_model_use_online_logic("tests/vad_example.wav")
|
||||||
|
# print(vad_result)
|
||||||
|
@ -2,7 +2,9 @@
|
|||||||
Pipeline测试
|
Pipeline测试
|
||||||
VAD+ASR+SPK(FAKE)
|
VAD+ASR+SPK(FAKE)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from src.pipeline.ASRpipeline import ASRPipeline
|
from src.pipeline.ASRpipeline import ASRPipeline
|
||||||
|
from src.pipeline import PipelineFactory
|
||||||
from src.models import AudioBinary_data_list, AudioBinary_Config
|
from src.models import AudioBinary_data_list, AudioBinary_Config
|
||||||
from src.model_loader import ModelLoader
|
from src.model_loader import ModelLoader
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
@ -18,6 +20,7 @@ OVAERWATCH = False
|
|||||||
|
|
||||||
model_loader = ModelLoader()
|
model_loader = ModelLoader()
|
||||||
|
|
||||||
|
|
||||||
def test_asr_pipeline():
|
def test_asr_pipeline():
|
||||||
# 加载模型
|
# 加载模型
|
||||||
args = {
|
args = {
|
||||||
@ -36,9 +39,9 @@ def test_asr_pipeline():
|
|||||||
chunk_stride=1600,
|
chunk_stride=1600,
|
||||||
sample_rate=sample_rate,
|
sample_rate=sample_rate,
|
||||||
sample_width=16,
|
sample_width=16,
|
||||||
channels=1
|
channels=1,
|
||||||
)
|
)
|
||||||
chunk_stride = int(audio_config.chunk_size*sample_rate/1000)
|
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
|
||||||
audio_config.chunk_stride = chunk_stride
|
audio_config.chunk_stride = chunk_stride
|
||||||
|
|
||||||
# 创建参数Dict
|
# 创建参数Dict
|
||||||
@ -52,29 +55,39 @@ def test_asr_pipeline():
|
|||||||
input_queue = Queue()
|
input_queue = Queue()
|
||||||
|
|
||||||
# 创建Pipeline
|
# 创建Pipeline
|
||||||
asr_pipeline = ASRPipeline()
|
# asr_pipeline = ASRPipeline()
|
||||||
asr_pipeline.set_models(models)
|
# asr_pipeline.set_models(models)
|
||||||
asr_pipeline.set_config(config)
|
# asr_pipeline.set_config(config)
|
||||||
asr_pipeline.set_audio_binary(audio_binary_data_list)
|
# asr_pipeline.set_audio_binary(audio_binary_data_list)
|
||||||
asr_pipeline.set_input_queue(input_queue)
|
# asr_pipeline.set_input_queue(input_queue)
|
||||||
asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
|
# asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
|
||||||
asr_pipeline.bake()
|
# asr_pipeline.bake()
|
||||||
|
asr_pipeline = PipelineFactory.create_pipeline(
|
||||||
|
pipeline_name = "ASRpipeline",
|
||||||
|
models=models,
|
||||||
|
config=config,
|
||||||
|
audio_binary=audio_binary_data_list,
|
||||||
|
input_queue=input_queue,
|
||||||
|
callback=lambda x: print(f"pipeline callback: {x}")
|
||||||
|
)
|
||||||
|
|
||||||
# 运行Pipeline
|
# 运行Pipeline
|
||||||
asr_instance = asr_pipeline.run()
|
asr_instance = asr_pipeline.run()
|
||||||
|
|
||||||
|
|
||||||
audio_clip_len = 200
|
audio_clip_len = 200
|
||||||
print(f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}")
|
print(
|
||||||
|
f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}"
|
||||||
|
)
|
||||||
for i in range(0, len(audio_data), audio_clip_len):
|
for i in range(0, len(audio_data), audio_clip_len):
|
||||||
input_queue.put(audio_data[i:i+audio_clip_len])
|
input_queue.put(audio_data[i : i + audio_clip_len])
|
||||||
|
|
||||||
# time.sleep(10)
|
# time.sleep(10)
|
||||||
# input_queue.put(None)
|
# input_queue.put(None)
|
||||||
|
|
||||||
# 等待Pipeline结束
|
# 等待Pipeline结束
|
||||||
# asr_instance.join()
|
# asr_instance.join()
|
||||||
|
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
asr_pipeline.stop()
|
asr_pipeline.stop()
|
||||||
# asr_pipeline.stop()
|
# asr_pipeline.stop()
|
||||||
|
|
||||||
|
@ -10,23 +10,23 @@ import os
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
# 将src目录添加到路径
|
# 将src目录添加到路径
|
||||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
from src.config import parse_args
|
from src.config import parse_args
|
||||||
|
|
||||||
|
|
||||||
def test_default_args():
|
def test_default_args():
|
||||||
"""测试默认参数值"""
|
"""测试默认参数值"""
|
||||||
with patch('sys.argv', ['script.py']):
|
with patch("sys.argv", ["script.py"]):
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# 检查服务器参数
|
# 检查服务器参数
|
||||||
assert args.host == "0.0.0.0"
|
assert args.host == "0.0.0.0"
|
||||||
assert args.port == 10095
|
assert args.port == 10095
|
||||||
|
|
||||||
# 检查SSL参数
|
# 检查SSL参数
|
||||||
assert args.certfile == ""
|
assert args.certfile == ""
|
||||||
assert args.keyfile == ""
|
assert args.keyfile == ""
|
||||||
|
|
||||||
# 检查模型参数
|
# 检查模型参数
|
||||||
assert "paraformer" in args.asr_model
|
assert "paraformer" in args.asr_model
|
||||||
assert args.asr_model_revision == "v2.0.4"
|
assert args.asr_model_revision == "v2.0.4"
|
||||||
@ -36,7 +36,7 @@ def test_default_args():
|
|||||||
assert args.vad_model_revision == "v2.0.4"
|
assert args.vad_model_revision == "v2.0.4"
|
||||||
assert "punc" in args.punc_model
|
assert "punc" in args.punc_model
|
||||||
assert args.punc_model_revision == "v2.0.4"
|
assert args.punc_model_revision == "v2.0.4"
|
||||||
|
|
||||||
# 检查硬件配置
|
# 检查硬件配置
|
||||||
assert args.ngpu == 1
|
assert args.ngpu == 1
|
||||||
assert args.device == "cuda"
|
assert args.device == "cuda"
|
||||||
@ -46,19 +46,26 @@ def test_default_args():
|
|||||||
def test_custom_args():
|
def test_custom_args():
|
||||||
"""测试自定义参数值"""
|
"""测试自定义参数值"""
|
||||||
test_args = [
|
test_args = [
|
||||||
'script.py',
|
"script.py",
|
||||||
'--host', 'localhost',
|
"--host",
|
||||||
'--port', '8080',
|
"localhost",
|
||||||
'--certfile', 'cert.pem',
|
"--port",
|
||||||
'--keyfile', 'key.pem',
|
"8080",
|
||||||
'--asr_model', 'custom_model',
|
"--certfile",
|
||||||
'--ngpu', '0',
|
"cert.pem",
|
||||||
'--device', 'cpu'
|
"--keyfile",
|
||||||
|
"key.pem",
|
||||||
|
"--asr_model",
|
||||||
|
"custom_model",
|
||||||
|
"--ngpu",
|
||||||
|
"0",
|
||||||
|
"--device",
|
||||||
|
"cpu",
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch('sys.argv', test_args):
|
with patch("sys.argv", test_args):
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# 检查自定义参数
|
# 检查自定义参数
|
||||||
assert args.host == "localhost"
|
assert args.host == "localhost"
|
||||||
assert args.port == 8080
|
assert args.port == 8080
|
||||||
@ -66,4 +73,4 @@ def test_custom_args():
|
|||||||
assert args.keyfile == "key.pem"
|
assert args.keyfile == "key.pem"
|
||||||
assert args.asr_model == "custom_model"
|
assert args.asr_model == "custom_model"
|
||||||
assert args.ngpu == 0
|
assert args.ngpu == 0
|
||||||
assert args.device == "cpu"
|
assert args.device == "cpu"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user