diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..32e9b9e --- /dev/null +++ b/src/core/__init__.py @@ -0,0 +1,3 @@ +from .model_loader import ModelLoader + +__all__ = ["ModelLoader"] \ No newline at end of file diff --git a/src/functor/spk_functor.py b/src/functor/spk_functor.py index d004543..1f0a800 100644 --- a/src/functor/spk_functor.py +++ b/src/functor/spk_functor.py @@ -4,16 +4,41 @@ SpkFunctor """ from src.functor.base import BaseFunctor -from src.models import AudioBinary_Config, VAD_Functor_result -from typing import Callable, List +from src.models import AudioBinary_Config, VAD_Functor_result, SpeakerCreate +from typing import Callable, List, Dict from queue import Queue, Empty +import json +import torch import threading +import numpy # 日志 from src.utils.logger import get_module_logger logger = get_module_logger(__name__) - +from modelscope.pipelines import pipeline +# sv_pipeline = pipeline( +# task='speaker-verification', +# model='iic/speech_campplus_sv_zh-cn_16k-common', +# model_revision='v1.0.0' +# ) +# speaker1_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker1_a_cn_16k.wav' +# speaker1_b_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker1_b_cn_16k.wav' +# speaker2_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker2_a_cn_16k.wav' +# # 相同说话人语音 +# result = sv_pipeline([speaker1_a_wav, speaker1_b_wav]) +# print(result) +# # 不同说话人语音 +# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav]) +# print(result) +# # 可以自定义得分阈值来进行识别,阈值越高,判定为同一人的条件越严格 +# result = sv_pipeline([speaker1_a_wav, speaker1_a_wav], thr=0.6) +# print(result) +# # 可以传入output_emb参数,输出结果中就会包含提取到的说话人embedding +# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], output_emb=True) +# print(result['embs'], result['outputs']) +# # 可以传入save_dir参数,提取到的说话人embedding会存储在save_dir目录中 +# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], save_dir='savePath/') class SPKFunctor(BaseFunctor): """ @@ -26,15 +51,84 @@ class SPKFunctor(BaseFunctor): 使用stop()停止线程, 但需要等待input_queue为空 """ + class speaker_verify: + """ + 说话人验证 + """ + def __init__(self, *args, **kwargs): + self._spk_data: List[SpeakerCreate] = [] + def add_speaker(self, speaker: SpeakerCreate) -> None: + self._spk_data.append(speaker) + + def verify(self, emb: numpy.ndarray) -> Dict: + # 将输入的numpy embedding转换为tensor + input_emb_tensor = torch.from_numpy(emb).unsqueeze(0) + + results = [] + for speaker in self._spk_data: + if not speaker.speaker_embs: + continue + # 将列表中存储的embedding转换为tensor + speaker_emb_tensor = torch.tensor(speaker.speaker_embs).unsqueeze(0) + + # 计算余弦相似度 + score = torch.nn.functional.cosine_similarity(input_emb_tensor, speaker_emb_tensor).item() + + result = { + "score": score, + "speaker_id": speaker.speaker_id, + "speaker_name": speaker.speaker_name, + "speaker_description": speaker.speaker_description + } + results.append(result) + + if not results: + return { + "speaker_id": "unknown", + "speaker_name": "Unknown", + "speaker_description": "No registered speakers to verify against.", + "score": 0.0, + "results": [] + } + + results.sort(key=lambda x: x['score'], reverse=True) + best_match = results[0] + return { + "speaker_id": best_match['speaker_id'], + "speaker_name": best_match['speaker_name'], + "speaker_description": best_match['speaker_description'], + "score": best_match['score'], + "results": results + } + def __init__(self) -> None: super().__init__() # 资源与配置 + self._spk_verify = self.speaker_verify() + self._sv_pipeline = pipeline( + task='speaker-verification', + model='iic/speech_campplus_sv_zh-cn_16k-common', + model_revision='v1.0.0' + ) self._model: dict = {} # 模型 self._callback: List[Callable] = [] # 回调函数 self._input_queue: Queue = None # 输入队列 self._audio_config: AudioBinary_Config = None # 音频配置 + def load_spk_data_local( + self, + spk_data_path: str = 'data/speakers.json', + ) -> None: + """ + 加载本地说话人数据 + """ + with open(spk_data_path, 'r') as f: + spk_data = json.load(f) + for spk in spk_data: + self._spk_verify.add_speaker(SpeakerCreate(**spk)) + + def reset_cache(self) -> None: """ 重置缓存, 用于任务完成后清理缓存数据, 准备下次任务 @@ -60,6 +154,7 @@ class SPKFunctor(BaseFunctor): self._audio_config = audio_config logger.debug("SpkFunctor设置音频配置: %s", self._audio_config) + def add_callback(self, callback: Callable) -> None: """ 向自身的_callback: List[Callable]回调函数列表中添加回调函数 @@ -84,7 +179,11 @@ class SPKFunctor(BaseFunctor): # input=binary_data, # chunk_size=self._audio_config.chunk_size, # ) - result = [{"result": "spk1", "score": {"spk1": 0.9, "spk2": 0.3}}] + sv_result = self._sv_pipeline([binary_data], output_emb=True) + embs = sv_result['embs'][0] + + result = self._spk_verify.verify(embs) + self._do_callback(result) def _run(self) -> None: diff --git a/src/models/__init__.py b/src/models/__init__.py index d55041f..b460b17 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,6 +1,8 @@ from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data from .vad import VAD_Functor_result +from .spk import SpeakerCreate, SpeakerResponse + __all__ = [ "AudioBinary_Config", "AudioBinary_data_list", diff --git a/src/models/spk.py b/src/models/spk.py new file mode 100644 index 0000000..9e583d7 --- /dev/null +++ b/src/models/spk.py @@ -0,0 +1,68 @@ +""" +src/schemas/speaker.py +------------------------ +此模块定义与说话人(speakers)表对应的 Pydantic 模型,用于 API 数据验证和序列化。 + +模型说明: +- SpeakerBase:定义说话人的基础字段,包括姓名与描述。 +- SpeakerCreate:用于创建说话人时的数据验证,直接继承 SpeakerBase。 +- SpeakerUpdate:用于更新说话人信息时,所有字段均为可选。 +- SpeakerResponse:返回给客户端时使用,包含数据库生成的字段(如 speaker_id、created_at 等)。 +""" + +from datetime import datetime +from typing import Optional, List +from uuid import UUID +from pydantic import BaseModel, Field +from .base import BaseSchema + +# 基础模型,定义说话人的核心属性 +class SpeakerBase(BaseSchema): + speaker_id: UUID = Field( + ..., + description="说话人唯一标识符" + ) + speaker_name: str = Field( + ..., + max_length=255, + description="说话人姓名,必填,最多255字符" + ) + speaker_description: Optional[str] = Field( + None, + description="说话人描述信息,可选" + ) + speaker_embs: Optional[List[float]] = Field( + None, + description="说话人embedding,可选" + ) + +# 用于创建说话人时的数据模型,直接继承 SpeakerBase +class SpeakerCreate(SpeakerBase): + pass + +# 用于更新说话人时的数据模型,所有字段都是可选的 +class SpeakerUpdate(BaseModel): + speaker_name: Optional[str] = Field( + None, + max_length=255, + description="说话人姓名" + ) + speaker_description: Optional[str] = Field( + None, + description="说话人描述信息" + ) + +# 用于 API 返回的说话人响应模型,扩展基础模型以包含数据库生成的字段 +class SpeakerResponse(SpeakerBase): + speaker_id: UUID = Field( + ..., + description="说话人的唯一标识符(UUID)" + ) + created_at: datetime = Field( + ..., + description="记录创建时间" + ) + updated_at: datetime = Field( + ..., + description="最近更新时间" + ) \ No newline at end of file diff --git a/src/websockets/endpoint/__init__.py b/src/websockets/endpoint/__init__.py new file mode 100644 index 0000000..5e3a2bf --- /dev/null +++ b/src/websockets/endpoint/__init__.py @@ -0,0 +1,3 @@ +from .asr_endpoint import router as asr_router + +__all__ = ["asr_router"] \ No newline at end of file diff --git a/src/websockets/endpoint/asr_endpoint.py b/src/websockets/endpoint/asr_endpoint.py new file mode 100644 index 0000000..25adfbb --- /dev/null +++ b/src/websockets/endpoint/asr_endpoint.py @@ -0,0 +1,32 @@ +""" + -*- coding: utf-8 -*- + 此模块是ASR的websocket端点, 使用FastAPI的websocket端点 +""" + +from fastapi import WebSocket, APIRouter + +router = APIRouter() + +from src.runner.ASRRunner import ASRRunner + +ASRRunner_instance = ASRRunner() + +from src.core import ModelLoader +model_loader = ModelLoader() +args = { + "asr_model": "paraformer-zh", + "asr_model_revision": "v2.0.4", + "vad_model": "fsmn-vad", + "vad_model_revision": "v2.0.4", + "spk_model": "cam++", + "spk_model_revision": "v2.0.2", + "audio_update": False, +} +models = model_loader.load_models(args) + +@router.websocket("/asr_full") +async def asr_endpoint(websocket: WebSocket): + await websocket.accept() + while True: + data = await websocket.receive_text() + print(data) diff --git a/src/websockets/router.py b/src/websockets/router.py new file mode 100644 index 0000000..2694a6a --- /dev/null +++ b/src/websockets/router.py @@ -0,0 +1,3 @@ +from endpoint import asr_router + +__all__ = ["asr_router"] \ No newline at end of file diff --git a/test_main.py b/test_main.py index 850efe4..4af9475 100644 --- a/test_main.py +++ b/test_main.py @@ -5,7 +5,7 @@ from tests.pipeline.asr_test import test_asr_pipeline from src.utils.logger import get_module_logger, setup_root_logger -from tests.runner.stt_runner_test import test_asr_runner +from tests.runner.asr_runner_test import test_asr_runner setup_root_logger(level="INFO", log_file="logs/test_main.log") logger = get_module_logger(__name__) diff --git a/tests/spkverify_use.py b/tests/spkverify_use.py index cca1f15..f1a457e 100644 --- a/tests/spkverify_use.py +++ b/tests/spkverify_use.py @@ -19,5 +19,7 @@ print(result) # 可以传入output_emb参数,输出结果中就会包含提取到的说话人embedding result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], output_emb=True) print(result['embs'], result['outputs']) +# result1 = sv_pipeline([speaker1_a_wav], output_emb=True) +# print(result1['embs'], result1['outputs']) # 可以传入save_dir参数,提取到的说话人embedding会存储在save_dir目录中 result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], save_dir='savePath/') \ No newline at end of file