Skip to content

Commit

Permalink
feat: recursion cpu constraints (#464)
Browse files Browse the repository at this point in the history
Co-authored-by: John Guibas <[email protected]>
  • Loading branch information
ctian1 and jtguibas authored Apr 3, 2024
1 parent 813e2d7 commit 1e7bbfa
Show file tree
Hide file tree
Showing 17 changed files with 337 additions and 98 deletions.
2 changes: 1 addition & 1 deletion core/src/air/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub trait MachineAir<F: Field>: BaseAir<F> {
/// The execution record containing events for producing the air trace.
type Record: MachineRecord;

type Program;
type Program: Send + Sync;

/// A unique identifier for this AIR as part of a machine.
fn name(&self) -> String;
Expand Down
18 changes: 15 additions & 3 deletions core/src/lookup/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ fn field_to_int<F: PrimeField32>(x: F) -> i32 {

pub fn debug_interactions<SC: StarkGenericConfig, A: MachineAir<Val<SC>>>(
chip: &MachineChip<SC, A>,
program: &A::Program,
record: &A::Record,
interaction_kinds: Vec<InteractionKind>,
) -> (
Expand All @@ -58,6 +59,7 @@ pub fn debug_interactions<SC: StarkGenericConfig, A: MachineAir<Val<SC>>>(
let mut key_to_count = BTreeMap::new();

let trace = chip.generate_trace(record, &mut A::Record::default());
let mut preprocessed_trace = chip.generate_preprocessed_trace(program);
let mut main = trace.clone();
let height = trace.clone().height();

Expand All @@ -72,13 +74,21 @@ pub fn debug_interactions<SC: StarkGenericConfig, A: MachineAir<Val<SC>>>(
if !interaction_kinds.contains(&interaction.kind) {
continue;
}
let mut empty = vec![];
let preprocessed_row = preprocessed_trace
.as_mut()
.map(|t| t.row_mut(row))
.or_else(|| Some(&mut empty))
.unwrap();
let is_send = m < nb_send_interactions;
let multiplicity_eval: Val<SC> = interaction.multiplicity.apply(&[], main.row_mut(row));
let multiplicity_eval: Val<SC> = interaction
.multiplicity
.apply(preprocessed_row, main.row_mut(row));

if !multiplicity_eval.is_zero() {
let mut values = vec![];
for value in &interaction.values {
let expr: Val<SC> = value.apply(&[], main.row_mut(row));
let expr: Val<SC> = value.apply(preprocessed_row, main.row_mut(row));
values.push(expr);
}
let key = format!(
Expand Down Expand Up @@ -114,6 +124,7 @@ pub fn debug_interactions<SC: StarkGenericConfig, A: MachineAir<Val<SC>>>(
/// and print out the ones for which the set of sends and receives don't match.
pub fn debug_interactions_with_all_chips<SC, A>(
chips: &[MachineChip<SC, A>],
program: &A::Program,
segment: &A::Record,
interaction_kinds: Vec<InteractionKind>,
) -> bool
Expand All @@ -127,7 +138,8 @@ where
let mut total = SC::Val::zero();

for chip in chips.iter() {
let (_, count) = debug_interactions::<SC, A>(chip, segment, interaction_kinds.clone());
let (_, count) =
debug_interactions::<SC, A>(chip, program, segment, interaction_kinds.clone());

tracing::info!("{} chip has {} distinct events", chip.name(), count.len());
for (key, value) in count.iter() {
Expand Down
4 changes: 4 additions & 0 deletions core/src/memory/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,15 @@ mod tests {
fn test_memory_lookup_interactions() {
setup_logger();
let program = sha_extend_program();
let program_clone = program.clone();
let mut runtime = Runtime::new(program);
runtime.run();

let machine: crate::stark::MachineStark<BabyBearPoseidon2, RiscvAir<BabyBear>> =
RiscvAir::machine(BabyBearPoseidon2::new());
debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
machine.chips(),
&program_clone,
&runtime.record,
vec![InteractionKind::Memory],
);
Expand All @@ -245,12 +247,14 @@ mod tests {
fn test_byte_lookup_interactions() {
setup_logger();
let program = sha_extend_program();
let program_clone = program.clone();
let mut runtime = Runtime::new(program);
runtime.run();

let machine = RiscvAir::machine(BabyBearPoseidon2::new());
debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
machine.chips(),
&program_clone,
&runtime.record,
vec![InteractionKind::Byte],
);
Expand Down
21 changes: 14 additions & 7 deletions core/src/stark/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> MachineStark<SC, A> {

pub fn debug_constraints(
&self,
program: &A::Program,
pk: &ProvingKey<SC>,
record: A::Record,
challenger: &mut SC::Challenger,
Expand All @@ -294,7 +295,12 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> MachineStark<SC, A> {
// Generate the main trace for each chip.
let traces = chips
.par_iter()
.map(|chip| chip.generate_trace(shard, &mut A::Record::default()))
.map(|chip| {
(
chip.generate_trace(shard, &mut A::Record::default()),
chip.generate_preprocessed_trace(program),
)
})
.collect::<Vec<_>>();

// Get a permutation challenge.
Expand All @@ -311,9 +317,9 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> MachineStark<SC, A> {
chips
.par_iter()
.zip(traces.par_iter())
.map(|(chip, main_trace)| {
.map(|(chip, (main_trace, pre_trace))| {
let perm_trace = chip.generate_permutation_trace(
None,
pre_trace.as_ref(),
main_trace,
&permutation_challenges,
);
Expand All @@ -331,15 +337,15 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> MachineStark<SC, A> {

// Compute some statistics.
for i in 0..chips.len() {
let trace_width = traces[i].width();
let trace_width = traces[i].0.width();
let permutation_width = permutation_traces[i].width();
let total_width = trace_width + permutation_width;
tracing::debug!(
"{:<11} | Cols = {:<5} | Rows = {:<5} | Cells = {:<10} | Main Cols = {:.2}% | Perm Cols = {:.2}%",
chips[i].name(),
total_width,
traces[i].height(),
total_width * traces[i].height(),
traces[i].0.height(),
total_width * traces[i].0.height(),
(100f32 * trace_width as f32) / total_width as f32,
(100f32 * permutation_width as f32) / total_width as f32);
}
Expand All @@ -353,7 +359,7 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> MachineStark<SC, A> {
debug_constraints::<SC, A>(
chips[i],
permutation_trace,
&traces[i],
&traces[i].0,
&permutation_traces[i],
&permutation_challenges,
PublicValuesDigest::<Word<Val<SC>>>::new(shard.public_values_digest()),
Expand All @@ -371,6 +377,7 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> MachineStark<SC, A> {
}
debug_interactions_with_all_chips::<SC, A>(
self.chips(),
program,
&record,
InteractionKind::all_kinds(),
);
Expand Down
4 changes: 2 additions & 2 deletions core/src/utils/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ pub fn run_test_core(
let nb_bytes = bincode::serialize(&proof).unwrap().len();

#[cfg(feature = "debug")]
machine.debug_constraints(&pk, record_clone, &mut challenger);
machine.debug_constraints(&runtime.program, &pk, record_clone, &mut challenger);

let mut challenger = machine.config().challenger();
machine.verify(&vk, &proof, &mut challenger)?;
Expand Down Expand Up @@ -137,7 +137,7 @@ where
#[cfg(feature = "debug")]
{
let record_clone = runtime.record.clone();
machine.debug_constraints(&pk, record_clone, &mut challenger);
machine.debug_constraints(&program, &pk, record_clone, &mut challenger);
}
let public_values = std::mem::take(&mut runtime.state.public_values_stream);
let proof = prove_core(machine.config().clone(), runtime);
Expand Down
77 changes: 56 additions & 21 deletions recursion/core/src/cpu/air.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::air::BinomialExtensionUtils;
use crate::air::BlockBuilder;
use crate::cpu::CpuChip;
use crate::runtime::Program;
Expand All @@ -11,10 +12,12 @@ use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use p3_matrix::MatrixRowSlices;
use sp1_core::air::AirInteraction;
use sp1_core::air::BinomialExtension;
use sp1_core::air::MachineAir;
use sp1_core::lookup::InteractionKind;
use sp1_core::stark::SP1AirBuilder;
use sp1_core::utils::indices_arr;
use sp1_core::{air::MachineAir, utils::pad_to_power_of_two};
use sp1_core::utils::pad_rows;
use std::borrow::Borrow;
use std::borrow::BorrowMut;
use std::mem::transmute;
Expand Down Expand Up @@ -44,7 +47,7 @@ impl<F: PrimeField32> MachineAir<F> for CpuChip<F> {
input: &ExecutionRecord<F>,
_: &mut ExecutionRecord<F>,
) -> RowMajorMatrix<F> {
let rows = input
let mut rows = input
.cpu_events
.iter()
.map(|event| {
Expand All @@ -55,6 +58,8 @@ impl<F: PrimeField32> MachineAir<F> for CpuChip<F> {
cols.pc = event.pc;
cols.fp = event.fp;

cols.selectors.populate(&event.instruction);

cols.instruction.opcode = F::from_canonical_u32(event.instruction.opcode as u32);
cols.instruction.op_a = event.instruction.op_a;
cols.instruction.op_b = event.instruction.op_b;
Expand All @@ -76,18 +81,22 @@ impl<F: PrimeField32> MachineAir<F> for CpuChip<F> {
cols.c.value = event.instruction.op_c;
}

// cols.add_scratch = cols.b.value.0[0] + cols.c.value.0[0];
// cols.sub_scratch = cols.b.value.0[0] - cols.c.value.0[0];
// cols.mul_scratch = cols.b.value.0[0] * cols.c.value.0[0];
// cols.add_ext_scratch = (BinomialExtension::from_block(cols.b.value)
// + BinomialExtension::from_block(cols.c.value))
// .as_block();
// cols.sub_ext_scratch = (BinomialExtension::from_block(cols.b.value)
// - BinomialExtension::from_block(cols.c.value))
// .as_block();
// cols.mul_ext_scratch = (BinomialExtension::from_block(cols.b.value)
// * BinomialExtension::from_block(cols.c.value))
// .as_block();
let alu_cols = cols.opcode_specific.alu_mut();
if cols.selectors.is_add.is_one() {
alu_cols.add_scratch.0[0] = cols.b.value.0[0] + cols.c.value.0[0];
alu_cols.sub_scratch.0[0] = cols.b.value.0[0] - cols.c.value.0[0];
alu_cols.mul_scratch.0[0] = cols.b.value.0[0] * cols.c.value.0[0];
} else if cols.selectors.is_eadd.is_one() || cols.selectors.is_efadd.is_one() {
alu_cols.add_scratch = (BinomialExtension::from_block(cols.b.value)
+ BinomialExtension::from_block(cols.c.value))
.as_block();
alu_cols.sub_scratch = (BinomialExtension::from_block(cols.b.value)
- BinomialExtension::from_block(cols.c.value))
.as_block();
alu_cols.mul_scratch = (BinomialExtension::from_block(cols.b.value)
* BinomialExtension::from_block(cols.c.value))
.as_block();
}

// cols.a_eq_b
// .populate((cols.a.value.0[0] - cols.b.value.0[0]).as_canonical_u32());
Expand All @@ -101,12 +110,16 @@ impl<F: PrimeField32> MachineAir<F> for CpuChip<F> {
})
.collect::<Vec<_>>();

pad_rows(&mut rows, || {
let mut row = [F::zero(); NUM_CPU_COLS];
let cols: &mut CpuCols<F> = row.as_mut_slice().borrow_mut();
cols.selectors.is_noop = F::one();
row
});

let mut trace =
RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_CPU_COLS);

// Pad the trace to a power of two.
pad_to_power_of_two::<NUM_CPU_COLS, F>(&mut trace.values);

for i in input.cpu_events.len()..trace.height() {
trace.values[i * NUM_CPU_COLS + CPU_COL_MAP.clk] =
F::from_canonical_u32(4) * F::from_canonical_usize(i);
Expand Down Expand Up @@ -148,7 +161,7 @@ where
builder
.when_transition()
.when(next.is_real)
.assert_eq(local.clk + AB::F::from_canonical_u32(4), next.clk);
.assert_eq(local.clk.into() + AB::F::from_canonical_u32(4), next.clk);

// // Increment pc by 1 every cycle unless it is a branch instruction that is satisfied.
// builder
Expand All @@ -168,9 +181,24 @@ where
.assert_block_eq::<AB::Var, AB::Var>(local.c.value, local.instruction.op_c);

// Compute ALU.
// builder.assert_eq(local.b.value.0[0] + local.c.value.0[0], local.add_scratch);
// builder.assert_eq(local.b.value.0[0] - local.c.value.0[0], local.sub_scratch);
// builder.assert_eq(local.b.value.0[0] * local.c.value.0[0], local.mul_scratch);
let alu_cols = local.opcode_specific.alu();
builder.when(local.selectors.is_add).assert_eq(
local.b.value.0[0] + local.c.value.0[0],
alu_cols.add_scratch[0],
);
builder.when(local.selectors.is_add).assert_eq(
local.b.value.0[0] - local.c.value.0[0],
alu_cols.sub_scratch[0],
);
builder.when(local.selectors.is_add).assert_eq(
local.b.value.0[0] * local.c.value.0[0],
alu_cols.mul_scratch[0],
);

builder.assert_eq(
local.is_real * local.is_real * local.is_real,
local.is_real * local.is_real * local.is_real,
);

// // Compute extension ALU.
// builder.assert_ext_eq(
Expand Down Expand Up @@ -318,6 +346,13 @@ where
prog_interaction_vals.extend_from_slice(&local.instruction.op_c.map(|x| x.into()).0);
prog_interaction_vals.push(local.instruction.imm_b.into());
prog_interaction_vals.push(local.instruction.imm_c.into());
prog_interaction_vals.extend_from_slice(
&local
.selectors
.into_iter()
.map(|x| x.into())
.collect::<Vec<_>>(),
);
builder.send(AirInteraction::new(
prog_interaction_vals,
local.is_real.into(),
Expand Down
15 changes: 3 additions & 12 deletions recursion/core/src/cpu/columns/alu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,11 @@ pub struct AluCols<T> {
pub ext_b: Block<T>,

// c = a + b;
pub add_scratch: T,
pub add_scratch: Block<T>,

// c = a - b;
pub sub_scratch: T,
pub sub_scratch: Block<T>,

// c = a * b;
pub mul_scratch: T,

// ext(c) = ext(a) + ext(b);
pub add_ext_scratch: Block<T>,

// ext(c) = ext(a) - ext(b);
pub sub_ext_scratch: Block<T>,

// ext(c) = ext(a) * ext(b);
pub mul_ext_scratch: Block<T>,
pub mul_scratch: Block<T>,
}
2 changes: 2 additions & 0 deletions recursion/core/src/cpu/columns/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ impl<F: PrimeField> InstructionCols<F> {
self.op_a = instruction.op_a;
self.op_b = instruction.op_b;
self.op_c = instruction.op_c;
self.imm_b = F::from_bool(instruction.imm_b);
self.imm_c = F::from_bool(instruction.imm_c)
}
}

Expand Down
6 changes: 4 additions & 2 deletions recursion/core/src/cpu/columns/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ pub use alu::*;
pub use instruction::*;
pub use opcode::*;

use self::opcode_specific::OpcodeSpecificCols;

/// The column layout for the chip.
#[derive(AlignedBorrow, Default, Clone, Debug)]
#[repr(C)]
pub struct CpuCols<T> {
pub struct CpuCols<T: Copy> {
pub clk: T,
pub pc: T,
pub fp: T,
Expand All @@ -27,7 +29,7 @@ pub struct CpuCols<T> {
pub b: MemoryReadWriteCols<T>,
pub c: MemoryReadWriteCols<T>,

pub alu: AluCols<T>,
pub opcode_specific: OpcodeSpecificCols<T>,

// result = operand_1 == operand_2;
pub eq_1_2: IsExtZeroOperation<T>,
Expand Down
Loading

0 comments on commit 1e7bbfa

Please sign in to comment.