-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
accba50
commit 5996215
Showing
8 changed files
with
281 additions
and
255 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
const std = @import("std"); | ||
|
||
pub fn FieldElement(T: type, MOD: T) type { | ||
return struct { | ||
value: Int, | ||
|
||
// Some helper types | ||
const Element = @This(); | ||
const bit_size = @bitSizeOf(T); | ||
/// The power of `2 ^ 64` closest to the number of bits required to represent | ||
/// the field modulus. | ||
const rounded_2_pow_64 = std.mem.alignForward(u32, bit_size, 64); | ||
const DoubleInt = std.meta.Int(.unsigned, rounded_2_pow_64 * 2); | ||
const Int = std.meta.Int(.unsigned, rounded_2_pow_64); | ||
|
||
/// The additive identity of the field. | ||
pub const ZERO: Element = .{ .value = 0 }; | ||
/// The prime field modulus. NOT in Montgomery form. | ||
pub const MODULUS: u254 = MOD; | ||
/// `R = M % MODULUS` where `M` is `2 ^ rounded_2_pow_64`. | ||
pub const R: Int = (std.math.powi(DoubleInt, 2, rounded_2_pow_64) catch | ||
@panic("failed to compute R")) % MODULUS; | ||
/// `R2 = R^2 % MODULUS`. | ||
pub const R2: Int = (std.math.powi(DoubleInt, R, 2) catch | ||
@panic("failed to compute R2")) % MODULUS; | ||
/// - `INVERSE = -MODULUS ^ (-1) mod 2^254` | ||
/// - `INVERSE = R ^ (-1)` | ||
pub const INVERSE = inv: { | ||
var inv: Int = 1; | ||
for (0..bit_size) |_| { | ||
inv *%= inv; | ||
inv *%= MODULUS; | ||
} | ||
break :inv -%inv; | ||
}; | ||
|
||
/// Adds two field elements in the Montgomery domain. | ||
pub fn add(self: *Element, other: Element) void { | ||
var sum = self.value + other.value; | ||
if (sum >= Element.MODULUS) { | ||
sum -= Element.MODULUS; | ||
} | ||
self.value = @bitCast(sum); | ||
} | ||
|
||
/// Multiplies two field elements in the Montgomery domain. | ||
pub fn multiply(self: *Element, other: Element) void { | ||
const product = @as(DoubleInt, self.value) * @as(DoubleInt, other.value); | ||
const low_prod: Int = @truncate(product); | ||
const high_prod: Int = @intCast(product >> rounded_2_pow_64); | ||
|
||
const reduced: Int = low_prod *% INVERSE; | ||
const mod_prod = @as(DoubleInt, reduced) * MODULUS; | ||
const low_mod_prod: Int = @truncate(mod_prod); | ||
const high_mod_prod: Int = @intCast(mod_prod >> rounded_2_pow_64); | ||
|
||
// TODO: this is a wrapping subtraction! | ||
const carry = @addWithOverflow(low_prod, low_mod_prod)[1]; | ||
var final_sum: Int = @truncate(@as(DoubleInt, carry) + | ||
@as(DoubleInt, high_prod) + | ||
@as(DoubleInt, high_mod_prod)); | ||
|
||
if (final_sum >= MODULUS) { | ||
final_sum -%= MODULUS; | ||
} | ||
|
||
self.value = final_sum; | ||
} | ||
|
||
/// Squares a field element in the Montgomery domain. | ||
/// | ||
/// NOTE: Just performs `self.mul(self)` for now, I'm not aware | ||
/// of a better solution. | ||
pub fn square(self: Element) Element { | ||
var out: Element = self; | ||
out.multiply(self); | ||
return out; | ||
} | ||
|
||
/// Translates a field element out of the Montgomery domain. | ||
pub fn fromMontgomery(self: Element) Element { | ||
const product: Int = @truncate(@as(DoubleInt, self.value) * @as(DoubleInt, INVERSE)); | ||
const mod_prod = @as(DoubleInt, product) * MODULUS; | ||
const low_mod_prod: Int = @truncate(mod_prod); | ||
const high_mod_prod: Int = @truncate(mod_prod >> rounded_2_pow_64); | ||
const add_overflow: u1 = @addWithOverflow(self.value, low_mod_prod)[1]; | ||
const adjusted_diff = add_overflow + high_mod_prod; | ||
|
||
// TODO: this is just a compare and reduce | ||
const carry_adjust, const is_carry_adjusted = @subWithOverflow(adjusted_diff, MODULUS); | ||
const is_negative = @subWithOverflow(0, is_carry_adjusted)[1]; | ||
const result = if (is_negative == 0) carry_adjust else adjusted_diff; | ||
return .{ .value = result }; | ||
} | ||
|
||
/// Translates a field element into of the Montgomery domain. | ||
pub fn fromInteger(integer: Int) Element { | ||
const product = @as(DoubleInt, integer) * @as(DoubleInt, R2); | ||
const low_prod: Int = @truncate(product); | ||
const high_prod: Int = @truncate(product >> rounded_2_pow_64); | ||
const n_prod: Int = @truncate(@as(DoubleInt, low_prod) * @as(DoubleInt, INVERSE)); | ||
const mod_prod = @as(DoubleInt, n_prod) * MODULUS; | ||
const low_mod_prod: Int = @truncate(mod_prod); | ||
const high_mod_prod: Int = @truncate(mod_prod >> rounded_2_pow_64); | ||
const add_overflow = @addWithOverflow(low_prod, low_mod_prod)[1]; | ||
const adjusted_diff = add_overflow + high_prod + high_mod_prod; | ||
|
||
// TODO: this is just a compare and reduce | ||
const carry_adjust, const is_carry_adjusted = @subWithOverflow(adjusted_diff, MODULUS); | ||
const is_negative = @subWithOverflow(0, is_carry_adjusted)[1]; | ||
const result = if (is_negative == 0) carry_adjust else adjusted_diff; | ||
return .{ .value = result }; | ||
} | ||
|
||
/// A helper function for the parameters list. Assumes the array is little endian | ||
/// and already in Montgomery form. | ||
pub fn fromArray(array: [@divExact(rounded_2_pow_64, 64)]u64) Element { | ||
return .{ .value = @bitCast(array) }; | ||
} | ||
|
||
pub fn format( | ||
elem: Element, | ||
comptime _: []const u8, | ||
_: std.fmt.FormatOptions, | ||
writer: anytype, | ||
) !void { | ||
try writer.print("{d}", .{elem.value}); | ||
} | ||
}; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pub const FieldElement = @import("element.zig").FieldElement; |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
const std = @import("std"); | ||
const FieldElement = @import("ff").FieldElement; | ||
|
||
const PARAMS: [12]Hasher.Params = .{ | ||
@import("params.zig").BN256_x5_2, | ||
@import("params.zig").BN256_x5_3, | ||
@import("params.zig").BN256_x5_4, | ||
@import("params.zig").BN256_x5_5, | ||
@import("params.zig").BN256_x5_6, | ||
@import("params.zig").BN256_x5_7, | ||
@import("params.zig").BN256_x5_8, | ||
@import("params.zig").BN256_x5_9, | ||
@import("params.zig").BN256_x5_10, | ||
@import("params.zig").BN256_x5_11, | ||
@import("params.zig").BN256_x5_12, | ||
@import("params.zig").BN256_x5_13, | ||
}; | ||
|
||
pub const MODULUS: u254 = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001; | ||
|
||
pub const Hasher = struct { | ||
endian: std.builtin.Endian, | ||
state: std.BoundedArray(Element, 13), | ||
|
||
const Element = FieldElement(u254, MODULUS); | ||
|
||
pub const Params = struct { | ||
/// Round constants. | ||
ark: []const Element, | ||
/// MSD matrix. | ||
mds: []const []const Element, | ||
/// The number of full rounds (where S-box is applied to all elements of the state). | ||
full_rounds: u32, | ||
/// The number of partial rounds (where S-box is applied only to the first element | ||
/// of the state). | ||
partial_rounds: u32, | ||
/// The number of prime fields in the state. | ||
width: u32, | ||
/// Exponential used in S-box to power elements of the state | ||
alpha: u32, | ||
}; | ||
|
||
pub fn init(endian: std.builtin.Endian) Hasher { | ||
var state: std.BoundedArray(Element, 13) = .{}; | ||
state.appendAssumeCapacity(Element.ZERO); | ||
return .{ | ||
.endian = endian, | ||
.state = state, | ||
}; | ||
} | ||
|
||
pub fn hash(bytes: []const u8, endian: std.builtin.Endian) ![32]u8 { | ||
if (bytes.len % 32 != 0) return error.InputNotMultipleOf32; | ||
var hasher = Hasher.init(endian); | ||
|
||
var iter = std.mem.window(u8, bytes, 32, 32); | ||
while (iter.next()) |slice| { | ||
try hasher.append(slice); | ||
} | ||
|
||
return hasher.finish(); | ||
} | ||
|
||
pub fn append(hasher: *Hasher, bytes: []const u8) !void { | ||
const integer = std.mem.readInt(u256, bytes[0..32], hasher.endian); | ||
if (integer >= Element.MODULUS) { | ||
return error.LargerThanMod; | ||
} | ||
const element = Element.fromInteger(integer); | ||
try hasher.state.append(element); | ||
} | ||
|
||
pub fn finish(hasher: *Hasher) ![32]u8 { | ||
const width = hasher.state.len; | ||
const params = PARAMS[width - 2]; | ||
if (width != params.width) return error.Unexpected; | ||
|
||
const all_rounds = params.full_rounds + params.partial_rounds; | ||
const half_rounds = params.full_rounds / 2; | ||
|
||
for (0..half_rounds) |round| { | ||
hasher.applyArk(params, round); | ||
hasher.applySBoxFull(width); | ||
hasher.applyMds(params); | ||
} | ||
|
||
for (half_rounds..half_rounds + params.partial_rounds) |round| { | ||
hasher.applyArk(params, round); | ||
hasher.applySBoxFull(1); | ||
hasher.applyMds(params); | ||
} | ||
|
||
for (half_rounds + params.partial_rounds..all_rounds) |round| { | ||
hasher.applyArk(params, round); | ||
hasher.applySBoxFull(width); | ||
hasher.applyMds(params); | ||
} | ||
|
||
var result = hasher.state.get(0).fromMontgomery(); | ||
if (hasher.endian == .big) result.value = @byteSwap(result.value); | ||
return @bitCast(result.value); | ||
} | ||
|
||
fn applyArk(hasher: *Hasher, params: Params, round: u64) void { | ||
for (hasher.state.slice(), 0..) |*a, i| { | ||
a.add(params.ark[round * params.width + i]); | ||
} | ||
} | ||
|
||
fn applySBoxFull(hasher: *Hasher, width: u64) void { | ||
// compute s[i] ^ 5 | ||
for (hasher.state.slice()[0..width]) |*s| { | ||
var t: Element = undefined; | ||
t = s.square(); // t = s ^ 2 | ||
t = t.square(); // t = s ^ 4 | ||
s.multiply(t); // s = s ^ 5 | ||
} | ||
} | ||
|
||
fn applyMds(hasher: *Hasher, params: Params) void { | ||
const width = params.width; | ||
var buffer: [13]Element = .{Element.ZERO} ** 13; | ||
for (0..hasher.state.len) |i| { | ||
for (hasher.state.slice(), 0..) |*a, j| { | ||
var t: Element = a.*; | ||
t.multiply(params.mds[i][j]); | ||
buffer[i].add(t); | ||
} | ||
} | ||
@memcpy(hasher.state.slice(), buffer[0..width]); | ||
} | ||
}; |
Oops, something went wrong.