257 lines
9.4 KiB
C++
257 lines
9.4 KiB
C++
// fstext/context-fst-test.cc
|
|
|
|
// Copyright 2009-2011 Microsoft Corporation
|
|
|
|
// See ../../COPYING for clarification regarding multiple authors
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
|
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
|
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
|
// See the Apache 2 License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "fstext/context-fst.h"
|
|
#include "fstext/fst-test-utils.h"
|
|
#include "tree/context-dep.h"
|
|
#include "util/kaldi-io.h"
|
|
#include "base/kaldi-math.h"
|
|
|
|
namespace fst
|
|
{
|
|
using std::vector;
|
|
using std::cout;
|
|
|
|
// GenAcceptorFromSequence generates a linear acceptor (identical input+output symbols) that has this
|
|
// sequence of symbols, and
|
|
template<class Arc>
|
|
static VectorFst<Arc> *GenAcceptorFromSequence(const vector<typename Arc::Label> &symbols, float cost) {
|
|
typedef typename Arc::Weight Weight;
|
|
typedef typename Arc::StateId StateId;
|
|
|
|
vector<float> split_cost(symbols.size()+1, 0.0); // for #-arcs + end-state.
|
|
{ // compute split_cost. it must sum to "cost".
|
|
std::set<int32> indices;
|
|
size_t num_indices = 1 + (kaldi::Rand() % split_cost.size());
|
|
while (indices.size() < num_indices) indices.insert(kaldi::Rand() % split_cost.size());
|
|
for (std::set<int32>::iterator iter = indices.begin(); iter != indices.end(); ++iter) {
|
|
split_cost[*iter] = cost / num_indices;
|
|
}
|
|
}
|
|
|
|
VectorFst<Arc> *fst = new VectorFst<Arc>();
|
|
StateId cur_state = fst->AddState();
|
|
fst->SetStart(cur_state);
|
|
for (size_t i = 0; i < symbols.size(); i++) {
|
|
StateId next_state = fst->AddState();
|
|
Arc arc;
|
|
arc.ilabel = symbols[i];
|
|
arc.olabel = symbols[i];
|
|
arc.nextstate = next_state;
|
|
arc.weight = (Weight) split_cost[i];
|
|
fst->AddArc(cur_state, arc);
|
|
cur_state = next_state;
|
|
|
|
}
|
|
fst->SetFinal(cur_state, (Weight)split_cost[symbols.size()]);
|
|
return fst;
|
|
}
|
|
|
|
|
|
|
|
// CheckPhones is used to test the correctness of an FST that is the result of
|
|
// composition with a ContextFst.
|
|
template<class Arc>
|
|
static float CheckPhones(const VectorFst<Arc> &linear_fst,
|
|
const vector<typename Arc::Label> &phone_ids,
|
|
const vector<typename Arc::Label> &disambig_ids,
|
|
const vector<typename Arc::Label> &phone_seq,
|
|
const vector<vector<typename Arc::Label> > &ilabel_info,
|
|
int N, int P) {
|
|
typedef typename Arc::Label Label;
|
|
typedef typename Arc::StateId StateId;
|
|
typedef typename Arc::Weight Weight;
|
|
|
|
assert(kaldi::IsSorted(phone_ids)); // so we can do binary_search.
|
|
|
|
|
|
vector<int32> input_syms;
|
|
vector<int32> output_syms;
|
|
Weight tot_cost;
|
|
bool ans = GetLinearSymbolSequence(linear_fst, &input_syms,
|
|
&output_syms, &tot_cost);
|
|
assert(ans); // should be linear.
|
|
|
|
vector<int32> phone_seq_check;
|
|
for (size_t i = 0; i < output_syms.size(); i++)
|
|
if (std::binary_search(phone_ids.begin(), phone_ids.end(), output_syms[i]))
|
|
phone_seq_check.push_back(output_syms[i]);
|
|
|
|
assert(phone_seq_check == phone_seq);
|
|
|
|
vector<vector<int32> > input_syms_long;
|
|
for (size_t i = 0; i < input_syms.size(); i++) {
|
|
Label isym = input_syms[i];
|
|
if (ilabel_info[isym].size() == 0) continue; // epsilon.
|
|
if ( (ilabel_info[isym].size() == 1 &&
|
|
ilabel_info[isym][0] <= 0) ) continue; // disambig.
|
|
input_syms_long.push_back(ilabel_info[isym]);
|
|
}
|
|
|
|
for (size_t i = 0; i < input_syms_long.size(); i++) {
|
|
vector<int32> phone_context_window(N); // phone at pos i will be at pos P in this window.
|
|
int pos = ((int)i) - P; // pos of first phone in window [ may be out of range] .
|
|
for (int j = 0; j < N; j++, pos++) {
|
|
if (static_cast<size_t>(pos) < phone_seq.size()) phone_context_window[j] = phone_seq[pos];
|
|
else phone_context_window[j] = 0; // 0 is a special symbol that context-dep-itf expects to see
|
|
// when no phone is present due to out-of-window. context-fst knows about this too.
|
|
}
|
|
assert(input_syms_long[i] == phone_context_window);
|
|
}
|
|
return tot_cost.Value();
|
|
}
|
|
|
|
|
|
|
|
|
|
template<class Arc>
|
|
static VectorFst<Arc> *GenRandPhoneSeq(vector<typename Arc::Label> &phone_syms,
|
|
vector<typename Arc::Label> &disambig_syms,
|
|
typename Arc::Label subsequential_symbol,
|
|
int num_subseq_syms,
|
|
float seq_prob,
|
|
vector<typename Arc::Label> *phoneseq_out) {
|
|
KALDI_ASSERT(phoneseq_out != NULL);
|
|
typedef typename Arc::Label Label;
|
|
// Generate an FST that is a random phone sequence, ending
|
|
// with "num_subseq_syms" subsequential symbols. It will
|
|
// have disambiguation symbols randomly interspersed throughout.
|
|
// The number of phones is random (possibly zero).
|
|
size_t len = (kaldi::Rand() % 4) * (kaldi::Rand() % 3); // up to 3*2=6 phones.
|
|
float disambig_prob = 0.33;
|
|
phoneseq_out->clear();
|
|
vector<Label> syms; // the phones
|
|
for (size_t i = 0; i < len; i++) {
|
|
while (kaldi::RandUniform() < disambig_prob) syms.push_back(disambig_syms[kaldi::Rand() % disambig_syms.size()]);
|
|
Label phone_id = phone_syms[kaldi::Rand() % phone_syms.size()];
|
|
phoneseq_out->push_back(phone_id); // record in output the underlying phone sequence.
|
|
syms.push_back(phone_id);
|
|
}
|
|
for (size_t i = 0; static_cast<int32>(i) < num_subseq_syms; i++) {
|
|
while (kaldi::RandUniform() < disambig_prob) syms.push_back(disambig_syms[kaldi::Rand() % disambig_syms.size()]);
|
|
syms.push_back(subsequential_symbol);
|
|
}
|
|
while (kaldi::RandUniform() < disambig_prob) syms.push_back(disambig_syms[kaldi::Rand() % disambig_syms.size()]);
|
|
|
|
// OK, now have the symbols of the FST as a vector.
|
|
return GenAcceptorFromSequence<Arc>(syms, seq_prob);
|
|
}
|
|
|
|
// Don't instantiate with log semiring, as RandEquivalent may fail.
|
|
// TestContestFst also test ReadILabelInfo and WriteILabelInfo.
|
|
static void TestContextFst(bool verbose, bool use_matcher) {
|
|
typedef StdArc Arc;
|
|
typedef Arc::Label Label;
|
|
typedef Arc::StateId StateId;
|
|
typedef Arc::Weight Weight;
|
|
|
|
// Generate a random set of phones.
|
|
size_t num_phones = 1 + kaldi::Rand() % 10;
|
|
std::set<int32> phones_set;
|
|
while (phones_set.size() < num_phones) phones_set.insert(1 + kaldi::Rand() % (num_phones + 5)); // don't use 0 [== epsilon]
|
|
vector<int32> phones;
|
|
kaldi::CopySetToVector(phones_set, &phones);
|
|
|
|
int N = 1 + kaldi::Rand() % 4; // Context size, in range 1..4.
|
|
int P = kaldi::Rand() % N; // 1.. N-1.
|
|
if (verbose) std::cout << "N = "<< N << ", P = "<<P<<'\n';
|
|
|
|
Label subsequential_symbol = 1000;
|
|
vector<int32> disambig_syms;
|
|
for (size_t i =0; i < 5; i++) disambig_syms.push_back(500 + i);
|
|
vector<int32> phone_syms;
|
|
for (size_t i = 0; i < phones.size();i++) phone_syms.push_back(phones[i]);
|
|
|
|
|
|
InverseContextFst inv_cfst(subsequential_symbol,
|
|
phones, disambig_syms,
|
|
N, P);
|
|
|
|
|
|
/* Now create random phone-sequences and compose them with the context FST.
|
|
*/
|
|
|
|
for (size_t p = 0; p < 10; p++) {
|
|
vector<int32> phone_seq;
|
|
int num_subseq = N - P - 1; // zero if P == N-1, i.e. P is last element, i.e. left-context only.
|
|
float tot_cost = 20.0 * kaldi::RandUniform();
|
|
VectorFst<Arc> *f = GenRandPhoneSeq<Arc>(phone_syms, disambig_syms, subsequential_symbol, num_subseq, tot_cost, &phone_seq);
|
|
if (verbose) {
|
|
std::cout << "Sequence FST is:\n";
|
|
{ // Try to print the fst.
|
|
FstPrinter<Arc> fstprinter(*f, NULL, NULL, NULL, false, true, "\t");
|
|
fstprinter.Print(&std::cout, "standard output");
|
|
}
|
|
}
|
|
|
|
VectorFst<Arc> fst_composed;
|
|
|
|
ComposeDeterministicOnDemandInverse(*f, &inv_cfst, &fst_composed);
|
|
|
|
|
|
// Testing WriteILabelInfo and ReadILabelInfo.
|
|
{
|
|
bool binary = (kaldi::Rand() % 2 == 0);
|
|
WriteILabelInfo(kaldi::Output("tmpf", binary).Stream(),
|
|
binary, inv_cfst.IlabelInfo());
|
|
|
|
bool binary_in;
|
|
vector<vector<int32> > ilabel_info;
|
|
kaldi::Input ki("tmpf", &binary_in);
|
|
ReadILabelInfo(ki.Stream(),
|
|
binary_in, &ilabel_info);
|
|
assert(ilabel_info == inv_cfst.IlabelInfo());
|
|
}
|
|
|
|
|
|
if (verbose) {
|
|
std::cout << "Composed FST is:\n";
|
|
{ // Try to print the fst.
|
|
FstPrinter<Arc> fstprinter(fst_composed, NULL, NULL, NULL, false, true, "\t");
|
|
fstprinter.Print(&std::cout, "standard output");
|
|
}
|
|
}
|
|
|
|
// now check the composed FST.
|
|
float tot_cost_check = CheckPhones<Arc>(fst_composed,
|
|
phone_syms,
|
|
disambig_syms,
|
|
phone_seq,
|
|
inv_cfst.IlabelInfo(),
|
|
N, P);
|
|
kaldi::AssertEqual(tot_cost, tot_cost_check);
|
|
|
|
delete f;
|
|
}
|
|
|
|
unlink("tmpf");
|
|
}
|
|
|
|
|
|
} // namespace fst
|
|
|
|
int main() {
|
|
|
|
for (int i = 0;i < 16;i++) {
|
|
bool verbose = (i < 4);
|
|
bool use_matcher = ( (i/4) % 2 == 0);
|
|
fst::TestContextFst(verbose, use_matcher);
|
|
}
|
|
}
|