diff --git a/python-api-examples/offline-tts.py b/python-api-examples/offline-tts.py index f36ea6f49..630454d71 100755 --- a/python-api-examples/offline-tts.py +++ b/python-api-examples/offline-tts.py @@ -124,6 +124,11 @@ def main(): start = time.time() audio = tts.generate(args.text, sid=args.sid) end = time.time() + + if len(audio.samples) == 0: + print("Error in generating audios. Please read previous error messages.") + return + elapsed_seconds = end - start audio_duration = len(audio.samples) / audio.sample_rate real_time_factor = elapsed_seconds / audio_duration diff --git a/sherpa-onnx/csrc/lexicon.cc b/sherpa-onnx/csrc/lexicon.cc index 3707f1489..87203f4e1 100644 --- a/sherpa-onnx/csrc/lexicon.cc +++ b/sherpa-onnx/csrc/lexicon.cc @@ -104,9 +104,17 @@ std::vector Lexicon::ConvertTextToTokenIdsChinese( std::vector ans; - ans.push_back(token2id_.at("sil")); + auto sil = token2id_.at("sil"); + auto eos = token2id_.at("eos"); + + ans.push_back(sil); for (const auto &w : words) { + if (punctuations_.count(w)) { + ans.push_back(sil); + continue; + } + if (!word2ids_.count(w)) { SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str()); continue; @@ -115,8 +123,8 @@ std::vector Lexicon::ConvertTextToTokenIdsChinese( const auto &token_ids = word2ids_.at(w); ans.insert(ans.end(), token_ids.begin(), token_ids.end()); } - ans.push_back(token2id_.at("sil")); - ans.push_back(token2id_.at("eos")); + ans.push_back(sil); + ans.push_back(eos); return ans; } @@ -126,6 +134,7 @@ std::vector Lexicon::ConvertTextToTokenIdsEnglish( ToLowerCase(&text); std::vector words = SplitUtf8(text); + int32_t blank = token2id_.at(" "); std::vector ans; for (const auto &w : words) { @@ -141,12 +150,10 @@ std::vector Lexicon::ConvertTextToTokenIdsEnglish( const auto &token_ids = word2ids_.at(w); ans.insert(ans.end(), token_ids.begin(), token_ids.end()); - if (blank_ != -1) { - ans.push_back(blank_); - } + ans.push_back(blank); } - if (blank_ != -1 && !ans.empty()) { + if (!ans.empty()) { // remove the last blank ans.resize(ans.size() - 1); } @@ -156,9 +163,6 @@ std::vector Lexicon::ConvertTextToTokenIdsEnglish( void Lexicon::InitTokens(const std::string &tokens) { token2id_ = ReadTokens(tokens); - if (token2id_.count(" ")) { - blank_ = token2id_.at(" "); - } } void Lexicon::InitLanguage(const std::string &_lang) { diff --git a/sherpa-onnx/csrc/lexicon.h b/sherpa-onnx/csrc/lexicon.h index 74e374ee1..35bace8f8 100644 --- a/sherpa-onnx/csrc/lexicon.h +++ b/sherpa-onnx/csrc/lexicon.h @@ -44,7 +44,6 @@ class Lexicon { std::unordered_map> word2ids_; std::unordered_set punctuations_; std::unordered_map token2id_; - int32_t blank_ = -1; // ID for the blank token Language language_; // }; diff --git a/sherpa-onnx/csrc/offline-tts-vits-impl.h b/sherpa-onnx/csrc/offline-tts-vits-impl.h index 7174e78cd..b553ebe4a 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-impl.h +++ b/sherpa-onnx/csrc/offline-tts-vits-impl.h @@ -25,6 +25,23 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { GeneratedAudio Generate(const std::string &text, int64_t sid = 0) const override { + int32_t num_speakers = model_->NumSpeakers(); + if (num_speakers == 0 && sid != 0) { + SHERPA_ONNX_LOGE( + "This is a single-speaker model and supports only sid 0. Given sid: " + "%d", + sid); + return {}; + } + + if (num_speakers != 0 && (sid >= num_speakers || sid < 0)) { + SHERPA_ONNX_LOGE( + "This model contains only %d speakers. sid should be in the range " + "[%d, %d]. Given: %d", + num_speakers, 0, num_speakers - 1, sid); + return {}; + } + std::vector x = lexicon_.ConvertTextToTokenIds(text); if (x.empty()) { SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.cc b/sherpa-onnx/csrc/offline-tts-vits-model.cc index 06aab516f..4f2ae9ec0 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model.cc @@ -85,6 +85,7 @@ class OfflineTtsVitsModel::Impl { std::string Punctuations() const { return punctuations_; } std::string Language() const { return language_; } + int32_t NumSpeakers() const { return num_speakers_; } private: void Init(void *model_data, size_t model_data_length) { @@ -107,7 +108,7 @@ class OfflineTtsVitsModel::Impl { Ort::AllocatorWithDefaultOptions allocator; // used in the macro below SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate"); SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank"); - SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers"); + SHERPA_ONNX_READ_META_DATA(num_speakers_, "n_speakers"); SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation"); SHERPA_ONNX_READ_META_DATA_STR(language_, "language"); } @@ -128,7 +129,7 @@ class OfflineTtsVitsModel::Impl { int32_t sample_rate_; int32_t add_blank_; - int32_t n_speakers_; + int32_t num_speakers_; std::string punctuations_; std::string language_; }; @@ -152,4 +153,8 @@ std::string OfflineTtsVitsModel::Punctuations() const { std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); } +int32_t OfflineTtsVitsModel::NumSpeakers() const { + return impl_->NumSpeakers(); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.h b/sherpa-onnx/csrc/offline-tts-vits-model.h index 0c8208d53..a3870fbd7 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.h +++ b/sherpa-onnx/csrc/offline-tts-vits-model.h @@ -39,6 +39,7 @@ class OfflineTtsVitsModel { std::string Punctuations() const; std::string Language() const; + int32_t NumSpeakers() const; private: class Impl; diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc index bd4fc4e00..6097468a2 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc @@ -81,6 +81,12 @@ or detailes. sherpa_onnx::OfflineTts tts(config); auto audio = tts.Generate(po.GetArg(1), sid); + if (audio.samples.empty()) { + fprintf( + stderr, + "Error in generating audios. Please read previous error messages.\n"); + exit(EXIT_FAILURE); + } bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate, audio.samples.data(), audio.samples.size());