diff --git a/core/Cargo.toml b/core/Cargo.toml index 7014e87893..0717bfaa9e 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -70,6 +70,7 @@ debug = ["parallel"] debug-proof = ["parallel", "perf"] serial = [] neon = ["p3-blake3/neon"] +keccak = [] [[bench]] name = "main" diff --git a/core/src/syscall/precompiles/keccak256/trace.rs b/core/src/syscall/precompiles/keccak256/trace.rs index ff4d0d13f8..63c87ca4cf 100644 --- a/core/src/syscall/precompiles/keccak256/trace.rs +++ b/core/src/syscall/precompiles/keccak256/trace.rs @@ -5,6 +5,9 @@ use alloc::vec::Vec; use p3_field::PrimeField32; use p3_keccak_air::{generate_trace_rows, NUM_KECCAK_COLS, NUM_ROUNDS}; use p3_matrix::dense::RowMajorMatrix; +use p3_maybe_rayon::prelude::ParallelIterator; +use p3_maybe_rayon::prelude::ParallelSlice; +use tracing::instrument; use crate::{ air::MachineAir, runtime::ExecutionRecord, syscall::precompiles::keccak256::STATE_SIZE, @@ -20,6 +23,7 @@ impl MachineAir for KeccakPermuteChip { "KeccakPermute".to_string() } + #[instrument(name = "generate keccak permute trace", skip_all)] fn generate_trace( &self, input: &ExecutionRecord, @@ -39,84 +43,105 @@ impl MachineAir for KeccakPermuteChip { num_total_permutations = 1; } - let mut new_field_events = Vec::new(); - let mut rows = Vec::new(); - for permutation_num in 0..num_total_permutations { - let is_real_permutation = permutation_num < num_real_permutations; - - let event = if is_real_permutation { - Some(&input.keccak_permute_events[permutation_num]) - } else { - None - }; - - let perm_input: [u64; STATE_SIZE] = if is_real_permutation { - event.unwrap().pre_state - } else { - [0; STATE_SIZE] - }; - - let start_clk = if is_real_permutation { - event.unwrap().clk - } else { - 0 - }; - - let shard = if is_real_permutation { - event.unwrap().shard - } else { - 0 - }; - - // First get the trace for the plonky3 keccak air. - let p3_keccak_trace = generate_trace_rows::(vec![perm_input]); - - // Create all the rows for the permutation. - for (i, p3_keccak_row) in (0..NUM_ROUNDS).zip(p3_keccak_trace.rows()) { - let mut row = [F::zero(); NUM_KECCAK_COLS + NUM_KECCAK_MEM_COLS]; - - // Copy the keccack row into the trace_row - row[..NUM_KECCAK_COLS].copy_from_slice(p3_keccak_row); - - let mem_row = &mut row[NUM_KECCAK_COLS..]; - - let col: &mut KeccakMemCols = mem_row.borrow_mut(); - col.shard = F::from_canonical_u32(shard); - col.clk = F::from_canonical_u32(start_clk + i as u32 * 4); - - // if this is the first row, then populate read memory accesses - if i == 0 && is_real_permutation { - for (j, read_record) in event.unwrap().state_read_records.iter().enumerate() { - col.state_mem[j].populate_read(*read_record, &mut new_field_events); + let chunk_size = std::cmp::max(num_total_permutations / num_cpus::get(), 1); + + let rows_and_records = (0..num_total_permutations) + .collect::>() + .par_chunks(chunk_size) + .map(|chunk_indices| { + let mut record = ExecutionRecord::default(); + let mut new_field_events = Vec::new(); + let mut chunk_rows = Vec::new(); + + chunk_indices.iter().for_each(|permutation_num| { + let is_real_permutation = *permutation_num < num_real_permutations; + + let event = if is_real_permutation { + Some(&input.keccak_permute_events[*permutation_num]) + } else { + None + }; + + let perm_input: [u64; STATE_SIZE] = if is_real_permutation { + event.unwrap().pre_state + } else { + [0; STATE_SIZE] + }; + + let start_clk = if is_real_permutation { + event.unwrap().clk + } else { + 0 + }; + + // First get the trace for the plonky3 keccak air. + let p3_keccak_trace = generate_trace_rows::(vec![perm_input]); + + let mut rows = Vec::new(); + + // Create all the rows for the permutation. + for (i, p3_keccak_row) in (0..NUM_ROUNDS).zip(p3_keccak_trace.rows()) { + // TODO: Is p3_keccak_trace_rows always 24 long? If so, we can just enumerate it. + let row_num = permutation_num * NUM_ROUNDS + i; + if row_num == num_rows { + break; + } + + let mut row = [F::zero(); NUM_KECCAK_COLS + NUM_KECCAK_MEM_COLS]; + + // Copy the keccack row into the trace_row + row[..NUM_KECCAK_COLS].copy_from_slice(p3_keccak_row); + + let mem_row = &mut row[NUM_KECCAK_COLS..]; + + let col: &mut KeccakMemCols = mem_row.borrow_mut(); + + // if this is the first row, then populate read memory accesses + if i == 0 && is_real_permutation { + for (j, read_record) in + event.unwrap().state_read_records.iter().enumerate() + { + col.state_mem[j].populate_read(*read_record, &mut new_field_events); + } + + col.state_addr = F::from_canonical_u32(event.unwrap().state_addr); + col.do_memory_check = F::one(); + } + + // if this is the last row, then populate write memory accesses + // if this is the last row, then populate write memory accesses + let last_row_num = NUM_ROUNDS - 1; + if i == last_row_num && is_real_permutation { + for (j, write_record) in + event.unwrap().state_write_records.iter().enumerate() + { + col.state_mem[j] + .populate_write(*write_record, &mut new_field_events); + } + + col.state_addr = F::from_canonical_u32(event.unwrap().state_addr); + col.do_memory_check = F::one(); + } + + col.is_real = F::from_bool(is_real_permutation); + + rows.push(row); } - - col.state_addr = F::from_canonical_u32(event.unwrap().state_addr); - col.do_memory_check = F::one(); - } - - // if this is the last row, then populate write memory accesses - let last_row_num = NUM_ROUNDS - 1; - if i == last_row_num && is_real_permutation { - for (j, write_record) in event.unwrap().state_write_records.iter().enumerate() { - col.state_mem[j].populate_write(*write_record, &mut new_field_events); - } - - col.state_addr = F::from_canonical_u32(event.unwrap().state_addr); - col.do_memory_check = F::one(); - } - - col.is_real = F::from_bool(is_real_permutation); - - rows.push(row); - - if rows.len() == num_rows { - break; - } - } + chunk_rows.extend(rows); + }); + record.add_field_events(&new_field_events); + + (chunk_rows, record) + }) + .collect::>(); + + // Generate the trace rows for each event. + let mut rows: Vec<[F; NUM_KECCAK_COLS]> = vec![]; + for mut row_and_record in rows_and_records { + rows.extend(row_and_record.0); + output.append(&mut row_and_record.1); } - output.add_field_events(&new_field_events); - // Convert the trace to a row major matrix. RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), @@ -124,3 +149,6 @@ impl MachineAir for KeccakPermuteChip { ) } } + +#[cfg(test)] +mod tests {} diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs index ced770596d..0cf2ef3063 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs @@ -27,9 +27,12 @@ use p3_field::AbstractField; use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::MatrixRowSlices; +use p3_maybe_rayon::prelude::ParallelIterator; +use p3_maybe_rayon::prelude::ParallelSlice; use sp1_derive::AlignedBorrow; use std::fmt::Debug; use std::marker::PhantomData; +use tracing::instrument; pub const NUM_WEIERSTRASS_DOUBLE_COLS: usize = size_of::>(); @@ -162,49 +165,66 @@ impl MachineAir "WeierstrassDoubleAssign".to_string() } + #[instrument(name = "generate WeierstrassDoubleAssign trace", skip_all)] fn generate_trace( &self, input: &ExecutionRecord, output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let mut rows = Vec::new(); - - let mut new_field_events = Vec::new(); - - for i in 0..input.weierstrass_double_events.len() { - let event = input.weierstrass_double_events[i]; - let mut row = [F::zero(); NUM_WEIERSTRASS_DOUBLE_COLS]; - let cols: &mut WeierstrassDoubleAssignCols = row.as_mut_slice().borrow_mut(); - - // Decode affine points. - let p = &event.p; - let p = AffinePoint::::from_words_le(p); - let (p_x, p_y) = (p.x, p.y); - - // Populate basic columns. - cols.is_real = F::one(); - cols.shard = F::from_canonical_u32(event.shard); - cols.clk = F::from_canonical_u32(event.clk); - cols.p_ptr = F::from_canonical_u32(event.p_ptr); - - Self::populate_field_ops(cols, p_x, p_y); - - // Populate the memory access columns. - for i in 0..NUM_WORDS_EC_POINT { - cols.p_access[i].populate(event.p_memory_records[i], &mut new_field_events); - } - - rows.push(row); + let chunk_size = std::cmp::max(input.weierstrass_double_events.len() / num_cpus::get(), 1); + + let rows_and_records = input + .weierstrass_double_events + .par_chunks(chunk_size) + .map(|events| { + let mut record = ExecutionRecord::default(); + let mut new_field_events = Vec::new(); + let mut rows = events + .iter() + .map(|event| { + let mut row = [F::zero(); NUM_WEIERSTRASS_DOUBLE_COLS]; + let cols: &mut WeierstrassDoubleAssignCols = + row.as_mut_slice().borrow_mut(); + + // Decode affine points. + let p = &event.p; + let p = AffinePoint::::from_words_le(p); + let (p_x, p_y) = (p.x, p.y); + + // Populate basic columns. + cols.is_real = F::one(); + cols.shard = F::from_canonical_u32(event.shard); + cols.clk = F::from_canonical_u32(event.clk); + cols.p_ptr = F::from_canonical_u32(event.p_ptr); + + Self::populate_field_ops(cols, p_x, p_y); + + // Populate the memory access columns. + for i in 0..NUM_WORDS_EC_POINT { + cols.p_access[i] + .populate(event.p_memory_records[i], &mut new_field_events); + } + row + }) + .collect::>(); + record.add_field_events(&new_field_events); + pad_rows(&mut rows, || { + let mut row = [F::zero(); NUM_WEIERSTRASS_DOUBLE_COLS]; + let cols: &mut WeierstrassDoubleAssignCols = row.as_mut_slice().borrow_mut(); + let zero = BigUint::zero(); + Self::populate_field_ops(cols, zero.clone(), zero.clone()); + row + }); + (rows, record) + }) + .collect::>(); + + // Generate the trace rows for each event. + let mut rows: Vec<[F; NUM_WEIERSTRASS_DOUBLE_COLS]> = vec![]; + for mut row_and_record in rows_and_records { + rows.extend(row_and_record.0); + output.append(&mut row_and_record.1); } - output.add_field_events(&new_field_events); - - pad_rows(&mut rows, || { - let mut row = [F::zero(); NUM_WEIERSTRASS_DOUBLE_COLS]; - let cols: &mut WeierstrassDoubleAssignCols = row.as_mut_slice().borrow_mut(); - let zero = BigUint::zero(); - Self::populate_field_ops(cols, zero.clone(), zero.clone()); - row - }); // Convert the trace to a row major matrix. RowMajorMatrix::new( @@ -344,11 +364,30 @@ where #[cfg(test)] pub mod tests { + use p3_baby_bear::BabyBear; + use p3_matrix::dense::RowMajorMatrix; + use crate::{ - runtime::Program, - utils::{run_test, setup_logger, tests::SECP256K1_DOUBLE_ELF}, + air::MachineAir, + alu::AluEvent, + runtime::{ExecutionRecord, Opcode, Program}, + utils::{ + ec::weierstrass::secp256k1::Secp256k1, run_test, setup_logger, + tests::SECP256K1_DOUBLE_ELF, + }, }; + use super::WeierstrassDoubleAssignChip; + + #[test] + fn generate_trace() { + let mut shard = ExecutionRecord::default(); + shard.add_events = vec![AluEvent::new(0, Opcode::ADD, 14, 8, 6)]; + let chip = WeierstrassDoubleAssignChip::::new(); + let _trace: RowMajorMatrix = + chip.generate_trace(&shard, &mut ExecutionRecord::default()); + } + #[test] fn test_secp256k1_double_simple() { setup_logger(); diff --git a/tests/secp256k1-double/elf/riscv32im-succinct-zkvm-elf b/tests/secp256k1-double/elf/riscv32im-succinct-zkvm-elf index ec3c1d332a..2a3499310c 100755 Binary files a/tests/secp256k1-double/elf/riscv32im-succinct-zkvm-elf and b/tests/secp256k1-double/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/secp256k1-double/src/main.rs b/tests/secp256k1-double/src/main.rs index 52aa791145..8a2d8b8c20 100644 --- a/tests/secp256k1-double/src/main.rs +++ b/tests/secp256k1-double/src/main.rs @@ -6,31 +6,33 @@ extern "C" { } pub fn main() { - // generator. - // 55066263022277343669578718895168534326250603453777594175500187360389116729240 - // 32670510020758816978083085130507043184471273380659243275938904335757337482424 - let mut a: [u8; 64] = [ - 152, 23, 248, 22, 91, 129, 242, 89, 217, 40, 206, 45, 219, 252, 155, 2, 7, 11, 135, 206, - 149, 98, 160, 85, 172, 187, 220, 249, 126, 102, 190, 121, 184, 212, 16, 251, 143, 208, 71, - 156, 25, 84, 133, 166, 72, 180, 23, 253, 168, 8, 17, 14, 252, 251, 164, 93, 101, 196, 163, - 38, 119, 218, 58, 72, - ]; + for _ in 0..16384 { + // generator. + // 55066263022277343669578718895168534326250603453777594175500187360389116729240 + // 32670510020758816978083085130507043184471273380659243275938904335757337482424 + let mut a: [u8; 64] = [ + 152, 23, 248, 22, 91, 129, 242, 89, 217, 40, 206, 45, 219, 252, 155, 2, 7, 11, 135, + 206, 149, 98, 160, 85, 172, 187, 220, 249, 126, 102, 190, 121, 184, 212, 16, 251, 143, + 208, 71, 156, 25, 84, 133, 166, 72, 180, 23, 253, 168, 8, 17, 14, 252, 251, 164, 93, + 101, 196, 163, 38, 119, 218, 58, 72, + ]; - unsafe { - syscall_secp256k1_double(a.as_mut_ptr() as *mut u32); - } + unsafe { + syscall_secp256k1_double(a.as_mut_ptr() as *mut u32); + } - // 2 * generator. - // 89565891926547004231252920425935692360644145829622209833684329913297188986597 - // 12158399299693830322967808612713398636155367887041628176798871954788371653930 - let b: [u8; 64] = [ - 229, 158, 112, 92, 185, 9, 172, 171, 167, 60, 239, 140, 75, 142, 119, 92, 216, 124, 192, - 149, 110, 64, 69, 48, 109, 125, 237, 65, 148, 127, 4, 198, 42, 229, 207, 80, 169, 49, 100, - 35, 225, 208, 102, 50, 101, 50, 246, 247, 238, 234, 108, 70, 25, 132, 197, 163, 57, 195, - 61, 166, 254, 104, 225, 26, - ]; + // 2 * generator. + // 89565891926547004231252920425935692360644145829622209833684329913297188986597 + // 12158399299693830322967808612713398636155367887041628176798871954788371653930 + let b: [u8; 64] = [ + 229, 158, 112, 92, 185, 9, 172, 171, 167, 60, 239, 140, 75, 142, 119, 92, 216, 124, + 192, 149, 110, 64, 69, 48, 109, 125, 237, 65, 148, 127, 4, 198, 42, 229, 207, 80, 169, + 49, 100, 35, 225, 208, 102, 50, 101, 50, 246, 247, 238, 234, 108, 70, 25, 132, 197, + 163, 57, 195, 61, 166, 254, 104, 225, 26, + ]; - assert_eq!(a, b); + assert_eq!(a, b); + } println!("done"); }