From a7afc1bbffbc14c7b14441150424f47981709f48 Mon Sep 17 00:00:00 2001 From: Tej Qu Nair Date: Sat, 24 Aug 2024 18:46:22 -0700 Subject: [PATCH] perf: recursion v2 tracegen (#1376) Co-authored-by: Ubuntu --- .../recursion/core-v2/src/chips/alu_base.rs | 116 +++++++-------- crates/recursion/core-v2/src/chips/alu_ext.rs | 139 ++++++++++-------- crates/recursion/core-v2/src/chips/mem/mod.rs | 2 + .../core-v2/src/chips/mem/variable.rs | 53 +++---- .../core-v2/src/chips/poseidon2_wide/trace.rs | 115 ++++++++------- crates/recursion/core-v2/src/lib.rs | 4 +- crates/stark/src/machine.rs | 51 +++---- crates/stark/src/prover.rs | 28 ++-- 8 files changed, 257 insertions(+), 251 deletions(-) diff --git a/crates/recursion/core-v2/src/chips/alu_base.rs b/crates/recursion/core-v2/src/chips/alu_base.rs index fef1875fad..ccce0581ed 100644 --- a/crates/recursion/core-v2/src/chips/alu_base.rs +++ b/crates/recursion/core-v2/src/chips/alu_base.rs @@ -1,9 +1,9 @@ use core::borrow::Borrow; -use itertools::Itertools; use p3_air::{Air, AirBuilder, BaseAir, PairBuilder}; use p3_field::{Field, PrimeField32}; use p3_matrix::{dense::RowMajorMatrix, Matrix}; -use sp1_core_machine::utils::pad_to_power_of_two; +use p3_maybe_rayon::prelude::*; +use sp1_core_machine::utils::next_power_of_two; use sp1_derive::AlignedBorrow; use sp1_stark::air::MachineAir; use std::{borrow::BorrowMut, iter::zip}; @@ -23,6 +23,8 @@ pub struct BaseAluCols { pub values: [BaseAluValueCols; NUM_BASE_ALU_ENTRIES_PER_ROW], } +pub const NUM_BASE_ALU_VALUE_COLS: usize = core::mem::size_of::>(); + #[derive(AlignedBorrow, Debug, Clone, Copy)] #[repr(C)] pub struct BaseAluValueCols { @@ -38,6 +40,8 @@ pub struct BaseAluPreprocessedCols { pub accesses: [BaseAluAccessCols; NUM_BASE_ALU_ENTRIES_PER_ROW], } +pub const NUM_BASE_ALU_ACCESS_COLS: usize = core::mem::size_of::>(); + #[derive(AlignedBorrow, Debug, Clone, Copy)] #[repr(C)] pub struct BaseAluAccessCols { @@ -69,14 +73,26 @@ impl MachineAir for BaseAluChip { } fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option> { - let rows = program + // Allocating an intermediate `Vec` is faster. + let instrs = program .instructions - .iter() - .filter_map(|instruction| { - let Instruction::BaseAlu(BaseAluInstr { opcode, mult, addrs }) = instruction else { - return None; - }; - let mut access = BaseAluAccessCols { + .iter() // Faster than using `rayon` for some reason. Maybe vectorization? + .filter_map(|instruction| match instruction { + Instruction::BaseAlu(x) => Some(x), + _ => None, + }) + .collect::>(); + + let nb_rows = instrs.len().div_ceil(NUM_BASE_ALU_ENTRIES_PER_ROW); + let padded_nb_rows = next_power_of_two(nb_rows, None); + let mut values = vec![F::zero(); padded_nb_rows * NUM_BASE_ALU_PREPROCESSED_COLS]; + // Generate the trace rows & corresponding records for each chunk of events in parallel. + let populate_len = instrs.len() * NUM_BASE_ALU_ACCESS_COLS; + values[..populate_len].par_chunks_mut(NUM_BASE_ALU_ACCESS_COLS).zip_eq(instrs).for_each( + |(row, instr)| { + let BaseAluInstr { opcode, mult, addrs } = instr; + let access: &mut BaseAluAccessCols<_> = row.borrow_mut(); + *access = BaseAluAccessCols { addrs: addrs.to_owned(), is_add: F::from_bool(false), is_sub: F::from_bool(false), @@ -91,31 +107,11 @@ impl MachineAir for BaseAluChip { BaseAluOpcode::DivF => &mut access.is_div, }; *target_flag = F::from_bool(true); - - Some(access) - }) - .chunks(NUM_BASE_ALU_ENTRIES_PER_ROW) - .into_iter() - .map(|row_accesses| { - let mut row = [F::zero(); NUM_BASE_ALU_PREPROCESSED_COLS]; - let cols: &mut BaseAluPreprocessedCols<_> = row.as_mut_slice().borrow_mut(); - for (cell, access) in zip(&mut cols.accesses, row_accesses) { - *cell = access; - } - row - }) - .collect::>(); - - // Convert the trace to a row major matrix. - let mut trace = RowMajorMatrix::new( - rows.into_iter().flatten().collect::>(), - NUM_BASE_ALU_PREPROCESSED_COLS, + }, ); - // Pad the trace to a power of two. - pad_to_power_of_two::(&mut trace.values); - - Some(trace) + // Convert the trace to a row major matrix. + Some(RowMajorMatrix::new(values, NUM_BASE_ALU_PREPROCESSED_COLS)) } fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { @@ -123,29 +119,21 @@ impl MachineAir for BaseAluChip { } fn generate_trace(&self, input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix { + let events = &input.base_alu_events; + let nb_rows = events.len().div_ceil(NUM_BASE_ALU_ENTRIES_PER_ROW); + let padded_nb_rows = next_power_of_two(nb_rows, None); + let mut values = vec![F::zero(); padded_nb_rows * NUM_BASE_ALU_COLS]; // Generate the trace rows & corresponding records for each chunk of events in parallel. - let rows = input - .base_alu_events - .chunks(NUM_BASE_ALU_ENTRIES_PER_ROW) - .map(|row_events| { - let mut row = [F::zero(); NUM_BASE_ALU_COLS]; - let cols: &mut BaseAluCols<_> = row.as_mut_slice().borrow_mut(); - for (cell, &vals) in zip(&mut cols.values, row_events) { - *cell = BaseAluValueCols { vals }; - } - - row - }) - .collect::>(); + let populate_len = events.len() * NUM_BASE_ALU_VALUE_COLS; + values[..populate_len].par_chunks_mut(NUM_BASE_ALU_VALUE_COLS).zip_eq(events).for_each( + |(row, &vals)| { + let cols: &mut BaseAluValueCols<_> = row.borrow_mut(); + *cols = BaseAluValueCols { vals }; + }, + ); // Convert the trace to a row major matrix. - let mut trace = - RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_BASE_ALU_COLS); - - // Pad the trace to a power of two. - pad_to_power_of_two::(&mut trace.values); - - trace + RowMajorMatrix::new(values, NUM_BASE_ALU_COLS) } fn included(&self, _record: &Self::Record) -> bool { @@ -165,23 +153,25 @@ where let prep_local = prep.row_slice(0); let prep_local: &BaseAluPreprocessedCols = (*prep_local).borrow(); - for (value, access) in zip(local.values, prep_local.accesses) { - let BaseAluValueCols { vals: BaseAluIo { out, in1, in2 } } = value; - + for ( + BaseAluValueCols { vals: BaseAluIo { out, in1, in2 } }, + BaseAluAccessCols { addrs, is_add, is_sub, is_mul, is_div, mult }, + ) in zip(local.values, prep_local.accesses) + { // Check exactly one flag is enabled. - let is_real = access.is_add + access.is_sub + access.is_mul + access.is_div; + let is_real = is_add + is_sub + is_mul + is_div; builder.assert_bool(is_real.clone()); - builder.when(access.is_add).assert_eq(in1 + in2, out); - builder.when(access.is_sub).assert_eq(in1, in2 + out); - builder.when(access.is_mul).assert_eq(in1 * in2, out); - builder.when(access.is_div).assert_eq(in1, in2 * out); + builder.when(is_add).assert_eq(in1 + in2, out); + builder.when(is_sub).assert_eq(in1, in2 + out); + builder.when(is_mul).assert_eq(in1 * in2, out); + builder.when(is_div).assert_eq(in1, in2 * out); - builder.receive_single(access.addrs.in1, in1, is_real.clone()); + builder.receive_single(addrs.in1, in1, is_real.clone()); - builder.receive_single(access.addrs.in2, in2, is_real); + builder.receive_single(addrs.in2, in2, is_real); - builder.send_single(access.addrs.out, out, access.mult); + builder.send_single(addrs.out, out, mult); } } } diff --git a/crates/recursion/core-v2/src/chips/alu_ext.rs b/crates/recursion/core-v2/src/chips/alu_ext.rs index 0b259927b9..39a92c7f8a 100644 --- a/crates/recursion/core-v2/src/chips/alu_ext.rs +++ b/crates/recursion/core-v2/src/chips/alu_ext.rs @@ -2,13 +2,16 @@ use core::borrow::Borrow; use p3_air::{Air, BaseAir, PairBuilder}; use p3_field::{extension::BinomiallyExtendable, Field, PrimeField32}; use p3_matrix::{dense::RowMajorMatrix, Matrix}; -use sp1_core_machine::utils::pad_to_power_of_two; +use p3_maybe_rayon::prelude::*; +use sp1_core_machine::utils::next_power_of_two; use sp1_derive::AlignedBorrow; use sp1_stark::air::{ExtensionAirBuilder, MachineAir}; -use std::borrow::BorrowMut; +use std::{borrow::BorrowMut, iter::zip}; use crate::{builder::SP1RecursionAirBuilder, *}; +pub const NUM_EXT_ALU_ENTRIES_PER_ROW: usize = 4; + #[derive(Default)] pub struct ExtAluChip {} @@ -17,6 +20,13 @@ pub const NUM_EXT_ALU_COLS: usize = core::mem::size_of::>(); #[derive(AlignedBorrow, Debug, Clone, Copy)] #[repr(C)] pub struct ExtAluCols { + pub values: [ExtAluValueCols; NUM_EXT_ALU_ENTRIES_PER_ROW], +} +const NUM_EXT_ALU_VALUE_COLS: usize = core::mem::size_of::>(); + +#[derive(AlignedBorrow, Debug, Clone, Copy)] +#[repr(C)] +pub struct ExtAluValueCols { pub vals: ExtAluIo>, } @@ -25,6 +35,14 @@ pub const NUM_EXT_ALU_PREPROCESSED_COLS: usize = core::mem::size_of:: { + pub accesses: [ExtAluAccessCols; NUM_EXT_ALU_ENTRIES_PER_ROW], +} + +pub const NUM_EXT_ALU_ACCESS_COLS: usize = core::mem::size_of::>(); + +#[derive(AlignedBorrow, Debug, Clone, Copy)] +#[repr(C)] +pub struct ExtAluAccessCols { pub addrs: ExtAluIo>, pub is_add: F, pub is_sub: F, @@ -53,16 +71,26 @@ impl> MachineAir for ExtAluChip { } fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option> { - let rows = program + // Allocating an intermediate `Vec` is faster. + let instrs = program .instructions - .iter() - .filter_map(|instruction| { - let Instruction::ExtAlu(ExtAluInstr { opcode, mult, addrs }) = instruction else { - return None; - }; - let mut row = [F::zero(); NUM_EXT_ALU_PREPROCESSED_COLS]; - let cols: &mut ExtAluPreprocessedCols = row.as_mut_slice().borrow_mut(); - *cols = ExtAluPreprocessedCols { + .iter() // Faster than using `rayon` for some reason. Maybe vectorization? + .filter_map(|instruction| match instruction { + Instruction::ExtAlu(x) => Some(x), + _ => None, + }) + .collect::>(); + + let nb_rows = instrs.len().div_ceil(NUM_EXT_ALU_ENTRIES_PER_ROW); + let padded_nb_rows = next_power_of_two(nb_rows, None); + let mut values = vec![F::zero(); padded_nb_rows * NUM_EXT_ALU_PREPROCESSED_COLS]; + // Generate the trace rows & corresponding records for each chunk of events in parallel. + let populate_len = instrs.len() * NUM_EXT_ALU_ACCESS_COLS; + values[..populate_len].par_chunks_mut(NUM_EXT_ALU_ACCESS_COLS).zip_eq(instrs).for_each( + |(row, instr)| { + let ExtAluInstr { opcode, mult, addrs } = instr; + let access: &mut ExtAluAccessCols<_> = row.borrow_mut(); + *access = ExtAluAccessCols { addrs: addrs.to_owned(), is_add: F::from_bool(false), is_sub: F::from_bool(false), @@ -71,27 +99,17 @@ impl> MachineAir for ExtAluChip { mult: mult.to_owned(), }; let target_flag = match opcode { - ExtAluOpcode::AddE => &mut cols.is_add, - ExtAluOpcode::SubE => &mut cols.is_sub, - ExtAluOpcode::MulE => &mut cols.is_mul, - ExtAluOpcode::DivE => &mut cols.is_div, + ExtAluOpcode::AddE => &mut access.is_add, + ExtAluOpcode::SubE => &mut access.is_sub, + ExtAluOpcode::MulE => &mut access.is_mul, + ExtAluOpcode::DivE => &mut access.is_div, }; *target_flag = F::from_bool(true); - - Some(row) - }) - .collect::>(); - - // Convert the trace to a row major matrix. - let mut trace = RowMajorMatrix::new( - rows.into_iter().flatten().collect::>(), - NUM_EXT_ALU_PREPROCESSED_COLS, + }, ); - // Pad the trace to a power of two. - pad_to_power_of_two::(&mut trace.values); - - Some(trace) + // Convert the trace to a row major matrix. + Some(RowMajorMatrix::new(values, NUM_EXT_ALU_PREPROCESSED_COLS)) } fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { @@ -99,26 +117,21 @@ impl> MachineAir for ExtAluChip { } fn generate_trace(&self, input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix { - let ext_alu_events = input.ext_alu_events.clone(); - + let events = &input.ext_alu_events; + let nb_rows = events.len().div_ceil(NUM_EXT_ALU_ENTRIES_PER_ROW); + let padded_nb_rows = next_power_of_two(nb_rows, None); + let mut values = vec![F::zero(); padded_nb_rows * NUM_EXT_ALU_COLS]; // Generate the trace rows & corresponding records for each chunk of events in parallel. - let rows = ext_alu_events - .into_iter() - .map(|vals| { - let mut row = [F::zero(); NUM_EXT_ALU_COLS]; - *row.as_mut_slice().borrow_mut() = ExtAluCols { vals }; - row - }) - .collect::>(); + let populate_len = events.len() * NUM_EXT_ALU_VALUE_COLS; + values[..populate_len].par_chunks_mut(NUM_EXT_ALU_VALUE_COLS).zip_eq(events).for_each( + |(row, &vals)| { + let cols: &mut ExtAluValueCols<_> = row.borrow_mut(); + *cols = ExtAluValueCols { vals }; + }, + ); // Convert the trace to a row major matrix. - let mut trace = - RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_EXT_ALU_COLS); - - // Pad the trace to a power of two. - pad_to_power_of_two::(&mut trace.values); - - trace + RowMajorMatrix::new(values, NUM_EXT_ALU_COLS) } fn included(&self, _record: &Self::Record) -> bool { @@ -138,26 +151,32 @@ where let prep_local = prep.row_slice(0); let prep_local: &ExtAluPreprocessedCols = (*prep_local).borrow(); - // Check exactly one flag is enabled. - let is_real = prep_local.is_add + prep_local.is_sub + prep_local.is_mul + prep_local.is_div; - builder.assert_bool(is_real.clone()); + for ( + ExtAluValueCols { vals }, + ExtAluAccessCols { addrs, is_add, is_sub, is_mul, is_div, mult }, + ) in zip(local.values, prep_local.accesses) + { + let in1 = vals.in1.as_extension::(); + let in2 = vals.in2.as_extension::(); + let out = vals.out.as_extension::(); - let in1 = local.vals.in1.as_extension::(); - let in2 = local.vals.in2.as_extension::(); - let out = local.vals.out.as_extension::(); + // Check exactly one flag is enabled. + let is_real = is_add + is_sub + is_mul + is_div; + builder.assert_bool(is_real.clone()); - builder.when(prep_local.is_add).assert_ext_eq(in1.clone() + in2.clone(), out.clone()); - builder.when(prep_local.is_sub).assert_ext_eq(in1.clone(), in2.clone() + out.clone()); - builder.when(prep_local.is_mul).assert_ext_eq(in1.clone() * in2.clone(), out.clone()); - builder.when(prep_local.is_div).assert_ext_eq(in1, in2 * out); + builder.when(is_add).assert_ext_eq(in1.clone() + in2.clone(), out.clone()); + builder.when(is_sub).assert_ext_eq(in1.clone(), in2.clone() + out.clone()); + builder.when(is_mul).assert_ext_eq(in1.clone() * in2.clone(), out.clone()); + builder.when(is_div).assert_ext_eq(in1, in2 * out); - // Read the inputs from memory. - builder.receive_block(prep_local.addrs.in1, local.vals.in1, is_real.clone()); + // Read the inputs from memory. + builder.receive_block(addrs.in1, vals.in1, is_real.clone()); - builder.receive_block(prep_local.addrs.in2, local.vals.in2, is_real); + builder.receive_block(addrs.in2, vals.in2, is_real); - // Write the output to memory. - builder.send_block(prep_local.addrs.out, local.vals.out, prep_local.mult); + // Write the output to memory. + builder.send_block(addrs.out, vals.out, mult); + } } } diff --git a/crates/recursion/core-v2/src/chips/mem/mod.rs b/crates/recursion/core-v2/src/chips/mem/mod.rs index d6926ce48f..f318db027a 100644 --- a/crates/recursion/core-v2/src/chips/mem/mod.rs +++ b/crates/recursion/core-v2/src/chips/mem/mod.rs @@ -8,6 +8,8 @@ use sp1_derive::AlignedBorrow; use crate::Address; +pub const NUM_MEM_ACCESS_COLS: usize = core::mem::size_of::>(); + /// Data describing in what manner to access a particular memory block. #[derive(AlignedBorrow, Debug, Clone, Copy)] #[repr(C)] diff --git a/crates/recursion/core-v2/src/chips/mem/variable.rs b/crates/recursion/core-v2/src/chips/mem/variable.rs index 090f244277..a2eafb10a3 100644 --- a/crates/recursion/core-v2/src/chips/mem/variable.rs +++ b/crates/recursion/core-v2/src/chips/mem/variable.rs @@ -1,17 +1,17 @@ use core::borrow::Borrow; use instruction::{HintBitsInstr, HintExt2FeltsInstr, HintInstr}; -use itertools::Itertools; use p3_air::{Air, BaseAir, PairBuilder}; use p3_field::PrimeField32; use p3_matrix::{dense::RowMajorMatrix, Matrix}; -use sp1_core_machine::utils::pad_to_power_of_two; +use p3_maybe_rayon::prelude::*; +use sp1_core_machine::utils::{next_power_of_two, pad_to_power_of_two}; use sp1_derive::AlignedBorrow; use sp1_stark::air::MachineAir; use std::{borrow::BorrowMut, iter::zip, marker::PhantomData}; use crate::{builder::SP1RecursionAirBuilder, *}; -use super::MemoryAccessCols; +use super::{MemoryAccessCols, NUM_MEM_ACCESS_COLS}; pub const NUM_MEM_ENTRIES_PER_ROW: usize = 16; @@ -56,50 +56,35 @@ impl MachineAir for MemoryChip { } fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option> { - let rows = program + // Allocating an intermediate `Vec` is faster. + let accesses = program .instructions - .iter() - .flat_map(|instruction| match instruction { + .par_iter() // Using `rayon` here provides a big speedup. + .flat_map_iter(|instruction| match instruction { Instruction::Hint(HintInstr { output_addrs_mults }) | Instruction::HintBits(HintBitsInstr { output_addrs_mults, input_addr: _, // No receive interaction for the hint operation - }) => output_addrs_mults - .iter() - .map(|&(addr, mult)| MemoryAccessCols { addr, mult }) - .collect(), + }) => output_addrs_mults.iter().collect(), Instruction::HintExt2Felts(HintExt2FeltsInstr { output_addrs_mults, input_addr: _, // No receive interaction for the hint operation - }) => output_addrs_mults - .iter() - .map(|&(addr, mult)| MemoryAccessCols { addr, mult }) - .collect(), - + }) => output_addrs_mults.iter().collect(), _ => vec![], }) - .chunks(NUM_MEM_ENTRIES_PER_ROW) - .into_iter() - .map(|row_accesses| { - let mut row = [F::zero(); NUM_MEM_PREPROCESSED_INIT_COLS]; - let cols: &mut MemoryPreprocessedCols<_> = row.as_mut_slice().borrow_mut(); - for (cell, access) in zip(&mut cols.accesses, row_accesses) { - *cell = access; - } - row - }) .collect::>(); - // Convert the trace to a row major matrix. - let mut trace = RowMajorMatrix::new( - rows.into_iter().flatten().collect::>(), - NUM_MEM_PREPROCESSED_INIT_COLS, - ); - - // Pad the trace to a power of two. - pad_to_power_of_two::(&mut trace.values); + let nb_rows = accesses.len().div_ceil(NUM_MEM_ENTRIES_PER_ROW); + let padded_nb_rows = next_power_of_two(nb_rows, None); + let mut values = vec![F::zero(); padded_nb_rows * NUM_MEM_PREPROCESSED_INIT_COLS]; + // Generate the trace rows & corresponding records for each chunk of events in parallel. + let populate_len = accesses.len() * NUM_MEM_ACCESS_COLS; + values[..populate_len] + .par_chunks_mut(NUM_MEM_ACCESS_COLS) + .zip_eq(accesses) + .for_each(|(row, &(addr, mult))| *row.borrow_mut() = MemoryAccessCols { addr, mult }); - Some(trace) + Some(RowMajorMatrix::new(values, NUM_MEM_PREPROCESSED_INIT_COLS)) } fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { diff --git a/crates/recursion/core-v2/src/chips/poseidon2_wide/trace.rs b/crates/recursion/core-v2/src/chips/poseidon2_wide/trace.rs index 71089558d2..c1cffae1e1 100644 --- a/crates/recursion/core-v2/src/chips/poseidon2_wide/trace.rs +++ b/crates/recursion/core-v2/src/chips/poseidon2_wide/trace.rs @@ -1,10 +1,10 @@ use std::{borrow::BorrowMut, mem::size_of}; -use itertools::Itertools; use p3_air::BaseAir; use p3_field::PrimeField32; use p3_matrix::{dense::RowMajorMatrix, Matrix}; -use sp1_core_machine::utils::pad_rows_fixed; +use p3_maybe_rayon::prelude::*; +use sp1_core_machine::utils::next_power_of_two; use sp1_primitives::RC_16_30_U32; use sp1_stark::air::MachineAir; use tracing::instrument; @@ -43,28 +43,33 @@ impl MachineAir for Poseidon2WideChip, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let mut rows = Vec::new(); - - let num_columns = as BaseAir>::width(self); - - for event in &input.poseidon2_events { - let mut row = vec![F::zero(); num_columns]; - self.populate_perm(event.input, Some(event.output), row.as_mut_slice()); - rows.push(row); - } - - if self.pad { - // Pad the trace to a power of two. - // This will need to be adjusted when the AIR constraints are implemented. - let mut dummy_row = vec![F::zero(); num_columns]; - self.populate_perm([F::zero(); WIDTH], None, &mut dummy_row); - - pad_rows_fixed(&mut rows, || dummy_row.clone(), self.fixed_log2_rows); - } + let events = &input.poseidon2_events; + let padded_nb_rows = next_power_of_two(events.len(), self.fixed_log2_rows); + let num_columns = >::width(self); + let mut values = vec![F::zero(); padded_nb_rows * num_columns]; + + let mut dummy_row = vec![F::zero(); num_columns]; + self.populate_perm([F::zero(); WIDTH], None, &mut dummy_row); + + let populate_len = events.len() * num_columns; + let (values_pop, values_dummy) = values.split_at_mut(populate_len); + join( + || { + values_pop.par_chunks_mut(num_columns).zip_eq(&input.poseidon2_events).for_each( + |(row, &event)| { + self.populate_perm(event.input, Some(event.output), row); + }, + ) + }, + || { + values_dummy + .par_chunks_mut(num_columns) + .for_each(|row| row.copy_from_slice(&dummy_row)) + }, + ); // Convert the trace to a row major matrix. - let trace = - RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), num_columns); + let trace = RowMajorMatrix::new(values, num_columns); #[cfg(debug_assertions)] println!( @@ -85,44 +90,40 @@ impl MachineAir for Poseidon2WideChip Option> { - let instructions = - program.instructions.iter().filter_map(|instruction| match instruction { - Poseidon2(instr) => Some(instr), + // Allocating an intermediate `Vec` is faster. + let instrs = program + .instructions + .iter() // Faster than using `rayon` for some reason. Maybe vectorization? + .filter_map(|instruction| match instruction { + Poseidon2(instr) => Some(instr.as_ref()), _ => None, - }); - - let num_instructions = instructions.clone().count(); - - let mut rows = vec![[F::zero(); PREPROCESSED_POSEIDON2_WIDTH]; num_instructions]; - - // Iterate over the instructions and take NUM_EXTERNAL_ROUNDS + 2 rows for each instruction. - instructions.zip_eq(rows.iter_mut()).for_each(|(instruction, row)| { - let cols: &mut Poseidon2PreprocessedCols<_> = (*row).as_mut_slice().borrow_mut(); - - // Set the memory columns. We read once, at the first iteration, - // and write once, at the last iteration. - cols.memory_preprocessed = std::array::from_fn(|j| { - if j < WIDTH { - MemoryAccessCols { addr: instruction.addrs.input[j], mult: F::neg_one() } - } else { - MemoryAccessCols { - addr: instruction.addrs.output[j - WIDTH], - mult: instruction.mults[j - WIDTH], + }) + .collect::>(); + + let padded_nb_rows = next_power_of_two(instrs.len(), self.fixed_log2_rows); + let mut values = vec![F::zero(); padded_nb_rows * PREPROCESSED_POSEIDON2_WIDTH]; + + let populate_len = instrs.len() * PREPROCESSED_POSEIDON2_WIDTH; + values[..populate_len] + .par_chunks_mut(PREPROCESSED_POSEIDON2_WIDTH) + .zip_eq(instrs) + .for_each(|(row, instr)| { + let cols: &mut Poseidon2PreprocessedCols<_> = row.borrow_mut(); + + // Set the memory columns. We read once, at the first iteration, + // and write once, at the last iteration. + cols.memory_preprocessed = std::array::from_fn(|j| { + if j < WIDTH { + MemoryAccessCols { addr: instr.addrs.input[j], mult: F::neg_one() } + } else { + MemoryAccessCols { + addr: instr.addrs.output[j - WIDTH], + mult: instr.mults[j - WIDTH], + } } - } + }); }); - }); - if self.pad { - // Pad the trace to a power of two. - // This may need to be adjusted when the AIR constraints are implemented. - pad_rows_fixed( - &mut rows, - || [F::zero(); PREPROCESSED_POSEIDON2_WIDTH], - self.fixed_log2_rows, - ); - } - let trace_rows = rows.into_iter().flatten().collect::>(); - Some(RowMajorMatrix::new(trace_rows, PREPROCESSED_POSEIDON2_WIDTH)) + Some(RowMajorMatrix::new(values, PREPROCESSED_POSEIDON2_WIDTH)) } } diff --git a/crates/recursion/core-v2/src/lib.rs b/crates/recursion/core-v2/src/lib.rs index 4ebaefe8c8..61dd5cdc81 100644 --- a/crates/recursion/core-v2/src/lib.rs +++ b/crates/recursion/core-v2/src/lib.rs @@ -21,7 +21,7 @@ use crate::chips::poseidon2_skinny::WIDTH; #[derive( AlignedBorrow, Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Default, )] -#[repr(C)] +#[repr(transparent)] pub struct Address(pub F); impl Address { @@ -35,6 +35,7 @@ impl Address { /// The inputs and outputs to an operation of the base field ALU. #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[repr(C)] pub struct BaseAluIo { pub out: V, pub in1: V, @@ -55,6 +56,7 @@ pub struct BaseAluInstr { /// The inputs and outputs to an operation of the extension field ALU. #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[repr(C)] pub struct ExtAluIo { pub out: V, pub in1: V, diff --git a/crates/stark/src/machine.rs b/crates/stark/src/machine.rs index e4e4c4ff88..882bd2a04d 100644 --- a/crates/stark/src/machine.rs +++ b/crates/stark/src/machine.rs @@ -7,7 +7,7 @@ use p3_field::{AbstractExtensionField, AbstractField, Field, PrimeField32}; use p3_matrix::{dense::RowMajorMatrix, Dimensions, Matrix}; use p3_maybe_rayon::prelude::*; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use std::{cmp::Reverse, fmt::Debug}; +use std::{cmp::Reverse, fmt::Debug, time::Instant}; use tracing::instrument; use super::{debug_constraints, Dom}; @@ -156,30 +156,31 @@ impl>> StarkMachine { #[allow(clippy::map_unwrap_or)] #[allow(clippy::redundant_closure_for_method_calls)] pub fn setup(&self, program: &A::Program) -> (StarkProvingKey, StarkVerifyingKey) { - let mut named_preprocessed_traces = tracing::debug_span!("generate preprocessed traces") - .in_scope(|| { - self.chips() - .iter() - .map(|chip| { - let prep_trace = chip.generate_preprocessed_trace(program); - // Assert that the chip width data is correct. - let expected_width = prep_trace.as_ref().map(|t| t.width()).unwrap_or(0); - assert_eq!( - expected_width, - chip.preprocessed_width(), - "Incorrect number of preprocessed columns for chip {}", - chip.name() - ); - - (chip.name(), prep_trace) - }) - .filter(|(_, prep_trace)| prep_trace.is_some()) - .map(|(name, prep_trace)| { - let prep_trace = prep_trace.unwrap(); - (name, prep_trace) - }) - .collect::>() - }); + let parent_span = tracing::debug_span!("generate preprocessed traces"); + let mut named_preprocessed_traces = parent_span.in_scope(|| { + self.chips() + .par_iter() + .filter_map(|chip| { + let chip_name = chip.name(); + let begin = Instant::now(); + let prep_trace = chip.generate_preprocessed_trace(program); + tracing::debug!( + parent: &parent_span, + "generated preprocessed trace for chip {} in {:?}", + chip_name, + begin.elapsed() + ); + // Assert that the chip width data is correct. + let expected_width = prep_trace.as_ref().map(|t| t.width()).unwrap_or(0); + assert_eq!( + expected_width, + chip.preprocessed_width(), + "Incorrect number of preprocessed columns for chip {chip_name}" + ); + prep_trace.map(move |t| (chip_name, t)) + }) + .collect::>() + }); // Order the chips and traces by trace size (biggest first), and get the ordering map. named_preprocessed_traces.sort_by_key(|(_, trace)| Reverse(trace.height())); diff --git a/crates/stark/src/prover.rs b/crates/stark/src/prover.rs index 0d08661d6c..d3839b3f57 100644 --- a/crates/stark/src/prover.rs +++ b/crates/stark/src/prover.rs @@ -1,7 +1,7 @@ use core::fmt::Display; use itertools::Itertools; use serde::{de::DeserializeOwned, Serialize}; -use std::{cmp::Reverse, error::Error}; +use std::{cmp::Reverse, error::Error, time::Instant}; use crate::{AirOpenedValues, ChipOpenedValues, ShardOpenedValues}; use p3_air::Air; @@ -54,16 +54,22 @@ pub trait MachineProver>: // For each chip, generate the trace. let parent_span = tracing::debug_span!("generate traces for shard"); parent_span.in_scope(|| { - shard_chips - .par_iter() - .map(|chip| { - let chip_name = chip.name(); - let trace = tracing::debug_span!(parent: &parent_span, "generate trace for chip", %chip_name) - .in_scope(|| chip.generate_trace(record, &mut A::Record::default())); - (chip_name, trace) - }) - .collect::>() - }) + shard_chips + .par_iter() + .map(|chip| { + let chip_name = chip.name(); + let begin = Instant::now(); + let trace = chip.generate_trace(record, &mut A::Record::default()); + tracing::debug!( + parent: &parent_span, + "generated trace for chip {} in {:?}", + chip_name, + begin.elapsed() + ); + (chip_name, trace) + }) + .collect::>() + }) } /// Commit to the main traces.