Skip to content

Commit

Permalink
Feat: Add android example of MNIST inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Scramjet911 committed Sep 3, 2024
1 parent 96a2340 commit 3a235e6
Show file tree
Hide file tree
Showing 51 changed files with 1,410 additions and 2 deletions.
126 changes: 126 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ members = [
"crates/burn-import/onnx-tests",
"examples/*",
"examples/pytorch-import/model",
"examples/mnist-inference-android/app/src/main/rust",
"xtask",
]

exclude = [
"examples/notebook",
"examples/raspberry-pi-pico", # will cause dependency building issues otherwise
"examples/mnist-inference-android",
"examples/raspberry-pi-pico", # will cause dependency building issues otherwise
# "crates/burn-cuda", # comment this line to work on burn-cuda
]

Expand Down
3 changes: 2 additions & 1 deletion _typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ extend-exclude = [
"*.onnx",
"assets/ModuleSerialization.xml",
"examples/image-classification-web/src/model/label.txt",
"examples/mnist-inference-android/gradle/*",
]

[default.extend-words]
# Don't correct "arange" which is intentional
arange = "arange"
arange = "arange"
15 changes: 15 additions & 0 deletions examples/mnist-inference-android/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
*.iml
.gradle
/local.properties
/.idea/caches
/.idea/libraries
/.idea/modules.xml
/.idea/workspace.xml
/.idea/navEditor.xml
/.idea/assetWizardSettings.xml
.DS_Store
/build
/captures
.externalNativeBuild
.cxx
local.properties
105 changes: 105 additions & 0 deletions examples/mnist-inference-android/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# MNIST number detector Android App

This project is a sample Android application that demonstrates how to integrate a `Burn` into an android app using the JNI (Java Native Interface).

## Table of Contents

- [Workflow](#workflow)
- [Prerequisites](#prerequisites)
- [Setup](#setup)
- [How To make your own](#how-to-make-your-own)
- [License](#license)

## Workflow
1. **Image Input:** The user provides an image input through the app's interface.
2. **Image Processing:** The image is converted to a grayscale `byteArray` in Kotlin.
3. **JNI Bridge:** The grayscale `byteArray` is passed to a Rust function via JNI.
4. **Rust Processing:** The Rust function calls the `forward` method from the `burn` library, using a pretrained MNIST ONNX model to perform inference.
5. **Result Handling:** The result, an integer representing the predicted digit, is logged to the android console and returned from Rust to Kotlin.
6. **Output Display:** The predicted digit is displayed on the screen

## Prerequisites
- Android Studio (latest version recommended)
- Rust (installed and configured)
- Android NDK (Native Development Kit)

## Setup
1. **Install Rust dependencies:**

Ensure Rust is installed and the `cargo` command is available:

```bash
rustup update
```
And that you have installed all the rustup toolchains required:
```bash
rustup target add \
aarch64-linux-android \
armv7-linux-androideabi \
i686-linux-android \
x86_64-linux-android
```

2. **Configure the Android NDK:**

Ensure that the Android NDK is installed. You can install it via Android Studio's SDK Manager.
3. **Build the android app:**
Running the android app should automatically build the rust libraries due to the gradle tasks configured at the app level. (More on that later)
## How To make your own
1. There are a few ways to compile a rust library for android -
- Add targets in `.cargo/config.toml` and build with them. Then we can add the `.so` files generated to the jni directory in `app/src/main/jniLibs`
- Add gradle plugins (like [rust-android-gradle](https://github.com/mozilla/rust-android-gradle) or [cargo-ndk-android](https://github.com/willir/cargo-ndk-android-gradle) using `rust-android-gradle` in this project) to do the work for you, so that the rust library is built on each app build. (Might want to change for expensive library builds)
2. To interface with Kotlin(Java) you can either use an interface generator (like [flapigen-rs](https://github.com/Dushistov/flapigen-rs)) or make them by yourself. This sample function doesn't use flapigen.
3. Now the function to be called from android (`infer()` here) needs to follow the [JNI naming conventions](https://docs.oracle.com/javase/1.5.0/docs/guide/jni/spec/design.html) (The correct name is also shown in the call error if it doesn't exist).
4. **Important** The first 2 arguments of the jni interfacing function will be the `env` variable (for interface functions) and the `this` object. The data you pass will start from the 3rd argument.
5. Next for converting the data from java to rust data types, there are multiple functions in the env variable passed to the function. Use as required...
6. Then in the app's `build.gradle` we add the part to run the cargo build before building the app and the also the cargo build details:
```kotlin
// Cargo build details
cargo {
module = "./src/main/rust" // Or whatever directory contains your Cargo.toml
libname = "mnist_inference_android" // Or whatever matches Cargo.toml's [package] name.
targets = listOf(
"arm", "arm64",
"x86",
"x86_64"
)
prebuiltToolchains = true
}
// Used to build cargo before the android build task is run
// See more options here: https://github.com/mozilla/rust-android-gradle/issues/133
project.afterEvaluate {
tasks.withType(com.nishtahir.CargoBuildTask::class)
.forEach { buildTask ->
tasks.withType(com.android.build.gradle.tasks.MergeSourceSetFolders::class)
.configureEach {
this.inputs.dir(
layout.buildDirectory.dir("rustJniLibs" + File.separatorChar + buildTask.toolchain!!.folder)
)
this.dependsOn(buildTask)
}
}
}
```
(In the example we have also added the target directory in `config.toml` since otherwise it will build into the workspace target, which we do not want)
7. Here the library's name is `mnist-android` so we will initialize it in our app:
```kotlin
class MainActivity : ComponentActivity() {
init {
System.loadLibrary("mnist_android") // Note: '-' is changed to '_'
}
...
}
```
8. Finally use it by declaring it as an external function first
```kotlin
external fun infer(inputImage: ByteArray): Int;
...
infer(byteArray)
```
1 change: 1 addition & 0 deletions examples/mnist-inference-android/app/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/build
Loading

0 comments on commit 3a235e6

Please sign in to comment.