[代码结构]black . 对所有文件格式调整,无功能变化。

This commit is contained in:
Ziyang.Zhang 2025-06-12 15:49:43 +08:00
parent 5b94c40016
commit 5a820b49e4
28 changed files with 543 additions and 421 deletions

View File

@ -1,11 +1,7 @@
from funasr import AutoModel
chunk_size = 200 # ms
model = AutoModel(
model="fsmn-vad",
model_revision="v2.0.4",
disable_update=True
)
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
import soundfile
@ -23,7 +19,7 @@ for i in range(total_chunk_num):
cache=cache,
is_final=is_final,
chunk_size=chunk_size,
disable_pbar=True
disable_pbar=True,
)
if len(res[0]["value"]):
print(res)

View File

@ -46,7 +46,7 @@ class AudioBinary:
else:
raise ValueError("参数类型错误")
def add_slice_listener(self, slice_listener: callable):
def add_slice_listener(self, slice_listener: callable) -> None:
"""
添加切片监听器
参数:
@ -128,7 +128,7 @@ class AudioChunk:
此类仅用于AudioBinary与Funtor的交互, 不负责其它逻辑
"""
_instance = None
_instance: Optional[AudioChunk] = None
def __new__(cls, *args, **kwargs):
"""
@ -138,7 +138,7 @@ class AudioChunk:
cls._instance = super(AudioChunk, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self):
def __init__(self) -> None:
"""
初始化AudioChunk实例
"""

View File

@ -21,41 +21,23 @@ def parse_args():
"--host",
type=str,
default="0.0.0.0",
help="服务器主机地址例如localhost, 0.0.0.0"
)
parser.add_argument(
"--port",
type=int,
default=10095,
help="WebSocket服务器端口"
help="服务器主机地址例如localhost, 0.0.0.0",
)
parser.add_argument("--port", type=int, default=10095, help="WebSocket服务器端口")
# SSL配置
parser.add_argument(
"--certfile",
type=str,
default="",
help="SSL证书文件路径"
)
parser.add_argument(
"--keyfile",
type=str,
default="",
help="SSL密钥文件路径"
)
parser.add_argument("--certfile", type=str, default="", help="SSL证书文件路径")
parser.add_argument("--keyfile", type=str, default="", help="SSL密钥文件路径")
# ASR模型配置
parser.add_argument(
"--asr_model",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="离线ASR模型从ModelScope获取"
help="离线ASR模型从ModelScope获取",
)
parser.add_argument(
"--asr_model_revision",
type=str,
default="v2.0.4",
help="离线ASR模型版本"
"--asr_model_revision", type=str, default="v2.0.4", help="离线ASR模型版本"
)
# 在线ASR模型配置
@ -63,13 +45,13 @@ def parse_args():
"--asr_model_online",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
help="在线ASR模型从ModelScope获取"
help="在线ASR模型从ModelScope获取",
)
parser.add_argument(
"--asr_model_online_revision",
type=str,
default="v2.0.4",
help="在线ASR模型版本"
help="在线ASR模型版本",
)
# VAD模型配置
@ -77,13 +59,10 @@ def parse_args():
"--vad_model",
type=str,
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
help="VAD语音活动检测模型从ModelScope获取"
help="VAD语音活动检测模型从ModelScope获取",
)
parser.add_argument(
"--vad_model_revision",
type=str,
default="v2.0.4",
help="VAD模型版本"
"--vad_model_revision", type=str, default="v2.0.4", help="VAD模型版本"
)
# 标点符号模型配置
@ -91,34 +70,18 @@ def parse_args():
"--punc_model",
type=str,
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
help="标点符号模型从ModelScope获取"
help="标点符号模型从ModelScope获取",
)
parser.add_argument(
"--punc_model_revision",
type=str,
default="v2.0.4",
help="标点符号模型版本"
"--punc_model_revision", type=str, default="v2.0.4", help="标点符号模型版本"
)
# 硬件配置
parser.add_argument("--ngpu", type=int, default=1, help="GPU数量0表示仅使用CPU")
parser.add_argument(
"--ngpu",
type=int,
default=1,
help="GPU数量0表示仅使用CPU"
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="设备类型cuda或cpu"
)
parser.add_argument(
"--ncpu",
type=int,
default=4,
help="CPU核心数"
"--device", type=str, default="cuda", help="设备类型cuda或cpu"
)
parser.add_argument("--ncpu", type=int, default=4, help="CPU核心数")
return parser.parse_args()

View File

@ -2,6 +2,7 @@
ASRFunctor
负责对音频片段进行ASR处理, 以ASR_Result进行callback
"""
from src.functor.base import BaseFunctor
from src.models import AudioBinary_data_list, AudioBinary_Config, VAD_Functor_result
from typing import Callable, List
@ -13,6 +14,7 @@ from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class ASRFunctor(BaseFunctor):
"""
ASRFunctor
@ -83,7 +85,7 @@ class ASRFunctor(BaseFunctor):
"""
回调函数
"""
text = result[0]['text'].replace(" ", "")
text = result[0]["text"].replace(" ", "")
for callback in self._callback:
callback(text)
@ -157,5 +159,3 @@ class ASRFunctor(BaseFunctor):
with self._status_lock:
self._is_running = False
return not self._thread.is_alive()

View File

@ -11,11 +11,13 @@ Functor基础模块
BaseFunctor: Functor抽象类
FunctorFactory: Functor工厂类
"""
from abc import ABC, abstractmethod
from typing import Callable, List
from queue import Queue
import threading
class BaseFunctor(ABC):
"""
Functor抽象类
@ -27,9 +29,7 @@ class BaseFunctor(ABC):
_model (dict): 存储模型相关的配置和实例
"""
def __init__(
self
):
def __init__(self):
"""
初始化函数器
@ -111,7 +111,6 @@ class BaseFunctor(ABC):
"""
class FunctorFactory:
"""
Functor工厂类
@ -151,10 +150,9 @@ class FunctorFactory:
创建VAD Functor实例
"""
from src.functor.vad_functor import VADFunctor
audio_config = config["audio_config"]
model = {
"vad": models["vad"]
}
model = {"vad": models["vad"]}
vad_functor = VADFunctor()
vad_functor.set_audio_config(audio_config)
@ -167,10 +165,9 @@ class FunctorFactory:
创建ASR Functor实例
"""
from src.functor.asr_functor import ASRFunctor
audio_config = config["audio_config"]
model = {
"asr": models["asr"]
}
model = {"asr": models["asr"]}
asr_functor = ASRFunctor()
asr_functor.set_audio_config(audio_config)
@ -183,10 +180,9 @@ class FunctorFactory:
创建SPK Functor实例
"""
from src.functor.spk_functor import SPKFunctor
audio_config = config["audio_config"]
model = {
"spk": models["spk"]
}
model = {"spk": models["spk"]}
spk_functor = SPKFunctor()
spk_functor.set_audio_config(audio_config)

View File

@ -2,6 +2,7 @@
SpkFunctor
负责对音频片段进行SPK说话人识别处理, 以SPK_Result进行callback
"""
from src.functor.base import BaseFunctor
from src.models import AudioBinary_Config, VAD_Functor_result
from typing import Callable, List
@ -13,6 +14,7 @@ from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
class SPKFunctor(BaseFunctor):
"""
SPKFunctor
@ -33,7 +35,6 @@ class SPKFunctor(BaseFunctor):
self._input_queue: Queue = None # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置
def reset_cache(self) -> None:
"""
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
@ -83,7 +84,7 @@ class SPKFunctor(BaseFunctor):
# input=binary_data,
# chunk_size=self._audio_config.chunk_size,
# )
result = [{'result': "spk1", 'score': {"spk1": 0.9, "spk2": 0.3}}]
result = [{"result": "spk1", "score": {"spk1": 0.9, "spk2": 0.3}}]
self._do_callback(result)
def _run(self) -> None:
@ -142,5 +143,3 @@ class SPKFunctor(BaseFunctor):
with self._status_lock:
self._is_running = False
return not self._thread.is_alive()

View File

@ -2,6 +2,7 @@
VADFunctor
负责对音频片段进行VAD处理, 以VAD_Result进行callback
"""
import threading
from queue import Empty, Queue
from typing import List, Any, Callable
@ -105,9 +106,7 @@ class VADFunctor(BaseFunctor):
self._callback = []
self._callback.append(callback)
def _do_callback(
self, result: List[List[int]]
) -> None:
def _do_callback(self, result: List[List[int]]) -> None:
"""
回调函数
VADFunctor包装结果, 存储到AudioBinary中, 并向队列中添加AudioBinary_Slice

View File

@ -6,9 +6,11 @@
from src.utils.logger import get_module_logger
from typing import Any, Dict, Type, Callable
# 配置日志
logger = get_module_logger(__name__, level="INFO")
class AutoAfterMeta(type):
"""
自动调用__after__函数的元类
@ -21,7 +23,7 @@ class AutoAfterMeta(type):
# 遍历所有属性
for attr_name, attr_value in attrs.items():
# 如果是函数且不是以_开头
if callable(attr_value) and not attr_name.startswith('__'):
if callable(attr_value) and not attr_name.startswith("__"):
# 获取原函数
original_func = attr_value
@ -45,6 +47,7 @@ class AutoAfterMeta(type):
logger.error(f"调用{after_func_name}时出错: {e}")
return result
return wrapper
# 替换原函数
@ -68,6 +71,7 @@ class AutoAfterMeta(type):
return cls._instances[cls]
"""
整体识别的处理逻辑
1.压入二进制音频信息
@ -88,10 +92,12 @@ from src.models import AudioBinary_Config
from src.models import AudioBinary_Chunk
from typing import List
class LogicTrager(metaclass=AutoAfterMeta):
"""逻辑触发器类"""
def __init__(self,
def __init__(
self,
audio_chunk_max_size: int = 1024 * 1024 * 10,
audio_config: AudioBinary_Config = None,
result_callback: Callable = None,
@ -101,10 +107,12 @@ class LogicTrager(metaclass=AutoAfterMeta):
# 存储音频块
self._audio_chunk: List[AudioBinary_Chunk] = []
# 存储二进制数据
self._audio_chunk_binary = b''
self._audio_chunk_binary = b""
self._audio_chunk_max_size = audio_chunk_max_size
# 音频参数
self._audio_config = audio_config if audio_config is not None else AudioBinary_Config()
self._audio_config = (
audio_config if audio_config is not None else AudioBinary_Config()
)
# 结果队列
self._result_queue = []
# 聚合结果回调函数
@ -113,7 +121,6 @@ class LogicTrager(metaclass=AutoAfterMeta):
self._vad = VAD(VAD_model=models.get("vad"), audio_config=self._audio_config)
self._vad.set_callback(self.push_audio_chunk)
logger.info("初始化LogicTrager")
def push_binary_data(self, chunk: bytes) -> None:
@ -138,7 +145,11 @@ class LogicTrager(metaclass=AutoAfterMeta):
"""
音频处理
"""
logger.info("LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(chunk.start_time, chunk.end_time, len(chunk.chunk)))
logger.info(
"LogicTrager push_audio_chunk [{}ms:{}ms] (len={})".format(
chunk.start_time, chunk.end_time, len(chunk.chunk)
)
)
self._audio_chunk.append(chunk)
def __after__push_audio_chunk(self) -> None:

View File

@ -115,7 +115,7 @@ class ModelLoader:
self.models = {}
# 加载离线ASR模型
# 检查对应键是否存在
model_list = ['asr', 'asr_online', 'vad', 'punc', 'spk']
model_list = ["asr", "asr_online", "vad", "punc", "spk"]
for model_name in model_list:
name_model = f"{model_name}_model"
name_model_revision = f"{model_name}_model_revision"

View File

@ -1,3 +1,9 @@
from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
from .vad import VAD_Functor_result
__all__ = ["AudioBinary_Config", "AudioBinary_data_list", "_AudioBinary_data", "VAD_Functor_result"]
__all__ = [
"AudioBinary_Config",
"AudioBinary_data_list",
"_AudioBinary_data",
"VAD_Functor_result",
]

View File

@ -8,8 +8,10 @@ logger = get_module_logger(__name__)
binary_data_types = (bytes, numpy.ndarray)
class AudioBinary_Config(BaseModel):
"""二进制音频块配置信息"""
class Config:
arbitrary_types_allowed = True
@ -37,14 +39,16 @@ class AudioBinary_Config(BaseModel):
"""
return int(frame * 1000 / self.sample_rate)
class _AudioBinary_data(BaseModel):
"""音频数据"""
binary_data: binary_data_types = Field(description="音频二进制数据", default=None)
class Config:
arbitrary_types_allowed = True
@validator('binary_data')
@validator("binary_data")
def validate_binary_data(cls, v):
"""
验证音频数据
@ -54,7 +58,11 @@ class _AudioBinary_data(BaseModel):
binary_data_types: 音频数据
"""
if not isinstance(v, (bytes, numpy.ndarray)):
logger.warning("[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查", cls.__class__.__name__, type(v))
logger.warning(
"[%s]binary_data不是bytes, numpy.ndarray类型, 而是%s类型, 请检查",
cls.__class__.__name__,
type(v),
)
return v
def __len__(self):
@ -71,7 +79,11 @@ class _AudioBinary_data(BaseModel):
Args:
binary_data: 音频数据
"""
logger.debug("[%s]初始化音频数据, 数据类型为%s", self.__class__.__name__, type(binary_data))
logger.debug(
"[%s]初始化音频数据, 数据类型为%s",
self.__class__.__name__,
type(binary_data),
)
super().__init__(binary_data=binary_data)
def __getitem__(self):
@ -82,9 +94,13 @@ class _AudioBinary_data(BaseModel):
"""
return self.binary_data
class AudioBinary_data_list(BaseModel):
"""音频数据列表"""
binary_data_list: List[_AudioBinary_data] = Field(description="音频数据列表", default=[])
binary_data_list: List[_AudioBinary_data] = Field(
description="音频数据列表", default=[]
)
class Config:
arbitrary_types_allowed = True
@ -118,6 +134,7 @@ class AudioBinary_data_list(BaseModel):
"""
return len(self.binary_data_list)
# class AudioBinary_Slice(BaseModel):
# """音频块切片"""
# target_Binary: AudioBinary_data_list = Field(description="目标音频块", default=None)

View File

@ -2,23 +2,27 @@ from pydantic import BaseModel, Field, validator
from typing import List, Optional, Callable, Any
from .audio import AudioBinary_data_list, _AudioBinary_data
class VAD_Functor_result(BaseModel):
"""
VADFunctor结果
"""
audiobinary_data_list: AudioBinary_data_list = Field(description="音频数据列表")
audiobinary_index: int = Field(description="音频数据索引")
audiobinary_data: _AudioBinary_data = Field(description="音频数据, 指向AudioBinary_data")
audiobinary_data: _AudioBinary_data = Field(
description="音频数据, 指向AudioBinary_data"
)
start_time: int = Field(description="开始时间", is_required=True)
end_time: int = Field(description="结束时间", is_required=True)
@validator('audiobinary_data_list')
@validator("audiobinary_data_list")
def validate_audiobinary_data_list(cls, v):
if not isinstance(v, AudioBinary_data_list):
raise ValueError("audiobinary_data_list必须是AudioBinary_data_list类型")
return v
@validator('audiobinary_index')
@validator("audiobinary_index")
def validate_audiobinary_index(cls, v):
if not isinstance(v, int):
raise ValueError("audiobinary_index必须是int类型")
@ -26,13 +30,13 @@ class VAD_Functor_result(BaseModel):
raise ValueError("audiobinary_index必须大于0")
return v
@validator('audiobinary_data')
@validator("audiobinary_data")
def validate_audiobinary_data(cls, v):
if not isinstance(v, _AudioBinary_data):
raise ValueError("audiobinary_data必须是AudioBinary_data类型")
return v
@validator('start_time')
@validator("start_time")
def validate_start_time(cls, v):
if not isinstance(v, int):
raise ValueError("start_time必须是int类型")
@ -40,11 +44,11 @@ class VAD_Functor_result(BaseModel):
raise ValueError("start_time必须大于0")
return v
@validator('end_time')
@validator("end_time")
def validate_end_time(cls, v, values):
if not isinstance(v, int):
raise ValueError("end_time必须是int类型")
if 'start_time' in values and v <= values['start_time']:
if "start_time" in values and v <= values["start_time"]:
raise ValueError("end_time必须大于start_time")
return v
@ -54,7 +58,7 @@ class VAD_Functor_result(BaseModel):
audiobinary_data_list: AudioBinary_data_list,
data: Any,
start_time: int,
end_time: int
end_time: int,
):
"""
创建VAD片段
@ -66,7 +70,8 @@ class VAD_Functor_result(BaseModel):
audiobinary_index=index,
audiobinary_data=audiobinary_data_list[index],
start_time=start_time,
end_time=end_time)
end_time=end_time,
)
def __len__(self):
"""
@ -78,11 +83,9 @@ class VAD_Functor_result(BaseModel):
"""
字符串展示内容
"""
tostr = f'audiobinary_data_index: {self.audiobinary_index}\n'
tostr += f'start_time: {self.start_time}\n'
tostr += f'end_time: {self.end_time}\n'
tostr += f'data_length: {len(self.audiobinary_data.binary_data)}\n'
tostr += f'data_type: {type(self.audiobinary_data.binary_data)}\n'
tostr = f"audiobinary_data_index: {self.audiobinary_index}\n"
tostr += f"start_time: {self.start_time}\n"
tostr += f"end_time: {self.end_time}\n"
tostr += f"data_length: {len(self.audiobinary_data.binary_data)}\n"
tostr += f"data_type: {type(self.audiobinary_data.binary_data)}\n"
return tostr

View File

@ -7,11 +7,13 @@ import threading
logger = get_module_logger(__name__)
class ASRPipeline(PipelineBase):
"""
管道类
实现具体的处理逻辑
"""
def __init__(self, *args, **kwargs):
"""
初始化管道
@ -91,27 +93,24 @@ class ASRPipeline(PipelineBase):
"""
try:
from src.functor import FunctorFactory
# 加载VAD、asr、spk functor
self._functor_dict["vad"] = FunctorFactory.make_functor(
functor_name = "vad",
config = self._config,
models = self._models
functor_name="vad", config=self._config, models=self._models
)
self._functor_dict["asr"] = FunctorFactory.make_functor(
functor_name = "asr",
config = self._config,
models = self._models
functor_name="asr", config=self._config, models=self._models
)
self._functor_dict["spk"] = FunctorFactory.make_functor(
functor_name = "spk",
config = self._config,
models = self._models
functor_name="spk", config=self._config, models=self._models
)
# 创建音频数据存储单元
self._audio_binary_data_list = AudioBinary_data_list()
self._functor_dict["vad"].set_audio_binary_data_list(self._audio_binary_data_list)
self._functor_dict["vad"].set_audio_binary_data_list(
self._audio_binary_data_list
)
# 初始化子队列
self._subqueue_dict["original"] = Queue()
@ -134,13 +133,23 @@ class ASRPipeline(PipelineBase):
"""
带回调函数的put
"""
def put_with_check(data: Any) -> None:
queue.put(data)
callback(data)
return put_with_check
self._functor_dict["asr"].add_callback(put_with_checkcallback(self._subqueue_dict["asrend"], self._check_result))
self._functor_dict["spk"].add_callback(put_with_checkcallback(self._subqueue_dict["spkend"], self._check_result))
self._functor_dict["asr"].add_callback(
put_with_checkcallback(
self._subqueue_dict["asrend"], self._check_result
)
)
self._functor_dict["spk"].add_callback(
put_with_checkcallback(
self._subqueue_dict["spkend"], self._check_result
)
)
except ImportError:
raise ImportError("functorFactory引入失败,ASRPipeline无法完成初始化")
@ -150,14 +159,14 @@ class ASRPipeline(PipelineBase):
检查结果
"""
# 若asr和spk队列中都有数据则合并数据
if self._subqueue_dict["asrend"].qsize() & self._subqueue_dict["spkend"].qsize():
if (
self._subqueue_dict["asrend"].qsize()
& self._subqueue_dict["spkend"].qsize()
):
asr_data = self._subqueue_dict["asrend"].get()
spk_data = self._subqueue_dict["spkend"].get()
# 合并数据
result = {
"asr_data": asr_data,
"spk_data": spk_data
}
result = {"asr_data": asr_data, "spk_data": spk_data}
# 通知回调函数
self._notify_callbacks(result)

View File

@ -8,11 +8,13 @@ import time
# 配置日志
logger = logging.getLogger(__name__)
class PipelineBase(ABC):
"""
管道基类
定义了管道的基本接口和通用功能
"""
def __init__(self, input_queue: Optional[Queue] = None):
"""
初始化管道
@ -115,14 +117,35 @@ class PipelineBase(ABC):
# 注意Python的线程无法被强制终止这里只是设置标志
# 实际终止需要依赖操作系统的进程管理
class PipelineFactory:
"""
管道工厂类
用于创建管道实例
"""
@staticmethod
def create_pipeline(pipeline_name: str) -> Any:
from src.pipeline.ASRpipeline import ASRPipeline
def _create_pipeline_ASRpipeline(*args, **kwargs) -> ASRPipeline:
"""
创建ASR管道实例
"""
from src.pipeline.ASRpipeline import ASRPipeline
pipeline = ASRPipeline()
pipeline.set_config(kwargs["config"])
pipeline.set_models(kwargs["models"])
pipeline.set_audio_binary(kwargs["audio_binary"])
pipeline.set_input_queue(kwargs["input_queue"])
pipeline.add_callback(kwargs["callback"])
pipeline.bake()
return pipeline
@classmethod
def create_pipeline(cls, pipeline_name: str, *args, **kwargs) -> Any:
"""
创建管道实例
"""
pass
if pipeline_name == "ASRpipeline":
return cls._create_pipeline_ASRpipeline(*args, **kwargs)
else:
raise ValueError(f"不支持的管道类型: {pipeline_name}")

View File

@ -172,7 +172,7 @@ class STTRunner(RunnerBase):
logger.warning(
"等待队列处理完成超时(%s秒),队列中还有 %d 个任务未处理",
self._stop_timeout,
self._input_queue.qsize()
self._input_queue.qsize(),
)
success = False
break
@ -188,7 +188,7 @@ class STTRunner(RunnerBase):
"错误堆栈:\n%s",
error_type,
error_msg,
error_traceback
error_traceback,
)
success = False
@ -198,7 +198,7 @@ class STTRunner(RunnerBase):
logger.warning(
"部分管道停止失败,队列状态: 大小=%d, 是否为空=%s",
self._input_queue.qsize(),
self._input_queue.empty()
self._input_queue.empty(),
)
return success

View File

@ -88,7 +88,9 @@ async def ws_serve(websocket, path):
# 更新各种配置参数
if "is_speaking" in messagejson:
websocket.is_speaking = messagejson["is_speaking"]
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
websocket.status_dict_asr_online["is_final"] = (
not websocket.is_speaking
)
if "chunk_interval" in messagejson:
websocket.chunk_interval = messagejson["chunk_interval"]
if "wav_name" in messagejson:
@ -97,11 +99,17 @@ async def ws_serve(websocket, path):
chunk_size = messagejson["chunk_size"]
if isinstance(chunk_size, str):
chunk_size = chunk_size.split(",")
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
websocket.status_dict_asr_online["chunk_size"] = [
int(x) for x in chunk_size
]
if "encoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
websocket.status_dict_asr_online["encoder_chunk_look_back"] = (
messagejson["encoder_chunk_look_back"]
)
if "decoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"]
websocket.status_dict_asr_online["decoder_chunk_look_back"] = (
messagejson["decoder_chunk_look_back"]
)
if "hotword" in messagejson:
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
if "mode" in messagejson:
@ -111,11 +119,17 @@ async def ws_serve(websocket, path):
# 根据chunk_interval更新VAD的chunk_size
websocket.status_dict_vad["chunk_size"] = int(
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1] * 60 / websocket.chunk_interval
websocket.status_dict_asr_online.get("chunk_size", [0, 10])[1]
* 60
/ websocket.chunk_interval
)
# 处理音频数据
if len(frames_asr_online) > 0 or len(frames_asr) >= 0 or not isinstance(message, str):
if (
len(frames_asr_online) > 0
or len(frames_asr) >= 0
or not isinstance(message, str)
):
if not isinstance(message, str): # 二进制音频数据
# 添加到帧缓冲区
frames.append(message)
@ -127,8 +141,10 @@ async def ws_serve(websocket, path):
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
# 达到chunk_interval或最终帧时处理在线ASR
if (len(frames_asr_online) % websocket.chunk_interval == 0 or
websocket.status_dict_asr_online["is_final"]):
if (
len(frames_asr_online) % websocket.chunk_interval == 0
or websocket.status_dict_asr_online["is_final"]
):
if websocket.mode == "2pass" or websocket.mode == "online":
audio_in = b"".join(frames_asr_online)
try:
@ -143,7 +159,9 @@ async def ws_serve(websocket, path):
# VAD处理 - 语音活动检测
try:
speech_start_i, speech_end_i = await asr_service.async_vad(websocket, message)
speech_start_i, speech_end_i = await asr_service.async_vad(
websocket, message
)
except Exception as e:
print(f"VAD处理错误: {e}")
@ -151,8 +169,12 @@ async def ws_serve(websocket, path):
if speech_start_i != -1:
speech_start = True
# 计算开始偏移并收集前面的帧
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
frames_pre = frames[-beg_bias:] if beg_bias < len(frames) else frames
beg_bias = (
websocket.vad_pre_idx - speech_start_i
) // duration_ms
frames_pre = (
frames[-beg_bias:] if beg_bias < len(frames) else frames
)
frames_asr = []
frames_asr.extend(frames_pre)
@ -207,16 +229,16 @@ def start_server(args, asr_service_instance):
ssl_context.load_cert_chain(args.certfile, keyfile=args.keyfile)
start_server = websockets.serve(
ws_serve, args.host, args.port,
ws_serve,
args.host,
args.port,
subprotocols=["binary"],
ping_interval=None,
ssl=ssl_context
ssl=ssl_context,
)
else:
start_server = websockets.serve(
ws_serve, args.host, args.port,
subprotocols=["binary"],
ping_interval=None
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None
)
print(f"WebSocket服务器已启动 - 监听 {args.host}:{args.port}")

View File

@ -35,8 +35,7 @@ class ASRService:
"""
# 使用VAD模型分析音频段
segments_result = self.model_vad.generate(
input=audio_in,
**websocket.status_dict_vad
input=audio_in, **websocket.status_dict_vad
)[0]["value"]
speech_start = -1
@ -64,36 +63,38 @@ class ASRService:
if len(audio_in) > 0:
# 使用离线ASR模型处理音频
rec_result = self.model_asr.generate(
input=audio_in,
**websocket.status_dict_asr
input=audio_in, **websocket.status_dict_asr
)[0]
# 如果有标点符号模型且识别出文本,则添加标点
if self.model_punc is not None and len(rec_result["text"]) > 0:
rec_result = self.model_punc.generate(
input=rec_result["text"],
**websocket.status_dict_punc
input=rec_result["text"], **websocket.status_dict_punc
)[0]
# 如果识别出文本,发送到客户端
if len(rec_result["text"]) > 0:
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({
message = json.dumps(
{
"mode": mode,
"text": rec_result["text"],
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
})
}
)
await websocket.send(message)
else:
# 如果没有音频数据,发送空文本
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({
message = json.dumps(
{
"mode": mode,
"text": "",
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
})
}
)
await websocket.send(message)
async def async_asr_online(self, websocket, audio_in):
@ -107,21 +108,24 @@ class ASRService:
if len(audio_in) > 0:
# 使用在线ASR模型处理音频
rec_result = self.model_asr_streaming.generate(
input=audio_in,
**websocket.status_dict_asr_online
input=audio_in, **websocket.status_dict_asr_online
)[0]
# 在2pass模式下如果是最终帧则跳过(留给离线ASR处理)
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False):
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get(
"is_final", False
):
return
# 如果识别出文本,发送到客户端
if len(rec_result["text"]):
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({
message = json.dumps(
{
"mode": mode,
"text": rec_result["text"],
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
})
}
)
await websocket.send(message)

View File

@ -1,10 +1,12 @@
"""
处理各类音频数据与bytes的转换
"""
import wave
from pydub import AudioSegment
import io
def wav_to_bytes(wav_path: str) -> bytes:
"""
将WAV文件读取为bytes数据
@ -20,18 +22,20 @@ def wav_to_bytes(wav_path: str) -> bytes:
wave.Error: 如果文件不是有效的WAV文件
"""
try:
with wave.open(wav_path, 'rb') as wf:
with wave.open(wav_path, "rb") as wf:
# 读取所有帧
frames = wf.readframes(wf.getnframes())
return frames
except FileNotFoundError:
# 可以选择记录日志或重新抛出,这里为了清晰直接重新抛出
raise FileNotFoundError(f"错误未找到WAV文件 '{wav_path}'")
raise FileNotFoundError(f"错误: 未找到WAV文件 '{wav_path}'")
except wave.Error as e:
raise wave.Error(f"错误打开或读取WAV文件 '{wav_path}' 失败 - {e}")
raise wave.Error(f"错误: 打开或读取WAV文件 '{wav_path}' 失败 - {e}")
def bytes_to_wav(bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int):
def bytes_to_wav(
bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: int, framerate: int
):
"""
将bytes数据写入为WAV文件
@ -46,17 +50,18 @@ def bytes_to_wav(bytes_data: bytes, wav_path: str, nchannels: int, sampwidth: in
wave.Error: 如果写入WAV文件失败
"""
try:
with wave.open(wav_path, 'wb') as wf:
with wave.open(wav_path, "wb") as wf:
wf.setnchannels(nchannels)
wf.setsampwidth(sampwidth)
wf.setframerate(framerate)
wf.writeframes(bytes_data)
except wave.Error as e:
raise wave.Error(f"错误写入WAV文件 '{wav_path}' 失败 - {e}")
raise wave.Error(f"错误: 写入WAV文件 '{wav_path}' 失败 - {e}")
except Exception as e:
# 捕获其他可能的写入错误
raise Exception(f"写入WAV文件 '{wav_path}' 时发生未知错误 - {e}")
def mp3_to_bytes(mp3_path: str) -> bytes:
"""
将MP3文件转换为bytes数据 (原始PCM数据)
@ -76,12 +81,19 @@ def mp3_to_bytes(mp3_path: str) -> bytes:
# 获取原始PCM数据
return audio.raw_data
except FileNotFoundError:
raise FileNotFoundError(f"错误未找到MP3文件 '{mp3_path}'")
raise FileNotFoundError(f"错误: 未找到MP3文件 '{mp3_path}'")
except Exception as e: # pydub 可能抛出多种解码相关的错误
raise Exception(f"错误处理MP3文件 '{mp3_path}' 失败 - {e}")
raise Exception(f"错误: 处理MP3文件 '{mp3_path}' 失败 - {e}")
def bytes_to_mp3(bytes_data: bytes, mp3_path: str, frame_rate: int, channels: int, sample_width: int, bitrate: str = "192k"):
def bytes_to_mp3(
bytes_data: bytes,
mp3_path: str,
frame_rate: int,
channels: int,
sample_width: int,
bitrate: str = "192k",
):
"""
将原始PCM bytes数据转换为MP3文件
@ -102,9 +114,9 @@ def bytes_to_mp3(bytes_data: bytes, mp3_path: str, frame_rate: int, channels: in
data=bytes_data,
sample_width=sample_width,
frame_rate=frame_rate,
channels=channels
channels=channels,
)
# 导出为MP3
audio.export(mp3_path, format="mp3", bitrate=bitrate)
except Exception as e:
raise Exception(f"错误转换或写入MP3文件 '{mp3_path}' 失败 - {e}")
raise Exception(f"错误: 转换或写入MP3文件 '{mp3_path}' 失败 - {e}")

View File

@ -3,6 +3,7 @@ import sys
from pathlib import Path
from typing import Optional
def setup_logger(
name: str = None,
level: str = "INFO",
@ -45,17 +46,15 @@ def setup_logger(
log_path = Path(log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# 注意:移除了 propagate = False允许日志传递
return logger
def setup_root_logger(
level: str = "INFO",
log_file: Optional[str] = None
) -> None:
def setup_root_logger(level: str = "INFO", log_file: Optional[str] = None) -> None:
"""
配置根日志器
@ -65,10 +64,11 @@ def setup_root_logger(
"""
setup_logger(None, level, log_file)
def get_module_logger(
module_name: str,
level: Optional[str] = None, # 改为可选参数
log_file: Optional[str] = None # 一般不需要单独指定
log_file: Optional[str] = None, # 一般不需要单独指定
) -> logging.Logger:
"""
获取模块级别的logger

View File

@ -1,10 +1,15 @@
from tests.functor.vad_test import test_vad_functor
"""
测试主函数
请在tests目录下创建测试文件, 并在此文件中调用
"""
from tests.pipeline.asr_test import test_asr_pipeline
from src.utils.logger import get_module_logger, setup_root_logger
setup_root_logger(level="INFO", log_file="logs/test_main.log")
logger = get_module_logger(__name__)
# from tests.functor.vad_test import test_vad_functor
# logger.info("开始测试VAD函数器")
# test_vad_functor()

View File

@ -2,6 +2,7 @@
Functor测试
VAD测试
"""
from src.functor.vad_functor import VADFunctor
from src.functor.asr_functor import ASRFunctor
from src.functor.spk_functor import SPKFunctor
@ -21,6 +22,7 @@ logger = get_module_logger(__name__)
model_loader = ModelLoader()
def test_vad_functor():
# 加载模型
args = {
@ -38,7 +40,7 @@ def test_vad_functor():
chunk_stride=1600,
sample_rate=sample_rate,
sample_width=16,
channels=1
channels=1,
)
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
audio_config.chunk_stride = chunk_stride
@ -62,9 +64,7 @@ def test_vad_functor():
vad_functor.add_callback(lambda x: vad2asr_queue.put(x))
vad_functor.add_callback(lambda x: vad2spk_queue.put(x))
# 设置模型
vad_functor.set_model({
'vad': model_loader.models['vad']
})
vad_functor.set_model({"vad": model_loader.models["vad"]})
# 启动VAD函数器
vad_functor.run()
@ -77,9 +77,7 @@ def test_vad_functor():
# 设置回调函数
asr_functor.add_callback(lambda x: print(f"asr callback: {x}"))
# 设置模型
asr_functor.set_model({
'asr': model_loader.models['asr']
})
asr_functor.set_model({"asr": model_loader.models["asr"]})
# 启动ASR函数器
asr_functor.run()
@ -92,23 +90,25 @@ def test_vad_functor():
# 设置回调函数
spk_functor.add_callback(lambda x: print(f"spk callback: {x}"))
# 设置模型
spk_functor.set_model({
spk_functor.set_model(
{
# 'spk': model_loader.models['spk']
'spk': 'fake_spk'
})
"spk": "fake_spk"
}
)
# 启动SPK函数器
spk_functor.run()
f_binary = f_data
audio_clip_len = 200
print(f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}")
print(
f"f_binary: {len(f_binary)}, audio_clip_len: {audio_clip_len}, clip_num: {len(f_binary) // audio_clip_len}"
)
for i in range(0, len(f_binary), audio_clip_len):
binary_data = f_binary[i : i + audio_clip_len]
input_queue.put(binary_data)
# 等待VAD函数器结束
vad_functor.stop()
print("[vad_test] VAD函数器结束")
@ -119,4 +119,6 @@ def test_vad_functor():
if OVERWATCH:
for index in range(len(audio_binary_data_list)):
save_path = f"tests/vad_test_output_{index}.wav"
soundfile.write(save_path, audio_binary_data_list[index].binary_data, sample_rate)
soundfile.write(
save_path, audio_binary_data_list[index].binary_data, sample_rate
)

View File

@ -1,9 +1,20 @@
"""
模型使用测试
此处主要用于各类调用模型的处理数据与输出格式
请在主目录下test_main.py中调用
将需要测试的模型定义在函数中进行测试, 函数名称需要与测试内容匹配
"""
from funasr import AutoModel
from typing import List, Dict, Any
from src.models import VADResponse
import time
def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
"""
在线VAD模型使用
"""
chunk_size = 100 # ms
model = AutoModel(model="fsmn-vad", model_revision="v2.0.4", disable_update=True)
@ -21,7 +32,9 @@ def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
time.sleep(0.1)
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
is_final = i == total_chunk_num - 1
res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
res = model.generate(
input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size
)
if len(res[0]["value"]):
vad_result += VADResponse.from_raw(res)
for item in res[0]["value"]:
@ -32,16 +45,24 @@ def vad_model_use_online(file_path: str) -> List[Dict[str, Any]]:
# print(item)
return vad_result
def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
"""
在线VAD模型使用
测试LogicTrager
在Rebuild版本后LogicTrager中已弃用
"""
from src.logic_trager import LogicTrager
import soundfile
from src.config import parse_args
args = parse_args()
# from src.functor.model_loader import load_models
# models = load_models(args)
from src.model_loader import ModelLoader
models = ModelLoader(args)
chunk_size = 200 # ms
@ -50,7 +71,9 @@ def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
speech, sample_rate = soundfile.read(file_path)
chunk_stride = int(chunk_size * sample_rate / 1000)
audio_config = AudioBinary_Config(sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size)
audio_config = AudioBinary_Config(
sample_rate=sample_rate, sample_width=2, channels=1, chunk_size=chunk_size
)
logic_trager = LogicTrager(models=models, audio_config=audio_config)
for i in range(len(speech) // chunk_stride + 1):
@ -61,12 +84,22 @@ def vad_model_use_online_logic(file_path: str) -> List[Dict[str, Any]]:
# print(item)
return None
def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
"""
ASR模型使用
离线ASR模型使用
"""
from funasr import AutoModel
model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
vad_model="fsmn-vad", vad_model_revision="v2.0.4",
model = AutoModel(
model="paraformer-zh",
model_revision="v2.0.4",
vad_model="fsmn-vad",
vad_model_revision="v2.0.4",
# punc_model="ct-punc-c", punc_model_revision="v2.0.4",
spk_model="cam++", spk_model_revision="v2.0.2",
spk_model="cam++",
spk_model_revision="v2.0.2",
spk_mode="vad_segment",
auto_update=False,
)
@ -80,7 +113,9 @@ def asr_model_use_offline(file_path: str) -> List[Dict[str, Any]]:
result = model.generate(speech)
return result
if __name__ == "__main__":
# if __name__ == "__main__":
# 请在主目录下调用test_main.py文件进行测试
# vad_result = vad_model_use_online("tests/vad_example.wav")
vad_result = vad_model_use_online_logic("tests/vad_example.wav")
# vad_result = vad_model_use_online_logic("tests/vad_example.wav")
# print(vad_result)

View File

@ -2,7 +2,9 @@
Pipeline测试
VAD+ASR+SPK(FAKE)
"""
from src.pipeline.ASRpipeline import ASRPipeline
from src.pipeline import PipelineFactory
from src.models import AudioBinary_data_list, AudioBinary_Config
from src.model_loader import ModelLoader
from queue import Queue
@ -18,6 +20,7 @@ OVAERWATCH = False
model_loader = ModelLoader()
def test_asr_pipeline():
# 加载模型
args = {
@ -36,7 +39,7 @@ def test_asr_pipeline():
chunk_stride=1600,
sample_rate=sample_rate,
sample_width=16,
channels=1
channels=1,
)
chunk_stride = int(audio_config.chunk_size * sample_rate / 1000)
audio_config.chunk_stride = chunk_stride
@ -52,19 +55,30 @@ def test_asr_pipeline():
input_queue = Queue()
# 创建Pipeline
asr_pipeline = ASRPipeline()
asr_pipeline.set_models(models)
asr_pipeline.set_config(config)
asr_pipeline.set_audio_binary(audio_binary_data_list)
asr_pipeline.set_input_queue(input_queue)
asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
asr_pipeline.bake()
# asr_pipeline = ASRPipeline()
# asr_pipeline.set_models(models)
# asr_pipeline.set_config(config)
# asr_pipeline.set_audio_binary(audio_binary_data_list)
# asr_pipeline.set_input_queue(input_queue)
# asr_pipeline.add_callback(lambda x: print(f"pipeline callback: {x}"))
# asr_pipeline.bake()
asr_pipeline = PipelineFactory.create_pipeline(
pipeline_name = "ASRpipeline",
models=models,
config=config,
audio_binary=audio_binary_data_list,
input_queue=input_queue,
callback=lambda x: print(f"pipeline callback: {x}")
)
# 运行Pipeline
asr_instance = asr_pipeline.run()
audio_clip_len = 200
print(f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}")
print(
f"audio_data: {len(audio_data)}, audio_clip_len: {audio_clip_len}, clip_num: {len(audio_data) // audio_clip_len}"
)
for i in range(0, len(audio_data), audio_clip_len):
input_queue.put(audio_data[i : i + audio_clip_len])
@ -77,4 +91,3 @@ def test_asr_pipeline():
time.sleep(5)
asr_pipeline.stop()
# asr_pipeline.stop()

View File

@ -10,13 +10,13 @@ import os
from unittest.mock import patch
# 将src目录添加到路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.config import parse_args
def test_default_args():
"""测试默认参数值"""
with patch('sys.argv', ['script.py']):
with patch("sys.argv", ["script.py"]):
args = parse_args()
# 检查服务器参数
@ -46,17 +46,24 @@ def test_default_args():
def test_custom_args():
"""测试自定义参数值"""
test_args = [
'script.py',
'--host', 'localhost',
'--port', '8080',
'--certfile', 'cert.pem',
'--keyfile', 'key.pem',
'--asr_model', 'custom_model',
'--ngpu', '0',
'--device', 'cpu'
"script.py",
"--host",
"localhost",
"--port",
"8080",
"--certfile",
"cert.pem",
"--keyfile",
"key.pem",
"--asr_model",
"custom_model",
"--ngpu",
"0",
"--device",
"cpu",
]
with patch('sys.argv', test_args):
with patch("sys.argv", test_args):
args = parse_args()
# 检查自定义参数