FunASR/runtime/onnxruntime/third_party/kaldi/lm/const-arpa-lm.cc

1086 lines
37 KiB
C++
Raw Permalink Normal View History

2024-05-18 15:50:56 +08:00
// lm/const-arpa-lm.cc
// Copyright 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <limits>
#include <sstream>
#include <utility>
#include "base/kaldi-math.h"
#include "lm/arpa-file-parser.h"
#include "lm/const-arpa-lm.h"
#include "util/stl-utils.h"
#include "util/text-utils.h"
namespace kaldi {
// Auxiliary struct for converting ConstArpaLm format langugae model to Arpa
// format.
struct ArpaLine {
std::vector<int32> words; // Sequence of words to be printed.
float logprob; // Logprob corresponds to word sequence.
float backoff_logprob; // Backoff_logprob corresponds to word sequence.
// Comparison function for sorting.
bool operator < (const ArpaLine &other) const {
if (words.size() < other.words.size()) {
return true;
} else if (words.size() > other.words.size()) {
return false;
} else {
return words < other.words;
}
}
};
// Auxiliary class to build ConstArpaLm. We first use this class to figure out
// the relative address of different LmStates, and then put everything into one
// block in memory.
class LmState {
public:
union ChildType {
// If child is not the final order, we keep the pointer to its LmState.
LmState* state;
// If child is the final order, we only keep the log probability for it.
float prob;
};
struct ChildrenVectorLessThan {
bool operator()(
const std::pair<int32, union ChildType>& lhs,
const std::pair<int32, union ChildType>& rhs) const {
return lhs.first < rhs.first;
}
};
LmState(const bool is_unigram, const bool is_child_final_order,
const float logprob, const float backoff_logprob) :
is_unigram_(is_unigram), is_child_final_order_(is_child_final_order),
logprob_(logprob), backoff_logprob_(backoff_logprob) {}
void SetMyAddress(const int64 address) {
my_address_ = address;
}
void AddChild(const int32 word, LmState* child_state) {
KALDI_ASSERT(!is_child_final_order_);
ChildType child;
child.state = child_state;
children_.push_back(std::make_pair(word, child));
}
void AddChild(const int32 word, const float child_prob) {
KALDI_ASSERT(is_child_final_order_);
ChildType child;
child.prob = child_prob;
children_.push_back(std::make_pair(word, child));
}
int64 MyAddress() const {
return my_address_;
}
bool IsUnigram() const {
return is_unigram_;
}
bool IsChildFinalOrder() const {
return is_child_final_order_;
}
float Logprob() const {
return logprob_;
}
float BackoffLogprob() const {
return backoff_logprob_;
}
int32 NumChildren() const {
return children_.size();
}
std::pair<int32, union ChildType> GetChild(const int32 index) {
KALDI_ASSERT(index < children_.size());
KALDI_ASSERT(index >= 0);
return children_[index];
}
void SortChildren() {
std::sort(children_.begin(), children_.end(), ChildrenVectorLessThan());
}
// Checks if the current LmState is a leaf.
bool IsLeaf() const {
return (backoff_logprob_ == 0.0 && children_.empty());
}
// Computes the size of the memory that the current LmState would take in
// <lm_states> array. It's the number of 4-byte chunks.
int32 MemSize() const {
if (IsLeaf() && !is_unigram_) {
// We don't create an entry in this case; the logprob will be stored in
// the same int32 that we would normally store the pointer in.
return 0;
} else {
// We store the following information:
// logprob, backoff_logprob, children.size() and children data.
return (3 + 2 * children_.size());
}
}
private:
// Unigram states will have LmStates even if they are leaves, therefore we
// need to note when this is a unigram or not.
bool is_unigram_;
// If the current LmState has an order of (final_order - 1), then its child
// must be the final order. We only keep the log probability for its child.
bool is_child_final_order_;
// When we compute the addresses of the LmStates as offsets into <lm_states_>
// pointer, and put the offsets here. Note that this is just offset, not
// actual pointer.
int64 my_address_;
// Language model log probability of the current sequence. For example, if
// this state is "A B", then it would be the logprob of "A -> B".
float logprob_;
// Language model backoff log probability of the current sequence, e.g., state
// "A B -> X" backing off to "B -> X".
float backoff_logprob_;
// List of children.
std::vector<std::pair<int32, union ChildType> > children_;
};
// Class to build ConstArpaLm from Arpa format language model. It relies on the
// auxiliary class LmState above.
class ConstArpaLmBuilder : public ArpaFileParser {
public:
explicit ConstArpaLmBuilder(ArpaParseOptions options)
: ArpaFileParser(options, NULL) {
ngram_order_ = 0;
num_words_ = 0;
overflow_buffer_size_ = 0;
lm_states_size_ = 0;
max_address_offset_ = pow(2, 30) - 1;
is_built_ = false;
lm_states_ = NULL;
unigram_states_ = NULL;
overflow_buffer_ = NULL;
}
~ConstArpaLmBuilder() {
unordered_map<std::vector<int32>,
LmState*, VectorHasher<int32> >::iterator iter;
for (iter = seq_to_state_.begin(); iter != seq_to_state_.end(); ++iter) {
delete iter->second;
}
if (is_built_) {
delete[] lm_states_;
delete[] unigram_states_;
delete[] overflow_buffer_;
}
}
// Writes ConstArpaLm.
void Write(std::ostream &os, bool binary) const;
void SetMaxAddressOffset(const int32 max_address_offset) {
KALDI_WARN << "You are changing <max_address_offset_>; the default should "
<< "not be changed unless you are in testing mode.";
max_address_offset_ = max_address_offset;
}
protected:
// ArpaFileParser overrides.
virtual void HeaderAvailable();
virtual void ConsumeNGram(const NGram& ngram);
virtual void ReadComplete();
private:
struct WordsAndLmStatePairLessThan {
bool operator()(
const std::pair<std::vector<int32>*, LmState*>& lhs,
const std::pair<std::vector<int32>*, LmState*>& rhs) const {
return *(lhs.first) < *(rhs.first);
}
};
private:
// Indicating if ConstArpaLm has been built or not.
bool is_built_;
// Maximum relative address for the child. We put it here just for testing.
// The default value is 30-bits and should not be changed except for testing.
int32 max_address_offset_;
// N-gram order of language model. This can be figured out from "/data/"
// section in Arpa format language model.
int32 ngram_order_;
// Index of largest word-id plus one. It defines the end of <unigram_states_>
// array.
int32 num_words_;
// Number of entries in the overflow buffer for pointers that couldn't be
// represented as a 30-bit relative index.
int32 overflow_buffer_size_;
// Size of the <lm_states_> array, which will be needed by I/O.
int64 lm_states_size_;
// Memory blcok for storing LmStates.
int32* lm_states_;
// Memory block for storing pointers of unigram LmStates.
int32** unigram_states_;
// Memory block for storing pointers of the LmStates that have large relative
// address to their parents.
int32** overflow_buffer_;
// Hash table from word sequences to LmStates.
unordered_map<std::vector<int32>,
LmState*, VectorHasher<int32> > seq_to_state_;
};
void ConstArpaLmBuilder::HeaderAvailable() {
ngram_order_ = NgramCounts().size();
}
void ConstArpaLmBuilder::ConsumeNGram(const NGram &ngram) {
int32 cur_order = ngram.words.size();
// If <ngram_order_> is larger than 1, then we do not create LmState for
// the final order entry. We only keep the log probability for it.
LmState *lm_state = NULL;
if (cur_order != ngram_order_ || ngram_order_ == 1) {
lm_state = new LmState(cur_order == 1,
cur_order == ngram_order_ - 1,
ngram.logprob, ngram.backoff);
if (seq_to_state_.find(ngram.words) != seq_to_state_.end()) {
std::ostringstream os;
os << "[ ";
for (size_t i = 0; i < ngram.words.size(); i++) {
os << ngram.words[i] << " ";
}
os <<"]";
KALDI_ERR << "N-gram " << os.str() << " appears twice in the arpa file";
}
seq_to_state_[ngram.words] = lm_state;
}
// If n-gram order is larger than 1, we have to add possible child to
// existing LmStates. We have the following two assumptions:
// 1. N-grams are processed from small order to larger ones, i.e., from
// 1, 2, ... to the highest order.
// 2. If a n-gram exists in the Arpa format language model, then the
// "history" n-gram also exists. For example, if "A B C" is a valid
// n-gram, then "A B" is also a valid n-gram.
int32 last_word = ngram.words[cur_order - 1];
if (cur_order > 1) {
std::vector<int32> hist(ngram.words.begin(), ngram.words.end() - 1);
unordered_map<std::vector<int32>,
LmState*, VectorHasher<int32> >::iterator hist_iter;
hist_iter = seq_to_state_.find(hist);
if (hist_iter == seq_to_state_.end()) {
std::ostringstream ss;
for (int i = 0; i < cur_order; ++i)
ss << (i == 0 ? '[' : ' ') << ngram.words[i];
KALDI_ERR << "In line " << LineNumber() << ": "
<< cur_order << "-gram " << ss.str() << "] does not have "
<< "a parent model " << cur_order << "-gram.";
}
if (cur_order != ngram_order_ || ngram_order_ == 1) {
KALDI_ASSERT(lm_state != NULL);
KALDI_ASSERT(!hist_iter->second->IsChildFinalOrder());
hist_iter->second->AddChild(last_word, lm_state);
} else {
KALDI_ASSERT(lm_state == NULL);
KALDI_ASSERT(hist_iter->second->IsChildFinalOrder());
hist_iter->second->AddChild(last_word, ngram.logprob);
}
} else {
// Figures out <max_word_id>.
num_words_ = std::max(num_words_, last_word + 1);
}
}
// ConstArpaLm can be built in the following steps, assuming we have already
// created LmStates <seq_to_state_>:
// 1. Sort LmStates lexicographically.
// This enables us to compute relative address. When we say lexicographic, we
// treat the word-ids as letters. After sorting, the LmStates are in the
// following order:
// ...
// A B
// A B A
// A B B
// A B C
// ...
// where each line represents a LmState.
// 2. Update <my_address> in LmState, which is relative to the first element in
// <sorted_vec>.
// 3. Put the following structure into the memory block
// struct LmState {
// float logprob;
// float backoff_logprob;
// int32 num_children;
// std::pair<int32, int32> [] children;
// }
//
// At the same time, we will also create two special buffers:
// <unigram_states_>
// <overflow_buffer_>
void ConstArpaLmBuilder::ReadComplete() {
// STEP 1: sorting LmStates lexicographically.
// Vector for holding the sorted LmStates.
std::vector<std::pair<std::vector<int32>*, LmState*> > sorted_vec;
unordered_map<std::vector<int32>,
LmState*, VectorHasher<int32> >::iterator iter;
for (iter = seq_to_state_.begin(); iter != seq_to_state_.end(); ++iter) {
if (iter->second->MemSize() > 0) {
sorted_vec.push_back(
std::make_pair(const_cast<std::vector<int32>*>(&(iter->first)),
iter->second));
}
}
std::sort(sorted_vec.begin(), sorted_vec.end(),
WordsAndLmStatePairLessThan());
// STEP 2: updating <my_address> in LmState.
for (int32 i = 0; i < sorted_vec.size(); ++i) {
lm_states_size_ += sorted_vec[i].second->MemSize();
if (i == 0) {
sorted_vec[i].second->SetMyAddress(0);
} else {
sorted_vec[i].second->SetMyAddress(sorted_vec[i - 1].second->MyAddress()
+ sorted_vec[i - 1].second->MemSize());
}
}
// STEP 3: creating memory block to store LmStates.
// Reserves a memory block for LmStates.
int64 lm_states_index = 0;
try {
lm_states_ = new int32[lm_states_size_];
} catch(const std::exception &e) {
KALDI_ERR << e.what();
}
// Puts data into memory block.
unigram_states_ = new int32*[num_words_];
std::vector<int32*> overflow_buffer_vec;
for (int32 i = 0; i < num_words_; ++i) {
unigram_states_[i] = NULL;
}
for (int32 i = 0; i < sorted_vec.size(); ++i) {
// Current address.
int32* parent_address = lm_states_ + lm_states_index;
// Adds logprob.
Int32AndFloat logprob_f(sorted_vec[i].second->Logprob());
lm_states_[lm_states_index++] = logprob_f.i;
// Adds backoff_logprob.
Int32AndFloat backoff_logprob_f(sorted_vec[i].second->BackoffLogprob());
lm_states_[lm_states_index++] = backoff_logprob_f.i;
// Adds num_children.
lm_states_[lm_states_index++] = sorted_vec[i].second->NumChildren();
// Adds children, there are 3 cases:
// 1. Child is a leaf and not unigram
// 2. Child is not a leaf or is unigram
// 2.1 Relative address can be represented by 30 bits
// 2.2 Relative address cannot be represented by 30 bits
sorted_vec[i].second->SortChildren();
for (int32 j = 0; j < sorted_vec[i].second->NumChildren(); ++j) {
int32 child_info;
if (sorted_vec[i].second->IsChildFinalOrder() ||
sorted_vec[i].second->GetChild(j).second.state->MemSize() == 0) {
// Child is a leaf and not unigram. In this case we will not create an
// entry in <lm_states_>; instead, we put the logprob in the place where
// we normally store the poitner.
Int32AndFloat child_logprob_f;
if (sorted_vec[i].second->IsChildFinalOrder()) {
child_logprob_f.f = sorted_vec[i].second->GetChild(j).second.prob;
} else {
child_logprob_f.f =
sorted_vec[i].second->GetChild(j).second.state->Logprob();
}
child_info = child_logprob_f.i;
child_info &= ~1; // Sets the last bit to 0 so <child_info> is even.
} else {
// Child is not a leaf or is unigram.
int64 offset =
sorted_vec[i].second->GetChild(j).second.state->MyAddress()
- sorted_vec[i].second->MyAddress();
KALDI_ASSERT(offset > 0);
if (offset <= max_address_offset_) {
// Relative address can be represented by 30 bits.
child_info = offset * 2;
child_info |= 1;
} else {
// Relative address cannot be represented by 30 bits, we have to put
// the child address into <overflow_buffer_>.
int32* abs_address = parent_address + offset;
overflow_buffer_vec.push_back(abs_address);
int32 overflow_buffer_index = overflow_buffer_vec.size() - 1;
child_info = overflow_buffer_index * 2;
child_info |= 1;
child_info *= -1;
}
}
// Child word.
lm_states_[lm_states_index++] = sorted_vec[i].second->GetChild(j).first;
// Child info.
lm_states_[lm_states_index++] = child_info;
}
// If the current state corresponds to an unigram, then create a separate
// loop up table to improve efficiency, since those will be looked up pretty
// frequently.
if (sorted_vec[i].second->IsUnigram()) {
KALDI_ASSERT(sorted_vec[i].first->size() == 1);
unigram_states_[(*sorted_vec[i].first)[0]] = parent_address;
}
}
KALDI_ASSERT(lm_states_size_ == lm_states_index);
// Move <overflow_buffer_> from vector holder to array.
overflow_buffer_size_ = overflow_buffer_vec.size();
overflow_buffer_ = new int32*[overflow_buffer_size_];
for (int32 i = 0; i < overflow_buffer_size_; ++i) {
overflow_buffer_[i] = overflow_buffer_vec[i];
}
is_built_ = true;
}
void ConstArpaLmBuilder::Write(std::ostream &os, bool binary) const {
if (!binary) {
KALDI_ERR << "text-mode writing is not implemented for ConstArpaLmBuilder.";
}
KALDI_ASSERT(is_built_);
// Creates ConstArpaLm.
ConstArpaLm const_arpa_lm(
Options().bos_symbol, Options().eos_symbol, Options().unk_symbol,
ngram_order_, num_words_, overflow_buffer_size_, lm_states_size_,
unigram_states_, overflow_buffer_, lm_states_);
const_arpa_lm.Write(os, binary);
}
void ConstArpaLm::Write(std::ostream &os, bool binary) const {
KALDI_ASSERT(initialized_);
if (!binary) {
KALDI_ERR << "text-mode writing is not implemented for ConstArpaLm.";
}
WriteToken(os, binary, "<ConstArpaLm>");
// Misc info.
WriteToken(os, binary, "<LmInfo>");
WriteBasicType(os, binary, bos_symbol_);
WriteBasicType(os, binary, eos_symbol_);
WriteBasicType(os, binary, unk_symbol_);
WriteBasicType(os, binary, ngram_order_);
WriteToken(os, binary, "</LmInfo>");
// LmStates section.
WriteToken(os, binary, "<LmStates>");
WriteBasicType(os, binary, lm_states_size_);
os.write(reinterpret_cast<char *>(lm_states_),
sizeof(int32) * lm_states_size_);
if (!os.good()) {
KALDI_ERR << "ConstArpaLm <LmStates> section writing failed.";
}
WriteToken(os, binary, "</LmStates>");
// Unigram section. We write memory offset to disk instead of the absolute
// pointers.
WriteToken(os, binary, "<LmUnigram>");
WriteBasicType(os, binary, num_words_);
int64* tmp_unigram_address = new int64[num_words_];
for (int32 i = 0; i < num_words_; ++i) {
// The relative address here is a little bit tricky:
// 1. If the original address is NULL, then we set the relative address to
// zero.
// 2. If the original address is not NULL, we set it to the following:
// unigram_states_[i] - lm_states_ + 1
// we plus 1 to ensure that the above value is positive.
tmp_unigram_address[i] = (unigram_states_[i] == NULL) ? 0 :
unigram_states_[i] - lm_states_ + 1;
}
os.write(reinterpret_cast<char *>(tmp_unigram_address),
sizeof(int64) * num_words_);
if (!os.good()) {
KALDI_ERR << "ConstArpaLm <LmUnigram> section writing failed.";
}
delete[] tmp_unigram_address; // Releases the memory.
tmp_unigram_address = NULL;
WriteToken(os, binary, "</LmUnigram>");
// Overflow section. We write memory offset to disk instead of the absolute
// pointers.
WriteToken(os, binary, "<LmOverflow>");
WriteBasicType(os, binary, overflow_buffer_size_);
int64* tmp_overflow_address = new int64[overflow_buffer_size_];
for (int32 i = 0; i < overflow_buffer_size_; ++i) {
// The relative address here is a little bit tricky:
// 1. If the original address is NULL, then we set the relative address to
// zero.
// 2. If the original address is not NULL, we set it to the following:
// overflow_buffer_[i] - lm_states_ + 1
// we plus 1 to ensure that the above value is positive.
tmp_overflow_address[i] = (overflow_buffer_[i] == NULL) ? 0 :
overflow_buffer_[i] - lm_states_ + 1;
}
os.write(reinterpret_cast<char *>(tmp_overflow_address),
sizeof(int64) * overflow_buffer_size_);
if (!os.good()) {
KALDI_ERR << "ConstArpaLm <LmOverflow> section writing failed.";
}
delete[] tmp_overflow_address;
tmp_overflow_address = NULL;
WriteToken(os, binary, "</LmOverflow>");
WriteToken(os, binary, "</ConstArpaLm>");
}
void ConstArpaLm::Read(std::istream &is, bool binary) {
KALDI_ASSERT(!initialized_);
if (!binary) {
KALDI_ERR << "text-mode reading is not implemented for ConstArpaLm.";
}
int first_char = is.peek();
if (first_char == 4) { // Old on-disk format starts with length of int32.
ReadInternalOldFormat(is, binary);
} else { // New on-disk format starts with token <ConstArpaLm>.
ReadInternal(is, binary);
}
}
void ConstArpaLm::ReadInternal(std::istream &is, bool binary) {
KALDI_ASSERT(!initialized_);
if (!binary) {
KALDI_ERR << "text-mode reading is not implemented for ConstArpaLm.";
}
ExpectToken(is, binary, "<ConstArpaLm>");
// Misc info.
ExpectToken(is, binary, "<LmInfo>");
ReadBasicType(is, binary, &bos_symbol_);
ReadBasicType(is, binary, &eos_symbol_);
ReadBasicType(is, binary, &unk_symbol_);
ReadBasicType(is, binary, &ngram_order_);
ExpectToken(is, binary, "</LmInfo>");
// LmStates section.
ExpectToken(is, binary, "<LmStates>");
ReadBasicType(is, binary, &lm_states_size_);
lm_states_ = new int32[lm_states_size_];
is.read(reinterpret_cast<char *>(lm_states_),
sizeof(int32) * lm_states_size_);
if (!is.good()) {
KALDI_ERR << "ConstArpaLm <LmStates> section reading failed.";
}
ExpectToken(is, binary, "</LmStates>");
// Unigram section. We write memory offset to disk instead of the absolute
// pointers.
ExpectToken(is, binary, "<LmUnigram>");
ReadBasicType(is, binary, &num_words_);
unigram_states_ = new int32*[num_words_];
int64* tmp_unigram_address = new int64[num_words_];
is.read(reinterpret_cast<char *>(tmp_unigram_address),
sizeof(int64) * num_words_);
if (!is.good()) {
KALDI_ERR << "ConstArpaLm <LmUnigram> section reading failed.";
}
for (int32 i = 0; i < num_words_; ++i) {
// Check out how we compute the relative address in ConstArpaLm::Write().
unigram_states_[i] = (tmp_unigram_address[i] == 0) ? NULL
: lm_states_ + tmp_unigram_address[i] - 1;
}
delete[] tmp_unigram_address;
tmp_unigram_address = NULL;
ExpectToken(is, binary, "</LmUnigram>");
// Overflow section. We write memory offset to disk instead of the absolute
// pointers.
ExpectToken(is, binary, "<LmOverflow>");
ReadBasicType(is, binary, &overflow_buffer_size_);
overflow_buffer_ = new int32*[overflow_buffer_size_];
int64* tmp_overflow_address = new int64[overflow_buffer_size_];
is.read(reinterpret_cast<char *>(tmp_overflow_address),
sizeof(int64) * overflow_buffer_size_);
if (!is.good()) {
KALDI_ERR << "ConstArpaLm <LmOverflow> section reading failed.";
}
for (int32 i = 0; i < overflow_buffer_size_; ++i) {
// Check out how we compute the relative address in ConstArpaLm::Write().
overflow_buffer_[i] = (tmp_overflow_address[i] == 0) ? NULL
: lm_states_ + tmp_overflow_address[i] - 1;
}
delete[] tmp_overflow_address;
tmp_overflow_address = NULL;
ExpectToken(is, binary, "</LmOverflow>");
ExpectToken(is, binary, "</ConstArpaLm>");
KALDI_ASSERT(ngram_order_ > 0);
KALDI_ASSERT(bos_symbol_ < num_words_ && bos_symbol_ > 0);
KALDI_ASSERT(eos_symbol_ < num_words_ && eos_symbol_ > 0);
KALDI_ASSERT(unk_symbol_ < num_words_ &&
(unk_symbol_ > 0 || unk_symbol_ == -1));
lm_states_end_ = lm_states_ + lm_states_size_ - 1;
memory_assigned_ = true;
initialized_ = true;
}
void ConstArpaLm::ReadInternalOldFormat(std::istream &is, bool binary) {
KALDI_ASSERT(!initialized_);
if (!binary) {
KALDI_ERR << "text-mode reading is not implemented for ConstArpaLm.";
}
// Misc info.
ReadBasicType(is, binary, &bos_symbol_);
ReadBasicType(is, binary, &eos_symbol_);
ReadBasicType(is, binary, &unk_symbol_);
ReadBasicType(is, binary, &ngram_order_);
// LmStates section.
// In the deprecated version, <lm_states_size_> used to be type of int32,
// which was a bug. We therefore use int32 for read for back-compatibility.
int32 lm_states_size_int32;
ReadBasicType(is, binary, &lm_states_size_int32);
lm_states_size_ = static_cast<int64>(lm_states_size_int32);
lm_states_ = new int32[lm_states_size_];
for (int64 i = 0; i < lm_states_size_; ++i) {
ReadBasicType(is, binary, &lm_states_[i]);
}
// Unigram section. We write memory offset to disk instead of the absolute
// pointers.
ReadBasicType(is, binary, &num_words_);
unigram_states_ = new int32*[num_words_];
for (int32 i = 0; i < num_words_; ++i) {
int64 tmp_address;
ReadBasicType(is, binary, &tmp_address);
// Check out how we compute the relative address in ConstArpaLm::Write().
unigram_states_[i] =
(tmp_address == 0) ? NULL : lm_states_ + tmp_address - 1;
}
// Overflow section. We write memory offset to disk instead of the absolute
// pointers.
ReadBasicType(is, binary, &overflow_buffer_size_);
overflow_buffer_ = new int32*[overflow_buffer_size_];
for (int32 i = 0; i < overflow_buffer_size_; ++i) {
int64 tmp_address;
ReadBasicType(is, binary, &tmp_address);
// Check out how we compute the relative address in ConstArpaLm::Write().
overflow_buffer_[i] =
(tmp_address == 0) ? NULL : lm_states_ + tmp_address - 1;
}
KALDI_ASSERT(ngram_order_ > 0);
KALDI_ASSERT(bos_symbol_ < num_words_ && bos_symbol_ > 0);
KALDI_ASSERT(eos_symbol_ < num_words_ && eos_symbol_ > 0);
KALDI_ASSERT(unk_symbol_ < num_words_ &&
(unk_symbol_ > 0 || unk_symbol_ == -1));
lm_states_end_ = lm_states_ + lm_states_size_ - 1;
memory_assigned_ = true;
initialized_ = true;
}
bool ConstArpaLm::HistoryStateExists(const std::vector<int32>& hist) const {
// We do not create LmState for empty word sequence, but technically it is the
// history state of all unigrams.
if (hist.size() == 0) {
return true;
}
// Tries to locate the LmState of the given word sequence.
int32* lm_state = GetLmState(hist);
if (lm_state == NULL) {
// <lm_state> does not exist means <hist> has no child.
return false;
} else {
// Note that we always create LmState for unigrams, so even if <lm_state> is
// not NULL, we still have to check if it has child.
KALDI_ASSERT(lm_state >= lm_states_);
KALDI_ASSERT(lm_state + 2 <= lm_states_end_);
// <lm_state + 2> points to <num_children>.
if (*(lm_state + 2) > 0) {
return true;
} else {
return false;
}
}
return true;
}
float ConstArpaLm::GetNgramLogprob(const int32 word,
const std::vector<int32>& hist) const {
KALDI_ASSERT(initialized_);
// If the history size plus one is larger than <ngram_order_>, remove the old
// words.
std::vector<int32> mapped_hist(hist);
while (mapped_hist.size() >= ngram_order_) {
mapped_hist.erase(mapped_hist.begin(), mapped_hist.begin() + 1);
}
KALDI_ASSERT(mapped_hist.size() + 1 <= ngram_order_);
// TODO(guoguo): check with Dan if this is reasonable.
// Maps possible out-of-vocabulary words to <unk>. If a word does not have a
// corresponding LmState, we treat it as <unk>. We map it to <unk> if <unk> is
// specified.
int32 mapped_word = word;
if (unk_symbol_ != -1) {
KALDI_ASSERT(mapped_word >= 0);
if (mapped_word >= num_words_ || unigram_states_[mapped_word] == NULL) {
mapped_word = unk_symbol_;
}
for (int32 i = 0; i < mapped_hist.size(); ++i) {
KALDI_ASSERT(mapped_hist[i] >= 0);
if (mapped_hist[i] >= num_words_ ||
unigram_states_[mapped_hist[i]] == NULL) {
mapped_hist[i] = unk_symbol_;
}
}
}
// Loops up n-gram probability.
return GetNgramLogprobRecurse(mapped_word, mapped_hist);
}
float ConstArpaLm::GetNgramLogprobRecurse(
const int32 word, const std::vector<int32>& hist) const {
KALDI_ASSERT(initialized_);
KALDI_ASSERT(hist.size() + 1 <= ngram_order_);
// Unigram case.
if (hist.size() == 0) {
if (word >= num_words_ || unigram_states_[word] == NULL) {
// If <unk> is defined, then the word sequence should have already been
// mapped to <unk> is necessary; this is for the case where <unk> is not
// defined.
return -std::numeric_limits<float>::infinity();
} else {
Int32AndFloat logprob_i(*unigram_states_[word]);
return logprob_i.f;
}
}
// High n-gram orders.
float logprob = 0.0;
float backoff_logprob = 0.0;
int32* state;
if ((state = GetLmState(hist)) != NULL) {
int32 child_info;
int32* child_lm_state = NULL;
if (GetChildInfo(word, state, &child_info)) {
DecodeChildInfo(child_info, state, &child_lm_state, &logprob);
return logprob;
} else {
Int32AndFloat backoff_logprob_i(*(state + 1));
backoff_logprob = backoff_logprob_i.f;
}
}
std::vector<int32> new_hist(hist);
new_hist.erase(new_hist.begin(), new_hist.begin() + 1);
return backoff_logprob + GetNgramLogprobRecurse(word, new_hist);
}
int32* ConstArpaLm::GetLmState(const std::vector<int32>& seq) const {
KALDI_ASSERT(initialized_);
// No LmState exists for empty word sequence.
if (seq.size() == 0) return NULL;
// If <unk> is defined, then the word sequence should have already been mapped
// to <unk> is necessary; this is for the case where <unk> is not defined.
if (seq[0] >= num_words_ || unigram_states_[seq[0]] == NULL) return NULL;
int32* parent = unigram_states_[seq[0]];
int32 child_info;
int32* child_lm_state = NULL;
float logprob;
for (int32 i = 1; i < seq.size(); ++i) {
if (!GetChildInfo(seq[i], parent, &child_info)) {
return NULL;
}
DecodeChildInfo(child_info, parent, &child_lm_state, &logprob);
if (child_lm_state == NULL) {
return NULL;
} else {
parent = child_lm_state;
}
}
return parent;
}
bool ConstArpaLm::GetChildInfo(const int32 word,
int32* parent, int32* child_info) const {
KALDI_ASSERT(initialized_);
KALDI_ASSERT(parent != NULL);
KALDI_ASSERT(parent >= lm_states_);
KALDI_ASSERT(child_info != NULL);
KALDI_ASSERT(parent + 2 <= lm_states_end_);
int32 num_children = *(parent + 2);
KALDI_ASSERT(parent + 2 + 2 * num_children <= lm_states_end_);
if (num_children == 0) return false;
// A binary search into the children memory block.
int32 start_index = 1;
int32 end_index = num_children;
while (start_index <= end_index) {
int32 mid_index = round((start_index + end_index) / 2);
int32 mid_word = *(parent + 1 + 2 * mid_index);
if (mid_word == word) {
*child_info = *(parent + 2 + 2 * mid_index);
return true;
} else if (mid_word < word) {
start_index = mid_index + 1;
} else {
end_index = mid_index - 1;
}
}
return false;
}
void ConstArpaLm::DecodeChildInfo(const int32 child_info,
int32* parent,
int32** child_lm_state,
float* logprob) const {
KALDI_ASSERT(initialized_);
KALDI_ASSERT(logprob != NULL);
if (child_info % 2 == 0) {
// Child is a leaf, only returns the log probability.
*child_lm_state = NULL;
Int32AndFloat logprob_i(child_info);
*logprob = logprob_i.f;
} else {
int32 child_offset = child_info / 2;
if (child_offset > 0) {
*child_lm_state = parent + child_offset;
Int32AndFloat logprob_i(**child_lm_state);
*logprob = logprob_i.f;
} else {
KALDI_ASSERT(-child_offset < overflow_buffer_size_);
*child_lm_state = overflow_buffer_[-child_offset];
Int32AndFloat logprob_i(**child_lm_state);
*logprob = logprob_i.f;
}
KALDI_ASSERT(*child_lm_state >= lm_states_);
KALDI_ASSERT(*child_lm_state <= lm_states_end_);
}
}
void ConstArpaLm::WriteArpaRecurse(int32* lm_state,
const std::vector<int32>& seq,
std::vector<ArpaLine> *output) const {
if (lm_state == NULL) return;
KALDI_ASSERT(lm_state >= lm_states_);
KALDI_ASSERT(lm_state + 2 <= lm_states_end_);
// Inserts the current LmState to <output>.
ArpaLine arpa_line;
arpa_line.words = seq;
Int32AndFloat logprob_i(*lm_state);
arpa_line.logprob = logprob_i.f;
Int32AndFloat backoff_logprob_i(*(lm_state + 1));
arpa_line.backoff_logprob = backoff_logprob_i.f;
output->push_back(arpa_line);
// Scans for possible children, and recursively adds child to <output>.
int32 num_children = *(lm_state + 2);
KALDI_ASSERT(lm_state + 2 + 2 * num_children <= lm_states_end_);
for (int32 i = 0; i < num_children; ++i) {
std::vector<int32> new_seq(seq);
new_seq.push_back(*(lm_state + 3 + 2 * i));
int32 child_info = *(lm_state + 4 + 2 * i);
float logprob;
int32* child_lm_state = NULL;
DecodeChildInfo(child_info, lm_state, &child_lm_state, &logprob);
if (child_lm_state == NULL) {
// Leaf case.
ArpaLine child_arpa_line;
child_arpa_line.words = new_seq;
child_arpa_line.logprob = logprob;
child_arpa_line.backoff_logprob = 0.0;
output->push_back(child_arpa_line);
} else {
WriteArpaRecurse(child_lm_state, new_seq, output);
}
}
}
void ConstArpaLm::WriteArpa(std::ostream &os) const {
KALDI_ASSERT(initialized_);
std::vector<ArpaLine> tmp_output;
for (int32 i = 0; i < num_words_; ++i) {
if (unigram_states_[i] != NULL) {
std::vector<int32> seq(1, i);
WriteArpaRecurse(unigram_states_[i], seq, &tmp_output);
}
}
// Sorts ArpaLines and collects head information.
std::sort(tmp_output.begin(), tmp_output.end());
std::vector<int32> ngram_count(1, 0);
for (int32 i = 0; i < tmp_output.size(); ++i) {
if (tmp_output[i].words.size() >= ngram_count.size()) {
ngram_count.resize(tmp_output[i].words.size() + 1);
ngram_count[tmp_output[i].words.size()] = 1;
} else {
ngram_count[tmp_output[i].words.size()] += 1;
}
}
// Writes the header.
os << std::endl;
os << "\\data\\" << std::endl;
for (int32 i = 1; i < ngram_count.size(); ++i) {
os << "ngram " << i << "=" << ngram_count[i] << std::endl;
}
// Writes n-grams.
int32 current_order = 0;
for (int32 i = 0; i < tmp_output.size(); ++i) {
// Beginning of a n-gram section.
if (tmp_output[i].words.size() != current_order) {
current_order = tmp_output[i].words.size();
os << std::endl;
os << "\\" << current_order << "-grams:" << std::endl;
}
// Writes logprob.
os << tmp_output[i].logprob << '\t';
// Writes word sequence.
for (int32 j = 0; j < tmp_output[i].words.size(); ++j) {
os << tmp_output[i].words[j];
if (j != tmp_output[i].words.size() - 1) {
os << " ";
}
}
// Writes backoff_logprob if it is not zero.
if (tmp_output[i].backoff_logprob != 0.0) {
os << '\t' << tmp_output[i].backoff_logprob;
}
os << std::endl;
}
os << std::endl << "\\end\\" << std::endl;
}
ConstArpaLmDeterministicFst::ConstArpaLmDeterministicFst(
const ConstArpaLm& lm) : lm_(lm) {
// Creates a history state for <s>.
std::vector<Label> bos_state(1, lm_.BosSymbol());
state_to_wseq_.push_back(bos_state);
wseq_to_state_[bos_state] = 0;
start_state_ = 0;
}
fst::StdArc::Weight ConstArpaLmDeterministicFst::Final(StateId s) {
// At this point, we should have created the state.
KALDI_ASSERT(static_cast<size_t>(s) < state_to_wseq_.size());
const std::vector<Label>& wseq = state_to_wseq_[s];
float logprob = lm_.GetNgramLogprob(lm_.EosSymbol(), wseq);
return Weight(-logprob);
}
bool ConstArpaLmDeterministicFst::GetArc(StateId s,
Label ilabel, fst::StdArc *oarc) {
// At this point, we should have created the state.
KALDI_ASSERT(static_cast<size_t>(s) < state_to_wseq_.size());
std::vector<Label> wseq = state_to_wseq_[s];
float logprob = lm_.GetNgramLogprob(ilabel, wseq);
if (logprob == -std::numeric_limits<float>::infinity()) {
return false;
}
// Locates the next state in ConstArpaLm. Note that OOV and backoff have been
// taken care of in ConstArpaLm.
wseq.push_back(ilabel);
while (wseq.size() >= lm_.NgramOrder()) {
// History state has at most lm_.NgramOrder() -1 words in the state.
wseq.erase(wseq.begin(), wseq.begin() + 1);
}
while (!lm_.HistoryStateExists(wseq)) {
KALDI_ASSERT(wseq.size() > 0);
wseq.erase(wseq.begin(), wseq.begin() + 1);
}
std::pair<const std::vector<Label>, StateId> wseq_state_pair(
wseq, static_cast<Label>(state_to_wseq_.size()));
// Attemps to insert the current <wseq_state_pair>. If the pair already exists
// then it returns false.
typedef MapType::iterator IterType;
std::pair<IterType, bool> result = wseq_to_state_.insert(wseq_state_pair);
// If the pair was just inserted, then also add it to <state_to_wseq_>.
if (result.second == true)
state_to_wseq_.push_back(wseq);
// Creates the arc.
oarc->ilabel = ilabel;
oarc->olabel = ilabel;
oarc->nextstate = result.first->second;
oarc->weight = Weight(-logprob);
return true;
}
bool BuildConstArpaLm(const ArpaParseOptions& options,
const std::string& arpa_rxfilename,
const std::string& const_arpa_wxfilename) {
ConstArpaLmBuilder lm_builder(options);
KALDI_LOG << "Reading " << arpa_rxfilename;
Input ki(arpa_rxfilename);
lm_builder.Read(ki.Stream());
WriteKaldiObject(lm_builder, const_arpa_wxfilename, true);
return true;
}
} // namespace kaldi