33 lines
725 B
Python
33 lines
725 B
Python
|
import torch
|
||
|
from torch.nn import functional as F
|
||
|
|
||
|
|
||
|
class DotScorer(torch.nn.Module):
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
xs_pad: torch.Tensor,
|
||
|
spk_emb: torch.Tensor,
|
||
|
):
|
||
|
# xs_pad: B, T, D
|
||
|
# spk_emb: B, N, D
|
||
|
scores = torch.matmul(xs_pad, spk_emb.transpose(1, 2))
|
||
|
return scores
|
||
|
|
||
|
|
||
|
class CosScorer(torch.nn.Module):
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
xs_pad: torch.Tensor,
|
||
|
spk_emb: torch.Tensor,
|
||
|
):
|
||
|
# xs_pad: B, T, D
|
||
|
# spk_emb: B, N, D
|
||
|
scores = F.cosine_similarity(xs_pad.unsqueeze(2), spk_emb.unsqueeze(1), dim=-1)
|
||
|
return scores
|