# Copyright 2019 Shigeki Karita # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """Transformer encoder definition.""" from typing import List from typing import Optional from typing import Tuple import torch from torch import nn import logging from funasr.models.transformer.attention import MultiHeadedAttention from funasr.models.lcbnet.attention import MultiHeadedAttentionReturnWeight from funasr.models.transformer.embedding import PositionalEncoding from funasr.models.transformer.layer_norm import LayerNorm from funasr.models.transformer.utils.nets_utils import make_pad_mask from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward from funasr.models.transformer.utils.repeat import repeat from funasr.register import tables class EncoderLayer(nn.Module): """Encoder layer module. Args: size (int): Input dimension. self_attn (torch.nn.Module): Self-attention module instance. `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance can be used as the argument. feed_forward (torch.nn.Module): Feed-forward module instance. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance can be used as the argument. dropout_rate (float): Dropout rate. normalize_before (bool): Whether to use layer_norm before the first block. concat_after (bool): Whether to concat attention layer's input and output. if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) stochastic_depth_rate (float): Proability to skip this layer. During training, the layer may skip residual computation and return input as-is with given probability. """ def __init__( self, size, self_attn, feed_forward, dropout_rate, normalize_before=True, concat_after=False, stochastic_depth_rate=0.0, ): """Construct an EncoderLayer object.""" super(EncoderLayer, self).__init__() self.self_attn = self_attn self.feed_forward = feed_forward self.norm1 = LayerNorm(size) self.norm2 = LayerNorm(size) self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before self.concat_after = concat_after if self.concat_after: self.concat_linear = nn.Linear(size + size, size) self.stochastic_depth_rate = stochastic_depth_rate def forward(self, x, mask, cache=None): """Compute encoded features. Args: x_input (torch.Tensor): Input tensor (#batch, time, size). mask (torch.Tensor): Mask tensor for the input (#batch, time). cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). Returns: torch.Tensor: Output tensor (#batch, time, size). torch.Tensor: Mask tensor (#batch, time). """ skip_layer = False # with stochastic depth, residual connection `x + f(x)` becomes # `x <- x + 1 / (1 - p) * f(x)` at training time. stoch_layer_coeff = 1.0 if self.training and self.stochastic_depth_rate > 0: skip_layer = torch.rand(1).item() < self.stochastic_depth_rate stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) if skip_layer: if cache is not None: x = torch.cat([cache, x], dim=1) return x, mask residual = x if self.normalize_before: x = self.norm1(x) if cache is None: x_q = x else: assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) x_q = x[:, -1:, :] residual = residual[:, -1:, :] mask = None if mask is None else mask[:, -1:, :] if self.concat_after: x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) x = residual + stoch_layer_coeff * self.concat_linear(x_concat) else: x = residual + stoch_layer_coeff * self.dropout(self.self_attn(x_q, x, x, mask)) if not self.normalize_before: x = self.norm1(x) residual = x if self.normalize_before: x = self.norm2(x) x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) if not self.normalize_before: x = self.norm2(x) if cache is not None: x = torch.cat([cache, x], dim=1) return x, mask @tables.register("encoder_classes", "TransformerTextEncoder") class TransformerTextEncoder(nn.Module): """Transformer text encoder module. Args: input_size: input dim output_size: dimension of attention attention_heads: the number of heads of multi head attention linear_units: the number of units of position-wise feed forward num_blocks: the number of decoder blocks dropout_rate: dropout rate attention_dropout_rate: dropout rate in attention positional_dropout_rate: dropout rate after adding positional encoding input_layer: input layer type pos_enc_class: PositionalEncoding or ScaledPositionalEncoding normalize_before: whether to use layer_norm before the first block concat_after: whether to concat attention layer's input and output if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) positionwise_layer_type: linear of conv1d positionwise_conv_kernel_size: kernel size of positionwise conv1d layer padding_idx: padding_idx for input_layer=embed """ def __init__( self, input_size: int, output_size: int = 256, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, pos_enc_class=PositionalEncoding, normalize_before: bool = True, concat_after: bool = False, ): super().__init__() self._output_size = output_size self.embed = torch.nn.Sequential( torch.nn.Embedding(input_size, output_size), pos_enc_class(output_size, positional_dropout_rate), ) self.normalize_before = normalize_before positionwise_layer = PositionwiseFeedForward positionwise_layer_args = ( output_size, linear_units, dropout_rate, ) self.encoders = repeat( num_blocks, lambda lnum: EncoderLayer( output_size, MultiHeadedAttention(attention_heads, output_size, attention_dropout_rate), positionwise_layer(*positionwise_layer_args), dropout_rate, normalize_before, concat_after, ), ) if self.normalize_before: self.after_norm = LayerNorm(output_size) def output_size(self) -> int: return self._output_size def forward( self, xs_pad: torch.Tensor, ilens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Embed positions in tensor. Args: xs_pad: input tensor (B, L, D) ilens: input length (B) Returns: position embedded tensor and mask """ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) xs_pad = self.embed(xs_pad) xs_pad, masks = self.encoders(xs_pad, masks) if self.normalize_before: xs_pad = self.after_norm(xs_pad) olens = masks.squeeze(1).sum(1) return xs_pad, olens, None @tables.register("encoder_classes", "FusionSANEncoder") class SelfSrcAttention(nn.Module): """Single decoder layer module. Args: size (int): Input dimension. self_attn (torch.nn.Module): Self-attention module instance. `MultiHeadedAttention` instance can be used as the argument. src_attn (torch.nn.Module): Self-attention module instance. `MultiHeadedAttention` instance can be used as the argument. feed_forward (torch.nn.Module): Feed-forward module instance. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance can be used as the argument. dropout_rate (float): Dropout rate. normalize_before (bool): Whether to use layer_norm before the first block. concat_after (bool): Whether to concat attention layer's input and output. if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) """ def __init__( self, size, attention_heads, attention_dim, linear_units, self_attention_dropout_rate, src_attention_dropout_rate, positional_dropout_rate, dropout_rate, normalize_before=True, concat_after=False, ): """Construct an SelfSrcAttention object.""" super(SelfSrcAttention, self).__init__() self.size = size self.self_attn = MultiHeadedAttention( attention_heads, attention_dim, self_attention_dropout_rate ) self.src_attn = MultiHeadedAttentionReturnWeight( attention_heads, attention_dim, src_attention_dropout_rate ) self.feed_forward = PositionwiseFeedForward( attention_dim, linear_units, positional_dropout_rate ) self.norm1 = LayerNorm(size) self.norm2 = LayerNorm(size) self.norm3 = LayerNorm(size) self.dropout = nn.Dropout(dropout_rate) self.normalize_before = normalize_before self.concat_after = concat_after if self.concat_after: self.concat_linear1 = nn.Linear(size + size, size) self.concat_linear2 = nn.Linear(size + size, size) def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None): """Compute decoded features. Args: tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). cache (List[torch.Tensor]): List of cached tensors. Each tensor shape should be (#batch, maxlen_out - 1, size). Returns: torch.Tensor: Output tensor(#batch, maxlen_out, size). torch.Tensor: Mask for output tensor (#batch, maxlen_out). torch.Tensor: Encoded memory (#batch, maxlen_in, size). torch.Tensor: Encoded memory mask (#batch, maxlen_in). """ residual = tgt if self.normalize_before: tgt = self.norm1(tgt) if cache is None: tgt_q = tgt tgt_q_mask = tgt_mask else: # compute only the last frame query keeping dim: max_time_out -> 1 assert cache.shape == ( tgt.shape[0], tgt.shape[1] - 1, self.size, ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" tgt_q = tgt[:, -1:, :] residual = residual[:, -1:, :] tgt_q_mask = None if tgt_mask is not None: tgt_q_mask = tgt_mask[:, -1:, :] if self.concat_after: tgt_concat = torch.cat((tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1) x = residual + self.concat_linear1(tgt_concat) else: x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) if not self.normalize_before: x = self.norm1(x) residual = x if self.normalize_before: x = self.norm2(x) if self.concat_after: x_concat = torch.cat((x, self.src_attn(x, memory, memory, memory_mask)), dim=-1) x = residual + self.concat_linear2(x_concat) else: x, score = self.src_attn(x, memory, memory, memory_mask) x = residual + self.dropout(x) if not self.normalize_before: x = self.norm2(x) residual = x if self.normalize_before: x = self.norm3(x) x = residual + self.dropout(self.feed_forward(x)) if not self.normalize_before: x = self.norm3(x) if cache is not None: x = torch.cat([cache, x], dim=1) return x, tgt_mask, memory, memory_mask @tables.register("encoder_classes", "ConvBiasPredictor") class ConvPredictor(nn.Module): def __init__( self, size=256, l_order=3, r_order=3, attention_heads=4, attention_dropout_rate=0.1, linear_units=2048, ): super().__init__() self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate) self.norm1 = LayerNorm(size) self.feed_forward = PositionwiseFeedForward(size, linear_units, attention_dropout_rate) self.norm2 = LayerNorm(size) self.pad = nn.ConstantPad1d((l_order, r_order), 0) self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size) self.output_linear = nn.Linear(size, 1) def forward(self, text_enc, asr_enc): # stage1 cross-attention residual = text_enc text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None) # stage2 FFN residual = text_enc text_enc = self.norm1(text_enc) text_enc = residual + self.feed_forward(text_enc) # stage Conv predictor text_enc = self.norm2(text_enc) context = text_enc.transpose(1, 2) queries = self.pad(context) memory = self.conv1d(queries) output = memory + context output = output.transpose(1, 2) output = torch.relu(output) output = self.output_linear(output) if output.dim() == 3: output = output.squeeze(2) return output