diff --git a/frontends/concrete-python/concrete/fhe/extensions/bits.py b/frontends/concrete-python/concrete/fhe/extensions/bits.py index 5b7581d0f5..b48cb7a87b 100644 --- a/frontends/concrete-python/concrete/fhe/extensions/bits.py +++ b/frontends/concrete-python/concrete/fhe/extensions/bits.py @@ -105,7 +105,7 @@ def __getitem__(self, index: Union[int, np.integer, slice]) -> Tracer: def evaluator(x, bits): # pylint: disable=redefined-outer-name if isinstance(bits, (int, np.integer)): - return (x & (1 << bits)) >> bits + return (x >> bits) & 1 assert isinstance(bits, slice) @@ -126,16 +126,24 @@ def evaluator(x, bits): # pylint: disable=redefined-outer-name result = 0 for i, bit in enumerate(range(start, stop, step)): - value = (x & (1 << bit)) >> bit + value = (x >> bit) & 1 result += value << i return result if isinstance(self.value, Tracer): + output_value = deepcopy(self.value.output) + direct_single_bit = ( + Tracer._is_direct + and isinstance(index, int) + and isinstance(output_value.dtype, Integer) + ) + if direct_single_bit: + output_value.dtype.bit_width = 1 # type: ignore[attr-defined] computation = Node.generic( "extract_bit_pattern", [deepcopy(self.value.output)], - deepcopy(self.value.output), + output_value, evaluator, kwargs={"bits": index}, ) diff --git a/frontends/concrete-python/concrete/fhe/extensions/synthesis/README.md b/frontends/concrete-python/concrete/fhe/extensions/synthesis/README.md new file mode 100644 index 0000000000..bc806d6847 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/extensions/synthesis/README.md @@ -0,0 +1 @@ +http://yowasp.org/ diff --git a/frontends/concrete-python/concrete/fhe/extensions/synthesis/__init__.py b/frontends/concrete-python/concrete/fhe/extensions/synthesis/__init__.py new file mode 100644 index 0000000000..2d1f3403c3 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/extensions/synthesis/__init__.py @@ -0,0 +1,2 @@ +"""Provide synthesis main entry points.""" +from .api import lut, verilog_expression, verilog_module diff --git a/frontends/concrete-python/concrete/fhe/extensions/synthesis/api.py b/frontends/concrete-python/concrete/fhe/extensions/synthesis/api.py new file mode 100644 index 0000000000..f6cdc30c84 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/extensions/synthesis/api.py @@ -0,0 +1,125 @@ +""" +EXPERIMENTAL extension to synthesize a fhe compatible function from verilog code. + +The resulting object can be used directly as a python function. +For instance you can write a relu function using: + + out = fhe.int5 + params = {"a": out} + expression = "(a >= 0) ? a : 0" + + relu = synth.verilog_expression(params, expression, out) + + @fhe.circuit({"a": "encrypted"}) + def circuit(a: out): + return relu(a=a) + +""" +from collections import Counter + +from concrete.fhe.extensions.synthesis.luts_to_fhe import tlu_circuit_to_fhe +from concrete.fhe.extensions.synthesis.luts_to_graph import to_graph +from concrete.fhe.extensions.synthesis.verilog_source import ( + Ty, + verilog_from_expression, + verilog_from_tlu, +) +from concrete.fhe.extensions.synthesis.verilog_to_luts import yosys_lut_synthesis + + +class FheFunction: + """Main class to synthesize verilog to tracer function.""" + + def __init__( + self, + *, + verilog, + name, + params=None, + yosys_dot_file=False, + verbose=False, + ): + self.name = name + self.verilog = verilog + if verbose: + print() + print(f"Verilog, {name}:") + print(verilog) + print() + if verbose: + print("Synthesis") + self.circuit = yosys_lut_synthesis( + verilog, yosys_dot_file=yosys_dot_file, circuit_name=name + ) + if verbose: + print() + print(f"TLUs counts, {self.tlu_counts()}:") + print() + self.params = params + self.tracer = tlu_circuit_to_fhe(self.circuit, self.params, verbose) + + def __call__(self, **kwargs): + """Call the tracer function.""" + return self.tracer(**kwargs) + + def tlu_counts(self): + """Count the number of tlus in the synthesized tracer keyed by input precision.""" + counter = Counter() + for node in self.circuit.nodes: + if len(node.arguments) == 1: + print(node) + counter.update({len(node.arguments): 1}) + return dict(sorted(counter.items())) + + def graph(self, *, filename=None, view=True, **kwargs): + """Render the synthesized tracer as a graph.""" + graph = to_graph(self.name, self.circuit.nodes) + graph.render(filename=filename, view=view, cleanup=filename is None, **kwargs) + + +def lut(table, out_type=None, **kwargs): + """Synthesize a lookup function from a table.""" + # assert not signed # TODO signed case + if isinstance(out_type, list): + msg = "Multi-message output is not supported" + raise TypeError(msg) + if out_type: + out_type = Ty( + bit_width=out_type.dtype.bit_width, + is_signed=out_type.dtype.is_signed, + ) + verilog, out_type = verilog_from_tlu(table, signed_input=False, out_type=out_type) + if "name" not in kwargs: + kwargs.setdefault("name", "lut") + return FheFunction(verilog=verilog, **kwargs) + + +def _uniformize_as_list(v): + return v if isinstance(v, (list, tuple)) else [v] + + +def verilog_expression(params, expression, out_type, **kwargs): + """Synthesize a lookup function from a verilog function.""" + result_name = "result" + if result_name in params: + result_name = f"{result_name}_{hash(expression)}" + params = dict(params) + params[result_name] = out_type + verilog_params = { + name: Ty( + bit_width=sum(ty.dtype.bit_width for ty in _uniformize_as_list(type_list)), + is_signed=_uniformize_as_list(type_list)[0].dtype.is_signed, + ) + for name, type_list in params.items() + } + verilog = verilog_from_expression(verilog_params, expression, result_name) + if "name" not in kwargs: + kwargs.setdefault("name", expression) + return FheFunction(verilog=verilog, params=params, **kwargs) + + +def verilog_module(verilog, **kwargs): + """Synthesize a lookup function from a verilog module.""" + if "name" not in kwargs: + kwargs.setdefault("name", "main") + return FheFunction(verilog=verilog, **kwargs) diff --git a/frontends/concrete-python/concrete/fhe/extensions/synthesis/luts_to_fhe.py b/frontends/concrete-python/concrete/fhe/extensions/synthesis/luts_to_fhe.py new file mode 100644 index 0000000000..dbd298f4a9 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/extensions/synthesis/luts_to_fhe.py @@ -0,0 +1,237 @@ +""" +Convert the simple Tlu Dag to a concrete-python Tracer function. +""" + +from copy import deepcopy +from typing import Dict, List + +import numpy as np + +from concrete import fhe +from concrete.fhe.extensions.synthesis.verilog_to_luts import TluNode +from concrete.fhe.tracing.tracer import Tracer + +WEIGHT_TO_TLU = True +ENFORCE_BITWDTH = True + + +def layered_nodes(nodes: List[TluNode]): + """ + Group nodes in layers by readyness (ready to be computed). + """ + waiting: Dict[str, List[int]] = {} + for i, tlu_node in enumerate(nodes): + for arg in tlu_node.arguments: + waiting.setdefault(arg.name, []).append(i) + readyness = [ + sum(not (node.origin.is_parameter) for node in tlu_node.arguments) for tlu_node in nodes + ] + nb_nodes = len(nodes) + ready_nodes = [i for i, count in enumerate(readyness) if count == 0] + layers = [] + while nb_nodes > 0: + assert ready_nodes, (nb_nodes, readyness) + new_ready_nodes = [] + for i_node in ready_nodes: + for result in nodes[i_node].results: + for index in waiting.get(result.name, ()): + readyness[index] -= 1 + assert readyness[index] >= 0 + if readyness[index] == 0: + new_ready_nodes.append(index) + layers += [ready_nodes] + nb_nodes -= len(ready_nodes) + ready_nodes = new_ready_nodes + return layers + + +def tlu_circuit_to_fhe(circuit, params, verbose): + """Convert the simple TLU dag to a Tracer function.""" + layers = layered_nodes(circuit.nodes) + scheduled_nodes = [circuit.nodes[index_node] for layer in layers for index_node in layer] + + if verbose: + print("Layers") + for i, layer in enumerate(layers): + arities = [len(circuit.nodes[node_index].arguments) for node_index in layer] + print(f"Layer {i}") + print(f" {arities}") + print(f" nb luts: {len(layer)}") + + results_type = [params[result[0].base_name] for result in circuit.results] + + # collect max arity use for all values + max_arity_use = {} + for bits in circuit.parameters.values(): + for value in bits: + max_arity_use[value.name] = 0 + for tlu in scheduled_nodes: + for value in tlu.arguments: + max_arity_use[value.name] = max(max_arity_use[value.name], tlu.arity) + for value in tlu.results: + assert value.name not in max_arity_use + max_arity_use[value.name] = 0 + for result in circuit.results: + for res_bit in result: + max_arity_use[res_bit.name] = len(result) + + # collect min scale use for all values + min_scale = {} + for bits in circuit.parameters.values(): + for value in bits: + min_scale[value.name] = 1 + for result in circuit.results: + for i, res_bit in enumerate(result): + min_scale[res_bit.name] = 2**i + for tlu in scheduled_nodes: + for i, value in enumerate(tlu.arguments): + min_scale[value.name] = min(min_scale[value.name], 2**i) + for value in tlu.results: + assert value.name not in min_scale or value.origin.is_result + if not value.origin.is_result: + min_scale[value.name] = 2 ** (max_arity_use[value.name] - 1) + if not WEIGHT_TO_TLU: + for n in min_scale: + min_scale[n] = 1 + + # pylint: disable=protected-access + skip_force_bit_witdth = not (Tracer._is_direct or ENFORCE_BITWDTH) + + def with_unsigned(tracer): + """Transform a tracer to an unsigned one. + + Note: only works because thanks to bit extraction which it's annotate. + """ + if not isinstance(tracer, Tracer): + return tracer + tracer.output = deepcopy(tracer.output) + tracer.output.dtype.is_signed = False + return tracer + + # stonger than hints + def with_bit_width(tracer, bit_width): + if not isinstance(tracer, Tracer): + return tracer + if skip_force_bit_witdth: + return tracer + tracer.output = deepcopy(tracer.output) + tracer.output.dtype.bit_width = bit_width + return tracer + + def with_result_type(tracer, out_type): + if not isinstance(tracer, Tracer): + return tracer + assert not tracer.output.dtype.is_signed + # a workaround for the missing to signed operator + if out_type.dtype.is_signed: + tracer = -(-tracer) + tracer.output.dtype.is_signed = out_type.dtype.is_signed + if skip_force_bit_witdth: + return tracer + tracer.output = deepcopy(tracer.output) + tracer.output.dtype.bit_width = out_type.dtype.bit_width + return tracer + + def repack_scaled_bits(scaled_bit_values): + scaled_bit_values = list(scaled_bit_values) + bit_width = len(scaled_bit_values) + assert bit_width > 0 + repacked_bits = None + for i, (scale, value) in enumerate(scaled_bit_values): + if value is None: + assert scale == 0 + continue + if isinstance(value, int) and value == 0: + continue + assert scale >= 1 + if repacked_bits is None: + assert scale == 1 + repacked_bits = value + else: + assert scale <= 2**i + assert scale & (scale - 1) == 0 # is power of 2 + weight = 2**i // scale + repacked_bits += value if weight == 1 else value * weight + return with_bit_width(repacked_bits, bit_width) + + def tracer(**kwargs): + for name in circuit.parameters: + if name not in kwargs: + msg = f"{circuit.name}() has a missing keyword argument '{name}'" + raise TypeError(msg) + for name in kwargs: + if name not in circuit.parameters: + msg = f"{circuit.name}() got an unexpected keyword argument '{name}'" + raise TypeError(msg) + + # decompose parameters into bits + with fhe.tag("bit_extractions"): + parameters = { + bit.name: with_bit_width( + fhe.bits(with_unsigned(value))[bit.origin.bit_index], + bit_width=max_arity_use[bit.name], + ) + for name, value in kwargs.items() + for bit in circuit.parameters[name] + } + # contains all intermediate tracer + intermediate_values = dict(parameters) + + # handle special case first + # constant tlu and identity + for tlu_node in scheduled_nodes: + assert len(tlu_node.results) == 1 + output_name = tlu_node.results[0].name + if len(tlu_node.results) == 1 and tlu_node.content == [0, 1]: + assert len(tlu_node.arguments) == 1 + intermediate_values[output_name] = intermediate_values[tlu_node.arguments[0].name] + elif len(tlu_node.arguments) == 0: + if tlu_node.content == ["0"]: + intermediate_values[output_name] = 0 + elif tlu_node.content == ["1"]: + intermediate_values[output_name] = 1 + else: + msg = "Unknown Constant TLU content" + raise ValueError(msg) + + with fhe.tag("synthesis"): + # apply all tlus + for tlu_node in scheduled_nodes: + output_name = tlu_node.results[0].name + if output_name in intermediate_values: + continue + assert tlu_node.arguments + repacked_bits = repack_scaled_bits( + (min_scale[arg.name], intermediate_values[arg.name]) + for arg in tlu_node.arguments + ) + rescale = min_scale[output_name] + flat_content = np.array(tlu_node.content).reshape(-1) + tlu = fhe.LookupTable([v * rescale for v in flat_content]) + max_precision = max_arity_use[output_name] + if max_precision: + result = with_bit_width(tlu[repacked_bits], bit_width=max_precision) + else: + result = tlu[repacked_bits] + intermediate_values[output_name] = result + + with fhe.tag("bit_assemble"): + # recompose bits into result + results = tuple( + repack_scaled_bits( + (min_scale[res_bit.name], intermediate_values[res_bit.name]) + for res_bit in result + ) + for result in circuit.results + ) + for r_type in results_type: + assert not isinstance(r_type, list) + # Provide the right result type + results = tuple( + with_result_type(result, r_type) for result, r_type in zip(results, results_type) + ) + if len(results) == 1: + return results[0] + return results + + return tracer diff --git a/frontends/concrete-python/concrete/fhe/extensions/synthesis/luts_to_graph.py b/frontends/concrete-python/concrete/fhe/extensions/synthesis/luts_to_graph.py new file mode 100644 index 0000000000..accf9e0200 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/extensions/synthesis/luts_to_graph.py @@ -0,0 +1,135 @@ +"""Construct a graphviz graph from tlu nodes.""" + + +def detect_std_lut_label(node): + """Give friendly name to some known tlus.""" + inputs = node.arguments + content = node.content + outputs = node.results + assert len(outputs) == 1 + if len(inputs) == 1: + return "l1-not" if content == [1, 0] else "l1-id" + if len(inputs) == 2: + if content == [[0, 0], [0, 1]]: + return "l2-carry" + if content == [[1, 1], [1, 0]]: + return "l2-nand" + if content == [[0, 1], [1, 1]]: + return "l2-or" + if content == [[1, 0], [0, 0]]: + return "l2-2-nor" + if content == [[0, 1], [1, 0]]: + return "l2-add" + if content == [[1, 0], [0, 1]]: + return "l2-nadd" + elif len(inputs) == 3: + if content == [[[0, 1], [1, 0]], [[1, 0], [0, 1]]]: + return "l3-add" + if content == [[[1, 0], [0, 1]], [[0, 1], [1, 0]]]: + return "l3-add(n,_,_)" + if content == [[[1, 0], [0, 0]], [[1, 1], [1, 0]]]: + return "l3-ncarry(n,_,_)" + if content == [[[1, 1], [1, 0]], [[1, 0], [0, 0]]]: + return "l3-ncarry(_,_,_)" + elif len(inputs) == 4: + if content == [[[[1, 0], [0, 1]], [[0, 1], [0, 1]]], [[[1, 0], [1, 0]], [[1, 0], [0, 1]]]]: + return "l4-add(n,_,n,_)" + return None + + +# def _detect_std_lut_label(node): +# inputs = node.arguments +# content = node.content +# outputs = node.results +# assert len(outputs) == 1 +# if len(inputs) == 0: +# return f"c-{content[0]}" + +# if len(inputs) == 1: +# if content == [1, 0]: +# return "l1-not" +# else: +# return "l1-id" +# if len(inputs) == 2: +# if content == [[0, 0], [0, 1]]: +# return "l2-carry" +# if content == [[1, 1], [1, 0]]: +# return "l2-nand" +# if content == [[0, 1], [1, 1]]: +# return "l2-or" +# if content == [[1, 0], [0, 0]]: +# return "l2-nor" +# if content == [[0, 1], [1, 0]]: +# return "l2-add" +# if content == [[1, 0], [0, 1]]: +# return "l2-add" +# elif len(inputs) == 3: +# if content == [[[0, 1], [1, 0]], [[1, 0], [0, 1]]]: +# return "l3-add" +# elif content == [[[1, 0], [0, 1]], [[0, 1], [1, 0]]]: +# return "l3-add" +# elif content == [[[1, 0], [0, 0]], [[1, 1], [1, 0]]]: +# return "l3-carry" +# elif content == [[[1, 1], [1, 0]], [[1, 0], [0, 0]]]: +# return "l3-carry" +# elif len(inputs) == 4: +# if [[[[1, 0], [0, 1]], [[0, 1], [0, 1]]], [[[1, 0], [1, 0]], [[1, 0], [0, 1]]]]: +# return "l4-add" +# return None + + +def to_graph(name, nodes): + """Construct a graphviz graph from tlu nodes.""" + try: + import graphviz # pylint: disable=import-outside-toplevel + except ImportError as exc: + msg = "You must install concrete-python with graphviz support or install graphviz manually" + raise ValueError(msg) from exc + + graph = graphviz.Digraph(name=name) + declated_node = set() + with graph.subgraph(name="cluster_output") as cluster: + for tlu in nodes: + outputs = tlu.results + for result in outputs: + if result.is_interface: + cluster.node(result.name, result.name, color="magenta", shape="box") + declated_node.add(result.name) + with graph.subgraph(name="cluster_input") as cluster: + for tlu in nodes: + inputs = tlu.arguments + for argument in sorted(inputs, key=lambda n: n.name): + if argument.is_interface and argument.name not in declated_node: + cluster.node(argument.name, argument.name, color="blue", shape="box") + declated_node.add(argument.name) + + def node_name_label(node, i): + label = detect_std_lut_label(node) or f"l-{len(node.arguments)}" + name = f"{label}-{i}" + return name, label + + with graph.subgraph(name="cluster_inner") as cluster: + for i, tlu in enumerate(nodes): + inputs = tlu.arguments + outputs = tlu.results + name, label = node_name_label(tlu, i) + cluster.node(name, label, shape="octagon") + for argument in inputs: + if argument.name not in declated_node: + cluster.node(argument.name, argument.name, color="black", shape="point") + declated_node.add(argument.name) + for result in outputs: + if result.name not in declated_node: + cluster.node(result.name, result.name, color="black", shape="point") + declated_node.add(result.name) + for i, tlu in enumerate(nodes): + inputs = tlu.arguments + outputs = tlu.results + name, label = node_name_label(tlu, i) + for j, argument in enumerate(inputs): + assert argument.name in declated_node + graph.edge(argument.name, name, headlabel=str(j)) + for result in outputs: + assert result.name in declated_node + graph.edge(name, result.name) + return graph diff --git a/frontends/concrete-python/concrete/fhe/extensions/synthesis/verilog_source.py b/frontends/concrete-python/concrete/fhe/extensions/synthesis/verilog_source.py new file mode 100644 index 0000000000..679670e928 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/extensions/synthesis/verilog_source.py @@ -0,0 +1,64 @@ +"""Provide helper function to generate verilog source code.""" + +import math +from dataclasses import dataclass +from typing import Dict, List, Tuple + +import numpy as np + + +@dataclass +class Ty: + """Type of arguments and result of a verilog module.""" + + bit_width: int + is_signed: bool + + +def signed(b): + """Textual signed attribute.""" + return "signed" if b else "" + + +def verilog_from_expression(params: Dict[str, Ty], operation: str, result_name: str) -> str: + """Create a verilog module source from params specification (including result), operation.""" + + out_bit_width = params[result_name].bit_width + out_signed = params[result_name].is_signed + params = ", ".join( + f"input {signed(ty.is_signed)} [0:{ty.bit_width - 1}] {name}" + for name, ty in params.items() + if name != result_name + ) + return f"""\ +module main({params}, output {signed(out_signed)} [0:{out_bit_width-1}] {result_name}); + assign {result_name} = {operation}; +endmodule +""" + + +def verilog_from_tlu(table: List[int], signed_input=False, out_type=None) -> Tuple[str, Ty]: + """Create a verilog module source doing the table lookup in table.""" + assert not signed_input + table = list(np.array(table).reshape(-1)) + max_table = max(table) + precision_a = math.ceil(math.log2(len(table))) + if out_type is None: + out_type = Ty( + bit_width=max(1, math.ceil(math.log2(max(1, max_table)))), + is_signed=False, + ) + expr = " " + " :".join( + [f"(a == {i}) ? {v}\n" for i, v in enumerate(table)] + [f"{max_table}"] + ) + return ( + f"""\ +module main(a, result); + input[0:{precision_a-1}] a; + output {signed(out_type.is_signed)} [0:{out_type.bit_width-1}] result; + assign result = (\n{expr} + ); +endmodule\ +""", + out_type, + ) diff --git a/frontends/concrete-python/concrete/fhe/extensions/synthesis/verilog_to_luts.py b/frontends/concrete-python/concrete/fhe/extensions/synthesis/verilog_to_luts.py new file mode 100644 index 0000000000..1a058528cb --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/extensions/synthesis/verilog_to_luts.py @@ -0,0 +1,346 @@ +""" +Rewrite yosys json output as a simple Tlu Dag. +""" +import json +import os +import shutil +import subprocess +import sys +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional + +import numpy as np + +VERBOSE = False + + +def log(*args): + """Log function statically activated.""" + if VERBOSE: + print(*args) + + +## Assuming same arity everywhere, i.e. dag with uniform precision TLUs +LUT_COSTS = { + 1: 29, + 2: 33, + 3: 45, + 4: 74, + 5: 101, + 6: 231, + 7: 535, + 8: 1721, +} + + +def luts_spec_abc(): + """Generate the costs table (area, delay) for `abc`.""" + return "\n".join( + # arity area delay + f"{arity}\t{cost}\t{cost}" + for arity, cost in LUT_COSTS.items() + ) + + +YOSIS_EXE_NAME = "yowasp-yosys" +_yosis_exe = None # pylint: disable=invalid-name + + +def detect_yosis_exe(): + """Detect yosis executable.""" + global _yosis_exe # noqa: PLW0603 pylint: disable=global-statement + if _yosis_exe: + return _yosis_exe + result = ( + shutil.which(YOSIS_EXE_NAME) + or shutil.which(YOSIS_EXE_NAME, path=os.path.dirname(sys.executable)) + or shutil.which(YOSIS_EXE_NAME, path=os.path.dirname(shutil.which("python3") or "")) + ) + if result is None: + msg = f"{YOSIS_EXE_NAME} cannot be found." + raise RuntimeError(msg) + _yosis_exe = result + return result + + +def yosys_script(abc_path, verilog_path, json_path, dot_file, no_clean_up=False): + """Generate `yosys` scripts.""" + no_cleanup = "-nocleanup -showtmp" if no_clean_up else "" + return f""" +echo on +read -sv {verilog_path}; +prep +techmap +log Synthesis with ABC: {abc_path} +abc {no_cleanup} -script {abc_path} +write_json {json_path} +""" + ( + "" if not dot_file else "show -stretch" + ) + + +def abc_script(lut_cost_path): + """Generate `abc` scripts.""" + return f""" +# & avoid a bug when cleaning tmp +read_lut {lut_cost_path} +print_lut +strash +&get -n +&fraig -x +&put +scorr +dc2 +dretime +strash +dch -f +if +mfs2 +#lutpack +""" + + +def bstr(bytes_str): + """Binary str to str.""" + return str(bytes_str, encoding="utf-8") + + +def _yosys_run_script( + abc_file, lut_costs_file, yosys_file, verilog_file, verilog_content, json_file, dot_file=True +): + """Run the yosys script using a subprocess based on the inputs/outpus files.""" + tmpdir_prefix = Path.home() / ".cache" / "concrete-python" / "synthesis" + os.makedirs(tmpdir_prefix, exist_ok=True) + abc_file.write(abc_script(lut_costs_file.name)) + abc_file.flush() + lut_costs_file.write(luts_spec_abc()) + lut_costs_file.flush() + verilog_file.write(verilog_content) + verilog_file.flush() + yosys_file.write(yosys_script(abc_file.name, verilog_file.name, json_file.name, dot_file)) + yosys_file.flush() + yosys_call = [detect_yosis_exe(), "-s", yosys_file.name] + try: + completed = subprocess.run(yosys_call, check=True, capture_output=True) + log(completed.stdout) + log(completed.stderr) + except subprocess.CalledProcessError as exc: + log(exc.output) + log(exc.stderr) + raise_verilog_warnings_and_error( + exc.stdout + exc.stderr, verilog_file.name, verilog_content + ) + print(bstr(exc.output)) + print(bstr(exc.stderr)) + raise + + if b"Warning" in completed.stdout: + raise_verilog_warnings_and_error(completed.stdout, verilog_file.name, verilog_content) + try: + return json.load(json_file) + except json.decoder.JSONDecodeError: + if not json_file.read(): + print(completed.stdout) + print(completed.stderr) + else: + print("JSON:", json_file.read()) + raise + + +def raise_verilog_warnings_and_error(output, verilog_path, verilog_content): + """Raise a tailored exception to provide user information to the detected error.""" + if isinstance(output, bytes): + output = str(output, encoding="utf8") + fatal = None + msgs = [] + for line in output.splitlines(): + location = verilog_path + ":" + if line.startswith(location): + msgs.append(line) + line_nb = int(line.split(location)[1].split(":")[0]) + fatal = f"{line}\nError at line {line_nb}:\n{verilog_content.splitlines()[line_nb-1]}" + elif "Warning" in line: + msgs.append(line) + if msgs and (len(msgs) != 1 and fatal): + print() + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print("The following warnings/errors need to be checked") + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print() + print("\n".join(msgs)) + print() + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print() + if fatal: + raise ValueError(fatal) from None + + +def yosys_run_script(verilog, yosys_dot_file): + """Run the yosys script with the adequate generated files.""" + tmpdir_prefix = Path.home() / ".cache" / "concrete-python" / "synthesis" + os.makedirs(tmpdir_prefix, exist_ok=True) + + def tmpfile(mode, suffix, **kwargs): + return tempfile.NamedTemporaryFile( + mode=mode, dir=tmpdir_prefix, suffix=f"-{suffix}", **kwargs + ) + + # fmt: off + with \ + tmpfile("w+", "script.abc") as abc_file, \ + tmpfile("w+", "lut-costs.txt") as lut_costs_file, \ + tmpfile("w+", "source.verilog") as verilog_file, \ + tmpfile("w+", "yosys.script") as yosys_file, \ + tmpfile("r", "luts.json", delete=False) as json_file: + return _yosys_run_script( + abc_file, lut_costs_file, yosys_file, + verilog_file, verilog, + json_file, dot_file=yosys_dot_file, + ) + # fmt: on + + +@dataclass(frozen=True) +class ValueOrigin: + """Original verilog input/output information for a value.""" + + is_parameter: bool = False + is_result: bool = False + bit_index: int = 0 + + +@dataclass +class ValueNode: + """An intermediate named value.""" + + name: str + origin: ValueOrigin = ValueOrigin() + base_name: Optional[str] = None + + @classmethod + def parameter(cls, base_name, i): + """Construct a parameter (i.e. verilog input) derived value from the base input name.""" + # E.g. `ValueNode.parameter("a", 3)` is "a[3]", the 4th bit of verilog input `a`. + return ValueNode( + f"{base_name}[{i}]", ValueOrigin(is_parameter=True, bit_index=i), base_name + ) + + @classmethod + def result(cls, base_name, i): + """Construct a result (i.e. verilog input) derived value from the base output name.""" + # E.g. `ValueNode.result("r", 3)` is "r[3]", the 4th bit of verilog output `r`. + return ValueNode(f"{base_name}[{i}]", ValueOrigin(is_result=True, bit_index=i), base_name) + + @classmethod + def port(cls, base_name, port, i): + """Construct a verilog port, i.e. either parameter/input or result/output.""" + if port["direction"] == "input": + return cls.parameter(base_name, i) + assert port["direction"] == "output" + return cls.result(base_name, i) + + @property + def is_interface(self): + """Check if a value is part of the verilog circuit interface.""" + return self.origin.is_parameter or self.origin.is_result + + +@dataclass +class TluNode: + """A TLU operation node.""" + + arguments: List[ValueNode] + results: List[ValueNode] + content: List[Any] + + @property + def arity(self): + """Number of single bit parameters of the TLU.""" + return len(self.arguments) + + +@dataclass +class Circuit: + """A full circuit composent of parameters, results and intermediates nodes.""" + + name: str + parameters: Dict[str, List[ValueNode]] + results: List[List[ValueNode]] + nodes: List[TluNode] + + +def convert_yosys_json_to_circuit(json_data, circuit_name="noname"): + """Create a Circuit object from yosys json output.""" + modules = json_data["modules"] + assert len(modules) == 1 + (module,) = modules.values() + assert set(module.keys()) == {"attributes", "ports", "cells", "netnames"}, module.keys() + symbolic_name = {0: "0", 1: "1"} + nodes = [] + + symbol_set = set() + parameters = {} + results = [] + for name, port in module["ports"].items(): + elements = [] + for i, bit in enumerate(port["bits"]): + bit_node = ValueNode.port(name, port, i) + elements.append(bit_node) + symbol_set.add(bit_node.name) + if bit in ("0", "1"): + # output wired to constant + nodes += [ + TluNode( + arguments=[], + results=[bit_node], + content=[bit], + ) + ] + elif bit in symbolic_name: + # input wired to output + nodes += [ + TluNode( + arguments=[symbolic_name[bit]], + results=[bit_node], + content=[0, 1], + ) + ] + log("Equiv: ", symbolic_name[bit], bit_node) + else: + symbolic_name[bit] = bit_node + if elements[0].origin.is_parameter: + parameters[name] = elements + else: + results += [elements] + + log("Interface Names:", symbolic_name) + + for cell_value in module["cells"].values(): + assert cell_value["type"] == "$lut", cell_value + content = list(reversed(list(map(int, cell_value["parameters"]["LUT"])))) + assert set(cell_value["connections"].keys()) == {"A", "Y"} + arguments = [ + symbolic_name.get(n, ValueNode(f"n{n}")) for n in cell_value["connections"]["A"] + ] + # the first input is the last index # missing transpose ? + structured_content = np.array(content).reshape((2,) * (len(arguments))).tolist() + + cell_results = [ + symbolic_name.get(n, ValueNode(f"n{n}")) for n in cell_value["connections"]["Y"] + ] + nodes += [TluNode(arguments=arguments, results=cell_results, content=structured_content)] + + for node in nodes: + log(len(node.arguments), node.arguments, "---", node.content, "-->", node.results) + + log("Names:", symbolic_name) + log("Nodes:", nodes) + return Circuit(circuit_name, parameters, results, nodes) + + +def yosys_lut_synthesis(verilog: str, yosys_dot_file=False, circuit_name="noname"): + """Create a Circuit object from a verilog module.""" + json_data = yosys_run_script(verilog, yosys_dot_file) + return convert_yosys_json_to_circuit(json_data, circuit_name) diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index e98250afcc..71fea10458 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -40,6 +40,22 @@ # pylint: enable=import-error,no-name-in-module +# See https://raw.githubusercontent.com/zama-ai/concrete/main/compilers/concrete-optimizer/v0-parameters/ref/v0_last_128 +# Provide a very coarse way to compare 2 alternative code generation +LUT_COSTS_V0_NORM2_0 = { + 1: 29, + 2: 33, + 3: 45, + 4: 74, + 5: 101, + 6: 231, + 7: 535, + 8: 1721, + 9: 3864, + 10: 8697, + 11: 19522, +} + class Context: """ @@ -125,6 +141,15 @@ def typeof(self, value: Union[ValueDescription, Node]) -> ConversionType: return result if value.is_scalar else self.tensor(result, value.shape) + def fork_type(self, type_, bit_width): + return self.typeof( + ValueDescription( + dtype=Integer(is_signed=type_.is_signed, bit_width=bit_width), + shape=type_.shape, + is_encrypted=type_.is_encrypted, + ) + ) + # utilities def location(self) -> MlirLocation: @@ -2173,6 +2198,22 @@ def encrypt(self, resulting_type: ConversionType, x: Conversion) -> Conversion: def equal(self, resulting_type: ConversionType, x: Conversion, y: Conversion) -> Conversion: return self.comparison(resulting_type, x, y, accept={Comparison.EQUAL}) + def shift_left(self, x: Conversion, rank: int) -> Conversion: + assert rank >= 0 + assert rank < x.bit_width + shifter = 2**rank + shifter = self.constant(self.i(x.bit_width + 1), shifter) + return self.mul(x.type, x, shifter) + + def reduce_precision(self, x: Conversion, bit_width: int) -> Conversion: + assert bit_width > 0 + assert bit_width <= x.bit_width + if bit_width == x.bit_width: + return x + scaled_x = self.shift_left(x, x.type.bit_width - bit_width) + x = self.reinterpret(scaled_x, bit_width=bit_width) + return x + def extract_bits( self, resulting_type: ConversionType, @@ -2200,66 +2241,101 @@ def extract_bits( start = bits.start or MIN_EXTRACTABLE_BIT stop = bits.stop or (MAX_EXTRACTABLE_BIT if step > 0 else (MIN_EXTRACTABLE_BIT - 1)) - bits_and_their_positions = [] - for position, bit in enumerate(range(start, stop, step)): - bits_and_their_positions.append((bit, position)) + bits = list(range(start, stop, step)) + + bits_and_their_positions = ((bit, position) for position, bit in enumerate(bits)) bits_and_their_positions = sorted( bits_and_their_positions, key=lambda bit_and_its_position: bit_and_its_position[0], ) + # we optimize bulk extract in low precision, used for identity + cost_one_tlu = LUT_COSTS_V0_NORM2_0.get(x.bit_width, float("inf")) + cost_many_lsbs = LUT_COSTS_V0_NORM2_0[1] * (max(bits, default=0) + 1) + if cost_one_tlu < cost_many_lsbs: + + def tlu_cell_with_positive_value(i): + return x.type.is_unsigned or i < 2 ** (x.bit_width - 1) + + def tlu_cell_input_value(i): + if tlu_cell_with_positive_value(i): + return i + return -(2 ** (x.bit_width) - i) + + table = [ + sum( + ((tlu_cell_input_value(i) >> bit) & 1) << position + for bit, position in bits_and_their_positions + ) + + (0 if tlu_cell_with_positive_value(i) else 2 ** (resulting_type.bit_width + 1)) + for i in range(2**x.bit_width) + ] + tlu_result = self.tlu(resulting_type, x, table) + return self.to_signedness(tlu_result, of=resulting_type) + current_bit = 0 max_bit = x.original_bit_width lsb: Optional[Conversion] = None result: Optional[Conversion] = None - - for index, (bit, position) in enumerate(bits_and_their_positions): + for bit, position in bits_and_their_positions: if bit >= max_bit and x.is_unsigned: break - last = index == len(bits_and_their_positions) - 1 while bit != (current_bit - 1): - if bit == (max_bit - 1) and x.bit_width == 1 and x.is_unsigned: + if ( + bit == (max_bit - 1) + and x.bit_width == resulting_type.bit_width == 1 + and x.is_unsigned + ): + lsb_bit_witdh = 1 lsb = x - elif last and bit == current_bit: - lsb = self.lsb(resulting_type, x) else: - lsb = self.lsb(x.type, x) + lsb_bit_witdh = max(resulting_type.bit_width - position, x.bit_width) + lsb_type = self.fork_type(x.type, lsb_bit_witdh) + lsb = self.lsb(lsb_type, x) + + # check that we only need to shift to emulate the initial and final position + # position are expressed for the final bit_width + initial_position = resulting_type.bit_width - x.bit_width + actual_position = resulting_type.bit_width - lsb.type.bit_width + delta_precision = initial_position - actual_position + assert 0 <= delta_precision < resulting_type.bit_width + assert ( + actual_position <= initial_position + ), "extract_bits: Cannot get back to initial precision" + assert ( + actual_position <= position + ), "extract_bits: Cannot get back to final precision" current_bit += 1 if current_bit >= max_bit: break - if not last or bit != (current_bit - 1): - cleared = self.sub(x.type, x, lsb) - x = self.reinterpret(cleared, bit_width=(x.bit_width - 1)) + clearing_bit = self.reduce_precision(lsb, x.bit_width) + cleared = self.sub(x.type, x, clearing_bit) + x = self.reinterpret(cleared, bit_width=(x.bit_width - 1)) assert lsb is not None - lsb = self.to_signedness(lsb, of=resulting_type) - - if lsb.bit_width > resulting_type.bit_width: - difference = (lsb.bit_width - resulting_type.bit_width) + position - shifter = self.constant(self.i(lsb.bit_width + 1), 2**difference) - shifted = self.mul(lsb.type, lsb, shifter) - lsb = self.reinterpret(shifted, bit_width=resulting_type.bit_width) - - elif lsb.bit_width < resulting_type.bit_width: - shift = 2 ** (lsb.bit_width - 1) - if shift != 1: - shifter = self.constant(self.i(lsb.bit_width + 1), shift) - shifted = self.mul(lsb.type, lsb, shifter) - lsb = self.reinterpret(shifted, bit_width=1) - lsb = self.tlu(resulting_type, lsb, [0 << position, 1 << position]) - - elif position != 0: - shifter = self.constant(self.i(lsb.bit_width + 1), 2**position) - lsb = self.mul(lsb.type, lsb, shifter) + bit_value = self.to_signedness(lsb, of=resulting_type) + bit_value = self.reinterpret( + bit_value, bit_width=max(resulting_type.bit_width, max_bit) + ) - assert lsb is not None - result = lsb if result is None else self.add(resulting_type, result, lsb) + delta_precision = position - actual_position + assert actual_position < 0 or 0 <= delta_precision < resulting_type.bit_width, ( + position, + actual_position, + resulting_type.bit_width, + ) + if delta_precision: + bit_value = self.shift_left(bit_value, delta_precision) + + bit_value = self.reinterpret(bit_value, bit_width=resulting_type.bit_width) + + result = bit_value if result is None else self.add(resulting_type, result, bit_value) return result if result is not None else self.zeros(resulting_type) @@ -3209,8 +3285,8 @@ def round_bit_pattern( unskewed = x if approx_conf.symetrize_deltas: highest_supported_precision = 62 - delta_precision = highest_supported_precision - x.type.bit_width - full_precision = x.type.bit_width + delta_precision + delta_precision = highest_supported_precision - x.bit_width + full_precision = x.bit_width + delta_precision half_in_extra_precision = ( 1 << (delta_precision - 1) ) - 1 # slightly smaller then half @@ -3763,11 +3839,12 @@ def reinterpret( ) -> Conversion: assert x.is_encrypted - if x.bit_width == bit_width: + result_unsigned = x.is_unsigned if signed is None else not signed + + if x.bit_width == bit_width and x.is_unsigned == result_unsigned: return x - result_signed = x.is_unsigned if signed is None else signed - resulting_element_type = (self.eint if result_signed else self.esint)(bit_width) + resulting_element_type = (self.eint if result_unsigned else self.esint)(bit_width) resulting_type = self.tensor(resulting_element_type, shape=x.shape) operation = ( diff --git a/frontends/concrete-python/concrete/fhe/representation/node.py b/frontends/concrete-python/concrete/fhe/representation/node.py index 380fca1b78..027243c8a0 100644 --- a/frontends/concrete-python/concrete/fhe/representation/node.py +++ b/frontends/concrete-python/concrete/fhe/representation/node.py @@ -172,15 +172,15 @@ def __init__( fhe_directory = os.path.dirname(fhe.__file__) # pylint: enable=cyclic-import,import-outside-toplevel - for frame in reversed(traceback.extract_stack()): + self.location = f"{frame.filename}:{frame.lineno}" if frame.filename == "<__array_function__ internals>": continue if frame.filename.startswith(fhe_directory): + self.location = f"{frame.filename}:{frame.lineno}" continue - self.location = f"{frame.filename}:{frame.lineno}" break # pylint: disable=cyclic-import,import-outside-toplevel diff --git a/frontends/concrete-python/requirements.txt b/frontends/concrete-python/requirements.txt index 1d612353fd..a0dd708710 100644 --- a/frontends/concrete-python/requirements.txt +++ b/frontends/concrete-python/requirements.txt @@ -4,4 +4,5 @@ networkx>=2.6 numpy>=1.23 scipy>=1.10 torch>=1.13 +yowasp-yosys>=0.39.0 z3-solver>=4.12 diff --git a/frontends/concrete-python/tests/execution/test_bit_extraction.py b/frontends/concrete-python/tests/execution/test_bit_extraction.py index 3285170643..483fcd788f 100644 --- a/frontends/concrete-python/tests/execution/test_bit_extraction.py +++ b/frontends/concrete-python/tests/execution/test_bit_extraction.py @@ -152,6 +152,7 @@ def test_bad_plain_bit_extraction( "input_bit_width,input_is_signed,operation", [ # unsigned + pytest.param(3, False, lambda x: fhe.bits(x)[0:3], id="unsigned-3b[0:3]"), pytest.param(5, False, lambda x: fhe.bits(x)[0], id="unsigned-5b[0]"), pytest.param(5, False, lambda x: fhe.bits(x)[1], id="unsigned-5b[1]"), pytest.param(5, False, lambda x: fhe.bits(x)[2], id="unsigned-5b[2]"), @@ -166,6 +167,7 @@ def test_bad_plain_bit_extraction( pytest.param(5, False, lambda x: fhe.bits(x)[2::-1], id="unsigned-5b[2::-1]"), pytest.param(5, False, lambda x: fhe.bits(x)[1:30:10], id="unsigned-5b[1:30:10]"), # signed + pytest.param(3, True, lambda x: fhe.bits(x)[0:3], id="signed-3b[0:3]"), pytest.param(5, True, lambda x: fhe.bits(x)[0], id="signed-5b[0]"), pytest.param(5, True, lambda x: fhe.bits(x)[1], id="signed-5b[1]"), pytest.param(5, True, lambda x: fhe.bits(x)[2], id="signed-5b[2]"), @@ -179,9 +181,11 @@ def test_bad_plain_bit_extraction( pytest.param(5, True, lambda x: fhe.bits(x)[2::-1], id="signed-5b[2::-1]"), pytest.param(5, True, lambda x: fhe.bits(x)[1:30:10], id="signed-5b[1:30:10]"), # unsigned (result bit-width increased) + pytest.param(3, False, lambda x: fhe.bits(x)[0:3] + 100, id="unsigned-3b[0:3] + 100"), pytest.param(5, False, lambda x: fhe.bits(x)[0] + 100, id="unsigned-5b[0] + 100"), pytest.param(5, False, lambda x: fhe.bits(x)[1:3] + 100, id="unsigned-5b[1:3] + 100"), # signed (result bit-width increased) + pytest.param(3, True, lambda x: fhe.bits(x)[0:3], id="signed-3b[0:3] + 100"), pytest.param(5, True, lambda x: fhe.bits(x)[0] + 100, id="signed-5b[0] + 100"), pytest.param(5, True, lambda x: fhe.bits(x)[1:3] + 100, id="signed-5b[1:3] + 100"), ], @@ -203,7 +207,143 @@ def test_bit_extraction(input_bit_width, input_is_signed, operation, helpers): compiler = fhe.Compiler(operation, {"x": "encrypted"}) circuit = compiler.compile(inputset, helpers.configuration()) - values = inputset if len(inputset) <= 8 else random.sample(inputset, 8) for value in values: helpers.check_execution(circuit, operation, value, retries=3) + + +def mlir_count_ops(mlir, operation): + """ + Count op in mlir. + """ + return sum(operation in line for line in mlir.splitlines()) + + +def test_highest_bit_extraction_mlir(helpers): + """ + Test bit extraction of the highest bit. Saves one lsb. + """ + + precision = 8 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return fhe.bits(x)[precision - 1] + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == precision - 1 + assert mlir_count_ops(circuit.mlir, "lookup") == 0 + + +def test_bits_extraction_to_same_bitwidth_mlir(helpers): + """ + Test bit extraction to same. + """ + + precision = 8 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return tuple(fhe.bits(x)[i] for i in range(precision)) + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == precision - 1 + assert mlir_count_ops(circuit.mlir, "lookup") == 0 + + +def test_bits_extraction_to_bigger_bitwidth_mlir(helpers): + """ + Test bit extraction to bigger bitwidth. + """ + + precision = 8 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return tuple(fhe.bits(x)[i] + (2**precision + 1) for i in range(precision)) + + circuit = operation.compile(inputset, helpers.configuration()) + print(circuit.mlir) + assert mlir_count_ops(circuit.mlir, "lsb") == precision + assert mlir_count_ops(circuit.mlir, "lookup") == 0 + + +def test_seq_bits_extraction_to_same_bitwidth_mlir(helpers): + """ + Test sequential bit extraction to smaller bitwidth. + """ + + precision = 8 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return tuple(fhe.bits(x)[i] + (2**precision - 2) for i in range(precision)) + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == precision + assert mlir_count_ops(circuit.mlir, "lookup") == 0 + + +def test_seq_bits_extraction_to_smaller_bitwidth_mlir(helpers): + """ + Test sequential bit extraction to smaller bitwidth. + """ + + precision = 8 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return tuple(fhe.bits(x)[i] for i in range(precision)) + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == precision - 1 + assert mlir_count_ops(circuit.mlir, "lookup") == 0 + + +def test_seq_bits_extraction_to_bigger_bitwidth_mlir(helpers): + """ + Test sequential bit extraction to bigger bitwidth. + """ + + precision = 8 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return tuple(fhe.bits(x)[i] + 2 ** (precision + 1) for i in range(precision)) + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == precision + assert mlir_count_ops(circuit.mlir, "lookup") == 0 + + +def test_bit_extract_to_1_tlu(helpers): + """ + Test bit extract as 1 tlu for small precision. + """ + precision = 3 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return fhe.bits(x)[0:2] + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == 0 + assert mlir_count_ops(circuit.mlir, "lookup") == 1 + + precision = 4 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): # pylint: disable=function-redefined + return fhe.bits(x)[0:2] + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == 2 + assert mlir_count_ops(circuit.mlir, "lookup") == 0 diff --git a/frontends/concrete-python/tests/extensions/test_synthesis.py b/frontends/concrete-python/tests/extensions/test_synthesis.py new file mode 100644 index 0000000000..6ef6c4bbde --- /dev/null +++ b/frontends/concrete-python/tests/extensions/test_synthesis.py @@ -0,0 +1,142 @@ +""" +Tests of 'synthesis' extension. +""" + +import pytest + +import concrete.fhe.extensions.synthesis as synth +from concrete import fhe + + +def test_relu_correctness(): + """ + Check relu. + """ + out = fhe.int5 + params = {"a": out} + expression = "(a >= 0) ? a : 0" + + relu = synth.verilog_expression(params, expression, out) + + @fhe.circuit({"a": "encrypted"}) + def circuit(a: out): + return relu(a=a) + + def ref_relu(v): + return v if v >= 0 else 0 + + r = list(range(0, 2**4)) + list(range(-(2**4), 0)) + for a in r: + assert circuit.simulate(a) == ref_relu(a) + + +def test_signed_add_correctness(): + """ + Check add. + """ + out = fhe.int5 + params = {"a": out, "b": out} + expression = "a + b" + + @fhe.circuit({"a": "encrypted", "b": "encrypted"}) + def circuit(a: out, b: out): + return synth.verilog_expression(params, expression, out)(a=a, b=b) + + print(circuit.mlir) + + r = list(range(0, 2**4)) + list(range(-(2**4), 0)) + for a in r: + for b in r: + if a + b in r: + assert circuit.simulate(a, b) == a + b + + +def test_unsigned_add_correctness(): + """ + Check add. + """ + out = fhe.uint5 + params = {"a": out, "b": out} + expression = "a + b" + + @fhe.circuit({"a": "encrypted", "b": "encrypted"}) + def circuit(a: out, b: out): + return synth.verilog_expression(params, expression, out)(a=a, b=b) + + print(circuit.mlir) + + r = list(range(0, 2**5)) + for a in r: + for b in r: + if a + b in r: + assert circuit.simulate(a, b) == a + b + + +def test_signed_mul_correctness(): + """ + Check add. + """ + out = fhe.int5 + params = {"a": out, "b": out} + expression = "a * b" + + @fhe.circuit({"a": "encrypted", "b": "encrypted"}) + def circuit(a: out, b: out): + return synth.verilog_expression(params, expression, out)(a=a, b=b) + + r = list(range(0, 2**4)) + list(range(-(2**4), 0)) + for a in r: + for b in r: + if a * b in r: + assert circuit.simulate(a, b) == a * b + + +def test_unsigned_mul_correctness(): + """ + Check add. + """ + out = fhe.uint5 + params = {"a": out, "b": out} + expression = "a * b" + + @fhe.circuit({"a": "encrypted", "b": "encrypted"}) + def circuit(a: out, b: out): + return synth.verilog_expression(params, expression, out)(a=a, b=b) + + r = list(range(0, 2**5)) + for a in r: + for b in r: + if a * b in r: + assert circuit.simulate(a, b) == a * b + + +def test_relu_limit(): + """ + Check the limit for relu precision. + 28bits is only attainable only with weight in tlu fuzing. + """ + out = fhe.int28 + params = {"a": out} + expression = "(a >= 0) ? a : 0" + + relu = synth.verilog_expression(params, expression, out) + + @fhe.circuit({"a": "encrypted", "b": "encrypted"}) + def _(a: out, b: out): + v = relu(a=a) - relu(a=b) + return relu(a=v) + + out = fhe.int29 + params = {"a": out} + expression = "(a >= 0) ? a : 0" + + relu = synth.verilog_expression(params, expression, out) + + with pytest.raises(RuntimeError) as err: + + @fhe.circuit({"a": "encrypted", "b": "encrypted"}) + def _(a: out, b: out): + v = relu(a=a) - relu(a=b) + return relu(a=v) + + assert str(err.value) == "NoParametersFound"