[说话人认证]编写SPKFunctor,使用cam++提取embs,用余弦相似度判断得分。

This commit is contained in:
Ziyang.Zhang 2025-06-30 14:13:06 +08:00
parent 5dac718dee
commit 4e9dd83d55
9 changed files with 217 additions and 5 deletions

3
src/core/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from .model_loader import ModelLoader
__all__ = ["ModelLoader"]

View File

@ -4,16 +4,41 @@ SpkFunctor
"""
from src.functor.base import BaseFunctor
from src.models import AudioBinary_Config, VAD_Functor_result
from typing import Callable, List
from src.models import AudioBinary_Config, VAD_Functor_result, SpeakerCreate
from typing import Callable, List, Dict
from queue import Queue, Empty
import json
import torch
import threading
import numpy
# 日志
from src.utils.logger import get_module_logger
logger = get_module_logger(__name__)
from modelscope.pipelines import pipeline
# sv_pipeline = pipeline(
# task='speaker-verification',
# model='iic/speech_campplus_sv_zh-cn_16k-common',
# model_revision='v1.0.0'
# )
# speaker1_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker1_a_cn_16k.wav'
# speaker1_b_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker1_b_cn_16k.wav'
# speaker2_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker2_a_cn_16k.wav'
# # 相同说话人语音
# result = sv_pipeline([speaker1_a_wav, speaker1_b_wav])
# print(result)
# # 不同说话人语音
# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav])
# print(result)
# # 可以自定义得分阈值来进行识别,阈值越高,判定为同一人的条件越严格
# result = sv_pipeline([speaker1_a_wav, speaker1_a_wav], thr=0.6)
# print(result)
# # 可以传入output_emb参数输出结果中就会包含提取到的说话人embedding
# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], output_emb=True)
# print(result['embs'], result['outputs'])
# # 可以传入save_dir参数提取到的说话人embedding会存储在save_dir目录中
# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], save_dir='savePath/')
class SPKFunctor(BaseFunctor):
"""
@ -26,15 +51,84 @@ class SPKFunctor(BaseFunctor):
使用stop()停止线程, 但需要等待input_queue为空
"""
class speaker_verify:
"""
说话人验证
"""
def __init__(self, *args, **kwargs):
self._spk_data: List[SpeakerCreate] = []
def add_speaker(self, speaker: SpeakerCreate) -> None:
self._spk_data.append(speaker)
def verify(self, emb: numpy.ndarray) -> Dict:
# 将输入的numpy embedding转换为tensor
input_emb_tensor = torch.from_numpy(emb).unsqueeze(0)
results = []
for speaker in self._spk_data:
if not speaker.speaker_embs:
continue
# 将列表中存储的embedding转换为tensor
speaker_emb_tensor = torch.tensor(speaker.speaker_embs).unsqueeze(0)
# 计算余弦相似度
score = torch.nn.functional.cosine_similarity(input_emb_tensor, speaker_emb_tensor).item()
result = {
"score": score,
"speaker_id": speaker.speaker_id,
"speaker_name": speaker.speaker_name,
"speaker_description": speaker.speaker_description
}
results.append(result)
if not results:
return {
"speaker_id": "unknown",
"speaker_name": "Unknown",
"speaker_description": "No registered speakers to verify against.",
"score": 0.0,
"results": []
}
results.sort(key=lambda x: x['score'], reverse=True)
best_match = results[0]
return {
"speaker_id": best_match['speaker_id'],
"speaker_name": best_match['speaker_name'],
"speaker_description": best_match['speaker_description'],
"score": best_match['score'],
"results": results
}
def __init__(self) -> None:
super().__init__()
# 资源与配置
self._spk_verify = self.speaker_verify()
self._sv_pipeline = pipeline(
task='speaker-verification',
model='iic/speech_campplus_sv_zh-cn_16k-common',
model_revision='v1.0.0'
)
self._model: dict = {} # 模型
self._callback: List[Callable] = [] # 回调函数
self._input_queue: Queue = None # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置
def load_spk_data_local(
self,
spk_data_path: str = 'data/speakers.json',
) -> None:
"""
加载本地说话人数据
"""
with open(spk_data_path, 'r') as f:
spk_data = json.load(f)
for spk in spk_data:
self._spk_verify.add_speaker(SpeakerCreate(**spk))
def reset_cache(self) -> None:
"""
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
@ -60,6 +154,7 @@ class SPKFunctor(BaseFunctor):
self._audio_config = audio_config
logger.debug("SpkFunctor设置音频配置: %s", self._audio_config)
def add_callback(self, callback: Callable) -> None:
"""
向自身的_callback: List[Callable]回调函数列表中添加回调函数
@ -84,7 +179,11 @@ class SPKFunctor(BaseFunctor):
# input=binary_data,
# chunk_size=self._audio_config.chunk_size,
# )
result = [{"result": "spk1", "score": {"spk1": 0.9, "spk2": 0.3}}]
sv_result = self._sv_pipeline([binary_data], output_emb=True)
embs = sv_result['embs'][0]
result = self._spk_verify.verify(embs)
self._do_callback(result)
def _run(self) -> None:

View File

@ -1,6 +1,8 @@
from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
from .vad import VAD_Functor_result
from .spk import SpeakerCreate, SpeakerResponse
__all__ = [
"AudioBinary_Config",
"AudioBinary_data_list",

68
src/models/spk.py Normal file
View File

@ -0,0 +1,68 @@
"""
src/schemas/speaker.py
------------------------
此模块定义与说话人speakers表对应的 Pydantic 模型用于 API 数据验证和序列化
模型说明
- SpeakerBase定义说话人的基础字段包括姓名与描述
- SpeakerCreate用于创建说话人时的数据验证直接继承 SpeakerBase
- SpeakerUpdate用于更新说话人信息时所有字段均为可选
- SpeakerResponse返回给客户端时使用包含数据库生成的字段 speaker_idcreated_at
"""
from datetime import datetime
from typing import Optional, List
from uuid import UUID
from pydantic import BaseModel, Field
from .base import BaseSchema
# 基础模型,定义说话人的核心属性
class SpeakerBase(BaseSchema):
speaker_id: UUID = Field(
...,
description="说话人唯一标识符"
)
speaker_name: str = Field(
...,
max_length=255,
description="说话人姓名必填最多255字符"
)
speaker_description: Optional[str] = Field(
None,
description="说话人描述信息,可选"
)
speaker_embs: Optional[List[float]] = Field(
None,
description="说话人embedding可选"
)
# 用于创建说话人时的数据模型,直接继承 SpeakerBase
class SpeakerCreate(SpeakerBase):
pass
# 用于更新说话人时的数据模型,所有字段都是可选的
class SpeakerUpdate(BaseModel):
speaker_name: Optional[str] = Field(
None,
max_length=255,
description="说话人姓名"
)
speaker_description: Optional[str] = Field(
None,
description="说话人描述信息"
)
# 用于 API 返回的说话人响应模型,扩展基础模型以包含数据库生成的字段
class SpeakerResponse(SpeakerBase):
speaker_id: UUID = Field(
...,
description="说话人的唯一标识符UUID"
)
created_at: datetime = Field(
...,
description="记录创建时间"
)
updated_at: datetime = Field(
...,
description="最近更新时间"
)

View File

@ -0,0 +1,3 @@
from .asr_endpoint import router as asr_router
__all__ = ["asr_router"]

View File

@ -0,0 +1,32 @@
"""
-*- coding: utf-8 -*-
此模块是ASR的websocket端点, 使用FastAPI的websocket端点
"""
from fastapi import WebSocket, APIRouter
router = APIRouter()
from src.runner.ASRRunner import ASRRunner
ASRRunner_instance = ASRRunner()
from src.core import ModelLoader
model_loader = ModelLoader()
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"spk_model": "cam++",
"spk_model_revision": "v2.0.2",
"audio_update": False,
}
models = model_loader.load_models(args)
@router.websocket("/asr_full")
async def asr_endpoint(websocket: WebSocket):
await websocket.accept()
while True:
data = await websocket.receive_text()
print(data)

3
src/websockets/router.py Normal file
View File

@ -0,0 +1,3 @@
from endpoint import asr_router
__all__ = ["asr_router"]

View File

@ -5,7 +5,7 @@
from tests.pipeline.asr_test import test_asr_pipeline
from src.utils.logger import get_module_logger, setup_root_logger
from tests.runner.stt_runner_test import test_asr_runner
from tests.runner.asr_runner_test import test_asr_runner
setup_root_logger(level="INFO", log_file="logs/test_main.log")
logger = get_module_logger(__name__)

View File

@ -19,5 +19,7 @@ print(result)
# 可以传入output_emb参数输出结果中就会包含提取到的说话人embedding
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], output_emb=True)
print(result['embs'], result['outputs'])
# result1 = sv_pipeline([speaker1_a_wav], output_emb=True)
# print(result1['embs'], result1['outputs'])
# 可以传入save_dir参数提取到的说话人embedding会存储在save_dir目录中
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], save_dir='savePath/')