FunASR/funasr/models/lcbnet/encoder.py

399 lines
14 KiB
Python

# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Transformer encoder definition."""
from typing import List
from typing import Optional
from typing import Tuple
import torch
from torch import nn
import logging
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.lcbnet.attention import MultiHeadedAttentionReturnWeight
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
from funasr.models.transformer.utils.repeat import repeat
from funasr.register import tables
class EncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
stochastic_depth_rate (float): Proability to skip this layer.
During training, the layer may skip residual computation and return input
as-is with given probability.
"""
def __init__(
self,
size,
self_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
def forward(self, x, mask, cache=None):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
residual = x
if self.normalize_before:
x = self.norm1(x)
if cache is None:
x_q = x
else:
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = residual + stoch_layer_coeff * self.dropout(self.self_attn(x_q, x, x, mask))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
@tables.register("encoder_classes", "TransformerTextEncoder")
class TransformerTextEncoder(nn.Module):
"""Transformer text encoder module.
Args:
input_size: input dim
output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the number of units of position-wise feed forward
num_blocks: the number of decoder blocks
dropout_rate: dropout rate
attention_dropout_rate: dropout rate in attention
positional_dropout_rate: dropout rate after adding positional encoding
input_layer: input layer type
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before: whether to use layer_norm before the first block
concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied.
i.e. x -> x + att(x)
positionwise_layer_type: linear of conv1d
positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
padding_idx: padding_idx for input_layer=embed
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
):
super().__init__()
self._output_size = output_size
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size),
pos_enc_class(output_size, positional_dropout_rate),
)
self.normalize_before = normalize_before
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
output_size,
MultiHeadedAttention(attention_heads, output_size, attention_dropout_rate),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
xs_pad = self.embed(xs_pad)
xs_pad, masks = self.encoders(xs_pad, masks)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
return xs_pad, olens, None
@tables.register("encoder_classes", "FusionSANEncoder")
class SelfSrcAttention(nn.Module):
"""Single decoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
src_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(
self,
size,
attention_heads,
attention_dim,
linear_units,
self_attention_dropout_rate,
src_attention_dropout_rate,
positional_dropout_rate,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an SelfSrcAttention object."""
super(SelfSrcAttention, self).__init__()
self.size = size
self.self_attn = MultiHeadedAttention(
attention_heads, attention_dim, self_attention_dropout_rate
)
self.src_attn = MultiHeadedAttentionReturnWeight(
attention_heads, attention_dim, src_attention_dropout_rate
)
self.feed_forward = PositionwiseFeedForward(
attention_dim, linear_units, positional_dropout_rate
)
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.norm3 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
"""Compute decoded features.
Args:
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
cache (List[torch.Tensor]): List of cached tensors.
Each tensor shape should be (#batch, maxlen_out - 1, size).
Returns:
torch.Tensor: Output tensor(#batch, maxlen_out, size).
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
"""
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (
tgt.shape[0],
tgt.shape[1] - 1,
self.size,
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = None
if tgt_mask is not None:
tgt_q_mask = tgt_mask[:, -1:, :]
if self.concat_after:
tgt_concat = torch.cat((tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
if self.concat_after:
x_concat = torch.cat((x, self.src_attn(x, memory, memory, memory_mask)), dim=-1)
x = residual + self.concat_linear2(x_concat)
else:
x, score = self.src_attn(x, memory, memory, memory_mask)
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.norm2(x)
residual = x
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm3(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, tgt_mask, memory, memory_mask
@tables.register("encoder_classes", "ConvBiasPredictor")
class ConvPredictor(nn.Module):
def __init__(
self,
size=256,
l_order=3,
r_order=3,
attention_heads=4,
attention_dropout_rate=0.1,
linear_units=2048,
):
super().__init__()
self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate)
self.norm1 = LayerNorm(size)
self.feed_forward = PositionwiseFeedForward(size, linear_units, attention_dropout_rate)
self.norm2 = LayerNorm(size)
self.pad = nn.ConstantPad1d((l_order, r_order), 0)
self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size)
self.output_linear = nn.Linear(size, 1)
def forward(self, text_enc, asr_enc):
# stage1 cross-attention
residual = text_enc
text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None)
# stage2 FFN
residual = text_enc
text_enc = self.norm1(text_enc)
text_enc = residual + self.feed_forward(text_enc)
# stage Conv predictor
text_enc = self.norm2(text_enc)
context = text_enc.transpose(1, 2)
queries = self.pad(context)
memory = self.conv1d(queries)
output = memory + context
output = output.transpose(1, 2)
output = torch.relu(output)
output = self.output_linear(output)
if output.dim() == 3:
output = output.squeeze(2)
return output