Skip to content

Commit

Permalink
Add a tts example
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Oct 13, 2023
1 parent 0f1c9d9 commit 017e471
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 33 deletions.
35 changes: 29 additions & 6 deletions sherpa-onnx/csrc/lexicon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ static std::vector<int32_t> ConvertTokensToIds(
return ids;
}

Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens) {
std::unordered_map<std::string, int32_t> token2id = ReadTokens(tokens);
Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
const std::string &punctuations) {
token2id_ = ReadTokens(tokens);
std::ifstream is(lexicon);

std::string word;
Expand All @@ -101,31 +102,53 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens) {
token_list.push_back(std::move(phone));
}

std::vector<int32_t> ids = ConvertTokensToIds(token2id, token_list);
std::vector<int32_t> ids = ConvertTokensToIds(token2id_, token_list);
if (ids.empty()) {
continue;
}
word2ids_.insert({std::move(word), std::move(ids)});
}

// process punctuations
std::vector<std::string> punctuation_list;
SplitStringToVector(punctuations, " ", false, &punctuation_list);
for (auto &s : punctuation_list) {
punctuations_.insert(std::move(s));
}
}

std::vector<int32_t> Lexicon::ConvertTextToTokenIds(
std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
const std::string &_text) const {
std::string text(_text);
ToLowerCase(&text);

std::vector<std::string> words;
SplitStringToVector(text, " ", false, &words);

std::vector<int32_t> ans;
for (const auto &w : words) {
std::vector<int64_t> ans;
for (auto w : words) {
std::vector<int64_t> prefix;
while (!w.empty() && punctuations_.count(std::string(1, w[0]))) {
// if w begins with a punctuation
prefix.push_back(token2id_.at(std::string(1, w[0])));
w = std::string(w.begin() + 1, w.end());
}

std::vector<int64_t> suffix;
while (!w.empty() && punctuations_.count(std::string(1, w.back()))) {
suffix.push_back(token2id_.at(std::string(1, w.back())));
w = std::string(w.begin(), w.end() - 1);
}

if (!word2ids_.count(w)) {
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
continue;
}

const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), prefix.begin(), prefix.end());
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
ans.insert(ans.end(), suffix.rbegin(), suffix.rend());
}

return ans;
Expand Down
8 changes: 6 additions & 2 deletions sherpa-onnx/csrc/lexicon.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@
#include <cstdint>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

namespace sherpa_onnx {

class Lexicon {
public:
Lexicon(const std::string &lexicon, const std::string &tokens);
Lexicon(const std::string &lexicon, const std::string &tokens,
const std::string &punctuations);

std::vector<int32_t> ConvertTextToTokenIds(const std::string &text) const;
std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const;

private:
std::unordered_map<std::string, std::vector<int32_t>> word2ids_;
std::unordered_set<std::string> punctuations_;
std::unordered_map<std::string, int32_t> token2id_;
};

} // namespace sherpa_onnx
Expand Down
51 changes: 29 additions & 22 deletions sherpa-onnx/csrc/offline-tts-vits-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_

#include <fstream>
#include <memory>
#include <string>
#include <utility>
Expand All @@ -21,44 +20,52 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
public:
explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config)
: model_(std::make_unique<OfflineTtsVitsModel>(config.model)),
lexicon_(config.model.vits.lexicon, config.model.vits.tokens) {
lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
model_->Punctuations()) {
SHERPA_ONNX_LOGE("config: %s\n", config.ToString().c_str());
}

GeneratedAudio Generate(const std::string &text) const override {
SHERPA_ONNX_LOGE("txt: %s", text.c_str());
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
if (x.empty()) {
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
return {};
}

auto t = lexicon_.ConvertTextToTokenIds("liliana");
for (auto i : t) {
fprintf(stderr, "%d ", i);
if (model_->AddBlank()) {
std::vector<int64_t> buffer(x.size() * 2 + 1);
int32_t i = 1;
for (auto k : x) {
buffer[i] = k;
i += 2;
}
x = std::move(buffer);
}
fprintf(stderr, "\n");

std::vector<int64_t> x = {
0, 54, 0, 157, 0, 102, 0, 54, 0, 51, 0, 158, 0, 156, 0, 72,
0, 56, 0, 83, 0, 3, 0, 16, 0, 157, 0, 43, 0, 135, 0, 85,
0, 16, 0, 55, 0, 156, 0, 57, 0, 135, 0, 61, 0, 62, 0, 16,
0, 44, 0, 52, 0, 156, 0, 63, 0, 158, 0, 125, 0, 102, 0, 48,
0, 83, 0, 54, 0, 16, 0, 72, 0, 56, 0, 46, 0, 16, 0, 54,
0, 156, 0, 138, 0, 64, 0, 54, 0, 51, 0, 16, 0, 70, 0, 61,
0, 156, 0, 102, 0, 61, 0, 62, 0, 83, 0, 56, 0, 62, 0};
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

std::array<int64_t, 2> x_shape = {1, static_cast<int32_t>(x.size())};
Ort::Value x_tensor = Ort::Value::CreateTensor(
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());

Ort::Value audio = model_->Run(std::move(x_tensor));

std::vector<int64_t> audio_shape =
audio.GetTensorTypeAndShapeInfo().GetShape();

const float *p = audio.GetTensorData<float>();
std::ofstream os("t.pcm", std::ios::binary);
os.write(reinterpret_cast<const char *>(p), sizeof(float) * audio_shape[2]);
int64_t total = 1;
// The output shape may be (1, 1, total) or (1, total) or (total,)
for (auto i : audio_shape) {
total *= i;
}

// sox -t raw -r 22050 -b 32 -e floating-point -c 1 ./t.pcm ./t.wav
const float *p = audio.GetTensorData<float>();

return {};
GeneratedAudio ans;
ans.sample_rate = model_->SampleRate();
ans.samples = std::vector<float>(p, p + total);
return ans;
}

private:
Expand Down
8 changes: 8 additions & 0 deletions sherpa-onnx/csrc/offline-tts-vits-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class OfflineTtsVitsModel::Impl {

bool AddBlank() const { return add_blank_; }

std::string Punctuations() const { return punctuations_; }

private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
Expand All @@ -91,6 +93,7 @@ class OfflineTtsVitsModel::Impl {
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
}

private:
Expand All @@ -109,6 +112,7 @@ class OfflineTtsVitsModel::Impl {

int32_t sample_rate_;
int32_t add_blank_;
std::string punctuations_;
};

OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
Expand All @@ -124,4 +128,8 @@ int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); }

bool OfflineTtsVitsModel::AddBlank() const { return impl_->AddBlank(); }

std::string OfflineTtsVitsModel::Punctuations() const {
return impl_->Punctuations();
}

} // namespace sherpa_onnx
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/offline-tts-vits-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_

#include <memory>
#include <string>

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
Expand All @@ -32,6 +33,8 @@ class OfflineTtsVitsModel {
// true to insert a blank between each token
bool AddBlank() const;

std::string Punctuations() const;

private:
class Impl;
std::unique_ptr<Impl> impl_;
Expand Down
22 changes: 19 additions & 3 deletions sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
//
// Copyright (c) 2023 Xiaomi Corporation

#include <fstream>

#include "sherpa-onnx/csrc/offline-tts.h"
#include "sherpa-onnx/csrc/parse-options.h"

Expand All @@ -13,7 +15,7 @@ Offline text-to-speech with sherpa-onnx
--vits-model /path/to/model.onnx \
--vits-lexicon /path/to/lexicon.txt \
--vits-tokens /path/to/tokens.txt
"some text within double quotes"
'some text within single quotes'
It will generate a file test.wav.
)usage";
Expand All @@ -23,19 +25,33 @@ It will generate a file test.wav.
config.Register(&po);
po.Read(argc, argv);

if (po.NumArgs() != 1) {
if (po.NumArgs() == 0) {
fprintf(stderr, "Error: Please provide the text to generate audio.\n\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}

if (po.NumArgs() > 1) {
fprintf(stderr,
"Error: Accept only one positional argument. Please use single "
"quotes to wrap your text\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}

if (!config.Validate()) {
fprintf(stderr, "Errors in config!\n");
exit(EXIT_FAILURE);
}

sherpa_onnx::OfflineTts tts(config);
tts.Generate("hello world\n");
auto audio = tts.Generate(po.GetArg(1));

std::ofstream os("t.pcm", std::ios::binary);
os.write(reinterpret_cast<const char *>(audio.samples.data()),
sizeof(float) * audio.samples.size());

// sox -t raw -r 22050 -b 32 -e floating-point -c 1 ./t.pcm ./t.wav

return 0;
}

0 comments on commit 017e471

Please sign in to comment.