Skip to content

Commit

Permalink
feat: add shard to byte and program table (#463)
Browse files Browse the repository at this point in the history
Co-authored-by: Chris Tian <[email protected]>
  • Loading branch information
puma314 and ctian1 authored Apr 3, 2024
1 parent 4497641 commit 08f5a8e
Show file tree
Hide file tree
Showing 53 changed files with 766 additions and 374 deletions.
12 changes: 3 additions & 9 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ nohash-hasher = "0.2.0"
num = { version = "0.4.1" }
p3-air = { workspace = true }
p3-baby-bear = { workspace = true }
p3-blake3 = { workspace = true }
p3-blake3 = { workspace = true, features = ["parallel"] }
p3-challenger = { workspace = true }
p3-commit = { workspace = true }
p3-dft = { workspace = true }
Expand All @@ -24,7 +24,7 @@ p3-goldilocks = { workspace = true }
p3-keccak = { workspace = true }
p3-keccak-air = { workspace = true }
p3-matrix = { workspace = true }
p3-maybe-rayon = { workspace = true }
p3-maybe-rayon = { workspace = true, features = ["parallel"] }
p3-mds = { workspace = true }
p3-merkle-tree = { workspace = true }
p3-poseidon2 = { workspace = true }
Expand Down Expand Up @@ -71,14 +71,8 @@ num = { version = "0.4.1", features = ["rand"] }
rand = "0.8.5"

[features]
debug = ["parallel"]
debug-proof = ["parallel", "perf"]
default = ["perf"]
keccak = []
debug = []
neon = ["p3-blake3/neon"]
parallel = ["p3-maybe-rayon/parallel", "p3-blake3/parallel"]
perf = ["parallel"]
serial = []

[[bench]]
harness = false
Expand Down
3 changes: 0 additions & 3 deletions core/benches/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ use sp1_core::utils::{run_and_prove, BabyBearPoseidon2};

#[allow(unreachable_code)]
pub fn criterion_benchmark(c: &mut Criterion) {
#[cfg(not(feature = "perf"))]
unreachable!("--features=perf must be enabled to run this benchmark");

let mut group = c.benchmark_group("prove");
group.sample_size(10);
let programs = ["fibonacci"];
Expand Down
90 changes: 72 additions & 18 deletions core/src/air/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,84 +92,108 @@ pub trait BaseAirBuilder: AirBuilder + MessageBuilder<AirInteraction<Self::Expr>
/// A trait which contains methods for byte interactions in an AIR.
pub trait ByteAirBuilder: BaseAirBuilder {
/// Sends a byte operation to be processed.
fn send_byte<EOp, Ea, Eb, Ec, EMult>(
fn send_byte<EOp, Ea, Eb, Ec, EShard, EMult>(
&mut self,
opcode: EOp,
a: Ea,
b: Eb,
c: Ec,
shard: EShard,
multiplicity: EMult,
) where
EOp: Into<Self::Expr>,
Ea: Into<Self::Expr>,
Eb: Into<Self::Expr>,
Ec: Into<Self::Expr>,
EShard: Into<Self::Expr>,
EMult: Into<Self::Expr>,
{
self.send_byte_pair(opcode, a, Self::Expr::zero(), b, c, multiplicity)
self.send_byte_pair(opcode, a, Self::Expr::zero(), b, c, shard, multiplicity)
}

/// Sends a byte operation with two outputs to be processed.
fn send_byte_pair<EOp, Ea1, Ea2, Eb, Ec, EMult>(
#[allow(clippy::too_many_arguments)]
fn send_byte_pair<EOp, Ea1, Ea2, Eb, Ec, EShard, EMult>(
&mut self,
opcode: EOp,
a1: Ea1,
a2: Ea2,
b: Eb,
c: Ec,
shard: EShard,
multiplicity: EMult,
) where
EOp: Into<Self::Expr>,
Ea1: Into<Self::Expr>,
Ea2: Into<Self::Expr>,
Eb: Into<Self::Expr>,
Ec: Into<Self::Expr>,
EShard: Into<Self::Expr>,
EMult: Into<Self::Expr>,
{
self.send(AirInteraction::new(
vec![opcode.into(), a1.into(), a2.into(), b.into(), c.into()],
vec![
opcode.into(),
a1.into(),
a2.into(),
b.into(),
c.into(),
shard.into(),
],
multiplicity.into(),
InteractionKind::Byte,
));
}

/// Receives a byte operation to be processed.
fn receive_byte<EOp, Ea, Eb, Ec, EMult>(
fn receive_byte<EOp, Ea, Eb, Ec, EMult, EShard>(
&mut self,
opcode: EOp,
a: Ea,
b: Eb,
c: Ec,
shard: EShard,
multiplicity: EMult,
) where
EOp: Into<Self::Expr>,
Ea: Into<Self::Expr>,
Eb: Into<Self::Expr>,
Ec: Into<Self::Expr>,
EShard: Into<Self::Expr>,
EMult: Into<Self::Expr>,
{
self.receive_byte_pair(opcode, a, Self::Expr::zero(), b, c, multiplicity)
self.receive_byte_pair(opcode, a, Self::Expr::zero(), b, c, shard, multiplicity)
}

/// Receives a byte operation with two outputs to be processed.
fn receive_byte_pair<EOp, Ea1, Ea2, Eb, Ec, EMult>(
#[allow(clippy::too_many_arguments)]
fn receive_byte_pair<EOp, Ea1, Ea2, Eb, Ec, EMult, EShard>(
&mut self,
opcode: EOp,
a1: Ea1,
a2: Ea2,
b: Eb,
c: Ec,
shard: EShard,
multiplicity: EMult,
) where
EOp: Into<Self::Expr>,
Ea1: Into<Self::Expr>,
Ea2: Into<Self::Expr>,
Eb: Into<Self::Expr>,
Ec: Into<Self::Expr>,
EShard: Into<Self::Expr>,
EMult: Into<Self::Expr>,
{
self.receive(AirInteraction::new(
vec![opcode.into(), a1.into(), a2.into(), b.into(), c.into()],
vec![
opcode.into(),
a1.into(),
a2.into(),
b.into(),
c.into(),
shard.into(),
],
multiplicity.into(),
InteractionKind::Byte,
));
Expand Down Expand Up @@ -217,9 +241,14 @@ pub trait WordAirBuilder: ByteAirBuilder {
}

/// Check that each limb of the given slice is a u8.
fn slice_range_check_u8<EWord: Into<Self::Expr> + Clone, EMult: Into<Self::Expr> + Clone>(
fn slice_range_check_u8<
EWord: Into<Self::Expr> + Clone,
EShard: Into<Self::Expr> + Clone,
EMult: Into<Self::Expr> + Clone,
>(
&mut self,
input: &[EWord],
shard: EShard,
mult: EMult,
) {
let mut index = 0;
Expand All @@ -229,6 +258,7 @@ pub trait WordAirBuilder: ByteAirBuilder {
Self::Expr::zero(),
input[index].clone(),
input[index + 1].clone(),
shard.clone(),
mult.clone(),
);
index += 2;
Expand All @@ -239,15 +269,21 @@ pub trait WordAirBuilder: ByteAirBuilder {
Self::Expr::zero(),
input[index].clone(),
Self::Expr::zero(),
shard.clone(),
mult.clone(),
);
}
}

/// Check that each limb of the given slice is a u16.
fn slice_range_check_u16<EWord: Into<Self::Expr> + Copy, EMult: Into<Self::Expr> + Clone>(
fn slice_range_check_u16<
EWord: Into<Self::Expr> + Copy,
EShard: Into<Self::Expr> + Clone,
EMult: Into<Self::Expr> + Clone,
>(
&mut self,
input: &[EWord],
shard: EShard,
mult: EMult,
) {
input.iter().for_each(|limb| {
Expand All @@ -256,6 +292,7 @@ pub trait WordAirBuilder: ByteAirBuilder {
*limb,
Self::Expr::zero(),
Self::Expr::zero(),
shard.clone(),
mult.clone(),
);
});
Expand All @@ -265,24 +302,27 @@ pub trait WordAirBuilder: ByteAirBuilder {
/// A trait which contains methods related to ALU interactions in an AIR.
pub trait AluAirBuilder: BaseAirBuilder {
/// Sends an ALU operation to be processed.
fn send_alu<EOp, Ea, Eb, Ec, EMult>(
fn send_alu<EOp, Ea, Eb, Ec, EShard, EMult>(
&mut self,
opcode: EOp,
a: Word<Ea>,
b: Word<Eb>,
c: Word<Ec>,
shard: EShard,
multiplicity: EMult,
) where
EOp: Into<Self::Expr>,
Ea: Into<Self::Expr>,
Eb: Into<Self::Expr>,
Ec: Into<Self::Expr>,
EShard: Into<Self::Expr>,
EMult: Into<Self::Expr>,
{
let values = once(opcode.into())
.chain(a.0.into_iter().map(Into::into))
.chain(b.0.into_iter().map(Into::into))
.chain(c.0.into_iter().map(Into::into))
.chain(once(shard.into()))
.collect();

self.send(AirInteraction::new(
Expand All @@ -293,24 +333,27 @@ pub trait AluAirBuilder: BaseAirBuilder {
}

/// Receives an ALU operation to be processed.
fn receive_alu<EOp, Ea, Eb, Ec, EMult>(
fn receive_alu<EOp, Ea, Eb, Ec, EShard, EMult>(
&mut self,
opcode: EOp,
a: Word<Ea>,
b: Word<Eb>,
c: Word<Ec>,
shard: EShard,
multiplicity: EMult,
) where
EOp: Into<Self::Expr>,
Ea: Into<Self::Expr>,
Eb: Into<Self::Expr>,
Ec: Into<Self::Expr>,
EShard: Into<Self::Expr>,
EMult: Into<Self::Expr>,
{
let values = once(opcode.into())
.chain(a.0.into_iter().map(Into::into))
.chain(b.0.into_iter().map(Into::into))
.chain(c.0.into_iter().map(Into::into))
.chain(once(shard.into()))
.collect();

self.receive(AirInteraction::new(
Expand Down Expand Up @@ -457,12 +500,12 @@ pub trait MemoryAirBuilder: BaseAirBuilder {
) where
Eb: Into<Self::Expr> + Clone,
EVerify: Into<Self::Expr>,
EShard: Into<Self::Expr>,
EShard: Into<Self::Expr> + Clone,
EClk: Into<Self::Expr>,
{
let do_check: Self::Expr = do_check.into();
let compare_clk: Self::Expr = mem_access.compare_clk.clone().into();
let shard: Self::Expr = shard.into();
let shard: Self::Expr = shard.clone().into();
let prev_shard: Self::Expr = mem_access.prev_shard.clone().into();

// First verify that compare_clk's value is correct.
Expand All @@ -478,7 +521,7 @@ pub trait MemoryAirBuilder: BaseAirBuilder {
mem_access.prev_shard.clone(),
);

let current_comp_val = self.if_else(compare_clk.clone(), clk.into(), shard);
let current_comp_val = self.if_else(compare_clk.clone(), clk.into(), shard.clone());

// Assert `current_comp_val > prev_comp_val`. We check this by asserting that
// `0 <= current_comp_val-prev_comp_val-1 < 2^24`.
Expand All @@ -495,6 +538,7 @@ pub trait MemoryAirBuilder: BaseAirBuilder {
diff_minus_one,
mem_access.diff_16bit_limb.clone(),
mem_access.diff_8bit_limb.clone(),
shard.clone(),
do_check,
);
}
Expand All @@ -505,15 +549,17 @@ pub trait MemoryAirBuilder: BaseAirBuilder {
/// check on it's limbs. It will also verify that the limbs are correct. This method is needed
/// since the memory access timestamp check (see [Self::verify_mem_access_ts]) needs to assume
/// the clk is within 24 bits.
fn verify_range_24bits<EValue, ELimb, EVerify>(
fn verify_range_24bits<EValue, ELimb, EShard, EVerify>(
&mut self,
value: EValue,
limb_16: ELimb,
limb_8: ELimb,
shard: EShard,
do_check: EVerify,
) where
EValue: Into<Self::Expr>,
ELimb: Into<Self::Expr> + Clone,
EShard: Into<Self::Expr> + Clone,
EVerify: Into<Self::Expr> + Clone,
{
// Verify that value = limb_16 + limb_8 * 2^16.
Expand All @@ -529,6 +575,7 @@ pub trait MemoryAirBuilder: BaseAirBuilder {
limb_16,
Self::Expr::zero(),
Self::Expr::zero(),
shard.clone(),
do_check.clone(),
);

Expand All @@ -537,6 +584,7 @@ pub trait MemoryAirBuilder: BaseAirBuilder {
Self::Expr::zero(),
Self::Expr::zero(),
limb_8,
shard.clone(),
do_check,
)
}
Expand Down Expand Up @@ -571,22 +619,25 @@ pub trait MemoryAirBuilder: BaseAirBuilder {
/// A trait which contains methods related to program interactions in an AIR.
pub trait ProgramAirBuilder: BaseAirBuilder {
/// Sends an instruction.
fn send_program<EPc, EInst, ESel, EMult>(
fn send_program<EPc, EInst, ESel, EShard, EMult>(
&mut self,
pc: EPc,
instruction: InstructionCols<EInst>,
selectors: OpcodeSelectorCols<ESel>,
shard: EShard,
multiplicity: EMult,
) where
EPc: Into<Self::Expr>,
EInst: Into<Self::Expr> + Copy,
ESel: Into<Self::Expr> + Copy,
EShard: Into<Self::Expr> + Copy,
EMult: Into<Self::Expr>,
{
let values = once(pc.into())
.chain(once(instruction.opcode.into()))
.chain(instruction.into_iter().map(|x| x.into()))
.chain(selectors.into_iter().map(|x| x.into()))
.chain(once(shard.into()))
.collect();

self.send(AirInteraction::new(
Expand All @@ -597,22 +648,25 @@ pub trait ProgramAirBuilder: BaseAirBuilder {
}

/// Receives an instruction.
fn receive_program<EPc, EInst, ESel, EMult>(
fn receive_program<EPc, EInst, ESel, EShard, EMult>(
&mut self,
pc: EPc,
instruction: InstructionCols<EInst>,
selectors: OpcodeSelectorCols<ESel>,
shard: EShard,
multiplicity: EMult,
) where
EPc: Into<Self::Expr>,
EInst: Into<Self::Expr> + Copy,
ESel: Into<Self::Expr> + Copy,
EShard: Into<Self::Expr> + Copy,
EMult: Into<Self::Expr>,
{
let values: Vec<<Self as AirBuilder>::Expr> = once(pc.into())
.chain(once(instruction.opcode.into()))
.chain(instruction.into_iter().map(|x| x.into()))
.chain(selectors.into_iter().map(|x| x.into()))
.chain(once(shard.into()))
.collect();

self.receive(AirInteraction::new(
Expand Down
Loading

0 comments on commit 08f5a8e

Please sign in to comment.