diff --git a/frontends/concrete-python/examples/levenshtein_distance/levenshtein_distance.py b/frontends/concrete-python/examples/levenshtein_distance/levenshtein_distance.py index 7f9356761f..5a78ca4f4e 100644 --- a/frontends/concrete-python/examples/levenshtein_distance/levenshtein_distance.py +++ b/frontends/concrete-python/examples/levenshtein_distance/levenshtein_distance.py @@ -2,16 +2,13 @@ import time import argparse +import random from functools import lru_cache import numpy from concrete import fhe -# Parameters to be set by the user -max_string_length = 6 - - # Module FHE @fhe.module() class MyModule: @@ -53,6 +50,14 @@ def mix(is_equal, if_equal, case_1, case_2, case_3): ) +def random_int(mapping_to_int): + return numpy.random.randint(len(mapping_to_int)) + + +def random_pick_in_keys(mapping_to_int): + return random.choice(list(mapping_to_int)) + + # For now, we pick only small letters def random_letter_as_int(): return 97 + numpy.random.randint(26) @@ -62,8 +67,8 @@ def random_letter(): return chr(random_letter_as_int()) -def random_string(l): - return "".join([random_letter() for _ in range(l)]) +def random_string(mapping_to_int, l): + return "".join([random_pick_in_keys(mapping_to_int) for _ in range(l)]) def map_string_to_int(s): @@ -150,20 +155,34 @@ def manage_args(): action="store_true", help="Run random tests", ) + parser.add_argument( + "--alphabet", + dest="alphabet", + choices=["string", "STRING", "StRiNg", "ACTG"], + default="string", + help="Setting the alphabet", + ) + parser.add_argument( + "--max_string_length", + dest="max_string_length", + type=int, + default=4, + help="Setting the alphabet", + ) args = parser.parse_args() return args -def compile_module(args): +def compile_module(mapping_to_int, args): # Compilation - inputset_equal = [(random_letter_as_int(), random_letter_as_int()) for _ in range(1000)] + inputset_equal = [(random_int(mapping_to_int), random_int(mapping_to_int)) for _ in range(1000)] inputset_mix = [ ( numpy.random.randint(2), - numpy.random.randint(max_string_length), - numpy.random.randint(max_string_length), - numpy.random.randint(max_string_length), - numpy.random.randint(max_string_length), + numpy.random.randint(args.max_string_length), + numpy.random.randint(args.max_string_length), + numpy.random.randint(args.max_string_length), + numpy.random.randint(args.max_string_length), ) for _ in range(100) ] @@ -180,15 +199,37 @@ def compile_module(args): return my_module -def prepare_random_patterns(): +def prepare_alphabet_mapping(alphabet): + if alphabet == "string": + letters = "".join([chr(97 + i) for i in range(26)]) + elif alphabet == "STRING": + letters = "".join([chr(65 + i) for i in range(26)]) + elif alphabet == "StRiNg": + letters = "".join([chr(97 + i) for i in range(26)] + [chr(65 + i) for i in range(26)]) + elif alphabet == "ACTG": + letters = "ACTG" + else: + raise ValueError(f"Unknown alphabet {alphabet}") + + print(f"Alphabet is {letters}") + + mapping_to_int = {} + + for i, c in enumerate(letters): + mapping_to_int[c] = i + + return mapping_to_int + + +def prepare_random_patterns(mapping_to_int, args): # Random patterns of different lengths list_patterns = [] - for length_1 in range(max_string_length + 1): - for length_2 in range(max_string_length + 1): + for length_1 in range(args.max_string_length + 1): + for length_2 in range(args.max_string_length + 1): list_patterns += [ ( - random_string(length_1), - random_string(length_2), + random_string(mapping_to_int, length_1), + random_string(mapping_to_int, length_2), ) for _ in range(1) ] @@ -205,9 +246,6 @@ def compute_in_simulation(my_module, list_patterns): print(f" Computing Levenshtein between strings '{a}' and '{b}'", end="") - assert len(a) <= max_string_length - assert len(b) <= max_string_length - a_as_int = map_string_to_int(a) b_as_int = map_string_to_int(b) @@ -218,7 +256,7 @@ def compute_in_simulation(my_module, list_patterns): print(" - OK") -def compute_in_fhe(my_module, list_patterns): +def compute_in_fhe(my_module, list_patterns, mapping_to_int): # Key generation my_module.keygen() @@ -229,11 +267,8 @@ def compute_in_fhe(my_module, list_patterns): print(f" Computing Levenshtein between strings '{a}' and '{b}'", end="") - assert len(a) <= max_string_length - assert len(b) <= max_string_length - - a_as_int = map_string_to_int(a) - b_as_int = map_string_to_int(b) + a_as_int = [mapping_to_int[ai] for ai in a] + b_as_int = [mapping_to_int[bi] for bi in b] a_enc = tuple(my_module.equal.encrypt(ai, None)[0] for ai in a_as_int) b_enc = tuple(my_module.equal.encrypt(None, bi)[1] for bi in b_as_int) @@ -257,12 +292,13 @@ def main(): args = manage_args() # Do what the user requested - my_module = compile_module(args) - if args.autotest: - list_patterns = prepare_random_patterns() + print(f"Making random tests with alphabet {args.alphabet}\n") + mapping_to_int = prepare_alphabet_mapping(args.alphabet) + my_module = compile_module(mapping_to_int, args) + list_patterns = prepare_random_patterns(mapping_to_int, args) compute_in_simulation(my_module, list_patterns) - compute_in_fhe(my_module, list_patterns) + compute_in_fhe(my_module, list_patterns, mapping_to_int) if __name__ == "__main__":