diff --git a/frontends/concrete-python/.pylintrc b/frontends/concrete-python/.pylintrc index 6b83703deb..e0ef676713 100644 --- a/frontends/concrete-python/.pylintrc +++ b/frontends/concrete-python/.pylintrc @@ -450,6 +450,7 @@ disable=raw-checker-failed, wrong-import-order, unsubscriptable-object, no-else-continue, + no-else-return, unnecessary-comprehension # Enable the message, report, category or checker with the given id(s). You can diff --git a/frontends/concrete-python/.ruff.toml b/frontends/concrete-python/.ruff.toml index b520e6e7dc..0897538aa4 100644 --- a/frontends/concrete-python/.ruff.toml +++ b/frontends/concrete-python/.ruff.toml @@ -7,9 +7,10 @@ select = [ ] ignore = [ "A", "D", "FBT", "T20", "ANN", "N806", "ARG001", "S101", "BLE001", "RUF100", "ERA001", "SIM105", - "RET504", "TID252", "PD011", "I001", "UP015", "C901", "A001", "SIM118", "PGH003", "PLW2901", + "RET504", "RET505", "TID252", "PD011", "I001", "UP015", "C901", "A001", "SIM118", "PGH003", "PLW2901", "PLR0915", "C416", "PLR0911", "PLR0912", "PLR0913", "RUF005", "PLR2004", "S110", "PLC1901", - "E731", "RET507", "SIM102" + "E731", "RET507", "SIM102", "SIM108", + "Q000", ] [per-file-ignores] diff --git a/frontends/concrete-python/concrete/fhe/compilation/__init__.py b/frontends/concrete-python/concrete/fhe/compilation/__init__.py index 8a2ff857d0..97e4abc15e 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/__init__.py +++ b/frontends/concrete-python/concrete/fhe/compilation/__init__.py @@ -19,6 +19,7 @@ MultiParameterStrategy, MultivariateStrategy, ParameterSelectionStrategy, + SynthesisConfig, ) from .keys import Keys from .module import FheFunction, FheModule diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index 90c856c9a9..f3974c4d34 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -122,6 +122,31 @@ class ApproximateRoundingConfig: """ +@dataclass +class SynthesisConfig: + """ + Controls the behavior of synthesis. + """ + + start_tlu_at_precision: int = 7 + """ + Starting synthesis at the given TLU input precision, but keep the original TLU when it's faster. + Used to make high precision TLU faster by rewritting them with several lower precisions TLU. + """ + + force_tlu_at_precision: int = 17 + """ + Force synthesis at the given TLU input precision, even if it's slower than the original TLU. + Used to replace any high precision TLU by several lower precisions TLU. + """ + + maximal_tlu_input_bit_width: int = 8 + """ + Maximal bit_width for TLU generated by synthesis. + Used if you want guarantees on the maximum input bit_width of TLU after synthesis. + """ + + class ComparisonStrategy(str, Enum): """ ComparisonStrategy, to specify implementation preference for comparisons. @@ -994,6 +1019,7 @@ class Configuration: dynamic_assignment_check_out_of_bounds: bool simulate_encrypt_run_decrypt: bool composable: bool + synthesis_config: SynthesisConfig def __init__( self, @@ -1063,6 +1089,7 @@ def __init__( dynamic_indexing_check_out_of_bounds: bool = True, dynamic_assignment_check_out_of_bounds: bool = True, simulate_encrypt_run_decrypt: bool = False, + synthesis_config: Optional[SynthesisConfig] = None, ): self.verbose = verbose self.compiler_debug_mode = compiler_debug_mode @@ -1170,6 +1197,8 @@ def __init__( self.simulate_encrypt_run_decrypt = simulate_encrypt_run_decrypt + self.synthesis_config = synthesis_config or SynthesisConfig() + self._validate() class Keep: @@ -1245,6 +1274,7 @@ def fork( dynamic_indexing_check_out_of_bounds: Union[Keep, bool] = KEEP, dynamic_assignment_check_out_of_bounds: Union[Keep, bool] = KEEP, simulate_encrypt_run_decrypt: Union[Keep, bool] = KEEP, + synthesis_config: Union[Keep, Optional[SynthesisConfig]] = KEEP, ) -> "Configuration": """ Get a new configuration from another one specified changes. diff --git a/frontends/concrete-python/concrete/fhe/extensions/synthesis/__init__.py b/frontends/concrete-python/concrete/fhe/extensions/synthesis/__init__.py new file mode 100644 index 0000000000..b9fbfa200a --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/extensions/synthesis/__init__.py @@ -0,0 +1,3 @@ +"""Provide synthesis main entry points.""" + +from .fhe_function import lut, verilog_expression diff --git a/frontends/concrete-python/concrete/fhe/extensions/synthesis/eval_context.py b/frontends/concrete-python/concrete/fhe/extensions/synthesis/eval_context.py new file mode 100644 index 0000000000..3afe278f61 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/extensions/synthesis/eval_context.py @@ -0,0 +1,129 @@ +# pylint: disable=missing-module-docstring,missing-function-docstring + +from dataclasses import dataclass + +import numpy as np + +from concrete.fhe.extensions.synthesis.verilog_source import Ty + + +class EvalContext: + """ + This is a reduced context with similar method as `concrete.fhe.mlir.Context`. + + It provides a clear evaluation backend for tlu_circuit_to_mlir. + Until the all internal_api are used directly by concrete-python, + this helps to keep all tests for previous backend and api. + For now only synthesis of TLU is supported by concrete-python. + """ + + # For all `EvalContext` method look at `Context` documentation. + + @dataclass + class Ty: + """Equivalent for `ConversionType`.""" + + bit_width: int + is_tensor: bool = False + shape: tuple = () # not used + + @dataclass + class Val: + """Equivalent for `Conversion`. Contains the evaluation result.""" + + value: 'int | np.ndarray' + type: Ty + + def __init__(self, value, type_): + try: + value = int(value) + except TypeError: + pass + self.value = value + self.type = type_ + + def fork_type(self, type_, bit_width=None, shape=None): + return self.Ty(bit_width=bit_width or type_.bit_width, shape=shape or type_.shape) + + def i(self, size): + return self.Ty(size) + + def constant(self, type_: Ty, value: int): + return self.Val(value, type_) + + def mul(self, type_: Ty, a: Val, b: Val): + assert isinstance(b.value, int) + assert a.type == type_ + return self.Val(a.value * b.value, type_) + + def add(self, type_: Ty, a: Val, b: Val): + assert a.type == b.type == type_ + return self.Val(a.value + b.value, type_) + + def sub(self, type_: Ty, a: Val, b: Val): + assert isinstance(b.value, int) + assert a.type == type_ + return self.Val(a.value - b.value, type_) + + def tlu(self, type_: Ty, arg: Val, tlu_content, **_kwargs): + if isinstance(arg, int): + v = self.Val(tlu_content[arg.value], type_) + else: + v = np.vectorize(lambda v: int(tlu_content[v]))(arg.value) + return self.Val(v, type_) + + def extract_bits(self, type_: Ty, arg: Val, bit_index, **_kwargs): + return self.Val((arg.value >> bit_index) & 1, type_) + + def to_unsigned(self, arg: Val): + def aux(value): + if value < 0: + return 2**arg.type.bit_width + value + return value + + if isinstance(arg.value, int): + v = aux(arg.value) + else: + v = np.vectorize(aux)(arg.value) + return self.Val(v, arg.type) + + def to_signed(self, arg: Val): + def aux(value): + assert value >= 0 + negative = value >= 2 ** (arg.type.bit_width - 1) + if negative: + return -(2**arg.type.bit_width - arg.value) + return value + + if isinstance(arg.value, int): + v = aux(arg.value) + else: + v = np.vectorize(aux)(arg.value) + return self.Val(v, arg.type) + + def index(self, type_: Ty, tensor: Val, index): + assert isinstance(tensor.value, list), type(tensor.value) + assert len(index) == 1 + (index,) = index + return self.Val(tensor.value[index], self.Ty(type_.bit_width, is_tensor=False)) + + def reinterpret(self, arg, bit_width=None): + arg_bit_width = arg.type.bit_width + if bit_width is None: + bit_width = arg_bit_width + if bit_width == arg_bit_width: + return arg + shift = 2 ** (bit_width - arg_bit_width) + if isinstance(arg, int): + v = arg.value * shift + else: + v = np.vectorize(lambda v: v * shift)(arg.value) + return self.Val(v, self.Ty(bit_width=bit_width)) + + def safe_reduce_precision(self, arg, bit_width): + if arg.type.bit_width == bit_width: + return arg + assert arg.type.bit_width > bit_width + shift = arg.type.bit_width - bit_width + shifted = self.mul(arg.type, arg, self.constant(self.i(bit_width + 1), 2**shift)) + return self.reinterpret(shifted, bit_width) diff --git a/frontends/concrete-python/concrete/fhe/extensions/synthesis/fhe_function.py b/frontends/concrete-python/concrete/fhe/extensions/synthesis/fhe_function.py new file mode 100644 index 0000000000..3752b24cb3 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/extensions/synthesis/fhe_function.py @@ -0,0 +1,162 @@ +""" +INTERNAL extension to synthesize a fhe compatible function from verilog code. +""" + +from collections import Counter +from typing import Optional + +import concrete.fhe.dtypes as fhe_dtypes +import concrete.fhe.tracing.typing as fhe_typing +from concrete.fhe.dtypes.integer import Integer +from concrete.fhe.extensions.synthesis.eval_context import EvalContext +from concrete.fhe.extensions.synthesis.luts_to_fhe import tlu_circuit_to_mlir +from concrete.fhe.extensions.synthesis.luts_to_graph import to_graph +from concrete.fhe.extensions.synthesis.verilog_source import ( + Ty, + verilog_from_expression, + verilog_from_tlu, +) +from concrete.fhe.extensions.synthesis.verilog_to_luts import yosys_lut_synthesis +from concrete.fhe.values.value_description import ValueDescription + + +class FheFunction: + """Main class to synthesize verilog.""" + + def __init__( + self, + *, + verilog, + name, + params=None, + result_name="result", + yosys_dot_file=False, + verbose=False, + ): + assert params + self.name = name + self.verilog = verilog + if verbose: + print() + print(f"Verilog, {name}:") + print(verilog) + print() + if verbose: + print("Synthesis") + self.circuit = yosys_lut_synthesis( + verilog, yosys_dot_file=yosys_dot_file, circuit_name=name + ) + if verbose: + print() + print(f"TLUs counts, {self.tlu_counts()}:") + print() + self.params = params + self.result_name = result_name + + self.mlir = tlu_circuit_to_mlir(self.circuit, self.params, result_name, verbose) + + def __call__(self, **kwargs): + """ + Evaluate using mlir generation with a direct evaluation context. + + This is useful for testing purpose. + """ + args = [] + for name, type_ in self.params.items(): + if name == "result": + continue + if isinstance(type_, list): + val = EvalContext.Val( + kwargs[name], EvalContext.Ty(type_[0].dtype.bit_width, is_tensor=True) + ) + else: + val = EvalContext.Val(kwargs[name], EvalContext.Ty(type_.dtype.bit_width)) + args.append(val) + result_ty = self.params["result"] + if isinstance(result_ty, list): + eval_ty = EvalContext.Ty(result_ty, is_tensor=True) + else: + eval_ty = EvalContext.Ty(result_ty, is_tensor=False) + result = self.mlir(EvalContext(), eval_ty, args) + if isinstance(result_ty, list): + return [r.value for r in result] + else: + return result.value + + def tlu_counts(self): + """Count the number of tlus in the synthesized tracer keyed by input precision.""" + counter = Counter() + for node in self.circuit.nodes: + if len(node.arguments) == 1: + print(node) + counter.update({len(node.arguments): 1}) + return dict(sorted(counter.items())) + + def is_faster_than_1_tlu(self, reference_costs): + """Verify that synthesis is faster than the original tlu.""" + costs = 0 + for node in self.circuit.nodes: + zero_cost = len(node.arguments) <= 1 + if zero_cost: + # constant or inversion (converted to substraction) + continue + else: + costs += reference_costs[len(node.arguments)] + try: + return costs <= reference_costs[self.params["a"].dtype.bit_width] + except KeyError: + return True + + def graph(self, *, filename=None, view=True, **kwargs): + """Render the synthesized tracer as a graph.""" + graph = to_graph(self.name, self.circuit.nodes) + graph.render(filename=filename, view=view, cleanup=filename is None, **kwargs) + + +def lut(table: 'list[int]', out_type: Optional[ValueDescription] = None, **kwargs): + """Synthesize a lookup function from a table.""" + # assert not signed # TODO signed case + if isinstance(out_type, list): + msg = "Multi-message output is not supported" + raise TypeError(msg) + if out_type: + assert isinstance(out_type.dtype, Integer) + v_out_type = Ty( + bit_width=out_type.dtype.bit_width, + is_signed=out_type.dtype.is_signed, + ) + verilog, v_out_type = verilog_from_tlu(table, signed_input=False, out_type=v_out_type) + if "name" not in kwargs: + kwargs.setdefault("name", "lut") + if "params" not in kwargs: + dtype = fhe_dtypes.Integer.that_can_represent(len(table) - 1) + a_ty = getattr(fhe_typing, f"uint{dtype.bit_width}") + assert a_ty + kwargs["params"] = {"a": a_ty, "result": out_type} + return FheFunction(verilog=verilog, **kwargs) + + +def _uniformize_as_list(v): + return v if isinstance(v, (list, tuple)) else [v] + + +def verilog_expression( + in_params: 'dict[str, ValueDescription]', expression: str, out_type: ValueDescription, **kwargs +): + """Synthesize a lookup function from a verilog function.""" + result_name = "result" + if result_name in in_params: + result_name = f"{result_name}_{hash(expression)}" + in_params = dict(in_params) + in_params[result_name] = out_type + verilog_params = { + name: Ty( + bit_width=sum(ty.dtype.bit_width for ty in _uniformize_as_list(type_list)), + is_signed=any(ty.dtype.is_signed for ty in _uniformize_as_list(type_list)), + ) + for name, type_list in in_params.items() + } + verilog = verilog_from_expression(verilog_params, expression, result_name) + if "name" not in kwargs: + kwargs.setdefault("name", expression) + return FheFunction(verilog=verilog, params=in_params, result_name=result_name, **kwargs) diff --git a/frontends/concrete-python/concrete/fhe/extensions/synthesis/luts_to_fhe.py b/frontends/concrete-python/concrete/fhe/extensions/synthesis/luts_to_fhe.py new file mode 100644 index 0000000000..4ea4fe78d0 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/extensions/synthesis/luts_to_fhe.py @@ -0,0 +1,480 @@ +""" +Convert the simple Tlu Dag to a concrete-python Tracer function. +""" + +from collections import Counter +from dataclasses import dataclass +from typing import Iterable, cast + +import numpy as np + +from mlir.dialects import arith + +from concrete.fhe.dtypes.integer import Integer +from concrete.fhe.extensions.synthesis.verilog_to_luts import TluCircuit, TluNode +from concrete.fhe.mlir import context as mlir_context +from concrete.fhe.mlir.conversion import Conversion +from concrete.fhe.values.value_description import ValueDescription + +WEIGHT_TO_TLU = True +RESCALE_BEFORE_TLU = True + + +def power_of_2_scale(v): + """Compute the exact log2 of v.""" + assert v > 0 + if v % 2 == 0: + return 1 + power_of_2_scale(v // 2) + else: + assert v == 1 + return 0 + + +def compute_max_arity_use( + circuit: TluCircuit, + params: 'dict[str, ValueDescription | list[ValueDescription]]', + scheduled_nodes, + bit_location, + used_count, +): + """Collect max arity use for all values to later compute all bitwidth.""" + max_arity_use = {} + for bits in circuit.parameters.values(): + for value in bits: + max_arity_use[value.name] = 1 + for tlu in scheduled_nodes: + for value in tlu.arguments: + max_arity_use[value.name] = max(max_arity_use[value.name], tlu.arity) + for value in tlu.results: + assert value.name not in max_arity_use + max_arity_use[value.name] = 1 + + # Result can be either precision correct on single use or < 7 or converted + for result in circuit.results: + for res_bit in result: + base_name = result[0].origin.base_name + if isinstance(params[base_name], list): + word_i = bit_location[base_name][res_bit.name].word + dtype = params[base_name][word_i].dtype # type: ignore + else: + dtype = params[base_name].dtype # type: ignore + bit_width = dtype.bit_width # type: ignore + use_bit_width = ( + # direct to result = exact + used_count.get(res_bit.name, 1) == 1 + # force result precision everywhere if small + or max_arity_use[res_bit.name] < bit_width < 7 + ) + if use_bit_width: + max_arity_use[res_bit.name] = bit_width + else: + # do not let big precision results for big precision everywhere + pass + return max_arity_use + + +def compute_inferred_bit_width( + circuit: TluCircuit, + params: 'dict[str, ValueDescription | list[ValueDescription]]', + scheduled_nodes, + bit_location, + used_count, +): + """Compute the bitwidth of all values.""" + max_arity_use = compute_max_arity_use( + circuit, params, scheduled_nodes, bit_location, used_count + ) + bit_width_unify = set() + unified = [] + for tlu in scheduled_nodes: + constant = len(tlu.arguments) == 0 + if constant: + continue + no_tlu = len(tlu.arguments) == 1 and tlu.content in ([0, 1], [1, 0]) + if no_tlu: + input_name = tlu.arguments[0].name + output_name = tlu.results[0].name + unified.append((input_name, output_name)) + bit_width_unify.add(tuple(sorted((input_name, output_name)))) + else: + bit_width_unify.add(tuple(sorted([bit.name for bit in tlu.arguments]))) + + changed = True + inferred_bit_width = dict(max_arity_use) + while changed: + changed = False + for bits_to_unify in bit_width_unify: + max_arity_uses = [inferred_bit_width[bit] for bit in bits_to_unify] + bit_width_max = max(max_arity_uses) + for bit, v in zip(bits_to_unify, max_arity_uses): + if v != bit_width_max: + inferred_bit_width[bit] = bit_width_max + changed = True + for input_name, output_name in unified: + max_max = max(inferred_bit_width[input_name], inferred_bit_width[output_name]) + inferred_bit_width[input_name] = max_max + inferred_bit_width[output_name] = max_max + return inferred_bit_width, unified + + +@dataclass +class BitLocation: + """Express a bit position in a multi-word circuit input/output.""" + + word: int + """Word position.""" + local_bit_index: int + """Bit position in the word.""" + global_bit_index: int + """Bit position in the global value.""" + + +def compute_bit_location(circuit: TluCircuit, params) -> 'dict[str, dict[str, BitLocation]]': + """Compute the bit location of all inputs/outputs.""" + bit_location: dict[str, dict[str, BitLocation]] = {} + for name in circuit.parameters: + ty_l: list = params[name] # type: ignore + if not isinstance(params[name], list): + ty_l = [params[name]] # type: ignore + bit_location[name] = {} + g_index = 0 + for i_word, type_ in enumerate(ty_l): + for bit_index in range(type_.dtype.bit_width): # type: ignore + bit = circuit.parameters[name][g_index] + assert bit.origin.bit_index == g_index + bit_location[name][bit.name] = BitLocation( + word=i_word, + local_bit_index=bit_index, + global_bit_index=g_index, + ) + g_index += 1 + for result in circuit.results: + name = result[0].origin.base_name + ty_l = params[name] + if not isinstance(ty_l, list): + ty_l = [ty_l] + bit_location[name] = {} + g_index = 0 + for i_word, type_ in enumerate(ty_l): + for bit_index in range(type_.dtype.bit_width): + bit = result[g_index] + assert bit.origin.bit_index == g_index + bit_location[name][bit.name] = BitLocation( + word=i_word, + local_bit_index=bit_index, + global_bit_index=g_index, + ) + g_index += 1 + + return bit_location + + +def compute_min_scale( + circuit: TluCircuit, + scheduled_nodes, + bit_location: 'dict[str, dict[str, BitLocation]]', + unified: 'list[tuple[str, str]]', + inferred_bit_width: 'dict[str, int]', +): + """Compute the minimum weight for all value.""" + min_scale = {} + for result in circuit.results: + for res_bit in result: + base_name = res_bit.origin.base_name + local_bit_index = bit_location[base_name][res_bit.name].local_bit_index + min_scale[res_bit.name] = 2**local_bit_index + for tlu in scheduled_nodes: + for i, value in enumerate(tlu.arguments): + if all(value.name != name for name, _ in unified) or len(tlu.arguments) > 1: + min_scale[value.name] = min(min_scale.get(value.name, 2**i), 2**i) + for value in tlu.results: + assert value.name not in min_scale or value.origin.is_result + if not value.origin.is_result: + min_scale[value.name] = 2 ** (inferred_bit_width[value.name] - 1) + if not WEIGHT_TO_TLU: + for n in min_scale: + min_scale[n] = 1 + + for input_name, output_name in unified: + min1 = min_scale.get(input_name) + min2 = min_scale[output_name] + if min1 is None: + min1 = min2 + min_min = min(min1, min2) + min_scale[input_name] = min_min + min_scale[output_name] = min_min + + return min_scale + + +def repack_scaled_bits( + context: 'mlir_context.Context', + scaled_bit_values: 'Iterable[tuple[int, Conversion]]', + before_tlu: bool = True, +): + """Recombine all binary value to an integer value.""" + max_weight = 0 + scaled_bit_values = list(scaled_bit_values) + assert scaled_bit_values + bit_width = len(scaled_bit_values) + assert bit_width > 0 + repacked_bits = None + arg_max_bit_width = 0 + constants_sum = 0 + for i, (scale, value) in enumerate(scaled_bit_values): + if i == 0: + assert scale == 1 + if value is None: + assert scale == 0 + continue + assert scale >= 1 + assert scale <= 2**i + assert scale & (scale - 1) == 0 # is power of 2 + weight = 2**i // scale + max_weight = max(max_weight, weight) + if isinstance(value.result, arith.ConstantOp): + # clear mul-mul is not supported + constants_sum += int(str(value.result.attributes["value"]).split(":")[0]) * weight + continue + if weight == 1: + add = value + else: + weight = context.constant(context.i(value.type.bit_width + 1), weight) + add = context.mul(value.type, value, weight) + if add.type.bit_width > bit_width: + add = context.safe_reduce_precision(add, bit_width) + elif add.type.bit_width > bit_width: + add_type = context.fork_type(add.type, bit_width=bit_width) + add = context.tlu( + add_type, + add, + [min(v, 2**bit_width - 1) for v in range(2**add.type.bit_width)], + ) + assert add.type.bit_width == bit_width + if repacked_bits is None: + repacked_bits = add + else: + assert repacked_bits.type.bit_width == add.type.bit_width, ( + repacked_bits.type.bit_width, + add.type.bit_width, + scaled_bit_values, + ) + repacked_bits = context.add(add.type, repacked_bits, add) + assert repacked_bits is not None + if constants_sum != 0: + constants_sum = context.constant(context.i(repacked_bits.type.bit_width + 1), constants_sum) + repacked_bits = context.add(repacked_bits.type, repacked_bits, constants_sum) + extra_bits = arg_max_bit_width - bit_width + if extra_bits > 0 and before_tlu and RESCALE_BEFORE_TLU: + repacked_bits = context.safe_reduce_precision(repacked_bits, bit_width) + return repacked_bits + + +def layered_nodes(nodes: 'list[TluNode]'): + """ + Group nodes in layers by readyness (ready to be computed). + """ + waiting: 'dict[str, list[int]]' = {} + for i, tlu_node in enumerate(nodes): + for arg in tlu_node.arguments: + waiting.setdefault(arg.name, []).append(i) + readyness = [ + sum(not (node.origin.is_parameter) for node in tlu_node.arguments) for tlu_node in nodes + ] + nb_nodes = len(nodes) + ready_nodes = [i for i, count in enumerate(readyness) if count == 0] + layers = [] + while nb_nodes > 0: + assert ready_nodes, (nb_nodes, readyness) + new_ready_nodes = [] + for i_node in ready_nodes: + for result in nodes[i_node].results: + for index in waiting.get(result.name, ()): + readyness[index] -= 1 + assert readyness[index] >= 0 + if readyness[index] == 0: + new_ready_nodes.append(index) + layers += [ready_nodes] + nb_nodes -= len(ready_nodes) + ready_nodes = new_ready_nodes + return layers + + +def tlu_circuit_to_mlir( + circuit: TluCircuit, + params: 'dict[str, ValueDescription | list[ValueDescription]]', + result_name: str, + verbose: bool, +): + """Convert the simple TLU dag to a Tracer function.""" + layers = layered_nodes(circuit.nodes) + scheduled_nodes = [circuit.nodes[index_node] for layer in layers for index_node in layer] + + if verbose: + print("Layers") + for i, layer in enumerate(layers): + arities = [len(circuit.nodes[node_index].arguments) for node_index in layer] + print(f"Layer {i}") + print(f" {arities}") + print(f" nb luts: {len(layer)}") + + # positions of bits, useful for multi-word inputs/outputs + bit_location = compute_bit_location(circuit, params) + + # detect single use case + used_count: Counter = Counter() + used_count.update(tlu_arg.name for tlu in scheduled_nodes for tlu_arg in tlu.arguments) + used_count.update(res_bit.name for result in circuit.results for res_bit in result) + + # inferred bit width + inferred_bit_width, unified = compute_inferred_bit_width( + circuit, params, scheduled_nodes, bit_location, used_count + ) + + # collect min scale use for all values + min_scale = compute_min_scale( + circuit, scheduled_nodes, bit_location, unified, inferred_bit_width + ) + + def mlir(context: mlir_context.Context, _resulting_type, args: 'list[Conversion]'): + if len(circuit.parameters) != len(args): + msg = "Invalid number of args" + raise ValueError(msg) + + kwargs = {name: arg for name, arg in zip(circuit.parameters, args)} + + for name, value in kwargs.items(): + if name in params: + if isinstance(params[name], list): + if not value.type.is_tensor: + msg = f"`{name}` should be a tensor" + raise TypeError(msg) + + # decompose parameters into bits + parameters = {} + for name, value in kwargs.items(): + unsigned = context.to_unsigned(value) + for bit, bit_loc in zip(circuit.parameters[name], bit_location[name].values()): + word_i = bit_loc.word + word_bit_i = bit_loc.local_bit_index + assert bit_loc.global_bit_index == bit.origin.bit_index + if isinstance(params[name], list): + word_value_type = context.fork_type( + value.type, shape=tuple(value.type.shape)[:-1] + ) + word_value = context.index(word_value_type, value, [word_i]) + else: + assert word_i == 0 + word_value = unsigned + min_scale_bit = min_scale.get(bit.name) + if min_scale_bit is None: + # bit is unused + continue + shift_by = power_of_2_scale(min_scale[bit.name]) + # Take into account word_bit_i + bit_width = inferred_bit_width[bit.name] - shift_by + # shift_by + bit_value_type = context.fork_type(word_value.type, bit_width=bit_width) + bit_value = context.extract_bits( + bit_value_type, + word_value, + word_bit_i, + assume_many_extract=True, + refresh_all_bits=True, + ) + if shift_by > 0: + bit_value_type = context.fork_type( + bit_value_type, bit_width=inferred_bit_width[bit.name] + ) + bit_value = context.reinterpret( + bit_value, bit_width=inferred_bit_width[bit.name] + ) + parameters[bit.name] = bit_value + + # will contains all intermediate ciphertext + intermediate_values = dict(parameters) + + # handle special case first + # constant and identity tlu + for tlu_node in scheduled_nodes: + output_name = tlu_node.results[0].name + assert len(tlu_node.results) == 1 + rescale = min_scale[output_name] + if len(tlu_node.results) == 1 and tlu_node.content == [0, 1]: + assert len(tlu_node.arguments) == 1 + assert any(tlu_node.arguments[0].name == name for name, _ in unified) + assert any(tlu_node.results[0].name == name for _, name in unified) + intermediate_values[output_name] = intermediate_values[tlu_node.arguments[0].name] + assert ( + intermediate_values[tlu_node.arguments[0].name].type.bit_width + == inferred_bit_width[output_name] + ) + elif len(tlu_node.results) == 1 and tlu_node.content == [1, 0]: + conv_arg = intermediate_values[tlu_node.arguments[0].name] + c_1_type = context.i(conv_arg.type.bit_width) + c_1 = context.constant(c_1_type, 1) + rev_conv_arg = context.sub(conv_arg.type, c_1, conv_arg) + intermediate_values[output_name] = rev_conv_arg + elif len(tlu_node.arguments) == 0: + bit_type = context.i(inferred_bit_width[output_name]) # TODO: tensor input + if tlu_node.content == ["0"]: + intermediate_values[output_name] = context.constant(bit_type, 0) + elif tlu_node.content == ["1"]: + intermediate_values[output_name] = context.constant(bit_type, rescale) + else: + msg = "Unknown Constant TLU content" + raise ValueError(msg) + else: + continue + + # apply all tlus + for tlu_node in scheduled_nodes: + output_name = tlu_node.results[0].name + if output_name in intermediate_values: + continue + assert len(tlu_node.arguments) > 1 + repacked_bits = repack_scaled_bits( + context, + ( + (min_scale[arg.name], intermediate_values[arg.name]) + for arg in tlu_node.arguments + ), + ) + rescale = min_scale[output_name] + flat_content = np.array(tlu_node.content).reshape(-1) + rescaled_content = [v * rescale for v in flat_content] + max_precision = inferred_bit_width[output_name] + assert max_precision + result_type = context.fork_type(repacked_bits.type, bit_width=max_precision) + result = context.tlu(result_type, repacked_bits, rescaled_content, no_synth=True) + intermediate_values[output_name] = result + + # recompose bits into result + results = [] + for result in circuit.results: + bits = [ + (min_scale[res_bit.name], intermediate_values[res_bit.name]) for res_bit in result + ] + ty_result = params[result_name] + if ty_result and isinstance(ty_result, list): + words = [] + for word_ty in ty_result: + word_bits = bits[: word_ty.dtype.bit_width] # type: ignore + bits = bits[word_ty.dtype.bit_width :] # type: ignore + word_result = repack_scaled_bits(context, word_bits, before_tlu=False) + if word_ty.dtype.is_signed: # type: ignore + word_result = context.to_signed(word_result) + words += [word_result] + results += [words] + else: + ty_result = cast(ValueDescription, ty_result) + result = repack_scaled_bits(context, bits, before_tlu=False) + assert isinstance(ty_result.dtype, Integer) + if ty_result.dtype.is_signed: # type: ignore + result = context.to_signed(result) + results += [result] + if len(results) == 1: + return results[0] + return tuple(results) + + return mlir diff --git a/frontends/concrete-python/concrete/fhe/extensions/synthesis/luts_to_graph.py b/frontends/concrete-python/concrete/fhe/extensions/synthesis/luts_to_graph.py new file mode 100644 index 0000000000..f4c0b4d671 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/extensions/synthesis/luts_to_graph.py @@ -0,0 +1,135 @@ +"""Construct a graphviz graph from tlu nodes.""" + + +def detect_std_lut_label(node): + """Give friendly name to some known tlus.""" + inputs = node.arguments + content = node.content + outputs = node.results + assert len(outputs) == 1 + if len(inputs) == 1: + return "l1-not" if content == [1, 0] else "l1-id" + if len(inputs) == 2: + if content == [[0, 0], [0, 1]]: + return "l2-carry" + if content == [[1, 1], [1, 0]]: + return "l2-nand" + if content == [[0, 1], [1, 1]]: + return "l2-or" + if content == [[1, 0], [0, 0]]: + return "l2-nor" + if content == [[0, 1], [1, 0]]: + return "l2-add" + if content == [[1, 0], [0, 1]]: + return "l2-nadd" + elif len(inputs) == 3: + if content == [[[0, 1], [1, 0]], [[1, 0], [0, 1]]]: + return "l3-add" + if content == [[[1, 0], [0, 1]], [[0, 1], [1, 0]]]: + return "l3-add(n,_,_)" + if content == [[[1, 0], [0, 0]], [[1, 1], [1, 0]]]: + return "l3-ncarry(n,_,_)" + if content == [[[1, 1], [1, 0]], [[1, 0], [0, 0]]]: + return "l3-ncarry(_,_,_)" + elif len(inputs) == 4: + if content == [[[[1, 0], [0, 1]], [[0, 1], [0, 1]]], [[[1, 0], [1, 0]], [[1, 0], [0, 1]]]]: + return "l4-add(n,_,n,_)" + return None + + +# def _detect_std_lut_label(node): +# inputs = node.arguments +# content = node.content +# outputs = node.results +# assert len(outputs) == 1 +# if len(inputs) == 0: +# return f"c-{content[0]}" + +# if len(inputs) == 1: +# if content == [1, 0]: +# return "l1-not" +# else: +# return "l1-id" +# if len(inputs) == 2: +# if content == [[0, 0], [0, 1]]: +# return "l2-carry" +# if content == [[1, 1], [1, 0]]: +# return "l2-nand" +# if content == [[0, 1], [1, 1]]: +# return "l2-or" +# if content == [[1, 0], [0, 0]]: +# return "l2-nor" +# if content == [[0, 1], [1, 0]]: +# return "l2-add" +# if content == [[1, 0], [0, 1]]: +# return "l2-add" +# elif len(inputs) == 3: +# if content == [[[0, 1], [1, 0]], [[1, 0], [0, 1]]]: +# return "l3-add" +# elif content == [[[1, 0], [0, 1]], [[0, 1], [1, 0]]]: +# return "l3-add" +# elif content == [[[1, 0], [0, 0]], [[1, 1], [1, 0]]]: +# return "l3-carry" +# elif content == [[[1, 1], [1, 0]], [[1, 0], [0, 0]]]: +# return "l3-carry" +# elif len(inputs) == 4: +# if [[[[1, 0], [0, 1]], [[0, 1], [0, 1]]], [[[1, 0], [1, 0]], [[1, 0], [0, 1]]]]: +# return "l4-add" +# return None + + +def to_graph(name, nodes): + """Construct a graphviz graph from tlu nodes.""" + try: + import graphviz # pylint: disable=import-outside-toplevel + except ImportError as exc: + msg = "You must install concrete-python with graphviz support or install graphviz manually" + raise ValueError(msg) from exc + + graph = graphviz.Digraph(name=name) + declated_node = set() + with graph.subgraph(name="cluster_output") as cluster: + for tlu in nodes: + outputs = tlu.results + for result in outputs: + if result.is_interface: + cluster.node(result.name, result.name, color="magenta", shape="box") + declated_node.add(result.name) + with graph.subgraph(name="cluster_input") as cluster: + for tlu in nodes: + inputs = tlu.arguments + for argument in sorted(inputs, key=lambda n: n.name): + if argument.is_interface and argument.name not in declated_node: + cluster.node(argument.name, argument.name, color="blue", shape="box") + declated_node.add(argument.name) + + def node_name_label(node, i): + label = detect_std_lut_label(node) or f"l-{len(node.arguments)}" + name = f"{label}-{i}" + return name, node.name + " = " + label + + with graph.subgraph(name="cluster_inner") as cluster: + for i, tlu in enumerate(nodes): + inputs = tlu.arguments + outputs = tlu.results + name, label = node_name_label(tlu, i) + cluster.node(name, label, shape="octagon") + for argument in inputs: + if argument.name not in declated_node: + cluster.node(argument.name, argument.name, color="black", shape="point") + declated_node.add(argument.name) + for result in outputs: + if result.name not in declated_node: + cluster.node(result.name, result.name, color="black", shape="point") + declated_node.add(result.name) + for i, tlu in enumerate(nodes): + inputs = tlu.arguments + outputs = tlu.results + name, label = node_name_label(tlu, i) + for j, argument in enumerate(inputs): + assert argument.name in declated_node + graph.edge(argument.name, name, headlabel=str(j)) + for result in outputs: + assert result.name in declated_node + graph.edge(name, result.name) + return graph diff --git a/frontends/concrete-python/concrete/fhe/extensions/synthesis/verilog_source.py b/frontends/concrete-python/concrete/fhe/extensions/synthesis/verilog_source.py new file mode 100644 index 0000000000..471b21d63c --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/extensions/synthesis/verilog_source.py @@ -0,0 +1,118 @@ +"""Provide helper function to generate verilog source code.""" + +import math +from collections import Counter +from dataclasses import dataclass +from typing import Dict, List, Tuple + +import numpy as np + + +@dataclass +class Ty: + """Type of arguments and result of a verilog module.""" + + bit_width: int + is_signed: bool + + +def signed(b): + """Textual signed attribute.""" + return "signed" if b else "" + + +def verilog_from_expression(params: Dict[str, Ty], operation: str, result_name: str) -> str: + """Create a verilog module source from params specification (including result), operation.""" + + out_bit_width = params[result_name].bit_width + out_signed = params[result_name].is_signed + params = ", ".join( + f"input {signed(ty.is_signed)} [0:{ty.bit_width - 1}] {name}" + for name, ty in params.items() + if name != result_name + ) + return f"""\ +module main({params}, output {signed(out_signed)} [0:{out_bit_width-1}] {result_name}); + assign {result_name} = {operation}; +endmodule +""" + + +def verilog_from_tlu(table: List[int], signed_input=False, out_type=None) -> Tuple[str, Ty]: + """Create a verilog module source doing the table lookup in table.""" + # note: verilog indexing is done with upper:lower (vs lower:upper) to avoid a bug in yosys + assert not signed_input + table = list(table) + max_table = max(table) + precision_a = math.ceil(math.log2(len(table))) + if out_type is None: + out_type = Ty( + bit_width=max(1, math.ceil(math.log2(max(1, max_table)))), + is_signed=False, + ) + + branching_bits = 2 + max_block_bits = 2 + + def gen_radix_tree(table, depth=1, remaining_bits=None, bits_len=None): + assert table + if bits_len is None: + bits_len = int(math.ceil(math.log2(len(table)))) + if remaining_bits is None: + remaining_bits = bits_len + if all(v == table[0] for v in table): + return str(table[0]) + start_str = " " * (4 * depth) + join_str = ":\n" + start_str + if remaining_bits > max_block_bits: + block_count = 2**branching_bits + block_size = 2 ** (remaining_bits - branching_bits) + blocks = [] + bits_checked = f"a[{remaining_bits-1}:{remaining_bits-branching_bits}]" + for i_block in range(block_count): + sub_table = table[i_block * block_size : (i_block + 1) * block_size] + blocks.append( + gen_radix_tree( + sub_table, + depth=depth + 1, + remaining_bits=remaining_bits - branching_bits, + bits_len=bits_len, + ) + ) + return start_str + join_str.join( + [ + f"({bits_checked} == {bits_cond}) ? \n({block})\n" + for bits_cond, block in enumerate(blocks[:-1]) + ] + + ["\n" + blocks[-1]] + ) + + # TODO: could compress a bit here, like simple linear cases + count = Counter() + count.update(table) + # the most common result is put in a last else so no conditions need to be checked + most_common_result = sorted(count.items(), key=lambda t: (t[1] << branching_bits) + t[0])[0][ + 0 + ] + bits_checked = f"a[{remaining_bits-1}:0]" + return start_str + join_str.join( + [ + f"({bits_checked} == {bits_cond}) ? {value}" + for bits_cond, value in enumerate(table) + if value != most_common_result + ] + + [str(most_common_result)] + ) + + blocks = gen_radix_tree(table) + return ( + f"""\ +module main(a, result); + input[{precision_a-1}:0] a; + output {signed(out_type.is_signed)} [{out_type.bit_width-1}:0] result; + assign result = (\n{blocks} + ); +endmodule\ +""", + out_type, + ) diff --git a/frontends/concrete-python/concrete/fhe/extensions/synthesis/verilog_to_luts.py b/frontends/concrete-python/concrete/fhe/extensions/synthesis/verilog_to_luts.py new file mode 100644 index 0000000000..4ad8ef3aee --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/extensions/synthesis/verilog_to_luts.py @@ -0,0 +1,363 @@ +""" +Rewrite yosys json output as a simple Tlu Dag. +""" + +import json +import os +import shutil +import subprocess +import sys +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np + +VERBOSE = False + + +def log(*args): + """Log function statically activated.""" + if VERBOSE: + print(*args) + + +## Assuming same arity everywhere, i.e. dag with uniform precision TLUs +LUT_COSTS = { + 1: 29, + 2: 33, + 3: 45, + 4: 74, + 5: 101, + 6: 231, + 7: 535, + 8: 1721, +} + + +def luts_spec_abc(): + """Generate the costs table (area, delay) for `abc`.""" + return "\n".join( + # arity area delay + f"{arity}\t{cost}\t{cost}" + for arity, cost in LUT_COSTS.items() + ) + + +YOSYS_EXE_NAME = "yowasp-yosys" +_yosys_exe = None # pylint: disable=invalid-name + + +def detect_yosys_exe(): + """Detect yosys executable.""" + global _yosys_exe # noqa: PLW0603 pylint: disable=global-statement + if _yosys_exe: + return _yosys_exe + result = ( + shutil.which(YOSYS_EXE_NAME) + or shutil.which(YOSYS_EXE_NAME, path=os.path.dirname(sys.executable)) + or shutil.which(YOSYS_EXE_NAME, path=os.path.dirname(shutil.which("python3") or "")) + ) + if result is None: + msg = f"{YOSYS_EXE_NAME} cannot be found." + raise RuntimeError(msg) + _yosys_exe = result + return result + + +def yosys_script(abc_path, verilog_path, json_path, dot_file, no_clean_up=False): + """Generate `yosys` scripts.""" + no_cleanup = "-nocleanup -showtmp" if no_clean_up else "" + return f""" +echo on +read -sv {verilog_path}; +prep +techmap +log Synthesis with ABC: {abc_path} +abc {no_cleanup} -script {abc_path} +write_json {json_path} +""" + ( + "" if not dot_file else "show -stretch" + ) + + +def abc_script(lut_cost_path): + """Generate `abc` scripts.""" + return f""" +# & avoid a bug when cleaning tmp +read_lut {lut_cost_path} +print_lut +strash +&get -n +&fraig -x +&put +scorr +dc2 +dretime +strash +dch -f +if +mfs2 +lutpack +""" + + +def bstr(bytes_str): + """Binary str to str.""" + return str(bytes_str, encoding="utf-8") + + +def _yosys_run_script( + abc_file, lut_costs_file, yosys_file, verilog_file, verilog_content, json_file, dot_file=True +): + """Run the yosys script using a subprocess based on the inputs/outpus files.""" + tmpdir_prefix = Path.home() / ".cache" / "concrete-python" / "synthesis" + os.makedirs(tmpdir_prefix, exist_ok=True) + new_files = [ + (abc_file, abc_script(lut_costs_file.name)), + (lut_costs_file, luts_spec_abc()), + (verilog_file, verilog_content), + (yosys_file, yosys_script(abc_file.name, verilog_file.name, json_file.name, dot_file)), + ] + for new_file, content in new_files: + new_file.write(content) + new_file.flush() + yosys_call = [detect_yosys_exe(), "-s", yosys_file.name] + try: + completed = subprocess.run(yosys_call, check=True, capture_output=True) + log(completed.stdout) + log(completed.stderr) + except subprocess.CalledProcessError as exc: + log(exc.output) + log(exc.stderr) + raise_verilog_warnings_and_error( + exc.stdout + exc.stderr, verilog_file.name, verilog_content + ) + print(bstr(exc.output)) + print(bstr(exc.stderr)) + raise + + if b"Warning" in completed.stdout: + raise_verilog_warnings_and_error(completed.stdout, verilog_file.name, verilog_content) + try: + return json.load(json_file) + except json.decoder.JSONDecodeError: + if not json_file.read(): + print(completed.stdout) + print(completed.stderr) + else: + print("JSON:", json_file.read()) + raise + + +def raise_verilog_warnings_and_error(output, verilog_path, verilog_content): + """Raise a tailored exception to provide user information to the detected error.""" + if isinstance(output, bytes): + output = str(output, encoding="utf8") + fatal = None + msgs = [] + for line in output.splitlines(): + location = verilog_path + ":" + if line.startswith(location): + msgs.append(line) + line_nb = int(line.split(location)[1].split(":")[0]) + lines = verilog_content.splitlines() + context_lines = "\n".join(verilog_content.splitlines()[line_nb - 2 : line_nb]) + underline = "^" * len(lines[line_nb - 1]) + fatal = f"{line}\nError at line {line_nb}:\n{context_lines}\n{underline}" + elif "Warning" in line: + msgs.append(line) + if msgs and (len(msgs) != 1 and fatal): + print() + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print("The following warnings/errors need to be checked") + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print() + print("\n".join(msgs)) + print() + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print() + if fatal: + raise ValueError(fatal) from None + + +def yosys_run_script(verilog, yosys_dot_file): + """Run the yosys script with the adequate generated files.""" + tmpdir_prefix = Path.home() / ".cache" / "concrete-python" / "synthesis" + os.makedirs(tmpdir_prefix, exist_ok=True) + + def tmpfile(mode, suffix, **kwargs): + return tempfile.NamedTemporaryFile( + mode=mode, dir=tmpdir_prefix, suffix=f"-{suffix}", **kwargs + ) + + # fmt: off + with \ + tmpfile("w+", "script.abc") as abc_file, \ + tmpfile("w+", "lut-costs.txt") as lut_costs_file, \ + tmpfile("w+", "source.verilog") as verilog_file, \ + tmpfile("w+", "yosys.script") as yosys_file, \ + tmpfile("r", "luts.json") as json_file: + return _yosys_run_script( + abc_file, lut_costs_file, yosys_file, + verilog_file, verilog, + json_file, dot_file=yosys_dot_file, + ) + # fmt: on + + +@dataclass(frozen=True) +class ValueOrigin: + """Original verilog input/output information for a value.""" + + is_parameter: bool = False + is_result: bool = False + bit_index: int = 0 + base_name: str = "" + + +@dataclass(frozen=True, order=True) +class ValueNode: + """An intermediate named value.""" + + name: str + origin: ValueOrigin = ValueOrigin() + + @classmethod + def parameter(cls, base_name, i): + """Construct a parameter (i.e. verilog input) derived value from the base input name.""" + # E.g. `ValueNode.parameter("a", 3)` is "a[3]", the 4th bit of verilog input `a`. + return ValueNode( + f"{base_name}[{i}]", ValueOrigin(is_parameter=True, bit_index=i, base_name=base_name) + ) + + @classmethod + def result(cls, base_name, i): + """Construct a result (i.e. verilog input) derived value from the base output name.""" + # E.g. `ValueNode.result("r", 3)` is "r[3]", the 4th bit of verilog output `r`. + return ValueNode( + f"{base_name}[{i}]", ValueOrigin(is_result=True, bit_index=i, base_name=base_name) + ) + + @classmethod + def port(cls, base_name, port, i): + """Construct a verilog port, i.e. either parameter/input or result/output.""" + if port["direction"] == "input": + return cls.parameter(base_name, i) + assert port["direction"] == "output" + return cls.result(base_name, i) + + @property + def is_interface(self): + """Check if a value is part of the verilog circuit interface.""" + return self.origin.is_parameter or self.origin.is_result + + +@dataclass +class TluNode: + """A TLU operation node.""" + + arguments: 'list[ValueNode]' + results: 'list[ValueNode]' + content: 'list[Any]' + + @property + def arity(self): + """Number of single bit parameters of the TLU.""" + return len(self.arguments) + + @property + def name(self): + """Name of the result.""" + assert len(self.results) == 1 + return self.results[0].name + + def __str__(self): + name = ", ".join(r.name for r in self.results) + args = ", ".join(a.name for a in self.arguments) + return f"{name} = tlu ({args}) ({self.content})" + + +@dataclass +class TluCircuit: + """A full circuit composent of parameters, results and intermediates nodes.""" + + name: str + parameters: 'dict[str, list[ValueNode]]' + results: 'list[list[ValueNode]]' + nodes: 'list[TluNode]' + + +def convert_yosys_json_to_circuit(json_data, circuit_name="noname"): + """Create a Circuit object from yosys json output.""" + modules = json_data["modules"] + assert len(modules) == 1 + (module,) = modules.values() + assert set(module.keys()) == {"attributes", "ports", "cells", "netnames"}, module.keys() + symbolic_name = {0: "0", 1: "1"} + nodes = [] + + symbol_set = set() + parameters = {} + results = [] + for name, port in module["ports"].items(): + elements = [] + for i, bit in enumerate(port["bits"]): + bit_node = ValueNode.port(name, port, i) + elements.append(bit_node) + symbol_set.add(bit_node.name) + if bit in ("0", "1"): + # output wired to constant + nodes += [ + TluNode( + arguments=[], + results=[bit_node], + content=[bit], + ) + ] + elif bit in symbolic_name: + # input wired to output + nodes += [ + TluNode( + arguments=[symbolic_name[bit]], + results=[bit_node], + content=[0, 1], + ) + ] + log("Equiv: ", symbolic_name[bit], bit_node) + else: + symbolic_name[bit] = bit_node + if elements[0].origin.is_parameter: + parameters[name] = elements + else: + results += [elements] + + log("Interface Names:", symbolic_name) + + for cell_value in module["cells"].values(): + assert cell_value["type"] == "$lut", cell_value + content = list(reversed(list(map(int, cell_value["parameters"]["LUT"])))) + assert set(cell_value["connections"].keys()) == {"A", "Y"} + arguments = [ + symbolic_name.get(n, ValueNode(f"n{n}")) for n in cell_value["connections"]["A"] + ] + # the first input is the last index # missing transpose ? + structured_content = np.array(content).reshape((2,) * (len(arguments))).tolist() + + cell_results = [ + symbolic_name.get(n, ValueNode(f"n{n}")) for n in cell_value["connections"]["Y"] + ] + nodes += [TluNode(arguments=arguments, results=cell_results, content=structured_content)] + + for node in nodes: + log(len(node.arguments), node.arguments, "---", node.content, "-->", node.results) + + log("Nodes:", nodes) + return TluCircuit(circuit_name, parameters, results, nodes) + + +def yosys_lut_synthesis(verilog: str, yosys_dot_file=False, circuit_name="noname"): + """Create a Circuit object from a verilog module.""" + json_data = yosys_run_script(verilog, yosys_dot_file) + return convert_yosys_json_to_circuit(json_data, circuit_name) diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index bfabfa4fc0..00d2f65d7c 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -4,6 +4,7 @@ # pylint: disable=import-error,no-name-in-module +from copy import deepcopy from random import randint from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Set, Tuple, Union @@ -39,6 +40,7 @@ ) from ..dtypes import Integer from ..extensions.bits import MAX_EXTRACTABLE_BIT, MIN_EXTRACTABLE_BIT +from ..extensions.synthesis import fhe_function as synth from ..representation import Graph, GraphProcessor, Node from ..tfhers.dtypes import TFHERSIntegerType from ..values import ValueDescription @@ -2403,8 +2405,6 @@ def shift_left_at_constant_precision(self, x: Conversion, rank: int) -> Conversi def safe_reduce_precision(self, x: Conversion, bit_width: int) -> Conversion: assert bit_width > 0 assert bit_width <= x.bit_width - if x.bit_width != x.original_bit_width: - assert bit_width >= x.original_bit_width if bit_width == x.bit_width: return x scaled_x = self.shift_left_at_constant_precision(x, x.type.bit_width - bit_width) @@ -2416,6 +2416,8 @@ def extract_bits( resulting_type: ConversionType, x: Conversion, bits: Union[int, np.integer, slice], + assume_many_extract=True, + refresh_all_bits=True, ) -> Conversion: if x.is_clear: highlights: Dict[Node, Union[str, List[str]]] = { @@ -2450,7 +2452,7 @@ def extract_bits( # we optimize bulk extract in low precision, used for identity cost_one_tlu = LUT_COSTS_V0_NORM2_0.get(x.bit_width, float("inf")) cost_many_lsbs = LUT_COSTS_V0_NORM2_0[1] * (max(bits, default=0) + 1) - if cost_one_tlu < cost_many_lsbs: + if cost_one_tlu < cost_many_lsbs and not assume_many_extract: def tlu_cell_with_positive_value(i): return x.type.is_unsigned or i < 2 ** (x.bit_width - 1) @@ -2485,6 +2487,7 @@ def tlu_cell_input_value(i): bit == (max_bit - 1) and x.bit_width == resulting_type.bit_width == 1 and x.is_unsigned + and not refresh_all_bits ): lsb_bit_witdh = 1 lsb = x @@ -2512,8 +2515,9 @@ def tlu_cell_input_value(i): break clearing_bit = self.safe_reduce_precision(lsb, x.bit_width) - cleared = self.sub(x.type, x, clearing_bit) - x = self.reinterpret(cleared, bit_width=(x.bit_width - 1)) + if bit != (current_bit - 1) or bit != bits_and_their_positions[-1][0]: + cleared = self.sub(x.type, x, clearing_bit) + x = self.reinterpret(cleared, bit_width=(x.bit_width - 1)) assert lsb is not None bit_value = self.to_signedness(lsb, of=resulting_type) @@ -3732,7 +3736,13 @@ def tensorize(self, x: Conversion) -> Conversion: original_bit_width=x.original_bit_width, ) - def tlu(self, resulting_type: ConversionType, on: Conversion, table: Sequence[int]): + def tlu( + self, + resulting_type: ConversionType, + on: Conversion, + table: Sequence[int], + no_synth: bool = True, + ): if on.is_clear: highlights = { on.origin: "this clear value is used as an input to a table lookup", @@ -3741,7 +3751,6 @@ def tlu(self, resulting_type: ConversionType, on: Conversion, table: Sequence[in self.error(highlights) assert resulting_type.is_encrypted - offset_before_tlu = on.origin.properties.get("offset_before_tlu") if offset_before_tlu is not None: @@ -3796,6 +3805,17 @@ def tlu(self, resulting_type: ConversionType, on: Conversion, table: Sequence[in else: table = table[: len(table) // 2] + padding + table[-len(table) // 2 :] + synth_config = self.configuration.synthesis_config + synth_try = synth_config.start_tlu_at_precision <= on.bit_width and not no_synth + if synth_try: + out_type = deepcopy(on.origin.output) + assert isinstance(out_type.dtype, Integer) + out_type.dtype.bit_width = resulting_type.bit_width + fhe_function = synth.lut(table, out_type=out_type) + synth_force = synth_config.force_tlu_at_precision <= on.bit_width + if synth_force or fhe_function.is_faster_than_1_tlu(LUT_COSTS_V0_NORM2_0): + return fhe_function.mlir(self, resulting_type, [on]) + dialect = fhe if on.is_scalar else fhelinalg operation = dialect.ApplyLookupTableEintOp diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index 9809a3ea1d..3d3db29f01 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -830,7 +830,9 @@ def tlu(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: variable_input_index = variable_input_indices[0] variable_input = preds[variable_input_index] + no_synth = False if variable_input.origin.properties.get("name") == "truncate_bit_pattern": + no_synth = True original_bit_width = variable_input.origin.properties["original_bit_width"] lsbs_to_remove = variable_input.origin.properties["kwargs"]["lsbs_to_remove"] truncated_bit_width = original_bit_width - lsbs_to_remove @@ -845,6 +847,7 @@ def tlu(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: variable_input = ctx.reinterpret(variable_input, bit_width=truncated_bit_width) elif variable_input.origin.properties.get("name") == "round_bit_pattern": + no_synth = True exactness = ( variable_input.origin.properties["exactness"] or ctx.configuration.rounding_exactness @@ -877,7 +880,9 @@ def tlu(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: ) if len(tables) == 1: - return ctx.tlu(ctx.typeof(node), on=variable_input, table=lut_values.tolist()) + return ctx.tlu( + ctx.typeof(node), on=variable_input, table=lut_values.tolist(), no_synth=no_synth + ) assert map_values is not None return ctx.multi_tlu( diff --git a/frontends/concrete-python/requirements.txt b/frontends/concrete-python/requirements.txt index 298354dd94..74b822848a 100644 --- a/frontends/concrete-python/requirements.txt +++ b/frontends/concrete-python/requirements.txt @@ -4,4 +4,6 @@ networkx>=2.6 numpy>=1.23,<2.0 scipy>=1.10 torch>=1.13 +yowasp-runtime==1.53 +yowasp-yosys==0.44.0.0.post760 z3-solver>=4.12 diff --git a/frontends/concrete-python/scripts/format/formatter.sh b/frontends/concrete-python/scripts/format/formatter.sh index 73be71883c..7e1a4e629b 100755 --- a/frontends/concrete-python/scripts/format/formatter.sh +++ b/frontends/concrete-python/scripts/format/formatter.sh @@ -39,7 +39,7 @@ done for SRC_DIR in "${DIRS[@]}"; do isort -l 100 --profile black ${CHECK:+"$CHECK"} "${SRC_DIR}" ((FAILURES+=$?)) - black -l 100 ${CHECK:+"$CHECK"} "${SRC_DIR}" + black --skip-string-normalization -l 100 ${CHECK:+"$CHECK"} "${SRC_DIR}" ((FAILURES+=$?)) done diff --git a/frontends/concrete-python/tests/execution/test_synthesis.py b/frontends/concrete-python/tests/execution/test_synthesis.py new file mode 100644 index 0000000000..2f97cedafd --- /dev/null +++ b/frontends/concrete-python/tests/execution/test_synthesis.py @@ -0,0 +1,350 @@ +""" +Tests of 'synthesis' extension. +""" + +import itertools +import time + +import pytest + +import concrete.fhe.extensions.synthesis as synth +from concrete import fhe +from concrete.fhe.compilation.configuration import SynthesisConfig + + +def signed_range(p: int): + """Range of value for a signed of precision `p`.""" + return list(range(0, 2 ** (p - 1))) + list(range(-(2 ** (p - 1)), 0)) + + +def signed_modulo(v, p): + """Modulo for signed value.""" + if v >= 2 ** (p - 1): + return signed_modulo(v - 2**p, p) + elif v < -(2 ** (p - 1)): + return signed_modulo(v + 2**p, p) + return v + + +def test_relu_correctness(): + """ + Check relu. + """ + ty_in_out = fhe.int5 + params = {"a": ty_in_out} + expression = "(a >= 0) ? a : 0" + + relu = synth.verilog_expression(params, expression, ty_in_out) + + def ref_relu(v): + return v if v >= 0 else 0 + + for a in signed_range(5): + assert relu(a=a) == ref_relu(a), a + + +def test_signed_add_correctness(): + """ + Check signed add. + """ + ty_in_out = fhe.int5 + params = {"a": ty_in_out, "b": ty_in_out} + expression = "a + b" + + oper = synth.verilog_expression(params, expression, ty_in_out) + r = signed_range(5) + for a in r: + for b in r: + expected = signed_modulo(a + b, 5) + assert oper(a=a, b=b) == expected + + +def test_unsigned_add_correctness(): + """ + Check unsigned add. + """ + out = fhe.uint5 + params = {"a": out, "b": out} + expression = "a + b" + + oper = synth.verilog_expression(params, expression, out) + + r = list(range(0, 2**out.dtype.bit_width)) + for a in r: + for b in r: + assert oper(a=a, b=b) == (a + b) % 2**5 + + +def test_signed_mul_correctness(): + """ + Check signed mul. + """ + out = fhe.int5 + params = {"a": out, "b": out} + expression = "a * b" + + oper = synth.verilog_expression(params, expression, out) + + r = signed_range(5) + for a in r: + for b in r: + expected = signed_modulo(a * b, 5) + assert oper(a=a, b=b) == expected + + +def test_unsigned_mul_correctness(): + """ + Check unsigned mul. + """ + out = fhe.uint5 + params = {"a": out, "b": out} + expression = "a * b" + + oper = synth.verilog_expression(params, expression, out) + + r = list(range(0, 2**out.dtype.bit_width)) + for a in r: + for b in r: + expected = (a * b) % 2**out.dtype.bit_width + assert oper(a=a, b=b) == expected + + +def test_unsigned_div(): + """ + Check unsigned div. + """ + ty_in_out = fhe.uint4 + params = {"a": ty_in_out, "b": ty_in_out} + expression = "a / b" + + inputset = list(itertools.product(list(range(2**ty_in_out.dtype.bit_width)), repeat=2)) + + oper = synth.verilog_expression(params, expression, ty_in_out) + + for a, b in inputset: + if b == 0: + expected = 15 # infinity + else: + expected = a // b + assert oper(a=a, b=b) == expected, (a, b) + + +def convert_to_radix(v, tys): + """Convert integer to multi-word integer.""" + if not (isinstance(tys, list)): + return v + shift_left = 0 + signed = any(ty.dtype.is_signed for ty in tys) + words = [] + for i, type_ in enumerate(tys): + last = (i + 1) == len(tys) + v_word = v >> shift_left + if signed and last: + assert -(2 ** (type_.dtype.bit_width - 1)) <= v_word < 2 ** (type_.dtype.bit_width - 1) + else: + if last: + assert v_word < 2**type_.dtype.bit_width + v_word = v_word % (2**type_.dtype.bit_width) + shift_left += type_.dtype.bit_width + words.append(v_word) + return words + + +def test_input_radix_relu_correctness(): + """ + Check relu with input in radix encoding. + """ + ty_out = fhe.int5 + ty_in = [fhe.uint3, fhe.int2] + params = {"a": ty_in} + expression = "(a >= 0) ? a : 0" + + relu = synth.verilog_expression(params, expression, ty_out) + + def ref_relu(v): + return v if v >= 0 else 0 + + r = signed_range(5) + for a in r: + a_words = convert_to_radix(a, ty_in) + assert relu(a=a_words) == convert_to_radix(ref_relu(a), ty_out), a + + +def test_radix_signed_mul_correctness(): + """ + Check signed mul with radix encoding. + """ + ty_out = [fhe.uint1, fhe.uint1, fhe.int3] + a_ty_in = [fhe.uint3, fhe.int2] + b_ty_in = [fhe.uint2, fhe.int3] + params = {"a": a_ty_in, "b": b_ty_in} + expression = "a * b" + + oper = synth.verilog_expression(params, expression, ty_out) + + r = signed_range(5) + for a in r: + for b in r: + a_words = convert_to_radix(a, a_ty_in) + b_words = convert_to_radix(b, b_ty_in) + expected = convert_to_radix(signed_modulo(a * b, 5), ty_out) + assert list(oper(a=a_words, b=b_words)) == expected + # assert list(circuit.simulate(*a_words, *b_words)) == expected + + +def to_bits(v, size): + """Integer to list of bits.""" + return [v >> i & 1 for i in range(size)] + + +def from_bits(v): + """List of bits to integer.""" + return sum(b << i for i, b in enumerate(v)) + + +@pytest.mark.parametrize( + "bit_width,reverse_bits", + [ + pytest.param(4, False), # generate a ternary sequence + pytest.param(8, True), # generate a ternary sequence + pytest.param(10, False), # generate a ternary tree + pytest.param(10, True), # generate a ternary tree + ], +) +def test_identity_tlu(bit_width, reverse_bits): + """ + Check the simplest TLU, this gives synthesize to a wire circuit. + + Bits are reversed to ensure we workaround correctly a bug in yosys. + """ + inputset = list(range(2**bit_width)) + + tlu_content = list(range(2**bit_width)) + if reverse_bits: + tlu_content = [from_bits(reversed(to_bits(v, bit_width))) for v in tlu_content] + + @fhe.compiler({"a": "encrypted"}) + def tlu(a): + return fhe.LookupTable(tlu_content)[a] + + time_0 = time.time() + conf = fhe.Configuration(synthesis_config=SynthesisConfig(start_tlu_at_precision=0)) + circuit = tlu.compile(inputset, conf) + time_1 = time.time() + assert time_1 - time_0 < 20 + + for a in inputset: + expected = tlu_content[a] + assert circuit.simulate(a) == expected + + if bit_width < 7: + assert circuit.mlir.count("lsb") == bit_width, circuit.mlir + assert circuit.mlir.count("table") == 0 + + +def test_bit_const_tlu4(): + """ + Check a more complex TLU, 4bits. + """ + + bit_width = 4 + inputset = list(range(2**bit_width)) + + @fhe.compiler({"a": "encrypted"}) + def tlu(a): + return fhe.LookupTable([v - v % 2 for v in range(2**bit_width)])[a] + + conf = fhe.Configuration(synthesis_config=SynthesisConfig(start_tlu_at_precision=0)) + circuit = tlu.compile(inputset) + + # it's not fast enough so synthesis is not kept + assert circuit.mlir.count("lsb") == 0 + assert circuit.mlir.count("table") == 1 + + # forcing to keep the synthesis + conf = fhe.Configuration( + synthesis_config=SynthesisConfig(start_tlu_at_precision=0, force_tlu_at_precision=0) + ) + circuit = tlu.compile(inputset, conf) + + assert circuit.mlir.count("lsb") >= 4 + assert circuit.mlir.count("table") == 0 + + for a in inputset: + expected = a - (a % 2) + assert circuit.simulate(a) == expected + + +def test_add_1_tlu4(): + """ + Check a more complex TLU, 4bits. + """ + + bit_width = 4 + inputset = list(range(2**bit_width)) + + @fhe.compiler({"a": "encrypted"}) + def tlu(a): + return fhe.LookupTable([v + 1 for v in range(2**bit_width)])[a] + + conf = fhe.Configuration(synthesis_config=SynthesisConfig(start_tlu_at_precision=0)) + circuit = tlu.compile(inputset) + + # it's not fast enough so synthesis is not kept + assert circuit.mlir.count("lsb") == 0 + assert circuit.mlir.count("table") == 1 + + # forcing to keep the synthesis + conf = fhe.Configuration( + synthesis_config=SynthesisConfig(start_tlu_at_precision=0, force_tlu_at_precision=0) + ) + circuit = tlu.compile(inputset, conf) + + assert circuit.mlir.count("lsb") >= 4 + assert circuit.mlir.count("table") == 4 + + for a in inputset: + expected = a + 1 + assert circuit.simulate(a) == expected + # assert circuit.encrypt_run_decrypt(a) == expected + + +def test_div_tlu10(): + """ + Check a more complex tlu, 10bits. + """ + + bit_width = 5 + inputset = list(itertools.product(list(range(2**bit_width)), repeat=2)) + + def div(a, b): + return a // b if a and b else 0 + + div_table = [div(a, b) for a in range(2**bit_width) for b in range(2**bit_width)] + + @fhe.compiler({"a": "encrypted", "b": "encrypted"}) + def tlu(a, b): + return fhe.LookupTable(div_table)[a * 2**bit_width + b] + + # forcing to keep the synthesis + conf = fhe.Configuration( + synthesis_config=SynthesisConfig(start_tlu_at_precision=0, force_tlu_at_precision=0) + ) + time_0 = time.time() + circuit = tlu.compile(inputset, conf) + time_1 = time.time() + assert time_1 - time_0 < 60 + + assert circuit.mlir.count("lsb") >= 2 * bit_width + assert circuit.mlir.count("table") > 1 + for line in circuit.mlir.splitlines(): + if "table" not in line: + continue + tlu_bit_width = line.rsplit(":")[1].split("->")[0].split(",")[0].split("<")[1].strip(">") + assert 1 < int(tlu_bit_width) <= 7 + + testset = inputset + for a, b in testset: + expected = div(a, b) + simu = circuit.simulate(a, b) + assert simu == expected, (a, b, expected, simu)