175 lines
5.3 KiB
Python
175 lines
5.3 KiB
Python
from typing import List
|
|
from typing import Optional
|
|
from typing import Sequence
|
|
from typing import Tuple
|
|
from typing import Union
|
|
import logging
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
import numpy as np
|
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
|
from funasr.models.transformer.layer_norm import LayerNorm
|
|
from funasr.models.encoder.abs_encoder import AbsEncoder
|
|
import math
|
|
from funasr.models.transformer.utils.repeat import repeat
|
|
|
|
|
|
class EncoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_units,
|
|
num_units,
|
|
kernel_size=3,
|
|
activation="tanh",
|
|
stride=1,
|
|
include_batch_norm=False,
|
|
residual=False,
|
|
):
|
|
super().__init__()
|
|
left_padding = math.ceil((kernel_size - stride) / 2)
|
|
right_padding = kernel_size - stride - left_padding
|
|
self.conv_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
|
self.conv1d = nn.Conv1d(
|
|
input_units,
|
|
num_units,
|
|
kernel_size,
|
|
stride,
|
|
)
|
|
self.activation = self.get_activation(activation)
|
|
if include_batch_norm:
|
|
self.bn = nn.BatchNorm1d(num_units, momentum=0.99, eps=1e-3)
|
|
self.residual = residual
|
|
self.include_batch_norm = include_batch_norm
|
|
self.input_units = input_units
|
|
self.num_units = num_units
|
|
self.stride = stride
|
|
|
|
@staticmethod
|
|
def get_activation(activation):
|
|
if activation == "tanh":
|
|
return nn.Tanh()
|
|
else:
|
|
return nn.ReLU()
|
|
|
|
def forward(self, xs_pad, ilens=None):
|
|
outputs = self.conv1d(self.conv_padding(xs_pad))
|
|
if self.residual and self.stride == 1 and self.input_units == self.num_units:
|
|
outputs = outputs + xs_pad
|
|
|
|
if self.include_batch_norm:
|
|
outputs = self.bn(outputs)
|
|
|
|
# add parenthesis for repeat module
|
|
return self.activation(outputs), ilens
|
|
|
|
|
|
class ConvEncoder(AbsEncoder):
|
|
"""
|
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
|
Convolution encoder in OpenNMT framework
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_layers,
|
|
input_units,
|
|
num_units,
|
|
kernel_size=3,
|
|
dropout_rate=0.3,
|
|
position_encoder=None,
|
|
activation="tanh",
|
|
auxiliary_states=True,
|
|
out_units=None,
|
|
out_norm=False,
|
|
out_residual=False,
|
|
include_batchnorm=False,
|
|
regularization_weight=0.0,
|
|
stride=1,
|
|
tf2torch_tensor_name_prefix_torch: str = "speaker_encoder",
|
|
tf2torch_tensor_name_prefix_tf: str = "EAND/speaker_encoder",
|
|
):
|
|
super().__init__()
|
|
self._output_size = num_units
|
|
|
|
self.num_layers = num_layers
|
|
self.input_units = input_units
|
|
self.num_units = num_units
|
|
self.kernel_size = kernel_size
|
|
self.dropout_rate = dropout_rate
|
|
self.position_encoder = position_encoder
|
|
self.out_units = out_units
|
|
self.auxiliary_states = auxiliary_states
|
|
self.out_norm = out_norm
|
|
self.activation = activation
|
|
self.out_residual = out_residual
|
|
self.include_batch_norm = include_batchnorm
|
|
self.regularization_weight = regularization_weight
|
|
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
|
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
|
if isinstance(stride, int):
|
|
self.stride = [stride] * self.num_layers
|
|
else:
|
|
self.stride = stride
|
|
self.downsample_rate = 1
|
|
for s in self.stride:
|
|
self.downsample_rate *= s
|
|
|
|
self.dropout = nn.Dropout(dropout_rate)
|
|
self.cnn_a = repeat(
|
|
self.num_layers,
|
|
lambda lnum: EncoderLayer(
|
|
input_units if lnum == 0 else num_units,
|
|
num_units,
|
|
kernel_size,
|
|
activation,
|
|
self.stride[lnum],
|
|
include_batchnorm,
|
|
residual=True if lnum > 0 else False,
|
|
),
|
|
)
|
|
|
|
if self.out_units is not None:
|
|
left_padding = math.ceil((kernel_size - stride) / 2)
|
|
right_padding = kernel_size - stride - left_padding
|
|
self.out_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
|
self.conv_out = nn.Conv1d(
|
|
num_units,
|
|
out_units,
|
|
kernel_size,
|
|
)
|
|
|
|
if self.out_norm:
|
|
self.after_norm = LayerNorm(out_units)
|
|
|
|
def output_size(self) -> int:
|
|
return self.num_units
|
|
|
|
def forward(
|
|
self,
|
|
xs_pad: torch.Tensor,
|
|
ilens: torch.Tensor,
|
|
prev_states: torch.Tensor = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
|
|
|
inputs = xs_pad
|
|
if self.position_encoder is not None:
|
|
inputs = self.position_encoder(inputs)
|
|
|
|
if self.dropout_rate > 0:
|
|
inputs = self.dropout(inputs)
|
|
|
|
outputs, _ = self.cnn_a(inputs.transpose(1, 2), ilens)
|
|
|
|
if self.out_units is not None:
|
|
outputs = self.conv_out(self.out_padding(outputs))
|
|
|
|
outputs = outputs.transpose(1, 2)
|
|
if self.out_norm:
|
|
outputs = self.after_norm(outputs)
|
|
|
|
if self.out_residual:
|
|
outputs = outputs + inputs
|
|
|
|
return outputs, ilens, None
|