Skip to content

Commit

Permalink
chore: cleanup prover (#551)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtguibas authored Apr 20, 2024
1 parent 2f92f82 commit 08a6281
Show file tree
Hide file tree
Showing 27 changed files with 463 additions and 636 deletions.
9 changes: 5 additions & 4 deletions core/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub struct SP1Stdin {
/// Public values for the prover.
#[derive(Serialize, Deserialize)]
pub struct SP1PublicValues {
// TODO: fix
pub buffer: Buffer,
}

Expand Down Expand Up @@ -110,10 +111,10 @@ impl AsRef<[u8]> for SP1PublicValues {
pub mod proof_serde {
use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize};

use crate::stark::{Proof, StarkGenericConfig};
use crate::stark::{MachineProof, StarkGenericConfig};

pub fn serialize<S, SC: StarkGenericConfig + Serialize>(
proof: &Proof<SC>,
proof: &MachineProof<SC>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
Expand All @@ -130,7 +131,7 @@ pub mod proof_serde {

pub fn deserialize<'de, D, SC: StarkGenericConfig + DeserializeOwned>(
deserializer: D,
) -> Result<Proof<SC>, D::Error>
) -> Result<MachineProof<SC>, D::Error>
where
D: Deserializer<'de>,
{
Expand All @@ -140,7 +141,7 @@ pub mod proof_serde {
let proof = bincode::deserialize(&bytes).map_err(serde::de::Error::custom)?;
Ok(proof)
} else {
Proof::<SC>::deserialize(deserializer)
MachineProof::<SC>::deserialize(deserializer)
}
}
}
4 changes: 2 additions & 2 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ pub use io::*;
use runtime::{Program, Runtime};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use stark::Proof;
use stark::MachineProof;
use stark::StarkGenericConfig;

/// A proof of a RISCV ELF execution with given inputs and outputs.
#[derive(Serialize, Deserialize)]
#[deprecated(note = "Import from sp1_sdk instead of sp1_core")]
pub struct SP1ProofWithIO<SC: StarkGenericConfig + Serialize + DeserializeOwned> {
#[serde(with = "proof_serde")]
pub proof: Proof<SC>,
pub proof: MachineProof<SC>,
pub stdin: SP1Stdin,
pub public_values: SP1PublicValues,
}
8 changes: 4 additions & 4 deletions core/src/lookup/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use p3_matrix::Matrix;

use super::InteractionKind;
use crate::air::MachineAir;
use crate::stark::{MachineChip, MachineStark, ProvingKey, StarkGenericConfig, Val};
use crate::stark::{MachineChip, StarkGenericConfig, StarkMachine, StarkProvingKey, Val};

#[derive(Debug)]
pub struct InteractionData<F: Field> {
Expand Down Expand Up @@ -47,7 +47,7 @@ fn field_to_int<F: PrimeField32>(x: F) -> i32 {

pub fn debug_interactions<SC: StarkGenericConfig, A: MachineAir<Val<SC>>>(
chip: &MachineChip<SC, A>,
pkey: &ProvingKey<SC>,
pkey: &StarkProvingKey<SC>,
record: &A::Record,
interaction_kinds: Vec<InteractionKind>,
) -> (
Expand Down Expand Up @@ -126,8 +126,8 @@ pub fn debug_interactions<SC: StarkGenericConfig, A: MachineAir<Val<SC>>>(
/// Calculate the number of times we send and receive each event of the given interaction type,
/// and print out the ones for which the set of sends and receives don't match.
pub fn debug_interactions_with_all_chips<SC, A>(
machine: &MachineStark<SC, A>,
pkey: &ProvingKey<SC>,
machine: &StarkMachine<SC, A>,
pkey: &StarkProvingKey<SC>,
shards: &[A::Record],
interaction_kinds: Vec<InteractionKind>,
) -> bool
Expand Down
2 changes: 1 addition & 1 deletion core/src/memory/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ mod tests {
let program_clone = program.clone();
let mut runtime = Runtime::new(program);
runtime.run();
let machine: crate::stark::MachineStark<BabyBearPoseidon2, RiscvAir<BabyBear>> =
let machine: crate::stark::StarkMachine<BabyBearPoseidon2, RiscvAir<BabyBear>> =
RiscvAir::machine(BabyBearPoseidon2::new());
let (pkey, _) = machine.setup(&program_clone);
let shards = machine.shard(
Expand Down
6 changes: 3 additions & 3 deletions core/src/runtime/io.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::io::Read;

use crate::stark::{Proof, VerifyingKey};
use crate::stark::{MachineProof, StarkVerifyingKey};
use crate::utils::BabyBearPoseidon2Inner;

use serde::de::DeserializeOwned;
Expand Down Expand Up @@ -34,8 +34,8 @@ impl Runtime {

pub fn write_proof(
&mut self,
proof: Proof<BabyBearPoseidon2Inner>,
vk: VerifyingKey<BabyBearPoseidon2Inner>,
proof: MachineProof<BabyBearPoseidon2Inner>,
vk: StarkVerifyingKey<BabyBearPoseidon2Inner>,
) {
self.state.proof_stream.push((proof, vk));
}
Expand Down
6 changes: 3 additions & 3 deletions core/src/runtime/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
use serde_with::serde_as;

use crate::{
stark::{Proof, VerifyingKey},
stark::{MachineProof, StarkVerifyingKey},
utils::BabyBearPoseidon2Inner,
};

Expand Down Expand Up @@ -45,8 +45,8 @@ pub struct ExecutionState {
/// A stream of proofs inputted to the program.
#[serde(skip)] // TODO: fix serialization for VerifyingKey
pub proof_stream: Vec<(
Proof<BabyBearPoseidon2Inner>,
VerifyingKey<BabyBearPoseidon2Inner>,
MachineProof<BabyBearPoseidon2Inner>,
StarkVerifyingKey<BabyBearPoseidon2Inner>,
)>,

/// A ptr to the current position in the proof stream, incremented after verifying a proof.
Expand Down
6 changes: 3 additions & 3 deletions core/src/stark/air.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::MachineStark;
use super::StarkMachine;
pub use crate::air::SP1AirBuilder;
use crate::air::{MachineAir, SP1_PROOF_NUM_PV_ELTS};
use crate::memory::{MemoryChipType, MemoryProgramChip};
Expand Down Expand Up @@ -104,12 +104,12 @@ pub enum RiscvAir<F: PrimeField32> {
}

impl<F: PrimeField32> RiscvAir<F> {
pub fn machine<SC: StarkGenericConfig<Val = F>>(config: SC) -> MachineStark<SC, Self> {
pub fn machine<SC: StarkGenericConfig<Val = F>>(config: SC) -> StarkMachine<SC, Self> {
let chips = Self::get_all()
.into_iter()
.map(Chip::new)
.collect::<Vec<_>>();
MachineStark::new(config, chips, SP1_PROOF_NUM_PV_ELTS)
StarkMachine::new(config, chips, SP1_PROOF_NUM_PV_ELTS)
}

/// Get all the different RISC-V AIRs.
Expand Down
34 changes: 17 additions & 17 deletions core/src/stark/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ use crate::stark::VerifierConstraintFolder;

use super::Chip;
use super::Com;
use super::MachineProof;
use super::PcsProverData;
use super::Proof;
use super::Prover;
use super::StarkGenericConfig;
use super::Val;
Expand All @@ -44,7 +44,7 @@ use super::Verifier;
pub type MachineChip<SC, A> = Chip<Val<SC>, A>;

/// A STARK for proving RISC-V execution.
pub struct MachineStark<SC: StarkGenericConfig, A> {
pub struct StarkMachine<SC: StarkGenericConfig, A> {
/// The STARK settings for the RISC-V STARK.
config: SC,
/// The chips that make up the RISC-V STARK machine, in order of their execution.
Expand All @@ -54,7 +54,7 @@ pub struct MachineStark<SC: StarkGenericConfig, A> {
num_pv_elts: usize,
}

impl<SC: StarkGenericConfig, A> MachineStark<SC, A> {
impl<SC: StarkGenericConfig, A> StarkMachine<SC, A> {
pub fn new(config: SC, chips: Vec<Chip<Val<SC>, A>>, num_pv_elts: usize) -> Self {
Self {
config,
Expand All @@ -64,43 +64,43 @@ impl<SC: StarkGenericConfig, A> MachineStark<SC, A> {
}
}

pub struct ProvingKey<SC: StarkGenericConfig> {
pub struct StarkProvingKey<SC: StarkGenericConfig> {
pub commit: Com<SC>,
pub pc_start: Val<SC>,
pub traces: Vec<RowMajorMatrix<Val<SC>>>,
pub data: PcsProverData<SC>,
pub chip_ordering: HashMap<String, usize>,
}

impl<SC: StarkGenericConfig> ProvingKey<SC> {
impl<SC: StarkGenericConfig> StarkProvingKey<SC> {
pub fn observe_into(&self, challenger: &mut SC::Challenger) {
challenger.observe(self.commit.clone());
challenger.observe(self.pc_start);
}
}

#[derive(Clone)]
pub struct VerifyingKey<SC: StarkGenericConfig> {
pub struct StarkVerifyingKey<SC: StarkGenericConfig> {
pub commit: Com<SC>,
pub pc_start: Val<SC>,
pub chip_information: Vec<(String, Dom<SC>, Dimensions)>,
pub chip_ordering: HashMap<String, usize>,
}

impl<SC: StarkGenericConfig> VerifyingKey<SC> {
impl<SC: StarkGenericConfig> StarkVerifyingKey<SC> {
pub fn observe_into(&self, challenger: &mut SC::Challenger) {
challenger.observe(self.commit.clone());
challenger.observe(self.pc_start);
}
}

impl<SC: StarkGenericConfig> Debug for VerifyingKey<SC> {
impl<SC: StarkGenericConfig> Debug for StarkVerifyingKey<SC> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VerifyingKey").finish()
}
}

impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> MachineStark<SC, A> {
impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> StarkMachine<SC, A> {
/// Get an array containing a `ChipRef` for all the chips of this RISC-V STARK machine.
pub fn chips(&self) -> &[MachineChip<SC, A>] {
&self.chips
Expand Down Expand Up @@ -154,7 +154,7 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> MachineStark<SC, A> {
///
/// Given a program, this function generates the proving and verifying keys. The keys correspond
/// to the program code and other preprocessed colunms such as lookup tables.
pub fn setup(&self, program: &A::Program) -> (ProvingKey<SC>, VerifyingKey<SC>) {
pub fn setup(&self, program: &A::Program) -> (StarkProvingKey<SC>, StarkVerifyingKey<SC>) {
let mut named_preprocessed_traces = self
.chips()
.iter()
Expand Down Expand Up @@ -213,14 +213,14 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> MachineStark<SC, A> {
let pc_start = program.pc_start();

(
ProvingKey {
StarkProvingKey {
commit: commit.clone(),
pc_start,
traces,
data,
chip_ordering: chip_ordering.clone(),
},
VerifyingKey {
StarkVerifyingKey {
commit,
pc_start,
chip_information,
Expand Down Expand Up @@ -261,10 +261,10 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> MachineStark<SC, A> {
/// a STARK proof that the execution record is valid.
pub fn prove<P: Prover<SC, A>>(
&self,
pk: &ProvingKey<SC>,
pk: &StarkProvingKey<SC>,
record: A::Record,
challenger: &mut SC::Challenger,
) -> Proof<SC>
) -> MachineProof<SC>
where
A: for<'a> Air<ProverConstraintFolder<'a, SC>>
+ Air<InteractionBuilder<Val<SC>>>
Expand All @@ -285,8 +285,8 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> MachineStark<SC, A> {
/// Verify that a proof is complete and valid given a verifying key and a claimed digest.
pub fn verify(
&self,
vk: &VerifyingKey<SC>,
proof: &Proof<SC>,
vk: &StarkVerifyingKey<SC>,
proof: &MachineProof<SC>,
challenger: &mut SC::Challenger,
) -> Result<(PublicValuesDigest, DeferredDigest), ProgramVerificationError>
where
Expand Down Expand Up @@ -413,7 +413,7 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> MachineStark<SC, A> {

pub fn debug_constraints(
&self,
pk: &ProvingKey<SC>,
pk: &StarkProvingKey<SC>,
record: A::Record,
challenger: &mut SC::Challenger,
) where
Expand Down
24 changes: 12 additions & 12 deletions core/src/stark/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ use p3_util::log2_ceil_usize;
use p3_util::log2_strict_usize;
use web_time::Instant;

use super::{quotient_values, MachineStark, PcsProverData, Val};
use super::{quotient_values, PcsProverData, StarkMachine, Val};
use super::{types::*, StarkGenericConfig};
use super::{Com, OpeningProof};
use super::{ProvingKey, VerifierConstraintFolder};
use super::{StarkProvingKey, VerifierConstraintFolder};
use crate::air::MachineAir;
use crate::lookup::InteractionBuilder;
use crate::stark::record::MachineRecord;
Expand All @@ -43,11 +43,11 @@ fn chunk_vec<T>(mut vec: Vec<T>, chunk_size: usize) -> Vec<Vec<T>> {

pub trait Prover<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> {
fn prove_shards(
machine: &MachineStark<SC, A>,
pk: &ProvingKey<SC>,
machine: &StarkMachine<SC, A>,
pk: &StarkProvingKey<SC>,
shards: Vec<A::Record>,
challenger: &mut SC::Challenger,
) -> Proof<SC>
) -> MachineProof<SC>
where
A: for<'a> Air<ProverConstraintFolder<'a, SC>>
+ Air<InteractionBuilder<Val<SC>>>
Expand All @@ -66,11 +66,11 @@ where
A: MachineAir<Val<SC>>,
{
fn prove_shards(
machine: &MachineStark<SC, A>,
pk: &ProvingKey<SC>,
machine: &StarkMachine<SC, A>,
pk: &StarkProvingKey<SC>,
shards: Vec<A::Record>,
challenger: &mut SC::Challenger,
) -> Proof<SC>
) -> MachineProof<SC>
where
A: for<'a> Air<ProverConstraintFolder<'a, SC>>
+ Air<InteractionBuilder<Val<SC>>>
Expand Down Expand Up @@ -147,7 +147,7 @@ where
.collect::<Vec<_>>()
});

Proof { shard_proofs }
MachineProof { shard_proofs }
}
}

Expand All @@ -164,7 +164,7 @@ where
{
pub fn commit_main(
config: &SC,
machine: &MachineStark<SC, A>,
machine: &StarkMachine<SC, A>,
shard: &A::Record,
index: usize,
) -> ShardMainData<SC> {
Expand Down Expand Up @@ -223,7 +223,7 @@ where
/// Prove the program for the given shard and given a commitment to the main data.
pub fn prove_shard(
config: &SC,
pk: &ProvingKey<SC>,
pk: &StarkProvingKey<SC>,
chips: &[&MachineChip<SC, A>],
mut shard_data: ShardMainData<SC>,
challenger: &mut SC::Challenger,
Expand Down Expand Up @@ -521,7 +521,7 @@ where
}

pub fn commit_shards<F, EF>(
machine: &MachineStark<SC, A>,
machine: &StarkMachine<SC, A>,
shards: &[A::Record],
) -> (Vec<Com<SC>>, Vec<ShardMainDataWrapper<SC>>)
where
Expand Down
4 changes: 2 additions & 2 deletions core/src/stark/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ impl<SC: StarkGenericConfig> ShardProof<SC> {

#[derive(Serialize, Deserialize, Clone)]
#[serde(bound = "")]
pub struct Proof<SC: StarkGenericConfig> {
pub struct MachineProof<SC: StarkGenericConfig> {
pub shard_proofs: Vec<ShardProof<SC>>,
}

impl<SC: StarkGenericConfig> Debug for Proof<SC> {
impl<SC: StarkGenericConfig> Debug for MachineProof<SC> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Proof")
.field("shard_proofs", &self.shard_proofs.len())
Expand Down
Loading

0 comments on commit 08a6281

Please sign in to comment.