Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core,host): initial aggregation API #375

Merged
merged 44 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
ba7f009
initial proof aggregation implementation
Brechtpd Aug 20, 2024
4639e32
aggregation improvements + risc0 aggregation
Brechtpd Aug 23, 2024
91cbbd7
sp1 aggregation fixes
Brechtpd Aug 23, 2024
ee21b87
sp1 aggregation elf
Brechtpd Aug 23, 2024
461c943
uuid support for risc0 aggregation
Brechtpd Aug 23, 2024
ae303ce
Merge remote-tracking branch 'origin/main' into proof-aggregation
Brechtpd Aug 23, 2024
bd1135a
risc0 aggregation circuit compile fixes
Brechtpd Aug 23, 2024
f07300b
fix sgx proof aggregation
Brechtpd Aug 25, 2024
ee6fe0a
fmt
Brechtpd Aug 25, 2024
8e86975
Merge branch 'main' into proof-aggregation
petarvujovic98 Sep 17, 2024
83d538b
feat(core,host): initial aggregation API
petarvujovic98 Sep 17, 2024
df758fa
fix(core,host,sgx): fix compiler and clippy errors
petarvujovic98 Sep 17, 2024
973431e
fix(core,lib,provers): revert merge bugs and add sp1 stubs
petarvujovic98 Sep 17, 2024
ccb2fb3
Merge branch 'proof-aggregation' into proof-aggregation-api
petarvujovic98 Sep 17, 2024
9b829db
fix(core): remove double member
petarvujovic98 Sep 17, 2024
2ccac77
fix(sp1): fix dependency naming
petarvujovic98 Sep 17, 2024
247ffee
refactor(risc0): clean up aggregation file
petarvujovic98 Sep 17, 2024
2e6f7d8
fix(sp1): enable verification for proof aggregation
petarvujovic98 Sep 17, 2024
29e6fb1
feat(host): migrate to v3 API
petarvujovic98 Sep 17, 2024
a887cc6
feat(sp1): run cargo fmt
petarvujovic98 Sep 17, 2024
5d29f78
feat(core): make `l1_inclusion_block_number` optional
petarvujovic98 Sep 18, 2024
8a3947b
fixproof req input into prove state manager
smtmfft Sep 19, 2024
7276c7b
feat(core,host,lib,tasks): add aggregation tasks and API
petarvujovic98 Sep 19, 2024
83df208
fix(core): fix typo
petarvujovic98 Sep 19, 2024
728cb3a
Merge remote-tracking branch 'origin/main' into proof-aggregation-api
smtmfft Sep 20, 2024
16e4abc
fix v3 error return
smtmfft Sep 20, 2024
fa6fe88
feat(sp1): implement aggregate function
petarvujovic98 Sep 20, 2024
1aaec18
fix sgx aggregation for back compatibility
smtmfft Sep 23, 2024
789a599
fix(lib): fix typo
petarvujovic98 Sep 23, 2024
0358a26
fix risc0 aggregation
smtmfft Sep 24, 2024
a033011
Merge branch 'proof-aggregation-api' of https://github.com/taikoxyz/r…
smtmfft Sep 24, 2024
9da8834
fix(host,sp1): handle statuses
petarvujovic98 Sep 26, 2024
bdcb5f3
enable sp1 aggregation
smtmfft Sep 28, 2024
334d93f
feat(host): error out on empty proof array request
petarvujovic98 Sep 30, 2024
e1d79ff
fix(host): return proper status report
petarvujovic98 Sep 30, 2024
e34567f
feat(host,tasks): adding details to error statuses
petarvujovic98 Sep 30, 2024
d14774d
fix sp1 aggregation
smtmfft Oct 2, 2024
62c29a5
update prove-block script
smtmfft Oct 3, 2024
f2a2e2c
fix(fmt): run cargo fmt
petarvujovic98 Oct 7, 2024
e26ef1b
fix(clippy): fix clippy issues
petarvujovic98 Oct 7, 2024
d04e8a0
chore(repo): cleanup captured vars in format calls
petarvujovic98 Oct 7, 2024
e9bd312
Merge branch 'main' into proof-aggregation-api
petarvujovic98 Oct 9, 2024
c55f162
fix(sp1): convert to proper types
petarvujovic98 Oct 9, 2024
8df544b
chore(sp1): remove the unneccessary
petarvujovic98 Oct 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

169 changes: 166 additions & 3 deletions core/src/interfaces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ use alloy_primitives::{Address, B256};
use clap::{Args, ValueEnum};
use raiko_lib::{
consts::VerifierType,
input::{BlobProofType, GuestInput, GuestOutput},
input::{
AggregationGuestInput, AggregationGuestOutput, BlobProofType, GuestInput, GuestOutput,
},
prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverError},
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_with::{serde_as, DisplayFromStr};
use std::{collections::HashMap, path::Path, str::FromStr};
use std::{collections::HashMap, fmt::Display, path::Path, str::FromStr};
use utoipa::ToSchema;

#[derive(Debug, thiserror::Error, ToSchema)]
Expand Down Expand Up @@ -203,6 +205,47 @@ impl ProofType {
}
}

/// Run the prover driver depending on the proof type.
pub async fn aggregate_proofs(
&self,
input: AggregationGuestInput,
output: &AggregationGuestOutput,
config: &Value,
store: Option<&mut dyn IdWrite>,
) -> RaikoResult<Proof> {
let proof = match self {
ProofType::Native => NativeProver::aggregate(input.clone(), output, config, store)
.await
.map_err(<ProverError as Into<RaikoError>>::into),
ProofType::Sp1 => {
#[cfg(feature = "sp1")]
return sp1_driver::Sp1Prover::aggregate(input.clone(), output, config, store)
.await
.map_err(|e| e.into());
#[cfg(not(feature = "sp1"))]
Err(RaikoError::FeatureNotSupportedError(*self))
}
ProofType::Risc0 => {
#[cfg(feature = "risc0")]
return risc0_driver::Risc0Prover::aggregate(input.clone(), output, config, store)
.await
.map_err(|e| e.into());
#[cfg(not(feature = "risc0"))]
Err(RaikoError::FeatureNotSupportedError(*self))
}
ProofType::Sgx => {
#[cfg(feature = "sgx")]
return sgx_prover::SgxProver::aggregate(input.clone(), output, config, store)
.await
.map_err(|e| e.into());
#[cfg(not(feature = "sgx"))]
Err(RaikoError::FeatureNotSupportedError(*self))
}
}?;

Ok(proof)
}

pub async fn cancel_proof(
&self,
proof_key: ProofKey,
Expand Down Expand Up @@ -302,7 +345,7 @@ pub struct ProofRequestOpt {
pub prover_args: ProverSpecificOpts,
}

#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema, Args)]
#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema, Args, PartialEq, Eq, Hash)]
pub struct ProverSpecificOpts {
/// Native prover specific options.
pub native: Option<Value>,
Expand Down Expand Up @@ -398,3 +441,123 @@ impl TryFrom<ProofRequestOpt> for ProofRequest {
})
}
}

#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema)]
#[serde(default)]
/// A request for proof aggregation of multiple proofs.
pub struct AggregationRequest {
/// The block numbers and l1 inclusion block numbers for the blocks to aggregate proofs for.
pub block_numbers: Vec<(u64, Option<u64>)>,
/// The network to generate the proof for.
pub network: Option<String>,
/// The L1 network to generate the proof for.
pub l1_network: Option<String>,
// Graffiti.
pub graffiti: Option<String>,
/// The protocol instance data.
pub prover: Option<String>,
/// The proof type.
pub proof_type: Option<String>,
/// Blob proof type.
pub blob_proof_type: Option<String>,
smtmfft marked this conversation as resolved.
Show resolved Hide resolved
#[serde(flatten)]
/// Any additional prover params in JSON format.
pub prover_args: ProverSpecificOpts,
}

impl AggregationRequest {
/// Merge proof request options into aggregation request options.
pub fn merge(&mut self, opts: &ProofRequestOpt) -> RaikoResult<()> {
let this = serde_json::to_value(&self)?;
let mut opts = serde_json::to_value(opts)?;
merge(&mut opts, &this);
*self = serde_json::from_value(opts)?;
Ok(())
}
}

impl From<AggregationRequest> for Vec<ProofRequestOpt> {
fn from(value: AggregationRequest) -> Self {
value
.block_numbers
.iter()
.map(
|&(block_number, l1_inclusion_block_number)| ProofRequestOpt {
block_number: Some(block_number),
l1_inclusion_block_number,
network: value.network.clone(),
l1_network: value.l1_network.clone(),
graffiti: value.graffiti.clone(),
prover: value.prover.clone(),
proof_type: value.proof_type.clone(),
blob_proof_type: value.blob_proof_type.clone(),
prover_args: value.prover_args.clone(),
},
)
.collect()
}
}

impl From<ProofRequestOpt> for AggregationRequest {
fn from(value: ProofRequestOpt) -> Self {
let block_numbers = if let Some(block_number) = value.block_number {
vec![(block_number, value.l1_inclusion_block_number)]
} else {
vec![]
};

Self {
block_numbers,
network: value.network,
l1_network: value.l1_network,
graffiti: value.graffiti,
prover: value.prover,
proof_type: value.proof_type,
blob_proof_type: value.blob_proof_type,
prover_args: value.prover_args,
}
}
}

#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema, PartialEq, Eq, Hash)]
#[serde(default)]
/// A request for proof aggregation of multiple proofs.
pub struct AggregationOnlyRequest {
/// The block numbers and l1 inclusion block numbers for the blocks to aggregate proofs for.
pub proofs: Vec<Proof>,
/// The proof type.
pub proof_type: Option<String>,
#[serde(flatten)]
/// Any additional prover params in JSON format.
pub prover_args: ProverSpecificOpts,
}

impl Display for AggregationOnlyRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&format!(
"AggregationOnlyRequest {{ {:?}, {:?} }}",
self.proof_type, self.prover_args
))
}
}

impl From<(AggregationRequest, Vec<Proof>)> for AggregationOnlyRequest {
fn from((request, proofs): (AggregationRequest, Vec<Proof>)) -> Self {
Self {
proofs,
proof_type: request.proof_type,
prover_args: request.prover_args,
}
}
}

impl AggregationOnlyRequest {
/// Merge proof request options into aggregation request options.
pub fn merge(&mut self, opts: &ProofRequestOpt) -> RaikoResult<()> {
let this = serde_json::to_value(&self)?;
let mut opts = serde_json::to_value(opts)?;
merge(&mut opts, &this);
*self = serde_json::from_value(opts)?;
Ok(())
}
}
71 changes: 59 additions & 12 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,9 @@ mod tests {
use clap::ValueEnum;
use raiko_lib::{
consts::{Network, SupportedChainSpecs},
input::BlobProofType,
input::{AggregationGuestInput, AggregationGuestOutput, BlobProofType},
primitives::B256,
prover::Proof,
};
use serde_json::{json, Value};
use std::{collections::HashMap, env};
Expand All @@ -242,7 +243,7 @@ mod tests {
ci == "1"
}

fn test_proof_params() -> HashMap<String, Value> {
fn test_proof_params(enable_aggregation: bool) -> HashMap<String, Value> {
let mut prover_args = HashMap::new();
prover_args.insert(
"native".to_string(),
Expand All @@ -256,7 +257,7 @@ mod tests {
"sp1".to_string(),
json! {
{
"recursion": "core",
"recursion": if enable_aggregation { "compressed" } else { "plonk" },
"prover": "mock",
"verify": true
}
Expand All @@ -278,8 +279,8 @@ mod tests {
json! {
{
"instance_id": 121,
"setup": true,
"bootstrap": true,
"setup": enable_aggregation,
"bootstrap": enable_aggregation,
"prove": true,
}
},
Expand All @@ -291,7 +292,7 @@ mod tests {
l1_chain_spec: ChainSpec,
taiko_chain_spec: ChainSpec,
proof_request: ProofRequest,
) {
) -> Proof {
let provider =
RpcBlockDataProvider::new(&taiko_chain_spec.rpc, proof_request.block_number - 1)
.expect("Could not create RpcBlockDataProvider");
Expand All @@ -301,10 +302,10 @@ mod tests {
.await
.expect("input generation failed");
let output = raiko.get_output(&input).expect("output generation failed");
let _proof = raiko
raiko
.prove(input, &output, None)
.await
.expect("proof generation failed");
.expect("proof generation failed")
}

#[ignore]
Expand Down Expand Up @@ -332,7 +333,7 @@ mod tests {
l1_network,
proof_type,
blob_proof_type: BlobProofType::ProofOfEquivalence,
prover_args: test_proof_params(),
prover_args: test_proof_params(false),
};
prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;
}
Expand Down Expand Up @@ -361,7 +362,7 @@ mod tests {
l1_network,
proof_type,
blob_proof_type: BlobProofType::ProofOfEquivalence,
prover_args: test_proof_params(),
prover_args: test_proof_params(false),
};
prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;
}
Expand Down Expand Up @@ -399,7 +400,7 @@ mod tests {
l1_network,
proof_type,
blob_proof_type: BlobProofType::ProofOfEquivalence,
prover_args: test_proof_params(),
prover_args: test_proof_params(false),
};
prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;
}
Expand Down Expand Up @@ -432,9 +433,55 @@ mod tests {
l1_network,
proof_type,
blob_proof_type: BlobProofType::ProofOfEquivalence,
prover_args: test_proof_params(),
prover_args: test_proof_params(false),
};
prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;
}
}

#[tokio::test(flavor = "multi_thread")]
async fn test_prove_block_taiko_a7_aggregated() {
let proof_type = get_proof_type_from_env();
let l1_network = Network::Holesky.to_string();
let network = Network::TaikoA7.to_string();
// Give the CI an simpler block to test because it doesn't have enough memory.
// Unfortunately that also means that kzg is not getting fully verified by CI.
let block_number = if is_ci() { 105987 } else { 101368 };
let taiko_chain_spec = SupportedChainSpecs::default()
.get_chain_spec(&network)
.unwrap();
let l1_chain_spec = SupportedChainSpecs::default()
.get_chain_spec(&l1_network)
.unwrap();

let proof_request = ProofRequest {
block_number,
l1_inclusion_block_number: 0,
network,
graffiti: B256::ZERO,
prover: Address::ZERO,
l1_network,
proof_type,
blob_proof_type: BlobProofType::ProofOfEquivalence,
prover_args: test_proof_params(true),
};
let proof = prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;

let input = AggregationGuestInput {
proofs: vec![proof.clone(), proof],
};

let output = AggregationGuestOutput { hash: B256::ZERO };

let aggregated_proof = proof_type
.aggregate_proofs(
input,
&output,
&serde_json::to_value(&test_proof_params(false)).unwrap(),
None,
)
.await
.expect("proof aggregation failed");
println!("aggregated proof: {aggregated_proof:?}");
}
}
5 changes: 1 addition & 4 deletions core/src/preflight/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,8 @@ pub async fn prepare_taiko_chain_input(
RaikoError::Preflight("No L1 inclusion block hash for the requested block".to_owned())
})?;
info!(
"L1 inclusion block number: {:?}, hash: {:?}. L1 state block number: {:?}, hash: {:?}",
l1_inclusion_block_number,
l1_inclusion_block_hash,
"L1 inclusion block number: {l1_inclusion_block_number:?}, hash: {l1_inclusion_block_hash:?}. L1 state block number: {:?}, hash: {l1_state_block_hash:?}",
l1_state_header.number,
l1_state_block_hash
);

// Fetch the tx data from either calldata or blobdata
Expand Down
14 changes: 14 additions & 0 deletions core/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,28 @@ impl Prover for NativeProver {
}

Ok(Proof {
input: None,
proof: None,
quote: None,
uuid: None,
kzg_proof: None,
})
}

async fn cancel(_proof_key: ProofKey, _read: Box<&mut dyn IdStore>) -> ProverResult<()> {
Ok(())
}

async fn aggregate(
_input: raiko_lib::input::AggregationGuestInput,
_output: &raiko_lib::input::AggregationGuestOutput,
_config: &ProverConfig,
_store: Option<&mut dyn IdWrite>,
) -> ProverResult<Proof> {
Ok(Proof {
..Default::default()
})
}
}

#[ignore = "Only used to test serialized data"]
Expand Down
Loading
Loading