[说话人认证]编写SPKFunctor,使用cam++提取embs,用余弦相似度判断得分。
This commit is contained in:
parent
5dac718dee
commit
4e9dd83d55
3
src/core/__init__.py
Normal file
3
src/core/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .model_loader import ModelLoader
|
||||
|
||||
__all__ = ["ModelLoader"]
|
@ -4,16 +4,41 @@ SpkFunctor
|
||||
"""
|
||||
|
||||
from src.functor.base import BaseFunctor
|
||||
from src.models import AudioBinary_Config, VAD_Functor_result
|
||||
from typing import Callable, List
|
||||
from src.models import AudioBinary_Config, VAD_Functor_result, SpeakerCreate
|
||||
from typing import Callable, List, Dict
|
||||
from queue import Queue, Empty
|
||||
import json
|
||||
import torch
|
||||
import threading
|
||||
import numpy
|
||||
|
||||
# 日志
|
||||
from src.utils.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
# sv_pipeline = pipeline(
|
||||
# task='speaker-verification',
|
||||
# model='iic/speech_campplus_sv_zh-cn_16k-common',
|
||||
# model_revision='v1.0.0'
|
||||
# )
|
||||
# speaker1_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker1_a_cn_16k.wav'
|
||||
# speaker1_b_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker1_b_cn_16k.wav'
|
||||
# speaker2_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker2_a_cn_16k.wav'
|
||||
# # 相同说话人语音
|
||||
# result = sv_pipeline([speaker1_a_wav, speaker1_b_wav])
|
||||
# print(result)
|
||||
# # 不同说话人语音
|
||||
# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav])
|
||||
# print(result)
|
||||
# # 可以自定义得分阈值来进行识别,阈值越高,判定为同一人的条件越严格
|
||||
# result = sv_pipeline([speaker1_a_wav, speaker1_a_wav], thr=0.6)
|
||||
# print(result)
|
||||
# # 可以传入output_emb参数,输出结果中就会包含提取到的说话人embedding
|
||||
# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], output_emb=True)
|
||||
# print(result['embs'], result['outputs'])
|
||||
# # 可以传入save_dir参数,提取到的说话人embedding会存储在save_dir目录中
|
||||
# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], save_dir='savePath/')
|
||||
|
||||
class SPKFunctor(BaseFunctor):
|
||||
"""
|
||||
@ -26,15 +51,84 @@ class SPKFunctor(BaseFunctor):
|
||||
|
||||
使用stop()停止线程, 但需要等待input_queue为空
|
||||
"""
|
||||
class speaker_verify:
|
||||
"""
|
||||
说话人验证
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._spk_data: List[SpeakerCreate] = []
|
||||
|
||||
def add_speaker(self, speaker: SpeakerCreate) -> None:
|
||||
self._spk_data.append(speaker)
|
||||
|
||||
def verify(self, emb: numpy.ndarray) -> Dict:
|
||||
# 将输入的numpy embedding转换为tensor
|
||||
input_emb_tensor = torch.from_numpy(emb).unsqueeze(0)
|
||||
|
||||
results = []
|
||||
for speaker in self._spk_data:
|
||||
if not speaker.speaker_embs:
|
||||
continue
|
||||
# 将列表中存储的embedding转换为tensor
|
||||
speaker_emb_tensor = torch.tensor(speaker.speaker_embs).unsqueeze(0)
|
||||
|
||||
# 计算余弦相似度
|
||||
score = torch.nn.functional.cosine_similarity(input_emb_tensor, speaker_emb_tensor).item()
|
||||
|
||||
result = {
|
||||
"score": score,
|
||||
"speaker_id": speaker.speaker_id,
|
||||
"speaker_name": speaker.speaker_name,
|
||||
"speaker_description": speaker.speaker_description
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
if not results:
|
||||
return {
|
||||
"speaker_id": "unknown",
|
||||
"speaker_name": "Unknown",
|
||||
"speaker_description": "No registered speakers to verify against.",
|
||||
"score": 0.0,
|
||||
"results": []
|
||||
}
|
||||
|
||||
results.sort(key=lambda x: x['score'], reverse=True)
|
||||
best_match = results[0]
|
||||
return {
|
||||
"speaker_id": best_match['speaker_id'],
|
||||
"speaker_name": best_match['speaker_name'],
|
||||
"speaker_description": best_match['speaker_description'],
|
||||
"score": best_match['score'],
|
||||
"results": results
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# 资源与配置
|
||||
self._spk_verify = self.speaker_verify()
|
||||
self._sv_pipeline = pipeline(
|
||||
task='speaker-verification',
|
||||
model='iic/speech_campplus_sv_zh-cn_16k-common',
|
||||
model_revision='v1.0.0'
|
||||
)
|
||||
self._model: dict = {} # 模型
|
||||
self._callback: List[Callable] = [] # 回调函数
|
||||
self._input_queue: Queue = None # 输入队列
|
||||
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||
|
||||
def load_spk_data_local(
|
||||
self,
|
||||
spk_data_path: str = 'data/speakers.json',
|
||||
) -> None:
|
||||
"""
|
||||
加载本地说话人数据
|
||||
"""
|
||||
with open(spk_data_path, 'r') as f:
|
||||
spk_data = json.load(f)
|
||||
for spk in spk_data:
|
||||
self._spk_verify.add_speaker(SpeakerCreate(**spk))
|
||||
|
||||
|
||||
def reset_cache(self) -> None:
|
||||
"""
|
||||
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||
@ -60,6 +154,7 @@ class SPKFunctor(BaseFunctor):
|
||||
self._audio_config = audio_config
|
||||
logger.debug("SpkFunctor设置音频配置: %s", self._audio_config)
|
||||
|
||||
|
||||
def add_callback(self, callback: Callable) -> None:
|
||||
"""
|
||||
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||||
@ -84,7 +179,11 @@ class SPKFunctor(BaseFunctor):
|
||||
# input=binary_data,
|
||||
# chunk_size=self._audio_config.chunk_size,
|
||||
# )
|
||||
result = [{"result": "spk1", "score": {"spk1": 0.9, "spk2": 0.3}}]
|
||||
sv_result = self._sv_pipeline([binary_data], output_emb=True)
|
||||
embs = sv_result['embs'][0]
|
||||
|
||||
result = self._spk_verify.verify(embs)
|
||||
|
||||
self._do_callback(result)
|
||||
|
||||
def _run(self) -> None:
|
||||
|
@ -1,6 +1,8 @@
|
||||
from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
|
||||
from .vad import VAD_Functor_result
|
||||
|
||||
from .spk import SpeakerCreate, SpeakerResponse
|
||||
|
||||
__all__ = [
|
||||
"AudioBinary_Config",
|
||||
"AudioBinary_data_list",
|
||||
|
68
src/models/spk.py
Normal file
68
src/models/spk.py
Normal file
@ -0,0 +1,68 @@
|
||||
"""
|
||||
src/schemas/speaker.py
|
||||
------------------------
|
||||
此模块定义与说话人(speakers)表对应的 Pydantic 模型,用于 API 数据验证和序列化。
|
||||
|
||||
模型说明:
|
||||
- SpeakerBase:定义说话人的基础字段,包括姓名与描述。
|
||||
- SpeakerCreate:用于创建说话人时的数据验证,直接继承 SpeakerBase。
|
||||
- SpeakerUpdate:用于更新说话人信息时,所有字段均为可选。
|
||||
- SpeakerResponse:返回给客户端时使用,包含数据库生成的字段(如 speaker_id、created_at 等)。
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, Field
|
||||
from .base import BaseSchema
|
||||
|
||||
# 基础模型,定义说话人的核心属性
|
||||
class SpeakerBase(BaseSchema):
|
||||
speaker_id: UUID = Field(
|
||||
...,
|
||||
description="说话人唯一标识符"
|
||||
)
|
||||
speaker_name: str = Field(
|
||||
...,
|
||||
max_length=255,
|
||||
description="说话人姓名,必填,最多255字符"
|
||||
)
|
||||
speaker_description: Optional[str] = Field(
|
||||
None,
|
||||
description="说话人描述信息,可选"
|
||||
)
|
||||
speaker_embs: Optional[List[float]] = Field(
|
||||
None,
|
||||
description="说话人embedding,可选"
|
||||
)
|
||||
|
||||
# 用于创建说话人时的数据模型,直接继承 SpeakerBase
|
||||
class SpeakerCreate(SpeakerBase):
|
||||
pass
|
||||
|
||||
# 用于更新说话人时的数据模型,所有字段都是可选的
|
||||
class SpeakerUpdate(BaseModel):
|
||||
speaker_name: Optional[str] = Field(
|
||||
None,
|
||||
max_length=255,
|
||||
description="说话人姓名"
|
||||
)
|
||||
speaker_description: Optional[str] = Field(
|
||||
None,
|
||||
description="说话人描述信息"
|
||||
)
|
||||
|
||||
# 用于 API 返回的说话人响应模型,扩展基础模型以包含数据库生成的字段
|
||||
class SpeakerResponse(SpeakerBase):
|
||||
speaker_id: UUID = Field(
|
||||
...,
|
||||
description="说话人的唯一标识符(UUID)"
|
||||
)
|
||||
created_at: datetime = Field(
|
||||
...,
|
||||
description="记录创建时间"
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
...,
|
||||
description="最近更新时间"
|
||||
)
|
3
src/websockets/endpoint/__init__.py
Normal file
3
src/websockets/endpoint/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .asr_endpoint import router as asr_router
|
||||
|
||||
__all__ = ["asr_router"]
|
32
src/websockets/endpoint/asr_endpoint.py
Normal file
32
src/websockets/endpoint/asr_endpoint.py
Normal file
@ -0,0 +1,32 @@
|
||||
"""
|
||||
-*- coding: utf-8 -*-
|
||||
此模块是ASR的websocket端点, 使用FastAPI的websocket端点
|
||||
"""
|
||||
|
||||
from fastapi import WebSocket, APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
from src.runner.ASRRunner import ASRRunner
|
||||
|
||||
ASRRunner_instance = ASRRunner()
|
||||
|
||||
from src.core import ModelLoader
|
||||
model_loader = ModelLoader()
|
||||
args = {
|
||||
"asr_model": "paraformer-zh",
|
||||
"asr_model_revision": "v2.0.4",
|
||||
"vad_model": "fsmn-vad",
|
||||
"vad_model_revision": "v2.0.4",
|
||||
"spk_model": "cam++",
|
||||
"spk_model_revision": "v2.0.2",
|
||||
"audio_update": False,
|
||||
}
|
||||
models = model_loader.load_models(args)
|
||||
|
||||
@router.websocket("/asr_full")
|
||||
async def asr_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
print(data)
|
3
src/websockets/router.py
Normal file
3
src/websockets/router.py
Normal file
@ -0,0 +1,3 @@
|
||||
from endpoint import asr_router
|
||||
|
||||
__all__ = ["asr_router"]
|
@ -5,7 +5,7 @@
|
||||
|
||||
from tests.pipeline.asr_test import test_asr_pipeline
|
||||
from src.utils.logger import get_module_logger, setup_root_logger
|
||||
from tests.runner.stt_runner_test import test_asr_runner
|
||||
from tests.runner.asr_runner_test import test_asr_runner
|
||||
|
||||
setup_root_logger(level="INFO", log_file="logs/test_main.log")
|
||||
logger = get_module_logger(__name__)
|
||||
|
@ -19,5 +19,7 @@ print(result)
|
||||
# 可以传入output_emb参数,输出结果中就会包含提取到的说话人embedding
|
||||
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], output_emb=True)
|
||||
print(result['embs'], result['outputs'])
|
||||
# result1 = sv_pipeline([speaker1_a_wav], output_emb=True)
|
||||
# print(result1['embs'], result1['outputs'])
|
||||
# 可以传入save_dir参数,提取到的说话人embedding会存储在save_dir目录中
|
||||
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], save_dir='savePath/')
|
Loading…
x
Reference in New Issue
Block a user