From c16cb9d84e732b036bad5c5244bdc9147071b598 Mon Sep 17 00:00:00 2001 From: John Guibas Date: Mon, 6 May 2024 15:05:31 -0700 Subject: [PATCH] feat: nextgen ci for sp1-prover (#663) --- .github/workflows/tests.yml | 31 +++++++- core/src/stark/machine.rs | 30 +++---- core/src/utils/prove.rs | 8 +- prover/Cargo.toml | 4 - prover/scripts/e2e.rs | 105 ------------------------ prover/src/build.rs | 4 +- prover/src/lib.rs | 127 ++++++++++++++++-------------- prover/src/verify.rs | 125 ++++++++++++++++++++++++----- recursion/core/src/stark/utils.rs | 6 +- sdk/src/lib.rs | 38 +++++---- sdk/src/provers/local.rs | 2 +- sdk/src/provers/mock.rs | 10 ++- sdk/src/provers/mod.rs | 14 +++- 13 files changed, 267 insertions(+), 237 deletions(-) delete mode 100644 prover/scripts/e2e.rs diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 350dcb3a45..d88183dee9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -19,7 +19,7 @@ concurrency: jobs: test: - name: Core + name: SP1 Prover E2E runs-on: warp-ubuntu-latest-arm64-32x env: CARGO_NET_GIT_FETCH_WITH_CLI: "true" @@ -36,13 +36,40 @@ jobs: uses: actions-rs/cargo@v1 with: command: test - args: -p sp1-core -p sp1-recursion-compiler -p sp1-recursion-program -p sp1-recursion-circuit -p sp1-sdk --release + args: -p sp1-prover --release env: RUSTFLAGS: -Copt-level=3 -Cdebug-assertions -Coverflow-checks=y -Cdebuginfo=0 -C target-cpu=native RUST_LOG: 1 RUST_BACKTRACE: 1 CARGO_INCREMENTAL: 1 FRI_QUERIES: 1 + SP1_DEV_WRAPPER: false + + test-main: + name: SP1 Tests + runs-on: warp-ubuntu-latest-arm64-32x + if: github.ref == 'refs/heads/main' + steps: + - name: Checkout sources + uses: actions/checkout@v2 + + - name: Setup CI + uses: ./.github/actions/setup + with: + pull_token: ${{ secrets.PULL_TOKEN }} + + - name: Run cargo test + uses: actions-rs/cargo@v1 + with: + command: test + args: --release + env: + RUSTFLAGS: -Copt-level=3 -Cdebug-assertions -Coverflow-checks=y -Cdebuginfo=0 -C target-cpu=native + RUST_LOG: 1 + RUST_BACKTRACE: 1 + CARGO_INCREMENTAL: 1 + FRI_QUERIES: 1 + SP1_DEV_WRAPPER: false lints: name: Formatting & Clippy diff --git a/core/src/stark/machine.rs b/core/src/stark/machine.rs index 84ce7b6d6c..7646ae5641 100644 --- a/core/src/stark/machine.rs +++ b/core/src/stark/machine.rs @@ -296,7 +296,7 @@ impl>> StarkMachine { vk: &StarkVerifyingKey, proof: &MachineProof, challenger: &mut SC::Challenger, - ) -> Result<(), ProgramVerificationError> + ) -> Result<(), MachineVerificationError> where SC::Challenger: Clone, A: for<'a> Air>, @@ -312,7 +312,7 @@ impl>> StarkMachine { // Verify the shard proofs. if proof.shard_proofs.is_empty() { - return Err(ProgramVerificationError::EmptyProof); + return Err(MachineVerificationError::EmptyProof); } tracing::debug_span!("verify shard proofs").in_scope(|| { @@ -328,7 +328,7 @@ impl>> StarkMachine { &mut challenger.clone(), shard_proof, ) - .map_err(ProgramVerificationError::InvalidSegmentProof) + .map_err(MachineVerificationError::InvalidSegmentProof) })?; } @@ -343,7 +343,7 @@ impl>> StarkMachine { } match sum.is_zero() { true => Ok(()), - false => Err(ProgramVerificationError::NonZeroCumulativeSum), + false => Err(MachineVerificationError::NonZeroCumulativeSum), } }) } @@ -461,7 +461,7 @@ impl>> StarkMachine { } } -pub enum ProgramVerificationError { +pub enum MachineVerificationError { InvalidSegmentProof(VerificationError), InvalidGlobalProof(VerificationError), NonZeroCumulativeSum, @@ -471,41 +471,41 @@ pub enum ProgramVerificationError { InvalidPublicValues(&'static str), } -impl Debug for ProgramVerificationError { +impl Debug for MachineVerificationError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - ProgramVerificationError::InvalidSegmentProof(e) => { + MachineVerificationError::InvalidSegmentProof(e) => { write!(f, "Invalid segment proof: {:?}", e) } - ProgramVerificationError::InvalidGlobalProof(e) => { + MachineVerificationError::InvalidGlobalProof(e) => { write!(f, "Invalid global proof: {:?}", e) } - ProgramVerificationError::NonZeroCumulativeSum => { + MachineVerificationError::NonZeroCumulativeSum => { write!(f, "Non-zero cumulative sum") } - ProgramVerificationError::InvalidPublicValuesDigest => { + MachineVerificationError::InvalidPublicValuesDigest => { write!(f, "Invalid public values digest") } - ProgramVerificationError::EmptyProof => { + MachineVerificationError::EmptyProof => { write!(f, "Empty proof") } - ProgramVerificationError::DebugInteractionsFailed => { + MachineVerificationError::DebugInteractionsFailed => { write!(f, "Debug interactions failed") } - ProgramVerificationError::InvalidPublicValues(s) => { + MachineVerificationError::InvalidPublicValues(s) => { write!(f, "Invalid public values: {}", s) } } } } -impl std::fmt::Display for ProgramVerificationError { +impl std::fmt::Display for MachineVerificationError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Debug::fmt(self, f) } } -impl std::error::Error for ProgramVerificationError {} +impl std::error::Error for MachineVerificationError {} #[cfg(test)] #[allow(non_snake_case)] diff --git a/core/src/utils/prove.rs b/core/src/utils/prove.rs index 7cd796fcbb..515672a646 100644 --- a/core/src/utils/prove.rs +++ b/core/src/utils/prove.rs @@ -33,7 +33,7 @@ const LOG_DEGREE_BOUND: usize = 31; pub fn run_test_io( program: Program, inputs: SP1Stdin, -) -> Result> { +) -> Result> { let runtime = tracing::info_span!("runtime.run(...)").in_scope(|| { let mut runtime = Runtime::new(program); runtime.write_vecs(&inputs.buffer); @@ -49,7 +49,7 @@ pub fn run_test( program: Program, ) -> Result< crate::stark::MachineProof, - crate::stark::ProgramVerificationError, + crate::stark::MachineVerificationError, > { let runtime = tracing::info_span!("runtime.run(...)").in_scope(|| { let mut runtime = Runtime::new(program); @@ -64,7 +64,7 @@ pub fn run_test_core( runtime: Runtime, ) -> Result< crate::stark::MachineProof, - crate::stark::ProgramVerificationError, + crate::stark::MachineVerificationError, > { let config = BabyBearPoseidon2::new(); let machine = RiscvAir::machine(config); @@ -80,7 +80,7 @@ pub fn run_test_machine( machine: StarkMachine, pk: StarkProvingKey, vk: StarkVerifyingKey, -) -> Result, crate::stark::ProgramVerificationError> +) -> Result, crate::stark::MachineVerificationError> where A: MachineAir + for<'a> Air> diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 66b7a9a5d7..61ab087ab8 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -52,10 +52,6 @@ path = "scripts/tendermint_sweep.rs" name = "fibonacci_groth16" path = "scripts/fibonacci_groth16.rs" -[[bin]] -name = "e2e" -path = "scripts/e2e.rs" - [[bin]] name = "test_groth16_verification" path = "scripts/test_groth16_verification.rs" diff --git a/prover/scripts/e2e.rs b/prover/scripts/e2e.rs deleted file mode 100644 index 397ceb1d82..0000000000 --- a/prover/scripts/e2e.rs +++ /dev/null @@ -1,105 +0,0 @@ -#![feature(generic_const_exprs)] -#![allow(incomplete_features)] - -use clap::Parser; -use p3_baby_bear::BabyBear; -use sp1_core::io::SP1Stdin; -use sp1_prover::utils::{babybear_bytes_to_bn254, babybears_to_bn254, words_to_bytes}; -use sp1_prover::SP1Prover; -use sp1_recursion_circuit::stark::build_wrap_circuit; -use sp1_recursion_circuit::witness::Witnessable; -use sp1_recursion_compiler::ir::Witness; -use sp1_recursion_core::air::RecursionPublicValues; -use sp1_recursion_gnark_ffi::{convert, verify, Groth16Prover}; -use subtle_encoding::hex; - -#[derive(Parser, Debug)] -#[clap(author, version, about, long_about = None)] -struct Args { - #[clap(short, long)] - build_dir: String, -} - -pub fn main() { - sp1_core::utils::setup_logger(); - std::env::set_var("RECONSTRUCT_COMMITMENTS", "false"); - - let args = Args::parse(); - - let elf = include_bytes!("../../tests/fibonacci/elf/riscv32im-succinct-zkvm-elf"); - - tracing::info!("initializing prover"); - let prover = SP1Prover::new(); - - tracing::info!("setup elf"); - let (pk, vk) = prover.setup(elf); - - tracing::info!("prove core"); - let stdin = SP1Stdin::new(); - let core_proof = prover.prove_core(&pk, &stdin); - - tracing::info!("reduce"); - let reduced_proof = prover.compress(&vk, core_proof, vec![]); - - tracing::info!("compress"); - let compressed_proof = prover.shrink(&vk, reduced_proof); - - tracing::info!("wrap"); - let wrapped_proof = prover.wrap_bn254(&vk, compressed_proof); - - tracing::info!("building verifier constraints"); - let constraints = tracing::info_span!("wrap circuit") - .in_scope(|| build_wrap_circuit(&prover.wrap_vk, wrapped_proof.clone())); - - tracing::info!("building template witness"); - let pv = RecursionPublicValues::from_vec(wrapped_proof.public_values.clone()); - let vkey_hash = babybears_to_bn254(&pv.sp1_vk_digest); - let committed_values_digest_bytes: [BabyBear; 32] = words_to_bytes(&pv.committed_value_digest) - .try_into() - .unwrap(); - let committed_values_digest = babybear_bytes_to_bn254(&committed_values_digest_bytes); - - let mut witness = Witness::default(); - wrapped_proof.write(&mut witness); - witness.write_commited_values_digest(committed_values_digest); - witness.write_vkey_hash(vkey_hash); - - tracing::info!("sanity check gnark test"); - Groth16Prover::test(constraints.clone(), witness.clone()); - - tracing::info!("sanity check gnark build"); - Groth16Prover::build( - constraints.clone(), - witness.clone(), - args.build_dir.clone().into(), - ); - - tracing::info!("sanity check gnark prove"); - let groth16_prover = Groth16Prover::new(args.build_dir.clone().into()); - - tracing::info!("gnark prove"); - let proof = groth16_prover.prove(witness.clone()); - - tracing::info!("verify gnark proof"); - let verified = verify(proof.clone(), &args.build_dir.clone().into()); - assert!(verified); - - tracing::info!("convert gnark proof"); - let solidity_proof = convert(proof.clone(), &args.build_dir.clone().into()); - - // tracing::info!("sanity check plonk bn254 build"); - // PlonkBn254Prover::build( - // constraints.clone(), - // witness.clone(), - // args.build_dir.clone().into(), - // ); - - // tracing::info!("sanity check plonk bn254 prove"); - // let proof = PlonkBn254Prover::prove(witness.clone(), args.build_dir.clone().into()); - - println!( - "{:?}", - String::from_utf8(hex::encode(proof.encoded_proof)).unwrap() - ); - println!("solidity proof: {:?}", solidity_proof); -} diff --git a/prover/src/build.rs b/prover/src/build.rs index 25975f9f21..5c274ddcce 100644 --- a/prover/src/build.rs +++ b/prover/src/build.rs @@ -50,11 +50,11 @@ fn dummy_proof() -> (StarkVerifyingKey, ShardProof) { tracing::info!("wrap"); let wrapped_proof = prover.wrap_bn254(&vk, shrink_proof); - (prover.wrap_vk, wrapped_proof) + (prover.wrap_vk, wrapped_proof.proof) } /// Build the verifier constraints and template witness for the circuit. -fn build_constraints( +pub fn build_constraints( wrap_vk: &StarkVerifyingKey, wrapped_proof: &ShardProof, ) -> (Vec, Witness) { diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 84b2876340..aa6820790e 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -136,7 +136,7 @@ impl SP1Prover { let core_machine = RiscvAir::machine(CoreSC::default()); let compress_machine = RecursionAirWideDeg3::machine(InnerSC::default()); let shrink_machine = RecursionAirSkinnyDeg7::machine(InnerSC::compressed()); - let wrap_machine = RecursionAirSkinnyDeg7::machine(OuterSC::default()); + let wrap_machine = RecursionAirSkinnyDeg7::wrap_machine(OuterSC::default()); Self { recursion_setup_program, recursion_program, @@ -650,7 +650,7 @@ impl SP1Prover { &self, vk: &SP1VerifyingKey, reduced_proof: SP1ReduceProof, - ) -> ShardProof { + ) -> SP1ReduceProof { // Get verify_start_challenger from the reduce proof's public values. let pv = RecursionPublicValues::from_vec(reduced_proof.proof.public_values.clone()); let mut core_challenger = self.core_machine.config().challenger(); @@ -674,12 +674,12 @@ impl SP1Prover { true, true, ) - .proof } /// Wrap the STARK proven over a SNARK-friendly field into a Groth16 proof. #[instrument(name = "wrap_groth16", level = "info", skip_all)] - pub fn wrap_groth16(&self, proof: ShardProof, build_dir: PathBuf) -> Groth16Proof { + pub fn wrap_groth16(&self, proof: SP1ReduceProof, build_dir: PathBuf) -> Groth16Proof { + let proof = &proof.proof; let pv = RecursionPublicValues::from_vec(proof.public_values.clone()); // Convert pv.vkey_digest to a bn254 field element @@ -734,18 +734,19 @@ impl SP1Prover { mod tests { use super::*; + use crate::build::build_constraints; use p3_field::PrimeField32; - use sp1_core::air::{PublicValues, Word}; use sp1_core::io::SP1Stdin; + use sp1_core::stark::MachineVerificationError; use sp1_core::utils::setup_logger; + /// Tests an end-to-end workflow of proving a program across the entire proof generation + /// pipeline. + /// + /// TODO: Remove the fact that we ignore [MachineVerificationError::NonZeroCumulativeSum]. #[test] - #[ignore] - fn test_prove_sp1() { + fn test_e2e() { setup_logger(); - std::env::set_var("RECONSTRUCT_COMMITMENTS", "false"); - - // Generate SP1 proof let elf = include_bytes!("../../tests/fibonacci/elf/riscv32im-succinct-zkvm-elf"); tracing::info!("initializing prover"); @@ -764,49 +765,73 @@ mod tests { tracing::info!("compress"); let compressed_proof = prover.compress(&vk, core_proof, vec![]); + tracing::info!("verify compressed"); + let result = prover.verify_compressed(&compressed_proof, &vk); + if let Err(MachineVerificationError::NonZeroCumulativeSum) = result { + tracing::warn!("non-zero cumulative sum for compress"); + } else { + result.unwrap(); + } + + tracing::info!("shrink"); + let shrink_proof = prover.shrink(&vk, compressed_proof); + + tracing::info!("verify shrink"); + let result = prover.verify_shrink(&shrink_proof, &vk); + if let Err(MachineVerificationError::NonZeroCumulativeSum) = result { + tracing::warn!("non-zero cumulative sum for shrink"); + } else { + result.unwrap(); + } + tracing::info!("wrap bn254"); - let wrapped_bn254_proof = prover.wrap_bn254(&vk, compressed_proof); + let wrapped_bn254_proof = prover.wrap_bn254(&vk, shrink_proof); + + tracing::info!("verify wrap bn254"); + let result = prover.verify_wrap_bn254(&wrapped_bn254_proof, &vk); + if let Err(MachineVerificationError::NonZeroCumulativeSum) = result { + tracing::warn!("non-zero cumulative sum for wrap bn254"); + } else { + result.unwrap(); + } - tracing::info!("groth16"); - prover.wrap_groth16(wrapped_bn254_proof, PathBuf::from("build")); + // TODO: replace this with real groth16 proof generation. + tracing::info!("generate groth16 proof"); + let (constraints, witness) = build_constraints(&prover.wrap_vk, &wrapped_bn254_proof.proof); + Groth16Prover::test(constraints.clone(), witness.clone()); } - /// This test ensures that a proof can be deferred in the core vm and verified in recursion. + /// Tests an end-to-end workflow of proving a program across the entire proof generation + /// pipeline in addition to verifying deferred proofs. + /// + /// TODO: Remove the fact that we ignore [MachineVerificationError::NonZeroCumulativeSum]. #[test] - #[ignore] - fn test_deferred_verify() { + fn test_e2e_with_deferred_proofs() { setup_logger(); - std::env::set_var("RECONSTRUCT_COMMITMENTS", "false"); - std::env::set_var("SHARD_SIZE", "262144"); - std::env::set_var("MAX_RECURSION_PROGRAM_SIZE", "1"); - // keccak program which proves keccak of various inputs + // Test program which proves the Keccak-256 hash of various inputs. let keccak_elf = include_bytes!("../../tests/keccak256/elf/riscv32im-succinct-zkvm-elf"); - // verify program which verifies proofs of a vkey and a list of committed inputs + + // Test program which verifies proofs of a vkey and a list of committed inputs. let verify_elf = include_bytes!("../../tests/verify-proof/elf/riscv32im-succinct-zkvm-elf"); tracing::info!("initializing prover"); let prover = SP1Prover::new(); - tracing::info!("setup elf"); + tracing::info!("setup keccak elf"); let (keccak_pk, keccak_vk) = prover.setup(keccak_elf); + + tracing::info!("setup verify elf"); let (verify_pk, verify_vk) = prover.setup(verify_elf); - // Prove keccak of various inputs tracing::info!("prove subproof 1"); let mut stdin = SP1Stdin::new(); stdin.write(&1usize); stdin.write(&vec![0u8, 0, 0]); let deferred_proof_1 = prover.prove_core(&keccak_pk, &stdin); let pv_1 = deferred_proof_1.public_values.as_slice().to_vec().clone(); - println!("proof 1 pv: {:?}", hex::encode(pv_1.clone())); - let pv_digest_1 = deferred_proof_1.proof.0[0].public_values[..32] - .iter() - .map(|x| x.as_canonical_u32() as u8) - .collect::>(); - println!("proof 1 pv_digest: {:?}", hex::encode(pv_digest_1.clone())); - // Generate a second proof of keccak of various inputs + // Generate a second proof of keccak of various inputs. tracing::info!("prove subproof 2"); let mut stdin = SP1Stdin::new(); stdin.write(&3usize); @@ -815,22 +840,16 @@ mod tests { stdin.write(&vec![5, 6, 7]); let deferred_proof_2 = prover.prove_core(&keccak_pk, &stdin); let pv_2 = deferred_proof_2.public_values.as_slice().to_vec().clone(); - println!("proof 2 pv: {:?}", hex::encode(pv_2.clone())); - let pv_digest_2 = deferred_proof_2.proof.0[0].public_values[..32] - .iter() - .map(|x| x.as_canonical_u32() as u8) - .collect::>(); - println!("proof 2 pv_digest: {:?}", hex::encode(pv_digest_2.clone())); - // Generate recursive proof of first subproof - println!("reduce subproof 1"); + // Generate recursive proof of first subproof. + tracing::info!("compress subproof 1"); let deferred_reduce_1 = prover.compress(&keccak_vk, deferred_proof_1, vec![]); - // Generate recursive proof of second subproof - println!("reduce subproof 2"); + // Generate recursive proof of second subproof. + tracing::info!("compress subproof 2"); let deferred_reduce_2 = prover.compress(&keccak_vk, deferred_proof_2, vec![]); - // Run verify program with keccak vkey, subproofs, and their committed values + // Run verify program with keccak vkey, subproofs, and their committed values. let mut stdin = SP1Stdin::new(); let vkey_digest = keccak_vk.hash(); let vkey_digest: [u32; 8] = vkey_digest @@ -845,17 +864,11 @@ mod tests { stdin.write_proof(deferred_reduce_2.proof.clone(), keccak_vk.vk.clone()); stdin.write_proof(deferred_reduce_2.proof.clone(), keccak_vk.vk.clone()); - // Prove verify program - println!("proving verify program (core)"); + tracing::info!("proving verify program (core)"); let verify_proof = prover.prove_core(&verify_pk, &stdin); - let pv = PublicValues::, BabyBear>::from_vec( - verify_proof.proof.0[0].public_values.clone(), - ); - - println!("deferred_hash: {:?}", pv.deferred_proofs_digest); // Generate recursive proof of verify program - println!("proving verify program (recursion)"); + tracing::info!("compress verify program"); let verify_reduce = prover.compress( &verify_vk, verify_proof, @@ -865,15 +878,13 @@ mod tests { deferred_reduce_2.proof, ], ); - let reduce_pv = RecursionPublicValues::from_vec(verify_reduce.proof.public_values.clone()); - println!("deferred_hash: {:?}", reduce_pv.deferred_proofs_digest); - println!("complete: {:?}", reduce_pv.is_complete); - - let reduced_proof = SP1ReducedProofData(verify_reduce.proof); - prover.verify_reduced(&reduced_proof, &verify_vk).unwrap(); - std::env::remove_var("RECONSTRUCT_COMMITMENTS"); - std::env::remove_var("SHARD_SIZE"); - std::env::remove_var("MAX_RECURSION_PROGRAM_SIZE"); + tracing::info!("verify verify program"); + let result = prover.verify_compressed(&verify_reduce, &verify_vk); + if let Err(MachineVerificationError::NonZeroCumulativeSum) = result { + tracing::warn!("non-zero cumulative sum for verify"); + } else { + result.unwrap(); + } } } diff --git a/prover/src/verify.rs b/prover/src/verify.rs index 64c895fc9d..2cced6a87a 100644 --- a/prover/src/verify.rs +++ b/prover/src/verify.rs @@ -3,12 +3,13 @@ use p3_baby_bear::BabyBear; use p3_field::AbstractField; use sp1_core::{ air::PublicValues, - stark::{MachineProof, ProgramVerificationError, StarkGenericConfig}, + stark::{MachineProof, MachineVerificationError, StarkGenericConfig}, + utils::BabyBearPoseidon2, }; -use sp1_recursion_core::air::RecursionPublicValues; +use sp1_recursion_core::{air::RecursionPublicValues, stark::config::BabyBearPoseidon2Outer}; use crate::{ - CoreSC, HashableKey, SP1CoreProofData, SP1Prover, SP1ReducedProofData, SP1VerifyingKey, + CoreSC, HashableKey, OuterSC, SP1CoreProofData, SP1Prover, SP1ReduceProof, SP1VerifyingKey, }; impl SP1Prover { @@ -18,7 +19,7 @@ impl SP1Prover { &self, proof: &SP1CoreProofData, vk: &SP1VerifyingKey, - ) -> Result<(), ProgramVerificationError> { + ) -> Result<(), MachineVerificationError> { let mut challenger = self.core_machine.config().challenger(); let machine_proof = MachineProof { shard_proofs: proof.0.to_vec(), @@ -33,12 +34,12 @@ impl SP1Prover { if i == 0 { // If it's the first shard, index should be 1. if public_values.shard != BabyBear::one() { - return Err(ProgramVerificationError::InvalidPublicValues( + return Err(MachineVerificationError::InvalidPublicValues( "first shard not 1", )); } if public_values.start_pc != vk.vk.pc_start { - return Err(ProgramVerificationError::InvalidPublicValues( + return Err(MachineVerificationError::InvalidPublicValues( "wrong pc_start", )); } @@ -48,13 +49,13 @@ impl SP1Prover { PublicValues::from_vec(prev_shard_proof.public_values.clone()); // For non-first shards, the index should be the previous index + 1. if public_values.shard != prev_public_values.shard + BabyBear::one() { - return Err(ProgramVerificationError::InvalidPublicValues( + return Err(MachineVerificationError::InvalidPublicValues( "non incremental shard index", )); } // Start pc should be what the next pc declared in the previous shard was. if public_values.start_pc != prev_public_values.next_pc { - return Err(ProgramVerificationError::InvalidPublicValues("pc mismatch")); + return Err(MachineVerificationError::InvalidPublicValues("pc mismatch")); } // Digests and exit code should be the same in all shards. if public_values.committed_value_digest != prev_public_values.committed_value_digest @@ -62,19 +63,19 @@ impl SP1Prover { != prev_public_values.deferred_proofs_digest || public_values.exit_code != prev_public_values.exit_code { - return Err(ProgramVerificationError::InvalidPublicValues( + return Err(MachineVerificationError::InvalidPublicValues( "digest or exit code mismatch", )); } // The last shard should be halted. Halt is signaled with next_pc == 0. if i == proof.0.len() - 1 && public_values.next_pc != BabyBear::zero() { - return Err(ProgramVerificationError::InvalidPublicValues( + return Err(MachineVerificationError::InvalidPublicValues( "last shard isn't halted", )); } // All non-last shards should not be halted. if i != proof.0.len() - 1 && public_values.next_pc == BabyBear::zero() { - return Err(ProgramVerificationError::InvalidPublicValues( + return Err(MachineVerificationError::InvalidPublicValues( "non-last shard is halted", )); } @@ -84,25 +85,25 @@ impl SP1Prover { Ok(()) } - /// Verify a reduced proof. - pub fn verify_reduced( + /// Verify a compressed proof. + pub fn verify_compressed( &self, - proof: &SP1ReducedProofData, + proof: &SP1ReduceProof, vk: &SP1VerifyingKey, - ) -> Result<(), ProgramVerificationError> { + ) -> Result<(), MachineVerificationError> { let mut challenger = self.compress_machine.config().challenger(); let machine_proof = MachineProof { - shard_proofs: vec![proof.0.clone()], + shard_proofs: vec![proof.proof.clone()], }; self.compress_machine .verify(&self.compress_vk, &machine_proof, &mut challenger)?; // Validate public values - let public_values = RecursionPublicValues::from_vec(proof.0.public_values.clone()); + let public_values = RecursionPublicValues::from_vec(proof.proof.public_values.clone()); // `is_complete` should be 1. In the reduce program, this ensures that the proof is fully reduced. if public_values.is_complete != BabyBear::one() { - return Err(ProgramVerificationError::InvalidPublicValues( + return Err(MachineVerificationError::InvalidPublicValues( "is_complete is not 1", )); } @@ -110,7 +111,7 @@ impl SP1Prover { // Verify that the proof is for the sp1 vkey we are expecting. let vkey_hash = vk.hash(); if public_values.sp1_vk_digest != vkey_hash { - return Err(ProgramVerificationError::InvalidPublicValues( + return Err(MachineVerificationError::InvalidPublicValues( "sp1 vk hash mismatch", )); } @@ -118,7 +119,91 @@ impl SP1Prover { // Verify that the reduce program is the one we are expecting. let recursion_vkey_hash = self.compress_vk.hash(); if public_values.recursion_vk_digest != recursion_vkey_hash { - return Err(ProgramVerificationError::InvalidPublicValues( + return Err(MachineVerificationError::InvalidPublicValues( + "recursion vk hash mismatch", + )); + } + + Ok(()) + } + + /// Verify a shrink proof. + pub fn verify_shrink( + &self, + proof: &SP1ReduceProof, + vk: &SP1VerifyingKey, + ) -> Result<(), MachineVerificationError> { + let mut challenger = self.shrink_machine.config().challenger(); + let machine_proof = MachineProof { + shard_proofs: vec![proof.proof.clone()], + }; + self.shrink_machine + .verify(&self.shrink_vk, &machine_proof, &mut challenger)?; + + // Validate public values + let public_values = RecursionPublicValues::from_vec(proof.proof.public_values.clone()); + + // `is_complete` should be 1. In the reduce program, this ensures that the proof is fully reduced. + if public_values.is_complete != BabyBear::one() { + return Err(MachineVerificationError::InvalidPublicValues( + "is_complete is not 1", + )); + } + + // Verify that the proof is for the sp1 vkey we are expecting. + let vkey_hash = vk.hash(); + if public_values.sp1_vk_digest != vkey_hash { + return Err(MachineVerificationError::InvalidPublicValues( + "sp1 vk hash mismatch", + )); + } + + // Verify that the reduce program is the one we are expecting. + let recursion_vkey_hash = self.shrink_vk.hash(); + if public_values.recursion_vk_digest != recursion_vkey_hash { + return Err(MachineVerificationError::InvalidPublicValues( + "recursion vk hash mismatch", + )); + } + + Ok(()) + } + + /// Verify a wrap bn254 proof. + pub fn verify_wrap_bn254( + &self, + proof: &SP1ReduceProof, + vk: &SP1VerifyingKey, + ) -> Result<(), MachineVerificationError> { + let mut challenger = self.wrap_machine.config().challenger(); + let machine_proof = MachineProof { + shard_proofs: vec![proof.proof.clone()], + }; + self.wrap_machine + .verify(&self.wrap_vk, &machine_proof, &mut challenger)?; + + // Validate public values + let public_values = RecursionPublicValues::from_vec(proof.proof.public_values.clone()); + + // `is_complete` should be 1. In the reduce program, this ensures that the proof is fully reduced. + if public_values.is_complete != BabyBear::one() { + return Err(MachineVerificationError::InvalidPublicValues( + "is_complete is not 1", + )); + } + + // Verify that the proof is for the sp1 vkey we are expecting. + let vkey_hash = vk.hash(); + if public_values.sp1_vk_digest != vkey_hash { + return Err(MachineVerificationError::InvalidPublicValues( + "sp1 vk hash mismatch", + )); + } + + // Verify that the reduce program is the one we are expecting. + let recursion_vkey_hash = self.shrink_vk.hash(); + if public_values.recursion_vk_digest != recursion_vkey_hash { + return Err(MachineVerificationError::InvalidPublicValues( "recursion vk hash mismatch", )); } diff --git a/recursion/core/src/stark/utils.rs b/recursion/core/src/stark/utils.rs index bd3055cc56..6881ae4d24 100644 --- a/recursion/core/src/stark/utils.rs +++ b/recursion/core/src/stark/utils.rs @@ -9,7 +9,7 @@ use crate::runtime::RecursionProgram; use crate::runtime::Runtime; use crate::stark::RecursionAirSkinnyDeg7; use p3_field::PrimeField32; -use sp1_core::stark::ProgramVerificationError; +use sp1_core::stark::MachineVerificationError; use sp1_core::utils::run_test_machine; use std::collections::VecDeque; @@ -49,7 +49,7 @@ pub fn run_test_recursion( let record = runtime.record.clone(); let result = run_test_machine(record, machine, pk, vk); if let Err(e) = result { - if let ProgramVerificationError::::NonZeroCumulativeSum = e { + if let MachineVerificationError::::NonZeroCumulativeSum = e { // For now we ignore this error, as the cumulative sum checking is expected to fail. } else { panic!("Verification failed: {:?}", e); @@ -63,7 +63,7 @@ pub fn run_test_recursion( let record = runtime.record.clone(); let result = run_test_machine(record, machine, pk, vk); if let Err(e) = result { - if let ProgramVerificationError::::NonZeroCumulativeSum = e { + if let MachineVerificationError::::NonZeroCumulativeSum = e { // For now we ignore this error, as the cumulative sum checking is expected to fail. } else { panic!("Verification failed: {:?}", e); diff --git a/sdk/src/lib.rs b/sdk/src/lib.rs index dffc74e606..6c11436775 100644 --- a/sdk/src/lib.rs +++ b/sdk/src/lib.rs @@ -24,7 +24,7 @@ use std::{env, fmt::Debug, fs::File, path::Path}; use anyhow::{Ok, Result}; pub use provers::{LocalProver, MockProver, NetworkProver, Prover}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use sp1_core::stark::ShardProof; +use sp1_core::stark::{MachineVerificationError, ShardProof}; pub use sp1_prover::{ CoreSC, Groth16Proof, InnerSC, PlonkBn254Proof, SP1CoreProof, SP1Prover, SP1ProvingKey, SP1PublicValues, SP1Stdin, SP1VerifyingKey, @@ -48,9 +48,11 @@ pub struct SP1ProofWithPublicValues

{ /// A [SP1ProofWithPublicValues] generated with [ProverClient::prove]. pub type SP1Proof = SP1ProofWithPublicValues>>; +pub type SP1ProofVerificationError = MachineVerificationError; /// A [SP1ProofWithPublicValues] generated with [ProverClient::prove_compressed]. pub type SP1CompressedProof = SP1ProofWithPublicValues>; +pub type SP1CompressedProofVerificationError = MachineVerificationError; /// A [SP1ProofWithPublicValues] generated with [ProverClient::prove_groth16]. pub type SP1Groth16Proof = SP1ProofWithPublicValues; @@ -68,7 +70,7 @@ impl ProverClient { /// /// ### Examples /// - /// ``` + /// ```no_run /// use sp1_sdk::ProverClient; /// /// std::env::set_var("SP1_PROVER", "local"); @@ -102,7 +104,7 @@ impl ProverClient { /// /// ### Examples /// - /// ``` + /// ```no_run /// use sp1_sdk::ProverClient; /// /// let client = ProverClient::mock(); @@ -120,7 +122,7 @@ impl ProverClient { /// /// ### Examples /// - /// ``` + /// ```no_run /// use sp1_sdk::ProverClient; /// /// let client = ProverClient::local(); @@ -137,7 +139,7 @@ impl ProverClient { /// /// ### Examples /// - /// ``` + /// ```no_run /// use sp1_sdk::ProverClient; /// /// let client = ProverClient::remote(); @@ -154,7 +156,7 @@ impl ProverClient { /// /// /// ### Examples - /// ``` + /// ```no_run /// use sp1_sdk::{ProverClient, SP1Stdin}; /// /// // Load the program. @@ -181,7 +183,7 @@ impl ProverClient { /// data (such as lookup tables) that are used to prove the program's correctness. /// /// ### Examples - /// ``` + /// ```no_run /// use sp1_sdk::{ProverClient, SP1Stdin}; /// /// let elf = include_bytes!("../../examples/fibonacci/program/elf/riscv32im-succinct-zkvm-elf"); @@ -201,7 +203,7 @@ impl ProverClient { /// [Self::prove_groth16], or [Self::prove_plonk] methods. /// /// ### Examples - /// ``` + /// ```no_run /// use sp1_sdk::{ProverClient, SP1Stdin}; /// /// // Load the program. @@ -230,7 +232,7 @@ impl ProverClient { /// proof that is of constant size and friendly for recursion and off-chain verification. /// /// ### Examples - /// ``` + /// ```no_run /// use sp1_sdk::{ProverClient, SP1Stdin}; /// /// // Load the program. @@ -263,7 +265,7 @@ impl ProverClient { /// proof that is of constant size and friendly for on-chain verification. /// /// ### Examples - /// ``` + /// ```no_run /// use sp1_sdk::{ProverClient, SP1Stdin}; /// /// // Load the program. @@ -293,7 +295,7 @@ impl ProverClient { /// proof that is of constant size and friendly for on-chain verification. /// /// ### Examples - /// ``` + /// ```no_run /// use sp1_sdk::{ProverClient, SP1Stdin}; /// /// // Load the program. @@ -320,7 +322,7 @@ impl ProverClient { /// [Self::setup]. /// /// ### Examples - /// ``` + /// ```no_run /// use sp1_sdk::{ProverClient, SP1Stdin}; /// /// let elf = include_bytes!("../../examples/fibonacci/program/elf/riscv32im-succinct-zkvm-elf"); @@ -331,7 +333,11 @@ impl ProverClient { /// let proof = client.prove(&pk, stdin).unwrap(); /// client.verify(&proof, &vk).unwrap(); /// ``` - pub fn verify(&self, proof: &SP1Proof, vkey: &SP1VerifyingKey) -> Result<()> { + pub fn verify( + &self, + proof: &SP1Proof, + vkey: &SP1VerifyingKey, + ) -> Result<(), SP1ProofVerificationError> { self.prover.verify(proof, vkey) } @@ -339,7 +345,7 @@ impl ProverClient { /// produced by [Self::setup]. /// /// ### Examples - /// ``` + /// ```no_run /// use sp1_sdk::{ProverClient, SP1Stdin}; /// /// // Load the program. @@ -371,7 +377,7 @@ impl ProverClient { /// produced by [Self::setup]. /// /// ### Examples - /// ``` + /// ```no_run /// use sp1_sdk::{ProverClient, SP1Stdin}; /// /// // Load the program. @@ -401,7 +407,7 @@ impl ProverClient { /// produced by [Self::setup]. /// /// ### Examples - /// ``` + /// ```no_run /// use sp1_sdk::{ProverClient, SP1Stdin}; /// /// // Load the program. diff --git a/sdk/src/provers/local.rs b/sdk/src/provers/local.rs index 5e5676981f..c965be8598 100644 --- a/sdk/src/provers/local.rs +++ b/sdk/src/provers/local.rs @@ -80,7 +80,7 @@ impl Prover for LocalProver { } sp1_prover::build::groth16_artifacts( &self.prover.wrap_vk, - &outer_proof, + &outer_proof.proof, build_dir.clone(), ); build_dir diff --git a/sdk/src/provers/mock.rs b/sdk/src/provers/mock.rs index ffea8ae55a..addb91b257 100644 --- a/sdk/src/provers/mock.rs +++ b/sdk/src/provers/mock.rs @@ -1,7 +1,7 @@ #![allow(unused_variables)] use crate::{ - Prover, SP1CompressedProof, SP1Groth16Proof, SP1PlonkProof, SP1Proof, SP1ProofWithPublicValues, - SP1ProvingKey, SP1VerifyingKey, + Prover, SP1CompressedProof, SP1Groth16Proof, SP1PlonkProof, SP1Proof, + SP1ProofVerificationError, SP1ProofWithPublicValues, SP1ProvingKey, SP1VerifyingKey, }; use anyhow::Result; use sp1_prover::{SP1Prover, SP1Stdin}; @@ -57,7 +57,11 @@ impl Prover for MockProver { todo!() } - fn verify(&self, _proof: &SP1Proof, _vkey: &SP1VerifyingKey) -> Result<()> { + fn verify( + &self, + _proof: &SP1Proof, + _vkey: &SP1VerifyingKey, + ) -> Result<(), SP1ProofVerificationError> { Ok(()) } diff --git a/sdk/src/provers/mod.rs b/sdk/src/provers/mod.rs index cbd5906805..417361dc7e 100644 --- a/sdk/src/provers/mod.rs +++ b/sdk/src/provers/mod.rs @@ -11,7 +11,9 @@ pub use network::NetworkProver; use sha2::{Digest, Sha256}; use sp1_core::air::PublicValues; use sp1_core::stark::MachineProof; +use sp1_core::stark::MachineVerificationError; use sp1_core::stark::StarkGenericConfig; +use sp1_prover::CoreSC; use sp1_prover::SP1Prover; use sp1_prover::{SP1ProvingKey, SP1Stdin, SP1VerifyingKey}; @@ -36,20 +38,24 @@ pub trait Prover: Send + Sync { fn prove_plonk(&self, pk: &SP1ProvingKey, stdin: SP1Stdin) -> Result; /// Verify that an SP1 proof is valid given its vkey and metadata. - fn verify(&self, proof: &SP1Proof, vkey: &SP1VerifyingKey) -> Result<()> { + fn verify( + &self, + proof: &SP1Proof, + vkey: &SP1VerifyingKey, + ) -> Result<(), MachineVerificationError> { let pv = PublicValues::from_vec(proof.proof[0].public_values.clone()); let pv_digest: [u8; 32] = Sha256::digest(proof.public_values.as_slice()).into(); if pv_digest != *pv.commit_digest_bytes() { - return Err(anyhow::anyhow!("Public values digest mismatch")); + return Err(MachineVerificationError::InvalidPublicValuesDigest); } let machine_proof = MachineProof { shard_proofs: proof.proof.clone(), }; let sp1_prover = self.sp1_prover(); let mut challenger = sp1_prover.core_machine.config().challenger(); - Ok(sp1_prover + sp1_prover .core_machine - .verify(&vkey.vk, &machine_proof, &mut challenger)?) + .verify(&vkey.vk, &machine_proof, &mut challenger) } /// Verify that a compressed SP1 proof is valid given its vkey and metadata.