41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
|
import torch
|
||
|
|
||
|
|
||
|
def th_accuracy(pad_outputs, pad_targets, ignore_label):
|
||
|
"""Calculate accuracy.
|
||
|
|
||
|
Args:
|
||
|
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||
|
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
||
|
ignore_label (int): Ignore label id.
|
||
|
|
||
|
Returns:
|
||
|
float: Accuracy value (0.0 - 1.0).
|
||
|
|
||
|
"""
|
||
|
pad_pred = pad_outputs.view(
|
||
|
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
|
||
|
).argmax(2)
|
||
|
mask = pad_targets != ignore_label
|
||
|
numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||
|
denominator = torch.sum(mask)
|
||
|
return float(numerator) / float(denominator)
|
||
|
|
||
|
|
||
|
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
|
||
|
"""Calculate accuracy.
|
||
|
|
||
|
Args:
|
||
|
pad_outputs (LongTensor): Prediction tensors (B, Lmax).
|
||
|
pad_targets (LongTensor): Target label tensors (B, Lmax).
|
||
|
ignore_label (int): Ignore label id.
|
||
|
|
||
|
Returns:
|
||
|
float: Accuracy value (0.0 - 1.0).
|
||
|
|
||
|
"""
|
||
|
mask = pad_targets != ignore_label
|
||
|
numerator = torch.sum(pad_outputs.masked_select(mask) == pad_targets.masked_select(mask))
|
||
|
denominator = torch.sum(mask)
|
||
|
return numerator.float() / denominator.float() # (FIX:MZY):return torch.Tensor type
|