Skip to content

Commit

Permalink
Add TTS with VITS (#360)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Oct 13, 2023
1 parent 4771c92 commit 536d580
Show file tree
Hide file tree
Showing 17 changed files with 839 additions and 0 deletions.
18 changes: 18 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ set(sources
wave-reader.cc
)

list(APPEND sources
lexicon.cc
offline-tts-impl.cc
offline-tts-model-config.cc
offline-tts-vits-model-config.cc
offline-tts-vits-model.cc
offline-tts.cc
)

if(SHERPA_ONNX_ENABLE_CHECK)
list(APPEND sources log.cc)
endif()
Expand Down Expand Up @@ -135,23 +144,31 @@ endif()
add_executable(sherpa-onnx sherpa-onnx.cc)
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)


target_link_libraries(sherpa-onnx sherpa-onnx-core)
target_link_libraries(sherpa-onnx-offline sherpa-onnx-core)
target_link_libraries(sherpa-onnx-offline-parallel sherpa-onnx-core)
target_link_libraries(sherpa-onnx-offline-tts sherpa-onnx-core)
if(NOT WIN32)
target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")

target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")

target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")

target_link_libraries(sherpa-onnx-offline-tts "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
target_link_libraries(sherpa-onnx-offline-tts "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")

if(SHERPA_ONNX_ENABLE_PYTHON)
target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
target_link_libraries(sherpa-onnx-offline-tts "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
endif()
endif()

Expand All @@ -170,6 +187,7 @@ install(
sherpa-onnx
sherpa-onnx-offline
sherpa-onnx-offline-parallel
sherpa-onnx-offline-tts
DESTINATION
bin
)
Expand Down
157 changes: 157 additions & 0 deletions sherpa-onnx/csrc/lexicon.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// sherpa-onnx/csrc/lexicon.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation

#include "sherpa-onnx/csrc/lexicon.h"

#include <algorithm>
#include <cctype>
#include <fstream>
#include <sstream>
#include <utility>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"

namespace sherpa_onnx {

static void ToLowerCase(std::string *in_out) {
std::transform(in_out->begin(), in_out->end(), in_out->begin(),
[](unsigned char c) { return std::tolower(c); });
}

// Note: We don't use SymbolTable here since tokens may contain a blank
// in the first column
static std::unordered_map<std::string, int32_t> ReadTokens(
const std::string &tokens) {
std::unordered_map<std::string, int32_t> token2id;

std::ifstream is(tokens);
std::string line;

std::string sym;
int32_t id;
while (std::getline(is, line)) {
std::istringstream iss(line);
iss >> sym;
if (iss.eof()) {
id = atoi(sym.c_str());
sym = " ";
} else {
iss >> id;
}

if (!iss.eof()) {
SHERPA_ONNX_LOGE("Error: %s", line.c_str());
exit(-1);
}

#if 0
if (token2id.count(sym)) {
SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d",
sym.c_str(), line.c_str(), token2id.at(sym));
exit(-1);
}
#endif
token2id.insert({sym, id});
}

return token2id;
}

static std::vector<int32_t> ConvertTokensToIds(
const std::unordered_map<std::string, int32_t> &token2id,
const std::vector<std::string> &tokens) {
std::vector<int32_t> ids;
ids.reserve(tokens.size());
for (const auto &s : tokens) {
if (!token2id.count(s)) {
return {};
}
int32_t id = token2id.at(s);
ids.push_back(id);
}

return ids;
}

Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
const std::string &punctuations) {
token2id_ = ReadTokens(tokens);
std::ifstream is(lexicon);

std::string word;
std::vector<std::string> token_list;
std::string line;
std::string phone;

while (std::getline(is, line)) {
std::istringstream iss(line);

token_list.clear();

iss >> word;
ToLowerCase(&word);

if (word2ids_.count(word)) {
SHERPA_ONNX_LOGE("Duplicated word: %s", word.c_str());
return;
}

while (iss >> phone) {
token_list.push_back(std::move(phone));
}

std::vector<int32_t> ids = ConvertTokensToIds(token2id_, token_list);
if (ids.empty()) {
continue;
}
word2ids_.insert({std::move(word), std::move(ids)});
}

// process punctuations
std::vector<std::string> punctuation_list;
SplitStringToVector(punctuations, " ", false, &punctuation_list);
for (auto &s : punctuation_list) {
punctuations_.insert(std::move(s));
}
}

std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
const std::string &_text) const {
std::string text(_text);
ToLowerCase(&text);

std::vector<std::string> words;
SplitStringToVector(text, " ", false, &words);

std::vector<int64_t> ans;
for (auto w : words) {
std::vector<int64_t> prefix;
while (!w.empty() && punctuations_.count(std::string(1, w[0]))) {
// if w begins with a punctuation
prefix.push_back(token2id_.at(std::string(1, w[0])));
w = std::string(w.begin() + 1, w.end());
}

std::vector<int64_t> suffix;
while (!w.empty() && punctuations_.count(std::string(1, w.back()))) {
suffix.push_back(token2id_.at(std::string(1, w.back())));
w = std::string(w.begin(), w.end() - 1);
}

if (!word2ids_.count(w)) {
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
continue;
}

const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), prefix.begin(), prefix.end());
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
ans.insert(ans.end(), suffix.rbegin(), suffix.rend());
}

return ans;
}

} // namespace sherpa_onnx
31 changes: 31 additions & 0 deletions sherpa-onnx/csrc/lexicon.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// sherpa-onnx/csrc/lexicon.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_LEXICON_H_
#define SHERPA_ONNX_CSRC_LEXICON_H_

#include <cstdint>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

namespace sherpa_onnx {

class Lexicon {
public:
Lexicon(const std::string &lexicon, const std::string &tokens,
const std::string &punctuations);

std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const;

private:
std::unordered_map<std::string, std::vector<int32_t>> word2ids_;
std::unordered_set<std::string> punctuations_;
std::unordered_map<std::string, int32_t> token2id_;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_LEXICON_H_
19 changes: 19 additions & 0 deletions sherpa-onnx/csrc/offline-tts-impl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// sherpa-onnx/csrc/offline-tts-impl.cc
//
// Copyright (c) 2023 Xiaomi Corporation

#include "sherpa-onnx/csrc/offline-tts-impl.h"

#include <memory>

#include "sherpa-onnx/csrc/offline-tts-vits-impl.h"

namespace sherpa_onnx {

std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
const OfflineTtsConfig &config) {
// TODO(fangjun): Support other types
return std::make_unique<OfflineTtsVitsImpl>(config);
}

} // namespace sherpa_onnx
26 changes: 26 additions & 0 deletions sherpa-onnx/csrc/offline-tts-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// sherpa-onnx/csrc/offline-tts-impl.h
//
// Copyright (c) 2023 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_

#include <memory>
#include <string>

#include "sherpa-onnx/csrc/offline-tts.h"

namespace sherpa_onnx {

class OfflineTtsImpl {
public:
virtual ~OfflineTtsImpl() = default;

static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config);

virtual GeneratedAudio Generate(const std::string &text) const = 0;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_
45 changes: 45 additions & 0 deletions sherpa-onnx/csrc/offline-tts-model-config.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// sherpa-onnx/csrc/offline-tts-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation

#include "sherpa-onnx/csrc/offline-tts-model-config.h"

#include "sherpa-onnx/csrc/macros.h"

namespace sherpa_onnx {

void OfflineTtsModelConfig::Register(ParseOptions *po) {
vits.Register(po);

po->Register("num-threads", &num_threads,
"Number of threads to run the neural network");

po->Register("debug", &debug,
"true to print model information while loading it.");

po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
}

bool OfflineTtsModelConfig::Validate() const {
if (num_threads < 1) {
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
return false;
}

return vits.Validate();
}

std::string OfflineTtsModelConfig::ToString() const {
std::ostringstream os;

os << "OfflineTtsModelConfig(";
os << "vits=" << vits.ToString() << ", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\")";

return os.str();
}

} // namespace sherpa_onnx
40 changes: 40 additions & 0 deletions sherpa-onnx/csrc/offline-tts-model-config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// sherpa-onnx/csrc/offline-tts-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_

#include <string>

#include "sherpa-onnx/csrc/offline-tts-vits-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"

namespace sherpa_onnx {

struct OfflineTtsModelConfig {
OfflineTtsVitsModelConfig vits;

int32_t num_threads = 1;
bool debug = false;
std::string provider = "cpu";

OfflineTtsModelConfig() = default;

OfflineTtsModelConfig(const OfflineTtsVitsModelConfig &vits,
int32_t num_threads, bool debug,
const std::string &provider)
: vits(vits),
num_threads(num_threads),
debug(debug),
provider(provider) {}

void Register(ParseOptions *po);
bool Validate() const;

std::string ToString() const;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_
Loading

0 comments on commit 536d580

Please sign in to comment.