-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Write-up for MLP inference on MNIST using HEIR CKKS pipeline #1232
Comments
This is really a fantastic start! Thank you for being so thorough. I want to get this work checked in, so I'm trying to figure out how to tease apart the overall thing into different work items we can make progress on. Instead of an At a high level, our main goal should be to ensure the file sizes are small, runtimes are fast, and that things can be compressed appropriately. So the goal to put the weights in a separate file should mainly be to enable ease of compression. For MLIR, I think we could get most of those benefits by having the values be globals and storing the files as MLIR bytecode (not plaintext), and we could easily check in the bytecode files as test inputs. Adding bytecode is a slightly more complicated task, in that it requires versioning our IR, so I'd like to avoid that as long as possible, and maybe that means having large-ish plaintext checked-in files for a while. So long as we avoid the compile time and stack usage issues above, we can find a good solution for having large test files. |
I fully support this, I'm just trying to think through how we could make this easier to do than the current method, which involves relatively complex |
Another subtask: we should support emitting hex values in the emitter. I would bet using |
Thinking about this some more, could we get the input model saved in an independent format, say ONNX or stablehlo, and use that as a starting point? |
MLP is one of the simplest form of neural network, and supporting it might be a starter pointer for other more complex neural network (e.g. the widely-used ResNet benchmark in HE literature).
Indeed such support could be imported from frontend like TOSA/STABLEHLO, but as #1079 suggest, HEIR currently does not support some linalg op so we have to manually write it. Also, there are still some limitation on the HEIR CKKS pipeline so manual intervention is a must.
The essential code is in ZenithalHourlyRate@617e0d1, including the MLIR MLP impl and corresponding LLVM/OpenFHE driver.
The weight/test data needed is in https://cloud.tsinghua.edu.cn/f/b2da50f8bdbc4aa1859f/
Network Design
MNIST is a dataset of handwritten digits (image of 28x28), and MLP could be used to classify them (into 10 labels); the accuracy is often good enough, around 95%.
The typical MLP design involves the following part:
In HEIR, we do have implemented Halevi-Shoup matrix multiplication, but we only support square matrix for now; also, RELU could be hardly expressed in HE primitive, so a polynomial approximation is needed (check #658, #665 and #1217).
So, the specialized network design becomes
Padding is added accordingly.
Training
Training of this specialized network is done by @yushijinhun using pytorch, achieving accuracy of 96%.
Inference.
There are two version of inference MLIR impl in the code above. The
mlp.mlir
is to show that the cleartext computation itself is correct (lower to LLVM and run with C function), and themlp_inline.mlir
could be accepted by the HEIR pipeline and produce a OpenFHE code.Cleartext computation (for verifying we are correct)
mlp.mlir
contains the following codeThe corresponding C interface is
The reason we use such function signature is that when lowered to LLVM (then interfaced with C), we have to deal with
memref.alloc
and ambiguous ownership can lead to memory bug.The lowering and compiling is done by the following step
Running
./mlp_main
we will get an accuracy of9634/10000
.Homomorphic computation
To convert things to homormophic domain, especially for the HEIR CKKS pipeline, the following transformation is done
The tricky part of current matmul impl is that its weight/buffer should all be
arith.constant
instead of argument, so we make the following change tomlir_inline.mlir
. Note that the filled weight is extracted from the weight above, so we end up with a 24MBmlir_inline.mlir
Then we need to change the function signature to
Then we do the following transform
We need the following change in
mlp__generate_crypto_context()
inmlp_openfhe.cpp
(be careful, this file is 30M, open it in IDE may make it stuck) to make it work:.SetPlaintextModulus()
. CKKS does not need it.v14431.SetSecurityLevel(HEStd_NotSet); v14431.SetRingDim(1 << 11);
. Changev14431
to your own value. Because our ciphertext is of size 1024, for full packing we need RingDim 2048. To meet the default security parameter, RingDim would become32768
and things is too slow.Compile it against
mlp_openfhe_main.cpp
We could get the following output. Each inference takes ~30 seconds, memory usage is ~3GB.
If we use the RingDim 32768, it takes ~6min to do one inference and memory usage is 26G.
Discussion
ulimit -s unlimited
. We should support load weight at runtime instead of hard-encode it in IR. The technical issue is then issue many tensor op to correctly pack the plaintext weight matrix.tests/Example/benchmark
, I want to discuss how to integrate it as a benchmark example of HEIR.The text was updated successfully, but these errors were encountered: