// Copyright 2020 Jiayu DU // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifdef HAVE_KENLM #ifndef KALDI_LM_KENLM_H #define KALDI_LM_KENLM_H #include #include #include #include #include "lm/model.hh" #include "util/murmur_hash.hh" namespace kaldi { // KenLm class wraps kenlm model(supporting both "trie" or "probing" models): // 1. provides interface for loading binary LM, and holds it with ownership // 2. provides interface for ngram score query at runtime // 3. handles the index mapping between kaldi's symbols & kenlm's words // KenLm object is heavy, stateless and thread-safe, // can be shared by Fst wrapper class(i.e. KenLmDeterministicOnDemandFst) class KenLm { public: typedef lm::WordIndex WordIndex; typedef lm::ngram::State State; public: KenLm() : model_(nullptr), vocab_(nullptr), bos_sym_(""), eos_sym_(""), unk_sym_(""), bos_symid_(0), eos_symid_(0), unk_symid_(0) { } ~KenLm() { if (model_ != nullptr) { delete model_; } model_ = nullptr; vocab_ = nullptr; symid_to_wid_.clear(); } // If you have big LM on SSD hard-drive, // you can set load_method to util::LoadMethod::LAZY, // which enables "on-demand" model reading(via POSIX mmap) at runtime. // Refer to tools/kenlm/util/mmap.hh for more load methods. int Load(std::string kenlm_filename, std::string kaldi_symbol_table_filename, util::LoadMethod load_method = util::LoadMethod::POPULATE_OR_READ); inline WordIndex GetWordIndex(std::string word) const { return vocab_->Index(word.c_str()); } inline WordIndex GetWordIndex(int32 symbol_id) const { return symid_to_wid_[symbol_id]; } void SetStateToBeginOfSentence(State *s) const { model_->BeginSentenceWrite(s); } void SetStateToNull(State *s) const { model_->NullContextWrite(s); } int32 BosSymbolIndex() const { return bos_symid_; } int32 EosSymbolIndex() const { return eos_symid_; } int32 UnkSymbolIndex() const { return unk_symid_; } inline BaseFloat Score(const State *in_state, WordIndex word, State *out_state) const { return model_->BaseScore(in_state, word, out_state); } // This provides a fast state hashing, // KenLmDeterministicOnDemandFst needs this for Fst states managing. struct StateHasher { inline size_t operator()(const State &s) const noexcept { return util::MurmurHashNative(s.words, sizeof(WordIndex) * s.Length()); } }; private: void ComputeSymbolToWordIndexMapping(std::string symbol_table); private: lm::base::Model *model_; // with ownership // without ownership, points to internal vocabulary of model_ const lm::base::Vocabulary* vocab_; // There are two integerized indexing systems here: // 1. Kaldi's fst output *symbol index*(defined in words.txt), // 2. KenLm's *word index*(defined by word string hashing). // In order to rescore kaldi hypotheses with kenlm ngrams, // we need to know the index mapping from symbol to word. // KenLm class precomputes (during model loading) and stores this mapping, // and apply the mapping at runtime. // This is slower, but at least we don't need // to modify/convert runtime resources.(e.g. HCLG/lattices or kenlm models) // // In the mapping, and #0 symbols are special: // They do not correspond to any word in KenLm, // so the mapping of these two symbols are logically undefined, // we just map them to KenLm's to avoid random invalid mapping. // symid_to_wid_[kaldi_symbol_index] -> kenlm word index std::vector symid_to_wid_; // special lm symbols std::string bos_sym_; std::string eos_sym_; std::string unk_sym_; int32 bos_symid_; int32 eos_symid_; int32 unk_symid_; }; // class KenLm // DeterministicOnDemandFst wraps a KenLm object as a deteministic Fst. // Internally, it manages dynamically expanded Fst states(so not thread-safe), // different threads should create their own instances of this class. // They are lightweight and can share the same KenLm object. template class KenLmDeterministicOnDemandFst : public fst::DeterministicOnDemandFst { public: typedef typename Arc::Weight Weight; typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; typedef typename KenLm::State State; typedef typename KenLm::WordIndex WordIndex; explicit KenLmDeterministicOnDemandFst(const KenLm *lm) : lm_(lm), num_states_(0), bos_state_id_(0) { // create bos to be FST start state MapElem e; lm->SetStateToBeginOfSentence(&e.first); e.second = bos_state_id_; std::pair r = state_map_.insert(e); KALDI_ASSERT(r.second == true); // bos successfully inserted into state map state_vec_.push_back(&r.first->first); num_states_++; eos_symbol_id_ = lm_->EosSymbolIndex(); } virtual ~KenLmDeterministicOnDemandFst() { } virtual StateId Start() { return bos_state_id_; } virtual bool GetArc(StateId s, Label label, Arc *oarc) { KALDI_ASSERT(s < static_cast(state_vec_.size())); const State* istate = state_vec_[s]; MapElem e; WordIndex word = lm_->GetWordIndex(label); BaseFloat log_10_prob = lm_->Score(istate, word, &e.first); e.second = num_states_; std::pair r = state_map_.insert(e); if (r.second == true) { // new state state_vec_.push_back(&(r.first->first)); num_states_++; } oarc->ilabel = label; oarc->olabel = oarc->ilabel; oarc->nextstate = r.first->second; oarc->weight = Weight(-log_10_prob * M_LN10); // KenLm log10() -> Kaldi ln() return true; } virtual Weight Final(StateId s) { Arc oarc; GetArc(s, eos_symbol_id_, &oarc); return oarc.weight; } private: typedef std::pair MapElem; typedef unordered_map MapType; typedef typename MapType::iterator IterType; const KenLm *lm_; // no ownership MapType state_map_; std::vector state_vec_; StateId num_states_; // state vector index range, [0, num_states_) StateId bos_state_id_; // fst start state id Label eos_symbol_id_; }; // class KenLmDeterministicOnDemandFst } // namespace kaldi #endif #endif