[代码结构]black . 对所有文件格式调整,无功能变化。

This commit is contained in:
Ziyang.Zhang 2025-06-12 15:49:43 +08:00
parent 5b94c40016
commit 5a820b49e4
28 changed files with 543 additions and 421 deletions

22
main.py
View File

@ -1,11 +1,7 @@
from funasr import AutoModel from funasr import AutoModel
chunk_size = 200 # ms chunk_size = 200 # ms
model = AutoModel( model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
model="fsmn-vad",
model_revision="v2.0.4",
disable_update=True
)
import soundfile import soundfile
@ -14,16 +10,16 @@ speech, sample_rate = soundfile.read(wav_file)
chunk_stride = int(chunk_size * sample_rate / 1000) chunk_stride = int(chunk_size * sample_rate / 1000)
cache = {} cache = {}
total_chunk_num = int(len((speech)-1)/chunk_stride+1) total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
for i in range(total_chunk_num): for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
is_final = i == total_chunk_num - 1 is_final = i == total_chunk_num - 1
res = model.generate( res = model.generate(
input=speech_chunk, input=speech_chunk,
cache=cache, cache=cache,
is_final=is_final, is_final=is_final,
chunk_size=chunk_size, chunk_size=chunk_size,
disable_pbar=True disable_pbar=True,
) )
if len(res[0]["value"]): if len(res[0]["value"]):
print(res) print(res)
@ -31,4 +27,4 @@ for i in range(total_chunk_num):
print(f"len(speech): {len(speech)}") print(f"len(speech): {len(speech)}")
print(f"len(speech_chunk): {len(speech_chunk)}") print(f"len(speech_chunk): {len(speech_chunk)}")
print(f"total_chunk_num: {total_chunk_num}") print(f"total_chunk_num: {total_chunk_num}")
print(f"generateconfig: chunk_size: {chunk_size}, chunk_stride: {chunk_stride}") print(f"generateconfig: chunk_size: {chunk_size}, chunk_stride: {chunk_stride}")

View File

@ -11,4 +11,4 @@ FunASR WebSocket服务
- 支持多种识别模式(2pass/online/offline) - 支持多种识别模式(2pass/online/offline)
""" """
__version__ = "0.1.0" __version__ = "0.1.0"

View File

@ -46,7 +46,7 @@ class AudioBinary:
else: else:
raise ValueError("参数类型错误") raise ValueError("参数类型错误")
def add_slice_listener(self, slice_listener: callable): def add_slice_listener(self, slice_listener: callable) -> None:
""" """
添加切片监听器 添加切片监听器
参数: 参数:
@ -98,10 +98,10 @@ class AudioBinary:
self._binary_data_list.rewrite(target_index, binary_data) self._binary_data_list.rewrite(target_index, binary_data)
def get_binary_data( def get_binary_data(
self, self,
start: int = 0, start: int = 0,
end: Optional[int] = None, end: Optional[int] = None,
) -> Optional[bytes]: ) -> Optional[bytes]:
""" """
获取指定索引的音频数据块 获取指定索引的音频数据块
参数: 参数:
@ -128,7 +128,7 @@ class AudioChunk:
此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑 此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑
""" """
_instance = None _instance: Optional[AudioChunk] = None
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
""" """
@ -138,7 +138,7 @@ class AudioChunk:
cls._instance = super(AudioChunk, cls).__new__(cls, *args, **kwargs) cls._instance = super(AudioChunk, cls).__new__(cls, *args, **kwargs)
return cls._instance return cls._instance
def __init__(self): def __init__(self) -> None:
""" """
初始化AudioChunk实例 初始化AudioChunk实例
""" """
@ -146,10 +146,10 @@ class AudioChunk:
self._slice_listener: List[callable] = [] self._slice_listener: List[callable] = []
def get_audio_binary( def get_audio_binary(
self, self,
binary_name: Optional[str] = None, binary_name: Optional[str] = None,
audio_config: Optional[AudioBinary_Config] = None, audio_config: Optional[AudioBinary_Config] = None,
) -> AudioBinary: ) -> AudioBinary:
""" """
获取音频数据块 获取音频数据块
参数: 参数:

View File

@ -10,116 +10,79 @@ import argparse
def parse_args(): def parse_args():
""" """
解析命令行参数 解析命令行参数
返回: 返回:
argparse.Namespace: 解析后的参数对象 argparse.Namespace: 解析后的参数对象
""" """
parser = argparse.ArgumentParser(description="FunASR WebSocket服务器") parser = argparse.ArgumentParser(description="FunASR WebSocket服务器")
# 服务器配置 # 服务器配置
parser.add_argument( parser.add_argument(
"--host", "--host",
type=str, type=str,
default="0.0.0.0", default="0.0.0.0",
help="服务器主机地址例如localhost, 0.0.0.0" help="服务器主机地址例如localhost, 0.0.0.0",
) )
parser.add_argument( parser.add_argument("--port", type=int, default=10095, help="WebSocket服务器端口")
"--port",
type=int,
default=10095,
help="WebSocket服务器端口"
)
# SSL配置 # SSL配置
parser.add_argument( parser.add_argument("--certfile", type=str, default="", help="SSL证书文件路径")
"--certfile", parser.add_argument("--keyfile", type=str, default="", help="SSL密钥文件路径")
type=str,
default="",
help="SSL证书文件路径"
)
parser.add_argument(
"--keyfile",
type=str,
default="",
help="SSL密钥文件路径"
)
# ASR模型配置 # ASR模型配置
parser.add_argument( parser.add_argument(
"--asr_model", "--asr_model",
type=str, type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="离线ASR模型从ModelScope获取" help="离线ASR模型从ModelScope获取",
) )
parser.add_argument( parser.add_argument(
"--asr_model_revision", "--asr_model_revision", type=str, default="v2.0.4", help="离线ASR模型版本"
type=str,
default="v2.0.4",
help="离线ASR模型版本"
) )
# 在线ASR模型配置 # 在线ASR模型配置
parser.add_argument( parser.add_argument(
"--asr_model_online", "--asr_model_online",
type=str, type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
help="在线ASR模型从ModelScope获取" help="在线ASR模型从ModelScope获取",
) )
parser.add_argument( parser.add_argument(
"--asr_model_online_revision", "--asr_model_online_revision",
type=str, type=str,
default="v2.0.4", default="v2.0.4",
help="在线ASR模型版本" help="在线ASR模型版本",
) )
# VAD模型配置 # VAD模型配置
parser.add_argument( parser.add_argument(
"--vad_model", "--vad_model",
type=str, type=str,
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
help="VAD语音活动检测模型从ModelScope获取" help="VAD语音活动检测模型从ModelScope获取",
) )
parser.add_argument( parser.add_argument(
"--vad_model_revision", "--vad_model_revision", type=str, default="v2.0.4", help="VAD模型版本"
type=str,
default="v2.0.4",
help="VAD模型版本"
) )
# 标点符号模型配置 # 标点符号模型配置
parser.add_argument( parser.add_argument(
"--punc_model", "--punc_model",
type=str, type=str,
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
help="标点符号模型从ModelScope获取" help="标点符号模型从ModelScope获取",
) )
parser.add_argument( parser.add_argument(
"--punc_model_revision", "--punc_model_revision", type=str, default="v2.0.4", help="标点符号模型版本"
type=str,
default="v2.0.4",
help="标点符号模型版本"
) )
# 硬件配置 # 硬件配置
parser.add_argument("--ngpu", type=int, default=1, help="GPU数量0表示仅使用CPU")
parser.add_argument( parser.add_argument(
"--ngpu", "--device", type=str, default="cuda", help="设备类型cuda或cpu"
type=int,
default=1,
help="GPU数量0表示仅使用CPU"
) )
parser.add_argument( parser.add_argument("--ncpu", type=int, default=4, help="CPU核心数")
"--device",
type=str,
default="cuda",
help="设备类型cuda或cpu"
)
parser.add_argument(
"--ncpu",
type=int,
default=4,
help="CPU核心数"
)
return parser.parse_args() return parser.parse_args()
@ -127,4 +90,4 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
print("配置参数:") print("配置参数:")
for arg in vars(args): for arg in vars(args):
print(f" {arg}: {getattr(args, arg)}") print(f" {arg}: {getattr(args, arg)}")

View File

@ -1,4 +1,4 @@
from .vad_functor import VADFunctor from .vad_functor import VADFunctor
from .base import FunctorFactory from .base import FunctorFactory
__all__ = ["VADFunctor", "FunctorFactory"] __all__ = ["VADFunctor", "FunctorFactory"]

View File

@ -2,8 +2,9 @@
ASRFunctor ASRFunctor
负责对音频片段进行ASR处理, 以ASR_Result进行callback 负责对音频片段进行ASR处理, 以ASR_Result进行callback
""" """
from src.functor.base import BaseFunctor from src.functor.base import BaseFunctor
from src.models import AudioBinary_data_list, AudioBinary_Config,VAD_Functor_result from src.models import AudioBinary_data_list, AudioBinary_Config, VAD_Functor_result
from typing import Callable, List from typing import Callable, List
from queue import Queue, Empty from queue import Queue, Empty
import threading import threading
@ -13,6 +14,7 @@ from src.utils.logger import get_module_logger
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
class ASRFunctor(BaseFunctor): class ASRFunctor(BaseFunctor):
""" """
ASRFunctor ASRFunctor
@ -51,26 +53,26 @@ class ASRFunctor(BaseFunctor):
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务 重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
""" """
pass pass
def set_input_queue(self, queue: Queue) -> None: def set_input_queue(self, queue: Queue) -> None:
""" """
设置监听的输入消息队列 设置监听的输入消息队列
""" """
self._input_queue = queue self._input_queue = queue
def set_model(self, model: dict) -> None: def set_model(self, model: dict) -> None:
""" """
设置推理模型 设置推理模型
""" """
self._model = model self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config) -> None: def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
""" """
设置音频配置 设置音频配置
""" """
self._audio_config = audio_config self._audio_config = audio_config
logger.debug("ASRFunctor设置音频配置: %s", self._audio_config) logger.debug("ASRFunctor设置音频配置: %s", self._audio_config)
def add_callback(self, callback: Callable) -> None: def add_callback(self, callback: Callable) -> None:
""" """
向自身的_callback: List[Callable]回调函数列表中添加回调函数 向自身的_callback: List[Callable]回调函数列表中添加回调函数
@ -78,12 +80,12 @@ class ASRFunctor(BaseFunctor):
if not isinstance(self._callback, list): if not isinstance(self._callback, list):
self._callback = [] self._callback = []
self._callback.append(callback) self._callback.append(callback)
def _do_callback(self, result: List[str]) -> None: def _do_callback(self, result: List[str]) -> None:
""" """
回调函数 回调函数
""" """
text = result[0]['text'].replace(" ", "") text = result[0]["text"].replace(" ", "")
for callback in self._callback: for callback in self._callback:
callback(text) callback(text)
@ -98,7 +100,7 @@ class ASRFunctor(BaseFunctor):
hotwords=self._hotwords, hotwords=self._hotwords,
) )
self._do_callback(result) self._do_callback(result)
def _run(self) -> None: def _run(self) -> None:
""" """
线程运行逻辑 线程运行逻辑
@ -132,7 +134,7 @@ class ASRFunctor(BaseFunctor):
self._thread = threading.Thread(target=self._run, daemon=True) self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start() self._thread.start()
return self._thread return self._thread
def _pre_check(self) -> bool: def _pre_check(self) -> bool:
""" """
预检查 预检查
@ -146,7 +148,7 @@ class ASRFunctor(BaseFunctor):
if self._callback is None: if self._callback is None:
raise ValueError("回调函数未设置") raise ValueError("回调函数未设置")
return True return True
def stop(self) -> bool: def stop(self) -> bool:
""" """
停止线程 停止线程
@ -157,5 +159,3 @@ class ASRFunctor(BaseFunctor):
with self._status_lock: with self._status_lock:
self._is_running = False self._is_running = False
return not self._thread.is_alive() return not self._thread.is_alive()

View File

@ -4,18 +4,20 @@ Functor基础模块
该模块定义了Functor的基类,所有功能性的类(如VADPUNCASRSPK等)都应继承自这个基类 该模块定义了Functor的基类,所有功能性的类(如VADPUNCASRSPK等)都应继承自这个基类
基类提供了数据处理的基本框架,包括: 基类提供了数据处理的基本框架,包括:
- 回调函数管理 - 回调函数管理
- 模型配置管理 - 模型配置管理
- 线程运行控制 - 线程运行控制
主要类: 主要类:
BaseFunctor: Functor抽象类 BaseFunctor: Functor抽象类
FunctorFactory: Functor工厂类 FunctorFactory: Functor工厂类
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, List from typing import Callable, List
from queue import Queue from queue import Queue
import threading import threading
class BaseFunctor(ABC): class BaseFunctor(ABC):
""" """
Functor抽象类 Functor抽象类
@ -27,9 +29,7 @@ class BaseFunctor(ABC):
_model (dict): 存储模型相关的配置和实例 _model (dict): 存储模型相关的配置和实例
""" """
def __init__( def __init__(self):
self
):
""" """
初始化函数器 初始化函数器
@ -38,8 +38,8 @@ class BaseFunctor(ABC):
model (dict): 模型相关的配置和实例 model (dict): 模型相关的配置和实例
""" """
self._callback: List[Callable] = [] self._callback: List[Callable] = []
self._model: dict = {} self._model: dict = {}
# flag # flag
self._is_running: bool = False self._is_running: bool = False
self._stop_event: bool = False self._stop_event: bool = False
# 状态锁 # 状态锁
@ -91,7 +91,7 @@ class BaseFunctor(ABC):
返回: 返回:
线程实例 线程实例
""" """
@abstractmethod @abstractmethod
def _pre_check(self): def _pre_check(self):
""" """
@ -111,13 +111,12 @@ class BaseFunctor(ABC):
""" """
class FunctorFactory: class FunctorFactory:
""" """
Functor工厂类 Functor工厂类
该工厂类负责创建和配置Functor实例 该工厂类负责创建和配置Functor实例
主要方法: 主要方法:
make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor: make_funtor(funtor_name: str, config: dict, models: dict) -> BaseFunctor:
创建并配置Functor实例 创建并配置Functor实例
@ -138,58 +137,55 @@ class FunctorFactory:
""" """
if functor_name == "vad": if functor_name == "vad":
return cls._make_vadfunctor(config = config,models = models) return cls._make_vadfunctor(config=config, models=models)
elif functor_name == "asr": elif functor_name == "asr":
return cls._make_asrfunctor(config = config,models = models) return cls._make_asrfunctor(config=config, models=models)
elif functor_name == "spk": elif functor_name == "spk":
return cls._make_spkfunctor(config = config,models = models) return cls._make_spkfunctor(config=config, models=models)
else: else:
raise ValueError(f"不支持的Functor类型: {functor_name}") raise ValueError(f"不支持的Functor类型: {functor_name}")
def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor: def _make_vadfunctor(config: dict, models: dict) -> BaseFunctor:
""" """
创建VAD Functor实例 创建VAD Functor实例
""" """
from src.functor.vad_functor import VADFunctor from src.functor.vad_functor import VADFunctor
audio_config = config["audio_config"] audio_config = config["audio_config"]
model = { model = {"vad": models["vad"]}
"vad": models["vad"]
}
vad_functor = VADFunctor() vad_functor = VADFunctor()
vad_functor.set_audio_config(audio_config) vad_functor.set_audio_config(audio_config)
vad_functor.set_model(model) vad_functor.set_model(model)
return vad_functor return vad_functor
def _make_asrfunctor(config: dict, models: dict) -> BaseFunctor: def _make_asrfunctor(config: dict, models: dict) -> BaseFunctor:
""" """
创建ASR Functor实例 创建ASR Functor实例
""" """
from src.functor.asr_functor import ASRFunctor from src.functor.asr_functor import ASRFunctor
audio_config = config["audio_config"] audio_config = config["audio_config"]
model = { model = {"asr": models["asr"]}
"asr": models["asr"]
}
asr_functor = ASRFunctor() asr_functor = ASRFunctor()
asr_functor.set_audio_config(audio_config) asr_functor.set_audio_config(audio_config)
asr_functor.set_model(model) asr_functor.set_model(model)
return asr_functor return asr_functor
def _make_spkfunctor(config: dict, models: dict) -> BaseFunctor: def _make_spkfunctor(config: dict, models: dict) -> BaseFunctor:
""" """
创建SPK Functor实例 创建SPK Functor实例
""" """
from src.functor.spk_functor import SPKFunctor from src.functor.spk_functor import SPKFunctor
audio_config = config["audio_config"] audio_config = config["audio_config"]
model = { model = {"spk": models["spk"]}
"spk": models["spk"]
}
spk_functor = SPKFunctor() spk_functor = SPKFunctor()
spk_functor.set_audio_config(audio_config) spk_functor.set_audio_config(audio_config)
spk_functor.set_model(model) spk_functor.set_model(model)
return spk_functor return spk_functor

View File

@ -2,6 +2,7 @@
SpkFunctor SpkFunctor
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback 负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
""" """
from src.functor.base import BaseFunctor from src.functor.base import BaseFunctor
from src.models import AudioBinary_Config, VAD_Functor_result from src.models import AudioBinary_Config, VAD_Functor_result
from typing import Callable, List from typing import Callable, List
@ -13,6 +14,7 @@ from src.utils.logger import get_module_logger
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
class SPKFunctor(BaseFunctor): class SPKFunctor(BaseFunctor):
""" """
SPKFunctor SPKFunctor
@ -33,25 +35,24 @@ class SPKFunctor(BaseFunctor):
self._input_queue: Queue = None # 输入队列 self._input_queue: Queue = None # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置 self._audio_config: AudioBinary_Config = None # 音频配置
def reset_cache(self) -> None: def reset_cache(self) -> None:
""" """
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务 重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
""" """
pass pass
def set_input_queue(self, queue: Queue) -> None: def set_input_queue(self, queue: Queue) -> None:
""" """
设置监听的输入消息队列 设置监听的输入消息队列
""" """
self._input_queue = queue self._input_queue = queue
def set_model(self, model: dict) -> None: def set_model(self, model: dict) -> None:
""" """
设置推理模型 设置推理模型
""" """
self._model = model self._model = model
def set_audio_config(self, audio_config: AudioBinary_Config) -> None: def set_audio_config(self, audio_config: AudioBinary_Config) -> None:
""" """
设置音频配置 设置音频配置
@ -66,7 +67,7 @@ class SPKFunctor(BaseFunctor):
if not isinstance(self._callback, list): if not isinstance(self._callback, list):
self._callback = [] self._callback = []
self._callback.append(callback) self._callback.append(callback)
def _do_callback(self, result: List[str]) -> None: def _do_callback(self, result: List[str]) -> None:
""" """
回调函数 回调函数
@ -83,9 +84,9 @@ class SPKFunctor(BaseFunctor):
# input=binary_data, # input=binary_data,
# chunk_size=self._audio_config.chunk_size, # chunk_size=self._audio_config.chunk_size,
# ) # )
result = [{'result': "spk1", 'score': {"spk1": 0.9, "spk2": 0.3}}] result = [{"result": "spk1", "score": {"spk1": 0.9, "spk2": 0.3}}]
self._do_callback(result) self._do_callback(result)
def _run(self) -> None: def _run(self) -> None:
""" """
线程运行逻辑 线程运行逻辑
@ -108,7 +109,7 @@ class SPKFunctor(BaseFunctor):
except Exception as e: except Exception as e:
logger.error("SpkFunctor运行时发生错误: %s", e) logger.error("SpkFunctor运行时发生错误: %s", e)
raise e raise e
def run(self) -> threading.Thread: def run(self) -> threading.Thread:
""" """
启动线程 启动线程
@ -119,7 +120,7 @@ class SPKFunctor(BaseFunctor):
self._thread = threading.Thread(target=self._run, daemon=True) self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start() self._thread.start()
return self._thread return self._thread
def _pre_check(self) -> bool: def _pre_check(self) -> bool:
""" """
预检查 预检查
@ -131,7 +132,7 @@ class SPKFunctor(BaseFunctor):
if self._callback is None: if self._callback is None:
raise ValueError("回调函数未设置") raise ValueError("回调函数未设置")
return True return True
def stop(self) -> bool: def stop(self) -> bool:
""" """
停止线程 停止线程
@ -142,5 +143,3 @@ class SPKFunctor(BaseFunctor):
with self._status_lock: with self._status_lock:
self._is_running = False self._is_running = False
return not self._thread.is_alive() return not self._thread.is_alive()

View File

@ -2,6 +2,7 @@
VADFunctor VADFunctor
负责对音频片段进行VAD处理, 以VAD_Result进行callback 负责对音频片段进行VAD处理, 以VAD_Result进行callback
""" """
import threading import threading
from queue import Empty, Queue from queue import Empty, Queue
from typing import List, Any, Callable from typing import List, Any, Callable
@ -105,9 +106,7 @@ class VADFunctor(BaseFunctor):
self._callback = [] self._callback = []
self._callback.append(callback) self._callback.append(callback)
def _do_callback( def _do_callback(self, result: List[List[int]]) -> None:
self, result: List[List[int]]
) -> None:
""" """
回调函数 回调函数
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice

View File

@ -6,34 +6,36 @@
from src.utils.logger import get_module_logger from src.utils.logger import get_module_logger
from typing import Any, Dict, Type, Callable from typing import Any, Dict, Type, Callable
# 配置日志 # 配置日志
logger = get_module_logger(__name__, level="INFO") logger = get_module_logger(__name__, level="INFO")
class AutoAfterMeta(type): class AutoAfterMeta(type):
""" """
自动调用__after__函数的元类 自动调用__after__函数的元类
实现单例模式 实现单例模式
""" """
_instances: Dict[Type, Any] = {} # 存储单例实例 _instances: Dict[Type, Any] = {} # 存储单例实例
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
# 遍历所有属性 # 遍历所有属性
for attr_name, attr_value in attrs.items(): for attr_name, attr_value in attrs.items():
# 如果是函数且不是以_开头 # 如果是函数且不是以_开头
if callable(attr_value) and not attr_name.startswith('__'): if callable(attr_value) and not attr_name.startswith("__"):
# 获取原函数 # 获取原函数
original_func = attr_value original_func = attr_value
# 创建包装函数 # 创建包装函数
def make_wrapper(func): def make_wrapper(func):
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
# 执行原函数 # 执行原函数
result = func(self, *args, **kwargs) result = func(self, *args, **kwargs)
# 构建_after_函数名 # 构建_after_函数名
after_func_name = f"__after__{func.__name__}" after_func_name = f"__after__{func.__name__}"
# 检查是否存在对应的_after_函数 # 检查是否存在对应的_after_函数
if hasattr(self, after_func_name): if hasattr(self, after_func_name):
after_func = getattr(self, after_func_name) after_func = getattr(self, after_func_name)
@ -43,17 +45,18 @@ class AutoAfterMeta(type):
after_func() after_func()
except Exception as e: except Exception as e:
logger.error(f"调用{after_func_name}时出错: {e}") logger.error(f"调用{after_func_name}时出错: {e}")
return result return result
return wrapper return wrapper
# 替换原函数 # 替换原函数
attrs[attr_name] = make_wrapper(original_func) attrs[attr_name] = make_wrapper(original_func)
# 创建类 # 创建类
new_class = super().__new__(cls, name, bases, attrs) new_class = super().__new__(cls, name, bases, attrs)
return new_class return new_class
def __call__(cls, *args, **kwargs): def __call__(cls, *args, **kwargs):
""" """
重写__call__方法实现单例模式 重写__call__方法实现单例模式
@ -65,9 +68,10 @@ class AutoAfterMeta(type):
logger.info(f"创建{cls.__name__}的新实例") logger.info(f"创建{cls.__name__}的新实例")
else: else:
logger.debug(f"返回{cls.__name__}的现有实例") logger.debug(f"返回{cls.__name__}的现有实例")
return cls._instances[cls] return cls._instances[cls]
""" """
整体识别的处理逻辑 整体识别的处理逻辑
1.压入二进制音频信息 1.压入二进制音频信息
@ -88,10 +92,12 @@ from src.models import AudioBinary_Config
from src.models import AudioBinary_Chunk from src.models import AudioBinary_Chunk
from typing import List from typing import List
class LogicTrager(metaclass=AutoAfterMeta): class LogicTrager(metaclass=AutoAfterMeta):
"""逻辑触发器类""" """逻辑触发器类"""
def __init__(self, def __init__(
self,
audio_chunk_max_size: int = 1024 * 1024 * 10, audio_chunk_max_size: int = 1024 * 1024 * 10,
audio_config: AudioBinary_Config = None, audio_config: AudioBinary_Config = None,
result_callback: Callable = None, result_callback: Callable = None,
@ -99,46 +105,51 @@ class LogicTrager(metaclass=AutoAfterMeta):
): ):
"""初始化""" """初始化"""
# 存储音频块 # 存储音频块
self._audio_chunk : List[AudioBinary_Chunk] = [] self._audio_chunk: List[AudioBinary_Chunk] = []
# 存储二进制数据 # 存储二进制数据
self._audio_chunk_binary = b'' self._audio_chunk_binary = b""
self._audio_chunk_max_size = audio_chunk_max_size self._audio_chunk_max_size = audio_chunk_max_size
# 音频参数 # 音频参数
self._audio_config = audio_config if audio_config is not None else AudioBinary_Config() self._audio_config = (
audio_config if audio_config is not None else AudioBinary_Config()
)
# 结果队列 # 结果队列
self._result_queue = [] self._result_queue = []
# 聚合结果回调函数 # 聚合结果回调函数
self._aggregate_result_callback = result_callback self._aggregate_result_callback = result_callback
# 组件 # 组件
self._vad = VAD(VAD_model = models.get("vad"), audio_config = self._audio_config) self._vad = VAD(VAD_model=models.get("vad"), audio_config=self._audio_config)
self._vad.set_callback(self.push_audio_chunk) self._vad.set_callback(self.push_audio_chunk)
logger.info("初始化LogicTrager") logger.info("初始化LogicTrager")
def push_binary_data(self, chunk: bytes) -> None: def push_binary_data(self, chunk: bytes) -> None:
""" """
压入音频块至VAD模块 压入音频块至VAD模块
参数: 参数:
chunk: 音频数据块 chunk: 音频数据块
""" """
# print("LogicTrager push_binary_data", len(chunk)) # print("LogicTrager push_binary_data", len(chunk))
self._vad.push_binary_data(chunk) self._vad.push_binary_data(chunk)
self.__after__push_binary_data() self.__after__push_binary_data()
def __after__push_binary_data(self) -> None: def __after__push_binary_data(self) -> None:
""" """
添加音频块后处理 添加音频块后处理
""" """
# print("LogicTrager __after__push_binary_data") # print("LogicTrager __after__push_binary_data")
self._vad.process_vad_result() self._vad.process_vad_result()
def push_audio_chunk(self, chunk: AudioBinary_Chunk) -> None: def push_audio_chunk(self, chunk: AudioBinary_Chunk) -> None:
""" """
音频处理 音频处理
""" """
logger.info("LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(chunk.start_time, chunk.end_time, len(chunk.chunk))) logger.info(
"LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(
chunk.start_time, chunk.end_time, len(chunk.chunk)
)
)
self._audio_chunk.append(chunk) self._audio_chunk.append(chunk)
def __after__push_audio_chunk(self) -> None: def __after__push_audio_chunk(self) -> None:
@ -162,4 +173,4 @@ class LogicTrager(metaclass=AutoAfterMeta):
def __call__(self): def __call__(self):
"""调用函数""" """调用函数"""
pass pass

View File

@ -115,7 +115,7 @@ class ModelLoader:
self.models = {} self.models = {}
# 加载离线ASR模型 # 加载离线ASR模型
# 检查对应键是否存在 # 检查对应键是否存在
model_list = ['asr', 'asr_online', 'vad', 'punc', 'spk'] model_list = ["asr", "asr_online", "vad", "punc", "spk"]
for model_name in model_list: for model_name in model_list:
name_model = f"{model_name}_model" name_model = f"{model_name}_model"
name_model_revision = f"{model_name}_model_revision" name_model_revision = f"{model_name}_model_revision"

View File

@ -1,3 +1,9 @@
from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
from .vad import VAD_Functor_result from .vad import VAD_Functor_result
__all__ = ["AudioBinary_Config", "AudioBinary_data_list", "_AudioBinary_data", "VAD_Functor_result"]
__all__ = [
"AudioBinary_Config",
"AudioBinary_data_list",
"_AudioBinary_data",
"VAD_Functor_result",
]

View File

@ -8,8 +8,10 @@ logger = get_module_logger(__name__)
binary_data_types = (bytes, numpy.ndarray) binary_data_types = (bytes, numpy.ndarray)
class AudioBinary_Config(BaseModel): class AudioBinary_Config(BaseModel):
"""二进制音频块配置信息""" """二进制音频块配置信息"""
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -37,14 +39,16 @@ class AudioBinary_Config(BaseModel):
""" """
return int(frame * 1000 / self.sample_rate) return int(frame * 1000 / self.sample_rate)
class _AudioBinary_data(BaseModel): class _AudioBinary_data(BaseModel):
"""音频数据""" """音频数据"""
binary_data: binary_data_types = Field(description="音频二进制数据", default=None) binary_data: binary_data_types = Field(description="音频二进制数据", default=None)
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@validator('binary_data') @validator("binary_data")
def validate_binary_data(cls, v): def validate_binary_data(cls, v):
""" """
验证音频数据 验证音频数据
@ -54,7 +58,11 @@ class _AudioBinary_data(BaseModel):
binary_data_types: 音频数据 binary_data_types: 音频数据
""" """
if not isinstance(v, (bytes, numpy.ndarray)): if not isinstance(v, (bytes, numpy.ndarray)):
logger.warning("[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查", cls.__class__.__name__, type(v)) logger.warning(
"[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查",
cls.__class__.__name__,
type(v),
)
return v return v
def __len__(self): def __len__(self):
@ -64,14 +72,18 @@ class _AudioBinary_data(BaseModel):
int: 音频数据长度 int: 音频数据长度
""" """
return len(self.binary_data) return len(self.binary_data)
def __init__(self, binary_data: binary_data_types): def __init__(self, binary_data: binary_data_types):
""" """
初始化音频数据 初始化音频数据
Args: Args:
binary_data: 音频数据 binary_data: 音频数据
""" """
logger.debug("[%s]初始化音频数据, 数据类型为%s", self.__class__.__name__, type(binary_data)) logger.debug(
"[%s]初始化音频数据, 数据类型为%s",
self.__class__.__name__,
type(binary_data),
)
super().__init__(binary_data=binary_data) super().__init__(binary_data=binary_data)
def __getitem__(self): def __getitem__(self):
@ -82,9 +94,13 @@ class _AudioBinary_data(BaseModel):
""" """
return self.binary_data return self.binary_data
class AudioBinary_data_list(BaseModel): class AudioBinary_data_list(BaseModel):
"""音频数据列表""" """音频数据列表"""
binary_data_list: List[_AudioBinary_data] = Field(description="音频数据列表", default=[])
binary_data_list: List[_AudioBinary_data] = Field(
description="音频数据列表", default=[]
)
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -118,6 +134,7 @@ class AudioBinary_data_list(BaseModel):
""" """
return len(self.binary_data_list) return len(self.binary_data_list)
# class AudioBinary_Slice(BaseModel): # class AudioBinary_Slice(BaseModel):
# """音频块切片""" # """音频块切片"""
# target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None) # target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None)
@ -138,4 +155,4 @@ class AudioBinary_data_list(BaseModel):
# return v # return v
# def __call__(self): # def __call__(self):
# return self.target_Binary(self.start_index, self.end_index) # return self.target_Binary(self.start_index, self.end_index)

View File

@ -2,23 +2,27 @@ from pydantic import BaseModel, Field, validator
from typing import List, Optional, Callable, Any from typing import List, Optional, Callable, Any
from .audio import AudioBinary_data_list, _AudioBinary_data from .audio import AudioBinary_data_list, _AudioBinary_data
class VAD_Functor_result(BaseModel): class VAD_Functor_result(BaseModel):
""" """
VADFunctor结果 VADFunctor结果
""" """
audiobinary_data_list: AudioBinary_data_list = Field(description="音频数据列表") audiobinary_data_list: AudioBinary_data_list = Field(description="音频数据列表")
audiobinary_index: int = Field(description="音频数据索引") audiobinary_index: int = Field(description="音频数据索引")
audiobinary_data: _AudioBinary_data = Field(description="音频数据, 指向AudioBinary_data") audiobinary_data: _AudioBinary_data = Field(
description="音频数据, 指向AudioBinary_data"
)
start_time: int = Field(description="开始时间", is_required=True) start_time: int = Field(description="开始时间", is_required=True)
end_time: int = Field(description="结束时间", is_required=True) end_time: int = Field(description="结束时间", is_required=True)
@validator('audiobinary_data_list') @validator("audiobinary_data_list")
def validate_audiobinary_data_list(cls, v): def validate_audiobinary_data_list(cls, v):
if not isinstance(v, AudioBinary_data_list): if not isinstance(v, AudioBinary_data_list):
raise ValueError("audiobinary_data_list必须是AudioBinary_data_list类型") raise ValueError("audiobinary_data_list必须是AudioBinary_data_list类型")
return v return v
@validator('audiobinary_index') @validator("audiobinary_index")
def validate_audiobinary_index(cls, v): def validate_audiobinary_index(cls, v):
if not isinstance(v, int): if not isinstance(v, int):
raise ValueError("audiobinary_index必须是int类型") raise ValueError("audiobinary_index必须是int类型")
@ -26,35 +30,35 @@ class VAD_Functor_result(BaseModel):
raise ValueError("audiobinary_index必须大于0") raise ValueError("audiobinary_index必须大于0")
return v return v
@validator('audiobinary_data') @validator("audiobinary_data")
def validate_audiobinary_data(cls, v): def validate_audiobinary_data(cls, v):
if not isinstance(v, _AudioBinary_data): if not isinstance(v, _AudioBinary_data):
raise ValueError("audiobinary_data必须是AudioBinary_data类型") raise ValueError("audiobinary_data必须是AudioBinary_data类型")
return v return v
@validator('start_time') @validator("start_time")
def validate_start_time(cls, v): def validate_start_time(cls, v):
if not isinstance(v, int): if not isinstance(v, int):
raise ValueError("start_time必须是int类型") raise ValueError("start_time必须是int类型")
if v < 0: if v < 0:
raise ValueError("start_time必须大于0") raise ValueError("start_time必须大于0")
return v return v
@validator('end_time') @validator("end_time")
def validate_end_time(cls, v, values): def validate_end_time(cls, v, values):
if not isinstance(v, int): if not isinstance(v, int):
raise ValueError("end_time必须是int类型") raise ValueError("end_time必须是int类型")
if 'start_time' in values and v <= values['start_time']: if "start_time" in values and v <= values["start_time"]:
raise ValueError("end_time必须大于start_time") raise ValueError("end_time必须大于start_time")
return v return v
@classmethod @classmethod
def create_from_push_data( def create_from_push_data(
cls, cls,
audiobinary_data_list: AudioBinary_data_list, audiobinary_data_list: AudioBinary_data_list,
data: Any, data: Any,
start_time: int, start_time: int,
end_time: int end_time: int,
): ):
""" """
创建VAD片段 创建VAD片段
@ -66,7 +70,8 @@ class VAD_Functor_result(BaseModel):
audiobinary_index=index, audiobinary_index=index,
audiobinary_data=audiobinary_data_list[index], audiobinary_data=audiobinary_data_list[index],
start_time=start_time, start_time=start_time,
end_time=end_time) end_time=end_time,
)
def __len__(self): def __len__(self):
""" """
@ -78,11 +83,9 @@ class VAD_Functor_result(BaseModel):
""" """
字符串展示内容 字符串展示内容
""" """
tostr = f'audiobinary_data_index: {self.audiobinary_index}\n' tostr = f"audiobinary_data_index: {self.audiobinary_index}\n"
tostr += f'start_time: {self.start_time}\n' tostr += f"start_time: {self.start_time}\n"
tostr += f'end_time: {self.end_time}\n' tostr += f"end_time: {self.end_time}\n"
tostr += f'data_length: {len(self.audiobinary_data.binary_data)}\n' tostr += f"data_length: {len(self.audiobinary_data.binary_data)}\n"
tostr += f'data_type: {type(self.audiobinary_data.binary_data)}\n' tostr += f"data_type: {type(self.audiobinary_data.binary_data)}\n"
return tostr return tostr

View File

@ -7,11 +7,13 @@ import threading
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
class ASRPipeline(PipelineBase): class ASRPipeline(PipelineBase):
""" """
管道类 管道类
实现具体的处理逻辑 实现具体的处理逻辑
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
""" """
初始化管道 初始化管道
@ -91,27 +93,24 @@ class ASRPipeline(PipelineBase):
""" """
try: try:
from src.functor import FunctorFactory from src.functor import FunctorFactory
# 加载VAD、asr、spk functor # 加载VAD、asr、spk functor
self._functor_dict["vad"] = FunctorFactory.make_functor( self._functor_dict["vad"] = FunctorFactory.make_functor(
functor_name = "vad", functor_name="vad", config=self._config, models=self._models
config = self._config,
models = self._models
) )
self._functor_dict["asr"] = FunctorFactory.make_functor( self._functor_dict["asr"] = FunctorFactory.make_functor(
functor_name = "asr", functor_name="asr", config=self._config, models=self._models
config = self._config,
models = self._models
) )
self._functor_dict["spk"] = FunctorFactory.make_functor( self._functor_dict["spk"] = FunctorFactory.make_functor(
functor_name = "spk", functor_name="spk", config=self._config, models=self._models
config = self._config,
models = self._models
) )
# 创建音频数据存储单元 # 创建音频数据存储单元
self._audio_binary_data_list = AudioBinary_data_list() self._audio_binary_data_list = AudioBinary_data_list()
self._functor_dict["vad"].set_audio_binary_data_list(self._audio_binary_data_list) self._functor_dict["vad"].set_audio_binary_data_list(
self._audio_binary_data_list
)
# 初始化子队列 # 初始化子队列
self._subqueue_dict["original"] = Queue() self._subqueue_dict["original"] = Queue()
@ -121,7 +120,7 @@ class ASRPipeline(PipelineBase):
self._subqueue_dict["spkend"] = Queue() self._subqueue_dict["spkend"] = Queue()
# 设置子队列的输入队列 # 设置子队列的输入队列
self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"]) self._functor_dict["vad"].set_input_queue(self._subqueue_dict["original"])
self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"]) self._functor_dict["asr"].set_input_queue(self._subqueue_dict["vad2asr"])
self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"]) self._functor_dict["spk"].set_input_queue(self._subqueue_dict["vad2spk"])
@ -134,30 +133,40 @@ class ASRPipeline(PipelineBase):
""" """
带回调函数的put 带回调函数的put
""" """
def put_with_check(data: Any) -> None: def put_with_check(data: Any) -> None:
queue.put(data) queue.put(data)
callback(data) callback(data)
return put_with_check return put_with_check
self._functor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result)) self._functor_dict["asr"].add_callback(
self._functor_dict["spk"].add_callback(put_with_checkcallback(self._subqueue_dict["spkend"], self._check_result)) put_with_checkcallback(
self._subqueue_dict["asrend"], self._check_result
)
)
self._functor_dict["spk"].add_callback(
put_with_checkcallback(
self._subqueue_dict["spkend"], self._check_result
)
)
except ImportError: except ImportError:
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化") raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
def _check_result(self, result: Any) -> None: def _check_result(self, result: Any) -> None:
""" """
检查结果 检查结果
""" """
# 若asr和spk队列中都有数据则合并数据 # 若asr和spk队列中都有数据则合并数据
if self._subqueue_dict["asrend"].qsize() & self._subqueue_dict["spkend"].qsize(): if (
self._subqueue_dict["asrend"].qsize()
& self._subqueue_dict["spkend"].qsize()
):
asr_data = self._subqueue_dict["asrend"].get() asr_data = self._subqueue_dict["asrend"].get()
spk_data = self._subqueue_dict["spkend"].get() spk_data = self._subqueue_dict["spkend"].get()
# 合并数据 # 合并数据
result = { result = {"asr_data": asr_data, "spk_data": spk_data}
"asr_data": asr_data,
"spk_data": spk_data
}
# 通知回调函数 # 通知回调函数
self._notify_callbacks(result) self._notify_callbacks(result)
@ -211,13 +220,13 @@ class ASRPipeline(PipelineBase):
logger.info("收到结束信号,管道准备停止") logger.info("收到结束信号,管道准备停止")
self._input_queue.task_done() # 标记结束信号已处理 self._input_queue.task_done() # 标记结束信号已处理
break break
# 处理数据 # 处理数据
self._process(data) self._process(data)
# 标记任务完成 # 标记任务完成
self._input_queue.task_done() self._input_queue.task_done()
except Empty: except Empty:
# 队列获取超时,继续等待 # 队列获取超时,继续等待
continue continue
@ -237,7 +246,7 @@ class ASRPipeline(PipelineBase):
""" """
# 子类实现具体的处理逻辑 # 子类实现具体的处理逻辑
self._subqueue_dict["original"].put(data) self._subqueue_dict["original"].put(data)
def stop(self) -> None: def stop(self) -> None:
""" """
停止管道 停止管道

View File

@ -8,11 +8,13 @@ import time
# 配置日志 # 配置日志
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PipelineBase(ABC): class PipelineBase(ABC):
""" """
管道基类 管道基类
定义了管道的基本接口和通用功能 定义了管道的基本接口和通用功能
""" """
def __init__(self, input_queue: Optional[Queue] = None): def __init__(self, input_queue: Optional[Queue] = None):
""" """
初始化管道 初始化管道
@ -93,7 +95,7 @@ class PipelineBase(ABC):
if self._thread and self._thread.is_alive(): if self._thread and self._thread.is_alive():
timeout = timeout if timeout is not None else self._stop_timeout timeout = timeout if timeout is not None else self._stop_timeout
self._thread.join(timeout=timeout) self._thread.join(timeout=timeout)
# 检查是否成功停止 # 检查是否成功停止
if self._thread.is_alive(): if self._thread.is_alive():
logger.warning(f"管道停止超时({timeout}秒),强制终止") logger.warning(f"管道停止超时({timeout}秒),强制终止")
@ -101,7 +103,7 @@ class PipelineBase(ABC):
else: else:
logger.info("管道已成功停止") logger.info("管道已成功停止")
return True return True
return True return True
def force_stop(self) -> None: def force_stop(self) -> None:
@ -115,14 +117,35 @@ class PipelineBase(ABC):
# 注意Python的线程无法被强制终止这里只是设置标志 # 注意Python的线程无法被强制终止这里只是设置标志
# 实际终止需要依赖操作系统的进程管理 # 实际终止需要依赖操作系统的进程管理
class PipelineFactory: class PipelineFactory:
""" """
管道工厂类 管道工厂类
用于创建管道实例 用于创建管道实例
""" """
@staticmethod
def create_pipeline(pipeline_name: str) -> Any: from src.pipeline.ASRpipeline import ASRPipeline
def _create_pipeline_ASRpipeline(*args, **kwargs) -> ASRPipeline:
"""
创建ASR管道实例
"""
from src.pipeline.ASRpipeline import ASRPipeline
pipeline = ASRPipeline()
pipeline.set_config(kwargs["config"])
pipeline.set_models(kwargs["models"])
pipeline.set_audio_binary(kwargs["audio_binary"])
pipeline.set_input_queue(kwargs["input_queue"])
pipeline.add_callback(kwargs["callback"])
pipeline.bake()
return pipeline
@classmethod
def create_pipeline(cls, pipeline_name: str, *args, **kwargs) -> Any:
""" """
创建管道实例 创建管道实例
""" """
pass if pipeline_name == "ASRpipeline":
return cls._create_pipeline_ASRpipeline(*args, **kwargs)
else:
raise ValueError(f"不支持的管道类型: {pipeline_name}")

View File

@ -172,7 +172,7 @@ class STTRunner(RunnerBase):
logger.warning( logger.warning(
"等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理", "等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理",
self._stop_timeout, self._stop_timeout,
self._input_queue.qsize() self._input_queue.qsize(),
) )
success = False success = False
break break
@ -188,7 +188,7 @@ class STTRunner(RunnerBase):
"错误堆栈:\n%s", "错误堆栈:\n%s",
error_type, error_type,
error_msg, error_msg,
error_traceback error_traceback,
) )
success = False success = False
@ -198,7 +198,7 @@ class STTRunner(RunnerBase):
logger.warning( logger.warning(
"部分管道停止失败,队列状态: 大小=%d, 是否为空=%s", "部分管道停止失败,队列状态: 大小=%d, 是否为空=%s",
self._input_queue.qsize(), self._input_queue.qsize(),
self._input_queue.empty() self._input_queue.empty(),
) )
return success return success

View File

@ -43,7 +43,7 @@ async def clear_websocket():
async def ws_serve(websocket, path): async def ws_serve(websocket, path):
""" """
WebSocket服务主函数处理客户端连接和消息 WebSocket服务主函数处理客户端连接和消息
参数: 参数:
websocket: WebSocket连接对象 websocket: WebSocket连接对象
path: 连接路径 path: 连接路径
@ -51,13 +51,13 @@ async def ws_serve(websocket, path):
frames = [] # 存储所有音频帧 frames = [] # 存储所有音频帧
frames_asr = [] # 存储用于离线ASR的音频帧 frames_asr = [] # 存储用于离线ASR的音频帧
frames_asr_online = [] # 存储用于在线ASR的音频帧 frames_asr_online = [] # 存储用于在线ASR的音频帧
global websocket_users global websocket_users
# await clear_websocket() # 清理现有连接(目前注释掉,允许多客户端) # await clear_websocket() # 清理现有连接(目前注释掉,允许多客户端)
# 添加到用户集合 # 添加到用户集合
websocket_users.add(websocket) websocket_users.add(websocket)
# 初始化连接状态 # 初始化连接状态
websocket.status_dict_asr = {} websocket.status_dict_asr = {}
websocket.status_dict_asr_online = {"cache": {}, "is_final": False} websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
@ -66,15 +66,15 @@ async def ws_serve(websocket, path):
websocket.chunk_interval = 10 websocket.chunk_interval = 10
websocket.vad_pre_idx = 0 websocket.vad_pre_idx = 0
websocket.is_speaking = True # 默认用户正在说话 websocket.is_speaking = True # 默认用户正在说话
# 语音检测状态 # 语音检测状态
speech_start = False speech_start = False
speech_end_i = -1 speech_end_i = -1
# 初始化配置 # 初始化配置
websocket.wav_name = "microphone" websocket.wav_name = "microphone"
websocket.mode = "2pass" # 默认使用两阶段识别模式 websocket.mode = "2pass" # 默认使用两阶段识别模式
print("新用户已连接", flush=True) print("新用户已连接", flush=True)
try: try:
@ -84,11 +84,13 @@ async def ws_serve(websocket, path):
if isinstance(message, str): if isinstance(message, str):
try: try:
messagejson = json.loads(message) messagejson = json.loads(message)
# 更新各种配置参数 # 更新各种配置参数
if "is_speaking" in messagejson: if "is_speaking" in messagejson:
websocket.is_speaking = messagejson["is_speaking"] websocket.is_speaking = messagejson["is_speaking"]
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking websocket.status_dict_asr_online["is_final"] = (
not websocket.is_speaking
)
if "chunk_interval" in messagejson: if "chunk_interval" in messagejson:
websocket.chunk_interval = messagejson["chunk_interval"] websocket.chunk_interval = messagejson["chunk_interval"]
if "wav_name" in messagejson: if "wav_name" in messagejson:
@ -97,11 +99,17 @@ async def ws_serve(websocket, path):
chunk_size = messagejson["chunk_size"] chunk_size = messagejson["chunk_size"]
if isinstance(chunk_size, str): if isinstance(chunk_size, str):
chunk_size = chunk_size.split(",") chunk_size = chunk_size.split(",")
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size] websocket.status_dict_asr_online["chunk_size"] = [
int(x) for x in chunk_size
]
if "encoder_chunk_look_back" in messagejson: if "encoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"] websocket.status_dict_asr_online["encoder_chunk_look_back"] = (
messagejson["encoder_chunk_look_back"]
)
if "decoder_chunk_look_back" in messagejson: if "decoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"] websocket.status_dict_asr_online["decoder_chunk_look_back"] = (
messagejson["decoder_chunk_look_back"]
)
if "hotword" in messagejson: if "hotword" in messagejson:
websocket.status_dict_asr["hotword"] = messagejson["hotwords"] websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
if "mode" in messagejson: if "mode" in messagejson:
@ -111,11 +119,17 @@ async def ws_serve(websocket, path):
# 根据chunk_interval更新VAD的chunk_size # 根据chunk_interval更新VAD的chunk_size
websocket.status_dict_vad["chunk_size"] = int( websocket.status_dict_vad["chunk_size"] = int(
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1] * 60 / websocket.chunk_interval websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1]
* 60
/ websocket.chunk_interval
) )
# 处理音频数据 # 处理音频数据
if len(frames_asr_online) > 0 or len(frames_asr) >= 0 or not isinstance(message, str): if (
len(frames_asr_online) > 0
or len(frames_asr) >= 0
or not isinstance(message, str)
):
if not isinstance(message, str): # 二进制音频数据 if not isinstance(message, str): # 二进制音频数据
# 添加到帧缓冲区 # 添加到帧缓冲区
frames.append(message) frames.append(message)
@ -125,10 +139,12 @@ async def ws_serve(websocket, path):
# 处理在线ASR # 处理在线ASR
frames_asr_online.append(message) frames_asr_online.append(message)
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1 websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
# 达到chunk_interval或最终帧时处理在线ASR # 达到chunk_interval或最终帧时处理在线ASR
if (len(frames_asr_online) % websocket.chunk_interval == 0 or if (
websocket.status_dict_asr_online["is_final"]): len(frames_asr_online) % websocket.chunk_interval == 0
or websocket.status_dict_asr_online["is_final"]
):
if websocket.mode == "2pass" or websocket.mode == "online": if websocket.mode == "2pass" or websocket.mode == "online":
audio_in = b"".join(frames_asr_online) audio_in = b"".join(frames_asr_online)
try: try:
@ -136,26 +152,32 @@ async def ws_serve(websocket, path):
except Exception as e: except Exception as e:
print(f"在线ASR处理错误: {e}") print(f"在线ASR处理错误: {e}")
frames_asr_online = [] frames_asr_online = []
# 如果检测到语音开始收集帧用于离线ASR # 如果检测到语音开始收集帧用于离线ASR
if speech_start: if speech_start:
frames_asr.append(message) frames_asr.append(message)
# VAD处理 - 语音活动检测 # VAD处理 - 语音活动检测
try: try:
speech_start_i, speech_end_i = await asr_service.async_vad(websocket, message) speech_start_i, speech_end_i = await asr_service.async_vad(
websocket, message
)
except Exception as e: except Exception as e:
print(f"VAD处理错误: {e}") print(f"VAD处理错误: {e}")
# 检测到语音开始 # 检测到语音开始
if speech_start_i != -1: if speech_start_i != -1:
speech_start = True speech_start = True
# 计算开始偏移并收集前面的帧 # 计算开始偏移并收集前面的帧
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms beg_bias = (
frames_pre = frames[-beg_bias:] if beg_bias < len(frames) else frames websocket.vad_pre_idx - speech_start_i
) // duration_ms
frames_pre = (
frames[-beg_bias:] if beg_bias < len(frames) else frames
)
frames_asr = [] frames_asr = []
frames_asr.extend(frames_pre) frames_asr.extend(frames_pre)
# 处理离线ASR (语音结束或用户停止说话) # 处理离线ASR (语音结束或用户停止说话)
if speech_end_i != -1 or not websocket.is_speaking: if speech_end_i != -1 or not websocket.is_speaking:
if websocket.mode == "2pass" or websocket.mode == "offline": if websocket.mode == "2pass" or websocket.mode == "offline":
@ -164,13 +186,13 @@ async def ws_serve(websocket, path):
await asr_service.async_asr(websocket, audio_in) await asr_service.async_asr(websocket, audio_in)
except Exception as e: except Exception as e:
print(f"离线ASR处理错误: {e}") print(f"离线ASR处理错误: {e}")
# 重置状态 # 重置状态
frames_asr = [] frames_asr = []
speech_start = False speech_start = False
frames_asr_online = [] frames_asr_online = []
websocket.status_dict_asr_online["cache"] = {} websocket.status_dict_asr_online["cache"] = {}
# 如果用户停止说话,完全重置 # 如果用户停止说话,完全重置
if not websocket.is_speaking: if not websocket.is_speaking:
websocket.vad_pre_idx = 0 websocket.vad_pre_idx = 0
@ -193,34 +215,34 @@ async def ws_serve(websocket, path):
def start_server(args, asr_service_instance): def start_server(args, asr_service_instance):
""" """
启动WebSocket服务器 启动WebSocket服务器
参数: 参数:
args: 命令行参数 args: 命令行参数
asr_service_instance: ASR服务实例 asr_service_instance: ASR服务实例
""" """
global asr_service global asr_service
asr_service = asr_service_instance asr_service = asr_service_instance
# 配置SSL (如果提供了证书) # 配置SSL (如果提供了证书)
if args.certfile and len(args.certfile) > 0: if args.certfile and len(args.certfile) > 0:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile) ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile)
start_server = websockets.serve( start_server = websockets.serve(
ws_serve, args.host, args.port, ws_serve,
subprotocols=["binary"], args.host,
ping_interval=None, args.port,
ssl=ssl_context subprotocols=["binary"],
ping_interval=None,
ssl=ssl_context,
) )
else: else:
start_server = websockets.serve( start_server = websockets.serve(
ws_serve, args.host, args.port, ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None
subprotocols=["binary"],
ping_interval=None
) )
print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}") print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}")
# 启动事件循环 # 启动事件循环
asyncio.get_event_loop().run_until_complete(start_server) asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever() asyncio.get_event_loop().run_forever()
@ -229,14 +251,14 @@ def start_server(args, asr_service_instance):
if __name__ == "__main__": if __name__ == "__main__":
# 解析命令行参数 # 解析命令行参数
args = parse_args() args = parse_args()
# 加载模型 # 加载模型
print("正在加载模型...") print("正在加载模型...")
models = load_models(args) models = load_models(args)
print("模型加载完成!当前仅支持单个客户端同时连接!") print("模型加载完成!当前仅支持单个客户端同时连接!")
# 创建ASR服务 # 创建ASR服务
asr_service = ASRService(models) asr_service = ASRService(models)
# 启动服务器 # 启动服务器
start_server(args, asr_service) start_server(args, asr_service)

View File

@ -9,11 +9,11 @@ import json
class ASRService: class ASRService:
"""ASR服务类封装各种语音识别相关功能""" """ASR服务类封装各种语音识别相关功能"""
def __init__(self, models): def __init__(self, models):
""" """
初始化ASR服务 初始化ASR服务
参数: 参数:
models: 包含各种预加载模型的字典 models: 包含各种预加载模型的字典
""" """
@ -21,42 +21,41 @@ class ASRService:
self.model_asr_streaming = models["asr_streaming"] self.model_asr_streaming = models["asr_streaming"]
self.model_vad = models["vad"] self.model_vad = models["vad"]
self.model_punc = models["punc"] self.model_punc = models["punc"]
async def async_vad(self, websocket, audio_in): async def async_vad(self, websocket, audio_in):
""" """
语音活动检测 语音活动检测
参数: 参数:
websocket: WebSocket连接 websocket: WebSocket连接
audio_in: 二进制音频数据 audio_in: 二进制音频数据
返回: 返回:
tuple: (speech_start, speech_end) 语音开始和结束位置 tuple: (speech_start, speech_end) 语音开始和结束位置
""" """
# 使用VAD模型分析音频段 # 使用VAD模型分析音频段
segments_result = self.model_vad.generate( segments_result = self.model_vad.generate(
input=audio_in, input=audio_in, **websocket.status_dict_vad
**websocket.status_dict_vad
)[0]["value"] )[0]["value"]
speech_start = -1 speech_start = -1
speech_end = -1 speech_end = -1
# 解析VAD结果 # 解析VAD结果
if len(segments_result) == 0 or len(segments_result) > 1: if len(segments_result) == 0 or len(segments_result) > 1:
return speech_start, speech_end return speech_start, speech_end
if segments_result[0][0] != -1: if segments_result[0][0] != -1:
speech_start = segments_result[0][0] speech_start = segments_result[0][0]
if segments_result[0][1] != -1: if segments_result[0][1] != -1:
speech_end = segments_result[0][1] speech_end = segments_result[0][1]
return speech_start, speech_end return speech_start, speech_end
async def async_asr(self, websocket, audio_in): async def async_asr(self, websocket, audio_in):
""" """
离线ASR处理 离线ASR处理
参数: 参数:
websocket: WebSocket连接 websocket: WebSocket连接
audio_in: 二进制音频数据 audio_in: 二进制音频数据
@ -64,42 +63,44 @@ class ASRService:
if len(audio_in) > 0: if len(audio_in) > 0:
# 使用离线ASR模型处理音频 # 使用离线ASR模型处理音频
rec_result = self.model_asr.generate( rec_result = self.model_asr.generate(
input=audio_in, input=audio_in, **websocket.status_dict_asr
**websocket.status_dict_asr
)[0] )[0]
# 如果有标点符号模型且识别出文本,则添加标点 # 如果有标点符号模型且识别出文本,则添加标点
if self.model_punc is not None and len(rec_result["text"]) > 0: if self.model_punc is not None and len(rec_result["text"]) > 0:
rec_result = self.model_punc.generate( rec_result = self.model_punc.generate(
input=rec_result["text"], input=rec_result["text"], **websocket.status_dict_punc
**websocket.status_dict_punc
)[0] )[0]
# 如果识别出文本,发送到客户端 # 如果识别出文本,发送到客户端
if len(rec_result["text"]) > 0: if len(rec_result["text"]) > 0:
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({ message = json.dumps(
"mode": mode, {
"text": rec_result["text"], "mode": mode,
"wav_name": websocket.wav_name, "text": rec_result["text"],
"is_final": websocket.is_speaking, "wav_name": websocket.wav_name,
}) "is_final": websocket.is_speaking,
}
)
await websocket.send(message) await websocket.send(message)
else: else:
# 如果没有音频数据,发送空文本 # 如果没有音频数据,发送空文本
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({ message = json.dumps(
"mode": mode, {
"text": "", "mode": mode,
"wav_name": websocket.wav_name, "text": "",
"is_final": websocket.is_speaking, "wav_name": websocket.wav_name,
}) "is_final": websocket.is_speaking,
}
)
await websocket.send(message) await websocket.send(message)
async def async_asr_online(self, websocket, audio_in): async def async_asr_online(self, websocket, audio_in):
""" """
在线ASR处理 在线ASR处理
参数: 参数:
websocket: WebSocket连接 websocket: WebSocket连接
audio_in: 二进制音频数据 audio_in: 二进制音频数据
@ -107,21 +108,24 @@ class ASRService:
if len(audio_in) > 0: if len(audio_in) > 0:
# 使用在线ASR模型处理音频 # 使用在线ASR模型处理音频
rec_result = self.model_asr_streaming.generate( rec_result = self.model_asr_streaming.generate(
input=audio_in, input=audio_in, **websocket.status_dict_asr_online
**websocket.status_dict_asr_online
)[0] )[0]
# 在2pass模式下如果是最终帧则跳过(留给离线ASR处理) # 在2pass模式下如果是最终帧则跳过(留给离线ASR处理)
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False): if websocket.mode == "2pass" and websocket.status_dict_asr_online.get(
"is_final", False
):
return return
# 如果识别出文本,发送到客户端 # 如果识别出文本,发送到客户端
if len(rec_result["text"]): if len(rec_result["text"]):
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({ message = json.dumps(
"mode": mode, {
"text": rec_result["text"], "mode": mode,
"wav_name": websocket.wav_name, "text": rec_result["text"],
"is_final": websocket.is_speaking, "wav_name": websocket.wav_name,
}) "is_final": websocket.is_speaking,
await websocket.send(message) }
)
await websocket.send(message)

View File

@ -1,3 +1,3 @@
from .logger import get_module_logger, setup_root_logger from .logger import get_module_logger, setup_root_logger
__all__ = ["get_module_logger", "setup_root_logger"] __all__ = ["get_module_logger", "setup_root_logger"]

View File

@ -1,10 +1,12 @@
""" """
处理各类音频数据与bytes的转换 处理各类音频数据与bytes的转换
""" """
import wave import wave
from pydub import AudioSegment from pydub import AudioSegment
import io import io
def wav_to_bytes(wav_path: str) -> bytes: def wav_to_bytes(wav_path: str) -> bytes:
""" """
将WAV文件读取为bytes数据 将WAV文件读取为bytes数据
@ -14,24 +16,26 @@ def wav_to_bytes(wav_path: str) -> bytes:
返回: 返回:
bytes: WAV文件的原始字节数据 bytes: WAV文件的原始字节数据
异常: 异常:
FileNotFoundError: 如果WAV文件不存在 FileNotFoundError: 如果WAV文件不存在
wave.Error: 如果文件不是有效的WAV文件 wave.Error: 如果文件不是有效的WAV文件
""" """
try: try:
with wave.open(wav_path, 'rb') as wf: with wave.open(wav_path, "rb") as wf:
# 读取所有帧 # 读取所有帧
frames = wf.readframes(wf.getnframes()) frames = wf.readframes(wf.getnframes())
return frames return frames
except FileNotFoundError: except FileNotFoundError:
# 可以选择记录日志或重新抛出,这里为了清晰直接重新抛出 # 可以选择记录日志或重新抛出,这里为了清晰直接重新抛出
raise FileNotFoundError(f"错误未找到WAV文件 '{wav_path}'") raise FileNotFoundError(f"错误: 未找到WAV文件 '{wav_path}'")
except wave.Error as e: except wave.Error as e:
raise wave.Error(f"错误打开或读取WAV文件 '{wav_path}' 失败 - {e}") raise wave.Error(f"错误: 打开或读取WAV文件 '{wav_path}' 失败 - {e}")
def bytes_to_wav(bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int): def bytes_to_wav(
bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int
):
""" """
将bytes数据写入为WAV文件 将bytes数据写入为WAV文件
@ -41,22 +45,23 @@ def bytes_to_wav(bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: in
nchannels (int): 声道数 (例如 1 for mono, 2 for stereo) nchannels (int): 声道数 (例如 1 for mono, 2 for stereo)
sampwidth (int): 采样宽度 (字节数, 例如 2 for 16-bit audio) sampwidth (int): 采样宽度 (字节数, 例如 2 for 16-bit audio)
framerate (int): 采样率 (例如 44100, 16000) framerate (int): 采样率 (例如 44100, 16000)
异常: 异常:
wave.Error: 如果写入WAV文件失败 wave.Error: 如果写入WAV文件失败
""" """
try: try:
with wave.open(wav_path, 'wb') as wf: with wave.open(wav_path, "wb") as wf:
wf.setnchannels(nchannels) wf.setnchannels(nchannels)
wf.setsampwidth(sampwidth) wf.setsampwidth(sampwidth)
wf.setframerate(framerate) wf.setframerate(framerate)
wf.writeframes(bytes_data) wf.writeframes(bytes_data)
except wave.Error as e: except wave.Error as e:
raise wave.Error(f"错误写入WAV文件 '{wav_path}' 失败 - {e}") raise wave.Error(f"错误: 写入WAV文件 '{wav_path}' 失败 - {e}")
except Exception as e: except Exception as e:
# 捕获其他可能的写入错误 # 捕获其他可能的写入错误
raise Exception(f"写入WAV文件 '{wav_path}' 时发生未知错误 - {e}") raise Exception(f"写入WAV文件 '{wav_path}' 时发生未知错误 - {e}")
def mp3_to_bytes(mp3_path: str) -> bytes: def mp3_to_bytes(mp3_path: str) -> bytes:
""" """
将MP3文件转换为bytes数据 (原始PCM数据) 将MP3文件转换为bytes数据 (原始PCM数据)
@ -66,7 +71,7 @@ def mp3_to_bytes(mp3_path: str) -> bytes:
返回: 返回:
bytes: MP3文件解码后的原始PCM字节数据 bytes: MP3文件解码后的原始PCM字节数据
异常: 异常:
FileNotFoundError: 如果MP3文件不存在 FileNotFoundError: 如果MP3文件不存在
pydub.exceptions.CouldntDecodeError: 如果MP3文件无法解码 pydub.exceptions.CouldntDecodeError: 如果MP3文件无法解码
@ -76,12 +81,19 @@ def mp3_to_bytes(mp3_path: str) -> bytes:
# 获取原始PCM数据 # 获取原始PCM数据
return audio.raw_data return audio.raw_data
except FileNotFoundError: except FileNotFoundError:
raise FileNotFoundError(f"错误未找到MP3文件 '{mp3_path}'") raise FileNotFoundError(f"错误: 未找到MP3文件 '{mp3_path}'")
except Exception as e: # pydub 可能抛出多种解码相关的错误 except Exception as e: # pydub 可能抛出多种解码相关的错误
raise Exception(f"错误处理MP3文件 '{mp3_path}' 失败 - {e}") raise Exception(f"错误: 处理MP3文件 '{mp3_path}' 失败 - {e}")
def bytes_to_mp3(bytes_data: bytes, mp3_path: str, frame_rate: int, channels: int, sample_width: int, bitrate: str = "192k"): def bytes_to_mp3(
bytes_data: bytes,
mp3_path: str,
frame_rate: int,
channels: int,
sample_width: int,
bitrate: str = "192k",
):
""" """
将原始PCM bytes数据转换为MP3文件 将原始PCM bytes数据转换为MP3文件
@ -102,9 +114,9 @@ def bytes_to_mp3(bytes_data: bytes, mp3_path: str, frame_rate: int, channels: in
data=bytes_data, data=bytes_data,
sample_width=sample_width, sample_width=sample_width,
frame_rate=frame_rate, frame_rate=frame_rate,
channels=channels channels=channels,
) )
# 导出为MP3 # 导出为MP3
audio.export(mp3_path, format="mp3", bitrate=bitrate) audio.export(mp3_path, format="mp3", bitrate=bitrate)
except Exception as e: except Exception as e:
raise Exception(f"错误转换或写入MP3文件 '{mp3_path}' 失败 - {e}") raise Exception(f"错误: 转换或写入MP3文件 '{mp3_path}' 失败 - {e}")

View File

@ -3,6 +3,7 @@ import sys
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
def setup_logger( def setup_logger(
name: str = None, name: str = None,
level: str = "INFO", level: str = "INFO",
@ -12,80 +13,79 @@ def setup_logger(
) -> logging.Logger: ) -> logging.Logger:
""" """
设置并返回一个配置好的logger实例 设置并返回一个配置好的logger实例
Args: Args:
name: logger的名称默认为None使用root logger name: logger的名称默认为None使用root logger
level: 日志级别默认为"INFO" level: 日志级别默认为"INFO"
log_file: 日志文件路径默认为None仅控制台输出 log_file: 日志文件路径默认为None仅控制台输出
log_format: 日志格式 log_format: 日志格式
date_format: 日期格式 date_format: 日期格式
Returns: Returns:
logging.Logger: 配置好的logger实例 logging.Logger: 配置好的logger实例
""" """
# 获取logger实例 # 获取logger实例
logger = logging.getLogger(name) logger = logging.getLogger(name)
# 设置日志级别 # 设置日志级别
level = getattr(logging, level.upper()) level = getattr(logging, level.upper())
logger.setLevel(level) logger.setLevel(level)
print(f"添加处理器 {name} {log_file} {log_format} {date_format}") print(f"添加处理器 {name} {log_file} {log_format} {date_format}")
# 创建格式器 # 创建格式器
formatter = logging.Formatter(log_format, date_format) formatter = logging.Formatter(log_format, date_format)
# 添加控制台处理器 # 添加控制台处理器
console_handler = logging.StreamHandler(sys.stdout) console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter) console_handler.setFormatter(formatter)
logger.addHandler(console_handler) logger.addHandler(console_handler)
# 如果指定了日志文件,添加文件处理器 # 如果指定了日志文件,添加文件处理器
if log_file: if log_file:
# 确保日志目录存在 # 确保日志目录存在
log_path = Path(log_file) log_path = Path(log_file)
log_path.parent.mkdir(parents=True, exist_ok=True) log_path.parent.mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(log_file, encoding='utf-8') file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
logger.addHandler(file_handler) logger.addHandler(file_handler)
# 注意:移除了 propagate = False允许日志传递 # 注意:移除了 propagate = False允许日志传递
return logger return logger
def setup_root_logger(
level: str = "INFO", def setup_root_logger(level: str = "INFO", log_file: Optional[str] = None) -> None:
log_file: Optional[str] = None
) -> None:
""" """
配置根日志器 配置根日志器
Args: Args:
level: 日志级别 level: 日志级别
log_file: 日志文件路径 log_file: 日志文件路径
""" """
setup_logger(None, level, log_file) setup_logger(None, level, log_file)
def get_module_logger( def get_module_logger(
module_name: str, module_name: str,
level: Optional[str] = None, # 改为可选参数 level: Optional[str] = None, # 改为可选参数
log_file: Optional[str] = None # 一般不需要单独指定 log_file: Optional[str] = None, # 一般不需要单独指定
) -> logging.Logger: ) -> logging.Logger:
""" """
获取模块级别的logger 获取模块级别的logger
Args: Args:
module_name: 模块名称通常传入__name__ module_name: 模块名称通常传入__name__
level: 可选的日志级别如果不指定则继承父级配置 level: 可选的日志级别如果不指定则继承父级配置
log_file: 可选的日志文件路径一般不需要指定 log_file: 可选的日志文件路径一般不需要指定
""" """
logger = logging.getLogger(module_name) logger = logging.getLogger(module_name)
# 只有显式指定了level才设置 # 只有显式指定了level才设置
if level: if level:
logger.setLevel(getattr(logging, level.upper())) logger.setLevel(getattr(logging, level.upper()))
# 只有显式指定了log_file才添加文件处理器 # 只有显式指定了log_file才添加文件处理器
if log_file: if log_file:
setup_logger(module_name, level or "INFO", log_file) setup_logger(module_name, level or "INFO", log_file)
return logger return logger

View File

@ -1,10 +1,15 @@
from tests.functor.vad_test import test_vad_functor """
测试主函数
请在tests目录下创建测试文件, 并在此文件中调用
"""
from tests.pipeline.asr_test import test_asr_pipeline from tests.pipeline.asr_test import test_asr_pipeline
from src.utils.logger import get_module_logger, setup_root_logger from src.utils.logger import get_module_logger, setup_root_logger
setup_root_logger(level="INFO", log_file="logs/test_main.log") setup_root_logger(level="INFO", log_file="logs/test_main.log")
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
# from tests.functor.vad_test import test_vad_functor
# logger.info("开始测试VAD函数器") # logger.info("开始测试VAD函数器")
# test_vad_functor() # test_vad_functor()

View File

@ -1 +1 @@
"""FunASR WebSocket服务测试模块""" """FunASR WebSocket服务测试模块"""

View File

@ -2,6 +2,7 @@
Functor测试 Functor测试
VAD测试 VAD测试
""" """
from src.functor.vad_functor import VADFunctor from src.functor.vad_functor import VADFunctor
from src.functor.asr_functor import ASRFunctor from src.functor.asr_functor import ASRFunctor
from src.functor.spk_functor import SPKFunctor from src.functor.spk_functor import SPKFunctor
@ -21,6 +22,7 @@ logger = get_module_logger(__name__)
model_loader = ModelLoader() model_loader = ModelLoader()
def test_vad_functor(): def test_vad_functor():
# 加载模型 # 加载模型
args = { args = {
@ -38,9 +40,9 @@ def test_vad_functor():
chunk_stride=1600, chunk_stride=1600,
sample_rate=sample_rate, sample_rate=sample_rate,
sample_width=16, sample_width=16,
channels=1 channels=1,
) )
chunk_stride = int(audio_config.chunk_size*sample_rate/1000) chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
audio_config.chunk_stride = chunk_stride audio_config.chunk_stride = chunk_stride
# 创建输入队列 # 创建输入队列
input_queue = Queue() input_queue = Queue()
@ -62,9 +64,7 @@ def test_vad_functor():
vad_functor.add_callback(lambda x: vad2asr_queue.put(x)) vad_functor.add_callback(lambda x: vad2asr_queue.put(x))
vad_functor.add_callback(lambda x: vad2spk_queue.put(x)) vad_functor.add_callback(lambda x: vad2spk_queue.put(x))
# 设置模型 # 设置模型
vad_functor.set_model({ vad_functor.set_model({"vad": model_loader.models["vad"]})
'vad': model_loader.models['vad']
})
# 启动VAD函数器 # 启动VAD函数器
vad_functor.run() vad_functor.run()
@ -77,9 +77,7 @@ def test_vad_functor():
# 设置回调函数 # 设置回调函数
asr_functor.add_callback(lambda x: print(f"asr callback: {x}")) asr_functor.add_callback(lambda x: print(f"asr callback: {x}"))
# 设置模型 # 设置模型
asr_functor.set_model({ asr_functor.set_model({"asr": model_loader.models["asr"]})
'asr': model_loader.models['asr']
})
# 启动ASR函数器 # 启动ASR函数器
asr_functor.run() asr_functor.run()
@ -92,23 +90,25 @@ def test_vad_functor():
# 设置回调函数 # 设置回调函数
spk_functor.add_callback(lambda x: print(f"spk callback: {x}")) spk_functor.add_callback(lambda x: print(f"spk callback: {x}"))
# 设置模型 # 设置模型
spk_functor.set_model({ spk_functor.set_model(
# 'spk': model_loader.models['spk'] {
'spk': 'fake_spk' # 'spk': model_loader.models['spk']
}) "spk": "fake_spk"
}
)
# 启动SPK函数器 # 启动SPK函数器
spk_functor.run() spk_functor.run()
f_binary = f_data f_binary = f_data
audio_clip_len = 200 audio_clip_len = 200
print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}") print(
f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}"
)
for i in range(0, len(f_binary), audio_clip_len): for i in range(0, len(f_binary), audio_clip_len):
binary_data = f_binary[i:i+audio_clip_len] binary_data = f_binary[i : i + audio_clip_len]
input_queue.put(binary_data) input_queue.put(binary_data)
# 等待VAD函数器结束 # 等待VAD函数器结束
vad_functor.stop() vad_functor.stop()
print("[vad_test] VAD函数器结束") print("[vad_test] VAD函数器结束")
@ -119,4 +119,6 @@ def test_vad_functor():
if OVERWATCH: if OVERWATCH:
for index in range(len(audio_binary_data_list)): for index in range(len(audio_binary_data_list)):
save_path = f"tests/vad_test_output_{index}.wav" save_path = f"tests/vad_test_output_{index}.wav"
soundfile.write(save_path, audio_binary_data_list[index].binary_data, sample_rate) soundfile.write(
save_path, audio_binary_data_list[index].binary_data, sample_rate
)

View File

@ -1,10 +1,21 @@
"""
模型使用测试
此处主要用于各类调用模型的处理数据与输出格式
请在主目录下test_main.py中调用
将需要测试的模型定义在函数中进行测试, 函数名称需要与测试内容匹配
"""
from funasr import AutoModel from funasr import AutoModel
from typing import List, Dict, Any from typing import List, Dict, Any
from src.models import VADResponse from src.models import VADResponse
import time import time
def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]: def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
chunk_size = 100 # ms """
在线VAD模型使用
"""
chunk_size = 100 # ms
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True) model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
vad_result = VADResponse() vad_result = VADResponse()
@ -16,12 +27,14 @@ def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
chunk_stride = int(chunk_size * sample_rate / 1000) chunk_stride = int(chunk_size * sample_rate / 1000)
cache = {} cache = {}
total_chunk_num = int(len((speech)-1)/chunk_stride+1) total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
for i in range(total_chunk_num): for i in range(total_chunk_num):
time.sleep(0.1) time.sleep(0.1)
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
is_final = i == total_chunk_num - 1 is_final = i == total_chunk_num - 1
res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size) res = model.generate(
input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size
)
if len(res[0]["value"]): if len(res[0]["value"]):
vad_result += VADResponse.from_raw(res) vad_result += VADResponse.from_raw(res)
for item in res[0]["value"]: for item in res[0]["value"]:
@ -32,44 +45,64 @@ def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
# print(item) # print(item)
return vad_result return vad_result
def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]: def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
"""
在线VAD模型使用
测试LogicTrager
在Rebuild版本后LogicTrager中已弃用
"""
from src.logic_trager import LogicTrager from src.logic_trager import LogicTrager
import soundfile import soundfile
from src.config import parse_args from src.config import parse_args
args = parse_args() args = parse_args()
# from src.functor.model_loader import load_models # from src.functor.model_loader import load_models
# models = load_models(args) # models = load_models(args)
from src.model_loader import ModelLoader from src.model_loader import ModelLoader
models = ModelLoader(args) models = ModelLoader(args)
chunk_size = 200 # ms chunk_size = 200 # ms
from src.models import AudioBinary_Config from src.models import AudioBinary_Config
import soundfile import soundfile
speech, sample_rate = soundfile.read(file_path) speech, sample_rate = soundfile.read(file_path)
chunk_stride = int(chunk_size * sample_rate / 1000) chunk_stride = int(chunk_size * sample_rate / 1000)
audio_config = AudioBinary_Config(sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size) audio_config = AudioBinary_Config(
sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size
)
logic_trager = LogicTrager(models=models, audio_config=audio_config) logic_trager = LogicTrager(models=models, audio_config=audio_config)
for i in range(len(speech)//chunk_stride+1): for i in range(len(speech) // chunk_stride + 1):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
logic_trager.push_binary_data(speech_chunk) logic_trager.push_binary_data(speech_chunk)
# for item in items: # for item in items:
# print(item) # print(item)
return None return None
def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]: def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
"""
ASR模型使用
离线ASR模型使用
"""
from funasr import AutoModel from funasr import AutoModel
model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
vad_model="fsmn-vad", vad_model_revision="v2.0.4", model = AutoModel(
# punc_model="ct-punc-c", punc_model_revision="v2.0.4", model="paraformer-zh",
spk_model="cam++", spk_model_revision="v2.0.2", model_revision="v2.0.4",
spk_mode="vad_segment", vad_model="fsmn-vad",
auto_update=False, vad_model_revision="v2.0.4",
) # punc_model="ct-punc-c", punc_model_revision="v2.0.4",
spk_model="cam++",
spk_model_revision="v2.0.2",
spk_mode="vad_segment",
auto_update=False,
)
import soundfile import soundfile
@ -80,7 +113,9 @@ def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
result = model.generate(speech) result = model.generate(speech)
return result return result
if __name__ == "__main__":
# vad_result = vad_model_use_online("tests/vad_example.wav") # if __name__ == "__main__":
vad_result = vad_model_use_online_logic("tests/vad_example.wav") # 请在主目录下调用test_main.py文件进行测试
# print(vad_result) # vad_result = vad_model_use_online("tests/vad_example.wav")
# vad_result = vad_model_use_online_logic("tests/vad_example.wav")
# print(vad_result)

View File

@ -2,7 +2,9 @@
Pipeline测试 Pipeline测试
VAD+ASR+SPK(FAKE) VAD+ASR+SPK(FAKE)
""" """
from src.pipeline.ASRpipeline import ASRPipeline from src.pipeline.ASRpipeline import ASRPipeline
from src.pipeline import PipelineFactory
from src.models import AudioBinary_data_list, AudioBinary_Config from src.models import AudioBinary_data_list, AudioBinary_Config
from src.model_loader import ModelLoader from src.model_loader import ModelLoader
from queue import Queue from queue import Queue
@ -18,6 +20,7 @@ OVAERWATCH = False
model_loader = ModelLoader() model_loader = ModelLoader()
def test_asr_pipeline(): def test_asr_pipeline():
# 加载模型 # 加载模型
args = { args = {
@ -36,9 +39,9 @@ def test_asr_pipeline():
chunk_stride=1600, chunk_stride=1600,
sample_rate=sample_rate, sample_rate=sample_rate,
sample_width=16, sample_width=16,
channels=1 channels=1,
) )
chunk_stride = int(audio_config.chunk_size*sample_rate/1000) chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
audio_config.chunk_stride = chunk_stride audio_config.chunk_stride = chunk_stride
# 创建参数Dict # 创建参数Dict
@ -52,29 +55,39 @@ def test_asr_pipeline():
input_queue = Queue() input_queue = Queue()
# 创建Pipeline # 创建Pipeline
asr_pipeline = ASRPipeline() # asr_pipeline = ASRPipeline()
asr_pipeline.set_models(models) # asr_pipeline.set_models(models)
asr_pipeline.set_config(config) # asr_pipeline.set_config(config)
asr_pipeline.set_audio_binary(audio_binary_data_list) # asr_pipeline.set_audio_binary(audio_binary_data_list)
asr_pipeline.set_input_queue(input_queue) # asr_pipeline.set_input_queue(input_queue)
asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}")) # asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
asr_pipeline.bake() # asr_pipeline.bake()
asr_pipeline = PipelineFactory.create_pipeline(
pipeline_name = "ASRpipeline",
models=models,
config=config,
audio_binary=audio_binary_data_list,
input_queue=input_queue,
callback=lambda x: print(f"pipeline callback: {x}")
)
# 运行Pipeline # 运行Pipeline
asr_instance = asr_pipeline.run() asr_instance = asr_pipeline.run()
audio_clip_len = 200 audio_clip_len = 200
print(f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}") print(
f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}"
)
for i in range(0, len(audio_data), audio_clip_len): for i in range(0, len(audio_data), audio_clip_len):
input_queue.put(audio_data[i:i+audio_clip_len]) input_queue.put(audio_data[i : i + audio_clip_len])
# time.sleep(10) # time.sleep(10)
# input_queue.put(None) # input_queue.put(None)
# 等待Pipeline结束 # 等待Pipeline结束
# asr_instance.join() # asr_instance.join()
time.sleep(5) time.sleep(5)
asr_pipeline.stop() asr_pipeline.stop()
# asr_pipeline.stop() # asr_pipeline.stop()

View File

@ -10,23 +10,23 @@ import os
from unittest.mock import patch from unittest.mock import patch
# 将src目录添加到路径 # 将src目录添加到路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.config import parse_args from src.config import parse_args
def test_default_args(): def test_default_args():
"""测试默认参数值""" """测试默认参数值"""
with patch('sys.argv', ['script.py']): with patch("sys.argv", ["script.py"]):
args = parse_args() args = parse_args()
# 检查服务器参数 # 检查服务器参数
assert args.host == "0.0.0.0" assert args.host == "0.0.0.0"
assert args.port == 10095 assert args.port == 10095
# 检查SSL参数 # 检查SSL参数
assert args.certfile == "" assert args.certfile == ""
assert args.keyfile == "" assert args.keyfile == ""
# 检查模型参数 # 检查模型参数
assert "paraformer" in args.asr_model assert "paraformer" in args.asr_model
assert args.asr_model_revision == "v2.0.4" assert args.asr_model_revision == "v2.0.4"
@ -36,7 +36,7 @@ def test_default_args():
assert args.vad_model_revision == "v2.0.4" assert args.vad_model_revision == "v2.0.4"
assert "punc" in args.punc_model assert "punc" in args.punc_model
assert args.punc_model_revision == "v2.0.4" assert args.punc_model_revision == "v2.0.4"
# 检查硬件配置 # 检查硬件配置
assert args.ngpu == 1 assert args.ngpu == 1
assert args.device == "cuda" assert args.device == "cuda"
@ -46,19 +46,26 @@ def test_default_args():
def test_custom_args(): def test_custom_args():
"""测试自定义参数值""" """测试自定义参数值"""
test_args = [ test_args = [
'script.py', "script.py",
'--host', 'localhost', "--host",
'--port', '8080', "localhost",
'--certfile', 'cert.pem', "--port",
'--keyfile', 'key.pem', "8080",
'--asr_model', 'custom_model', "--certfile",
'--ngpu', '0', "cert.pem",
'--device', 'cpu' "--keyfile",
"key.pem",
"--asr_model",
"custom_model",
"--ngpu",
"0",
"--device",
"cpu",
] ]
with patch('sys.argv', test_args): with patch("sys.argv", test_args):
args = parse_args() args = parse_args()
# 检查自定义参数 # 检查自定义参数
assert args.host == "localhost" assert args.host == "localhost"
assert args.port == 8080 assert args.port == 8080
@ -66,4 +73,4 @@ def test_custom_args():
assert args.keyfile == "key.pem" assert args.keyfile == "key.pem"
assert args.asr_model == "custom_model" assert args.asr_model == "custom_model"
assert args.ngpu == 0 assert args.ngpu == 0
assert args.device == "cpu" assert args.device == "cpu"