From 1307b6e421af1d707a0508d3a650f66dd27b740a Mon Sep 17 00:00:00 2001 From: John Guibas Date: Tue, 23 Apr 2024 12:36:07 -0700 Subject: [PATCH] fix: groth16 prover issues (#571) --- .gitignore | 5 +- Cargo.lock | 2 +- prover/Makefile | 2 +- prover/src/lib.rs | 5 +- recursion/circuit/Cargo.toml | 2 +- recursion/circuit/src/challenger.rs | 19 ++-- recursion/circuit/src/constraints.rs | 7 +- recursion/circuit/src/fri.rs | 11 +-- recursion/circuit/src/poseidon2.rs | 15 ++-- recursion/circuit/src/stark.rs | 86 +++++++++--------- recursion/circuit/src/witness.rs | 7 +- .../compiler/src/constraints/groth16_ffi.rs | 87 ------------------- recursion/compiler/src/constraints/mod.rs | 64 +------------- recursion/groth16-ffi/src/lib.rs | 21 +++-- recursion/groth16/main_test.go | 81 +++++++++++++---- 15 files changed, 151 insertions(+), 263 deletions(-) delete mode 100644 recursion/compiler/src/constraints/groth16_ffi.rs diff --git a/.gitignore b/.gitignore index 5c9f4bc870..7562666488 100644 --- a/.gitignore +++ b/.gitignore @@ -21,5 +21,6 @@ benchmark.csv # Environment .env -# Groth16 FFI -recursion/groth16-ffi/build \ No newline at end of file +# Build Artifacts +recursion/groth16-ffi/build +prover/build \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 32d502132e..f73e260328 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4495,11 +4495,11 @@ dependencies = [ "p3-util", "rand", "serde", - "serial_test", "sp1-core", "sp1-recursion-compiler", "sp1-recursion-core", "sp1-recursion-derive", + "sp1-recursion-groth16-ffi", "sp1-recursion-program", "zkhash", ] diff --git a/prover/Makefile b/prover/Makefile index be6472442c..cdec807a73 100644 --- a/prover/Makefile +++ b/prover/Makefile @@ -6,7 +6,7 @@ all: groth16: RUST_LOG=info RUSTFLAGS='-C target-cpu=native' \ cargo run -p sp1-prover --release --bin e2e -- \ - --build-dir=../recursion/groth16-ffi/build + --build-dir=./build fibonacci-sweep: mkdir -p scripts/results && \ diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 8f1126f2d1..b7776c94c3 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -7,6 +7,7 @@ //! 3. Wrap the shard proof into a SNARK-friendly field. //! 4. Wrap the last shard proof, proven over the SNARK-friendly field, into a Groth16/PLONK proof. +#![warn(unused_extern_crates)] #![allow(incomplete_features)] #![feature(generic_const_exprs)] #![allow(deprecated)] @@ -36,13 +37,13 @@ use sp1_core::{ }; use sp1_recursion_circuit::stark::build_wrap_circuit; use sp1_recursion_circuit::witness::Witnessable; -use sp1_recursion_compiler::constraints::groth16_ffi; use sp1_recursion_compiler::ir::Witness; use sp1_recursion_core::runtime::RecursionProgram; use sp1_recursion_core::{ runtime::Runtime as RecursionRuntime, stark::{config::BabyBearPoseidon2Outer, RecursionAir}, }; +use sp1_recursion_groth16_ffi::Groth16Prover; use sp1_recursion_program::reduce::ReduceProgram; use sp1_recursion_program::{hints::Hintable, stark::EMPTY}; @@ -425,7 +426,7 @@ impl SP1Prover { let mut witness = Witness::default(); proof.write(&mut witness); let constraints = build_wrap_circuit(&self.reduce_vk_outer, proof); - groth16_ffi::test_prove(constraints, witness); + Groth16Prover::test(constraints, witness); } // TODO: Get rid of this method by reading it from public values. diff --git a/recursion/circuit/Cargo.toml b/recursion/circuit/Cargo.toml index 14e256e403..9861804e93 100644 --- a/recursion/circuit/Cargo.toml +++ b/recursion/circuit/Cargo.toml @@ -21,7 +21,6 @@ sp1-recursion-compiler = { path = "../compiler" } sp1-recursion-program = { path = "../program" } p3-bn254-fr = { workspace = true } p3-baby-bear = { workspace = true } -serial_test = "3.0.0" bincode = "1.3.3" [dev-dependencies] @@ -33,3 +32,4 @@ p3-merkle-tree = { workspace = true } p3-poseidon2 = { workspace = true } zkhash = { git = "https://github.com/HorizenLabs/poseidon2" } rand = "0.8.4" +sp1-recursion-groth16-ffi = { path = "../groth16-ffi" } \ No newline at end of file diff --git a/recursion/circuit/src/challenger.rs b/recursion/circuit/src/challenger.rs index cf2769e86f..3fd1982f8d 100644 --- a/recursion/circuit/src/challenger.rs +++ b/recursion/circuit/src/challenger.rs @@ -158,12 +158,12 @@ mod tests { use p3_field::split_32 as split_32_gt; use p3_field::AbstractField; use p3_symmetric::Hash; - use serial_test::serial; use sp1_recursion_compiler::config::OuterConfig; - use sp1_recursion_compiler::constraints::{groth16_ffi, ConstraintCompiler}; + use sp1_recursion_compiler::constraints::ConstraintCompiler; use sp1_recursion_compiler::ir::SymbolicExt; use sp1_recursion_compiler::ir::{Builder, Witness}; use sp1_recursion_core::stark::config::{outer_perm, OuterChallenger}; + use sp1_recursion_groth16_ffi::Groth16Prover; use super::reduce_32; use super::split_32; @@ -171,7 +171,6 @@ mod tests { use crate::DIGEST_SIZE; #[test] - #[serial] fn test_num2bits_v() { let mut builder = Builder::::default(); let mut value_u32 = 1345237507; @@ -184,11 +183,10 @@ mod tests { let mut backend = ConstraintCompiler::::default(); let constraints = backend.emit(builder.operations); - groth16_ffi::test_prove::(constraints, Witness::default()); + Groth16Prover::test::(constraints.clone(), Witness::default()); } #[test] - #[serial] fn test_reduce_32() { let value_1 = BabyBear::from_canonical_u32(1345237507); let value_2 = BabyBear::from_canonical_u32(1000001); @@ -202,11 +200,10 @@ mod tests { let mut backend = ConstraintCompiler::::default(); let constraints = backend.emit(builder.operations); - groth16_ffi::test_prove::(constraints, Witness::default()); + Groth16Prover::test::(constraints.clone(), Witness::default()); } #[test] - #[serial] fn test_split_32() { let value = Bn254Fr::from_canonical_u32(1345237507); let gt: Vec = split_32_gt(value, 3); @@ -221,11 +218,10 @@ mod tests { let mut backend = ConstraintCompiler::::default(); let constraints = backend.emit(builder.operations); - groth16_ffi::test_prove::(constraints, Witness::default()); + Groth16Prover::test::(constraints.clone(), Witness::default()); } #[test] - #[serial] fn test_challenger() { let perm = outer_perm(); let mut challenger = OuterChallenger::new(perm).unwrap(); @@ -259,11 +255,10 @@ mod tests { let mut backend = ConstraintCompiler::::default(); let constraints = backend.emit(builder.operations); - groth16_ffi::test_prove::(constraints, Witness::default()); + Groth16Prover::test::(constraints.clone(), Witness::default()); } #[test] - #[serial] fn test_challenger_sample_ext() { let perm = outer_perm(); let mut challenger = OuterChallenger::new(perm).unwrap(); @@ -302,6 +297,6 @@ mod tests { let mut backend = ConstraintCompiler::::default(); let constraints = backend.emit(builder.operations); - groth16_ffi::test_prove::(constraints, Witness::default()); + Groth16Prover::test::(constraints.clone(), Witness::default()); } } diff --git a/recursion/circuit/src/constraints.rs b/recursion/circuit/src/constraints.rs index 654f22d80e..5bc353ced1 100644 --- a/recursion/circuit/src/constraints.rs +++ b/recursion/circuit/src/constraints.rs @@ -169,14 +169,13 @@ mod tests { use p3_challenger::{CanObserve, FieldChallenger}; use p3_commit::{Pcs, PolynomialSpace}; use serde::{de::DeserializeOwned, Serialize}; - use serial_test::serial; use sp1_core::stark::{ Chip, Com, Dom, LocalProver, OpeningProof, PcsProverData, ShardCommitment, ShardMainData, ShardProof, StarkGenericConfig, StarkMachine, }; use sp1_recursion_compiler::{ config::OuterConfig, - constraints::{groth16_ffi, ConstraintCompiler}, + constraints::ConstraintCompiler, ir::{Builder, Witness}, prelude::ExtConst, }; @@ -184,6 +183,7 @@ mod tests { runtime::Runtime, stark::{config::BabyBearPoseidon2Outer, RecursionAir}, }; + use sp1_recursion_groth16_ffi::Groth16Prover; use crate::stark::{tests::basic_program, StarkVerifierCircuit}; @@ -282,7 +282,6 @@ mod tests { } #[test] - #[serial] fn test_verify_constraints_whole() { type SC = BabyBearPoseidon2Outer; type F = ::Val; @@ -361,6 +360,6 @@ mod tests { let mut backend = ConstraintCompiler::::default(); let constraints = backend.emit(builder.operations); - groth16_ffi::test_prove::(constraints, Witness::default()); + Groth16Prover::test::(constraints.clone(), Witness::default()); } } diff --git a/recursion/circuit/src/fri.rs b/recursion/circuit/src/fri.rs index 096cf73019..3cdae64774 100644 --- a/recursion/circuit/src/fri.rs +++ b/recursion/circuit/src/fri.rs @@ -34,7 +34,6 @@ pub fn verify_shape_and_sample_challenges( } assert_eq!(proof.query_proofs.len(), config.num_queries); - challenger.check_witness(builder, config.proof_of_work_bits, proof.pow_witness); let log_max_height = proof.commit_phase_commits.len() + config.log_blowup; @@ -242,16 +241,16 @@ pub mod tests { use p3_fri::{verifier, TwoAdicFriPcsProof}; use p3_matrix::dense::RowMajorMatrix; use rand::rngs::OsRng; - use serial_test::serial; use sp1_recursion_compiler::{ config::OuterConfig, - constraints::{groth16_ffi, ConstraintCompiler}, + constraints::ConstraintCompiler, ir::{Builder, Ext, Felt, SymbolicExt, Var, Witness}, }; use sp1_recursion_core::stark::config::{ outer_perm, test_fri_config, OuterChallenge, OuterChallengeMmcs, OuterChallenger, OuterCompress, OuterDft, OuterFriProof, OuterHash, OuterPcs, OuterVal, OuterValMmcs, }; + use sp1_recursion_groth16_ffi::Groth16Prover; use super::{verify_shape_and_sample_challenges, verify_two_adic_pcs, TwoAdicPcsRoundVariable}; use crate::{ @@ -402,7 +401,6 @@ pub mod tests { } #[test] - #[serial] fn test_fri_verify_shape_and_sample_challenges() { let mut rng = &mut OsRng; let log_degrees = &[16, 9, 7, 4, 2]; @@ -485,11 +483,10 @@ pub mod tests { let mut backend = ConstraintCompiler::::default(); let constraints = backend.emit(builder.operations); - groth16_ffi::test_prove::(constraints, Witness::default()); + Groth16Prover::test::(constraints.clone(), Witness::default()); } #[test] - #[serial] fn test_verify_two_adic_pcs() { let mut rng = &mut OsRng; let log_degrees = &[19, 19]; @@ -559,6 +556,6 @@ pub mod tests { let mut backend = ConstraintCompiler::::default(); let constraints = backend.emit(builder.operations); - groth16_ffi::test_prove::(constraints, Witness::default()); + Groth16Prover::test::(constraints.clone(), Witness::default()); } } diff --git a/recursion/circuit/src/poseidon2.rs b/recursion/circuit/src/poseidon2.rs index 38fe161327..6a1ef7d96a 100644 --- a/recursion/circuit/src/poseidon2.rs +++ b/recursion/circuit/src/poseidon2.rs @@ -58,17 +58,16 @@ pub mod tests { use p3_bn254_fr::Bn254Fr; use p3_field::AbstractField; use p3_symmetric::{CryptographicHasher, Permutation, PseudoCompressionFunction}; - use serial_test::serial; use sp1_recursion_compiler::config::OuterConfig; - use sp1_recursion_compiler::constraints::{groth16_ffi, ConstraintCompiler}; + use sp1_recursion_compiler::constraints::ConstraintCompiler; use sp1_recursion_compiler::ir::{Builder, Felt, Var, Witness}; use sp1_recursion_core::stark::config::{outer_perm, OuterCompress, OuterHash}; + use sp1_recursion_groth16_ffi::Groth16Prover; use crate::poseidon2::Poseidon2CircuitBuilder; use crate::types::OuterDigestVariable; #[test] - #[serial] fn test_p2_permute_mut() { let poseidon2 = outer_perm(); let input: [Bn254Fr; 3] = [ @@ -91,11 +90,10 @@ pub mod tests { let mut backend = ConstraintCompiler::::default(); let constraints = backend.emit(builder.operations); - groth16_ffi::test_prove::(constraints, Witness::default()); + Groth16Prover::test::(constraints.clone(), Witness::default()); } #[test] - #[serial] fn test_p2_hash() { let perm = outer_perm(); let hasher = OuterHash::new(perm.clone()).unwrap(); @@ -109,7 +107,7 @@ pub mod tests { BabyBear::from_canonical_u32(2), BabyBear::from_canonical_u32(2), ]; - let output = hasher.hash_iter(input.into_iter()); + let output = hasher.hash_iter(input); let mut builder = Builder::::default(); let a: Felt<_> = builder.eval(input[0]); @@ -125,11 +123,10 @@ pub mod tests { let mut backend = ConstraintCompiler::::default(); let constraints = backend.emit(builder.operations); - groth16_ffi::test_prove::(constraints, Witness::default()); + Groth16Prover::test::(constraints.clone(), Witness::default()); } #[test] - #[serial] fn test_p2_compress() { let perm = outer_perm(); let compressor = OuterCompress::new(perm.clone()); @@ -146,6 +143,6 @@ pub mod tests { let mut backend = ConstraintCompiler::::default(); let constraints = backend.emit(builder.operations); - groth16_ffi::test_prove::(constraints, Witness::default()); + Groth16Prover::test::(constraints.clone(), Witness::default()); } } diff --git a/recursion/circuit/src/stark.rs b/recursion/circuit/src/stark.rs index 77aab3af06..244491fbf2 100644 --- a/recursion/circuit/src/stark.rs +++ b/recursion/circuit/src/stark.rs @@ -1,5 +1,6 @@ use std::marker::PhantomData; +use crate::fri::verify_two_adic_pcs; use crate::types::OuterDigestVariable; use crate::witness::Witnessable; use p3_air::Air; @@ -13,10 +14,10 @@ use sp1_core::{ }; use sp1_recursion_compiler::config::OuterConfig; use sp1_recursion_compiler::constraints::{Constraint, ConstraintCompiler}; -use sp1_recursion_compiler::ir::{Builder, Config}; +use sp1_recursion_compiler::ir::{Builder, Config, Felt}; use sp1_recursion_compiler::ir::{Usize, Witness}; use sp1_recursion_compiler::prelude::SymbolicVar; -use sp1_recursion_core::stark::config::BabyBearPoseidon2Outer; +use sp1_recursion_core::stark::config::{outer_fri_config, BabyBearPoseidon2Outer}; use sp1_recursion_core::stark::RecursionAir; use sp1_recursion_program::commit::PolynomialSpaceVariable; use sp1_recursion_program::folder::RecursiveVerifierConstraintFolder; @@ -45,7 +46,7 @@ where machine: &StarkMachine, challenger: &mut MultiField32ChallengerVariable, proof: &RecursionShardProofVariable, - _sorted_chips: Vec, + sorted_chips: Vec, sorted_indices: Vec, ) where A: MachineAir + for<'a> Air>, @@ -66,13 +67,13 @@ where quotient_commit, } = commitment; - let _permutation_challenges = (0..2) + let permutation_challenges = (0..2) .map(|_| challenger.sample_ext(builder)) .collect::>(); challenger.observe_commitment(builder, *permutation_commit); - let _alpha = challenger.sample_ext(builder); + let alpha = challenger.sample_ext(builder); challenger.observe_commitment(builder, *quotient_commit); @@ -200,37 +201,31 @@ where rounds.push(main_round); rounds.push(perm_round); rounds.push(quotient_round); - // let config = outer_fri_config(); - // verify_two_adic_pcs(builder, &config, &proof.opening_proof, challenger, rounds); - - // for (i, sorted_chip) in sorted_chips.iter().enumerate() { - // for chip in machine.chips() { - // if chip.name() == *sorted_chip { - // println!("chip {} = {}", i, sorted_chip); - // builder.print_debug(4 + i); - // if chip.preprocessed_width() > 0 { - // continue; - // } - // let values = &opened_values.chips[i]; - // let trace_domain = &trace_domains[i]; - // let quotient_domain = "ient_domains[i]; - // let qc_domains = - // quotient_domain.split_domains(builder, chip.log_quotient_degree()); - // Self::verify_constraints( - // builder, - // chip, - // values, - // proof.public_values.clone(), - // trace_domain.clone(), - // qc_domains, - // zeta, - // alpha, - // &permutation_challenges, - // ); - // builder.print_debug(4 + i); - // } - // } - // } + let config = outer_fri_config(); + verify_two_adic_pcs(builder, &config, &proof.opening_proof, challenger, rounds); + + for (i, sorted_chip) in sorted_chips.iter().enumerate() { + for chip in machine.chips() { + if chip.name() == *sorted_chip { + let values = &opened_values.chips[i]; + let trace_domain = &trace_domains[i]; + let quotient_domain = "ient_domains[i]; + let qc_domains = + quotient_domain.split_domains(builder, chip.log_quotient_degree()); + Self::verify_constraints( + builder, + chip, + values, + proof.public_values.clone(), + trace_domain.clone(), + qc_domains, + zeta, + alpha, + &permutation_challenges, + ); + } + } + } } } @@ -252,6 +247,8 @@ pub fn build_wrap_circuit( let preprocessed_commit: OuterDigestVariable = [builder.eval(preprocessed_commit_val[0])]; challenger.observe_commitment(&mut builder, preprocessed_commit); + let pc_start: Felt<_> = builder.eval(vk.pc_start); + challenger.observe(&mut builder, pc_start); let chips = outer_machine .shard_chips_ordered(&dummy_proof.chip_ordering) @@ -303,15 +300,15 @@ pub(crate) mod tests { use crate::witness::Witnessable; use p3_baby_bear::DiffusionMatrixBabybear; use p3_field::PrimeField32; - use serial_test::serial; use sp1_core::stark::{LocalProver, StarkGenericConfig}; + use sp1_recursion_compiler::config::OuterConfig; use sp1_recursion_compiler::ir::Witness; - use sp1_recursion_compiler::{config::OuterConfig, constraints::groth16_ffi}; use sp1_recursion_core::{ cpu::Instruction, runtime::{Opcode, RecursionProgram, Runtime}, stark::{config::BabyBearPoseidon2Outer, RecursionAir}, }; + use sp1_recursion_groth16_ffi::Groth16Prover; pub fn basic_program() -> RecursionProgram { let zero = [F::zero(); 4]; @@ -348,7 +345,6 @@ pub(crate) mod tests { } #[test] - #[serial] fn test_recursive_verify_shard_v2() { type SC = BabyBearPoseidon2Outer; type F = ::Val; @@ -363,19 +359,23 @@ pub(crate) mod tests { let machine = A::machine(config); let (pk, vk) = machine.setup(&program); let mut challenger = machine.config().challenger(); - let mut proofs = machine - .prove::>(&pk, runtime.record, &mut challenger) - .shard_proofs; + let proof = machine.prove::>(&pk, runtime.record, &mut challenger); + let mut proofs = proof.shard_proofs.clone(); let mut runtime = Runtime::::new_no_perm(&program); runtime.run(); + // Uncomment these lines to verify the proof for debugging purposes. + // + // let mut challenger = machine.config().challenger(); + // machine.verify(&vk, &proof, &mut challenger).unwrap(); + let mut witness = Witness::default(); let proof = proofs.pop().unwrap(); proof.write(&mut witness); let constraints = build_wrap_circuit(&vk, proof); - groth16_ffi::test_prove::(constraints, witness); + Groth16Prover::test::(constraints, witness); } } diff --git a/recursion/circuit/src/witness.rs b/recursion/circuit/src/witness.rs index 6e5f876553..53f0c07504 100644 --- a/recursion/circuit/src/witness.rs +++ b/recursion/circuit/src/witness.rs @@ -330,16 +330,15 @@ mod tests { use p3_baby_bear::BabyBear; use p3_bn254_fr::Bn254Fr; use p3_field::AbstractField; - use serial_test::serial; use sp1_recursion_compiler::{ config::OuterConfig, - constraints::{groth16_ffi, ConstraintCompiler}, + constraints::ConstraintCompiler, ir::{Builder, ExtConst, Witness}, }; use sp1_recursion_core::stark::config::OuterChallenge; + use sp1_recursion_groth16_ffi::Groth16Prover; #[test] - #[serial] fn test_witness_simple() { let mut builder = Builder::::default(); let a = builder.witness_var(); @@ -365,7 +364,7 @@ mod tests { let mut backend = ConstraintCompiler::::default(); let constraints = backend.emit(builder.operations); - groth16_ffi::test_prove::( + Groth16Prover::test::( constraints, Witness { vars: vec![Bn254Fr::one(), Bn254Fr::two()], diff --git a/recursion/compiler/src/constraints/groth16_ffi.rs b/recursion/compiler/src/constraints/groth16_ffi.rs deleted file mode 100644 index e525cb2ae1..0000000000 --- a/recursion/compiler/src/constraints/groth16_ffi.rs +++ /dev/null @@ -1,87 +0,0 @@ -use std::{ - fs::File, - io::Write, - process::{Command, Stdio}, -}; - -use p3_field::AbstractExtensionField; -use p3_field::AbstractField; -use p3_field::PrimeField; -use serde::Deserialize; -use serde::Serialize; - -use super::Constraint; -use crate::prelude::{Config, Witness}; - -#[derive(Debug, Clone, Default, Serialize, Deserialize)] -pub struct Groth16Witness { - pub vars: Vec, - pub felts: Vec, - pub exts: Vec>, -} - -pub fn test_prove(constraints: Vec, mut witness: Witness) { - let serialized = serde_json::to_string(&constraints).unwrap(); - let manifest_dir = env!("CARGO_MANIFEST_DIR"); - let dir = format!("{}/../groth16", manifest_dir); - - // Append some dummy elements to the witness to avoid compilation errors. - witness.vars.push(C::N::from_canonical_usize(999)); - witness.felts.push(C::F::from_canonical_usize(999)); - witness.exts.push(C::EF::from_canonical_usize(999)); - - // Write constraints. - let constraints_path = format!("{}/constraints.json", dir); - let mut file = File::create(constraints_path).unwrap(); - file.write_all(serialized.as_bytes()).unwrap(); - - // Write witness. - let witness_path = format!("{}/witness.json", dir); - let gnark_witness = Groth16Witness { - vars: witness - .vars - .into_iter() - .map(|w| w.as_canonical_biguint().to_string()) - .collect(), - felts: witness - .felts - .into_iter() - .map(|w| w.as_canonical_biguint().to_string()) - .collect(), - exts: witness - .exts - .into_iter() - .map(|w| { - w.as_base_slice() - .iter() - .map(|x| x.as_canonical_biguint().to_string()) - .collect() - }) - .collect(), - }; - let mut file = File::create(witness_path).unwrap(); - let serialized = serde_json::to_string(&gnark_witness).unwrap(); - file.write_all(serialized.as_bytes()).unwrap(); - - let result = Command::new("go") - .args([ - "test", - "-tags=prover_checks", - "-v", - "-timeout", - "100000s", - "-run", - "^TestMain$", - "github.com/succinctlabs/sp1-recursion-groth16", - ]) - .current_dir(dir) - .stderr(Stdio::inherit()) - .stdout(Stdio::inherit()) - .stdin(Stdio::inherit()) - .output() - .unwrap(); - - if !result.status.success() { - panic!("failed to run test circuit"); - } -} diff --git a/recursion/compiler/src/constraints/mod.rs b/recursion/compiler/src/constraints/mod.rs index c9fdbc5cb8..8d70b9eb8c 100644 --- a/recursion/compiler/src/constraints/mod.rs +++ b/recursion/compiler/src/constraints/mod.rs @@ -1,4 +1,4 @@ -pub mod groth16_ffi; +// pub mod groth16_ffi; pub mod opcodes; use core::fmt::Debug; @@ -340,65 +340,3 @@ impl ConstraintCompiler { constraints } } - -#[cfg(test)] -mod tests { - - use p3_baby_bear::BabyBear; - use p3_bn254_fr::Bn254Fr; - use p3_field::{extension::BinomialExtensionField, AbstractField}; - use serial_test::serial; - - use super::*; - use crate::{ - config::OuterConfig, - ir::{Builder, Ext, Felt, Var}, - prelude::Witness, - }; - - #[test] - #[serial] - fn test_imm() { - let program = vec![ - DslIr::ImmV(Var::new(0), Bn254Fr::zero()), - DslIr::ImmF(Felt::new(1), BabyBear::one()), - DslIr::ImmE(Ext::new(2), BinomialExtensionField::::one()), - DslIr::PrintV(Var::new(0)), - DslIr::PrintF(Felt::new(1)), - DslIr::PrintE(Ext::new(2)), - ]; - let mut backend = ConstraintCompiler::::default(); - let constraints = backend.emit(program.into()); - groth16_ffi::test_prove::(constraints, Witness::default()); - } - - #[test] - #[serial] - fn test_basic_program() { - let mut builder = Builder::::default(); - let a: Var<_> = builder.eval(Bn254Fr::two()); - let b: Var<_> = builder.eval(Bn254Fr::from_canonical_u32(100)); - let c: Var<_> = builder.eval(a * b); - builder.assert_var_eq(c, Bn254Fr::from_canonical_u32(200)); - - let mut backend = ConstraintCompiler::::default(); - let constraints = backend.emit(builder.operations); - groth16_ffi::test_prove::(constraints, Witness::default()); - } - - #[test] - #[serial] - fn test_num2bits_v() { - let mut builder = Builder::::default(); - let value_u32 = 100; - let a: Var<_> = builder.eval(Bn254Fr::from_canonical_u32(value_u32)); - let bits = builder.num2bits_v_circuit(a, 32); - for i in 0..32 { - builder.assert_var_eq(bits[i], Bn254Fr::from_canonical_u32((value_u32 >> i) & 1)); - } - - let mut backend = ConstraintCompiler::::default(); - let constraints = backend.emit(builder.operations); - groth16_ffi::test_prove::(constraints, Witness::default()); - } -} diff --git a/recursion/groth16-ffi/src/lib.rs b/recursion/groth16-ffi/src/lib.rs index 4cd08a0617..3be6a01cbe 100644 --- a/recursion/groth16-ffi/src/lib.rs +++ b/recursion/groth16-ffi/src/lib.rs @@ -1,3 +1,5 @@ +#![warn(unused_extern_crates)] + pub mod witness; use std::{ @@ -53,19 +55,17 @@ impl Groth16Prover { pub fn test(constraints: Vec, witness: Witness) { let serialized = serde_json::to_string(&constraints).unwrap(); let manifest_dir = env!("CARGO_MANIFEST_DIR"); - let dir = format!("{}/../groth16", manifest_dir); + let groth16_dir = format!("{}/../groth16", manifest_dir); // Write constraints. - let constraints_path = format!("{}/constraints.json", dir); - let mut file = File::create(constraints_path).unwrap(); - file.write_all(serialized.as_bytes()).unwrap(); + let mut constraints_file = tempfile::NamedTempFile::new().unwrap(); + constraints_file.write_all(serialized.as_bytes()).unwrap(); // Write witness. - let witness_path = format!("{}/witness.json", dir); + let mut witness_file = tempfile::NamedTempFile::new().unwrap(); let gnark_witness: Groth16Witness = witness.into(); - let mut file = File::create(witness_path).unwrap(); let serialized = serde_json::to_string(&gnark_witness).unwrap(); - file.write_all(serialized.as_bytes()).unwrap(); + witness_file.write_all(serialized.as_bytes()).unwrap(); let result = Command::new("go") .args([ @@ -78,7 +78,12 @@ impl Groth16Prover { "^TestMain$", "github.com/succinctlabs/sp1-recursion-groth16", ]) - .current_dir(dir) + .current_dir(groth16_dir) + .env("WITNESS_JSON", witness_file.path().to_str().unwrap()) + .env( + "CONSTRAINTS_JSON", + constraints_file.path().to_str().unwrap(), + ) .stderr(Stdio::inherit()) .stdout(Stdio::inherit()) .stdin(Stdio::inherit()) diff --git a/recursion/groth16/main_test.go b/recursion/groth16/main_test.go index 61247d65e0..a6998a595e 100644 --- a/recursion/groth16/main_test.go +++ b/recursion/groth16/main_test.go @@ -2,19 +2,18 @@ package main import ( "encoding/json" - "fmt" "os" "testing" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/test" + "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/succinctlabs/sp1-recursion-groth16/babybear" ) func TestMain(t *testing.T) { - assert := test.NewAssert(t) + // assert := test.NewAssert(t) // Get the file name from an environment variable. fileName := os.Getenv("WITNESS_JSON") @@ -29,27 +28,24 @@ func TestMain(t *testing.T) { } // Deserialize the JSON data into a slice of Instruction structs - var witness Inputs - err = json.Unmarshal(data, &witness) + var inputs Inputs + err = json.Unmarshal(data, &inputs) if err != nil { panic(err) } - vars := make([]frontend.Variable, len(witness.Vars)) - felts := make([]*babybear.Variable, len(witness.Felts)) - exts := make([]*babybear.ExtensionVariable, len(witness.Exts)) - for i := 0; i < len(witness.Vars); i++ { - vars[i] = frontend.Variable(witness.Vars[i]) + vars := make([]frontend.Variable, len(inputs.Vars)) + felts := make([]*babybear.Variable, len(inputs.Felts)) + exts := make([]*babybear.ExtensionVariable, len(inputs.Exts)) + for i := 0; i < len(inputs.Vars); i++ { + vars[i] = frontend.Variable(inputs.Vars[i]) } - fmt.Println("NbVars:", len(vars)) - for i := 0; i < len(witness.Felts); i++ { - felts[i] = babybear.NewF(witness.Felts[i]) + for i := 0; i < len(inputs.Felts); i++ { + felts[i] = babybear.NewF(inputs.Felts[i]) } - fmt.Println("NbFelts:", len(felts)) - for i := 0; i < len(witness.Exts); i++ { - exts[i] = babybear.NewE(witness.Exts[i]) + for i := 0; i < len(inputs.Exts); i++ { + exts[i] = babybear.NewE(inputs.Exts[i]) } - fmt.Println("NbExts:", len(exts)) // Run some sanity checks. circuit := Circuit{ @@ -57,5 +53,52 @@ func TestMain(t *testing.T) { Felts: felts, Exts: exts, } - assert.CheckCircuit(&circuit, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16)) + + // Compile the circuit. + builder := r1cs.NewBuilder + r1cs, err := frontend.Compile(ecc.BN254.ScalarField(), builder, &circuit) + if err != nil { + panic(err) + } + + // Run the dummy setup. + var pk groth16.ProvingKey + pk, err = groth16.DummySetup(r1cs) + if err != nil { + panic(err) + } + + // Generate witness. + vars = make([]frontend.Variable, len(inputs.Vars)) + felts = make([]*babybear.Variable, len(inputs.Felts)) + exts = make([]*babybear.ExtensionVariable, len(inputs.Exts)) + for i := 0; i < len(inputs.Vars); i++ { + vars[i] = frontend.Variable(inputs.Vars[i]) + } + for i := 0; i < len(inputs.Felts); i++ { + felts[i] = babybear.NewF(inputs.Felts[i]) + } + for i := 0; i < len(inputs.Exts); i++ { + exts[i] = babybear.NewE(inputs.Exts[i]) + } + assignment := Circuit{ + Vars: vars, + Felts: felts, + Exts: exts, + } + witness, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } + + // Generate the proof. + _, err = groth16.Prove(r1cs, pk, witness) + if err != nil { + panic(err) + } + + // This was the old way we were testing the circuit, but it seems to have edge cases where it + // doesn't properly check that the prover will succeed. + // + // assert.CheckCircuit(&circuit, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16)) }