From b41f6d2c94d377d8716a10d4a3c222ac608f542a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 25 Oct 2024 10:55:16 +0800 Subject: [PATCH] Support GigaAM CTC models for Russian ASR (#1464) See also https://github.com/salute-developers/GigaAM --- .github/scripts/test-offline-ctc.sh | 15 ++ .../export-nemo-giga-am-to-onnx.yaml | 88 ++++++++++ .github/workflows/linux.yaml | 20 +-- scripts/apk/generate-vad-asr-apk-script.py | 18 ++ scripts/nemo/GigaAM/README.md | 10 ++ scripts/nemo/GigaAM/export-onnx-ctc.py | 114 +++++++++++++ scripts/nemo/GigaAM/run-ctc.sh | 36 ++++ scripts/nemo/GigaAM/test-onnx-ctc.py | 157 ++++++++++++++++++ sherpa-onnx/csrc/features.cc | 2 + sherpa-onnx/csrc/features.h | 1 + sherpa-onnx/csrc/jieba-lexicon.cc | 2 +- sherpa-onnx/csrc/lexicon.cc | 40 +---- sherpa-onnx/csrc/lexicon.h | 6 - sherpa-onnx/csrc/macros.h | 140 +++++++++------- sherpa-onnx/csrc/melo-tts-lexicon.cc | 2 +- sherpa-onnx/csrc/offline-ctc-model.cc | 19 +-- sherpa-onnx/csrc/offline-ctc-model.h | 4 + .../csrc/offline-nemo-enc-dec-ctc-model.cc | 12 +- .../csrc/offline-nemo-enc-dec-ctc-model.h | 2 + .../csrc/offline-recognizer-ctc-impl.h | 19 ++- sherpa-onnx/csrc/offline-recognizer-impl.cc | 6 +- sherpa-onnx/csrc/symbol-table.cc | 66 +++++--- sherpa-onnx/csrc/symbol-table.h | 12 ++ sherpa-onnx/kotlin-api/OfflineRecognizer.kt | 10 ++ 24 files changed, 641 insertions(+), 160 deletions(-) create mode 100644 .github/workflows/export-nemo-giga-am-to-onnx.yaml create mode 100644 scripts/nemo/GigaAM/README.md create mode 100755 scripts/nemo/GigaAM/export-onnx-ctc.py create mode 100755 scripts/nemo/GigaAM/run-ctc.sh create mode 100755 scripts/nemo/GigaAM/test-onnx-ctc.py diff --git a/.github/scripts/test-offline-ctc.sh b/.github/scripts/test-offline-ctc.sh index 57208e9da..f85b58539 100755 --- a/.github/scripts/test-offline-ctc.sh +++ b/.github/scripts/test-offline-ctc.sh @@ -15,6 +15,21 @@ echo "PATH: $PATH" which $EXE +log "------------------------------------------------------------" +log "Run NeMo GigaAM Russian models" +log "------------------------------------------------------------" +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2 +tar xvf sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2 +rm sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2 + +$EXE \ + --nemo-ctc-model=./sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24/model.int8.onnx \ + --tokens=./sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24/tokens.txt \ + --debug=1 \ + ./sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24/test_wavs/example.wav + +rm -rf sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24 + log "------------------------------------------------------------" log "Run SenseVoice models" log "------------------------------------------------------------" diff --git a/.github/workflows/export-nemo-giga-am-to-onnx.yaml b/.github/workflows/export-nemo-giga-am-to-onnx.yaml new file mode 100644 index 000000000..f48c344e6 --- /dev/null +++ b/.github/workflows/export-nemo-giga-am-to-onnx.yaml @@ -0,0 +1,88 @@ +name: export-nemo-giga-am-to-onnx + +on: + workflow_dispatch: + +concurrency: + group: export-nemo-giga-am-to-onnx-${{ github.ref }} + cancel-in-progress: true + +jobs: + export-nemo-am-giga-to-onnx: + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' + name: export nemo GigaAM models to ONNX + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [macos-latest] + python-version: ["3.10"] + + steps: + - uses: actions/checkout@v4 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Run CTC + shell: bash + run: | + pushd scripts/nemo/GigaAM + ./run-ctc.sh + popd + + d=sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24 + mkdir $d + mkdir $d/test_wavs + rm scripts/nemo/GigaAM/model.onnx + mv -v scripts/nemo/GigaAM/*.int8.onnx $d/ + mv -v scripts/nemo/GigaAM/*.md $d/ + mv -v scripts/nemo/GigaAM/*.pdf $d/ + mv -v scripts/nemo/GigaAM/tokens.txt $d/ + mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/ + mv -v scripts/nemo/GigaAM/run-ctc.sh $d/ + mv -v scripts/nemo/GigaAM/*-ctc.py $d/ + + ls -lh scripts/nemo/GigaAM/ + + ls -lh $d + + tar cjvf ${d}.tar.bz2 $d + + - name: Release + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + file: ./*.tar.bz2 + overwrite: true + repo_name: k2-fsa/sherpa-onnx + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} + tag: asr-models + + - name: Publish to huggingface (CTC) + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + uses: nick-fields/retry@v3 + with: + max_attempts: 20 + timeout_seconds: 200 + shell: bash + command: | + git config --global user.email "csukuangfj@gmail.com" + git config --global user.name "Fangjun Kuang" + + d=sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24 + export GIT_LFS_SKIP_SMUDGE=1 + export GIT_CLONE_PROTECTION_ACTIVE=false + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface + mv -v $d/* ./huggingface + cd huggingface + git lfs track "*.onnx" + git lfs track "*.wav" + git status + git add . + git status + git commit -m "add models" + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index a37f48618..2ad40e215 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -149,6 +149,16 @@ jobs: name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} path: install/* + - name: Test offline CTC + shell: bash + run: | + du -h -d1 . + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline + + .github/scripts/test-offline-ctc.sh + du -h -d1 . + - name: Test C++ API shell: bash run: | @@ -180,16 +190,6 @@ jobs: .github/scripts/test-offline-transducer.sh du -h -d1 . - - name: Test offline CTC - shell: bash - run: | - du -h -d1 . - export PATH=$PWD/build/bin:$PATH - export EXE=sherpa-onnx-offline - - .github/scripts/test-offline-ctc.sh - du -h -d1 . - - name: Test online punctuation shell: bash run: | diff --git a/scripts/apk/generate-vad-asr-apk-script.py b/scripts/apk/generate-vad-asr-apk-script.py index b6361427f..8217e6ea0 100755 --- a/scripts/apk/generate-vad-asr-apk-script.py +++ b/scripts/apk/generate-vad-asr-apk-script.py @@ -333,6 +333,24 @@ def get_models(): ls -lh + popd + """, + ), + Model( + model_name="sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24", + idx=19, + lang="ru", + short_name="nemo_ctc_giga_am", + cmd=""" + pushd $model_name + + rm -rfv test_wavs + + rm -fv *.sh + rm -fv *.py + + ls -lh + popd """, ), diff --git a/scripts/nemo/GigaAM/README.md b/scripts/nemo/GigaAM/README.md new file mode 100644 index 000000000..583d10a16 --- /dev/null +++ b/scripts/nemo/GigaAM/README.md @@ -0,0 +1,10 @@ +# Introduction + +This folder contains scripts for converting models from +https://github.com/salute-developers/GigaAM +to sherpa-onnx. + +The ASR models are for Russian speech recognition in this folder. + +Please see the license of the models at +https://github.com/salute-developers/GigaAM/blob/main/GigaAM%20License_NC.pdf diff --git a/scripts/nemo/GigaAM/export-onnx-ctc.py b/scripts/nemo/GigaAM/export-onnx-ctc.py new file mode 100755 index 000000000..fbcec518e --- /dev/null +++ b/scripts/nemo/GigaAM/export-onnx-ctc.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) +from typing import Dict + +import onnx +import torch +import torchaudio +from nemo.collections.asr.models import EncDecCTCModel +from nemo.collections.asr.modules.audio_preprocessing import ( + AudioToMelSpectrogramPreprocessor as NeMoAudioToMelSpectrogramPreprocessor, +) +from nemo.collections.asr.parts.preprocessing.features import ( + FilterbankFeaturesTA as NeMoFilterbankFeaturesTA, +) +from onnxruntime.quantization import QuantType, quantize_dynamic + + +class FilterbankFeaturesTA(NeMoFilterbankFeaturesTA): + def __init__(self, mel_scale: str = "htk", wkwargs=None, **kwargs): + if "window_size" in kwargs: + del kwargs["window_size"] + if "window_stride" in kwargs: + del kwargs["window_stride"] + + super().__init__(**kwargs) + + self._mel_spec_extractor: torchaudio.transforms.MelSpectrogram = ( + torchaudio.transforms.MelSpectrogram( + sample_rate=self._sample_rate, + win_length=self.win_length, + hop_length=self.hop_length, + n_mels=kwargs["nfilt"], + window_fn=self.torch_windows[kwargs["window"]], + mel_scale=mel_scale, + norm=kwargs["mel_norm"], + n_fft=kwargs["n_fft"], + f_max=kwargs.get("highfreq", None), + f_min=kwargs.get("lowfreq", 0), + wkwargs=wkwargs, + ) + ) + + +class AudioToMelSpectrogramPreprocessor(NeMoAudioToMelSpectrogramPreprocessor): + def __init__(self, mel_scale: str = "htk", **kwargs): + super().__init__(**kwargs) + kwargs["nfilt"] = kwargs["features"] + del kwargs["features"] + self.featurizer = ( + FilterbankFeaturesTA( # Deprecated arguments; kept for config compatibility + mel_scale=mel_scale, + **kwargs, + ) + ) + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + while len(model.metadata_props): + model.metadata_props.pop() + + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +def main(): + model = EncDecCTCModel.from_config_file("./ctc_model_config.yaml") + ckpt = torch.load("./ctc_model_weights.ckpt", map_location="cpu") + model.load_state_dict(ckpt, strict=False) + model.eval() + + with open("tokens.txt", "w", encoding="utf-8") as f: + for i, t in enumerate(model.cfg.labels): + f.write(f"{t} {i}\n") + f.write(f" {i+1}\n") + + filename = "model.onnx" + model.export(filename) + + meta_data = { + "vocab_size": len(model.cfg.labels) + 1, + "normalize_type": "", + "subsampling_factor": 4, + "model_type": "EncDecCTCModel", + "version": "1", + "model_author": "https://github.com/salute-developers/GigaAM", + "license": "https://github.com/salute-developers/GigaAM/blob/main/GigaAM%20License_NC.pdf", + "language": "Russian", + "is_giga_am": 1, + } + add_meta_data(filename, meta_data) + + filename_int8 = "model.int8.onnx" + quantize_dynamic( + model_input=filename, + model_output=filename_int8, + weight_type=QuantType.QUInt8, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/nemo/GigaAM/run-ctc.sh b/scripts/nemo/GigaAM/run-ctc.sh new file mode 100755 index 000000000..499bc452b --- /dev/null +++ b/scripts/nemo/GigaAM/run-ctc.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +set -ex + +function install_nemo() { + curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py + python3 get-pip.py + + pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html + + pip install -qq wget text-unidecode matplotlib>=3.3.2 onnx onnxruntime pybind11 Cython einops kaldi-native-fbank soundfile librosa + pip install -qq ipython + + # sudo apt-get install -q -y sox libsndfile1 ffmpeg python3-pip ipython + + BRANCH='main' + python3 -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr] + + pip install numpy==1.26.4 +} + +function download_files() { + curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/ctc_model_weights.ckpt + curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/ctc_model_config.yaml + curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/example.wav + curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/long_example.wav + curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM%20License_NC.pdf +} + +install_nemo +download_files + +python3 ./export-onnx-ctc.py +ls -lh +python3 ./test-onnx-ctc.py diff --git a/scripts/nemo/GigaAM/test-onnx-ctc.py b/scripts/nemo/GigaAM/test-onnx-ctc.py new file mode 100755 index 000000000..5c181e49a --- /dev/null +++ b/scripts/nemo/GigaAM/test-onnx-ctc.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +# https://github.com/salute-developers/GigaAM + +import kaldi_native_fbank as knf +import librosa +import numpy as np +import onnxruntime as ort +import soundfile as sf +import torch + + +def create_fbank(): + opts = knf.FbankOptions() + opts.frame_opts.dither = 0 + opts.frame_opts.remove_dc_offset = False + opts.frame_opts.preemph_coeff = 0 + opts.frame_opts.window_type = "hann" + + # Even though GigaAM uses 400 for fft, here we use 512 + # since kaldi-native-fbank only support fft for power of 2. + opts.frame_opts.round_to_power_of_two = True + + opts.mel_opts.low_freq = 0 + opts.mel_opts.high_freq = 8000 + opts.mel_opts.num_bins = 64 + + fbank = knf.OnlineFbank(opts) + return fbank + + +def compute_features(audio, fbank) -> np.ndarray: + """ + Args: + audio: (num_samples,), np.float32 + fbank: the fbank extractor + Returns: + features: (num_frames, feat_dim), np.float32 + """ + assert len(audio.shape) == 1, audio.shape + fbank.accept_waveform(16000, audio) + ans = [] + processed = 0 + while processed < fbank.num_frames_ready: + ans.append(np.array(fbank.get_frame(processed))) + processed += 1 + ans = np.stack(ans) + return ans + + +def display(sess): + print("==========Input==========") + for i in sess.get_inputs(): + print(i) + print("==========Output==========") + for i in sess.get_outputs(): + print(i) + + +""" +==========Input========== +NodeArg(name='audio_signal', type='tensor(float)', shape=['audio_signal_dynamic_axes_1', 64, 'audio_signal_dynamic_axes_2']) +NodeArg(name='length', type='tensor(int64)', shape=['length_dynamic_axes_1']) +==========Output========== +NodeArg(name='logprobs', type='tensor(float)', shape=['logprobs_dynamic_axes_1', 'logprobs_dynamic_axes_2', 34]) +""" + + +class OnnxModel: + def __init__( + self, + filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.model = ort.InferenceSession( + filename, + sess_options=session_opts, + providers=["CPUExecutionProvider"], + ) + display(self.model) + + def __call__(self, x: np.ndarray): + # x: (T, C) + x = torch.from_numpy(x) + x = x.t().unsqueeze(0) + # x: [1, C, T] + x_lens = torch.tensor([x.shape[-1]], dtype=torch.int64) + + log_probs = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + )[0] + # [batch_size, T, dim] + return log_probs + + +def main(): + filename = "./model.int8.onnx" + tokens = "./tokens.txt" + wav = "./example.wav" + + model = OnnxModel(filename) + + id2token = dict() + with open(tokens, encoding="utf-8") as f: + for line in f: + fields = line.split() + if len(fields) == 1: + id2token[int(fields[0])] = " " + else: + t, idx = fields + id2token[int(idx)] = t + + fbank = create_fbank() + audio, sample_rate = sf.read(wav, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + if sample_rate != 16000: + audio = librosa.resample( + audio, + orig_sr=sample_rate, + target_sr=16000, + ) + sample_rate = 16000 + + features = compute_features(audio, fbank) + print("features.shape", features.shape) + + blank = len(id2token) - 1 + prev = -1 + ans = [] + log_probs = model(features) + print("log_probs", log_probs.shape) + log_probs = torch.from_numpy(log_probs)[0] + ids = torch.argmax(log_probs, dim=1).tolist() + for i in ids: + if i != blank and i != prev: + ans.append(i) + prev = i + + tokens = [id2token[i] for i in ans] + + text = "".join(tokens) + print(wav) + print(text) + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/csrc/features.cc b/sherpa-onnx/csrc/features.cc index ed806f392..5b50d5f24 100644 --- a/sherpa-onnx/csrc/features.cc +++ b/sherpa-onnx/csrc/features.cc @@ -193,6 +193,7 @@ class FeatureExtractor::Impl { opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms; opts_.frame_opts.frame_length_ms = config_.frame_length_ms; opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset; + opts_.frame_opts.preemph_coeff = config_.preemph_coeff; opts_.frame_opts.window_type = config_.window_type; opts_.mel_opts.num_bins = config_.feature_dim; @@ -211,6 +212,7 @@ class FeatureExtractor::Impl { mfcc_opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms; mfcc_opts_.frame_opts.frame_length_ms = config_.frame_length_ms; mfcc_opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset; + mfcc_opts_.frame_opts.preemph_coeff = config_.preemph_coeff; mfcc_opts_.frame_opts.window_type = config_.window_type; mfcc_opts_.mel_opts.num_bins = config_.feature_dim; diff --git a/sherpa-onnx/csrc/features.h b/sherpa-onnx/csrc/features.h index afbacd2ec..fb5ff2fe3 100644 --- a/sherpa-onnx/csrc/features.h +++ b/sherpa-onnx/csrc/features.h @@ -57,6 +57,7 @@ struct FeatureExtractorConfig { float frame_length_ms = 25.0f; // in milliseconds. bool is_librosa = false; bool remove_dc_offset = true; // Subtract mean of wave before FFT. + float preemph_coeff = 0.97f; // Preemphasis coefficient. std::string window_type = "povey"; // e.g. Hamming window // For models from NeMo diff --git a/sherpa-onnx/csrc/jieba-lexicon.cc b/sherpa-onnx/csrc/jieba-lexicon.cc index e995c8cb8..a53f057f0 100644 --- a/sherpa-onnx/csrc/jieba-lexicon.cc +++ b/sherpa-onnx/csrc/jieba-lexicon.cc @@ -10,8 +10,8 @@ #include "cppjieba/Jieba.hpp" #include "sherpa-onnx/csrc/file-utils.h" -#include "sherpa-onnx/csrc/lexicon.h" #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/symbol-table.h" #include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { diff --git a/sherpa-onnx/csrc/lexicon.cc b/sherpa-onnx/csrc/lexicon.cc index 24184f180..499f17f68 100644 --- a/sherpa-onnx/csrc/lexicon.cc +++ b/sherpa-onnx/csrc/lexicon.cc @@ -21,6 +21,7 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/symbol-table.h" #include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { @@ -74,45 +75,6 @@ static std::vector ProcessHeteronyms( return ans; } -// Note: We don't use SymbolTable here since tokens may contain a blank -// in the first column -std::unordered_map ReadTokens(std::istream &is) { - std::unordered_map token2id; - - std::string line; - - std::string sym; - int32_t id = -1; - while (std::getline(is, line)) { - std::istringstream iss(line); - iss >> sym; - if (iss.eof()) { - id = atoi(sym.c_str()); - sym = " "; - } else { - iss >> id; - } - - // eat the trailing \r\n on windows - iss >> std::ws; - if (!iss.eof()) { - SHERPA_ONNX_LOGE("Error: %s", line.c_str()); - exit(-1); - } - -#if 0 - if (token2id.count(sym)) { - SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d", - sym.c_str(), line.c_str(), token2id.at(sym)); - exit(-1); - } -#endif - token2id.insert({std::move(sym), id}); - } - - return token2id; -} - std::vector ConvertTokensToIds( const std::unordered_map &token2id, const std::vector &tokens) { diff --git a/sherpa-onnx/csrc/lexicon.h b/sherpa-onnx/csrc/lexicon.h index 9c2d91ee9..fb60cdb7f 100644 --- a/sherpa-onnx/csrc/lexicon.h +++ b/sherpa-onnx/csrc/lexicon.h @@ -67,12 +67,6 @@ class Lexicon : public OfflineTtsFrontend { bool debug_ = false; }; -std::unordered_map ReadTokens(std::istream &is); - -std::vector ConvertTokensToIds( - const std::unordered_map &token2id, - const std::vector &tokens); - } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_LEXICON_H_ diff --git a/sherpa-onnx/csrc/macros.h b/sherpa-onnx/csrc/macros.h index 6bd6f62a6..521739a89 100644 --- a/sherpa-onnx/csrc/macros.h +++ b/sherpa-onnx/csrc/macros.h @@ -41,13 +41,13 @@ auto value = \ meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ if (!value) { \ - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ exit(-1); \ } \ \ dst = atoi(value.get()); \ if (dst < 0) { \ - SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \ + SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \ exit(-1); \ } \ } while (0) @@ -61,80 +61,80 @@ } else { \ dst = atoi(value.get()); \ if (dst < 0) { \ - SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \ + SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \ exit(-1); \ } \ } \ } while (0) // read a vector of integers -#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \ - do { \ - auto value = \ - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ - if (!value) { \ - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ - exit(-1); \ - } \ - \ - bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \ - if (!ret) { \ - SHERPA_ONNX_LOGE("Invalid value %s for %s", value.get(), src_key); \ - exit(-1); \ - } \ +#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \ + do { \ + auto value = \ + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ + if (!value) { \ + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ + exit(-1); \ + } \ + \ + bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \ + if (!ret) { \ + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \ + exit(-1); \ + } \ } while (0) // read a vector of floats -#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \ - do { \ - auto value = \ - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ - if (!value) { \ - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ - exit(-1); \ - } \ - \ - bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \ - if (!ret) { \ - SHERPA_ONNX_LOGE("Invalid value %s for %s", value.get(), src_key); \ - exit(-1); \ - } \ +#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \ + do { \ + auto value = \ + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ + if (!value) { \ + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ + exit(-1); \ + } \ + \ + bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \ + if (!ret) { \ + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \ + exit(-1); \ + } \ } while (0) // read a vector of strings -#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \ - do { \ - auto value = \ - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ - if (!value) { \ - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ - exit(-1); \ - } \ - SplitStringToVector(value.get(), ",", false, &dst); \ - \ - if (dst.empty()) { \ - SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \ - src_key); \ - exit(-1); \ - } \ +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \ + do { \ + auto value = \ + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ + if (!value) { \ + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ + exit(-1); \ + } \ + SplitStringToVector(value.get(), ",", false, &dst); \ + \ + if (dst.empty()) { \ + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \ + value.get(), src_key); \ + exit(-1); \ + } \ } while (0) // read a vector of strings separated by sep -#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \ - do { \ - auto value = \ - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ - if (!value) { \ - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ - exit(-1); \ - } \ - SplitStringToVector(value.get(), sep, false, &dst); \ - \ - if (dst.empty()) { \ - SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \ - src_key); \ - exit(-1); \ - } \ +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \ + do { \ + auto value = \ + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ + if (!value) { \ + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ + exit(-1); \ + } \ + SplitStringToVector(value.get(), sep, false, &dst); \ + \ + if (dst.empty()) { \ + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \ + value.get(), src_key); \ + exit(-1); \ + } \ } while (0) // Read a string @@ -143,17 +143,29 @@ auto value = \ meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ if (!value) { \ - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ exit(-1); \ } \ \ dst = value.get(); \ if (dst.empty()) { \ - SHERPA_ONNX_LOGE("Invalid value for %s\n", src_key); \ + SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \ exit(-1); \ } \ } while (0) +#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \ + do { \ + auto value = \ + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ + if (!value) { \ + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ + exit(-1); \ + } \ + \ + dst = value.get(); \ + } while (0) + #define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \ default_value) \ do { \ @@ -164,7 +176,7 @@ } else { \ dst = value.get(); \ if (dst.empty()) { \ - SHERPA_ONNX_LOGE("Invalid value for %s\n", src_key); \ + SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \ exit(-1); \ } \ } \ diff --git a/sherpa-onnx/csrc/melo-tts-lexicon.cc b/sherpa-onnx/csrc/melo-tts-lexicon.cc index d7332b36c..a605e6c38 100644 --- a/sherpa-onnx/csrc/melo-tts-lexicon.cc +++ b/sherpa-onnx/csrc/melo-tts-lexicon.cc @@ -10,8 +10,8 @@ #include "cppjieba/Jieba.hpp" #include "sherpa-onnx/csrc/file-utils.h" -#include "sherpa-onnx/csrc/lexicon.h" #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/symbol-table.h" #include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { diff --git a/sherpa-onnx/csrc/offline-ctc-model.cc b/sherpa-onnx/csrc/offline-ctc-model.cc index 66e67ecf1..9d1e05d9b 100644 --- a/sherpa-onnx/csrc/offline-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-ctc-model.cc @@ -21,6 +21,7 @@ namespace { enum class ModelType : std::uint8_t { kEncDecCTCModelBPE, + kEncDecCTCModel, kEncDecHybridRNNTCTCBPEModel, kTdnn, kZipformerCtc, @@ -75,6 +76,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, if (model_type.get() == std::string("EncDecCTCModelBPE")) { return ModelType::kEncDecCTCModelBPE; + } else if (model_type.get() == std::string("EncDecCTCModel")) { + return ModelType::kEncDecCTCModel; } else if (model_type.get() == std::string("EncDecHybridRNNTCTCBPEModel")) { return ModelType::kEncDecHybridRNNTCTCBPEModel; } else if (model_type.get() == std::string("tdnn")) { @@ -121,22 +124,18 @@ std::unique_ptr OfflineCtcModel::Create( switch (model_type) { case ModelType::kEncDecCTCModelBPE: return std::make_unique(config); - break; + case ModelType::kEncDecCTCModel: + return std::make_unique(config); case ModelType::kEncDecHybridRNNTCTCBPEModel: return std::make_unique(config); - break; case ModelType::kTdnn: return std::make_unique(config); - break; case ModelType::kZipformerCtc: return std::make_unique(config); - break; case ModelType::kWenetCtc: return std::make_unique(config); - break; case ModelType::kTeleSpeechCtc: return std::make_unique(config); - break; case ModelType::kUnknown: SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); return nullptr; @@ -177,23 +176,19 @@ std::unique_ptr OfflineCtcModel::Create( switch (model_type) { case ModelType::kEncDecCTCModelBPE: return std::make_unique(mgr, config); - break; + case ModelType::kEncDecCTCModel: + return std::make_unique(mgr, config); case ModelType::kEncDecHybridRNNTCTCBPEModel: return std::make_unique(mgr, config); - break; case ModelType::kTdnn: return std::make_unique(mgr, config); - break; case ModelType::kZipformerCtc: return std::make_unique(mgr, config); - break; case ModelType::kWenetCtc: return std::make_unique(mgr, config); - break; case ModelType::kTeleSpeechCtc: return std::make_unique(mgr, config); - break; case ModelType::kUnknown: SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); return nullptr; diff --git a/sherpa-onnx/csrc/offline-ctc-model.h b/sherpa-onnx/csrc/offline-ctc-model.h index f4c7406f6..ead532b48 100644 --- a/sherpa-onnx/csrc/offline-ctc-model.h +++ b/sherpa-onnx/csrc/offline-ctc-model.h @@ -66,6 +66,10 @@ class OfflineCtcModel { // Return true if the model supports batch size > 1 virtual bool SupportBatchProcessing() const { return true; } + + // return true for models from https://github.com/salute-developers/GigaAM + // return false otherwise + virtual bool IsGigaAM() const { return false; } }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc index 2d790954b..708cb4b4f 100644 --- a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc @@ -72,6 +72,8 @@ class OfflineNemoEncDecCtcModel::Impl { std::string FeatureNormalizationMethod() const { return normalize_type_; } + bool IsGigaAM() const { return is_giga_am_; } + private: void Init(void *model_data, size_t model_data_length) { sess_ = std::make_unique(env_, model_data, model_data_length, @@ -92,7 +94,9 @@ class OfflineNemoEncDecCtcModel::Impl { Ort::AllocatorWithDefaultOptions allocator; // used in the macro below SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); - SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type"); + SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(normalize_type_, + "normalize_type"); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(is_giga_am_, "is_giga_am", 0); } private: @@ -112,6 +116,10 @@ class OfflineNemoEncDecCtcModel::Impl { int32_t vocab_size_ = 0; int32_t subsampling_factor_ = 0; std::string normalize_type_; + + // it is 1 for models from + // https://github.com/salute-developers/GigaAM + int32_t is_giga_am_ = 0; }; OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel( @@ -146,4 +154,6 @@ std::string OfflineNemoEncDecCtcModel::FeatureNormalizationMethod() const { return impl_->FeatureNormalizationMethod(); } +bool OfflineNemoEncDecCtcModel::IsGigaAM() const { return impl_->IsGigaAM(); } + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h index 6e1ba5855..c961e694e 100644 --- a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h +++ b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h @@ -76,6 +76,8 @@ class OfflineNemoEncDecCtcModel : public OfflineCtcModel { // for details std::string FeatureNormalizationMethod() const override; + bool IsGigaAM() const override; + private: class Impl; std::unique_ptr impl_; diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h index 7bbe6938c..1199ff109 100644 --- a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -104,11 +104,20 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { } if (!config_.model_config.nemo_ctc.model.empty()) { - config_.feat_config.low_freq = 0; - config_.feat_config.high_freq = 0; - config_.feat_config.is_librosa = true; - config_.feat_config.remove_dc_offset = false; - config_.feat_config.window_type = "hann"; + if (model_->IsGigaAM()) { + config_.feat_config.low_freq = 0; + config_.feat_config.high_freq = 8000; + config_.feat_config.remove_dc_offset = false; + config_.feat_config.preemph_coeff = 0; + config_.feat_config.window_type = "hann"; + config_.feat_config.feature_dim = 64; + } else { + config_.feat_config.low_freq = 0; + config_.feat_config.high_freq = 0; + config_.feat_config.is_librosa = true; + config_.feat_config.remove_dc_offset = false; + config_.feat_config.window_type = "hann"; + } } if (!config_.model_config.wenet_ctc.model.empty()) { diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index 5062968cc..a80301ebf 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -172,7 +172,7 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(config); } - if (model_type == "EncDecCTCModelBPE" || + if (model_type == "EncDecCTCModelBPE" || model_type == "EncDecCTCModel" || model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" || model_type == "zipformer2_ctc" || model_type == "wenet_ctc" || model_type == "telespeech_ctc") { @@ -189,6 +189,7 @@ std::unique_ptr OfflineRecognizerImpl::Create( " - Non-streaming transducer models from icefall\n" " - Non-streaming Paraformer models from FunASR\n" " - EncDecCTCModelBPE models from NeMo\n" + " - EncDecCTCModel models from NeMo\n" " - EncDecHybridRNNTCTCBPEModel models from NeMo\n" " - Whisper models\n" " - Tdnn models\n" @@ -343,7 +344,7 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(mgr, config); } - if (model_type == "EncDecCTCModelBPE" || + if (model_type == "EncDecCTCModelBPE" || model_type == "EncDecCTCModel" || model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" || model_type == "zipformer2_ctc" || model_type == "wenet_ctc" || model_type == "telespeech_ctc") { @@ -360,6 +361,7 @@ std::unique_ptr OfflineRecognizerImpl::Create( " - Non-streaming transducer models from icefall\n" " - Non-streaming Paraformer models from FunASR\n" " - EncDecCTCModelBPE models from NeMo\n" + " - EncDecCTCModel models from NeMo\n" " - EncDecHybridRNNTCTCBPEModel models from NeMo\n" " - Whisper models\n" " - Tdnn models\n" diff --git a/sherpa-onnx/csrc/symbol-table.cc b/sherpa-onnx/csrc/symbol-table.cc index 5655c03a8..a71225c38 100644 --- a/sherpa-onnx/csrc/symbol-table.cc +++ b/sherpa-onnx/csrc/symbol-table.cc @@ -7,6 +7,8 @@ #include #include #include +#include +#include #if __ANDROID_API__ >= 9 #include @@ -16,10 +18,54 @@ #endif #include "sherpa-onnx/csrc/base64-decode.h" +#include "sherpa-onnx/csrc/lexicon.h" #include "sherpa-onnx/csrc/onnx-utils.h" namespace sherpa_onnx { +std::unordered_map ReadTokens( + std::istream &is, + std::unordered_map *id2token /*= nullptr*/) { + std::unordered_map token2id; + + std::string line; + + std::string sym; + int32_t id = -1; + while (std::getline(is, line)) { + std::istringstream iss(line); + iss >> sym; + if (iss.eof()) { + id = atoi(sym.c_str()); + sym = " "; + } else { + iss >> id; + } + + // eat the trailing \r\n on windows + iss >> std::ws; + if (!iss.eof()) { + SHERPA_ONNX_LOGE("Error: %s", line.c_str()); + exit(-1); + } + +#if 0 + if (token2id.count(sym)) { + SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d", + sym.c_str(), line.c_str(), token2id.at(sym)); + exit(-1); + } +#endif + if (id2token) { + id2token->insert({id, sym}); + } + + token2id.insert({std::move(sym), id}); + } + + return token2id; +} + SymbolTable::SymbolTable(const std::string &filename, bool is_file) { if (is_file) { std::ifstream is(filename); @@ -39,25 +85,7 @@ SymbolTable::SymbolTable(AAssetManager *mgr, const std::string &filename) { } #endif -void SymbolTable::Init(std::istream &is) { - std::string sym; - int32_t id = 0; - while (is >> sym >> id) { -#if 0 - // we disable the test here since for some multi-lingual BPE models - // from NeMo, the same symbol can appear multiple times with different IDs. - if (sym != " ") { - assert(sym2id_.count(sym) == 0); - } -#endif - - assert(id2sym_.count(id) == 0); - - sym2id_.insert({sym, id}); - id2sym_.insert({id, sym}); - } - assert(is.eof()); -} +void SymbolTable::Init(std::istream &is) { sym2id_ = ReadTokens(is, &id2sym_); } std::string SymbolTable::ToString() const { std::ostringstream os; diff --git a/sherpa-onnx/csrc/symbol-table.h b/sherpa-onnx/csrc/symbol-table.h index 2c17b4d5e..75a96144e 100644 --- a/sherpa-onnx/csrc/symbol-table.h +++ b/sherpa-onnx/csrc/symbol-table.h @@ -5,8 +5,10 @@ #ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ #define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ +#include #include #include +#include #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" @@ -15,6 +17,16 @@ namespace sherpa_onnx { +// The same token can be mapped to different integer IDs, so +// we need an id2token argument here. +std::unordered_map ReadTokens( + std::istream &is, + std::unordered_map *id2token = nullptr); + +std::vector ConvertTokensToIds( + const std::unordered_map &token2id, + const std::vector &tokens); + /// It manages mapping between symbols and integer IDs. class SymbolTable { public: diff --git a/sherpa-onnx/kotlin-api/OfflineRecognizer.kt b/sherpa-onnx/kotlin-api/OfflineRecognizer.kt index 203278cb7..10cdc5179 100644 --- a/sherpa-onnx/kotlin-api/OfflineRecognizer.kt +++ b/sherpa-onnx/kotlin-api/OfflineRecognizer.kt @@ -394,6 +394,16 @@ fun getOfflineModelConfig(type: Int): OfflineModelConfig? { modelType = "transducer", ) } + + 19 -> { + val modelDir = "sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24" + return OfflineModelConfig( + nemo = OfflineNemoEncDecCtcModelConfig( + model = "$modelDir/model.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + ) + } } return null }