diff --git a/core/src/air/machine.rs b/core/src/air/machine.rs index 8446dbc493..977abf4ab3 100644 --- a/core/src/air/machine.rs +++ b/core/src/air/machine.rs @@ -11,7 +11,7 @@ pub trait MachineAir: BaseAir { /// The execution record containing events for producing the air trace. type Record: MachineRecord; - type Program; + type Program: Send + Sync; /// A unique identifier for this AIR as part of a machine. fn name(&self) -> String; diff --git a/core/src/lookup/debug.rs b/core/src/lookup/debug.rs index f9c7816875..6fa7fc2ba4 100644 --- a/core/src/lookup/debug.rs +++ b/core/src/lookup/debug.rs @@ -48,6 +48,7 @@ fn field_to_int(x: F) -> i32 { pub fn debug_interactions>>( chip: &MachineChip, + program: &A::Program, record: &A::Record, interaction_kinds: Vec, ) -> ( @@ -58,6 +59,7 @@ pub fn debug_interactions>>( let mut key_to_count = BTreeMap::new(); let trace = chip.generate_trace(record, &mut A::Record::default()); + let mut preprocessed_trace = chip.generate_preprocessed_trace(program); let mut main = trace.clone(); let height = trace.clone().height(); @@ -72,13 +74,21 @@ pub fn debug_interactions>>( if !interaction_kinds.contains(&interaction.kind) { continue; } + let mut empty = vec![]; + let preprocessed_row = preprocessed_trace + .as_mut() + .map(|t| t.row_mut(row)) + .or_else(|| Some(&mut empty)) + .unwrap(); let is_send = m < nb_send_interactions; - let multiplicity_eval: Val = interaction.multiplicity.apply(&[], main.row_mut(row)); + let multiplicity_eval: Val = interaction + .multiplicity + .apply(preprocessed_row, main.row_mut(row)); if !multiplicity_eval.is_zero() { let mut values = vec![]; for value in &interaction.values { - let expr: Val = value.apply(&[], main.row_mut(row)); + let expr: Val = value.apply(preprocessed_row, main.row_mut(row)); values.push(expr); } let key = format!( @@ -114,6 +124,7 @@ pub fn debug_interactions>>( /// and print out the ones for which the set of sends and receives don't match. pub fn debug_interactions_with_all_chips( chips: &[MachineChip], + program: &A::Program, segment: &A::Record, interaction_kinds: Vec, ) -> bool @@ -127,7 +138,8 @@ where let mut total = SC::Val::zero(); for chip in chips.iter() { - let (_, count) = debug_interactions::(chip, segment, interaction_kinds.clone()); + let (_, count) = + debug_interactions::(chip, program, segment, interaction_kinds.clone()); tracing::info!("{} chip has {} distinct events", chip.name(), count.len()); for (key, value) in count.iter() { diff --git a/core/src/memory/global.rs b/core/src/memory/global.rs index b0fcfd12bb..4b02c8231c 100644 --- a/core/src/memory/global.rs +++ b/core/src/memory/global.rs @@ -229,6 +229,7 @@ mod tests { fn test_memory_lookup_interactions() { setup_logger(); let program = sha_extend_program(); + let program_clone = program.clone(); let mut runtime = Runtime::new(program); runtime.run(); @@ -236,6 +237,7 @@ mod tests { RiscvAir::machine(BabyBearPoseidon2::new()); debug_interactions_with_all_chips::>( machine.chips(), + &program_clone, &runtime.record, vec![InteractionKind::Memory], ); @@ -245,12 +247,14 @@ mod tests { fn test_byte_lookup_interactions() { setup_logger(); let program = sha_extend_program(); + let program_clone = program.clone(); let mut runtime = Runtime::new(program); runtime.run(); let machine = RiscvAir::machine(BabyBearPoseidon2::new()); debug_interactions_with_all_chips::>( machine.chips(), + &program_clone, &runtime.record, vec![InteractionKind::Byte], ); diff --git a/core/src/stark/machine.rs b/core/src/stark/machine.rs index 005a2c8cc3..64ccd7567a 100644 --- a/core/src/stark/machine.rs +++ b/core/src/stark/machine.rs @@ -274,6 +274,7 @@ impl>> MachineStark { pub fn debug_constraints( &self, + program: &A::Program, pk: &ProvingKey, record: A::Record, challenger: &mut SC::Challenger, @@ -294,7 +295,12 @@ impl>> MachineStark { // Generate the main trace for each chip. let traces = chips .par_iter() - .map(|chip| chip.generate_trace(shard, &mut A::Record::default())) + .map(|chip| { + ( + chip.generate_trace(shard, &mut A::Record::default()), + chip.generate_preprocessed_trace(program), + ) + }) .collect::>(); // Get a permutation challenge. @@ -311,9 +317,9 @@ impl>> MachineStark { chips .par_iter() .zip(traces.par_iter()) - .map(|(chip, main_trace)| { + .map(|(chip, (main_trace, pre_trace))| { let perm_trace = chip.generate_permutation_trace( - None, + pre_trace.as_ref(), main_trace, &permutation_challenges, ); @@ -331,15 +337,15 @@ impl>> MachineStark { // Compute some statistics. for i in 0..chips.len() { - let trace_width = traces[i].width(); + let trace_width = traces[i].0.width(); let permutation_width = permutation_traces[i].width(); let total_width = trace_width + permutation_width; tracing::debug!( "{:<11} | Cols = {:<5} | Rows = {:<5} | Cells = {:<10} | Main Cols = {:.2}% | Perm Cols = {:.2}%", chips[i].name(), total_width, - traces[i].height(), - total_width * traces[i].height(), + traces[i].0.height(), + total_width * traces[i].0.height(), (100f32 * trace_width as f32) / total_width as f32, (100f32 * permutation_width as f32) / total_width as f32); } @@ -353,7 +359,7 @@ impl>> MachineStark { debug_constraints::( chips[i], permutation_trace, - &traces[i], + &traces[i].0, &permutation_traces[i], &permutation_challenges, PublicValuesDigest::>>::new(shard.public_values_digest()), @@ -371,6 +377,7 @@ impl>> MachineStark { } debug_interactions_with_all_chips::( self.chips(), + program, &record, InteractionKind::all_kinds(), ); diff --git a/core/src/utils/prove.rs b/core/src/utils/prove.rs index e298d502db..ac4778ca7f 100644 --- a/core/src/utils/prove.rs +++ b/core/src/utils/prove.rs @@ -81,7 +81,7 @@ pub fn run_test_core( let nb_bytes = bincode::serialize(&proof).unwrap().len(); #[cfg(feature = "debug")] - machine.debug_constraints(&pk, record_clone, &mut challenger); + machine.debug_constraints(&runtime.program, &pk, record_clone, &mut challenger); let mut challenger = machine.config().challenger(); machine.verify(&vk, &proof, &mut challenger)?; @@ -137,7 +137,7 @@ where #[cfg(feature = "debug")] { let record_clone = runtime.record.clone(); - machine.debug_constraints(&pk, record_clone, &mut challenger); + machine.debug_constraints(&program, &pk, record_clone, &mut challenger); } let public_values = std::mem::take(&mut runtime.state.public_values_stream); let proof = prove_core(machine.config().clone(), runtime); diff --git a/recursion/core/src/cpu/air.rs b/recursion/core/src/cpu/air.rs index afde26b459..cfefd80a85 100644 --- a/recursion/core/src/cpu/air.rs +++ b/recursion/core/src/cpu/air.rs @@ -1,3 +1,4 @@ +use crate::air::BinomialExtensionUtils; use crate::air::BlockBuilder; use crate::cpu::CpuChip; use crate::runtime::Program; @@ -11,10 +12,12 @@ use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; use p3_matrix::MatrixRowSlices; use sp1_core::air::AirInteraction; +use sp1_core::air::BinomialExtension; +use sp1_core::air::MachineAir; use sp1_core::lookup::InteractionKind; use sp1_core::stark::SP1AirBuilder; use sp1_core::utils::indices_arr; -use sp1_core::{air::MachineAir, utils::pad_to_power_of_two}; +use sp1_core::utils::pad_rows; use std::borrow::Borrow; use std::borrow::BorrowMut; use std::mem::transmute; @@ -44,7 +47,7 @@ impl MachineAir for CpuChip { input: &ExecutionRecord, _: &mut ExecutionRecord, ) -> RowMajorMatrix { - let rows = input + let mut rows = input .cpu_events .iter() .map(|event| { @@ -55,6 +58,8 @@ impl MachineAir for CpuChip { cols.pc = event.pc; cols.fp = event.fp; + cols.selectors.populate(&event.instruction); + cols.instruction.opcode = F::from_canonical_u32(event.instruction.opcode as u32); cols.instruction.op_a = event.instruction.op_a; cols.instruction.op_b = event.instruction.op_b; @@ -76,18 +81,22 @@ impl MachineAir for CpuChip { cols.c.value = event.instruction.op_c; } - // cols.add_scratch = cols.b.value.0[0] + cols.c.value.0[0]; - // cols.sub_scratch = cols.b.value.0[0] - cols.c.value.0[0]; - // cols.mul_scratch = cols.b.value.0[0] * cols.c.value.0[0]; - // cols.add_ext_scratch = (BinomialExtension::from_block(cols.b.value) - // + BinomialExtension::from_block(cols.c.value)) - // .as_block(); - // cols.sub_ext_scratch = (BinomialExtension::from_block(cols.b.value) - // - BinomialExtension::from_block(cols.c.value)) - // .as_block(); - // cols.mul_ext_scratch = (BinomialExtension::from_block(cols.b.value) - // * BinomialExtension::from_block(cols.c.value)) - // .as_block(); + let alu_cols = cols.opcode_specific.alu_mut(); + if cols.selectors.is_add.is_one() { + alu_cols.add_scratch.0[0] = cols.b.value.0[0] + cols.c.value.0[0]; + alu_cols.sub_scratch.0[0] = cols.b.value.0[0] - cols.c.value.0[0]; + alu_cols.mul_scratch.0[0] = cols.b.value.0[0] * cols.c.value.0[0]; + } else if cols.selectors.is_eadd.is_one() || cols.selectors.is_efadd.is_one() { + alu_cols.add_scratch = (BinomialExtension::from_block(cols.b.value) + + BinomialExtension::from_block(cols.c.value)) + .as_block(); + alu_cols.sub_scratch = (BinomialExtension::from_block(cols.b.value) + - BinomialExtension::from_block(cols.c.value)) + .as_block(); + alu_cols.mul_scratch = (BinomialExtension::from_block(cols.b.value) + * BinomialExtension::from_block(cols.c.value)) + .as_block(); + } // cols.a_eq_b // .populate((cols.a.value.0[0] - cols.b.value.0[0]).as_canonical_u32()); @@ -101,12 +110,16 @@ impl MachineAir for CpuChip { }) .collect::>(); + pad_rows(&mut rows, || { + let mut row = [F::zero(); NUM_CPU_COLS]; + let cols: &mut CpuCols = row.as_mut_slice().borrow_mut(); + cols.selectors.is_noop = F::one(); + row + }); + let mut trace = RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_CPU_COLS); - // Pad the trace to a power of two. - pad_to_power_of_two::(&mut trace.values); - for i in input.cpu_events.len()..trace.height() { trace.values[i * NUM_CPU_COLS + CPU_COL_MAP.clk] = F::from_canonical_u32(4) * F::from_canonical_usize(i); @@ -148,7 +161,7 @@ where builder .when_transition() .when(next.is_real) - .assert_eq(local.clk + AB::F::from_canonical_u32(4), next.clk); + .assert_eq(local.clk.into() + AB::F::from_canonical_u32(4), next.clk); // // Increment pc by 1 every cycle unless it is a branch instruction that is satisfied. // builder @@ -168,9 +181,24 @@ where .assert_block_eq::(local.c.value, local.instruction.op_c); // Compute ALU. - // builder.assert_eq(local.b.value.0[0] + local.c.value.0[0], local.add_scratch); - // builder.assert_eq(local.b.value.0[0] - local.c.value.0[0], local.sub_scratch); - // builder.assert_eq(local.b.value.0[0] * local.c.value.0[0], local.mul_scratch); + let alu_cols = local.opcode_specific.alu(); + builder.when(local.selectors.is_add).assert_eq( + local.b.value.0[0] + local.c.value.0[0], + alu_cols.add_scratch[0], + ); + builder.when(local.selectors.is_add).assert_eq( + local.b.value.0[0] - local.c.value.0[0], + alu_cols.sub_scratch[0], + ); + builder.when(local.selectors.is_add).assert_eq( + local.b.value.0[0] * local.c.value.0[0], + alu_cols.mul_scratch[0], + ); + + builder.assert_eq( + local.is_real * local.is_real * local.is_real, + local.is_real * local.is_real * local.is_real, + ); // // Compute extension ALU. // builder.assert_ext_eq( @@ -318,6 +346,13 @@ where prog_interaction_vals.extend_from_slice(&local.instruction.op_c.map(|x| x.into()).0); prog_interaction_vals.push(local.instruction.imm_b.into()); prog_interaction_vals.push(local.instruction.imm_c.into()); + prog_interaction_vals.extend_from_slice( + &local + .selectors + .into_iter() + .map(|x| x.into()) + .collect::>(), + ); builder.send(AirInteraction::new( prog_interaction_vals, local.is_real.into(), diff --git a/recursion/core/src/cpu/columns/alu.rs b/recursion/core/src/cpu/columns/alu.rs index c45603f56e..edeb88d466 100644 --- a/recursion/core/src/cpu/columns/alu.rs +++ b/recursion/core/src/cpu/columns/alu.rs @@ -9,20 +9,11 @@ pub struct AluCols { pub ext_b: Block, // c = a + b; - pub add_scratch: T, + pub add_scratch: Block, // c = a - b; - pub sub_scratch: T, + pub sub_scratch: Block, // c = a * b; - pub mul_scratch: T, - - // ext(c) = ext(a) + ext(b); - pub add_ext_scratch: Block, - - // ext(c) = ext(a) - ext(b); - pub sub_ext_scratch: Block, - - // ext(c) = ext(a) * ext(b); - pub mul_ext_scratch: Block, + pub mul_scratch: Block, } diff --git a/recursion/core/src/cpu/columns/instruction.rs b/recursion/core/src/cpu/columns/instruction.rs index 2220c7b1ff..83dfb1caa8 100644 --- a/recursion/core/src/cpu/columns/instruction.rs +++ b/recursion/core/src/cpu/columns/instruction.rs @@ -20,6 +20,8 @@ impl InstructionCols { self.op_a = instruction.op_a; self.op_b = instruction.op_b; self.op_c = instruction.op_c; + self.imm_b = F::from_bool(instruction.imm_b); + self.imm_c = F::from_bool(instruction.imm_c) } } diff --git a/recursion/core/src/cpu/columns/mod.rs b/recursion/core/src/cpu/columns/mod.rs index da14d923b8..954df26260 100644 --- a/recursion/core/src/cpu/columns/mod.rs +++ b/recursion/core/src/cpu/columns/mod.rs @@ -12,10 +12,12 @@ pub use alu::*; pub use instruction::*; pub use opcode::*; +use self::opcode_specific::OpcodeSpecificCols; + /// The column layout for the chip. #[derive(AlignedBorrow, Default, Clone, Debug)] #[repr(C)] -pub struct CpuCols { +pub struct CpuCols { pub clk: T, pub pc: T, pub fp: T, @@ -27,7 +29,7 @@ pub struct CpuCols { pub b: MemoryReadWriteCols, pub c: MemoryReadWriteCols, - pub alu: AluCols, + pub opcode_specific: OpcodeSpecificCols, // result = operand_1 == operand_2; pub eq_1_2: IsExtZeroOperation, diff --git a/recursion/core/src/cpu/columns/opcode.rs b/recursion/core/src/cpu/columns/opcode.rs index 9277ea1865..0988599a53 100644 --- a/recursion/core/src/cpu/columns/opcode.rs +++ b/recursion/core/src/cpu/columns/opcode.rs @@ -59,7 +59,7 @@ impl OpcodeSelectorCols { /// The opcode flag should be set to 1 for the relevant opcode and 0 for the rest. We already /// assume that the state of the columns is set to zero at the start of the function, so we only /// need to set the relevant opcode column to 1. - pub fn populate(&mut self, instruction: Instruction) { + pub fn populate(&mut self, instruction: &Instruction) { match instruction.opcode { Opcode::ADD => self.is_add = F::one(), Opcode::SUB => self.is_sub = F::one(), @@ -86,12 +86,14 @@ impl OpcodeSelectorCols { Opcode::JAL => self.is_jal = F::one(), Opcode::JALR => self.is_jalr = F::one(), Opcode::TRAP => self.is_trap = F::one(), - _ => unimplemented!(), + Opcode::PrintF => self.is_noop = F::one(), + Opcode::PrintE => self.is_noop = F::one(), + _ => unimplemented!("opcode {:?} not supported", instruction.opcode), } } } -impl IntoIterator for OpcodeSelectorCols { +impl IntoIterator for &OpcodeSelectorCols { type Item = T; type IntoIter = std::array::IntoIter; diff --git a/recursion/core/src/cpu/columns/opcode_specific.rs b/recursion/core/src/cpu/columns/opcode_specific.rs index 8b13789179..ee423e308f 100644 --- a/recursion/core/src/cpu/columns/opcode_specific.rs +++ b/recursion/core/src/cpu/columns/opcode_specific.rs @@ -1 +1,38 @@ +use crate::cpu::columns::AluCols; +use std::fmt::{Debug, Formatter}; +use std::mem::{size_of, transmute}; +pub const NUM_OPCODE_SPECIFIC_COLS: usize = size_of::>(); + +/// Shared columns whose interpretation depends on the instruction being executed. +#[derive(Clone, Copy)] +#[repr(C)] +pub union OpcodeSpecificCols { + alu: AluCols, +} + +impl Default for OpcodeSpecificCols { + fn default() -> Self { + OpcodeSpecificCols { + alu: AluCols::::default(), + } + } +} + +impl Debug for OpcodeSpecificCols { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + // SAFETY: repr(C) ensures uniform fields are in declaration order with no padding. + let self_arr: &[T; NUM_OPCODE_SPECIFIC_COLS] = unsafe { transmute(self) }; + Debug::fmt(self_arr, f) + } +} + +// SAFETY: Each view is a valid interpretation of the underlying array. +impl OpcodeSpecificCols { + pub fn alu(&self) -> &AluCols { + unsafe { &self.alu } + } + pub fn alu_mut(&mut self) -> &mut AluCols { + unsafe { &mut self.alu } + } +} diff --git a/recursion/core/src/program/mod.rs b/recursion/core/src/program/mod.rs index 3455d57f95..b79e26f908 100644 --- a/recursion/core/src/program/mod.rs +++ b/recursion/core/src/program/mod.rs @@ -1,33 +1,49 @@ -use crate::runtime::Program; -use crate::{cpu::InstructionCols, runtime::ExecutionRecord}; +use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; -use p3_air::{Air, BaseAir}; +use p3_air::{Air, BaseAir, PairBuilder}; use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::MatrixRowSlices; +use sp1_core::air::{AirInteraction, MachineAir, SP1AirBuilder}; use sp1_core::lookup::InteractionKind; -use sp1_core::{ - air::{AirInteraction, MachineAir, SP1AirBuilder}, - utils::pad_to_power_of_two, -}; -use sp1_derive::AlignedBorrow; -use std::borrow::Borrow; -use std::borrow::BorrowMut; +use sp1_core::utils::pad_to_power_of_two; use std::collections::HashMap; -pub const NUM_PROGRAM_COLS: usize = size_of::>(); +use sp1_derive::AlignedBorrow; -#[derive(Default)] -pub struct ProgramChip; +use crate::cpu::columns::InstructionCols; +use crate::cpu::columns::OpcodeSelectorCols; +use crate::runtime::{ExecutionRecord, Program}; + +pub const NUM_PROGRAM_PREPROCESSED_COLS: usize = size_of::>(); +pub const NUM_PROGRAM_MULT_COLS: usize = size_of::>(); -#[derive(AlignedBorrow, Clone, Copy, Debug, Default)] +/// The column layout for the chip. +#[derive(AlignedBorrow, Clone, Copy, Default)] #[repr(C)] -pub struct ProgramCols { +pub struct ProgramPreprocessedCols { pub pc: T, pub instruction: InstructionCols, + pub selectors: OpcodeSelectorCols, +} + +/// The column layout for the chip. +#[derive(AlignedBorrow, Clone, Copy, Default)] +#[repr(C)] +pub struct ProgramMultiplicityCols { pub multiplicity: T, } +/// A chip that implements addition for the opcodes ADD and ADDI. +#[derive(Default)] +pub struct ProgramChip; + +impl ProgramChip { + pub fn new() -> Self { + Self {} + } +} + impl MachineAir for ProgramChip { type Record = ExecutionRecord; @@ -37,38 +53,70 @@ impl MachineAir for ProgramChip { "Program".to_string() } + fn preprocessed_width(&self) -> usize { + NUM_PROGRAM_PREPROCESSED_COLS + } + + fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option> { + let rows = program + .instructions + .clone() + .into_iter() + .enumerate() + .map(|(i, instruction)| { + let pc = i as u32 * 4; + let mut row = [F::zero(); NUM_PROGRAM_PREPROCESSED_COLS]; + let cols: &mut ProgramPreprocessedCols = row.as_mut_slice().borrow_mut(); + cols.pc = F::from_canonical_u32(pc); + cols.selectors.populate(&instruction); + cols.instruction.populate(instruction); + + row + }) + .collect::>(); + + // Convert the trace to a row major matrix. + let mut trace = RowMajorMatrix::new( + rows.into_iter().flatten().collect::>(), + NUM_PROGRAM_PREPROCESSED_COLS, + ); + + // Pad the trace to a power of two. + pad_to_power_of_two::(&mut trace.values); + + Some(trace) + } + fn generate_trace( &self, input: &ExecutionRecord, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { + // Generate the trace rows for each event. + + // Collect the number of times each instruction is called from the cpu events. + // Store it as a map of PC -> count. let mut instruction_counts = HashMap::new(); input.cpu_events.iter().for_each(|event| { let pc = event.pc; instruction_counts - .entry(pc) + .entry(pc.as_canonical_u32()) .and_modify(|count| *count += 1) .or_insert(1); }); + let rows = input .program .instructions .clone() .into_iter() .enumerate() - .map(|(i, instruction)| { - let pc = F::from_canonical_u32(i as u32); - let mut row = [F::zero(); NUM_PROGRAM_COLS]; - let cols: &mut ProgramCols = row.as_mut_slice().borrow_mut(); - cols.pc = pc; - cols.instruction.opcode = F::from_canonical_u32(instruction.opcode as u32); - cols.instruction.op_a = instruction.op_a; - cols.instruction.op_b = instruction.op_b; - cols.instruction.op_c = instruction.op_c; - cols.instruction.imm_b = F::from_bool(instruction.imm_b); - cols.instruction.imm_c = F::from_bool(instruction.imm_c); + .map(|(i, _)| { + let pc = i as u32; + let mut row = [F::zero(); NUM_PROGRAM_MULT_COLS]; + let cols: &mut ProgramMultiplicityCols = row.as_mut_slice().borrow_mut(); cols.multiplicity = - F::from_canonical_usize(*instruction_counts.get(&cols.pc).unwrap_or(&0)); + F::from_canonical_usize(*instruction_counts.get(&pc).unwrap_or(&0)); row }) .collect::>(); @@ -76,11 +124,11 @@ impl MachineAir for ProgramChip { // Convert the trace to a row major matrix. let mut trace = RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), - NUM_PROGRAM_COLS, + NUM_PROGRAM_MULT_COLS, ); // Pad the trace to a power of two. - pad_to_power_of_two::(&mut trace.values); + pad_to_power_of_two::(&mut trace.values); trace } @@ -92,33 +140,43 @@ impl MachineAir for ProgramChip { impl BaseAir for ProgramChip { fn width(&self) -> usize { - NUM_PROGRAM_COLS + NUM_PROGRAM_MULT_COLS } } impl Air for ProgramChip where - AB: SP1AirBuilder, + AB: SP1AirBuilder + PairBuilder, { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let local: &ProgramCols = main.row_slice(0).borrow(); + let preprocessed = builder.preprocessed(); + + let prep_local: &ProgramPreprocessedCols = preprocessed.row_slice(0).borrow(); + let mult_local: &ProgramMultiplicityCols = main.row_slice(0).borrow(); // Dummy constraint of degree 3. builder.assert_eq( - local.pc * local.pc * local.pc, - local.pc * local.pc * local.pc, + prep_local.pc * prep_local.pc * prep_local.pc, + prep_local.pc * prep_local.pc * prep_local.pc, ); - let mut interaction_vals: Vec = vec![local.instruction.opcode.into()]; - interaction_vals.push(local.instruction.op_a.into()); - interaction_vals.extend_from_slice(&local.instruction.op_b.map(|x| x.into()).0); - interaction_vals.extend_from_slice(&local.instruction.op_c.map(|x| x.into()).0); - interaction_vals.push(local.instruction.imm_b.into()); - interaction_vals.push(local.instruction.imm_c.into()); + let mut interaction_vals: Vec = vec![prep_local.instruction.opcode.into()]; + interaction_vals.push(prep_local.instruction.op_a.into()); + interaction_vals.extend_from_slice(&prep_local.instruction.op_b.map(|x| x.into()).0); + interaction_vals.extend_from_slice(&prep_local.instruction.op_c.map(|x| x.into()).0); + interaction_vals.push(prep_local.instruction.imm_b.into()); + interaction_vals.push(prep_local.instruction.imm_c.into()); + interaction_vals.extend_from_slice( + &prep_local + .selectors + .into_iter() + .map(|x| x.into()) + .collect::>(), + ); builder.receive(AirInteraction::new( interaction_vals, - local.multiplicity.into(), + mult_local.multiplicity.into(), InteractionKind::Program, )); } diff --git a/recursion/core/src/runtime/mod.rs b/recursion/core/src/runtime/mod.rs index 96b0cbec0d..1dcfad44a8 100644 --- a/recursion/core/src/runtime/mod.rs +++ b/recursion/core/src/runtime/mod.rs @@ -184,6 +184,10 @@ where }; } + fn get_memory_entry(&self, addr: F) -> &MemoryEntry { + &self.memory[addr.as_canonical_u32() as usize] + } + fn timestamp(&self, position: &MemoryAccessPosition) -> F { self.clk + F::from_canonical_u32(*position as u32) } @@ -257,7 +261,8 @@ where // If b is an immediate, then we store the value at the address in a. self.fp + instruction.op_a } else { - self.mr(self.fp + instruction.op_a, MemoryAccessPosition::A)[0] + index * size + offset + // Load without touching access. This assumes that the caller will call mw on a_ptr. + self.get_memory_entry(self.fp + instruction.op_a).value[0] + index * size + offset }; let b = if instruction.imm_b_base() { @@ -284,6 +289,7 @@ where while self.pc < F::from_canonical_u32(self.program.instructions.len() as u32) { let idx = self.pc.as_canonical_u32() as usize; let instruction = self.program.instructions[idx].clone(); + let mut next_pc = self.pc + F::one(); let (a, b, c): (Block, Block, Block); match instruction.opcode { @@ -368,7 +374,7 @@ where Opcode::LW => { self.nb_memory_ops += 1; let (a_ptr, b_val) = self.load_rr(&instruction); - let prev_a = self.mr(a_ptr, MemoryAccessPosition::A); + let prev_a = self.get_memory_entry(a_ptr).value; let a_val = Block::from([b_val[0], prev_a[1], prev_a[2], prev_a[3]]); self.mw(a_ptr, a_val, MemoryAccessPosition::A); (a, b, c) = (a_val, b_val, Block::default()); @@ -383,7 +389,7 @@ where Opcode::SW => { self.nb_memory_ops += 1; let (a_ptr, b_val) = self.store_rr(&instruction); - let prev_a = self.mr(a_ptr, MemoryAccessPosition::A); + let prev_a = self.get_memory_entry(a_ptr).value; let a_val = Block::from([b_val[0], prev_a[1], prev_a[2], prev_a[3]]); self.mw(a_ptr, a_val, MemoryAccessPosition::A); (a, b, c) = (a_val, b_val, Block::default()); diff --git a/recursion/core/src/runtime/opcode.rs b/recursion/core/src/runtime/opcode.rs index 5776732eb7..1dc88d8f14 100644 --- a/recursion/core/src/runtime/opcode.rs +++ b/recursion/core/src/runtime/opcode.rs @@ -1,7 +1,7 @@ use p3_field::AbstractField; #[allow(clippy::upper_case_acronyms)] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Opcode { // Arithmetic field instructions. ADD = 0, diff --git a/recursion/core/src/runtime/record.rs b/recursion/core/src/runtime/record.rs index 119d545c0f..1a9ece21d4 100644 --- a/recursion/core/src/runtime/record.rs +++ b/recursion/core/src/runtime/record.rs @@ -34,7 +34,13 @@ impl MachineRecord for ExecutionRecord { HashMap::new() } - fn append(&mut self, _: &mut Self) {} + fn append(&mut self, other: &mut Self) { + self.cpu_events.append(&mut other.cpu_events); + self.first_memory_record + .append(&mut other.first_memory_record); + self.last_memory_record + .append(&mut other.last_memory_record); + } fn shard(self, _: &Self::Config) -> Vec { vec![self] diff --git a/recursion/program/src/lib.rs b/recursion/program/src/lib.rs index c845773780..0a62ff2f00 100644 --- a/recursion/program/src/lib.rs +++ b/recursion/program/src/lib.rs @@ -1,3 +1,4 @@ +#![allow(incomplete_features)] #![feature(generic_const_exprs)] pub mod challenger; pub mod commit; diff --git a/recursion/program/src/stark.rs b/recursion/program/src/stark.rs index 70fb98d9b1..9a2a59c36e 100644 --- a/recursion/program/src/stark.rs +++ b/recursion/program/src/stark.rs @@ -261,8 +261,10 @@ pub(crate) mod tests { use crate::challenger::CanObserveVariable; use crate::challenger::FeltChallenger; + use crate::stark::Ext; use p3_challenger::{CanObserve, FieldChallenger}; use p3_field::AbstractField; + use rand::Rng; use sp1_core::air::PublicValuesDigest; use sp1_core::runtime::Program; use sp1_core::{ @@ -292,6 +294,10 @@ pub(crate) mod tests { }, }; + use sp1_core::stark::LocalProver; + use sp1_recursion_core::stark::RecursionAir; + use sp1_sdk::utils::setup_logger; + type SC = BabyBearPoseidon2; type F = ::Val; type EF = ::Challenge; @@ -518,6 +524,25 @@ pub(crate) mod tests { let elapsed = time.elapsed(); runtime.print_stats(); println!("Execution took: {:?}", elapsed); + + // let config = BabyBearPoseidon2::new(); + // let machine = RecursionAir::machine(config); + // let (pk, vk) = machine.setup(&program); + // let mut challenger = machine.config().challenger(); + + // // debug_interactions_with_all_chips::>( + // // machine.chips(), + // // &runtime.record, + // // vec![InteractionKind::Memory], + // // ); + + // let start = Instant::now(); + // let proof = machine.prove::>(&pk, runtime.record, &mut challenger); + // let duration = start.elapsed().as_secs(); + + // let mut challenger = machine.config().challenger(); + // machine.verify(&vk, &proof, &mut challenger).unwrap(); + // println!("proving duration = {}", duration); } #[test] @@ -592,4 +617,55 @@ pub(crate) mod tests { runtime.run(); } + + #[test] + #[ignore] + fn test_kitchen_sink() { + setup_logger(); + + let time = Instant::now(); + let mut builder = VmBuilder::::default(); + + let a: Felt<_> = builder.eval(F::from_canonical_u32(23)); + let b: Felt<_> = builder.eval(F::from_canonical_u32(17)); + let a_plus_b = builder.eval(a + b); + let mut rng = rand::thread_rng(); + let a_ext_val = rng.gen::(); + let b_ext_val = rng.gen::(); + let a_ext: Ext<_, _> = builder.eval(a_ext_val.cons()); + let b_ext: Ext<_, _> = builder.eval(b_ext_val.cons()); + let a_plus_b_ext = builder.eval(a_ext + b_ext); + builder.print_f(a_plus_b); + builder.print_e(a_plus_b_ext); + + let program = builder.compile(); + let elapsed = time.elapsed(); + println!("Building took: {:?}", elapsed); + + let machine = A::machine(SC::default()); + let mut runtime = Runtime::::new(&program, machine.config().perm.clone()); + + let time = Instant::now(); + runtime.run(); + let elapsed = time.elapsed(); + runtime.print_stats(); + println!("Execution took: {:?}", elapsed); + + let config = BabyBearPoseidon2::new(); + let machine = RecursionAir::machine(config); + let (pk, vk) = machine.setup(&program); + let mut challenger = machine.config().challenger(); + + let record_clone = runtime.record.clone(); + machine.debug_constraints(&program, &pk, record_clone, &mut challenger); + + let start = Instant::now(); + let mut challenger = machine.config().challenger(); + let proof = machine.prove::>(&pk, runtime.record, &mut challenger); + let duration = start.elapsed().as_secs(); + + let mut challenger = machine.config().challenger(); + machine.verify(&vk, &proof, &mut challenger).unwrap(); + println!("proving duration = {}", duration); + } }