FunASR/runtime/onnxruntime/third_party/kaldi/lm/arpa-lm-compiler-test.cc

251 lines
7.9 KiB
C++
Raw Normal View History

2024-05-18 15:50:56 +08:00
// lm/arpa-lm-compiler-test.cc
// Copyright 2009-2011 Gilles Boulianne
// Copyright 2016 Smart Action LLC (kkm)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include <sstream>
#include "base/kaldi-error.h"
#include "base/kaldi-math.h"
#include "lm/arpa-lm-compiler.h"
#include "util/kaldi-io.h"
namespace kaldi {
// Predefine some symbol values, because any integer is as good than any other.
enum {
kEps = 0,
kDisambig,
kBos, kEos,
};
// Number of random sentences for coverage test.
static const int kRandomSentences = 50;
// Creates an FST that generates any sequence of symbols taken from given
// symbol table. The FST is then associated with the symbol table.
static fst::StdVectorFst* CreateGenFst(bool seps, const fst::SymbolTable* pst) {
fst::StdVectorFst* genFst = new fst::StdVectorFst;
genFst->SetInputSymbols(pst);
genFst->SetOutputSymbols(pst);
fst::StdArc::StateId midId = genFst->AddState();
if (!seps) {
fst::StdArc::StateId initId = genFst->AddState();
fst::StdArc::StateId finalId = genFst->AddState();
genFst->SetStart(initId);
genFst->SetFinal(finalId, fst::StdArc::Weight::One());
genFst->AddArc(initId, fst::StdArc(kBos, kBos, 0, midId));
genFst->AddArc(midId, fst::StdArc(kEos, kEos, 0, finalId));
} else {
genFst->SetStart(midId);
genFst->SetFinal(midId, fst::StdArc::Weight::One());
}
// Add a loop for each symbol in the table except the four special ones.
fst::SymbolTableIterator si(*pst);
for (si.Reset(); !si.Done(); si.Next()) {
if (si.Value() == kBos || si.Value() == kEos ||
si.Value() == kEps || si.Value() == kDisambig)
continue;
genFst->AddArc(midId, fst::StdArc(si.Value(), si.Value(),
fst::StdArc::Weight::One(), midId));
}
return genFst;
}
// Compile given ARPA file.
ArpaLmCompiler* Compile(bool seps, const std::string &infile) {
ArpaParseOptions options;
fst::SymbolTable symbols;
// Use spaces on special symbols, so we rather fail than read them by mistake.
symbols.AddSymbol(" <eps>", kEps);
symbols.AddSymbol(" #0", kDisambig);
options.bos_symbol = symbols.AddSymbol("<s>", kBos);
options.eos_symbol = symbols.AddSymbol("</s>", kEos);
options.oov_handling = ArpaParseOptions::kAddToSymbols;
// Tests in this form cannot be run with epsilon substitution, unless every
// random path is also fitted with a #0-transducing self-loop.
ArpaLmCompiler* lm_compiler =
new ArpaLmCompiler(options,
seps ? kDisambig : 0,
&symbols);
{
Input ki(infile);
lm_compiler->Read(ki.Stream());
}
return lm_compiler;
}
// Add a state to an FSA after last_state, add a form last_state to the new
// state, and return the new state.
fst::StdArc::StateId AddToChainFsa(fst::StdMutableFst* fst,
fst::StdArc::StateId last_state,
int64 symbol) {
fst::StdArc::StateId next_state = fst->AddState();
fst->AddArc(last_state, fst::StdArc(symbol, symbol, 0, next_state));
return next_state;
}
// Add a disambiguator-generating self loop to every state of an FST.
void AddSelfLoops(fst::StdMutableFst* fst) {
for (fst::StateIterator<fst::StdMutableFst> siter(*fst);
!siter.Done(); siter.Next()) {
fst->AddArc(siter.Value(),
fst::StdArc(kEps, kDisambig, 0, siter.Value()));
}
}
// Compiles infile and then runs kRandomSentences random coverage tests on the
// compiled FST.
bool CoverageTest(bool seps, const std::string &infile) {
// Compile ARPA model.
ArpaLmCompiler* lm_compiler = Compile(seps, infile);
// Create an FST that generates any sequence of symbols taken from the model
// output.
fst::StdVectorFst* genFst =
CreateGenFst(seps, lm_compiler->Fst().OutputSymbols());
int num_successes = 0;
for (int32 i = 0; i < kRandomSentences; ++i) {
// Generate a random sentence FST.
fst::StdVectorFst sentence;
RandGen(*genFst, &sentence);
if (seps)
AddSelfLoops(&sentence);
fst::ArcSort(lm_compiler->MutableFst(), fst::StdOLabelCompare());
// The past must successfully compose with the LM FST.
fst::StdVectorFst composition;
Compose(sentence, lm_compiler->Fst(), &composition);
if (composition.Start() != fst::kNoStateId)
++num_successes;
}
delete genFst;
delete lm_compiler;
bool ok = num_successes == kRandomSentences;
if (!ok) {
KALDI_WARN << "Coverage test failed on " << infile << ": composed "
<< num_successes << "/" << kRandomSentences;
}
return ok;
}
bool ScoringTest(bool seps, const std::string &infile, const std::string& sentence,
float expected) {
ArpaLmCompiler* lm_compiler = Compile(seps, infile);
const fst::SymbolTable* symbols = lm_compiler->Fst().InputSymbols();
// Create a sentence FST for scoring.
fst::StdVectorFst sentFst;
fst::StdArc::StateId state = sentFst.AddState();
sentFst.SetStart(state);
if (!seps) {
state = AddToChainFsa(&sentFst, state, kBos);
}
std::stringstream ss(sentence);
std::string word;
while (ss >> word) {
int64 word_sym = symbols->Find(word);
KALDI_ASSERT(word_sym != -1);
state = AddToChainFsa(&sentFst, state, word_sym);
}
if (!seps) {
state = AddToChainFsa(&sentFst, state, kEos);
}
if (seps) {
AddSelfLoops(&sentFst);
}
sentFst.SetFinal(state, 0);
sentFst.SetOutputSymbols(symbols);
// Do the composition and extract final weight.
fst::StdVectorFst composed;
fst::Compose(sentFst, lm_compiler->Fst(), &composed);
delete lm_compiler;
if (composed.Start() == fst::kNoStateId) {
KALDI_WARN << "Test sentence " << sentence << " did not compose "
<< "with the language model FST\n";
return false;
}
std::vector<fst::StdArc::Weight> shortest;
fst::ShortestDistance(composed, &shortest, true);
float actual = shortest[composed.Start()].Value();
bool ok = ApproxEqual(expected, actual);
if (!ok) {
KALDI_WARN << "Scored " << sentence << " in " << infile
<< ": Expected=" << expected << " actual=" << actual;
}
return ok;
}
bool ThrowsExceptionTest(bool seps, const std::string &infile) {
try {
// Make memory cleanup easy in both cases of try-catch block.
std::unique_ptr<ArpaLmCompiler> compiler(Compile(seps, infile));
return false;
} catch (const KaldiFatalError&) {
return true;
}
}
} // namespace kaldi
bool RunAllTests(bool seps) {
bool ok = true;
ok &= kaldi::CoverageTest(seps, "test_data/missing_backoffs.arpa");
ok &= kaldi::CoverageTest(seps, "test_data/unused_backoffs.arpa");
ok &= kaldi::CoverageTest(seps, "test_data/input.arpa");
ok &= kaldi::ScoringTest(seps, "test_data/input.arpa", "b b b a", 59.2649);
ok &= kaldi::ScoringTest(seps, "test_data/input.arpa", "a b", 4.36082);
ok &= kaldi::ThrowsExceptionTest(seps, "test_data/missing_bos.arpa");
if (!ok) {
KALDI_WARN << "Tests " << (seps ? "with" : "without")
<< " epsilon substitution FAILED";
}
return ok;
}
int main(int argc, char *argv[]) {
bool ok = true;
ok &= RunAllTests(false); // Without disambiguators (old behavior).
ok &= RunAllTests(true); // With epsilon substitution (new behavior).
if (ok) {
KALDI_LOG << "All tests passed";
return 0;
} else {
KALDI_WARN << "Test FAILED";
return 1;
}
}