155 lines
4.7 KiB
Python
155 lines
4.7 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from funasr.models.data2vec.multihead_attention import MultiheadAttention
|
|
|
|
|
|
class Fp32LayerNorm(nn.LayerNorm):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def forward(self, input):
|
|
output = F.layer_norm(
|
|
input.float(),
|
|
self.normalized_shape,
|
|
self.weight.float() if self.weight is not None else None,
|
|
self.bias.float() if self.bias is not None else None,
|
|
self.eps,
|
|
)
|
|
return output.type_as(input)
|
|
|
|
|
|
class Fp32GroupNorm(nn.GroupNorm):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def forward(self, input):
|
|
output = F.group_norm(
|
|
input.float(),
|
|
self.num_groups,
|
|
self.weight.float() if self.weight is not None else None,
|
|
self.bias.float() if self.bias is not None else None,
|
|
self.eps,
|
|
)
|
|
return output.type_as(input)
|
|
|
|
|
|
class TransposeLast(nn.Module):
|
|
def __init__(self, deconstruct_idx=None):
|
|
super().__init__()
|
|
self.deconstruct_idx = deconstruct_idx
|
|
|
|
def forward(self, x):
|
|
if self.deconstruct_idx is not None:
|
|
x = x[self.deconstruct_idx]
|
|
return x.transpose(-2, -1)
|
|
|
|
|
|
class SamePad(nn.Module):
|
|
def __init__(self, kernel_size, causal=False):
|
|
super().__init__()
|
|
if causal:
|
|
self.remove = kernel_size - 1
|
|
else:
|
|
self.remove = 1 if kernel_size % 2 == 0 else 0
|
|
|
|
def forward(self, x):
|
|
if self.remove > 0:
|
|
x = x[:, :, : -self.remove]
|
|
return x
|
|
|
|
|
|
def pad_to_multiple(x, multiple, dim=-1, value=0):
|
|
# Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
|
|
if x is None:
|
|
return None, 0
|
|
tsz = x.size(dim)
|
|
m = tsz / multiple
|
|
remainder = math.ceil(m) * multiple - tsz
|
|
if m.is_integer():
|
|
return x, 0
|
|
pad_offset = (0,) * (-1 - dim) * 2
|
|
|
|
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
|
|
|
|
|
|
def gelu_accurate(x):
|
|
if not hasattr(gelu_accurate, "_a"):
|
|
gelu_accurate._a = math.sqrt(2 / math.pi)
|
|
return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
|
|
|
|
|
def gelu(x: torch.Tensor) -> torch.Tensor:
|
|
return torch.nn.functional.gelu(x.float()).type_as(x)
|
|
|
|
|
|
def get_available_activation_fns():
|
|
return [
|
|
"relu",
|
|
"gelu",
|
|
"gelu_fast", # deprecated
|
|
"gelu_accurate",
|
|
"tanh",
|
|
"linear",
|
|
]
|
|
|
|
|
|
def get_activation_fn(activation: str):
|
|
"""Returns the activation function corresponding to `activation`"""
|
|
|
|
if activation == "relu":
|
|
return F.relu
|
|
elif activation == "gelu":
|
|
return gelu
|
|
elif activation == "gelu_accurate":
|
|
return gelu_accurate
|
|
elif activation == "tanh":
|
|
return torch.tanh
|
|
elif activation == "linear":
|
|
return lambda x: x
|
|
elif activation == "swish":
|
|
return torch.nn.SiLU
|
|
else:
|
|
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
|
|
|
|
|
def init_bert_params(module):
|
|
"""
|
|
Initialize the weights specific to the BERT Model.
|
|
This overrides the default initializations depending on the specified arguments.
|
|
1. If normal_init_linear_weights is set then weights of linear
|
|
layer will be initialized using the normal distribution and
|
|
bais will be set to the specified value.
|
|
2. If normal_init_embed_weights is set then weights of embedding
|
|
layer will be initialized using the normal distribution.
|
|
3. If normal_init_proj_weights is set then weights of
|
|
in_project_weight for MultiHeadAttention initialized using
|
|
the normal distribution (to be validated).
|
|
"""
|
|
|
|
def normal_(data):
|
|
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
|
# so that the RNG is consistent with and without FSDP
|
|
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
|
|
|
if isinstance(module, nn.Linear):
|
|
normal_(module.weight.data)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
if isinstance(module, nn.Embedding):
|
|
normal_(module.weight.data)
|
|
if module.padding_idx is not None:
|
|
module.weight.data[module.padding_idx].zero_()
|
|
if isinstance(module, MultiheadAttention):
|
|
normal_(module.q_proj.weight.data)
|
|
normal_(module.k_proj.weight.data)
|
|
normal_(module.v_proj.weight.data)
|