294 lines
10 KiB
C++
294 lines
10 KiB
C++
// decoder/simple-decoder.cc
|
|
|
|
// Copyright 2009-2011 Microsoft Corporation
|
|
// 2012-2013 Johns Hopkins University (author: Daniel Povey)
|
|
|
|
// See ../../COPYING for clarification regarding multiple authors
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
|
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
|
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
|
// See the Apache 2 License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "decoder/simple-decoder.h"
|
|
#include "fstext/remove-eps-local.h"
|
|
#include <algorithm>
|
|
|
|
namespace kaldi {
|
|
|
|
SimpleDecoder::~SimpleDecoder() {
|
|
ClearToks(cur_toks_);
|
|
ClearToks(prev_toks_);
|
|
}
|
|
|
|
|
|
bool SimpleDecoder::Decode(DecodableInterface *decodable) {
|
|
InitDecoding();
|
|
AdvanceDecoding(decodable);
|
|
return (!cur_toks_.empty());
|
|
}
|
|
|
|
void SimpleDecoder::InitDecoding() {
|
|
// clean up from last time:
|
|
ClearToks(cur_toks_);
|
|
ClearToks(prev_toks_);
|
|
// initialize decoding:
|
|
StateId start_state = fst_.Start();
|
|
KALDI_ASSERT(start_state != fst::kNoStateId);
|
|
StdArc dummy_arc(0, 0, StdWeight::One(), start_state);
|
|
cur_toks_[start_state] = new Token(dummy_arc, 0.0, NULL);
|
|
num_frames_decoded_ = 0;
|
|
ProcessNonemitting();
|
|
}
|
|
|
|
void SimpleDecoder::AdvanceDecoding(DecodableInterface *decodable,
|
|
int32 max_num_frames) {
|
|
KALDI_ASSERT(num_frames_decoded_ >= 0 &&
|
|
"You must call InitDecoding() before AdvanceDecoding()");
|
|
int32 num_frames_ready = decodable->NumFramesReady();
|
|
// num_frames_ready must be >= num_frames_decoded, or else
|
|
// the number of frames ready must have decreased (which doesn't
|
|
// make sense) or the decodable object changed between calls
|
|
// (which isn't allowed).
|
|
KALDI_ASSERT(num_frames_ready >= num_frames_decoded_);
|
|
int32 target_frames_decoded = num_frames_ready;
|
|
if (max_num_frames >= 0)
|
|
target_frames_decoded = std::min(target_frames_decoded,
|
|
num_frames_decoded_ + max_num_frames);
|
|
while (num_frames_decoded_ < target_frames_decoded) {
|
|
// note: ProcessEmitting() increments num_frames_decoded_
|
|
ClearToks(prev_toks_);
|
|
cur_toks_.swap(prev_toks_);
|
|
ProcessEmitting(decodable);
|
|
ProcessNonemitting();
|
|
PruneToks(beam_, &cur_toks_);
|
|
}
|
|
}
|
|
|
|
bool SimpleDecoder::ReachedFinal() const {
|
|
for (unordered_map<StateId, Token*>::const_iterator iter = cur_toks_.begin();
|
|
iter != cur_toks_.end();
|
|
++iter) {
|
|
if (iter->second->cost_ != std::numeric_limits<BaseFloat>::infinity() &&
|
|
fst_.Final(iter->first) != StdWeight::Zero())
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
BaseFloat SimpleDecoder::FinalRelativeCost() const {
|
|
// as a special case, if there are no active tokens at all (e.g. some kind of
|
|
// pruning failure), return infinity.
|
|
double infinity = std::numeric_limits<double>::infinity();
|
|
if (cur_toks_.empty())
|
|
return infinity;
|
|
double best_cost = infinity,
|
|
best_cost_with_final = infinity;
|
|
for (unordered_map<StateId, Token*>::const_iterator iter = cur_toks_.begin();
|
|
iter != cur_toks_.end();
|
|
++iter) {
|
|
// Note: Plus is taking the minimum cost, since we're in the tropical
|
|
// semiring.
|
|
best_cost = std::min(best_cost, iter->second->cost_);
|
|
best_cost_with_final = std::min(best_cost_with_final,
|
|
iter->second->cost_ +
|
|
fst_.Final(iter->first).Value());
|
|
}
|
|
BaseFloat extra_cost = best_cost_with_final - best_cost;
|
|
if (extra_cost != extra_cost) { // NaN. This shouldn't happen; it indicates some
|
|
// kind of error, most likely.
|
|
KALDI_WARN << "Found NaN (likely search failure in decoding)";
|
|
return infinity;
|
|
}
|
|
// Note: extra_cost will be infinity if no states were final.
|
|
return extra_cost;
|
|
}
|
|
|
|
// Outputs an FST corresponding to the single best path
|
|
// through the lattice.
|
|
bool SimpleDecoder::GetBestPath(Lattice *fst_out, bool use_final_probs) const {
|
|
fst_out->DeleteStates();
|
|
Token *best_tok = NULL;
|
|
bool is_final = ReachedFinal();
|
|
if (!is_final) {
|
|
for (unordered_map<StateId, Token*>::const_iterator iter = cur_toks_.begin();
|
|
iter != cur_toks_.end();
|
|
++iter)
|
|
if (best_tok == NULL || *best_tok < *(iter->second) )
|
|
best_tok = iter->second;
|
|
} else {
|
|
double infinity =std::numeric_limits<double>::infinity(),
|
|
best_cost = infinity;
|
|
for (unordered_map<StateId, Token*>::const_iterator iter = cur_toks_.begin();
|
|
iter != cur_toks_.end();
|
|
++iter) {
|
|
double this_cost = iter->second->cost_ + fst_.Final(iter->first).Value();
|
|
if (this_cost != infinity && this_cost < best_cost) {
|
|
best_cost = this_cost;
|
|
best_tok = iter->second;
|
|
}
|
|
}
|
|
}
|
|
if (best_tok == NULL) return false; // No output.
|
|
|
|
std::vector<LatticeArc> arcs_reverse; // arcs in reverse order.
|
|
for (Token *tok = best_tok; tok != NULL; tok = tok->prev_)
|
|
arcs_reverse.push_back(tok->arc_);
|
|
KALDI_ASSERT(arcs_reverse.back().nextstate == fst_.Start());
|
|
arcs_reverse.pop_back(); // that was a "fake" token... gives no info.
|
|
|
|
StateId cur_state = fst_out->AddState();
|
|
fst_out->SetStart(cur_state);
|
|
for (ssize_t i = static_cast<ssize_t>(arcs_reverse.size())-1; i >= 0; i--) {
|
|
LatticeArc arc = arcs_reverse[i];
|
|
arc.nextstate = fst_out->AddState();
|
|
fst_out->AddArc(cur_state, arc);
|
|
cur_state = arc.nextstate;
|
|
}
|
|
if (is_final && use_final_probs)
|
|
fst_out->SetFinal(cur_state,
|
|
LatticeWeight(fst_.Final(best_tok->arc_.nextstate).Value(),
|
|
0.0));
|
|
else
|
|
fst_out->SetFinal(cur_state, LatticeWeight::One());
|
|
fst::RemoveEpsLocal(fst_out);
|
|
return true;
|
|
}
|
|
|
|
|
|
void SimpleDecoder::ProcessEmitting(DecodableInterface *decodable) {
|
|
int32 frame = num_frames_decoded_;
|
|
// Processes emitting arcs for one frame. Propagates from
|
|
// prev_toks_ to cur_toks_.
|
|
double cutoff = std::numeric_limits<BaseFloat>::infinity();
|
|
for (unordered_map<StateId, Token*>::iterator iter = prev_toks_.begin();
|
|
iter != prev_toks_.end();
|
|
++iter) {
|
|
StateId state = iter->first;
|
|
Token *tok = iter->second;
|
|
KALDI_ASSERT(state == tok->arc_.nextstate);
|
|
for (fst::ArcIterator<fst::Fst<StdArc> > aiter(fst_, state);
|
|
!aiter.Done();
|
|
aiter.Next()) {
|
|
const StdArc &arc = aiter.Value();
|
|
if (arc.ilabel != 0) { // propagate..
|
|
BaseFloat acoustic_cost = -decodable->LogLikelihood(frame, arc.ilabel);
|
|
double total_cost = tok->cost_ + arc.weight.Value() + acoustic_cost;
|
|
|
|
if (total_cost >= cutoff) continue;
|
|
if (total_cost + beam_ < cutoff)
|
|
cutoff = total_cost + beam_;
|
|
Token *new_tok = new Token(arc, acoustic_cost, tok);
|
|
unordered_map<StateId, Token*>::iterator find_iter
|
|
= cur_toks_.find(arc.nextstate);
|
|
if (find_iter == cur_toks_.end()) {
|
|
cur_toks_[arc.nextstate] = new_tok;
|
|
} else {
|
|
if ( *(find_iter->second) < *new_tok ) {
|
|
Token::TokenDelete(find_iter->second);
|
|
find_iter->second = new_tok;
|
|
} else {
|
|
Token::TokenDelete(new_tok);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
num_frames_decoded_++;
|
|
}
|
|
|
|
void SimpleDecoder::ProcessNonemitting() {
|
|
// Processes nonemitting arcs for one frame. Propagates within
|
|
// cur_toks_.
|
|
std::vector<StateId> queue;
|
|
double infinity = std::numeric_limits<double>::infinity();
|
|
double best_cost = infinity;
|
|
for (unordered_map<StateId, Token*>::iterator iter = cur_toks_.begin();
|
|
iter != cur_toks_.end();
|
|
++iter) {
|
|
queue.push_back(iter->first);
|
|
best_cost = std::min(best_cost, iter->second->cost_);
|
|
}
|
|
double cutoff = best_cost + beam_;
|
|
|
|
while (!queue.empty()) {
|
|
StateId state = queue.back();
|
|
queue.pop_back();
|
|
Token *tok = cur_toks_[state];
|
|
KALDI_ASSERT(tok != NULL && state == tok->arc_.nextstate);
|
|
for (fst::ArcIterator<fst::Fst<StdArc> > aiter(fst_, state);
|
|
!aiter.Done();
|
|
aiter.Next()) {
|
|
const StdArc &arc = aiter.Value();
|
|
if (arc.ilabel == 0) { // propagate nonemitting only...
|
|
const BaseFloat acoustic_cost = 0.0;
|
|
Token *new_tok = new Token(arc, acoustic_cost, tok);
|
|
if (new_tok->cost_ > cutoff) {
|
|
Token::TokenDelete(new_tok);
|
|
} else {
|
|
unordered_map<StateId, Token*>::iterator find_iter
|
|
= cur_toks_.find(arc.nextstate);
|
|
if (find_iter == cur_toks_.end()) {
|
|
cur_toks_[arc.nextstate] = new_tok;
|
|
queue.push_back(arc.nextstate);
|
|
} else {
|
|
if ( *(find_iter->second) < *new_tok ) {
|
|
Token::TokenDelete(find_iter->second);
|
|
find_iter->second = new_tok;
|
|
queue.push_back(arc.nextstate);
|
|
} else {
|
|
Token::TokenDelete(new_tok);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// static
|
|
void SimpleDecoder::ClearToks(unordered_map<StateId, Token*> &toks) {
|
|
for (unordered_map<StateId, Token*>::iterator iter = toks.begin();
|
|
iter != toks.end(); ++iter) {
|
|
Token::TokenDelete(iter->second);
|
|
}
|
|
toks.clear();
|
|
}
|
|
|
|
// static
|
|
void SimpleDecoder::PruneToks(BaseFloat beam, unordered_map<StateId, Token*> *toks) {
|
|
if (toks->empty()) {
|
|
KALDI_VLOG(2) << "No tokens to prune.\n";
|
|
return;
|
|
}
|
|
double best_cost = std::numeric_limits<double>::infinity();
|
|
for (unordered_map<StateId, Token*>::iterator iter = toks->begin();
|
|
iter != toks->end(); ++iter)
|
|
best_cost = std::min(best_cost, iter->second->cost_);
|
|
std::vector<StateId> retained;
|
|
double cutoff = best_cost + beam;
|
|
for (unordered_map<StateId, Token*>::iterator iter = toks->begin();
|
|
iter != toks->end(); ++iter) {
|
|
if (iter->second->cost_ < cutoff)
|
|
retained.push_back(iter->first);
|
|
else
|
|
Token::TokenDelete(iter->second);
|
|
}
|
|
unordered_map<StateId, Token*> tmp;
|
|
for (size_t i = 0; i < retained.size(); i++) {
|
|
tmp[retained[i]] = (*toks)[retained[i]];
|
|
}
|
|
KALDI_VLOG(2) << "Pruned to " << (retained.size()) << " toks.\n";
|
|
tmp.swap(*toks);
|
|
}
|
|
|
|
} // end namespace kaldi.
|