From 9efe69720df875fde43df3c6467ab4345e043a20 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 16 Oct 2023 17:22:30 +0800 Subject: [PATCH] Support VITS VCTK models (#367) * Support VITS VCTK models * Release v1.8.1 --- CMakeLists.txt | 2 +- python-api-examples/offline-tts.py | 12 +- scripts/vits/.gitignore | 1 + scripts/vits/export-onnx-ljs.py | 1 + scripts/vits/export-onnx-vctk.py | 222 ++++++++++++++++++ sherpa-onnx/csrc/offline-tts-impl.h | 3 +- sherpa-onnx/csrc/offline-tts-vits-impl.h | 5 +- .../csrc/offline-tts-vits-model-config.cc | 10 +- .../csrc/offline-tts-vits-model-config.h | 18 +- sherpa-onnx/csrc/offline-tts-vits-model.cc | 33 ++- sherpa-onnx/csrc/offline-tts-vits-model.h | 6 +- sherpa-onnx/csrc/offline-tts.cc | 5 +- sherpa-onnx/csrc/offline-tts.h | 6 +- sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc | 28 ++- .../csrc/offline-tts-vits-model-config.cc | 9 +- sherpa-onnx/python/csrc/offline-tts.cc | 2 +- 16 files changed, 332 insertions(+), 31 deletions(-) create mode 100755 scripts/vits/export-onnx-vctk.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 2726998fd..70a014e66 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.8.0") +set(SHERPA_ONNX_VERSION "1.8.1") # Disable warning about # diff --git a/python-api-examples/offline-tts.py b/python-api-examples/offline-tts.py index e264a0926..85e588040 100755 --- a/python-api-examples/offline-tts.py +++ b/python-api-examples/offline-tts.py @@ -58,6 +58,16 @@ def get_args(): help="Path to save generated wave", ) + parser.add_argument( + "--sid", + type=int, + default=0, + help="""Speaker ID. Used only for multi-speaker models, e.g. + models trained using the VCTK dataset. Not used for single-speaker + models, e.g., models trained using the LJ speech dataset. + """, + ) + parser.add_argument( "--debug", type=bool, @@ -105,7 +115,7 @@ def main(): ) ) tts = sherpa_onnx.OfflineTts(tts_config) - audio = tts.generate(args.text) + audio = tts.generate(args.text, sid=args.sid) sf.write( args.output_filename, audio.samples, diff --git a/scripts/vits/.gitignore b/scripts/vits/.gitignore index 91f2b86ad..73f338ee7 100644 --- a/scripts/vits/.gitignore +++ b/scripts/vits/.gitignore @@ -1 +1,2 @@ tokens-ljs.txt +tokens-vctk.txt diff --git a/scripts/vits/export-onnx-ljs.py b/scripts/vits/export-onnx-ljs.py index 84402d1c9..285f93fb5 100755 --- a/scripts/vits/export-onnx-ljs.py +++ b/scripts/vits/export-onnx-ljs.py @@ -191,6 +191,7 @@ def main(): "comment": "ljspeech", "language": "English", "add_blank": int(hps.data.add_blank), + "n_speakers": int(hps.data.n_speakers), "sample_rate": hps.data.sampling_rate, "punctuation": " ".join(list(_punctuation)), } diff --git a/scripts/vits/export-onnx-vctk.py b/scripts/vits/export-onnx-vctk.py new file mode 100755 index 000000000..1a8dc135d --- /dev/null +++ b/scripts/vits/export-onnx-vctk.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This script converts vits models trained using the VCTK dataset. + +Usage: + +(1) Download vits + +cd /Users/fangjun/open-source +git clone https://github.com/jaywalnut310/vits + +(2) Download pre-trained models from +https://huggingface.co/csukuangfj/vits-vctk/tree/main + +wget https://huggingface.co/csukuangfj/vits-vctk/resolve/main/pretrained_vctk.pth + +(3) Run this file + +./export-onnx-vctk.py \ + --config ~/open-source//vits/configs/vctk_base.json \ + --checkpoint ~/open-source/icefall-models/vits-vctk/pretrained_vctk.pth + +It will generate the following two files: + +$ ls -lh *.onnx +-rw-r--r-- 1 fangjun staff 37M Oct 16 10:57 vits-vctk.int8.onnx +-rw-r--r-- 1 fangjun staff 116M Oct 16 10:57 vits-vctk.onnx +""" +import sys + +# Please change this line to point to the vits directory. +# You can download vits from +# https://github.com/jaywalnut310/vits +sys.path.insert(0, "/Users/fangjun/open-source/vits") # noqa + +import argparse +from pathlib import Path +from typing import Dict, Any + +import commons +import onnx +import torch +import utils +from models import SynthesizerTrn +from onnxruntime.quantization import QuantType, quantize_dynamic +from text import text_to_sequence +from text.symbols import symbols +from text.symbols import _punctuation + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + type=str, + required=True, + help="""Path to vctk_base.json. + You can find it at + https://huggingface.co/csukuangfj/vits-vctk/resolve/main/vctk_base.json + """, + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="""Path to the checkpoint file. + You can find it at + https://huggingface.co/csukuangfj/vits-vctk/resolve/main/pretrained_vctk.pth + """, + ) + + return parser.parse_args() + + +class OnnxModel(torch.nn.Module): + def __init__(self, model: SynthesizerTrn): + super().__init__() + self.model = model + + def forward( + self, + x, + x_lengths, + noise_scale=1, + length_scale=1, + noise_scale_w=1.0, + sid=0, + max_len=None, + ): + return self.model.infer( + x=x, + x_lengths=x_lengths, + sid=sid, + noise_scale=noise_scale, + length_scale=length_scale, + noise_scale_w=noise_scale_w, + max_len=max_len, + )[0] + + +def get_text(text, hps): + text_norm = text_to_sequence(text, hps.data.text_cleaners) + if hps.data.add_blank: + text_norm = commons.intersperse(text_norm, 0) + text_norm = torch.LongTensor(text_norm) + return text_norm + + +def check_args(args): + assert Path(args.config).is_file(), args.config + assert Path(args.checkpoint).is_file(), args.checkpoint + + +def add_meta_data(filename: str, meta_data: Dict[str, Any]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +def generate_tokens(): + with open("tokens-vctk.txt", "w", encoding="utf-8") as f: + for i, s in enumerate(symbols): + f.write(f"{s} {i}\n") + print("Generated tokens-vctk.txt") + + +@torch.no_grad() +def main(): + args = get_args() + check_args(args) + + generate_tokens() + + hps = utils.get_hparams_from_file(args.config) + + net_g = SynthesizerTrn( + len(symbols), + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ) + _ = net_g.eval() + + _ = utils.load_checkpoint(args.checkpoint, net_g, None) + + x = get_text("Liliana is the most beautiful assistant", hps) + x = x.unsqueeze(0) + + x_length = torch.tensor([x.shape[1]], dtype=torch.int64) + noise_scale = torch.tensor([1], dtype=torch.float32) + length_scale = torch.tensor([1], dtype=torch.float32) + noise_scale_w = torch.tensor([1], dtype=torch.float32) + sid = torch.tensor([0], dtype=torch.int64) + + model = OnnxModel(net_g) + + opset_version = 13 + + filename = "vits-vctk.onnx" + + torch.onnx.export( + model, + (x, x_length, noise_scale, length_scale, noise_scale_w, sid), + filename, + opset_version=opset_version, + input_names=[ + "x", + "x_length", + "noise_scale", + "length_scale", + "noise_scale_w", + "sid", + ], + output_names=["y"], + dynamic_axes={ + "x": {0: "N", 1: "L"}, # n_audio is also known as batch_size + "x_length": {0: "N"}, + "y": {0: "N", 2: "L"}, + }, + ) + meta_data = { + "model_type": "vits", + "comment": "vctk", + "language": "English", + "add_blank": int(hps.data.add_blank), + "n_speakers": int(hps.data.n_speakers), + "sample_rate": hps.data.sampling_rate, + "punctuation": " ".join(list(_punctuation)), + } + print("meta_data", meta_data) + add_meta_data(filename=filename, meta_data=meta_data) + + print("Generate int8 quantization models") + + filename_int8 = "vits-vctk.int8.onnx" + quantize_dynamic( + model_input=filename, + model_output=filename_int8, + weight_type=QuantType.QUInt8, + ) + + print(f"Saved to {filename} and {filename_int8}") + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/csrc/offline-tts-impl.h b/sherpa-onnx/csrc/offline-tts-impl.h index 877dd11a1..1de5590ea 100644 --- a/sherpa-onnx/csrc/offline-tts-impl.h +++ b/sherpa-onnx/csrc/offline-tts-impl.h @@ -18,7 +18,8 @@ class OfflineTtsImpl { static std::unique_ptr Create(const OfflineTtsConfig &config); - virtual GeneratedAudio Generate(const std::string &text) const = 0; + virtual GeneratedAudio Generate(const std::string &text, + int64_t sid = 0) const = 0; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-vits-impl.h b/sherpa-onnx/csrc/offline-tts-vits-impl.h index abbf28193..59651ab21 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-impl.h +++ b/sherpa-onnx/csrc/offline-tts-vits-impl.h @@ -23,7 +23,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { lexicon_(config.model.vits.lexicon, config.model.vits.tokens, model_->Punctuations()) {} - GeneratedAudio Generate(const std::string &text) const override { + GeneratedAudio Generate(const std::string &text, + int64_t sid = 0) const override { std::vector x = lexicon_.ConvertTextToTokenIds(text); if (x.empty()) { SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); @@ -47,7 +48,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { Ort::Value x_tensor = Ort::Value::CreateTensor( memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); - Ort::Value audio = model_->Run(std::move(x_tensor)); + Ort::Value audio = model_->Run(std::move(x_tensor), sid); std::vector audio_shape = audio.GetTensorTypeAndShapeInfo().GetShape(); diff --git a/sherpa-onnx/csrc/offline-tts-vits-model-config.cc b/sherpa-onnx/csrc/offline-tts-vits-model-config.cc index 5bcb7f0bf..dcb90079a 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model-config.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model-config.cc @@ -13,6 +13,11 @@ void OfflineTtsVitsModelConfig::Register(ParseOptions *po) { po->Register("vits-model", &model, "Path to VITS model"); po->Register("vits-lexicon", &lexicon, "Path to lexicon.txt for VITS models"); po->Register("vits-tokens", &tokens, "Path to tokens.txt for VITS models"); + po->Register("vits-noise-scale", &noise_scale, "noise_scale for VITS models"); + po->Register("vits-noise-scale-w", &noise_scale_w, + "noise_scale_w for VITS models"); + po->Register("vits-length-scale", &length_scale, + "length_scale for VITS models"); } bool OfflineTtsVitsModelConfig::Validate() const { @@ -55,7 +60,10 @@ std::string OfflineTtsVitsModelConfig::ToString() const { os << "OfflineTtsVitsModelConfig("; os << "model=\"" << model << "\", "; os << "lexicon=\"" << lexicon << "\", "; - os << "tokens=\"" << tokens << "\")"; + os << "tokens=\"" << tokens << "\", "; + os << "noise_scale=" << noise_scale << ", "; + os << "noise_scale_w=" << noise_scale_w << ", "; + os << "length_scale=" << length_scale << ")"; return os.str(); } diff --git a/sherpa-onnx/csrc/offline-tts-vits-model-config.h b/sherpa-onnx/csrc/offline-tts-vits-model-config.h index c8f097598..62bc566be 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model-config.h +++ b/sherpa-onnx/csrc/offline-tts-vits-model-config.h @@ -16,12 +16,26 @@ struct OfflineTtsVitsModelConfig { std::string lexicon; std::string tokens; + float noise_scale = 0.667; + float noise_scale_w = 0.8; + float length_scale = 1; + + // used only for multi-speaker models, e.g, vctk speech dataset. + // Not applicable for single-speaker models, e.g., ljspeech dataset + OfflineTtsVitsModelConfig() = default; OfflineTtsVitsModelConfig(const std::string &model, const std::string &lexicon, - const std::string &tokens) - : model(model), lexicon(lexicon), tokens(tokens) {} + const std::string &tokens, + float noise_scale = 0.667, + float noise_scale_w = 0.8, float length_scale = 1) + : model(model), + lexicon(lexicon), + tokens(tokens), + noise_scale(noise_scale), + noise_scale_w(noise_scale_w), + length_scale(length_scale) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.cc b/sherpa-onnx/csrc/offline-tts-vits-model.cc index 2f6365132..2d6792941 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model.cc @@ -26,7 +26,7 @@ class OfflineTtsVitsModel::Impl { Init(buf.data(), buf.size()); } - Ort::Value Run(Ort::Value x) { + Ort::Value Run(Ort::Value x, int64_t sid) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); @@ -44,20 +44,33 @@ class OfflineTtsVitsModel::Impl { Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1); int64_t scale_shape = 1; - float noise_scale = 1; - float length_scale = 1; - float noise_scale_w = 1; + float noise_scale = config_.vits.noise_scale; + float length_scale = config_.vits.length_scale; + float noise_scale_w = config_.vits.noise_scale_w; Ort::Value noise_scale_tensor = Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1); + Ort::Value length_scale_tensor = Ort::Value::CreateTensor( memory_info, &length_scale, 1, &scale_shape, 1); + Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor( memory_info, &noise_scale_w, 1, &scale_shape, 1); - std::array inputs = { - std::move(x), std::move(x_length), std::move(noise_scale_tensor), - std::move(length_scale_tensor), std::move(noise_scale_w_tensor)}; + Ort::Value sid_tensor = + Ort::Value::CreateTensor(memory_info, &sid, 1, &scale_shape, 1); + + std::vector inputs; + inputs.reserve(6); + inputs.push_back(std::move(x)); + inputs.push_back(std::move(x_length)); + inputs.push_back(std::move(noise_scale_tensor)); + inputs.push_back(std::move(length_scale_tensor)); + inputs.push_back(std::move(noise_scale_w_tensor)); + + if (input_names_.size() == 6 && input_names_.back() == "sid") { + inputs.push_back(std::move(sid_tensor)); + } auto out = sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), @@ -93,6 +106,7 @@ class OfflineTtsVitsModel::Impl { Ort::AllocatorWithDefaultOptions allocator; // used in the macro below SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate"); SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank"); + SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers"); SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation"); } @@ -112,6 +126,7 @@ class OfflineTtsVitsModel::Impl { int32_t sample_rate_; int32_t add_blank_; + int32_t n_speakers_; std::string punctuations_; }; @@ -120,8 +135,8 @@ OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config) OfflineTtsVitsModel::~OfflineTtsVitsModel() = default; -Ort::Value OfflineTtsVitsModel::Run(Ort::Value x) { - return impl_->Run(std::move(x)); +Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/) { + return impl_->Run(std::move(x), sid); } int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); } diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.h b/sherpa-onnx/csrc/offline-tts-vits-model.h index ca2c1c6be..de3927f73 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.h +++ b/sherpa-onnx/csrc/offline-tts-vits-model.h @@ -22,10 +22,14 @@ class OfflineTtsVitsModel { /** Run the model. * * @param x A int64 tensor of shape (1, num_tokens) + // @param sid Speaker ID. Used only for multi-speaker models, e.g., models + // trained using the VCTK dataset. It is not used for + // single-speaker models, e.g., models trained using the ljspeech + // dataset. * @return Return a float32 tensor containing audio samples. You can flatten * it to a 1-D tensor. */ - Ort::Value Run(Ort::Value x); + Ort::Value Run(Ort::Value x, int64_t sid = 0); // Sample rate of the generated audio int32_t SampleRate() const; diff --git a/sherpa-onnx/csrc/offline-tts.cc b/sherpa-onnx/csrc/offline-tts.cc index 1154f2e4e..36ec9beee 100644 --- a/sherpa-onnx/csrc/offline-tts.cc +++ b/sherpa-onnx/csrc/offline-tts.cc @@ -28,8 +28,9 @@ OfflineTts::OfflineTts(const OfflineTtsConfig &config) OfflineTts::~OfflineTts() = default; -GeneratedAudio OfflineTts::Generate(const std::string &text) const { - return impl_->Generate(text); +GeneratedAudio OfflineTts::Generate(const std::string &text, + int64_t sid /*=0*/) const { + return impl_->Generate(text, sid); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts.h b/sherpa-onnx/csrc/offline-tts.h index 0d6ce6687..6e0b1402d 100644 --- a/sherpa-onnx/csrc/offline-tts.h +++ b/sherpa-onnx/csrc/offline-tts.h @@ -39,7 +39,11 @@ class OfflineTts { ~OfflineTts(); explicit OfflineTts(const OfflineTtsConfig &config); // @param text A string containing words separated by spaces - GeneratedAudio Generate(const std::string &text) const; + // @param sid Speaker ID. Used only for multi-speaker models, e.g., models + // trained using the VCTK dataset. It is not used for + // single-speaker models, e.g., models trained using the ljspeech + // dataset. + GeneratedAudio Generate(const std::string &text, int64_t sid = 0) const; private: std::unique_ptr impl_; diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc index 0354218b2..bd4fc4e00 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc @@ -13,11 +13,12 @@ int main(int32_t argc, char *argv[]) { Offline text-to-speech with sherpa-onnx ./bin/sherpa-onnx-offline-tts \ - --vits-model /path/to/model.onnx \ - --vits-lexicon /path/to/lexicon.txt \ - --vits-tokens /path/to/tokens.txt - --output-filename ./generated.wav \ - 'some text within single quotes' + --vits-model=/path/to/model.onnx \ + --vits-lexicon=/path/to/lexicon.txt \ + --vits-tokens=/path/to/tokens.txt \ + --sid=0 \ + --output-filename=./generated.wav \ + 'some text within single quotes on linux/macos or use double quotes on windows' It will generate a file ./generated.wav as specified by --output-filename. @@ -33,15 +34,27 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt --vits-model=./vits-ljs.onnx \ --vits-lexicon=./lexicon.txt \ --vits-tokens=./tokens.txt \ + --sid=0 \ --output-filename=./generated.wav \ 'liliana, the most beautiful and lovely assistant of our team!' + +Please see +https://k2-fsa.github.io/sherpa/onnx/tts/index.html +or detailes. )usage"; sherpa_onnx::ParseOptions po(kUsageMessage); std::string output_filename = "./generated.wav"; + int32_t sid = 0; + po.Register("output-filename", &output_filename, "Path to save the generated audio"); + po.Register("sid", &sid, + "Speaker ID. Used only for multi-speaker models, e.g., models " + "trained using the VCTK dataset. Not used for single-speaker " + "models, e.g., models trained using the LJSpeech dataset"); + sherpa_onnx::OfflineTtsConfig config; config.Register(&po); @@ -67,7 +80,7 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt } sherpa_onnx::OfflineTts tts(config); - auto audio = tts.Generate(po.GetArg(1)); + auto audio = tts.Generate(po.GetArg(1), sid); bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate, audio.samples.data(), audio.samples.size()); @@ -76,7 +89,8 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt exit(EXIT_FAILURE); } - fprintf(stderr, "The text is: %s\n", po.GetArg(1).c_str()); + fprintf(stderr, "The text is: %s. Speaker ID: %d\n", po.GetArg(1).c_str(), + sid); fprintf(stderr, "Saved to %s successfully!\n", output_filename.c_str()); return 0; diff --git a/sherpa-onnx/python/csrc/offline-tts-vits-model-config.cc b/sherpa-onnx/python/csrc/offline-tts-vits-model-config.cc index 2471c3f5e..60521ef9a 100644 --- a/sherpa-onnx/python/csrc/offline-tts-vits-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-tts-vits-model-config.cc @@ -16,11 +16,16 @@ void PybindOfflineTtsVitsModelConfig(py::module *m) { py::class_(*m, "OfflineTtsVitsModelConfig") .def(py::init<>()) .def(py::init(), - py::arg("model"), py::arg("lexicon"), py::arg("tokens")) + const std::string &, float, float, float>(), + py::arg("model"), py::arg("lexicon"), py::arg("tokens"), + py::arg("noise_scale") = 0.667, py::arg("noise_scale_w") = 0.8, + py::arg("length_scale") = 1.0) .def_readwrite("model", &PyClass::model) .def_readwrite("lexicon", &PyClass::lexicon) .def_readwrite("tokens", &PyClass::tokens) + .def_readwrite("noise_scale", &PyClass::noise_scale) + .def_readwrite("noise_scale_w", &PyClass::noise_scale_w) + .def_readwrite("length_scale", &PyClass::length_scale) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/csrc/offline-tts.cc b/sherpa-onnx/python/csrc/offline-tts.cc index 199e4df2d..81af16917 100644 --- a/sherpa-onnx/python/csrc/offline-tts.cc +++ b/sherpa-onnx/python/csrc/offline-tts.cc @@ -40,7 +40,7 @@ void PybindOfflineTts(py::module *m) { using PyClass = OfflineTts; py::class_(*m, "OfflineTts") .def(py::init(), py::arg("config")) - .def("generate", &PyClass::Generate); + .def("generate", &PyClass::Generate, py::arg("text"), py::arg("sid") = 0); } } // namespace sherpa_onnx