Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into chris/sdk-improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
ctian1 committed Apr 24, 2024
2 parents ed51684 + b4ae919 commit b009149
Show file tree
Hide file tree
Showing 13 changed files with 700 additions and 525 deletions.
4 changes: 2 additions & 2 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ tempfile = "3.9.0"
tracing = "0.1.40"
tracing-forest = { version = "0.1.6", features = ["ansi", "smallvec"] }
tracing-subscriber = { version = "0.3.17", features = ["std", "env-filter"] }
strum_macros = "0.26.2"
strum = "0.26.2"
strum_macros = "0.26"
strum = "0.26"
web-time = "1.1.0"
rayon-scan = "0.1.1"

Expand Down
4 changes: 2 additions & 2 deletions recursion/circuit/src/challenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ mod tests {
challenger.observe(&mut builder, c);
let result2 = challenger.sample_ext(&mut builder);

builder.assert_ext_eq(SymbolicExt::Const(gt1), result1);
builder.assert_ext_eq(SymbolicExt::Const(gt2), result2);
builder.assert_ext_eq(SymbolicExt::from_f(gt1), result1);
builder.assert_ext_eq(SymbolicExt::from_f(gt2), result2);

let mut backend = ConstraintCompiler::<OuterConfig>::default();
let constraints = backend.emit(builder.operations);
Expand Down
8 changes: 4 additions & 4 deletions recursion/circuit/src/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use sp1_core::stark::AirOpenedValues;
use sp1_core::stark::PROOF_MAX_NUM_PVS;
use sp1_core::stark::{MachineChip, StarkGenericConfig};
use sp1_recursion_compiler::ir::Array;
use sp1_recursion_compiler::ir::ExtensionOperand;
use sp1_recursion_compiler::ir::Felt;
use sp1_recursion_compiler::ir::SymbolicFelt;
use sp1_recursion_compiler::ir::{Builder, Config, Ext};
use sp1_recursion_compiler::prelude::SymbolicExt;
use sp1_recursion_program::commit::PolynomialSpaceVariable;
Expand Down Expand Up @@ -45,7 +45,7 @@ where
.iter()
.enumerate()
.map(|(e_i, &x)| {
x * SymbolicExt::<C::F, C::EF>::Const(C::EF::monomial(e_i))
x * SymbolicExt::<C::F, C::EF>::from_f(C::EF::monomial(e_i))
})
.sum::<SymbolicExt<_, _>>(),
)
Expand Down Expand Up @@ -101,8 +101,8 @@ where
// Calculate: other_domain.zp_at_point(zeta)
// * other_domain.zp_at_point(domain.first_point()).inverse()
let first_point = domain.first_point(builder);
let first_point: Ext<_, _> =
builder.eval(SymbolicExt::Base(SymbolicFelt::Val(first_point).into()));
let first_point_ext = first_point.to_operand().symbolic();
let first_point: Ext<_, _> = builder.eval(first_point_ext);
let z = other_domain.zp_at_point(builder, first_point);
other_domain.zp_at_point(builder, zeta) * z.inverse()
})
Expand Down
18 changes: 9 additions & 9 deletions recursion/circuit/src/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ pub fn verify_two_adic_pcs<C: Config>(
.zip(&fri_challenges.query_indices)
.map(|(query_opening, &index)| {
let mut ro: [Ext<C::F, C::EF>; 32] =
[builder.eval(SymbolicExt::Const(C::EF::zero())); 32];
[builder.eval(SymbolicExt::from_f(C::EF::zero())); 32];
let mut alpha_pow: [Ext<C::F, C::EF>; 32] =
[builder.eval(SymbolicExt::Const(C::EF::one())); 32];
[builder.eval(SymbolicExt::from_f(C::EF::one())); 32];

for (batch_opening, round) in izip!(query_opening.clone(), &rounds) {
let batch_commit = round.batch_commit;
Expand Down Expand Up @@ -174,8 +174,8 @@ pub fn verify_query<C: Config>(
reduced_openings: [Ext<C::F, C::EF>; 32],
log_max_height: usize,
) -> Ext<C::F, C::EF> {
let mut folded_eval: Ext<C::F, C::EF> = builder.eval(SymbolicExt::Const(C::EF::zero()));
let two_adic_generator = builder.eval(SymbolicExt::Const(C::EF::two_adic_generator(
let mut folded_eval: Ext<C::F, C::EF> = builder.eval(SymbolicExt::from_f(C::EF::zero()));
let two_adic_generator = builder.eval(SymbolicExt::from_f(C::EF::two_adic_generator(
log_max_height,
)));
let index_bits = builder.num2bits_v_circuit(index, 256);
Expand Down Expand Up @@ -288,7 +288,7 @@ pub mod tests {
.iter()
.map(|commit_phase_opening| {
let sibling_value =
builder.eval(SymbolicExt::Const(commit_phase_opening.sibling_value));
builder.eval(SymbolicExt::from_f(commit_phase_opening.sibling_value));
let opening_proof = commit_phase_opening
.opening_proof
.iter()
Expand All @@ -313,7 +313,7 @@ pub mod tests {
FriProofVariable {
commit_phase_commits,
query_proofs,
final_poly: builder.eval(SymbolicExt::Const(fri_proof.final_poly)),
final_poly: builder.eval(SymbolicExt::from_f(fri_proof.final_poly)),
pow_witness: builder.eval(fri_proof.pow_witness),
}
}
Expand Down Expand Up @@ -372,14 +372,14 @@ pub mod tests {
for (domain, poly) in os.into_iter() {
let points: Vec<Ext<OuterVal, OuterChallenge>> = poly
.iter()
.map(|(p, _)| builder.eval(SymbolicExt::Const(*p)))
.map(|(p, _)| builder.eval(SymbolicExt::from_f(*p)))
.collect::<Vec<_>>();
let values: Vec<Vec<Ext<OuterVal, OuterChallenge>>> = poly
.iter()
.map(|(_, v)| {
v.clone()
.iter()
.map(|t| builder.eval(SymbolicExt::Const(*t)))
.map(|t| builder.eval(SymbolicExt::from_f(*t)))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
Expand Down Expand Up @@ -469,7 +469,7 @@ pub mod tests {

for i in 0..fri_challenges_gt.betas.len() {
builder.assert_ext_eq(
SymbolicExt::Const(fri_challenges_gt.betas[i]),
SymbolicExt::from_f(fri_challenges_gt.betas[i]),
fri_challenges.betas[i],
);
}
Expand Down
38 changes: 22 additions & 16 deletions recursion/compiler/src/ir/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,53 +567,59 @@ impl<'a, C: Config> IfBuilder<'a, C> {

fn condition(&mut self) -> IfCondition<C::N> {
match (self.lhs.clone(), self.rhs.clone(), self.is_eq) {
(SymbolicVar::Const(lhs), SymbolicVar::Const(rhs), true) => {
(SymbolicVar::Const(lhs, _), SymbolicVar::Const(rhs, _), true) => {
IfCondition::EqConst(lhs, rhs)
}
(SymbolicVar::Const(lhs), SymbolicVar::Const(rhs), false) => {
(SymbolicVar::Const(lhs, _), SymbolicVar::Const(rhs, _), false) => {
IfCondition::NeConst(lhs, rhs)
}
(SymbolicVar::Const(lhs), SymbolicVar::Val(rhs), true) => IfCondition::EqI(rhs, lhs),
(SymbolicVar::Const(lhs), SymbolicVar::Val(rhs), false) => IfCondition::NeI(rhs, lhs),
(SymbolicVar::Const(lhs), rhs, true) => {
(SymbolicVar::Const(lhs, _), SymbolicVar::Val(rhs, _), true) => {
IfCondition::EqI(rhs, lhs)
}
(SymbolicVar::Const(lhs, _), SymbolicVar::Val(rhs, _), false) => {
IfCondition::NeI(rhs, lhs)
}
(SymbolicVar::Const(lhs, _), rhs, true) => {
let rhs: Var<C::N> = self.builder.eval(rhs);
IfCondition::EqI(rhs, lhs)
}
(SymbolicVar::Const(lhs), rhs, false) => {
(SymbolicVar::Const(lhs, _), rhs, false) => {
let rhs: Var<C::N> = self.builder.eval(rhs);
IfCondition::NeI(rhs, lhs)
}
(SymbolicVar::Val(lhs), SymbolicVar::Const(rhs), true) => {
(SymbolicVar::Val(lhs, _), SymbolicVar::Const(rhs, _), true) => {
let lhs: Var<C::N> = self.builder.eval(lhs);
IfCondition::EqI(lhs, rhs)
}
(SymbolicVar::Val(lhs), SymbolicVar::Const(rhs), false) => {
(SymbolicVar::Val(lhs, _), SymbolicVar::Const(rhs, _), false) => {
let lhs: Var<C::N> = self.builder.eval(lhs);
IfCondition::NeI(lhs, rhs)
}
(lhs, SymbolicVar::Const(rhs), true) => {
(lhs, SymbolicVar::Const(rhs, _), true) => {
let lhs: Var<C::N> = self.builder.eval(lhs);
IfCondition::EqI(lhs, rhs)
}
(lhs, SymbolicVar::Const(rhs), false) => {
(lhs, SymbolicVar::Const(rhs, _), false) => {
let lhs: Var<C::N> = self.builder.eval(lhs);
IfCondition::NeI(lhs, rhs)
}
(SymbolicVar::Val(lhs), SymbolicVar::Val(rhs), true) => IfCondition::Eq(lhs, rhs),
(SymbolicVar::Val(lhs), SymbolicVar::Val(rhs), false) => IfCondition::Ne(lhs, rhs),
(SymbolicVar::Val(lhs), rhs, true) => {
(SymbolicVar::Val(lhs, _), SymbolicVar::Val(rhs, _), true) => IfCondition::Eq(lhs, rhs),
(SymbolicVar::Val(lhs, _), SymbolicVar::Val(rhs, _), false) => {
IfCondition::Ne(lhs, rhs)
}
(SymbolicVar::Val(lhs, _), rhs, true) => {
let rhs: Var<C::N> = self.builder.eval(rhs);
IfCondition::Eq(lhs, rhs)
}
(SymbolicVar::Val(lhs), rhs, false) => {
(SymbolicVar::Val(lhs, _), rhs, false) => {
let rhs: Var<C::N> = self.builder.eval(rhs);
IfCondition::Ne(lhs, rhs)
}
(lhs, SymbolicVar::Val(rhs), true) => {
(lhs, SymbolicVar::Val(rhs, _), true) => {
let lhs: Var<C::N> = self.builder.eval(lhs);
IfCondition::Eq(lhs, rhs)
}
(lhs, SymbolicVar::Val(rhs), false) => {
(lhs, SymbolicVar::Val(rhs, _), false) => {
let lhs: Var<C::N> = self.builder.eval(lhs);
IfCondition::Ne(lhs, rhs)
}
Expand Down
26 changes: 13 additions & 13 deletions recursion/compiler/src/ir/ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub struct Ptr<N> {
pub address: Var<N>,
}

pub struct SymbolicPtr<N> {
pub struct SymbolicPtr<N: Field> {
pub address: SymbolicVar<N>,
}

Expand Down Expand Up @@ -77,15 +77,15 @@ impl<C: Config> MemVariable<C> for Ptr<C::N> {
}
}

impl<N> From<Ptr<N>> for SymbolicPtr<N> {
impl<N: Field> From<Ptr<N>> for SymbolicPtr<N> {
fn from(ptr: Ptr<N>) -> Self {
SymbolicPtr {
address: SymbolicVar::Val(ptr.address),
address: SymbolicVar::from(ptr.address),
}
}
}

impl<N> Add for Ptr<N> {
impl<N: Field> Add for Ptr<N> {
type Output = SymbolicPtr<N>;

fn add(self, rhs: Self) -> Self::Output {
Expand All @@ -95,7 +95,7 @@ impl<N> Add for Ptr<N> {
}
}

impl<N> Sub for Ptr<N> {
impl<N: Field> Sub for Ptr<N> {
type Output = SymbolicPtr<N>;

fn sub(self, rhs: Self) -> Self::Output {
Expand All @@ -105,7 +105,7 @@ impl<N> Sub for Ptr<N> {
}
}

impl<N> Add for SymbolicPtr<N> {
impl<N: Field> Add for SymbolicPtr<N> {
type Output = Self;

fn add(self, rhs: Self) -> Self {
Expand All @@ -115,7 +115,7 @@ impl<N> Add for SymbolicPtr<N> {
}
}

impl<N> Sub for SymbolicPtr<N> {
impl<N: Field> Sub for SymbolicPtr<N> {
type Output = Self;

fn sub(self, rhs: Self) -> Self {
Expand All @@ -125,7 +125,7 @@ impl<N> Sub for SymbolicPtr<N> {
}
}

impl<N> Add<Ptr<N>> for SymbolicPtr<N> {
impl<N: Field> Add<Ptr<N>> for SymbolicPtr<N> {
type Output = Self;

fn add(self, rhs: Ptr<N>) -> Self {
Expand All @@ -135,7 +135,7 @@ impl<N> Add<Ptr<N>> for SymbolicPtr<N> {
}
}

impl<N> Sub<Ptr<N>> for SymbolicPtr<N> {
impl<N: Field> Sub<Ptr<N>> for SymbolicPtr<N> {
type Output = Self;

fn sub(self, rhs: Ptr<N>) -> Self {
Expand All @@ -145,7 +145,7 @@ impl<N> Sub<Ptr<N>> for SymbolicPtr<N> {
}
}

impl<N> Add<SymbolicPtr<N>> for Ptr<N> {
impl<N: Field> Add<SymbolicPtr<N>> for Ptr<N> {
type Output = SymbolicPtr<N>;

fn add(self, rhs: SymbolicPtr<N>) -> SymbolicPtr<N> {
Expand All @@ -155,7 +155,7 @@ impl<N> Add<SymbolicPtr<N>> for Ptr<N> {
}
}

impl<N> Add<SymbolicVar<N>> for Ptr<N> {
impl<N: Field> Add<SymbolicVar<N>> for Ptr<N> {
type Output = SymbolicPtr<N>;

fn add(self, rhs: SymbolicVar<N>) -> SymbolicPtr<N> {
Expand All @@ -165,7 +165,7 @@ impl<N> Add<SymbolicVar<N>> for Ptr<N> {
}
}

impl<N> Sub<SymbolicVar<N>> for Ptr<N> {
impl<N: Field> Sub<SymbolicVar<N>> for Ptr<N> {
type Output = SymbolicPtr<N>;

fn sub(self, rhs: SymbolicVar<N>) -> SymbolicPtr<N> {
Expand All @@ -175,7 +175,7 @@ impl<N> Sub<SymbolicVar<N>> for Ptr<N> {
}
}

impl<N> Sub<SymbolicPtr<N>> for Ptr<N> {
impl<N: Field> Sub<SymbolicPtr<N>> for Ptr<N> {
type Output = SymbolicPtr<N>;

fn sub(self, rhs: SymbolicPtr<N>) -> SymbolicPtr<N> {
Expand Down
Loading

0 comments on commit b009149

Please sign in to comment.