FunASR/runtime/onnxruntime/third_party/kaldi/fstext/remove-eps-local-test.cc

179 lines
5.6 KiB
C++

// fstext/remove-eps-local-test.cc
// Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "fstext/remove-eps-local.h"
#include "fstext/fstext-utils.h"
#include "fstext/fst-test-utils.h"
#include "base/kaldi-math.h"
namespace fst
{
using std::vector;
using std::cout;
// Don't instantiate with log semiring, as RandEquivalent may fail.
template<class Arc> static void TestRemoveEpsLocal() {
typedef typename Arc::Label Label;
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
VectorFst<Arc> fst;
int n_syms = 2 + kaldi::Rand() % 5, n_arcs = 5 + kaldi::Rand() % 30, n_final = 1 + kaldi::Rand()%10;
SymbolTable symtab("my-symbol-table"), *sptr = &symtab;
vector<Label> all_syms; // including epsilon.
// Put symbols in the symbol table from 1..n_syms-1.
for (size_t i = 0;i < (size_t)n_syms;i++) {
std::stringstream ss;
if (i == 0) ss << "<eps>";
else ss<<i;
Label cur_lab = sptr->AddSymbol(ss.str());
assert(cur_lab == (Label)i);
all_syms.push_back(cur_lab);
}
assert(all_syms[0] == 0);
fst.AddState();
int cur_num_states = 1;
for (int i = 0; i < n_arcs; i++) {
StateId src_state = kaldi::Rand() % cur_num_states;
StateId dst_state;
if (kaldi::RandUniform() < 0.1) dst_state = kaldi::Rand() % cur_num_states;
else {
dst_state = cur_num_states++; fst.AddState();
}
Arc arc;
if (kaldi::RandUniform() < 0.3) arc.ilabel = all_syms[kaldi::Rand()%all_syms.size()];
else arc.ilabel = 0;
if (kaldi::RandUniform() < 0.3) arc.olabel = all_syms[kaldi::Rand()%all_syms.size()];
else arc.olabel = 0;
arc.weight = (Weight) (0 + 0.1*(kaldi::Rand() % 5));
arc.nextstate = dst_state;
fst.AddArc(src_state, arc);
}
for (int i = 0; i < n_final; i++) {
fst.SetFinal(kaldi::Rand() % cur_num_states, (Weight) (0 + 0.1*(kaldi::Rand() % 5)));
}
if (kaldi::RandUniform() < 0.8) fst.SetStart(0); // usually leads to nicer examples.
else fst.SetStart(kaldi::Rand() % cur_num_states);
Connect(&fst);
if (fst.Start() == kNoStateId) return; // "Connect" made it empty.
std::cout <<" printing after trimming\n";
{
FstPrinter<Arc> fstprinter(fst, sptr, sptr, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
VectorFst<Arc> fst_copy1(fst);
RemoveEpsLocal(&fst_copy1);
{
std::cout << "copy1 = \n";
FstPrinter<Arc> fstprinter(fst_copy1, sptr, sptr, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
int num_states_0 = fst.NumStates();
int num_states_1 = fst_copy1.NumStates();
std::cout << "Number of states 0 = "<<num_states_0<<", 1 = "<<num_states_1<<'\n';
assert(RandEquivalent(fst, fst_copy1, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length-- max?*/));
}
static void TestRemoveEpsLocalSpecial() {
// test that RemoveEpsLocalSpecial preserves equivalence in tropical while
// maintaining stochasticity in log.
typedef VectorFst<LogArc> Fst;
typedef LogArc::Weight Weight;
typedef LogArc::StateId StateId;
typedef LogArc Arc;
VectorFst<LogArc> *logfst = RandFst<LogArc>();
{ // Make the FST stochastic.
for (StateId s = 0; s < logfst->NumStates(); s++) {
Weight w = logfst->Final(s);
for (ArcIterator<Fst> aiter(*logfst, s); !aiter.Done(); aiter.Next()) {
w = Plus(w, aiter.Value().weight);
}
if (w != Weight::Zero()) {
logfst->SetFinal(s, Divide(logfst->Final(s), w, DIVIDE_ANY));
for (MutableArcIterator<Fst> aiter(logfst, s); !aiter.Done(); aiter.Next()) {
Arc a = aiter.Value();
a.weight = Divide(a.weight, w, DIVIDE_ANY);
aiter.SetValue(a);
}
}
}
}
#ifndef _MSC_VER
assert(IsStochasticFst(*logfst, kDelta*10));
#endif
{
std::cout << "logfst = \n";
FstPrinter<LogArc> fstprinter(*logfst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
VectorFst<StdArc> fst;
Cast(*logfst, &fst);
VectorFst<StdArc> fst_copy(fst);
RemoveEpsLocalSpecial(&fst); // removes eps in std-arc but keep stochastic in log-arc
// make sure equivalent.
assert(RandEquivalent(fst, fst_copy, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length-- max?*/));
VectorFst<LogArc> logfst2;
Cast(fst, &logfst2);
{
std::cout << "logfst2 = \n";
FstPrinter<LogArc> fstprinter(logfst2, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
if (ApproxEqual(ShortestDistance(*logfst), ShortestDistance(logfst2))) {
// make sure we preserved stochasticity in cases where doing so was
// possible... if the log-semiring total weight changed, then it is
// not possible so don't assert this.
assert(IsStochasticFst(logfst2, kDelta*10));
}
delete logfst;
}
} // namespace fst
int main() {
using namespace fst;
for (int i = 0; i < 10; i++) {
TestRemoveEpsLocal<fst::StdArc>();
TestRemoveEpsLocalSpecial();
}
}