-
Notifications
You must be signed in to change notification settings - Fork 478
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4771c92
commit 536d580
Showing
17 changed files
with
839 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
// sherpa-onnx/csrc/lexicon.cc | ||
// | ||
// Copyright (c) 2022-2023 Xiaomi Corporation | ||
|
||
#include "sherpa-onnx/csrc/lexicon.h" | ||
|
||
#include <algorithm> | ||
#include <cctype> | ||
#include <fstream> | ||
#include <sstream> | ||
#include <utility> | ||
|
||
#include "sherpa-onnx/csrc/macros.h" | ||
#include "sherpa-onnx/csrc/text-utils.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
static void ToLowerCase(std::string *in_out) { | ||
std::transform(in_out->begin(), in_out->end(), in_out->begin(), | ||
[](unsigned char c) { return std::tolower(c); }); | ||
} | ||
|
||
// Note: We don't use SymbolTable here since tokens may contain a blank | ||
// in the first column | ||
static std::unordered_map<std::string, int32_t> ReadTokens( | ||
const std::string &tokens) { | ||
std::unordered_map<std::string, int32_t> token2id; | ||
|
||
std::ifstream is(tokens); | ||
std::string line; | ||
|
||
std::string sym; | ||
int32_t id; | ||
while (std::getline(is, line)) { | ||
std::istringstream iss(line); | ||
iss >> sym; | ||
if (iss.eof()) { | ||
id = atoi(sym.c_str()); | ||
sym = " "; | ||
} else { | ||
iss >> id; | ||
} | ||
|
||
if (!iss.eof()) { | ||
SHERPA_ONNX_LOGE("Error: %s", line.c_str()); | ||
exit(-1); | ||
} | ||
|
||
#if 0 | ||
if (token2id.count(sym)) { | ||
SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d", | ||
sym.c_str(), line.c_str(), token2id.at(sym)); | ||
exit(-1); | ||
} | ||
#endif | ||
token2id.insert({sym, id}); | ||
} | ||
|
||
return token2id; | ||
} | ||
|
||
static std::vector<int32_t> ConvertTokensToIds( | ||
const std::unordered_map<std::string, int32_t> &token2id, | ||
const std::vector<std::string> &tokens) { | ||
std::vector<int32_t> ids; | ||
ids.reserve(tokens.size()); | ||
for (const auto &s : tokens) { | ||
if (!token2id.count(s)) { | ||
return {}; | ||
} | ||
int32_t id = token2id.at(s); | ||
ids.push_back(id); | ||
} | ||
|
||
return ids; | ||
} | ||
|
||
Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, | ||
const std::string &punctuations) { | ||
token2id_ = ReadTokens(tokens); | ||
std::ifstream is(lexicon); | ||
|
||
std::string word; | ||
std::vector<std::string> token_list; | ||
std::string line; | ||
std::string phone; | ||
|
||
while (std::getline(is, line)) { | ||
std::istringstream iss(line); | ||
|
||
token_list.clear(); | ||
|
||
iss >> word; | ||
ToLowerCase(&word); | ||
|
||
if (word2ids_.count(word)) { | ||
SHERPA_ONNX_LOGE("Duplicated word: %s", word.c_str()); | ||
return; | ||
} | ||
|
||
while (iss >> phone) { | ||
token_list.push_back(std::move(phone)); | ||
} | ||
|
||
std::vector<int32_t> ids = ConvertTokensToIds(token2id_, token_list); | ||
if (ids.empty()) { | ||
continue; | ||
} | ||
word2ids_.insert({std::move(word), std::move(ids)}); | ||
} | ||
|
||
// process punctuations | ||
std::vector<std::string> punctuation_list; | ||
SplitStringToVector(punctuations, " ", false, &punctuation_list); | ||
for (auto &s : punctuation_list) { | ||
punctuations_.insert(std::move(s)); | ||
} | ||
} | ||
|
||
std::vector<int64_t> Lexicon::ConvertTextToTokenIds( | ||
const std::string &_text) const { | ||
std::string text(_text); | ||
ToLowerCase(&text); | ||
|
||
std::vector<std::string> words; | ||
SplitStringToVector(text, " ", false, &words); | ||
|
||
std::vector<int64_t> ans; | ||
for (auto w : words) { | ||
std::vector<int64_t> prefix; | ||
while (!w.empty() && punctuations_.count(std::string(1, w[0]))) { | ||
// if w begins with a punctuation | ||
prefix.push_back(token2id_.at(std::string(1, w[0]))); | ||
w = std::string(w.begin() + 1, w.end()); | ||
} | ||
|
||
std::vector<int64_t> suffix; | ||
while (!w.empty() && punctuations_.count(std::string(1, w.back()))) { | ||
suffix.push_back(token2id_.at(std::string(1, w.back()))); | ||
w = std::string(w.begin(), w.end() - 1); | ||
} | ||
|
||
if (!word2ids_.count(w)) { | ||
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str()); | ||
continue; | ||
} | ||
|
||
const auto &token_ids = word2ids_.at(w); | ||
ans.insert(ans.end(), prefix.begin(), prefix.end()); | ||
ans.insert(ans.end(), token_ids.begin(), token_ids.end()); | ||
ans.insert(ans.end(), suffix.rbegin(), suffix.rend()); | ||
} | ||
|
||
return ans; | ||
} | ||
|
||
} // namespace sherpa_onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
// sherpa-onnx/csrc/lexicon.h | ||
// | ||
// Copyright (c) 2022-2023 Xiaomi Corporation | ||
|
||
#ifndef SHERPA_ONNX_CSRC_LEXICON_H_ | ||
#define SHERPA_ONNX_CSRC_LEXICON_H_ | ||
|
||
#include <cstdint> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <unordered_set> | ||
#include <vector> | ||
|
||
namespace sherpa_onnx { | ||
|
||
class Lexicon { | ||
public: | ||
Lexicon(const std::string &lexicon, const std::string &tokens, | ||
const std::string &punctuations); | ||
|
||
std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const; | ||
|
||
private: | ||
std::unordered_map<std::string, std::vector<int32_t>> word2ids_; | ||
std::unordered_set<std::string> punctuations_; | ||
std::unordered_map<std::string, int32_t> token2id_; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_LEXICON_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
// sherpa-onnx/csrc/offline-tts-impl.cc | ||
// | ||
// Copyright (c) 2023 Xiaomi Corporation | ||
|
||
#include "sherpa-onnx/csrc/offline-tts-impl.h" | ||
|
||
#include <memory> | ||
|
||
#include "sherpa-onnx/csrc/offline-tts-vits-impl.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create( | ||
const OfflineTtsConfig &config) { | ||
// TODO(fangjun): Support other types | ||
return std::make_unique<OfflineTtsVitsImpl>(config); | ||
} | ||
|
||
} // namespace sherpa_onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// sherpa-onnx/csrc/offline-tts-impl.h | ||
// | ||
// Copyright (c) 2023 Xiaomi Corporation | ||
|
||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_ | ||
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_ | ||
|
||
#include <memory> | ||
#include <string> | ||
|
||
#include "sherpa-onnx/csrc/offline-tts.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
class OfflineTtsImpl { | ||
public: | ||
virtual ~OfflineTtsImpl() = default; | ||
|
||
static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config); | ||
|
||
virtual GeneratedAudio Generate(const std::string &text) const = 0; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
// sherpa-onnx/csrc/offline-tts-model-config.cc | ||
// | ||
// Copyright (c) 2023 Xiaomi Corporation | ||
|
||
#include "sherpa-onnx/csrc/offline-tts-model-config.h" | ||
|
||
#include "sherpa-onnx/csrc/macros.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
void OfflineTtsModelConfig::Register(ParseOptions *po) { | ||
vits.Register(po); | ||
|
||
po->Register("num-threads", &num_threads, | ||
"Number of threads to run the neural network"); | ||
|
||
po->Register("debug", &debug, | ||
"true to print model information while loading it."); | ||
|
||
po->Register("provider", &provider, | ||
"Specify a provider to use: cpu, cuda, coreml"); | ||
} | ||
|
||
bool OfflineTtsModelConfig::Validate() const { | ||
if (num_threads < 1) { | ||
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); | ||
return false; | ||
} | ||
|
||
return vits.Validate(); | ||
} | ||
|
||
std::string OfflineTtsModelConfig::ToString() const { | ||
std::ostringstream os; | ||
|
||
os << "OfflineTtsModelConfig("; | ||
os << "vits=" << vits.ToString() << ", "; | ||
os << "num_threads=" << num_threads << ", "; | ||
os << "debug=" << (debug ? "True" : "False") << ", "; | ||
os << "provider=\"" << provider << "\")"; | ||
|
||
return os.str(); | ||
} | ||
|
||
} // namespace sherpa_onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
// sherpa-onnx/csrc/offline-tts-model-config.h | ||
// | ||
// Copyright (c) 2023 Xiaomi Corporation | ||
|
||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_ | ||
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_ | ||
|
||
#include <string> | ||
|
||
#include "sherpa-onnx/csrc/offline-tts-vits-model-config.h" | ||
#include "sherpa-onnx/csrc/parse-options.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
struct OfflineTtsModelConfig { | ||
OfflineTtsVitsModelConfig vits; | ||
|
||
int32_t num_threads = 1; | ||
bool debug = false; | ||
std::string provider = "cpu"; | ||
|
||
OfflineTtsModelConfig() = default; | ||
|
||
OfflineTtsModelConfig(const OfflineTtsVitsModelConfig &vits, | ||
int32_t num_threads, bool debug, | ||
const std::string &provider) | ||
: vits(vits), | ||
num_threads(num_threads), | ||
debug(debug), | ||
provider(provider) {} | ||
|
||
void Register(ParseOptions *po); | ||
bool Validate() const; | ||
|
||
std::string ToString() const; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_ |
Oops, something went wrong.