125 lines
3.4 KiB
Python
125 lines
3.4 KiB
Python
"""MLP with convolutional gating (cgMLP) definition.
|
|
|
|
References:
|
|
https://openreview.net/forum?id=RA-zVvZLYIy
|
|
https://arxiv.org/abs/2105.08050
|
|
|
|
"""
|
|
|
|
import torch
|
|
|
|
from funasr.models.transformer.utils.nets_utils import get_activation
|
|
from funasr.models.transformer.layer_norm import LayerNorm
|
|
|
|
|
|
class ConvolutionalSpatialGatingUnit(torch.nn.Module):
|
|
"""Convolutional Spatial Gating Unit (CSGU)."""
|
|
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
kernel_size: int,
|
|
dropout_rate: float,
|
|
use_linear_after_conv: bool,
|
|
gate_activation: str,
|
|
):
|
|
super().__init__()
|
|
|
|
n_channels = size // 2 # split input channels
|
|
self.norm = LayerNorm(n_channels)
|
|
self.conv = torch.nn.Conv1d(
|
|
n_channels,
|
|
n_channels,
|
|
kernel_size,
|
|
1,
|
|
(kernel_size - 1) // 2,
|
|
groups=n_channels,
|
|
)
|
|
if use_linear_after_conv:
|
|
self.linear = torch.nn.Linear(n_channels, n_channels)
|
|
else:
|
|
self.linear = None
|
|
|
|
if gate_activation == "identity":
|
|
self.act = torch.nn.Identity()
|
|
else:
|
|
self.act = get_activation(gate_activation)
|
|
|
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
|
|
|
def espnet_initialization_fn(self):
|
|
torch.nn.init.normal_(self.conv.weight, std=1e-6)
|
|
torch.nn.init.ones_(self.conv.bias)
|
|
if self.linear is not None:
|
|
torch.nn.init.normal_(self.linear.weight, std=1e-6)
|
|
torch.nn.init.ones_(self.linear.bias)
|
|
|
|
def forward(self, x, gate_add=None):
|
|
"""Forward method
|
|
|
|
Args:
|
|
x (torch.Tensor): (N, T, D)
|
|
gate_add (torch.Tensor): (N, T, D/2)
|
|
|
|
Returns:
|
|
out (torch.Tensor): (N, T, D/2)
|
|
"""
|
|
|
|
x_r, x_g = x.chunk(2, dim=-1)
|
|
|
|
x_g = self.norm(x_g) # (N, T, D/2)
|
|
x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2) # (N, T, D/2)
|
|
if self.linear is not None:
|
|
x_g = self.linear(x_g)
|
|
|
|
if gate_add is not None:
|
|
x_g = x_g + gate_add
|
|
|
|
x_g = self.act(x_g)
|
|
out = x_r * x_g # (N, T, D/2)
|
|
out = self.dropout(out)
|
|
return out
|
|
|
|
|
|
class ConvolutionalGatingMLP(torch.nn.Module):
|
|
"""Convolutional Gating MLP (cgMLP)."""
|
|
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
linear_units: int,
|
|
kernel_size: int,
|
|
dropout_rate: float,
|
|
use_linear_after_conv: bool,
|
|
gate_activation: str,
|
|
):
|
|
super().__init__()
|
|
|
|
self.channel_proj1 = torch.nn.Sequential(
|
|
torch.nn.Linear(size, linear_units), torch.nn.GELU()
|
|
)
|
|
self.csgu = ConvolutionalSpatialGatingUnit(
|
|
size=linear_units,
|
|
kernel_size=kernel_size,
|
|
dropout_rate=dropout_rate,
|
|
use_linear_after_conv=use_linear_after_conv,
|
|
gate_activation=gate_activation,
|
|
)
|
|
self.channel_proj2 = torch.nn.Linear(linear_units // 2, size)
|
|
|
|
def forward(self, x, mask):
|
|
if isinstance(x, tuple):
|
|
xs_pad, pos_emb = x
|
|
else:
|
|
xs_pad, pos_emb = x, None
|
|
|
|
xs_pad = self.channel_proj1(xs_pad) # size -> linear_units
|
|
xs_pad = self.csgu(xs_pad) # linear_units -> linear_units/2
|
|
xs_pad = self.channel_proj2(xs_pad) # linear_units/2 -> size
|
|
|
|
if pos_emb is not None:
|
|
out = (xs_pad, pos_emb)
|
|
else:
|
|
out = xs_pad
|
|
return out
|