From 5a820b49e43567825566f2012f846cb33fafd7e9 Mon Sep 17 00:00:00 2001 From: "Ziyang.Zhang" Date: Thu, 12 Jun 2025 15:49:43 +0800 Subject: [PATCH] =?UTF-8?q?[=E4=BB=A3=E7=A0=81=E7=BB=93=E6=9E=84]black=20.?= =?UTF-8?q?=20=E5=AF=B9=E6=89=80=E6=9C=89=E6=96=87=E4=BB=B6=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=E8=B0=83=E6=95=B4=EF=BC=8C=E6=97=A0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E5=8F=98=E5=8C=96=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 22 +++----- src/__init__.py | 2 +- src/audio_chunk.py | 22 ++++---- src/config.py | 99 +++++++++++---------------------- src/functor/__init__.py | 2 +- src/functor/asr_functor.py | 24 ++++---- src/functor/base.py | 46 +++++++-------- src/functor/spk_functor.py | 23 ++++---- src/functor/vad_functor.py | 5 +- src/logic_trager.py | 59 ++++++++++++-------- src/model_loader.py | 2 +- src/models/__init__.py | 8 ++- src/models/audio.py | 29 ++++++++-- src/models/vad.py | 39 +++++++------ src/pipeline/ASRpipeline.py | 55 ++++++++++-------- src/pipeline/base.py | 33 +++++++++-- src/runner.py | 6 +- src/server.py | 108 ++++++++++++++++++++++-------------- src/service.py | 94 ++++++++++++++++--------------- src/utils/__init__.py | 2 +- src/utils/data_format.py | 42 +++++++++----- src/utils/logger.py | 40 ++++++------- test_main.py | 7 ++- tests/__init__.py | 2 +- tests/functor/vad_test.py | 36 ++++++------ tests/modelsuse.py | 73 +++++++++++++++++------- tests/pipeline/asr_test.py | 41 +++++++++----- tests/test_config.py | 43 ++++++++------ 28 files changed, 543 insertions(+), 421 deletions(-) diff --git a/main.py b/main.py index 03accbd..49697c4 100644 --- a/main.py +++ b/main.py @@ -1,11 +1,7 @@ from funasr import AutoModel -chunk_size = 200 # ms -model = AutoModel( - model="fsmn-vad", - model_revision="v2.0.4", - disable_update=True -) +chunk_size = 200 # ms +model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True) import soundfile @@ -14,16 +10,16 @@ speech, sample_rate = soundfile.read(wav_file) chunk_stride = int(chunk_size * sample_rate / 1000) cache = {} -total_chunk_num = int(len((speech)-1)/chunk_stride+1) +total_chunk_num = int(len((speech) - 1) / chunk_stride + 1) for i in range(total_chunk_num): - speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] + speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride] is_final = i == total_chunk_num - 1 res = model.generate( - input=speech_chunk, - cache=cache, - is_final=is_final, + input=speech_chunk, + cache=cache, + is_final=is_final, chunk_size=chunk_size, - disable_pbar=True + disable_pbar=True, ) if len(res[0]["value"]): print(res) @@ -31,4 +27,4 @@ for i in range(total_chunk_num): print(f"len(speech): {len(speech)}") print(f"len(speech_chunk): {len(speech_chunk)}") print(f"total_chunk_num: {total_chunk_num}") -print(f"generateconfig: chunk_size: {chunk_size}, chunk_stride: {chunk_stride}") \ No newline at end of file +print(f"generateconfig: chunk_size: {chunk_size}, chunk_stride: {chunk_stride}") diff --git a/src/__init__.py b/src/__init__.py index e13b69f..530b6f3 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -11,4 +11,4 @@ FunASR WebSocket服务 - 支持多种识别模式(2pass/online/offline) """ -__version__ = "0.1.0" \ No newline at end of file +__version__ = "0.1.0" diff --git a/src/audio_chunk.py b/src/audio_chunk.py index c96bef8..72b6241 100644 --- a/src/audio_chunk.py +++ b/src/audio_chunk.py @@ -46,7 +46,7 @@ class AudioBinary: else: raise ValueError("参数类型错误") - def add_slice_listener(self, slice_listener: callable): + def add_slice_listener(self, slice_listener: callable) -> None: """ 添加切片监听器 参数: @@ -98,10 +98,10 @@ class AudioBinary: self._binary_data_list.rewrite(target_index, binary_data) def get_binary_data( - self, - start: int = 0, - end: Optional[int] = None, - ) -> Optional[bytes]: + self, + start: int = 0, + end: Optional[int] = None, + ) -> Optional[bytes]: """ 获取指定索引的音频数据块 参数: @@ -128,7 +128,7 @@ class AudioChunk: 此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑。 """ - _instance = None + _instance: Optional[AudioChunk] = None def __new__(cls, *args, **kwargs): """ @@ -138,7 +138,7 @@ class AudioChunk: cls._instance = super(AudioChunk, cls).__new__(cls, *args, **kwargs) return cls._instance - def __init__(self): + def __init__(self) -> None: """ 初始化AudioChunk实例 """ @@ -146,10 +146,10 @@ class AudioChunk: self._slice_listener: List[callable] = [] def get_audio_binary( - self, - binary_name: Optional[str] = None, - audio_config: Optional[AudioBinary_Config] = None, - ) -> AudioBinary: + self, + binary_name: Optional[str] = None, + audio_config: Optional[AudioBinary_Config] = None, + ) -> AudioBinary: """ 获取音频数据块 参数: diff --git a/src/config.py b/src/config.py index feea99f..7b9a34f 100644 --- a/src/config.py +++ b/src/config.py @@ -10,116 +10,79 @@ import argparse def parse_args(): """ 解析命令行参数 - + 返回: argparse.Namespace: 解析后的参数对象 """ parser = argparse.ArgumentParser(description="FunASR WebSocket服务器") - + # 服务器配置 parser.add_argument( - "--host", - type=str, - default="0.0.0.0", - help="服务器主机地址,例如:localhost, 0.0.0.0" + "--host", + type=str, + default="0.0.0.0", + help="服务器主机地址,例如:localhost, 0.0.0.0", ) - parser.add_argument( - "--port", - type=int, - default=10095, - help="WebSocket服务器端口" - ) - + parser.add_argument("--port", type=int, default=10095, help="WebSocket服务器端口") + # SSL配置 - parser.add_argument( - "--certfile", - type=str, - default="", - help="SSL证书文件路径" - ) - parser.add_argument( - "--keyfile", - type=str, - default="", - help="SSL密钥文件路径" - ) - + parser.add_argument("--certfile", type=str, default="", help="SSL证书文件路径") + parser.add_argument("--keyfile", type=str, default="", help="SSL密钥文件路径") + # ASR模型配置 parser.add_argument( "--asr_model", type=str, default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", - help="离线ASR模型(从ModelScope获取)" + help="离线ASR模型(从ModelScope获取)", ) parser.add_argument( - "--asr_model_revision", - type=str, - default="v2.0.4", - help="离线ASR模型版本" + "--asr_model_revision", type=str, default="v2.0.4", help="离线ASR模型版本" ) - + # 在线ASR模型配置 parser.add_argument( "--asr_model_online", type=str, default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", - help="在线ASR模型(从ModelScope获取)" + help="在线ASR模型(从ModelScope获取)", ) parser.add_argument( - "--asr_model_online_revision", - type=str, - default="v2.0.4", - help="在线ASR模型版本" + "--asr_model_online_revision", + type=str, + default="v2.0.4", + help="在线ASR模型版本", ) - + # VAD模型配置 parser.add_argument( "--vad_model", type=str, default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", - help="VAD语音活动检测模型(从ModelScope获取)" + help="VAD语音活动检测模型(从ModelScope获取)", ) parser.add_argument( - "--vad_model_revision", - type=str, - default="v2.0.4", - help="VAD模型版本" + "--vad_model_revision", type=str, default="v2.0.4", help="VAD模型版本" ) - + # 标点符号模型配置 parser.add_argument( "--punc_model", type=str, default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", - help="标点符号模型(从ModelScope获取)" + help="标点符号模型(从ModelScope获取)", ) parser.add_argument( - "--punc_model_revision", - type=str, - default="v2.0.4", - help="标点符号模型版本" + "--punc_model_revision", type=str, default="v2.0.4", help="标点符号模型版本" ) - + # 硬件配置 + parser.add_argument("--ngpu", type=int, default=1, help="GPU数量,0表示仅使用CPU") parser.add_argument( - "--ngpu", - type=int, - default=1, - help="GPU数量,0表示仅使用CPU" + "--device", type=str, default="cuda", help="设备类型:cuda或cpu" ) - parser.add_argument( - "--device", - type=str, - default="cuda", - help="设备类型:cuda或cpu" - ) - parser.add_argument( - "--ncpu", - type=int, - default=4, - help="CPU核心数" - ) - + parser.add_argument("--ncpu", type=int, default=4, help="CPU核心数") + return parser.parse_args() @@ -127,4 +90,4 @@ if __name__ == "__main__": args = parse_args() print("配置参数:") for arg in vars(args): - print(f" {arg}: {getattr(args, arg)}") \ No newline at end of file + print(f" {arg}: {getattr(args, arg)}") diff --git a/src/functor/__init__.py b/src/functor/__init__.py index a18b658..dd5bb75 100644 --- a/src/functor/__init__.py +++ b/src/functor/__init__.py @@ -1,4 +1,4 @@ from .vad_functor import VADFunctor from .base import FunctorFactory -__all__ = ["VADFunctor", "FunctorFactory"] \ No newline at end of file +__all__ = ["VADFunctor", "FunctorFactory"] diff --git a/src/functor/asr_functor.py b/src/functor/asr_functor.py index 9f766a4..8a82ed9 100644 --- a/src/functor/asr_functor.py +++ b/src/functor/asr_functor.py @@ -2,8 +2,9 @@ ASRFunctor 负责对音频片段进行ASR处理, 以ASR_Result进行callback """ + from src.functor.base import BaseFunctor -from src.models import AudioBinary_data_list, AudioBinary_Config,VAD_Functor_result +from src.models import AudioBinary_data_list, AudioBinary_Config, VAD_Functor_result from typing import Callable, List from queue import Queue, Empty import threading @@ -13,6 +14,7 @@ from src.utils.logger import get_module_logger logger = get_module_logger(__name__) + class ASRFunctor(BaseFunctor): """ ASRFunctor @@ -51,26 +53,26 @@ class ASRFunctor(BaseFunctor): 重置缓存, 用于任务完成后清理缓存数据, 准备下次任务 """ pass - + def set_input_queue(self, queue: Queue) -> None: """ 设置监听的输入消息队列 """ self._input_queue = queue - + def set_model(self, model: dict) -> None: """ 设置推理模型 """ self._model = model - + def set_audio_config(self, audio_config: AudioBinary_Config) -> None: """ 设置音频配置 """ self._audio_config = audio_config logger.debug("ASRFunctor设置音频配置: %s", self._audio_config) - + def add_callback(self, callback: Callable) -> None: """ 向自身的_callback: List[Callable]回调函数列表中添加回调函数 @@ -78,12 +80,12 @@ class ASRFunctor(BaseFunctor): if not isinstance(self._callback, list): self._callback = [] self._callback.append(callback) - + def _do_callback(self, result: List[str]) -> None: """ 回调函数 """ - text = result[0]['text'].replace(" ", "") + text = result[0]["text"].replace(" ", "") for callback in self._callback: callback(text) @@ -98,7 +100,7 @@ class ASRFunctor(BaseFunctor): hotwords=self._hotwords, ) self._do_callback(result) - + def _run(self) -> None: """ 线程运行逻辑 @@ -132,7 +134,7 @@ class ASRFunctor(BaseFunctor): self._thread = threading.Thread(target=self._run, daemon=True) self._thread.start() return self._thread - + def _pre_check(self) -> bool: """ 预检查 @@ -146,7 +148,7 @@ class ASRFunctor(BaseFunctor): if self._callback is None: raise ValueError("回调函数未设置") return True - + def stop(self) -> bool: """ 停止线程 @@ -157,5 +159,3 @@ class ASRFunctor(BaseFunctor): with self._status_lock: self._is_running = False return not self._thread.is_alive() - - \ No newline at end of file diff --git a/src/functor/base.py b/src/functor/base.py index a6bd904..ba0b13f 100644 --- a/src/functor/base.py +++ b/src/functor/base.py @@ -4,18 +4,20 @@ Functor基础模块 该模块定义了Functor的基类,所有功能性的类(如VAD、PUNC、ASR、SPK等)都应继承自这个基类。 基类提供了数据处理的基本框架,包括: - 回调函数管理 -- 模型配置管理 +- 模型配置管理 - 线程运行控制 主要类: BaseFunctor: Functor抽象类 FunctorFactory: Functor工厂类 """ + from abc import ABC, abstractmethod from typing import Callable, List from queue import Queue import threading + class BaseFunctor(ABC): """ Functor抽象类 @@ -27,9 +29,7 @@ class BaseFunctor(ABC): _model (dict): 存储模型相关的配置和实例 """ - def __init__( - self - ): + def __init__(self): """ 初始化函数器 @@ -38,8 +38,8 @@ class BaseFunctor(ABC): model (dict): 模型相关的配置和实例 """ self._callback: List[Callable] = [] - self._model: dict = {} - # flag + self._model: dict = {} + # flag self._is_running: bool = False self._stop_event: bool = False # 状态锁 @@ -91,7 +91,7 @@ class BaseFunctor(ABC): 返回: 线程实例 """ - + @abstractmethod def _pre_check(self): """ @@ -111,13 +111,12 @@ class BaseFunctor(ABC): """ - class FunctorFactory: """ Functor工厂类 该工厂类负责创建和配置Functor实例 - + 主要方法: make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor: 创建并配置Functor实例 @@ -138,58 +137,55 @@ class FunctorFactory: """ if functor_name == "vad": - return cls._make_vadfunctor(config = config,models = models) + return cls._make_vadfunctor(config=config, models=models) elif functor_name == "asr": - return cls._make_asrfunctor(config = config,models = models) + return cls._make_asrfunctor(config=config, models=models) elif functor_name == "spk": - return cls._make_spkfunctor(config = config,models = models) + return cls._make_spkfunctor(config=config, models=models) else: raise ValueError(f"不支持的Functor类型: {functor_name}") - + def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor: """ 创建VAD Functor实例 """ from src.functor.vad_functor import VADFunctor + audio_config = config["audio_config"] - model = { - "vad": models["vad"] - } + model = {"vad": models["vad"]} vad_functor = VADFunctor() vad_functor.set_audio_config(audio_config) vad_functor.set_model(model) return vad_functor - + def _make_asrfunctor(config: dict, models: dict) -> BaseFunctor: """ 创建ASR Functor实例 """ from src.functor.asr_functor import ASRFunctor + audio_config = config["audio_config"] - model = { - "asr": models["asr"] - } + model = {"asr": models["asr"]} asr_functor = ASRFunctor() asr_functor.set_audio_config(audio_config) asr_functor.set_model(model) return asr_functor - + def _make_spkfunctor(config: dict, models: dict) -> BaseFunctor: """ 创建SPK Functor实例 """ from src.functor.spk_functor import SPKFunctor + audio_config = config["audio_config"] - model = { - "spk": models["spk"] - } + model = {"spk": models["spk"]} spk_functor = SPKFunctor() spk_functor.set_audio_config(audio_config) spk_functor.set_model(model) - return spk_functor \ No newline at end of file + return spk_functor diff --git a/src/functor/spk_functor.py b/src/functor/spk_functor.py index c72bb98..2a10d86 100644 --- a/src/functor/spk_functor.py +++ b/src/functor/spk_functor.py @@ -2,6 +2,7 @@ SpkFunctor 负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback """ + from src.functor.base import BaseFunctor from src.models import AudioBinary_Config, VAD_Functor_result from typing import Callable, List @@ -13,6 +14,7 @@ from src.utils.logger import get_module_logger logger = get_module_logger(__name__) + class SPKFunctor(BaseFunctor): """ SPKFunctor @@ -33,25 +35,24 @@ class SPKFunctor(BaseFunctor): self._input_queue: Queue = None # 输入队列 self._audio_config: AudioBinary_Config = None # 音频配置 - def reset_cache(self) -> None: """ 重置缓存, 用于任务完成后清理缓存数据, 准备下次任务 """ pass - + def set_input_queue(self, queue: Queue) -> None: """ 设置监听的输入消息队列 """ self._input_queue = queue - + def set_model(self, model: dict) -> None: """ 设置推理模型 """ self._model = model - + def set_audio_config(self, audio_config: AudioBinary_Config) -> None: """ 设置音频配置 @@ -66,7 +67,7 @@ class SPKFunctor(BaseFunctor): if not isinstance(self._callback, list): self._callback = [] self._callback.append(callback) - + def _do_callback(self, result: List[str]) -> None: """ 回调函数 @@ -83,9 +84,9 @@ class SPKFunctor(BaseFunctor): # input=binary_data, # chunk_size=self._audio_config.chunk_size, # ) - result = [{'result': "spk1", 'score': {"spk1": 0.9, "spk2": 0.3}}] + result = [{"result": "spk1", "score": {"spk1": 0.9, "spk2": 0.3}}] self._do_callback(result) - + def _run(self) -> None: """ 线程运行逻辑 @@ -108,7 +109,7 @@ class SPKFunctor(BaseFunctor): except Exception as e: logger.error("SpkFunctor运行时发生错误: %s", e) raise e - + def run(self) -> threading.Thread: """ 启动线程 @@ -119,7 +120,7 @@ class SPKFunctor(BaseFunctor): self._thread = threading.Thread(target=self._run, daemon=True) self._thread.start() return self._thread - + def _pre_check(self) -> bool: """ 预检查 @@ -131,7 +132,7 @@ class SPKFunctor(BaseFunctor): if self._callback is None: raise ValueError("回调函数未设置") return True - + def stop(self) -> bool: """ 停止线程 @@ -142,5 +143,3 @@ class SPKFunctor(BaseFunctor): with self._status_lock: self._is_running = False return not self._thread.is_alive() - - \ No newline at end of file diff --git a/src/functor/vad_functor.py b/src/functor/vad_functor.py index 4f4cfca..fb869c5 100644 --- a/src/functor/vad_functor.py +++ b/src/functor/vad_functor.py @@ -2,6 +2,7 @@ VADFunctor 负责对音频片段进行VAD处理, 以VAD_Result进行callback """ + import threading from queue import Empty, Queue from typing import List, Any, Callable @@ -105,9 +106,7 @@ class VADFunctor(BaseFunctor): self._callback = [] self._callback.append(callback) - def _do_callback( - self, result: List[List[int]] - ) -> None: + def _do_callback(self, result: List[List[int]]) -> None: """ 回调函数 VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice diff --git a/src/logic_trager.py b/src/logic_trager.py index 18a3210..bf92bdb 100644 --- a/src/logic_trager.py +++ b/src/logic_trager.py @@ -6,34 +6,36 @@ from src.utils.logger import get_module_logger from typing import Any, Dict, Type, Callable + # 配置日志 logger = get_module_logger(__name__, level="INFO") + class AutoAfterMeta(type): """ 自动调用__after__函数的元类 实现单例模式 """ - + _instances: Dict[Type, Any] = {} # 存储单例实例 - + def __new__(cls, name, bases, attrs): # 遍历所有属性 for attr_name, attr_value in attrs.items(): # 如果是函数且不是以_开头 - if callable(attr_value) and not attr_name.startswith('__'): + if callable(attr_value) and not attr_name.startswith("__"): # 获取原函数 original_func = attr_value - + # 创建包装函数 def make_wrapper(func): def wrapper(self, *args, **kwargs): # 执行原函数 result = func(self, *args, **kwargs) - + # 构建_after_函数名 after_func_name = f"__after__{func.__name__}" - + # 检查是否存在对应的_after_函数 if hasattr(self, after_func_name): after_func = getattr(self, after_func_name) @@ -43,17 +45,18 @@ class AutoAfterMeta(type): after_func() except Exception as e: logger.error(f"调用{after_func_name}时出错: {e}") - + return result + return wrapper - + # 替换原函数 attrs[attr_name] = make_wrapper(original_func) - + # 创建类 new_class = super().__new__(cls, name, bases, attrs) return new_class - + def __call__(cls, *args, **kwargs): """ 重写__call__方法实现单例模式 @@ -65,9 +68,10 @@ class AutoAfterMeta(type): logger.info(f"创建{cls.__name__}的新实例") else: logger.debug(f"返回{cls.__name__}的现有实例") - + return cls._instances[cls] + """ 整体识别的处理逻辑: 1.压入二进制音频信息 @@ -88,10 +92,12 @@ from src.models import AudioBinary_Config from src.models import AudioBinary_Chunk from typing import List + class LogicTrager(metaclass=AutoAfterMeta): """逻辑触发器类""" - - def __init__(self, + + def __init__( + self, audio_chunk_max_size: int = 1024 * 1024 * 10, audio_config: AudioBinary_Config = None, result_callback: Callable = None, @@ -99,46 +105,51 @@ class LogicTrager(metaclass=AutoAfterMeta): ): """初始化""" # 存储音频块 - self._audio_chunk : List[AudioBinary_Chunk] = [] + self._audio_chunk: List[AudioBinary_Chunk] = [] # 存储二进制数据 - self._audio_chunk_binary = b'' + self._audio_chunk_binary = b"" self._audio_chunk_max_size = audio_chunk_max_size # 音频参数 - self._audio_config = audio_config if audio_config is not None else AudioBinary_Config() + self._audio_config = ( + audio_config if audio_config is not None else AudioBinary_Config() + ) # 结果队列 self._result_queue = [] # 聚合结果回调函数 self._aggregate_result_callback = result_callback # 组件 - self._vad = VAD(VAD_model = models.get("vad"), audio_config = self._audio_config) + self._vad = VAD(VAD_model=models.get("vad"), audio_config=self._audio_config) self._vad.set_callback(self.push_audio_chunk) - logger.info("初始化LogicTrager") - + def push_binary_data(self, chunk: bytes) -> None: """ 压入音频块至VAD模块 - + 参数: chunk: 音频数据块 """ # print("LogicTrager push_binary_data", len(chunk)) self._vad.push_binary_data(chunk) self.__after__push_binary_data() - + def __after__push_binary_data(self) -> None: """ 添加音频块后处理 """ # print("LogicTrager __after__push_binary_data") self._vad.process_vad_result() - + def push_audio_chunk(self, chunk: AudioBinary_Chunk) -> None: """ 音频处理 """ - logger.info("LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(chunk.start_time, chunk.end_time, len(chunk.chunk))) + logger.info( + "LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format( + chunk.start_time, chunk.end_time, len(chunk.chunk) + ) + ) self._audio_chunk.append(chunk) def __after__push_audio_chunk(self) -> None: @@ -162,4 +173,4 @@ class LogicTrager(metaclass=AutoAfterMeta): def __call__(self): """调用函数""" - pass \ No newline at end of file + pass diff --git a/src/model_loader.py b/src/model_loader.py index 50717b2..80471e2 100644 --- a/src/model_loader.py +++ b/src/model_loader.py @@ -115,7 +115,7 @@ class ModelLoader: self.models = {} # 加载离线ASR模型 # 检查对应键是否存在 - model_list = ['asr', 'asr_online', 'vad', 'punc', 'spk'] + model_list = ["asr", "asr_online", "vad", "punc", "spk"] for model_name in model_list: name_model = f"{model_name}_model" name_model_revision = f"{model_name}_model_revision" diff --git a/src/models/__init__.py b/src/models/__init__.py index 28004e7..d55041f 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,3 +1,9 @@ from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data from .vad import VAD_Functor_result -__all__ = ["AudioBinary_Config", "AudioBinary_data_list", "_AudioBinary_data", "VAD_Functor_result"] \ No newline at end of file + +__all__ = [ + "AudioBinary_Config", + "AudioBinary_data_list", + "_AudioBinary_data", + "VAD_Functor_result", +] diff --git a/src/models/audio.py b/src/models/audio.py index 558e3f3..a8bbef6 100644 --- a/src/models/audio.py +++ b/src/models/audio.py @@ -8,8 +8,10 @@ logger = get_module_logger(__name__) binary_data_types = (bytes, numpy.ndarray) + class AudioBinary_Config(BaseModel): """二进制音频块配置信息""" + class Config: arbitrary_types_allowed = True @@ -37,14 +39,16 @@ class AudioBinary_Config(BaseModel): """ return int(frame * 1000 / self.sample_rate) + class _AudioBinary_data(BaseModel): """音频数据""" + binary_data: binary_data_types = Field(description="音频二进制数据", default=None) class Config: arbitrary_types_allowed = True - @validator('binary_data') + @validator("binary_data") def validate_binary_data(cls, v): """ 验证音频数据 @@ -54,7 +58,11 @@ class _AudioBinary_data(BaseModel): binary_data_types: 音频数据 """ if not isinstance(v, (bytes, numpy.ndarray)): - logger.warning("[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查", cls.__class__.__name__, type(v)) + logger.warning( + "[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查", + cls.__class__.__name__, + type(v), + ) return v def __len__(self): @@ -64,14 +72,18 @@ class _AudioBinary_data(BaseModel): int: 音频数据长度 """ return len(self.binary_data) - + def __init__(self, binary_data: binary_data_types): """ 初始化音频数据 Args: binary_data: 音频数据 """ - logger.debug("[%s]初始化音频数据, 数据类型为%s", self.__class__.__name__, type(binary_data)) + logger.debug( + "[%s]初始化音频数据, 数据类型为%s", + self.__class__.__name__, + type(binary_data), + ) super().__init__(binary_data=binary_data) def __getitem__(self): @@ -82,9 +94,13 @@ class _AudioBinary_data(BaseModel): """ return self.binary_data + class AudioBinary_data_list(BaseModel): """音频数据列表""" - binary_data_list: List[_AudioBinary_data] = Field(description="音频数据列表", default=[]) + + binary_data_list: List[_AudioBinary_data] = Field( + description="音频数据列表", default=[] + ) class Config: arbitrary_types_allowed = True @@ -118,6 +134,7 @@ class AudioBinary_data_list(BaseModel): """ return len(self.binary_data_list) + # class AudioBinary_Slice(BaseModel): # """音频块切片""" # target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None) @@ -138,4 +155,4 @@ class AudioBinary_data_list(BaseModel): # return v # def __call__(self): -# return self.target_Binary(self.start_index, self.end_index) \ No newline at end of file +# return self.target_Binary(self.start_index, self.end_index) diff --git a/src/models/vad.py b/src/models/vad.py index 7eafca7..30be717 100644 --- a/src/models/vad.py +++ b/src/models/vad.py @@ -2,23 +2,27 @@ from pydantic import BaseModel, Field, validator from typing import List, Optional, Callable, Any from .audio import AudioBinary_data_list, _AudioBinary_data + class VAD_Functor_result(BaseModel): """ VADFunctor结果 """ + audiobinary_data_list: AudioBinary_data_list = Field(description="音频数据列表") audiobinary_index: int = Field(description="音频数据索引") - audiobinary_data: _AudioBinary_data = Field(description="音频数据, 指向AudioBinary_data") + audiobinary_data: _AudioBinary_data = Field( + description="音频数据, 指向AudioBinary_data" + ) start_time: int = Field(description="开始时间", is_required=True) end_time: int = Field(description="结束时间", is_required=True) - @validator('audiobinary_data_list') + @validator("audiobinary_data_list") def validate_audiobinary_data_list(cls, v): if not isinstance(v, AudioBinary_data_list): raise ValueError("audiobinary_data_list必须是AudioBinary_data_list类型") return v - @validator('audiobinary_index') + @validator("audiobinary_index") def validate_audiobinary_index(cls, v): if not isinstance(v, int): raise ValueError("audiobinary_index必须是int类型") @@ -26,35 +30,35 @@ class VAD_Functor_result(BaseModel): raise ValueError("audiobinary_index必须大于0") return v - @validator('audiobinary_data') + @validator("audiobinary_data") def validate_audiobinary_data(cls, v): if not isinstance(v, _AudioBinary_data): raise ValueError("audiobinary_data必须是AudioBinary_data类型") return v - @validator('start_time') + @validator("start_time") def validate_start_time(cls, v): if not isinstance(v, int): raise ValueError("start_time必须是int类型") if v < 0: raise ValueError("start_time必须大于0") return v - - @validator('end_time') + + @validator("end_time") def validate_end_time(cls, v, values): if not isinstance(v, int): raise ValueError("end_time必须是int类型") - if 'start_time' in values and v <= values['start_time']: + if "start_time" in values and v <= values["start_time"]: raise ValueError("end_time必须大于start_time") return v - + @classmethod def create_from_push_data( cls, audiobinary_data_list: AudioBinary_data_list, data: Any, start_time: int, - end_time: int + end_time: int, ): """ 创建VAD片段 @@ -66,7 +70,8 @@ class VAD_Functor_result(BaseModel): audiobinary_index=index, audiobinary_data=audiobinary_data_list[index], start_time=start_time, - end_time=end_time) + end_time=end_time, + ) def __len__(self): """ @@ -78,11 +83,9 @@ class VAD_Functor_result(BaseModel): """ 字符串展示内容 """ - tostr = f'audiobinary_data_index: {self.audiobinary_index}\n' - tostr += f'start_time: {self.start_time}\n' - tostr += f'end_time: {self.end_time}\n' - tostr += f'data_length: {len(self.audiobinary_data.binary_data)}\n' - tostr += f'data_type: {type(self.audiobinary_data.binary_data)}\n' + tostr = f"audiobinary_data_index: {self.audiobinary_index}\n" + tostr += f"start_time: {self.start_time}\n" + tostr += f"end_time: {self.end_time}\n" + tostr += f"data_length: {len(self.audiobinary_data.binary_data)}\n" + tostr += f"data_type: {type(self.audiobinary_data.binary_data)}\n" return tostr - - \ No newline at end of file diff --git a/src/pipeline/ASRpipeline.py b/src/pipeline/ASRpipeline.py index a50e2b7..ab44de0 100644 --- a/src/pipeline/ASRpipeline.py +++ b/src/pipeline/ASRpipeline.py @@ -7,11 +7,13 @@ import threading logger = get_module_logger(__name__) + class ASRPipeline(PipelineBase): """ 管道类 实现具体的处理逻辑 """ + def __init__(self, *args, **kwargs): """ 初始化管道 @@ -91,27 +93,24 @@ class ASRPipeline(PipelineBase): """ try: from src.functor import FunctorFactory + # 加载VAD、asr、spk functor self._functor_dict["vad"] = FunctorFactory.make_functor( - functor_name = "vad", - config = self._config, - models = self._models + functor_name="vad", config=self._config, models=self._models ) self._functor_dict["asr"] = FunctorFactory.make_functor( - functor_name = "asr", - config = self._config, - models = self._models + functor_name="asr", config=self._config, models=self._models ) self._functor_dict["spk"] = FunctorFactory.make_functor( - functor_name = "spk", - config = self._config, - models = self._models + functor_name="spk", config=self._config, models=self._models ) # 创建音频数据存储单元 self._audio_binary_data_list = AudioBinary_data_list() - self._functor_dict["vad"].set_audio_binary_data_list(self._audio_binary_data_list) + self._functor_dict["vad"].set_audio_binary_data_list( + self._audio_binary_data_list + ) # 初始化子队列 self._subqueue_dict["original"] = Queue() @@ -121,7 +120,7 @@ class ASRPipeline(PipelineBase): self._subqueue_dict["spkend"] = Queue() # 设置子队列的输入队列 - self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"]) + self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"]) self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"]) self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"]) @@ -134,30 +133,40 @@ class ASRPipeline(PipelineBase): """ 带回调函数的put """ + def put_with_check(data: Any) -> None: queue.put(data) callback(data) + return put_with_check - self._functor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result)) - self._functor_dict["spk"].add_callback(put_with_checkcallback(self._subqueue_dict["spkend"], self._check_result)) + self._functor_dict["asr"].add_callback( + put_with_checkcallback( + self._subqueue_dict["asrend"], self._check_result + ) + ) + self._functor_dict["spk"].add_callback( + put_with_checkcallback( + self._subqueue_dict["spkend"], self._check_result + ) + ) except ImportError: - raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化") - + raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化") + def _check_result(self, result: Any) -> None: """ 检查结果 """ # 若asr和spk队列中都有数据,则合并数据 - if self._subqueue_dict["asrend"].qsize() & self._subqueue_dict["spkend"].qsize(): + if ( + self._subqueue_dict["asrend"].qsize() + & self._subqueue_dict["spkend"].qsize() + ): asr_data = self._subqueue_dict["asrend"].get() spk_data = self._subqueue_dict["spkend"].get() # 合并数据 - result = { - "asr_data": asr_data, - "spk_data": spk_data - } + result = {"asr_data": asr_data, "spk_data": spk_data} # 通知回调函数 self._notify_callbacks(result) @@ -211,13 +220,13 @@ class ASRPipeline(PipelineBase): logger.info("收到结束信号,管道准备停止") self._input_queue.task_done() # 标记结束信号已处理 break - + # 处理数据 self._process(data) # 标记任务完成 self._input_queue.task_done() - + except Empty: # 队列获取超时,继续等待 continue @@ -237,7 +246,7 @@ class ASRPipeline(PipelineBase): """ # 子类实现具体的处理逻辑 self._subqueue_dict["original"].put(data) - + def stop(self) -> None: """ 停止管道 diff --git a/src/pipeline/base.py b/src/pipeline/base.py index 6d7968d..7d5ef13 100644 --- a/src/pipeline/base.py +++ b/src/pipeline/base.py @@ -8,11 +8,13 @@ import time # 配置日志 logger = logging.getLogger(__name__) + class PipelineBase(ABC): """ 管道基类 定义了管道的基本接口和通用功能 """ + def __init__(self, input_queue: Optional[Queue] = None): """ 初始化管道 @@ -93,7 +95,7 @@ class PipelineBase(ABC): if self._thread and self._thread.is_alive(): timeout = timeout if timeout is not None else self._stop_timeout self._thread.join(timeout=timeout) - + # 检查是否成功停止 if self._thread.is_alive(): logger.warning(f"管道停止超时({timeout}秒),强制终止") @@ -101,7 +103,7 @@ class PipelineBase(ABC): else: logger.info("管道已成功停止") return True - + return True def force_stop(self) -> None: @@ -115,14 +117,35 @@ class PipelineBase(ABC): # 注意:Python的线程无法被强制终止,这里只是设置标志 # 实际终止需要依赖操作系统的进程管理 + class PipelineFactory: """ 管道工厂类 用于创建管道实例 """ - @staticmethod - def create_pipeline(pipeline_name: str) -> Any: + + from src.pipeline.ASRpipeline import ASRPipeline + def _create_pipeline_ASRpipeline(*args, **kwargs) -> ASRPipeline: + """ + 创建ASR管道实例 + """ + from src.pipeline.ASRpipeline import ASRPipeline + pipeline = ASRPipeline() + pipeline.set_config(kwargs["config"]) + pipeline.set_models(kwargs["models"]) + pipeline.set_audio_binary(kwargs["audio_binary"]) + pipeline.set_input_queue(kwargs["input_queue"]) + pipeline.add_callback(kwargs["callback"]) + pipeline.bake() + return pipeline + + @classmethod + def create_pipeline(cls, pipeline_name: str, *args, **kwargs) -> Any: """ 创建管道实例 """ - pass + if pipeline_name == "ASRpipeline": + return cls._create_pipeline_ASRpipeline(*args, **kwargs) + else: + raise ValueError(f"不支持的管道类型: {pipeline_name}") + diff --git a/src/runner.py b/src/runner.py index 24dabde..9d908cc 100644 --- a/src/runner.py +++ b/src/runner.py @@ -172,7 +172,7 @@ class STTRunner(RunnerBase): logger.warning( "等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理", self._stop_timeout, - self._input_queue.qsize() + self._input_queue.qsize(), ) success = False break @@ -188,7 +188,7 @@ class STTRunner(RunnerBase): "错误堆栈:\n%s", error_type, error_msg, - error_traceback + error_traceback, ) success = False @@ -198,7 +198,7 @@ class STTRunner(RunnerBase): logger.warning( "部分管道停止失败,队列状态: 大小=%d, 是否为空=%s", self._input_queue.qsize(), - self._input_queue.empty() + self._input_queue.empty(), ) return success diff --git a/src/server.py b/src/server.py index 9b27763..8d48693 100644 --- a/src/server.py +++ b/src/server.py @@ -43,7 +43,7 @@ async def clear_websocket(): async def ws_serve(websocket, path): """ WebSocket服务主函数,处理客户端连接和消息 - + 参数: websocket: WebSocket连接对象 path: 连接路径 @@ -51,13 +51,13 @@ async def ws_serve(websocket, path): frames = [] # 存储所有音频帧 frames_asr = [] # 存储用于离线ASR的音频帧 frames_asr_online = [] # 存储用于在线ASR的音频帧 - + global websocket_users # await clear_websocket() # 清理现有连接(目前注释掉,允许多客户端) - + # 添加到用户集合 websocket_users.add(websocket) - + # 初始化连接状态 websocket.status_dict_asr = {} websocket.status_dict_asr_online = {"cache": {}, "is_final": False} @@ -66,15 +66,15 @@ async def ws_serve(websocket, path): websocket.chunk_interval = 10 websocket.vad_pre_idx = 0 websocket.is_speaking = True # 默认用户正在说话 - + # 语音检测状态 speech_start = False speech_end_i = -1 - + # 初始化配置 websocket.wav_name = "microphone" websocket.mode = "2pass" # 默认使用两阶段识别模式 - + print("新用户已连接", flush=True) try: @@ -84,11 +84,13 @@ async def ws_serve(websocket, path): if isinstance(message, str): try: messagejson = json.loads(message) - + # 更新各种配置参数 if "is_speaking" in messagejson: websocket.is_speaking = messagejson["is_speaking"] - websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking + websocket.status_dict_asr_online["is_final"] = ( + not websocket.is_speaking + ) if "chunk_interval" in messagejson: websocket.chunk_interval = messagejson["chunk_interval"] if "wav_name" in messagejson: @@ -97,11 +99,17 @@ async def ws_serve(websocket, path): chunk_size = messagejson["chunk_size"] if isinstance(chunk_size, str): chunk_size = chunk_size.split(",") - websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size] + websocket.status_dict_asr_online["chunk_size"] = [ + int(x) for x in chunk_size + ] if "encoder_chunk_look_back" in messagejson: - websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"] + websocket.status_dict_asr_online["encoder_chunk_look_back"] = ( + messagejson["encoder_chunk_look_back"] + ) if "decoder_chunk_look_back" in messagejson: - websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"] + websocket.status_dict_asr_online["decoder_chunk_look_back"] = ( + messagejson["decoder_chunk_look_back"] + ) if "hotword" in messagejson: websocket.status_dict_asr["hotword"] = messagejson["hotwords"] if "mode" in messagejson: @@ -111,11 +119,17 @@ async def ws_serve(websocket, path): # 根据chunk_interval更新VAD的chunk_size websocket.status_dict_vad["chunk_size"] = int( - websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1] * 60 / websocket.chunk_interval + websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1] + * 60 + / websocket.chunk_interval ) - + # 处理音频数据 - if len(frames_asr_online) > 0 or len(frames_asr) >= 0 or not isinstance(message, str): + if ( + len(frames_asr_online) > 0 + or len(frames_asr) >= 0 + or not isinstance(message, str) + ): if not isinstance(message, str): # 二进制音频数据 # 添加到帧缓冲区 frames.append(message) @@ -125,10 +139,12 @@ async def ws_serve(websocket, path): # 处理在线ASR frames_asr_online.append(message) websocket.status_dict_asr_online["is_final"] = speech_end_i != -1 - + # 达到chunk_interval或最终帧时处理在线ASR - if (len(frames_asr_online) % websocket.chunk_interval == 0 or - websocket.status_dict_asr_online["is_final"]): + if ( + len(frames_asr_online) % websocket.chunk_interval == 0 + or websocket.status_dict_asr_online["is_final"] + ): if websocket.mode == "2pass" or websocket.mode == "online": audio_in = b"".join(frames_asr_online) try: @@ -136,26 +152,32 @@ async def ws_serve(websocket, path): except Exception as e: print(f"在线ASR处理错误: {e}") frames_asr_online = [] - + # 如果检测到语音开始,收集帧用于离线ASR if speech_start: frames_asr.append(message) - + # VAD处理 - 语音活动检测 try: - speech_start_i, speech_end_i = await asr_service.async_vad(websocket, message) + speech_start_i, speech_end_i = await asr_service.async_vad( + websocket, message + ) except Exception as e: print(f"VAD处理错误: {e}") - + # 检测到语音开始 if speech_start_i != -1: speech_start = True # 计算开始偏移并收集前面的帧 - beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms - frames_pre = frames[-beg_bias:] if beg_bias < len(frames) else frames + beg_bias = ( + websocket.vad_pre_idx - speech_start_i + ) // duration_ms + frames_pre = ( + frames[-beg_bias:] if beg_bias < len(frames) else frames + ) frames_asr = [] frames_asr.extend(frames_pre) - + # 处理离线ASR (语音结束或用户停止说话) if speech_end_i != -1 or not websocket.is_speaking: if websocket.mode == "2pass" or websocket.mode == "offline": @@ -164,13 +186,13 @@ async def ws_serve(websocket, path): await asr_service.async_asr(websocket, audio_in) except Exception as e: print(f"离线ASR处理错误: {e}") - + # 重置状态 frames_asr = [] speech_start = False frames_asr_online = [] websocket.status_dict_asr_online["cache"] = {} - + # 如果用户停止说话,完全重置 if not websocket.is_speaking: websocket.vad_pre_idx = 0 @@ -193,34 +215,34 @@ async def ws_serve(websocket, path): def start_server(args, asr_service_instance): """ 启动WebSocket服务器 - + 参数: args: 命令行参数 asr_service_instance: ASR服务实例 """ global asr_service asr_service = asr_service_instance - + # 配置SSL (如果提供了证书) if args.certfile and len(args.certfile) > 0: ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile) - + start_server = websockets.serve( - ws_serve, args.host, args.port, - subprotocols=["binary"], - ping_interval=None, - ssl=ssl_context + ws_serve, + args.host, + args.port, + subprotocols=["binary"], + ping_interval=None, + ssl=ssl_context, ) else: start_server = websockets.serve( - ws_serve, args.host, args.port, - subprotocols=["binary"], - ping_interval=None + ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None ) - + print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}") - + # 启动事件循环 asyncio.get_event_loop().run_until_complete(start_server) asyncio.get_event_loop().run_forever() @@ -229,14 +251,14 @@ def start_server(args, asr_service_instance): if __name__ == "__main__": # 解析命令行参数 args = parse_args() - + # 加载模型 print("正在加载模型...") models = load_models(args) print("模型加载完成!当前仅支持单个客户端同时连接!") - + # 创建ASR服务 asr_service = ASRService(models) - + # 启动服务器 - start_server(args, asr_service) \ No newline at end of file + start_server(args, asr_service) diff --git a/src/service.py b/src/service.py index 131cfa2..e8a6dfb 100644 --- a/src/service.py +++ b/src/service.py @@ -9,11 +9,11 @@ import json class ASRService: """ASR服务类,封装各种语音识别相关功能""" - + def __init__(self, models): """ 初始化ASR服务 - + 参数: models: 包含各种预加载模型的字典 """ @@ -21,42 +21,41 @@ class ASRService: self.model_asr_streaming = models["asr_streaming"] self.model_vad = models["vad"] self.model_punc = models["punc"] - + async def async_vad(self, websocket, audio_in): """ 语音活动检测 - + 参数: websocket: WebSocket连接 audio_in: 二进制音频数据 - + 返回: tuple: (speech_start, speech_end) 语音开始和结束位置 """ # 使用VAD模型分析音频段 segments_result = self.model_vad.generate( - input=audio_in, - **websocket.status_dict_vad + input=audio_in, **websocket.status_dict_vad )[0]["value"] - + speech_start = -1 speech_end = -1 - + # 解析VAD结果 if len(segments_result) == 0 or len(segments_result) > 1: return speech_start, speech_end - + if segments_result[0][0] != -1: speech_start = segments_result[0][0] if segments_result[0][1] != -1: speech_end = segments_result[0][1] - + return speech_start, speech_end - + async def async_asr(self, websocket, audio_in): """ 离线ASR处理 - + 参数: websocket: WebSocket连接 audio_in: 二进制音频数据 @@ -64,42 +63,44 @@ class ASRService: if len(audio_in) > 0: # 使用离线ASR模型处理音频 rec_result = self.model_asr.generate( - input=audio_in, - **websocket.status_dict_asr + input=audio_in, **websocket.status_dict_asr )[0] - + # 如果有标点符号模型且识别出文本,则添加标点 if self.model_punc is not None and len(rec_result["text"]) > 0: rec_result = self.model_punc.generate( - input=rec_result["text"], - **websocket.status_dict_punc + input=rec_result["text"], **websocket.status_dict_punc )[0] - + # 如果识别出文本,发送到客户端 if len(rec_result["text"]) > 0: mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode - message = json.dumps({ - "mode": mode, - "text": rec_result["text"], - "wav_name": websocket.wav_name, - "is_final": websocket.is_speaking, - }) + message = json.dumps( + { + "mode": mode, + "text": rec_result["text"], + "wav_name": websocket.wav_name, + "is_final": websocket.is_speaking, + } + ) await websocket.send(message) else: # 如果没有音频数据,发送空文本 mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode - message = json.dumps({ - "mode": mode, - "text": "", - "wav_name": websocket.wav_name, - "is_final": websocket.is_speaking, - }) + message = json.dumps( + { + "mode": mode, + "text": "", + "wav_name": websocket.wav_name, + "is_final": websocket.is_speaking, + } + ) await websocket.send(message) - + async def async_asr_online(self, websocket, audio_in): """ 在线ASR处理 - + 参数: websocket: WebSocket连接 audio_in: 二进制音频数据 @@ -107,21 +108,24 @@ class ASRService: if len(audio_in) > 0: # 使用在线ASR模型处理音频 rec_result = self.model_asr_streaming.generate( - input=audio_in, - **websocket.status_dict_asr_online + input=audio_in, **websocket.status_dict_asr_online )[0] - + # 在2pass模式下,如果是最终帧则跳过(留给离线ASR处理) - if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False): + if websocket.mode == "2pass" and websocket.status_dict_asr_online.get( + "is_final", False + ): return - + # 如果识别出文本,发送到客户端 if len(rec_result["text"]): mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode - message = json.dumps({ - "mode": mode, - "text": rec_result["text"], - "wav_name": websocket.wav_name, - "is_final": websocket.is_speaking, - }) - await websocket.send(message) \ No newline at end of file + message = json.dumps( + { + "mode": mode, + "text": rec_result["text"], + "wav_name": websocket.wav_name, + "is_final": websocket.is_speaking, + } + ) + await websocket.send(message) diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 0e567b2..19cb0bd 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,3 +1,3 @@ from .logger import get_module_logger, setup_root_logger -__all__ = ["get_module_logger", "setup_root_logger"] \ No newline at end of file +__all__ = ["get_module_logger", "setup_root_logger"] diff --git a/src/utils/data_format.py b/src/utils/data_format.py index bb845f9..110074f 100644 --- a/src/utils/data_format.py +++ b/src/utils/data_format.py @@ -1,10 +1,12 @@ """ 处理各类音频数据与bytes的转换 """ + import wave from pydub import AudioSegment import io + def wav_to_bytes(wav_path: str) -> bytes: """ 将WAV文件读取为bytes数据。 @@ -14,24 +16,26 @@ def wav_to_bytes(wav_path: str) -> bytes: 返回: bytes: WAV文件的原始字节数据。 - + 异常: FileNotFoundError: 如果WAV文件不存在。 wave.Error: 如果文件不是有效的WAV文件。 """ try: - with wave.open(wav_path, 'rb') as wf: + with wave.open(wav_path, "rb") as wf: # 读取所有帧 frames = wf.readframes(wf.getnframes()) return frames except FileNotFoundError: # 可以选择记录日志或重新抛出,这里为了清晰直接重新抛出 - raise FileNotFoundError(f"错误:未找到WAV文件 '{wav_path}'") + raise FileNotFoundError(f"错误: 未找到WAV文件 '{wav_path}'") except wave.Error as e: - raise wave.Error(f"错误:打开或读取WAV文件 '{wav_path}' 失败 - {e}") + raise wave.Error(f"错误: 打开或读取WAV文件 '{wav_path}' 失败 - {e}") -def bytes_to_wav(bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int): +def bytes_to_wav( + bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int +): """ 将bytes数据写入为WAV文件。 @@ -41,22 +45,23 @@ def bytes_to_wav(bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: in nchannels (int): 声道数 (例如 1 for mono, 2 for stereo)。 sampwidth (int): 采样宽度 (字节数, 例如 2 for 16-bit audio)。 framerate (int): 采样率 (例如 44100, 16000)。 - + 异常: wave.Error: 如果写入WAV文件失败。 """ try: - with wave.open(wav_path, 'wb') as wf: + with wave.open(wav_path, "wb") as wf: wf.setnchannels(nchannels) wf.setsampwidth(sampwidth) wf.setframerate(framerate) wf.writeframes(bytes_data) except wave.Error as e: - raise wave.Error(f"错误:写入WAV文件 '{wav_path}' 失败 - {e}") + raise wave.Error(f"错误: 写入WAV文件 '{wav_path}' 失败 - {e}") except Exception as e: # 捕获其他可能的写入错误 raise Exception(f"写入WAV文件 '{wav_path}' 时发生未知错误 - {e}") + def mp3_to_bytes(mp3_path: str) -> bytes: """ 将MP3文件转换为bytes数据 (原始PCM数据)。 @@ -66,7 +71,7 @@ def mp3_to_bytes(mp3_path: str) -> bytes: 返回: bytes: MP3文件解码后的原始PCM字节数据。 - + 异常: FileNotFoundError: 如果MP3文件不存在。 pydub.exceptions.CouldntDecodeError: 如果MP3文件无法解码。 @@ -76,12 +81,19 @@ def mp3_to_bytes(mp3_path: str) -> bytes: # 获取原始PCM数据 return audio.raw_data except FileNotFoundError: - raise FileNotFoundError(f"错误:未找到MP3文件 '{mp3_path}'") - except Exception as e: # pydub 可能抛出多种解码相关的错误 - raise Exception(f"错误:处理MP3文件 '{mp3_path}' 失败 - {e}") + raise FileNotFoundError(f"错误: 未找到MP3文件 '{mp3_path}'") + except Exception as e: # pydub 可能抛出多种解码相关的错误 + raise Exception(f"错误: 处理MP3文件 '{mp3_path}' 失败 - {e}") -def bytes_to_mp3(bytes_data: bytes, mp3_path: str, frame_rate: int, channels: int, sample_width: int, bitrate: str = "192k"): +def bytes_to_mp3( + bytes_data: bytes, + mp3_path: str, + frame_rate: int, + channels: int, + sample_width: int, + bitrate: str = "192k", +): """ 将原始PCM bytes数据转换为MP3文件。 @@ -102,9 +114,9 @@ def bytes_to_mp3(bytes_data: bytes, mp3_path: str, frame_rate: int, channels: in data=bytes_data, sample_width=sample_width, frame_rate=frame_rate, - channels=channels + channels=channels, ) # 导出为MP3 audio.export(mp3_path, format="mp3", bitrate=bitrate) except Exception as e: - raise Exception(f"错误:转换或写入MP3文件 '{mp3_path}' 失败 - {e}") + raise Exception(f"错误: 转换或写入MP3文件 '{mp3_path}' 失败 - {e}") diff --git a/src/utils/logger.py b/src/utils/logger.py index b8f5f12..444a399 100644 --- a/src/utils/logger.py +++ b/src/utils/logger.py @@ -3,6 +3,7 @@ import sys from pathlib import Path from typing import Optional + def setup_logger( name: str = None, level: str = "INFO", @@ -12,80 +13,79 @@ def setup_logger( ) -> logging.Logger: """ 设置并返回一个配置好的logger实例 - + Args: name: logger的名称,默认为None(使用root logger) level: 日志级别,默认为"INFO" log_file: 日志文件路径,默认为None(仅控制台输出) log_format: 日志格式 date_format: 日期格式 - + Returns: logging.Logger: 配置好的logger实例 """ # 获取logger实例 logger = logging.getLogger(name) - + # 设置日志级别 level = getattr(logging, level.upper()) logger.setLevel(level) - + print(f"添加处理器 {name} {log_file} {log_format} {date_format}") # 创建格式器 formatter = logging.Formatter(log_format, date_format) - + # 添加控制台处理器 console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(formatter) logger.addHandler(console_handler) - + # 如果指定了日志文件,添加文件处理器 if log_file: # 确保日志目录存在 log_path = Path(log_file) log_path.parent.mkdir(parents=True, exist_ok=True) - - file_handler = logging.FileHandler(log_file, encoding='utf-8') + + file_handler = logging.FileHandler(log_file, encoding="utf-8") file_handler.setFormatter(formatter) logger.addHandler(file_handler) - + # 注意:移除了 propagate = False,允许日志传递 return logger -def setup_root_logger( - level: str = "INFO", - log_file: Optional[str] = None -) -> None: + +def setup_root_logger(level: str = "INFO", log_file: Optional[str] = None) -> None: """ 配置根日志器 - + Args: level: 日志级别 log_file: 日志文件路径 """ setup_logger(None, level, log_file) + def get_module_logger( module_name: str, level: Optional[str] = None, # 改为可选参数 - log_file: Optional[str] = None # 一般不需要单独指定 + log_file: Optional[str] = None, # 一般不需要单独指定 ) -> logging.Logger: """ 获取模块级别的logger - + Args: module_name: 模块名称,通常传入__name__ level: 可选的日志级别,如果不指定则继承父级配置 log_file: 可选的日志文件路径,一般不需要指定 """ logger = logging.getLogger(module_name) - + # 只有显式指定了level才设置 if level: logger.setLevel(getattr(logging, level.upper())) - + # 只有显式指定了log_file才添加文件处理器 if log_file: setup_logger(module_name, level or "INFO", log_file) - - return logger \ No newline at end of file + + return logger diff --git a/test_main.py b/test_main.py index 4fea328..caaa40e 100644 --- a/test_main.py +++ b/test_main.py @@ -1,10 +1,15 @@ -from tests.functor.vad_test import test_vad_functor +""" +测试主函数 +请在tests目录下创建测试文件, 并在此文件中调用 +""" + from tests.pipeline.asr_test import test_asr_pipeline from src.utils.logger import get_module_logger, setup_root_logger setup_root_logger(level="INFO", log_file="logs/test_main.log") logger = get_module_logger(__name__) +# from tests.functor.vad_test import test_vad_functor # logger.info("开始测试VAD函数器") # test_vad_functor() diff --git a/tests/__init__.py b/tests/__init__.py index 4ffb53d..259625c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ -"""FunASR WebSocket服务测试模块""" \ No newline at end of file +"""FunASR WebSocket服务测试模块""" diff --git a/tests/functor/vad_test.py b/tests/functor/vad_test.py index ab3abc7..97016dd 100644 --- a/tests/functor/vad_test.py +++ b/tests/functor/vad_test.py @@ -2,6 +2,7 @@ Functor测试 VAD测试 """ + from src.functor.vad_functor import VADFunctor from src.functor.asr_functor import ASRFunctor from src.functor.spk_functor import SPKFunctor @@ -21,6 +22,7 @@ logger = get_module_logger(__name__) model_loader = ModelLoader() + def test_vad_functor(): # 加载模型 args = { @@ -38,9 +40,9 @@ def test_vad_functor(): chunk_stride=1600, sample_rate=sample_rate, sample_width=16, - channels=1 + channels=1, ) - chunk_stride = int(audio_config.chunk_size*sample_rate/1000) + chunk_stride = int(audio_config.chunk_size * sample_rate / 1000) audio_config.chunk_stride = chunk_stride # 创建输入队列 input_queue = Queue() @@ -62,9 +64,7 @@ def test_vad_functor(): vad_functor.add_callback(lambda x: vad2asr_queue.put(x)) vad_functor.add_callback(lambda x: vad2spk_queue.put(x)) # 设置模型 - vad_functor.set_model({ - 'vad': model_loader.models['vad'] - }) + vad_functor.set_model({"vad": model_loader.models["vad"]}) # 启动VAD函数器 vad_functor.run() @@ -77,9 +77,7 @@ def test_vad_functor(): # 设置回调函数 asr_functor.add_callback(lambda x: print(f"asr callback: {x}")) # 设置模型 - asr_functor.set_model({ - 'asr': model_loader.models['asr'] - }) + asr_functor.set_model({"asr": model_loader.models["asr"]}) # 启动ASR函数器 asr_functor.run() @@ -92,23 +90,25 @@ def test_vad_functor(): # 设置回调函数 spk_functor.add_callback(lambda x: print(f"spk callback: {x}")) # 设置模型 - spk_functor.set_model({ - # 'spk': model_loader.models['spk'] - 'spk': 'fake_spk' - }) + spk_functor.set_model( + { + # 'spk': model_loader.models['spk'] + "spk": "fake_spk" + } + ) # 启动SPK函数器 spk_functor.run() - f_binary = f_data audio_clip_len = 200 - print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}") + print( + f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}" + ) for i in range(0, len(f_binary), audio_clip_len): - binary_data = f_binary[i:i+audio_clip_len] + binary_data = f_binary[i : i + audio_clip_len] input_queue.put(binary_data) # 等待VAD函数器结束 - vad_functor.stop() print("[vad_test] VAD函数器结束") @@ -119,4 +119,6 @@ def test_vad_functor(): if OVERWATCH: for index in range(len(audio_binary_data_list)): save_path = f"tests/vad_test_output_{index}.wav" - soundfile.write(save_path, audio_binary_data_list[index].binary_data, sample_rate) + soundfile.write( + save_path, audio_binary_data_list[index].binary_data, sample_rate + ) diff --git a/tests/modelsuse.py b/tests/modelsuse.py index 30d035d..8123d3f 100644 --- a/tests/modelsuse.py +++ b/tests/modelsuse.py @@ -1,10 +1,21 @@ +""" +模型使用测试 +此处主要用于各类调用模型的处理数据与输出格式 +请在主目录下test_main.py中调用 +将需要测试的模型定义在函数中进行测试, 函数名称需要与测试内容匹配。 +""" + from funasr import AutoModel from typing import List, Dict, Any from src.models import VADResponse import time + def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]: - chunk_size = 100 # ms + """ + 在线VAD模型使用 + """ + chunk_size = 100 # ms model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True) vad_result = VADResponse() @@ -16,12 +27,14 @@ def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]: chunk_stride = int(chunk_size * sample_rate / 1000) cache = {} - total_chunk_num = int(len((speech)-1)/chunk_stride+1) + total_chunk_num = int(len((speech) - 1) / chunk_stride + 1) for i in range(total_chunk_num): time.sleep(0.1) - speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] + speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride] is_final = i == total_chunk_num - 1 - res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size) + res = model.generate( + input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size + ) if len(res[0]["value"]): vad_result += VADResponse.from_raw(res) for item in res[0]["value"]: @@ -32,44 +45,64 @@ def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]: # print(item) return vad_result + def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]: + """ + 在线VAD模型使用 + 测试LogicTrager + 在Rebuild版本后LogicTrager中已弃用 + """ from src.logic_trager import LogicTrager import soundfile from src.config import parse_args + args = parse_args() # from src.functor.model_loader import load_models # models = load_models(args) from src.model_loader import ModelLoader + models = ModelLoader(args) - chunk_size = 200 # ms + chunk_size = 200 # ms from src.models import AudioBinary_Config import soundfile speech, sample_rate = soundfile.read(file_path) chunk_stride = int(chunk_size * sample_rate / 1000) - audio_config = AudioBinary_Config(sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size) + audio_config = AudioBinary_Config( + sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size + ) logic_trager = LogicTrager(models=models, audio_config=audio_config) - for i in range(len(speech)//chunk_stride+1): - speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] + for i in range(len(speech) // chunk_stride + 1): + speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride] logic_trager.push_binary_data(speech_chunk) # for item in items: # print(item) return None + def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]: + """ + ASR模型使用 + 离线ASR模型使用 + """ from funasr import AutoModel - model = AutoModel(model="paraformer-zh", model_revision="v2.0.4", - vad_model="fsmn-vad", vad_model_revision="v2.0.4", - # punc_model="ct-punc-c", punc_model_revision="v2.0.4", - spk_model="cam++", spk_model_revision="v2.0.2", - spk_mode="vad_segment", - auto_update=False, - ) + + model = AutoModel( + model="paraformer-zh", + model_revision="v2.0.4", + vad_model="fsmn-vad", + vad_model_revision="v2.0.4", + # punc_model="ct-punc-c", punc_model_revision="v2.0.4", + spk_model="cam++", + spk_model_revision="v2.0.2", + spk_mode="vad_segment", + auto_update=False, + ) import soundfile @@ -80,7 +113,9 @@ def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]: result = model.generate(speech) return result -if __name__ == "__main__": - # vad_result = vad_model_use_online("tests/vad_example.wav") - vad_result = vad_model_use_online_logic("tests/vad_example.wav") - # print(vad_result) \ No newline at end of file + +# if __name__ == "__main__": +# 请在主目录下调用test_main.py文件进行测试 +# vad_result = vad_model_use_online("tests/vad_example.wav") +# vad_result = vad_model_use_online_logic("tests/vad_example.wav") +# print(vad_result) diff --git a/tests/pipeline/asr_test.py b/tests/pipeline/asr_test.py index 388009d..631383c 100644 --- a/tests/pipeline/asr_test.py +++ b/tests/pipeline/asr_test.py @@ -2,7 +2,9 @@ Pipeline测试 VAD+ASR+SPK(FAKE) """ + from src.pipeline.ASRpipeline import ASRPipeline +from src.pipeline import PipelineFactory from src.models import AudioBinary_data_list, AudioBinary_Config from src.model_loader import ModelLoader from queue import Queue @@ -18,6 +20,7 @@ OVAERWATCH = False model_loader = ModelLoader() + def test_asr_pipeline(): # 加载模型 args = { @@ -36,9 +39,9 @@ def test_asr_pipeline(): chunk_stride=1600, sample_rate=sample_rate, sample_width=16, - channels=1 + channels=1, ) - chunk_stride = int(audio_config.chunk_size*sample_rate/1000) + chunk_stride = int(audio_config.chunk_size * sample_rate / 1000) audio_config.chunk_stride = chunk_stride # 创建参数Dict @@ -52,29 +55,39 @@ def test_asr_pipeline(): input_queue = Queue() # 创建Pipeline - asr_pipeline = ASRPipeline() - asr_pipeline.set_models(models) - asr_pipeline.set_config(config) - asr_pipeline.set_audio_binary(audio_binary_data_list) - asr_pipeline.set_input_queue(input_queue) - asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}")) - asr_pipeline.bake() - + # asr_pipeline = ASRPipeline() + # asr_pipeline.set_models(models) + # asr_pipeline.set_config(config) + # asr_pipeline.set_audio_binary(audio_binary_data_list) + # asr_pipeline.set_input_queue(input_queue) + # asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}")) + # asr_pipeline.bake() + asr_pipeline = PipelineFactory.create_pipeline( + pipeline_name = "ASRpipeline", + models=models, + config=config, + audio_binary=audio_binary_data_list, + input_queue=input_queue, + callback=lambda x: print(f"pipeline callback: {x}") + ) + # 运行Pipeline asr_instance = asr_pipeline.run() + audio_clip_len = 200 - print(f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}") + print( + f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}" + ) for i in range(0, len(audio_data), audio_clip_len): - input_queue.put(audio_data[i:i+audio_clip_len]) + input_queue.put(audio_data[i : i + audio_clip_len]) # time.sleep(10) # input_queue.put(None) # 等待Pipeline结束 # asr_instance.join() - + time.sleep(5) asr_pipeline.stop() # asr_pipeline.stop() - diff --git a/tests/test_config.py b/tests/test_config.py index 6710998..4540910 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -10,23 +10,23 @@ import os from unittest.mock import patch # 将src目录添加到路径 -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from src.config import parse_args def test_default_args(): """测试默认参数值""" - with patch('sys.argv', ['script.py']): + with patch("sys.argv", ["script.py"]): args = parse_args() - + # 检查服务器参数 assert args.host == "0.0.0.0" assert args.port == 10095 - + # 检查SSL参数 assert args.certfile == "" assert args.keyfile == "" - + # 检查模型参数 assert "paraformer" in args.asr_model assert args.asr_model_revision == "v2.0.4" @@ -36,7 +36,7 @@ def test_default_args(): assert args.vad_model_revision == "v2.0.4" assert "punc" in args.punc_model assert args.punc_model_revision == "v2.0.4" - + # 检查硬件配置 assert args.ngpu == 1 assert args.device == "cuda" @@ -46,19 +46,26 @@ def test_default_args(): def test_custom_args(): """测试自定义参数值""" test_args = [ - 'script.py', - '--host', 'localhost', - '--port', '8080', - '--certfile', 'cert.pem', - '--keyfile', 'key.pem', - '--asr_model', 'custom_model', - '--ngpu', '0', - '--device', 'cpu' + "script.py", + "--host", + "localhost", + "--port", + "8080", + "--certfile", + "cert.pem", + "--keyfile", + "key.pem", + "--asr_model", + "custom_model", + "--ngpu", + "0", + "--device", + "cpu", ] - - with patch('sys.argv', test_args): + + with patch("sys.argv", test_args): args = parse_args() - + # 检查自定义参数 assert args.host == "localhost" assert args.port == 8080 @@ -66,4 +73,4 @@ def test_custom_args(): assert args.keyfile == "key.pem" assert args.asr_model == "custom_model" assert args.ngpu == 0 - assert args.device == "cpu" \ No newline at end of file + assert args.device == "cpu"