From 813e2d73491d56e89deb89109eee35c22832885e Mon Sep 17 00:00:00 2001 From: Tamir Hemo Date: Wed, 3 Apr 2024 10:57:58 -0700 Subject: [PATCH] fix: public inputs in recursion program (#467) --- core/src/air/public_values.rs | 8 +++- recursion/program/src/challenger.rs | 9 ++++ recursion/program/src/constraints.rs | 50 +++++++++++++++++----- recursion/program/src/folder.rs | 7 +-- recursion/program/src/stark.rs | 64 +++++++++++++++++++++------- recursion/program/src/types.rs | 3 ++ 6 files changed, 112 insertions(+), 29 deletions(-) diff --git a/core/src/air/public_values.rs b/core/src/air/public_values.rs index e6cc34d3b6..1d5a557a5f 100644 --- a/core/src/air/public_values.rs +++ b/core/src/air/public_values.rs @@ -14,7 +14,7 @@ use super::Word; // TODO: Create a config struct that will store the num_words setting and the hash function // and initial entropy used. -const PV_DIGEST_NUM_WORDS: usize = 8; +pub const PV_DIGEST_NUM_WORDS: usize = 8; /// The PublicValuesDigest struct is used to represent the public values digest. This is the hash of all the /// bytes that the guest program has written to public values. @@ -81,6 +81,12 @@ impl From>> for Vec { } } +impl From<[T; PV_DIGEST_NUM_WORDS]> for PublicValuesDigest { + fn from(words: [T; PV_DIGEST_NUM_WORDS]) -> Self { + PublicValuesDigest(words) + } +} + /// Implement the IndexMut trait for PublicValuesDigest to index specific words. impl IndexMut for PublicValuesDigest { fn index_mut(&mut self, index: usize) -> &mut Self::Output { diff --git a/recursion/program/src/challenger.rs b/recursion/program/src/challenger.rs index 7043f6c855..3d54b15dc9 100644 --- a/recursion/program/src/challenger.rs +++ b/recursion/program/src/challenger.rs @@ -7,6 +7,15 @@ use crate::types::Commitment; pub trait CanObserveVariable { fn observe(&mut self, builder: &mut Builder, value: V); + + fn observe_slice(&mut self, builder: &mut Builder, values: &[V]) + where + V: Copy, + { + for value in values { + self.observe(builder, *value); + } + } } pub trait CanSampleVariable { diff --git a/recursion/program/src/constraints.rs b/recursion/program/src/constraints.rs index 729fd4bcec..50b7af143e 100644 --- a/recursion/program/src/constraints.rs +++ b/recursion/program/src/constraints.rs @@ -8,8 +8,11 @@ use p3_field::AbstractExtensionField; use p3_field::AbstractField; use p3_field::TwoAdicField; use sp1_core::air::MachineAir; +use sp1_core::air::PublicValuesDigest; +use sp1_core::air::Word; use sp1_core::stark::AirOpenedValues; use sp1_core::stark::{MachineChip, StarkGenericConfig}; +use sp1_recursion_compiler::ir::Felt; use crate::commit::PolynomialSpaceVariable; @@ -30,6 +33,7 @@ where builder: &mut Builder, chip: &MachineChip, opening: &ChipOpening, + public_values_digest: PublicValuesDigest>>, selectors: &LagrangeSelectors>, alpha: Ext, permutation_challenges: &[C::EF], @@ -56,6 +60,7 @@ where }; let zero: Ext = builder.eval(SC::Val::zero()); + let public_values: Vec> = public_values_digest.into(); let mut folder = RecursiveVerifierConstraintFolder { builder, preprocessed: opening.preprocessed.view(), @@ -63,6 +68,7 @@ where perm: perm_opening.view(), perm_challenges: permutation_challenges, cumulative_sum: opening.cumulative_sum, + public_values: &public_values, is_first_row: selectors.is_first_row, is_last_row: selectors.is_last_row, is_transition: selectors.is_transition, @@ -124,6 +130,7 @@ where builder: &mut Builder, chip: &MachineChip, opening: &ChipOpenedValuesVariable, + public_values_digest: PublicValuesDigest>>, trace_domain: TwoAdicMultiplicativeCosetVariable, qc_domains: Vec>, zeta: Ext, @@ -139,6 +146,7 @@ where builder, chip, &opening, + public_values_digest, &sels, alpha, permutation_challenges, @@ -156,6 +164,7 @@ mod tests { use itertools::{izip, Itertools}; use serde::{de::DeserializeOwned, Serialize}; use sp1_core::{ + air::{PublicValuesDigest, Word}, runtime::Program, stark::{ Chip, Com, Dom, MachineStark, OpeningProof, PcsProverData, RiscvAir, ShardCommitment, @@ -272,7 +281,6 @@ mod tests { } #[test] - #[ignore] fn test_verify_constraints_parts() { type SC = BabyBearPoseidon2; type F = ::Val; @@ -287,21 +295,25 @@ mod tests { let machine = A::machine(SC::default()); let (_, vk) = machine.setup(&Program::from(elf)); let mut challenger = machine.config().challenger(); - let proofs = SP1Prover::prove_with_config(elf, SP1Stdin::new(), machine.config().clone()) + let proof = SP1Prover::prove_with_config(elf, SP1Stdin::new(), machine.config().clone()) .unwrap() - .proof - .shard_proofs; + .proof; println!("Proof generated successfully"); challenger.observe(vk.commit); - proofs.iter().for_each(|proof| { + proof.shard_proofs.iter().for_each(|proof| { challenger.observe(proof.commitment.main_commit); }); + // Observe the public input digest + let pv_digest_field_elms: Vec = + PublicValuesDigest::>::new(proof.public_values_digest).into(); + challenger.observe_slice(&pv_digest_field_elms); + // Run the verify inside the DSL and compare it to the calculated value. let mut builder = VmBuilder::::default(); - for proof in proofs.into_iter().take(1) { + for proof in proof.shard_proofs.into_iter().take(1) { let ( chips, trace_domains_vals, @@ -311,6 +323,12 @@ mod tests { zeta_val, ) = get_shard_data(&machine, &proof, &mut challenger); + // Set up the public values digest. + let public_values_digest = PublicValuesDigest::from(core::array::from_fn(|i| { + let word_val = proof.public_values_digest[i]; + Word(core::array::from_fn(|j| builder.eval(word_val[j]))) + })); + for (chip, trace_domain_val, qc_domains_vals, values_vals) in izip!( chips.iter(), trace_domains_vals, @@ -341,6 +359,7 @@ mod tests { &mut builder, chip, &values, + public_values_digest, &sels, alpha, permutation_challenges.as_slice(), @@ -385,7 +404,6 @@ mod tests { } #[test] - #[ignore] fn test_verify_constraints_whole() { type SC = BabyBearPoseidon2; type F = ::Val; @@ -404,19 +422,24 @@ mod tests { SP1Prover::prove_with_config(elf, SP1Stdin::new(), machine.config().clone()).unwrap(); SP1Verifier::verify_with_config(elf, &proof, machine.config().clone()).unwrap(); - let proofs = proof.proof.shard_proofs; + let proof = proof.proof; println!("Proof generated and verified successfully"); challenger.observe(vk.commit); - proofs.iter().for_each(|proof| { + proof.shard_proofs.iter().for_each(|proof| { challenger.observe(proof.commitment.main_commit); }); + // Observe the public input digest + let pv_digest_field_elms: Vec = + PublicValuesDigest::>::new(proof.public_values_digest).into(); + challenger.observe_slice(&pv_digest_field_elms); + // Run the verify inside the DSL and compare it to the calculated value. let mut builder = VmBuilder::::default(); - for proof in proofs.into_iter().take(1) { + for proof in proof.shard_proofs.into_iter().take(1) { let ( chips, trace_domains_vals, @@ -441,10 +464,17 @@ mod tests { .map(|domain| builder.eval_const(*domain)) .collect::>(); + // Set up the public values digest. + let public_values_digest = PublicValuesDigest::from(core::array::from_fn(|i| { + let word_val = proof.public_values_digest[i]; + Word(core::array::from_fn(|j| builder.eval(word_val[j]))) + })); + StarkVerifier::<_, SC>::verify_constraints::( &mut builder, chip, &opening, + public_values_digest, trace_domain, qc_domains, zeta, diff --git a/recursion/program/src/folder.rs b/recursion/program/src/folder.rs index 36ad05810a..77b667d4e8 100644 --- a/recursion/program/src/folder.rs +++ b/recursion/program/src/folder.rs @@ -5,7 +5,7 @@ use p3_air::{ use sp1_core::air::{EmptyMessageBuilder, MultiTableAirBuilder, PublicValuesBuilder}; use sp1_recursion_compiler::{ - ir::{Builder, Config, Ext}, + ir::{Builder, Config, Ext, Felt}, prelude::SymbolicExt, }; @@ -15,6 +15,7 @@ pub struct RecursiveVerifierConstraintFolder<'a, C: Config> { pub main: TwoRowMatrixView<'a, Ext>, pub perm: TwoRowMatrixView<'a, Ext>, pub perm_challenges: &'a [C::EF], + pub public_values: &'a [Felt], pub cumulative_sum: Ext, pub is_first_row: Ext, pub is_last_row: Ext, @@ -101,9 +102,9 @@ impl<'a, C: Config> EmptyMessageBuilder for RecursiveVerifierConstraintFolder<'a impl<'a, C: Config> PublicValuesBuilder for RecursiveVerifierConstraintFolder<'a, C> {} impl<'a, C: Config> AirBuilderWithPublicValues for RecursiveVerifierConstraintFolder<'a, C> { - type PublicVar = C::F; + type PublicVar = Felt; fn public_values(&self) -> &[Self::PublicVar] { - &[] + self.public_values } } diff --git a/recursion/program/src/stark.rs b/recursion/program/src/stark.rs index b7a7526903..70fb98d9b1 100644 --- a/recursion/program/src/stark.rs +++ b/recursion/program/src/stark.rs @@ -243,6 +243,7 @@ where builder, chip, &values, + proof.public_values_digest, trace_domain, qc_domains, zeta, @@ -262,6 +263,7 @@ pub(crate) mod tests { use crate::challenger::FeltChallenger; use p3_challenger::{CanObserve, FieldChallenger}; use p3_field::AbstractField; + use sp1_core::air::PublicValuesDigest; use sp1_core::runtime::Program; use sp1_core::{ air::MachineAir, @@ -269,6 +271,7 @@ pub(crate) mod tests { utils::BabyBearPoseidon2, }; use sp1_recursion_compiler::ir::Array; + use sp1_recursion_compiler::ir::Felt; use sp1_recursion_compiler::{ asm::{AsmConfig, VmBuilder}, ir::{Builder, Config, ExtConst, Usize}, @@ -276,6 +279,8 @@ pub(crate) mod tests { use sp1_recursion_core::runtime::{Runtime, DIGEST_SIZE}; use sp1_sdk::{SP1Prover, SP1Stdin}; + use sp1_core::air::Word; + use crate::{ challenger::DuplexChallengerVariable, fri::{ @@ -303,6 +308,12 @@ pub(crate) mod tests { { let index = builder.materialize(Usize::Const(proof.index)); + // Set up the public values digest. + let public_values_digest = PublicValuesDigest::from(core::array::from_fn(|i| { + let word_val = proof.public_values_digest[i]; + Word(core::array::from_fn(|j| builder.eval(word_val[j]))) + })); + // Set up the commitments. let mut main_commit: Commitment<_> = builder.dyn_array(DIGEST_SIZE); let mut permutation_commit: Commitment<_> = builder.dyn_array(DIGEST_SIZE); @@ -360,6 +371,7 @@ pub(crate) mod tests { opened_values, opening_proof, sorted_indices, + public_values_digest, } } @@ -427,7 +439,6 @@ pub(crate) mod tests { } #[test] - #[ignore] fn test_recursive_verify_shard() { // Generate a dummy proof. sp1_core::utils::setup_logger(); @@ -438,17 +449,23 @@ pub(crate) mod tests { let (_, vk) = machine.setup(&Program::from(elf)); let mut challenger_val = machine.config().challenger(); - let proofs = SP1Prover::prove_with_config(elf, SP1Stdin::new(), machine.config().clone()) + let proof = SP1Prover::prove_with_config(elf, SP1Stdin::new(), machine.config().clone()) .unwrap() - .proof - .shard_proofs; + .proof; + let mut challenger_ver = machine.config().challenger(); + machine.verify(&vk, &proof, &mut challenger_ver).unwrap(); println!("Proof generated successfully"); challenger_val.observe(vk.commit); - proofs.iter().for_each(|proof| { + proof.shard_proofs.iter().for_each(|proof| { challenger_val.observe(proof.commitment.main_commit); }); + // Observe the public input digest + let pv_digest_field_elms: Vec = + PublicValuesDigest::>::new(proof.public_values_digest).into(); + challenger_val.observe_slice(&pv_digest_field_elms); + let permutation_challenges = (0..2) .map(|_| challenger_val.sample_ext_element::()) .collect::>(); @@ -465,12 +482,18 @@ pub(crate) mod tests { challenger.observe(&mut builder, preprocessed_commit); let mut shard_proofs = vec![]; - for proof_val in proofs { + for proof_val in proof.shard_proofs { let proof = const_proof(&mut builder, &machine, proof_val); let ShardCommitment { main_commit, .. } = &proof.commitment; challenger.observe(&mut builder, main_commit.clone()); shard_proofs.push(proof); } + // Observe the public input digest + let pv_digest_felt: Vec> = pv_digest_field_elms + .iter() + .map(|x| builder.eval(*x)) + .collect(); + challenger.observe_slice(&mut builder, &pv_digest_felt); for proof in shard_proofs { StarkVerifier::::verify_shard( @@ -508,30 +531,36 @@ pub(crate) mod tests { let machine = A::machine(SC::default()); let (_, vk) = machine.setup(&Program::from(elf)); let mut challenger_val = machine.config().challenger(); - let proofs = SP1Prover::prove_with_config(elf, SP1Stdin::new(), machine.config().clone()) + let proof = SP1Prover::prove_with_config(elf, SP1Stdin::new(), machine.config().clone()) .unwrap() - .proof - .shard_proofs; + .proof; println!("Proof generated successfully"); - - proofs.iter().for_each(|proof| { + challenger_val.observe(vk.commit); + proof.shard_proofs.iter().for_each(|proof| { challenger_val.observe(proof.commitment.main_commit); }); + // Observe the public input digest + let pv_digest_field_elms: Vec = + PublicValuesDigest::>::new(proof.public_values_digest).into(); + challenger_val.observe_slice(&pv_digest_field_elms); + let permutation_challenges = (0..2) .map(|_| challenger_val.sample_ext_element::()) .collect::>(); - // Observe all the commitments. let mut builder = VmBuilder::::default(); let config = const_fri_config(&mut builder, default_fri_config()); let pcs = TwoAdicFriPcsVariable { config }; let mut challenger = DuplexChallengerVariable::new(&mut builder); + let preprocessed_commit_val: [F; DIGEST_SIZE] = vk.commit.into(); + let preprocessed_commit: Array = builder.eval_const(preprocessed_commit_val.to_vec()); + challenger.observe(&mut builder, preprocessed_commit); + let mut shard_proofs = vec![]; - for proof_val in proofs { - // Change a commitment to be incorrect. + for proof_val in proof.shard_proofs { let mut proof_val = proof_val; proof_val.commitment.main_commit = [F::zero(); DIGEST_SIZE].into(); let proof = const_proof(&mut builder, &machine, proof_val); @@ -539,7 +568,12 @@ pub(crate) mod tests { challenger.observe(&mut builder, main_commit.clone()); shard_proofs.push(proof); } - + // Observe the public input digest + let pv_digest_felt: Vec> = pv_digest_field_elms + .iter() + .map(|x| builder.eval(*x)) + .collect(); + challenger.observe_slice(&mut builder, &pv_digest_felt); for proof in shard_proofs { StarkVerifier::::verify_shard( &mut builder, diff --git a/recursion/program/src/types.rs b/recursion/program/src/types.rs index 23781219be..d247e5b92a 100644 --- a/recursion/program/src/types.rs +++ b/recursion/program/src/types.rs @@ -1,6 +1,8 @@ use p3_air::BaseAir; use p3_field::AbstractExtensionField; use p3_field::AbstractField; +use sp1_core::air::PublicValuesDigest; +use sp1_core::air::Word; use sp1_core::{ air::MachineAir, stark::{AirOpenedValues, Chip, ChipOpenedValues, ShardCommitment}, @@ -65,6 +67,7 @@ pub struct ShardProofVariable { pub commitment: ShardCommitment>, pub opened_values: ShardOpenedValuesVariable, pub opening_proof: TwoAdicPcsProofVariable, + pub public_values_digest: PublicValuesDigest>>, pub sorted_indices: Vec>, }