Skip to content

Commit

Permalink
refactor air in keccak
Browse files Browse the repository at this point in the history
  • Loading branch information
tamirhemo committed Feb 24, 2024
1 parent ff14294 commit 22c3880
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 49 deletions.
41 changes: 20 additions & 21 deletions core/src/syscall/precompiles/keccak256/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use core::borrow::Borrow;

use p3_air::{Air, AirBuilder, BaseAir};
use p3_field::AbstractField;
use p3_keccak_air::{KeccakAir, U64_LIMBS};
use p3_keccak_air::{KeccakAir, KeccakCols, NUM_KECCAK_COLS, U64_LIMBS};
use p3_matrix::MatrixRowSlices;

use crate::{
Expand All @@ -11,13 +11,13 @@ use crate::{
};

use super::{
columns::{KeccakCols, NUM_KECCAK_COLS},
columns::{KeccakMemCols, NUM_KECCAK_MEM_COLS},
KeccakPermuteChip, STATE_NUM_WORDS, STATE_SIZE,
};

impl<F> BaseAir<F> for KeccakPermuteChip {
fn width(&self) -> usize {
NUM_KECCAK_COLS
NUM_KECCAK_COLS + NUM_KECCAK_MEM_COLS
}
}

Expand All @@ -27,22 +27,23 @@ where
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local: &KeccakCols<AB::Var> = main.row_slice(0).borrow();

let local_keccak: &KeccakCols<AB::Var> = main.row_slice(0)[..NUM_KECCAK_COLS].borrow();
let local_mem: &KeccakMemCols<AB::Var> = main.row_slice(0)[NUM_KECCAK_COLS..].borrow();

builder.assert_eq(
(local.p3_keccak_cols.step_flags[0] + local.p3_keccak_cols.step_flags[23])
* local.is_real,
local.do_memory_check,
(local_keccak.step_flags[0] + local_keccak.step_flags[23]) * local_mem.is_real,
local_mem.do_memory_check,
);

// Constrain memory
for i in 0..STATE_NUM_WORDS as u32 {
builder.constraint_memory_access(
local.shard,
local.clk,
local.state_addr + AB::Expr::from_canonical_u32(i * 4),
&local.state_mem[i as usize],
local.do_memory_check,
local_mem.shard,
local_mem.clk,
local_mem.state_addr + AB::Expr::from_canonical_u32(i * 4),
&local_mem.state_mem[i as usize],
local_mem.do_memory_check,
);
}

Expand All @@ -53,8 +54,8 @@ where
let expr_2_pow_8 = AB::Expr::from_canonical_u32(2u32.pow(8));

for i in 0..STATE_SIZE as u32 {
let least_sig_word = local.state_mem[(i * 2) as usize].value();
let most_sig_word = local.state_mem[(i * 2 + 1) as usize].value();
let least_sig_word = local_mem.state_mem[(i * 2) as usize].value();
let most_sig_word = local_mem.state_mem[(i * 2 + 1) as usize].value();
let memory_limbs = [
least_sig_word.0[0] + least_sig_word.0[1] * expr_2_pow_8.clone(),
least_sig_word.0[2] + least_sig_word.0[3] * expr_2_pow_8.clone(),
Expand All @@ -66,28 +67,26 @@ where
let x_idx = i % 5;

// When step_flags[0] == 1, then verify memory matches with local.p3_keccak_cols.a
let a_value_limbs = local.p3_keccak_cols.a[y_idx as usize][x_idx as usize];
let a_value_limbs = local_keccak.a[y_idx as usize][x_idx as usize];
for i in 0..U64_LIMBS {
builder
.when(local.p3_keccak_cols.step_flags[0] * local.is_real)
.when(local_keccak.step_flags[0] * local_mem.is_real)
.assert_eq(memory_limbs[i].clone(), a_value_limbs[i]);
}

// When step_flags[23] == 1, then verify memory matches with local.p3_keccak_cols.a_prime_prime_prime
for i in 0..U64_LIMBS {
builder
.when(local.p3_keccak_cols.step_flags[23] * local.is_real)
.when(local_keccak.step_flags[23] * local_mem.is_real)
.assert_eq(
memory_limbs[i].clone(),
local
.p3_keccak_cols
.a_prime_prime_prime(x_idx as usize, y_idx as usize, i),
local_keccak.a_prime_prime_prime(x_idx as usize, y_idx as usize, i),
)
}
}

let mut sub_builder =
SubAirBuilder::<AB, KeccakAir, AB::Var>::new(builder, self.p3_keccak_col_range.clone());
SubAirBuilder::<AB, KeccakAir, AB::Var>::new(builder, 0..NUM_KECCAK_COLS);

// Eval the plonky3 keccak air
self.p3_keccak.eval(&mut sub_builder);
Expand Down
12 changes: 5 additions & 7 deletions core/src/syscall/precompiles/keccak256/columns.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use core::borrow::{Borrow, BorrowMut};
use core::mem::{offset_of, size_of};
use core::mem::size_of;

use p3_keccak_air::KeccakCols as P3KeccakCols;
use sp1_derive::AlignedBorrow;

use crate::memory::MemoryReadWriteCols;
Expand All @@ -10,12 +9,11 @@ use super::STATE_NUM_WORDS;

#[derive(AlignedBorrow)]
#[repr(C)]
pub(crate) struct KeccakCols<T> {
pub(crate) struct KeccakMemCols<T> {
pub shard: T,
pub clk: T,

pub p3_keccak_cols: P3KeccakCols<T>,

// pub p3_keccak_cols: P3KeccakCols<T>,
pub state_mem: [MemoryReadWriteCols<T>; STATE_NUM_WORDS],
pub state_addr: T,

Expand All @@ -24,5 +22,5 @@ pub(crate) struct KeccakCols<T> {
pub is_real: T,
}

pub const NUM_KECCAK_COLS: usize = size_of::<KeccakCols<u8>>();
pub const P3_KECCAK_COLS_OFFSET: usize = offset_of!(KeccakCols<u8>, p3_keccak_cols);
pub const NUM_KECCAK_MEM_COLS: usize = size_of::<KeccakMemCols<u8>>();
// pub const P3_KECCAK_COLS_OFFSET: usize = offset_of!(KeccakCols<u8>, p3_keccak_cols);
9 changes: 2 additions & 7 deletions core/src/syscall/precompiles/keccak256/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use std::ops::Range;

use crate::syscall::precompiles::{MemoryReadRecord, MemoryWriteRecord};

use p3_keccak_air::{KeccakAir, NUM_KECCAK_COLS as P3_NUM_KECCAK_COLS};
use p3_keccak_air::KeccakAir;

use self::columns::P3_KECCAK_COLS_OFFSET;
// use self::columns::P3_KECCAK_COLS_OFFSET;

mod air;
pub mod columns;
Expand All @@ -29,7 +27,6 @@ pub struct KeccakPermuteEvent {

pub struct KeccakPermuteChip {
p3_keccak: KeccakAir,
p3_keccak_col_range: Range<usize>,
}

impl KeccakPermuteChip {
Expand All @@ -38,8 +35,6 @@ impl KeccakPermuteChip {
let p3_keccak_air = KeccakAir {};
Self {
p3_keccak: p3_keccak_air,
p3_keccak_col_range: P3_KECCAK_COLS_OFFSET
..(P3_KECCAK_COLS_OFFSET + P3_NUM_KECCAK_COLS),
}
}
}
Expand Down
27 changes: 13 additions & 14 deletions core/src/syscall/precompiles/keccak256/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@ use std::borrow::BorrowMut;
use alloc::vec::Vec;

use p3_field::PrimeField32;
use p3_keccak_air::{generate_trace_rows, NUM_ROUNDS};
use p3_keccak_air::{generate_trace_rows, NUM_KECCAK_COLS, NUM_ROUNDS};
use p3_matrix::dense::RowMajorMatrix;

use crate::{
air::MachineAir,
runtime::ExecutionRecord,
syscall::precompiles::keccak256::{
columns::{KeccakCols, NUM_KECCAK_COLS},
STATE_SIZE,
},
air::MachineAir, runtime::ExecutionRecord, syscall::precompiles::keccak256::STATE_SIZE,
};

use super::KeccakPermuteChip;
use super::{
columns::{KeccakMemCols, NUM_KECCAK_MEM_COLS},
KeccakPermuteChip,
};

impl<F: PrimeField32> MachineAir<F> for KeccakPermuteChip {
fn name(&self) -> String {
Expand Down Expand Up @@ -75,13 +73,14 @@ impl<F: PrimeField32> MachineAir<F> for KeccakPermuteChip {

// Create all the rows for the permutation.
for (i, p3_keccak_row) in (0..NUM_ROUNDS).zip(p3_keccak_trace.rows()) {
let mut row = [F::zero(); NUM_KECCAK_COLS];
let mut row = [F::zero(); NUM_KECCAK_COLS + NUM_KECCAK_MEM_COLS];

// Copy the keccack row into the trace_row
row[..NUM_KECCAK_COLS].copy_from_slice(p3_keccak_row);

// copy over the p3_keccak_row to the row
row[self.p3_keccak_col_range.start..self.p3_keccak_col_range.end]
.copy_from_slice(p3_keccak_row);
let mem_row = &mut row[NUM_KECCAK_COLS..];

let col: &mut KeccakCols<F> = row.as_mut_slice().borrow_mut();
let col: &mut KeccakMemCols<F> = mem_row.borrow_mut();
col.shard = F::from_canonical_u32(shard);
col.clk = F::from_canonical_u32(start_clk + i as u32 * 4);

Expand Down Expand Up @@ -121,7 +120,7 @@ impl<F: PrimeField32> MachineAir<F> for KeccakPermuteChip {
// Convert the trace to a row major matrix.
RowMajorMatrix::new(
rows.into_iter().flatten().collect::<Vec<_>>(),
NUM_KECCAK_COLS,
NUM_KECCAK_COLS + NUM_KECCAK_MEM_COLS,
)
}
}

0 comments on commit 22c3880

Please sign in to comment.