Skip to content

Commit

Permalink
perf: mul trace gen (#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
ratankaliani authored Feb 24, 2024
1 parent 17e6ebc commit 7f6775d
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 104 deletions.
1 change: 0 additions & 1 deletion cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use std::cmp::min;
use std::fs::File as SyncFile;
use std::io::Write;
use std::process::{Command, Stdio};
use target_lexicon;

pub const RUSTUP_TOOLCHAIN_NAME: &str = "succinct";

Expand Down
235 changes: 132 additions & 103 deletions core/src/alu/mul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ use p3_field::AbstractField;
use p3_field::PrimeField;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::MatrixRowSlices;
use p3_maybe_rayon::prelude::ParallelIterator;
use p3_maybe_rayon::prelude::ParallelSlice;
use sp1_derive::AlignedBorrow;
use tracing::instrument;

Expand Down Expand Up @@ -118,101 +120,123 @@ impl<F: PrimeField> MachineAir<F> for MulChip {
input: &ExecutionRecord,
output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
// Generate the trace rows for each event.
let mut rows: Vec<[F; NUM_MUL_COLS]> = vec![];
let mul_events = input.mul_events.clone();
for event in mul_events.iter() {
assert!(
event.opcode == Opcode::MUL
|| event.opcode == Opcode::MULHU
|| event.opcode == Opcode::MULH
|| event.opcode == Opcode::MULHSU
);
let mut row = [F::zero(); NUM_MUL_COLS];
let cols: &mut MulCols<F> = row.as_mut_slice().borrow_mut();
let a_word = event.a.to_le_bytes();
let b_word = event.b.to_le_bytes();
let c_word = event.c.to_le_bytes();

let mut b = b_word.to_vec();
let mut c = c_word.to_vec();

// Handle b and c's signs.
{
let b_msb = get_msb(b_word);
cols.b_msb = F::from_canonical_u8(b_msb);
let c_msb = get_msb(c_word);
cols.c_msb = F::from_canonical_u8(c_msb);

// If b is signed and it is negative, sign extend b.
if (event.opcode == Opcode::MULH || event.opcode == Opcode::MULHSU) && b_msb == 1 {
cols.b_sign_extend = F::one();
b.resize(PRODUCT_SIZE, BYTE_MASK);
}

// If c is signed and it is negative, sign extend c.
if event.opcode == Opcode::MULH && c_msb == 1 {
cols.c_sign_extend = F::one();
c.resize(PRODUCT_SIZE, BYTE_MASK);
}

// Insert the MSB lookup events.
{
let words = [b_word, c_word];
let mut blu_events: Vec<ByteLookupEvent> = vec![];
for word in words.iter() {
let most_significant_byte = word[WORD_SIZE - 1];
blu_events.push(ByteLookupEvent {
opcode: ByteOpcode::MSB,
a1: get_msb(*word) as u32,
a2: 0,
b: most_significant_byte as u32,
c: 0,
});
}
output.add_byte_lookup_events(blu_events);
}
}

let mut product = [0u32; PRODUCT_SIZE];
for i in 0..b.len() {
for j in 0..c.len() {
if i + j < PRODUCT_SIZE {
product[i + j] += (b[i] as u32) * (c[j] as u32);
}
}
}
// Compute the chunk size based on the number of events and the number of CPUs.
let chunk_size = std::cmp::max(mul_events.len() / num_cpus::get(), 1);

// Generate the trace rows & corresponding records for each chunk of events in parallel.
let rows_and_records = mul_events
.par_chunks(chunk_size)
.map(|events| {
let mut record = ExecutionRecord::default();
let rows = events
.iter()
.map(|event| {
// Ensure that the opcode is MUL, MULHU, MULH, or MULHSU.
assert!(
event.opcode == Opcode::MUL
|| event.opcode == Opcode::MULHU
|| event.opcode == Opcode::MULH
|| event.opcode == Opcode::MULHSU
);
let mut row = [F::zero(); NUM_MUL_COLS];
let cols: &mut MulCols<F> = row.as_mut_slice().borrow_mut();

let a_word = event.a.to_le_bytes();
let b_word = event.b.to_le_bytes();
let c_word = event.c.to_le_bytes();

let mut b = b_word.to_vec();
let mut c = c_word.to_vec();

// Handle b and c's signs.
{
let b_msb = get_msb(b_word);
cols.b_msb = F::from_canonical_u8(b_msb);
let c_msb = get_msb(c_word);
cols.c_msb = F::from_canonical_u8(c_msb);

// If b is signed and it is negative, sign extend b.
if (event.opcode == Opcode::MULH || event.opcode == Opcode::MULHSU)
&& b_msb == 1
{
cols.b_sign_extend = F::one();
b.resize(PRODUCT_SIZE, BYTE_MASK);
}

// If c is signed and it is negative, sign extend c.
if event.opcode == Opcode::MULH && c_msb == 1 {
cols.c_sign_extend = F::one();
c.resize(PRODUCT_SIZE, BYTE_MASK);
}

// Insert the MSB lookup events.
{
let words = [b_word, c_word];
let mut blu_events: Vec<ByteLookupEvent> = vec![];
for word in words.iter() {
let most_significant_byte = word[WORD_SIZE - 1];
blu_events.push(ByteLookupEvent {
opcode: ByteOpcode::MSB,
a1: get_msb(*word) as u32,
a2: 0,
b: most_significant_byte as u32,
c: 0,
});
}
record.add_byte_lookup_events(blu_events);
}
}

let mut product = [0u32; PRODUCT_SIZE];
for i in 0..b.len() {
for j in 0..c.len() {
if i + j < PRODUCT_SIZE {
product[i + j] += (b[i] as u32) * (c[j] as u32);
}
}
}

// Calculate the correct product using the `product` array. We store the correct carry
// value for verification.
let base = 1 << BYTE_SIZE;
let mut carry = [0u32; PRODUCT_SIZE];
for i in 0..PRODUCT_SIZE {
carry[i] = product[i] / base;
product[i] %= base;
if i + 1 < PRODUCT_SIZE {
product[i + 1] += carry[i];
}
cols.carry[i] = F::from_canonical_u32(carry[i]);
}

cols.product = product.map(F::from_canonical_u32);
cols.a = Word(a_word.map(F::from_canonical_u8));
cols.b = Word(b_word.map(F::from_canonical_u8));
cols.c = Word(c_word.map(F::from_canonical_u8));
cols.is_real = F::one();
cols.is_mul = F::from_bool(event.opcode == Opcode::MUL);
cols.is_mulh = F::from_bool(event.opcode == Opcode::MULH);
cols.is_mulhu = F::from_bool(event.opcode == Opcode::MULHU);
cols.is_mulhsu = F::from_bool(event.opcode == Opcode::MULHSU);

// Range check.
{
record.add_u16_range_checks(&carry);
record.add_u8_range_checks(&product.map(|x| x as u8));
}
row
})
.collect::<Vec<_>>();
(rows, record)
})
.collect::<Vec<_>>();

// Calculate the correct product using the `product` array. We store the correct carry
// value for verification.
let base = 1 << BYTE_SIZE;
let mut carry = [0u32; PRODUCT_SIZE];
for i in 0..PRODUCT_SIZE {
carry[i] = product[i] / base;
product[i] %= base;
if i + 1 < PRODUCT_SIZE {
product[i + 1] += carry[i];
}
cols.carry[i] = F::from_canonical_u32(carry[i]);
}

cols.product = product.map(F::from_canonical_u32);
cols.a = Word(a_word.map(F::from_canonical_u8));
cols.b = Word(b_word.map(F::from_canonical_u8));
cols.c = Word(c_word.map(F::from_canonical_u8));
cols.is_real = F::one();
cols.is_mul = F::from_bool(event.opcode == Opcode::MUL);
cols.is_mulh = F::from_bool(event.opcode == Opcode::MULH);
cols.is_mulhu = F::from_bool(event.opcode == Opcode::MULHU);
cols.is_mulhsu = F::from_bool(event.opcode == Opcode::MULHSU);

// Range check.
{
output.add_u16_range_checks(&carry);
output.add_u8_range_checks(&product.map(|x| x as u8));
}

rows.push(row);
// Generate the trace rows for each event.
let mut rows: Vec<[F; NUM_MUL_COLS]> = vec![];
for mut row_and_record in rows_and_records {
rows.extend(row_and_record.0);
output.append(&mut row_and_record.1);
}

// Convert the trace to a row major matrix.
Expand Down Expand Up @@ -404,19 +428,24 @@ mod tests {
use super::MulChip;

#[test]
fn generate_trace() {
fn generate_trace_mul() {
let mut shard = ExecutionRecord::default();
shard.mul_events = vec![AluEvent::new(
0,
Opcode::MULHSU,
0x80004000,
0x80000000,
0xffff8000,
)];

// Fill mul_events with 10^7 MULHSU events.
let mut mul_events: Vec<AluEvent> = Vec::new();
for _ in 0..10i32.pow(7) {
mul_events.push(AluEvent::new(
0,
Opcode::MULHSU,
0x80004000,
0x80000000,
0xffff8000,
));
}
shard.mul_events = mul_events;
let chip = MulChip::default();
let trace: RowMajorMatrix<BabyBear> =
let _trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&shard, &mut ExecutionRecord::default());
println!("{:?}", trace.values)
}

#[test]
Expand Down

0 comments on commit 7f6775d

Please sign in to comment.