[代码结构]black . 对所有文件格式调整,无功能变化。
This commit is contained in:
parent
5b94c40016
commit
5a820b49e4
8
main.py
8
main.py
@ -1,11 +1,7 @@
|
||||
from funasr import AutoModel
|
||||
|
||||
chunk_size = 200 # ms
|
||||
model = AutoModel(
|
||||
model="fsmn-vad",
|
||||
model_revision="v2.0.4",
|
||||
disable_update=True
|
||||
)
|
||||
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
|
||||
|
||||
import soundfile
|
||||
|
||||
@ -23,7 +19,7 @@ for i in range(total_chunk_num):
|
||||
cache=cache,
|
||||
is_final=is_final,
|
||||
chunk_size=chunk_size,
|
||||
disable_pbar=True
|
||||
disable_pbar=True,
|
||||
)
|
||||
if len(res[0]["value"]):
|
||||
print(res)
|
||||
|
@ -46,7 +46,7 @@ class AudioBinary:
|
||||
else:
|
||||
raise ValueError("参数类型错误")
|
||||
|
||||
def add_slice_listener(self, slice_listener: callable):
|
||||
def add_slice_listener(self, slice_listener: callable) -> None:
|
||||
"""
|
||||
添加切片监听器
|
||||
参数:
|
||||
@ -128,7 +128,7 @@ class AudioChunk:
|
||||
此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑。
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_instance: Optional[AudioChunk] = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""
|
||||
@ -138,7 +138,7 @@ class AudioChunk:
|
||||
cls._instance = super(AudioChunk, cls).__new__(cls, *args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
初始化AudioChunk实例
|
||||
"""
|
||||
|
@ -21,41 +21,23 @@ def parse_args():
|
||||
"--host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="服务器主机地址,例如:localhost, 0.0.0.0"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=10095,
|
||||
help="WebSocket服务器端口"
|
||||
help="服务器主机地址,例如:localhost, 0.0.0.0",
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=10095, help="WebSocket服务器端口")
|
||||
|
||||
# SSL配置
|
||||
parser.add_argument(
|
||||
"--certfile",
|
||||
type=str,
|
||||
default="",
|
||||
help="SSL证书文件路径"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keyfile",
|
||||
type=str,
|
||||
default="",
|
||||
help="SSL密钥文件路径"
|
||||
)
|
||||
parser.add_argument("--certfile", type=str, default="", help="SSL证书文件路径")
|
||||
parser.add_argument("--keyfile", type=str, default="", help="SSL密钥文件路径")
|
||||
|
||||
# ASR模型配置
|
||||
parser.add_argument(
|
||||
"--asr_model",
|
||||
type=str,
|
||||
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
help="离线ASR模型(从ModelScope获取)"
|
||||
help="离线ASR模型(从ModelScope获取)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--asr_model_revision",
|
||||
type=str,
|
||||
default="v2.0.4",
|
||||
help="离线ASR模型版本"
|
||||
"--asr_model_revision", type=str, default="v2.0.4", help="离线ASR模型版本"
|
||||
)
|
||||
|
||||
# 在线ASR模型配置
|
||||
@ -63,13 +45,13 @@ def parse_args():
|
||||
"--asr_model_online",
|
||||
type=str,
|
||||
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
|
||||
help="在线ASR模型(从ModelScope获取)"
|
||||
help="在线ASR模型(从ModelScope获取)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--asr_model_online_revision",
|
||||
type=str,
|
||||
default="v2.0.4",
|
||||
help="在线ASR模型版本"
|
||||
help="在线ASR模型版本",
|
||||
)
|
||||
|
||||
# VAD模型配置
|
||||
@ -77,13 +59,10 @@ def parse_args():
|
||||
"--vad_model",
|
||||
type=str,
|
||||
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
help="VAD语音活动检测模型(从ModelScope获取)"
|
||||
help="VAD语音活动检测模型(从ModelScope获取)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vad_model_revision",
|
||||
type=str,
|
||||
default="v2.0.4",
|
||||
help="VAD模型版本"
|
||||
"--vad_model_revision", type=str, default="v2.0.4", help="VAD模型版本"
|
||||
)
|
||||
|
||||
# 标点符号模型配置
|
||||
@ -91,34 +70,18 @@ def parse_args():
|
||||
"--punc_model",
|
||||
type=str,
|
||||
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
|
||||
help="标点符号模型(从ModelScope获取)"
|
||||
help="标点符号模型(从ModelScope获取)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--punc_model_revision",
|
||||
type=str,
|
||||
default="v2.0.4",
|
||||
help="标点符号模型版本"
|
||||
"--punc_model_revision", type=str, default="v2.0.4", help="标点符号模型版本"
|
||||
)
|
||||
|
||||
# 硬件配置
|
||||
parser.add_argument("--ngpu", type=int, default=1, help="GPU数量,0表示仅使用CPU")
|
||||
parser.add_argument(
|
||||
"--ngpu",
|
||||
type=int,
|
||||
default=1,
|
||||
help="GPU数量,0表示仅使用CPU"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="设备类型:cuda或cpu"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ncpu",
|
||||
type=int,
|
||||
default=4,
|
||||
help="CPU核心数"
|
||||
"--device", type=str, default="cuda", help="设备类型:cuda或cpu"
|
||||
)
|
||||
parser.add_argument("--ncpu", type=int, default=4, help="CPU核心数")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
ASRFunctor
|
||||
负责对音频片段进行ASR处理, 以ASR_Result进行callback
|
||||
"""
|
||||
|
||||
from src.functor.base import BaseFunctor
|
||||
from src.models import AudioBinary_data_list, AudioBinary_Config, VAD_Functor_result
|
||||
from typing import Callable, List
|
||||
@ -13,6 +14,7 @@ from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
class ASRFunctor(BaseFunctor):
|
||||
"""
|
||||
ASRFunctor
|
||||
@ -83,7 +85,7 @@ class ASRFunctor(BaseFunctor):
|
||||
"""
|
||||
回调函数
|
||||
"""
|
||||
text = result[0]['text'].replace(" ", "")
|
||||
text = result[0]["text"].replace(" ", "")
|
||||
for callback in self._callback:
|
||||
callback(text)
|
||||
|
||||
@ -157,5 +159,3 @@ class ASRFunctor(BaseFunctor):
|
||||
with self._status_lock:
|
||||
self._is_running = False
|
||||
return not self._thread.is_alive()
|
||||
|
||||
|
@ -11,11 +11,13 @@ Functor基础模块
|
||||
BaseFunctor: Functor抽象类
|
||||
FunctorFactory: Functor工厂类
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List
|
||||
from queue import Queue
|
||||
import threading
|
||||
|
||||
|
||||
class BaseFunctor(ABC):
|
||||
"""
|
||||
Functor抽象类
|
||||
@ -27,9 +29,7 @@ class BaseFunctor(ABC):
|
||||
_model (dict): 存储模型相关的配置和实例
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self
|
||||
):
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化函数器
|
||||
|
||||
@ -111,7 +111,6 @@ class BaseFunctor(ABC):
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class FunctorFactory:
|
||||
"""
|
||||
Functor工厂类
|
||||
@ -151,10 +150,9 @@ class FunctorFactory:
|
||||
创建VAD Functor实例
|
||||
"""
|
||||
from src.functor.vad_functor import VADFunctor
|
||||
|
||||
audio_config = config["audio_config"]
|
||||
model = {
|
||||
"vad": models["vad"]
|
||||
}
|
||||
model = {"vad": models["vad"]}
|
||||
|
||||
vad_functor = VADFunctor()
|
||||
vad_functor.set_audio_config(audio_config)
|
||||
@ -167,10 +165,9 @@ class FunctorFactory:
|
||||
创建ASR Functor实例
|
||||
"""
|
||||
from src.functor.asr_functor import ASRFunctor
|
||||
|
||||
audio_config = config["audio_config"]
|
||||
model = {
|
||||
"asr": models["asr"]
|
||||
}
|
||||
model = {"asr": models["asr"]}
|
||||
|
||||
asr_functor = ASRFunctor()
|
||||
asr_functor.set_audio_config(audio_config)
|
||||
@ -183,10 +180,9 @@ class FunctorFactory:
|
||||
创建SPK Functor实例
|
||||
"""
|
||||
from src.functor.spk_functor import SPKFunctor
|
||||
|
||||
audio_config = config["audio_config"]
|
||||
model = {
|
||||
"spk": models["spk"]
|
||||
}
|
||||
model = {"spk": models["spk"]}
|
||||
|
||||
spk_functor = SPKFunctor()
|
||||
spk_functor.set_audio_config(audio_config)
|
||||
|
@ -2,6 +2,7 @@
|
||||
SpkFunctor
|
||||
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
|
||||
"""
|
||||
|
||||
from src.functor.base import BaseFunctor
|
||||
from src.models import AudioBinary_Config, VAD_Functor_result
|
||||
from typing import Callable, List
|
||||
@ -13,6 +14,7 @@ from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
class SPKFunctor(BaseFunctor):
|
||||
"""
|
||||
SPKFunctor
|
||||
@ -33,7 +35,6 @@ class SPKFunctor(BaseFunctor):
|
||||
self._input_queue: Queue = None # 输入队列
|
||||
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||
|
||||
|
||||
def reset_cache(self) -> None:
|
||||
"""
|
||||
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||
@ -83,7 +84,7 @@ class SPKFunctor(BaseFunctor):
|
||||
# input=binary_data,
|
||||
# chunk_size=self._audio_config.chunk_size,
|
||||
# )
|
||||
result = [{'result': "spk1", 'score': {"spk1": 0.9, "spk2": 0.3}}]
|
||||
result = [{"result": "spk1", "score": {"spk1": 0.9, "spk2": 0.3}}]
|
||||
self._do_callback(result)
|
||||
|
||||
def _run(self) -> None:
|
||||
@ -142,5 +143,3 @@ class SPKFunctor(BaseFunctor):
|
||||
with self._status_lock:
|
||||
self._is_running = False
|
||||
return not self._thread.is_alive()
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
VADFunctor
|
||||
负责对音频片段进行VAD处理, 以VAD_Result进行callback
|
||||
"""
|
||||
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
from typing import List, Any, Callable
|
||||
@ -105,9 +106,7 @@ class VADFunctor(BaseFunctor):
|
||||
self._callback = []
|
||||
self._callback.append(callback)
|
||||
|
||||
def _do_callback(
|
||||
self, result: List[List[int]]
|
||||
) -> None:
|
||||
def _do_callback(self, result: List[List[int]]) -> None:
|
||||
"""
|
||||
回调函数
|
||||
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice
|
||||
|
@ -6,9 +6,11 @@
|
||||
|
||||
from src.utils.logger import get_module_logger
|
||||
from typing import Any, Dict, Type, Callable
|
||||
|
||||
# 配置日志
|
||||
logger = get_module_logger(__name__, level="INFO")
|
||||
|
||||
|
||||
class AutoAfterMeta(type):
|
||||
"""
|
||||
自动调用__after__函数的元类
|
||||
@ -21,7 +23,7 @@ class AutoAfterMeta(type):
|
||||
# 遍历所有属性
|
||||
for attr_name, attr_value in attrs.items():
|
||||
# 如果是函数且不是以_开头
|
||||
if callable(attr_value) and not attr_name.startswith('__'):
|
||||
if callable(attr_value) and not attr_name.startswith("__"):
|
||||
# 获取原函数
|
||||
original_func = attr_value
|
||||
|
||||
@ -45,6 +47,7 @@ class AutoAfterMeta(type):
|
||||
logger.error(f"调用{after_func_name}时出错: {e}")
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
# 替换原函数
|
||||
@ -68,6 +71,7 @@ class AutoAfterMeta(type):
|
||||
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
"""
|
||||
整体识别的处理逻辑:
|
||||
1.压入二进制音频信息
|
||||
@ -88,10 +92,12 @@ from src.models import AudioBinary_Config
|
||||
from src.models import AudioBinary_Chunk
|
||||
from typing import List
|
||||
|
||||
|
||||
class LogicTrager(metaclass=AutoAfterMeta):
|
||||
"""逻辑触发器类"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
audio_chunk_max_size: int = 1024 * 1024 * 10,
|
||||
audio_config: AudioBinary_Config = None,
|
||||
result_callback: Callable = None,
|
||||
@ -101,10 +107,12 @@ class LogicTrager(metaclass=AutoAfterMeta):
|
||||
# 存储音频块
|
||||
self._audio_chunk: List[AudioBinary_Chunk] = []
|
||||
# 存储二进制数据
|
||||
self._audio_chunk_binary = b''
|
||||
self._audio_chunk_binary = b""
|
||||
self._audio_chunk_max_size = audio_chunk_max_size
|
||||
# 音频参数
|
||||
self._audio_config = audio_config if audio_config is not None else AudioBinary_Config()
|
||||
self._audio_config = (
|
||||
audio_config if audio_config is not None else AudioBinary_Config()
|
||||
)
|
||||
# 结果队列
|
||||
self._result_queue = []
|
||||
# 聚合结果回调函数
|
||||
@ -113,7 +121,6 @@ class LogicTrager(metaclass=AutoAfterMeta):
|
||||
self._vad = VAD(VAD_model=models.get("vad"), audio_config=self._audio_config)
|
||||
self._vad.set_callback(self.push_audio_chunk)
|
||||
|
||||
|
||||
logger.info("初始化LogicTrager")
|
||||
|
||||
def push_binary_data(self, chunk: bytes) -> None:
|
||||
@ -138,7 +145,11 @@ class LogicTrager(metaclass=AutoAfterMeta):
|
||||
"""
|
||||
音频处理
|
||||
"""
|
||||
logger.info("LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(chunk.start_time, chunk.end_time, len(chunk.chunk)))
|
||||
logger.info(
|
||||
"LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(
|
||||
chunk.start_time, chunk.end_time, len(chunk.chunk)
|
||||
)
|
||||
)
|
||||
self._audio_chunk.append(chunk)
|
||||
|
||||
def __after__push_audio_chunk(self) -> None:
|
||||
|
@ -115,7 +115,7 @@ class ModelLoader:
|
||||
self.models = {}
|
||||
# 加载离线ASR模型
|
||||
# 检查对应键是否存在
|
||||
model_list = ['asr', 'asr_online', 'vad', 'punc', 'spk']
|
||||
model_list = ["asr", "asr_online", "vad", "punc", "spk"]
|
||||
for model_name in model_list:
|
||||
name_model = f"{model_name}_model"
|
||||
name_model_revision = f"{model_name}_model_revision"
|
||||
|
@ -1,3 +1,9 @@
|
||||
from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
|
||||
from .vad import VAD_Functor_result
|
||||
__all__ = ["AudioBinary_Config", "AudioBinary_data_list", "_AudioBinary_data", "VAD_Functor_result"]
|
||||
|
||||
__all__ = [
|
||||
"AudioBinary_Config",
|
||||
"AudioBinary_data_list",
|
||||
"_AudioBinary_data",
|
||||
"VAD_Functor_result",
|
||||
]
|
||||
|
@ -8,8 +8,10 @@ logger = get_module_logger(__name__)
|
||||
|
||||
binary_data_types = (bytes, numpy.ndarray)
|
||||
|
||||
|
||||
class AudioBinary_Config(BaseModel):
|
||||
"""二进制音频块配置信息"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@ -37,14 +39,16 @@ class AudioBinary_Config(BaseModel):
|
||||
"""
|
||||
return int(frame * 1000 / self.sample_rate)
|
||||
|
||||
|
||||
class _AudioBinary_data(BaseModel):
|
||||
"""音频数据"""
|
||||
|
||||
binary_data: binary_data_types = Field(description="音频二进制数据", default=None)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator('binary_data')
|
||||
@validator("binary_data")
|
||||
def validate_binary_data(cls, v):
|
||||
"""
|
||||
验证音频数据
|
||||
@ -54,7 +58,11 @@ class _AudioBinary_data(BaseModel):
|
||||
binary_data_types: 音频数据
|
||||
"""
|
||||
if not isinstance(v, (bytes, numpy.ndarray)):
|
||||
logger.warning("[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查", cls.__class__.__name__, type(v))
|
||||
logger.warning(
|
||||
"[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查",
|
||||
cls.__class__.__name__,
|
||||
type(v),
|
||||
)
|
||||
return v
|
||||
|
||||
def __len__(self):
|
||||
@ -71,7 +79,11 @@ class _AudioBinary_data(BaseModel):
|
||||
Args:
|
||||
binary_data: 音频数据
|
||||
"""
|
||||
logger.debug("[%s]初始化音频数据, 数据类型为%s", self.__class__.__name__, type(binary_data))
|
||||
logger.debug(
|
||||
"[%s]初始化音频数据, 数据类型为%s",
|
||||
self.__class__.__name__,
|
||||
type(binary_data),
|
||||
)
|
||||
super().__init__(binary_data=binary_data)
|
||||
|
||||
def __getitem__(self):
|
||||
@ -82,9 +94,13 @@ class _AudioBinary_data(BaseModel):
|
||||
"""
|
||||
return self.binary_data
|
||||
|
||||
|
||||
class AudioBinary_data_list(BaseModel):
|
||||
"""音频数据列表"""
|
||||
binary_data_list: List[_AudioBinary_data] = Field(description="音频数据列表", default=[])
|
||||
|
||||
binary_data_list: List[_AudioBinary_data] = Field(
|
||||
description="音频数据列表", default=[]
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@ -118,6 +134,7 @@ class AudioBinary_data_list(BaseModel):
|
||||
"""
|
||||
return len(self.binary_data_list)
|
||||
|
||||
|
||||
# class AudioBinary_Slice(BaseModel):
|
||||
# """音频块切片"""
|
||||
# target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None)
|
||||
|
@ -2,23 +2,27 @@ from pydantic import BaseModel, Field, validator
|
||||
from typing import List, Optional, Callable, Any
|
||||
from .audio import AudioBinary_data_list, _AudioBinary_data
|
||||
|
||||
|
||||
class VAD_Functor_result(BaseModel):
|
||||
"""
|
||||
VADFunctor结果
|
||||
"""
|
||||
|
||||
audiobinary_data_list: AudioBinary_data_list = Field(description="音频数据列表")
|
||||
audiobinary_index: int = Field(description="音频数据索引")
|
||||
audiobinary_data: _AudioBinary_data = Field(description="音频数据, 指向AudioBinary_data")
|
||||
audiobinary_data: _AudioBinary_data = Field(
|
||||
description="音频数据, 指向AudioBinary_data"
|
||||
)
|
||||
start_time: int = Field(description="开始时间", is_required=True)
|
||||
end_time: int = Field(description="结束时间", is_required=True)
|
||||
|
||||
@validator('audiobinary_data_list')
|
||||
@validator("audiobinary_data_list")
|
||||
def validate_audiobinary_data_list(cls, v):
|
||||
if not isinstance(v, AudioBinary_data_list):
|
||||
raise ValueError("audiobinary_data_list必须是AudioBinary_data_list类型")
|
||||
return v
|
||||
|
||||
@validator('audiobinary_index')
|
||||
@validator("audiobinary_index")
|
||||
def validate_audiobinary_index(cls, v):
|
||||
if not isinstance(v, int):
|
||||
raise ValueError("audiobinary_index必须是int类型")
|
||||
@ -26,13 +30,13 @@ class VAD_Functor_result(BaseModel):
|
||||
raise ValueError("audiobinary_index必须大于0")
|
||||
return v
|
||||
|
||||
@validator('audiobinary_data')
|
||||
@validator("audiobinary_data")
|
||||
def validate_audiobinary_data(cls, v):
|
||||
if not isinstance(v, _AudioBinary_data):
|
||||
raise ValueError("audiobinary_data必须是AudioBinary_data类型")
|
||||
return v
|
||||
|
||||
@validator('start_time')
|
||||
@validator("start_time")
|
||||
def validate_start_time(cls, v):
|
||||
if not isinstance(v, int):
|
||||
raise ValueError("start_time必须是int类型")
|
||||
@ -40,11 +44,11 @@ class VAD_Functor_result(BaseModel):
|
||||
raise ValueError("start_time必须大于0")
|
||||
return v
|
||||
|
||||
@validator('end_time')
|
||||
@validator("end_time")
|
||||
def validate_end_time(cls, v, values):
|
||||
if not isinstance(v, int):
|
||||
raise ValueError("end_time必须是int类型")
|
||||
if 'start_time' in values and v <= values['start_time']:
|
||||
if "start_time" in values and v <= values["start_time"]:
|
||||
raise ValueError("end_time必须大于start_time")
|
||||
return v
|
||||
|
||||
@ -54,7 +58,7 @@ class VAD_Functor_result(BaseModel):
|
||||
audiobinary_data_list: AudioBinary_data_list,
|
||||
data: Any,
|
||||
start_time: int,
|
||||
end_time: int
|
||||
end_time: int,
|
||||
):
|
||||
"""
|
||||
创建VAD片段
|
||||
@ -66,7 +70,8 @@ class VAD_Functor_result(BaseModel):
|
||||
audiobinary_index=index,
|
||||
audiobinary_data=audiobinary_data_list[index],
|
||||
start_time=start_time,
|
||||
end_time=end_time)
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
@ -78,11 +83,9 @@ class VAD_Functor_result(BaseModel):
|
||||
"""
|
||||
字符串展示内容
|
||||
"""
|
||||
tostr = f'audiobinary_data_index: {self.audiobinary_index}\n'
|
||||
tostr += f'start_time: {self.start_time}\n'
|
||||
tostr += f'end_time: {self.end_time}\n'
|
||||
tostr += f'data_length: {len(self.audiobinary_data.binary_data)}\n'
|
||||
tostr += f'data_type: {type(self.audiobinary_data.binary_data)}\n'
|
||||
tostr = f"audiobinary_data_index: {self.audiobinary_index}\n"
|
||||
tostr += f"start_time: {self.start_time}\n"
|
||||
tostr += f"end_time: {self.end_time}\n"
|
||||
tostr += f"data_length: {len(self.audiobinary_data.binary_data)}\n"
|
||||
tostr += f"data_type: {type(self.audiobinary_data.binary_data)}\n"
|
||||
return tostr
|
||||
|
||||
|
@ -7,11 +7,13 @@ import threading
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
class ASRPipeline(PipelineBase):
|
||||
"""
|
||||
管道类
|
||||
实现具体的处理逻辑
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
初始化管道
|
||||
@ -91,27 +93,24 @@ class ASRPipeline(PipelineBase):
|
||||
"""
|
||||
try:
|
||||
from src.functor import FunctorFactory
|
||||
|
||||
# 加载VAD、asr、spk functor
|
||||
self._functor_dict["vad"] = FunctorFactory.make_functor(
|
||||
functor_name = "vad",
|
||||
config = self._config,
|
||||
models = self._models
|
||||
functor_name="vad", config=self._config, models=self._models
|
||||
)
|
||||
self._functor_dict["asr"] = FunctorFactory.make_functor(
|
||||
functor_name = "asr",
|
||||
config = self._config,
|
||||
models = self._models
|
||||
functor_name="asr", config=self._config, models=self._models
|
||||
)
|
||||
self._functor_dict["spk"] = FunctorFactory.make_functor(
|
||||
functor_name = "spk",
|
||||
config = self._config,
|
||||
models = self._models
|
||||
functor_name="spk", config=self._config, models=self._models
|
||||
)
|
||||
|
||||
# 创建音频数据存储单元
|
||||
self._audio_binary_data_list = AudioBinary_data_list()
|
||||
|
||||
self._functor_dict["vad"].set_audio_binary_data_list(self._audio_binary_data_list)
|
||||
self._functor_dict["vad"].set_audio_binary_data_list(
|
||||
self._audio_binary_data_list
|
||||
)
|
||||
|
||||
# 初始化子队列
|
||||
self._subqueue_dict["original"] = Queue()
|
||||
@ -134,13 +133,23 @@ class ASRPipeline(PipelineBase):
|
||||
"""
|
||||
带回调函数的put
|
||||
"""
|
||||
|
||||
def put_with_check(data: Any) -> None:
|
||||
queue.put(data)
|
||||
callback(data)
|
||||
|
||||
return put_with_check
|
||||
|
||||
self._functor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result))
|
||||
self._functor_dict["spk"].add_callback(put_with_checkcallback(self._subqueue_dict["spkend"], self._check_result))
|
||||
self._functor_dict["asr"].add_callback(
|
||||
put_with_checkcallback(
|
||||
self._subqueue_dict["asrend"], self._check_result
|
||||
)
|
||||
)
|
||||
self._functor_dict["spk"].add_callback(
|
||||
put_with_checkcallback(
|
||||
self._subqueue_dict["spkend"], self._check_result
|
||||
)
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
|
||||
@ -150,14 +159,14 @@ class ASRPipeline(PipelineBase):
|
||||
检查结果
|
||||
"""
|
||||
# 若asr和spk队列中都有数据,则合并数据
|
||||
if self._subqueue_dict["asrend"].qsize() & self._subqueue_dict["spkend"].qsize():
|
||||
if (
|
||||
self._subqueue_dict["asrend"].qsize()
|
||||
& self._subqueue_dict["spkend"].qsize()
|
||||
):
|
||||
asr_data = self._subqueue_dict["asrend"].get()
|
||||
spk_data = self._subqueue_dict["spkend"].get()
|
||||
# 合并数据
|
||||
result = {
|
||||
"asr_data": asr_data,
|
||||
"spk_data": spk_data
|
||||
}
|
||||
result = {"asr_data": asr_data, "spk_data": spk_data}
|
||||
# 通知回调函数
|
||||
self._notify_callbacks(result)
|
||||
|
||||
|
@ -8,11 +8,13 @@ import time
|
||||
# 配置日志
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineBase(ABC):
|
||||
"""
|
||||
管道基类
|
||||
定义了管道的基本接口和通用功能
|
||||
"""
|
||||
|
||||
def __init__(self, input_queue: Optional[Queue] = None):
|
||||
"""
|
||||
初始化管道
|
||||
@ -115,14 +117,35 @@ class PipelineBase(ABC):
|
||||
# 注意:Python的线程无法被强制终止,这里只是设置标志
|
||||
# 实际终止需要依赖操作系统的进程管理
|
||||
|
||||
|
||||
class PipelineFactory:
|
||||
"""
|
||||
管道工厂类
|
||||
用于创建管道实例
|
||||
"""
|
||||
@staticmethod
|
||||
def create_pipeline(pipeline_name: str) -> Any:
|
||||
|
||||
from src.pipeline.ASRpipeline import ASRPipeline
|
||||
def _create_pipeline_ASRpipeline(*args, **kwargs) -> ASRPipeline:
|
||||
"""
|
||||
创建ASR管道实例
|
||||
"""
|
||||
from src.pipeline.ASRpipeline import ASRPipeline
|
||||
pipeline = ASRPipeline()
|
||||
pipeline.set_config(kwargs["config"])
|
||||
pipeline.set_models(kwargs["models"])
|
||||
pipeline.set_audio_binary(kwargs["audio_binary"])
|
||||
pipeline.set_input_queue(kwargs["input_queue"])
|
||||
pipeline.add_callback(kwargs["callback"])
|
||||
pipeline.bake()
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def create_pipeline(cls, pipeline_name: str, *args, **kwargs) -> Any:
|
||||
"""
|
||||
创建管道实例
|
||||
"""
|
||||
pass
|
||||
if pipeline_name == "ASRpipeline":
|
||||
return cls._create_pipeline_ASRpipeline(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"不支持的管道类型: {pipeline_name}")
|
||||
|
||||
|
@ -172,7 +172,7 @@ class STTRunner(RunnerBase):
|
||||
logger.warning(
|
||||
"等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理",
|
||||
self._stop_timeout,
|
||||
self._input_queue.qsize()
|
||||
self._input_queue.qsize(),
|
||||
)
|
||||
success = False
|
||||
break
|
||||
@ -188,7 +188,7 @@ class STTRunner(RunnerBase):
|
||||
"错误堆栈:\n%s",
|
||||
error_type,
|
||||
error_msg,
|
||||
error_traceback
|
||||
error_traceback,
|
||||
)
|
||||
success = False
|
||||
|
||||
@ -198,7 +198,7 @@ class STTRunner(RunnerBase):
|
||||
logger.warning(
|
||||
"部分管道停止失败,队列状态: 大小=%d, 是否为空=%s",
|
||||
self._input_queue.qsize(),
|
||||
self._input_queue.empty()
|
||||
self._input_queue.empty(),
|
||||
)
|
||||
|
||||
return success
|
||||
|
@ -88,7 +88,9 @@ async def ws_serve(websocket, path):
|
||||
# 更新各种配置参数
|
||||
if "is_speaking" in messagejson:
|
||||
websocket.is_speaking = messagejson["is_speaking"]
|
||||
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
|
||||
websocket.status_dict_asr_online["is_final"] = (
|
||||
not websocket.is_speaking
|
||||
)
|
||||
if "chunk_interval" in messagejson:
|
||||
websocket.chunk_interval = messagejson["chunk_interval"]
|
||||
if "wav_name" in messagejson:
|
||||
@ -97,11 +99,17 @@ async def ws_serve(websocket, path):
|
||||
chunk_size = messagejson["chunk_size"]
|
||||
if isinstance(chunk_size, str):
|
||||
chunk_size = chunk_size.split(",")
|
||||
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
|
||||
websocket.status_dict_asr_online["chunk_size"] = [
|
||||
int(x) for x in chunk_size
|
||||
]
|
||||
if "encoder_chunk_look_back" in messagejson:
|
||||
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
|
||||
websocket.status_dict_asr_online["encoder_chunk_look_back"] = (
|
||||
messagejson["encoder_chunk_look_back"]
|
||||
)
|
||||
if "decoder_chunk_look_back" in messagejson:
|
||||
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"]
|
||||
websocket.status_dict_asr_online["decoder_chunk_look_back"] = (
|
||||
messagejson["decoder_chunk_look_back"]
|
||||
)
|
||||
if "hotword" in messagejson:
|
||||
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
|
||||
if "mode" in messagejson:
|
||||
@ -111,11 +119,17 @@ async def ws_serve(websocket, path):
|
||||
|
||||
# 根据chunk_interval更新VAD的chunk_size
|
||||
websocket.status_dict_vad["chunk_size"] = int(
|
||||
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1] * 60 / websocket.chunk_interval
|
||||
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1]
|
||||
* 60
|
||||
/ websocket.chunk_interval
|
||||
)
|
||||
|
||||
# 处理音频数据
|
||||
if len(frames_asr_online) > 0 or len(frames_asr) >= 0 or not isinstance(message, str):
|
||||
if (
|
||||
len(frames_asr_online) > 0
|
||||
or len(frames_asr) >= 0
|
||||
or not isinstance(message, str)
|
||||
):
|
||||
if not isinstance(message, str): # 二进制音频数据
|
||||
# 添加到帧缓冲区
|
||||
frames.append(message)
|
||||
@ -127,8 +141,10 @@ async def ws_serve(websocket, path):
|
||||
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
|
||||
|
||||
# 达到chunk_interval或最终帧时处理在线ASR
|
||||
if (len(frames_asr_online) % websocket.chunk_interval == 0 or
|
||||
websocket.status_dict_asr_online["is_final"]):
|
||||
if (
|
||||
len(frames_asr_online) % websocket.chunk_interval == 0
|
||||
or websocket.status_dict_asr_online["is_final"]
|
||||
):
|
||||
if websocket.mode == "2pass" or websocket.mode == "online":
|
||||
audio_in = b"".join(frames_asr_online)
|
||||
try:
|
||||
@ -143,7 +159,9 @@ async def ws_serve(websocket, path):
|
||||
|
||||
# VAD处理 - 语音活动检测
|
||||
try:
|
||||
speech_start_i, speech_end_i = await asr_service.async_vad(websocket, message)
|
||||
speech_start_i, speech_end_i = await asr_service.async_vad(
|
||||
websocket, message
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"VAD处理错误: {e}")
|
||||
|
||||
@ -151,8 +169,12 @@ async def ws_serve(websocket, path):
|
||||
if speech_start_i != -1:
|
||||
speech_start = True
|
||||
# 计算开始偏移并收集前面的帧
|
||||
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
|
||||
frames_pre = frames[-beg_bias:] if beg_bias < len(frames) else frames
|
||||
beg_bias = (
|
||||
websocket.vad_pre_idx - speech_start_i
|
||||
) // duration_ms
|
||||
frames_pre = (
|
||||
frames[-beg_bias:] if beg_bias < len(frames) else frames
|
||||
)
|
||||
frames_asr = []
|
||||
frames_asr.extend(frames_pre)
|
||||
|
||||
@ -207,16 +229,16 @@ def start_server(args, asr_service_instance):
|
||||
ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile)
|
||||
|
||||
start_server = websockets.serve(
|
||||
ws_serve, args.host, args.port,
|
||||
ws_serve,
|
||||
args.host,
|
||||
args.port,
|
||||
subprotocols=["binary"],
|
||||
ping_interval=None,
|
||||
ssl=ssl_context
|
||||
ssl=ssl_context,
|
||||
)
|
||||
else:
|
||||
start_server = websockets.serve(
|
||||
ws_serve, args.host, args.port,
|
||||
subprotocols=["binary"],
|
||||
ping_interval=None
|
||||
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None
|
||||
)
|
||||
|
||||
print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}")
|
||||
|
@ -35,8 +35,7 @@ class ASRService:
|
||||
"""
|
||||
# 使用VAD模型分析音频段
|
||||
segments_result = self.model_vad.generate(
|
||||
input=audio_in,
|
||||
**websocket.status_dict_vad
|
||||
input=audio_in, **websocket.status_dict_vad
|
||||
)[0]["value"]
|
||||
|
||||
speech_start = -1
|
||||
@ -64,36 +63,38 @@ class ASRService:
|
||||
if len(audio_in) > 0:
|
||||
# 使用离线ASR模型处理音频
|
||||
rec_result = self.model_asr.generate(
|
||||
input=audio_in,
|
||||
**websocket.status_dict_asr
|
||||
input=audio_in, **websocket.status_dict_asr
|
||||
)[0]
|
||||
|
||||
# 如果有标点符号模型且识别出文本,则添加标点
|
||||
if self.model_punc is not None and len(rec_result["text"]) > 0:
|
||||
rec_result = self.model_punc.generate(
|
||||
input=rec_result["text"],
|
||||
**websocket.status_dict_punc
|
||||
input=rec_result["text"], **websocket.status_dict_punc
|
||||
)[0]
|
||||
|
||||
# 如果识别出文本,发送到客户端
|
||||
if len(rec_result["text"]) > 0:
|
||||
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
||||
message = json.dumps({
|
||||
message = json.dumps(
|
||||
{
|
||||
"mode": mode,
|
||||
"text": rec_result["text"],
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
})
|
||||
}
|
||||
)
|
||||
await websocket.send(message)
|
||||
else:
|
||||
# 如果没有音频数据,发送空文本
|
||||
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
||||
message = json.dumps({
|
||||
message = json.dumps(
|
||||
{
|
||||
"mode": mode,
|
||||
"text": "",
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
})
|
||||
}
|
||||
)
|
||||
await websocket.send(message)
|
||||
|
||||
async def async_asr_online(self, websocket, audio_in):
|
||||
@ -107,21 +108,24 @@ class ASRService:
|
||||
if len(audio_in) > 0:
|
||||
# 使用在线ASR模型处理音频
|
||||
rec_result = self.model_asr_streaming.generate(
|
||||
input=audio_in,
|
||||
**websocket.status_dict_asr_online
|
||||
input=audio_in, **websocket.status_dict_asr_online
|
||||
)[0]
|
||||
|
||||
# 在2pass模式下,如果是最终帧则跳过(留给离线ASR处理)
|
||||
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False):
|
||||
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get(
|
||||
"is_final", False
|
||||
):
|
||||
return
|
||||
|
||||
# 如果识别出文本,发送到客户端
|
||||
if len(rec_result["text"]):
|
||||
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
|
||||
message = json.dumps({
|
||||
message = json.dumps(
|
||||
{
|
||||
"mode": mode,
|
||||
"text": rec_result["text"],
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
})
|
||||
}
|
||||
)
|
||||
await websocket.send(message)
|
@ -1,10 +1,12 @@
|
||||
"""
|
||||
处理各类音频数据与bytes的转换
|
||||
"""
|
||||
|
||||
import wave
|
||||
from pydub import AudioSegment
|
||||
import io
|
||||
|
||||
|
||||
def wav_to_bytes(wav_path: str) -> bytes:
|
||||
"""
|
||||
将WAV文件读取为bytes数据。
|
||||
@ -20,18 +22,20 @@ def wav_to_bytes(wav_path: str) -> bytes:
|
||||
wave.Error: 如果文件不是有效的WAV文件。
|
||||
"""
|
||||
try:
|
||||
with wave.open(wav_path, 'rb') as wf:
|
||||
with wave.open(wav_path, "rb") as wf:
|
||||
# 读取所有帧
|
||||
frames = wf.readframes(wf.getnframes())
|
||||
return frames
|
||||
except FileNotFoundError:
|
||||
# 可以选择记录日志或重新抛出,这里为了清晰直接重新抛出
|
||||
raise FileNotFoundError(f"错误:未找到WAV文件 '{wav_path}'")
|
||||
raise FileNotFoundError(f"错误: 未找到WAV文件 '{wav_path}'")
|
||||
except wave.Error as e:
|
||||
raise wave.Error(f"错误:打开或读取WAV文件 '{wav_path}' 失败 - {e}")
|
||||
raise wave.Error(f"错误: 打开或读取WAV文件 '{wav_path}' 失败 - {e}")
|
||||
|
||||
|
||||
def bytes_to_wav(bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int):
|
||||
def bytes_to_wav(
|
||||
bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int
|
||||
):
|
||||
"""
|
||||
将bytes数据写入为WAV文件。
|
||||
|
||||
@ -46,17 +50,18 @@ def bytes_to_wav(bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: in
|
||||
wave.Error: 如果写入WAV文件失败。
|
||||
"""
|
||||
try:
|
||||
with wave.open(wav_path, 'wb') as wf:
|
||||
with wave.open(wav_path, "wb") as wf:
|
||||
wf.setnchannels(nchannels)
|
||||
wf.setsampwidth(sampwidth)
|
||||
wf.setframerate(framerate)
|
||||
wf.writeframes(bytes_data)
|
||||
except wave.Error as e:
|
||||
raise wave.Error(f"错误:写入WAV文件 '{wav_path}' 失败 - {e}")
|
||||
raise wave.Error(f"错误: 写入WAV文件 '{wav_path}' 失败 - {e}")
|
||||
except Exception as e:
|
||||
# 捕获其他可能的写入错误
|
||||
raise Exception(f"写入WAV文件 '{wav_path}' 时发生未知错误 - {e}")
|
||||
|
||||
|
||||
def mp3_to_bytes(mp3_path: str) -> bytes:
|
||||
"""
|
||||
将MP3文件转换为bytes数据 (原始PCM数据)。
|
||||
@ -76,12 +81,19 @@ def mp3_to_bytes(mp3_path: str) -> bytes:
|
||||
# 获取原始PCM数据
|
||||
return audio.raw_data
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"错误:未找到MP3文件 '{mp3_path}'")
|
||||
raise FileNotFoundError(f"错误: 未找到MP3文件 '{mp3_path}'")
|
||||
except Exception as e: # pydub 可能抛出多种解码相关的错误
|
||||
raise Exception(f"错误:处理MP3文件 '{mp3_path}' 失败 - {e}")
|
||||
raise Exception(f"错误: 处理MP3文件 '{mp3_path}' 失败 - {e}")
|
||||
|
||||
|
||||
def bytes_to_mp3(bytes_data: bytes, mp3_path: str, frame_rate: int, channels: int, sample_width: int, bitrate: str = "192k"):
|
||||
def bytes_to_mp3(
|
||||
bytes_data: bytes,
|
||||
mp3_path: str,
|
||||
frame_rate: int,
|
||||
channels: int,
|
||||
sample_width: int,
|
||||
bitrate: str = "192k",
|
||||
):
|
||||
"""
|
||||
将原始PCM bytes数据转换为MP3文件。
|
||||
|
||||
@ -102,9 +114,9 @@ def bytes_to_mp3(bytes_data: bytes, mp3_path: str, frame_rate: int, channels: in
|
||||
data=bytes_data,
|
||||
sample_width=sample_width,
|
||||
frame_rate=frame_rate,
|
||||
channels=channels
|
||||
channels=channels,
|
||||
)
|
||||
# 导出为MP3
|
||||
audio.export(mp3_path, format="mp3", bitrate=bitrate)
|
||||
except Exception as e:
|
||||
raise Exception(f"错误:转换或写入MP3文件 '{mp3_path}' 失败 - {e}")
|
||||
raise Exception(f"错误: 转换或写入MP3文件 '{mp3_path}' 失败 - {e}")
|
||||
|
@ -3,6 +3,7 @@ import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def setup_logger(
|
||||
name: str = None,
|
||||
level: str = "INFO",
|
||||
@ -45,17 +46,15 @@ def setup_logger(
|
||||
log_path = Path(log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_handler = logging.FileHandler(log_file, encoding='utf-8')
|
||||
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# 注意:移除了 propagate = False,允许日志传递
|
||||
return logger
|
||||
|
||||
def setup_root_logger(
|
||||
level: str = "INFO",
|
||||
log_file: Optional[str] = None
|
||||
) -> None:
|
||||
|
||||
def setup_root_logger(level: str = "INFO", log_file: Optional[str] = None) -> None:
|
||||
"""
|
||||
配置根日志器
|
||||
|
||||
@ -65,10 +64,11 @@ def setup_root_logger(
|
||||
"""
|
||||
setup_logger(None, level, log_file)
|
||||
|
||||
|
||||
def get_module_logger(
|
||||
module_name: str,
|
||||
level: Optional[str] = None, # 改为可选参数
|
||||
log_file: Optional[str] = None # 一般不需要单独指定
|
||||
log_file: Optional[str] = None, # 一般不需要单独指定
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
获取模块级别的logger
|
||||
|
@ -1,10 +1,15 @@
|
||||
from tests.functor.vad_test import test_vad_functor
|
||||
"""
|
||||
测试主函数
|
||||
请在tests目录下创建测试文件, 并在此文件中调用
|
||||
"""
|
||||
|
||||
from tests.pipeline.asr_test import test_asr_pipeline
|
||||
from src.utils.logger import get_module_logger, setup_root_logger
|
||||
|
||||
setup_root_logger(level="INFO", log_file="logs/test_main.log")
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
# from tests.functor.vad_test import test_vad_functor
|
||||
# logger.info("开始测试VAD函数器")
|
||||
# test_vad_functor()
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
Functor测试
|
||||
VAD测试
|
||||
"""
|
||||
|
||||
from src.functor.vad_functor import VADFunctor
|
||||
from src.functor.asr_functor import ASRFunctor
|
||||
from src.functor.spk_functor import SPKFunctor
|
||||
@ -21,6 +22,7 @@ logger = get_module_logger(__name__)
|
||||
|
||||
model_loader = ModelLoader()
|
||||
|
||||
|
||||
def test_vad_functor():
|
||||
# 加载模型
|
||||
args = {
|
||||
@ -38,7 +40,7 @@ def test_vad_functor():
|
||||
chunk_stride=1600,
|
||||
sample_rate=sample_rate,
|
||||
sample_width=16,
|
||||
channels=1
|
||||
channels=1,
|
||||
)
|
||||
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
|
||||
audio_config.chunk_stride = chunk_stride
|
||||
@ -62,9 +64,7 @@ def test_vad_functor():
|
||||
vad_functor.add_callback(lambda x: vad2asr_queue.put(x))
|
||||
vad_functor.add_callback(lambda x: vad2spk_queue.put(x))
|
||||
# 设置模型
|
||||
vad_functor.set_model({
|
||||
'vad': model_loader.models['vad']
|
||||
})
|
||||
vad_functor.set_model({"vad": model_loader.models["vad"]})
|
||||
# 启动VAD函数器
|
||||
vad_functor.run()
|
||||
|
||||
@ -77,9 +77,7 @@ def test_vad_functor():
|
||||
# 设置回调函数
|
||||
asr_functor.add_callback(lambda x: print(f"asr callback: {x}"))
|
||||
# 设置模型
|
||||
asr_functor.set_model({
|
||||
'asr': model_loader.models['asr']
|
||||
})
|
||||
asr_functor.set_model({"asr": model_loader.models["asr"]})
|
||||
# 启动ASR函数器
|
||||
asr_functor.run()
|
||||
|
||||
@ -92,23 +90,25 @@ def test_vad_functor():
|
||||
# 设置回调函数
|
||||
spk_functor.add_callback(lambda x: print(f"spk callback: {x}"))
|
||||
# 设置模型
|
||||
spk_functor.set_model({
|
||||
spk_functor.set_model(
|
||||
{
|
||||
# 'spk': model_loader.models['spk']
|
||||
'spk': 'fake_spk'
|
||||
})
|
||||
"spk": "fake_spk"
|
||||
}
|
||||
)
|
||||
# 启动SPK函数器
|
||||
spk_functor.run()
|
||||
|
||||
|
||||
f_binary = f_data
|
||||
audio_clip_len = 200
|
||||
print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}")
|
||||
print(
|
||||
f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}"
|
||||
)
|
||||
for i in range(0, len(f_binary), audio_clip_len):
|
||||
binary_data = f_binary[i : i + audio_clip_len]
|
||||
input_queue.put(binary_data)
|
||||
# 等待VAD函数器结束
|
||||
|
||||
|
||||
vad_functor.stop()
|
||||
print("[vad_test] VAD函数器结束")
|
||||
|
||||
@ -119,4 +119,6 @@ def test_vad_functor():
|
||||
if OVERWATCH:
|
||||
for index in range(len(audio_binary_data_list)):
|
||||
save_path = f"tests/vad_test_output_{index}.wav"
|
||||
soundfile.write(save_path, audio_binary_data_list[index].binary_data, sample_rate)
|
||||
soundfile.write(
|
||||
save_path, audio_binary_data_list[index].binary_data, sample_rate
|
||||
)
|
||||
|
@ -1,9 +1,20 @@
|
||||
"""
|
||||
模型使用测试
|
||||
此处主要用于各类调用模型的处理数据与输出格式
|
||||
请在主目录下test_main.py中调用
|
||||
将需要测试的模型定义在函数中进行测试, 函数名称需要与测试内容匹配。
|
||||
"""
|
||||
|
||||
from funasr import AutoModel
|
||||
from typing import List, Dict, Any
|
||||
from src.models import VADResponse
|
||||
import time
|
||||
|
||||
|
||||
def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
在线VAD模型使用
|
||||
"""
|
||||
chunk_size = 100 # ms
|
||||
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
|
||||
|
||||
@ -21,7 +32,9 @@ def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
|
||||
time.sleep(0.1)
|
||||
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
|
||||
is_final = i == total_chunk_num - 1
|
||||
res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
|
||||
res = model.generate(
|
||||
input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size
|
||||
)
|
||||
if len(res[0]["value"]):
|
||||
vad_result += VADResponse.from_raw(res)
|
||||
for item in res[0]["value"]:
|
||||
@ -32,16 +45,24 @@ def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
|
||||
# print(item)
|
||||
return vad_result
|
||||
|
||||
|
||||
def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
在线VAD模型使用
|
||||
测试LogicTrager
|
||||
在Rebuild版本后LogicTrager中已弃用
|
||||
"""
|
||||
from src.logic_trager import LogicTrager
|
||||
import soundfile
|
||||
|
||||
from src.config import parse_args
|
||||
|
||||
args = parse_args()
|
||||
|
||||
# from src.functor.model_loader import load_models
|
||||
# models = load_models(args)
|
||||
from src.model_loader import ModelLoader
|
||||
|
||||
models = ModelLoader(args)
|
||||
|
||||
chunk_size = 200 # ms
|
||||
@ -50,7 +71,9 @@ def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
|
||||
|
||||
speech, sample_rate = soundfile.read(file_path)
|
||||
chunk_stride = int(chunk_size * sample_rate / 1000)
|
||||
audio_config = AudioBinary_Config(sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size)
|
||||
audio_config = AudioBinary_Config(
|
||||
sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size
|
||||
)
|
||||
|
||||
logic_trager = LogicTrager(models=models, audio_config=audio_config)
|
||||
for i in range(len(speech) // chunk_stride + 1):
|
||||
@ -61,12 +84,22 @@ def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
|
||||
# print(item)
|
||||
return None
|
||||
|
||||
|
||||
def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
ASR模型使用
|
||||
离线ASR模型使用
|
||||
"""
|
||||
from funasr import AutoModel
|
||||
model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
|
||||
vad_model="fsmn-vad", vad_model_revision="v2.0.4",
|
||||
|
||||
model = AutoModel(
|
||||
model="paraformer-zh",
|
||||
model_revision="v2.0.4",
|
||||
vad_model="fsmn-vad",
|
||||
vad_model_revision="v2.0.4",
|
||||
# punc_model="ct-punc-c", punc_model_revision="v2.0.4",
|
||||
spk_model="cam++", spk_model_revision="v2.0.2",
|
||||
spk_model="cam++",
|
||||
spk_model_revision="v2.0.2",
|
||||
spk_mode="vad_segment",
|
||||
auto_update=False,
|
||||
)
|
||||
@ -80,7 +113,9 @@ def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
|
||||
result = model.generate(speech)
|
||||
return result
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# 请在主目录下调用test_main.py文件进行测试
|
||||
# vad_result = vad_model_use_online("tests/vad_example.wav")
|
||||
vad_result = vad_model_use_online_logic("tests/vad_example.wav")
|
||||
# vad_result = vad_model_use_online_logic("tests/vad_example.wav")
|
||||
# print(vad_result)
|
@ -2,7 +2,9 @@
|
||||
Pipeline测试
|
||||
VAD+ASR+SPK(FAKE)
|
||||
"""
|
||||
|
||||
from src.pipeline.ASRpipeline import ASRPipeline
|
||||
from src.pipeline import PipelineFactory
|
||||
from src.models import AudioBinary_data_list, AudioBinary_Config
|
||||
from src.model_loader import ModelLoader
|
||||
from queue import Queue
|
||||
@ -18,6 +20,7 @@ OVAERWATCH = False
|
||||
|
||||
model_loader = ModelLoader()
|
||||
|
||||
|
||||
def test_asr_pipeline():
|
||||
# 加载模型
|
||||
args = {
|
||||
@ -36,7 +39,7 @@ def test_asr_pipeline():
|
||||
chunk_stride=1600,
|
||||
sample_rate=sample_rate,
|
||||
sample_width=16,
|
||||
channels=1
|
||||
channels=1,
|
||||
)
|
||||
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
|
||||
audio_config.chunk_stride = chunk_stride
|
||||
@ -52,19 +55,30 @@ def test_asr_pipeline():
|
||||
input_queue = Queue()
|
||||
|
||||
# 创建Pipeline
|
||||
asr_pipeline = ASRPipeline()
|
||||
asr_pipeline.set_models(models)
|
||||
asr_pipeline.set_config(config)
|
||||
asr_pipeline.set_audio_binary(audio_binary_data_list)
|
||||
asr_pipeline.set_input_queue(input_queue)
|
||||
asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
|
||||
asr_pipeline.bake()
|
||||
# asr_pipeline = ASRPipeline()
|
||||
# asr_pipeline.set_models(models)
|
||||
# asr_pipeline.set_config(config)
|
||||
# asr_pipeline.set_audio_binary(audio_binary_data_list)
|
||||
# asr_pipeline.set_input_queue(input_queue)
|
||||
# asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
|
||||
# asr_pipeline.bake()
|
||||
asr_pipeline = PipelineFactory.create_pipeline(
|
||||
pipeline_name = "ASRpipeline",
|
||||
models=models,
|
||||
config=config,
|
||||
audio_binary=audio_binary_data_list,
|
||||
input_queue=input_queue,
|
||||
callback=lambda x: print(f"pipeline callback: {x}")
|
||||
)
|
||||
|
||||
# 运行Pipeline
|
||||
asr_instance = asr_pipeline.run()
|
||||
|
||||
|
||||
audio_clip_len = 200
|
||||
print(f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}")
|
||||
print(
|
||||
f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}"
|
||||
)
|
||||
for i in range(0, len(audio_data), audio_clip_len):
|
||||
input_queue.put(audio_data[i : i + audio_clip_len])
|
||||
|
||||
@ -77,4 +91,3 @@ def test_asr_pipeline():
|
||||
time.sleep(5)
|
||||
asr_pipeline.stop()
|
||||
# asr_pipeline.stop()
|
||||
|
||||
|
@ -10,13 +10,13 @@ import os
|
||||
from unittest.mock import patch
|
||||
|
||||
# 将src目录添加到路径
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from src.config import parse_args
|
||||
|
||||
|
||||
def test_default_args():
|
||||
"""测试默认参数值"""
|
||||
with patch('sys.argv', ['script.py']):
|
||||
with patch("sys.argv", ["script.py"]):
|
||||
args = parse_args()
|
||||
|
||||
# 检查服务器参数
|
||||
@ -46,17 +46,24 @@ def test_default_args():
|
||||
def test_custom_args():
|
||||
"""测试自定义参数值"""
|
||||
test_args = [
|
||||
'script.py',
|
||||
'--host', 'localhost',
|
||||
'--port', '8080',
|
||||
'--certfile', 'cert.pem',
|
||||
'--keyfile', 'key.pem',
|
||||
'--asr_model', 'custom_model',
|
||||
'--ngpu', '0',
|
||||
'--device', 'cpu'
|
||||
"script.py",
|
||||
"--host",
|
||||
"localhost",
|
||||
"--port",
|
||||
"8080",
|
||||
"--certfile",
|
||||
"cert.pem",
|
||||
"--keyfile",
|
||||
"key.pem",
|
||||
"--asr_model",
|
||||
"custom_model",
|
||||
"--ngpu",
|
||||
"0",
|
||||
"--device",
|
||||
"cpu",
|
||||
]
|
||||
|
||||
with patch('sys.argv', test_args):
|
||||
with patch("sys.argv", test_args):
|
||||
args = parse_args()
|
||||
|
||||
# 检查自定义参数
|
||||
|
Loading…
x
Reference in New Issue
Block a user