FunASR/runtime/triton_gpu/client/utils.py

61 lines
1.8 KiB
Python
Raw Normal View History

2024-05-18 15:50:56 +08:00
import numpy as np
def _levenshtein_distance(ref, hyp):
"""Levenshtein distance is a string metric for measuring the difference
between two sequences. Informally, the levenshtein disctance is defined as
the minimum number of single-character edits (substitutions, insertions or
deletions) required to change one word into the other. We can naturally
extend the edits to word level when calculate levenshtein disctance for
two sentences.
"""
m = len(ref)
n = len(hyp)
# special case
if ref == hyp:
return 0
if m == 0:
return n
if n == 0:
return m
if m < n:
ref, hyp = hyp, ref
m, n = n, m
# use O(min(m, n)) space
distance = np.zeros((2, n + 1), dtype=np.int32)
# initialize distance matrix
for j in range(n + 1):
distance[0][j] = j
# calculate levenshtein distance
for i in range(1, m + 1):
prev_row_idx = (i - 1) % 2
cur_row_idx = i % 2
distance[cur_row_idx][0] = i
for j in range(1, n + 1):
if ref[i - 1] == hyp[j - 1]:
distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
else:
s_num = distance[prev_row_idx][j - 1] + 1
i_num = distance[cur_row_idx][j - 1] + 1
d_num = distance[prev_row_idx][j] + 1
distance[cur_row_idx][j] = min(s_num, i_num, d_num)
return distance[m % 2][n]
def cal_cer(references, predictions):
errors = 0
lengths = 0
for ref, pred in zip(references, predictions):
cur_ref = list(ref)
cur_hyp = list(pred)
cur_error = _levenshtein_distance(cur_ref, cur_hyp)
errors += cur_error
lengths += len(cur_ref)
return float(errors) / lengths