diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index d3538480f8..82f53bca61 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -93,6 +93,7 @@ list(APPEND sources offline-tts-vits-model-config.cc offline-tts-vits-model.cc offline-tts.cc + wave-writer.cc ) if(SHERPA_ONNX_ENABLE_CHECK) diff --git a/sherpa-onnx/csrc/lexicon.cc b/sherpa-onnx/csrc/lexicon.cc index 310fefa785..a2e4af68c4 100644 --- a/sherpa-onnx/csrc/lexicon.cc +++ b/sherpa-onnx/csrc/lexicon.cc @@ -53,7 +53,7 @@ static std::unordered_map ReadTokens( exit(-1); } #endif - token2id.insert({sym, id}); + token2id.insert({std::move(sym), id}); } return token2id; @@ -78,6 +78,7 @@ static std::vector ConvertTokensToIds( Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, const std::string &punctuations) { token2id_ = ReadTokens(tokens); + blank_ = token2id_.at(" "); std::ifstream is(lexicon); std::string word; @@ -149,6 +150,11 @@ std::vector Lexicon::ConvertTextToTokenIds( ans.insert(ans.end(), prefix.begin(), prefix.end()); ans.insert(ans.end(), token_ids.begin(), token_ids.end()); ans.insert(ans.end(), suffix.rbegin(), suffix.rend()); + ans.push_back(blank_); + } + + if (!ans.empty()) { + ans.resize(ans.size() - 1); } return ans; diff --git a/sherpa-onnx/csrc/lexicon.h b/sherpa-onnx/csrc/lexicon.h index 2e746ada91..73d6c8a8d4 100644 --- a/sherpa-onnx/csrc/lexicon.h +++ b/sherpa-onnx/csrc/lexicon.h @@ -24,6 +24,7 @@ class Lexicon { std::unordered_map> word2ids_; std::unordered_set punctuations_; std::unordered_map token2id_; + int32_t blank_; // ID for the blank token }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc index 18b520f482..0354218b27 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc @@ -6,6 +6,7 @@ #include "sherpa-onnx/csrc/offline-tts.h" #include "sherpa-onnx/csrc/parse-options.h" +#include "sherpa-onnx/csrc/wave-writer.h" int main(int32_t argc, char *argv[]) { const char *kUsageMessage = R"usage( @@ -15,13 +16,34 @@ Offline text-to-speech with sherpa-onnx --vits-model /path/to/model.onnx \ --vits-lexicon /path/to/lexicon.txt \ --vits-tokens /path/to/tokens.txt + --output-filename ./generated.wav \ 'some text within single quotes' -It will generate a file test.wav. +It will generate a file ./generated.wav as specified by --output-filename. + +You can download a test model from +https://huggingface.co/csukuangfj/vits-ljs + +For instance, you can use: +wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/vits-ljs.onnx +wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt +wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt + +./bin/sherpa-onnx-offline-tts \ + --vits-model=./vits-ljs.onnx \ + --vits-lexicon=./lexicon.txt \ + --vits-tokens=./tokens.txt \ + --output-filename=./generated.wav \ + 'liliana, the most beautiful and lovely assistant of our team!' )usage"; sherpa_onnx::ParseOptions po(kUsageMessage); + std::string output_filename = "./generated.wav"; + po.Register("output-filename", &output_filename, + "Path to save the generated audio"); + sherpa_onnx::OfflineTtsConfig config; + config.Register(&po); po.Read(argc, argv); @@ -47,11 +69,15 @@ It will generate a file test.wav. sherpa_onnx::OfflineTts tts(config); auto audio = tts.Generate(po.GetArg(1)); - std::ofstream os("t.pcm", std::ios::binary); - os.write(reinterpret_cast(audio.samples.data()), - sizeof(float) * audio.samples.size()); + bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate, + audio.samples.data(), audio.samples.size()); + if (!ok) { + fprintf(stderr, "Failed to write wave to %s\n", output_filename.c_str()); + exit(EXIT_FAILURE); + } - // sox -t raw -r 22050 -b 32 -e floating-point -c 1 ./t.pcm ./t.wav + fprintf(stderr, "The text is: %s\n", po.GetArg(1).c_str()); + fprintf(stderr, "Saved to %s successfully!\n", output_filename.c_str()); return 0; } diff --git a/sherpa-onnx/csrc/wave-writer.cc b/sherpa-onnx/csrc/wave-writer.cc new file mode 100644 index 0000000000..f20af4b13c --- /dev/null +++ b/sherpa-onnx/csrc/wave-writer.cc @@ -0,0 +1,82 @@ +// sherpa-onnx/csrc/wave-writer.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/wave-writer.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { +namespace { + +// see http://soundfile.sapp.org/doc/WaveFormat/ +// +// Note: We assume little endian here +// TODO(fangjun): Support big endian +struct WaveHeader { + int32_t chunk_id; + int32_t chunk_size; + int32_t format; + int32_t subchunk1_id; + int32_t subchunk1_size; + int16_t audio_format; + int16_t num_channels; + int32_t sample_rate; + int32_t byte_rate; + int16_t block_align; + int16_t bits_per_sample; + int32_t subchunk2_id; // a tag of this chunk + int32_t subchunk2_size; // size of subchunk2 +}; + +} // namespace + +bool WriteWave(const std::string &filename, int32_t sampling_rate, + const float *samples, int32_t n) { + WaveHeader header; + header.chunk_id = 0x46464952; // FFIR + header.format = 0x45564157; // EVAW + header.subchunk1_id = 0x20746d66; // "fmt " + header.subchunk1_size = 16; // 16 for PCM + header.audio_format = 1; // PCM =1 + + int32_t num_channels = 1; + int32_t bits_per_sample = 16; // int16_t + header.num_channels = num_channels; + header.sample_rate = sampling_rate; + header.byte_rate = sampling_rate * num_channels * bits_per_sample / 8; + header.block_align = num_channels * bits_per_sample / 8; + header.bits_per_sample = bits_per_sample; + header.subchunk2_id = 0x61746164; // atad + header.subchunk2_size = n * num_channels * bits_per_sample / 8; + + header.chunk_size = 36 + header.subchunk2_size; + + std::vector samples_int16(n); + for (int32_t i = 0; i != n; ++i) { + samples_int16[i] = samples[i] * 32676; + } + + std::ofstream os(filename, std::ios::binary); + if (!os) { + SHERPA_ONNX_LOGE("Failed to create %s", filename.c_str()); + return false; + } + + os.write(reinterpret_cast(&header), sizeof(header)); + os.write(reinterpret_cast(samples_int16.data()), + samples_int16.size() * sizeof(int16_t)); + + if (!os) { + SHERPA_ONNX_LOGE("Write %s failed", filename.c_str()); + return false; + } + + return true; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/wave-writer.h b/sherpa-onnx/csrc/wave-writer.h new file mode 100644 index 0000000000..bae4c504b0 --- /dev/null +++ b/sherpa-onnx/csrc/wave-writer.h @@ -0,0 +1,27 @@ +// sherpa-onnx/csrc/wave-writer.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_WAVE_WRITER_H_ +#define SHERPA_ONNX_CSRC_WAVE_WRITER_H_ + +#include +#include + +namespace sherpa_onnx { + +// Write a single channel wave file. +// Note that the input samples are in the range [-1, 1]. It will be multiplied +// by 32767 and saved in int16_t format in the wave file. +// +// @param filename Path to save the samples. +// @param sampling_rate Sample rate of the samples. +// @param samples Pointer to the samples +// @param n Number of samples +// @return Return true if the write succeeds; return false otherwise. +bool WriteWave(const std::string &filename, int32_t sampling_rate, + const float *samples, int32_t n); + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_WAVE_WRITER_H_