FunASR/runtime/onnxruntime/third_party/kaldi/decoder/lattice-incremental-decoder.cc

1731 lines
72 KiB
C++

// decoder/lattice-incremental-decoder.cc
// Copyright 2019 Zhehuai Chen, Daniel Povey
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "decoder/lattice-incremental-decoder.h"
#include "lat/lattice-functions.h"
#include "base/timer.h"
namespace kaldi {
// instantiate this class once for each thing you have to decode.
template <typename FST, typename Token>
LatticeIncrementalDecoderTpl<FST, Token>::LatticeIncrementalDecoderTpl(
const FST &fst, const TransitionInformation &trans_model,
const LatticeIncrementalDecoderConfig &config)
: fst_(&fst),
delete_fst_(false),
num_toks_(0),
config_(config),
determinizer_(trans_model, config) {
config.Check();
toks_.SetSize(1000); // just so on the first frame we do something reasonable.
}
template <typename FST, typename Token>
LatticeIncrementalDecoderTpl<FST, Token>::LatticeIncrementalDecoderTpl(
const LatticeIncrementalDecoderConfig &config, FST *fst,
const TransitionInformation &trans_model)
: fst_(fst),
delete_fst_(true),
num_toks_(0),
config_(config),
determinizer_(trans_model, config) {
config.Check();
toks_.SetSize(1000); // just so on the first frame we do something reasonable.
}
template <typename FST, typename Token>
LatticeIncrementalDecoderTpl<FST, Token>::~LatticeIncrementalDecoderTpl() {
DeleteElems(toks_.Clear());
ClearActiveTokens();
if (delete_fst_) delete fst_;
}
template <typename FST, typename Token>
void LatticeIncrementalDecoderTpl<FST, Token>::InitDecoding() {
// clean up from last time:
DeleteElems(toks_.Clear());
cost_offsets_.clear();
ClearActiveTokens();
warned_ = false;
num_toks_ = 0;
decoding_finalized_ = false;
final_costs_.clear();
StateId start_state = fst_->Start();
KALDI_ASSERT(start_state != fst::kNoStateId);
active_toks_.resize(1);
Token *start_tok = new Token(0.0, 0.0, NULL, NULL, NULL);
active_toks_[0].toks = start_tok;
toks_.Insert(start_state, start_tok);
num_toks_++;
determinizer_.Init();
num_frames_in_lattice_ = 0;
token2label_map_.clear();
next_token_label_ = LatticeIncrementalDeterminizer::kTokenLabelOffset;
ProcessNonemitting(config_.beam);
}
template <typename FST, typename Token>
void LatticeIncrementalDecoderTpl<FST, Token>::UpdateLatticeDeterminization() {
if (NumFramesDecoded() - num_frames_in_lattice_ <
config_.determinize_max_delay)
return;
/* Make sure the token-pruning is active. Note: PruneActiveTokens() has
internal logic that prevents it from doing unnecessary work if you
call it and then immediately call it again. */
PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
int32 first = num_frames_in_lattice_ + config_.determinize_min_chunk_size,
last = NumFramesDecoded(),
fewest_tokens = std::numeric_limits<int32>::max(),
best_frame = -1;
for (int32 t = last; t >= first; t--) {
/* Make sure PruneActiveTokens() has computed num_toks for all these
frames... */
KALDI_ASSERT(active_toks_[t].num_toks != -1);
if (active_toks_[t].num_toks < fewest_tokens) {
// <= because we want the latest one in case of ties.
fewest_tokens = active_toks_[t].num_toks;
best_frame = t;
}
}
/* Skip this update if we have too many tokens, determinization will take too long,
postpone it to the next update */
if (fewest_tokens > config_.determinize_max_active)
return;
/* OK, determinize the chunk that spans from num_frames_in_lattice_ to
best_frame. */
bool use_final_probs = false;
GetLattice(best_frame, use_final_probs);
return;
}
// Returns true if any kind of traceback is available (not necessarily from
// a final state). It should only very rarely return false; this indicates
// an unusual search error.
template <typename FST, typename Token>
bool LatticeIncrementalDecoderTpl<FST, Token>::Decode(DecodableInterface *decodable) {
InitDecoding();
// We use 1-based indexing for frames in this decoder (if you view it in
// terms of features), but note that the decodable object uses zero-based
// numbering, which we have to correct for when we call it.
while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) {
if (NumFramesDecoded() % config_.prune_interval == 0) {
PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
}
UpdateLatticeDeterminization();
BaseFloat cost_cutoff = ProcessEmitting(decodable);
ProcessNonemitting(cost_cutoff);
}
Timer timer;
FinalizeDecoding();
bool use_final_probs = true;
GetLattice(NumFramesDecoded(), use_final_probs);
KALDI_VLOG(2) << "Delay time during and after FinalizeDecoding()"
<< "(secs): " << timer.Elapsed();
// Returns true if we have any kind of traceback available (not necessarily
// to the end state; query ReachedFinal() for that).
return !active_toks_.empty() && active_toks_.back().toks != NULL;
}
template <typename FST, typename Token>
void LatticeIncrementalDecoderTpl<FST, Token>::PossiblyResizeHash(size_t num_toks) {
size_t new_sz =
static_cast<size_t>(static_cast<BaseFloat>(num_toks) * config_.hash_ratio);
if (new_sz > toks_.Size()) {
toks_.SetSize(new_sz);
}
}
/*
A note on the definition of extra_cost.
extra_cost is used in pruning tokens, to save memory.
extra_cost can be thought of as a beta (backward) cost assuming
we had set the betas on currently-active tokens to all be the negative
of the alphas for those tokens. (So all currently active tokens would
be on (tied) best paths).
Define the 'forward cost' of a token as zero for any token on the frame
we're currently decoding; and for other frames, as the shortest-path cost
between that token and a token on the frame we're currently decoding.
(by "currently decoding" I mean the most recently processed frame).
Then define the extra_cost of a token (always >= 0) as the forward-cost of
the token minus the smallest forward-cost of any token on the same frame.
We can use the extra_cost to accurately prune away tokens that we know will
never appear in the lattice. If the extra_cost is greater than the desired
lattice beam, the token would provably never appear in the lattice, so we can
prune away the token.
The advantage of storing the extra_cost rather than the forward-cost, is that
it is less costly to keep the extra_cost up-to-date when we process new frames.
When we process a new frame, *all* the previous frames' forward-costs would change;
but in general the extra_cost will change only for a finite number of frames.
(Actually we don't update all the extra_costs every time we update a frame; we
only do it every 'config_.prune_interval' frames).
*/
// FindOrAddToken either locates a token in hash of toks_,
// or if necessary inserts a new, empty token (i.e. with no forward links)
// for the current frame. [note: it's inserted if necessary into hash toks_
// and also into the singly linked list of tokens active on this frame
// (whose head is at active_toks_[frame]).
template <typename FST, typename Token>
inline Token *LatticeIncrementalDecoderTpl<FST, Token>::FindOrAddToken(
StateId state, int32 frame_plus_one, BaseFloat tot_cost, Token *backpointer,
bool *changed) {
// Returns the Token pointer. Sets "changed" (if non-NULL) to true
// if the token was newly created or the cost changed.
KALDI_ASSERT(frame_plus_one < active_toks_.size());
Token *&toks = active_toks_[frame_plus_one].toks;
Elem *e_found = toks_.Find(state);
if (e_found == NULL) { // no such token presently.
const BaseFloat extra_cost = 0.0;
// tokens on the currently final frame have zero extra_cost
// as any of them could end up
// on the winning path.
Token *new_tok = new Token(tot_cost, extra_cost, NULL, toks, backpointer);
// NULL: no forward links yet
toks = new_tok;
num_toks_++;
toks_.Insert(state, new_tok);
if (changed) *changed = true;
return new_tok;
} else {
Token *tok = e_found->val; // There is an existing Token for this state.
if (tok->tot_cost > tot_cost) { // replace old token
tok->tot_cost = tot_cost;
// SetBackpointer() just does tok->backpointer = backpointer in
// the case where Token == BackpointerToken, else nothing.
tok->SetBackpointer(backpointer);
// we don't allocate a new token, the old stays linked in active_toks_
// we only replace the tot_cost
// in the current frame, there are no forward links (and no extra_cost)
// only in ProcessNonemitting we have to delete forward links
// in case we visit a state for the second time
// those forward links, that lead to this replaced token before:
// they remain and will hopefully be pruned later (PruneForwardLinks...)
if (changed) *changed = true;
} else {
if (changed) *changed = false;
}
return tok;
}
}
// prunes outgoing links for all tokens in active_toks_[frame]
// it's called by PruneActiveTokens
// all links, that have link_extra_cost > lattice_beam are pruned
template <typename FST, typename Token>
void LatticeIncrementalDecoderTpl<FST, Token>::PruneForwardLinks(
int32 frame_plus_one, bool *extra_costs_changed, bool *links_pruned,
BaseFloat delta) {
// delta is the amount by which the extra_costs must change
// If delta is larger, we'll tend to go back less far
// toward the beginning of the file.
// extra_costs_changed is set to true if extra_cost was changed for any token
// links_pruned is set to true if any link in any token was pruned
*extra_costs_changed = false;
*links_pruned = false;
KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size());
if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen.
if (!warned_) {
KALDI_WARN << "No tokens alive [doing pruning].. warning first "
"time only for each utterance\n";
warned_ = true;
}
}
// We have to iterate until there is no more change, because the links
// are not guaranteed to be in topological order.
bool changed = true; // difference new minus old extra cost >= delta ?
while (changed) {
changed = false;
for (Token *tok = active_toks_[frame_plus_one].toks; tok != NULL;
tok = tok->next) {
ForwardLinkT *link, *prev_link = NULL;
// will recompute tok_extra_cost for tok.
BaseFloat tok_extra_cost = std::numeric_limits<BaseFloat>::infinity();
// tok_extra_cost is the best (min) of link_extra_cost of outgoing links
for (link = tok->links; link != NULL;) {
// See if we need to excise this link...
Token *next_tok = link->next_tok;
BaseFloat link_extra_cost =
next_tok->extra_cost +
((tok->tot_cost + link->acoustic_cost + link->graph_cost) -
next_tok->tot_cost); // difference in brackets is >= 0
// link_exta_cost is the difference in score between the best paths
// through link source state and through link destination state
KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN
if (link_extra_cost > config_.lattice_beam) { // excise link
ForwardLinkT *next_link = link->next;
if (prev_link != NULL)
prev_link->next = next_link;
else
tok->links = next_link;
delete link;
link = next_link; // advance link but leave prev_link the same.
*links_pruned = true;
} else { // keep the link and update the tok_extra_cost if needed.
if (link_extra_cost < 0.0) { // this is just a precaution.
if (link_extra_cost < -0.01)
KALDI_WARN << "Negative extra_cost: " << link_extra_cost;
link_extra_cost = 0.0;
}
if (link_extra_cost < tok_extra_cost) tok_extra_cost = link_extra_cost;
prev_link = link; // move to next link
link = link->next;
}
} // for all outgoing links
if (fabs(tok_extra_cost - tok->extra_cost) > delta)
changed = true; // difference new minus old is bigger than delta
tok->extra_cost = tok_extra_cost;
// will be +infinity or <= lattice_beam_.
// infinity indicates, that no forward link survived pruning
} // for all Token on active_toks_[frame]
if (changed) *extra_costs_changed = true;
// Note: it's theoretically possible that aggressive compiler
// optimizations could cause an infinite loop here for small delta and
// high-dynamic-range scores.
} // while changed
}
// PruneForwardLinksFinal is a version of PruneForwardLinks that we call
// on the final frame. If there are final tokens active, it uses
// the final-probs for pruning, otherwise it treats all tokens as final.
template <typename FST, typename Token>
void LatticeIncrementalDecoderTpl<FST, Token>::PruneForwardLinksFinal() {
KALDI_ASSERT(!active_toks_.empty());
int32 frame_plus_one = active_toks_.size() - 1;
if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen.
KALDI_WARN << "No tokens alive at end of file";
typedef typename unordered_map<Token *, BaseFloat>::const_iterator IterType;
ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_);
decoding_finalized_ = true;
// We call DeleteElems() as a nicety, not because it's really necessary;
// otherwise there would be a time, after calling PruneTokensForFrame() on the
// final frame, when toks_.GetList() or toks_.Clear() would contain pointers
// to nonexistent tokens.
DeleteElems(toks_.Clear());
// Now go through tokens on this frame, pruning forward links... may have to
// iterate a few times until there is no more change, because the list is not
// in topological order. This is a modified version of the code in
// PruneForwardLinks, but here we also take account of the final-probs.
bool changed = true;
BaseFloat delta = 1.0e-05;
while (changed) {
changed = false;
for (Token *tok = active_toks_[frame_plus_one].toks; tok != NULL;
tok = tok->next) {
ForwardLinkT *link, *prev_link = NULL;
// will recompute tok_extra_cost. It has a term in it that corresponds
// to the "final-prob", so instead of initializing tok_extra_cost to infinity
// below we set it to the difference between the (score+final_prob) of this
// token,
// and the best such (score+final_prob).
BaseFloat final_cost;
if (final_costs_.empty()) {
final_cost = 0.0;
} else {
IterType iter = final_costs_.find(tok);
if (iter != final_costs_.end())
final_cost = iter->second;
else
final_cost = std::numeric_limits<BaseFloat>::infinity();
}
BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_;
// tok_extra_cost will be a "min" over either directly being final, or
// being indirectly final through other links, and the loop below may
// decrease its value:
for (link = tok->links; link != NULL;) {
// See if we need to excise this link...
Token *next_tok = link->next_tok;
BaseFloat link_extra_cost =
next_tok->extra_cost +
((tok->tot_cost + link->acoustic_cost + link->graph_cost) -
next_tok->tot_cost);
if (link_extra_cost > config_.lattice_beam) { // excise link
ForwardLinkT *next_link = link->next;
if (prev_link != NULL)
prev_link->next = next_link;
else
tok->links = next_link;
delete link;
link = next_link; // advance link but leave prev_link the same.
} else { // keep the link and update the tok_extra_cost if needed.
if (link_extra_cost < 0.0) { // this is just a precaution.
if (link_extra_cost < -0.01)
KALDI_WARN << "Negative extra_cost: " << link_extra_cost;
link_extra_cost = 0.0;
}
if (link_extra_cost < tok_extra_cost) tok_extra_cost = link_extra_cost;
prev_link = link;
link = link->next;
}
}
// prune away tokens worse than lattice_beam above best path. This step
// was not necessary in the non-final case because then, this case
// showed up as having no forward links. Here, the tok_extra_cost has
// an extra component relating to the final-prob.
if (tok_extra_cost > config_.lattice_beam)
tok_extra_cost = std::numeric_limits<BaseFloat>::infinity();
// to be pruned in PruneTokensForFrame
if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) changed = true;
tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_.
}
} // while changed
}
template <typename FST, typename Token>
BaseFloat LatticeIncrementalDecoderTpl<FST, Token>::FinalRelativeCost() const {
BaseFloat relative_cost;
ComputeFinalCosts(NULL, &relative_cost, NULL);
return relative_cost;
}
// Prune away any tokens on this frame that have no forward links.
// [we don't do this in PruneForwardLinks because it would give us
// a problem with dangling pointers].
// It's called by PruneActiveTokens if any forward links have been pruned
template <typename FST, typename Token>
void LatticeIncrementalDecoderTpl<FST, Token>::PruneTokensForFrame(
int32 frame_plus_one) {
KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size());
Token *&toks = active_toks_[frame_plus_one].toks;
if (toks == NULL) KALDI_WARN << "No tokens alive [doing pruning]";
Token *tok, *next_tok, *prev_tok = NULL;
int32 num_toks = 0;
for (tok = toks; tok != NULL; tok = next_tok, num_toks++) {
next_tok = tok->next;
if (tok->extra_cost == std::numeric_limits<BaseFloat>::infinity()) {
// token is unreachable from end of graph; (no forward links survived)
// excise tok from list and delete tok.
if (prev_tok != NULL)
prev_tok->next = tok->next;
else
toks = tok->next;
delete tok;
num_toks_--;
} else { // fetch next Token
prev_tok = tok;
}
}
active_toks_[frame_plus_one].num_toks = num_toks;
}
// Go backwards through still-alive tokens, pruning them, starting not from
// the current frame (where we want to keep all tokens) but from the frame before
// that. We go backwards through the frames and stop when we reach a point
// where the delta-costs are not changing (and the delta controls when we consider
// a cost to have "not changed").
template <typename FST, typename Token>
void LatticeIncrementalDecoderTpl<FST, Token>::PruneActiveTokens(BaseFloat delta) {
int32 cur_frame_plus_one = NumFramesDecoded();
int32 num_toks_begin = num_toks_;
if (active_toks_[cur_frame_plus_one].num_toks == -1){
// The current frame's tokens don't get pruned so they don't get counted
// (the count is needed by the incremental determinization code).
// Fix this.
int this_frame_num_toks = 0;
for (Token *t = active_toks_[cur_frame_plus_one].toks; t != NULL; t = t->next)
this_frame_num_toks++;
active_toks_[cur_frame_plus_one].num_toks = this_frame_num_toks;
}
// The index "f" below represents a "frame plus one", i.e. you'd have to subtract
// one to get the corresponding index for the decodable object.
for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) {
// Reason why we need to prune forward links in this situation:
// (1) we have never pruned them (new TokenList)
// (2) we have not yet pruned the forward links to the next f,
// after any of those tokens have changed their extra_cost.
if (active_toks_[f].must_prune_forward_links) {
bool extra_costs_changed = false, links_pruned = false;
PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta);
if (extra_costs_changed && f > 0) // any token has changed extra_cost
active_toks_[f - 1].must_prune_forward_links = true;
if (links_pruned) // any link was pruned
active_toks_[f].must_prune_tokens = true;
active_toks_[f].must_prune_forward_links = false; // job done
}
if (f + 1 < cur_frame_plus_one && // except for last f (no forward links)
active_toks_[f + 1].must_prune_tokens) {
PruneTokensForFrame(f + 1);
active_toks_[f + 1].must_prune_tokens = false;
}
}
KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin
<< " to " << num_toks_;
}
template <typename FST, typename Token>
void LatticeIncrementalDecoderTpl<FST, Token>::ComputeFinalCosts(
unordered_map<Token *, BaseFloat> *final_costs, BaseFloat *final_relative_cost,
BaseFloat *final_best_cost) const {
if (decoding_finalized_) {
// If we finalized decoding, the list toks_ will no longer exist, so return
// something we already computed.
if (final_costs) *final_costs = final_costs_;
if (final_relative_cost) *final_relative_cost = final_relative_cost_;
if (final_best_cost) *final_best_cost = final_best_cost_;
return;
}
if (final_costs != NULL) final_costs->clear();
const Elem *final_toks = toks_.GetList();
BaseFloat infinity = std::numeric_limits<BaseFloat>::infinity();
BaseFloat best_cost = infinity, best_cost_with_final = infinity;
while (final_toks != NULL) {
StateId state = final_toks->key;
Token *tok = final_toks->val;
const Elem *next = final_toks->tail;
BaseFloat final_cost = fst_->Final(state).Value();
BaseFloat cost = tok->tot_cost, cost_with_final = cost + final_cost;
best_cost = std::min(cost, best_cost);
best_cost_with_final = std::min(cost_with_final, best_cost_with_final);
if (final_costs != NULL && final_cost != infinity)
(*final_costs)[tok] = final_cost;
final_toks = next;
}
if (final_relative_cost != NULL) {
if (best_cost == infinity && best_cost_with_final == infinity) {
// Likely this will only happen if there are no tokens surviving.
// This seems the least bad way to handle it.
*final_relative_cost = infinity;
} else {
*final_relative_cost = best_cost_with_final - best_cost;
}
}
if (final_best_cost != NULL) {
if (best_cost_with_final != infinity) { // final-state exists.
*final_best_cost = best_cost_with_final;
} else { // no final-state exists.
*final_best_cost = best_cost;
}
}
}
template <typename FST, typename Token>
void LatticeIncrementalDecoderTpl<FST, Token>::AdvanceDecoding(
DecodableInterface *decodable, int32 max_num_frames) {
if (std::is_same<FST, fst::Fst<fst::StdArc> >::value) {
// if the type 'FST' is the FST base-class, then see if the FST type of fst_
// is actually VectorFst or ConstFst. If so, call the AdvanceDecoding()
// function after casting *this to the more specific type.
if (fst_->Type() == "const") {
LatticeIncrementalDecoderTpl<fst::ConstFst<fst::StdArc>, Token> *this_cast =
reinterpret_cast<
LatticeIncrementalDecoderTpl<fst::ConstFst<fst::StdArc>, Token> *>(
this);
this_cast->AdvanceDecoding(decodable, max_num_frames);
return;
} else if (fst_->Type() == "vector") {
LatticeIncrementalDecoderTpl<fst::VectorFst<fst::StdArc>, Token> *this_cast =
reinterpret_cast<
LatticeIncrementalDecoderTpl<fst::VectorFst<fst::StdArc>, Token> *>(
this);
this_cast->AdvanceDecoding(decodable, max_num_frames);
return;
}
}
KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ &&
"You must call InitDecoding() before AdvanceDecoding");
int32 num_frames_ready = decodable->NumFramesReady();
// num_frames_ready must be >= num_frames_decoded, or else
// the number of frames ready must have decreased (which doesn't
// make sense) or the decodable object changed between calls
// (which isn't allowed).
KALDI_ASSERT(num_frames_ready >= NumFramesDecoded());
int32 target_frames_decoded = num_frames_ready;
if (max_num_frames >= 0)
target_frames_decoded =
std::min(target_frames_decoded, NumFramesDecoded() + max_num_frames);
while (NumFramesDecoded() < target_frames_decoded) {
if (NumFramesDecoded() % config_.prune_interval == 0) {
PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
}
BaseFloat cost_cutoff = ProcessEmitting(decodable);
ProcessNonemitting(cost_cutoff);
}
UpdateLatticeDeterminization();
}
// FinalizeDecoding() is a version of PruneActiveTokens that we call
// (optionally) on the final frame. Takes into account the final-prob of
// tokens. This function used to be called PruneActiveTokensFinal().
template <typename FST, typename Token>
void LatticeIncrementalDecoderTpl<FST, Token>::FinalizeDecoding() {
int32 final_frame_plus_one = NumFramesDecoded();
int32 num_toks_begin = num_toks_;
// PruneForwardLinksFinal() prunes the final frame (with final-probs), and
// sets decoding_finalized_.
PruneForwardLinksFinal();
for (int32 f = final_frame_plus_one - 1; f >= 0; f--) {
bool b1, b2; // values not used.
BaseFloat dontcare = 0.0; // delta of zero means we must always update
PruneForwardLinks(f, &b1, &b2, dontcare);
PruneTokensForFrame(f + 1);
}
PruneTokensForFrame(0);
KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin << " to " << num_toks_;
}
/// Gets the weight cutoff. Also counts the active tokens.
template <typename FST, typename Token>
BaseFloat LatticeIncrementalDecoderTpl<FST, Token>::GetCutoff(
Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, Elem **best_elem) {
BaseFloat best_weight = std::numeric_limits<BaseFloat>::infinity();
// positive == high cost == bad.
size_t count = 0;
if (config_.max_active == std::numeric_limits<int32>::max() &&
config_.min_active == 0) {
for (Elem *e = list_head; e != NULL; e = e->tail, count++) {
BaseFloat w = static_cast<BaseFloat>(e->val->tot_cost);
if (w < best_weight) {
best_weight = w;
if (best_elem) *best_elem = e;
}
}
if (tok_count != NULL) *tok_count = count;
if (adaptive_beam != NULL) *adaptive_beam = config_.beam;
return best_weight + config_.beam;
} else {
tmp_array_.clear();
for (Elem *e = list_head; e != NULL; e = e->tail, count++) {
BaseFloat w = e->val->tot_cost;
tmp_array_.push_back(w);
if (w < best_weight) {
best_weight = w;
if (best_elem) *best_elem = e;
}
}
if (tok_count != NULL) *tok_count = count;
BaseFloat beam_cutoff = best_weight + config_.beam,
min_active_cutoff = std::numeric_limits<BaseFloat>::infinity(),
max_active_cutoff = std::numeric_limits<BaseFloat>::infinity();
KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded()
<< " is " << tmp_array_.size();
if (tmp_array_.size() > static_cast<size_t>(config_.max_active)) {
std::nth_element(tmp_array_.begin(), tmp_array_.begin() + config_.max_active,
tmp_array_.end());
max_active_cutoff = tmp_array_[config_.max_active];
}
if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam.
if (adaptive_beam)
*adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta;
return max_active_cutoff;
}
if (tmp_array_.size() > static_cast<size_t>(config_.min_active)) {
if (config_.min_active == 0)
min_active_cutoff = best_weight;
else {
std::nth_element(tmp_array_.begin(), tmp_array_.begin() + config_.min_active,
tmp_array_.size() > static_cast<size_t>(config_.max_active)
? tmp_array_.begin() + config_.max_active
: tmp_array_.end());
min_active_cutoff = tmp_array_[config_.min_active];
}
}
if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam.
if (adaptive_beam)
*adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta;
return min_active_cutoff;
} else {
*adaptive_beam = config_.beam;
return beam_cutoff;
}
}
}
template <typename FST, typename Token>
BaseFloat LatticeIncrementalDecoderTpl<FST, Token>::ProcessEmitting(
DecodableInterface *decodable) {
KALDI_ASSERT(active_toks_.size() > 0);
int32 frame = active_toks_.size() - 1; // frame is the frame-index
// (zero-based) used to get likelihoods
// from the decodable object.
active_toks_.resize(active_toks_.size() + 1);
Elem *final_toks = toks_.Clear(); // analogous to swapping prev_toks_ / cur_toks_
// in simple-decoder.h. Removes the Elems from
// being indexed in the hash in toks_.
Elem *best_elem = NULL;
BaseFloat adaptive_beam;
size_t tok_cnt;
BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem);
KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is "
<< adaptive_beam;
PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough.
BaseFloat next_cutoff = std::numeric_limits<BaseFloat>::infinity();
// pruning "online" before having seen all tokens
BaseFloat cost_offset = 0.0; // Used to keep probabilities in a good
// dynamic range.
// First process the best token to get a hopefully
// reasonably tight bound on the next cutoff. The only
// products of the next block are "next_cutoff" and "cost_offset".
if (best_elem) {
StateId state = best_elem->key;
Token *tok = best_elem->val;
cost_offset = -tok->tot_cost;
for (fst::ArcIterator<FST> aiter(*fst_, state); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0) { // propagate..
BaseFloat new_weight = arc.weight.Value() + cost_offset -
decodable->LogLikelihood(frame, arc.ilabel) +
tok->tot_cost;
if (new_weight + adaptive_beam < next_cutoff)
next_cutoff = new_weight + adaptive_beam;
}
}
}
// Store the offset on the acoustic likelihoods that we're applying.
// Could just do cost_offsets_.push_back(cost_offset), but we
// do it this way as it's more robust to future code changes.
cost_offsets_.resize(frame + 1, 0.0);
cost_offsets_[frame] = cost_offset;
// the tokens are now owned here, in final_toks, and the hash is empty.
// 'owned' is a complex thing here; the point is we need to call DeleteElem
// on each elem 'e' to let toks_ know we're done with them.
for (Elem *e = final_toks, *e_tail; e != NULL; e = e_tail) {
// loop this way because we delete "e" as we go.
StateId state = e->key;
Token *tok = e->val;
if (tok->tot_cost <= cur_cutoff) {
for (fst::ArcIterator<FST> aiter(*fst_, state); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0) { // propagate..
BaseFloat ac_cost =
cost_offset - decodable->LogLikelihood(frame, arc.ilabel),
graph_cost = arc.weight.Value(), cur_cost = tok->tot_cost,
tot_cost = cur_cost + ac_cost + graph_cost;
if (tot_cost >= next_cutoff)
continue;
else if (tot_cost + adaptive_beam < next_cutoff)
next_cutoff = tot_cost + adaptive_beam; // prune by best current token
// Note: the frame indexes into active_toks_ are one-based,
// hence the + 1.
Token *next_tok =
FindOrAddToken(arc.nextstate, frame + 1, tot_cost, tok, NULL);
// NULL: no change indicator needed
// Add ForwardLink from tok to next_tok (put on head of list tok->links)
tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, graph_cost,
ac_cost, tok->links);
}
} // for all arcs
}
e_tail = e->tail;
toks_.Delete(e); // delete Elem
}
return next_cutoff;
}
// static inline
template <typename FST, typename Token>
void LatticeIncrementalDecoderTpl<FST, Token>::DeleteForwardLinks(Token *tok) {
ForwardLinkT *l = tok->links, *m;
while (l != NULL) {
m = l->next;
delete l;
l = m;
}
tok->links = NULL;
}
template <typename FST, typename Token>
void LatticeIncrementalDecoderTpl<FST, Token>::ProcessNonemitting(BaseFloat cutoff) {
KALDI_ASSERT(!active_toks_.empty());
int32 frame = static_cast<int32>(active_toks_.size()) - 2;
// Note: "frame" is the time-index we just processed, or -1 if
// we are processing the nonemitting transitions before the
// first frame (called from InitDecoding()).
// Processes nonemitting arcs for one frame. Propagates within toks_.
// Note-- this queue structure is is not very optimal as
// it may cause us to process states unnecessarily (e.g. more than once),
// but in the baseline code, turning this vector into a set to fix this
// problem did not improve overall speed.
KALDI_ASSERT(queue_.empty());
if (toks_.GetList() == NULL) {
if (!warned_) {
KALDI_WARN << "Error, no surviving tokens: frame is " << frame;
warned_ = true;
}
}
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
StateId state = e->key;
if (fst_->NumInputEpsilons(state) != 0) queue_.push_back(state);
}
while (!queue_.empty()) {
StateId state = queue_.back();
queue_.pop_back();
Token *tok =
toks_.Find(state)
->val; // would segfault if state not in toks_ but this can't happen.
BaseFloat cur_cost = tok->tot_cost;
if (cur_cost >= cutoff) // Don't bother processing successors.
continue;
// If "tok" has any existing forward links, delete them,
// because we're about to regenerate them. This is a kind
// of non-optimality (remember, this is the simple decoder),
// but since most states are emitting it's not a huge issue.
DeleteForwardLinks(tok); // necessary when re-visiting
tok->links = NULL;
for (fst::ArcIterator<FST> aiter(*fst_, state); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel == 0) { // propagate nonemitting only...
BaseFloat graph_cost = arc.weight.Value(), tot_cost = cur_cost + graph_cost;
if (tot_cost < cutoff) {
bool changed;
Token *new_tok =
FindOrAddToken(arc.nextstate, frame + 1, tot_cost, tok, &changed);
tok->links =
new ForwardLinkT(new_tok, 0, arc.olabel, graph_cost, 0, tok->links);
// "changed" tells us whether the new token has a different
// cost from before, or is new [if so, add into queue].
if (changed && fst_->NumInputEpsilons(arc.nextstate) != 0)
queue_.push_back(arc.nextstate);
}
}
} // for all arcs
} // while queue not empty
}
template <typename FST, typename Token>
void LatticeIncrementalDecoderTpl<FST, Token>::DeleteElems(Elem *list) {
for (Elem *e = list, *e_tail; e != NULL; e = e_tail) {
e_tail = e->tail;
toks_.Delete(e);
}
}
template <typename FST, typename Token>
void LatticeIncrementalDecoderTpl<
FST, Token>::ClearActiveTokens() { // a cleanup routine, at utt end/begin
for (size_t i = 0; i < active_toks_.size(); i++) {
// Delete all tokens alive on this frame, and any forward
// links they may have.
for (Token *tok = active_toks_[i].toks; tok != NULL;) {
DeleteForwardLinks(tok);
Token *next_tok = tok->next;
delete tok;
num_toks_--;
tok = next_tok;
}
}
active_toks_.clear();
KALDI_ASSERT(num_toks_ == 0);
}
template <typename FST, typename Token>
const CompactLattice& LatticeIncrementalDecoderTpl<FST, Token>::GetLattice(
int32 num_frames_to_include,
bool use_final_probs) {
KALDI_ASSERT(num_frames_to_include >= num_frames_in_lattice_ &&
num_frames_to_include <= NumFramesDecoded());
if (num_frames_in_lattice_ > 0 &&
determinizer_.GetLattice().NumStates() == 0) {
/* Something went wrong, lattice is empty and will continue to be empty.
User-level code should detect and deal with this.
*/
num_frames_in_lattice_ = num_frames_to_include;
return determinizer_.GetLattice();
}
if (decoding_finalized_ && !use_final_probs) {
// This is not supported
KALDI_ERR << "You cannot get the lattice without final-probs after "
"calling FinalizeDecoding().";
}
if (use_final_probs && num_frames_to_include != NumFramesDecoded()) {
/* This is because we only remember the relation between HCLG states and
Tokens for the current frame; the Token does not have a `state` field. */
KALDI_ERR << "use-final-probs may no be true if you are not "
"getting a lattice for all frames decoded so far.";
}
if (num_frames_to_include > num_frames_in_lattice_) {
/* Make sure the token-pruning is up to date. If we just pruned the tokens,
this will do very little work. */
PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
if (determinizer_.GetLattice().NumStates() == 0 ||
determinizer_.GetLattice().Final(0) != CompactLatticeWeight::Zero()) {
num_frames_in_lattice_ = 0;
determinizer_.Init();
}
Lattice chunk_lat;
unordered_map<Label, LatticeArc::StateId> token_label2state;
if (num_frames_in_lattice_ != 0) {
determinizer_.InitializeRawLatticeChunk(&chunk_lat,
&token_label2state);
}
// tok_map will map from Token* to state-id in chunk_lat.
// The cur and prev versions alternate on different frames.
unordered_map<Token*, StateId> &tok2state_map(temp_token_map_);
tok2state_map.clear();
unordered_map<Token*, Label> &next_token2label_map(token2label_map_temp_);
next_token2label_map.clear();
{ // Deal with the last frame in the chunk, the one numbered `num_frames_to_include`.
// (Yes, this is backwards). We allocate token labels, and set tokens as
// final, but don't add any transitions. This may leave some states
// disconnected (e.g. due to chains of nonemitting arcs), but it's OK; we'll
// fix it when we generate the next chunk of lattice.
int32 frame = num_frames_to_include;
// Allocate state-ids for all tokens on this frame.
for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) {
/* If we included the final-costs at this stage, they will cause
non-final states to be pruned out from the end of the lattice. */
BaseFloat final_cost;
{ // This block computes final_cost
if (decoding_finalized_) {
if (final_costs_.empty()) {
final_cost = 0.0; /* No final-state survived, so treat all as final
* with probability One(). */
} else {
auto iter = final_costs_.find(tok);
if (iter == final_costs_.end())
final_cost = std::numeric_limits<BaseFloat>::infinity();
else
final_cost = iter->second;
}
} else {
/* this is a `fake` final-cost used to guide pruning. It's as if we
set the betas (backward-probs) on the final frame to the
negatives of the corresponding alphas, so all tokens on the last
frae will be on a best path.. the extra_cost for each token
always corresponds to its alpha+beta on this assumption. We want
the final_cost here to correspond to the beta (backward-prob), so
we get that by final_cost = extra_cost - tot_cost.
[The tot_cost is the forward/alpha cost.]
*/
final_cost = tok->extra_cost - tok->tot_cost;
}
}
StateId state = chunk_lat.AddState();
tok2state_map[tok] = state;
if (final_cost < std::numeric_limits<BaseFloat>::infinity()) {
next_token2label_map[tok] = AllocateNewTokenLabel();
StateId token_final_state = chunk_lat.AddState();
LatticeArc::Label ilabel = 0,
olabel = (next_token2label_map[tok] = AllocateNewTokenLabel());
chunk_lat.AddArc(state,
LatticeArc(ilabel, olabel,
LatticeWeight::One(),
token_final_state));
chunk_lat.SetFinal(token_final_state, LatticeWeight(final_cost, 0.0));
}
}
}
// Go in reverse order over the remaining frames so we can create arcs as we
// go, and their destination-states will already be in the map.
for (int32 frame = num_frames_to_include;
frame >= num_frames_in_lattice_; frame--) {
// The conditional below is needed for the last frame of the utterance.
BaseFloat cost_offset = (frame < cost_offsets_.size() ?
cost_offsets_[frame] : 0.0);
// For the first frame of the chunk, we need to make sure the states are
// the ones created by InitializeRawLatticeChunk() (where not pruned away).
if (frame == num_frames_in_lattice_ && num_frames_in_lattice_ != 0) {
for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) {
auto iter = token2label_map_.find(tok);
KALDI_ASSERT(iter != token2label_map_.end());
Label token_label = iter->second;
auto iter2 = token_label2state.find(token_label);
if (iter2 != token_label2state.end()) {
StateId state = iter2->second;
tok2state_map[tok] = state;
} else {
// Some states may have been pruned out, but we should still allocate
// them. They might have been part of chains of nonemitting arcs
// where the state became disconnected because the last chunk didn't
// include arcs starting at this frame.
StateId state = chunk_lat.AddState();
tok2state_map[tok] = state;
}
}
} else if (frame != num_frames_to_include) { // We already created states
// for the last frame.
for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) {
StateId state = chunk_lat.AddState();
tok2state_map[tok] = state;
}
}
for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) {
auto iter = tok2state_map.find(tok);
KALDI_ASSERT(iter != tok2state_map.end());
StateId cur_state = iter->second;
for (ForwardLinkT *l = tok->links; l != NULL; l = l->next) {
auto next_iter = tok2state_map.find(l->next_tok);
if (next_iter == tok2state_map.end()) {
// Emitting arcs from the last frame we're including -- ignore
// these.
KALDI_ASSERT(frame == num_frames_to_include);
continue;
}
StateId next_state = next_iter->second;
BaseFloat this_offset = (l->ilabel != 0 ? cost_offset : 0);
LatticeArc arc(l->ilabel, l->olabel,
LatticeWeight(l->graph_cost, l->acoustic_cost - this_offset),
next_state);
// Note: the epsilons get redundantly included at the end and beginning
// of successive chunks. These will get removed in the determinization.
chunk_lat.AddArc(cur_state, arc);
}
}
}
if (num_frames_in_lattice_ == 0) {
// This block locates the start token. NOTE: we use the fact that in the
// linked list of tokens, things are added at the head, so the start state
// must be at the tail. If this data structure is changed in future, we
// might need to explicitly store the start token as a class member.
Token *tok = active_toks_[0].toks;
if (tok == NULL) {
KALDI_WARN << "No tokens exist on start frame";
return determinizer_.GetLattice(); // will be empty.
}
while (tok->next != NULL)
tok = tok->next;
Token *start_token = tok;
auto iter = tok2state_map.find(start_token);
KALDI_ASSERT(iter != tok2state_map.end());
StateId start_state = iter->second;
chunk_lat.SetStart(start_state);
}
token2label_map_.swap(next_token2label_map);
// bool finished_before_beam =
determinizer_.AcceptRawLatticeChunk(&chunk_lat);
// We are ignoring the return status, which say whether it finished before the beam.
num_frames_in_lattice_ = num_frames_to_include;
if (determinizer_.GetLattice().NumStates() == 0)
return determinizer_.GetLattice(); // Something went wrong, lattice is empty.
}
unordered_map<Token*, BaseFloat> token2final_cost;
unordered_map<Label, BaseFloat> token_label2final_cost;
if (use_final_probs) {
ComputeFinalCosts(&token2final_cost, NULL, NULL);
for (const auto &p: token2final_cost) {
Token *tok = p.first;
BaseFloat cost = p.second;
auto iter = token2label_map_.find(tok);
if (iter != token2label_map_.end()) {
/* Some tokens may not have survived the pruned determinization. */
Label token_label = iter->second;
bool ret = token_label2final_cost.insert({token_label, cost}).second;
KALDI_ASSERT(ret); /* Make sure it was inserted. */
}
}
}
/* Note: these final-probs won't affect the next chunk, only the lattice
returned from GetLattice(). They are kind of temporaries. */
determinizer_.SetFinalCosts(token_label2final_cost.empty() ? NULL :
&token_label2final_cost);
return determinizer_.GetLattice();
}
template <typename FST, typename Token>
int32 LatticeIncrementalDecoderTpl<FST, Token>::GetNumToksForFrame(int32 frame) {
int32 r = 0;
for (Token *tok = active_toks_[frame].toks; tok; tok = tok->next) r++;
return r;
}
/* This utility function adds an arc to a Lattice, but where the source is a
CompactLatticeArc. If the CompactLatticeArc has a string with length greater
than 1, this will require adding extra states to `lat`.
*/
static void AddCompactLatticeArcToLattice(
const CompactLatticeArc &clat_arc,
LatticeArc::StateId src_state,
Lattice *lat) {
const std::vector<int32> &string = clat_arc.weight.String();
size_t N = string.size();
if (N == 0) {
LatticeArc arc;
arc.ilabel = 0;
arc.olabel = clat_arc.ilabel;
arc.nextstate = clat_arc.nextstate;
arc.weight = clat_arc.weight.Weight();
lat->AddArc(src_state, arc);
} else {
LatticeArc::StateId cur_state = src_state;
for (size_t i = 0; i < N; i++) {
LatticeArc arc;
arc.ilabel = string[i];
arc.olabel = (i == 0 ? clat_arc.ilabel : 0);
arc.nextstate = (i + 1 == N ? clat_arc.nextstate : lat->AddState());
arc.weight = (i == 0 ? clat_arc.weight.Weight() : LatticeWeight::One());
lat->AddArc(cur_state, arc);
cur_state = arc.nextstate;
}
}
}
void LatticeIncrementalDeterminizer::Init() {
non_final_redet_states_.clear();
clat_.DeleteStates();
final_arcs_.clear();
forward_costs_.clear();
arcs_in_.clear();
}
CompactLattice::StateId LatticeIncrementalDeterminizer::AddStateToClat() {
CompactLattice::StateId ans = clat_.AddState();
forward_costs_.push_back(std::numeric_limits<BaseFloat>::infinity());
KALDI_ASSERT(forward_costs_.size() == ans + 1);
arcs_in_.resize(ans + 1);
return ans;
}
void LatticeIncrementalDeterminizer::AddArcToClat(
CompactLattice::StateId state,
const CompactLatticeArc &arc) {
BaseFloat forward_cost = forward_costs_[state] +
ConvertToCost(arc.weight);
if (forward_cost == std::numeric_limits<BaseFloat>::infinity())
return;
int32 arc_idx = clat_.NumArcs(state);
clat_.AddArc(state, arc);
arcs_in_[arc.nextstate].push_back({state, arc_idx});
if (forward_cost < forward_costs_[arc.nextstate])
forward_costs_[arc.nextstate] = forward_cost;
}
// See documentation in header
void LatticeIncrementalDeterminizer::IdentifyTokenFinalStates(
const CompactLattice &chunk_clat,
std::unordered_map<CompactLattice::StateId, CompactLatticeArc::Label> *token_map) const {
token_map->clear();
using StateId = CompactLattice::StateId;
using Label = CompactLatticeArc::Label;
StateId num_states = chunk_clat.NumStates();
for (StateId state = 0; state < num_states; state++) {
for (fst::ArcIterator<CompactLattice> aiter(chunk_clat, state);
!aiter.Done(); aiter.Next()) {
const CompactLatticeArc &arc = aiter.Value();
if (arc.olabel >= kTokenLabelOffset && arc.olabel < kMaxTokenLabel) {
StateId nextstate = arc.nextstate;
auto r = token_map->insert({nextstate, arc.olabel});
// Check consistency of labels on incoming arcs
KALDI_ASSERT(r.first->second == arc.olabel);
}
}
}
}
void LatticeIncrementalDeterminizer::GetNonFinalRedetStates() {
using StateId = CompactLattice::StateId;
non_final_redet_states_.clear();
non_final_redet_states_.reserve(final_arcs_.size());
std::vector<StateId> state_queue;
for (const CompactLatticeArc &arc: final_arcs_) {
// Note: we abuse the .nextstate field to store the state which is really
// the source of that arc.
StateId redet_state = arc.nextstate;
if (forward_costs_[redet_state] != std::numeric_limits<BaseFloat>::infinity()) {
// if it is accessible..
if (non_final_redet_states_.insert(redet_state).second) {
// it was not already there
state_queue.push_back(redet_state);
}
}
}
// Add any states that are reachable from the states above.
while (!state_queue.empty()) {
StateId s = state_queue.back();
state_queue.pop_back();
for (fst::ArcIterator<CompactLattice> aiter(clat_, s); !aiter.Done();
aiter.Next()) {
const CompactLatticeArc &arc = aiter.Value();
StateId nextstate = arc.nextstate;
if (non_final_redet_states_.insert(nextstate).second)
state_queue.push_back(nextstate); // it was not already there
}
}
}
void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk(
Lattice *olat,
unordered_map<Label, LatticeArc::StateId> *token_label2state) {
using namespace fst;
olat->DeleteStates();
LatticeArc::StateId start_state = olat->AddState();
olat->SetStart(start_state);
token_label2state->clear();
// redet_state_map maps from state-ids in clat_ to state-ids in olat. This
// will be the set of states from which the arcs to final-states in the
// canonical appended lattice leave (physically, these are in the .nextstate
// elements of arcs_, since we use that field for the source state), plus any
// states reachable from those states.
unordered_map<CompactLattice::StateId, LatticeArc::StateId> redet_state_map;
for (CompactLattice::StateId redet_state: non_final_redet_states_)
redet_state_map[redet_state] = olat->AddState();
// First, process any arcs leaving the non-final redeterminized states that
// are not to final-states. (What we mean by "not to final states" is, not to
// stats that are final in the `canonical appended lattice`.. they may
// actually be physically final in clat_, because we make clat_ what we want
// to return to the user.
for (CompactLattice::StateId redet_state: non_final_redet_states_) {
LatticeArc::StateId lat_state = redet_state_map[redet_state];
for (ArcIterator<CompactLattice> aiter(clat_, redet_state);
!aiter.Done(); aiter.Next()) {
const CompactLatticeArc &arc = aiter.Value();
CompactLattice::StateId nextstate = arc.nextstate;
LatticeArc::StateId lat_nextstate = olat->NumStates();
auto r = redet_state_map.insert({nextstate, lat_nextstate});
if (r.second) { // Was inserted.
LatticeArc::StateId s = olat->AddState();
KALDI_ASSERT(s == lat_nextstate);
} else {
// was not inserted -> was already there.
lat_nextstate = r.first->second;
}
CompactLatticeArc clat_arc(arc);
clat_arc.nextstate = lat_nextstate;
AddCompactLatticeArcToLattice(clat_arc, lat_state, olat);
}
clat_.DeleteArcs(redet_state);
clat_.SetFinal(redet_state, CompactLatticeWeight::Zero());
}
for (const CompactLatticeArc &arc: final_arcs_) {
// We abuse the `nextstate` field to store the source state.
CompactLattice::StateId src_state = arc.nextstate;
auto iter = redet_state_map.find(src_state);
if (forward_costs_[src_state] == std::numeric_limits<BaseFloat>::infinity())
continue; /* Unreachable state */
KALDI_ASSERT(iter != redet_state_map.end());
LatticeArc::StateId src_lat_state = iter->second;
Label token_label = arc.ilabel; // will be == arc.olabel.
KALDI_ASSERT(token_label >= kTokenLabelOffset &&
token_label < kMaxTokenLabel);
auto r = token_label2state->insert({token_label,
olat->NumStates()});
LatticeArc::StateId dest_lat_state = r.first->second;
if (r.second) { // was inserted
LatticeArc::StateId new_state = olat->AddState();
KALDI_ASSERT(new_state == dest_lat_state);
}
CompactLatticeArc new_arc;
new_arc.nextstate = dest_lat_state;
/* We convert the token-label to epsilon; it's not needed anymore. */
new_arc.ilabel = new_arc.olabel = 0;
new_arc.weight = arc.weight;
AddCompactLatticeArcToLattice(new_arc, src_lat_state, olat);
}
// Now deal with the initial-probs. Arcs from initial-states to
// redeterminized-states in the raw lattice have an olabel that identifies the
// id of that redeterminized-state in clat_, and a cost that is derived from
// its entry in forward_costs_. These forward-probs are used to get the
// pruned lattice determinization to behave correctly, and will be canceled
// out later on.
//
// In the paper this is the second-from-last bullet in Sec. 5.2. NOTE: in the
// paper we state that we only include such arcs for "each redeterminized
// state that is either initial in det(A) or that has an arc entering it from
// a state that is not a redeterminized state." In fact, we include these
// arcs for all redeterminized states. I realized that it won't make a
// difference to the outcome, and it's easier to do it this way.
for (CompactLattice::StateId state_id: non_final_redet_states_) {
BaseFloat forward_cost = forward_costs_[state_id];
LatticeArc arc;
arc.ilabel = 0;
// The olabel (which appears where the word-id would) is what
// we call a 'state-label'. It identifies a state in clat_.
arc.olabel = state_id + kStateLabelOffset;
// It doesn't matter what field we put forward_cost in (or whether we
// divide it among them both; the effect on pruning is the same, and
// we will cancel it out later anyway.
arc.weight = LatticeWeight(forward_cost, 0);
auto iter = redet_state_map.find(state_id);
KALDI_ASSERT(iter != redet_state_map.end());
arc.nextstate = iter->second;
olat->AddArc(start_state, arc);
}
}
void LatticeIncrementalDeterminizer::GetRawLatticeFinalCosts(
const Lattice &raw_fst,
std::unordered_map<Label, BaseFloat> *old_final_costs) {
LatticeArc::StateId raw_fst_num_states = raw_fst.NumStates();
for (LatticeArc::StateId s = 0; s < raw_fst_num_states; s++) {
for (fst::ArcIterator<Lattice> aiter(raw_fst, s); !aiter.Done();
aiter.Next()) {
const LatticeArc &value = aiter.Value();
if (value.olabel >= (Label)kTokenLabelOffset &&
value.olabel < (Label)kMaxTokenLabel) {
LatticeWeight final_weight = raw_fst.Final(value.nextstate);
if (final_weight != LatticeWeight::Zero() &&
final_weight.Value2() != 0) {
KALDI_ERR << "Label " << value.olabel << " from state " << s
<< " looks like a token-label but its next-state "
<< value.nextstate <<
" has unexpected final-weight " << final_weight.Value1() << ','
<< final_weight.Value2();
}
auto r = old_final_costs->insert({value.olabel,
final_weight.Value1()});
if (!r.second && r.first->second != final_weight.Value1()) {
// For any given token-label, all arcs in raw_fst with that
// olabel should go to the same state, so this should be
// impossible.
KALDI_ERR << "Unexpected mismatch in final-costs for tokens, "
<< r.first->second << " vs " << final_weight.Value1();
}
}
}
}
}
bool LatticeIncrementalDeterminizer::ProcessArcsFromChunkStartState(
const CompactLattice &chunk_clat,
std::unordered_map<CompactLattice::StateId, CompactLattice::StateId> *state_map) {
using StateId = CompactLattice::StateId;
StateId clat_num_states = clat_.NumStates();
// Process arcs leaving the start state of chunk_clat. These arcs will have
// state-labels on them (unless this is the first chunk).
// For destination-states of those arcs, work out which states in
// clat_ they correspond to and update their forward_costs.
for (fst::ArcIterator<CompactLattice> aiter(chunk_clat, chunk_clat.Start());
!aiter.Done(); aiter.Next()) {
const CompactLatticeArc &arc = aiter.Value();
Label label = arc.ilabel; // ilabel == olabel; would be the olabel
// in a Lattice.
if (!(label >= kStateLabelOffset &&
label - kStateLabelOffset < clat_num_states)) {
// The label was not a state-label. This should only be possible on the
// first chunk.
KALDI_ASSERT(state_map->empty());
return true; // this is the first chunk.
}
StateId clat_state = label - kStateLabelOffset;
StateId chunk_state = arc.nextstate;
auto p = state_map->insert({chunk_state, clat_state});
StateId dest_clat_state = p.first->second;
// We deleted all its arcs in InitializeRawLatticeChunk
KALDI_ASSERT(clat_.NumArcs(clat_state) == 0);
/*
In almost all cases, dest_clat_state and clat_state will be the same state;
but there may be situations where two arcs with different state-labels
left the start state and entered the same next-state in chunk_clat; and in
these cases, they will be different.
We didn't address this issue in the paper (or actually realize it could be
a problem). What we do is pick one of the clat_states as the "canonical"
one, and redirect all incoming transitions of the others to enter the
"canonical" one. (Search below for new_in_arc.nextstate =
dest_clat_state).
*/
if (clat_state != dest_clat_state) {
// Check that the start state isn't getting merged with any other state.
// If this were possible, we'd need to deal with it specially, but it
// can't be, because to be merged, 2 states must have identical arcs
// leaving them with identical weights, so we'd need to have another state
// on frame 0 identical to the start state, which is not possible if the
// lattice is deterministic and epsilon-free.
KALDI_ASSERT(clat_state != 0 && dest_clat_state != 0);
}
// in_weight is an extra weight that we'll include on arcs entering this
// state from the previous chunk. We need to cancel out
// `forward_costs[clat_state]`, which was included in the corresponding arc
// in the raw lattice for pruning purposes; and we need to include the
// weight on the arc from the start-state of `chunk_clat` to this state.
CompactLatticeWeight extra_weight_in = arc.weight;
extra_weight_in.SetWeight(
fst::Times(extra_weight_in.Weight(),
LatticeWeight(-forward_costs_[clat_state], 0.0)));
// We don't allow state 0 to be a redeterminized-state; calling code assures
// this. Search for `determinizer_.GetLattice().Final(0) !=
// CompactLatticeWeight::Zero())` to find that calling code.
KALDI_ASSERT(clat_state != 0);
// Note: 0 is the start state of clat_. This was checked.
forward_costs_[clat_state] = (clat_state == 0 ? 0 :
std::numeric_limits<BaseFloat>::infinity());
std::vector<std::pair<StateId, int32> > arcs_in;
arcs_in.swap(arcs_in_[clat_state]);
for (auto p: arcs_in) {
// Note: we'll be doing `continue` below if this input arc came from
// another redeterminized-state, because we did DeleteArcs() for them in
// InitializeRawLatticeChunk(). Those arcs will be transferred
// from chunk_clat later on.
CompactLattice::StateId src_state = p.first;
int32 arc_pos = p.second;
if (arc_pos >= (int32)clat_.NumArcs(src_state))
continue;
fst::MutableArcIterator<CompactLattice> aiter(&clat_, src_state);
aiter.Seek(arc_pos);
if (aiter.Value().nextstate != clat_state)
continue; // This arc record has become invalidated.
CompactLatticeArc new_in_arc(aiter.Value());
// In most cases we will have dest_clat_state == clat_state, so the next
// line won't change the value of .nextstate
new_in_arc.nextstate = dest_clat_state;
new_in_arc.weight = fst::Times(new_in_arc.weight, extra_weight_in);
aiter.SetValue(new_in_arc);
BaseFloat new_forward_cost = forward_costs_[src_state] +
ConvertToCost(new_in_arc.weight);
if (new_forward_cost < forward_costs_[dest_clat_state])
forward_costs_[dest_clat_state] = new_forward_cost;
arcs_in_[dest_clat_state].push_back(p);
}
}
return false; // this is not the first chunk.
}
void LatticeIncrementalDeterminizer::TransferArcsToClat(
const CompactLattice &chunk_clat,
bool is_first_chunk,
const std::unordered_map<CompactLattice::StateId, CompactLattice::StateId> &state_map,
const std::unordered_map<CompactLattice::StateId, Label> &chunk_state_to_token,
const std::unordered_map<Label, BaseFloat> &old_final_costs) {
using StateId = CompactLattice::StateId;
StateId chunk_num_states = chunk_clat.NumStates();
// Now transfer arcs from chunk_clat to clat_.
for (StateId chunk_state = (is_first_chunk ? 0 : 1);
chunk_state < chunk_num_states; chunk_state++) {
auto iter = state_map.find(chunk_state);
if (iter == state_map.end()) {
KALDI_ASSERT(chunk_state_to_token.count(chunk_state) != 0);
// Don't process token-final states. Anyway they have no arcs leaving
// them.
continue;
}
StateId clat_state = iter->second;
// We know that this point that `clat_state` is not a token-final state
// (see glossary for definition) as if it were, we would have done
// `continue` above.
//
// Only in the last chunk of the lattice would be there be a final-prob on
// states that are not `token-final states`; these final-probs would
// normally all be Zero() at this point. So in almost all cases the following
// call will do nothing.
clat_.SetFinal(clat_state, chunk_clat.Final(chunk_state));
// Process arcs leaving this state.
for (fst::ArcIterator<CompactLattice> aiter(chunk_clat, chunk_state);
!aiter.Done(); aiter.Next()) {
CompactLatticeArc arc(aiter.Value());
auto next_iter = state_map.find(arc.nextstate);
if (next_iter != state_map.end()) {
// The normal case (when the .nextstate has a corresponding
// state in clat_) is very simple. Just copy the arc over.
arc.nextstate = next_iter->second;
KALDI_ASSERT(arc.ilabel < kTokenLabelOffset ||
arc.ilabel > kMaxTokenLabel);
AddArcToClat(clat_state, arc);
} else {
// This is the case when the arc is to a `token-final` state (see
// glossary.)
// TODO: remove the following slightly excessive assertion?
KALDI_ASSERT(chunk_clat.Final(arc.nextstate) != CompactLatticeWeight::Zero() &&
arc.olabel >= (Label)kTokenLabelOffset &&
arc.olabel < (Label)kMaxTokenLabel &&
chunk_state_to_token.count(arc.nextstate) != 0 &&
old_final_costs.count(arc.olabel) != 0);
// Include the final-cost of the next state (which should be final)
// in arc.weight.
arc.weight = fst::Times(arc.weight,
chunk_clat.Final(arc.nextstate));
auto cost_iter = old_final_costs.find(arc.olabel);
KALDI_ASSERT(cost_iter != old_final_costs.end());
BaseFloat old_final_cost = cost_iter->second;
// `arc` is going to become an element of final_arcs_. These
// contain information about transitions from states in clat_ to
// `token-final` states (i.e. states that have a token-label on the arc
// to them and that are final in the canonical compact lattice).
// We subtract the old_final_cost as it was just a temporary cost
// introduced for pruning purposes.
arc.weight.SetWeight(fst::Times(arc.weight.Weight(),
LatticeWeight{-old_final_cost, 0.0}));
// In a slight abuse of the Arc data structure, the nextstate is set to
// the source state. The label (ilabel == olabel) indicates the
// token it is associated with.
arc.nextstate = clat_state;
final_arcs_.push_back(arc);
}
}
}
}
bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk(
Lattice *raw_fst) {
using Label = CompactLatticeArc::Label;
using StateId = CompactLattice::StateId;
// old_final_costs is a map from a `token-label` (see glossary) to the
// associated final-prob in a final-state of `raw_fst`, that is associated
// with that Token. These are Tokens that were active at the end of the
// chunk. The final-probs may arise from beta (backward) costs, introduced
// for pruning purposes, and/or from final-probs in HCLG. Those costs will
// not be included in anything we store permamently in this class; they used
// only to guide pruned determinization, and we will use `old_final_costs`
// later to cancel them out.
std::unordered_map<Label, BaseFloat> old_final_costs;
GetRawLatticeFinalCosts(*raw_fst, &old_final_costs);
CompactLattice chunk_clat;
bool determinized_till_beam = DeterminizeLatticePhonePrunedWrapper(
trans_model_, raw_fst, config_.lattice_beam, &chunk_clat,
config_.det_opts);
TopSortCompactLatticeIfNeeded(&chunk_clat);
std::unordered_map<StateId, Label> chunk_state_to_token;
IdentifyTokenFinalStates(chunk_clat,
&chunk_state_to_token);
StateId chunk_num_states = chunk_clat.NumStates();
if (chunk_num_states == 0) {
// This will be an error but user-level calling code can detect it from the
// lattice being empty.
KALDI_WARN << "Empty lattice, something went wrong.";
clat_.DeleteStates();
return false;
}
StateId start_state = chunk_clat.Start(); // would be 0.
KALDI_ASSERT(start_state == 0);
// Process arcs leaving the start state of chunk_clat. Unless this is the
// first chunk in the lattice, all arcs leaving the start state of chunk_clat
// will have `state labels` on them (identifying redeterminized-states in
// clat_), and will transition to a state in `chunk_clat` that we can identify
// with that redeterminized-state.
// state_map maps from (non-initial, non-token-final state s in chunk_clat) to
// a state in clat_.
std::unordered_map<StateId, StateId> state_map;
bool is_first_chunk = ProcessArcsFromChunkStartState(chunk_clat, &state_map);
// Remove any existing arcs in clat_ that leave redeterminized-states, and
// make those states non-final. Below, we'll add arcs leaving those states
// (and possibly new final-probs.)
for (StateId clat_state: non_final_redet_states_) {
clat_.DeleteArcs(clat_state);
clat_.SetFinal(clat_state, CompactLatticeWeight::Zero());
}
// The previous final-arc info is no longer relevant; we'll recreate it below.
final_arcs_.clear();
// assume chunk_lat.Start() == 0; we asserted it above. Allocate state-ids
// for all remaining states in chunk_clat, except for token-final states.
for (StateId state = (is_first_chunk ? 0 : 1);
state < chunk_num_states; state++) {
if (chunk_state_to_token.count(state) != 0)
continue; // these `token-final` states don't get a state allocated.
StateId new_clat_state = clat_.NumStates();
if (state_map.insert({state, new_clat_state}).second) {
// If it was inserted then we need to actually allocate that state
StateId s = AddStateToClat();
KALDI_ASSERT(s == new_clat_state);
} // else do nothing; it would have been a redeterminized-state and no
} // allocation is needed since they already exist in clat_. and
// in state_map.
if (is_first_chunk) {
auto iter = state_map.find(start_state);
KALDI_ASSERT(iter != state_map.end());
CompactLattice::StateId clat_start_state = iter->second;
KALDI_ASSERT(clat_start_state == 0); // topological order.
clat_.SetStart(clat_start_state);
forward_costs_[clat_start_state] = 0.0;
}
TransferArcsToClat(chunk_clat, is_first_chunk,
state_map, chunk_state_to_token, old_final_costs);
GetNonFinalRedetStates();
return determinized_till_beam;
}
void LatticeIncrementalDeterminizer::SetFinalCosts(
const unordered_map<Label, BaseFloat> *token_label2final_cost) {
if (final_arcs_.empty()) {
KALDI_WARN << "SetFinalCosts() called when final_arcs_.empty()... possibly "
"means you are calling this after Finalize()? Not allowed: could "
"indicate a code error. Or possibly decoding failed somehow.";
}
/*
prefinal states a terminology that does not appear in the paper. What it
means is: the set of states that have an arc with a Token-label as the label
leaving them in the canonical appended lattice.
*/
std::unordered_set<int32> &prefinal_states(temp_);
prefinal_states.clear();
for (const auto &arc: final_arcs_) {
/* Caution: `state` is actually the state the arc would
leave from in the canonical appended lattice; we just store
that in the .nextstate field. */
CompactLattice::StateId state = arc.nextstate;
prefinal_states.insert(state);
}
for (int32 state: prefinal_states)
clat_.SetFinal(state, CompactLatticeWeight::Zero());
for (const CompactLatticeArc &arc: final_arcs_) {
Label token_label = arc.ilabel;
/* Note: we store the source state in the .nextstate field. */
CompactLattice::StateId src_state = arc.nextstate;
BaseFloat graph_final_cost;
if (token_label2final_cost == NULL) {
graph_final_cost = 0.0;
} else {
auto iter = token_label2final_cost->find(token_label);
if (iter == token_label2final_cost->end())
continue;
else
graph_final_cost = iter->second;
}
/* It might seem odd to set a final-prob on the src-state of the arc..
the point is that the symbol on the arc is a token-label, which should not
appear in the lattice the user sees, so after that token-label is removed
the arc would just become a final-prob.
*/
clat_.SetFinal(src_state,
fst::Plus(clat_.Final(src_state),
fst::Times(arc.weight,
CompactLatticeWeight(
LatticeWeight(graph_final_cost, 0), {}))));
}
}
// Instantiate the template for the combination of token types and FST types
// that we'll need.
template class LatticeIncrementalDecoderTpl<fst::Fst<fst::StdArc>, decoder::StdToken>;
template class LatticeIncrementalDecoderTpl<fst::VectorFst<fst::StdArc>,
decoder::StdToken>;
template class LatticeIncrementalDecoderTpl<fst::ConstFst<fst::StdArc>,
decoder::StdToken>;
template class LatticeIncrementalDecoderTpl<fst::ConstGrammarFst ,
decoder::StdToken>;
template class LatticeIncrementalDecoderTpl<fst::VectorGrammarFst,
decoder::StdToken>;
template class LatticeIncrementalDecoderTpl<fst::Fst<fst::StdArc>,
decoder::BackpointerToken>;
template class LatticeIncrementalDecoderTpl<fst::VectorFst<fst::StdArc>,
decoder::BackpointerToken>;
template class LatticeIncrementalDecoderTpl<fst::ConstFst<fst::StdArc>,
decoder::BackpointerToken>;
template class LatticeIncrementalDecoderTpl<fst::ConstGrammarFst,
decoder::BackpointerToken>;
template class LatticeIncrementalDecoderTpl<fst::VectorGrammarFst,
decoder::BackpointerToken>;
} // end namespace kaldi.