Skip to content

Commit

Permalink
Add openfhe sub plain operation
Browse files Browse the repository at this point in the history
Fixes #1200

Seems like openfhe does expose an overload to EvalSub that allows ct-pt.

PiperOrigin-RevId: 719367130
  • Loading branch information
asraa authored and copybara-github committed Jan 24, 2025
1 parent 68edaa0 commit dd91785
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 6 deletions.
4 changes: 3 additions & 1 deletion lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,9 @@ struct LWEToOpenfhe : public impl::LWEToOpenfheBase<LWEToOpenfhe> {
ConvertCiphertextPlaintextOp<bgv::AddPlainOp, openfhe::AddPlainOp>,
ConvertCiphertextPlaintextOp<ckks::AddPlainOp, openfhe::AddPlainOp>,

// TODO (#1200): SubPlain support for OpenFHE
// SubPlain
ConvertCiphertextPlaintextOp<bgv::SubPlainOp, openfhe::SubPlainOp>,
ConvertCiphertextPlaintextOp<ckks::SubPlainOp, openfhe::SubPlainOp>,

// MulPlain
ConvertCiphertextPlaintextOp<bgv::MulPlainOp, openfhe::MulPlainOp>,
Expand Down
13 changes: 13 additions & 0 deletions lib/Dialect/Openfhe/IR/OpenfheOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,19 @@ def AddPlainOp : Openfhe_Op<"add_plain",[
let results = (outs NewLWECiphertext:$output);
}

def SubPlainOp : Openfhe_Op<"sub_plain",[
Pure,
AllTypesMatch<["ciphertext", "output"]>
]> {
let summary = "OpenFHE sub operation of a ciphertext and a plaintext.";
let arguments = (ins
Openfhe_CryptoContext:$cryptoContext,
NewLWECiphertext:$ciphertext,
NewLWEPlaintext:$plaintext
);
let results = (outs NewLWECiphertext:$output);
}

def MulOp : Openfhe_BinaryOp<"mul"> { let summary = "OpenFHE mul operation of two ciphertexts with relinearization."; }

def MulNoRelinOp : Openfhe_Op<"mul_no_relin", [Pure, AllTypesMatch<["lhs", "rhs"]>]> {
Expand Down
17 changes: 12 additions & 5 deletions lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ LogicalResult OpenFhePkeEmitter::translate(Operation &op) {
.Case<lwe::RLWEDecodeOp, lwe::ReinterpretUnderlyingTypeOp>(
[&](auto op) { return printOperation(op); })
// OpenFHE ops
.Case<AddOp, AddPlainOp, SubOp, MulNoRelinOp, MulOp, MulPlainOp,
SquareOp, NegateOp, MulConstOp, RelinOp, ModReduceOp,
LevelReduceOp, RotOp, AutomorphOp, KeySwitchOp, EncryptOp,
DecryptOp, GenParamsOp, GenContextOp, GenMulKeyOp, GenRotKeyOp,
GenBootstrapKeyOp, MakePackedPlaintextOp,
.Case<AddOp, AddPlainOp, SubOp, SubPlainOp, MulNoRelinOp, MulOp,
MulPlainOp, SquareOp, NegateOp, MulConstOp, RelinOp,
ModReduceOp, LevelReduceOp, RotOp, AutomorphOp, KeySwitchOp,
EncryptOp, DecryptOp, GenParamsOp, GenContextOp, GenMulKeyOp,
GenRotKeyOp, GenBootstrapKeyOp, MakePackedPlaintextOp,
MakeCKKSPackedPlaintextOp, SetupBootstrapOp, BootstrapOp>(
[&](auto op) { return printOperation(op); })
.Default([&](Operation &) {
Expand Down Expand Up @@ -235,6 +235,13 @@ LogicalResult OpenFhePkeEmitter::printOperation(SubOp op) {
{op.getLhs(), op.getRhs()}, "EvalSub");
}

LogicalResult OpenFhePkeEmitter::printOperation(SubPlainOp op) {
// OpenFHE defines an overload for EvalSub to work on both plaintext and
// ciphertext inputs.
return printEvalMethod(op.getResult(), op.getCryptoContext(),
{op.getCiphertext(), op.getPlaintext()}, "EvalSub");
}

LogicalResult OpenFhePkeEmitter::printOperation(MulNoRelinOp op) {
return printEvalMethod(op.getResult(), op.getCryptoContext(),
{op.getLhs(), op.getRhs()}, "EvalMultNoRelin");
Expand Down
1 change: 1 addition & 0 deletions lib/Target/OpenFhePke/OpenFhePkeEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class OpenFhePkeEmitter {
LogicalResult printOperation(SetupBootstrapOp op);
LogicalResult printOperation(SquareOp op);
LogicalResult printOperation(SubOp op);
LogicalResult printOperation(SubPlainOp op);

// Helpers for above
LogicalResult printEvalMethod(::mlir::Value result,
Expand Down
11 changes: 11 additions & 0 deletions tests/Dialect/Openfhe/Emitters/emit_openfhe_pke.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,17 @@ module attributes {scheme.ckks} {
%1 = lwe.rlwe_decode %0 {encoding = #full_crt_packing_encoding, ring = #ring_Z65537_i64_1_x32_} : !scalar_pt_ty -> i16
return %1 : i16
}
// CHECK-LABEL: CiphertextT test_sub_plain(
// CHECK-SAME: CryptoContextT [[CC:[^,]*]],
// CHECK-SAME: Plaintext [[ARG1:[^,]*]],
// CHECK-SAME: CiphertextT [[ARG2:[^,]*]]) {
// CHECK-NEXT: const auto& [[v0:.*]] = [[CC]]->EvalSub([[ARG2]], [[ARG1]]);
// CHECK-NEXT: return [[v0]];
// CHECK-NEXT: }
func.func @test_sub_plain(%cc: !openfhe.crypto_context, %pt :!tensor_pt_ty, %ct : !tensor_ct_ty) -> !tensor_ct_ty {
%0 = openfhe.sub_plain %cc, %ct, %pt: (!openfhe.crypto_context, !tensor_ct_ty, !tensor_pt_ty) -> !tensor_ct_ty
return %0 : !tensor_ct_ty
}
}

// -----
Expand Down
6 changes: 6 additions & 0 deletions tests/Dialect/Openfhe/IR/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ module {
return
}

// CHECK-LABEL: func @test_sub_plain
func.func @test_sub_plain(%cc : !cc, %pt : !pt, %ct: !ct) {
%out = openfhe.sub_plain %cc, %ct, %pt: (!cc, !ct, !pt) -> !ct
return
}

// CHECK-LABEL: func @test_mul
func.func @test_mul(%cc : !cc, %pt : !pt, %pk: !pk) {
%c1 = openfhe.encrypt %cc, %pt, %pk : (!cc, !pt, !pk) -> !ct
Expand Down
8 changes: 8 additions & 0 deletions tests/Examples/openfhe/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ openfhe_end_to_end_test(
test_src = "binops_test.cpp",
)

openfhe_end_to_end_test(
name = "ciphertext_plaintext_ops_test",
generated_lib_header = "ciphertext_plaintext_ops_lib.h",
mlir_src = "ciphertext_plaintext_ops.mlir",
tags = ["notap"],
test_src = "ciphertext_plaintext_ops_test.cpp",
)

openfhe_end_to_end_test(
name = "simple_sum_test",
generated_lib_header = "simple_sum_lib.h",
Expand Down
26 changes: 26 additions & 0 deletions tests/Examples/openfhe/ciphertext_plaintext_ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: heir-translate %s --emit-openfhe-pke | FileCheck %s

!cc = !openfhe.crypto_context

!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64>
!Z65537_i64_ = !mod_arith.int<65537 : i64>
#key = #lwe.key<>
#modulus_chain_L5_C0_ = #lwe.modulus_chain<elements = <1095233372161 : i64, 1032955396097 : i64, 1005037682689 : i64, 998595133441 : i64, 972824936449 : i64, 959939837953 : i64>, current = 0>
!rns_L0_ = !rns.rns<!Z1095233372161_i64_>
#ring_rns_L0_1_x8_ = #polynomial.ring<coefficientType = !rns_L0_, polynomialModulus = <1 + x**8>>
#ring_Z65537_i64_1_x8_ = #polynomial.ring<coefficientType = !Z65537_i64_, polynomialModulus = <1 + x**8>>
#inverse_canonical_encoding = #lwe.inverse_canonical_encoding<scaling_factor = 1024>
#plaintext_space = #lwe.plaintext_space<ring = #ring_Z65537_i64_1_x8_, encoding = #inverse_canonical_encoding>
#ciphertext_space_L0_ = #lwe.ciphertext_space<ring = #ring_rns_L0_1_x8_, encryption_type = lsb>
!pt = !lwe.new_lwe_plaintext<application_data = <message_type = f16>, plaintext_space = #plaintext_space>
!ct = !lwe.new_lwe_ciphertext<application_data = <message_type = f16>, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_>

// [(input1 + input2) - input3] * input4
module attributes {scheme.bgv} {
func.func @test_ciphertext_plaintext_ops(%cc : !cc, %input1 : !ct, %input2 : !pt, %input3 : !pt, %input4 : !pt) -> !ct {
%add_res = openfhe.add_plain %cc, %input1, %input2 : (!cc, !ct, !pt) -> !ct
%sub_res = openfhe.sub_plain %cc, %add_res, %input3 : (!cc, !ct, !pt) -> !ct
%mul_res = openfhe.mul_plain %cc, %sub_res, %input4 : (!cc, !ct, !pt) -> !ct
return %mul_res : !ct
}
}
57 changes: 57 additions & 0 deletions tests/Examples/openfhe/ciphertext_plaintext_ops_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#include <cstdint>
#include <vector>

#include "gmock/gmock.h" // from @googletest
#include "gtest/gtest.h" // from @googletest
#include "src/pke/include/openfhe.h" // from @openfhe

// Generated headers (block clang-format from messing up order)
#include "tests/Examples/openfhe/ciphertext_plaintext_ops_lib.h"

using namespace lbcrypto;
using ::testing::ElementsAre;

namespace mlir {
namespace heir {
namespace openfhe {

TEST(CiphertextPlaintextOpsTest, TestInput1) {
CCParams<CryptoContextBGVRNS> parameters;
parameters.SetMultiplicativeDepth(2);
parameters.SetPlaintextModulus(65537);
CryptoContext<DCRTPoly> cryptoContext = GenCryptoContext(parameters);
cryptoContext->Enable(PKE);
cryptoContext->Enable(KEYSWITCH);
cryptoContext->Enable(LEVELEDSHE);

KeyPair<DCRTPoly> keyPair;
keyPair = cryptoContext->KeyGen();
cryptoContext->EvalMultKeyGen(keyPair.secretKey);
cryptoContext->EvalRotateKeyGen(keyPair.secretKey, {1, 2, -1, -2});

std::vector<int64_t> vectorOfInts1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
Plaintext plaintext1 = cryptoContext->MakePackedPlaintext(vectorOfInts1);
std::vector<int64_t> vectorOfInts2 = {3, 2, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12};
Plaintext plaintext2 = cryptoContext->MakePackedPlaintext(vectorOfInts2);
std::vector<int64_t> vectorOfInts3 = {3, 2, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12};
Plaintext plaintext3 = cryptoContext->MakePackedPlaintext(vectorOfInts3);
std::vector<int64_t> vectorOfInts4 = {1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2};
Plaintext plaintext4 = cryptoContext->MakePackedPlaintext(vectorOfInts4);

auto ciphertext1 = cryptoContext->Encrypt(keyPair.publicKey, plaintext1);

// Computes [(ciphertext1 + plaintext2) - plaintext3] * plaintext4
auto ciphertextActual = test_ciphertext_plaintext_ops(
cryptoContext, ciphertext1, plaintext2, plaintext3, plaintext4);

Plaintext plaintextActual;
cryptoContext->Decrypt(keyPair.secretKey, ciphertextActual, &plaintextActual);
auto actual = plaintextActual->GetPackedValue();
actual.resize(12);

EXPECT_THAT(actual, ElementsAre(1, 4, 3, 8, 5, 12, 7, 16, 9, 20, 11, 24));
}

} // namespace openfhe
} // namespace heir
} // namespace mlir

0 comments on commit dd91785

Please sign in to comment.