94 lines
2.8 KiB
Python
94 lines
2.8 KiB
Python
|
from typing import Tuple
|
||
|
|
||
|
from pytorch_wpe import wpe_one_iteration
|
||
|
import torch
|
||
|
from torch_complex.tensor import ComplexTensor
|
||
|
|
||
|
from funasr.frontends.utils.mask_estimator import MaskEstimator
|
||
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||
|
|
||
|
|
||
|
class DNN_WPE(torch.nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
wtype: str = "blstmp",
|
||
|
widim: int = 257,
|
||
|
wlayers: int = 3,
|
||
|
wunits: int = 300,
|
||
|
wprojs: int = 320,
|
||
|
dropout_rate: float = 0.0,
|
||
|
taps: int = 5,
|
||
|
delay: int = 3,
|
||
|
use_dnn_mask: bool = True,
|
||
|
iterations: int = 1,
|
||
|
normalization: bool = False,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.iterations = iterations
|
||
|
self.taps = taps
|
||
|
self.delay = delay
|
||
|
|
||
|
self.normalization = normalization
|
||
|
self.use_dnn_mask = use_dnn_mask
|
||
|
|
||
|
self.inverse_power = True
|
||
|
|
||
|
if self.use_dnn_mask:
|
||
|
self.mask_est = MaskEstimator(
|
||
|
wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self, data: ComplexTensor, ilens: torch.LongTensor
|
||
|
) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
|
||
|
"""The forward function
|
||
|
|
||
|
Notation:
|
||
|
B: Batch
|
||
|
C: Channel
|
||
|
T: Time or Sequence length
|
||
|
F: Freq or Some dimension of the feature vector
|
||
|
|
||
|
Args:
|
||
|
data: (B, C, T, F)
|
||
|
ilens: (B,)
|
||
|
Returns:
|
||
|
data: (B, C, T, F)
|
||
|
ilens: (B,)
|
||
|
"""
|
||
|
# (B, T, C, F) -> (B, F, C, T)
|
||
|
enhanced = data = data.permute(0, 3, 2, 1)
|
||
|
mask = None
|
||
|
|
||
|
for i in range(self.iterations):
|
||
|
# Calculate power: (..., C, T)
|
||
|
power = enhanced.real**2 + enhanced.imag**2
|
||
|
if i == 0 and self.use_dnn_mask:
|
||
|
# mask: (B, F, C, T)
|
||
|
(mask,), _ = self.mask_est(enhanced, ilens)
|
||
|
if self.normalization:
|
||
|
# Normalize along T
|
||
|
mask = mask / mask.sum(dim=-1)[..., None]
|
||
|
# (..., C, T) * (..., C, T) -> (..., C, T)
|
||
|
power = power * mask
|
||
|
|
||
|
# Averaging along the channel axis: (..., C, T) -> (..., T)
|
||
|
power = power.mean(dim=-2)
|
||
|
|
||
|
# enhanced: (..., C, T) -> (..., C, T)
|
||
|
enhanced = wpe_one_iteration(
|
||
|
data.contiguous(),
|
||
|
power,
|
||
|
taps=self.taps,
|
||
|
delay=self.delay,
|
||
|
inverse_power=self.inverse_power,
|
||
|
)
|
||
|
|
||
|
enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0)
|
||
|
|
||
|
# (B, F, C, T) -> (B, T, C, F)
|
||
|
enhanced = enhanced.permute(0, 3, 2, 1)
|
||
|
if mask is not None:
|
||
|
mask = mask.transpose(-1, -3)
|
||
|
return enhanced, ilens, mask
|