diff --git a/core/src/stark/chip.rs b/core/src/stark/chip.rs index af0957e53b..881bcc3094 100644 --- a/core/src/stark/chip.rs +++ b/core/src/stark/chip.rs @@ -87,14 +87,21 @@ where where F: PrimeField, { + let batch_size = self.logup_batch_size(); generate_permutation_trace( &self.sends, &self.receives, preprocessed, main, random_elements, + batch_size, ) } + + pub fn logup_batch_size(&self) -> usize { + // TODO: calculate by log_quotient_degree. + 2 + } } impl BaseAir for Chip @@ -156,7 +163,8 @@ where // Evaluate the execution trace constraints. self.air.eval(builder); // Evaluate permutation constraints. - eval_permutation_constraints(&self.sends, &self.receives, builder); + let batch_size = self.logup_batch_size(); + eval_permutation_constraints(&self.sends, &self.receives, batch_size, builder); } } diff --git a/core/src/stark/permutation.rs b/core/src/stark/permutation.rs index 9c39cf4766..067c1ffd52 100644 --- a/core/src/stark/permutation.rs +++ b/core/src/stark/permutation.rs @@ -1,9 +1,9 @@ +use itertools::Itertools; use p3_air::{ExtensionBuilder, PairBuilder}; -use p3_field::{AbstractField, ExtensionField, Field, Powers, PrimeField}; -use p3_matrix::{dense::RowMajorMatrix, Matrix, MatrixRowSlices}; +use p3_field::{AbstractExtensionField, AbstractField, ExtensionField, Field, Powers, PrimeField}; +use p3_matrix::{dense::RowMajorMatrix, Matrix, MatrixRowSlices, MatrixRowSlicesMut}; use p3_maybe_rayon::prelude::*; -use super::util::batch_multiplicative_inverse_inplace; use crate::{air::MultiTableAirBuilder, lookup::Interaction}; /// Generates powers of a random element based on how many interactions there are in the chip. @@ -24,6 +24,48 @@ pub fn generate_interaction_rlc_elements( random_element.powers().skip(1).take(n).collect::>() } +#[allow(clippy::too_many_arguments)] +pub fn populate_permutation_row>( + row: &mut [EF], + preprocessed_row: &[F], + main_row: &[F], + sends: &[Interaction], + receives: &[Interaction], + alphas: &[EF], + betas: Powers, + batch_size: usize, +) { + let interaction_chunks = &sends + .iter() + .map(|int| (int, true)) + .chain(receives.iter().map(|int| (int, false))) + .chunks(batch_size); + let num_chunks = (sends.len() + receives.len() + 1) / batch_size; + assert_eq!(num_chunks + 1, row.len()); + // Compute the denominators \prod_{i\in B} row_fingerprint(alpha, beta). + for (value, chunk) in row.iter_mut().zip(interaction_chunks) { + *value = chunk + .into_iter() + .map(|(interaction, is_send)| { + let alpha = alphas[interaction.argument_index()]; + let mut denominator = alpha; + for (columns, beta) in interaction.values.iter().zip(betas.clone()) { + denominator += beta * columns.apply::(preprocessed_row, main_row) + } + let mut mult = interaction + .multiplicity + .apply::(preprocessed_row, main_row); + + if !is_send { + mult = -mult; + } + + EF::from_base(mult) / denominator + }) + .sum(); + } +} + /// Generates the permutation trace for the given chip and main trace based on a variant of LogUp. /// /// The permutation trace has (N+1)*EF::NUM_COLS columns, where N is the number of interactions in @@ -34,6 +76,7 @@ pub(crate) fn generate_permutation_trace>( preprocessed: Option<&RowMajorMatrix>, main: &RowMajorMatrix, random_elements: &[EF], + batch_size: usize, ) -> RowMajorMatrix { // Generate the RLC elements to uniquely identify each interaction. let alphas = generate_interaction_rlc_elements(sends, receives, random_elements[0]); @@ -49,119 +92,103 @@ pub(crate) fn generate_permutation_trace>( // where f_{i, c_k} is the value at row i for column c_k. The computed value is essentially a // fingerprint for the interaction. let chunk_rate = 1 << 8; - let permutation_trace_width = sends.len() + receives.len() + 1; - - let mut permutation_trace_values = { - // Compute the permutation trace values in parallel. + let permutation_trace_width = (sends.len() + receives.len() + 1) / batch_size + 1; + let height = main.height(); - match preprocessed { - Some(prep) => { - let mut values = prep - .par_row_chunks(chunk_rate) - .zip_eq(main.par_row_chunks(chunk_rate)) - .flat_map(|(prep_rows_chunk, main_rows_chunk)| { - prep_rows_chunk - .rows() - .zip(main_rows_chunk.rows()) - .flat_map(|(prep_row, main_row)| { - compute_permutation_row( - prep_row, - main_row, - sends, - receives, - &alphas, - betas.clone(), - ) - }) - .collect::>() - }) - .collect::>(); + let mut permutation_trace = RowMajorMatrix::new( + vec![EF::zero(); permutation_trace_width * height], + permutation_trace_width, + ); - // Compute the permutation trace values for the remainder. - let remainder = main.height() % chunk_rate; - for i in 0..remainder { - let perm_row = compute_permutation_row( - prep.row_slice(main.height() - remainder + i), - main.row_slice(main.height() - remainder + i), - sends, - receives, - &alphas, - betas.clone(), - ); - values.extend(perm_row); - } + // Compute the permutation trace values in parallel. - values + match preprocessed { + Some(prep) => { + permutation_trace + .par_row_chunks_mut(chunk_rate) + .zip_eq(prep.par_row_chunks(chunk_rate)) + .zip_eq(main.par_row_chunks(chunk_rate)) + .for_each(|((mut chunk, prep_rows_chunk), main_rows_chunk)| { + chunk + .rows_mut() + .zip(prep_rows_chunk.rows()) + .zip(main_rows_chunk.rows()) + .for_each(|((row, prep_row), main_row)| { + populate_permutation_row( + row, + prep_row, + main_row, + sends, + receives, + &alphas, + betas.clone(), + batch_size, + ) + }) + }); + // Compute the permutation trace values for the remainder. + let remainder = height % chunk_rate; + for i in 0..remainder { + let index = height - remainder + i; + populate_permutation_row( + permutation_trace.row_slice_mut(index), + prep.row_slice(index), + main.row_slice(index), + sends, + receives, + &alphas, + betas.clone(), + batch_size, + ); } - None => { - let mut values = main - .par_row_chunks(chunk_rate) - .flat_map(|main_rows_chunk| { - main_rows_chunk - .rows() - .flat_map(|main_row| { - compute_permutation_row( - &[], - main_row, - sends, - receives, - &alphas, - betas.clone(), - ) - }) - .collect::>() - }) - .collect::>(); - - // Compute the permutation trace values for the remainder. - let remainder = main.height() % chunk_rate; - for i in 0..remainder { - let perm_row = compute_permutation_row( - &[], - main.row_slice(main.height() - remainder + i), - sends, - receives, - &alphas, - betas.clone(), - ); - values.extend(perm_row); - } - - values + } + None => { + permutation_trace + .par_row_chunks_mut(chunk_rate) + .zip_eq(main.par_row_chunks(chunk_rate)) + .for_each(|(mut chunk, main_rows_chunk)| { + chunk + .rows_mut() + .zip(main_rows_chunk.rows()) + .for_each(|(row, main_row)| { + populate_permutation_row( + row, + &[], + main_row, + sends, + receives, + &alphas, + betas.clone(), + batch_size, + ) + }) + }); + // Compute the permutation trace values for the remainder. + let remainder = height % chunk_rate; + for i in 0..remainder { + let index = height - remainder + i; + populate_permutation_row( + permutation_trace.row_slice_mut(index), + &[], + main.row_slice(index), + sends, + receives, + &alphas, + betas.clone(), + batch_size, + ); } } - }; - - // The permutation trace is actually the multiplicative inverse of the RLC's we computed above. - permutation_trace_values - .chunks_mut(chunk_rate) - .par_bridge() - .for_each(|chunk| batch_multiplicative_inverse_inplace(chunk)); - let mut permutation_trace = - RowMajorMatrix::new(permutation_trace_values, permutation_trace_width); + } - // Weight each row of the permutation trace by the respective multiplicities. - let mut phi = vec![EF::zero(); permutation_trace.height()]; - let nb_sends = sends.len(); - for (i, (main_row, permutation_row)) in main - .rows() - .zip(permutation_trace.as_view_mut().rows_mut()) - .enumerate() - { - if i > 0 { - phi[i] = phi[i - 1]; - } - // All all sends - for (j, send) in sends.iter().enumerate() { - let mult = send.multiplicity.apply::(&[], main_row); - phi[i] += EF::from_base(mult) * permutation_row[j]; - } - // Subtract all receives - for (j, rec) in receives.iter().enumerate() { - let mult = rec.multiplicity.apply::(&[], main_row); - phi[i] -= EF::from_base(mult) * permutation_row[nb_sends + j]; - } - *permutation_row.last_mut().unwrap() = phi[i]; + // Write the cumultative sum. + let mut cumulative_sum = EF::zero(); + for permutation_row in permutation_trace.as_view_mut().rows_mut() { + cumulative_sum += permutation_row[0..permutation_trace_width - 1] + .iter() + .copied() + .sum::(); + *permutation_row.last_mut().unwrap() = cumulative_sum; } permutation_trace @@ -176,6 +203,7 @@ pub(crate) fn generate_permutation_trace>( pub fn eval_permutation_constraints( sends: &[Interaction], receives: &[Interaction], + batch_size: usize, builder: &mut AB, ) where F: Field, @@ -188,84 +216,94 @@ pub fn eval_permutation_constraints( let main = builder.main(); let main_local: &[AB::Var] = main.row_slice(0); - let main_next: &[AB::Var] = main.row_slice(1); let preprocessed = builder.preprocessed(); let preprocessed_local = preprocessed.row_slice(0); - let preprocessed_next = preprocessed.row_slice(1); let perm = builder.permutation(); let perm_width = perm.width(); let perm_local: &[AB::VarEF] = perm.row_slice(0); let perm_next: &[AB::VarEF] = perm.row_slice(1); - let phi_local = perm_local[perm_width - 1]; - let phi_next = perm_next[perm_width - 1]; - let alphas = generate_interaction_rlc_elements(sends, receives, alpha); let betas = beta.powers(); - let lhs: AB::ExprEF = phi_next.into() - phi_local.into(); - let mut rhs = AB::ExprEF::zero(); - let mut phi_0 = AB::ExprEF::zero(); + // Ensure that each batch sum m_i/f_i is computed correctly. + let interaction_chunks = &sends + .iter() + .map(|int| (int, true)) + .chain(receives.iter().map(|int| (int, false))) + .chunks(batch_size); + for (entry, chunk) in perm_local.iter().zip(interaction_chunks) { + // Assert that the i-eth entry is equal to the sum_i m_i/rlc_i by constraints: + // entry * \prod_i rlc_i = \sum_i m_i * \prod_{j!=i} rlc_j. + + // First, we calculate the random linear combinations and multiplicities with the correct + // sign depending on wetther the interaction is a send or a recieve. + let mut rlcs: Vec = Vec::with_capacity(batch_size); + let mut multiplicities: Vec = Vec::with_capacity(batch_size); + for (interaction, is_send) in chunk { + let mut rlc = AB::ExprEF::zero(); + for (field, beta) in interaction.values.iter().zip(betas.clone()) { + let elem = field.apply::(preprocessed_local, main_local); + rlc += beta * elem; + } + rlc += alphas[interaction.argument_index()].clone(); + rlcs.push(rlc); - let nb_sends = sends.len(); - for (m, interaction) in sends.iter().chain(receives.iter()).enumerate() { - // Ensure that the recipricals of the RLC's were properly calculated. - let mut rlc = AB::ExprEF::zero(); - for (field, beta) in interaction.values.iter().zip(betas.clone()) { - let elem = field.apply::(preprocessed_local, main_local); - rlc += beta * elem; + let send_factor = if is_send { AB::F::one() } else { -AB::F::one() }; + multiplicities.push( + interaction + .multiplicity + .apply::(preprocessed_local, main_local) + * send_factor, + ); } - rlc += alphas[interaction.argument_index()].clone(); - builder.assert_one_ext(rlc * perm_local[m].into()); - - let mult_local = interaction - .multiplicity - .apply::(preprocessed_local, main_local); - let mult_next = interaction - .multiplicity - .apply::(preprocessed_next, main_next); - // Ensure that the running sum is computed correctly. - if m < nb_sends { - phi_0 += perm_local[m].into() * mult_local; - rhs += perm_next[m].into() * mult_next; - } else { - phi_0 -= perm_local[m].into() * mult_local; - rhs -= perm_next[m].into() * mult_next; + // Now we can calculate the numerator and denominator of the combined batch. + let mut product = AB::ExprEF::one(); + let mut numerator = AB::ExprEF::zero(); + for (i, (m, rlc)) in multiplicities.into_iter().zip(rlcs.iter()).enumerate() { + // Calculate the running product of all rlcs. + product *= rlc.clone(); + // Calculate the product of all but the current rlc. + let mut all_but_current = AB::ExprEF::one(); + for other_rlc in rlcs + .iter() + .enumerate() + .filter(|(j, _)| i != *j) + .map(|(_, rlc)| rlc) + { + all_but_current *= other_rlc.clone(); + } + numerator += AB::ExprEF::from_base(m) * all_but_current; } + + // Finally, assert that the entry is equal to the numerator divided by the product. + let entry: AB::ExprEF = (*entry).into(); + builder.assert_eq_ext(product.clone() * entry.clone(), numerator); } - // Running sum constraints. - builder.when_transition().assert_eq_ext(lhs, rhs); + let sum_local = perm_local[..perm_width - 1] + .iter() + .map(|x| (*x).into()) + .sum::(); + + let sum_next = perm_next[..perm_width - 1] + .iter() + .map(|x| (*x).into()) + .sum::(); + + let phi_local: AB::ExprEF = (*perm_local.last().unwrap()).into(); + let phi_next: AB::ExprEF = (*perm_next.last().unwrap()).into(); builder - .when_first_row() - .assert_eq_ext(*perm_local.last().unwrap(), phi_0); + .when_transition() + .assert_eq_ext(phi_next - phi_local.clone(), sum_next); + + builder.when_first_row().assert_eq_ext(phi_local, sum_local); let cumulative_sum = builder.cumulative_sum(); builder .when_last_row() .assert_eq_ext(*perm_local.last().unwrap(), cumulative_sum); } - -/// Computes the permutation fingerprint of a row. -pub fn compute_permutation_row>( - preprocessed_row: &[F], - main_row: &[F], - sends: &[Interaction], - receives: &[Interaction], - alphas: &[EF], - betas: Powers, -) -> Vec { - let width = sends.len() + receives.len() + 1; - let mut row = vec![EF::zero(); width]; - for (i, interaction) in sends.iter().chain(receives.iter()).enumerate() { - let alpha = alphas[interaction.argument_index()]; - row[i] = alpha; - for (columns, beta) in interaction.values.iter().zip(betas.clone()) { - row[i] += beta * columns.apply::(preprocessed_row, main_row) - } - } - row -} diff --git a/core/src/stark/util.rs b/core/src/stark/util.rs index 2cae15a9f1..8ba06e98fb 100644 --- a/core/src/stark/util.rs +++ b/core/src/stark/util.rs @@ -1,6 +1,7 @@ use p3_field::Field; /// An implementation of `batch_multiplicative_inverse` that operates in place. +#[allow(dead_code)] pub fn batch_multiplicative_inverse_inplace(values: &mut [F]) { // Check if values are zero and construct a new vector with only nonzero values. let mut nonzero_values = Vec::with_capacity(values.len()); diff --git a/recursion/circuit/src/types.rs b/recursion/circuit/src/types.rs index e77560b6a1..61de3b1cb8 100644 --- a/recursion/circuit/src/types.rs +++ b/recursion/circuit/src/types.rs @@ -170,7 +170,9 @@ impl ChipOpening { local: vec![], next: vec![], }; - let permutation_width = C::EF::D * (chip.num_interactions() + 1); + let permutation_width = + C::EF::D * ((chip.num_interactions() + 1) / chip.logup_batch_size() + 1); + for i in 0..permutation_width { permutation.local.push(opening.permutation.local[i]); permutation.next.push(opening.permutation.next[i]); diff --git a/recursion/program/src/types.rs b/recursion/program/src/types.rs index 7b7ddd0c6c..ab370bddfe 100644 --- a/recursion/program/src/types.rs +++ b/recursion/program/src/types.rs @@ -102,7 +102,8 @@ impl ChipOpening { local: vec![], next: vec![], }; - let permutation_width = C::EF::D * (chip.num_interactions() + 1); + let permutation_width = + C::EF::D * ((chip.num_interactions() + 1) / chip.logup_batch_size() + 1); for i in 0..permutation_width { permutation .local