Skip to content

Commit

Permalink
Validate input sid (#369)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Oct 18, 2023
1 parent 1ee79e3 commit 8545c3b
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 14 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON)
option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF)
option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is ON on Linux" ON)
option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is OFF on Linux" ON)

set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
Expand Down
5 changes: 5 additions & 0 deletions python-api-examples/offline-tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ def main():
start = time.time()
audio = tts.generate(args.text, sid=args.sid)
end = time.time()

if len(audio.samples) == 0:
print("Error in generating audios. Please read previous error messages.")
return

elapsed_seconds = end - start
audio_duration = len(audio.samples) / audio.sample_rate
real_time_factor = elapsed_seconds / audio_duration
Expand Down
24 changes: 14 additions & 10 deletions sherpa-onnx/csrc/lexicon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,17 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(

std::vector<int64_t> ans;

ans.push_back(token2id_.at("sil"));
auto sil = token2id_.at("sil");
auto eos = token2id_.at("eos");

ans.push_back(sil);

for (const auto &w : words) {
if (punctuations_.count(w)) {
ans.push_back(sil);
continue;
}

if (!word2ids_.count(w)) {
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
continue;
Expand All @@ -115,8 +123,8 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
}
ans.push_back(token2id_.at("sil"));
ans.push_back(token2id_.at("eos"));
ans.push_back(sil);
ans.push_back(eos);
return ans;
}

Expand All @@ -126,6 +134,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
ToLowerCase(&text);

std::vector<std::string> words = SplitUtf8(text);
int32_t blank = token2id_.at(" ");

std::vector<int64_t> ans;
for (const auto &w : words) {
Expand All @@ -141,12 +150,10 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(

const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
if (blank_ != -1) {
ans.push_back(blank_);
}
ans.push_back(blank);
}

if (blank_ != -1 && !ans.empty()) {
if (!ans.empty()) {
// remove the last blank
ans.resize(ans.size() - 1);
}
Expand All @@ -156,9 +163,6 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(

void Lexicon::InitTokens(const std::string &tokens) {
token2id_ = ReadTokens(tokens);
if (token2id_.count(" ")) {
blank_ = token2id_.at(" ");
}
}

void Lexicon::InitLanguage(const std::string &_lang) {
Expand Down
1 change: 0 additions & 1 deletion sherpa-onnx/csrc/lexicon.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class Lexicon {
std::unordered_map<std::string, std::vector<int32_t>> word2ids_;
std::unordered_set<std::string> punctuations_;
std::unordered_map<std::string, int32_t> token2id_;
int32_t blank_ = -1; // ID for the blank token
Language language_;
//
};
Expand Down
17 changes: 17 additions & 0 deletions sherpa-onnx/csrc/offline-tts-vits-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,23 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {

GeneratedAudio Generate(const std::string &text,
int64_t sid = 0) const override {
int32_t num_speakers = model_->NumSpeakers();
if (num_speakers == 0 && sid != 0) {
SHERPA_ONNX_LOGE(
"This is a single-speaker model and supports only sid 0. Given sid: "
"%d",
sid);
return {};
}

if (num_speakers != 0 && (sid >= num_speakers || sid < 0)) {
SHERPA_ONNX_LOGE(
"This model contains only %d speakers. sid should be in the range "
"[%d, %d]. Given: %d",
num_speakers, 0, num_speakers - 1, sid);
return {};
}

std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
if (x.empty()) {
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
Expand Down
9 changes: 7 additions & 2 deletions sherpa-onnx/csrc/offline-tts-vits-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class OfflineTtsVitsModel::Impl {

std::string Punctuations() const { return punctuations_; }
std::string Language() const { return language_; }
int32_t NumSpeakers() const { return num_speakers_; }

private:
void Init(void *model_data, size_t model_data_length) {
Expand All @@ -107,7 +108,7 @@ class OfflineTtsVitsModel::Impl {
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers");
SHERPA_ONNX_READ_META_DATA(num_speakers_, "n_speakers");
SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
SHERPA_ONNX_READ_META_DATA_STR(language_, "language");
}
Expand All @@ -128,7 +129,7 @@ class OfflineTtsVitsModel::Impl {

int32_t sample_rate_;
int32_t add_blank_;
int32_t n_speakers_;
int32_t num_speakers_;
std::string punctuations_;
std::string language_;
};
Expand All @@ -152,4 +153,8 @@ std::string OfflineTtsVitsModel::Punctuations() const {

std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); }

int32_t OfflineTtsVitsModel::NumSpeakers() const {
return impl_->NumSpeakers();
}

} // namespace sherpa_onnx
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/offline-tts-vits-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class OfflineTtsVitsModel {

std::string Punctuations() const;
std::string Language() const;
int32_t NumSpeakers() const;

private:
class Impl;
Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ or detailes.

sherpa_onnx::OfflineTts tts(config);
auto audio = tts.Generate(po.GetArg(1), sid);
if (audio.samples.empty()) {
fprintf(
stderr,
"Error in generating audios. Please read previous error messages.\n");
exit(EXIT_FAILURE);
}

bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate,
audio.samples.data(), audio.samples.size());
Expand Down

0 comments on commit 8545c3b

Please sign in to comment.