Skip to content

Commit

Permalink
chore: add load tfhers params from dict
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-stoian-zama committed Jan 14, 2025
1 parent 57eef1c commit 4454f32
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 2 deletions.
18 changes: 18 additions & 0 deletions frontends/concrete-python/concrete/fhe/tfhers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import json
from math import log2
from typing import Dict

from .bridge import new_bridge
from .dtypes import (
Expand Down Expand Up @@ -41,6 +42,23 @@ def get_type_from_params(
with open(path_to_params_json) as f:
crypto_param_dict = json.load(f)

return get_type_from_params_dict(crypto_param_dict, is_signed, precision)


def get_type_from_params_dict(
crypto_param_dict: Dict, is_signed: bool, precision: int
) -> TFHERSIntegerType:
"""Get a TFHE-rs integer type from TFHE-rs parameters in JSON format.
Args:
crypto_param_dict (Dict): dictionary of TFHE-rs parameters
is_signed (bool): sign of the result type
precision (int): precision of the result type
Returns:
TFHERSIntegerType: constructed type from the loaded parameters
"""

lwe_dim = crypto_param_dict["lwe_dimension"]
glwe_dim = crypto_param_dict["glwe_dimension"]
poly_size = crypto_param_dict["polynomial_size"]
Expand Down
22 changes: 20 additions & 2 deletions frontends/concrete-python/examples/tfhers-ml/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,26 @@
IS_SIGNED = True
#######################################

tfhers_type = tfhers.get_type_from_params(
TFHERS_PARAMS_FILE,
PARAMS_DICT = {
"lwe_dimension": 902,
"glwe_dimension": 1,
"polynomial_size": 4096,
"lwe_noise_distribution": {"Gaussian": {"std": 1.0994794733558207e-6, "mean": 0.0}},
"glwe_noise_distribution": {"Gaussian": {"std": 2.168404344971009e-19, "mean": 0.0}},
"pbs_base_log": 15,
"pbs_level": 2,
"ks_base_log": 3,
"ks_level": 6,
"message_modulus": 4,
"carry_modulus": 8,
"max_noise_level": 10,
"log2_p_fail": -64.084,
"ciphertext_modulus": {"modulus": 0, "scalar_bits": 64},
"encryption_key_choice": "Big",
}

tfhers_type = tfhers.get_type_from_params_dict(
PARAMS_DICT,
is_signed=IS_SIGNED,
precision=FHEUINT_PRECISION,
)
Expand Down
35 changes: 35 additions & 0 deletions frontends/concrete-python/tests/dtypes/test_tfhers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
Tests of `TFHERSIntegerType` data type.
"""

import json
import tempfile

import numpy as np
import pytest

Expand All @@ -18,6 +21,24 @@
tfhers.EncryptionKeyChoice.BIG,
)

DEFAULT_TFHERS_PARAMS_DICT = {
"lwe_dimension": 902,
"glwe_dimension": 1,
"polynomial_size": 4096,
"lwe_noise_distribution": {"Gaussian": {"std": 1.0994794733558207e-6, "mean": 0.0}},
"glwe_noise_distribution": {"Gaussian": {"std": 2.168404344971009e-19, "mean": 0.0}},
"pbs_base_log": 15,
"pbs_level": 2,
"ks_base_log": 3,
"ks_level": 6,
"message_modulus": 4,
"carry_modulus": 8,
"max_noise_level": 10,
"log2_p_fail": -64.084,
"ciphertext_modulus": {"modulus": 0, "scalar_bits": 64},
"encryption_key_choice": "Big",
}


def parameterize_partial_dtype(partial_dtype) -> tfhers.TFHERSIntegerType:
"""Create a tfhers type from a partial func missing tfhers params.
Expand Down Expand Up @@ -150,3 +171,17 @@ def test_tfhers_encryption_variance(crypto_params: tfhers.CryptoParams):
return
assert crypto_params.encryption_key_choice == tfhers.EncryptionKeyChoice.SMALL
assert crypto_params.encryption_variance() == crypto_params.lwe_noise_distribution**2


@pytest.mark.parametrize("params_dict", (DEFAULT_TFHERS_PARAMS_DICT,))
def test_load_tfhers_params_dict(params_dict):
tfhers.get_type_from_params_dict(params_dict, True, 8)


@pytest.mark.parametrize("params_dict", (DEFAULT_TFHERS_PARAMS_DICT,))
def test_load_tfhers_params_file(params_dict):
fpath = tempfile.mktemp()

Check failure

Code scanning / CodeQL

Insecure temporary file High test

Call to deprecated function tempfile.mktemp may be insecure.
with open(fpath, "wt") as f:
f.write(json.dumps(DEFAULT_TFHERS_PARAMS_DICT))

tfhers.get_type_from_params(fpath, True, 8)

0 comments on commit 4454f32

Please sign in to comment.