Skip to content

Commit

Permalink
feat: initial recursion core (#354)
Browse files Browse the repository at this point in the history
Co-authored-by: John Guibas <[email protected]>
  • Loading branch information
jtguibas and John Guibas authored Mar 8, 2024
1 parent 34bdc59 commit f2acc41
Show file tree
Hide file tree
Showing 52 changed files with 1,237 additions and 261 deletions.
15 changes: 15 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[workspace]
members = ["core", "cli", "derive", "zkvm/*", "helper", "eval"]
members = ["core", "cli", "derive", "zkvm/*", "helper", "eval", "recursion/core"]
exclude = ["examples/target"]
resolver = "2"

Expand Down
14 changes: 7 additions & 7 deletions core/src/air/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ use p3_air::BaseAir;
use p3_field::Field;
use p3_matrix::dense::RowMajorMatrix;

use crate::runtime::{ExecutionRecord, Program};
use crate::{runtime::Program, stark::MachineRecord};

pub use sp1_derive::MachineAir;

/// An AIR that is part of a Risc-V AIR arithmetization.
pub trait MachineAir<F: Field>: BaseAir<F> {
type Record: MachineRecord;

/// A unique identifier for this AIR as part of a machine.
fn name(&self) -> String;

Expand All @@ -16,14 +18,10 @@ pub trait MachineAir<F: Field>: BaseAir<F> {
/// - `input` is the execution record containing the events to be written to the trace.
/// - `output` is the execution record containing events that the `MachineAir` can add to
/// the record such as byte lookup requests.
fn generate_trace(
&self,
input: &ExecutionRecord,
output: &mut ExecutionRecord,
) -> RowMajorMatrix<F>;
fn generate_trace(&self, input: &Self::Record, output: &mut Self::Record) -> RowMajorMatrix<F>;

/// Generate the dependencies for a given execution record.
fn generate_dependencies(&self, input: &ExecutionRecord, output: &mut ExecutionRecord) {
fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
self.generate_trace(input, output);
}

Expand All @@ -36,4 +34,6 @@ pub trait MachineAir<F: Field>: BaseAir<F> {
fn generate_preprocessed_trace(&self, program: &Program) -> Option<RowMajorMatrix<F>> {
None
}

fn included(&self, shard: &Self::Record) -> bool;
}
7 changes: 7 additions & 0 deletions core/src/alu/add/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::air::MachineAir;
use crate::air::{SP1AirBuilder, Word};
use crate::operations::AddOperation;
use crate::runtime::{ExecutionRecord, Opcode};
use crate::stark::MachineRecord;
use crate::utils::pad_to_power_of_two;

/// The number of main trace columns for `AddChip`.
Expand Down Expand Up @@ -40,6 +41,8 @@ pub struct AddCols<T> {
}

impl<F: PrimeField> MachineAir<F> for AddChip {
type Record = ExecutionRecord;

fn name(&self) -> String {
"Add".to_string()
}
Expand Down Expand Up @@ -88,6 +91,10 @@ impl<F: PrimeField> MachineAir<F> for AddChip {

trace
}

fn included(&self, shard: &Self::Record) -> bool {
!shard.add_events.is_empty()
}
}

impl<F> BaseAir<F> for AddChip {
Expand Down
6 changes: 6 additions & 0 deletions core/src/alu/bitwise/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ pub struct BitwiseCols<T> {
}

impl<F: PrimeField> MachineAir<F> for BitwiseChip {
type Record = ExecutionRecord;

fn name(&self) -> String {
"Bitwise".to_string()
}
Expand Down Expand Up @@ -98,6 +100,10 @@ impl<F: PrimeField> MachineAir<F> for BitwiseChip {

trace
}

fn included(&self, shard: &Self::Record) -> bool {
!shard.bitwise_events.is_empty()
}
}

impl<F> BaseAir<F> for BitwiseChip {
Expand Down
6 changes: 6 additions & 0 deletions core/src/alu/divrem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ pub struct DivRemCols<T> {
}

impl<F: PrimeField> MachineAir<F> for DivRemChip {
type Record = ExecutionRecord;

fn name(&self) -> String {
"DivRem".to_string()
}
Expand Down Expand Up @@ -396,6 +398,10 @@ impl<F: PrimeField> MachineAir<F> for DivRemChip {

trace
}

fn included(&self, shard: &Self::Record) -> bool {
!shard.divrem_events.is_empty()
}
}

impl<F> BaseAir<F> for DivRemChip {
Expand Down
6 changes: 6 additions & 0 deletions core/src/alu/lt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ impl LtCols<u32> {
}

impl<F: PrimeField> MachineAir<F> for LtChip {
type Record = ExecutionRecord;

fn name(&self) -> String {
"Lt".to_string()
}
Expand Down Expand Up @@ -169,6 +171,10 @@ impl<F: PrimeField> MachineAir<F> for LtChip {

trace
}

fn included(&self, shard: &Self::Record) -> bool {
!shard.lt_events.is_empty()
}
}

impl<F> BaseAir<F> for LtChip {
Expand Down
7 changes: 7 additions & 0 deletions core/src/alu/mul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ use crate::alu::mul::utils::get_msb;
use crate::bytes::{ByteLookupEvent, ByteOpcode};
use crate::disassembler::WORD_SIZE;
use crate::runtime::{ExecutionRecord, Opcode};
use crate::stark::MachineRecord;
use crate::utils::pad_to_power_of_two;

/// The number of main trace columns for `MulChip`.
Expand Down Expand Up @@ -110,6 +111,8 @@ pub struct MulCols<T> {
}

impl<F: PrimeField> MachineAir<F> for MulChip {
type Record = ExecutionRecord;

fn name(&self) -> String {
"Mul".to_string()
}
Expand Down Expand Up @@ -248,6 +251,10 @@ impl<F: PrimeField> MachineAir<F> for MulChip {

trace
}

fn included(&self, shard: &Self::Record) -> bool {
!shard.mul_events.is_empty()
}
}

impl<F> BaseAir<F> for MulChip {
Expand Down
6 changes: 6 additions & 0 deletions core/src/alu/sll/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ pub struct ShiftLeftCols<T> {
}

impl<F: PrimeField> MachineAir<F> for ShiftLeft {
type Record = ExecutionRecord;

fn name(&self) -> String {
"ShiftLeft".to_string()
}
Expand Down Expand Up @@ -189,6 +191,10 @@ impl<F: PrimeField> MachineAir<F> for ShiftLeft {

trace
}

fn included(&self, shard: &Self::Record) -> bool {
!shard.shift_left_events.is_empty()
}
}

impl<F> BaseAir<F> for ShiftLeft {
Expand Down
6 changes: 6 additions & 0 deletions core/src/alu/sr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ pub struct ShiftRightCols<T> {
}

impl<F: PrimeField> MachineAir<F> for ShiftRightChip {
type Record = ExecutionRecord;

fn name(&self) -> String {
"ShiftRight".to_string()
}
Expand Down Expand Up @@ -266,6 +268,10 @@ impl<F: PrimeField> MachineAir<F> for ShiftRightChip {

trace
}

fn included(&self, shard: &Self::Record) -> bool {
!shard.shift_right_events.is_empty()
}
}

impl<F> BaseAir<F> for ShiftRightChip {
Expand Down
6 changes: 6 additions & 0 deletions core/src/alu/sub/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ pub struct SubCols<T> {
}

impl<F: PrimeField> MachineAir<F> for SubChip {
type Record = ExecutionRecord;

fn name(&self) -> String {
"Sub".to_string()
}
Expand Down Expand Up @@ -99,6 +101,10 @@ impl<F: PrimeField> MachineAir<F> for SubChip {

trace
}

fn included(&self, shard: &Self::Record) -> bool {
!shard.sub_events.is_empty()
}
}

impl<F> BaseAir<F> for SubChip {
Expand Down
6 changes: 6 additions & 0 deletions core/src/bytes/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use crate::{air::MachineAir, runtime::ExecutionRecord};
pub const NUM_ROWS: usize = 1 << 16;

impl<F: Field> MachineAir<F> for ByteChip<F> {
type Record = ExecutionRecord;

fn name(&self) -> String {
"Byte".to_string()
}
Expand All @@ -29,4 +31,8 @@ impl<F: Field> MachineAir<F> for ByteChip<F> {

trace
}

fn included(&self, shard: &Self::Record) -> bool {
!shard.byte_lookups.is_empty()
}
}
6 changes: 6 additions & 0 deletions core/src/cpu/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ use std::borrow::BorrowMut;
use tracing::instrument;

impl<F: PrimeField> MachineAir<F> for CpuChip {
type Record = ExecutionRecord;

fn name(&self) -> String {
"CPU".to_string()
}
Expand Down Expand Up @@ -102,6 +104,10 @@ impl<F: PrimeField> MachineAir<F> for CpuChip {
output.add_field_events(&field_events);
});
}

fn included(&self, _: &Self::Record) -> bool {
true
}
}

impl CpuChip {
Expand Down
6 changes: 6 additions & 0 deletions core/src/field/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub struct FieldLtuCols<T> {
}

impl<F: PrimeField> MachineAir<F> for FieldLtuChip {
type Record = ExecutionRecord;

fn name(&self) -> String {
"FieldLTU".to_string()
}
Expand Down Expand Up @@ -91,6 +93,10 @@ impl<F: PrimeField> MachineAir<F> for FieldLtuChip {

trace
}

fn included(&self, shard: &Self::Record) -> bool {
!shard.field_events.is_empty()
}
}

pub const LTU_NB_BITS: usize = 29;
Expand Down
7 changes: 4 additions & 3 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ pub mod utils;

pub use io::*;

use crate::stark::RiscvAir;
use anyhow::Result;
use p3_commit::Pcs;
use p3_matrix::dense::RowMajorMatrix;
use runtime::{Program, Runtime};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use stark::StarkGenericConfig;
use stark::{OpeningProof, ProgramVerificationError, Proof, ShardMainData};
use stark::{RiscvStark, StarkGenericConfig};
use std::fs;
use utils::{prove_core, BabyBearBlake3, StarkUtils};

Expand Down Expand Up @@ -123,7 +124,7 @@ impl SP1Verifier {
) -> Result<(), ProgramVerificationError> {
let config = BabyBearBlake3::new();
let mut challenger = config.challenger();
let machine = RiscvStark::new(config);
let machine = RiscvAir::machine(config);
let (_, vk) = machine.setup(&Program::from(elf));
machine.verify(&vk, &proof.proof, &mut challenger)
}
Expand All @@ -145,7 +146,7 @@ impl SP1Verifier {
<SC as StarkGenericConfig>::Val: p3_field::PrimeField32,
{
let mut challenger = config.challenger();
let machine = RiscvStark::new(config);
let machine = RiscvAir::machine(config);

let (_, vk) = machine.setup(&Program::from(elf));
machine.verify(&vk, &proof.proof, &mut challenger)
Expand Down
22 changes: 12 additions & 10 deletions core/src/lookup/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ use p3_field::{Field, PrimeField64};
use p3_matrix::Matrix;

use crate::air::MachineAir;
use crate::runtime::ExecutionRecord;
use crate::stark::{RiscvChip, StarkGenericConfig};
use crate::stark::{MachineChip, StarkGenericConfig};

use super::InteractionKind;

Expand Down Expand Up @@ -43,9 +42,9 @@ fn babybear_to_int(n: BabyBear) -> i32 {
}
}

pub fn debug_interactions<SC: StarkGenericConfig>(
chip: &RiscvChip<SC>,
record: &ExecutionRecord,
pub fn debug_interactions<SC: StarkGenericConfig, A: MachineAir<SC::Val>>(
chip: &MachineChip<SC, A>,
record: &A::Record,
interaction_kinds: Vec<InteractionKind>,
) -> (
BTreeMap<String, Vec<InteractionData<SC::Val>>>,
Expand All @@ -54,7 +53,7 @@ pub fn debug_interactions<SC: StarkGenericConfig>(
let mut key_to_vec_data = BTreeMap::new();
let mut key_to_count = BTreeMap::new();

let trace = chip.generate_trace(record, &mut ExecutionRecord::default());
let trace = chip.generate_trace(record, &mut A::Record::default());
let mut main = trace.clone();
let height = trace.clone().height();

Expand Down Expand Up @@ -109,15 +108,18 @@ pub fn debug_interactions<SC: StarkGenericConfig>(

/// Calculate the the number of times we send and receive each event of the given interaction type,
/// and print out the ones for which the set of sends and receives don't match.
pub fn debug_interactions_with_all_chips<SC: StarkGenericConfig<Val = BabyBear>>(
chips: &[RiscvChip<SC>],
segment: &ExecutionRecord,
pub fn debug_interactions_with_all_chips<
SC: StarkGenericConfig<Val = BabyBear>,
A: MachineAir<SC::Val>,
>(
chips: &[MachineChip<SC, A>],
segment: &A::Record,
interaction_kinds: Vec<InteractionKind>,
) -> bool {
let mut final_map = BTreeMap::new();

for chip in chips.iter() {
let (_, count) = debug_interactions::<SC>(chip, segment, interaction_kinds.clone());
let (_, count) = debug_interactions::<SC, A>(chip, segment, interaction_kinds.clone());

tracing::debug!("{} chip has {} distinct events", chip.name(), count.len());
for (key, value) in count.iter() {
Expand Down
Loading

0 comments on commit f2acc41

Please sign in to comment.