diff --git a/src/model.cc b/src/model.cc index 035ffee6..d99d95c9 100644 --- a/src/model.cc +++ b/src/model.cc @@ -12,18 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. - // -// For details of possible model layout see doc/models.md section model-structure +// For details of possible model layout see doc/models.md section +// model-structure #include "model.h" -#include +#include #include -#include #include -#include - +#include +#include #ifdef HAVE_MKL // We need to set num threads @@ -32,18 +31,19 @@ namespace fst { -static FstRegisterer OLabelLookAheadFst_StdArc_registerer; +static FstRegisterer + OLabelLookAheadFst_StdArc_registerer; static FstRegisterer> NGramFst_StdArc_registerer; -} // namespace fst +} // namespace fst #ifdef __ANDROID__ #include -static void KaldiLogHandler(const LogMessageEnvelope &env, const char *message) -{ +static void KaldiLogHandler(const LogMessageEnvelope &env, + const char *message) { int priority; if (env.severity > GetVerboseLevel()) - return; + return; if (env.severity > LogMessageEnvelope::kInfo) { priority = ANDROID_LOG_VERBOSE; @@ -66,16 +66,16 @@ static void KaldiLogHandler(const LogMessageEnvelope &env, const char *message) } std::stringstream full_message; - full_message << env.func << "():" << env.file << ':' - << env.line << ") " << message; + full_message << env.func << "():" << env.file << ':' << env.line << ") " + << message; __android_log_print(priority, "VoskAPI", "%s", full_message.str().c_str()); } #else -static void KaldiLogHandler(const LogMessageEnvelope &env, const char *message) -{ +static void KaldiLogHandler(const LogMessageEnvelope &env, + const char *message) { if (env.severity > GetVerboseLevel()) - return; + return; // Modified default Kaldi logging so we can disable LOG messages. std::stringstream full_message; @@ -99,8 +99,7 @@ static void KaldiLogHandler(const LogMessageEnvelope &env, const char *message) } } // Add other info from the envelope and the message text. - full_message << "VoskAPI" << ':' - << env.func << "():" << env.file << ':' + full_message << "VoskAPI" << ':' << env.func << "():" << env.file << ':' << env.line << ") " << message; // Print the complete message to stderr. @@ -111,283 +110,314 @@ static void KaldiLogHandler(const LogMessageEnvelope &env, const char *message) Model::Model(const char *model_path) : model_path_str_(model_path) { - SetLogHandler(KaldiLogHandler); + SetLogHandler(KaldiLogHandler); #ifdef HAVE_MKL - mkl_set_num_threads(1); + mkl_set_num_threads(1); #endif - struct stat buffer; - string am_v2_path = model_path_str_ + "/am/final.mdl"; - string model_conf_v2_path = model_path_str_ + "/conf/model.conf"; - string am_v1_path = model_path_str_ + "/final.mdl"; - string mfcc_v1_path = model_path_str_ + "/mfcc.conf"; - if (stat(am_v2_path.c_str(), &buffer) == 0 && stat(model_conf_v2_path.c_str(), &buffer) == 0) { - ConfigureV2(); - ReadDataFiles(); - } else if (stat(am_v1_path.c_str(), &buffer) == 0 && stat(mfcc_v1_path.c_str(), &buffer) == 0) { - ConfigureV1(); - ReadDataFiles(); - } else { - KALDI_ERR << "Folder '" << model_path_str_ << "' does not contain model files. " << - "Make sure you specified the model path properly in Model constructor. " << - "If you are not sure about relative path, use absolute path specification."; - } + struct stat buffer; + string am_v2_path = model_path_str_ + "/am/final.mdl"; + string model_conf_v2_path = model_path_str_ + "/conf/model.conf"; + string am_v1_path = model_path_str_ + "/final.mdl"; + string mfcc_v1_path = model_path_str_ + "/mfcc.conf"; + if (stat(am_v2_path.c_str(), &buffer) == 0 && + stat(model_conf_v2_path.c_str(), &buffer) == 0) { + ConfigureV2(); + ReadDataFiles(); + } else if (stat(am_v1_path.c_str(), &buffer) == 0 && + stat(mfcc_v1_path.c_str(), &buffer) == 0) { + ConfigureV1(); + ReadDataFiles(); + } else { + KALDI_ERR << "Folder '" << model_path_str_ + << "' does not contain model files. " + << "Make sure you specified the model path properly in Model " + "constructor. " + << "If you are not sure about relative path, use absolute path " + "specification."; + } - ref_cnt_ = 1; + ref_cnt_ = 1; } // Old model layout without model configuration file -void Model::ConfigureV1() -{ - const char *extra_args[] = { - "--max-active=7000", - "--beam=13.0", - "--lattice-beam=6.0", - "--acoustic-scale=1.0", - - "--frame-subsampling-factor=3", - - "--endpoint.silence-phones=1:2:3:4:5:6:7:8:9:10", - "--endpoint.rule2.min-trailing-silence=0.5", - "--endpoint.rule3.min-trailing-silence=1.0", - "--endpoint.rule4.min-trailing-silence=2.0", - - "--print-args=false", - }; - - kaldi::ParseOptions po(""); - nnet3_decoding_config_.Register(&po); - endpoint_config_.Register(&po); - decodable_opts_.Register(&po); - - vector args; - args.push_back("vosk"); - args.insert(args.end(), extra_args, extra_args + sizeof(extra_args) / sizeof(extra_args[0])); - po.Read(args.size(), args.data()); - - nnet3_rxfilename_ = model_path_str_ + "/final.mdl"; - hclg_fst_rxfilename_ = model_path_str_ + "/HCLG.fst"; - hcl_fst_rxfilename_ = model_path_str_ + "/HCLr.fst"; - g_fst_rxfilename_ = model_path_str_ + "/Gr.fst"; - disambig_rxfilename_ = model_path_str_ + "/disambig_tid.int"; - word_syms_rxfilename_ = model_path_str_ + "/words.txt"; - winfo_rxfilename_ = model_path_str_ + "/word_boundary.int"; - carpa_rxfilename_ = model_path_str_ + "/rescore/G.carpa"; - std_fst_rxfilename_ = model_path_str_ + "/rescore/G.fst"; - final_ie_rxfilename_ = model_path_str_ + "/ivector/final.ie"; - mfcc_conf_rxfilename_ = model_path_str_ + "/mfcc.conf"; - fbank_conf_rxfilename_ = model_path_str_ + "/fbank.conf"; - global_cmvn_stats_rxfilename_ = model_path_str_ + "/global_cmvn.stats"; - pitch_conf_rxfilename_ = model_path_str_ + "/pitch.conf"; - rnnlm_word_feats_rxfilename_ = model_path_str_ + "/rnnlm/word_feats.txt"; - rnnlm_feat_embedding_rxfilename_ = model_path_str_ + "/rnnlm/feat_embedding.final.mat"; - rnnlm_config_rxfilename_ = model_path_str_ + "/rnnlm/special_symbol_opts.conf"; - rnnlm_lm_rxfilename_ = model_path_str_ + "/rnnlm/final.raw"; +void Model::ConfigureV1() { + const char *extra_args[] = { + "--max-active=7000", + "--beam=13.0", + "--lattice-beam=6.0", + "--acoustic-scale=1.0", + + "--frame-subsampling-factor=3", + + "--endpoint.silence-phones=1:2:3:4:5:6:7:8:9:10", + "--endpoint.rule2.min-trailing-silence=0.5", + "--endpoint.rule3.min-trailing-silence=1.0", + "--endpoint.rule4.min-trailing-silence=2.0", + + "--print-args=false", + }; + + kaldi::ParseOptions po(""); + nnet3_decoding_config_.Register(&po); + endpoint_config_.Register(&po); + decodable_opts_.Register(&po); + + vector args; + args.push_back("vosk"); + args.insert(args.end(), extra_args, + extra_args + sizeof(extra_args) / sizeof(extra_args[0])); + po.Read(args.size(), args.data()); + + nnet3_rxfilename_ = model_path_str_ + "/final.mdl"; + ctx_dep_rxfilename_ = model_path_str_ + "/tree"; + hclg_fst_rxfilename_ = model_path_str_ + "/HCLG.fst"; + hcl_fst_rxfilename_ = model_path_str_ + "/HCLr.fst"; + g_fst_rxfilename_ = model_path_str_ + "/Gr.fst"; + disambig_rxfilename_ = model_path_str_ + "/disambig_tid.int"; + word_syms_rxfilename_ = model_path_str_ + "/words.txt"; + winfo_rxfilename_ = model_path_str_ + "/word_boundary.int"; + phone_syms_rxfilename_ = model_path_str_ + "/phones.txt"; + carpa_rxfilename_ = model_path_str_ + "/rescore/G.carpa"; + std_fst_rxfilename_ = model_path_str_ + "/rescore/G.fst"; + final_ie_rxfilename_ = model_path_str_ + "/ivector/final.ie"; + mfcc_conf_rxfilename_ = model_path_str_ + "/mfcc.conf"; + fbank_conf_rxfilename_ = model_path_str_ + "/fbank.conf"; + global_cmvn_stats_rxfilename_ = model_path_str_ + "/global_cmvn.stats"; + pitch_conf_rxfilename_ = model_path_str_ + "/pitch.conf"; + rnnlm_word_feats_rxfilename_ = model_path_str_ + "/rnnlm/word_feats.txt"; + rnnlm_feat_embedding_rxfilename_ = + model_path_str_ + "/rnnlm/feat_embedding.final.mat"; + rnnlm_config_rxfilename_ = + model_path_str_ + "/rnnlm/special_symbol_opts.conf"; + rnnlm_lm_rxfilename_ = model_path_str_ + "/rnnlm/final.raw"; } -void Model::ConfigureV2() -{ - kaldi::ParseOptions po("something"); - nnet3_decoding_config_.Register(&po); - endpoint_config_.Register(&po); - decodable_opts_.Register(&po); - po.ReadConfigFile(model_path_str_ + "/conf/model.conf"); - - - nnet3_rxfilename_ = model_path_str_ + "/am/final.mdl"; - hclg_fst_rxfilename_ = model_path_str_ + "/graph/HCLG.fst"; - hcl_fst_rxfilename_ = model_path_str_ + "/graph/HCLr.fst"; - g_fst_rxfilename_ = model_path_str_ + "/graph/Gr.fst"; - disambig_rxfilename_ = model_path_str_ + "/graph/disambig_tid.int"; - word_syms_rxfilename_ = model_path_str_ + "/graph/words.txt"; - winfo_rxfilename_ = model_path_str_ + "/graph/phones/word_boundary.int"; - carpa_rxfilename_ = model_path_str_ + "/rescore/G.carpa"; - std_fst_rxfilename_ = model_path_str_ + "/rescore/G.fst"; - final_ie_rxfilename_ = model_path_str_ + "/ivector/final.ie"; - mfcc_conf_rxfilename_ = model_path_str_ + "/conf/mfcc.conf"; - fbank_conf_rxfilename_ = model_path_str_ + "/conf/fbank.conf"; - global_cmvn_stats_rxfilename_ = model_path_str_ + "/am/global_cmvn.stats"; - pitch_conf_rxfilename_ = model_path_str_ + "/conf/pitch.conf"; - rnnlm_word_feats_rxfilename_ = model_path_str_ + "/rnnlm/word_feats.txt"; - rnnlm_feat_embedding_rxfilename_ = model_path_str_ + "/rnnlm/feat_embedding.final.mat"; - rnnlm_config_rxfilename_ = model_path_str_ + "/rnnlm/special_symbol_opts.conf"; - rnnlm_lm_rxfilename_ = model_path_str_ + "/rnnlm/final.raw"; +void Model::ConfigureV2() { + kaldi::ParseOptions po("something"); + nnet3_decoding_config_.Register(&po); + endpoint_config_.Register(&po); + decodable_opts_.Register(&po); + po.ReadConfigFile(model_path_str_ + "/conf/model.conf"); + + nnet3_rxfilename_ = model_path_str_ + "/am/final.mdl"; + ctx_dep_rxfilename_ = model_path_str_ + "/am/tree"; + hclg_fst_rxfilename_ = model_path_str_ + "/graph/HCLG.fst"; + hcl_fst_rxfilename_ = model_path_str_ + "/graph/HCLr.fst"; + g_fst_rxfilename_ = model_path_str_ + "/graph/Gr.fst"; + disambig_rxfilename_ = model_path_str_ + "/graph/disambig_tid.int"; + word_syms_rxfilename_ = model_path_str_ + "/graph/words.txt"; + winfo_rxfilename_ = model_path_str_ + "/graph/phones/word_boundary.int"; + phone_syms_rxfilename_ = model_path_str_ + "/graph/phones.txt"; + carpa_rxfilename_ = model_path_str_ + "/rescore/G.carpa"; + std_fst_rxfilename_ = model_path_str_ + "/rescore/G.fst"; + final_ie_rxfilename_ = model_path_str_ + "/ivector/final.ie"; + mfcc_conf_rxfilename_ = model_path_str_ + "/conf/mfcc.conf"; + fbank_conf_rxfilename_ = model_path_str_ + "/conf/fbank.conf"; + global_cmvn_stats_rxfilename_ = model_path_str_ + "/am/global_cmvn.stats"; + pitch_conf_rxfilename_ = model_path_str_ + "/conf/pitch.conf"; + rnnlm_word_feats_rxfilename_ = model_path_str_ + "/rnnlm/word_feats.txt"; + rnnlm_feat_embedding_rxfilename_ = + model_path_str_ + "/rnnlm/feat_embedding.final.mat"; + rnnlm_config_rxfilename_ = + model_path_str_ + "/rnnlm/special_symbol_opts.conf"; + rnnlm_lm_rxfilename_ = model_path_str_ + "/rnnlm/final.raw"; } -void Model::ReadDataFiles() -{ - struct stat buffer; - - KALDI_LOG << "Decoding params beam=" << nnet3_decoding_config_.beam << - " max-active=" << nnet3_decoding_config_.max_active << - " lattice-beam=" << nnet3_decoding_config_.lattice_beam; - KALDI_LOG << "Silence phones " << endpoint_config_.silence_phones; - - if (stat(mfcc_conf_rxfilename_.c_str(), &buffer) == 0) { - feature_info_.feature_type = "mfcc"; - ReadConfigFromFile(mfcc_conf_rxfilename_, &feature_info_.mfcc_opts); - feature_info_.mfcc_opts.frame_opts.allow_downsample = true; // It is safe to downsample - } else if (stat(fbank_conf_rxfilename_.c_str(), &buffer) == 0) { - feature_info_.feature_type = "fbank"; - ReadConfigFromFile(fbank_conf_rxfilename_, &feature_info_.fbank_opts); - feature_info_.fbank_opts.frame_opts.allow_downsample = true; // It is safe to downsample - } else { - KALDI_ERR << "Failed to find feature config file"; - } +void Model::ReadDataFiles() { + struct stat buffer; + + KALDI_LOG << "Decoding params beam=" << nnet3_decoding_config_.beam + << " max-active=" << nnet3_decoding_config_.max_active + << " lattice-beam=" << nnet3_decoding_config_.lattice_beam; + KALDI_LOG << "Silence phones " << endpoint_config_.silence_phones; + + if (stat(mfcc_conf_rxfilename_.c_str(), &buffer) == 0) { + feature_info_.feature_type = "mfcc"; + ReadConfigFromFile(mfcc_conf_rxfilename_, &feature_info_.mfcc_opts); + feature_info_.mfcc_opts.frame_opts.allow_downsample = + true; // It is safe to downsample + } else if (stat(fbank_conf_rxfilename_.c_str(), &buffer) == 0) { + feature_info_.feature_type = "fbank"; + ReadConfigFromFile(fbank_conf_rxfilename_, &feature_info_.fbank_opts); + feature_info_.fbank_opts.frame_opts.allow_downsample = + true; // It is safe to downsample + } else { + KALDI_ERR << "Failed to find feature config file"; + } - feature_info_.silence_weighting_config.silence_weight = 1e-3; - feature_info_.silence_weighting_config.silence_phones_str = endpoint_config_.silence_phones; + feature_info_.silence_weighting_config.silence_weight = 1e-3; + feature_info_.silence_weighting_config.silence_phones_str = + endpoint_config_.silence_phones; + + trans_model_ = new kaldi::TransitionModel(); + nnet_ = new kaldi::nnet3::AmNnetSimple(); + { + bool binary; + kaldi::Input ki(nnet3_rxfilename_, &binary); + trans_model_->Read(ki.Stream(), binary); + nnet_->Read(ki.Stream(), binary); + SetBatchnormTestMode(true, &(nnet_->GetNnet())); + SetDropoutTestMode(true, &(nnet_->GetNnet())); + nnet3::CollapseModel(nnet3::CollapseModelConfig(), &(nnet_->GetNnet())); + } - trans_model_ = new kaldi::TransitionModel(); - nnet_ = new kaldi::nnet3::AmNnetSimple(); - { - bool binary; - kaldi::Input ki(nnet3_rxfilename_, &binary); - trans_model_->Read(ki.Stream(), binary); - nnet_->Read(ki.Stream(), binary); - SetBatchnormTestMode(true, &(nnet_->GetNnet())); - SetDropoutTestMode(true, &(nnet_->GetNnet())); - nnet3::CollapseModel(nnet3::CollapseModelConfig(), &(nnet_->GetNnet())); - } + decodable_info_ = + new nnet3::DecodableNnetSimpleLoopedInfo(decodable_opts_, nnet_); + + if (stat(final_ie_rxfilename_.c_str(), &buffer) == 0) { + KALDI_LOG << "Loading i-vector extractor from " << final_ie_rxfilename_; + + OnlineIvectorExtractionConfig ivector_extraction_opts; + ivector_extraction_opts.splice_config_rxfilename = + model_path_str_ + "/ivector/splice.conf"; + ivector_extraction_opts.cmvn_config_rxfilename = + model_path_str_ + "/ivector/online_cmvn.conf"; + ivector_extraction_opts.lda_mat_rxfilename = + model_path_str_ + "/ivector/final.mat"; + ivector_extraction_opts.global_cmvn_stats_rxfilename = + model_path_str_ + "/ivector/global_cmvn.stats"; + ivector_extraction_opts.diag_ubm_rxfilename = + model_path_str_ + "/ivector/final.dubm"; + ivector_extraction_opts.ivector_extractor_rxfilename = + model_path_str_ + "/ivector/final.ie"; + ivector_extraction_opts.max_count = 100; + + feature_info_.use_ivectors = true; + feature_info_.ivector_extractor_info.Init(ivector_extraction_opts); + } else if (nnet_->IvectorDim() > 0) { + KALDI_ERR << "Can't find required ivector extractor"; + } else { + feature_info_.use_ivectors = false; + } - decodable_info_ = new nnet3::DecodableNnetSimpleLoopedInfo(decodable_opts_, - nnet_); - - if (stat(final_ie_rxfilename_.c_str(), &buffer) == 0) { - KALDI_LOG << "Loading i-vector extractor from " << final_ie_rxfilename_; - - OnlineIvectorExtractionConfig ivector_extraction_opts; - ivector_extraction_opts.splice_config_rxfilename = model_path_str_ + "/ivector/splice.conf"; - ivector_extraction_opts.cmvn_config_rxfilename = model_path_str_ + "/ivector/online_cmvn.conf"; - ivector_extraction_opts.lda_mat_rxfilename = model_path_str_ + "/ivector/final.mat"; - ivector_extraction_opts.global_cmvn_stats_rxfilename = model_path_str_ + "/ivector/global_cmvn.stats"; - ivector_extraction_opts.diag_ubm_rxfilename = model_path_str_ + "/ivector/final.dubm"; - ivector_extraction_opts.ivector_extractor_rxfilename = model_path_str_ + "/ivector/final.ie"; - ivector_extraction_opts.max_count = 100; - - feature_info_.use_ivectors = true; - feature_info_.ivector_extractor_info.Init(ivector_extraction_opts); - } else if (nnet_->IvectorDim() > 0) { - KALDI_ERR << "Can't find required ivector extractor"; - } else { - feature_info_.use_ivectors = false; - } + if (stat(global_cmvn_stats_rxfilename_.c_str(), &buffer) == 0) { + KALDI_LOG << "Reading CMVN stats from " << global_cmvn_stats_rxfilename_; + feature_info_.use_cmvn = true; + ReadKaldiObject(global_cmvn_stats_rxfilename_, + &feature_info_.global_cmvn_stats); + } - if (stat(global_cmvn_stats_rxfilename_.c_str(), &buffer) == 0) { - KALDI_LOG << "Reading CMVN stats from " << global_cmvn_stats_rxfilename_; - feature_info_.use_cmvn = true; - ReadKaldiObject(global_cmvn_stats_rxfilename_, &feature_info_.global_cmvn_stats); - } + if (stat(pitch_conf_rxfilename_.c_str(), &buffer) == 0) { + KALDI_LOG << "Using pitch in feature pipeline"; + feature_info_.add_pitch = true; + ReadConfigsFromFile(pitch_conf_rxfilename_, &feature_info_.pitch_opts, + &feature_info_.pitch_process_opts); + } - if (stat(pitch_conf_rxfilename_.c_str(), &buffer) == 0) { - KALDI_LOG << "Using pitch in feature pipeline"; - feature_info_.add_pitch = true; - ReadConfigsFromFile(pitch_conf_rxfilename_, - &feature_info_.pitch_opts, &feature_info_.pitch_process_opts); + if (stat(hclg_fst_rxfilename_.c_str(), &buffer) == 0) { + KALDI_LOG << "Loading HCLG from " << hclg_fst_rxfilename_; + hclg_fst_ = fst::ReadFstKaldiGeneric(hclg_fst_rxfilename_); + } else { + KALDI_LOG << "Loading HCL and G from " << hcl_fst_rxfilename_ << " " + << g_fst_rxfilename_; + hcl_fst_ = fst::StdFst::Read(hcl_fst_rxfilename_); + g_fst_ = fst::StdFst::Read(g_fst_rxfilename_); + if (!ReadIntegerVectorSimple(disambig_rxfilename_, &disambig_)) { + KALDI_ERR << "Could not read disambig symbol table from file " + << disambig_rxfilename_; } + } - if (stat(hclg_fst_rxfilename_.c_str(), &buffer) == 0) { - KALDI_LOG << "Loading HCLG from " << hclg_fst_rxfilename_; - hclg_fst_ = fst::ReadFstKaldiGeneric(hclg_fst_rxfilename_); - } else { - KALDI_LOG << "Loading HCL and G from " << hcl_fst_rxfilename_ << " " << g_fst_rxfilename_; - hcl_fst_ = fst::StdFst::Read(hcl_fst_rxfilename_); - g_fst_ = fst::StdFst::Read(g_fst_rxfilename_); - if (!ReadIntegerVectorSimple(disambig_rxfilename_, &disambig_)) { - KALDI_ERR << "Could not read disambig symbol table from file " - << disambig_rxfilename_; - } - } + if (stat(ctx_dep_rxfilename_.c_str(), &buffer) == 0) { + KALDI_LOG << "Loading context dependency from " << ctx_dep_rxfilename_; + ctx_dep_ = new kaldi::ContextDependency(); + kaldi::ReadKaldiObject(ctx_dep_rxfilename_, ctx_dep_); + } - if (hclg_fst_ && hclg_fst_->OutputSymbols()) { - word_syms_ = hclg_fst_->OutputSymbols(); - } else if (g_fst_ && g_fst_->OutputSymbols()) { - word_syms_ = g_fst_->OutputSymbols(); - } - if (!word_syms_) { - KALDI_LOG << "Loading words from " << word_syms_rxfilename_; - if (!(word_syms_ = fst::SymbolTable::ReadText(word_syms_rxfilename_))) - KALDI_ERR << "Could not read symbol table from file " - << word_syms_rxfilename_; - word_syms_loaded_ = word_syms_; - } - if (!word_syms_) { - KALDI_ERR << "Word symbol table empty"; - } + if (hclg_fst_ && hclg_fst_->OutputSymbols()) { + word_syms_ = hclg_fst_->OutputSymbols(); + } else if (g_fst_ && g_fst_->OutputSymbols()) { + word_syms_ = g_fst_->OutputSymbols(); + } + if (!word_syms_) { + KALDI_LOG << "Loading words from " << word_syms_rxfilename_; + if (!(word_syms_ = fst::SymbolTable::ReadText(word_syms_rxfilename_))) + KALDI_ERR << "Could not read symbol table from file " + << word_syms_rxfilename_; + word_syms_loaded_ = word_syms_; + } + if (!word_syms_) { + KALDI_ERR << "Word symbol table empty"; + } - if (stat(winfo_rxfilename_.c_str(), &buffer) == 0) { - KALDI_LOG << "Loading winfo " << winfo_rxfilename_; - kaldi::WordBoundaryInfoNewOpts opts; - winfo_ = new kaldi::WordBoundaryInfo(opts, winfo_rxfilename_); - } + if (stat(winfo_rxfilename_.c_str(), &buffer) == 0) { + KALDI_LOG << "Loading winfo " << winfo_rxfilename_; + kaldi::WordBoundaryInfoNewOpts opts; + winfo_ = new kaldi::WordBoundaryInfo(opts, winfo_rxfilename_); + } - if (stat(carpa_rxfilename_.c_str(), &buffer) == 0) { + if (stat(phone_syms_rxfilename_.c_str(), &buffer) == 0) { + KALDI_LOG << "Loading phones from " << phone_syms_rxfilename_; + phone_syms_ = fst::SymbolTable::ReadText(phone_syms_rxfilename_); + } - KALDI_LOG << "Loading subtract G.fst model from " << std_fst_rxfilename_; - graph_lm_fst_ = fst::ReadAndPrepareLmFst(std_fst_rxfilename_); - KALDI_LOG << "Loading CARPA model from " << carpa_rxfilename_; - ReadKaldiObject(carpa_rxfilename_, &const_arpa_); - } + if (stat(carpa_rxfilename_.c_str(), &buffer) == 0) { + + KALDI_LOG << "Loading subtract G.fst model from " << std_fst_rxfilename_; + graph_lm_fst_ = fst::ReadAndPrepareLmFst(std_fst_rxfilename_); + KALDI_LOG << "Loading CARPA model from " << carpa_rxfilename_; + ReadKaldiObject(carpa_rxfilename_, &const_arpa_); + } + + // RNNLM Rescoring + if (stat(rnnlm_lm_rxfilename_.c_str(), &buffer) == 0) { + KALDI_LOG << "Loading RNNLM model from " << rnnlm_lm_rxfilename_; - // RNNLM Rescoring - if (stat(rnnlm_lm_rxfilename_.c_str(), &buffer) == 0) { - KALDI_LOG << "Loading RNNLM model from " << rnnlm_lm_rxfilename_; - - ReadKaldiObject(rnnlm_lm_rxfilename_, &rnnlm); - Matrix feature_embedding_mat; - ReadKaldiObject(rnnlm_feat_embedding_rxfilename_, &feature_embedding_mat); - SparseMatrix word_feature_mat; - { - Input input(rnnlm_word_feats_rxfilename_); - int32 feature_dim = feature_embedding_mat.NumRows(); - rnnlm::ReadSparseWordFeatures(input.Stream(), feature_dim, - &word_feature_mat); - } - Matrix wm(word_feature_mat.NumRows(), feature_embedding_mat.NumCols()); - wm.AddSmatMat(1.0, word_feature_mat, kNoTrans, - feature_embedding_mat, 0.0); - word_embedding_mat.Resize(wm.NumRows(), wm.NumCols(), kUndefined); - word_embedding_mat.CopyFromMat(wm); - - ReadConfigFromFile(rnnlm_config_rxfilename_, &rnnlm_compute_opts); - - rnnlm_enabled_ = true; + ReadKaldiObject(rnnlm_lm_rxfilename_, &rnnlm); + Matrix feature_embedding_mat; + ReadKaldiObject(rnnlm_feat_embedding_rxfilename_, &feature_embedding_mat); + SparseMatrix word_feature_mat; + { + Input input(rnnlm_word_feats_rxfilename_); + int32 feature_dim = feature_embedding_mat.NumRows(); + rnnlm::ReadSparseWordFeatures(input.Stream(), feature_dim, + &word_feature_mat); } + Matrix wm(word_feature_mat.NumRows(), + feature_embedding_mat.NumCols()); + wm.AddSmatMat(1.0, word_feature_mat, kNoTrans, feature_embedding_mat, 0.0); + word_embedding_mat.Resize(wm.NumRows(), wm.NumCols(), kUndefined); + word_embedding_mat.CopyFromMat(wm); + + ReadConfigFromFile(rnnlm_config_rxfilename_, &rnnlm_compute_opts); + rnnlm_enabled_ = true; + } } -void Model::Ref() -{ - std::atomic_fetch_add_explicit(&ref_cnt_, 1, std::memory_order_relaxed); +void Model::Ref() { + std::atomic_fetch_add_explicit(&ref_cnt_, 1, std::memory_order_relaxed); } -void Model::Unref() -{ - if (std::atomic_fetch_sub_explicit(&ref_cnt_, 1, std::memory_order_release) == 1) { - std::atomic_thread_fence(std::memory_order_acquire); - delete this; - } +void Model::Unref() { + if (std::atomic_fetch_sub_explicit(&ref_cnt_, 1, std::memory_order_release) == + 1) { + std::atomic_thread_fence(std::memory_order_acquire); + delete this; + } } -int Model::FindWord(const char *word) -{ - if (!word_syms_) - return -1; +int Model::FindWord(const char *word) { + if (!word_syms_) + return -1; - return word_syms_->Find(word); + return word_syms_->Find(word); } Model::~Model() { - delete decodable_info_; - delete trans_model_; - delete nnet_; - if (word_syms_loaded_) - delete word_syms_; - delete winfo_; - delete hclg_fst_; - delete hcl_fst_; - delete g_fst_; - delete graph_lm_fst_; + delete decodable_info_; + delete trans_model_; + delete nnet_; + if (word_syms_loaded_) + delete word_syms_; + delete phone_syms_; + delete winfo_; + delete hclg_fst_; + delete hcl_fst_; + delete g_fst_; + delete ctx_dep_; + delete graph_lm_fst_; } diff --git a/src/model.h b/src/model.h index 7fc09df6..edcf2b7a 100644 --- a/src/model.h +++ b/src/model.h @@ -18,19 +18,20 @@ #include "base/kaldi-common.h" #include "fstext/fstext-lib.h" #include "fstext/fstext-utils.h" -#include "online2/onlinebin-util.h" -#include "online2/online-timing.h" -#include "online2/online-endpoint.h" -#include "online2/online-nnet3-incremental-decoding.h" -#include "online2/online-feature-pipeline.h" #include "lat/lattice-functions.h" #include "lat/sausages.h" #include "lat/word-align-lattice.h" #include "lm/const-arpa-lm.h" -#include "util/parse-options.h" #include "nnet3/nnet-utils.h" -#include "rnnlm/rnnlm-utils.h" +#include "online2/online-endpoint.h" +#include "online2/online-feature-pipeline.h" +#include "online2/online-nnet3-incremental-decoding.h" +#include "online2/online-timing.h" +#include "online2/onlinebin-util.h" #include "rnnlm/rnnlm-lattice-rescoring.h" +#include "rnnlm/rnnlm-utils.h" +#include "tree/context-dep.h" +#include "util/parse-options.h" #include using namespace kaldi; @@ -41,66 +42,71 @@ class Recognizer; class Model { public: - Model(const char *model_path); - void Ref(); - void Unref(); - int FindWord(const char *word); + Model(const char *model_path); + void Ref(); + void Unref(); + int FindWord(const char *word); protected: - ~Model(); - void ConfigureV1(); - void ConfigureV2(); - void ReadDataFiles(); - - friend class Recognizer; - - string model_path_str_; - string nnet3_rxfilename_; - string hclg_fst_rxfilename_; - string hcl_fst_rxfilename_; - string g_fst_rxfilename_; - string disambig_rxfilename_; - string word_syms_rxfilename_; - string winfo_rxfilename_; - string carpa_rxfilename_; - string std_fst_rxfilename_; - string final_ie_rxfilename_; - string mfcc_conf_rxfilename_; - string fbank_conf_rxfilename_; - string global_cmvn_stats_rxfilename_; - string pitch_conf_rxfilename_; - - string rnnlm_word_feats_rxfilename_; - string rnnlm_feat_embedding_rxfilename_; - string rnnlm_config_rxfilename_; - string rnnlm_lm_rxfilename_; - - kaldi::OnlineEndpointConfig endpoint_config_; - kaldi::LatticeIncrementalDecoderConfig nnet3_decoding_config_; - kaldi::nnet3::NnetSimpleLoopedComputationOptions decodable_opts_; - kaldi::OnlineNnet2FeaturePipelineInfo feature_info_; - - kaldi::nnet3::DecodableNnetSimpleLoopedInfo *decodable_info_ = nullptr; - kaldi::TransitionModel *trans_model_ = nullptr; - kaldi::nnet3::AmNnetSimple *nnet_ = nullptr; - const fst::SymbolTable *word_syms_ = nullptr; - bool word_syms_loaded_ = false; - kaldi::WordBoundaryInfo *winfo_ = nullptr; - vector disambig_; - - fst::Fst *hclg_fst_ = nullptr; - fst::Fst *hcl_fst_ = nullptr; - fst::Fst *g_fst_ = nullptr; - - fst::VectorFst *graph_lm_fst_ = nullptr; - kaldi::ConstArpaLm const_arpa_; - - kaldi::rnnlm::RnnlmComputeStateComputationOptions rnnlm_compute_opts; - CuMatrix word_embedding_mat; - kaldi::nnet3::Nnet rnnlm; - bool rnnlm_enabled_ = false; - - std::atomic ref_cnt_; + ~Model(); + void ConfigureV1(); + void ConfigureV2(); + void ReadDataFiles(); + + friend class Recognizer; + + string model_path_str_; + string nnet3_rxfilename_; + string hclg_fst_rxfilename_; + string hcl_fst_rxfilename_; + string g_fst_rxfilename_; + string ctx_dep_rxfilename_; + string disambig_rxfilename_; + string word_syms_rxfilename_; + string winfo_rxfilename_; + string carpa_rxfilename_; + string std_fst_rxfilename_; + string final_ie_rxfilename_; + string mfcc_conf_rxfilename_; + string fbank_conf_rxfilename_; + string global_cmvn_stats_rxfilename_; + string pitch_conf_rxfilename_; + string phone_syms_rxfilename_; + + string rnnlm_word_feats_rxfilename_; + string rnnlm_feat_embedding_rxfilename_; + string rnnlm_config_rxfilename_; + string rnnlm_lm_rxfilename_; + + kaldi::OnlineEndpointConfig endpoint_config_; + kaldi::LatticeIncrementalDecoderConfig nnet3_decoding_config_; + kaldi::nnet3::NnetSimpleLoopedComputationOptions decodable_opts_; + kaldi::OnlineNnet2FeaturePipelineInfo feature_info_; + + kaldi::nnet3::DecodableNnetSimpleLoopedInfo *decodable_info_ = nullptr; + kaldi::TransitionModel *trans_model_ = nullptr; + kaldi::nnet3::AmNnetSimple *nnet_ = nullptr; + const fst::SymbolTable *word_syms_ = nullptr; + bool word_syms_loaded_ = false; + kaldi::WordBoundaryInfo *winfo_ = nullptr; + vector disambig_; + const fst::SymbolTable *phone_syms_ = nullptr; + + fst::Fst *hclg_fst_ = nullptr; + fst::Fst *hcl_fst_ = nullptr; + fst::Fst *g_fst_ = nullptr; + + ContextDependency *ctx_dep_ = nullptr; + + fst::VectorFst *graph_lm_fst_ = nullptr; + kaldi::ConstArpaLm const_arpa_; + + kaldi::rnnlm::RnnlmComputeStateComputationOptions rnnlm_compute_opts; + CuMatrix word_embedding_mat; + kaldi::nnet3::Nnet rnnlm; + bool rnnlm_enabled_ = false; + + std::atomic ref_cnt_; }; #endif /* VOSK_MODEL_H */ diff --git a/src/recognizer.cc b/src/recognizer.cc index f75c86bc..8c71d22c 100644 --- a/src/recognizer.cc +++ b/src/recognizer.cc @@ -13,532 +13,553 @@ // limitations under the License. #include "recognizer.h" -#include "json.h" #include "fstext/fstext-utils.h" -#include "lat/sausages.h" +#include "json.h" #include "language_model.h" +#include "lat/sausages.h" using namespace fst; using namespace kaldi::nnet3; -Recognizer::Recognizer(Model *model, float sample_frequency) : model_(model), spk_model_(0), sample_frequency_(sample_frequency) { +Recognizer::Recognizer(Model *model, float sample_frequency) + : model_(model), spk_model_(0), sample_frequency_(sample_frequency) { - model_->Ref(); + model_->Ref(); - feature_pipeline_ = new kaldi::OnlineNnet2FeaturePipeline (model_->feature_info_); - silence_weighting_ = new kaldi::OnlineSilenceWeighting(*model_->trans_model_, model_->feature_info_.silence_weighting_config, 3); + feature_pipeline_ = + new kaldi::OnlineNnet2FeaturePipeline(model_->feature_info_); + silence_weighting_ = new kaldi::OnlineSilenceWeighting( + *model_->trans_model_, model_->feature_info_.silence_weighting_config, 3); - if (!model_->hclg_fst_) { - if (model_->hcl_fst_ && model_->g_fst_) { - decode_fst_ = LookaheadComposeFst(*model_->hcl_fst_, *model_->g_fst_, model_->disambig_); - } else { - KALDI_ERR << "Can't create decoding graph"; - } + if (!model_->hclg_fst_) { + if (GetHclFst() && model_->g_fst_) { + decode_fst_ = + LookaheadComposeFst(*GetHclFst(), *model_->g_fst_, *GetDisambig()); + } else { + KALDI_ERR << "Can't create decoding graph"; } + } - decoder_ = new kaldi::SingleUtteranceNnet3IncrementalDecoder(model_->nnet3_decoding_config_, - *model_->trans_model_, - *model_->decodable_info_, - model_->hclg_fst_ ? *model_->hclg_fst_ : *decode_fst_, - feature_pipeline_); + decoder_ = new kaldi::SingleUtteranceNnet3IncrementalDecoder( + model_->nnet3_decoding_config_, *model_->trans_model_, + *model_->decodable_info_, + model_->hclg_fst_ ? *model_->hclg_fst_ : *decode_fst_, feature_pipeline_); - InitState(); - InitRescoring(); + InitState(); + InitRescoring(); } -Recognizer::Recognizer(Model *model, float sample_frequency, char const *grammar) : model_(model), spk_model_(0), sample_frequency_(sample_frequency) -{ - model_->Ref(); +Recognizer::Recognizer(Model *model, float sample_frequency, + char const *grammar) + : model_(model), spk_model_(0), sample_frequency_(sample_frequency) { + model_->Ref(); - feature_pipeline_ = new kaldi::OnlineNnet2FeaturePipeline (model_->feature_info_); - silence_weighting_ = new kaldi::OnlineSilenceWeighting(*model_->trans_model_, model_->feature_info_.silence_weighting_config, 3); + feature_pipeline_ = + new kaldi::OnlineNnet2FeaturePipeline(model_->feature_info_); + silence_weighting_ = new kaldi::OnlineSilenceWeighting( + *model_->trans_model_, model_->feature_info_.silence_weighting_config, 3); - if (model_->hcl_fst_) { - UpdateGrammarFst(grammar); - } else { - KALDI_WARN << "Runtime graphs are not supported by this model"; - } + if (model_->hcl_fst_) { + UpdateGrammarFst(grammar); + } else { + KALDI_WARN << "Runtime graphs are not supported by this model"; + } - decoder_ = new kaldi::SingleUtteranceNnet3IncrementalDecoder(model_->nnet3_decoding_config_, - *model_->trans_model_, - *model_->decodable_info_, - model_->hclg_fst_ ? *model_->hclg_fst_ : *decode_fst_, - feature_pipeline_); + decoder_ = new kaldi::SingleUtteranceNnet3IncrementalDecoder( + model_->nnet3_decoding_config_, *model_->trans_model_, + *model_->decodable_info_, + model_->hclg_fst_ ? *model_->hclg_fst_ : *decode_fst_, feature_pipeline_); - InitState(); - InitRescoring(); + InitState(); + InitRescoring(); } -Recognizer::Recognizer(Model *model, float sample_frequency, SpkModel *spk_model) : model_(model), spk_model_(spk_model), sample_frequency_(sample_frequency) { +Recognizer::Recognizer(Model *model, float sample_frequency, + SpkModel *spk_model) + : model_(model), spk_model_(spk_model), + sample_frequency_(sample_frequency) { - model_->Ref(); - spk_model->Ref(); + model_->Ref(); + spk_model->Ref(); - feature_pipeline_ = new kaldi::OnlineNnet2FeaturePipeline (model_->feature_info_); - silence_weighting_ = new kaldi::OnlineSilenceWeighting(*model_->trans_model_, model_->feature_info_.silence_weighting_config, 3); + feature_pipeline_ = + new kaldi::OnlineNnet2FeaturePipeline(model_->feature_info_); + silence_weighting_ = new kaldi::OnlineSilenceWeighting( + *model_->trans_model_, model_->feature_info_.silence_weighting_config, 3); - if (!model_->hclg_fst_) { - if (model_->hcl_fst_ && model_->g_fst_) { - decode_fst_ = LookaheadComposeFst(*model_->hcl_fst_, *model_->g_fst_, model_->disambig_); - } else { - KALDI_ERR << "Can't create decoding graph"; - } + if (!model_->hclg_fst_) { + if (model_->hcl_fst_ && model_->g_fst_) { + decode_fst_ = + LookaheadComposeFst(*GetHclFst(), *model_->g_fst_, *GetDisambig()); + } else { + KALDI_ERR << "Can't create decoding graph"; } + } - decoder_ = new kaldi::SingleUtteranceNnet3IncrementalDecoder(model_->nnet3_decoding_config_, - *model_->trans_model_, - *model_->decodable_info_, - model_->hclg_fst_ ? *model_->hclg_fst_ : *decode_fst_, - feature_pipeline_); + decoder_ = new kaldi::SingleUtteranceNnet3IncrementalDecoder( + model_->nnet3_decoding_config_, *model_->trans_model_, + *model_->decodable_info_, + model_->hclg_fst_ ? *model_->hclg_fst_ : *decode_fst_, feature_pipeline_); - spk_feature_ = new OnlineMfcc(spk_model_->spkvector_mfcc_opts); + spk_feature_ = new OnlineMfcc(spk_model_->spkvector_mfcc_opts); - InitState(); - InitRescoring(); + InitState(); + InitRescoring(); } Recognizer::~Recognizer() { - delete decoder_; - delete feature_pipeline_; - delete silence_weighting_; - delete g_fst_; - delete decode_fst_; - delete spk_feature_; - - delete lm_to_subtract_; - delete carpa_to_add_; - delete carpa_to_add_scale_; - delete rnnlm_info_; - delete rnnlm_to_add_; - delete rnnlm_to_add_scale_; - - model_->Unref(); - if (spk_model_) - spk_model_->Unref(); + delete decoder_; + delete feature_pipeline_; + delete silence_weighting_; + delete g_fst_; + delete decode_fst_; + delete spk_feature_; + + delete lm_to_subtract_; + delete carpa_to_add_; + delete carpa_to_add_scale_; + delete rnnlm_info_; + delete rnnlm_to_add_; + delete rnnlm_to_add_scale_; + + model_->Unref(); + if (spk_model_) + spk_model_->Unref(); } -void Recognizer::InitState() -{ - frame_offset_ = 0; - samples_processed_ = 0; - samples_round_start_ = 0; +void Recognizer::InitState() { + frame_offset_ = 0; + samples_processed_ = 0; + samples_round_start_ = 0; - state_ = RECOGNIZER_INITIALIZED; + state_ = RECOGNIZER_INITIALIZED; } -void Recognizer::InitRescoring() -{ - if (model_->graph_lm_fst_) { - - fst::CacheOptions cache_opts(true, -1); - fst::ArcMapFstOptions mapfst_opts(cache_opts); - fst::StdToLatticeMapper mapper; - - lm_to_subtract_ = new fst::ArcMapFst >(*model_->graph_lm_fst_, mapper, mapfst_opts); - carpa_to_add_ = new ConstArpaLmDeterministicFst(model_->const_arpa_); - - if (model_->rnnlm_enabled_) { - int lm_order = 4; - rnnlm_info_ = new kaldi::rnnlm::RnnlmComputeStateInfo(model_->rnnlm_compute_opts, model_->rnnlm, model_->word_embedding_mat); - rnnlm_to_add_ = new kaldi::rnnlm::KaldiRnnlmDeterministicFst(lm_order, *rnnlm_info_); - rnnlm_to_add_scale_ = new fst::ScaleDeterministicOnDemandFst(0.5, rnnlm_to_add_); - carpa_to_add_scale_ = new fst::ScaleDeterministicOnDemandFst(-0.5, carpa_to_add_); - } +void Recognizer::InitRescoring() { + if (model_->graph_lm_fst_) { + + fst::CacheOptions cache_opts(true, -1); + fst::ArcMapFstOptions mapfst_opts(cache_opts); + fst::StdToLatticeMapper mapper; + + lm_to_subtract_ = new fst::ArcMapFst>( + *model_->graph_lm_fst_, mapper, mapfst_opts); + carpa_to_add_ = new ConstArpaLmDeterministicFst(model_->const_arpa_); + + if (model_->rnnlm_enabled_) { + int lm_order = 4; + rnnlm_info_ = new kaldi::rnnlm::RnnlmComputeStateInfo( + model_->rnnlm_compute_opts, model_->rnnlm, + model_->word_embedding_mat); + rnnlm_to_add_ = + new kaldi::rnnlm::KaldiRnnlmDeterministicFst(lm_order, *rnnlm_info_); + rnnlm_to_add_scale_ = + new fst::ScaleDeterministicOnDemandFst(0.5, rnnlm_to_add_); + carpa_to_add_scale_ = + new fst::ScaleDeterministicOnDemandFst(-0.5, carpa_to_add_); } + } } -void Recognizer::CleanUp() -{ - delete silence_weighting_; - silence_weighting_ = new kaldi::OnlineSilenceWeighting(*model_->trans_model_, model_->feature_info_.silence_weighting_config, 3); +void Recognizer::CleanUp() { + delete silence_weighting_; + silence_weighting_ = new kaldi::OnlineSilenceWeighting( + *model_->trans_model_, model_->feature_info_.silence_weighting_config, 3); - if (decoder_) - frame_offset_ += decoder_->NumFramesDecoded(); + if (decoder_) + frame_offset_ += decoder_->NumFramesDecoded(); - // Each 10 minutes we drop the pipeline to save frontend memory in continuous processing - // here we drop few frames remaining in the feature pipeline but hope it will not - // cause a huge accuracy drop since it happens not very frequently. + // Each 10 minutes we drop the pipeline to save frontend memory in continuous + // processing here we drop few frames remaining in the feature pipeline but + // hope it will not cause a huge accuracy drop since it happens not very + // frequently. - // Also restart if we retrieved final result already + // Also restart if we retrieved final result already - if (decoder_ == nullptr || state_ == RECOGNIZER_FINALIZED || frame_offset_ > 20000) { - samples_round_start_ += samples_processed_; - samples_processed_ = 0; - frame_offset_ = 0; + if (decoder_ == nullptr || state_ == RECOGNIZER_FINALIZED || + frame_offset_ > 20000) { + samples_round_start_ += samples_processed_; + samples_processed_ = 0; + frame_offset_ = 0; - delete decoder_; - delete feature_pipeline_; + delete decoder_; + delete feature_pipeline_; - feature_pipeline_ = new kaldi::OnlineNnet2FeaturePipeline (model_->feature_info_); - decoder_ = new kaldi::SingleUtteranceNnet3IncrementalDecoder(model_->nnet3_decoding_config_, - *model_->trans_model_, - *model_->decodable_info_, - model_->hclg_fst_ ? *model_->hclg_fst_ : *decode_fst_, - feature_pipeline_); + feature_pipeline_ = + new kaldi::OnlineNnet2FeaturePipeline(model_->feature_info_); + decoder_ = new kaldi::SingleUtteranceNnet3IncrementalDecoder( + model_->nnet3_decoding_config_, *model_->trans_model_, + *model_->decodable_info_, + model_->hclg_fst_ ? *model_->hclg_fst_ : *decode_fst_, + feature_pipeline_); - if (spk_model_) { - delete spk_feature_; - spk_feature_ = new OnlineMfcc(spk_model_->spkvector_mfcc_opts); - } - } else { - decoder_->InitDecoding(frame_offset_); + if (spk_model_) { + delete spk_feature_; + spk_feature_ = new OnlineMfcc(spk_model_->spkvector_mfcc_opts); } + } else { + decoder_->InitDecoding(frame_offset_); + } } -void Recognizer::UpdateSilenceWeights() -{ - if (silence_weighting_->Active() && feature_pipeline_->NumFramesReady() > 0 && - feature_pipeline_->IvectorFeature() != nullptr) { - vector > delta_weights; - silence_weighting_->ComputeCurrentTraceback(decoder_->Decoder()); - silence_weighting_->GetDeltaWeights(feature_pipeline_->NumFramesReady(), - frame_offset_ * 3, - &delta_weights); - feature_pipeline_->UpdateFrameWeights(delta_weights); - } +void Recognizer::UpdateSilenceWeights() { + if (silence_weighting_->Active() && feature_pipeline_->NumFramesReady() > 0 && + feature_pipeline_->IvectorFeature() != nullptr) { + vector> delta_weights; + silence_weighting_->ComputeCurrentTraceback(decoder_->Decoder()); + silence_weighting_->GetDeltaWeights(feature_pipeline_->NumFramesReady(), + frame_offset_ * 3, &delta_weights); + feature_pipeline_->UpdateFrameWeights(delta_weights); + } } -void Recognizer::SetMaxAlternatives(int max_alternatives) -{ - max_alternatives_ = max_alternatives; +void Recognizer::SetMaxAlternatives(int max_alternatives) { + max_alternatives_ = max_alternatives; } -void Recognizer::SetWords(bool words) -{ - words_ = words; -} +void Recognizer::SetWords(bool words) { words_ = words; } -void Recognizer::SetPartialWords(bool partial_words) -{ - partial_words_ = partial_words; +void Recognizer::SetPartialWords(bool partial_words) { + partial_words_ = partial_words; } -void Recognizer::SetNLSML(bool nlsml) -{ - nlsml_ = nlsml; -} +void Recognizer::SetNLSML(bool nlsml) { nlsml_ = nlsml; } -void Recognizer::SetSpkModel(SpkModel *spk_model) -{ - if (state_ == RECOGNIZER_RUNNING) { - KALDI_ERR << "Can't add speaker model to already running recognizer"; - return; - } - spk_model_ = spk_model; - spk_model_->Ref(); - spk_feature_ = new OnlineMfcc(spk_model_->spkvector_mfcc_opts); +void Recognizer::SetSpkModel(SpkModel *spk_model) { + if (state_ == RECOGNIZER_RUNNING) { + KALDI_ERR << "Can't add speaker model to already running recognizer"; + return; + } + spk_model_ = spk_model; + spk_model_->Ref(); + spk_feature_ = new OnlineMfcc(spk_model_->spkvector_mfcc_opts); } -void Recognizer::SetGrm(char const *grammar) -{ - if (state_ == RECOGNIZER_RUNNING) { - KALDI_ERR << "Can't add speaker model to already running recognizer"; - return; - } +void Recognizer::SetGrm(char const *grammar, const char *const *words, + const char *const *pronunciations, int num_words) { + if (state_ == RECOGNIZER_RUNNING) { + KALDI_ERR << "Can't add speaker model to already running recognizer"; + return; + } - if (!model_->hcl_fst_) { - KALDI_WARN << "Runtime graphs are not supported by this model"; - return; - } + if (!model_->hcl_fst_) { + KALDI_WARN << "Runtime graphs are not supported by this model"; + return; + } + if (!strcmp(grammar, "[]")) { + delete hcl_fst_; + delete disambig_; delete decode_fst_; - - if (!strcmp(grammar, "[]")) { - decode_fst_ = LookaheadComposeFst(*model_->hcl_fst_, *model_->g_fst_, model_->disambig_); - } else { - UpdateGrammarFst(grammar); + decode_fst_ = + LookaheadComposeFst(*GetHclFst(), *model_->g_fst_, *GetDisambig()); + } else { + // Update HCLr fst if needed + if (num_words > 0 && words != nullptr && pronunciations != nullptr) { + KALDI_LOG << "Rebuilding lexicon with " << num_words << " words"; + vector words_vec(words, words + num_words); + vector pronunciations_vec(pronunciations, + pronunciations + num_words); + auto t0 = chrono::high_resolution_clock::now(); + RebuildLexicon(words_vec, pronunciations_vec); + if (GetHclFst() == nullptr) { + KALDI_ERR << "Failed to rebuild lexicon"; + return; + } + auto t1 = chrono::high_resolution_clock::now(); + auto duration = + chrono::duration_cast(t1 - t0).count(); + KALDI_LOG << "Rebuilding lexicon done in " << duration << "ms"; } + // Update grammar fst + delete decode_fst_; + UpdateGrammarFst(grammar); + } - samples_round_start_ += samples_processed_; - samples_processed_ = 0; - frame_offset_ = 0; + samples_round_start_ += samples_processed_; + samples_processed_ = 0; + frame_offset_ = 0; - delete decoder_; - delete feature_pipeline_; - delete silence_weighting_; + delete decoder_; + delete feature_pipeline_; + delete silence_weighting_; - silence_weighting_ = new kaldi::OnlineSilenceWeighting(*model_->trans_model_, model_->feature_info_.silence_weighting_config, 3); - feature_pipeline_ = new kaldi::OnlineNnet2FeaturePipeline (model_->feature_info_); - decoder_ = new kaldi::SingleUtteranceNnet3IncrementalDecoder(model_->nnet3_decoding_config_, - *model_->trans_model_, - *model_->decodable_info_, - *decode_fst_, - feature_pipeline_); + silence_weighting_ = new kaldi::OnlineSilenceWeighting( + *model_->trans_model_, model_->feature_info_.silence_weighting_config, 3); + feature_pipeline_ = + new kaldi::OnlineNnet2FeaturePipeline(model_->feature_info_); + decoder_ = new kaldi::SingleUtteranceNnet3IncrementalDecoder( + model_->nnet3_decoding_config_, *model_->trans_model_, + *model_->decodable_info_, *decode_fst_, feature_pipeline_); - if (spk_model_) { - delete spk_feature_; - spk_feature_ = new OnlineMfcc(spk_model_->spkvector_mfcc_opts); - } + if (spk_model_) { + delete spk_feature_; + spk_feature_ = new OnlineMfcc(spk_model_->spkvector_mfcc_opts); + } - state_ = RECOGNIZER_INITIALIZED; + state_ = RECOGNIZER_INITIALIZED; } +void Recognizer::UpdateGrammarFst(char const *grammar) { + json::JSON obj; + obj = json::JSON::Load(grammar); -void Recognizer::UpdateGrammarFst(char const *grammar) -{ - json::JSON obj; - obj = json::JSON::Load(grammar); + if (obj.length() <= 0) { + KALDI_WARN << "Expecting array of strings, got: '" << grammar << "'"; + return; + } - if (obj.length() <= 0) { - KALDI_WARN << "Expecting array of strings, got: '" << grammar << "'"; - return; + KALDI_LOG << obj; + + LanguageModelOptions opts; + + opts.ngram_order = 2; + opts.discount = 0.5; + + LanguageModelEstimator estimator(opts); + for (int i = 0; i < obj.length(); i++) { + bool ok; + string line = obj[i].ToString(ok); + if (!ok) { + KALDI_ERR << "Expecting array of strings, got: '" << obj << "'"; } - KALDI_LOG << obj; - - LanguageModelOptions opts; - - opts.ngram_order = 2; - opts.discount = 0.5; - - LanguageModelEstimator estimator(opts); - for (int i = 0; i < obj.length(); i++) { - bool ok; - string line = obj[i].ToString(ok); - if (!ok) { - KALDI_ERR << "Expecting array of strings, got: '" << obj << "'"; - } - - std::vector sentence; - stringstream ss(line); - string token; - while (getline(ss, token, ' ')) { - int32 id = model_->word_syms_->Find(token); - if (id == kNoSymbol) { - KALDI_WARN << "Ignoring word missing in vocabulary: '" << token << "'"; - } else { - sentence.push_back(id); - } - } - estimator.AddCounts(sentence); + std::vector sentence; + stringstream ss(line); + string token; + while (getline(ss, token, ' ')) { + int32 id = model_->word_syms_->Find(token); + if (id == kNoSymbol) { + KALDI_WARN << "Ignoring word missing in vocabulary: '" << token << "'"; + } else { + sentence.push_back(id); + } } - g_fst_ = new StdVectorFst(); - estimator.Estimate(g_fst_); + estimator.AddCounts(sentence); + } + g_fst_ = new StdVectorFst(); + estimator.Estimate(g_fst_); - decode_fst_ = LookaheadComposeFst(*model_->hcl_fst_, *g_fst_, model_->disambig_); + decode_fst_ = LookaheadComposeFst(*GetHclFst(), *g_fst_, *GetDisambig()); } - -bool Recognizer::AcceptWaveform(const char *data, int len) -{ - Vector wave; - wave.Resize(len / 2, kUndefined); - for (int i = 0; i < len / 2; i++) - wave(i) = *(((short *)data) + i); - return AcceptWaveform(wave); +bool Recognizer::AcceptWaveform(const char *data, int len) { + Vector wave; + wave.Resize(len / 2, kUndefined); + for (int i = 0; i < len / 2; i++) + wave(i) = *(((short *)data) + i); + return AcceptWaveform(wave); } -bool Recognizer::AcceptWaveform(const short *sdata, int len) -{ - Vector wave; - wave.Resize(len, kUndefined); - for (int i = 0; i < len; i++) - wave(i) = sdata[i]; - return AcceptWaveform(wave); +bool Recognizer::AcceptWaveform(const short *sdata, int len) { + Vector wave; + wave.Resize(len, kUndefined); + for (int i = 0; i < len; i++) + wave(i) = sdata[i]; + return AcceptWaveform(wave); } -bool Recognizer::AcceptWaveform(const float *fdata, int len) -{ - Vector wave; - wave.Resize(len, kUndefined); - for (int i = 0; i < len; i++) - wave(i) = fdata[i]; - return AcceptWaveform(wave); +bool Recognizer::AcceptWaveform(const float *fdata, int len) { + Vector wave; + wave.Resize(len, kUndefined); + for (int i = 0; i < len; i++) + wave(i) = fdata[i]; + return AcceptWaveform(wave); } -bool Recognizer::AcceptWaveform(Vector &wdata) -{ - // Cleanup if we finalized previous utterance or the whole feature pipeline - if (!(state_ == RECOGNIZER_RUNNING || state_ == RECOGNIZER_INITIALIZED)) { - CleanUp(); - } - state_ = RECOGNIZER_RUNNING; - - int step = static_cast(sample_frequency_ * 0.2); - for (int i = 0; i < wdata.Dim(); i+= step) { - SubVector r = wdata.Range(i, std::min(step, wdata.Dim() - i)); - feature_pipeline_->AcceptWaveform(sample_frequency_, r); - UpdateSilenceWeights(); - decoder_->AdvanceDecoding(); - } - samples_processed_ += wdata.Dim(); +bool Recognizer::AcceptWaveform(Vector &wdata) { + // Cleanup if we finalized previous utterance or the whole feature pipeline + if (!(state_ == RECOGNIZER_RUNNING || state_ == RECOGNIZER_INITIALIZED)) { + CleanUp(); + } + state_ = RECOGNIZER_RUNNING; - if (spk_feature_) { - spk_feature_->AcceptWaveform(sample_frequency_, wdata); - } + int step = static_cast(sample_frequency_ * 0.2); + for (int i = 0; i < wdata.Dim(); i += step) { + SubVector r = wdata.Range(i, std::min(step, wdata.Dim() - i)); + feature_pipeline_->AcceptWaveform(sample_frequency_, r); + UpdateSilenceWeights(); + decoder_->AdvanceDecoding(); + } + samples_processed_ += wdata.Dim(); - if (decoder_->EndpointDetected(model_->endpoint_config_)) { - return true; - } + if (spk_feature_) { + spk_feature_->AcceptWaveform(sample_frequency_, wdata); + } - return false; + if (decoder_->EndpointDetected(model_->endpoint_config_)) { + return true; + } + + return false; } // Computes an xvector from a chunk of speech features. static void RunNnetComputation(const MatrixBase &features, - const nnet3::Nnet &nnet, nnet3::CachingOptimizingCompiler *compiler, - Vector *xvector) -{ - nnet3::ComputationRequest request; - request.need_model_derivative = false; - request.store_component_stats = false; - request.inputs.push_back( - nnet3::IoSpecification("input", 0, features.NumRows())); - nnet3::IoSpecification output_spec; - output_spec.name = "output"; - output_spec.has_deriv = false; - output_spec.indexes.resize(1); - request.outputs.resize(1); - request.outputs[0].Swap(&output_spec); - shared_ptr computation = compiler->Compile(request); - nnet3::Nnet *nnet_to_update = nullptr; // we're not doing any update. - nnet3::NnetComputer computer(nnet3::NnetComputeOptions(), *computation, - nnet, nnet_to_update); - CuMatrix input_feats_cu(features); - computer.AcceptInput("input", &input_feats_cu); - computer.Run(); - CuMatrix cu_output; - computer.GetOutputDestructive("output", &cu_output); - xvector->Resize(cu_output.NumCols()); - xvector->CopyFromVec(cu_output.Row(0)); + const nnet3::Nnet &nnet, + nnet3::CachingOptimizingCompiler *compiler, + Vector *xvector) { + nnet3::ComputationRequest request; + request.need_model_derivative = false; + request.store_component_stats = false; + request.inputs.push_back( + nnet3::IoSpecification("input", 0, features.NumRows())); + nnet3::IoSpecification output_spec; + output_spec.name = "output"; + output_spec.has_deriv = false; + output_spec.indexes.resize(1); + request.outputs.resize(1); + request.outputs[0].Swap(&output_spec); + shared_ptr computation = + compiler->Compile(request); + nnet3::Nnet *nnet_to_update = nullptr; // we're not doing any update. + nnet3::NnetComputer computer(nnet3::NnetComputeOptions(), *computation, nnet, + nnet_to_update); + CuMatrix input_feats_cu(features); + computer.AcceptInput("input", &input_feats_cu); + computer.Run(); + CuMatrix cu_output; + computer.GetOutputDestructive("output", &cu_output); + xvector->Resize(cu_output.NumCols()); + xvector->CopyFromVec(cu_output.Row(0)); } #define MIN_SPK_FEATS 50 -bool Recognizer::GetSpkVector(Vector &out_xvector, int *num_spk_frames) -{ - vector nonsilence_frames; - if (silence_weighting_->Active() && feature_pipeline_->NumFramesReady() > 0) { - silence_weighting_->ComputeCurrentTraceback(decoder_->Decoder(), true); - silence_weighting_->GetNonsilenceFrames(feature_pipeline_->NumFramesReady(), - frame_offset_ * 3, - &nonsilence_frames); - } - - int num_frames = spk_feature_->NumFramesReady() - frame_offset_ * 3; - Matrix mfcc(num_frames, spk_feature_->Dim()); +bool Recognizer::GetSpkVector(Vector &out_xvector, + int *num_spk_frames) { + vector nonsilence_frames; + if (silence_weighting_->Active() && feature_pipeline_->NumFramesReady() > 0) { + silence_weighting_->ComputeCurrentTraceback(decoder_->Decoder(), true); + silence_weighting_->GetNonsilenceFrames(feature_pipeline_->NumFramesReady(), + frame_offset_ * 3, + &nonsilence_frames); + } - // Not very efficient, would be nice to have faster search - int num_nonsilence_frames = 0; - Vector feat(spk_feature_->Dim()); + int num_frames = spk_feature_->NumFramesReady() - frame_offset_ * 3; + Matrix mfcc(num_frames, spk_feature_->Dim()); - for (int i = 0; i < num_frames; ++i) { - if (std::find(nonsilence_frames.begin(), - nonsilence_frames.end(), i / 3) == nonsilence_frames.end()) { - continue; - } + // Not very efficient, would be nice to have faster search + int num_nonsilence_frames = 0; + Vector feat(spk_feature_->Dim()); - spk_feature_->GetFrame(i + frame_offset_ * 3, &feat); - mfcc.CopyRowFromVec(feat, num_nonsilence_frames); - num_nonsilence_frames++; + for (int i = 0; i < num_frames; ++i) { + if (std::find(nonsilence_frames.begin(), nonsilence_frames.end(), i / 3) == + nonsilence_frames.end()) { + continue; } - *num_spk_frames = num_nonsilence_frames; + spk_feature_->GetFrame(i + frame_offset_ * 3, &feat); + mfcc.CopyRowFromVec(feat, num_nonsilence_frames); + num_nonsilence_frames++; + } - // Don't extract vector if not enough data - if (num_nonsilence_frames < MIN_SPK_FEATS) { - return false; - } + *num_spk_frames = num_nonsilence_frames; + + // Don't extract vector if not enough data + if (num_nonsilence_frames < MIN_SPK_FEATS) { + return false; + } - mfcc.Resize(num_nonsilence_frames, spk_feature_->Dim(), kCopyData); + mfcc.Resize(num_nonsilence_frames, spk_feature_->Dim(), kCopyData); - SlidingWindowCmnOptions cmvn_opts; - cmvn_opts.center = true; - cmvn_opts.cmn_window = 300; - Matrix features(mfcc.NumRows(), mfcc.NumCols(), kUndefined); - SlidingWindowCmn(cmvn_opts, mfcc, &features); + SlidingWindowCmnOptions cmvn_opts; + cmvn_opts.center = true; + cmvn_opts.cmn_window = 300; + Matrix features(mfcc.NumRows(), mfcc.NumCols(), kUndefined); + SlidingWindowCmn(cmvn_opts, mfcc, &features); - nnet3::NnetSimpleComputationOptions opts; - nnet3::CachingOptimizingCompilerOptions compiler_config; - nnet3::CachingOptimizingCompiler compiler(spk_model_->speaker_nnet, opts.optimize_config, compiler_config); + nnet3::NnetSimpleComputationOptions opts; + nnet3::CachingOptimizingCompilerOptions compiler_config; + nnet3::CachingOptimizingCompiler compiler( + spk_model_->speaker_nnet, opts.optimize_config, compiler_config); - Vector xvector; - RunNnetComputation(features, spk_model_->speaker_nnet, &compiler, &xvector); + Vector xvector; + RunNnetComputation(features, spk_model_->speaker_nnet, &compiler, &xvector); - // Whiten the vector with global mean and transform and normalize mean - xvector.AddVec(-1.0, spk_model_->mean); + // Whiten the vector with global mean and transform and normalize mean + xvector.AddVec(-1.0, spk_model_->mean); - out_xvector.Resize(spk_model_->transform.NumRows(), kSetZero); - out_xvector.AddMatVec(1.0, spk_model_->transform, kNoTrans, xvector, 0.0); + out_xvector.Resize(spk_model_->transform.NumRows(), kSetZero); + out_xvector.AddMatVec(1.0, spk_model_->transform, kNoTrans, xvector, 0.0); - BaseFloat norm = out_xvector.Norm(2.0); - BaseFloat ratio = norm / sqrt(out_xvector.Dim()); // how much larger it is - // than it would be, in - // expectation, if normally - out_xvector.Scale(1.0 / ratio); + BaseFloat norm = out_xvector.Norm(2.0); + BaseFloat ratio = norm / sqrt(out_xvector.Dim()); // how much larger it is + // than it would be, in + // expectation, if normally + out_xvector.Scale(1.0 / ratio); - return true; + return true; } // If we can't align, we still need to prepare for MBR -static void CopyLatticeForMbr(CompactLattice &lat, CompactLattice *lat_out) -{ - *lat_out = lat; - RmEpsilon(lat_out, true); - fst::CreateSuperFinal(lat_out); - TopSortCompactLatticeIfNeeded(lat_out); +static void CopyLatticeForMbr(CompactLattice &lat, CompactLattice *lat_out) { + *lat_out = lat; + RmEpsilon(lat_out, true); + fst::CreateSuperFinal(lat_out); + TopSortCompactLatticeIfNeeded(lat_out); } -const char *Recognizer::MbrResult(CompactLattice &rlat) -{ +const char *Recognizer::MbrResult(CompactLattice &rlat) { - CompactLattice aligned_lat; - if (model_->winfo_) { - WordAlignLattice(rlat, *model_->trans_model_, *model_->winfo_, 0, &aligned_lat); - } else { - CopyLatticeForMbr(rlat, &aligned_lat); - } + CompactLattice aligned_lat; + if (model_->winfo_) { + WordAlignLattice(rlat, *model_->trans_model_, *model_->winfo_, 0, + &aligned_lat); + } else { + CopyLatticeForMbr(rlat, &aligned_lat); + } - MinimumBayesRisk mbr(aligned_lat); - const vector &conf = mbr.GetOneBestConfidences(); - const vector &words = mbr.GetOneBest(); - const vector > × = - mbr.GetOneBestTimes(); + MinimumBayesRisk mbr(aligned_lat); + const vector &conf = mbr.GetOneBestConfidences(); + const vector &words = mbr.GetOneBest(); + const vector> × = mbr.GetOneBestTimes(); - int size = words.size(); + int size = words.size(); - json::JSON obj; - stringstream text; + json::JSON obj; + stringstream text; - // Create JSON object - for (int i = 0; i < size; i++) { - json::JSON word; - - if (words_) { - word["word"] = model_->word_syms_->Find(words[i]); - word["start"] = samples_round_start_ / sample_frequency_ + (frame_offset_ + times[i].first) * 0.03; - word["end"] = samples_round_start_ / sample_frequency_ + (frame_offset_ + times[i].second) * 0.03; - word["conf"] = conf[i]; - obj["result"].append(word); - } - - if (i) { - text << " "; - } - text << model_->word_syms_->Find(words[i]); + // Create JSON object + for (int i = 0; i < size; i++) { + json::JSON word; + + if (words_) { + word["word"] = model_->word_syms_->Find(words[i]); + word["start"] = samples_round_start_ / sample_frequency_ + + (frame_offset_ + times[i].first) * 0.03; + word["end"] = samples_round_start_ / sample_frequency_ + + (frame_offset_ + times[i].second) * 0.03; + word["conf"] = conf[i]; + obj["result"].append(word); } - obj["text"] = text.str(); - if (spk_model_) { - Vector xvector; - int num_spk_frames; - if (GetSpkVector(xvector, &num_spk_frames)) { - for (int i = 0; i < xvector.Dim(); i++) { - obj["spk"].append(xvector(i)); - } - obj["spk_frames"] = num_spk_frames; - } + if (i) { + text << " "; } + text << model_->word_syms_->Find(words[i]); + } + obj["text"] = text.str(); - return StoreReturn(obj.dump()); + if (spk_model_) { + Vector xvector; + int num_spk_frames; + if (GetSpkVector(xvector, &num_spk_frames)) { + for (int i = 0; i < xvector.Dim(); i++) { + obj["spk"].append(xvector(i)); + } + obj["spk_frames"] = num_spk_frames; + } + } + + return StoreReturn(obj.dump()); } -static bool CompactLatticeToWordAlignmentWeight(const CompactLattice &clat, - std::vector *words, - std::vector *begin_times, - std::vector *lengths, - CompactLattice::Weight *tot_weight_out) -{ +static bool CompactLatticeToWordAlignmentWeight( + const CompactLattice &clat, std::vector *words, + std::vector *begin_times, std::vector *lengths, + CompactLattice::Weight *tot_weight_out) { typedef CompactLattice::Arc Arc; typedef Arc::Label Label; typedef CompactLattice::StateId StateId; @@ -568,7 +589,7 @@ static bool CompactLatticeToWordAlignmentWeight(const CompactLattice &clat, } if (!final.String().empty()) { KALDI_WARN << "Lattice has alignments on final-weight: probably " - "was not word-aligned (alignments will be approximate)"; + "was not word-aligned (alignments will be approximate)"; } tot_weight = Times(final, tot_weight); *tot_weight_out = tot_weight; @@ -593,337 +614,596 @@ static bool CompactLatticeToWordAlignmentWeight(const CompactLattice &clat, } } +const char *Recognizer::NbestResult(CompactLattice &clat) { + Lattice lat; + Lattice nbest_lat; + std::vector nbest_lats; -const char *Recognizer::NbestResult(CompactLattice &clat) -{ - Lattice lat; - Lattice nbest_lat; - std::vector nbest_lats; + ConvertLattice(clat, &lat); + fst::ShortestPath(lat, &nbest_lat, max_alternatives_); + fst::ConvertNbestToVector(nbest_lat, &nbest_lats); - ConvertLattice (clat, &lat); - fst::ShortestPath(lat, &nbest_lat, max_alternatives_); - fst::ConvertNbestToVector(nbest_lat, &nbest_lats); + json::JSON obj; + for (int k = 0; k < nbest_lats.size(); k++) { - json::JSON obj; - for (int k = 0; k < nbest_lats.size(); k++) { + Lattice nlat = nbest_lats[k]; - Lattice nlat = nbest_lats[k]; + CompactLattice nclat; + fst::Invert(&nlat); + DeterminizeLattice(nlat, &nclat); - CompactLattice nclat; - fst::Invert(&nlat); - DeterminizeLattice(nlat, &nclat); + CompactLattice aligned_nclat; + if (model_->winfo_) { + WordAlignLattice(nclat, *model_->trans_model_, *model_->winfo_, 0, + &aligned_nclat); + } else { + aligned_nclat = nclat; + } - CompactLattice aligned_nclat; - if (model_->winfo_) { - WordAlignLattice(nclat, *model_->trans_model_, *model_->winfo_, 0, &aligned_nclat); - } else { - aligned_nclat = nclat; - } + std::vector words; + std::vector begin_times; + std::vector lengths; + CompactLattice::Weight weight; + + CompactLatticeToWordAlignmentWeight(aligned_nclat, &words, &begin_times, + &lengths, &weight); + float likelihood = -(weight.Weight().Value1() + weight.Weight().Value2()); - std::vector words; - std::vector begin_times; - std::vector lengths; - CompactLattice::Weight weight; - - CompactLatticeToWordAlignmentWeight(aligned_nclat, &words, &begin_times, &lengths, &weight); - float likelihood = -(weight.Weight().Value1() + weight.Weight().Value2()); - - stringstream text; - json::JSON entry; - - for (int i = 0, first = 1; i < words.size(); i++) { - json::JSON word; - if (words[i] == 0) - continue; - if (words_) { - word["word"] = model_->word_syms_->Find(words[i]); - word["start"] = samples_round_start_ / sample_frequency_ + (frame_offset_ + begin_times[i]) * 0.03; - word["end"] = samples_round_start_ / sample_frequency_ + (frame_offset_ + begin_times[i] + lengths[i]) * 0.03; - entry["result"].append(word); - } - - if (first) - first = 0; - else - text << " "; - - text << model_->word_syms_->Find(words[i]); + stringstream text; + json::JSON entry; + + for (int i = 0, first = 1; i < words.size(); i++) { + json::JSON word; + if (words[i] == 0) + continue; + if (words_) { + word["word"] = model_->word_syms_->Find(words[i]); + word["start"] = samples_round_start_ / sample_frequency_ + + (frame_offset_ + begin_times[i]) * 0.03; + word["end"] = samples_round_start_ / sample_frequency_ + + (frame_offset_ + begin_times[i] + lengths[i]) * 0.03; + entry["result"].append(word); } - entry["text"] = text.str(); - entry["confidence"]= likelihood; - obj["alternatives"].append(entry); + if (first) + first = 0; + else + text << " "; + + text << model_->word_syms_->Find(words[i]); } - return StoreReturn(obj.dump()); + entry["text"] = text.str(); + entry["confidence"] = likelihood; + obj["alternatives"].append(entry); + } + + return StoreReturn(obj.dump()); } -const char *Recognizer::NlsmlResult(CompactLattice &clat) -{ - Lattice lat; - Lattice nbest_lat; - std::vector nbest_lats; +const char *Recognizer::NlsmlResult(CompactLattice &clat) { + Lattice lat; + Lattice nbest_lat; + std::vector nbest_lats; - ConvertLattice (clat, &lat); - fst::ShortestPath(lat, &nbest_lat, max_alternatives_); - fst::ConvertNbestToVector(nbest_lat, &nbest_lats); + ConvertLattice(clat, &lat); + fst::ShortestPath(lat, &nbest_lat, max_alternatives_); + fst::ConvertNbestToVector(nbest_lat, &nbest_lats); - std::stringstream ss; - ss << "\n"; - ss << "\n"; + std::stringstream ss; + ss << "\n"; + ss << "\n"; - for (int k = 0; k < nbest_lats.size(); k++) { + for (int k = 0; k < nbest_lats.size(); k++) { - Lattice nlat = nbest_lats[k]; + Lattice nlat = nbest_lats[k]; - CompactLattice nclat; - fst::Invert(&nlat); - DeterminizeLattice(nlat, &nclat); + CompactLattice nclat; + fst::Invert(&nlat); + DeterminizeLattice(nlat, &nclat); - CompactLattice aligned_nclat; - if (model_->winfo_) { - WordAlignLattice(nclat, *model_->trans_model_, *model_->winfo_, 0, &aligned_nclat); - } else { - aligned_nclat = nclat; - } + CompactLattice aligned_nclat; + if (model_->winfo_) { + WordAlignLattice(nclat, *model_->trans_model_, *model_->winfo_, 0, + &aligned_nclat); + } else { + aligned_nclat = nclat; + } - std::vector words; - std::vector begin_times; - std::vector lengths; - CompactLattice::Weight weight; + std::vector words; + std::vector begin_times; + std::vector lengths; + CompactLattice::Weight weight; - CompactLatticeToWordAlignmentWeight(aligned_nclat, &words, &begin_times, &lengths, &weight); - float likelihood = -(weight.Weight().Value1() + weight.Weight().Value2()); + CompactLatticeToWordAlignmentWeight(aligned_nclat, &words, &begin_times, + &lengths, &weight); + float likelihood = -(weight.Weight().Value1() + weight.Weight().Value2()); - stringstream text; - for (int i = 0, first = 1; i < words.size(); i++) { - if (words[i] == 0) - continue; + stringstream text; + for (int i = 0, first = 1; i < words.size(); i++) { + if (words[i] == 0) + continue; - if (first) - first = 0; - else - text << " "; + if (first) + first = 0; + else + text << " "; - text << model_->word_syms_->Find(words[i]); - } + text << model_->word_syms_->Find(words[i]); + } + + ss << "\n"; + ss << "" << text.str() << "\n"; + ss << "" << text.str() << "\n"; + ss << "\n"; + } + ss << "\n"; + + return StoreReturn(ss.str()); +} + +const char *Recognizer::GetResult() { + if (decoder_->NumFramesDecoded() == 0) { + return StoreEmptyReturn(); + } - ss << "\n"; - ss << "" << text.str() << "\n"; - ss << "" << text.str() << "\n"; - ss << "\n"; + // Original from decoder, subtracted graph weight, rescored with carpa, + // rescored with rnnlm + CompactLattice clat, slat, tlat, rlat; + + clat = decoder_->GetLattice(decoder_->NumFramesDecoded(), true); + + if (lm_to_subtract_ && carpa_to_add_) { + Lattice lat, composed_lat; + + // Delete old score + ConvertLattice(clat, &lat); + fst::ScaleLattice(fst::GraphLatticeScale(-1.0), &lat); + fst::Compose(lat, *lm_to_subtract_, &composed_lat); + fst::Invert(&composed_lat); + DeterminizeLattice(composed_lat, &slat); + fst::ScaleLattice(fst::GraphLatticeScale(-1.0), &slat); + + // Add CARPA score + TopSortCompactLatticeIfNeeded(&slat); + ComposeCompactLatticeDeterministic(slat, carpa_to_add_, &tlat); + + // Rescore with RNNLM score on top if needed + if (rnnlm_to_add_scale_) { + ComposeLatticePrunedOptions compose_opts; + compose_opts.lattice_compose_beam = 3.0; + compose_opts.max_arcs = 3000; + fst::ComposeDeterministicOnDemandFst combined_rnnlm( + carpa_to_add_scale_, rnnlm_to_add_scale_); + + TopSortCompactLatticeIfNeeded(&tlat); + ComposeCompactLatticePruned(compose_opts, tlat, &combined_rnnlm, &rlat); + rnnlm_to_add_->Clear(); + } else { + rlat = tlat; } - ss << "\n"; + } else { + rlat = clat; + } + + // Pruned composition can return empty lattice. It should be rare + if (rlat.Start() != 0) { + return StoreEmptyReturn(); + } - return StoreReturn(ss.str()); + // Apply rescoring weight + fst::ScaleLattice(fst::GraphLatticeScale(0.9), &rlat); + + if (max_alternatives_ == 0) { + return MbrResult(rlat); + } else if (nlsml_) { + return NlsmlResult(rlat); + } else { + return NbestResult(rlat); + } } -const char* Recognizer::GetResult() -{ - if (decoder_->NumFramesDecoded() == 0) { - return StoreEmptyReturn(); +const char *Recognizer::PartialResult() { + if (state_ != RECOGNIZER_RUNNING) { + return StoreEmptyReturn(); + } + + json::JSON res; + + if (partial_words_) { + + if (decoder_->NumFramesInLattice() == 0) { + res["partial"] = ""; + return StoreReturn(res.dump()); } - // Original from decoder, subtracted graph weight, rescored with carpa, rescored with rnnlm - CompactLattice clat, slat, tlat, rlat; - - clat = decoder_->GetLattice(decoder_->NumFramesDecoded(), true); - - if (lm_to_subtract_ && carpa_to_add_) { - Lattice lat, composed_lat; - - // Delete old score - ConvertLattice(clat, &lat); - fst::ScaleLattice(fst::GraphLatticeScale(-1.0), &lat); - fst::Compose(lat, *lm_to_subtract_, &composed_lat); - fst::Invert(&composed_lat); - DeterminizeLattice(composed_lat, &slat); - fst::ScaleLattice(fst::GraphLatticeScale(-1.0), &slat); - - // Add CARPA score - TopSortCompactLatticeIfNeeded(&slat); - ComposeCompactLatticeDeterministic(slat, carpa_to_add_, &tlat); - - // Rescore with RNNLM score on top if needed - if (rnnlm_to_add_scale_) { - ComposeLatticePrunedOptions compose_opts; - compose_opts.lattice_compose_beam = 3.0; - compose_opts.max_arcs = 3000; - fst::ComposeDeterministicOnDemandFst combined_rnnlm(carpa_to_add_scale_, rnnlm_to_add_scale_); - - TopSortCompactLatticeIfNeeded(&tlat); - ComposeCompactLatticePruned(compose_opts, tlat, - &combined_rnnlm, &rlat); - rnnlm_to_add_->Clear(); - } else { - rlat = tlat; - } + CompactLattice clat; + CompactLattice aligned_lat; + + clat = decoder_->GetLattice(decoder_->NumFramesInLattice(), false); + if (model_->winfo_) { + WordAlignLatticePartial(clat, *model_->trans_model_, *model_->winfo_, 0, + &aligned_lat); } else { - rlat = clat; + CopyLatticeForMbr(clat, &aligned_lat); } - // Pruned composition can return empty lattice. It should be rare - if (rlat.Start() != 0) { - return StoreEmptyReturn(); + MinimumBayesRisk mbr(aligned_lat); + const vector &conf = mbr.GetOneBestConfidences(); + const vector &words = mbr.GetOneBest(); + const vector> × = mbr.GetOneBestTimes(); + + int size = words.size(); + + stringstream text; + + // Create JSON object + for (int i = 0; i < size; i++) { + json::JSON word; + + word["word"] = model_->word_syms_->Find(words[i]); + word["start"] = samples_round_start_ / sample_frequency_ + + (frame_offset_ + times[i].first) * 0.03; + word["end"] = samples_round_start_ / sample_frequency_ + + (frame_offset_ + times[i].second) * 0.03; + word["conf"] = conf[i]; + res["partial_result"].append(word); + + if (i) { + text << " "; + } + text << model_->word_syms_->Find(words[i]); } + res["partial"] = text.str(); - // Apply rescoring weight - fst::ScaleLattice(fst::GraphLatticeScale(0.9), &rlat); + } else { - if (max_alternatives_ == 0) { - return MbrResult(rlat); - } else if (nlsml_) { - return NlsmlResult(rlat); - } else { - return NbestResult(rlat); + if (decoder_->NumFramesDecoded() == 0) { + res["partial"] = ""; + return StoreReturn(res.dump()); + } + Lattice lat; + decoder_->GetBestPath(false, &lat); + vector alignment, words; + LatticeWeight weight; + GetLinearSymbolSequence(lat, &alignment, &words, &weight); + + ostringstream text; + for (size_t i = 0; i < words.size(); i++) { + if (i) { + text << " "; + } + text << model_->word_syms_->Find(words[i]); } + res["partial"] = text.str(); + } + return StoreReturn(res.dump()); } +const char *Recognizer::Result() { + if (state_ != RECOGNIZER_RUNNING) { + return StoreEmptyReturn(); + } + decoder_->FinalizeDecoding(); + state_ = RECOGNIZER_ENDPOINT; + return GetResult(); +} -const char* Recognizer::PartialResult() -{ - if (state_ != RECOGNIZER_RUNNING) { - return StoreEmptyReturn(); - } +const char *Recognizer::FinalResult() { + if (state_ != RECOGNIZER_RUNNING) { + return StoreEmptyReturn(); + } - json::JSON res; + feature_pipeline_->InputFinished(); + UpdateSilenceWeights(); + decoder_->AdvanceDecoding(); + decoder_->FinalizeDecoding(); + state_ = RECOGNIZER_FINALIZED; + GetResult(); + + // Free some memory while we are finalized, next + // iteration will reinitialize them anyway + delete decoder_; + delete feature_pipeline_; + delete silence_weighting_; + delete spk_feature_; + + feature_pipeline_ = nullptr; + silence_weighting_ = nullptr; + decoder_ = nullptr; + spk_feature_ = nullptr; + + return last_result_.c_str(); +} - if (partial_words_) { +void Recognizer::Reset() { + if (state_ == RECOGNIZER_RUNNING) { + decoder_->FinalizeDecoding(); + } + StoreEmptyReturn(); + state_ = RECOGNIZER_ENDPOINT; +} - if (decoder_->NumFramesInLattice() == 0) { - res["partial"] = ""; - return StoreReturn(res.dump()); - } +const char *Recognizer::StoreEmptyReturn() { + if (!max_alternatives_) { + return StoreReturn("{\"text\": \"\"}"); + } else if (nlsml_) { + return StoreReturn("\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "\n"); + } else { + return StoreReturn( + "{\"alternatives\" : [{\"text\": \"\", \"confidence\" : 1.0}] }"); + } +} - CompactLattice clat; - CompactLattice aligned_lat; +// Store result in recognizer and return as const string +const char *Recognizer::StoreReturn(const string &res) { + last_result_ = res; + return last_result_.c_str(); +} - clat = decoder_->GetLattice(decoder_->NumFramesInLattice(), false); - if (model_->winfo_) { - WordAlignLatticePartial(clat, *model_->trans_model_, *model_->winfo_, 0, &aligned_lat); - } else { - CopyLatticeForMbr(clat, &aligned_lat); - } +void Recognizer::RebuildLexicon(std::vector &words, + std::vector &pronunciations) { + using namespace fst; + using namespace std; + using StateId = StdVectorFst::StateId; + using Weight = StdArc::Weight; + using Label = StdArc::Label; + + if (words.size() != pronunciations.size()) { + KALDI_ERR << "Number of words and pronunciations must be equal"; + return; + } - MinimumBayesRisk mbr(aligned_lat); - const vector &conf = mbr.GetOneBestConfidences(); - const vector &words = mbr.GetOneBest(); - const vector > × = mbr.GetOneBestTimes(); + if (state_ == RECOGNIZER_RUNNING) { + KALDI_ERR << "Can't add speaker model to already running recognizer"; + return; + } - int size = words.size(); + if (model_->ctx_dep_ == nullptr) { + KALDI_ERR << "Can't rebuild lexicon without phone symbols and ctx dep tree"; + return; + } - stringstream text; + // Maybe make this adjustable?: + + string silence_phone = "SIL"; + // At the beginning of sentence and after each word, we output silence with + // probability 0.5; + // the probability mass assigned to having no silence is 1.0 - 0.5 = 0.5. + float silence_prob = 0.5; + // In mkgraph.sh = 1.0, in training = 0.1, in compile-graph.cc = 0.1 + float self_loop_scale = 1.0; + // In our current training scripts, this scale is 1.0. This scale only affects + // the parts of the transitions that do not relate to self-loop probabilities, + // and in the normal topology (Bakis model) it has no effect at all + float transition_scale = 1.0; + + Label silence_phone_id = model_->phone_syms_->Find(silence_phone); + if (silence_phone_id == kNoSymbol) { + KALDI_ERR << "Silence phone not found in the phone symbol table"; + return; + } - // Create JSON object - for (int i = 0; i < size; i++) { - json::JSON word; + // Create a new word symbol table for the new words + SymbolTable word_syms("words"); - word["word"] = model_->word_syms_->Find(words[i]); - word["start"] = samples_round_start_ / sample_frequency_ + (frame_offset_ + times[i].first) * 0.03; - word["end"] = samples_round_start_ / sample_frequency_ + (frame_offset_ + times[i].second) * 0.03; - word["conf"] = conf[i]; - res["partial_result"].append(word); + VectorFst l_fst; + StateId start_state = l_fst.AddState(); + StateId loop_state = l_fst.AddState(); + StateId silence_state = l_fst.AddState(); + l_fst.SetStart(start_state); - if (i) { - text << " "; - } - text << model_->word_syms_->Find(words[i]); - } - res["partial"] = text.str(); + // Add transitions + float nosil_cost = -log(1.0 - silence_prob); + float sil_cost = -log(silence_prob); + l_fst.AddArc(start_state, StdArc(0, 0, Weight(nosil_cost), loop_state)); + l_fst.AddArc(start_state, + StdArc(silence_phone_id, 0, Weight(sil_cost), silence_state)); + l_fst.AddArc(silence_state, + StdArc(silence_phone_id, 0, Weight::One(), loop_state)); + + l_fst.SetFinal(loop_state, Weight::One()); + + // Insert the epsilon symbol at the begining of words and pronunciations + // In the loop we skip any further ` SIL` pairs + words.insert(words.begin(), ""); + pronunciations.insert(pronunciations.begin(), silence_phone); + + // Add a map to store existing pronunciations + SymbolTable disambiguation_syms("disambiguation"); + unordered_map last_disambiguation_symbol; + + for (size_t i = 0; i < words.size(); ++i) { + const string &word = words[i]; + const string &pronunciation = pronunciations[i]; + + // Skip any manually added epsion entries + if (i != 0 && word == "" && pronunciation == silence_phone) { + continue; + } + + if (word.empty() || pronunciation.empty()) { + KALDI_WARN << "Skipping word with empty word or pronunciation in line " + << i + 1; + continue; + } + + Label word_id = word_syms.AddSymbol(word); + + Label disambiguation_symbol = kNoLabel; + // Check if pronunciation exists in the map + if (last_disambiguation_symbol.find(pronunciation) != + last_disambiguation_symbol.end()) { + // Increment the disambiguation symbol counter + disambiguation_symbol = last_disambiguation_symbol[pronunciation]; + int64 new_disambiguation_number = disambiguation_symbol + 1; + disambiguation_symbol = disambiguation_syms.AddSymbol( + "#" + to_string(new_disambiguation_number)); + last_disambiguation_symbol[pronunciation] = new_disambiguation_number; } else { + // Add the pronunciation to the map + last_disambiguation_symbol[pronunciation] = -1; + } - if (decoder_->NumFramesDecoded() == 0) { - res["partial"] = ""; - return StoreReturn(res.dump()); - } - Lattice lat; - decoder_->GetBestPath(false, &lat); - vector alignment, words; - LatticeWeight weight; - GetLinearSymbolSequence(lat, &alignment, &words, &weight); - - ostringstream text; - for (size_t i = 0; i < words.size(); i++) { - if (i) { - text << " "; - } - text << model_->word_syms_->Find(words[i]); - } - res["partial"] = text.str(); + istringstream iss(pronunciation); + string phone; + StateId current_state = loop_state; + bool first_phone = true; + while (iss >> phone) { + Label phone_id = model_->phone_syms_->Find(phone); + if (phone_id == kNoSymbol) { + KALDI_WARN << "Ignoring phone missing in vocabulary: '" << phone << "'"; + continue; + } + + StateId next_state_temp = l_fst.AddState(); + Label olabel = first_phone ? word_id : 0; + + if (first_phone && disambiguation_symbol != kNoLabel) { + current_state = next_state_temp; + next_state_temp = l_fst.AddState(); + l_fst.AddArc(current_state, StdArc(0, disambiguation_symbol, + Weight::One(), next_state_temp)); + } + + l_fst.AddArc(current_state, + StdArc(phone_id, olabel, Weight::One(), next_state_temp)); + current_state = next_state_temp; + first_phone = false; } - return StoreReturn(res.dump()); -} + if (current_state != loop_state) { + if (silence_phone_id != model_->phone_syms_->Find(pronunciation)) { + l_fst.AddArc(current_state, + StdArc(0, 0, Weight(nosil_cost), loop_state)); + l_fst.AddArc(current_state, StdArc(silence_phone_id, 0, + Weight(sil_cost), silence_state)); + } else { + l_fst.AddArc(current_state, StdArc(0, 0, Weight::One(), loop_state)); + } + } + } + + DeterminizeStarInLog(&l_fst); + ArcSort(&l_fst, StdILabelCompare()); -const char* Recognizer::Result() -{ - if (state_ != RECOGNIZER_RUNNING) { - return StoreEmptyReturn(); + // Extract phone disambiguation symbols + // by looking for symbols starting with '#' + vector disambig_syms; + for (int i = 0; i < model_->phone_syms_->NumSymbols(); ++i) { + const string &symbol = model_->phone_syms_->Find(i); + if (!symbol.empty() && symbol[0] == '#') { + disambig_syms.push_back(i); } - decoder_->FinalizeDecoding(); - state_ = RECOGNIZER_ENDPOINT; - return GetResult(); -} + } -const char* Recognizer::FinalResult() -{ - if (state_ != RECOGNIZER_RUNNING) { - return StoreEmptyReturn(); + int32 context_width = model_->ctx_dep_->ContextWidth(); + int32 central_position = model_->ctx_dep_->CentralPosition(); + + vector> ilabels; + // TODO: Add nonterm stuff + VectorFst cl_fst; + ComposeContext(disambig_syms, context_width, central_position, &l_fst, + &cl_fst, &ilabels); + ArcSort(&cl_fst, StdILabelCompare()); + + // Create H transducer + HTransducerConfig h_cfg; + h_cfg.transition_scale = transition_scale; + // Must be >= 0 for grammar fst + h_cfg.nonterm_phones_offset = -1; + // disambiguation symbols on the input side of H + vector *disambig_syms_h = new vector(); + VectorFst *h_fst = + GetHTransducer(ilabels, *model_->ctx_dep_, *model_->trans_model_, h_cfg, + disambig_syms_h); + + ArcSort(h_fst, StdOLabelCompare()); + + // Compose HCL transducer + VectorFst composed_fst; + // TableCompose(*h_fst, cl_fst, &composed_fst); + Compose(*h_fst, cl_fst, &composed_fst); + delete h_fst; + + // Epsilon-removal and determinization combined. + // This will fail if not determinizable. + DeterminizeStarInLog(&composed_fst); + + if (!disambig_syms_h->empty()) { + RemoveSomeInputSymbols(*disambig_syms_h, &composed_fst); + RemoveEpsLocal(&composed_fst); + } + + bool check_no_self_loops = true, reorder = true; + AddSelfLoops(*model_->trans_model_, *disambig_syms_h, self_loop_scale, + reorder, check_no_self_loops, &composed_fst); + + ArcSort(&composed_fst, StdOLabelCompare()); + + // Create the olabel lookahead matcher + vector> relabel; + StdOLabelLookAheadFst lcomposed_fst(composed_fst); + + // Get the relabel pairs + LabelLookAheadRelabeler::RelabelPairs(lcomposed_fst, &relabel); + + // Print the relabel pairs + SymbolTable *relabeled_word_syms = new SymbolTable("words"); + // Go through word_syms_ and relabel the words + for (int i = 0; i < word_syms.NumSymbols(); ++i) { + string word = word_syms.Find(i); + // Check if the word is in the relabel map + Label wid = i; + for (const auto &pair : relabel) { + if (pair.first == i) { + wid = pair.second; + break; + } } + relabeled_word_syms->AddSymbol(word, wid); + } - feature_pipeline_->InputFinished(); - UpdateSilenceWeights(); - decoder_->AdvanceDecoding(); - decoder_->FinalizeDecoding(); - state_ = RECOGNIZER_FINALIZED; - GetResult(); + // Switch HCLr, word_syms_ and disambig_ with new variables + delete hcl_fst_; + hcl_fst_ = lcomposed_fst.Copy(false); - // Free some memory while we are finalized, next - // iteration will reinitialize them anyway - delete decoder_; - delete feature_pipeline_; - delete silence_weighting_; - delete spk_feature_; + delete word_syms_; + word_syms_ = relabeled_word_syms; - feature_pipeline_ = nullptr; - silence_weighting_ = nullptr; - decoder_ = nullptr; - spk_feature_ = nullptr; + delete disambig_; + disambig_ = disambig_syms_h; +} - return last_result_.c_str(); +string Recognizer::FindWord(int64 word_id) { + string word = word_syms_ ? word_syms_->Find(word_id) + : model_->word_syms_->Find(word_id); + return word; } -void Recognizer::Reset() -{ - if (state_ == RECOGNIZER_RUNNING) { - decoder_->FinalizeDecoding(); - } - StoreEmptyReturn(); - state_ = RECOGNIZER_ENDPOINT; -} - -const char *Recognizer::StoreEmptyReturn() -{ - if (!max_alternatives_) { - return StoreReturn("{\"text\": \"\"}"); - } else if (nlsml_) { - return StoreReturn("\n" - "\n" - "\n" - "\n" - "\n" - "\n" - "\n"); - } else { - return StoreReturn("{\"alternatives\" : [{\"text\": \"\", \"confidence\" : 1.0}] }"); - } +int64 Recognizer::FindWordId(const string &word) { + return word_syms_ ? word_syms_->Find(word) : model_->word_syms_->Find(word); } -// Store result in recognizer and return as const string -const char *Recognizer::StoreReturn(const string &res) -{ - last_result_ = res; - return last_result_.c_str(); +fst::Fst *Recognizer::GetHclFst() { + if (hcl_fst_ == nullptr) { + return model_->hcl_fst_; + } + return hcl_fst_; } + +std::vector *Recognizer::GetDisambig() { + if (disambig_ == nullptr) { + return &model_->disambig_; + } + return disambig_; +} \ No newline at end of file diff --git a/src/recognizer.h b/src/recognizer.h index 6fa26710..b4db048f 100644 --- a/src/recognizer.h +++ b/src/recognizer.h @@ -15,18 +15,21 @@ #ifndef VOSK_KALDI_RECOGNIZER_H #define VOSK_KALDI_RECOGNIZER_H +#include + #include "base/kaldi-common.h" -#include "util/common-utils.h" -#include "fstext/fstext-lib.h" -#include "fstext/fstext-utils.h" #include "decoder/lattice-faster-decoder.h" #include "feat/feature-mfcc.h" +#include "fstext/fstext-lib.h" +#include "fstext/fstext-utils.h" +#include "lat/compose-lattice-pruned.h" #include "lat/kaldi-lattice.h" +#include "lat/lattice-functions-transition-model.h" #include "lat/word-align-lattice.h" -#include "lat/compose-lattice-pruned.h" #include "nnet3/am-nnet-simple.h" #include "nnet3/nnet-am-decodable-simple.h" #include "nnet3/nnet-utils.h" +#include "util/common-utils.h" #include "model.h" #include "spk_model.h" @@ -34,82 +37,96 @@ using namespace kaldi; enum RecognizerState { - RECOGNIZER_INITIALIZED, - RECOGNIZER_RUNNING, - RECOGNIZER_ENDPOINT, - RECOGNIZER_FINALIZED + RECOGNIZER_INITIALIZED, + RECOGNIZER_RUNNING, + RECOGNIZER_ENDPOINT, + RECOGNIZER_FINALIZED }; class Recognizer { - public: - Recognizer(Model *model, float sample_frequency); - Recognizer(Model *model, float sample_frequency, SpkModel *spk_model); - Recognizer(Model *model, float sample_frequency, char const *grammar); - ~Recognizer(); - void SetMaxAlternatives(int max_alternatives); - void SetSpkModel(SpkModel *spk_model); - void SetGrm(char const *grammar); - void SetWords(bool words); - void SetPartialWords(bool partial_words); - void SetNLSML(bool nlsml); - bool AcceptWaveform(const char *data, int len); - bool AcceptWaveform(const short *sdata, int len); - bool AcceptWaveform(const float *fdata, int len); - const char* Result(); - const char* FinalResult(); - const char* PartialResult(); - void Reset(); - - private: - void InitState(); - void InitRescoring(); - void CleanUp(); - void UpdateSilenceWeights(); - void UpdateGrammarFst(char const *grammar); - bool AcceptWaveform(Vector &wdata); - bool GetSpkVector(Vector &out_xvector, int *frames); - const char *GetResult(); - const char *StoreEmptyReturn(); - const char *StoreReturn(const string &res); - const char *MbrResult(CompactLattice &clat); - const char *NbestResult(CompactLattice &clat); - const char *NlsmlResult(CompactLattice &clat); - - Model *model_ = nullptr; - SingleUtteranceNnet3IncrementalDecoder *decoder_ = nullptr; - fst::LookaheadFst *decode_fst_ = nullptr; - fst::StdVectorFst *g_fst_ = nullptr; // dynamically constructed grammar - OnlineNnet2FeaturePipeline *feature_pipeline_ = nullptr; - OnlineSilenceWeighting *silence_weighting_ = nullptr; - - // Speaker identification - SpkModel *spk_model_ = nullptr; - OnlineBaseFeature *spk_feature_ = nullptr; - - // Rescoring - fst::ArcMapFst > *lm_to_subtract_ = nullptr; - kaldi::ConstArpaLmDeterministicFst *carpa_to_add_ = nullptr; - fst::ScaleDeterministicOnDemandFst *carpa_to_add_scale_ = nullptr; - // RNNLM rescoring - kaldi::rnnlm::KaldiRnnlmDeterministicFst* rnnlm_to_add_ = nullptr; - fst::DeterministicOnDemandFst *rnnlm_to_add_scale_ = nullptr; - kaldi::rnnlm::RnnlmComputeStateInfo *rnnlm_info_ = nullptr; - - - // Other - int max_alternatives_ = 0; // Disable alternatives by default - bool words_ = false; - bool partial_words_ = false; - bool nlsml_ = false; - - float sample_frequency_; - int32 frame_offset_; - - int64 samples_processed_; - int64 samples_round_start_; - - RecognizerState state_; - string last_result_; +public: + Recognizer(Model *model, float sample_frequency); + Recognizer(Model *model, float sample_frequency, SpkModel *spk_model); + Recognizer(Model *model, float sample_frequency, char const *grammar); + ~Recognizer(); + void SetMaxAlternatives(int max_alternatives); + void SetSpkModel(SpkModel *spk_model); + void SetGrm(char const *grammar, const char *const *words, + const char *const *pronunciations, int num_words); + void SetWords(bool words); + void SetPartialWords(bool partial_words); + void SetNLSML(bool nlsml); + bool AcceptWaveform(const char *data, int len); + bool AcceptWaveform(const short *sdata, int len); + bool AcceptWaveform(const float *fdata, int len); + const char *Result(); + const char *FinalResult(); + const char *PartialResult(); + void Reset(); + +private: + void InitState(); + void InitRescoring(); + void CleanUp(); + void UpdateSilenceWeights(); + void UpdateGrammarFst(char const *grammar); + bool AcceptWaveform(Vector &wdata); + bool GetSpkVector(Vector &out_xvector, int *frames); + const char *GetResult(); + const char *StoreEmptyReturn(); + const char *StoreReturn(const string &res); + const char *MbrResult(CompactLattice &clat); + const char *NbestResult(CompactLattice &clat); + const char *NlsmlResult(CompactLattice &clat); + + string FindWord(int64 word_id); + int64 FindWordId(const string &word); + void RebuildLexicon(std::vector &words, + std::vector &pronunciations); + fst::Fst *GetHclFst(); + std::vector *GetDisambig(); + + Model *model_ = nullptr; + SingleUtteranceNnet3IncrementalDecoder *decoder_ = nullptr; + fst::LookaheadFst *decode_fst_ = nullptr; + fst::StdVectorFst *g_fst_ = nullptr; // dynamically constructed grammar + OnlineNnet2FeaturePipeline *feature_pipeline_ = nullptr; + OnlineSilenceWeighting *silence_weighting_ = nullptr; + + // Speaker identification + SpkModel *spk_model_ = nullptr; + OnlineBaseFeature *spk_feature_ = nullptr; + + // Rescoring + fst::ArcMapFst> + *lm_to_subtract_ = nullptr; + kaldi::ConstArpaLmDeterministicFst *carpa_to_add_ = nullptr; + fst::ScaleDeterministicOnDemandFst *carpa_to_add_scale_ = nullptr; + // RNNLM rescoring + kaldi::rnnlm::KaldiRnnlmDeterministicFst *rnnlm_to_add_ = nullptr; + fst::DeterministicOnDemandFst *rnnlm_to_add_scale_ = nullptr; + kaldi::rnnlm::RnnlmComputeStateInfo *rnnlm_info_ = nullptr; + + // Other + int max_alternatives_ = 0; // Disable alternatives by default + bool words_ = false; + bool partial_words_ = false; + bool nlsml_ = false; + + float sample_frequency_; + int32 frame_offset_; + + int64 samples_processed_; + int64 samples_round_start_; + + RecognizerState state_; + string last_result_; + + // To be able to add words to the lexicon on the fly we need + // to create a copy of model_->hcl_fst_ and model_->word_syms_ + fst::Fst *hcl_fst_ = nullptr; + fst::SymbolTable *word_syms_ = nullptr; + std::vector *disambig_ = nullptr; }; #endif /* VOSK_KALDI_RECOGNIZER_H */ diff --git a/src/vosk_api.cc b/src/vosk_api.cc index f146b22c..f17ef26e 100644 --- a/src/vosk_api.cc +++ b/src/vosk_api.cc @@ -14,277 +14,264 @@ #include "vosk_api.h" -#include "recognizer.h" #include "model.h" +#include "recognizer.h" #include "spk_model.h" #if HAVE_CUDA -#include "cudamatrix/cu-device.h" #include "batch_recognizer.h" +#include "cudamatrix/cu-device.h" #endif #include using namespace kaldi; -VoskModel *vosk_model_new(const char *model_path) -{ - try { - return (VoskModel *)new Model(model_path); - } catch (...) { - return nullptr; - } +VoskModel *vosk_model_new(const char *model_path) { + try { + return (VoskModel *)new Model(model_path); + } catch (...) { + return nullptr; + } } -void vosk_model_free(VoskModel *model) -{ - if (model == nullptr) { - return; - } - ((Model *)model)->Unref(); +void vosk_model_free(VoskModel *model) { + if (model == nullptr) { + return; + } + ((Model *)model)->Unref(); } -int vosk_model_find_word(VoskModel *model, const char *word) -{ - return (int) ((Model *)model)->FindWord(word); +int vosk_model_find_word(VoskModel *model, const char *word) { + return (int)((Model *)model)->FindWord(word); } -VoskSpkModel *vosk_spk_model_new(const char *model_path) -{ - try { - return (VoskSpkModel *)new SpkModel(model_path); - } catch (...) { - return nullptr; - } +VoskSpkModel *vosk_spk_model_new(const char *model_path) { + try { + return (VoskSpkModel *)new SpkModel(model_path); + } catch (...) { + return nullptr; + } } -void vosk_spk_model_free(VoskSpkModel *model) -{ - if (model == nullptr) { - return; - } - ((SpkModel *)model)->Unref(); +void vosk_spk_model_free(VoskSpkModel *model) { + if (model == nullptr) { + return; + } + ((SpkModel *)model)->Unref(); } -VoskRecognizer *vosk_recognizer_new(VoskModel *model, float sample_rate) -{ - try { - return (VoskRecognizer *)new Recognizer((Model *)model, sample_rate); - } catch (...) { - return nullptr; - } +VoskRecognizer *vosk_recognizer_new(VoskModel *model, float sample_rate) { + try { + return (VoskRecognizer *)new Recognizer((Model *)model, sample_rate); + } catch (...) { + return nullptr; + } } -VoskRecognizer *vosk_recognizer_new_spk(VoskModel *model, float sample_rate, VoskSpkModel *spk_model) -{ - try { - return (VoskRecognizer *)new Recognizer((Model *)model, sample_rate, (SpkModel *)spk_model); - } catch (...) { - return nullptr; - } +VoskRecognizer *vosk_recognizer_new_spk(VoskModel *model, float sample_rate, + VoskSpkModel *spk_model) { + try { + return (VoskRecognizer *)new Recognizer((Model *)model, sample_rate, + (SpkModel *)spk_model); + } catch (...) { + return nullptr; + } } -VoskRecognizer *vosk_recognizer_new_grm(VoskModel *model, float sample_rate, const char *grammar) -{ - try { - return (VoskRecognizer *)new Recognizer((Model *)model, sample_rate, grammar); - } catch (...) { - return nullptr; - } +VoskRecognizer *vosk_recognizer_new_grm(VoskModel *model, float sample_rate, + const char *grammar) { + try { + return (VoskRecognizer *)new Recognizer((Model *)model, sample_rate, + grammar); + } catch (...) { + return nullptr; + } } -void vosk_recognizer_set_max_alternatives(VoskRecognizer *recognizer, int max_alternatives) -{ - ((Recognizer *)recognizer)->SetMaxAlternatives(max_alternatives); +void vosk_recognizer_set_max_alternatives(VoskRecognizer *recognizer, + int max_alternatives) { + ((Recognizer *)recognizer)->SetMaxAlternatives(max_alternatives); } -void vosk_recognizer_set_words(VoskRecognizer *recognizer, int words) -{ - ((Recognizer *)recognizer)->SetWords((bool)words); +void vosk_recognizer_set_words(VoskRecognizer *recognizer, int words) { + ((Recognizer *)recognizer)->SetWords((bool)words); } -void vosk_recognizer_set_partial_words(VoskRecognizer *recognizer, int partial_words) -{ - ((Recognizer *)recognizer)->SetPartialWords((bool)partial_words); +void vosk_recognizer_set_partial_words(VoskRecognizer *recognizer, + int partial_words) { + ((Recognizer *)recognizer)->SetPartialWords((bool)partial_words); } -void vosk_recognizer_set_nlsml(VoskRecognizer *recognizer, int nlsml) -{ - ((Recognizer *)recognizer)->SetNLSML((bool)nlsml); +void vosk_recognizer_set_nlsml(VoskRecognizer *recognizer, int nlsml) { + ((Recognizer *)recognizer)->SetNLSML((bool)nlsml); } -void vosk_recognizer_set_spk_model(VoskRecognizer *recognizer, VoskSpkModel *spk_model) -{ - if (recognizer == nullptr || spk_model == nullptr) { - return; - } - ((Recognizer *)recognizer)->SetSpkModel((SpkModel *)spk_model); +void vosk_recognizer_set_spk_model(VoskRecognizer *recognizer, + VoskSpkModel *spk_model) { + if (recognizer == nullptr || spk_model == nullptr) { + return; + } + ((Recognizer *)recognizer)->SetSpkModel((SpkModel *)spk_model); } -void vosk_recognizer_set_grm(VoskRecognizer *recognizer, char const *grammar) -{ - if (recognizer == nullptr) { - return; - } - ((Recognizer *)recognizer)->SetGrm(grammar); +void vosk_recognizer_set_grm(VoskRecognizer *recognizer, char const *grammar) { + if (recognizer == nullptr) { + return; + } + ((Recognizer *)recognizer)->SetGrm(grammar, nullptr, nullptr, 0); } -int vosk_recognizer_accept_waveform(VoskRecognizer *recognizer, const char *data, int length) -{ - try { - return ((Recognizer *)(recognizer))->AcceptWaveform(data, length); - } catch (...) { - return -1; - } +void vosk_recognizer_set_grm_with_lexicon(VoskRecognizer *recognizer, + char const *grammar, + const char *const *words, + const char *const *pronunciations, + int num_words) { + if (recognizer == nullptr) { + return; + } + ((Recognizer *)recognizer)->SetGrm(grammar, words, pronunciations, num_words); } -int vosk_recognizer_accept_waveform_s(VoskRecognizer *recognizer, const short *data, int length) -{ - try { - return ((Recognizer *)(recognizer))->AcceptWaveform(data, length); - } catch (...) { - return -1; - } +int vosk_recognizer_accept_waveform(VoskRecognizer *recognizer, + const char *data, int length) { + try { + return ((Recognizer *)(recognizer))->AcceptWaveform(data, length); + } catch (...) { + return -1; + } } -int vosk_recognizer_accept_waveform_f(VoskRecognizer *recognizer, const float *data, int length) -{ - try { - return ((Recognizer *)(recognizer))->AcceptWaveform(data, length); - } catch (...) { - return -1; - } +int vosk_recognizer_accept_waveform_s(VoskRecognizer *recognizer, + const short *data, int length) { + try { + return ((Recognizer *)(recognizer))->AcceptWaveform(data, length); + } catch (...) { + return -1; + } } -const char *vosk_recognizer_result(VoskRecognizer *recognizer) -{ - return ((Recognizer *)recognizer)->Result(); +int vosk_recognizer_accept_waveform_f(VoskRecognizer *recognizer, + const float *data, int length) { + try { + return ((Recognizer *)(recognizer))->AcceptWaveform(data, length); + } catch (...) { + return -1; + } } -const char *vosk_recognizer_partial_result(VoskRecognizer *recognizer) -{ - return ((Recognizer *)recognizer)->PartialResult(); +const char *vosk_recognizer_result(VoskRecognizer *recognizer) { + return ((Recognizer *)recognizer)->Result(); } -const char *vosk_recognizer_final_result(VoskRecognizer *recognizer) -{ - return ((Recognizer *)recognizer)->FinalResult(); +const char *vosk_recognizer_partial_result(VoskRecognizer *recognizer) { + return ((Recognizer *)recognizer)->PartialResult(); } -void vosk_recognizer_reset(VoskRecognizer *recognizer) -{ - ((Recognizer *)recognizer)->Reset(); +const char *vosk_recognizer_final_result(VoskRecognizer *recognizer) { + return ((Recognizer *)recognizer)->FinalResult(); } -void vosk_recognizer_free(VoskRecognizer *recognizer) -{ - delete (Recognizer *)(recognizer); +void vosk_recognizer_reset(VoskRecognizer *recognizer) { + ((Recognizer *)recognizer)->Reset(); } -void vosk_set_log_level(int log_level) -{ - SetVerboseLevel(log_level); +void vosk_recognizer_free(VoskRecognizer *recognizer) { + delete (Recognizer *)(recognizer); } -void vosk_gpu_init() -{ +void vosk_set_log_level(int log_level) { SetVerboseLevel(log_level); } + +void vosk_gpu_init() { #if HAVE_CUDA -// kaldi::CuDevice::EnableTensorCores(true); -// kaldi::CuDevice::EnableTf32Compute(true); - kaldi::CuDevice::Instantiate().SelectGpuId("yes"); - kaldi::CuDevice::Instantiate().AllowMultithreading(); + // kaldi::CuDevice::EnableTensorCores(true); + // kaldi::CuDevice::EnableTf32Compute(true); + kaldi::CuDevice::Instantiate().SelectGpuId("yes"); + kaldi::CuDevice::Instantiate().AllowMultithreading(); #endif } -void vosk_gpu_thread_init() -{ +void vosk_gpu_thread_init() { #if HAVE_CUDA - kaldi::CuDevice::Instantiate(); + kaldi::CuDevice::Instantiate(); #endif } -VoskBatchModel *vosk_batch_model_new(const char *model_path) -{ +VoskBatchModel *vosk_batch_model_new(const char *model_path) { #if HAVE_CUDA - return (VoskBatchModel *)(new BatchModel(model_path)); + return (VoskBatchModel *)(new BatchModel(model_path)); #else - return NULL; + return NULL; #endif } -void vosk_batch_model_free(VoskBatchModel *model) -{ +void vosk_batch_model_free(VoskBatchModel *model) { #if HAVE_CUDA - delete ((BatchModel *)model); + delete ((BatchModel *)model); #endif } -void vosk_batch_model_wait(VoskBatchModel *model) -{ +void vosk_batch_model_wait(VoskBatchModel *model) { #if HAVE_CUDA - ((BatchModel *)model)->WaitForCompletion(); + ((BatchModel *)model)->WaitForCompletion(); #endif } -VoskBatchRecognizer *vosk_batch_recognizer_new(VoskBatchModel *model, float sample_rate) -{ +VoskBatchRecognizer *vosk_batch_recognizer_new(VoskBatchModel *model, + float sample_rate) { #if HAVE_CUDA - return (VoskBatchRecognizer *)(new BatchRecognizer((BatchModel *)model, sample_rate)); + return (VoskBatchRecognizer *)(new BatchRecognizer((BatchModel *)model, + sample_rate)); #else - return NULL; + return NULL; #endif } -void vosk_batch_recognizer_free(VoskBatchRecognizer *recognizer) -{ +void vosk_batch_recognizer_free(VoskBatchRecognizer *recognizer) { #if HAVE_CUDA - delete ((BatchRecognizer *)recognizer); + delete ((BatchRecognizer *)recognizer); #endif } -void vosk_batch_recognizer_accept_waveform(VoskBatchRecognizer *recognizer, const char *data, int length) -{ +void vosk_batch_recognizer_accept_waveform(VoskBatchRecognizer *recognizer, + const char *data, int length) { #if HAVE_CUDA - ((BatchRecognizer *)recognizer)->AcceptWaveform(data, length); + ((BatchRecognizer *)recognizer)->AcceptWaveform(data, length); #endif } -void vosk_batch_recognizer_set_nlsml(VoskBatchRecognizer *recognizer, int nlsml) -{ +void vosk_batch_recognizer_set_nlsml(VoskBatchRecognizer *recognizer, + int nlsml) { #if HAVE_CUDA - ((BatchRecognizer *)recognizer)->SetNLSML((bool)nlsml); + ((BatchRecognizer *)recognizer)->SetNLSML((bool)nlsml); #endif } -void vosk_batch_recognizer_finish_stream(VoskBatchRecognizer *recognizer) -{ +void vosk_batch_recognizer_finish_stream(VoskBatchRecognizer *recognizer) { #if HAVE_CUDA - ((BatchRecognizer *)recognizer)->FinishStream(); + ((BatchRecognizer *)recognizer)->FinishStream(); #endif } -const char *vosk_batch_recognizer_front_result(VoskBatchRecognizer *recognizer) -{ +const char * +vosk_batch_recognizer_front_result(VoskBatchRecognizer *recognizer) { #if HAVE_CUDA - return ((BatchRecognizer *)recognizer)->FrontResult(); + return ((BatchRecognizer *)recognizer)->FrontResult(); #else - return NULL; + return NULL; #endif } -void vosk_batch_recognizer_pop(VoskBatchRecognizer *recognizer) -{ +void vosk_batch_recognizer_pop(VoskBatchRecognizer *recognizer) { #if HAVE_CUDA - ((BatchRecognizer *)recognizer)->Pop(); + ((BatchRecognizer *)recognizer)->Pop(); #endif } - -int vosk_batch_recognizer_get_pending_chunks(VoskBatchRecognizer *recognizer) -{ +int vosk_batch_recognizer_get_pending_chunks(VoskBatchRecognizer *recognizer) { #if HAVE_CUDA - return ((BatchRecognizer *)recognizer)->GetNumPendingChunks(); + return ((BatchRecognizer *)recognizer)->GetNumPendingChunks(); #else - return 0; + return 0; #endif } diff --git a/src/vosk_api.h b/src/vosk_api.h index f0cfa163..ae5b29bd 100644 --- a/src/vosk_api.h +++ b/src/vosk_api.h @@ -26,20 +26,17 @@ extern "C" { * threads. */ typedef struct VoskModel VoskModel; - /** Speaker model is the same as model but contains the data * for speaker identification. */ typedef struct VoskSpkModel VoskSpkModel; - /** Recognizer object is the main object which processes data. * Each recognizer usually runs in own thread and takes audio as input. * Once audio is processed recognizer returns JSON object as a string - * which represent decoded information - words, confidences, times, n-best lists, - * speaker information and so on */ + * which represent decoded information - words, confidences, times, n-best + * lists, speaker information and so on */ typedef struct VoskRecognizer VoskRecognizer; - /** * Batch model object */ @@ -50,14 +47,12 @@ typedef struct VoskBatchModel VoskBatchModel; */ typedef struct VoskBatchRecognizer VoskBatchRecognizer; - /** Loads model data from the file and returns the model object * * @param model_path: the path of the model on the filesystem * @returns model object or NULL if problem occured */ VoskModel *vosk_model_new(const char *model_path); - /** Releases the model memory * * The model object is reference-counted so if some recognizer @@ -65,7 +60,6 @@ VoskModel *vosk_model_new(const char *model_path); * last recognizer is released, model will be released too. */ void vosk_model_free(VoskModel *model); - /** Check if a word can be recognized by the model * @param word: the word * @returns the word symbol if @param word exists inside the model @@ -73,14 +67,12 @@ void vosk_model_free(VoskModel *model); * Reminding that word symbol 0 is for */ int vosk_model_find_word(VoskModel *model, const char *word); - /** Loads speaker model data from the file and returns the model object * * @param model_path: the path of the model on the filesystem * @returns model object or NULL if problem occurred */ VoskSpkModel *vosk_spk_model_new(const char *model_path); - /** Releases the model memory * * The model object is reference-counted so if some recognizer @@ -91,69 +83,104 @@ void vosk_spk_model_free(VoskSpkModel *model); /** Creates the recognizer object * * The recognizers process the speech and return text using shared model data - * @param model VoskModel containing static data for recognizer. Model can be - * shared across recognizers, even running in different threads. - * @param sample_rate The sample rate of the audio you going to feed into the recognizer. - * Make sure this rate matches the audio content, it is a common + * @param model VoskModel containing static data for recognizer. Model + * can be shared across recognizers, even running in different threads. + * @param sample_rate The sample rate of the audio you going to feed into the + * recognizer. Make sure this rate matches the audio content, it is a common * issue causing accuracy problems. * @returns recognizer object or NULL if problem occured */ VoskRecognizer *vosk_recognizer_new(VoskModel *model, float sample_rate); - /** Creates the recognizer object with speaker recognition * * With the speaker recognition mode the recognizer not just recognize * text but also return speaker vectors one can use for speaker identification * - * @param model VoskModel containing static data for recognizer. Model can be - * shared across recognizers, even running in different threads. - * @param sample_rate The sample rate of the audio you going to feed into the recognizer. - * Make sure this rate matches the audio content, it is a common + * @param model VoskModel containing static data for recognizer. Model + * can be shared across recognizers, even running in different threads. + * @param sample_rate The sample rate of the audio you going to feed into the + * recognizer. Make sure this rate matches the audio content, it is a common * issue causing accuracy problems. * @param spk_model speaker model for speaker identification * @returns recognizer object or NULL if problem occured */ -VoskRecognizer *vosk_recognizer_new_spk(VoskModel *model, float sample_rate, VoskSpkModel *spk_model); - +VoskRecognizer *vosk_recognizer_new_spk(VoskModel *model, float sample_rate, + VoskSpkModel *spk_model); /** Creates the recognizer object with the phrase list * - * Sometimes when you want to improve recognition accuracy and when you don't need - * to recognize large vocabulary you can specify a list of phrases to recognize. This - * will improve recognizer speed and accuracy but might return [unk] if user said - * something different. + * Sometimes when you want to improve recognition accuracy and when you don't + * need to recognize large vocabulary you can specify a list of phrases to + * recognize. This will improve recognizer speed and accuracy but might return + * [unk] if user said something different. * - * Only recognizers with lookahead models support this type of quick configuration. - * Precompiled HCLG graph models are not supported. + * Only recognizers with lookahead models support this type of quick + * configuration. Precompiled HCLG graph models are not supported. * - * @param model VoskModel containing static data for recognizer. Model can be - * shared across recognizers, even running in different threads. - * @param sample_rate The sample rate of the audio you going to feed into the recognizer. - * Make sure this rate matches the audio content, it is a common + * @param model VoskModel containing static data for recognizer. Model + * can be shared across recognizers, even running in different threads. + * @param sample_rate The sample rate of the audio you going to feed into the + * recognizer. Make sure this rate matches the audio content, it is a common * issue causing accuracy problems. - * @param grammar The string with the list of phrases to recognize as JSON array of strings, - * for example "["one two three four five", "[unk]"]". + * @param grammar The string with the list of phrases to recognize as JSON + * array of strings, for example "["one two three four five", "[unk]"]". * * @returns recognizer object or NULL if problem occured */ -VoskRecognizer *vosk_recognizer_new_grm(VoskModel *model, float sample_rate, const char *grammar); - +VoskRecognizer *vosk_recognizer_new_grm(VoskModel *model, float sample_rate, + const char *grammar); /** Adds speaker model to already initialized recognizer * - * Can add speaker recognition model to already created recognizer. Helps to initialize - * speaker recognition for grammar-based recognizer. + * Can add speaker recognition model to already created recognizer. Helps to + * initialize speaker recognition for grammar-based recognizer. * * @param spk_model Speaker recognition model */ -void vosk_recognizer_set_spk_model(VoskRecognizer *recognizer, VoskSpkModel *spk_model); - +void vosk_recognizer_set_spk_model(VoskRecognizer *recognizer, + VoskSpkModel *spk_model); /** Reconfigures recognizer to use grammar * * @param recognizer Already running VoskRecognizer - * @param grammar Set of phrases in JSON array of strings or "[]" to use default model graph. - * See also vosk_recognizer_new_grm + * @param grammar Set of phrases in JSON array of strings or "[]" to use + * default model graph. See also vosk_recognizer_new_grm */ void vosk_recognizer_set_grm(VoskRecognizer *recognizer, char const *grammar); +/** + * Reconfigures recognizer to use grammar with a custom pronunciation lexicon. + * + * Note: This function is only supported by lookahead models that + * include the `tree` file (at `/am/tree` or `/tree`) and phone symbol table + * (`/graph/phones.txt` or `/phones.txt`) and is only useful for small lexicons + * (e.g. 100 words). For larger lexicons, consider rebuilding the model with the + * desired lexicon. + * + * The phones must be white-space separated and each phone must be out of the + * phone symbol table of the model. If there are multiple versions of the phones + * starting with `_B`, `_I`, `_E` or `_S`, these phones will be used as position + * markers and must be used correctly: + * + * - `_S` is used for pronunciations that are only one phone long, otherwise: + * - `_B` is used for the first phone in a word + * - `_I` is used for intermediate phones in a word + * - `_E` is used for the last phone in a word + * + * The lexicon must also include the `` entry mapped to the silence phone + * (e.g. word: ``, pronunciation: `SIL`), which is used for epsilon (empty) + * transitions. + * + * @param recognizer Already running VoskRecognizer + * @param grammar Set of phrases in JSON array of strings or "[]" to use + * @param words The array of words to use in the grammar (e.g. "one", "two") + * @param pronunciations The array of pronunciations for the words (e.g. "HH_B + * W_I AH_I N_E", "T_B UW_E"). + * @param num_words The number of words / pronunciations in the grammar + * default model graph. See also vosk_recognizer_new_grm + */ +void vosk_recognizer_set_grm_with_lexicon(VoskRecognizer *recognizer, + char const *grammar, + const char *const *words, + const char *const *pronunciations, + int num_words); /** Configures recognizer to output n-best results * @@ -166,10 +193,11 @@ void vosk_recognizer_set_grm(VoskRecognizer *recognizer, char const *grammar); * } * * - * @param max_alternatives - maximum alternatives to return from recognition results + * @param max_alternatives - maximum alternatives to return from recognition + * results */ -void vosk_recognizer_set_max_alternatives(VoskRecognizer *recognizer, int max_alternatives); - +void vosk_recognizer_set_max_alternatives(VoskRecognizer *recognizer, + int max_alternatives); /** Enables words with times in the output * @@ -210,35 +238,34 @@ void vosk_recognizer_set_words(VoskRecognizer *recognizer, int words); * * @param partial_words - boolean value */ -void vosk_recognizer_set_partial_words(VoskRecognizer *recognizer, int partial_words); +void vosk_recognizer_set_partial_words(VoskRecognizer *recognizer, + int partial_words); /** Set NLSML output * @param nlsml - boolean value */ void vosk_recognizer_set_nlsml(VoskRecognizer *recognizer, int nlsml); - /** Accept voice data * * accept and process new chunk of voice data * * @param data - audio data in PCM 16-bit mono format * @param length - length of the audio data - * @returns 1 if silence is occured and you can retrieve a new utterance with result method - * 0 if decoding continues - * -1 if exception occured */ -int vosk_recognizer_accept_waveform(VoskRecognizer *recognizer, const char *data, int length); - + * @returns 1 if silence is occured and you can retrieve a new utterance with + * result method 0 if decoding continues -1 if exception occured */ +int vosk_recognizer_accept_waveform(VoskRecognizer *recognizer, + const char *data, int length); -/** Same as above but the version with the short data for language bindings where you have - * audio as array of shorts */ -int vosk_recognizer_accept_waveform_s(VoskRecognizer *recognizer, const short *data, int length); - - -/** Same as above but the version with the float data for language bindings where you have - * audio as array of floats */ -int vosk_recognizer_accept_waveform_f(VoskRecognizer *recognizer, const float *data, int length); +/** Same as above but the version with the short data for language bindings + * where you have audio as array of shorts */ +int vosk_recognizer_accept_waveform_s(VoskRecognizer *recognizer, + const short *data, int length); +/** Same as above but the version with the float data for language bindings + * where you have audio as array of floats */ +int vosk_recognizer_accept_waveform_f(VoskRecognizer *recognizer, + const float *data, int length); /** Returns speech recognition result * @@ -252,13 +279,14 @@ int vosk_recognizer_accept_waveform_f(VoskRecognizer *recognizer, const float *d * } * * - * If alternatives enabled it returns result with alternatives, see also vosk_recognizer_set_max_alternatives(). + * If alternatives enabled it returns result with alternatives, see also + * vosk_recognizer_set_max_alternatives(). * - * If word times enabled returns word time, see also vosk_recognizer_set_word_times(). + * If word times enabled returns word time, see also + * vosk_recognizer_set_word_times(). */ const char *vosk_recognizer_result(VoskRecognizer *recognizer); - /** Returns partial speech recognition * * @returns partial speech recognition text which is not yet finalized. @@ -272,22 +300,20 @@ const char *vosk_recognizer_result(VoskRecognizer *recognizer); */ const char *vosk_recognizer_partial_result(VoskRecognizer *recognizer); - -/** Returns speech recognition result. Same as result, but doesn't wait for silence - * You usually call it in the end of the stream to get final bits of audio. It - * flushes the feature pipeline, so all remaining audio chunks got processed. +/** Returns speech recognition result. Same as result, but doesn't wait for + * silence You usually call it in the end of the stream to get final bits of + * audio. It flushes the feature pipeline, so all remaining audio chunks got + * processed. * * @returns speech result in JSON format. */ const char *vosk_recognizer_final_result(VoskRecognizer *recognizer); - /** Resets the recognizer * * Resets current results so the recognition can continue from scratch */ void vosk_recognizer_reset(VoskRecognizer *recognizer); - /** Releases recognizer object * * Underlying model is also unreferenced and if needed released */ @@ -329,18 +355,21 @@ void vosk_batch_model_wait(VoskBatchModel *model); /** Creates batch recognizer object * @returns recognizer object or NULL if problem occured */ -VoskBatchRecognizer *vosk_batch_recognizer_new(VoskBatchModel *model, float sample_rate); - +VoskBatchRecognizer *vosk_batch_recognizer_new(VoskBatchModel *model, + float sample_rate); + /** Releases batch recognizer object */ void vosk_batch_recognizer_free(VoskBatchRecognizer *recognizer); /** Accept batch voice data */ -void vosk_batch_recognizer_accept_waveform(VoskBatchRecognizer *recognizer, const char *data, int length); +void vosk_batch_recognizer_accept_waveform(VoskBatchRecognizer *recognizer, + const char *data, int length); /** Set NLSML output * @param nlsml - boolean value */ -void vosk_batch_recognizer_set_nlsml(VoskBatchRecognizer *recognizer, int nlsml); +void vosk_batch_recognizer_set_nlsml(VoskBatchRecognizer *recognizer, + int nlsml); /** Closes the stream */ void vosk_batch_recognizer_finish_stream(VoskBatchRecognizer *recognizer);