Skip to content

Commit

Permalink
Faster random_mod (#703)
Browse files Browse the repository at this point in the history
  • Loading branch information
dvdplm authored Jan 3, 2025
1 parent 6993ac0 commit a0e1b3a
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 16 deletions.
122 changes: 120 additions & 2 deletions benches/uint.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,125 @@
use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion};
use crypto_bigint::{
Limb, NonZero, Odd, Random, RandomMod, Reciprocal, Uint, U128, U2048, U256, U4096,
Limb, NonZero, Odd, Random, RandomBits, RandomMod, Reciprocal, Uint, U1024, U128, U2048, U256,
U4096, U512,
};
use rand_core::OsRng;
use rand_chacha::ChaCha8Rng;
use rand_core::{OsRng, RngCore, SeedableRng};

fn make_rng() -> ChaCha8Rng {
ChaCha8Rng::from_seed(*b"01234567890123456789012345678901")
}

fn bench_random(c: &mut Criterion) {
let mut group = c.benchmark_group("bounded random");

let mut rng = make_rng();
group.bench_function("random_mod, U1024", |b| {
let bound = U1024::random(&mut rng);
let bound_nz = NonZero::new(bound).unwrap();
b.iter(|| black_box(U1024::random_mod(&mut rng, &bound_nz)));
});

let mut rng = make_rng();
group.bench_function("random_bits, U1024", |b| {
let bound = U1024::random(&mut rng);
let bound_bits = bound.bits_vartime();
b.iter(|| {
let mut r = U1024::random_bits(&mut rng, bound_bits);
while r >= bound {
r = U1024::random_bits(&mut rng, bound_bits);
}
black_box(r)
});
});

let mut rng = make_rng();
group.bench_function("random_mod, U1024, small bound", |b| {
let bound = U1024::from_u64(rng.next_u64());
let bound_nz = NonZero::new(bound).unwrap();
b.iter(|| black_box(U1024::random_mod(&mut rng, &bound_nz)));
});

let mut rng = make_rng();
group.bench_function("random_bits, U1024, small bound", |b| {
let bound = U1024::from_u64(rng.next_u64());
let bound_bits = bound.bits_vartime();
b.iter(|| {
let mut r = U1024::random_bits(&mut rng, bound_bits);
while r >= bound {
r = U1024::random_bits(&mut rng, bound_bits);
}
black_box(r)
});
});

let mut rng = make_rng();
group.bench_function("random_mod, U1024, 512 bit bound low", |b| {
let bound = U512::random(&mut rng);
let bound = U1024::from((bound, U512::ZERO));
let bound_nz = NonZero::new(bound).unwrap();
b.iter(|| black_box(U1024::random_mod(&mut rng, &bound_nz)));
});

let mut rng = make_rng();
group.bench_function("random_bits, U1024, 512 bit bound low", |b| {
let bound = U512::random(&mut rng);
let bound = U1024::from((bound, U512::ZERO));
let bound_bits = bound.bits_vartime();
b.iter(|| {
let mut r = U1024::random_bits(&mut rng, bound_bits);
while r >= bound {
r = U1024::random_bits(&mut rng, bound_bits);
}
black_box(r)
});
});

let mut rng = make_rng();
group.bench_function("random_mod, U1024, 512 bit bound hi", |b| {
let bound = U512::random(&mut rng);
let bound = U1024::from((U512::ZERO, bound));
let bound_nz = NonZero::new(bound).unwrap();
b.iter(|| black_box(U1024::random_mod(&mut rng, &bound_nz)));
});

let mut rng = make_rng();
group.bench_function("random_bits, U1024, 512 bit bound hi", |b| {
let bound = U512::random(&mut rng);
let bound = U1024::from((U512::ZERO, bound));
let bound_bits = bound.bits_vartime();
b.iter(|| {
let mut r = U1024::random_bits(&mut rng, bound_bits);
while r >= bound {
r = U1024::random_bits(&mut rng, bound_bits);
}
black_box(r)
});
});

// Slow case: the hi limb is just `2`
let mut rng = make_rng();
group.bench_function("random_mod, U1024, tiny high limb", |b| {
let hex_1024 = "00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000291A6B42D1C7D2A7184D13E36F65773BBEFB4FA7996101300D49F09962A361F00";
let modulus = U1024::from_be_hex(hex_1024);
let modulus_nz = NonZero::new(modulus).unwrap();
b.iter(|| black_box(U1024::random_mod(&mut rng, &modulus_nz)));
});

// Slow case: the hi limb is just `2`
let mut rng = make_rng();
group.bench_function("random_bits, U1024, tiny high limb", |b| {
let hex_1024 = "00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000291A6B42D1C7D2A7184D13E36F65773BBEFB4FA7996101300D49F09962A361F00";
let bound = U1024::from_be_hex(hex_1024);
let bound_bits = bound.bits_vartime();
b.iter(|| {
let mut r = U1024::random_bits(&mut rng, bound_bits);
while r >= bound {
r = U1024::random_bits(&mut rng, bound_bits);
}
});
});
}

fn bench_mul(c: &mut Criterion) {
let mut group = c.benchmark_group("wrapping ops");
Expand Down Expand Up @@ -370,6 +487,7 @@ fn bench_sqrt(c: &mut Criterion) {

criterion_group!(
benches,
bench_random,
bench_mul,
bench_division,
bench_gcd,
Expand Down
36 changes: 22 additions & 14 deletions src/uint/rand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,42 +93,50 @@ pub(super) fn random_mod_core<T>(
modulus: &NonZero<T>,
n_bits: u32,
) where
T: AsMut<[Limb]> + ConstantTimeLess + Zero,
T: AsMut<[Limb]> + AsRef<[Limb]> + ConstantTimeLess + Zero,
{
let n_bytes = ((n_bits + 7) / 8) as usize;
#[cfg(target_pointer_width = "64")]
let mut next_word = || rng.next_u64();
#[cfg(target_pointer_width = "32")]
let mut next_word = || rng.next_u32();

let n_limbs = n_bits.div_ceil(Limb::BITS) as usize;
let hi_bytes = n_bytes - (n_limbs - 1) * Limb::BYTES;

let mut bytes = Limb::ZERO.to_le_bytes();
let hi_word_modulus = modulus.as_ref().as_ref()[n_limbs - 1].0;
let mask = !0 >> hi_word_modulus.leading_zeros();
let mut hi_word = next_word() & mask;

loop {
while hi_word > hi_word_modulus {
hi_word = next_word() & mask;
}
// Set high limb
n.as_mut()[n_limbs - 1] = Limb::from_le_bytes(hi_word.to_le_bytes());
// Set low limbs
for i in 0..n_limbs - 1 {
rng.fill_bytes(bytes.as_mut());
// Need to deserialize from little-endian to make sure that two 32-bit limbs
// deserialized sequentially are equal to one 64-bit limb produced from the same
// byte stream.
n.as_mut()[i] = Limb::from_le_bytes(bytes);
n.as_mut()[i] = Limb::from_le_bytes(next_word().to_le_bytes());
}

// Generate the high limb which may need to only be filled partially.
bytes = Limb::ZERO.to_le_bytes();
rng.fill_bytes(&mut bytes[..hi_bytes]);
n.as_mut()[n_limbs - 1] = Limb::from_le_bytes(bytes);

// If the high limb is equal to the modulus' high limb, it's still possible
// that the full uint is too big so we check and repeat if it is.
if n.ct_lt(modulus).into() {
break;
}
hi_word = next_word() & mask;
}
}

#[cfg(test)]
mod tests {
use crate::{Limb, NonZero, RandomBits, RandomMod, U256};
use rand_chacha::ChaCha8Rng;
use rand_core::SeedableRng;

#[test]
fn random_mod() {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(1);
let mut rng = ChaCha8Rng::seed_from_u64(1);

// Ensure `random_mod` runs in a reasonable amount of time
let modulus = NonZero::new(U256::from(42u8)).unwrap();
Expand All @@ -148,7 +156,7 @@ mod tests {

#[test]
fn random_bits() {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(1);
let mut rng = ChaCha8Rng::seed_from_u64(1);

let lower_bound = 16;

Expand Down

0 comments on commit a0e1b3a

Please sign in to comment.