143 lines
6.1 KiB
Python
143 lines
6.1 KiB
Python
from typing import Tuple, Optional
|
|
import numpy as np
|
|
import torch
|
|
from torch.nn import functional as F
|
|
import torch.nn as nn
|
|
|
|
|
|
class ProfileAug(nn.Module):
|
|
"""
|
|
Implement the augmentation for profiles including:
|
|
- Split aug: split one profile into two profiles, i.e., main and inaccurate, labels assigned to main
|
|
- Merge aug: merge two profiles into one, labels are also merged into one, the other set to zero
|
|
- Disturb aug: disturb some profile with others to simulate the inaccurate clustering centroids.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
apply_split_aug: bool = True,
|
|
split_aug_prob: float = 0.05,
|
|
apply_merge_aug: bool = True,
|
|
merge_aug_prob: float = 0.2,
|
|
apply_disturb_aug: bool = True,
|
|
disturb_aug_prob: float = 0.4,
|
|
disturb_alpha: float = 0.2,
|
|
) -> None:
|
|
super().__init__()
|
|
self.apply_split_aug = apply_split_aug
|
|
self.split_aug_prob = split_aug_prob
|
|
self.apply_merge_aug = apply_merge_aug
|
|
self.merge_aug_prob = merge_aug_prob
|
|
self.apply_disturb_aug = apply_disturb_aug
|
|
self.disturb_aug_prob = disturb_aug_prob
|
|
self.disturb_alpha = disturb_alpha
|
|
|
|
def split_aug(self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor):
|
|
# B, N
|
|
bsz, dim = profile.shape[0], profile.shape[-1]
|
|
profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False)
|
|
spk_count = binary_labels.sum(dim=1)
|
|
prob = np.random.rand(bsz)
|
|
batch_indices = np.nonzero(prob < self.split_aug_prob)[0]
|
|
for idx in batch_indices:
|
|
valid_spk_idx = torch.nonzero(spk_count[idx] * mask[idx])
|
|
pad_spk_idx = torch.nonzero((spk_count[idx] == 0) * mask[idx])
|
|
if len(valid_spk_idx) == 0 or len(pad_spk_idx) == 0:
|
|
continue
|
|
split_spk_idx = valid_spk_idx[torch.randint(len(valid_spk_idx), ())]
|
|
to_cover_idx = pad_spk_idx[torch.randint(len(pad_spk_idx), ())]
|
|
disturb_vec = torch.randn((dim,)).to(profile)
|
|
disturb_vec = F.normalize(disturb_vec, dim=-1)
|
|
profile[idx, to_cover_idx] = F.normalize(
|
|
profile[idx, split_spk_idx] + self.disturb_alpha * disturb_vec
|
|
)
|
|
mask[idx, split_spk_idx] = 0
|
|
mask[idx, to_cover_idx] = 0
|
|
return profile, binary_labels, mask
|
|
|
|
def merge_aug(self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor):
|
|
bsz, dim = profile.shape[0], profile.shape[-1]
|
|
profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False)
|
|
spk_count = binary_labels.sum(dim=1)
|
|
prob = np.random.rand(bsz)
|
|
batch_indices = np.nonzero(prob < self.merge_aug_prob)[0]
|
|
for idx in batch_indices:
|
|
valid_spk_idx = torch.nonzero(profile_norm[idx] * mask[idx])
|
|
if len(valid_spk_idx) == 0:
|
|
continue
|
|
to_merge = torch.randint(len(valid_spk_idx), (2,))
|
|
spk_idx_1, spk_idx_2 = valid_spk_idx[to_merge[0]], valid_spk_idx[to_merge[1]]
|
|
# merge profile
|
|
profile[idx, spk_idx_1] = profile[idx, spk_idx_1] + profile[idx, spk_idx_2]
|
|
profile[idx, spk_idx_1] = F.normalize(profile[idx, spk_idx_1], dim=-1)
|
|
profile[idx, spk_idx_2] = 0
|
|
# merge binary labels
|
|
binary_labels[idx, :, spk_idx_1] = (
|
|
binary_labels[idx, :, spk_idx_1] + binary_labels[idx, :, spk_idx_2]
|
|
)
|
|
binary_labels[idx, :, spk_idx_1] = (binary_labels[idx, :, spk_idx_1] > 0).to(
|
|
binary_labels
|
|
)
|
|
binary_labels[idx, :, spk_idx_2] = 0
|
|
|
|
mask[idx, spk_idx_1] = 0
|
|
mask[idx, spk_idx_2] = 0
|
|
|
|
return profile, binary_labels, mask
|
|
|
|
def disturb_aug(self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor):
|
|
bsz, dim = profile.shape[0], profile.shape[-1]
|
|
profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False)
|
|
spk_count = binary_labels.sum(dim=1)
|
|
prob = np.random.rand(bsz)
|
|
batch_indices = np.nonzero(prob < self.disturb_aug_prob)[0]
|
|
for idx in batch_indices:
|
|
pos_spk_idx = torch.nonzero(spk_count[idx] * mask[idx])
|
|
valid_spk_idx = torch.nonzero(profile_norm[idx] * mask[idx])
|
|
if len(pos_spk_idx) == 0 or len(valid_spk_idx) == 0:
|
|
continue
|
|
to_disturb_idx = pos_spk_idx[torch.randint(len(pos_spk_idx), ())]
|
|
disturb_idx = valid_spk_idx[torch.randint(len(valid_spk_idx), ())]
|
|
alpha = self.disturb_alpha * torch.rand(()).item()
|
|
profile[idx, to_disturb_idx] = (1 - alpha) * profile[
|
|
idx, to_disturb_idx
|
|
] + alpha * profile[idx, disturb_idx]
|
|
profile[idx, to_disturb_idx] = F.normalize(profile[idx, to_disturb_idx], dim=-1)
|
|
mask[idx, to_disturb_idx] = 0
|
|
|
|
return profile, binary_labels, mask
|
|
|
|
def forward(
|
|
self,
|
|
speech: torch.Tensor,
|
|
speech_lengths: torch.Tensor = None,
|
|
profile: torch.Tensor = None,
|
|
profile_lengths: torch.Tensor = None,
|
|
binary_labels: torch.Tensor = None,
|
|
labels_length: torch.Tensor = None,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
|
|
# copy inputs to avoid inplace-operation
|
|
speech, profile, binary_labels = (
|
|
torch.clone(speech),
|
|
torch.clone(profile),
|
|
torch.clone(binary_labels),
|
|
)
|
|
profile = F.normalize(profile, dim=-1)
|
|
|
|
profile_mask = torch.ones(profile.shape[:2]).to(profile)
|
|
if self.apply_disturb_aug:
|
|
profile, binary_labels, profile_mask = self.disturb_aug(
|
|
profile, binary_labels, profile_mask
|
|
)
|
|
if self.apply_split_aug:
|
|
profile, binary_labels, profile_mask = self.split_aug(
|
|
profile, binary_labels, profile_mask
|
|
)
|
|
if self.apply_merge_aug:
|
|
profile, binary_labels, profile_mask = self.merge_aug(
|
|
profile, binary_labels, profile_mask
|
|
)
|
|
|
|
return speech, profile, binary_labels
|