FunASR/runtime/onnxruntime/third_party/kaldi/lat/sausages.h

271 lines
12 KiB
C
Raw Normal View History

2024-05-18 15:50:56 +08:00
// lat/sausages.h
// Copyright 2012 Johns Hopkins University (Author: Daniel Povey)
// 2015 Guoguo Chen
// 2019 Dogan Can
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LAT_SAUSAGES_H_
#define KALDI_LAT_SAUSAGES_H_
#include <vector>
#include <map>
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "fstext/fstext-lib.h"
#include "lat/kaldi-lattice.h"
namespace kaldi {
/// The implementation of the Minimum Bayes Risk decoding method described in
/// "Minimum Bayes Risk decoding and system combination based on a recursion for
/// edit distance", Haihua Xu, Daniel Povey, Lidia Mangu and Jie Zhu, Computer
/// Speech and Language, 2011
/// This is a slightly more principled way to do Minimum Bayes Risk (MBR) decoding
/// than the standard "Confusion Network" method. Note: MBR decoding aims to
/// minimize the expected word error rate, assuming the lattice encodes the
/// true uncertainty about what was spoken; standard Viterbi decoding gives the
/// most likely utterance, which corresponds to minimizing the expected sentence
/// error rate.
///
/// In addition to giving the MBR output, we also provide a way to get a
/// "Confusion Network" or informally "sausage"-like structure. This is a
/// linear sequence of bins, and in each bin, there is a distribution over
/// words (or epsilon, meaning no word). This is useful for estimating
/// confidence. Note: due to the way these sausages are made, typically there
/// will be, between each bin representing a high-confidence word, a bin
/// in which epsilon (no word) is the most likely word. Inside these bins
/// is where we put possible insertions.
struct MinimumBayesRiskOptions {
/// Boolean configuration parameter: if true, we actually update the hypothesis
/// to do MBR decoding (if false, our output is the MAP decoded output, but we
/// output the stats too (i.e. the confidences)).
bool decode_mbr;
/// Boolean configuration parameter: if true, the 1-best path will 'keep' the <eps> bins,
bool print_silence;
MinimumBayesRiskOptions() : decode_mbr(true), print_silence(false)
{ }
void Register(OptionsItf *opts) {
opts->Register("decode-mbr", &decode_mbr, "If true, do Minimum Bayes Risk "
"decoding (else, Maximum a Posteriori)");
opts->Register("print-silence", &print_silence, "Keep the inter-word '<eps>' "
"bins in the 1-best output (ctm, <eps> can be a 'silence' or a 'deleted' word)");
}
};
/// This class does the word-level Minimum Bayes Risk computation, and gives you
/// either the 1-best MBR output together with the expected Bayes Risk,
/// or a sausage-like structure.
class MinimumBayesRisk {
public:
/// Initialize with compact lattice-- any acoustic scaling etc., is assumed
/// to have been done already.
/// This does the whole computation. You get the output with
/// GetOneBest(), GetBayesRisk(), and GetSausageStats().
MinimumBayesRisk(const CompactLattice &clat,
MinimumBayesRiskOptions opts = MinimumBayesRiskOptions());
// Uses the provided <words> as <R_> instead of using the lattice best path.
// Note that the default value of opts.decode_mbr is true. If you provide 1-best
// hypothesis from MAP decoding, the output ctm from MBR decoding may be
// mismatched with the provided <words> (<words> would be used as the starting
// point of optimization).
MinimumBayesRisk(const CompactLattice &clat,
const std::vector<int32> &words,
MinimumBayesRiskOptions opts = MinimumBayesRiskOptions());
// Uses the provided <words> as <R_> and <times> of bins instead of using the lattice best path.
// Note that the default value of opts.decode_mbr is true. If you provide 1-best
// hypothesis from MAP decoding, the output ctm from MBR decoding may be
// mismatched with the provided <words> (<words> would be used as the starting
// point of optimization).
MinimumBayesRisk(const CompactLattice &clat,
const std::vector<int32> &words,
const std::vector<std::pair<BaseFloat,BaseFloat> > &times,
MinimumBayesRiskOptions opts = MinimumBayesRiskOptions());
const std::vector<int32> &GetOneBest() const { // gets one-best (with no epsilons)
return R_;
}
const std::vector<std::vector<std::pair<BaseFloat, BaseFloat> > > GetTimes() const {
return times_; // returns average (start,end) times for each word in each
// bin. These are raw averages without any processing, i.e. time intervals
// from different bins can overlap.
}
const std::vector<std::pair<BaseFloat, BaseFloat> > GetSausageTimes() const {
return sausage_times_; // returns average (start,end) times for each bin.
// This is typically the weighted average of the times in GetTimes() but can
// be slightly different if the times for the bins overlap, in which case
// the times returned by this method do not overlap unlike the times
// returned by GetTimes().
}
const std::vector<std::pair<BaseFloat, BaseFloat> > &GetOneBestTimes() const {
return one_best_times_; // returns average (start,end) times for each word
// corresponding to an entry in the one-best output. This is typically the
// appropriate subset of the times in GetTimes() but can be slightly
// different if the times for the one-best words overlap, in which case
// the times returned by this method do not overlap unlike the times
// returned by GetTimes().
}
/// Outputs the confidences for the one-best transcript.
const std::vector<BaseFloat> &GetOneBestConfidences() const {
return one_best_confidences_;
}
/// Returns the expected WER over this sentence (assuming model correctness).
BaseFloat GetBayesRisk() const { return L_; }
const std::vector<std::vector<std::pair<int32, BaseFloat> > > &GetSausageStats() const {
return gamma_;
}
private:
void PrepareLatticeAndInitStats(CompactLattice *clat);
/// Minimum-Bayes-Risk Decode. Top-level algorithm. Figure 6 of the paper.
void MbrDecode();
/// Without the 'penalize' argument this gives us the basic edit-distance
/// function l(a,b), as in the paper.
/// With the 'penalize' argument it can be interpreted as the edit distance
/// plus the 'delta' from the paper, except that we make a kind of conceptual
/// bug-fix and only apply the delta if the edit-distance was not already
/// zero. This bug-fix was necessary in order to force all the stats to show
/// up, that should show up, and applying the bug-fix makes the sausage stats
/// significantly less sparse.
inline double l(int32 a, int32 b, bool penalize = false) {
if (a == b) return 0.0;
else return (penalize ? 1.0 + delta() : 1.0);
}
/// returns r_q, in one-based indexing, as in the paper.
inline int32 r(int32 q) { return R_[q-1]; }
/// Figure 4 of the paper; called from AccStats (Fig. 5)
double EditDistance(int32 N, int32 Q,
Vector<double> &alpha,
Matrix<double> &alpha_dash,
Vector<double> &alpha_dash_arc);
/// Figure 5 of the paper. Outputs to gamma_ and L_.
void AccStats();
/// Removes epsilons (symbol 0) from a vector
static void RemoveEps(std::vector<int32> *vec);
// Ensures that between each word in "vec" and at the beginning and end, is
// epsilon (0). (But if no words in vec, just one epsilon)
static void NormalizeEps(std::vector<int32> *vec);
// delta() is a constant used in the algorithm, which penalizes
// the use of certain epsilon transitions in the edit-distance which would cause
// words not to show up in the accumulated edit-distance statistics.
// There has been a conceptual bug-fix versus the way it was presented in
// the paper: we now add delta only if the edit-distance was not already
// zero.
static inline BaseFloat delta() { return 1.0e-05; }
/// Function used to increment map.
static inline void AddToMap(int32 i, double d, std::map<int32, double> *gamma) {
if (d == 0) return;
std::pair<const int32, double> pr(i, d);
std::pair<std::map<int32, double>::iterator, bool> ret = gamma->insert(pr);
if (!ret.second) // not inserted, so add to contents.
ret.first->second += d;
}
struct Arc {
int32 word;
int32 start_node;
int32 end_node;
BaseFloat loglike;
};
MinimumBayesRiskOptions opts_;
/// Arcs in the topologically sorted acceptor form of the word-level lattice,
/// with one final-state. Contains (word-symbol, log-likelihood on arc ==
/// negated cost). Indexed from zero.
std::vector<Arc> arcs_;
/// For each node in the lattice, a list of arcs entering that node. Indexed
/// from 1 (first node == 1).
std::vector<std::vector<int32> > pre_;
std::vector<int32> state_times_; // time of each state in the word lattice,
// indexed from 1 (same index as into pre_)
std::vector<int32> R_; // current 1-best word sequence, normalized to have
// epsilons between each word and at the beginning and end. R in paper...
// caution: indexed from zero, not from 1 as in paper.
double L_; // current averaged edit-distance between lattice and R_.
// \hat{L} in paper.
std::vector<std::vector<std::pair<int32, BaseFloat> > > gamma_;
// The stats we accumulate; these are pairs of (posterior, word-id), and note
// that word-id may be epsilon. Caution: indexed from zero, not from 1 as in
// paper. We sort in reverse order on the second member (posterior), so more
// likely word is first.
std::vector<std::vector<std::pair<BaseFloat, BaseFloat> > > times_;
// The average start and end times for words in each confusion-network bin.
// This is like an average over arcs, of the tau_b and tau_e quantities in
// Appendix C of the paper. Indexed from zero, like gamma_ and R_.
std::vector<std::pair<BaseFloat, BaseFloat> > sausage_times_;
// The average start and end times for each confusion-network bin. This
// is like an average over words, of the tau_b and tau_e quantities in
// Appendix C of the paper. Indexed from zero, like gamma_ and R_.
std::vector<std::pair<BaseFloat, BaseFloat> > one_best_times_;
// The average start and end times for words in the one best output. This
// is like an average over the arcs, of the tau_b and tau_e quantities in
// Appendix C of the paper. Indexed from zero, like gamma_ and R_.
std::vector<BaseFloat> one_best_confidences_;
// vector of confidences for the 1-best output (which could be
// the MAP output if opts_.decode_mbr == false, or the MBR output otherwise).
// Indexed by the same index as one_best_times_.
struct GammaCompare{
// should be like operator <. But we want reverse order
// on the 2nd element (posterior), so it'll be like operator
// > that looks first at the posterior.
bool operator () (const std::pair<int32, BaseFloat> &a,
const std::pair<int32, BaseFloat> &b) const {
if (a.second > b.second) return true;
else if (a.second < b.second) return false;
else return a.first > b.first;
}
};
};
} // namespace kaldi
#endif // KALDI_LAT_SAUSAGES_H_