From c570c1bdd8610302fd68980f36754e5c48ab92ef Mon Sep 17 00:00:00 2001 From: Sebastien La Duca Date: Fri, 19 Apr 2024 13:58:09 -0400 Subject: [PATCH] fix(recursion): poseidon2 chip matches plonky3 (#548) --- recursion/core/src/poseidon2/external.rs | 136 +++++++++++------- recursion/core/src/poseidon2/mod.rs | 38 ----- recursion/core/src/poseidon2_wide/external.rs | 39 +++-- recursion/core/src/poseidon2_wide/mod.rs | 17 +-- 4 files changed, 105 insertions(+), 125 deletions(-) diff --git a/recursion/core/src/poseidon2/external.rs b/recursion/core/src/poseidon2/external.rs index e7a331f9e8..6289a45d2d 100644 --- a/recursion/core/src/poseidon2/external.rs +++ b/recursion/core/src/poseidon2/external.rs @@ -13,7 +13,10 @@ use sp1_primitives::RC_16_30_U32; use std::borrow::BorrowMut; use tracing::instrument; -use super::{apply_m_4, matmul_internal, MATRIX_DIAG_16_BABYBEAR_U32}; +use crate::poseidon2_wide::{ + apply_m_4, external_linear_layer, internal_linear_layer, matmul_internal, + MATRIX_DIAG_16_BABYBEAR_U32, +}; use crate::runtime::{ExecutionRecord, RecursionProgram}; /// The number of main trace columns for `AddChip`. @@ -60,9 +63,9 @@ impl MachineAir for Poseidon2Chip { let rounds_f = 8; let rounds_p = 22; - let rounds = rounds_f + rounds_p; - let rounds_f_beginning = rounds_f / 2; - let p_end = rounds_f_beginning + rounds_p; + let rounds = rounds_f + rounds_p + 1; + let rounds_p_beginning = 1 + rounds_f / 2; + let p_end = rounds_p_beginning + rounds_p; for poseidon2_event in input.poseidon2_events.iter() { let mut round_input = poseidon2_event.input; @@ -75,8 +78,8 @@ impl MachineAir for Poseidon2Chip { cols.rounds[r] = F::one(); let is_initial_layer = r == 0; - let is_external_layer = r != 0 - && (((r - 1) < rounds_f_beginning) || (p_end <= (r - 1) && (r - 1) < rounds)); + let is_external_layer = + (r >= 1 && r < rounds_p_beginning) || (r >= p_end && r < rounds); if is_initial_layer { // Mark the selector as initial. @@ -121,23 +124,9 @@ impl MachineAir for Poseidon2Chip { // Apply either the external or internal linear layer. if cols.is_initial == F::one() || cols.is_external == F::one() { - for j in (0..WIDTH).step_by(4) { - apply_m_4(&mut state[j..j + 4]); - } - let sums: [F; 4] = core::array::from_fn(|k| { - (0..WIDTH).step_by(4).map(|j| state[j + k]).sum::() - }); - for j in 0..WIDTH { - state[j] += sums[j % 4]; - } + external_linear_layer(&mut state); } else if cols.is_internal == F::one() { - let matmul_constants: [F; WIDTH] = MATRIX_DIAG_16_BABYBEAR_U32 - .iter() - .map(|x| F::from_wrapped_u32(*x)) - .collect::>() - .try_into() - .unwrap(); - matmul_internal(&mut state, matmul_constants); + internal_linear_layer(&mut state) } // Copy the state to the output. @@ -329,70 +318,109 @@ where #[cfg(test)] mod tests { - use std::borrow::BorrowMut; + use itertools::Itertools; + use std::borrow::Borrow; use std::time::Instant; use p3_baby_bear::BabyBear; use p3_baby_bear::DiffusionMatrixBabybear; use p3_field::AbstractField; - use p3_matrix::dense::RowMajorMatrix; + use p3_matrix::{dense::RowMajorMatrix, Matrix}; use p3_poseidon2::Poseidon2; use p3_poseidon2::Poseidon2ExternalMatrixGeneral; use sp1_core::stark::StarkGenericConfig; use sp1_core::utils::inner_perm; use sp1_core::{ air::MachineAir, - utils::{uni_stark_prove, BabyBearPoseidon2}, + utils::{uni_stark_prove, uni_stark_verify, BabyBearPoseidon2}, }; - use crate::poseidon2::external::WIDTH; - use crate::{poseidon2::external::Poseidon2Chip, runtime::ExecutionRecord}; + use crate::{ + poseidon2::{Poseidon2Chip, Poseidon2Event, WIDTH}, + runtime::ExecutionRecord, + }; use p3_symmetric::Permutation; - use super::{Poseidon2Cols, NUM_POSEIDON2_COLS}; + use super::Poseidon2Cols; + + const ROWS_PER_PERMUTATION: usize = 31; #[test] fn generate_trace() { let chip = Poseidon2Chip; - let trace: RowMajorMatrix = chip.generate_trace( - &ExecutionRecord::::default(), - &mut ExecutionRecord::::default(), - ); - println!("{:?}", trace.values) + let test_inputs = vec![ + [BabyBear::from_canonical_u32(1); WIDTH], + [BabyBear::from_canonical_u32(2); WIDTH], + [BabyBear::from_canonical_u32(3); WIDTH], + [BabyBear::from_canonical_u32(4); WIDTH], + ]; + + let gt: Poseidon2< + BabyBear, + Poseidon2ExternalMatrixGeneral, + DiffusionMatrixBabybear, + 16, + 7, + > = inner_perm(); + + let expected_outputs = test_inputs + .iter() + .map(|input| gt.permute(*input)) + .collect::>(); + + let mut input_exec = ExecutionRecord::::default(); + for input in test_inputs.iter().cloned() { + input_exec.poseidon2_events.push(Poseidon2Event { input }); + } + + let trace: RowMajorMatrix = + chip.generate_trace(&input_exec, &mut ExecutionRecord::::default()); + + for (i, expected_output) in expected_outputs.iter().enumerate() { + let row = trace.row(ROWS_PER_PERMUTATION * (i + 1) - 1).collect_vec(); + let cols: &Poseidon2Cols = row.as_slice().borrow(); + assert_eq!(expected_output, &cols.output); + } } #[test] - #[ignore] fn prove_babybear() { let config = BabyBearPoseidon2::new(); let mut challenger = config.challenger(); let chip = Poseidon2Chip; - let trace: RowMajorMatrix = chip.generate_trace( - &ExecutionRecord::::default(), - &mut ExecutionRecord::::default(), + + let test_inputs = (0..16) + .map(|i| [BabyBear::from_canonical_u32(i); WIDTH]) + .collect_vec(); + + let mut input_exec = ExecutionRecord::::default(); + for input in test_inputs.iter().cloned() { + input_exec.poseidon2_events.push(Poseidon2Event { input }); + } + let trace: RowMajorMatrix = + chip.generate_trace(&input_exec, &mut ExecutionRecord::::default()); + println!( + "trace dims is width: {:?}, height: {:?}", + trace.width(), + trace.height() ); - let gt: Poseidon2< + let start = Instant::now(); + let proof = uni_stark_prove(&config, &chip, &mut challenger, trace); + let duration = start.elapsed().as_secs_f64(); + println!("proof duration = {:?}", duration); + + let mut challenger: p3_challenger::DuplexChallenger< BabyBear, - Poseidon2ExternalMatrixGeneral, - DiffusionMatrixBabybear, + Poseidon2, 16, - 7, - > = inner_perm(); - let input = [BabyBear::one(); WIDTH]; - let output = gt.permute(input); - - let mut row: [BabyBear; NUM_POSEIDON2_COLS] = trace.values - [NUM_POSEIDON2_COLS * 30..(NUM_POSEIDON2_COLS) * 31] - .try_into() - .unwrap(); - let cols: &mut Poseidon2Cols = row.as_mut_slice().borrow_mut(); - assert_eq!(cols.output, output); - + > = config.challenger(); let start = Instant::now(); - uni_stark_prove(&config, &chip, &mut challenger, trace); + uni_stark_verify(&config, &chip, &mut challenger, &proof) + .expect("expected proof to be valid"); + let duration = start.elapsed().as_secs_f64(); - println!("duration = {:?}", duration); + println!("verify duration = {:?}", duration); } } diff --git a/recursion/core/src/poseidon2/mod.rs b/recursion/core/src/poseidon2/mod.rs index 99f032e11b..a4e6e58f8a 100644 --- a/recursion/core/src/poseidon2/mod.rs +++ b/recursion/core/src/poseidon2/mod.rs @@ -1,8 +1,6 @@ #![allow(clippy::needless_range_loop)] use crate::poseidon2::external::WIDTH; -use p3_field::{AbstractField, Field}; - mod external; pub use external::Poseidon2Chip; @@ -11,39 +9,3 @@ pub use external::Poseidon2Chip; pub struct Poseidon2Event { pub input: [F; WIDTH], } - -// TODO: Make this public inside Plonky3 and import directly. -pub fn apply_m_4(x: &mut [AF]) -where - AF: AbstractField, -{ - let t0 = x[0].clone() + x[1].clone(); - let t1 = x[2].clone() + x[3].clone(); - let t2 = x[1].clone() + x[1].clone() + t1.clone(); - let t3 = x[3].clone() + x[3].clone() + t0.clone(); - let t4 = t1.clone() + t1.clone() + t1.clone() + t1 + t3.clone(); - let t5 = t0.clone() + t0.clone() + t0.clone() + t0 + t2.clone(); - let t6 = t3 + t5.clone(); - let t7 = t2 + t4.clone(); - x[0] = t6; - x[1] = t5; - x[2] = t7; - x[3] = t4; -} - -// TODO: Make this public inside Plonky3 and import directly. -pub fn matmul_internal, const WIDTH: usize>( - state: &mut [AF; WIDTH], - mat_internal_diag_m_1: [F; WIDTH], -) { - let sum: AF = state.iter().cloned().sum(); - for i in 0..WIDTH { - state[i] *= AF::from_f(mat_internal_diag_m_1[i]); - state[i] += sum.clone(); - } -} - -pub const MATRIX_DIAG_16_BABYBEAR_U32: [u32; 16] = [ - 0x0a632d94, 0x6db657b7, 0x56fbdc9e, 0x052b3d8a, 0x33745201, 0x5c03108c, 0x0beba37b, 0x258c2e8b, - 0x12029f39, 0x694909ce, 0x6d231724, 0x21c3b222, 0x3c0904a5, 0x01d6acda, 0x27705c83, 0x5231c802, -]; diff --git a/recursion/core/src/poseidon2_wide/external.rs b/recursion/core/src/poseidon2_wide/external.rs index 723f48ba5d..40a96890ae 100644 --- a/recursion/core/src/poseidon2_wide/external.rs +++ b/recursion/core/src/poseidon2_wide/external.rs @@ -81,10 +81,10 @@ impl MachineAir for Poseidon2WideChip { let mut row = [F::zero(); NUM_POSEIDON2_WIDE_COLS]; let cols: &mut Poseidon2WideCols = row.as_mut_slice().borrow_mut(); - cols.input = event.input; - // Apply the initial round. - external_linear_layer(&cols.input, &mut cols.external_rounds[0].state); + cols.input = event.input; + cols.external_rounds[0].state = event.input; + external_linear_layer(&mut cols.external_rounds[0].state); // Apply the first half of external rounds. for r in 0..NUM_EXTERNAL_ROUNDS / 2 { @@ -141,7 +141,7 @@ fn populate_external_round( cols: &mut Poseidon2WideCols, r: usize, ) -> [F; WIDTH] { - let linear_layer_input = { + let mut state = { let round_cols = cols.external_rounds[r].borrow_mut(); // Add round constants. @@ -172,9 +172,8 @@ fn populate_external_round( }; // Apply the linear layer. - let mut next_state = [F::zero(); WIDTH]; - external_linear_layer(&linear_layer_input, &mut next_state); - next_state + external_linear_layer(&mut state); + state } fn populate_internal_rounds(cols: &mut Poseidon2WideCols) -> [F; WIDTH] { @@ -194,10 +193,8 @@ fn populate_internal_rounds(cols: &mut Poseidon2WideCols) -> let sbox_deg_7 = cols.sbox_deg_3[r] * cols.sbox_deg_3[r] * add_rc; // Apply the linear layer. - let mut linear_layer_input = state; - linear_layer_input[0] = sbox_deg_7; - - internal_linear_layer(&linear_layer_input, &mut state); + state[0] = sbox_deg_7; + internal_linear_layer(&mut state); // Optimization: since we're only applying the sbox to the 0th state element, we only // need to have columns for the 0th state element at every step. This is because the @@ -241,8 +238,8 @@ fn eval_external_round( } // Apply the linear layer. - let mut linear_layer_output: [AB::Expr; WIDTH] = core::array::from_fn(|_| AB::Expr::zero()); - external_linear_layer(&sbox_deg_7, &mut linear_layer_output); + let mut state = sbox_deg_7; + external_linear_layer(&mut state); let next_state_cols = if r == NUM_EXTERNAL_ROUNDS / 2 - 1 { &cols.internal_rounds.state @@ -252,7 +249,7 @@ fn eval_external_round( &cols.external_rounds[r + 1].state }; for i in 0..WIDTH { - builder.assert_eq(next_state_cols[i], linear_layer_output[i].clone()); + builder.assert_eq(next_state_cols[i], state[i].clone()); } } @@ -277,9 +274,8 @@ fn eval_internal_rounds(builder: &mut AB, cols: &Poseidon2Wid // Apply the linear layer. // See `populate_internal_rounds` for why we don't have columns for the new state here. - let mut linear_layer_input = state.clone(); - linear_layer_input[0] = sbox_deg_7.clone(); - internal_linear_layer(&linear_layer_input, &mut state); + state[0] = sbox_deg_7.clone(); + internal_linear_layer(&mut state); if r < NUM_INTERNAL_ROUNDS - 1 { builder.assert_eq(round_cols.s0[r], state[0].clone()); @@ -311,10 +307,10 @@ where // Apply the initial round. let initial_round_output = { - let input: [AB::Expr; WIDTH] = core::array::from_fn(|i| cols.input[i].into()); - let mut output: [AB::Expr; WIDTH] = core::array::from_fn(|_| AB::Expr::zero()); - external_linear_layer(&input, &mut output); - output + let mut initial_round_output: [AB::Expr; WIDTH] = + core::array::from_fn(|i| cols.input[i].into()); + external_linear_layer(&mut initial_round_output); + initial_round_output }; for i in 0..WIDTH { builder.assert_eq( @@ -400,7 +396,6 @@ mod tests { /// A test proving 2^10 permuations #[test] - #[ignore] fn prove_babybear() { let config = BabyBearPoseidon2Inner::new(); let mut challenger = config.challenger(); diff --git a/recursion/core/src/poseidon2_wide/mod.rs b/recursion/core/src/poseidon2_wide/mod.rs index b58402c22c..ee93a24dbb 100644 --- a/recursion/core/src/poseidon2_wide/mod.rs +++ b/recursion/core/src/poseidon2_wide/mod.rs @@ -40,35 +40,30 @@ pub fn matmul_internal, const WIDTH: usize>( state[i] += sum.clone(); } } -pub(crate) fn external_linear_layer( - input: &[AF; WIDTH], - output: &mut [AF; WIDTH], -) { - output.clone_from_slice(input); +pub(crate) fn external_linear_layer(state: &mut [AF; WIDTH]) { for j in (0..WIDTH).step_by(4) { - apply_m_4(&mut output[j..j + 4]); + apply_m_4(&mut state[j..j + 4]); } let sums: [AF; 4] = core::array::from_fn(|k| { (0..WIDTH) .step_by(4) - .map(|j| output[j + k].clone()) + .map(|j| state[j + k].clone()) .sum::() }); for j in 0..WIDTH { - output[j] += sums[j % 4].clone(); + state[j] += sums[j % 4].clone(); } } -pub(crate) fn internal_linear_layer(input: &[F; WIDTH], output: &mut [F; WIDTH]) { - output.clone_from_slice(input); +pub(crate) fn internal_linear_layer(state: &mut [F; WIDTH]) { let matmul_constants: [::F; WIDTH] = MATRIX_DIAG_16_BABYBEAR_U32 .iter() .map(|x| ::F::from_wrapped_u32(*x)) .collect::>() .try_into() .unwrap(); - matmul_internal(output, matmul_constants); + matmul_internal(state, matmul_constants); } pub const MATRIX_DIAG_16_BABYBEAR_U32: [u32; 16] = [