diff --git a/.github/workflows/publish_to_pypi.yml b/.github/workflows/publish_to_pypi.yml new file mode 100644 index 0000000..2f483b7 --- /dev/null +++ b/.github/workflows/publish_to_pypi.yml @@ -0,0 +1,41 @@ +name: Publish Python ๐Ÿ distribution ๐Ÿ“ฆ to PyPI + +on: push + +jobs: + build: + name: Build distribution ๐Ÿ“ฆ + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.x" + - name: Install pypa/build + run: >- + python3 -m + pip install + build + --user + - name: Build a binary wheel and a source tarball + run: python3 -m build + - name: Store the distribution packages + uses: actions/upload-artifact@v3 + with: + name: python-package-distributions + path: dist/ + + publish-to-pypi: + name: >- + Publish Python ๐Ÿ distribution ๐Ÿ“ฆ to PyPI + if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes + needs: + - build + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/oraqle + permissions: + id-token: write # IMPORTANT: mandatory for trusted publishing diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml new file mode 100644 index 0000000..0110efc --- /dev/null +++ b/.github/workflows/python-app.yml @@ -0,0 +1,40 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Python application + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +permissions: + contents: read + +jobs: + build: + + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v3 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff pytest + pip install -r requirements.txt + pip install -e . + - name: Lint with ruff + run: | + ruff check + - name: Test with pytest + run: | + pytest diff --git a/.gitignore b/.gitignore index f9606a3..c8cf315 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,13 @@ /venv +*.dot +.idea/** +__pycache__/** +/build +.DS_Store +*.egg-info* +instructions.txt +.ipynb_checkpoints/ +*.pdf +*.pkl +.sphinx_build/ +/dist diff --git a/README.md b/README.md index 1ef3cd1..a88d6f8 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,11 @@ # Oraqle -The first version of the oraqle compiler will be released in April 2024. +The oraqle compiler lets you generate arithmetic circuits from high-level Python code. It also lets you generate code using HElib. + +This repository uses a fork of fhegen as a dependency and adapts some of the code from [fhegen](https://github.com/Crypto-TII/fhegen), which was written by Johannes Mono, Chiara Marcolla, Georg Land, Tim Gรผneysu, and Najwa Aaraj. You can read their theoretical work at: https://eprint.iacr.org/2022/706. + +## Setting up +The best way to get things up and running is using a virtual environment: +- Set up a virtualenv using `python3 -m venv venv` in the directory. +- Enter the virtual environment using `source venv/bin/activate`. +- Install the requirements using `pip install requirements.txt`. +- *To overcome import problems*, run `pip install -e .`, which will create links to your files (so you do not need to re-install after every change). diff --git a/addchain_cache.db b/addchain_cache.db new file mode 100644 index 0000000..e211620 Binary files /dev/null and b/addchain_cache.db differ diff --git a/compiler/hello_world.py b/compiler/hello_world.py deleted file mode 100644 index eac0111..0000000 --- a/compiler/hello_world.py +++ /dev/null @@ -1,2 +0,0 @@ -if __name__ == "__main__": - print("Hello world!") diff --git a/docs/api/abstract_nodes_api.md b/docs/api/abstract_nodes_api.md new file mode 100644 index 0000000..78621bf --- /dev/null +++ b/docs/api/abstract_nodes_api.md @@ -0,0 +1,5 @@ +# Abstract nodes API +!!! warning + In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version. + +If you want to extend the oraqle compiler, or implement your own high-level nodes, it is easiest to extend one of the existing abstract node classes. diff --git a/docs/api/addition_chains_api.md b/docs/api/addition_chains_api.md new file mode 100644 index 0000000..54ef612 --- /dev/null +++ b/docs/api/addition_chains_api.md @@ -0,0 +1,11 @@ +# Addition chains API +!!! warning + In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version. + +The `add_chains` module contains tools for generating addition chains. + +::: oraqle.add_chains + options: + heading_level: 2 + show_submodules: true + show_if_no_docstring: false diff --git a/docs/api/circuits_api.md b/docs/api/circuits_api.md new file mode 100644 index 0000000..e4bce56 --- /dev/null +++ b/docs/api/circuits_api.md @@ -0,0 +1,16 @@ +# Circuits API +!!! warning + In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version. + + +## High-level circuits +::: oraqle.compiler.circuit.Circuit + options: + heading_level: 3 + + +## Arithmetic circuits +::: oraqle.compiler.circuit.ArithmeticCircuit + options: + heading_level: 3 + \ No newline at end of file diff --git a/docs/api/code_generation_api.md b/docs/api/code_generation_api.md new file mode 100644 index 0000000..9ef6dfe --- /dev/null +++ b/docs/api/code_generation_api.md @@ -0,0 +1,56 @@ +# Code generation API +!!! warning + In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version. + +The easiest way is using: +```python3 +arithmetic_circuit.generate_code() +``` + +## Arithmetic instructions +If you want to extend the oraqle compiler, or implement your own code generation, you can use the following instructions to do so. + +??? info "Abstract instruction" + ::: oraqle.compiler.instructions.ArithmeticInstruction + options: + heading_level: 3 + +??? info "InputInstruction" + ::: oraqle.compiler.instructions.InputInstruction + options: + heading_level: 3 + +??? info "AdditionInstruction" + ::: oraqle.compiler.instructions.AdditionInstruction + options: + heading_level: 3 + +??? info "MultiplicationInstruction" + ::: oraqle.compiler.instructions.MultiplicationInstruction + options: + heading_level: 3 + +??? info "ConstantAdditionInstruction" + ::: oraqle.compiler.instructions.ConstantAdditionInstruction + options: + heading_level: 3 + +??? info "ConstantMultiplicationInstruction" + ::: oraqle.compiler.instructions.ConstantMultiplicationInstruction + options: + heading_level: 3 + +??? info "OutputInstruction" + ::: oraqle.compiler.instructions.OutputInstruction + options: + heading_level: 3 + + +## Generating arithmetic programs +::: oraqle.compiler.instructions.ArithmeticProgram + options: + heading_level: 3 + + +## Generating code for HElib +... diff --git a/docs/api/nodes_api.md b/docs/api/nodes_api.md new file mode 100644 index 0000000..e575383 --- /dev/null +++ b/docs/api/nodes_api.md @@ -0,0 +1,53 @@ +# Nodes API +!!! warning + In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version. + +## Boolean operations + +??? info "AND operation" + ::: oraqle.compiler.boolean.bool_and.And + options: + heading_level: 3 + +??? info "OR operation" + ::: oraqle.compiler.boolean.bool_or.Or + options: + heading_level: 3 + +??? info "NEG operation" + ::: oraqle.compiler.boolean.bool_neg.Neg + options: + heading_level: 3 + + +## Arithmetic operations +These operations are fundamental arithmetic operations, so they will stay the same when they are arithmetized. + + +## High-level arithmetic operations + +??? info "Subtraction" + ::: oraqle.compiler.arithmetic.subtraction.Subtraction + options: + heading_level: 3 + +??? info "Exponentiation" + ::: oraqle.compiler.arithmetic.exponentiation.Power + options: + heading_level: 3 + + +## Polynomial evaluation + +??? info "Univariate polynomial evaluation" + ::: oraqle.compiler.polynomials.univariate.UnivariatePoly + options: + heading_level: 3 + + +## Control flow + +??? info "If-else statement" + ::: oraqle.compiler.control_flow.conditional.IfElse + options: + heading_level: 3 diff --git a/docs/api/pareto_fronts_api.md b/docs/api/pareto_fronts_api.md new file mode 100644 index 0000000..5376661 --- /dev/null +++ b/docs/api/pareto_fronts_api.md @@ -0,0 +1,18 @@ +# Pareto fronts API +!!! warning + In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version. + +If you are using depth-aware arithmetization, you will find that the compiler does not output one arithmetic circuit. +Instead, it outputs a Pareto front, which represents the best circuits that it could generate trading off two metrics: +The *multiplicative depth* and the *multiplicative size/cost*. +This page briefly explains the API for interfacing with these Pareto fronts. + +## The abstract base class + +??? info "Abstract ParetoFront" + ::: oraqle.compiler.nodes.abstract.ParetoFront + options: + heading_level: 3 + +## Depth-size and depth-cost fronts + diff --git a/docs/config.md b/docs/config.md new file mode 100644 index 0000000..c2c653d --- /dev/null +++ b/docs/config.md @@ -0,0 +1,7 @@ +# Configuration parameters + +::: oraqle.config + options: + heading_level: 2 + show_submodules: true + show_if_no_docstring: false diff --git a/docs/example_circuits.md b/docs/example_circuits.md new file mode 100644 index 0000000..e4eabe9 --- /dev/null +++ b/docs/example_circuits.md @@ -0,0 +1,7 @@ +!!! warning + Some of these example circuits are untested and may be incorrect. + +::: oraqle.circuits + options: + heading_level: 3 + show_submodules: true diff --git a/docs/getting_started.md b/docs/getting_started.md new file mode 100644 index 0000000..a10f59e --- /dev/null +++ b/docs/getting_started.md @@ -0,0 +1,85 @@ +# Getting started +In 5 minutes, this page will guide you through how to install oraqle, how to specify high-level programs, and how to arithmetize your first circuit! + +## Installation +Simply install the most recent version of the Oraqle compiler using: +``` +pip install oraqle +``` + +We use continuous integration to test every build of the Oraqle compiler on Windows, MacOS, and Unix systems. +If you do run into problems, feel free to [open an issue on GitHub]()! + +## Specifying high-level programs +Let's start with importing `galois`, which represents our plaintext algebra. +We will also immediately import the relevant oraqle classes for our little example: +```python3 +from galois import GF + +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes.leafs import Input +``` + +For this example, we will use 31 as our plaintext modulus. This algebra is denoted by `GF(31)`. +Let's create a few inputs that represent elements in this algebra: +```python3 +gf = GF(31) + +x = Input("x", gf) +y = Input("y", gf) +z = Input("z", gf) +``` + +We can now perform some operations on these elements, and they do not have to be arithmetic operations! +For example, we can perform equality checks or comparisons: +``` +comparison = x < y +equality = y == z +both = comparison & equality +``` + +While we have specified some operations, we have not yet established this as a circuit. We will do so now: +```python3 +circuit = Circuit([both]) +``` + +And that's it! We are done specifying our first high-level circuit. +As you can see this is all very similar to writing a regular Python program. +If you want to visualize this high-level circuit before we continue with arithmetizing it, you can run the following (if you have graphviz installed): +```python3 +circuit.to_pdf("high_level_circuit.pdf") +``` + +!!! tip + If you do not have graphviz installed, you can instead call: + ```python3 + circuit.to_dot("high_level_circuit.dot") + ``` + After that, you can copy the file contents to [an online graphviz viewer](https://dreampuf.github.io/GraphvizOnline)! + +## Arithmetizing your first circuit +At this point, arithmetization is a breeze, because the oraqle compiler takes care of these steps. +We can create an arithmetic circuit and visualize it using the following snippet: +```python3 +arithmetic_circuit = circuit.arithmetize() +arithmetic_circuit.to_pdf("arithmetic_circuit.pdf") +``` + +You will notice that it's quite a large circuit. But how large is it exactly? +This is a question that we can ask to the oraqle compiler: +```python3 +print("Depth:", arithmetic_circuit.multiplicative_depth()) +print("Size:", arithmetic_circuit.multiplicative_size()) +print("Cost:", arithmetic_circuit.multiplicative_cost(0.7)) +``` + +In the last line, we asked the compiler to output the multiplicative cost, considering that squaring operations are cheaper than regular multiplications. +We weighed this cost with a factor 0.7. + +Now that we have an arithmetic circuit, we can use homomorphic encryption to evaluate it! +If you are curious about executing these circuits for real, consider reading [the code generation tutorial](tutorial_running_exps.md). + +!!! warning + There are many homomorphic encryption libraries that do not support plaintext moduli that are not NTT-friendly. The plaintext modulus we chose (31) is not NTT-friendly. + In fact, only very few primes are NTT-friendly, and they are somewhat large. This is why, right now, the oraqle compiler only implements code generation for HElib. + HElib is (as far as we are aware) the only library that supports plaintext moduli that are not NTT-friendly. diff --git a/docs/images/oraqle_logo_cropped.svg b/docs/images/oraqle_logo_cropped.svg new file mode 100644 index 0000000..f306d73 --- /dev/null +++ b/docs/images/oraqle_logo_cropped.svg @@ -0,0 +1,119 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..1dd7f78 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,18 @@ +# Welcome to oraqle +
+ Oraqle logo
+ A secure computation compiler +
+ +Simply install the most recent version of the Oraqle compiler using: +``` +pip install oraqle==0.1.0 +``` + +Consider checking out our [getting started page](getting_started.md) to help you get up to speed with arithmetizing circuits! + +## API reference +!!! warning + In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version. + +For an API reference, you can check out the pages for [circuits](api/circuits_api.md) and for [nodes](api/nodes_api.md). diff --git a/docs/tutorial_running_exps.md b/docs/tutorial_running_exps.md new file mode 100644 index 0000000..593bd22 --- /dev/null +++ b/docs/tutorial_running_exps.md @@ -0,0 +1,3 @@ +# Tutorial: Running experiments +!!! failure + This section is currently missing. Please see the [code generation API](api/code_generation_api.md) for some documentation for now. diff --git a/main.cpp b/main.cpp new file mode 100644 index 0000000..3549e0f --- /dev/null +++ b/main.cpp @@ -0,0 +1,59 @@ + +#include +#include +#include +#include + +#include + +typedef helib::Ptxt ptxt_t; +typedef helib::Ctxt ctxt_t; + +std::map input_map; + +void parse_arguments(int argc, char* argv[]) { + for (int i = 1; i < argc; ++i) { + std::string argument(argv[i]); + size_t pos = argument.find('='); + if (pos != std::string::npos) { + std::string key = argument.substr(0, pos); + int value = std::stoi(argument.substr(pos + 1)); + input_map[key] = value; + } + } +} + +int extract_input(const std::string& name) { + if (input_map.find(name) != input_map.end()) { + return input_map[name]; + } else { + std::cerr << "Error: " << name << " not found" << std::endl; + return -1; + } +} + +int main(int argc, char* argv[]) { + // Parse the inputs + parse_arguments(argc, argv); + + // Set up the HE parameters + unsigned long p = 257; + unsigned long m = 65536; + unsigned long r = 1; + unsigned long bits = 449; + unsigned long c = 3; + helib::Context context = helib::ContextBuilder() + .m(m) + .p(p) + .r(r) + .bits(bits) + .c(c) + .build(); + + + // Generate keys + helib::SecKey secret_key(context); + secret_key.GenSecKey(); + helib::addSome1DMatrices(secret_key); + const helib::PubKey& public_key = secret_key; + diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..417bff0 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,57 @@ +site_name: Oraqle + +nav: + - index.md + - getting_started.md + - tutorial_running_exps.md + - API reference: + - api/circuits_api.md + - api/nodes_api.md + - api/code_generation_api.md + - api/pareto_fronts_api.md + - api/abstract_nodes_api.md + - api/addition_chains_api.md + - example_circuits.md + - config.md + +plugins: +- search +- mkdocstrings: + handlers: + python: + options: + show_root_heading: true + allow_inspection: false + show_submodules: false + show_root_full_path: false + show_symbol_type_heading: true + # show_symbol_type_toc: true This currently causes a bug + docstring_style: google + follow_wrapped_lines: true + crosslink_types: true # Makes types clickable + crosslink_types_style: 'sphinx' # Default or sphinx style + annotations_path: brief + inherited_members: true + members_order: source + show_if_no_docstring: true + separate_signature: false + show_source: false + docstring_section_style: list + +theme: + name: material + highlightjs: true + +markdown_extensions: + - admonition + - pymdownx.superfences + - pymdownx.inlinehilite + - pymdownx.critic + - pymdownx.details + - pymdownx.tasklist + - pymdownx.tabbed + - pymdownx.magiclink + - pymdownx.tilde + - toc: + permalink: true + toc_depth: 3 diff --git a/oraqle/__init__.py b/oraqle/__init__.py new file mode 100644 index 0000000..7b975e3 --- /dev/null +++ b/oraqle/__init__.py @@ -0,0 +1 @@ +"""This module contains the oraqle compiler, tools, and example circuits.""" diff --git a/oraqle/add_chains/__init__.py b/oraqle/add_chains/__init__.py new file mode 100644 index 0000000..e1a4ca8 --- /dev/null +++ b/oraqle/add_chains/__init__.py @@ -0,0 +1 @@ +"""Tools for generating addition chains using different constraints and objectives.""" diff --git a/oraqle/add_chains/addition_chains.py b/oraqle/add_chains/addition_chains.py new file mode 100644 index 0000000..a66bbd3 --- /dev/null +++ b/oraqle/add_chains/addition_chains.py @@ -0,0 +1,283 @@ +"""Tools for generating short addition chains using a MaxSAT formulation.""" +import math +from typing import List, Optional, Tuple + +from pysat.card import CardEnc +from pysat.formula import WCNF + +from oraqle.add_chains.memoization import ADDCHAIN_CACHE_PATH, cache_to_disk +from oraqle.add_chains.solving import solve, solve_with_time_limit +from oraqle.config import MAXSAT_TIMEOUT + + +def thurber_bounds(target: int, max_size: int) -> List[Tuple[int, int]]: + """Returns the Thurber bounds for a given target and a maximum size of the addition chain.""" + m = target + t = 0 + while (m % 2) == 0: + t += 1 + m >>= 1 + + bounds = [] + for step in range(max_size - t - 3 + 1): + if ((1 << (max_size - t - step - 2) + 1) % target) == 0: + denominator = (1 << (t + 1)) * ((1 << (max_size - t - (step + 2))) + 1) + else: + denominator = (1 << t) * ((1 << (max_size - t - (step + 1))) + 1) + bound = int(math.ceil(target / denominator)) + bounds.append((bound, min(1 << step, target))) + + step = max_size - t - 2 + if step > 0: + denominator = (1 << t) * ((1 << (max_size - t - (step + 1))) + 1) + bound = int(math.ceil(target / denominator)) + bounds.append((bound, min(1 << step, target))) + + if max_size - t - 1 > 0: + for step in range(max_size - t - 1, max_size + 1): + bound = int(math.ceil(target / (1 << (max_size - step)))) + bounds.append((bound, min(1 << step, target))) + + return bounds + + +@cache_to_disk(ADDCHAIN_CACHE_PATH, ignore_args={"solver", "encoding", "thurber"}) +def add_chain( # noqa: PLR0912, PLR0913, PLR0915, PLR0917 + target: int, + max_depth: Optional[int], + strict_cost_max: float, + squaring_cost: float, + solver: str, + encoding: int, + thurber: bool, + min_size: int, + precomputed_values: Optional[Tuple[Tuple[int, int], ...]], +) -> Optional[List[Tuple[int, int]]]: + """Generates a minimum-cost addition chain for a given target, abiding to the constraints. + + Parameters: + target: The target integer. + max_depth: The maximum depth of the addition chain + strict_cost_max: A strict upper bound on the cost of the addition chain. I.e., cost(chain) < strict_cost_max. + squaring_cost: The cost of doubling (squaring), compared to other additions (multiplications), which cost 1.0. + solver: Name of the SAT solver, e.g. "glucose421" for glucose 4.2.1. See: https://pysathq.github.io/docs/html/api/solvers.html. + encoding: The encoding to use for cardinality constraints. See: https://pysathq.github.io/docs/html/api/card.html#pysat.card.EncType. + thurber: Whether to use the Thurber bounds, which provide lower bounds for the elements in the chain. The bounds are ignored when `precomputed_values = True`. + min_size: The minimum size of the chain. It is always possible to use `math.ceil(math.log2(target))`. + precomputed_values: If there are any precomputed values that can be used for free, they can be specified as a tuple of pairs (value, chain_depth). + + Raises: # noqa: DOC502 + TimeoutError: If the global MAXSAT_TIMEOUT is not None, and it is reached before a maxsat instance could be solved. + + Returns: + A minimum-cost addition chain, if it exists. + """ + # TODO: Maybe precomputed_values should not be optional, but should be ignored if it is empty + assert target != 0 + + if target == 1: + return [] + + def x(i) -> int: + return i + + if precomputed_values is not None: + + def z(i: int) -> int: + offset = target + 1 + return i + offset + + def y(i, j) -> int: + # TODO: We can make the offset tighter + offset = (target + 1) if precomputed_values is None else 2 * (target + 1) + assert i <= j + return j * (j + 1) // 2 + i + offset + + def y_inv(n: int) -> Tuple[int, int]: + offset = (target + 1) if precomputed_values is None else 2 * (target + 1) + assert n >= offset + n -= offset + j = math.floor((math.sqrt(1 + 8 * n) - 1) // 2) + i = n - j * (j + 1) // 2 + + return i, j # minus 1 so that 1 -> 0 + + if max_depth is not None: + + def d(i, depth) -> int: + offset = y(target, target) + assert depth <= max_depth + 1 + return offset + 1 + (i - 1) * (max_depth + 1) + depth + + wcnf = WCNF() + + # x_i for i = 1,...,target represents the computed additions + # y_i,j for i,j = 2,...,target s.t. i <= j represents that i+j is computed + + # Add constraints + big_disjunctions = {k: [] for k in range(1, target + 1)} + for j in range(1, target + 1): + x_j = x(j) + + for i in range(1, min(j + 1, target + 1 - j)): + x_i = x(i) + y_ij = y(i, j) + + k = i + j + + # y_ij requires that x_i is set + wcnf.append([-y_ij, x_i]) + if i != j: + # y_ij requires that x_j is set + wcnf.append([-y_ij, x_j]) + + # x_k is set when y_ij is set + big_disjunctions[k].append(y_ij) + + # Add objective + wcnf.append([-y(i, j)], weight=(squaring_cost if i == j else 1)) + + if max_depth is not None: + for depth in range(max_depth + 1): + # d_k,depth+1 is set when d_i,depth and y_ij are set + wcnf.append([d(k, depth + 1), -d(i, depth), -y_ij]) + if i != j: + # d_k,depth+1 is set when d_j,depth and y_ij are set + wcnf.append([d(k, depth + 1), -d(j, depth), -y_ij]) + + if precomputed_values is not None: + for k, k_depth in precomputed_values: + if k == 0 or k > target: + continue + + if max_depth is not None and k_depth > max_depth: + continue + + # x_k is set when z_k is set + big_disjunctions[k].append(z(k)) + + if max_depth is not None: + wcnf.append([d(k, k_depth), -z(k)]) + + wcnf.append([x(target)]) + + if max_depth is not None: + wcnf.append([d(1, 0)]) + + for k in range(2, target + 1): + big_disjunctions[k].append(-x(k)) + wcnf.append(big_disjunctions[k]) + + # Cut some potential additions + if precomputed_values is None: + # We do not use these bounds when precomputed_values is not None + wcnf.append([x(m) for m in range((k + 1) // 2, k)]) # type: ignore + + if max_depth is not None: + # May not exceed max_depth + wcnf.append([-d(k, max_depth + 1)]) + + # Add generalized Thurber bounds (for each step in the chain, the number must be between lower_bound and 2^step) + # We do not use the Thurber bounds when precomputed_values is not None + if thurber and precomputed_values is None: + max_size = math.floor(strict_cost_max / squaring_cost) + for lb, ub in thurber_bounds(target, max_size): + # FIXME: These bounds seem not to help for target ~ hundreds + wcnf.append([x(i) for i in range(lb, ub + 1)]) + + # Bound the number of x that are true from below + if max_depth is None: + top_id = y(target, target) + else: + top_id = y(target, target) + 1 + (target - 1) * (max_depth + 1) + max_depth + 1 + at_least_cnf = CardEnc.atleast( + [x(k) for k in range(2, target + 1)], bound=min_size, top_id=top_id, encoding=encoding + ) + wcnf.extend(at_least_cnf) + + # Solve + if MAXSAT_TIMEOUT is None: + model = solve(wcnf, solver, strict_cost_max) + else: + model = solve_with_time_limit(wcnf, solver, strict_cost_max, MAXSAT_TIMEOUT) + + if model is None: + return None + + offset = (target + 1) if precomputed_values is None else 2 * (target + 1) + return [y_inv(n) for n in model if offset <= n <= y(target, target)] + + +def test_addition_chain(): # noqa: D103 + chain = add_chain( + 8, + 3, + 2.0, + 0.5, + solver="glucose42", + encoding=1, + thurber=True, + min_size=2, + precomputed_values=None, + ) + assert chain == [(1, 1), (2, 2), (4, 4)] + + +def test_addition_chain_precomputed_no_depth(): # noqa: D103 + chain = add_chain( + 8, + None, + 2.0, + 0.5, + solver="glucose42", + encoding=1, + thurber=True, + min_size=1, + precomputed_values=((7, 2),), + ) + assert chain == [(1, 7)] + + +def test_addition_chain_precomputed_depth(): # noqa: D103 + chain = add_chain( + 8, + 3, + 2.0, + 0.5, + solver="glucose42", + encoding=1, + thurber=True, + min_size=1, + precomputed_values=((7, 2),), + ) + assert chain == [(1, 7)] + + +def test_addition_chain_precomputed_depth_too_large(): # noqa: D103 + chain = add_chain( + 8, + 3, + 2.0, + 0.5, + solver="glucose42", + encoding=1, + thurber=True, + min_size=1, + precomputed_values=((7, 3),), + ) + assert chain == [(1, 1), (2, 2), (4, 4)] + + +def test_addition_chain_precomputed_no_depth_squaring(): # noqa: D103 + chain = add_chain( + 18, + None, + 2.0, + 0.5, + solver="glucose42", + encoding=1, + thurber=True, + min_size=1, + precomputed_values=((9, 3),), + ) + assert chain == [(9, 9)] diff --git a/oraqle/add_chains/addition_chains_front.py b/oraqle/add_chains/addition_chains_front.py new file mode 100644 index 0000000..2776582 --- /dev/null +++ b/oraqle/add_chains/addition_chains_front.py @@ -0,0 +1,153 @@ +"""Tools for generating addition chains that trade off depth and cost.""" +import math +from typing import List, Optional, Tuple + +from oraqle.add_chains.addition_chains import add_chain +from oraqle.add_chains.addition_chains_mod import add_chain_modp, hw, size_lower_bound + + +def chain_depth( + chain: List[Tuple[int, int]], + precomputed_values: Optional[Tuple[Tuple[int, int], ...]] = None, + modulus: Optional[int] = None, +) -> int: + """Return the depth of the addition chain.""" + depths = {1: 0} + if precomputed_values is not None: + depths.update(precomputed_values) + + if modulus is None: + for x, y in chain: + depths[x + y] = max(depths[x], depths[y]) + 1 + else: + for x, y in chain: + depths[(x + y) % modulus] = max(depths[x % modulus], depths[y % modulus]) + 1 + + return max(depths.values()) + + +def gen_pareto_front( # noqa: PLR0912, PLR0913, PLR0917 + target: int, + modulus: Optional[int], + squaring_cost: float, + solver="glucose42", + encoding=1, + thurber=True, + precomputed_values: Optional[Tuple[Tuple[int, int], ...]] = None, +) -> List[Tuple[int, List[Tuple[int, int]]]]: + """Returns a Pareto front of addition chains, trading of cost and depth.""" + if target == 1: + return [(0, [])] + + if modulus is not None: + assert target <= modulus + + # Find the lowest depth chain using square & multiply (SaM) + sam_depth = math.ceil(math.log2(target)) + sam_cost = math.ceil(math.log2(target)) * squaring_cost + hw(target) - 1 + sam_target = target + + # If there is a modulus, we should also consider it to find an upper bound on the cost of a minimum-depth chain + if modulus is not None: + current_target = target + modulus - 1 + while math.log2(current_target) <= sam_depth: + current_cost = ( + math.ceil(math.log2(current_target)) * squaring_cost + hw(current_target) - 1 + ) + if current_cost < sam_cost: + sam_cost = current_cost + sam_target = target + current_target += modulus - 1 + + # Find the cheapest chain (i.e. no depth constraints) + min_size = size_lower_bound(target) if precomputed_values is None else 1 + if modulus is None: + cheapest_chain = add_chain( + target, + None, + sam_cost, + squaring_cost, + solver, + encoding, + thurber, + min_size, + precomputed_values, + ) + else: + cheapest_chain = add_chain_modp( + target, + modulus, + None, + sam_cost, + squaring_cost, + solver, + encoding, + thurber, + min_size, + precomputed_values, + ) + + # If no cheapest chain is found that satisfies these bounds, then square and multiply had the same cost + if cheapest_chain is None: + sam_chain = [] + for i in range(math.ceil(math.log2(sam_target))): + sam_chain.append((2**i, 2**i)) + previous = 1 + for i in range(math.ceil(math.log2(sam_target))): + if (sam_target >> i) & 1: + sam_chain.append((previous, 2**i)) + previous += 2**i + return [(sam_depth, sam_chain)] + + add_size = len(cheapest_chain) # TODO: Check that this is indeed a valid bound + add_cost = sum(squaring_cost if x == y else 1.0 for x, y in cheapest_chain) + add_depth = chain_depth(cheapest_chain, precomputed_values, modulus=modulus) + + # Go through increasing depth and decrease the previous size, until we reach the cost of square and multiply + pareto_front = [] + current_depth = sam_depth + current_cost = sam_cost + while current_cost > add_cost and current_depth < add_depth: + if modulus is None: + chain = add_chain( + target, + current_depth, + current_cost, + squaring_cost, + solver, + encoding, + thurber, + add_size, + precomputed_values, + ) + else: + chain = add_chain_modp( + target, + modulus, + current_depth, + current_cost, + squaring_cost, + solver, + encoding, + thurber, + add_size, + precomputed_values, + ) + + if chain is not None: + # Add to the Pareto front + pareto_front.append((current_depth, chain)) + current_cost = sum(squaring_cost if x == y else 1.0 for x, y in chain) + + current_depth += 1 + + # Add the final chain and return + if add_cost < current_cost or len(pareto_front) == 0: + pareto_front.append((add_depth, cheapest_chain)) + + return pareto_front + + +def test_gen_exponentiation_front_small(): # noqa: D103 + front = gen_pareto_front(2, None, 0.75) + assert front == [(1, [(1, 1)])] diff --git a/oraqle/add_chains/addition_chains_heuristic.py b/oraqle/add_chains/addition_chains_heuristic.py new file mode 100644 index 0000000..70ec347 --- /dev/null +++ b/oraqle/add_chains/addition_chains_heuristic.py @@ -0,0 +1,143 @@ +"""This module contains functions for finding addition chains, while sometimes resorting to heuristics to prevent long computations.""" + +from functools import lru_cache +import math +from typing import List, Optional, Tuple + +from oraqle.add_chains.addition_chains import add_chain +from oraqle.add_chains.addition_chains_mod import add_chain_modp, hw +from oraqle.add_chains.solving import extract_indices + + +def _mul(current_chain: List[Tuple[int, int]], other_chain: List[Tuple[int, int]]): + length = len(current_chain) + for a, b in other_chain: + current_chain.append((a + length, b + length)) + + +def _chain(n, k) -> List[Tuple[int, int]]: + q = n // k + r = n % k + if r in {0, 1}: + chain_k = _minchain(k) + _mul(chain_k, _minchain(q)) + if r == 1: + chain_k.append((0, len(chain_k))) + return chain_k + else: + chain_k = _chain(k, r) + index_r = len(chain_k) + _mul(chain_k, _minchain(q)) + chain_k.append((index_r, len(chain_k))) + return chain_k + + +def _minchain(n: int) -> List[Tuple[int, int]]: + log_n = n.bit_length() - 1 + if n == 1 << log_n: + return [(i, i) for i in range(log_n)] + elif n == 3: + return [(0, 0), (0, 1)] + else: + k = n // (1 << (log_n // 2)) + return _chain(n, k) + + +@lru_cache +def add_chain_guaranteed( # noqa: PLR0913, PLR0917 + target: int, + modulus: Optional[int], + squaring_cost: float, + solver: str = "glucose421", + encoding: int = 1, + thurber: bool = True, + precomputed_values: Optional[Tuple[Tuple[int, int], ...]] = None, +) -> List[Tuple[int, int]]: + """Always generates an addition chain for a given target, which is suboptimal if the inputs are too large. + + In some cases, the result is not necessarily optimal. These are the cases where we resort to a heuristic. + This currently happens if: + - The target exceeds 1000. + - The modulus (if provided) exceeds 200. + - MAXSAT_TIMEOUT is not None and a MaxSAT instance timed out + + !!! note + This function is useful for preventing long computation, but the result is not guaranteed to be (close to) optimal. + Unlike `add_chain`, this function will always return an addition chain. + + Parameters: + target: The target integer. + modulus: Modulus to take into account. In an exponentiation chain, this is the modulus in the exponent, i.e. x^target mod p corresponds to `modulus = p - 1`. + squaring_cost: The cost of doubling (squaring), compared to other additions (multiplications), which cost 1.0. + solver: Name of the SAT solver, e.g. "glucose421" for glucose 4.2.1. See: https://pysathq.github.io/docs/html/api/solvers.html. + encoding: The encoding to use for cardinality constraints. See: https://pysathq.github.io/docs/html/api/card.html#pysat.card.EncType. + thurber: Whether to use the Thurber bounds, which provide lower bounds for the elements in the chain. The bounds are ignored when `precomputed_values = True`. + precomputed_values: If there are any precomputed values that can be used for free, they can be specified as a tuple of pairs (value, chain_depth). + + Raises: # noqa: DOC502 + TimeoutError: If the global MAXSAT_TIMEOUT is not None, and it is reached before a maxsat instance could be solved. + + Returns: + An addition chain. + """ + # We want to do better than square and multiply, so we find an upper bound + sam_cost = math.ceil(math.log2(target)) * squaring_cost + hw(target) - 1 + + # Apply CSE to the square & mutliply chain + if precomputed_values is not None: + for exp, depth in precomputed_values: + if (exp & (exp - 1)) == 0 and depth == math.log2(exp): + sam_cost -= squaring_cost + + try: + addition_chain = None + if modulus is not None and modulus <= 200: + addition_chain = add_chain_modp( + target, + modulus, + None, + sam_cost, + squaring_cost, + solver, + encoding, + thurber, + min_size=math.ceil(math.log2(target)) if precomputed_values is None else 1, + precomputed_values=precomputed_values, + ) + elif target <= 1000: + addition_chain = add_chain( + target, + None, + sam_cost, + squaring_cost, + solver, + encoding, + thurber, + min_size=math.ceil(math.log2(target)) if precomputed_values is None else 1, + precomputed_values=precomputed_values, + ) + + if addition_chain is not None: + addition_chain = extract_indices( + addition_chain, precomputed_values=None if precomputed_values is None else list(k for k, _ in precomputed_values), modulus=modulus + ) + except TimeoutError: + # The MaxSAT solver timed out, so we resort to a heuristic + pass + + if addition_chain is None: + # If no other addition chain algorithm has been called or if we could not do better than square and multiply + + # Uses the minchain algorithm from ["Addition chains using continued fractions."][BBBD1989] + # The implementation was adapted from the `addchain` Rust crate (https://github.com/str4d/addchain). + # This algorithm is not optimal: Below 1000 it requires one too many multiplication in 29 cases. + addition_chain = _minchain(target) + + if precomputed_values is not None: + # We must shift the indices in the addition chain + shift = len(precomputed_values) + addition_chain = [(0 if x == 0 else x + shift, 0 if y == 0 else y + shift) for (x, y) in addition_chain] + + assert addition_chain is not None + + return addition_chain diff --git a/oraqle/add_chains/addition_chains_mod.py b/oraqle/add_chains/addition_chains_mod.py new file mode 100644 index 0000000..08c4181 --- /dev/null +++ b/oraqle/add_chains/addition_chains_mod.py @@ -0,0 +1,164 @@ +"""Tools for computing addition chains, taking into account the modular nature of the algebra.""" +import math +from typing import List, Optional, Tuple + +from oraqle.add_chains.addition_chains import add_chain + + +def hw(n: int) -> int: + """Returns the Hamming weight of n.""" + c = 0 + while n: + c += 1 + n &= n - 1 + + return c + + +def size_lower_bound(target: int) -> int: + """Returns a lower bound on the size of the addition chain for this target.""" + return math.ceil( + max( + math.log2(target) + math.log2(hw(target)) - 2.13, + math.log2(target), + math.log2(target) + math.log(hw(target), 3) - 1, + ) + ) + + +def cost_lower_bound_monotonic(target: int, squaring_cost: float) -> float: + """Returns a lower bound on the cost of the addition chain for this target. The bound is guaranteed to grow monotonically with the target.""" + return math.ceil(math.log2(target)) * squaring_cost + + +def chain_cost(chain: List[Tuple[int, int]], squaring_cost: float) -> float: + """Returns the cost of the addition chain, considering doubling (squaring) to be cheaper than other additions (multiplications).""" + return sum(squaring_cost if x == y else 1.0 for x, y in chain) + + +def add_chain_modp( # noqa: PLR0913, PLR0917 + target: int, + modulus: int, + max_depth: Optional[int], + strict_cost_max: float, + squaring_cost: float, + solver, + encoding, + thurber, + min_size: int, + precomputed_values: Optional[Tuple[Tuple[int, int], ...]] = None, +) -> Optional[List[Tuple[int, int]]]: + """Computes an addition chain for target modulo p with the given constraints and optimization parameters. + + The precomputed_powers are an optional set of powers that have previously been computed along with their depth. + This means that those powers can be reused for free. + + Returns: + If it exists, a minimal addition chain meeting the given constraints and optimization parameters. + """ + if precomputed_values is not None: + # The shortest chain in (t + (k-1)p, t + kp] will have length at least k + # The cheapest chain in (t + (k-1)p, t + kp] will have cost at least k / sqr_cost + best_chain = None + + k = 0 + while (k / squaring_cost) < strict_cost_max: + # Add multiples of the precomputed_values + new_precomputed_values = [] + for precomputed_value, depth in precomputed_values: + for i in range(k + 1): + new_precomputed_values.append((precomputed_value + i * modulus, depth)) + + chain = add_chain( + target + k * modulus, + max_depth, + strict_cost_max, + squaring_cost, + solver, + encoding, + thurber, + min_size=max(min_size, k), + precomputed_values=tuple(new_precomputed_values), + ) + + if chain is not None: + cost = chain_cost(chain, squaring_cost) + strict_cost_max = min(strict_cost_max, cost) + best_chain = chain + + k += 1 + + return best_chain + + best_chain = None + best_cost = None + + current_target = target + + i = 0 + + while cost_lower_bound_monotonic(current_target, squaring_cost) < strict_cost_max and ( + max_depth is None or math.ceil(math.log2(current_target)) <= max_depth + ): + tightest_min_size = max(size_lower_bound(current_target), min_size) + if (tightest_min_size * squaring_cost) >= ( + strict_cost_max if best_cost is None else min(strict_cost_max, best_cost) + ): + current_target += modulus + continue + + chain = add_chain( + current_target, + max_depth, + strict_cost_max, + squaring_cost, + solver, + encoding, + thurber, + tightest_min_size, + precomputed_values, + ) + + if chain is not None: + cost = chain_cost(chain, squaring_cost) + if best_cost is None or cost < best_cost: + best_cost = cost + best_chain = chain + strict_cost_max = min(best_cost, strict_cost_max) + + current_target += modulus + + i += 1 + return best_chain + + +def test_add_chain_modp_over_modulus(): # noqa: D103 + chain = add_chain_modp( + 62, + 66, + None, + 8.0, + 0.75, + solver="glucose42", + encoding=1, + thurber=True, + min_size=1, + precomputed_values=None, + ) + assert chain == [(1, 1), (2, 2), (4, 4), (8, 8), (16, 16), (32, 32), (64, 64)] + + +def test_add_chain_modp_precomputations(): # noqa: D103 + chain = add_chain_modp( + 64, # 64+66 = 65+65 + 66, + None, + 2.0, + 0.75, + solver="glucose42", + encoding=1, + thurber=True, + min_size=1, + precomputed_values=((65, 5),), + ) + assert chain == [(65, 65)] diff --git a/oraqle/add_chains/memoization.py b/oraqle/add_chains/memoization.py new file mode 100644 index 0000000..e236888 --- /dev/null +++ b/oraqle/add_chains/memoization.py @@ -0,0 +1,58 @@ +"""This module contains tools for memoizing addition chains, as these are expensive to compute.""" +from hashlib import sha3_256 +import inspect +import shelve +from typing import Set + +from sympy import sieve + + +ADDCHAIN_CACHE_PATH = "addchain_cache" + + +# Adapted from: https://stackoverflow.com/questions/16463582/memoize-to-disk-python-persistent-memoization +def cache_to_disk(file_name, ignore_args: Set[str]): + """This decorator caches the calls to this function in a file on disk, ignoring the arguments listed in `ignore_args`.""" + d = shelve.open(file_name) # noqa: SIM115 + + def decorator(func): + signature = inspect.signature(func) + signature_args = list(signature.parameters.keys()) + assert all(arg in signature_args for arg in ignore_args) + + def wrapped_func(*args, **kwargs): + relevant_args = [a for a, sa in zip(args, signature_args) if sa not in ignore_args] + for kwarg in signature_args[len(args):]: + if kwarg not in ignore_args: + relevant_args.append(kwargs[kwarg]) + + h = sha3_256() + h.update(str(relevant_args).encode('ascii')) + hashed_args = h.hexdigest() + + if hashed_args not in d: + d[hashed_args] = func(*args, **kwargs) + return d[hashed_args] + + return wrapped_func + + return decorator # noqa: DOC201 + + +if __name__ == "__main__": + from oraqle.add_chains.addition_chains_front import gen_pareto_front + + # Precompute addition chains for x^(p-1) mod p for the first 30 primes p + primes = list(sieve.primerange(300))[:30] + for sqr_cost in [0.5, 0.75, 1.0]: + print(f"Computing for {sqr_cost}") + + for p in primes: + gen_pareto_front( + p - 1, + modulus=p - 1, + squaring_cost=sqr_cost, + solver="glucose42", + encoding=1, + thurber=True, + ) diff --git a/oraqle/add_chains/solving.py b/oraqle/add_chains/solving.py new file mode 100644 index 0000000..2cb29cd --- /dev/null +++ b/oraqle/add_chains/solving.py @@ -0,0 +1,139 @@ +"""Tools for solving SAT formulations.""" +import math +import signal +from typing import List, Optional, Sequence, Tuple + +from pysat.examples.rc2 import RC2 +from pysat.formula import WCNF + + +def solve(wcnf: WCNF, solver: str, strict_cost_max: Optional[float]) -> Optional[List[int]]: + """This code is adapted from pysat's internal code to stop when we have reached a maximum cost. + + Returns: + A list containing the assignment (where 3 indicates that 3=True and -3 indicates that 3=False), or None if the wcnf is unsatisfiable. + """ + rc2 = RC2(wcnf, solver) + + if strict_cost_max is None: + strict_cost_max = float("inf") + + while not rc2.oracle.solve(assumptions=rc2.sels + rc2.sums): # type: ignore + rc2.get_core() + + if not rc2.core: + # core is empty, i.e. hard part is unsatisfiable + return None + + rc2.process_core() + + if rc2.cost >= strict_cost_max: + return None + + rc2.model = rc2.oracle.get_model() # type: ignore + + # Return None if the model could not be solved + if rc2.model is None: + return None + + # Extract the model + if rc2.model is None and rc2.pool.top == 0: + # we seem to have been given an empty formula + # so let's transform the None model returned to [] + rc2.model = [] + + rc2.model = filter(lambda inp: abs(inp) in rc2.vmap.i2e, rc2.model) # type: ignore + rc2.model = map(lambda inp: int(math.copysign(rc2.vmap.i2e[abs(inp)], inp)), rc2.model) + rc2.model = sorted(rc2.model, key=abs) + + return rc2.model + + +def extract_indices( + sequence: List[Tuple[int, int]], + precomputed_values: Optional[Sequence[int]] = None, + modulus: Optional[int] = None, +) -> List[Tuple[int, int]]: + """Returns the indices for each step of the addition chain. + + If n precomputed values are provided, then these are considered to be the first n indices after x (i.e. x has index 0, followed by 1, ..., n representing the precomputed values). + """ + indices = {1: 0} + offset = 1 + if precomputed_values is not None: + for v in precomputed_values: + indices[v] = offset + offset += 1 + ans_sequence = [] + + if modulus is None: + for index, pair in enumerate(sequence): + i, j = pair + ans_sequence.append((indices[i], indices[j])) + indices[i + j] = index + offset + else: + for index, pair in enumerate(sequence): + i, j = pair + ans_sequence.append((indices[i % modulus], indices[j % modulus])) + indices[(i + j) % modulus] = index + offset + + return ans_sequence + + +def solve_with_time_limit(wcnf: WCNF, solver: str, strict_cost_max: Optional[float], timeout_secs: float) -> Optional[List[int]]: + """This code is adapted from pysat's internal code to stop when we have reached a maximum cost. + + Raises: # noqa: DOC502 + TimeoutError: When a timeout occurs (after `timeout_secs` seconds) + + Returns: + A list containing the assignment (where 3 indicates that 3=True and -3 indicates that 3=False), or None if the wcnf is unsatisfiable. + """ + def timeout_handler(s, f): + raise TimeoutError + + # Set the timeout + signal.signal(signal.SIGALRM, timeout_handler) + signal.setitimer(signal.ITIMER_REAL, timeout_secs) + + try: + # TODO: Reduce code duplication: we only changed solve to solve_limited + rc2 = RC2(wcnf, solver) + + if strict_cost_max is None: + strict_cost_max = float("inf") + + while not rc2.oracle.solve_limited(assumptions=rc2.sels + rc2.sums, expect_interrupt=True): # type: ignore + rc2.get_core() + + if not rc2.core: + # core is empty, i.e. hard part is unsatisfiable + signal.setitimer(signal.ITIMER_REAL, 0) + return None + + rc2.process_core() + + if rc2.cost >= strict_cost_max: + signal.setitimer(signal.ITIMER_REAL, 0) + return None + + signal.setitimer(signal.ITIMER_REAL, 0) + rc2.model = rc2.oracle.get_model() # type: ignore + + # Return None if the model could not be solved + if rc2.model is None: + return None + + # Extract the model + if rc2.model is None and rc2.pool.top == 0: + # we seem to have been given an empty formula + # so let's transform the None model returned to [] + rc2.model = [] + + rc2.model = filter(lambda inp: abs(inp) in rc2.vmap.i2e, rc2.model) # type: ignore + rc2.model = map(lambda inp: int(math.copysign(rc2.vmap.i2e[abs(inp)], inp)), rc2.model) + rc2.model = sorted(rc2.model, key=abs) + + return rc2.model + except TimeoutError as err: + raise TimeoutError from err # noqa: DOC501 diff --git a/oraqle/circuits/__init__.py b/oraqle/circuits/__init__.py new file mode 100644 index 0000000..d0582c2 --- /dev/null +++ b/oraqle/circuits/__init__.py @@ -0,0 +1 @@ +"""This package contains example circuits and tools for generating them.""" diff --git a/oraqle/circuits/aes.py b/oraqle/circuits/aes.py new file mode 100644 index 0000000..e2832fa --- /dev/null +++ b/oraqle/circuits/aes.py @@ -0,0 +1,87 @@ +"""This module implements a high-level AES encryption circuit for a constant key.""" +from typing import List + +from aeskeyschedule import key_schedule +from galois import GF + +from oraqle.compiler.arithmetic.exponentiation import Power +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes import Constant +from oraqle.compiler.nodes.abstract import Node +from oraqle.compiler.nodes.leafs import Input + +gf = GF(2**8) + + +def encrypt(plaintext: List[Node], key: bytes) -> List[Node]: + """Returns an AES encryption circuit for a constant `key`.""" + mix = [Constant(gf(2)), Constant(gf(3)), Constant(gf(1)), Constant(gf(1))] + + round_keys = [[Constant(gf(byte)) for byte in round_key] for round_key in key_schedule(key)] + + def additions(nodes: List[Node]) -> Node: + node_iter = iter(nodes) + out = next(node_iter) + next(node_iter) + for node in node_iter: + out += node + return out + + def sbox(node: Node, method="minchain") -> Node: + if method == "hardcoded": + x2 = node.mul(node, flatten=False) + x3 = node.mul(x2, flatten=False) + x6 = x3.mul(x3, flatten=False) + x12 = x6.mul(x6, flatten=False) + x15 = x12.mul(x3, flatten=False) + x30 = x15.mul(x15, flatten=False) + x60 = x30.mul(x30, flatten=False) + x63 = x60.mul(x3, flatten=False) + x126 = x63.mul(x63, flatten=False) + x127 = node.mul(x126, flatten=False) + x254 = x127.mul(x127, flatten=False) + return x254 + elif method == "minchain": + return Power(node, 254, gf) + else: + raise Exception(f"Invalid method: {method}.") + + # AddRoundKey + b = [round_key + plaintext_byte for round_key, plaintext_byte in zip(round_keys[0], plaintext)] + + for round in range(9): + # SubBytes (modular inverse) + b = [sbox(b[j], method="hardcoded") for j in range(16)] + + # ShiftRows + b[1], b[5], b[9], b[13] = b[5], b[9], b[13], b[1] + b[2], b[6], b[10], b[14] = b[10], b[14], b[2], b[6] + b[3], b[7], b[11], b[15] = b[15], b[3], b[7], b[11] + + # MixColumns + b = [additions([mix[(j + i) % 4] * b[j // 4 + i] for i in range(4)]) for j in range(16)] + + # AddRoundKey + b = [round_key + b[j] for j, round_key in zip(range(16), round_keys[round + 1])] + b: List[Node] + + return b + + +if __name__ == "__main__": + # TODO: Consider if we want to support degree > 1 + circuit = Circuit( + encrypt([Input(f"{i}", gf) for i in range(16)], b"abcdabcdabcdabcd") + ).arithmetize() + print(circuit) + print(circuit.multiplicative_depth()) + print(circuit.multiplicative_size()) + circuit.eliminate_subexpressions() + print(circuit.multiplicative_depth()) + print(circuit.multiplicative_size()) + + # TODO: Test if it corresponds to a plaintext implementation of AES + + +def test_aes_128(): # noqa: D103 + # Only checks if no errors occur + Circuit(encrypt([Input(f"{i}", gf) for i in range(16)], b"abcdabcdabcdabcd")).arithmetize() diff --git a/oraqle/circuits/cardio.py b/oraqle/circuits/cardio.py new file mode 100644 index 0000000..969a260 --- /dev/null +++ b/oraqle/circuits/cardio.py @@ -0,0 +1,128 @@ +"""This module implements the cardio circuit that is often used in benchmarking compilers, see: https://arxiv.org/abs/2101.07078.""" +from typing import Type +from galois import GF, FieldArray + +from oraqle.compiler.boolean.bool_neg import Neg +from oraqle.compiler.boolean.bool_or import any_ +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes import Input +from oraqle.compiler.nodes.abstract import Node +from oraqle.compiler.nodes.arbitrary_arithmetic import sum_ + + +def construct_cardio_risk_circuit(gf: Type[FieldArray]) -> Node: + """Returns the cardio circuit from https://arxiv.org/abs/2101.07078.""" + man = Input("man", gf) + woman = Input("woman", gf) + smoking = Input("smoking", gf) + age = Input("age", gf) + diabetic = Input("diabetic", gf) + hbp = Input("hbp", gf) + cholesterol = Input("cholesterol", gf) + weight = Input("weight", gf) + height = Input("height", gf) + activity = Input("activity", gf) + alcohol = Input("alcohol", gf) + + return sum_( + man & (age > 50), + woman & (age > 60), + smoking, + diabetic, + hbp, + cholesterol < 40, + weight > (height - 90), # This might underflow if the modulus is too small + activity < 30, + man & (alcohol > 3), + Neg(man, gf) & (alcohol > 2), + ) + + +def construct_cardio_elevated_risk_circuit(gf: Type[FieldArray]) -> Node: + """Returns a variant of the cardio circuit that returns a Boolean indicating whether any risk factor returned true.""" + man = Input("man", gf) + woman = Input("woman", gf) + smoking = Input("smoking", gf) + age = Input("age", gf) + diabetic = Input("diabetic", gf) + hbp = Input("hbp", gf) + cholesterol = Input("cholesterol", gf) + weight = Input("weight", gf) + height = Input("height", gf) + activity = Input("activity", gf) + alcohol = Input("alcohol", gf) + + return any_( + man & (age > 50), + woman & (age > 60), + smoking, + diabetic, + hbp, + cholesterol < 40, + weight > (height - 90), # This might underflow if the modulus is too small + activity < 30, + man & (alcohol > 3), + Neg(man, gf) & (alcohol > 2), + ) + + +def test_cardio_p101(): # noqa: D103 + gf = GF(101) + circuit = Circuit([construct_cardio_risk_circuit(gf)]) + + for _, _, arithmetization in circuit.arithmetize_depth_aware(): + assert arithmetization.evaluate({ + "man": gf(1), + "woman": gf(0), + "age": gf(50), + "smoking": gf(0), + "diabetic": gf(0), + "hbp": gf(0), + "cholesterol": gf(45), + "weight": gf(10), + "height": gf(100), + "activity": gf(90), + "alcohol": gf(3), + })[0] == 0 + + assert arithmetization.evaluate({ + "man": gf(0), + "woman": gf(1), + "age": gf(50), + "smoking": gf(0), + "diabetic": gf(0), + "hbp": gf(0), + "cholesterol": gf(45), + "weight": gf(10), + "height": gf(100), + "activity": gf(90), + "alcohol": gf(3), + })[0] == 1 + + assert arithmetization.evaluate({ + "man": gf(1), + "woman": gf(0), + "age": gf(50), + "smoking": gf(0), + "diabetic": gf(0), + "hbp": gf(0), + "cholesterol": gf(39), + "weight": gf(10), + "height": gf(100), + "activity": gf(90), + "alcohol": gf(3), + })[0] == 1 + + assert arithmetization.evaluate({ + "man": gf(1), + "woman": gf(0), + "age": gf(50), + "smoking": gf(1), + "diabetic": gf(0), + "hbp": gf(0), + "cholesterol": gf(45), + "weight": gf(10), + "height": gf(100), + "activity": gf(90), + "alcohol": gf(3), + })[0] == 1 diff --git a/oraqle/circuits/median.py b/oraqle/circuits/median.py new file mode 100644 index 0000000..ccc5847 --- /dev/null +++ b/oraqle/circuits/median.py @@ -0,0 +1,30 @@ +"""This module implements circuits for computing the median.""" +from typing import Sequence, Type + +from galois import GF, FieldArray + +from oraqle.circuits.sorting import cswp +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes import Input + +gf = GF(1037347783) + + +def gen_median_circuit(inputs: Sequence[int], gf: Type[FieldArray]): + """Returns a naive circuit for finding the median value of `inputs`.""" + input_nodes = [Input(f"Input {v}", gf) for v in inputs] + + outputs = [n for n in input_nodes] + + for i in range(len(outputs) - 1, -1, -1): + for j in range(i): + outputs[j], outputs[j + 1] = cswp(outputs[j], outputs[j + 1]) # type: ignore + + if len(outputs) % 2 == 1: + return Circuit([outputs[len(outputs) // 2]]) + return Circuit([outputs[len(outputs) // 2 + 1]]) + + +if __name__ == "__main__": + circuit = gen_median_circuit(range(10), gf) + circuit.to_graph("median.dot") diff --git a/oraqle/circuits/mimc.py b/oraqle/circuits/mimc.py new file mode 100644 index 0000000..4521c88 --- /dev/null +++ b/oraqle/circuits/mimc.py @@ -0,0 +1,51 @@ +"""MIMC is an MPC-friendly cipher: https://eprint.iacr.org/2016/492.""" +from math import ceil, log2 +from random import randint + +from galois import GF + +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes import Constant, Input, Node + +gf = GF(680564733841876926926749214863536422929) + + +# TODO: Check parameters with the paper +def encrypt(plaintext: Node, key: int, power_n: int = 129) -> Node: + """Returns an MIMC encryption circuit using a constant key.""" + rounds = ceil(power_n / log2(3)) + + constants = [ + ( + Constant(gf(0)) + if (round == 0) or (round == (rounds - 1)) + else Constant(gf(randint(0, 2**power_n))) + ) + for round in range(rounds) + ] + key_constant = Constant(gf(key)) + + for round in range(rounds): + added = plaintext + key_constant + constants[round] + plaintext = added * added * added + + return plaintext + key_constant + + +if __name__ == "__main__": + node = encrypt(Input("m", gf), 12345) + + circuit = Circuit([node]).arithmetize() + print(circuit.multiplicative_depth()) + print(circuit.multiplicative_size()) + + circuit.to_graph("mimc-129.dot") + + +def test_mimc_129(): # noqa: D103 + node = encrypt(Input("m", gf), 12345) + + circuit = Circuit([node]).arithmetize() + + assert circuit.multiplicative_depth() == 164 + assert circuit.multiplicative_size() == 164 diff --git a/oraqle/circuits/sorting.py b/oraqle/circuits/sorting.py new file mode 100644 index 0000000..b8e4233 --- /dev/null +++ b/oraqle/circuits/sorting.py @@ -0,0 +1,45 @@ +"""This module contains sorting circuits and comparators.""" +from typing import Sequence, Tuple, Type + +from galois import GF, FieldArray + +from oraqle.compiler.circuit import ArithmeticCircuit, Circuit +from oraqle.compiler.nodes import Input +from oraqle.compiler.nodes.abstract import Node + +gf = GF(13) + + +def cswp(lhs: Node, rhs: Node) -> Tuple[Node, Node]: + """Conditionally swap inputs `lhs` and `rhs` such that `lhs <= rhs`. + + Returns: + A tuple representing (lower, higher) + """ + teq = lhs < rhs + + first = teq * (lhs - rhs) + rhs + second = lhs + rhs - first + + return ( + first, + second, + ) + + +def gen_naive_sort_circuit(inputs: Sequence[int], gf: Type[FieldArray]) -> ArithmeticCircuit: + """Returns a naive sorting circuit for the given sequence of `inputs`.""" + input_nodes = [Input(f"Input {v}", gf) for v in inputs] + + outputs = [n for n in input_nodes] + + for i in range(len(outputs) - 1, -1, -1): + for j in range(i): + outputs[j], outputs[j + 1] = cswp(outputs[j], outputs[j + 1]) # type: ignore + + return Circuit(outputs).arithmetize() # type: ignore + + +if __name__ == "__main__": + circuit = gen_naive_sort_circuit(range(2), gf) + circuit.to_graph("sorting.dot") diff --git a/oraqle/circuits/veto_voting.py b/oraqle/circuits/veto_voting.py new file mode 100644 index 0000000..4c6af73 --- /dev/null +++ b/oraqle/circuits/veto_voting.py @@ -0,0 +1,27 @@ +"""The veto voting circuit is the inverse of a consensus vote between a number of participants. + +The circuit is essentially a large OR operation, returning 1 if any participant vetoes (by submitting a 1). +This represents a vote that anyone can veto. +""" +from typing import Type + +from galois import GF, FieldArray + +from oraqle.compiler.boolean.bool_or import any_ +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes import Input + +gf = GF(103) + + +def gen_veto_voting_circuit(participants: int, gf: Type[FieldArray]): + """Returns a veto voting circuit between the number of `participants`.""" + input_nodes = {Input(f"Input {i}", gf) for i in range(participants)} + return Circuit([any_(*input_nodes)]) + + +if __name__ == "__main__": + circuit = gen_veto_voting_circuit(10, gf).arithmetize() + + circuit.eliminate_subexpressions() + circuit.to_graph("veto-voting.dot") diff --git a/oraqle/compiler/__init__.py b/oraqle/compiler/__init__.py new file mode 100644 index 0000000..cd9ee1a --- /dev/null +++ b/oraqle/compiler/__init__.py @@ -0,0 +1 @@ +"""The compiler package contains the main machinery for describing high-level circuits, arithmetizing them, and generating code.""" diff --git a/oraqle/compiler/arithmetic/__init__.py b/oraqle/compiler/arithmetic/__init__.py new file mode 100644 index 0000000..6bc561b --- /dev/null +++ b/oraqle/compiler/arithmetic/__init__.py @@ -0,0 +1 @@ +"""This module contains classes for arithmetic operations that are not simply additions or multiplications.""" diff --git a/oraqle/compiler/arithmetic/exponentiation.py b/oraqle/compiler/arithmetic/exponentiation.py new file mode 100644 index 0000000..44e8ec0 --- /dev/null +++ b/oraqle/compiler/arithmetic/exponentiation.py @@ -0,0 +1,109 @@ +"""This module contains classes and functions for efficient exponentiation circuits.""" +import math +from typing import Type + +from galois import GF, FieldArray + +from oraqle.add_chains.addition_chains_front import gen_pareto_front +from oraqle.add_chains.addition_chains_heuristic import add_chain_guaranteed +from oraqle.add_chains.solving import extract_indices +from oraqle.compiler.nodes.abstract import CostParetoFront, Node +from oraqle.compiler.nodes.binary_arithmetic import Multiplication +from oraqle.compiler.nodes.leafs import Input +from oraqle.compiler.nodes.univariate import UnivariateNode + + +# TODO: Think about the role of Power when there are also Products +class Power(UnivariateNode): + """Represents an exponentiation: x ** constant.""" + + @property + def _node_shape(self) -> str: + return "box" + + @property + def _hash_name(self) -> str: + return f"pow_{self._exponent}" + + @property + def _node_label(self) -> str: + return f"Pow: {self._exponent}" + + def __init__(self, node: Node, exponent: int, gf: Type[FieldArray]): + """Initialize a `Power` node that exponentiates `node` with `exponent`.""" + self._exponent = exponent + super().__init__(node, gf) + + def _operation_inner(self, input: FieldArray, gf: Type[FieldArray]) -> FieldArray: + return input**self._exponent # type: ignore + + def _arithmetize_inner(self, strategy: str) -> "Node": + if strategy == "naive": + # Square & multiply + nodes = [self._node.arithmetize(strategy)] + + for i in range(math.ceil(math.log2(self._exponent))): + nodes.append(nodes[i].mul(nodes[i], flatten=False)) + previous = None + for i in range(math.ceil(math.log2(self._exponent))): + if (self._exponent >> i) & 1: + if previous is None: + previous = nodes[i] + else: + nodes.append(nodes[i].mul(previous, flatten=False)) + previous = nodes[-1] + + assert previous is not None + return previous + + assert strategy == "best-effort" + + addition_chain = add_chain_guaranteed(self._exponent, self._gf.characteristic - 1, squaring_cost=1.0) + + nodes = [self._node.arithmetize(strategy).to_arithmetic()] + + for i, j in addition_chain: + nodes.append(Multiplication(nodes[i], nodes[j], self._gf)) + + return nodes[-1] + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + # TODO: While generating the front, we can take into account the maximum cost etc. implied by the depth-aware arithmetization of the operand + if self._gf.characteristic <= 257: + front = gen_pareto_front(self._exponent, self._gf.characteristic, cost_of_squaring) + else: + front = gen_pareto_front(self._exponent, None, cost_of_squaring) + + final_front = CostParetoFront(cost_of_squaring) + + for depth1, _, node in self._node.arithmetize_depth_aware(cost_of_squaring): + for depth2, chain in front: + c = extract_indices( + chain, + modulus=self._gf.characteristic - 1 if self._gf.characteristic <= 257 else None, + ) + + nodes = [node] + + for i, j in c: + nodes.append(Multiplication(nodes[i], nodes[j], self._gf)) + + final_front.add(nodes[-1], depth=depth1 + depth2) + + return final_front + + +def test_depth_aware_arithmetization(): # noqa: D103 + gf = GF(31) + + x = Input("x", gf) + node = Power(x, 30, gf) + front = node.arithmetize_depth_aware(cost_of_squaring=1.0) + node.clear_cache(set()) + + for _, _, n in front: + assert n.evaluate({"x": gf(0)}) == 0 + n.clear_cache(set()) + + for xx in range(1, 31): + assert n.evaluate({"x": gf(xx)}) == 1 diff --git a/oraqle/compiler/arithmetic/subtraction.py b/oraqle/compiler/arithmetic/subtraction.py new file mode 100644 index 0000000..bda0437 --- /dev/null +++ b/oraqle/compiler/arithmetic/subtraction.py @@ -0,0 +1,68 @@ +"""This module contains classes for representing subtraction: x - y.""" +from galois import GF, FieldArray + +from oraqle.compiler.nodes.abstract import CostParetoFront, Node +from oraqle.compiler.nodes.leafs import Constant, Input +from oraqle.compiler.nodes.non_commutative import NonCommutativeBinaryNode + + +class Subtraction(NonCommutativeBinaryNode): + """Represents a subtraction, which can be arithmetized using addition and constant-multiplication.""" + + @property + def _overriden_graphviz_attributes(self) -> dict: + return {"shape": "square", "style": "rounded,filled", "fillcolor": "cornsilk"} + + @property + def _hash_name(self) -> str: + return "sub" + + @property + def _node_label(self) -> str: + return "-" + + def _operation_inner(self, x, y) -> FieldArray: + return x - y + + def _arithmetize_inner(self, strategy: str) -> Node: + # TODO: Reorganize the files: let the arithmetic folder only contain pure arithmetic (including add and mul) and move exponentiation elsewhere. + # TODO: For schemes that support subtraction we do not need to do this. We should only do this transformation during the compiler stage. + return (self._left.arithmetize(strategy) + (Constant(-self._gf(1)) * self._right.arithmetize(strategy))).arithmetize(strategy) # type: ignore # TODO: Should we always perform a final arithmetization in every node for constant folding? E.g. in Node? + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + result = self._left + (Constant(-self._gf(1)) * self._right) + front = result.arithmetize_depth_aware(cost_of_squaring) + return front + + +def test_evaluate_mod5(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = Subtraction(a, b, gf) + + assert node.evaluate({"a": gf(3), "b": gf(2)}) == gf(1) + node.clear_cache(set()) + assert node.evaluate({"a": gf(4), "b": gf(1)}) == gf(3) + node.clear_cache(set()) + assert node.evaluate({"a": gf(1), "b": gf(3)}) == gf(3) + node.clear_cache(set()) + assert node.evaluate({"a": gf(0), "b": gf(4)}) == gf(1) + + +def test_evaluate_arithmetized_mod5(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = Subtraction(a, b, gf).arithmetize("best-effort") + node.clear_cache(set()) + + assert node.evaluate({"a": gf(3), "b": gf(2)}) == gf(1) + node.clear_cache(set()) + assert node.evaluate({"a": gf(4), "b": gf(1)}) == gf(3) + node.clear_cache(set()) + assert node.evaluate({"a": gf(1), "b": gf(3)}) == gf(3) + node.clear_cache(set()) + assert node.evaluate({"a": gf(0), "b": gf(4)}) == gf(1) diff --git a/oraqle/compiler/boolean/__init__.py b/oraqle/compiler/boolean/__init__.py new file mode 100644 index 0000000..e68ad67 --- /dev/null +++ b/oraqle/compiler/boolean/__init__.py @@ -0,0 +1 @@ +"""This package contains nodes for expressing common Boolean operations.""" diff --git a/oraqle/compiler/boolean/bool_and.py b/oraqle/compiler/boolean/bool_and.py new file mode 100644 index 0000000..c172567 --- /dev/null +++ b/oraqle/compiler/boolean/bool_and.py @@ -0,0 +1,743 @@ +"""This module contains tools for evaluating AND operations between many inputs.""" +import itertools +import math +from abc import ABC, abstractmethod +from collections import Counter +from heapq import heapify, heappop, heappush +from typing import Iterable, List, Optional, Sequence, Set, Tuple, Type + +from galois import GF, FieldArray + +from oraqle.add_chains.addition_chains_front import gen_pareto_front +from oraqle.add_chains.addition_chains_mod import chain_cost +from oraqle.add_chains.solving import extract_indices +from oraqle.compiler.boolean.bool_neg import Neg +from oraqle.compiler.comparison.equality import IsNonZero +from oraqle.compiler.nodes.abstract import ( + ArithmeticNode, + CostParetoFront, + Node, + UnoverloadedWrapper, +) +from oraqle.compiler.nodes.arbitrary_arithmetic import ( + _PrioritizedItem, + Product, + Sum, + _generate_multiplication_tree, +) +from oraqle.compiler.nodes.binary_arithmetic import Multiplication +from oraqle.compiler.nodes.flexible import CommutativeUniqueReducibleNode +from oraqle.compiler.nodes.leafs import Constant, Input + + +class And(CommutativeUniqueReducibleNode): + """Performs an AND operation over several operands. The user must ensure that the operands are Booleans.""" + + @property + def _hash_name(self) -> str: + return "and" + + @property + def _node_label(self) -> str: + return "AND" + + def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray: + return self._gf(bool(a) & bool(b)) + + def _arithmetize_inner(self, strategy: str) -> Node: # noqa: PLR0911, PLR0912 + new_operands: Set[UnoverloadedWrapper] = set() + for operand in self._operands: + new_operand = operand.node.arithmetize(strategy) + + if isinstance(new_operand, Constant): + if not bool(new_operand._value): + return Constant(self._gf(0)) + continue + + new_operands.add(UnoverloadedWrapper(new_operand)) + + if len(new_operands) == 0: + return Constant(self._gf(1)) + elif len(new_operands) == 1: + return next(iter(new_operands)).node + + if strategy == "naive": + return Product(Counter({operand: 1 for operand in new_operands}), self._gf).arithmetize( + strategy + ) + + # TODO: Calling to_arithmetic here should not be necessary if we can decide the predicted depth + queue = [ + ( + _PrioritizedItem( + 0, operand.node + ) # TODO: We should just maybe make a breadth method on Node + if isinstance(operand.node, Constant) + else _PrioritizedItem( + operand.node.to_arithmetic().multiplicative_depth(), operand.node + ) + ) + for operand in new_operands + ] + heapify(queue) + + while len(queue) > (self._gf._characteristic - 1): + total_sum = None + max_depth = None + for _ in range(self._gf._characteristic - 1): + if len(queue) == 0: + break + + popped = heappop(queue) + if max_depth is None or max_depth < popped.priority: + max_depth = popped.priority + + if total_sum is None: + total_sum = Neg(popped.item, self._gf) + else: + total_sum += Neg(popped.item, self._gf) + + assert total_sum is not None + final_result = Neg(IsNonZero(total_sum, self._gf), self._gf).arithmetize(strategy) + + assert max_depth is not None + heappush(queue, _PrioritizedItem(max_depth, final_result)) + + if len(queue) == 1: + return heappop(queue).item + + dummy_node = Input("dummy_node", self._gf) + is_non_zero = IsNonZero(dummy_node, self._gf).arithmetize(strategy).to_arithmetic() + cost = is_non_zero.multiplicative_cost( + 1.0 + ) # FIXME: This needs to be the actual squaring cost + + if len(queue) - 1 < cost: + return Product( + Counter({UnoverloadedWrapper(operand.item): 1 for operand in queue}), self._gf + ).arithmetize(strategy) + + return Neg( + IsNonZero( + Sum( + Counter({UnoverloadedWrapper(Neg(node.item, self._gf)): 1 for node in queue}), + self._gf, + ), + self._gf, + ), + self._gf, + ).arithmetize(strategy) + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + new_operands: Set[CostParetoFront] = set() + for operand in self._operands: + new_operand = operand.node.arithmetize_depth_aware(cost_of_squaring) + new_operands.add(new_operand) + + if len(new_operands) == 0: + return CostParetoFront.from_leaf(Constant(self._gf(1)), cost_of_squaring) + elif len(new_operands) == 1: + return next(iter(new_operands)) + + front = CostParetoFront(cost_of_squaring) + + # TODO: This is brute force composition + for operands in itertools.product(*(iter(new_operand) for new_operand in new_operands)): + checked_operands = [] + for depth, cost, node in operands: + if isinstance(node, Constant): + assert int(node._value) in {0, 1} + if node._value == 0: + return CostParetoFront.from_leaf(Constant(self._gf(0)), cost_of_squaring) + else: + checked_operands.append((depth, cost, node)) + + if len(checked_operands) == 0: + return CostParetoFront.from_leaf(Constant(self._gf(1)), cost_of_squaring) + + if len(checked_operands) == 1: + depth, cost, node = checked_operands[0] + front.add(node, depth, cost) + continue + + this_front = _find_depth_cost_front( + checked_operands, + self._gf, + float("inf"), + squaring_cost=cost_of_squaring, + is_and=True, + ) + front.add_front(this_front) + + return front + + def and_flatten(self, other: Node) -> Node: + """Performs an AND operation with `other`, flattening the `And` node if either of the two is also an `And` and absorbing `Constant`s. + + Returns: + An `And` node containing the flattened AND operation, or a `Constant` node. + """ + if isinstance(other, Constant): + if bool(other._value): + return self + else: + return Constant(self._gf(0)) + + if isinstance(other, And): + return And(self._operands | other._operands, self._gf) + + new_operands = self._operands.copy() + new_operands.add(UnoverloadedWrapper(other)) + return And(new_operands, self._gf) + + +def test_evaluate_mod3(): # noqa: D103 + gf = GF(3) + + a = Input("a", gf) + b = Input("b", gf) + node = (a & b).arithmetize("best-effort") + + assert node.evaluate({"a": gf(0), "b": gf(0)}) == gf(0) + node.clear_cache(set()) + assert node.evaluate({"a": gf(0), "b": gf(1)}) == gf(0) + node.clear_cache(set()) + assert node.evaluate({"a": gf(1), "b": gf(0)}) == gf(0) + node.clear_cache(set()) + assert node.evaluate({"a": gf(1), "b": gf(1)}) == gf(1) + + +def test_evaluate_arithmetized_mod3(): # noqa: D103 + gf = GF(3) + + a = Input("a", gf) + b = Input("b", gf) + node = (a & b).arithmetize("best-effort") + + node.clear_cache(set()) + assert node.evaluate({"a": gf(0), "b": gf(0)}) == gf(0) + node.clear_cache(set()) + assert node.evaluate({"a": gf(0), "b": gf(1)}) == gf(0) + node.clear_cache(set()) + assert node.evaluate({"a": gf(1), "b": gf(0)}) == gf(0) + node.clear_cache(set()) + assert node.evaluate({"a": gf(1), "b": gf(1)}) == gf(1) + + +def test_evaluate_arithmetized_depth_aware_mod2(): # noqa: D103 + gf = GF(2) + + a = Input("a", gf) + b = Input("b", gf) + node = a & b + front = node.arithmetize_depth_aware(cost_of_squaring=1.0) + + for _, _, n in front: + n.clear_cache(set()) + assert n.evaluate({"a": gf(0), "b": gf(0)}) == gf(0) + n.clear_cache(set()) + assert n.evaluate({"a": gf(0), "b": gf(1)}) == gf(0) + n.clear_cache(set()) + assert n.evaluate({"a": gf(1), "b": gf(0)}) == gf(0) + n.clear_cache(set()) + assert n.evaluate({"a": gf(1), "b": gf(1)}) == gf(1) + + +def test_evaluate_arithmetized_depth_aware_mod3(): # noqa: D103 + gf = GF(3) + + a = Input("a", gf) + b = Input("b", gf) + node = a & b + front = node.arithmetize_depth_aware(cost_of_squaring=1.0) + + for _, _, n in front: + n.clear_cache(set()) + assert n.evaluate({"a": gf(0), "b": gf(0)}) == gf(0) + n.clear_cache(set()) + assert n.evaluate({"a": gf(0), "b": gf(1)}) == gf(0) + n.clear_cache(set()) + assert n.evaluate({"a": gf(1), "b": gf(0)}) == gf(0) + n.clear_cache(set()) + assert n.evaluate({"a": gf(1), "b": gf(1)}) == gf(1) + + +def test_evaluate_arithmetized_depth_aware_7_mod5(): # noqa: D103 + gf = GF(5) + + xs = {Input(f"x{i}", gf) for i in range(7)} + node = And({UnoverloadedWrapper(x) for x in xs}, gf) # type: ignore + front = node.arithmetize_depth_aware(cost_of_squaring=1.0) + + for _, _, n in front: + n.clear_cache(set()) + assert n.evaluate({f"x{i}": gf(0) for i in range(50)}) == gf(0) + n.clear_cache(set()) + assert n.evaluate({f"x{i}": gf(i % 2) for i in range(50)}) == gf(0) + n.clear_cache(set()) + assert n.evaluate({f"x{i}": gf(1) for i in range(50)}) == gf(1) + + +def test_evaluate_arithmetized_depth_aware_50_mod31(): # noqa: D103 + gf = GF(31) + + xs = {Input(f"x{i}", gf) for i in range(50)} + node = And({UnoverloadedWrapper(x) for x in xs}, gf) # type: ignore + front = node.arithmetize_depth_aware(cost_of_squaring=1.0) + + for _, _, n in front: + n.clear_cache(set()) + assert n.evaluate({f"x{i}": gf(0) for i in range(50)}) == gf(0) + n.clear_cache(set()) + assert n.evaluate({f"x{i}": gf(i % 2) for i in range(50)}) == gf(0) + n.clear_cache(set()) + assert n.evaluate({f"x{i}": gf(1) for i in range(50)}) == gf(1) + + +class NaryLogicNode(ABC): + """Represents a (sub)circuit for computing and AND or OR operation.""" + + def __init__(self, breadth: int, cost: float) -> None: + """Initialize this logic node with the given `breadth` and `cost` (which are not checked).""" + self.breadth = breadth + self.cost = cost + + @abstractmethod + def local_cost(self) -> float: + """Compute the local multiplicative cost, so ignoring the cost of the inputs.""" + + @abstractmethod + def print(self, level: int = 0): + """Prints this subcircuit for debugging purposes.""" + + @abstractmethod + def to_arithmetic_node(self, is_and: bool, gf: Type[FieldArray]) -> ArithmeticNode: + """Returns an `ArithmeticNode` representing this logic node (AND if `is_and = True` else OR).""" + + +class InputNaryLogicNode(NaryLogicNode): + """An input logic node.""" + + def __init__(self, node: ArithmeticNode, breadth: int) -> None: + """Initialize the input node with the given `breadth`.""" + self._node = node + super().__init__(breadth, 0.0) + + def local_cost(self) -> float: # noqa: D102 + return 0.0 + + def print(self, level: int = 0): # noqa: D102 + print(" " * level + "x") + + def to_arithmetic_node(self, is_and: bool, gf: Type[FieldArray]) -> ArithmeticNode: # noqa: D102 + return self._node + + +class ProductNaryLogicNode(NaryLogicNode): + """A `ProductNaryLogicNode` represents an OR/AND (sub)circuit in which all inputs are multiplied (and flattened).""" + + def __init__(self, operands: List[NaryLogicNode], breadth: int) -> None: + """Initialize a product subcircuit with the given `operands` and `breadth`.""" + # Merge subproducts into this product + self._operands = list( + itertools.chain.from_iterable( + operand._operands if isinstance(operand, ProductNaryLogicNode) else [operand] + for operand in operands + ) + ) + self._arithmetic_node = None + self._is_and = None + super().__init__(breadth, self._compute_cost()) + + def _compute_cost(self) -> float: + return sum(op.cost for op in self._operands) + len(self._operands) - 1 + + def local_cost(self) -> float: # noqa: D102 + return len(self._operands) - 1 + + def print(self, level: int = 0): # noqa: D102 + print(" " * level + "prod:") + for op in self._operands: + op.print(level + 1) + + def to_arithmetic_node(self, is_and: bool, gf: Type[FieldArray]) -> ArithmeticNode: # noqa: D102 + if self._is_and is not None and self._is_and != is_and: + self._arithmetic_node = None + + if self._arithmetic_node is None: + _, result = _generate_multiplication_tree(((math.ceil(math.log2(operand.breadth)), operand.to_arithmetic_node(is_and, gf) if is_and else Neg(operand.to_arithmetic_node(is_and, gf), gf).arithmetize("best-effort").to_arithmetic()) for operand in self._operands), (1 for _ in range(len(self._operands)))) # type: ignore + + if not is_and: + result = Neg(result, gf) + + self._arithmetic_node = result.arithmetize( + "best-effort" + ).to_arithmetic() # TODO: This could be more elegant + self._is_and = is_and + + assert math.ceil(math.log2(self.breadth)) == self._arithmetic_node.multiplicative_depth() # type: ignore + return self._arithmetic_node + + +class SumReduceNaryLogicNode(NaryLogicNode): + """A `SumReduceNaryLogicNode` represents an OR/AND (sub)circuit in which all inputs are summed and then reduced to a Boolean.""" + + def __init__( + self, + operands: List[NaryLogicNode], + exponentiation_depth: int, + exponentiation_cost: float, + exponentiation_chain: List[Tuple[int, int]], + breadth: int, + ) -> None: + """Initialize a sum-reduce subcircuit with the given exponentiation chain (and properties), over the given `operands`.""" + self._operands = operands + self._exponentiation_depth = exponentiation_depth + self._exponentiation_cost = exponentiation_cost + self._exponentiation_chain = exponentiation_chain + self._arithmetic_node = None + self._is_and = None + super().__init__(breadth, self._compute_cost()) + + def _compute_cost(self) -> float: + return sum(op.cost for op in self._operands) + self._exponentiation_cost + + def local_cost(self) -> float: # noqa: D102 + return self._exponentiation_cost + + def print(self, level: int = 0): # noqa: D102 + print(" " * level + f"sumred({self._exponentiation_depth}, {self._exponentiation_cost}):") + for op in self._operands: + op.print(level + 1) + + def to_arithmetic_node(self, is_and: bool, gf: Type[FieldArray]) -> ArithmeticNode: # noqa: D102 + if self._is_and is not None and self._is_and != is_and: + self._arithmetic_node = None + + if self._arithmetic_node is None: + # TODO: This should be replaced by augmented circuit nodes + if is_and: + result = ( + Sum( + Counter( + { + UnoverloadedWrapper( + Neg(operand.to_arithmetic_node(is_and, gf), gf) + ): 1 + for operand in self._operands + } + ), + gf, + ) + .arithmetize("best-effort") + .to_arithmetic() + ) + else: + result = ( + Sum( + Counter( + { + UnoverloadedWrapper(operand.to_arithmetic_node(is_and, gf)): 1 + for operand in self._operands + } + ), + gf, + ) + .arithmetize("best-effort") + .to_arithmetic() + ) + + # Exponentiation + chain = extract_indices(self._exponentiation_chain, modulus=gf.characteristic - 1) + nodes = [result] + for i, j in chain: + nodes.append(Multiplication(nodes[i], nodes[j], gf)) # type: ignore + result = nodes[-1] + + if is_and: + result = Neg(result, gf).arithmetize("best-effort") + + self._arithmetic_node = result.to_arithmetic() # TODO: This could be more elegant + self._is_and = is_and + + assert math.ceil(math.log2(self.breadth)) == self._arithmetic_node.multiplicative_depth() # type: ignore + return self._arithmetic_node + + +def _minimum_cost(operand_count: int, exponentiation_cost: float, p: int) -> float: + r = math.ceil((p - 1 - operand_count) / (2 - p)) + return r * exponentiation_cost + min(exponentiation_cost, operand_count + r * (2 - p) - 1) + + +def _find_depth_cost_front( + operands: Sequence[Tuple[int, float, ArithmeticNode]], + gf: Type[FieldArray], + strict_cost_upper: float, + squaring_cost: float, + is_and: bool, +) -> CostParetoFront: + new_operands: List[NaryLogicNode] = [ + InputNaryLogicNode(node, 0 if isinstance(node, Constant) else 2**depth) + for depth, _, node in operands + ] + + circuits = minimize_depth_cost( + new_operands, gf.characteristic, strict_cost_upper, squaring_cost + ) + + front = CostParetoFront(squaring_cost) + for depth, _, node in circuits: + front.add(node.to_arithmetic_node(is_and, gf), depth) + + return front + + +# TODO: This is copied from arbitrary_arithmetic.py +def _generate_sumred_tree( + operands: Iterable[Tuple[int, InputNaryLogicNode]], + squaring_cost: float, +) -> Tuple[int, SumReduceNaryLogicNode]: + queue = [_PrioritizedItem(*operand) for operand in operands] + heapify(queue) + + while len(queue) > 1: + a = heappop(queue) + b = heappop(queue) + + depth = max(a.priority, b.priority) + 1 + heappush( + queue, + _PrioritizedItem( + depth, + SumReduceNaryLogicNode([a.item, b.item], 2, squaring_cost, [(1, 1)], 2**depth), + ), + ) + + return (queue[0].priority, queue[0].item) + + +def minimize_depth_cost( + operands: List[NaryLogicNode], p: int, strict_cost_upper: float, squaring_cost: float +) -> List[Tuple[int, float, NaryLogicNode]]: + """Finds the depth-cost Pareto front. + + Returns: + A front in the form of a list of tuples containing (depth, cost, node). + """ + assert len(operands) >= 2 + + if p == 2: + result = ProductNaryLogicNode( + operands, breadth=sum(operand.breadth for operand in operands) + ) + return [(math.ceil(math.log2(result.breadth)), result.cost, result)] + + if p == 3: + depth, result = _generate_sumred_tree([(math.ceil(math.log2(operand.breadth)), operand) for operand in operands], squaring_cost) # type: ignore + return [(depth, result.cost, result)] + + sorted_operands = sorted(operands, key=lambda op: op.breadth, reverse=True) + depth_limit = math.ceil(math.log2(sorted_operands[0].breadth)) # + 1 + + front = gen_pareto_front(p - 1, p, squaring_cost) + exponentiation_specs = [ + (depth, chain_cost(chain, squaring_cost), chain) for depth, chain in front + ] + _, cheapest_exponentiation_cost, _ = exponentiation_specs[-1] + + mincost = _minimum_cost(len(sorted_operands), cheapest_exponentiation_cost, p) + + circuits = [] + while True: + breadth_limit = 2**depth_limit + result = minimize_depth_cost_recursive( + sorted_operands, breadth_limit, exponentiation_specs, p, strict_cost_upper + ) + + if result is None: + depth_limit += 1 + continue + + assert result.cost >= mincost + assert result.cost < strict_cost_upper, f"{result.cost} >= {strict_cost_upper}" + + if result.cost == mincost: + circuits.append((depth_limit, result.cost, result)) + return circuits + + circuits.append((depth_limit, result.cost, result)) + strict_cost_upper = result.cost + + # TODO: If we want to return the minimum breadth we have to increment at a higher resolution + depth_limit += 1 + + +def _find_index_breadth(sorted_operands: List[NaryLogicNode], greater_or_equal_to: int) -> int: + for i in range(len(sorted_operands)): + if sorted_operands[i].breadth < greater_or_equal_to: + return i + + return len(sorted_operands) + + +def _insert(sorted_operands: List[NaryLogicNode], node: NaryLogicNode): + for i in range(len(sorted_operands)): + if sorted_operands[i].breadth < node.breadth: + sorted_operands.insert(i, node) + return + + sorted_operands.append(node) + + +def minimize_depth_cost_recursive( # noqa: PLR0912, PLR0914, PLR0915 + sorted_operands: List[NaryLogicNode], + breadth_limit: int, + exponentiation_specs: List[Tuple[int, float, List[Tuple[int, int]]]], + p: int, + strict_cost_upper: float, +) -> Optional[NaryLogicNode]: + """Find a minimum-depth circuit for the given `breadth_limit` and `strict_cost_upper` bound. + + Operands must be sorted from deep to shallow. + Returns the lowest-cost circuit for the given depth. + The exponentiation_specs must be sorted from high-cost to low-cost. + + Returns: + A minimum-depth circuit in the form of an `NaryLogicNode` satisfying the constraints, or None if the constraints cannot be satisfied. + """ + if len(sorted_operands) == 1: + if breadth_limit >= sorted_operands[0].breadth and strict_cost_upper > 0: + assert sorted_operands[0].cost < strict_cost_upper + return sorted_operands[0] + return None + + # If the breadth limit is exceeded, stop + if breadth_limit < 1: + return None + + # If the cost limit is exceeded, stop + if strict_cost_upper <= 0: + return None + + # If the lower bound for the cost exceeds the limit, also stop + _, cheapest_exponentiation_cost, _ = exponentiation_specs[-1] + lower_bound_cost = _minimum_cost(len(sorted_operands), cheapest_exponentiation_cost, p) + if lower_bound_cost >= strict_cost_upper: + return None + + output = None + + for exponentiation_depth, exponentiation_cost, exponentiation_chain in exponentiation_specs: + # We do not call .cost() in this algorithm because we only consider the cost of the AND/OR subcircuit + + type_2_limit = 2 ** (math.ceil(math.log2(breadth_limit)) - exponentiation_depth) + if len(sorted_operands) < p: + if ( + all(operand.breadth <= type_2_limit for operand in sorted_operands) + and exponentiation_cost < strict_cost_upper + ): + # Use a type-2 arithmetization + depth = math.ceil(math.log2(sorted_operands[0].breadth)) + exponentiation_depth + output = SumReduceNaryLogicNode( + sorted_operands, + exponentiation_depth, + exponentiation_cost, + exponentiation_chain, + breadth=2**depth, + ) + strict_cost_upper = exponentiation_cost + + if (tot := sum(op.breadth for op in sorted_operands)) <= breadth_limit and len( + sorted_operands + ) - 1 < strict_cost_upper: + output = ProductNaryLogicNode(sorted_operands, breadth=tot) + strict_cost_upper = len(sorted_operands) - 1 + + continue + + # At this point, we know that len(sorted_operands) >= p + + # Try a type-1 arithmetization, so no type-2 at all + if (tot := sum(op.breadth for op in sorted_operands)) <= breadth_limit and len( + sorted_operands + ) - 1 < strict_cost_upper: + output = ProductNaryLogicNode(sorted_operands, breadth=tot) + strict_cost_upper = len(sorted_operands) - 1 + + reduced = all(operand.breadth <= type_2_limit for operand in sorted_operands) + if reduced: + if exponentiation_cost >= strict_cost_upper: + continue + + # Use a type-2 arithmetization on operands of decreasing depth + cache = set() + for i in range(len(sorted_operands) - 1): + selected_operands = sorted_operands[i : (i + p - 1)] + breadths = tuple(operand.breadth for operand in selected_operands) + if breadths in cache: + continue + cache.add(breadths) + + depth = math.ceil(math.log2(selected_operands[0].breadth)) + exponentiation_depth + + new_operands = sorted_operands[:i] + if i + p - 1 < len(sorted_operands): + new_operands += sorted_operands[i + p - 1 :] + + breadth = 2**depth + sum_red = SumReduceNaryLogicNode( + sorted_operands[i : i + p - 1], + exponentiation_depth, + exponentiation_cost, + exponentiation_chain, + breadth=breadth, + ) + _insert(new_operands, sum_red) + + potential_output = minimize_depth_cost_recursive( + new_operands, + breadth_limit, + exponentiation_specs, + p, + strict_cost_upper - exponentiation_cost, + ) + if potential_output is not None: + output = potential_output + strict_cost_upper -= potential_output.local_cost() + else: + # Isolate all the operands that cannot use a type-2 arithmetization + first_small_index = _find_index_breadth(sorted_operands, type_2_limit) + large_operands = sorted_operands[:first_small_index] + small_operands = sorted_operands[first_small_index:] + + # If there are no small operands, then this arithmetization is not possible + if len(small_operands) == 0: + continue + + # Use a type-1 arithmetization for large_operands + assert len(large_operands) > 0 + cost = len( + large_operands + ) # Not -1 because we also need a multiplication with the AND/OR of small_operands + if cost >= strict_cost_upper: + continue + + breadth = sum(operand.breadth for operand in large_operands) + new_breadth_limit = breadth_limit - breadth + + sub_output = minimize_depth_cost_recursive( + small_operands, new_breadth_limit, exponentiation_specs, p, strict_cost_upper - cost + ) + if sub_output is not None: + output = ProductNaryLogicNode( + [*large_operands, sub_output], breadth=breadth + sub_output.breadth + ) + strict_cost_upper -= output.local_cost() + + return output + + +def all_(*operands: Node) -> And: + """Returns an `And` node that evaluates to true if any of the given `operands` evaluates to true.""" + assert len(operands) > 0 + return And(set(UnoverloadedWrapper(operand) for operand in operands), operands[0]._gf) diff --git a/oraqle/compiler/boolean/bool_neg.py b/oraqle/compiler/boolean/bool_neg.py new file mode 100644 index 0000000..ac9ede5 --- /dev/null +++ b/oraqle/compiler/boolean/bool_neg.py @@ -0,0 +1,37 @@ +"""Classes for describing Boolean negation.""" +from galois import FieldArray + +from oraqle.compiler.arithmetic.subtraction import Subtraction +from oraqle.compiler.nodes.abstract import CostParetoFront, Node +from oraqle.compiler.nodes.leafs import Constant +from oraqle.compiler.nodes.univariate import UnivariateNode + + +class Neg(UnivariateNode): + """A node that negates a Boolean input.""" + + @property + def _node_shape(self) -> str: + return "box" + + @property + def _hash_name(self) -> str: + return "neg" + + @property + def _node_label(self) -> str: + return "NEG" + + def _operation_inner(self, input: FieldArray) -> FieldArray: + assert input in {0, 1} + return self._gf(not bool(input)) + + def _arithmetize_inner(self, strategy: str) -> Node: + return Subtraction( + Constant(self._gf(1)), self._node.arithmetize(strategy), self._gf + ).arithmetize(strategy) + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + return Subtraction(Constant(self._gf(1)), self._node, self._gf).arithmetize_depth_aware( + cost_of_squaring + ) diff --git a/oraqle/compiler/boolean/bool_or.py b/oraqle/compiler/boolean/bool_or.py new file mode 100644 index 0000000..eb92a12 --- /dev/null +++ b/oraqle/compiler/boolean/bool_or.py @@ -0,0 +1,181 @@ +"""This module contains tools for evaluating OR operations between many inputs.""" +import itertools +from typing import Set + +from galois import GF, FieldArray + +from oraqle.compiler.boolean.bool_and import And, _find_depth_cost_front +from oraqle.compiler.boolean.bool_neg import Neg +from oraqle.compiler.nodes.abstract import CostParetoFront, Node, UnoverloadedWrapper +from oraqle.compiler.nodes.flexible import CommutativeUniqueReducibleNode +from oraqle.compiler.nodes.leafs import Constant, Input + +# TODO: Reduce code duplication between OR and AND + + +class Or(CommutativeUniqueReducibleNode): + """Performs an OR operation over several operands. The user must ensure that the operands are Booleans.""" + + @property + def _hash_name(self) -> str: + return "or" + + @property + def _node_label(self) -> str: + return "OR" + + def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray: + return self._gf(bool(a) | bool(b)) + + def _arithmetize_inner(self, strategy: str) -> Node: + # FIXME: Handle what happens when arithmetize outputs a constant! + # TODO: Also consider the arithmetization using randomness + return Neg( + And( + { + UnoverloadedWrapper(Neg(operand.node.arithmetize(strategy), self._gf)) + for operand in self._operands + }, + self._gf, + ), + self._gf, + ).arithmetize(strategy) + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + # TODO: This is mostly copied from AND + new_operands: Set[CostParetoFront] = set() + for operand in self._operands: + new_operand = operand.node.arithmetize_depth_aware(cost_of_squaring) + new_operands.add(new_operand) + + if len(new_operands) == 0: + return CostParetoFront.from_leaf(Constant(self._gf(1)), cost_of_squaring) + elif len(new_operands) == 1: + return next(iter(new_operands)) + + # TODO: We can check if any of the element in new_operands are constants and return early + + front = CostParetoFront(cost_of_squaring) + + # TODO: This is brute force composition + for operands in itertools.product(*(iter(new_operand) for new_operand in new_operands)): + checked_operands = [] + for depth, cost, node in operands: + if isinstance(node, Constant): + assert node._value in {0, 1} + if node._value == 0: + return CostParetoFront.from_leaf(Constant(self._gf(0)), cost_of_squaring) + else: + checked_operands.append((depth, cost, node)) + + if len(checked_operands) == 0: + return CostParetoFront.from_leaf(Constant(self._gf(1)), cost_of_squaring) + + if len(checked_operands) == 1: + depth, cost, node = checked_operands[0] + front.add(node, depth, cost) + continue + + this_front = _find_depth_cost_front( + checked_operands, + self._gf, + float("inf"), + squaring_cost=cost_of_squaring, + is_and=False, + ) + front.add_front(this_front) + + return front + + def or_flatten(self, other: Node) -> Node: + """Performs an OR operation with `other`, flattening the `Or` node if either of the two is also an `Or` and absorbing `Constant`s. + + Returns: + An `Or` node containing the flattened OR operation, or a `Constant` node. + """ + if isinstance(other, Constant): + if bool(other._value): + return Constant(self._gf(1)) + else: + return self + + if isinstance(other, Or): + return Or(self._operands | other._operands, self._gf) + + new_operands = self._operands.copy() + new_operands.add(UnoverloadedWrapper(other)) + return Or(new_operands, self._gf) + + +def any_(*operands: Node) -> Or: + """Returns an `Or` node that evaluates to true if any of the given `operands` evaluates to true.""" + assert len(operands) > 0 + return Or(set(UnoverloadedWrapper(operand) for operand in operands), operands[0]._gf) + + +def test_evaluate_mod3(): # noqa: D103 + gf = GF(3) + + a = Input("a", gf) + b = Input("b", gf) + node = a | b + + assert node.evaluate({"a": gf(0), "b": gf(0)}) == gf(0) + node.clear_cache(set()) + assert node.evaluate({"a": gf(0), "b": gf(1)}) == gf(1) + node.clear_cache(set()) + assert node.evaluate({"a": gf(1), "b": gf(0)}) == gf(1) + node.clear_cache(set()) + assert node.evaluate({"a": gf(1), "b": gf(1)}) == gf(1) + + +def test_evaluate_arithmetized_depth_aware_mod2(): # noqa: D103 + gf = GF(2) + + a = Input("a", gf) + b = Input("b", gf) + node = a | b + front = node.arithmetize_depth_aware(cost_of_squaring=1.0) + + for _, _, n in front: + n.clear_cache(set()) + assert n.evaluate({"a": gf(0), "b": gf(0)}) == gf(0) + n.clear_cache(set()) + assert n.evaluate({"a": gf(0), "b": gf(1)}) == gf(1) + n.clear_cache(set()) + assert n.evaluate({"a": gf(1), "b": gf(0)}) == gf(1) + n.clear_cache(set()) + assert n.evaluate({"a": gf(1), "b": gf(1)}) == gf(1) + + +def test_evaluate_arithmetized_mod3(): # noqa: D103 + gf = GF(3) + + a = Input("a", gf) + b = Input("b", gf) + node = (a | b).arithmetize("best-effort") + + node.clear_cache(set()) + assert node.evaluate({"a": gf(0), "b": gf(0)}) == gf(0) + node.clear_cache(set()) + assert node.evaluate({"a": gf(0), "b": gf(1)}) == gf(1) + node.clear_cache(set()) + assert node.evaluate({"a": gf(1), "b": gf(0)}) == gf(1) + node.clear_cache(set()) + assert node.evaluate({"a": gf(1), "b": gf(1)}) == gf(1) + + +def test_evaluate_arithmetized_depth_aware_50_mod31(): # noqa: D103 + gf = GF(31) + + xs = {Input(f"x{i}", gf) for i in range(50)} + node = Or({UnoverloadedWrapper(x) for x in xs}, gf) + front = node.arithmetize_depth_aware(cost_of_squaring=1.0) + + for _, _, n in front: + n.clear_cache(set()) + assert n.evaluate({f"x{i}": gf(0) for i in range(50)}) == gf(0) + n.clear_cache(set()) + assert n.evaluate({f"x{i}": gf(i % 2) for i in range(50)}) == gf(1) + n.clear_cache(set()) + assert n.evaluate({f"x{i}": gf(1) for i in range(50)}) == gf(1) diff --git a/oraqle/compiler/circuit.py b/oraqle/compiler/circuit.py new file mode 100644 index 0000000..6dd194b --- /dev/null +++ b/oraqle/compiler/circuit.py @@ -0,0 +1,415 @@ +"""This module contains classes for representing circuits.""" +import subprocess +import tempfile +from typing import Dict, List, Optional, Tuple + +from fhegen.bgv import logqP +from fhegen.util import estsecurity +from galois import FieldArray + +from oraqle.compiler.graphviz import DotFile +from oraqle.compiler.instructions import ArithmeticProgram, OutputInstruction +from oraqle.compiler.nodes.abstract import ArithmeticNode, Node + + +class Circuit: + """Represents a circuit over a fixed finite field that can be turned into an arithmetic circuit. Behind the scenes this is a directed acyclic graph (DAG). The circuit only has references to the outputs.""" + + def __init__(self, outputs: List[Node]): + """Initialize a circuit with the given `outputs`.""" + assert len(outputs) > 0 + self._outputs = outputs + self._gf = outputs[0]._gf + + def evaluate(self, actual_inputs: Dict[str, FieldArray]) -> List[FieldArray]: + """Evaluates the circuit with the given named inputs. + + This function does not error if it is given more inputs than necessary, but it will error if one is missing. + + Returns: + Evaluated output in plain text. + """ + assert all(isinstance(value, self._gf) for value in actual_inputs.values()) + + actual_outputs = [output.evaluate(actual_inputs) for output in self._outputs] + self._clear_cache() + + return actual_outputs + + def to_graph(self, file_name: str): + """Saves a DOT file representing the circuit as a graph at the given `file_name`.""" + graph_builder = DotFile() + + for output in self._outputs: + graph_builder.add_link( + output.to_graph(graph_builder), + graph_builder.add_node(label="Output", shape="plain"), + ) + self._clear_cache() + + graph_builder.to_file(file_name) + + def to_pdf(self, file_name: str): + """Saves a PDF file representing the circuit as a graph at the given `file_name`.""" + with tempfile.NamedTemporaryFile(suffix=".dot", delete=False) as dot_file: + self.to_graph(dot_file.name) + + subprocess.run(["dot", "-Tpdf", dot_file.name, "-o", file_name], check=True) + + def display_graph(self, metadata: Optional[dict] = None): + """Displays the circuit in a Python notebook.""" + with tempfile.NamedTemporaryFile(suffix=".dot", delete=False) as dot_file: + self.to_graph(dot_file.name) + + with open(dot_file.name, encoding="utf8") as file: + file_content = file.read() + + import graphviz + from IPython.display import display_png + + src = graphviz.Source(file_content) + display_png(src, metadata=metadata) + + def eliminate_subexpressions(self): + """Perform semantic common subexpression elimination on all outputs.""" + for output in self._outputs: + output.eliminate_common_subexpressions({}) + + def is_equivalent(self, other: object) -> bool: + """Returns whether the two circuits are semantically equivalent. + + False positives do not occure but false negatives do. + """ + if not isinstance(other, self.__class__): + return False + + return all(out1.is_equivalent(out2) for out1, out2 in zip(self._outputs, other._outputs)) + + def arithmetize(self, strategy: str = "best-effort") -> "ArithmeticCircuit": + """Arithmetizes this circuit by calling arithmetize on all outputs. + + This replaces all high-level operations with arithmetic operations (constants, additions, and multiplications). + The current implementation only aims at reducing the total number of multiplications. + + Returns: + An equivalent arithmetic circuit with low multiplicative size. + """ + arithmetic_circuit = ArithmeticCircuit( + [output.arithmetize(strategy).to_arithmetic() for output in self._outputs] + ) + # FIXME: Also call to_arithmetic + arithmetic_circuit._clear_cache() + + return arithmetic_circuit + + def arithmetize_depth_aware( + self, cost_of_squaring: float = 1.0 + ) -> List[Tuple[int, int, "ArithmeticCircuit"]]: + """Perform depth-aware arithmetization on this circuit. + + !!! failure + The current implementation only supports circuits with a single output. + + This function replaces high-level nodes with arithmetic operations (constants, additions, and multiplications). + + Returns: + A list with tuples containing the multiplicative depth, the multiplicative cost, and the generated arithmetization from low to high depth. + """ + assert len(self._outputs) == 1 + assert cost_of_squaring <= 1.0 + + front = [] + for depth, size, node in self._outputs[0].arithmetize_depth_aware(cost_of_squaring): + arithmetic_circuit = ArithmeticCircuit([node]) + arithmetic_circuit._clear_cache() + front.append((depth, size, arithmetic_circuit)) + + arithmetic_circuit._clear_cache() + return front + + def _clear_cache(self): + already_cleared = set() + for output in self._outputs: + output.clear_cache(already_cleared) + + +helib_preamble = """ +#include +#include +#include +#include + +#include + +typedef helib::Ptxt ptxt_t; +typedef helib::Ctxt ctxt_t; + +std::map input_map; + +void parse_arguments(int argc, char* argv[]) { + for (int i = 1; i < argc; ++i) { + std::string argument(argv[i]); + size_t pos = argument.find('='); + if (pos != std::string::npos) { + std::string key = argument.substr(0, pos); + int value = std::stoi(argument.substr(pos + 1)); + input_map[key] = value; + } + } +} + +int extract_input(const std::string& name) { + if (input_map.find(name) != input_map.end()) { + return input_map[name]; + } else { + std::cerr << "Error: " << name << " not found" << std::endl; + return -1; + } +} + +int main(int argc, char* argv[]) { + // Parse the inputs + parse_arguments(argc, argv); +""" + +helib_keygen = """ + // Generate keys + helib::SecKey secret_key(context); + secret_key.GenSecKey(); + helib::addSome1DMatrices(secret_key); + const helib::PubKey& public_key = secret_key; +""" + +helib_postamble = """ + return 0; +} +""" + + +class ArithmeticCircuit(Circuit): + """Represents an arithmetic circuit over a fixed finite field, so it only contains arithmetic nodes.""" + + _outputs: List[ArithmeticNode] + + def multiplicative_depth(self) -> int: + """Returns the multiplicative depth of the circuit.""" + depth = max(output.multiplicative_depth() for output in self._outputs) + self._clear_cache() + + return depth + + def multiplicative_size(self) -> int: + """Returns the multiplicative size (number of multiplications) of the circuit.""" + multiplications = set().union(*(output.multiplications() for output in self._outputs)) + size = len(multiplications) + + return size + + def multiplicative_cost(self, cost_of_squaring: float) -> float: + """Returns the multiplicative cost of the circuit.""" + multiplications = set().union(*(output.multiplications() for output in self._outputs)) + squarings = set().union(*(output.squarings() for output in self._outputs)) + cost = len(multiplications) - len(squarings) + cost_of_squaring * len(squarings) + + return cost + + def generate_program(self) -> ArithmeticProgram: + """Returns an arithmetic program for this arithmetic circuit.""" + # Reset the parent counts + for output in self._outputs: + output.reset_parent_count() + + # Count the parents + for output in self._outputs: + output.count_parents() + + # Reset the cache for instruction writing + self._clear_cache() + + # Write the instructions + instructions = [] + stack_occupied = [] + + stack_counter = 0 + for output in self._outputs: + output_index, stack_counter = output.create_instructions( + instructions, stack_counter, stack_occupied + ) + instructions.append(OutputInstruction(output_index)) + + # Reset the cache for future operations + self._clear_cache() + + return ArithmeticProgram(instructions, len(stack_occupied), self._gf) + + def summands_between_multiplications(self) -> int: + """Computes the maximum number of summands between two consecutive multiplications in this circuit. + + !!! failure + This currently returns the hardcoded value 10 + + Returns: + The highest number of summands between two consecutive multiplications + """ + # FIXME: This is currently hardcoded + return 10 + + def _generate_helib_params(self) -> Tuple[str, Tuple[int, int, int, int]]: + # Returns the code, along with (m, r, bits, c) + multiplicative_depth = self.multiplicative_depth() + summands_between_mults = self.summands_between_multiplications() + + # This code is adapted from fhegen: https://github.com/Crypto-TII/fhegen + # It was written by Johannes Mono, Chiara Marcolla, Georg Land, Tim Gรผneysu, and Najwa Aaraj + + ops = { + "model": "OpenFHE", + "muls": multiplicative_depth + 1, + "const": True, + "rots": 0, + "sums": summands_between_mults, + } + + sdist = "Ternary" + sigma = 3.19 + ve = sigma * sigma + vs = {"Ternary": 2 / 3, "Error": ve}[sdist] + b_args = { + "m": 4, + "t": self._gf.characteristic, + "D": 6, + "Vs": vs, + "Ve": ve, + } # We will loop over increasing m to find a suitable value + kswargs = {"method": "Hybrid-RNS", "L": multiplicative_depth + 1, "beta": 2**10, "omega": 3} + + while True: + logq, logp = logqP(ops, b_args, kswargs, sdist) + log = sum(logq) + logp if logp else sum(logq) + if logp and estsecurity(b_args["m"], log, sdist) >= 128: + break + + b_args["m"] <<= 1 + + # TODO: This is a workaround + if self._gf.characteristic == 2: + b_args["m"] -= 1 + + sec = estsecurity(b_args["m"], sum(logq) + logp, sdist) + assert sec >= 128 + + return f""" + // Set up the HE parameters + unsigned long p = {self._gf.characteristic}; + unsigned long m = {b_args["m"]}; + unsigned long r = 1; + unsigned long bits = {sum(logq)}; + unsigned long c = 3; + helib::Context context = helib::ContextBuilder() + .m(m) + .p(p) + .r(r) + .bits(bits) + .c(c) + .build(); +""", (b_args["m"], 1, sum(logq), 3) + + def generate_code( + self, + filename: str, + iterations: int = 1, + measure_time: bool = False, + decrypt_outputs: bool = False, + ) -> Tuple[int, int, int, int]: + """Generates an HElib implementation of the circuit. + + If decrypt_outputs is True, prints the decrypted output. + Otherwise, it prints whether the ciphertext has noise budget remaining (i.e. it is correct with high probability). + + !!! note + Decryption is part of the measured run time. + + Args: + filename: Test + iterations: Number of times to run the circuit + measure_time: Whether to output a measurement of the total run time + decrypt_outputs: Whether to print the decrypted outputs, or to simply check if there is noise budget remaining + + Returns: + Parameters that were chosen: (ring dimension m, Hensel lifting = 1, bits in the modchain, columns in key switching = 3). + """ + from oraqle.compiler.instructions import InputInstruction + + # Generate HElib code + with open(filename, "w", encoding="utf8") as file: + # Write start of file and parameters + file.write(helib_preamble) + param_code, params = self._generate_helib_params() + file.write(param_code) + file.write("\n") + file.write(helib_keygen) + file.write("\n") + + # Encrypt the inputs + program = self.generate_program() + inputs = [ + instruction._name + for instruction in program._instructions + if isinstance(instruction, InputInstruction) + ] + file.write("\t// Encrypt the inputs\n") + for input in inputs: + file.write( + f'\tstd::vector vec_{input}(1, extract_input("{input}"));\n\tptxt_t ptxt_{input}(context, vec_{input});\n\tctxt_t ciph_{input}(public_key);\n\tpublic_key.Encrypt(ciph_{input}, ptxt_{input});\n' + ) + file.write("\n") + + # If timing is enabled, start the timer + if measure_time: + file.write("\tauto start = std::chrono::high_resolution_clock::now();\n") + file.write("\n") + + # If we perform multiple iterations, wrap in a for loop + if iterations > 1: + file.write(f"\tfor (int i = 0; i < {iterations}; i++) {{\n") + + # Write the actual instructions + file.write("\t// Perform the actual circuit\n") + file.write( + "\n".join( + f"\t{line}" for line in program.generate_code(decrypt_outputs).splitlines() + ) + ) + file.write("\n") + + # If we perform multiple iterations, close the for loop + if iterations > 1: + file.write("\t}\n") + + # If timing is enabled, stop the timer + if measure_time: + file.write("\n") + file.write("\tauto end = std::chrono::high_resolution_clock::now();\n") + file.write("\tstd::chrono::duration elapsed = end - start;\n") + file.write("\tstd::cout << elapsed.count() << std::endl;") + file.write("\n") + + # Finish the file + file.write(helib_postamble) + + return params + + +if __name__ == "__main__": + from galois import GF + + from oraqle.compiler.circuit import Circuit + from oraqle.compiler.nodes.leafs import Input + + gf = GF(7) + + x = Input("x", gf) + y = Input("y", gf) + + arithmetic_circuit = Circuit([x < y]).arithmetize() + arithmetic_circuit.generate_code("main.cpp", iterations=10, measure_time=True) diff --git a/oraqle/compiler/comparison/__init__.py b/oraqle/compiler/comparison/__init__.py new file mode 100644 index 0000000..837510d --- /dev/null +++ b/oraqle/compiler/comparison/__init__.py @@ -0,0 +1 @@ +"""Package containing tools for expressing equality and comparison operations.""" diff --git a/oraqle/compiler/comparison/comparison.py b/oraqle/compiler/comparison/comparison.py new file mode 100644 index 0000000..18f3156 --- /dev/null +++ b/oraqle/compiler/comparison/comparison.py @@ -0,0 +1,693 @@ +"""Classes for representing comparisons such as x < y, x >= y, semi-comparisons etc.""" +from typing import Type + +from galois import GF, FieldArray + +from oraqle.compiler.arithmetic.subtraction import Subtraction +from oraqle.compiler.boolean.bool_neg import Neg +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.comparison.in_upper_half import IliashenkoZuccaInUpperHalf, InUpperHalf +from oraqle.compiler.nodes.abstract import CostParetoFront, Node, iterate_increasing_depth +from oraqle.compiler.nodes.leafs import Constant, Input +from oraqle.compiler.nodes.non_commutative import NonCommutativeBinaryNode + + +class AbstractComparison(NonCommutativeBinaryNode): + """An abstract class for comparisons, representing that they can be flipped: i.e. x > y <=> y < x.""" + + def __init__(self, left, right, less_than: bool, gf: Type[FieldArray]): + """Initialize an abstract comparison, indicating the direction of the comparison by specifying `less_than`.""" + self._less_than = less_than + super().__init__(left, right, gf) + + def __hash__(self) -> int: + if self._hash is None: + left_hash = hash(self._left) + right_hash = hash(self._right) + + if self._less_than: + self._hash = hash((self._hash_name, (left_hash, right_hash))) + else: + self._hash = hash((self._hash_name, (right_hash, left_hash))) + + return self._hash + + def is_equivalent(self, other: Node) -> bool: # noqa: D102 + if not isinstance(other, self.__class__): + return False + + if hash(self) != hash(other): + return False + + if self._less_than ^ other._less_than: + return self._left.is_equivalent(other._right) and self._right.is_equivalent(other._left) + else: + return self._left.is_equivalent(other._left) and self._right.is_equivalent(other._right) + + +class SemiStrictComparison(AbstractComparison): + """A node representing a comparison x < y or x > y that only works when x and y are at most p // 2 elements apart. + + Semi-comparisons are only valid if the absolute difference between the inputs does not exceed half of the field size. + """ + + @property + def _hash_name(self) -> str: + return "semi_strict_comparison" + + @property + def _node_label(self) -> str: + return "~<" if self._less_than else ">~" + + def _operation_inner(self, x, y) -> FieldArray: + assert abs(int(x) - int(y)) <= self._gf.characteristic // 2 + + if self._less_than: + return self._gf(int(int(x) < int(y))) + else: + return self._gf(int(int(x) > int(y))) + + def _arithmetize_inner(self, strategy: str) -> Node: + if self._less_than: + left = self._left + right = self._right + else: + left = self._right + right = self._left + + return InUpperHalf( + Subtraction(left.arithmetize(strategy), right.arithmetize(strategy), self._gf), + self._gf, + ).arithmetize(strategy) + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + front = CostParetoFront(cost_of_squaring) + + if self._less_than: + left = self._left + right = self._right + else: + left = self._right + right = self._left + + left_front = left.arithmetize_depth_aware(cost_of_squaring) + right_front = right.arithmetize_depth_aware(cost_of_squaring) + + for left, right in iterate_increasing_depth(left_front, right_front): + _, _, left_node = left + _, _, right_node = right + + sub_front = InUpperHalf( + Subtraction(left_node, right_node, self._gf), + self._gf, + ).arithmetize_depth_aware(cost_of_squaring) + + front.add_front(sub_front) + + assert not front.is_empty() + return front + + +class StrictComparison(AbstractComparison): + """A node representing a comparison x < y or x > y.""" + + @property + def _hash_name(self) -> str: + return "strict_comparison" + + @property + def _node_label(self) -> str: + return "<" + + def _operation_inner(self, x, y) -> FieldArray: + if self._less_than: + return self._gf(int(int(x) < int(y))) + else: + return self._gf(int(int(x) > int(y))) + + def _arithmetize_inner(self, strategy: str) -> Node: + p = self._gf.characteristic + + if self._less_than: + left = self._left + right = self._right + else: + left = self._right + right = self._left + + left = left.arithmetize(strategy) + right = right.arithmetize(strategy) + + left_is_small = SemiStrictComparison( + left, Constant(self._gf(p // 2)), less_than=True, gf=self._gf + ) + right_is_small = SemiStrictComparison( + right, Constant(self._gf(p // 2)), less_than=True, gf=self._gf + ) + + # Test whether left and right are in the same range + same_range = (left_is_small & right_is_small) + ( + Neg(left_is_small, self._gf) & Neg(right_is_small, self._gf) + ) + + # Performs left < right on the reduced inputs, note that if both are in the upper half the difference is still small enough for a semi-comparison + comparison = SemiStrictComparison(left, right, less_than=True, gf=self._gf) + result = same_range * comparison + + # Performs left < right when one if small and the other is large + right_is_larger = left_is_small & Neg(right_is_small, self._gf) + result += right_is_larger + + return result.arithmetize(strategy) + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + p = self._gf.characteristic + + if self._less_than: + left = self._left + right = self._right + else: + left = self._right + right = self._left + + left_front = left.arithmetize_depth_aware(cost_of_squaring) + right_front = right.arithmetize_depth_aware(cost_of_squaring) + + # TODO: This is just exhaustive. We can instead add a method decompose so that we do not have to copy this from arithmetize. + front = CostParetoFront(cost_of_squaring) + + for _, _, left_node in left_front: + for _, _, right_node in right_front: + left_is_small = SemiStrictComparison( + left_node, Constant(self._gf(p // 2)), less_than=True, gf=self._gf + ) + right_is_small = SemiStrictComparison( + right_node, Constant(self._gf(p // 2)), less_than=True, gf=self._gf + ) + + # Test whether left and right are in the same range + same_range = (left_is_small & right_is_small) + ( + Neg(left_is_small, self._gf) & Neg(right_is_small, self._gf) + ) + + # Performs left < right on the reduced inputs, note that if both are in the upper half the difference is still small enough for a semi-comparison + comparison = SemiStrictComparison( + left_node, right_node, less_than=True, gf=self._gf + ) + result = same_range * comparison + + # Performs left < right when one if small and the other is large + right_is_larger = left_is_small & Neg(right_is_small, self._gf) + result += right_is_larger + + front.add_front(result.arithmetize_depth_aware(cost_of_squaring)) + + return front + + +class SemiComparison(AbstractComparison): + """A node representing a comparison x <= y or x >= y that only works when x and y are at most p // 2 elements apart.""" + + @property + def _hash_name(self) -> str: + return "semi_comparison" + + @property + def _node_label(self) -> str: + return "~<=" if self._less_than else ">=~" + + def _operation_inner(self, x, y) -> FieldArray: + assert abs(int(x) - int(y)) <= self._gf.characteristic // 2 + + if self._less_than: + return self._gf(int(int(x) <= int(y))) + else: + return self._gf(int(int(x) >= int(y))) + + def _arithmetize_inner(self, strategy: str) -> Node: + return Neg( + SemiStrictComparison( + self._left.arithmetize(strategy), + self._right.arithmetize(strategy), + less_than=not self._less_than, + gf=self._gf, + ), + self._gf, + ).arithmetize(strategy) + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + return Neg( + SemiStrictComparison( + self._left, self._right, less_than=not self._less_than, gf=self._gf + ), + self._gf, + ).arithmetize_depth_aware(cost_of_squaring) + + +class Comparison(AbstractComparison): + """A node representing a comparison x <= y or x >= y.""" + + @property + def _hash_name(self) -> str: + return "comparison" + + @property + def _node_label(self) -> str: + return "<=" if self._less_than else ">=" + + def _operation_inner(self, x, y) -> FieldArray: + if self._less_than: + return self._gf(int(int(x) <= int(y))) + else: + return self._gf(int(int(x) >= int(y))) + + def _arithmetize_inner(self, strategy: str) -> Node: + return Neg( + StrictComparison( + self._left.arithmetize(strategy), + self._right.arithmetize(strategy), + less_than=not self._less_than, + gf=self._gf, + ), + self._gf, + ).arithmetize(strategy) + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + return Neg( + StrictComparison(self._left, self._right, less_than=not self._less_than, gf=self._gf), + self._gf, + ).arithmetize_depth_aware(cost_of_squaring) + + +class T2SemiLessThan(NonCommutativeBinaryNode): + """Implementation of [the algorithm from the T2 framework](https://petsymposium.org/popets/2023/popets-2023-0075.pdf) for performing x < y.""" + + @property + def _hash_name(self) -> str: + return "less_than_t2" + + @property + def _node_label(self) -> str: + return "< [t2]" + + def _operation_inner(self, x, y) -> FieldArray: + return self._gf(int(int(x) < int(y))) + + def _arithmetize_inner(self, strategy: str) -> Node: + out = Constant(self._gf(0)) + + p = self._gf.characteristic + for a in range((p + 1) // 2, p): + out += Constant(self._gf(1)) - (self._left - self._right - Constant(self._gf(a))) ** ( + p - 1 + ) + + return out.arithmetize(strategy) + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + raise NotImplementedError() + + +class IliashenkoZuccaSemiLessThan(NonCommutativeBinaryNode): + """Implementation of the [Illiashenko-Zucca algorithm](https://eprint.iacr.org/2021/315) for performing x < y.""" + + @property + def _hash_name(self) -> str: + return "less_than_t2" + + @property + def _node_label(self) -> str: + return "< [t2]" + + def _operation_inner(self, x, y) -> FieldArray: + return self._gf(int(int(x) < int(y))) + + def _arithmetize_inner(self, strategy: str) -> Node: + return IliashenkoZuccaInUpperHalf( + Subtraction( + self._left.arithmetize(strategy), self._right.arithmetize(strategy), self._gf + ), + self._gf, + ).arithmetize(strategy) + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + raise NotImplementedError() + + +def test_evaluate_semi_mod5_lt(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = SemiStrictComparison(a, b, less_than=True, gf=gf) + + for x in range(3): + for y in range(3): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x < y)) + node.clear_cache(set()) + + +def test_evaluate_semi_arithmetized_mod5_lt(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = SemiStrictComparison(a, b, less_than=True, gf=gf).arithmetize("best-effort") + node.clear_cache(set()) + + for x in range(3): + for y in range(3): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x < y)) + node.clear_cache(set()) + + +def test_evaluate_mod5_lt(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = StrictComparison(a, b, less_than=True, gf=gf) + + for x in range(5): + for y in range(5): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x < y)) + node.clear_cache(set()) + + +def test_evaluate_arithmetized_mod5_lt(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = StrictComparison(a, b, less_than=True, gf=gf).arithmetize("best-effort") + node.clear_cache(set()) + + for x in range(5): + for y in range(5): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x < y)) + node.clear_cache(set()) + + +def test_evaluate_arithmetized_mod11_lt(): # noqa: D103 + gf = GF(11) + + a = Input("a", gf) + b = Input("b", gf) + node = StrictComparison(a, b, less_than=True, gf=gf).arithmetize("best-effort") + node.clear_cache(set()) + + for x in range(11): + for y in range(11): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x < y)) + node.clear_cache(set()) + + +def test_evaluate_arithmetized_depth_aware_semi_mod11_lt(): # noqa: D103 + gf = GF(11) + + a = Input("a", gf) + b = Input("b", gf) + front = SemiStrictComparison(a, b, less_than=True, gf=gf).arithmetize_depth_aware(1.0) + + for _, _, node in front: + for x in range(6): + for y in range(6): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x < y)) + node.clear_cache(set()) + + +def test_evaluate_arithmetized_depth_aware_mod11_lt(): # noqa: D103 + gf = GF(11) + + a = Input("a", gf) + b = Input("b", gf) + front = StrictComparison(a, b, less_than=True, gf=gf).arithmetize_depth_aware(1.0) + + for _, _, node in front: + for x in range(11): + for y in range(11): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x < y)) + node.clear_cache(set()) + + +def test_evaluate_semi_arithmetized_mod5_t2(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = T2SemiLessThan(a, b, gf).arithmetize("best-effort") + node.clear_cache(set()) + + for x in range(3): + for y in range(3): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x < y)) + node.clear_cache(set()) + + +def test_evaluate_semi_arithmetized_mod11_t2(): # noqa: D103 + gf = GF(11) + + a = Input("a", gf) + b = Input("b", gf) + node = T2SemiLessThan(a, b, gf).arithmetize("best-effort") + node.clear_cache(set()) + + for x in range(6): + for y in range(6): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x < y)) + node.clear_cache(set()) + + +def test_evaluate_semi_mod5_gt(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = SemiStrictComparison(a, b, less_than=False, gf=gf) + + for x in range(3): + for y in range(3): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x > y)) + node.clear_cache(set()) + + +def test_evaluate_semi_arithmetized_mod5_gt(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = SemiStrictComparison(a, b, less_than=False, gf=gf).arithmetize("best-effort") + node.clear_cache(set()) + + for x in range(3): + for y in range(3): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x > y)) + node.clear_cache(set()) + + +def test_evaluate_mod5_gt(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = StrictComparison(a, b, less_than=False, gf=gf) + + for x in range(5): + for y in range(5): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x > y)) + node.clear_cache(set()) + + +def test_evaluate_arithmetized_mod5_gt(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = StrictComparison(a, b, less_than=False, gf=gf).arithmetize("best-effort") + node.clear_cache(set()) + + for x in range(5): + for y in range(5): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x > y)) + node.clear_cache(set()) + + +def test_evaluate_semi_mod5_ge(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = SemiComparison(a, b, less_than=False, gf=gf) + + for x in range(3): + for y in range(3): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x >= y)) + node.clear_cache(set()) + + +def test_evaluate_semi_arithmetized_mod5_ge(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = SemiComparison(a, b, less_than=False, gf=gf).arithmetize("best-effort") + node.clear_cache(set()) + + for x in range(3): + for y in range(3): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x >= y)) + node.clear_cache(set()) + + +def test_evaluate_mod5_ge(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = Comparison(a, b, less_than=False, gf=gf) + + for x in range(5): + for y in range(5): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x >= y)) + node.clear_cache(set()) + + +def test_evaluate_arithmetized_mod5_ge(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = Comparison(a, b, less_than=False, gf=gf).arithmetize("best-effort") + node.clear_cache(set()) + + for x in range(5): + for y in range(5): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x >= y)) + node.clear_cache(set()) + + +def test_evaluate_arithmetized_depth_aware_mod5_ge(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = Comparison(a, b, less_than=False, gf=gf) + front = node.arithmetize_depth_aware(0.75) + + for _, _, n in front: + for x in range(5): + for y in range(5): + assert n.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x >= y)) + n.clear_cache(set()) + + +def test_evaluate_semi_mod5_le(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = SemiComparison(a, b, less_than=True, gf=gf) + + for x in range(3): + for y in range(3): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x <= y)) + node.clear_cache(set()) + + +def test_evaluate_semi_arithmetized_mod5_le(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = SemiComparison(a, b, less_than=True, gf=gf).arithmetize("best-effort") + node.clear_cache(set()) + + for x in range(3): + for y in range(3): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x <= y)) + node.clear_cache(set()) + + +def test_evaluate_mod5_le(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = Comparison(a, b, less_than=True, gf=gf) + + for x in range(5): + for y in range(5): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x <= y)) + node.clear_cache(set()) + + +def test_evaluate_arithmetized_mod5_le(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = Comparison(a, b, less_than=True, gf=gf).arithmetize("best-effort") + node.clear_cache(set()) + + for x in range(5): + for y in range(5): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x <= y)) + node.clear_cache(set()) + + +def test_evaluate_semi_arithmetized_mod101_lt(): # noqa: D103 + gf = GF(101) + + a = Input("a", gf) + b = Input("b", gf) + node = SemiStrictComparison(a, b, less_than=True, gf=gf).arithmetize("best-effort") + node.clear_cache(set()) + + for x in range(51): + for y in range(51): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x < y)) + node.clear_cache(set()) + + +def test_evaluate_semi_depth_aware_arithmetized_mod61_lt(): # noqa: D103 + gf = GF(61) + + a = Input("a", gf) + b = Input("b", gf) + front = SemiStrictComparison(a, b, less_than=True, gf=gf).arithmetize_depth_aware(cost_of_squaring=1.0) + + for _, _, node in front: + node.clear_cache(set()) + + for x in range(31): + for y in range(31): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x < y)) + node.clear_cache(set()) + + +def test_evaluate_semi_depth_aware_arithmetized_mod61_lt_05sq(): # noqa: D103 + gf = GF(61) + + a = Input("a", gf) + b = Input("b", gf) + front = SemiStrictComparison(a, b, less_than=True, gf=gf).arithmetize_depth_aware(cost_of_squaring=0.5) + + for _, _, node in front: + node.clear_cache(set()) + + for x in range(31): + for y in range(31): + assert node.evaluate({"a": gf(x), "b": gf(y)}) == gf(int(x < y)) + node.clear_cache(set()) + + +def test_lessthan_mod101(): # noqa: D103 + gf = GF(101) + + x = Input("x", gf) + circuit = Circuit([x < 30]) + + for _, _, arithmetization in circuit.arithmetize_depth_aware(): + assert arithmetization.evaluate({ + "x": gf(90), + })[0] == 0 diff --git a/oraqle/compiler/comparison/equality.py b/oraqle/compiler/comparison/equality.py new file mode 100644 index 0000000..cf6f3a9 --- /dev/null +++ b/oraqle/compiler/comparison/equality.py @@ -0,0 +1,106 @@ +"""This module contains classes for representing equality checks.""" +from galois import GF, FieldArray + +from oraqle.compiler.arithmetic.exponentiation import Power +from oraqle.compiler.arithmetic.subtraction import Subtraction +from oraqle.compiler.boolean.bool_neg import Neg +from oraqle.compiler.nodes.abstract import CostParetoFront, Node +from oraqle.compiler.nodes.binary_arithmetic import CommutativeBinaryNode +from oraqle.compiler.nodes.leafs import Input +from oraqle.compiler.nodes.univariate import UnivariateNode + + +class IsNonZero(UnivariateNode): + """This node represents a zero check: x == 0.""" + + @property + def _node_shape(self) -> str: + return "box" + + @property + def _hash_name(self) -> str: + return "is_nonzero" + + @property + def _node_label(self) -> str: + return "!= 0" + + def _operation_inner(self, input: FieldArray) -> FieldArray: + return input != 0 + + def _arithmetize_inner(self, strategy: str) -> Node: + return Power(self._node, self._gf.order - 1, self._gf).arithmetize(strategy) + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + return Power(self._node, self._gf.order - 1, self._gf).arithmetize_depth_aware( + cost_of_squaring + ) + + +class Equals(CommutativeBinaryNode): + """This node represents an equality operation: x == y.""" + + @property + def _hash_name(self) -> str: + return "equals" + + @property + def _node_label(self) -> str: + return "==" + + def _operation_inner(self, x, y) -> FieldArray: + return self._gf(int(x == y)) + + def _arithmetize_inner(self, strategy: str) -> Node: + return Neg( + IsNonZero(Subtraction(self._left, self._right, self._gf), self._gf), + self._gf, + ).arithmetize(strategy) + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + return Neg( + IsNonZero(Subtraction(self._left, self._right, self._gf), self._gf), + self._gf, + ).arithmetize_depth_aware(cost_of_squaring) + + +def test_evaluate_mod5(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = Equals(a, b, gf) + + assert node.evaluate({"a": gf(3), "b": gf(2)}) == gf(0) + node.clear_cache(set()) + assert node.evaluate({"a": gf(4), "b": gf(4)}) == gf(1) + node.clear_cache(set()) + assert node.evaluate({"a": gf(1), "b": gf(2)}) == gf(0) + node.clear_cache(set()) + assert node.evaluate({"a": gf(0), "b": gf(0)}) == gf(1) + + +def test_evaluate_arithmetized_mod5(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + node = Equals(a, b, gf).arithmetize("best-effort") + node.clear_cache(set()) + + assert node.evaluate({"a": gf(3), "b": gf(2)}) == gf(0) + node.clear_cache(set()) + assert node.evaluate({"a": gf(4), "b": gf(4)}) == gf(1) + node.clear_cache(set()) + assert node.evaluate({"a": gf(1), "b": gf(2)}) == gf(0) + node.clear_cache(set()) + assert node.evaluate({"a": gf(0), "b": gf(0)}) == gf(1) + + +def test_equality_equivalence_commutative(): # noqa: D103 + gf = GF(5) + + a = Input("a", gf) + b = Input("b", gf) + + assert (a == b).is_equivalent(b == a) diff --git a/oraqle/compiler/comparison/in_upper_half.py b/oraqle/compiler/comparison/in_upper_half.py new file mode 100644 index 0000000..5cdc1f6 --- /dev/null +++ b/oraqle/compiler/comparison/in_upper_half.py @@ -0,0 +1,275 @@ +"""This module contains classes for checking if an element is in the upper half of the finite field.""" +import math + +from galois import GF, FieldArray + +from oraqle.add_chains.addition_chains_front import gen_pareto_front +from oraqle.add_chains.addition_chains_heuristic import add_chain_guaranteed +from oraqle.add_chains.solving import extract_indices +from oraqle.compiler.nodes.abstract import CostParetoFront, Node +from oraqle.compiler.nodes.binary_arithmetic import Addition, Multiplication +from oraqle.compiler.nodes.leafs import Input +from oraqle.compiler.nodes.unary_arithmetic import ConstantMultiplication +from oraqle.compiler.nodes.univariate import UnivariateNode +from oraqle.compiler.polynomials.univariate import UnivariatePoly, _eval_poly + + +class InUpperHalf(UnivariateNode): + """Returns 1 when the input is contained in the upper half of the field, which are considered the negative numbers. + + Specifically, it returns 0 in the range [0, (p - 1) / 2] and 1 in the range ((p - 1) / 2, p - 1]. + """ + + @property + def _node_shape(self) -> str: + return "box" + + @property + def _hash_name(self) -> str: + return "in_upper_half" + + @property + def _node_label(self) -> str: + return "> (p-1)/2" + + def _operation_inner(self, input: FieldArray) -> FieldArray: + p = self._gf.characteristic + if 0 <= int(input) <= p // 2: + return self._gf(0) + + return self._gf(1) + + def _arithmetize_inner(self, strategy: str) -> Node: + coefficients = [] + + # From: Faster homomorphic comparison operations for BGV and BFV, Ilia Iliashenko & Vincent Zucca, 2021 + p = self._gf.characteristic + for i in range(p - 1): + if i % 2 == 0: + # Ignore every even power, we take care of this by squaring the input node. + continue + + coefficient = self._gf(0) + for a in range(1, p // 2 + 1): + coefficient += self._gf(a) ** (p - 1 - i) + coefficients.append(coefficient) + + # We do not add the final coefficient, which will be computed later, so we do not do coefficients.append(gf((p + 1) // 2)) + + input_node = self._node.arithmetize(strategy).to_arithmetic() + input_node_squared = input_node * input_node + arithmetization, precomputed_powers = UnivariatePoly( + input_node_squared, coefficients, self._gf + ).arithmetize_custom(strategy) + + # Since we skip the first coefficient, we manually multiply the output by the input node. + result = Multiplication(input_node, arithmetization, self._gf) + + # Compute the final coefficient using an exponentiation + precomputed_values = tuple( + ( + (2 * exp) % (p - 1), + power_node.multiplicative_depth() - input_node.multiplicative_depth(), + ) + for exp, power_node in precomputed_powers.items() + ) + + addition_chain = add_chain_guaranteed(p - 1, p - 1, squaring_cost=1.0, precomputed_values=precomputed_values) + + nodes = [input_node] + nodes.extend(power_node for _, power_node in precomputed_powers.items()) + + for i, j in addition_chain: + nodes.append(Multiplication(nodes[i], nodes[j], self._gf)) + + final_term = ConstantMultiplication(nodes[-1], self._gf((p + 1) // 2)) + + return (Addition(result, final_term, self._gf)).arithmetize(strategy) + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + # TODO: Handle p = 2 and p = 3 separately + + # TODO: Reduce code duplication + final_front = CostParetoFront(cost_of_squaring) + + for node_depth, _, node in self._node.arithmetize_depth_aware(cost_of_squaring): + coefficients = [] + + # From: Faster homomorphic comparison operations for BGV and BFV, Ilia Iliashenko & Vincent Zucca, 2021 + p = self._gf.characteristic + for i in range(p - 1): + if i % 2 == 0: + # Ignore every even power, we take care of this by squaring the input node. + continue + + coefficient = self._gf(0) + for a in range(1, p // 2 + 1): + coefficient += self._gf(a) ** (p - 1 - i) + coefficients.append(coefficient) + + # We do not add the final coefficient, which will be computed later, so we do not do coefficients.append(gf((p + 1) // 2)) + + input_node_squared = Multiplication(node, node, self._gf) + arithmetizations, precomputed_powers = UnivariatePoly( + input_node_squared, coefficients, self._gf + ).arithmetize_depth_aware_custom(cost_of_squaring) + + assert not arithmetizations.is_empty() + + for depth, _, poly_arith in arithmetizations: + # Since we skip the first coefficient, we manually multiply the output by the input node. + result = Multiplication(node, poly_arith, self._gf) + + # Compute the final coefficient using an exponentiation + precomputed_values = tuple( + ((2 * exp) % (p - 1), power_node.multiplicative_depth() - node_depth) + for exp, power_node in precomputed_powers[depth].items() + ) + # TODO: This is copied from Power, but in the future we can probably remove this if we have augmented circuits + if p <= 200: + front = gen_pareto_front( + p - 1, + self._gf.characteristic - 1, + cost_of_squaring, + precomputed_values=precomputed_values, + ) + else: + front = gen_pareto_front( + p - 1, None, cost_of_squaring, precomputed_values=precomputed_values + ) + + final_power_front = CostParetoFront(cost_of_squaring) + + for depth2, chain in front: + c = extract_indices( + chain, + precomputed_values=list(k for k, _ in precomputed_values), + modulus=p - 1, + ) + + nodes = [node] + nodes.extend(power_node for _, power_node in precomputed_powers[depth].items()) + + for i, j in c: + nodes.append(Multiplication(nodes[i], nodes[j], self._gf)) + + final_power_front.add(nodes[-1], depth=node_depth + depth2) + + for _, _, final_power in final_power_front: + final_term = ConstantMultiplication(final_power, self._gf((p + 1) // 2)) + final_front.add(Addition(result, final_term, self._gf)) + + assert not final_front.is_empty() + return final_front + + +class IliashenkoZuccaInUpperHalf(UnivariateNode): + """Returns 1 when the input is contained in the upper half of the field, which are considered the negative numbers. + + Specifically, it returns 0 in the range [0, (p - 1) / 2] and 1 in the range ((p - 1) / 2, p - 1]. + """ + + @property + def _node_shape(self) -> str: + return "box" + + @property + def _hash_name(self) -> str: + return "in_upper_half_iz21" + + @property + def _node_label(self) -> str: + return "> (p-1)/2 [IZ21]" + + def _operation_inner(self, input: FieldArray) -> FieldArray: + p = self._gf.characteristic + if 0 <= int(input) <= p // 2: + return self._gf(0) + + return self._gf(1) + + def _arithmetize_inner(self, strategy: str) -> Node: + coefficients = [] + + # TODO: This is copied from above + # From: Faster homomorphic comparison operations for BGV and BFV, Ilia Iliashenko & Vincent Zucca, 2021 + p = self._gf.characteristic + for i in range(p - 1): + if i % 2 == 0: + # Ignore every even power, we take care of this by squaring the input node. + continue + + coefficient = self._gf(0) + for a in range(1, p // 2 + 1): + coefficient += self._gf(a) ** (p - 1 - i) + coefficients.append(coefficient) + + # We do not add the final coefficient, which will be computed later, so we do not do coefficients.append(gf((p + 1) // 2)) + + input_node = self._node.arithmetize(strategy).to_arithmetic() + input_node_squared = Multiplication(input_node, input_node, self._gf) + + # We decide ahead of time which k to use + k = round(math.sqrt((p - 3) / 2)) + arithmetization, precomputed_powers = _eval_poly( + input_node_squared, coefficients, k, self._gf, squaring_cost=1.0 + ) + + # Since we skip the first coefficient, we manually multiply the output by the input node. + result = Multiplication(input_node, arithmetization, self._gf) + + # Compute the final coefficient using an exponentiation + precomputed_values = tuple( + ( + (2 * exp) % (p - 1), + power_node.multiplicative_depth() - input_node.multiplicative_depth(), + ) + for exp, power_node in precomputed_powers.items() + ) + + addition_chain = add_chain_guaranteed(p - 1, p - 1, squaring_cost=1.0, precomputed_values=precomputed_values) + + nodes = [input_node] + nodes.extend(power_node for _, power_node in precomputed_powers.items()) + + for i, j in addition_chain: + nodes.append(Multiplication(nodes[i], nodes[j], self._gf)) + final_monomial = nodes[-1] + + final_term = ConstantMultiplication(final_monomial, self._gf((p + 1) // 2)) + + return (Addition(result, final_term, self._gf)).arithmetize(strategy) + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + raise NotImplementedError() + # TODO: Handle p = 2 and p = 3 separately + + +# TODO: Make a univariate node class with an easy way to test if evaluation corresponds to evaluating the arithmetic +def test_evaluate_mod7(): # noqa: D103 + gf = GF(7) + + x = Input("x", gf) + node = InUpperHalf(x, gf) + + for i in range(3): + assert node.evaluate({"x": gf(i)}) == gf(0) + node.clear_cache(set()) + for i in range(4, 7): + assert node.evaluate({"x": gf(i)}) == gf(1) + node.clear_cache(set()) + + +def test_evaluate_arithmetized_mod7(): # noqa: D103 + gf = GF(7) + + x = Input("x", gf) + node = InUpperHalf(x, gf).arithmetize("best-effort") + node.clear_cache(set()) + + for i in range(3): + assert node.evaluate({"x": gf(i)}) == gf(0) + node.clear_cache(set()) + for i in range(4, 7): + assert node.evaluate({"x": gf(i)}) == gf(1) + node.clear_cache(set()) diff --git a/oraqle/compiler/control_flow/__init__.py b/oraqle/compiler/control_flow/__init__.py new file mode 100644 index 0000000..536cf1d --- /dev/null +++ b/oraqle/compiler/control_flow/__init__.py @@ -0,0 +1 @@ +"""This package contains control flow functions.""" diff --git a/oraqle/compiler/control_flow/conditional.py b/oraqle/compiler/control_flow/conditional.py new file mode 100644 index 0000000..2a8ba03 --- /dev/null +++ b/oraqle/compiler/control_flow/conditional.py @@ -0,0 +1,110 @@ +"""This module contains tools for evaluating conditional statements.""" +from typing import List, Type + +from galois import GF, FieldArray + +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes.abstract import CostParetoFront, Node +from oraqle.compiler.nodes.fixed import FixedNode +from oraqle.compiler.nodes.leafs import Constant, Input + + +class IfElse(FixedNode): + """A node representing an if-else clause.""" + + @property + def _node_label(self): + return "If" + + @property + def _hash_name(self): + return "if_else" + + def __init__(self, condition: Node, positive: Node, negative: Node, gf: Type[FieldArray]): + """Initialize an if-else node: If condition evaluates to true, then it outputs positive, otherwise it outputs negative.""" + self._condition = condition + self._positive = positive + self._negative = negative + super().__init__(gf) + + def __hash__(self) -> int: + return hash((self._hash_name, self._condition, self._positive, self._negative)) + + def is_equivalent(self, other: Node) -> bool: # noqa: D102 + if not isinstance(other, self.__class__): + return False + + return ( + self._condition.is_equivalent(other._condition) + and self._positive.is_equivalent(other._positive) + and self._negative.is_equivalent(other._negative) + ) + + def operands(self) -> List[Node]: # noqa: D102 + return [self._condition, self._positive, self._negative] + + def set_operands(self, operands: List[Node]): # noqa: D102 + self._condition = operands[0] + self._positive = operands[1] + self._negative = operands[2] + + def operation(self, operands: List[FieldArray]) -> FieldArray: # noqa: D102 + assert operands[0] == 0 or operands[0] == 1 + return operands[1] if operands[0] == 1 else operands[2] + + def _arithmetize_inner(self, strategy: str) -> Node: + return (self._condition * (self._positive - self._negative) + self._negative).arithmetize( + strategy + ) + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + return ( + self._condition * (self._positive - self._negative) + self._negative + ).arithmetize_depth_aware(cost_of_squaring) + + +def if_else(condition: Node, positive: Node, negative: Node) -> IfElse: + """Sugar expression for creating an if-else clause. + + Returns: + An `IfElse` node that equals `positive` if `condition` is true, and `negative` otherwise. + """ + assert condition._gf == positive._gf + assert condition._gf == negative._gf + return IfElse(condition, positive, negative, condition._gf) + + +def test_if_else(): # noqa: D103 + gf = GF(11) + + a = Input("a", gf) + b = Input("b", gf) + + output = if_else(a == b, Constant(gf(3)), Constant(gf(5))) + + circuit = Circuit([output]) + + for val_a in range(11): + for val_b in range(11): + expected = gf(3) if val_a == val_b else gf(5) + + values = {"a": gf(val_a), "b": gf(val_b)} + assert circuit.evaluate(values) == expected + + +def test_if_else_arithmetized(): # noqa: D103 + gf = GF(11) + + a = Input("a", gf) + b = Input("b", gf) + + output = if_else(a == b, Constant(gf(3)), Constant(gf(5))) + + arithmetic_circuit = Circuit([output]).arithmetize() + + for val_a in range(11): + for val_b in range(11): + expected = gf(3) if val_a == val_b else gf(5) + + values = {"a": gf(val_a), "b": gf(val_b)} + assert arithmetic_circuit.evaluate(values) == expected diff --git a/oraqle/compiler/func2poly.py b/oraqle/compiler/func2poly.py new file mode 100644 index 0000000..d78a2dc --- /dev/null +++ b/oraqle/compiler/func2poly.py @@ -0,0 +1,44 @@ +"""Tools for interpolating polynomials from arbitrary functions.""" +import itertools +from typing import Callable, List + +from sympy import Poly, symbols + + +def principal_character(x, prime_modulus): + """Computes the principal character. This expression always returns 1 when x = 0 and 0 otherwise. Only works for prime moduli. + + Returns: + The principal character x**(p-1). + """ + return x ** (prime_modulus - 1) + + +def interpolate_polynomial( + function: Callable[..., int], prime_modulus: int, input_names: List[str] +) -> Poly: + """Interpolates a polynomial for the given function. This is currently only implemented for prime moduli. This function interpolates the polynomial on all possible inputs. + + Returns: + A sympy `Poly` object representing the unique polynomial that evaluates to the same outputs for all inputs as `function`. + """ + variables = symbols(input_names) + poly = 0 + + for inputs in itertools.product(range(prime_modulus), repeat=len(input_names)): + output = function(*inputs) + assert 0 <= output < prime_modulus + + product = output + for input, variable in zip(inputs, variables): + product *= Poly( + 1 - principal_character(variable - input, prime_modulus), + variable, + modulus=prime_modulus, + ) + product = Poly(product, variables, modulus=prime_modulus) + + poly += product + poly = Poly(poly, variables, modulus=prime_modulus) + + return Poly(poly, variables, modulus=prime_modulus) diff --git a/oraqle/compiler/graphviz.py b/oraqle/compiler/graphviz.py new file mode 100644 index 0000000..d56cb80 --- /dev/null +++ b/oraqle/compiler/graphviz.py @@ -0,0 +1,56 @@ +"""This module contains classes and functions for visualizing circuits using graphviz.""" +from typing import Dict, List, Tuple + +expensive_style = {"shape": "diamond"} + + +class DotFile: + """A `DotFile` is a graph description format that can be rendered to e.g. PDF using graphviz.""" + + def __init__(self): + """Initialize an empty DotFile.""" + self._nodes: List[Dict[str, str]] = [] + self._links: List[Tuple[int, int, Dict[str, str]]] = [] + + def add_node(self, **kwargs) -> int: + """Adds a node to the file. The keyword arguments are directly put into the DOT file. + + For example, one can specify a label, a color, a style, etc... + + Returns: + The identifier of this node in this `DotFile`. + """ + node_id = len(self._nodes) + self._nodes.append(kwargs) + + return node_id + + def add_link(self, from_id: int, to_id: int, **kwargs): + """Adds an unformatted link between the nodes with `from_id` and `to_id`. The keyword arguments are directly put into the DOT file.""" + self._links.append((from_id, to_id, kwargs)) + + def to_file(self, filename: str): + """Writes the DOT file to the given filename as a directed graph called 'G'.""" + with open(filename, mode="w", encoding="utf-8") as file: + file.write("digraph G {\n") + file.write('forcelabels="true";\n') + file.write("graph [nodesep=0.25,ranksep=0.6];") # nodesep, ranksep + + # Write all the nodes + for node_id, attributes in enumerate(self._nodes): + transformed_attributes = ",".join( + [f'{key}="{value}"' for key, value in attributes.items()] + ) + file.write(f"n{node_id} [{transformed_attributes}];\n") + + # Write all the links + for from_id, to_id, attributes in self._links: + if len(attributes) == 0: + file.write(f"n{from_id}->n{to_id};\n") + else: + text = f"n{from_id}->n{to_id} [" + text += ",".join((f"{key}={value}" for key, value in attributes.items())) + text += "];\n" + file.write(text) + + file.write("}\n") diff --git a/oraqle/compiler/instructions.py b/oraqle/compiler/instructions.py new file mode 100644 index 0000000..10e1c8e --- /dev/null +++ b/oraqle/compiler/instructions.py @@ -0,0 +1,241 @@ +"""This module contains the classes that represent instructions and programs for evaluating arithmetic circuits.""" +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Type + +from galois import GF, FieldArray + + +class ArithmeticInstruction(ABC): + """An abstract arithmetic instruction that computes an operation in an arithmetic circuit using a stack.""" + + def __init__(self, stack_index: int) -> None: + """Initialize an instruction that writes it output to the stack at `stack_index`.""" + self._stack_index = stack_index + + @abstractmethod + def evaluate( + self, stack: List[Optional[FieldArray]], inputs: Dict[str, FieldArray] + ) -> Optional[FieldArray]: + """Executes the instruction on plaintext inputs without using encryption, keeping track of the plaintext values in the stack.""" + + @abstractmethod + def generate_code(self, stack_initialized: List[bool], decrypt_outputs: bool) -> str: + """Generates code for this instruction, keeping track of which places of the stack are already initialized.""" + + +class AdditionInstruction(ArithmeticInstruction): + """Reads two elements from the stack, adds them, and writes the result to the stack.""" + + def __init__(self, stack_index: int, left_stack_index: int, right_stack_index: int) -> None: + """Initialize an instruction that adds the elements at `left_stack_index` and `right_stack_index`, placing the result at `stack_index`.""" + self._left_stack_index = left_stack_index + self._right_stack_index = right_stack_index + super().__init__(stack_index) + + def evaluate(self, stack: List[Optional[FieldArray]], _inputs: Dict[str, FieldArray]) -> None: # noqa: D102 + left = stack[self._left_stack_index] + right = stack[self._right_stack_index] + assert left is not None + assert right is not None + stack[self._stack_index] = left + right + + def generate_code(self, stack_initialized: List[bool], _decrypt_outputs: bool) -> str: # noqa: D102 + if self._left_stack_index == self._stack_index: + return f"stack_{self._stack_index} += stack_{self._right_stack_index};\n" + if self._right_stack_index == self._stack_index: + return f"stack_{self._stack_index} += stack_{self._left_stack_index};\n" + + code = "" + if not stack_initialized[self._stack_index]: + code += "ctxt_t " + code += f"stack_{self._stack_index} = stack_{self._left_stack_index};\nstack_{self._stack_index} += stack_{self._right_stack_index};\n" + stack_initialized[self._stack_index] = True + return code + + +class MultiplicationInstruction(ArithmeticInstruction): + """Reads two elements from the stack, multiplies them, and writes the result to the stack.""" + + def __init__(self, stack_index: int, left_stack_index: int, right_stack_index: int) -> None: + """Initialize an instruction that multiplies the elements at `left_stack_index` and `right_stack_index`, placing the result at `stack_index`.""" + self._left_stack_index = left_stack_index + self._right_stack_index = right_stack_index + super().__init__(stack_index) + + def evaluate(self, stack: List[Optional[FieldArray]], _inputs: Dict[str, FieldArray]) -> None: # noqa: D102 + left = stack[self._left_stack_index] + right = stack[self._right_stack_index] + assert left is not None + assert right is not None + stack[self._stack_index] = left * right + + def generate_code(self, stack_initialized: List[bool], _decrypt_outputs: bool) -> str: # noqa: D102 + if self._left_stack_index == self._stack_index: + return f"stack_{self._stack_index} *= stack_{self._right_stack_index};\n" + if self._right_stack_index == self._stack_index: + return f"stack_{self._stack_index} *= stack_{self._left_stack_index};\n" + + code = "" + if not stack_initialized[self._stack_index]: + code += "ctxt_t " + code += f"stack_{self._stack_index} = stack_{self._left_stack_index};\nstack_{self._stack_index} *= stack_{self._right_stack_index};\n" + stack_initialized[self._stack_index] = True + return code + + +class ConstantAdditionInstruction(ArithmeticInstruction): + """Reads an element from the stack, adds a constant to it it, and writes the result to the stack.""" + + def __init__(self, stack_index: int, input_stack_index: int, constant: FieldArray) -> None: + """Initialize an instruction that adds `constant` to the element at `input_stack_index`, placing the result at `stack_index`.""" + self._input_stack_index = input_stack_index + self._constant = constant + super().__init__(stack_index) + + def evaluate(self, stack: List[Optional[FieldArray]], _inputs: Dict[str, FieldArray]) -> None: # noqa: D102 + operand = stack[self._input_stack_index] + assert operand is not None + stack[self._stack_index] = operand + self._constant + + def generate_code(self, stack_initialized: List[bool], _decrypt_outputs: bool) -> str: # noqa: D102 + if self._stack_index == self._input_stack_index: + return f"stack_{self._input_stack_index} += {self._constant}l;\n" + + code = "" + if not stack_initialized[self._stack_index]: + code += "ctxt_t " + code += f"stack_{self._stack_index} = stack_{self._input_stack_index};\nstack_{self._stack_index} += {self._constant}l;\n" + stack_initialized[self._stack_index] = True + return code + + +class ConstantMultiplicationInstruction(ArithmeticInstruction): + """Reads an element from the stack, multiplies it with a constant, and writes the result to the stack.""" + + def __init__(self, stack_index: int, input_stack_index: int, constant: FieldArray) -> None: + """Initialize an instruction that multiplies the element at `input_stack_index` with `constant`, placing the result at `stack_index`.""" + self._input_stack_index = input_stack_index + self._constant = constant + super().__init__(stack_index) + + def evaluate(self, stack: List[Optional[FieldArray]], _inputs: Dict[str, FieldArray]) -> None: # noqa: D102 + operand = stack[self._input_stack_index] + assert operand is not None + stack[self._stack_index] = operand * self._constant + + def generate_code(self, stack_initialized: List[bool], _decrypt_outputs: bool) -> str: # noqa: D102 + if self._stack_index == self._input_stack_index: + return f"stack_{self._input_stack_index} *= {self._constant}l;\n" + + code = "" + if not stack_initialized[self._stack_index]: + code += "ctxt_t " + code += f"stack_{self._stack_index} = stack_{self._input_stack_index};\nstack_{self._stack_index} *= {self._constant}l;\n" + stack_initialized[self._stack_index] = True + return code + + +class InputInstruction(ArithmeticInstruction): + """Writes an input to the stack.""" + + def __init__(self, stack_index: int, name: str) -> None: + """Initialize an `InputInstruction` that places the input with the given `name` in the stack at index `stack_index`.""" + self._name = name + super().__init__(stack_index) + + def evaluate(self, stack: List[Optional[FieldArray]], inputs: Dict[str, FieldArray]) -> None: # noqa: D102 + stack[self._stack_index] = inputs[self._name] + + def generate_code(self, stack_initialized: List[bool], _decrypt_outputs: bool) -> str: # noqa: D102 + code = "" + if not stack_initialized[self._stack_index]: + code += "ctxt_t " + code += f"stack_{self._stack_index} = ciph_{self._name};\n" + stack_initialized[self._stack_index] = True + return code + + +class OutputInstruction(ArithmeticInstruction): + """Outputs an element from the stack.""" + + def evaluate(self, stack: List[FieldArray], _inputs: Dict[str, FieldArray]) -> FieldArray: # noqa: D102 + return stack[self._stack_index] + + def generate_code(self, stack_initialized: List[bool], decrypt_outputs: bool) -> str: # noqa: D102 + if decrypt_outputs: + return f"ptxt_t decrypted(context);\nsecret_key.Decrypt(decrypted, stack_{self._stack_index});\nstd::cout << decrypted << std::endl;\n" + else: + return f'std::cout << "Output correctness: " << stack_{self._stack_index}.isCorrect() << std::endl;\n' + + +class ArithmeticProgram: + """An ArithmeticProgram represents an ordered set of arithmetic operations that compute an arithmetic circuit. + + The easiest way to obtain an `ArithmeticProgram` of an `ArithmeticCircuit` is to call `ArithmeticCircuit.generate_program()`. + """ + + def __init__( + self, instructions: List[ArithmeticInstruction], stack_size: int, gf: Type[FieldArray] + ) -> None: + """Initialize an `ArithmeticProgram` from a list of `instructions`. + + The user must specify an upper bound on the `stack_size` required. + """ + self._instructions = instructions + self._stack_size = stack_size + self._gf = gf + + def execute(self, inputs: Dict[str, FieldArray]) -> FieldArray: + """Executes the arithmetic program on plaintext inputs without using encryption. + + Raises: + Exception: If there were no outputs in this program. + + Returns: + The first output in this program. + """ + # FIXME: Currently only supports a single output + for input in inputs.values(): + assert isinstance(input, self._gf) + + stack: List[Optional[FieldArray]] = [None for _ in range(self._stack_size)] + + for instruction in self._instructions: + if (output := instruction.evaluate(stack, inputs)) is not None: + return output + + raise Exception("The program did not output anything") + + def generate_code(self, decrypt_outputs: bool) -> str: + """Generates HElib code for this program. + + If `decrypt_outputs` is true, then the generated code will decrypt the outputs at the end of the circuit. + + Returns: + The generated code as a string. + """ + code = "" + stack_initialized = [False] * self._stack_size + + for instruction in self._instructions: + code += instruction.generate_code(stack_initialized, decrypt_outputs) + + return code + + +def test_instructions_small_comparison(): # noqa: D103 + from oraqle.compiler.circuit import Circuit + from oraqle.compiler.nodes.leafs import Input + + gf = GF(7) + + x = Input("x", gf) + y = Input("y", gf) + + arithmetic_circuit = Circuit([x < y]).arithmetize() + program = arithmetic_circuit.generate_program() + + for x in range(7): + for y in range(7): + inputs = {"x": gf(x), "y": gf(y)} + assert arithmetic_circuit.evaluate(inputs) == program.execute(inputs) diff --git a/oraqle/compiler/nodes/__init__.py b/oraqle/compiler/nodes/__init__.py new file mode 100644 index 0000000..2d7d435 --- /dev/null +++ b/oraqle/compiler/nodes/__init__.py @@ -0,0 +1,6 @@ +"""The nodes package contains a collection of fundamental abstract and concrete nodes.""" +from oraqle.compiler.nodes.abstract import Node +from oraqle.compiler.nodes.binary_arithmetic import Addition, Multiplication +from oraqle.compiler.nodes.leafs import Constant, Input + +__all__ = ['Addition', 'Constant', 'Input', 'Multiplication', 'Node'] diff --git a/oraqle/compiler/nodes/abstract.py b/oraqle/compiler/nodes/abstract.py new file mode 100644 index 0000000..4eef417 --- /dev/null +++ b/oraqle/compiler/nodes/abstract.py @@ -0,0 +1,783 @@ +"""Module containing the most fundamental classes in the compiler.""" +from abc import ABC, abstractmethod +from collections import Counter +from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union + +from galois import FieldArray + +from oraqle.compiler.graphviz import DotFile +from oraqle.compiler.instructions import ArithmeticInstruction + + +def select_stack_index(stack_occupied: List[bool]) -> int: + """Selects a free index in the stack and occupies it. + + Returns: + The first free index in `stack_occupied`. + """ + for index, occupied in enumerate(stack_occupied): + if not occupied: + stack_occupied[index] = True + return index + + index = len(stack_occupied) + stack_occupied.append(True) + return index + + +# TODO: It would be great if we can move out this ParetoFront class, but it's hard to do without circular imports +class ParetoFront(ABC): + """Abstract base class for ParetoFronts. + + One objective is to minimize the multiplicative depth, while the other objective is minimizing some value, such as the multiplicative size or cost. + """ + + def __init__(self) -> None: + """Initialize an empty ParetoFront.""" + self._nodes_by_depth: Dict[int, Tuple[Union[int, float], ArithmeticNode]] = {} + self._highest_depth: int = -1 + + @abstractmethod + def _get_value(self, node: "ArithmeticNode") -> Union[int, float]: + pass + + @abstractmethod + def _default_value(self) -> Union[int, float]: + pass + + @classmethod + def from_node( + cls, + node: "ArithmeticNode", + depth: Optional[int] = None, + value: Optional[Union[int, float]] = None, + ) -> "ParetoFront": + """Initialize a `ParetoFront` with one node in it. + + Returns: + New `ParetoFront`. + """ + self = cls() + self.add(node, depth, value) + return self + + @classmethod + def from_leaf(cls, leaf) -> "ParetoFront": + """Initialize a `ParetoFront` with one leaf node in it. + + Returns: + New `ParetoFront`. + """ + self = cls() + self.add_leaf(leaf) + return self + + def add( + self, + node: "ArithmeticNode", + depth: Optional[int] = None, + value: Optional[Union[int, float]] = None, + ) -> bool: + """Adds the given `Node` to the `ParetoFront` by computing its multiplicative depth and value. + + Alternatively, the user can supply an unchecked `depth` and `value` so that these values do not have to be (re)computed. + + Returns: + `True` if and only if the node was inserted into the ParetoFront (so it was in some way better than the current `Nodes`). + """ + if depth is None: + depth = node.multiplicative_depth() + + if value is None: + value = self._get_value(node) + + return self._add(depth, value, node) + + def _add(self, depth: int, value: Union[int, float], node: "ArithmeticNode") -> bool: + """Returns True if and only if the node was inserted into the ParetoFront.""" + for d in range(depth + 1): + if d in self._nodes_by_depth and self._nodes_by_depth[d][0] <= value: + return False + + self._nodes_by_depth[depth] = (value, node) + self._highest_depth = max(depth, self._highest_depth) + + for d in range(depth + 1, self._highest_depth + 1): + if d in self._nodes_by_depth and self._nodes_by_depth[d][0] >= value: + del self._nodes_by_depth[d] + + return True + + def add_leaf(self, leaf): + """Add a leaf node to this `ParetoFront`.""" + self._add(0, 0, leaf) # type: ignore + + def add_front(self, front: "ParetoFront"): + """Add all elements from `front` to `self`.""" + # TODO: This can be optimized + for d, s, n in front: + self.add(n, d, s) + + def __iter__(self) -> Iterator[Tuple[int, Union[int, float], "ArithmeticNode"]]: + for depth in range(self._highest_depth + 1): + if depth in self._nodes_by_depth: + yield depth, self._nodes_by_depth[depth][0], self._nodes_by_depth[depth][1] + + def get_smallest_at_depth( + self, max_depth: int + ) -> Optional[Tuple[int, Union[int, float], "ArithmeticNode"]]: + """Returns the circuit with the smallest value that has at most depth `max_depth`.""" + for depth in reversed(range(max_depth + 1)): + if depth in self._nodes_by_depth: + return depth, self._nodes_by_depth[depth][0], self._nodes_by_depth[depth][1] + + def is_empty(self) -> bool: + """Returns whether the front is empty.""" + return len(self._nodes_by_depth) == 0 + + def get_lowest_value(self) -> Optional["ArithmeticNode"]: + """Returns the value (size or cost) of the Node with the highest depth, and therefore the lowest value.""" + if self._highest_depth == -1: + return None + + return self._nodes_by_depth[self._highest_depth][1] + + +def iterate_increasing_depth(front1: ParetoFront, front2: ParetoFront) -> Iterator[ + Tuple[ + Tuple[int, Union[int, float], "ArithmeticNode"], + Tuple[int, Union[int, float], "ArithmeticNode"], + ] +]: + """Iterates over two ParetoFronts, returning pairs of ArithmeticNodes such that the multiplicative depth grows monotonically. + + Yields: + Pairs of tuples, containing the multiplicative depth, the multiplicative size/cost, and the arithmetization, in that order. + """ + highest_depth = max(front1._highest_depth, front2._highest_depth) + last_depth: Optional[int] = None + + # TODO: This is quite inefficient because we constantly loop over the same parts of the fronts, we could instead iterate over both fronts in sequence + for depth in range(highest_depth + 1): + res1 = front1.get_smallest_at_depth(depth) + res2 = front2.get_smallest_at_depth(depth) + + if res1 is None or res2 is None: + continue + + d1, _, _ = res1 + d2, _, _ = res2 + + if last_depth is None or d1 > last_depth or d2 > last_depth: + yield res1, res2 + + +class SizeParetoFront(ParetoFront): + """A `ParetoFront` that trades off multiplicative depth with multiplicative size.""" + + def _get_value(self, node: "ArithmeticNode") -> int: + return node.multiplicative_size() + + def _default_value(self) -> int: + return 0 + + def add(self, node: "ArithmeticNode", depth: Optional[int] = None, size: Optional[int] = None): + """Adds the given `Node` to the `SizeParetoFront` by computing its multiplicative depth and size. + + Alternatively, the user can supply an unchecked `depth` and `size` so that these values do not have to be (re)computed. + + Returns: + `True` if and only if the node was inserted into the ParetoFront (so it was in some way better than the current `Nodes`). + """ + return super().add(node, depth, value=size) + + +class CostParetoFront(ParetoFront): + """A `ParetoFront` that trades off multiplicative depth with multiplicative cost.""" + + def __init__(self, cost_of_squaring: float) -> None: + """Initialize an empty `CostParetoFront` with the given `cost_of_squaring`.""" + self._cost_of_squaring = cost_of_squaring + super().__init__() + + @classmethod + def from_node( + cls, + node: "ArithmeticNode", + cost_of_squaring: float, + depth: Optional[int] = None, + cost: Optional[float] = None, + ) -> "CostParetoFront": + """Initialize a `CostParetoFront` with one node in it. + + Returns: + New `CostParetoFront`. + """ + self = cls(cost_of_squaring) + self.add(node, depth, cost) + return self + + @classmethod + def from_leaf(cls, leaf, cost_of_squaring: float) -> "CostParetoFront": + """Initialize a `CostParetoFront` with one leaf node in it. + + Returns: + New `CostParetoFront`. + """ + self = cls(cost_of_squaring) + self.add_leaf(leaf) + return self + + def _get_value(self, node: "ArithmeticNode") -> float: + return node.multiplicative_cost(self._cost_of_squaring) + + def _default_value(self) -> float: + return 0.0 + + def add( + self, node: "ArithmeticNode", depth: Optional[int] = None, cost: Optional[float] = None + ) -> bool: + """Adds the given `Node` to the `CostParetoFront` by computing its multiplicative depth and cost. + + Alternatively, the user can supply an unchecked `depth` and `cost` so that these values do not have to be (re)computed. + + Returns: + `True` if and only if the node was inserted into the ParetoFront (so it was in some way better than the current `Nodes`). + """ + return super().add(node, depth, value=cost) + + +def _to_node(obj: Union["Node", int, bool], gf: Type[FieldArray]) -> "Node": + if isinstance(obj, Node): + return obj + + if isinstance(obj, int): + from oraqle.compiler.nodes.leafs import Constant + + return Constant(gf(obj)) + + +def try_to_node(obj: Any, gf: Type[FieldArray]) -> Optional["Node"]: + """Tries to cast this object into a valid `Node`. + + This can be used to transform e.g. an `int` or `bool` into a `Constant`. + If it is applied to a `Node`, it does nothing. + + Returns: + A `Node` or `None` depending on whether the object is castable. + """ + return _to_node(obj, gf) + + +class Node(ABC): # noqa: PLR0904 + """Abstract node in an arithmetic circuit.""" + + @property + @abstractmethod + def _node_label(self) -> str: + pass + + # TODO: This property should be removed if we do not provide a default hash implementation. + @property + @abstractmethod + def _hash_name(self) -> str: + pass + + @property + def _overriden_graphviz_attributes(self) -> dict: + return {"style": "rounded,filled", "fillcolor": "cornsilk"} + + def __init__(self, gf: Type[FieldArray]): + """Creates a new node, of which the result is known by the parties identified by `known_by`, as well as those who know all input operands.""" + # TODO: We should probably make separate methods to clear individual caches + self._evaluate_cache: Optional[FieldArray] = None + self._to_graph_cache: Optional[int] = None + self._arithmetize_cache: Optional[Node] = None + self._arithmetize_depth_cache: Optional[CostParetoFront] = None + self._instruction_cache: Optional[int] = None + self._arithmetic_cache: Optional[ArithmeticNode] = None + self._parent_count_cache: Optional[int] = None + + self._hash = None + + self._party = None + self._plaintext = False + self._parent_count = 0 + + self._gf = gf + + @abstractmethod + def apply_function_to_operands(self, function: Callable[["Node"], None]): + """Applies function to all operands of this node.""" + + @abstractmethod + def replace_operands_using_function(self, function: Callable[["Node"], "Node"]): + """Replaces each operand of this node with the node generated by calling function on said operand.""" + + @abstractmethod + def evaluate(self, actual_inputs: Dict[str, FieldArray]) -> FieldArray: + """Evaluates the node in the arithmetic circuit. The output should always be reduced modulo the modulus.""" + + def clear_cache(self, already_cleared: Set[int]): + """Clears any cached values of the node and any of its operands.""" + # FIXME: The cache should not be cleared twice for the same node, but there is no way to check this. + if id(self) not in already_cleared: + self.apply_function_to_operands(lambda operand: operand.clear_cache(already_cleared)) + + self._evaluate_cache: Optional[FieldArray] = None + self._to_graph_cache: Optional[int] = None + self._arithmetize_cache: Optional[Node] = None + self._arithmetize_depth_cache: Optional[CostParetoFront] = None + self._instruction_cache: Optional[int] = None + self._arithmetic_cache: Optional[ArithmeticNode] = None + self._parent_count_cache: Optional[int] = None + + self._hash = None + + already_cleared.add(id(self)) + + def to_graph(self, graph_builder: DotFile) -> int: + """Adds this node to the graph as well as its edges. + + Returns: + The identifier of this `Node` in the `DotFile`. + """ + if self._to_graph_cache is None: + attributes = {"shape": "box"} + attributes.update(self._overriden_graphviz_attributes) + + self._to_graph_cache = graph_builder.add_node( + label=self._node_label, + **attributes, + ) + + # FIXME: This does not take multiplicity into account; add option to apply_function_to_operands to take multiplicity into account + self.apply_function_to_operands(lambda operand: graph_builder.add_link(operand.to_graph(graph_builder), self._to_graph_cache)) # type: ignore + + return self._to_graph_cache + + @abstractmethod + def __hash__(self) -> int: + raise NotImplementedError( + "The abstract class does not provide a default implementation of __hash__" + ) + + # TODO: We can add a strategy to this method, e.g. to exhaustively check equivalence. + @abstractmethod + def is_equivalent(self, other: "Node") -> bool: + """Checks whether two nodes are semantically equivalent. + + This method will always return `False` if they are not. + This method will maybe return True if they are indeed equivalent. + In other words, this method may produce false negatives, but it will never produce false positives. + """ + + # TODO: Rework CSE. In an arithmetic circuit, it should only return arithmetic nodes. + def eliminate_common_subexpressions(self, terms: Dict[int, "Node"]) -> "Node": + """Eliminates duplicate subexpressions that are equivalent (as defined by a node's `__eq__` and `__hash__` method). + + Returns: + A `Node` that must replace the previous expression. + """ + # TODO: What if we try breadth-first search? It will be more expensive but it will save the lowest depth solution first. + # FIXME: Handle conflicts (duplicate hashes) using a list instead of a single node. + # TODO: For performance reasons, maybe we should only save terms of a certain maximum depth. + h = hash(self) + if h in terms and self.is_equivalent(terms[h]): + return terms[h] + + self.replace_operands_using_function( + lambda operand: operand.eliminate_common_subexpressions(terms) + ) + + terms[h] = self + return self + + def count_parents(self): + """Counts the total number of nodes in this subcircuit.""" + self._parent_count += 1 + + if self._parent_count_cache is None: + self._parent_count_cache = True + self.apply_function_to_operands(lambda operand: operand.count_parents()) + + def reset_parent_count(self): + """Resets the cached number of nodes in this subcircuit to 0.""" + self._parent_count = 0 + self.apply_function_to_operands(lambda operand: operand.reset_parent_count()) + + @abstractmethod + def arithmetize(self, strategy: str) -> "Node": + """Arithmetizes this node, replacing it with only arithmetic operations (constants, additions, and multiplications). + + The current implementation only aims at reducing the total number of multiplications. + """ + + @abstractmethod + def arithmetize_depth_aware( + self, cost_of_squaring: float + ) -> "CostParetoFront": + """Arithmetizes this node in a depth-aware fashion, replacing high-level nodes with only arithmetic operations (constants, additions, and multiplications). + + Returns: + `CostParetoFront` containing a front that trades off multiplicative depth and multiplicative cost. + """ + + def to_arithmetic(self) -> "ArithmeticNode": + """Outputs this node's equivalent ArithmeticNode. Errors if this node does not have a direct arithmetic equivalent. + + Raises: + Exception: If there is no direct arithmetic equivalent. + """ + # TODO: Make this a non-generic exception + raise Exception( + f"This node does not have a direct arithmetic equivalent: {self}. Consider first calling `arithmetize`." + ) + + def add(self, other: "Node", flatten=True) -> "Node": + """Performs a summation between `self` and `other`, possibly flattening any sums. + + It is possible to disable flattening by setting `flatten=False`. + + Returns: + A possibly flattened `Sum` node or a `Constant` representing self & other. + """ + from oraqle.compiler.nodes.arbitrary_arithmetic import Sum + from oraqle.compiler.nodes.leafs import Constant + + if flatten and isinstance(self, Sum): + return self.add_flatten(other) + + if flatten and isinstance(other, Sum): + return other.add_flatten(self) + + if isinstance(other, Constant): + if int(other._value) == 0: + return self + return Sum(Counter({UnoverloadedWrapper(self): 1}), self._gf, constant=other._value) + + if id(self) == id(other): + return Sum(Counter({UnoverloadedWrapper(self): 2}), self._gf) + else: + return Sum( + Counter({UnoverloadedWrapper(self): 1, UnoverloadedWrapper(other): 1}), self._gf + ) + + def __add__(self, other) -> "Node": + other_node = try_to_node(other, self._gf) + if other_node is None: + raise Exception(f"The RHS of this + cannot be made into a Node: {self} - {other}") + + return self.add(other_node) + + def __radd__(self, other) -> "Node": + other_node = try_to_node(other, self._gf) + if other_node is None: + raise Exception(f"The LHS of this + cannot be made into a Node: {other} - {self}") + + return self.add(other_node) + + def mul(self, other: "Node", flatten=True) -> "Node": # noqa: PLR0911 + """Performs a multiplication between `self` and `other`, possibly flattening any products. + + It is possible to disable flattening by setting `flatten=False`. + + Returns: + A possibly flattened `Product` node or a `Constant` representing self & other. + """ + from oraqle.compiler.nodes.arbitrary_arithmetic import Product + from oraqle.compiler.nodes.leafs import Constant + + if flatten and isinstance(self, Product): + return self.mul_flatten(other) + + if flatten and isinstance(other, Product): + return other.mul_flatten(self) + + if isinstance(other, Constant): + if int(other._value) == 0: + return other + if int(other._value) == 1: + return self + return Product(Counter({UnoverloadedWrapper(self): 1}), self._gf, constant=other._value) + + if id(self) == id(other): + return Product(Counter({UnoverloadedWrapper(self): 2}), self._gf) + else: + return Product( + Counter({UnoverloadedWrapper(self): 1, UnoverloadedWrapper(other): 1}), self._gf + ) + + def __mul__(self, other) -> "Node": + if not isinstance(other, Node): + raise Exception(f"The RHS of this multiplication is not a Node: {self} * {other}") + + return self.mul(other) + + def bool_or(self, other: "Node", flatten=True) -> "Node": + """Performs an OR operation between `self` and `other`, possibly flattening the result into an OR operation between many operands. + + It is possible to disable flattening by setting `flatten=False`. + + Returns: + A possibly flattened `Or` node or a `Constant` representing self & other. + """ + from oraqle.compiler.boolean.bool_or import Or + from oraqle.compiler.nodes.leafs import Constant + + if flatten and isinstance(other, Or): + return other.or_flatten(self) + + if isinstance(other, Constant): + if bool(other._value): + return Constant(self._gf(1)) + else: + return self + + if self.is_equivalent(other): + return self + else: + return Or({UnoverloadedWrapper(self), UnoverloadedWrapper(other)}, self._gf) + + def __or__(self, other) -> "Node": + if not isinstance(other, Node): + raise Exception(f"The RHS of this OR is not a Node: {self} | {other}") + + return self.bool_or(other) + + def bool_and(self, other: "Node", flatten=True) -> "Node": + """Performs an AND operation between `self` and `other`, possibly flattening the result into an AND operation between many operands. + + It is possible to disable flattening by setting `flatten=False`. + + Returns: + A possibly flattened `And` node or a `Constant` representing self & other. + """ + from oraqle.compiler.boolean.bool_and import And + from oraqle.compiler.nodes.leafs import Constant + + if flatten and isinstance(other, And): + return other.and_flatten(self) + + if isinstance(other, Constant): + if bool(other._value): + return self + else: + return Constant(self._gf(0)) + + if self.is_equivalent(other): + return self + else: + return And({UnoverloadedWrapper(self), UnoverloadedWrapper(other)}, self._gf) + + def __and__(self, other) -> "Node": + if not isinstance(other, Node): + raise Exception(f"The RHS of this AND is not a Node: {self} & {other}") + + return self.bool_and(other) + + def __lt__(self, other) -> "Node": + other_node = try_to_node(other, self._gf) + if other_node is None: + raise Exception(f"The RHS of this < cannot be made into a Node: {self} < {other}") + + from oraqle.compiler.comparison.comparison import StrictComparison + + return StrictComparison(self, other_node, less_than=True, gf=self._gf) + + def __gt__(self, other) -> "Node": + other_node = try_to_node(other, self._gf) + if other_node is None: + raise Exception(f"The RHS of this > cannot be made into a Node: {self} > {other}") + + from oraqle.compiler.comparison.comparison import StrictComparison + + return StrictComparison(self, other_node, less_than=False, gf=self._gf) + + def __le__(self, other) -> "Node": + other_node = try_to_node(other, self._gf) + if other_node is None: + raise Exception(f"The RHS of this <= cannot be made into a Node: {self} <= {other}") + + from oraqle.compiler.comparison.comparison import Comparison + + return Comparison(self, other_node, less_than=True, gf=self._gf) + + def __ge__(self, other) -> "Node": + other_node = try_to_node(other, self._gf) + if other_node is None: + raise Exception(f"The RHS of this >= cannot be made into a Node: {self} >= {other}") + + from oraqle.compiler.comparison.comparison import Comparison + + return Comparison(self, other_node, less_than=False, gf=self._gf) + + def __neg__(self) -> "Node": + from oraqle.compiler.nodes.leafs import Constant + + return Constant(-self._gf(1)) * self + + def __invert__(self) -> "Node": + from oraqle.compiler.boolean.bool_neg import Neg + + return Neg(self, self._gf) + + def __pow__(self, other) -> "Node": + if not isinstance(other, int): + raise Exception(f"The exponent must be an integer: {self}**{other}") + + from oraqle.compiler.arithmetic.exponentiation import Power + + return Power(self, other, self._gf) + + def __sub__(self, other) -> "Node": + other_node = try_to_node(other, self._gf) + if other_node is None: + raise Exception(f"The RHS of this - cannot be made into a Node: {self} - {other}") + + from oraqle.compiler.arithmetic.subtraction import Subtraction + + return Subtraction(self, other_node, self._gf) + + def __rsub__(self, other) -> "Node": + other_node = try_to_node(other, self._gf) + if other_node is None: + raise Exception(f"The LHS of this - cannot be made into a Node: {other} - {self}") + + from oraqle.compiler.arithmetic.subtraction import Subtraction + + return Subtraction(other_node, self, self._gf) + + def __eq__(self, other) -> "Node": + other_node = try_to_node(other, self._gf) + if other_node is None: + raise Exception(f"The RHS of this == cannot be made into a Node: {self} == {other}") + + from oraqle.compiler.comparison.equality import Equals + + return Equals(self, other_node, self._gf) + + +class UnoverloadedWrapper: + """The `UnoverloadedWrapper` class wraps a `Node` such that hash(.) and x == y work as expected. + + !!! note + The equality operator perform semantic equality! + """ + + def __init__(self, node: Node) -> None: + """Wrap `Node`.""" + self.node = node + + def __hash__(self) -> int: + return hash(self.node) + + def __eq__(self, other) -> bool: + if not isinstance(other, UnoverloadedWrapper): + return False + + if hash(self) != hash(other): + return False + + return self.node.is_equivalent(other.node) + + +# TODO: Do we need a separate class to distinguish nodes from arithmetic nodes (which only have arithmetic operands)? +class ArithmeticNode(Node): + """Extension of Node to indicate that this is a node permitted in a purely arithmetic circuit (with binary additions and multiplications). + + The ArithmeticNode 'mixin' must always come before the base class in the class declaration. + """ + + # ArithmeticNode should be like an interface; it should not have an __init__ method. + + def clear_cache(self, already_cleared: Set[int]): + """Clears any cached values of the node and any of its operands.""" + # FIXME: The cache should not be cleared twice for the same node, but there is no way to check this. + if id(self) not in already_cleared: + for node in self.operands(): + node.clear_cache(already_cleared) + + self._evaluate_cache: Optional[FieldArray] = None + self._to_graph_cache: Optional[int] = None + self._arithmetize_cache: Optional[Node] = None + self._arithmetize_depth_cache: Optional[ParetoFront] = None + self._instruction_cache: Optional[int] = None + self._arithmetic_cache: Optional[ArithmeticNode] = None + self._parent_count_cache: Optional[int] = None + + self._hash = None + + already_cleared.add(id(self)) + + @abstractmethod + def operands(self) -> List["ArithmeticNode"]: + """Returns the operands (children) of this node. The list can be empty. The nodes MUST be arithmetic nodes.""" + + @abstractmethod + def set_operands(self, operands: List["ArithmeticNode"]): + """Overwrites the operands of this node. The nodes MUST be arithmetic nodes.""" + + @abstractmethod + def multiplicative_depth(self) -> int: + """Computes the multiplicative depth of this node and its children recursively. + + Returns: + The largest number of multiplications from the output of this node to the leafs of this subcircuit. + """ + + def multiplicative_size(self) -> int: + """Computes the multiplicative size (number of multiplications) by counting the size of the set returned by self.multiplications(). + + Returns: + The number of multiplications in this subcircuit. + """ + return len(self.multiplications()) + + def multiplicative_cost(self, cost_of_squaring: float) -> float: + """Computes the multiplicative cost (number of general multiplications + cost_of_squaring * squarings). + + It does so by counting the size of the sets returned by self.multiplications() and self.squarings(). + + Returns: + The number of proper multiplications + the cost of squaring * the number of squarings. + """ + return ( + len(self.multiplications()) + - len(self.squarings()) + + cost_of_squaring * len(self.squarings()) + ) + + @abstractmethod + def multiplications(self) -> Set[int]: + """Returns a set of all the multiplications in this tree of descendants, including itself. + + This includes any squarings. + """ + + @abstractmethod + def squarings(self) -> Set[int]: + """Returns a set of all the squarings in this tree of descendants, including itself.""" + + def arithmetize(self, strategy: str) -> "ArithmeticNode": # noqa: D102 + if self._arithmetize_cache2 is None: + self.set_operands([operand.arithmetize(strategy) for operand in self.operands()]) + self._arithmetize_cache2 = self + + return self._arithmetize_cache2 + + @abstractmethod + def create_instructions( + self, + instructions: List[ArithmeticInstruction], + stack_counter: int, + stack_occupied: List[bool], + ) -> Tuple[int, int]: + """Creates a set of instructions of this node to the given file. Returns the index in the stack of the output and the stack_counter. + + !!! note + This method assumes that the _parent_count of each node is up to date. + """ + + def to_arithmetic(self) -> "ArithmeticNode": # noqa: D102 + return self diff --git a/oraqle/compiler/nodes/arbitrary_arithmetic.py b/oraqle/compiler/nodes/arbitrary_arithmetic.py new file mode 100644 index 0000000..d2516a9 --- /dev/null +++ b/oraqle/compiler/nodes/arbitrary_arithmetic.py @@ -0,0 +1,392 @@ +"""This module contains arithmetic operations between a flexible number of inputs: summations and products.""" +import itertools +from collections import Counter +from dataclasses import dataclass, field +from functools import reduce +from heapq import heapify, heappop, heappush +from typing import Any +from typing import Counter as CounterType +from typing import Dict, Iterable, Optional, Tuple, Type, Union + +from galois import FieldArray + +from oraqle.compiler.nodes.abstract import ( + ArithmeticNode, + CostParetoFront, + Node, + UnoverloadedWrapper, + _to_node, +) +from oraqle.compiler.nodes.binary_arithmetic import Addition, Multiplication +from oraqle.compiler.nodes.flexible import CommutativeMultiplicityReducibleNode +from oraqle.compiler.nodes.leafs import Constant +from oraqle.compiler.nodes.unary_arithmetic import ConstantAddition, ConstantMultiplication + + +# TODO: This is mostly copied from generate_multiplication_tree (depth is different) +def _generate_addition_tree( + summands: Iterable[Tuple[int, ArithmeticNode]], counts: Iterable[int] +) -> Tuple[int, Addition]: + queue = [ + _PrioritizedItem(*summand) for summand, count in zip(summands, counts) for _ in range(count) + ] + heapify(queue) + + while len(queue) > 1: + a = heappop(queue) + b = heappop(queue) + + a_const = isinstance(a.item, Constant) + b_const = isinstance(b.item, Constant) + + # TODO: This should move to Node + if a_const: + if b_const: + new = a.item + b.item + else: + new = b.item if a.item._value == 0 else ConstantAddition(b.item, a.item._value) + elif b_const: + new = a.item if b.item._value == 0 else ConstantAddition(a.item, b.item._value) + else: + new = Addition(a.item, b.item, a.item._gf) + + heappush( + queue, + _PrioritizedItem(max(a.priority, b.priority), new), + ) + + return (queue[0].priority, queue[0].item) + + +class Sum(CommutativeMultiplicityReducibleNode): + """This node represents a sum between two or more operands, or at least one operand and a constant.""" + + @property + def _hash_name(self) -> str: + return "sum" + + @property + def _node_label(self) -> str: + return "+" + + @property + def _identity(self) -> FieldArray: + return self._gf(0) + + def _arithmetize_inner(self, strategy: str) -> Node: + # TODO: Wrap exponents + new_operands = Counter() + new_constant = self._constant + for operand, count in self._operands.items(): + new_operand = operand.node.arithmetize(strategy) + + if isinstance(new_operand, Constant): + new_constant += new_operand._value * count + else: + new_operands[UnoverloadedWrapper(new_operand)] += count + + if len(new_operands) == 0: + return Constant(new_constant) # type: ignore + elif sum(new_operands.values()) == 1 and new_constant == self._identity: + return next(iter(new_operands)).node + + return Sum(new_operands, self._gf, new_constant) + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + # FIXME: This could be done way more efficiently by iterating over increasing depth + front = CostParetoFront(cost_of_squaring) + + for operands in itertools.product( + *( + operand.node.arithmetize_depth_aware(cost_of_squaring) + for operand in self._operands + ) + ): + addition_tree = _generate_addition_tree( + ((d, operand) for d, _, operand in operands), self._operands.values() + ) + if self._constant != self._identity: + if isinstance(addition_tree[1], Constant): + return CostParetoFront.from_leaf( + Constant(addition_tree[1]._value + self._constant), cost_of_squaring + ) + + addition_tree = ( + addition_tree[0], + ConstantAddition(addition_tree[1], self._constant), + ) + front.add(addition_tree[1], depth=addition_tree[0]) + + assert not front.is_empty() + return front + + def to_arithmetic(self) -> ArithmeticNode: # noqa: D102 + if self._arithmetic_cache is None: + # FIXME: Perform actual rebalancing + operands = iter(self._operands.elements()) + + # TODO: There is a lot of duplication between this and multiplications + if self._constant == self._identity: + self._arithmetic_cache = Addition( + next(operands).node.to_arithmetic(), + next(operands).node.to_arithmetic(), + self._gf, + ) + else: + self._arithmetic_cache = ConstantAddition( + next(operands).node.to_arithmetic(), self._constant + ) + + for operand in operands: + self._arithmetic_cache = Addition( + self._arithmetic_cache, operand.node.to_arithmetic(), self._gf + ) + + return self._arithmetic_cache + + def evaluate(self, actual_inputs: Dict[str, FieldArray]) -> FieldArray: # noqa: D102 + if self._evaluate_cache is None: + self._evaluate_cache = reduce( + lambda a, b: a + b, + ( + operand.node.evaluate(actual_inputs) * count + for operand, count in self._operands.items() + ), + ) + self._evaluate_cache += self._constant + + return self._evaluate_cache # type: ignore + + def add_flatten(self, other: Node) -> Node: + """Adds this node to `other`, flattening the summation if either of the two is also a `Sum` and absorbing `Constant`s. + + Returns: + A `Sum` node containing the flattened summation, or a `Constant` node. + """ + order = self._gf.order + # TODO: Consider already assigning values to e.g. result._depth + if isinstance(other, Sum): + counter = self._operands + other._operands + counter_dict = { + el: count % order for el, count in counter.items() if count % order != 0 + } + constant = self._constant + other._constant + if len(counter_dict) == 0: + return Constant(constant) # type: ignore + return Sum(Counter(counter_dict), self._gf, constant) # type: ignore + elif isinstance(other, Constant): + if sum(self._operands.values()) == 1 and int(self._constant + other._value) == 0: + return next(iter(self._operands)).node + return Sum(self._operands, self._gf, self._constant + other._value) # type: ignore + + counter = self._operands.copy() + unoverloaded_other = UnoverloadedWrapper(other) + counter[unoverloaded_other] = (counter[unoverloaded_other] + 1) % order + if counter[unoverloaded_other] == 0: + counter.pop(unoverloaded_other) + + # FIXME: If empty, return Constant(0) + + return Sum(counter, self._gf, self._constant) + + +@dataclass(order=True) +class _PrioritizedItem: + priority: int + item: Any = field(compare=False) + + +def _generate_multiplication_tree( + multiplicands: Iterable[Tuple[int, ArithmeticNode]], counts: Iterable[int] +) -> Tuple[int, Multiplication]: + queue = [ + _PrioritizedItem(*multiplicand) + for multiplicand, count in zip(multiplicands, counts) + for _ in range(count) + ] + heapify(queue) + + while len(queue) > 1: + a = heappop(queue) + b = heappop(queue) + + a_const = isinstance(a.item, Constant) + b_const = isinstance(b.item, Constant) + + # TODO: This should move to Node + if a_const: + if b_const: + new = a.item * b.item + elif a.item._value == 1: + new = b.item + else: + new = ConstantMultiplication(b.item, a.item._value) + elif b_const: + new = a.item if b.item._value == 1 else ConstantMultiplication(a.item, b.item._value) + else: + new = Multiplication(a.item, b.item, a.item._gf) + + heappush( + queue, + _PrioritizedItem(max(a.priority, b.priority) + (not a_const and not b_const), new), + ) + + return (queue[0].priority, queue[0].item) + + +class Product(CommutativeMultiplicityReducibleNode): + """This node represents a product between two or more operands, or at least one operand and a constant.""" + + def __init__( + self, + operands: CounterType[UnoverloadedWrapper], + gf: Type[FieldArray], + constant: Optional[FieldArray] = None, + ): + """Initialize a `Product` with the given `Counter` as operands and an optional `constant`.""" + super().__init__(operands, gf, constant) + assert constant != 0 + + @property + def _hash_name(self) -> str: + return "product" + + @property + def _node_label(self) -> str: + return "ร—" # noqa: RUF001 + + @property + def _identity(self) -> FieldArray: + return self._gf(1) + + def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray: + return a * b # type: ignore + + def _arithmetize_inner(self, strategy: str) -> Node: + # TODO: Wrap exponents + new_operands = Counter() + new_constant = self._constant + for operand, count in self._operands.items(): + new_operand = operand.node.arithmetize(strategy) + + if isinstance(new_operand, Constant): + new_constant *= new_operand._value**count + else: + new_operands[UnoverloadedWrapper(new_operand)] += count + + if len(new_operands) == 0: + return Constant(new_constant) # type: ignore + elif sum(new_operands.values()) == 1 and new_constant == self._identity: + return next(iter(new_operands)).node + + if new_constant == 0: + return Constant(self._gf(0)) + + return Product(new_operands, self._gf, new_constant) # type: ignore + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + # TODO: This could be done more efficiently by going breadth-wise + front = CostParetoFront(cost_of_squaring) + + for operands in itertools.product( + *( + operand.node.arithmetize_depth_aware(cost_of_squaring) + for operand in self._operands + ) + ): + multiplication_tree = _generate_multiplication_tree( + ((d, operand) for d, _, operand in operands), self._operands.values() + ) + if self._constant != self._identity: + if isinstance(multiplication_tree[1], Constant): + return CostParetoFront.from_leaf( + Constant(multiplication_tree[1]._value * self._constant), cost_of_squaring + ) + + multiplication_tree = ( + multiplication_tree[0], + ConstantMultiplication(multiplication_tree[1], self._constant), + ) + front.add(multiplication_tree[1], depth=multiplication_tree[0]) + + assert not front.is_empty() + return front + + def to_arithmetic(self) -> ArithmeticNode: # noqa: D102 + if self._arithmetic_cache is None: + # FIXME: Perform actual rebalancing + operands = iter(self._operands.elements()) + + if self._constant == self._identity: + self._arithmetic_cache = Multiplication( + next(operands).node.to_arithmetic(), + next(operands).node.to_arithmetic(), + self._gf, + ) + else: + self._arithmetic_cache = ConstantMultiplication( + next(operands).node.to_arithmetic(), self._constant + ) + + for operand in operands: + self._arithmetic_cache = Multiplication( + self._arithmetic_cache, operand.node.to_arithmetic(), self._gf + ) + + return self._arithmetic_cache + + def evaluate(self, actual_inputs: Dict[str, FieldArray]) -> FieldArray: # noqa: D102 + if self._evaluate_cache is None: + self._evaluate_cache = reduce(lambda a, b: a * b, (operand.node.evaluate(actual_inputs) ** count for operand, count in self._operands.items())) # type: ignore + self._evaluate_cache *= self._constant # type: ignore + + return self._evaluate_cache # type: ignore + + def mul_flatten(self, other: Node) -> Node: + """Multiplies this node with `other`, flattening the product if either of the two is also a `Product` and absorbing `Constant`s. + + Returns: + A `Product` node containing the flattened product, or a `Constant` node. + """ + # TODO: Consider already assigning values to e.g. result._depth + if isinstance(other, Product): + # TODO: Wrap powers (due to modulo arithmetic) + return Product(self._operands + other._operands, self._gf, self._constant * other._constant) # type: ignore + elif isinstance(other, Constant): + if other._value == 0: + return Constant(self._gf(0)) + return Product(self._operands, self._gf, self._constant * other._value) # type: ignore + + counter = self._operands.copy() + counter[UnoverloadedWrapper(other)] += 1 # type: ignore + return Product(counter, self._gf, self._constant) + + +def _first_gf(*operands: Union[Node, int, bool]) -> Optional[Type[FieldArray]]: + for operand in operands: + if isinstance(operand, Node): + return operand._gf + + +def sum_(*operands: Union[Node, int, bool]) -> Sum: + """Performs a sum between any number of nodes (or operands such as integers). + + Returns: + A `Sum` between all operands. + """ + assert len(operands) > 0 + gf = _first_gf(*operands) + assert gf is not None + return Sum(Counter(UnoverloadedWrapper(_to_node(operand, gf)) for operand in operands), gf) + + +def product_(*operands: Node) -> Product: + """Performs a product between any number of nodes (or operands such as integers). + + Returns: + A `Product` between all operands. + """ + assert len(operands) > 0 + gf = _first_gf(*operands) + assert gf is not None + return Product(Counter(UnoverloadedWrapper(_to_node(operand, gf)) for operand in operands), gf) diff --git a/oraqle/compiler/nodes/binary_arithmetic.py b/oraqle/compiler/nodes/binary_arithmetic.py new file mode 100644 index 0000000..ec5daac --- /dev/null +++ b/oraqle/compiler/nodes/binary_arithmetic.py @@ -0,0 +1,264 @@ +"""Module containing binary arithmetic nodes: additions and multiplications between non-constant nodes.""" +from abc import abstractmethod +from typing import List, Optional, Set, Tuple, Type + +from galois import FieldArray + +from oraqle.compiler.instructions import ( + AdditionInstruction, + ArithmeticInstruction, + MultiplicationInstruction, +) +from oraqle.compiler.nodes.abstract import ( + ArithmeticNode, + CostParetoFront, + Node, + iterate_increasing_depth, + select_stack_index, +) +from oraqle.compiler.nodes.fixed import BinaryNode +from oraqle.compiler.nodes.leafs import Constant + + +class CommutativeBinaryNode(BinaryNode): + """This node has two operands and implements a commutative operation between arithmetic nodes.""" + + def __init__( + self, + left: Node, + right: Node, + gf: Type[FieldArray], + ): + """Initialize the binary node with operands `left` and `right`.""" + self._left = left + self._right = right + super().__init__(gf) + + @abstractmethod + def _operation_inner(self, x: FieldArray, y: FieldArray) -> FieldArray: + """Applies the binary operation on x and y.""" + + def operation(self, operands: List[FieldArray]) -> FieldArray: # noqa: D102 + return self._operation_inner(operands[0], operands[1]) + + def operands(self) -> List[Node]: # noqa: D102 + return [self._left, self._right] + + def set_operands(self, operands: List[ArithmeticNode]): # noqa: D102 + self._left = operands[0] + self._right = operands[1] + + def __hash__(self) -> int: + if self._hash is None: + left_hash = hash(self._left) + right_hash = hash(self._right) + + # Make the hash commutative + if left_hash < right_hash: + self._hash = hash((self._hash_name, (left_hash, right_hash))) + else: + self._hash = hash((self._hash_name, (right_hash, left_hash))) + + return self._hash + + def is_equivalent(self, other: Node) -> bool: # noqa: D102 + if not isinstance(other, self.__class__): + return False + + if hash(self) != hash(other): + return False + + # Equivalence by commutative equality + return ( + self._left.is_equivalent(other._left) and self._right.is_equivalent(other._right) + ) or (self._left.is_equivalent(other._right) and self._right.is_equivalent(other._left)) + + +class CommutativeArithmeticBinaryNode(CommutativeBinaryNode): + """This node has two operands and implements a commutative operation between arithmetic nodes.""" + + def __init__( + self, + left: ArithmeticNode, + right: ArithmeticNode, + gf: Type[FieldArray], + ): + """Initialize this binary node with the given `left` and `right` operands. + + Raises: + Exception: Neither `left` nor `right` is allowed to be a `Constant`. + """ + super().__init__(left, right, gf) + + self._multiplications: Optional[Set[int]] = None + self._squarings: Optional[Set[int]] = None + self._depth_cache: Optional[int] = None + + if isinstance(left, Constant) or isinstance(right, Constant): + self._is_multiplication = False + raise Exception("This should be a constant.") + + def multiplicative_depth(self) -> int: # noqa: D102 + if self._depth_cache is None: + self._depth_cache = self._is_multiplication + max( + self._left.multiplicative_depth(), self._right.multiplicative_depth() + ) + + return self._depth_cache + + def multiplications(self) -> Set[int]: # noqa: D102 + if self._multiplications is None: + self._multiplications = set().union( + *(operand.multiplications() for operand in self.operands()) # type: ignore + ) + if self._is_multiplication: + self._multiplications.add(id(self)) + + return self._multiplications + + # TODO: Squaring should probably be a UniveriateNode + def squarings(self) -> Set[int]: # noqa: D102 + if self._squarings is None: + self._squarings = set().union(*(operand.squarings() for operand in self.operands())) # type: ignore + if self._is_multiplication and id(self._left) == id(self._right): + self._squarings.add(id(self)) + + return self._squarings + + def create_instructions( # noqa: D102 + self, + instructions: List[ArithmeticInstruction], + stack_counter: int, + stack_occupied: List[bool], + ) -> Tuple[int, int]: + self._left: ArithmeticNode + self._right: ArithmeticNode + + if self._instruction_cache is None: + left_index, stack_counter = self._left.create_instructions( + instructions, stack_counter, stack_occupied + ) + right_index, stack_counter = self._right.create_instructions( + instructions, stack_counter, stack_occupied + ) + + # FIXME: Is it possible for e.g. self._left._instruction_cache to be None? + + self._left._parent_count -= 1 + if self._left._parent_count == 0: + stack_occupied[self._left._instruction_cache] = False # type: ignore + + self._right._parent_count -= 1 + if self._right._parent_count == 0: + stack_occupied[self._right._instruction_cache] = False # type: ignore + + self._instruction_cache = select_stack_index(stack_occupied) + + if self._is_multiplication: + instructions.append( + MultiplicationInstruction(self._instruction_cache, left_index, right_index) + ) + else: + instructions.append( + AdditionInstruction(self._instruction_cache, left_index, right_index) + ) + + return self._instruction_cache, stack_counter + + +# FIXME: This order should probably change +class Addition(CommutativeArithmeticBinaryNode, ArithmeticNode): + """Performs modular addition of two previous nodes in an arithmetic circuit.""" + + @property + def _overriden_graphviz_attributes(self) -> dict: + return {"shape": "square", "style": "rounded,filled", "fillcolor": "grey80"} + + @property + def _hash_name(self) -> str: + return "add" + + @property + def _node_label(self) -> str: + return "+" + + def __init__( + self, + left: ArithmeticNode, + right: ArithmeticNode, + gf: Type[FieldArray], + ): + """Initialize a modular addition between `left` and `right`.""" + self._is_multiplication = False + super().__init__(left, right, gf) + + def _operation_inner(self, x, y): + return x + y + + def arithmetize(self, strategy: str) -> Node: # noqa: D102 + self._left = self._left.arithmetize(strategy) + self._right = self._right.arithmetize(strategy) + return self + + def _arithmetize_inner(self, strategy: str) -> Node: + raise NotImplementedError() + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + front = CostParetoFront(cost_of_squaring) + + for res1, res2 in iterate_increasing_depth( + self._left.arithmetize_depth_aware(cost_of_squaring), + self._right.arithmetize_depth_aware(cost_of_squaring), + ): + d1, _, e1 = res1 + d2, _, e2 = res2 + + # TODO: Do we use + here for flattening? + front.add(Addition(e1, e2, self._gf), depth=max(d1, d2)) + + assert not front.is_empty() + return front + + +class Multiplication(CommutativeArithmeticBinaryNode, ArithmeticNode): + """Performs modular multiplication of two previous nodes in an arithmetic circuit.""" + + @property + def _overriden_graphviz_attributes(self) -> dict: + return {"shape": "square", "style": "rounded,filled", "fillcolor": "lightpink"} + + @property + def _hash_name(self) -> str: + return "mul" + + @property + def _node_label(self) -> str: + return "ร—" # noqa: RUF001 + + def __init__( + self, + left: ArithmeticNode, + right: ArithmeticNode, + gf: Type[FieldArray], + ): + """Initialize a modular multiplication between `left` and `right`.""" + assert isinstance(left, ArithmeticNode) + assert isinstance(right, ArithmeticNode) + + self._is_multiplication = True + super().__init__(left, right, gf) + + def _operation_inner(self, x, y): + return x * y + + # TODO: This is very hacky! Arithmetic nodes should simply not have to be arithmetized... + def arithmetize(self, strategy: str) -> Node: # noqa: D102 + self._left = self._left.arithmetize(strategy) + self._right = self._right.arithmetize(strategy) + return self + + def _arithmetize_inner(self, strategy: str) -> Node: + raise NotImplementedError() + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + return CostParetoFront.from_node(self, cost_of_squaring) diff --git a/oraqle/compiler/nodes/fixed.py b/oraqle/compiler/nodes/fixed.py new file mode 100644 index 0000000..767d403 --- /dev/null +++ b/oraqle/compiler/nodes/fixed.py @@ -0,0 +1,100 @@ +"""Module containing fixed nodes: nodes with a fixed number of inputs.""" +from abc import abstractmethod +from typing import Callable, Dict, List + +from galois import FieldArray + +from oraqle.compiler.nodes.abstract import CostParetoFront, Node + + +class FixedNode(Node): + """A node with a fixed number of operands.""" + + @abstractmethod + def operands(self) -> List["Node"]: + """Returns the operands (children) of this node. The list can be empty.""" + + @abstractmethod + def set_operands(self, operands: List["Node"]): + """Overwrites the operands of this node.""" + # TODO: Consider replacing this method with a graph traversal method that applies a function on all operands and replaces them. + + + def apply_function_to_operands(self, function: Callable[[Node], None]): # noqa: D102 + for operand in self.operands(): + function(operand) + + + def replace_operands_using_function(self, function: Callable[[Node], Node]): # noqa: D102 + self.set_operands([function(operand) for operand in self.operands()]) + # TODO: These caches should only be cleared if this is an ArithmeticNode + self._multiplications = None + self._squarings = None + self._depth_cache = None + + + def evaluate(self, actual_inputs: Dict[str, FieldArray]) -> FieldArray: # noqa: D102 + # TODO: Remove modulus in this method and store it in each node instead. Alternatively, add `modulus` to methods such as `flatten` as well. + if self._evaluate_cache is None: + self._evaluate_cache = self.operation( + [operand.evaluate(actual_inputs) for operand in self.operands()] + ) + + return self._evaluate_cache + + @abstractmethod + def operation(self, operands: List[FieldArray]) -> FieldArray: + """Evaluates this node on the specified operands.""" + + def arithmetize(self, strategy: str) -> "Node": # noqa: D102 + if self._arithmetize_cache is None: + if self._arithmetize_depth_cache is not None: + return self._arithmetize_depth_cache.get_lowest_value() # type: ignore + + # If we know all operands we can simply evaluate this node + operands = self.operands() + if len(operands) > 0 and all( + hasattr(operand, "_value") for operand in operands + ): # This is a hacky way of checking whether the operands are all constant + from oraqle.compiler.nodes.leafs import Constant + + self._arithmetize_cache = Constant(self.operation([operand._value for operand in self.operands()])) # type: ignore + else: + self._arithmetize_cache = self._arithmetize_inner(strategy) + + return self._arithmetize_cache + + @abstractmethod + def _arithmetize_inner(self, strategy: str) -> "Node": + pass + + # TODO: Reduce code duplication + + def arithmetize_depth_aware(self, cost_of_squaring: float) -> CostParetoFront: # noqa: D102 + if self._arithmetize_depth_cache is None: + if self._arithmetize_cache is not None: + raise Exception("This should not happen") + + # If we know all operands we can simply evaluate this node + operands = self.operands() + if len(operands) > 0 and all( + hasattr(operand, "_value") for operand in operands + ): # This is a hacky way of checking whether the operands are all constant + from oraqle.compiler.nodes.leafs import Constant + + self._arithmetize_depth_cache = CostParetoFront.from_leaf(Constant(self.operation([operand._value for operand in self.operands()])), cost_of_squaring) # type: ignore + else: + self._arithmetize_depth_cache = self._arithmetize_depth_aware_inner( + cost_of_squaring + ) + + assert self._arithmetize_depth_cache is not None + return self._arithmetize_depth_cache + + @abstractmethod + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + pass + + +class BinaryNode(FixedNode): + """A node with two operands.""" diff --git a/oraqle/compiler/nodes/flexible.py b/oraqle/compiler/nodes/flexible.py new file mode 100644 index 0000000..b5cb0be --- /dev/null +++ b/oraqle/compiler/nodes/flexible.py @@ -0,0 +1,164 @@ +"""Module containing nodes with a flexible number of operands.""" +from abc import abstractmethod +from collections import Counter +from functools import reduce +from typing import Callable +from typing import Counter as CounterType +from typing import Dict, Optional, Set, Type + +from galois import FieldArray + +from oraqle.compiler.graphviz import DotFile +from oraqle.compiler.nodes.abstract import CostParetoFront, Node, UnoverloadedWrapper +from oraqle.compiler.nodes.leafs import Constant + + +class FlexibleNode(Node): + """A node with an arbitrary number of operands. The operation must be reducible using a binary associative operation.""" + + # TODO: Ensure that when all inputs are constants, the node is replaced with its evaluation + + def arithmetize(self, strategy: str) -> Node: # noqa: D102 + if self._arithmetize_cache is None: + self._arithmetize_cache = self._arithmetize_inner(strategy) + + return self._arithmetize_cache + + @abstractmethod + def _arithmetize_inner(self, strategy: str) -> "Node": + pass + + def arithmetize_depth_aware(self, cost_of_squaring: float) -> CostParetoFront: # noqa: D102 + if self._arithmetize_depth_cache is None: + self._arithmetize_depth_cache = self._arithmetize_depth_aware_inner(cost_of_squaring) + + return self._arithmetize_depth_cache + + @abstractmethod + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + pass + + +class CommutativeUniqueReducibleNode(FlexibleNode): + """A node with an operation that is reducible without taking order into account: i.e. it has a binary operation that is associative and commutative. + + The operands are unique, i.e. the same operand will never appear twice. + """ + + def __init__( + self, + operands: Set[UnoverloadedWrapper], + gf: Type[FieldArray], + ): + """Initialize a node with the given set as the operands. None of the operands can be a constant.""" + self._operands = operands + assert not any(isinstance(operand.node, Constant) for operand in self._operands) + assert len(operands) > 1 + super().__init__(gf) + + def apply_function_to_operands(self, function: Callable[[Node], None]): # noqa: D102 + for operand in self._operands: + function(operand.node) + + def replace_operands_using_function(self, function: Callable[[Node], Node]): # noqa: D102 + self._operands = {UnoverloadedWrapper(function(operand.node)) for operand in self._operands} + + def evaluate(self, actual_inputs: Dict[str, FieldArray]) -> FieldArray: # noqa: D102 + if self._evaluate_cache is None: + self._evaluate_cache = reduce( + self._inner_operation, + (operand.node.evaluate(actual_inputs) for operand in self._operands), + ) + + return self._evaluate_cache + + @abstractmethod + def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray: + """Perform the reducible operation performed by this node (order should not matter).""" + + def __hash__(self) -> int: + if self._hash is None: + # The hash is commutative + hashes = sorted([hash(operand) for operand in self._operands]) + self._hash = hash((self._hash_name, tuple(hashes))) + + return self._hash + + def is_equivalent(self, other: Node) -> bool: # noqa: D102 + if not isinstance(other, self.__class__): + return False + + if hash(self) != hash(other): + return False + + return self._operands == other._operands + + +class CommutativeMultiplicityReducibleNode(FlexibleNode): + """A node with an operation that is reducible without taking order into account: i.e. it has a binary operation that is associative and commutative.""" + + def __init__( + self, + operands: CounterType[UnoverloadedWrapper], + gf: Type[FieldArray], + constant: Optional[FieldArray] = None, + ): + """Initialize a reducible node with the given `Counter` representing the operands, none of which is allowed to be a constant.""" + super().__init__(gf) + self._constant = self._identity if constant is None else constant + self._operands = operands + assert not any(isinstance(operand, Constant) for operand in self._operands) + assert (sum(operands.values()) + (self._constant != self._identity)) > 1 + assert isinstance(next(iter(self._operands)), UnoverloadedWrapper) + + @property + @abstractmethod + def _identity(self) -> FieldArray: + pass + + def apply_function_to_operands(self, function: Callable[[Node], None]): # noqa: D102 + for operand in self._operands: + function(operand.node) + + def replace_operands_using_function(self, function: Callable[[Node], Node]): # noqa: D102 + # FIXME: What if there is only one operand remaining? + self._operands = Counter( + { + UnoverloadedWrapper(function(operand.node)): count + for operand, count in self._operands.items() + } + ) + assert not any(isinstance(operand.node, Constant) for operand in self._operands) + assert (sum(self._operands.values()) + (self._constant != self._identity)) > 1 + + def __hash__(self) -> int: + if self._hash is None: + # The hash is commutative + hashes = sorted( + [(hash(operand.node), count) for operand, count in self._operands.items()] + ) + self._hash = hash((self._hash_name, tuple(hashes), int(self._constant))) + + return self._hash + + def is_equivalent(self, other: Node) -> bool: # noqa: D102 + if not isinstance(other, self.__class__): + return False + + if hash(self) != hash(other): + return False + + return self._operands == other._operands and self._constant == other._constant + + def to_graph(self, graph_builder: DotFile) -> int: # noqa: D102 + if self._to_graph_cache is None: + super().to_graph(graph_builder) + self._to_graph_cache: int + + if self._constant != self._identity: + # TODO: Add known_by + graph_builder.add_link( + graph_builder.add_node(label=str(self._constant)), self._to_graph_cache + ) + + return self._to_graph_cache diff --git a/oraqle/compiler/nodes/leafs.py b/oraqle/compiler/nodes/leafs.py new file mode 100644 index 0000000..35a6725 --- /dev/null +++ b/oraqle/compiler/nodes/leafs.py @@ -0,0 +1,192 @@ +"""Module containing leaf nodes: i.e. nodes without an input.""" +from typing import Any, Dict, List, Set, Tuple, Type + +from galois import FieldArray + +from oraqle.compiler.graphviz import DotFile +from oraqle.compiler.instructions import ArithmeticInstruction, InputInstruction +from oraqle.compiler.nodes.abstract import ArithmeticNode, CostParetoFront, Node, select_stack_index +from oraqle.compiler.nodes.fixed import FixedNode + + +class ArithmeticLeafNode(FixedNode, ArithmeticNode): + """An ArithmeticLeafNode is an ArithmeticNode with no inputs.""" + + def operands(self) -> List[Node]: # noqa: D102 + return [] + + def set_operands(self, operands: List["Node"]): # noqa: D102 + pass + + def _arithmetize_inner(self, strategy: str) -> Node: + return self + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + return CostParetoFront.from_leaf(self, cost_of_squaring) + + def multiplicative_depth(self) -> int: # noqa: D102 + return 0 + + def multiplicative_size(self) -> int: # noqa: D102 + return 0 + + def multiplications(self) -> Set[int]: # noqa: D102 + return set() + + def squarings(self) -> Set[int]: # noqa: D102 + return set() + + +# TODO: Merge ArithmeticInput and Input using multiple inheritance +class Input(ArithmeticLeafNode): + """Represents a named input to the arithmetic circuit.""" + + @property + def _overriden_graphviz_attributes(self) -> dict: + return {"shape": "circle", "style": "filled", "fillcolor": "lightsteelblue1"} + + @property + def _hash_name(self) -> str: + return "input" + + @property + def _node_label(self) -> str: + return self._name + + def __init__(self, name: str, gf: Type[FieldArray]) -> None: + """Initialize an input with the given `name`.""" + super().__init__(gf) + self._name = name + + + def operation(self, operands: List[FieldArray]) -> FieldArray: # noqa: D102 + raise Exception() + + + def evaluate(self, actual_inputs: Dict[str, FieldArray]) -> FieldArray: # noqa: D102 + return actual_inputs[self._name] + + + def to_graph(self, graph_builder: DotFile) -> int: # noqa: D102 + if self._to_graph_cache is None: + label = self._name + + self._to_graph_cache = graph_builder.add_node( + label=label, **self._overriden_graphviz_attributes + ) + + return self._to_graph_cache + + def __hash__(self) -> int: + return hash(self._name) + + + def is_equivalent(self, other: Node) -> bool: # noqa: D102 + if not isinstance(other, self.__class__): + return False + + return self._name == other._name + + + def create_instructions( # noqa: D102 + self, + instructions: List[ArithmeticInstruction], + stack_counter: int, + stack_occupied: List[bool], + ) -> Tuple[int, int]: + if self._instruction_cache is None: + self._instruction_cache = select_stack_index(stack_occupied) + instructions.append(InputInstruction(self._instruction_cache, self._name)) + + return self._instruction_cache, stack_counter + + +class Constant(ArithmeticLeafNode): + """Represents a Node with a constant value.""" + + @property + def _overriden_graphviz_attributes(self) -> dict: + return {"style": "filled", "fillcolor": "red", "shape": "circle"} + + @property + def _hash_name(self) -> str: + return "constant" + + @property + def _node_label(self) -> str: + return str(self._value) + + def __init__(self, value: FieldArray): + """Initialize a Node with the given `value`.""" + super().__init__(value.__class__) + self._value = value + + + def operation(self, operands: List[FieldArray]) -> FieldArray: # noqa: D102 + return self._value + + + def to_graph(self, graph_builder: DotFile) -> Any: # noqa: D102 + if self._to_graph_cache is None: + label = str(self._value) + + self._to_graph_cache = graph_builder.add_node( + label=label, **self._overriden_graphviz_attributes + ) + + return self._to_graph_cache + + def __hash__(self) -> int: + return hash(int(self._value)) + + + def is_equivalent(self, other: Node) -> bool: # noqa: D102 + if not isinstance(other, self.__class__): + return False + + return self._value == other._value + + + def add(self, other: "Node", flatten=True) -> "Node": # noqa: D102 + if isinstance(other, Constant): + return Constant(self._value + other._value) + + return other.add(self, flatten) + + + def mul(self, other: "Node", flatten=True) -> "Node": # noqa: D102 + if isinstance(other, Constant): + return Constant(self._value * other._value) + + return other.mul(self, flatten) + + + def bool_or(self, other: "Node", flatten=True) -> Node: # noqa: D102 + if isinstance(other, Constant): + return Constant(self._gf(bool(self._value) | bool(other._value))) + + return other.bool_or(self, flatten) + + def bool_and(self, other: "Node", flatten=True) -> Node: # noqa: D102 + if isinstance(other, Constant): + return Constant(self._gf(bool(self._value) & bool(other._value))) + + return other.bool_and(self, flatten) + + def create_instructions( # noqa: D102 + self, + instructions: List[ArithmeticInstruction], + stack_counter: int, + stack_occupied: List[bool], + ) -> Tuple[int]: + raise NotImplementedError("The circuit is a constant.") + + +class DummyNode(FixedNode): + """A DummyNode is a fixed node with no inputs and no behavior.""" + + def operands(self) -> List[Node]: # noqa: D102 + return [] + + def set_operands(self, operands: List["Node"]): # noqa: D102 + pass diff --git a/oraqle/compiler/nodes/non_commutative.py b/oraqle/compiler/nodes/non_commutative.py new file mode 100644 index 0000000..4a069b8 --- /dev/null +++ b/oraqle/compiler/nodes/non_commutative.py @@ -0,0 +1,69 @@ +"""A collection of abstract nodes representing operations that are non-commutative.""" +from abc import abstractmethod +from typing import List, Type + +from galois import FieldArray + +from oraqle.compiler.graphviz import DotFile +from oraqle.compiler.nodes.abstract import Node +from oraqle.compiler.nodes.fixed import BinaryNode + + +class NonCommutativeBinaryNode(BinaryNode): + """Represents a non-cummutative binary operation such as `x < y` or `x - y`.""" + + def __init__(self, left, right, gf: Type[FieldArray]): + """Initialize a Node that performs an operation between two operands that is not commutative.""" + self._left = left + self._right = right + super().__init__(gf) + + @abstractmethod + def _operation_inner(self, x, y) -> FieldArray: + """Applies the binary operation on x and y.""" + + def operation(self, operands: List[FieldArray]) -> FieldArray: # noqa: D102 + return self._operation_inner(operands[0], operands[1]) + + def operands(self) -> List[Node]: # noqa: D102 + return [self._left, self._right] + + def set_operands(self, operands: List["Node"]): # noqa: D102 + self._left = operands[0] + self._right = operands[1] + + def __hash__(self) -> int: + if self._hash is None: + left_hash = hash(self._left) + right_hash = hash(self._right) + + self._hash = hash((self._hash_name, (left_hash, right_hash))) + + return self._hash + + def is_equivalent(self, other: Node) -> bool: # noqa: D102 + if not isinstance(other, self.__class__): + return False + + if hash(self) != hash(other): + return False + + return self._left.is_equivalent(other._left) and self._right.is_equivalent(other._right) + + def to_graph(self, graph_builder: DotFile) -> int: # noqa: D102 + if self._to_graph_cache is None: + attributes = {"shape": "box"} + attributes.update(self._overriden_graphviz_attributes) + + self._to_graph_cache = graph_builder.add_node( + label=self._node_label, + **attributes, + ) + + left = self._left.to_graph(graph_builder) + right = self._right.to_graph(graph_builder) + + graph_builder.add_link(left, self._to_graph_cache, headport="nw") + graph_builder.add_link(right, self._to_graph_cache, headport="ne") + + return self._to_graph_cache diff --git a/oraqle/compiler/nodes/unary_arithmetic.py b/oraqle/compiler/nodes/unary_arithmetic.py new file mode 100644 index 0000000..5e74d64 --- /dev/null +++ b/oraqle/compiler/nodes/unary_arithmetic.py @@ -0,0 +1,217 @@ +"""This module contains `ArithmeticNode`s with a single input: Constant additions and constant multiplications.""" +from typing import List, Optional, Set, Tuple + +from galois import FieldArray + +from oraqle.compiler.graphviz import DotFile +from oraqle.compiler.instructions import ( + ArithmeticInstruction, + ConstantAdditionInstruction, + ConstantMultiplicationInstruction, +) +from oraqle.compiler.nodes.abstract import ArithmeticNode, CostParetoFront, Node, select_stack_index +from oraqle.compiler.nodes.univariate import UnivariateNode + +# TODO: There is (going to be) a lot of code duplication between these two classes + + +class ConstantAddition(UnivariateNode, ArithmeticNode): + """This node represents a multiplication of another node with a constant.""" + + @property + def _overriden_graphviz_attributes(self) -> dict: + return {"style": "rounded,filled", "fillcolor": "grey80"} + + @property + def _node_shape(self) -> str: + return "square" + + @property + def _hash_name(self) -> str: + return f"constant_add_{self._constant}" + + @property + def _node_label(self) -> str: + return "+" + + def __init__(self, node: ArithmeticNode, constant: FieldArray): + """Represents the operation `constant + node`.""" + super().__init__(node, constant.__class__) + self._constant = constant + assert constant != 0 + + self._depth_cache: Optional[int] = None + + + def _operation_inner(self, input: FieldArray) -> FieldArray: + return input + self._constant + + + def multiplicative_depth(self) -> int: # noqa: D102 + if self._depth_cache is None: + self._depth_cache = self._node.multiplicative_depth() + + return self._depth_cache + + + def multiplications(self) -> Set[int]: # noqa: D102 + return self._node.multiplications() + + + def squarings(self) -> Set[int]: # noqa: D102 + return self._node.squarings() + + + def create_instructions( # noqa: D102 + self, + instructions: List[ArithmeticInstruction], + stack_counter: int, + stack_occupied: List[bool], + ) -> Tuple[int, int]: + self._node: ArithmeticNode + + if self._instruction_cache is None: + operand_index, stack_counter = self._node.create_instructions( + instructions, stack_counter, stack_occupied + ) + + self._node._parent_count -= 1 + if self._node._parent_count == 0: + stack_occupied[self._node._instruction_cache] = False # type: ignore + + self._instruction_cache = select_stack_index(stack_occupied) + + instructions.append( + ConstantAdditionInstruction(self._instruction_cache, operand_index, self._constant) + ) + + return self._instruction_cache, stack_counter + + + def _arithmetize_inner(self, strategy: str) -> Node: + return self + + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + front = CostParetoFront(cost_of_squaring) + for _, _, node in self._node.arithmetize_depth_aware(cost_of_squaring): + front.add(ConstantAddition(node, self._constant)) + return front + + + def to_graph(self, graph_builder: DotFile) -> int: # noqa: D102 + if self._to_graph_cache is None: + super().to_graph(graph_builder) + self._to_graph_cache: int + + # TODO: Add known_by + graph_builder.add_link( + graph_builder.add_node( + label=str(self._constant), shape="circle", style="filled", fillcolor="grey92" + ), + self._to_graph_cache, + ) + + return self._to_graph_cache + + +class ConstantMultiplication(UnivariateNode, ArithmeticNode): + """This node represents a multiplication of another node with a constant.""" + + @property + def _overriden_graphviz_attributes(self) -> dict: + return {"style": "rounded,filled", "fillcolor": "grey80"} + + @property + def _node_shape(self) -> str: + return "square" + + @property + def _hash_name(self) -> str: + return f"constant_mul_{self._constant}" + + @property + def _node_label(self) -> str: + return "ร—" # noqa: RUF001 + + def __init__(self, node: Node, constant: FieldArray): + """Represents the operation `constant * node`.""" + super().__init__(node, constant.__class__) + self._constant = constant + assert constant != 0 + assert constant != 1 + + self._depth_cache: Optional[int] = None + + def _operation_inner(self, input: FieldArray) -> FieldArray: + return input * self._constant # type: ignore + + + def multiplicative_depth(self) -> int: # noqa: D102 + if self._depth_cache is None: + self._depth_cache = self._node.multiplicative_depth() # type: ignore + + return self._depth_cache # type: ignore + + + def multiplications(self) -> Set[int]: # noqa: D102 + return self._node.multiplications() # type: ignore + + + def squarings(self) -> Set[int]: # noqa: D102 + return self._node.squarings() # type: ignore + + + def create_instructions( # noqa: D102 + self, + instructions: List[ArithmeticInstruction], + stack_counter: int, + stack_occupied: List[bool], + ) -> Tuple[int, int]: + self._node: ArithmeticNode + + if self._instruction_cache is None: + operand_index, stack_counter = self._node.create_instructions( + instructions, stack_counter, stack_occupied + ) + + self._node._parent_count -= 1 + if self._node._parent_count == 0: + stack_occupied[self._node._instruction_cache] = False # type: ignore + + self._instruction_cache = select_stack_index(stack_occupied) + + instructions.append( + ConstantMultiplicationInstruction( + self._instruction_cache, operand_index, self._constant + ) + ) + + return self._instruction_cache, stack_counter + + + def _arithmetize_inner(self, strategy: str) -> Node: + return self + + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + front = CostParetoFront(cost_of_squaring) + for _, _, node in self._node.arithmetize_depth_aware(cost_of_squaring): + front.add(ConstantMultiplication(node, self._constant)) + return front + + + def to_graph(self, graph_builder: DotFile) -> int: # noqa: D102 + if self._to_graph_cache is None: + super().to_graph(graph_builder) + self._to_graph_cache: int + + # TODO: Add known_by + graph_builder.add_link( + graph_builder.add_node( + label=str(self._constant), shape="circle", style="filled", fillcolor="grey92" + ), + self._to_graph_cache, + ) + + return self._to_graph_cache diff --git a/oraqle/compiler/nodes/univariate.py b/oraqle/compiler/nodes/univariate.py new file mode 100644 index 0000000..35f3a19 --- /dev/null +++ b/oraqle/compiler/nodes/univariate.py @@ -0,0 +1,81 @@ +"""Abstract nodes for univariate operations.""" + +from abc import abstractmethod +from typing import List, Type + +from galois import FieldArray + +from oraqle.compiler.graphviz import DotFile +from oraqle.compiler.nodes.abstract import Node +from oraqle.compiler.nodes.fixed import FixedNode +from oraqle.compiler.nodes.leafs import Constant + + +class UnivariateNode(FixedNode): + """An abstract node with a single input.""" + + @property + @abstractmethod + def _node_shape(self) -> str: + """Graphviz node shape.""" + + def __init__(self, node: Node, gf: Type[FieldArray]): + """Initialize a univariate node.""" + self._node = node + assert not isinstance(node, Constant) + super().__init__(gf) + + + def operands(self) -> List["Node"]: # noqa: D102 + return [self._node] + + + def set_operands(self, operands: List["Node"]): # noqa: D102 + self._node = operands[0] + + @abstractmethod + def _operation_inner(self, input: FieldArray) -> FieldArray: + """Evaluate the operation on the input. This method does not have to cache.""" + + + def operation(self, operands: List[FieldArray]) -> FieldArray: # noqa: D102 + return self._operation_inner(operands[0]) + + + def to_graph(self, graph_builder: DotFile) -> int: # noqa: D102 + if self._to_graph_cache is None: + attributes = {} + + attributes.update(self._overriden_graphviz_attributes) + + self._to_graph_cache = graph_builder.add_node( + label=self._node_label, shape=self._node_shape, **attributes + ) + + graph_builder.add_link(self._node.to_graph(graph_builder), self._to_graph_cache) + + return self._to_graph_cache + + def __hash__(self) -> int: + if self._hash is None: + self._hash = hash((self._hash_name, self._node)) + + return self._hash + + def is_equivalent(self, other: Node) -> bool: + """Check whether `self` is semantically equivalent to `other`. + + This function may have false negatives but it should never return false positives. + + Returns: + ------- + `True` if `self` is semantically equivalent to `other`, `False` if they are not or that they cannot be shown to be equivalent. + + """ + if not isinstance(other, self.__class__): + return False + + if hash(self) != hash(other): + return False + + return self._node.is_equivalent(other._node) diff --git a/oraqle/compiler/poly2circuit.py b/oraqle/compiler/poly2circuit.py new file mode 100644 index 0000000..c481b27 --- /dev/null +++ b/oraqle/compiler/poly2circuit.py @@ -0,0 +1,149 @@ +"""Module for automatic circuit generation for any functions with any number of inputs. + +Warning: These circuits can be very large! +""" + +from collections import Counter +from typing import Dict, List, Tuple, Type + +from galois import GF, FieldArray +from sympy import Add, Integer, Mul, Poly, Pow, Symbol +from sympy.core.numbers import NegativeOne + +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.func2poly import interpolate_polynomial +from oraqle.compiler.nodes import Constant, Input, Node +from oraqle.compiler.nodes.abstract import UnoverloadedWrapper +from oraqle.compiler.nodes.arbitrary_arithmetic import Product + + +def construct_subcircuit(expression, gf, modulus: int, inputs: Dict[str, Input]) -> Node: # noqa: PLR0912 + """Build a circuit with a single output given an expression of simple arithmetic operations in Sympy. + + Raises: + ------ + Exception: Exponents must be integers, or an exception will be raised. + + Returns: + ------- + A subcircuit (Node) computing the given sympy expression. + + """ + if expression.func == Add: + arg_iter = iter(expression.args) + + # The first argument can be a scalar. + first = next(arg_iter) + if first.func in {Integer, NegativeOne}: + if first.func == Integer: + scalar = Constant(gf(int(first) % modulus)) + else: + scalar = Constant(-gf(1)) + result = scalar + construct_subcircuit(next(arg_iter), gf, modulus, inputs) + else: + # TODO: Replace this entire part with a sum + result = construct_subcircuit(first, gf, modulus, inputs) + construct_subcircuit( + next(arg_iter), gf, modulus, inputs + ) + + for arg in arg_iter: + result = construct_subcircuit(arg, gf, modulus, inputs) + result + + return result + elif expression.func == Mul: + arg_iter = iter(expression.args) + + # The first argument can be a scalar. + first = next(arg_iter) + if first.func in {Integer, NegativeOne}: + if first.func == Integer: + scalar = Constant(gf(int(first) % modulus)) + else: + scalar = Constant(-gf(1)) + result = scalar * construct_subcircuit(next(arg_iter), gf, modulus, inputs) + else: + # TODO: Replace this entire part with a product + result = construct_subcircuit(first, gf, modulus, inputs) * construct_subcircuit( + next(arg_iter), gf, modulus, inputs + ) + + for arg in arg_iter: + result = construct_subcircuit(arg, gf, modulus, inputs) * result + + return result + elif expression.func == Pow: + if expression.args[1].func != Integer: + raise Exception("There was an exponent with a non-integer exponent") + # Change powers to series of multiplications + subcircuit = construct_subcircuit(expression.args[0], gf, modulus, inputs) + # TODO: This is not the most efficient way; we can use re-balancing. + return Product( + Counter({UnoverloadedWrapper(subcircuit): int(expression.args[1])}), gf + ) # FIXME: This could be flattened + elif expression.func == Symbol: + assert len(expression.args) == 0 + var = str(expression) + if var in inputs: + return inputs[var] + new_input = Input(var, gf) + inputs[var] = new_input + return new_input + else: + raise Exception( + f"The expression contained an invalid operation (not one implemented in arithmetic circuits): {expression.func}." + ) + + +def construct_circuit(polynomials: List[Poly], modulus: int) -> Tuple[Circuit, Type[FieldArray]]: + """Construct an arithmetic circuit from a list of polynomials and the fixed modulus. + + Returns: + ------- + A circuit outputting the evaluation of each polynomial. + + """ + inputs = {} + gf = GF(modulus) + return ( + Circuit( + [construct_subcircuit(poly.expr, gf, modulus, inputs) for poly in polynomials], + ), + gf, + ) + + +if __name__ == "__main__": + # Use function max(x, y) + function = max + modulus = 7 + + # Create a polynomial and then a circuit that evalutes this expression + poly = interpolate_polynomial(function, modulus, ["x", "y"]) + circuit, gf = construct_circuit([poly], modulus) + + # Output a DOT file for this high-level circuit (you can visualize it using https://dreampuf.github.io/GraphvizOnline/) + circuit.to_graph("max_7_hl.dot") + + # Arithmetize the high-level circuit, afterwards it will only contain arithmetic operations + circuit = circuit.arithmetize() + circuit.to_graph("max_7_hl.dot") + + # Print the initial metrics of the circuit + print("depth", circuit.multiplicative_depth()) + print("size", circuit.multiplicative_size()) + + # Apply common subexpression elimination (CSE) to remove duplicate operations from the circuit + circuit.eliminate_subexpressions() + + # Output a DOT file for this arithmetic circuit (you can visualize it using https://dreampuf.github.io/GraphvizOnline/) + circuit.to_graph("max_7.dot") + + # Print the resulting metrics of the circuit + print("depth", circuit.multiplicative_depth()) + print("size", circuit.multiplicative_size()) + + # Test that given x=4 and y=2 indeed max(x, y) = 4 + assert circuit.evaluate({"x": gf(4), "y": gf(2)}) == [4] + + # Output a DOT file for this arithmetic circuit (you can visualize it using https://dreampuf.github.io/GraphvizOnline/) + circuit.to_graph("max_7.dot") diff --git a/oraqle/compiler/polynomials/__init__.py b/oraqle/compiler/polynomials/__init__.py new file mode 100644 index 0000000..1ca2682 --- /dev/null +++ b/oraqle/compiler/polynomials/__init__.py @@ -0,0 +1,5 @@ +"""The polynomials package contains nodes for performing polynomial evaluation. + +In a finite field, the set of polyfunctions is the same as the set of all functions. +So, you can perform any function by interpolating a polynomial. +""" diff --git a/oraqle/compiler/polynomials/univariate.py b/oraqle/compiler/polynomials/univariate.py new file mode 100644 index 0000000..cd306c8 --- /dev/null +++ b/oraqle/compiler/polynomials/univariate.py @@ -0,0 +1,620 @@ +"""Evaluation of univariate polynomials.""" + +import math +from typing import Callable, Dict, List, Optional, Tuple, Type + +from galois import GF, FieldArray + +from oraqle.add_chains.addition_chains_heuristic import add_chain_guaranteed +from oraqle.compiler.arithmetic.subtraction import Subtraction +from oraqle.compiler.func2poly import interpolate_polynomial +from oraqle.compiler.nodes.abstract import ArithmeticNode, CostParetoFront, Node +from oraqle.compiler.nodes.binary_arithmetic import Multiplication +from oraqle.compiler.nodes.leafs import Constant, Input +from oraqle.compiler.nodes.unary_arithmetic import ConstantMultiplication +from oraqle.compiler.nodes.univariate import UnivariateNode +from oraqle.config import PS_METHOD_FACTOR_K + + +def _format_polynomial(coefficients: List[FieldArray]) -> str: + degree = len(coefficients) - 1 + if degree == 0: + return str(coefficients[0]) + + terms = [] + for i, coef in enumerate(coefficients): + if coef == 0: + # Skip zero coefficients + continue + + term = str(coef) if i == 0 or coef > 1 else "" + + if i > 0: + term += "x" + + if i > 1: + term += f"^{i}" + + if term != "": + terms.append(term) + + polynomial = " + ".join(terms) + return polynomial + + +class UnivariatePoly(UnivariateNode): + """Evaluation of a univariate polynomial.""" + + @property + def _node_shape(self) -> str: + return "box" + + @property + def _hash_name(self) -> str: + return "univariate_poly" + + @property + def _node_label(self) -> str: + return _format_polynomial(self._coefficients) + + def __init__( + self, + node: Node, + coefficients: List[FieldArray], + gf: Type[FieldArray], + ): + """Initialize a univariate polynomial with the given coefficients from least to highest order.""" + self._coefficients = coefficients + # TODO: We can reduce this polynomial if its degree is too high + super().__init__(node, gf) + + self._custom_arithmetize_cache = None + + @classmethod + def from_function( + cls, node: Node, gf: Type[FieldArray], function: Callable[[int], int] + ) -> "UnivariatePoly": + """Interpolate a univariate polynomial for the given function. + + Returns: + ------- + A UnivariatePoly whose coefficients compute the `function` on all inputs. + + """ + coefficients = [ + gf(int(coeff) % gf.characteristic) + for coeff in reversed( + interpolate_polynomial(function, gf.characteristic, ["x"]).as_list() + ) + ] + return cls(node, coefficients, gf) + + def _operation_inner(self, input: FieldArray) -> FieldArray: + coefficient_iter = iter(self._coefficients) + result = next(coefficient_iter).copy() + + x_pow = input.copy() + for coefficient in coefficient_iter: + result += coefficient * x_pow + x_pow *= input + + return result # type: ignore + + def _arithmetize_inner(self, strategy: str) -> Node: + return self.arithmetize_custom(strategy)[0] + + def arithmetize_custom(self, strategy: str) -> Tuple[ArithmeticNode, Dict[int, ArithmeticNode]]: + """Compute an arithmetization along with a dictionary of precomputed powers. + + Returns: + ------- + An arithmetization and a dictionary of previously computed powers. + + """ + if len(self._coefficients) == 0: + return Constant(self._gf(0)), {} + + if len(self._coefficients) == 1: + return Constant(self._coefficients[0]), {} + + x = self._node.arithmetize(strategy).to_arithmetic() + + best_arithmetization: Optional[Node] = None + best_arithmetization_powers = None + + lowest_multiplicative_size = 1_000_000_000 # TODO: Not elegant + optimal_k = math.sqrt(2 * len(self._coefficients)) + bound = min(int(math.ceil(PS_METHOD_FACTOR_K * optimal_k)), len(self._coefficients)) + for k in range(1, bound): + ( + arithmetization, + precomputed_powers, + ) = _eval_poly(x, self._coefficients, k, self._gf, 1.0) + + arithmetization = arithmetization.to_arithmetic() + # TODO: It would be best to perform CSE during the circuit creation + assert isinstance(arithmetization, ArithmeticNode) + + if arithmetization.multiplicative_size() <= lowest_multiplicative_size: + lowest_multiplicative_size = arithmetization.multiplicative_size() + best_arithmetization = arithmetization + best_arithmetization_powers = precomputed_powers + + # TODO: Also perform the alternative poly evaluation + + # TODO: This check is probably unnecessary + assert best_arithmetization is not None + assert best_arithmetization_powers is not None + + return ( + best_arithmetization.arithmetize(strategy), + best_arithmetization_powers, + ) + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + return self.arithmetize_depth_aware_custom(cost_of_squaring)[0] + + def arithmetize_depth_aware_custom( + self, cost_of_squaring: float + ) -> Tuple[CostParetoFront, Dict[int, Dict[int, ArithmeticNode]]]: + """Compute a depth-aware arithmetization as well as a dictionary indexed by the depth of the nodes in the front. The dictionary stores precomputed powers. + + Returns: + ------- + A CostParetoFront with the depth-aware arithmetization and a dictionary indexed by the depth of the nodes in the front, returning a dictionary with previously computed powers. + + """ + # TODO: Perhaps this should be cached + if len(self._coefficients) == 0: + return CostParetoFront.from_leaf(Constant(self._gf(0)), cost_of_squaring), {0: {}} + + if len(self._coefficients) == 1: + return CostParetoFront.from_leaf(Constant(self._coefficients[0]), cost_of_squaring), { + 0: {} + } + + front = CostParetoFront(cost_of_squaring) + all_precomputed_powers = {} + + for _, _, x in self._node.arithmetize_depth_aware(cost_of_squaring): + optimal_k = math.sqrt(2 * len(self._coefficients)) + bound = min(int(math.ceil(PS_METHOD_FACTOR_K * optimal_k)), len(self._coefficients)) + for k in range(1, bound): + ( + arithmetization, + precomputed_powers, + ) = _eval_poly(x, self._coefficients, k, self._gf, cost_of_squaring) + + arithmetization = arithmetization.to_arithmetic() + assert isinstance(arithmetization, ArithmeticNode) + + added = front.add(arithmetization) + if added: + all_precomputed_powers[arithmetization.multiplicative_depth()] = ( + precomputed_powers + ) + + for k in range(1, len(self._coefficients)): + ( + arithmetization, + precomputed_powers, + ) = _eval_poly_divide_conquer(x, self._coefficients, k, self._gf, cost_of_squaring) + + arithmetization = arithmetization.to_arithmetic() + assert isinstance(arithmetization, ArithmeticNode) + + added = front.add(arithmetization) + if added: + all_precomputed_powers[arithmetization.multiplicative_depth()] = ( + precomputed_powers + ) + + for k in range(1, len(self._coefficients)): + ( + arithmetization, + precomputed_powers, + ) = _eval_poly_alternative(x, self._coefficients, k, self._gf) + + arithmetization = arithmetization.to_arithmetic() + assert isinstance(arithmetization, ArithmeticNode) + + added = front.add(arithmetization) + if added: + all_precomputed_powers[arithmetization.multiplicative_depth()] = ( + precomputed_powers + ) + + precomputed_powers = {depth: all_precomputed_powers[depth] for depth, _, _ in front} + return front, precomputed_powers + + +def _monic_euclidean_division( + a: List[FieldArray], b: List[FieldArray], gf +) -> Tuple[List[FieldArray], List[FieldArray]]: + q = [gf(0) for _ in range(len(a))] + r = [el.copy() for el in a] + d = len(b) - 1 + c = b[-1].copy() + assert c == 1 + while (len(r) - 1) >= d: + if r[-1] == 0: + r.pop() + continue + + s_monomial = len(r) - 1 - d + f = r[-1] + q[s_monomial] += f + + for i in range(d + 1): + r[s_monomial + i] -= f * b[i] + r.pop() + + while len(q) > 0 and q[-1] == 0: + q.pop() + + return q, r + + +def _eval_poly_using_precomputed_ks( + coefficients: List[FieldArray], precomputed_ks: List[ArithmeticNode], gf +) -> ArithmeticNode: + if len(coefficients) == 0: + return Constant(gf(0)) + + # TODO: What if the constant is 0? Do we want to rely on no-op removal later or do it here already? + output = Constant(coefficients[0]) + + for i in range(1, len(coefficients)): + if coefficients[i] == 0: + continue + + if coefficients[i] == 1: + output += precomputed_ks[i - 1] + continue + + output += ( + Constant(coefficients[i]).mul(precomputed_ks[i - 1], flatten=False) + ) # FIXME: Consider just using * + + return output.arithmetize("best-effort").to_arithmetic() + + +def _eval_monic_poly_specific( + coefficients: List[FieldArray], + precomputed_ks: List[ArithmeticNode], + precomputed_pow2s: List[ArithmeticNode], + gf, + p: int, +) -> ArithmeticNode: + if all(c == 0 for c in coefficients): + return Constant(gf(0)) + + degree = len(coefficients) - 1 + + # Base case, this is free after precomputation + if degree <= len(precomputed_ks): + return _eval_poly_using_precomputed_ks(coefficients, precomputed_ks, gf) + + assert degree % len(precomputed_ks) == 0 + assert ((degree // len(precomputed_ks)) + 1) % 2 == 0 + + k = len(precomputed_ks) + assert p == (((degree // k) + 1) // 2) + + r = coefficients[: (k * p - 1) + 1] + q = coefficients[(k * p - 1) + 1 :] + + assert (len(q) - 1) == k * (p - 1) + + r[k * (p - 1)] = r[k * (p - 1)].copy() - gf(1) + c, s = _monic_euclidean_division(r, q, gf) + assert len(c) - 1 <= (len(precomputed_ks) - 1) + + monomial = precomputed_pow2s[int(math.log2(p))] + + c_output = _eval_poly_using_precomputed_ks(c, precomputed_ks, gf) + + left = monomial.add(c_output, flatten=False) + right = _eval_monic_poly_specific(q, precomputed_ks, precomputed_pow2s, gf, p // 2) + + s.append(gf(1)) # This adds the monomial + assert (len(s) - 1) == k * (p - 1) + remainder = _eval_monic_poly_specific(s, precomputed_ks, precomputed_pow2s, gf, p // 2) + + final_product = left.mul(right, flatten=False) + return ( + final_product.add(remainder, flatten=False).arithmetize("best-effort").to_arithmetic() + ) # TODO: Strategy + + +def _precompute_ks(x: ArithmeticNode, k: int) -> List[ArithmeticNode]: + # TODO: We can use an addition sequence for this to reduce the multiplicative cost + ks = [x] + for _ in range(math.ceil(math.log2(k))): + last = ks[-1] + new_ks = [] + for pre in ks: + new_ks.append(Multiplication(pre, last, pre._gf)) + ks.extend(new_ks) + + return ks[:k] + + +def _compute_extended_monomial( + x: ArithmeticNode, + precomputed_powers: Dict[int, ArithmeticNode], + target: int, + gf: Type[FieldArray], + squaring_cost: float, +) -> ArithmeticNode: + if target == 0: + return Constant(gf(1)) + + # TODO: Use squaring_cost + p = gf.characteristic + precomputed_values = tuple( + ( + exp % (p - 1), + power_node.multiplicative_depth() - x.multiplicative_depth(), + ) + for exp, power_node in precomputed_powers.items() + ) + # TODO: This is copied from Power, but in the future we can probably remove this if we have augmented circuits + addition_chain = add_chain_guaranteed(target, modulus=p - 1, squaring_cost=squaring_cost, precomputed_values=precomputed_values) + + nodes = [x] + nodes.extend(power_node for _, power_node in precomputed_powers.items()) + + for i, j in addition_chain: + nodes.append(Multiplication(nodes[i], nodes[j], gf)) + + return nodes[-1] + + +def _eval_poly( + x: ArithmeticNode, + coefficients: List[FieldArray], + k: int, + gf: Type[FieldArray], + squaring_cost: float, +) -> Tuple[ArithmeticNode, Dict[int, ArithmeticNode]]: + # Paterson & Stockmeyer's algorithm + degree = len(coefficients) - 1 + precomputed_ks = _precompute_ks(x, k) + precomputed_powers = { + i % (gf.characteristic - 1): node for i, node in zip(range(1, k + 1), precomputed_ks) + } + + # Find the largest p such that k(2^p-1) >= degree + p = 0 + while True: + p += 1 + if (2**p - 1) * k >= degree: + break + + new_degree = (2**p - 1) * k + precomputed_pow2s = [precomputed_ks[-1]] + for j in range(p - 1): # TODO: Check if p - 1 is enough + precomputed_pow2s.append( + Multiplication(precomputed_pow2s[-1], precomputed_pow2s[-1], precomputed_pow2s[-1]._gf) + ) + precomputed_powers[(k * (2 ** (j + 1))) % (gf.characteristic - 1)] = precomputed_pow2s[-1] + + # Pad to the next degree k * (2^p - 1) monic polynomial + new_coefficients = [gf(0) for _ in range(new_degree + 1)] + for j, c in enumerate(coefficients): + new_coefficients[j] = c.copy() + + extended = new_coefficients[-1] == 0 + factor = gf(1) + if int(new_coefficients[-1]) > 1: + # The polynomial is not monic + inverse = coefficients[-1] ** -1 + new_coefficients = [inverse * c for c in coefficients] + factor = coefficients[-1] + + new_coefficients[-1] = gf(1) + + monomial_index = new_degree % (gf.characteristic - 1) + if monomial_index == 0: + monomial_index = gf.characteristic - 1 + if extended and monomial_index <= degree: + # In some cases we can eliminate the added monomial by changing the coefficients + new_coefficients[monomial_index] -= gf(1) + extended = False + + evaluation = _eval_monic_poly_specific( + new_coefficients, precomputed_ks, precomputed_pow2s, gf, 2**p // 2 + ) + + if extended: + monomial = _compute_extended_monomial( + x, precomputed_powers, new_degree % (gf.characteristic - 1), gf, squaring_cost + ) + precomputed_powers[new_degree % (gf.characteristic - 1)] = monomial + evaluation = ( + Subtraction(evaluation, monomial, gf).arithmetize("best-effort").to_arithmetic() + ) # TODO: We should not have to choose a strategy here + + if int(factor) > 1: + # Make up for the missing factor + evaluation = ConstantMultiplication(evaluation, factor) + + return evaluation, precomputed_powers + + +def _eval_poly_alternative( + x: ArithmeticNode, coefficients: List[FieldArray], k: int, gf: Type[FieldArray] +) -> Tuple[Node, Dict[int, ArithmeticNode]]: + # Baby-step giant-step algorithm + assert len(coefficients) > 0 + + i = len(coefficients) - 1 + while coefficients[i] == 0: + i -= 1 + coefficients = [coefficients[j].copy() for j in range(i + 1)] # Copies and trims the coefficients + + # Precompute x, x^2, ..., x^k + precomputed_ks = _precompute_ks(x, k) + precomputed_powers = { + i % (gf.characteristic - 1): node for i, node in zip(range(1, k + 1), precomputed_ks) + } + + # Process the first chunk + chunk = coefficients[-(k + 1) :] + aggregator = _eval_poly_using_precomputed_ks(chunk, precomputed_ks, gf) + coefficients = coefficients[: -(k + 1)] + + # Go through the coefficients, chunk by chunk + while len(coefficients) >= k: + chunk = coefficients[-k:] + aggregator = aggregator * precomputed_ks[-1] + _eval_poly_using_precomputed_ks( + chunk, precomputed_ks, gf + ) + coefficients = coefficients[:-k] + + # If there is a small chunk remaining + if len(coefficients) > 0: + aggregator = aggregator * precomputed_ks[ + len(coefficients) - 1 + ] + _eval_poly_using_precomputed_ks(coefficients, precomputed_ks, gf) + + return aggregator, precomputed_powers + + +def _eval_poly_divide_conquer_specific( + coefficients: List[FieldArray], + precomputed_ks: List[ArithmeticNode], + precomputed_pow2s: List[ArithmeticNode], + gf, + p: int, +) -> ArithmeticNode: + if all(c == 0 for c in coefficients): + return Constant(gf(0)) + + degree = len(coefficients) - 1 + + # Base case, this is free after precomputation + if degree <= len(precomputed_ks): + return _eval_poly_using_precomputed_ks(coefficients, precomputed_ks, gf) + + assert degree / 2 <= (len(precomputed_ks) * p) + + subdegree = p * len(precomputed_ks) + r = coefficients[:subdegree] + q = coefficients[subdegree:] + + r_eval = _eval_poly_divide_conquer_specific(r, precomputed_ks, precomputed_pow2s, gf, p // 2) + q_eval = _eval_poly_divide_conquer_specific(q, precomputed_ks, precomputed_pow2s, gf, p // 2) + + final_product = q_eval.mul(precomputed_pow2s[int(math.log2(p))], flatten=False) + return ( + final_product.add(r_eval, flatten=False).arithmetize("best-effort").to_arithmetic() + ) # TODO: Strategy + + +def _eval_poly_divide_conquer( + x: ArithmeticNode, + coefficients: List[FieldArray], + k: int, + gf: Type[FieldArray], + _squaring_cost: float, +) -> Tuple[ArithmeticNode, Dict[int, ArithmeticNode]]: + # Divide-and-conquer algorithm + # TODO: Reduce code duplication with poly_eval + degree = len(coefficients) - 1 + precomputed_ks = _precompute_ks(x, k) + precomputed_powers = { + i % (gf.characteristic - 1): node for i, node in zip(range(1, k + 1), precomputed_ks) + } + + # Find the largest p such that k * 2^p >= degree + p = 0 + while True: + p += 1 + if 2**p * k >= degree: + break + + precomputed_pow2s = [precomputed_ks[-1]] + for j in range(p - 1): # TODO: Check if p - 1 is enough + precomputed_pow2s.append( + Multiplication(precomputed_pow2s[-1], precomputed_pow2s[-1], precomputed_pow2s[-1]._gf) + ) + precomputed_powers[(k * (2 ** (j + 1))) % (gf.characteristic - 1)] = precomputed_pow2s[-1] + + evaluation = _eval_poly_divide_conquer_specific( + coefficients, precomputed_ks, precomputed_pow2s, gf, 2 ** (p - 1) + ) + + return evaluation, precomputed_powers + + +def _eval_coefficients(x: FieldArray, coefficients: List[FieldArray]) -> FieldArray: + x_pow = x.copy() + result = coefficients[0].copy() + + for coeff in coefficients[1:]: + result += x_pow * coeff + x_pow *= x + + return result + + +def test_ps_method(): # noqa: D103 + gf = GF(31) + coefficients = [gf(i) for i in range(31)] + + x = Input("x", gf) + + for k in range(1, len(coefficients)): + ( + arithmetization, + _, + ) = _eval_poly(x, coefficients, k, gf, squaring_cost=1.0) + arithmetization.clear_cache(set()) + + for xx in range(31): + assert arithmetization.evaluate({"x": gf(xx)}) == _eval_coefficients(gf(xx), coefficients) + arithmetization.clear_cache(set()) + + assert all(coefficients[i] == i for i in range(31)) + + +def test_divide_conquer_method(): # noqa: D103 + gf = GF(31) + coefficients = [gf(i) for i in range(31)] + + x = Input("x", gf) + + for k in range(1, len(coefficients)): + ( + arithmetization, + _, + ) = _eval_poly_divide_conquer(x, coefficients, k, gf, _squaring_cost=1.0) + arithmetization.clear_cache(set()) + + for xx in range(31): + assert arithmetization.evaluate({"x": gf(xx)}) == _eval_coefficients(gf(xx), coefficients) + arithmetization.clear_cache(set()) + + assert all(coefficients[i] == i for i in range(31)) + + +def test_babystep_giantstep_method(): # noqa: D103 + gf = GF(31) + coefficients = [gf(i) for i in range(31)] + + x = Input("x", gf) + + for k in range(1, len(coefficients)): + ( + arithmetization, + _, + ) = _eval_poly_alternative(x, coefficients, k, gf) + arithmetization.clear_cache(set()) + + for xx in range(31): + assert arithmetization.evaluate({"x": gf(xx)}) == _eval_coefficients(gf(xx), coefficients) + arithmetization.clear_cache(set()) + + assert all(coefficients[i] == i for i in range(31)) diff --git a/oraqle/config.py b/oraqle/config.py new file mode 100644 index 0000000..e47979c --- /dev/null +++ b/oraqle/config.py @@ -0,0 +1,27 @@ +"""This module contains global configuration options. + +!!! warning + This is almost certainly going to be removed in the future. +We do not want oraqle to have a global configuration, but this is currently an intentional evil to prevent large refactors in the initial versions. +""" +from typing import Annotated, Optional + + +Seconds = Annotated[float, "seconds"] +MAXSAT_TIMEOUT: Optional[Seconds] = None +"""Time-out for individual calls to the MaxSAT solver. + +!!! danger + This causes non-deterministic behavior! + +!!! bug + There is currently a chance to get `AttributeError`s, which is a bug caused by pysat trying to delete an oracle that does not exist. + There is no current workaround for this.""" + + +PS_METHOD_FACTOR_K: float = 2.0 +"""Approximation factor for the PS-method, higher is better. + +The Paterson-Stockmeyer method takes a value k, that is theoretically optimal when k = sqrt(2 * degree). +However, sometimes it is better to try other values of k (e.g. due to rounding and to trade off depth and cost). +This factor, let's call it f, is used to limit the candidate values of k that we try: [1, f * sqrt(2 * degree)).""" diff --git a/oraqle/demo/depth_aware_equality.ipynb b/oraqle/demo/depth_aware_equality.ipynb new file mode 100644 index 0000000..9a7a6f6 --- /dev/null +++ b/oraqle/demo/depth_aware_equality.ipynb @@ -0,0 +1,151 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "3580d661-a471-4131-a3e3-a88161d38209", + "metadata": {}, + "source": [ + "# A new paradigm in arithmetization: Depth-aware arithmetization" + ] + }, + { + "cell_type": "markdown", + "id": "d97c8705-c1d1-4b3b-848e-0a68dc7a703b", + "metadata": {}, + "source": [ + "#### An equality circuit is easy to define but somewhat hard to optimize!" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "b615a377-1777-4aec-acb6-79f202dec6ac", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from galois import GF\n", + "\n", + "from circuit_compiler.compiler.nodes.leafs import Input\n", + "from circuit_compiler.compiler.comparison.equality import Equals\n", + "from circuit_compiler.compiler.circuit import Circuit\n", + "\n", + "gf = GF(467)\n", + "\n", + "a = Input(\"a\", gf)\n", + "b = Input(\"b\", gf)\n", + "\n", + "output = Equals(a, b, gf)\n", + "\n", + "circuit = Circuit(outputs=[output], gf=gf)\n", + "circuit.display_graph()" + ] + }, + { + "cell_type": "markdown", + "id": "e466a3dd-b3ef-4449-979b-c443bf860f79", + "metadata": {}, + "source": [ + "#### Naive methods only find only one arithmetization for this high-level circuit." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "3828eb5d-6da2-421f-b5e0-9cc0e399d434", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "naive_arithmetic_circuit = circuit.arithmetize(\"naive\")\n", + "naive_arithmetic_circuit.display_graph()" + ] + }, + { + "cell_type": "markdown", + "id": "7a721fc6-ce6c-441e-8ce2-67fee11af6cc", + "metadata": {}, + "source": [ + "#### Our depth-aware arithmetization method finds a second circuit that is potentially faster!" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "18d08a7b-a97d-4a74-856a-79dcf7197f07", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Depth: 9 , Cost: 8.0\n" + ] + }, + { + "data": { + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Depth: 11 , Cost: 7.5\n" + ] + }, + { + "data": { + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "depth_aware_circuits = circuit.arithmetize_depth_aware(0.5)\n", + "\n", + "for depth, cost, depth_aware_circuit in depth_aware_circuits:\n", + " print(\"Depth: \", depth, \", \", \"Cost: \", cost)\n", + " depth_aware_circuit.display_graph()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/oraqle/demo/playground.ipynb b/oraqle/demo/playground.ipynb new file mode 100644 index 0000000..73fed45 --- /dev/null +++ b/oraqle/demo/playground.ipynb @@ -0,0 +1,79 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b5e62be9-cada-42d2-a2bb-7f3ee38aec51", + "metadata": {}, + "source": [ + "# Playground" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df1eaaad-6ad2-4601-8d12-3b02d9254bfa", + "metadata": {}, + "outputs": [], + "source": [ + "from galois import GF\n", + "\n", + "from circuit_compiler.compiler.boolean.bool_and import And\n", + "from circuit_compiler.compiler.circuit import Circuit\n", + "from circuit_compiler.compiler.nodes.leafs import Input\n", + "\n", + "gf = GF(5)\n", + "\n", + "xs = [Input(f\"x{i}\", gf) for i in range(11)]\n", + "\n", + "output = And(set(xs), gf)\n", + "\n", + "circuit = Circuit(outputs=[output], gf=gf)\n", + "circuit.display_graph()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce27d985-b20c-4f7e-a303-929598c61c17", + "metadata": {}, + "outputs": [], + "source": [ + "naive_arithmetic_circuit = circuit.arithmetize(\"naive\")\n", + "naive_arithmetic_circuit.display_graph()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5c0531f-f50b-4852-b81c-833a813eb235", + "metadata": {}, + "outputs": [], + "source": [ + "circuit._clear_cache()\n", + "better_arithmetic_circuit = circuit.arithmetize(\"best-effort\")\n", + "better_arithmetic_circuit.display_graph()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/oraqle/demo/small_comparison_bgv.ipynb b/oraqle/demo/small_comparison_bgv.ipynb new file mode 100644 index 0000000..516a58e --- /dev/null +++ b/oraqle/demo/small_comparison_bgv.ipynb @@ -0,0 +1,162 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0f2abd68-5065-49c2-aefa-65ca3c8be8f8", + "metadata": {}, + "source": [ + "# Compiling homomorphic encryption circuits made easy" + ] + }, + { + "cell_type": "markdown", + "id": "1f425b04-35ab-4a1c-8ed4-20ecdc7d2901", + "metadata": {}, + "source": [ + "#### The only boilerplate consists of defining the plaintext space and the inputs of the program." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18d03a72-d22a-4f54-9a68-ab31507d1e34", + "metadata": {}, + "outputs": [], + "source": [ + "from galois import GF\n", + "\n", + "from circuit_compiler.compiler.nodes.leafs import Input\n", + "\n", + "gf = GF(11)\n", + "\n", + "a = Input(\"a\", gf)\n", + "b = Input(\"b\", gf)" + ] + }, + { + "cell_type": "markdown", + "id": "7a7890f4-c770-4699-acba-ec2e6796a5bb", + "metadata": {}, + "source": [ + "#### Programmers can use the primitives that they are used to." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0dd02769-50cc-4eb1-a9e0-896b944d9b28", + "metadata": {}, + "outputs": [], + "source": [ + "output = a < b" + ] + }, + { + "cell_type": "markdown", + "id": "8a26b9ca-2441-48e1-8aad-4b626755485e", + "metadata": {}, + "source": [ + "#### A circuit can have an arbitrary number of outputs; here we only have one." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d00fa605-4510-4393-bdb0-4dd54a21f5f8", + "metadata": {}, + "outputs": [], + "source": [ + "from circuit_compiler.compiler.circuit import Circuit\n", + "\n", + "circuit = Circuit(outputs=[output], gf=gf)\n", + "circuit.display_graph()" + ] + }, + { + "cell_type": "markdown", + "id": "fc7c6e33-a7ad-4e2f-a742-40653160a0ca", + "metadata": {}, + "source": [ + "#### Turning high-level circuits into arithmetic circuits is a fully automatic process that improves on the state of the art in multiple ways." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a441c9f5-de63-4253-bbb6-4b63511acc67", + "metadata": {}, + "outputs": [], + "source": [ + "arithmetic_circuit = circuit.arithmetize()\n", + "arithmetic_circuit.display_graph()" + ] + }, + { + "cell_type": "markdown", + "id": "33a64549-4081-4fb8-9631-1f007b368dfa", + "metadata": {}, + "source": [ + "#### The compiler implements a form of semantic subexpression elimination that significantly optimizes large circuits." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1cd8bfaf-8113-444a-812c-b2a4fe124cec", + "metadata": {}, + "outputs": [], + "source": [ + "arithmetic_circuit.eliminate_subexpressions()\n", + "arithmetic_circuit.display_graph()" + ] + }, + { + "cell_type": "markdown", + "id": "a89d7c56-ef33-4ac6-b06a-0f88d45aff91", + "metadata": {}, + "source": [ + "#### This much smaller circuit is still correct!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d50141a-b84a-4ac4-93e7-a7a1cf484688", + "metadata": {}, + "outputs": [], + "source": [ + "import tabulate\n", + "\n", + "for val_a in range(11):\n", + " for val_b in range(11):\n", + " assert arithmetic_circuit.evaluate({\"a\": gf(val_a), \"b\": gf(val_b)}) == gf(val_a < val_b)\n", + "\n", + "data = [[arithmetic_circuit.evaluate({\"a\": gf(val_a), \"b\": gf(val_b)})[0] for val_a in range(11)] for val_b in range(11)]\n", + "\n", + "table = tabulate.tabulate(data, tablefmt='html')\n", + "table" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/oraqle/examples/depth_aware_comparison.py b/oraqle/examples/depth_aware_comparison.py new file mode 100644 index 0000000..7b7ee72 --- /dev/null +++ b/oraqle/examples/depth_aware_comparison.py @@ -0,0 +1,33 @@ +"""Depth-aware arithmetization of a comparison modulo 101.""" + +from galois import GF + +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes.leafs import Input + +gf = GF(101) +cost_of_squaring = 1.0 + +a = Input("a", gf) +b = Input("b", gf) + +output = a < b + +circuit = Circuit(outputs=[output]) +circuit.to_graph("high_level_circuit.dot") + +arithmetic_circuits = circuit.arithmetize_depth_aware(cost_of_squaring) + +for depth, cost, arithmetic_circuit in arithmetic_circuits: + assert arithmetic_circuit.multiplicative_depth() == depth + assert arithmetic_circuit.multiplicative_cost(cost_of_squaring) == cost + + print("pre CSE", depth, cost) + + arithmetic_circuit.eliminate_subexpressions() + + print( + "post CSE", + arithmetic_circuit.multiplicative_depth(), + arithmetic_circuit.multiplicative_cost(cost_of_squaring), + ) diff --git a/oraqle/examples/depth_aware_equality.py b/oraqle/examples/depth_aware_equality.py new file mode 100644 index 0000000..8d42b23 --- /dev/null +++ b/oraqle/examples/depth_aware_equality.py @@ -0,0 +1,23 @@ +"""Depth-aware arithmetization for an equality operation modulo 31.""" + +from galois import GF + +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.comparison.equality import Equals +from oraqle.compiler.nodes.leafs import Input + +gf = GF(31) + +a = Input("a", gf) +b = Input("b", gf) + +output = Equals(a, b, gf) + +circuit = Circuit(outputs=[output]) + +arithmetic_circuits = circuit.arithmetize_depth_aware(cost_of_squaring=1.0) + +if __name__ == "__main__": + circuit.to_pdf("high_level_circuit.pdf") + for depth, size, arithmetic_circuit in arithmetic_circuits: + arithmetic_circuit.to_pdf(f"arithmetic_circuit_d{depth}_s{size}.pdf") diff --git a/oraqle/examples/long_and.py b/oraqle/examples/long_and.py new file mode 100644 index 0000000..0123972 --- /dev/null +++ b/oraqle/examples/long_and.py @@ -0,0 +1,20 @@ +"""Arithmetization of an AND operation between 15 inputs.""" + +from galois import GF + +from oraqle.compiler.boolean.bool_and import And +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes.abstract import UnoverloadedWrapper +from oraqle.compiler.nodes.leafs import Input + +gf = GF(5) + +xs = [Input(f"x{i}", gf) for i in range(15)] + +output = And(set(UnoverloadedWrapper(x) for x in xs), gf) + +circuit = Circuit(outputs=[output]) +circuit.to_graph("high_level_circuit.dot") + +arithmetic_circuit = circuit.arithmetize() +arithmetic_circuit.to_graph("arithmetic_circuit.dot") diff --git a/oraqle/examples/small_comparison.py b/oraqle/examples/small_comparison.py new file mode 100644 index 0000000..f28c9d8 --- /dev/null +++ b/oraqle/examples/small_comparison.py @@ -0,0 +1,19 @@ +"""Arithmetizes a comparison modulo 11 with a constant.""" + +from galois import GF + +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes.leafs import Constant, Input + +gf = GF(11) + +a = Input("a", gf) +b = Constant(gf(3)) # Input("b") + +output = a < b + +circuit = Circuit(outputs=[output]) +circuit.to_graph("high_level_circuit.dot") + +arithmetic_circuit = circuit.arithmetize() +arithmetic_circuit.to_graph("arithmetic_circuit.dot") diff --git a/oraqle/examples/small_polynomial.py b/oraqle/examples/small_polynomial.py new file mode 100644 index 0000000..9a5cc41 --- /dev/null +++ b/oraqle/examples/small_polynomial.py @@ -0,0 +1,19 @@ +"""Creates graphs for the arithmetization of a small polynomial evaluation.""" + +from galois import GF + +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes.leafs import Input +from oraqle.compiler.polynomials.univariate import UnivariatePoly + +gf = GF(11) + +x = Input("x", gf) + +output = UnivariatePoly(x, [gf(1), gf(2), gf(3), gf(4), gf(5), gf(6), gf(1)], gf) + +circuit = Circuit(outputs=[output]) +circuit.to_graph("high_level_circuit.dot") + +arithmetic_circuit = circuit.arithmetize() +arithmetic_circuit.to_graph("arithmetic_circuit.dot") diff --git a/oraqle/examples/visualize_circuits.py b/oraqle/examples/visualize_circuits.py new file mode 100644 index 0000000..c96e9d5 --- /dev/null +++ b/oraqle/examples/visualize_circuits.py @@ -0,0 +1,80 @@ +"""Visualization of three circuits computing an OR operation on 7 inputs.""" + +from galois import GF + +from oraqle.compiler.arithmetic.exponentiation import Power +from oraqle.compiler.boolean.bool_neg import Neg +from oraqle.compiler.circuit import ArithmeticCircuit, Circuit +from oraqle.compiler.nodes.binary_arithmetic import Multiplication +from oraqle.compiler.nodes.leafs import Input + +gf = GF(5) + +x1 = Input("x1", gf) +x2 = Input("x2", gf) +x3 = Input("x3", gf) +x4 = Input("x4", gf) +x5 = Input("x5", gf) +x6 = Input("x6", gf) +x7 = Input("x7", gf) + +sum1 = x1 + x2 + x3 + x4 +exp1 = Power(sum1, 4, gf) + +sum2 = x5 + x6 + x7 + exp1 +exp2 = Power(sum2, 4, gf) + +circuit = Circuit([exp2]) +arithmetic_circuit = circuit.arithmetize() +arithmetic_circuit.to_graph("arithmetic_circuit1.dot") + + +inv1 = Neg(x1, gf) +inv2 = Neg(x2, gf) +inv3 = Neg(x3, gf) +inv4 = Neg(x4, gf) +inv5 = Neg(x5, gf) +inv6 = Neg(x6, gf) + +mul1 = inv1 * inv2 +invmul1 = Neg(mul1, gf) + +mul2 = inv3 * inv4 +invmul2 = Neg(mul2, gf) + +mul3 = inv5 * inv6 +invmul3 = Neg(mul3, gf) + +add1 = mul1 + mul2 +add2 = mul3 + add1 + +add3 = add2 + x7 + +exp = Power(add3, 4, gf) + +circuit = Circuit([exp]) +arithmetic_circuit = circuit.arithmetize() +arithmetic_circuit.to_graph("arithmetic_circuit2.dot") + + +inv1 = Neg(x1, gf).arithmetize("best-effort").to_arithmetic() +inv2 = Neg(x2, gf).arithmetize("best-effort").to_arithmetic() +inv3 = Neg(x3, gf).arithmetize("best-effort").to_arithmetic() +inv4 = Neg(x4, gf).arithmetize("best-effort").to_arithmetic() +inv5 = Neg(x5, gf).arithmetize("best-effort").to_arithmetic() +inv6 = Neg(x6, gf).arithmetize("best-effort").to_arithmetic() +inv7 = Neg(x7, gf).arithmetize("best-effort").to_arithmetic() + +mul1 = Multiplication(inv1, inv2, gf) +mul2 = Multiplication(inv3, inv4, gf) +mul3 = Multiplication(inv5, inv6, gf) + +mul4 = Multiplication(mul1, mul2, gf) +mul5 = Multiplication(mul3, inv7, gf) + +mul6 = Multiplication(mul4, mul5, gf) + +inv = Neg(mul6, gf).arithmetize("best-effort").to_arithmetic() + +arithmetic_circuit = ArithmeticCircuit([inv]) +arithmetic_circuit.to_graph("arithmetic_circuit3.dot") diff --git a/oraqle/experiments/depth_aware_arithmetization/execution/cardio_circuits.py b/oraqle/experiments/depth_aware_arithmetization/execution/cardio_circuits.py new file mode 100644 index 0000000..d53bc56 --- /dev/null +++ b/oraqle/experiments/depth_aware_arithmetization/execution/cardio_circuits.py @@ -0,0 +1,35 @@ +import time + +from galois import GF + +from oraqle.circuits.cardio import ( + construct_cardio_elevated_risk_circuit, + construct_cardio_risk_circuit, +) +from oraqle.compiler.circuit import Circuit + +if __name__ == "__main__": + gf = GF(257) + + for cost_of_squaring in [0.5, 0.75, 1.0]: + print(f"--- Cardio risk assessment ({cost_of_squaring}) ---") + circuit = Circuit([construct_cardio_risk_circuit(gf)]) + + start = time.monotonic() + front = circuit.arithmetize_depth_aware(cost_of_squaring=cost_of_squaring) + print("Run time:", time.monotonic() - start, "s") + + for depth, cost, arithmetic_circuit in front: + print(depth, cost) + arithmetic_circuit.to_graph(f"cardio_arith_d{depth}_c{cost}.dot") + + print(f"--- Cardio elevated risk assessment ({cost_of_squaring}) ---") + circuit = Circuit([construct_cardio_elevated_risk_circuit(gf)]) + + start = time.monotonic() + front = circuit.arithmetize_depth_aware(cost_of_squaring=cost_of_squaring) + print("Run time:", time.monotonic() - start, "s") + + for depth, cost, arithmetic_circuit in front: + print(depth, cost) + arithmetic_circuit.to_graph(f"cardio_elevated_arith_d{depth}_c{cost}.dot") diff --git a/oraqle/experiments/depth_aware_arithmetization/execution/comparisons.py b/oraqle/experiments/depth_aware_arithmetization/execution/comparisons.py new file mode 100644 index 0000000..295ec2e --- /dev/null +++ b/oraqle/experiments/depth_aware_arithmetization/execution/comparisons.py @@ -0,0 +1,53 @@ +from galois import GF + +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.comparison.comparison import ( + IliashenkoZuccaSemiLessThan, + SemiStrictComparison, + T2SemiLessThan, +) +from oraqle.compiler.nodes.leafs import Input + +if __name__ == "__main__": + for p in [29, 43, 61, 101, 131]: + gf = GF(p) + + x = Input("x", gf) + y = Input("y", gf) + + print(f"-------- p = {p}: ---------") + our_circuit = Circuit([SemiStrictComparison(x, y, less_than=True, gf=gf)]) + our_front = our_circuit.arithmetize_depth_aware() + print("Our circuits:", our_front) + + our_front[0][2].to_graph(f"comp_{p}_ours.dot") + + t2_circuit = Circuit([T2SemiLessThan(x, y, gf)]) + t2_arithmetization = t2_circuit.arithmetize() + print( + "T2 circuit:", + t2_arithmetization.multiplicative_depth(), + t2_arithmetization.multiplicative_size(), + ) + t2_arithmetization.eliminate_subexpressions() + print( + "T2 circuit CSE:", + t2_arithmetization.multiplicative_depth(), + t2_arithmetization.multiplicative_size(), + ) + + iz21_circuit = Circuit([IliashenkoZuccaSemiLessThan(x, y, gf)]) + iz21_arithmetization = iz21_circuit.arithmetize() + iz21_arithmetization.to_graph(f"comp_{p}_iz21.dot") + print( + "IZ21 circuits:", + iz21_arithmetization.multiplicative_depth(), + iz21_arithmetization.multiplicative_size(), + ) + iz21_arithmetization.eliminate_subexpressions() + iz21_arithmetization.to_graph(f"comp_{p}_iz21_cse.dot") + print( + "IZ21 circuit CSE:", + iz21_arithmetization.multiplicative_depth(), + iz21_arithmetization.multiplicative_size(), + ) diff --git a/oraqle/experiments/depth_aware_arithmetization/execution/equality_first_prime_mods_exec.py b/oraqle/experiments/depth_aware_arithmetization/execution/equality_first_prime_mods_exec.py new file mode 100644 index 0000000..2366ac9 --- /dev/null +++ b/oraqle/experiments/depth_aware_arithmetization/execution/equality_first_prime_mods_exec.py @@ -0,0 +1,191 @@ +import math +import multiprocessing +import pickle +import time +from functools import partial +from typing import List, Tuple + +from matplotlib import pyplot as plt +from sympy import sieve + +from oraqle.add_chains.addition_chains_front import chain_depth, gen_pareto_front +from oraqle.add_chains.addition_chains_mod import chain_cost, hw + + +def experiment( + t: int, squaring_cost: float +) -> Tuple[List[Tuple[int, float, List[Tuple[int, int]]]], float]: + start = time.monotonic() + chains = gen_pareto_front( + t - 1, + modulus=t - 1, + squaring_cost=squaring_cost, + solver="glucose42", + encoding=1, + thurber=True, + ) + duration = time.monotonic() - start + + return [ + (chain_depth(chain, modulus=t - 1), chain_cost(chain, squaring_cost), chain) + for _, chain in chains + ], duration + + +def experiment2( + t: int, squaring_cost: float +) -> Tuple[List[Tuple[int, float, List[Tuple[int, int]]]], float]: + start = time.monotonic() + chains = gen_pareto_front( + t - 1, + modulus=None, + squaring_cost=squaring_cost, + solver="glucose42", + encoding=1, + thurber=True, + ) + duration = time.monotonic() - start + + return [ + (chain_depth(chain), chain_cost(chain, squaring_cost), chain) for _, chain in chains + ], duration + + +def plot_specific_outputs(specific_outputs, specific_outputs_nomod, primes, squaring_cost: float): + plt.figure(figsize=(9, 2.8)) + plt.grid(axis="y", zorder=-1000, alpha=0.5) + + for x, p in enumerate(primes): + label = "Square & multiply" if p == 2 else None + t = p - 1 + plt.scatter( + x, + math.ceil(math.log2(t)) * squaring_cost + hw(t) - 1, + color="black", + label=label, + zorder=100, + marker="_", + ) + + for x, outputs in enumerate(specific_outputs): + chains, _ = outputs + for depth, cost, _ in chains: + plt.scatter( + x, + cost, + color="black", + zorder=100, + s=50, + label="Optimal circuit" if x == 0 else None, + ) + if len(chains) > 1: + plt.text( + x, + cost - 0.05, + str(depth), + fontsize=6, + ha="center", + va="center", + color="white", + zorder=200, + fontweight="bold", + ) + + plt.xticks(range(len(primes)), primes, rotation=50) + plt.yticks(range(2 * math.ceil(math.log2(primes[-1])))) + + plt.xlabel("Modulus") + plt.ylabel("Multiplicative cost") + + ax1 = plt.gca() + ax2 = ax1.twinx() + for x, outputs in enumerate(specific_outputs): + _, duration = outputs + ax2.bar(x, duration, color="tab:cyan", zorder=0, alpha=0.3, label="Considering modulus" if x == 0 else None) # type: ignore + for x, outputs in enumerate(specific_outputs_nomod): + _, duration = outputs + ax2.bar(x, duration, color="tab:cyan", zorder=0, alpha=1.0, label="Ignoring modulus" if x == 0 else None) # type: ignore + ax2.set_ylabel("Generation time [s]", color="tab:cyan", alpha=1.0) + + ax1.step( + range(len(primes)), + [squaring_cost * math.ceil(math.log2(p - 1)) for p in primes], + zorder=10, + color="black", + where="mid", + label="Lower bound", + linestyle=":", + ) + + # Combine legends from both axes + lines, labels = ax1.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() # type: ignore + ax1.legend(lines + lines2, labels + labels2, loc="upper left", fontsize="small") + + plt.savefig(f"equality_first_prime_mods_{squaring_cost}.pdf", bbox_inches="tight") + plt.show() + + +if __name__ == "__main__": + run_experiments = False + + if run_experiments: + multiprocessing.set_start_method("fork") + threads = 4 + pool = multiprocessing.Pool(threads) + + primes = list(sieve.primerange(300))[:30] # [:50] + + for sqr_cost in [0.5, 0.75, 1.0]: + print(f"Computing for {sqr_cost}") + experiment_sqr_cost = partial(experiment, squaring_cost=sqr_cost) + outs = list(pool.map(experiment_sqr_cost, primes)) + + with open(f"equality_experiment_{sqr_cost}_mod.pkl", mode="wb") as file: + pickle.dump((primes, outs), file) + + for sqr_cost in [0.5, 0.75, 1.0]: + print(f"Computing for {sqr_cost}") + experiment_sqr_cost = partial(experiment2, squaring_cost=sqr_cost) + outs = list(pool.map(experiment_sqr_cost, primes)) + + with open(f"equality_experiment_{sqr_cost}_nomod.pkl", mode="wb") as file: + pickle.dump((primes, outs), file) + + # Visualize + with open("equality_experiment_0.5_mod.pkl", "rb") as file: + primes_05_mod, outputs_05_mod = pickle.load(file) + with open("equality_experiment_0.75_mod.pkl", "rb") as file: + primes_075_mod, outputs_075_mod = pickle.load(file) + with open("equality_experiment_1.0_mod.pkl", "rb") as file: + primes_10_mod, outputs_10_mod = pickle.load(file) + + with open("equality_experiment_0.5_nomod.pkl", "rb") as file: + primes_05_nomod, outputs_05_nomod = pickle.load(file) + with open("equality_experiment_0.75_nomod.pkl", "rb") as file: + primes_075_nomod, outputs_075_nomod = pickle.load(file) + with open("equality_experiment_1.0_nomod.pkl", "rb") as file: + primes_10_nomod, outputs_10_nomod = pickle.load(file) + + # All the primes should match + primes = primes_10_mod + assert primes == primes_05_mod + assert primes == primes_075_mod + assert primes == primes_05_nomod + assert primes == primes_075_nomod + assert primes == primes_10_nomod + + # All the chains should match (not in theory, but for this visualization they should) + assert all( + all(x == y for x, y in zip(a[0], b[0])) for a, b in zip(outputs_05_mod, outputs_05_nomod) + ) + assert all( + all(x == y for x, y in zip(a[0], b[0])) for a, b in zip(outputs_075_mod, outputs_075_nomod) + ) + assert all( + all(x == y for x, y in zip(a[0], b[0])) for a, b in zip(outputs_10_mod, outputs_10_nomod) + ) + + plot_specific_outputs(outputs_05_mod, outputs_05_nomod, primes, squaring_cost=0.5) + plot_specific_outputs(outputs_075_mod, outputs_075_nomod, primes, squaring_cost=0.75) + plot_specific_outputs(outputs_10_mod, outputs_10_nomod, primes, squaring_cost=1.0) diff --git a/oraqle/experiments/depth_aware_arithmetization/execution/poly_evaluation_pareto_front.py b/oraqle/experiments/depth_aware_arithmetization/execution/poly_evaluation_pareto_front.py new file mode 100644 index 0000000..8e43792 --- /dev/null +++ b/oraqle/experiments/depth_aware_arithmetization/execution/poly_evaluation_pareto_front.py @@ -0,0 +1,180 @@ +import math +import sys + +from galois import GF +from matplotlib import pyplot as plt +from matplotlib.ticker import MultipleLocator + +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes.abstract import SizeParetoFront +from oraqle.compiler.nodes.leafs import Input +from oraqle.compiler.polynomials.univariate import ( + UnivariatePoly, + _eval_poly, + _eval_poly_alternative, + _eval_poly_divide_conquer, +) + +if __name__ == "__main__": + sys.setrecursionlimit(15000) + + shape_size = 150 + + plt.figure(figsize=(3.5, 4.4)) + + marker1 = (3, 2, 0) + marker2 = (3, 2, 40) + marker3 = (3, 2, 80) + o_marker = "o" + linewidth = 2.5 + + squaring_cost = 1.0 + + p = 127 # 31 + gf = GF(p) + for d in [p - 1]: + x = Input("x", gf) + + poly = UnivariatePoly.from_function(x, gf, lambda x: x % 7) + coefficients = poly._coefficients + + # Generate points + print("Paterson & Stockmeyer") + depths = [] + sizes = [] + + front = SizeParetoFront() + + for k in range(1, len(coefficients)): + res, pows = _eval_poly(x, coefficients, k, gf, squaring_cost) + circ = Circuit([res]).arithmetize() + depths.append(circ.multiplicative_depth()) + sizes.append(circ.multiplicative_size()) + front.add(res, circ.multiplicative_depth(), circ.multiplicative_size()) # type: ignore + print(k, circ.multiplicative_depth(), circ.multiplicative_size()) + + data = {(d, s) for d, s in zip(depths, sizes)} + plt.scatter( + [d for d, _ in data], + [s for _, s in data], + marker=marker2, # type: ignore + zorder=10, + alpha=0.4, + s=shape_size, + linewidth=linewidth, + ) + + print("Baby-step giant-step") + depths2 = [] + sizes2 = [] + for k in range(1, len(coefficients)): + res, pows = _eval_poly_alternative(x, coefficients, k, gf) + circ = Circuit([res]).arithmetize() + depths2.append(circ.multiplicative_depth()) + sizes2.append(circ.multiplicative_size()) + front.add(res, circ.multiplicative_depth(), circ.multiplicative_size()) # type: ignore + + data2 = {(d, s) for d, s in zip(depths2, sizes2)} + plt.scatter( + [d for d, _ in data2], + [s for _, s in data2], + marker=marker1, # type: ignore + zorder=11, + alpha=0.45, + s=shape_size, + linewidth=linewidth, + ) + + print("Divide and conquer") + depths3 = [] + sizes3 = [] + for k in range(1, len(coefficients)): + res, pows = _eval_poly_divide_conquer(x, coefficients, k, gf, squaring_cost) + circ = Circuit([res]).arithmetize() + depths3.append(circ.multiplicative_depth()) + sizes3.append(circ.multiplicative_size()) + front.add(res, circ.multiplicative_depth(), circ.multiplicative_size()) # type: ignore + + data3 = {(d, s) for d, s in zip(depths3, sizes3)} + plt.scatter( + [d for d, _ in data3], + [s for _, s in data3], + marker=marker3, # type: ignore + zorder=11, + alpha=0.45, + s=shape_size, + linewidth=linewidth, + ) + + # Plot the front + front_initial = [(d, s) for d, s in data2 if d in front._nodes_by_depth and front._nodes_by_depth[d][0] == s] # type: ignore + front_advanced = [(d, s) for d, s in data if d in front._nodes_by_depth and front._nodes_by_depth[d][0] == s] # type: ignore + front_divconq = [(d, s) for d, s in data3 if d in front._nodes_by_depth and front._nodes_by_depth[d][0] == s] # type: ignore + + plt.scatter( + [d for d, _ in front_initial], + [s for _, s in front_initial], + marker=marker1, # type: ignore + zorder=10, + color="tab:orange", + s=shape_size, + label="Baby-step giant-step", + linewidth=linewidth, + ) + plt.scatter( + [d for d, _ in front_advanced], + [s for _, s in front_advanced], + marker=marker2, # type: ignore + zorder=10, + color="tab:blue", + s=shape_size, + label="Paterson & Stockmeyer", + linewidth=linewidth, + ) + plt.scatter( + [d for d, _ in front_divconq], + [s for _, s in front_divconq], + marker=marker3, # type: ignore + zorder=10, + color="tab:green", + s=shape_size, + label="Divide & Conquer", + linewidth=linewidth, + ) + + k = round(math.sqrt(d / 2)) + res, pows = _eval_poly(x, coefficients, k, gf, squaring_cost) + circ = Circuit([res]).arithmetize() + plt.scatter( + circ.multiplicative_depth(), + circ.multiplicative_size(), + marker=o_marker, + s=shape_size + 50, + facecolors="none", + edgecolors="black", + ) + plt.text( + circ.multiplicative_depth(), + circ.multiplicative_size() + 0.4, + f"k = {k}", + ha="center", + fontsize=8, + ) + + plt.xlim((5, 15)) + plt.ylim((15, 30)) + + plt.gca().set_aspect("equal") + + plt.gca().xaxis.set_minor_locator(MultipleLocator(1)) + plt.gca().yaxis.set_minor_locator(MultipleLocator(1)) + + plt.grid(True, which="both", zorder=1, alpha=0.5) + + plt.xlabel("Multiplicative depth") + plt.ylabel("Multiplicative size") + + plt.legend(fontsize="small") + + plt.savefig("poly_eval_front_2.pdf", bbox_inches="tight") + plt.show() diff --git a/oraqle/experiments/depth_aware_arithmetization/execution/run_all.sh b/oraqle/experiments/depth_aware_arithmetization/execution/run_all.sh new file mode 100755 index 0000000..8230af4 --- /dev/null +++ b/oraqle/experiments/depth_aware_arithmetization/execution/run_all.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# Get the directory where the script is located +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Change to the script's directory +cd "$SCRIPT_DIR" + +# Loop through all Python files in the script's directory +for file in *.py +do + # Check if there are any Python files + if [ -e "$file" ]; then + echo "Running $file" + python3 "$file" + else + echo "No Python files found in the script's directory." + break + fi +done diff --git a/oraqle/experiments/depth_aware_arithmetization/execution/veto_voting_per_mod.py b/oraqle/experiments/depth_aware_arithmetization/execution/veto_voting_per_mod.py new file mode 100644 index 0000000..238cc22 --- /dev/null +++ b/oraqle/experiments/depth_aware_arithmetization/execution/veto_voting_per_mod.py @@ -0,0 +1,99 @@ +from typing import List + +from galois import GF +from matplotlib import pyplot as plt +from sympy import sieve + +from oraqle.compiler.boolean.bool_and import _minimum_cost +from oraqle.compiler.boolean.bool_or import Or +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes.abstract import CostParetoFront, UnoverloadedWrapper +from oraqle.compiler.nodes.leafs import Input +from oraqle.experiments.oraqle_spotlight.experiments.veto_voting_minimal_cost import ( + exponentiation_results, +) + + +def generate_all_fronts(): + results = {} + + for p in [7, 11, 13, 17]: + fronts = [] + + print(f"------ p = {p} ------") + for k in range(2, 51): + gf = GF(p) + xs = [Input(f"x{i}", gf) for i in range(k)] + + circuit = Circuit([Or(set(UnoverloadedWrapper(x) for x in xs), gf)]) + front = circuit.arithmetize_depth_aware(cost_of_squaring=1.0) + + print(f"{k}.", end=" ") + for f in front: + print(f[0], f[1], end=" ") + + print() + fronts.append(front) + + results[p] = fronts + + return results + + +def plot_fronts(fronts: List[CostParetoFront], color, label, **kwargs): + plt.scatter([], [], color=color, label=label, **kwargs) + for k, front in zip(range(2, 51), fronts): + for depth, cost, _ in front: + kwargs["marker"] = (depth, 2, 0) + kwargs["s"] = 16 + kwargs["linewidth"] = 0.5 + plt.scatter(k, cost, color=color, **kwargs) + + +if __name__ == "__main__": + fronts_by_p = generate_all_fronts() + max_k = 50 + + plt.figure(figsize=(4, 4)) + + plt.plot( + range(2, max_k + 1), + [k - 1 for k in range(2, max_k + 1)], + color="gray", + linestyle="solid", + label="Naive", + linewidth=0.7, + ) + + plot_fronts(fronts_by_p[7], "tab:purple", "Modulus p = 7", zorder=100) + plot_fronts(fronts_by_p[13], "tab:green", "Modulus p = 13", zorder=100) + + best_costs = [100000000.0] * (max_k + 1) + best_ps = [None] * (max_k + 1) + # This is for sqr = 0.75 mul + primes = list(sieve.primerange(300))[1:50] + for p in primes: + for k in range(2, max_k + 1): + cost = _minimum_cost(k, exponentiation_results[p][0][0][1], p) + if cost < best_costs[k - 2]: + best_costs[k - 2] = cost + best_ps[k - 2] = p + + plt.step( + range(2, max_k + 1), + best_costs[:-2], + zorder=10, + color="gray", + where="mid", + label="Lowest for any p", + linestyle="solid", + linewidth=0.7, + ) + + plt.legend() + + plt.xlabel("Number of operands") + plt.ylabel("Multiplicative size") + + plt.savefig("veto_voting.pdf", bbox_inches="tight") + plt.show() diff --git a/oraqle/experiments/oraqle_spotlight/examples/and_16.py b/oraqle/experiments/oraqle_spotlight/examples/and_16.py new file mode 100644 index 0000000..d1efb56 --- /dev/null +++ b/oraqle/experiments/oraqle_spotlight/examples/and_16.py @@ -0,0 +1,17 @@ +from galois import GF + +from oraqle.compiler.boolean.bool_and import all_ +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes.leafs import Input + +if __name__ == "__main__": + gf = GF(17) + + xs = (Input(f"x{i + 1}", gf) for i in range(16)) + + conjunction = all_(*xs) + + circuit = Circuit([conjunction]) + arithmetic_circuit = circuit.arithmetize() + + arithmetic_circuit.to_pdf("conjunction.pdf") diff --git a/oraqle/experiments/oraqle_spotlight/examples/common_expressions.py b/oraqle/experiments/oraqle_spotlight/examples/common_expressions.py new file mode 100644 index 0000000..0942584 --- /dev/null +++ b/oraqle/experiments/oraqle_spotlight/examples/common_expressions.py @@ -0,0 +1,44 @@ +from typing import Tuple + +from galois import GF + +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes.abstract import Node +from oraqle.compiler.nodes.arbitrary_arithmetic import sum_ +from oraqle.compiler.nodes.leafs import Input + + +def generate_nodes() -> Tuple[Node, Node]: + gf = GF(31) + + x = Input("x", gf) + y = Input("y", gf) + z1 = Input("z1", gf) + z2 = Input("z2", gf) + z3 = Input("z3", gf) + z4 = Input("z4", gf) + + comparison = x < y + sum = sum_(z1, z2, z3, z4) + cse1 = comparison & sum + + comparison = y > x + sum = sum_(z3, z2, z4) + z1 + cse2 = sum & comparison + + return cse1, cse2 + + +def test_cse_equivalence(): + cse1, cse2 = generate_nodes() + assert cse1.is_equivalent(cse2) + + +if __name__ == "__main__": + cse1, cse2 = generate_nodes() + + cse1 = Circuit([cse1]) + cse2 = Circuit([cse2]) + + cse1.to_pdf("cse1.pdf") + cse2.to_pdf("cse2.pdf") diff --git a/oraqle/experiments/oraqle_spotlight/examples/equality_31.py b/oraqle/experiments/oraqle_spotlight/examples/equality_31.py new file mode 100644 index 0000000..a5cb051 --- /dev/null +++ b/oraqle/experiments/oraqle_spotlight/examples/equality_31.py @@ -0,0 +1,18 @@ +from galois import GF + +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes.leafs import Input + +if __name__ == "__main__": + gf = GF(31) + + x = Input("x", gf) + y = Input("y", gf) + + equality = x == y + + circuit = Circuit([equality]) + arithmetic_circuits = circuit.arithmetize_depth_aware(cost_of_squaring=1.0) + + for d, _, arithmetic_circuit in arithmetic_circuits: + arithmetic_circuit.to_pdf(f"equality_{d}.pdf") diff --git a/oraqle/experiments/oraqle_spotlight/examples/equality_and_comparison.py b/oraqle/experiments/oraqle_spotlight/examples/equality_and_comparison.py new file mode 100644 index 0000000..cb31753 --- /dev/null +++ b/oraqle/experiments/oraqle_spotlight/examples/equality_and_comparison.py @@ -0,0 +1,19 @@ +from galois import GF + +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes.leafs import Input + +if __name__ == "__main__": + gf = GF(31) + + x = Input("x", gf) + y = Input("y", gf) + z = Input("z", gf) + + comparison = x < y + equality = y == z + both = comparison & equality + + circuit = Circuit([both]) + + circuit.to_pdf("example.pdf") diff --git a/oraqle/experiments/oraqle_spotlight/examples/t2_comparison.py b/oraqle/experiments/oraqle_spotlight/examples/t2_comparison.py new file mode 100644 index 0000000..1f43ec8 --- /dev/null +++ b/oraqle/experiments/oraqle_spotlight/examples/t2_comparison.py @@ -0,0 +1,21 @@ +from galois import GF + +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes.leafs import Input + +p = 7 +gf = GF(p) + +x = Input("x", gf) +y = Input("y", gf) + +comparison = 0 + +for a in range((p + 1) // 2, p): + comparison += 1 - (x - y - a) ** (p - 1) + +circuit = Circuit([comparison]) # type: ignore + +if __name__ == "__main__": + circuit.to_graph("t2.dot") + circuit.to_pdf("t2.pdf") diff --git a/oraqle/experiments/oraqle_spotlight/experiments/comparisons/comparisons_bench.py b/oraqle/experiments/oraqle_spotlight/experiments/comparisons/comparisons_bench.py new file mode 100644 index 0000000..70a8889 --- /dev/null +++ b/oraqle/experiments/oraqle_spotlight/experiments/comparisons/comparisons_bench.py @@ -0,0 +1,115 @@ +import random +import subprocess + +from galois import GF +from matplotlib import pyplot as plt +from sympy import sieve + +from oraqle.compiler.circuit import ArithmeticCircuit, Circuit +from oraqle.compiler.comparison.comparison import SemiStrictComparison, T2SemiLessThan +from oraqle.compiler.nodes.leafs import Input + + +def run_benchmark(arithmetic_circuit: ArithmeticCircuit) -> float: + # Prepare the benchmark + arithmetic_circuit.generate_code("main.cpp", iterations=10, measure_time=True) + subprocess.run("make", capture_output=True, check=True) + + # Run the benchmark + command = ["./main"] + p = arithmetic_circuit._gf.characteristic + command.append(f"x={random.randint(0, p - 1)}") + command.append(f"y={random.randint(0, p - 1)}") + print("Running:", " ".join(command)) + result = subprocess.run(command, capture_output=True, text=True, check=False) + + if result.returncode != 0: + print("stderr:") + print(result.stderr) + print() + print("stdout:") + print(result.stdout) + + # Check if the noise was not too large + print(result.stdout) + lines = result.stdout.splitlines() + for line in lines[:-1]: + assert line.endswith("1") + + run_time = float(lines[-1]) / 10 + print(p, run_time) + + return run_time + + +if __name__ == "__main__": + run_benchmarks = False + gen_plots = True + + if run_benchmarks: + primes = list(sieve.primerange(300))[2:20] + + our_times = [] + t2_times = [] + + for p in primes: + gf = GF(p) + + x = Input("x", gf) + y = Input("y", gf) + + print(f"-------- p = {p}: ---------") + our_circuit = Circuit([SemiStrictComparison(x, y, less_than=True, gf=gf)]) + our_front = our_circuit.arithmetize_depth_aware() + print("Our circuits:", our_front) + + ts = [] + for _, _, arithmetic_circuit in our_front: + ts.append(run_benchmark(arithmetic_circuit)) + our_times.append(tuple(ts)) + + t2_circuit = Circuit([T2SemiLessThan(x, y, gf)]) + t2_arithmetization = t2_circuit.arithmetize() + print( + "T2 circuit:", + t2_arithmetization.multiplicative_depth(), + t2_arithmetization.multiplicative_size(), + ) + + t2_times.append(run_benchmark(t2_arithmetization)) + + print(primes) + print(our_times) + print(t2_times) + + if gen_plots: + primes = [5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71] + our_times = [(0.0156603,), (0.0523416,), (0.0954489,), (0.0936497,), (0.111959,), (0.128402,), (0.288951,), (0.42076, 0.368583), (0.416362,), (0.40343,), (0.385652,), (0.437486,), (0.481356,), (0.522607, 0.504944), (0.526451,), (0.5904119999999999, 0.5146740000000001), (0.592896,), (0.621265, 0.598357)] + t2_times = [0.0156379, 0.0938689, 0.23473899999999998, 0.319668, 0.366707, 0.6632450000000001, 1.8380299999999998, 1.14859, 2.9022200000000002, 3.2060299999999997, 3.5419899999999997, 4.53918, 5.02624, 5.4439, 8.64118, 6.6267499999999995, 6.99609, 9.21295] + + plt.figure(figsize=(4, 2)) + plt.grid(axis="y", zorder=-1000, alpha=0.5) + + plt.scatter( + range(len(primes)), t2_times, marker="_", label="T2's Circuit", color="tab:orange" + ) + + for x, ts in enumerate(our_times): + for t in ts: + plt.scatter( + x, + t, + marker="_", + label="Oraqle's circuits" if x == 0 else None, + color="tab:cyan", + ) + + plt.xticks(range(len(primes)), primes, fontsize=8) # type: ignore + + plt.xlabel("Modulus") + plt.ylabel("Run time (s)") + + plt.legend() + + plt.savefig("t2_comparison.pdf", bbox_inches="tight") + plt.show() diff --git a/oraqle/experiments/oraqle_spotlight/experiments/large_equality/.gitignore b/oraqle/experiments/oraqle_spotlight/experiments/large_equality/.gitignore new file mode 100644 index 0000000..adf89f2 --- /dev/null +++ b/oraqle/experiments/oraqle_spotlight/experiments/large_equality/.gitignore @@ -0,0 +1,6 @@ +/CMakeFiles +CMakeCache.txt +cmake_install.cmake +helib.log +Makefile +main diff --git a/oraqle/experiments/oraqle_spotlight/experiments/large_equality/CMakeLists.txt b/oraqle/experiments/oraqle_spotlight/experiments/large_equality/CMakeLists.txt new file mode 100644 index 0000000..a172dbd --- /dev/null +++ b/oraqle/experiments/oraqle_spotlight/experiments/large_equality/CMakeLists.txt @@ -0,0 +1,9 @@ +cmake_minimum_required(VERSION 3.10.2 FATAL_ERROR) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +find_package(helib) +add_executable(main main.cpp) +target_link_libraries(main helib) diff --git a/oraqle/experiments/oraqle_spotlight/experiments/large_equality/large_equality.py b/oraqle/experiments/oraqle_spotlight/experiments/large_equality/large_equality.py new file mode 100644 index 0000000..e77a453 --- /dev/null +++ b/oraqle/experiments/oraqle_spotlight/experiments/large_equality/large_equality.py @@ -0,0 +1,104 @@ +import math +import random +import subprocess +import time +from typing import List, Tuple + +from galois import GF +from sympy import sieve + +from oraqle.compiler.boolean.bool_and import all_ +from oraqle.compiler.circuit import ArithmeticCircuit, Circuit +from oraqle.compiler.nodes.leafs import Input + + +def generate_circuits(bits: int) -> List[Tuple[int, ArithmeticCircuit, int, float]]: + circuits = [] + + primes = list(sieve.primerange(300))[:10] # [:55] # p <= 257 + start = time.monotonic() + times = [] + for p in primes: + # (6, 63.0): p=2 + # (7, 58.0): p=5 + # (8, 51.0): p=17 + + limbs = math.ceil(bits / math.log2(p)) + + gf = GF(p) + + xs = [Input(f"x{i}", gf) for i in range(limbs)] + ys = [Input(f"y{i}", gf) for i in range(limbs)] + circuit = Circuit([all_(*(xs[i] == ys[i] for i in range(limbs)))]) + + inbetween = time.monotonic() + front = circuit.arithmetize_depth_aware(0.75) + + print(f"{p}.", end=" ") + + for f in front: + circuits.append((p, f[2], f[0], f[1])) + print(f[0], f[1], end=" ") + + inbetween_time = time.monotonic() - inbetween + print(inbetween_time) + times.append((p, inbetween_time)) + + print(times) + print("Total time", time.monotonic() - start) + + return circuits + + +if __name__ == "__main__": + bits = 64 + benchmark_circuits = False + generate_table = True + + # Run a benchmark for all circuits in the front + if benchmark_circuits: + # Generate all circuits per p + circuits = generate_circuits(bits) + + results = [] + for p, arithmetic_circuit, d, c in circuits: + # Prepare the benchmark + params = arithmetic_circuit.generate_code("main.cpp", iterations=10, measure_time=True) + subprocess.run("make", check=True) + + # Run the benchmark + command = ["./main"] + limbs = math.ceil(bits / math.log2(p)) + for i in range(limbs): + command.append(f"x{i}={random.randint(0, p - 1)}") + command.append(f"y{i}={random.randint(0, p - 1)}") + print("Running:", " ".join(command)) + result = subprocess.run(command, capture_output=True, text=True, check=False) + + if result.returncode != 0: + print("stderr:") + print(result.stderr) + print() + print("stdout:") + print(result.stdout) + + # Check if the noise was not too large + print(result.stdout) + lines = result.stdout.splitlines() + for line in lines[:-1]: + assert line.endswith("1") + + run_time = float(lines[-1]) / 10 + print(p, run_time, d, c, params) + results.append((p, d, c, params, run_time)) + + print(results) + + if generate_table: + gen_times = [(2, 0.007554411888122559), (3, 0.06264467351138592), (5, 8.457202550023794), (7, 0.05447225831449032), (11, 0.0478445328772068), (13, 0.052152080461382866), (17, 0.04349260404706001), (19, 0.04553743451833725), (23, 0.05198719538748264), (29, 0.046183058992028236)] + results = [(2, 6, 63.0, (16383, 1, 142, 3), 3.27577), (3, 7, 60.75, (32768, 1, 170, 3), 1.51993), (5, 7, 58.0, (32768, 1, 178, 3), 1.7679099999999999), (5, 8, 55.5, (32768, 1, 197, 3), 1.93994), (7, 8, 74.0, (32768, 1, 206, 3), 2.90913), (7, 9, 70.0, (32768, 1, 226, 3), 2.6624600000000003), (7, 10, 69.5, (32768, 1, 246, 3), 3.00814), (11, 9, 69.25, (32768, 1, 228, 3), 2.50603), (11, 12, 68.25, (32768, 1, 300, 3), 3.25469), (13, 9, 68.75, (32768, 1, 237, 3), 2.67845), (13, 10, 67.75, (32768, 1, 237, 3), 2.7718), (13, 11, 66.0, (32768, 1, 237, 3), 2.56386), (13, 12, 65.0, (32768, 1, 301, 3), 3.10959), (17, 8, 51.0, (32768, 1, 217, 3), 1.8792300000000002), (19, 9, 79.0, (32768, 1, 238, 3), 2.85011), (19, 10, 68.0, (32768, 1, 259, 3), 2.8636500000000003), (23, 9, 89.0, (32768, 1, 248, 3), 4.135730000000001), (23, 10, 80.0, (32768, 1, 270, 3), 3.75128), (29, 9, 83.0, (32768, 1, 249, 3), 3.7119), (29, 10, 75.0, (32768, 1, 271, 3), 3.46666)] + + gen_times = {p: t for p, t in gen_times} + + for p, d, c, params, run_time in results: + print(f"{p} & {d} & {c} & {params[0]} & {params[1]} & {params[2]} & {params[3]} & {round(gen_times[p], 2)} & {round(run_time, 2)} \\\\") diff --git a/oraqle/experiments/oraqle_spotlight/experiments/veto_voting_minimal_cost.py b/oraqle/experiments/oraqle_spotlight/experiments/veto_voting_minimal_cost.py new file mode 100644 index 0000000..04fd8e7 --- /dev/null +++ b/oraqle/experiments/oraqle_spotlight/experiments/veto_voting_minimal_cost.py @@ -0,0 +1,82 @@ +"""Finds the minimum cost for veto voting circuits for different prime moduli.""" + +from sympy import sieve + +from oraqle.compiler.boolean.bool_and import _minimum_cost + +exponentiation_results = { + 2: ([(0, 0.0)], 8.633400000002123e-05), + 3: ([(1, 0.75)], 4.6670000000137435e-06), + 5: ([(2, 1.5)], 7.695799999996034e-05), + 7: ([(3, 2.5)], 0.0053472920000000035), + 11: ([(4, 3.25)], 0.007671625000000015), + 13: ([(4, 3.25)], 0.002812749999999975), + 17: ([(4, 3.0)], 7.891700000001167e-05), + 19: ([(5, 4.0)], 0.012155541999999964), + 23: ([(5, 5.0)], 0.03937258299999996), + 29: ([(5, 5.0)], 0.018942542000000007), + 31: ([(5, 6.0), (6, 5.0)], 0.064326), + 37: ([(6, 4.75)], 0.019883207999999986), + 41: ([(6, 4.75)], 0.02284237499999997), + 43: ([(6, 5.75)], 0.03223737499999996), + 47: ([(6, 6.75), (7, 6.0)], 0.607119292), + 53: ([(6, 5.75)], 0.03940958299999997), + 59: ([(6, 6.75)], 1.243811584), + 61: ([(6, 6.75), (7, 5.75)], 0.446000167), + 67: ([(7, 5.5)], 0.051902208000000005), + 71: ([(7, 6.5)], 0.18221370799999997), + 73: ([(7, 5.5)], 0.044685417000000005), + 79: ([(7, 6.75)], 0.362901958), + 83: ([(7, 6.5)], 0.121000375), + 89: ([(7, 6.5)], 0.182695375), + 97: ([(7, 5.5)], 0.06858350000000002), + 101: ([(7, 6.5)], 0.38408749999999997), + 103: ([(7, 7.5), (8, 6.5)], 3.3626029170000002), + 107: ([(7, 7.5)], 8.891771667), + 109: ([(7, 7.5), (8, 6.5)], 4.596561917), + 113: ([(7, 6.5)], 0.1859389579999995), + 127: ([(7, 9.5), (8, 7.5)], 1619.89318625), + 131: ([(8, 6.25)], 0.05858354099996177), + 137: ([(8, 6.25)], 0.10623299999999991), + 139: ([(8, 7.25)], 1.2351711669999998), + 149: ([(8, 7.25)], 0.48292875), + 151: ([(8, 7.5)], 4.641820375), + 157: ([(8, 7.5)], 2.49218775), + 163: ([(8, 7.25)], 0.5001321249999999), + 167: ([(8, 8.25), (9, 7.5)], 48.444338791), + 173: ([(8, 8.25), (9, 7.5)], 37.677076833), + 179: ([(8, 8.25)], 132.232723375), + 181: ([(8, 8.25), (9, 7.25)], 53.822612083999985), + 191: ([(8, 9.25), (9, 8.25)], 907.7980847910001), + 193: ([(8, 6.25)], 0.12370429100008096), + 197: ([(8, 7.25)], 0.6496936670000002), + 199: ([(8, 8.25), (9, 7.25)], 50.102889333), + 211: ([(8, 8.25)], 83.20584475), + 223: ([(8, 10.0), (9, 8.25)], 6772.927301542), + 227: ([(8, 8.25)], 50.801469917), + 229: ([(8, 8.25)], 39.942074416000004), +} + + +def run_experiments(): + """Run the experiments and prints the results.""" + max_k = 50 + best_costs = [100000000.0] * (max_k + 1) + best_ps = [None] * (max_k + 1) + + # This is for sqr = 0.75 mul + primes = list(sieve.primerange(300))[1:50] + for p in primes: + print(f"------ p = {p} ------") + for k in range(2, max_k + 1): + cost = _minimum_cost(k, exponentiation_results[p][0][0][1], p) + if cost < best_costs[k - 2]: + best_costs[k - 2] = cost + best_ps[k - 2] = p + + for k, cost, p in zip(range(2, max_k + 1), best_costs, best_ps): + print(k, cost, p) + + +if __name__ == "__main__": + run_experiments() diff --git a/oraqle/helib_template/.gitignore b/oraqle/helib_template/.gitignore new file mode 100644 index 0000000..adf89f2 --- /dev/null +++ b/oraqle/helib_template/.gitignore @@ -0,0 +1,6 @@ +/CMakeFiles +CMakeCache.txt +cmake_install.cmake +helib.log +Makefile +main diff --git a/oraqle/helib_template/CMakeLists.txt b/oraqle/helib_template/CMakeLists.txt new file mode 100644 index 0000000..a172dbd --- /dev/null +++ b/oraqle/helib_template/CMakeLists.txt @@ -0,0 +1,9 @@ +cmake_minimum_required(VERSION 3.10.2 FATAL_ERROR) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +find_package(helib) +add_executable(main main.cpp) +target_link_libraries(main helib) diff --git a/oraqle/helib_template/main.cpp b/oraqle/helib_template/main.cpp new file mode 100644 index 0000000..ad77e16 --- /dev/null +++ b/oraqle/helib_template/main.cpp @@ -0,0 +1,83 @@ + +#include +#include +#include + +#include + +typedef helib::Ptxt ptxt_t; +typedef helib::Ctxt ctxt_t; + +std::map input_map; + +void parse_arguments(int argc, char* argv[]) { + for (int i = 1; i < argc; ++i) { + std::string argument(argv[i]); + size_t pos = argument.find('='); + if (pos != std::string::npos) { + std::string key = argument.substr(0, pos); + int value = std::stoi(argument.substr(pos + 1)); + input_map[key] = value; + } + } +} + +int extract_input(const std::string& name) { + if (input_map.find(name) != input_map.end()) { + return input_map[name]; + } else { + std::cerr << "Error: " << name << " not found" << std::endl; + return -1; + } +} + +int main(int argc, char* argv[]) { + // Parse the inputs + parse_arguments(argc, argv); + + // Set up the HE parameters + unsigned long p = 5; + unsigned long m = 8192; + unsigned long r = 1; + unsigned long bits = 72; + unsigned long c = 3; + helib::Context context = helib::ContextBuilder() + .m(m) + .p(p) + .r(r) + .bits(bits) + .c(c) + .build(); + + + // Generate keys + helib::SecKey secret_key(context); + secret_key.GenSecKey(); + helib::addSome1DMatrices(secret_key); + const helib::PubKey& public_key = secret_key; + + // Encrypt the inputs + std::vector vec_x(1, extract_input("x")); + ptxt_t ptxt_x(context, vec_x); + ctxt_t ciph_x(public_key); + public_key.Encrypt(ciph_x, ptxt_x); + std::vector vec_y(1, extract_input("y")); + ptxt_t ptxt_y(context, vec_y); + ctxt_t ciph_y(public_key); + public_key.Encrypt(ciph_y, ptxt_y); + + // Perform the actual circuit + ctxt_t stack_0 = ciph_x; + ctxt_t stack_1 = ciph_y; + stack_1 *= 4l; + stack_0 += stack_1; + stack_0 *= stack_0; + stack_0 *= stack_0; + stack_0 *= 4l; + stack_0 += 1l; + ptxt_t decrypted(context); + secret_key.Decrypt(decrypted, stack_0); + std::cout << decrypted << std::endl; + + return 0; +} diff --git a/pyproject.toml b/pyproject.toml index 9f69c2c..a29cc77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "oraqle" -description = "Secure computation compiler for MPC, FHE, and arithmetic circuits in general" -version = "0.0.1" +description = "Secure computation compiler for homomorphic encryption and arithmetic circuits in general" +version = "0.1.0" requires-python = ">= 3.8" authors = [ {name = "Jelle Vos", email = "J.V.Vos@tudelft.nl"}, @@ -10,16 +10,3 @@ maintainers = [ {name = "Jelle Vos", email = "J.V.Vos@tudelft.nl"} ] readme = "README.md" - -[tool.isort] -multi_line_output = 3 -include_trailing_comma = true -force_grid_wrap = 0 -use_parentheses = true -ensure_newline_before_comments = true -line_length = 100 -profile = "black" -skip_gitignore = true - -[tool.black] -line_length = 100 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f7aedc4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +sympy +six +galois>=0.3.8 +aeskeyschedule +python-sat +git+https://github.com/jellevos/fhegen.git +matplotlib diff --git a/requirements_dev.txt b/requirements_dev.txt index e8f9035..c901bea 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,5 +1,9 @@ -flake8 pytest -black -isort -pep8-naming +gensafeprime +graphviz +tabulate +ruff +mkdocs +mkdocstrings[python] +mkautodoc +pymdown-extensions diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..1eb6ad8 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,55 @@ +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".ipynb_checkpoints", + ".mypy_cache", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "build", + "dist", + "node_modules", + "site-packages", + "venv", +] + +line-length = 100 +indent-width = 4 +target-version = "py38" + +[lint] +preview = true +# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or +# McCabe complexity (`C901`) by default. +select = ["W", "E4", "E7", "E9", "F", "ERA001", "B", "D", "DOC", "PLW", "B", "SIM", "UP", "PLR", "RUF", "PIE"] +ignore = ["E203", "E501", "E731", "D105", "W293", "PLR2004", "PLR6301"] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[lint.per-file-ignores] +"oraqle/experiments/*" = ["D", "DOC"] + +[lint.pydocstyle] +# Use Google-style docstrings. +convention = "google" + +[format] +preview = true +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" +docstring-code-format = true +docstring-code-line-length = "dynamic" diff --git a/setup.cfg b/setup.cfg index fe6d2d8..ad0bae2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,9 +1,3 @@ -[flake8] -max-line-length = 100 -extend-ignore = E203, E501, W503, E731 -exclude = venv build -per-file-ignores = __init__.py:F401 - [tool:pytest] python_files = *.py norecursedirs = venv build diff --git a/tests/test_circuit_sizes_costs.py b/tests/test_circuit_sizes_costs.py new file mode 100644 index 0000000..7699f42 --- /dev/null +++ b/tests/test_circuit_sizes_costs.py @@ -0,0 +1,93 @@ +"""Test file for testing circuits sizes.""" + +from collections import Counter + +from galois import GF + +from oraqle.compiler.nodes.abstract import ArithmeticNode, UnoverloadedWrapper +from oraqle.compiler.nodes.arbitrary_arithmetic import Sum +from oraqle.compiler.nodes.leafs import Constant, Input + + +def test_size_exponentiation_chain(): + """Test.""" + gf = GF(101) + + x = Input("x", gf) + + x = x.mul(x, flatten=False) + x = x.mul(x, flatten=False) + x = x.mul(x, flatten=False) + + x = x.to_arithmetic() + assert isinstance(x, ArithmeticNode) + assert ( + x.multiplicative_size() == 3 + ), f"((x^2)^2)^2 should be 3 multiplications, but counted {x.multiplicative_size()}" + assert x.multiplicative_cost(0.5) == 1.5 + + +def test_size_sum_of_products(): + """Test.""" + gf = GF(101) + + a = Input("a", gf) + b = Input("b", gf) + c = Input("c", gf) + d = Input("d", gf) + + ab = a * b + cd = c * d + + out = ab + cd + out = out.to_arithmetic() + + assert isinstance(out, ArithmeticNode) + assert ( + out.multiplicative_size() == 2 + ), f"a * b + c * d should be 2 multiplications, but counted {out.multiplicative_size()}" + assert out.multiplicative_cost(0.7) == 2 + + +def test_size_linear_function(): + """Test.""" + gf = GF(101) + + a = Input("a", gf) + b = Input("b", gf) + c = Input("c", gf) + + out = Sum( + Counter({UnoverloadedWrapper(a): 1, UnoverloadedWrapper(b): 3, UnoverloadedWrapper(c): 1}), + gf, + gf(2), + ) + + out = out.to_arithmetic() + assert out.multiplicative_size() == 0 + assert out.multiplicative_cost(0.5) == 0 + + +def test_size_duplicate_nodes(): + """Test.""" + gf = GF(101) + + x = Input("x", gf) + + add1 = x.add(Constant(gf(1))) + add2 = x.add(Constant(gf(1))) + + mul1 = x.mul(x, flatten=False) + mul2 = x.mul(x, flatten=False) + + add3 = mul2.add(add2, flatten=False) + + mul3 = mul1.mul(add3, flatten=False) + + out = add1.add(mul3, flatten=False) + + out = out.to_arithmetic() + + assert isinstance(out, ArithmeticNode) + assert out.multiplicative_size() == 3 + assert out.multiplicative_cost(0.7) == 2.4 diff --git a/tests/test_poly2circuit.py b/tests/test_poly2circuit.py new file mode 100644 index 0000000..5b18db8 --- /dev/null +++ b/tests/test_poly2circuit.py @@ -0,0 +1,63 @@ +"""Test file for generating circuits using polynomial interpolation.""" + +import itertools + +from oraqle.compiler.func2poly import interpolate_polynomial +from oraqle.compiler.poly2circuit import construct_circuit + + +def _construct_and_test_circuit_from_bivariate_lambda(function, modulus: int, cse=False): + poly = interpolate_polynomial(function, modulus, ["x", "y"]) + circuit, gf = construct_circuit([poly], modulus) + circuit = circuit.arithmetize() + + if cse: + circuit.eliminate_subexpressions() + + for x, y in itertools.product(range(modulus), repeat=2): + print(function, x, y) + assert circuit.evaluate({"x": gf(x), "y": gf(y)}) == [function(x, y)] + + +def test_inequality_mod7(): + """Tests x != y (mod 7).""" + _construct_and_test_circuit_from_bivariate_lambda(lambda x, y: int(x != y), modulus=7) + + +def test_inequality_mod13(): + """Tests x != y (mod 13).""" + _construct_and_test_circuit_from_bivariate_lambda(lambda x, y: int(x != y), modulus=13) + + +def test_max_mod7(): + """Tests max(x, y) (mod 7).""" + _construct_and_test_circuit_from_bivariate_lambda(max, modulus=7) + + +def test_max_mod13(): + """Tests max(x, y) (mod 13).""" + _construct_and_test_circuit_from_bivariate_lambda(max, modulus=13) + + +def test_xor_mod11(): + """Tests x ^ y (mod 11).""" + _construct_and_test_circuit_from_bivariate_lambda(lambda x, y: (x ^ y) % 11, modulus=11) + + +def test_inequality_mod11_cse(): + """Tests x ^ y (mod 11) with CSE.""" + _construct_and_test_circuit_from_bivariate_lambda( + lambda x, y: int(x != y), modulus=11, cse=True + ) + + +def test_max_mod7_cse(): + """Tests max(x, y) (mod 7) with CSE.""" + _construct_and_test_circuit_from_bivariate_lambda(max, modulus=7, cse=True) + + +def test_xor_mod13_cse(): + """Tests x ^ y (mod 13) with CSE.""" + _construct_and_test_circuit_from_bivariate_lambda( + lambda x, y: (x ^ y) % 13, modulus=13, cse=True + ) diff --git a/tests/test_sugar_expressions.py b/tests/test_sugar_expressions.py new file mode 100644 index 0000000..f778d0b --- /dev/null +++ b/tests/test_sugar_expressions.py @@ -0,0 +1,22 @@ +"""Test file for sugar expressions.""" + +from galois import GF + +from oraqle.compiler.circuit import Circuit +from oraqle.compiler.nodes.arbitrary_arithmetic import sum_ +from oraqle.compiler.nodes.leafs import Input + + +def test_sum(): + """Tests the sum_ function.""" + gf = GF(127) + + a = Input("a", gf) + b = Input("b", gf) + + arithmetic_circuit = Circuit([sum_(a, 4, b, 3)]).arithmetize() + + for val_a in range(127): + for val_b in range(127): + expected = gf(val_a) + gf(val_b) + gf(7) + assert arithmetic_circuit.evaluate({"a": gf(val_a), "b": gf(val_b)}) == expected