[代码结构]black . 对所有文件格式调整,无功能变化。

This commit is contained in:
Ziyang.Zhang 2025-06-12 15:49:43 +08:00
parent 5b94c40016
commit 5a820b49e4
28 changed files with 543 additions and 421 deletions

14
main.py
View File

@ -1,11 +1,7 @@
from funasr import AutoModel from funasr import AutoModel
chunk_size = 200 # ms chunk_size = 200 # ms
model = AutoModel( model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
model="fsmn-vad",
model_revision="v2.0.4",
disable_update=True
)
import soundfile import soundfile
@ -14,16 +10,16 @@ speech, sample_rate = soundfile.read(wav_file)
chunk_stride = int(chunk_size * sample_rate / 1000) chunk_stride = int(chunk_size * sample_rate / 1000)
cache = {} cache = {}
total_chunk_num = int(len((speech)-1)/chunk_stride+1) total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
for i in range(total_chunk_num): for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
is_final = i == total_chunk_num - 1 is_final = i == total_chunk_num - 1
res = model.generate( res = model.generate(
input=speech_chunk, input=speech_chunk,
cache=cache, cache=cache,
is_final=is_final, is_final=is_final,
chunk_size=chunk_size, chunk_size=chunk_size,
disable_pbar=True disable_pbar=True,
) )
if len(res[0]["value"]): if len(res[0]["value"]):
print(res) print(res)

View File

@ -46,7 +46,7 @@ class AudioBinary:
else: else:
raise ValueError("参数类型错误") raise ValueError("参数类型错误")
def add_slice_listener(self, slice_listener: callable): def add_slice_listener(self, slice_listener: callable) -> None:
""" """
添加切片监听器 添加切片监听器
参数: 参数:
@ -98,10 +98,10 @@ class AudioBinary:
self._binary_data_list.rewrite(target_index, binary_data) self._binary_data_list.rewrite(target_index, binary_data)
def get_binary_data( def get_binary_data(
self, self,
start: int = 0, start: int = 0,
end: Optional[int] = None, end: Optional[int] = None,
) -> Optional[bytes]: ) -> Optional[bytes]:
""" """
获取指定索引的音频数据块 获取指定索引的音频数据块
参数: 参数:
@ -128,7 +128,7 @@ class AudioChunk:
此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑 此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑
""" """
_instance = None _instance: Optional[AudioChunk] = None
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
""" """
@ -138,7 +138,7 @@ class AudioChunk:
cls._instance = super(AudioChunk, cls).__new__(cls, *args, **kwargs) cls._instance = super(AudioChunk, cls).__new__(cls, *args, **kwargs)
return cls._instance return cls._instance
def __init__(self): def __init__(self) -> None:
""" """
初始化AudioChunk实例 初始化AudioChunk实例
""" """
@ -146,10 +146,10 @@ class AudioChunk:
self._slice_listener: List[callable] = [] self._slice_listener: List[callable] = []
def get_audio_binary( def get_audio_binary(
self, self,
binary_name: Optional[str] = None, binary_name: Optional[str] = None,
audio_config: Optional[AudioBinary_Config] = None, audio_config: Optional[AudioBinary_Config] = None,
) -> AudioBinary: ) -> AudioBinary:
""" """
获取音频数据块 获取音频数据块
参数: 参数:

View File

@ -21,41 +21,23 @@ def parse_args():
"--host", "--host",
type=str, type=str,
default="0.0.0.0", default="0.0.0.0",
help="服务器主机地址例如localhost, 0.0.0.0" help="服务器主机地址例如localhost, 0.0.0.0",
)
parser.add_argument(
"--port",
type=int,
default=10095,
help="WebSocket服务器端口"
) )
parser.add_argument("--port", type=int, default=10095, help="WebSocket服务器端口")
# SSL配置 # SSL配置
parser.add_argument( parser.add_argument("--certfile", type=str, default="", help="SSL证书文件路径")
"--certfile", parser.add_argument("--keyfile", type=str, default="", help="SSL密钥文件路径")
type=str,
default="",
help="SSL证书文件路径"
)
parser.add_argument(
"--keyfile",
type=str,
default="",
help="SSL密钥文件路径"
)
# ASR模型配置 # ASR模型配置
parser.add_argument( parser.add_argument(
"--asr_model", "--asr_model",
type=str, type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="离线ASR模型从ModelScope获取" help="离线ASR模型从ModelScope获取",
) )
parser.add_argument( parser.add_argument(
"--asr_model_revision", "--asr_model_revision", type=str, default="v2.0.4", help="离线ASR模型版本"
type=str,
default="v2.0.4",
help="离线ASR模型版本"
) )
# 在线ASR模型配置 # 在线ASR模型配置
@ -63,13 +45,13 @@ def parse_args():
"--asr_model_online", "--asr_model_online",
type=str, type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
help="在线ASR模型从ModelScope获取" help="在线ASR模型从ModelScope获取",
) )
parser.add_argument( parser.add_argument(
"--asr_model_online_revision", "--asr_model_online_revision",
type=str, type=str,
default="v2.0.4", default="v2.0.4",
help="在线ASR模型版本" help="在线ASR模型版本",
) )
# VAD模型配置 # VAD模型配置
@ -77,13 +59,10 @@ def parse_args():
"--vad_model", "--vad_model",
type=str, type=str,
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
help="VAD语音活动检测模型从ModelScope获取" help="VAD语音活动检测模型从ModelScope获取",
) )
parser.add_argument( parser.add_argument(
"--vad_model_revision", "--vad_model_revision", type=str, default="v2.0.4", help="VAD模型版本"
type=str,
default="v2.0.4",
help="VAD模型版本"
) )
# 标点符号模型配置 # 标点符号模型配置
@ -91,34 +70,18 @@ def parse_args():
"--punc_model", "--punc_model",
type=str, type=str,
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
help="标点符号模型从ModelScope获取" help="标点符号模型从ModelScope获取",
) )
parser.add_argument( parser.add_argument(
"--punc_model_revision", "--punc_model_revision", type=str, default="v2.0.4", help="标点符号模型版本"
type=str,
default="v2.0.4",
help="标点符号模型版本"
) )
# 硬件配置 # 硬件配置
parser.add_argument("--ngpu", type=int, default=1, help="GPU数量0表示仅使用CPU")
parser.add_argument( parser.add_argument(
"--ngpu", "--device", type=str, default="cuda", help="设备类型cuda或cpu"
type=int,
default=1,
help="GPU数量0表示仅使用CPU"
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="设备类型cuda或cpu"
)
parser.add_argument(
"--ncpu",
type=int,
default=4,
help="CPU核心数"
) )
parser.add_argument("--ncpu", type=int, default=4, help="CPU核心数")
return parser.parse_args() return parser.parse_args()

View File

@ -2,8 +2,9 @@
ASRFunctor ASRFunctor
负责对音频片段进行ASR处理, 以ASR_Result进行callback 负责对音频片段进行ASR处理, 以ASR_Result进行callback
""" """
from src.functor.base import BaseFunctor from src.functor.base import BaseFunctor
from src.models import AudioBinary_data_list, AudioBinary_Config,VAD_Functor_result from src.models import AudioBinary_data_list, AudioBinary_Config, VAD_Functor_result
from typing import Callable, List from typing import Callable, List
from queue import Queue, Empty from queue import Queue, Empty
import threading import threading
@ -13,6 +14,7 @@ from src.utils.logger import get_module_logger
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
class ASRFunctor(BaseFunctor): class ASRFunctor(BaseFunctor):
""" """
ASRFunctor ASRFunctor
@ -83,7 +85,7 @@ class ASRFunctor(BaseFunctor):
""" """
回调函数 回调函数
""" """
text = result[0]['text'].replace(" ", "") text = result[0]["text"].replace(" ", "")
for callback in self._callback: for callback in self._callback:
callback(text) callback(text)
@ -157,5 +159,3 @@ class ASRFunctor(BaseFunctor):
with self._status_lock: with self._status_lock:
self._is_running = False self._is_running = False
return not self._thread.is_alive() return not self._thread.is_alive()

View File

@ -11,11 +11,13 @@ Functor基础模块
BaseFunctor: Functor抽象类 BaseFunctor: Functor抽象类
FunctorFactory: Functor工厂类 FunctorFactory: Functor工厂类
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, List from typing import Callable, List
from queue import Queue from queue import Queue
import threading import threading
class BaseFunctor(ABC): class BaseFunctor(ABC):
""" """
Functor抽象类 Functor抽象类
@ -27,9 +29,7 @@ class BaseFunctor(ABC):
_model (dict): 存储模型相关的配置和实例 _model (dict): 存储模型相关的配置和实例
""" """
def __init__( def __init__(self):
self
):
""" """
初始化函数器 初始化函数器
@ -111,7 +111,6 @@ class BaseFunctor(ABC):
""" """
class FunctorFactory: class FunctorFactory:
""" """
Functor工厂类 Functor工厂类
@ -138,11 +137,11 @@ class FunctorFactory:
""" """
if functor_name == "vad": if functor_name == "vad":
return cls._make_vadfunctor(config = config,models = models) return cls._make_vadfunctor(config=config, models=models)
elif functor_name == "asr": elif functor_name == "asr":
return cls._make_asrfunctor(config = config,models = models) return cls._make_asrfunctor(config=config, models=models)
elif functor_name == "spk": elif functor_name == "spk":
return cls._make_spkfunctor(config = config,models = models) return cls._make_spkfunctor(config=config, models=models)
else: else:
raise ValueError(f"不支持的Functor类型: {functor_name}") raise ValueError(f"不支持的Functor类型: {functor_name}")
@ -151,10 +150,9 @@ class FunctorFactory:
创建VAD Functor实例 创建VAD Functor实例
""" """
from src.functor.vad_functor import VADFunctor from src.functor.vad_functor import VADFunctor
audio_config = config["audio_config"] audio_config = config["audio_config"]
model = { model = {"vad": models["vad"]}
"vad": models["vad"]
}
vad_functor = VADFunctor() vad_functor = VADFunctor()
vad_functor.set_audio_config(audio_config) vad_functor.set_audio_config(audio_config)
@ -167,10 +165,9 @@ class FunctorFactory:
创建ASR Functor实例 创建ASR Functor实例
""" """
from src.functor.asr_functor import ASRFunctor from src.functor.asr_functor import ASRFunctor
audio_config = config["audio_config"] audio_config = config["audio_config"]
model = { model = {"asr": models["asr"]}
"asr": models["asr"]
}
asr_functor = ASRFunctor() asr_functor = ASRFunctor()
asr_functor.set_audio_config(audio_config) asr_functor.set_audio_config(audio_config)
@ -183,10 +180,9 @@ class FunctorFactory:
创建SPK Functor实例 创建SPK Functor实例
""" """
from src.functor.spk_functor import SPKFunctor from src.functor.spk_functor import SPKFunctor
audio_config = config["audio_config"] audio_config = config["audio_config"]
model = { model = {"spk": models["spk"]}
"spk": models["spk"]
}
spk_functor = SPKFunctor() spk_functor = SPKFunctor()
spk_functor.set_audio_config(audio_config) spk_functor.set_audio_config(audio_config)

View File

@ -2,6 +2,7 @@
SpkFunctor SpkFunctor
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback 负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
""" """
from src.functor.base import BaseFunctor from src.functor.base import BaseFunctor
from src.models import AudioBinary_Config, VAD_Functor_result from src.models import AudioBinary_Config, VAD_Functor_result
from typing import Callable, List from typing import Callable, List
@ -13,6 +14,7 @@ from src.utils.logger import get_module_logger
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
class SPKFunctor(BaseFunctor): class SPKFunctor(BaseFunctor):
""" """
SPKFunctor SPKFunctor
@ -33,7 +35,6 @@ class SPKFunctor(BaseFunctor):
self._input_queue: Queue = None # 输入队列 self._input_queue: Queue = None # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置 self._audio_config: AudioBinary_Config = None # 音频配置
def reset_cache(self) -> None: def reset_cache(self) -> None:
""" """
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务 重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
@ -83,7 +84,7 @@ class SPKFunctor(BaseFunctor):
# input=binary_data, # input=binary_data,
# chunk_size=self._audio_config.chunk_size, # chunk_size=self._audio_config.chunk_size,
# ) # )
result = [{'result': "spk1", 'score': {"spk1": 0.9, "spk2": 0.3}}] result = [{"result": "spk1", "score": {"spk1": 0.9, "spk2": 0.3}}]
self._do_callback(result) self._do_callback(result)
def _run(self) -> None: def _run(self) -> None:
@ -142,5 +143,3 @@ class SPKFunctor(BaseFunctor):
with self._status_lock: with self._status_lock:
self._is_running = False self._is_running = False
return not self._thread.is_alive() return not self._thread.is_alive()

View File

@ -2,6 +2,7 @@
VADFunctor VADFunctor
负责对音频片段进行VAD处理, 以VAD_Result进行callback 负责对音频片段进行VAD处理, 以VAD_Result进行callback
""" """
import threading import threading
from queue import Empty, Queue from queue import Empty, Queue
from typing import List, Any, Callable from typing import List, Any, Callable
@ -105,9 +106,7 @@ class VADFunctor(BaseFunctor):
self._callback = [] self._callback = []
self._callback.append(callback) self._callback.append(callback)
def _do_callback( def _do_callback(self, result: List[List[int]]) -> None:
self, result: List[List[int]]
) -> None:
""" """
回调函数 回调函数
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice

View File

@ -6,9 +6,11 @@
from src.utils.logger import get_module_logger from src.utils.logger import get_module_logger
from typing import Any, Dict, Type, Callable from typing import Any, Dict, Type, Callable
# 配置日志 # 配置日志
logger = get_module_logger(__name__, level="INFO") logger = get_module_logger(__name__, level="INFO")
class AutoAfterMeta(type): class AutoAfterMeta(type):
""" """
自动调用__after__函数的元类 自动调用__after__函数的元类
@ -21,7 +23,7 @@ class AutoAfterMeta(type):
# 遍历所有属性 # 遍历所有属性
for attr_name, attr_value in attrs.items(): for attr_name, attr_value in attrs.items():
# 如果是函数且不是以_开头 # 如果是函数且不是以_开头
if callable(attr_value) and not attr_name.startswith('__'): if callable(attr_value) and not attr_name.startswith("__"):
# 获取原函数 # 获取原函数
original_func = attr_value original_func = attr_value
@ -45,6 +47,7 @@ class AutoAfterMeta(type):
logger.error(f"调用{after_func_name}时出错: {e}") logger.error(f"调用{after_func_name}时出错: {e}")
return result return result
return wrapper return wrapper
# 替换原函数 # 替换原函数
@ -68,6 +71,7 @@ class AutoAfterMeta(type):
return cls._instances[cls] return cls._instances[cls]
""" """
整体识别的处理逻辑 整体识别的处理逻辑
1.压入二进制音频信息 1.压入二进制音频信息
@ -88,10 +92,12 @@ from src.models import AudioBinary_Config
from src.models import AudioBinary_Chunk from src.models import AudioBinary_Chunk
from typing import List from typing import List
class LogicTrager(metaclass=AutoAfterMeta): class LogicTrager(metaclass=AutoAfterMeta):
"""逻辑触发器类""" """逻辑触发器类"""
def __init__(self, def __init__(
self,
audio_chunk_max_size: int = 1024 * 1024 * 10, audio_chunk_max_size: int = 1024 * 1024 * 10,
audio_config: AudioBinary_Config = None, audio_config: AudioBinary_Config = None,
result_callback: Callable = None, result_callback: Callable = None,
@ -99,21 +105,22 @@ class LogicTrager(metaclass=AutoAfterMeta):
): ):
"""初始化""" """初始化"""
# 存储音频块 # 存储音频块
self._audio_chunk : List[AudioBinary_Chunk] = [] self._audio_chunk: List[AudioBinary_Chunk] = []
# 存储二进制数据 # 存储二进制数据
self._audio_chunk_binary = b'' self._audio_chunk_binary = b""
self._audio_chunk_max_size = audio_chunk_max_size self._audio_chunk_max_size = audio_chunk_max_size
# 音频参数 # 音频参数
self._audio_config = audio_config if audio_config is not None else AudioBinary_Config() self._audio_config = (
audio_config if audio_config is not None else AudioBinary_Config()
)
# 结果队列 # 结果队列
self._result_queue = [] self._result_queue = []
# 聚合结果回调函数 # 聚合结果回调函数
self._aggregate_result_callback = result_callback self._aggregate_result_callback = result_callback
# 组件 # 组件
self._vad = VAD(VAD_model = models.get("vad"), audio_config = self._audio_config) self._vad = VAD(VAD_model=models.get("vad"), audio_config=self._audio_config)
self._vad.set_callback(self.push_audio_chunk) self._vad.set_callback(self.push_audio_chunk)
logger.info("初始化LogicTrager") logger.info("初始化LogicTrager")
def push_binary_data(self, chunk: bytes) -> None: def push_binary_data(self, chunk: bytes) -> None:
@ -138,7 +145,11 @@ class LogicTrager(metaclass=AutoAfterMeta):
""" """
音频处理 音频处理
""" """
logger.info("LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(chunk.start_time, chunk.end_time, len(chunk.chunk))) logger.info(
"LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(
chunk.start_time, chunk.end_time, len(chunk.chunk)
)
)
self._audio_chunk.append(chunk) self._audio_chunk.append(chunk)
def __after__push_audio_chunk(self) -> None: def __after__push_audio_chunk(self) -> None:

View File

@ -115,7 +115,7 @@ class ModelLoader:
self.models = {} self.models = {}
# 加载离线ASR模型 # 加载离线ASR模型
# 检查对应键是否存在 # 检查对应键是否存在
model_list = ['asr', 'asr_online', 'vad', 'punc', 'spk'] model_list = ["asr", "asr_online", "vad", "punc", "spk"]
for model_name in model_list: for model_name in model_list:
name_model = f"{model_name}_model" name_model = f"{model_name}_model"
name_model_revision = f"{model_name}_model_revision" name_model_revision = f"{model_name}_model_revision"

View File

@ -1,3 +1,9 @@
from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
from .vad import VAD_Functor_result from .vad import VAD_Functor_result
__all__ = ["AudioBinary_Config", "AudioBinary_data_list", "_AudioBinary_data", "VAD_Functor_result"]
__all__ = [
"AudioBinary_Config",
"AudioBinary_data_list",
"_AudioBinary_data",
"VAD_Functor_result",
]

View File

@ -8,8 +8,10 @@ logger = get_module_logger(__name__)
binary_data_types = (bytes, numpy.ndarray) binary_data_types = (bytes, numpy.ndarray)
class AudioBinary_Config(BaseModel): class AudioBinary_Config(BaseModel):
"""二进制音频块配置信息""" """二进制音频块配置信息"""
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -37,14 +39,16 @@ class AudioBinary_Config(BaseModel):
""" """
return int(frame * 1000 / self.sample_rate) return int(frame * 1000 / self.sample_rate)
class _AudioBinary_data(BaseModel): class _AudioBinary_data(BaseModel):
"""音频数据""" """音频数据"""
binary_data: binary_data_types = Field(description="音频二进制数据", default=None) binary_data: binary_data_types = Field(description="音频二进制数据", default=None)
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@validator('binary_data') @validator("binary_data")
def validate_binary_data(cls, v): def validate_binary_data(cls, v):
""" """
验证音频数据 验证音频数据
@ -54,7 +58,11 @@ class _AudioBinary_data(BaseModel):
binary_data_types: 音频数据 binary_data_types: 音频数据
""" """
if not isinstance(v, (bytes, numpy.ndarray)): if not isinstance(v, (bytes, numpy.ndarray)):
logger.warning("[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查", cls.__class__.__name__, type(v)) logger.warning(
"[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查",
cls.__class__.__name__,
type(v),
)
return v return v
def __len__(self): def __len__(self):
@ -71,7 +79,11 @@ class _AudioBinary_data(BaseModel):
Args: Args:
binary_data: 音频数据 binary_data: 音频数据
""" """
logger.debug("[%s]初始化音频数据, 数据类型为%s", self.__class__.__name__, type(binary_data)) logger.debug(
"[%s]初始化音频数据, 数据类型为%s",
self.__class__.__name__,
type(binary_data),
)
super().__init__(binary_data=binary_data) super().__init__(binary_data=binary_data)
def __getitem__(self): def __getitem__(self):
@ -82,9 +94,13 @@ class _AudioBinary_data(BaseModel):
""" """
return self.binary_data return self.binary_data
class AudioBinary_data_list(BaseModel): class AudioBinary_data_list(BaseModel):
"""音频数据列表""" """音频数据列表"""
binary_data_list: List[_AudioBinary_data] = Field(description="音频数据列表", default=[])
binary_data_list: List[_AudioBinary_data] = Field(
description="音频数据列表", default=[]
)
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -118,6 +134,7 @@ class AudioBinary_data_list(BaseModel):
""" """
return len(self.binary_data_list) return len(self.binary_data_list)
# class AudioBinary_Slice(BaseModel): # class AudioBinary_Slice(BaseModel):
# """音频块切片""" # """音频块切片"""
# target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None) # target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None)

View File

@ -2,23 +2,27 @@ from pydantic import BaseModel, Field, validator
from typing import List, Optional, Callable, Any from typing import List, Optional, Callable, Any
from .audio import AudioBinary_data_list, _AudioBinary_data from .audio import AudioBinary_data_list, _AudioBinary_data
class VAD_Functor_result(BaseModel): class VAD_Functor_result(BaseModel):
""" """
VADFunctor结果 VADFunctor结果
""" """
audiobinary_data_list: AudioBinary_data_list = Field(description="音频数据列表") audiobinary_data_list: AudioBinary_data_list = Field(description="音频数据列表")
audiobinary_index: int = Field(description="音频数据索引") audiobinary_index: int = Field(description="音频数据索引")
audiobinary_data: _AudioBinary_data = Field(description="音频数据, 指向AudioBinary_data") audiobinary_data: _AudioBinary_data = Field(
description="音频数据, 指向AudioBinary_data"
)
start_time: int = Field(description="开始时间", is_required=True) start_time: int = Field(description="开始时间", is_required=True)
end_time: int = Field(description="结束时间", is_required=True) end_time: int = Field(description="结束时间", is_required=True)
@validator('audiobinary_data_list') @validator("audiobinary_data_list")
def validate_audiobinary_data_list(cls, v): def validate_audiobinary_data_list(cls, v):
if not isinstance(v, AudioBinary_data_list): if not isinstance(v, AudioBinary_data_list):
raise ValueError("audiobinary_data_list必须是AudioBinary_data_list类型") raise ValueError("audiobinary_data_list必须是AudioBinary_data_list类型")
return v return v
@validator('audiobinary_index') @validator("audiobinary_index")
def validate_audiobinary_index(cls, v): def validate_audiobinary_index(cls, v):
if not isinstance(v, int): if not isinstance(v, int):
raise ValueError("audiobinary_index必须是int类型") raise ValueError("audiobinary_index必须是int类型")
@ -26,13 +30,13 @@ class VAD_Functor_result(BaseModel):
raise ValueError("audiobinary_index必须大于0") raise ValueError("audiobinary_index必须大于0")
return v return v
@validator('audiobinary_data') @validator("audiobinary_data")
def validate_audiobinary_data(cls, v): def validate_audiobinary_data(cls, v):
if not isinstance(v, _AudioBinary_data): if not isinstance(v, _AudioBinary_data):
raise ValueError("audiobinary_data必须是AudioBinary_data类型") raise ValueError("audiobinary_data必须是AudioBinary_data类型")
return v return v
@validator('start_time') @validator("start_time")
def validate_start_time(cls, v): def validate_start_time(cls, v):
if not isinstance(v, int): if not isinstance(v, int):
raise ValueError("start_time必须是int类型") raise ValueError("start_time必须是int类型")
@ -40,11 +44,11 @@ class VAD_Functor_result(BaseModel):
raise ValueError("start_time必须大于0") raise ValueError("start_time必须大于0")
return v return v
@validator('end_time') @validator("end_time")
def validate_end_time(cls, v, values): def validate_end_time(cls, v, values):
if not isinstance(v, int): if not isinstance(v, int):
raise ValueError("end_time必须是int类型") raise ValueError("end_time必须是int类型")
if 'start_time' in values and v <= values['start_time']: if "start_time" in values and v <= values["start_time"]:
raise ValueError("end_time必须大于start_time") raise ValueError("end_time必须大于start_time")
return v return v
@ -54,7 +58,7 @@ class VAD_Functor_result(BaseModel):
audiobinary_data_list: AudioBinary_data_list, audiobinary_data_list: AudioBinary_data_list,
data: Any, data: Any,
start_time: int, start_time: int,
end_time: int end_time: int,
): ):
""" """
创建VAD片段 创建VAD片段
@ -66,7 +70,8 @@ class VAD_Functor_result(BaseModel):
audiobinary_index=index, audiobinary_index=index,
audiobinary_data=audiobinary_data_list[index], audiobinary_data=audiobinary_data_list[index],
start_time=start_time, start_time=start_time,
end_time=end_time) end_time=end_time,
)
def __len__(self): def __len__(self):
""" """
@ -78,11 +83,9 @@ class VAD_Functor_result(BaseModel):
""" """
字符串展示内容 字符串展示内容
""" """
tostr = f'audiobinary_data_index: {self.audiobinary_index}\n' tostr = f"audiobinary_data_index: {self.audiobinary_index}\n"
tostr += f'start_time: {self.start_time}\n' tostr += f"start_time: {self.start_time}\n"
tostr += f'end_time: {self.end_time}\n' tostr += f"end_time: {self.end_time}\n"
tostr += f'data_length: {len(self.audiobinary_data.binary_data)}\n' tostr += f"data_length: {len(self.audiobinary_data.binary_data)}\n"
tostr += f'data_type: {type(self.audiobinary_data.binary_data)}\n' tostr += f"data_type: {type(self.audiobinary_data.binary_data)}\n"
return tostr return tostr

View File

@ -7,11 +7,13 @@ import threading
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
class ASRPipeline(PipelineBase): class ASRPipeline(PipelineBase):
""" """
管道类 管道类
实现具体的处理逻辑 实现具体的处理逻辑
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
""" """
初始化管道 初始化管道
@ -91,27 +93,24 @@ class ASRPipeline(PipelineBase):
""" """
try: try:
from src.functor import FunctorFactory from src.functor import FunctorFactory
# 加载VAD、asr、spk functor # 加载VAD、asr、spk functor
self._functor_dict["vad"] = FunctorFactory.make_functor( self._functor_dict["vad"] = FunctorFactory.make_functor(
functor_name = "vad", functor_name="vad", config=self._config, models=self._models
config = self._config,
models = self._models
) )
self._functor_dict["asr"] = FunctorFactory.make_functor( self._functor_dict["asr"] = FunctorFactory.make_functor(
functor_name = "asr", functor_name="asr", config=self._config, models=self._models
config = self._config,
models = self._models
) )
self._functor_dict["spk"] = FunctorFactory.make_functor( self._functor_dict["spk"] = FunctorFactory.make_functor(
functor_name = "spk", functor_name="spk", config=self._config, models=self._models
config = self._config,
models = self._models
) )
# 创建音频数据存储单元 # 创建音频数据存储单元
self._audio_binary_data_list = AudioBinary_data_list() self._audio_binary_data_list = AudioBinary_data_list()
self._functor_dict["vad"].set_audio_binary_data_list(self._audio_binary_data_list) self._functor_dict["vad"].set_audio_binary_data_list(
self._audio_binary_data_list
)
# 初始化子队列 # 初始化子队列
self._subqueue_dict["original"] = Queue() self._subqueue_dict["original"] = Queue()
@ -134,13 +133,23 @@ class ASRPipeline(PipelineBase):
""" """
带回调函数的put 带回调函数的put
""" """
def put_with_check(data: Any) -> None: def put_with_check(data: Any) -> None:
queue.put(data) queue.put(data)
callback(data) callback(data)
return put_with_check return put_with_check
self._functor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result)) self._functor_dict["asr"].add_callback(
self._functor_dict["spk"].add_callback(put_with_checkcallback(self._subqueue_dict["spkend"], self._check_result)) put_with_checkcallback(
self._subqueue_dict["asrend"], self._check_result
)
)
self._functor_dict["spk"].add_callback(
put_with_checkcallback(
self._subqueue_dict["spkend"], self._check_result
)
)
except ImportError: except ImportError:
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化") raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
@ -150,14 +159,14 @@ class ASRPipeline(PipelineBase):
检查结果 检查结果
""" """
# 若asr和spk队列中都有数据则合并数据 # 若asr和spk队列中都有数据则合并数据
if self._subqueue_dict["asrend"].qsize() & self._subqueue_dict["spkend"].qsize(): if (
self._subqueue_dict["asrend"].qsize()
& self._subqueue_dict["spkend"].qsize()
):
asr_data = self._subqueue_dict["asrend"].get() asr_data = self._subqueue_dict["asrend"].get()
spk_data = self._subqueue_dict["spkend"].get() spk_data = self._subqueue_dict["spkend"].get()
# 合并数据 # 合并数据
result = { result = {"asr_data": asr_data, "spk_data": spk_data}
"asr_data": asr_data,
"spk_data": spk_data
}
# 通知回调函数 # 通知回调函数
self._notify_callbacks(result) self._notify_callbacks(result)

View File

@ -8,11 +8,13 @@ import time
# 配置日志 # 配置日志
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PipelineBase(ABC): class PipelineBase(ABC):
""" """
管道基类 管道基类
定义了管道的基本接口和通用功能 定义了管道的基本接口和通用功能
""" """
def __init__(self, input_queue: Optional[Queue] = None): def __init__(self, input_queue: Optional[Queue] = None):
""" """
初始化管道 初始化管道
@ -115,14 +117,35 @@ class PipelineBase(ABC):
# 注意Python的线程无法被强制终止这里只是设置标志 # 注意Python的线程无法被强制终止这里只是设置标志
# 实际终止需要依赖操作系统的进程管理 # 实际终止需要依赖操作系统的进程管理
class PipelineFactory: class PipelineFactory:
""" """
管道工厂类 管道工厂类
用于创建管道实例 用于创建管道实例
""" """
@staticmethod
def create_pipeline(pipeline_name: str) -> Any: from src.pipeline.ASRpipeline import ASRPipeline
def _create_pipeline_ASRpipeline(*args, **kwargs) -> ASRPipeline:
"""
创建ASR管道实例
"""
from src.pipeline.ASRpipeline import ASRPipeline
pipeline = ASRPipeline()
pipeline.set_config(kwargs["config"])
pipeline.set_models(kwargs["models"])
pipeline.set_audio_binary(kwargs["audio_binary"])
pipeline.set_input_queue(kwargs["input_queue"])
pipeline.add_callback(kwargs["callback"])
pipeline.bake()
return pipeline
@classmethod
def create_pipeline(cls, pipeline_name: str, *args, **kwargs) -> Any:
""" """
创建管道实例 创建管道实例
""" """
pass if pipeline_name == "ASRpipeline":
return cls._create_pipeline_ASRpipeline(*args, **kwargs)
else:
raise ValueError(f"不支持的管道类型: {pipeline_name}")

View File

@ -172,7 +172,7 @@ class STTRunner(RunnerBase):
logger.warning( logger.warning(
"等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理", "等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理",
self._stop_timeout, self._stop_timeout,
self._input_queue.qsize() self._input_queue.qsize(),
) )
success = False success = False
break break
@ -188,7 +188,7 @@ class STTRunner(RunnerBase):
"错误堆栈:\n%s", "错误堆栈:\n%s",
error_type, error_type,
error_msg, error_msg,
error_traceback error_traceback,
) )
success = False success = False
@ -198,7 +198,7 @@ class STTRunner(RunnerBase):
logger.warning( logger.warning(
"部分管道停止失败,队列状态: 大小=%d, 是否为空=%s", "部分管道停止失败,队列状态: 大小=%d, 是否为空=%s",
self._input_queue.qsize(), self._input_queue.qsize(),
self._input_queue.empty() self._input_queue.empty(),
) )
return success return success

View File

@ -88,7 +88,9 @@ async def ws_serve(websocket, path):
# 更新各种配置参数 # 更新各种配置参数
if "is_speaking" in messagejson: if "is_speaking" in messagejson:
websocket.is_speaking = messagejson["is_speaking"] websocket.is_speaking = messagejson["is_speaking"]
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking websocket.status_dict_asr_online["is_final"] = (
not websocket.is_speaking
)
if "chunk_interval" in messagejson: if "chunk_interval" in messagejson:
websocket.chunk_interval = messagejson["chunk_interval"] websocket.chunk_interval = messagejson["chunk_interval"]
if "wav_name" in messagejson: if "wav_name" in messagejson:
@ -97,11 +99,17 @@ async def ws_serve(websocket, path):
chunk_size = messagejson["chunk_size"] chunk_size = messagejson["chunk_size"]
if isinstance(chunk_size, str): if isinstance(chunk_size, str):
chunk_size = chunk_size.split(",") chunk_size = chunk_size.split(",")
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size] websocket.status_dict_asr_online["chunk_size"] = [
int(x) for x in chunk_size
]
if "encoder_chunk_look_back" in messagejson: if "encoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"] websocket.status_dict_asr_online["encoder_chunk_look_back"] = (
messagejson["encoder_chunk_look_back"]
)
if "decoder_chunk_look_back" in messagejson: if "decoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"] websocket.status_dict_asr_online["decoder_chunk_look_back"] = (
messagejson["decoder_chunk_look_back"]
)
if "hotword" in messagejson: if "hotword" in messagejson:
websocket.status_dict_asr["hotword"] = messagejson["hotwords"] websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
if "mode" in messagejson: if "mode" in messagejson:
@ -111,11 +119,17 @@ async def ws_serve(websocket, path):
# 根据chunk_interval更新VAD的chunk_size # 根据chunk_interval更新VAD的chunk_size
websocket.status_dict_vad["chunk_size"] = int( websocket.status_dict_vad["chunk_size"] = int(
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1] * 60 / websocket.chunk_interval websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1]
* 60
/ websocket.chunk_interval
) )
# 处理音频数据 # 处理音频数据
if len(frames_asr_online) > 0 or len(frames_asr) >= 0 or not isinstance(message, str): if (
len(frames_asr_online) > 0
or len(frames_asr) >= 0
or not isinstance(message, str)
):
if not isinstance(message, str): # 二进制音频数据 if not isinstance(message, str): # 二进制音频数据
# 添加到帧缓冲区 # 添加到帧缓冲区
frames.append(message) frames.append(message)
@ -127,8 +141,10 @@ async def ws_serve(websocket, path):
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1 websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
# 达到chunk_interval或最终帧时处理在线ASR # 达到chunk_interval或最终帧时处理在线ASR
if (len(frames_asr_online) % websocket.chunk_interval == 0 or if (
websocket.status_dict_asr_online["is_final"]): len(frames_asr_online) % websocket.chunk_interval == 0
or websocket.status_dict_asr_online["is_final"]
):
if websocket.mode == "2pass" or websocket.mode == "online": if websocket.mode == "2pass" or websocket.mode == "online":
audio_in = b"".join(frames_asr_online) audio_in = b"".join(frames_asr_online)
try: try:
@ -143,7 +159,9 @@ async def ws_serve(websocket, path):
# VAD处理 - 语音活动检测 # VAD处理 - 语音活动检测
try: try:
speech_start_i, speech_end_i = await asr_service.async_vad(websocket, message) speech_start_i, speech_end_i = await asr_service.async_vad(
websocket, message
)
except Exception as e: except Exception as e:
print(f"VAD处理错误: {e}") print(f"VAD处理错误: {e}")
@ -151,8 +169,12 @@ async def ws_serve(websocket, path):
if speech_start_i != -1: if speech_start_i != -1:
speech_start = True speech_start = True
# 计算开始偏移并收集前面的帧 # 计算开始偏移并收集前面的帧
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms beg_bias = (
frames_pre = frames[-beg_bias:] if beg_bias < len(frames) else frames websocket.vad_pre_idx - speech_start_i
) // duration_ms
frames_pre = (
frames[-beg_bias:] if beg_bias < len(frames) else frames
)
frames_asr = [] frames_asr = []
frames_asr.extend(frames_pre) frames_asr.extend(frames_pre)
@ -207,16 +229,16 @@ def start_server(args, asr_service_instance):
ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile) ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile)
start_server = websockets.serve( start_server = websockets.serve(
ws_serve, args.host, args.port, ws_serve,
args.host,
args.port,
subprotocols=["binary"], subprotocols=["binary"],
ping_interval=None, ping_interval=None,
ssl=ssl_context ssl=ssl_context,
) )
else: else:
start_server = websockets.serve( start_server = websockets.serve(
ws_serve, args.host, args.port, ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None
subprotocols=["binary"],
ping_interval=None
) )
print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}") print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}")

View File

@ -35,8 +35,7 @@ class ASRService:
""" """
# 使用VAD模型分析音频段 # 使用VAD模型分析音频段
segments_result = self.model_vad.generate( segments_result = self.model_vad.generate(
input=audio_in, input=audio_in, **websocket.status_dict_vad
**websocket.status_dict_vad
)[0]["value"] )[0]["value"]
speech_start = -1 speech_start = -1
@ -64,36 +63,38 @@ class ASRService:
if len(audio_in) > 0: if len(audio_in) > 0:
# 使用离线ASR模型处理音频 # 使用离线ASR模型处理音频
rec_result = self.model_asr.generate( rec_result = self.model_asr.generate(
input=audio_in, input=audio_in, **websocket.status_dict_asr
**websocket.status_dict_asr
)[0] )[0]
# 如果有标点符号模型且识别出文本,则添加标点 # 如果有标点符号模型且识别出文本,则添加标点
if self.model_punc is not None and len(rec_result["text"]) > 0: if self.model_punc is not None and len(rec_result["text"]) > 0:
rec_result = self.model_punc.generate( rec_result = self.model_punc.generate(
input=rec_result["text"], input=rec_result["text"], **websocket.status_dict_punc
**websocket.status_dict_punc
)[0] )[0]
# 如果识别出文本,发送到客户端 # 如果识别出文本,发送到客户端
if len(rec_result["text"]) > 0: if len(rec_result["text"]) > 0:
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({ message = json.dumps(
"mode": mode, {
"text": rec_result["text"], "mode": mode,
"wav_name": websocket.wav_name, "text": rec_result["text"],
"is_final": websocket.is_speaking, "wav_name": websocket.wav_name,
}) "is_final": websocket.is_speaking,
}
)
await websocket.send(message) await websocket.send(message)
else: else:
# 如果没有音频数据,发送空文本 # 如果没有音频数据,发送空文本
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({ message = json.dumps(
"mode": mode, {
"text": "", "mode": mode,
"wav_name": websocket.wav_name, "text": "",
"is_final": websocket.is_speaking, "wav_name": websocket.wav_name,
}) "is_final": websocket.is_speaking,
}
)
await websocket.send(message) await websocket.send(message)
async def async_asr_online(self, websocket, audio_in): async def async_asr_online(self, websocket, audio_in):
@ -107,21 +108,24 @@ class ASRService:
if len(audio_in) > 0: if len(audio_in) > 0:
# 使用在线ASR模型处理音频 # 使用在线ASR模型处理音频
rec_result = self.model_asr_streaming.generate( rec_result = self.model_asr_streaming.generate(
input=audio_in, input=audio_in, **websocket.status_dict_asr_online
**websocket.status_dict_asr_online
)[0] )[0]
# 在2pass模式下如果是最终帧则跳过(留给离线ASR处理) # 在2pass模式下如果是最终帧则跳过(留给离线ASR处理)
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False): if websocket.mode == "2pass" and websocket.status_dict_asr_online.get(
"is_final", False
):
return return
# 如果识别出文本,发送到客户端 # 如果识别出文本,发送到客户端
if len(rec_result["text"]): if len(rec_result["text"]):
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({ message = json.dumps(
"mode": mode, {
"text": rec_result["text"], "mode": mode,
"wav_name": websocket.wav_name, "text": rec_result["text"],
"is_final": websocket.is_speaking, "wav_name": websocket.wav_name,
}) "is_final": websocket.is_speaking,
}
)
await websocket.send(message) await websocket.send(message)

View File

@ -1,10 +1,12 @@
""" """
处理各类音频数据与bytes的转换 处理各类音频数据与bytes的转换
""" """
import wave import wave
from pydub import AudioSegment from pydub import AudioSegment
import io import io
def wav_to_bytes(wav_path: str) -> bytes: def wav_to_bytes(wav_path: str) -> bytes:
""" """
将WAV文件读取为bytes数据 将WAV文件读取为bytes数据
@ -20,18 +22,20 @@ def wav_to_bytes(wav_path: str) -> bytes:
wave.Error: 如果文件不是有效的WAV文件 wave.Error: 如果文件不是有效的WAV文件
""" """
try: try:
with wave.open(wav_path, 'rb') as wf: with wave.open(wav_path, "rb") as wf:
# 读取所有帧 # 读取所有帧
frames = wf.readframes(wf.getnframes()) frames = wf.readframes(wf.getnframes())
return frames return frames
except FileNotFoundError: except FileNotFoundError:
# 可以选择记录日志或重新抛出,这里为了清晰直接重新抛出 # 可以选择记录日志或重新抛出,这里为了清晰直接重新抛出
raise FileNotFoundError(f"错误未找到WAV文件 '{wav_path}'") raise FileNotFoundError(f"错误: 未找到WAV文件 '{wav_path}'")
except wave.Error as e: except wave.Error as e:
raise wave.Error(f"错误打开或读取WAV文件 '{wav_path}' 失败 - {e}") raise wave.Error(f"错误: 打开或读取WAV文件 '{wav_path}' 失败 - {e}")
def bytes_to_wav(bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int): def bytes_to_wav(
bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int
):
""" """
将bytes数据写入为WAV文件 将bytes数据写入为WAV文件
@ -46,17 +50,18 @@ def bytes_to_wav(bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: in
wave.Error: 如果写入WAV文件失败 wave.Error: 如果写入WAV文件失败
""" """
try: try:
with wave.open(wav_path, 'wb') as wf: with wave.open(wav_path, "wb") as wf:
wf.setnchannels(nchannels) wf.setnchannels(nchannels)
wf.setsampwidth(sampwidth) wf.setsampwidth(sampwidth)
wf.setframerate(framerate) wf.setframerate(framerate)
wf.writeframes(bytes_data) wf.writeframes(bytes_data)
except wave.Error as e: except wave.Error as e:
raise wave.Error(f"错误写入WAV文件 '{wav_path}' 失败 - {e}") raise wave.Error(f"错误: 写入WAV文件 '{wav_path}' 失败 - {e}")
except Exception as e: except Exception as e:
# 捕获其他可能的写入错误 # 捕获其他可能的写入错误
raise Exception(f"写入WAV文件 '{wav_path}' 时发生未知错误 - {e}") raise Exception(f"写入WAV文件 '{wav_path}' 时发生未知错误 - {e}")
def mp3_to_bytes(mp3_path: str) -> bytes: def mp3_to_bytes(mp3_path: str) -> bytes:
""" """
将MP3文件转换为bytes数据 (原始PCM数据) 将MP3文件转换为bytes数据 (原始PCM数据)
@ -76,12 +81,19 @@ def mp3_to_bytes(mp3_path: str) -> bytes:
# 获取原始PCM数据 # 获取原始PCM数据
return audio.raw_data return audio.raw_data
except FileNotFoundError: except FileNotFoundError:
raise FileNotFoundError(f"错误未找到MP3文件 '{mp3_path}'") raise FileNotFoundError(f"错误: 未找到MP3文件 '{mp3_path}'")
except Exception as e: # pydub 可能抛出多种解码相关的错误 except Exception as e: # pydub 可能抛出多种解码相关的错误
raise Exception(f"错误处理MP3文件 '{mp3_path}' 失败 - {e}") raise Exception(f"错误: 处理MP3文件 '{mp3_path}' 失败 - {e}")
def bytes_to_mp3(bytes_data: bytes, mp3_path: str, frame_rate: int, channels: int, sample_width: int, bitrate: str = "192k"): def bytes_to_mp3(
bytes_data: bytes,
mp3_path: str,
frame_rate: int,
channels: int,
sample_width: int,
bitrate: str = "192k",
):
""" """
将原始PCM bytes数据转换为MP3文件 将原始PCM bytes数据转换为MP3文件
@ -102,9 +114,9 @@ def bytes_to_mp3(bytes_data: bytes, mp3_path: str, frame_rate: int, channels: in
data=bytes_data, data=bytes_data,
sample_width=sample_width, sample_width=sample_width,
frame_rate=frame_rate, frame_rate=frame_rate,
channels=channels channels=channels,
) )
# 导出为MP3 # 导出为MP3
audio.export(mp3_path, format="mp3", bitrate=bitrate) audio.export(mp3_path, format="mp3", bitrate=bitrate)
except Exception as e: except Exception as e:
raise Exception(f"错误转换或写入MP3文件 '{mp3_path}' 失败 - {e}") raise Exception(f"错误: 转换或写入MP3文件 '{mp3_path}' 失败 - {e}")

View File

@ -3,6 +3,7 @@ import sys
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
def setup_logger( def setup_logger(
name: str = None, name: str = None,
level: str = "INFO", level: str = "INFO",
@ -45,17 +46,15 @@ def setup_logger(
log_path = Path(log_file) log_path = Path(log_file)
log_path.parent.mkdir(parents=True, exist_ok=True) log_path.parent.mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(log_file, encoding='utf-8') file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
logger.addHandler(file_handler) logger.addHandler(file_handler)
# 注意:移除了 propagate = False允许日志传递 # 注意:移除了 propagate = False允许日志传递
return logger return logger
def setup_root_logger(
level: str = "INFO", def setup_root_logger(level: str = "INFO", log_file: Optional[str] = None) -> None:
log_file: Optional[str] = None
) -> None:
""" """
配置根日志器 配置根日志器
@ -65,10 +64,11 @@ def setup_root_logger(
""" """
setup_logger(None, level, log_file) setup_logger(None, level, log_file)
def get_module_logger( def get_module_logger(
module_name: str, module_name: str,
level: Optional[str] = None, # 改为可选参数 level: Optional[str] = None, # 改为可选参数
log_file: Optional[str] = None # 一般不需要单独指定 log_file: Optional[str] = None, # 一般不需要单独指定
) -> logging.Logger: ) -> logging.Logger:
""" """
获取模块级别的logger 获取模块级别的logger

View File

@ -1,10 +1,15 @@
from tests.functor.vad_test import test_vad_functor """
测试主函数
请在tests目录下创建测试文件, 并在此文件中调用
"""
from tests.pipeline.asr_test import test_asr_pipeline from tests.pipeline.asr_test import test_asr_pipeline
from src.utils.logger import get_module_logger, setup_root_logger from src.utils.logger import get_module_logger, setup_root_logger
setup_root_logger(level="INFO", log_file="logs/test_main.log") setup_root_logger(level="INFO", log_file="logs/test_main.log")
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
# from tests.functor.vad_test import test_vad_functor
# logger.info("开始测试VAD函数器") # logger.info("开始测试VAD函数器")
# test_vad_functor() # test_vad_functor()

View File

@ -2,6 +2,7 @@
Functor测试 Functor测试
VAD测试 VAD测试
""" """
from src.functor.vad_functor import VADFunctor from src.functor.vad_functor import VADFunctor
from src.functor.asr_functor import ASRFunctor from src.functor.asr_functor import ASRFunctor
from src.functor.spk_functor import SPKFunctor from src.functor.spk_functor import SPKFunctor
@ -21,6 +22,7 @@ logger = get_module_logger(__name__)
model_loader = ModelLoader() model_loader = ModelLoader()
def test_vad_functor(): def test_vad_functor():
# 加载模型 # 加载模型
args = { args = {
@ -38,9 +40,9 @@ def test_vad_functor():
chunk_stride=1600, chunk_stride=1600,
sample_rate=sample_rate, sample_rate=sample_rate,
sample_width=16, sample_width=16,
channels=1 channels=1,
) )
chunk_stride = int(audio_config.chunk_size*sample_rate/1000) chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
audio_config.chunk_stride = chunk_stride audio_config.chunk_stride = chunk_stride
# 创建输入队列 # 创建输入队列
input_queue = Queue() input_queue = Queue()
@ -62,9 +64,7 @@ def test_vad_functor():
vad_functor.add_callback(lambda x: vad2asr_queue.put(x)) vad_functor.add_callback(lambda x: vad2asr_queue.put(x))
vad_functor.add_callback(lambda x: vad2spk_queue.put(x)) vad_functor.add_callback(lambda x: vad2spk_queue.put(x))
# 设置模型 # 设置模型
vad_functor.set_model({ vad_functor.set_model({"vad": model_loader.models["vad"]})
'vad': model_loader.models['vad']
})
# 启动VAD函数器 # 启动VAD函数器
vad_functor.run() vad_functor.run()
@ -77,9 +77,7 @@ def test_vad_functor():
# 设置回调函数 # 设置回调函数
asr_functor.add_callback(lambda x: print(f"asr callback: {x}")) asr_functor.add_callback(lambda x: print(f"asr callback: {x}"))
# 设置模型 # 设置模型
asr_functor.set_model({ asr_functor.set_model({"asr": model_loader.models["asr"]})
'asr': model_loader.models['asr']
})
# 启动ASR函数器 # 启动ASR函数器
asr_functor.run() asr_functor.run()
@ -92,23 +90,25 @@ def test_vad_functor():
# 设置回调函数 # 设置回调函数
spk_functor.add_callback(lambda x: print(f"spk callback: {x}")) spk_functor.add_callback(lambda x: print(f"spk callback: {x}"))
# 设置模型 # 设置模型
spk_functor.set_model({ spk_functor.set_model(
# 'spk': model_loader.models['spk'] {
'spk': 'fake_spk' # 'spk': model_loader.models['spk']
}) "spk": "fake_spk"
}
)
# 启动SPK函数器 # 启动SPK函数器
spk_functor.run() spk_functor.run()
f_binary = f_data f_binary = f_data
audio_clip_len = 200 audio_clip_len = 200
print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}") print(
f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}"
)
for i in range(0, len(f_binary), audio_clip_len): for i in range(0, len(f_binary), audio_clip_len):
binary_data = f_binary[i:i+audio_clip_len] binary_data = f_binary[i : i + audio_clip_len]
input_queue.put(binary_data) input_queue.put(binary_data)
# 等待VAD函数器结束 # 等待VAD函数器结束
vad_functor.stop() vad_functor.stop()
print("[vad_test] VAD函数器结束") print("[vad_test] VAD函数器结束")
@ -119,4 +119,6 @@ def test_vad_functor():
if OVERWATCH: if OVERWATCH:
for index in range(len(audio_binary_data_list)): for index in range(len(audio_binary_data_list)):
save_path = f"tests/vad_test_output_{index}.wav" save_path = f"tests/vad_test_output_{index}.wav"
soundfile.write(save_path, audio_binary_data_list[index].binary_data, sample_rate) soundfile.write(
save_path, audio_binary_data_list[index].binary_data, sample_rate
)

View File

@ -1,10 +1,21 @@
"""
模型使用测试
此处主要用于各类调用模型的处理数据与输出格式
请在主目录下test_main.py中调用
将需要测试的模型定义在函数中进行测试, 函数名称需要与测试内容匹配
"""
from funasr import AutoModel from funasr import AutoModel
from typing import List, Dict, Any from typing import List, Dict, Any
from src.models import VADResponse from src.models import VADResponse
import time import time
def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]: def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
chunk_size = 100 # ms """
在线VAD模型使用
"""
chunk_size = 100 # ms
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True) model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
vad_result = VADResponse() vad_result = VADResponse()
@ -16,12 +27,14 @@ def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
chunk_stride = int(chunk_size * sample_rate / 1000) chunk_stride = int(chunk_size * sample_rate / 1000)
cache = {} cache = {}
total_chunk_num = int(len((speech)-1)/chunk_stride+1) total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
for i in range(total_chunk_num): for i in range(total_chunk_num):
time.sleep(0.1) time.sleep(0.1)
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
is_final = i == total_chunk_num - 1 is_final = i == total_chunk_num - 1
res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size) res = model.generate(
input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size
)
if len(res[0]["value"]): if len(res[0]["value"]):
vad_result += VADResponse.from_raw(res) vad_result += VADResponse.from_raw(res)
for item in res[0]["value"]: for item in res[0]["value"]:
@ -32,44 +45,64 @@ def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
# print(item) # print(item)
return vad_result return vad_result
def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]: def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
"""
在线VAD模型使用
测试LogicTrager
在Rebuild版本后LogicTrager中已弃用
"""
from src.logic_trager import LogicTrager from src.logic_trager import LogicTrager
import soundfile import soundfile
from src.config import parse_args from src.config import parse_args
args = parse_args() args = parse_args()
# from src.functor.model_loader import load_models # from src.functor.model_loader import load_models
# models = load_models(args) # models = load_models(args)
from src.model_loader import ModelLoader from src.model_loader import ModelLoader
models = ModelLoader(args) models = ModelLoader(args)
chunk_size = 200 # ms chunk_size = 200 # ms
from src.models import AudioBinary_Config from src.models import AudioBinary_Config
import soundfile import soundfile
speech, sample_rate = soundfile.read(file_path) speech, sample_rate = soundfile.read(file_path)
chunk_stride = int(chunk_size * sample_rate / 1000) chunk_stride = int(chunk_size * sample_rate / 1000)
audio_config = AudioBinary_Config(sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size) audio_config = AudioBinary_Config(
sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size
)
logic_trager = LogicTrager(models=models, audio_config=audio_config) logic_trager = LogicTrager(models=models, audio_config=audio_config)
for i in range(len(speech)//chunk_stride+1): for i in range(len(speech) // chunk_stride + 1):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
logic_trager.push_binary_data(speech_chunk) logic_trager.push_binary_data(speech_chunk)
# for item in items: # for item in items:
# print(item) # print(item)
return None return None
def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]: def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
"""
ASR模型使用
离线ASR模型使用
"""
from funasr import AutoModel from funasr import AutoModel
model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
vad_model="fsmn-vad", vad_model_revision="v2.0.4", model = AutoModel(
# punc_model="ct-punc-c", punc_model_revision="v2.0.4", model="paraformer-zh",
spk_model="cam++", spk_model_revision="v2.0.2", model_revision="v2.0.4",
spk_mode="vad_segment", vad_model="fsmn-vad",
auto_update=False, vad_model_revision="v2.0.4",
) # punc_model="ct-punc-c", punc_model_revision="v2.0.4",
spk_model="cam++",
spk_model_revision="v2.0.2",
spk_mode="vad_segment",
auto_update=False,
)
import soundfile import soundfile
@ -80,7 +113,9 @@ def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
result = model.generate(speech) result = model.generate(speech)
return result return result
if __name__ == "__main__":
# vad_result = vad_model_use_online("tests/vad_example.wav") # if __name__ == "__main__":
vad_result = vad_model_use_online_logic("tests/vad_example.wav") # 请在主目录下调用test_main.py文件进行测试
# print(vad_result) # vad_result = vad_model_use_online("tests/vad_example.wav")
# vad_result = vad_model_use_online_logic("tests/vad_example.wav")
# print(vad_result)

View File

@ -2,7 +2,9 @@
Pipeline测试 Pipeline测试
VAD+ASR+SPK(FAKE) VAD+ASR+SPK(FAKE)
""" """
from src.pipeline.ASRpipeline import ASRPipeline from src.pipeline.ASRpipeline import ASRPipeline
from src.pipeline import PipelineFactory
from src.models import AudioBinary_data_list, AudioBinary_Config from src.models import AudioBinary_data_list, AudioBinary_Config
from src.model_loader import ModelLoader from src.model_loader import ModelLoader
from queue import Queue from queue import Queue
@ -18,6 +20,7 @@ OVAERWATCH = False
model_loader = ModelLoader() model_loader = ModelLoader()
def test_asr_pipeline(): def test_asr_pipeline():
# 加载模型 # 加载模型
args = { args = {
@ -36,9 +39,9 @@ def test_asr_pipeline():
chunk_stride=1600, chunk_stride=1600,
sample_rate=sample_rate, sample_rate=sample_rate,
sample_width=16, sample_width=16,
channels=1 channels=1,
) )
chunk_stride = int(audio_config.chunk_size*sample_rate/1000) chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
audio_config.chunk_stride = chunk_stride audio_config.chunk_stride = chunk_stride
# 创建参数Dict # 创建参数Dict
@ -52,21 +55,32 @@ def test_asr_pipeline():
input_queue = Queue() input_queue = Queue()
# 创建Pipeline # 创建Pipeline
asr_pipeline = ASRPipeline() # asr_pipeline = ASRPipeline()
asr_pipeline.set_models(models) # asr_pipeline.set_models(models)
asr_pipeline.set_config(config) # asr_pipeline.set_config(config)
asr_pipeline.set_audio_binary(audio_binary_data_list) # asr_pipeline.set_audio_binary(audio_binary_data_list)
asr_pipeline.set_input_queue(input_queue) # asr_pipeline.set_input_queue(input_queue)
asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}")) # asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
asr_pipeline.bake() # asr_pipeline.bake()
asr_pipeline = PipelineFactory.create_pipeline(
pipeline_name = "ASRpipeline",
models=models,
config=config,
audio_binary=audio_binary_data_list,
input_queue=input_queue,
callback=lambda x: print(f"pipeline callback: {x}")
)
# 运行Pipeline # 运行Pipeline
asr_instance = asr_pipeline.run() asr_instance = asr_pipeline.run()
audio_clip_len = 200 audio_clip_len = 200
print(f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}") print(
f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}"
)
for i in range(0, len(audio_data), audio_clip_len): for i in range(0, len(audio_data), audio_clip_len):
input_queue.put(audio_data[i:i+audio_clip_len]) input_queue.put(audio_data[i : i + audio_clip_len])
# time.sleep(10) # time.sleep(10)
# input_queue.put(None) # input_queue.put(None)
@ -77,4 +91,3 @@ def test_asr_pipeline():
time.sleep(5) time.sleep(5)
asr_pipeline.stop() asr_pipeline.stop()
# asr_pipeline.stop() # asr_pipeline.stop()

View File

@ -10,13 +10,13 @@ import os
from unittest.mock import patch from unittest.mock import patch
# 将src目录添加到路径 # 将src目录添加到路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.config import parse_args from src.config import parse_args
def test_default_args(): def test_default_args():
"""测试默认参数值""" """测试默认参数值"""
with patch('sys.argv', ['script.py']): with patch("sys.argv", ["script.py"]):
args = parse_args() args = parse_args()
# 检查服务器参数 # 检查服务器参数
@ -46,17 +46,24 @@ def test_default_args():
def test_custom_args(): def test_custom_args():
"""测试自定义参数值""" """测试自定义参数值"""
test_args = [ test_args = [
'script.py', "script.py",
'--host', 'localhost', "--host",
'--port', '8080', "localhost",
'--certfile', 'cert.pem', "--port",
'--keyfile', 'key.pem', "8080",
'--asr_model', 'custom_model', "--certfile",
'--ngpu', '0', "cert.pem",
'--device', 'cpu' "--keyfile",
"key.pem",
"--asr_model",
"custom_model",
"--ngpu",
"0",
"--device",
"cpu",
] ]
with patch('sys.argv', test_args): with patch("sys.argv", test_args):
args = parse_args() args = parse_args()
# 检查自定义参数 # 检查自定义参数