Skip to content

Commit

Permalink
restructure the project
Browse files Browse the repository at this point in the history
  • Loading branch information
Rexicon226 committed Jan 12, 2025
1 parent accba50 commit 5996215
Show file tree
Hide file tree
Showing 8 changed files with 281 additions and 255 deletions.
17 changes: 12 additions & 5 deletions build.zig
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,25 @@ pub fn build(b: *std.Build) !void {
const target = b.standardTargetOptions(.{});
const optimize = b.standardOptimizeOption(.{});

const lib = b.addModule("poseidon", .{
.root_source_file = b.path("src/lib.zig"),
const ff_mod = b.addModule("ff", .{
.root_source_file = b.path("ff/ff.zig"),
.target = target,
.optimize = optimize,
});

const poseidon_mod = b.addModule("poseidon", .{
.root_source_file = b.path("poseidon/poseidon.zig"),
.target = target,
.optimize = optimize,
});
poseidon_mod.addImport("ff", ff_mod);

const test_exe = b.addTest(.{
.root_source_file = b.path("tests/test.zig"),
.target = target,
.optimize = optimize,
});
test_exe.root_module.addImport("poseidon", lib);
test_exe.root_module.addImport("poseidon", poseidon_mod);

const test_step = b.step("test", "Tests the library");
const run_test = b.addRunArtifact(test_exe);
Expand All @@ -28,7 +35,7 @@ pub fn build(b: *std.Build) !void {
.optimize = optimize,
});
fuzz_exe.linkLibC();
fuzz_exe.root_module.addImport("poseidon", lib);
fuzz_exe.root_module.addImport("poseidon", poseidon_mod);
fuzz_exe.root_module.omit_frame_pointer = false;
b.installArtifact(fuzz_exe);

Expand All @@ -44,7 +51,7 @@ pub fn build(b: *std.Build) !void {
.optimize = optimize,
});
bench_exe.linkLibC();
bench_exe.root_module.addImport("poseidon", lib);
bench_exe.root_module.addImport("poseidon", poseidon_mod);
b.installArtifact(bench_exe);

const bench_step = b.step("bench", "Benches the library");
Expand Down
130 changes: 130 additions & 0 deletions ff/element.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
const std = @import("std");

pub fn FieldElement(T: type, MOD: T) type {
return struct {
value: Int,

// Some helper types
const Element = @This();
const bit_size = @bitSizeOf(T);
/// The power of `2 ^ 64` closest to the number of bits required to represent
/// the field modulus.
const rounded_2_pow_64 = std.mem.alignForward(u32, bit_size, 64);
const DoubleInt = std.meta.Int(.unsigned, rounded_2_pow_64 * 2);
const Int = std.meta.Int(.unsigned, rounded_2_pow_64);

/// The additive identity of the field.
pub const ZERO: Element = .{ .value = 0 };
/// The prime field modulus. NOT in Montgomery form.
pub const MODULUS: u254 = MOD;
/// `R = M % MODULUS` where `M` is `2 ^ rounded_2_pow_64`.
pub const R: Int = (std.math.powi(DoubleInt, 2, rounded_2_pow_64) catch
@panic("failed to compute R")) % MODULUS;
/// `R2 = R^2 % MODULUS`.
pub const R2: Int = (std.math.powi(DoubleInt, R, 2) catch
@panic("failed to compute R2")) % MODULUS;
/// - `INVERSE = -MODULUS ^ (-1) mod 2^254`
/// - `INVERSE = R ^ (-1)`
pub const INVERSE = inv: {
var inv: Int = 1;
for (0..bit_size) |_| {
inv *%= inv;
inv *%= MODULUS;
}
break :inv -%inv;
};

/// Adds two field elements in the Montgomery domain.
pub fn add(self: *Element, other: Element) void {
var sum = self.value + other.value;
if (sum >= Element.MODULUS) {
sum -= Element.MODULUS;
}
self.value = @bitCast(sum);
}

/// Multiplies two field elements in the Montgomery domain.
pub fn multiply(self: *Element, other: Element) void {
const product = @as(DoubleInt, self.value) * @as(DoubleInt, other.value);
const low_prod: Int = @truncate(product);
const high_prod: Int = @intCast(product >> rounded_2_pow_64);

const reduced: Int = low_prod *% INVERSE;
const mod_prod = @as(DoubleInt, reduced) * MODULUS;
const low_mod_prod: Int = @truncate(mod_prod);
const high_mod_prod: Int = @intCast(mod_prod >> rounded_2_pow_64);

// TODO: this is a wrapping subtraction!
const carry = @addWithOverflow(low_prod, low_mod_prod)[1];
var final_sum: Int = @truncate(@as(DoubleInt, carry) +
@as(DoubleInt, high_prod) +
@as(DoubleInt, high_mod_prod));

if (final_sum >= MODULUS) {
final_sum -%= MODULUS;
}

self.value = final_sum;
}

/// Squares a field element in the Montgomery domain.
///
/// NOTE: Just performs `self.mul(self)` for now, I'm not aware
/// of a better solution.
pub fn square(self: Element) Element {
var out: Element = self;
out.multiply(self);
return out;
}

/// Translates a field element out of the Montgomery domain.
pub fn fromMontgomery(self: Element) Element {
const product: Int = @truncate(@as(DoubleInt, self.value) * @as(DoubleInt, INVERSE));
const mod_prod = @as(DoubleInt, product) * MODULUS;
const low_mod_prod: Int = @truncate(mod_prod);
const high_mod_prod: Int = @truncate(mod_prod >> rounded_2_pow_64);
const add_overflow: u1 = @addWithOverflow(self.value, low_mod_prod)[1];
const adjusted_diff = add_overflow + high_mod_prod;

// TODO: this is just a compare and reduce
const carry_adjust, const is_carry_adjusted = @subWithOverflow(adjusted_diff, MODULUS);
const is_negative = @subWithOverflow(0, is_carry_adjusted)[1];
const result = if (is_negative == 0) carry_adjust else adjusted_diff;
return .{ .value = result };
}

/// Translates a field element into of the Montgomery domain.
pub fn fromInteger(integer: Int) Element {
const product = @as(DoubleInt, integer) * @as(DoubleInt, R2);
const low_prod: Int = @truncate(product);
const high_prod: Int = @truncate(product >> rounded_2_pow_64);
const n_prod: Int = @truncate(@as(DoubleInt, low_prod) * @as(DoubleInt, INVERSE));
const mod_prod = @as(DoubleInt, n_prod) * MODULUS;
const low_mod_prod: Int = @truncate(mod_prod);
const high_mod_prod: Int = @truncate(mod_prod >> rounded_2_pow_64);
const add_overflow = @addWithOverflow(low_prod, low_mod_prod)[1];
const adjusted_diff = add_overflow + high_prod + high_mod_prod;

// TODO: this is just a compare and reduce
const carry_adjust, const is_carry_adjusted = @subWithOverflow(adjusted_diff, MODULUS);
const is_negative = @subWithOverflow(0, is_carry_adjusted)[1];
const result = if (is_negative == 0) carry_adjust else adjusted_diff;
return .{ .value = result };
}

/// A helper function for the parameters list. Assumes the array is little endian
/// and already in Montgomery form.
pub fn fromArray(array: [@divExact(rounded_2_pow_64, 64)]u64) Element {
return .{ .value = @bitCast(array) };
}

pub fn format(
elem: Element,
comptime _: []const u8,
_: std.fmt.FormatOptions,
writer: anytype,
) !void {
try writer.print("{d}", .{elem.value});
}
};
}
1 change: 1 addition & 0 deletions ff/ff.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub const FieldElement = @import("element.zig").FieldElement;
File renamed without changes.
7 changes: 5 additions & 2 deletions src/params.zig → poseidon/params.zig
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
const Hasher = @import("lib.zig").Hasher;
const Element = @import("lib.zig").Element;
const poseidon = @import("poseidon.zig");
const Hasher = poseidon.Hasher;
const MODULUS = poseidon.MODULUS;
const FieldElement = @import("ff").FieldElement;
const Element = FieldElement(u254, MODULUS);

const PARTIAL_ROUNDS: [15]u64 = .{ 56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64 };

Expand Down
132 changes: 132 additions & 0 deletions poseidon/poseidon.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
const std = @import("std");
const FieldElement = @import("ff").FieldElement;

const PARAMS: [12]Hasher.Params = .{
@import("params.zig").BN256_x5_2,
@import("params.zig").BN256_x5_3,
@import("params.zig").BN256_x5_4,
@import("params.zig").BN256_x5_5,
@import("params.zig").BN256_x5_6,
@import("params.zig").BN256_x5_7,
@import("params.zig").BN256_x5_8,
@import("params.zig").BN256_x5_9,
@import("params.zig").BN256_x5_10,
@import("params.zig").BN256_x5_11,
@import("params.zig").BN256_x5_12,
@import("params.zig").BN256_x5_13,
};

pub const MODULUS: u254 = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001;

pub const Hasher = struct {
endian: std.builtin.Endian,
state: std.BoundedArray(Element, 13),

const Element = FieldElement(u254, MODULUS);

pub const Params = struct {
/// Round constants.
ark: []const Element,
/// MSD matrix.
mds: []const []const Element,
/// The number of full rounds (where S-box is applied to all elements of the state).
full_rounds: u32,
/// The number of partial rounds (where S-box is applied only to the first element
/// of the state).
partial_rounds: u32,
/// The number of prime fields in the state.
width: u32,
/// Exponential used in S-box to power elements of the state
alpha: u32,
};

pub fn init(endian: std.builtin.Endian) Hasher {
var state: std.BoundedArray(Element, 13) = .{};
state.appendAssumeCapacity(Element.ZERO);
return .{
.endian = endian,
.state = state,
};
}

pub fn hash(bytes: []const u8, endian: std.builtin.Endian) ![32]u8 {
if (bytes.len % 32 != 0) return error.InputNotMultipleOf32;
var hasher = Hasher.init(endian);

var iter = std.mem.window(u8, bytes, 32, 32);
while (iter.next()) |slice| {
try hasher.append(slice);
}

return hasher.finish();
}

pub fn append(hasher: *Hasher, bytes: []const u8) !void {
const integer = std.mem.readInt(u256, bytes[0..32], hasher.endian);
if (integer >= Element.MODULUS) {
return error.LargerThanMod;
}
const element = Element.fromInteger(integer);
try hasher.state.append(element);
}

pub fn finish(hasher: *Hasher) ![32]u8 {
const width = hasher.state.len;
const params = PARAMS[width - 2];
if (width != params.width) return error.Unexpected;

const all_rounds = params.full_rounds + params.partial_rounds;
const half_rounds = params.full_rounds / 2;

for (0..half_rounds) |round| {
hasher.applyArk(params, round);
hasher.applySBoxFull(width);
hasher.applyMds(params);
}

for (half_rounds..half_rounds + params.partial_rounds) |round| {
hasher.applyArk(params, round);
hasher.applySBoxFull(1);
hasher.applyMds(params);
}

for (half_rounds + params.partial_rounds..all_rounds) |round| {
hasher.applyArk(params, round);
hasher.applySBoxFull(width);
hasher.applyMds(params);
}

var result = hasher.state.get(0).fromMontgomery();
if (hasher.endian == .big) result.value = @byteSwap(result.value);
return @bitCast(result.value);
}

fn applyArk(hasher: *Hasher, params: Params, round: u64) void {
for (hasher.state.slice(), 0..) |*a, i| {
a.add(params.ark[round * params.width + i]);
}
}

fn applySBoxFull(hasher: *Hasher, width: u64) void {
// compute s[i] ^ 5
for (hasher.state.slice()[0..width]) |*s| {
var t: Element = undefined;
t = s.square(); // t = s ^ 2
t = t.square(); // t = s ^ 4
s.multiply(t); // s = s ^ 5
}
}

fn applyMds(hasher: *Hasher, params: Params) void {
const width = params.width;
var buffer: [13]Element = .{Element.ZERO} ** 13;
for (0..hasher.state.len) |i| {
for (hasher.state.slice(), 0..) |*a, j| {
var t: Element = a.*;
t.multiply(params.mds[i][j]);
buffer[i].add(t);
}
}
@memcpy(hasher.state.slice(), buffer[0..width]);
}
};
Loading

0 comments on commit 5996215

Please sign in to comment.