120 lines
4.2 KiB
Python
120 lines
4.2 KiB
Python
|
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||
|
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||
|
|
||
|
""" This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||
|
|
||
|
|
||
|
class TAP(nn.Module):
|
||
|
"""
|
||
|
Temporal average pooling, only first-order mean is considered
|
||
|
"""
|
||
|
|
||
|
def __init__(self, **kwargs):
|
||
|
super(TAP, self).__init__()
|
||
|
|
||
|
def forward(self, x):
|
||
|
pooling_mean = x.mean(dim=-1)
|
||
|
# To be compatable with 2D input
|
||
|
pooling_mean = pooling_mean.flatten(start_dim=1)
|
||
|
return pooling_mean
|
||
|
|
||
|
|
||
|
class TSDP(nn.Module):
|
||
|
"""
|
||
|
Temporal standard deviation pooling, only second-order std is considered
|
||
|
"""
|
||
|
|
||
|
def __init__(self, **kwargs):
|
||
|
super(TSDP, self).__init__()
|
||
|
|
||
|
def forward(self, x):
|
||
|
# The last dimension is the temporal axis
|
||
|
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
|
||
|
pooling_std = pooling_std.flatten(start_dim=1)
|
||
|
return pooling_std
|
||
|
|
||
|
|
||
|
class TSTP(nn.Module):
|
||
|
"""
|
||
|
Temporal statistics pooling, concatenate mean and std, which is used in
|
||
|
x-vector
|
||
|
Comment: simple concatenation can not make full use of both statistics
|
||
|
"""
|
||
|
|
||
|
def __init__(self, **kwargs):
|
||
|
super(TSTP, self).__init__()
|
||
|
|
||
|
def forward(self, x, olens):
|
||
|
# The last dimension is the temporal axis
|
||
|
masks = (~make_pad_mask(olens, maxlen=x.shape[-1])[:, None, None, :]).to(x.device)
|
||
|
x_masked = x * masks
|
||
|
sum_without_padding = torch.sum(x_masked, axis=-1)
|
||
|
count_without_padding = torch.sum(masks, axis=-1)
|
||
|
mean_without_padding = sum_without_padding / count_without_padding
|
||
|
|
||
|
var_without_padding = ((x_masked - mean_without_padding.unsqueeze(-1)) ** 2 * masks).sum(
|
||
|
-1
|
||
|
) / count_without_padding
|
||
|
|
||
|
pooling_mean = mean_without_padding
|
||
|
pooling_std = torch.sqrt(var_without_padding + 1e-8)
|
||
|
pooling_mean = pooling_mean.flatten(start_dim=1)
|
||
|
pooling_std = pooling_std.flatten(start_dim=1)
|
||
|
|
||
|
stats = torch.cat((pooling_mean, pooling_std), 1)
|
||
|
return stats
|
||
|
|
||
|
|
||
|
class ASTP(nn.Module):
|
||
|
"""Attentive statistics pooling: Channel- and context-dependent
|
||
|
statistics pooling, first used in ECAPA_TDNN.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
|
||
|
super(ASTP, self).__init__()
|
||
|
self.global_context_att = global_context_att
|
||
|
|
||
|
# Use Conv1d with stride == 1 rather than Linear, then we don't
|
||
|
# need to transpose inputs.
|
||
|
if global_context_att:
|
||
|
self.linear1 = nn.Conv1d(
|
||
|
in_dim * 3, bottleneck_dim, kernel_size=1
|
||
|
) # equals W and b in the paper
|
||
|
else:
|
||
|
self.linear1 = nn.Conv1d(
|
||
|
in_dim, bottleneck_dim, kernel_size=1
|
||
|
) # equals W and b in the paper
|
||
|
self.linear2 = nn.Conv1d(
|
||
|
bottleneck_dim, in_dim, kernel_size=1
|
||
|
) # equals V and k in the paper
|
||
|
|
||
|
def forward(self, x):
|
||
|
"""
|
||
|
x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
|
||
|
or a 4-dimensional tensor in resnet architecture (B,C,F,T)
|
||
|
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
||
|
"""
|
||
|
if len(x.shape) == 4:
|
||
|
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
|
||
|
assert len(x.shape) == 3
|
||
|
|
||
|
if self.global_context_att:
|
||
|
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
||
|
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
||
|
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
||
|
else:
|
||
|
x_in = x
|
||
|
|
||
|
# DON'T use ReLU here! ReLU may be hard to converge.
|
||
|
alpha = torch.tanh(self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
|
||
|
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
||
|
mean = torch.sum(alpha * x, dim=2)
|
||
|
var = torch.sum(alpha * (x**2), dim=2) - mean**2
|
||
|
std = torch.sqrt(var.clamp(min=1e-10))
|
||
|
return torch.cat([mean, std], dim=1)
|