Skip to content

Commit

Permalink
feat: recursion program table + memory tracing (#356)
Browse files Browse the repository at this point in the history
Co-authored-by: John Guibas <[email protected]>
  • Loading branch information
jtguibas and John Guibas authored Mar 8, 2024
1 parent e56ab93 commit 428120a
Show file tree
Hide file tree
Showing 10 changed files with 373 additions and 39 deletions.
17 changes: 16 additions & 1 deletion recursion/core/src/cpu/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use p3_air::BaseAir;
use p3_field::PrimeField32;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::MatrixRowSlices;
use sp1_core::air::AirInteraction;
use sp1_core::lookup::InteractionKind;
use sp1_core::stark::SP1AirBuilder;
use sp1_core::{air::MachineAir, utils::pad_to_power_of_two};
use std::borrow::Borrow;
Expand Down Expand Up @@ -42,6 +44,7 @@ impl<F: PrimeField32> MachineAir<F> for CpuChip<F> {
cols.instruction.op_c = event.instruction.op_c;
cols.instruction.imm_b = F::from_canonical_u32(event.instruction.imm_b as u32);
cols.instruction.imm_c = F::from_canonical_u32(event.instruction.imm_c as u32);
cols.is_real = F::one();
row
})
.collect::<Vec<_>>();
Expand Down Expand Up @@ -72,7 +75,19 @@ where
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let _: &CpuCols<AB::Var> = main.row_slice(0).borrow();
let local: &CpuCols<AB::Var> = main.row_slice(0).borrow();
let _: &CpuCols<AB::Var> = main.row_slice(1).borrow();
builder.send(AirInteraction::new(
vec![
local.instruction.opcode.into(),
local.instruction.op_a.into(),
local.instruction.op_b.into(),
local.instruction.op_c.into(),
local.instruction.imm_b.into(),
local.instruction.imm_c.into(),
],
local.is_real.into(),
InteractionKind::Program,
));
}
}
4 changes: 3 additions & 1 deletion recursion/core/src/cpu/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use sp1_core::operations::IsZeroOperation;
use sp1_derive::AlignedBorrow;

/// The column layout for the chip.
#[derive(AlignedBorrow, Default, Clone, Copy)]
#[derive(AlignedBorrow, Default, Clone, Copy, Debug)]
#[repr(C)]
pub struct CpuCols<T> {
pub clk: T,
Expand Down Expand Up @@ -54,6 +54,8 @@ pub struct CpuCols<T> {

// c = a == b;
pub a_eq_b: IsZeroOperation<T>,

pub is_real: T,
}

#[derive(AlignedBorrow, Default, Debug, Clone, Copy)]
Expand Down
22 changes: 8 additions & 14 deletions recursion/core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
pub mod air;
pub mod cpu;
pub mod memory;
pub mod program;
pub mod runtime;
pub mod stark;

#[cfg(test)]
pub mod tests {
use crate::runtime::{ExecutionRecord, Instruction, Opcode, Program, Runtime};
use crate::runtime::{Instruction, Opcode, Program, Runtime};
use crate::stark::RecursionAir;

use p3_baby_bear::BabyBear;
Expand All @@ -18,9 +19,9 @@ pub mod tests {

pub fn fibonacci_program<F: PrimeField32>() -> Program<F> {
// .main
// imm 0(fp) 1 <-- a = 1
// imm 1(fp) 1 <-- b = 1
// imm 2(fp) 10 <-- iterations = 10
// imm 0(fp) 1 <-- a = 1
// imm 1(fp) 1 <-- b = 1
// imm 2(fp) 10 <-- iterations = 10
// .body:
// add 3(fp) 0(fp) 1(fp) <-- tmp = a + b
// sw 0(fp) 1(fp) <-- a = b
Expand All @@ -45,18 +46,11 @@ pub mod tests {

#[test]
fn test_fibonacci_execute() {
let program = fibonacci_program();
let mut runtime = Runtime::<BabyBear> {
clk: BabyBear::zero(),
program,
fp: BabyBear::zero(),
pc: BabyBear::zero(),
memory: vec![BabyBear::zero(); 1024 * 1024],
record: ExecutionRecord::<BabyBear>::default(),
};
let program = fibonacci_program::<BabyBear>();
let mut runtime = Runtime::new(&program);
runtime.run();
println!("{:#?}", runtime.record.cpu_events);
assert_eq!(runtime.memory[1], BabyBear::from_canonical_u32(144));
assert_eq!(runtime.memory[1].value, BabyBear::from_canonical_u32(144));
}

#[test]
Expand Down
119 changes: 119 additions & 0 deletions recursion/core/src/memory/global.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
use core::mem::size_of;
use p3_air::{Air, BaseAir};
use p3_field::PrimeField32;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::MatrixRowSlices;
use sp1_core::air::SP1AirBuilder;
use sp1_core::utils::indices_arr;
use sp1_core::{air::MachineAir, utils::pad_to_power_of_two};
use sp1_derive::AlignedBorrow;
use std::borrow::Borrow;
use std::mem::transmute;

use crate::memory::Word;
use crate::runtime::ExecutionRecord;

#[allow(dead_code)]
#[derive(PartialEq)]
pub enum MemoryChipKind {
Init,
Finalize,
Program,
}

pub struct MemoryGlobalChip {
pub kind: MemoryChipKind,
}

#[allow(dead_code)]
impl MemoryGlobalChip {
pub fn new(kind: MemoryChipKind) -> Self {
Self { kind }
}
}

impl<F> BaseAir<F> for MemoryGlobalChip {
fn width(&self) -> usize {
NUM_MEMORY_INIT_COLS
}
}

impl<F: PrimeField32> MachineAir<F> for MemoryGlobalChip {
type Record = ExecutionRecord<F>;

fn name(&self) -> String {
match self.kind {
MemoryChipKind::Init => "MemoryInit".to_string(),
MemoryChipKind::Finalize => "MemoryFinalize".to_string(),
MemoryChipKind::Program => "MemoryProgram".to_string(),
}
}

#[allow(unused_variables)]
fn generate_trace(
&self,
input: &ExecutionRecord<F>,
_output: &mut ExecutionRecord<F>,
) -> RowMajorMatrix<F> {
let memory_record = match self.kind {
MemoryChipKind::Init => &input.first_memory_record,
MemoryChipKind::Finalize => &input.last_memory_record,
MemoryChipKind::Program => &input.program_memory_record,
};
let rows: Vec<[F; 8]> = (0..memory_record.len()) // TODO: change this back to par_iter
.map(|i| [F::zero(); NUM_MEMORY_INIT_COLS])
.collect::<Vec<_>>();

let mut trace = RowMajorMatrix::new(
rows.into_iter().flatten().collect::<Vec<_>>(),
NUM_MEMORY_INIT_COLS,
);

pad_to_power_of_two::<NUM_MEMORY_INIT_COLS, F>(&mut trace.values);

trace
}

fn included(&self, shard: &Self::Record) -> bool {
match self.kind {
MemoryChipKind::Init => !shard.first_memory_record.is_empty(),
MemoryChipKind::Finalize => !shard.last_memory_record.is_empty(),
MemoryChipKind::Program => !shard.program_memory_record.is_empty(),
}
}
}

#[derive(AlignedBorrow, Default, Debug, Clone, Copy)]
#[repr(C)]
pub struct MemoryInitCols<T> {
pub shard: T,
pub timestamp: T,
pub addr: T,
pub value: Word<T>,
pub is_real: T,
}

pub(crate) const NUM_MEMORY_INIT_COLS: usize = size_of::<MemoryInitCols<u8>>();
#[allow(dead_code)]
pub(crate) const MEMORY_INIT_COL_MAP: MemoryInitCols<usize> = make_col_map();

const fn make_col_map() -> MemoryInitCols<usize> {
let indices_arr = indices_arr::<NUM_MEMORY_INIT_COLS>();
unsafe { transmute::<[usize; NUM_MEMORY_INIT_COLS], MemoryInitCols<usize>>(indices_arr) }
}

impl<AB> Air<AB> for MemoryGlobalChip
where
AB: SP1AirBuilder,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local: &MemoryInitCols<AB::Var> = main.row_slice(0).borrow();

// Dummy constraint of degree 3.
builder.assert_eq(
local.is_real * local.is_real * local.is_real,
local.is_real * local.is_real * local.is_real,
);
}
}
3 changes: 3 additions & 0 deletions recursion/core/src/memory/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
mod global;

use crate::air::Word;
use sp1_derive::AlignedBorrow;
use std::mem::size_of;

#[derive(Debug, Clone)]
pub struct MemoryRecord<F> {
pub addr: F,
pub value: F,
pub timestamp: F,
pub prev_value: F,
Expand Down
123 changes: 123 additions & 0 deletions recursion/core/src/program/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
use crate::{cpu::columns::InstructionCols, runtime::ExecutionRecord};
use core::mem::size_of;
use hashbrown::HashMap;
use p3_air::{Air, BaseAir};
use p3_field::PrimeField32;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::MatrixRowSlices;
use sp1_core::lookup::InteractionKind;
use sp1_core::{
air::{AirInteraction, MachineAir, SP1AirBuilder},
utils::pad_to_power_of_two,
};
use sp1_derive::AlignedBorrow;
use std::borrow::Borrow;
use std::borrow::BorrowMut;

pub const NUM_PROGRAM_COLS: usize = size_of::<ProgramCols<u8>>();

#[derive(Default)]
pub struct ProgramChip;

#[derive(AlignedBorrow, Clone, Copy, Debug, Default)]
#[repr(C)]
pub struct ProgramCols<T> {
pub pc: T,
pub instruction: InstructionCols<T>,
pub multiplicity: T,
}

impl<F: PrimeField32> MachineAir<F> for ProgramChip {
type Record = ExecutionRecord<F>;

fn name(&self) -> String {
"Program".to_string()
}

fn generate_trace(
&self,
input: &ExecutionRecord<F>,
_output: &mut ExecutionRecord<F>,
) -> RowMajorMatrix<F> {
let mut instruction_counts = HashMap::new();
input.cpu_events.iter().for_each(|event| {
let pc = event.pc;
instruction_counts
.entry(pc)
.and_modify(|count| *count += 1)
.or_insert(1);
});
let rows = input
.program
.instructions
.clone()
.into_iter()
.enumerate()
.map(|(i, instruction)| {
let pc = F::from_canonical_u32(i as u32);
let mut row = [F::zero(); NUM_PROGRAM_COLS];
let cols: &mut ProgramCols<F> = row.as_mut_slice().borrow_mut();
cols.pc = pc;
cols.instruction.opcode = F::from_canonical_u32(instruction.opcode as u32);
cols.instruction.op_a = instruction.op_a;
cols.instruction.op_b = instruction.op_b;
cols.instruction.op_c = instruction.op_c;
cols.instruction.imm_b = F::from_bool(instruction.imm_b);
cols.instruction.imm_c = F::from_bool(instruction.imm_c);
cols.multiplicity =
F::from_canonical_usize(*instruction_counts.get(&cols.pc).unwrap_or(&0));
row
})
.collect::<Vec<_>>();

// Convert the trace to a row major matrix.
let mut trace = RowMajorMatrix::new(
rows.into_iter().flatten().collect::<Vec<_>>(),
NUM_PROGRAM_COLS,
);

// Pad the trace to a power of two.
pad_to_power_of_two::<NUM_PROGRAM_COLS, F>(&mut trace.values);

trace
}

fn included(&self, _: &Self::Record) -> bool {
true
}
}

impl<F> BaseAir<F> for ProgramChip {
fn width(&self) -> usize {
NUM_PROGRAM_COLS
}
}

impl<AB> Air<AB> for ProgramChip
where
AB: SP1AirBuilder,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local: &ProgramCols<AB::Var> = main.row_slice(0).borrow();

// Dummy constraint of degree 3.
builder.assert_eq(
local.pc * local.pc * local.pc,
local.pc * local.pc * local.pc,
);

builder.receive(AirInteraction::new(
vec![
local.instruction.opcode.into(),
local.instruction.op_a.into(),
local.instruction.op_b.into(),
local.instruction.op_c.into(),
local.instruction.imm_b.into(),
local.instruction.imm_c.into(),
],
local.multiplicity.into(),
InteractionKind::Program,
));
}
}
Loading

0 comments on commit 428120a

Please sign in to comment.