[说话人认证]编写SPKFunctor,使用cam++提取embs,用余弦相似度判断得分。

This commit is contained in:
Ziyang.Zhang 2025-06-30 14:13:06 +08:00
parent 5dac718dee
commit 4e9dd83d55
9 changed files with 217 additions and 5 deletions

3
src/core/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from .model_loader import ModelLoader
__all__ = ["ModelLoader"]

View File

@ -4,16 +4,41 @@ SpkFunctor
""" """
from src.functor.base import BaseFunctor from src.functor.base import BaseFunctor
from src.models import AudioBinary_Config, VAD_Functor_result from src.models import AudioBinary_Config, VAD_Functor_result, SpeakerCreate
from typing import Callable, List from typing import Callable, List, Dict
from queue import Queue, Empty from queue import Queue, Empty
import json
import torch
import threading import threading
import numpy
# 日志 # 日志
from src.utils.logger import get_module_logger from src.utils.logger import get_module_logger
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
from modelscope.pipelines import pipeline
# sv_pipeline = pipeline(
# task='speaker-verification',
# model='iic/speech_campplus_sv_zh-cn_16k-common',
# model_revision='v1.0.0'
# )
# speaker1_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker1_a_cn_16k.wav'
# speaker1_b_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker1_b_cn_16k.wav'
# speaker2_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker2_a_cn_16k.wav'
# # 相同说话人语音
# result = sv_pipeline([speaker1_a_wav, speaker1_b_wav])
# print(result)
# # 不同说话人语音
# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav])
# print(result)
# # 可以自定义得分阈值来进行识别,阈值越高,判定为同一人的条件越严格
# result = sv_pipeline([speaker1_a_wav, speaker1_a_wav], thr=0.6)
# print(result)
# # 可以传入output_emb参数输出结果中就会包含提取到的说话人embedding
# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], output_emb=True)
# print(result['embs'], result['outputs'])
# # 可以传入save_dir参数提取到的说话人embedding会存储在save_dir目录中
# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], save_dir='savePath/')
class SPKFunctor(BaseFunctor): class SPKFunctor(BaseFunctor):
""" """
@ -26,15 +51,84 @@ class SPKFunctor(BaseFunctor):
使用stop()停止线程, 但需要等待input_queue为空 使用stop()停止线程, 但需要等待input_queue为空
""" """
class speaker_verify:
"""
说话人验证
"""
def __init__(self, *args, **kwargs):
self._spk_data: List[SpeakerCreate] = []
def add_speaker(self, speaker: SpeakerCreate) -> None:
self._spk_data.append(speaker)
def verify(self, emb: numpy.ndarray) -> Dict:
# 将输入的numpy embedding转换为tensor
input_emb_tensor = torch.from_numpy(emb).unsqueeze(0)
results = []
for speaker in self._spk_data:
if not speaker.speaker_embs:
continue
# 将列表中存储的embedding转换为tensor
speaker_emb_tensor = torch.tensor(speaker.speaker_embs).unsqueeze(0)
# 计算余弦相似度
score = torch.nn.functional.cosine_similarity(input_emb_tensor, speaker_emb_tensor).item()
result = {
"score": score,
"speaker_id": speaker.speaker_id,
"speaker_name": speaker.speaker_name,
"speaker_description": speaker.speaker_description
}
results.append(result)
if not results:
return {
"speaker_id": "unknown",
"speaker_name": "Unknown",
"speaker_description": "No registered speakers to verify against.",
"score": 0.0,
"results": []
}
results.sort(key=lambda x: x['score'], reverse=True)
best_match = results[0]
return {
"speaker_id": best_match['speaker_id'],
"speaker_name": best_match['speaker_name'],
"speaker_description": best_match['speaker_description'],
"score": best_match['score'],
"results": results
}
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
# 资源与配置 # 资源与配置
self._spk_verify = self.speaker_verify()
self._sv_pipeline = pipeline(
task='speaker-verification',
model='iic/speech_campplus_sv_zh-cn_16k-common',
model_revision='v1.0.0'
)
self._model: dict = {} # 模型 self._model: dict = {} # 模型
self._callback: List[Callable] = [] # 回调函数 self._callback: List[Callable] = [] # 回调函数
self._input_queue: Queue = None # 输入队列 self._input_queue: Queue = None # 输入队列
self._audio_config: AudioBinary_Config = None # 音频配置 self._audio_config: AudioBinary_Config = None # 音频配置
def load_spk_data_local(
self,
spk_data_path: str = 'data/speakers.json',
) -> None:
"""
加载本地说话人数据
"""
with open(spk_data_path, 'r') as f:
spk_data = json.load(f)
for spk in spk_data:
self._spk_verify.add_speaker(SpeakerCreate(**spk))
def reset_cache(self) -> None: def reset_cache(self) -> None:
""" """
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务 重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
@ -60,6 +154,7 @@ class SPKFunctor(BaseFunctor):
self._audio_config = audio_config self._audio_config = audio_config
logger.debug("SpkFunctor设置音频配置: %s", self._audio_config) logger.debug("SpkFunctor设置音频配置: %s", self._audio_config)
def add_callback(self, callback: Callable) -> None: def add_callback(self, callback: Callable) -> None:
""" """
向自身的_callback: List[Callable]回调函数列表中添加回调函数 向自身的_callback: List[Callable]回调函数列表中添加回调函数
@ -84,7 +179,11 @@ class SPKFunctor(BaseFunctor):
# input=binary_data, # input=binary_data,
# chunk_size=self._audio_config.chunk_size, # chunk_size=self._audio_config.chunk_size,
# ) # )
result = [{"result": "spk1", "score": {"spk1": 0.9, "spk2": 0.3}}] sv_result = self._sv_pipeline([binary_data], output_emb=True)
embs = sv_result['embs'][0]
result = self._spk_verify.verify(embs)
self._do_callback(result) self._do_callback(result)
def _run(self) -> None: def _run(self) -> None:

View File

@ -1,6 +1,8 @@
from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
from .vad import VAD_Functor_result from .vad import VAD_Functor_result
from .spk import SpeakerCreate, SpeakerResponse
__all__ = [ __all__ = [
"AudioBinary_Config", "AudioBinary_Config",
"AudioBinary_data_list", "AudioBinary_data_list",

68
src/models/spk.py Normal file
View File

@ -0,0 +1,68 @@
"""
src/schemas/speaker.py
------------------------
此模块定义与说话人speakers表对应的 Pydantic 模型用于 API 数据验证和序列化
模型说明
- SpeakerBase定义说话人的基础字段包括姓名与描述
- SpeakerCreate用于创建说话人时的数据验证直接继承 SpeakerBase
- SpeakerUpdate用于更新说话人信息时所有字段均为可选
- SpeakerResponse返回给客户端时使用包含数据库生成的字段 speaker_idcreated_at
"""
from datetime import datetime
from typing import Optional, List
from uuid import UUID
from pydantic import BaseModel, Field
from .base import BaseSchema
# 基础模型,定义说话人的核心属性
class SpeakerBase(BaseSchema):
speaker_id: UUID = Field(
...,
description="说话人唯一标识符"
)
speaker_name: str = Field(
...,
max_length=255,
description="说话人姓名必填最多255字符"
)
speaker_description: Optional[str] = Field(
None,
description="说话人描述信息,可选"
)
speaker_embs: Optional[List[float]] = Field(
None,
description="说话人embedding可选"
)
# 用于创建说话人时的数据模型,直接继承 SpeakerBase
class SpeakerCreate(SpeakerBase):
pass
# 用于更新说话人时的数据模型,所有字段都是可选的
class SpeakerUpdate(BaseModel):
speaker_name: Optional[str] = Field(
None,
max_length=255,
description="说话人姓名"
)
speaker_description: Optional[str] = Field(
None,
description="说话人描述信息"
)
# 用于 API 返回的说话人响应模型,扩展基础模型以包含数据库生成的字段
class SpeakerResponse(SpeakerBase):
speaker_id: UUID = Field(
...,
description="说话人的唯一标识符UUID"
)
created_at: datetime = Field(
...,
description="记录创建时间"
)
updated_at: datetime = Field(
...,
description="最近更新时间"
)

View File

@ -0,0 +1,3 @@
from .asr_endpoint import router as asr_router
__all__ = ["asr_router"]

View File

@ -0,0 +1,32 @@
"""
-*- coding: utf-8 -*-
此模块是ASR的websocket端点, 使用FastAPI的websocket端点
"""
from fastapi import WebSocket, APIRouter
router = APIRouter()
from src.runner.ASRRunner import ASRRunner
ASRRunner_instance = ASRRunner()
from src.core import ModelLoader
model_loader = ModelLoader()
args = {
"asr_model": "paraformer-zh",
"asr_model_revision": "v2.0.4",
"vad_model": "fsmn-vad",
"vad_model_revision": "v2.0.4",
"spk_model": "cam++",
"spk_model_revision": "v2.0.2",
"audio_update": False,
}
models = model_loader.load_models(args)
@router.websocket("/asr_full")
async def asr_endpoint(websocket: WebSocket):
await websocket.accept()
while True:
data = await websocket.receive_text()
print(data)

3
src/websockets/router.py Normal file
View File

@ -0,0 +1,3 @@
from endpoint import asr_router
__all__ = ["asr_router"]

View File

@ -5,7 +5,7 @@
from tests.pipeline.asr_test import test_asr_pipeline from tests.pipeline.asr_test import test_asr_pipeline
from src.utils.logger import get_module_logger, setup_root_logger from src.utils.logger import get_module_logger, setup_root_logger
from tests.runner.stt_runner_test import test_asr_runner from tests.runner.asr_runner_test import test_asr_runner
setup_root_logger(level="INFO", log_file="logs/test_main.log") setup_root_logger(level="INFO", log_file="logs/test_main.log")
logger = get_module_logger(__name__) logger = get_module_logger(__name__)

View File

@ -19,5 +19,7 @@ print(result)
# 可以传入output_emb参数输出结果中就会包含提取到的说话人embedding # 可以传入output_emb参数输出结果中就会包含提取到的说话人embedding
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], output_emb=True) result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], output_emb=True)
print(result['embs'], result['outputs']) print(result['embs'], result['outputs'])
# result1 = sv_pipeline([speaker1_a_wav], output_emb=True)
# print(result1['embs'], result1['outputs'])
# 可以传入save_dir参数提取到的说话人embedding会存储在save_dir目录中 # 可以传入save_dir参数提取到的说话人embedding会存储在save_dir目录中
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], save_dir='savePath/') result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], save_dir='savePath/')