[说话人认证]编写SPKFunctor,使用cam++提取embs,用余弦相似度判断得分。
This commit is contained in:
parent
5dac718dee
commit
4e9dd83d55
3
src/core/__init__.py
Normal file
3
src/core/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .model_loader import ModelLoader
|
||||||
|
|
||||||
|
__all__ = ["ModelLoader"]
|
@ -4,16 +4,41 @@ SpkFunctor
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from src.functor.base import BaseFunctor
|
from src.functor.base import BaseFunctor
|
||||||
from src.models import AudioBinary_Config, VAD_Functor_result
|
from src.models import AudioBinary_Config, VAD_Functor_result, SpeakerCreate
|
||||||
from typing import Callable, List
|
from typing import Callable, List, Dict
|
||||||
from queue import Queue, Empty
|
from queue import Queue, Empty
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
import threading
|
import threading
|
||||||
|
import numpy
|
||||||
|
|
||||||
# 日志
|
# 日志
|
||||||
from src.utils.logger import get_module_logger
|
from src.utils.logger import get_module_logger
|
||||||
|
|
||||||
logger = get_module_logger(__name__)
|
logger = get_module_logger(__name__)
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
# sv_pipeline = pipeline(
|
||||||
|
# task='speaker-verification',
|
||||||
|
# model='iic/speech_campplus_sv_zh-cn_16k-common',
|
||||||
|
# model_revision='v1.0.0'
|
||||||
|
# )
|
||||||
|
# speaker1_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker1_a_cn_16k.wav'
|
||||||
|
# speaker1_b_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker1_b_cn_16k.wav'
|
||||||
|
# speaker2_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker2_a_cn_16k.wav'
|
||||||
|
# # 相同说话人语音
|
||||||
|
# result = sv_pipeline([speaker1_a_wav, speaker1_b_wav])
|
||||||
|
# print(result)
|
||||||
|
# # 不同说话人语音
|
||||||
|
# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav])
|
||||||
|
# print(result)
|
||||||
|
# # 可以自定义得分阈值来进行识别,阈值越高,判定为同一人的条件越严格
|
||||||
|
# result = sv_pipeline([speaker1_a_wav, speaker1_a_wav], thr=0.6)
|
||||||
|
# print(result)
|
||||||
|
# # 可以传入output_emb参数,输出结果中就会包含提取到的说话人embedding
|
||||||
|
# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], output_emb=True)
|
||||||
|
# print(result['embs'], result['outputs'])
|
||||||
|
# # 可以传入save_dir参数,提取到的说话人embedding会存储在save_dir目录中
|
||||||
|
# result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], save_dir='savePath/')
|
||||||
|
|
||||||
class SPKFunctor(BaseFunctor):
|
class SPKFunctor(BaseFunctor):
|
||||||
"""
|
"""
|
||||||
@ -26,15 +51,84 @@ class SPKFunctor(BaseFunctor):
|
|||||||
|
|
||||||
使用stop()停止线程, 但需要等待input_queue为空
|
使用stop()停止线程, 但需要等待input_queue为空
|
||||||
"""
|
"""
|
||||||
|
class speaker_verify:
|
||||||
|
"""
|
||||||
|
说话人验证
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._spk_data: List[SpeakerCreate] = []
|
||||||
|
|
||||||
|
def add_speaker(self, speaker: SpeakerCreate) -> None:
|
||||||
|
self._spk_data.append(speaker)
|
||||||
|
|
||||||
|
def verify(self, emb: numpy.ndarray) -> Dict:
|
||||||
|
# 将输入的numpy embedding转换为tensor
|
||||||
|
input_emb_tensor = torch.from_numpy(emb).unsqueeze(0)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for speaker in self._spk_data:
|
||||||
|
if not speaker.speaker_embs:
|
||||||
|
continue
|
||||||
|
# 将列表中存储的embedding转换为tensor
|
||||||
|
speaker_emb_tensor = torch.tensor(speaker.speaker_embs).unsqueeze(0)
|
||||||
|
|
||||||
|
# 计算余弦相似度
|
||||||
|
score = torch.nn.functional.cosine_similarity(input_emb_tensor, speaker_emb_tensor).item()
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"score": score,
|
||||||
|
"speaker_id": speaker.speaker_id,
|
||||||
|
"speaker_name": speaker.speaker_name,
|
||||||
|
"speaker_description": speaker.speaker_description
|
||||||
|
}
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return {
|
||||||
|
"speaker_id": "unknown",
|
||||||
|
"speaker_name": "Unknown",
|
||||||
|
"speaker_description": "No registered speakers to verify against.",
|
||||||
|
"score": 0.0,
|
||||||
|
"results": []
|
||||||
|
}
|
||||||
|
|
||||||
|
results.sort(key=lambda x: x['score'], reverse=True)
|
||||||
|
best_match = results[0]
|
||||||
|
return {
|
||||||
|
"speaker_id": best_match['speaker_id'],
|
||||||
|
"speaker_name": best_match['speaker_name'],
|
||||||
|
"speaker_description": best_match['speaker_description'],
|
||||||
|
"score": best_match['score'],
|
||||||
|
"results": results
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 资源与配置
|
# 资源与配置
|
||||||
|
self._spk_verify = self.speaker_verify()
|
||||||
|
self._sv_pipeline = pipeline(
|
||||||
|
task='speaker-verification',
|
||||||
|
model='iic/speech_campplus_sv_zh-cn_16k-common',
|
||||||
|
model_revision='v1.0.0'
|
||||||
|
)
|
||||||
self._model: dict = {} # 模型
|
self._model: dict = {} # 模型
|
||||||
self._callback: List[Callable] = [] # 回调函数
|
self._callback: List[Callable] = [] # 回调函数
|
||||||
self._input_queue: Queue = None # 输入队列
|
self._input_queue: Queue = None # 输入队列
|
||||||
self._audio_config: AudioBinary_Config = None # 音频配置
|
self._audio_config: AudioBinary_Config = None # 音频配置
|
||||||
|
|
||||||
|
def load_spk_data_local(
|
||||||
|
self,
|
||||||
|
spk_data_path: str = 'data/speakers.json',
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
加载本地说话人数据
|
||||||
|
"""
|
||||||
|
with open(spk_data_path, 'r') as f:
|
||||||
|
spk_data = json.load(f)
|
||||||
|
for spk in spk_data:
|
||||||
|
self._spk_verify.add_speaker(SpeakerCreate(**spk))
|
||||||
|
|
||||||
|
|
||||||
def reset_cache(self) -> None:
|
def reset_cache(self) -> None:
|
||||||
"""
|
"""
|
||||||
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
重置缓存, 用于任务完成后清理缓存数据, 准备下次任务
|
||||||
@ -60,6 +154,7 @@ class SPKFunctor(BaseFunctor):
|
|||||||
self._audio_config = audio_config
|
self._audio_config = audio_config
|
||||||
logger.debug("SpkFunctor设置音频配置: %s", self._audio_config)
|
logger.debug("SpkFunctor设置音频配置: %s", self._audio_config)
|
||||||
|
|
||||||
|
|
||||||
def add_callback(self, callback: Callable) -> None:
|
def add_callback(self, callback: Callable) -> None:
|
||||||
"""
|
"""
|
||||||
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
向自身的_callback: List[Callable]回调函数列表中添加回调函数
|
||||||
@ -84,7 +179,11 @@ class SPKFunctor(BaseFunctor):
|
|||||||
# input=binary_data,
|
# input=binary_data,
|
||||||
# chunk_size=self._audio_config.chunk_size,
|
# chunk_size=self._audio_config.chunk_size,
|
||||||
# )
|
# )
|
||||||
result = [{"result": "spk1", "score": {"spk1": 0.9, "spk2": 0.3}}]
|
sv_result = self._sv_pipeline([binary_data], output_emb=True)
|
||||||
|
embs = sv_result['embs'][0]
|
||||||
|
|
||||||
|
result = self._spk_verify.verify(embs)
|
||||||
|
|
||||||
self._do_callback(result)
|
self._do_callback(result)
|
||||||
|
|
||||||
def _run(self) -> None:
|
def _run(self) -> None:
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
|
from .audio import AudioBinary_Config, AudioBinary_data_list, _AudioBinary_data
|
||||||
from .vad import VAD_Functor_result
|
from .vad import VAD_Functor_result
|
||||||
|
|
||||||
|
from .spk import SpeakerCreate, SpeakerResponse
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AudioBinary_Config",
|
"AudioBinary_Config",
|
||||||
"AudioBinary_data_list",
|
"AudioBinary_data_list",
|
||||||
|
68
src/models/spk.py
Normal file
68
src/models/spk.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
"""
|
||||||
|
src/schemas/speaker.py
|
||||||
|
------------------------
|
||||||
|
此模块定义与说话人(speakers)表对应的 Pydantic 模型,用于 API 数据验证和序列化。
|
||||||
|
|
||||||
|
模型说明:
|
||||||
|
- SpeakerBase:定义说话人的基础字段,包括姓名与描述。
|
||||||
|
- SpeakerCreate:用于创建说话人时的数据验证,直接继承 SpeakerBase。
|
||||||
|
- SpeakerUpdate:用于更新说话人信息时,所有字段均为可选。
|
||||||
|
- SpeakerResponse:返回给客户端时使用,包含数据库生成的字段(如 speaker_id、created_at 等)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List
|
||||||
|
from uuid import UUID
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from .base import BaseSchema
|
||||||
|
|
||||||
|
# 基础模型,定义说话人的核心属性
|
||||||
|
class SpeakerBase(BaseSchema):
|
||||||
|
speaker_id: UUID = Field(
|
||||||
|
...,
|
||||||
|
description="说话人唯一标识符"
|
||||||
|
)
|
||||||
|
speaker_name: str = Field(
|
||||||
|
...,
|
||||||
|
max_length=255,
|
||||||
|
description="说话人姓名,必填,最多255字符"
|
||||||
|
)
|
||||||
|
speaker_description: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="说话人描述信息,可选"
|
||||||
|
)
|
||||||
|
speaker_embs: Optional[List[float]] = Field(
|
||||||
|
None,
|
||||||
|
description="说话人embedding,可选"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 用于创建说话人时的数据模型,直接继承 SpeakerBase
|
||||||
|
class SpeakerCreate(SpeakerBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 用于更新说话人时的数据模型,所有字段都是可选的
|
||||||
|
class SpeakerUpdate(BaseModel):
|
||||||
|
speaker_name: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
max_length=255,
|
||||||
|
description="说话人姓名"
|
||||||
|
)
|
||||||
|
speaker_description: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="说话人描述信息"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 用于 API 返回的说话人响应模型,扩展基础模型以包含数据库生成的字段
|
||||||
|
class SpeakerResponse(SpeakerBase):
|
||||||
|
speaker_id: UUID = Field(
|
||||||
|
...,
|
||||||
|
description="说话人的唯一标识符(UUID)"
|
||||||
|
)
|
||||||
|
created_at: datetime = Field(
|
||||||
|
...,
|
||||||
|
description="记录创建时间"
|
||||||
|
)
|
||||||
|
updated_at: datetime = Field(
|
||||||
|
...,
|
||||||
|
description="最近更新时间"
|
||||||
|
)
|
3
src/websockets/endpoint/__init__.py
Normal file
3
src/websockets/endpoint/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .asr_endpoint import router as asr_router
|
||||||
|
|
||||||
|
__all__ = ["asr_router"]
|
32
src/websockets/endpoint/asr_endpoint.py
Normal file
32
src/websockets/endpoint/asr_endpoint.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
"""
|
||||||
|
-*- coding: utf-8 -*-
|
||||||
|
此模块是ASR的websocket端点, 使用FastAPI的websocket端点
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import WebSocket, APIRouter
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
from src.runner.ASRRunner import ASRRunner
|
||||||
|
|
||||||
|
ASRRunner_instance = ASRRunner()
|
||||||
|
|
||||||
|
from src.core import ModelLoader
|
||||||
|
model_loader = ModelLoader()
|
||||||
|
args = {
|
||||||
|
"asr_model": "paraformer-zh",
|
||||||
|
"asr_model_revision": "v2.0.4",
|
||||||
|
"vad_model": "fsmn-vad",
|
||||||
|
"vad_model_revision": "v2.0.4",
|
||||||
|
"spk_model": "cam++",
|
||||||
|
"spk_model_revision": "v2.0.2",
|
||||||
|
"audio_update": False,
|
||||||
|
}
|
||||||
|
models = model_loader.load_models(args)
|
||||||
|
|
||||||
|
@router.websocket("/asr_full")
|
||||||
|
async def asr_endpoint(websocket: WebSocket):
|
||||||
|
await websocket.accept()
|
||||||
|
while True:
|
||||||
|
data = await websocket.receive_text()
|
||||||
|
print(data)
|
3
src/websockets/router.py
Normal file
3
src/websockets/router.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from endpoint import asr_router
|
||||||
|
|
||||||
|
__all__ = ["asr_router"]
|
@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
from tests.pipeline.asr_test import test_asr_pipeline
|
from tests.pipeline.asr_test import test_asr_pipeline
|
||||||
from src.utils.logger import get_module_logger, setup_root_logger
|
from src.utils.logger import get_module_logger, setup_root_logger
|
||||||
from tests.runner.stt_runner_test import test_asr_runner
|
from tests.runner.asr_runner_test import test_asr_runner
|
||||||
|
|
||||||
setup_root_logger(level="INFO", log_file="logs/test_main.log")
|
setup_root_logger(level="INFO", log_file="logs/test_main.log")
|
||||||
logger = get_module_logger(__name__)
|
logger = get_module_logger(__name__)
|
||||||
|
@ -19,5 +19,7 @@ print(result)
|
|||||||
# 可以传入output_emb参数,输出结果中就会包含提取到的说话人embedding
|
# 可以传入output_emb参数,输出结果中就会包含提取到的说话人embedding
|
||||||
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], output_emb=True)
|
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], output_emb=True)
|
||||||
print(result['embs'], result['outputs'])
|
print(result['embs'], result['outputs'])
|
||||||
|
# result1 = sv_pipeline([speaker1_a_wav], output_emb=True)
|
||||||
|
# print(result1['embs'], result1['outputs'])
|
||||||
# 可以传入save_dir参数,提取到的说话人embedding会存储在save_dir目录中
|
# 可以传入save_dir参数,提取到的说话人embedding会存储在save_dir目录中
|
||||||
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], save_dir='savePath/')
|
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], save_dir='savePath/')
|
Loading…
x
Reference in New Issue
Block a user