Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: verify shard transitions + fixes #482

Merged
merged 17 commits into from
Apr 17, 2024
14 changes: 12 additions & 2 deletions core/src/air/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use p3_air::BaseAir;
use p3_field::Field;
use p3_matrix::dense::RowMajorMatrix;

use crate::stark::MachineRecord;
use crate::{runtime::Program, stark::MachineRecord};

pub use sp1_derive::MachineAir;

Expand All @@ -11,7 +11,7 @@ pub trait MachineAir<F: Field>: BaseAir<F> {
/// The execution record containing events for producing the air trace.
type Record: MachineRecord;

type Program: Send + Sync;
type Program: MachineProgram<F>;

/// A unique identifier for this AIR as part of a machine.
fn name(&self) -> String;
Expand Down Expand Up @@ -41,3 +41,13 @@ pub trait MachineAir<F: Field>: BaseAir<F> {
None
}
}

pub trait MachineProgram<F>: Send + Sync {
fn pc_start(&self) -> F;
}

impl<F: Field> MachineProgram<F> for Program {
fn pc_start(&self) -> F {
F::from_canonical_u32(self.pc_start)
}
}
15 changes: 2 additions & 13 deletions core/src/cpu/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,10 +407,10 @@ impl CpuChip {
.when(is_ecall_instruction.clone() * is_enter_unconstrained)
.assert_word_eq(local.op_a_val(), zero_word);

// When the syscall is not one of ENTER_UNCONSTRAINED, HINT_LEN, or HALT, op_a shouldn't change.
// When the syscall is not one of ENTER_UNCONSTRAINED or HINT_LEN, op_a shouldn't change.
builder
.when(is_ecall_instruction.clone())
.when_not(is_enter_unconstrained + is_hint_len + is_halt)
.when_not(is_enter_unconstrained + is_hint_len)
.assert_word_eq(local.op_a_val(), local.op_a_access.prev_value);

(
Expand Down Expand Up @@ -574,17 +574,6 @@ impl CpuChip {
builder.index_word_array(&commit_digest, &ecall_columns.index_bitmap);

let digest_word = local.op_c_access.prev_value();
// Verify b and c do not change during commit syscall.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can we just delete this? Doesn't we need this?

Copy link
Member Author

@ctian1 ctian1 Apr 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's read cols so there's no point even checking it? Also we actually check it in other places in the same file as well lol

builder
.when(
local.selectors.is_ecall * (is_commit.clone() + is_commit_deferred_proofs.clone()),
)
.assert_word_eq(*local.op_b_access.value(), *local.op_b_access.prev_value());
builder
.when(
local.selectors.is_ecall * (is_commit.clone() + is_commit_deferred_proofs.clone()),
)
.assert_word_eq(*local.op_c_access.value(), *local.op_c_access.prev_value());

// Verify the public_values_digest_word.
builder
Expand Down
4 changes: 2 additions & 2 deletions core/src/cpu/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,6 @@ impl CpuChip {
let syscall_id = cols.op_a_access.prev_value[0];
// let send_to_table = cols.op_a_access.prev_value[1];
// let num_cycles = cols.op_a_access.prev_value[2];
// let is_halt = cols.op_a_access.prev_value[3];

// Populate `is_enter_unconstrained`.
ecall_cols
Expand Down Expand Up @@ -621,7 +620,7 @@ mod tests {
use super::*;

use crate::runtime::{tests::simple_program, Instruction, Runtime};
use crate::utils::run_test;
use crate::utils::{run_test, setup_logger};

#[test]
fn generate_trace() {
Expand Down Expand Up @@ -671,6 +670,7 @@ mod tests {

#[test]
fn prove_trace() {
setup_logger();
let program = simple_program();
run_test(program).unwrap();
}
Expand Down
6 changes: 6 additions & 0 deletions core/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ impl SP1PublicValues {
}
}

impl AsRef<[u8]> for SP1PublicValues {
fn as_ref(&self) -> &[u8] {
&self.buffer.data
}
}

pub mod proof_serde {
use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize};

Expand Down
85 changes: 59 additions & 26 deletions core/src/memory/program.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use core::borrow::{Borrow, BorrowMut};
use core::mem::size_of;
use p3_air::{Air, BaseAir, PairBuilder};
use p3_air::{Air, AirBuilder, BaseAir, PairBuilder};
use p3_field::AbstractField;
use p3_field::PrimeField;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use std::collections::BTreeMap;

use sp1_derive::AlignedBorrow;

use crate::air::{AirInteraction, SP1AirBuilder};
use crate::air::{AirInteraction, PublicValues, SP1AirBuilder};
use crate::air::{MachineAir, Word};
use crate::operations::IsZeroOperation;
use crate::runtime::{ExecutionRecord, Program};
use crate::utils::pad_to_power_of_two;

Expand All @@ -24,16 +24,22 @@ pub const NUM_MEMORY_PROGRAM_MULT_COLS: usize = size_of::<MemoryProgramMultCols<
pub struct MemoryProgramPreprocessedCols<T> {
pub addr: T,
pub value: Word<T>,
pub is_real: T,
}

/// The column layout for the chip.
/// Multiplicity columns.
#[derive(AlignedBorrow, Clone, Copy, Default)]
#[repr(C)]
pub struct MemoryProgramMultCols<T> {
pub used: T,
/// The multiplicity of the event, must be 1 in the first shard and 0 otherwise.
pub multiplicity: T,
/// Columns to see if current shard is 1.
pub is_first_shard: IsZeroOperation<T>,
}

/// Chip that initializes memory that is provided from the program.
/// Chip that initializes memory that is provided from the program. The table is preprocessed and
/// receives each row in the first shard. This prevents any of these addresses from being
/// overwritten through the normal MemoryInit.
#[derive(Default)]
pub struct MemoryProgramChip;

Expand All @@ -58,13 +64,16 @@ impl<F: PrimeField> MachineAir<F> for MemoryProgramChip {

fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
let program_memory = program.memory_image.clone();
// Note that BTreeMap is guaranteed to be sorted by key. This makes the row order
// deterministic.
let rows = program_memory
.into_iter()
.map(|(addr, word)| {
let mut row = [F::zero(); NUM_MEMORY_PROGRAM_PREPROCESSED_COLS];
let cols: &mut MemoryProgramPreprocessedCols<F> = row.as_mut_slice().borrow_mut();
cols.addr = F::from_canonical_u32(addr);
cols.value = Word::from(word);
cols.is_real = F::one();

row
})
Expand All @@ -91,30 +100,28 @@ impl<F: PrimeField> MachineAir<F> for MemoryProgramChip {
input: &ExecutionRecord,
_output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
// Build a map of each address in program memory image to whether it was used.
// We have to do it from program because only the last shard has all the events, but every
// preprocessed row needs a corresponding mult row even if it's not used.
let mut addr_used_map = input
let program_memory_addrs = input
.program
.memory_image
.keys()
.map(|addr| (*addr, false))
.collect::<BTreeMap<_, _>>();
for event in &input.program_memory_events {
if event.used == 1 {
if let Some(used) = addr_used_map.get_mut(&event.addr) {
*used = true;
}
}
}
.copied()
.collect::<Vec<_>>();

let mult = if input.index == 1 {
F::one()
} else {
F::zero()
};

// Generate the trace rows for each event.
let rows = addr_used_map
.values()
.map(|used| {
let rows = program_memory_addrs
.into_iter()
.map(|_| {
let mut row = [F::zero(); NUM_MEMORY_PROGRAM_MULT_COLS];
let cols: &mut MemoryProgramMultCols<F> = row.as_mut_slice().borrow_mut();
cols.used = F::from_bool(*used);
cols.multiplicity = mult;
IsZeroOperation::populate(&mut cols.is_first_shard, input.index - 1);

row
})
.collect::<Vec<_>>();
Expand Down Expand Up @@ -147,21 +154,47 @@ where
AB: SP1AirBuilder + PairBuilder,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let preprocessed = builder.preprocessed();
let main = builder.main();

let prep_local = preprocessed.row_slice(0);
let prep_local: &MemoryProgramPreprocessedCols<AB::Var> = (*prep_local).borrow();

let mult_local = main.row_slice(0);
let mult_local: &MemoryProgramMultCols<AB::Var> = (*mult_local).borrow();

builder.assert_bool(mult_local.used);
// Get shard from public values and evaluate whether it is the first shard.
let public_values = PublicValues::<Word<AB::Expr>, AB::Expr>::from_vec(
builder
.public_values()
.iter()
.map(|elm| (*elm).into())
.collect::<Vec<_>>(),
);
IsZeroOperation::<AB::F>::eval(
builder,
public_values.shard - AB::Expr::one(),
mult_local.is_first_shard,
prep_local.is_real.into(),
);
let is_first_shard = mult_local.is_first_shard.result;

// Multiplicity must be either 0 or 1.
builder.assert_bool(mult_local.multiplicity);
// If first shard and preprocessed is real, multiplicity must be one.
builder
.when(is_first_shard * prep_local.is_real)
.assert_one(mult_local.multiplicity);
// If not first shard or preprocessed is not real, multiplicity must be zero.
builder
.when((AB::Expr::one() - is_first_shard) + (AB::Expr::one() - prep_local.is_real))
.assert_zero(mult_local.multiplicity);

let mut values = vec![AB::Expr::zero(), AB::Expr::zero(), prep_local.addr.into()];
values.extend(prep_local.value.map(Into::into));
builder.receive(AirInteraction::new(
values,
mult_local.used.into(),
mult_local.multiplicity.into(),
crate::lookup::InteractionKind::Memory,
));
}
Expand Down
31 changes: 0 additions & 31 deletions core/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ use std::io::Write;
use std::rc::Rc;
use std::sync::Arc;

use nohash_hasher::BuildNoHashHasher;

use crate::memory::MemoryInitializeFinalizeEvent;
use crate::utils::env;
use crate::{alu::AluEvent, cpu::CpuEvent};
Expand Down Expand Up @@ -967,16 +965,6 @@ impl Runtime {
}

// SECTION: Set up all MemoryInitializeFinalizeEvents needed for memory argument.

// Program Memory is the global constants of the program. We need to mark which of these
// addresses are used by the program, as some invocations might not touch all addresses.
// program_memory_map maps an addr to its value and whether it was touched during the program.
let mut program_memory_map = HashMap::with_hasher(BuildNoHashHasher::<u32>::default());

for (key, value) in &self.program.memory_image {
program_memory_map.insert(key, (*value, true));
}

let memory_finalize_events = &mut self.record.memory_finalize_events;

// We handle the addr = 0 case separately, as we constrain it to be 0 in the first row
Expand All @@ -1002,30 +990,11 @@ impl Runtime {
}

let record = *self.state.memory.get(addr).unwrap();
if record.shard == 0 && record.timestamp == 0 {
// This means that we never accessed this memory location throughout our entire program.
// The only way this can happen is if this was in the program memory image.
// We mark this (addr, value) as not touched in the `program_memory_map` map.
program_memory_map.insert(addr, (record.value, false));
continue;
}

memory_finalize_events.push(MemoryInitializeFinalizeEvent::finalize_from_record(
*addr, &record,
));
}

let mut program_memory_events = program_memory_map
.into_iter()
.map(|(addr, (value, used))| {
MemoryInitializeFinalizeEvent::initialize(*addr, value, used)
})
.collect::<Vec<MemoryInitializeFinalizeEvent>>();
// Sort the program_memory_events by addr to create a canonical ordering for the
// preprocessed table, as this is part of the vkey.
program_memory_events.sort_by_key(|event| event.addr);

self.record.program_memory_events = program_memory_events;
}

fn get_syscall(&mut self, code: SyscallCode) -> Option<&Rc<dyn Syscall>> {
Expand Down
7 changes: 0 additions & 7 deletions core/src/runtime/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ pub struct ExecutionRecord {

pub memory_finalize_events: Vec<MemoryInitializeFinalizeEvent>,

pub program_memory_events: Vec<MemoryInitializeFinalizeEvent>,

/// The public values.
pub public_values: PublicValues<u32, u32>,
}
Expand Down Expand Up @@ -264,8 +262,6 @@ impl MachineRecord for ExecutionRecord {
.append(&mut other.memory_initialize_events);
self.memory_finalize_events
.append(&mut other.memory_finalize_events);
self.program_memory_events
.append(&mut other.program_memory_events);
}

fn shard(mut self, config: &ShardingConfig) -> Vec<Self> {
Expand Down Expand Up @@ -468,9 +464,6 @@ impl MachineRecord for ExecutionRecord {
last_shard
.memory_finalize_events
.extend_from_slice(&self.memory_finalize_events);
last_shard
.program_memory_events
.extend_from_slice(&self.program_memory_events);

shards
}
Expand Down
4 changes: 0 additions & 4 deletions core/src/runtime/syscall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,6 @@ impl SyscallCode {
pub fn num_cycles(&self) -> u32 {
(*self as u32).to_le_bytes()[2].into()
}

pub fn is_halt(&self) -> u32 {
(*self as u32).to_le_bytes()[3].into()
}
}

pub trait Syscall {
Expand Down
Loading
Loading