FunASR/runtime/onnxruntime/third_party/kaldi/decoder/lattice-simple-decoder.h

319 lines
14 KiB
C
Raw Normal View History

2024-05-18 15:50:56 +08:00
// decoder/lattice-simple-decoder.h
// Copyright 2009-2012 Microsoft Corporation
// 2012-2014 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_DECODER_LATTICE_SIMPLE_DECODER_H_
#define KALDI_DECODER_LATTICE_SIMPLE_DECODER_H_
#include "util/stl-utils.h"
#include "fst/fstlib.h"
#include "itf/decodable-itf.h"
#include "fstext/fstext-lib.h"
#include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h"
#include <algorithm>
namespace kaldi {
struct LatticeSimpleDecoderConfig {
BaseFloat beam;
BaseFloat lattice_beam;
int32 prune_interval;
bool determinize_lattice; // not inspected by this class... used in
// command-line program.
bool prune_lattice;
BaseFloat beam_ratio;
BaseFloat prune_scale; // Note: we don't make this configurable on the command line,
// it's not a very important parameter. It affects the
// algorithm that prunes the tokens as we go.
fst::DeterminizeLatticePhonePrunedOptions det_opts;
LatticeSimpleDecoderConfig(): beam(16.0),
lattice_beam(10.0),
prune_interval(25),
determinize_lattice(true),
beam_ratio(0.9),
prune_scale(0.1) { }
void Register(OptionsItf *opts) {
det_opts.Register(opts);
opts->Register("beam", &beam, "Decoding beam.");
opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam");
opts->Register("prune-interval", &prune_interval, "Interval (in frames) at "
"which to prune tokens");
opts->Register("determinize-lattice", &determinize_lattice, "If true, "
"determinize the lattice (in a special sense, keeping only "
"best pdf-sequence for each word-sequence).");
}
void Check() const {
KALDI_ASSERT(beam > 0.0 && lattice_beam > 0.0 && prune_interval > 0);
}
};
/** Simplest possible decoder, included largely for didactic purposes and as a
means to debug more highly optimized decoders. See \ref decoders_simple
for more information.
*/
class LatticeSimpleDecoder {
public:
typedef fst::StdArc Arc;
typedef Arc::Label Label;
typedef Arc::StateId StateId;
typedef Arc::Weight Weight;
// instantiate this class once for each thing you have to decode.
LatticeSimpleDecoder(const fst::Fst<fst::StdArc> &fst,
const LatticeSimpleDecoderConfig &config):
fst_(fst), config_(config), num_toks_(0) { config.Check(); }
~LatticeSimpleDecoder() { ClearActiveTokens(); }
const LatticeSimpleDecoderConfig &GetOptions() const {
return config_;
}
// Returns true if any kind of traceback is available (not necessarily from
// a final state).
bool Decode(DecodableInterface *decodable);
/// says whether a final-state was active on the last frame. If it was not, the
/// lattice (or traceback) will end with states that are not final-states.
bool ReachedFinal() const {
return FinalRelativeCost() != std::numeric_limits<BaseFloat>::infinity();
}
/// InitDecoding initializes the decoding, and should only be used if you
/// intend to call AdvanceDecoding(). If you call Decode(), you don't need
/// to call this. You can call InitDecoding if you have already decoded an
/// utterance and want to start with a new utterance.
void InitDecoding();
/// This function may be optionally called after AdvanceDecoding(), when you
/// do not plan to decode any further. It does an extra pruning step that
/// will help to prune the lattices output by GetLattice and (particularly)
/// GetRawLattice more accurately, particularly toward the end of the
/// utterance. It does this by using the final-probs in pruning (if any
/// final-state survived); it also does a final pruning step that visits all
/// states (the pruning that is done during decoding may fail to prune states
/// that are within kPruningScale = 0.1 outside of the beam). If you call
/// this, you cannot call AdvanceDecoding again (it will fail), and you
/// cannot call GetLattice() and related functions with use_final_probs =
/// false.
/// Used to be called PruneActiveTokensFinal().
void FinalizeDecoding();
/// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives
/// more information. It returns the difference between the best (final-cost
/// plus cost) of any token on the final frame, and the best cost of any token
/// on the final frame. If it is infinity it means no final-states were
/// present on the final frame. It will usually be nonnegative. If it not
/// too positive (e.g. < 5 is my first guess, but this is not tested) you can
/// take it as a good indication that we reached the final-state with
/// reasonable likelihood.
BaseFloat FinalRelativeCost() const;
// Outputs an FST corresponding to the single best path
// through the lattice. Returns true if result is nonempty
// (using the return status is deprecated, it will become void).
// If "use_final_probs" is true AND we reached the final-state
// of the graph then it will include those as final-probs, else
// it will treat all final-probs as one.
bool GetBestPath(Lattice *lat,
bool use_final_probs = true) const;
// Outputs an FST corresponding to the raw, state-level
// tracebacks. Returns true if result is nonempty
// (using the return status is deprecated, it will become void).
// If "use_final_probs" is true AND we reached the final-state
// of the graph then it will include those as final-probs, else
// it will treat all final-probs as one.
bool GetRawLattice(Lattice *lat,
bool use_final_probs = true) const;
// This function is now deprecated, since now we do determinization from
// outside the LatticeTrackingDecoder class.
// Outputs an FST corresponding to the lattice-determinized
// lattice (one path per word sequence). [will become deprecated,
// users should determinize themselves.]
bool GetLattice(CompactLattice *clat,
bool use_final_probs = true) const;
inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; }
private:
struct Token;
// ForwardLinks are the links from a token to a token on the next frame.
// or sometimes on the current frame (for input-epsilon links).
struct ForwardLink {
Token *next_tok; // the next token [or NULL if represents final-state]
Label ilabel; // ilabel on link.
Label olabel; // olabel on link.
BaseFloat graph_cost; // graph cost of traversing link (contains LM, etc.)
BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing link
ForwardLink *next; // next in singly-linked list of forward links from a
// token.
ForwardLink(Token *next_tok, Label ilabel, Label olabel,
BaseFloat graph_cost, BaseFloat acoustic_cost,
ForwardLink *next):
next_tok(next_tok), ilabel(ilabel), olabel(olabel),
graph_cost(graph_cost), acoustic_cost(acoustic_cost),
next(next) { }
};
// Token is what's resident in a particular state at a particular time.
// In this decoder a Token actually contains *forward* links.
// When first created, a Token just has the (total) cost. We add forward
// links from it when we process the next frame.
struct Token {
BaseFloat tot_cost; // would equal weight.Value()... cost up to this point.
BaseFloat extra_cost; // >= 0. After calling PruneForwardLinks, this equals
// the minimum difference between the cost of the best path this is on,
// and the cost of the absolute best path, under the assumption
// that any of the currently active states at the decoding front may
// eventually succeed (e.g. if you were to take the currently active states
// one by one and compute this difference, and then take the minimum).
ForwardLink *links; // Head of singly linked list of ForwardLinks
Token *next; // Next in list of tokens for this frame.
Token(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLink *links,
Token *next): tot_cost(tot_cost), extra_cost(extra_cost), links(links),
next(next) { }
Token() {}
void DeleteForwardLinks() {
ForwardLink *l = links, *m;
while (l != NULL) {
m = l->next;
delete l;
l = m;
}
links = NULL;
}
};
// head and tail of per-frame list of Tokens (list is in topological order),
// and something saying whether we ever pruned it using PruneForwardLinks.
struct TokenList {
Token *toks;
bool must_prune_forward_links;
bool must_prune_tokens;
TokenList(): toks(NULL), must_prune_forward_links(true),
must_prune_tokens(true) { }
};
// FindOrAddToken either locates a token in cur_toks_, or if necessary inserts a new,
// empty token (i.e. with no forward links) for the current frame. [note: it's
// inserted if necessary into cur_toks_ and also into the singly linked list
// of tokens active on this frame (whose head is at active_toks_[frame]).
//
// Returns the Token pointer. Sets "changed" (if non-NULL) to true
// if the token was newly created or the cost changed.
inline Token *FindOrAddToken(StateId state, int32 frame_plus_one,
BaseFloat tot_cost, bool emitting, bool *changed);
// delta is the amount by which the extra_costs must
// change before it sets "extra_costs_changed" to true. If delta is larger,
// we'll tend to go back less far toward the beginning of the file.
void PruneForwardLinks(int32 frame, bool *extra_costs_changed,
bool *links_pruned,
BaseFloat delta);
// PruneForwardLinksFinal is a version of PruneForwardLinks that we call
// on the final frame. If there are final tokens active, it uses the final-probs
// for pruning, otherwise it treats all tokens as final.
void PruneForwardLinksFinal();
// Prune away any tokens on this frame that have no forward links. [we don't do
// this in PruneForwardLinks because it would give us a problem with dangling
// pointers].
void PruneTokensForFrame(int32 frame);
// Go backwards through still-alive tokens, pruning them if the
// forward+backward cost is more than lat_beam away from the best path. It's
// possible to prove that this is "correct" in the sense that we won't lose
// anything outside of lat_beam, regardless of what happens in the future.
// delta controls when it considers a cost to have changed enough to continue
// going backward and propagating the change. larger delta -> will recurse
// less far.
void PruneActiveTokens(BaseFloat delta);
void ProcessEmitting(DecodableInterface *decodable);
void ProcessNonemitting();
void ClearActiveTokens(); // a cleanup routine, at utt end/begin
// This function computes the final-costs for tokens active on the final
// frame. It outputs to final-costs, if non-NULL, a map from the Token*
// pointer to the final-prob of the corresponding state, or zero for all states if
// none were final. It outputs to final_relative_cost, if non-NULL, the
// difference between the best forward-cost including the final-prob cost, and
// the best forward-cost without including the final-prob cost (this will
// usually be positive), or infinity if there were no final-probs. It outputs
// to final_best_cost, if non-NULL, the lowest for any token t active on the
// final frame, of t + final-cost[t], where final-cost[t] is the final-cost
// in the graph of the state corresponding to token t, or zero if there
// were no final-probs active on the final frame.
// You cannot call this after FinalizeDecoding() has been called; in that
// case you should get the answer from class-member variables.
void ComputeFinalCosts(unordered_map<Token*, BaseFloat> *final_costs,
BaseFloat *final_relative_cost,
BaseFloat *final_best_cost) const;
// PruneCurrentTokens deletes the tokens from the "toks" map, but not
// from the active_toks_ list, which could cause dangling forward pointers
// (will delete it during regular pruning operation).
void PruneCurrentTokens(BaseFloat beam, unordered_map<StateId, Token*> *toks);
unordered_map<StateId, Token*> cur_toks_;
unordered_map<StateId, Token*> prev_toks_;
std::vector<TokenList> active_toks_; // Lists of tokens, indexed by
// frame_plus_one
const fst::Fst<fst::StdArc> &fst_;
LatticeSimpleDecoderConfig config_;
int32 num_toks_; // current total #toks allocated...
bool warned_;
/// decoding_finalized_ is true if someone called FinalizeDecoding(). [note,
/// calling this is optional]. If true, it's forbidden to decode more. Also,
/// if this is set, then the output of ComputeFinalCosts() is in the next
/// three variables. The reason we need to do this is that after
/// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some
/// of the tokens on the last frame are freed, so we free the list from
/// cur_toks_ to avoid having dangling pointers hanging around.
bool decoding_finalized_;
/// For the meaning of the next 3 variables, see the comment for
/// decoding_finalized_ above., and ComputeFinalCosts().
unordered_map<Token*, BaseFloat> final_costs_;
BaseFloat final_relative_cost_;
BaseFloat final_best_cost_;
};
} // end namespace kaldi.
#endif