Skip to content

Commit

Permalink
Support VITS VCTK models (#367)
Browse files Browse the repository at this point in the history
* Support VITS VCTK models

* Release v1.8.1
  • Loading branch information
csukuangfj authored Oct 16, 2023
1 parent d01682d commit 9efe697
Show file tree
Hide file tree
Showing 16 changed files with 332 additions and 31 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)

set(SHERPA_ONNX_VERSION "1.8.0")
set(SHERPA_ONNX_VERSION "1.8.1")

# Disable warning about
#
Expand Down
12 changes: 11 additions & 1 deletion python-api-examples/offline-tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ def get_args():
help="Path to save generated wave",
)

parser.add_argument(
"--sid",
type=int,
default=0,
help="""Speaker ID. Used only for multi-speaker models, e.g.
models trained using the VCTK dataset. Not used for single-speaker
models, e.g., models trained using the LJ speech dataset.
""",
)

parser.add_argument(
"--debug",
type=bool,
Expand Down Expand Up @@ -105,7 +115,7 @@ def main():
)
)
tts = sherpa_onnx.OfflineTts(tts_config)
audio = tts.generate(args.text)
audio = tts.generate(args.text, sid=args.sid)
sf.write(
args.output_filename,
audio.samples,
Expand Down
1 change: 1 addition & 0 deletions scripts/vits/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
tokens-ljs.txt
tokens-vctk.txt
1 change: 1 addition & 0 deletions scripts/vits/export-onnx-ljs.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def main():
"comment": "ljspeech",
"language": "English",
"add_blank": int(hps.data.add_blank),
"n_speakers": int(hps.data.n_speakers),
"sample_rate": hps.data.sampling_rate,
"punctuation": " ".join(list(_punctuation)),
}
Expand Down
222 changes: 222 additions & 0 deletions scripts/vits/export-onnx-vctk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)

"""
This script converts vits models trained using the VCTK dataset.
Usage:
(1) Download vits
cd /Users/fangjun/open-source
git clone https://github.com/jaywalnut310/vits
(2) Download pre-trained models from
https://huggingface.co/csukuangfj/vits-vctk/tree/main
wget https://huggingface.co/csukuangfj/vits-vctk/resolve/main/pretrained_vctk.pth
(3) Run this file
./export-onnx-vctk.py \
--config ~/open-source//vits/configs/vctk_base.json \
--checkpoint ~/open-source/icefall-models/vits-vctk/pretrained_vctk.pth
It will generate the following two files:
$ ls -lh *.onnx
-rw-r--r-- 1 fangjun staff 37M Oct 16 10:57 vits-vctk.int8.onnx
-rw-r--r-- 1 fangjun staff 116M Oct 16 10:57 vits-vctk.onnx
"""
import sys

# Please change this line to point to the vits directory.
# You can download vits from
# https://github.com/jaywalnut310/vits
sys.path.insert(0, "/Users/fangjun/open-source/vits") # noqa

import argparse
from pathlib import Path
from typing import Dict, Any

import commons
import onnx
import torch
import utils
from models import SynthesizerTrn
from onnxruntime.quantization import QuantType, quantize_dynamic
from text import text_to_sequence
from text.symbols import symbols
from text.symbols import _punctuation


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
type=str,
required=True,
help="""Path to vctk_base.json.
You can find it at
https://huggingface.co/csukuangfj/vits-vctk/resolve/main/vctk_base.json
""",
)

parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="""Path to the checkpoint file.
You can find it at
https://huggingface.co/csukuangfj/vits-vctk/resolve/main/pretrained_vctk.pth
""",
)

return parser.parse_args()


class OnnxModel(torch.nn.Module):
def __init__(self, model: SynthesizerTrn):
super().__init__()
self.model = model

def forward(
self,
x,
x_lengths,
noise_scale=1,
length_scale=1,
noise_scale_w=1.0,
sid=0,
max_len=None,
):
return self.model.infer(
x=x,
x_lengths=x_lengths,
sid=sid,
noise_scale=noise_scale,
length_scale=length_scale,
noise_scale_w=noise_scale_w,
max_len=max_len,
)[0]


def get_text(text, hps):
text_norm = text_to_sequence(text, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm


def check_args(args):
assert Path(args.config).is_file(), args.config
assert Path(args.checkpoint).is_file(), args.checkpoint


def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)

onnx.save(model, filename)


def generate_tokens():
with open("tokens-vctk.txt", "w", encoding="utf-8") as f:
for i, s in enumerate(symbols):
f.write(f"{s} {i}\n")
print("Generated tokens-vctk.txt")


@torch.no_grad()
def main():
args = get_args()
check_args(args)

generate_tokens()

hps = utils.get_hparams_from_file(args.config)

net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
)
_ = net_g.eval()

_ = utils.load_checkpoint(args.checkpoint, net_g, None)

x = get_text("Liliana is the most beautiful assistant", hps)
x = x.unsqueeze(0)

x_length = torch.tensor([x.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
length_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_w = torch.tensor([1], dtype=torch.float32)
sid = torch.tensor([0], dtype=torch.int64)

model = OnnxModel(net_g)

opset_version = 13

filename = "vits-vctk.onnx"

torch.onnx.export(
model,
(x, x_length, noise_scale, length_scale, noise_scale_w, sid),
filename,
opset_version=opset_version,
input_names=[
"x",
"x_length",
"noise_scale",
"length_scale",
"noise_scale_w",
"sid",
],
output_names=["y"],
dynamic_axes={
"x": {0: "N", 1: "L"}, # n_audio is also known as batch_size
"x_length": {0: "N"},
"y": {0: "N", 2: "L"},
},
)
meta_data = {
"model_type": "vits",
"comment": "vctk",
"language": "English",
"add_blank": int(hps.data.add_blank),
"n_speakers": int(hps.data.n_speakers),
"sample_rate": hps.data.sampling_rate,
"punctuation": " ".join(list(_punctuation)),
}
print("meta_data", meta_data)
add_meta_data(filename=filename, meta_data=meta_data)

print("Generate int8 quantization models")

filename_int8 = "vits-vctk.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
weight_type=QuantType.QUInt8,
)

print(f"Saved to {filename} and {filename_int8}")


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/offline-tts-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class OfflineTtsImpl {

static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config);

virtual GeneratedAudio Generate(const std::string &text) const = 0;
virtual GeneratedAudio Generate(const std::string &text,
int64_t sid = 0) const = 0;
};

} // namespace sherpa_onnx
Expand Down
5 changes: 3 additions & 2 deletions sherpa-onnx/csrc/offline-tts-vits-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
model_->Punctuations()) {}

GeneratedAudio Generate(const std::string &text) const override {
GeneratedAudio Generate(const std::string &text,
int64_t sid = 0) const override {
std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
if (x.empty()) {
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
Expand All @@ -47,7 +48,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
Ort::Value x_tensor = Ort::Value::CreateTensor(
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());

Ort::Value audio = model_->Run(std::move(x_tensor));
Ort::Value audio = model_->Run(std::move(x_tensor), sid);

std::vector<int64_t> audio_shape =
audio.GetTensorTypeAndShapeInfo().GetShape();
Expand Down
10 changes: 9 additions & 1 deletion sherpa-onnx/csrc/offline-tts-vits-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ void OfflineTtsVitsModelConfig::Register(ParseOptions *po) {
po->Register("vits-model", &model, "Path to VITS model");
po->Register("vits-lexicon", &lexicon, "Path to lexicon.txt for VITS models");
po->Register("vits-tokens", &tokens, "Path to tokens.txt for VITS models");
po->Register("vits-noise-scale", &noise_scale, "noise_scale for VITS models");
po->Register("vits-noise-scale-w", &noise_scale_w,
"noise_scale_w for VITS models");
po->Register("vits-length-scale", &length_scale,
"length_scale for VITS models");
}

bool OfflineTtsVitsModelConfig::Validate() const {
Expand Down Expand Up @@ -55,7 +60,10 @@ std::string OfflineTtsVitsModelConfig::ToString() const {
os << "OfflineTtsVitsModelConfig(";
os << "model=\"" << model << "\", ";
os << "lexicon=\"" << lexicon << "\", ";
os << "tokens=\"" << tokens << "\")";
os << "tokens=\"" << tokens << "\", ";
os << "noise_scale=" << noise_scale << ", ";
os << "noise_scale_w=" << noise_scale_w << ", ";
os << "length_scale=" << length_scale << ")";

return os.str();
}
Expand Down
18 changes: 16 additions & 2 deletions sherpa-onnx/csrc/offline-tts-vits-model-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,26 @@ struct OfflineTtsVitsModelConfig {
std::string lexicon;
std::string tokens;

float noise_scale = 0.667;
float noise_scale_w = 0.8;
float length_scale = 1;

// used only for multi-speaker models, e.g, vctk speech dataset.
// Not applicable for single-speaker models, e.g., ljspeech dataset

OfflineTtsVitsModelConfig() = default;

OfflineTtsVitsModelConfig(const std::string &model,
const std::string &lexicon,
const std::string &tokens)
: model(model), lexicon(lexicon), tokens(tokens) {}
const std::string &tokens,
float noise_scale = 0.667,
float noise_scale_w = 0.8, float length_scale = 1)
: model(model),
lexicon(lexicon),
tokens(tokens),
noise_scale(noise_scale),
noise_scale_w(noise_scale_w),
length_scale(length_scale) {}

void Register(ParseOptions *po);
bool Validate() const;
Expand Down
Loading

0 comments on commit 9efe697

Please sign in to comment.