-
Notifications
You must be signed in to change notification settings - Fork 151
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
docs(frontend): adding a use-case for Levenshtein distance
closes #zama-ai/concrete-internal#750
- Loading branch information
1 parent
f299adf
commit 122de8c
Showing
1 changed file
with
190 additions
and
0 deletions.
There are no files selected for viewing
190 changes: 190 additions & 0 deletions
190
frontends/concrete-python/examples/levenshtein_distance/levenshtein_distance.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
from concrete import fhe | ||
import numpy | ||
import time | ||
|
||
# Parameters to be set by the user | ||
max_string_length = 6 | ||
|
||
# Module FHE | ||
@fhe.module() | ||
class MyModule: | ||
@fhe.function({"x": "encrypted", "y": "encrypted"}) | ||
def compare(x, y): | ||
return x == y | ||
|
||
@fhe.function( | ||
{ | ||
"is_equal": "encrypted", | ||
"if_equal": "encrypted", | ||
"case_1": "encrypted", | ||
"case_2": "encrypted", | ||
"case_3": "encrypted", | ||
} | ||
) | ||
def mix(is_equal, if_equal, case_1, case_2, case_3): | ||
min_12 = numpy.minimum(case_1, case_2) | ||
min_123 = numpy.minimum(min_12, case_3) | ||
|
||
# FIXME: it should be faster with that, but semantic breaks more often | ||
# return fhe.if_then_else(is_equal, if_equal, 1 + min_123) | ||
|
||
return is_equal * if_equal + (1 - is_equal) * (1 + min_123) | ||
|
||
|
||
# For now, we pick only small letters | ||
def random_letter_as_int(): | ||
return 97 + numpy.random.randint(26) | ||
|
||
|
||
def random_letter(): | ||
return chr(random_letter_as_int()) | ||
|
||
|
||
def random_string(l): | ||
return "".join([random_letter() for _ in range(l)]) | ||
|
||
|
||
def map_string_to_int(s): | ||
return [ord(si) for si in s] | ||
|
||
|
||
# Compilation | ||
inputset_compare = [(random_letter_as_int(), random_letter_as_int()) for _ in range(1000)] | ||
inputset_mix = [ | ||
( | ||
numpy.random.randint(2), | ||
numpy.random.randint(max_string_length), | ||
numpy.random.randint(max_string_length), | ||
numpy.random.randint(max_string_length), | ||
numpy.random.randint(max_string_length), | ||
) | ||
for _ in range(100) | ||
] | ||
|
||
my_module = MyModule.compile( | ||
{"compare": inputset_compare, "mix": inputset_mix}, show_mlir=True, p_error=10**-8 | ||
) | ||
|
||
# Function in clear, for reference and comparison | ||
def levenshtein_clear(x, y): | ||
if len(x) == 0: | ||
return len(y) | ||
if len(y) == 0: | ||
return len(x) | ||
|
||
if x[0] == y[0]: | ||
return levenshtein_clear(x[1:], y[1:]) | ||
|
||
case_1 = levenshtein_clear(x[1:], y) | ||
case_2 = levenshtein_clear(x, y[1:]) | ||
case_3 = levenshtein_clear(x[1:], y[1:]) | ||
|
||
return 1 + min(case_1, case_2, case_3) | ||
|
||
|
||
# Function in FHE-simulate, to debug | ||
def levenshtein_simulate(x, y): | ||
if len(x) == 0: | ||
return len(y) | ||
if len(y) == 0: | ||
return len(x) | ||
|
||
if_equal = levenshtein_simulate(x[1:], y[1:]) | ||
case_1 = levenshtein_simulate(x[1:], y) | ||
case_2 = levenshtein_simulate(x, y[1:]) | ||
case_3 = if_equal | ||
|
||
is_equal = my_module.compare(x[0], y[0]) | ||
returned_value = my_module.mix(is_equal, if_equal, case_1, case_2, case_3) | ||
|
||
return returned_value | ||
|
||
|
||
# Function in FHE | ||
def levenshtein_fhe(x, y): | ||
if len(x) == 0: | ||
# In clear, that's return len(y) | ||
return my_module.compare.encrypt(len(y), None)[0] | ||
if len(y) == 0: | ||
# In clear, that's return len(x) | ||
return my_module.compare.encrypt(len(x), None)[0] | ||
|
||
if_equal = levenshtein_fhe(x[1:], y[1:]) | ||
case_1 = levenshtein_fhe(x[1:], y) | ||
case_2 = levenshtein_fhe(x, y[1:]) | ||
case_3 = if_equal | ||
|
||
# In FHE | ||
is_equal = my_module.compare.run(x[0], y[0]) | ||
returned_value = my_module.mix.run(is_equal, if_equal, case_1, case_2, case_3) | ||
|
||
return returned_value | ||
|
||
|
||
# Random patterns of different lengths | ||
list_patterns = [ | ||
("", ""), | ||
("", "a"), | ||
("b", ""), | ||
("a", "a"), | ||
("a", "b"), | ||
] | ||
|
||
for length_1 in range(max_string_length + 1): | ||
for length_2 in range(max_string_length + 1): | ||
list_patterns += [ | ||
( | ||
random_string(length_1), | ||
random_string(length_2), | ||
) | ||
for _ in range(1) | ||
] | ||
|
||
# Checks in simulation | ||
print("Computations in simulation\n") | ||
|
||
for a, b in list_patterns: | ||
|
||
print(f" Computing Levenshtein between strings '{a}' and '{b}'", end="") | ||
|
||
assert len(a) <= max_string_length | ||
assert len(b) <= max_string_length | ||
|
||
a_as_int = map_string_to_int(a) | ||
b_as_int = map_string_to_int(b) | ||
|
||
l1_simulate = levenshtein_simulate(a_as_int, b_as_int) | ||
l1_clear = levenshtein_clear(a_as_int, b_as_int) | ||
|
||
assert l1_simulate == l1_clear, f" {l1_simulate=} and {l1_clear=} are different" | ||
print(" - OK") | ||
|
||
# Key generation: FIXME, to be done calling a keygen function, see https://github.com/zama-ai/concrete-internal/issues/751 | ||
my_module.compare.encrypt(None, None) | ||
|
||
# Checks in FHE | ||
print("\nComputations in FHE\n") | ||
|
||
for a, b in list_patterns: | ||
|
||
print(f" Computing Levenshtein between strings '{a}' and '{b}'", end="") | ||
|
||
assert len(a) <= max_string_length | ||
assert len(b) <= max_string_length | ||
|
||
a_as_int = map_string_to_int(a) | ||
b_as_int = map_string_to_int(b) | ||
|
||
a_enc = [my_module.compare.encrypt(ai, None)[0] for ai in a_as_int] | ||
b_enc = [my_module.compare.encrypt(None, bi)[1] for bi in b_as_int] | ||
|
||
time_begin = time.time() | ||
l1_fhe_enc = levenshtein_fhe(a_enc, b_enc) | ||
time_end = time.time() | ||
|
||
l1_fhe = my_module.mix.decrypt(l1_fhe_enc) | ||
|
||
l1_clear = levenshtein_clear(a, b) | ||
|
||
assert l1_fhe == l1_clear, f" {l1_fhe=} and {l1_clear=} are different" | ||
print(f" - OK in {time_end - time_begin:.2f} seconds") |