Skip to content

Commit

Permalink
add python API and examples for TTS (#364)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Oct 14, 2023
1 parent 1ac2232 commit 655e0fa
Show file tree
Hide file tree
Showing 16 changed files with 320 additions and 6 deletions.
20 changes: 20 additions & 0 deletions .github/scripts/test-python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

log "Offline TTS test"

wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/vits-ljs.onnx
wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt
wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt

python3 ./python-api-examples/offline-tts.py \
--vits-model=./vits-ljs.onnx \
--vits-lexicon=./lexicon.txt \
--vits-tokens=./tokens.txt \
--output-filename=./tts.wav \
'liliana, the most beautiful and lovely assistant of our team!'

ls -lh ./tts.wav
file ./tts.wav

rm -v vits-ljs.onnx ./lexicon.txt ./tokens.txt

mkdir -p /tmp/icefall-models
dir=/tmp/icefall-models

Expand Down Expand Up @@ -171,3 +189,5 @@ rm -rf $repo
git clone https://github.com/pkufool/sherpa-test-data /tmp/sherpa-test-data

python3 sherpa-onnx/python/tests/test_text2token.py --verbose

rm -rf /tmp/sherpa-test-data
9 changes: 7 additions & 2 deletions .github/workflows/run-python-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
python-version: "3.10"

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
with:
fetch-depth: 0

Expand All @@ -54,7 +54,7 @@ jobs:
- name: Install Python dependencies
shell: bash
run: |
python3 -m pip install --upgrade pip numpy sentencepiece==0.1.96
python3 -m pip install --upgrade pip numpy sentencepiece==0.1.96 soundfile
- name: Install sherpa-onnx
shell: bash
Expand All @@ -65,3 +65,8 @@ jobs:
shell: bash
run: |
.github/scripts/test-python.sh
- uses: actions/upload-artifact@v3
with:
name: tts-generated-test-files
path: tts.wav
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)

set(SHERPA_ONNX_VERSION "1.7.21")
set(SHERPA_ONNX_VERSION "1.8.0")

# Disable warning about
#
Expand Down
4 changes: 4 additions & 0 deletions cmake/cmake_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,17 @@ def build_extension(self, ext: setuptools.extension.Extension):
binaries += ["sherpa-onnx-offline-websocket-server"]
binaries += ["sherpa-onnx-online-websocket-client"]
binaries += ["sherpa-onnx-vad-microphone"]
binaries += ["sherpa-onnx-offline-tts"]

if is_windows():
binaries += ["kaldi-native-fbank-core.dll"]
binaries += ["sherpa-onnx-c-api.dll"]
binaries += ["sherpa-onnx-core.dll"]
binaries += ["sherpa-onnx-portaudio.dll"]
binaries += ["onnxruntime.dll"]
binaries += ["kaldi-decoder-core.dll"]
binaries += ["sherpa-onnx-fst.dll"]
binaries += ["sherpa-onnx-kaldifst-core.dll"]

for f in binaries:
suffix = "" if "dll" in f else suffix
Expand Down
120 changes: 120 additions & 0 deletions python-api-examples/offline-tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#!/usr/bin/env python3
#
# Copyright (c) 2023 Xiaomi Corporation

"""
This file demonstrates how to use sherpa-onnx Python API to generate audio
from text, i.e., text-to-speech.
Usage:
1. Download a model
wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/vits-ljs.onnx
wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt
wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
python3 ./python-api-examples/offline-tts.py \
--vits-model=./vits-ljs.onnx \
--vits-lexicon=./lexicon.txt \
--vits-tokens=./tokens.txt \
--output-filename=./generated.wav \
'liliana, the most beautiful and lovely assistant of our team!'
"""

import argparse

import sherpa_onnx
import soundfile as sf


def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument(
"--vits-model",
type=str,
help="Path to vits model.onnx",
)

parser.add_argument(
"--vits-lexicon",
type=str,
help="Path to lexicon.txt",
)

parser.add_argument(
"--vits-tokens",
type=str,
help="Path to tokens.txt",
)

parser.add_argument(
"--output-filename",
type=str,
default="./generated.wav",
help="Path to save generated wave",
)

parser.add_argument(
"--debug",
type=bool,
default=False,
help="True to show debug messages",
)

parser.add_argument(
"--provider",
type=str,
default="cpu",
help="valid values: cpu, cuda, coreml",
)

parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)

parser.add_argument(
"text",
type=str,
help="The input text to generate audio for",
)

return parser.parse_args()


def main():
args = get_args()
print(args)

tts_config = sherpa_onnx.OfflineTtsConfig(
model=sherpa_onnx.OfflineTtsModelConfig(
vits=sherpa_onnx.OfflineTtsVitsModelConfig(
model=args.vits_model,
lexicon=args.vits_lexicon,
tokens=args.vits_tokens,
),
provider=args.provider,
debug=args.debug,
num_threads=args.num_threads,
)
)
tts = sherpa_onnx.OfflineTts(tts_config)
audio = tts.generate(args.text)
sf.write(
args.output_filename,
audio.samples,
samplerate=audio.sample_rate,
subtype="PCM_16",
)
print(f"Saved to {args.output_filename}")
print(f"The text is '{args.text}'")


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,16 @@ def get_binaries_to_install():
binaries += ["sherpa-onnx-offline-websocket-server"]
binaries += ["sherpa-onnx-online-websocket-client"]
binaries += ["sherpa-onnx-vad-microphone"]
binaries += ["sherpa-onnx-offline-tts"]
if is_windows():
binaries += ["kaldi-native-fbank-core.dll"]
binaries += ["sherpa-onnx-c-api.dll"]
binaries += ["sherpa-onnx-core.dll"]
binaries += ["sherpa-onnx-portaudio.dll"]
binaries += ["onnxruntime.dll"]
binaries += ["kaldi-decoder-core.dll"]
binaries += ["sherpa-onnx-fst.dll"]
binaries += ["sherpa-onnx-kaldifst-core.dll"]

exe = []
for f in binaries:
Expand Down
4 changes: 1 addition & 3 deletions sherpa-onnx/csrc/offline-tts-vits-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config)
: model_(std::make_unique<OfflineTtsVitsModel>(config.model)),
lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
model_->Punctuations()) {
SHERPA_ONNX_LOGE("config: %s\n", config.ToString().c_str());
}
model_->Punctuations()) {}

GeneratedAudio Generate(const std::string &text) const override {
std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/python/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ pybind11_add_module(_sherpa_onnx
offline-stream.cc
offline-tdnn-model-config.cc
offline-transducer-model-config.cc
offline-tts-model-config.cc
offline-tts-vits-model-config.cc
offline-tts.cc
offline-whisper-model-config.cc
offline-zipformer-ctc-model-config.cc
online-lm-config.cc
Expand Down
32 changes: 32 additions & 0 deletions sherpa-onnx/python/csrc/offline-tts-model-config.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// sherpa-onnx/python/csrc/offline-tts-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation

#include "sherpa-onnx/python/csrc/offline-tts-model-config.h"

#include <string>

#include "sherpa-onnx/csrc/offline-tts-model-config.h"
#include "sherpa-onnx/python/csrc/offline-tts-vits-model-config.h"

namespace sherpa_onnx {

void PybindOfflineTtsModelConfig(py::module *m) {
PybindOfflineTtsVitsModelConfig(m);

using PyClass = OfflineTtsModelConfig;

py::class_<PyClass>(*m, "OfflineTtsModelConfig")
.def(py::init<>())
.def(py::init<const OfflineTtsVitsModelConfig &, int32_t, bool,
const std::string &>(),
py::arg("vits"), py::arg("num_threads") = 1,
py::arg("debug") = false, py::arg("provider") = "cpu")
.def_readwrite("vits", &PyClass::vits)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def("__str__", &PyClass::ToString);
}

} // namespace sherpa_onnx
16 changes: 16 additions & 0 deletions sherpa-onnx/python/csrc/offline-tts-model-config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/offline-tts-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation

#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_

#include "sherpa-onnx/python/csrc/sherpa-onnx.h"

namespace sherpa_onnx {

void PybindOfflineTtsModelConfig(py::module *m);

}

#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_
27 changes: 27 additions & 0 deletions sherpa-onnx/python/csrc/offline-tts-vits-model-config.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// sherpa-onnx/python/csrc/offline-tts-vits-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation

#include "sherpa-onnx/python/csrc/offline-tts-vits-model-config.h"

#include <string>

#include "sherpa-onnx/csrc/offline-tts-vits-model-config.h"

namespace sherpa_onnx {

void PybindOfflineTtsVitsModelConfig(py::module *m) {
using PyClass = OfflineTtsVitsModelConfig;

py::class_<PyClass>(*m, "OfflineTtsVitsModelConfig")
.def(py::init<>())
.def(py::init<const std::string &, const std::string &,
const std::string &>(),
py::arg("model"), py::arg("lexicon"), py::arg("tokens"))
.def_readwrite("model", &PyClass::model)
.def_readwrite("lexicon", &PyClass::lexicon)
.def_readwrite("tokens", &PyClass::tokens)
.def("__str__", &PyClass::ToString);
}

} // namespace sherpa_onnx
16 changes: 16 additions & 0 deletions sherpa-onnx/python/csrc/offline-tts-vits-model-config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/offline-tts-vits-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation

#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_

#include "sherpa-onnx/python/csrc/sherpa-onnx.h"

namespace sherpa_onnx {

void PybindOfflineTtsVitsModelConfig(py::module *m);

}

#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_
46 changes: 46 additions & 0 deletions sherpa-onnx/python/csrc/offline-tts.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// sherpa-onnx/python/csrc/offline-tts.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/offline-tts.h"

#include "sherpa-onnx/csrc/offline-tts.h"
#include "sherpa-onnx/python/csrc/offline-tts-model-config.h"

namespace sherpa_onnx {

static void PybindGeneratedAudio(py::module *m) {
using PyClass = GeneratedAudio;
py::class_<PyClass>(*m, "GeneratedAudio")
.def(py::init<>())
.def_readwrite("samples", &PyClass::samples)
.def_readwrite("sample_rate", &PyClass::sample_rate)
.def("__str__", [](PyClass &self) {
std::ostringstream os;
os << "GeneratedAudio(sample_rate=" << self.sample_rate << ", ";
os << "num_samples=" << self.samples.size() << ")";
return os.str();
});
}

static void PybindOfflineTtsConfig(py::module *m) {
PybindOfflineTtsModelConfig(m);

using PyClass = OfflineTtsConfig;
py::class_<PyClass>(*m, "OfflineTtsConfig")
.def(py::init<>())
.def(py::init<const OfflineTtsModelConfig &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("__str__", &PyClass::ToString);
}

void PybindOfflineTts(py::module *m) {
PybindOfflineTtsConfig(m);
PybindGeneratedAudio(m);

using PyClass = OfflineTts;
py::class_<PyClass>(*m, "OfflineTts")
.def(py::init<const OfflineTtsConfig &>(), py::arg("config"))
.def("generate", &PyClass::Generate);
}

} // namespace sherpa_onnx
Loading

0 comments on commit 655e0fa

Please sign in to comment.