diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index 87865464e..dec34101e 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -91,6 +91,18 @@ python3 ./python-api-examples/add-punctuation.py rm -rf $repo +log "test online punctuation" + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 +tar xvf sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 +rm sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 +repo=sherpa-onnx-online-punct-en-2024-08-06 +ls -lh $repo + +python3 ./python-api-examples/add-punctuation-online.py + +rm -rf $repo + log "test audio tagging" curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 diff --git a/.gitignore b/.gitignore index 727d42c32..bf8ca193f 100644 --- a/.gitignore +++ b/.gitignore @@ -117,3 +117,4 @@ sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17 vits-melo-tts-zh_en *.o *.ppu +sherpa-onnx-online-punct-en-2024-08-06 diff --git a/python-api-examples/add-punctuation-online.py b/python-api-examples/add-punctuation-online.py new file mode 100755 index 000000000..a883b284c --- /dev/null +++ b/python-api-examples/add-punctuation-online.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 + +""" +This script shows how to add punctuations to text using sherpa-onnx Python API. + +Please download the model from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/punctuation-models + +The following is an example + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 +tar xvf sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 +rm sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 +""" + +from pathlib import Path + +import sherpa_onnx + + +def main(): + model = "./sherpa-onnx-online-punct-en-2024-08-06/model.onnx" + bpe = "./sherpa-onnx-online-punct-en-2024-08-06/bpe.vocab" + if not Path(model).is_file(): + raise ValueError(f"{model} does not exist") + if not Path(bpe).is_file(): + raise ValueError(f"{bpe} does not exist") + + model_config = sherpa_onnx.OnlinePunctuationModelConfig( + cnn_bilstm=model, bpe_vocab=bpe + ) + config = sherpa_onnx.OnlinePunctuationConfig(model_config=model_config) + punct = sherpa_onnx.OnlinePunctuation(config) + + texts = [ + "how are you i am fine thank you", + "The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry", + ] + for text in texts: + text_with_punct = punct.add_punctuation_with_case(text) + print("----------") + print(f"input : {text}") + print(f"output: {text_with_punct}") + print("----------") + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index fa5d32aff..a6edb5139 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -27,6 +27,7 @@ set(srcs online-model-config.cc online-nemo-ctc-model-config.cc online-paraformer-model-config.cc + online-punctuation.cc online-recognizer.cc online-stream.cc online-transducer-model-config.cc diff --git a/sherpa-onnx/python/csrc/online-punctuation.cc b/sherpa-onnx/python/csrc/online-punctuation.cc new file mode 100644 index 000000000..13aa66b64 --- /dev/null +++ b/sherpa-onnx/python/csrc/online-punctuation.cc @@ -0,0 +1,50 @@ +// sherpa-onnx/python/csrc/online-punctuation.cc +// +// Copyright (c) 2024 + +#include "sherpa-onnx/python/csrc/online-punctuation.h" + +#include "sherpa-onnx/csrc/online-punctuation.h" + +namespace sherpa_onnx { + +static void PybindOnlinePunctuationModelConfig(py::module *m) { + using PyClass = OnlinePunctuationModelConfig; + py::class_(*m, "OnlinePunctuationModelConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("cnn_bilstm"), py::arg("bpe_vocab"), py::arg("num_threads") = 1, + py::arg("debug") = false, py::arg("provider") = "cpu") + .def_readwrite("cnn_bilstm", &PyClass::cnn_bilstm) + .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +static void PybindOnlinePunctuationConfig(py::module *m) { + PybindOnlinePunctuationModelConfig(m); + using PyClass = OnlinePunctuationConfig; + + py::class_(*m, "OnlinePunctuationConfig") + .def(py::init<>()) + .def(py::init(), py::arg("model_config")) + .def_readwrite("model_config", &PyClass::model) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +void PybindOnlinePunctuation(py::module *m) { + PybindOnlinePunctuationConfig(m); + using PyClass = OnlinePunctuation; + + py::class_(*m, "OnlinePunctuation") + .def(py::init(), py::arg("config"), + py::call_guard()) + .def("add_punctuation_with_case", &PyClass::AddPunctuationWithCase, py::arg("text"), + py::call_guard()); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/online-punctuation.h b/sherpa-onnx/python/csrc/online-punctuation.h new file mode 100644 index 000000000..a6dbd1805 --- /dev/null +++ b/sherpa-onnx/python/csrc/online-punctuation.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/online-punctuation.h +// +// Copyright (c) 2024 + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_PUNCTUATION_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_PUNCTUATION_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOnlinePunctuation(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_PUNCTUATION_H_ diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 5b369ed84..0f04d4cad 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -20,6 +20,7 @@ #include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h" #include "sherpa-onnx/python/csrc/online-lm-config.h" #include "sherpa-onnx/python/csrc/online-model-config.h" +#include "sherpa-onnx/python/csrc/online-punctuation.h" #include "sherpa-onnx/python/csrc/online-recognizer.h" #include "sherpa-onnx/python/csrc/online-stream.h" #include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h" @@ -42,6 +43,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindWaveWriter(&m); PybindAudioTagging(&m); PybindOfflinePunctuation(&m); + PybindOnlinePunctuation(&m); PybindFeatures(&m); PybindOnlineCtcFstDecoderConfig(&m); diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index 7a832ba06..72560c42e 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -15,6 +15,9 @@ OfflineTtsModelConfig, OfflineTtsVitsModelConfig, OfflineZipformerAudioTaggingModelConfig, + OnlinePunctuation, + OnlinePunctuationConfig, + OnlinePunctuationModelConfig, OnlineStream, SileroVadModelConfig, SpeakerEmbeddingExtractor,