Skip to content

Commit

Permalink
almost there with the mul
Browse files Browse the repository at this point in the history
  • Loading branch information
Rexicon226 committed Jan 12, 2025
1 parent 889e18f commit 6141e41
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 40 deletions.
68 changes: 29 additions & 39 deletions src/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ pub const Hasher = struct {
var t: Element = undefined;
t = s.square(); // t = s ^ 2
t = t.square(); // t = s ^ 4
s.mul(t); // s = s ^ 5
s.multiply(t); // s = s ^ 5
}
}

Expand All @@ -118,7 +118,7 @@ pub const Hasher = struct {
for (0..hasher.state.len) |i| {
for (hasher.state.slice(), 0..) |*a, j| {
var t: Element = a.*;
t.mul(params.mds[i][j]);
t.multiply(params.mds[i][j]);
buffer[i].add(t);
}
}
Expand All @@ -129,24 +129,24 @@ pub const Hasher = struct {
pub const Element = struct {
value: u256,

const ZERO: Element = .{ .value = 0 };
/// The additive identity of the field.
pub const ZERO: Element = .{ .value = 0 };
/// The prime field modulus. NOT in Montgomery form.
pub const MODULUS = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001;
pub const MODULUS: u254 = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001;
/// `R = M % MODULUS` where `M` is the power of `2^64` closest to the number of bits
/// required to represent the field modulus.
const R = computeR() catch @panic("failed to compute R");
const N = 0x73f82f1d0d8341b2e39a9828990623916586864b4c6911b3c2e1f593efffffff;
/// `INVERSE = -MODULUS ^ (-1) mod 2^64`
const INVERSE = computeInv() catch @panic("failed to compute INV");

fn computeR() !u256 {
return try std.math.powi(u257, 2, 256) % MODULUS;
}

// TODO: compute it at comptime!
fn computeInv() !u64 {
return 0xC2E1F593EFFFFFFF;
}
pub const R = (std.math.powi(u257, 2, 256) catch @panic("failed to compute R")) % MODULUS;
/// `INVERSE = -MODULUS ^ (-1) mod 2^254`
///
/// `INVERSE = R ^ (-1)`
pub const INVERSE = inv: {
var inv: u256 = 1;
for (0..@bitSizeOf(@TypeOf(MODULUS))) |_| {
inv *%= inv;
inv *%= MODULUS;
}
break :inv -%inv;
};

/// Adds two field elements in the Montgomery domain.
fn add(self: *Element, other: Element) void {
Expand All @@ -158,34 +158,24 @@ pub const Element = struct {
}

/// Multiplies two field elements in the Montgomery domain.
fn mul(self: *Element, other: Element) void {
const a = self.value;
const c = other.value;

const product = @as(u512, a) * @as(u512, c);
fn multiply(self: *Element, other: Element) void {
const product = @as(u512, self.value) * @as(u512, other.value);
const low_prod: u256 = @truncate(product);
const high_prod: u256 = @intCast(product >> 256);

const reduced: u256 = low_prod *% N;

const reduced: u256 = low_prod *% INVERSE;
const mod_prod = @as(u512, reduced) * MODULUS;
const low_mod_prod: u256 = @truncate(mod_prod);
const high_mod_prod: u256 = @intCast(mod_prod >> 256);
const carry1 = @addWithOverflow(low_prod, low_mod_prod)[1];

const final_sum = @as(u512, carry1) + @as(u512, high_prod) + @as(u512, high_mod_prod);
const low_final_sum: u256 = @truncate(final_sum);
const carry2: u1 = @truncate(final_sum >> 256);
const carry = @addWithOverflow(low_prod, low_mod_prod)[1];
var final_sum: u256 = @truncate(@as(u512, carry) + @as(u512, high_prod) + @as(u512, high_mod_prod));

const adjusted_diff = -@as(i512, low_final_sum) - @as(i512, MODULUS);
const is_negative: u1 = @bitCast(@as(i1, @truncate(adjusted_diff >> 256)));
const low_adjusted_diff: u256 = @bitCast(@as(i256, @truncate(adjusted_diff)));

const is_carry_adjusted = @subWithOverflow(carry2, is_negative)[1];
const mask = 0 -% @as(u256, is_carry_adjusted);
const result: u256 = (mask & low_final_sum) | ((~mask) & low_adjusted_diff);
if (final_sum >= MODULUS) {
final_sum -%= MODULUS;
}

self.value = result;
self.value = final_sum;
}

/// Squares a field element in the Montgomery domain.
Expand All @@ -194,13 +184,13 @@ pub const Element = struct {
/// of a better solution.
fn square(self: Element) Element {
var out: Element = self;
out.mul(self);
out.multiply(self);
return out;
}

/// Translates a field element out of the Montgomery domain.
fn fromMontgomery(self: Element) Element {
const product: u256 = @truncate(@as(u512, self.value) * @as(u512, N));
const product: u256 = @truncate(@as(u512, self.value) * @as(u512, INVERSE));
const mod_prod = @as(u512, product) * MODULUS;
const low_mod_prod: u256 = @truncate(mod_prod);
const high_mod_prod: u256 = @truncate(mod_prod >> 256);
Expand All @@ -218,7 +208,7 @@ pub const Element = struct {
@as(u512, 0x216d0b17f4e44a58c49833d53bb808553fe3ab1e35c59e31bb8e645ae216da7);
const low_prod: u256 = @truncate(product);
const high_prod: u256 = @truncate(product >> 256);
const n_prod: u256 = @truncate(@as(u512, low_prod) * @as(u512, N));
const n_prod: u256 = @truncate(@as(u512, low_prod) * @as(u512, INVERSE));
const mod_prod = @as(u512, n_prod) * MODULUS;
const low_mod_prod: u256 = @truncate(mod_prod);
const high_mod_prod: u256 = @truncate(mod_prod >> 256);
Expand Down
3 changes: 2 additions & 1 deletion tests/test.zig
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
const std = @import("std");
const Hasher = @import("poseidon").Hasher;
const poseidon = @import("poseidon");
const Hasher = poseidon.Hasher;

const expectEqualSlices = std.testing.expectEqualSlices;

Expand Down

0 comments on commit 6141e41

Please sign in to comment.