Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: trace gen #305

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ debug = ["parallel"]
debug-proof = ["parallel", "perf"]
serial = []
neon = ["p3-blake3/neon"]
keccak = []

[[bench]]
name = "main"
Expand Down
178 changes: 103 additions & 75 deletions core/src/syscall/precompiles/keccak256/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
use p3_field::PrimeField32;
use p3_keccak_air::{generate_trace_rows, NUM_KECCAK_COLS, NUM_ROUNDS};
use p3_matrix::dense::RowMajorMatrix;
use p3_maybe_rayon::prelude::ParallelIterator;
use p3_maybe_rayon::prelude::ParallelSlice;
use tracing::instrument;

use crate::{
air::MachineAir, runtime::ExecutionRecord, syscall::precompiles::keccak256::STATE_SIZE,
Expand All @@ -20,6 +23,7 @@
"KeccakPermute".to_string()
}

#[instrument(name = "generate keccak permute trace", skip_all)]
fn generate_trace(
&self,
input: &ExecutionRecord,
Expand All @@ -39,88 +43,112 @@
num_total_permutations = 1;
}

let mut new_field_events = Vec::new();
let mut rows = Vec::new();
for permutation_num in 0..num_total_permutations {
let is_real_permutation = permutation_num < num_real_permutations;

let event = if is_real_permutation {
Some(&input.keccak_permute_events[permutation_num])
} else {
None
};

let perm_input: [u64; STATE_SIZE] = if is_real_permutation {
event.unwrap().pre_state
} else {
[0; STATE_SIZE]
};

let start_clk = if is_real_permutation {
event.unwrap().clk
} else {
0
};

let shard = if is_real_permutation {
event.unwrap().shard
} else {
0
};

// First get the trace for the plonky3 keccak air.
let p3_keccak_trace = generate_trace_rows::<F>(vec![perm_input]);

// Create all the rows for the permutation.
for (i, p3_keccak_row) in (0..NUM_ROUNDS).zip(p3_keccak_trace.rows()) {
let mut row = [F::zero(); NUM_KECCAK_COLS + NUM_KECCAK_MEM_COLS];

// Copy the keccack row into the trace_row
row[..NUM_KECCAK_COLS].copy_from_slice(p3_keccak_row);

let mem_row = &mut row[NUM_KECCAK_COLS..];

let col: &mut KeccakMemCols<F> = mem_row.borrow_mut();
col.shard = F::from_canonical_u32(shard);
col.clk = F::from_canonical_u32(start_clk + i as u32 * 4);

// if this is the first row, then populate read memory accesses
if i == 0 && is_real_permutation {
for (j, read_record) in event.unwrap().state_read_records.iter().enumerate() {
col.state_mem[j].populate_read(*read_record, &mut new_field_events);
let chunk_size = std::cmp::max(num_total_permutations / num_cpus::get(), 1);

let rows_and_records = (0..num_total_permutations)
.collect::<Vec<_>>()
.par_chunks(chunk_size)
.map(|chunk_indices| {
let mut record = ExecutionRecord::default();
let mut new_field_events = Vec::new();
let mut chunk_rows = Vec::new();

chunk_indices.iter().for_each(|permutation_num| {
let is_real_permutation = *permutation_num < num_real_permutations;

let event = if is_real_permutation {
Some(&input.keccak_permute_events[*permutation_num])
} else {
None
};

let perm_input: [u64; STATE_SIZE] = if is_real_permutation {
event.unwrap().pre_state
} else {
[0; STATE_SIZE]
};

let start_clk = if is_real_permutation {
event.unwrap().clk
} else {
0
};

// First get the trace for the plonky3 keccak air.
let p3_keccak_trace = generate_trace_rows::<F>(vec![perm_input]);

let mut rows = Vec::new();

// Create all the rows for the permutation.
for (i, p3_keccak_row) in (0..NUM_ROUNDS).zip(p3_keccak_trace.rows()) {
// TODO: Is p3_keccak_trace_rows always 24 long? If so, we can just enumerate it.
let row_num = permutation_num * NUM_ROUNDS + i;
if row_num == num_rows {
break;
}

let mut row = [F::zero(); NUM_KECCAK_COLS + NUM_KECCAK_MEM_COLS];

// Copy the keccack row into the trace_row
row[..NUM_KECCAK_COLS].copy_from_slice(p3_keccak_row);

let mem_row = &mut row[NUM_KECCAK_COLS..];

let col: &mut KeccakMemCols<F> = mem_row.borrow_mut();

// if this is the first row, then populate read memory accesses
if i == 0 && is_real_permutation {
for (j, read_record) in
event.unwrap().state_read_records.iter().enumerate()
{
col.state_mem[j].populate_read(*read_record, &mut new_field_events);
}

col.state_addr = F::from_canonical_u32(event.unwrap().state_addr);
col.do_memory_check = F::one();
}

// if this is the last row, then populate write memory accesses
// if this is the last row, then populate write memory accesses
let last_row_num = NUM_ROUNDS - 1;
if i == last_row_num && is_real_permutation {
for (j, write_record) in
event.unwrap().state_write_records.iter().enumerate()
{
col.state_mem[j]
.populate_write(*write_record, &mut new_field_events);
}

col.state_addr = F::from_canonical_u32(event.unwrap().state_addr);
col.do_memory_check = F::one();
}

col.is_real = F::from_bool(is_real_permutation);

rows.push(row);
}

col.state_addr = F::from_canonical_u32(event.unwrap().state_addr);
col.do_memory_check = F::one();
}

// if this is the last row, then populate write memory accesses
let last_row_num = NUM_ROUNDS - 1;
if i == last_row_num && is_real_permutation {
for (j, write_record) in event.unwrap().state_write_records.iter().enumerate() {
col.state_mem[j].populate_write(*write_record, &mut new_field_events);
}

col.state_addr = F::from_canonical_u32(event.unwrap().state_addr);
col.do_memory_check = F::one();
}

col.is_real = F::from_bool(is_real_permutation);

rows.push(row);

if rows.len() == num_rows {
break;
}
}
chunk_rows.extend(rows);
});
record.add_field_events(&new_field_events);

(chunk_rows, record)
})
.collect::<Vec<_>>();

// Generate the trace rows for each event.
let mut rows: Vec<[F; NUM_KECCAK_COLS]> = vec![];
for mut row_and_record in rows_and_records {
rows.extend(row_and_record.0);

Check failure on line 141 in core/src/syscall/precompiles/keccak256/trace.rs

View workflow job for this annotation

GitHub Actions / Formatting & Clippy

the trait bound `std::vec::Vec<[F; 2733]>: std::iter::Extend<[F; 3388]>` is not satisfied

Check failure on line 141 in core/src/syscall/precompiles/keccak256/trace.rs

View workflow job for this annotation

GitHub Actions / CI Test Suite

the trait bound `Vec<[F; 2733]>: Extend<[F; 3388]>` is not satisfied
output.append(&mut row_and_record.1);
}

output.add_field_events(&new_field_events);

// Convert the trace to a row major matrix.
RowMajorMatrix::new(
rows.into_iter().flatten().collect::<Vec<_>>(),
NUM_KECCAK_COLS + NUM_KECCAK_MEM_COLS,
)
}
}

#[cfg(test)]
mod tests {}
117 changes: 78 additions & 39 deletions core/src/syscall/precompiles/weierstrass/weierstrass_double.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ use p3_field::AbstractField;
use p3_field::PrimeField32;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::MatrixRowSlices;
use p3_maybe_rayon::prelude::ParallelIterator;
use p3_maybe_rayon::prelude::ParallelSlice;
use sp1_derive::AlignedBorrow;
use std::fmt::Debug;
use std::marker::PhantomData;
use tracing::instrument;

pub const NUM_WEIERSTRASS_DOUBLE_COLS: usize = size_of::<WeierstrassDoubleAssignCols<u8>>();

Expand Down Expand Up @@ -162,49 +165,66 @@ impl<F: PrimeField32, E: EllipticCurve + WeierstrassParameters> MachineAir<F>
"WeierstrassDoubleAssign".to_string()
}

#[instrument(name = "generate WeierstrassDoubleAssign trace", skip_all)]
fn generate_trace(
&self,
input: &ExecutionRecord,
output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
let mut rows = Vec::new();

let mut new_field_events = Vec::new();

for i in 0..input.weierstrass_double_events.len() {
let event = input.weierstrass_double_events[i];
let mut row = [F::zero(); NUM_WEIERSTRASS_DOUBLE_COLS];
let cols: &mut WeierstrassDoubleAssignCols<F> = row.as_mut_slice().borrow_mut();

// Decode affine points.
let p = &event.p;
let p = AffinePoint::<E>::from_words_le(p);
let (p_x, p_y) = (p.x, p.y);

// Populate basic columns.
cols.is_real = F::one();
cols.shard = F::from_canonical_u32(event.shard);
cols.clk = F::from_canonical_u32(event.clk);
cols.p_ptr = F::from_canonical_u32(event.p_ptr);

Self::populate_field_ops(cols, p_x, p_y);

// Populate the memory access columns.
for i in 0..NUM_WORDS_EC_POINT {
cols.p_access[i].populate(event.p_memory_records[i], &mut new_field_events);
}

rows.push(row);
let chunk_size = std::cmp::max(input.weierstrass_double_events.len() / num_cpus::get(), 1);

let rows_and_records = input
.weierstrass_double_events
.par_chunks(chunk_size)
.map(|events| {
let mut record = ExecutionRecord::default();
let mut new_field_events = Vec::new();
let mut rows = events
.iter()
.map(|event| {
let mut row = [F::zero(); NUM_WEIERSTRASS_DOUBLE_COLS];
let cols: &mut WeierstrassDoubleAssignCols<F> =
row.as_mut_slice().borrow_mut();

// Decode affine points.
let p = &event.p;
let p = AffinePoint::<E>::from_words_le(p);
let (p_x, p_y) = (p.x, p.y);

// Populate basic columns.
cols.is_real = F::one();
cols.shard = F::from_canonical_u32(event.shard);
cols.clk = F::from_canonical_u32(event.clk);
cols.p_ptr = F::from_canonical_u32(event.p_ptr);

Self::populate_field_ops(cols, p_x, p_y);

// Populate the memory access columns.
for i in 0..NUM_WORDS_EC_POINT {
cols.p_access[i]
.populate(event.p_memory_records[i], &mut new_field_events);
}
row
})
.collect::<Vec<_>>();
record.add_field_events(&new_field_events);
pad_rows(&mut rows, || {
let mut row = [F::zero(); NUM_WEIERSTRASS_DOUBLE_COLS];
let cols: &mut WeierstrassDoubleAssignCols<F> = row.as_mut_slice().borrow_mut();
let zero = BigUint::zero();
Self::populate_field_ops(cols, zero.clone(), zero.clone());
row
});
(rows, record)
})
.collect::<Vec<_>>();

// Generate the trace rows for each event.
let mut rows: Vec<[F; NUM_WEIERSTRASS_DOUBLE_COLS]> = vec![];
for mut row_and_record in rows_and_records {
rows.extend(row_and_record.0);
output.append(&mut row_and_record.1);
}
output.add_field_events(&new_field_events);

pad_rows(&mut rows, || {
let mut row = [F::zero(); NUM_WEIERSTRASS_DOUBLE_COLS];
let cols: &mut WeierstrassDoubleAssignCols<F> = row.as_mut_slice().borrow_mut();
let zero = BigUint::zero();
Self::populate_field_ops(cols, zero.clone(), zero.clone());
row
});

// Convert the trace to a row major matrix.
RowMajorMatrix::new(
Expand Down Expand Up @@ -344,11 +364,30 @@ where
#[cfg(test)]
pub mod tests {

use p3_baby_bear::BabyBear;
use p3_matrix::dense::RowMajorMatrix;

use crate::{
runtime::Program,
utils::{run_test, setup_logger, tests::SECP256K1_DOUBLE_ELF},
air::MachineAir,
alu::AluEvent,
runtime::{ExecutionRecord, Opcode, Program},
utils::{
ec::weierstrass::secp256k1::Secp256k1, run_test, setup_logger,
tests::SECP256K1_DOUBLE_ELF,
},
};

use super::WeierstrassDoubleAssignChip;

#[test]
fn generate_trace() {
let mut shard = ExecutionRecord::default();
shard.add_events = vec![AluEvent::new(0, Opcode::ADD, 14, 8, 6)];
let chip = WeierstrassDoubleAssignChip::<Secp256k1>::new();
let _trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&shard, &mut ExecutionRecord::default());
}

#[test]
fn test_secp256k1_double_simple() {
setup_logger();
Expand Down
Binary file modified tests/secp256k1-double/elf/riscv32im-succinct-zkvm-elf
Binary file not shown.
Loading
Loading