Skip to content

Commit

Permalink
fix: public inputs in recursion program (#467)
Browse files Browse the repository at this point in the history
  • Loading branch information
tamirhemo authored Apr 3, 2024
1 parent bde25d6 commit 813e2d7
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 29 deletions.
8 changes: 7 additions & 1 deletion core/src/air/public_values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use super::Word;

// TODO: Create a config struct that will store the num_words setting and the hash function
// and initial entropy used.
const PV_DIGEST_NUM_WORDS: usize = 8;
pub const PV_DIGEST_NUM_WORDS: usize = 8;

/// The PublicValuesDigest struct is used to represent the public values digest. This is the hash of all the
/// bytes that the guest program has written to public values.
Expand Down Expand Up @@ -81,6 +81,12 @@ impl<T: Debug + Copy> From<PublicValuesDigest<Word<T>>> for Vec<T> {
}
}

impl<T> From<[T; PV_DIGEST_NUM_WORDS]> for PublicValuesDigest<T> {
fn from(words: [T; PV_DIGEST_NUM_WORDS]) -> Self {
PublicValuesDigest(words)
}
}

/// Implement the IndexMut trait for PublicValuesDigest to index specific words.
impl<T> IndexMut<usize> for PublicValuesDigest<T> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
Expand Down
9 changes: 9 additions & 0 deletions recursion/program/src/challenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ use crate::types::Commitment;

pub trait CanObserveVariable<C: Config, V> {
fn observe(&mut self, builder: &mut Builder<C>, value: V);

fn observe_slice(&mut self, builder: &mut Builder<C>, values: &[V])
where
V: Copy,
{
for value in values {
self.observe(builder, *value);
}
}
}

pub trait CanSampleVariable<C: Config, V> {
Expand Down
50 changes: 40 additions & 10 deletions recursion/program/src/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ use p3_field::AbstractExtensionField;
use p3_field::AbstractField;
use p3_field::TwoAdicField;
use sp1_core::air::MachineAir;
use sp1_core::air::PublicValuesDigest;
use sp1_core::air::Word;
use sp1_core::stark::AirOpenedValues;
use sp1_core::stark::{MachineChip, StarkGenericConfig};
use sp1_recursion_compiler::ir::Felt;

use crate::commit::PolynomialSpaceVariable;

Expand All @@ -30,6 +33,7 @@ where
builder: &mut Builder<C>,
chip: &MachineChip<SC, A>,
opening: &ChipOpening<C>,
public_values_digest: PublicValuesDigest<Word<Felt<C::F>>>,
selectors: &LagrangeSelectors<Ext<C::F, C::EF>>,
alpha: Ext<C::F, C::EF>,
permutation_challenges: &[C::EF],
Expand All @@ -56,13 +60,15 @@ where
};

let zero: Ext<SC::Val, SC::Challenge> = builder.eval(SC::Val::zero());
let public_values: Vec<Felt<C::F>> = public_values_digest.into();
let mut folder = RecursiveVerifierConstraintFolder {
builder,
preprocessed: opening.preprocessed.view(),
main: opening.main.view(),
perm: perm_opening.view(),
perm_challenges: permutation_challenges,
cumulative_sum: opening.cumulative_sum,
public_values: &public_values,
is_first_row: selectors.is_first_row,
is_last_row: selectors.is_last_row,
is_transition: selectors.is_transition,
Expand Down Expand Up @@ -124,6 +130,7 @@ where
builder: &mut Builder<C>,
chip: &MachineChip<SC, A>,
opening: &ChipOpenedValuesVariable<C>,
public_values_digest: PublicValuesDigest<Word<Felt<C::F>>>,
trace_domain: TwoAdicMultiplicativeCosetVariable<C>,
qc_domains: Vec<TwoAdicMultiplicativeCosetVariable<C>>,
zeta: Ext<C::F, C::EF>,
Expand All @@ -139,6 +146,7 @@ where
builder,
chip,
&opening,
public_values_digest,
&sels,
alpha,
permutation_challenges,
Expand All @@ -156,6 +164,7 @@ mod tests {
use itertools::{izip, Itertools};
use serde::{de::DeserializeOwned, Serialize};
use sp1_core::{
air::{PublicValuesDigest, Word},
runtime::Program,
stark::{
Chip, Com, Dom, MachineStark, OpeningProof, PcsProverData, RiscvAir, ShardCommitment,
Expand Down Expand Up @@ -272,7 +281,6 @@ mod tests {
}

#[test]
#[ignore]
fn test_verify_constraints_parts() {
type SC = BabyBearPoseidon2;
type F = <SC as StarkGenericConfig>::Val;
Expand All @@ -287,21 +295,25 @@ mod tests {
let machine = A::machine(SC::default());
let (_, vk) = machine.setup(&Program::from(elf));
let mut challenger = machine.config().challenger();
let proofs = SP1Prover::prove_with_config(elf, SP1Stdin::new(), machine.config().clone())
let proof = SP1Prover::prove_with_config(elf, SP1Stdin::new(), machine.config().clone())
.unwrap()
.proof
.shard_proofs;
.proof;
println!("Proof generated successfully");

challenger.observe(vk.commit);
proofs.iter().for_each(|proof| {
proof.shard_proofs.iter().for_each(|proof| {
challenger.observe(proof.commitment.main_commit);
});

// Observe the public input digest
let pv_digest_field_elms: Vec<F> =
PublicValuesDigest::<Word<F>>::new(proof.public_values_digest).into();
challenger.observe_slice(&pv_digest_field_elms);

// Run the verify inside the DSL and compare it to the calculated value.
let mut builder = VmBuilder::<F, EF>::default();

for proof in proofs.into_iter().take(1) {
for proof in proof.shard_proofs.into_iter().take(1) {
let (
chips,
trace_domains_vals,
Expand All @@ -311,6 +323,12 @@ mod tests {
zeta_val,
) = get_shard_data(&machine, &proof, &mut challenger);

// Set up the public values digest.
let public_values_digest = PublicValuesDigest::from(core::array::from_fn(|i| {
let word_val = proof.public_values_digest[i];
Word(core::array::from_fn(|j| builder.eval(word_val[j])))
}));

for (chip, trace_domain_val, qc_domains_vals, values_vals) in izip!(
chips.iter(),
trace_domains_vals,
Expand Down Expand Up @@ -341,6 +359,7 @@ mod tests {
&mut builder,
chip,
&values,
public_values_digest,
&sels,
alpha,
permutation_challenges.as_slice(),
Expand Down Expand Up @@ -385,7 +404,6 @@ mod tests {
}

#[test]
#[ignore]
fn test_verify_constraints_whole() {
type SC = BabyBearPoseidon2;
type F = <SC as StarkGenericConfig>::Val;
Expand All @@ -404,19 +422,24 @@ mod tests {
SP1Prover::prove_with_config(elf, SP1Stdin::new(), machine.config().clone()).unwrap();
SP1Verifier::verify_with_config(elf, &proof, machine.config().clone()).unwrap();

let proofs = proof.proof.shard_proofs;
let proof = proof.proof;
println!("Proof generated and verified successfully");

challenger.observe(vk.commit);

proofs.iter().for_each(|proof| {
proof.shard_proofs.iter().for_each(|proof| {
challenger.observe(proof.commitment.main_commit);
});

// Observe the public input digest
let pv_digest_field_elms: Vec<F> =
PublicValuesDigest::<Word<F>>::new(proof.public_values_digest).into();
challenger.observe_slice(&pv_digest_field_elms);

// Run the verify inside the DSL and compare it to the calculated value.
let mut builder = VmBuilder::<F, EF>::default();

for proof in proofs.into_iter().take(1) {
for proof in proof.shard_proofs.into_iter().take(1) {
let (
chips,
trace_domains_vals,
Expand All @@ -441,10 +464,17 @@ mod tests {
.map(|domain| builder.eval_const(*domain))
.collect::<Vec<_>>();

// Set up the public values digest.
let public_values_digest = PublicValuesDigest::from(core::array::from_fn(|i| {
let word_val = proof.public_values_digest[i];
Word(core::array::from_fn(|j| builder.eval(word_val[j])))
}));

StarkVerifier::<_, SC>::verify_constraints::<A>(
&mut builder,
chip,
&opening,
public_values_digest,
trace_domain,
qc_domains,
zeta,
Expand Down
7 changes: 4 additions & 3 deletions recursion/program/src/folder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use p3_air::{
use sp1_core::air::{EmptyMessageBuilder, MultiTableAirBuilder, PublicValuesBuilder};

use sp1_recursion_compiler::{
ir::{Builder, Config, Ext},
ir::{Builder, Config, Ext, Felt},
prelude::SymbolicExt,
};

Expand All @@ -15,6 +15,7 @@ pub struct RecursiveVerifierConstraintFolder<'a, C: Config> {
pub main: TwoRowMatrixView<'a, Ext<C::F, C::EF>>,
pub perm: TwoRowMatrixView<'a, Ext<C::F, C::EF>>,
pub perm_challenges: &'a [C::EF],
pub public_values: &'a [Felt<C::F>],
pub cumulative_sum: Ext<C::F, C::EF>,
pub is_first_row: Ext<C::F, C::EF>,
pub is_last_row: Ext<C::F, C::EF>,
Expand Down Expand Up @@ -101,9 +102,9 @@ impl<'a, C: Config> EmptyMessageBuilder for RecursiveVerifierConstraintFolder<'a
impl<'a, C: Config> PublicValuesBuilder for RecursiveVerifierConstraintFolder<'a, C> {}

impl<'a, C: Config> AirBuilderWithPublicValues for RecursiveVerifierConstraintFolder<'a, C> {
type PublicVar = C::F;
type PublicVar = Felt<C::F>;

fn public_values(&self) -> &[Self::PublicVar] {
&[]
self.public_values
}
}
Loading

0 comments on commit 813e2d7

Please sign in to comment.