FunASR/runtime/onnxruntime/third_party/kaldi/fstext/trivial-factor-weight.h

402 lines
13 KiB
C++

// fstext/trivial-factor-weight.h
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
//
//
// This is a modified file from the OpenFST Library v1.2.7 available at
// http://www.openfst.org and released under the Apache License Version 2.0.
//
//
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: allauzen@google.com (Cyril Allauzen)
#ifndef KALDI_FSTEXT_TRIVIAL_FACTOR_WEIGHT_H_
#define KALDI_FSTEXT_TRIVIAL_FACTOR_WEIGHT_H_
// TrivialFactorWeight.h This is an extension to factor-weight.h in the OpenFst
// code. It is a version of FactorWeight that creates separate states (with
// input epsilons) rather than pushing the factors forward. This is for
// converting from Gallic FSTs, where you want the result to be a bit more
// trivial with input epsilons inserted where there are multiple output symbols.
// This has the advantage that it always works, for any input (also I just
// prefer this approach).
#include <unordered_map>
using std::unordered_map;
#include <algorithm>
#include <string>
#include <utility>
#include <vector>
#include <fst/cache.h>
#include <fst/test-properties.h>
namespace fst {
template <class Arc>
struct TrivialFactorWeightOptions : CacheOptions {
typedef typename Arc::Label Label;
float delta;
Label extra_ilabel; // input label of extra arcs
Label extra_olabel; // output label of extra arcs
TrivialFactorWeightOptions(const CacheOptions &opts, float d,
Label il = 0, Label ol = 0)
: CacheOptions(opts), delta(d), extra_ilabel(il), extra_olabel(ol) {}
explicit TrivialFactorWeightOptions(
float d, Label il = 0, Label ol = 0)
: delta(d), extra_ilabel(il), extra_olabel(ol) {}
TrivialFactorWeightOptions(): delta(kDelta), extra_ilabel(0), extra_olabel(0) {}
};
namespace internal {
// Implementation class for TrivialFactorWeight
template <class A, class F>
class TrivialFactorWeightFstImpl
: public CacheImpl<A> {
public:
using CacheImpl<A>::PushArc;
using FstImpl<A>::SetType;
using FstImpl<A>::SetProperties;
using FstImpl<A>::Properties;
using FstImpl<A>::SetInputSymbols;
using FstImpl<A>::SetOutputSymbols;
using CacheBaseImpl< CacheState<A> >::HasStart;
using CacheBaseImpl< CacheState<A> >::HasFinal;
using CacheBaseImpl< CacheState<A> >::HasArcs;
typedef A Arc;
typedef typename A::Label Label;
typedef typename A::Weight Weight;
typedef typename A::StateId StateId;
typedef F FactorIterator;
typedef DefaultCacheStore<A> Store;
typedef typename Store::State State;
struct Element {
Element() {}
Element(StateId s, Weight w) : state(s), weight(w) {}
StateId state; // Input state Id
Weight weight; // Residual weight
};
TrivialFactorWeightFstImpl(const Fst<A> &fst, const TrivialFactorWeightOptions<A> &opts)
: CacheImpl<A>(opts),
fst_(fst.Copy()),
delta_(opts.delta),
extra_ilabel_(opts.extra_ilabel),
extra_olabel_(opts.extra_olabel) {
SetType("factor-weight");
uint64 props = fst.Properties(kFstProperties, false);
SetProperties(FactorWeightProperties(props), kCopyProperties);
SetInputSymbols(fst.InputSymbols());
SetOutputSymbols(fst.OutputSymbols());
}
TrivialFactorWeightFstImpl(const TrivialFactorWeightFstImpl<A, F> &impl)
: CacheImpl<A>(impl),
fst_(impl.fst_->Copy(true)),
delta_(impl.delta_),
extra_ilabel_(impl.extra_ilabel_),
extra_olabel_(impl.extra_olabel_) {
SetType("factor-weight");
SetProperties(impl.Properties(), kCopyProperties);
SetInputSymbols(impl.InputSymbols());
SetOutputSymbols(impl.OutputSymbols());
}
StateId Start() {
if (!HasStart()) {
StateId s = fst_->Start();
if (s == kNoStateId)
return kNoStateId;
StateId start = this->FindState(Element(fst_->Start(), Weight::One()));
this->SetStart(start);
}
return CacheImpl<A>::Start();
}
Weight Final(StateId s) {
if (!HasFinal(s)) {
const Element &e = elements_[s];
Weight w;
if (e.state == kNoStateId) { // extra state inserted to represent final weights.
FactorIterator fit(e.weight);
if (fit.Done()) { // cannot be factored.
w = e.weight; // so it's final
} else {
w = Weight::Zero(); // need another transition.
}
} else {
if (e.weight != Weight::One()) { // Not a real state.
w = Weight::Zero();
} else { // corresponds to a "real" state.
w = fst_->Final(e.state);
FactorIterator fit(w);
if (!fit.Done()) // we would have intermediate states representing this final state.
w = Weight::Zero();
}
}
this->SetFinal(s, w);
return w;
} else {
return CacheImpl<A>::Final(s);
}
}
size_t NumArcs(StateId s) {
if (!HasArcs(s))
Expand(s);
return CacheImpl<A>::NumArcs(s);
}
size_t NumInputEpsilons(StateId s) {
if (!HasArcs(s))
Expand(s);
return CacheImpl<A>::NumInputEpsilons(s);
}
size_t NumOutputEpsilons(StateId s) {
if (!HasArcs(s))
Expand(s);
return CacheImpl<A>::NumOutputEpsilons(s);
}
void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
if (!HasArcs(s))
Expand(s);
CacheImpl<A>::InitArcIterator(s, data);
}
// Find state corresponding to an element. Create new state
// if element not found.
StateId FindState(const Element &e) {
typename ElementMap::iterator eit = element_map_.find(e);
if (eit != element_map_.end()) {
return (*eit).second;
} else {
StateId s = elements_.size();
elements_.push_back(e);
element_map_.insert(std::pair<const Element, StateId>(e, s));
return s;
}
}
// Computes the outgoing transitions from a state, creating new destination
// states as needed.
void Expand(StateId s) {
CHECK(static_cast<size_t>(s) < elements_.size());
Element e = elements_[s];
if (e.weight != Weight::One()) {
FactorIterator fit(e.weight);
if (fit.Done()) { // Cannot be factored-> create a link to dest state directly
if (e.state != kNoStateId) {
StateId dest = FindState(Element(e.state, Weight::One()));
PushArc(s, Arc(extra_ilabel_, extra_olabel_, e.weight, dest));
} // else we're done. This is a final state.
} else { // Can be factored.
const std::pair<Weight, Weight> &p = fit.Value();
StateId dest = FindState(Element(e.state, p.second.Quantize(delta_)));
PushArc(s, Arc(extra_ilabel_, extra_olabel_, p.first, dest));
}
} else { // Unit weight. This corresponds to a "real" state.
CHECK(e.state != kNoStateId);
for (ArcIterator< Fst<A> > ait(*fst_, e.state);
!ait.Done();
ait.Next()) {
const A &arc = ait.Value();
FactorIterator fit(arc.weight);
if (fit.Done()) { // cannot be factored->just link directly to dest.
StateId dest = FindState(Element(arc.nextstate, Weight::One()));
PushArc(s, Arc(arc.ilabel, arc.olabel, arc.weight, dest));
} else {
const std::pair<Weight, Weight> &p = fit.Value();
StateId dest = FindState(Element(arc.nextstate, p.second.Quantize(delta_)));
PushArc(s, Arc(arc.ilabel, arc.olabel, p.first, dest));
}
}
// See if we have to add arcs for final-states [only if final-weight is factorable].
Weight final_w = fst_->Final(e.state);
if (final_w != Weight::Zero()) {
FactorIterator fit(final_w);
if (!fit.Done()) {
const std::pair<Weight, Weight> &p = fit.Value();
StateId dest = FindState(Element(kNoStateId, p.second.Quantize(delta_)));
PushArc(s, Arc(extra_ilabel_, extra_olabel_, p.first, dest));
}
}
}
this->SetArcs(s);
}
private:
// Equality function for Elements, assume weights have been quantized.
class ElementEqual {
public:
bool operator()(const Element &x, const Element &y) const {
return x.state == y.state && x.weight == y.weight;
}
};
// Hash function for Elements to Fst states.
class ElementKey {
public:
size_t operator()(const Element &x) const {
return static_cast<size_t>(x.state * kPrime + x.weight.Hash());
}
private:
static const int kPrime = 7853;
};
typedef unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
std::unique_ptr<const Fst<A>> fst_;
float delta_;
uint32 mode_; // factoring arc and/or final weights
Label extra_ilabel_; // ilabel of arc created when factoring final w's
Label extra_olabel_; // olabel of arc created when factoring final w's
std::vector<Element> elements_; // mapping Fst state to Elements
ElementMap element_map_; // mapping Elements to Fst state
};
} // namespace internal
/// TrivialFactorWeightFst takes as template parameter a FactorIterator as
/// defined above. The result of weight factoring is a transducer
/// equivalent to the input whose path weights have been factored
/// according to the FactorIterator. States and transitions will be
/// added as necessary.
/// This algorithm differs from the one implemented in FactorWeightFst
/// in that it does not attempt to push the extra weight forward to the
/// next state: it uses a sequence of "extra" intermediate state, and
/// outputs the remaining weight right away. This ensures that it will
/// always succeed, even for Gallic representations of FSTs that have cycles
/// with more output than input symbols.
/// Note that the code below was modified from factor-weight.h by just
/// search-and-replacing "FactorWeight" by "TrivialFactorWeight".
template <class A, class F>
class TrivialFactorWeightFst :
public ImplToFst<internal::TrivialFactorWeightFstImpl<A, F>> {
public:
friend class ArcIterator< TrivialFactorWeightFst<A, F> >;
friend class StateIterator< TrivialFactorWeightFst<A, F> >;
typedef A Arc;
typedef typename A::Weight Weight;
typedef typename A::StateId StateId;
typedef DefaultCacheStore<Arc> Store;
typedef typename Store::State State;
typedef internal::TrivialFactorWeightFstImpl<A, F> Impl;
explicit TrivialFactorWeightFst(const Fst<A> &fst)
: ImplToFst<Impl>(std::make_shared<Impl>(fst, TrivialFactorWeightOptions<A>())) {}
TrivialFactorWeightFst(const Fst<A> &fst, const TrivialFactorWeightOptions<A> &opts)
: ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
// See Fst<>::Copy() for doc.
TrivialFactorWeightFst(const TrivialFactorWeightFst<A, F> &fst, bool copy)
: ImplToFst<Impl>(fst, copy) {}
// Get a copy of this TrivialFactorWeightFst. See Fst<>::Copy() for further doc.
TrivialFactorWeightFst<A, F> *Copy(bool copy = false) const override {
return new TrivialFactorWeightFst<A, F>(*this, copy);
}
inline void InitStateIterator(StateIteratorData<A> *data) const override;
void InitArcIterator(StateId s, ArcIteratorData<A> *data) const override {
GetMutableImpl()->InitArcIterator(s, data);
}
private:
using ImplToFst<Impl>::GetImpl;
using ImplToFst<Impl>::GetMutableImpl;
TrivialFactorWeightFst &operator=(const TrivialFactorWeightFst &fst) = delete;
};
// Specialization for TrivialFactorWeightFst.
template<class A, class F>
class StateIterator< TrivialFactorWeightFst<A, F> >
: public CacheStateIterator< TrivialFactorWeightFst<A, F> > {
public:
explicit StateIterator(const TrivialFactorWeightFst<A, F> &fst)
: CacheStateIterator< TrivialFactorWeightFst<A, F> >(fst, fst.GetMutableImpl()) {}
};
// Specialization for TrivialFactorWeightFst.
template <class A, class F>
class ArcIterator< TrivialFactorWeightFst<A, F> >
: public CacheArcIterator< TrivialFactorWeightFst<A, F> > {
public:
typedef typename A::StateId StateId;
ArcIterator(const TrivialFactorWeightFst<A, F> &fst, StateId s)
: CacheArcIterator< TrivialFactorWeightFst<A, F>>(fst.GetMutableImpl(), s) {
if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
}
};
template <class A, class F>
inline void TrivialFactorWeightFst<A, F>::InitStateIterator(
StateIteratorData<A> *data) const {
data->base = new StateIterator< TrivialFactorWeightFst<A, F> >(*this);
}
} // namespace fst
#endif