177 lines
6.3 KiB
Python
177 lines
6.3 KiB
Python
|
#!/usr/bin/env python3
|
||
|
# Copyright 2018 David Snyder
|
||
|
# Apache 2.0
|
||
|
|
||
|
# This script computes the minimum detection cost function, which is a common
|
||
|
# error metric used in speaker recognition. Compared to equal error-rate,
|
||
|
# which assigns equal weight to false negatives and false positives, this
|
||
|
# error-rate is usually used to assess performance in settings where achieving
|
||
|
# a low false positive rate is more important than achieving a low false
|
||
|
# negative rate. See the NIST 2016 Speaker Recognition Evaluation Plan at
|
||
|
# https://www.nist.gov/sites/default/files/documents/2016/10/07/sre16_eval_plan_v1.3.pdf
|
||
|
# for more details about the metric.
|
||
|
from __future__ import print_function
|
||
|
from operator import itemgetter
|
||
|
import sys, argparse, os
|
||
|
|
||
|
|
||
|
def GetArgs():
|
||
|
parser = argparse.ArgumentParser(
|
||
|
description="Compute the minimum "
|
||
|
"detection cost function along with the threshold at which it occurs. "
|
||
|
"Usage: sid/compute_min_dcf.py [options...] <scores-file> "
|
||
|
"<trials-file> "
|
||
|
"E.g., sid/compute_min_dcf.py --p-target 0.01 --c-miss 1 --c-fa 1 "
|
||
|
"exp/scores/trials data/test/trials",
|
||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--p-target",
|
||
|
type=float,
|
||
|
dest="p_target",
|
||
|
default=0.01,
|
||
|
help="The prior probability of the target speaker in a trial.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--c-miss",
|
||
|
type=float,
|
||
|
dest="c_miss",
|
||
|
default=1,
|
||
|
help="Cost of a missed detection. This is usually not changed.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--c-fa",
|
||
|
type=float,
|
||
|
dest="c_fa",
|
||
|
default=1,
|
||
|
help="Cost of a spurious detection. This is usually not changed.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"scores_filename",
|
||
|
help="Input scores file, with columns of the form " "<utt1> <utt2> <score>",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"trials_filename",
|
||
|
help="Input trials file, with columns of the form " "<utt1> <utt2> <target/nontarget>",
|
||
|
)
|
||
|
sys.stderr.write(" ".join(sys.argv) + "\n")
|
||
|
args = parser.parse_args()
|
||
|
args = CheckArgs(args)
|
||
|
return args
|
||
|
|
||
|
|
||
|
def CheckArgs(args):
|
||
|
if args.c_fa <= 0:
|
||
|
raise Exception("--c-fa must be greater than 0")
|
||
|
if args.c_miss <= 0:
|
||
|
raise Exception("--c-miss must be greater than 0")
|
||
|
if args.p_target <= 0 or args.p_target >= 1:
|
||
|
raise Exception("--p-target must be greater than 0 and less than 1")
|
||
|
return args
|
||
|
|
||
|
|
||
|
# Creates a list of false-negative rates, a list of false-positive rates
|
||
|
# and a list of decision thresholds that give those error-rates.
|
||
|
def ComputeErrorRates(scores, labels):
|
||
|
|
||
|
# Sort the scores from smallest to largest, and also get the corresponding
|
||
|
# indexes of the sorted scores. We will treat the sorted scores as the
|
||
|
# thresholds at which the the error-rates are evaluated.
|
||
|
sorted_indexes, thresholds = zip(
|
||
|
*sorted([(index, threshold) for index, threshold in enumerate(scores)], key=itemgetter(1))
|
||
|
)
|
||
|
labels = [labels[i] for i in sorted_indexes]
|
||
|
fns = []
|
||
|
tns = []
|
||
|
|
||
|
# At the end of this loop, fns[i] is the number of errors made by
|
||
|
# incorrectly rejecting scores less than thresholds[i]. And, tns[i]
|
||
|
# is the total number of times that we have correctly rejected scores
|
||
|
# less than thresholds[i].
|
||
|
for i in range(0, len(labels)):
|
||
|
if i == 0:
|
||
|
fns.append(labels[i])
|
||
|
tns.append(1 - labels[i])
|
||
|
else:
|
||
|
fns.append(fns[i - 1] + labels[i])
|
||
|
tns.append(tns[i - 1] + 1 - labels[i])
|
||
|
positives = sum(labels)
|
||
|
negatives = len(labels) - positives
|
||
|
|
||
|
# Now divide the false negatives by the total number of
|
||
|
# positives to obtain the false negative rates across
|
||
|
# all thresholds
|
||
|
fnrs = [fn / float(positives) for fn in fns]
|
||
|
|
||
|
# Divide the true negatives by the total number of
|
||
|
# negatives to get the true negative rate. Subtract these
|
||
|
# quantities from 1 to get the false positive rates.
|
||
|
fprs = [1 - tn / float(negatives) for tn in tns]
|
||
|
return fnrs, fprs, thresholds
|
||
|
|
||
|
|
||
|
# Computes the minimum of the detection cost function. The comments refer to
|
||
|
# equations in Section 3 of the NIST 2016 Speaker Recognition Evaluation Plan.
|
||
|
def ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa):
|
||
|
min_c_det = float("inf")
|
||
|
min_c_det_threshold = thresholds[0]
|
||
|
for i in range(0, len(fnrs)):
|
||
|
# See Equation (2). it is a weighted sum of false negative
|
||
|
# and false positive errors.
|
||
|
c_det = c_miss * fnrs[i] * p_target + c_fa * fprs[i] * (1 - p_target)
|
||
|
if c_det < min_c_det:
|
||
|
min_c_det = c_det
|
||
|
min_c_det_threshold = thresholds[i]
|
||
|
# See Equations (3) and (4). Now we normalize the cost.
|
||
|
c_def = min(c_miss * p_target, c_fa * (1 - p_target))
|
||
|
min_dcf = min_c_det / c_def
|
||
|
return min_dcf, min_c_det_threshold
|
||
|
|
||
|
|
||
|
def compute_min_dcf(scores_filename, trials_filename, c_miss=1, c_fa=1, p_target=0.01):
|
||
|
scores_file = open(scores_filename, "r").readlines()
|
||
|
trials_file = open(trials_filename, "r").readlines()
|
||
|
c_miss = c_miss
|
||
|
c_fa = c_fa
|
||
|
p_target = p_target
|
||
|
|
||
|
scores = []
|
||
|
labels = []
|
||
|
|
||
|
trials = {}
|
||
|
for line in trials_file:
|
||
|
utt1, utt2, target = line.rstrip().split()
|
||
|
trial = utt1 + " " + utt2
|
||
|
trials[trial] = target
|
||
|
|
||
|
for line in scores_file:
|
||
|
utt1, utt2, score = line.rstrip().split()
|
||
|
trial = utt1 + " " + utt2
|
||
|
if trial in trials:
|
||
|
scores.append(float(score))
|
||
|
if trials[trial] == "target":
|
||
|
labels.append(1)
|
||
|
else:
|
||
|
labels.append(0)
|
||
|
else:
|
||
|
raise Exception("Missing entry for " + utt1 + " and " + utt2 + " " + scores_filename)
|
||
|
|
||
|
fnrs, fprs, thresholds = ComputeErrorRates(scores, labels)
|
||
|
mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa)
|
||
|
return mindcf, threshold
|
||
|
|
||
|
|
||
|
def main():
|
||
|
args = GetArgs()
|
||
|
mindcf, threshold = compute_min_dcf(
|
||
|
args.scores_filename, args.trials_filename, args.c_miss, args.c_fa, args.p_target
|
||
|
)
|
||
|
sys.stdout.write(
|
||
|
"minDCF is {0:.4f} at threshold {1:.4f} (p-target={2}, c-miss={3}, "
|
||
|
"c-fa={4})\n".format(mindcf, threshold, args.p_target, args.c_miss, args.c_fa)
|
||
|
)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|