From a16c0f9a99cf5e703505b3f3200caf5dbcc9ead0 Mon Sep 17 00:00:00 2001 From: Tamir Hemo Date: Fri, 15 Mar 2024 14:07:40 -0700 Subject: [PATCH] feat: array and symbolic evaluation (#390) --- Cargo.lock | 1 + core/src/stark/folder.rs | 222 ++++++-- core/src/stark/verifier.rs | 1 + recursion/compiler/Cargo.toml | 1 + recursion/compiler/examples/verifier.rs | 33 ++ recursion/compiler/src/asm/compiler.rs | 30 +- recursion/compiler/src/asm/heap.rs | 4 - recursion/compiler/src/asm/instruction.rs | 146 +---- recursion/compiler/src/asm/mod.rs | 2 - recursion/compiler/src/ir/collections.rs | 65 ++- recursion/compiler/src/ir/instructions.rs | 26 +- recursion/compiler/src/ir/ptr.rs | 201 ++++++- recursion/compiler/src/ir/symbolic.rs | 637 ++++++++++++++++------ recursion/compiler/src/ir/types.rs | 39 +- recursion/compiler/src/ir/var.rs | 7 +- recursion/compiler/tests/arithmetic.rs | 25 +- recursion/compiler/tests/array.rs | 79 +++ recursion/core/src/lib.rs | 34 +- recursion/core/src/runtime/instruction.rs | 55 +- recursion/core/src/runtime/mod.rs | 110 ++-- recursion/core/src/runtime/opcode.rs | 22 +- 21 files changed, 1208 insertions(+), 532 deletions(-) create mode 100644 recursion/compiler/examples/verifier.rs delete mode 100644 recursion/compiler/src/asm/heap.rs create mode 100644 recursion/compiler/tests/array.rs diff --git a/Cargo.lock b/Cargo.lock index f428e1cc84..72de36de4f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2590,6 +2590,7 @@ dependencies = [ name = "sp1-recursion-compiler" version = "0.1.0" dependencies = [ + "p3-air", "p3-baby-bear", "p3-field", "rand", diff --git a/core/src/stark/folder.rs b/core/src/stark/folder.rs index df61abc977..47d4217e30 100644 --- a/core/src/stark/folder.rs +++ b/core/src/stark/folder.rs @@ -1,7 +1,12 @@ -use super::{PackedChallenge, PackedVal, StarkGenericConfig}; +use std::{ + marker::PhantomData, + ops::{Add, Mul, MulAssign, Sub}, +}; + +use super::{Challenge, PackedChallenge, PackedVal, StarkGenericConfig, Val}; use crate::air::{EmptyMessageBuilder, MultiTableAirBuilder}; use p3_air::{AirBuilder, ExtensionBuilder, PairBuilder, PermutationAirBuilder, TwoRowMatrixView}; -use p3_field::AbstractField; +use p3_field::{AbstractField, ExtensionField, Field}; /// A folder for prover constraints. pub struct ProverConstraintFolder<'a, SC: StarkGenericConfig> { @@ -95,57 +100,110 @@ impl<'a, SC: StarkGenericConfig> PairBuilder for ProverConstraintFolder<'a, SC> impl<'a, SC: StarkGenericConfig> EmptyMessageBuilder for ProverConstraintFolder<'a, SC> {} +pub type VerifierConstraintFolder<'a, SC> = + GenericVerifierConstraintFolder<'a, Val, Challenge, Challenge, Challenge>; + /// A folder for verifier constraints. -pub struct VerifierConstraintFolder<'a, SC: StarkGenericConfig> { - pub preprocessed: TwoRowMatrixView<'a, SC::Challenge>, - pub main: TwoRowMatrixView<'a, SC::Challenge>, - pub perm: TwoRowMatrixView<'a, SC::Challenge>, - pub perm_challenges: &'a [SC::Challenge], - pub cumulative_sum: SC::Challenge, - pub is_first_row: SC::Challenge, - pub is_last_row: SC::Challenge, - pub is_transition: SC::Challenge, - pub alpha: SC::Challenge, - pub accumulator: SC::Challenge, +pub struct GenericVerifierConstraintFolder<'a, F, EF, Var, Expr> { + pub preprocessed: TwoRowMatrixView<'a, Var>, + pub main: TwoRowMatrixView<'a, Var>, + pub perm: TwoRowMatrixView<'a, Var>, + pub perm_challenges: &'a [EF], + pub cumulative_sum: Var, + pub is_first_row: Var, + pub is_last_row: Var, + pub is_transition: Var, + pub alpha: EF, + pub accumulator: Expr, + pub _marker: PhantomData, } -impl<'a, SC: StarkGenericConfig> AirBuilder for VerifierConstraintFolder<'a, SC> { - type F = SC::Val; - type Expr = SC::Challenge; - type Var = SC::Challenge; - type M = TwoRowMatrixView<'a, SC::Challenge>; +impl<'a, F, EF, Var, Expr> AirBuilder for GenericVerifierConstraintFolder<'a, F, EF, Var, Expr> +where + F: Field, + EF: ExtensionField, + Expr: AbstractField + + From + + Add + + Add + + Sub + + Sub + + Mul + + Mul + + MulAssign, + Var: Into + + Copy + + Add + + Add + + Add + + Sub + + Sub + + Sub + + Mul + + Mul + + Mul, +{ + type F = F; + type Expr = Expr; + type Var = Var; + type M = TwoRowMatrixView<'a, Var>; fn main(&self) -> Self::M { self.main } fn is_first_row(&self) -> Self::Expr { - self.is_first_row + self.is_first_row.into() } fn is_last_row(&self) -> Self::Expr { - self.is_last_row + self.is_last_row.into() } fn is_transition_window(&self, size: usize) -> Self::Expr { if size == 2 { - self.is_transition + self.is_transition.into() } else { panic!("uni-stark only supports a window size of 2") } } fn assert_zero>(&mut self, x: I) { - let x: SC::Challenge = x.into(); + let x: Expr = x.into(); self.accumulator *= self.alpha; self.accumulator += x; } } -impl<'a, SC: StarkGenericConfig> ExtensionBuilder for VerifierConstraintFolder<'a, SC> { - type EF = SC::Challenge; - type ExprEF = SC::Challenge; - type VarEF = SC::Challenge; +impl<'a, F, EF, Var, Expr> ExtensionBuilder + for GenericVerifierConstraintFolder<'a, F, EF, Var, Expr> +where + F: Field, + EF: ExtensionField, + Expr: AbstractField + + From + + Add + + Add + + Sub + + Sub + + Mul + + Mul + + MulAssign, + Var: Into + + Copy + + Add + + Add + + Add + + Sub + + Sub + + Sub + + Mul + + Mul + + Mul, +{ + type EF = EF; + type ExprEF = Expr; + type VarEF = Var; fn assert_zero_ext(&mut self, x: I) where @@ -155,8 +213,33 @@ impl<'a, SC: StarkGenericConfig> ExtensionBuilder for VerifierConstraintFolder<' } } -impl<'a, SC: StarkGenericConfig> PermutationAirBuilder for VerifierConstraintFolder<'a, SC> { - type MP = TwoRowMatrixView<'a, SC::Challenge>; +impl<'a, F, EF, Var, Expr> PermutationAirBuilder + for GenericVerifierConstraintFolder<'a, F, EF, Var, Expr> +where + F: Field, + EF: ExtensionField, + Expr: AbstractField + + From + + Add + + Add + + Sub + + Sub + + Mul + + Mul + + MulAssign, + Var: Into + + Copy + + Add + + Add + + Add + + Sub + + Sub + + Sub + + Mul + + Mul + + Mul, +{ + type MP = TwoRowMatrixView<'a, Var>; fn permutation(&self) -> Self::MP { self.perm @@ -167,18 +250,93 @@ impl<'a, SC: StarkGenericConfig> PermutationAirBuilder for VerifierConstraintFol } } -impl<'a, SC: StarkGenericConfig> MultiTableAirBuilder for VerifierConstraintFolder<'a, SC> { - type Sum = SC::Challenge; +impl<'a, F, EF, Var, Expr> MultiTableAirBuilder + for GenericVerifierConstraintFolder<'a, F, EF, Var, Expr> +where + F: Field, + EF: ExtensionField, + Expr: AbstractField + + From + + Add + + Add + + Sub + + Sub + + Mul + + Mul + + MulAssign, + Var: Into + + Copy + + Add + + Add + + Add + + Sub + + Sub + + Sub + + Mul + + Mul + + Mul, +{ + type Sum = Var; fn cumulative_sum(&self) -> Self::Sum { self.cumulative_sum } } -impl<'a, SC: StarkGenericConfig> PairBuilder for VerifierConstraintFolder<'a, SC> { +impl<'a, F, EF, Var, Expr> PairBuilder for GenericVerifierConstraintFolder<'a, F, EF, Var, Expr> +where + F: Field, + EF: ExtensionField, + Expr: AbstractField + + From + + Add + + Add + + Sub + + Sub + + Mul + + Mul + + MulAssign, + Var: Into + + Copy + + Add + + Add + + Add + + Sub + + Sub + + Sub + + Mul + + Mul + + Mul, +{ fn preprocessed(&self) -> Self::M { self.preprocessed } } -impl<'a, SC: StarkGenericConfig> EmptyMessageBuilder for VerifierConstraintFolder<'a, SC> {} +impl<'a, F, EF, Var, Expr> EmptyMessageBuilder + for GenericVerifierConstraintFolder<'a, F, EF, Var, Expr> +where + F: Field, + EF: ExtensionField, + Expr: AbstractField + + From + + Add + + Add + + Sub + + Sub + + Mul + + Mul + + MulAssign, + Var: Into + + Copy + + Add + + Add + + Add + + Sub + + Sub + + Sub + + Mul + + Mul + + Mul, +{ +} diff --git a/core/src/stark/verifier.rs b/core/src/stark/verifier.rs index 31cd63c6db..46f52b1aff 100644 --- a/core/src/stark/verifier.rs +++ b/core/src/stark/verifier.rs @@ -216,6 +216,7 @@ impl> Verifier { is_transition, alpha, accumulator: SC::Challenge::zero(), + _marker: PhantomData, }; chip.eval(&mut folder); diff --git a/recursion/compiler/Cargo.toml b/recursion/compiler/Cargo.toml index 14d5ca3bab..2fbff1a255 100644 --- a/recursion/compiler/Cargo.toml +++ b/recursion/compiler/Cargo.toml @@ -11,5 +11,6 @@ sp1-recursion-core = { path = "../core" } [dev-dependencies] p3-baby-bear = { workspace = true } +p3-air = { workspace = true } sp1-core = { path = "../../core" } rand = "0.8.4" diff --git a/recursion/compiler/examples/verifier.rs b/recursion/compiler/examples/verifier.rs new file mode 100644 index 0000000000..36d340ae3a --- /dev/null +++ b/recursion/compiler/examples/verifier.rs @@ -0,0 +1,33 @@ +use p3_air::Air; + +use sp1_core::air::MachineAir; +use sp1_core::stark::{GenericVerifierConstraintFolder, MachineChip, StarkGenericConfig}; +use sp1_recursion_compiler::ir::{Ext, SymbolicExt}; + +#[allow(clippy::type_complexity)] +#[allow(dead_code)] +fn verify_constraints>( + chip: MachineChip, + folder: &mut GenericVerifierConstraintFolder< + SC::Val, + SC::Challenge, + Ext, + SymbolicExt, + >, +) where + A: for<'a> Air< + GenericVerifierConstraintFolder< + 'a, + SC::Val, + SC::Challenge, + Ext, + SymbolicExt, + >, + >, +{ + chip.eval(folder); +} + +fn main() { + println!("Hello, world!"); +} diff --git a/recursion/compiler/src/asm/compiler.rs b/recursion/compiler/src/asm/compiler.rs index 9c979e918e..edd8f25d1f 100644 --- a/recursion/compiler/src/asm/compiler.rs +++ b/recursion/compiler/src/asm/compiler.rs @@ -9,11 +9,12 @@ use alloc::vec::Vec; use p3_field::ExtensionField; use p3_field::PrimeField32; use sp1_recursion_core::runtime::Program; +use sp1_recursion_core::runtime::STACK_SIZE; use crate::asm::AsmInstruction; use crate::ir::Builder; use crate::ir::Usize; -use crate::ir::{Config, DslIR, Ext, Felt, Var}; +use crate::ir::{Config, DslIR, Ext, Felt, Ptr, Var}; use p3_field::Field; pub(crate) const ZERO: i32 = 0; @@ -63,6 +64,12 @@ impl Felt { } } +impl Ptr { + fn fp(&self) -> i32 { + self.address.fp() + } +} + impl Ext { pub fn fp(&self) -> i32 { -((self.0 as i32) * 3 + 8) @@ -79,6 +86,9 @@ impl> AsmCompiler { } pub fn build(&mut self, operations: Vec>>) { + // Set the heap pointer value according to stack size + let stack_size = F::from_canonical_usize(STACK_SIZE + 4); + self.push(AsmInstruction::IMM(HEAP_PTR, stack_size)); for op in operations { match op { DslIR::Imm(dst, src) => { @@ -432,22 +442,30 @@ impl> AsmCompiler { }; if_compiler.then(|builder| builder.push(AsmInstruction::TRAP)); } - _ => todo!(), + DslIR::Alloc(ptr, len) => { + self.alloc(ptr, len); + } + DslIR::LoadV(var, ptr) => self.push(AsmInstruction::LW(var.fp(), ptr.fp())), + DslIR::LoadF(var, ptr) => self.push(AsmInstruction::LW(var.fp(), ptr.fp())), + DslIR::LoadE(var, ptr) => self.push(AsmInstruction::LE(var.fp(), ptr.fp())), + DslIR::StoreV(ptr, var) => self.push(AsmInstruction::SW(ptr.fp(), var.fp())), + DslIR::StoreF(ptr, var) => self.push(AsmInstruction::SW(ptr.fp(), var.fp())), + DslIR::StoreE(ptr, var) => self.push(AsmInstruction::SE(ptr.fp(), var.fp())), } } } - pub fn alloc(&mut self, ptr: Var, len: Usize) { + pub fn alloc(&mut self, ptr: Ptr, len: Usize) { // Load the current heap ptr address to the stack value and advance the heap ptr. match len { Usize::Const(len) => { let len = F::from_canonical_usize(len); - self.push(AsmInstruction::IMM(ptr.fp(), len)); + self.push(AsmInstruction::ADDI(ptr.fp(), HEAP_PTR, F::zero())); self.push(AsmInstruction::ADDI(HEAP_PTR, HEAP_PTR, len)); } Usize::Var(len) => { - self.push(AsmInstruction::ADDI(ptr.fp(), len.fp(), F::zero())); - self.push(AsmInstruction::ADDI(HEAP_PTR, HEAP_PTR, F::one())); + self.push(AsmInstruction::ADDI(ptr.fp(), HEAP_PTR, F::zero())); + self.push(AsmInstruction::ADD(HEAP_PTR, HEAP_PTR, len.fp())); } } } diff --git a/recursion/compiler/src/asm/heap.rs b/recursion/compiler/src/asm/heap.rs deleted file mode 100644 index b2d977a17c..0000000000 --- a/recursion/compiler/src/asm/heap.rs +++ /dev/null @@ -1,4 +0,0 @@ -#[allow(dead_code)] -pub struct Ptr { - fp: i32, -} diff --git a/recursion/compiler/src/asm/instruction.rs b/recursion/compiler/src/asm/instruction.rs index e95e3ff15f..15faa68833 100644 --- a/recursion/compiler/src/asm/instruction.rs +++ b/recursion/compiler/src/asm/instruction.rs @@ -12,9 +12,9 @@ use super::ZERO; #[derive(Debug, Clone)] pub enum AsmInstruction { - /// Load work (src, dst) : load a value from the address stored at dest(fp) into src(fp). + /// Load work (dst, src) : load a value from the address stored at src(fp) into dstfp). LW(i32, i32), - /// Store word (src, dst) : store a value from src(fp) into the address stored at dest(fp). + /// Store word (dst, src) : store a value from src(fp) into the address stored at dest(fp). SW(i32, i32), // Get immediate (dst, value) : load a value into the dest(fp). IMM(i32, F), @@ -109,36 +109,15 @@ impl> AsmInstruction { let f_u32 = |x: F| [x, F::zero(), F::zero(), F::zero()]; let zero = [F::zero(), F::zero(), F::zero(), F::zero()]; match self { - AsmInstruction::LW(dst, src) => Instruction::new( - Opcode::LW, - i32_f(dst), - i32_f_arr(src), - zero, - false, - false, - false, - false, - ), - AsmInstruction::SW(dst, src) => Instruction::new( - Opcode::SW, - i32_f(dst), - i32_f_arr(src), - zero, - false, - false, - false, - false, - ), - AsmInstruction::IMM(dst, value) => Instruction::new( - Opcode::LW, - i32_f(dst), - f_u32(value), - zero, - true, - false, - false, - false, - ), + AsmInstruction::LW(dst, src) => { + Instruction::new(Opcode::LW, i32_f(dst), i32_f_arr(src), zero, false, false) + } + AsmInstruction::SW(dst, src) => { + Instruction::new(Opcode::SW, i32_f(dst), i32_f_arr(src), zero, false, false) + } + AsmInstruction::IMM(dst, value) => { + Instruction::new(Opcode::LW, i32_f(dst), f_u32(value), zero, true, false) + } AsmInstruction::ADD(dst, lhs, rhs) => Instruction::new( Opcode::ADD, i32_f(dst), @@ -146,8 +125,6 @@ impl> AsmInstruction { i32_f_arr(rhs), false, false, - false, - false, ), AsmInstruction::ADDI(dst, lhs, rhs) => Instruction::new( Opcode::ADD, @@ -155,9 +132,7 @@ impl> AsmInstruction { i32_f_arr(lhs), f_u32(rhs), false, - false, true, - false, ), AsmInstruction::SUB(dst, lhs, rhs) => Instruction::new( Opcode::SUB, @@ -166,8 +141,6 @@ impl> AsmInstruction { i32_f_arr(rhs), false, false, - false, - false, ), AsmInstruction::SUBI(dst, lhs, rhs) => Instruction::new( Opcode::SUB, @@ -175,9 +148,7 @@ impl> AsmInstruction { i32_f_arr(lhs), f_u32(rhs), false, - false, true, - false, ), AsmInstruction::SUBIN(dst, lhs, rhs) => Instruction::new( Opcode::SUB, @@ -186,8 +157,6 @@ impl> AsmInstruction { i32_f_arr(rhs), true, false, - false, - false, ), AsmInstruction::MUL(dst, lhs, rhs) => Instruction::new( Opcode::MUL, @@ -196,8 +165,6 @@ impl> AsmInstruction { i32_f_arr(rhs), false, false, - false, - false, ), AsmInstruction::MULI(dst, lhs, rhs) => Instruction::new( Opcode::MUL, @@ -205,9 +172,7 @@ impl> AsmInstruction { i32_f_arr(lhs), f_u32(rhs), false, - false, true, - false, ), AsmInstruction::DIV(dst, lhs, rhs) => Instruction::new( Opcode::DIV, @@ -216,8 +181,6 @@ impl> AsmInstruction { i32_f_arr(rhs), false, false, - false, - false, ), AsmInstruction::DIVI(dst, lhs, rhs) => Instruction::new( Opcode::DIV, @@ -225,9 +188,7 @@ impl> AsmInstruction { i32_f_arr(lhs), f_u32(rhs), false, - false, true, - false, ), AsmInstruction::DIVIN(dst, lhs, rhs) => Instruction::new( Opcode::DIV, @@ -236,38 +197,20 @@ impl> AsmInstruction { i32_f_arr(rhs), true, false, - false, - false, - ), - AsmInstruction::LE(dst, src) => Instruction::new( - Opcode::LW, - i32_f(dst), - i32_f_arr(src), - zero, - false, - false, - false, - false, - ), - AsmInstruction::SE(dst, src) => Instruction::new( - Opcode::SW, - i32_f(dst), - i32_f_arr(src), - zero, - false, - false, - false, - false, ), + AsmInstruction::LE(dst, src) => { + Instruction::new(Opcode::LE, i32_f(dst), i32_f_arr(src), zero, false, false) + } + AsmInstruction::SE(dst, src) => { + Instruction::new(Opcode::SE, i32_f(dst), i32_f_arr(src), zero, false, false) + } AsmInstruction::EIMM(dst, value) => Instruction::new( - Opcode::LW, + Opcode::LE, i32_f(dst), value.as_base_slice().try_into().unwrap(), zero, - false, true, false, - false, ), AsmInstruction::EADD(dst, lhs, rhs) => Instruction::new( Opcode::EADD, @@ -276,8 +219,6 @@ impl> AsmInstruction { i32_f_arr(rhs), false, false, - false, - false, ), AsmInstruction::EADDI(dst, lhs, rhs) => Instruction::new( Opcode::EADD, @@ -285,8 +226,6 @@ impl> AsmInstruction { i32_f_arr(lhs), rhs.as_base_slice().try_into().unwrap(), false, - false, - false, true, ), AsmInstruction::ESUB(dst, lhs, rhs) => Instruction::new( @@ -296,8 +235,6 @@ impl> AsmInstruction { i32_f_arr(rhs), false, false, - false, - false, ), AsmInstruction::ESUBI(dst, lhs, rhs) => Instruction::new( Opcode::ESUB, @@ -305,8 +242,6 @@ impl> AsmInstruction { i32_f_arr(lhs), rhs.as_base_slice().try_into().unwrap(), false, - false, - false, true, ), AsmInstruction::ESUBIN(dst, lhs, rhs) => Instruction::new( @@ -314,10 +249,8 @@ impl> AsmInstruction { i32_f(dst), lhs.as_base_slice().try_into().unwrap(), i32_f_arr(rhs), - false, true, false, - false, ), AsmInstruction::EMUL(dst, lhs, rhs) => Instruction::new( Opcode::EMUL, @@ -326,8 +259,6 @@ impl> AsmInstruction { i32_f_arr(rhs), false, false, - false, - false, ), AsmInstruction::EMULI(dst, lhs, rhs) => Instruction::new( Opcode::EMUL, @@ -335,8 +266,6 @@ impl> AsmInstruction { i32_f_arr(lhs), rhs.as_base_slice().try_into().unwrap(), false, - false, - false, true, ), AsmInstruction::EDIV(dst, lhs, rhs) => Instruction::new( @@ -346,8 +275,6 @@ impl> AsmInstruction { i32_f_arr(rhs), false, false, - false, - false, ), AsmInstruction::EDIVI(dst, lhs, rhs) => Instruction::new( Opcode::EDIV, @@ -355,8 +282,6 @@ impl> AsmInstruction { i32_f_arr(lhs), rhs.as_base_slice().try_into().unwrap(), false, - false, - false, true, ), AsmInstruction::EDIVIN(dst, lhs, rhs) => Instruction::new( @@ -364,10 +289,8 @@ impl> AsmInstruction { i32_f(dst), lhs.as_base_slice().try_into().unwrap(), i32_f_arr(rhs), - false, true, false, - false, ), AsmInstruction::BEQ(label, lhs, rhs) => { let offset = @@ -378,8 +301,6 @@ impl> AsmInstruction { i32_f_arr(rhs), f_u32(offset), false, - false, - false, true, ) } @@ -392,9 +313,7 @@ impl> AsmInstruction { f_u32(rhs), f_u32(offset), true, - false, true, - false, ) } AsmInstruction::BNE(label, lhs, rhs) => { @@ -406,9 +325,7 @@ impl> AsmInstruction { i32_f_arr(rhs), f_u32(offset), false, - false, true, - false, ) } AsmInstruction::BNEI(label, lhs, rhs) => { @@ -420,9 +337,7 @@ impl> AsmInstruction { f_u32(rhs), f_u32(offset), true, - false, true, - false, ) } AsmInstruction::EBNE(label, lhs, rhs) => { @@ -434,9 +349,7 @@ impl> AsmInstruction { i32_f_arr(rhs), f_u32(offset), false, - false, true, - false, ) } AsmInstruction::EBNEI(label, lhs, rhs) => { @@ -447,10 +360,8 @@ impl> AsmInstruction { i32_f(lhs), rhs.as_base_slice().try_into().unwrap(), f_u32(offset), - false, true, true, - false, ) } AsmInstruction::EBEQ(label, lhs, rhs) => { @@ -462,9 +373,7 @@ impl> AsmInstruction { i32_f_arr(rhs), f_u32(offset), false, - false, true, - false, ) } AsmInstruction::EBEQI(label, lhs, rhs) => { @@ -475,10 +384,8 @@ impl> AsmInstruction { i32_f(lhs), rhs.as_base_slice().try_into().unwrap(), f_u32(offset), - false, true, true, - false, ) } AsmInstruction::JAL(dst, label, offset) => { @@ -490,9 +397,7 @@ impl> AsmInstruction { f_u32(pc_offset), f_u32(offset), false, - false, true, - false, ) } AsmInstruction::JALR(dst, label, offset) => Instruction::new( @@ -502,19 +407,10 @@ impl> AsmInstruction { i32_f_arr(offset), false, false, - false, - false, - ), - AsmInstruction::TRAP => Instruction::new( - Opcode::TRAP, - F::zero(), - zero, - zero, - false, - false, - false, - false, ), + AsmInstruction::TRAP => { + Instruction::new(Opcode::TRAP, F::zero(), zero, zero, false, false) + } } } diff --git a/recursion/compiler/src/asm/mod.rs b/recursion/compiler/src/asm/mod.rs index 07253e689a..9f17670170 100644 --- a/recursion/compiler/src/asm/mod.rs +++ b/recursion/compiler/src/asm/mod.rs @@ -1,9 +1,7 @@ mod code; mod compiler; -mod heap; mod instruction; pub use code::*; pub use compiler::*; -pub use heap::*; pub use instruction::*; diff --git a/recursion/compiler/src/ir/collections.rs b/recursion/compiler/src/ir/collections.rs index f16f953306..f35c6c419b 100644 --- a/recursion/compiler/src/ir/collections.rs +++ b/recursion/compiler/src/ir/collections.rs @@ -1,78 +1,77 @@ -use std::marker::PhantomData; +use super::{Builder, Config, MemVariable, Ptr, Usize, Var}; +use p3_field::AbstractField; -use super::{Builder, Config, MemVariable, Ptr, Usize}; - -pub enum Slice { +pub enum Array { Fixed(Vec), - Vec(Vector), -} - -#[allow(dead_code)] -pub struct Vector { - ptr: Ptr, - len: Usize, - cap: Usize, - _marker: PhantomData, + Dyn(Ptr, Usize), } -impl> Vector { +impl> Array { pub fn len(&self) -> Usize { - self.len + match self { + Self::Fixed(vec) => Usize::from(vec.len()), + Self::Dyn(_, len) => *len, + } } } impl Builder { - pub fn vec, I: Into>>(&mut self, cap: I) -> Vector { - let cap = cap.into(); - Vector { - ptr: self.alloc(cap, V::size_of()), - len: Usize::from(0), - cap, - _marker: PhantomData, + /// Initialize an array of fixed length `len`. The entries will be uninitialized. + pub fn array, I: Into>>(&mut self, len: I) -> Array { + let len = len.into(); + match len { + Usize::Const(len) => Array::Fixed(vec![self.uninit::(); len]), + Usize::Var(len) => { + let len: Var = self.eval(len * C::N::from_canonical_usize(V::size_of())); + let len = Usize::Var(len); + let ptr = self.alloc(len); + Array::Dyn(ptr, len) + } } } pub fn get, I: Into>>( &mut self, - slice: &Slice, + slice: &Array, index: I, ) -> V { let index = index.into(); match slice { - Slice::Fixed(slice) => { + Array::Fixed(slice) => { if let Usize::Const(idx) = index { slice[idx] } else { panic!("Cannot index into a fixed slice with a variable size") } } - Slice::Vec(slice) => { + Array::Dyn(ptr, _) => { let var = self.uninit(); - self.load(var, slice.ptr, index); + self.load(var, *ptr + index * V::size_of()); var } } } - pub fn set, I: Into>>( + pub fn set, I: Into>, Expr: Into>( &mut self, - slice: &mut Slice, + slice: &mut Array, index: I, - value: V, + value: Expr, ) { let index = index.into(); match slice { - Slice::Fixed(slice) => { + Array::Fixed(slice) => { if let Usize::Const(idx) = index { - slice[idx] = value; + self.assign(slice[idx], value); } else { panic!("Cannot index into a fixed slice with a variable size") } } - Slice::Vec(slice) => { - self.store(slice.ptr, index, value); + Array::Dyn(ptr, _) => { + let value: V = self.eval(value); + self.store(*ptr + index * V::size_of(), value); } } } diff --git a/recursion/compiler/src/ir/instructions.rs b/recursion/compiler/src/ir/instructions.rs index 72b40be446..19485a94c6 100644 --- a/recursion/compiler/src/ir/instructions.rs +++ b/recursion/compiler/src/ir/instructions.rs @@ -71,18 +71,18 @@ pub enum DslIR { AssertEqEI(Ext, C::EF), AssertNeEI(Ext, C::EF), // Memory instructions. - /// Allocate (ptr, len, size) allocated a memory slice of length `len * size` - Alloc(Ptr, Usize, usize), - /// Load variable (var, ptr, offset) - LoadV(Var, Ptr, Usize), - /// Load field element (var, ptr, offset) - LoadF(Felt, Ptr, Usize), + /// Allocate (ptr, len) a memory slice of length len + Alloc(Ptr, Usize), + /// Load variable (var, ptr) + LoadV(Var, Ptr), + /// Load field element (var, ptr) + LoadF(Felt, Ptr), /// Load extension field - LoadE(Ext, Ptr, Usize), - /// Store variable - StoreV(Var, Ptr, Usize), - /// Store field element - StoreF(Felt, Ptr, Usize), - /// Store extension field - StoreE(Ext, Ptr, Usize), + LoadE(Ext, Ptr), + /// Store variable at address + StoreV(Ptr, Var), + /// Store field element at adress + StoreF(Ptr, Felt), + /// Store extension field at adress + StoreE(Ptr, Ext), } diff --git a/recursion/compiler/src/ir/ptr.rs b/recursion/compiler/src/ir/ptr.rs index 9b3af72adc..4500e391fa 100644 --- a/recursion/compiler/src/ir/ptr.rs +++ b/recursion/compiler/src/ir/ptr.rs @@ -1,38 +1,37 @@ +use p3_field::Field; + use super::{Builder, Config, DslIR, MemVariable, SymbolicVar, Usize, Var, Variable}; +use core::ops::{Add, Sub}; #[derive(Debug, Clone, Copy)] pub struct Ptr { - address: Var, + pub address: Var, +} + +pub struct SymbolicPtr { + pub address: SymbolicVar, } impl Builder { - pub(crate) fn alloc(&mut self, len: Usize, size: usize) -> Ptr { + pub(crate) fn alloc(&mut self, len: Usize) -> Ptr { let ptr = Ptr::uninit(self); - self.push(DslIR::Alloc(ptr, len, size)); + self.push(DslIR::Alloc(ptr, len)); ptr } - pub fn load, I: Into>>( - &mut self, - var: V, - ptr: Ptr, - offset: I, - ) { - var.load(ptr, offset.into(), self); + pub fn load, P: Into>>(&mut self, var: V, ptr: P) { + let load_ptr = self.eval(ptr); + var.load(load_ptr, self); } - pub fn store, I: Into>>( - &mut self, - ptr: Ptr, - offset: I, - value: V, - ) { - value.store(ptr, offset.into(), self); + pub fn store, P: Into>>(&mut self, ptr: P, value: V) { + let store_ptr = self.eval(ptr); + value.store(store_ptr, self); } } impl Variable for Ptr { - type Expression = SymbolicVar; + type Expression = SymbolicPtr; fn uninit(builder: &mut Builder) -> Self { Ptr { @@ -41,7 +40,7 @@ impl Variable for Ptr { } fn assign(&self, src: Self::Expression, builder: &mut Builder) { - self.address.assign(src, builder); + self.address.assign(src.address, builder); } fn assert_eq( @@ -49,7 +48,7 @@ impl Variable for Ptr { rhs: impl Into, builder: &mut Builder, ) { - Var::assert_eq(lhs, rhs, builder); + Var::assert_eq(lhs.into().address, rhs.into().address, builder); } fn assert_ne( @@ -57,6 +56,166 @@ impl Variable for Ptr { rhs: impl Into, builder: &mut Builder, ) { - Var::assert_ne(lhs, rhs, builder); + Var::assert_ne(lhs.into().address, rhs.into().address, builder); + } +} + +impl Add for Ptr { + type Output = SymbolicPtr; + + fn add(self, rhs: Self) -> Self::Output { + SymbolicPtr { + address: self.address + rhs.address, + } + } +} + +impl Sub for Ptr { + type Output = SymbolicPtr; + + fn sub(self, rhs: Self) -> Self::Output { + SymbolicPtr { + address: self.address - rhs.address, + } + } +} + +impl Add for SymbolicPtr { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + Self { + address: self.address + rhs.address, + } + } +} + +impl Sub for SymbolicPtr { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + Self { + address: self.address - rhs.address, + } + } +} + +impl Add> for SymbolicPtr { + type Output = Self; + + fn add(self, rhs: Ptr) -> Self { + Self { + address: self.address + rhs.address, + } + } +} + +impl Sub> for SymbolicPtr { + type Output = Self; + + fn sub(self, rhs: Ptr) -> Self { + Self { + address: self.address - rhs.address, + } + } +} + +impl Add> for Ptr { + type Output = SymbolicPtr; + + fn add(self, rhs: SymbolicPtr) -> SymbolicPtr { + SymbolicPtr { + address: self.address + rhs.address, + } + } +} + +impl Add> for Ptr { + type Output = SymbolicPtr; + + fn add(self, rhs: SymbolicVar) -> SymbolicPtr { + SymbolicPtr { + address: self.address + rhs, + } + } +} + +impl Sub> for Ptr { + type Output = SymbolicPtr; + + fn sub(self, rhs: SymbolicVar) -> SymbolicPtr { + SymbolicPtr { + address: self.address - rhs, + } + } +} + +impl Sub> for Ptr { + type Output = SymbolicPtr; + + fn sub(self, rhs: SymbolicPtr) -> SymbolicPtr { + SymbolicPtr { + address: self.address - rhs.address, + } + } +} + +impl Add> for Ptr { + type Output = SymbolicPtr; + + fn add(self, rhs: Usize) -> SymbolicPtr { + match rhs { + Usize::Const(rhs) => SymbolicPtr { + address: self.address + N::from_canonical_usize(rhs), + }, + Usize::Var(rhs) => SymbolicPtr { + address: self.address + rhs, + }, + } + } +} + +impl Add> for SymbolicPtr { + type Output = SymbolicPtr; + + fn add(self, rhs: Usize) -> SymbolicPtr { + match rhs { + Usize::Const(rhs) => SymbolicPtr { + address: self.address + N::from_canonical_usize(rhs), + }, + Usize::Var(rhs) => SymbolicPtr { + address: self.address + rhs, + }, + } + } +} + +impl Sub> for Ptr { + type Output = SymbolicPtr; + + fn sub(self, rhs: Usize) -> SymbolicPtr { + match rhs { + Usize::Const(rhs) => SymbolicPtr { + address: self.address - N::from_canonical_usize(rhs), + }, + Usize::Var(rhs) => SymbolicPtr { + address: self.address - rhs, + }, + } + } +} + +impl Sub> for SymbolicPtr { + type Output = SymbolicPtr; + + fn sub(self, rhs: Usize) -> SymbolicPtr { + match rhs { + Usize::Const(rhs) => SymbolicPtr { + address: self.address - N::from_canonical_usize(rhs), + }, + Usize::Var(rhs) => SymbolicPtr { + address: self.address - rhs, + }, + } } } diff --git a/recursion/compiler/src/ir/symbolic.rs b/recursion/compiler/src/ir/symbolic.rs index 2546bd93a3..26794d8a79 100644 --- a/recursion/compiler/src/ir/symbolic.rs +++ b/recursion/compiler/src/ir/symbolic.rs @@ -1,8 +1,14 @@ +use super::Usize; use super::{Ext, Felt, Var}; use alloc::rc::Rc; +use core::any::Any; use core::ops::{Add, Div, Mul, Neg, Sub}; -use p3_field::ExtensionField; use p3_field::Field; +use p3_field::{AbstractField, ExtensionField}; +use std::any::TypeId; +use std::iter::{Product, Sum}; +use std::mem; +use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign}; #[derive(Debug, Clone)] pub enum SymbolicVar { @@ -37,6 +43,192 @@ pub enum SymbolicExt { Neg(Rc>), } +#[derive(Debug, Clone)] +pub enum ExtOperand { + Base(F), + Const(EF), + Felt(Felt), + Ext(Ext), + SymFelt(SymbolicFelt), + Sym(SymbolicExt), +} + +pub trait ExtConst> { + fn cons(self) -> SymbolicExt; +} + +impl> ExtConst for EF { + fn cons(self) -> SymbolicExt { + SymbolicExt::Const(self) + } +} + +pub trait ExtensionOperand> { + fn to_operand(self) -> ExtOperand; +} + +impl AbstractField for SymbolicVar { + type F = N; + + fn zero() -> Self { + SymbolicVar::Const(N::zero()) + } + + fn one() -> Self { + SymbolicVar::Const(N::one()) + } + + fn two() -> Self { + SymbolicVar::Const(N::two()) + } + + fn neg_one() -> Self { + SymbolicVar::Const(N::neg_one()) + } + + fn from_f(f: Self::F) -> Self { + SymbolicVar::Const(f) + } + fn from_bool(b: bool) -> Self { + SymbolicVar::Const(N::from_bool(b)) + } + fn from_canonical_u8(n: u8) -> Self { + SymbolicVar::Const(N::from_canonical_u8(n)) + } + fn from_canonical_u16(n: u16) -> Self { + SymbolicVar::Const(N::from_canonical_u16(n)) + } + fn from_canonical_u32(n: u32) -> Self { + SymbolicVar::Const(N::from_canonical_u32(n)) + } + fn from_canonical_u64(n: u64) -> Self { + SymbolicVar::Const(N::from_canonical_u64(n)) + } + fn from_canonical_usize(n: usize) -> Self { + SymbolicVar::Const(N::from_canonical_usize(n)) + } + + fn from_wrapped_u32(n: u32) -> Self { + SymbolicVar::Const(N::from_wrapped_u32(n)) + } + fn from_wrapped_u64(n: u64) -> Self { + SymbolicVar::Const(N::from_wrapped_u64(n)) + } + + /// A generator of this field's entire multiplicative group. + fn generator() -> Self { + SymbolicVar::Const(N::generator()) + } +} + +impl AbstractField for SymbolicFelt { + type F = F; + + fn zero() -> Self { + SymbolicFelt::Const(F::zero()) + } + + fn one() -> Self { + SymbolicFelt::Const(F::one()) + } + + fn two() -> Self { + SymbolicFelt::Const(F::two()) + } + + fn neg_one() -> Self { + SymbolicFelt::Const(F::neg_one()) + } + + fn from_f(f: Self::F) -> Self { + SymbolicFelt::Const(f) + } + fn from_bool(b: bool) -> Self { + SymbolicFelt::Const(F::from_bool(b)) + } + fn from_canonical_u8(n: u8) -> Self { + SymbolicFelt::Const(F::from_canonical_u8(n)) + } + fn from_canonical_u16(n: u16) -> Self { + SymbolicFelt::Const(F::from_canonical_u16(n)) + } + fn from_canonical_u32(n: u32) -> Self { + SymbolicFelt::Const(F::from_canonical_u32(n)) + } + fn from_canonical_u64(n: u64) -> Self { + SymbolicFelt::Const(F::from_canonical_u64(n)) + } + fn from_canonical_usize(n: usize) -> Self { + SymbolicFelt::Const(F::from_canonical_usize(n)) + } + + fn from_wrapped_u32(n: u32) -> Self { + SymbolicFelt::Const(F::from_wrapped_u32(n)) + } + fn from_wrapped_u64(n: u64) -> Self { + SymbolicFelt::Const(F::from_wrapped_u64(n)) + } + + /// A generator of this field's entire multiplicative group. + fn generator() -> Self { + SymbolicFelt::Const(F::generator()) + } +} + +impl> AbstractField for SymbolicExt { + type F = EF; + + fn zero() -> Self { + SymbolicExt::Const(EF::zero()) + } + + fn one() -> Self { + SymbolicExt::Const(EF::one()) + } + + fn two() -> Self { + SymbolicExt::Const(EF::two()) + } + + fn neg_one() -> Self { + SymbolicExt::Const(EF::neg_one()) + } + + fn from_f(f: Self::F) -> Self { + SymbolicExt::Const(f) + } + fn from_bool(b: bool) -> Self { + SymbolicExt::Const(EF::from_bool(b)) + } + fn from_canonical_u8(n: u8) -> Self { + SymbolicExt::Const(EF::from_canonical_u8(n)) + } + fn from_canonical_u16(n: u16) -> Self { + SymbolicExt::Const(EF::from_canonical_u16(n)) + } + fn from_canonical_u32(n: u32) -> Self { + SymbolicExt::Const(EF::from_canonical_u32(n)) + } + fn from_canonical_u64(n: u64) -> Self { + SymbolicExt::Const(EF::from_canonical_u64(n)) + } + fn from_canonical_usize(n: usize) -> Self { + SymbolicExt::Const(EF::from_canonical_usize(n)) + } + + fn from_wrapped_u32(n: u32) -> Self { + SymbolicExt::Const(EF::from_wrapped_u32(n)) + } + fn from_wrapped_u64(n: u64) -> Self { + SymbolicExt::Const(EF::from_wrapped_u64(n)) + } + + /// A generator of this field's entire multiplicative group. + fn generator() -> Self { + SymbolicExt::Const(EF::generator()) + } +} + // Implement all conversions from constants N, F, EF, to the corresponding symbolic types impl From for SymbolicVar { @@ -51,9 +243,9 @@ impl From for SymbolicFelt { } } -impl From for SymbolicExt { - fn from(ef: EF) -> Self { - SymbolicExt::Const(ef) +impl> From for SymbolicExt { + fn from(f: F) -> Self { + SymbolicExt::Base(Rc::new(SymbolicFelt::Const(f))) } } @@ -71,7 +263,7 @@ impl From> for SymbolicFelt { } } -impl From> for SymbolicExt { +impl> From> for SymbolicExt { fn from(e: Ext) -> Self { SymbolicExt::Val(e) } @@ -95,11 +287,29 @@ impl Add for SymbolicFelt { } } -impl Add for SymbolicExt { +impl, E: ExtensionOperand> Add for SymbolicExt { type Output = Self; - fn add(self, rhs: Self) -> Self::Output { - SymbolicExt::Add(Rc::new(self), Rc::new(rhs)) + fn add(self, rhs: E) -> Self::Output { + let rhs = rhs.to_operand(); + match rhs { + ExtOperand::Base(f) => SymbolicExt::Add( + Rc::new(self), + Rc::new(SymbolicExt::Base(Rc::new(SymbolicFelt::Const(f)))), + ), + ExtOperand::Const(ef) => { + SymbolicExt::Add(Rc::new(self), Rc::new(SymbolicExt::Const(ef))) + } + ExtOperand::Felt(f) => SymbolicExt::Add( + Rc::new(self), + Rc::new(SymbolicExt::Base(Rc::new(SymbolicFelt::Val(f)))), + ), + ExtOperand::Ext(e) => SymbolicExt::Add(Rc::new(self), Rc::new(SymbolicExt::Val(e))), + ExtOperand::SymFelt(f) => { + SymbolicExt::Add(Rc::new(self), Rc::new(SymbolicExt::Base(Rc::new(f)))) + } + ExtOperand::Sym(e) => SymbolicExt::Add(Rc::new(self), Rc::new(e)), + } } } @@ -119,11 +329,29 @@ impl Mul for SymbolicFelt { } } -impl Mul for SymbolicExt { +impl, E: Any> Mul for SymbolicExt { type Output = Self; - fn mul(self, rhs: Self) -> Self::Output { - SymbolicExt::Mul(Rc::new(self), Rc::new(rhs)) + fn mul(self, rhs: E) -> Self::Output { + let rhs = rhs.to_operand(); + match rhs { + ExtOperand::Base(f) => SymbolicExt::Mul( + Rc::new(self), + Rc::new(SymbolicExt::Base(Rc::new(SymbolicFelt::Const(f)))), + ), + ExtOperand::Const(ef) => { + SymbolicExt::Mul(Rc::new(self), Rc::new(SymbolicExt::Const(ef))) + } + ExtOperand::Felt(f) => SymbolicExt::Mul( + Rc::new(self), + Rc::new(SymbolicExt::Base(Rc::new(SymbolicFelt::Val(f)))), + ), + ExtOperand::Ext(e) => SymbolicExt::Mul(Rc::new(self), Rc::new(SymbolicExt::Val(e))), + ExtOperand::SymFelt(f) => { + SymbolicExt::Mul(Rc::new(self), Rc::new(SymbolicExt::Base(Rc::new(f)))) + } + ExtOperand::Sym(e) => SymbolicExt::Mul(Rc::new(self), Rc::new(e)), + } } } @@ -143,11 +371,29 @@ impl Sub for SymbolicFelt { } } -impl Sub for SymbolicExt { +impl, E: Any> Sub for SymbolicExt { type Output = Self; - fn sub(self, rhs: Self) -> Self::Output { - SymbolicExt::Sub(Rc::new(self), Rc::new(rhs)) + fn sub(self, rhs: E) -> Self::Output { + let rhs = rhs.to_operand(); + match rhs { + ExtOperand::Base(f) => SymbolicExt::Sub( + Rc::new(self), + Rc::new(SymbolicExt::Base(Rc::new(SymbolicFelt::Const(f)))), + ), + ExtOperand::Const(ef) => { + SymbolicExt::Sub(Rc::new(self), Rc::new(SymbolicExt::Const(ef))) + } + ExtOperand::Felt(f) => SymbolicExt::Sub( + Rc::new(self), + Rc::new(SymbolicExt::Base(Rc::new(SymbolicFelt::Val(f)))), + ), + ExtOperand::Ext(e) => SymbolicExt::Sub(Rc::new(self), Rc::new(SymbolicExt::Val(e))), + ExtOperand::SymFelt(f) => { + SymbolicExt::Sub(Rc::new(self), Rc::new(SymbolicExt::Base(Rc::new(f)))) + } + ExtOperand::Sym(e) => SymbolicExt::Sub(Rc::new(self), Rc::new(e)), + } } } @@ -159,11 +405,29 @@ impl Div for SymbolicFelt { } } -impl Div for SymbolicExt { +impl, E: Any> Div for SymbolicExt { type Output = Self; - fn div(self, rhs: Self) -> Self::Output { - SymbolicExt::Div(Rc::new(self), Rc::new(rhs)) + fn div(self, rhs: E) -> Self::Output { + let rhs = rhs.to_operand(); + match rhs { + ExtOperand::Base(f) => SymbolicExt::Div( + Rc::new(self), + Rc::new(SymbolicExt::Base(Rc::new(SymbolicFelt::Const(f)))), + ), + ExtOperand::Const(ef) => { + SymbolicExt::Div(Rc::new(self), Rc::new(SymbolicExt::Const(ef))) + } + ExtOperand::Felt(f) => SymbolicExt::Div( + Rc::new(self), + Rc::new(SymbolicExt::Base(Rc::new(SymbolicFelt::Val(f)))), + ), + ExtOperand::Ext(e) => SymbolicExt::Div(Rc::new(self), Rc::new(SymbolicExt::Val(e))), + ExtOperand::SymFelt(f) => { + SymbolicExt::Div(Rc::new(self), Rc::new(SymbolicExt::Base(Rc::new(f)))) + } + ExtOperand::Sym(e) => SymbolicExt::Div(Rc::new(self), Rc::new(e)), + } } } @@ -183,7 +447,7 @@ impl Neg for SymbolicFelt { } } -impl Neg for SymbolicExt { +impl> Neg for SymbolicExt { type Output = Self; fn neg(self) -> Self::Output { @@ -209,14 +473,6 @@ impl Add for SymbolicFelt { } } -impl Add for SymbolicExt { - type Output = Self; - - fn add(self, rhs: EF) -> Self::Output { - SymbolicExt::Add(Rc::new(self), Rc::new(SymbolicExt::Const(rhs))) - } -} - impl Mul for SymbolicVar { type Output = Self; @@ -233,14 +489,6 @@ impl Mul for SymbolicFelt { } } -impl Mul for SymbolicExt { - type Output = Self; - - fn mul(self, rhs: EF) -> Self::Output { - SymbolicExt::Mul(Rc::new(self), Rc::new(SymbolicExt::Const(rhs))) - } -} - impl Sub for SymbolicVar { type Output = Self; @@ -257,30 +505,6 @@ impl Sub for SymbolicFelt { } } -impl Sub for SymbolicExt { - type Output = Self; - - fn sub(self, rhs: EF) -> Self::Output { - SymbolicExt::Sub(Rc::new(self), Rc::new(SymbolicExt::Const(rhs))) - } -} - -impl Div for SymbolicFelt { - type Output = Self; - - fn div(self, rhs: F) -> Self::Output { - SymbolicFelt::Div(Rc::new(self), Rc::new(SymbolicFelt::Const(rhs))) - } -} - -impl Div for SymbolicExt { - type Output = Self; - - fn div(self, rhs: EF) -> Self::Output { - SymbolicExt::Div(Rc::new(self), Rc::new(SymbolicExt::Const(rhs))) - } -} - // Implement all operations between SymbolicVar, SymbolicFelt, SymbolicExt, and Var, // Felt, Ext. @@ -300,14 +524,6 @@ impl Add> for SymbolicFelt { } } -impl Add> for SymbolicExt { - type Output = SymbolicExt; - - fn add(self, rhs: Ext) -> Self::Output { - self + SymbolicExt::from(rhs) - } -} - impl Mul> for SymbolicVar { type Output = SymbolicVar; @@ -324,14 +540,6 @@ impl Mul> for SymbolicFelt { } } -impl Mul> for SymbolicExt { - type Output = SymbolicExt; - - fn mul(self, rhs: Ext) -> Self::Output { - self * SymbolicExt::from(rhs) - } -} - impl Sub> for SymbolicVar { type Output = SymbolicVar; @@ -348,14 +556,6 @@ impl Sub> for SymbolicFelt { } } -impl Sub> for SymbolicExt { - type Output = SymbolicExt; - - fn sub(self, rhs: Ext) -> Self::Output { - self - SymbolicExt::from(rhs) - } -} - impl Div> for Felt { type Output = SymbolicFelt; @@ -364,14 +564,6 @@ impl Div> for Felt { } } -impl Div> for Ext { - type Output = SymbolicExt; - - fn div(self, rhs: SymbolicExt) -> Self::Output { - SymbolicExt::::from(self) / rhs - } -} - // Implement operations between constants N, F, EF, and Var, Felt, Ext. impl Add for Var { @@ -406,22 +598,6 @@ impl Add for Felt { } } -impl> Add for Ext { - type Output = SymbolicExt; - - fn add(self, rhs: Self) -> Self::Output { - SymbolicExt::::from(self) + rhs - } -} - -impl> Add for Ext { - type Output = SymbolicExt; - - fn add(self, rhs: EF) -> Self::Output { - SymbolicExt::from(self) + rhs - } -} - impl Mul for Var { type Output = SymbolicVar; @@ -454,22 +630,6 @@ impl Mul for Felt { } } -impl> Mul for Ext { - type Output = SymbolicExt; - - fn mul(self, rhs: Self) -> Self::Output { - SymbolicExt::::from(self) * rhs - } -} - -impl Mul for Ext { - type Output = SymbolicExt; - - fn mul(self, rhs: EF) -> Self::Output { - SymbolicExt::from(self) * rhs - } -} - impl Sub for Var { type Output = SymbolicVar; @@ -502,47 +662,43 @@ impl Sub for Felt { } } -impl> Sub for Ext { +impl, E: Any> Add for Ext { type Output = SymbolicExt; - fn sub(self, rhs: Self) -> Self::Output { - SymbolicExt::::from(self) - rhs + fn add(self, rhs: E) -> Self::Output { + let rhs: ExtOperand = rhs.to_operand(); + SymbolicExt::::from(self) + rhs } } -impl Sub for Ext { +impl, E: Any> Mul for Ext { type Output = SymbolicExt; - fn sub(self, rhs: EF) -> Self::Output { - SymbolicExt::from(self) - rhs + fn mul(self, rhs: E) -> Self::Output { + let rhs: ExtOperand = rhs.to_operand(); + SymbolicExt::::from(self) * rhs } } -impl Sub> for Ext { +impl, E: Any> Sub for Ext { type Output = SymbolicExt; - fn sub(self, rhs: SymbolicExt) -> Self::Output { + fn sub(self, rhs: E) -> Self::Output { + let rhs: ExtOperand = rhs.to_operand(); SymbolicExt::::from(self) - rhs } } -impl Add> for Ext { - type Output = SymbolicExt; - - fn add(self, rhs: SymbolicExt) -> Self::Output { - SymbolicExt::::from(self) + rhs - } -} - -impl Mul> for Ext { +impl, E: Any> Div for Ext { type Output = SymbolicExt; - fn mul(self, rhs: SymbolicExt) -> Self::Output { - SymbolicExt::::from(self) * rhs + fn div(self, rhs: E) -> Self::Output { + let rhs: ExtOperand = rhs.to_operand(); + SymbolicExt::::from(self) / rhs } } -impl Add> for Felt { +impl> Add> for Felt { type Output = SymbolicExt; fn add(self, rhs: SymbolicExt) -> Self::Output { @@ -550,7 +706,7 @@ impl Add> for Felt { } } -impl Mul> for Felt { +impl> Mul> for Felt { type Output = SymbolicExt; fn mul(self, rhs: SymbolicExt) -> Self::Output { @@ -558,7 +714,7 @@ impl Mul> for Felt { } } -impl Sub> for Felt { +impl> Sub> for Felt { type Output = SymbolicExt; fn sub(self, rhs: SymbolicExt) -> Self::Output { @@ -566,7 +722,7 @@ impl Sub> for Felt { } } -impl Div> for Felt { +impl> Div> for Felt { type Output = SymbolicExt; fn div(self, rhs: SymbolicExt) -> Self::Output { @@ -582,17 +738,6 @@ impl Div for Felt { } } -impl Div for Ext { - type Output = SymbolicExt; - - fn div(self, rhs: Self) -> Self::Output { - SymbolicExt::Div( - Rc::new(SymbolicExt::from(self)), - Rc::new(SymbolicExt::from(rhs)), - ) - } -} - impl Div> for SymbolicFelt { type Output = SymbolicFelt; @@ -600,3 +745,183 @@ impl Div> for SymbolicFelt { SymbolicFelt::Div(Rc::new(self), Rc::new(SymbolicFelt::Val(rhs))) } } + +impl Sub> for Var { + type Output = SymbolicVar; + + fn sub(self, rhs: SymbolicVar) -> Self::Output { + SymbolicVar::::from(self) - rhs + } +} + +impl Add> for Var { + type Output = SymbolicVar; + + fn add(self, rhs: SymbolicVar) -> Self::Output { + SymbolicVar::::from(self) + rhs + } +} + +impl Mul for Usize { + type Output = SymbolicVar; + + fn mul(self, rhs: usize) -> Self::Output { + match self { + Usize::Const(n) => SymbolicVar::Const(N::from_canonical_usize(n * rhs)), + Usize::Var(n) => SymbolicVar::Val(n) * N::from_canonical_usize(rhs), + } + } +} + +impl Product for SymbolicVar { + fn product>(iter: I) -> Self { + iter.fold(SymbolicVar::one(), |acc, x| acc * x) + } +} + +impl Sum for SymbolicVar { + fn sum>(iter: I) -> Self { + iter.fold(SymbolicVar::zero(), |acc, x| acc + x) + } +} + +impl AddAssign for SymbolicVar { + fn add_assign(&mut self, rhs: Self) { + *self = self.clone() + rhs; + } +} + +impl SubAssign for SymbolicVar { + fn sub_assign(&mut self, rhs: Self) { + *self = self.clone() - rhs; + } +} + +impl MulAssign for SymbolicVar { + fn mul_assign(&mut self, rhs: Self) { + *self = self.clone() * rhs; + } +} + +impl Default for SymbolicVar { + fn default() -> Self { + SymbolicVar::zero() + } +} + +impl Sum for SymbolicFelt { + fn sum>(iter: I) -> Self { + iter.fold(SymbolicFelt::zero(), |acc, x| acc + x) + } +} + +impl Product for SymbolicFelt { + fn product>(iter: I) -> Self { + iter.fold(SymbolicFelt::one(), |acc, x| acc * x) + } +} + +impl AddAssign for SymbolicFelt { + fn add_assign(&mut self, rhs: Self) { + *self = self.clone() + rhs; + } +} + +impl SubAssign for SymbolicFelt { + fn sub_assign(&mut self, rhs: Self) { + *self = self.clone() - rhs; + } +} + +impl MulAssign for SymbolicFelt { + fn mul_assign(&mut self, rhs: Self) { + *self = self.clone() * rhs; + } +} + +impl Default for SymbolicFelt { + fn default() -> Self { + SymbolicFelt::zero() + } +} + +impl> Sum for SymbolicExt { + fn sum>(iter: I) -> Self { + iter.fold(SymbolicExt::zero(), |acc, x| acc + x) + } +} + +impl> Product for SymbolicExt { + fn product>(iter: I) -> Self { + iter.fold(SymbolicExt::one(), |acc, x| acc * x) + } +} + +impl> Default for SymbolicExt { + fn default() -> Self { + SymbolicExt::zero() + } +} + +impl, E: Any> AddAssign for SymbolicExt { + fn add_assign(&mut self, rhs: E) { + *self = self.clone() + rhs; + } +} + +impl, E: Any> SubAssign for SymbolicExt { + fn sub_assign(&mut self, rhs: E) { + *self = self.clone() - rhs; + } +} + +impl, E: Any> MulAssign for SymbolicExt { + fn mul_assign(&mut self, rhs: E) { + *self = self.clone() * rhs; + } +} + +impl, E: Any> DivAssign for SymbolicExt { + fn div_assign(&mut self, rhs: E) { + *self = self.clone() / rhs; + } +} + +impl, E: Any> ExtensionOperand for E { + fn to_operand(self) -> ExtOperand { + match self.type_id() { + ty if ty == TypeId::of::() => { + // *Saftey*: We know that E is a F and we can transmute it to F which implements + // the Copy trait. + let value = unsafe { mem::transmute_copy::(&self) }; + ExtOperand::::Base(value) + } + ty if ty == TypeId::of::() => { + // *Saftey*: We know that E is a EF and we can transmute it to EF which implements + // the Copy trait. + let value = unsafe { mem::transmute_copy::(&self) }; + ExtOperand::::Const(value) + } + ty if ty == TypeId::of::>() => { + // *Saftey*: We know that E is a Felt and we can transmute it to Felt which + // implements the Copy trait. + let value = unsafe { mem::transmute_copy::>(&self) }; + ExtOperand::::Felt(value) + } + ty if ty == TypeId::of::>() => { + // *Saftey*: We know that E is a Ext and we can transmute it to Ext + // which implements the Copy trait. + let value = unsafe { mem::transmute_copy::>(&self) }; + ExtOperand::::Ext(value) + } + ty if ty == TypeId::of::>() => { + // *Saftey*: We know that E is a SymbolicExt and we can transmute it to + // SymbolicExt but we need to clone the pointer. + let value_ref = unsafe { mem::transmute::<&E, &SymbolicExt>(&self) }; + let value = value_ref.clone(); + ExtOperand::::Sym(value) + } + _ => unimplemented!("Unsupported type"), + } + } +} diff --git a/recursion/compiler/src/ir/types.rs b/recursion/compiler/src/ir/types.rs index 6ab91162ac..0d83203466 100644 --- a/recursion/compiler/src/ir/types.rs +++ b/recursion/compiler/src/ir/types.rs @@ -326,17 +326,12 @@ impl MemVariable for Var { 1 } - fn load(&self, ptr: Ptr, offset: Usize, builder: &mut Builder) { - builder.push(DslIR::LoadV(*self, ptr, offset)); + fn load(&self, ptr: Ptr, builder: &mut Builder) { + builder.push(DslIR::LoadV(*self, ptr)); } - fn store( - &self, - ptr: Ptr<::N>, - offset: Usize<::N>, - builder: &mut Builder, - ) { - builder.push(DslIR::StoreV(*self, ptr, offset)); + fn store(&self, ptr: Ptr<::N>, builder: &mut Builder) { + builder.push(DslIR::StoreV(ptr, *self)); } } @@ -632,17 +627,12 @@ impl MemVariable for Felt { 1 } - fn load(&self, ptr: Ptr, offset: Usize, builder: &mut Builder) { - builder.push(DslIR::LoadF(*self, ptr, offset)); + fn load(&self, ptr: Ptr, builder: &mut Builder) { + builder.push(DslIR::LoadF(*self, ptr)); } - fn store( - &self, - ptr: Ptr<::N>, - offset: Usize<::N>, - builder: &mut Builder, - ) { - builder.push(DslIR::StoreF(*self, ptr, offset)); + fn store(&self, ptr: Ptr<::N>, builder: &mut Builder) { + builder.push(DslIR::StoreF(ptr, *self)); } } @@ -955,16 +945,11 @@ impl MemVariable for Ext { 4 } - fn load(&self, ptr: Ptr, offset: Usize, builder: &mut Builder) { - builder.push(DslIR::LoadE(*self, ptr, offset)); + fn load(&self, ptr: Ptr, builder: &mut Builder) { + builder.push(DslIR::LoadE(*self, ptr)); } - fn store( - &self, - ptr: Ptr<::N>, - offset: Usize<::N>, - builder: &mut Builder, - ) { - builder.push(DslIR::StoreE(*self, ptr, offset)); + fn store(&self, ptr: Ptr<::N>, builder: &mut Builder) { + builder.push(DslIR::StoreE(ptr, *self)); } } diff --git a/recursion/compiler/src/ir/var.rs b/recursion/compiler/src/ir/var.rs index 97b2965c44..93aabb8782 100644 --- a/recursion/compiler/src/ir/var.rs +++ b/recursion/compiler/src/ir/var.rs @@ -1,4 +1,5 @@ -use super::{Builder, Config, Ptr, Usize}; +use super::{Builder, Config, Ptr}; + pub trait Variable: Copy { type Expression; @@ -21,6 +22,6 @@ pub trait Variable: Copy { pub trait MemVariable: Variable { fn size_of() -> usize; - fn load(&self, ptr: Ptr, offset: Usize, builder: &mut Builder); - fn store(&self, ptr: Ptr, offset: Usize, builder: &mut Builder); + fn load(&self, ptr: Ptr, builder: &mut Builder); + fn store(&self, ptr: Ptr, builder: &mut Builder); } diff --git a/recursion/compiler/tests/arithmetic.rs b/recursion/compiler/tests/arithmetic.rs index d19e99e1c6..205f9765a3 100644 --- a/recursion/compiler/tests/arithmetic.rs +++ b/recursion/compiler/tests/arithmetic.rs @@ -22,13 +22,13 @@ fn test_compiler_arithmetic() { builder.assert_felt_eq(one * one, F::one()); builder.assert_felt_eq(one + one, F::two()); - let zero_ext: Ext<_, _> = builder.eval(EF::zero()); - let one_ext: Ext<_, _> = builder.eval(EF::one()); + let zero_ext: Ext<_, _> = builder.eval(EF::zero().cons()); + let one_ext: Ext<_, _> = builder.eval(EF::one().cons()); - builder.assert_ext_eq(zero_ext * one_ext, EF::zero()); - builder.assert_ext_eq(one_ext * one_ext, EF::one()); - builder.assert_ext_eq(one_ext + one_ext, EF::two()); - builder.assert_ext_eq(one_ext - one_ext, EF::zero()); + builder.assert_ext_eq(zero_ext * one_ext, EF::zero().cons()); + builder.assert_ext_eq(one_ext * one_ext, EF::one().cons()); + builder.assert_ext_eq(one_ext + one_ext, EF::two().cons()); + builder.assert_ext_eq(one_ext - one_ext, EF::zero().cons()); for _ in 0..num_tests { let a_val = rng.gen::(); @@ -36,18 +36,19 @@ fn test_compiler_arithmetic() { let a: Felt<_> = builder.eval(a_val); let b: Felt<_> = builder.eval(b_val); builder.assert_felt_eq(a + b, a_val + b_val); + builder.assert_felt_eq(a + b, a + b_val); builder.assert_felt_eq(a * b, a_val * b_val); builder.assert_felt_eq(a - b, a_val - b_val); builder.assert_felt_eq(a / b, a_val / b_val); let a_ext_val = rng.gen::(); let b_ext_val = rng.gen::(); - let a_ext: Ext<_, _> = builder.eval(a_ext_val); - let b_ext: Ext<_, _> = builder.eval(b_ext_val); - builder.assert_ext_eq(a_ext + b_ext, a_ext_val + b_ext_val); - builder.assert_ext_eq(a_ext * b_ext, a_ext_val * b_ext_val); - builder.assert_ext_eq(a_ext - b_ext, a_ext_val - b_ext_val); - builder.assert_ext_eq(a_ext / b_ext, a_ext_val / b_ext_val); + let a_ext: Ext<_, _> = builder.eval(a_ext_val.cons()); + let b_ext: Ext<_, _> = builder.eval(b_ext_val.cons()); + builder.assert_ext_eq(a_ext + b_ext, (a_ext_val + b_ext_val).cons()); + builder.assert_ext_eq(a_ext * b_ext, (a_ext_val * b_ext_val).cons()); + builder.assert_ext_eq(a_ext - b_ext, (a_ext_val - b_ext_val).cons()); + builder.assert_ext_eq(a_ext / b_ext, (a_ext_val / b_ext_val).cons()); } let program = builder.compile(); diff --git a/recursion/compiler/tests/array.rs b/recursion/compiler/tests/array.rs new file mode 100644 index 0000000000..170eeff968 --- /dev/null +++ b/recursion/compiler/tests/array.rs @@ -0,0 +1,79 @@ +use p3_field::AbstractField; +use rand::{thread_rng, Rng}; +use sp1_core::stark::StarkGenericConfig; +use sp1_core::utils::BabyBearPoseidon2; +use sp1_recursion_compiler::asm::VmBuilder; +use sp1_recursion_compiler::prelude::*; +use sp1_recursion_core::runtime::Runtime; + +#[test] +fn test_compiler_array() { + type SC = BabyBearPoseidon2; + type F = ::Val; + type EF = ::Challenge; + let mut builder = VmBuilder::::default(); + + // Sum all the values of an array. + let len: usize = 1000; + let mut rng = thread_rng(); + + let mut static_array = builder.array::, _>(len); + + // Put values statically + for i in 0..len { + builder.set(&mut static_array, i, F::one()); + } + // Assert values set. + for i in 0..len { + let value = builder.get(&static_array, i); + builder.assert_var_eq(value, F::one()); + } + + let dyn_len: Var<_> = builder.eval(F::from_canonical_usize(len)); + let mut var_array = builder.array::, _>(dyn_len); + let mut felt_array = builder.array::, _>(dyn_len); + let mut ext_array = builder.array::, _>(dyn_len); + // Put values statically + let var_vals = (0..len).map(|_| rng.gen::()).collect::>(); + let felt_vals = (0..len).map(|_| rng.gen::()).collect::>(); + let ext_vals = (0..len).map(|_| rng.gen::()).collect::>(); + for i in 0..len { + builder.set(&mut var_array, i, var_vals[i]); + builder.set(&mut felt_array, i, felt_vals[i]); + builder.set(&mut ext_array, i, ext_vals[i].cons()); + } + // Assert values set. + for i in 0..len { + let var_value = builder.get(&var_array, i); + builder.assert_var_eq(var_value, var_vals[i]); + let felt_value = builder.get(&felt_array, i); + builder.assert_felt_eq(felt_value, felt_vals[i]); + let ext_value = builder.get(&ext_array, i); + builder.assert_ext_eq(ext_value, ext_vals[i].cons()); + } + + // Put values dynamically + builder.range(0, dyn_len).for_each(|i, builder| { + builder.set(&mut var_array, i, i * F::two()); + builder.set(&mut felt_array, i, F::from_canonical_u32(3)); + builder.set(&mut ext_array, i, (EF::from_canonical_u32(4)).cons()); + }); + + // Assert values set. + builder.range(0, dyn_len).for_each(|i, builder| { + let var_value = builder.get(&var_array, i); + builder.assert_var_eq(var_value, i * F::two()); + let felt_value = builder.get(&felt_array, i); + builder.assert_felt_eq(felt_value, F::from_canonical_u32(3)); + let ext_value = builder.get(&ext_array, i); + builder.assert_ext_eq(ext_value, EF::from_canonical_u32(4).cons()); + }); + + let code = builder.compile_to_asm(); + println!("{code}"); + + let program = code.machine_code(); + + let mut runtime = Runtime::::new(&program); + runtime.run(); +} diff --git a/recursion/core/src/lib.rs b/recursion/core/src/lib.rs index d5488c3f57..e19fb3dd34 100644 --- a/recursion/core/src/lib.rs +++ b/recursion/core/src/lib.rs @@ -39,26 +39,15 @@ pub mod tests { Program:: { instructions: vec![ // .main - Instruction::new(Opcode::SW, F::zero(), one, zero, true, false, true, false), - Instruction::new( - Opcode::SW, - F::from_canonical_u32(1), - one, - zero, - true, - false, - true, - false, - ), + Instruction::new(Opcode::SW, F::zero(), one, zero, true, true), + Instruction::new(Opcode::SW, F::from_canonical_u32(1), one, zero, true, true), Instruction::new( Opcode::SW, F::from_canonical_u32(2), [F::from_canonical_u32(10), F::zero(), F::zero(), F::zero()], zero, true, - false, true, - false, ), // .body: Instruction::new( @@ -67,29 +56,16 @@ pub mod tests { zero, one, false, - false, true, - false, - ), - Instruction::new( - Opcode::SW, - F::from_canonical_u32(0), - one, - zero, - false, - false, - true, - false, ), + Instruction::new(Opcode::SW, F::from_canonical_u32(0), one, zero, false, true), Instruction::new( Opcode::SW, F::from_canonical_u32(1), [F::two() + F::one(), F::zero(), F::zero(), F::zero()], zero, false, - false, true, - false, ), Instruction::new( Opcode::SUB, @@ -97,9 +73,7 @@ pub mod tests { [F::two(), F::zero(), F::zero(), F::zero()], one, false, - false, true, - false, ), Instruction::new( Opcode::BNE, @@ -112,9 +86,7 @@ pub mod tests { F::zero(), ], true, - false, true, - false, ), ], } diff --git a/recursion/core/src/runtime/instruction.rs b/recursion/core/src/runtime/instruction.rs index 16fe89167e..c4b9c6bb1f 100644 --- a/recursion/core/src/runtime/instruction.rs +++ b/recursion/core/src/runtime/instruction.rs @@ -18,17 +18,11 @@ pub struct Instruction { /// The third operand. pub op_c: Block, - /// Whether the second operand is an immediate field value. + /// Whether the second operand is an immediate value. pub imm_b: bool, - /// Whether the second operand is an immediate extension value. - pub imm_ext_b: bool, - /// Whether the third operand is an immediate value. pub imm_c: bool, - - /// Whether the third operand is an immediate extension value. - pub imm_ext_c: bool, } impl Instruction { @@ -39,9 +33,7 @@ impl Instruction { op_b: [F; D], op_c: [F; D], imm_b: bool, - imm_ext_b: bool, imm_c: bool, - imm_ext_c: bool, ) -> Self { Self { opcode, @@ -49,9 +41,50 @@ impl Instruction { op_b: Block::from(op_b), op_c: Block::from(op_c), imm_b, - imm_ext_b, imm_c, - imm_ext_c, } } + + pub(crate) fn is_b_ext(&self) -> bool { + matches!( + self.opcode, + Opcode::LE + | Opcode::SE + | Opcode::EADD + | Opcode::ESUB + | Opcode::EMUL + | Opcode::EFADD + | Opcode::EFSUB + | Opcode::EFMUL + | Opcode::EDIV + | Opcode::EBNE + | Opcode::EBEQ + ) + } + + pub(crate) fn is_c_ext(&self) -> bool { + matches!( + self.opcode, + Opcode::LE + | Opcode::SE + | Opcode::EADD + | Opcode::EMUL + | Opcode::ESUB + | Opcode::EDIV + | Opcode::EFADD + | Opcode::EFSUB + | Opcode::EFMUL + | Opcode::EFDIV + | Opcode::EBNE + | Opcode::EBEQ + ) + } + + pub(crate) fn imm_b_base(&self) -> bool { + self.imm_b && !self.is_b_ext() + } + + pub(crate) fn imm_c_base(&self) -> bool { + self.imm_c && !self.is_c_ext() + } } diff --git a/recursion/core/src/runtime/mod.rs b/recursion/core/src/runtime/mod.rs index 3a92c19ad9..258bc614a7 100644 --- a/recursion/core/src/runtime/mod.rs +++ b/recursion/core/src/runtime/mod.rs @@ -17,8 +17,8 @@ use crate::memory::MemoryRecord; use p3_field::{ExtensionField, PrimeField32}; use sp1_core::runtime::MemoryAccessPosition; -pub(crate) const STACK_SIZE: usize = 1024; -pub(crate) const MEMORY_SIZE: usize = 1024 * 1024; +pub const STACK_SIZE: usize = 1 << 20; +pub const MEMORY_SIZE: usize = 1 << 26; pub const D: usize = 4; @@ -127,72 +127,70 @@ impl> Runtime { self.clk + F::from_canonical_u32(*position as u32) } - /// Fetch the destination address and input operand values for an ALU instruction. - fn alu_rr(&mut self, instruction: &Instruction) -> (F, Block, Block) { - let a_ptr = self.fp + instruction.op_a; + fn get_b(&mut self, instruction: &Instruction) -> Block { + if instruction.imm_b_base() { + Block::from(instruction.op_b[0]) + } else if instruction.imm_b { + instruction.op_b + } else { + self.mr(self.fp + instruction.op_b[0], MemoryAccessPosition::B) + } + } - let c_val = if instruction.imm_c { + fn get_c(&mut self, instruction: &Instruction) -> Block { + if instruction.imm_c_base() { Block::from(instruction.op_c[0]) - } else if instruction.imm_ext_c { + } else if instruction.imm_c { instruction.op_c } else { self.mr(self.fp + instruction.op_c[0], MemoryAccessPosition::C) - }; + } + } - let b_val = if instruction.imm_b { - Block::from(instruction.op_b[0]) - } else if instruction.imm_ext_b { - instruction.op_b - } else { - self.mr(self.fp + instruction.op_b[0], MemoryAccessPosition::B) - }; + /// Fetch the destination address and input operand values for an ALU instruction. + fn alu_rr(&mut self, instruction: &Instruction) -> (F, Block, Block) { + let a_ptr = self.fp + instruction.op_a; + let c_val = self.get_c(instruction); + let b_val = self.get_b(instruction); (a_ptr, b_val, c_val) } /// Fetch the destination address input operand values for a load instruction (from heap). fn load_rr(&mut self, instruction: &Instruction) -> (F, Block) { - if instruction.imm_b { - let a_ptr = self.fp + instruction.op_a; - let b = Block::from(instruction.op_b[0]); - (a_ptr, b) - } else if instruction.imm_ext_b { - let a_ptr = self.fp + instruction.op_a; - let b = instruction.op_b; - (a_ptr, b) + let a_ptr = self.fp + instruction.op_a; + let b = if instruction.imm_b_base() { + Block::from(instruction.op_b[0]) + } else if instruction.imm_b { + instruction.op_b } else { - let a_ptr = self.fp + instruction.op_a; - let b = self.mr(self.fp + instruction.op_b[0], MemoryAccessPosition::B); - (a_ptr, b) - } + let address = self.mr(self.fp + instruction.op_b[0], MemoryAccessPosition::B); + self.mr(address[0], MemoryAccessPosition::A) + }; + (a_ptr, b) } /// Fetch the destination address input operand values for a store instruction (from stack). fn store_rr(&mut self, instruction: &Instruction) -> (F, Block) { - if instruction.imm_b { - let a_ptr = self.fp + instruction.op_a; - let b_val = Block::from(instruction.op_b[0]); - (a_ptr, b_val) - } else if instruction.imm_ext_b { - let a_ptr = self.fp + instruction.op_a; - (a_ptr, instruction.op_b) + let a_ptr = if instruction.imm_b { + self.fp + instruction.op_a } else { - let a_ptr = self.fp + instruction.op_a; - let b = self.mr(self.fp + instruction.op_b[0], MemoryAccessPosition::B); - (a_ptr, b) - } - } - - /// Fetch the input operand values for a branch instruction. - fn branch_rr(&mut self, instruction: &Instruction) -> (Block, Block, F) { - let a = self.mr(self.fp + instruction.op_a, MemoryAccessPosition::A); - let b = if instruction.imm_b { + self.mr(self.fp + instruction.op_a, MemoryAccessPosition::A)[0] + }; + let b = if instruction.imm_b_base() { Block::from(instruction.op_b[0]) - } else if instruction.imm_ext_b { + } else if instruction.imm_b { instruction.op_b } else { self.mr(self.fp + instruction.op_b[0], MemoryAccessPosition::B) }; + (a_ptr, b) + } + + /// Fetch the input operand values for a branch instruction. + fn branch_rr(&mut self, instruction: &Instruction) -> (Block, Block, F) { + let a = self.mr(self.fp + instruction.op_a, MemoryAccessPosition::A); + let b = self.get_b(instruction); let c = instruction.op_c[0]; (a, b, c) @@ -233,28 +231,28 @@ impl> Runtime { self.mw(a_ptr, a_val, MemoryAccessPosition::A); (a, b, c) = (a_val, b_val, c_val); } - Opcode::EADD => { + Opcode::EADD | Opcode::EFADD => { let (a_ptr, b_val, c_val) = self.alu_rr(&instruction); let sum = EF::from_base_slice(&b_val.0) + EF::from_base_slice(&c_val.0); let a_val = Block::from(sum.as_base_slice()); self.mw(a_ptr, a_val, MemoryAccessPosition::A); (a, b, c) = (a_val, b_val, c_val); } - Opcode::EMUL => { + Opcode::EMUL | Opcode::EFMUL => { let (a_ptr, b_val, c_val) = self.alu_rr(&instruction); let product = EF::from_base_slice(&b_val.0) * EF::from_base_slice(&c_val.0); let a_val = Block::from(product.as_base_slice()); self.mw(a_ptr, a_val, MemoryAccessPosition::A); (a, b, c) = (a_val, b_val, c_val); } - Opcode::ESUB => { + Opcode::ESUB | Opcode::EFSUB => { let (a_ptr, b_val, c_val) = self.alu_rr(&instruction); let diff = EF::from_base_slice(&b_val.0) - EF::from_base_slice(&c_val.0); let a_val = Block::from(diff.as_base_slice()); self.mw(a_ptr, a_val, MemoryAccessPosition::A); (a, b, c) = (a_val, b_val, c_val); } - Opcode::EDIV => { + Opcode::EDIV | Opcode::EFDIV => { let (a_ptr, b_val, c_val) = self.alu_rr(&instruction); let quotient = EF::from_base_slice(&b_val.0) / EF::from_base_slice(&c_val.0); let a_val = Block::from(quotient.as_base_slice()); @@ -262,12 +260,26 @@ impl> Runtime { (a, b, c) = (a_val, b_val, c_val); } Opcode::LW => { + let (a_ptr, b_val) = self.load_rr(&instruction); + let prev_a = self.mr(a_ptr, MemoryAccessPosition::A); + let a_val = Block::from([b_val[0], prev_a[1], prev_a[2], prev_a[3]]); + self.mw(a_ptr, a_val, MemoryAccessPosition::A); + (a, b, c) = (a_val, b_val, Block::default()); + } + Opcode::LE => { let (a_ptr, b_val) = self.load_rr(&instruction); let a_val = b_val; self.mw(a_ptr, a_val, MemoryAccessPosition::A); (a, b, c) = (a_val, b_val, Block::default()); } Opcode::SW => { + let (a_ptr, b_val) = self.store_rr(&instruction); + let prev_a = self.mr(a_ptr, MemoryAccessPosition::A); + let a_val = Block::from([b_val[0], prev_a[1], prev_a[2], prev_a[3]]); + self.mw(a_ptr, a_val, MemoryAccessPosition::A); + (a, b, c) = (a_val, b_val, Block::default()); + } + Opcode::SE => { let (a_ptr, b_val) = self.store_rr(&instruction); let a_val = b_val; self.mw(a_ptr, a_val, MemoryAccessPosition::A); diff --git a/recursion/core/src/runtime/opcode.rs b/recursion/core/src/runtime/opcode.rs index 727602cdde..9b3c100f9a 100644 --- a/recursion/core/src/runtime/opcode.rs +++ b/recursion/core/src/runtime/opcode.rs @@ -8,25 +8,33 @@ pub enum Opcode { DIV = 3, // Arithmetic field extension operations. - EADD = 11, - ESUB = 12, - EMUL = 13, - EDIV = 14, + EADD = 10, + ESUB = 11, + EMUL = 12, + EDIV = 13, + + // Mixed arithmetic operations. + EFADD = 20, + EFSUB = 21, + EFMUL = 22, + EFDIV = 23, // Memory instructions. LW = 4, SW = 5, + LE = 14, + SE = 15, // Branch instructions. BEQ = 6, BNE = 7, - EBEQ = 15, - EBNE = 16, + EBEQ = 16, + EBNE = 17, // Jump instructions. JAL = 8, JALR = 9, // System instructions. - TRAP = 10, + TRAP = 30, }