From 514fe62364f77f03398b03daea728cfb9fc96aca Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 24 Oct 2024 11:43:04 +0100 Subject: [PATCH] refactor(frontend/compiler): single API for import/export of TFHErs int --- .../concretelang/ClientLib/ClientLib.h | 19 +-- .../lib/Bindings/Python/CompilerAPIModule.cpp | 35 +--- .../Python/concrete/compiler/tfhers_int.py | 70 +------- .../compiler/lib/ClientLib/ClientLib.cpp | 153 +++++++----------- .../concrete/fhe/tfhers/bridge.py | 34 +--- 5 files changed, 79 insertions(+), 232 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h index 99f39a248a..1aacf41efa 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h @@ -34,18 +34,13 @@ using concretelang::values::Value; namespace concretelang { namespace clientlib { -Result -importTfhersFheUint8(llvm::ArrayRef serializedFheUint8, - TfhersFheIntDescription desc, uint32_t encryptionKeyId, - double encryptionVariance); -Result> exportTfhersFheUint8(TransportValue value, - TfhersFheIntDescription info); -Result -importTfhersFheInt8(llvm::ArrayRef serializedFheUint8, - TfhersFheIntDescription desc, uint32_t encryptionKeyId, - double encryptionVariance); -Result> exportTfhersFheInt8(TransportValue value, - TfhersFheIntDescription info); +Result importTfhersInteger(llvm::ArrayRef buffer, + TfhersFheIntDescription integerDesc, + uint32_t encryptionKeyId, + double encryptionVariance); + +Result> +exportTfhersInteger(TransportValue value, TfhersFheIntDescription integerDesc); class ClientCircuit { diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 9e90e1f2d6..bbd12d4067 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -1881,14 +1881,14 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( "Return the `circuit` ClientCircuit.", arg("circuit")) .doc() = "Client-side / Encryption program"; - m.def("import_tfhers_fheuint8", + m.def("import_tfhers_int", [](const pybind11::bytes &serialized_fheuint, TfhersFheIntDescription info, uint32_t encryptionKeyId, double encryptionVariance) { const std::string &buffer_str = serialized_fheuint; std::vector buffer(buffer_str.begin(), buffer_str.end()); auto arrayRef = llvm::ArrayRef(buffer); - auto valueOrError = ::concretelang::clientlib::importTfhersFheUint8( + auto valueOrError = ::concretelang::clientlib::importTfhersInteger( arrayRef, info, encryptionKeyId, encryptionVariance); if (valueOrError.has_error()) { throw std::runtime_error(valueOrError.error().mesg); @@ -1896,34 +1896,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( return TransportValue{valueOrError.value()}; }); - m.def("export_tfhers_fheuint8", - [](TransportValue fheuint, TfhersFheIntDescription info) { - auto result = - ::concretelang::clientlib::exportTfhersFheUint8(fheuint, info); - if (result.has_error()) { - throw std::runtime_error(result.error().mesg); - } - return result.value(); - }); - - m.def("import_tfhers_fheint8", - [](const pybind11::bytes &serialized_fheuint, - TfhersFheIntDescription info, uint32_t encryptionKeyId, - double encryptionVariance) { - const std::string &buffer_str = serialized_fheuint; - std::vector buffer(buffer_str.begin(), buffer_str.end()); - auto arrayRef = llvm::ArrayRef(buffer); - auto valueOrError = ::concretelang::clientlib::importTfhersFheInt8( - arrayRef, info, encryptionKeyId, encryptionVariance); - if (valueOrError.has_error()) { - throw std::runtime_error(valueOrError.error().mesg); - } - return TransportValue{valueOrError.value()}; - }); - - m.def("export_tfhers_fheint8", [](TransportValue fheuint, - TfhersFheIntDescription info) { - auto result = ::concretelang::clientlib::exportTfhersFheInt8(fheuint, info); + m.def("export_tfhers_int", [](TransportValue fheuint, + TfhersFheIntDescription info) { + auto result = ::concretelang::clientlib::exportTfhersInteger(fheuint, info); if (result.has_error()) { throw std::runtime_error(result.error().mesg); } diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py index 89cfd38c9b..9491743c2e 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py @@ -3,10 +3,8 @@ # pylint: disable=no-name-in-module,import-error, from mlir._mlir_libs._concretelang._compiler import ( - import_tfhers_fheuint8 as _import_tfhers_fheuint8, - export_tfhers_fheuint8 as _export_tfhers_fheuint8, - import_tfhers_fheint8 as _import_tfhers_fheint8, - export_tfhers_fheint8 as _export_tfhers_fheint8, + import_tfhers_int as _import_tfhers_int, + export_tfhers_int as _export_tfhers_int, TfhersFheIntDescription as _TfhersFheIntDescription, TransportValue, ) @@ -184,7 +182,7 @@ class TfhersExporter: """A helper class to import and export TFHErs big integers.""" @staticmethod - def export_fheuint8(value: TransportValue, info: TfhersFheIntDescription) -> bytes: + def export_int(value: TransportValue, info: TfhersFheIntDescription) -> bytes: """Convert Concrete value to TFHErs and serialize it. Args: @@ -195,7 +193,7 @@ def export_fheuint8(value: TransportValue, info: TfhersFheIntDescription) -> byt TypeError: if wrong input types Returns: - bytes: converted and serialized fheuint8 + bytes: converted and serialized TFHErs integer """ if not isinstance(value, TransportValue): raise TypeError(f"value must be of type TransportValue, not {type(value)}") @@ -203,16 +201,16 @@ def export_fheuint8(value: TransportValue, info: TfhersFheIntDescription) -> byt raise TypeError( f"info must be of type TfhersFheIntDescription, not {type(info)}" ) - return bytes(_export_tfhers_fheuint8(value, info.cpp())) + return bytes(_export_tfhers_int(value, info.cpp())) @staticmethod - def import_fheuint8( + def import_int( buffer: bytes, info: TfhersFheIntDescription, keyid: int, variance: float ) -> TransportValue: """Unserialize and convert from TFHErs to Concrete value. Args: - buffer (bytes): serialized fheuint8 + buffer (bytes): serialized TFHErs integer info (TfhersFheIntDescription): description of the TFHErs integer to import keyid (int): id of the key used for encryption variance (float): variance used for encryption @@ -233,56 +231,4 @@ def import_fheuint8( raise TypeError(f"keyid must be of type int, not {type(keyid)}") if not isinstance(variance, float): raise TypeError(f"variance must be of type float, not {type(variance)}") - return _import_tfhers_fheuint8(buffer, info.cpp(), keyid, variance) - - @staticmethod - def export_fheint8(value: TransportValue, info: TfhersFheIntDescription) -> bytes: - """Convert Concrete value to TFHErs and serialize it. - - Args: - value (Value): value to export - info (TfhersFheIntDescription): description of the TFHErs integer to export to - - Raises: - TypeError: if wrong input types - - Returns: - bytes: converted and serialized fheuint8 - """ - if not isinstance(value, TransportValue): - raise TypeError(f"value must be of type Value, not {type(value)}") - if not isinstance(info, TfhersFheIntDescription): - raise TypeError( - f"info must be of type TfhersFheIntDescription, not {type(info)}" - ) - return bytes(_export_tfhers_fheint8(value, info.cpp())) - - @staticmethod - def import_fheint8( - buffer: bytes, info: TfhersFheIntDescription, keyid: int, variance: float - ) -> TransportValue: - """Unserialize and convert from TFHErs to Concrete value. - - Args: - buffer (bytes): serialized fheuint8 - info (TfhersFheIntDescription): description of the TFHErs integer to import - keyid (int): id of the key used for encryption - variance (float): variance used for encryption - - Raises: - TypeError: if wrong input types - - Returns: - Value: unserialized and converted value - """ - if not isinstance(buffer, bytes): - raise TypeError(f"buffer must be of type bytes, not {type(buffer)}") - if not isinstance(info, TfhersFheIntDescription): - raise TypeError( - f"info must be of type TfhersFheIntDescription, not {type(info)}" - ) - if not isinstance(keyid, int): - raise TypeError(f"keyid must be of type int, not {type(keyid)}") - if not isinstance(variance, float): - raise TypeError(f"variance must be of type float, not {type(variance)}") - return _import_tfhers_fheint8(buffer, info.cpp(), keyid, variance) + return _import_tfhers_int(buffer, info.cpp(), keyid, variance) diff --git a/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp b/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp index 1ed0493a78..3fee8e9586 100644 --- a/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp +++ b/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp @@ -187,92 +187,34 @@ Result ClientProgram::getClientCircuit(std::string circuitName) { "`"); } -Result -importTfhersFheUint8(llvm::ArrayRef serializedFheUint8, - TfhersFheIntDescription desc, uint32_t encryptionKeyId, - double encryptionVariance) { - if (desc.width != 8 || desc.is_signed == true) { - return StringError( - "trying to import FheUint8 but description doesn't match this type"); - } - - auto dims = std::vector({desc.n_cts, desc.lwe_size}); - auto outputTensor = Tensor::fromDimensions(dims); - auto err = concrete_cpu_tfhers_uint8_to_lwe_array( - serializedFheUint8.data(), serializedFheUint8.size(), - outputTensor.values.data(), desc); - if (err) { - return StringError("couldn't convert fheuint to lwe array: err()") - << err << ")"; - } - - auto value = Value{outputTensor}.intoRawTransportValue(); - auto lwe = value.asBuilder().initTypeInfo().initLweCiphertext(); - lwe.setIntegerPrecision(64); - // dimensions - lwe.initAbstractShape().setDimensions({(uint32_t)desc.n_cts}); - lwe.initConcreteShape().setDimensions( - {(uint32_t)desc.n_cts, (uint32_t)desc.lwe_size}); - // encryption - auto encryption = lwe.initEncryption(); - encryption.setLweDimension((uint32_t)desc.lwe_size - 1); - encryption.initModulus().initMod().initNative(); - encryption.setKeyId(encryptionKeyId); - encryption.setVariance(encryptionVariance); - // Encoding - auto encoding = lwe.initEncoding(); - auto integer = encoding.initInteger(); - integer.setIsSigned(false); - integer.setWidth(std::log2(desc.message_modulus * desc.carry_modulus)); - integer.initMode().initNative(); - - return value; -} - -Result> -exportTfhersFheUint8(TransportValue value, TfhersFheIntDescription desc) { - if (desc.width != 8 || desc.is_signed == true) { - return StringError( - "trying to export FheUint8 but description doesn't match this type"); - } - - auto fheuint = Value::fromRawTransportValue(value); - if (fheuint.isScalar()) { - return StringError("expected a tensor, but value is a scalar"); - } - auto tensorOrError = fheuint.getTensor(); - if (!tensorOrError.has_value()) { - return StringError("couldn't get tensor from value"); - } - const size_t bufferSize = - concrete_cpu_tfhers_fheint_buffer_size_u64(desc.lwe_size, desc.n_cts); - std::vector buffer(bufferSize, 0); - auto flatData = tensorOrError.value().values; - auto size = concrete_cpu_lwe_array_to_tfhers_uint8( - flatData.data(), buffer.data(), buffer.size(), desc); - if (size == 0) { - return StringError("couldn't convert lwe array to fheuint8"); - } - // we truncate to the serialized data - assert(size <= buffer.size()); - buffer.resize(size, 0); - return buffer; -} - -Result -importTfhersFheInt8(llvm::ArrayRef serializedFheUint8, - TfhersFheIntDescription desc, uint32_t encryptionKeyId, - double encryptionVariance) { - if (desc.width != 8 || desc.is_signed == false) { - return StringError( - "trying to import FheInt8 but description doesn't match this type"); +Result importTfhersInteger(llvm::ArrayRef buffer, + TfhersFheIntDescription integerDesc, + uint32_t encryptionKeyId, + double encryptionVariance) { + + // Select conversion function based on integer description + std::function + conversion_func; + if (integerDesc.width == 8) { + if (integerDesc.is_signed) { // fheint8 + conversion_func = concrete_cpu_tfhers_int8_to_lwe_array; + } else { // fheuint8 + conversion_func = concrete_cpu_tfhers_uint8_to_lwe_array; + } + } else { + std::ostringstream stringStream; + stringStream << "importTfhersInteger: no support for " << integerDesc.width + << "bits " << (integerDesc.is_signed ? "signed" : "unsigned") + << " integer"; + std::string errorMsg = stringStream.str(); + return StringError(errorMsg); } - auto dims = std::vector({desc.n_cts, desc.lwe_size}); + auto dims = std::vector({integerDesc.n_cts, integerDesc.lwe_size}); auto outputTensor = Tensor::fromDimensions(dims); - auto err = concrete_cpu_tfhers_int8_to_lwe_array( - serializedFheUint8.data(), serializedFheUint8.size(), - outputTensor.values.data(), desc); + auto err = conversion_func(buffer.data(), buffer.size(), + outputTensor.values.data(), integerDesc); if (err) { return StringError("couldn't convert fheint to lwe array"); } @@ -281,30 +223,47 @@ importTfhersFheInt8(llvm::ArrayRef serializedFheUint8, auto lwe = value.asBuilder().initTypeInfo().initLweCiphertext(); lwe.setIntegerPrecision(64); // dimensions - lwe.initAbstractShape().setDimensions({(uint32_t)desc.n_cts}); + lwe.initAbstractShape().setDimensions({(uint32_t)integerDesc.n_cts}); lwe.initConcreteShape().setDimensions( - {(uint32_t)desc.n_cts, (uint32_t)desc.lwe_size}); + {(uint32_t)integerDesc.n_cts, (uint32_t)integerDesc.lwe_size}); // encryption auto encryption = lwe.initEncryption(); - encryption.setLweDimension((uint32_t)desc.lwe_size - 1); + encryption.setLweDimension((uint32_t)integerDesc.lwe_size - 1); encryption.initModulus().initMod().initNative(); encryption.setKeyId(encryptionKeyId); encryption.setVariance(encryptionVariance); // Encoding auto encoding = lwe.initEncoding(); auto integer = encoding.initInteger(); - integer.setIsSigned(false); - integer.setWidth(std::log2(desc.message_modulus * desc.carry_modulus)); + integer.setIsSigned( + false); // should always be unsigned as its for the radix encoded cts + integer.setWidth( + std::log2(integerDesc.message_modulus * integerDesc.carry_modulus)); integer.initMode().initNative(); return value; } -Result> exportTfhersFheInt8(TransportValue value, - TfhersFheIntDescription desc) { - if (desc.width != 8 || desc.is_signed == false) { - return StringError( - "trying to export FheInt8 but description doesn't match this type"); +Result> +exportTfhersInteger(TransportValue value, TfhersFheIntDescription integerDesc) { + // Select conversion function based on integer description + std::function + conversion_func; + std::function buffer_size_func; + if (integerDesc.width == 8) { + if (integerDesc.is_signed) { // fheint8 + conversion_func = concrete_cpu_lwe_array_to_tfhers_int8; + } else { // fheuint8 + conversion_func = concrete_cpu_lwe_array_to_tfhers_uint8; + } + } else { + std::ostringstream stringStream; + stringStream << "exportTfhersInteger: no support for " << integerDesc.width + << "bits " << (integerDesc.is_signed ? "signed" : "unsigned") + << " integer"; + std::string errorMsg = stringStream.str(); + return StringError(errorMsg); } auto fheuint = Value::fromRawTransportValue(value); @@ -315,12 +274,12 @@ Result> exportTfhersFheInt8(TransportValue value, if (!tensorOrError.has_value()) { return StringError("couldn't get tensor from value"); } - size_t buffer_size = - concrete_cpu_tfhers_fheint_buffer_size_u64(desc.lwe_size, desc.n_cts); + size_t buffer_size = concrete_cpu_tfhers_fheint_buffer_size_u64( + integerDesc.lwe_size, integerDesc.n_cts); std::vector buffer(buffer_size, 0); auto flat_data = tensorOrError.value().values; - auto size = concrete_cpu_lwe_array_to_tfhers_int8( - flat_data.data(), buffer.data(), buffer.size(), desc); + auto size = conversion_func(flat_data.data(), buffer.data(), buffer.size(), + integerDesc); if (size == 0) { return StringError("couldn't convert lwe array to fheint8"); } diff --git a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py index 27a5811b05..de08b36842 100644 --- a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py +++ b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py @@ -118,22 +118,9 @@ def import_value(self, buffer: bytes, input_idx: int) -> Value: raise ValueError(msg) fheint_desc = self._description_from_type(input_type) - - bit_width = input_type.bit_width - signed = input_type.is_signed keyid = self._input_keyid(input_idx) variance = self._input_variance(input_idx) - if bit_width == 8: - if not signed: - return Value(TfhersExporter.import_fheuint8(buffer, fheint_desc, keyid, variance)) - else: - return Value(TfhersExporter.import_fheint8(buffer, fheint_desc, keyid, variance)) - - msg = ( # pragma: no cover - f"importing {'signed' if signed else 'unsigned'} integers of {bit_width}bits is not" - " yet supported" - ) - raise NotImplementedError(msg) # pragma: no cover + return Value(TfhersExporter.import_int(buffer, fheint_desc, keyid, variance)) def export_value(self, value: Value, output_idx: int) -> bytes: """Export a value as a serialized TFHErs integer. @@ -151,24 +138,9 @@ def export_value(self, value: Value, output_idx: int) -> bytes: raise ValueError(msg) fheint_desc = self._description_from_type(output_type) - - bit_width = output_type.bit_width - signed = output_type.is_signed - if bit_width == 8: - if not signed: - return TfhersExporter.export_fheuint8( - value._inner, fheint_desc # pylint: disable=protected-access - ) - else: - return TfhersExporter.export_fheint8( - value._inner, fheint_desc # pylint: disable=protected-access - ) - - msg = ( # pragma: no cover - f"exporting value to {'signed' if signed else 'unsigned'} integers of {bit_width}bits" - " is not yet supported" + return TfhersExporter.export_int( + value._inner, fheint_desc # pylint: disable=protected-access ) - raise NotImplementedError(msg) # pragma: no cover def serialize_input_secret_key(self, input_idx: int) -> bytes: """Serialize secret key used for a specific input.