Skip to content

Commit

Permalink
Fix casting issues for Booleans
Browse files Browse the repository at this point in the history
  • Loading branch information
jellevos committed Jan 8, 2025
1 parent 29683c7 commit 52f62a0
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 26 deletions.
7 changes: 4 additions & 3 deletions oraqle/compiler/boolean/bool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from abc import abstractmethod
import types
from typing import Dict, Optional, Type
from galois import GF, FieldArray
from oraqle.compiler.circuit import Circuit
Expand Down Expand Up @@ -166,16 +167,16 @@ def transform_to_inv_unreduced_boolean(self) -> InvUnreducedBoolean:

_class_cache = {}

def _get_dynamic_class(name, bases, attrs):
def _get_dynamic_class(name, bases):
"""Tracks dynamic classes so that cast_to_reduced_boolean on a specific class always returns the same dynamic Boolean class."""
key = (name, bases)
if key not in _class_cache:
_class_cache[key] = type(name, bases, attrs)
_class_cache[key] = types.new_class(name, bases)
return _class_cache[key]


def _cast_to[N: Node](node: Node, to: Type[N]) -> N:
CastedNode = _get_dynamic_class(f'{node.__class__.__name__}_{N.__name__}', (node.__class__, N), dict(node.__class__.__dict__)) # type: ignore
CastedNode = _get_dynamic_class(f'{node.__class__.__name__}_{to.__name__}', (node.__class__, to)) # type: ignore
node.__class__ = CastedNode
return node # type: ignore

Expand Down
36 changes: 18 additions & 18 deletions oraqle/compiler/boolean/bool_and.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from oraqle.add_chains.addition_chains_front import gen_pareto_front
from oraqle.add_chains.addition_chains_mod import chain_cost
from oraqle.add_chains.solving import extract_indices
from oraqle.compiler.boolean.bool import Boolean, BooleanConstant, InvUnreducedBoolean, ReducedBoolean, ReducedBooleanInput, UnreducedBoolean
from oraqle.compiler.boolean.bool_neg import Neg, ReducedNeg
from oraqle.compiler.boolean.bool import Boolean, BooleanConstant, InvUnreducedBoolean, ReducedBoolean, ReducedBooleanInput, UnreducedBoolean, cast_to_inv_unreduced_boolean, cast_to_reduced_boolean
from oraqle.compiler.boolean.bool_neg import ReducedNeg
from oraqle.compiler.circuit import Circuit
from oraqle.compiler.comparison.equality import IsNonZero
from oraqle.compiler.comparison.equality import ReducedIsNonZero
from oraqle.compiler.nodes.abstract import (
ArithmeticNode,
CostParetoFront,
Expand Down Expand Up @@ -105,7 +105,7 @@ def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray:
def _arithmetize_inner(self, strategy: str) -> Node:
# TODO: We need to randomize (i.e. make it a Sum with random multiplicities)
# TODO: Consider not supporting additions between Booleans unless they are cast to field elements
return sum_(*self._operands).arithmetize(strategy)
return cast_to_inv_unreduced_boolean(sum_(*self._operands)).arithmetize(strategy)

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
raise NotImplementedError("TODO!")
Expand Down Expand Up @@ -138,14 +138,14 @@ def _arithmetize_inner(self, strategy: str) -> Node: # noqa: PLR0911, PLR0912
new_operands.add(UnoverloadedWrapper(new_operand))

if len(new_operands) == 0:
return Constant(self._gf(1))
return BooleanConstant(self._gf(1))
elif len(new_operands) == 1:
return next(iter(new_operands)).node

if strategy == "naive":
return Product(Counter({operand: 1 for operand in new_operands}), self._gf).arithmetize(
return cast_to_reduced_boolean(Product(Counter({operand: 1 for operand in new_operands}), self._gf).arithmetize(
strategy
)
))

# TODO: Calling to_arithmetic here should not be necessary if we can decide the predicted depth
queue = [
Expand Down Expand Up @@ -174,34 +174,34 @@ def _arithmetize_inner(self, strategy: str) -> Node: # noqa: PLR0911, PLR0912
max_depth = popped.priority

if total_sum is None:
total_sum = Neg(popped.item, self._gf)
total_sum = ReducedNeg(popped.item, self._gf)
else:
total_sum += Neg(popped.item, self._gf)
total_sum += ReducedNeg(popped.item, self._gf)

assert total_sum is not None
final_result = Neg(IsNonZero(total_sum, self._gf), self._gf).arithmetize(strategy)
final_result = ReducedNeg(ReducedIsNonZero(total_sum, self._gf), self._gf).arithmetize(strategy)

assert max_depth is not None
heappush(queue, _PrioritizedItem(max_depth, final_result))

if len(queue) == 1:
return heappop(queue).item

dummy_node = Input("dummy_node", self._gf)
is_non_zero = IsNonZero(dummy_node, self._gf).arithmetize(strategy).to_arithmetic()
dummy_node = ReducedBooleanInput("dummy_node", self._gf)
is_non_zero = ReducedIsNonZero(dummy_node, self._gf).arithmetize(strategy).to_arithmetic()
cost = is_non_zero.multiplicative_cost(
1.0
) # FIXME: This needs to be the actual squaring cost

if len(queue) - 1 < cost:
return Product(
return cast_to_reduced_boolean(Product(
Counter({UnoverloadedWrapper(operand.item): 1 for operand in queue}), self._gf
).arithmetize(strategy)
)).arithmetize(strategy)

return Neg(
IsNonZero(
return ReducedNeg(
ReducedIsNonZero(
Sum(
Counter({UnoverloadedWrapper(Neg(node.item, self._gf)): 1 for node in queue}),
Counter({UnoverloadedWrapper(ReducedNeg(node.item, self._gf)): 1 for node in queue}),
self._gf,
),
self._gf,
Expand Down Expand Up @@ -516,7 +516,7 @@ def to_arithmetic_node(self, is_and: bool, gf: Type[FieldArray]) -> ArithmeticNo
nodes = [result]
for i, j in chain:
nodes.append(Multiplication(nodes[i], nodes[j], gf)) # type: ignore
result = nodes[-1]
result = cast_to_reduced_boolean(nodes[-1])

if is_and:
result = ReducedNeg(result, gf).arithmetize("best-effort") # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions oraqle/compiler/boolean/bool_neg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from oraqle.compiler.arithmetic.subtraction import Subtraction
from oraqle.compiler.boolean.bool import Boolean, InvUnreducedBoolean, ReducedBoolean, UnreducedBoolean
from oraqle.compiler.circuit import Circuit
from oraqle.compiler.nodes.abstract import CostParetoFront, Node
from oraqle.compiler.nodes.leafs import Constant
from oraqle.compiler.nodes.univariate import UnivariateNode
Expand Down Expand Up @@ -34,6 +35,7 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF
raise NotImplementedError("TODO!")

def transform_to_reduced_boolean(self) -> ReducedBoolean:
Circuit([self._node]).to_pdf("debug.pdf")
return ReducedNeg(self._node.transform_to_reduced_boolean(), self._gf)

def transform_to_unreduced_boolean(self) -> UnreducedBoolean:
Expand Down
4 changes: 2 additions & 2 deletions oraqle/compiler/boolean/bool_or.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from galois import GF, FieldArray

from oraqle.compiler.boolean.bool import Boolean, BooleanConstant, InvUnreducedBoolean, ReducedBoolean, ReducedBooleanInput, UnreducedBoolean
from oraqle.compiler.boolean.bool import Boolean, BooleanConstant, InvUnreducedBoolean, ReducedBoolean, ReducedBooleanInput, UnreducedBoolean, cast_to_unreduced_boolean
from oraqle.compiler.boolean.bool_and import ReducedAnd, _find_depth_cost_front
from oraqle.compiler.boolean.bool_neg import ReducedNeg
from oraqle.compiler.nodes.abstract import CostParetoFront, Node, UnoverloadedWrapper
Expand Down Expand Up @@ -86,7 +86,7 @@ def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray:

def _arithmetize_inner(self, strategy: str) -> Node:
# TODO: We need to randomize (i.e. make it a Sum with random multiplicities)
return sum_(*self._operands).arithmetize(strategy)
return cast_to_unreduced_boolean(sum_(*self._operands)).arithmetize(strategy)

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
raise NotImplementedError("TODO!")
Expand Down
6 changes: 3 additions & 3 deletions oraqle/compiler/comparison/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from oraqle.compiler.arithmetic.exponentiation import Power
from oraqle.compiler.arithmetic.subtraction import Subtraction
from oraqle.compiler.boolean.bool import Boolean, InvUnreducedBoolean, ReducedBoolean, UnreducedBoolean
from oraqle.compiler.boolean.bool import Boolean, InvUnreducedBoolean, ReducedBoolean, UnreducedBoolean, cast_to_reduced_boolean
from oraqle.compiler.boolean.bool_neg import Neg, ReducedNeg
from oraqle.compiler.nodes.abstract import CostParetoFront, Node
from oraqle.compiler.nodes.binary_arithmetic import CommutativeBinaryNode
Expand Down Expand Up @@ -64,10 +64,10 @@ def _operation_inner(self, input: FieldArray) -> FieldArray:
return input != 0

def _arithmetize_inner(self, strategy: str) -> Node:
return Power(self._node, self._gf.order - 1, self._gf).arithmetize(strategy)
return cast_to_reduced_boolean(Power(self._node, self._gf.order - 1, self._gf)).arithmetize(strategy)

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
return Power(self._node, self._gf.order - 1, self._gf).arithmetize_depth_aware(
return cast_to_reduced_boolean(Power(self._node, self._gf.order - 1, self._gf)).arithmetize_depth_aware(
cost_of_squaring
)

Expand Down

0 comments on commit 52f62a0

Please sign in to comment.