[代码结构]black . 对所有文件格式调整,无功能变化。

This commit is contained in:
Ziyang.Zhang 2025-06-12 15:49:43 +08:00
parent 5b94c40016
commit 5a820b49e4
28 changed files with 543 additions and 421 deletions

22
main.py
View File

@ -1,11 +1,7 @@
from funasr import AutoModel
chunk_size = 200 # ms
model = AutoModel(
model="fsmn-vad",
model_revision="v2.0.4",
disable_update=True
)
chunk_size = 200 # ms
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
import soundfile
@ -14,16 +10,16 @@ speech, sample_rate = soundfile.read(wav_file)
chunk_stride = int(chunk_size * sample_rate / 1000)
cache = {}
total_chunk_num = int(len((speech)-1)/chunk_stride+1)
total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
is_final = i == total_chunk_num - 1
res = model.generate(
input=speech_chunk,
cache=cache,
is_final=is_final,
input=speech_chunk,
cache=cache,
is_final=is_final,
chunk_size=chunk_size,
disable_pbar=True
disable_pbar=True,
)
if len(res[0]["value"]):
print(res)
@ -31,4 +27,4 @@ for i in range(total_chunk_num):
print(f"len(speech): {len(speech)}")
print(f"len(speech_chunk): {len(speech_chunk)}")
print(f"total_chunk_num: {total_chunk_num}")
print(f"generateconfig: chunk_size: {chunk_size}, chunk_stride: {chunk_stride}")
print(f"generateconfig: chunk_size: {chunk_size}, chunk_stride: {chunk_stride}")

View File

@ -11,4 +11,4 @@ FunASR WebSocket服务
- 支持多种识别模式(2pass/online/offline)
"""
__version__ = "0.1.0"
__version__ = "0.1.0"

View File

@ -46,7 +46,7 @@ class AudioBinary:
else:
raise ValueError("参数类型错误")
def add_slice_listener(self, slice_listener: callable):
def add_slice_listener(self, slice_listener: callable) -> None:
"""
添加切片监听器
参数:
@ -98,10 +98,10 @@ class AudioBinary:
self._binary_data_list.rewrite(target_index, binary_data)
def get_binary_data(
self,
start: int = 0,
end: Optional[int] = None,
) -> Optional[bytes]:
self,
start: int = 0,
end: Optional[int] = None,
) -> Optional[bytes]:
"""
获取指定索引的音频数据块
参数:
@ -128,7 +128,7 @@ class AudioChunk:
此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑
"""
_instance = None
_instance: Optional[AudioChunk] = None
def __new__(cls, *args, **kwargs):
"""
@ -138,7 +138,7 @@ class AudioChunk:
cls._instance = super(AudioChunk, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self):
def __init__(self) -> None:
"""
初始化AudioChunk实例
"""
@ -146,10 +146,10 @@ class AudioChunk:
self._slice_listener: List[callable] = []
def get_audio_binary(
self,
binary_name: Optional[str] = None,
audio_config: Optional[AudioBinary_Config] = None,
) -> AudioBinary:
self,
binary_name: Optional[str] = None,
audio_config: Optional[AudioBinary_Config] = None,
) -> AudioBinary:
"""
获取音频数据块
参数:

View File

@ -10,116 +10,79 @@ import argparse
def parse_args():
"""
解析命令行参数
返回:
argparse.Namespace: 解析后的参数对象
"""
parser = argparse.ArgumentParser(description="FunASR WebSocket服务器")
# 服务器配置
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="服务器主机地址例如localhost, 0.0.0.0"
"--host",
type=str,
default="0.0.0.0",
help="服务器主机地址例如localhost, 0.0.0.0",
)
parser.add_argument(
"--port",
type=int,
default=10095,
help="WebSocket服务器端口"
)
parser.add_argument("--port", type=int, default=10095, help="WebSocket服务器端口")
# SSL配置
parser.add_argument(
"--certfile",
type=str,
default="",
help="SSL证书文件路径"
)
parser.add_argument(
"--keyfile",
type=str,
default="",
help="SSL密钥文件路径"
)
parser.add_argument("--certfile", type=str, default="", help="SSL证书文件路径")
parser.add_argument("--keyfile", type=str, default="", help="SSL密钥文件路径")
# ASR模型配置
parser.add_argument(
"--asr_model",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="离线ASR模型从ModelScope获取"
help="离线ASR模型从ModelScope获取",
)
parser.add_argument(
"--asr_model_revision",
type=str,
default="v2.0.4",
help="离线ASR模型版本"
"--asr_model_revision", type=str, default="v2.0.4", help="离线ASR模型版本"
)
# 在线ASR模型配置
parser.add_argument(
"--asr_model_online",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
help="在线ASR模型从ModelScope获取"
help="在线ASR模型从ModelScope获取",
)
parser.add_argument(
"--asr_model_online_revision",
type=str,
default="v2.0.4",
help="在线ASR模型版本"
"--asr_model_online_revision",
type=str,
default="v2.0.4",
help="在线ASR模型版本",
)
# VAD模型配置
parser.add_argument(
"--vad_model",
type=str,
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
help="VAD语音活动检测模型从ModelScope获取"
help="VAD语音活动检测模型从ModelScope获取",
)
parser.add_argument(
"--vad_model_revision",
type=str,
default="v2.0.4",
help="VAD模型版本"
"--vad_model_revision", type=str, default="v2.0.4", help="VAD模型版本"
)
# 标点符号模型配置
parser.add_argument(
"--punc_model",
type=str,
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
help="标点符号模型从ModelScope获取"
help="标点符号模型从ModelScope获取",
)
parser.add_argument(
"--punc_model_revision",
type=str,
default="v2.0.4",
help="标点符号模型版本"
"--punc_model_revision", type=str, default="v2.0.4", help="标点符号模型版本"
)
# 硬件配置
parser.add_argument("--ngpu", type=int, default=1, help="GPU数量0表示仅使用CPU")
parser.add_argument(
"--ngpu",
type=int,
default=1,
help="GPU数量0表示仅使用CPU"
"--device", type=str, default="cuda", help="设备类型cuda或cpu"
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="设备类型cuda或cpu"
)
parser.add_argument(
"--ncpu",
type=int,
default=4,
help="CPU核心数"
)
parser.add_argument("--ncpu", type=int, default=4, help="CPU核心数")
return parser.parse_args()
@ -127,4 +90,4 @@ if __name__ == "__main__":
args = parse_args()
print("配置参数:")
for arg in vars(args):
print(f" {arg}: {getattr(args, arg)}")
print(f" {arg}: {getattr(args, arg)}")

View File

@ -1,4 +1,4 @@
from .vad_functor import VADFunctor
from .base import FunctorFactory
__all__ = ["VADFunctor", "FunctorFactory"]
__all__ = ["VADFunctor", "FunctorFactory"]

View File

@ -2,8 +2,9 @@
ASRFunctor
负责对音频片段进行ASR处理, 以ASR_Result进行callback
"""
from src.functor.base import BaseFunctor
from src.models import AudioBinary_data_list, AudioBinary_Config,VAD_Functor_result
from src.models import AudioBinary_data_list, AudioBinary_Config, VAD_Functor_result
from typing import Callable, List
from queue import Queue, Empty
import threading
@ -13,6 +14,7 @@ from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class ASRFunctor(BaseFunctor):
"""
ASRFunctor
@ -51,26 +53,26 @@ class ASRFunctor(BaseFunctor):
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
"""
pass
def set_input_queue(self, queue: Queue) -> None:
"""
设置监听的输入消息队列
"""
self._input_queue = queue
def set_model(self, model: dict) -> None:
"""
设置推理模型
"""
self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
"""
设置音频配置
"""
self._audio_config = audio_config
logger.debug("ASRFunctor设置音频配置: %s", self._audio_config)
def add_callback(self, callback: Callable) -> None:
"""
向自身的_callback: List[Callable]回调函数列表中添加回调函数
@ -78,12 +80,12 @@ class ASRFunctor(BaseFunctor):
if not isinstance(self._callback, list):
self._callback = []
self._callback.append(callback)
def _do_callback(self, result: List[str]) -> None:
"""
回调函数
"""
text = result[0]['text'].replace(" ", "")
text = result[0]["text"].replace(" ", "")
for callback in self._callback:
callback(text)
@ -98,7 +100,7 @@ class ASRFunctor(BaseFunctor):
hotwords=self._hotwords,
)
self._do_callback(result)
def _run(self) -> None:
"""
线程运行逻辑
@ -132,7 +134,7 @@ class ASRFunctor(BaseFunctor):
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
return self._thread
def _pre_check(self) -> bool:
"""
预检查
@ -146,7 +148,7 @@ class ASRFunctor(BaseFunctor):
if self._callback is None:
raise ValueError("回调函数未设置")
return True
def stop(self) -> bool:
"""
停止线程
@ -157,5 +159,3 @@ class ASRFunctor(BaseFunctor):
with self._status_lock:
self._is_running = False
return not self._thread.is_alive()

View File

@ -4,18 +4,20 @@ Functor基础模块
该模块定义了Functor的基类,所有功能性的类(如VADPUNCASRSPK等)都应继承自这个基类
基类提供了数据处理的基本框架,包括:
- 回调函数管理
- 模型配置管理
- 模型配置管理
- 线程运行控制
主要类:
BaseFunctor: Functor抽象类
FunctorFactory: Functor工厂类
"""
from abc import ABC, abstractmethod
from typing import Callable, List
from queue import Queue
import threading
class BaseFunctor(ABC):
"""
Functor抽象类
@ -27,9 +29,7 @@ class BaseFunctor(ABC):
_model (dict): 存储模型相关的配置和实例
"""
def __init__(
self
):
def __init__(self):
"""
初始化函数器
@ -38,8 +38,8 @@ class BaseFunctor(ABC):
model (dict): 模型相关的配置和实例
"""
self._callback: List[Callable] = []
self._model: dict = {}
# flag
self._model: dict = {}
# flag
self._is_running: bool = False
self._stop_event: bool = False
# 状态锁
@ -91,7 +91,7 @@ class BaseFunctor(ABC):
返回:
线程实例
"""
@abstractmethod
def _pre_check(self):
"""
@ -111,13 +111,12 @@ class BaseFunctor(ABC):
"""
class FunctorFactory:
"""
Functor工厂类
该工厂类负责创建和配置Functor实例
主要方法:
make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
创建并配置Functor实例
@ -138,58 +137,55 @@ class FunctorFactory:
"""
if functor_name == "vad":
return cls._make_vadfunctor(config = config,models = models)
return cls._make_vadfunctor(config=config, models=models)
elif functor_name == "asr":
return cls._make_asrfunctor(config = config,models = models)
return cls._make_asrfunctor(config=config, models=models)
elif functor_name == "spk":
return cls._make_spkfunctor(config = config,models = models)
return cls._make_spkfunctor(config=config, models=models)
else:
raise ValueError(f"不支持的Functor类型: {functor_name}")
def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建VAD Functor实例
"""
from src.functor.vad_functor import VADFunctor
audio_config = config["audio_config"]
model = {
"vad": models["vad"]
}
model = {"vad": models["vad"]}
vad_functor = VADFunctor()
vad_functor.set_audio_config(audio_config)
vad_functor.set_model(model)
return vad_functor
def _make_asrfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建ASR Functor实例
"""
from src.functor.asr_functor import ASRFunctor
audio_config = config["audio_config"]
model = {
"asr": models["asr"]
}
model = {"asr": models["asr"]}
asr_functor = ASRFunctor()
asr_functor.set_audio_config(audio_config)
asr_functor.set_model(model)
return asr_functor
def _make_spkfunctor(config: dict, models: dict) -> BaseFunctor:
"""
创建SPK Functor实例
"""
from src.functor.spk_functor import SPKFunctor
audio_config = config["audio_config"]
model = {
"spk": models["spk"]
}
model = {"spk": models["spk"]}
spk_functor = SPKFunctor()
spk_functor.set_audio_config(audio_config)
spk_functor.set_model(model)
return spk_functor
return spk_functor

View File

@ -2,6 +2,7 @@
SpkFunctor
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
"""
from src.functor.base import BaseFunctor
from src.models import AudioBinary_Config, VAD_Functor_result
from typing import Callable, List
@ -13,6 +14,7 @@ from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class SPKFunctor(BaseFunctor):
"""
SPKFunctor
@ -33,25 +35,24 @@ class SPKFunctor(BaseFunctor):
self._input_queue: Queue = None # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置
def reset_cache(self) -> None:
"""
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
"""
pass
def set_input_queue(self, queue: Queue) -> None:
"""
设置监听的输入消息队列
"""
self._input_queue = queue
def set_model(self, model: dict) -> None:
"""
设置推理模型
"""
self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
"""
设置音频配置
@ -66,7 +67,7 @@ class SPKFunctor(BaseFunctor):
if not isinstance(self._callback, list):
self._callback = []
self._callback.append(callback)
def _do_callback(self, result: List[str]) -> None:
"""
回调函数
@ -83,9 +84,9 @@ class SPKFunctor(BaseFunctor):
# input=binary_data,
# chunk_size=self._audio_config.chunk_size,
# )
result = [{'result': "spk1", 'score': {"spk1": 0.9, "spk2": 0.3}}]
result = [{"result": "spk1", "score": {"spk1": 0.9, "spk2": 0.3}}]
self._do_callback(result)
def _run(self) -> None:
"""
线程运行逻辑
@ -108,7 +109,7 @@ class SPKFunctor(BaseFunctor):
except Exception as e:
logger.error("SpkFunctor运行时发生错误: %s", e)
raise e
def run(self) -> threading.Thread:
"""
启动线程
@ -119,7 +120,7 @@ class SPKFunctor(BaseFunctor):
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
return self._thread
def _pre_check(self) -> bool:
"""
预检查
@ -131,7 +132,7 @@ class SPKFunctor(BaseFunctor):
if self._callback is None:
raise ValueError("回调函数未设置")
return True
def stop(self) -> bool:
"""
停止线程
@ -142,5 +143,3 @@ class SPKFunctor(BaseFunctor):
with self._status_lock:
self._is_running = False
return not self._thread.is_alive()

View File

@ -2,6 +2,7 @@
VADFunctor
负责对音频片段进行VAD处理, 以VAD_Result进行callback
"""
import threading
from queue import Empty, Queue
from typing import List, Any, Callable
@ -105,9 +106,7 @@ class VADFunctor(BaseFunctor):
self._callback = []
self._callback.append(callback)
def _do_callback(
self, result: List[List[int]]
) -> None:
def _do_callback(self, result: List[List[int]]) -> None:
"""
回调函数
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice

View File

@ -6,34 +6,36 @@
from src.utils.logger import get_module_logger
from typing import Any, Dict, Type, Callable
# 配置日志
logger = get_module_logger(__name__, level="INFO")
class AutoAfterMeta(type):
"""
自动调用__after__函数的元类
实现单例模式
"""
_instances: Dict[Type, Any] = {} # 存储单例实例
def __new__(cls, name, bases, attrs):
# 遍历所有属性
for attr_name, attr_value in attrs.items():
# 如果是函数且不是以_开头
if callable(attr_value) and not attr_name.startswith('__'):
if callable(attr_value) and not attr_name.startswith("__"):
# 获取原函数
original_func = attr_value
# 创建包装函数
def make_wrapper(func):
def wrapper(self, *args, **kwargs):
# 执行原函数
result = func(self, *args, **kwargs)
# 构建_after_函数名
after_func_name = f"__after__{func.__name__}"
# 检查是否存在对应的_after_函数
if hasattr(self, after_func_name):
after_func = getattr(self, after_func_name)
@ -43,17 +45,18 @@ class AutoAfterMeta(type):
after_func()
except Exception as e:
logger.error(f"调用{after_func_name}时出错: {e}")
return result
return wrapper
# 替换原函数
attrs[attr_name] = make_wrapper(original_func)
# 创建类
new_class = super().__new__(cls, name, bases, attrs)
return new_class
def __call__(cls, *args, **kwargs):
"""
重写__call__方法实现单例模式
@ -65,9 +68,10 @@ class AutoAfterMeta(type):
logger.info(f"创建{cls.__name__}的新实例")
else:
logger.debug(f"返回{cls.__name__}的现有实例")
return cls._instances[cls]
"""
整体识别的处理逻辑
1.压入二进制音频信息
@ -88,10 +92,12 @@ from src.models import AudioBinary_Config
from src.models import AudioBinary_Chunk
from typing import List
class LogicTrager(metaclass=AutoAfterMeta):
"""逻辑触发器类"""
def __init__(self,
def __init__(
self,
audio_chunk_max_size: int = 1024 * 1024 * 10,
audio_config: AudioBinary_Config = None,
result_callback: Callable = None,
@ -99,46 +105,51 @@ class LogicTrager(metaclass=AutoAfterMeta):
):
"""初始化"""
# 存储音频块
self._audio_chunk : List[AudioBinary_Chunk] = []
self._audio_chunk: List[AudioBinary_Chunk] = []
# 存储二进制数据
self._audio_chunk_binary = b''
self._audio_chunk_binary = b""
self._audio_chunk_max_size = audio_chunk_max_size
# 音频参数
self._audio_config = audio_config if audio_config is not None else AudioBinary_Config()
self._audio_config = (
audio_config if audio_config is not None else AudioBinary_Config()
)
# 结果队列
self._result_queue = []
# 聚合结果回调函数
self._aggregate_result_callback = result_callback
# 组件
self._vad = VAD(VAD_model = models.get("vad"), audio_config = self._audio_config)
self._vad = VAD(VAD_model=models.get("vad"), audio_config=self._audio_config)
self._vad.set_callback(self.push_audio_chunk)
logger.info("初始化LogicTrager")
def push_binary_data(self, chunk: bytes) -> None:
"""
压入音频块至VAD模块
参数:
chunk: 音频数据块
"""
# print("LogicTrager push_binary_data", len(chunk))
self._vad.push_binary_data(chunk)
self.__after__push_binary_data()
def __after__push_binary_data(self) -> None:
"""
添加音频块后处理
"""
# print("LogicTrager __after__push_binary_data")
self._vad.process_vad_result()
def push_audio_chunk(self, chunk: AudioBinary_Chunk) -> None:
"""
音频处理
"""
logger.info("LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(chunk.start_time, chunk.end_time, len(chunk.chunk)))
logger.info(
"LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(
chunk.start_time, chunk.end_time, len(chunk.chunk)
)
)
self._audio_chunk.append(chunk)
def __after__push_audio_chunk(self) -> None:
@ -162,4 +173,4 @@ class LogicTrager(metaclass=AutoAfterMeta):
def __call__(self):
"""调用函数"""
pass
pass

View File

@ -115,7 +115,7 @@ class ModelLoader:
self.models = {}
# 加载离线ASR模型
# 检查对应键是否存在
model_list = ['asr', 'asr_online', 'vad', 'punc', 'spk']
model_list = ["asr", "asr_online", "vad", "punc", "spk"]
for model_name in model_list:
name_model = f"{model_name}_model"
name_model_revision = f"{model_name}_model_revision"

View File

@ -1,3 +1,9 @@
from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
from .vad import VAD_Functor_result
__all__ = ["AudioBinary_Config", "AudioBinary_data_list", "_AudioBinary_data", "VAD_Functor_result"]
__all__ = [
"AudioBinary_Config",
"AudioBinary_data_list",
"_AudioBinary_data",
"VAD_Functor_result",
]

View File

@ -8,8 +8,10 @@ logger = get_module_logger(__name__)
binary_data_types = (bytes, numpy.ndarray)
class AudioBinary_Config(BaseModel):
"""二进制音频块配置信息"""
class Config:
arbitrary_types_allowed = True
@ -37,14 +39,16 @@ class AudioBinary_Config(BaseModel):
"""
return int(frame * 1000 / self.sample_rate)
class _AudioBinary_data(BaseModel):
"""音频数据"""
binary_data: binary_data_types = Field(description="音频二进制数据", default=None)
class Config:
arbitrary_types_allowed = True
@validator('binary_data')
@validator("binary_data")
def validate_binary_data(cls, v):
"""
验证音频数据
@ -54,7 +58,11 @@ class _AudioBinary_data(BaseModel):
binary_data_types: 音频数据
"""
if not isinstance(v, (bytes, numpy.ndarray)):
logger.warning("[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查", cls.__class__.__name__, type(v))
logger.warning(
"[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查",
cls.__class__.__name__,
type(v),
)
return v
def __len__(self):
@ -64,14 +72,18 @@ class _AudioBinary_data(BaseModel):
int: 音频数据长度
"""
return len(self.binary_data)
def __init__(self, binary_data: binary_data_types):
"""
初始化音频数据
Args:
binary_data: 音频数据
"""
logger.debug("[%s]初始化音频数据, 数据类型为%s", self.__class__.__name__, type(binary_data))
logger.debug(
"[%s]初始化音频数据, 数据类型为%s",
self.__class__.__name__,
type(binary_data),
)
super().__init__(binary_data=binary_data)
def __getitem__(self):
@ -82,9 +94,13 @@ class _AudioBinary_data(BaseModel):
"""
return self.binary_data
class AudioBinary_data_list(BaseModel):
"""音频数据列表"""
binary_data_list: List[_AudioBinary_data] = Field(description="音频数据列表", default=[])
binary_data_list: List[_AudioBinary_data] = Field(
description="音频数据列表", default=[]
)
class Config:
arbitrary_types_allowed = True
@ -118,6 +134,7 @@ class AudioBinary_data_list(BaseModel):
"""
return len(self.binary_data_list)
# class AudioBinary_Slice(BaseModel):
# """音频块切片"""
# target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None)
@ -138,4 +155,4 @@ class AudioBinary_data_list(BaseModel):
# return v
# def __call__(self):
# return self.target_Binary(self.start_index, self.end_index)
# return self.target_Binary(self.start_index, self.end_index)

View File

@ -2,23 +2,27 @@ from pydantic import BaseModel, Field, validator
from typing import List, Optional, Callable, Any
from .audio import AudioBinary_data_list, _AudioBinary_data
class VAD_Functor_result(BaseModel):
"""
VADFunctor结果
"""
audiobinary_data_list: AudioBinary_data_list = Field(description="音频数据列表")
audiobinary_index: int = Field(description="音频数据索引")
audiobinary_data: _AudioBinary_data = Field(description="音频数据, 指向AudioBinary_data")
audiobinary_data: _AudioBinary_data = Field(
description="音频数据, 指向AudioBinary_data"
)
start_time: int = Field(description="开始时间", is_required=True)
end_time: int = Field(description="结束时间", is_required=True)
@validator('audiobinary_data_list')
@validator("audiobinary_data_list")
def validate_audiobinary_data_list(cls, v):
if not isinstance(v, AudioBinary_data_list):
raise ValueError("audiobinary_data_list必须是AudioBinary_data_list类型")
return v
@validator('audiobinary_index')
@validator("audiobinary_index")
def validate_audiobinary_index(cls, v):
if not isinstance(v, int):
raise ValueError("audiobinary_index必须是int类型")
@ -26,35 +30,35 @@ class VAD_Functor_result(BaseModel):
raise ValueError("audiobinary_index必须大于0")
return v
@validator('audiobinary_data')
@validator("audiobinary_data")
def validate_audiobinary_data(cls, v):
if not isinstance(v, _AudioBinary_data):
raise ValueError("audiobinary_data必须是AudioBinary_data类型")
return v
@validator('start_time')
@validator("start_time")
def validate_start_time(cls, v):
if not isinstance(v, int):
raise ValueError("start_time必须是int类型")
if v < 0:
raise ValueError("start_time必须大于0")
return v
@validator('end_time')
@validator("end_time")
def validate_end_time(cls, v, values):
if not isinstance(v, int):
raise ValueError("end_time必须是int类型")
if 'start_time' in values and v <= values['start_time']:
if "start_time" in values and v <= values["start_time"]:
raise ValueError("end_time必须大于start_time")
return v
@classmethod
def create_from_push_data(
cls,
audiobinary_data_list: AudioBinary_data_list,
data: Any,
start_time: int,
end_time: int
end_time: int,
):
"""
创建VAD片段
@ -66,7 +70,8 @@ class VAD_Functor_result(BaseModel):
audiobinary_index=index,
audiobinary_data=audiobinary_data_list[index],
start_time=start_time,
end_time=end_time)
end_time=end_time,
)
def __len__(self):
"""
@ -78,11 +83,9 @@ class VAD_Functor_result(BaseModel):
"""
字符串展示内容
"""
tostr = f'audiobinary_data_index: {self.audiobinary_index}\n'
tostr += f'start_time: {self.start_time}\n'
tostr += f'end_time: {self.end_time}\n'
tostr += f'data_length: {len(self.audiobinary_data.binary_data)}\n'
tostr += f'data_type: {type(self.audiobinary_data.binary_data)}\n'
tostr = f"audiobinary_data_index: {self.audiobinary_index}\n"
tostr += f"start_time: {self.start_time}\n"
tostr += f"end_time: {self.end_time}\n"
tostr += f"data_length: {len(self.audiobinary_data.binary_data)}\n"
tostr += f"data_type: {type(self.audiobinary_data.binary_data)}\n"
return tostr

View File

@ -7,11 +7,13 @@ import threading
logger = get_module_logger(__name__)
class ASRPipeline(PipelineBase):
"""
管道类
实现具体的处理逻辑
"""
def __init__(self, *args, **kwargs):
"""
初始化管道
@ -91,27 +93,24 @@ class ASRPipeline(PipelineBase):
"""
try:
from src.functor import FunctorFactory
# 加载VAD、asr、spk functor
self._functor_dict["vad"] = FunctorFactory.make_functor(
functor_name = "vad",
config = self._config,
models = self._models
functor_name="vad", config=self._config, models=self._models
)
self._functor_dict["asr"] = FunctorFactory.make_functor(
functor_name = "asr",
config = self._config,
models = self._models
functor_name="asr", config=self._config, models=self._models
)
self._functor_dict["spk"] = FunctorFactory.make_functor(
functor_name = "spk",
config = self._config,
models = self._models
functor_name="spk", config=self._config, models=self._models
)
# 创建音频数据存储单元
self._audio_binary_data_list = AudioBinary_data_list()
self._functor_dict["vad"].set_audio_binary_data_list(self._audio_binary_data_list)
self._functor_dict["vad"].set_audio_binary_data_list(
self._audio_binary_data_list
)
# 初始化子队列
self._subqueue_dict["original"] = Queue()
@ -121,7 +120,7 @@ class ASRPipeline(PipelineBase):
self._subqueue_dict["spkend"] = Queue()
# 设置子队列的输入队列
self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"])
self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"])
self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
@ -134,30 +133,40 @@ class ASRPipeline(PipelineBase):
"""
带回调函数的put
"""
def put_with_check(data: Any) -> None:
queue.put(data)
callback(data)
return put_with_check
self._functor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result))
self._functor_dict["spk"].add_callback(put_with_checkcallback(self._subqueue_dict["spkend"], self._check_result))
self._functor_dict["asr"].add_callback(
put_with_checkcallback(
self._subqueue_dict["asrend"], self._check_result
)
)
self._functor_dict["spk"].add_callback(
put_with_checkcallback(
self._subqueue_dict["spkend"], self._check_result
)
)
except ImportError:
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
def _check_result(self, result: Any) -> None:
"""
检查结果
"""
# 若asr和spk队列中都有数据则合并数据
if self._subqueue_dict["asrend"].qsize() & self._subqueue_dict["spkend"].qsize():
if (
self._subqueue_dict["asrend"].qsize()
& self._subqueue_dict["spkend"].qsize()
):
asr_data = self._subqueue_dict["asrend"].get()
spk_data = self._subqueue_dict["spkend"].get()
# 合并数据
result = {
"asr_data": asr_data,
"spk_data": spk_data
}
result = {"asr_data": asr_data, "spk_data": spk_data}
# 通知回调函数
self._notify_callbacks(result)
@ -211,13 +220,13 @@ class ASRPipeline(PipelineBase):
logger.info("收到结束信号,管道准备停止")
self._input_queue.task_done() # 标记结束信号已处理
break
# 处理数据
self._process(data)
# 标记任务完成
self._input_queue.task_done()
except Empty:
# 队列获取超时,继续等待
continue
@ -237,7 +246,7 @@ class ASRPipeline(PipelineBase):
"""
# 子类实现具体的处理逻辑
self._subqueue_dict["original"].put(data)
def stop(self) -> None:
"""
停止管道

View File

@ -8,11 +8,13 @@ import time
# 配置日志
logger = logging.getLogger(__name__)
class PipelineBase(ABC):
"""
管道基类
定义了管道的基本接口和通用功能
"""
def __init__(self, input_queue: Optional[Queue] = None):
"""
初始化管道
@ -93,7 +95,7 @@ class PipelineBase(ABC):
if self._thread and self._thread.is_alive():
timeout = timeout if timeout is not None else self._stop_timeout
self._thread.join(timeout=timeout)
# 检查是否成功停止
if self._thread.is_alive():
logger.warning(f"管道停止超时({timeout}秒),强制终止")
@ -101,7 +103,7 @@ class PipelineBase(ABC):
else:
logger.info("管道已成功停止")
return True
return True
def force_stop(self) -> None:
@ -115,14 +117,35 @@ class PipelineBase(ABC):
# 注意Python的线程无法被强制终止这里只是设置标志
# 实际终止需要依赖操作系统的进程管理
class PipelineFactory:
"""
管道工厂类
用于创建管道实例
"""
@staticmethod
def create_pipeline(pipeline_name: str) -> Any:
from src.pipeline.ASRpipeline import ASRPipeline
def _create_pipeline_ASRpipeline(*args, **kwargs) -> ASRPipeline:
"""
创建ASR管道实例
"""
from src.pipeline.ASRpipeline import ASRPipeline
pipeline = ASRPipeline()
pipeline.set_config(kwargs["config"])
pipeline.set_models(kwargs["models"])
pipeline.set_audio_binary(kwargs["audio_binary"])
pipeline.set_input_queue(kwargs["input_queue"])
pipeline.add_callback(kwargs["callback"])
pipeline.bake()
return pipeline
@classmethod
def create_pipeline(cls, pipeline_name: str, *args, **kwargs) -> Any:
"""
创建管道实例
"""
pass
if pipeline_name == "ASRpipeline":
return cls._create_pipeline_ASRpipeline(*args, **kwargs)
else:
raise ValueError(f"不支持的管道类型: {pipeline_name}")

View File

@ -172,7 +172,7 @@ class STTRunner(RunnerBase):
logger.warning(
"等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理",
self._stop_timeout,
self._input_queue.qsize()
self._input_queue.qsize(),
)
success = False
break
@ -188,7 +188,7 @@ class STTRunner(RunnerBase):
"错误堆栈:\n%s",
error_type,
error_msg,
error_traceback
error_traceback,
)
success = False
@ -198,7 +198,7 @@ class STTRunner(RunnerBase):
logger.warning(
"部分管道停止失败,队列状态: 大小=%d, 是否为空=%s",
self._input_queue.qsize(),
self._input_queue.empty()
self._input_queue.empty(),
)
return success

View File

@ -43,7 +43,7 @@ async def clear_websocket():
async def ws_serve(websocket, path):
"""
WebSocket服务主函数处理客户端连接和消息
参数:
websocket: WebSocket连接对象
path: 连接路径
@ -51,13 +51,13 @@ async def ws_serve(websocket, path):
frames = [] # 存储所有音频帧
frames_asr = [] # 存储用于离线ASR的音频帧
frames_asr_online = [] # 存储用于在线ASR的音频帧
global websocket_users
# await clear_websocket() # 清理现有连接(目前注释掉,允许多客户端)
# 添加到用户集合
websocket_users.add(websocket)
# 初始化连接状态
websocket.status_dict_asr = {}
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
@ -66,15 +66,15 @@ async def ws_serve(websocket, path):
websocket.chunk_interval = 10
websocket.vad_pre_idx = 0
websocket.is_speaking = True # 默认用户正在说话
# 语音检测状态
speech_start = False
speech_end_i = -1
# 初始化配置
websocket.wav_name = "microphone"
websocket.mode = "2pass" # 默认使用两阶段识别模式
print("新用户已连接", flush=True)
try:
@ -84,11 +84,13 @@ async def ws_serve(websocket, path):
if isinstance(message, str):
try:
messagejson = json.loads(message)
# 更新各种配置参数
if "is_speaking" in messagejson:
websocket.is_speaking = messagejson["is_speaking"]
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
websocket.status_dict_asr_online["is_final"] = (
not websocket.is_speaking
)
if "chunk_interval" in messagejson:
websocket.chunk_interval = messagejson["chunk_interval"]
if "wav_name" in messagejson:
@ -97,11 +99,17 @@ async def ws_serve(websocket, path):
chunk_size = messagejson["chunk_size"]
if isinstance(chunk_size, str):
chunk_size = chunk_size.split(",")
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
websocket.status_dict_asr_online["chunk_size"] = [
int(x) for x in chunk_size
]
if "encoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
websocket.status_dict_asr_online["encoder_chunk_look_back"] = (
messagejson["encoder_chunk_look_back"]
)
if "decoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"]
websocket.status_dict_asr_online["decoder_chunk_look_back"] = (
messagejson["decoder_chunk_look_back"]
)
if "hotword" in messagejson:
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
if "mode" in messagejson:
@ -111,11 +119,17 @@ async def ws_serve(websocket, path):
# 根据chunk_interval更新VAD的chunk_size
websocket.status_dict_vad["chunk_size"] = int(
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1] * 60 / websocket.chunk_interval
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1]
* 60
/ websocket.chunk_interval
)
# 处理音频数据
if len(frames_asr_online) > 0 or len(frames_asr) >= 0 or not isinstance(message, str):
if (
len(frames_asr_online) > 0
or len(frames_asr) >= 0
or not isinstance(message, str)
):
if not isinstance(message, str): # 二进制音频数据
# 添加到帧缓冲区
frames.append(message)
@ -125,10 +139,12 @@ async def ws_serve(websocket, path):
# 处理在线ASR
frames_asr_online.append(message)
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
# 达到chunk_interval或最终帧时处理在线ASR
if (len(frames_asr_online) % websocket.chunk_interval == 0 or
websocket.status_dict_asr_online["is_final"]):
if (
len(frames_asr_online) % websocket.chunk_interval == 0
or websocket.status_dict_asr_online["is_final"]
):
if websocket.mode == "2pass" or websocket.mode == "online":
audio_in = b"".join(frames_asr_online)
try:
@ -136,26 +152,32 @@ async def ws_serve(websocket, path):
except Exception as e:
print(f"在线ASR处理错误: {e}")
frames_asr_online = []
# 如果检测到语音开始收集帧用于离线ASR
if speech_start:
frames_asr.append(message)
# VAD处理 - 语音活动检测
try:
speech_start_i, speech_end_i = await asr_service.async_vad(websocket, message)
speech_start_i, speech_end_i = await asr_service.async_vad(
websocket, message
)
except Exception as e:
print(f"VAD处理错误: {e}")
# 检测到语音开始
if speech_start_i != -1:
speech_start = True
# 计算开始偏移并收集前面的帧
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
frames_pre = frames[-beg_bias:] if beg_bias < len(frames) else frames
beg_bias = (
websocket.vad_pre_idx - speech_start_i
) // duration_ms
frames_pre = (
frames[-beg_bias:] if beg_bias < len(frames) else frames
)
frames_asr = []
frames_asr.extend(frames_pre)
# 处理离线ASR (语音结束或用户停止说话)
if speech_end_i != -1 or not websocket.is_speaking:
if websocket.mode == "2pass" or websocket.mode == "offline":
@ -164,13 +186,13 @@ async def ws_serve(websocket, path):
await asr_service.async_asr(websocket, audio_in)
except Exception as e:
print(f"离线ASR处理错误: {e}")
# 重置状态
frames_asr = []
speech_start = False
frames_asr_online = []
websocket.status_dict_asr_online["cache"] = {}
# 如果用户停止说话,完全重置
if not websocket.is_speaking:
websocket.vad_pre_idx = 0
@ -193,34 +215,34 @@ async def ws_serve(websocket, path):
def start_server(args, asr_service_instance):
"""
启动WebSocket服务器
参数:
args: 命令行参数
asr_service_instance: ASR服务实例
"""
global asr_service
asr_service = asr_service_instance
# 配置SSL (如果提供了证书)
if args.certfile and len(args.certfile) > 0:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile)
start_server = websockets.serve(
ws_serve, args.host, args.port,
subprotocols=["binary"],
ping_interval=None,
ssl=ssl_context
ws_serve,
args.host,
args.port,
subprotocols=["binary"],
ping_interval=None,
ssl=ssl_context,
)
else:
start_server = websockets.serve(
ws_serve, args.host, args.port,
subprotocols=["binary"],
ping_interval=None
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None
)
print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}")
# 启动事件循环
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()
@ -229,14 +251,14 @@ def start_server(args, asr_service_instance):
if __name__ == "__main__":
# 解析命令行参数
args = parse_args()
# 加载模型
print("正在加载模型...")
models = load_models(args)
print("模型加载完成!当前仅支持单个客户端同时连接!")
# 创建ASR服务
asr_service = ASRService(models)
# 启动服务器
start_server(args, asr_service)
start_server(args, asr_service)

View File

@ -9,11 +9,11 @@ import json
class ASRService:
"""ASR服务类封装各种语音识别相关功能"""
def __init__(self, models):
"""
初始化ASR服务
参数:
models: 包含各种预加载模型的字典
"""
@ -21,42 +21,41 @@ class ASRService:
self.model_asr_streaming = models["asr_streaming"]
self.model_vad = models["vad"]
self.model_punc = models["punc"]
async def async_vad(self, websocket, audio_in):
"""
语音活动检测
参数:
websocket: WebSocket连接
audio_in: 二进制音频数据
返回:
tuple: (speech_start, speech_end) 语音开始和结束位置
"""
# 使用VAD模型分析音频段
segments_result = self.model_vad.generate(
input=audio_in,
**websocket.status_dict_vad
input=audio_in, **websocket.status_dict_vad
)[0]["value"]
speech_start = -1
speech_end = -1
# 解析VAD结果
if len(segments_result) == 0 or len(segments_result) > 1:
return speech_start, speech_end
if segments_result[0][0] != -1:
speech_start = segments_result[0][0]
if segments_result[0][1] != -1:
speech_end = segments_result[0][1]
return speech_start, speech_end
async def async_asr(self, websocket, audio_in):
"""
离线ASR处理
参数:
websocket: WebSocket连接
audio_in: 二进制音频数据
@ -64,42 +63,44 @@ class ASRService:
if len(audio_in) > 0:
# 使用离线ASR模型处理音频
rec_result = self.model_asr.generate(
input=audio_in,
**websocket.status_dict_asr
input=audio_in, **websocket.status_dict_asr
)[0]
# 如果有标点符号模型且识别出文本,则添加标点
if self.model_punc is not None and len(rec_result["text"]) > 0:
rec_result = self.model_punc.generate(
input=rec_result["text"],
**websocket.status_dict_punc
input=rec_result["text"], **websocket.status_dict_punc
)[0]
# 如果识别出文本,发送到客户端
if len(rec_result["text"]) > 0:
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({
"mode": mode,
"text": rec_result["text"],
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
})
message = json.dumps(
{
"mode": mode,
"text": rec_result["text"],
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
await websocket.send(message)
else:
# 如果没有音频数据,发送空文本
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({
"mode": mode,
"text": "",
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
})
message = json.dumps(
{
"mode": mode,
"text": "",
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
await websocket.send(message)
async def async_asr_online(self, websocket, audio_in):
"""
在线ASR处理
参数:
websocket: WebSocket连接
audio_in: 二进制音频数据
@ -107,21 +108,24 @@ class ASRService:
if len(audio_in) > 0:
# 使用在线ASR模型处理音频
rec_result = self.model_asr_streaming.generate(
input=audio_in,
**websocket.status_dict_asr_online
input=audio_in, **websocket.status_dict_asr_online
)[0]
# 在2pass模式下如果是最终帧则跳过(留给离线ASR处理)
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False):
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get(
"is_final", False
):
return
# 如果识别出文本,发送到客户端
if len(rec_result["text"]):
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({
"mode": mode,
"text": rec_result["text"],
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
})
await websocket.send(message)
message = json.dumps(
{
"mode": mode,
"text": rec_result["text"],
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
await websocket.send(message)

View File

@ -1,3 +1,3 @@
from .logger import get_module_logger, setup_root_logger
__all__ = ["get_module_logger", "setup_root_logger"]
__all__ = ["get_module_logger", "setup_root_logger"]

View File

@ -1,10 +1,12 @@
"""
处理各类音频数据与bytes的转换
"""
import wave
from pydub import AudioSegment
import io
def wav_to_bytes(wav_path: str) -> bytes:
"""
将WAV文件读取为bytes数据
@ -14,24 +16,26 @@ def wav_to_bytes(wav_path: str) -> bytes:
返回:
bytes: WAV文件的原始字节数据
异常:
FileNotFoundError: 如果WAV文件不存在
wave.Error: 如果文件不是有效的WAV文件
"""
try:
with wave.open(wav_path, 'rb') as wf:
with wave.open(wav_path, "rb") as wf:
# 读取所有帧
frames = wf.readframes(wf.getnframes())
return frames
except FileNotFoundError:
# 可以选择记录日志或重新抛出,这里为了清晰直接重新抛出
raise FileNotFoundError(f"错误未找到WAV文件 '{wav_path}'")
raise FileNotFoundError(f"错误: 未找到WAV文件 '{wav_path}'")
except wave.Error as e:
raise wave.Error(f"错误打开或读取WAV文件 '{wav_path}' 失败 - {e}")
raise wave.Error(f"错误: 打开或读取WAV文件 '{wav_path}' 失败 - {e}")
def bytes_to_wav(bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int):
def bytes_to_wav(
bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int
):
"""
将bytes数据写入为WAV文件
@ -41,22 +45,23 @@ def bytes_to_wav(bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: in
nchannels (int): 声道数 (例如 1 for mono, 2 for stereo)
sampwidth (int): 采样宽度 (字节数, 例如 2 for 16-bit audio)
framerate (int): 采样率 (例如 44100, 16000)
异常:
wave.Error: 如果写入WAV文件失败
"""
try:
with wave.open(wav_path, 'wb') as wf:
with wave.open(wav_path, "wb") as wf:
wf.setnchannels(nchannels)
wf.setsampwidth(sampwidth)
wf.setframerate(framerate)
wf.writeframes(bytes_data)
except wave.Error as e:
raise wave.Error(f"错误写入WAV文件 '{wav_path}' 失败 - {e}")
raise wave.Error(f"错误: 写入WAV文件 '{wav_path}' 失败 - {e}")
except Exception as e:
# 捕获其他可能的写入错误
raise Exception(f"写入WAV文件 '{wav_path}' 时发生未知错误 - {e}")
def mp3_to_bytes(mp3_path: str) -> bytes:
"""
将MP3文件转换为bytes数据 (原始PCM数据)
@ -66,7 +71,7 @@ def mp3_to_bytes(mp3_path: str) -> bytes:
返回:
bytes: MP3文件解码后的原始PCM字节数据
异常:
FileNotFoundError: 如果MP3文件不存在
pydub.exceptions.CouldntDecodeError: 如果MP3文件无法解码
@ -76,12 +81,19 @@ def mp3_to_bytes(mp3_path: str) -> bytes:
# 获取原始PCM数据
return audio.raw_data
except FileNotFoundError:
raise FileNotFoundError(f"错误未找到MP3文件 '{mp3_path}'")
except Exception as e: # pydub 可能抛出多种解码相关的错误
raise Exception(f"错误处理MP3文件 '{mp3_path}' 失败 - {e}")
raise FileNotFoundError(f"错误: 未找到MP3文件 '{mp3_path}'")
except Exception as e: # pydub 可能抛出多种解码相关的错误
raise Exception(f"错误: 处理MP3文件 '{mp3_path}' 失败 - {e}")
def bytes_to_mp3(bytes_data: bytes, mp3_path: str, frame_rate: int, channels: int, sample_width: int, bitrate: str = "192k"):
def bytes_to_mp3(
bytes_data: bytes,
mp3_path: str,
frame_rate: int,
channels: int,
sample_width: int,
bitrate: str = "192k",
):
"""
将原始PCM bytes数据转换为MP3文件
@ -102,9 +114,9 @@ def bytes_to_mp3(bytes_data: bytes, mp3_path: str, frame_rate: int, channels: in
data=bytes_data,
sample_width=sample_width,
frame_rate=frame_rate,
channels=channels
channels=channels,
)
# 导出为MP3
audio.export(mp3_path, format="mp3", bitrate=bitrate)
except Exception as e:
raise Exception(f"错误转换或写入MP3文件 '{mp3_path}' 失败 - {e}")
raise Exception(f"错误: 转换或写入MP3文件 '{mp3_path}' 失败 - {e}")

View File

@ -3,6 +3,7 @@ import sys
from pathlib import Path
from typing import Optional
def setup_logger(
name: str = None,
level: str = "INFO",
@ -12,80 +13,79 @@ def setup_logger(
) -> logging.Logger:
"""
设置并返回一个配置好的logger实例
Args:
name: logger的名称默认为None使用root logger
level: 日志级别默认为"INFO"
log_file: 日志文件路径默认为None仅控制台输出
log_format: 日志格式
date_format: 日期格式
Returns:
logging.Logger: 配置好的logger实例
"""
# 获取logger实例
logger = logging.getLogger(name)
# 设置日志级别
level = getattr(logging, level.upper())
logger.setLevel(level)
print(f"添加处理器 {name} {log_file} {log_format} {date_format}")
# 创建格式器
formatter = logging.Formatter(log_format, date_format)
# 添加控制台处理器
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# 如果指定了日志文件,添加文件处理器
if log_file:
# 确保日志目录存在
log_path = Path(log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# 注意:移除了 propagate = False允许日志传递
return logger
def setup_root_logger(
level: str = "INFO",
log_file: Optional[str] = None
) -> None:
def setup_root_logger(level: str = "INFO", log_file: Optional[str] = None) -> None:
"""
配置根日志器
Args:
level: 日志级别
log_file: 日志文件路径
"""
setup_logger(None, level, log_file)
def get_module_logger(
module_name: str,
level: Optional[str] = None, # 改为可选参数
log_file: Optional[str] = None # 一般不需要单独指定
log_file: Optional[str] = None, # 一般不需要单独指定
) -> logging.Logger:
"""
获取模块级别的logger
Args:
module_name: 模块名称通常传入__name__
level: 可选的日志级别如果不指定则继承父级配置
log_file: 可选的日志文件路径一般不需要指定
"""
logger = logging.getLogger(module_name)
# 只有显式指定了level才设置
if level:
logger.setLevel(getattr(logging, level.upper()))
# 只有显式指定了log_file才添加文件处理器
if log_file:
setup_logger(module_name, level or "INFO", log_file)
return logger
return logger

View File

@ -1,10 +1,15 @@
from tests.functor.vad_test import test_vad_functor
"""
测试主函数
请在tests目录下创建测试文件, 并在此文件中调用
"""
from tests.pipeline.asr_test import test_asr_pipeline
from src.utils.logger import get_module_logger, setup_root_logger
setup_root_logger(level="INFO", log_file="logs/test_main.log")
logger = get_module_logger(__name__)
# from tests.functor.vad_test import test_vad_functor
# logger.info("开始测试VAD函数器")
# test_vad_functor()

View File

@ -1 +1 @@
"""FunASR WebSocket服务测试模块"""
"""FunASR WebSocket服务测试模块"""

View File

@ -2,6 +2,7 @@
Functor测试
VAD测试
"""
from src.functor.vad_functor import VADFunctor
from src.functor.asr_functor import ASRFunctor
from src.functor.spk_functor import SPKFunctor
@ -21,6 +22,7 @@ logger = get_module_logger(__name__)
model_loader = ModelLoader()
def test_vad_functor():
# 加载模型
args = {
@ -38,9 +40,9 @@ def test_vad_functor():
chunk_stride=1600,
sample_rate=sample_rate,
sample_width=16,
channels=1
channels=1,
)
chunk_stride = int(audio_config.chunk_size*sample_rate/1000)
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
audio_config.chunk_stride = chunk_stride
# 创建输入队列
input_queue = Queue()
@ -62,9 +64,7 @@ def test_vad_functor():
vad_functor.add_callback(lambda x: vad2asr_queue.put(x))
vad_functor.add_callback(lambda x: vad2spk_queue.put(x))
# 设置模型
vad_functor.set_model({
'vad': model_loader.models['vad']
})
vad_functor.set_model({"vad": model_loader.models["vad"]})
# 启动VAD函数器
vad_functor.run()
@ -77,9 +77,7 @@ def test_vad_functor():
# 设置回调函数
asr_functor.add_callback(lambda x: print(f"asr callback: {x}"))
# 设置模型
asr_functor.set_model({
'asr': model_loader.models['asr']
})
asr_functor.set_model({"asr": model_loader.models["asr"]})
# 启动ASR函数器
asr_functor.run()
@ -92,23 +90,25 @@ def test_vad_functor():
# 设置回调函数
spk_functor.add_callback(lambda x: print(f"spk callback: {x}"))
# 设置模型
spk_functor.set_model({
# 'spk': model_loader.models['spk']
'spk': 'fake_spk'
})
spk_functor.set_model(
{
# 'spk': model_loader.models['spk']
"spk": "fake_spk"
}
)
# 启动SPK函数器
spk_functor.run()
f_binary = f_data
audio_clip_len = 200
print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}")
print(
f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}"
)
for i in range(0, len(f_binary), audio_clip_len):
binary_data = f_binary[i:i+audio_clip_len]
binary_data = f_binary[i : i + audio_clip_len]
input_queue.put(binary_data)
# 等待VAD函数器结束
vad_functor.stop()
print("[vad_test] VAD函数器结束")
@ -119,4 +119,6 @@ def test_vad_functor():
if OVERWATCH:
for index in range(len(audio_binary_data_list)):
save_path = f"tests/vad_test_output_{index}.wav"
soundfile.write(save_path, audio_binary_data_list[index].binary_data, sample_rate)
soundfile.write(
save_path, audio_binary_data_list[index].binary_data, sample_rate
)

View File

@ -1,10 +1,21 @@
"""
模型使用测试
此处主要用于各类调用模型的处理数据与输出格式
请在主目录下test_main.py中调用
将需要测试的模型定义在函数中进行测试, 函数名称需要与测试内容匹配
"""
from funasr import AutoModel
from typing import List, Dict, Any
from src.models import VADResponse
import time
def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
chunk_size = 100 # ms
"""
在线VAD模型使用
"""
chunk_size = 100 # ms
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
vad_result = VADResponse()
@ -16,12 +27,14 @@ def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
chunk_stride = int(chunk_size * sample_rate / 1000)
cache = {}
total_chunk_num = int(len((speech)-1)/chunk_stride+1)
total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
for i in range(total_chunk_num):
time.sleep(0.1)
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
is_final = i == total_chunk_num - 1
res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
res = model.generate(
input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size
)
if len(res[0]["value"]):
vad_result += VADResponse.from_raw(res)
for item in res[0]["value"]:
@ -32,44 +45,64 @@ def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
# print(item)
return vad_result
def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
"""
在线VAD模型使用
测试LogicTrager
在Rebuild版本后LogicTrager中已弃用
"""
from src.logic_trager import LogicTrager
import soundfile
from src.config import parse_args
args = parse_args()
# from src.functor.model_loader import load_models
# models = load_models(args)
from src.model_loader import ModelLoader
models = ModelLoader(args)
chunk_size = 200 # ms
chunk_size = 200 # ms
from src.models import AudioBinary_Config
import soundfile
speech, sample_rate = soundfile.read(file_path)
chunk_stride = int(chunk_size * sample_rate / 1000)
audio_config = AudioBinary_Config(sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size)
audio_config = AudioBinary_Config(
sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size
)
logic_trager = LogicTrager(models=models, audio_config=audio_config)
for i in range(len(speech)//chunk_stride+1):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
for i in range(len(speech) // chunk_stride + 1):
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
logic_trager.push_binary_data(speech_chunk)
# for item in items:
# print(item)
return None
def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
"""
ASR模型使用
离线ASR模型使用
"""
from funasr import AutoModel
model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
vad_model="fsmn-vad", vad_model_revision="v2.0.4",
# punc_model="ct-punc-c", punc_model_revision="v2.0.4",
spk_model="cam++", spk_model_revision="v2.0.2",
spk_mode="vad_segment",
auto_update=False,
)
model = AutoModel(
model="paraformer-zh",
model_revision="v2.0.4",
vad_model="fsmn-vad",
vad_model_revision="v2.0.4",
# punc_model="ct-punc-c", punc_model_revision="v2.0.4",
spk_model="cam++",
spk_model_revision="v2.0.2",
spk_mode="vad_segment",
auto_update=False,
)
import soundfile
@ -80,7 +113,9 @@ def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
result = model.generate(speech)
return result
if __name__ == "__main__":
# vad_result = vad_model_use_online("tests/vad_example.wav")
vad_result = vad_model_use_online_logic("tests/vad_example.wav")
# print(vad_result)
# if __name__ == "__main__":
# 请在主目录下调用test_main.py文件进行测试
# vad_result = vad_model_use_online("tests/vad_example.wav")
# vad_result = vad_model_use_online_logic("tests/vad_example.wav")
# print(vad_result)

View File

@ -2,7 +2,9 @@
Pipeline测试
VAD+ASR+SPK(FAKE)
"""
from src.pipeline.ASRpipeline import ASRPipeline
from src.pipeline import PipelineFactory
from src.models import AudioBinary_data_list, AudioBinary_Config
from src.model_loader import ModelLoader
from queue import Queue
@ -18,6 +20,7 @@ OVAERWATCH = False
model_loader = ModelLoader()
def test_asr_pipeline():
# 加载模型
args = {
@ -36,9 +39,9 @@ def test_asr_pipeline():
chunk_stride=1600,
sample_rate=sample_rate,
sample_width=16,
channels=1
channels=1,
)
chunk_stride = int(audio_config.chunk_size*sample_rate/1000)
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
audio_config.chunk_stride = chunk_stride
# 创建参数Dict
@ -52,29 +55,39 @@ def test_asr_pipeline():
input_queue = Queue()
# 创建Pipeline
asr_pipeline = ASRPipeline()
asr_pipeline.set_models(models)
asr_pipeline.set_config(config)
asr_pipeline.set_audio_binary(audio_binary_data_list)
asr_pipeline.set_input_queue(input_queue)
asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
asr_pipeline.bake()
# asr_pipeline = ASRPipeline()
# asr_pipeline.set_models(models)
# asr_pipeline.set_config(config)
# asr_pipeline.set_audio_binary(audio_binary_data_list)
# asr_pipeline.set_input_queue(input_queue)
# asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
# asr_pipeline.bake()
asr_pipeline = PipelineFactory.create_pipeline(
pipeline_name = "ASRpipeline",
models=models,
config=config,
audio_binary=audio_binary_data_list,
input_queue=input_queue,
callback=lambda x: print(f"pipeline callback: {x}")
)
# 运行Pipeline
asr_instance = asr_pipeline.run()
audio_clip_len = 200
print(f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}")
print(
f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}"
)
for i in range(0, len(audio_data), audio_clip_len):
input_queue.put(audio_data[i:i+audio_clip_len])
input_queue.put(audio_data[i : i + audio_clip_len])
# time.sleep(10)
# input_queue.put(None)
# 等待Pipeline结束
# asr_instance.join()
time.sleep(5)
asr_pipeline.stop()
# asr_pipeline.stop()

View File

@ -10,23 +10,23 @@ import os
from unittest.mock import patch
# 将src目录添加到路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.config import parse_args
def test_default_args():
"""测试默认参数值"""
with patch('sys.argv', ['script.py']):
with patch("sys.argv", ["script.py"]):
args = parse_args()
# 检查服务器参数
assert args.host == "0.0.0.0"
assert args.port == 10095
# 检查SSL参数
assert args.certfile == ""
assert args.keyfile == ""
# 检查模型参数
assert "paraformer" in args.asr_model
assert args.asr_model_revision == "v2.0.4"
@ -36,7 +36,7 @@ def test_default_args():
assert args.vad_model_revision == "v2.0.4"
assert "punc" in args.punc_model
assert args.punc_model_revision == "v2.0.4"
# 检查硬件配置
assert args.ngpu == 1
assert args.device == "cuda"
@ -46,19 +46,26 @@ def test_default_args():
def test_custom_args():
"""测试自定义参数值"""
test_args = [
'script.py',
'--host', 'localhost',
'--port', '8080',
'--certfile', 'cert.pem',
'--keyfile', 'key.pem',
'--asr_model', 'custom_model',
'--ngpu', '0',
'--device', 'cpu'
"script.py",
"--host",
"localhost",
"--port",
"8080",
"--certfile",
"cert.pem",
"--keyfile",
"key.pem",
"--asr_model",
"custom_model",
"--ngpu",
"0",
"--device",
"cpu",
]
with patch('sys.argv', test_args):
with patch("sys.argv", test_args):
args = parse_args()
# 检查自定义参数
assert args.host == "localhost"
assert args.port == 8080
@ -66,4 +73,4 @@ def test_custom_args():
assert args.keyfile == "key.pem"
assert args.asr_model == "custom_model"
assert args.ngpu == 0
assert args.device == "cpu"
assert args.device == "cpu"