Skip to content

Commit

Permalink
feat: add support for witness in programs (#476)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtguibas authored Apr 7, 2024
1 parent f6d6fd8 commit 72cb1ba
Show file tree
Hide file tree
Showing 29 changed files with 1,311 additions and 800 deletions.
2 changes: 1 addition & 1 deletion core/src/air/public_values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub const PV_DIGEST_NUM_WORDS: usize = 8;
/// The PublicValuesDigest struct is used to represent the public values digest. This is the hash of all the
/// bytes that the guest program has written to public values.
#[derive(Serialize, Deserialize, Clone, Copy, Default, Debug)]
pub struct PublicValuesDigest<T>([T; PV_DIGEST_NUM_WORDS]);
pub struct PublicValuesDigest<T>(pub [T; PV_DIGEST_NUM_WORDS]);

/// Conversion from a byte array into a PublicValuesDigest<u32>.
impl From<&[u8]> for PublicValuesDigest<u32> {
Expand Down
8 changes: 8 additions & 0 deletions core/src/stark/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::lookup::InteractionKind;
use crate::stark::record::MachineRecord;
use crate::stark::DebugConstraintBuilder;
use crate::stark::ProverConstraintFolder;
use crate::stark::ShardProof;
use crate::stark::VerifierConstraintFolder;

use p3_air::Air;
Expand Down Expand Up @@ -94,6 +95,13 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> MachineStark<SC, A> {
.sorted_by_key(|chip| chip_ordering.get(&chip.name()))
}

pub fn chips_sorted_indices(&self, proof: &ShardProof<SC>) -> Vec<Option<usize>> {
self.chips()
.iter()
.map(|chip| proof.chip_ordering.get(&chip.name()).cloned())
.collect()
}

/// The setup preprocessing phase.
///
/// Given a program, this function generates the proving and verifying keys. The keys correspond
Expand Down
3 changes: 2 additions & 1 deletion prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ pub fn prove_sp1() -> (Proof<InnerSC>, VerifyingKey<InnerSC>) {
}

pub fn prove_compress(sp1_proof: Proof<InnerSC>, vk: VerifyingKey<InnerSC>) {
let program = build_compress(sp1_proof, vk);
let (program, witness_stream) = build_compress(sp1_proof, vk);

let config = InnerSC::default();
let machine = InnerA::machine(config);
let mut runtime = Runtime::<InnerF, InnerEF, _>::new(&program, machine.config().perm.clone());
runtime.witness_stream = witness_stream;

let time = Instant::now();
runtime.run();
Expand Down
2 changes: 1 addition & 1 deletion recursion/circuit/src/domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use p3_field::Field;
use p3_field::{AbstractField, TwoAdicField};
use sp1_recursion_compiler::prelude::*;
use sp1_recursion_program::commit::PolynomialSpaceVariable;
use sp1_recursion_program::types::FriConfigVariable;
use sp1_recursion_program::fri::types::FriConfigVariable;

#[derive(Clone, Copy)]
pub struct TwoAdicMultiplicativeCosetVariable<C: Config> {
Expand Down
24 changes: 24 additions & 0 deletions recursion/compiler/src/asm/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,13 +524,37 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> AsmCo
}
_ => unimplemented!(),
},
DslIR::HintLen(dst) => self.push(AsmInstruction::HintLen(dst.fp())),
DslIR::HintVars(dst) => match dst {
Array::Dyn(dst, _) => self.push(AsmInstruction::Hint(dst.fp())),
_ => unimplemented!(),
},
DslIR::HintFelts(dst) => match dst {
Array::Dyn(dst, _) => self.push(AsmInstruction::Hint(dst.fp())),
_ => unimplemented!(),
},
DslIR::HintExts(dst) => match dst {
Array::Dyn(dst, _) => self.push(AsmInstruction::Hint(dst.fp())),
_ => unimplemented!(),
},
DslIR::FriFold(m, input_ptr) => {
if let Array::Dyn(ptr, _) = input_ptr {
self.push(AsmInstruction::FriFold(m.fp(), ptr.fp()));
} else {
unimplemented!();
}
}
DslIR::Poseidon2CompressBabyBear(result, left, right) => {
match (result, left, right) {
(Array::Dyn(result, _), Array::Dyn(left, _), Array::Dyn(right, _)) => self
.push(AsmInstruction::Poseidon2Compress(
result.fp(),
left.fp(),
right.fp(),
)),
_ => unimplemented!(),
}
}
_ => unimplemented!(),
}
}
Expand Down
42 changes: 42 additions & 0 deletions recursion/compiler/src/asm/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,15 @@ pub enum AsmInstruction<F, EF> {

/// Perform a permutation of the Poseidon2 hash function on the array specified by the ptr.
Poseidon2Permute(i32, i32),
Poseidon2Compress(i32, i32, i32),

PrintV(i32),
PrintF(i32),
PrintE(i32),
Ext2Felt(i32, i32),

HintLen(i32),
Hint(i32),
// FRIFold(m, input) specific instructions.
FriFold(i32, i32),
}
Expand Down Expand Up @@ -844,6 +847,26 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
false,
true,
),
AsmInstruction::HintLen(dst) => Instruction::new(
Opcode::HintLen,
i32_f(dst),
i32_f_arr(dst),
f_u32(F::zero()),
F::zero(),
F::zero(),
false,
true,
),
AsmInstruction::Hint(dst) => Instruction::new(
Opcode::Hint,
i32_f(dst),
i32_f_arr(dst),
f_u32(F::zero()),
F::zero(),
F::zero(),
false,
true,
),
AsmInstruction::FriFold(m, ptr) => Instruction::new(
Opcode::FRIFold,
i32_f(m),
Expand All @@ -854,6 +877,16 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
false,
true,
),
AsmInstruction::Poseidon2Compress(result, src1, src2) => Instruction::new(
Opcode::Poseidon2Compress,
i32_f(result),
i32_f_arr(src1),
i32_f_arr(src2),
F::zero(),
F::zero(),
false,
false,
),
}
}

Expand Down Expand Up @@ -1136,9 +1169,18 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
write!(f, "print_e ({})fp", dst)
}
AsmInstruction::Ext2Felt(dst, src) => write!(f, "ext2felt ({})fp, {})fp", dst, src),
AsmInstruction::HintLen(dst) => write!(f, "hint_len ({})fp", dst),
AsmInstruction::Hint(dst) => write!(f, "hint ({})fp", dst),
AsmInstruction::FriFold(m, input_ptr) => {
write!(f, "fri_fold ({})fp, ({})fp", m, input_ptr)
}
AsmInstruction::Poseidon2Compress(result, src1, src2) => {
write!(
f,
"poseidon2_compress ({})fp, {})fp, {})fp",
result, src1, src2
)
}
}
}
}
74 changes: 67 additions & 7 deletions recursion/compiler/src/ir/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,22 @@ impl<C: Config> Builder<C> {
));
}

/// Applies the Poseidon2 permutation to the given array.
///
/// Reference: https://github.com/Plonky3/Plonky3/blob/4809fa7bedd9ba8f6f5d3267b1592618e3776c57/poseidon2/src/lib.rs#L119
pub fn poseidon2_compress_x(
&mut self,
result: &mut Array<C, Felt<C::F>>,
left: &Array<C, Felt<C::F>>,
right: &Array<C, Felt<C::F>>,
) {
self.operations.push(DslIR::Poseidon2CompressBabyBear(
result.clone(),
left.clone(),
right.clone(),
));
}

/// Applies the Poseidon2 permutation to the given array.
///
/// Reference: https://github.com/Plonky3/Plonky3/blob/4809fa7bedd9ba8f6f5d3267b1592618e3776c57/poseidon2/src/lib.rs#L119
Expand Down Expand Up @@ -394,13 +410,8 @@ impl<C: Config> Builder<C> {
builder.poseidon2_permute_mut(&state);
});

let mut result = self.dyn_array(DIGEST_SIZE);
for i in 0..DIGEST_SIZE {
let el = self.get(&state, i);
self.set(&mut result, i, el);
}

result
state.truncate(self, Usize::Const(DIGEST_SIZE));
state
}

/// Applies the Poseidon2 compression function to the given array.
Expand Down Expand Up @@ -649,6 +660,55 @@ impl<C: Config> Builder<C> {
pub fn power_of_two_expr(&mut self, power: Usize<C::N>) -> Ext<C::F, C::EF> {
self.sll(C::EF::one().cons(), power)
}

pub fn hint_len(&mut self) -> Var<C::N> {
let len = self.uninit();
self.operations.push(DslIR::HintLen(len));
len
}

pub fn hint_var(&mut self) -> Var<C::N> {
let len = self.hint_len();
let arr = self.dyn_array(len);
self.operations.push(DslIR::HintVars(arr.clone()));
self.get(&arr, 0)
}

pub fn hint_felt(&mut self) -> Felt<C::F> {
let len = self.hint_len();
let arr = self.dyn_array(len);
self.operations.push(DslIR::HintFelts(arr.clone()));
self.get(&arr, 0)
}

pub fn hint_ext(&mut self) -> Ext<C::F, C::EF> {
let len = self.hint_len();
let arr = self.dyn_array(len);
self.operations.push(DslIR::HintExts(arr.clone()));
self.get(&arr, 0)
}

pub fn hint_vars(&mut self) -> Array<C, Var<C::N>> {
let len = self.hint_len();
self.print_v(len);
let arr = self.dyn_array(len);
self.operations.push(DslIR::HintVars(arr.clone()));
arr
}

pub fn hint_felts(&mut self) -> Array<C, Felt<C::F>> {
let len = self.hint_len();
let arr = self.dyn_array(len);
self.operations.push(DslIR::HintFelts(arr.clone()));
arr
}

pub fn hint_exts(&mut self) -> Array<C, Ext<C::F, C::EF>> {
let len = self.hint_len();
let arr = self.dyn_array(len);
self.operations.push(DslIR::HintExts(arr.clone()));
arr
}
}

pub struct IfBuilder<'a, C: Config> {
Expand Down
9 changes: 9 additions & 0 deletions recursion/compiler/src/ir/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,21 @@ pub enum DslIR<C: Config> {
HintBitsV(Array<C, Var<C::N>>, Var<C::N>),
HintBitsF(Array<C, Var<C::N>>, Felt<C::F>),
Poseidon2PermuteBabyBear(Array<C, Felt<C::F>>, Array<C, Felt<C::F>>),
Poseidon2CompressBabyBear(
Array<C, Felt<C::F>>,
Array<C, Felt<C::F>>,
Array<C, Felt<C::F>>,
),
TwoAdicGenerator(Felt<C::F>, Usize<C::N>),
ReverseBitsLen(Usize<C::N>, Usize<C::N>, Usize<C::N>),
ExpUsizeV(Var<C::N>, Var<C::N>, Usize<C::N>),
ExpUsizeF(Felt<C::F>, Felt<C::F>, Usize<C::N>),
Ext2Felt(Array<C, Felt<C::F>>, Ext<C::F, C::EF>),

HintLen(Var<C::N>),
HintVars(Array<C, Var<C::N>>),
HintFelts(Array<C, Felt<C::F>>),
HintExts(Array<C, Ext<C::F, C::EF>>),
// FRI specific instructions.
FriFold(Var<C::N>, Array<C, FriFoldInput<C>>),

Expand Down
2 changes: 1 addition & 1 deletion recursion/compiler/src/ir/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1207,7 +1207,7 @@ impl<C: Config> Variable<C> for Ext<C::F, C::EF> {

impl<C: Config> MemVariable<C> for Ext<C::F, C::EF> {
fn size_of() -> usize {
4
1
}

fn load(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
Expand Down
11 changes: 3 additions & 8 deletions recursion/compiler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
#![allow(clippy::needless_range_loop)]
#![allow(clippy::type_complexity)]

use asm::AsmConfig;
use p3_baby_bear::BabyBear;
use p3_bn254_fr::Bn254Fr;
use p3_field::extension::BinomialExtensionField;
use prelude::Config;
use sp1_recursion_core::stark::config::{InnerChallenge, InnerVal};
extern crate alloc;

pub mod asm;
Expand All @@ -19,14 +21,7 @@ pub mod prelude {
pub use sp1_recursion_derive::DslVariable;
}

#[derive(Clone, Default, Debug)]
pub struct InnerConfig;

impl Config for InnerConfig {
type N = BabyBear;
type F = BabyBear;
type EF = BinomialExtensionField<BabyBear, 4>;
}
pub type InnerConfig = AsmConfig<InnerVal, InnerChallenge>;

#[derive(Clone, Default, Debug)]
pub struct OuterConfig;
Expand Down
42 changes: 42 additions & 0 deletions recursion/compiler/tests/io.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use p3_field::AbstractField;
use sp1_core::stark::StarkGenericConfig;
use sp1_core::utils::BabyBearPoseidon2;
use sp1_recursion_compiler::asm::VmBuilder;
use sp1_recursion_core::runtime::Runtime;

#[test]
fn test_io() {
type SC = BabyBearPoseidon2;
type F = <SC as StarkGenericConfig>::Val;
type EF = <SC as StarkGenericConfig>::Challenge;
let mut builder = VmBuilder::<F, EF>::default();

let arr = builder.hint_vars();
builder.range(0, arr.len()).for_each(|i, builder| {
let el = builder.get(&arr, i);
builder.print_v(el);
});

let arr = builder.hint_felts();
builder.range(0, arr.len()).for_each(|i, builder| {
let el = builder.get(&arr, i);
builder.print_f(el);
});

let arr = builder.hint_exts();
builder.range(0, arr.len()).for_each(|i, builder| {
let el = builder.get(&arr, i);
builder.print_e(el);
});

let program = builder.compile();

let config = SC::default();
let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.witness_stream = vec![
vec![F::zero().into(), F::zero().into(), F::one().into()],
vec![F::zero().into(), F::zero().into(), F::two().into()],
vec![F::one().into(), F::one().into(), F::two().into()],
];
runtime.run();
}
1 change: 1 addition & 0 deletions recursion/core/src/cpu/columns/opcode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ impl<F: Field> OpcodeSelectorCols<F> {
Opcode::PrintF => self.is_noop = F::one(),
Opcode::PrintE => self.is_noop = F::one(),
Opcode::FRIFold => self.is_fri_fold = F::one(),
_ => unreachable!(),
}
}
}
Expand Down
Loading

0 comments on commit 72cb1ba

Please sign in to comment.