Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support VITS VCTK models #367

Merged
merged 2 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)

set(SHERPA_ONNX_VERSION "1.8.0")
set(SHERPA_ONNX_VERSION "1.8.1")

# Disable warning about
#
Expand Down
12 changes: 11 additions & 1 deletion python-api-examples/offline-tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ def get_args():
help="Path to save generated wave",
)

parser.add_argument(
"--sid",
type=int,
default=0,
help="""Speaker ID. Used only for multi-speaker models, e.g.
models trained using the VCTK dataset. Not used for single-speaker
models, e.g., models trained using the LJ speech dataset.
""",
)

parser.add_argument(
"--debug",
type=bool,
Expand Down Expand Up @@ -105,7 +115,7 @@ def main():
)
)
tts = sherpa_onnx.OfflineTts(tts_config)
audio = tts.generate(args.text)
audio = tts.generate(args.text, sid=args.sid)
sf.write(
args.output_filename,
audio.samples,
Expand Down
1 change: 1 addition & 0 deletions scripts/vits/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
tokens-ljs.txt
tokens-vctk.txt
1 change: 1 addition & 0 deletions scripts/vits/export-onnx-ljs.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def main():
"comment": "ljspeech",
"language": "English",
"add_blank": int(hps.data.add_blank),
"n_speakers": int(hps.data.n_speakers),
"sample_rate": hps.data.sampling_rate,
"punctuation": " ".join(list(_punctuation)),
}
Expand Down
222 changes: 222 additions & 0 deletions scripts/vits/export-onnx-vctk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)

"""
This script converts vits models trained using the VCTK dataset.

Usage:

(1) Download vits

cd /Users/fangjun/open-source
git clone https://github.com/jaywalnut310/vits

(2) Download pre-trained models from
https://huggingface.co/csukuangfj/vits-vctk/tree/main

wget https://huggingface.co/csukuangfj/vits-vctk/resolve/main/pretrained_vctk.pth

(3) Run this file

./export-onnx-vctk.py \
--config ~/open-source//vits/configs/vctk_base.json \
--checkpoint ~/open-source/icefall-models/vits-vctk/pretrained_vctk.pth

It will generate the following two files:

$ ls -lh *.onnx
-rw-r--r-- 1 fangjun staff 37M Oct 16 10:57 vits-vctk.int8.onnx
-rw-r--r-- 1 fangjun staff 116M Oct 16 10:57 vits-vctk.onnx
"""
import sys

# Please change this line to point to the vits directory.
# You can download vits from
# https://github.com/jaywalnut310/vits
sys.path.insert(0, "/Users/fangjun/open-source/vits") # noqa

import argparse
from pathlib import Path
from typing import Dict, Any

import commons
import onnx
import torch
import utils
from models import SynthesizerTrn
from onnxruntime.quantization import QuantType, quantize_dynamic
from text import text_to_sequence
from text.symbols import symbols
from text.symbols import _punctuation


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
type=str,
required=True,
help="""Path to vctk_base.json.
You can find it at
https://huggingface.co/csukuangfj/vits-vctk/resolve/main/vctk_base.json
""",
)

parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="""Path to the checkpoint file.
You can find it at
https://huggingface.co/csukuangfj/vits-vctk/resolve/main/pretrained_vctk.pth
""",
)

return parser.parse_args()


class OnnxModel(torch.nn.Module):
def __init__(self, model: SynthesizerTrn):
super().__init__()
self.model = model

def forward(
self,
x,
x_lengths,
noise_scale=1,
length_scale=1,
noise_scale_w=1.0,
sid=0,
max_len=None,
):
return self.model.infer(
x=x,
x_lengths=x_lengths,
sid=sid,
noise_scale=noise_scale,
length_scale=length_scale,
noise_scale_w=noise_scale_w,
max_len=max_len,
)[0]


def get_text(text, hps):
text_norm = text_to_sequence(text, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm


def check_args(args):
assert Path(args.config).is_file(), args.config
assert Path(args.checkpoint).is_file(), args.checkpoint


def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.

Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)

onnx.save(model, filename)


def generate_tokens():
with open("tokens-vctk.txt", "w", encoding="utf-8") as f:
for i, s in enumerate(symbols):
f.write(f"{s} {i}\n")
print("Generated tokens-vctk.txt")


@torch.no_grad()
def main():
args = get_args()
check_args(args)

generate_tokens()

hps = utils.get_hparams_from_file(args.config)

net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
)
_ = net_g.eval()

_ = utils.load_checkpoint(args.checkpoint, net_g, None)

x = get_text("Liliana is the most beautiful assistant", hps)
x = x.unsqueeze(0)

x_length = torch.tensor([x.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
length_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_w = torch.tensor([1], dtype=torch.float32)
sid = torch.tensor([0], dtype=torch.int64)

model = OnnxModel(net_g)

opset_version = 13

filename = "vits-vctk.onnx"

torch.onnx.export(
model,
(x, x_length, noise_scale, length_scale, noise_scale_w, sid),
filename,
opset_version=opset_version,
input_names=[
"x",
"x_length",
"noise_scale",
"length_scale",
"noise_scale_w",
"sid",
],
output_names=["y"],
dynamic_axes={
"x": {0: "N", 1: "L"}, # n_audio is also known as batch_size
"x_length": {0: "N"},
"y": {0: "N", 2: "L"},
},
)
meta_data = {
"model_type": "vits",
"comment": "vctk",
"language": "English",
"add_blank": int(hps.data.add_blank),
"n_speakers": int(hps.data.n_speakers),
"sample_rate": hps.data.sampling_rate,
"punctuation": " ".join(list(_punctuation)),
}
print("meta_data", meta_data)
add_meta_data(filename=filename, meta_data=meta_data)

print("Generate int8 quantization models")

filename_int8 = "vits-vctk.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
weight_type=QuantType.QUInt8,
)

print(f"Saved to {filename} and {filename_int8}")


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/offline-tts-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class OfflineTtsImpl {

static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config);

virtual GeneratedAudio Generate(const std::string &text) const = 0;
virtual GeneratedAudio Generate(const std::string &text,
int64_t sid = 0) const = 0;
};

} // namespace sherpa_onnx
Expand Down
5 changes: 3 additions & 2 deletions sherpa-onnx/csrc/offline-tts-vits-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
model_->Punctuations()) {}

GeneratedAudio Generate(const std::string &text) const override {
GeneratedAudio Generate(const std::string &text,
int64_t sid = 0) const override {
std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
if (x.empty()) {
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
Expand All @@ -47,7 +48,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
Ort::Value x_tensor = Ort::Value::CreateTensor(
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());

Ort::Value audio = model_->Run(std::move(x_tensor));
Ort::Value audio = model_->Run(std::move(x_tensor), sid);

std::vector<int64_t> audio_shape =
audio.GetTensorTypeAndShapeInfo().GetShape();
Expand Down
10 changes: 9 additions & 1 deletion sherpa-onnx/csrc/offline-tts-vits-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ void OfflineTtsVitsModelConfig::Register(ParseOptions *po) {
po->Register("vits-model", &model, "Path to VITS model");
po->Register("vits-lexicon", &lexicon, "Path to lexicon.txt for VITS models");
po->Register("vits-tokens", &tokens, "Path to tokens.txt for VITS models");
po->Register("vits-noise-scale", &noise_scale, "noise_scale for VITS models");
po->Register("vits-noise-scale-w", &noise_scale_w,
"noise_scale_w for VITS models");
po->Register("vits-length-scale", &length_scale,
"length_scale for VITS models");
}

bool OfflineTtsVitsModelConfig::Validate() const {
Expand Down Expand Up @@ -55,7 +60,10 @@ std::string OfflineTtsVitsModelConfig::ToString() const {
os << "OfflineTtsVitsModelConfig(";
os << "model=\"" << model << "\", ";
os << "lexicon=\"" << lexicon << "\", ";
os << "tokens=\"" << tokens << "\")";
os << "tokens=\"" << tokens << "\", ";
os << "noise_scale=" << noise_scale << ", ";
os << "noise_scale_w=" << noise_scale_w << ", ";
os << "length_scale=" << length_scale << ")";

return os.str();
}
Expand Down
18 changes: 16 additions & 2 deletions sherpa-onnx/csrc/offline-tts-vits-model-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,26 @@ struct OfflineTtsVitsModelConfig {
std::string lexicon;
std::string tokens;

float noise_scale = 0.667;
float noise_scale_w = 0.8;
float length_scale = 1;

// used only for multi-speaker models, e.g, vctk speech dataset.
// Not applicable for single-speaker models, e.g., ljspeech dataset

OfflineTtsVitsModelConfig() = default;

OfflineTtsVitsModelConfig(const std::string &model,
const std::string &lexicon,
const std::string &tokens)
: model(model), lexicon(lexicon), tokens(tokens) {}
const std::string &tokens,
float noise_scale = 0.667,
float noise_scale_w = 0.8, float length_scale = 1)
: model(model),
lexicon(lexicon),
tokens(tokens),
noise_scale(noise_scale),
noise_scale_w(noise_scale_w),
length_scale(length_scale) {}

void Register(ParseOptions *po);
bool Validate() const;
Expand Down
Loading