117 lines
3.3 KiB
Python
117 lines
3.3 KiB
Python
|
from pathlib import Path
|
||
|
from typing import Tuple
|
||
|
from typing import Union
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
|
||
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||
|
from funasr.register import tables
|
||
|
|
||
|
|
||
|
@tables.register("normalize_classes", "GlobalMVN")
|
||
|
class GlobalMVN(torch.nn.Module):
|
||
|
"""Apply global mean and variance normalization
|
||
|
TODO(kamo): Make this class portable somehow
|
||
|
Args:
|
||
|
stats_file: npy file
|
||
|
norm_means: Apply mean normalization
|
||
|
norm_vars: Apply var normalization
|
||
|
eps:
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
stats_file: Union[Path, str],
|
||
|
norm_means: bool = True,
|
||
|
norm_vars: bool = True,
|
||
|
eps: float = 1.0e-20,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.norm_means = norm_means
|
||
|
self.norm_vars = norm_vars
|
||
|
self.eps = eps
|
||
|
stats_file = Path(stats_file)
|
||
|
|
||
|
self.stats_file = stats_file
|
||
|
stats = np.load(stats_file)
|
||
|
if isinstance(stats, np.ndarray):
|
||
|
# Kaldi like stats
|
||
|
count = stats[0].flatten()[-1]
|
||
|
mean = stats[0, :-1] / count
|
||
|
var = stats[1, :-1] / count - mean * mean
|
||
|
else:
|
||
|
# New style: Npz file
|
||
|
count = stats["count"]
|
||
|
sum_v = stats["sum"]
|
||
|
sum_square_v = stats["sum_square"]
|
||
|
mean = sum_v / count
|
||
|
var = sum_square_v / count - mean * mean
|
||
|
std = np.sqrt(np.maximum(var, eps))
|
||
|
|
||
|
self.register_buffer("mean", torch.from_numpy(mean))
|
||
|
self.register_buffer("std", torch.from_numpy(std))
|
||
|
|
||
|
def extra_repr(self):
|
||
|
return (
|
||
|
f"stats_file={self.stats_file}, "
|
||
|
f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self, x: torch.Tensor, ilens: torch.Tensor = None
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
"""Forward function
|
||
|
Args:
|
||
|
x: (B, L, ...)
|
||
|
ilens: (B,)
|
||
|
"""
|
||
|
if ilens is None:
|
||
|
ilens = x.new_full([x.size(0)], x.size(1))
|
||
|
norm_means = self.norm_means
|
||
|
norm_vars = self.norm_vars
|
||
|
self.mean = self.mean.to(x.device, x.dtype)
|
||
|
self.std = self.std.to(x.device, x.dtype)
|
||
|
mask = make_pad_mask(ilens, x, 1)
|
||
|
|
||
|
# feat: (B, T, D)
|
||
|
if norm_means:
|
||
|
if x.requires_grad:
|
||
|
x = x - self.mean
|
||
|
else:
|
||
|
x -= self.mean
|
||
|
if x.requires_grad:
|
||
|
x = x.masked_fill(mask, 0.0)
|
||
|
else:
|
||
|
x.masked_fill_(mask, 0.0)
|
||
|
|
||
|
if norm_vars:
|
||
|
x /= self.std
|
||
|
|
||
|
return x, ilens
|
||
|
|
||
|
def inverse(
|
||
|
self, x: torch.Tensor, ilens: torch.Tensor = None
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
if ilens is None:
|
||
|
ilens = x.new_full([x.size(0)], x.size(1))
|
||
|
norm_means = self.norm_means
|
||
|
norm_vars = self.norm_vars
|
||
|
self.mean = self.mean.to(x.device, x.dtype)
|
||
|
self.std = self.std.to(x.device, x.dtype)
|
||
|
mask = make_pad_mask(ilens, x, 1)
|
||
|
|
||
|
if x.requires_grad:
|
||
|
x = x.masked_fill(mask, 0.0)
|
||
|
else:
|
||
|
x.masked_fill_(mask, 0.0)
|
||
|
|
||
|
if norm_vars:
|
||
|
x *= self.std
|
||
|
|
||
|
# feat: (B, T, D)
|
||
|
if norm_means:
|
||
|
x += self.mean
|
||
|
x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
|
||
|
return x, ilens
|