Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write-up for MLP inference on MNIST using HEIR CKKS pipeline #1232

Open
ZenithalHourlyRate opened this issue Dec 24, 2024 · 4 comments
Open

Comments

@ZenithalHourlyRate
Copy link
Collaborator

ZenithalHourlyRate commented Dec 24, 2024

MLP is one of the simplest form of neural network, and supporting it might be a starter pointer for other more complex neural network (e.g. the widely-used ResNet benchmark in HE literature).

Indeed such support could be imported from frontend like TOSA/STABLEHLO, but as #1079 suggest, HEIR currently does not support some linalg op so we have to manually write it. Also, there are still some limitation on the HEIR CKKS pipeline so manual intervention is a must.

The essential code is in ZenithalHourlyRate@617e0d1, including the MLIR MLP impl and corresponding LLVM/OpenFHE driver.

The weight/test data needed is in https://cloud.tsinghua.edu.cn/f/b2da50f8bdbc4aa1859f/

Network Design

MNIST is a dataset of handwritten digits (image of 28x28), and MLP could be used to classify them (into 10 labels); the accuracy is often good enough, around 95%.

The typical MLP design involves the following part:

  • The first Fully-Connected layer, of size 784x512.
  • The activation layer, using RELU as the activation function.
  • The second Fully-Connected layer, of size 512x10.

In HEIR, we do have implemented Halevi-Shoup matrix multiplication, but we only support square matrix for now; also, RELU could be hardly expressed in HE primitive, so a polynomial approximation is needed (check #658, #665 and #1217).

So, the specialized network design becomes

  • The first FC layer, of size 1024x1024
  • The activation layer, using Approx-RELU based on polynomial appox
  • The second FC layer, of size 1024x1024

Padding is added accordingly.

Training

Training of this specialized network is done by @yushijinhun using pytorch, achieving accuracy of 96%.

Inference.

There are two version of inference MLIR impl in the code above. The mlp.mlir is to show that the cleartext computation itself is correct (lower to LLVM and run with C function), and the mlp_inline.mlir could be accepted by the HEIR pipeline and produce a OpenFHE code.

Cleartext computation (for verifying we are correct)

mlp.mlir contains the following code

func.func @approx_sign(%x: tensor<1x1024xf32>) -> tensor<1x1024xf32>
// detailed impl

func.func @approx_relu(%x: tensor<1x1024xf32>) -> tensor<1x1024xf32>
// detailed impl

func.func @mlp(%input: tensor<1x1024xf32>, %fc1: tensor<1024x1024xf32>, %fc2: tensor<1024x1024xf32>, %fc1_buffer: tensor<1x1024xf32>, %fc2_buffer: tensor<1x1024xf32>) -> tensor<1x1024xf32> attributes {llvm.emit_c_interface} {
  %fc1_result = linalg.matmul ins(%input, %fc1 : tensor<1x1024xf32>, tensor<1024x1024xf32>) outs(%fc1_buffer : tensor<1x1024xf32>) -> tensor<1x1024xf32>
  %relu1 = call @approx_relu(%fc1_result) : (tensor<1x1024xf32>) -> tensor<1x1024xf32>
  %fc2_result = linalg.matmul ins(%relu1, %fc2 : tensor<1x1024xf32>, tensor<1024x1024xf32>) outs(%fc2_buffer : tensor<1x1024xf32>) -> tensor<1x1024xf32>
  return %fc2_result : tensor<1x1024xf32>
}

The corresponding C interface is

extern "C" {
void *_mlir_ciface_mlp(MemRefDescriptor<float, 2> *output,
                       MemRefDescriptor<float, 2> *input,
                       MemRefDescriptor<float, 2> *fc1,
                       MemRefDescriptor<float, 2> *fc2,
                       MemRefDescriptor<float, 2> *fc1_buffer,
                       MemRefDescriptor<float, 2> *fc2_buffer);
}

The reason we use such function signature is that when lowered to LLVM (then interfaced with C), we have to deal with memref.alloc and ambiguous ownership can lead to memory bug.

The lowering and compiling is done by the following step

heir-opt --heir-polynomial-to-llvm mlp.mlir | mlir-translate --mlir-to-llvmir | llc --relocation-model=pic -o mlp.s
clang++ -o mlp_main mlp.cpp mlp.s libmlir_c_runner_utils.so -fPIE

Running ./mlp_main we will get an accuracy of 9634/10000.

Homomorphic computation

To convert things to homormophic domain, especially for the HEIR CKKS pipeline, the following transformation is done

# HEIR does not support func.call now so inline everything
mlir-opt --inline mlp.mlir > mlir_inline.mlir
# then delete @approx_sign and @approx_relu function in it

The tricky part of current matmul impl is that its weight/buffer should all be arith.constant instead of argument, so we make the following change to mlir_inline.mlir. Note that the filled weight is extracted from the weight above, so we end up with a 24MB mlir_inline.mlir

%buffer0 = arith.constant dense<0.0> : tensor<1x1024xf32>
%weight0 = arith.constant dense<[[ /* fill weight here! */ ]]> : tensor<1024x1024xf32> 
linalg.matmul ins(%arg0, %weight0 : tensor<1x1024xf32>, tensor<1024x1024xf32>) outs(%buffer0 : tensor<1x1024xf32>) -> tensor<1x1024xf32>

// similarly for fc2

Then we need to change the function signature to

func.func @mlp(%arg0: tensor<1x1024xf32> {secret.secret}) -> tensor<1x1024xf32>

Then we do the following transform

# Note that due to large size in input, all the following command is slow
# this specific command takes ~10 minutes
heir-opt -- --mlir-to-openfhe-ckks="entry-function=mlp" mlp_inline.mlir > mlp_openfhe.mlir
## or detailed command below to inspect each step
## Halevi-Shoup matmul
#heir-opt --mlir-to-secret-arithmetic mlp_inline.mlir > mlp_secret_arithmetic.mlir
## Annotate RNS information, it shows that we need 9 level to finish computing.
#heir-opt --secret-insert-mgmt-ckks mlp_small_secret_arithmetic.mlir > mlp_mgmt.mlir
## lower to CKKS dialect
#heir-opt --secret-distribute-generic --secret-to-ckks mlp_mgmt.mlir > mlp_ckks.mlir
## lower to OpenFHE dialect
#heir-opt --lwe-add-client-interface="use-public-key=true one-value-per-helper-fn=true" --ckks-to-lwe --lwe-to-openfhe --openfhe-configure-crypto-context="entry-function=mlp" mlp_ckks.mlir > mlp_openfhe.mlir
# Translate to C++ openfhe function
# Note that the weight is so large that by default it is printed in HEX form, which is unwanted by the current mlir-translate openfhe emitter
heir-translate '--mlir-print-elementsattrs-with-hex-if-larger=-1' --emit-openfhe-pke --openfhe-scheme=ckks mlp_openfhe.mlir > mlp_openfhe.cpp
heir-translate --emit-openfhe-pke-header --openfhe-scheme=ckks mlp_openfhe.mlir > mlp_openfhe.h

We need the following change in mlp__generate_crypto_context() in mlp_openfhe.cpp (be careful, this file is 30M, open it in IDE may make it stuck) to make it work:

  • Delete .SetPlaintextModulus(). CKKS does not need it.
  • Add v14431.SetSecurityLevel(HEStd_NotSet); v14431.SetRingDim(1 << 11);. Change v14431 to your own value. Because our ciphertext is of size 1024, for full packing we need RingDim 2048. To meet the default security parameter, RingDim would become 32768 and things is too slow.

Compile it against mlp_openfhe_main.cpp

# you should add -I and -L for your specific OpenFHE installation
# takes around 30s to compile
clang++ -std=c++17 -o mlp_openfhe mlp_openfhe_main.cpp mlp_openfhe.cpp -I. -lOPENFHEcore -lOPENFHEpke -lOPENFHEbinfhe
# set unlimited stack size, as the weight is all in stack now.
# otherwise you will observe segfault
ulimit -s unlimited
# run the program
./mlp_openfhe

We could get the following output. Each inference takes ~30 seconds, memory usage is ~3GB.

Element Parameters: ILDCRTParams [m=4096 n=2048 q=1867228318734141816999655779671310589399405959067135205574019961160007458457284606914047149436113163649889870755858538873363336506924688056138331269655941214209 ru=0 bigq=0 bigru=0]
  m_params:
    0: ILParams [m=4096 n=2048 q=1152921504606830593 ru=459811883340678 bigq=0 bigru=0]
    1: ILParams [m=4096 n=2048 q=1125899907260417 ru=479982368344 bigq=0 bigru=0]
    2: ILParams [m=4096 n=2048 q=1125899907063809 ru=809174752721 bigq=0 bigru=0]
    3: ILParams [m=4096 n=2048 q=1125899907219457 ru=105703250448 bigq=0 bigru=0]
    4: ILParams [m=4096 n=2048 q=1125899906977793 ru=996886817494 bigq=0 bigru=0]
    5: ILParams [m=4096 n=2048 q=1125899907145729 ru=842109450467 bigq=0 bigru=0]
    6: ILParams [m=4096 n=2048 q=1125899906990081 ru=25027901798 bigq=0 bigru=0]
    7: ILParams [m=4096 n=2048 q=1125899907096577 ru=203477575998 bigq=0 bigru=0]
    8: ILParams [m=4096 n=2048 q=1125899906826241 ru=1080667890455 bigq=0 bigru=0]
    9: ILParams [m=4096 n=2048 q=1125899906949121 ru=251751765212 bigq=0 bigru=0]
    10: ILParams [m=4096 n=2048 q=557057 ru=66 bigq=0 bigru=0]


Encoding Parameters: [p=50 rootP =0 bigP =0 rootBigP =0 g=0 L=1024]

max_id: 7, label: 7
max_id: 2, label: 2
max_id: 1, label: 1
max_id: 0, label: 0
max_id: 4, label: 4
max_id: 1, label: 1
max_id: 4, label: 4
max_id: 9, label: 9
max_id: 6, label: 5
max_id: 9, label: 9
accuracy: 9/10

If we use the RingDim 32768, it takes ~6min to do one inference and memory usage is 26G.

Discussion

  • The limitation on the current Halevi-shoup matmul impl makes the thing most painful as IR is big (~20M) and CPP output is big (~30M) and stack usage is big ulimit -s unlimited. We should support load weight at runtime instead of hard-encode it in IR. The technical issue is then issue many tensor op to correctly pack the plaintext weight matrix.
  • I think for all complex example/benchmark, there should be a cleartext version (lowered to LLVM) and a homormorphic verson (lowered to backend) so we could ensure correctness (easier to debug whether it is wrong input program or wrong compiler transformation) and we may further use these two lowering to compare cleartext/homormorphic computation efficiency difference.
  • I think this example fits well under tests/Example/benchmark, I want to discuss how to integrate it as a benchmark example of HEIR.
@j2kun
Copy link
Collaborator

j2kun commented Jan 6, 2025

This is really a fantastic start! Thank you for being so thorough.

I want to get this work checked in, so I'm trying to figure out how to tease apart the overall thing into different work items we can make progress on.

Instead of an arith.constant, we should be able to use a memref.global and the get_global op to avoid putting all the data on the stack.

At a high level, our main goal should be to ensure the file sizes are small, runtimes are fast, and that things can be compressed appropriately. So the goal to put the weights in a separate file should mainly be to enable ease of compression. For MLIR, I think we could get most of those benefits by having the values be globals and storing the files as MLIR bytecode (not plaintext), and we could easily check in the bytecode files as test inputs. Adding bytecode is a slightly more complicated task, in that it requires versioning our IR, so I'd like to avoid that as long as possible, and maybe that means having large-ish plaintext checked-in files for a while. So long as we avoid the compile time and stack usage issues above, we can find a good solution for having large test files.

@j2kun
Copy link
Collaborator

j2kun commented Jan 6, 2025

I think for all complex example/benchmark, there should be a cleartext version (lowered to LLVM) and a homormorphic verson (lowered to backend) so we could ensure correctness

I fully support this, I'm just trying to think through how we could make this easier to do than the current method, which involves relatively complex RUN directives in lit files.

@j2kun
Copy link
Collaborator

j2kun commented Jan 6, 2025

Another subtask: we should support emitting hex values in the emitter. I would bet using emitc instead of manual emission would get this for free, but in the mean time we might be able to port the code from emitc's codegen to get the same effect.

@j2kun
Copy link
Collaborator

j2kun commented Jan 6, 2025

Thinking about this some more, could we get the input model saved in an independent format, say ONNX or stablehlo, and use that as a starting point?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants