From 3422b9388dfec1824d0970ef0d753e963ab0b76e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 31 Dec 2024 19:20:52 +0800 Subject: [PATCH] Add Kotlin API for Matcha-TTS models. (#1668) --- .github/workflows/jni.yaml | 5 ++++ .gitignore | 1 + kotlin-api-examples/run.sh | 10 +++++++ kotlin-api-examples/test_tts.kt | 29 ++++++++++++++++-- sherpa-onnx/c-api/c-api.cc | 12 +++++--- sherpa-onnx/c-api/c-api.h | 6 ++-- sherpa-onnx/csrc/jieba-lexicon.cc | 2 +- sherpa-onnx/jni/offline-tts.cc | 49 +++++++++++++++++++++++++++++++ sherpa-onnx/kotlin-api/Tts.kt | 12 ++++++++ 9 files changed, 117 insertions(+), 9 deletions(-) diff --git a/.github/workflows/jni.yaml b/.github/workflows/jni.yaml index a0f769393..0dc775f4c 100644 --- a/.github/workflows/jni.yaml +++ b/.github/workflows/jni.yaml @@ -75,3 +75,8 @@ jobs: cd ./kotlin-api-examples ./run.sh + + - uses: actions/upload-artifact@v4 + with: + name: tts-files-${{ matrix.os }} + path: kotlin-api-examples/test-*.wav diff --git a/.gitignore b/.gitignore index cfb6fa57c..eeec52d9c 100644 --- a/.gitignore +++ b/.gitignore @@ -125,3 +125,4 @@ sherpa-onnx-moonshine-tiny-en-int8 sherpa-onnx-moonshine-base-en-int8 harmony-os/SherpaOnnxHar/sherpa_onnx/LICENSE harmony-os/SherpaOnnxHar/sherpa_onnx/CHANGELOG.md +matcha-icefall-zh-baker diff --git a/kotlin-api-examples/run.sh b/kotlin-api-examples/run.sh index 3b3d15938..63ea224d1 100755 --- a/kotlin-api-examples/run.sh +++ b/kotlin-api-examples/run.sh @@ -105,6 +105,16 @@ function testTts() { rm vits-piper-en_US-amy-low.tar.bz2 fi + if [ ! -f ./matcha-icefall-zh-baker/model-steps-3.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2 + tar xvf matcha-icefall-zh-baker.tar.bz2 + rm matcha-icefall-zh-baker.tar.bz2 + fi + + if [ ! -f ./hifigan_v2.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx + fi + out_filename=test_tts.jar kotlinc-jvm -include-runtime -d $out_filename \ test_tts.kt \ diff --git a/kotlin-api-examples/test_tts.kt b/kotlin-api-examples/test_tts.kt index 8602a4461..3865c33e3 100644 --- a/kotlin-api-examples/test_tts.kt +++ b/kotlin-api-examples/test_tts.kt @@ -1,10 +1,35 @@ package com.k2fsa.sherpa.onnx fun main() { - testTts() + testVits() + testMatcha() } -fun testTts() { +fun testMatcha() { + // see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models + // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2 + var config = OfflineTtsConfig( + model=OfflineTtsModelConfig( + matcha=OfflineTtsMatchaModelConfig( + acousticModel="./matcha-icefall-zh-baker/model-steps-3.onnx", + vocoder="./hifigan_v2.onnx", + tokens="./matcha-icefall-zh-baker/tokens.txt", + lexicon="./matcha-icefall-zh-baker/lexicon.txt", + dictDir="./matcha-icefall-zh-baker/dict", + ), + numThreads=1, + debug=true, + ), + ruleFsts="./matcha-icefall-zh-baker/phone.fst,./matcha-icefall-zh-baker/date.fst,./matcha-icefall-zh-baker/number.fst", + ) + val tts = OfflineTts(config=config) + val audio = tts.generateWithCallback(text="某某银行的副行长和一些行政领导表示,他们去过长江和长白山; 经济不断增长。2024年12月31号,拨打110或者18920240511。123456块钱。", callback=::callback) + audio.save(filename="test-zh.wav") + tts.release() + println("Saved to test-zh.wav") +} + +fun testVits() { // see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 var config = OfflineTtsConfig( diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 4d1bb6625..703380730 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -1727,11 +1727,15 @@ const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation( auto p = new SherpaOnnxOnlinePunctuation; try { sherpa_onnx::OnlinePunctuationConfig punctuation_config; - punctuation_config.model.cnn_bilstm = SHERPA_ONNX_OR(config->model.cnn_bilstm, ""); - punctuation_config.model.bpe_vocab = SHERPA_ONNX_OR(config->model.bpe_vocab, ""); - punctuation_config.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1); + punctuation_config.model.cnn_bilstm = + SHERPA_ONNX_OR(config->model.cnn_bilstm, ""); + punctuation_config.model.bpe_vocab = + SHERPA_ONNX_OR(config->model.bpe_vocab, ""); + punctuation_config.model.num_threads = + SHERPA_ONNX_OR(config->model.num_threads, 1); punctuation_config.model.debug = config->model.debug; - punctuation_config.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu"); + punctuation_config.model.provider = + SHERPA_ONNX_OR(config->model.provider, "cpu"); p->impl = std::make_unique(punctuation_config); diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 990c94cb3..167051cd9 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -1381,12 +1381,14 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationConfig { SherpaOnnxOnlinePunctuationModelConfig model; } SherpaOnnxOnlinePunctuationConfig; -SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation SherpaOnnxOnlinePunctuation; +SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation + SherpaOnnxOnlinePunctuation; // Create an online punctuation processor. The user has to invoke // SherpaOnnxDestroyOnlinePunctuation() to free the returned pointer // to avoid memory leak -SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation( +SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation * +SherpaOnnxCreateOnlinePunctuation( const SherpaOnnxOnlinePunctuationConfig *config); // Free a pointer returned by SherpaOnnxCreateOnlinePunctuation() diff --git a/sherpa-onnx/csrc/jieba-lexicon.cc b/sherpa-onnx/csrc/jieba-lexicon.cc index 11bd1f20f..57b77666b 100644 --- a/sherpa-onnx/csrc/jieba-lexicon.cc +++ b/sherpa-onnx/csrc/jieba-lexicon.cc @@ -155,7 +155,7 @@ class JiebaLexicon::Impl { this_sentence.insert(this_sentence.end(), ids.begin(), ids.end()); - if (w == "。" || w == "!" || w == "?" || w == ",") { + if (IsPunct(w)) { ans.emplace_back(std::move(this_sentence)); this_sentence = {}; } diff --git a/sherpa-onnx/jni/offline-tts.cc b/sherpa-onnx/jni/offline-tts.cc index 4d67afc27..985d581ef 100644 --- a/sherpa-onnx/jni/offline-tts.cc +++ b/sherpa-onnx/jni/offline-tts.cc @@ -20,6 +20,7 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { jobject model = env->GetObjectField(config, fid); jclass model_config_cls = env->GetObjectClass(model); + // vits fid = env->GetFieldID(model_config_cls, "vits", "Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;"); jobject vits = env->GetObjectField(model, fid); @@ -64,6 +65,54 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(vits_cls, "lengthScale", "F"); ans.model.vits.length_scale = env->GetFloatField(vits, fid); + // matcha + fid = env->GetFieldID(model_config_cls, "matcha", + "Lcom/k2fsa/sherpa/onnx/OfflineTtsMatchaModelConfig;"); + jobject matcha = env->GetObjectField(model, fid); + jclass matcha_cls = env->GetObjectClass(matcha); + + fid = env->GetFieldID(matcha_cls, "acousticModel", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(matcha, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.matcha.acoustic_model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(matcha_cls, "vocoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(matcha, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.matcha.vocoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(matcha_cls, "lexicon", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(matcha, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.matcha.lexicon = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(matcha_cls, "tokens", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(matcha, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.matcha.tokens = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(matcha_cls, "dataDir", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(matcha, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.matcha.data_dir = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(matcha_cls, "dictDir", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(matcha, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.matcha.dict_dir = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(matcha_cls, "noiseScale", "F"); + ans.model.matcha.noise_scale = env->GetFloatField(matcha, fid); + + fid = env->GetFieldID(matcha_cls, "lengthScale", "F"); + ans.model.matcha.length_scale = env->GetFloatField(matcha, fid); + fid = env->GetFieldID(model_config_cls, "numThreads", "I"); ans.model.num_threads = env->GetIntField(model, fid); diff --git a/sherpa-onnx/kotlin-api/Tts.kt b/sherpa-onnx/kotlin-api/Tts.kt index 6152cd914..231b87d81 100644 --- a/sherpa-onnx/kotlin-api/Tts.kt +++ b/sherpa-onnx/kotlin-api/Tts.kt @@ -14,8 +14,20 @@ data class OfflineTtsVitsModelConfig( var lengthScale: Float = 1.0f, ) +data class OfflineTtsMatchaModelConfig( + var acousticModel: String = "", + var vocoder: String = "", + var lexicon: String = "", + var tokens: String = "", + var dataDir: String = "", + var dictDir: String = "", + var noiseScale: Float = 1.0f, + var lengthScale: Float = 1.0f, +) + data class OfflineTtsModelConfig( var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(), + var matcha: OfflineTtsMatchaModelConfig = OfflineTtsMatchaModelConfig(), var numThreads: Int = 1, var debug: Boolean = false, var provider: String = "cpu",