Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jni interface and kotlin API examples for TTS. #381

Merged
merged 3 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions kotlin-api-examples/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
hs_err*
main.jar
vits-zh-aishell3
22 changes: 22 additions & 0 deletions kotlin-api-examples/Main.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,28 @@ package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager

fun main() {
testTts()
testAsr()
}

fun testTts() {
var config = OfflineTtsConfig(
model=OfflineTtsModelConfig(
vits=OfflineTtsVitsModelConfig(
model="./vits-zh-aishell3/vits-aishell3.onnx",
lexicon="./vits-zh-aishell3/lexicon.txt",
tokens="./vits-zh-aishell3/tokens.txt",
),
numThreads=1,
debug=true,
)
)
val tts = OfflineTts(config=config)
val audio = tts.generate(text="林美丽最美丽!", sid=99, speed=1.2f)
audio.save(filename="99.wav")
}

fun testAsr() {
var featConfig = FeatureConfig(
sampleRate = 16000,
featureDim = 80,
Expand Down
112 changes: 112 additions & 0 deletions kotlin-api-examples/Tts.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright (c) 2023 Xiaomi Corporation
package com.k2fsa.sherpa.onnx

import android.content.res.AssetManager

data class OfflineTtsVitsModelConfig(
var model: String,
var lexicon: String,
var tokens: String,
var noiseScale: Float = 0.667f,
var noiseScaleW: Float = 0.8f,
var lengthScale: Float = 1.0f,
)

data class OfflineTtsModelConfig(
var vits: OfflineTtsVitsModelConfig,
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
)

data class OfflineTtsConfig(
var model: OfflineTtsModelConfig,
)

class GeneratedAudio(
val samples : FloatArray,
val sampleRate: Int,
) {
fun save(filename: String) = saveImpl(filename=filename, samples=samples, sampleRate=sampleRate)

private external fun saveImpl(
filename: String,
samples: FloatArray,
sampleRate: Int
): Boolean
}

class OfflineTts(
assetManager: AssetManager? = null,
var config: OfflineTtsConfig,
) {
private var ptr: Long

init {
if (assetManager != null) {
ptr = new(assetManager, config)
} else {
ptr = newFromFile(config)
}
}

fun generate(
text: String,
sid: Int = 0,
speed: Float = 1.0f
): GeneratedAudio {
var objArray = generateImpl(ptr, text=text, sid=sid, speed=speed)
return GeneratedAudio(samples=objArray[0] as FloatArray,
sampleRate=objArray[1] as Int)
}

fun allocate(assetManager: AssetManager? = null) {
if (ptr == 0L) {
if (assetManager != null) {
ptr = new(assetManager, config)
} else {
ptr = newFromFile(config)
}
}
}

fun free() {
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}

protected fun finalize() {
delete(ptr)
}

private external fun new(
assetManager: AssetManager,
config: OfflineTtsConfig,
): Long

private external fun newFromFile(
config: OfflineTtsConfig,
): Long

private external fun delete(ptr: Long)

// The returned array has two entries:
// - the first entry is an 1-D float array containing audio samples.
// Each sample is normalized to the range [-1, 1]
// - the second entry is the sample rate
external fun generateImpl(
ptr: Long,
text: String,
sid: Int = 0,
speed: Float = 1.0f
): Array<Any>

companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}

}
27 changes: 15 additions & 12 deletions kotlin-api-examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,24 @@

set -e


cd ..
mkdir -p build
cd build

cmake \
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
-DBUILD_SHARED_LIBS=ON \
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
-DSHERPA_ONNX_ENABLE_JNI=ON \
..

make -j4
ls -lh lib
if [ ! -f ../build/lib/libsherpa-onnx-jni.dylib ]; then
cmake \
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
-DBUILD_SHARED_LIBS=ON \
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
-DSHERPA_ONNX_ENABLE_JNI=ON \
..

make -j4
ls -lh lib
fi

export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH

Expand All @@ -31,7 +34,7 @@ if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then
git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
fi

kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt faked-asset-manager.kt
kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt faked-asset-manager.kt Tts.kt

ls -lh main.jar

Expand Down
56 changes: 45 additions & 11 deletions sherpa-onnx/csrc/lexicon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@
#include <sstream>
#include <utility>

#if __ANDROID_API__ >= 9
#include <strstream>

#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"

namespace sherpa_onnx {
Expand All @@ -22,11 +30,9 @@ static void ToLowerCase(std::string *in_out) {

// Note: We don't use SymbolTable here since tokens may contain a blank
// in the first column
static std::unordered_map<std::string, int32_t> ReadTokens(
const std::string &tokens) {
static std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is) {
std::unordered_map<std::string, int32_t> token2id;

std::ifstream is(tokens);
std::string line;

std::string sym;
Expand Down Expand Up @@ -80,11 +86,43 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
bool debug /*= false*/)
: debug_(debug) {
InitLanguage(language);
InitTokens(tokens);
InitLexicon(lexicon);

{
std::ifstream is(tokens);
InitTokens(is);
}

{
std::ifstream is(lexicon);
InitLexicon(is);
}

InitPunctuations(punctuations);
}

#if __ANDROID_API__ >= 9
Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon,
const std::string &tokens, const std::string &punctuations,
const std::string &language, bool debug /*= false*/)
: debug_(debug) {
InitLanguage(language);

{
auto buf = ReadFile(mgr, tokens);
std::istrstream is(buf.data(), buf.size());
InitTokens(is);
}

{
auto buf = ReadFile(mgr, lexicon);
std::istrstream is(buf.data(), buf.size());
InitLexicon(is);
}

InitPunctuations(punctuations);
}
#endif

std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
const std::string &text) const {
switch (language_) {
Expand Down Expand Up @@ -192,9 +230,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
return ans;
}

void Lexicon::InitTokens(const std::string &tokens) {
token2id_ = ReadTokens(tokens);
}
void Lexicon::InitTokens(std::istream &is) { token2id_ = ReadTokens(is); }

void Lexicon::InitLanguage(const std::string &_lang) {
std::string lang(_lang);
Expand All @@ -209,9 +245,7 @@ void Lexicon::InitLanguage(const std::string &_lang) {
}
}

void Lexicon::InitLexicon(const std::string &lexicon) {
std::ifstream is(lexicon);

void Lexicon::InitLexicon(std::istream &is) {
std::string word;
std::vector<std::string> token_list;
std::string line;
Expand Down
16 changes: 14 additions & 2 deletions sherpa-onnx/csrc/lexicon.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@
#define SHERPA_ONNX_CSRC_LEXICON_H_

#include <cstdint>
#include <iostream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

namespace sherpa_onnx {

// TODO(fangjun): Refactor it to an abstract class
Expand All @@ -20,6 +26,12 @@ class Lexicon {
const std::string &punctuations, const std::string &language,
bool debug = false);

#if __ANDROID_API__ >= 9
Lexicon(AAssetManager *mgr, const std::string &lexicon,
const std::string &tokens, const std::string &punctuations,
const std::string &language, bool debug = false);
#endif

std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const;

private:
Expand All @@ -30,8 +42,8 @@ class Lexicon {
const std::string &text) const;

void InitLanguage(const std::string &lang);
void InitTokens(const std::string &tokens);
void InitLexicon(const std::string &lexicon);
void InitTokens(std::istream &is);
void InitLexicon(std::istream &is);
void InitPunctuations(const std::string &punctuations);

private:
Expand Down
8 changes: 8 additions & 0 deletions sherpa-onnx/csrc/offline-tts-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,12 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
return std::make_unique<OfflineTtsVitsImpl>(config);
}

#if __ANDROID_API__ >= 9
std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
AAssetManager *mgr, const OfflineTtsConfig &config) {
// TODO(fangjun): Support other types
return std::make_unique<OfflineTtsVitsImpl>(mgr, config);
}
#endif

} // namespace sherpa_onnx
10 changes: 10 additions & 0 deletions sherpa-onnx/csrc/offline-tts-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
#include <memory>
#include <string>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "sherpa-onnx/csrc/offline-tts.h"

namespace sherpa_onnx {
Expand All @@ -18,6 +23,11 @@ class OfflineTtsImpl {

static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config);

#if __ANDROID_API__ >= 9
static std::unique_ptr<OfflineTtsImpl> Create(AAssetManager *mgr,
const OfflineTtsConfig &config);
#endif

virtual GeneratedAudio Generate(const std::string &text, int64_t sid = 0,
float speed = 1.0) const = 0;
};
Expand Down
13 changes: 13 additions & 0 deletions sherpa-onnx/csrc/offline-tts-vits-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
#include <utility>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-tts-impl.h"
Expand All @@ -24,6 +29,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
model_->Punctuations(), model_->Language(),
config.model.debug) {}

#if __ANDROID_API__ >= 9
OfflineTtsVitsImpl(AAssetManager *mgr, const OfflineTtsConfig &config)
: model_(std::make_unique<OfflineTtsVitsModel>(mgr, config.model)),
lexicon_(mgr, config.model.vits.lexicon, config.model.vits.tokens,
model_->Punctuations(), model_->Language(),
config.model.debug) {}
#endif

GeneratedAudio Generate(const std::string &text, int64_t sid = 0,
float speed = 1.0) const override {
int32_t num_speakers = model_->NumSpeakers();
Expand Down
Loading