Skip to content

Commit

Permalink
docs(frontend): reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
bcm-at-zama committed Jul 9, 2024
1 parent 0c2fc5f commit def0699
Showing 1 changed file with 106 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,79 +12,67 @@

class Alphabet:

letters = None
mapping_to_int = {}
module = None
letters: str = None
mapping_to_int: dict = {}

def set_lowercase(self):
@staticmethod
def lowercase():
"""Set lower case alphabet."""
self.letters = "".join([chr(97 + i) for i in range(26)])
return Alphabet("abcdefghijklmnopqrstuvwxyz")

def set_uppercase(self):
@staticmethod
def uppercase():
"""Set upper case alphabet."""
self.letters = "".join([chr(65 + i) for i in range(26)])
return Alphabet("ABCDEFGHIJKLMNOPQRSTUVWXYZ")

def set_anycase(self):
@staticmethod
def anycase():
"""Set any-case alphabet."""
self.letters = "".join([chr(97 + i) for i in range(26)] + [chr(65 + i) for i in range(26)])
return Alphabet.lowercase() + Alphabet.uppercase()

def set_dna(self):
@staticmethod
def dna():
"""Set DNA alphabet."""
self.letters = "ACTG"
return Alphabet("ATGC")

def return_available_alphabets():
def __init__(self, letters: str):
self.letters = letters

for i, c in enumerate(self.letters):
self.mapping_to_int[c] = i

def __add__(self, other: "Alphabet") -> "Alphabet":
return Alphabet(self.letters + other.letters)

def return_available_alphabets() -> list:
"""Return available alphabets."""
return ["string", "STRING", "StRiNg", "ACTG"]

def check_alphabet(self, alphabet_name):
"""Check an alphabet is available."""
@staticmethod
def init_by_name(alphabet_name: str) -> "Alphabet":
"""Set the alphabet."""
assert (
alphabet_name in Alphabet.return_available_alphabets()
), f"Unknown alphabet {alphabet_name}"

def set_alphabet(self, alphabet_name, verbose=True):
"""Set the alphabet."""
self.check_alphabet(alphabet_name)

if alphabet_name == "string":
self.set_lowercase()
return Alphabet.lowercase()
if alphabet_name == "STRING":
self.set_uppercase()
return Alphabet.uppercase()
if alphabet_name == "StRiNg":
self.set_anycase()
return Alphabet.anycase()
if alphabet_name == "ACTG":
self.set_dna()

if verbose:
print(f"Making random tests with alphabet {alphabet_name}")
print(f"Letters are {self.letters}\n")
return Alphabet.dna()

for i, c in enumerate(self.letters):
self.mapping_to_int[c] = i

def check_string_is_in_alphabet(self, string):
"""Check a string is a valid string of an alphabet."""
assert len(self.mapping_to_int) > 0, "Mapping not defined"

for c in string:
if c not in self.mapping_to_int:
raise ValueError(
f"Char {c} of {string} is not in alphabet {list(self.mapping_to_int.keys())}, please choose the right --alphabet"
)

def _random_pick_in_values(self):
def random_pick_in_values(self) -> int:
"""Pick the integer-encoding of a random char in an alphabet."""
return numpy.random.randint(len(self.mapping_to_int))

def _random_pick_in_keys(self):
"""Pick a random char in an alphabet."""
return random.choice(list(self.mapping_to_int))

def _random_string(self, l):
def _random_string(self, length: int) -> str:
"""Pick a random string in the alphabet."""
return "".join([self._random_pick_in_keys() for _ in range(l)])
return "".join([random.choice(list(self.mapping_to_int)) for _ in range(length)])

def prepare_random_patterns(self, len_min, len_max, nb_strings):
def prepare_random_patterns(self, len_min: int, len_max: int, nb_strings: int) -> list:
"""Prepare random patterns of different lengths."""
assert len(self.mapping_to_int) > 0, "Mapping not defined"

Expand All @@ -102,28 +90,60 @@ def prepare_random_patterns(self, len_min, len_max, nb_strings):

return list_patterns

def encode_string(self, string):
def encode(self, string: str) -> tuple:
"""Encode a string, ie map it to integers using the alphabet."""

assert len(self.mapping_to_int) > 0, "Mapping not defined"

for si in string:
if si not in self.mapping_to_int:
raise ValueError(
f"Char {si} of {string} is not in alphabet {list(self.mapping_to_int.keys())}, please choose the right --alphabet"
)

return tuple([self.mapping_to_int[si] for si in string])

def encode_and_encrypt_strings(self, a, b):

class LevenshteinDistance:
alphabet: Alphabet
module: fhe.module

def __init__(self, alphabet: Alphabet, args):
self.alphabet = alphabet

self._compile_module(args)

def calculate(self, a: str, b: str, mode: str, show_distance: bool = False):
"""Compute a distance between two strings, either in fhe or in simulate."""
if mode == "simulate":
self._compute_in_simulation([(a, b)])
else:
assert mode == "fhe", "Only 'simulate' and 'fhe' mode are available"
self._compute_in_fhe([(a, b)], show_distance=show_distance)

def calculate_list(self, l: list, mode: str):
"""Compute a distance between strings of a list, either in fhe or in simulate."""
for (a, b) in l:
self.calculate(a, b, mode)

def _encode_and_encrypt_strings(self, a: str, b: str) -> tuple:
"""Encode a string, ie map it to integers using the alphabet, and then encrypt the integers."""
a_as_int = self.encode_string(a)
b_as_int = self.encode_string(b)
a_as_int = self.alphabet.encode(a)
b_as_int = self.alphabet.encode(b)

a_enc = tuple(self.module.equal.encrypt(ai, None)[0] for ai in a_as_int)
b_enc = tuple(self.module.equal.encrypt(None, bi)[1] for bi in b_as_int)

return a_enc, b_enc

def compile_module(self, args):
def _compile_module(self, args):
"""Compile the FHE module."""
assert len(self.mapping_to_int) > 0, "Mapping not defined"
assert len(self.alphabet.mapping_to_int) > 0, "Mapping not defined"

inputset_equal = [
(
self._random_pick_in_values(),
self._random_pick_in_values(),
self.alphabet.random_pick_in_values(),
self.alphabet.random_pick_in_values(),
)
for _ in range(1000)
]
Expand All @@ -135,14 +155,14 @@ def compile_module(self, args):
numpy.random.randint(args.max_string_length),
numpy.random.randint(args.max_string_length),
)
for _ in range(100)
for _ in range(1000)
]

self.module = LevenshsteinModule.compile(
{
"equal": inputset_equal,
"mix": inputset_mix,
"constant": [i for i in range(len(self.mapping_to_int))],
"constant": [i for i in range(len(self.alphabet.mapping_to_int))],
},
show_mlir=args.show_mlir,
p_error=10**-20,
Expand All @@ -151,36 +171,31 @@ def compile_module(self, args):
min_max_strategy_preference=fhe.MinMaxStrategy.ONE_TLU_PROMOTED,
)

def compute_in_simulation(self, list_patterns):
def _compute_in_simulation(self, list_patterns: list):
"""Check equality between distance in simulation and clear distance."""
print("Computations in simulation\n")

for a, b in list_patterns:

print(f" Computing Levenshtein between strings '{a}' and '{b}'", end="")

a_as_int = self.encode_string(a)
b_as_int = self.encode_string(b)
a_as_int = self.alphabet.encode(a)
b_as_int = self.alphabet.encode(b)

l1_simulate = levenshtein_simulate(self.module, a_as_int, b_as_int)
l1_clear = levenshtein_clear(a_as_int, b_as_int)

assert l1_simulate == l1_clear, f" {l1_simulate=} and {l1_clear=} are different"
print(" - OK")

def compute_in_fhe(self, list_patterns, verbose=True, show_distance=False):
def _compute_in_fhe(self, list_patterns: list, show_distance: bool = False):
"""Check equality between distance in FHE and clear distance."""
self.module.keygen()

# Checks in FHE
if verbose:
print("\nComputations in FHE\n")

for a, b in list_patterns:

print(f" Computing Levenshtein between strings '{a}' and '{b}'", end="")

a_enc, b_enc = self.encode_and_encrypt_strings(a, b)
a_enc, b_enc = self._encode_and_encrypt_strings(a, b)

time_begin = time.time()
l1_fhe_enc = levenshtein_fhe(self.module, a_enc, b_enc)
Expand All @@ -202,12 +217,12 @@ def compute_in_fhe(self, list_patterns, verbose=True, show_distance=False):
@fhe.module()
class LevenshsteinModule:
@fhe.function({"x": "encrypted", "y": "encrypted"})
def equal(x, y):
def equal(x: int, y: int):
"""Assert equality between two chars of the alphabet."""
return x == y

@fhe.function({"x": "clear"})
def constant(x):
def constant(x: int):
return fhe.zero() + x

@fhe.function(
Expand All @@ -219,7 +234,7 @@ def constant(x):
"case_3": "encrypted",
}
)
def mix(is_equal, if_equal, case_1, case_2, case_3):
def mix(is_equal: bool, if_equal: int, case_1: int, case_2: int, case_3: int):
"""Compute the min of (case_1, case_2, case_3), and then return `if_equal` if `is_equal` is
True, or the min in the other case."""
min_12 = numpy.minimum(case_1, case_2)
Expand Down Expand Up @@ -251,7 +266,7 @@ def mix(is_equal, if_equal, case_1, case_2, case_3):


@lru_cache
def levenshtein_clear(x, y):
def levenshtein_clear(x: str, y: str):
"""Compute the distance in clear, for reference and comparison."""
if len(x) == 0:
return len(y)
Expand All @@ -269,7 +284,7 @@ def levenshtein_clear(x, y):


@lru_cache
def levenshtein_simulate(module, x, y):
def levenshtein_simulate(module: fhe.module, x: str, y: str):
"""Compute the distance in simulation."""
if len(x) == 0:
return len(y)
Expand All @@ -288,7 +303,7 @@ def levenshtein_simulate(module, x, y):


@lru_cache
def levenshtein_fhe(module, x, y):
def levenshtein_fhe(module: fhe.module, x: str, y: str):
"""Compute the distance in FHE."""
if len(x) == 0:
return module.constant.run(module.constant.encrypt(len(y)))
Expand Down Expand Up @@ -374,30 +389,32 @@ def main():

# Do what the user requested
if args.autotest:
alphabet = Alphabet()
alphabet.set_alphabet(args.alphabet)

alphabet.compile_module(args)
alphabet = Alphabet.init_by_name(args.alphabet)
levenshtein_distance = LevenshteinDistance(alphabet, args)

print(f"Making random tests with alphabet {args.alphabet}")
print(f"Letters are {alphabet.letters}\n")

list_patterns = alphabet.prepare_random_patterns(0, args.max_string_length, 1)
alphabet.compute_in_simulation(list_patterns)
alphabet.compute_in_fhe(list_patterns)
print("Computations in simulation\n")
levenshtein_distance.calculate_list(list_patterns, mode="simulate")
print("\nComputations in FHE\n")
levenshtein_distance.calculate_list(list_patterns, mode="fhe")
print("")

if args.autoperf:
alphabet = Alphabet()

for alphabet_name in ["ACTG", "string", "STRING", "StRiNg"]:
print(
f"Typical performances for alphabet {alphabet_name}, with string of maximal length:\n"
)

alphabet.set_alphabet(alphabet_name, verbose=False)

alphabet.compile_module(args)
alphabet = Alphabet.init_by_name(alphabet_name)
levenshtein_distance = LevenshteinDistance(alphabet, args)
list_patterns = alphabet.prepare_random_patterns(
args.max_string_length, args.max_string_length, 3
)
alphabet.compute_in_fhe(list_patterns, verbose=False)
levenshtein_distance.calculate_list(list_patterns, mode="fhe")
print("")

if args.distance != None:
Expand All @@ -411,14 +428,11 @@ def main():
"Warning, --max_string_length was smaller than lengths of the input strings, fixing it"
)

alphabet = Alphabet()
alphabet.set_alphabet(args.alphabet, verbose=False)

alphabet.compile_module(args)
alphabet.check_string_is_in_alphabet(args.distance[0])
alphabet.check_string_is_in_alphabet(args.distance[1])
list_patterns = [args.distance]
alphabet.compute_in_fhe(list_patterns, verbose=False, show_distance=True)
alphabet = Alphabet.init_by_name(args.alphabet)
levenshtein_distance = LevenshteinDistance(alphabet, args)
levenshtein_distance.calculate(
args.distance[0], args.distance[1], mode="fhe", show_distance=True
)
print("")

print("Successful end\n")
Expand Down

0 comments on commit def0699

Please sign in to comment.