Skip to content

Commit

Permalink
feat: arithmetic bug fix and add compiler to ci (#394)
Browse files Browse the repository at this point in the history
  • Loading branch information
tamirhemo authored Mar 16, 2024
1 parent a16c0f9 commit 644e182
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 14 deletions.
12 changes: 11 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
toolchain: nightly-2024-01-25
override: true

- name: Run cargo test
- name: Run cargo test on sp1-core
uses: actions-rs/cargo@v1
with:
command: test
Expand All @@ -56,6 +56,16 @@ jobs:
RUST_LOG: 1
RUST_BACKTRACE: 1

- name: Run cargo test on sp1-recursion-compiler
uses: actions-rs/cargo@v1
with:
command: test
args: -p sp1-recursion-compiler --release
env:
RUSTFLAGS: -Copt-level=3 -Cdebug-assertions -Coverflow-checks=y -Cdebuginfo=0
RUST_LOG: 1
RUST_BACKTRACE: 1

- name: Run cargo test with no default features
uses: actions-rs/cargo@v1
with:
Expand Down
18 changes: 9 additions & 9 deletions recursion/compiler/src/gnark/lib/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ type Circuit struct {

func (circuit *Circuit) Define(api frontend.API) error {
fieldChip := babybear.NewChip(api)

// Variables.
var felt0 *babybear.Variable
var felt2 *babybear.Variable
var backend1 frontend.Variable
var var0 frontend.Variable
var backend0 frontend.Variable
var felt1 *babybear.Variable
var var0 frontend.Variable
var felt2 *babybear.Variable

var felt0 *babybear.Variable

// Operations.
var0 = frontend.Variable(0)
felt0 = babybear.NewVariable(0)
Expand All @@ -33,10 +33,10 @@ func (circuit *Circuit) Define(api frontend.API) error {
}
fieldChip.AssertEq(felt0, babybear.NewVariable(144))
backend0 = api.IsZero(api.Sub(var0, var0))
felt0 = fieldChip.Select(backend0, fieldChip.Add(felt1, babybear.NewVariable(0)), felt0)
felt0 = fieldChip.Select(backend0, fieldChip.Add(felt0, felt1), felt0)
felt0 = fieldChip.Select(backend0, fieldChip.Add(felt1, babybear.NewVariable(0)), felt0)
felt0 = fieldChip.Select(backend0, fieldChip.Add(felt0, felt1), felt0)
backend1 = api.Sub(frontend.Variable(1), api.IsZero(api.Sub(var0, var0)))
felt0 = fieldChip.Select(backend1, fieldChip.Add(felt1, babybear.NewVariable(0)), felt0)
felt0 = fieldChip.Select(backend1, fieldChip.Add(felt0, felt1), felt0)
felt0 = fieldChip.Select(backend1, fieldChip.Add(felt1, babybear.NewVariable(0)), felt0)
felt0 = fieldChip.Select(backend1, fieldChip.Add(felt0, felt1), felt0)
return nil
}
89 changes: 85 additions & 4 deletions recursion/compiler/src/ir/symbolic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,14 @@ impl<F: Field, EF: ExtensionField<F>, E: Any> Add<E> for Ext<F, EF> {

fn add(self, rhs: E) -> Self::Output {
let rhs: ExtOperand<F, EF> = rhs.to_operand();
SymbolicExt::<F, EF>::from(self) + rhs
match rhs {
ExtOperand::Base(f) => SymbolicExt::Base(Rc::new(SymbolicFelt::Const(f))) + self,
ExtOperand::Const(ef) => SymbolicExt::Const(ef) + self,
ExtOperand::Felt(f) => SymbolicExt::Base(Rc::new(SymbolicFelt::Val(f))) + self,
ExtOperand::Ext(e) => SymbolicExt::Val(e) + self,
ExtOperand::SymFelt(f) => SymbolicExt::Base(Rc::new(f)) + self,
ExtOperand::Sym(e) => e + self,
}
}
}

Expand All @@ -676,7 +683,14 @@ impl<F: Field, EF: ExtensionField<F>, E: Any> Mul<E> for Ext<F, EF> {

fn mul(self, rhs: E) -> Self::Output {
let rhs: ExtOperand<F, EF> = rhs.to_operand();
SymbolicExt::<F, EF>::from(self) * rhs
match rhs {
ExtOperand::Base(f) => SymbolicExt::Base(Rc::new(SymbolicFelt::Const(f))) * self,
ExtOperand::Const(ef) => SymbolicExt::Const(ef) * self,
ExtOperand::Felt(f) => SymbolicExt::Base(Rc::new(SymbolicFelt::Val(f))) * self,
ExtOperand::Ext(e) => SymbolicExt::Val(e) * self,
ExtOperand::SymFelt(f) => SymbolicExt::Base(Rc::new(f)) * self,
ExtOperand::Sym(e) => e * self,
}
}
}

Expand All @@ -685,7 +699,29 @@ impl<F: Field, EF: ExtensionField<F>, E: Any> Sub<E> for Ext<F, EF> {

fn sub(self, rhs: E) -> Self::Output {
let rhs: ExtOperand<F, EF> = rhs.to_operand();
SymbolicExt::<F, EF>::from(self) - rhs
match rhs {
ExtOperand::Base(f) => SymbolicExt::Sub(
Rc::new(SymbolicExt::Val(self)),
Rc::new(SymbolicExt::Base(Rc::new(SymbolicFelt::Const(f)))),
),
ExtOperand::Const(ef) => SymbolicExt::Sub(
Rc::new(SymbolicExt::Val(self)),
Rc::new(SymbolicExt::Const(ef)),
),
ExtOperand::Felt(f) => SymbolicExt::Sub(
Rc::new(SymbolicExt::Val(self)),
Rc::new(SymbolicExt::Base(Rc::new(SymbolicFelt::Val(f)))),
),
ExtOperand::Ext(e) => SymbolicExt::Sub(
Rc::new(SymbolicExt::Val(self)),
Rc::new(SymbolicExt::Val(e)),
),
ExtOperand::SymFelt(f) => SymbolicExt::Sub(
Rc::new(SymbolicExt::Val(self)),
Rc::new(SymbolicExt::Base(Rc::new(f))),
),
ExtOperand::Sym(e) => SymbolicExt::Sub(Rc::new(SymbolicExt::Val(self)), Rc::new(e)),
}
}
}

Expand All @@ -694,7 +730,29 @@ impl<F: Field, EF: ExtensionField<F>, E: Any> Div<E> for Ext<F, EF> {

fn div(self, rhs: E) -> Self::Output {
let rhs: ExtOperand<F, EF> = rhs.to_operand();
SymbolicExt::<F, EF>::from(self) / rhs
match rhs {
ExtOperand::Base(f) => SymbolicExt::Div(
Rc::new(SymbolicExt::Val(self)),
Rc::new(SymbolicExt::Base(Rc::new(SymbolicFelt::Const(f)))),
),
ExtOperand::Const(ef) => SymbolicExt::Div(
Rc::new(SymbolicExt::Val(self)),
Rc::new(SymbolicExt::Const(ef)),
),
ExtOperand::Felt(f) => SymbolicExt::Div(
Rc::new(SymbolicExt::Val(self)),
Rc::new(SymbolicExt::Base(Rc::new(SymbolicFelt::Val(f)))),
),
ExtOperand::Ext(e) => SymbolicExt::Div(
Rc::new(SymbolicExt::Val(self)),
Rc::new(SymbolicExt::Val(e)),
),
ExtOperand::SymFelt(f) => SymbolicExt::Div(
Rc::new(SymbolicExt::Val(self)),
Rc::new(SymbolicExt::Base(Rc::new(f))),
),
ExtOperand::Sym(e) => SymbolicExt::Div(Rc::new(SymbolicExt::Val(self)), Rc::new(e)),
}
}
}

Expand Down Expand Up @@ -738,6 +796,14 @@ impl<F> Div for Felt<F> {
}
}

impl<F> Div<F> for Felt<F> {
type Output = SymbolicFelt<F>;

fn div(self, rhs: F) -> Self::Output {
SymbolicFelt::from(self) / rhs
}
}

impl<F> Div<Felt<F>> for SymbolicFelt<F> {
type Output = SymbolicFelt<F>;

Expand All @@ -746,6 +812,14 @@ impl<F> Div<Felt<F>> for SymbolicFelt<F> {
}
}

impl<F> Div<F> for SymbolicFelt<F> {
type Output = SymbolicFelt<F>;

fn div(self, rhs: F) -> Self::Output {
SymbolicFelt::Div(Rc::new(self), Rc::new(SymbolicFelt::Const(rhs)))
}
}

impl<N> Sub<SymbolicVar<N>> for Var<N> {
type Output = SymbolicVar<N>;

Expand Down Expand Up @@ -914,6 +988,13 @@ impl<F: Field, EF: ExtensionField<F>, E: Any> ExtensionOperand<F, EF> for E {
let value = unsafe { mem::transmute_copy::<E, Ext<F, EF>>(&self) };
ExtOperand::<F, EF>::Ext(value)
}
ty if ty == TypeId::of::<SymbolicFelt<F>>() => {
// *Saftey*: We know that E is a Symbolic Felt<F> and we can transmute it to
// SymbolicFelt<F> but we need to clone the pointer.
let value_ref = unsafe { mem::transmute::<&E, &SymbolicFelt<F>>(&self) };
let value = value_ref.clone();
ExtOperand::<F, EF>::SymFelt(value)
}
ty if ty == TypeId::of::<SymbolicExt<F, EF>>() => {
// *Saftey*: We know that E is a SymbolicExt<F, EF> and we can transmute it to
// SymbolicExt<F, EF> but we need to clone the pointer.
Expand Down

0 comments on commit 644e182

Please sign in to comment.