Skip to content

Commit

Permalink
perf: recursion v2 tracegen (#1376)
Browse files Browse the repository at this point in the history
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
tqn and Ubuntu authored Aug 25, 2024
1 parent 4467d6a commit a7afc1b
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 251 deletions.
116 changes: 53 additions & 63 deletions crates/recursion/core-v2/src/chips/alu_base.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use core::borrow::Borrow;
use itertools::Itertools;
use p3_air::{Air, AirBuilder, BaseAir, PairBuilder};
use p3_field::{Field, PrimeField32};
use p3_matrix::{dense::RowMajorMatrix, Matrix};
use sp1_core_machine::utils::pad_to_power_of_two;
use p3_maybe_rayon::prelude::*;
use sp1_core_machine::utils::next_power_of_two;
use sp1_derive::AlignedBorrow;
use sp1_stark::air::MachineAir;
use std::{borrow::BorrowMut, iter::zip};
Expand All @@ -23,6 +23,8 @@ pub struct BaseAluCols<F: Copy> {
pub values: [BaseAluValueCols<F>; NUM_BASE_ALU_ENTRIES_PER_ROW],
}

pub const NUM_BASE_ALU_VALUE_COLS: usize = core::mem::size_of::<BaseAluValueCols<u8>>();

#[derive(AlignedBorrow, Debug, Clone, Copy)]
#[repr(C)]
pub struct BaseAluValueCols<F: Copy> {
Expand All @@ -38,6 +40,8 @@ pub struct BaseAluPreprocessedCols<F: Copy> {
pub accesses: [BaseAluAccessCols<F>; NUM_BASE_ALU_ENTRIES_PER_ROW],
}

pub const NUM_BASE_ALU_ACCESS_COLS: usize = core::mem::size_of::<BaseAluAccessCols<u8>>();

#[derive(AlignedBorrow, Debug, Clone, Copy)]
#[repr(C)]
pub struct BaseAluAccessCols<F: Copy> {
Expand Down Expand Up @@ -69,14 +73,26 @@ impl<F: PrimeField32> MachineAir<F> for BaseAluChip {
}

fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
let rows = program
// Allocating an intermediate `Vec` is faster.
let instrs = program
.instructions
.iter()
.filter_map(|instruction| {
let Instruction::BaseAlu(BaseAluInstr { opcode, mult, addrs }) = instruction else {
return None;
};
let mut access = BaseAluAccessCols {
.iter() // Faster than using `rayon` for some reason. Maybe vectorization?
.filter_map(|instruction| match instruction {
Instruction::BaseAlu(x) => Some(x),
_ => None,
})
.collect::<Vec<_>>();

let nb_rows = instrs.len().div_ceil(NUM_BASE_ALU_ENTRIES_PER_ROW);
let padded_nb_rows = next_power_of_two(nb_rows, None);
let mut values = vec![F::zero(); padded_nb_rows * NUM_BASE_ALU_PREPROCESSED_COLS];
// Generate the trace rows & corresponding records for each chunk of events in parallel.
let populate_len = instrs.len() * NUM_BASE_ALU_ACCESS_COLS;
values[..populate_len].par_chunks_mut(NUM_BASE_ALU_ACCESS_COLS).zip_eq(instrs).for_each(
|(row, instr)| {
let BaseAluInstr { opcode, mult, addrs } = instr;
let access: &mut BaseAluAccessCols<_> = row.borrow_mut();
*access = BaseAluAccessCols {
addrs: addrs.to_owned(),
is_add: F::from_bool(false),
is_sub: F::from_bool(false),
Expand All @@ -91,61 +107,33 @@ impl<F: PrimeField32> MachineAir<F> for BaseAluChip {
BaseAluOpcode::DivF => &mut access.is_div,
};
*target_flag = F::from_bool(true);

Some(access)
})
.chunks(NUM_BASE_ALU_ENTRIES_PER_ROW)
.into_iter()
.map(|row_accesses| {
let mut row = [F::zero(); NUM_BASE_ALU_PREPROCESSED_COLS];
let cols: &mut BaseAluPreprocessedCols<_> = row.as_mut_slice().borrow_mut();
for (cell, access) in zip(&mut cols.accesses, row_accesses) {
*cell = access;
}
row
})
.collect::<Vec<_>>();

// Convert the trace to a row major matrix.
let mut trace = RowMajorMatrix::new(
rows.into_iter().flatten().collect::<Vec<_>>(),
NUM_BASE_ALU_PREPROCESSED_COLS,
},
);

// Pad the trace to a power of two.
pad_to_power_of_two::<NUM_BASE_ALU_PREPROCESSED_COLS, F>(&mut trace.values);

Some(trace)
// Convert the trace to a row major matrix.
Some(RowMajorMatrix::new(values, NUM_BASE_ALU_PREPROCESSED_COLS))
}

fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
// This is a no-op.
}

fn generate_trace(&self, input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
let events = &input.base_alu_events;
let nb_rows = events.len().div_ceil(NUM_BASE_ALU_ENTRIES_PER_ROW);
let padded_nb_rows = next_power_of_two(nb_rows, None);
let mut values = vec![F::zero(); padded_nb_rows * NUM_BASE_ALU_COLS];
// Generate the trace rows & corresponding records for each chunk of events in parallel.
let rows = input
.base_alu_events
.chunks(NUM_BASE_ALU_ENTRIES_PER_ROW)
.map(|row_events| {
let mut row = [F::zero(); NUM_BASE_ALU_COLS];
let cols: &mut BaseAluCols<_> = row.as_mut_slice().borrow_mut();
for (cell, &vals) in zip(&mut cols.values, row_events) {
*cell = BaseAluValueCols { vals };
}

row
})
.collect::<Vec<_>>();
let populate_len = events.len() * NUM_BASE_ALU_VALUE_COLS;
values[..populate_len].par_chunks_mut(NUM_BASE_ALU_VALUE_COLS).zip_eq(events).for_each(
|(row, &vals)| {
let cols: &mut BaseAluValueCols<_> = row.borrow_mut();
*cols = BaseAluValueCols { vals };
},
);

// Convert the trace to a row major matrix.
let mut trace =
RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_BASE_ALU_COLS);

// Pad the trace to a power of two.
pad_to_power_of_two::<NUM_BASE_ALU_COLS, F>(&mut trace.values);

trace
RowMajorMatrix::new(values, NUM_BASE_ALU_COLS)
}

fn included(&self, _record: &Self::Record) -> bool {
Expand All @@ -165,23 +153,25 @@ where
let prep_local = prep.row_slice(0);
let prep_local: &BaseAluPreprocessedCols<AB::Var> = (*prep_local).borrow();

for (value, access) in zip(local.values, prep_local.accesses) {
let BaseAluValueCols { vals: BaseAluIo { out, in1, in2 } } = value;

for (
BaseAluValueCols { vals: BaseAluIo { out, in1, in2 } },
BaseAluAccessCols { addrs, is_add, is_sub, is_mul, is_div, mult },
) in zip(local.values, prep_local.accesses)
{
// Check exactly one flag is enabled.
let is_real = access.is_add + access.is_sub + access.is_mul + access.is_div;
let is_real = is_add + is_sub + is_mul + is_div;
builder.assert_bool(is_real.clone());

builder.when(access.is_add).assert_eq(in1 + in2, out);
builder.when(access.is_sub).assert_eq(in1, in2 + out);
builder.when(access.is_mul).assert_eq(in1 * in2, out);
builder.when(access.is_div).assert_eq(in1, in2 * out);
builder.when(is_add).assert_eq(in1 + in2, out);
builder.when(is_sub).assert_eq(in1, in2 + out);
builder.when(is_mul).assert_eq(in1 * in2, out);
builder.when(is_div).assert_eq(in1, in2 * out);

builder.receive_single(access.addrs.in1, in1, is_real.clone());
builder.receive_single(addrs.in1, in1, is_real.clone());

builder.receive_single(access.addrs.in2, in2, is_real);
builder.receive_single(addrs.in2, in2, is_real);

builder.send_single(access.addrs.out, out, access.mult);
builder.send_single(addrs.out, out, mult);
}
}
}
Expand Down
139 changes: 79 additions & 60 deletions crates/recursion/core-v2/src/chips/alu_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ use core::borrow::Borrow;
use p3_air::{Air, BaseAir, PairBuilder};
use p3_field::{extension::BinomiallyExtendable, Field, PrimeField32};
use p3_matrix::{dense::RowMajorMatrix, Matrix};
use sp1_core_machine::utils::pad_to_power_of_two;
use p3_maybe_rayon::prelude::*;
use sp1_core_machine::utils::next_power_of_two;
use sp1_derive::AlignedBorrow;
use sp1_stark::air::{ExtensionAirBuilder, MachineAir};
use std::borrow::BorrowMut;
use std::{borrow::BorrowMut, iter::zip};

use crate::{builder::SP1RecursionAirBuilder, *};

pub const NUM_EXT_ALU_ENTRIES_PER_ROW: usize = 4;

#[derive(Default)]
pub struct ExtAluChip {}

Expand All @@ -17,6 +20,13 @@ pub const NUM_EXT_ALU_COLS: usize = core::mem::size_of::<ExtAluCols<u8>>();
#[derive(AlignedBorrow, Debug, Clone, Copy)]
#[repr(C)]
pub struct ExtAluCols<F: Copy> {
pub values: [ExtAluValueCols<F>; NUM_EXT_ALU_ENTRIES_PER_ROW],
}
const NUM_EXT_ALU_VALUE_COLS: usize = core::mem::size_of::<ExtAluValueCols<u8>>();

#[derive(AlignedBorrow, Debug, Clone, Copy)]
#[repr(C)]
pub struct ExtAluValueCols<F: Copy> {
pub vals: ExtAluIo<Block<F>>,
}

Expand All @@ -25,6 +35,14 @@ pub const NUM_EXT_ALU_PREPROCESSED_COLS: usize = core::mem::size_of::<ExtAluPrep
#[derive(AlignedBorrow, Debug, Clone, Copy)]
#[repr(C)]
pub struct ExtAluPreprocessedCols<F: Copy> {
pub accesses: [ExtAluAccessCols<F>; NUM_EXT_ALU_ENTRIES_PER_ROW],
}

pub const NUM_EXT_ALU_ACCESS_COLS: usize = core::mem::size_of::<ExtAluAccessCols<u8>>();

#[derive(AlignedBorrow, Debug, Clone, Copy)]
#[repr(C)]
pub struct ExtAluAccessCols<F: Copy> {
pub addrs: ExtAluIo<Address<F>>,
pub is_add: F,
pub is_sub: F,
Expand Down Expand Up @@ -53,16 +71,26 @@ impl<F: PrimeField32 + BinomiallyExtendable<D>> MachineAir<F> for ExtAluChip {
}

fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
let rows = program
// Allocating an intermediate `Vec` is faster.
let instrs = program
.instructions
.iter()
.filter_map(|instruction| {
let Instruction::ExtAlu(ExtAluInstr { opcode, mult, addrs }) = instruction else {
return None;
};
let mut row = [F::zero(); NUM_EXT_ALU_PREPROCESSED_COLS];
let cols: &mut ExtAluPreprocessedCols<F> = row.as_mut_slice().borrow_mut();
*cols = ExtAluPreprocessedCols {
.iter() // Faster than using `rayon` for some reason. Maybe vectorization?
.filter_map(|instruction| match instruction {
Instruction::ExtAlu(x) => Some(x),
_ => None,
})
.collect::<Vec<_>>();

let nb_rows = instrs.len().div_ceil(NUM_EXT_ALU_ENTRIES_PER_ROW);
let padded_nb_rows = next_power_of_two(nb_rows, None);
let mut values = vec![F::zero(); padded_nb_rows * NUM_EXT_ALU_PREPROCESSED_COLS];
// Generate the trace rows & corresponding records for each chunk of events in parallel.
let populate_len = instrs.len() * NUM_EXT_ALU_ACCESS_COLS;
values[..populate_len].par_chunks_mut(NUM_EXT_ALU_ACCESS_COLS).zip_eq(instrs).for_each(
|(row, instr)| {
let ExtAluInstr { opcode, mult, addrs } = instr;
let access: &mut ExtAluAccessCols<_> = row.borrow_mut();
*access = ExtAluAccessCols {
addrs: addrs.to_owned(),
is_add: F::from_bool(false),
is_sub: F::from_bool(false),
Expand All @@ -71,54 +99,39 @@ impl<F: PrimeField32 + BinomiallyExtendable<D>> MachineAir<F> for ExtAluChip {
mult: mult.to_owned(),
};
let target_flag = match opcode {
ExtAluOpcode::AddE => &mut cols.is_add,
ExtAluOpcode::SubE => &mut cols.is_sub,
ExtAluOpcode::MulE => &mut cols.is_mul,
ExtAluOpcode::DivE => &mut cols.is_div,
ExtAluOpcode::AddE => &mut access.is_add,
ExtAluOpcode::SubE => &mut access.is_sub,
ExtAluOpcode::MulE => &mut access.is_mul,
ExtAluOpcode::DivE => &mut access.is_div,
};
*target_flag = F::from_bool(true);

Some(row)
})
.collect::<Vec<_>>();

// Convert the trace to a row major matrix.
let mut trace = RowMajorMatrix::new(
rows.into_iter().flatten().collect::<Vec<_>>(),
NUM_EXT_ALU_PREPROCESSED_COLS,
},
);

// Pad the trace to a power of two.
pad_to_power_of_two::<NUM_EXT_ALU_PREPROCESSED_COLS, F>(&mut trace.values);

Some(trace)
// Convert the trace to a row major matrix.
Some(RowMajorMatrix::new(values, NUM_EXT_ALU_PREPROCESSED_COLS))
}

fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
// This is a no-op.
}

fn generate_trace(&self, input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
let ext_alu_events = input.ext_alu_events.clone();

let events = &input.ext_alu_events;
let nb_rows = events.len().div_ceil(NUM_EXT_ALU_ENTRIES_PER_ROW);
let padded_nb_rows = next_power_of_two(nb_rows, None);
let mut values = vec![F::zero(); padded_nb_rows * NUM_EXT_ALU_COLS];
// Generate the trace rows & corresponding records for each chunk of events in parallel.
let rows = ext_alu_events
.into_iter()
.map(|vals| {
let mut row = [F::zero(); NUM_EXT_ALU_COLS];
*row.as_mut_slice().borrow_mut() = ExtAluCols { vals };
row
})
.collect::<Vec<_>>();
let populate_len = events.len() * NUM_EXT_ALU_VALUE_COLS;
values[..populate_len].par_chunks_mut(NUM_EXT_ALU_VALUE_COLS).zip_eq(events).for_each(
|(row, &vals)| {
let cols: &mut ExtAluValueCols<_> = row.borrow_mut();
*cols = ExtAluValueCols { vals };
},
);

// Convert the trace to a row major matrix.
let mut trace =
RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_EXT_ALU_COLS);

// Pad the trace to a power of two.
pad_to_power_of_two::<NUM_EXT_ALU_COLS, F>(&mut trace.values);

trace
RowMajorMatrix::new(values, NUM_EXT_ALU_COLS)
}

fn included(&self, _record: &Self::Record) -> bool {
Expand All @@ -138,26 +151,32 @@ where
let prep_local = prep.row_slice(0);
let prep_local: &ExtAluPreprocessedCols<AB::Var> = (*prep_local).borrow();

// Check exactly one flag is enabled.
let is_real = prep_local.is_add + prep_local.is_sub + prep_local.is_mul + prep_local.is_div;
builder.assert_bool(is_real.clone());
for (
ExtAluValueCols { vals },
ExtAluAccessCols { addrs, is_add, is_sub, is_mul, is_div, mult },
) in zip(local.values, prep_local.accesses)
{
let in1 = vals.in1.as_extension::<AB>();
let in2 = vals.in2.as_extension::<AB>();
let out = vals.out.as_extension::<AB>();

let in1 = local.vals.in1.as_extension::<AB>();
let in2 = local.vals.in2.as_extension::<AB>();
let out = local.vals.out.as_extension::<AB>();
// Check exactly one flag is enabled.
let is_real = is_add + is_sub + is_mul + is_div;
builder.assert_bool(is_real.clone());

builder.when(prep_local.is_add).assert_ext_eq(in1.clone() + in2.clone(), out.clone());
builder.when(prep_local.is_sub).assert_ext_eq(in1.clone(), in2.clone() + out.clone());
builder.when(prep_local.is_mul).assert_ext_eq(in1.clone() * in2.clone(), out.clone());
builder.when(prep_local.is_div).assert_ext_eq(in1, in2 * out);
builder.when(is_add).assert_ext_eq(in1.clone() + in2.clone(), out.clone());
builder.when(is_sub).assert_ext_eq(in1.clone(), in2.clone() + out.clone());
builder.when(is_mul).assert_ext_eq(in1.clone() * in2.clone(), out.clone());
builder.when(is_div).assert_ext_eq(in1, in2 * out);

// Read the inputs from memory.
builder.receive_block(prep_local.addrs.in1, local.vals.in1, is_real.clone());
// Read the inputs from memory.
builder.receive_block(addrs.in1, vals.in1, is_real.clone());

builder.receive_block(prep_local.addrs.in2, local.vals.in2, is_real);
builder.receive_block(addrs.in2, vals.in2, is_real);

// Write the output to memory.
builder.send_block(prep_local.addrs.out, local.vals.out, prep_local.mult);
// Write the output to memory.
builder.send_block(addrs.out, vals.out, mult);
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions crates/recursion/core-v2/src/chips/mem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use sp1_derive::AlignedBorrow;

use crate::Address;

pub const NUM_MEM_ACCESS_COLS: usize = core::mem::size_of::<MemoryAccessCols<u8>>();

/// Data describing in what manner to access a particular memory block.
#[derive(AlignedBorrow, Debug, Clone, Copy)]
#[repr(C)]
Expand Down
Loading

0 comments on commit a7afc1b

Please sign in to comment.