Skip to content

Commit

Permalink
Support GigaAM CTC models for Russian ASR (#1464)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Oct 25, 2024
1 parent 2b40079 commit b41f6d2
Show file tree
Hide file tree
Showing 24 changed files with 641 additions and 160 deletions.
15 changes: 15 additions & 0 deletions .github/scripts/test-offline-ctc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ echo "PATH: $PATH"

which $EXE

log "------------------------------------------------------------"
log "Run NeMo GigaAM Russian models"
log "------------------------------------------------------------"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2
tar xvf sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2
rm sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2

$EXE \
--nemo-ctc-model=./sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24/model.int8.onnx \
--tokens=./sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24/tokens.txt \
--debug=1 \
./sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24/test_wavs/example.wav

rm -rf sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24

log "------------------------------------------------------------"
log "Run SenseVoice models"
log "------------------------------------------------------------"
Expand Down
88 changes: 88 additions & 0 deletions .github/workflows/export-nemo-giga-am-to-onnx.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
name: export-nemo-giga-am-to-onnx

on:
workflow_dispatch:

concurrency:
group: export-nemo-giga-am-to-onnx-${{ github.ref }}
cancel-in-progress: true

jobs:
export-nemo-am-giga-to-onnx:
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name: export nemo GigaAM models to ONNX
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [macos-latest]
python-version: ["3.10"]

steps:
- uses: actions/checkout@v4

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Run CTC
shell: bash
run: |
pushd scripts/nemo/GigaAM
./run-ctc.sh
popd
d=sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24
mkdir $d
mkdir $d/test_wavs
rm scripts/nemo/GigaAM/model.onnx
mv -v scripts/nemo/GigaAM/*.int8.onnx $d/
mv -v scripts/nemo/GigaAM/*.md $d/
mv -v scripts/nemo/GigaAM/*.pdf $d/
mv -v scripts/nemo/GigaAM/tokens.txt $d/
mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/
mv -v scripts/nemo/GigaAM/run-ctc.sh $d/
mv -v scripts/nemo/GigaAM/*-ctc.py $d/
ls -lh scripts/nemo/GigaAM/
ls -lh $d
tar cjvf ${d}.tar.bz2 $d
- name: Release
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
file: ./*.tar.bz2
overwrite: true
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: asr-models

- name: Publish to huggingface (CTC)
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "[email protected]"
git config --global user.name "Fangjun Kuang"
d=sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24
export GIT_LFS_SKIP_SMUDGE=1
export GIT_CLONE_PROTECTION_ACTIVE=false
git clone https://csukuangfj:[email protected]/csukuangfj/$d huggingface
mv -v $d/* ./huggingface
cd huggingface
git lfs track "*.onnx"
git lfs track "*.wav"
git status
git add .
git status
git commit -m "add models"
git push https://csukuangfj:[email protected]/csukuangfj/$d main
20 changes: 10 additions & 10 deletions .github/workflows/linux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,16 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: install/*

- name: Test offline CTC
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-ctc.sh
du -h -d1 .
- name: Test C++ API
shell: bash
run: |
Expand Down Expand Up @@ -180,16 +190,6 @@ jobs:
.github/scripts/test-offline-transducer.sh
du -h -d1 .
- name: Test offline CTC
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-ctc.sh
du -h -d1 .
- name: Test online punctuation
shell: bash
run: |
Expand Down
18 changes: 18 additions & 0 deletions scripts/apk/generate-vad-asr-apk-script.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,24 @@ def get_models():
ls -lh
popd
""",
),
Model(
model_name="sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24",
idx=19,
lang="ru",
short_name="nemo_ctc_giga_am",
cmd="""
pushd $model_name
rm -rfv test_wavs
rm -fv *.sh
rm -fv *.py
ls -lh
popd
""",
),
Expand Down
10 changes: 10 additions & 0 deletions scripts/nemo/GigaAM/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Introduction

This folder contains scripts for converting models from
https://github.com/salute-developers/GigaAM
to sherpa-onnx.

The ASR models are for Russian speech recognition in this folder.

Please see the license of the models at
https://github.com/salute-developers/GigaAM/blob/main/GigaAM%20License_NC.pdf
114 changes: 114 additions & 0 deletions scripts/nemo/GigaAM/export-onnx-ctc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
from typing import Dict

import onnx
import torch
import torchaudio
from nemo.collections.asr.models import EncDecCTCModel
from nemo.collections.asr.modules.audio_preprocessing import (
AudioToMelSpectrogramPreprocessor as NeMoAudioToMelSpectrogramPreprocessor,
)
from nemo.collections.asr.parts.preprocessing.features import (
FilterbankFeaturesTA as NeMoFilterbankFeaturesTA,
)
from onnxruntime.quantization import QuantType, quantize_dynamic


class FilterbankFeaturesTA(NeMoFilterbankFeaturesTA):
def __init__(self, mel_scale: str = "htk", wkwargs=None, **kwargs):
if "window_size" in kwargs:
del kwargs["window_size"]
if "window_stride" in kwargs:
del kwargs["window_stride"]

super().__init__(**kwargs)

self._mel_spec_extractor: torchaudio.transforms.MelSpectrogram = (
torchaudio.transforms.MelSpectrogram(
sample_rate=self._sample_rate,
win_length=self.win_length,
hop_length=self.hop_length,
n_mels=kwargs["nfilt"],
window_fn=self.torch_windows[kwargs["window"]],
mel_scale=mel_scale,
norm=kwargs["mel_norm"],
n_fft=kwargs["n_fft"],
f_max=kwargs.get("highfreq", None),
f_min=kwargs.get("lowfreq", 0),
wkwargs=wkwargs,
)
)


class AudioToMelSpectrogramPreprocessor(NeMoAudioToMelSpectrogramPreprocessor):
def __init__(self, mel_scale: str = "htk", **kwargs):
super().__init__(**kwargs)
kwargs["nfilt"] = kwargs["features"]
del kwargs["features"]
self.featurizer = (
FilterbankFeaturesTA( # Deprecated arguments; kept for config compatibility
mel_scale=mel_scale,
**kwargs,
)
)


def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()

for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)

onnx.save(model, filename)


def main():
model = EncDecCTCModel.from_config_file("./ctc_model_config.yaml")
ckpt = torch.load("./ctc_model_weights.ckpt", map_location="cpu")
model.load_state_dict(ckpt, strict=False)
model.eval()

with open("tokens.txt", "w", encoding="utf-8") as f:
for i, t in enumerate(model.cfg.labels):
f.write(f"{t} {i}\n")
f.write(f"<blk> {i+1}\n")

filename = "model.onnx"
model.export(filename)

meta_data = {
"vocab_size": len(model.cfg.labels) + 1,
"normalize_type": "",
"subsampling_factor": 4,
"model_type": "EncDecCTCModel",
"version": "1",
"model_author": "https://github.com/salute-developers/GigaAM",
"license": "https://github.com/salute-developers/GigaAM/blob/main/GigaAM%20License_NC.pdf",
"language": "Russian",
"is_giga_am": 1,
}
add_meta_data(filename, meta_data)

filename_int8 = "model.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
weight_type=QuantType.QUInt8,
)


if __name__ == "__main__":
main()
36 changes: 36 additions & 0 deletions scripts/nemo/GigaAM/run-ctc.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/usr/bin/env bash
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)

set -ex

function install_nemo() {
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
python3 get-pip.py

pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html

pip install -qq wget text-unidecode matplotlib>=3.3.2 onnx onnxruntime pybind11 Cython einops kaldi-native-fbank soundfile librosa
pip install -qq ipython

# sudo apt-get install -q -y sox libsndfile1 ffmpeg python3-pip ipython

BRANCH='main'
python3 -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]

pip install numpy==1.26.4
}

function download_files() {
curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/ctc_model_weights.ckpt
curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/ctc_model_config.yaml
curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/example.wav
curl -SL -O https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/long_example.wav
curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM%20License_NC.pdf
}

install_nemo
download_files

python3 ./export-onnx-ctc.py
ls -lh
python3 ./test-onnx-ctc.py
Loading

0 comments on commit b41f6d2

Please sign in to comment.