105 lines
3.3 KiB
Python
105 lines
3.3 KiB
Python
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.multiprocessing
|
||
|
import torch.nn.functional as F
|
||
|
from itertools import combinations
|
||
|
from itertools import permutations
|
||
|
|
||
|
|
||
|
def generate_mapping_dict(max_speaker_num=6, max_olp_speaker_num=3):
|
||
|
all_kinds = []
|
||
|
all_kinds.append(0)
|
||
|
for i in range(max_olp_speaker_num):
|
||
|
selected_num = i + 1
|
||
|
coms = np.array(list(combinations(np.arange(max_speaker_num), selected_num)))
|
||
|
for com in coms:
|
||
|
tmp = np.zeros(max_speaker_num)
|
||
|
tmp[com] = 1
|
||
|
item = int(raw_dec_trans(tmp.reshape(1, -1), max_speaker_num)[0])
|
||
|
all_kinds.append(item)
|
||
|
all_kinds_order = sorted(all_kinds)
|
||
|
|
||
|
mapping_dict = {}
|
||
|
mapping_dict["dec2label"] = {}
|
||
|
mapping_dict["label2dec"] = {}
|
||
|
for i in range(len(all_kinds_order)):
|
||
|
dec = all_kinds_order[i]
|
||
|
mapping_dict["dec2label"][dec] = i
|
||
|
mapping_dict["label2dec"][i] = dec
|
||
|
oov_id = len(all_kinds_order)
|
||
|
mapping_dict["oov"] = oov_id
|
||
|
return mapping_dict
|
||
|
|
||
|
|
||
|
def raw_dec_trans(x, max_speaker_num):
|
||
|
num_list = []
|
||
|
for i in range(max_speaker_num):
|
||
|
num_list.append(x[:, i])
|
||
|
base = 1
|
||
|
T = x.shape[0]
|
||
|
res = np.zeros((T))
|
||
|
for num in num_list:
|
||
|
res += num * base
|
||
|
base = base * 2
|
||
|
return res
|
||
|
|
||
|
|
||
|
def mapping_func(num, mapping_dict):
|
||
|
if num in mapping_dict["dec2label"].keys():
|
||
|
label = mapping_dict["dec2label"][num]
|
||
|
else:
|
||
|
label = mapping_dict["oov"]
|
||
|
return label
|
||
|
|
||
|
|
||
|
def dec_trans(x, max_speaker_num, mapping_dict):
|
||
|
num_list = []
|
||
|
for i in range(max_speaker_num):
|
||
|
num_list.append(x[:, i])
|
||
|
base = 1
|
||
|
T = x.shape[0]
|
||
|
res = np.zeros((T))
|
||
|
for num in num_list:
|
||
|
res += num * base
|
||
|
base = base * 2
|
||
|
res = np.array([mapping_func(i, mapping_dict) for i in res])
|
||
|
return res
|
||
|
|
||
|
|
||
|
def create_powerlabel(label, mapping_dict, max_speaker_num=6, max_olp_speaker_num=3):
|
||
|
T, C = label.shape
|
||
|
padding_label = np.zeros((T, max_speaker_num))
|
||
|
padding_label[:, :C] = label
|
||
|
out_label = dec_trans(padding_label, max_speaker_num, mapping_dict)
|
||
|
out_label = torch.from_numpy(out_label)
|
||
|
return out_label
|
||
|
|
||
|
|
||
|
def generate_perm_pse(label, n_speaker, mapping_dict, max_speaker_num, max_olp_speaker_num=3):
|
||
|
perms = np.array(list(permutations(range(n_speaker)))).astype(np.float32)
|
||
|
perms = torch.from_numpy(perms).to(label.device).to(torch.int64)
|
||
|
perm_labels = [label[:, perm] for perm in perms]
|
||
|
perm_pse_labels = [
|
||
|
create_powerlabel(perm_label.cpu().numpy(), mapping_dict, max_speaker_num).to(
|
||
|
perm_label.device, non_blocking=True
|
||
|
)
|
||
|
for perm_label in perm_labels
|
||
|
]
|
||
|
return perm_labels, perm_pse_labels
|
||
|
|
||
|
|
||
|
def generate_min_pse(
|
||
|
label, n_speaker, mapping_dict, max_speaker_num, pse_logit, max_olp_speaker_num=3
|
||
|
):
|
||
|
perm_labels, perm_pse_labels = generate_perm_pse(
|
||
|
label, n_speaker, mapping_dict, max_speaker_num, max_olp_speaker_num=max_olp_speaker_num
|
||
|
)
|
||
|
losses = [
|
||
|
F.cross_entropy(input=pse_logit, target=perm_pse_label.to(torch.long)) * len(pse_logit)
|
||
|
for perm_pse_label in perm_pse_labels
|
||
|
]
|
||
|
loss = torch.stack(losses)
|
||
|
min_index = torch.argmin(loss)
|
||
|
selected_perm_label, selected_pse_label = perm_labels[min_index], perm_pse_labels[min_index]
|
||
|
return selected_perm_label, selected_pse_label
|