Skip to content

Commit

Permalink
compute R2 at compile time instead of hardcoding
Browse files Browse the repository at this point in the history
  • Loading branch information
Rexicon226 committed Jan 12, 2025
1 parent 6141e41 commit accba50
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ pub const Element = struct {
pub const MODULUS: u254 = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001;
/// `R = M % MODULUS` where `M` is the power of `2^64` closest to the number of bits
/// required to represent the field modulus.
pub const R = (std.math.powi(u257, 2, 256) catch @panic("failed to compute R")) % MODULUS;
pub const R: u256 = (std.math.powi(u257, 2, 256) catch @panic("failed to compute R")) % MODULUS;
/// `R2 = R^2 % MODULUS`
pub const R2: u256 = (std.math.powi(u512, R, 2) catch @panic("failed to compute R2")) % MODULUS;
/// `INVERSE = -MODULUS ^ (-1) mod 2^254`
///
/// `INVERSE = R ^ (-1)`
Expand Down Expand Up @@ -196,6 +198,8 @@ pub const Element = struct {
const high_mod_prod: u256 = @truncate(mod_prod >> 256);
const add_overflow: u1 = @addWithOverflow(self.value, low_mod_prod)[1];
const adjusted_diff = add_overflow + high_mod_prod;

// TODO: this is just a compare and reduce
const carry_adjust, const is_carry_adjusted = @subWithOverflow(adjusted_diff, MODULUS);
const is_negative = @subWithOverflow(0, is_carry_adjusted)[1];
const result = if (is_negative == 0) carry_adjust else adjusted_diff;
Expand All @@ -204,16 +208,17 @@ pub const Element = struct {

/// Translates a field element into of the Montgomery domain.
pub fn fromInteger(integer: u256) Element {
const product = @as(u512, integer) *
@as(u512, 0x216d0b17f4e44a58c49833d53bb808553fe3ab1e35c59e31bb8e645ae216da7);
const product = @as(u512, integer) * @as(u512, R2);
const low_prod: u256 = @truncate(product);
const high_prod: u256 = @truncate(product >> 256);
const n_prod: u256 = @truncate(@as(u512, low_prod) * @as(u512, INVERSE));
const mod_prod = @as(u512, n_prod) * MODULUS;
const low_mod_prod: u256 = @truncate(mod_prod);
const high_mod_prod: u256 = @truncate(mod_prod >> 256);
const add_overflow = @addWithOverflow(low_prod, low_mod_prod)[1];
const adjusted_diff = (@as(u256, add_overflow) + high_prod) + high_mod_prod;
const adjusted_diff = add_overflow + high_prod + high_mod_prod;

// TODO: this is just a compare and reduce
const carry_adjust, const is_carry_adjusted = @subWithOverflow(adjusted_diff, MODULUS);
const is_negative = @subWithOverflow(0, is_carry_adjusted)[1];
const result = if (is_negative == 0) carry_adjust else adjusted_diff;
Expand Down

0 comments on commit accba50

Please sign in to comment.