Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support non-streaming WeNet CTC models. #426

Merged
merged 5 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions .github/scripts/test-offline-ctc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,47 @@ echo "PATH: $PATH"

which $EXE

log "------------------------------------------------------------"
log "Run Wenet models"
log "------------------------------------------------------------"
wenet_models=(
sherpa-onnx-zh-wenet-aishell
sherpa-onnx-zh-wenet-aishell2
sherpa-onnx-zh-wenet-wenetspeech
sherpa-onnx-zh-wenet-multi-cn
sherpa-onnx-en-wenet-librispeech
sherpa-onnx-en-wenet-gigaspeech
)
for name in ${wenet_models[@]}; do
repo_url=https://huggingface.co/csukuangfj/$name
log "Start testing ${repo_url}"
repo=$(basename $repo_url)
log "Download pretrained model and test-data from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
ls -lh *.onnx
popd

log "test float32 models"
time $EXE \
--tokens=$repo/tokens.txt \
--wenet-ctc-model=$repo/model.onnx \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav

log "test int8 models"
time $EXE \
--tokens=$repo/tokens.txt \
--wenet-ctc-model=$repo/model.int8.onnx \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav

rm -rf $repo
done

log "------------------------------------------------------------"
log "Run tdnn yesno (Hebrew)"
log "------------------------------------------------------------"
Expand Down
11 changes: 0 additions & 11 deletions .github/workflows/export-wenet-to-onnx.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,6 @@
name: export-wenet-to-onnx

on:
push:
branches:
- master
paths:
- 'scripts/wenet/**'
- '.github/workflows/export-wenet-to-onnx.yaml'
pull_request:
paths:
- 'scripts/wenet/**'
- '.github/workflows/export-wenet-to-onnx.yaml'

workflow_dispatch:

concurrency:
Expand Down
16 changes: 8 additions & 8 deletions .github/workflows/linux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ jobs:
file build/bin/sherpa-onnx
readelf -d build/bin/sherpa-onnx

- name: Test offline CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline

.github/scripts/test-offline-ctc.sh

- name: Test offline TTS
shell: bash
run: |
Expand All @@ -115,14 +123,6 @@ jobs:

.github/scripts/test-offline-whisper.sh

- name: Test offline CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline

.github/scripts/test-offline-ctc.sh

- name: Test offline transducer
shell: bash
run: |
Expand Down
3 changes: 2 additions & 1 deletion scripts/wenet/export-onnx-streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def main():
# https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
url = os.environ.get("WENET_URL", "")
meta_data = {
"model_type": "wenet-ctc",
"model_type": "wenet_ctc",
"version": "1",
"model_author": "wenet",
"comment": "streaming",
Expand All @@ -185,6 +185,7 @@ def main():
"cnn_module_kernel": cnn_module_kernel,
"right_context": right_context,
"subsampling_factor": subsampling_factor,
"vocab_size": torch_model.ctc.ctc_lo.weight.shape[0],
}
add_meta_data(filename=filename, meta_data=meta_data)

Expand Down
4 changes: 3 additions & 1 deletion scripts/wenet/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,12 @@ def main():
# https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
url = os.environ.get("WENET_URL", "")
meta_data = {
"model_type": "wenet-ctc",
"model_type": "wenet_ctc",
"version": "1",
"model_author": "wenet",
"comment": "non-streaming",
"subsampling_factor": torch_model.encoder.embed.subsampling_rate,
"vocab_size": torch_model.ctc.ctc_lo.weight.shape[0],
"url": url,
}
add_meta_data(filename=filename, meta_data=meta_data)
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ set(sources
offline-transducer-model-config.cc
offline-transducer-model.cc
offline-transducer-modified-beam-search-decoder.cc
offline-wenet-ctc-model-config.cc
offline-wenet-ctc-model.cc
offline-whisper-greedy-search-decoder.cc
offline-whisper-model-config.cc
offline-whisper-model.cc
Expand Down
19 changes: 18 additions & 1 deletion sherpa-onnx/csrc/offline-ctc-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h"
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"

Expand All @@ -21,10 +22,11 @@ enum class ModelType {
kEncDecCTCModelBPE,
kTdnn,
kZipformerCtc,
kWenetCtc,
kUnkown,
};

}
} // namespace

namespace sherpa_onnx {

Expand Down Expand Up @@ -52,6 +54,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
"If you are using models from NeMo, please refer to\n"
"https://huggingface.co/csukuangfj/"
"sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
"If you are using models from WeNet, please refer to\n"
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/"
"run.sh\n"
"\n"
"for how to add metadta to model.onnx\n");
return ModelType::kUnkown;
Expand All @@ -63,6 +68,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kTdnn;
} else if (model_type.get() == std::string("zipformer2_ctc")) {
return ModelType::kZipformerCtc;
} else if (model_type.get() == std::string("wenet_ctc")) {
return ModelType::kWenetCtc;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
Expand All @@ -80,6 +87,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
filename = config.tdnn.model;
} else if (!config.zipformer_ctc.model.empty()) {
filename = config.zipformer_ctc.model;
} else if (!config.wenet_ctc.model.empty()) {
filename = config.wenet_ctc.model;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
Expand All @@ -101,6 +110,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kZipformerCtc:
return std::make_unique<OfflineZipformerCtcModel>(config);
break;
case ModelType::kWenetCtc:
return std::make_unique<OfflineWenetCtcModel>(config);
break;
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
Expand All @@ -122,6 +134,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
filename = config.tdnn.model;
} else if (!config.zipformer_ctc.model.empty()) {
filename = config.zipformer_ctc.model;
} else if (!config.wenet_ctc.model.empty()) {
filename = config.wenet_ctc.model;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
Expand All @@ -143,6 +157,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kZipformerCtc:
return std::make_unique<OfflineZipformerCtcModel>(mgr, config);
break;
case ModelType::kWenetCtc:
return std::make_unique<OfflineWenetCtcModel>(mgr, config);
break;
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/offline-ctc-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class OfflineCtcModel {
* for the features.
*/
virtual std::string FeatureNormalizationMethod() const { return {}; }

// Return true if the model supports batch size > 1
virtual bool SupportBatchProcessing() const { return true; }
};

} // namespace sherpa_onnx
Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/csrc/offline-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
whisper.Register(po);
tdnn.Register(po);
zipformer_ctc.Register(po);
wenet_ctc.Register(po);

po->Register("tokens", &tokens, "Path to tokens.txt");

Expand Down Expand Up @@ -67,6 +68,10 @@ bool OfflineModelConfig::Validate() const {
return zipformer_ctc.Validate();
}

if (!wenet_ctc.model.empty()) {
return wenet_ctc.Validate();
}

return transducer.Validate();
}

Expand All @@ -80,6 +85,7 @@ std::string OfflineModelConfig::ToString() const {
os << "whisper=" << whisper.ToString() << ", ";
os << "tdnn=" << tdnn.ToString() << ", ";
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/offline-model-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h"

Expand All @@ -22,6 +23,7 @@ struct OfflineModelConfig {
OfflineWhisperModelConfig whisper;
OfflineTdnnModelConfig tdnn;
OfflineZipformerCtcModelConfig zipformer_ctc;
OfflineWenetCtcModelConfig wenet_ctc;

std::string tokens;
int32_t num_threads = 2;
Expand All @@ -46,6 +48,7 @@ struct OfflineModelConfig {
const OfflineWhisperModelConfig &whisper,
const OfflineTdnnModelConfig &tdnn,
const OfflineZipformerCtcModelConfig &zipformer_ctc,
const OfflineWenetCtcModelConfig &wenet_ctc,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type)
: transducer(transducer),
Expand All @@ -54,6 +57,7 @@ struct OfflineModelConfig {
whisper(whisper),
tdnn(tdnn),
zipformer_ctc(zipformer_ctc),
wenet_ctc(wenet_ctc),
tokens(tokens),
num_threads(num_threads),
debug(debug),
Expand Down
55 changes: 53 additions & 2 deletions sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
#endif

void Init() {
if (!config_.model_config.wenet_ctc.model.empty()) {
// WeNet CTC models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false
config_.feat_config.normalize_samples = false;
}

config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();

Expand All @@ -85,10 +91,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
config_.ctc_fst_decoder_config);
} else if (config_.decoding_method == "greedy_search") {
if (!symbol_table_.contains("<blk>") &&
!symbol_table_.contains("<eps>")) {
!symbol_table_.contains("<eps>") &&
!symbol_table_.contains("<blank>")) {
SHERPA_ONNX_LOGE(
"We expect that tokens.txt contains "
"the symbol <blk> or <eps> and its ID.");
"the symbol <blk> or <eps> or <blank> and its ID.");
exit(-1);
}

Expand All @@ -98,6 +105,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
} else if (symbol_table_.contains("<eps>")) {
// for tdnn models of the yesno recipe from icefall
blank_id = symbol_table_["<eps>"];
} else if (symbol_table_.contains("<blank>")) {
// for Wenet CTC models
blank_id = symbol_table_["<blank>"];
}

decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
Expand All @@ -113,6 +123,15 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}

void DecodeStreams(OfflineStream **ss, int32_t n) const override {
if (!model_->SupportBatchProcessing()) {
// If the model does not support batch process,
// we process each stream independently.
for (int32_t i = 0; i != n; ++i) {
DecodeStream(ss[i]);
}
return;
}

auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

Expand Down Expand Up @@ -164,6 +183,38 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}
}

private:
// Decode a single stream.
// Some models do not support batch size > 1, e.g., WeNet CTC models.
void DecodeStream(OfflineStream *s) const {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

int32_t feat_dim = config_.feat_config.feature_dim;
std::vector<float> f = s->GetFrames();

int32_t num_frames = f.size() / feat_dim;

std::array<int64_t, 3> shape = {1, num_frames, feat_dim};

Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
shape.data(), shape.size());

int64_t x_length_scalar = num_frames;
std::array<int64_t, 1> x_length_shape = {1};
Ort::Value x_length =
Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1,
x_length_shape.data(), x_length_shape.size());

auto t = model_->Forward(std::move(x), std::move(x_length));
auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]));
int32_t frame_shift_ms = 10;

auto r = Convert(results[0], symbol_table_, frame_shift_ms,
model_->SubsamplingFactor());
s->SetResult(r);
}

private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
Expand Down
Loading
Loading