Skip to content

Commit

Permalink
hm broken
Browse files Browse the repository at this point in the history
  • Loading branch information
jtguibas committed Apr 7, 2024
2 parents e66dc7a + f6d6fd8 commit d72301a
Show file tree
Hide file tree
Showing 14 changed files with 131 additions and 38 deletions.
7 changes: 7 additions & 0 deletions recursion/compiler/src/asm/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,13 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> AsmCo
Array::Dyn(dst, _) => self.push(AsmInstruction::Hint(dst.fp())),
_ => unimplemented!(),
},
DslIR::FriFold(m, input_ptr) => {
if let Array::Dyn(ptr, _) = input_ptr {
self.push(AsmInstruction::FriFold(m.fp(), ptr.fp()));
} else {
unimplemented!();
}
}
_ => unimplemented!(),
}
}
Expand Down
15 changes: 15 additions & 0 deletions recursion/compiler/src/asm/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ pub enum AsmInstruction<F, EF> {

HintLen(i32),
Hint(i32),
// FRIFold(m, input) specific instructions.
FriFold(i32, i32),
}

impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
Expand Down Expand Up @@ -864,6 +866,16 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
false,
true,
),
AsmInstruction::FriFold(m, ptr) => Instruction::new(
Opcode::FRIFold,
i32_f(m),
i32_f_arr(ptr),
f_u32(F::zero()),
F::zero(),
F::zero(),
false,
true,
),
}
}

Expand Down Expand Up @@ -1148,6 +1160,9 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
AsmInstruction::Ext2Felt(dst, src) => write!(f, "ext2felt ({})fp, {})fp", dst, src),
AsmInstruction::HintLen(dst) => write!(f, "hint_len ({})fp", dst),
AsmInstruction::Hint(dst) => write!(f, "hint ({})fp", dst),
AsmInstruction::FriFold(m, input_ptr) => {
write!(f, "fri_fold ({})fp, ({})fp", m, input_ptr)
}
}
}
}
1 change: 1 addition & 0 deletions recursion/compiler/src/ir/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ impl<C: Config, V: MemVariable<C>> Array<C, V> {
todo!()
}
Self::Dyn(ptr, len) => {
assert!(V::size_of() == 1, "only support variables of size 1");
let new_address = builder.eval(ptr.address + shift);
let new_ptr = Ptr::<C::N> {
address: new_address,
Expand Down
14 changes: 14 additions & 0 deletions recursion/compiler/src/ir/fold.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use crate::ir::{Array, Config};
use crate::prelude::*;

#[derive(DslVariable, Debug, Clone)]
pub struct FriFoldInput<C: Config> {
pub z: Ext<C::F, C::EF>,
pub alpha: Ext<C::F, C::EF>,
pub x: Felt<C::F>,
pub log_height: Var<C::N>,
pub mat_opening: Array<C, Ext<C::F, C::EF>>,
pub ps_at_z: Array<C, Ext<C::F, C::EF>>,
pub alpha_pow: Array<C, Ext<C::F, C::EF>>,
pub ro: Array<C, Ext<C::F, C::EF>>,
}
4 changes: 3 additions & 1 deletion recursion/compiler/src/ir/instructions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{Array, MemIndex, Ptr};
use super::{Array, FriFoldInput, MemIndex, Ptr};

use super::{Config, Ext, Felt, Usize, Var};

Expand Down Expand Up @@ -113,6 +113,8 @@ pub enum DslIR<C: Config> {
HintVars(Array<C, Var<C::N>>),
HintFelts(Array<C, Felt<C::F>>),
HintExts(Array<C, Ext<C::F, C::EF>>),
// FRI specific instructions.
FriFold(Var<C::N>, Array<C, FriFoldInput<C>>),

// Circuit-specific instructions.
CircuitPoseidon2Permute([Var<C::N>; 3]),
Expand Down
2 changes: 2 additions & 0 deletions recursion/compiler/src/ir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use p3_field::{ExtensionField, PrimeField, TwoAdicField};

mod builder;
mod collections;
mod fold;
mod instructions;
mod ptr;
mod symbolic;
Expand All @@ -11,6 +12,7 @@ mod var;

pub use builder::*;
pub use collections::*;
pub use fold::*;
pub use instructions::*;
pub use ptr::*;
pub use symbolic::*;
Expand Down
4 changes: 4 additions & 0 deletions recursion/core/src/cpu/columns/opcode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ pub struct OpcodeSelectorCols<T> {
// System instructions.
pub is_trap: T,
pub is_noop: T,

pub is_fri_fold: T,
}

impl<F: Field> OpcodeSelectorCols<F> {
Expand Down Expand Up @@ -92,6 +94,7 @@ impl<F: Field> OpcodeSelectorCols<F> {
Opcode::HintBits => self.is_noop = F::one(),
Opcode::PrintF => self.is_noop = F::one(),
Opcode::PrintE => self.is_noop = F::one(),
Opcode::FRIFold => self.is_fri_fold = F::one(),
_ => unreachable!(),
}
}
Expand Down Expand Up @@ -130,6 +133,7 @@ impl<T: Copy> IntoIterator for &OpcodeSelectorCols<T> {
self.is_jalr,
self.is_trap,
self.is_noop,
self.is_fri_fold,
]
.into_iter()
}
Expand Down
48 changes: 48 additions & 0 deletions recursion/core/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,54 @@ where
}
(a, b, c) = (a_val, b_val, c_val);
}
Opcode::FRIFold => {
let a_val = self.mr(self.fp + instruction.op_a, MemoryAccessPosition::A);
let b_val = self.mr(self.fp + instruction.op_b[0], MemoryAccessPosition::B);
let c_val = Block::<F>::default();

let m = a_val[0].as_canonical_u32() as usize;
let input_ptr = b_val[0].as_canonical_u32() as usize;

// Read the input values.
let mut ptr = input_ptr;
let z = self.memory[ptr].value.ext::<EF>();
ptr += 4;
let alpha = self.memory[ptr].value.ext::<EF>();
ptr += 4;
let x = self.memory[ptr].value[0];
ptr += 1;
let log_height = self.memory[ptr].value[0].as_canonical_u32() as usize;
ptr += 1;
let mat_opening_ptr = self.memory[ptr].value[0].as_canonical_u32() as usize;
ptr += 2;
let ps_at_z_ptr = self.memory[ptr].value[0].as_canonical_u32() as usize;
ptr += 2;
let alpha_pow_ptr = self.memory[ptr].value[0].as_canonical_u32() as usize;
ptr += 2;
let ro_ptr = self.memory[ptr].value[0].as_canonical_u32() as usize;

// Get the opening values.
let p_at_x = self.memory[mat_opening_ptr + m * EF::D].value.ext::<EF>();
let p_at_z = self.memory[ps_at_z_ptr + m * EF::D].value.ext::<EF>();

// Calculate the quotient and update the values
let quotient = (-p_at_z + p_at_x) / (-z + x);

// Modify the ro and alpha pow values.
let alpha_pow_at_log_height = self.memory[alpha_pow_ptr + log_height * EF::D]
.value
.ext::<EF>();
let ro_at_log_height =
self.memory[ro_ptr + log_height * EF::D].value.ext::<EF>();

self.memory[ro_ptr + log_height * EF::D].value = Block::from(
(ro_at_log_height + alpha_pow_at_log_height * quotient).as_base_slice(),
);
self.memory[alpha_pow_ptr + log_height * EF::D].value =
Block::from((alpha_pow_at_log_height * alpha).as_base_slice());

(a, b, c) = (a_val, b_val, c_val);
}
};

let event = CpuEvent {
Expand Down
1 change: 1 addition & 0 deletions recursion/core/src/runtime/opcode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ pub enum Opcode {

HintLen = 37,
Hint = 38,
FRIFold = 36,
}

impl Opcode {
Expand Down
2 changes: 1 addition & 1 deletion recursion/program/src/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ pub fn const_fri_config(
builder.set(&mut subgroups, i, domain_value);
}
FriConfigVariable {
log_blowup: Val::from_canonical_usize(config.log_blowup),
log_blowup: config.log_blowup,
num_queries: config.num_queries,
proof_of_work_bits: config.proof_of_work_bits,
subgroups,
Expand Down
8 changes: 5 additions & 3 deletions recursion/program/src/fri/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub use two_adic_pcs::*;
use p3_field::AbstractField;
use p3_field::Field;
use p3_field::TwoAdicField;

use sp1_recursion_compiler::ir::Array;
use sp1_recursion_compiler::ir::Builder;
use sp1_recursion_compiler::ir::Config;
Expand Down Expand Up @@ -65,7 +66,8 @@ pub fn verify_shape_and_sample_challenges<C: Config>(
challenger.check_witness(builder, config.proof_of_work_bits, proof.pow_witness);

let num_commit_phase_commits = proof.commit_phase_commits.len().materialize(builder);
let log_max_height: Var<_> = builder.eval(num_commit_phase_commits + config.log_blowup);
let log_max_height: Var<_> =
builder.eval(num_commit_phase_commits + C::N::from_canonical_usize(config.log_blowup));
let mut query_indices = builder.array(config.num_queries);
builder.range(0, config.num_queries).for_each(|i, builder| {
let index_bits = challenger.sample_bits(builder, Usize::Var(log_max_height));
Expand All @@ -82,7 +84,6 @@ pub fn verify_shape_and_sample_challenges<C: Config>(
///
/// Reference: https://github.com/Plonky3/Plonky3/blob/4809fa7bedd9ba8f6f5d3267b1592618e3776c57/fri/src/verifier.rs#L67
#[allow(clippy::type_complexity)]
#[allow(unused_variables)]
pub fn verify_challenges<C: Config>(
builder: &mut Builder<C>,
config: &FriConfigVariable<C>,
Expand All @@ -94,7 +95,8 @@ pub fn verify_challenges<C: Config>(
C::EF: TwoAdicField,
{
let nb_commit_phase_commits = proof.commit_phase_commits.len().materialize(builder);
let log_max_height = builder.eval(nb_commit_phase_commits + config.log_blowup);
let log_max_height =
builder.eval(nb_commit_phase_commits + C::N::from_canonical_usize(config.log_blowup));
builder
.range(0, challenges.query_indices.len())
.for_each(|i, builder| {
Expand Down
60 changes: 28 additions & 32 deletions recursion/program/src/fri/two_adic_pcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ pub fn verify_two_adic_pcs<C: Config>(
C::F: TwoAdicField,
C::EF: TwoAdicField,
{
let log_blowup = C::N::from_canonical_usize(config.log_blowup);
let blowup = C::N::from_canonical_usize(1 << config.log_blowup);
let alpha = challenger.sample_ext(builder);

let fri_challenges =
Expand All @@ -40,24 +42,24 @@ pub fn verify_two_adic_pcs<C: Config>(
.commit_phase_commits
.len()
.materialize(builder);
let log_global_max_height: Var<_> = builder.eval(commit_phase_commits_len + config.log_blowup);
let log_global_max_height: Var<_> = builder.eval(commit_phase_commits_len + log_blowup);

let mut reduced_openings: Array<C, Array<C, Ext<C::F, C::EF>>> =
builder.array(proof.query_openings.len());

builder
.range(0, proof.query_openings.len())
.for_each(|i, builder| {
let query_opening = builder.get(&proof.query_openings, i);
let index_bits = builder.get(&fri_challenges.query_indices, i);

let mut ro: Array<C, Ext<C::F, C::EF>> = builder.array(32);
let zero: Ext<C::F, C::EF> = builder.eval(SymbolicExt::Const(C::EF::zero()));
let mut alpha_pow: Array<C, Ext<C::F, C::EF>> = builder.array(32);
for j in 0..32 {
builder.set(&mut ro, j, zero);
builder.set(&mut ro, j, C::EF::zero().cons());
}
let mut alpha_pow: Array<C, Ext<C::F, C::EF>> = builder.array(32);
let one: Ext<C::F, C::EF> = builder.eval(SymbolicExt::Const(C::EF::one()));
for j in 0..32 {
builder.set(&mut alpha_pow, j, one);
builder.set(&mut alpha_pow, j, C::EF::one().cons());
}

builder.range(0, rounds.len()).for_each(|j, builder| {
Expand All @@ -69,14 +71,14 @@ pub fn verify_two_adic_pcs<C: Config>(
let mut batch_heights_log2: Array<C, Var<C::N>> = builder.array(mats.len());
builder.range(0, mats.len()).for_each(|k, builder| {
let mat = builder.get(&mats, k);
let height_log2: Var<_> = builder.eval(mat.domain.log_n + config.log_blowup);
let height_log2: Var<_> = builder.eval(mat.domain.log_n + log_blowup);
builder.set(&mut batch_heights_log2, k, height_log2);
});
let mut batch_dims: Array<C, DimensionsVariable<C>> = builder.array(mats.len());
builder.range(0, mats.len()).for_each(|k, builder| {
let mat = builder.get(&mats, k);
let dim = DimensionsVariable::<C> {
height: builder.eval(mat.domain.size() * C::N::two()), // TODO: fix this to use blowup
height: builder.eval(mat.domain.size() * blowup), // TODO: fix this to use blowup
};
builder.set(&mut batch_dims, k, dim);
});
Expand All @@ -103,8 +105,7 @@ pub fn verify_two_adic_pcs<C: Config>(
let mat_values = mat.values;

let log2_domain_size = mat.domain.log_n;
let log_height: Var<C::N> =
builder.eval(log2_domain_size + config.log_blowup);
let log_height: Var<C::N> = builder.eval(log2_domain_size + log_blowup);

let bits_reduced: Var<C::N> =
builder.eval(log_global_max_height - log_height);
Expand All @@ -122,28 +123,23 @@ pub fn verify_two_adic_pcs<C: Config>(
builder.range(0, mat_points.len()).for_each(|l, builder| {
let z: Ext<C::F, C::EF> = builder.get(&mat_points, l);
let ps_at_z = builder.get(&mat_values, l);

let input = FriFoldInput {
z,
alpha,
x,
log_height,
mat_opening: mat_opening.clone(),
ps_at_z: ps_at_z.clone(),
alpha_pow: alpha_pow.clone(),
ro: ro.clone(),
};

let mut input_ptr = builder.array::<FriFoldInput<_>>(1);
builder.set(&mut input_ptr, 0, input);

builder.range(0, ps_at_z.len()).for_each(|m, builder| {
let p_at_x: SymbolicExt<C::F, C::EF> =
builder.get(&mat_opening, m).into();
let p_at_z: SymbolicExt<C::F, C::EF> =
builder.get(&ps_at_z, m).into();

let quotient: SymbolicExt<C::F, C::EF> =
(-p_at_z + p_at_x) / (-z + x);

let ro_at_log_height = builder.get(&ro, log_height);
let alpha_pow_at_log_height = builder.get(&alpha_pow, log_height);

builder.set(
&mut ro,
log_height,
ro_at_log_height + alpha_pow_at_log_height * quotient,
);
builder.set(
&mut alpha_pow,
log_height,
alpha_pow_at_log_height * alpha,
);
builder.push(DslIR::FriFold(m, input_ptr.clone()));
});
});
});
Expand Down Expand Up @@ -319,7 +315,7 @@ pub(crate) mod tests {
builder.set(&mut subgroups, i, domain_value);
}
FriConfigVariable {
log_blowup: InnerVal::from_canonical_usize(config.log_blowup),
log_blowup: config.log_blowup,
num_queries: config.num_queries,
proof_of_work_bits: config.proof_of_work_bits,
subgroups,
Expand Down
2 changes: 1 addition & 1 deletion recursion/program/src/fri/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub type DigestVariable<C: Config> = Array<C, Felt<C::F>>;

#[derive(Clone)]
pub struct FriConfigVariable<C: Config> {
pub log_blowup: C::N,
pub log_blowup: usize,
pub num_queries: usize,
pub proof_of_work_bits: usize,
pub generators: Array<C, Felt<C::F>>,
Expand Down
1 change: 1 addition & 0 deletions recursion/program/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::fri::TwoAdicMultiplicativeCosetVariable;

pub type PublicValuesDigestVariable<C: Config> = Array<C, Felt<C::F>>;

/// Reference: https://github.com/Plonky3/Plonky3/blob/4809fa7bedd9ba8f6f5d3267b1592618e3776c57/fri/src/proof.rs#L12
#[derive(DslVariable, Clone)]
pub struct ShardProofVariable<C: Config> {
pub index: Var<C::N>,
Expand Down

0 comments on commit d72301a

Please sign in to comment.