Skip to content

Commit

Permalink
feat: stark cleanup and verification (#556)
Browse files Browse the repository at this point in the history
  • Loading branch information
tamirhemo authored Apr 22, 2024
1 parent b579519 commit ab103ee
Show file tree
Hide file tree
Showing 14 changed files with 288 additions and 49 deletions.
3 changes: 3 additions & 0 deletions core/src/cpu/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ impl CpuChip {
// Verify that the clk increments are correct. Most clk increment should be 4, but for some
// precompiles, there are additional cycles.
let num_extra_cycles = self.get_num_extra_ecall_cycles::<AB>(local);

// We already assert that `local.clk < 2^24`. `num_extra_cycles` is an entry of a word and
// therefore less than `2^8`, this means that the sum cannot overflow in a 31 bit field.
let expected_next_clk =
local.clk + AB::Expr::from_canonical_u32(4) + num_extra_cycles.clone();

Expand Down
29 changes: 27 additions & 2 deletions core/src/stark/chip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ use p3_util::log2_ceil_usize;

use crate::{
air::{MachineAir, MultiTableAirBuilder, SP1AirBuilder},
lookup::{Interaction, InteractionBuilder},
lookup::{Interaction, InteractionBuilder, InteractionKind},
};

use super::{eval_permutation_constraints, generate_permutation_trace};
use super::{eval_permutation_constraints, generate_permutation_trace, permutation_trace_width};

/// An Air that encodes lookups based on interactions.
pub struct Chip<F: Field, A> {
Expand Down Expand Up @@ -74,10 +74,21 @@ where
}
}

#[inline]
pub fn num_interactions(&self) -> usize {
self.sends.len() + self.receives.len()
}

#[inline]
pub fn num_sends_by_kind(&self, kind: InteractionKind) -> usize {
self.sends.iter().filter(|i| i.kind == kind).count()
}

#[inline]
pub fn num_receives_by_kind(&self, kind: InteractionKind) -> usize {
self.receives.iter().filter(|i| i.kind == kind).count()
}

pub fn generate_permutation_trace<EF: ExtensionField<F>>(
&self,
preprocessed: Option<&RowMajorMatrix<F>>,
Expand All @@ -98,6 +109,20 @@ where
)
}

#[inline]
pub fn permutation_width(&self) -> usize {
permutation_trace_width(
self.sends().len() + self.receives().len(),
self.logup_batch_size(),
)
}

#[inline]
pub fn quotient_width(&self) -> usize {
1 << self.log_quotient_degree
}

#[inline]
pub fn logup_batch_size(&self) -> usize {
// TODO: calculate by log_quotient_degree.
2
Expand Down
42 changes: 37 additions & 5 deletions core/src/stark/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> StarkMachine<SC, A> {
vk: &StarkVerifyingKey<SC>,
proof: &MachineProof<SC>,
challenger: &mut SC::Challenger,
) -> Result<(PublicValuesDigest, DeferredDigest), ProgramVerificationError>
) -> Result<(PublicValuesDigest, DeferredDigest), ProgramVerificationError<SC>>
where
SC::Challenger: Clone,
A: for<'a> Air<VerifierConstraintFolder<'a, SC>>,
Expand Down Expand Up @@ -522,16 +522,48 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> StarkMachine<SC, A> {
}
}

#[derive(Debug)]
pub enum ProgramVerificationError {
InvalidSegmentProof(VerificationError),
InvalidGlobalProof(VerificationError),
pub enum ProgramVerificationError<SC: StarkGenericConfig> {
InvalidSegmentProof(VerificationError<SC>),
InvalidGlobalProof(VerificationError<SC>),
NonZeroCumulativeSum,
InvalidShardTransition(&'static str),
InvalidPublicValuesDigest,
DebugInteractionsFailed,
}

impl<SC: StarkGenericConfig> Debug for ProgramVerificationError<SC> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProgramVerificationError::InvalidSegmentProof(e) => {
write!(f, "Invalid segment proof: {:?}", e)
}
ProgramVerificationError::InvalidGlobalProof(e) => {
write!(f, "Invalid global proof: {:?}", e)
}
ProgramVerificationError::NonZeroCumulativeSum => {
write!(f, "Non-zero cumulative sum")
}
ProgramVerificationError::InvalidShardTransition(s) => {
write!(f, "Invalid shard transition: {}", s)
}
ProgramVerificationError::InvalidPublicValuesDigest => {
write!(f, "Invalid public values digest")
}
ProgramVerificationError::DebugInteractionsFailed => {
write!(f, "Debug interactions failed")
}
}
}
}

impl<SC: StarkGenericConfig> std::fmt::Display for ProgramVerificationError<SC> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(self, f)
}
}

impl<SC: StarkGenericConfig> std::error::Error for ProgramVerificationError<SC> {}

#[cfg(test)]
#[allow(non_snake_case)]
pub mod tests {
Expand Down
6 changes: 5 additions & 1 deletion core/src/stark/permutation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ pub fn populate_permutation_row<F: PrimeField, EF: ExtensionField<F>>(
}
}

pub const fn permutation_trace_width(num_interactions: usize, batch_size: usize) -> usize {
(num_interactions + 1) / batch_size + 1
}

/// Generates the permutation trace for the given chip and main trace based on a variant of LogUp.
///
/// The permutation trace has (N+1)*EF::NUM_COLS columns, where N is the number of interactions in
Expand All @@ -96,7 +100,7 @@ pub(crate) fn generate_permutation_trace<F: PrimeField, EF: ExtensionField<F>>(
//
// where f_{i, c_k} is the value at row i for column c_k. The computed value is essentially a
// fingerprint for the interaction.
let permutation_trace_width = (sends.len() + receives.len() + 1) / batch_size + 1;
let permutation_trace_width = permutation_trace_width(sends.len() + receives.len(), batch_size);
let height = main.height();

let mut permutation_trace = RowMajorMatrix::new(
Expand Down
3 changes: 1 addition & 2 deletions core/src/stark/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ where
permutation_trace_on_quotient_domains,
&packed_perm_challenges,
alpha,
shard_data.public_values.clone(),
&shard_data.public_values,
)
})
.collect::<Vec<_>>()
Expand Down Expand Up @@ -505,7 +505,6 @@ where
.collect::<Vec<_>>();

ShardProof::<SC> {
index: shard_data.index,
commitment: ShardCommitment {
main_commit: shard_data.main_commit.clone(),
permutation_commit,
Expand Down
5 changes: 2 additions & 3 deletions core/src/stark/quotient.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub fn quotient_values<SC, A, Mat>(
permutation_trace_on_quotient_domain: Mat,
perm_challenges: &[PackedChallenge<SC>],
alpha: SC::Challenge,
public_values: Vec<Val<SC>>,
public_values: &[Val<SC>],
) -> Vec<SC::Challenge>
where
A: for<'a> Air<ProverConstraintFolder<'a, SC>>,
Expand Down Expand Up @@ -116,7 +116,6 @@ where
.collect();

let accumulator = PackedChallenge::<SC>::zero();
let public_values = public_values.to_vec();
let mut folder = ProverConstraintFolder {
preprocessed: VerticalPair::new(
RowMajorMatrixView::new_row(&prep_local),
Expand All @@ -137,7 +136,7 @@ where
is_transition,
alpha,
accumulator,
public_values: &public_values,
public_values,
};
chip.eval(&mut folder);

Expand Down
1 change: 0 additions & 1 deletion core/src/stark/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ pub const PROOF_MAX_NUM_PVS: usize = SP1_PROOF_NUM_PV_ELTS;
#[derive(Serialize, Deserialize, Clone)]
#[serde(bound = "")]
pub struct ShardProof<SC: StarkGenericConfig> {
pub index: usize,
pub commitment: ShardCommitment<Com<SC>>,
pub opened_values: ShardOpenedValues<Challenge<SC>>,
pub opening_proof: OpeningProof<SC>,
Expand Down
Loading

0 comments on commit ab103ee

Please sign in to comment.