From 380a523f4765600eb2b7042394d99e4eb092b02d Mon Sep 17 00:00:00 2001 From: Bourgerie Quentin Date: Fri, 5 Jul 2024 17:02:01 +0200 Subject: [PATCH] docs(frontend): Proposal with constant --- .../levenshtein_distance/levenshtein_distance.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/frontends/concrete-python/examples/levenshtein_distance/levenshtein_distance.py b/frontends/concrete-python/examples/levenshtein_distance/levenshtein_distance.py index 59f01a4426..d62cc94e63 100644 --- a/frontends/concrete-python/examples/levenshtein_distance/levenshtein_distance.py +++ b/frontends/concrete-python/examples/levenshtein_distance/levenshtein_distance.py @@ -56,6 +56,10 @@ def mix(is_equal, if_equal, case_1, case_2, case_3): return fhe.if_then_else(is_equal, if_equal, 1 + min_123) + @fhe.function({"x": "clear"}) + def constant(x): + return fhe.zero() + x + # There is a single output in mix: it can go to # - input 1 of mix # - input 2 of mix @@ -71,6 +75,10 @@ def mix(is_equal, if_equal, case_1, case_2, case_3): fhe.Wire(fhe.AllOutputs(mix), fhe.Input(mix, 2)), fhe.Wire(fhe.AllOutputs(mix), fhe.Input(mix, 3)), fhe.Wire(fhe.AllOutputs(mix), fhe.Input(mix, 4)), + fhe.Wire(fhe.AllOutputs(constant), fhe.Input(mix, 1)), + fhe.Wire(fhe.AllOutputs(constant), fhe.Input(mix, 2)), + fhe.Wire(fhe.AllOutputs(constant), fhe.Input(mix, 3)), + fhe.Wire(fhe.AllOutputs(constant), fhe.Input(mix, 4)), ] ) @@ -116,9 +124,9 @@ def levenshtein_simulate(my_module, x, y): def levenshtein_fhe(my_module, x, y): """Compute the distance in FHE.""" if len(x) == 0: - return my_module.mix.encrypt(None, len(y), None, None, None)[1] + return my_module.constant.run(my_module.constant.encrypt(len(y))) if len(y) == 0: - return my_module.mix.encrypt(None, len(x), None, None, None)[1] + return my_module.constant.run(my_module.constant.encrypt(len(x))) if_equal = levenshtein_fhe(my_module, x[1:], y[1:]) case_1 = levenshtein_fhe(my_module, x[1:], y) @@ -208,7 +216,7 @@ def compile_module(mapping_to_int, args): ] my_module = LevenshsteinModule.compile( - {"equal": inputset_equal, "mix": inputset_mix}, + {"equal": inputset_equal, "mix": inputset_mix, "constant": [i for i in range(len(mapping_to_int))]}, show_mlir=args.show_mlir, p_error=10**-20, show_optimizer=args.show_optimizer,