diff --git a/sherpa-onnx/csrc/lexicon.cc b/sherpa-onnx/csrc/lexicon.cc index 88f95f1f6..26b215b43 100644 --- a/sherpa-onnx/csrc/lexicon.cc +++ b/sherpa-onnx/csrc/lexicon.cc @@ -131,6 +131,8 @@ std::vector Lexicon::ConvertTextToTokenIds( return ConvertTextToTokenIdsEnglish(text); case Language::kGerman: return ConvertTextToTokenIdsGerman(text); + case Language::kSpanish: + return ConvertTextToTokenIdsSpanish(text); case Language::kChinese: return ConvertTextToTokenIdsChinese(text); default: @@ -250,6 +252,8 @@ void Lexicon::InitLanguage(const std::string &_lang) { language_ = Language::kEnglish; } else if (lang == "german") { language_ = Language::kGerman; + } else if (lang == "spanish") { + language_ = Language::kSpanish; } else if (lang == "chinese") { language_ = Language::kChinese; } else { diff --git a/sherpa-onnx/csrc/lexicon.h b/sherpa-onnx/csrc/lexicon.h index fcf791422..ecbf26513 100644 --- a/sherpa-onnx/csrc/lexicon.h +++ b/sherpa-onnx/csrc/lexicon.h @@ -41,6 +41,11 @@ class Lexicon { return ConvertTextToTokenIdsEnglish(text); } + std::vector ConvertTextToTokenIdsSpanish( + const std::string &text) const { + return ConvertTextToTokenIdsEnglish(text); + } + std::vector ConvertTextToTokenIdsEnglish( const std::string &text) const; @@ -56,6 +61,7 @@ class Lexicon { enum class Language { kEnglish, kGerman, + kSpanish, kChinese, kUnknown, }; diff --git a/sherpa-onnx/csrc/text-utils.cc b/sherpa-onnx/csrc/text-utils.cc index 06cf2eef0..e3196b6ab 100644 --- a/sherpa-onnx/csrc/text-utils.cc +++ b/sherpa-onnx/csrc/text-utils.cc @@ -164,7 +164,7 @@ template bool SplitStringToFloats(const std::string &full, const char *delim, std::vector *out); static bool IsPunct(char c) { return c != '\'' && std::ispunct(c); } -static bool IsGermanUmlauts(const std::string &words) { +static bool IsGermanUmlauts(const std::string &word) { // ä 0xC3 0xA4 // ö 0xC3 0xB6 // ü 0xC3 0xBC @@ -173,12 +173,12 @@ static bool IsGermanUmlauts(const std::string &words) { // Ü 0xC3 0x9C // ß 0xC3 0x9F - if (words.size() != 2 || static_cast(words[0]) != 0xc3) { + if (word.size() != 2 || static_cast(word[0]) != 0xc3) { return false; } - auto c = static_cast(words[1]); - if (c == 0xa4 || c == 0xb6 || c == 0xbC || c == 0x84 || c == 0x96 || + auto c = static_cast(word[1]); + if (c == 0xa4 || c == 0xb6 || c == 0xbc || c == 0x84 || c == 0x96 || c == 0x9c || c == 0x9f) { return true; } @@ -186,6 +186,33 @@ static bool IsGermanUmlauts(const std::string &words) { return false; } +// see https://www.tandem.net/blog/spanish-accents +static bool IsSpanishDiacritic(const std::string &word) { + // á 0xC3 0xA1 + // é 0xC3 0xA9 + // í 0xC3 0xAD + // ó 0xC3 0xB3 + // ú 0xC3 0xBA + // ü 0xC3 0xBC + // ñ 0xC3 0xB1 + + if (word.size() != 2 || static_cast(word[0]) != 0xc3) { + return false; + } + + auto c = static_cast(word[1]); + if (c == 0xa1 || c == 0xa9 || c == 0xad || c == 0xb3 || c == 0xba || + c == 0xbc || c == 0xb1) { + return true; + } + + return false; +} + +static bool IsSpecial(const std::string &w) { + return IsGermanUmlauts(w) || IsSpanishDiacritic(w); +} + static std::vector MergeCharactersIntoWords( const std::vector &words) { std::vector ans; @@ -196,7 +223,7 @@ static std::vector MergeCharactersIntoWords( while (i < n) { const auto &w = words[i]; - if (w.size() >= 3 || (w.size() == 2 && !IsGermanUmlauts(w)) || + if (w.size() >= 3 || (w.size() == 2 && !IsSpecial(w)) || (w.size() == 1 && (IsPunct(w[0]) || std::isspace(w[0])))) { if (prev != -1) { std::string t; @@ -215,7 +242,7 @@ static std::vector MergeCharactersIntoWords( } // e.g., öffnen - if (w.size() == 1 || (w.size() == 2 && IsGermanUmlauts(w))) { + if (w.size() == 1 || (w.size() == 2 && IsSpecial(w))) { if (prev == -1) { prev = i; }