Skip to content

Commit

Permalink
refactor from_field, add, sub, neg,
Browse files Browse the repository at this point in the history
  • Loading branch information
olehmisar committed Jan 20, 2025
1 parent e2d6b53 commit a793732
Showing 1 changed file with 93 additions and 79 deletions.
172 changes: 93 additions & 79 deletions src/fns/constrained_ops.nr
Original file line number Diff line number Diff line change
Expand Up @@ -42,31 +42,34 @@ pub(crate) fn from_field<let N: u32, let MOD_BITS: u32>(
// safty: we check that the resulting limbs represent the intended field element
// we check the bit length, the limbs being max 120 bits, and the value in total is less than the field modulus
let result = unsafe { __from_field::<N>(field) };
// validate the limbs are in range and the value in total is less than 2^254

let TWO_POW_120 = 0x1000000000000000000000000000000;
// validate that the last limb is less than the modulus
if N > 2 {
// validate that the result is less than the modulus
let mut grumpkin_modulus = [0; N];
grumpkin_modulus[0] = 0x33e84879b9709143e1f593f0000001;
grumpkin_modulus[1] = 0x4e72e131a029b85045b68181585d28;
grumpkin_modulus[2] = 0x3064;
validate_gt::<N, 254>(grumpkin_modulus, result);
// validate that the limbs are in range
validate_in_range::<N, 254>(result);

if !std::runtime::is_unconstrained() {
// validate the limbs are in range and the value in total is less than 2^254

let TWO_POW_120 = 0x1000000000000000000000000000000;
// validate that the last limb is less than the modulus
if N > 2 {
// validate that the result is less than the modulus
let mut grumpkin_modulus = [0; N];
grumpkin_modulus[0] = 0x33e84879b9709143e1f593f0000001;
grumpkin_modulus[1] = 0x4e72e131a029b85045b68181585d28;
grumpkin_modulus[2] = 0x3064;
validate_gt::<N, 254>(grumpkin_modulus, result);
// validate that the limbs are in range
validate_in_range::<N, 254>(result);
}
// validate the limbs sum up to the field value
let field_val = if N < 2 {
result[0]
} else if N == 2 {
validate_in_range::<N, 254>(result);
result[0] + result[1] * TWO_POW_120
} else {
validate_in_range::<N, 254>(result);
result[0] + result[1] * TWO_POW_120 + result[2] * TWO_POW_120 * TWO_POW_120
};
assert(field_val == field);
}
// validate the limbs sum up to the field value
let field_val = if N < 2 {
result[0]
} else if N == 2 {
validate_in_range::<N, 254>(result);
result[0] + result[1] * TWO_POW_120
} else {
validate_in_range::<N, 254>(result);
result[0] + result[1] * TWO_POW_120 + result[2] * TWO_POW_120 * TWO_POW_120
};
assert(field_val == field);
result
}

Expand Down Expand Up @@ -335,18 +338,22 @@ pub(crate) fn neg<let N: u32, let MOD_BITS: u32>(
) -> [Field; N] {
// so we do... p - x - r = 0 and there might be borrow flags
let (result, borrow_flags) = unsafe { __neg_with_flags(params, val) };
validate_in_range::<_, MOD_BITS>(result);
let modulus = params.modulus;
let borrow_shift = 0x1000000000000000000000000000000;
let result_limb = modulus[0] - val[0] - result[0] + (borrow_flags[0] as Field * borrow_shift);
assert(result_limb == 0);
for i in 1..N - 1 {
let result_limb = modulus[i] - val[i] - result[i] - borrow_flags[i - 1] as Field
+ (borrow_flags[i] as Field * borrow_shift);
if !std::runtime::is_unconstrained() {
validate_in_range::<_, MOD_BITS>(result);
let modulus = params.modulus;
let borrow_shift = 0x1000000000000000000000000000000;
let result_limb =
modulus[0] - val[0] - result[0] + (borrow_flags[0] as Field * borrow_shift);
assert(result_limb == 0);
for i in 1..N - 1 {
let result_limb = modulus[i] - val[i] - result[i] - borrow_flags[i - 1] as Field
+ (borrow_flags[i] as Field * borrow_shift);
assert(result_limb == 0);
}
let result_limb =
modulus[N - 1] - val[N - 1] - result[N - 1] - borrow_flags[N - 2] as Field;
assert(result_limb == 0);
}
let result_limb = modulus[N - 1] - val[N - 1] - result[N - 1] - borrow_flags[N - 2] as Field;
assert(result_limb == 0);
result
}

Expand All @@ -358,31 +365,36 @@ pub(crate) fn add<let N: u32, let MOD_BITS: u32>(
// so we do... p - x - r = 0 and there might be borrow flags
let (result, carry_flags, borrow_flags, overflow_modulus) =
unsafe { __add_with_flags(params, lhs, rhs) };
validate_in_range::<_, MOD_BITS>(result);
let modulus = params.modulus;
let borrow_shift = 0x1000000000000000000000000000000;
let carry_shift = 0x1000000000000000000000000000000;

let mut subtrahend: [Field; N] = [0; N];
if (overflow_modulus) {
subtrahend = modulus;
}
let result_limb = lhs[0] + rhs[0] - subtrahend[0] - result[0]
+ (borrow_flags[0] as Field * borrow_shift)
- (carry_flags[0] as Field * carry_shift);
assert(result_limb == 0);
for i in 1..N - 1 {
let result_limb = lhs[i] + rhs[i] - subtrahend[i] - result[i] - borrow_flags[i - 1] as Field
+ carry_flags[i - 1] as Field
+ ((borrow_flags[i] as Field - carry_flags[i] as Field) * borrow_shift);
if !std::runtime::is_unconstrained() {
validate_in_range::<_, MOD_BITS>(result);
let modulus = params.modulus;
let borrow_shift = 0x1000000000000000000000000000000;
let carry_shift = 0x1000000000000000000000000000000;

let mut subtrahend: [Field; N] = [0; N];
if (overflow_modulus) {
subtrahend = modulus;
}
let result_limb = lhs[0] + rhs[0] - subtrahend[0] - result[0]
+ (borrow_flags[0] as Field * borrow_shift)
- (carry_flags[0] as Field * carry_shift);
assert(result_limb == 0);
for i in 1..N - 1 {
let result_limb = lhs[i] + rhs[i]
- subtrahend[i]
- result[i]
- borrow_flags[i - 1] as Field
+ carry_flags[i - 1] as Field
+ ((borrow_flags[i] as Field - carry_flags[i] as Field) * borrow_shift);
assert(result_limb == 0);
}
let result_limb = lhs[N - 1] + rhs[N - 1]
- subtrahend[N - 1]
- result[N - 1]
- borrow_flags[N - 2] as Field
+ carry_flags[N - 2] as Field;
assert(result_limb == 0);
}
let result_limb = lhs[N - 1] + rhs[N - 1]
- subtrahend[N - 1]
- result[N - 1]
- borrow_flags[N - 2] as Field
+ carry_flags[N - 2] as Field;
assert(result_limb == 0);
result
}

Expand All @@ -396,30 +408,32 @@ pub(crate) fn sub<let N: u32, let MOD_BITS: u32>(
// p + a - b - r = 0
let (result, carry_flags, borrow_flags, underflow) =
unsafe { __sub_with_flags(params, lhs, rhs) };
validate_in_range::<_, MOD_BITS>(result);
let modulus = params.modulus;
let borrow_shift = 0x1000000000000000000000000000000;
let carry_shift = 0x1000000000000000000000000000000;

let mut addend: [Field; N] = [0; N];
if (underflow) {
addend = modulus;
}
let result_limb = lhs[0] - rhs[0] + addend[0] - result[0]
+ (borrow_flags[0] as Field * borrow_shift)
- (carry_flags[0] as Field * carry_shift);
assert(result_limb == 0);
for i in 1..N - 1 {
let result_limb = lhs[i] - rhs[i] + addend[i] - result[i] - borrow_flags[i - 1] as Field
+ carry_flags[i - 1] as Field
+ ((borrow_flags[i] as Field - carry_flags[i] as Field) * borrow_shift);
if !std::runtime::is_unconstrained() {
validate_in_range::<_, MOD_BITS>(result);
let modulus = params.modulus;
let borrow_shift = 0x1000000000000000000000000000000;
let carry_shift = 0x1000000000000000000000000000000;

let mut addend: [Field; N] = [0; N];
if (underflow) {
addend = modulus;
}
let result_limb = lhs[0] - rhs[0] + addend[0] - result[0]
+ (borrow_flags[0] as Field * borrow_shift)
- (carry_flags[0] as Field * carry_shift);
assert(result_limb == 0);
for i in 1..N - 1 {
let result_limb = lhs[i] - rhs[i] + addend[i] - result[i] - borrow_flags[i - 1] as Field
+ carry_flags[i - 1] as Field
+ ((borrow_flags[i] as Field - carry_flags[i] as Field) * borrow_shift);
assert(result_limb == 0);
}
let result_limb = lhs[N - 1] - rhs[N - 1] + addend[N - 1]
- result[N - 1]
- borrow_flags[N - 2] as Field
+ carry_flags[N - 2] as Field;
assert(result_limb == 0);
}
let result_limb = lhs[N - 1] - rhs[N - 1] + addend[N - 1]
- result[N - 1]
- borrow_flags[N - 2] as Field
+ carry_flags[N - 2] as Field;
assert(result_limb == 0);
result
}

Expand Down

0 comments on commit a793732

Please sign in to comment.