diff --git a/sherpa-onnx/csrc/lexicon.cc b/sherpa-onnx/csrc/lexicon.cc index f5f84b8d0..310fefa78 100644 --- a/sherpa-onnx/csrc/lexicon.cc +++ b/sherpa-onnx/csrc/lexicon.cc @@ -75,8 +75,9 @@ static std::vector ConvertTokensToIds( return ids; } -Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens) { - std::unordered_map token2id = ReadTokens(tokens); +Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, + const std::string &punctuations) { + token2id_ = ReadTokens(tokens); std::ifstream is(lexicon); std::string word; @@ -101,15 +102,22 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens) { token_list.push_back(std::move(phone)); } - std::vector ids = ConvertTokensToIds(token2id, token_list); + std::vector ids = ConvertTokensToIds(token2id_, token_list); if (ids.empty()) { continue; } word2ids_.insert({std::move(word), std::move(ids)}); } + + // process punctuations + std::vector punctuation_list; + SplitStringToVector(punctuations, " ", false, &punctuation_list); + for (auto &s : punctuation_list) { + punctuations_.insert(std::move(s)); + } } -std::vector Lexicon::ConvertTextToTokenIds( +std::vector Lexicon::ConvertTextToTokenIds( const std::string &_text) const { std::string text(_text); ToLowerCase(&text); @@ -117,15 +125,30 @@ std::vector Lexicon::ConvertTextToTokenIds( std::vector words; SplitStringToVector(text, " ", false, &words); - std::vector ans; - for (const auto &w : words) { + std::vector ans; + for (auto w : words) { + std::vector prefix; + while (!w.empty() && punctuations_.count(std::string(1, w[0]))) { + // if w begins with a punctuation + prefix.push_back(token2id_.at(std::string(1, w[0]))); + w = std::string(w.begin() + 1, w.end()); + } + + std::vector suffix; + while (!w.empty() && punctuations_.count(std::string(1, w.back()))) { + suffix.push_back(token2id_.at(std::string(1, w.back()))); + w = std::string(w.begin(), w.end() - 1); + } + if (!word2ids_.count(w)) { SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str()); continue; } const auto &token_ids = word2ids_.at(w); + ans.insert(ans.end(), prefix.begin(), prefix.end()); ans.insert(ans.end(), token_ids.begin(), token_ids.end()); + ans.insert(ans.end(), suffix.rbegin(), suffix.rend()); } return ans; diff --git a/sherpa-onnx/csrc/lexicon.h b/sherpa-onnx/csrc/lexicon.h index eefce319b..2e746ada9 100644 --- a/sherpa-onnx/csrc/lexicon.h +++ b/sherpa-onnx/csrc/lexicon.h @@ -8,18 +8,22 @@ #include #include #include +#include #include namespace sherpa_onnx { class Lexicon { public: - Lexicon(const std::string &lexicon, const std::string &tokens); + Lexicon(const std::string &lexicon, const std::string &tokens, + const std::string &punctuations); - std::vector ConvertTextToTokenIds(const std::string &text) const; + std::vector ConvertTextToTokenIds(const std::string &text) const; private: std::unordered_map> word2ids_; + std::unordered_set punctuations_; + std::unordered_map token2id_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-vits-impl.h b/sherpa-onnx/csrc/offline-tts-vits-impl.h index 61456794e..e9b94064c 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-impl.h +++ b/sherpa-onnx/csrc/offline-tts-vits-impl.h @@ -4,7 +4,6 @@ #ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_ -#include #include #include #include @@ -21,44 +20,52 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { public: explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config) : model_(std::make_unique(config.model)), - lexicon_(config.model.vits.lexicon, config.model.vits.tokens) { + lexicon_(config.model.vits.lexicon, config.model.vits.tokens, + model_->Punctuations()) { SHERPA_ONNX_LOGE("config: %s\n", config.ToString().c_str()); } GeneratedAudio Generate(const std::string &text) const override { - SHERPA_ONNX_LOGE("txt: %s", text.c_str()); - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + std::vector x = lexicon_.ConvertTextToTokenIds(text); + if (x.empty()) { + SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); + return {}; + } - auto t = lexicon_.ConvertTextToTokenIds("liliana"); - for (auto i : t) { - fprintf(stderr, "%d ", i); + if (model_->AddBlank()) { + std::vector buffer(x.size() * 2 + 1); + int32_t i = 1; + for (auto k : x) { + buffer[i] = k; + i += 2; + } + x = std::move(buffer); } - fprintf(stderr, "\n"); - std::vector x = { - 0, 54, 0, 157, 0, 102, 0, 54, 0, 51, 0, 158, 0, 156, 0, 72, - 0, 56, 0, 83, 0, 3, 0, 16, 0, 157, 0, 43, 0, 135, 0, 85, - 0, 16, 0, 55, 0, 156, 0, 57, 0, 135, 0, 61, 0, 62, 0, 16, - 0, 44, 0, 52, 0, 156, 0, 63, 0, 158, 0, 125, 0, 102, 0, 48, - 0, 83, 0, 54, 0, 16, 0, 72, 0, 56, 0, 46, 0, 16, 0, 54, - 0, 156, 0, 138, 0, 64, 0, 54, 0, 51, 0, 16, 0, 70, 0, 61, - 0, 156, 0, 102, 0, 61, 0, 62, 0, 83, 0, 56, 0, 62, 0}; + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + std::array x_shape = {1, static_cast(x.size())}; Ort::Value x_tensor = Ort::Value::CreateTensor( memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); + Ort::Value audio = model_->Run(std::move(x_tensor)); std::vector audio_shape = audio.GetTensorTypeAndShapeInfo().GetShape(); - const float *p = audio.GetTensorData(); - std::ofstream os("t.pcm", std::ios::binary); - os.write(reinterpret_cast(p), sizeof(float) * audio_shape[2]); + int64_t total = 1; + // The output shape may be (1, 1, total) or (1, total) or (total,) + for (auto i : audio_shape) { + total *= i; + } - // sox -t raw -r 22050 -b 32 -e floating-point -c 1 ./t.pcm ./t.wav + const float *p = audio.GetTensorData(); - return {}; + GeneratedAudio ans; + ans.sample_rate = model_->SampleRate(); + ans.samples = std::vector(p, p + total); + return ans; } private: diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.cc b/sherpa-onnx/csrc/offline-tts-vits-model.cc index c8d7e8655..2f6365132 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model.cc @@ -70,6 +70,8 @@ class OfflineTtsVitsModel::Impl { bool AddBlank() const { return add_blank_; } + std::string Punctuations() const { return punctuations_; } + private: void Init(void *model_data, size_t model_data_length) { sess_ = std::make_unique(env_, model_data, model_data_length, @@ -91,6 +93,7 @@ class OfflineTtsVitsModel::Impl { Ort::AllocatorWithDefaultOptions allocator; // used in the macro below SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate"); SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank"); + SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation"); } private: @@ -109,6 +112,7 @@ class OfflineTtsVitsModel::Impl { int32_t sample_rate_; int32_t add_blank_; + std::string punctuations_; }; OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config) @@ -124,4 +128,8 @@ int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); } bool OfflineTtsVitsModel::AddBlank() const { return impl_->AddBlank(); } +std::string OfflineTtsVitsModel::Punctuations() const { + return impl_->Punctuations(); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.h b/sherpa-onnx/csrc/offline-tts-vits-model.h index 9825c89bc..ca2c1c6be 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.h +++ b/sherpa-onnx/csrc/offline-tts-vits-model.h @@ -6,6 +6,7 @@ #define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_ #include +#include #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/offline-tts-model-config.h" @@ -32,6 +33,8 @@ class OfflineTtsVitsModel { // true to insert a blank between each token bool AddBlank() const; + std::string Punctuations() const; + private: class Impl; std::unique_ptr impl_; diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc index 1d5d6cd57..18b520f48 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc @@ -2,6 +2,8 @@ // // Copyright (c) 2023 Xiaomi Corporation +#include + #include "sherpa-onnx/csrc/offline-tts.h" #include "sherpa-onnx/csrc/parse-options.h" @@ -13,7 +15,7 @@ Offline text-to-speech with sherpa-onnx --vits-model /path/to/model.onnx \ --vits-lexicon /path/to/lexicon.txt \ --vits-tokens /path/to/tokens.txt - "some text within double quotes" + 'some text within single quotes' It will generate a file test.wav. )usage"; @@ -23,19 +25,33 @@ It will generate a file test.wav. config.Register(&po); po.Read(argc, argv); - if (po.NumArgs() != 1) { + if (po.NumArgs() == 0) { fprintf(stderr, "Error: Please provide the text to generate audio.\n\n"); po.PrintUsage(); exit(EXIT_FAILURE); } + if (po.NumArgs() > 1) { + fprintf(stderr, + "Error: Accept only one positional argument. Please use single " + "quotes to wrap your text\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + if (!config.Validate()) { fprintf(stderr, "Errors in config!\n"); exit(EXIT_FAILURE); } sherpa_onnx::OfflineTts tts(config); - tts.Generate("hello world\n"); + auto audio = tts.Generate(po.GetArg(1)); + + std::ofstream os("t.pcm", std::ios::binary); + os.write(reinterpret_cast(audio.samples.data()), + sizeof(float) * audio.samples.size()); + + // sox -t raw -r 22050 -b 32 -e floating-point -c 1 ./t.pcm ./t.wav return 0; }