From f2e53f84be38864d4a9f3ecb62fed90a54df9512 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20P=C3=A9r=C3=A9?= <alexandre.pere@zama.ai> Date: Thu, 21 Nov 2024 14:45:50 +0100 Subject: [PATCH] fix --- .../include/concretelang/Common/Keysets.h | 4 +- .../include/concretelang/Common/Security.h | 36 ++--- .../lib/Bindings/Python/CompilerAPIModule.cpp | 129 +++++++++++++++++- .../compiler/lib/Common/CMakeLists.txt | 1 + .../compiler/lib/Common/Keysets.cpp | 128 ++++++++++------- .../compiler/lib/Common/Security.cpp | 37 +++++ .../compiler/lib/Runtime/simulation.cpp | 7 +- .../lib/Support/ProgramInfoGeneration.cpp | 17 ++- .../src/concrete-optimizer.rs | 21 ++- .../src/cpp/concrete-optimizer.cpp | 6 +- .../src/cpp/concrete-optimizer.hpp | 2 +- .../optimization/dag/multi_parameters/mod.rs | 2 +- ...neric_generation.rs => virtual_circuit.rs} | 38 ++++-- .../tests/compilation/test_restrictions.py | 2 +- 14 files changed, 312 insertions(+), 118 deletions(-) rename tools/parameter-curves/concrete-security-curves-cpp/include/concrete/curves.h => compilers/concrete-compiler/compiler/include/concretelang/Common/Security.h (69%) create mode 100644 compilers/concrete-compiler/compiler/lib/Common/Security.cpp rename compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/{generic_generation.rs => virtual_circuit.rs} (87%) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Common/Keysets.h b/compilers/concrete-compiler/compiler/include/concretelang/Common/Keysets.h index b6d47a20b6..0e84df64b4 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Common/Keysets.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Common/Keysets.h @@ -93,9 +93,9 @@ class KeysetCache { KeysetCache() = default; }; -Message<concreteprotocol::KeysetInfo> generate_generic_keyset_info( +Message<concreteprotocol::KeysetInfo> keysetInfoFromVirtualCircuit( std::vector<concrete_optimizer::utils::PartitionDefinition> partitions, - bool generate_fks); + bool generate_fks, std::optional<concrete_optimizer::Options> options); } // namespace keysets } // namespace concretelang diff --git a/tools/parameter-curves/concrete-security-curves-cpp/include/concrete/curves.h b/compilers/concrete-compiler/compiler/include/concretelang/Common/Security.h similarity index 69% rename from tools/parameter-curves/concrete-security-curves-cpp/include/concrete/curves.h rename to compilers/concrete-compiler/compiler/include/concretelang/Common/Security.h index 839d0a0445..3fd25f98ce 100644 --- a/tools/parameter-curves/concrete-security-curves-cpp/include/concrete/curves.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Common/Security.h @@ -2,16 +2,11 @@ // Exceptions. See // https://github.com/zama-ai/concrete/blob/main/LICENSE.txt // for license information. +#ifndef CONCRETELANG_COMMON_SECURITY_H +#define CONCRETELANG_COMMON_SECURITY_H -#ifndef CONCRETELANG_SUPPORT_V0CURVES_H_ -#define CONCRETELANG_SUPPORT_V0CURVES_H_ - -#include <algorithm> -#include <cmath> -#include <cstddef> -#include <vector> - -namespace concrete { +namespace concretelang { +namespace security { enum KeyFormat { BINARY, @@ -42,31 +37,16 @@ struct SecurityCurve { /// @param polynomialSize The size of the polynom of the glwe /// @param logQ The log of q /// @return The secure encryption variances - double getVariance(int glweDimension, int polynomialSize, int logQ) { - auto size = glweDimension * polynomialSize; - if (size < minimalLweDimension) { - return NAN; - } - auto a = std::pow(2, (slope * size + bias) * 2); - auto b = std::pow(2, -2 * (logQ - 2)); - return a > b ? a : b; - } + double getVariance(int glweDimension, int polynomialSize, int logQ); }; -#include "curves.gen.h" - /// @brief Return the security curve for a given level and a key format. /// @param bitsOfSecurity The number of bits of security /// @param keyFormat The format of the key /// @return The security curve or nullptr if the curve is not found. -SecurityCurve *getSecurityCurve(int bitsOfSecurity, KeyFormat keyFormat) { - for (size_t i = 0; i < curvesLen; i++) { - if (curves[i].bits == bitsOfSecurity && curves[i].keyFormat == keyFormat) - return &curves[i]; - } - return nullptr; -} +SecurityCurve *getSecurityCurve(int bitsOfSecurity, KeyFormat keyFormat); -} // namespace concrete +} // namespace security +} // namespace concretelang #endif diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 86ea598963..30377500d6 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -18,7 +18,6 @@ #include "concretelang/Support/Error.h" #include "concretelang/Support/V0Parameters.h" #include "concretelang/Support/logging.h" -#include <cstdint> #include <filesystem> #include <memory> #include <mlir-c/Bindings/Python/Interop.h> @@ -326,6 +325,122 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( .doc() = "Allow to restrict the optimizer search space to be compatible " "with a keyset."; + // ------------------------------------------------------------------------------// + // OPTIMIZER OPTIONS // + // ------------------------------------------------------------------------------// + pybind11::class_<concrete_optimizer::Options>(m, "OptimizerOptions") + .def( + "set_security_level", + [](concrete_optimizer::Options &options, uint64_t security_level) { + options.security_level = security_level; + }, + "Set option for security level.", arg("security_level")) + .def( + "set_maximum_acceptable_error_probability", + [](concrete_optimizer::Options &options, + double maximum_acceptable_error_probability) { + options.maximum_acceptable_error_probability = + maximum_acceptable_error_probability; + }, + "Set option for maximum acceptable error probability.", + arg("maximum_acceptable_error_probability")) + .def( + "set_key_sharing", + [](concrete_optimizer::Options &options, bool key_sharing) { + options.key_sharing = key_sharing; + }, + "Set option for key sharing.", arg("key_sharing")) + .def( + "set_multi_param_strategy_to_by_precision", + [](concrete_optimizer::Options &options) { + options.multi_param_strategy = + concrete_optimizer::MultiParamStrategy::ByPrecision; + }, + "Set option for multi param strategy to by-precision.") + .def( + "set_multi_param_strategy_to_by_precision_and_norm_2", + [](concrete_optimizer::Options &options) { + options.multi_param_strategy = + concrete_optimizer::MultiParamStrategy::ByPrecisionAndNorm2; + }, + "Set option for multi param strategy to by-precision-and-norm2.") + .def( + "set_default_log_norm2_woppbs", + [](concrete_optimizer::Options &options, + double default_log_norm2_woppbs) { + options.default_log_norm2_woppbs = default_log_norm2_woppbs; + }, + "Set option for default log norm2 woppbs.", + arg("default_log_norm2_woppbs")) + .def( + "set_use_gpu_constraints", + [](concrete_optimizer::Options &options, bool use_gpu_constraints) { + options.use_gpu_constraints = use_gpu_constraints; + }, + "Set option for use gpu constrints.", arg("use_gpu_constraints")) + .def( + "set_encoding_to_auto", + [](concrete_optimizer::Options &options) { + options.encoding = concrete_optimizer::Encoding::Auto; + }, + "Set option for encoding to auto.") + .def( + "set_encoding_to_crt", + [](concrete_optimizer::Options &options) { + options.encoding = concrete_optimizer::Encoding::Crt; + }, + "Set option for encoding to crt.") + .def( + "set_encoding_to_native", + [](concrete_optimizer::Options &options) { + options.encoding = concrete_optimizer::Encoding::Native; + }, + "Set option for encoding to native.") + .def( + "set_cache_on_disk", + [](concrete_optimizer::Options &options, bool cache_on_disk) { + options.cache_on_disk = cache_on_disk; + }, + "Set option for cache on disk.", arg("cache_on_disk")) + .def( + "set_ciphertext_modulus_log", + [](concrete_optimizer::Options &options, + uint32_t ciphertext_modulus_log) { + options.ciphertext_modulus_log = ciphertext_modulus_log; + }, + "Set option for ciphertext modulus log.", + arg("ciphertext_modulus_log")) + .def( + "set_fft_precision", + [](concrete_optimizer::Options &options, uint32_t fft_precision) { + options.fft_precision = fft_precision; + }, + "Set option for fft precision.", arg("fft_precision")) + .def( + "set_fft_precision", + [](concrete_optimizer::Options &options, uint32_t fft_precision) { + options.fft_precision = fft_precision; + }, + "Set option for fft precision.", arg("fft_precision")) + .def( + "set_range_restriction", + [](concrete_optimizer::Options &options, + concrete_optimizer::restriction::RangeRestriction restriction) { + options.range_restriction = std::make_shared< + concrete_optimizer::restriction::RangeRestriction>(restriction); + }, + "Set option for range restriction", arg("restriction")) + .def( + "set_keyset_restriction", + [](concrete_optimizer::Options &options, + concrete_optimizer::restriction::KeysetRestriction restriction) { + options.keyset_restriction = std::make_shared< + concrete_optimizer::restriction::KeysetRestriction>( + restriction); + }, + "Set option for keyset restriction", arg("restriction")) + .doc() = "Options for the optimizer."; + // ------------------------------------------------------------------------------// // COMPILATION OPTIONS // // ------------------------------------------------------------------------------// @@ -972,18 +1087,20 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( typedef Message<concreteprotocol::KeysetInfo> KeysetInfo; pybind11::class_<KeysetInfo>(m, "KeysetInfo") .def_static( - "generate_generic", + "generate_virtual", [](std::vector<concrete_optimizer::utils::PartitionDefinition> partitions, - bool generateFks) -> KeysetInfo { + bool generateFks, + std::optional<concrete_optimizer::Options> options) -> KeysetInfo { if (partitions.size() < 2) { throw std::runtime_error("Need at least two partition defs to " - "generate a generic keyset info."); + "generate a virtual keyset info."); } - return ::concretelang::keysets::generate_generic_keyset_info( - partitions, generateFks); + return ::concretelang::keysets::keysetInfoFromVirtualCircuit( + partitions, generateFks, options); }, arg("partition_defs"), arg("generate_fks"), + arg("options") = std::nullopt, "Generate a generic keyset info for a set of partition definitions") .def( "secret_keys", diff --git a/compilers/concrete-compiler/compiler/lib/Common/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Common/CMakeLists.txt index 7c91bd3fb6..843defdb5f 100644 --- a/compilers/concrete-compiler/compiler/lib/Common/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Common/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_library( Keys.cpp Keysets.cpp Transformers.cpp + Security.cpp Values.cpp DEPENDS concrete-protocol diff --git a/compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp b/compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp index dabdcd052e..3096e64a0d 100644 --- a/compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp +++ b/compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp @@ -11,6 +11,7 @@ #include "concretelang/Common/Csprng.h" #include "concretelang/Common/Error.h" #include "concretelang/Common/Keys.h" +#include "concretelang/Common/Security.h" #include "kj/common.h" #include "kj/io.h" #include "llvm/ADT/ScopeExit.h" @@ -418,98 +419,131 @@ KeysetCache::getKeyset(const Message<concreteprotocol::KeysetInfo> &keysetInfo, return std::move(keyset); } -Message<concreteprotocol::KeysetInfo> generate_generic_keyset_info( - std::vector<concrete_optimizer::utils::PartitionDefinition> partitionDefs, - bool generateFks) { +Message<concreteprotocol::KeysetInfo> +generateKeysetInfoFromParameters(CircuitKeys parameters, + concrete_optimizer::Options options) { auto output = Message<concreteprotocol::KeysetInfo>{}; - rust::Vec<concrete_optimizer::utils::PartitionDefinition> rustPartitionDefs{}; - for (auto def : partitionDefs) { - rustPartitionDefs.push_back(def); - } - auto parameters = concrete_optimizer::utils::generate_generic_keyset_info( - rustPartitionDefs, generateFks); + auto curve = ::concretelang::security::getSecurityCurve( + options.security_level, ::concretelang::security::BINARY); auto skLen = (int)parameters.secret_keys.size(); auto skBuilder = output.asBuilder().initLweSecretKeys(skLen); - for (int i = 0; i < skLen; i++) { + for (auto sk : llvm::enumerate(parameters.secret_keys)) { auto output = Message<concreteprotocol::LweSecretKeyInfo>(); - auto sk = parameters.secret_keys[i]; - output.asBuilder().setId(sk.identifier); + output.asBuilder().setId(sk.value().identifier); output.asBuilder().getParams().setIntegerPrecision(64); - output.asBuilder().getParams().setLweDimension(sk.polynomial_size * - sk.glwe_dimension); + output.asBuilder().getParams().setLweDimension(sk.value().polynomial_size * + sk.value().glwe_dimension); output.asBuilder().getParams().setKeyType( ::concreteprotocol::KeyType::BINARY); - skBuilder.setWithCaveats(i, output.asReader()); + skBuilder.setWithCaveats(sk.index(), output.asReader()); } auto bskLen = (int)parameters.bootstrap_keys.size(); auto bskBuilder = output.asBuilder().initLweBootstrapKeys(bskLen); - for (int i = 0; i < bskLen; i++) { + for (auto bsk : llvm::enumerate(parameters.bootstrap_keys)) { auto output = Message<concreteprotocol::LweBootstrapKeyInfo>(); - auto bsk = parameters.bootstrap_keys[i]; - output.asBuilder().setId(bsk.identifier); - output.asBuilder().setInputId(bsk.input_key.identifier); - output.asBuilder().setOutputId(bsk.output_key.identifier); + output.asBuilder().setId(bsk.value().identifier); + output.asBuilder().setInputId(bsk.value().input_key.identifier); + output.asBuilder().setOutputId(bsk.value().output_key.identifier); output.asBuilder().getParams().setLevelCount( - bsk.br_decomposition_parameter.level); + bsk.value().br_decomposition_parameter.level); output.asBuilder().getParams().setBaseLog( - bsk.br_decomposition_parameter.log2_base); + bsk.value().br_decomposition_parameter.log2_base); output.asBuilder().getParams().setGlweDimension( - bsk.output_key.glwe_dimension); + bsk.value().output_key.glwe_dimension); output.asBuilder().getParams().setPolynomialSize( - bsk.output_key.polynomial_size); + bsk.value().output_key.polynomial_size); output.asBuilder().getParams().setInputLweDimension( - bsk.input_key.polynomial_size); + bsk.value().input_key.polynomial_size); output.asBuilder().getParams().setIntegerPrecision(64); output.asBuilder().getParams().setKeyType( concreteprotocol::KeyType::BINARY); - bskBuilder.setWithCaveats(i, output.asReader()); + output.asBuilder().getParams().setVariance( + curve->getVariance(bsk.value().output_key.glwe_dimension, + bsk.value().output_key.polynomial_size, 64)); + bskBuilder.setWithCaveats(bsk.index(), output.asReader()); } auto kskLen = (int)parameters.keyswitch_keys.size(); auto ckskLen = (int)parameters.conversion_keyswitch_keys.size(); auto kskBuilder = output.asBuilder().initLweKeyswitchKeys(kskLen + ckskLen); - for (int i = 0; i < kskLen; i++) { + for (auto ksk : llvm::enumerate(parameters.keyswitch_keys)) { auto output = Message<concreteprotocol::LweKeyswitchKeyInfo>(); - auto ksk = parameters.keyswitch_keys[i]; - output.asBuilder().setId(ksk.identifier); - output.asBuilder().setInputId(ksk.input_key.identifier); - output.asBuilder().setOutputId(ksk.output_key.identifier); + output.asBuilder().setId(ksk.value().identifier); + output.asBuilder().setInputId(ksk.value().input_key.identifier); + output.asBuilder().setOutputId(ksk.value().output_key.identifier); output.asBuilder().getParams().setLevelCount( - ksk.ks_decomposition_parameter.level); + ksk.value().ks_decomposition_parameter.level); output.asBuilder().getParams().setBaseLog( - ksk.ks_decomposition_parameter.log2_base); + ksk.value().ks_decomposition_parameter.log2_base); output.asBuilder().getParams().setIntegerPrecision(64); output.asBuilder().getParams().setInputLweDimension( - ksk.input_key.glwe_dimension * ksk.input_key.polynomial_size); + ksk.value().input_key.glwe_dimension * + ksk.value().input_key.polynomial_size); output.asBuilder().getParams().setOutputLweDimension( - ksk.output_key.glwe_dimension * ksk.output_key.polynomial_size); + ksk.value().output_key.glwe_dimension * + ksk.value().output_key.polynomial_size); output.asBuilder().getParams().setKeyType( concreteprotocol::KeyType::BINARY); - kskBuilder.setWithCaveats(i, output.asReader()); - } - for (int i = 0; i < ckskLen; i++) { + output.asBuilder().getParams().setVariance( + curve->getVariance(1, + ksk.value().output_key.glwe_dimension * + ksk.value().output_key.polynomial_size, + 64)); + kskBuilder.setWithCaveats(ksk.index(), output.asReader()); + } + for (auto ksk : llvm::enumerate(parameters.conversion_keyswitch_keys)) { auto output = Message<concreteprotocol::LweKeyswitchKeyInfo>(); - auto ksk = parameters.conversion_keyswitch_keys[i]; - output.asBuilder().setId(ksk.identifier); - output.asBuilder().setInputId(ksk.input_key.identifier); - output.asBuilder().setOutputId(ksk.output_key.identifier); + output.asBuilder().setId(ksk.value().identifier); + output.asBuilder().setInputId(ksk.value().input_key.identifier); + output.asBuilder().setOutputId(ksk.value().output_key.identifier); output.asBuilder().getParams().setLevelCount( - ksk.ks_decomposition_parameter.level); + ksk.value().ks_decomposition_parameter.level); output.asBuilder().getParams().setBaseLog( - ksk.ks_decomposition_parameter.log2_base); + ksk.value().ks_decomposition_parameter.log2_base); output.asBuilder().getParams().setIntegerPrecision(64); output.asBuilder().getParams().setInputLweDimension( - ksk.input_key.glwe_dimension * ksk.input_key.polynomial_size); + ksk.value().input_key.glwe_dimension * + ksk.value().input_key.polynomial_size); output.asBuilder().getParams().setOutputLweDimension( - ksk.output_key.glwe_dimension * ksk.output_key.polynomial_size); + ksk.value().output_key.glwe_dimension * + ksk.value().output_key.polynomial_size); output.asBuilder().getParams().setKeyType( concreteprotocol::KeyType::BINARY); - kskBuilder.setWithCaveats(i + kskLen, output.asReader()); + output.asBuilder().getParams().setVariance( + curve->getVariance(1, + ksk.value().output_key.glwe_dimension * + ksk.value().output_key.polynomial_size, + 64)); + kskBuilder.setWithCaveats(ksk.index() + kskLen, output.asReader()); } return output; } +Message<concreteprotocol::KeysetInfo> keysetInfoFromVirtualCircuit( + std::vector<concrete_optimizer::utils::PartitionDefinition> partitionDefs, + bool generateFks, std::optional<concrete_optimizer::Options> options) { + + rust::Vec<concrete_optimizer::utils::PartitionDefinition> rustPartitionDefs{}; + for (auto def : partitionDefs) { + rustPartitionDefs.push_back(def); + } + + auto defaultOptions = concrete_optimizer::Options{}; + defaultOptions.security_level = 128; + defaultOptions.maximum_acceptable_error_probability = 0.000063342483999973; + defaultOptions.key_sharing = true; + defaultOptions.ciphertext_modulus_log = 64; + defaultOptions.fft_precision = 53; + + auto opts = options.value_or(defaultOptions); + + auto parameters = concrete_optimizer::utils::generate_virtual_keyset_info( + rustPartitionDefs, generateFks, opts); + + return generateKeysetInfoFromParameters(parameters, opts); +} + } // namespace keysets } // namespace concretelang diff --git a/compilers/concrete-compiler/compiler/lib/Common/Security.cpp b/compilers/concrete-compiler/compiler/lib/Common/Security.cpp new file mode 100644 index 0000000000..330b8bfb35 --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Common/Security.cpp @@ -0,0 +1,37 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete/blob/main/LICENSE.txt +// for license information. + +#include "concretelang/Common/Security.h" +#include <algorithm> +#include <cmath> +#include <cstddef> +#include <vector> + +namespace concretelang { +namespace security { + +double SecurityCurve::getVariance(int glweDimension, int polynomialSize, + int logQ) { + auto size = glweDimension * polynomialSize; + if (size < minimalLweDimension) { + return NAN; + } + auto a = std::pow(2, (slope * size + bias) * 2); + auto b = std::pow(2, -2 * (logQ - 2)); + return a > b ? a : b; +} + +#include "concrete/curves.gen.h" + +SecurityCurve *getSecurityCurve(int bitsOfSecurity, KeyFormat keyFormat) { + for (size_t i = 0; i < curvesLen; i++) { + if (curves[i].bits == bitsOfSecurity && curves[i].keyFormat == keyFormat) + return &curves[i]; + } + return nullptr; +} + +} // namespace security +} // namespace concretelang diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp index 4a2a319000..825a06bd48 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp +++ b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp @@ -6,8 +6,8 @@ #include "concretelang/Runtime/simulation.h" #include "concrete-cpu-noise-model.h" #include "concrete-cpu.h" -#include "concrete/curves.h" #include "concretelang/Common/Csprng.h" +#include "concretelang/Common/Security.h" #include "concretelang/Runtime/wrappers.h" #include "concretelang/Support/V0Parameters.h" #include <assert.h> @@ -19,8 +19,9 @@ using concretelang::csprng::SoftCSPRNG; thread_local auto default_csprng = SoftCSPRNG(0); const uint64_t UINT63_MAX = UINT64_MAX >> 1; -inline concrete::SecurityCurve *security_curve() { - return concrete::getSecurityCurve(128, concrete::BINARY); +inline concretelang::security::SecurityCurve *security_curve() { + return concretelang::security::getSecurityCurve( + 128, concretelang::security::BINARY); } uint64_t from_torus(double torus) { diff --git a/compilers/concrete-compiler/compiler/lib/Support/ProgramInfoGeneration.cpp b/compilers/concrete-compiler/compiler/lib/Support/ProgramInfoGeneration.cpp index 5511cf5fc2..c2554b9abb 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/ProgramInfoGeneration.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/ProgramInfoGeneration.cpp @@ -12,8 +12,8 @@ #include "capnp/message.h" #include "concrete-protocol.capnp.h" -#include "concrete/curves.h" #include "concretelang/Common/Protocol.h" +#include "concretelang/Common/Security.h" #include "concretelang/Common/Values.h" #include "concretelang/Conversion/Utils/GlobalFHEContext.h" #include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" @@ -41,13 +41,13 @@ using concretelang::protocol::Message; namespace mlir { namespace concretelang { -const auto keyFormat = concrete::BINARY; +const auto keyFormat = ::concretelang::security::BINARY; typedef double Variance; llvm::Expected<Message<concreteprotocol::GateInfo>> generateGate(mlir::Type inputType, const Message<concreteprotocol::EncodingInfo> &inputEncodingInfo, - concrete::SecurityCurve curve, + ::concretelang::security::SecurityCurve curve, concreteprotocol::Compression compression) { auto inputEncoding = inputEncodingInfo.asReader().getEncoding(); @@ -181,7 +181,8 @@ generateGate(mlir::Type inputType, Message<concreteprotocol::KeysetInfo> extractKeysetInfo(TFHE::TFHECircuitKeys circuitKeys, - concrete::SecurityCurve curve, bool compressEvaluationKeys) { + ::concretelang::security::SecurityCurve curve, + bool compressEvaluationKeys) { auto output = Message<concreteprotocol::KeysetInfo>(); @@ -307,7 +308,7 @@ extractKeysetInfo(TFHE::TFHECircuitKeys circuitKeys, llvm::Expected<Message<concreteprotocol::CircuitInfo>> extractCircuitInfo(mlir::func::FuncOp funcOp, concreteprotocol::CircuitEncodingInfo::Reader encodings, - concrete::SecurityCurve curve, + ::concretelang::security::SecurityCurve curve, bool compressInputCiphertexts) { auto output = Message<concreteprotocol::CircuitInfo>(); @@ -348,7 +349,8 @@ extractCircuitInfo(mlir::func::FuncOp funcOp, llvm::Expected<Message<concreteprotocol::ProgramInfo>> extractProgramInfo( mlir::ModuleOp module, const Message<concreteprotocol::ProgramEncodingInfo> &encodings, - concrete::SecurityCurve curve, bool compressInputCiphertexts) { + ::concretelang::security::SecurityCurve curve, + bool compressInputCiphertexts) { auto output = Message<concreteprotocol::ProgramInfo>(); auto circuitsCount = encodings.asReader().getCircuits().size(); @@ -386,7 +388,8 @@ createProgramInfoFromTfheDialect( bool compressEvaluationKeys, bool compressInputCiphertexts) { // Check that security curves exist - const auto curve = concrete::getSecurityCurve(bitsOfSecurity, keyFormat); + const auto curve = + ::concretelang::security::getSecurityCurve(bitsOfSecurity, keyFormat); if (curve == nullptr) { return StreamStringError("Cannot find security curves for ") << bitsOfSecurity << "bits"; diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs index c76d31eab4..836c7ce31b 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -11,13 +11,13 @@ use concrete_optimizer::dag::operator::{ }; use concrete_optimizer::dag::unparametrized; use concrete_optimizer::optimization::config::{Config, SearchSpace}; -use concrete_optimizer::optimization::dag::multi_parameters::generic_generation::generate_generic_parameters; use concrete_optimizer::optimization::dag::multi_parameters::keys_spec::CircuitSolution; use concrete_optimizer::optimization::dag::multi_parameters::optimize::{ KeysetRestriction, MacroParameters, NoSearchSpaceRestriction, RangeRestriction, SearchSpaceRestriction, }; use concrete_optimizer::optimization::dag::multi_parameters::partition_cut::PartitionCut; +use concrete_optimizer::optimization::dag::multi_parameters::virtual_circuit::generate_virtual_parameters; use concrete_optimizer::optimization::dag::multi_parameters::{keys_spec, PartitionIndex}; use concrete_optimizer::optimization::dag::solo_key::optimize_generic::{ Encoding, Solution as DagSolution, @@ -914,18 +914,28 @@ fn location_from_string(string: &str) -> Box<Location> { } } -fn generate_generic_keyset_info( +fn generate_virtual_keyset_info( inputs: Vec<ffi::PartitionDefinition>, generate_fks: bool, + options: &ffi::Options, ) -> ffi::CircuitKeys { - generate_generic_parameters( + let config = Config { + security_level: options.security_level, + maximum_acceptable_error_probability: options.maximum_acceptable_error_probability, + key_sharing: options.key_sharing, + ciphertext_modulus_log: options.ciphertext_modulus_log, + fft_precision: options.fft_precision, + complexity_model: &CpuComplexity::default(), + }; + generate_virtual_parameters( inputs .into_iter() .map( - |ffi::PartitionDefinition { precision, norm2 }| concrete_optimizer::optimization::dag::multi_parameters::generic_generation::PartitionDefinition { precision, norm2 }, + |ffi::PartitionDefinition { precision, norm2 }| concrete_optimizer::optimization::dag::multi_parameters::virtual_circuit::PartitionDefinition { precision, norm2 }, ) .collect(), generate_fks, + config ) .into() } @@ -999,9 +1009,10 @@ mod ffi { fn location_from_string(string: &str) -> Box<Location>; #[namespace = "concrete_optimizer::utils"] - fn generate_generic_keyset_info( + fn generate_virtual_keyset_info( partitions: Vec<PartitionDefinition>, generate_fks: bool, + options: &Options, ) -> CircuitKeys; #[namespace = "concrete_optimizer::utils"] diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index fe34c2444c..9d00ce3875 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -1440,7 +1440,7 @@ ::concrete_optimizer::Location *concrete_optimizer$utils$cxxbridge1$location_unk ::concrete_optimizer::Location *concrete_optimizer$utils$cxxbridge1$location_from_string(::rust::Str string) noexcept; -void concrete_optimizer$utils$cxxbridge1$generate_generic_keyset_info(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> *partitions, bool generate_fks, ::CircuitKeys *return$) noexcept; +void concrete_optimizer$utils$cxxbridge1$generate_virtual_keyset_info(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> *partitions, bool generate_fks, ::concrete_optimizer::Options const &options, ::CircuitKeys *return$) noexcept; ::concrete_optimizer::ExternalPartition *concrete_optimizer$utils$cxxbridge1$get_external_partition(::rust::String *name, ::std::uint64_t log2_polynomial_size, ::std::uint64_t glwe_dimension, ::std::uint64_t internal_dim, double max_variance, double variance) noexcept; @@ -1584,10 +1584,10 @@ ::rust::Box<::concrete_optimizer::Location> location_from_string(::rust::Str str return ::rust::Box<::concrete_optimizer::Location>::from_raw(concrete_optimizer$utils$cxxbridge1$location_from_string(string)); } -::CircuitKeys generate_generic_keyset_info(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> partitions, bool generate_fks) noexcept { +::CircuitKeys generate_virtual_keyset_info(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> partitions, bool generate_fks, ::concrete_optimizer::Options const &options) noexcept { ::rust::ManuallyDrop<::rust::Vec<::concrete_optimizer::utils::PartitionDefinition>> partitions$(::std::move(partitions)); ::rust::MaybeUninit<::CircuitKeys> return$; - concrete_optimizer$utils$cxxbridge1$generate_generic_keyset_info(&partitions$.value, generate_fks, &return$.value); + concrete_optimizer$utils$cxxbridge1$generate_virtual_keyset_info(&partitions$.value, generate_fks, options, &return$.value); return ::std::move(return$.value); } diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index 662493a2c4..8cc0a87068 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -1396,7 +1396,7 @@ ::rust::Box<::concrete_optimizer::Location> location_unknown() noexcept; ::rust::Box<::concrete_optimizer::Location> location_from_string(::rust::Str string) noexcept; -::CircuitKeys generate_generic_keyset_info(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> partitions, bool generate_fks) noexcept; +::CircuitKeys generate_virtual_keyset_info(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> partitions, bool generate_fks, ::concrete_optimizer::Options const &options) noexcept; ::rust::Box<::concrete_optimizer::ExternalPartition> get_external_partition(::rust::String name, ::std::uint64_t log2_polynomial_size, ::std::uint64_t glwe_dimension, ::std::uint64_t internal_dim, double max_variance, double variance) noexcept; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs index 7e88ea7bae..a9ca822f86 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs @@ -2,7 +2,6 @@ pub(crate) mod analyze; mod complexity; mod fast_keyswitch; mod feasible; -pub mod generic_generation; pub mod keys_spec; pub mod optimize; pub mod optimize_generic; @@ -11,6 +10,7 @@ mod partitionning; mod partitions; mod union_find; pub(crate) mod variance_constraint; +pub mod virtual_circuit; mod noise_expression; mod symbolic; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/generic_generation.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/virtual_circuit.rs similarity index 87% rename from compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/generic_generation.rs rename to compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/virtual_circuit.rs index 0e0c4fc087..a18a908db4 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/generic_generation.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/virtual_circuit.rs @@ -1,5 +1,4 @@ use crate::{ - computing_cost::cpu::CpuComplexity, config::ProcessingUnit, dag::{ operator::{FunctionTable, LevelledComplexity, Precision, Shape}, @@ -35,10 +34,10 @@ impl PartialOrd for PartitionDefinition { } } -pub fn generate_generic_parameters( - partitions: Vec<PartitionDefinition>, +fn generate_virtual_circuit( + partitions: &[PartitionDefinition], generate_fks: bool, -) -> CircuitKeys { +) -> unparametrized::Dag { let mut dag = unparametrized::Dag::new(); for def_a in partitions.iter() { @@ -99,18 +98,19 @@ pub fn generate_generic_parameters( } } } + dag +} + +pub fn generate_virtual_parameters( + partitions: Vec<PartitionDefinition>, + generate_fks: bool, + config: Config, +) -> CircuitKeys { + let dag = generate_virtual_circuit(partitions.as_slice(), generate_fks); let precisions: Vec<_> = partitions.iter().map(|def| def.precision).collect(); let n_partitions = precisions.len(); let p_cut = PartitionCut::maximal_partitionning(&dag); - let config = Config { - security_level: 128, - maximum_acceptable_error_probability: _4_SIGMA, - key_sharing: true, - ciphertext_modulus_log: 64, - fft_precision: 53, - complexity_model: &CpuComplexity::default(), - }; let search_space = SearchSpace::default_cpu(); let cache = decomposition::cache(128, ProcessingUnit::Cpu, None, true, 64, 53); let parameters = optimize( @@ -144,11 +144,20 @@ pub fn generate_generic_parameters( #[cfg(test)] mod test { - use super::{generate_generic_parameters, PartitionDefinition}; + use super::*; + use crate::computing_cost::cpu::CpuComplexity; #[test] fn test_generate_generic_parameters() { - let _ = generate_generic_parameters( + let config = Config { + security_level: 128, + maximum_acceptable_error_probability: _4_SIGMA, + key_sharing: true, + ciphertext_modulus_log: 64, + fft_precision: 53, + complexity_model: &CpuComplexity::default(), + }; + let a = generate_virtual_parameters( vec![ PartitionDefinition { precision: 3, @@ -164,6 +173,7 @@ mod test { }, ], true, + config, ); } } diff --git a/frontends/concrete-python/tests/compilation/test_restrictions.py b/frontends/concrete-python/tests/compilation/test_restrictions.py index ab42c5a520..4e06ef97f1 100644 --- a/frontends/concrete-python/tests/compilation/test_restrictions.py +++ b/frontends/concrete-python/tests/compilation/test_restrictions.py @@ -108,7 +108,7 @@ def test_generic_restriction(): Test that compiling a module works. """ - generic_keyset_info = KeysetInfo.generate_generic( + generic_keyset_info = KeysetInfo.generate_virtual( [PartitionDefinition(8, 10.0), PartitionDefinition(10, 10000.0)], True )