FunASR/runtime/onnxruntime/third_party/kaldi/lat/lattice-functions-transitio...

263 lines
9.5 KiB
C++

// lat/lattice-functions-transition-model.cc
// Copyright 2009-2011 Saarland University (Author: Arnab Ghoshal)
// 2012-2013 Johns Hopkins University (Author: Daniel Povey); Chao Weng;
// Bagher BabaAli
// 2013 Cisco Systems (author: Neha Agrawal) [code modified
// from original code in ../gmmbin/gmm-rescore-lattice.cc]
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "lat/lattice-functions-transition-model.h"
#include "hmm/hmm-utils.h"
#include "hmm/transition-model.h"
#include "lat/lattice-functions.h"
namespace kaldi {
BaseFloat LatticeForwardBackwardMmi(
const TransitionModel &tmodel,
const Lattice &lat,
const std::vector<int32> &num_ali,
bool drop_frames,
bool convert_to_pdf_ids,
bool cancel,
Posterior *post) {
// First compute the MMI posteriors.
Posterior den_post;
BaseFloat ans = LatticeForwardBackward(lat,
&den_post,
NULL);
Posterior num_post;
AlignmentToPosterior(num_ali, &num_post);
// Now negate the MMI posteriors and add the numerator
// posteriors.
ScalePosterior(-1.0, &den_post);
if (convert_to_pdf_ids) {
Posterior num_tmp;
ConvertPosteriorToPdfs(tmodel, num_post, &num_tmp);
num_tmp.swap(num_post);
Posterior den_tmp;
ConvertPosteriorToPdfs(tmodel, den_post, &den_tmp);
den_tmp.swap(den_post);
}
MergePosteriors(num_post, den_post,
cancel, drop_frames, post);
return ans;
}
bool CompactLatticeToWordProns(
const TransitionModel &tmodel,
const CompactLattice &clat,
std::vector<int32> *words,
std::vector<int32> *begin_times,
std::vector<int32> *lengths,
std::vector<std::vector<int32> > *prons,
std::vector<std::vector<int32> > *phone_lengths) {
words->clear();
begin_times->clear();
lengths->clear();
prons->clear();
phone_lengths->clear();
typedef CompactLattice::Arc Arc;
typedef Arc::Label Label;
typedef CompactLattice::StateId StateId;
typedef CompactLattice::Weight Weight;
using namespace fst;
StateId state = clat.Start();
int32 cur_time = 0;
if (state == kNoStateId) {
KALDI_WARN << "Empty lattice.";
return false;
}
while (1) {
Weight final = clat.Final(state);
size_t num_arcs = clat.NumArcs(state);
if (final != Weight::Zero()) {
if (num_arcs != 0) {
KALDI_WARN << "Lattice is not linear.";
return false;
}
if (! final.String().empty()) {
KALDI_WARN << "Lattice has alignments on final-weight: probably "
"was not word-aligned (alignments will be approximate)";
}
return true;
} else {
if (num_arcs != 1) {
KALDI_WARN << "Lattice is not linear: num-arcs = " << num_arcs;
return false;
}
fst::ArcIterator<CompactLattice> aiter(clat, state);
const Arc &arc = aiter.Value();
Label word_id = arc.ilabel; // Note: ilabel==olabel, since acceptor.
// Also note: word_id may be zero; we output it anyway.
int32 length = arc.weight.String().size();
words->push_back(word_id);
begin_times->push_back(cur_time);
lengths->push_back(length);
const std::vector<int32> &arc_alignment = arc.weight.String();
std::vector<std::vector<int32> > split_alignment;
SplitToPhones(tmodel, arc_alignment, &split_alignment);
std::vector<int32> phones(split_alignment.size());
std::vector<int32> plengths(split_alignment.size());
for (size_t i = 0; i < split_alignment.size(); i++) {
KALDI_ASSERT(!split_alignment[i].empty());
phones[i] = tmodel.TransitionIdToPhone(split_alignment[i][0]);
plengths[i] = split_alignment[i].size();
}
prons->push_back(phones);
phone_lengths->push_back(plengths);
cur_time += length;
state = arc.nextstate;
}
}
}
// Returns true if this vector of transition-ids could be a valid
// word. Note: for testing, we assume that the lexicon always
// has the same input-word and output-word. The other case is complex
// to test.
static bool IsPlausibleWord(const WordAlignLatticeLexiconInfo &lexicon_info,
const TransitionModel &tmodel,
int32 word_id,
const std::vector<int32> &transition_ids) {
std::vector<std::vector<int32> > split_alignment; // Split into phones.
if (!SplitToPhones(tmodel, transition_ids, &split_alignment)) {
KALDI_WARN << "Could not split word into phones correctly (forced-out?)";
}
std::vector<int32> phones(split_alignment.size());
for (size_t i = 0; i < split_alignment.size(); i++) {
KALDI_ASSERT(!split_alignment[i].empty());
phones[i] = tmodel.TransitionIdToPhone(split_alignment[i][0]);
}
std::vector<int32> lexicon_entry;
lexicon_entry.push_back(word_id);
lexicon_entry.insert(lexicon_entry.end(), phones.begin(), phones.end());
if (!lexicon_info.IsValidEntry(lexicon_entry)) {
std::ostringstream ostr;
for (size_t i = 0; i < lexicon_entry.size(); i++)
ostr << lexicon_entry[i] << ' ';
KALDI_WARN << "Invalid arc in aligned lattice (code error?) lexicon-entry is " << ostr.str();
return false;
} else {
return true;
}
}
/// Testing code; map word symbols in the lattice "lat" using the equivalence-classes
/// obtained from the lexicon, using the function EquivalenceClassOf in the lexicon_info
/// object.
static void MapSymbols(const WordAlignLatticeLexiconInfo &lexicon_info,
CompactLattice *lat) {
typedef CompactLattice::StateId StateId;
for (StateId s = 0; s < lat->NumStates(); s++) {
for (fst::MutableArcIterator<CompactLattice> aiter(lat, s);
!aiter.Done(); aiter.Next()) {
CompactLatticeArc arc (aiter.Value());
KALDI_ASSERT(arc.ilabel == arc.olabel);
arc.ilabel = lexicon_info.EquivalenceClassOf(arc.ilabel);
arc.olabel = arc.ilabel;
aiter.SetValue(arc);
}
}
}
bool TestWordAlignedLattice(const WordAlignLatticeLexiconInfo &lexicon_info,
const TransitionModel &tmodel,
CompactLattice clat,
CompactLattice aligned_clat,
bool allow_duplicate_paths) {
int32 max_err = 5, num_err = 0;
{ // We test whether the forward-backward likelihoods differ; this is intended
// to detect when we have duplicate paths in the aligned lattice, for some path
// in the input lattice (e.g. due to epsilon-sequencing problems).
Posterior post;
Lattice lat, aligned_lat;
ConvertLattice(clat, &lat);
ConvertLattice(aligned_clat, &aligned_lat);
TopSort(&lat);
TopSort(&aligned_lat);
BaseFloat like_before = LatticeForwardBackward(lat, &post),
like_after = LatticeForwardBackward(aligned_lat, &post);
if (fabs(like_before - like_after) >
1.0e-04 * (fabs(like_before) + fabs(like_after))) {
KALDI_WARN << "Forward-backward likelihoods differ in word-aligned lattice "
<< "testing, " << like_before << " != " << like_after;
if (!allow_duplicate_paths)
num_err++;
}
}
// Do a check on the arcs of the aligned lattice, that each arc corresponds
// to an entry in the lexicon.
for (CompactLattice::StateId s = 0; s < aligned_clat.NumStates(); s++) {
for (fst::ArcIterator<CompactLattice> aiter(aligned_clat, s);
!aiter.Done(); aiter.Next()) {
const CompactLatticeArc &arc (aiter.Value());
KALDI_ASSERT(arc.ilabel == arc.olabel);
int32 word_id = arc.ilabel;
const std::vector<int32> &tids = arc.weight.String();
if (word_id == 0 && tids.empty()) continue; // We allow epsilon arcs.
if (num_err < max_err)
if (!IsPlausibleWord(lexicon_info, tmodel, word_id, tids))
num_err++;
// Note: IsPlausibleWord will warn if there is an error.
}
if (!aligned_clat.Final(s).String().empty()) {
KALDI_WARN << "Aligned lattice has nonempty string on its final-prob.";
return false;
}
}
// Next we'll do an equivalence test.
// First map symbols into equivalence classes, so that we don't wrongly fail
// due to the capability of the framework to map words to other words.
// (e.g. mapping <eps> on silence arcs to SIL).
MapSymbols(lexicon_info, &clat);
MapSymbols(lexicon_info, &aligned_clat);
/// Check equivalence.
int32 num_paths = 5, seed = Rand(), max_path_length = -1;
BaseFloat delta = 0.2; // some lattices have large costs -> use large delta.
FLAGS_v = GetVerboseLevel(); // set the OpenFst verbose level to the Kaldi
// verbose level.
if (!RandEquivalent(clat, aligned_clat, num_paths, delta, seed, max_path_length)) {
KALDI_WARN << "Equivalence test failed during lattice alignment.";
return false;
}
FLAGS_v = 0;
return (num_err == 0);
}
} // namespace kaldi