diff --git a/recursion/core/src/poseidon2_wide/external.rs b/recursion/core/src/poseidon2_wide/external.rs index 40a96890ae..3c9cf067d3 100644 --- a/recursion/core/src/poseidon2_wide/external.rs +++ b/recursion/core/src/poseidon2_wide/external.rs @@ -1,6 +1,6 @@ use core::borrow::Borrow; use core::mem::size_of; -use p3_air::{Air, BaseAir}; +use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::{AbstractField, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; @@ -36,6 +36,8 @@ pub struct Poseidon2WideCols { pub output: [T; WIDTH], external_rounds: [Poseidon2WideExternalRoundCols; NUM_EXTERNAL_ROUNDS], internal_rounds: Poseidon2WideInternalRoundsCols, + + pub is_real: T, } /// A grouping of columns for a single external round. @@ -84,6 +86,7 @@ impl MachineAir for Poseidon2WideChip { // Apply the initial round. cols.input = event.input; cols.external_rounds[0].state = event.input; + cols.is_real = F::one(); external_linear_layer(&mut cols.external_rounds[0].state); // Apply the first half of external rounds. @@ -213,6 +216,7 @@ fn eval_external_round( builder: &mut AB, cols: &Poseidon2WideCols, r: usize, + is_real: AB::Var, ) { let round_cols = cols.external_rounds[r]; @@ -223,7 +227,7 @@ fn eval_external_round( r + NUM_INTERNAL_ROUNDS }; let add_rc: [AB::Expr; WIDTH] = core::array::from_fn(|i| { - round_cols.state[i].into() + AB::Expr::from_wrapped_u32(RC_16_30_U32[round][i]) + round_cols.state[i].into() + is_real * AB::F::from_wrapped_u32(RC_16_30_U32[round][i]) }); // Apply the sboxes. @@ -253,7 +257,11 @@ fn eval_external_round( } } -fn eval_internal_rounds(builder: &mut AB, cols: &Poseidon2WideCols) { +fn eval_internal_rounds( + builder: &mut AB, + cols: &Poseidon2WideCols, + is_real: AB::Var, +) { let round_cols = &cols.internal_rounds; let mut state: [AB::Expr; WIDTH] = core::array::from_fn(|i| round_cols.state[i].into()); for r in 0..NUM_INTERNAL_ROUNDS { @@ -263,7 +271,7 @@ fn eval_internal_rounds(builder: &mut AB, cols: &Poseidon2Wid state[0].clone() } else { round_cols.s0[r - 1].into() - } + AB::Expr::from_wrapped_u32(RC_16_30_U32[round][0]); + } + is_real * AB::Expr::from_wrapped_u32(RC_16_30_U32[round][0]); let sbox_deg_3 = add_rc.clone() * add_rc.clone() * add_rc.clone(); builder.assert_eq(round_cols.sbox_deg_3[r], sbox_deg_3); @@ -313,7 +321,7 @@ where initial_round_output }; for i in 0..WIDTH { - builder.assert_eq( + builder.when(cols.is_real).assert_eq( cols.external_rounds[0].state[i], initial_round_output[i].clone(), ); @@ -321,15 +329,15 @@ where // Apply the first half of external rounds. for r in 0..NUM_EXTERNAL_ROUNDS / 2 { - eval_external_round(builder, cols, r); + eval_external_round(builder, cols, r, cols.is_real); } // Apply the internal rounds. - eval_internal_rounds(builder, cols); + eval_internal_rounds(builder, cols, cols.is_real); // Apply the second half of external rounds. for r in NUM_EXTERNAL_ROUNDS / 2..NUM_EXTERNAL_ROUNDS { - eval_external_round(builder, cols, r); + eval_external_round(builder, cols, r, cols.is_real); } } } @@ -396,18 +404,18 @@ mod tests { /// A test proving 2^10 permuations #[test] - fn prove_babybear() { + fn poseidon2_wide_prove_babybear() { let config = BabyBearPoseidon2Inner::new(); let mut challenger = config.challenger(); let chip = Poseidon2WideChip; - let test_inputs = (0..1024) + let test_inputs = (0..1000) .map(|i| [BabyBear::from_canonical_u32(i); WIDTH]) .collect_vec(); let mut input_exec = ExecutionRecord::::default(); - for input in test_inputs.iter().cloned() { + for input in test_inputs { input_exec.poseidon2_events.push(Poseidon2Event { input }); } let trace: RowMajorMatrix =