Skip to content

Commit

Permalink
Add Python binding for online punctuation models (#1312)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaochie authored Sep 9, 2024
1 parent 857cb50 commit 3bffc24
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 0 deletions.
12 changes: 12 additions & 0 deletions .github/scripts/test-python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,18 @@ python3 ./python-api-examples/add-punctuation.py

rm -rf $repo

log "test online punctuation"

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
tar xvf sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
rm sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
repo=sherpa-onnx-online-punct-en-2024-08-06
ls -lh $repo

python3 ./python-api-examples/add-punctuation-online.py

rm -rf $repo

log "test audio tagging"

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,4 @@ sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17
vits-melo-tts-zh_en
*.o
*.ppu
sherpa-onnx-online-punct-en-2024-08-06
48 changes: 48 additions & 0 deletions python-api-examples/add-punctuation-online.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python3

"""
This script shows how to add punctuations to text using sherpa-onnx Python API.
Please download the model from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/punctuation-models
The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
tar xvf sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
rm sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
"""

from pathlib import Path

import sherpa_onnx


def main():
model = "./sherpa-onnx-online-punct-en-2024-08-06/model.onnx"
bpe = "./sherpa-onnx-online-punct-en-2024-08-06/bpe.vocab"
if not Path(model).is_file():
raise ValueError(f"{model} does not exist")
if not Path(bpe).is_file():
raise ValueError(f"{bpe} does not exist")

model_config = sherpa_onnx.OnlinePunctuationModelConfig(
cnn_bilstm=model, bpe_vocab=bpe
)
config = sherpa_onnx.OnlinePunctuationConfig(model_config=model_config)
punct = sherpa_onnx.OnlinePunctuation(config)

texts = [
"how are you i am fine thank you",
"The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry",
]
for text in texts:
text_with_punct = punct.add_punctuation_with_case(text)
print("----------")
print(f"input : {text}")
print(f"output: {text_with_punct}")
print("----------")


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions sherpa-onnx/python/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ set(srcs
online-model-config.cc
online-nemo-ctc-model-config.cc
online-paraformer-model-config.cc
online-punctuation.cc
online-recognizer.cc
online-stream.cc
online-transducer-model-config.cc
Expand Down
50 changes: 50 additions & 0 deletions sherpa-onnx/python/csrc/online-punctuation.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// sherpa-onnx/python/csrc/online-punctuation.cc
//
// Copyright (c) 2024

#include "sherpa-onnx/python/csrc/online-punctuation.h"

#include "sherpa-onnx/csrc/online-punctuation.h"

namespace sherpa_onnx {

static void PybindOnlinePunctuationModelConfig(py::module *m) {
using PyClass = OnlinePunctuationModelConfig;
py::class_<PyClass>(*m, "OnlinePunctuationModelConfig")
.def(py::init<>())
.def(py::init<const std::string &, const std::string &, int32_t, bool, const std::string &>(),
py::arg("cnn_bilstm"), py::arg("bpe_vocab"), py::arg("num_threads") = 1,
py::arg("debug") = false, py::arg("provider") = "cpu")
.def_readwrite("cnn_bilstm", &PyClass::cnn_bilstm)
.def_readwrite("bpe_vocab", &PyClass::bpe_vocab)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}

static void PybindOnlinePunctuationConfig(py::module *m) {
PybindOnlinePunctuationModelConfig(m);
using PyClass = OnlinePunctuationConfig;

py::class_<PyClass>(*m, "OnlinePunctuationConfig")
.def(py::init<>())
.def(py::init<const OnlinePunctuationModelConfig &>(), py::arg("model_config"))
.def_readwrite("model_config", &PyClass::model)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}

void PybindOnlinePunctuation(py::module *m) {
PybindOnlinePunctuationConfig(m);
using PyClass = OnlinePunctuation;

py::class_<PyClass>(*m, "OnlinePunctuation")
.def(py::init<const OnlinePunctuationConfig &>(), py::arg("config"),
py::call_guard<py::gil_scoped_release>())
.def("add_punctuation_with_case", &PyClass::AddPunctuationWithCase, py::arg("text"),
py::call_guard<py::gil_scoped_release>());
}

} // namespace sherpa_onnx
16 changes: 16 additions & 0 deletions sherpa-onnx/python/csrc/online-punctuation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/online-punctuation.h
//
// Copyright (c) 2024

#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_PUNCTUATION_H_
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_PUNCTUATION_H_

#include "sherpa-onnx/python/csrc/sherpa-onnx.h"

namespace sherpa_onnx {

void PybindOnlinePunctuation(py::module *m);

}

#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_PUNCTUATION_H_
2 changes: 2 additions & 0 deletions sherpa-onnx/python/csrc/sherpa-onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h"
#include "sherpa-onnx/python/csrc/online-lm-config.h"
#include "sherpa-onnx/python/csrc/online-model-config.h"
#include "sherpa-onnx/python/csrc/online-punctuation.h"
#include "sherpa-onnx/python/csrc/online-recognizer.h"
#include "sherpa-onnx/python/csrc/online-stream.h"
#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
Expand All @@ -42,6 +43,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindWaveWriter(&m);
PybindAudioTagging(&m);
PybindOfflinePunctuation(&m);
PybindOnlinePunctuation(&m);

PybindFeatures(&m);
PybindOnlineCtcFstDecoderConfig(&m);
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/python/sherpa_onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
OfflineTtsModelConfig,
OfflineTtsVitsModelConfig,
OfflineZipformerAudioTaggingModelConfig,
OnlinePunctuation,
OnlinePunctuationConfig,
OnlinePunctuationModelConfig,
OnlineStream,
SileroVadModelConfig,
SpeakerEmbeddingExtractor,
Expand Down

0 comments on commit 3bffc24

Please sign in to comment.