diff --git a/.github/scripts/test-offline-ctc.sh b/.github/scripts/test-offline-ctc.sh index f53016866..16f7889df 100755 --- a/.github/scripts/test-offline-ctc.sh +++ b/.github/scripts/test-offline-ctc.sh @@ -13,6 +13,47 @@ echo "PATH: $PATH" which $EXE +log "------------------------------------------------------------" +log "Run Wenet models" +log "------------------------------------------------------------" +wenet_models=( +sherpa-onnx-zh-wenet-aishell +sherpa-onnx-zh-wenet-aishell2 +sherpa-onnx-zh-wenet-wenetspeech +sherpa-onnx-zh-wenet-multi-cn +sherpa-onnx-en-wenet-librispeech +sherpa-onnx-en-wenet-gigaspeech +) +for name in ${wenet_models[@]}; do + repo_url=https://huggingface.co/csukuangfj/$name + log "Start testing ${repo_url}" + repo=$(basename $repo_url) + log "Download pretrained model and test-data from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + pushd $repo + git lfs pull --include "*.onnx" + ls -lh *.onnx + popd + + log "test float32 models" + time $EXE \ + --tokens=$repo/tokens.txt \ + --wenet-ctc-model=$repo/model.onnx \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + + log "test int8 models" + time $EXE \ + --tokens=$repo/tokens.txt \ + --wenet-ctc-model=$repo/model.int8.onnx \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + + rm -rf $repo +done + log "------------------------------------------------------------" log "Run tdnn yesno (Hebrew)" log "------------------------------------------------------------" diff --git a/.github/workflows/export-wenet-to-onnx.yaml b/.github/workflows/export-wenet-to-onnx.yaml index 191e6a6b7..3ac14e8d1 100644 --- a/.github/workflows/export-wenet-to-onnx.yaml +++ b/.github/workflows/export-wenet-to-onnx.yaml @@ -1,17 +1,6 @@ name: export-wenet-to-onnx on: - push: - branches: - - master - paths: - - 'scripts/wenet/**' - - '.github/workflows/export-wenet-to-onnx.yaml' - pull_request: - paths: - - 'scripts/wenet/**' - - '.github/workflows/export-wenet-to-onnx.yaml' - workflow_dispatch: concurrency: diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index 927e6bd44..88c82c3eb 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -89,6 +89,14 @@ jobs: file build/bin/sherpa-onnx readelf -d build/bin/sherpa-onnx + - name: Test offline CTC + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline + + .github/scripts/test-offline-ctc.sh + - name: Test offline TTS shell: bash run: | @@ -115,14 +123,6 @@ jobs: .github/scripts/test-offline-whisper.sh - - name: Test offline CTC - shell: bash - run: | - export PATH=$PWD/build/bin:$PATH - export EXE=sherpa-onnx-offline - - .github/scripts/test-offline-ctc.sh - - name: Test offline transducer shell: bash run: | diff --git a/scripts/wenet/export-onnx-streaming.py b/scripts/wenet/export-onnx-streaming.py index bc384d694..27d8afde5 100755 --- a/scripts/wenet/export-onnx-streaming.py +++ b/scripts/wenet/export-onnx-streaming.py @@ -172,7 +172,7 @@ def main(): # https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz url = os.environ.get("WENET_URL", "") meta_data = { - "model_type": "wenet-ctc", + "model_type": "wenet_ctc", "version": "1", "model_author": "wenet", "comment": "streaming", @@ -185,6 +185,7 @@ def main(): "cnn_module_kernel": cnn_module_kernel, "right_context": right_context, "subsampling_factor": subsampling_factor, + "vocab_size": torch_model.ctc.ctc_lo.weight.shape[0], } add_meta_data(filename=filename, meta_data=meta_data) diff --git a/scripts/wenet/export-onnx.py b/scripts/wenet/export-onnx.py index 791afbd5d..b6ac9874d 100755 --- a/scripts/wenet/export-onnx.py +++ b/scripts/wenet/export-onnx.py @@ -107,10 +107,12 @@ def main(): # https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz url = os.environ.get("WENET_URL", "") meta_data = { - "model_type": "wenet-ctc", + "model_type": "wenet_ctc", "version": "1", "model_author": "wenet", "comment": "non-streaming", + "subsampling_factor": torch_model.encoder.embed.subsampling_rate, + "vocab_size": torch_model.ctc.ctc_lo.weight.shape[0], "url": url, } add_meta_data(filename=filename, meta_data=meta_data) diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index b44874d47..0e870b301 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -41,6 +41,8 @@ set(sources offline-transducer-model-config.cc offline-transducer-model.cc offline-transducer-modified-beam-search-decoder.cc + offline-wenet-ctc-model-config.cc + offline-wenet-ctc-model.cc offline-whisper-greedy-search-decoder.cc offline-whisper-model-config.cc offline-whisper-model.cc diff --git a/sherpa-onnx/csrc/offline-ctc-model.cc b/sherpa-onnx/csrc/offline-ctc-model.cc index d8864404b..cf0649acb 100644 --- a/sherpa-onnx/csrc/offline-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-ctc-model.cc @@ -12,6 +12,7 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" #include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" +#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h" #include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h" #include "sherpa-onnx/csrc/onnx-utils.h" @@ -21,10 +22,11 @@ enum class ModelType { kEncDecCTCModelBPE, kTdnn, kZipformerCtc, + kWenetCtc, kUnkown, }; -} +} // namespace namespace sherpa_onnx { @@ -52,6 +54,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, "If you are using models from NeMo, please refer to\n" "https://huggingface.co/csukuangfj/" "sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py" + "If you are using models from WeNet, please refer to\n" + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/" + "run.sh\n" "\n" "for how to add metadta to model.onnx\n"); return ModelType::kUnkown; @@ -63,6 +68,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, return ModelType::kTdnn; } else if (model_type.get() == std::string("zipformer2_ctc")) { return ModelType::kZipformerCtc; + } else if (model_type.get() == std::string("wenet_ctc")) { + return ModelType::kWenetCtc; } else { SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); return ModelType::kUnkown; @@ -80,6 +87,8 @@ std::unique_ptr OfflineCtcModel::Create( filename = config.tdnn.model; } else if (!config.zipformer_ctc.model.empty()) { filename = config.zipformer_ctc.model; + } else if (!config.wenet_ctc.model.empty()) { + filename = config.wenet_ctc.model; } else { SHERPA_ONNX_LOGE("Please specify a CTC model"); exit(-1); @@ -101,6 +110,9 @@ std::unique_ptr OfflineCtcModel::Create( case ModelType::kZipformerCtc: return std::make_unique(config); break; + case ModelType::kWenetCtc: + return std::make_unique(config); + break; case ModelType::kUnkown: SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); return nullptr; @@ -122,6 +134,8 @@ std::unique_ptr OfflineCtcModel::Create( filename = config.tdnn.model; } else if (!config.zipformer_ctc.model.empty()) { filename = config.zipformer_ctc.model; + } else if (!config.wenet_ctc.model.empty()) { + filename = config.wenet_ctc.model; } else { SHERPA_ONNX_LOGE("Please specify a CTC model"); exit(-1); @@ -143,6 +157,9 @@ std::unique_ptr OfflineCtcModel::Create( case ModelType::kZipformerCtc: return std::make_unique(mgr, config); break; + case ModelType::kWenetCtc: + return std::make_unique(mgr, config); + break; case ModelType::kUnkown: SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); return nullptr; diff --git a/sherpa-onnx/csrc/offline-ctc-model.h b/sherpa-onnx/csrc/offline-ctc-model.h index e2947f957..f4c7406f6 100644 --- a/sherpa-onnx/csrc/offline-ctc-model.h +++ b/sherpa-onnx/csrc/offline-ctc-model.h @@ -63,6 +63,9 @@ class OfflineCtcModel { * for the features. */ virtual std::string FeatureNormalizationMethod() const { return {}; } + + // Return true if the model supports batch size > 1 + virtual bool SupportBatchProcessing() const { return true; } }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-model-config.cc b/sherpa-onnx/csrc/offline-model-config.cc index 02c799d63..5f4a6770a 100644 --- a/sherpa-onnx/csrc/offline-model-config.cc +++ b/sherpa-onnx/csrc/offline-model-config.cc @@ -17,6 +17,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { whisper.Register(po); tdnn.Register(po); zipformer_ctc.Register(po); + wenet_ctc.Register(po); po->Register("tokens", &tokens, "Path to tokens.txt"); @@ -67,6 +68,10 @@ bool OfflineModelConfig::Validate() const { return zipformer_ctc.Validate(); } + if (!wenet_ctc.model.empty()) { + return wenet_ctc.Validate(); + } + return transducer.Validate(); } @@ -80,6 +85,7 @@ std::string OfflineModelConfig::ToString() const { os << "whisper=" << whisper.ToString() << ", "; os << "tdnn=" << tdnn.ToString() << ", "; os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", "; + os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; os << "tokens=\"" << tokens << "\", "; os << "num_threads=" << num_threads << ", "; os << "debug=" << (debug ? "True" : "False") << ", "; diff --git a/sherpa-onnx/csrc/offline-model-config.h b/sherpa-onnx/csrc/offline-model-config.h index 55a063f98..9750642f2 100644 --- a/sherpa-onnx/csrc/offline-model-config.h +++ b/sherpa-onnx/csrc/offline-model-config.h @@ -10,6 +10,7 @@ #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" #include "sherpa-onnx/csrc/offline-tdnn-model-config.h" #include "sherpa-onnx/csrc/offline-transducer-model-config.h" +#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h" #include "sherpa-onnx/csrc/offline-whisper-model-config.h" #include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h" @@ -22,6 +23,7 @@ struct OfflineModelConfig { OfflineWhisperModelConfig whisper; OfflineTdnnModelConfig tdnn; OfflineZipformerCtcModelConfig zipformer_ctc; + OfflineWenetCtcModelConfig wenet_ctc; std::string tokens; int32_t num_threads = 2; @@ -46,6 +48,7 @@ struct OfflineModelConfig { const OfflineWhisperModelConfig &whisper, const OfflineTdnnModelConfig &tdnn, const OfflineZipformerCtcModelConfig &zipformer_ctc, + const OfflineWenetCtcModelConfig &wenet_ctc, const std::string &tokens, int32_t num_threads, bool debug, const std::string &provider, const std::string &model_type) : transducer(transducer), @@ -54,6 +57,7 @@ struct OfflineModelConfig { whisper(whisper), tdnn(tdnn), zipformer_ctc(zipformer_ctc), + wenet_ctc(wenet_ctc), tokens(tokens), num_threads(num_threads), debug(debug), diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h index 98d220ba5..2eb908b53 100644 --- a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -75,6 +75,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { #endif void Init() { + if (!config_.model_config.wenet_ctc.model.empty()) { + // WeNet CTC models assume input samples are in the range + // [-32768, 32767], so we set normalize_samples to false + config_.feat_config.normalize_samples = false; + } + config_.feat_config.nemo_normalize_type = model_->FeatureNormalizationMethod(); @@ -85,10 +91,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { config_.ctc_fst_decoder_config); } else if (config_.decoding_method == "greedy_search") { if (!symbol_table_.contains("") && - !symbol_table_.contains("")) { + !symbol_table_.contains("") && + !symbol_table_.contains("")) { SHERPA_ONNX_LOGE( "We expect that tokens.txt contains " - "the symbol or and its ID."); + "the symbol or or and its ID."); exit(-1); } @@ -98,6 +105,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { } else if (symbol_table_.contains("")) { // for tdnn models of the yesno recipe from icefall blank_id = symbol_table_[""]; + } else if (symbol_table_.contains("")) { + // for Wenet CTC models + blank_id = symbol_table_[""]; } decoder_ = std::make_unique(blank_id); @@ -113,6 +123,15 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { } void DecodeStreams(OfflineStream **ss, int32_t n) const override { + if (!model_->SupportBatchProcessing()) { + // If the model does not support batch process, + // we process each stream independently. + for (int32_t i = 0; i != n; ++i) { + DecodeStream(ss[i]); + } + return; + } + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); @@ -164,6 +183,38 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { } } + private: + // Decode a single stream. + // Some models do not support batch size > 1, e.g., WeNet CTC models. + void DecodeStream(OfflineStream *s) const { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + int32_t feat_dim = config_.feat_config.feature_dim; + std::vector f = s->GetFrames(); + + int32_t num_frames = f.size() / feat_dim; + + std::array shape = {1, num_frames, feat_dim}; + + Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(), + shape.data(), shape.size()); + + int64_t x_length_scalar = num_frames; + std::array x_length_shape = {1}; + Ort::Value x_length = + Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1, + x_length_shape.data(), x_length_shape.size()); + + auto t = model_->Forward(std::move(x), std::move(x_length)); + auto results = decoder_->Decode(std::move(t[0]), std::move(t[1])); + int32_t frame_shift_ms = 10; + + auto r = Convert(results[0], symbol_table_, frame_shift_ms, + model_->SubsamplingFactor()); + s->SetResult(r); + } + private: OfflineRecognizerConfig config_; SymbolTable symbol_table_; diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index 31e16133b..488e7e7f3 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -26,7 +26,7 @@ std::unique_ptr OfflineRecognizerImpl::Create( } else if (model_type == "paraformer") { return std::make_unique(config); } else if (model_type == "nemo_ctc" || model_type == "tdnn" || - model_type == "zipformer2_ctc") { + model_type == "zipformer2_ctc" || model_type == "wenet_ctc") { return std::make_unique(config); } else if (model_type == "whisper") { return std::make_unique(config); @@ -51,6 +51,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( model_filename = config.model_config.tdnn.model; } else if (!config.model_config.zipformer_ctc.model.empty()) { model_filename = config.model_config.zipformer_ctc.model; + } else if (!config.model_config.wenet_ctc.model.empty()) { + model_filename = config.model_config.wenet_ctc.model; } else if (!config.model_config.whisper.encoder.empty()) { model_filename = config.model_config.whisper.encoder; } else { @@ -99,6 +101,10 @@ std::unique_ptr OfflineRecognizerImpl::Create( "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/" "zipformer/export-onnx-ctc.py" "\n" + "(6) CTC models from WeNet" + "\n " + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh" + "\n" "\n"); exit(-1); } @@ -114,7 +120,7 @@ std::unique_ptr OfflineRecognizerImpl::Create( } if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" || - model_type == "zipformer2_ctc") { + model_type == "zipformer2_ctc" || model_type == "wenet_ctc") { return std::make_unique(config); } @@ -130,7 +136,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( " - EncDecCTCModelBPE models from NeMo\n" " - Whisper models\n" " - Tdnn models\n" - " - Zipformer CTC models\n", + " - Zipformer CTC models\n" + " - WeNet CTC models\n", model_type.c_str()); exit(-1); @@ -146,7 +153,7 @@ std::unique_ptr OfflineRecognizerImpl::Create( } else if (model_type == "paraformer") { return std::make_unique(mgr, config); } else if (model_type == "nemo_ctc" || model_type == "tdnn" || - model_type == "zipformer2_ctc") { + model_type == "zipformer2_ctc" || model_type == "wenet_ctc") { return std::make_unique(mgr, config); } else if (model_type == "whisper") { return std::make_unique(mgr, config); @@ -171,6 +178,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( model_filename = config.model_config.tdnn.model; } else if (!config.model_config.zipformer_ctc.model.empty()) { model_filename = config.model_config.zipformer_ctc.model; + } else if (!config.model_config.wenet_ctc.model.empty()) { + model_filename = config.model_config.wenet_ctc.model; } else if (!config.model_config.whisper.encoder.empty()) { model_filename = config.model_config.whisper.encoder; } else { @@ -219,6 +228,10 @@ std::unique_ptr OfflineRecognizerImpl::Create( "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/" "zipformer/export-onnx-ctc.py" "\n" + "(6) CTC models from WeNet" + "\n " + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh" + "\n" "\n"); exit(-1); } @@ -234,7 +247,7 @@ std::unique_ptr OfflineRecognizerImpl::Create( } if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" || - model_type == "zipformer2_ctc") { + model_type == "zipformer2_ctc" || model_type == "wenet_ctc") { return std::make_unique(mgr, config); } @@ -250,7 +263,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( " - EncDecCTCModelBPE models from NeMo\n" " - Whisper models\n" " - Tdnn models\n" - " - Zipformer CTC models\n", + " - Zipformer CTC models\n" + " - WeNet CTC models\n", model_type.c_str()); exit(-1); diff --git a/sherpa-onnx/csrc/offline-stream.h b/sherpa-onnx/csrc/offline-stream.h index 6eee0e545..26b890b60 100644 --- a/sherpa-onnx/csrc/offline-stream.h +++ b/sherpa-onnx/csrc/offline-stream.h @@ -40,7 +40,8 @@ struct OfflineFeatureExtractorConfig { // Feature dimension int32_t feature_dim = 80; - // Set internally by some models, e.g., paraformer sets it to false. + // Set internally by some models, e.g., paraformer and wenet CTC models set + // it to false. // This parameter is not exposed to users from the commandline // If true, the feature extractor expects inputs to be normalized to // the range [-1, 1]. diff --git a/sherpa-onnx/csrc/offline-wenet-ctc-model-config.cc b/sherpa-onnx/csrc/offline-wenet-ctc-model-config.cc new file mode 100644 index 000000000..f3543948e --- /dev/null +++ b/sherpa-onnx/csrc/offline-wenet-ctc-model-config.cc @@ -0,0 +1,37 @@ +// sherpa-onnx/csrc/offline-wenet-ctc-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineWenetCtcModelConfig::Register(ParseOptions *po) { + po->Register( + "wenet-ctc-model", &model, + "Path to model.onnx from WeNet. Please see " + "https://github.com/k2-fsa/sherpa-onnx/pull/425 for available models"); +} + +bool OfflineWenetCtcModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("WeNet model: %s does not exist", model.c_str()); + return false; + } + + return true; +} + +std::string OfflineWenetCtcModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineWenetCtcModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-wenet-ctc-model-config.h b/sherpa-onnx/csrc/offline-wenet-ctc-model-config.h new file mode 100644 index 000000000..4a9b30b80 --- /dev/null +++ b/sherpa-onnx/csrc/offline-wenet-ctc-model-config.h @@ -0,0 +1,28 @@ +// sherpa-onnx/csrc/offline-wenet-ctc-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineWenetCtcModelConfig { + std::string model; + + OfflineWenetCtcModelConfig() = default; + explicit OfflineWenetCtcModelConfig(const std::string &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-wenet-ctc-model.cc b/sherpa-onnx/csrc/offline-wenet-ctc-model.cc new file mode 100644 index 000000000..93fdffab8 --- /dev/null +++ b/sherpa-onnx/csrc/offline-wenet-ctc-model.cc @@ -0,0 +1,118 @@ +// sherpa-onnx/csrc/offline-wenet-ctc-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h" + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" +#include "sherpa-onnx/csrc/transpose.h" + +namespace sherpa_onnx { + +class OfflineWenetCtcModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.wenet_ctc.model); + Init(buf.data(), buf.size()); + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.wenet_ctc.model); + Init(buf.data(), buf.size()); + } +#endif + + std::vector Forward(Ort::Value features, + Ort::Value features_length) { + std::array inputs = {std::move(features), + std::move(features_length)}; + + return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + } + + int32_t VocabSize() const { return vocab_size_; } + + int32_t SubsamplingFactor() const { return subsampling_factor_; } + + OrtAllocator *Allocator() const { return allocator_; } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); + } + + private: + OfflineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t vocab_size_ = 0; + int32_t subsampling_factor_ = 0; +}; + +OfflineWenetCtcModel::OfflineWenetCtcModel(const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OfflineWenetCtcModel::OfflineWenetCtcModel(AAssetManager *mgr, + const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OfflineWenetCtcModel::~OfflineWenetCtcModel() = default; + +std::vector OfflineWenetCtcModel::Forward( + Ort::Value features, Ort::Value features_length) { + return impl_->Forward(std::move(features), std::move(features_length)); +} + +int32_t OfflineWenetCtcModel::VocabSize() const { return impl_->VocabSize(); } + +int32_t OfflineWenetCtcModel::SubsamplingFactor() const { + return impl_->SubsamplingFactor(); +} + +OrtAllocator *OfflineWenetCtcModel::Allocator() const { + return impl_->Allocator(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-wenet-ctc-model.h b/sherpa-onnx/csrc/offline-wenet-ctc-model.h new file mode 100644 index 000000000..4eb78b73a --- /dev/null +++ b/sherpa-onnx/csrc/offline-wenet-ctc-model.h @@ -0,0 +1,79 @@ +// sherpa-onnx/csrc/offline-wenet-ctc-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_ +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-ctc-model.h" +#include "sherpa-onnx/csrc/offline-model-config.h" + +namespace sherpa_onnx { + +/** This class implements the CTC model from WeNet. + * + * See + * https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/export-onnx.py + * https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/test-onnx.py + * https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh + * + */ +class OfflineWenetCtcModel : public OfflineCtcModel { + public: + explicit OfflineWenetCtcModel(const OfflineModelConfig &config); + +#if __ANDROID_API__ >= 9 + OfflineWenetCtcModel(AAssetManager *mgr, const OfflineModelConfig &config); +#endif + + ~OfflineWenetCtcModel() override; + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int64_t. + * + * @return Return a vector containing: + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t + */ + std::vector Forward(Ort::Value features, + Ort::Value features_length) override; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const override; + + /** SubsamplingFactor of the model + * + * For Citrinet, the subsampling factor is usually 4. + * For Conformer CTC, the subsampling factor is usually 8. + */ + int32_t SubsamplingFactor() const override; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const override; + + // WeNet CTC models do not support batch size > 1 + bool SupportBatchProcessing() const override { return false; } + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_ diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 62500bdc8..26c44db07 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -17,6 +17,7 @@ pybind11_add_module(_sherpa_onnx offline-tts-model-config.cc offline-tts-vits-model-config.cc offline-tts.cc + offline-wenet-ctc-model-config.cc offline-whisper-model-config.cc offline-zipformer-ctc-model-config.cc online-lm-config.cc diff --git a/sherpa-onnx/python/csrc/offline-model-config.cc b/sherpa-onnx/python/csrc/offline-model-config.cc index 970d83b20..fa742490f 100644 --- a/sherpa-onnx/python/csrc/offline-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-model-config.cc @@ -12,6 +12,7 @@ #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" #include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h" #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" +#include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h" #include "sherpa-onnx/python/csrc/offline-whisper-model-config.h" #include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h" @@ -24,6 +25,7 @@ void PybindOfflineModelConfig(py::module *m) { PybindOfflineWhisperModelConfig(m); PybindOfflineTdnnModelConfig(m); PybindOfflineZipformerCtcModelConfig(m); + PybindOfflineWenetCtcModelConfig(m); using PyClass = OfflineModelConfig; py::class_(*m, "OfflineModelConfig") @@ -32,7 +34,8 @@ void PybindOfflineModelConfig(py::module *m) { const OfflineNemoEncDecCtcModelConfig &, const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &, - const OfflineZipformerCtcModelConfig &, const std::string &, + const OfflineZipformerCtcModelConfig &, + const OfflineWenetCtcModelConfig &, const std::string &, int32_t, bool, const std::string &, const std::string &>(), py::arg("transducer") = OfflineTransducerModelConfig(), py::arg("paraformer") = OfflineParaformerModelConfig(), @@ -40,6 +43,7 @@ void PybindOfflineModelConfig(py::module *m) { py::arg("whisper") = OfflineWhisperModelConfig(), py::arg("tdnn") = OfflineTdnnModelConfig(), py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), + py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, py::arg("provider") = "cpu", py::arg("model_type") = "") .def_readwrite("transducer", &PyClass::transducer) @@ -48,6 +52,7 @@ void PybindOfflineModelConfig(py::module *m) { .def_readwrite("whisper", &PyClass::whisper) .def_readwrite("tdnn", &PyClass::tdnn) .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc) + .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("debug", &PyClass::debug) diff --git a/sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.cc b/sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.cc new file mode 100644 index 000000000..d8bce13a2 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.cc @@ -0,0 +1,22 @@ +// sherpa-onnx/python/csrc/offline-wenet-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h" + +#include +#include + +#include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h" + +namespace sherpa_onnx { + +void PybindOfflineWenetCtcModelConfig(py::module *m) { + using PyClass = OfflineWenetCtcModelConfig; + py::class_(*m, "OfflineWenetCtcModelConfig") + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h b/sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h new file mode 100644 index 000000000..ea92c46f1 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-wenet-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineWenetCtcModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_