Skip to content

Commit

Permalink
fix(recursion): poseidon2 chip matches plonky3 (#548)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sladuca authored Apr 19, 2024
1 parent e290ae8 commit c570c1b
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 125 deletions.
136 changes: 82 additions & 54 deletions recursion/core/src/poseidon2/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ use sp1_primitives::RC_16_30_U32;
use std::borrow::BorrowMut;
use tracing::instrument;

use super::{apply_m_4, matmul_internal, MATRIX_DIAG_16_BABYBEAR_U32};
use crate::poseidon2_wide::{
apply_m_4, external_linear_layer, internal_linear_layer, matmul_internal,
MATRIX_DIAG_16_BABYBEAR_U32,
};
use crate::runtime::{ExecutionRecord, RecursionProgram};

/// The number of main trace columns for `AddChip`.
Expand Down Expand Up @@ -60,9 +63,9 @@ impl<F: PrimeField32> MachineAir<F> for Poseidon2Chip {

let rounds_f = 8;
let rounds_p = 22;
let rounds = rounds_f + rounds_p;
let rounds_f_beginning = rounds_f / 2;
let p_end = rounds_f_beginning + rounds_p;
let rounds = rounds_f + rounds_p + 1;
let rounds_p_beginning = 1 + rounds_f / 2;
let p_end = rounds_p_beginning + rounds_p;

for poseidon2_event in input.poseidon2_events.iter() {
let mut round_input = poseidon2_event.input;
Expand All @@ -75,8 +78,8 @@ impl<F: PrimeField32> MachineAir<F> for Poseidon2Chip {

cols.rounds[r] = F::one();
let is_initial_layer = r == 0;
let is_external_layer = r != 0
&& (((r - 1) < rounds_f_beginning) || (p_end <= (r - 1) && (r - 1) < rounds));
let is_external_layer =
(r >= 1 && r < rounds_p_beginning) || (r >= p_end && r < rounds);

if is_initial_layer {
// Mark the selector as initial.
Expand Down Expand Up @@ -121,23 +124,9 @@ impl<F: PrimeField32> MachineAir<F> for Poseidon2Chip {

// Apply either the external or internal linear layer.
if cols.is_initial == F::one() || cols.is_external == F::one() {
for j in (0..WIDTH).step_by(4) {
apply_m_4(&mut state[j..j + 4]);
}
let sums: [F; 4] = core::array::from_fn(|k| {
(0..WIDTH).step_by(4).map(|j| state[j + k]).sum::<F>()
});
for j in 0..WIDTH {
state[j] += sums[j % 4];
}
external_linear_layer(&mut state);
} else if cols.is_internal == F::one() {
let matmul_constants: [F; WIDTH] = MATRIX_DIAG_16_BABYBEAR_U32
.iter()
.map(|x| F::from_wrapped_u32(*x))
.collect::<Vec<_>>()
.try_into()
.unwrap();
matmul_internal(&mut state, matmul_constants);
internal_linear_layer(&mut state)
}

// Copy the state to the output.
Expand Down Expand Up @@ -329,70 +318,109 @@ where

#[cfg(test)]
mod tests {
use std::borrow::BorrowMut;
use itertools::Itertools;
use std::borrow::Borrow;
use std::time::Instant;

use p3_baby_bear::BabyBear;
use p3_baby_bear::DiffusionMatrixBabybear;
use p3_field::AbstractField;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::{dense::RowMajorMatrix, Matrix};
use p3_poseidon2::Poseidon2;
use p3_poseidon2::Poseidon2ExternalMatrixGeneral;
use sp1_core::stark::StarkGenericConfig;
use sp1_core::utils::inner_perm;
use sp1_core::{
air::MachineAir,
utils::{uni_stark_prove, BabyBearPoseidon2},
utils::{uni_stark_prove, uni_stark_verify, BabyBearPoseidon2},
};

use crate::poseidon2::external::WIDTH;
use crate::{poseidon2::external::Poseidon2Chip, runtime::ExecutionRecord};
use crate::{
poseidon2::{Poseidon2Chip, Poseidon2Event, WIDTH},
runtime::ExecutionRecord,
};
use p3_symmetric::Permutation;

use super::{Poseidon2Cols, NUM_POSEIDON2_COLS};
use super::Poseidon2Cols;

const ROWS_PER_PERMUTATION: usize = 31;

#[test]
fn generate_trace() {
let chip = Poseidon2Chip;
let trace: RowMajorMatrix<BabyBear> = chip.generate_trace(
&ExecutionRecord::<BabyBear>::default(),
&mut ExecutionRecord::<BabyBear>::default(),
);
println!("{:?}", trace.values)
let test_inputs = vec![
[BabyBear::from_canonical_u32(1); WIDTH],
[BabyBear::from_canonical_u32(2); WIDTH],
[BabyBear::from_canonical_u32(3); WIDTH],
[BabyBear::from_canonical_u32(4); WIDTH],
];

let gt: Poseidon2<
BabyBear,
Poseidon2ExternalMatrixGeneral,
DiffusionMatrixBabybear,
16,
7,
> = inner_perm();

let expected_outputs = test_inputs
.iter()
.map(|input| gt.permute(*input))
.collect::<Vec<_>>();

let mut input_exec = ExecutionRecord::<BabyBear>::default();
for input in test_inputs.iter().cloned() {
input_exec.poseidon2_events.push(Poseidon2Event { input });
}

let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&input_exec, &mut ExecutionRecord::<BabyBear>::default());

for (i, expected_output) in expected_outputs.iter().enumerate() {
let row = trace.row(ROWS_PER_PERMUTATION * (i + 1) - 1).collect_vec();
let cols: &Poseidon2Cols<BabyBear> = row.as_slice().borrow();
assert_eq!(expected_output, &cols.output);
}
}

#[test]
#[ignore]
fn prove_babybear() {
let config = BabyBearPoseidon2::new();
let mut challenger = config.challenger();

let chip = Poseidon2Chip;
let trace: RowMajorMatrix<BabyBear> = chip.generate_trace(
&ExecutionRecord::<BabyBear>::default(),
&mut ExecutionRecord::<BabyBear>::default(),

let test_inputs = (0..16)
.map(|i| [BabyBear::from_canonical_u32(i); WIDTH])
.collect_vec();

let mut input_exec = ExecutionRecord::<BabyBear>::default();
for input in test_inputs.iter().cloned() {
input_exec.poseidon2_events.push(Poseidon2Event { input });
}
let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&input_exec, &mut ExecutionRecord::<BabyBear>::default());
println!(
"trace dims is width: {:?}, height: {:?}",
trace.width(),
trace.height()
);

let gt: Poseidon2<
let start = Instant::now();
let proof = uni_stark_prove(&config, &chip, &mut challenger, trace);
let duration = start.elapsed().as_secs_f64();
println!("proof duration = {:?}", duration);

let mut challenger: p3_challenger::DuplexChallenger<
BabyBear,
Poseidon2ExternalMatrixGeneral,
DiffusionMatrixBabybear,
Poseidon2<BabyBear, Poseidon2ExternalMatrixGeneral, DiffusionMatrixBabybear, 16, 7>,
16,
7,
> = inner_perm();
let input = [BabyBear::one(); WIDTH];
let output = gt.permute(input);

let mut row: [BabyBear; NUM_POSEIDON2_COLS] = trace.values
[NUM_POSEIDON2_COLS * 30..(NUM_POSEIDON2_COLS) * 31]
.try_into()
.unwrap();
let cols: &mut Poseidon2Cols<BabyBear> = row.as_mut_slice().borrow_mut();
assert_eq!(cols.output, output);

> = config.challenger();
let start = Instant::now();
uni_stark_prove(&config, &chip, &mut challenger, trace);
uni_stark_verify(&config, &chip, &mut challenger, &proof)
.expect("expected proof to be valid");

let duration = start.elapsed().as_secs_f64();
println!("duration = {:?}", duration);
println!("verify duration = {:?}", duration);
}
}
38 changes: 0 additions & 38 deletions recursion/core/src/poseidon2/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#![allow(clippy::needless_range_loop)]

use crate::poseidon2::external::WIDTH;
use p3_field::{AbstractField, Field};

mod external;

pub use external::Poseidon2Chip;
Expand All @@ -11,39 +9,3 @@ pub use external::Poseidon2Chip;
pub struct Poseidon2Event<F> {
pub input: [F; WIDTH],
}

// TODO: Make this public inside Plonky3 and import directly.
pub fn apply_m_4<AF>(x: &mut [AF])
where
AF: AbstractField,
{
let t0 = x[0].clone() + x[1].clone();
let t1 = x[2].clone() + x[3].clone();
let t2 = x[1].clone() + x[1].clone() + t1.clone();
let t3 = x[3].clone() + x[3].clone() + t0.clone();
let t4 = t1.clone() + t1.clone() + t1.clone() + t1 + t3.clone();
let t5 = t0.clone() + t0.clone() + t0.clone() + t0 + t2.clone();
let t6 = t3 + t5.clone();
let t7 = t2 + t4.clone();
x[0] = t6;
x[1] = t5;
x[2] = t7;
x[3] = t4;
}

// TODO: Make this public inside Plonky3 and import directly.
pub fn matmul_internal<F: Field, AF: AbstractField<F = F>, const WIDTH: usize>(
state: &mut [AF; WIDTH],
mat_internal_diag_m_1: [F; WIDTH],
) {
let sum: AF = state.iter().cloned().sum();
for i in 0..WIDTH {
state[i] *= AF::from_f(mat_internal_diag_m_1[i]);
state[i] += sum.clone();
}
}

pub const MATRIX_DIAG_16_BABYBEAR_U32: [u32; 16] = [
0x0a632d94, 0x6db657b7, 0x56fbdc9e, 0x052b3d8a, 0x33745201, 0x5c03108c, 0x0beba37b, 0x258c2e8b,
0x12029f39, 0x694909ce, 0x6d231724, 0x21c3b222, 0x3c0904a5, 0x01d6acda, 0x27705c83, 0x5231c802,
];
39 changes: 17 additions & 22 deletions recursion/core/src/poseidon2_wide/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ impl<F: PrimeField32> MachineAir<F> for Poseidon2WideChip {
let mut row = [F::zero(); NUM_POSEIDON2_WIDE_COLS];
let cols: &mut Poseidon2WideCols<F> = row.as_mut_slice().borrow_mut();

cols.input = event.input;

// Apply the initial round.
external_linear_layer(&cols.input, &mut cols.external_rounds[0].state);
cols.input = event.input;
cols.external_rounds[0].state = event.input;
external_linear_layer(&mut cols.external_rounds[0].state);

// Apply the first half of external rounds.
for r in 0..NUM_EXTERNAL_ROUNDS / 2 {
Expand Down Expand Up @@ -141,7 +141,7 @@ fn populate_external_round<F: PrimeField32>(
cols: &mut Poseidon2WideCols<F>,
r: usize,
) -> [F; WIDTH] {
let linear_layer_input = {
let mut state = {
let round_cols = cols.external_rounds[r].borrow_mut();

// Add round constants.
Expand Down Expand Up @@ -172,9 +172,8 @@ fn populate_external_round<F: PrimeField32>(
};

// Apply the linear layer.
let mut next_state = [F::zero(); WIDTH];
external_linear_layer(&linear_layer_input, &mut next_state);
next_state
external_linear_layer(&mut state);
state
}

fn populate_internal_rounds<F: PrimeField32>(cols: &mut Poseidon2WideCols<F>) -> [F; WIDTH] {
Expand All @@ -194,10 +193,8 @@ fn populate_internal_rounds<F: PrimeField32>(cols: &mut Poseidon2WideCols<F>) ->
let sbox_deg_7 = cols.sbox_deg_3[r] * cols.sbox_deg_3[r] * add_rc;

// Apply the linear layer.
let mut linear_layer_input = state;
linear_layer_input[0] = sbox_deg_7;

internal_linear_layer(&linear_layer_input, &mut state);
state[0] = sbox_deg_7;
internal_linear_layer(&mut state);

// Optimization: since we're only applying the sbox to the 0th state element, we only
// need to have columns for the 0th state element at every step. This is because the
Expand Down Expand Up @@ -241,8 +238,8 @@ fn eval_external_round<AB: SP1AirBuilder>(
}

// Apply the linear layer.
let mut linear_layer_output: [AB::Expr; WIDTH] = core::array::from_fn(|_| AB::Expr::zero());
external_linear_layer(&sbox_deg_7, &mut linear_layer_output);
let mut state = sbox_deg_7;
external_linear_layer(&mut state);

let next_state_cols = if r == NUM_EXTERNAL_ROUNDS / 2 - 1 {
&cols.internal_rounds.state
Expand All @@ -252,7 +249,7 @@ fn eval_external_round<AB: SP1AirBuilder>(
&cols.external_rounds[r + 1].state
};
for i in 0..WIDTH {
builder.assert_eq(next_state_cols[i], linear_layer_output[i].clone());
builder.assert_eq(next_state_cols[i], state[i].clone());
}
}

Expand All @@ -277,9 +274,8 @@ fn eval_internal_rounds<AB: SP1AirBuilder>(builder: &mut AB, cols: &Poseidon2Wid

// Apply the linear layer.
// See `populate_internal_rounds` for why we don't have columns for the new state here.
let mut linear_layer_input = state.clone();
linear_layer_input[0] = sbox_deg_7.clone();
internal_linear_layer(&linear_layer_input, &mut state);
state[0] = sbox_deg_7.clone();
internal_linear_layer(&mut state);

if r < NUM_INTERNAL_ROUNDS - 1 {
builder.assert_eq(round_cols.s0[r], state[0].clone());
Expand Down Expand Up @@ -311,10 +307,10 @@ where

// Apply the initial round.
let initial_round_output = {
let input: [AB::Expr; WIDTH] = core::array::from_fn(|i| cols.input[i].into());
let mut output: [AB::Expr; WIDTH] = core::array::from_fn(|_| AB::Expr::zero());
external_linear_layer(&input, &mut output);
output
let mut initial_round_output: [AB::Expr; WIDTH] =
core::array::from_fn(|i| cols.input[i].into());
external_linear_layer(&mut initial_round_output);
initial_round_output
};
for i in 0..WIDTH {
builder.assert_eq(
Expand Down Expand Up @@ -400,7 +396,6 @@ mod tests {

/// A test proving 2^10 permuations
#[test]
#[ignore]
fn prove_babybear() {
let config = BabyBearPoseidon2Inner::new();
let mut challenger = config.challenger();
Expand Down
17 changes: 6 additions & 11 deletions recursion/core/src/poseidon2_wide/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,35 +40,30 @@ pub fn matmul_internal<F: Field, AF: AbstractField<F = F>, const WIDTH: usize>(
state[i] += sum.clone();
}
}
pub(crate) fn external_linear_layer<AF: AbstractField>(
input: &[AF; WIDTH],
output: &mut [AF; WIDTH],
) {
output.clone_from_slice(input);
pub(crate) fn external_linear_layer<AF: AbstractField>(state: &mut [AF; WIDTH]) {
for j in (0..WIDTH).step_by(4) {
apply_m_4(&mut output[j..j + 4]);
apply_m_4(&mut state[j..j + 4]);
}
let sums: [AF; 4] = core::array::from_fn(|k| {
(0..WIDTH)
.step_by(4)
.map(|j| output[j + k].clone())
.map(|j| state[j + k].clone())
.sum::<AF>()
});

for j in 0..WIDTH {
output[j] += sums[j % 4].clone();
state[j] += sums[j % 4].clone();
}
}

pub(crate) fn internal_linear_layer<F: AbstractField>(input: &[F; WIDTH], output: &mut [F; WIDTH]) {
output.clone_from_slice(input);
pub(crate) fn internal_linear_layer<F: AbstractField>(state: &mut [F; WIDTH]) {
let matmul_constants: [<F as AbstractField>::F; WIDTH] = MATRIX_DIAG_16_BABYBEAR_U32
.iter()
.map(|x| <F as AbstractField>::F::from_wrapped_u32(*x))
.collect::<Vec<_>>()
.try_into()
.unwrap();
matmul_internal(output, matmul_constants);
matmul_internal(state, matmul_constants);
}

pub const MATRIX_DIAG_16_BABYBEAR_U32: [u32; 16] = [
Expand Down

0 comments on commit c570c1b

Please sign in to comment.