From 08f5a8e078d8ddb91a1bdd30ec1d8250cf5952d3 Mon Sep 17 00:00:00 2001 From: puma314 Date: Wed, 3 Apr 2024 14:41:46 -0700 Subject: [PATCH] feat: add shard to byte and program table (#463) Co-authored-by: Chris Tian --- core/Cargo.toml | 12 +-- core/benches/main.rs | 3 - core/src/air/builder.rs | 90 +++++++++++++++---- core/src/alu/add_sub/mod.rs | 33 +++++-- core/src/alu/bitwise/mod.rs | 17 ++-- core/src/alu/divrem/mod.rs | 41 ++++++--- core/src/alu/lt/mod.rs | 35 ++++---- core/src/alu/mod.rs | 6 +- core/src/alu/mul/mod.rs | 29 ++++-- core/src/alu/sll/mod.rs | 19 ++-- core/src/alu/sr/mod.rs | 24 +++-- core/src/bytes/air.rs | 19 ++-- core/src/bytes/columns.rs | 4 + core/src/bytes/event.rs | 6 +- core/src/bytes/mod.rs | 31 ++++--- core/src/bytes/trace.rs | 23 +++-- core/src/cpu/air/branch.rs | 3 + core/src/cpu/air/memory.rs | 1 + core/src/cpu/air/mod.rs | 21 ++++- core/src/cpu/trace.rs | 36 +++++++- core/src/lookup/debug.rs | 70 +++++++++++---- core/src/memory/global.rs | 34 ++++--- core/src/memory/trace.rs | 20 ++++- core/src/operations/add.rs | 21 +++-- core/src/operations/add4.rs | 23 ++--- core/src/operations/add5.rs | 21 +++-- core/src/operations/and.rs | 5 +- core/src/operations/fixed_rotate_right.rs | 11 ++- core/src/operations/fixed_shift_right.rs | 11 ++- core/src/operations/not.rs | 6 +- core/src/operations/or.rs | 6 +- core/src/operations/xor.rs | 5 +- core/src/program/mod.rs | 7 ++ core/src/runtime/mod.rs | 1 + core/src/runtime/record.rs | 64 ++++++++----- core/src/runtime/syscall.rs | 9 +- core/src/stark/air.rs | 4 +- core/src/stark/machine.rs | 39 ++++---- core/src/stark/prover.rs | 18 ---- core/src/stark/verifier.rs | 17 ---- .../precompiles/blake3/compress/air.rs | 8 +- .../precompiles/blake3/compress/columns.rs | 2 +- .../syscall/precompiles/blake3/compress/g.rs | 61 ++++++++----- .../precompiles/blake3/compress/mod.rs | 30 +++---- .../precompiles/blake3/compress/trace.rs | 6 +- core/src/syscall/precompiles/keccak256/air.rs | 2 +- .../precompiles/sha256/compress/air.rs | 31 ++++++- .../precompiles/sha256/compress/trace.rs | 72 ++++++++------- .../syscall/precompiles/sha256/extend/air.rs | 11 +++ .../syscall/precompiles/sha256/extend/mod.rs | 2 +- .../precompiles/sha256/extend/trace.rs | 54 +++++++---- core/src/utils/prove.rs | 14 +-- recursion/program/src/stark.rs | 2 +- 53 files changed, 766 insertions(+), 374 deletions(-) diff --git a/core/Cargo.toml b/core/Cargo.toml index 8e45968ec0..f9a02d2c13 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -14,7 +14,7 @@ nohash-hasher = "0.2.0" num = { version = "0.4.1" } p3-air = { workspace = true } p3-baby-bear = { workspace = true } -p3-blake3 = { workspace = true } +p3-blake3 = { workspace = true, features = ["parallel"] } p3-challenger = { workspace = true } p3-commit = { workspace = true } p3-dft = { workspace = true } @@ -24,7 +24,7 @@ p3-goldilocks = { workspace = true } p3-keccak = { workspace = true } p3-keccak-air = { workspace = true } p3-matrix = { workspace = true } -p3-maybe-rayon = { workspace = true } +p3-maybe-rayon = { workspace = true, features = ["parallel"] } p3-mds = { workspace = true } p3-merkle-tree = { workspace = true } p3-poseidon2 = { workspace = true } @@ -71,14 +71,8 @@ num = { version = "0.4.1", features = ["rand"] } rand = "0.8.5" [features] -debug = ["parallel"] -debug-proof = ["parallel", "perf"] -default = ["perf"] -keccak = [] +debug = [] neon = ["p3-blake3/neon"] -parallel = ["p3-maybe-rayon/parallel", "p3-blake3/parallel"] -perf = ["parallel"] -serial = [] [[bench]] harness = false diff --git a/core/benches/main.rs b/core/benches/main.rs index dd16bb4739..71ac8fd4b8 100644 --- a/core/benches/main.rs +++ b/core/benches/main.rs @@ -4,9 +4,6 @@ use sp1_core::utils::{run_and_prove, BabyBearPoseidon2}; #[allow(unreachable_code)] pub fn criterion_benchmark(c: &mut Criterion) { - #[cfg(not(feature = "perf"))] - unreachable!("--features=perf must be enabled to run this benchmark"); - let mut group = c.benchmark_group("prove"); group.sample_size(10); let programs = ["fibonacci"]; diff --git a/core/src/air/builder.rs b/core/src/air/builder.rs index 680a5d7854..572ea64cf2 100644 --- a/core/src/air/builder.rs +++ b/core/src/air/builder.rs @@ -92,31 +92,35 @@ pub trait BaseAirBuilder: AirBuilder + MessageBuilder /// A trait which contains methods for byte interactions in an AIR. pub trait ByteAirBuilder: BaseAirBuilder { /// Sends a byte operation to be processed. - fn send_byte( + fn send_byte( &mut self, opcode: EOp, a: Ea, b: Eb, c: Ec, + shard: EShard, multiplicity: EMult, ) where EOp: Into, Ea: Into, Eb: Into, Ec: Into, + EShard: Into, EMult: Into, { - self.send_byte_pair(opcode, a, Self::Expr::zero(), b, c, multiplicity) + self.send_byte_pair(opcode, a, Self::Expr::zero(), b, c, shard, multiplicity) } /// Sends a byte operation with two outputs to be processed. - fn send_byte_pair( + #[allow(clippy::too_many_arguments)] + fn send_byte_pair( &mut self, opcode: EOp, a1: Ea1, a2: Ea2, b: Eb, c: Ec, + shard: EShard, multiplicity: EMult, ) where EOp: Into, @@ -124,41 +128,53 @@ pub trait ByteAirBuilder: BaseAirBuilder { Ea2: Into, Eb: Into, Ec: Into, + EShard: Into, EMult: Into, { self.send(AirInteraction::new( - vec![opcode.into(), a1.into(), a2.into(), b.into(), c.into()], + vec![ + opcode.into(), + a1.into(), + a2.into(), + b.into(), + c.into(), + shard.into(), + ], multiplicity.into(), InteractionKind::Byte, )); } /// Receives a byte operation to be processed. - fn receive_byte( + fn receive_byte( &mut self, opcode: EOp, a: Ea, b: Eb, c: Ec, + shard: EShard, multiplicity: EMult, ) where EOp: Into, Ea: Into, Eb: Into, Ec: Into, + EShard: Into, EMult: Into, { - self.receive_byte_pair(opcode, a, Self::Expr::zero(), b, c, multiplicity) + self.receive_byte_pair(opcode, a, Self::Expr::zero(), b, c, shard, multiplicity) } /// Receives a byte operation with two outputs to be processed. - fn receive_byte_pair( + #[allow(clippy::too_many_arguments)] + fn receive_byte_pair( &mut self, opcode: EOp, a1: Ea1, a2: Ea2, b: Eb, c: Ec, + shard: EShard, multiplicity: EMult, ) where EOp: Into, @@ -166,10 +182,18 @@ pub trait ByteAirBuilder: BaseAirBuilder { Ea2: Into, Eb: Into, Ec: Into, + EShard: Into, EMult: Into, { self.receive(AirInteraction::new( - vec![opcode.into(), a1.into(), a2.into(), b.into(), c.into()], + vec![ + opcode.into(), + a1.into(), + a2.into(), + b.into(), + c.into(), + shard.into(), + ], multiplicity.into(), InteractionKind::Byte, )); @@ -217,9 +241,14 @@ pub trait WordAirBuilder: ByteAirBuilder { } /// Check that each limb of the given slice is a u8. - fn slice_range_check_u8 + Clone, EMult: Into + Clone>( + fn slice_range_check_u8< + EWord: Into + Clone, + EShard: Into + Clone, + EMult: Into + Clone, + >( &mut self, input: &[EWord], + shard: EShard, mult: EMult, ) { let mut index = 0; @@ -229,6 +258,7 @@ pub trait WordAirBuilder: ByteAirBuilder { Self::Expr::zero(), input[index].clone(), input[index + 1].clone(), + shard.clone(), mult.clone(), ); index += 2; @@ -239,15 +269,21 @@ pub trait WordAirBuilder: ByteAirBuilder { Self::Expr::zero(), input[index].clone(), Self::Expr::zero(), + shard.clone(), mult.clone(), ); } } /// Check that each limb of the given slice is a u16. - fn slice_range_check_u16 + Copy, EMult: Into + Clone>( + fn slice_range_check_u16< + EWord: Into + Copy, + EShard: Into + Clone, + EMult: Into + Clone, + >( &mut self, input: &[EWord], + shard: EShard, mult: EMult, ) { input.iter().for_each(|limb| { @@ -256,6 +292,7 @@ pub trait WordAirBuilder: ByteAirBuilder { *limb, Self::Expr::zero(), Self::Expr::zero(), + shard.clone(), mult.clone(), ); }); @@ -265,24 +302,27 @@ pub trait WordAirBuilder: ByteAirBuilder { /// A trait which contains methods related to ALU interactions in an AIR. pub trait AluAirBuilder: BaseAirBuilder { /// Sends an ALU operation to be processed. - fn send_alu( + fn send_alu( &mut self, opcode: EOp, a: Word, b: Word, c: Word, + shard: EShard, multiplicity: EMult, ) where EOp: Into, Ea: Into, Eb: Into, Ec: Into, + EShard: Into, EMult: Into, { let values = once(opcode.into()) .chain(a.0.into_iter().map(Into::into)) .chain(b.0.into_iter().map(Into::into)) .chain(c.0.into_iter().map(Into::into)) + .chain(once(shard.into())) .collect(); self.send(AirInteraction::new( @@ -293,24 +333,27 @@ pub trait AluAirBuilder: BaseAirBuilder { } /// Receives an ALU operation to be processed. - fn receive_alu( + fn receive_alu( &mut self, opcode: EOp, a: Word, b: Word, c: Word, + shard: EShard, multiplicity: EMult, ) where EOp: Into, Ea: Into, Eb: Into, Ec: Into, + EShard: Into, EMult: Into, { let values = once(opcode.into()) .chain(a.0.into_iter().map(Into::into)) .chain(b.0.into_iter().map(Into::into)) .chain(c.0.into_iter().map(Into::into)) + .chain(once(shard.into())) .collect(); self.receive(AirInteraction::new( @@ -457,12 +500,12 @@ pub trait MemoryAirBuilder: BaseAirBuilder { ) where Eb: Into + Clone, EVerify: Into, - EShard: Into, + EShard: Into + Clone, EClk: Into, { let do_check: Self::Expr = do_check.into(); let compare_clk: Self::Expr = mem_access.compare_clk.clone().into(); - let shard: Self::Expr = shard.into(); + let shard: Self::Expr = shard.clone().into(); let prev_shard: Self::Expr = mem_access.prev_shard.clone().into(); // First verify that compare_clk's value is correct. @@ -478,7 +521,7 @@ pub trait MemoryAirBuilder: BaseAirBuilder { mem_access.prev_shard.clone(), ); - let current_comp_val = self.if_else(compare_clk.clone(), clk.into(), shard); + let current_comp_val = self.if_else(compare_clk.clone(), clk.into(), shard.clone()); // Assert `current_comp_val > prev_comp_val`. We check this by asserting that // `0 <= current_comp_val-prev_comp_val-1 < 2^24`. @@ -495,6 +538,7 @@ pub trait MemoryAirBuilder: BaseAirBuilder { diff_minus_one, mem_access.diff_16bit_limb.clone(), mem_access.diff_8bit_limb.clone(), + shard.clone(), do_check, ); } @@ -505,15 +549,17 @@ pub trait MemoryAirBuilder: BaseAirBuilder { /// check on it's limbs. It will also verify that the limbs are correct. This method is needed /// since the memory access timestamp check (see [Self::verify_mem_access_ts]) needs to assume /// the clk is within 24 bits. - fn verify_range_24bits( + fn verify_range_24bits( &mut self, value: EValue, limb_16: ELimb, limb_8: ELimb, + shard: EShard, do_check: EVerify, ) where EValue: Into, ELimb: Into + Clone, + EShard: Into + Clone, EVerify: Into + Clone, { // Verify that value = limb_16 + limb_8 * 2^16. @@ -529,6 +575,7 @@ pub trait MemoryAirBuilder: BaseAirBuilder { limb_16, Self::Expr::zero(), Self::Expr::zero(), + shard.clone(), do_check.clone(), ); @@ -537,6 +584,7 @@ pub trait MemoryAirBuilder: BaseAirBuilder { Self::Expr::zero(), Self::Expr::zero(), limb_8, + shard.clone(), do_check, ) } @@ -571,22 +619,25 @@ pub trait MemoryAirBuilder: BaseAirBuilder { /// A trait which contains methods related to program interactions in an AIR. pub trait ProgramAirBuilder: BaseAirBuilder { /// Sends an instruction. - fn send_program( + fn send_program( &mut self, pc: EPc, instruction: InstructionCols, selectors: OpcodeSelectorCols, + shard: EShard, multiplicity: EMult, ) where EPc: Into, EInst: Into + Copy, ESel: Into + Copy, + EShard: Into + Copy, EMult: Into, { let values = once(pc.into()) .chain(once(instruction.opcode.into())) .chain(instruction.into_iter().map(|x| x.into())) .chain(selectors.into_iter().map(|x| x.into())) + .chain(once(shard.into())) .collect(); self.send(AirInteraction::new( @@ -597,22 +648,25 @@ pub trait ProgramAirBuilder: BaseAirBuilder { } /// Receives an instruction. - fn receive_program( + fn receive_program( &mut self, pc: EPc, instruction: InstructionCols, selectors: OpcodeSelectorCols, + shard: EShard, multiplicity: EMult, ) where EPc: Into, EInst: Into + Copy, ESel: Into + Copy, + EShard: Into + Copy, EMult: Into, { let values: Vec<::Expr> = once(pc.into()) .chain(once(instruction.opcode.into())) .chain(instruction.into_iter().map(|x| x.into())) .chain(selectors.into_iter().map(|x| x.into())) + .chain(once(shard.into())) .collect(); self.receive(AirInteraction::new( diff --git a/core/src/alu/add_sub/mod.rs b/core/src/alu/add_sub/mod.rs index 313c65307d..7becf84f9c 100644 --- a/core/src/alu/add_sub/mod.rs +++ b/core/src/alu/add_sub/mod.rs @@ -32,6 +32,9 @@ pub struct AddSubChip; #[derive(AlignedBorrow, Default, Clone, Copy)] #[repr(C)] pub struct AddSubCols { + /// The shard number, used for byte lookup table. + pub shard: T, + /// Boolean to indicate whether the row is for an add operation. pub is_add: T, /// Boolean to indicate whether the row is for a sub operation. @@ -84,6 +87,7 @@ impl MachineAir for AddSubChip { let mut row = [F::zero(); NUM_ADD_SUB_COLS]; let cols: &mut AddSubCols = row.as_mut_slice().borrow_mut(); let is_add = event.opcode == Opcode::ADD; + cols.shard = F::from_canonical_u32(event.shard); cols.is_add = F::from_bool(is_add); cols.is_sub = F::from_bool(!is_add); @@ -91,7 +95,7 @@ impl MachineAir for AddSubChip { let operand_2 = event.c; cols.add_operation - .populate(&mut record, operand_1, operand_2); + .populate(&mut record, event.shard, operand_1, operand_2); cols.operand_1 = Word::from(operand_1); cols.operand_2 = Word::from(operand_2); row @@ -149,6 +153,7 @@ where local.operand_1, local.operand_2, local.add_operation, + local.shard, is_real, ); @@ -159,6 +164,7 @@ where local.add_operation.value, local.operand_1, local.operand_2, + local.shard, local.is_add, ); @@ -168,6 +174,7 @@ where local.operand_1, local.add_operation.value, local.operand_2, + local.shard, local.is_sub, ); @@ -201,7 +208,7 @@ mod tests { #[test] fn generate_trace() { let mut shard = ExecutionRecord::default(); - shard.add_events = vec![AluEvent::new(0, Opcode::ADD, 14, 8, 6)]; + shard.add_events = vec![AluEvent::new(0, 0, Opcode::ADD, 14, 8, 6)]; let chip = AddSubChip::default(); let trace: RowMajorMatrix = chip.generate_trace(&shard, &mut ExecutionRecord::default()); @@ -218,17 +225,27 @@ mod tests { let operand_1 = thread_rng().gen_range(0..u32::MAX); let operand_2 = thread_rng().gen_range(0..u32::MAX); let result = operand_1.wrapping_add(operand_2); - shard - .add_events - .push(AluEvent::new(0, Opcode::ADD, result, operand_1, operand_2)); + shard.add_events.push(AluEvent::new( + 0, + 0, + Opcode::ADD, + result, + operand_1, + operand_2, + )); } for _ in 0..1000 { let operand_1 = thread_rng().gen_range(0..u32::MAX); let operand_2 = thread_rng().gen_range(0..u32::MAX); let result = operand_1.wrapping_sub(operand_2); - shard - .add_events - .push(AluEvent::new(0, Opcode::SUB, result, operand_1, operand_2)); + shard.add_events.push(AluEvent::new( + 0, + 0, + Opcode::SUB, + result, + operand_1, + operand_2, + )); } let chip = AddSubChip::default(); diff --git a/core/src/alu/bitwise/mod.rs b/core/src/alu/bitwise/mod.rs index 7986e3b3e7..38c02a1d1c 100644 --- a/core/src/alu/bitwise/mod.rs +++ b/core/src/alu/bitwise/mod.rs @@ -22,7 +22,11 @@ pub struct BitwiseChip; /// The column layout for the chip. #[derive(AlignedBorrow, Default, Clone, Copy)] +#[repr(C)] pub struct BitwiseCols { + /// The shard number, used for byte lookup table. + pub shard: T, + /// The output operand. pub a: Word, @@ -68,6 +72,7 @@ impl MachineAir for BitwiseChip { let b = event.b.to_le_bytes(); let c = event.c.to_le_bytes(); + cols.shard = F::from_canonical_u32(event.shard); cols.a = Word::from(event.a); cols.b = Word::from(event.b); cols.c = Word::from(event.c); @@ -78,6 +83,7 @@ impl MachineAir for BitwiseChip { for ((b_a, b_b), b_c) in a.into_iter().zip(b).zip(c) { let byte_event = ByteLookupEvent { + shard: event.shard, opcode: ByteOpcode::from(event.opcode), a1: b_a as u32, a2: 0, @@ -130,7 +136,7 @@ where // Get a multiplicity of `1` only for a true row. let mult = local.is_xor + local.is_or + local.is_and; for ((a, b), c) in local.a.into_iter().zip(local.b).zip(local.c) { - builder.send_byte(opcode.clone(), a, b, c, mult.clone()); + builder.send_byte(opcode.clone(), a, b, c, local.shard, mult.clone()); } // Receive the arguments. @@ -141,6 +147,7 @@ where local.a, local.b, local.c, + local.shard, local.is_xor + local.is_or + local.is_and, ); @@ -168,7 +175,7 @@ mod tests { #[test] fn generate_trace() { let mut shard = ExecutionRecord::default(); - shard.bitwise_events = vec![AluEvent::new(0, Opcode::XOR, 25, 10, 19)]; + shard.bitwise_events = vec![AluEvent::new(0, 0, Opcode::XOR, 25, 10, 19)]; let chip = BitwiseChip::default(); let trace: RowMajorMatrix = chip.generate_trace(&shard, &mut ExecutionRecord::default()); @@ -182,9 +189,9 @@ mod tests { let mut shard = ExecutionRecord::default(); shard.bitwise_events = [ - AluEvent::new(0, Opcode::XOR, 25, 10, 19), - AluEvent::new(0, Opcode::OR, 27, 10, 19), - AluEvent::new(0, Opcode::AND, 2, 10, 19), + AluEvent::new(0, 0, Opcode::XOR, 25, 10, 19), + AluEvent::new(0, 0, Opcode::OR, 27, 10, 19), + AluEvent::new(0, 0, Opcode::AND, 2, 10, 19), ] .repeat(1000); let chip = BitwiseChip::default(); diff --git a/core/src/alu/divrem/mod.rs b/core/src/alu/divrem/mod.rs index 559ff89497..d478d6646f 100644 --- a/core/src/alu/divrem/mod.rs +++ b/core/src/alu/divrem/mod.rs @@ -100,6 +100,9 @@ pub struct DivRemChip; #[derive(AlignedBorrow, Default, Debug, Clone, Copy)] #[repr(C)] pub struct DivRemCols { + /// The shard number, used for byte lookup table. + pub shard: T, + /// The output operand. pub a: Word, @@ -214,6 +217,7 @@ impl MachineAir for DivRemChip { cols.a = Word::from(event.a); cols.b = Word::from(event.b); cols.c = Word::from(event.c); + cols.shard = F::from_canonical_u32(event.shard); cols.is_real = F::one(); cols.is_divu = F::from_bool(event.opcode == Opcode::DIVU); cols.is_remu = F::from_bool(event.opcode == Opcode::REMU); @@ -255,6 +259,7 @@ impl MachineAir for DivRemChip { for word in words.iter() { let most_significant_byte = word.to_le_bytes()[WORD_SIZE - 1]; blu_events.push(ByteLookupEvent { + shard: event.shard, opcode: ByteOpcode::MSB, a1: get_msb(*word) as u32, a2: 0, @@ -314,6 +319,7 @@ impl MachineAir for DivRemChip { } let lower_multiplication = AluEvent { + shard: event.shard, clk: event.clk, opcode: Opcode::MUL, a: lower_word, @@ -323,6 +329,7 @@ impl MachineAir for DivRemChip { output.add_mul_event(lower_multiplication); let upper_multiplication = AluEvent { + shard: event.shard, clk: event.clk, opcode: { if is_signed_operation(event.opcode) { @@ -340,6 +347,7 @@ impl MachineAir for DivRemChip { let lt_event = if is_signed_operation(event.opcode) { AluEvent { + shard: event.shard, opcode: Opcode::SLT, a: 1, b: (remainder as i32).abs() as u32, @@ -348,6 +356,7 @@ impl MachineAir for DivRemChip { } } else { AluEvent { + shard: event.shard, opcode: Opcode::SLTU, a: 1, b: remainder, @@ -360,9 +369,9 @@ impl MachineAir for DivRemChip { // Range check. { - output.add_u8_range_checks("ient.to_le_bytes()); - output.add_u8_range_checks(&remainder.to_le_bytes()); - output.add_u8_range_checks(&c_times_quotient); + output.add_u8_range_checks(event.shard, "ient.to_le_bytes()); + output.add_u8_range_checks(event.shard, &remainder.to_le_bytes()); + output.add_u8_range_checks(event.shard, &c_times_quotient); } } @@ -455,6 +464,7 @@ where Word(lower_half), local.quotient, local.c, + local.shard, local.is_real, ); @@ -478,6 +488,7 @@ where Word(upper_half), local.quotient, local.c, + local.shard, local.is_real, ); } @@ -669,6 +680,7 @@ where Word([one.clone(), zero.clone(), zero.clone(), zero.clone()]), local.abs_remainder, local.max_abs_c_or_1, + local.shard, local.is_real, ); } @@ -684,20 +696,20 @@ where for msb_pair in msb_pairs.iter() { let msb = msb_pair.0; let byte = msb_pair.1; - builder.send_byte(opcode, msb, byte, zero.clone(), local.is_real); + builder.send_byte(opcode, msb, byte, zero.clone(), local.shard, local.is_real); } } // Range check all the bytes. { - builder.slice_range_check_u8(&local.quotient.0, local.is_real); - builder.slice_range_check_u8(&local.remainder.0, local.is_real); + builder.slice_range_check_u8(&local.quotient.0, local.shard, local.is_real); + builder.slice_range_check_u8(&local.remainder.0, local.shard, local.is_real); local.carry.iter().for_each(|carry| { builder.assert_bool(*carry); }); - builder.slice_range_check_u8(&local.c_times_quotient, local.is_real); + builder.slice_range_check_u8(&local.c_times_quotient, local.shard, local.is_real); } // Check that the flags are boolean. @@ -741,7 +753,14 @@ where + local.is_rem * rem }; - builder.receive_alu(opcode, local.a, local.b, local.c, local.is_real); + builder.receive_alu( + opcode, + local.a, + local.b, + local.c, + local.shard, + local.is_real, + ); } // A dummy constraint to keep the degree 3. @@ -773,7 +792,7 @@ mod tests { #[test] fn generate_trace() { let mut shard = ExecutionRecord::default(); - shard.divrem_events = vec![AluEvent::new(0, Opcode::DIVU, 2, 17, 3)]; + shard.divrem_events = vec![AluEvent::new(0, 0, Opcode::DIVU, 2, 17, 3)]; let chip = DivRemChip::default(); let trace: RowMajorMatrix = chip.generate_trace(&shard, &mut ExecutionRecord::default()); @@ -826,12 +845,12 @@ mod tests { (Opcode::REM, 0, 1 << 31, neg(1)), ]; for t in divrems.iter() { - divrem_events.push(AluEvent::new(0, t.0, t.1, t.2, t.3)); + divrem_events.push(AluEvent::new(0, 0, t.0, t.1, t.2, t.3)); } // Append more events until we have 1000 tests. for _ in 0..(1000 - divrems.len()) { - divrem_events.push(AluEvent::new(0, Opcode::DIVU, 1, 1, 1)); + divrem_events.push(AluEvent::new(0, 0, Opcode::DIVU, 1, 1, 1)); } let mut shard = ExecutionRecord::default(); diff --git a/core/src/alu/lt/mod.rs b/core/src/alu/lt/mod.rs index 6de0c30ee5..9d5e17ea01 100644 --- a/core/src/alu/lt/mod.rs +++ b/core/src/alu/lt/mod.rs @@ -26,6 +26,9 @@ pub struct LtChip; #[derive(AlignedBorrow, Default, Clone, Copy)] #[repr(C)] pub struct LtCols { + /// The shard number, used for byte lookup table. + pub shard: T, + /// The output operand. pub a: Word, @@ -101,6 +104,7 @@ impl MachineAir for LtChip { let b = event.b.to_le_bytes(); let c = event.c.to_le_bytes(); + cols.shard = F::from_canonical_u32(event.shard); cols.a = Word(a.map(F::from_canonical_u8)); cols.b = Word(b.map(F::from_canonical_u8)); cols.c = Word(c.map(F::from_canonical_u8)); @@ -295,6 +299,7 @@ where local.a, local.b, local.c, + local.shard, local.is_slt + local.is_sltu, ); } @@ -322,7 +327,7 @@ mod tests { #[test] fn generate_trace() { let mut shard = ExecutionRecord::default(); - shard.lt_events = vec![AluEvent::new(0, Opcode::SLT, 0, 3, 2)]; + shard.lt_events = vec![AluEvent::new(0, 0, Opcode::SLT, 0, 3, 2)]; let chip = LtChip::default(); let trace: RowMajorMatrix = chip.generate_trace(&shard, &mut ExecutionRecord::default()); @@ -350,21 +355,21 @@ mod tests { const NEG_4: u32 = 0b11111111111111111111111111111100; shard.lt_events = vec![ // 0 == 3 < 2 - AluEvent::new(0, Opcode::SLT, 0, 3, 2), + AluEvent::new(0, 0, Opcode::SLT, 0, 3, 2), // 1 == 2 < 3 - AluEvent::new(1, Opcode::SLT, 1, 2, 3), + AluEvent::new(0, 1, Opcode::SLT, 1, 2, 3), // 0 == 5 < -3 - AluEvent::new(3, Opcode::SLT, 0, 5, NEG_3), + AluEvent::new(0, 3, Opcode::SLT, 0, 5, NEG_3), // 1 == -3 < 5 - AluEvent::new(2, Opcode::SLT, 1, NEG_3, 5), + AluEvent::new(0, 2, Opcode::SLT, 1, NEG_3, 5), // 0 == -3 < -4 - AluEvent::new(4, Opcode::SLT, 0, NEG_3, NEG_4), + AluEvent::new(0, 4, Opcode::SLT, 0, NEG_3, NEG_4), // 1 == -4 < -3 - AluEvent::new(4, Opcode::SLT, 1, NEG_4, NEG_3), + AluEvent::new(0, 4, Opcode::SLT, 1, NEG_4, NEG_3), // 0 == 3 < 3 - AluEvent::new(5, Opcode::SLT, 0, 3, 3), + AluEvent::new(0, 5, Opcode::SLT, 0, 3, 3), // 0 == -3 < -3 - AluEvent::new(5, Opcode::SLT, 0, NEG_3, NEG_3), + AluEvent::new(0, 5, Opcode::SLT, 0, NEG_3, NEG_3), ]; prove_babybear_template(&mut shard); @@ -377,17 +382,17 @@ mod tests { const LARGE: u32 = 0b11111111111111111111111111111101; shard.lt_events = vec![ // 0 == 3 < 2 - AluEvent::new(0, Opcode::SLTU, 0, 3, 2), + AluEvent::new(0, 0, Opcode::SLTU, 0, 3, 2), // 1 == 2 < 3 - AluEvent::new(1, Opcode::SLTU, 1, 2, 3), + AluEvent::new(0, 1, Opcode::SLTU, 1, 2, 3), // 0 == LARGE < 5 - AluEvent::new(2, Opcode::SLTU, 0, LARGE, 5), + AluEvent::new(0, 2, Opcode::SLTU, 0, LARGE, 5), // 1 == 5 < LARGE - AluEvent::new(3, Opcode::SLTU, 1, 5, LARGE), + AluEvent::new(0, 3, Opcode::SLTU, 1, 5, LARGE), // 0 == 0 < 0 - AluEvent::new(5, Opcode::SLTU, 0, 0, 0), + AluEvent::new(0, 5, Opcode::SLTU, 0, 0, 0), // 0 == LARGE < LARGE - AluEvent::new(5, Opcode::SLTU, 0, LARGE, LARGE), + AluEvent::new(0, 5, Opcode::SLTU, 0, LARGE, LARGE), ]; prove_babybear_template(&mut shard); diff --git a/core/src/alu/mod.rs b/core/src/alu/mod.rs index 9dc0a7b01b..1f780988b3 100644 --- a/core/src/alu/mod.rs +++ b/core/src/alu/mod.rs @@ -21,6 +21,9 @@ use crate::runtime::Opcode; /// A standard format for describing ALU operations that need to be proven. #[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub struct AluEvent { + /// The shard number, used for byte lookup table. + pub shard: u32, + /// The clock cycle that the operation occurs on. pub clk: u32, @@ -39,8 +42,9 @@ pub struct AluEvent { impl AluEvent { /// Creates a new `AluEvent`. - pub fn new(clk: u32, opcode: Opcode, a: u32, b: u32, c: u32) -> Self { + pub fn new(shard: u32, clk: u32, opcode: Opcode, a: u32, b: u32, c: u32) -> Self { Self { + shard, clk, opcode, a, diff --git a/core/src/alu/mul/mod.rs b/core/src/alu/mul/mod.rs index 1fbb1f6fa5..deff5ebd2d 100644 --- a/core/src/alu/mul/mod.rs +++ b/core/src/alu/mul/mod.rs @@ -72,6 +72,9 @@ pub struct MulChip; #[derive(AlignedBorrow, Default, Debug, Clone, Copy)] #[repr(C)] pub struct MulCols { + /// The shard number, used for byte lookup table. + pub shard: T, + /// The output operand. pub a: Word, @@ -187,6 +190,7 @@ impl MachineAir for MulChip { for word in words.iter() { let most_significant_byte = word[WORD_SIZE - 1]; blu_events.push(ByteLookupEvent { + shard: event.shard, opcode: ByteOpcode::MSB, a1: get_msb(*word) as u32, a2: 0, @@ -229,11 +233,12 @@ impl MachineAir for MulChip { cols.is_mulh = F::from_bool(event.opcode == Opcode::MULH); cols.is_mulhu = F::from_bool(event.opcode == Opcode::MULHU); cols.is_mulhsu = F::from_bool(event.opcode == Opcode::MULHSU); + cols.shard = F::from_canonical_u32(event.shard); // Range check. { - record.add_u16_range_checks(&carry); - record.add_u8_range_checks(&product.map(|x| x as u8)); + record.add_u16_range_checks(event.shard, &carry); + record.add_u8_range_checks(event.shard, &product.map(|x| x as u8)); } row }) @@ -294,7 +299,7 @@ where for msb_pair in msb_pairs.iter() { let msb = msb_pair.0; let byte = msb_pair.1; - builder.send_byte(opcode, msb, byte, zero.clone(), local.is_real); + builder.send_byte(opcode, msb, byte, zero.clone(), local.shard, local.is_real); } (local.b_msb, local.c_msb) }; @@ -420,13 +425,20 @@ where // Ensure that the carry is at most 2^16. This ensures that // product_before_carry_propagation - carry * base + last_carry never overflows or // underflows enough to "wrap" around to create a second solution. - builder.slice_range_check_u16(&local.carry, local.is_real); + builder.slice_range_check_u16(&local.carry, local.shard, local.is_real); - builder.slice_range_check_u8(&local.product, local.is_real); + builder.slice_range_check_u8(&local.product, local.shard, local.is_real); } // Receive the arguments. - builder.receive_alu(opcode, local.a, local.b, local.c, local.is_real); + builder.receive_alu( + opcode, + local.a, + local.b, + local.c, + local.shard, + local.is_real, + ); // A dummy constraint to keep the degree at least 3. builder.assert_zero( @@ -462,6 +474,7 @@ mod tests { let mut mul_events: Vec = Vec::new(); for _ in 0..10i32.pow(7) { mul_events.push(AluEvent::new( + 0, 0, Opcode::MULHSU, 0x80004000, @@ -536,12 +549,12 @@ mod tests { (Opcode::MULH, 0xffffffff, 0x00000001, 0xffffffff), ]; for t in mul_instructions.iter() { - mul_events.push(AluEvent::new(0, t.0, t.1, t.2, t.3)); + mul_events.push(AluEvent::new(0, 0, t.0, t.1, t.2, t.3)); } // Append more events until we have 1000 tests. for _ in 0..(1000 - mul_instructions.len()) { - mul_events.push(AluEvent::new(0, Opcode::MUL, 1, 1, 1)); + mul_events.push(AluEvent::new(0, 0, Opcode::MUL, 1, 1, 1)); } shard.mul_events = mul_events; diff --git a/core/src/alu/sll/mod.rs b/core/src/alu/sll/mod.rs index a8b1d83d00..da1601c7d3 100644 --- a/core/src/alu/sll/mod.rs +++ b/core/src/alu/sll/mod.rs @@ -60,6 +60,9 @@ pub struct ShiftLeft; #[derive(AlignedBorrow, Default, Debug, Clone, Copy)] #[repr(C)] pub struct ShiftLeftCols { + /// The shard number, used for byte lookup table. + pub shard: T, + /// The output operand. pub a: Word, @@ -114,6 +117,7 @@ impl MachineAir for ShiftLeft { let a = event.a.to_le_bytes(); let b = event.b.to_le_bytes(); let c = event.c.to_le_bytes(); + cols.shard = F::from_canonical_u32(event.shard); cols.a = Word(a.map(F::from_canonical_u8)); cols.b = Word(b.map(F::from_canonical_u8)); cols.c = Word(c.map(F::from_canonical_u8)); @@ -152,8 +156,8 @@ impl MachineAir for ShiftLeft { // Range checks. { - output.add_u8_range_checks(&bit_shift_result); - output.add_u8_range_checks(&bit_shift_result_carry); + output.add_u8_range_checks(event.shard, &bit_shift_result); + output.add_u8_range_checks(event.shard, &bit_shift_result_carry); } // Sanity check. @@ -308,8 +312,8 @@ where // Range check. { - builder.slice_range_check_u8(&local.bit_shift_result, local.is_real); - builder.slice_range_check_u8(&local.bit_shift_result_carry, local.is_real); + builder.slice_range_check_u8(&local.bit_shift_result, local.shard, local.is_real); + builder.slice_range_check_u8(&local.bit_shift_result_carry, local.shard, local.is_real); } for shift in local.shift_by_n_bytes.iter() { @@ -332,6 +336,7 @@ where local.a, local.b, local.c, + local.shard, local.is_real, ); @@ -364,7 +369,7 @@ mod tests { #[test] fn generate_trace() { let mut shard = ExecutionRecord::default(); - shard.shift_left_events = vec![AluEvent::new(0, Opcode::SLL, 16, 8, 1)]; + shard.shift_left_events = vec![AluEvent::new(0, 0, Opcode::SLL, 16, 8, 1)]; let chip = ShiftLeft::default(); let trace: RowMajorMatrix = chip.generate_trace(&shard, &mut ExecutionRecord::default()); @@ -399,12 +404,12 @@ mod tests { (Opcode::SLL, 0x00000000, 0x21212120, 0xffffffff), ]; for t in shift_instructions.iter() { - shift_events.push(AluEvent::new(0, t.0, t.1, t.2, t.3)); + shift_events.push(AluEvent::new(0, 0, t.0, t.1, t.2, t.3)); } // Append more events until we have 1000 tests. for _ in 0..(1000 - shift_instructions.len()) { - //shift_events.push(AluEvent::new(0, Opcode::SLL, 14, 8, 6)); + //shift_events.push(AluEvent::new(0, 0, Opcode::SLL, 14, 8, 6)); } let mut shard = ExecutionRecord::default(); diff --git a/core/src/alu/sr/mod.rs b/core/src/alu/sr/mod.rs index 5280b93208..a29e7a1c9e 100644 --- a/core/src/alu/sr/mod.rs +++ b/core/src/alu/sr/mod.rs @@ -79,6 +79,9 @@ pub struct ShiftRightChip; #[derive(AlignedBorrow, Default, Debug, Clone, Copy)] #[repr(C)] pub struct ShiftRightCols { + /// The shard number, used for byte lookup table. + pub shard: T, + /// The output operand. pub a: Word, @@ -146,6 +149,7 @@ impl MachineAir for ShiftRightChip { let cols: &mut ShiftRightCols = row.as_mut_slice().borrow_mut(); // Initialize cols with basic operands and flags derived from the current event. { + cols.shard = F::from_canonical_u32(event.shard); cols.a = Word::from(event.a); cols.b = Word::from(event.b); cols.c = Word::from(event.c); @@ -164,6 +168,7 @@ impl MachineAir for ShiftRightChip { // Insert the MSB lookup event. let most_significant_byte = event.b.to_le_bytes()[WORD_SIZE - 1]; output.add_byte_lookup_events(vec![ByteLookupEvent { + shard: event.shard, opcode: ByteOpcode::MSB, a1: ((most_significant_byte >> 7) & 1) as u32, a2: 0, @@ -212,6 +217,7 @@ impl MachineAir for ShiftRightChip { let (shift, carry) = shr_carry(byte_shift_result[i], num_bits_to_shift as u8); let byte_event = ByteLookupEvent { + shard: event.shard, opcode: ByteOpcode::ShrCarry, a1: shift as u32, a2: carry as u32, @@ -234,10 +240,10 @@ impl MachineAir for ShiftRightChip { debug_assert_eq!(cols.a[i], cols.bit_shift_result[i].clone()); } // Range checks. - output.add_u8_range_checks(&byte_shift_result); - output.add_u8_range_checks(&bit_shift_result); - output.add_u8_range_checks(&shr_carry_output_carry); - output.add_u8_range_checks(&shr_carry_output_shifted_byte); + output.add_u8_range_checks(event.shard, &byte_shift_result); + output.add_u8_range_checks(event.shard, &bit_shift_result); + output.add_u8_range_checks(event.shard, &shr_carry_output_carry); + output.add_u8_range_checks(event.shard, &shr_carry_output_shifted_byte); } rows.push(row); @@ -297,7 +303,7 @@ where let byte = local.b[WORD_SIZE - 1]; let opcode = AB::F::from_canonical_u32(ByteOpcode::MSB as u32); let msb = local.b_msb; - builder.send_byte(opcode, msb, byte, zero.clone(), local.is_real); + builder.send_byte(opcode, msb, byte, zero.clone(), local.shard, local.is_real); } // Calculate the number of bits and bytes to shift by from c. @@ -400,6 +406,7 @@ where local.shr_carry_output_carry[i], local.byte_shift_result[i], num_bits_to_shift.clone(), + local.shard, local.is_real, ); } @@ -446,7 +453,7 @@ where ]; for long_word in long_words.iter() { - builder.slice_range_check_u8(long_word, local.is_real); + builder.slice_range_check_u8(long_word, local.shard, local.is_real); } } @@ -461,6 +468,7 @@ where local.a, local.b, local.c, + local.shard, local.is_real, ); } @@ -487,7 +495,7 @@ mod tests { #[test] fn generate_trace() { let mut shard = ExecutionRecord::default(); - shard.shift_right_events = vec![AluEvent::new(0, Opcode::SRL, 6, 12, 1)]; + shard.shift_right_events = vec![AluEvent::new(0, 0, Opcode::SRL, 6, 12, 1)]; let chip = ShiftRightChip::default(); let trace: RowMajorMatrix = chip.generate_trace(&shard, &mut ExecutionRecord::default()); @@ -538,7 +546,7 @@ mod tests { ]; let mut shift_events: Vec = Vec::new(); for t in shifts.iter() { - shift_events.push(AluEvent::new(0, t.0, t.1, t.2, t.3)); + shift_events.push(AluEvent::new(0, 0, t.0, t.1, t.2, t.3)); } let mut shard = ExecutionRecord::default(); shard.shift_right_events = shift_events; diff --git a/core/src/bytes/air.rs b/core/src/bytes/air.rs index 4f7b3ca2b4..22c2f04d56 100644 --- a/core/src/bytes/air.rs +++ b/core/src/bytes/air.rs @@ -33,19 +33,22 @@ impl Air for ByteChip { for (i, opcode) in ByteOpcode::all().iter().enumerate() { let field_op = opcode.as_field::(); let mult = local_mult.multiplicities[i]; + let shard = local_mult.shard; match opcode { ByteOpcode::AND => { - builder.receive_byte(field_op, local.and, local.b, local.c, mult) + builder.receive_byte(field_op, local.and, local.b, local.c, shard, mult) + } + ByteOpcode::OR => { + builder.receive_byte(field_op, local.or, local.b, local.c, shard, mult) } - ByteOpcode::OR => builder.receive_byte(field_op, local.or, local.b, local.c, mult), ByteOpcode::XOR => { - builder.receive_byte(field_op, local.xor, local.b, local.c, mult) + builder.receive_byte(field_op, local.xor, local.b, local.c, shard, mult) } ByteOpcode::SLL => { - builder.receive_byte(field_op, local.sll, local.b, local.c, mult) + builder.receive_byte(field_op, local.sll, local.b, local.c, shard, mult) } ByteOpcode::U8Range => { - builder.receive_byte(field_op, AB::F::zero(), local.b, local.c, mult) + builder.receive_byte(field_op, AB::F::zero(), local.b, local.c, shard, mult) } ByteOpcode::ShrCarry => builder.receive_byte_pair( field_op, @@ -53,19 +56,21 @@ impl Air for ByteChip { local.shr_carry, local.b, local.c, + shard, mult, ), ByteOpcode::LTU => { - builder.receive_byte(field_op, local.ltu, local.b, local.c, mult) + builder.receive_byte(field_op, local.ltu, local.b, local.c, shard, mult) } ByteOpcode::MSB => { - builder.receive_byte(field_op, local.msb, local.b, AB::F::zero(), mult) + builder.receive_byte(field_op, local.msb, local.b, AB::F::zero(), shard, mult) } ByteOpcode::U16Range => builder.receive_byte( field_op, local.value_u16, AB::F::zero(), AB::F::zero(), + shard, mult, ), } diff --git a/core/src/bytes/columns.rs b/core/src/bytes/columns.rs index a0f66f4dda..7a7a188792 100644 --- a/core/src/bytes/columns.rs +++ b/core/src/bytes/columns.rs @@ -43,8 +43,12 @@ pub struct BytePreprocessedCols { pub value_u16: T, } +/// For each byte operation in the preprocessed table, a corresponding ByteMultCols row tracks the +/// number of times the operation is used. #[derive(Debug, Clone, Copy, AlignedBorrow)] #[repr(C)] pub struct ByteMultCols { pub multiplicities: [T; NUM_BYTE_OPS], + /// Shard number is tracked so that the multiplicities do not overflow. + pub shard: T, } diff --git a/core/src/bytes/event.rs b/core/src/bytes/event.rs index 466401b53b..3d56ac7226 100644 --- a/core/src/bytes/event.rs +++ b/core/src/bytes/event.rs @@ -4,6 +4,9 @@ use serde::{Deserialize, Serialize}; /// A byte lookup event. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub struct ByteLookupEvent { + /// The shard number, used for byte lookup table. + pub shard: u32, + /// The opcode of the operation. pub opcode: ByteOpcode, @@ -22,8 +25,9 @@ pub struct ByteLookupEvent { impl ByteLookupEvent { /// Creates a new `ByteLookupEvent`. - pub fn new(opcode: ByteOpcode, a1: u32, a2: u32, b: u32, c: u32) -> Self { + pub fn new(shard: u32, opcode: ByteOpcode, a1: u32, a2: u32, b: u32, c: u32) -> Self { Self { + shard, opcode, a1, a2, diff --git a/core/src/bytes/mod.rs b/core/src/bytes/mod.rs index 36160a136c..e7d4054c8f 100644 --- a/core/src/bytes/mod.rs +++ b/core/src/bytes/mod.rs @@ -36,7 +36,9 @@ impl ByteChip { /// - `trace` is a matrix containing all possible byte operations. /// - `map` is a map map from a byte lookup to the corresponding row it appears in the table and /// the index of the result in the array of multiplicities. - pub fn trace_and_map() -> (RowMajorMatrix, BTreeMap) { + pub fn trace_and_map( + shard: u32, + ) -> (RowMajorMatrix, BTreeMap) { // A map from a byte lookup to its corresponding row in the table and index in the array of // multiplicities. let mut event_map = BTreeMap::new(); @@ -66,44 +68,53 @@ impl ByteChip { ByteOpcode::AND => { let and = b & c; col.and = F::from_canonical_u8(and); - ByteLookupEvent::new(*opcode, and as u32, 0, b as u32, c as u32) + ByteLookupEvent::new(shard, *opcode, and as u32, 0, b as u32, c as u32) } ByteOpcode::OR => { let or = b | c; col.or = F::from_canonical_u8(or); - ByteLookupEvent::new(*opcode, or as u32, 0, b as u32, c as u32) + ByteLookupEvent::new(shard, *opcode, or as u32, 0, b as u32, c as u32) } ByteOpcode::XOR => { let xor = b ^ c; col.xor = F::from_canonical_u8(xor); - ByteLookupEvent::new(*opcode, xor as u32, 0, b as u32, c as u32) + ByteLookupEvent::new(shard, *opcode, xor as u32, 0, b as u32, c as u32) } ByteOpcode::SLL => { let sll = b << (c & 7); col.sll = F::from_canonical_u8(sll); - ByteLookupEvent::new(*opcode, sll as u32, 0, b as u32, c as u32) + ByteLookupEvent::new(shard, *opcode, sll as u32, 0, b as u32, c as u32) + } + ByteOpcode::U8Range => { + ByteLookupEvent::new(shard, *opcode, 0, 0, b as u32, c as u32) } - ByteOpcode::U8Range => ByteLookupEvent::new(*opcode, 0, 0, b as u32, c as u32), ByteOpcode::ShrCarry => { let (res, carry) = shr_carry(b, c); col.shr = F::from_canonical_u8(res); col.shr_carry = F::from_canonical_u8(carry); - ByteLookupEvent::new(*opcode, res as u32, carry as u32, b as u32, c as u32) + ByteLookupEvent::new( + shard, + *opcode, + res as u32, + carry as u32, + b as u32, + c as u32, + ) } ByteOpcode::LTU => { let ltu = b < c; col.ltu = F::from_bool(ltu); - ByteLookupEvent::new(*opcode, ltu as u32, 0, b as u32, c as u32) + ByteLookupEvent::new(shard, *opcode, ltu as u32, 0, b as u32, c as u32) } ByteOpcode::MSB => { let msb = (b & 0b1000_0000) != 0; col.msb = F::from_bool(msb); - ByteLookupEvent::new(*opcode, msb as u32, 0, b as u32, 0 as u32) + ByteLookupEvent::new(shard, *opcode, msb as u32, 0, b as u32, 0 as u32) } ByteOpcode::U16Range => { let v = ((b as u32) << 8) + c as u32; col.value_u16 = F::from_canonical_u32(v); - ByteLookupEvent::new(*opcode, v, 0, 0, 0) + ByteLookupEvent::new(shard, *opcode, v, 0, 0, 0) } }; event_map.insert(event, (row_index, i)); diff --git a/core/src/bytes/trace.rs b/core/src/bytes/trace.rs index 92324d9a3a..22b8204208 100644 --- a/core/src/bytes/trace.rs +++ b/core/src/bytes/trace.rs @@ -1,8 +1,10 @@ +use std::borrow::BorrowMut; + use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; use super::{ - columns::{NUM_BYTE_MULT_COLS, NUM_BYTE_PREPROCESSED_COLS}, + columns::{ByteMultCols, NUM_BYTE_MULT_COLS, NUM_BYTE_PREPROCESSED_COLS}, ByteChip, }; use crate::{ @@ -26,28 +28,39 @@ impl MachineAir for ByteChip { } fn generate_preprocessed_trace(&self, _program: &Self::Program) -> Option> { - let (trace, _) = Self::trace_and_map(); + // TODO: We should be able to make this a constant. Also, trace / map should be separate. + // Since we only need the trace and not the map, we can just pass 0 as the shard. + let (trace, _) = Self::trace_and_map(0); Some(trace) } + fn generate_dependencies(&self, _input: &ExecutionRecord, _output: &mut ExecutionRecord) { + // Do nothing since this chip has no dependencies. + } + fn generate_trace( &self, input: &ExecutionRecord, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let (_, event_map) = Self::trace_and_map(); + let shard = input.index; + let (_, event_map) = Self::trace_and_map(shard); let mut trace = RowMajorMatrix::new( vec![F::zero(); NUM_BYTE_MULT_COLS * NUM_ROWS], NUM_BYTE_MULT_COLS, ); - for (lookup, mult) in input.byte_lookups.iter() { + for (lookup, mult) in input.byte_lookups[&shard].iter() { let (row, index) = event_map[lookup]; + let cols: &mut ByteMultCols = trace.row_mut(row).borrow_mut(); // Update the trace multiplicity - trace.row_mut(row)[index] += F::from_canonical_usize(*mult); + cols.multiplicities[index] += F::from_canonical_usize(*mult); + + // Set the shard column as the current shard. + cols.shard = F::from_canonical_u32(shard); } trace diff --git a/core/src/cpu/air/branch.rs b/core/src/cpu/air/branch.rs index e71658a69f..0611f73f91 100644 --- a/core/src/cpu/air/branch.rs +++ b/core/src/cpu/air/branch.rs @@ -57,6 +57,7 @@ impl CpuChip { branch_cols.next_pc, branch_cols.pc, local.op_c_val(), + local.shard, local.branching, ); @@ -138,6 +139,7 @@ impl CpuChip { Word::extend_var::(branch_cols.a_lt_b), local.op_a_val(), local.op_b_val(), + local.shard, is_branch_instruction.clone(), ); @@ -148,6 +150,7 @@ impl CpuChip { Word::extend_var::(branch_cols.a_gt_b), local.op_b_val(), local.op_a_val(), + local.shard, is_branch_instruction.clone(), ); } diff --git a/core/src/cpu/air/memory.rs b/core/src/cpu/air/memory.rs index 5175a1d4ec..85826633cd 100644 --- a/core/src/cpu/air/memory.rs +++ b/core/src/cpu/air/memory.rs @@ -83,6 +83,7 @@ impl CpuChip { local.op_a_val(), local.unsigned_mem_val, signed_value, + local.shard, local.mem_value_is_neg, ); diff --git a/core/src/cpu/air/mod.rs b/core/src/cpu/air/mod.rs index f0c613d96e..4e47333403 100644 --- a/core/src/cpu/air/mod.rs +++ b/core/src/cpu/air/mod.rs @@ -37,7 +37,13 @@ where let is_alu_instruction: AB::Expr = self.is_alu_instruction::(&local.selectors); // Program constraints. - builder.send_program(local.pc, local.instruction, local.selectors, local.is_real); + builder.send_program( + local.pc, + local.instruction, + local.selectors, + local.shard, + local.is_real, + ); // Load immediates into b and c, if the immediate flags are on. builder @@ -110,7 +116,11 @@ where ); // Check that each addr_word element is a byte. - builder.slice_range_check_u8(&memory_columns.addr_word.0, is_memory_instruction.clone()); + builder.slice_range_check_u8( + &memory_columns.addr_word.0, + local.shard, + is_memory_instruction.clone(), + ); // Send to the ALU table to verify correct calculation of addr_word. builder.send_alu( @@ -118,6 +128,7 @@ where memory_columns.addr_word, local.op_b_val(), local.op_c_val(), + local.shard, is_memory_instruction.clone(), ); @@ -142,6 +153,7 @@ where local.op_a_val(), local.op_b_val(), local.op_c_val(), + local.shard, is_alu_instruction, ); @@ -227,6 +239,7 @@ impl CpuChip { jump_columns.next_pc, jump_columns.pc, local.op_b_val(), + local.shard, local.selectors.is_jal, ); @@ -236,6 +249,7 @@ impl CpuChip { jump_columns.next_pc, local.op_b_val(), local.op_c_val(), + local.shard, local.selectors.is_jalr, ); } @@ -256,6 +270,7 @@ impl CpuChip { local.op_a_val(), auipc_columns.pc, local.op_b_val(), + local.shard, local.selectors.is_auipc, ); } @@ -377,6 +392,7 @@ impl CpuChip { local.shard, AB::Expr::zero(), AB::Expr::zero(), + local.shard, local.is_real, ); @@ -402,6 +418,7 @@ impl CpuChip { local.clk, local.clk_16bit_limb, local.clk_8bit_limb, + local.shard, local.is_real, ); } diff --git a/core/src/cpu/trace.rs b/core/src/cpu/trace.rs index ce571277e4..0cfc8a03df 100644 --- a/core/src/cpu/trace.rs +++ b/core/src/cpu/trace.rs @@ -185,15 +185,36 @@ impl CpuChip { new_blu_events: &mut Vec, ) { cols.shard = F::from_canonical_u32(event.shard); - new_blu_events.push(ByteLookupEvent::new(U16Range, event.shard, 0, 0, 0)); + new_blu_events.push(ByteLookupEvent::new( + event.shard, + U16Range, + event.shard, + 0, + 0, + 0, + )); cols.clk = F::from_canonical_u32(event.clk); let clk_16bit_limb = event.clk & 0xffff; cols.clk_16bit_limb = F::from_canonical_u32(clk_16bit_limb); let clk_8bit_limb = (event.clk >> 16) & 0xff; cols.clk_8bit_limb = F::from_canonical_u32(clk_8bit_limb); - new_blu_events.push(ByteLookupEvent::new(U16Range, clk_16bit_limb, 0, 0, 0)); - new_blu_events.push(ByteLookupEvent::new(U8Range, 0, 0, 0, clk_8bit_limb)); + new_blu_events.push(ByteLookupEvent::new( + event.shard, + U16Range, + clk_16bit_limb, + 0, + 0, + 0, + )); + new_blu_events.push(ByteLookupEvent::new( + event.shard, + U8Range, + 0, + 0, + 0, + clk_8bit_limb, + )); } /// Populates columns related to memory. @@ -227,6 +248,7 @@ impl CpuChip { // Add event to ALU check to check that addr == b + c let add_event = AluEvent { + shard: event.shard, clk: event.clk, opcode: Opcode::ADD, a: memory_addr, @@ -290,6 +312,7 @@ impl CpuChip { if memory_columns.most_sig_byte_decomp[7] == F::one() { cols.mem_value_is_neg = F::one(); let sub_event = AluEvent { + shard: event.shard, clk: event.clk, opcode: Opcode::SUB, a: event.a, @@ -309,6 +332,7 @@ impl CpuChip { let addr_bytes = memory_addr.to_le_bytes(); for byte_pair in addr_bytes.chunks_exact(2) { new_blu_events.push(ByteLookupEvent { + shard: event.shard, opcode: ByteOpcode::U8Range, a1: 0, a2: 0, @@ -351,6 +375,7 @@ impl CpuChip { }; // Add the ALU events for the comparisons let lt_comp_event = AluEvent { + shard: event.shard, clk: event.clk, opcode: alu_op_code, a: a_lt_b as u32, @@ -364,6 +389,7 @@ impl CpuChip { .or_insert(vec![lt_comp_event]); let gt_comp_event = AluEvent { + shard: event.shard, clk: event.clk, opcode: alu_op_code, a: a_gt_b as u32, @@ -396,6 +422,7 @@ impl CpuChip { branch_columns.next_pc = next_pc.into(); let add_event = AluEvent { + shard: event.shard, clk: event.clk, opcode: Opcode::ADD, a: next_pc, @@ -430,6 +457,7 @@ impl CpuChip { jump_columns.next_pc = next_pc.into(); let add_event = AluEvent { + shard: event.shard, clk: event.clk, opcode: Opcode::ADD, a: next_pc, @@ -447,6 +475,7 @@ impl CpuChip { jump_columns.next_pc = next_pc.into(); let add_event = AluEvent { + shard: event.shard, clk: event.clk, opcode: Opcode::ADD, a: next_pc, @@ -477,6 +506,7 @@ impl CpuChip { auipc_columns.pc = event.pc.into(); let add_event = AluEvent { + shard: event.shard, clk: event.clk, opcode: Opcode::ADD, a: event.a, diff --git a/core/src/lookup/debug.rs b/core/src/lookup/debug.rs index 6fa7fc2ba4..7961199edc 100644 --- a/core/src/lookup/debug.rs +++ b/core/src/lookup/debug.rs @@ -6,7 +6,7 @@ use p3_field::{Field, PrimeField64}; use p3_matrix::Matrix; use crate::air::MachineAir; -use crate::stark::{MachineChip, StarkGenericConfig, Val}; +use crate::stark::{MachineChip, MachineStark, ProvingKey, StarkGenericConfig, Val}; use super::InteractionKind; @@ -48,7 +48,7 @@ fn field_to_int(x: F) -> i32 { pub fn debug_interactions>>( chip: &MachineChip, - program: &A::Program, + pkey: &ProvingKey, record: &A::Record, interaction_kinds: Vec, ) -> ( @@ -59,7 +59,11 @@ pub fn debug_interactions>>( let mut key_to_count = BTreeMap::new(); let trace = chip.generate_trace(record, &mut A::Record::default()); - let mut preprocessed_trace = chip.generate_preprocessed_trace(program); + let mut pre_traces = pkey.traces.clone(); + let mut preprocessed_trace = pkey + .chip_ordering + .get(&chip.name()) + .map(|&index| pre_traces.get_mut(index).unwrap()); let mut main = trace.clone(); let height = trace.clone().height(); @@ -123,9 +127,9 @@ pub fn debug_interactions>>( /// Calculate the number of times we send and receive each event of the given interaction type, /// and print out the ones for which the set of sends and receives don't match. pub fn debug_interactions_with_all_chips( - chips: &[MachineChip], - program: &A::Program, - segment: &A::Record, + machine: &MachineStark, + pkey: &ProvingKey, + shards: &[A::Record], interaction_kinds: Vec, ) -> bool where @@ -134,22 +138,25 @@ where A: MachineAir, { let mut final_map = BTreeMap::new(); - let mut total = SC::Val::zero(); + let chips = machine.chips(); for chip in chips.iter() { - let (_, count) = - debug_interactions::(chip, program, segment, interaction_kinds.clone()); - - tracing::info!("{} chip has {} distinct events", chip.name(), count.len()); - for (key, value) in count.iter() { - let entry = final_map - .entry(key.clone()) - .or_insert((SC::Val::zero(), BTreeMap::new())); - entry.0 += *value; - total += *value; - *entry.1.entry(chip.name()).or_insert(SC::Val::zero()) += *value; + let mut total_events = 0; + for shard in shards { + let (_, count) = + debug_interactions::(chip, pkey, shard, interaction_kinds.clone()); + total_events += count.len(); + for (key, value) in count.iter() { + let entry = final_map + .entry(key.clone()) + .or_insert((SC::Val::zero(), BTreeMap::new())); + entry.0 += *value; + total += *value; + *entry.1.entry(chip.name()).or_insert(SC::Val::zero()) += *value; + } } + tracing::info!("{} chip has {} distinct events", chip.name(), total_events); } tracing::info!("Final counts below."); @@ -197,3 +204,30 @@ where !any_nonzero } + +#[cfg(test)] +mod test { + use crate::{ + lookup::InteractionKind, + runtime::{Program, Runtime, ShardingConfig}, + stark::RiscvAir, + utils::{setup_logger, tests::FIBONACCI_ELF, BabyBearPoseidon2}, + }; + + use super::debug_interactions_with_all_chips; + + #[test] + fn test_debug_interactions() { + setup_logger(); + let program = Program::from(FIBONACCI_ELF); + let config = BabyBearPoseidon2::new(); + let machine = RiscvAir::machine(config); + let (pk, _) = machine.setup(&program); + let mut runtime = Runtime::new(program); + runtime.run(); + let shards = machine.shard(runtime.record, &ShardingConfig::default()); + let ok = + debug_interactions_with_all_chips(&machine, &pk, &shards, InteractionKind::all_kinds()); + assert!(ok); + } +} diff --git a/core/src/memory/global.rs b/core/src/memory/global.rs index 4b02c8231c..38a4dc8cb2 100644 --- a/core/src/memory/global.rs +++ b/core/src/memory/global.rs @@ -172,17 +172,17 @@ where #[cfg(test)] mod tests { + use super::*; use crate::lookup::{debug_interactions_with_all_chips, InteractionKind}; + use crate::runtime::tests::simple_program; use crate::runtime::Runtime; + use crate::stark::MachineRecord; use crate::stark::{RiscvAir, StarkGenericConfig}; use crate::syscall::precompiles::sha256::extend_tests::sha_extend_program; + use crate::utils::{setup_logger, BabyBearPoseidon2}; use crate::utils::{uni_stark_prove as prove, uni_stark_verify as verify}; use p3_baby_bear::BabyBear; - use super::*; - use crate::runtime::tests::simple_program; - use crate::utils::{setup_logger, BabyBearPoseidon2}; - #[test] fn test_memory_generate_trace() { let program = simple_program(); @@ -232,13 +232,18 @@ mod tests { let program_clone = program.clone(); let mut runtime = Runtime::new(program); runtime.run(); - let machine: crate::stark::MachineStark> = RiscvAir::machine(BabyBearPoseidon2::new()); + let (pkey, _) = machine.setup(&program_clone); + let shards = machine.shard( + runtime.record, + &::Config::default(), + ); + assert_eq!(shards.len(), 1); debug_interactions_with_all_chips::>( - machine.chips(), - &program_clone, - &runtime.record, + &machine, + &pkey, + &shards, vec![InteractionKind::Memory], ); } @@ -250,12 +255,17 @@ mod tests { let program_clone = program.clone(); let mut runtime = Runtime::new(program); runtime.run(); - let machine = RiscvAir::machine(BabyBearPoseidon2::new()); + let (pkey, _) = machine.setup(&program_clone); + let shards = machine.shard( + runtime.record, + &::Config::default(), + ); + assert_eq!(shards.len(), 1); debug_interactions_with_all_chips::>( - machine.chips(), - &program_clone, - &runtime.record, + &machine, + &pkey, + &shards, vec![InteractionKind::Byte], ); } diff --git a/core/src/memory/trace.rs b/core/src/memory/trace.rs index 44a69ffd37..e55efa2db6 100644 --- a/core/src/memory/trace.rs +++ b/core/src/memory/trace.rs @@ -137,10 +137,26 @@ impl MemoryAccessCols { let diff_8bit_limb = (diff_minus_one >> 16) & 0xff; self.diff_8bit_limb = F::from_canonical_u32(diff_8bit_limb); + let shard = current_record.shard; + // Add a byte table lookup with the 16Range op. - new_blu_events.push(ByteLookupEvent::new(U16Range, diff_16bit_limb, 0, 0, 0)); + new_blu_events.push(ByteLookupEvent::new( + shard, + U16Range, + diff_16bit_limb, + 0, + 0, + 0, + )); // Add a byte table lookup with the U8Range op. - new_blu_events.push(ByteLookupEvent::new(U8Range, 0, 0, 0, diff_8bit_limb)); + new_blu_events.push(ByteLookupEvent::new( + shard, + U8Range, + 0, + 0, + 0, + diff_8bit_limb, + )); } } diff --git a/core/src/operations/add.rs b/core/src/operations/add.rs index 801ae8d149..3a506c0414 100644 --- a/core/src/operations/add.rs +++ b/core/src/operations/add.rs @@ -20,7 +20,13 @@ pub struct AddOperation { } impl AddOperation { - pub fn populate(&mut self, record: &mut ExecutionRecord, a_u32: u32, b_u32: u32) -> u32 { + pub fn populate( + &mut self, + record: &mut ExecutionRecord, + shard: u32, + a_u32: u32, + b_u32: u32, + ) -> u32 { let expected = a_u32.wrapping_add(b_u32); self.value = Word::from(expected); let a = a_u32.to_le_bytes(); @@ -48,9 +54,9 @@ impl AddOperation { // Range check { - record.add_u8_range_checks(&a); - record.add_u8_range_checks(&b); - record.add_u8_range_checks(&expected.to_le_bytes()); + record.add_u8_range_checks(shard, &a); + record.add_u8_range_checks(shard, &b); + record.add_u8_range_checks(shard, &expected.to_le_bytes()); } expected } @@ -60,6 +66,7 @@ impl AddOperation { a: Word, b: Word, cols: AddOperation, + shard: AB::Var, is_real: AB::Expr, ) { let one = AB::Expr::one(); @@ -96,9 +103,9 @@ impl AddOperation { // Range check each byte. { - builder.slice_range_check_u8(&a.0, is_real.clone()); - builder.slice_range_check_u8(&b.0, is_real.clone()); - builder.slice_range_check_u8(&cols.value.0, is_real); + builder.slice_range_check_u8(&a.0, shard, is_real.clone()); + builder.slice_range_check_u8(&b.0, shard, is_real.clone()); + builder.slice_range_check_u8(&cols.value.0, shard, is_real); } // Degree 3 constraint to avoid "OodEvaluationMismatch". diff --git a/core/src/operations/add4.rs b/core/src/operations/add4.rs index 3e3ed3ca37..96e8723262 100644 --- a/core/src/operations/add4.rs +++ b/core/src/operations/add4.rs @@ -35,6 +35,7 @@ impl Add4Operation { pub fn populate( &mut self, record: &mut ExecutionRecord, + shard: u32, a_u32: u32, b_u32: u32, c_u32: u32, @@ -69,31 +70,33 @@ impl Add4Operation { // Range check. { - record.add_u8_range_checks(&a); - record.add_u8_range_checks(&b); - record.add_u8_range_checks(&c); - record.add_u8_range_checks(&d); - record.add_u8_range_checks(&expected.to_le_bytes()); + record.add_u8_range_checks(shard, &a); + record.add_u8_range_checks(shard, &b); + record.add_u8_range_checks(shard, &c); + record.add_u8_range_checks(shard, &d); + record.add_u8_range_checks(shard, &expected.to_le_bytes()); } expected } + #[allow(clippy::too_many_arguments)] pub fn eval( builder: &mut AB, a: Word, b: Word, c: Word, d: Word, + shard: AB::Var, is_real: AB::Var, cols: Add4Operation, ) { // Range check each byte. { - builder.slice_range_check_u8(&a.0, is_real); - builder.slice_range_check_u8(&b.0, is_real); - builder.slice_range_check_u8(&c.0, is_real); - builder.slice_range_check_u8(&d.0, is_real); - builder.slice_range_check_u8(&cols.value.0, is_real); + builder.slice_range_check_u8(&a.0, shard, is_real); + builder.slice_range_check_u8(&b.0, shard, is_real); + builder.slice_range_check_u8(&c.0, shard, is_real); + builder.slice_range_check_u8(&d.0, shard, is_real); + builder.slice_range_check_u8(&cols.value.0, shard, is_real); } builder.assert_bool(is_real); diff --git a/core/src/operations/add5.rs b/core/src/operations/add5.rs index 98195a5a8c..5be3081c30 100644 --- a/core/src/operations/add5.rs +++ b/core/src/operations/add5.rs @@ -37,9 +37,11 @@ pub struct Add5Operation { } impl Add5Operation { + #[allow(clippy::too_many_arguments)] pub fn populate( &mut self, - shard: &mut ExecutionRecord, + record: &mut ExecutionRecord, + shard: u32, a_u32: u32, b_u32: u32, c_u32: u32, @@ -80,12 +82,12 @@ impl Add5Operation { // Range check. { - shard.add_u8_range_checks(&a); - shard.add_u8_range_checks(&b); - shard.add_u8_range_checks(&c); - shard.add_u8_range_checks(&d); - shard.add_u8_range_checks(&e); - shard.add_u8_range_checks(&expected.to_le_bytes()); + record.add_u8_range_checks(shard, &a); + record.add_u8_range_checks(shard, &b); + record.add_u8_range_checks(shard, &c); + record.add_u8_range_checks(shard, &d); + record.add_u8_range_checks(shard, &e); + record.add_u8_range_checks(shard, &expected.to_le_bytes()); } expected @@ -94,6 +96,7 @@ impl Add5Operation { pub fn eval( builder: &mut AB, words: &[Word; 5], + shard: AB::Var, is_real: AB::Var, cols: Add5Operation, ) { @@ -102,8 +105,8 @@ impl Add5Operation { { words .iter() - .for_each(|word| builder.slice_range_check_u8(&word.0, is_real)); - builder.slice_range_check_u8(&cols.value.0, is_real); + .for_each(|word| builder.slice_range_check_u8(&word.0, shard, is_real)); + builder.slice_range_check_u8(&cols.value.0, shard, is_real); } let mut builder_is_real = builder.when(is_real); diff --git a/core/src/operations/and.rs b/core/src/operations/and.rs index 97a3e6fda7..aba5d0078d 100644 --- a/core/src/operations/and.rs +++ b/core/src/operations/and.rs @@ -18,7 +18,7 @@ pub struct AndOperation { } impl AndOperation { - pub fn populate(&mut self, record: &mut ExecutionRecord, x: u32, y: u32) -> u32 { + pub fn populate(&mut self, record: &mut ExecutionRecord, shard: u32, x: u32, y: u32) -> u32 { let expected = x & y; let x_bytes = x.to_le_bytes(); let y_bytes = y.to_le_bytes(); @@ -27,6 +27,7 @@ impl AndOperation { self.value[i] = F::from_canonical_u8(and); let byte_event = ByteLookupEvent { + shard, opcode: ByteOpcode::AND, a1: and as u32, a2: 0, @@ -44,6 +45,7 @@ impl AndOperation { a: Word, b: Word, cols: AndOperation, + shard: AB::Var, is_real: AB::Var, ) { for i in 0..WORD_SIZE { @@ -52,6 +54,7 @@ impl AndOperation { cols.value[i], a[i], b[i], + shard, is_real, ); } diff --git a/core/src/operations/fixed_rotate_right.rs b/core/src/operations/fixed_rotate_right.rs index 40af827980..1ba61941d9 100644 --- a/core/src/operations/fixed_rotate_right.rs +++ b/core/src/operations/fixed_rotate_right.rs @@ -40,7 +40,13 @@ impl FixedRotateRightOperation { 1 << (8 - nb_bits_to_shift) } - pub fn populate(&mut self, record: &mut ExecutionRecord, input: u32, rotation: usize) -> u32 { + pub fn populate( + &mut self, + record: &mut ExecutionRecord, + shard: u32, + input: u32, + rotation: usize, + ) -> u32 { let input_bytes = input.to_le_bytes().map(F::from_canonical_u8); let expected = input.rotate_right(rotation as u32); @@ -68,6 +74,7 @@ impl FixedRotateRightOperation { let (shift, carry) = shr_carry(b, c); let byte_event = ByteLookupEvent { + shard, opcode: ByteOpcode::ShrCarry, a1: shift as u32, a2: carry as u32, @@ -102,6 +109,7 @@ impl FixedRotateRightOperation { input: Word, rotation: usize, cols: FixedRotateRightOperation, + shard: AB::Var, is_real: AB::Var, ) { // Compute some constants with respect to the rotation needed for the rotation. @@ -128,6 +136,7 @@ impl FixedRotateRightOperation { cols.carry[i], input_bytes_rotated[i], AB::F::from_canonical_usize(nb_bits_to_shift), + shard, is_real, ); diff --git a/core/src/operations/fixed_shift_right.rs b/core/src/operations/fixed_shift_right.rs index 766cc1cddf..4944ab94ba 100644 --- a/core/src/operations/fixed_shift_right.rs +++ b/core/src/operations/fixed_shift_right.rs @@ -40,7 +40,13 @@ impl FixedShiftRightOperation { 1 << (8 - nb_bits_to_shift) } - pub fn populate(&mut self, record: &mut ExecutionRecord, input: u32, rotation: usize) -> u32 { + pub fn populate( + &mut self, + record: &mut ExecutionRecord, + shard: u32, + input: u32, + rotation: usize, + ) -> u32 { let input_bytes = input.to_le_bytes().map(F::from_canonical_u8); let expected = input >> rotation; @@ -67,6 +73,7 @@ impl FixedShiftRightOperation { let c = nb_bits_to_shift as u8; let (shift, carry) = shr_carry(b, c); let byte_event = ByteLookupEvent { + shard, opcode: ByteOpcode::ShrCarry, a1: shift as u32, a2: carry as u32, @@ -101,6 +108,7 @@ impl FixedShiftRightOperation { input: Word, rotation: usize, cols: FixedShiftRightOperation, + shard: AB::Var, is_real: AB::Var, ) { // Compute some constants with respect to the rotation needed for the rotation. @@ -128,6 +136,7 @@ impl FixedShiftRightOperation { cols.carry[i], input_bytes_rotated[i].clone(), AB::F::from_canonical_usize(nb_bits_to_shift), + shard, is_real, ); diff --git a/core/src/operations/not.rs b/core/src/operations/not.rs index 198ea2afd1..d7aa8859e1 100644 --- a/core/src/operations/not.rs +++ b/core/src/operations/not.rs @@ -18,13 +18,13 @@ pub struct NotOperation { } impl NotOperation { - pub fn populate(&mut self, record: &mut ExecutionRecord, x: u32) -> u32 { + pub fn populate(&mut self, record: &mut ExecutionRecord, shard: u32, x: u32) -> u32 { let expected = !x; let x_bytes = x.to_le_bytes(); for i in 0..WORD_SIZE { self.value[i] = F::from_canonical_u8(!x_bytes[i]); } - record.add_u8_range_checks(&x_bytes); + record.add_u8_range_checks(shard, &x_bytes); expected } @@ -33,6 +33,7 @@ impl NotOperation { builder: &mut AB, a: Word, cols: NotOperation, + shard: AB::Var, is_real: AB::Var, ) { for i in (0..WORD_SIZE).step_by(2) { @@ -42,6 +43,7 @@ impl NotOperation { AB::F::zero(), a[i], a[i + 1], + shard, is_real, ); } diff --git a/core/src/operations/or.rs b/core/src/operations/or.rs index 713b1894c4..ea161cdee0 100644 --- a/core/src/operations/or.rs +++ b/core/src/operations/or.rs @@ -19,13 +19,13 @@ pub struct OrOperation { } impl OrOperation { - pub fn populate(&mut self, record: &mut ExecutionRecord, x: u32, y: u32) -> u32 { + pub fn populate(&mut self, record: &mut ExecutionRecord, shard: u32, x: u32, y: u32) -> u32 { let expected = x | y; let x_bytes = x.to_le_bytes(); let y_bytes = y.to_le_bytes(); for i in 0..WORD_SIZE { self.value[i] = F::from_canonical_u8(x_bytes[i] | y_bytes[i]); - record.lookup_or(x_bytes[i], y_bytes[i]); + record.lookup_or(shard, x_bytes[i], y_bytes[i]); } expected } @@ -35,6 +35,7 @@ impl OrOperation { a: Word, b: Word, cols: OrOperation, + shard: AB::Var, is_real: AB::Var, ) { for i in 0..WORD_SIZE { @@ -43,6 +44,7 @@ impl OrOperation { cols.value[i], a[i], b[i], + shard, is_real, ); } diff --git a/core/src/operations/xor.rs b/core/src/operations/xor.rs index 9d26e08aa1..1620bbe00e 100644 --- a/core/src/operations/xor.rs +++ b/core/src/operations/xor.rs @@ -18,7 +18,7 @@ pub struct XorOperation { } impl XorOperation { - pub fn populate(&mut self, record: &mut ExecutionRecord, x: u32, y: u32) -> u32 { + pub fn populate(&mut self, record: &mut ExecutionRecord, shard: u32, x: u32, y: u32) -> u32 { let expected = x ^ y; let x_bytes = x.to_le_bytes(); let y_bytes = y.to_le_bytes(); @@ -27,6 +27,7 @@ impl XorOperation { self.value[i] = F::from_canonical_u8(xor); let byte_event = ByteLookupEvent { + shard, opcode: ByteOpcode::XOR, a1: xor as u32, a2: 0, @@ -44,6 +45,7 @@ impl XorOperation { a: Word, b: Word, cols: XorOperation, + shard: AB::Var, is_real: AB::Var, ) { for i in 0..WORD_SIZE { @@ -52,6 +54,7 @@ impl XorOperation { cols.value[i], a[i], b[i], + shard, is_real, ); } diff --git a/core/src/program/mod.rs b/core/src/program/mod.rs index 2099579b20..bfbcaa52f1 100644 --- a/core/src/program/mod.rs +++ b/core/src/program/mod.rs @@ -31,6 +31,7 @@ pub struct ProgramPreprocessedCols { #[derive(AlignedBorrow, Clone, Copy, Default)] #[repr(C)] pub struct ProgramMultiplicityCols { + pub shard: T, pub multiplicity: T, } @@ -87,6 +88,10 @@ impl MachineAir for ProgramChip { Some(trace) } + fn generate_dependencies(&self, _input: &ExecutionRecord, _output: &mut ExecutionRecord) { + // Do nothing since this chip has no dependencies. + } + fn generate_trace( &self, input: &ExecutionRecord, @@ -115,6 +120,7 @@ impl MachineAir for ProgramChip { let pc = input.program.pc_base + (i as u32 * 4); let mut row = [F::zero(); NUM_PROGRAM_MULT_COLS]; let cols: &mut ProgramMultiplicityCols = row.as_mut_slice().borrow_mut(); + cols.shard = F::from_canonical_u32(input.index); cols.multiplicity = F::from_canonical_usize(*instruction_counts.get(&pc).unwrap_or(&0)); row @@ -166,6 +172,7 @@ where prep_local.pc, prep_local.instruction, prep_local.selectors, + mult_local.shard, mult_local.multiplicity, ); } diff --git a/core/src/runtime/mod.rs b/core/src/runtime/mod.rs index ff55888ee1..0844ac7a8a 100644 --- a/core/src/runtime/mod.rs +++ b/core/src/runtime/mod.rs @@ -376,6 +376,7 @@ impl Runtime { /// Emit an ALU event. fn emit_alu(&mut self, clk: u32, opcode: Opcode, a: u32, b: u32, c: u32) { let event = AluEvent { + shard: self.shard(), clk, opcode, a, diff --git a/core/src/runtime/record.rs b/core/src/runtime/record.rs index d3eb8bf126..25be558b23 100644 --- a/core/src/runtime/record.rs +++ b/core/src/runtime/record.rs @@ -59,8 +59,9 @@ pub struct ExecutionRecord { /// A trace of the SLT, SLTI, SLTU, and SLTIU events. pub lt_events: Vec, - /// A trace of the byte lookups needed. - pub byte_lookups: BTreeMap, + /// All byte lookups that are needed. The layout is shard -> (event -> count). Byte lookups are + /// sharded to prevent the multiplicities from overflowing. + pub byte_lookups: BTreeMap>, pub sha_extend_events: Vec, @@ -241,11 +242,20 @@ impl MachineRecord for ExecutionRecord { self.blake3_compress_inner_events .append(&mut other.blake3_compress_inner_events); - for (event, mult) in other.byte_lookups.iter_mut() { - self.byte_lookups - .entry(*event) - .and_modify(|i| *i += *mult) - .or_insert(*mult); + // Merge the byte lookups. + for (shard, events_map) in std::mem::take(&mut other.byte_lookups).into_iter() { + match self.byte_lookups.get_mut(&shard) { + Some(existing) => { + // If there's already a map for this shard, update counts for each event. + for (event, count) in events_map.iter() { + *existing.entry(*event).or_insert(0) += count; + } + } + None => { + // If there isn't a map for this shard, insert the whole map. + self.byte_lookups.insert(shard, events_map); + } + } } self.memory_initialize_events @@ -270,6 +280,7 @@ impl MachineRecord for ExecutionRecord { .collect::>(); let mut start_idx = 0; let mut current_shard_num = 1; + for (i, cpu_event) in self.cpu_events.iter().enumerate() { let at_last_event = i == num_cpu_events - 1; if cpu_event.shard != current_shard_num || at_last_event { @@ -278,8 +289,17 @@ impl MachineRecord for ExecutionRecord { let shard = &mut shards[current_shard_num as usize - 1]; shard.index = current_shard_num; shard.cpu_events = self.cpu_events[start_idx..last_idx].to_vec(); + // Each shard needs program because we use it in ProgramChip. shard.program = self.program.clone(); + // Byte lookups are already sharded, so put this shard's lookups in. + shard.byte_lookups.insert( + current_shard_num, + self.byte_lookups + .remove(¤t_shard_num) + .unwrap_or_default(), + ); + if !(at_last_event) { start_idx = i; current_shard_num = cpu_event.shard; @@ -422,9 +442,6 @@ impl MachineRecord for ExecutionRecord { // Blake3 compress events . first.blake3_compress_inner_events = std::mem::take(&mut self.blake3_compress_inner_events); - // Put all byte lookups in the first shard (as the table size is fixed) - first.byte_lookups = std::mem::take(&mut self.byte_lookups); - // Put the memory records in the last shard. let last_shard = shards.last_mut().unwrap(); @@ -475,10 +492,12 @@ impl ExecutionRecord { } pub fn add_byte_lookup_event(&mut self, blu_event: ByteLookupEvent) { - self.byte_lookups + *self + .byte_lookups + .entry(blu_event.shard) + .or_default() .entry(blu_event) - .and_modify(|i| *i += 1) - .or_insert(1); + .or_insert(0) += 1 } pub fn add_alu_events(&mut self, alu_events: HashMap>) { @@ -522,8 +541,9 @@ impl ExecutionRecord { } /// Adds a `ByteLookupEvent` to verify `a` and `b are indeed bytes to the shard. - pub fn add_u8_range_check(&mut self, a: u8, b: u8) { + pub fn add_u8_range_check(&mut self, shard: u32, a: u8, b: u8) { self.add_byte_lookup_event(ByteLookupEvent { + shard, opcode: ByteOpcode::U8Range, a1: 0, a2: 0, @@ -533,8 +553,9 @@ impl ExecutionRecord { } /// Adds a `ByteLookupEvent` to verify `a` is indeed u16. - pub fn add_u16_range_check(&mut self, a: u32) { + pub fn add_u16_range_check(&mut self, shard: u32, a: u32) { self.add_byte_lookup_event(ByteLookupEvent { + shard, opcode: ByteOpcode::U16Range, a1: a, a2: 0, @@ -544,26 +565,27 @@ impl ExecutionRecord { } /// Adds `ByteLookupEvent`s to verify that all the bytes in the input slice are indeed bytes. - pub fn add_u8_range_checks(&mut self, ls: &[u8]) { + pub fn add_u8_range_checks(&mut self, shard: u32, ls: &[u8]) { let mut index = 0; while index + 1 < ls.len() { - self.add_u8_range_check(ls[index], ls[index + 1]); + self.add_u8_range_check(shard, ls[index], ls[index + 1]); index += 2; } if index < ls.len() { // If the input slice's length is odd, we need to add a check for the last byte. - self.add_u8_range_check(ls[index], 0); + self.add_u8_range_check(shard, ls[index], 0); } } /// Adds `ByteLookupEvent`s to verify that all the bytes in the input slice are indeed bytes. - pub fn add_u16_range_checks(&mut self, ls: &[u32]) { - ls.iter().for_each(|x| self.add_u16_range_check(*x)); + pub fn add_u16_range_checks(&mut self, shard: u32, ls: &[u32]) { + ls.iter().for_each(|x| self.add_u16_range_check(shard, *x)); } /// Adds a `ByteLookupEvent` to compute the bitwise OR of the two input values. - pub fn lookup_or(&mut self, b: u8, c: u8) { + pub fn lookup_or(&mut self, shard: u32, b: u8, c: u8) { self.add_byte_lookup_event(ByteLookupEvent { + shard, opcode: ByteOpcode::OR, a1: (b | c) as u32, a2: 0, diff --git a/core/src/runtime/syscall.rs b/core/src/runtime/syscall.rs index ab53a9bebe..adcf957ebf 100644 --- a/core/src/runtime/syscall.rs +++ b/core/src/runtime/syscall.rs @@ -1,5 +1,4 @@ use crate::runtime::{Register, Runtime}; -use crate::syscall::precompiles::blake3::Blake3CompressInnerChip; use crate::syscall::precompiles::edwards::EdAddAssignChip; use crate::syscall::precompiles::edwards::EdDecompressChip; use crate::syscall::precompiles::k256::K256DecompressChip; @@ -263,10 +262,6 @@ pub fn default_syscall_map() -> HashMap> { SyscallCode::BN254_DOUBLE, Rc::new(WeierstrassDoubleAssignChip::::new()), ); - syscall_map.insert( - SyscallCode::BLAKE3_COMPRESS_INNER, - Rc::new(Blake3CompressInnerChip::new()), - ); syscall_map.insert( SyscallCode::ENTER_UNCONSTRAINED, Rc::new(SyscallEnterUnconstrained::new()), @@ -292,6 +287,10 @@ mod tests { fn test_syscalls_in_default_map() { let default_syscall_map = default_syscall_map(); for code in SyscallCode::iter() { + if code == SyscallCode::BLAKE3_COMPRESS_INNER { + // Blake3 is currently disabled. + continue; + } default_syscall_map.get(&code).unwrap(); } } diff --git a/core/src/stark/air.rs b/core/src/stark/air.rs index 44520968e6..1e2aef7cc9 100644 --- a/core/src/stark/air.rs +++ b/core/src/stark/air.rs @@ -86,7 +86,7 @@ pub enum RiscvAir { Secp256k1Double(WeierstrassDoubleAssignChip>), /// A precompile for the Keccak permutation. KeccakP(KeccakPermuteChip), - /// A precompile for the Blake3 compression function. + /// A precompile for the Blake3 compression function. (Disabled by default.) Blake3Compress(Blake3CompressInnerChip), /// A precompile for addition on the Elliptic curve bn254. Bn254Add(WeierstrassAddAssignChip>), @@ -129,8 +129,6 @@ impl RiscvAir { chips.push(RiscvAir::Secp256k1Double(secp256k1_double_assign)); let keccak_permute = KeccakPermuteChip::new(); chips.push(RiscvAir::KeccakP(keccak_permute)); - let blake3_compress_inner = Blake3CompressInnerChip::new(); - chips.push(RiscvAir::Blake3Compress(blake3_compress_inner)); let bn254_add_assign = WeierstrassAddAssignChip::>::new(); chips.push(RiscvAir::Bn254Add(bn254_add_assign)); let bn254_double_assign = WeierstrassDoubleAssignChip::>::new(); diff --git a/core/src/stark/machine.rs b/core/src/stark/machine.rs index 64ccd7567a..e1b2e33a04 100644 --- a/core/src/stark/machine.rs +++ b/core/src/stark/machine.rs @@ -236,7 +236,6 @@ impl>> MachineStark { challenger.observe(vk.commit.clone()); // TODO: Observe the challenges in a tree-like structure for easily verifiable reconstruction // in a map-reduce recursion setting. - #[cfg(feature = "perf")] tracing::debug_span!("observe challenges for all shards").in_scope(|| { proof.shard_proofs.iter().for_each(|proof| { challenger.observe(proof.commitment.main_commit.clone()); @@ -259,8 +258,9 @@ impl>> MachineStark { .map_err(ProgramVerificationError::InvalidSegmentProof) })?; } - tracing::info!("success"); + tracing::info!("verifying individual shards succeeded"); + tracing::info!("verifying cumulative sum is 0"); // Verify the cumulative sum is 0. let mut sum = SC::Challenge::zero(); for proof in proof.shard_proofs.iter() { @@ -274,7 +274,6 @@ impl>> MachineStark { pub fn debug_constraints( &self, - program: &A::Program, pk: &ProvingKey, record: A::Record, challenger: &mut SC::Challenger, @@ -283,7 +282,7 @@ impl>> MachineStark { A: for<'a> Air, SC::Challenge>>, { tracing::debug!("sharding the execution record"); - let mut shards = self.shard(record, &::Config::default()); + let shards = self.shard(record, &::Config::default()); tracing::debug!("checking constraints for each shard"); @@ -293,15 +292,19 @@ impl>> MachineStark { let chips = self.shard_chips(shard).collect::>(); // Generate the main trace for each chip. - let traces = chips - .par_iter() + let pre_traces = chips + .iter() .map(|chip| { - ( - chip.generate_trace(shard, &mut A::Record::default()), - chip.generate_preprocessed_trace(program), - ) + pk.chip_ordering + .get(&chip.name()) + .map(|index| &pk.traces[*index]) }) .collect::>(); + let traces = chips + .par_iter() + .map(|chip| chip.generate_trace(shard, &mut A::Record::default())) + .zip(pre_traces) + .collect::>(); // Get a permutation challenge. // Obtain the challenges used for the permutation argument. @@ -319,7 +322,7 @@ impl>> MachineStark { .zip(traces.par_iter()) .map(|(chip, (main_trace, pre_trace))| { let perm_trace = chip.generate_permutation_trace( - pre_trace.as_ref(), + *pre_trace, main_trace, &permutation_challenges, ); @@ -370,15 +373,10 @@ impl>> MachineStark { // If the cumulative sum is not zero, debug the interactions. if !cumulative_sum.is_zero() { - // Get the total record - let mut record = A::Record::default(); - for shard in shards.iter_mut() { - record.append(shard); - } debug_interactions_with_all_chips::( - self.chips(), - program, - &record, + self, + pk, + &shards, InteractionKind::all_kinds(), ); } @@ -489,6 +487,7 @@ pub mod tests { #[test] fn test_lt_prove() { + setup_logger(); let less_than = [Opcode::SLT, Opcode::SLTU]; for lt_op in less_than.iter() { let instructions = vec![ @@ -503,6 +502,7 @@ pub mod tests { #[test] fn test_bitwise_prove() { + setup_logger(); let bitwise_opcodes = [Opcode::XOR, Opcode::OR, Opcode::AND]; for bitwise_op in bitwise_opcodes.iter() { @@ -518,6 +518,7 @@ pub mod tests { #[test] fn test_divrem_prove() { + setup_logger(); let div_rem_ops = [Opcode::DIV, Opcode::DIVU, Opcode::REM, Opcode::REMU]; let operands = [ (1, 1), diff --git a/core/src/stark/prover.rs b/core/src/stark/prover.rs index 15982862c4..73688b2f63 100644 --- a/core/src/stark/prover.rs +++ b/core/src/stark/prover.rs @@ -588,24 +588,6 @@ where .unzip() }); - #[cfg(not(feature = "perf"))] - { - let bytes_written = shard_main_data - .iter() - .map(|data| match data { - ShardMainDataWrapper::InMemory(_) => 0, - ShardMainDataWrapper::TempFile(_, bytes_written) => *bytes_written, - ShardMainDataWrapper::Empty() => 0, - }) - .sum::(); - if bytes_written > 0 { - tracing::debug!( - "total main data written to disk: {}", - size::Size::from_bytes(bytes_written) - ); - } - } - (commitments, shard_main_data) } } diff --git a/core/src/stark/verifier.rs b/core/src/stark/verifier.rs index 65f540ff5c..d9d14876dd 100644 --- a/core/src/stark/verifier.rs +++ b/core/src/stark/verifier.rs @@ -26,7 +26,6 @@ pub struct Verifier(PhantomData, PhantomData); impl>> Verifier { /// Verify a proof for a collection of air chips. - #[cfg(feature = "perf")] pub fn verify_shard( config: &SC, vk: &VerifyingKey, @@ -74,7 +73,6 @@ impl>> Verifier { .map(|_| challenger.sample_ext_element::()) .collect::>(); - #[cfg(feature = "perf")] challenger.observe(permutation_commit.clone()); let alpha = challenger.sample_ext_element::(); @@ -191,23 +189,10 @@ impl>> Verifier { ) .map_err(|_| VerificationError::OodEvaluationMismatch(chip.name()))?; } - - Ok(()) - } - - #[cfg(not(feature = "perf"))] - pub fn verify_shard( - _config: &SC, - _vk: &VerifyingKey, - _chips: &[&MachineChip], - _challenger: &mut SC::Challenger, - _proof: &ShardProof, - ) -> Result<(), VerificationError> { Ok(()) } #[allow(clippy::too_many_arguments)] - #[cfg(feature = "perf")] fn verify_constraints( chip: &MachineChip, opening: ChipOpenedValues, @@ -241,7 +226,6 @@ impl>> Verifier { } } - #[cfg(feature = "perf")] pub fn eval_constraints( chip: &MachineChip, opening: &ChipOpenedValues, @@ -291,7 +275,6 @@ impl>> Verifier { folder.accumulator } - #[cfg(feature = "perf")] pub fn recompute_quotient( opening: &ChipOpenedValues, qc_domains: &[Domain], diff --git a/core/src/syscall/precompiles/blake3/compress/air.rs b/core/src/syscall/precompiles/blake3/compress/air.rs index 53d02b7a1a..0cb32872d1 100644 --- a/core/src/syscall/precompiles/blake3/compress/air.rs +++ b/core/src/syscall/precompiles/blake3/compress/air.rs @@ -37,7 +37,7 @@ where // TODO: constraint ecall_receive column. // TODO: constraint clk column to increment by 1 within same invocation of syscall. builder.receive_syscall( - local.segment, // TODO: rename this to "shard" + local.shard, local.clk, AB::F::from_canonical_u32(SyscallCode::BLAKE3_COMPRESS_INNER.syscall_id()), local.state_ptr, @@ -143,7 +143,7 @@ impl Blake3CompressInnerChip { // Read & write the state. for i in 0..NUM_STATE_WORDS_PER_CALL { builder.constraint_memory_access( - local.segment, + local.shard, local.clk, local.state_ptr + local.state_index[i] * AB::F::from_canonical_usize(WORD_SIZE), &local.state_reads_writes[i], @@ -179,7 +179,7 @@ impl Blake3CompressInnerChip { // Read the message. for i in 0..NUM_MSG_WORDS_PER_CALL { builder.constraint_memory_access( - local.segment, + local.shard, local.clk, local.message_ptr + local.msg_schedule[i] * AB::F::from_canonical_usize(WORD_SIZE), &local.message_reads[i], @@ -208,7 +208,7 @@ impl Blake3CompressInnerChip { ]; // Call the g function. - GOperation::::eval(builder, input, local.g, local.is_real); + GOperation::::eval(builder, input, local.g, local.shard, local.is_real); // Finally, the results of the g function should be written to the memory. for i in 0..NUM_STATE_WORDS_PER_CALL { diff --git a/core/src/syscall/precompiles/blake3/compress/columns.rs b/core/src/syscall/precompiles/blake3/compress/columns.rs index ac1246d6c9..12ea0139c9 100644 --- a/core/src/syscall/precompiles/blake3/compress/columns.rs +++ b/core/src/syscall/precompiles/blake3/compress/columns.rs @@ -16,7 +16,7 @@ pub const NUM_BLAKE3_COMPRESS_INNER_COLS: usize = size_of:: { - pub segment: T, + pub shard: T, pub clk: T, pub ecall_receive: T, diff --git a/core/src/syscall/precompiles/blake3/compress/g.rs b/core/src/syscall/precompiles/blake3/compress/g.rs index 4e4654794f..17b5f633de 100644 --- a/core/src/syscall/precompiles/blake3/compress/g.rs +++ b/core/src/syscall/precompiles/blake3/compress/g.rs @@ -46,7 +46,12 @@ pub struct GOperation { } impl GOperation { - pub fn populate(&mut self, record: &mut ExecutionRecord, input: [u32; 6]) -> [u32; 4] { + pub fn populate( + &mut self, + record: &mut ExecutionRecord, + shard: u32, + input: [u32; 6], + ) -> [u32; 4] { let mut a = input[0]; let mut b = input[1]; let mut c = input[2]; @@ -57,37 +62,37 @@ impl GOperation { // First 4 steps. { // a = a + b + x. - a = self.a_plus_b.populate(record, a, b); - a = self.a_plus_b_plus_x.populate(record, a, x); + a = self.a_plus_b.populate(record, shard, a, b); + a = self.a_plus_b_plus_x.populate(record, shard, a, x); // d = (d ^ a).rotate_right(16). - d = self.d_xor_a.populate(record, d, a); + d = self.d_xor_a.populate(record, shard, d, a); d = d.rotate_right(16); // c = c + d. - c = self.c_plus_d.populate(record, c, d); + c = self.c_plus_d.populate(record, shard, c, d); // b = (b ^ c).rotate_right(12). - b = self.b_xor_c.populate(record, b, c); - b = self.b_xor_c_rotate_right_12.populate(record, b, 12); + b = self.b_xor_c.populate(record, shard, b, c); + b = self.b_xor_c_rotate_right_12.populate(record, shard, b, 12); } // Second 4 steps. { // a = a + b + y. - a = self.a_plus_b_2.populate(record, a, b); - a = self.a_plus_b_2_add_y.populate(record, a, y); + a = self.a_plus_b_2.populate(record, shard, a, b); + a = self.a_plus_b_2_add_y.populate(record, shard, a, y); // d = (d ^ a).rotate_right(8). - d = self.d_xor_a_2.populate(record, d, a); + d = self.d_xor_a_2.populate(record, shard, d, a); d = d.rotate_right(8); // c = c + d. - c = self.c_plus_d_2.populate(record, c, d); + c = self.c_plus_d_2.populate(record, shard, c, d); // b = (b ^ c).rotate_right(7). - b = self.b_xor_c_2.populate(record, b, c); - b = self.b_xor_c_2_rotate_right_7.populate(record, b, 7); + b = self.b_xor_c_2.populate(record, shard, b, c); + b = self.b_xor_c_2_rotate_right_7.populate(record, shard, b, 7); } let result = [a, b, c, d]; @@ -100,6 +105,7 @@ impl GOperation { builder: &mut AB, input: [Word; 6], cols: GOperation, + shard: AB::Var, is_real: AB::Var, ) { builder.assert_bool(is_real); @@ -113,29 +119,30 @@ impl GOperation { // First 4 steps. { // a = a + b + x. - AddOperation::::eval(builder, a, b, cols.a_plus_b, is_real.into()); + AddOperation::::eval(builder, a, b, cols.a_plus_b, shard, is_real.into()); a = cols.a_plus_b.value; - AddOperation::::eval(builder, a, x, cols.a_plus_b_plus_x, is_real.into()); + AddOperation::::eval(builder, a, x, cols.a_plus_b_plus_x, shard, is_real.into()); a = cols.a_plus_b_plus_x.value; // d = (d ^ a).rotate_right(16). - XorOperation::::eval(builder, d, a, cols.d_xor_a, is_real); + XorOperation::::eval(builder, d, a, cols.d_xor_a, shard, is_real); d = cols.d_xor_a.value; // Rotate right by 16 bits. d = Word([d[2], d[3], d[0], d[1]]); // c = c + d. - AddOperation::::eval(builder, c, d, cols.c_plus_d, is_real.into()); + AddOperation::::eval(builder, c, d, cols.c_plus_d, shard, is_real.into()); c = cols.c_plus_d.value; // b = (b ^ c).rotate_right(12). - XorOperation::::eval(builder, b, c, cols.b_xor_c, is_real); + XorOperation::::eval(builder, b, c, cols.b_xor_c, shard, is_real); b = cols.b_xor_c.value; FixedRotateRightOperation::::eval( builder, b, 12, cols.b_xor_c_rotate_right_12, + shard, is_real, ); b = cols.b_xor_c_rotate_right_12.value; @@ -144,29 +151,37 @@ impl GOperation { // Second 4 steps. { // a = a + b + y. - AddOperation::::eval(builder, a, b, cols.a_plus_b_2, is_real.into()); + AddOperation::::eval(builder, a, b, cols.a_plus_b_2, shard, is_real.into()); a = cols.a_plus_b_2.value; - AddOperation::::eval(builder, a, y, cols.a_plus_b_2_add_y, is_real.into()); + AddOperation::::eval( + builder, + a, + y, + cols.a_plus_b_2_add_y, + shard, + is_real.into(), + ); a = cols.a_plus_b_2_add_y.value; // d = (d ^ a).rotate_right(8). - XorOperation::::eval(builder, d, a, cols.d_xor_a_2, is_real); + XorOperation::::eval(builder, d, a, cols.d_xor_a_2, shard, is_real); d = cols.d_xor_a_2.value; // Rotate right by 8 bits. d = Word([d[1], d[2], d[3], d[0]]); // c = c + d. - AddOperation::::eval(builder, c, d, cols.c_plus_d_2, is_real.into()); + AddOperation::::eval(builder, c, d, cols.c_plus_d_2, shard, is_real.into()); c = cols.c_plus_d_2.value; // b = (b ^ c).rotate_right(7). - XorOperation::::eval(builder, b, c, cols.b_xor_c_2, is_real); + XorOperation::::eval(builder, b, c, cols.b_xor_c_2, shard, is_real); b = cols.b_xor_c_2.value; FixedRotateRightOperation::::eval( builder, b, 7, cols.b_xor_c_2_rotate_right_7, + shard, is_real, ); b = cols.b_xor_c_2_rotate_right_7.value; diff --git a/core/src/syscall/precompiles/blake3/compress/mod.rs b/core/src/syscall/precompiles/blake3/compress/mod.rs index 2506560aa3..e8c1f8c3e1 100644 --- a/core/src/syscall/precompiles/blake3/compress/mod.rs +++ b/core/src/syscall/precompiles/blake3/compress/mod.rs @@ -115,9 +115,6 @@ pub mod compress_tests { use crate::runtime::Opcode; use crate::runtime::Register; use crate::runtime::SyscallCode; - use crate::utils::run_test; - use crate::utils::setup_logger; - use crate::utils::tests::BLAKE3_COMPRESS_ELF; use crate::Program; use super::MSG_SIZE; @@ -164,17 +161,18 @@ pub mod compress_tests { Program::new(instructions, 0, 0) } - #[test] - fn prove_babybear() { - setup_logger(); - let program = blake3_compress_internal_program(); - run_test(program).unwrap(); - } - - #[test] - fn test_blake3_compress_inner_elf() { - setup_logger(); - let program = Program::from(BLAKE3_COMPRESS_ELF); - run_test(program).unwrap(); - } + // Tests disabled because syscall is not enabled in default runtime/chip configs. + // #[test] + // fn prove_babybear() { + // setup_logger(); + // let program = blake3_compress_internal_program(); + // run_test(program).unwrap(); + // } + + // #[test] + // fn test_blake3_compress_inner_elf() { + // setup_logger(); + // let program = Program::from(BLAKE3_COMPRESS_ELF); + // run_test(program).unwrap(); + // } } diff --git a/core/src/syscall/precompiles/blake3/compress/trace.rs b/core/src/syscall/precompiles/blake3/compress/trace.rs index 1567686952..9017a6adf0 100644 --- a/core/src/syscall/precompiles/blake3/compress/trace.rs +++ b/core/src/syscall/precompiles/blake3/compress/trace.rs @@ -37,7 +37,7 @@ impl MachineAir for Blake3CompressInnerChip { for i in 0..input.blake3_compress_inner_events.len() { let event = input.blake3_compress_inner_events[i].clone(); - + let shard = event.shard; let mut clk = event.clk; for round in 0..ROUND_COUNT { for operation in 0..OPERATION_COUNT { @@ -46,7 +46,7 @@ impl MachineAir for Blake3CompressInnerChip { // Assign basic values to the columns. { - cols.segment = F::from_canonical_u32(event.shard); + cols.shard = F::from_canonical_u32(event.shard); cols.clk = F::from_canonical_u32(clk); cols.round_index = F::from_canonical_u32(round as u32); @@ -99,7 +99,7 @@ impl MachineAir for Blake3CompressInnerChip { event.message_reads[round][operation][1].value, ]; - cols.g.populate(output, input); + cols.g.populate(output, shard, input); } clk += 1; diff --git a/core/src/syscall/precompiles/keccak256/air.rs b/core/src/syscall/precompiles/keccak256/air.rs index ab9c4a679f..5ee4723281 100644 --- a/core/src/syscall/precompiles/keccak256/air.rs +++ b/core/src/syscall/precompiles/keccak256/air.rs @@ -136,7 +136,6 @@ where } } -#[cfg(feature = "keccak")] #[cfg(test)] mod test { use crate::SP1Stdin; @@ -151,6 +150,7 @@ mod test { const NUM_TEST_CASES: usize = 45; #[test] + #[ignore] fn test_keccak_random() { setup_logger(); let mut rng = rand::rngs::StdRng::seed_from_u64(0); diff --git a/core/src/syscall/precompiles/sha256/compress/air.rs b/core/src/syscall/precompiles/sha256/compress/air.rs index 2c11f92c12..e5f710ce07 100644 --- a/core/src/syscall/precompiles/sha256/compress/air.rs +++ b/core/src/syscall/precompiles/sha256/compress/air.rs @@ -285,6 +285,7 @@ impl ShaCompressChip { local.e, 6, local.e_rr_6, + local.shard, local.is_compression, ); // Calculate e rightrotate 11. @@ -293,6 +294,7 @@ impl ShaCompressChip { local.e, 11, local.e_rr_11, + local.shard, local.is_compression, ); // Calculate e rightrotate 25. @@ -301,6 +303,7 @@ impl ShaCompressChip { local.e, 25, local.e_rr_25, + local.shard, local.is_compression, ); // Calculate (e rightrotate 6) xor (e rightrotate 11). @@ -309,6 +312,7 @@ impl ShaCompressChip { local.e_rr_6.value, local.e_rr_11.value, local.s1_intermediate, + local.shard, local.is_compression, ); // Calculate S1 := ((e rightrotate 6) xor (e rightrotate 11)) xor (e rightrotate 25). @@ -317,6 +321,7 @@ impl ShaCompressChip { local.s1_intermediate.value, local.e_rr_25.value, local.s1, + local.shard, local.is_compression, ); @@ -327,16 +332,24 @@ impl ShaCompressChip { local.e, local.f, local.e_and_f, + local.shard, local.is_compression, ); // Calculate not e. - NotOperation::::eval(builder, local.e, local.e_not, local.is_compression); + NotOperation::::eval( + builder, + local.e, + local.e_not, + local.shard, + local.is_compression, + ); // Calculate (not e) and g. AndOperation::::eval( builder, local.e_not.value, local.g, local.e_not_and_g, + local.shard, local.is_compression, ); // Calculate ch := (e and f) xor ((not e) and g). @@ -345,6 +358,7 @@ impl ShaCompressChip { local.e_and_f.value, local.e_not_and_g.value, local.ch, + local.shard, local.is_compression, ); @@ -358,6 +372,7 @@ impl ShaCompressChip { local.k, local.mem.access.value, ], + local.shard, local.is_compression, local.temp1, ); @@ -369,6 +384,7 @@ impl ShaCompressChip { local.a, 2, local.a_rr_2, + local.shard, local.is_compression, ); // Calculate a rightrotate 13. @@ -377,6 +393,7 @@ impl ShaCompressChip { local.a, 13, local.a_rr_13, + local.shard, local.is_compression, ); // Calculate a rightrotate 22. @@ -385,6 +402,7 @@ impl ShaCompressChip { local.a, 22, local.a_rr_22, + local.shard, local.is_compression, ); // Calculate (a rightrotate 2) xor (a rightrotate 13). @@ -393,6 +411,7 @@ impl ShaCompressChip { local.a_rr_2.value, local.a_rr_13.value, local.s0_intermediate, + local.shard, local.is_compression, ); // Calculate S0 := ((a rightrotate 2) xor (a rightrotate 13)) xor (a rightrotate 22). @@ -401,6 +420,7 @@ impl ShaCompressChip { local.s0_intermediate.value, local.a_rr_22.value, local.s0, + local.shard, local.is_compression, ); @@ -411,6 +431,7 @@ impl ShaCompressChip { local.a, local.b, local.a_and_b, + local.shard, local.is_compression, ); // Calculate a and c. @@ -419,6 +440,7 @@ impl ShaCompressChip { local.a, local.c, local.a_and_c, + local.shard, local.is_compression, ); // Calculate b and c. @@ -427,6 +449,7 @@ impl ShaCompressChip { local.b, local.c, local.b_and_c, + local.shard, local.is_compression, ); // Calculate (a and b) xor (a and c). @@ -435,6 +458,7 @@ impl ShaCompressChip { local.a_and_b.value, local.a_and_c.value, local.maj_intermediate, + local.shard, local.is_compression, ); // Calculate maj := ((a and b) xor (a and c)) xor (b and c). @@ -443,6 +467,7 @@ impl ShaCompressChip { local.maj_intermediate.value, local.b_and_c.value, local.maj, + local.shard, local.is_compression, ); @@ -452,6 +477,7 @@ impl ShaCompressChip { local.s0.value, local.maj.value, local.temp2, + local.shard, local.is_compression.into(), ); @@ -461,6 +487,7 @@ impl ShaCompressChip { local.d, local.temp1.value, local.d_add_temp1, + local.shard, local.is_compression.into(), ); @@ -470,6 +497,7 @@ impl ShaCompressChip { local.temp1.value, local.temp2.value, local.temp1_add_temp2, + local.shard, local.is_compression.into(), ); } @@ -505,6 +533,7 @@ impl ShaCompressChip { local.mem.prev_value, local.finalized_operand, local.finalize_add, + local.shard, is_finalize.into(), ); diff --git a/core/src/syscall/precompiles/sha256/compress/trace.rs b/core/src/syscall/precompiles/sha256/compress/trace.rs index b5aa88a11a..27b76e9037 100644 --- a/core/src/syscall/precompiles/sha256/compress/trace.rs +++ b/core/src/syscall/precompiles/sha256/compress/trace.rs @@ -33,6 +33,7 @@ impl MachineAir for ShaCompressChip { let mut new_byte_lookup_events = Vec::new(); for i in 0..input.sha_compress_events.len() { let mut event = input.sha_compress_events[i].clone(); + let shard = event.shard; let og_h = event.h; @@ -107,37 +108,43 @@ impl MachineAir for ShaCompressChip { cols.g = Word::from(g); cols.h = Word::from(h); - let e_rr_6 = cols.e_rr_6.populate(output, e, 6); - let e_rr_11 = cols.e_rr_11.populate(output, e, 11); - let e_rr_25 = cols.e_rr_25.populate(output, e, 25); - let s1_intermediate = cols.s1_intermediate.populate(output, e_rr_6, e_rr_11); - let s1 = cols.s1.populate(output, s1_intermediate, e_rr_25); - - let e_and_f = cols.e_and_f.populate(output, e, f); - let e_not = cols.e_not.populate(output, e); - let e_not_and_g = cols.e_not_and_g.populate(output, e_not, g); - let ch = cols.ch.populate(output, e_and_f, e_not_and_g); - - let temp1 = cols - .temp1 - .populate(output, h, s1, ch, event.w[j], SHA_COMPRESS_K[j]); - - let a_rr_2 = cols.a_rr_2.populate(output, a, 2); - let a_rr_13 = cols.a_rr_13.populate(output, a, 13); - let a_rr_22 = cols.a_rr_22.populate(output, a, 22); - let s0_intermediate = cols.s0_intermediate.populate(output, a_rr_2, a_rr_13); - let s0 = cols.s0.populate(output, s0_intermediate, a_rr_22); - - let a_and_b = cols.a_and_b.populate(output, a, b); - let a_and_c = cols.a_and_c.populate(output, a, c); - let b_and_c = cols.b_and_c.populate(output, b, c); - let maj_intermediate = cols.maj_intermediate.populate(output, a_and_b, a_and_c); - let maj = cols.maj.populate(output, maj_intermediate, b_and_c); - - let temp2 = cols.temp2.populate(output, s0, maj); - - let d_add_temp1 = cols.d_add_temp1.populate(output, d, temp1); - let temp1_add_temp2 = cols.temp1_add_temp2.populate(output, temp1, temp2); + let e_rr_6 = cols.e_rr_6.populate(output, shard, e, 6); + let e_rr_11 = cols.e_rr_11.populate(output, shard, e, 11); + let e_rr_25 = cols.e_rr_25.populate(output, shard, e, 25); + let s1_intermediate = cols + .s1_intermediate + .populate(output, shard, e_rr_6, e_rr_11); + let s1 = cols.s1.populate(output, shard, s1_intermediate, e_rr_25); + + let e_and_f = cols.e_and_f.populate(output, shard, e, f); + let e_not = cols.e_not.populate(output, shard, e); + let e_not_and_g = cols.e_not_and_g.populate(output, shard, e_not, g); + let ch = cols.ch.populate(output, shard, e_and_f, e_not_and_g); + + let temp1 = + cols.temp1 + .populate(output, shard, h, s1, ch, event.w[j], SHA_COMPRESS_K[j]); + + let a_rr_2 = cols.a_rr_2.populate(output, shard, a, 2); + let a_rr_13 = cols.a_rr_13.populate(output, shard, a, 13); + let a_rr_22 = cols.a_rr_22.populate(output, shard, a, 22); + let s0_intermediate = cols + .s0_intermediate + .populate(output, shard, a_rr_2, a_rr_13); + let s0 = cols.s0.populate(output, shard, s0_intermediate, a_rr_22); + + let a_and_b = cols.a_and_b.populate(output, shard, a, b); + let a_and_c = cols.a_and_c.populate(output, shard, a, c); + let b_and_c = cols.b_and_c.populate(output, shard, b, c); + let maj_intermediate = cols + .maj_intermediate + .populate(output, shard, a_and_b, a_and_c); + let maj = cols.maj.populate(output, shard, maj_intermediate, b_and_c); + + let temp2 = cols.temp2.populate(output, shard, s0, maj); + + let d_add_temp1 = cols.d_add_temp1.populate(output, shard, d, temp1); + let temp1_add_temp2 = cols.temp1_add_temp2.populate(output, shard, temp1, temp2); event.h[7] = g; event.h[6] = f; @@ -174,7 +181,8 @@ impl MachineAir for ShaCompressChip { cols.octet[j] = F::one(); cols.octet_num[octet_num_idx] = F::one(); - cols.finalize_add.populate(output, og_h[j], event.h[j]); + cols.finalize_add + .populate(output, shard, og_h[j], event.h[j]); cols.mem .populate_write(event.h_write_records[j], &mut new_byte_lookup_events); cols.mem_addr = F::from_canonical_u32(event.h_ptr + (j * 4) as u32); diff --git a/core/src/syscall/precompiles/sha256/extend/air.rs b/core/src/syscall/precompiles/sha256/extend/air.rs index e8476ceeb2..1de06e0672 100644 --- a/core/src/syscall/precompiles/sha256/extend/air.rs +++ b/core/src/syscall/precompiles/sha256/extend/air.rs @@ -89,6 +89,7 @@ where *local.w_i_minus_15.value(), 7, local.w_i_minus_15_rr_7, + local.shard, local.is_real, ); // w[i-15] rightrotate 18. @@ -97,6 +98,7 @@ where *local.w_i_minus_15.value(), 18, local.w_i_minus_15_rr_18, + local.shard, local.is_real, ); // w[i-15] rightshift 3. @@ -105,6 +107,7 @@ where *local.w_i_minus_15.value(), 3, local.w_i_minus_15_rs_3, + local.shard, local.is_real, ); // (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) @@ -113,6 +116,7 @@ where local.w_i_minus_15_rr_7.value, local.w_i_minus_15_rr_18.value, local.s0_intermediate, + local.shard, local.is_real, ); // s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift 3) @@ -121,6 +125,7 @@ where local.s0_intermediate.value, local.w_i_minus_15_rs_3.value, local.s0, + local.shard, local.is_real, ); @@ -131,6 +136,7 @@ where *local.w_i_minus_2.value(), 17, local.w_i_minus_2_rr_17, + local.shard, local.is_real, ); // w[i-2] rightrotate 19. @@ -139,6 +145,7 @@ where *local.w_i_minus_2.value(), 19, local.w_i_minus_2_rr_19, + local.shard, local.is_real, ); // w[i-2] rightshift 10. @@ -147,6 +154,7 @@ where *local.w_i_minus_2.value(), 10, local.w_i_minus_2_rs_10, + local.shard, local.is_real, ); // (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) @@ -155,6 +163,7 @@ where local.w_i_minus_2_rr_17.value, local.w_i_minus_2_rr_19.value, local.s1_intermediate, + local.shard, local.is_real, ); // s1 := (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) xor (w[i-2] rightshift 10) @@ -163,6 +172,7 @@ where local.s1_intermediate.value, local.w_i_minus_2_rs_10.value, local.s1, + local.shard, local.is_real, ); @@ -173,6 +183,7 @@ where local.s0.value, *local.w_i_minus_7.value(), local.s1.value, + local.shard, local.is_real, local.s2, ); diff --git a/core/src/syscall/precompiles/sha256/extend/mod.rs b/core/src/syscall/precompiles/sha256/extend/mod.rs index c19ae489c9..529e7c2687 100644 --- a/core/src/syscall/precompiles/sha256/extend/mod.rs +++ b/core/src/syscall/precompiles/sha256/extend/mod.rs @@ -90,7 +90,7 @@ pub mod extend_tests { #[test] fn generate_trace() { let mut shard = ExecutionRecord::default(); - shard.add_events = vec![AluEvent::new(0, Opcode::ADD, 14, 8, 6)]; + shard.add_events = vec![AluEvent::new(0, 0, Opcode::ADD, 14, 8, 6)]; let chip = ShaExtendChip::new(); let trace: RowMajorMatrix = chip.generate_trace(&shard, &mut ExecutionRecord::default()); diff --git a/core/src/syscall/precompiles/sha256/extend/trace.rs b/core/src/syscall/precompiles/sha256/extend/trace.rs index 7b338e2c9a..01c4fa24f0 100644 --- a/core/src/syscall/precompiles/sha256/extend/trace.rs +++ b/core/src/syscall/precompiles/sha256/extend/trace.rs @@ -29,6 +29,7 @@ impl MachineAir for ShaExtendChip { let mut new_byte_lookup_events = Vec::new(); for i in 0..input.sha_extend_events.len() { let event = input.sha_extend_events[i].clone(); + let shard = event.shard; for j in 0..48usize { let mut row = [F::zero(); NUM_SHA_EXTEND_COLS]; let cols: &mut ShaExtendCols = row.as_mut_slice().borrow_mut(); @@ -49,28 +50,51 @@ impl MachineAir for ShaExtendChip { // `s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift 3)`. let w_i_minus_15 = event.w_i_minus_15_reads[j].value; - let w_i_minus_15_rr_7 = cols.w_i_minus_15_rr_7.populate(output, w_i_minus_15, 7); - let w_i_minus_15_rr_18 = cols.w_i_minus_15_rr_18.populate(output, w_i_minus_15, 18); - let w_i_minus_15_rs_3 = cols.w_i_minus_15_rs_3.populate(output, w_i_minus_15, 3); - let s0_intermediate = - cols.s0_intermediate - .populate(output, w_i_minus_15_rr_7, w_i_minus_15_rr_18); - let s0 = cols.s0.populate(output, s0_intermediate, w_i_minus_15_rs_3); + let w_i_minus_15_rr_7 = + cols.w_i_minus_15_rr_7 + .populate(output, shard, w_i_minus_15, 7); + let w_i_minus_15_rr_18 = + cols.w_i_minus_15_rr_18 + .populate(output, shard, w_i_minus_15, 18); + let w_i_minus_15_rs_3 = + cols.w_i_minus_15_rs_3 + .populate(output, shard, w_i_minus_15, 3); + let s0_intermediate = cols.s0_intermediate.populate( + output, + shard, + w_i_minus_15_rr_7, + w_i_minus_15_rr_18, + ); + let s0 = cols + .s0 + .populate(output, shard, s0_intermediate, w_i_minus_15_rs_3); // `s1 := (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) xor (w[i-2] rightshift 10)`. let w_i_minus_2 = event.w_i_minus_2_reads[j].value; - let w_i_minus_2_rr_17 = cols.w_i_minus_2_rr_17.populate(output, w_i_minus_2, 17); - let w_i_minus_2_rr_19 = cols.w_i_minus_2_rr_19.populate(output, w_i_minus_2, 19); - let w_i_minus_2_rs_10 = cols.w_i_minus_2_rs_10.populate(output, w_i_minus_2, 10); - let s1_intermediate = - cols.s1_intermediate - .populate(output, w_i_minus_2_rr_17, w_i_minus_2_rr_19); - let s1 = cols.s1.populate(output, s1_intermediate, w_i_minus_2_rs_10); + let w_i_minus_2_rr_17 = + cols.w_i_minus_2_rr_17 + .populate(output, shard, w_i_minus_2, 17); + let w_i_minus_2_rr_19 = + cols.w_i_minus_2_rr_19 + .populate(output, shard, w_i_minus_2, 19); + let w_i_minus_2_rs_10 = + cols.w_i_minus_2_rs_10 + .populate(output, shard, w_i_minus_2, 10); + let s1_intermediate = cols.s1_intermediate.populate( + output, + shard, + w_i_minus_2_rr_17, + w_i_minus_2_rr_19, + ); + let s1 = cols + .s1 + .populate(output, shard, s1_intermediate, w_i_minus_2_rs_10); // Compute `s2`. let w_i_minus_7 = event.w_i_minus_7_reads[j].value; let w_i_minus_16 = event.w_i_minus_16_reads[j].value; - cols.s2.populate(output, w_i_minus_16, s0, w_i_minus_7, s1); + cols.s2 + .populate(output, shard, w_i_minus_16, s0, w_i_minus_7, s1); cols.w_i .populate(event.w_i_writes[j], &mut new_byte_lookup_events); diff --git a/core/src/utils/prove.rs b/core/src/utils/prove.rs index ac4778ca7f..1e528722f4 100644 --- a/core/src/utils/prove.rs +++ b/core/src/utils/prove.rs @@ -61,6 +61,7 @@ pub fn run_test( run_test_core(runtime) } +#[allow(unused_variables)] pub fn run_test_core( runtime: Runtime, ) -> Result, crate::stark::ProgramVerificationError> { @@ -70,8 +71,12 @@ pub fn run_test_core( let mut challenger = machine.config().challenger(); #[cfg(feature = "debug")] - let record_clone = runtime.record.clone(); - + { + let mut challenger_clone = machine.config().challenger(); + let record_clone = runtime.record.clone(); + machine.debug_constraints(&pk, record_clone, &mut challenger_clone); + log::debug!("debug_constraints done"); + } let start = Instant::now(); let proof = tracing::info_span!("prove") .in_scope(|| machine.prove::>(&pk, runtime.record, &mut challenger)); @@ -80,9 +85,6 @@ pub fn run_test_core( let time = start.elapsed().as_millis(); let nb_bytes = bincode::serialize(&proof).unwrap().len(); - #[cfg(feature = "debug")] - machine.debug_constraints(&runtime.program, &pk, record_clone, &mut challenger); - let mut challenger = machine.config().challenger(); machine.verify(&vk, &proof, &mut challenger)?; @@ -137,7 +139,7 @@ where #[cfg(feature = "debug")] { let record_clone = runtime.record.clone(); - machine.debug_constraints(&program, &pk, record_clone, &mut challenger); + machine.debug_constraints(&pk, record_clone, &mut challenger); } let public_values = std::mem::take(&mut runtime.state.public_values_stream); let proof = prove_core(machine.config().clone(), runtime); diff --git a/recursion/program/src/stark.rs b/recursion/program/src/stark.rs index 9a2a59c36e..815c6eaeb6 100644 --- a/recursion/program/src/stark.rs +++ b/recursion/program/src/stark.rs @@ -657,7 +657,7 @@ pub(crate) mod tests { let mut challenger = machine.config().challenger(); let record_clone = runtime.record.clone(); - machine.debug_constraints(&program, &pk, record_clone, &mut challenger); + machine.debug_constraints(&pk, record_clone, &mut challenger); let start = Instant::now(); let mut challenger = machine.config().challenger();