diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index addfffc8..3aaf5491 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -122,12 +122,10 @@ jobs: cargo check --verbose --package p3-merkle-tree cargo check --verbose --package p3-mersenne-31 cargo check --verbose --package p3-monolith - cargo check --verbose --package p3-multi-stark cargo check --verbose --package p3-poseidon cargo check --verbose --package p3-poseidon2 cargo check --verbose --package p3-reed-solomon cargo check --verbose --package p3-rescue cargo check --verbose --package p3-symmetric - cargo check --verbose --package p3-tensor-pcs cargo check --verbose --package p3-uni-stark cargo check --verbose --package p3-util \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index a8b0d2bf..d32e4c64 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "blake3", "brakedown", "challenger", + "circle", "code", "commit", "dft", @@ -22,13 +23,11 @@ members = [ "maybe-rayon", "mersenne-31", "monolith", - "multi-stark", "poseidon", "poseidon2", "reed-solomon", "rescue", "symmetric", - "tensor-pcs", "util", "uni-stark", ] diff --git a/README.md b/README.md index 94156c2d..d9989bec 100644 --- a/README.md +++ b/README.md @@ -12,12 +12,12 @@ Fields: - [x] "complex" extension field - [x] ~128 bit extension field - [x] AVX2 - - [ ] AVX-512 + - [x] AVX-512 - [x] NEON - [x] BabyBear - [x] ~128 bit extension field - - [ ] AVX2 - - [ ] AVX-512 + - [x] AVX2 + - [x] AVX-512 - [x] NEON - [x] Goldilocks - [x] ~128 bit extension field @@ -64,6 +64,22 @@ We sometimes use a Keccak AIR to compare Plonky3's performance to other librarie RUST_LOG=info cargo run --example prove_baby_bear_keccak --release --features parallel ``` +## CPU features + +Plonky3 contains optimizations that rely on newer CPU instructions that are not available in older processors. These instruction sets include x86's [BMI1 and 2](https://en.wikipedia.org/wiki/X86_Bit_manipulation_instruction_set), [AVX2](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions#Advanced_Vector_Extensions_2), and [AVX-512](https://en.wikipedia.org/wiki/AVX-512). Rustc does not emit those instructions by default; they must be explicitly enabled through the `target-feature` compiler option (or implicitly by setting `target-cpu`). To enable all features that are supported on your machine, you can set `target-cpu` to `native`. For example, to run the tests: +``` +RUSTFLAGS="-Ctarget-cpu=native" cargo test +``` + +Support for some instructions, such as AVX-512, is still experimental. They are only available in the nightly build of Rustc and are enabled by the [`nightly-features` feature flag](#nightly-only-optimizations). To use them, you must enable the flag in Rustc (e.g. by setting `target-feature`) and you must also enable the `nightly-features` feature. + +## Nightly-only optimizations + +Some optimizations (in particular, AVX-512-optimized math) rely on features that are currently available only in the nightly build of Rustc. To use them, you need to enable the `nightly-features` feature. For example, to run the tests: +``` +cargo test --features nightly-features +``` + ## License diff --git a/air/src/virtual_column.rs b/air/src/virtual_column.rs index 0ccc9835..eb5ca88c 100644 --- a/air/src/virtual_column.rs +++ b/air/src/virtual_column.rs @@ -1,3 +1,4 @@ +use alloc::borrow::Cow; use alloc::vec; use alloc::vec::Vec; use core::ops::Mul; @@ -6,8 +7,8 @@ use p3_field::{AbstractField, Field}; /// An affine function over columns in a PAIR. #[derive(Clone, Debug)] -pub struct VirtualPairCol { - column_weights: Vec<(PairCol, F)>, +pub struct VirtualPairCol<'a, F: Field> { + column_weights: Cow<'a, [(PairCol, F)]>, constant: F, } @@ -19,7 +20,7 @@ pub enum PairCol { } impl PairCol { - fn get(&self, preprocessed: &[T], main: &[T]) -> T { + pub const fn get(&self, preprocessed: &[T], main: &[T]) -> T { match self { PairCol::Preprocessed(i) => preprocessed[*i], PairCol::Main(i) => main[*i], @@ -27,14 +28,28 @@ impl PairCol { } } -impl VirtualPairCol { - pub fn new(column_weights: Vec<(PairCol, F)>, constant: F) -> Self { +impl<'a, F: Field> VirtualPairCol<'a, F> { + pub const fn new(column_weights: Cow<'a, [(PairCol, F)]>, constant: F) -> Self { Self { column_weights, constant, } } + pub const fn new_owned(column_weights: Vec<(PairCol, F)>, constant: F) -> Self { + Self { + column_weights: Cow::Owned(column_weights), + constant, + } + } + + pub const fn new_borrowed(column_weights: &'a [(PairCol, F)], constant: F) -> Self { + Self { + column_weights: Cow::Borrowed(column_weights), + constant, + } + } + pub fn new_preprocessed(column_weights: Vec<(usize, F)>, constant: F) -> Self { Self::new( column_weights @@ -63,7 +78,7 @@ impl VirtualPairCol { #[must_use] pub fn constant(x: F) -> Self { Self { - column_weights: vec![], + column_weights: Cow::Owned(vec![]), constant: x, } } @@ -71,7 +86,7 @@ impl VirtualPairCol { #[must_use] pub fn single(column: PairCol) -> Self { Self { - column_weights: vec![(column, F::one())], + column_weights: Cow::Owned(vec![(column, F::one())]), constant: F::zero(), } } @@ -117,7 +132,7 @@ impl VirtualPairCol { Var: Into + Copy, { let mut result = self.constant.into(); - for (column, weight) in &self.column_weights { + for (column, weight) in self.column_weights.iter() { result += column.get(preprocessed, main).into() * *weight; } result diff --git a/baby-bear/Cargo.toml b/baby-bear/Cargo.toml index ab989de4..13375411 100644 --- a/baby-bear/Cargo.toml +++ b/baby-bear/Cargo.toml @@ -4,13 +4,22 @@ version = "0.1.0" edition = "2021" license = "MIT OR Apache-2.0" +[features] +nightly-features = [] + [dependencies] p3-field = { path = "../field" } +p3-mds = { path = "../mds" } +p3-poseidon2 = { path = "../poseidon2" } +p3-symmetric = { path = "../symmetric" } rand = "0.8.5" serde = { version = "1.0", default-features = false, features = ["derive"] } [dev-dependencies] p3-field-testing = { path = "../field-testing" } +ark-ff = { version = "^0.4.0", default-features = false } +zkhash = { git = "https://github.com/HorizenLabs/poseidon2" } +rand = { version = "0.8.5", features = ["min_const_gen"] } criterion = "0.5.1" rand_chacha = "0.3.1" serde_json = "1.0.113" diff --git a/baby-bear/src/baby_bear.rs b/baby-bear/src/baby_bear.rs index 88f03180..fc8930dd 100644 --- a/baby-bear/src/baby_bear.rs +++ b/baby-bear/src/baby_bear.rs @@ -3,8 +3,8 @@ use core::iter::{Product, Sum}; use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; use p3_field::{ - exp_1725656503, exp_u64_by_squaring, AbstractField, Field, Packable, PrimeField, PrimeField32, - PrimeField64, TwoAdicField, + exp_1725656503, exp_u64_by_squaring, halve_u32, AbstractField, Field, Packable, PrimeField, + PrimeField32, PrimeField64, TwoAdicField, }; use rand::distributions::{Distribution, Standard}; use rand::Rng; @@ -186,11 +186,30 @@ impl AbstractField for BabyBear { impl Field for BabyBear { #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] type Packing = crate::PackedBabyBearNeon; - #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] + #[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(all(feature = "nightly-features", target_feature = "avx512f")) + ))] type Packing = crate::PackedBabyBearAVX2; + #[cfg(all( + feature = "nightly-features", + target_arch = "x86_64", + target_feature = "avx512f" + ))] + type Packing = crate::PackedBabyBearAVX512; #[cfg(not(any( all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2"), + all( + target_arch = "x86_64", + target_feature = "avx2", + not(all(feature = "nightly-features", target_feature = "avx512f")) + ), + all( + feature = "nightly-features", + target_arch = "x86_64", + target_feature = "avx512f" + ), )))] type Packing = Self; @@ -238,6 +257,13 @@ impl Field for BabyBear { Some(p1110111111111111111111111111111) } + + #[inline] + fn halve(&self) -> Self { + BabyBear { + value: halve_u32::

(self.value), + } + } } impl PrimeField for BabyBear {} @@ -249,20 +275,6 @@ impl PrimeField64 for BabyBear { fn as_canonical_u64(&self) -> u64 { u64::from(self.as_canonical_u32()) } - - #[inline] - fn linear_combination_u64(u: [u64; N], v: &[Self; N]) -> Self { - // In order not to overflow a u64, we must have sum(u) <= 2^32. - debug_assert!(u.iter().sum::() <= (1u64 << 32)); - - let mut dot = u[0] * v[0].value as u64; - for i in 1..N { - dot += u[i] * v[i].value as u64; - } - Self { - value: (dot % (P as u64)) as u32, - } - } } impl PrimeField32 for BabyBear { @@ -418,6 +430,23 @@ const fn to_monty(x: u32) -> u32 { (((x as u64) << MONTY_BITS) % P as u64) as u32 } +/// Convert a constant u32 array into a constant Babybear array. +/// Saves every element in Monty Form +#[inline] +#[must_use] +pub(crate) const fn to_babybear_array(input: [u32; N]) -> [BabyBear; N] { + let mut output = [BabyBear { value: 0 }; N]; + let mut i = 0; + loop { + if i == N { + break; + } + output[i].value = to_monty(input[i]); + i += 1; + } + output +} + #[inline] #[must_use] fn to_monty_64(x: u64) -> u32 { diff --git a/baby-bear/src/lib.rs b/baby-bear/src/lib.rs index 3e85090c..26a34c43 100644 --- a/baby-bear/src/lib.rs +++ b/baby-bear/src/lib.rs @@ -1,11 +1,23 @@ #![no_std] +#![cfg_attr( + all( + feature = "nightly-features", + target_arch = "x86_64", + target_feature = "avx512f" + ), + feature(stdarch_x86_avx512) +)] extern crate alloc; mod baby_bear; mod extension; +mod mds; +mod poseidon2; pub use baby_bear::*; +pub use mds::*; +pub use poseidon2::DiffusionMatrixBabybear; #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] mod aarch64_neon; @@ -16,3 +28,16 @@ pub use aarch64_neon::*; mod x86_64_avx2; #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] pub use x86_64_avx2::*; + +#[cfg(all( + feature = "nightly-features", + target_arch = "x86_64", + target_feature = "avx512f" +))] +mod x86_64_avx512; +#[cfg(all( + feature = "nightly-features", + target_arch = "x86_64", + target_feature = "avx512f" +))] +pub use x86_64_avx512::*; diff --git a/baby-bear/src/mds.rs b/baby-bear/src/mds.rs new file mode 100644 index 00000000..27beac04 --- /dev/null +++ b/baby-bear/src/mds.rs @@ -0,0 +1,564 @@ +//! MDS matrices over the BabyBear field, and permutations defined by them. +//! +//! NB: Not all sizes have fast implementations of their permutations. +//! Supported sizes: 8, 12, 16, 24, 32, 64. +//! Sizes 8 and 12 are from Plonky2, size 16 was found as part of concurrent +//! work by Angus Gruen and Hamish Ivey-Law. Other sizes are from Ulrich Haböck's +//! database. + +use p3_field::{PrimeField32, PrimeField64}; +use p3_mds::karatsuba_convolution::Convolve; +use p3_mds::util::{dot_product, first_row_to_first_col}; +use p3_mds::MdsPermutation; +use p3_symmetric::Permutation; + +use crate::BabyBear; + +#[derive(Clone, Default)] +pub struct MdsMatrixBabyBear; + +/// Instantiate convolution for "small" RHS vectors over BabyBear. +/// +/// Here "small" means N = len(rhs) <= 16 and sum(r for r in rhs) < +/// 2^24 (roughly), though in practice the sum will be less than 2^9. +struct SmallConvolveBabyBear; +impl Convolve for SmallConvolveBabyBear { + /// Return the lift of a BabyBear element, satisfying 0 <= + /// input.value < P < 2^31. Note that BabyBear elements are + /// represented in Monty form. + #[inline(always)] + fn read(input: BabyBear) -> i64 { + input.value as i64 + } + + /// For a convolution of size N, |x| < N * 2^31 and (as per the + /// assumption above), |y| < 2^24. So the product is at most N * 2^55 + /// which will not overflow for N <= 16. + /// + /// Note that the LHS element is in Monty form, while the RHS + /// element is an "plain integer". This informs the implementation + /// of `reduce()` below. + #[inline(always)] + fn parity_dot(u: [i64; N], v: [i64; N]) -> i64 { + dot_product(u, v) + } + + /// The assumptions above mean z < N^2 * 2^55, which is at most + /// 2^63 when N <= 16. + /// + /// Because the LHS elements were in Monty form and the RHS + /// elements were plain integers, reduction is simply the usual + /// reduction modulo P, rather than "Monty reduction". + /// + /// NB: Even though intermediate values could be negative, the + /// output must be non-negative since the inputs were + /// non-negative. + #[inline(always)] + fn reduce(z: i64) -> BabyBear { + debug_assert!(z >= 0); + BabyBear { + value: (z as u64 % BabyBear::ORDER_U64) as u32, + } + } +} + +/// Given |x| < 2^80 compute x' such that: +/// |x'| < 2**50 +/// x' = x mod p +/// x' = x mod 2^10 +/// See Thm 1 (Below function) for a proof that this function is correct. +#[inline(always)] +fn barret_red_babybear(input: i128) -> i64 { + const N: usize = 40; // beta = 2^N, fixing N = 40 here + const P: i128 = BabyBear::ORDER_U32 as i128; + const I: i64 = (((1_i128) << (2 * N)) / P) as i64; // I = 2^80 / P => I < 2**50 + // I: i64 = 0x22222221d950c + const MASK: i64 = !((1 << 10) - 1); // Lets us 0 out the bottom 10 digits of an i64. + + // input = input_low + beta*input_high + // So input_high < 2**63 and fits in an i64. + let input_high = (input >> N) as i64; // input_high < input / beta < 2**{80 - N} + + // I, input_high are i64's so this mulitiplication can't overflow. + let quot = (((input_high as i128) * (I as i128)) >> N) as i64; + + // Replace quot by a close value which is divisibly by 2^10. + let quot_2adic = quot & MASK; + + // quot_2adic, P are i64's so this can't overflow. + // sub is by construction divisible by both P and 2^10. + let sub = (quot_2adic as i128) * P; + + (input - sub) as i64 +} + +// Theorem 1: +// Given |x| < 2^80, barret_red(x) computes an x' such that: +// x' = x mod p +// x' = x mod 2^10 +// |x'| < 2**50. +/////////////////////////////////////////////////////////////////////////////////////// +// PROOF: +// By construction P, 2**10 | sub and so we immediately see that +// x' = x mod p +// x' = x mod 2^10. +// +// It remains to prove that |x'| < 2**50. +// +// We start by introducing some simple inequalities and relations bewteen our variables: +// +// First consider the relationship between bitshift and division. +// It's easy to check that for all x: +// 1: (x >> N) <= x / 2**N <= 1 + (x >> N) +// +// Similarly, as our mask just 0's the last 10 bits, +// 2: x + 1 - 2^10 <= x & mask <= x +// +// Now if x, y are positive integers then +// (x / y) - 1 <= x // y <= x / y +// Where // denotes integer division. +// +// From this last inequality we immediately derive: +// 3: (2**{2N} / P) - 1 <= I <= (2**{2N} / P) +// 3a: 2**{2N} - P <= PI +// +// Finally, note that by definition: +// input = input_high*(2**N) + input_low +// Hence a simple rearrangement gets us +// 4: input_high*(2**N) = input - input_low +// +// +// We now need to split into cases depending on the sign of input. +// Note that if x = 0 then x' = 0 so that case is trivial. +/////////////////////////////////////////////////////////////////////////// +// CASE 1: input > 0 +// +// If input > 0 then: +// sub = Q*P = ((((input >> N) * I) >> N) & mask) * P <= P * (input / 2**{N}) * (2**{2N} / P) / 2**{N} = input +// So input - sub >= 0. +// +// We need to improve our bound on Q. Observe that: +// Q = (((input_high * I) >> N) & mask) +// --(2) => Q + (2^10 - 1) >= (input_high * I) >> N) +// --(1) => Q + 2^10 >= (I*x_high)/(2**N) +// => (2**N)*Q + 2^10*(2**N) >= I*x_high +// +// Hence we find that: +// (2**N)*Q*P + 2^10*(2**N)*P >= input_high*I*P +// --(3a) >= input_high*2**{2N} - P*input_high +// --(4) >= (2**N)*input - (2**N)*input_low - (2**N)*input_high (Assuming P < 2**N) +// +// Dividing by 2**N we get +// Q*P + 2^{10}*P >= input - input_low - input_high +// which rearranges to +// x' = input - Q*P <= 2^{10}*P + input_low + input_high +// +// Picking N = 40 we see that 2^{10}*P, input_low, input_high are all bounded by 2**40 +// Hence x' < 2**42 < 2**50 as desired. +// +// +// +/////////////////////////////////////////////////////////////////////////// +// CASE 2: input < 0 +// +// This case will be similar but all our inequalities will change slightly as negatives complicate things. +// First observe that: +// (input >> N) * I >= (input >> N) * 2**(2N) / P +// >= (1 + (input / 2**N)) * 2**(2N) / P +// >= (2**N + input) * 2**N / P +// +// Thus: +// Q = ((input >> N) * I) >> N >= ((2**N + input) * 2**N / P) >> N +// >= ((2**N + input) / P) - 1 +// +// And so sub = Q*P >= 2**N - P + input. +// Hence input - sub < 2**N - P. +// +// Thus if input - sub > 0 then |input - sub| < 2**50. +// Thus we are left with bounding -(input - sub) = (sub - input). +// Again we will proceed by improving our bound on Q. +// +// Q = (((input_high * I) >> N) & mask) +// --(2) => Q <= (input_high * I) >> N) <= (I*x_high)/(2**N) +// --(1) => Q <= (I*x_high)/(2**N) +// => (2**N)*Q <= I*x_high +// +// Hence we find that: +// (2**N)*Q*P <= input_high*I*P +// --(3a) <= input_high*2**{2N} - P*input_high +// --(4) <= (2**N)*input - (2**N)*input_low - (2**N)*input_high (Assuming P < 2**N) +// +// Dividing by 2**N we get +// Q*P <= input - input_low - input_high +// which rearranges to +// -x' = -input + Q*P <= -input_high - input_low < 2**50 +// +// This completes the proof. + +/// Instantiate convolution for "large" RHS vectors over BabyBear. +/// +/// Here "large" means the elements can be as big as the field +/// characteristic, and the size N of the RHS is <= 64. +struct LargeConvolveBabyBear; +impl Convolve for LargeConvolveBabyBear { + /// Return the lift of a BabyBear element, satisfying 0 <= + /// input.value < P < 2^31. Note that BabyBear elements are + /// represented in Monty form. + #[inline(always)] + fn read(input: BabyBear) -> i64 { + input.value as i64 + } + + #[inline(always)] + fn parity_dot(u: [i64; N], v: [i64; N]) -> i64 { + // For a convolution of size N, |x|, |y| < N * 2^31, so the + // product could be as much as N^2 * 2^62. This will overflow an + // i64, so we first widen to i128. Note that N^2 * 2^62 < 2^80 + // for N <= 64, as required by `barret_red_babybear()`. + + let mut dp = 0i128; + for i in 0..N { + dp += u[i] as i128 * v[i] as i128; + } + barret_red_babybear(dp) + } + + #[inline(always)] + fn reduce(z: i64) -> BabyBear { + // After the barret reduction method, the output z of parity + // dot satisfies |z| < 2^50 (See Thm 1 above). + // + // In the recombining steps, conv_n maps (wo, w1) -> + // ((wo + w1)/2, (wo + w1)/2) which has no effect on the maximal + // size. (Indeed, it makes sizes almost strictly smaller). + // + // On the other hand, negacyclic_conv_n (ignoring the re-index) + // recombines as: (w0, w1, w2) -> (w0 + w1, w2 - w0 - w1). + // Hence if the input is <= K, the output is <= 3K. + // + // Thus the values appearing at the end are bounded by 3^n 2^50 + // where n is the maximal number of negacyclic_conv + // recombination steps. When N = 64, we need to recombine for + // singed_conv_32, singed_conv_16, singed_conv_8 so the + // overall bound will be 3^3 2^50 < 32 * 2^50 < 2^55. + debug_assert!(z > -(1i64 << 55)); + debug_assert!(z < (1i64 << 55)); + + // Note we do NOT move it into MONTY form. We assume it is already + // in this form. + let red = (z % (BabyBear::ORDER_U32 as i64)) as u32; + + // If z >= 0: 0 <= red < P is the correct value and P + red will + // not overflow. + // If z < 0: -P < red < 0 and the value we want is P + red. + // On bits, + acts identically for i32 and u32. Hence we can use + // u32's and just check for overflow. + + let (corr, over) = red.overflowing_add(BabyBear::ORDER_U32); + let value = if over { corr } else { red }; + BabyBear { value } + } +} + +const MATRIX_CIRC_MDS_8_SML_ROW: [i64; 8] = [7, 1, 3, 8, 8, 3, 4, 9]; + +impl Permutation<[BabyBear; 8]> for MdsMatrixBabyBear { + fn permute(&self, input: [BabyBear; 8]) -> [BabyBear; 8] { + const MATRIX_CIRC_MDS_8_SML_COL: [i64; 8] = + first_row_to_first_col(&MATRIX_CIRC_MDS_8_SML_ROW); + SmallConvolveBabyBear::apply( + input, + MATRIX_CIRC_MDS_8_SML_COL, + SmallConvolveBabyBear::conv8, + ) + } + + fn permute_mut(&self, input: &mut [BabyBear; 8]) { + *input = self.permute(*input); + } +} +impl MdsPermutation for MdsMatrixBabyBear {} + +const MATRIX_CIRC_MDS_12_SML_ROW: [i64; 12] = [1, 1, 2, 1, 8, 9, 10, 7, 5, 9, 4, 10]; + +impl Permutation<[BabyBear; 12]> for MdsMatrixBabyBear { + fn permute(&self, input: [BabyBear; 12]) -> [BabyBear; 12] { + const MATRIX_CIRC_MDS_12_SML_COL: [i64; 12] = + first_row_to_first_col(&MATRIX_CIRC_MDS_12_SML_ROW); + SmallConvolveBabyBear::apply( + input, + MATRIX_CIRC_MDS_12_SML_COL, + SmallConvolveBabyBear::conv12, + ) + } + + fn permute_mut(&self, input: &mut [BabyBear; 12]) { + *input = self.permute(*input); + } +} +impl MdsPermutation for MdsMatrixBabyBear {} + +#[rustfmt::skip] +const MATRIX_CIRC_MDS_16_SML_ROW: [i64; 16] = [ + 1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3 +]; + +impl Permutation<[BabyBear; 16]> for MdsMatrixBabyBear { + fn permute(&self, input: [BabyBear; 16]) -> [BabyBear; 16] { + const MATRIX_CIRC_MDS_16_SML_COL: [i64; 16] = + first_row_to_first_col(&MATRIX_CIRC_MDS_16_SML_ROW); + SmallConvolveBabyBear::apply( + input, + MATRIX_CIRC_MDS_16_SML_COL, + SmallConvolveBabyBear::conv16, + ) + } + + fn permute_mut(&self, input: &mut [BabyBear; 16]) { + *input = self.permute(*input); + } +} +impl MdsPermutation for MdsMatrixBabyBear {} + +#[rustfmt::skip] +const MATRIX_CIRC_MDS_24_BABYBEAR_ROW: [i64; 24] = [ + 0x2D0AAAAB, 0x64850517, 0x17F5551D, 0x04ECBEB5, + 0x6D91A8D5, 0x60703026, 0x18D6F3CA, 0x729601A7, + 0x77CDA9E2, 0x3C0F5038, 0x26D52A61, 0x0360405D, + 0x68FC71C8, 0x2495A71D, 0x5D57AFC2, 0x1689DD98, + 0x3C2C3DBE, 0x0C23DC41, 0x0524C7F2, 0x6BE4DF69, + 0x0A6E572C, 0x5C7790FA, 0x17E118F6, 0x0878A07F, +]; + +impl Permutation<[BabyBear; 24]> for MdsMatrixBabyBear { + fn permute(&self, input: [BabyBear; 24]) -> [BabyBear; 24] { + const MATRIX_CIRC_MDS_24_BABYBEAR_COL: [i64; 24] = + first_row_to_first_col(&MATRIX_CIRC_MDS_24_BABYBEAR_ROW); + LargeConvolveBabyBear::apply( + input, + MATRIX_CIRC_MDS_24_BABYBEAR_COL, + LargeConvolveBabyBear::conv24, + ) + } + + fn permute_mut(&self, input: &mut [BabyBear; 24]) { + *input = self.permute(*input); + } +} +impl MdsPermutation for MdsMatrixBabyBear {} + +#[rustfmt::skip] +const MATRIX_CIRC_MDS_32_BABYBEAR_ROW: [i64; 32] = [ + 0x0BC00000, 0x2BED8F81, 0x337E0652, 0x4C4535D1, + 0x4AF2DC32, 0x2DB4050F, 0x676A7CE3, 0x3A06B68E, + 0x5E95C1B1, 0x2C5F54A0, 0x2332F13D, 0x58E757F1, + 0x3AA6DCCE, 0x607EE630, 0x4ED57FF0, 0x6E08555B, + 0x4C155556, 0x587FD0CE, 0x462F1551, 0x032A43CC, + 0x5E2E43EA, 0x71609B02, 0x0ED97E45, 0x562CA7E9, + 0x2CB70B1D, 0x4E941E23, 0x174A61C1, 0x117A9426, + 0x73562137, 0x54596086, 0x487C560B, 0x68A4ACAB, +]; + +impl Permutation<[BabyBear; 32]> for MdsMatrixBabyBear { + fn permute(&self, input: [BabyBear; 32]) -> [BabyBear; 32] { + const MATRIX_CIRC_MDS_32_BABYBEAR_COL: [i64; 32] = + first_row_to_first_col(&MATRIX_CIRC_MDS_32_BABYBEAR_ROW); + LargeConvolveBabyBear::apply( + input, + MATRIX_CIRC_MDS_32_BABYBEAR_COL, + LargeConvolveBabyBear::conv32, + ) + } + + fn permute_mut(&self, input: &mut [BabyBear; 32]) { + *input = self.permute(*input); + } +} +impl MdsPermutation for MdsMatrixBabyBear {} + +#[rustfmt::skip] +const MATRIX_CIRC_MDS_64_BABYBEAR_ROW: [i64; 64] = [ + 0x39577778, 0x0072F4E1, 0x0B1B8404, 0x041E9C88, + 0x32D22F9F, 0x4E4BF946, 0x20C7B6D7, 0x0587C267, + 0x55877229, 0x4D186EC4, 0x4A19FD23, 0x1A64A20F, + 0x2965CA4D, 0x16D98A5A, 0x471E544A, 0x193D5C8B, + 0x6E66DF0C, 0x28BF1F16, 0x26DB0BC8, 0x5B06CDDB, + 0x100DCCA2, 0x65C268AD, 0x199F09E7, 0x36BA04BE, + 0x06C393F2, 0x51B06DFD, 0x6951B0C4, 0x6683A4C2, + 0x3B53D11B, 0x26E5134C, 0x45A5F1C5, 0x6F4D2433, + 0x3CE2D82E, 0x36309A7D, 0x3DD9B459, 0x68051E4C, + 0x5C3AA720, 0x11640517, 0x0634D995, 0x1B0F6406, + 0x72A18430, 0x26513CC5, 0x67C0B93C, 0x548AB4A3, + 0x6395D20D, 0x3E5DBC41, 0x332AF630, 0x3C5DDCB3, + 0x0AA95792, 0x66EB5492, 0x3F78DDDC, 0x5AC41627, + 0x16CD5124, 0x3564DA96, 0x461867C9, 0x157B4E11, + 0x1AA486C8, 0x0C5095A9, 0x3833C0C6, 0x008FEBA5, + 0x52ECBE2E, 0x1D178A67, 0x58B3C04B, 0x6E95CB51, +]; + +impl Permutation<[BabyBear; 64]> for MdsMatrixBabyBear { + fn permute(&self, input: [BabyBear; 64]) -> [BabyBear; 64] { + const MATRIX_CIRC_MDS_64_BABYBEAR_COL: [i64; 64] = + first_row_to_first_col(&MATRIX_CIRC_MDS_64_BABYBEAR_ROW); + LargeConvolveBabyBear::apply( + input, + MATRIX_CIRC_MDS_64_BABYBEAR_COL, + LargeConvolveBabyBear::conv64, + ) + } + + fn permute_mut(&self, input: &mut [BabyBear; 64]) { + *input = self.permute(*input); + } +} +impl MdsPermutation for MdsMatrixBabyBear {} + +#[cfg(test)] +mod tests { + use p3_field::AbstractField; + use p3_symmetric::Permutation; + + use super::{BabyBear, MdsMatrixBabyBear}; + + #[test] + fn babybear8() { + let input: [BabyBear; 8] = [ + 391474477, 1174409341, 666967492, 1852498830, 1801235316, 820595865, 585587525, + 1348326858, + ] + .map(BabyBear::from_canonical_u64); + + let output = MdsMatrixBabyBear.permute(input); + + let expected: [BabyBear; 8] = [ + 1752937716, 1801468855, 1102954394, 284747746, 1636355768, 205443234, 1235359747, + 1159982032, + ] + .map(BabyBear::from_canonical_u64); + + assert_eq!(output, expected); + } + + #[test] + fn babybear12() { + let input: [BabyBear; 12] = [ + 918423259, 673549090, 364157140, 9832898, 493922569, 1171855651, 246075034, 1542167926, + 1787615541, 1696819900, 1884530130, 422386768, + ] + .map(BabyBear::from_canonical_u64); + + let output = MdsMatrixBabyBear.permute(input); + + let expected: [BabyBear; 12] = [ + 1631062293, 890348490, 1304705406, 1888740923, 845648570, 717048224, 1082440815, + 914769887, 1872991191, 1366539339, 1805116914, 1998032485, + ] + .map(BabyBear::from_canonical_u64); + + assert_eq!(output, expected); + } + + #[test] + fn babybear16() { + let input: [BabyBear; 16] = [ + 1983708094, 1477844074, 1638775686, 98517138, 70746308, 968700066, 275567720, + 1359144511, 960499489, 1215199187, 474302783, 79320256, 1923147803, 1197733438, + 1638511323, 303948902, + ] + .map(BabyBear::from_canonical_u64); + + let output = MdsMatrixBabyBear.permute(input); + + let expected: [BabyBear; 16] = [ + 1497569692, 1038070871, 669165859, 456905446, 1116763366, 1267622262, 1985953057, + 1060497461, 704264985, 306103349, 1271339089, 1551541970, 1796459417, 889229849, + 1731972538, 439594789, + ] + .map(BabyBear::from_canonical_u64); + + assert_eq!(output, expected); + } + + #[test] + fn babybear24() { + let input: [BabyBear; 24] = [ + 1307148929, 1603957607, 1515498600, 1412393512, 785287979, 988718522, 1750345556, + 853137995, 534387281, 930390055, 1600030977, 903985158, 1141020507, 636889442, + 966037834, 1778991639, 1440427266, 1379431959, 853403277, 959593575, 733455867, + 908584009, 817124993, 418826476, + ] + .map(BabyBear::from_canonical_u64); + + let output = MdsMatrixBabyBear.permute(input); + + let expected: [BabyBear; 24] = [ + 1537871777, 1626055274, 1705000179, 1426678258, 1688760658, 1347225494, 1291221794, + 1224656589, 1791446853, 1978133881, 1820380039, 1366829700, 27479566, 409595531, + 1223347944, 1752750033, 594548873, 1447473111, 1385412872, 1111945102, 1366585917, + 138866947, 1326436332, 656898133, + ] + .map(BabyBear::from_canonical_u64); + + assert_eq!(output, expected); + } + + #[test] + fn babybear32() { + let input: [BabyBear; 32] = [ + 1346087634, 1511946000, 1883470964, 54906057, 233060279, 5304922, 1881494193, + 743728289, 404047361, 1148556479, 144976634, 1726343008, 29659471, 1350407160, + 1636652429, 385978955, 327649601, 1248138459, 1255358242, 84164877, 1005571393, + 1713215328, 72913800, 1683904606, 904763213, 316800515, 656395998, 788184609, + 1824512025, 1177399063, 1358745087, 444151496, + ] + .map(BabyBear::from_canonical_u64); + + let output = MdsMatrixBabyBear.permute(input); + + let expected: [BabyBear; 32] = [ + 1359576919, 1657405784, 1031581836, 212090105, 699048671, 877916349, 205627787, + 1211567750, 210807569, 1696391051, 558468987, 161148427, 304343518, 76611896, + 532792005, 1963649139, 1283500358, 250848292, 1109842541, 2007388683, 433801252, + 1189712914, 626158024, 1436409738, 456315160, 1836818120, 1645024941, 925447491, + 1599571860, 1055439714, 353537136, 379644130, + ] + .map(BabyBear::from_canonical_u64); + + assert_eq!(output, expected); + } + + #[test] + fn babybear64() { + let input: [BabyBear; 64] = [ + 1931358930, 1322576114, 1658000717, 134388215, 1517892791, 1486447670, 93570662, + 898466034, 1576905917, 283824713, 1433559150, 1730678909, 155340881, 1978472263, + 1980644590, 1814040165, 654743892, 849954227, 323176597, 146970735, 252703735, + 1856579399, 162749290, 986745196, 352038183, 1239527508, 828473247, 1184743572, + 1017249065, 36804843, 1378131210, 1286724687, 596095979, 1916924908, 528946791, + 397247884, 23477278, 299412064, 415288430, 935825754, 1218003667, 1954592289, + 1594612673, 664096455, 958392778, 497208288, 1544504580, 1829423324, 956111902, + 458327015, 1736664598, 430977734, 599887171, 1100074154, 1197653896, 427838651, + 466509871, 1236918100, 940670246, 1421951147, 255557957, 1374188100, 315300068, + 623354170, + ] + .map(BabyBear::from_canonical_u64); + + let output = MdsMatrixBabyBear.permute(input); + + let expected: [BabyBear; 64] = [ + 442300274, 756862170, 167612495, 1103336044, 546496433, 1211822920, 329094196, + 1334376959, 944085937, 977350947, 1445060130, 918469957, 800346119, 1957918170, + 739098112, 1862817833, 1831589884, 1673860978, 698081523, 1128978338, 387929536, + 1106772486, 1367460469, 1911237185, 362669171, 819949894, 1801786287, 1943505026, + 586738185, 996076080, 1641277705, 1680239311, 1005815192, 63087470, 593010310, + 364673774, 543368618, 1576179136, 47618763, 1990080335, 1608655220, 499504830, + 861863262, 765074289, 139277832, 1139970138, 1510286607, 244269525, 43042067, + 119733624, 1314663255, 893295811, 1444902994, 914930267, 1675139862, 1148717487, + 1601328192, 534383401, 296215929, 1924587380, 1336639141, 34897994, 2005302060, + 1780337352, + ] + .map(BabyBear::from_canonical_u64); + + assert_eq!(output, expected); + } +} diff --git a/baby-bear/src/poseidon2.rs b/baby-bear/src/poseidon2.rs new file mode 100644 index 00000000..02cdbc34 --- /dev/null +++ b/baby-bear/src/poseidon2.rs @@ -0,0 +1,140 @@ +use p3_field::AbstractField; +use p3_poseidon2::{matmul_internal, DiffusionPermutation}; +use p3_symmetric::Permutation; + +use crate::{to_babybear_array, BabyBear}; + +// Diffusion matrices for Babybear16 and Babybear24. +// +// Reference: https://github.com/HorizenLabs/poseidon2/blob/main/plain_implementations/src/poseidon2/poseidon2_instance_babybear.rs +const MATRIX_DIAG_16_BABYBEAR_U32: [u32; 16] = [ + 0x0a632d94, 0x6db657b7, 0x56fbdc9e, 0x052b3d8a, 0x33745201, 0x5c03108c, 0x0beba37b, 0x258c2e8b, + 0x12029f39, 0x694909ce, 0x6d231724, 0x21c3b222, 0x3c0904a5, 0x01d6acda, 0x27705c83, 0x5231c802, +]; + +const MATRIX_DIAG_24_BABYBEAR_U32: [u32; 24] = [ + 0x409133f0, 0x1667a8a1, 0x06a6c7b6, 0x6f53160e, 0x273b11d1, 0x03176c5d, 0x72f9bbf9, 0x73ceba91, + 0x5cdef81d, 0x01393285, 0x46daee06, 0x065d7ba6, 0x52d72d6f, 0x05dd05e0, 0x3bab4b63, 0x6ada3842, + 0x2fc5fbec, 0x770d61b0, 0x5715aae9, 0x03ef0e90, 0x75b6c770, 0x242adf5f, 0x00d0ca4c, 0x36c0e388, +]; + +// Convert the above arrays of u32's into arrays of BabyBear field elements saved in MONTY form. +const MATRIX_DIAG_16_BABYBEAR_MONTY: [BabyBear; 16] = + to_babybear_array(MATRIX_DIAG_16_BABYBEAR_U32); +const MATRIX_DIAG_24_BABYBEAR_MONTY: [BabyBear; 24] = + to_babybear_array(MATRIX_DIAG_24_BABYBEAR_U32); + +#[derive(Debug, Clone, Default)] +pub struct DiffusionMatrixBabybear; + +impl> Permutation<[AF; 16]> for DiffusionMatrixBabybear { + fn permute_mut(&self, state: &mut [AF; 16]) { + matmul_internal::(state, MATRIX_DIAG_16_BABYBEAR_MONTY); + } +} + +impl> DiffusionPermutation for DiffusionMatrixBabybear {} + +impl> Permutation<[AF; 24]> for DiffusionMatrixBabybear { + fn permute_mut(&self, state: &mut [AF; 24]) { + matmul_internal::(state, MATRIX_DIAG_24_BABYBEAR_MONTY); + } +} + +impl> DiffusionPermutation for DiffusionMatrixBabybear {} + +#[cfg(test)] +mod tests { + use alloc::vec::Vec; + + use ark_ff::{BigInteger, PrimeField}; + use p3_poseidon2::Poseidon2; + use rand::Rng; + use zkhash::fields::babybear::FpBabyBear; + use zkhash::poseidon2::poseidon2::Poseidon2 as Poseidon2Ref; + use zkhash::poseidon2::poseidon2_instance_babybear::{POSEIDON2_BABYBEAR_16_PARAMS, RC16}; + + use super::*; + + // These are currently saved as their true values. It will be far more efficient to save them in Monty Form. + + #[test] + fn test_poseidon2_constants() { + let monty_constant = MATRIX_DIAG_16_BABYBEAR_U32.map(BabyBear::from_canonical_u32); + assert_eq!(monty_constant, MATRIX_DIAG_16_BABYBEAR_MONTY); + + let monty_constant = MATRIX_DIAG_24_BABYBEAR_U32.map(BabyBear::from_canonical_u32); + assert_eq!(monty_constant, MATRIX_DIAG_24_BABYBEAR_MONTY); + } + + fn babybear_from_ark_ff(input: FpBabyBear) -> BabyBear { + let as_bigint = input.into_bigint(); + let mut as_bytes = as_bigint.to_bytes_le(); + as_bytes.resize(4, 0); + let as_u32 = u32::from_le_bytes(as_bytes[0..4].try_into().unwrap()); + BabyBear::from_wrapped_u32(as_u32) + } + + #[test] + fn test_poseidon2_babybear_width_16() { + const WIDTH: usize = 16; + const D: u64 = 7; + const ROUNDS_F: usize = 8; + const ROUNDS_P: usize = 13; + + type F = BabyBear; + + let mut rng = rand::thread_rng(); + + // Poiseidon2 reference implementation from zkhash repo. + let poseidon2_ref = Poseidon2Ref::new(&POSEIDON2_BABYBEAR_16_PARAMS); + + // Copy over round constants from zkhash. + let round_constants: Vec<[F; WIDTH]> = RC16 + .iter() + .map(|vec| { + vec.iter() + .cloned() + .map(babybear_from_ark_ff) + .collect::>() + .try_into() + .unwrap() + }) + .collect(); + + // Our Poseidon2 implementation. + let poseidon2: Poseidon2 = + Poseidon2::new(ROUNDS_F, ROUNDS_P, round_constants, DiffusionMatrixBabybear); + + // Generate random input and convert to both BabyBear field formats. + let input_u32 = rng.gen::<[u32; WIDTH]>(); + let input_ref = input_u32 + .iter() + .cloned() + .map(FpBabyBear::from) + .collect::>(); + let input = input_u32.map(F::from_wrapped_u32); + + // Check that the conversion is correct. + assert!(input_ref + .iter() + .zip(input.iter()) + .all(|(a, b)| babybear_from_ark_ff(*a) == *b)); + + // Run reference implementation. + let output_ref = poseidon2_ref.permutation(&input_ref); + let expected: [F; WIDTH] = output_ref + .iter() + .cloned() + .map(babybear_from_ark_ff) + .collect::>() + .try_into() + .unwrap(); + + // Run our implementation. + let mut output = input; + poseidon2.permute_mut(&mut output); + + assert_eq!(output, expected); + } +} diff --git a/baby-bear/src/x86_64_avx2.rs b/baby-bear/src/x86_64_avx2.rs index 4a778159..0be97282 100644 --- a/baby-bear/src/x86_64_avx2.rs +++ b/baby-bear/src/x86_64_avx2.rs @@ -146,16 +146,16 @@ fn add(lhs: __m256i, rhs: __m256i) -> __m256i { // This implementation is based on [1] but with minor changes. The reduction is as follows: // // Constants: P = 2^31 - 2^27 + 1 -// B = 2^31 -// mu = P^-1 mod B +// B = 2^32 +// μ = P^-1 mod B // Input: 0 <= C < P B // Output: 0 <= R < P such that R = C B^-1 (mod P) -// 1. Q := mu C mod B +// 1. Q := μ C mod B // 2. D := (C - Q P) / B // 3. R := if D < 0 then D + P else D // // We first show that the division in step 2. is exact. It suffices to show that C = Q P (mod B). By -// definition of Q and mu, we have Q P = mu C P = P^-1 C P = C (mod B). We also have +// definition of Q and μ, we have Q P = μ C P = P^-1 C P = C (mod B). We also have // C - Q P = C (mod P), so thus D = C B^-1 (mod P). // // It remains to show that R is in the correct range. It suffices to show that -P <= D < P. We know @@ -669,7 +669,7 @@ unsafe impl PackedField for PackedBabyBearAVX2 { #[cfg(test)] mod tests { - use rand::{Rng, SeedableRng}; + use rand::SeedableRng; use rand_chacha::ChaCha20Rng; use super::*; diff --git a/baby-bear/src/x86_64_avx512.rs b/baby-bear/src/x86_64_avx512.rs new file mode 100644 index 00000000..c5f5a4a9 --- /dev/null +++ b/baby-bear/src/x86_64_avx512.rs @@ -0,0 +1,1350 @@ +use core::arch::x86_64::{self, __m512i, __mmask16, __mmask8}; +use core::iter::{Product, Sum}; +use core::mem::transmute; +use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; + +use p3_field::{AbstractField, Field, PackedField, PackedValue}; +use rand::distributions::{Distribution, Standard}; +use rand::Rng; + +use crate::BabyBear; + +const WIDTH: usize = 16; +const P: __m512i = unsafe { transmute::<[u32; WIDTH], _>([0x78000001; WIDTH]) }; +// On x86 MONTY_BITS is always 32, so MU = P^-1 (mod 2^32) = 0x88000001. +const MU: __m512i = unsafe { transmute::<[u32; WIDTH], _>([0x88000001; WIDTH]) }; +const EVENS: __mmask16 = 0b0101010101010101; +const EVENS4: __mmask16 = 0x0f0f; + +/// Vectorized AVX-512F implementation of `BabyBear` arithmetic. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[repr(transparent)] // This needed to make `transmute`s safe. +pub struct PackedBabyBearAVX512(pub [BabyBear; WIDTH]); + +impl PackedBabyBearAVX512 { + #[inline] + #[must_use] + /// Get an arch-specific vector representing the packed values. + fn to_vector(self) -> __m512i { + unsafe { + // Safety: `BabyBear` is `repr(transparent)` so it can be transmuted to `u32`. It + // follows that `[BabyBear; WIDTH]` can be transmuted to `[u32; WIDTH]`, which can be + // transmuted to `__m512i`, since arrays are guaranteed to be contiguous in memory. + // Finally `PackedBabyBearAVX512` is `repr(transparent)` so it can be transmuted to + // `[BabyBear; WIDTH]`. + transmute(self) + } + } + + #[inline] + #[must_use] + /// Make a packed field vector from an arch-specific vector. + /// + /// SAFETY: The caller must ensure that each element of `vector` represents a valid + /// `BabyBear`. In particular, each element of vector must be in `0..=P`. + unsafe fn from_vector(vector: __m512i) -> Self { + // Safety: It is up to the user to ensure that elements of `vector` represent valid + // `BabyBear` values. We must only reason about memory representations. `__m512i` can be + // transmuted to `[u32; WIDTH]` (since arrays elements are contiguous in memory), which can + // be transmuted to `[BabyBear; WIDTH]` (since `BabyBear` is `repr(transparent)`), which + // in turn can be transmuted to `PackedBabyBearAVX512` (since `PackedBabyBearAVX512` is also + // `repr(transparent)`). + transmute(vector) + } + + /// Copy `value` to all positions in a packed vector. This is the same as + /// `From::from`, but `const`. + #[inline] + #[must_use] + const fn broadcast(value: BabyBear) -> Self { + Self([value; WIDTH]) + } +} + +impl Add for PackedBabyBearAVX512 { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + let lhs = self.to_vector(); + let rhs = rhs.to_vector(); + let res = add(lhs, rhs); + unsafe { + // Safety: `add` returns values in canonical form when given values in canonical form. + Self::from_vector(res) + } + } +} + +impl Mul for PackedBabyBearAVX512 { + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let lhs = self.to_vector(); + let rhs = rhs.to_vector(); + let res = mul(lhs, rhs); + unsafe { + // Safety: `mul` returns values in canonical form when given values in canonical form. + Self::from_vector(res) + } + } +} + +impl Neg for PackedBabyBearAVX512 { + type Output = Self; + #[inline] + fn neg(self) -> Self { + let val = self.to_vector(); + let res = neg(val); + unsafe { + // Safety: `neg` returns values in canonical form when given values in canonical form. + Self::from_vector(res) + } + } +} + +impl Sub for PackedBabyBearAVX512 { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + let lhs = self.to_vector(); + let rhs = rhs.to_vector(); + let res = sub(lhs, rhs); + unsafe { + // Safety: `sub` returns values in canonical form when given values in canonical form. + Self::from_vector(res) + } + } +} + +/// Add two vectors of Baby Bear field elements in canonical form. +/// If the inputs are not in canonical form, the result is undefined. +#[inline] +#[must_use] +fn add(lhs: __m512i, rhs: __m512i) -> __m512i { + // We want this to compile to: + // vpaddd t, lhs, rhs + // vpsubd u, t, P + // vpminud res, t, u + // throughput: 1.5 cyc/vec (10.67 els/cyc) + // latency: 3 cyc + + // Let t := lhs + rhs. We want to return t mod P. Recall that lhs and rhs are in + // 0, ..., P - 1, so t is in 0, ..., 2 P - 2 (< 2^32). It suffices to return t if t < P and + // t - P otherwise. + // Let u := (t - P) mod 2^32 and r := unsigned_min(t, u). + // If t is in 0, ..., P - 1, then u is in (P - 1 <) 2^32 - P, ..., 2^32 - 1 and r = t. + // Otherwise, t is in P, ..., 2 P - 2, u is in 0, ..., P - 2 (< P) and r = u. Hence, r is t if + // t < P and t - P otherwise, as desired. + + unsafe { + // Safety: If this code got compiled then AVX-512F intrinsics are available. + let t = x86_64::_mm512_add_epi32(lhs, rhs); + let u = x86_64::_mm512_sub_epi32(t, P); + x86_64::_mm512_min_epu32(t, u) + } +} + +// MONTGOMERY MULTIPLICATION +// This implementation is based on [1] but with minor changes. The reduction is as follows: +// +// Constants: P = 2^31 - 2^27 + 1 +// B = 2^32 +// μ = P^-1 mod B +// Input: 0 <= C < P B +// Output: 0 <= R < P such that R = C B^-1 (mod P) +// 1. Q := μ C mod B +// 2. D := (C - Q P) / B +// 3. R := if D < 0 then D + P else D +// +// We first show that the division in step 2. is exact. It suffices to show that C = Q P (mod B). By +// definition of Q and μ, we have Q P = μ C P = P^-1 C P = C (mod B). We also have +// C - Q P = C (mod P), so thus D = C B^-1 (mod P). +// +// It remains to show that R is in the correct range. It suffices to show that -P <= D < P. We know +// that 0 <= C < P B and 0 <= Q P < P B. Then -P B < C - QP < P B and -P < D < P, as desired. +// +// [1] Modern Computer Arithmetic, Richard Brent and Paul Zimmermann, Cambridge University Press, +// 2010, algorithm 2.7. + +/// Viewing the input as a vector of 16 `u32`s, copy the odd elements into the even elements below +/// them. In other words, for all `0 <= i < 8`, set the even elements according to +/// `res[2 * i] := a[2 * i + 1]`, and the odd elements according to +/// `res[2 * i + 1] := a[2 * i + 1]`. +#[inline] +#[must_use] +fn movehdup_epi32(a: __m512i) -> __m512i { + // The instruction is only available in the floating-point flavor; this distinction is only for + // historical reasons and no longer matters. We cast to floats, do the thing, and cast back. + unsafe { + x86_64::_mm512_castps_si512(x86_64::_mm512_movehdup_ps(x86_64::_mm512_castsi512_ps(a))) + } +} + +/// Viewing `a` as a vector of 16 `u32`s, copy the odd elements into the even elements below them, +/// then merge with `src` according to the mask provided. In other words, for all `0 <= i < 8`, set +/// the even elements according to `res[2 * i] := if k[2 * i] { a[2 * i + 1] } else { src[2 * i] }`, +/// and the odd elements according to +/// `res[2 * i + 1] := if k[2 * i + 1] { a[2 * i + 1] } else { src[2 * i + 1] }`. +#[inline] +#[must_use] +fn mask_movehdup_epi32(src: __m512i, k: __mmask16, a: __m512i) -> __m512i { + // The instruction is only available in the floating-point flavor; this distinction is only for + // historical reasons and no longer matters. We cast to floats, do the thing, and cast back. + unsafe { + let src = x86_64::_mm512_castsi512_ps(src); + let a = x86_64::_mm512_castsi512_ps(a); + x86_64::_mm512_castps_si512(x86_64::_mm512_mask_movehdup_ps(src, k, a)) + } +} + +/// Multiply vectors of Baby Bear field elements in canonical form. +/// If the inputs are not in canonical form, the result is undefined. +#[inline] +#[must_use] +#[allow(non_snake_case)] +fn mul(lhs: __m512i, rhs: __m512i) -> __m512i { + // We want this to compile to: + // vmovshdup lhs_odd, lhs + // vmovshdup rhs_odd, rhs + // vpmuludq prod_evn, lhs, rhs + // vpmuludq prod_hi, lhs_odd, rhs_odd + // vpmuludq q_evn, prod_evn, MU + // vpmuludq q_odd, prod_hi, MU + // vmovshdup prod_hi{EVENS}, prod_evn + // vpmuludq q_P_evn, q_evn, P + // vpmuludq q_P_hi, q_odd, P + // vmovshdup q_P_hi{EVENS}, q_P_evn + // vpcmpltud underflow, prod_hi, q_P_hi + // vpsubd res, prod_hi, q_P_hi + // vpaddd res{underflow}, res, P + // throughput: 6.5 cyc/vec (2.46 els/cyc) + // latency: 21 cyc + unsafe { + // `vpmuludq` only reads the even doublewords, so when we pass `lhs` and `rhs` directly we + // get the eight products at even positions. + let lhs_evn = lhs; + let rhs_evn = rhs; + + // Copy the odd doublewords into even positions to compute the eight products at odd + // positions. + // NB: The odd doublewords are ignored by `vpmuludq`, so we have a lot of choices for how to + // do this; `vmovshdup` is nice because it runs on a memory port if the operand is in + // memory, thus improving our throughput. + let lhs_odd = movehdup_epi32(lhs); + let rhs_odd = movehdup_epi32(rhs); + + let prod_evn = x86_64::_mm512_mul_epu32(lhs_evn, rhs_evn); + let prod_odd = x86_64::_mm512_mul_epu32(lhs_odd, rhs_odd); + + let q_evn = x86_64::_mm512_mul_epu32(prod_evn, MU); + let q_odd = x86_64::_mm512_mul_epu32(prod_odd, MU); + + // Get all the high halves as one vector: this is `(lhs * rhs) >> 32`. + // NB: `vpermt2d` may feel like a more intuitive choice here, but it has much higher + // latency. + let prod_hi = mask_movehdup_epi32(prod_odd, EVENS, prod_evn); + + // Normally we'd want to mask to perform % 2**32, but the instruction below only reads the + // low 32 bits anyway. + let q_P_evn = x86_64::_mm512_mul_epu32(q_evn, P); + let q_P_odd = x86_64::_mm512_mul_epu32(q_odd, P); + + // We can ignore all the low halves of `q_P` as they cancel out. Get all the high halves as + // one vector. + let q_P_hi = mask_movehdup_epi32(q_P_odd, EVENS, q_P_evn); + + // Subtraction `prod_hi - q_P_hi` modulo `P`. + // NB: Normally we'd `vpaddd P` and take the `vpminud`, but `vpminud` runs on port 0, which + // is already under a lot of pressure performing multiplications. To relieve this pressure, + // we check for underflow to generate a mask, and then conditionally add `P`. The underflow + // check runs on port 5, increasing our throughput, although it does cost us an additional + // cycle of latency. + let underflow = x86_64::_mm512_cmplt_epu32_mask(prod_hi, q_P_hi); + let t = x86_64::_mm512_sub_epi32(prod_hi, q_P_hi); + x86_64::_mm512_mask_add_epi32(t, underflow, t, P) + } +} + +/// Negate a vector of Baby Bear field elements in canonical form. +/// If the inputs are not in canonical form, the result is undefined. +#[inline] +#[must_use] +fn neg(val: __m512i) -> __m512i { + // We want this to compile to: + // vptestmd nonzero, val, val + // vpsubd res{nonzero}{z}, P, val + // throughput: 1 cyc/vec (16 els/cyc) + // latency: 4 cyc + + // NB: This routine prioritizes throughput over latency. An alternative method would be to do + // sub(0, val), which would result in shorter latency, but also lower throughput. + + // If val is nonzero, then val is in {1, ..., P - 1} and P - val is in the same range. If val + // is zero, then the result is zeroed by masking. + unsafe { + // Safety: If this code got compiled then AVX-512F intrinsics are available. + let nonzero = x86_64::_mm512_test_epi32_mask(val, val); + x86_64::_mm512_maskz_sub_epi32(nonzero, P, val) + } +} + +/// Subtract vectors of Baby Bear field elements in canonical form. +/// If the inputs are not in canonical form, the result is undefined. +#[inline] +#[must_use] +fn sub(lhs: __m512i, rhs: __m512i) -> __m512i { + // We want this to compile to: + // vpsubd t, lhs, rhs + // vpaddd u, t, P + // vpminud res, t, u + // throughput: 1.5 cyc/vec (10.67 els/cyc) + // latency: 3 cyc + + // Let t := lhs - rhs. We want to return t mod P. Recall that lhs and rhs are in + // 0, ..., P - 1, so t is in (-2^31 <) -P + 1, ..., P - 1 (< 2^31). It suffices to return t if + // t >= 0 and t + P otherwise. + // Let u := (t + P) mod 2^32 and r := unsigned_min(t, u). + // If t is in 0, ..., P - 1, then u is in P, ..., 2 P - 1 and r = t. + // Otherwise, t is in -P + 1, ..., -1; u is in 1, ..., P - 1 (< P) and r = u. Hence, r is t if + // t < P and t - P otherwise, as desired. + unsafe { + // Safety: If this code got compiled then AVX-512F intrinsics are available. + let t = x86_64::_mm512_sub_epi32(lhs, rhs); + let u = x86_64::_mm512_add_epi32(t, P); + x86_64::_mm512_min_epu32(t, u) + } +} + +impl From for PackedBabyBearAVX512 { + #[inline] + fn from(value: BabyBear) -> Self { + Self::broadcast(value) + } +} + +impl Default for PackedBabyBearAVX512 { + #[inline] + fn default() -> Self { + BabyBear::default().into() + } +} + +impl AddAssign for PackedBabyBearAVX512 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl MulAssign for PackedBabyBearAVX512 { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl SubAssign for PackedBabyBearAVX512 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl Sum for PackedBabyBearAVX512 { + #[inline] + fn sum(iter: I) -> Self + where + I: Iterator, + { + iter.reduce(|lhs, rhs| lhs + rhs).unwrap_or(Self::zero()) + } +} + +impl Product for PackedBabyBearAVX512 { + #[inline] + fn product(iter: I) -> Self + where + I: Iterator, + { + iter.reduce(|lhs, rhs| lhs * rhs).unwrap_or(Self::one()) + } +} + +impl AbstractField for PackedBabyBearAVX512 { + type F = BabyBear; + + #[inline] + fn zero() -> Self { + BabyBear::zero().into() + } + + #[inline] + fn one() -> Self { + BabyBear::one().into() + } + + #[inline] + fn two() -> Self { + BabyBear::two().into() + } + + #[inline] + fn neg_one() -> Self { + BabyBear::neg_one().into() + } + + #[inline] + fn from_f(f: Self::F) -> Self { + f.into() + } + + #[inline] + fn from_bool(b: bool) -> Self { + BabyBear::from_bool(b).into() + } + #[inline] + fn from_canonical_u8(n: u8) -> Self { + BabyBear::from_canonical_u8(n).into() + } + #[inline] + fn from_canonical_u16(n: u16) -> Self { + BabyBear::from_canonical_u16(n).into() + } + #[inline] + fn from_canonical_u32(n: u32) -> Self { + BabyBear::from_canonical_u32(n).into() + } + #[inline] + fn from_canonical_u64(n: u64) -> Self { + BabyBear::from_canonical_u64(n).into() + } + #[inline] + fn from_canonical_usize(n: usize) -> Self { + BabyBear::from_canonical_usize(n).into() + } + + #[inline] + fn from_wrapped_u32(n: u32) -> Self { + BabyBear::from_wrapped_u32(n).into() + } + #[inline] + fn from_wrapped_u64(n: u64) -> Self { + BabyBear::from_wrapped_u64(n).into() + } + + #[inline] + fn generator() -> Self { + BabyBear::generator().into() + } +} + +impl Add for PackedBabyBearAVX512 { + type Output = Self; + #[inline] + fn add(self, rhs: BabyBear) -> Self { + self + Self::from(rhs) + } +} + +impl Mul for PackedBabyBearAVX512 { + type Output = Self; + #[inline] + fn mul(self, rhs: BabyBear) -> Self { + self * Self::from(rhs) + } +} + +impl Sub for PackedBabyBearAVX512 { + type Output = Self; + #[inline] + fn sub(self, rhs: BabyBear) -> Self { + self - Self::from(rhs) + } +} + +impl AddAssign for PackedBabyBearAVX512 { + #[inline] + fn add_assign(&mut self, rhs: BabyBear) { + *self += Self::from(rhs) + } +} + +impl MulAssign for PackedBabyBearAVX512 { + #[inline] + fn mul_assign(&mut self, rhs: BabyBear) { + *self *= Self::from(rhs) + } +} + +impl SubAssign for PackedBabyBearAVX512 { + #[inline] + fn sub_assign(&mut self, rhs: BabyBear) { + *self -= Self::from(rhs) + } +} + +impl Sum for PackedBabyBearAVX512 { + #[inline] + fn sum(iter: I) -> Self + where + I: Iterator, + { + iter.sum::().into() + } +} + +impl Product for PackedBabyBearAVX512 { + #[inline] + fn product(iter: I) -> Self + where + I: Iterator, + { + iter.product::().into() + } +} + +impl Div for PackedBabyBearAVX512 { + type Output = Self; + #[allow(clippy::suspicious_arithmetic_impl)] + #[inline] + fn div(self, rhs: BabyBear) -> Self { + self * rhs.inverse() + } +} + +impl Add for BabyBear { + type Output = PackedBabyBearAVX512; + #[inline] + fn add(self, rhs: PackedBabyBearAVX512) -> PackedBabyBearAVX512 { + PackedBabyBearAVX512::from(self) + rhs + } +} + +impl Mul for BabyBear { + type Output = PackedBabyBearAVX512; + #[inline] + fn mul(self, rhs: PackedBabyBearAVX512) -> PackedBabyBearAVX512 { + PackedBabyBearAVX512::from(self) * rhs + } +} + +impl Sub for BabyBear { + type Output = PackedBabyBearAVX512; + #[inline] + fn sub(self, rhs: PackedBabyBearAVX512) -> PackedBabyBearAVX512 { + PackedBabyBearAVX512::from(self) - rhs + } +} + +impl Distribution for Standard { + #[inline] + fn sample(&self, rng: &mut R) -> PackedBabyBearAVX512 { + PackedBabyBearAVX512(rng.gen()) + } +} + +// vpshrdq requires AVX-512VBMI2. +#[cfg(target_feature = "avx512vbmi2")] +#[inline] +#[must_use] +fn interleave1_antidiagonal(x: __m512i, y: __m512i) -> __m512i { + unsafe { + // Safety: If this code got compiled then AVX-512VBMI2 intrinsics are available. + x86_64::_mm512_shrdi_epi64::<32>(y, x) + } +} + +// If we can't use vpshrdq, then do a vpermi2d, but we waste a register and double the latency. +#[cfg(not(target_feature = "avx512vbmi2"))] +#[inline] +#[must_use] +fn interleave1_antidiagonal(x: __m512i, y: __m512i) -> __m512i { + const INTERLEAVE1_INDICES: __m512i = unsafe { + // Safety: `[u32; 16]` is trivially transmutable to `__m512i`. + transmute::<[u32; WIDTH], _>([ + 0x01, 0x10, 0x03, 0x12, 0x05, 0x14, 0x07, 0x16, 0x09, 0x18, 0x0b, 0x1a, 0x0d, 0x1c, + 0x0f, 0x1e, + ]) + }; + unsafe { + // Safety: If this code got compiled then AVX-512F intrinsics are available. + x86_64::_mm512_permutex2var_epi32(x, INTERLEAVE1_INDICES, y) + } +} + +#[inline] +#[must_use] +fn interleave1(x: __m512i, y: __m512i) -> (__m512i, __m512i) { + // If we have AVX-512VBMI2, we want this to compile to: + // vpshrdq t, x, y, 32 + // vpblendmd res0 {EVENS}, t, x + // vpblendmd res1 {EVENS}, y, t + // throughput: 1.5 cyc/2 vec (21.33 els/cyc) + // latency: 2 cyc + // + // Otherwise, we want it to compile to: + // vmovdqa32 t, INTERLEAVE1_INDICES + // vpermi2d t, x, y + // vpblendmd res0 {EVENS}, t, x + // vpblendmd res1 {EVENS}, y, t + // throughput: 1.5 cyc/2 vec (21.33 els/cyc) + // latency: 4 cyc + + // We currently have: + // x = [ x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 xa xb xc xd xe xf ], + // y = [ y0 y1 y2 y3 y4 y5 y6 y7 y8 y9 ya yb yc yd ye yf ]. + // First form + // t = [ x1 y0 x3 y2 x5 y4 x7 y6 x9 y8 xb ya xd yc xf ye ]. + let t = interleave1_antidiagonal(x, y); + + unsafe { + // Safety: If this code got compiled then AVX-512F intrinsics are available. + + // Then + // res0 = [ x0 y0 x2 y2 x4 y4 x6 y6 x8 y8 xa ya xc yc xe ye ], + // res1 = [ x1 y1 x3 y3 x5 y5 x7 y7 x9 y9 xb yb xd yd xf yf ]. + ( + x86_64::_mm512_mask_blend_epi32(EVENS, t, x), + x86_64::_mm512_mask_blend_epi32(EVENS, y, t), + ) + } +} + +#[inline] +#[must_use] +fn shuffle_epi64(a: __m512i, b: __m512i) -> __m512i { + // The instruction is only available in the floating-point flavor; this distinction is only for + // historical reasons and no longer matters. We cast to floats, do the thing, and cast back. + unsafe { + let a = x86_64::_mm512_castsi512_pd(a); + let b = x86_64::_mm512_castsi512_pd(b); + x86_64::_mm512_castpd_si512(x86_64::_mm512_shuffle_pd::(a, b)) + } +} + +#[inline] +#[must_use] +fn interleave2(x: __m512i, y: __m512i) -> (__m512i, __m512i) { + // We want this to compile to: + // vshufpd t, x, y, 55h + // vpblendmq res0 {EVENS}, t, x + // vpblendmq res1 {EVENS}, y, t + // throughput: 1.5 cyc/2 vec (21.33 els/cyc) + // latency: 2 cyc + + unsafe { + // Safety: If this code got compiled then AVX-512F intrinsics are available. + + // We currently have: + // x = [ x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 xa xb xc xd xe xf ], + // y = [ y0 y1 y2 y3 y4 y5 y6 y7 y8 y9 ya yb yc yd ye yf ]. + // First form + // t = [ x2 x3 y0 y1 x6 x7 y4 y5 xa xb y8 y9 xe xf yc yd ]. + let t = shuffle_epi64::<0b01010101>(x, y); + + // Then + // res0 = [ x0 x1 y0 y1 x4 x5 y4 y5 x8 x9 y8 y9 xc xd yc yd ], + // res1 = [ x2 x3 y2 y3 x6 x7 y6 y7 xa xb ya yb xe xf ye yf ]. + ( + x86_64::_mm512_mask_blend_epi64(EVENS as __mmask8, t, x), + x86_64::_mm512_mask_blend_epi64(EVENS as __mmask8, y, t), + ) + } +} + +#[inline] +#[must_use] +fn interleave4(x: __m512i, y: __m512i) -> (__m512i, __m512i) { + // We want this to compile to: + // vmovdqa64 t, INTERLEAVE4_INDICES + // vpermi2q t, x, y + // vpblendmd res0 {EVENS4}, t, x + // vpblendmd res1 {EVENS4}, y, t + // throughput: 1.5 cyc/2 vec (21.33 els/cyc) + // latency: 4 cyc + + const INTERLEAVE4_INDICES: __m512i = unsafe { + // Safety: `[u64; 8]` is trivially transmutable to `__m512i`. + transmute::<[u64; WIDTH / 2], _>([0o02, 0o03, 0o10, 0o11, 0o06, 0o07, 0o14, 0o15]) + }; + + unsafe { + // Safety: If this code got compiled then AVX-512F intrinsics are available. + + // We currently have: + // x = [ x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 xa xb xc xd xe xf ], + // y = [ y0 y1 y2 y3 y4 y5 y6 y7 y8 y9 ya yb yc yd ye yf ]. + // First form + // t = [ x4 x5 x6 x7 y0 y1 y2 y3 xc xd xe xf y8 y9 ya yb ]. + let t = x86_64::_mm512_permutex2var_epi64(x, INTERLEAVE4_INDICES, y); + + // Then + // res0 = [ x0 x1 x2 x3 y0 y1 y2 y3 x8 x9 xa xb y8 y9 ya yb ], + // res1 = [ x4 x5 x6 x7 y4 y5 y6 y7 xc xd xe xf yc yd ye yf ]. + ( + x86_64::_mm512_mask_blend_epi32(EVENS4, t, x), + x86_64::_mm512_mask_blend_epi32(EVENS4, y, t), + ) + } +} + +#[inline] +#[must_use] +fn interleave8(x: __m512i, y: __m512i) -> (__m512i, __m512i) { + // We want this to compile to: + // vshufi64x2 t, x, b, 4eh + // vpblendmq res0 {EVENS4}, t, x + // vpblendmq res1 {EVENS4}, y, t + // throughput: 1.5 cyc/2 vec (21.33 els/cyc) + // latency: 4 cyc + + unsafe { + // Safety: If this code got compiled then AVX-512F intrinsics are available. + + // We currently have: + // x = [ x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 xa xb xc xd xe xf ], + // y = [ y0 y1 y2 y3 y4 y5 y6 y7 y8 y9 ya yb yc yd ye yf ]. + // First form + // t = [ x8 x9 xa xb xc xd xe xf y0 y1 y2 y3 y4 y5 y6 y7 ]. + let t = x86_64::_mm512_shuffle_i64x2::<0b01_00_11_10>(x, y); + + // Then + // res0 = [ x0 x1 x2 x3 x4 x5 x6 x7 y0 y1 y2 y3 y4 y5 y6 y7 ], + // res1 = [ x8 x9 xa xb xc xd xe xf y8 y9 ya yb yc yd ye yf ]. + ( + x86_64::_mm512_mask_blend_epi64(EVENS4 as __mmask8, t, x), + x86_64::_mm512_mask_blend_epi64(EVENS4 as __mmask8, y, t), + ) + } +} + +unsafe impl PackedValue for PackedBabyBearAVX512 { + type Value = BabyBear; + + const WIDTH: usize = WIDTH; + + #[inline] + fn from_slice(slice: &[BabyBear]) -> &Self { + assert_eq!(slice.len(), Self::WIDTH); + unsafe { + // Safety: `[BabyBear; WIDTH]` can be transmuted to `PackedBabyBearAVX512` since the + // latter is `repr(transparent)`. They have the same alignment, so the reference cast is + // safe too. + &*slice.as_ptr().cast() + } + } + #[inline] + fn from_slice_mut(slice: &mut [BabyBear]) -> &mut Self { + assert_eq!(slice.len(), Self::WIDTH); + unsafe { + // Safety: `[BabyBear; WIDTH]` can be transmuted to `PackedBabyBearAVX512` since the + // latter is `repr(transparent)`. They have the same alignment, so the reference cast is + // safe too. + &mut *slice.as_mut_ptr().cast() + } + } + + /// Similar to `core:array::from_fn`. + #[inline] + fn from_fn BabyBear>(f: F) -> Self { + let vals_arr: [_; WIDTH] = core::array::from_fn(f); + Self(vals_arr) + } + + #[inline] + fn as_slice(&self) -> &[BabyBear] { + &self.0[..] + } + #[inline] + fn as_slice_mut(&mut self) -> &mut [BabyBear] { + &mut self.0[..] + } +} + +unsafe impl PackedField for PackedBabyBearAVX512 { + type Scalar = BabyBear; + + #[inline] + fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) { + let (v0, v1) = (self.to_vector(), other.to_vector()); + let (res0, res1) = match block_len { + 1 => interleave1(v0, v1), + 2 => interleave2(v0, v1), + 4 => interleave4(v0, v1), + 8 => interleave8(v0, v1), + 16 => (v0, v1), + _ => panic!("unsupported block_len"), + }; + unsafe { + // Safety: all values are in canonical form (we haven't changed them). + (Self::from_vector(res0), Self::from_vector(res1)) + } + } +} + +#[cfg(test)] +mod tests { + use rand::SeedableRng; + use rand_chacha::ChaCha20Rng; + + use super::*; + + type F = BabyBear; + type P = PackedBabyBearAVX512; + + const fn array_from_valid_reps(vals: [u32; WIDTH]) -> [F; WIDTH] { + let mut res = [BabyBear { value: 0 }; WIDTH]; + let mut i = 0; + while i < WIDTH { + res[i] = BabyBear { value: vals[i] }; + i += 1; + } + res + } + + const fn packed_from_valid_reps(vals: [u32; WIDTH]) -> P { + PackedBabyBearAVX512(array_from_valid_reps(vals)) + } + + fn array_from_random(seed: u64) -> [F; WIDTH] { + let mut rng = ChaCha20Rng::seed_from_u64(seed); + [(); WIDTH].map(|_| rng.gen()) + } + + fn packed_from_random(seed: u64) -> P { + PackedBabyBearAVX512(array_from_random(seed)) + } + + const SPECIAL_VALS: [F; WIDTH] = array_from_valid_reps([ + 0x00000000, 0x00000001, 0x78000000, 0x77ffffff, 0x3c000000, 0x0ffffffe, 0x68000003, + 0x70000002, 0x00000000, 0x00000001, 0x78000000, 0x77ffffff, 0x3c000000, 0x0ffffffe, + 0x68000003, 0x70000002, + ]); + + #[test] + fn test_interleave_1() { + let vec0 = packed_from_valid_reps([ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, + ]); + let vec1 = packed_from_valid_reps([ + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, + 0x1e, 0x1f, + ]); + + let expected0 = packed_from_valid_reps([ + 0x00, 0x10, 0x02, 0x12, 0x04, 0x14, 0x06, 0x16, 0x08, 0x18, 0x0a, 0x1a, 0x0c, 0x1c, + 0x0e, 0x1e, + ]); + let expected1 = packed_from_valid_reps([ + 0x01, 0x11, 0x03, 0x13, 0x05, 0x15, 0x07, 0x17, 0x09, 0x19, 0x0b, 0x1b, 0x0d, 0x1d, + 0x0f, 0x1f, + ]); + + let (res0, res1) = vec0.interleave(vec1, 1); + assert_eq!(res0, expected0); + assert_eq!(res1, expected1); + } + + #[test] + fn test_interleave_2() { + let vec0 = packed_from_valid_reps([ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, + ]); + let vec1 = packed_from_valid_reps([ + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, + 0x1e, 0x1f, + ]); + + let expected0 = packed_from_valid_reps([ + 0x00, 0x01, 0x10, 0x11, 0x04, 0x05, 0x14, 0x15, 0x08, 0x09, 0x18, 0x19, 0x0c, 0x0d, + 0x1c, 0x1d, + ]); + let expected1 = packed_from_valid_reps([ + 0x02, 0x03, 0x12, 0x13, 0x06, 0x07, 0x16, 0x17, 0x0a, 0x0b, 0x1a, 0x1b, 0x0e, 0x0f, + 0x1e, 0x1f, + ]); + + let (res0, res1) = vec0.interleave(vec1, 2); + assert_eq!(res0, expected0); + assert_eq!(res1, expected1); + } + + #[test] + fn test_interleave_4() { + let vec0 = packed_from_valid_reps([ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, + ]); + let vec1 = packed_from_valid_reps([ + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, + 0x1e, 0x1f, + ]); + + let expected0 = packed_from_valid_reps([ + 0x00, 0x01, 0x02, 0x03, 0x10, 0x11, 0x12, 0x13, 0x08, 0x09, 0x0a, 0x0b, 0x18, 0x19, + 0x1a, 0x1b, + ]); + let expected1 = packed_from_valid_reps([ + 0x04, 0x05, 0x06, 0x07, 0x14, 0x15, 0x16, 0x17, 0x0c, 0x0d, 0x0e, 0x0f, 0x1c, 0x1d, + 0x1e, 0x1f, + ]); + + let (res0, res1) = vec0.interleave(vec1, 4); + assert_eq!(res0, expected0); + assert_eq!(res1, expected1); + } + + #[test] + fn test_interleave_8() { + let vec0 = packed_from_valid_reps([ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, + ]); + let vec1 = packed_from_valid_reps([ + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, + 0x1e, 0x1f, + ]); + + let expected0 = packed_from_valid_reps([ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, + 0x16, 0x17, + ]); + let expected1 = packed_from_valid_reps([ + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, + 0x1e, 0x1f, + ]); + + let (res0, res1) = vec0.interleave(vec1, 8); + assert_eq!(res0, expected0); + assert_eq!(res1, expected1); + } + + #[test] + fn test_interleave_16() { + let vec0 = packed_from_valid_reps([ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, + ]); + let vec1 = packed_from_valid_reps([ + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, + 0x1e, 0x1f, + ]); + + let (res0, res1) = vec0.interleave(vec1, 16); + assert_eq!(res0, vec0); + assert_eq!(res1, vec1); + } + + #[test] + fn test_add_associative() { + let vec0 = packed_from_random(0x8b078c2b693c893f); + let vec1 = packed_from_random(0x4ff5dec04791e481); + let vec2 = packed_from_random(0x5806c495e9451f8e); + + let res0 = (vec0 + vec1) + vec2; + let res1 = vec0 + (vec1 + vec2); + + assert_eq!(res0, res1); + } + + #[test] + fn test_add_commutative() { + let vec0 = packed_from_random(0xe1bf9cac02e9072a); + let vec1 = packed_from_random(0xb5061e7de6a6c677); + + let res0 = vec0 + vec1; + let res1 = vec1 + vec0; + + assert_eq!(res0, res1); + } + + #[test] + fn test_additive_identity_right() { + let vec = packed_from_random(0xbcd56facf6a714b5); + let res = vec + P::zero(); + assert_eq!(res, vec); + } + + #[test] + fn test_additive_identity_left() { + let vec = packed_from_random(0xb614285cd641233c); + let res = P::zero() + vec; + assert_eq!(res, vec); + } + + #[test] + fn test_additive_inverse_add_neg() { + let vec = packed_from_random(0x4b89c8d023c9c62e); + let neg_vec = -vec; + let res = vec + neg_vec; + assert_eq!(res, P::zero()); + } + + #[test] + fn test_additive_inverse_sub() { + let vec = packed_from_random(0x2c94652ee5561341); + let res = vec - vec; + assert_eq!(res, P::zero()); + } + + #[test] + fn test_sub_anticommutative() { + let vec0 = packed_from_random(0xf3783730a14b460e); + let vec1 = packed_from_random(0x5b6f827a023525ee); + + let res0 = vec0 - vec1; + let res1 = -(vec1 - vec0); + + assert_eq!(res0, res1); + } + + #[test] + fn test_sub_zero() { + let vec = packed_from_random(0xc1a526f8226ec1e5); + let res = vec - P::zero(); + assert_eq!(res, vec); + } + + #[test] + fn test_zero_sub() { + let vec = packed_from_random(0x4444b9c090519333); + let res0 = P::zero() - vec; + let res1 = -vec; + assert_eq!(res0, res1); + } + + #[test] + fn test_neg_own_inverse() { + let vec = packed_from_random(0xee4df174b850a35f); + let res = --vec; + assert_eq!(res, vec); + } + + #[test] + fn test_sub_is_add_neg() { + let vec0 = packed_from_random(0x18f4b5c3a08e49fe); + let vec1 = packed_from_random(0x39bd37a1dc24d492); + let res0 = vec0 - vec1; + let res1 = vec0 + (-vec1); + assert_eq!(res0, res1); + } + + #[test] + fn test_mul_associative() { + let vec0 = packed_from_random(0x0b1ee4d7c979d50c); + let vec1 = packed_from_random(0x39faa0844a36e45a); + let vec2 = packed_from_random(0x08fac4ee76260e44); + + let res0 = (vec0 * vec1) * vec2; + let res1 = vec0 * (vec1 * vec2); + + assert_eq!(res0, res1); + } + + #[test] + fn test_mul_commutative() { + let vec0 = packed_from_random(0x10debdcbd409270c); + let vec1 = packed_from_random(0x927bc073c1c92b2f); + + let res0 = vec0 * vec1; + let res1 = vec1 * vec0; + + assert_eq!(res0, res1); + } + + #[test] + fn test_multiplicative_identity_right() { + let vec = packed_from_random(0xdf0a646b6b0c2c36); + let res = vec * P::one(); + assert_eq!(res, vec); + } + + #[test] + fn test_multiplicative_identity_left() { + let vec = packed_from_random(0x7b4d890bf7a38bd2); + let res = P::one() * vec; + assert_eq!(res, vec); + } + + #[test] + fn test_multiplicative_inverse() { + let arr = array_from_random(0xb0c7a5153103c5a8); + let arr_inv = arr.map(|x| x.inverse()); + + let vec = PackedBabyBearAVX512(arr); + let vec_inv = PackedBabyBearAVX512(arr_inv); + + let res = vec * vec_inv; + assert_eq!(res, P::one()); + } + + #[test] + fn test_mul_zero() { + let vec = packed_from_random(0x7f998daa72489bd7); + let res = vec * P::zero(); + assert_eq!(res, P::zero()); + } + + #[test] + fn test_zero_mul() { + let vec = packed_from_random(0x683bc2dd355b06e5); + let res = P::zero() * vec; + assert_eq!(res, P::zero()); + } + + #[test] + fn test_mul_negone() { + let vec = packed_from_random(0x97cb9670a8251202); + let res0 = vec * P::neg_one(); + let res1 = -vec; + assert_eq!(res0, res1); + } + + #[test] + fn test_negone_mul() { + let vec = packed_from_random(0xadae69873b5d3baf); + let res0 = P::neg_one() * vec; + let res1 = -vec; + assert_eq!(res0, res1); + } + + #[test] + fn test_neg_distributivity_left() { + let vec0 = packed_from_random(0xd0efd6f272c7de93); + let vec1 = packed_from_random(0xd5dd2cf5e76dd694); + + let res0 = vec0 * -vec1; + let res1 = -(vec0 * vec1); + + assert_eq!(res0, res1); + } + + #[test] + fn test_neg_distributivity_right() { + let vec0 = packed_from_random(0x0da9b03cd4b79b09); + let vec1 = packed_from_random(0x9964d3f4beaf1857); + + let res0 = -vec0 * vec1; + let res1 = -(vec0 * vec1); + + assert_eq!(res0, res1); + } + + #[test] + fn test_add_distributivity_left() { + let vec0 = packed_from_random(0x278d9e202925a1d1); + let vec1 = packed_from_random(0xf04cbac0cbad419f); + let vec2 = packed_from_random(0x76976e2abdc5a056); + + let res0 = vec0 * (vec1 + vec2); + let res1 = vec0 * vec1 + vec0 * vec2; + + assert_eq!(res0, res1); + } + + #[test] + fn test_add_distributivity_right() { + let vec0 = packed_from_random(0xbe1b606eafe2a2b8); + let vec1 = packed_from_random(0x552686a0978ab571); + let vec2 = packed_from_random(0x36f6eec4fd31a460); + + let res0 = (vec0 + vec1) * vec2; + let res1 = vec0 * vec2 + vec1 * vec2; + + assert_eq!(res0, res1); + } + + #[test] + fn test_sub_distributivity_left() { + let vec0 = packed_from_random(0x817d4a27febb0349); + let vec1 = packed_from_random(0x1eaf62a921d6519b); + let vec2 = packed_from_random(0xfec0fb8d3849465a); + + let res0 = vec0 * (vec1 - vec2); + let res1 = vec0 * vec1 - vec0 * vec2; + + assert_eq!(res0, res1); + } + + #[test] + fn test_sub_distributivity_right() { + let vec0 = packed_from_random(0x5a4a82e8e2394585); + let vec1 = packed_from_random(0x6006b1443a22b102); + let vec2 = packed_from_random(0x5a22deac65fcd454); + + let res0 = (vec0 - vec1) * vec2; + let res1 = vec0 * vec2 - vec1 * vec2; + + assert_eq!(res0, res1); + } + + #[test] + fn test_one_plus_one() { + assert_eq!(P::one() + P::one(), P::two()); + } + + #[test] + fn test_negone_plus_two() { + assert_eq!(P::neg_one() + P::two(), P::one()); + } + + #[test] + fn test_double() { + let vec = packed_from_random(0x2e61a907650881e9); + let res0 = P::two() * vec; + let res1 = vec + vec; + assert_eq!(res0, res1); + } + + #[test] + fn test_add_vs_scalar() { + let arr0 = array_from_random(0xac23b5a694dabf70); + let arr1 = array_from_random(0xd249ec90e8a6e733); + + let vec0 = PackedBabyBearAVX512(arr0); + let vec1 = PackedBabyBearAVX512(arr1); + let vec_res = vec0 + vec1; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], arr0[i] + arr1[i]); + } + } + + #[test] + fn test_add_vs_scalar_special_vals_left() { + let arr0 = SPECIAL_VALS; + let arr1 = array_from_random(0x1e2b153f07b64cf3); + + let vec0 = PackedBabyBearAVX512(arr0); + let vec1 = PackedBabyBearAVX512(arr1); + let vec_res = vec0 + vec1; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], arr0[i] + arr1[i]); + } + } + + #[test] + fn test_add_vs_scalar_special_vals_right() { + let arr0 = array_from_random(0xfcf974ac7625a260); + let arr1 = SPECIAL_VALS; + + let vec0 = PackedBabyBearAVX512(arr0); + let vec1 = PackedBabyBearAVX512(arr1); + let vec_res = vec0 + vec1; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], arr0[i] + arr1[i]); + } + } + + #[test] + fn test_sub_vs_scalar() { + let arr0 = array_from_random(0x167ce9d8e920876e); + let arr1 = array_from_random(0x52ddcdd3461e046f); + + let vec0 = PackedBabyBearAVX512(arr0); + let vec1 = PackedBabyBearAVX512(arr1); + let vec_res = vec0 - vec1; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], arr0[i] - arr1[i]); + } + } + + #[test] + fn test_sub_vs_scalar_special_vals_left() { + let arr0 = SPECIAL_VALS; + let arr1 = array_from_random(0x358498640bfe1375); + + let vec0 = PackedBabyBearAVX512(arr0); + let vec1 = PackedBabyBearAVX512(arr1); + let vec_res = vec0 - vec1; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], arr0[i] - arr1[i]); + } + } + + #[test] + fn test_sub_vs_scalar_special_vals_right() { + let arr0 = array_from_random(0x05d81ebfb8f0005c); + let arr1 = SPECIAL_VALS; + + let vec0 = PackedBabyBearAVX512(arr0); + let vec1 = PackedBabyBearAVX512(arr1); + let vec_res = vec0 - vec1; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], arr0[i] - arr1[i]); + } + } + + #[test] + fn test_mul_vs_scalar() { + let arr0 = array_from_random(0x4242ebdc09b74d77); + let arr1 = array_from_random(0x9937b275b3c056cd); + + let vec0 = PackedBabyBearAVX512(arr0); + let vec1 = PackedBabyBearAVX512(arr1); + let vec_res = vec0 * vec1; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], arr0[i] * arr1[i]); + } + } + + #[test] + fn test_mul_vs_scalar_special_vals_left() { + let arr0 = SPECIAL_VALS; + let arr1 = array_from_random(0x5285448b835458a3); + + let vec0 = PackedBabyBearAVX512(arr0); + let vec1 = PackedBabyBearAVX512(arr1); + let vec_res = vec0 * vec1; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], arr0[i] * arr1[i]); + } + } + + #[test] + fn test_mul_vs_scalar_special_vals_right() { + let arr0 = array_from_random(0x22508dc80001d865); + let arr1 = SPECIAL_VALS; + + let vec0 = PackedBabyBearAVX512(arr0); + let vec1 = PackedBabyBearAVX512(arr1); + let vec_res = vec0 * vec1; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], arr0[i] * arr1[i]); + } + } + + #[test] + fn test_neg_vs_scalar() { + let arr = array_from_random(0xc3c273a9b334372f); + + let vec = PackedBabyBearAVX512(arr); + let vec_res = -vec; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], -arr[i]); + } + } + + #[test] + fn test_neg_vs_scalar_special_vals() { + let arr = SPECIAL_VALS; + + let vec = PackedBabyBearAVX512(arr); + let vec_res = -vec; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], -arr[i]); + } + } +} diff --git a/challenger/src/duplex_challenger.rs b/challenger/src/duplex_challenger.rs index 5ed9f56b..97335894 100644 --- a/challenger/src/duplex_challenger.rs +++ b/challenger/src/duplex_challenger.rs @@ -100,6 +100,21 @@ where } } +// for TrivialPcs +impl CanObserve>> for DuplexChallenger +where + F: Copy, + P: CryptographicPermutation<[F; WIDTH]>, +{ + fn observe(&mut self, valuess: Vec>) { + for values in valuess { + for value in values { + self.observe(value); + } + } + } +} + impl CanSample for DuplexChallenger where F: Field, diff --git a/multi-stark/Cargo.toml b/circle/Cargo.toml similarity index 61% rename from multi-stark/Cargo.toml rename to circle/Cargo.toml index 20ceae1b..786957c8 100644 --- a/multi-stark/Cargo.toml +++ b/circle/Cargo.toml @@ -1,18 +1,26 @@ [package] -name = "p3-multi-stark" +name = "p3-circle" version = "0.1.0" edition = "2021" license = "MIT OR Apache-2.0" [dependencies] -p3-air = { path = "../air" } -p3-field = { path = "../field" } p3-challenger = { path = "../challenger" } p3-commit = { path = "../commit" } +p3-dft = { path = "../dft" } +p3-field = { path = "../field" } p3-matrix = { path = "../matrix" } p3-maybe-rayon = { path = "../maybe-rayon" } p3-util = { path = "../util" } +tracing = "0.1.37" +itertools = "0.12.0" + [dev-dependencies] +p3-mds = { path = "../mds" } +p3-mersenne-31 = { path = "../mersenne-31" } +p3-merkle-tree = { path = "../merkle-tree" } +p3-poseidon = { path = "../poseidon" } p3-symmetric = { path = "../symmetric" } + rand = "0.8.5" diff --git a/circle/src/cfft.rs b/circle/src/cfft.rs new file mode 100644 index 00000000..c732e1a9 --- /dev/null +++ b/circle/src/cfft.rs @@ -0,0 +1,231 @@ +//! The Circle FFT and its inverse, as detailed in +//! Circle STARKs, Section 4.2 (page 14 of the first revision PDF) +//! This code is based on Angus Gruen's implementation, which uses a slightly +//! different cfft basis than that of the paper. Basically, it continues using the +//! same twiddles for the second half of the chunk, which only changes the sign of the +//! resulting basis. For a full explanation see the comments in `util::circle_basis`. +//! This alternate basis doesn't cause any change to the code apart from our testing functions. + +use alloc::rc::Rc; +use alloc::vec::Vec; +use core::cell::RefCell; + +use p3_commit::PolynomialSpace; +use p3_dft::divide_by_height; +use p3_field::extension::{Complex, ComplexExtendable}; +use p3_field::{AbstractField, PackedValue}; +use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::{Matrix, MatrixRowChunksMut, MatrixRowSlices, MatrixRowSlicesMut}; +use p3_maybe_rayon::prelude::*; +use p3_util::log2_strict_usize; +use tracing::instrument; + +use crate::domain::CircleDomain; +use crate::twiddles::TwiddleCache; + +#[derive(Default, Clone)] +pub struct Cfft(Rc>>); + +impl Cfft { + pub fn cfft(&self, vec: Vec) -> Vec { + self.cfft_batch(RowMajorMatrix::new_col(vec)).values + } + pub fn cfft_batch>(&self, mat: M) -> M { + let log_n = log2_strict_usize(mat.height()); + self.coset_cfft_batch(mat, F::circle_two_adic_generator(log_n + 1)) + } + /// The cfft: interpolating evaluations over a domain to the (sign-switched) cfft basis + #[instrument(skip_all, fields(dims = %mat.dimensions()))] + pub fn coset_cfft_batch>(&self, mut mat: M, shift: Complex) -> M { + let n = mat.height(); + let log_n = log2_strict_usize(n); + + let mut cache = self.0.borrow_mut(); + let twiddles = cache.get_twiddles(log_n, shift, true); + + for (i, twiddle) in twiddles.iter().enumerate() { + let block_size = 1 << (log_n - i); + let half_block_size = block_size >> 1; + assert_eq!(twiddle.len(), half_block_size); + + mat.par_row_chunks_mut(block_size).for_each(|mut chunk| { + for (i, &t) in twiddle.iter().enumerate() { + let (lo, hi) = chunk.row_pair_slices_mut(i, block_size - i - 1); + let (lo_packed, lo_suffix) = F::Packing::pack_slice_with_suffix_mut(lo); + let (hi_packed, hi_suffix) = F::Packing::pack_slice_with_suffix_mut(hi); + dif_butterfly(lo_packed, hi_packed, t.into()); + dif_butterfly(lo_suffix, hi_suffix, t); + } + }); + } + // TODO: omit this? + divide_by_height(&mut mat); + mat + } + + pub fn icfft(&self, vec: Vec) -> Vec { + self.icfft_batch(RowMajorMatrix::new_col(vec)).values + } + pub fn icfft_batch>(&self, mat: M) -> M { + let log_n = log2_strict_usize(mat.height()); + self.coset_icfft_batch(mat, F::circle_two_adic_generator(log_n + 1)) + } + /// The icfft: evaluating a polynomial in monomial basis over a domain + #[instrument(skip_all, fields(dims = %mat.dimensions()))] + pub fn coset_icfft_batch>(&self, mat: M, shift: Complex) -> M { + self.coset_icfft_batch_skipping_first_layers(mat, shift, 0) + } + #[instrument(skip_all, fields(dims = %mat.dimensions()))] + fn coset_icfft_batch_skipping_first_layers>( + &self, + mut mat: M, + shift: Complex, + num_skipped_layers: usize, + ) -> M { + let n = mat.height(); + let log_n = log2_strict_usize(n); + + let mut cache = self.0.borrow_mut(); + let twiddles = cache.get_twiddles(log_n, shift, false); + + for (i, twiddle) in twiddles.iter().rev().enumerate().skip(num_skipped_layers) { + let block_size = 1 << (i + 1); + let half_block_size = block_size >> 1; + assert_eq!(twiddle.len(), half_block_size); + + mat.par_row_chunks_mut(block_size).for_each(|mut chunk| { + for (i, &t) in twiddle.iter().enumerate() { + let (lo, hi) = chunk.row_pair_slices_mut(i, block_size - i - 1); + let (lo_packed, lo_suffix) = F::Packing::pack_slice_with_suffix_mut(lo); + let (hi_packed, hi_suffix) = F::Packing::pack_slice_with_suffix_mut(hi); + dit_butterfly(lo_packed, hi_packed, t.into()); + dit_butterfly(lo_suffix, hi_suffix, t); + } + }); + } + + mat + } + + #[instrument(skip_all, fields(dims = %mat.dimensions()))] + pub fn lde>( + &self, + mut mat: M, + src_domain: CircleDomain, + target_domain: CircleDomain, + ) -> RowMajorMatrix { + assert_eq!(mat.height(), src_domain.size()); + assert!(target_domain.size() >= src_domain.size()); + let added_bits = target_domain.log_n - src_domain.log_n; + + // CFFT + mat = self.coset_cfft_batch(mat, src_domain.shift); + + /* + To do an LDE, we could interleave zeros into the coefficients, but + the first `added_bits` layers will simply fill out the zeros with the + lower order values. (In `ibutterfly`, `hi` will start as zero, and + both `lo` and `hi` are set to `lo`). So instead, we do the tiling directly + and skip the first `added_bits` layers. + */ + let tiled_mat = tile_rows(mat, 1 << added_bits); + debug_assert_eq!(tiled_mat.height(), target_domain.size()); + + self.coset_icfft_batch_skipping_first_layers(tiled_mat, target_domain.shift, added_bits) + } +} + +/// Division-in-frequency +#[inline(always)] +fn dif_butterfly(lo_chunk: &mut [F], hi_chunk: &mut [F], twiddle: F) { + for (lo, hi) in lo_chunk.iter_mut().zip(hi_chunk) { + let sum = *lo + *hi; + let diff = (*lo - *hi) * twiddle; + *lo = sum; + *hi = diff; + } +} + +/// Division-in-time +#[inline(always)] +fn dit_butterfly(lo_chunk: &mut [F], hi_chunk: &mut [F], twiddle: F) { + for (lo, hi) in lo_chunk.iter_mut().zip(hi_chunk) { + let hi_twiddle = *hi * twiddle; + let sum = *lo + hi_twiddle; + let diff = *lo - hi_twiddle; + *lo = sum; + *hi = diff; + } +} + +// Repeats rows +// TODO this can be micro-optimized +fn tile_rows(mat: impl MatrixRowSlices, repetitions: usize) -> RowMajorMatrix { + let mut values = Vec::with_capacity(mat.width() * mat.height() * repetitions); + for r in mat.row_slices() { + for _ in 0..repetitions { + values.extend_from_slice(r); + } + } + RowMajorMatrix::new(values, mat.width()) +} + +#[cfg(test)] +mod tests { + use p3_dft::bit_reversed_zero_pad; + use p3_mersenne_31::Mersenne31; + use rand::{thread_rng, Rng}; + + use super::*; + use crate::domain::CircleDomain; + use crate::util::{eval_circle_polys, univariate_to_point}; + + type F = Mersenne31; + + fn do_test_cfft(log_n: usize) { + let n = 1 << log_n; + let cfft = Cfft::default(); + + let shift: Complex = univariate_to_point(thread_rng().gen()).unwrap(); + + let evals = RowMajorMatrix::::rand(&mut thread_rng(), n, 1 << 5); + let coeffs = cfft.coset_cfft_batch(evals.clone(), shift); + + assert_eq!(evals.clone(), cfft.coset_icfft_batch(coeffs.clone(), shift)); + + let d = CircleDomain { shift, log_n }; + for (pt, ys) in d.points().zip(evals.rows()) { + assert_eq!(ys, eval_circle_polys(&coeffs, pt)); + } + } + + #[test] + fn test_cfft() { + do_test_cfft(5); + do_test_cfft(8); + } + + fn do_test_lde(log_n: usize, added_bits: usize) { + let n = 1 << log_n; + let cfft = Cfft::::default(); + + let shift: Complex = univariate_to_point(thread_rng().gen()).unwrap(); + + let evals = RowMajorMatrix::::rand(&mut thread_rng(), n, 1); + let src_domain = CircleDomain { log_n, shift }; + let target_domain = CircleDomain::standard(log_n + added_bits); + + let mut coeffs = cfft.coset_cfft_batch(evals.clone(), src_domain.shift); + bit_reversed_zero_pad(&mut coeffs, added_bits); + let expected = cfft.coset_icfft_batch(coeffs, target_domain.shift); + + let actual = cfft.lde(evals, src_domain, target_domain); + + assert_eq!(actual, expected); + } + + #[test] + fn test_lde() { + do_test_lde(3, 1); + } +} diff --git a/circle/src/domain.rs b/circle/src/domain.rs new file mode 100644 index 00000000..910ab8b2 --- /dev/null +++ b/circle/src/domain.rs @@ -0,0 +1,362 @@ +use alloc::vec; +use alloc::vec::Vec; + +use itertools::Itertools; +use p3_commit::{LagrangeSelectors, PolynomialSpace}; +use p3_field::extension::{Complex, ComplexExtendable}; +use p3_field::{batch_multiplicative_inverse, AbstractField, ExtensionField}; +use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::Matrix; +use p3_util::{log2_ceil_usize, log2_strict_usize}; +use tracing::instrument; + +use crate::util::{point_to_univariate, s_p_at_p, univariate_to_point, v_0, v_n}; + +/// A twin-coset of the circle group on F. It has a power-of-two size and an arbitrary shift. +/// +/// X is generator, O is the first coset, goes counterclockwise +/// ```text +/// O X . +/// . . +/// . O <- start = shift +/// . . - (1,0) +/// O . +/// . . +/// . . O +/// ``` +/// +/// For ordering reasons, the other half will start at gen / shift: +/// ```text +/// . X O <- start = gen/shift +/// . . +/// O . +/// . . - (1,0) +/// . O +/// . . +/// O . . +/// ``` +/// +/// The full domain is the interleaving of these two cosets +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct CircleDomain { + // log_n corresponds to the log size of the WHOLE domain + pub(crate) log_n: usize, + pub(crate) shift: Complex, +} + +impl CircleDomain { + pub(crate) fn new(log_n: usize, shift: Complex) -> Self { + Self { log_n, shift } + } + pub(crate) fn standard(log_n: usize) -> Self { + Self { + log_n, + shift: F::circle_two_adic_generator(log_n + 1), + } + } + fn is_standard(&self) -> bool { + self.shift == F::circle_two_adic_generator(self.log_n + 1) + } + pub(crate) fn points(&self) -> impl Iterator> { + let half_gen = F::circle_two_adic_generator(self.log_n - 1); + let coset0 = half_gen.shifted_powers(self.shift); + let coset1 = half_gen.shifted_powers(half_gen / self.shift); + coset0.interleave(coset1).take(1 << self.log_n) + } + + /// Computes the lagrange basis at point, not yet normalized by the value of the domain + /// vanishing poly, since that is more efficient to compute after the dot product. + #[instrument(skip_all, fields(log_n = %self.log_n))] + pub(crate) fn lagrange_basis>(&self, point: Complex) -> Vec { + let domain = self.points().collect_vec(); + + // the denominator so that the lagrange basis is normalized to 1 + // TODO: this depends only on domain, so should be precomputed + let lagrange_normalizer: Vec = domain + .iter() + .map(|p| s_p_at_p(p.real(), p.imag(), self.log_n)) + .collect(); + + let basis = domain + .into_iter() + .zip(&lagrange_normalizer) + .map(|(p, &ln)| { + // ext * base + // TODO: this can be sped up + v_0(p.conjugate().rotate(point)) * ln + }) + .collect_vec(); + + batch_multiplicative_inverse(&basis) + } +} + +impl PolynomialSpace for CircleDomain { + type Val = F; + + fn size(&self) -> usize { + 1 << self.log_n + } + + fn first_point(&self) -> Self::Val { + point_to_univariate(self.shift).unwrap() + } + + fn next_point>(&self, x: Ext) -> Option { + // Only in standard position do we have an algebraic expression to access the next point. + if self.is_standard() { + let gen = F::circle_two_adic_generator(self.log_n); + Some(point_to_univariate(gen.rotate(univariate_to_point(x).unwrap())).unwrap()) + } else { + None + } + } + + fn create_disjoint_domain(&self, min_size: usize) -> Self { + // Right now we simply guarantee the domain is disjoint by returning a + // larger standard position coset, which is fine because we always ask for a larger + // domain. If we wanted good performance for a disjoint domain of the same size, + // we could change the shift. Also we could support nonstandard twin cosets. + assert!( + self.is_standard(), + "create_disjoint_domain not currently supported for nonstandard twin cosets" + ); + let log_n = log2_ceil_usize(min_size); + // Any standard position coset that is not the same size as us will be disjoint. + Self::standard(if log_n == self.log_n { + log_n + 1 + } else { + log_n + }) + } + + fn zp_at_point>(&self, point: Ext) -> Ext { + v_n(univariate_to_point(point).unwrap().real(), self.log_n) + - v_n(self.shift.real(), self.log_n) + } + + fn selectors_at_point>( + &self, + point: Ext, + ) -> LagrangeSelectors { + let zeroifier = self.zp_at_point(point); + let p = univariate_to_point(point).unwrap(); + LagrangeSelectors { + is_first_row: zeroifier / v_0(self.shift.conjugate().rotate(p)), + is_last_row: zeroifier / v_0(self.shift.rotate(p)), + // This is the transition selector from the paper, but seems to send + // the quotient out of FFT space. It has a simple zero at the last point + // and a simple pole at the negation of the last point. + // is_transition: v_0(self.shift.rotate(p)), + // Instead we use this selector which has two zeros at the last point, + // which seems to work. TODO: More investigation is needed. + is_transition: self.shift.rotate(p).real() - Ext::one(), + inv_zeroifier: zeroifier.inverse(), + } + } + + // wow, really slow! + #[instrument(skip_all, fields(log_n = %coset.log_n))] + fn selectors_on_coset(&self, coset: Self) -> LagrangeSelectors> { + let sels = coset + .points() + .map(|p| self.selectors_at_point(point_to_univariate(p).unwrap())) + .collect_vec(); + LagrangeSelectors { + is_first_row: sels.iter().map(|s| s.is_first_row).collect(), + is_last_row: sels.iter().map(|s| s.is_last_row).collect(), + is_transition: sels.iter().map(|s| s.is_transition).collect(), + inv_zeroifier: sels.iter().map(|s| s.inv_zeroifier).collect(), + } + } + + /// Decompose a domain into disjoint twin-cosets. + fn split_domains(&self, num_chunks: usize) -> Vec { + assert!(self.is_standard()); + let log_chunks = log2_strict_usize(num_chunks); + self.points() + .take(num_chunks) + .map(|shift| CircleDomain { + log_n: self.log_n - log_chunks, + shift, + }) + .collect() + } + + /* + chunks=2: + + 1 . 1 + . . + 0 0 <-- start + . . - (1,0) + 0 0 + . . + 1 . 1 + + + idx -> which chunk to put it in: + chunks=2: 0 1 1 0 0 1 1 0 0 1 1 0 0 1 1 0 + chunks=4: 0 1 2 3 3 2 1 0 0 1 2 3 3 2 1 0 + */ + fn split_evals( + &self, + num_chunks: usize, + evals: RowMajorMatrix, + ) -> Vec> { + let log_chunks = log2_strict_usize(num_chunks); + assert!(evals.height() >> (log_chunks + 1) >= 1); + let width = evals.width(); + let mut values: Vec> = vec![vec![]; num_chunks]; + let mut rows = evals.rows(); + for _ in 0..(evals.height() >> (log_chunks + 1)) { + for chunk in values.iter_mut() { + chunk.extend_from_slice(rows.next().unwrap()); + } + for chunk in values.iter_mut().rev() { + chunk.extend_from_slice(rows.next().unwrap()); + } + } + assert!(rows.next().is_none()); + + values + .into_iter() + .map(|v| RowMajorMatrix::new(v, width)) + .collect() + } +} + +#[cfg(test)] +mod tests { + + use std::collections::HashSet; + + use itertools::izip; + use p3_matrix::routines::columnwise_dot_product; + use p3_mersenne_31::Mersenne31; + use rand::{thread_rng, Rng}; + + use super::*; + use crate::util::eval_circle_polys; + use crate::Cfft; + + fn assert_is_twin_coset(d: CircleDomain) { + let pts = d.points().collect_vec(); + let half_n = pts.len() >> 1; + for (l, r) in izip!(&pts[..half_n], pts[half_n..].iter().rev()) { + assert_eq!(*l, r.conjugate()); + } + } + + fn do_test_circle_domain(log_n: usize, width: usize) { + let n = 1 << log_n; + + type F = Mersenne31; + let d = CircleDomain::::standard(log_n); + + // we can move around the circle and end up where we started + let p0 = d.first_point(); + let mut p1 = p0; + for _ in 0..(n - 1) { + p1 = d.next_point(p1).unwrap(); + assert_ne!(p1, p0); + } + assert_eq!(d.next_point(p1).unwrap(), p0); + + // .points() is the same as first_point -> next_point + let mut uni_point = d.first_point(); + for p in d.points() { + assert_eq!(univariate_to_point(uni_point), Some(p)); + uni_point = d.next_point(uni_point).unwrap(); + } + + // disjoint domain is actually disjoint, and large enough + let seen: HashSet> = d.points().collect(); + for disjoint_size in [10, 100, n - 5, n + 15] { + let dd = d.create_disjoint_domain(disjoint_size); + assert!(dd.size() >= disjoint_size); + for pt in dd.points() { + assert!(!seen.contains(&pt)); + } + } + + // zp is zero + for p in d.points() { + assert_eq!(d.zp_at_point(point_to_univariate(p).unwrap()), F::zero()); + } + + // split domains + let evals = RowMajorMatrix::rand(&mut thread_rng(), n, width); + let orig: Vec<(Complex, Vec)> = + d.points().zip(evals.rows().map(|r| r.to_vec())).collect(); + for num_chunks in [1, 2, 4, 8] { + let mut combined = vec![]; + + let sds = d.split_domains(num_chunks); + assert_eq!(sds.len(), num_chunks); + let ses = d.split_evals(num_chunks, evals.clone()); + assert_eq!(ses.len(), num_chunks); + for (sd, se) in izip!(sds, ses) { + // Split domains are twin cosets + assert_is_twin_coset(sd); + // Split domains have correct size wrt original domain + assert_eq!(sd.size() * num_chunks, d.size()); + assert_eq!(se.width(), evals.width()); + assert_eq!(se.height() * num_chunks, d.size()); + combined.extend(sd.points().zip(se.rows().map(|r| r.to_vec()))); + } + // Union of split domains and evals is the original domain and evals + assert_eq!( + orig.iter().map(|x| x.0).collect::>(), + combined.iter().map(|x| x.0).collect::>(), + "union of split domains is orig domain" + ); + assert_eq!( + orig.iter().map(|x| &x.1).collect::>(), + combined.iter().map(|x| &x.1).collect::>(), + "union of split evals is orig evals" + ); + assert_eq!( + orig.iter().collect::>(), + combined.iter().collect::>(), + "split domains and evals correspond to orig domains and evals" + ); + } + } + + #[test] + fn test_circle_domain() { + do_test_circle_domain(4, 32); + do_test_circle_domain(10, 32); + } + + #[test] + fn test_barycentric() { + let log_n = 10; + let n = 1 << log_n; + + type F = Mersenne31; + + let evals = RowMajorMatrix::::rand(&mut thread_rng(), n, 16); + + let cfft = Cfft::default(); + + let shift: Complex = univariate_to_point(thread_rng().gen()).unwrap(); + let d = CircleDomain { shift, log_n }; + + let coeffs = cfft.coset_cfft_batch(evals.clone(), shift); + + // simple barycentric + let zeta: Complex = univariate_to_point(thread_rng().gen()).unwrap(); + + let basis = d.lagrange_basis(zeta); + let v_n_at_zeta = v_n(zeta.real(), log_n) - v_n(shift.real(), log_n); + + let ys = columnwise_dot_product(evals, basis.into_iter()) + .into_iter() + .map(|x| x * v_n_at_zeta) + .collect_vec(); + + assert_eq!(ys, eval_circle_polys(&coeffs, zeta)); + } +} diff --git a/circle/src/lib.rs b/circle/src/lib.rs new file mode 100644 index 00000000..7d4283f0 --- /dev/null +++ b/circle/src/lib.rs @@ -0,0 +1,15 @@ +//! A framework for operating over the unit circle of a finite field, +//! following the [Circle STARKs paper](https://eprint.iacr.org/2024/278) by Haböck, Levit and Papini. + +#![cfg_attr(not(test), no_std)] + +extern crate alloc; + +mod cfft; +mod domain; +mod pcs; +mod twiddles; +mod util; + +pub use cfft::*; +pub use pcs::*; diff --git a/circle/src/pcs.rs b/circle/src/pcs.rs new file mode 100644 index 00000000..4151388d --- /dev/null +++ b/circle/src/pcs.rs @@ -0,0 +1,146 @@ +use alloc::vec::Vec; + +use itertools::izip; +use p3_commit::{DirectMmcs, OpenedValues, Pcs}; +use p3_field::extension::ComplexExtendable; +use p3_field::ExtensionField; +use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView}; +use p3_matrix::routines::columnwise_dot_product; +use p3_matrix::{Matrix, MatrixRows}; +use p3_util::log2_strict_usize; +use tracing::instrument; + +use crate::cfft::Cfft; +use crate::domain::CircleDomain; +use crate::util::{univariate_to_point, v_n}; + +pub struct CirclePcs { + pub log_blowup: usize, + pub cfft: Cfft, + pub mmcs: InputMmcs, +} + +pub struct ProverData { + committed_domains: Vec>, + mmcs_data: MmcsData, +} + +impl Pcs for CirclePcs +where + Val: ComplexExtendable, + Challenge: ExtensionField, + InputMmcs: 'static + for<'a> DirectMmcs = RowMajorMatrixView<'a, Val>>, +{ + type Domain = CircleDomain; + type Commitment = InputMmcs::Commitment; + type ProverData = ProverData; + type Proof = (); + type Error = (); + + fn natural_domain_for_degree(&self, degree: usize) -> Self::Domain { + CircleDomain::standard(log2_strict_usize(degree)) + } + + fn commit( + &self, + evaluations: Vec<(Self::Domain, RowMajorMatrix)>, + ) -> (Self::Commitment, Self::ProverData) { + let (committed_domains, ldes): (Vec<_>, Vec<_>) = evaluations + .into_iter() + .map(|(domain, evals)| { + let committed_domain = CircleDomain::standard(domain.log_n + self.log_blowup); + // bitrev for fri? + let lde = self.cfft.lde(evals, domain, committed_domain); + (committed_domain, lde) + }) + .unzip(); + let (comm, mmcs_data) = self.mmcs.commit(ldes); + ( + comm, + ProverData { + committed_domains, + mmcs_data, + }, + ) + } + + fn get_evaluations_on_domain( + &self, + data: &Self::ProverData, + idx: usize, + domain: Self::Domain, + ) -> RowMajorMatrix { + // TODO do this correctly + let mat = self.mmcs.get_matrices(&data.mmcs_data)[idx]; + assert_eq!(mat.height(), 1 << domain.log_n); + assert_eq!(domain, data.committed_domains[idx]); + mat.to_row_major_matrix() + } + + #[instrument(skip_all)] + fn open( + &self, + // For each round, + rounds: Vec<( + &Self::ProverData, + // for each matrix, + Vec< + // points to open + Vec, + >, + )>, + _challenger: &mut Challenger, + ) -> (OpenedValues, Self::Proof) { + let values: OpenedValues = rounds + .into_iter() + .map(|(data, points_for_mats)| { + let mats = self.mmcs.get_matrices(&data.mmcs_data); + izip!(&data.committed_domains, mats, points_for_mats) + .map(|(domain, mat, points_for_mat)| { + let log_n = log2_strict_usize(mat.height()); + points_for_mat + .into_iter() + .map(|zeta| { + let zeta_point = univariate_to_point(zeta).unwrap(); + let basis: Vec = domain.lagrange_basis(zeta_point); + let v_n_at_zeta = + v_n(zeta_point.real(), log_n) - v_n(domain.shift.real(), log_n); + columnwise_dot_product(mat, basis.into_iter()) + .into_iter() + .map(|x| x * v_n_at_zeta) + .collect() + }) + .collect() + }) + .collect() + }) + .collect(); + // todo: fri prove + (values, ()) + } + + fn verify( + &self, + // For each round: + _rounds: Vec<( + Self::Commitment, + // for each matrix: + Vec<( + // its domain, + Self::Domain, + // for each point: + Vec<( + // the point, + Challenge, + // values at the point + Vec, + )>, + )>, + )>, + _proof: &Self::Proof, + _challenger: &mut Challenger, + ) -> Result<(), Self::Error> { + // todo: fri verify + Ok(()) + } +} diff --git a/circle/src/twiddles.rs b/circle/src/twiddles.rs new file mode 100644 index 00000000..601cc9d5 --- /dev/null +++ b/circle/src/twiddles.rs @@ -0,0 +1,78 @@ +use alloc::vec::Vec; +use core::mem; + +use itertools::Itertools; +use p3_field::batch_multiplicative_inverse; +use p3_field::extension::{Complex, ComplexExtendable}; +use p3_util::linear_map::LinearMap; +use tracing::instrument; + +use crate::domain::CircleDomain; + +#[derive(Default)] +pub(crate) struct TwiddleCache( + // (log_n, shift) -> (twiddles, inverse_twiddles) + #[allow(clippy::type_complexity)] + LinearMap<(usize, Complex), (Vec>, Option>>)>, +); + +impl TwiddleCache { + pub(crate) fn get_twiddles( + &mut self, + log_n: usize, + shift: Complex, + inv: bool, + ) -> &Vec> { + let cache = self + .0 + .get_or_insert_with((log_n, shift), || (compute_twiddles(log_n, shift), None)); + if !inv { + &cache.0 + } else { + cache.1.get_or_insert_with(|| { + cache + .0 + .iter() + .map(|xs| batch_multiplicative_inverse(xs)) + .collect() + }) + } + } +} + +/// Computes all (non-inverted) twiddles for the FFT over a circle domain of size 2^log_n, for all layers of the FFT. +#[instrument(skip(shift))] +fn compute_twiddles(log_n: usize, shift: Complex) -> Vec> { + let size = 1 << (log_n - 1); + + let init_domain = CircleDomain::new(log_n, shift) + .points() + .take(size) + .collect_vec(); + + // After the first step we only need the real part. + let mut working_domain: Vec<_> = init_domain + .iter() + .take(size / 2) + .map(|x| x.real()) + .collect(); + + (0..log_n) + .map(|i| { + let size = working_domain.len(); + let output = if i == 0 { + // The twiddles in step one are the inverse imaginary parts. + init_domain.iter().map(|x| x.imag()).collect_vec() + } else { + let new_working_domain = working_domain + .iter() + .take(size / 2) + // When we square a point, the real part changes as x -> 2x^2 - 1. + .map(|x| x.square().double() - F::one()) + .collect(); + mem::replace(&mut working_domain, new_working_domain) + }; + output + }) + .collect() +} diff --git a/circle/src/util.rs b/circle/src/util.rs new file mode 100644 index 00000000..dddb8f3a --- /dev/null +++ b/circle/src/util.rs @@ -0,0 +1,167 @@ +use p3_field::extension::Complex; +#[cfg(test)] +use p3_field::extension::ComplexExtendable; +use p3_field::{ExtensionField, Field}; +#[cfg(test)] +use p3_matrix::dense::RowMajorMatrix; +#[cfg(test)] +use p3_util::{log2_strict_usize, reverse_slice_index_bits}; + +/// Get the cfft polynomial basis. +/// The basis consists off all multi-linear products of: y, x, 2x^2 - 1, 2(2x^2 - 1)^2 - 1, ... +/// The ordering of these basis elements is the bit reversal of the sequence: 1, y, x, xy, (2x^2 - 1), (2x^2 - 1)y, ... +/// We also need to throw in a couple of negative signs for technical reasons. +#[cfg(test)] +pub(crate) fn circle_basis(point: Complex, log_n: usize) -> Vec { + if log_n == 0 { + return vec![F::one()]; + } + + let mut basis = vec![F::one()]; + basis.reserve(1 << log_n); + + // First compute the repeated applications of the squaring map π(x) = 2x^2 - 1 + let mut cur = point.real(); + for _ in 0..(log_n - 1) { + for i in 0..basis.len() { + basis.push(basis[i] * cur); + } + cur = F::two() * cur.square() - F::one(); + } + + // Bit reverse, and compute the second half of the array, + // which is just each element of the first half times y + reverse_slice_index_bits(&mut basis); + for i in 0..basis.len() { + basis.push(basis[i] * point.imag()); + } + + // Negate each element each time the binary representation of its index has a pair of adjacent ones, + // or equivalently, if the number of adjacent ones is odd. + // This comes from a peculiarity in how we compute the CFFT: + // The butterfly zips the first half of the domain with the second half reversed, because that maps each point + // to its involution. After each layer, the second half is still in reverse order, so we should use the twiddles + // in reverse order as well, but we ignore that and use the same twiddles for both halves. + // Using t(g^(N-k)) instead of t(g^k) just adds a negative sign. It turns out the number of negations is the number + // of adjacent ones in the index. + for (i, val) in basis.iter_mut().enumerate() { + let num_adjacent_ones = (i & (i >> 1)).count_ones(); + if num_adjacent_ones % 2 == 1 { + *val = -*val; + } + } + + basis +} + +#[cfg(test)] +pub(crate) fn eval_circle_polys( + coeffs: &RowMajorMatrix, + point: Complex, +) -> Vec { + use itertools::izip; + use p3_matrix::Matrix; + + let log_n = log2_strict_usize(coeffs.height()); + let mut accs = vec![F::zero(); coeffs.width()]; + for (row, basis) in coeffs.rows().zip(circle_basis(point, log_n)) { + for (acc, coeff) in izip!(&mut accs, row) { + *acc += *coeff * basis; + } + } + accs +} + +/// Circle STARKs, Section 3, Lemma 1: (page 4 of the first revision PDF) +/// (x, y) = ((1-t^2)/(1+t^2), 2t/(1+t^2)) +/// Returns None if t^2 = -1 (corresponding to the point at infinity). +pub(crate) fn univariate_to_point(t: F) -> Option> { + let t2 = t.square(); + let inv_denom = (F::one() + t2).try_inverse()?; + Some(Complex::new( + (F::one() - t2) * inv_denom, + t.double() * inv_denom, + )) +} + +/// Circle STARKs, Section 3, Lemma 1: (page 4 of the first revision PDF) +/// t = y / (x + 1) +/// If F has i, this should return that instead, but we don't have access.. +pub(crate) fn point_to_univariate(p: Complex) -> Option { + p.imag().try_div(p.real() + F::one()) +} + +/// Formula for the group operation in univariate coordinates +/// Circle STARKs, Section 3.1, Remark 4: (page 5 of the first revision PDF) +/// same as above, this *could* handle point at infinity if we had Field::try_sqrt +#[allow(unused)] +pub(crate) fn rotate_univariate>(t1: EF, t2: F) -> Option { + (t1 + t2).try_div(EF::one() - t1 * t2) +} + +/// Evaluate the vanishing polynomial for the standard position coset of size 2^log_n +/// at the point `p` (which has x coordinate `p_x`). +/// Circle STARKs, Section 3.3, Equation 8 (page 10 of the first revision PDF) +pub(crate) fn v_n(mut p_x: F, log_n: usize) -> F { + for _ in 0..(log_n - 1) { + p_x = p_x.square().double() - F::one(); + } + p_x +} + +/// Evaluate the formal derivative of `v_n` at the point `p` (which has x coordinate `p_x`). +/// Circle STARKs, Section 5.1, Remark 15 (page 21 of the first revision PDF) +fn v_n_prime(p_x: F, log_n: usize) -> F { + F::two().exp_u64((2 * (log_n - 1)) as u64) * (1..log_n).map(|i| v_n(p_x, i)).product() +} + +/// Simple zero at (1,0), simple pole at (-1,0) +/// Circle STARKs, Section 5.1, Lemma 11 (page 21 of the first revision PDF) +/// panics if called with x = -1 +pub(crate) fn v_0(p: Complex) -> F { + p.imag() / (p.real() + F::one()) +} + +/// The concrete value of the selector s_P = v_n / (v_0 . T_p⁻¹) at P, used for normalization to 1. +/// Circle STARKs, Section 5.1, Remark 16 (page 22 of the first revision PDF) +pub(crate) fn s_p_at_p(p_x: F, p_y: F, log_n: usize) -> F { + -F::two() * v_n_prime(p_x, log_n) * p_y +} + +#[cfg(test)] +mod tests { + use p3_field::AbstractField; + use p3_mersenne_31::Mersenne31; + + use super::*; + + type F = Mersenne31; + type C = Complex; + + #[test] + fn test_uni_to_point() { + // 0 -> (1, 0) + assert_eq!(univariate_to_point(F::zero()), Some(C::new_real(F::one()))); + // 1 -> (0, 1) + assert_eq!(univariate_to_point(F::one()), Some(C::new_imag(F::one()))); + // -1 -> (0, -1) + assert_eq!( + univariate_to_point(F::neg_one()), + Some(C::new_imag(F::neg_one())) + ); + } + + #[test] + fn test_s_p_at_p() { + // from sage + assert_eq!( + s_p_at_p( + // random point on the circle + F::from_canonical_u32(383393203), + F::from_canonical_u32(415518596), + 3 + ), + F::from_canonical_u32(1612953309) + ); + } +} diff --git a/commit/Cargo.toml b/commit/Cargo.toml index 873c287e..061b694f 100644 --- a/commit/Cargo.toml +++ b/commit/Cargo.toml @@ -4,8 +4,22 @@ version = "0.1.0" edition = "2021" license = "MIT OR Apache-2.0" +[features] +test-utils = ["dep:p3-dft"] + [dependencies] p3-challenger = { path = "../challenger" } p3-field = { path = "../field" } p3-matrix = { path = "../matrix" } +p3-util = { path = "../util" } + +itertools = "0.12.0" serde = { version = "1.0", default-features = false } + +# for testing +p3-dft = { path = "../dft", optional = true } + +[dev-dependencies] +p3-baby-bear = { path = "../baby-bear" } +p3-dft = { path = "../dft" } +rand = "0.8.5" diff --git a/commit/src/adapters/mod.rs b/commit/src/adapters/mod.rs index 5c77ff7a..9809fe60 100644 --- a/commit/src/adapters/mod.rs +++ b/commit/src/adapters/mod.rs @@ -1,9 +1,5 @@ //! Adapters for converting between different types of commitment schemes. mod extension_mmcs; -mod multi_from_uni_pcs; -mod uni_from_multi_pcs; pub use extension_mmcs::*; -pub use multi_from_uni_pcs::*; -pub use uni_from_multi_pcs::*; diff --git a/commit/src/adapters/multi_from_uni_pcs.rs b/commit/src/adapters/multi_from_uni_pcs.rs deleted file mode 100644 index 69734946..00000000 --- a/commit/src/adapters/multi_from_uni_pcs.rs +++ /dev/null @@ -1,21 +0,0 @@ -use core::marker::PhantomData; - -use p3_challenger::FieldChallenger; -use p3_field::{ExtensionField, Field}; -use p3_matrix::MatrixRows; - -use crate::pcs::UnivariatePcs; - -pub struct MultiFromUniPcs -where - Val: Field, - EF: ExtensionField, - In: MatrixRows, - U: UnivariatePcs, - Challenger: FieldChallenger, -{ - _uni: U, - _phantom: PhantomData<(Val, EF, In, Challenger)>, -} - -// TODO: Impl PCS, MultivariatePcs diff --git a/commit/src/adapters/uni_from_multi_pcs.rs b/commit/src/adapters/uni_from_multi_pcs.rs deleted file mode 100644 index baa7033f..00000000 --- a/commit/src/adapters/uni_from_multi_pcs.rs +++ /dev/null @@ -1,21 +0,0 @@ -use core::marker::PhantomData; - -use p3_challenger::FieldChallenger; -use p3_field::{ExtensionField, Field}; -use p3_matrix::MatrixRows; - -use crate::pcs::MultivariatePcs; - -pub struct UniFromMultiPcs -where - Val: Field, - EF: ExtensionField, - In: MatrixRows, - M: MultivariatePcs, - Challenger: FieldChallenger, -{ - _multi: M, - _phantom: PhantomData<(Val, EF, In, Challenger)>, -} - -// impl> UnivariatePcs for UniFromMultiPcs {} diff --git a/commit/src/domain.rs b/commit/src/domain.rs new file mode 100644 index 00000000..8ce5cb2e --- /dev/null +++ b/commit/src/domain.rs @@ -0,0 +1,163 @@ +use alloc::vec::Vec; + +use itertools::Itertools; +use p3_field::{ + batch_multiplicative_inverse, cyclic_subgroup_coset_known_order, ExtensionField, Field, + TwoAdicField, +}; +use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::MatrixRows; +use p3_util::{log2_ceil_usize, log2_strict_usize}; + +pub struct LagrangeSelectors { + pub is_first_row: T, + pub is_last_row: T, + pub is_transition: T, + pub inv_zeroifier: T, +} + +pub trait PolynomialSpace: Copy { + type Val: Field; + + fn size(&self) -> usize; + + fn first_point(&self) -> Self::Val; + + // This is only defined for cosets. + fn next_point>(&self, x: Ext) -> Option; + + // There are many choices for this, but we must pick a canonical one + // for both prover/verifier determinism and LDE caching. + fn create_disjoint_domain(&self, min_size: usize) -> Self; + + // Split this domain into `num_chunks` even chunks. + fn split_domains(&self, num_chunks: usize) -> Vec; + // Split the evals into chunks of evals, corresponding to each domain + // from `split_domains`. + fn split_evals( + &self, + num_chunks: usize, + evals: RowMajorMatrix, + ) -> Vec>; + + fn zp_at_point>(&self, point: Ext) -> Ext; + + // Unnormalized + fn selectors_at_point>( + &self, + point: Ext, + ) -> LagrangeSelectors; + + // Unnormalized + fn selectors_on_coset(&self, coset: Self) -> LagrangeSelectors>; +} + +#[derive(Copy, Clone)] +pub struct TwoAdicMultiplicativeCoset { + pub log_n: usize, + pub shift: Val, +} + +impl TwoAdicMultiplicativeCoset { + fn gen(&self) -> Val { + Val::two_adic_generator(self.log_n) + } +} + +impl PolynomialSpace for TwoAdicMultiplicativeCoset { + type Val = Val; + + fn size(&self) -> usize { + 1 << self.log_n + } + + fn first_point(&self) -> Self::Val { + self.shift + } + fn next_point>(&self, x: Ext) -> Option { + Some(x * self.gen()) + } + + fn create_disjoint_domain(&self, min_size: usize) -> Self { + Self { + log_n: log2_ceil_usize(min_size), + shift: self.shift * Val::generator(), + } + } + fn zp_at_point>(&self, point: Ext) -> Ext { + (point * self.shift.inverse()).exp_power_of_2(self.log_n) - Ext::one() + } + + fn split_domains(&self, num_chunks: usize) -> Vec { + let log_chunks = log2_strict_usize(num_chunks); + (0..num_chunks) + .map(|i| Self { + log_n: self.log_n - log_chunks, + shift: self.shift * self.gen().exp_u64(i as u64), + }) + .collect() + } + fn split_evals( + &self, + num_chunks: usize, + evals: RowMajorMatrix, + ) -> Vec> { + let view = evals.as_view(); + // todo less copy + (0..num_chunks) + .map(|i| view.vertically_strided(num_chunks, i).to_row_major_matrix()) + .collect() + } + + fn selectors_at_point>(&self, point: Ext) -> LagrangeSelectors { + let unshifted_point = point * self.shift.inverse(); + let z_h = unshifted_point.exp_power_of_2(self.log_n) - Ext::one(); + LagrangeSelectors { + is_first_row: z_h / (unshifted_point - Ext::one()), + is_last_row: z_h / (unshifted_point - self.gen().inverse()), + is_transition: unshifted_point - self.gen().inverse(), + inv_zeroifier: z_h.inverse(), + } + } + + fn selectors_on_coset(&self, coset: Self) -> LagrangeSelectors> { + assert_eq!(self.shift, Val::one()); + assert!(coset.log_n >= self.log_n); + let rate_bits = coset.log_n - self.log_n; + + let s_pow_n = coset.shift.exp_power_of_2(self.log_n); + // evals of Z_H(X) = X^n - 1 + let evals = Val::two_adic_generator(rate_bits) + .powers() + .take(1 << rate_bits) + .map(|x| s_pow_n * x - Val::one()) + .collect_vec(); + + let xs = cyclic_subgroup_coset_known_order(coset.gen(), coset.shift, 1 << coset.log_n) + .collect_vec(); + + let single_point_selector = |i: u64| { + let denoms = xs.iter().map(|&x| x - self.gen().exp_u64(i)).collect_vec(); + let invs = batch_multiplicative_inverse(&denoms); + evals + .iter() + .cycle() + .zip(invs) + .map(|(&z_h, inv)| z_h * inv) + .collect_vec() + }; + + let subgroup_last = self.gen().inverse(); + + LagrangeSelectors { + is_first_row: single_point_selector(0), + is_last_row: single_point_selector((1 << self.log_n) - 1), + is_transition: xs.into_iter().map(|x| x - subgroup_last).collect(), + inv_zeroifier: batch_multiplicative_inverse(&evals) + .into_iter() + .cycle() + .take(1 << coset.log_n) + .collect(), + } + } +} diff --git a/commit/src/lib.rs b/commit/src/lib.rs index a164f645..2a2f6b97 100644 --- a/commit/src/lib.rs +++ b/commit/src/lib.rs @@ -5,9 +5,14 @@ extern crate alloc; mod adapters; +mod domain; mod mmcs; mod pcs; +#[cfg(any(test, feature = "test-utils"))] +pub mod testing; + pub use adapters::*; +pub use domain::*; pub use mmcs::*; pub use pcs::*; diff --git a/commit/src/pcs.rs b/commit/src/pcs.rs index 5bdb6df0..9f6e9354 100644 --- a/commit/src/pcs.rs +++ b/commit/src/pcs.rs @@ -1,22 +1,25 @@ //! Traits for polynomial commitment schemes. -use alloc::vec; use alloc::vec::Vec; use core::fmt::Debug; -use p3_challenger::FieldChallenger; -use p3_field::{ExtensionField, Field}; -use p3_matrix::{Dimensions, MatrixGet, MatrixRows}; +use p3_field::ExtensionField; +use p3_matrix::dense::RowMajorMatrix; use serde::de::DeserializeOwned; use serde::Serialize; -/// A (not necessarily hiding) polynomial commitment scheme, for committing to (batches of) -/// polynomials defined over the field `F`. -/// -/// This high-level trait is agnostic with respect to the structure of a point; see `UnivariatePCS` -/// and `MultivariatePcs` for more specific subtraits. +use crate::PolynomialSpace; + +pub type Val = ::Val; + +/// A (not necessarily hiding) polynomial commitment scheme, for committing to (batches of) polynomials // TODO: Should we have a super-trait for weakly-binding PCSs, like FRI outside unique decoding radius? -pub trait Pcs> { +pub trait Pcs +where + Challenge: ExtensionField>, +{ + type Domain: PolynomialSpace; + /// The commitment that's sent to the verifier. type Commitment: Clone + Serialize + DeserializeOwned; @@ -28,96 +31,61 @@ pub trait Pcs> { type Error: Debug; - fn commit_batches(&self, polynomials: Vec) -> (Self::Commitment, Self::ProverData); - - fn commit_batch(&self, polynomials: In) -> (Self::Commitment, Self::ProverData) { - self.commit_batches(vec![polynomials]) - } -} - -pub type OpenedValues = Vec>; -pub type OpenedValuesForRound = Vec>; -pub type OpenedValuesForMatrix = Vec>; -pub type OpenedValuesForPoint = Vec; - -pub trait UnivariatePcs: Pcs -where - Val: Field, - EF: ExtensionField, - In: MatrixRows, - Challenger: FieldChallenger, -{ - fn open_multi_batches( - &self, - prover_data_and_points: &[(&Self::ProverData, &[Vec])], - challenger: &mut Challenger, - ) -> (OpenedValues, Self::Proof); - - fn verify_multi_batches( - &self, - commits_and_points: &[(Self::Commitment, &[Vec])], - dims: &[Vec], - values: OpenedValues, - proof: &Self::Proof, - challenger: &mut Challenger, - ) -> Result<(), Self::Error>; -} - -/// A `UnivariatePcs` where the commitment process involves computing a low-degree extension (LDE) -/// of each polynomial. These LDEs can be reused in other prover work. -pub trait UnivariatePcsWithLde: - UnivariatePcs -where - Val: Field, - EF: ExtensionField, - In: MatrixRows, - Challenger: FieldChallenger, -{ - type Lde<'a>: MatrixRows + MatrixGet + Sync - where - Self: 'a; - - fn coset_shift(&self) -> Val; - - fn log_blowup(&self) -> usize; - - fn get_ldes<'a, 'b>(&'a self, prover_data: &'b Self::ProverData) -> Vec> - where - 'a: 'b; + /// This should return a coset domain (s.t. Domain::next_point returns Some) + fn natural_domain_for_degree(&self, degree: usize) -> Self::Domain; - // Commit to polys that are already defined over a coset. - fn commit_shifted_batches( + #[allow(clippy::type_complexity)] + fn commit( &self, - polynomials: Vec, - coset_shift: &[Val], + evaluations: Vec<(Self::Domain, RowMajorMatrix>)>, ) -> (Self::Commitment, Self::ProverData); - fn commit_shifted_batch( + fn get_evaluations_on_domain( &self, - polynomials: In, - coset_shift: Val, - ) -> (Self::Commitment, Self::ProverData) { - self.commit_shifted_batches(vec![polynomials], &[coset_shift]) - } -} + prover_data: &Self::ProverData, + idx: usize, + domain: Self::Domain, + ) -> RowMajorMatrix>; -pub trait MultivariatePcs: Pcs -where - Val: Field, - EF: ExtensionField, - In: MatrixRows, - Challenger: FieldChallenger, -{ - fn open_multi_batches( + fn open( &self, - prover_data_and_points: &[(&Self::ProverData, &[Vec])], + // For each round, + rounds: Vec<( + &Self::ProverData, + // for each matrix, + Vec< + // points to open + Vec, + >, + )>, challenger: &mut Challenger, - ) -> (OpenedValues, Self::Proof); + ) -> (OpenedValues, Self::Proof); - fn verify_multi_batches( + #[allow(clippy::type_complexity)] + fn verify( &self, - commits_and_points: &[(Self::Commitment, &[Vec])], - values: OpenedValues, + // For each round: + rounds: Vec<( + Self::Commitment, + // for each matrix: + Vec<( + // its domain, + Self::Domain, + // for each point: + Vec<( + // the point, + Challenge, + // values at the point + Vec, + )>, + )>, + )>, proof: &Self::Proof, + challenger: &mut Challenger, ) -> Result<(), Self::Error>; } + +pub type OpenedValues = Vec>; +pub type OpenedValuesForRound = Vec>; +pub type OpenedValuesForMatrix = Vec>; +pub type OpenedValuesForPoint = Vec; diff --git a/commit/src/testing.rs b/commit/src/testing.rs new file mode 100644 index 00000000..2d31a8f3 --- /dev/null +++ b/commit/src/testing.rs @@ -0,0 +1,173 @@ +use alloc::vec; +use alloc::vec::Vec; +use core::marker::PhantomData; + +use p3_challenger::CanSample; +use p3_dft::TwoAdicSubgroupDft; +use p3_field::{ExtensionField, Field, TwoAdicField}; +use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::{Matrix, MatrixRowSlices, MatrixRows}; +use p3_util::log2_strict_usize; +use serde::{Deserialize, Serialize}; + +use crate::{OpenedValues, Pcs, PolynomialSpace, TwoAdicMultiplicativeCoset}; + +/// A trivial PCS: its commitment is simply the coefficients of each poly. +pub struct TrivialPcs> { + pub dft: Dft, + // degree bound + pub log_n: usize, + pub _phantom: PhantomData, +} + +pub fn eval_coeffs_at_pt>( + coeffs: &RowMajorMatrix, + x: EF, +) -> Vec { + let mut acc = vec![EF::zero(); coeffs.width()]; + for r in (0..coeffs.height()).rev() { + let row = coeffs.row_slice(r); + for (acc_c, row_c) in acc.iter_mut().zip(row) { + *acc_c *= x; + *acc_c += *row_c; + } + } + acc +} + +impl Pcs for TrivialPcs +where + Val: TwoAdicField, + Challenge: ExtensionField, + Challenger: CanSample, + + Dft: TwoAdicSubgroupDft, + + Vec>: Serialize + for<'de> Deserialize<'de>, +{ + type Domain = TwoAdicMultiplicativeCoset; + type Commitment = Vec>; + type ProverData = Vec>; + type Proof = (); + type Error = (); + + fn natural_domain_for_degree(&self, degree: usize) -> Self::Domain { + TwoAdicMultiplicativeCoset { + log_n: log2_strict_usize(degree), + shift: Val::one(), + } + } + + fn commit( + &self, + evaluations: Vec<(Self::Domain, RowMajorMatrix)>, + ) -> (Self::Commitment, Self::ProverData) { + let coeffs: Vec<_> = evaluations + .into_iter() + .map(|(domain, evals)| { + let log_domain_size = log2_strict_usize(domain.size()); + // for now, only commit on larger domain than natural + assert!(log_domain_size >= self.log_n); + assert_eq!(domain.size(), evals.height()); + // coset_idft_batch + let mut coeffs = self.dft.idft_batch(evals); + coeffs + .rows_mut() + .zip(domain.shift.inverse().powers()) + .for_each(|(row, weight)| { + row.iter_mut().for_each(|coeff| { + *coeff *= weight; + }) + }); + coeffs + }) + .collect(); + ( + coeffs.clone().into_iter().map(|m| m.values).collect(), + coeffs, + ) + } + + fn get_evaluations_on_domain( + &self, + prover_data: &Self::ProverData, + idx: usize, + domain: Self::Domain, + ) -> RowMajorMatrix { + let mut coeffs = prover_data[idx].clone(); + assert!(domain.log_n >= self.log_n); + coeffs.values.resize( + coeffs.values.len() << (domain.log_n - self.log_n), + Val::zero(), + ); + self.dft + .coset_dft_batch(coeffs, domain.shift) + .to_row_major_matrix() + } + + fn open( + &self, + // For each round, + rounds: Vec<( + &Self::ProverData, + // for each matrix, + Vec< + // points to open + Vec, + >, + )>, + _challenger: &mut Challenger, + ) -> (OpenedValues, Self::Proof) { + ( + rounds + .into_iter() + .map(|(coeffs_for_round, points_for_round)| { + coeffs_for_round + .iter() + .zip(points_for_round) + .map(|(coeffs_for_mat, points_for_mat)| { + points_for_mat + .into_iter() + .map(|pt| eval_coeffs_at_pt(coeffs_for_mat, pt)) + .collect() + }) + .collect() + }) + .collect(), + (), + ) + } + + fn verify( + &self, + // For each round: + rounds: Vec<( + Self::Commitment, + // for each matrix: + Vec<( + // its domain, + Self::Domain, + // for each point: + Vec<( + Challenge, + // values at this point + Vec, + )>, + )>, + )>, + _proof: &Self::Proof, + _challenger: &mut Challenger, + ) -> Result<(), Self::Error> { + for (comm, round_opening) in rounds { + for (coeff_vec, (domain, points_and_values)) in comm.into_iter().zip(round_opening) { + let width = coeff_vec.len() / domain.size(); + assert_eq!(width * domain.size(), coeff_vec.len()); + let coeffs = RowMajorMatrix::new(coeff_vec, width); + for (pt, values) in points_and_values { + assert_eq!(eval_coeffs_at_pt(&coeffs, pt), values); + } + } + } + Ok(()) + } +} diff --git a/dft/Cargo.toml b/dft/Cargo.toml index 947d3f2a..284a930b 100644 --- a/dft/Cargo.toml +++ b/dft/Cargo.toml @@ -9,6 +9,7 @@ p3-field = { path = "../field" } p3-matrix = { path = "../matrix" } p3-maybe-rayon = { path = "../maybe-rayon" } p3-util = { path = "../util" } +tracing = "0.1.37" [dev-dependencies] p3-baby-bear = { path = "../baby-bear" } diff --git a/dft/src/lib.rs b/dft/src/lib.rs index 28517e29..4720eab8 100644 --- a/dft/src/lib.rs +++ b/dft/src/lib.rs @@ -19,3 +19,4 @@ pub use radix_2_bowers::*; pub use radix_2_dit::*; pub use radix_2_dit_parallel::*; pub use traits::*; +pub use util::*; diff --git a/dft/src/radix_2_dit_parallel.rs b/dft/src/radix_2_dit_parallel.rs index 9a2edd19..dcc6cb29 100644 --- a/dft/src/radix_2_dit_parallel.rs +++ b/dft/src/radix_2_dit_parallel.rs @@ -7,6 +7,7 @@ use p3_matrix::util::reverse_matrix_index_bits; use p3_matrix::Matrix; use p3_maybe_rayon::prelude::*; use p3_util::{log2_strict_usize, reverse_bits, reverse_slice_index_bits}; +use tracing::instrument; use crate::butterflies::dit_butterfly; use crate::util::bit_reversed_zero_pad; @@ -46,6 +47,7 @@ impl TwoAdicSubgroupDft for Radix2DitParallel { mat.bit_reverse_rows() } + #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))] fn coset_lde_batch( &self, mut mat: RowMajorMatrix, diff --git a/dft/src/util.rs b/dft/src/util.rs index 6790cdbd..2babd88e 100644 --- a/dft/src/util.rs +++ b/dft/src/util.rs @@ -1,23 +1,24 @@ use alloc::vec; -use p3_field::Field; +use p3_field::{Field, PackedValue}; use p3_matrix::dense::RowMajorMatrix; -use p3_matrix::Matrix; +use p3_matrix::MatrixRowSlicesMut; /// Divide each coefficient of the given matrix by its height. -pub(crate) fn divide_by_height(mat: &mut RowMajorMatrix) { +pub fn divide_by_height(mat: &mut impl MatrixRowSlicesMut) { let h = mat.height(); let h_inv = F::from_canonical_usize(h).inverse(); - let (prefix, shorts, suffix) = unsafe { mat.values.align_to_mut::() }; - prefix.iter_mut().for_each(|x| *x *= h_inv); - shorts.iter_mut().for_each(|x| *x *= h_inv); - suffix.iter_mut().for_each(|x| *x *= h_inv); + for r in 0..h { + let (packed, suffix) = F::Packing::pack_slice_with_suffix_mut(mat.row_slice_mut(r)); + packed.iter_mut().for_each(|x| *x *= h_inv); + suffix.iter_mut().for_each(|x| *x *= h_inv); + } } /// Append zeros to the "end" of the given matrix, except that the matrix is in bit-reversed order, /// so in actuality we're interleaving zero rows. #[inline] -pub(crate) fn bit_reversed_zero_pad(mat: &mut RowMajorMatrix, added_bits: usize) { +pub fn bit_reversed_zero_pad(mat: &mut RowMajorMatrix, added_bits: usize) { if added_bits == 0 { return; } diff --git a/field-testing/src/lib.rs b/field-testing/src/lib.rs index 107c0db4..0b6959d4 100644 --- a/field-testing/src/lib.rs +++ b/field-testing/src/lib.rs @@ -26,6 +26,7 @@ where assert_eq!(x + (-x), F::zero()); assert_eq!(-x, F::zero() - x); assert_eq!(x + x, x * F::two()); + assert_eq!(x, x.halve() * F::two()); assert_eq!(x * (-x), -x.square()); assert_eq!(x + y, y + x); assert_eq!(x * y, y * x); diff --git a/field/src/extension/binomial_extension.rs b/field/src/extension/binomial_extension.rs index b773ad82..d44cfb30 100644 --- a/field/src/extension/binomial_extension.rs +++ b/field/src/extension/binomial_extension.rs @@ -218,6 +218,12 @@ impl, const D: usize> Field for BinomialExtensionFiel _ => Some(self.frobenius_inv()), } } + + fn halve(&self) -> Self { + Self { + value: self.value.map(|x| x.halve()), + } + } } impl Display for BinomialExtensionField diff --git a/field/src/extension/complex.rs b/field/src/extension/complex.rs index 5da75e8c..0143b9b2 100644 --- a/field/src/extension/complex.rs +++ b/field/src/extension/complex.rs @@ -1,5 +1,5 @@ use super::{BinomialExtensionField, BinomiallyExtendable, HasTwoAdicBionmialExtension}; -use crate::{AbstractField, Field}; +use crate::{AbstractExtensionField, AbstractField, Field}; pub type Complex = BinomialExtensionField; @@ -56,6 +56,14 @@ impl Complex { pub fn to_array(&self) -> [AF; 2] { self.value.clone() } + // Sometimes we want to rotate over an extension that's not necessarily ComplexExtendable, + // but still on the circle. + pub fn rotate>(&self, rhs: Complex) -> Complex { + Complex::::new( + rhs.real() * self.real() - rhs.imag() * self.imag(), + rhs.imag() * self.real() + rhs.real() * self.imag(), + ) + } } /// The complex extension of this field has a binomial extension. diff --git a/field/src/field.rs b/field/src/field.rs index d547b55a..d5fee19e 100644 --- a/field/src/field.rs +++ b/field/src/field.rs @@ -145,6 +145,14 @@ pub trait AbstractField: fn dot_product(u: &[Self; N], v: &[Self; N]) -> Self { u.iter().zip(v).map(|(x, y)| x.clone() * y.clone()).sum() } + + fn try_div(self, rhs: Rhs) -> Option<>::Output> + where + Rhs: Field, + Self: Mul, + { + rhs.try_inverse().map(|inv| self * inv) + } } /// An element of a finite field. @@ -207,6 +215,17 @@ pub trait Field: fn inverse(&self) -> Self { self.try_inverse().expect("Tried to invert zero") } + + /// Computes input/2. + /// Should be overwritten by most field implementations to use bitshifts. + /// Will error if the field characteristic is 2. + #[must_use] + fn halve(&self) -> Self { + let half = Self::two() + .try_inverse() + .expect("Cannot divide by 2 in fields with characteristic 2"); + *self * half + } } pub trait PrimeField: Field + Ord {} @@ -224,14 +243,6 @@ pub trait PrimeField64: PrimeField { /// Return the representative of `value` that is less than `ORDER_U64`. fn as_canonical_u64(&self) -> u64; - - /// Return the value \sum_{i=0}^N u[i] * v[i]. - /// - /// NB: Assumes that sum(u) <= 2^32 to allow implementations to avoid - /// overflow handling. - /// - /// TODO: Mark unsafe because of the assumption? - fn linear_combination_u64(u: [u64; N], v: &[Self; N]) -> Self; } /// A prime field of order less than `2^32`. diff --git a/field/src/helpers.rs b/field/src/helpers.rs index 3167c506..aa52f286 100644 --- a/field/src/helpers.rs +++ b/field/src/helpers.rs @@ -58,7 +58,7 @@ where x.iter_mut().zip(y).for_each(|(x_i, y_i)| *x_i += y_i * s); } -/// Extend a field `AF` element `x` to an arry of length `D` +/// Extend a field `AF` element `x` to an array of length `D` /// by filling zeros. pub fn field_to_array(x: AF) -> [AF; D] { let mut arr = array::from_fn(|_| AF::zero()); @@ -99,3 +99,31 @@ pub fn eval_poly(poly: &[AF], x: AF) -> AF { } acc } + +/// Given an element x from a 32 bit field F_P compute x/2. +#[inline] +pub fn halve_u32(input: u32) -> u32 { + let shift = (P + 1) >> 1; + let shr = input >> 1; + let lo_bit = input & 1; + let shr_corr = shr + shift; + if lo_bit == 0 { + shr + } else { + shr_corr + } +} + +/// Given an element x from a 64 bit field F_P compute x/2. +#[inline] +pub fn halve_u64(input: u64) -> u64 { + let shift = (P + 1) >> 1; + let shr = input >> 1; + let lo_bit = input & 1; + let shr_corr = shr + shift; + if lo_bit == 0 { + shr + } else { + shr_corr + } +} diff --git a/field/src/packed.rs b/field/src/packed.rs index fdda9065..918b8e91 100644 --- a/field/src/packed.rs +++ b/field/src/packed.rs @@ -63,6 +63,11 @@ pub unsafe trait PackedValue: 'static + Copy + From + Send + Sync { unsafe { slice::from_raw_parts_mut(buf_ptr, n) } } + fn pack_slice_with_suffix_mut(buf: &mut [Self::Value]) -> (&mut [Self], &mut [Self::Value]) { + let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH); + (Self::pack_slice_mut(packed), suffix) + } + fn unpack_slice(buf: &[Self]) -> &[Self::Value] { assert!(mem::align_of::() >= mem::align_of::()); let buf_ptr = buf.as_ptr().cast::(); diff --git a/fri/src/lib.rs b/fri/src/lib.rs index 14b7dfcb..de97342d 100644 --- a/fri/src/lib.rs +++ b/fri/src/lib.rs @@ -8,7 +8,7 @@ mod config; mod fold_even_odd; mod proof; pub mod prover; -pub mod two_adic_pcs; +mod two_adic_pcs; pub mod verifier; pub use config::*; diff --git a/fri/src/two_adic_pcs.rs b/fri/src/two_adic_pcs.rs index 1845bef0..9f99ca21 100644 --- a/fri/src/two_adic_pcs.rs +++ b/fri/src/two_adic_pcs.rs @@ -1,11 +1,11 @@ use alloc::vec; use alloc::vec::Vec; -use core::fmt::{Debug, Formatter}; +use core::fmt::Debug; use core::marker::PhantomData; use itertools::{izip, Itertools}; -use p3_challenger::{CanObserve, CanSample, FieldChallenger, GrindingChallenger}; -use p3_commit::{DirectMmcs, Mmcs, OpenedValues, Pcs, UnivariatePcs, UnivariatePcsWithLde}; +use p3_challenger::{CanObserve, CanSample, GrindingChallenger}; +use p3_commit::{DirectMmcs, Mmcs, OpenedValues, Pcs, PolynomialSpace, TwoAdicMultiplicativeCoset}; use p3_dft::TwoAdicSubgroupDft; use p3_field::{ batch_multiplicative_inverse, cyclic_subgroup_coset_known_order, AbstractField, ExtensionField, @@ -13,7 +13,7 @@ use p3_field::{ }; use p3_interpolation::interpolate_coset; use p3_matrix::bitrev::{BitReversableMatrix, BitReversedMatrixView}; -use p3_matrix::dense::RowMajorMatrixView; +use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView}; use p3_matrix::{Dimensions, Matrix, MatrixRows}; use p3_maybe_rayon::prelude::*; use p3_util::linear_map::LinearMap; @@ -24,167 +24,135 @@ use tracing::{info_span, instrument}; use crate::verifier::{self, FriError}; use crate::{prover, FriConfig, FriProof}; -/// We group all of our type bounds into this trait to reduce duplication across signatures. -pub trait TwoAdicFriPcsGenericConfig: Default { - type Val: TwoAdicField; - type Challenge: TwoAdicField + ExtensionField; - type Challenger: FieldChallenger - + GrindingChallenger - + CanObserve<>::Commitment> - + CanSample; - type Dft: TwoAdicSubgroupDft; - type InputMmcs: 'static - + for<'a> DirectMmcs = RowMajorMatrixView<'a, Self::Val>>; - type FriMmcs: DirectMmcs; +pub struct TwoAdicFriPcs { + // degree bound + log_n: usize, + dft: Dft, + mmcs: InputMmcs, + fri: FriConfig, + _phantom: PhantomData, } -pub struct TwoAdicFriPcsConfig( - PhantomData<(Val, Challenge, Challenger, Dft, InputMmcs, FriMmcs)>, -); -impl Default - for TwoAdicFriPcsConfig -{ - fn default() -> Self { - Self(PhantomData) - } -} - -impl TwoAdicFriPcsGenericConfig - for TwoAdicFriPcsConfig -where - Val: TwoAdicField, - Challenge: TwoAdicField + ExtensionField, - Challenger: FieldChallenger - + GrindingChallenger - + CanObserve<>::Commitment> - + CanSample, - Dft: TwoAdicSubgroupDft, - InputMmcs: 'static + for<'a> DirectMmcs = RowMajorMatrixView<'a, Val>>, - FriMmcs: DirectMmcs, -{ - type Val = Val; - type Challenge = Challenge; - type Challenger = Challenger; - type Dft = Dft; - type InputMmcs = InputMmcs; - type FriMmcs = FriMmcs; -} - -pub struct TwoAdicFriPcs { - fri: FriConfig, - dft: C::Dft, - mmcs: C::InputMmcs, -} - -impl TwoAdicFriPcs { - pub fn new(fri: FriConfig, dft: C::Dft, mmcs: C::InputMmcs) -> Self { - Self { fri, dft, mmcs } +impl TwoAdicFriPcs { + pub fn new(log_n: usize, dft: Dft, mmcs: InputMmcs, fri: FriConfig) -> Self { + Self { + log_n, + dft, + mmcs, + fri, + _phantom: PhantomData, + } } } -pub enum VerificationError { - InputMmcsError(>::Error), - FriError(FriError<>::Error>), -} - -impl Debug for VerificationError { - fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { - match self { - VerificationError::InputMmcsError(e) => { - f.debug_tuple("InputMmcsError").field(e).finish() - } - VerificationError::FriError(e) => f.debug_tuple("FriError").field(e).finish(), - } - } +#[derive(Debug)] +pub enum VerificationError { + InputMmcsError(InputMmcsError), + FriError(FriError), } #[derive(Serialize, Deserialize)] #[serde(bound = "")] -pub struct TwoAdicFriPcsProof { - pub(crate) fri_proof: FriProof, +pub struct TwoAdicFriPcsProof< + Val: Field, + Challenge: Field, + InputMmcs: Mmcs, + FriMmcs: Mmcs, +> { + pub(crate) fri_proof: FriProof, /// For each query, for each committed batch, query openings for that batch - pub(crate) query_openings: Vec>>, + pub(crate) query_openings: Vec>>, } #[derive(Serialize, Deserialize)] -pub struct BatchOpening { - pub(crate) opened_values: Vec>, - pub(crate) opening_proof: >::Proof, -} - -impl> Pcs for TwoAdicFriPcs { - type Commitment = >::Commitment; - type ProverData = >::ProverData; - type Proof = TwoAdicFriPcsProof; - type Error = VerificationError; - - fn commit_batches(&self, polynomials: Vec) -> (Self::Commitment, Self::ProverData) { - let ones = vec![C::Val::one(); polynomials.len()]; - self.commit_shifted_batches(polynomials, &ones) - } +#[serde(bound = "")] +pub struct BatchOpening> { + pub(crate) opened_values: Vec>, + pub(crate) opening_proof: >::Proof, } -impl> - UnivariatePcsWithLde for TwoAdicFriPcs +impl Pcs + for TwoAdicFriPcs +where + Val: TwoAdicField, + Dft: TwoAdicSubgroupDft, + InputMmcs: 'static + for<'a> DirectMmcs = RowMajorMatrixView<'a, Val>>, + FriMmcs: DirectMmcs, + Challenge: TwoAdicField + ExtensionField, + Challenger: + CanObserve + CanSample + GrindingChallenger, { - type Lde<'a> = BitReversedMatrixView<>::Mat<'a>> where Self: 'a; - - fn coset_shift(&self) -> C::Val { - C::Val::generator() - } - - fn log_blowup(&self) -> usize { - self.fri.log_blowup - } - - fn get_ldes<'a, 'b>(&'a self, prover_data: &'b Self::ProverData) -> Vec> - where - 'a: 'b, - { - // We committed to the bit-reversed LDE, so now we wrap it to return in natural order. - self.mmcs - .get_matrices(prover_data) - .into_iter() - .map(BitReversedMatrixView::new) - .collect() + type Domain = TwoAdicMultiplicativeCoset; + type Commitment = InputMmcs::Commitment; + type ProverData = InputMmcs::ProverData; + type Proof = TwoAdicFriPcsProof; + type Error = VerificationError; + + fn natural_domain_for_degree(&self, degree: usize) -> Self::Domain { + let log_n = log2_strict_usize(degree); + assert!(log_n <= self.log_n); + TwoAdicMultiplicativeCoset { + log_n, + shift: Val::one(), + } } - fn commit_shifted_batches( + fn commit( &self, - polynomials: Vec, - coset_shifts: &[C::Val], + evaluations: Vec<(Self::Domain, RowMajorMatrix)>, ) -> (Self::Commitment, Self::ProverData) { - let ldes = info_span!("compute all coset LDEs").in_scope(|| { - polynomials - .into_iter() - .zip_eq(coset_shifts) - .map(|(poly, coset_shift)| { - let shift = C::Val::generator() / *coset_shift; - let input = poly.to_row_major_matrix(); - // Commit to the bit-reversed LDE. - self.dft - .coset_lde_batch(input, self.fri.log_blowup, shift) - .bit_reverse_rows() - .to_row_major_matrix() - }) - .collect() - }); + let ldes: Vec> = evaluations + .into_iter() + .map(|(domain, evals)| { + assert_eq!(domain.size(), evals.height()); + let log_n = log2_strict_usize(domain.size()); + assert!(log_n <= self.log_n); + let shift = Val::generator() / domain.shift; + // Commit to the bit-reversed LDE. + self.dft + .coset_lde_batch(evals, self.fri.log_blowup, shift) + .bit_reverse_rows() + .to_row_major_matrix() + }) + .collect(); + self.mmcs.commit(ldes) } -} -impl> - UnivariatePcs for TwoAdicFriPcs -{ - #[instrument(name = "open_multi_batches", skip_all)] - fn open_multi_batches( + fn get_evaluations_on_domain( &self, - prover_data_and_points: &[(&Self::ProverData, &[Vec])], - challenger: &mut C::Challenger, - ) -> (OpenedValues, Self::Proof) { - // Batch combination challenge - let alpha = >::sample(challenger); + prover_data: &Self::ProverData, + idx: usize, + domain: Self::Domain, + ) -> RowMajorMatrix { + // todo: handle extrapolation for LDEs we don't have + assert_eq!(domain.shift, Val::generator()); + let lde = self.mmcs.get_matrices(prover_data)[idx]; + assert!(lde.height() >= domain.size()); + let extra_bits = log2_strict_usize(lde.height()) - log2_strict_usize(domain.size()); + // TODO get rid of these 2 copies + let strided = lde + .to_row_major_matrix() + .bit_reverse_rows() + .vertically_strided(1 << extra_bits, 0) + .to_row_major_matrix(); + assert_eq!(strided.height(), domain.size()); + strided + } + fn open( + &self, + // For each round, + rounds: Vec<( + &Self::ProverData, + // for each matrix, + Vec< + // points to open + Vec, + >, + )>, + challenger: &mut Challenger, + ) -> (OpenedValues, Self::Proof) { /* A quick rundown of the optimizations in this function: @@ -222,35 +190,37 @@ impl> */ - let mats_and_points = prover_data_and_points + // Batch combination challenge + let alpha: Challenge = challenger.sample(); + + let mats_and_points = rounds .iter() - .map(|(data, points)| (self.mmcs.get_matrices(data), *points)) + .map(|(data, points)| (self.mmcs.get_matrices(data), points)) .collect_vec(); let max_width = mats_and_points .iter() .flat_map(|(mats, _)| mats) - .map(|mat| mat.width()) + .map(|m| m.width()) .max() .unwrap(); - let alpha_reducer = PowersReducer::::new(alpha, max_width); + let alpha_reducer = PowersReducer::::new(alpha, max_width); // For each unique opening point z, we will find the largest degree bound // for that point, and precompute 1/(X - z) for the largest subgroup (in bitrev order). - let inv_denoms = compute_inverse_denominators(&mats_and_points, C::Val::generator()); + let inv_denoms = compute_inverse_denominators(&mats_and_points, Val::generator()); - let mut all_opened_values: OpenedValues = vec![]; + let mut all_opened_values: OpenedValues = vec![]; let mut reduced_openings: [_; 32] = core::array::from_fn(|_| None); let mut num_reduced = [0; 32]; - for (data, points) in prover_data_and_points { - let mats = self.mmcs.get_matrices(data); + for (mats, points) in mats_and_points { let opened_values_for_round = all_opened_values.pushed_mut(vec![]); - for (mat, points_for_mat) in izip!(mats, *points) { + for (mat, points_for_mat) in izip!(mats, points) { let log_height = log2_strict_usize(mat.height()); let reduced_opening_for_log_height = reduced_openings[log_height] - .get_or_insert_with(|| vec![C::Challenge::zero(); mat.height()]); + .get_or_insert_with(|| vec![Challenge::zero(); mat.height()]); debug_assert_eq!(reduced_opening_for_log_height.len(), mat.height()); let opened_values_for_mat = opened_values_for_round.pushed_mut(vec![]); @@ -265,7 +235,7 @@ impl> mat.split_rows(mat.height() >> self.fri.log_blowup); interpolate_coset( &BitReversedMatrixView::new(low_coset), - C::Val::generator(), + Val::generator(), point, ) }); @@ -299,9 +269,10 @@ impl> let query_openings = query_indices .into_iter() .map(|index| { - prover_data_and_points + rounds .iter() .map(|(data, _)| { + // needs to recombine decomposed openings.. or something... let (opened_values, opening_proof) = self.mmcs.open_batch(index, data); BatchOpening { opened_values, @@ -321,16 +292,29 @@ impl> ) } - fn verify_multi_batches( + fn verify( &self, - commits_and_points: &[(Self::Commitment, &[Vec])], - dims: &[Vec], - values: OpenedValues, + // For each round: + rounds: Vec<( + Self::Commitment, + // for each matrix: + Vec<( + // its domain, + Self::Domain, + // for each point: + Vec<( + // the point, + Challenge, + // values at the point + Vec, + )>, + )>, + )>, proof: &Self::Proof, - challenger: &mut C::Challenger, + challenger: &mut Challenger, ) -> Result<(), Self::Error> { // Batch combination challenge - let alpha = >::sample(challenger); + let alpha: Challenge = challenger.sample(); let fri_challenges = verifier::verify_shape_and_sample_challenges(&self.fri, &proof.fri_proof, challenger) @@ -338,41 +322,43 @@ impl> let log_max_height = proof.fri_proof.commit_phase_commits.len() + self.fri.log_blowup; - let reduced_openings: Vec<[C::Challenge; 32]> = proof + let reduced_openings: Vec<[Challenge; 32]> = proof .query_openings .iter() .zip(&fri_challenges.query_indices) .map(|(query_opening, &index)| { - let mut ro = [C::Challenge::zero(); 32]; - let mut alpha_pow = [C::Challenge::one(); 32]; - for (batch_opening, batch_dims, (batch_commit, batch_points), batch_at_z) in - izip!(query_opening, dims, commits_and_points, &values) - { + let mut ro = [Challenge::zero(); 32]; + let mut alpha_pow = [Challenge::one(); 32]; + for (batch_opening, (batch_commit, mats)) in izip!(query_opening, &rounds) { + let batch_dims: Vec = mats + .iter() + .map(|(domain, _)| Dimensions { + // todo: mmcs doesn't really need width + width: 0, + height: domain.size(), + }) + .collect_vec(); self.mmcs.verify_batch( batch_commit, - batch_dims, + &batch_dims, index, &batch_opening.opened_values, &batch_opening.opening_proof, )?; - for (mat_opening, mat_dims, mat_points, mat_at_z) in izip!( - &batch_opening.opened_values, - batch_dims, - *batch_points, - batch_at_z - ) { - let log_height = log2_strict_usize(mat_dims.height) + self.fri.log_blowup; + for (mat_opening, (mat_domain, mat_points_and_values)) in + izip!(&batch_opening.opened_values, mats) + { + let log_height = log2_strict_usize(mat_domain.size()) + self.fri.log_blowup; let bits_reduced = log_max_height - log_height; let rev_reduced_index = reverse_bits_len(index >> bits_reduced, log_height); - let x = C::Val::generator() - * C::Val::two_adic_generator(log_height) - .exp_u64(rev_reduced_index as u64); + let x = Val::generator() + * Val::two_adic_generator(log_height).exp_u64(rev_reduced_index as u64); - for (&z, ps_at_z) in izip!(mat_points, mat_at_z) { + for (z, ps_at_z) in mat_points_and_values { for (&p_at_x, &p_at_z) in izip!(mat_opening, ps_at_z) { - let quotient = (-p_at_z + p_at_x) / (-z + x); + let quotient = (-p_at_z + p_at_x) / (-*z + x); ro[log_height] += alpha_pow[log_height] * quotient; alpha_pow[log_height] *= alpha; } @@ -381,7 +367,7 @@ impl> } Ok(ro) }) - .collect::, >::Error>>() + .collect::, InputMmcs::Error>>() .map_err(VerificationError::InputMmcsError)?; verifier::verify_challenges( @@ -398,7 +384,7 @@ impl> #[instrument(skip_all)] fn compute_inverse_denominators, M: Matrix>( - mats_and_points: &[(Vec, &[Vec])], + mats_and_points: &[(Vec, &Vec>)], coset_shift: F, ) -> LinearMap> { let mut max_log_height_for_point: LinearMap = LinearMap::new(); diff --git a/fri/tests/fri.rs b/fri/tests/fri.rs index 45d78ed5..0e371a54 100644 --- a/fri/tests/fri.rs +++ b/fri/tests/fri.rs @@ -1,5 +1,5 @@ use itertools::Itertools; -use p3_baby_bear::BabyBear; +use p3_baby_bear::{BabyBear, DiffusionMatrixBabybear}; use p3_challenger::{CanSampleBits, DuplexChallenger, FieldChallenger}; use p3_commit::ExtensionMmcs; use p3_dft::{Radix2Dit, TwoAdicSubgroupDft}; @@ -10,7 +10,7 @@ use p3_matrix::dense::RowMajorMatrix; use p3_matrix::util::reverse_matrix_index_bits; use p3_matrix::{Matrix, MatrixRows}; use p3_merkle_tree::FieldMerkleTreeMmcs; -use p3_poseidon2::{DiffusionMatrixBabybear, Poseidon2}; +use p3_poseidon2::Poseidon2; use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; use p3_util::log2_strict_usize; use rand::{Rng, SeedableRng}; diff --git a/fri/tests/pcs.rs b/fri/tests/pcs.rs index e302312b..d5dc7f18 100644 --- a/fri/tests/pcs.rs +++ b/fri/tests/pcs.rs @@ -1,14 +1,13 @@ -use p3_baby_bear::BabyBear; +use p3_baby_bear::{BabyBear, DiffusionMatrixBabybear}; use p3_challenger::{CanObserve, DuplexChallenger, FieldChallenger}; -use p3_commit::{ExtensionMmcs, Pcs, UnivariatePcs}; +use p3_commit::{ExtensionMmcs, Pcs}; use p3_dft::Radix2DitParallel; use p3_field::extension::BinomialExtensionField; use p3_field::Field; -use p3_fri::{FriConfig, TwoAdicFriPcs, TwoAdicFriPcsConfig}; +use p3_fri::{FriConfig, TwoAdicFriPcs}; use p3_matrix::dense::RowMajorMatrix; -use p3_matrix::Matrix; use p3_merkle_tree::FieldMerkleTreeMmcs; -use p3_poseidon2::{DiffusionMatrixBabybear, Poseidon2}; +use p3_poseidon2::Poseidon2; use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; use rand::thread_rng; @@ -49,48 +48,48 @@ fn make_test_fri_pcs(log_degrees: &[usize]) { proof_of_work_bits: 8, mmcs: challenge_mmcs, }; - type Pcs = - TwoAdicFriPcs>; - let pcs = Pcs::new(fri_config, dft, val_mmcs); + type MyPcs = TwoAdicFriPcs; + let max_log_n = log_degrees.iter().copied().max().unwrap(); + let pcs: MyPcs = MyPcs::new(max_log_n, dft, val_mmcs, fri_config); let mut challenger = Challenger::new(perm.clone()); - let polynomials = log_degrees + let domains_and_polys = log_degrees .iter() - .map(|d| RowMajorMatrix::rand(&mut rng, 1 << *d, 10)) + .map(|&d| { + ( + >::natural_domain_for_degree(&pcs, 1 << d), + RowMajorMatrix::::rand(&mut rng, 1 << d, 10), + ) + }) .collect::>(); - let (commit, data) = pcs.commit_batches(polynomials.clone()); + let (commit, data) = + >::commit(&pcs, domains_and_polys.clone()); challenger.observe(commit); let zeta = challenger.sample_ext_element::(); - let points = polynomials.iter().map(|_| vec![zeta]).collect::>(); + let points = domains_and_polys + .iter() + .map(|_| vec![zeta]) + .collect::>(); - let (opening, proof) = , _>>::open_multi_batches( - &pcs, - &[(&data, &points)], - &mut challenger, - ); + let (opening, proof) = pcs.open(vec![(&data, points)], &mut challenger); // verify the proof. let mut challenger = Challenger::new(perm); challenger.observe(commit); let _ = challenger.sample_ext_element::(); - let dims = polynomials + + let os = domains_and_polys .iter() - .map(|p| p.dimensions()) - .collect::>(); - , _>>::verify_multi_batches( - &pcs, - &[(commit, &points)], - &[dims], - opening, - &proof, - &mut challenger, - ) - .expect("verification error"); + .zip(&opening[0]) + .map(|((domain, _), mat_openings)| (*domain, vec![(zeta, mat_openings[0].clone())])) + .collect(); + pcs.verify(vec![(commit, os)], &proof, &mut challenger) + .unwrap() } #[test] diff --git a/goldilocks/Cargo.toml b/goldilocks/Cargo.toml index c270e0c7..3dbe3ca6 100644 --- a/goldilocks/Cargo.toml +++ b/goldilocks/Cargo.toml @@ -6,12 +6,19 @@ license = "MIT OR Apache-2.0" [dependencies] p3-field = { path = "../field" } +p3-dft = { path = "../dft" } +p3-mds = { path = "../mds" } +p3-symmetric = { path = "../symmetric" } p3-util = { path = "../util" } +p3-poseidon2 = { path = "../poseidon2" } rand = "0.8.5" serde = { version = "1.0", default-features = false, features = ["derive"] } [dev-dependencies] p3-field-testing = { path = "../field-testing" } +ark-ff = { version = "^0.4.0", default-features = false } +zkhash = { git = "https://github.com/HorizenLabs/poseidon2" } +rand = { version = "0.8.5", features = ["min_const_gen"] } criterion = "0.5.1" [[bench]] diff --git a/goldilocks/src/lib.rs b/goldilocks/src/lib.rs index e5f6557b..a11e8971 100644 --- a/goldilocks/src/lib.rs +++ b/goldilocks/src/lib.rs @@ -2,7 +2,11 @@ #![no_std] +extern crate alloc; + mod extension; +mod mds; +mod poseidon2; use core::fmt; use core::fmt::{Debug, Display, Formatter}; @@ -10,15 +14,20 @@ use core::hash::{Hash, Hasher}; use core::iter::{Product, Sum}; use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; +pub use mds::*; use p3_field::{ - exp_10540996611094048183, exp_u64_by_squaring, AbstractField, Field, Packable, PrimeField, - PrimeField64, TwoAdicField, + exp_10540996611094048183, exp_u64_by_squaring, halve_u64, AbstractField, Field, Packable, + PrimeField, PrimeField64, TwoAdicField, }; use p3_util::{assume, branch_hint}; +pub use poseidon2::DiffusionMatrixGoldilocks; use rand::distributions::{Distribution, Standard}; use rand::Rng; use serde::{Deserialize, Serialize}; +/// The Goldilocks prime +const P: u64 = 0xFFFF_FFFF_0000_0001; + /// The prime field known as Goldilocks, defined as `F_p` where `p = 2^64 - 2^32 + 1`. #[derive(Copy, Clone, Default, Serialize, Deserialize)] pub struct Goldilocks { @@ -210,12 +219,17 @@ impl Field for Goldilocks { // compute base^1111111111111111111111111111111011111111111111111111111111111111 Some(t63.square() * *self) } + + #[inline] + fn halve(&self) -> Self { + Goldilocks::new(halve_u64::

(self.value)) + } } impl PrimeField for Goldilocks {} impl PrimeField64 for Goldilocks { - const ORDER_U64: u64 = 0xFFFF_FFFF_0000_0001; + const ORDER_U64: u64 = P; #[inline] fn as_canonical_u64(&self) -> u64 { @@ -226,20 +240,6 @@ impl PrimeField64 for Goldilocks { } c } - - fn linear_combination_u64(u: [u64; N], v: &[Self; N]) -> Self { - // In order not to overflow a u128, we must have sum(u) <= 2^64. - // However, we enforce the stronger condition sum(u) <= 2^32 - // to ensure the semantics of this function are consistent - // between the implementations. - debug_assert!(u.into_iter().map(u128::from).sum::() <= (1u128 << 32)); - - let mut dot = u[0] as u128 * v[0].value as u128; - for i in 1..N { - dot += u[i] as u128 * v[i].value as u128; - } - reduce128(dot) - } } impl TwoAdicField for Goldilocks { @@ -356,9 +356,6 @@ impl Div for Goldilocks { } } -// HELPER FUNCTIONS -// ================================================================================================ - /// Squares the base N number of times and multiplies the result by the tail value. #[inline(always)] fn exp_acc(base: Goldilocks, tail: Goldilocks) -> Goldilocks { @@ -368,7 +365,7 @@ fn exp_acc(base: Goldilocks, tail: Goldilocks) -> Goldilocks { /// Reduces to a 64-bit value. The result might not be in canonical form; it could be in between the /// field order and `2^64`. #[inline] -fn reduce128(x: u128) -> Goldilocks { +pub(crate) fn reduce128(x: u128) -> Goldilocks { let (x_lo, x_hi) = split(x); // This is a no-op let x_hi_hi = x_hi >> 32; let x_hi_lo = x_hi & Goldilocks::NEG_ORDER; @@ -429,6 +426,22 @@ unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { res_wrapped + Goldilocks::NEG_ORDER * u64::from(carry) } +/// Convert a constant u64 array into a constant Goldilocks array. +#[inline] +#[must_use] +pub(crate) const fn to_goldilocks_array(input: [u64; N]) -> [Goldilocks; N] { + let mut output = [Goldilocks { value: 0 }; N]; + let mut i = 0; + loop { + if i == N { + break; + } + output[i].value = input[i]; + i += 1; + } + output +} + #[cfg(test)] mod tests { use p3_field_testing::{test_field, test_two_adic_field}; diff --git a/mds/src/goldilocks.rs b/goldilocks/src/mds.rs similarity index 84% rename from mds/src/goldilocks.rs rename to goldilocks/src/mds.rs index 4cb42839..f9ee86a8 100644 --- a/mds/src/goldilocks.rs +++ b/goldilocks/src/mds.rs @@ -2,26 +2,75 @@ //! //! NB: Not all sizes have fast implementations of their permutations. //! Supported sizes: 8, 12, 16, 24, 32, 64, 68. -//! Sizes 8 and 12 are from Plonky2. Other sizes are from Ulrich Haböck's database. +//! Sizes 8 and 12 are from Plonky2, size 16 was found as part of concurrent +//! work by Angus Gruen and Hamish Ivey-Law. Other sizes are from Ulrich Haböck's +//! database. use p3_dft::Radix2Bowers; -use p3_goldilocks::Goldilocks; +use p3_mds::karatsuba_convolution::Convolve; +use p3_mds::util::{apply_circulant, apply_circulant_fft, first_row_to_first_col}; +use p3_mds::MdsPermutation; use p3_symmetric::Permutation; -use crate::util::{ - apply_circulant, apply_circulant_12_sml, apply_circulant_8_sml, apply_circulant_fft, - first_row_to_first_col, -}; -use crate::MdsPermutation; +use crate::{reduce128, Goldilocks}; #[derive(Clone, Default)] pub struct MdsMatrixGoldilocks; +/// Instantiate convolution for "small" RHS vectors over Goldilocks. +/// +/// Here "small" means N = len(rhs) <= 16 and sum(r for r in rhs) < +/// 2^51, though in practice the sum will be less than 2^9. +pub struct SmallConvolveGoldilocks; +impl Convolve for SmallConvolveGoldilocks { + /// Return the lift of a Goldilocks element, 0 <= input.value <= P + /// < 2^64. We widen immediately, since some valid Goldilocks elements + /// don't fit in an i64, and since in any case overflow can occur + /// for even the smallest convolutions. + #[inline(always)] + fn read(input: Goldilocks) -> i128 { + input.value as i128 + } + + /// For a convolution of size N, |x| < N * 2^64 and (as per the + /// assumption above), |y| < 2^51. So the product is at most N * + /// 2^115 which will not overflow for N <= 16. We widen `y` at + /// this point to perform the multiplication. + #[inline(always)] + fn parity_dot(u: [i128; N], v: [i64; N]) -> i128 { + let mut s = 0i128; + for i in 0..N { + s += u[i] * v[i] as i128; + } + s + } + + /// The assumptions above mean z < N^2 * 2^115, which is at most + /// 2^123 when N <= 16. + /// + /// NB: Even though intermediate values could be negative, the + /// output must be non-negative since the inputs were + /// non-negative. + #[inline(always)] + fn reduce(z: i128) -> Goldilocks { + debug_assert!(z >= 0); + reduce128(z as u128) + } +} + const FFT_ALGO: Radix2Bowers = Radix2Bowers; +const MATRIX_CIRC_MDS_8_SML_ROW: [i64; 8] = [7, 1, 3, 8, 8, 3, 4, 9]; + impl Permutation<[Goldilocks; 8]> for MdsMatrixGoldilocks { fn permute(&self, input: [Goldilocks; 8]) -> [Goldilocks; 8] { - apply_circulant_8_sml(input) + const MATRIX_CIRC_MDS_8_SML_COL: [i64; 8] = + first_row_to_first_col(&MATRIX_CIRC_MDS_8_SML_ROW); + SmallConvolveGoldilocks::apply( + input, + MATRIX_CIRC_MDS_8_SML_COL, + SmallConvolveGoldilocks::conv8, + ) } fn permute_mut(&self, input: &mut [Goldilocks; 8]) { @@ -30,9 +79,17 @@ impl Permutation<[Goldilocks; 8]> for MdsMatrixGoldilocks { } impl MdsPermutation for MdsMatrixGoldilocks {} +const MATRIX_CIRC_MDS_12_SML_ROW: [i64; 12] = [1, 1, 2, 1, 8, 9, 10, 7, 5, 9, 4, 10]; + impl Permutation<[Goldilocks; 12]> for MdsMatrixGoldilocks { fn permute(&self, input: [Goldilocks; 12]) -> [Goldilocks; 12] { - apply_circulant_12_sml(input) + const MATRIX_CIRC_MDS_12_SML_COL: [i64; 12] = + first_row_to_first_col(&MATRIX_CIRC_MDS_12_SML_ROW); + SmallConvolveGoldilocks::apply( + input, + MATRIX_CIRC_MDS_12_SML_COL, + SmallConvolveGoldilocks::conv12, + ) } fn permute_mut(&self, input: &mut [Goldilocks; 12]) { @@ -41,18 +98,18 @@ impl Permutation<[Goldilocks; 12]> for MdsMatrixGoldilocks { } impl MdsPermutation for MdsMatrixGoldilocks {} -#[rustfmt::skip] -const MATRIX_CIRC_MDS_16_GOLDILOCKS: [u64; 16] = [ - 0x0FFFFFFFF0001000, 0xF8FC7C7D47E3E3F3, 0xEC43C780F1D87790, 0xEAFD5FAB0A814029, - 0x29999FFFCFFFFCCD, 0x4E7D0C1750C5F9D0, 0xF3C5A1E6977E1D30, 0x90DEBDBDF4283830, - 0x4FFFFFFFAFFFFAAB, 0xE50D7B81579423EF, 0xEC34B87D2E278690, 0xF7011FDB0D7E4039, - 0x36665FFFCFFFFCCD, 0x8F7CFBE74FC1FE11, 0xF3C1DE178881E0F0, 0x511EC2B933D84731, -]; +const MATRIX_CIRC_MDS_16_SML_ROW: [i64; 16] = + [1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3]; impl Permutation<[Goldilocks; 16]> for MdsMatrixGoldilocks { fn permute(&self, input: [Goldilocks; 16]) -> [Goldilocks; 16] { - const ENTRIES: [u64; 16] = first_row_to_first_col(&MATRIX_CIRC_MDS_16_GOLDILOCKS); - apply_circulant_fft(FFT_ALGO, ENTRIES, &input) + const MATRIX_CIRC_MDS_16_SML_COL: [i64; 16] = + first_row_to_first_col(&MATRIX_CIRC_MDS_16_SML_ROW); + SmallConvolveGoldilocks::apply( + input, + MATRIX_CIRC_MDS_16_SML_COL, + SmallConvolveGoldilocks::conv16, + ) } fn permute_mut(&self, input: &mut [Goldilocks; 16]) { @@ -173,10 +230,9 @@ impl MdsPermutation for MdsMatrixGoldilocks {} #[cfg(test)] mod tests { use p3_field::AbstractField; - use p3_goldilocks::Goldilocks; use p3_symmetric::Permutation; - use super::MdsMatrixGoldilocks; + use super::{Goldilocks, MdsMatrixGoldilocks}; #[test] fn goldilocks8() { @@ -195,14 +251,14 @@ mod tests { let output = MdsMatrixGoldilocks.permute(input); let expected: [Goldilocks; 8] = [ - 7296579203883891650, - 15846818354170800942, - 2722920531482623643, - 9616208848921711631, - 490813044365975970, - 5031976952389823366, - 7947699737923523585, - 12198158979238091825, + 16726687146516531007, + 14721040752765534861, + 15566838577475948790, + 9095485010737904250, + 11353934351835864222, + 11056556168691087893, + 4199602889124860181, + 315643510993921470, ] .map(Goldilocks::from_canonical_u64); @@ -230,18 +286,18 @@ mod tests { let output = MdsMatrixGoldilocks.permute(input); let expected: [Goldilocks; 12] = [ - 1843219901452929153, - 8403333524301862517, - 6376512008882165421, - 8955522364079524476, - 9670564897072663334, - 3938053462378634031, - 6601899746530774049, - 12760892837989840359, - 18262125928170834728, - 16489603729927565926, - 9216989093042288220, - 14240946967822758312, + 9322351889214742299, + 8700136572060418355, + 4881757876459003977, + 9899544690241851021, + 480548822895830465, + 5445915149371405525, + 14955363277757168581, + 6672733082273363313, + 190938676320003294, + 1613225933948270736, + 3549006224849989171, + 12169032187873197425, ] .map(Goldilocks::from_canonical_u64); @@ -273,22 +329,22 @@ mod tests { let output = MdsMatrixGoldilocks.permute(input); let expected: [Goldilocks; 16] = [ - 5524669282304516875, - 17505467846953098022, - 7505835506215945517, - 4678037345724403903, - 10895647714009331453, - 5085395390658218948, - 9415955230270042820, - 612277897076940754, - 6973621272151388239, - 3749044944784924855, - 18059026573819502927, - 2497516531324297048, - 4238565225225375968, - 10076249375516184572, - 11967060791800253810, - 6267956432712136737, + 9484392671298797780, + 149770626972189150, + 12125722600598304117, + 15945232149672903756, + 13199929870021500593, + 18443980893262804946, + 317150800081307627, + 16910019239751125049, + 1996802739033818490, + 11668458913264624237, + 11078800762167869397, + 13758408662406282356, + 11119677412113674380, + 7344117715971661026, + 4202436890275702092, + 681166793519210465, ] .map(Goldilocks::from_canonical_u64); diff --git a/goldilocks/src/poseidon2.rs b/goldilocks/src/poseidon2.rs new file mode 100644 index 00000000..32664ca7 --- /dev/null +++ b/goldilocks/src/poseidon2.rs @@ -0,0 +1,294 @@ +//! Diffusion matrices for Goldilocks8, Goldilocks12, Goldilocks16, and Goldilocks20. +//! +//! Reference: https://github.com/HorizenLabs/poseidon2/blob/main/plain_implementations/src/poseidon2/poseidon2_instance_goldilocks.rs + +use p3_field::AbstractField; +use p3_poseidon2::{matmul_internal, DiffusionPermutation}; +use p3_symmetric::Permutation; + +use crate::{to_goldilocks_array, Goldilocks}; + +pub const MATRIX_DIAG_8_GOLDILOCKS_U64: [u64; 8] = [ + 0xa98811a1fed4e3a5, + 0x1cc48b54f377e2a0, + 0xe40cd4f6c5609a26, + 0x11de79ebca97a4a3, + 0x9177c73d8b7e929c, + 0x2a6fe8085797e791, + 0x3de6e93329f8d5ad, + 0x3f7af9125da962fe, +]; + +pub const MATRIX_DIAG_12_GOLDILOCKS_U64: [u64; 12] = [ + 0xc3b6c08e23ba9300, + 0xd84b5de94a324fb6, + 0x0d0c371c5b35b84f, + 0x7964f570e7188037, + 0x5daf18bbd996604b, + 0x6743bc47b9595257, + 0x5528b9362c59bb70, + 0xac45e25b7127b68b, + 0xa2077d7dfbb606b5, + 0xf3faac6faee378ae, + 0x0c6388b51545e883, + 0xd27dbb6944917b60, +]; + +pub const MATRIX_DIAG_16_GOLDILOCKS_U64: [u64; 16] = [ + 0xde9b91a467d6afc0, + 0xc5f16b9c76a9be17, + 0x0ab0fef2d540ac55, + 0x3001d27009d05773, + 0xed23b1f906d3d9eb, + 0x5ce73743cba97054, + 0x1c3bab944af4ba24, + 0x2faa105854dbafae, + 0x53ffb3ae6d421a10, + 0xbcda9df8884ba396, + 0xfc1273e4a31807bb, + 0xc77952573d5142c0, + 0x56683339a819b85e, + 0x328fcbd8f0ddc8eb, + 0xb5101e303fce9cb7, + 0x774487b8c40089bb, +]; + +pub const MATRIX_DIAG_20_GOLDILOCKS_U64: [u64; 20] = [ + 0x95c381fda3b1fa57, + 0xf36fe9eb1288f42c, + 0x89f5dcdfef277944, + 0x106f22eadeb3e2d2, + 0x684e31a2530e5111, + 0x27435c5d89fd148e, + 0x3ebed31c414dbf17, + 0xfd45b0b2d294e3cc, + 0x48c904473a7f6dbf, + 0xe0d1b67809295b4d, + 0xddd1941e9d199dcb, + 0x8cfe534eeb742219, + 0xa6e5261d9e3b8524, + 0x6897ee5ed0f82c1b, + 0x0e7dcd0739ee5f78, + 0x493253f3d0d32363, + 0xbb2737f5845f05c0, + 0xa187e810b06ad903, + 0xb635b995936c4918, + 0x0b3694a940bd2394, +]; + +// Convert the above arrays of u64's into arrays of Goldilocks field elements. +const MATRIX_DIAG_8_GOLDILOCKS: [Goldilocks; 8] = to_goldilocks_array(MATRIX_DIAG_8_GOLDILOCKS_U64); +const MATRIX_DIAG_12_GOLDILOCKS: [Goldilocks; 12] = + to_goldilocks_array(MATRIX_DIAG_12_GOLDILOCKS_U64); +const MATRIX_DIAG_16_GOLDILOCKS: [Goldilocks; 16] = + to_goldilocks_array(MATRIX_DIAG_16_GOLDILOCKS_U64); +const MATRIX_DIAG_20_GOLDILOCKS: [Goldilocks; 20] = + to_goldilocks_array(MATRIX_DIAG_20_GOLDILOCKS_U64); + +#[derive(Debug, Clone, Default)] +pub struct DiffusionMatrixGoldilocks; + +impl> Permutation<[AF; 8]> for DiffusionMatrixGoldilocks { + fn permute_mut(&self, state: &mut [AF; 8]) { + matmul_internal::(state, MATRIX_DIAG_8_GOLDILOCKS); + } +} + +impl> DiffusionPermutation for DiffusionMatrixGoldilocks {} + +impl> Permutation<[AF; 12]> for DiffusionMatrixGoldilocks { + fn permute_mut(&self, state: &mut [AF; 12]) { + matmul_internal::(state, MATRIX_DIAG_12_GOLDILOCKS); + } +} + +impl> DiffusionPermutation for DiffusionMatrixGoldilocks {} + +impl> Permutation<[AF; 16]> for DiffusionMatrixGoldilocks { + fn permute_mut(&self, state: &mut [AF; 16]) { + matmul_internal::(state, MATRIX_DIAG_16_GOLDILOCKS); + } +} + +impl> DiffusionPermutation for DiffusionMatrixGoldilocks {} + +impl> Permutation<[AF; 20]> for DiffusionMatrixGoldilocks { + fn permute_mut(&self, state: &mut [AF; 20]) { + matmul_internal::(state, MATRIX_DIAG_20_GOLDILOCKS); + } +} + +impl> DiffusionPermutation for DiffusionMatrixGoldilocks {} + +#[cfg(test)] +mod tests { + use alloc::vec::Vec; + + use ark_ff::{BigInteger, PrimeField}; + use p3_poseidon2::Poseidon2; + use rand::Rng; + use zkhash::fields::goldilocks::FpGoldiLocks; + use zkhash::poseidon2::poseidon2::Poseidon2 as Poseidon2Ref; + use zkhash::poseidon2::poseidon2_instance_goldilocks::{ + POSEIDON2_GOLDILOCKS_12_PARAMS, POSEIDON2_GOLDILOCKS_8_PARAMS, RC12, RC8, + }; + + use super::*; + + #[test] + fn test_poseidon2_constants() { + let monty_constant = MATRIX_DIAG_8_GOLDILOCKS_U64.map(Goldilocks::from_canonical_u64); + assert_eq!(monty_constant, MATRIX_DIAG_8_GOLDILOCKS); + + let monty_constant = MATRIX_DIAG_12_GOLDILOCKS_U64.map(Goldilocks::from_canonical_u64); + assert_eq!(monty_constant, MATRIX_DIAG_12_GOLDILOCKS); + + let monty_constant = MATRIX_DIAG_16_GOLDILOCKS_U64.map(Goldilocks::from_canonical_u64); + assert_eq!(monty_constant, MATRIX_DIAG_16_GOLDILOCKS); + + let monty_constant = MATRIX_DIAG_20_GOLDILOCKS_U64.map(Goldilocks::from_canonical_u64); + assert_eq!(monty_constant, MATRIX_DIAG_20_GOLDILOCKS); + } + + fn goldilocks_from_ark_ff(input: FpGoldiLocks) -> Goldilocks { + let as_bigint = input.into_bigint(); + let mut as_bytes = as_bigint.to_bytes_le(); + as_bytes.resize(8, 0); + let as_u64 = u64::from_le_bytes(as_bytes[0..8].try_into().unwrap()); + Goldilocks::from_wrapped_u64(as_u64) + } + + #[test] + fn test_poseidon2_goldilocks_width_8() { + const WIDTH: usize = 8; + const D: u64 = 7; + const ROUNDS_F: usize = 8; + const ROUNDS_P: usize = 22; + + type F = Goldilocks; + + let mut rng = rand::thread_rng(); + + // Poiseidon2 reference implementation from zkhash repo. + let poseidon2_ref = Poseidon2Ref::new(&POSEIDON2_GOLDILOCKS_8_PARAMS); + + // Copy over round constants from zkhash. + let round_constants: Vec<[F; WIDTH]> = RC8 + .iter() + .map(|vec| { + vec.iter() + .cloned() + .map(goldilocks_from_ark_ff) + .collect::>() + .try_into() + .unwrap() + }) + .collect(); + + // Our Poseidon2 implementation. + let poseidon2: Poseidon2 = Poseidon2::new( + ROUNDS_F, + ROUNDS_P, + round_constants, + DiffusionMatrixGoldilocks, + ); + + // Generate random input and convert to both Goldilocks field formats. + let input_u64 = rng.gen::<[u64; WIDTH]>(); + let input_ref = input_u64 + .iter() + .cloned() + .map(FpGoldiLocks::from) + .collect::>(); + let input = input_u64.map(F::from_wrapped_u64); + + // Check that the conversion is correct. + assert!(input_ref + .iter() + .zip(input.iter()) + .all(|(a, b)| goldilocks_from_ark_ff(*a) == *b)); + + // Run reference implementation. + let output_ref = poseidon2_ref.permutation(&input_ref); + let expected: [F; WIDTH] = output_ref + .iter() + .cloned() + .map(goldilocks_from_ark_ff) + .collect::>() + .try_into() + .unwrap(); + + // Run our implementation. + let mut output = input; + poseidon2.permute_mut(&mut output); + + assert_eq!(output, expected); + } + + #[test] + fn test_poseidon2_goldilocks_width_12() { + const WIDTH: usize = 12; + const D: u64 = 7; + const ROUNDS_F: usize = 8; + const ROUNDS_P: usize = 22; + + type F = Goldilocks; + + let mut rng = rand::thread_rng(); + + // Poiseidon2 reference implementation from zkhash repo. + let poseidon2_ref = Poseidon2Ref::new(&POSEIDON2_GOLDILOCKS_12_PARAMS); + + // Copy over round constants from zkhash. + let round_constants: Vec<[F; WIDTH]> = RC12 + .iter() + .map(|vec| { + vec.iter() + .cloned() + .map(goldilocks_from_ark_ff) + .collect::>() + .try_into() + .unwrap() + }) + .collect(); + + // Our Poseidon2 implementation. + let poseidon2: Poseidon2 = Poseidon2::new( + ROUNDS_F, + ROUNDS_P, + round_constants, + DiffusionMatrixGoldilocks, + ); + + // Generate random input and convert to both Goldilocks field formats. + let input_u64 = rng.gen::<[u64; WIDTH]>(); + let input_ref = input_u64 + .iter() + .cloned() + .map(FpGoldiLocks::from) + .collect::>(); + let input = input_u64.map(F::from_wrapped_u64); + + // Check that the conversion is correct. + assert!(input_ref + .iter() + .zip(input.iter()) + .all(|(a, b)| goldilocks_from_ark_ff(*a) == *b)); + + // Run reference implementation. + let output_ref = poseidon2_ref.permutation(&input_ref); + let expected: [F; WIDTH] = output_ref + .iter() + .cloned() + .map(goldilocks_from_ark_ff) + .collect::>() + .try_into() + .unwrap(); + + // Run our implementation. + let mut output = input; + poseidon2.permute_mut(&mut output); + + assert_eq!(output, expected); + } +} diff --git a/keccak-air/Cargo.toml b/keccak-air/Cargo.toml index c1b5f97d..592e721e 100644 --- a/keccak-air/Cargo.toml +++ b/keccak-air/Cargo.toml @@ -14,6 +14,7 @@ tracing = "0.1.37" [dev-dependencies] p3-baby-bear = { path = "../baby-bear" } p3-challenger = { path = "../challenger" } +p3-circle = { path = "../circle" } p3-commit = { path = "../commit" } p3-dft = { path = "../dft" } p3-fri = { path = "../fri" } @@ -22,6 +23,8 @@ p3-keccak = { path = "../keccak" } p3-maybe-rayon = { path = "../maybe-rayon" } p3-mds = { path = "../mds" } p3-merkle-tree = { path = "../merkle-tree" } +p3-mersenne-31 = { path = "../mersenne-31" } +p3-poseidon = {path = "../poseidon"} p3-poseidon2 = { path = "../poseidon2" } p3-symmetric = { path = "../symmetric" } p3-uni-stark = { path = "../uni-stark" } @@ -38,6 +41,9 @@ name = "prove_baby_bear_poseidon2" [[example]] name = "prove_goldilocks_keccak" +[[example]] +name = "prove_goldilocks_poseidon" + [features] # TODO: Consider removing, at least when this gets split off into another repository. # We should be able to enable p3-maybe-rayon/parallel directly; this just doesn't diff --git a/keccak-air/examples/prove_baby_bear_keccak.rs b/keccak-air/examples/prove_baby_bear_keccak.rs index 0c308e8e..a9580bce 100644 --- a/keccak-air/examples/prove_baby_bear_keccak.rs +++ b/keccak-air/examples/prove_baby_bear_keccak.rs @@ -3,12 +3,14 @@ use p3_challenger::{HashChallenger, SerializingChallenger32}; use p3_commit::ExtensionMmcs; use p3_dft::Radix2DitParallel; use p3_field::extension::BinomialExtensionField; -use p3_fri::{FriConfig, TwoAdicFriPcs, TwoAdicFriPcsConfig}; +use p3_fri::{FriConfig, TwoAdicFriPcs}; use p3_keccak::Keccak256Hash; use p3_keccak_air::{generate_trace_rows, KeccakAir}; +use p3_matrix::Matrix; use p3_merkle_tree::FieldMerkleTreeMmcs; use p3_symmetric::{CompressionFunctionFromHasher, SerializingHasher32}; use p3_uni_stark::{prove, verify, StarkConfig, VerificationError}; +use p3_util::log2_ceil_usize; use rand::random; use tracing_forest::util::LevelFilter; use tracing_forest::ForestLayer; @@ -50,23 +52,23 @@ fn main() -> Result<(), VerificationError> { type Challenger = SerializingChallenger32>; + let inputs = (0..NUM_HASHES).map(|_| random()).collect::>(); + let trace = generate_trace_rows::(inputs); + let fri_config = FriConfig { log_blowup: 1, num_queries: 100, proof_of_work_bits: 16, mmcs: challenge_mmcs, }; - type Pcs = - TwoAdicFriPcs>; - let pcs = Pcs::new(fri_config, dft, val_mmcs); + type Pcs = TwoAdicFriPcs; + let pcs = Pcs::new(log2_ceil_usize(trace.height()), dft, val_mmcs, fri_config); - type MyConfig = StarkConfig; - let config = StarkConfig::new(pcs); + type MyConfig = StarkConfig; + let config = MyConfig::new(pcs); let mut challenger = Challenger::from_hasher(vec![], byte_hash); - let inputs = (0..NUM_HASHES).map(|_| random()).collect::>(); - let trace = generate_trace_rows::(inputs); let proof = prove::(&config, &KeccakAir {}, &mut challenger, trace); let mut challenger = Challenger::from_hasher(vec![], byte_hash); diff --git a/keccak-air/examples/prove_baby_bear_poseidon2.rs b/keccak-air/examples/prove_baby_bear_poseidon2.rs index 048cde7b..deb32689 100644 --- a/keccak-air/examples/prove_baby_bear_poseidon2.rs +++ b/keccak-air/examples/prove_baby_bear_poseidon2.rs @@ -1,15 +1,17 @@ -use p3_baby_bear::BabyBear; +use p3_baby_bear::{BabyBear, DiffusionMatrixBabybear}; use p3_challenger::DuplexChallenger; use p3_commit::ExtensionMmcs; use p3_dft::Radix2DitParallel; use p3_field::extension::BinomialExtensionField; use p3_field::Field; -use p3_fri::{FriConfig, TwoAdicFriPcs, TwoAdicFriPcsConfig}; +use p3_fri::{FriConfig, TwoAdicFriPcs}; use p3_keccak_air::{generate_trace_rows, KeccakAir}; +use p3_matrix::Matrix; use p3_merkle_tree::FieldMerkleTreeMmcs; -use p3_poseidon2::{DiffusionMatrixBabybear, Poseidon2}; +use p3_poseidon2::Poseidon2; use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; use p3_uni_stark::{prove, verify, StarkConfig, VerificationError}; +use p3_util::log2_ceil_usize; use rand::{random, thread_rng}; use tracing_forest::util::LevelFilter; use tracing_forest::ForestLayer; @@ -58,23 +60,23 @@ fn main() -> Result<(), VerificationError> { type Challenger = DuplexChallenger; + let inputs = (0..NUM_HASHES).map(|_| random()).collect::>(); + let trace = generate_trace_rows::(inputs); + let fri_config = FriConfig { log_blowup: 1, num_queries: 100, proof_of_work_bits: 16, mmcs: challenge_mmcs, }; - type Pcs = - TwoAdicFriPcs>; - let pcs = Pcs::new(fri_config, dft, val_mmcs); + type Pcs = TwoAdicFriPcs; + let pcs = Pcs::new(log2_ceil_usize(trace.height()), dft, val_mmcs, fri_config); - type MyConfig = StarkConfig; - let config = StarkConfig::new(pcs); + type MyConfig = StarkConfig; + let config = MyConfig::new(pcs); let mut challenger = Challenger::new(perm.clone()); - let inputs = (0..NUM_HASHES).map(|_| random()).collect::>(); - let trace = generate_trace_rows::(inputs); let proof = prove::(&config, &KeccakAir {}, &mut challenger, trace); let mut challenger = Challenger::new(perm); diff --git a/keccak-air/examples/prove_goldilocks_keccak.rs b/keccak-air/examples/prove_goldilocks_keccak.rs index 9a446204..0b721ff2 100644 --- a/keccak-air/examples/prove_goldilocks_keccak.rs +++ b/keccak-air/examples/prove_goldilocks_keccak.rs @@ -2,13 +2,15 @@ use p3_challenger::{HashChallenger, SerializingChallenger64}; use p3_commit::ExtensionMmcs; use p3_dft::Radix2DitParallel; use p3_field::extension::BinomialExtensionField; -use p3_fri::{FriConfig, TwoAdicFriPcs, TwoAdicFriPcsConfig}; +use p3_fri::{FriConfig, TwoAdicFriPcs}; use p3_goldilocks::Goldilocks; use p3_keccak::Keccak256Hash; use p3_keccak_air::{generate_trace_rows, KeccakAir}; +use p3_matrix::Matrix; use p3_merkle_tree::FieldMerkleTreeMmcs; use p3_symmetric::{CompressionFunctionFromHasher, SerializingHasher64}; use p3_uni_stark::{prove, verify, StarkConfig, VerificationError}; +use p3_util::log2_ceil_usize; use rand::random; use tracing_forest::util::LevelFilter; use tracing_forest::ForestLayer; @@ -50,23 +52,23 @@ fn main() -> Result<(), VerificationError> { type Challenger = SerializingChallenger64>; + let inputs = (0..NUM_HASHES).map(|_| random()).collect::>(); + let trace = generate_trace_rows::(inputs); + let fri_config = FriConfig { log_blowup: 1, num_queries: 100, proof_of_work_bits: 16, mmcs: challenge_mmcs, }; - type Pcs = - TwoAdicFriPcs>; - let pcs = Pcs::new(fri_config, dft, val_mmcs); + type Pcs = TwoAdicFriPcs; + let pcs = Pcs::new(log2_ceil_usize(trace.height()), dft, val_mmcs, fri_config); - type MyConfig = StarkConfig; - let config = StarkConfig::new(pcs); + type MyConfig = StarkConfig; + let config = MyConfig::new(pcs); let mut challenger = Challenger::from_hasher(vec![], byte_hash); - let inputs = (0..NUM_HASHES).map(|_| random()).collect::>(); - let trace = generate_trace_rows::(inputs); let proof = prove::(&config, &KeccakAir {}, &mut challenger, trace); let mut challenger = Challenger::from_hasher(vec![], byte_hash); diff --git a/keccak-air/examples/prove_goldilocks_poseidon.rs b/keccak-air/examples/prove_goldilocks_poseidon.rs new file mode 100644 index 00000000..f17e54ca --- /dev/null +++ b/keccak-air/examples/prove_goldilocks_poseidon.rs @@ -0,0 +1,84 @@ +use p3_challenger::DuplexChallenger; +use p3_commit::ExtensionMmcs; +use p3_dft::Radix2DitParallel; +use p3_field::extension::BinomialExtensionField; +use p3_field::Field; +use p3_fri::{FriConfig, TwoAdicFriPcs}; +use p3_goldilocks::{Goldilocks, MdsMatrixGoldilocks}; +use p3_keccak_air::{generate_trace_rows, KeccakAir}; +use p3_matrix::Matrix; +use p3_merkle_tree::FieldMerkleTreeMmcs; +use p3_poseidon::Poseidon; +use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; +use p3_uni_stark::{prove, verify, StarkConfig, VerificationError}; +use p3_util::log2_ceil_usize; +use rand::{random, thread_rng}; +use tracing_forest::util::LevelFilter; +use tracing_forest::ForestLayer; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{EnvFilter, Registry}; + +const NUM_HASHES: usize = 680; + +fn main() -> Result<(), VerificationError> { + let env_filter = EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(); + + Registry::default() + .with(env_filter) + .with(ForestLayer::default()) + .init(); + + type Val = Goldilocks; + type Challenge = BinomialExtensionField; + + type Perm = Poseidon; + let perm = Perm::new_from_rng(4, 22, MdsMatrixGoldilocks, &mut thread_rng()); + + type MyHash = PaddingFreeSponge; + let hash = MyHash::new(perm.clone()); + + type MyCompress = TruncatedPermutation; + let compress = MyCompress::new(perm.clone()); + + type ValMmcs = FieldMerkleTreeMmcs< + ::Packing, + ::Packing, + MyHash, + MyCompress, + 4, + >; + let val_mmcs = ValMmcs::new(hash, compress); + + type ChallengeMmcs = ExtensionMmcs; + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + + type Dft = Radix2DitParallel; + let dft = Dft {}; + + type Challenger = DuplexChallenger; + + let inputs = (0..NUM_HASHES).map(|_| random()).collect::>(); + let trace = generate_trace_rows::(inputs); + + let fri_config = FriConfig { + log_blowup: 1, + num_queries: 100, + proof_of_work_bits: 16, + mmcs: challenge_mmcs, + }; + type Pcs = TwoAdicFriPcs; + let pcs = Pcs::new(log2_ceil_usize(trace.height()), dft, val_mmcs, fri_config); + + type MyConfig = StarkConfig; + let config = MyConfig::new(pcs); + + let mut challenger = Challenger::new(perm.clone()); + + let proof = prove::(&config, &KeccakAir {}, &mut challenger, trace); + + let mut challenger = Challenger::new(perm); + verify(&config, &KeccakAir {}, &mut challenger, &proof) +} diff --git a/keccak-air/examples/prove_m31_keccak.rs b/keccak-air/examples/prove_m31_keccak.rs new file mode 100644 index 00000000..d4d61341 --- /dev/null +++ b/keccak-air/examples/prove_m31_keccak.rs @@ -0,0 +1,85 @@ +use p3_challenger::{HashChallenger, SerializingChallenger32}; +use p3_circle::{Cfft, CirclePcs}; +use p3_commit::ExtensionMmcs; +use p3_fri::FriConfig; +use p3_keccak::Keccak256Hash; +use p3_keccak_air::{generate_trace_rows, KeccakAir}; +use p3_matrix::Matrix; +use p3_merkle_tree::FieldMerkleTreeMmcs; +use p3_mersenne_31::Mersenne31; +use p3_symmetric::{CompressionFunctionFromHasher, SerializingHasher32}; +use p3_uni_stark::{prove, verify, StarkConfig, VerificationError}; +use p3_util::log2_strict_usize; +use rand::random; +use tracing_forest::util::LevelFilter; +use tracing_forest::ForestLayer; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{EnvFilter, Registry}; + +const NUM_HASHES: usize = 680; + +fn main() -> Result<(), VerificationError> { + let env_filter = EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(); + + Registry::default() + .with(env_filter) + .with(ForestLayer::default()) + .init(); + + type Val = Mersenne31; + // type Challenge = BinomialExtensionField; + type Challenge = Val; + + type ByteHash = Keccak256Hash; + type FieldHash = SerializingHasher32; + let byte_hash = ByteHash {}; + let field_hash = FieldHash::new(Keccak256Hash {}); + + type MyCompress = CompressionFunctionFromHasher; + let compress = MyCompress::new(byte_hash); + + type ValMmcs = FieldMerkleTreeMmcs; + let val_mmcs = ValMmcs::new(field_hash, compress); + + type ChallengeMmcs = ExtensionMmcs; + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + + type Challenger = SerializingChallenger32>; + + let _fri_config = FriConfig { + log_blowup: 1, + num_queries: 100, + proof_of_work_bits: 16, + mmcs: challenge_mmcs, + }; + + type Pcs = CirclePcs; + let pcs = Pcs { + log_blowup: 1, + cfft: Cfft::default(), + mmcs: val_mmcs, + }; + + type MyConfig = StarkConfig; + let config = MyConfig::new(pcs); + + let inputs = (0..NUM_HASHES).map(|_| random()).collect::>(); + let trace = generate_trace_rows::(inputs); + + dbg!(trace.height(), log2_strict_usize(trace.height())); + + let air = KeccakAir {}; + + let mut challenger = Challenger::from_hasher(vec![], byte_hash); + + let proof = prove::(&config, &air, &mut challenger, trace); + + let mut challenger = Challenger::from_hasher(vec![], byte_hash); + verify(&config, &air, &mut challenger, &proof)?; + + println!("OK!!! 👍"); + Ok(()) +} diff --git a/matrix/Cargo.toml b/matrix/Cargo.toml index 7d5dfe1b..c7f50a1f 100644 --- a/matrix/Cargo.toml +++ b/matrix/Cargo.toml @@ -8,8 +8,10 @@ license = "MIT OR Apache-2.0" p3-field = { path = "../field" } p3-maybe-rayon = { path = "../maybe-rayon" } p3-util = { path = "../util" } +itertools = "0.12.0" rand = "0.8.5" serde = { version = "1.0", features = ["derive"] } +tracing = "0.1.37" [dev-dependencies] criterion = "0.5.1" diff --git a/matrix/src/dense.rs b/matrix/src/dense.rs index c1a4fc21..68690e81 100644 --- a/matrix/src/dense.rs +++ b/matrix/src/dense.rs @@ -9,7 +9,10 @@ use rand::distributions::{Distribution, Standard}; use rand::Rng; use serde::{Deserialize, Serialize}; -use crate::{Matrix, MatrixGet, MatrixRowSlices, MatrixRowSlicesMut, MatrixRows, MatrixTranspose}; +use crate::{ + Matrix, MatrixGet, MatrixRowChunksMut, MatrixRowSlices, MatrixRowSlicesMut, MatrixRows, + MatrixTranspose, +}; /// A default constant for block size matrix transposition. The value was chosen with 32-byte type, in mind. const TRANSPOSE_BLOCK_SIZE: usize = 64; @@ -229,6 +232,18 @@ impl MatrixRowSlicesMut for RowMajorMatrix { } } +impl MatrixRowChunksMut for RowMajorMatrix { + type RowChunkMut<'a> = RowMajorMatrixViewMut<'a, T> where T: 'a; + fn par_row_chunks_mut( + &mut self, + chunk_rows: usize, + ) -> impl IndexedParallelIterator> { + self.values + .par_chunks_exact_mut(self.width * chunk_rows) + .map(|slice| RowMajorMatrixViewMut::new(slice, self.width)) + } +} + #[derive(Copy, Clone)] pub struct RowMajorMatrixView<'a, T> { pub values: &'a [T], diff --git a/matrix/src/lib.rs b/matrix/src/lib.rs index 80df3c98..7e4b79a6 100644 --- a/matrix/src/lib.rs +++ b/matrix/src/lib.rs @@ -7,12 +7,15 @@ extern crate alloc; use alloc::vec::Vec; use core::fmt::{Debug, Display, Formatter}; +use p3_maybe_rayon::prelude::*; + use crate::dense::RowMajorMatrix; use crate::strided::VerticallyStridedMatrixView; pub mod bitrev; pub mod dense; pub mod mul; +pub mod routines; pub mod sparse; pub mod stack; pub mod strided; @@ -62,6 +65,10 @@ pub trait MatrixRows: Matrix { fn row(&self, r: usize) -> Self::Row<'_>; + fn rows(&self) -> impl Iterator> { + (0..self.height()).map(|r| self.row(r)) + } + fn row_vec(&self, r: usize) -> Vec { self.row(r).into_iter().collect() } @@ -100,11 +107,48 @@ pub trait MatrixRows: Matrix { /// A `Matrix` which supports access its rows as slices. pub trait MatrixRowSlices: MatrixRows { fn row_slice(&self, r: usize) -> &[T]; + + fn row_slices<'a>(&'a self) -> impl Iterator + where + T: 'a, + { + (0..self.height()).map(|r| self.row_slice(r)) + } } /// A `Matrix` which supports access its rows as mutable slices. pub trait MatrixRowSlicesMut: MatrixRowSlices { fn row_slice_mut(&mut self, r: usize) -> &mut [T]; + + // BEWARE: if we add a matrix type that has several rows in the same memory location, + // these default implementations will be invalid + // For example, a "tiling" matrix view that repeats its rows + + /// # Safety + /// Each row index in `rs` must be unique. + unsafe fn disjoint_row_slices_mut(&mut self, rs: [usize; N]) -> [&mut [T]; N] { + rs.map(|r| { + let s = self.row_slice_mut(r); + // launder the lifetime to 'a instead of being bound to self + unsafe { core::slice::from_raw_parts_mut(s.as_mut_ptr(), s.len()) } + }) + } + fn row_pair_slices_mut(&mut self, r0: usize, r1: usize) -> (&mut [T], &mut [T]) { + // make it safe by ensuring rs unique + assert_ne!(r0, r1); + let [s0, s1] = unsafe { self.disjoint_row_slices_mut([r0, r1]) }; + (s0, s1) + } +} + +pub trait MatrixRowChunksMut: MatrixRowSlicesMut { + type RowChunkMut<'a>: MatrixRowSlicesMut + Send + where + Self: 'a; + fn par_row_chunks_mut( + &mut self, + chunk_rows: usize, + ) -> impl IndexedParallelIterator>; } /// A `TransposeMatrix` which supports transpose logic for matrices diff --git a/matrix/src/routines.rs b/matrix/src/routines.rs new file mode 100644 index 00000000..417aa7d3 --- /dev/null +++ b/matrix/src/routines.rs @@ -0,0 +1,28 @@ +use alloc::vec; +use alloc::vec::Vec; + +use itertools::izip; +use p3_field::{ExtensionField, Field}; +use tracing::instrument; + +use crate::MatrixRows; + +/// Tranposed matrix-vector product: Mᵀv +/// Can handle a vector of extensions of the matrix field, the other way around +/// would require a different method. +/// TODO: make faster (currently ~100ms of m31_keccak) +#[instrument(skip_all, fields(dims = %m.dimensions()))] +pub fn columnwise_dot_product(m: M, v: impl Iterator) -> Vec +where + F: Field, + EF: ExtensionField, + M: MatrixRows, +{ + let mut accs = vec![EF::zero(); m.width()]; + for (row, vx) in izip!(m.rows(), v) { + for (acc, mx) in izip!(&mut accs, row) { + *acc += vx * mx; + } + } + accs +} diff --git a/mds/Cargo.toml b/mds/Cargo.toml index 3eb09dac..bf3e65f7 100644 --- a/mds/Cargo.toml +++ b/mds/Cargo.toml @@ -5,18 +5,19 @@ edition = "2021" license = "MIT OR Apache-2.0" [dependencies] -p3-baby-bear = { path = "../baby-bear" } p3-dft = { path = "../dft" } p3-field = { path = "../field" } -p3-goldilocks = { path = "../goldilocks" } p3-matrix = { path = "../matrix" } -p3-mersenne-31 = { path = "../mersenne-31" } p3-symmetric = { path = "../symmetric" } p3-util = { path = "../util" } rand = { version = "0.8.5", features = ["min_const_gen"] } +itertools = { version = "0.11.0" } [dev-dependencies] criterion = "0.5.1" +p3-baby-bear = { path = "../baby-bear" } +p3-goldilocks = { path = "../goldilocks" } +p3-mersenne-31 = { path = "../mersenne-31/" } [[bench]] name = "mds" diff --git a/mds/benches/mds.rs b/mds/benches/mds.rs index f9917fcc..d461be27 100644 --- a/mds/benches/mds.rs +++ b/mds/benches/mds.rs @@ -1,16 +1,13 @@ use std::any::type_name; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use p3_baby_bear::BabyBear; +use p3_baby_bear::{BabyBear, MdsMatrixBabyBear}; use p3_field::{AbstractField, Field}; -use p3_goldilocks::Goldilocks; -use p3_mds::babybear::MdsMatrixBabyBear; +use p3_goldilocks::{Goldilocks, MdsMatrixGoldilocks}; use p3_mds::coset_mds::CosetMds; -use p3_mds::goldilocks::MdsMatrixGoldilocks; use p3_mds::integrated_coset_mds::IntegratedCosetMds; -use p3_mds::mersenne31::MdsMatrixMersenne31; use p3_mds::MdsPermutation; -use p3_mersenne_31::Mersenne31; +use p3_mersenne_31::{MdsMatrixMersenne31, Mersenne31}; use rand::distributions::{Distribution, Standard}; use rand::{thread_rng, Rng}; @@ -19,15 +16,25 @@ fn bench_all_mds(c: &mut Criterion) { bench_mds::<::Packing, IntegratedCosetMds, 16>(c); bench_mds::, 16>(c); bench_mds::<::Packing, CosetMds, 16>(c); + + bench_mds::(c); bench_mds::(c); + bench_mds::(c); bench_mds::(c); + bench_mds::(c); + bench_mds::(c); bench_mds::(c); bench_mds::(c); bench_mds::(c); + bench_mds::(c); + bench_mds::(c); + bench_mds::(c); + bench_mds::(c); bench_mds::(c); bench_mds::(c); + bench_mds::(c); } fn bench_mds(c: &mut Criterion) diff --git a/mds/src/babybear.rs b/mds/src/babybear.rs deleted file mode 100644 index a0a8400e..00000000 --- a/mds/src/babybear.rs +++ /dev/null @@ -1,290 +0,0 @@ -//! MDS matrices over the BabyBear field, and permutations defined by them. -//! -//! NB: Not all sizes have fast implementations of their permutations. -//! Supported sizes: 8, 12, 16, 24, 32, 64. -//! Sizes 8 and 12 are from Plonky2. Other sizes are from Ulrich Haböck's database. - -use p3_baby_bear::BabyBear; -use p3_dft::Radix2Bowers; -use p3_symmetric::Permutation; - -use crate::util::{ - apply_circulant, apply_circulant_12_sml, apply_circulant_8_sml, apply_circulant_fft, - first_row_to_first_col, -}; -use crate::MdsPermutation; - -#[derive(Clone, Default)] -pub struct MdsMatrixBabyBear; - -const FFT_ALGO: Radix2Bowers = Radix2Bowers; - -impl Permutation<[BabyBear; 8]> for MdsMatrixBabyBear { - fn permute(&self, input: [BabyBear; 8]) -> [BabyBear; 8] { - apply_circulant_8_sml(input) - } - - fn permute_mut(&self, input: &mut [BabyBear; 8]) { - *input = self.permute(*input); - } -} -impl MdsPermutation for MdsMatrixBabyBear {} - -impl Permutation<[BabyBear; 12]> for MdsMatrixBabyBear { - fn permute(&self, input: [BabyBear; 12]) -> [BabyBear; 12] { - apply_circulant_12_sml(input) - } - - fn permute_mut(&self, input: &mut [BabyBear; 12]) { - *input = self.permute(*input); - } -} -impl MdsPermutation for MdsMatrixBabyBear {} - -#[rustfmt::skip] -const MATRIX_CIRC_MDS_16_BABYBEAR: [u64; 16] = [ - 0x07801000, 0x4ACAAC32, 0x6A709B76, 0x20413E94, - 0x00928499, 0x31C34CA3, 0x03BBC192, 0x3F20868B, - 0x257FFAAB, 0x5F05F559, 0x55B43EA9, 0x2BC659ED, - 0x2C6D7501, 0x1D110184, 0x0E1F608D, 0x2032F0C6, -]; - -impl Permutation<[BabyBear; 16]> for MdsMatrixBabyBear { - fn permute(&self, input: [BabyBear; 16]) -> [BabyBear; 16] { - const ENTRIES: [u64; 16] = first_row_to_first_col(&MATRIX_CIRC_MDS_16_BABYBEAR); - apply_circulant_fft(FFT_ALGO, ENTRIES, &input) - } - - fn permute_mut(&self, input: &mut [BabyBear; 16]) { - *input = self.permute(*input); - } -} -impl MdsPermutation for MdsMatrixBabyBear {} - -#[rustfmt::skip] -const MATRIX_CIRC_MDS_24_BABYBEAR: [u64; 24] = [ - 0x2D0AAAAB, 0x64850517, 0x17F5551D, 0x04ECBEB5, - 0x6D91A8D5, 0x60703026, 0x18D6F3CA, 0x729601A7, - 0x77CDA9E2, 0x3C0F5038, 0x26D52A61, 0x0360405D, - 0x68FC71C8, 0x2495A71D, 0x5D57AFC2, 0x1689DD98, - 0x3C2C3DBE, 0x0C23DC41, 0x0524C7F2, 0x6BE4DF69, - 0x0A6E572C, 0x5C7790FA, 0x17E118F6, 0x0878A07F, -]; - -impl Permutation<[BabyBear; 24]> for MdsMatrixBabyBear { - fn permute(&self, input: [BabyBear; 24]) -> [BabyBear; 24] { - apply_circulant(&MATRIX_CIRC_MDS_24_BABYBEAR, input) - } - - fn permute_mut(&self, input: &mut [BabyBear; 24]) { - *input = self.permute(*input); - } -} -impl MdsPermutation for MdsMatrixBabyBear {} - -#[rustfmt::skip] -const MATRIX_CIRC_MDS_32_BABYBEAR: [u64; 32] = [ - 0x0BC00000, 0x2BED8F81, 0x337E0652, 0x4C4535D1, - 0x4AF2DC32, 0x2DB4050F, 0x676A7CE3, 0x3A06B68E, - 0x5E95C1B1, 0x2C5F54A0, 0x2332F13D, 0x58E757F1, - 0x3AA6DCCE, 0x607EE630, 0x4ED57FF0, 0x6E08555B, - 0x4C155556, 0x587FD0CE, 0x462F1551, 0x032A43CC, - 0x5E2E43EA, 0x71609B02, 0x0ED97E45, 0x562CA7E9, - 0x2CB70B1D, 0x4E941E23, 0x174A61C1, 0x117A9426, - 0x73562137, 0x54596086, 0x487C560B, 0x68A4ACAB, -]; - -impl Permutation<[BabyBear; 32]> for MdsMatrixBabyBear { - fn permute(&self, input: [BabyBear; 32]) -> [BabyBear; 32] { - const ENTRIES: [u64; 32] = first_row_to_first_col(&MATRIX_CIRC_MDS_32_BABYBEAR); - apply_circulant_fft(FFT_ALGO, ENTRIES, &input) - } - - fn permute_mut(&self, input: &mut [BabyBear; 32]) { - *input = self.permute(*input); - } -} -impl MdsPermutation for MdsMatrixBabyBear {} - -#[rustfmt::skip] -const MATRIX_CIRC_MDS_64_BABYBEAR: [u64; 64] = [ - 0x39577778, 0x0072F4E1, 0x0B1B8404, 0x041E9C88, - 0x32D22F9F, 0x4E4BF946, 0x20C7B6D7, 0x0587C267, - 0x55877229, 0x4D186EC4, 0x4A19FD23, 0x1A64A20F, - 0x2965CA4D, 0x16D98A5A, 0x471E544A, 0x193D5C8B, - 0x6E66DF0C, 0x28BF1F16, 0x26DB0BC8, 0x5B06CDDB, - 0x100DCCA2, 0x65C268AD, 0x199F09E7, 0x36BA04BE, - 0x06C393F2, 0x51B06DFD, 0x6951B0C4, 0x6683A4C2, - 0x3B53D11B, 0x26E5134C, 0x45A5F1C5, 0x6F4D2433, - 0x3CE2D82E, 0x36309A7D, 0x3DD9B459, 0x68051E4C, - 0x5C3AA720, 0x11640517, 0x0634D995, 0x1B0F6406, - 0x72A18430, 0x26513CC5, 0x67C0B93C, 0x548AB4A3, - 0x6395D20D, 0x3E5DBC41, 0x332AF630, 0x3C5DDCB3, - 0x0AA95792, 0x66EB5492, 0x3F78DDDC, 0x5AC41627, - 0x16CD5124, 0x3564DA96, 0x461867C9, 0x157B4E11, - 0x1AA486C8, 0x0C5095A9, 0x3833C0C6, 0x008FEBA5, - 0x52ECBE2E, 0x1D178A67, 0x58B3C04B, 0x6E95CB51, -]; - -impl Permutation<[BabyBear; 64]> for MdsMatrixBabyBear { - fn permute(&self, input: [BabyBear; 64]) -> [BabyBear; 64] { - const ENTRIES: [u64; 64] = first_row_to_first_col(&MATRIX_CIRC_MDS_64_BABYBEAR); - apply_circulant_fft(FFT_ALGO, ENTRIES, &input) - } - - fn permute_mut(&self, input: &mut [BabyBear; 64]) { - *input = self.permute(*input); - } -} -impl MdsPermutation for MdsMatrixBabyBear {} - -#[cfg(test)] -mod tests { - use p3_baby_bear::BabyBear; - use p3_field::AbstractField; - use p3_symmetric::Permutation; - - use super::MdsMatrixBabyBear; - - #[test] - fn babybear8() { - let input: [BabyBear; 8] = [ - 391474477, 1174409341, 666967492, 1852498830, 1801235316, 820595865, 585587525, - 1348326858, - ] - .map(BabyBear::from_canonical_u64); - - let output = MdsMatrixBabyBear.permute(input); - - let expected: [BabyBear; 8] = [ - 504128309, 1915631392, 1485872679, 1192473153, 1425656962, 634837116, 1385055496, - 795071948, - ] - .map(BabyBear::from_canonical_u64); - - assert_eq!(output, expected); - } - - #[test] - fn babybear12() { - let input: [BabyBear; 12] = [ - 918423259, 673549090, 364157140, 9832898, 493922569, 1171855651, 246075034, 1542167926, - 1787615541, 1696819900, 1884530130, 422386768, - ] - .map(BabyBear::from_canonical_u64); - - let output = MdsMatrixBabyBear.permute(input); - - let expected: [BabyBear; 12] = [ - 772551966, 2009480750, 430187688, 1134406614, 351991333, 1100020355, 777201441, - 109334185, 2000422332, 226001108, 1763301937, 631922975, - ] - .map(BabyBear::from_canonical_u64); - - assert_eq!(output, expected); - } - - #[test] - fn babybear16() { - let input: [BabyBear; 16] = [ - 1983708094, 1477844074, 1638775686, 98517138, 70746308, 968700066, 275567720, - 1359144511, 960499489, 1215199187, 474302783, 79320256, 1923147803, 1197733438, - 1638511323, 303948902, - ] - .map(BabyBear::from_canonical_u64); - - let output = MdsMatrixBabyBear.permute(input); - - let expected: [BabyBear; 16] = [ - 556401834, 683220320, 1810464928, 1169932617, 638040805, 1006828793, 1808829293, - 1614898838, 23062004, 622101715, 967448737, 519782760, 579530259, 157817176, - 1439772057, 54268721, - ] - .map(BabyBear::from_canonical_u64); - - assert_eq!(output, expected); - } - - #[test] - fn babybear24() { - let input: [BabyBear; 24] = [ - 1307148929, 1603957607, 1515498600, 1412393512, 785287979, 988718522, 1750345556, - 853137995, 534387281, 930390055, 1600030977, 903985158, 1141020507, 636889442, - 966037834, 1778991639, 1440427266, 1379431959, 853403277, 959593575, 733455867, - 908584009, 817124993, 418826476, - ] - .map(BabyBear::from_canonical_u64); - - let output = MdsMatrixBabyBear.permute(input); - - let expected: [BabyBear; 24] = [ - 1537871777, 1626055274, 1705000179, 1426678258, 1688760658, 1347225494, 1291221794, - 1224656589, 1791446853, 1978133881, 1820380039, 1366829700, 27479566, 409595531, - 1223347944, 1752750033, 594548873, 1447473111, 1385412872, 1111945102, 1366585917, - 138866947, 1326436332, 656898133, - ] - .map(BabyBear::from_canonical_u64); - - assert_eq!(output, expected); - } - - #[test] - fn babybear32() { - let input: [BabyBear; 32] = [ - 1346087634, 1511946000, 1883470964, 54906057, 233060279, 5304922, 1881494193, - 743728289, 404047361, 1148556479, 144976634, 1726343008, 29659471, 1350407160, - 1636652429, 385978955, 327649601, 1248138459, 1255358242, 84164877, 1005571393, - 1713215328, 72913800, 1683904606, 904763213, 316800515, 656395998, 788184609, - 1824512025, 1177399063, 1358745087, 444151496, - ] - .map(BabyBear::from_canonical_u64); - - let output = MdsMatrixBabyBear.permute(input); - - let expected: [BabyBear; 32] = [ - 1359576919, 1657405784, 1031581836, 212090105, 699048671, 877916349, 205627787, - 1211567750, 210807569, 1696391051, 558468987, 161148427, 304343518, 76611896, - 532792005, 1963649139, 1283500358, 250848292, 1109842541, 2007388683, 433801252, - 1189712914, 626158024, 1436409738, 456315160, 1836818120, 1645024941, 925447491, - 1599571860, 1055439714, 353537136, 379644130, - ] - .map(BabyBear::from_canonical_u64); - - assert_eq!(output, expected); - } - - #[test] - fn babybear64() { - let input: [BabyBear; 64] = [ - 1931358930, 1322576114, 1658000717, 134388215, 1517892791, 1486447670, 93570662, - 898466034, 1576905917, 283824713, 1433559150, 1730678909, 155340881, 1978472263, - 1980644590, 1814040165, 654743892, 849954227, 323176597, 146970735, 252703735, - 1856579399, 162749290, 986745196, 352038183, 1239527508, 828473247, 1184743572, - 1017249065, 36804843, 1378131210, 1286724687, 596095979, 1916924908, 528946791, - 397247884, 23477278, 299412064, 415288430, 935825754, 1218003667, 1954592289, - 1594612673, 664096455, 958392778, 497208288, 1544504580, 1829423324, 956111902, - 458327015, 1736664598, 430977734, 599887171, 1100074154, 1197653896, 427838651, - 466509871, 1236918100, 940670246, 1421951147, 255557957, 1374188100, 315300068, - 623354170, - ] - .map(BabyBear::from_canonical_u64); - - let output = MdsMatrixBabyBear.permute(input); - - let expected: [BabyBear; 64] = [ - 442300274, 756862170, 167612495, 1103336044, 546496433, 1211822920, 329094196, - 1334376959, 944085937, 977350947, 1445060130, 918469957, 800346119, 1957918170, - 739098112, 1862817833, 1831589884, 1673860978, 698081523, 1128978338, 387929536, - 1106772486, 1367460469, 1911237185, 362669171, 819949894, 1801786287, 1943505026, - 586738185, 996076080, 1641277705, 1680239311, 1005815192, 63087470, 593010310, - 364673774, 543368618, 1576179136, 47618763, 1990080335, 1608655220, 499504830, - 861863262, 765074289, 139277832, 1139970138, 1510286607, 244269525, 43042067, - 119733624, 1314663255, 893295811, 1444902994, 914930267, 1675139862, 1148717487, - 1601328192, 534383401, 296215929, 1924587380, 1336639141, 34897994, 2005302060, - 1780337352, - ] - .map(BabyBear::from_canonical_u64); - - assert_eq!(output, expected); - } -} diff --git a/mds/src/karatsuba_convolution.rs b/mds/src/karatsuba_convolution.rs new file mode 100644 index 00000000..0b28a342 --- /dev/null +++ b/mds/src/karatsuba_convolution.rs @@ -0,0 +1,388 @@ +//! Calculate the convolution of two vectors using a Karatsuba-style +//! decomposition and the CRT. +//! +//! This is not a new idea, but we did have the pleasure of +//! reinventing it independently. Some references: +//! - https://cr.yp.to/lineartime/multapps-20080515.pdf +//! - https://2π.com/23/convolution/ +//! +//! Given a vector v \in F^N, let v(x) \in F[X] denote the polynomial +//! v_0 + v_1 x + ... + v_{N - 1} x^{N - 1}. Then w is equal to the +//! convolution v * u if and only if w(x) = v(x)u(x) mod x^N - 1. +//! Additionally, define the negacyclic convolution by w(x) = v(x)u(x) +//! mod x^N + 1. Using the Chinese remainder theorem we can compute +//! w(x) as +//! w(x) = 1/2 (w_0(x) + w_1(x)) + x^{N/2}/2 (w_0(x) - w_1(x)) +//! where +//! w_0 = v(x)u(x) mod x^{N/2} - 1 +//! w_1 = v(x)u(x) mod x^{N/2} + 1 +//! +//! To compute w_0 and w_1 we first compute +//! v_0(x) = v(x) mod x^{N/2} - 1 +//! v_1(x) = v(x) mod x^{N/2} + 1 +//! u_0(x) = u(x) mod x^{N/2} - 1 +//! u_1(x) = u(x) mod x^{N/2} + 1 +//! +//! Now w_0 is the convolution of v_0 and u_0 which we can compute +//! recursively. For w_1 we compute the negacyclic convolution +//! v_1(x)u_1(x) mod x^{N/2} + 1 using Karatsuba. +//! +//! There are 2 possible approaches to applying Karatsuba which mirror +//! the DIT vs DIF approaches to FFT's, the left/right decomposition +//! or the even/odd decomposition. The latter seems to have fewer +//! operations and so it is the one implemented below, though it does +//! require a bit more data manipulation. It works as follows: +//! +//! Define the even v_e and odd v_o parts so that v(x) = (v_e(x^2) + x v_o(x^2)). +//! Then v(x)u(x) +//! = (v_e(x^2)u_e(x^2) + x^2 v_o(x^2)u_o(x^2)) +//! + x ((v_e(x^2) + v_o(x^2))(u_e(x^2) + u_o(x^2)) +//! - (v_e(x^2)u_e(x^2) + v_o(x^2)u_o(x^2))) +//! This reduces the problem to 3 negacyclic convolutions of size N/2 which +//! can be computed recursively. +//! +//! Of course, for small sizes we just explicitly write out the O(n^2) +//! approach. + +use core::ops::{Add, AddAssign, Neg, ShrAssign, Sub, SubAssign}; + +/// This trait collects the operations needed by `Convolve` below. +/// +/// TODO: Think of a better name for this. +pub trait RngElt: + Add + + AddAssign + + Copy + + Default + + Neg + + ShrAssign + + Sub + + SubAssign +{ +} + +impl RngElt for i64 {} +impl RngElt for i128 {} + +/// Template function to perform convolution of vectors. +/// +/// Roughly speaking, for a convolution of size `N`, it should be +/// possible to add `N` elements of type `T` without overflowing, and +/// similarly for `U`. Then multiplication via `Self::mul` should +/// produce an element of type `V` which will not overflow after about +/// `N` additions (this is an over-estimate). +/// +/// For example usage, see `{mersenne-31,baby-bear,goldilocks}/src/mds.rs`. +/// +/// NB: In practice, one of the parameters to the convolution will be +/// constant (the MDS matrix). After inspecting Godbolt output, it +/// seems that the compiler does indeed generate single constants as +/// inputs to the multiplication, rather than doing all that +/// arithmetic on the constant values every time. Note however that, +/// for MDS matrices with large entries (N >= 24), these compile-time +/// generated constants will be about N times bigger than they need to +/// be in principle, which could be a potential avenue for some minor +/// improvements. +/// +/// NB: If primitive multiplications are still the bottleneck, a +/// further possibility would be to find an MDS matrix some of whose +/// entries are powers of 2. Then the multiplication can be replaced +/// with a shift, which on most architectures has better throughput +/// and latency, and is issued on different ports (1*p06) to +/// multiplication (1*p1). +pub trait Convolve { + /// Given an input element, retrieve the corresponding internal + /// element that will be used in calculations. + fn read(input: F) -> T; + + /// Given input vectors `lhs` and `rhs`, calculate their dot + /// product. The result can be reduced with respect to the modulus + /// (of `F`), but it must have the same lower 10 bits as the dot + /// product if all inputs are considered integers. See + /// `baby-bear/src/mds.rs::barret_red_babybear()` for an example + /// of how this can be implemented in practice. + fn parity_dot(lhs: [T; N], rhs: [U; N]) -> V; + + /// Convert an internal element of type `V` back into an external + /// element. + fn reduce(z: V) -> F; + + /// Convolve `lhs` and `rhs`. + /// + /// The parameter `conv` should be the function in this trait that + /// corresponds to length `N`. + #[inline(always)] + fn apply( + lhs: [F; N], + rhs: [U; N], + conv: C, + ) -> [F; N] { + let lhs = lhs.map(Self::read); + let mut output = [V::default(); N]; + conv(lhs, rhs, &mut output); + output.map(Self::reduce) + } + + #[inline(always)] + fn conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [V]) { + output[0] = Self::parity_dot(lhs, [rhs[0], rhs[2], rhs[1]]); + output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], rhs[2]]); + output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]); + } + + #[inline(always)] + fn negacyclic_conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [V]) { + output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[2], -rhs[1]]); + output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[2]]); + output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]); + } + + #[inline(always)] + fn conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [V]) { + // NB: This is just explicitly implementing + // conv_n_recursive::<4, 2, _, _>(lhs, rhs, output, Self::conv2, Self::negacyclic_conv2) + let u_p = [lhs[0] + lhs[2], lhs[1] + lhs[3]]; + let u_m = [lhs[0] - lhs[2], lhs[1] - lhs[3]]; + let v_p = [rhs[0] + rhs[2], rhs[1] + rhs[3]]; + let v_m = [rhs[0] - rhs[2], rhs[1] - rhs[3]]; + + output[0] = Self::parity_dot(u_m, [v_m[0], -v_m[1]]); + output[1] = Self::parity_dot(u_m, [v_m[1], v_m[0]]); + output[2] = Self::parity_dot(u_p, v_p); + output[3] = Self::parity_dot(u_p, [v_p[1], v_p[0]]); + + output[0] += output[2]; + output[1] += output[3]; + + output[0] >>= 1; + output[1] >>= 1; + + output[2] -= output[0]; + output[3] -= output[1]; + } + + #[inline(always)] + fn negacyclic_conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [V]) { + output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[3], -rhs[2], -rhs[1]]); + output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[3], -rhs[2]]); + output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0], -rhs[3]]); + output[3] = Self::parity_dot(lhs, [rhs[3], rhs[2], rhs[1], rhs[0]]); + } + + #[inline(always)] + fn conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [V]) { + conv_n_recursive::<6, 3, T, U, V, _, _>( + lhs, + rhs, + output, + Self::conv3, + Self::negacyclic_conv3, + ) + } + + #[inline(always)] + fn negacyclic_conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [V]) { + negacyclic_conv_n_recursive::<6, 3, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv3) + } + + #[inline(always)] + fn conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [V]) { + conv_n_recursive::<8, 4, T, U, V, _, _>( + lhs, + rhs, + output, + Self::conv4, + Self::negacyclic_conv4, + ) + } + + #[inline(always)] + fn negacyclic_conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [V]) { + negacyclic_conv_n_recursive::<8, 4, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv4) + } + + #[inline(always)] + fn conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [V]) { + conv_n_recursive::<12, 6, T, U, V, _, _>( + lhs, + rhs, + output, + Self::conv6, + Self::negacyclic_conv6, + ) + } + + #[inline(always)] + fn negacyclic_conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [V]) { + negacyclic_conv_n_recursive::<12, 6, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv6) + } + + #[inline(always)] + fn conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [V]) { + conv_n_recursive::<16, 8, T, U, V, _, _>( + lhs, + rhs, + output, + Self::conv8, + Self::negacyclic_conv8, + ) + } + + #[inline(always)] + fn negacyclic_conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [V]) { + negacyclic_conv_n_recursive::<16, 8, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv8) + } + + #[inline(always)] + fn conv24(lhs: [T; 24], rhs: [U; 24], output: &mut [V]) { + conv_n_recursive::<24, 12, T, U, V, _, _>( + lhs, + rhs, + output, + Self::conv12, + Self::negacyclic_conv12, + ) + } + + #[inline(always)] + fn conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [V]) { + conv_n_recursive::<32, 16, T, U, V, _, _>( + lhs, + rhs, + output, + Self::conv16, + Self::negacyclic_conv16, + ) + } + + #[inline(always)] + fn negacyclic_conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [V]) { + negacyclic_conv_n_recursive::<32, 16, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv16) + } + + #[inline(always)] + fn conv64(lhs: [T; 64], rhs: [U; 64], output: &mut [V]) { + conv_n_recursive::<64, 32, T, U, V, _, _>( + lhs, + rhs, + output, + Self::conv32, + Self::negacyclic_conv32, + ) + } +} + +/// Compute output(x) = lhs(x)rhs(x) mod x^N - 1. +/// Do this recursively using a convolution and negacyclic convolution of size HALF_N = N/2. +#[inline(always)] +fn conv_n_recursive( + lhs: [T; N], + rhs: [U; N], + output: &mut [V], + inner_conv: C, + inner_negacyclic_conv: NC, +) where + T: RngElt, + U: RngElt, + V: RngElt, + C: Fn([T; HALF_N], [U; HALF_N], &mut [V]), + NC: Fn([T; HALF_N], [U; HALF_N], &mut [V]), +{ + debug_assert_eq!(2 * HALF_N, N); + // NB: The compiler is smart enough not to initialise these arrays. + let mut lhs_pos = [T::default(); HALF_N]; // lhs_pos = lhs(x) mod x^{N/2} - 1 + let mut lhs_neg = [T::default(); HALF_N]; // lhs_neg = lhs(x) mod x^{N/2} + 1 + let mut rhs_pos = [U::default(); HALF_N]; // rhs_pos = rhs(x) mod x^{N/2} - 1 + let mut rhs_neg = [U::default(); HALF_N]; // rhs_neg = rhs(x) mod x^{N/2} + 1 + + for i in 0..HALF_N { + let s = lhs[i]; + let t = lhs[i + HALF_N]; + + lhs_pos[i] = s + t; + lhs_neg[i] = s - t; + + let s = rhs[i]; + let t = rhs[i + HALF_N]; + + rhs_pos[i] = s + t; + rhs_neg[i] = s - t; + } + + let (left, right) = output.split_at_mut(HALF_N); + + // left = w1 = lhs(x)rhs(x) mod x^{N/2} + 1 + inner_negacyclic_conv(lhs_neg, rhs_neg, left); + + // right = w0 = lhs(x)rhs(x) mod x^{N/2} - 1 + inner_conv(lhs_pos, rhs_pos, right); + + for i in 0..HALF_N { + left[i] += right[i]; // w_0 + w_1 + left[i] >>= 1; // (w_0 + w_1)/2 + right[i] -= left[i]; // (w_0 - w_1)/2 + } +} + +/// Compute output(x) = lhs(x)rhs(x) mod x^N + 1. +/// Do this recursively using three negacyclic convolutions of size HALF_N = N/2. +#[inline(always)] +fn negacyclic_conv_n_recursive( + lhs: [T; N], + rhs: [U; N], + output: &mut [V], + inner_negacyclic_conv: NC, +) where + T: RngElt, + U: RngElt, + V: RngElt, + NC: Fn([T; HALF_N], [U; HALF_N], &mut [V]), +{ + debug_assert_eq!(2 * HALF_N, N); + // NB: The compiler is smart enough not to initialise these arrays. + let mut lhs_even = [T::default(); HALF_N]; + let mut lhs_odd = [T::default(); HALF_N]; + let mut lhs_sum = [T::default(); HALF_N]; + let mut rhs_even = [U::default(); HALF_N]; + let mut rhs_odd = [U::default(); HALF_N]; + let mut rhs_sum = [U::default(); HALF_N]; + + for i in 0..HALF_N { + let s = lhs[2 * i]; + let t = lhs[2 * i + 1]; + lhs_even[i] = s; + lhs_odd[i] = t; + lhs_sum[i] = s + t; + + let s = rhs[2 * i]; + let t = rhs[2 * i + 1]; + rhs_even[i] = s; + rhs_odd[i] = t; + rhs_sum[i] = s + t; + } + + let mut even_s_conv = [V::default(); HALF_N]; + let (left, right) = output.split_at_mut(HALF_N); + + // Recursively compute the size N/2 negacyclic convolutions of + // the even parts, odd parts, and sums. + inner_negacyclic_conv(lhs_even, rhs_even, &mut even_s_conv); + inner_negacyclic_conv(lhs_odd, rhs_odd, left); + inner_negacyclic_conv(lhs_sum, rhs_sum, right); + + // Adjust so that the correct values are in right and + // even_s_conv respectively: + right[0] -= even_s_conv[0] + left[0]; + even_s_conv[0] -= left[HALF_N - 1]; + + for i in 1..HALF_N { + right[i] -= even_s_conv[i] + left[i]; + even_s_conv[i] += left[i - 1]; + } + + // Interleave even_s_conv and right in the output: + for i in 0..HALF_N { + output[2 * i] = even_s_conv[i]; + output[2 * i + 1] = output[i + HALF_N]; + } +} diff --git a/mds/src/lib.rs b/mds/src/lib.rs index f38730a1..3e5ee260 100644 --- a/mds/src/lib.rs +++ b/mds/src/lib.rs @@ -6,13 +6,10 @@ extern crate alloc; use p3_symmetric::Permutation; -pub mod babybear; mod butterflies; pub mod coset_mds; -pub mod goldilocks; pub mod integrated_coset_mds; -pub mod m4; -pub mod mersenne31; +pub mod karatsuba_convolution; pub mod util; pub trait MdsPermutation: Permutation<[T; WIDTH]> {} diff --git a/mds/src/util.rs b/mds/src/util.rs index 0e951c6d..b4d13a64 100644 --- a/mds/src/util.rs +++ b/mds/src/util.rs @@ -1,15 +1,39 @@ use alloc::vec::Vec; use core::array; +use core::ops::{AddAssign, Mul}; use p3_dft::TwoAdicSubgroupDft; -use p3_field::{AbstractField, PrimeField64, TwoAdicField}; - -// NB: These four are MDS for M31, BabyBear and Goldilocks -//const MATRIX_CIRC_MDS_8_2EXP: [u64; 8] = [1, 1, 2, 1, 8, 32, 4, 256]; -const MATRIX_CIRC_MDS_8_SML: [u64; 8] = [4, 1, 2, 9, 10, 5, 1, 1]; - -//const MATRIX_CIRC_MDS_12_2EXP: [u64; 12] = [1, 1, 2, 1, 8, 32, 2, 256, 4096, 8, 65536, 1024]; -const MATRIX_CIRC_MDS_12_SML: [u64; 12] = [9, 7, 4, 1, 16, 2, 256, 128, 3, 32, 1, 1]; +use p3_field::{AbstractField, TwoAdicField}; + +// NB: These are all MDS for M31, BabyBear and Goldilocks +// const MATRIX_CIRC_MDS_8_2EXP: [u64; 8] = [1, 1, 2, 1, 8, 32, 4, 256]; +// const MATRIX_CIRC_MDS_8_SML: [u64; 8] = [4, 1, 2, 9, 10, 5, 1, 1]; +// Much smaller: [1, 1, -1, 2, 3, 8, 2, -3] but need to deal with the -ve's + +// const MATRIX_CIRC_MDS_12_2EXP: [u64; 12] = [1, 1, 2, 1, 8, 32, 2, 256, 4096, 8, 65536, 1024]; +// const MATRIX_CIRC_MDS_12_SML: [u64; 12] = [9, 7, 4, 1, 16, 2, 256, 128, 3, 32, 1, 1]; +// const MATRIX_CIRC_MDS_12_SML: [u64; 12] = [1, 1, 2, 1, 8, 9, 10, 7, 5, 9, 4, 10]; + +// Trying to maximise the # of 1's in the vector. +// Not clear exactly what we should be optimising here but that seems reasonable. +// const MATRIX_CIRC_MDS_16_SML: [u64; 16] = +// [1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3]; +// 1, 1, 51, 52, 11, 63, 1, 2, 1, 2, 15, 67, 2, 22, 13, 3 +// [1, 1, 2, 1, 8, 32, 2, 65, 77, 8, 91, 31, 3, 65, 32, 7]; + +/// This will throw an error if N = 0 but it's hard to imagine this case coming up. +#[inline(always)] +pub fn dot_product(u: [T; N], v: [T; N]) -> T +where + T: Copy + AddAssign + Mul, +{ + debug_assert_ne!(N, 0); + let mut dp = u[0] * v[0]; + for i in 1..N { + dp += u[i] * v[i]; + } + dp +} /// Given the first row `circ_matrix` of an NxN circulant matrix, say /// C, return the product `C*input`. @@ -32,89 +56,6 @@ pub fn apply_circulant( output } -/// Given an array `input` and an `offset`, return the array whose -/// elements are those of `input` shifted `offset` places to the -/// right. -/// -/// NB: The algorithm is inefficient but simple enough that this -/// function can be declared `const`, and that is the intended use. In -/// non-`const` contexts you probably want `[T]::rotate_right()` from -/// the standard library. -pub(crate) const fn rotate_right(input: [u64; N], offset: usize) -> [u64; N] { - let mut output = [0u64; N]; - let mut i = 0; - loop { - if i == N { - break; - } - output[i] = input[(N - offset + i) % N]; - i += 1; - } - output -} - -/// As for `apply_circulant()` above, but with `circ_matrix` set to a -/// fixed 8x8 MDS matrix with small entries that satisfy the condition -/// on `PrimeField64::z_linear_combination_sml()`. -pub(crate) fn apply_circulant_8_sml(input: [F; 8]) -> [F; 8] { - const N: usize = 8; - let mut output = [F::zero(); N]; - - const MAT_0: [u64; N] = MATRIX_CIRC_MDS_8_SML; - output[0] = F::linear_combination_u64(MAT_0, &input); - const MAT_1: [u64; N] = rotate_right(MATRIX_CIRC_MDS_8_SML, 1); - output[1] = F::linear_combination_u64(MAT_1, &input); - const MAT_2: [u64; N] = rotate_right(MATRIX_CIRC_MDS_8_SML, 2); - output[2] = F::linear_combination_u64(MAT_2, &input); - const MAT_3: [u64; N] = rotate_right(MATRIX_CIRC_MDS_8_SML, 3); - output[3] = F::linear_combination_u64(MAT_3, &input); - const MAT_4: [u64; N] = rotate_right(MATRIX_CIRC_MDS_8_SML, 4); - output[4] = F::linear_combination_u64(MAT_4, &input); - const MAT_5: [u64; N] = rotate_right(MATRIX_CIRC_MDS_8_SML, 5); - output[5] = F::linear_combination_u64(MAT_5, &input); - const MAT_6: [u64; N] = rotate_right(MATRIX_CIRC_MDS_8_SML, 6); - output[6] = F::linear_combination_u64(MAT_6, &input); - const MAT_7: [u64; N] = rotate_right(MATRIX_CIRC_MDS_8_SML, 7); - output[7] = F::linear_combination_u64(MAT_7, &input); - - output -} - -/// As for `apply_circulant()` above, but with `circ_matrix` set to a -/// fixed 12x12 MDS matrix with small entries that satisfy the condition -/// on `PrimeField64::z_linear_combination_sml()`. -pub(crate) fn apply_circulant_12_sml(input: [F; 12]) -> [F; 12] { - const N: usize = 12; - let mut output = [F::zero(); N]; - - const MAT_0: [u64; N] = MATRIX_CIRC_MDS_12_SML; - output[0] = F::linear_combination_u64(MAT_0, &input); - const MAT_1: [u64; N] = rotate_right(MATRIX_CIRC_MDS_12_SML, 1); - output[1] = F::linear_combination_u64(MAT_1, &input); - const MAT_2: [u64; N] = rotate_right(MATRIX_CIRC_MDS_12_SML, 2); - output[2] = F::linear_combination_u64(MAT_2, &input); - const MAT_3: [u64; N] = rotate_right(MATRIX_CIRC_MDS_12_SML, 3); - output[3] = F::linear_combination_u64(MAT_3, &input); - const MAT_4: [u64; N] = rotate_right(MATRIX_CIRC_MDS_12_SML, 4); - output[4] = F::linear_combination_u64(MAT_4, &input); - const MAT_5: [u64; N] = rotate_right(MATRIX_CIRC_MDS_12_SML, 5); - output[5] = F::linear_combination_u64(MAT_5, &input); - const MAT_6: [u64; N] = rotate_right(MATRIX_CIRC_MDS_12_SML, 6); - output[6] = F::linear_combination_u64(MAT_6, &input); - const MAT_7: [u64; N] = rotate_right(MATRIX_CIRC_MDS_12_SML, 7); - output[7] = F::linear_combination_u64(MAT_7, &input); - const MAT_8: [u64; N] = rotate_right(MATRIX_CIRC_MDS_12_SML, 8); - output[8] = F::linear_combination_u64(MAT_8, &input); - const MAT_9: [u64; N] = rotate_right(MATRIX_CIRC_MDS_12_SML, 9); - output[9] = F::linear_combination_u64(MAT_9, &input); - const MAT_10: [u64; N] = rotate_right(MATRIX_CIRC_MDS_12_SML, 10); - output[10] = F::linear_combination_u64(MAT_10, &input); - const MAT_11: [u64; N] = rotate_right(MATRIX_CIRC_MDS_12_SML, 11); - output[11] = F::linear_combination_u64(MAT_11, &input); - - output -} - /// Given the first row of a circulant matrix, return the first column /// of that circulant matrix. For example, v = [0, 1, 2, 3, 4, 5], /// then output = [0, 5, 4, 3, 2, 1], i.e. the first element is the @@ -127,9 +68,9 @@ pub(crate) fn apply_circulant_12_sml(input: [F; 12]) -> [F; 12] /// NB: The algorithm is inefficient but simple enough that this /// function can be declared `const`, and that is the intended context /// for use. -pub(crate) const fn first_row_to_first_col(v: &[u64; N]) -> [u64; N] { - let mut output = [0u64; N]; - output[0] = v[0]; +pub const fn first_row_to_first_col(v: &[T; N]) -> [T; N] { + // Can do this to get a simple Default value. Might be better ways? + let mut output = [v[0]; N]; let mut i = 1; loop { if i >= N { @@ -146,7 +87,7 @@ pub(crate) const fn first_row_to_first_col(v: &[u64; N]) -> [u64 /// be specified by its first *column*, not its first row. If you have /// the row as an array, you can obtain the column with `first_row_to_first_col()`. #[inline] -pub(crate) fn apply_circulant_fft>( +pub fn apply_circulant_fft>( fft: FFT, column: [u64; N], input: &[F; N], @@ -168,21 +109,13 @@ pub(crate) fn apply_circulant_fft for SmallConvolveMersenne31 { + /// Return the lift of an (almost) reduced Mersenne31 element. + /// The Mersenne31 implementation guarantees that + /// 0 <= input.value <= P < 2^31. + #[inline(always)] + fn read(input: Mersenne31) -> i64 { + input.value as i64 + } + + /// FIXME: Refactor the dot product + /// For a convolution of size N, |x| < N * 2^31 and (as per the + /// assumption above), |y| < 2^24. So the product is at most N * 2^55 + /// which will not overflow for N <= 16. + #[inline(always)] + fn parity_dot(u: [i64; N], v: [i64; N]) -> i64 { + dot_product(u, v) + } + + /// The assumptions above mean z < N^2 * 2^55, which is at most + /// 2^63 when N <= 16. + /// + /// NB: Even though intermediate values could be negative, the + /// output must be non-negative since the inputs were + /// non-negative. + #[inline(always)] + fn reduce(z: i64) -> Mersenne31 { + debug_assert!(z >= 0); + Mersenne31::from_wrapped_u64(z as u64) + } +} + +/// Instantiate convolution for "large" RHS vectors over Mersenne31. +/// +/// Here "large" means the elements can be as big as the field +/// characteristic, and the size N of the RHS is <= 64. +struct LargeConvolveMersenne31; +impl Convolve for LargeConvolveMersenne31 { + /// Return the lift of an (almost) reduced Mersenne31 element. + /// The Mersenne31 implementation guarantees that + /// 0 <= input.value <= P < 2^31. + #[inline(always)] + fn read(input: Mersenne31) -> i64 { + input.value as i64 + } + + #[inline] + fn parity_dot(u: [i64; N], v: [i64; N]) -> i64 { + // For a convolution of size N, |x|, |y| < N * 2^31, so the product + // could be as much as N^2 * 2^62. This will overflow an i64, so + // we first widen to i128. + + let mut dp = 0i128; + for i in 0..N { + dp += u[i] as i128 * v[i] as i128; + } + + const LOWMASK: i128 = (1 << 42) - 1; // Gets the bits lower than 42. + const HIGHMASK: i128 = !(LOWMASK); // Gets all bits higher than 42. + + let low_bits = (dp & LOWMASK) as i64; // low_bits < 2**42 + let high_bits = ((dp & HIGHMASK) >> 31) as i64; // |high_bits| < 2**(n - 31) + + // Proof that low_bits + high_bits is congruent to dp (mod p) + // and congruent to dp (mod 2^11): + // + // The individual bounds clearly show that low_bits + + // high_bits < 2**(n - 30). + // + // Next observe that low_bits + high_bits = input - (2**31 - + // 1) * (high_bits) = input mod P. + // + // Finally note that 2**11 divides high_bits and so low_bits + + // high_bits = low_bits mod 2**11 = input mod 2**11. + + low_bits + high_bits + } + + #[inline] + fn reduce(z: i64) -> Mersenne31 { + // After the dot product, the maximal size is N^2 * 2^62 < 2^74 + // as N = 64 is the biggest size. So, after the partial + // reduction, the output z of parity dot satisfies |z| < 2^44 + // (Where 44 is 74 - 30). + // + // In the recombining steps, conv maps (wo, w1) -> ((wo + w1)/2, + // (wo + w1)/2) which has no effect on the maximal size. (Indeed, + // it makes sizes almost strictly smaller). + // + // On the other hand, negacyclic_conv (ignoring the re-index) + // recombines as: (w0, w1, w2) -> (w0 + w1, w2 - w0 - w1). Hence + // if the input is <= K, the output is <= 3K. + // + // Thus the values appearing at the end are bounded by 3^n 2^44 + // where n is the maximal number of negacyclic_conv recombination + // steps. When N = 64, we need to recombine for singed_conv_32, + // singed_conv_16, singed_conv_8 so the overall bound will be 3^3 + // 2^44 < 32 * 2^44 < 2^49. + debug_assert!(z > -(1i64 << 49)); + debug_assert!(z < (1i64 << 49)); + + const MASK: i64 = (1 << 31) - 1; + // Morally, our value is a i62 not a i64 as the top 3 bits are + // guaranteed to be equal. + let low_bits = Mersenne31::from_canonical_u32((z & MASK) as u32); + let high_bits = ((z >> 31) & MASK) as i32; + let sign_bits = (z >> 62) as i32; + + // Note that high_bits + sign_bits > 0 as by assumption b[63] = b[61]. + let high = Mersenne31::from_canonical_u32((high_bits + sign_bits) as u32); + low_bits + high + } +} + +const MATRIX_CIRC_MDS_8_SML_ROW: [i64; 8] = [7, 1, 3, 8, 8, 3, 4, 9]; + impl Permutation<[Mersenne31; 8]> for MdsMatrixMersenne31 { fn permute(&self, input: [Mersenne31; 8]) -> [Mersenne31; 8] { - apply_circulant_8_sml(input) + const MATRIX_CIRC_MDS_8_SML_COL: [i64; 8] = + first_row_to_first_col(&MATRIX_CIRC_MDS_8_SML_ROW); + SmallConvolveMersenne31::apply( + input, + MATRIX_CIRC_MDS_8_SML_COL, + SmallConvolveMersenne31::conv8, + ) } fn permute_mut(&self, input: &mut [Mersenne31; 8]) { @@ -24,9 +154,17 @@ impl Permutation<[Mersenne31; 8]> for MdsMatrixMersenne31 { } impl MdsPermutation for MdsMatrixMersenne31 {} +const MATRIX_CIRC_MDS_12_SML_ROW: [i64; 12] = [1, 1, 2, 1, 8, 9, 10, 7, 5, 9, 4, 10]; + impl Permutation<[Mersenne31; 12]> for MdsMatrixMersenne31 { fn permute(&self, input: [Mersenne31; 12]) -> [Mersenne31; 12] { - apply_circulant_12_sml(input) + const MATRIX_CIRC_MDS_12_SML_COL: [i64; 12] = + first_row_to_first_col(&MATRIX_CIRC_MDS_12_SML_ROW); + SmallConvolveMersenne31::apply( + input, + MATRIX_CIRC_MDS_12_SML_COL, + SmallConvolveMersenne31::conv12, + ) } fn permute_mut(&self, input: &mut [Mersenne31; 12]) { @@ -35,17 +173,18 @@ impl Permutation<[Mersenne31; 12]> for MdsMatrixMersenne31 { } impl MdsPermutation for MdsMatrixMersenne31 {} -#[rustfmt::skip] -const MATRIX_CIRC_MDS_16_MERSENNE31: [u64; 16] = [ - 0x327ACB92, 0x58C99138, 0x3AC486B5, 0x25123B13, - 0x2C74BDE9, 0x108BD51A, 0x4E911F9D, 0x19DD8E68, - 0x06227198, 0x516EE062, 0x0F742AE6, 0x738B4216, - 0x7AEDC4EC, 0x653B794A, 0x47366EC7, 0x6D85346D -]; +const MATRIX_CIRC_MDS_16_SML_ROW: [i64; 16] = + [1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3]; impl Permutation<[Mersenne31; 16]> for MdsMatrixMersenne31 { fn permute(&self, input: [Mersenne31; 16]) -> [Mersenne31; 16] { - apply_circulant(&MATRIX_CIRC_MDS_16_MERSENNE31, input) + const MATRIX_CIRC_MDS_16_SML_COL: [i64; 16] = + first_row_to_first_col(&MATRIX_CIRC_MDS_16_SML_ROW); + SmallConvolveMersenne31::apply( + input, + MATRIX_CIRC_MDS_16_SML_COL, + SmallConvolveMersenne31::conv16, + ) } fn permute_mut(&self, input: &mut [Mersenne31; 16]) { @@ -55,7 +194,7 @@ impl Permutation<[Mersenne31; 16]> for MdsMatrixMersenne31 { impl MdsPermutation for MdsMatrixMersenne31 {} #[rustfmt::skip] -const MATRIX_CIRC_MDS_32_MERSENNE31: [u64; 32] = [ +const MATRIX_CIRC_MDS_32_MERSENNE31_ROW: [i64; 32] = [ 0x1896DC78, 0x559D1E29, 0x04EBD732, 0x3FF449D7, 0x2DB0E2CE, 0x26776B85, 0x76018E57, 0x1025FA13, 0x06486BAB, 0x37706EBA, 0x25EB966B, 0x113C24E5, @@ -68,7 +207,13 @@ const MATRIX_CIRC_MDS_32_MERSENNE31: [u64; 32] = [ impl Permutation<[Mersenne31; 32]> for MdsMatrixMersenne31 { fn permute(&self, input: [Mersenne31; 32]) -> [Mersenne31; 32] { - apply_circulant(&MATRIX_CIRC_MDS_32_MERSENNE31, input) + const MATRIX_CIRC_MDS_32_MERSENNE31_COL: [i64; 32] = + first_row_to_first_col(&MATRIX_CIRC_MDS_32_MERSENNE31_ROW); + LargeConvolveMersenne31::apply( + input, + MATRIX_CIRC_MDS_32_MERSENNE31_COL, + LargeConvolveMersenne31::conv32, + ) } fn permute_mut(&self, input: &mut [Mersenne31; 32]) { @@ -78,7 +223,7 @@ impl Permutation<[Mersenne31; 32]> for MdsMatrixMersenne31 { impl MdsPermutation for MdsMatrixMersenne31 {} #[rustfmt::skip] -const MATRIX_CIRC_MDS_64_MERSENNE31: [u64; 64] = [ +const MATRIX_CIRC_MDS_64_MERSENNE31_ROW: [i64; 64] = [ 0x570227A5, 0x3702983F, 0x4B7B3B0A, 0x74F13DE3, 0x485314B0, 0x0157E2EC, 0x1AD2E5DE, 0x721515E3, 0x5452ADA3, 0x0C74B6C1, 0x67DA9450, 0x33A48369, @@ -99,7 +244,13 @@ const MATRIX_CIRC_MDS_64_MERSENNE31: [u64; 64] = [ impl Permutation<[Mersenne31; 64]> for MdsMatrixMersenne31 { fn permute(&self, input: [Mersenne31; 64]) -> [Mersenne31; 64] { - apply_circulant(&MATRIX_CIRC_MDS_64_MERSENNE31, input) + const MATRIX_CIRC_MDS_64_MERSENNE31_COL: [i64; 64] = + first_row_to_first_col(&MATRIX_CIRC_MDS_64_MERSENNE31_ROW); + LargeConvolveMersenne31::apply( + input, + MATRIX_CIRC_MDS_64_MERSENNE31_COL, + LargeConvolveMersenne31::conv64, + ) } fn permute_mut(&self, input: &mut [Mersenne31; 64]) { @@ -111,10 +262,9 @@ impl MdsPermutation for MdsMatrixMersenne31 {} #[cfg(test)] mod tests { use p3_field::AbstractField; - use p3_mersenne_31::Mersenne31; use p3_symmetric::Permutation; - use super::MdsMatrixMersenne31; + use super::{MdsMatrixMersenne31, Mersenne31}; #[test] fn mersenne8() { @@ -127,8 +277,8 @@ mod tests { let output = MdsMatrixMersenne31.permute(input); let expected: [Mersenne31; 8] = [ - 1796260072, 48130602, 971886692, 1460399885, 745498940, 352898876, 223078564, - 2090539234, + 895992680, 1343855369, 2107796831, 266468728, 846686506, 252887121, 205223309, + 260248790, ] .map(Mersenne31::from_canonical_u64); @@ -146,8 +296,8 @@ mod tests { let output = MdsMatrixMersenne31.permute(input); let expected: [Mersenne31; 12] = [ - 492952161, 916402585, 1541871876, 799921480, 707671572, 1293088641, 866554196, - 1471029895, 35362849, 2107961577, 1616107486, 762379007, + 860812289, 399778981, 1228500858, 798196553, 673507779, 1116345060, 829764188, + 138346433, 578243475, 553581995, 578183208, 1527769050, ] .map(Mersenne31::from_canonical_u64); @@ -166,9 +316,9 @@ mod tests { let output = MdsMatrixMersenne31.permute(input); let expected: [Mersenne31; 16] = [ - 1929166367, 1352685756, 1090911983, 379953343, 62410403, 637712268, 1637633936, - 555902167, 850536312, 913896503, 2070446350, 814495093, 651934716, 419066839, - 603091570, 1453848863, + 1858869691, 1607793806, 1200396641, 1400502985, 1511630695, 187938132, 1332411488, + 2041577083, 2014246632, 802022141, 796807132, 1647212930, 813167618, 1867105010, + 508596277, 1457551581, ] .map(Mersenne31::from_canonical_u64); diff --git a/mersenne-31/src/mersenne_31.rs b/mersenne-31/src/mersenne_31.rs index febed03a..1c141958 100644 --- a/mersenne-31/src/mersenne_31.rs +++ b/mersenne-31/src/mersenne_31.rs @@ -5,13 +5,16 @@ use core::iter::{Product, Sum}; use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; use p3_field::{ - exp_1717986917, exp_u64_by_squaring, AbstractField, Field, Packable, PrimeField, PrimeField32, - PrimeField64, + exp_1717986917, exp_u64_by_squaring, halve_u32, AbstractField, Field, Packable, PrimeField, + PrimeField32, PrimeField64, }; use rand::distributions::{Distribution, Standard}; use rand::Rng; use serde::{Deserialize, Serialize}; +/// The Mersenne31 prime +const P: u32 = (1 << 31) - 1; + /// The prime field `F_p` where `p = 2^31 - 1`. #[derive(Copy, Clone, Default, Serialize, Deserialize)] pub struct Mersenne31 { @@ -168,7 +171,31 @@ impl AbstractField for Mersenne31 { impl Field for Mersenne31 { #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] type Packing = crate::PackedMersenne31Neon; - #[cfg(not(all(target_arch = "aarch64", target_feature = "neon")))] + #[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(all(feature = "nightly-features", target_feature = "avx512f")) + ))] + type Packing = crate::PackedMersenne31AVX2; + #[cfg(all( + feature = "nightly-features", + target_arch = "x86_64", + target_feature = "avx512f" + ))] + type Packing = crate::PackedMersenne31AVX512; + #[cfg(not(any( + all(target_arch = "aarch64", target_feature = "neon"), + all( + target_arch = "x86_64", + target_feature = "avx2", + not(all(feature = "nightly-features", target_feature = "avx512f")) + ), + all( + feature = "nightly-features", + target_arch = "x86_64", + target_feature = "avx512f" + ), + )))] type Packing = Self; #[inline] @@ -225,12 +252,17 @@ impl Field for Mersenne31 { p1111111111111111111111111111.exp_power_of_2(3) * p101; Some(p1111111111111111111111111111101) } + + #[inline] + fn halve(&self) -> Self { + Mersenne31::new(halve_u32::

(self.value)) + } } impl PrimeField for Mersenne31 {} impl PrimeField32 for Mersenne31 { - const ORDER_U32: u32 = (1 << 31) - 1; + const ORDER_U32: u32 = P; #[inline] fn as_canonical_u32(&self) -> u32 { @@ -251,18 +283,6 @@ impl PrimeField64 for Mersenne31 { fn as_canonical_u64(&self) -> u64 { u64::from(self.as_canonical_u32()) } - - #[inline] - fn linear_combination_u64(u: [u64; N], v: &[Self; N]) -> Self { - // In order not to overflow a u64, we must have sum(u) <= 2^32. - debug_assert!(u.iter().sum::() <= (1u64 << 32)); - - let mut dot = u[0] * v[0].value as u64; - for i in 1..N { - dot += u[i] * v[i].value as u64; - } - Self::from_wrapped_u64(dot) - } } impl Add for Mersenne31 { diff --git a/mersenne-31/src/x86_64_avx2.rs b/mersenne-31/src/x86_64_avx2.rs index 4151002c..4d855c84 100644 --- a/mersenne-31/src/x86_64_avx2.rs +++ b/mersenne-31/src/x86_64_avx2.rs @@ -643,7 +643,7 @@ unsafe impl PackedField for PackedMersenne31AVX2 { #[cfg(test)] mod tests { - use rand::{Rng, SeedableRng}; + use rand::SeedableRng; use rand_chacha::ChaCha20Rng; use super::*; diff --git a/mersenne-31/src/x86_64_avx512.rs b/mersenne-31/src/x86_64_avx512.rs new file mode 100644 index 00000000..59b954ce --- /dev/null +++ b/mersenne-31/src/x86_64_avx512.rs @@ -0,0 +1,1308 @@ +use core::arch::x86_64::{self, __m512i, __mmask16, __mmask8}; +use core::iter::{Product, Sum}; +use core::mem::transmute; +use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; + +use p3_field::{AbstractField, Field, PackedField, PackedValue}; + +use crate::Mersenne31; + +const WIDTH: usize = 16; +const P: __m512i = unsafe { transmute::<[u32; WIDTH], _>([0x7fffffff; WIDTH]) }; +const EVENS: __mmask16 = 0b0101010101010101; +const ODDS: __mmask16 = 0b1010101010101010; +const EVENS4: __mmask16 = 0x0f0f; + +/// Vectorized AVX-512F implementation of `Mersenne31` arithmetic. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[repr(transparent)] // This needed to make `transmute`s safe. +pub struct PackedMersenne31AVX512(pub [Mersenne31; WIDTH]); + +impl PackedMersenne31AVX512 { + #[inline] + #[must_use] + /// Get an arch-specific vector representing the packed values. + fn to_vector(self) -> __m512i { + unsafe { + // Safety: `Mersenne31` is `repr(transparent)` so it can be transmuted to `u32`. It + // follows that `[Mersenne31; WIDTH]` can be transmuted to `[u32; WIDTH]`, which can be + // transmuted to `__m512i`, since arrays are guaranteed to be contiguous in memory. + // Finally `PackedMersenne31AVX512` is `repr(transparent)` so it can be transmuted to + // `[Mersenne31; WIDTH]`. + transmute(self) + } + } + + #[inline] + #[must_use] + /// Make a packed field vector from an arch-specific vector. + /// + /// SAFETY: The caller must ensure that each element of `vector` represents a valid + /// `Mersenne31`. In particular, each element of vector must be in `0..=P`. + unsafe fn from_vector(vector: __m512i) -> Self { + // Safety: It is up to the user to ensure that elements of `vector` represent valid + // `Mersenne31` values. We must only reason about memory representations. `__m512i` can be + // transmuted to `[u32; WIDTH]` (since arrays elements are contiguous in memory), which can + // be transmuted to `[Mersenne31; WIDTH]` (since `Mersenne31` is `repr(transparent)`), which + // in turn can be transmuted to `PackedMersenne31AVX512` (since `PackedMersenne31AVX512` is also + // `repr(transparent)`). + transmute(vector) + } + + /// Copy `value` to all positions in a packed vector. This is the same as + /// `From::from`, but `const`. + #[inline] + #[must_use] + const fn broadcast(value: Mersenne31) -> Self { + Self([value; WIDTH]) + } +} + +impl Add for PackedMersenne31AVX512 { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + let lhs = self.to_vector(); + let rhs = rhs.to_vector(); + let res = add(lhs, rhs); + unsafe { + // Safety: `add` returns values in canonical form when given values in canonical form. + Self::from_vector(res) + } + } +} + +impl Mul for PackedMersenne31AVX512 { + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + let lhs = self.to_vector(); + let rhs = rhs.to_vector(); + let res = mul(lhs, rhs); + unsafe { + // Safety: `mul` returns values in canonical form when given values in canonical form. + Self::from_vector(res) + } + } +} + +impl Neg for PackedMersenne31AVX512 { + type Output = Self; + #[inline] + fn neg(self) -> Self { + let val = self.to_vector(); + let res = neg(val); + unsafe { + // Safety: `neg` returns values in canonical form when given values in canonical form. + Self::from_vector(res) + } + } +} + +impl Sub for PackedMersenne31AVX512 { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + let lhs = self.to_vector(); + let rhs = rhs.to_vector(); + let res = sub(lhs, rhs); + unsafe { + // Safety: `sub` returns values in canonical form when given values in canonical form. + Self::from_vector(res) + } + } +} + +/// Add two vectors of Mersenne-31 field elements represented as values in {0, ..., P}. +/// If the inputs do not conform to this representation, the result is undefined. +#[inline] +#[must_use] +fn add(lhs: __m512i, rhs: __m512i) -> __m512i { + // We want this to compile to: + // vpaddd t, lhs, rhs + // vpsubd u, t, P + // vpminud res, t, u + // throughput: 1.5 cyc/vec (10.67 els/cyc) + // latency: 3 cyc + + // Let t := lhs + rhs. We want to return a value r in {0, ..., P} such that r = t (mod P). + // Define u := (t - P) mod 2^32 and r := min(t, u). t is in {0, ..., 2 P}. We argue by cases. + // If t is in {0, ..., P - 1}, then u is in {(P - 1 <) 2^32 - P, ..., 2^32 - 1}, so r = t is + // in the correct range. + // If t is in {P, ..., 2 P}, then u is in {0, ..., P} and r = u is in the correct range. + unsafe { + // Safety: If this code got compiled then AVX-512F intrinsics are available. + let t = x86_64::_mm512_add_epi32(lhs, rhs); + let u = x86_64::_mm512_sub_epi32(t, P); + x86_64::_mm512_min_epu32(t, u) + } +} + +#[inline] +#[must_use] +fn movehdup_epi32(a: __m512i) -> __m512i { + // The instruction is only available in the floating-point flavor; this distinction is only for + // historical reasons and no longer matters. We cast to floats, do the thing, and cast back. + unsafe { + x86_64::_mm512_castps_si512(x86_64::_mm512_movehdup_ps(x86_64::_mm512_castsi512_ps(a))) + } +} + +#[inline] +#[must_use] +fn mask_movehdup_epi32(src: __m512i, k: __mmask16, a: __m512i) -> __m512i { + // The instruction is only available in the floating-point flavor; this distinction is only for + // historical reasons and no longer matters. We cast to floats, do the thing, and cast back. + unsafe { + let src = x86_64::_mm512_castsi512_ps(src); + let a = x86_64::_mm512_castsi512_ps(a); + x86_64::_mm512_castps_si512(x86_64::_mm512_mask_movehdup_ps(src, k, a)) + } +} + +#[inline] +#[must_use] +fn mask_moveldup_epi32(src: __m512i, k: __mmask16, a: __m512i) -> __m512i { + // The instruction is only available in the floating-point flavor; this distinction is only for + // historical reasons and no longer matters. We cast to floats, do the thing, and cast back. + unsafe { + let src = x86_64::_mm512_castsi512_ps(src); + let a = x86_64::_mm512_castsi512_ps(a); + x86_64::_mm512_castps_si512(x86_64::_mm512_mask_moveldup_ps(src, k, a)) + } +} + +/// Multiply vectors of Mersenne-31 field elements represented as values in {0, ..., P}. +/// If the inputs do not conform to this representation, the result is undefined. +#[inline] +#[must_use] +fn mul(lhs: __m512i, rhs: __m512i) -> __m512i { + // We want this to compile to: + // vpaddd lhs_evn_dbl, lhs, lhs + // vmovshdup rhs_odd, rhs + // vpsrlq lhs_odd_dbl, lhs, 31 + // vpmuludq prod_lo_dbl, lhs_evn_dbl, rhs + // vpmuludq prod_odd_dbl, lhs_odd_dbl, rhs_odd + // vmovdqa32 prod_hi, prod_odd_dbl + // vmovshdup prod_hi{EVENS}, prod_lo_dbl + // vmovsldup prod_lo_dbl{ODDS}, prod_odd_dbl + // vpsrld prod_lo, prod_lo_dbl, 1 + // vpaddd t, prod_lo, prod_hi + // vpsubd u, t, P + // vpminud res, t, u + // throughput: 5.5 cyc/vec (2.91 els/cyc) + // latency: (lhs->res) 15 cyc, (rhs->res) 14 cyc + unsafe { + // vpmuludq only reads the bottom 32 bits of every 64-bit quadword. + // The even indices are already in the bottom 32 bits of a quadword, so we can leave them. + let rhs_evn = rhs; + // Again, vpmuludq only reads the bottom 32 bits so we don't need to clear the top. But we + // do want to double the lhs. + let lhs_evn_dbl = x86_64::_mm512_add_epi32(lhs, lhs); + // Copy the high 32 bits in each quadword of rhs down to the low 32. + let rhs_odd = movehdup_epi32(rhs); + // Right shift by 31 is equivalent to moving the high 32 bits down to the low 32, and then + // doubling it. So these are the odd indices in lhs, but doubled. + let lhs_odd_dbl = x86_64::_mm512_srli_epi64::<31>(lhs); + + // Multiply odd indices; since lhs_odd_dbl is doubled, these products are also doubled. + // prod_odd_dbl.quadword[i] = 2 * lhs.doubleword[2 * i + 1] * rhs.doubleword[2 * i + 1] + let prod_odd_dbl = x86_64::_mm512_mul_epu32(lhs_odd_dbl, rhs_odd); + // Multiply even indices; these are also doubled. + // prod_evn_dbl.quadword[i] = 2 * lhs.doubleword[2 * i] * rhs.doubleword[2 * i] + let prod_evn_dbl = x86_64::_mm512_mul_epu32(lhs_evn_dbl, rhs_evn); + + // Move the low halves of odd products into odd positions; keep the low halves of even + // products in even positions (where they already are). Note that the products are doubled, + // so the result is a vector of all the low halves, but doubled. + let prod_lo_dbl = mask_moveldup_epi32(prod_evn_dbl, ODDS, prod_odd_dbl); + // Move the high halves of even products into even positions, keeping the high halves of odd + // products where they are. The products are doubled, but we are looking at (prod >> 32), + // which cancels out the doubling, so this result is _not_ doubled. + let prod_hi = mask_movehdup_epi32(prod_odd_dbl, EVENS, prod_evn_dbl); + // Right shift to undo the doubling. + let prod_lo = x86_64::_mm512_srli_epi32::<1>(prod_lo_dbl); + + // Standard addition of two 31-bit values. + add(prod_lo, prod_hi) + } +} + +/// Negate a vector of Mersenne-31 field elements represented as values in {0, ..., P}. +/// If the input does not conform to this representation, the result is undefined. +#[inline] +#[must_use] +fn neg(val: __m512i) -> __m512i { + // We want this to compile to: + // vpxord res, val, P + // throughput: .5 cyc/vec (32 els/cyc) + // latency: 1 cyc + + // Since val is in {0, ..., P (= 2^31 - 1)}, res = val XOR P = P - val. Then res is in {0, + // ..., P}. + unsafe { + // Safety: If this code got compiled then AVX-512F intrinsics are available. + x86_64::_mm512_xor_epi32(val, P) + } +} + +/// Subtract vectors of Mersenne-31 field elements represented as values in {0, ..., P}. +/// If the inputs do not conform to this representation, the result is undefined. +#[inline] +#[must_use] +fn sub(lhs: __m512i, rhs: __m512i) -> __m512i { + // We want this to compile to: + // vpsubd t, lhs, rhs + // vpaddd u, t, P + // vpminud res, t, u + // throughput: 1.5 cyc/vec (10.67 els/cyc) + // latency: 3 cyc + + // Let d := lhs - rhs and t := d mod 2^32. We want to return a value r in {0, ..., P} such + // that r = d (mod P). + // Define u := (t + P) mod 2^32 and r := min(t, u). d is in {-P, ..., P}. We argue by cases. + // If d is in {0, ..., P}, then t = d and u is in {P, ..., 2 P}. r = t is in the correct + // range. + // If d is in {-P, ..., -1}, then t is in {2^32 - P, ..., 2^32 - 1} and u is in + // {0, ..., P - 1}. r = u is in the correct range. + unsafe { + // Safety: If this code got compiled then AVX-512F intrinsics are available. + let t = x86_64::_mm512_sub_epi32(lhs, rhs); + let u = x86_64::_mm512_add_epi32(t, P); + x86_64::_mm512_min_epu32(t, u) + } +} + +impl From for PackedMersenne31AVX512 { + #[inline] + fn from(value: Mersenne31) -> Self { + Self::broadcast(value) + } +} + +impl Default for PackedMersenne31AVX512 { + #[inline] + fn default() -> Self { + Mersenne31::default().into() + } +} + +impl AddAssign for PackedMersenne31AVX512 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl MulAssign for PackedMersenne31AVX512 { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl SubAssign for PackedMersenne31AVX512 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl Sum for PackedMersenne31AVX512 { + #[inline] + fn sum(iter: I) -> Self + where + I: Iterator, + { + iter.reduce(|lhs, rhs| lhs + rhs).unwrap_or(Self::zero()) + } +} + +impl Product for PackedMersenne31AVX512 { + #[inline] + fn product(iter: I) -> Self + where + I: Iterator, + { + iter.reduce(|lhs, rhs| lhs * rhs).unwrap_or(Self::one()) + } +} + +impl AbstractField for PackedMersenne31AVX512 { + type F = Mersenne31; + + #[inline] + fn zero() -> Self { + Mersenne31::zero().into() + } + + #[inline] + fn one() -> Self { + Mersenne31::one().into() + } + + #[inline] + fn two() -> Self { + Mersenne31::two().into() + } + + #[inline] + fn neg_one() -> Self { + Mersenne31::neg_one().into() + } + + #[inline] + fn from_f(f: Self::F) -> Self { + f.into() + } + + #[inline] + fn from_bool(b: bool) -> Self { + Mersenne31::from_bool(b).into() + } + #[inline] + fn from_canonical_u8(n: u8) -> Self { + Mersenne31::from_canonical_u8(n).into() + } + #[inline] + fn from_canonical_u16(n: u16) -> Self { + Mersenne31::from_canonical_u16(n).into() + } + #[inline] + fn from_canonical_u32(n: u32) -> Self { + Mersenne31::from_canonical_u32(n).into() + } + #[inline] + fn from_canonical_u64(n: u64) -> Self { + Mersenne31::from_canonical_u64(n).into() + } + #[inline] + fn from_canonical_usize(n: usize) -> Self { + Mersenne31::from_canonical_usize(n).into() + } + + #[inline] + fn from_wrapped_u32(n: u32) -> Self { + Mersenne31::from_wrapped_u32(n).into() + } + #[inline] + fn from_wrapped_u64(n: u64) -> Self { + Mersenne31::from_wrapped_u64(n).into() + } + + #[inline] + fn generator() -> Self { + Mersenne31::generator().into() + } +} + +impl Add for PackedMersenne31AVX512 { + type Output = Self; + #[inline] + fn add(self, rhs: Mersenne31) -> Self { + self + Self::from(rhs) + } +} + +impl Mul for PackedMersenne31AVX512 { + type Output = Self; + #[inline] + fn mul(self, rhs: Mersenne31) -> Self { + self * Self::from(rhs) + } +} + +impl Sub for PackedMersenne31AVX512 { + type Output = Self; + #[inline] + fn sub(self, rhs: Mersenne31) -> Self { + self - Self::from(rhs) + } +} + +impl AddAssign for PackedMersenne31AVX512 { + #[inline] + fn add_assign(&mut self, rhs: Mersenne31) { + *self += Self::from(rhs) + } +} + +impl MulAssign for PackedMersenne31AVX512 { + #[inline] + fn mul_assign(&mut self, rhs: Mersenne31) { + *self *= Self::from(rhs) + } +} + +impl SubAssign for PackedMersenne31AVX512 { + #[inline] + fn sub_assign(&mut self, rhs: Mersenne31) { + *self -= Self::from(rhs) + } +} + +impl Sum for PackedMersenne31AVX512 { + #[inline] + fn sum(iter: I) -> Self + where + I: Iterator, + { + iter.sum::().into() + } +} + +impl Product for PackedMersenne31AVX512 { + #[inline] + fn product(iter: I) -> Self + where + I: Iterator, + { + iter.product::().into() + } +} + +impl Div for PackedMersenne31AVX512 { + type Output = Self; + #[allow(clippy::suspicious_arithmetic_impl)] + #[inline] + fn div(self, rhs: Mersenne31) -> Self { + self * rhs.inverse() + } +} + +impl Add for Mersenne31 { + type Output = PackedMersenne31AVX512; + #[inline] + fn add(self, rhs: PackedMersenne31AVX512) -> PackedMersenne31AVX512 { + PackedMersenne31AVX512::from(self) + rhs + } +} + +impl Mul for Mersenne31 { + type Output = PackedMersenne31AVX512; + #[inline] + fn mul(self, rhs: PackedMersenne31AVX512) -> PackedMersenne31AVX512 { + PackedMersenne31AVX512::from(self) * rhs + } +} + +impl Sub for Mersenne31 { + type Output = PackedMersenne31AVX512; + #[inline] + fn sub(self, rhs: PackedMersenne31AVX512) -> PackedMersenne31AVX512 { + PackedMersenne31AVX512::from(self) - rhs + } +} + +// vpshrdq requires AVX-512VBMI2. +#[cfg(target_feature = "avx512vbmi2")] +#[inline] +#[must_use] +fn interleave1_antidiagonal(x: __m512i, y: __m512i) -> __m512i { + unsafe { + // Safety: If this code got compiled then AVX-512VBMI2 intrinsics are available. + x86_64::_mm512_shrdi_epi64::<32>(y, x) + } +} + +// If we can't use vpshrdq, then do a vpermi2d, but we waste a register and double the latency. +#[cfg(not(target_feature = "avx512vbmi2"))] +#[inline] +#[must_use] +fn interleave1_antidiagonal(x: __m512i, y: __m512i) -> __m512i { + const INTERLEAVE1_INDICES: __m512i = unsafe { + // Safety: `[u32; 16]` is trivially transmutable to `__m512i`. + transmute::<[u32; WIDTH], _>([ + 0x01, 0x10, 0x03, 0x12, 0x05, 0x14, 0x07, 0x16, 0x09, 0x18, 0x0b, 0x1a, 0x0d, 0x1c, + 0x0f, 0x1e, + ]) + }; + unsafe { + // Safety: If this code got compiled then AVX-512F intrinsics are available. + x86_64::_mm512_permutex2var_epi32(x, INTERLEAVE1_INDICES, y) + } +} + +#[inline] +#[must_use] +fn interleave1(x: __m512i, y: __m512i) -> (__m512i, __m512i) { + // If we have AVX-512VBMI2, we want this to compile to: + // vpshrdq t, x, y, 32 + // vpblendmd res0 {EVENS}, t, x + // vpblendmd res1 {EVENS}, y, t + // throughput: 1.5 cyc/2 vec (21.33 els/cyc) + // latency: 2 cyc + // + // Otherwise, we want it to compile to: + // vmovdqa32 t, INTERLEAVE1_INDICES + // vpermi2d t, x, y + // vpblendmd res0 {EVENS}, t, x + // vpblendmd res1 {EVENS}, y, t + // throughput: 1.5 cyc/2 vec (21.33 els/cyc) + // latency: 4 cyc + + // We currently have: + // x = [ x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 xa xb xc xd xe xf ], + // y = [ y0 y1 y2 y3 y4 y5 y6 y7 y8 y9 ya yb yc yd ye yf ]. + // First form + // t = [ x1 y0 x3 y2 x5 y4 x7 y6 x9 y8 xb ya xd yc xf ye ]. + let t = interleave1_antidiagonal(x, y); + + unsafe { + // Safety: If this code got compiled then AVX-512F intrinsics are available. + + // Then + // res0 = [ x0 y0 x2 y2 x4 y4 x6 y6 x8 y8 xa ya xc yc xe ye ], + // res1 = [ x1 y1 x3 y3 x5 y5 x7 y7 x9 y9 xb yb xd yd xf yf ]. + ( + x86_64::_mm512_mask_blend_epi32(EVENS, t, x), + x86_64::_mm512_mask_blend_epi32(EVENS, y, t), + ) + } +} + +#[inline] +#[must_use] +fn shuffle_epi64(a: __m512i, b: __m512i) -> __m512i { + // The instruction is only available in the floating-point flavor; this distinction is only for + // historical reasons and no longer matters. We cast to floats, do the thing, and cast back. + unsafe { + let a = x86_64::_mm512_castsi512_pd(a); + let b = x86_64::_mm512_castsi512_pd(b); + x86_64::_mm512_castpd_si512(x86_64::_mm512_shuffle_pd::(a, b)) + } +} + +#[inline] +#[must_use] +fn interleave2(x: __m512i, y: __m512i) -> (__m512i, __m512i) { + // We want this to compile to: + // vshufpd t, x, y, 55h + // vpblendmq res0 {EVENS}, t, x + // vpblendmq res1 {EVENS}, y, t + // throughput: 1.5 cyc/2 vec (21.33 els/cyc) + // latency: 2 cyc + + unsafe { + // Safety: If this code got compiled then AVX-512F intrinsics are available. + + // We currently have: + // x = [ x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 xa xb xc xd xe xf ], + // y = [ y0 y1 y2 y3 y4 y5 y6 y7 y8 y9 ya yb yc yd ye yf ]. + // First form + // t = [ x2 x3 y0 y1 x6 x7 y4 y5 xa xb y8 y9 xe xf yc yd ]. + let t = shuffle_epi64::<0b01010101>(x, y); + + // Then + // res0 = [ x0 x1 y0 y1 x4 x5 y4 y5 x8 x9 y8 y9 xc xd yc yd ], + // res1 = [ x2 x3 y2 y3 x6 x7 y6 y7 xa xb ya yb xe xf ye yf ]. + ( + x86_64::_mm512_mask_blend_epi64(EVENS as __mmask8, t, x), + x86_64::_mm512_mask_blend_epi64(EVENS as __mmask8, y, t), + ) + } +} + +#[inline] +#[must_use] +fn interleave4(x: __m512i, y: __m512i) -> (__m512i, __m512i) { + // We want this to compile to: + // vmovdqa64 t, INTERLEAVE4_INDICES + // vpermi2q t, x, y + // vpblendmd res0 {EVENS4}, t, x + // vpblendmd res1 {EVENS4}, y, t + // throughput: 1.5 cyc/2 vec (21.33 els/cyc) + // latency: 4 cyc + + const INTERLEAVE4_INDICES: __m512i = unsafe { + // Safety: `[u64; 8]` is trivially transmutable to `__m512i`. + transmute::<[u64; WIDTH / 2], _>([0o02, 0o03, 0o10, 0o11, 0o06, 0o07, 0o14, 0o15]) + }; + + unsafe { + // Safety: If this code got compiled then AVX-512F intrinsics are available. + + // We currently have: + // x = [ x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 xa xb xc xd xe xf ], + // y = [ y0 y1 y2 y3 y4 y5 y6 y7 y8 y9 ya yb yc yd ye yf ]. + // First form + // t = [ x4 x5 x6 x7 y0 y1 y2 y3 xc xd xe xf y8 y9 ya yb ]. + let t = x86_64::_mm512_permutex2var_epi64(x, INTERLEAVE4_INDICES, y); + + // Then + // res0 = [ x0 x1 x2 x3 y0 y1 y2 y3 x8 x9 xa xb y8 y9 ya yb ], + // res1 = [ x4 x5 x6 x7 y4 y5 y6 y7 xc xd xe xf yc yd ye yf ]. + ( + x86_64::_mm512_mask_blend_epi32(EVENS4, t, x), + x86_64::_mm512_mask_blend_epi32(EVENS4, y, t), + ) + } +} + +#[inline] +#[must_use] +fn interleave8(x: __m512i, y: __m512i) -> (__m512i, __m512i) { + // We want this to compile to: + // vshufi64x2 t, x, b, 4eh + // vpblendmq res0 {EVENS4}, t, x + // vpblendmq res1 {EVENS4}, y, t + // throughput: 1.5 cyc/2 vec (21.33 els/cyc) + // latency: 4 cyc + + unsafe { + // Safety: If this code got compiled then AVX-512F intrinsics are available. + + // We currently have: + // x = [ x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 xa xb xc xd xe xf ], + // y = [ y0 y1 y2 y3 y4 y5 y6 y7 y8 y9 ya yb yc yd ye yf ]. + // First form + // t = [ x8 x9 xa xb xc xd xe xf y0 y1 y2 y3 y4 y5 y6 y7 ]. + let t = x86_64::_mm512_shuffle_i64x2::<0b01_00_11_10>(x, y); + + // Then + // res0 = [ x0 x1 x2 x3 x4 x5 x6 x7 y0 y1 y2 y3 y4 y5 y6 y7 ], + // res1 = [ x8 x9 xa xb xc xd xe xf y8 y9 ya yb yc yd ye yf ]. + ( + x86_64::_mm512_mask_blend_epi64(EVENS4 as __mmask8, t, x), + x86_64::_mm512_mask_blend_epi64(EVENS4 as __mmask8, y, t), + ) + } +} + +unsafe impl PackedValue for PackedMersenne31AVX512 { + type Value = Mersenne31; + + const WIDTH: usize = WIDTH; + + #[inline] + fn from_slice(slice: &[Mersenne31]) -> &Self { + assert_eq!(slice.len(), Self::WIDTH); + unsafe { + // Safety: `[Mersenne31; WIDTH]` can be transmuted to `PackedMersenne31AVX512` since the + // latter is `repr(transparent)`. They have the same alignment, so the reference cast is + // safe too. + &*slice.as_ptr().cast() + } + } + #[inline] + fn from_slice_mut(slice: &mut [Mersenne31]) -> &mut Self { + assert_eq!(slice.len(), Self::WIDTH); + unsafe { + // Safety: `[Mersenne31; WIDTH]` can be transmuted to `PackedMersenne31AVX512` since the + // latter is `repr(transparent)`. They have the same alignment, so the reference cast is + // safe too. + &mut *slice.as_mut_ptr().cast() + } + } + + /// Similar to `core:array::from_fn`. + #[inline] + fn from_fn Mersenne31>(f: F) -> Self { + let vals_arr: [_; WIDTH] = core::array::from_fn(f); + Self(vals_arr) + } + + #[inline] + fn as_slice(&self) -> &[Mersenne31] { + &self.0[..] + } + #[inline] + fn as_slice_mut(&mut self) -> &mut [Mersenne31] { + &mut self.0[..] + } +} + +unsafe impl PackedField for PackedMersenne31AVX512 { + type Scalar = Mersenne31; + + #[inline] + fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) { + let (v0, v1) = (self.to_vector(), other.to_vector()); + let (res0, res1) = match block_len { + 1 => interleave1(v0, v1), + 2 => interleave2(v0, v1), + 4 => interleave4(v0, v1), + 8 => interleave8(v0, v1), + 16 => (v0, v1), + _ => panic!("unsupported block_len"), + }; + unsafe { + // Safety: all values are in canonical form (we haven't changed them). + (Self::from_vector(res0), Self::from_vector(res1)) + } + } +} + +#[cfg(test)] +mod tests { + use rand::{Rng, SeedableRng}; + use rand_chacha::ChaCha20Rng; + + use super::*; + + type F = Mersenne31; + type P = PackedMersenne31AVX512; + + const fn array_from_valid_reps(vals: [u32; WIDTH]) -> [F; WIDTH] { + let mut res = [Mersenne31 { value: 0 }; WIDTH]; + let mut i = 0; + while i < WIDTH { + res[i] = Mersenne31 { value: vals[i] }; + i += 1; + } + res + } + + const fn packed_from_valid_reps(vals: [u32; WIDTH]) -> P { + PackedMersenne31AVX512(array_from_valid_reps(vals)) + } + + fn array_from_random(seed: u64) -> [F; WIDTH] { + let mut rng = ChaCha20Rng::seed_from_u64(seed); + [(); WIDTH].map(|_| rng.gen()) + } + + fn packed_from_random(seed: u64) -> P { + PackedMersenne31AVX512(array_from_random(seed)) + } + + /// Zero has a redundant representation, so let's test both. + const BOTH_ZEROS: P = packed_from_valid_reps([ + 0x00000000, 0x7fffffff, 0x00000000, 0x7fffffff, 0x00000000, 0x7fffffff, 0x00000000, + 0x7fffffff, 0x00000000, 0x7fffffff, 0x00000000, 0x7fffffff, 0x00000000, 0x7fffffff, + 0x00000000, 0x7fffffff, + ]); + + const SPECIAL_VALS: [F; WIDTH] = array_from_valid_reps([ + 0x00000000, 0x7fffffff, 0x00000001, 0x7ffffffe, 0x00000002, 0x7ffffffd, 0x40000000, + 0x3fffffff, 0x00000000, 0x7fffffff, 0x00000001, 0x7ffffffe, 0x00000002, 0x7ffffffd, + 0x40000000, 0x3fffffff, + ]); + + #[test] + fn test_interleave_1() { + let vec0 = packed_from_valid_reps([ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, + ]); + let vec1 = packed_from_valid_reps([ + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, + 0x1e, 0x1f, + ]); + + let expected0 = packed_from_valid_reps([ + 0x00, 0x10, 0x02, 0x12, 0x04, 0x14, 0x06, 0x16, 0x08, 0x18, 0x0a, 0x1a, 0x0c, 0x1c, + 0x0e, 0x1e, + ]); + let expected1 = packed_from_valid_reps([ + 0x01, 0x11, 0x03, 0x13, 0x05, 0x15, 0x07, 0x17, 0x09, 0x19, 0x0b, 0x1b, 0x0d, 0x1d, + 0x0f, 0x1f, + ]); + + let (res0, res1) = vec0.interleave(vec1, 1); + assert_eq!(res0, expected0); + assert_eq!(res1, expected1); + } + + #[test] + fn test_interleave_2() { + let vec0 = packed_from_valid_reps([ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, + ]); + let vec1 = packed_from_valid_reps([ + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, + 0x1e, 0x1f, + ]); + + let expected0 = packed_from_valid_reps([ + 0x00, 0x01, 0x10, 0x11, 0x04, 0x05, 0x14, 0x15, 0x08, 0x09, 0x18, 0x19, 0x0c, 0x0d, + 0x1c, 0x1d, + ]); + let expected1 = packed_from_valid_reps([ + 0x02, 0x03, 0x12, 0x13, 0x06, 0x07, 0x16, 0x17, 0x0a, 0x0b, 0x1a, 0x1b, 0x0e, 0x0f, + 0x1e, 0x1f, + ]); + + let (res0, res1) = vec0.interleave(vec1, 2); + assert_eq!(res0, expected0); + assert_eq!(res1, expected1); + } + + #[test] + fn test_interleave_4() { + let vec0 = packed_from_valid_reps([ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, + ]); + let vec1 = packed_from_valid_reps([ + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, + 0x1e, 0x1f, + ]); + + let expected0 = packed_from_valid_reps([ + 0x00, 0x01, 0x02, 0x03, 0x10, 0x11, 0x12, 0x13, 0x08, 0x09, 0x0a, 0x0b, 0x18, 0x19, + 0x1a, 0x1b, + ]); + let expected1 = packed_from_valid_reps([ + 0x04, 0x05, 0x06, 0x07, 0x14, 0x15, 0x16, 0x17, 0x0c, 0x0d, 0x0e, 0x0f, 0x1c, 0x1d, + 0x1e, 0x1f, + ]); + + let (res0, res1) = vec0.interleave(vec1, 4); + assert_eq!(res0, expected0); + assert_eq!(res1, expected1); + } + + #[test] + fn test_interleave_8() { + let vec0 = packed_from_valid_reps([ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, + ]); + let vec1 = packed_from_valid_reps([ + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, + 0x1e, 0x1f, + ]); + + let expected0 = packed_from_valid_reps([ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, + 0x16, 0x17, + ]); + let expected1 = packed_from_valid_reps([ + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, + 0x1e, 0x1f, + ]); + + let (res0, res1) = vec0.interleave(vec1, 8); + assert_eq!(res0, expected0); + assert_eq!(res1, expected1); + } + + #[test] + fn test_interleave_16() { + let vec0 = packed_from_valid_reps([ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, + ]); + let vec1 = packed_from_valid_reps([ + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, + 0x1e, 0x1f, + ]); + + let (res0, res1) = vec0.interleave(vec1, 16); + assert_eq!(res0, vec0); + assert_eq!(res1, vec1); + } + + #[test] + fn test_add_associative() { + let vec0 = packed_from_random(0x8b078c2b693c893f); + let vec1 = packed_from_random(0x4ff5dec04791e481); + let vec2 = packed_from_random(0x5806c495e9451f8e); + + let res0 = (vec0 + vec1) + vec2; + let res1 = vec0 + (vec1 + vec2); + + assert_eq!(res0, res1); + } + + #[test] + fn test_add_commutative() { + let vec0 = packed_from_random(0xe1bf9cac02e9072a); + let vec1 = packed_from_random(0xb5061e7de6a6c677); + + let res0 = vec0 + vec1; + let res1 = vec1 + vec0; + + assert_eq!(res0, res1); + } + + #[test] + fn test_additive_identity_right() { + let vec = packed_from_random(0xbcd56facf6a714b5); + let res = vec + BOTH_ZEROS; + assert_eq!(res, vec); + } + + #[test] + fn test_additive_identity_left() { + let vec = packed_from_random(0xb614285cd641233c); + let res = BOTH_ZEROS + vec; + assert_eq!(res, vec); + } + + #[test] + fn test_additive_inverse_add_neg() { + let vec = packed_from_random(0x4b89c8d023c9c62e); + let neg_vec = -vec; + let res = vec + neg_vec; + assert_eq!(res, P::zero()); + } + + #[test] + fn test_additive_inverse_sub() { + let vec = packed_from_random(0x2c94652ee5561341); + let res = vec - vec; + assert_eq!(res, P::zero()); + } + + #[test] + fn test_sub_anticommutative() { + let vec0 = packed_from_random(0xf3783730a14b460e); + let vec1 = packed_from_random(0x5b6f827a023525ee); + + let res0 = vec0 - vec1; + let res1 = -(vec1 - vec0); + + assert_eq!(res0, res1); + } + + #[test] + fn test_sub_zero() { + let vec = packed_from_random(0xc1a526f8226ec1e5); + let res = vec - BOTH_ZEROS; + assert_eq!(res, vec); + } + + #[test] + fn test_zero_sub() { + let vec = packed_from_random(0x4444b9c090519333); + let res0 = BOTH_ZEROS - vec; + let res1 = -vec; + assert_eq!(res0, res1); + } + + #[test] + fn test_neg_own_inverse() { + let vec = packed_from_random(0xee4df174b850a35f); + let res = --vec; + assert_eq!(res, vec); + } + + #[test] + fn test_sub_is_add_neg() { + let vec0 = packed_from_random(0x18f4b5c3a08e49fe); + let vec1 = packed_from_random(0x39bd37a1dc24d492); + let res0 = vec0 - vec1; + let res1 = vec0 + (-vec1); + assert_eq!(res0, res1); + } + + #[test] + fn test_mul_associative() { + let vec0 = packed_from_random(0x0b1ee4d7c979d50c); + let vec1 = packed_from_random(0x39faa0844a36e45a); + let vec2 = packed_from_random(0x08fac4ee76260e44); + + let res0 = (vec0 * vec1) * vec2; + let res1 = vec0 * (vec1 * vec2); + + assert_eq!(res0, res1); + } + + #[test] + fn test_mul_commutative() { + let vec0 = packed_from_random(0x10debdcbd409270c); + let vec1 = packed_from_random(0x927bc073c1c92b2f); + + let res0 = vec0 * vec1; + let res1 = vec1 * vec0; + + assert_eq!(res0, res1); + } + + #[test] + fn test_multiplicative_identity_right() { + let vec = packed_from_random(0xdf0a646b6b0c2c36); + let res = vec * P::one(); + assert_eq!(res, vec); + } + + #[test] + fn test_multiplicative_identity_left() { + let vec = packed_from_random(0x7b4d890bf7a38bd2); + let res = P::one() * vec; + assert_eq!(res, vec); + } + + #[test] + fn test_multiplicative_inverse() { + let arr = array_from_random(0xb0c7a5153103c5a8); + let arr_inv = arr.map(|x| x.inverse()); + + let vec = PackedMersenne31AVX512(arr); + let vec_inv = PackedMersenne31AVX512(arr_inv); + + let res = vec * vec_inv; + assert_eq!(res, P::one()); + } + + #[test] + fn test_mul_zero() { + let vec = packed_from_random(0x7f998daa72489bd7); + let res = vec * BOTH_ZEROS; + assert_eq!(res, P::zero()); + } + + #[test] + fn test_zero_mul() { + let vec = packed_from_random(0x683bc2dd355b06e5); + let res = BOTH_ZEROS * vec; + assert_eq!(res, P::zero()); + } + + #[test] + fn test_mul_negone() { + let vec = packed_from_random(0x97cb9670a8251202); + let res0 = vec * P::neg_one(); + let res1 = -vec; + assert_eq!(res0, res1); + } + + #[test] + fn test_negone_mul() { + let vec = packed_from_random(0xadae69873b5d3baf); + let res0 = P::neg_one() * vec; + let res1 = -vec; + assert_eq!(res0, res1); + } + + #[test] + fn test_neg_distributivity_left() { + let vec0 = packed_from_random(0xd0efd6f272c7de93); + let vec1 = packed_from_random(0xd5dd2cf5e76dd694); + + let res0 = vec0 * -vec1; + let res1 = -(vec0 * vec1); + + assert_eq!(res0, res1); + } + + #[test] + fn test_neg_distributivity_right() { + let vec0 = packed_from_random(0x0da9b03cd4b79b09); + let vec1 = packed_from_random(0x9964d3f4beaf1857); + + let res0 = -vec0 * vec1; + let res1 = -(vec0 * vec1); + + assert_eq!(res0, res1); + } + + #[test] + fn test_add_distributivity_left() { + let vec0 = packed_from_random(0x278d9e202925a1d1); + let vec1 = packed_from_random(0xf04cbac0cbad419f); + let vec2 = packed_from_random(0x76976e2abdc5a056); + + let res0 = vec0 * (vec1 + vec2); + let res1 = vec0 * vec1 + vec0 * vec2; + + assert_eq!(res0, res1); + } + + #[test] + fn test_add_distributivity_right() { + let vec0 = packed_from_random(0xbe1b606eafe2a2b8); + let vec1 = packed_from_random(0x552686a0978ab571); + let vec2 = packed_from_random(0x36f6eec4fd31a460); + + let res0 = (vec0 + vec1) * vec2; + let res1 = vec0 * vec2 + vec1 * vec2; + + assert_eq!(res0, res1); + } + + #[test] + fn test_sub_distributivity_left() { + let vec0 = packed_from_random(0x817d4a27febb0349); + let vec1 = packed_from_random(0x1eaf62a921d6519b); + let vec2 = packed_from_random(0xfec0fb8d3849465a); + + let res0 = vec0 * (vec1 - vec2); + let res1 = vec0 * vec1 - vec0 * vec2; + + assert_eq!(res0, res1); + } + + #[test] + fn test_sub_distributivity_right() { + let vec0 = packed_from_random(0x5a4a82e8e2394585); + let vec1 = packed_from_random(0x6006b1443a22b102); + let vec2 = packed_from_random(0x5a22deac65fcd454); + + let res0 = (vec0 - vec1) * vec2; + let res1 = vec0 * vec2 - vec1 * vec2; + + assert_eq!(res0, res1); + } + + #[test] + fn test_one_plus_one() { + assert_eq!(P::one() + P::one(), P::two()); + } + + #[test] + fn test_negone_plus_two() { + assert_eq!(P::neg_one() + P::two(), P::one()); + } + + #[test] + fn test_double() { + let vec = packed_from_random(0x2e61a907650881e9); + let res0 = P::two() * vec; + let res1 = vec + vec; + assert_eq!(res0, res1); + } + + #[test] + fn test_add_vs_scalar() { + let arr0 = array_from_random(0xac23b5a694dabf70); + let arr1 = array_from_random(0xd249ec90e8a6e733); + + let vec0 = PackedMersenne31AVX512(arr0); + let vec1 = PackedMersenne31AVX512(arr1); + let vec_res = vec0 + vec1; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], arr0[i] + arr1[i]); + } + } + + #[test] + fn test_add_vs_scalar_special_vals_left() { + let arr0 = SPECIAL_VALS; + let arr1 = array_from_random(0x1e2b153f07b64cf3); + + let vec0 = PackedMersenne31AVX512(arr0); + let vec1 = PackedMersenne31AVX512(arr1); + let vec_res = vec0 + vec1; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], arr0[i] + arr1[i]); + } + } + + #[test] + fn test_add_vs_scalar_special_vals_right() { + let arr0 = array_from_random(0xfcf974ac7625a260); + let arr1 = SPECIAL_VALS; + + let vec0 = PackedMersenne31AVX512(arr0); + let vec1 = PackedMersenne31AVX512(arr1); + let vec_res = vec0 + vec1; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], arr0[i] + arr1[i]); + } + } + + #[test] + fn test_sub_vs_scalar() { + let arr0 = array_from_random(0x167ce9d8e920876e); + let arr1 = array_from_random(0x52ddcdd3461e046f); + + let vec0 = PackedMersenne31AVX512(arr0); + let vec1 = PackedMersenne31AVX512(arr1); + let vec_res = vec0 - vec1; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], arr0[i] - arr1[i]); + } + } + + #[test] + fn test_sub_vs_scalar_special_vals_left() { + let arr0 = SPECIAL_VALS; + let arr1 = array_from_random(0x358498640bfe1375); + + let vec0 = PackedMersenne31AVX512(arr0); + let vec1 = PackedMersenne31AVX512(arr1); + let vec_res = vec0 - vec1; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], arr0[i] - arr1[i]); + } + } + + #[test] + fn test_sub_vs_scalar_special_vals_right() { + let arr0 = array_from_random(0x05d81ebfb8f0005c); + let arr1 = SPECIAL_VALS; + + let vec0 = PackedMersenne31AVX512(arr0); + let vec1 = PackedMersenne31AVX512(arr1); + let vec_res = vec0 - vec1; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], arr0[i] - arr1[i]); + } + } + + #[test] + fn test_mul_vs_scalar() { + let arr0 = array_from_random(0x4242ebdc09b74d77); + let arr1 = array_from_random(0x9937b275b3c056cd); + + let vec0 = PackedMersenne31AVX512(arr0); + let vec1 = PackedMersenne31AVX512(arr1); + let vec_res = vec0 * vec1; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], arr0[i] * arr1[i]); + } + } + + #[test] + fn test_mul_vs_scalar_special_vals_left() { + let arr0 = SPECIAL_VALS; + let arr1 = array_from_random(0x5285448b835458a3); + + let vec0 = PackedMersenne31AVX512(arr0); + let vec1 = PackedMersenne31AVX512(arr1); + let vec_res = vec0 * vec1; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], arr0[i] * arr1[i]); + } + } + + #[test] + fn test_mul_vs_scalar_special_vals_right() { + let arr0 = array_from_random(0x22508dc80001d865); + let arr1 = SPECIAL_VALS; + + let vec0 = PackedMersenne31AVX512(arr0); + let vec1 = PackedMersenne31AVX512(arr1); + let vec_res = vec0 * vec1; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], arr0[i] * arr1[i]); + } + } + + #[test] + fn test_neg_vs_scalar() { + let arr = array_from_random(0xc3c273a9b334372f); + + let vec = PackedMersenne31AVX512(arr); + let vec_res = -vec; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], -arr[i]); + } + } + + #[test] + fn test_neg_vs_scalar_special_vals() { + let arr = SPECIAL_VALS; + + let vec = PackedMersenne31AVX512(arr); + let vec_res = -vec; + + for i in 0..WIDTH { + assert_eq!(vec_res.0[i], -arr[i]); + } + } +} diff --git a/monolith/benches/permute.rs b/monolith/benches/permute.rs index 0c0686a3..f0531802 100644 --- a/monolith/benches/permute.rs +++ b/monolith/benches/permute.rs @@ -1,21 +1,30 @@ use criterion::{criterion_group, criterion_main, Criterion}; use p3_field::AbstractField; -use p3_mersenne_31::Mersenne31; -use p3_monolith::{MonolithMdsMatrixMersenne31, MonolithMersenne31}; +use p3_mds::MdsPermutation; +use p3_mersenne_31::{MdsMatrixMersenne31, Mersenne31}; +use p3_monolith::MonolithMersenne31; -fn permute_benchmark(c: &mut Criterion) { - let mds = MonolithMdsMatrixMersenne31::<6>; - let monolith: MonolithMersenne31<_, 16, 5> = MonolithMersenne31::new(mds); +fn bench_monolith(c: &mut Criterion) { + monolith::<_, 12>(c, MdsMatrixMersenne31); + monolith::<_, 16>(c, MdsMatrixMersenne31); +} + +fn monolith(c: &mut Criterion, mds: Mds) +where + Mds: MdsPermutation, +{ + let monolith: MonolithMersenne31<_, WIDTH, 5> = MonolithMersenne31::new(mds); - let mut input: [Mersenne31; 16] = [Mersenne31::zero(); 16]; + let mut input: [Mersenne31; WIDTH] = [Mersenne31::zero(); WIDTH]; for (i, inp) in input.iter_mut().enumerate() { *inp = Mersenne31::from_canonical_usize(i); } - c.bench_function("monolith permutation", |b| { + let name = format!("monolith::", WIDTH); + c.bench_function(name.as_str(), |b| { b.iter(|| monolith.permutation(&mut input)) }); } -criterion_group!(benches, permute_benchmark); +criterion_group!(benches, bench_monolith); criterion_main!(benches); diff --git a/monolith/src/monolith.rs b/monolith/src/monolith.rs index 8442beb4..f8df0836 100644 --- a/monolith/src/monolith.rs +++ b/monolith/src/monolith.rs @@ -113,10 +113,12 @@ where .map(|arr| arr.map(|_| Self::random_field_element(&mut shake))) } + #[inline] pub fn concrete(&self, state: &mut [Mersenne31; WIDTH]) { self.mds.permute_mut(state); } + #[inline] pub fn add_round_constants( &self, state: &mut [Mersenne31; WIDTH], @@ -128,6 +130,7 @@ where } } + #[inline] pub fn bricks(state: &mut [Mersenne31; WIDTH]) { // Feistel Type-3 for (x, x_mut) in (state.to_owned()).iter().zip(state.iter_mut().skip(1)) { @@ -135,6 +138,7 @@ where } } + #[inline] pub fn bar(&self, el: Mersenne31) -> Mersenne31 { let val = &mut el.as_canonical_u32(); @@ -151,6 +155,7 @@ where Mersenne31::from_canonical_u32(*val) } + #[inline] pub fn bars(&self, state: &mut [Mersenne31; WIDTH]) { state .iter_mut() diff --git a/multi-stark/src/config.rs b/multi-stark/src/config.rs deleted file mode 100644 index f837e87b..00000000 --- a/multi-stark/src/config.rs +++ /dev/null @@ -1,65 +0,0 @@ -use core::marker::PhantomData; - -use p3_challenger::FieldChallenger; -use p3_commit::MultivariatePcs; -use p3_field::{AbstractExtensionField, ExtensionField, Field, PackedField}; -use p3_matrix::dense::RowMajorMatrixView; - -pub trait StarkGenericConfig { - /// A value of the trace. - type Val: Field; - - /// The field from which most random challenges are drawn. - type Challenge: ExtensionField; - - type PackedChallenge: PackedField - + AbstractExtensionField<::Packing>; - - /// The PCS used to commit to trace polynomials. - type Pcs: for<'a> MultivariatePcs< - Self::Val, - Self::Challenge, - RowMajorMatrixView<'a, Self::Val>, - Self::Challenger, - >; - - type Challenger: FieldChallenger; - - fn pcs(&self) -> &Self::Pcs; -} - -pub struct StarkConfig { - pcs: Pcs, - _phantom: PhantomData<(Val, Challenge, PackedChallenge, Challenger)>, -} - -impl - StarkConfig -{ - pub fn new(pcs: Pcs) -> Self { - Self { - pcs, - _phantom: PhantomData, - } - } -} - -impl StarkGenericConfig - for StarkConfig -where - Val: Field, - Challenge: ExtensionField, - PackedChallenge: PackedField + AbstractExtensionField, - Pcs: for<'a> MultivariatePcs, Challenger>, - Challenger: FieldChallenger, -{ - type Val = Val; - type Challenge = Challenge; - type PackedChallenge = PackedChallenge; - type Pcs = Pcs; - type Challenger = Challenger; - - fn pcs(&self) -> &Self::Pcs { - &self.pcs - } -} diff --git a/multi-stark/src/folder.rs b/multi-stark/src/folder.rs deleted file mode 100644 index 1e526aab..00000000 --- a/multi-stark/src/folder.rs +++ /dev/null @@ -1,53 +0,0 @@ -use p3_air::{AirBuilder, TwoRowMatrixView}; -use p3_field::{AbstractExtensionField, ExtensionField, Field, PackedField}; - -pub struct ConstraintFolder<'a, F, Challenge, PackedChallenge> -where - F: Field, -{ - pub(crate) main: TwoRowMatrixView<'a, F::Packing>, - pub(crate) is_first_row: F::Packing, - pub(crate) is_last_row: F::Packing, - pub(crate) is_transition: F::Packing, - pub(crate) alpha: Challenge, - pub(crate) accumulator: PackedChallenge, -} - -impl<'a, F, Challenge, PackedChallenge> AirBuilder - for ConstraintFolder<'a, F, Challenge, PackedChallenge> -where - F: Field, - Challenge: ExtensionField, - PackedChallenge: PackedField + AbstractExtensionField, -{ - type F = F; - type Expr = F::Packing; - type Var = F::Packing; - type M = TwoRowMatrixView<'a, F::Packing>; - - fn main(&self) -> Self::M { - self.main - } - - fn is_first_row(&self) -> Self::Expr { - self.is_first_row - } - - fn is_last_row(&self) -> Self::Expr { - self.is_last_row - } - - fn is_transition_window(&self, size: usize) -> Self::Expr { - if size == 2 { - self.is_transition - } else { - panic!("multi-stark only supports a window size of 2") - } - } - - fn assert_zero>(&mut self, x: I) { - let x: F::Packing = x.into(); - self.accumulator *= self.alpha; - self.accumulator += x; - } -} diff --git a/multi-stark/src/lib.rs b/multi-stark/src/lib.rs deleted file mode 100644 index 72109be3..00000000 --- a/multi-stark/src/lib.rs +++ /dev/null @@ -1,13 +0,0 @@ -//! A minimal multivariate STARK framework. - -#![no_std] - -extern crate alloc; - -mod config; -mod folder; -mod prover; - -pub use config::*; -pub use folder::*; -pub use prover::*; diff --git a/multi-stark/src/prover.rs b/multi-stark/src/prover.rs deleted file mode 100644 index d4c6513b..00000000 --- a/multi-stark/src/prover.rs +++ /dev/null @@ -1,21 +0,0 @@ -use p3_air::Air; -use p3_challenger::FieldChallenger; -use p3_commit::Pcs; -use p3_matrix::dense::RowMajorMatrix; - -use crate::{ConstraintFolder, StarkGenericConfig}; - -pub fn prove( - config: &SC, - _air: &A, - _challenger: &mut Challenger, - trace: RowMajorMatrix, -) where - SC: StarkGenericConfig, - A: for<'a> Air>, - Challenger: FieldChallenger, -{ - let (_trace_commit, _trace_data) = config.pcs().commit_batch(trace.as_view()); - - // challenger.observe_ext_element(trace_commit); // TODO -} diff --git a/multi-stark/tests/mul_air.rs b/multi-stark/tests/mul_air.rs deleted file mode 100644 index 7f3c4eac..00000000 --- a/multi-stark/tests/mul_air.rs +++ /dev/null @@ -1,72 +0,0 @@ -// use p3_air::{Air, AirBuilder}; -// use p3_multi_stark::{prove, StarkConfig}; -// use p3_challenger::DuplexChallenger; -// use p3_fri::FRIBasedPcs; -// use p3_lde::NaiveCosetLde; -// use p3_matrix::dense::RowMajorMatrix; -// use p3_matrix::Matrix; -// use p3_merkle_tree::MerkleTreeMMCS; -// use p3_poseidon::Poseidon; -// use p3_symmetric::TruncatedPermutation; -// use p3_symmetric::{ArrayPermutation, CryptographicPermutation, MDSPermutation}; -// use p3_symmetric::PaddingFreeSponge; -// use rand::thread_rng; -// use p3_mersenne_31::Mersenne31; -// use p3_tensor_pcs::TensorPcs; -// -// struct MulAir; -// -// impl Air for MulAir { -// fn eval(&self, builder: &mut AB) { -// let main = builder.main(); -// let main_local = main.row(0); -// let diff = main_local[0] * main_local[1] - main_local[2]; -// builder.assert_zero(diff); -// } -// } -// -// #[test] -// #[ignore] // TODO: Not ready yet. -// fn test_prove_goldilocks() { -// type Val = Mersenne31; -// type Challenge = Mersenne31; // TODO -// -// #[derive(Clone)] -// struct MyMds; -// impl CryptographicPermutation<[Val; 8]> for MyMds { -// fn permute(&self, input: [Val; 8]) -> [Val; 8] { -// input // TODO -// } -// } -// impl ArrayPermutation for MyMds {} -// impl MdsPermutation for MyMds {} -// -// type Mds = MyMds; -// let mds = MyMds; -// -// type Perm = Poseidon; -// let perm = Perm::new(5, 5, vec![], mds); -// -// type H4 = PaddingFreeSponge; -// let h4 = H4::new(perm.clone()); -// -// type C = TruncatedPermutation; -// let c = C::new(perm.clone()); -// -// type Mmcs = MerkleTreeMMCS; -// type Pcs = TensorPcs; -// type MyConfig = StarkConfig; -// -// let mut rng = thread_rng(); -// let trace = RowMajorMatrix::rand(&mut rng, 256, 10); -// let pcs = todo!(); -// let config = StarkConfig::new(pcs); -// let mut challenger = DuplexChallenger::new(perm); -// prove::(&MulAir, config, &mut challenger, trace); -// } -// -// #[test] -// #[ignore] // TODO: Not ready yet. -// fn test_prove_mersenne_31() { -// todo!() -// } diff --git a/poseidon/benches/poseidon.rs b/poseidon/benches/poseidon.rs index 7f9f3e73..3b95da0c 100644 --- a/poseidon/benches/poseidon.rs +++ b/poseidon/benches/poseidon.rs @@ -2,15 +2,12 @@ use std::any::type_name; use std::array; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use p3_baby_bear::BabyBear; +use p3_baby_bear::{BabyBear, MdsMatrixBabyBear}; use p3_field::{AbstractField, Field, PrimeField}; -use p3_goldilocks::Goldilocks; -use p3_mds::babybear::MdsMatrixBabyBear; +use p3_goldilocks::{Goldilocks, MdsMatrixGoldilocks}; use p3_mds::coset_mds::CosetMds; -use p3_mds::goldilocks::MdsMatrixGoldilocks; -use p3_mds::mersenne31::MdsMatrixMersenne31; use p3_mds::MdsPermutation; -use p3_mersenne_31::Mersenne31; +use p3_mersenne_31::{MdsMatrixMersenne31, Mersenne31}; use p3_poseidon::Poseidon; use p3_symmetric::Permutation; use rand::distributions::{Distribution, Standard}; diff --git a/poseidon2/Cargo.toml b/poseidon2/Cargo.toml index 3f60c247..eac6a2e0 100644 --- a/poseidon2/Cargo.toml +++ b/poseidon2/Cargo.toml @@ -5,17 +5,14 @@ edition = "2021" license = "MIT OR Apache-2.0" [dependencies] -p3-baby-bear = { path = "../baby-bear" } -p3-goldilocks = { path = "../goldilocks" } -p3-mersenne-31 = { path = "../mersenne-31" } p3-field = { path = "../field" } -p3-mds = { path = "../mds" } p3-symmetric = { path = "../symmetric" } -rand = "0.8.5" +rand = { version = "0.8.5", features = ["min_const_gen"] } [dev-dependencies] -ark-ff = { version = "^0.4.0", default-features = false } -zkhash = { git = "https://github.com/HorizenLabs/poseidon2" } +p3-mersenne-31 = { path = "../mersenne-31" } +p3-baby-bear = { path = "../baby-bear" } +p3-goldilocks = { path = "../goldilocks" } criterion = "0.5.1" [[bench]] diff --git a/poseidon2/benches/poseidon2.rs b/poseidon2/benches/poseidon2.rs index 746a574e..1f2778ca 100644 --- a/poseidon2/benches/poseidon2.rs +++ b/poseidon2/benches/poseidon2.rs @@ -1,12 +1,10 @@ use std::any::type_name; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use p3_baby_bear::BabyBear; +use p3_baby_bear::{BabyBear, DiffusionMatrixBabybear}; use p3_field::PrimeField64; -use p3_goldilocks::Goldilocks; -use p3_poseidon2::{ - DiffusionMatrixBabybear, DiffusionMatrixGoldilocks, DiffusionPermutation, Poseidon2, -}; +use p3_goldilocks::{DiffusionMatrixGoldilocks, Goldilocks}; +use p3_poseidon2::{DiffusionPermutation, Poseidon2}; use p3_symmetric::Permutation; use rand::distributions::{Distribution, Standard}; use rand::thread_rng; diff --git a/poseidon2/src/babybear.rs b/poseidon2/src/babybear.rs deleted file mode 100644 index 40007027..00000000 --- a/poseidon2/src/babybear.rs +++ /dev/null @@ -1,40 +0,0 @@ -//! Diffusion matrices for Babybear16 and Babybear24. -//! -//! Reference: https://github.com/HorizenLabs/poseidon2/blob/main/plain_implementations/src/poseidon2/poseidon2_instance_babybear.rs - -use p3_baby_bear::BabyBear; -use p3_field::AbstractField; -use p3_symmetric::Permutation; - -use crate::diffusion::matmul_internal; -use crate::DiffusionPermutation; - -pub const MATRIX_DIAG_16_BABYBEAR: [u64; 16] = [ - 0x0a632d94, 0x6db657b7, 0x56fbdc9e, 0x052b3d8a, 0x33745201, 0x5c03108c, 0x0beba37b, 0x258c2e8b, - 0x12029f39, 0x694909ce, 0x6d231724, 0x21c3b222, 0x3c0904a5, 0x01d6acda, 0x27705c83, 0x5231c802, -]; - -pub const MATRIX_DIAG_24_BABYBEAR: [u64; 24] = [ - 0x409133f0, 0x1667a8a1, 0x06a6c7b6, 0x6f53160e, 0x273b11d1, 0x03176c5d, 0x72f9bbf9, 0x73ceba91, - 0x5cdef81d, 0x01393285, 0x46daee06, 0x065d7ba6, 0x52d72d6f, 0x05dd05e0, 0x3bab4b63, 0x6ada3842, - 0x2fc5fbec, 0x770d61b0, 0x5715aae9, 0x03ef0e90, 0x75b6c770, 0x242adf5f, 0x00d0ca4c, 0x36c0e388, -]; - -#[derive(Debug, Clone, Default)] -pub struct DiffusionMatrixBabybear; - -impl> Permutation<[AF; 16]> for DiffusionMatrixBabybear { - fn permute_mut(&self, state: &mut [AF; 16]) { - matmul_internal::(state, MATRIX_DIAG_16_BABYBEAR); - } -} - -impl> DiffusionPermutation for DiffusionMatrixBabybear {} - -impl> Permutation<[AF; 24]> for DiffusionMatrixBabybear { - fn permute_mut(&self, state: &mut [AF; 24]) { - matmul_internal::(state, MATRIX_DIAG_24_BABYBEAR); - } -} - -impl DiffusionPermutation for DiffusionMatrixBabybear {} diff --git a/poseidon2/src/diffusion.rs b/poseidon2/src/diffusion.rs index fd362f10..f4856f2d 100644 --- a/poseidon2/src/diffusion.rs +++ b/poseidon2/src/diffusion.rs @@ -8,18 +8,18 @@ //! //! This file implements a trait for linear layers that satisfy these three properties. -use p3_field::AbstractField; +use p3_field::{AbstractField, Field}; use p3_symmetric::Permutation; pub trait DiffusionPermutation: Permutation<[T; WIDTH]> {} -pub fn matmul_internal( +pub fn matmul_internal, const WIDTH: usize>( state: &mut [AF; WIDTH], - mat_internal_diag_m_1: [u64; WIDTH], + mat_internal_diag_m_1: [F; WIDTH], ) { let sum: AF = state.iter().cloned().sum(); for i in 0..WIDTH { - state[i] *= AF::from_canonical_u64(mat_internal_diag_m_1[i]); + state[i] *= AF::from_f(mat_internal_diag_m_1[i]); state[i] += sum.clone(); } } diff --git a/poseidon2/src/goldilocks.rs b/poseidon2/src/goldilocks.rs deleted file mode 100644 index 5aa67bd3..00000000 --- a/poseidon2/src/goldilocks.rs +++ /dev/null @@ -1,113 +0,0 @@ -//! Diffusion matrices for Goldilocks8, Goldilocks12, Goldilocks16, and Goldilocks20. -//! -//! Reference: https://github.com/HorizenLabs/poseidon2/blob/main/plain_implementations/src/poseidon2/poseidon2_instance_goldilocks.rs - -use p3_field::AbstractField; -use p3_goldilocks::Goldilocks; -use p3_symmetric::Permutation; - -use crate::diffusion::matmul_internal; -use crate::DiffusionPermutation; - -pub const MATRIX_DIAG_8_GOLDILOCKS: [u64; 8] = [ - 0xa98811a1fed4e3a5, - 0x1cc48b54f377e2a0, - 0xe40cd4f6c5609a26, - 0x11de79ebca97a4a3, - 0x9177c73d8b7e929c, - 0x2a6fe8085797e791, - 0x3de6e93329f8d5ad, - 0x3f7af9125da962fe, -]; - -pub const MATRIX_DIAG_12_GOLDILOCKS: [u64; 12] = [ - 0xc3b6c08e23ba9300, - 0xd84b5de94a324fb6, - 0x0d0c371c5b35b84f, - 0x7964f570e7188037, - 0x5daf18bbd996604b, - 0x6743bc47b9595257, - 0x5528b9362c59bb70, - 0xac45e25b7127b68b, - 0xa2077d7dfbb606b5, - 0xf3faac6faee378ae, - 0x0c6388b51545e883, - 0xd27dbb6944917b60, -]; - -pub const MATRIX_DIAG_16_GOLDILOCKS: [u64; 16] = [ - 0xde9b91a467d6afc0, - 0xc5f16b9c76a9be17, - 0x0ab0fef2d540ac55, - 0x3001d27009d05773, - 0xed23b1f906d3d9eb, - 0x5ce73743cba97054, - 0x1c3bab944af4ba24, - 0x2faa105854dbafae, - 0x53ffb3ae6d421a10, - 0xbcda9df8884ba396, - 0xfc1273e4a31807bb, - 0xc77952573d5142c0, - 0x56683339a819b85e, - 0x328fcbd8f0ddc8eb, - 0xb5101e303fce9cb7, - 0x774487b8c40089bb, -]; - -pub const MATRIX_DIAG_20_GOLDILOCKS: [u64; 20] = [ - 0x95c381fda3b1fa57, - 0xf36fe9eb1288f42c, - 0x89f5dcdfef277944, - 0x106f22eadeb3e2d2, - 0x684e31a2530e5111, - 0x27435c5d89fd148e, - 0x3ebed31c414dbf17, - 0xfd45b0b2d294e3cc, - 0x48c904473a7f6dbf, - 0xe0d1b67809295b4d, - 0xddd1941e9d199dcb, - 0x8cfe534eeb742219, - 0xa6e5261d9e3b8524, - 0x6897ee5ed0f82c1b, - 0x0e7dcd0739ee5f78, - 0x493253f3d0d32363, - 0xbb2737f5845f05c0, - 0xa187e810b06ad903, - 0xb635b995936c4918, - 0x0b3694a940bd2394, -]; - -#[derive(Debug, Clone, Default)] -pub struct DiffusionMatrixGoldilocks; - -impl> Permutation<[AF; 8]> for DiffusionMatrixGoldilocks { - fn permute_mut(&self, state: &mut [AF; 8]) { - matmul_internal::(state, MATRIX_DIAG_8_GOLDILOCKS); - } -} - -impl> DiffusionPermutation for DiffusionMatrixGoldilocks {} - -impl> Permutation<[AF; 12]> for DiffusionMatrixGoldilocks { - fn permute_mut(&self, state: &mut [AF; 12]) { - matmul_internal::(state, MATRIX_DIAG_12_GOLDILOCKS); - } -} - -impl> DiffusionPermutation for DiffusionMatrixGoldilocks {} - -impl> Permutation<[AF; 16]> for DiffusionMatrixGoldilocks { - fn permute_mut(&self, state: &mut [AF; 16]) { - matmul_internal::(state, MATRIX_DIAG_16_GOLDILOCKS); - } -} - -impl> DiffusionPermutation for DiffusionMatrixGoldilocks {} - -impl> Permutation<[AF; 20]> for DiffusionMatrixGoldilocks { - fn permute_mut(&self, state: &mut [AF; 20]) { - matmul_internal::(state, MATRIX_DIAG_20_GOLDILOCKS); - } -} - -impl> DiffusionPermutation for DiffusionMatrixGoldilocks {} diff --git a/poseidon2/src/lib.rs b/poseidon2/src/lib.rs index 5db151ac..e609d4e9 100644 --- a/poseidon2/src/lib.rs +++ b/poseidon2/src/lib.rs @@ -8,20 +8,15 @@ extern crate alloc; -mod babybear; mod diffusion; -mod goldilocks; mod matrix; use alloc::vec::Vec; -pub use babybear::DiffusionMatrixBabybear; -pub use diffusion::DiffusionPermutation; -pub use goldilocks::DiffusionMatrixGoldilocks; +pub use diffusion::{matmul_internal, DiffusionPermutation}; use matrix::Poseidon2MEMatrix; use p3_field::{AbstractField, PrimeField}; use p3_symmetric::{CryptographicPermutation, Permutation}; -use rand::distributions::Standard; -use rand::prelude::Distribution; +use rand::distributions::{Distribution, Standard}; use rand::Rng; const SUPPORTED_WIDTHS: [usize; 8] = [2, 3, 4, 8, 12, 16, 20, 24]; @@ -161,238 +156,3 @@ where Diffusion: DiffusionPermutation, { } - -#[cfg(test)] -mod tests { - use alloc::vec::Vec; - - use ark_ff::{BigInteger, PrimeField}; - use p3_baby_bear::BabyBear; - use p3_field::AbstractField; - use p3_goldilocks::Goldilocks; - use p3_symmetric::Permutation; - use rand::Rng; - use zkhash::fields::babybear::FpBabyBear; - use zkhash::fields::goldilocks::FpGoldiLocks; - use zkhash::poseidon2::poseidon2::Poseidon2 as Poseidon2Ref; - use zkhash::poseidon2::poseidon2_instance_babybear::{POSEIDON2_BABYBEAR_16_PARAMS, RC16}; - use zkhash::poseidon2::poseidon2_instance_goldilocks::{ - POSEIDON2_GOLDILOCKS_12_PARAMS, POSEIDON2_GOLDILOCKS_8_PARAMS, RC12, RC8, - }; - - use crate::goldilocks::DiffusionMatrixGoldilocks; - use crate::{DiffusionMatrixBabybear, Poseidon2}; - - fn goldilocks_from_ark_ff(input: FpGoldiLocks) -> Goldilocks { - let as_bigint = input.into_bigint(); - let mut as_bytes = as_bigint.to_bytes_le(); - as_bytes.resize(8, 0); - let as_u64 = u64::from_le_bytes(as_bytes[0..8].try_into().unwrap()); - Goldilocks::from_wrapped_u64(as_u64) - } - - fn babybear_from_ark_ff(input: FpBabyBear) -> BabyBear { - let as_bigint = input.into_bigint(); - let mut as_bytes = as_bigint.to_bytes_le(); - as_bytes.resize(4, 0); - let as_u32 = u32::from_le_bytes(as_bytes[0..4].try_into().unwrap()); - BabyBear::from_wrapped_u32(as_u32) - } - - #[test] - fn test_poseidon2_goldilocks_width_8() { - const WIDTH: usize = 8; - const D: u64 = 7; - const ROUNDS_F: usize = 8; - const ROUNDS_P: usize = 22; - - type F = Goldilocks; - - let mut rng = rand::thread_rng(); - - // Poiseidon2 reference implementation from zkhash repo. - let poseidon2_ref = Poseidon2Ref::new(&POSEIDON2_GOLDILOCKS_8_PARAMS); - - // Copy over round constants from zkhash. - let round_constants: Vec<[F; WIDTH]> = RC8 - .iter() - .map(|vec| { - vec.iter() - .cloned() - .map(goldilocks_from_ark_ff) - .collect::>() - .try_into() - .unwrap() - }) - .collect(); - - // Our Poseidon2 implementation. - let poseidon2: Poseidon2 = Poseidon2::new( - ROUNDS_F, - ROUNDS_P, - round_constants, - DiffusionMatrixGoldilocks, - ); - - // Generate random input and convert to both Goldilocks field formats. - let input_u64 = rng.gen::<[u64; WIDTH]>(); - let input_ref = input_u64 - .iter() - .cloned() - .map(FpGoldiLocks::from) - .collect::>(); - let input = input_u64.map(F::from_wrapped_u64); - - // Check that the conversion is correct. - assert!(input_ref - .iter() - .zip(input.iter()) - .all(|(a, b)| goldilocks_from_ark_ff(*a) == *b)); - - // Run reference implementation. - let output_ref = poseidon2_ref.permutation(&input_ref); - let expected: [F; WIDTH] = output_ref - .iter() - .cloned() - .map(goldilocks_from_ark_ff) - .collect::>() - .try_into() - .unwrap(); - - // Run our implementation. - let mut output = input; - poseidon2.permute_mut(&mut output); - - assert_eq!(output, expected); - } - - #[test] - fn test_poseidon2_goldilocks_width_12() { - const WIDTH: usize = 12; - const D: u64 = 7; - const ROUNDS_F: usize = 8; - const ROUNDS_P: usize = 22; - - type F = Goldilocks; - - let mut rng = rand::thread_rng(); - - // Poiseidon2 reference implementation from zkhash repo. - let poseidon2_ref = Poseidon2Ref::new(&POSEIDON2_GOLDILOCKS_12_PARAMS); - - // Copy over round constants from zkhash. - let round_constants: Vec<[F; WIDTH]> = RC12 - .iter() - .map(|vec| { - vec.iter() - .cloned() - .map(goldilocks_from_ark_ff) - .collect::>() - .try_into() - .unwrap() - }) - .collect(); - - // Our Poseidon2 implementation. - let poseidon2: Poseidon2 = Poseidon2::new( - ROUNDS_F, - ROUNDS_P, - round_constants, - DiffusionMatrixGoldilocks, - ); - - // Generate random input and convert to both Goldilocks field formats. - let input_u64 = rng.gen::<[u64; WIDTH]>(); - let input_ref = input_u64 - .iter() - .cloned() - .map(FpGoldiLocks::from) - .collect::>(); - let input = input_u64.map(F::from_wrapped_u64); - - // Check that the conversion is correct. - assert!(input_ref - .iter() - .zip(input.iter()) - .all(|(a, b)| goldilocks_from_ark_ff(*a) == *b)); - - // Run reference implementation. - let output_ref = poseidon2_ref.permutation(&input_ref); - let expected: [F; WIDTH] = output_ref - .iter() - .cloned() - .map(goldilocks_from_ark_ff) - .collect::>() - .try_into() - .unwrap(); - - // Run our implementation. - let mut output = input; - poseidon2.permute_mut(&mut output); - - assert_eq!(output, expected); - } - - #[test] - fn test_poseidon2_babybear_width_16() { - const WIDTH: usize = 16; - const D: u64 = 7; - const ROUNDS_F: usize = 8; - const ROUNDS_P: usize = 13; - - type F = BabyBear; - - let mut rng = rand::thread_rng(); - - // Poiseidon2 reference implementation from zkhash repo. - let poseidon2_ref = Poseidon2Ref::new(&POSEIDON2_BABYBEAR_16_PARAMS); - - // Copy over round constants from zkhash. - let round_constants: Vec<[F; WIDTH]> = RC16 - .iter() - .map(|vec| { - vec.iter() - .cloned() - .map(babybear_from_ark_ff) - .collect::>() - .try_into() - .unwrap() - }) - .collect(); - - // Our Poseidon2 implementation. - let poseidon2: Poseidon2 = - Poseidon2::new(ROUNDS_F, ROUNDS_P, round_constants, DiffusionMatrixBabybear); - - // Generate random input and convert to both BabyBear field formats. - let input_u32 = rng.gen::<[u32; WIDTH]>(); - let input_ref = input_u32 - .iter() - .cloned() - .map(FpBabyBear::from) - .collect::>(); - let input = input_u32.map(F::from_wrapped_u32); - - // Check that the conversion is correct. - assert!(input_ref - .iter() - .zip(input.iter()) - .all(|(a, b)| babybear_from_ark_ff(*a) == *b)); - - // Run reference implementation. - let output_ref = poseidon2_ref.permutation(&input_ref); - let expected: [F; WIDTH] = output_ref - .iter() - .cloned() - .map(babybear_from_ark_ff) - .collect::>() - .try_into() - .unwrap(); - - // Run our implementation. - let mut output = input; - poseidon2.permute_mut(&mut output); - - assert_eq!(output, expected); - } -} diff --git a/poseidon2/src/matrix.rs b/poseidon2/src/matrix.rs index f2a389e2..fd67bea1 100644 --- a/poseidon2/src/matrix.rs +++ b/poseidon2/src/matrix.rs @@ -33,6 +33,30 @@ where x[3] = t4; } +// At some point we should switch this matrix to: +// [ 2 3 1 1 ] +// [ 1 2 3 1 ] +// [ 1 1 2 3 ] +// [ 3 1 1 2 ]. +// This is more efficient than the one above (11 additions vs 16 additions) and leads to a ~5% speed up. +// Unfortunately it breaks all the tests as we are testing against the implementation from zkhash. +// Hence will leave this as a comment for now and implement later. +// fn apply_m_4(x: &mut [AF]) +// where +// AF: AbstractField, +// AF::F: PrimeField, +// { +// let t01 = x[0].clone() + x[1].clone(); +// let t23 = x[2].clone() + x[3].clone(); +// let t0123 = t01.clone() + t23.clone(); +// let t01123 = t0123.clone() + x[1].clone(); +// let t01233 = t0123.clone() + x[3].clone(); +// x[3] = t01233.clone() + x[0].clone() + x[0].clone(); // 3*x[0] + x[1] + x[2] + 2*x[3] +// x[1] = t01123.clone() + x[2].clone() + x[2].clone(); // x[0] + 2*x[1] + 3*x[2] + x[3] +// x[0] = t01123 + t01; // 2*x[0] + 3*x[1] + x[2] + x[3] +// x[2] = t01233 + t23; // x[0] + x[1] + 2*x[2] + 3*x[3] +// } + impl Permutation<[AF; WIDTH]> for Poseidon2MEMatrix where AF: AbstractField, diff --git a/rescue/benches/rescue.rs b/rescue/benches/rescue.rs index c5de0395..f8227153 100644 --- a/rescue/benches/rescue.rs +++ b/rescue/benches/rescue.rs @@ -2,15 +2,12 @@ use std::any::type_name; use std::array; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use p3_baby_bear::BabyBear; +use p3_baby_bear::{BabyBear, MdsMatrixBabyBear}; use p3_field::{AbstractField, Field, PrimeField64}; -use p3_goldilocks::Goldilocks; -use p3_mds::babybear::MdsMatrixBabyBear; -use p3_mds::goldilocks::MdsMatrixGoldilocks; +use p3_goldilocks::{Goldilocks, MdsMatrixGoldilocks}; use p3_mds::integrated_coset_mds::IntegratedCosetMds; -use p3_mds::mersenne31::MdsMatrixMersenne31; use p3_mds::MdsPermutation; -use p3_mersenne_31::Mersenne31; +use p3_mersenne_31::{MdsMatrixMersenne31, Mersenne31}; use p3_rescue::{BasicSboxLayer, Rescue}; use p3_symmetric::Permutation; use rand::distributions::{Distribution, Standard}; diff --git a/rescue/src/rescue.rs b/rescue/src/rescue.rs index 29d21318..05b3bcd4 100644 --- a/rescue/src/rescue.rs +++ b/rescue/src/rescue.rs @@ -154,8 +154,7 @@ where #[cfg(test)] mod tests { use p3_field::AbstractField; - use p3_mds::mersenne31::MdsMatrixMersenne31; - use p3_mersenne_31::Mersenne31; + use p3_mersenne_31::{MdsMatrixMersenne31, Mersenne31}; use p3_symmetric::{CryptographicHasher, PaddingFreeSponge, Permutation}; use crate::rescue::Rescue; @@ -194,16 +193,16 @@ mod tests { // https://github.com/KULeuven-COSIC/Marvellous/blob/master/rescue_prime.sage const PERMUTATION_OUTPUTS: [[u64; WIDTH]; NUM_TESTS] = [ [ - 983158113, 88736227, 182376113, 380581876, 1054929865, 873254619, 1742172525, - 1018880997, 1922857524, 2128461101, 1878468735, 736900567, + 1415867641, 1662872101, 1070605392, 450708029, 1752877321, 144003686, 623713963, + 13124252, 1719755748, 1164265443, 1031746503, 656034061, ], [ - 504747180, 1708979401, 1023327691, 414948293, 1811202621, 427591394, 666516466, - 1900855073, 1511950466, 346735768, 708718627, 2070146754, + 745601819, 399135364, 1705560828, 1125372012, 2039222953, 1144119753, 1606567447, + 1152559313, 1762793605, 424623198, 651056006, 1227670410, ], [ - 2043076197, 1832583290, 59074227, 991951621, 1166633601, 629305333, 1869192382, - 1355209324, 1919016607, 175801753, 279984593, 2086613859, + 277798368, 1055656487, 366843969, 917136738, 1286790161, 1840518903, 161567750, + 974017246, 1102241644, 633393178, 896102012, 1791619348, ], ]; @@ -231,7 +230,7 @@ mod tests { let input: [Mersenne31; 6] = [1, 2, 3, 4, 5, 6].map(Mersenne31::from_canonical_u64); let expected: [Mersenne31; 6] = [ - 337439389, 568168673, 983336666, 1144682541, 1342961449, 386074361, + 2055426095, 968531194, 1592692524, 136824376, 175318858, 1160805485, ] .map(Mersenne31::from_canonical_u64); diff --git a/tensor-pcs/Cargo.toml b/tensor-pcs/Cargo.toml deleted file mode 100644 index fde22e7a..00000000 --- a/tensor-pcs/Cargo.toml +++ /dev/null @@ -1,14 +0,0 @@ -[package] -name = "p3-tensor-pcs" -version = "0.1.0" -edition = "2021" -license = "MIT OR Apache-2.0" - -[dependencies] -p3-challenger = { path = "../challenger" } -p3-code = { path = "../code" } -p3-commit = { path = "../commit" } -p3-field = { path = "../field" } -p3-matrix = { path = "../matrix" } -p3-util = { path = "../util" } -serde = "1.0.196" diff --git a/tensor-pcs/src/lib.rs b/tensor-pcs/src/lib.rs deleted file mode 100644 index 8fcadcf2..00000000 --- a/tensor-pcs/src/lib.rs +++ /dev/null @@ -1,11 +0,0 @@ -//! A PCS using degree 2 tensor codes, based on BCG20 . - -#![no_std] - -extern crate alloc; - -mod reshape; -mod tensor_pcs; -mod wrapped_matrix; - -pub use tensor_pcs::*; diff --git a/tensor-pcs/src/reshape.rs b/tensor-pcs/src/reshape.rs deleted file mode 100644 index 3e2b4e83..00000000 --- a/tensor-pcs/src/reshape.rs +++ /dev/null @@ -1,13 +0,0 @@ -use p3_util::log2_strict_usize; - -pub(crate) fn optimal_wraps(width: usize, height: usize) -> usize { - let height_bits = log2_strict_usize(height); - (0..height_bits) - .min_by_key(|&wrap_bits| estimate_cost(width << wrap_bits, height >> wrap_bits)) - .unwrap() -} - -fn estimate_cost(width: usize, height: usize) -> usize { - // TODO: Better constants... - width + height -} diff --git a/tensor-pcs/src/tensor_pcs.rs b/tensor-pcs/src/tensor_pcs.rs deleted file mode 100644 index 1b487565..00000000 --- a/tensor-pcs/src/tensor_pcs.rs +++ /dev/null @@ -1,66 +0,0 @@ -use alloc::vec::Vec; -use core::marker::PhantomData; - -use p3_code::LinearCodeFamily; -use p3_commit::{DirectMmcs, Pcs}; -use p3_field::Field; -use p3_matrix::dense::RowMajorMatrix; -use p3_matrix::MatrixRows; - -use crate::reshape::optimal_wraps; -use crate::wrapped_matrix::WrappedMatrix; - -pub struct TensorPcs -where - F: Field, - C: LinearCodeFamily>>, - M: DirectMmcs, -{ - codes: C, - mmcs: M, - _phantom_f: PhantomData, - _phantom_m: PhantomData, -} - -impl TensorPcs -where - F: Field, - C: LinearCodeFamily>, Out = RowMajorMatrix>, - M: DirectMmcs, -{ - pub fn new(codes: C, mmcs: M) -> Self { - Self { - codes, - mmcs, - _phantom_f: PhantomData, - _phantom_m: PhantomData, - } - } -} - -impl Pcs for TensorPcs -where - F: Field, - In: MatrixRows, - C: LinearCodeFamily>, Out = RowMajorMatrix>, - M: DirectMmcs, -{ - type Commitment = M::Commitment; - type ProverData = M::ProverData; - type Proof = Vec; - type Error = (); - - fn commit_batches(&self, polynomials: Vec) -> (Self::Commitment, Self::ProverData) { - let encoded_polynomials = polynomials - .into_iter() - .map(|mat| { - let wraps = optimal_wraps(mat.width(), mat.height()); - let wrapped = WrappedMatrix::new(mat.to_row_major_matrix(), wraps); - self.codes.encode_batch(wrapped) - }) - .collect(); - self.mmcs.commit(encoded_polynomials) - } -} - -// TODO: Impl MultivariatePcs diff --git a/tensor-pcs/src/wrapped_matrix.rs b/tensor-pcs/src/wrapped_matrix.rs deleted file mode 100644 index b922ed0d..00000000 --- a/tensor-pcs/src/wrapped_matrix.rs +++ /dev/null @@ -1,85 +0,0 @@ -use core::marker::PhantomData; - -use p3_matrix::{Matrix, MatrixRows}; - -pub struct WrappedMatrix { - inner: M, - wraps: usize, - _phantom_t: PhantomData, -} - -impl WrappedMatrix -where - M: Matrix, -{ - pub fn new(inner: M, wraps: usize) -> Self { - assert_eq!(inner.height() % wraps, 0); - Self { - inner, - wraps, - _phantom_t: PhantomData, - } - } -} - -impl Matrix for WrappedMatrix -where - M: Matrix, -{ - fn width(&self) -> usize { - self.inner.width() * self.wraps - } - - fn height(&self) -> usize { - self.inner.width() / self.wraps - } -} - -impl MatrixRows for WrappedMatrix -where - M: MatrixRows, -{ - type Row<'a> = WrappedMatrixRow<'a, T, M> where T: 'a, M: 'a; - - fn row(&self, r: usize) -> Self::Row<'_> { - WrappedMatrixRow { - wrapped_matrix: self, - row: r, - current_iter: self.inner.row(r).into_iter(), - next_wrap: 1, - } - } -} - -pub struct WrappedMatrixRow<'a, T, M> -where - T: 'a, - M: MatrixRows, -{ - wrapped_matrix: &'a WrappedMatrix, - row: usize, - current_iter: as IntoIterator>::IntoIter, - next_wrap: usize, -} - -impl<'a, T, M> Iterator for WrappedMatrixRow<'a, T, M> -where - T: 'a, - M: MatrixRows, -{ - type Item = T; - - fn next(&mut self) -> Option { - self.current_iter.next().or_else(|| { - (self.next_wrap < self.wrapped_matrix.wraps).then(|| { - self.current_iter = self - .wrapped_matrix - .inner - .row(self.next_wrap * self.wrapped_matrix.wraps + self.row) - .into_iter(); - self.next_wrap += 1; - self.current_iter.next().unwrap() - }) - }) - } -} diff --git a/uni-stark/Cargo.toml b/uni-stark/Cargo.toml index 475cd3df..4e007aa4 100644 --- a/uni-stark/Cargo.toml +++ b/uni-stark/Cargo.toml @@ -19,7 +19,10 @@ serde = { version = "1.0", default-features = false, features = ["derive", "allo [dev-dependencies] p3-baby-bear = { path = "../baby-bear" } +p3-commit = { path = "../commit", features = ["test-utils"] } +p3-circle = { path = "../circle" } p3-fri = { path = "../fri" } +p3-keccak = { path = "../keccak" } p3-mds = { path = "../mds" } p3-merkle-tree = { path = "../merkle-tree" } p3-goldilocks = { path = "../goldilocks" } diff --git a/uni-stark/src/config.rs b/uni-stark/src/config.rs index 960ff273..e9668600 100644 --- a/uni-stark/src/config.rs +++ b/uni-stark/src/config.rs @@ -1,43 +1,45 @@ use core::marker::PhantomData; -use p3_challenger::{CanObserve, FieldChallenger}; -use p3_commit::{Pcs, UnivariatePcsWithLde}; -use p3_field::{ExtensionField, Field, TwoAdicField}; -use p3_matrix::dense::RowMajorMatrix; +use p3_challenger::{CanObserve, CanSample, FieldChallenger}; +use p3_commit::{Pcs, PolynomialSpace}; +use p3_field::{ExtensionField, Field}; -pub type PackedVal = <::Val as Field>::Packing; -pub type PackedChallenge = <::Challenge as ExtensionField< - ::Val, ->>::ExtensionPacking; +pub type Domain = <::Pcs as Pcs< + ::Challenge, + ::Challenger, +>>::Domain; -pub trait StarkGenericConfig { - /// The field over which trace data is encoded. - type Val: TwoAdicField; +pub type Val = <<::Pcs as Pcs< + ::Challenge, + ::Challenger, +>>::Domain as PolynomialSpace>::Val; - /// The field from which most random challenges are drawn. - type Challenge: ExtensionField + TwoAdicField; +pub type PackedVal = as Field>::Packing; +pub type PackedChallenge = + <::Challenge as ExtensionField>>::ExtensionPacking; + +pub trait StarkGenericConfig { /// The PCS used to commit to trace polynomials. - type Pcs: UnivariatePcsWithLde< - Self::Val, - Self::Challenge, - RowMajorMatrix, - Self::Challenger, - >; + type Pcs: Pcs; + + /// The field from which most random challenges are drawn. + type Challenge: ExtensionField>; /// The challenger (Fiat-Shamir) implementation used. - type Challenger: FieldChallenger - + CanObserve<>>::Commitment>; + type Challenger: FieldChallenger> + + CanObserve<>::Commitment> + + CanSample; fn pcs(&self) -> &Self::Pcs; } -pub struct StarkConfig { +pub struct StarkConfig { pcs: Pcs, - _phantom: PhantomData<(Val, Challenge, Challenger)>, + _phantom: PhantomData<(Challenge, Challenger)>, } -impl StarkConfig { +impl StarkConfig { pub fn new(pcs: Pcs) -> Self { Self { pcs, @@ -46,18 +48,16 @@ impl StarkConfig StarkGenericConfig - for StarkConfig +impl StarkGenericConfig for StarkConfig where - Val: TwoAdicField, - Challenge: ExtensionField + TwoAdicField, - Pcs: UnivariatePcsWithLde, Challenger>, - Challenger: FieldChallenger - + CanObserve<>>::Commitment>, + Challenge: ExtensionField<::Val>, + Pcs: p3_commit::Pcs, + Challenger: FieldChallenger<::Val> + + CanObserve<>::Commitment> + + CanSample, { - type Val = Val; - type Challenge = Challenge; type Pcs = Pcs; + type Challenge = Challenge; type Challenger = Challenger; fn pcs(&self) -> &Self::Pcs { diff --git a/uni-stark/src/decompose.rs b/uni-stark/src/decompose.rs deleted file mode 100644 index 1858d614..00000000 --- a/uni-stark/src/decompose.rs +++ /dev/null @@ -1,131 +0,0 @@ -use alloc::vec; -use alloc::vec::Vec; - -use p3_field::{AbstractExtensionField, TwoAdicField}; -use p3_matrix::dense::RowMajorMatrix; -use p3_maybe_rayon::prelude::*; -use p3_util::log2_strict_usize; -use tracing::instrument; - -/// Decompose the quotient polynomial into chunks using a generalization of even-odd decomposition. -/// Then, arrange the results in a row-major matrix, so that each chunk of the decomposed polynomial -/// becomes `D` columns of the resulting matrix, where `D` is the field extension degree. -#[instrument(name = "decompose and flatten quotient", skip_all)] -pub fn decompose_and_flatten( - quotient_poly: Vec, - shift: Challenge, - log_chunks: usize, -) -> RowMajorMatrix -where - Val: TwoAdicField, - Challenge: AbstractExtensionField + TwoAdicField, -{ - let chunks: Vec> = decompose(quotient_poly, shift, log_chunks); - let degree = chunks[0].len(); - let quotient_chunks_flattened: Vec = (0..degree) - .into_par_iter() - .flat_map_iter(|row| { - chunks - .iter() - .flat_map(move |chunk| chunk[row].as_base_slice().iter().copied()) - }) - .collect(); - let challenge_ext_degree = >::D; - RowMajorMatrix::new( - quotient_chunks_flattened, - challenge_ext_degree << log_chunks, - ) -} - -/// A generalization of even-odd decomposition. -fn decompose(poly: Vec, shift: F, log_chunks: usize) -> Vec> { - // For now, we use a naive recursive method. - // A more optimized method might look similar to a decimation-in-time FFT, - // but only the first `log_chunks` layers. It should also be parallelized. - - if log_chunks == 0 { - return vec![poly]; - } - - let n = poly.len(); - debug_assert!(n > 1); - let log_n = log2_strict_usize(n); - let half_n = poly.len() / 2; - let g_inv = F::two_adic_generator(log_n).inverse(); - - let one_half = F::two().inverse(); - let (first, second) = poly.split_at(half_n); - - // Note that - // p_e(g^(2i)) = (p(g^i) + p(g^(n/2 + i))) / 2 - // p_o(g^(2i)) = (p(g^i) - p(g^(n/2 + i))) / (2 s g^i) - - // p_e(g^(2i)) = (a + b) / 2 - // p_o(g^(2i)) = (a - b) / (2 s g^i) - let mut g_inv_powers = g_inv.shifted_powers(shift.inverse()); - let g_inv_powers = (0..first.len()) - .map(|_| g_inv_powers.next().unwrap()) - .collect::>(); - let (even, odd): (Vec<_>, Vec<_>) = first - .par_iter() - .zip(second.par_iter()) - .zip(g_inv_powers.par_iter()) - .map(|((&a, &b), g_inv_power)| { - let sum = a + b; - let diff = a - b; - (sum * one_half, diff * one_half * *g_inv_power) - }) - .unzip(); - - let (even_decomp, odd_decomp) = join( - || decompose(even, shift.square(), log_chunks - 1), - || decompose(odd, shift.square(), log_chunks - 1), - ); - - let mut combined = even_decomp; - combined.extend(odd_decomp); - combined -} - -#[cfg(test)] -mod tests { - use itertools::Itertools; - use p3_baby_bear::BabyBear; - use p3_dft::{Radix2Dit, TwoAdicSubgroupDft}; - use p3_field::AbstractField; - use p3_util::reverse_slice_index_bits; - use rand::{thread_rng, Rng}; - - use super::*; - - // If we decompose evaluations over a coset s*g^i, we should get - // evaluations over s^log_chunks * g^(log_chunks*i). - #[test] - fn test_decompose_coset() { - type F = BabyBear; - - let mut rng = thread_rng(); - let dft = Radix2Dit::default(); - - let log_n = 5; - let n = 1 << log_n; - let log_chunks = 3; - let chunks = 1 << log_chunks; - let shift = F::generator(); - - let coeffs = (0..n).map(|_| rng.gen::()).collect::>(); - - let coset_evals = dft.coset_dft(coeffs.clone(), shift); - let mut decomp = decompose(coset_evals, shift, log_chunks); - - reverse_slice_index_bits(&mut decomp); - - for (i, e) in decomp.into_iter().enumerate() { - let chunk_coeffs = coeffs.iter().cloned().skip(i).step_by(chunks).collect_vec(); - assert_eq!( - dft.coset_dft(chunk_coeffs, shift.exp_power_of_2(log_chunks)), - e - ); - } - } -} diff --git a/uni-stark/src/folder.rs b/uni-stark/src/folder.rs index 25e5beca..3015ee85 100644 --- a/uni-stark/src/folder.rs +++ b/uni-stark/src/folder.rs @@ -1,7 +1,7 @@ use p3_air::{AirBuilder, TwoRowMatrixView}; -use p3_field::{AbstractField, Field}; +use p3_field::AbstractField; -use crate::{PackedChallenge, PackedVal, StarkGenericConfig}; +use crate::{PackedChallenge, PackedVal, StarkGenericConfig, Val}; pub struct ProverConstraintFolder<'a, SC: StarkGenericConfig> { pub main: TwoRowMatrixView<'a, PackedVal>, @@ -12,17 +12,17 @@ pub struct ProverConstraintFolder<'a, SC: StarkGenericConfig> { pub accumulator: PackedChallenge, } -pub struct VerifierConstraintFolder<'a, Challenge> { - pub main: TwoRowMatrixView<'a, Challenge>, - pub is_first_row: Challenge, - pub is_last_row: Challenge, - pub is_transition: Challenge, - pub alpha: Challenge, - pub accumulator: Challenge, +pub struct VerifierConstraintFolder<'a, SC: StarkGenericConfig> { + pub main: TwoRowMatrixView<'a, SC::Challenge>, + pub is_first_row: SC::Challenge, + pub is_last_row: SC::Challenge, + pub is_transition: SC::Challenge, + pub alpha: SC::Challenge, + pub accumulator: SC::Challenge, } impl<'a, SC: StarkGenericConfig> AirBuilder for ProverConstraintFolder<'a, SC> { - type F = SC::Val; + type F = Val; type Expr = PackedVal; type Var = PackedVal; type M = TwoRowMatrixView<'a, PackedVal>; @@ -54,11 +54,11 @@ impl<'a, SC: StarkGenericConfig> AirBuilder for ProverConstraintFolder<'a, SC> { } } -impl<'a, Challenge: Field> AirBuilder for VerifierConstraintFolder<'a, Challenge> { - type F = Challenge; - type Expr = Challenge; - type Var = Challenge; - type M = TwoRowMatrixView<'a, Challenge>; +impl<'a, SC: StarkGenericConfig> AirBuilder for VerifierConstraintFolder<'a, SC> { + type F = Val; + type Expr = SC::Challenge; + type Var = SC::Challenge; + type M = TwoRowMatrixView<'a, SC::Challenge>; fn main(&self) -> Self::M { self.main @@ -81,7 +81,7 @@ impl<'a, Challenge: Field> AirBuilder for VerifierConstraintFolder<'a, Challenge } fn assert_zero>(&mut self, x: I) { - let x: Challenge = x.into(); + let x: SC::Challenge = x.into(); self.accumulator *= self.alpha; self.accumulator += x; } diff --git a/uni-stark/src/lib.rs b/uni-stark/src/lib.rs index 4132539c..7bbf5623 100644 --- a/uni-stark/src/lib.rs +++ b/uni-stark/src/lib.rs @@ -5,7 +5,6 @@ extern crate alloc; mod config; -mod decompose; mod folder; mod proof; mod prover; @@ -19,7 +18,6 @@ mod check_constraints; pub use check_constraints::*; pub use config::*; -pub use decompose::*; pub use folder::*; pub use proof::*; pub use prover::*; diff --git a/uni-stark/src/proof.rs b/uni-stark/src/proof.rs index b03a00b0..85933e93 100644 --- a/uni-stark/src/proof.rs +++ b/uni-stark/src/proof.rs @@ -1,15 +1,18 @@ use alloc::vec::Vec; use p3_commit::Pcs; -use p3_matrix::dense::RowMajorMatrix; use serde::{Deserialize, Serialize}; use crate::StarkGenericConfig; -type Val = ::Val; -type ValMat = RowMajorMatrix>; -type Com = <::Pcs as Pcs, ValMat>>::Commitment; -type PcsProof = <::Pcs as Pcs, ValMat>>::Proof; +type Com = <::Pcs as Pcs< + ::Challenge, + ::Challenger, +>>::Commitment; +type PcsProof = <::Pcs as Pcs< + ::Challenge, + ::Challenger, +>>::Proof; #[derive(Serialize, Deserialize)] #[serde(bound = "")] @@ -30,5 +33,5 @@ pub struct Commitments { pub struct OpenedValues { pub(crate) trace_local: Vec, pub(crate) trace_next: Vec, - pub(crate) quotient_chunks: Vec, + pub(crate) quotient_chunks: Vec>, } diff --git a/uni-stark/src/prover.rs b/uni-stark/src/prover.rs index 473ea948..552f9c11 100644 --- a/uni-stark/src/prover.rs +++ b/uni-stark/src/prover.rs @@ -1,40 +1,37 @@ use alloc::vec; use alloc::vec::Vec; -use itertools::Itertools; +use itertools::{izip, Itertools}; use p3_air::{Air, TwoRowMatrixView}; -use p3_challenger::{CanObserve, FieldChallenger}; -use p3_commit::{Pcs, UnivariatePcs, UnivariatePcsWithLde}; -use p3_field::{ - cyclic_subgroup_coset_known_order, AbstractExtensionField, AbstractField, Field, PackedValue, - TwoAdicField, -}; +use p3_challenger::{CanObserve, CanSample, FieldChallenger}; +use p3_commit::{Pcs, PolynomialSpace}; +use p3_field::{AbstractExtensionField, AbstractField, PackedValue}; use p3_matrix::dense::RowMajorMatrix; -use p3_matrix::{Matrix, MatrixGet, MatrixRows}; +use p3_matrix::{Matrix, MatrixGet}; use p3_maybe_rayon::prelude::*; use p3_util::log2_strict_usize; use tracing::{info_span, instrument}; use crate::symbolic_builder::{get_log_quotient_degree, SymbolicAirBuilder}; use crate::{ - decompose_and_flatten, Commitments, OpenedValues, PackedChallenge, PackedVal, Proof, - ProverConstraintFolder, StarkGenericConfig, ZerofierOnCoset, + Commitments, Domain, OpenedValues, PackedChallenge, PackedVal, Proof, ProverConstraintFolder, + StarkGenericConfig, Val, }; #[instrument(skip_all)] pub fn prove< SC, - #[cfg(debug_assertions)] A: for<'a> Air>, + #[cfg(debug_assertions)] A: for<'a> Air>>, #[cfg(not(debug_assertions))] A, >( config: &SC, air: &A, challenger: &mut SC::Challenger, - trace: RowMajorMatrix, + trace: RowMajorMatrix>, ) -> Proof where SC: StarkGenericConfig, - A: Air> + for<'a> Air>, + A: Air>> + for<'a> Air>, { #[cfg(debug_assertions)] crate::check_constraints::check_constraints(air, &trace); @@ -42,44 +39,36 @@ where let degree = trace.height(); let log_degree = log2_strict_usize(degree); - let log_quotient_degree = get_log_quotient_degree::(air); - - let g_subgroup = SC::Val::two_adic_generator(log_degree); + let log_quotient_degree = get_log_quotient_degree::, A>(air); + let quotient_degree = 1 << log_quotient_degree; let pcs = config.pcs(); + let trace_domain = pcs.natural_domain_for_degree(degree); + let (trace_commit, trace_data) = - info_span!("commit to trace data").in_scope(|| pcs.commit_batch(trace)); + info_span!("commit to trace data").in_scope(|| pcs.commit(vec![(trace_domain, trace)])); challenger.observe(trace_commit.clone()); let alpha: SC::Challenge = challenger.sample_ext_element(); - let mut trace_ldes = pcs.get_ldes(&trace_data); - assert_eq!(trace_ldes.len(), 1); - let trace_lde = trace_ldes.pop().unwrap(); + let quotient_domain = + trace_domain.create_disjoint_domain(1 << (log_degree + log_quotient_degree)); - let log_stride_for_quotient = pcs.log_blowup() - log_quotient_degree; - let trace_lde_for_quotient = trace_lde.vertically_strided(1 << log_stride_for_quotient, 0); + let trace_on_quotient_domain = pcs.get_evaluations_on_domain(&trace_data, 0, quotient_domain); let quotient_values = quotient_values( - config, air, - log_degree, - log_quotient_degree, - trace_lde_for_quotient, + trace_domain, + quotient_domain, + trace_on_quotient_domain, alpha, ); - let quotient_chunks_flattened = decompose_and_flatten( - quotient_values, - SC::Challenge::from_base(pcs.coset_shift()), - log_quotient_degree, - ); - let (quotient_commit, quotient_data) = - info_span!("commit to quotient poly chunks").in_scope(|| { - pcs.commit_shifted_batch( - quotient_chunks_flattened, - pcs.coset_shift().exp_power_of_2(log_quotient_degree), - ) - }); + let quotient_flat = RowMajorMatrix::new_col(quotient_values).flatten_to_base(); + let quotient_chunks = quotient_domain.split_evals(quotient_degree, quotient_flat); + let qc_domains = quotient_domain.split_domains(quotient_degree); + + let (quotient_commit, quotient_data) = info_span!("commit to quotient poly chunks") + .in_scope(|| pcs.commit(izip!(qc_domains, quotient_chunks).collect_vec())); challenger.observe(quotient_commit.clone()); let commitments = Commitments { @@ -87,20 +76,23 @@ where quotient_chunks: quotient_commit, }; - let zeta: SC::Challenge = challenger.sample_ext_element(); - let (opened_values, opening_proof) = pcs.open_multi_batches( - &[ - (&trace_data, &[vec![zeta, zeta * g_subgroup]]), + let zeta: SC::Challenge = challenger.sample(); + let zeta_next = trace_domain.next_point(zeta).unwrap(); + + let (opened_values, opening_proof) = pcs.open( + vec![ + (&trace_data, vec![vec![zeta, zeta_next]]), ( "ient_data, - &[vec![zeta.exp_power_of_2(log_quotient_degree)]], + // open every chunk at zeta + (0..quotient_degree).map(|_| vec![zeta]).collect_vec(), ), ], challenger, ); let trace_local = opened_values[0][0][0].clone(); let trace_next = opened_values[0][0][1].clone(); - let quotient_chunks = opened_values[1][0][0].clone(); + let quotient_chunks = opened_values[1].iter().map(|v| v[0].clone()).collect_vec(); let opened_values = OpenedValues { trace_local, trace_next, @@ -116,75 +108,53 @@ where #[instrument(name = "compute quotient polynomial", skip_all)] fn quotient_values( - config: &SC, air: &A, - degree_bits: usize, - quotient_degree_bits: usize, - trace_lde: Mat, + trace_domain: Domain, + quotient_domain: Domain, + trace_on_quotient_domain: Mat, alpha: SC::Challenge, ) -> Vec where SC: StarkGenericConfig, A: for<'a> Air>, - Mat: MatrixGet + Sync, + Mat: MatrixGet> + Sync, { - let degree = 1 << degree_bits; - let quotient_size_bits = degree_bits + quotient_degree_bits; - let quotient_size = 1 << quotient_size_bits; - let g_subgroup = SC::Val::two_adic_generator(degree_bits); - let g_extended = SC::Val::two_adic_generator(quotient_size_bits); - let subgroup_last = g_subgroup.inverse(); - let coset_shift = config.pcs().coset_shift(); - let next_step = 1 << quotient_degree_bits; - - let mut coset: Vec<_> = - cyclic_subgroup_coset_known_order(g_extended, coset_shift, quotient_size).collect(); - - let zerofier_on_coset = ZerofierOnCoset::new(degree_bits, quotient_degree_bits, coset_shift); - - // Evaluations of L_first(x) = Z_H(x) / (x - 1) on our coset s H. - let mut lagrange_first_evals = zerofier_on_coset.lagrange_basis_unnormalized(0); - let mut lagrange_last_evals = zerofier_on_coset.lagrange_basis_unnormalized(degree - 1); - - // We have a few vectors of length `quotient_size`, and we're going to take slices therein of - // length `WIDTH`. In the edge case where `quotient_size < WIDTH`, we need to pad those vectors - // in order for the slices to exist. The entries beyond quotient_size will be ignored, so we can - // just use default values. - for _ in quotient_size..PackedVal::::WIDTH { - coset.push(SC::Val::default()); - lagrange_first_evals.push(SC::Val::default()); - lagrange_last_evals.push(SC::Val::default()); - } + let quotient_size = quotient_domain.size(); + let width = trace_on_quotient_domain.width(); + let sels = trace_domain.selectors_on_coset(quotient_domain); + + let qdb = log2_strict_usize(quotient_domain.size()) - log2_strict_usize(trace_domain.size()); + let next_step = 1 << qdb; + + assert!(quotient_size >= PackedVal::::WIDTH); (0..quotient_size) .into_par_iter() .step_by(PackedVal::::WIDTH) - .flat_map_iter(|i_local_start| { + .flat_map_iter(|i_start| { let wrap = |i| i % quotient_size; - let i_next_start = wrap(i_local_start + next_step); - let i_range = i_local_start..i_local_start + PackedVal::::WIDTH; + let i_range = i_start..i_start + PackedVal::::WIDTH; - let x = *PackedVal::::from_slice(&coset[i_range.clone()]); - let is_transition = x - subgroup_last; - let is_first_row = *PackedVal::::from_slice(&lagrange_first_evals[i_range.clone()]); - let is_last_row = *PackedVal::::from_slice(&lagrange_last_evals[i_range]); + let is_first_row = *PackedVal::::from_slice(&sels.is_first_row[i_range.clone()]); + let is_last_row = *PackedVal::::from_slice(&sels.is_last_row[i_range.clone()]); + let is_transition = *PackedVal::::from_slice(&sels.is_transition[i_range.clone()]); + let inv_zeroifier = *PackedVal::::from_slice(&sels.inv_zeroifier[i_range.clone()]); - let local: Vec<_> = (0..trace_lde.width()) + let local = (0..width) .map(|col| { PackedVal::::from_fn(|offset| { - let row = wrap(i_local_start + offset); - trace_lde.get(row, col) + trace_on_quotient_domain.get(wrap(i_start + offset), col) }) }) - .collect(); - let next: Vec<_> = (0..trace_lde.width()) + .collect_vec(); + + let next = (0..width) .map(|col| { PackedVal::::from_fn(|offset| { - let row = wrap(i_next_start + offset); - trace_lde.get(row, col) + trace_on_quotient_domain.get(wrap(i_start + next_step + offset), col) }) }) - .collect(); + .collect_vec(); let accumulator = PackedChallenge::::zero(); let mut folder = ProverConstraintFolder { @@ -201,13 +171,11 @@ where air.eval(&mut folder); // quotient(x) = constraints(x) / Z_H(x) - let zerofier_inv: PackedVal = zerofier_on_coset.eval_inverse_packed(i_local_start); - let quotient = folder.accumulator * zerofier_inv; + let quotient = folder.accumulator * inv_zeroifier; // "Transpose" D packed base coefficients into WIDTH scalar extension coefficients. - let limit = PackedVal::::WIDTH.min(quotient_size); - (0..limit).map(move |idx_in_packing| { - let quotient_value = (0..>::D) + (0..PackedVal::::WIDTH).map(move |idx_in_packing| { + let quotient_value = (0..>>::D) .map(|coeff_idx| quotient.as_base_slice()[coeff_idx].as_slice()[idx_in_packing]) .collect_vec(); SC::Challenge::from_base_slice("ient_value) diff --git a/uni-stark/src/symbolic_builder.rs b/uni-stark/src/symbolic_builder.rs index 19ce5cec..9a02e04f 100644 --- a/uni-stark/src/symbolic_builder.rs +++ b/uni-stark/src/symbolic_builder.rs @@ -39,7 +39,7 @@ where .unwrap_or(0) } -#[instrument(name = "evalute constraints symbolically", skip_all, level = "debug")] +#[instrument(name = "evaluate constraints symbolically", skip_all, level = "debug")] pub fn get_symbolic_constraints(air: &A) -> Vec> where F: Field, diff --git a/uni-stark/src/verifier.rs b/uni-stark/src/verifier.rs index 96944aa2..3c92dff3 100644 --- a/uni-stark/src/verifier.rs +++ b/uni-stark/src/verifier.rs @@ -1,16 +1,14 @@ use alloc::vec; -use alloc::vec::Vec; +use itertools::Itertools; use p3_air::{Air, BaseAir, TwoRowMatrixView}; -use p3_challenger::{CanObserve, FieldChallenger}; -use p3_commit::UnivariatePcs; -use p3_field::{AbstractExtensionField, AbstractField, Field, TwoAdicField}; -use p3_matrix::Dimensions; -use p3_util::reverse_slice_index_bits; +use p3_challenger::{CanObserve, CanSample, FieldChallenger}; +use p3_commit::{Pcs, PolynomialSpace}; +use p3_field::{AbstractExtensionField, AbstractField, Field}; use tracing::instrument; use crate::symbolic_builder::{get_log_quotient_degree, SymbolicAirBuilder}; -use crate::{Proof, StarkGenericConfig, VerifierConstraintFolder}; +use crate::{Proof, StarkGenericConfig, Val, VerifierConstraintFolder}; #[instrument(skip_all)] pub fn verify( @@ -21,11 +19,8 @@ pub fn verify( ) -> Result<(), VerificationError> where SC: StarkGenericConfig, - A: Air> + for<'a> Air>, + A: Air>> + for<'a> Air>, { - let log_quotient_degree = get_log_quotient_degree::(air); - let quotient_degree = 1 << log_quotient_degree; - let Proof { commitments, opened_values, @@ -33,86 +28,99 @@ where degree_bits, } = proof; - let air_width = >::width(air); - let quotient_chunks = quotient_degree * >::D; + let degree = 1 << degree_bits; + let log_quotient_degree = get_log_quotient_degree::, A>(air); + let quotient_degree = 1 << log_quotient_degree; + + let pcs = config.pcs(); + let trace_domain = pcs.natural_domain_for_degree(degree); + let quotient_domain = + trace_domain.create_disjoint_domain(1 << (degree_bits + log_quotient_degree)); + let quotient_chunks_domains = quotient_domain.split_domains(quotient_degree); + + let air_width = >>::width(air); let valid_shape = opened_values.trace_local.len() == air_width && opened_values.trace_next.len() == air_width - && opened_values.quotient_chunks.len() == quotient_chunks; + && opened_values.quotient_chunks.len() == quotient_degree + && opened_values + .quotient_chunks + .iter() + .all(|qc| qc.len() == >>::D); if !valid_shape { return Err(VerificationError::InvalidProofShape); } - let g_subgroup = SC::Val::two_adic_generator(*degree_bits); - challenger.observe(commitments.trace.clone()); let alpha: SC::Challenge = challenger.sample_ext_element(); challenger.observe(commitments.quotient_chunks.clone()); - let zeta: SC::Challenge = challenger.sample_ext_element(); - - let local_and_next = [vec![zeta, zeta * g_subgroup]]; - let commits_and_points = &[ - (commitments.trace.clone(), local_and_next.as_slice()), - ( - commitments.quotient_chunks.clone(), - &[vec![zeta.exp_power_of_2(log_quotient_degree)]], - ), - ]; - let values = vec![ - vec![vec![ - opened_values.trace_local.clone(), - opened_values.trace_next.clone(), - ]], - vec![vec![opened_values.quotient_chunks.clone()]], - ]; - let dims = &[ - vec![Dimensions { - width: air_width, - height: 1 << degree_bits, - }], - vec![Dimensions { - width: quotient_chunks, - height: 1 << degree_bits, - }], - ]; - config - .pcs() - .verify_multi_batches(commits_and_points, dims, values, opening_proof, challenger) - .map_err(|_| VerificationError::InvalidOpeningArgument)?; - - // Derive the opening of the quotient polynomial, which was split into degree n chunks, then - // flattened into D base field polynomials. We first undo the flattening. - let challenge_ext_degree = >::D; - let mut quotient_parts: Vec = opened_values - .quotient_chunks - .chunks(challenge_ext_degree) - .map(|chunk| { - chunk + + let zeta: SC::Challenge = challenger.sample(); + let zeta_next = trace_domain.next_point(zeta).unwrap(); + + pcs.verify( + vec![ + ( + commitments.trace.clone(), + vec![( + trace_domain, + vec![ + (zeta, opened_values.trace_local.clone()), + (zeta_next, opened_values.trace_next.clone()), + ], + )], + ), + ( + commitments.quotient_chunks.clone(), + quotient_chunks_domains + .iter() + .zip(&opened_values.quotient_chunks) + .map(|(domain, values)| (*domain, vec![(zeta, values.clone())])) + .collect_vec(), + ), + ], + opening_proof, + challenger, + ) + .map_err(|_| VerificationError::InvalidOpeningArgument)?; + + let zps = quotient_chunks_domains + .iter() + .enumerate() + .map(|(i, domain)| { + quotient_chunks_domains .iter() .enumerate() - .map(|(i, &c)| >::monomial(i) * c) - .sum() + .filter(|(j, _)| *j != i) + .map(|(_, other_domain)| { + other_domain.zp_at_point(zeta) + * other_domain.zp_at_point(domain.first_point()).inverse() + }) + .product::() + }) + .collect_vec(); + + let quotient = opened_values + .quotient_chunks + .iter() + .enumerate() + .map(|(ch_i, ch)| { + ch.iter() + .enumerate() + .map(|(e_i, &c)| zps[ch_i] * SC::Challenge::monomial(e_i) * c) + .sum::() }) - .collect(); - // Then we reconstruct the larger quotient polynomial from its degree-n parts. - reverse_slice_index_bits(&mut quotient_parts); - let quotient: SC::Challenge = zeta - .powers() - .zip(quotient_parts) - .map(|(weight, part)| part * weight) - .sum(); - - let z_h = zeta.exp_power_of_2(*degree_bits) - SC::Challenge::one(); - let is_first_row = z_h / (zeta - SC::Val::one()); - let is_last_row = z_h / (zeta - g_subgroup.inverse()); - let is_transition = zeta - g_subgroup.inverse(); + .sum::(); + + let sels = trace_domain.selectors_at_point(zeta); + let mut folder = VerifierConstraintFolder { main: TwoRowMatrixView { local: &opened_values.trace_local, next: &opened_values.trace_next, }, - is_first_row, - is_last_row, - is_transition, + is_first_row: sels.is_first_row, + is_last_row: sels.is_last_row, + is_transition: sels.is_transition, alpha, accumulator: SC::Challenge::zero(), }; @@ -120,8 +128,8 @@ where let folded_constraints = folder.accumulator; // Finally, check that - // folded_constraints(zeta) = Z_H(zeta) * quotient(zeta) - if folded_constraints != z_h * quotient { + // folded_constraints(zeta) / Z_H(zeta) = quotient(zeta) + if folded_constraints * sels.inv_zeroifier != quotient { return Err(VerificationError::OodEvaluationMismatch); } diff --git a/uni-stark/tests/mul_air.rs b/uni-stark/tests/mul_air.rs index 5c45896a..d3166276 100644 --- a/uni-stark/tests/mul_air.rs +++ b/uni-stark/tests/mul_air.rs @@ -1,30 +1,86 @@ +use std::marker::PhantomData; + use itertools::Itertools; use p3_air::{Air, AirBuilder, BaseAir}; -use p3_baby_bear::BabyBear; -use p3_challenger::DuplexChallenger; +use p3_baby_bear::{BabyBear, DiffusionMatrixBabybear}; +use p3_challenger::{DuplexChallenger, HashChallenger, SerializingChallenger32}; +use p3_circle::{Cfft, CirclePcs}; +use p3_commit::testing::TrivialPcs; use p3_commit::ExtensionMmcs; use p3_dft::Radix2DitParallel; use p3_field::extension::BinomialExtensionField; -use p3_field::Field; -use p3_fri::{FriConfig, TwoAdicFriPcs, TwoAdicFriPcsConfig}; +use p3_field::{AbstractField, Field}; +use p3_fri::{FriConfig, TwoAdicFriPcs}; +use p3_keccak::Keccak256Hash; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::MatrixRowSlices; use p3_merkle_tree::FieldMerkleTreeMmcs; -use p3_poseidon2::{DiffusionMatrixBabybear, Poseidon2}; -use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; -use p3_uni_stark::{prove, verify, StarkConfig, VerificationError}; +use p3_mersenne_31::Mersenne31; +use p3_poseidon2::Poseidon2; +use p3_symmetric::{ + CompressionFunctionFromHasher, PaddingFreeSponge, SerializingHasher32, TruncatedPermutation, +}; +use p3_uni_stark::{prove, verify, StarkConfig, StarkGenericConfig, Val, VerificationError}; use rand::distributions::{Distribution, Standard}; use rand::{thread_rng, Rng}; -use tracing_forest::ForestLayer; -use tracing_subscriber::layer::SubscriberExt; -use tracing_subscriber::util::SubscriberInitExt; -use tracing_subscriber::{EnvFilter, Registry}; /// How many `a * b = c` operations to do per row in the AIR. -const REPETITIONS: usize = 10; +const REPETITIONS: usize = 20; const TRACE_WIDTH: usize = REPETITIONS * 3; -struct MulAir; +/* +In its basic form, asserts a^(self.degree-1) * b = c +(so that the total constraint degree is self.degree) + + +If `uses_transition_constraints`, checks that on transition rows, the first a = row number +*/ +pub struct MulAir { + degree: u64, + uses_boundary_constraints: bool, + uses_transition_constraints: bool, +} + +impl Default for MulAir { + fn default() -> Self { + MulAir { + degree: 3, + uses_boundary_constraints: true, + uses_transition_constraints: true, + } + } +} + +impl MulAir { + pub fn random_valid_trace(&self, rows: usize, valid: bool) -> RowMajorMatrix + where + Standard: Distribution, + { + let mut rng = thread_rng(); + let mut trace_values = vec![F::default(); rows * TRACE_WIDTH]; + for (i, (a, b, c)) in trace_values.iter_mut().tuples().enumerate() { + let row = i / REPETITIONS; + + *a = if self.uses_transition_constraints { + F::from_canonical_usize(i) + } else { + rng.gen() + }; + *b = if self.uses_boundary_constraints && row == 0 { + a.square() + F::one() + } else { + rng.gen() + }; + *c = a.exp_u64(self.degree - 1) * *b; + + if !valid { + // make it invalid + *c *= F::two(); + } + } + RowMajorMatrix::new(trace_values, TRACE_WIDTH) + } +} impl BaseAir for MulAir { fn width(&self) -> usize { @@ -36,40 +92,104 @@ impl Air for MulAir { fn eval(&self, builder: &mut AB) { let main = builder.main(); let main_local = main.row_slice(0); + let main_next = main.row_slice(1); for i in 0..REPETITIONS { let start = i * 3; let a = main_local[start]; let b = main_local[start + 1]; let c = main_local[start + 2]; - builder.assert_zero(a * b - c); + builder.assert_zero(a.into().exp_u64(self.degree - 1) * b - c); + if self.uses_boundary_constraints { + builder + .when_first_row() + .assert_eq(a * a + AB::Expr::one(), b); + } + if self.uses_transition_constraints { + let next_a = main_next[start]; + builder + .when_transition() + .assert_eq(a + AB::Expr::from_canonical_usize(REPETITIONS), next_a); + } } } } -fn random_valid_trace(rows: usize) -> RowMajorMatrix +fn do_test( + config: SC, + air: MulAir, + log_height: usize, + challenger: SC::Challenger, +) -> Result<(), VerificationError> where - Standard: Distribution, + SC::Challenger: Clone, + Standard: Distribution>, { - let mut rng = thread_rng(); - let mut trace_values = vec![F::default(); rows * TRACE_WIDTH]; - for (a, b, c) in trace_values.iter_mut().tuples() { - *a = rng.gen(); - *b = rng.gen(); - *c = *a * *b; - } - RowMajorMatrix::new(trace_values, TRACE_WIDTH) + let trace = air.random_valid_trace(log_height, true); + + let mut p_challenger = challenger.clone(); + let proof = prove(&config, &air, &mut p_challenger, trace); + + let serialized_proof = postcard::to_allocvec(&proof).expect("unable to serialize proof"); + tracing::debug!("serialized_proof len: {} bytes", serialized_proof.len()); + + let deserialized_proof = + postcard::from_bytes(&serialized_proof).expect("unable to deserialize proof"); + + let mut v_challenger = challenger.clone(); + verify(&config, &air, &mut v_challenger, &deserialized_proof) +} + +fn do_test_bb_trivial(degree: u64, log_n: usize) -> Result<(), VerificationError> { + type Val = BabyBear; + type Challenge = BinomialExtensionField; + + type Perm = Poseidon2; + let perm = Perm::new_from_rng(8, 22, DiffusionMatrixBabybear, &mut thread_rng()); + + type Dft = Radix2DitParallel; + let dft = Dft {}; + + type Challenger = DuplexChallenger; + + type Pcs = TrivialPcs; + let pcs = p3_commit::testing::TrivialPcs { + dft, + log_n, + _phantom: PhantomData, + }; + + type MyConfig = StarkConfig; + let config = MyConfig::new(pcs); + + let air = MulAir { + degree, + ..Default::default() + }; + + do_test(config, air, 1 << log_n, Challenger::new(perm)) } #[test] -fn test_prove_baby_bear() -> Result<(), VerificationError> { - Registry::default() - .with(EnvFilter::from_default_env()) - .with(ForestLayer::default()) - .init(); +fn prove_bb_trivial_deg2() -> Result<(), VerificationError> { + do_test_bb_trivial(2, 10) +} - const HEIGHT: usize = 1 << 6; +#[test] +fn prove_bb_trivial_deg3() -> Result<(), VerificationError> { + do_test_bb_trivial(3, 10) +} +#[test] +fn prove_bb_trivial_deg4() -> Result<(), VerificationError> { + do_test_bb_trivial(4, 10) +} + +fn do_test_bb_twoadic( + log_blowup: usize, + degree: u64, + log_n: usize, +) -> Result<(), VerificationError> { type Val = BabyBear; type Challenge = BinomialExtensionField; @@ -100,34 +220,107 @@ fn test_prove_baby_bear() -> Result<(), VerificationError> { type Challenger = DuplexChallenger; let fri_config = FriConfig { - log_blowup: 1, + log_blowup, num_queries: 40, proof_of_work_bits: 8, mmcs: challenge_mmcs, }; - type Pcs = - TwoAdicFriPcs>; - let pcs = Pcs::new(fri_config, dft, val_mmcs); + type Pcs = TwoAdicFriPcs; + let pcs = Pcs::new(log_n, dft, val_mmcs, fri_config); - type MyConfig = StarkConfig; - let config = StarkConfig::new(pcs); + type MyConfig = StarkConfig; + let config = MyConfig::new(pcs); - let mut challenger = Challenger::new(perm.clone()); - let trace = random_valid_trace::(HEIGHT); - let proof = prove::(&config, &MulAir, &mut challenger, trace); + let air = MulAir { + degree, + ..Default::default() + }; - let serialized_proof = postcard::to_allocvec(&proof).expect("unable to serialize proof"); - tracing::debug!("serialized_proof len: {} bytes", serialized_proof.len()); + do_test(config, air, 1 << log_n, Challenger::new(perm)) +} - let deserialized_proof = - postcard::from_bytes(&serialized_proof).expect("unable to deserialize proof"); +#[test] +fn prove_bb_twoadic_deg2() -> Result<(), VerificationError> { + do_test_bb_twoadic(1, 2, 10) +} + +#[test] +fn prove_bb_twoadic_deg3() -> Result<(), VerificationError> { + do_test_bb_twoadic(1, 3, 10) +} + +#[test] +fn prove_bb_twoadic_deg4() -> Result<(), VerificationError> { + do_test_bb_twoadic(2, 4, 10) +} + +#[test] +fn prove_bb_twoadic_deg5() -> Result<(), VerificationError> { + do_test_bb_twoadic(2, 5, 10) +} + +fn do_test_m31_circle( + log_blowup: usize, + degree: u64, + log_n: usize, +) -> Result<(), VerificationError> { + type Val = Mersenne31; + // type Challenge = BinomialExtensionField; + type Challenge = Mersenne31; - let mut challenger = Challenger::new(perm); - verify(&config, &MulAir, &mut challenger, &deserialized_proof) + type ByteHash = Keccak256Hash; + type FieldHash = SerializingHasher32; + let byte_hash = ByteHash {}; + let field_hash = FieldHash::new(byte_hash); + + type MyCompress = CompressionFunctionFromHasher; + let compress = MyCompress::new(byte_hash); + + type ValMmcs = FieldMerkleTreeMmcs; + let val_mmcs = ValMmcs::new(field_hash, compress); + + type ChallengeMmcs = ExtensionMmcs; + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + + type Challenger = SerializingChallenger32>; + + let _fri_config = FriConfig { + log_blowup, + num_queries: 40, + proof_of_work_bits: 8, + mmcs: challenge_mmcs, + }; + + type Pcs = CirclePcs; + let pcs = Pcs { + log_blowup: 1, + cfft: Cfft::default(), + mmcs: val_mmcs, + }; + + type MyConfig = StarkConfig; + let config = MyConfig::new(pcs); + + let air = MulAir { + degree, + uses_boundary_constraints: true, + uses_transition_constraints: true, + }; + + do_test( + config, + air, + 1 << log_n, + Challenger::from_hasher(vec![], byte_hash), + ) +} + +#[test] +fn prove_m31_circle_deg2() -> Result<(), VerificationError> { + do_test_m31_circle(1, 2, 12) } #[test] -#[ignore] // TODO: Not ready yet. -fn test_prove_mersenne_31() { - todo!() +fn prove_m31_circle_deg3() -> Result<(), VerificationError> { + do_test_m31_circle(1, 3, 14) } diff --git a/util/src/linear_map.rs b/util/src/linear_map.rs index 165a1289..1c30d023 100644 --- a/util/src/linear_map.rs +++ b/util/src/linear_map.rs @@ -4,7 +4,8 @@ use core::mem; use crate::VecExt; /// O(n) Vec-backed map for keys that only implement Eq. -/// Only use this for a very small number of keys. +/// Only use this for a very small number of keys, operations +/// on it can easily become O(n^2). pub struct LinearMap(Vec<(K, V)>); impl Default for LinearMap { @@ -23,6 +24,7 @@ impl LinearMap { pub fn get_mut(&mut self, k: &K) -> Option<&mut V> { self.0.iter_mut().find(|(kk, _)| kk == k).map(|(_, v)| v) } + /// This is O(n), because we check for an existing entry. pub fn insert(&mut self, k: K, mut v: V) -> Option { if let Some(vv) = self.get_mut(&k) { mem::swap(&mut v, vv); @@ -47,6 +49,7 @@ impl LinearMap { } impl FromIterator<(K, V)> for LinearMap { + /// This calls `insert` in a loop, so is O(n^2)!! fn from_iter>(iter: T) -> Self { let mut me = LinearMap::default(); for (k, v) in iter {