167 lines
4.9 KiB
Python
167 lines
4.9 KiB
Python
|
#!/usr/bin/env python3
|
||
|
# -*- coding: utf-8 -*-
|
||
|
|
||
|
# Copyright 2019 Shigeki Karita
|
||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||
|
|
||
|
"""Layer normalization module."""
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
|
||
|
class LayerNorm(torch.nn.LayerNorm):
|
||
|
"""Layer normalization module.
|
||
|
|
||
|
Args:
|
||
|
nout (int): Output dim size.
|
||
|
dim (int): Dimension to be normalized.
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self, nout, dim=-1):
|
||
|
"""Construct an LayerNorm object."""
|
||
|
super(LayerNorm, self).__init__(nout, eps=1e-12)
|
||
|
self.dim = dim
|
||
|
|
||
|
def forward(self, x):
|
||
|
"""Apply layer normalization.
|
||
|
|
||
|
Args:
|
||
|
x (torch.Tensor): Input tensor.
|
||
|
|
||
|
Returns:
|
||
|
torch.Tensor: Normalized tensor.
|
||
|
|
||
|
"""
|
||
|
if self.dim == -1:
|
||
|
return super(LayerNorm, self).forward(x)
|
||
|
return super(LayerNorm, self).forward(x.transpose(self.dim, -1)).transpose(self.dim, -1)
|
||
|
|
||
|
|
||
|
class GlobalLayerNorm(nn.Module):
|
||
|
"""Calculate Global Layer Normalization.
|
||
|
|
||
|
Arguments
|
||
|
---------
|
||
|
dim : (int or list or torch.Size)
|
||
|
Input shape from an expected input of size.
|
||
|
eps : float
|
||
|
A value added to the denominator for numerical stability.
|
||
|
elementwise_affine : bool
|
||
|
A boolean value that when set to True,
|
||
|
this module has learnable per-element affine parameters
|
||
|
initialized to ones (for weights) and zeros (for biases).
|
||
|
|
||
|
Example
|
||
|
-------
|
||
|
>>> x = torch.randn(5, 10, 20)
|
||
|
>>> GLN = GlobalLayerNorm(10, 3)
|
||
|
>>> x_norm = GLN(x)
|
||
|
"""
|
||
|
|
||
|
def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
|
||
|
super(GlobalLayerNorm, self).__init__()
|
||
|
self.dim = dim
|
||
|
self.eps = eps
|
||
|
self.elementwise_affine = elementwise_affine
|
||
|
|
||
|
if self.elementwise_affine:
|
||
|
if shape == 3:
|
||
|
self.weight = nn.Parameter(torch.ones(self.dim, 1))
|
||
|
self.bias = nn.Parameter(torch.zeros(self.dim, 1))
|
||
|
if shape == 4:
|
||
|
self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
|
||
|
self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
|
||
|
else:
|
||
|
self.register_parameter("weight", None)
|
||
|
self.register_parameter("bias", None)
|
||
|
|
||
|
def forward(self, x):
|
||
|
"""Returns the normalized tensor.
|
||
|
|
||
|
Arguments
|
||
|
---------
|
||
|
x : torch.Tensor
|
||
|
Tensor of size [N, C, K, S] or [N, C, L].
|
||
|
"""
|
||
|
# x = N x C x K x S or N x C x L
|
||
|
# N x 1 x 1
|
||
|
# cln: mean,var N x 1 x K x S
|
||
|
# gln: mean,var N x 1 x 1
|
||
|
if x.dim() == 3:
|
||
|
mean = torch.mean(x, (1, 2), keepdim=True)
|
||
|
var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
|
||
|
if self.elementwise_affine:
|
||
|
x = self.weight * (x - mean) / torch.sqrt(var + self.eps) + self.bias
|
||
|
else:
|
||
|
x = (x - mean) / torch.sqrt(var + self.eps)
|
||
|
|
||
|
if x.dim() == 4:
|
||
|
mean = torch.mean(x, (1, 2, 3), keepdim=True)
|
||
|
var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
|
||
|
if self.elementwise_affine:
|
||
|
x = self.weight * (x - mean) / torch.sqrt(var + self.eps) + self.bias
|
||
|
else:
|
||
|
x = (x - mean) / torch.sqrt(var + self.eps)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class CumulativeLayerNorm(nn.LayerNorm):
|
||
|
"""Calculate Cumulative Layer Normalization.
|
||
|
|
||
|
Arguments
|
||
|
---------
|
||
|
dim : int
|
||
|
Dimension that you want to normalize.
|
||
|
elementwise_affine : True
|
||
|
Learnable per-element affine parameters.
|
||
|
|
||
|
Example
|
||
|
-------
|
||
|
>>> x = torch.randn(5, 10, 20)
|
||
|
>>> CLN = CumulativeLayerNorm(10)
|
||
|
>>> x_norm = CLN(x)
|
||
|
"""
|
||
|
|
||
|
def __init__(self, dim, elementwise_affine=True):
|
||
|
super(CumulativeLayerNorm, self).__init__(
|
||
|
dim, elementwise_affine=elementwise_affine, eps=1e-8
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
"""Returns the normalized tensor.
|
||
|
|
||
|
Arguments
|
||
|
---------
|
||
|
x : torch.Tensor
|
||
|
Tensor size [N, C, K, S] or [N, C, L]
|
||
|
"""
|
||
|
# x: N x C x K x S or N x C x L
|
||
|
# N x K x S x C
|
||
|
if x.dim() == 4:
|
||
|
x = x.permute(0, 2, 3, 1).contiguous()
|
||
|
# N x K x S x C == only channel norm
|
||
|
x = super().forward(x)
|
||
|
# N x C x K x S
|
||
|
x = x.permute(0, 3, 1, 2).contiguous()
|
||
|
if x.dim() == 3:
|
||
|
x = torch.transpose(x, 1, 2)
|
||
|
# N x L x C == only channel norm
|
||
|
x = super().forward(x)
|
||
|
# N x C x L
|
||
|
x = torch.transpose(x, 1, 2)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class ScaleNorm(nn.Module):
|
||
|
def __init__(self, dim, eps=1e-5):
|
||
|
super().__init__()
|
||
|
self.scale = dim**-0.5
|
||
|
self.eps = eps
|
||
|
self.g = nn.Parameter(torch.ones(1))
|
||
|
|
||
|
def forward(self, x):
|
||
|
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||
|
return x / norm.clamp(min=self.eps) * self.g
|