# Copyright 2019 Shigeki Karita # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """Transformer encoder definition.""" from typing import List from typing import Optional from typing import Tuple import torch from torch import nn import logging from funasr.models.transformer.attention import MultiHeadedAttention from funasr.models.transformer.embedding import PositionalEncoding from funasr.models.transformer.layer_norm import LayerNorm from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d from funasr.models.transformer.utils.nets_utils import make_pad_mask from funasr.models.transformer.positionwise_feed_forward import ( PositionwiseFeedForward, # noqa: H301 ) from funasr.models.transformer.utils.repeat import repeat from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D from funasr.models.transformer.utils.lightconv import LightweightConvolution from funasr.models.transformer.utils.lightconv2d import LightweightConvolution2D from funasr.models.transformer.utils.subsampling import Conv2dSubsampling from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2 from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6 from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8 from funasr.models.transformer.utils.subsampling import TooShortUttError from funasr.models.transformer.utils.subsampling import check_short_utt class EncoderLayer(nn.Module): """Encoder layer module. Args: size (int): Input dimension. self_attn (torch.nn.Module): Self-attention module instance. `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance can be used as the argument. feed_forward (torch.nn.Module): Feed-forward module instance. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance can be used as the argument. dropout_rate (float): Dropout rate. normalize_before (bool): Whether to use layer_norm before the first block. concat_after (bool): Whether to concat attention layer's input and output. if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) stochastic_depth_rate (float): Proability to skip this layer. During training, the layer may skip residual computation and return input as-is with given probability. """ def __init__( self, size, self_attn, feed_forward, dropout_rate, normalize_before=True, concat_after=False, stochastic_depth_rate=0.0, ): """Construct an EncoderLayer object.""" super(EncoderLayer, self).__init__() self.self_attn = self_attn self.feed_forward = feed_forward self.norm1 = LayerNorm(size) self.norm2 = LayerNorm(size) self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before self.concat_after = concat_after if self.concat_after: self.concat_linear = nn.Linear(size + size, size) self.stochastic_depth_rate = stochastic_depth_rate def forward(self, x, mask, cache=None): """Compute encoded features. Args: x_input (torch.Tensor): Input tensor (#batch, time, size). mask (torch.Tensor): Mask tensor for the input (#batch, time). cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). Returns: torch.Tensor: Output tensor (#batch, time, size). torch.Tensor: Mask tensor (#batch, time). """ skip_layer = False # with stochastic depth, residual connection `x + f(x)` becomes # `x <- x + 1 / (1 - p) * f(x)` at training time. stoch_layer_coeff = 1.0 if self.training and self.stochastic_depth_rate > 0: skip_layer = torch.rand(1).item() < self.stochastic_depth_rate stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) if skip_layer: if cache is not None: x = torch.cat([cache, x], dim=1) return x, mask residual = x if self.normalize_before: x = self.norm1(x) if cache is None: x_q = x else: assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) x_q = x[:, -1:, :] residual = residual[:, -1:, :] mask = None if mask is None else mask[:, -1:, :] if self.concat_after: x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) x = residual + stoch_layer_coeff * self.concat_linear(x_concat) else: x = residual + stoch_layer_coeff * self.dropout(self.self_attn(x_q, x, x, mask)) if not self.normalize_before: x = self.norm1(x) residual = x if self.normalize_before: x = self.norm2(x) x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) if not self.normalize_before: x = self.norm2(x) if cache is not None: x = torch.cat([cache, x], dim=1) return x, mask class TransformerEncoder_lm(nn.Module): """Transformer encoder module. Args: idim (int): Input dimension. attention_dim (int): Dimension of attention. attention_heads (int): The number of heads of multi head attention. conv_wshare (int): The number of kernel of convolution. Only used in selfattention_layer_type == "lightconv*" or "dynamiconv*". conv_kernel_length (Union[int, str]): Kernel size str of convolution (e.g. 71_71_71_71_71_71). Only used in selfattention_layer_type == "lightconv*" or "dynamiconv*". conv_usebias (bool): Whether to use bias in convolution. Only used in selfattention_layer_type == "lightconv*" or "dynamiconv*". linear_units (int): The number of units of position-wise feed forward. num_blocks (int): The number of decoder blocks. dropout_rate (float): Dropout rate. positional_dropout_rate (float): Dropout rate after adding positional encoding. attention_dropout_rate (float): Dropout rate in attention. input_layer (Union[str, torch.nn.Module]): Input layer type. pos_enc_class (torch.nn.Module): Positional encoding module class. `PositionalEncoding `or `ScaledPositionalEncoding` normalize_before (bool): Whether to use layer_norm before the first block. concat_after (bool): Whether to concat attention layer's input and output. if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. selfattention_layer_type (str): Encoder attention layer type. padding_idx (int): Padding idx for input_layer=embed. stochastic_depth_rate (float): Maximum probability to skip the encoder layer. intermediate_layers (Union[List[int], None]): indices of intermediate CTC layer. indices start from 1. if not None, intermediate outputs are returned (which changes return type signature.) """ def __init__( self, idim, attention_dim=256, attention_heads=4, conv_wshare=4, conv_kernel_length="11", conv_usebias=False, linear_units=2048, num_blocks=6, dropout_rate=0.1, positional_dropout_rate=0.1, attention_dropout_rate=0.0, input_layer="conv2d", pos_enc_class=PositionalEncoding, normalize_before=True, concat_after=False, positionwise_layer_type="linear", positionwise_conv_kernel_size=1, selfattention_layer_type="selfattn", padding_idx=-1, stochastic_depth_rate=0.0, intermediate_layers=None, ctc_softmax=None, conditioning_layer_dim=None, ): """Construct an Encoder object.""" super().__init__() self.conv_subsampling_factor = 1 if input_layer == "linear": self.embed = torch.nn.Sequential( torch.nn.Linear(idim, attention_dim), torch.nn.LayerNorm(attention_dim), torch.nn.Dropout(dropout_rate), torch.nn.ReLU(), pos_enc_class(attention_dim, positional_dropout_rate), ) elif input_layer == "conv2d": self.embed = Conv2dSubsampling(idim, attention_dim, dropout_rate) self.conv_subsampling_factor = 4 elif input_layer == "conv2d-scaled-pos-enc": self.embed = Conv2dSubsampling( idim, attention_dim, dropout_rate, pos_enc_class(attention_dim, positional_dropout_rate), ) self.conv_subsampling_factor = 4 elif input_layer == "conv2d6": self.embed = Conv2dSubsampling6(idim, attention_dim, dropout_rate) self.conv_subsampling_factor = 6 elif input_layer == "conv2d8": self.embed = Conv2dSubsampling8(idim, attention_dim, dropout_rate) self.conv_subsampling_factor = 8 elif input_layer == "embed": self.embed = torch.nn.Sequential( torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx), pos_enc_class(attention_dim, positional_dropout_rate), ) elif isinstance(input_layer, torch.nn.Module): self.embed = torch.nn.Sequential( input_layer, pos_enc_class(attention_dim, positional_dropout_rate), ) elif input_layer is None: self.embed = torch.nn.Sequential(pos_enc_class(attention_dim, positional_dropout_rate)) else: raise ValueError("unknown input_layer: " + input_layer) self.normalize_before = normalize_before positionwise_layer, positionwise_layer_args = self.get_positionwise_layer( positionwise_layer_type, attention_dim, linear_units, dropout_rate, positionwise_conv_kernel_size, ) if selfattention_layer_type in [ "selfattn", "rel_selfattn", "legacy_rel_selfattn", ]: logging.info("encoder self-attention layer type = self-attention") encoder_selfattn_layer = MultiHeadedAttention encoder_selfattn_layer_args = [ ( attention_heads, attention_dim, attention_dropout_rate, ) ] * num_blocks elif selfattention_layer_type == "lightconv": logging.info("encoder self-attention layer type = lightweight convolution") encoder_selfattn_layer = LightweightConvolution encoder_selfattn_layer_args = [ ( conv_wshare, attention_dim, attention_dropout_rate, int(conv_kernel_length.split("_")[lnum]), False, conv_usebias, ) for lnum in range(num_blocks) ] elif selfattention_layer_type == "lightconv2d": logging.info( "encoder self-attention layer " "type = lightweight convolution 2-dimensional" ) encoder_selfattn_layer = LightweightConvolution2D encoder_selfattn_layer_args = [ ( conv_wshare, attention_dim, attention_dropout_rate, int(conv_kernel_length.split("_")[lnum]), False, conv_usebias, ) for lnum in range(num_blocks) ] elif selfattention_layer_type == "dynamicconv": logging.info("encoder self-attention layer type = dynamic convolution") encoder_selfattn_layer = DynamicConvolution encoder_selfattn_layer_args = [ ( conv_wshare, attention_dim, attention_dropout_rate, int(conv_kernel_length.split("_")[lnum]), False, conv_usebias, ) for lnum in range(num_blocks) ] elif selfattention_layer_type == "dynamicconv2d": logging.info("encoder self-attention layer type = dynamic convolution 2-dimensional") encoder_selfattn_layer = DynamicConvolution2D encoder_selfattn_layer_args = [ ( conv_wshare, attention_dim, attention_dropout_rate, int(conv_kernel_length.split("_")[lnum]), False, conv_usebias, ) for lnum in range(num_blocks) ] else: raise NotImplementedError(selfattention_layer_type) self.encoders = repeat( num_blocks, lambda lnum: EncoderLayer( attention_dim, encoder_selfattn_layer(*encoder_selfattn_layer_args[lnum]), positionwise_layer(*positionwise_layer_args), dropout_rate, normalize_before, concat_after, stochastic_depth_rate * float(1 + lnum) / num_blocks, ), ) if self.normalize_before: self.after_norm = LayerNorm(attention_dim) self.intermediate_layers = intermediate_layers self.use_conditioning = True if ctc_softmax is not None else False if self.use_conditioning: self.ctc_softmax = ctc_softmax self.conditioning_layer = torch.nn.Linear(conditioning_layer_dim, attention_dim) def get_positionwise_layer( self, positionwise_layer_type="linear", attention_dim=256, linear_units=2048, dropout_rate=0.1, positionwise_conv_kernel_size=1, ): """Define positionwise layer.""" if positionwise_layer_type == "linear": positionwise_layer = PositionwiseFeedForward positionwise_layer_args = (attention_dim, linear_units, dropout_rate) elif positionwise_layer_type == "conv1d": positionwise_layer = MultiLayeredConv1d positionwise_layer_args = ( attention_dim, linear_units, positionwise_conv_kernel_size, dropout_rate, ) elif positionwise_layer_type == "conv1d-linear": positionwise_layer = Conv1dLinear positionwise_layer_args = ( attention_dim, linear_units, positionwise_conv_kernel_size, dropout_rate, ) else: raise NotImplementedError("Support only linear or conv1d.") return positionwise_layer, positionwise_layer_args def forward(self, xs, masks): """Encode input sequence. Args: xs (torch.Tensor): Input tensor (#batch, time, idim). masks (torch.Tensor): Mask tensor (#batch, time). Returns: torch.Tensor: Output tensor (#batch, time, attention_dim). torch.Tensor: Mask tensor (#batch, time). """ if isinstance( self.embed, (Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8), ): xs, masks = self.embed(xs, masks) else: xs = self.embed(xs) if self.intermediate_layers is None: xs, masks = self.encoders(xs, masks) else: intermediate_outputs = [] for layer_idx, encoder_layer in enumerate(self.encoders): xs, masks = encoder_layer(xs, masks) if ( self.intermediate_layers is not None and layer_idx + 1 in self.intermediate_layers ): encoder_output = xs # intermediate branches also require normalization. if self.normalize_before: encoder_output = self.after_norm(encoder_output) intermediate_outputs.append(encoder_output) if self.use_conditioning: intermediate_result = self.ctc_softmax(encoder_output) xs = xs + self.conditioning_layer(intermediate_result) if self.normalize_before: xs = self.after_norm(xs) if self.intermediate_layers is not None: return xs, masks, intermediate_outputs return xs, masks def forward_one_step(self, xs, masks, cache=None): """Encode input frame. Args: xs (torch.Tensor): Input tensor. masks (torch.Tensor): Mask tensor. cache (List[torch.Tensor]): List of cache tensors. Returns: torch.Tensor: Output tensor. torch.Tensor: Mask tensor. List[torch.Tensor]: List of new cache tensors. """ if isinstance(self.embed, Conv2dSubsampling): xs, masks = self.embed(xs, masks) else: xs = self.embed(xs) if cache is None: cache = [None for _ in range(len(self.encoders))] new_cache = [] for c, e in zip(cache, self.encoders): xs, masks = e(xs, masks, cache=c) new_cache.append(xs) if self.normalize_before: xs = self.after_norm(xs) return xs, masks, new_cache