429 lines
12 KiB
Python
429 lines
12 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
try:
|
|
from rotary_embedding_torch import RotaryEmbedding
|
|
except:
|
|
print(
|
|
"If you want use mossformer, lease install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch"
|
|
)
|
|
from funasr.models.transformer.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
|
|
from funasr.models.transformer.embedding import ScaledSinuEmbedding
|
|
from funasr.models.transformer.mossformer import FLASH_ShareA_FFConvM
|
|
|
|
|
|
def select_norm(norm, dim, shape):
|
|
"""Just a wrapper to select the normalization type."""
|
|
|
|
if norm == "gln":
|
|
return GlobalLayerNorm(dim, shape, elementwise_affine=True)
|
|
if norm == "cln":
|
|
return CumulativeLayerNorm(dim, elementwise_affine=True)
|
|
if norm == "ln":
|
|
return nn.GroupNorm(1, dim, eps=1e-8)
|
|
else:
|
|
return nn.BatchNorm1d(dim)
|
|
|
|
|
|
class MossformerBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim,
|
|
depth,
|
|
group_size=256,
|
|
query_key_dim=128,
|
|
expansion_factor=4.0,
|
|
causal=False,
|
|
attn_dropout=0.1,
|
|
norm_type="scalenorm",
|
|
shift_tokens=True
|
|
):
|
|
super().__init__()
|
|
assert norm_type in (
|
|
"scalenorm",
|
|
"layernorm",
|
|
), "norm_type must be one of scalenorm or layernorm"
|
|
|
|
if norm_type == "scalenorm":
|
|
norm_klass = ScaleNorm
|
|
elif norm_type == "layernorm":
|
|
norm_klass = nn.LayerNorm
|
|
|
|
self.group_size = group_size
|
|
|
|
rotary_pos_emb = RotaryEmbedding(dim=min(32, query_key_dim))
|
|
# max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
FLASH_ShareA_FFConvM(
|
|
dim=dim,
|
|
group_size=group_size,
|
|
query_key_dim=query_key_dim,
|
|
expansion_factor=expansion_factor,
|
|
causal=causal,
|
|
dropout=attn_dropout,
|
|
rotary_pos_emb=rotary_pos_emb,
|
|
norm_klass=norm_klass,
|
|
shift_tokens=shift_tokens,
|
|
)
|
|
for _ in range(depth)
|
|
]
|
|
)
|
|
|
|
def forward(self, x, *, mask=None):
|
|
ii = 0
|
|
for flash in self.layers:
|
|
x = flash(x, mask=mask)
|
|
ii = ii + 1
|
|
return x
|
|
|
|
|
|
class MossFormer_MaskNet(nn.Module):
|
|
"""The MossFormer module for computing output masks.
|
|
|
|
Arguments
|
|
---------
|
|
in_channels : int
|
|
Number of channels at the output of the encoder.
|
|
out_channels : int
|
|
Number of channels that would be inputted to the intra and inter blocks.
|
|
num_blocks : int
|
|
Number of layers of Dual Computation Block.
|
|
norm : str
|
|
Normalization type.
|
|
num_spks : int
|
|
Number of sources (speakers).
|
|
skip_around_intra : bool
|
|
Skip connection around intra.
|
|
use_global_pos_enc : bool
|
|
Global positional encodings.
|
|
max_length : int
|
|
Maximum sequence length.
|
|
|
|
Example
|
|
---------
|
|
>>> mossformer_block = MossFormerM(1, 64, 8)
|
|
>>> mossformer_masknet = MossFormer_MaskNet(64, 64, intra_block, num_spks=2)
|
|
>>> x = torch.randn(10, 64, 2000)
|
|
>>> x = mossformer_masknet(x)
|
|
>>> x.shape
|
|
torch.Size([2, 10, 64, 2000])
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
num_blocks=24,
|
|
norm="ln",
|
|
num_spks=2,
|
|
skip_around_intra=True,
|
|
use_global_pos_enc=True,
|
|
max_length=20000,
|
|
):
|
|
super(MossFormer_MaskNet, self).__init__()
|
|
self.num_spks = num_spks
|
|
self.num_blocks = num_blocks
|
|
self.norm = select_norm(norm, in_channels, 3)
|
|
self.conv1d_encoder = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
|
self.use_global_pos_enc = use_global_pos_enc
|
|
|
|
if self.use_global_pos_enc:
|
|
self.pos_enc = ScaledSinuEmbedding(out_channels)
|
|
|
|
self.mdl = Computation_Block(
|
|
num_blocks,
|
|
out_channels,
|
|
norm,
|
|
skip_around_intra=skip_around_intra,
|
|
)
|
|
|
|
self.conv1d_out = nn.Conv1d(out_channels, out_channels * num_spks, kernel_size=1)
|
|
self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False)
|
|
self.prelu = nn.PReLU()
|
|
self.activation = nn.ReLU()
|
|
# gated output layer
|
|
self.output = nn.Sequential(nn.Conv1d(out_channels, out_channels, 1), nn.Tanh())
|
|
self.output_gate = nn.Sequential(nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid())
|
|
|
|
def forward(self, x):
|
|
"""Returns the output tensor.
|
|
|
|
Arguments
|
|
---------
|
|
x : torch.Tensor
|
|
Input tensor of dimension [B, N, S].
|
|
|
|
Returns
|
|
-------
|
|
out : torch.Tensor
|
|
Output tensor of dimension [spks, B, N, S]
|
|
where, spks = Number of speakers
|
|
B = Batchsize,
|
|
N = number of filters
|
|
S = the number of time frames
|
|
"""
|
|
|
|
# before each line we indicate the shape after executing the line
|
|
|
|
# [B, N, L]
|
|
x = self.norm(x)
|
|
|
|
# [B, N, L]
|
|
x = self.conv1d_encoder(x)
|
|
if self.use_global_pos_enc:
|
|
# x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
|
|
# x.size(1) ** 0.5)
|
|
base = x
|
|
x = x.transpose(1, -1)
|
|
emb = self.pos_enc(x)
|
|
emb = emb.transpose(0, -1)
|
|
# print('base: {}, emb: {}'.format(base.shape, emb.shape))
|
|
x = base + emb
|
|
|
|
# [B, N, S]
|
|
# for i in range(self.num_modules):
|
|
# x = self.dual_mdl[i](x)
|
|
x = self.mdl(x)
|
|
x = self.prelu(x)
|
|
|
|
# [B, N*spks, S]
|
|
x = self.conv1d_out(x)
|
|
B, _, S = x.shape
|
|
|
|
# [B*spks, N, S]
|
|
x = x.view(B * self.num_spks, -1, S)
|
|
|
|
# [B*spks, N, S]
|
|
x = self.output(x) * self.output_gate(x)
|
|
|
|
# [B*spks, N, S]
|
|
x = self.conv1_decoder(x)
|
|
|
|
# [B, spks, N, S]
|
|
_, N, L = x.shape
|
|
x = x.view(B, self.num_spks, N, L)
|
|
x = self.activation(x)
|
|
|
|
# [spks, B, N, S]
|
|
x = x.transpose(0, 1)
|
|
|
|
return x
|
|
|
|
|
|
class MossFormerEncoder(nn.Module):
|
|
"""Convolutional Encoder Layer.
|
|
|
|
Arguments
|
|
---------
|
|
kernel_size : int
|
|
Length of filters.
|
|
in_channels : int
|
|
Number of input channels.
|
|
out_channels : int
|
|
Number of output channels.
|
|
|
|
Example
|
|
-------
|
|
>>> x = torch.randn(2, 1000)
|
|
>>> encoder = Encoder(kernel_size=4, out_channels=64)
|
|
>>> h = encoder(x)
|
|
>>> h.shape
|
|
torch.Size([2, 64, 499])
|
|
"""
|
|
|
|
def __init__(self, kernel_size=2, out_channels=64, in_channels=1):
|
|
super(MossFormerEncoder, self).__init__()
|
|
self.conv1d = nn.Conv1d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=kernel_size // 2,
|
|
groups=1,
|
|
bias=False,
|
|
)
|
|
self.in_channels = in_channels
|
|
|
|
def forward(self, x):
|
|
"""Return the encoded output.
|
|
|
|
Arguments
|
|
---------
|
|
x : torch.Tensor
|
|
Input tensor with dimensionality [B, L].
|
|
Return
|
|
------
|
|
x : torch.Tensor
|
|
Encoded tensor with dimensionality [B, N, T_out].
|
|
|
|
where B = Batchsize
|
|
L = Number of timepoints
|
|
N = Number of filters
|
|
T_out = Number of timepoints at the output of the encoder
|
|
"""
|
|
# B x L -> B x 1 x L
|
|
if self.in_channels == 1:
|
|
x = torch.unsqueeze(x, dim=1)
|
|
# B x 1 x L -> B x N x T_out
|
|
x = self.conv1d(x)
|
|
x = F.relu(x)
|
|
|
|
return x
|
|
|
|
|
|
class MossFormerM(nn.Module):
|
|
"""This class implements the transformer encoder.
|
|
|
|
Arguments
|
|
---------
|
|
num_blocks : int
|
|
Number of mossformer blocks to include.
|
|
d_model : int
|
|
The dimension of the input embedding.
|
|
attn_dropout : float
|
|
Dropout for the self-attention (Optional).
|
|
group_size: int
|
|
the chunk size
|
|
query_key_dim: int
|
|
the attention vector dimension
|
|
expansion_factor: int
|
|
the expansion factor for the linear projection in conv module
|
|
causal: bool
|
|
true for causal / false for non causal
|
|
|
|
Example
|
|
-------
|
|
>>> import torch
|
|
>>> x = torch.rand((8, 60, 512))
|
|
>>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512)
|
|
>>> output, _ = net(x)
|
|
>>> output.shape
|
|
torch.Size([8, 60, 512])
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_blocks,
|
|
d_model=None,
|
|
causal=False,
|
|
group_size=256,
|
|
query_key_dim=128,
|
|
expansion_factor=4.0,
|
|
attn_dropout=0.1,
|
|
):
|
|
super().__init__()
|
|
|
|
self.mossformerM = MossformerBlock(
|
|
dim=d_model,
|
|
depth=num_blocks,
|
|
group_size=group_size,
|
|
query_key_dim=query_key_dim,
|
|
expansion_factor=expansion_factor,
|
|
causal=causal,
|
|
attn_dropout=attn_dropout,
|
|
)
|
|
self.norm = nn.LayerNorm(d_model, eps=1e-6)
|
|
|
|
def forward(
|
|
self,
|
|
src,
|
|
):
|
|
"""
|
|
Arguments
|
|
----------
|
|
src : torch.Tensor
|
|
Tensor shape [B, L, N],
|
|
where, B = Batchsize,
|
|
L = time points
|
|
N = number of filters
|
|
The sequence to the encoder layer (required).
|
|
src_mask : tensor
|
|
The mask for the src sequence (optional).
|
|
src_key_padding_mask : tensor
|
|
The mask for the src keys per batch (optional).
|
|
"""
|
|
output = self.mossformerM(src)
|
|
output = self.norm(output)
|
|
|
|
return output
|
|
|
|
|
|
class Computation_Block(nn.Module):
|
|
"""Computation block for dual-path processing.
|
|
|
|
Arguments
|
|
---------
|
|
out_channels : int
|
|
Dimensionality of inter/intra model.
|
|
norm : str
|
|
Normalization type.
|
|
skip_around_intra : bool
|
|
Skip connection around the intra layer.
|
|
|
|
Example
|
|
---------
|
|
>>> comp_block = Computation_Block(64)
|
|
>>> x = torch.randn(10, 64, 100)
|
|
>>> x = comp_block(x)
|
|
>>> x.shape
|
|
torch.Size([10, 64, 100])
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_blocks,
|
|
out_channels,
|
|
norm="ln",
|
|
skip_around_intra=True,
|
|
):
|
|
super(Computation_Block, self).__init__()
|
|
|
|
##MossFormer2M: MossFormer with recurrence
|
|
# self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels)
|
|
##MossFormerM: the orignal MossFormer
|
|
self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels)
|
|
self.skip_around_intra = skip_around_intra
|
|
|
|
# Norm
|
|
self.norm = norm
|
|
if norm is not None:
|
|
self.intra_norm = select_norm(norm, out_channels, 3)
|
|
|
|
def forward(self, x):
|
|
"""Returns the output tensor.
|
|
|
|
Arguments
|
|
---------
|
|
x : torch.Tensor
|
|
Input tensor of dimension [B, N, S].
|
|
|
|
|
|
Return
|
|
---------
|
|
out: torch.Tensor
|
|
Output tensor of dimension [B, N, S].
|
|
where, B = Batchsize,
|
|
N = number of filters
|
|
S = sequence time index
|
|
"""
|
|
B, N, S = x.shape
|
|
# intra RNN
|
|
# [B, S, N]
|
|
intra = x.permute(0, 2, 1).contiguous() # .view(B, S, N)
|
|
|
|
intra = self.intra_mdl(intra)
|
|
|
|
# [B, N, S]
|
|
intra = intra.permute(0, 2, 1).contiguous()
|
|
if self.norm is not None:
|
|
intra = self.intra_norm(intra)
|
|
|
|
# [B, N, S]
|
|
if self.skip_around_intra:
|
|
intra = intra + x
|
|
|
|
out = intra
|
|
return out
|