50 lines
1.5 KiB
Python
50 lines
1.5 KiB
Python
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
from scipy.optimize import linear_sum_assignment
|
||
|
|
||
|
|
||
|
def standard_loss(ys, ts):
|
||
|
losses = [F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts)]
|
||
|
loss = torch.sum(torch.stack(losses))
|
||
|
n_frames = (
|
||
|
torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts])))
|
||
|
.to(torch.float32)
|
||
|
.to(ys[0].device)
|
||
|
)
|
||
|
loss = loss / n_frames
|
||
|
return loss
|
||
|
|
||
|
|
||
|
def fast_batch_pit_n_speaker_loss(ys, ts):
|
||
|
with torch.no_grad():
|
||
|
bs = len(ys)
|
||
|
indices = []
|
||
|
for b in range(bs):
|
||
|
y = ys[b].transpose(0, 1)
|
||
|
t = ts[b].transpose(0, 1)
|
||
|
C, _ = t.shape
|
||
|
y = y[:, None, :].repeat(1, C, 1)
|
||
|
t = t[None, :, :].repeat(C, 1, 1)
|
||
|
bce_loss = F.binary_cross_entropy(torch.sigmoid(y), t, reduction="none").mean(-1)
|
||
|
C = bce_loss.cpu()
|
||
|
indices.append(linear_sum_assignment(C))
|
||
|
labels_perm = [t[:, idx[1]] for t, idx in zip(ts, indices)]
|
||
|
|
||
|
return labels_perm
|
||
|
|
||
|
|
||
|
def cal_power_loss(logits, power_ts):
|
||
|
losses = [
|
||
|
F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit)
|
||
|
for logit, power_t in zip(logits, power_ts)
|
||
|
]
|
||
|
loss = torch.sum(torch.stack(losses))
|
||
|
n_frames = (
|
||
|
torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts])))
|
||
|
.to(torch.float32)
|
||
|
.to(power_ts[0].device)
|
||
|
)
|
||
|
loss = loss / n_frames
|
||
|
return loss
|