80 lines
2.5 KiB
Python
80 lines
2.5 KiB
Python
|
import librosa
|
||
|
import torch
|
||
|
from typing import Tuple
|
||
|
|
||
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||
|
|
||
|
|
||
|
class LogMel(torch.nn.Module):
|
||
|
"""Convert STFT to fbank feats
|
||
|
|
||
|
The arguments is same as librosa.filters.mel
|
||
|
|
||
|
Args:
|
||
|
fs: number > 0 [scalar] sampling rate of the incoming signal
|
||
|
n_fft: int > 0 [scalar] number of FFT components
|
||
|
n_mels: int > 0 [scalar] number of Mel bands to generate
|
||
|
fmin: float >= 0 [scalar] lowest frequency (in Hz)
|
||
|
fmax: float >= 0 [scalar] highest frequency (in Hz).
|
||
|
If `None`, use `fmax = fs / 2.0`
|
||
|
htk: use HTK formula instead of Slaney
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
fs: int = 16000,
|
||
|
n_fft: int = 512,
|
||
|
n_mels: int = 80,
|
||
|
fmin: float = None,
|
||
|
fmax: float = None,
|
||
|
htk: bool = False,
|
||
|
log_base: float = None,
|
||
|
):
|
||
|
super().__init__()
|
||
|
|
||
|
fmin = 0 if fmin is None else fmin
|
||
|
fmax = fs / 2 if fmax is None else fmax
|
||
|
_mel_options = dict(
|
||
|
sr=fs,
|
||
|
n_fft=n_fft,
|
||
|
n_mels=n_mels,
|
||
|
fmin=fmin,
|
||
|
fmax=fmax,
|
||
|
htk=htk,
|
||
|
)
|
||
|
self.mel_options = _mel_options
|
||
|
self.log_base = log_base
|
||
|
|
||
|
# Note(kamo): The mel matrix of librosa is different from kaldi.
|
||
|
melmat = librosa.filters.mel(**_mel_options)
|
||
|
# melmat: (D2, D1) -> (D1, D2)
|
||
|
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
|
||
|
|
||
|
def extra_repr(self):
|
||
|
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
feat: torch.Tensor,
|
||
|
ilens: torch.Tensor = None,
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
|
||
|
mel_feat = torch.matmul(feat, self.melmat)
|
||
|
mel_feat = torch.clamp(mel_feat, min=1e-10)
|
||
|
|
||
|
if self.log_base is None:
|
||
|
logmel_feat = mel_feat.log()
|
||
|
elif self.log_base == 2.0:
|
||
|
logmel_feat = mel_feat.log2()
|
||
|
elif self.log_base == 10.0:
|
||
|
logmel_feat = mel_feat.log10()
|
||
|
else:
|
||
|
logmel_feat = mel_feat.log() / torch.log(self.log_base)
|
||
|
|
||
|
# Zero padding
|
||
|
if ilens is not None:
|
||
|
logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0)
|
||
|
else:
|
||
|
ilens = feat.new_full([feat.size(0)], fill_value=feat.size(1), dtype=torch.long)
|
||
|
return logmel_feat, ilens
|