Skip to content

Commit

Permalink
support Chinese vits models
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Oct 17, 2023
1 parent 9efe697 commit b6da21e
Show file tree
Hide file tree
Showing 12 changed files with 259 additions and 49 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
include(asio)
endif()

include(utfcpp)

add_subdirectory(sherpa-onnx)

if(SHERPA_ONNX_ENABLE_C_API)
Expand Down
2 changes: 1 addition & 1 deletion cmake/kaldi-decoder.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function(download_kaldi_decoder)
set(kaldi_decoder_HASH "SHA256=98bf445a5b7961ccf3c3522317d900054eaadb6a9cdcf4531e7d9caece94a56d")

set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
set(KALDI_DECODER_ENABLE_TESTS OFF CACHE BOOL "" FORCE)
set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE)

# If you don't have access to the Internet,
Expand Down
45 changes: 45 additions & 0 deletions cmake/utfcpp.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
function(download_utfcpp)
include(FetchContent)

set(utfcpp_URL "https://github.com/nemtrif/utfcpp/archive/refs/tags/v3.2.5.tar.gz")
set(utfcpp_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/utfcpp-3.2.5.tar.gz")
set(utfcpp_HASH "SHA256=14fd1b3c466814cb4c40771b7f207b61d2c7a0aa6a5e620ca05c00df27f25afd")

# If you don't have access to the Internet,
# please pre-download utfcpp
set(possible_file_locations
$ENV{HOME}/Downloads/utfcpp-3.2.5.tar.gz
${PROJECT_SOURCE_DIR}/utfcpp-3.2.5.tar.gz
${PROJECT_BINARY_DIR}/utfcpp-3.2.5.tar.gz
/tmp/utfcpp-3.2.5.tar.gz
/star-fj/fangjun/download/github/utfcpp-3.2.5.tar.gz
)

foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(utfcpp_URL "${f}")
file(TO_CMAKE_PATH "${utfcpp_URL}" utfcpp_URL)
message(STATUS "Found local downloaded utfcpp: ${utfcpp_URL}")
set(utfcpp_URL2)
break()
endif()
endforeach()

FetchContent_Declare(utfcpp
URL
${utfcpp_URL}
${utfcpp_URL2}
URL_HASH ${utfcpp_HASH}
)

FetchContent_GetProperties(utfcpp)
if(NOT utfcpp_POPULATED)
message(STATUS "Downloading utfcpp from ${utfcpp_URL}")
FetchContent_Populate(utfcpp)
endif()
message(STATUS "utfcpp is downloaded to ${utfcpp_SOURCE_DIR}")
# add_subdirectory(${utfcpp_SOURCE_DIR} ${utfcpp_BINARY_DIR} EXCLUDE_FROM_ALL)
include_directories(${utfcpp_SOURCE_DIR})
endfunction()

download_utfcpp()
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ if(SHERPA_ONNX_ENABLE_TESTS)
stack-test.cc
transpose-test.cc
unbind-test.cc
utfcpp-test.cc
)

function(sherpa_onnx_add_test source)
Expand Down
145 changes: 100 additions & 45 deletions sherpa-onnx/csrc/lexicon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,105 @@ static std::vector<int32_t> ConvertTokensToIds(
}

Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
const std::string &punctuations) {
const std::string &punctuations, const std::string &language) {
InitLanguage(language);
InitTokens(tokens);
InitLexicon(lexicon);
InitPunctuations(punctuations);
}

std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
const std::string &text) const {
switch (language_) {
case Language::kEnglish:
return ConvertTextToTokenIdsEnglish(text);
case Language::kChinese:
return ConvertTextToTokenIdsChinese(text);
default:
SHERPA_ONNX_LOGE("Unknonw language: %d", static_cast<int32_t>(language_));
exit(-1);
}

return {};
}

std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
const std::string &text) const {
std::vector<std::string> words = SplitUtf8(text);

std::vector<int64_t> ans;

ans.push_back(token2id_.at("sil"));

for (const auto &w : words) {
if (!word2ids_.count(w)) {
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
continue;
}

const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
}
ans.push_back(token2id_.at("sil"));
ans.push_back(token2id_.at("eos"));
return ans;
}

std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
const std::string &_text) const {
std::string text(_text);
ToLowerCase(&text);

std::vector<std::string> words = SplitUtf8(text);

std::vector<int64_t> ans;
for (const auto &w : words) {
if (punctuations_.count(w)) {
ans.push_back(token2id_.at(w));
continue;
}

if (!word2ids_.count(w)) {
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
continue;
}

const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
if (blank_ != -1) {
ans.push_back(blank_);
}
}

if (blank_ != -1 && !ans.empty()) {
// remove the last blank
ans.resize(ans.size() - 1);
}

return ans;
}

void Lexicon::InitTokens(const std::string &tokens) {
token2id_ = ReadTokens(tokens);
blank_ = token2id_.at(" ");
if (token2id_.count(" ")) {
blank_ = token2id_.at(" ");
}
}

void Lexicon::InitLanguage(const std::string &_lang) {
std::string lang(_lang);
ToLowerCase(&lang);
if (lang == "english") {
language_ = Language::kEnglish;
} else if (lang == "chinese") {
language_ = Language::kChinese;
} else {
SHERPA_ONNX_LOGE("Unknown language: %s", _lang.c_str());
exit(-1);
}
}

void Lexicon::InitLexicon(const std::string &lexicon) {
std::ifstream is(lexicon);

std::string word;
Expand Down Expand Up @@ -109,55 +205,14 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
}
word2ids_.insert({std::move(word), std::move(ids)});
}
}

// process punctuations
void Lexicon::InitPunctuations(const std::string &punctuations) {
std::vector<std::string> punctuation_list;
SplitStringToVector(punctuations, " ", false, &punctuation_list);
for (auto &s : punctuation_list) {
punctuations_.insert(std::move(s));
}
}

std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
const std::string &_text) const {
std::string text(_text);
ToLowerCase(&text);

std::vector<std::string> words;
SplitStringToVector(text, " ", false, &words);

std::vector<int64_t> ans;
for (auto w : words) {
std::vector<int64_t> prefix;
while (!w.empty() && punctuations_.count(std::string(1, w[0]))) {
// if w begins with a punctuation
prefix.push_back(token2id_.at(std::string(1, w[0])));
w = std::string(w.begin() + 1, w.end());
}

std::vector<int64_t> suffix;
while (!w.empty() && punctuations_.count(std::string(1, w.back()))) {
suffix.push_back(token2id_.at(std::string(1, w.back())));
w = std::string(w.begin(), w.end() - 1);
}

if (!word2ids_.count(w)) {
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
continue;
}

const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), prefix.begin(), prefix.end());
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
ans.insert(ans.end(), suffix.rbegin(), suffix.rend());
ans.push_back(blank_);
}

if (!ans.empty()) {
ans.resize(ans.size() - 1);
}

return ans;
}

} // namespace sherpa_onnx
26 changes: 24 additions & 2 deletions sherpa-onnx/csrc/lexicon.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,40 @@

namespace sherpa_onnx {

// TODO(fangjun): Refactor it to an abstract class
class Lexicon {
public:
Lexicon(const std::string &lexicon, const std::string &tokens,
const std::string &punctuations);
const std::string &punctuations, const std::string &language);

std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const;

private:
std::vector<int64_t> ConvertTextToTokenIdsEnglish(
const std::string &text) const;

std::vector<int64_t> ConvertTextToTokenIdsChinese(
const std::string &text) const;

void InitLanguage(const std::string &lang);
void InitTokens(const std::string &tokens);
void InitLexicon(const std::string &lexicon);
void InitPunctuations(const std::string &punctuations);

private:
enum class Language {
kEnglish,
kChinese,
kUnknown,
};

private:
std::unordered_map<std::string, std::vector<int32_t>> word2ids_;
std::unordered_set<std::string> punctuations_;
std::unordered_map<std::string, int32_t> token2id_;
int32_t blank_; // ID for the blank token
int32_t blank_ = -1; // ID for the blank token
Language language_;
//
};

} // namespace sherpa_onnx
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/offline-tts-vits-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config)
: model_(std::make_unique<OfflineTtsVitsModel>(config.model)),
lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
model_->Punctuations()) {}
model_->Punctuations(), model_->Language()) {}

GeneratedAudio Generate(const std::string &text,
int64_t sid = 0) const override {
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/csrc/offline-tts-vits-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class OfflineTtsVitsModel::Impl {
bool AddBlank() const { return add_blank_; }

std::string Punctuations() const { return punctuations_; }
std::string Language() const { return language_; }

private:
void Init(void *model_data, size_t model_data_length) {
Expand All @@ -108,6 +109,7 @@ class OfflineTtsVitsModel::Impl {
SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers");
SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
SHERPA_ONNX_READ_META_DATA_STR(language_, "language");
}

private:
Expand All @@ -128,6 +130,7 @@ class OfflineTtsVitsModel::Impl {
int32_t add_blank_;
int32_t n_speakers_;
std::string punctuations_;
std::string language_;
};

OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
Expand All @@ -147,4 +150,6 @@ std::string OfflineTtsVitsModel::Punctuations() const {
return impl_->Punctuations();
}

std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); }

} // namespace sherpa_onnx
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/offline-tts-vits-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class OfflineTtsVitsModel {
bool AddBlank() const;

std::string Punctuations() const;
std::string Language() const;

private:
class Impl;
Expand Down
Loading

0 comments on commit b6da21e

Please sign in to comment.