Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support writing generated audio samples to wave files #363

Merged
merged 2 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ list(APPEND sources
offline-tts-vits-model-config.cc
offline-tts-vits-model.cc
offline-tts.cc
wave-writer.cc
)

if(SHERPA_ONNX_ENABLE_CHECK)
Expand Down
8 changes: 7 additions & 1 deletion sherpa-onnx/csrc/lexicon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ static std::unordered_map<std::string, int32_t> ReadTokens(
exit(-1);
}
#endif
token2id.insert({sym, id});
token2id.insert({std::move(sym), id});
}

return token2id;
Expand All @@ -78,6 +78,7 @@ static std::vector<int32_t> ConvertTokensToIds(
Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
const std::string &punctuations) {
token2id_ = ReadTokens(tokens);
blank_ = token2id_.at(" ");
std::ifstream is(lexicon);

std::string word;
Expand Down Expand Up @@ -149,6 +150,11 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
ans.insert(ans.end(), prefix.begin(), prefix.end());
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
ans.insert(ans.end(), suffix.rbegin(), suffix.rend());
ans.push_back(blank_);
}

if (!ans.empty()) {
ans.resize(ans.size() - 1);
}

return ans;
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/lexicon.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class Lexicon {
std::unordered_map<std::string, std::vector<int32_t>> word2ids_;
std::unordered_set<std::string> punctuations_;
std::unordered_map<std::string, int32_t> token2id_;
int32_t blank_; // ID for the blank token
};

} // namespace sherpa_onnx
Expand Down
36 changes: 31 additions & 5 deletions sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "sherpa-onnx/csrc/offline-tts.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-writer.h"

int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Expand All @@ -15,13 +16,34 @@ Offline text-to-speech with sherpa-onnx
--vits-model /path/to/model.onnx \
--vits-lexicon /path/to/lexicon.txt \
--vits-tokens /path/to/tokens.txt
--output-filename ./generated.wav \
'some text within single quotes'

It will generate a file test.wav.
It will generate a file ./generated.wav as specified by --output-filename.

You can download a test model from
https://huggingface.co/csukuangfj/vits-ljs

For instance, you can use:
wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/vits-ljs.onnx
wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt
wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt

./bin/sherpa-onnx-offline-tts \
--vits-model=./vits-ljs.onnx \
--vits-lexicon=./lexicon.txt \
--vits-tokens=./tokens.txt \
--output-filename=./generated.wav \
'liliana, the most beautiful and lovely assistant of our team!'
)usage";

sherpa_onnx::ParseOptions po(kUsageMessage);
std::string output_filename = "./generated.wav";
po.Register("output-filename", &output_filename,
"Path to save the generated audio");

sherpa_onnx::OfflineTtsConfig config;

config.Register(&po);
po.Read(argc, argv);

Expand All @@ -47,11 +69,15 @@ It will generate a file test.wav.
sherpa_onnx::OfflineTts tts(config);
auto audio = tts.Generate(po.GetArg(1));

std::ofstream os("t.pcm", std::ios::binary);
os.write(reinterpret_cast<const char *>(audio.samples.data()),
sizeof(float) * audio.samples.size());
bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate,
audio.samples.data(), audio.samples.size());
if (!ok) {
fprintf(stderr, "Failed to write wave to %s\n", output_filename.c_str());
exit(EXIT_FAILURE);
}

// sox -t raw -r 22050 -b 32 -e floating-point -c 1 ./t.pcm ./t.wav
fprintf(stderr, "The text is: %s\n", po.GetArg(1).c_str());
fprintf(stderr, "Saved to %s successfully!\n", output_filename.c_str());

return 0;
}
82 changes: 82 additions & 0 deletions sherpa-onnx/csrc/wave-writer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// sherpa-onnx/csrc/wave-writer.cc
//
// Copyright (c) 2023 Xiaomi Corporation

#include "sherpa-onnx/csrc/wave-writer.h"

#include <fstream>
#include <string>
#include <vector>

#include "sherpa-onnx/csrc/macros.h"

namespace sherpa_onnx {
namespace {

// see http://soundfile.sapp.org/doc/WaveFormat/
//
// Note: We assume little endian here
// TODO(fangjun): Support big endian
struct WaveHeader {
int32_t chunk_id;
int32_t chunk_size;
int32_t format;
int32_t subchunk1_id;
int32_t subchunk1_size;
int16_t audio_format;
int16_t num_channels;
int32_t sample_rate;
int32_t byte_rate;
int16_t block_align;
int16_t bits_per_sample;
int32_t subchunk2_id; // a tag of this chunk
int32_t subchunk2_size; // size of subchunk2
};

} // namespace

bool WriteWave(const std::string &filename, int32_t sampling_rate,
const float *samples, int32_t n) {
WaveHeader header;
header.chunk_id = 0x46464952; // FFIR
header.format = 0x45564157; // EVAW
header.subchunk1_id = 0x20746d66; // "fmt "
header.subchunk1_size = 16; // 16 for PCM
header.audio_format = 1; // PCM =1

int32_t num_channels = 1;
int32_t bits_per_sample = 16; // int16_t
header.num_channels = num_channels;
header.sample_rate = sampling_rate;
header.byte_rate = sampling_rate * num_channels * bits_per_sample / 8;
header.block_align = num_channels * bits_per_sample / 8;
header.bits_per_sample = bits_per_sample;
header.subchunk2_id = 0x61746164; // atad
header.subchunk2_size = n * num_channels * bits_per_sample / 8;

header.chunk_size = 36 + header.subchunk2_size;

std::vector<int16_t> samples_int16(n);
for (int32_t i = 0; i != n; ++i) {
samples_int16[i] = samples[i] * 32676;
}

std::ofstream os(filename, std::ios::binary);
if (!os) {
SHERPA_ONNX_LOGE("Failed to create %s", filename.c_str());
return false;
}

os.write(reinterpret_cast<const char *>(&header), sizeof(header));
os.write(reinterpret_cast<const char *>(samples_int16.data()),
samples_int16.size() * sizeof(int16_t));

if (!os) {
SHERPA_ONNX_LOGE("Write %s failed", filename.c_str());
return false;
}

return true;
}

} // namespace sherpa_onnx
27 changes: 27 additions & 0 deletions sherpa-onnx/csrc/wave-writer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// sherpa-onnx/csrc/wave-writer.h
//
// Copyright (c) 2023 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_WAVE_WRITER_H_
#define SHERPA_ONNX_CSRC_WAVE_WRITER_H_

#include <cstdint>
#include <string>

namespace sherpa_onnx {

// Write a single channel wave file.
// Note that the input samples are in the range [-1, 1]. It will be multiplied
// by 32767 and saved in int16_t format in the wave file.
//
// @param filename Path to save the samples.
// @param sampling_rate Sample rate of the samples.
// @param samples Pointer to the samples
// @param n Number of samples
// @return Return true if the write succeeds; return false otherwise.
bool WriteWave(const std::string &filename, int32_t sampling_rate,
const float *samples, int32_t n);

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_WAVE_WRITER_H_