diff --git a/.github/workflows/export-nemo-giga-am-to-onnx.yaml b/.github/workflows/export-nemo-giga-am-to-onnx.yaml index f48c344e6..1af754d0b 100644 --- a/.github/workflows/export-nemo-giga-am-to-onnx.yaml +++ b/.github/workflows/export-nemo-giga-am-to-onnx.yaml @@ -38,7 +38,7 @@ jobs: mkdir $d/test_wavs rm scripts/nemo/GigaAM/model.onnx mv -v scripts/nemo/GigaAM/*.int8.onnx $d/ - mv -v scripts/nemo/GigaAM/*.md $d/ + cp -v scripts/nemo/GigaAM/*.md $d/ mv -v scripts/nemo/GigaAM/*.pdf $d/ mv -v scripts/nemo/GigaAM/tokens.txt $d/ mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/ @@ -51,6 +51,34 @@ jobs: tar cjvf ${d}.tar.bz2 $d + - name: Run Transducer + shell: bash + run: | + pushd scripts/nemo/GigaAM + ./run-rnnt.sh + popd + + d=sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24 + mkdir $d + mkdir $d/test_wavs + + mv -v scripts/nemo/GigaAM/encoder.int8.onnx $d/ + mv -v scripts/nemo/GigaAM/decoder.onnx $d/ + mv -v scripts/nemo/GigaAM/joiner.onnx $d/ + + cp -v scripts/nemo/GigaAM/*.md $d/ + mv -v scripts/nemo/GigaAM/*.pdf $d/ + mv -v scripts/nemo/GigaAM/tokens.txt $d/ + mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/ + mv -v scripts/nemo/GigaAM/run-rnnt.sh $d/ + mv -v scripts/nemo/GigaAM/*-rnnt.py $d/ + + ls -lh scripts/nemo/GigaAM/ + + ls -lh $d + + tar cjvf ${d}.tar.bz2 $d + - name: Release uses: svenstaro/upload-release-action@v2 with: @@ -61,7 +89,7 @@ jobs: repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} tag: asr-models - - name: Publish to huggingface (CTC) + - name: Publish to huggingface (Transducer) env: HF_TOKEN: ${{ secrets.HF_TOKEN }} uses: nick-fields/retry@v3 @@ -73,7 +101,7 @@ jobs: git config --global user.email "csukuangfj@gmail.com" git config --global user.name "Fangjun Kuang" - d=sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24 + d=sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24/ export GIT_LFS_SKIP_SMUDGE=1 export GIT_CLONE_PROTECTION_ACTIVE=false git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface diff --git a/scripts/apk/generate-vad-asr-apk-script.py b/scripts/apk/generate-vad-asr-apk-script.py index 8217e6ea0..7671e975d 100755 --- a/scripts/apk/generate-vad-asr-apk-script.py +++ b/scripts/apk/generate-vad-asr-apk-script.py @@ -351,6 +351,24 @@ def get_models(): ls -lh + popd + """, + ), + Model( + model_name="sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24", + idx=20, + lang="ru", + short_name="nemo_transducer_giga_am", + cmd=""" + pushd $model_name + + rm -rfv test_wavs + + rm -fv *.sh + rm -fv *.py + + ls -lh + popd """, ), diff --git a/scripts/nemo/GigaAM/export-onnx-ctc.py b/scripts/nemo/GigaAM/export-onnx-ctc.py index fbcec518e..81feb3b78 100755 --- a/scripts/nemo/GigaAM/export-onnx-ctc.py +++ b/scripts/nemo/GigaAM/export-onnx-ctc.py @@ -75,6 +75,7 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]): onnx.save(model, filename) +@torch.no_grad() def main(): model = EncDecCTCModel.from_config_file("./ctc_model_config.yaml") ckpt = torch.load("./ctc_model_weights.ckpt", map_location="cpu") diff --git a/scripts/nemo/GigaAM/export-onnx-rnnt.py b/scripts/nemo/GigaAM/export-onnx-rnnt.py new file mode 100644 index 000000000..1ac05ff7f --- /dev/null +++ b/scripts/nemo/GigaAM/export-onnx-rnnt.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +from typing import Dict + +import onnx +import torch +import torchaudio +from nemo.collections.asr.models import EncDecRNNTBPEModel +from nemo.collections.asr.modules.audio_preprocessing import ( + AudioToMelSpectrogramPreprocessor as NeMoAudioToMelSpectrogramPreprocessor, +) +from nemo.collections.asr.parts.preprocessing.features import ( + FilterbankFeaturesTA as NeMoFilterbankFeaturesTA, +) +from onnxruntime.quantization import QuantType, quantize_dynamic + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + while len(model.metadata_props): + model.metadata_props.pop() + + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +class FilterbankFeaturesTA(NeMoFilterbankFeaturesTA): + def __init__(self, mel_scale: str = "htk", wkwargs=None, **kwargs): + if "window_size" in kwargs: + del kwargs["window_size"] + if "window_stride" in kwargs: + del kwargs["window_stride"] + + super().__init__(**kwargs) + + self._mel_spec_extractor: torchaudio.transforms.MelSpectrogram = ( + torchaudio.transforms.MelSpectrogram( + sample_rate=self._sample_rate, + win_length=self.win_length, + hop_length=self.hop_length, + n_mels=kwargs["nfilt"], + window_fn=self.torch_windows[kwargs["window"]], + mel_scale=mel_scale, + norm=kwargs["mel_norm"], + n_fft=kwargs["n_fft"], + f_max=kwargs.get("highfreq", None), + f_min=kwargs.get("lowfreq", 0), + wkwargs=wkwargs, + ) + ) + + +class AudioToMelSpectrogramPreprocessor(NeMoAudioToMelSpectrogramPreprocessor): + def __init__(self, mel_scale: str = "htk", **kwargs): + super().__init__(**kwargs) + kwargs["nfilt"] = kwargs["features"] + del kwargs["features"] + self.featurizer = ( + FilterbankFeaturesTA( # Deprecated arguments; kept for config compatibility + mel_scale=mel_scale, + **kwargs, + ) + ) + + +@torch.no_grad() +def main(): + model = EncDecRNNTBPEModel.from_config_file("./rnnt_model_config.yaml") + ckpt = torch.load("./rnnt_model_weights.ckpt", map_location="cpu") + model.load_state_dict(ckpt, strict=False) + model.eval() + + with open("./tokens.txt", "w", encoding="utf-8") as f: + for i, s in enumerate(model.joint.vocabulary): + f.write(f"{s} {i}\n") + f.write(f" {i+1}\n") + print("Saved to tokens.txt") + + model.encoder.export("encoder.onnx") + model.decoder.export("decoder.onnx") + model.joint.export("joiner.onnx") + + meta_data = { + "vocab_size": model.decoder.vocab_size, # not including the blank + "pred_rnn_layers": model.decoder.pred_rnn_layers, + "pred_hidden": model.decoder.pred_hidden, + "normalize_type": "", + "subsampling_factor": 4, + "model_type": "EncDecRNNTBPEModel", + "version": "1", + "model_author": "https://github.com/salute-developers/GigaAM", + "license": "https://github.com/salute-developers/GigaAM/blob/main/GigaAM%20License_NC.pdf", + "language": "Russian", + "is_giga_am": 1, + } + add_meta_data("encoder.onnx", meta_data) + + quantize_dynamic( + model_input="encoder.onnx", + model_output="encoder.int8.onnx", + weight_type=QuantType.QUInt8, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/nemo/GigaAM/run-ctc.sh b/scripts/nemo/GigaAM/run-ctc.sh index 499bc452b..03acc88e2 100755 --- a/scripts/nemo/GigaAM/run-ctc.sh +++ b/scripts/nemo/GigaAM/run-ctc.sh @@ -21,11 +21,15 @@ function install_nemo() { } function download_files() { - curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/ctc_model_weights.ckpt - curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/ctc_model_config.yaml - curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/example.wav - curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/long_example.wav - curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM%20License_NC.pdf + # curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/ctc_model_weights.ckpt + # curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/ctc_model_config.yaml + # curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/example.wav + # curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/long_example.wav + curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM/ctc/ctc_model_weights.ckpt + curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM/ctc/ctc_model_config.yaml + curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM/example.wav + curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM/long_example.wav + curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM/GigaAM%20License_NC.pdf } install_nemo diff --git a/scripts/nemo/GigaAM/run-rnnt.sh b/scripts/nemo/GigaAM/run-rnnt.sh new file mode 100755 index 000000000..209f4f15d --- /dev/null +++ b/scripts/nemo/GigaAM/run-rnnt.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +set -ex + +function install_nemo() { + curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py + python3 get-pip.py + + pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html + + pip install -qq wget text-unidecode matplotlib>=3.3.2 onnx onnxruntime pybind11 Cython einops kaldi-native-fbank soundfile librosa + pip install -qq ipython + + # sudo apt-get install -q -y sox libsndfile1 ffmpeg python3-pip ipython + + BRANCH='main' + python3 -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr] + + pip install numpy==1.26.4 +} + +function download_files() { + # curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/rnnt_model_weights.ckpt + # curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/rnnt_model_config.yaml + # curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/example.wav + # curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/long_example.wav + # curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/tokenizer_all_sets.tar + + curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM/rnnt/rnnt_model_weights.ckpt + curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM/rnnt/rnnt_model_config.yaml + curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM/example.wav + curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM/long_example.wav + curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM/GigaAM%20License_NC.pdf + curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM/rnnt/tokenizer_all_sets.tar + tar -xf tokenizer_all_sets.tar && rm tokenizer_all_sets.tar + ls -lh + echo "---" + ls -lh tokenizer_all_sets + echo "---" +} + +install_nemo +download_files + +python3 ./export-onnx-rnnt.py +ls -lh +python3 ./test-onnx-rnnt.py +rm -v encoder.onnx +ls -lh diff --git a/scripts/nemo/GigaAM/test-onnx-rnnt.py b/scripts/nemo/GigaAM/test-onnx-rnnt.py new file mode 100755 index 000000000..85c6a5e94 --- /dev/null +++ b/scripts/nemo/GigaAM/test-onnx-rnnt.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +import argparse +from pathlib import Path + +import kaldi_native_fbank as knf +import librosa +import numpy as np +import onnxruntime as ort +import soundfile as sf +import torch + + +def create_fbank(): + opts = knf.FbankOptions() + opts.frame_opts.dither = 0 + opts.frame_opts.remove_dc_offset = False + opts.frame_opts.preemph_coeff = 0 + opts.frame_opts.window_type = "hann" + + # Even though GigaAM uses 400 for fft, here we use 512 + # since kaldi-native-fbank only support fft for power of 2. + opts.frame_opts.round_to_power_of_two = True + + opts.mel_opts.low_freq = 0 + opts.mel_opts.high_freq = 8000 + opts.mel_opts.num_bins = 64 + + fbank = knf.OnlineFbank(opts) + return fbank + + +def compute_features(audio, fbank): + assert len(audio.shape) == 1, audio.shape + fbank.accept_waveform(16000, audio) + ans = [] + processed = 0 + while processed < fbank.num_frames_ready: + ans.append(np.array(fbank.get_frame(processed))) + processed += 1 + ans = np.stack(ans) + return ans + + +def display(sess): + print("==========Input==========") + for i in sess.get_inputs(): + print(i) + print("==========Output==========") + for i in sess.get_outputs(): + print(i) + + +""" +==========Input========== +NodeArg(name='audio_signal', type='tensor(float)', shape=['audio_signal_dynamic_axes_1', 64, 'audio_signal_dynamic_axes_2']) +NodeArg(name='length', type='tensor(int64)', shape=['length_dynamic_axes_1']) +==========Output========== +NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 768, 'outputs_dynamic_axes_2']) +NodeArg(name='encoded_lengths', type='tensor(int64)', shape=['encoded_lengths_dynamic_axes_1']) +==========Input========== +NodeArg(name='targets', type='tensor(int32)', shape=['targets_dynamic_axes_1', 'targets_dynamic_axes_2']) +NodeArg(name='target_length', type='tensor(int32)', shape=['target_length_dynamic_axes_1']) +NodeArg(name='states.1', type='tensor(float)', shape=[1, 'states.1_dim_1', 320]) +NodeArg(name='onnx::LSTM_3', type='tensor(float)', shape=[1, 1, 320]) +==========Output========== +NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 320, 'outputs_dynamic_axes_2']) +NodeArg(name='prednet_lengths', type='tensor(int32)', shape=['prednet_lengths_dynamic_axes_1']) +NodeArg(name='states', type='tensor(float)', shape=[1, 'states_dynamic_axes_1', 320]) +NodeArg(name='74', type='tensor(float)', shape=[1, 'states_dynamic_axes_1', 320]) +==========Input========== +NodeArg(name='encoder_outputs', type='tensor(float)', shape=['encoder_outputs_dynamic_axes_1', 768, 'encoder_outputs_dynamic_axes_2']) +NodeArg(name='decoder_outputs', type='tensor(float)', shape=['decoder_outputs_dynamic_axes_1', 320, 'decoder_outputs_dynamic_axes_2']) +==========Output========== +NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 'outputs_dynamic_axes_2', 'outputs_dynamic_axes_3', 513]) +""" + + +class OnnxModel: + def __init__( + self, + encoder: str, + decoder: str, + joiner: str, + ): + self.init_encoder(encoder) + display(self.encoder) + self.init_decoder(decoder) + display(self.decoder) + self.init_joiner(joiner) + display(self.joiner) + + def init_encoder(self, encoder): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.encoder = ort.InferenceSession( + encoder, + sess_options=session_opts, + providers=["CPUExecutionProvider"], + ) + + meta = self.encoder.get_modelmeta().custom_metadata_map + self.normalize_type = meta["normalize_type"] + print(meta) + + self.pred_rnn_layers = int(meta["pred_rnn_layers"]) + self.pred_hidden = int(meta["pred_hidden"]) + + def init_decoder(self, decoder): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.decoder = ort.InferenceSession( + decoder, + sess_options=session_opts, + providers=["CPUExecutionProvider"], + ) + + def init_joiner(self, joiner): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.joiner = ort.InferenceSession( + joiner, + sess_options=session_opts, + providers=["CPUExecutionProvider"], + ) + + def get_decoder_state(self): + batch_size = 1 + state0 = torch.zeros(self.pred_rnn_layers, batch_size, self.pred_hidden).numpy() + state1 = torch.zeros(self.pred_rnn_layers, batch_size, self.pred_hidden).numpy() + return state0, state1 + + def run_encoder(self, x: np.ndarray): + # x: (T, C) + x = torch.from_numpy(x) + x = x.t().unsqueeze(0) + # x: [1, C, T] + x_lens = torch.tensor([x.shape[-1]], dtype=torch.int64) + + (encoder_out, out_len) = self.encoder.run( + [ + self.encoder.get_outputs()[0].name, + self.encoder.get_outputs()[1].name, + ], + { + self.encoder.get_inputs()[0].name: x.numpy(), + self.encoder.get_inputs()[1].name: x_lens.numpy(), + }, + ) + # [batch_size, dim, T] + return encoder_out + + def run_decoder( + self, + token: int, + state0: np.ndarray, + state1: np.ndarray, + ): + target = torch.tensor([[token]], dtype=torch.int32).numpy() + target_len = torch.tensor([1], dtype=torch.int32).numpy() + + ( + decoder_out, + decoder_out_length, + state0_next, + state1_next, + ) = self.decoder.run( + [ + self.decoder.get_outputs()[0].name, + self.decoder.get_outputs()[1].name, + self.decoder.get_outputs()[2].name, + self.decoder.get_outputs()[3].name, + ], + { + self.decoder.get_inputs()[0].name: target, + self.decoder.get_inputs()[1].name: target_len, + self.decoder.get_inputs()[2].name: state0, + self.decoder.get_inputs()[3].name: state1, + }, + ) + return decoder_out, state0_next, state1_next + + def run_joiner( + self, + encoder_out: np.ndarray, + decoder_out: np.ndarray, + ): + # encoder_out: [batch_size, dim, 1] + # decoder_out: [batch_size, dim, 1] + logit = self.joiner.run( + [ + self.joiner.get_outputs()[0].name, + ], + { + self.joiner.get_inputs()[0].name: encoder_out, + self.joiner.get_inputs()[1].name: decoder_out, + }, + )[0] + # logit: [batch_size, 1, 1, vocab_size] + return logit + + +def main(): + model = OnnxModel("encoder.int8.onnx", "decoder.onnx", "joiner.onnx") + + id2token = dict() + with open("./tokens.txt", encoding="utf-8") as f: + for line in f: + t, idx = line.split() + id2token[int(idx)] = t + + fbank = create_fbank() + audio, sample_rate = sf.read("./example.wav", dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + if sample_rate != 16000: + audio = librosa.resample( + audio, + orig_sr=sample_rate, + target_sr=16000, + ) + sample_rate = 16000 + + tail_padding = np.zeros(sample_rate * 2) + + audio = np.concatenate([audio, tail_padding]) + + blank = len(id2token) - 1 + ans = [blank] + state0, state1 = model.get_decoder_state() + decoder_out, state0_next, state1_next = model.run_decoder(ans[-1], state0, state1) + + features = compute_features(audio, fbank) + print("audio.shape", audio.shape) + print("features.shape", features.shape) + + encoder_out = model.run_encoder(features) + # encoder_out:[batch_size, dim, T) + for t in range(encoder_out.shape[2]): + encoder_out_t = encoder_out[:, :, t : t + 1] + logits = model.run_joiner(encoder_out_t, decoder_out) + logits = torch.from_numpy(logits) + logits = logits.squeeze() + idx = torch.argmax(logits, dim=-1).item() + if idx != blank: + ans.append(idx) + state0 = state0_next + state1 = state1_next + decoder_out, state0_next, state1_next = model.run_decoder( + ans[-1], state0, state1 + ) + + ans = ans[1:] # remove the first blank + print(ans) + tokens = [id2token[i] for i in ans] + underline = "▁" + # underline = b"\xe2\x96\x81".decode() + text = "".join(tokens).replace(underline, " ").strip() + print("./example.wav") + print(text) + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index a80301ebf..f6c6e247c 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -166,7 +166,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(config); } - if (model_type == "EncDecHybridRNNTCTCBPEModel" && + if ((model_type == "EncDecHybridRNNTCTCBPEModel" || + model_type == "EncDecRNNTBPEModel") && !config.model_config.transducer.decoder_filename.empty() && !config.model_config.transducer.joiner_filename.empty()) { return std::make_unique(config); @@ -191,6 +192,7 @@ std::unique_ptr OfflineRecognizerImpl::Create( " - EncDecCTCModelBPE models from NeMo\n" " - EncDecCTCModel models from NeMo\n" " - EncDecHybridRNNTCTCBPEModel models from NeMo\n" + " - EncDecRNNTBPEModel models from NeMO" " - Whisper models\n" " - Tdnn models\n" " - Zipformer CTC models\n" @@ -338,7 +340,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(mgr, config); } - if (model_type == "EncDecHybridRNNTCTCBPEModel" && + if ((model_type == "EncDecHybridRNNTCTCBPEModel" || + model_type == "EncDecRNNTBPEModel") && !config.model_config.transducer.decoder_filename.empty() && !config.model_config.transducer.joiner_filename.empty()) { return std::make_unique(mgr, config); @@ -363,6 +366,7 @@ std::unique_ptr OfflineRecognizerImpl::Create( " - EncDecCTCModelBPE models from NeMo\n" " - EncDecCTCModel models from NeMo\n" " - EncDecHybridRNNTCTCBPEModel models from NeMo\n" + " - EncDecRNNTBPEModel models from NeMo\n" " - Whisper models\n" " - Tdnn models\n" " - Zipformer CTC models\n" diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h index 2f5b9e2a2..6727b0983 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h @@ -139,23 +139,29 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { } } - OfflineRecognizerConfig GetConfig() const override { - return config_; - } + OfflineRecognizerConfig GetConfig() const override { return config_; } private: void PostInit() { config_.feat_config.nemo_normalize_type = model_->FeatureNormalizationMethod(); - config_.feat_config.low_freq = 0; - // config_.feat_config.high_freq = 8000; - config_.feat_config.is_librosa = true; - config_.feat_config.remove_dc_offset = false; - // config_.feat_config.window_type = "hann"; config_.feat_config.dither = 0; - config_.feat_config.nemo_normalize_type = - model_->FeatureNormalizationMethod(); + + if (model_->IsGigaAM()) { + config_.feat_config.low_freq = 0; + config_.feat_config.high_freq = 8000; + config_.feat_config.remove_dc_offset = false; + config_.feat_config.preemph_coeff = 0; + config_.feat_config.window_type = "hann"; + config_.feat_config.feature_dim = 64; + } else { + config_.feat_config.low_freq = 0; + // config_.feat_config.high_freq = 8000; + config_.feat_config.is_librosa = true; + config_.feat_config.remove_dc_offset = false; + // config_.feat_config.window_type = "hann"; + } int32_t vocab_size = model_->VocabSize(); diff --git a/sherpa-onnx/csrc/offline-transducer-nemo-model.cc b/sherpa-onnx/csrc/offline-transducer-nemo-model.cc index f18e57da9..5332a835e 100644 --- a/sherpa-onnx/csrc/offline-transducer-nemo-model.cc +++ b/sherpa-onnx/csrc/offline-transducer-nemo-model.cc @@ -153,6 +153,8 @@ class OfflineTransducerNeMoModel::Impl { std::string FeatureNormalizationMethod() const { return normalize_type_; } + bool IsGigaAM() const { return is_giga_am_; } + private: void InitEncoder(void *model_data, size_t model_data_length) { encoder_sess_ = std::make_unique( @@ -181,9 +183,11 @@ class OfflineTransducerNeMoModel::Impl { vocab_size_ += 1; SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); - SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type"); + SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(normalize_type_, + "normalize_type"); SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers"); SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden"); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(is_giga_am_, "is_giga_am", 0); if (normalize_type_ == "NA") { normalize_type_ = ""; @@ -245,6 +249,7 @@ class OfflineTransducerNeMoModel::Impl { std::string normalize_type_; int32_t pred_rnn_layers_ = -1; int32_t pred_hidden_ = -1; + int32_t is_giga_am_ = 0; }; OfflineTransducerNeMoModel::OfflineTransducerNeMoModel( @@ -298,4 +303,6 @@ std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const { return impl_->FeatureNormalizationMethod(); } +bool OfflineTransducerNeMoModel::IsGigaAM() const { return impl_->IsGigaAM(); } + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-transducer-nemo-model.h b/sherpa-onnx/csrc/offline-transducer-nemo-model.h index 9ac135916..e4017a4c4 100644 --- a/sherpa-onnx/csrc/offline-transducer-nemo-model.h +++ b/sherpa-onnx/csrc/offline-transducer-nemo-model.h @@ -93,6 +93,8 @@ class OfflineTransducerNeMoModel { // for details std::string FeatureNormalizationMethod() const; + bool IsGigaAM() const; + private: class Impl; std::unique_ptr impl_; diff --git a/sherpa-onnx/kotlin-api/OfflineRecognizer.kt b/sherpa-onnx/kotlin-api/OfflineRecognizer.kt index 10cdc5179..b82436356 100644 --- a/sherpa-onnx/kotlin-api/OfflineRecognizer.kt +++ b/sherpa-onnx/kotlin-api/OfflineRecognizer.kt @@ -404,6 +404,19 @@ fun getOfflineModelConfig(type: Int): OfflineModelConfig? { tokens = "$modelDir/tokens.txt", ) } + + 20 -> { + val modelDir = "sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24" + return OfflineModelConfig( + transducer = OfflineTransducerModelConfig( + encoder = "$modelDir/encoder.int8.onnx", + decoder = "$modelDir/decoder.onnx", + joiner = "$modelDir/joiner.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "nemo_transducer", + ) + } } return null }