Skip to content

Commit

Permalink
feat: fri-fold precompile (#479)
Browse files Browse the repository at this point in the history
  • Loading branch information
tamirhemo authored Apr 6, 2024
1 parent 5ccb3b9 commit f6d6fd8
Show file tree
Hide file tree
Showing 13 changed files with 133 additions and 45 deletions.
7 changes: 7 additions & 0 deletions recursion/compiler/src/asm/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,13 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> AsmCo
}
_ => unimplemented!(),
},
DslIR::FriFold(m, input_ptr) => {
if let Array::Dyn(ptr, _) = input_ptr {
self.push(AsmInstruction::FriFold(m.fp(), ptr.fp()));
} else {
unimplemented!();
}
}
_ => unimplemented!(),
}
}
Expand Down
16 changes: 16 additions & 0 deletions recursion/compiler/src/asm/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ pub enum AsmInstruction<F, EF> {
PrintF(i32),
PrintE(i32),
Ext2Felt(i32, i32),

// FRIFold(m, input) specific instructions.
FriFold(i32, i32),
}

impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
Expand Down Expand Up @@ -841,6 +844,16 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
false,
true,
),
AsmInstruction::FriFold(m, ptr) => Instruction::new(
Opcode::FRIFold,
i32_f(m),
i32_f_arr(ptr),
f_u32(F::zero()),
F::zero(),
F::zero(),
false,
true,
),
}
}

Expand Down Expand Up @@ -1123,6 +1136,9 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
write!(f, "print_e ({})fp", dst)
}
AsmInstruction::Ext2Felt(dst, src) => write!(f, "ext2felt ({})fp, {})fp", dst, src),
AsmInstruction::FriFold(m, input_ptr) => {
write!(f, "fri_fold ({})fp, ({})fp", m, input_ptr)
}
}
}
}
1 change: 1 addition & 0 deletions recursion/compiler/src/ir/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ impl<C: Config, V: MemVariable<C>> Array<C, V> {
todo!()
}
Self::Dyn(ptr, len) => {
assert!(V::size_of() == 1, "only support variables of size 1");
let new_address = builder.eval(ptr.address + shift);
let new_ptr = Ptr::<C::N> {
address: new_address,
Expand Down
14 changes: 14 additions & 0 deletions recursion/compiler/src/ir/fold.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use crate::ir::{Array, Config};
use crate::prelude::*;

#[derive(DslVariable, Debug, Clone)]
pub struct FriFoldInput<C: Config> {
pub z: Ext<C::F, C::EF>,
pub alpha: Ext<C::F, C::EF>,
pub x: Felt<C::F>,
pub log_height: Var<C::N>,
pub mat_opening: Array<C, Ext<C::F, C::EF>>,
pub ps_at_z: Array<C, Ext<C::F, C::EF>>,
pub alpha_pow: Array<C, Ext<C::F, C::EF>>,
pub ro: Array<C, Ext<C::F, C::EF>>,
}
5 changes: 4 additions & 1 deletion recursion/compiler/src/ir/instructions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{Array, MemIndex, Ptr};
use super::{Array, FriFoldInput, MemIndex, Ptr};

use super::{Config, Ext, Felt, Usize, Var};

Expand Down Expand Up @@ -109,6 +109,9 @@ pub enum DslIR<C: Config> {
ExpUsizeF(Felt<C::F>, Felt<C::F>, Usize<C::N>),
Ext2Felt(Array<C, Felt<C::F>>, Ext<C::F, C::EF>),

// FRI specific instructions.
FriFold(Var<C::N>, Array<C, FriFoldInput<C>>),

// Circuit-specific instructions.
CircuitPoseidon2Permute([Var<C::N>; 3]),
CircuitNum2BitsV(Var<C::N>, usize, Vec<Var<C::N>>),
Expand Down
2 changes: 2 additions & 0 deletions recursion/compiler/src/ir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use p3_field::{ExtensionField, PrimeField, TwoAdicField};

mod builder;
mod collections;
mod fold;
mod instructions;
mod ptr;
mod symbolic;
Expand All @@ -11,6 +12,7 @@ mod var;

pub use builder::*;
pub use collections::*;
pub use fold::*;
pub use instructions::*;
pub use ptr::*;
pub use symbolic::*;
Expand Down
4 changes: 4 additions & 0 deletions recursion/core/src/cpu/columns/opcode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ pub struct OpcodeSelectorCols<T> {
// System instructions.
pub is_trap: T,
pub is_noop: T,

pub is_fri_fold: T,
}

impl<F: Field> OpcodeSelectorCols<F> {
Expand Down Expand Up @@ -92,6 +94,7 @@ impl<F: Field> OpcodeSelectorCols<F> {
Opcode::HintBits => self.is_noop = F::one(),
Opcode::PrintF => self.is_noop = F::one(),
Opcode::PrintE => self.is_noop = F::one(),
Opcode::FRIFold => self.is_fri_fold = F::one(),
}
}
}
Expand Down Expand Up @@ -129,6 +132,7 @@ impl<T: Copy> IntoIterator for &OpcodeSelectorCols<T> {
self.is_jalr,
self.is_trap,
self.is_noop,
self.is_fri_fold,
]
.into_iter()
}
Expand Down
48 changes: 48 additions & 0 deletions recursion/core/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,54 @@ where
}
(a, b, c) = (a_val, b_val, c_val);
}
Opcode::FRIFold => {
let a_val = self.mr(self.fp + instruction.op_a, MemoryAccessPosition::A);
let b_val = self.mr(self.fp + instruction.op_b[0], MemoryAccessPosition::B);
let c_val = Block::<F>::default();

let m = a_val[0].as_canonical_u32() as usize;
let input_ptr = b_val[0].as_canonical_u32() as usize;

// Read the input values.
let mut ptr = input_ptr;
let z = self.memory[ptr].value.ext::<EF>();
ptr += 4;
let alpha = self.memory[ptr].value.ext::<EF>();
ptr += 4;
let x = self.memory[ptr].value[0];
ptr += 1;
let log_height = self.memory[ptr].value[0].as_canonical_u32() as usize;
ptr += 1;
let mat_opening_ptr = self.memory[ptr].value[0].as_canonical_u32() as usize;
ptr += 2;
let ps_at_z_ptr = self.memory[ptr].value[0].as_canonical_u32() as usize;
ptr += 2;
let alpha_pow_ptr = self.memory[ptr].value[0].as_canonical_u32() as usize;
ptr += 2;
let ro_ptr = self.memory[ptr].value[0].as_canonical_u32() as usize;

// Get the opening values.
let p_at_x = self.memory[mat_opening_ptr + m * EF::D].value.ext::<EF>();
let p_at_z = self.memory[ps_at_z_ptr + m * EF::D].value.ext::<EF>();

// Calculate the quotient and update the values
let quotient = (-p_at_z + p_at_x) / (-z + x);

// Modify the ro and alpha pow values.
let alpha_pow_at_log_height = self.memory[alpha_pow_ptr + log_height * EF::D]
.value
.ext::<EF>();
let ro_at_log_height =
self.memory[ro_ptr + log_height * EF::D].value.ext::<EF>();

self.memory[ro_ptr + log_height * EF::D].value = Block::from(
(ro_at_log_height + alpha_pow_at_log_height * quotient).as_base_slice(),
);
self.memory[alpha_pow_ptr + log_height * EF::D].value =
Block::from((alpha_pow_at_log_height * alpha).as_base_slice());

(a, b, c) = (a_val, b_val, c_val);
}
};

let event = CpuEvent {
Expand Down
2 changes: 2 additions & 0 deletions recursion/core/src/runtime/opcode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ pub enum Opcode {
PrintF = 33,
PrintE = 34,
Ext2Felt = 35,

FRIFold = 36,
}

impl Opcode {
Expand Down
2 changes: 1 addition & 1 deletion recursion/program/src/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ pub fn const_fri_config(
builder.set(&mut subgroups, i, domain_value);
}
FriConfigVariable {
log_blowup: Val::from_canonical_usize(config.log_blowup),
log_blowup: config.log_blowup,
num_queries: config.num_queries,
proof_of_work_bits: config.proof_of_work_bits,
subgroups,
Expand Down
11 changes: 5 additions & 6 deletions recursion/program/src/fri/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod domain;
mod two_adic_pcs;

pub use domain::*;

use sp1_recursion_compiler::ir::Array;
use sp1_recursion_compiler::ir::Builder;
use sp1_recursion_compiler::ir::Config;
Expand All @@ -17,9 +18,6 @@ pub use two_adic_pcs::*;
#[cfg(test)]
pub(crate) use two_adic_pcs::tests::*;

// #[cfg(test)]
// pub(crate) use domain::tests::*;

use p3_field::AbstractField;
use p3_field::Field;
use p3_field::TwoAdicField;
Expand Down Expand Up @@ -66,7 +64,8 @@ pub fn verify_shape_and_sample_challenges<C: Config>(
challenger.check_witness(builder, config.proof_of_work_bits, proof.pow_witness);

let num_commit_phase_commits = proof.commit_phase_commits.len().materialize(builder);
let log_max_height: Var<_> = builder.eval(num_commit_phase_commits + config.log_blowup);
let log_max_height: Var<_> =
builder.eval(num_commit_phase_commits + C::N::from_canonical_usize(config.log_blowup));
let mut query_indices = builder.array(config.num_queries);
builder.range(0, config.num_queries).for_each(|i, builder| {
let index_bits = challenger.sample_bits(builder, Usize::Var(log_max_height));
Expand All @@ -83,7 +82,6 @@ pub fn verify_shape_and_sample_challenges<C: Config>(
///
/// Reference: https://github.com/Plonky3/Plonky3/blob/4809fa7bedd9ba8f6f5d3267b1592618e3776c57/fri/src/verifier.rs#L67
#[allow(clippy::type_complexity)]
#[allow(unused_variables)]
pub fn verify_challenges<C: Config>(
builder: &mut Builder<C>,
config: &FriConfigVariable<C>,
Expand All @@ -95,7 +93,8 @@ pub fn verify_challenges<C: Config>(
C::EF: TwoAdicField,
{
let nb_commit_phase_commits = proof.commit_phase_commits.len().materialize(builder);
let log_max_height = builder.eval(nb_commit_phase_commits + config.log_blowup);
let log_max_height =
builder.eval(nb_commit_phase_commits + C::N::from_canonical_usize(config.log_blowup));
builder
.range(0, challenges.query_indices.len())
.for_each(|i, builder| {
Expand Down
64 changes: 28 additions & 36 deletions recursion/program/src/fri/two_adic_pcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ pub fn verify_two_adic_pcs<C: Config>(
C::F: TwoAdicField,
C::EF: TwoAdicField,
{
let log_blowup = C::N::from_canonical_usize(config.log_blowup);
let blowup = C::N::from_canonical_usize(1 << config.log_blowup);
let alpha = challenger.sample_ext(builder);

let fri_challenges =
Expand All @@ -66,24 +68,24 @@ pub fn verify_two_adic_pcs<C: Config>(
.commit_phase_commits
.len()
.materialize(builder);
let log_global_max_height: Var<_> = builder.eval(commit_phase_commits_len + config.log_blowup);
let log_global_max_height: Var<_> = builder.eval(commit_phase_commits_len + log_blowup);

let mut reduced_openings: Array<C, Array<C, Ext<C::F, C::EF>>> =
builder.array(proof.query_openings.len());

builder
.range(0, proof.query_openings.len())
.for_each(|i, builder| {
let query_opening = builder.get(&proof.query_openings, i);
let index_bits = builder.get(&fri_challenges.query_indices, i);

let mut ro: Array<C, Ext<C::F, C::EF>> = builder.array(32);
let zero: Ext<C::F, C::EF> = builder.eval(SymbolicExt::Const(C::EF::zero()));
let mut alpha_pow: Array<C, Ext<C::F, C::EF>> = builder.array(32);
for j in 0..32 {
builder.set(&mut ro, j, zero);
builder.set(&mut ro, j, C::EF::zero().cons());
}
let mut alpha_pow: Array<C, Ext<C::F, C::EF>> = builder.array(32);
let one: Ext<C::F, C::EF> = builder.eval(SymbolicExt::Const(C::EF::one()));
for j in 0..32 {
builder.set(&mut alpha_pow, j, one);
builder.set(&mut alpha_pow, j, C::EF::one().cons());
}

builder.range(0, rounds.len()).for_each(|j, builder| {
Expand All @@ -95,14 +97,14 @@ pub fn verify_two_adic_pcs<C: Config>(
let mut batch_heights_log2: Array<C, Var<C::N>> = builder.array(mats.len());
builder.range(0, mats.len()).for_each(|k, builder| {
let mat = builder.get(&mats, k);
let height_log2: Var<_> = builder.eval(mat.domain.log_n + config.log_blowup);
let height_log2: Var<_> = builder.eval(mat.domain.log_n + log_blowup);
builder.set(&mut batch_heights_log2, k, height_log2);
});
let mut batch_dims: Array<C, Dimensions<C>> = builder.array(mats.len());
builder.range(0, mats.len()).for_each(|k, builder| {
let mat = builder.get(&mats, k);
let dim = Dimensions::<C> {
height: builder.eval(mat.domain.size() * C::N::two()), // TODO: fix this to use blowup
height: builder.eval(mat.domain.size() * blowup), // TODO: fix this to use blowup
};
builder.set(&mut batch_dims, k, dim);
});
Expand All @@ -129,8 +131,7 @@ pub fn verify_two_adic_pcs<C: Config>(
let mat_values = mat.values;

let log2_domain_size = mat.domain.log_n;
let log_height: Var<C::N> =
builder.eval(log2_domain_size + config.log_blowup);
let log_height: Var<C::N> = builder.eval(log2_domain_size + log_blowup);

let bits_reduced: Var<C::N> =
builder.eval(log_global_max_height - log_height);
Expand All @@ -148,32 +149,23 @@ pub fn verify_two_adic_pcs<C: Config>(
builder.range(0, mat_points.len()).for_each(|l, builder| {
let z: Ext<C::F, C::EF> = builder.get(&mat_points, l);
let ps_at_z = builder.get(&mat_values, l);

let input = FriFoldInput {
z,
alpha,
x,
log_height,
mat_opening: mat_opening.clone(),
ps_at_z: ps_at_z.clone(),
alpha_pow: alpha_pow.clone(),
ro: ro.clone(),
};

let mut input_ptr = builder.array::<FriFoldInput<_>>(1);
builder.set(&mut input_ptr, 0, input);

builder.range(0, ps_at_z.len()).for_each(|m, builder| {
let p_at_x: SymbolicExt<C::F, C::EF> =
builder.get(&mat_opening, m).into();
let p_at_z: SymbolicExt<C::F, C::EF> =
builder.get(&ps_at_z, m).into();

let quotient: SymbolicExt<C::F, C::EF> =
(-p_at_z + p_at_x) / (-z + x);
// let quotient = builder.eval(quotient);
// builder.print_e(quotient);

let ro_at_log_height = builder.get(&ro, log_height);
// builder.print_e(ro_at_log_height);
let alpha_pow_at_log_height = builder.get(&alpha_pow, log_height);
// builder.print_e(alpha_pow_at_log_height);

builder.set(
&mut ro,
log_height,
ro_at_log_height + alpha_pow_at_log_height * quotient,
);
builder.set(
&mut alpha_pow,
log_height,
alpha_pow_at_log_height * alpha,
);
builder.push(DslIR::FriFold(m, input_ptr.clone()));
});
});
});
Expand Down Expand Up @@ -377,7 +369,7 @@ pub(crate) mod tests {
builder.set(&mut subgroups, i, domain_value);
}
FriConfigVariable {
log_blowup: Val::from_canonical_usize(config.log_blowup),
log_blowup: config.log_blowup,
num_queries: config.num_queries,
proof_of_work_bits: config.proof_of_work_bits,
subgroups,
Expand Down
2 changes: 1 addition & 1 deletion recursion/program/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub type Commitment<C: Config> = Array<C, Felt<C::F>>;
/// Reference: https://github.com/Plonky3/Plonky3/blob/4809fa7bedd9ba8f6f5d3267b1592618e3776c57/fri/src/config.rs#L1
#[derive(Clone)]
pub struct FriConfigVariable<C: Config> {
pub log_blowup: C::N,
pub log_blowup: usize,
pub num_queries: usize,
pub proof_of_work_bits: usize,
pub generators: Array<C, Felt<C::F>>,
Expand Down

0 comments on commit f6d6fd8

Please sign in to comment.