diff --git a/Cargo.toml b/Cargo.toml index 89dd4e40..91cb0618 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,11 @@ name = "boxed_residue" harness = false required-features = ["alloc"] +[[bench]] +name = "boxed_uint" +harness = false +required-features = ["alloc"] + [[bench]] name = "dyn_residue" harness = false diff --git a/benches/boxed_uint.rs b/benches/boxed_uint.rs new file mode 100644 index 00000000..b34eaa33 --- /dev/null +++ b/benches/boxed_uint.rs @@ -0,0 +1,48 @@ +use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; +use crypto_bigint::BoxedUint; +use rand_core::OsRng; + +/// Size of `BoxedUint` to use in benchmark. +const UINT_BITS: u32 = 4096; + +fn bench_shifts(c: &mut Criterion) { + let mut group = c.benchmark_group("bit shifts"); + + group.bench_function("shl_vartime", |b| { + b.iter_batched( + || BoxedUint::random(&mut OsRng, UINT_BITS), + |x| black_box(x.shl_vartime(UINT_BITS / 2 + 10)), + BatchSize::SmallInput, + ) + }); + + group.bench_function("shl", |b| { + b.iter_batched( + || BoxedUint::random(&mut OsRng, UINT_BITS), + |x| x.shl(UINT_BITS / 2 + 10), + BatchSize::SmallInput, + ) + }); + + group.bench_function("shr_vartime", |b| { + b.iter_batched( + || BoxedUint::random(&mut OsRng, UINT_BITS), + |x| black_box(x.shr_vartime(UINT_BITS / 2 + 10)), + BatchSize::SmallInput, + ) + }); + + group.bench_function("shr", |b| { + b.iter_batched( + || BoxedUint::random(&mut OsRng, UINT_BITS), + |x| x.shr(UINT_BITS / 2 + 10), + BatchSize::SmallInput, + ) + }); + + group.finish(); +} + +criterion_group!(benches, bench_shifts); + +criterion_main!(benches); diff --git a/benches/uint.rs b/benches/uint.rs index 944d8c24..3688a2d9 100644 --- a/benches/uint.rs +++ b/benches/uint.rs @@ -1,11 +1,10 @@ -use criterion::{ - black_box, criterion_group, criterion_main, measurement::Measurement, BatchSize, - BenchmarkGroup, Criterion, -}; -use crypto_bigint::{Limb, NonZero, Random, Reciprocal, U128, U2048, U256}; +use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; +use crypto_bigint::{Limb, NonZero, Random, Reciprocal, Uint, U128, U2048, U256}; use rand_core::OsRng; -fn bench_division(group: &mut BenchmarkGroup<'_, M>) { +fn bench_division(c: &mut Criterion) { + let mut group = c.benchmark_group("wrapping ops"); + group.bench_function("div/rem, U256/U128, full size", |b| { b.iter_batched( || { @@ -69,9 +68,13 @@ fn bench_division(group: &mut BenchmarkGroup<'_, M>) { BatchSize::SmallInput, ) }); + + group.finish(); } -fn bench_shifts(group: &mut BenchmarkGroup<'_, M>) { +fn bench_shl(c: &mut Criterion) { + let mut group = c.benchmark_group("left shift"); + group.bench_function("shl_vartime, small, U2048", |b| { b.iter_batched(|| U2048::ONE, |x| x.shl_vartime(10), BatchSize::SmallInput) }); @@ -84,16 +87,54 @@ fn bench_shifts(group: &mut BenchmarkGroup<'_, M>) { ) }); + group.bench_function("shl_vartime_wide, large, U2048", |b| { + b.iter_batched( + || (U2048::ONE, U2048::ONE), + |x| Uint::shl_vartime_wide(x, 1024 + 10), + BatchSize::SmallInput, + ) + }); + group.bench_function("shl, U2048", |b| { b.iter_batched(|| U2048::ONE, |x| x.shl(1024 + 10), BatchSize::SmallInput) }); + group.finish(); +} + +fn bench_shr(c: &mut Criterion) { + let mut group = c.benchmark_group("right shift"); + + group.bench_function("shr_vartime, small, U2048", |b| { + b.iter_batched(|| U2048::ONE, |x| x.shr_vartime(10), BatchSize::SmallInput) + }); + + group.bench_function("shr_vartime, large, U2048", |b| { + b.iter_batched( + || U2048::ONE, + |x| x.shr_vartime(1024 + 10), + BatchSize::SmallInput, + ) + }); + + group.bench_function("shr_vartime_wide, large, U2048", |b| { + b.iter_batched( + || (U2048::ONE, U2048::ONE), + |x| Uint::shr_vartime_wide(x, 1024 + 10), + BatchSize::SmallInput, + ) + }); + group.bench_function("shr, U2048", |b| { b.iter_batched(|| U2048::ONE, |x| x.shr(1024 + 10), BatchSize::SmallInput) }); + + group.finish(); } -fn bench_inv_mod(group: &mut BenchmarkGroup<'_, M>) { +fn bench_inv_mod(c: &mut Criterion) { + let mut group = c.benchmark_group("modular ops"); + group.bench_function("inv_odd_mod, U256", |b| { b.iter_batched( || { @@ -144,21 +185,10 @@ fn bench_inv_mod(group: &mut BenchmarkGroup<'_, M>) { BatchSize::SmallInput, ) }); -} -fn bench_wrapping_ops(c: &mut Criterion) { - let mut group = c.benchmark_group("wrapping ops"); - bench_division(&mut group); - group.finish(); -} - -fn bench_modular_ops(c: &mut Criterion) { - let mut group = c.benchmark_group("modular ops"); - bench_shifts(&mut group); - bench_inv_mod(&mut group); group.finish(); } -criterion_group!(benches, bench_wrapping_ops, bench_modular_ops); +criterion_group!(benches, bench_shl, bench_shr, bench_division, bench_inv_mod); criterion_main!(benches); diff --git a/src/lib.rs b/src/lib.rs index 5b1b1fc4..ecf524cd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,7 +45,7 @@ //! U256::from_be_hex("ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551"); //! //! // Compute `MODULUS` shifted right by 1 at compile time -//! pub const MODULUS_SHR1: U256 = MODULUS.shr(1); +//! pub const MODULUS_SHR1: U256 = MODULUS.shr(1).0; //! ``` //! //! Note that large constant computations may accidentally trigger a the `const_eval_limit` of the compiler. diff --git a/src/limb/bit_not.rs b/src/limb/bit_not.rs index 26676d59..6d728d45 100644 --- a/src/limb/bit_not.rs +++ b/src/limb/bit_not.rs @@ -5,6 +5,7 @@ use core::ops::Not; impl Limb { /// Calculates `!a`. + #[inline(always)] pub const fn not(self) -> Self { Limb(!self.0) } diff --git a/src/limb/bit_or.rs b/src/limb/bit_or.rs index f863ac0d..340f4f76 100644 --- a/src/limb/bit_or.rs +++ b/src/limb/bit_or.rs @@ -5,6 +5,7 @@ use core::ops::{BitOr, BitOrAssign}; impl Limb { /// Calculates `a | b`. + #[inline(always)] pub const fn bitor(self, rhs: Self) -> Self { Limb(self.0 | rhs.0) } diff --git a/src/limb/bit_xor.rs b/src/limb/bit_xor.rs index a5078229..7c04e7b7 100644 --- a/src/limb/bit_xor.rs +++ b/src/limb/bit_xor.rs @@ -5,6 +5,7 @@ use core::ops::BitXor; impl Limb { /// Calculates `a ^ b`. + #[inline(always)] pub const fn bitxor(self, rhs: Self) -> Self { Limb(self.0 ^ rhs.0) } diff --git a/src/limb/bits.rs b/src/limb/bits.rs index 1c7674f4..4553137b 100644 --- a/src/limb/bits.rs +++ b/src/limb/bits.rs @@ -2,21 +2,25 @@ use super::Limb; impl Limb { /// Calculate the number of bits needed to represent this number. + #[inline(always)] pub const fn bits(self) -> u32 { Limb::BITS - self.0.leading_zeros() } /// Calculate the number of leading zeros in the binary representation of this number. + #[inline(always)] pub const fn leading_zeros(self) -> u32 { self.0.leading_zeros() } /// Calculate the number of trailing zeros in the binary representation of this number. + #[inline(always)] pub const fn trailing_zeros(self) -> u32 { self.0.trailing_zeros() } /// Calculate the number of trailing ones the binary representation of this number. + #[inline(always)] pub const fn trailing_ones(self) -> u32 { self.0.trailing_ones() } diff --git a/src/limb/mul.rs b/src/limb/mul.rs index 7f8b0845..1ea73b4e 100644 --- a/src/limb/mul.rs +++ b/src/limb/mul.rs @@ -17,7 +17,7 @@ impl Limb { } /// Perform saturating multiplication. - #[inline] + #[inline(always)] pub const fn saturating_mul(&self, rhs: Self) -> Self { Limb(self.0.saturating_mul(rhs.0)) } diff --git a/src/limb/shl.rs b/src/limb/shl.rs index 03e4a103..0e655387 100644 --- a/src/limb/shl.rs +++ b/src/limb/shl.rs @@ -10,6 +10,12 @@ impl Limb { pub const fn shl(self, shift: u32) -> Self { Limb(self.0 << shift) } + + /// Computes `self << 1` and return the result and the carry (0 or 1). + #[inline(always)] + pub(crate) const fn shl1(self) -> (Self, Self) { + (Self(self.0 << 1), Self(self.0 >> Self::HI_BIT)) + } } impl Shl for Limb { diff --git a/src/limb/shr.rs b/src/limb/shr.rs index a91c65d5..10fca6dc 100644 --- a/src/limb/shr.rs +++ b/src/limb/shr.rs @@ -10,6 +10,12 @@ impl Limb { pub const fn shr(self, shift: u32) -> Self { Limb(self.0 >> shift) } + + /// Computes `self >> 1` and return the result and the carry (0 or `1 << HI_BIT`). + #[inline(always)] + pub(crate) const fn shr1(self) -> (Self, Self) { + (Self(self.0 >> 1), Self(self.0 << Self::HI_BIT)) + } } impl Shr for Limb { diff --git a/src/modular/div_by_2.rs b/src/modular/div_by_2.rs index 278f3dda..12d82ed7 100644 --- a/src/modular/div_by_2.rs +++ b/src/modular/div_by_2.rs @@ -18,7 +18,7 @@ pub(crate) fn div_by_2(a: &Uint, modulus: &Uint { diff --git a/src/uint/boxed/bits.rs b/src/uint/boxed/bits.rs index fa681d22..e60ebae0 100644 --- a/src/uint/boxed/bits.rs +++ b/src/uint/boxed/bits.rs @@ -84,7 +84,7 @@ mod tests { fn uint_with_bits_at(positions: &[u32]) -> BoxedUint { let mut result = BoxedUint::zero_with_precision(256); for &pos in positions { - result |= BoxedUint::one_with_precision(256).shl_vartime(pos); + result |= BoxedUint::one_with_precision(256).shl_vartime(pos).unwrap(); } result } diff --git a/src/uint/boxed/div.rs b/src/uint/boxed/div.rs index c8a66213..d82bb982 100644 --- a/src/uint/boxed/div.rs +++ b/src/uint/boxed/div.rs @@ -37,7 +37,8 @@ impl BoxedUint { let mb = rhs.bits(); let mut bd = self.bits_precision() - mb; let mut rem = self.clone(); - let mut c = rhs.shl_vartime(bd); + // Will not overflow since `bd < bits_precision` + let mut c = rhs.shl_vartime(bd).expect("shift within range"); loop { let (r, borrow) = rem.sbb(&c, Limb::ZERO); @@ -77,7 +78,7 @@ impl BoxedUint { let bits_precision = self.bits_precision(); let mut rem = self.clone(); let mut quo = Self::zero_with_precision(bits_precision); - let mut c = rhs.shl(bits_precision - mb); + let (mut c, _overflow) = rhs.shl(bits_precision - mb); let mut i = bits_precision; let mut done = Choice::from(0u8); @@ -110,7 +111,8 @@ impl BoxedUint { let mut bd = self.bits_precision() - mb; let mut remainder = self.clone(); let mut quotient = Self::zero_with_precision(self.bits_precision()); - let mut c = rhs.shl_vartime(bd); + // Will not overflow since `bd < bits_precision` + let mut c = rhs.shl_vartime(bd).expect("shift within range"); loop { let (mut r, borrow) = remainder.sbb(&c, Limb::ZERO); diff --git a/src/uint/boxed/inv_mod.rs b/src/uint/boxed/inv_mod.rs index d4c063c5..d873ef57 100644 --- a/src/uint/boxed/inv_mod.rs +++ b/src/uint/boxed/inv_mod.rs @@ -11,7 +11,7 @@ impl BoxedUint { // Decompose `modulus = s * 2^k` where `s` is odd let k = modulus.trailing_zeros(); - let s = modulus.shr(k); + let s = modulus >> k; // Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses. // Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1` @@ -26,7 +26,7 @@ impl BoxedUint { let (m_odd_inv, _is_some) = s.inv_mod2k(k); // `s` is odd, so this always exists // This part is mod 2^k - let mask = Self::one().shl(k).wrapping_sub(&Self::one()); + let mask = (Self::one() << k).wrapping_sub(&Self::one()); let t = (b.wrapping_sub(&a).wrapping_mul(&m_odd_inv)).bitand(&mask); // Will not overflow since `a <= s - 1`, `t <= 2^k - 1`, @@ -126,9 +126,9 @@ impl BoxedUint { let cyy = new_u.conditional_adc_assign(modulus, cy); debug_assert!(bool::from(cy.ct_eq(&cyy))); - let (new_a, overflow) = a.shr1_with_overflow(); - debug_assert!(bool::from(!modulus_is_odd | !overflow)); - let (mut new_u, cy) = new_u.shr1_with_overflow(); + let (new_a, carry) = a.shr1_with_carry(); + debug_assert!(bool::from(!modulus_is_odd | !carry)); + let (mut new_u, cy) = new_u.shr1_with_carry(); let cy = new_u.conditional_adc_assign(&m1hp, cy); debug_assert!(bool::from(!modulus_is_odd | !cy)); diff --git a/src/uint/boxed/shl.rs b/src/uint/boxed/shl.rs index 7fa9d4af..8daee882 100644 --- a/src/uint/boxed/shl.rs +++ b/src/uint/boxed/shl.rs @@ -1,92 +1,106 @@ //! [`BoxedUint`] bitwise left shift operations. -use crate::{BoxedUint, CtChoice, Limb, Word}; +use crate::{BoxedUint, Limb}; use core::ops::{Shl, ShlAssign}; use subtle::{Choice, ConstantTimeLess}; impl BoxedUint { /// Computes `self << shift`. - /// Returns zero if `shift >= Self::BITS`. - pub fn shl(&self, shift: u32) -> Self { + /// + /// Returns a zero and a truthy `Choice` if `shift >= self.bits_precision()`, + /// or the result and a falsy `Choice` otherwise. + pub fn shl(&self, shift: u32) -> (Self, Choice) { + // `floor(log2(bits_precision - 1))` is the number of bits in the representation of `shift` + // (which lies in range `0 <= shift < bits_precision`). + let shift_bits = u32::BITS - (self.bits_precision() - 1).leading_zeros(); let overflow = !shift.ct_lt(&self.bits_precision()); let shift = shift % self.bits_precision(); - let log2_bits = u32::BITS - self.bits_precision().leading_zeros(); let mut result = self.clone(); + let mut temp = self.clone(); - for i in 0..log2_bits { + for i in 0..shift_bits { let bit = Choice::from(((shift >> i) & 1) as u8); - result = Self::conditional_select(&result, &result.shl_vartime(1 << i), bit); + temp.set_to_zero(); + // Will not overflow by construction + result + .shl_vartime_into(&mut temp, 1 << i) + .expect("shift within range"); + result.conditional_assign(&temp, bit); } - Self::conditional_select( - &result, - &Self::zero_with_precision(self.bits_precision()), - overflow, - ) + result.conditional_set_to_zero(overflow); + + (result, overflow) } - /// Computes `self << shift`. + /// Computes `self << shift` and writes the result into `dest`. + /// Returns `None` if `shift >= self.bits_precision()`. + /// + /// WARNING: for performance reasons, `dest` is assumed to be pre-zeroized. /// /// NOTE: this operation is variable time with respect to `shift` *ONLY*. /// /// When used with a fixed `shift`, this function is constant-time with respect to `self`. #[inline(always)] - pub fn shl_vartime(&self, shift: u32) -> Self { - let nlimbs = self.nlimbs(); - let mut limbs = vec![Limb::ZERO; nlimbs].into_boxed_slice(); - + fn shl_vartime_into(&self, dest: &mut Self, shift: u32) -> Option<()> { if shift >= self.bits_precision() { - return Self { limbs }; + return None; } + let nlimbs = self.nlimbs(); let shift_num = (shift / Limb::BITS) as usize; let rem = shift % Limb::BITS; - let mut i = nlimbs; - while i > shift_num { - i -= 1; - limbs[i] = self.limbs[i - shift_num]; + for i in shift_num..nlimbs { + dest.limbs[i] = self.limbs[i - shift_num]; } - let (new_lower, _carry) = (Self { limbs }).shl_limb(rem); - new_lower - } + if rem == 0 { + return Some(()); + } - /// Computes `self << shift` where `0 <= shift < Limb::BITS`, - /// returning the result and the carry. - #[inline(always)] - pub(crate) fn shl_limb(&self, shift: u32) -> (Self, Limb) { - let nlimbs = self.nlimbs(); - let mut limbs = vec![Limb::ZERO; nlimbs].into_boxed_slice(); - - let nz = CtChoice::from_u32_nonzero(shift); - let lshift = shift; - let rshift = nz.if_true_u32(Limb::BITS - shift); - let carry = nz.if_true_word(self.limbs[nlimbs - 1].0.wrapping_shr(Word::BITS - shift)); - - let mut i = nlimbs - 1; - while i > 0 { - let mut limb = self.limbs[i].0 << lshift; - let hi = self.limbs[i - 1].0 >> rshift; - limb |= nz.if_true_word(hi); - limbs[i] = Limb(limb); - i -= 1 + let mut carry = Limb::ZERO; + + for i in shift_num..nlimbs { + let shifted = dest.limbs[i].shl(rem); + let new_carry = dest.limbs[i].shr(Limb::BITS - rem); + dest.limbs[i] = shifted.bitor(carry); + carry = new_carry; } - limbs[0] = Limb(self.limbs[0].0 << lshift); - (Self { limbs }, Limb(carry)) + Some(()) + } + + /// Computes `self << shift`. + /// Returns `None` if `shift >= self.bits_precision()`. + /// + /// NOTE: this operation is variable time with respect to `shift` *ONLY*. + /// + /// When used with a fixed `shift`, this function is constant-time with respect to `self`. + #[inline(always)] + pub fn shl_vartime(&self, shift: u32) -> Option { + let mut result = Self::zero_with_precision(self.bits_precision()); + let success = self.shl_vartime_into(&mut result, shift); + success.map(|_| result) } - /// Computes `self >> 1` in constant-time. + /// Computes `self << 1` in constant-time. pub(crate) fn shl1(&self) -> Self { - // TODO(tarcieri): optimized implementation - self.shl_vartime(1) + let mut ret = self.clone(); + ret.shl1_assign(); + ret } - /// Computes `self >> 1` in-place in constant-time. + /// Computes `self << 1` in-place in constant-time. pub(crate) fn shl1_assign(&mut self) { - // TODO(tarcieri): optimized implementation - *self = self.shl1(); + let mut carry = self.limbs[0].0 >> Limb::HI_BIT; + self.limbs[0].shl_assign(1); + for i in 1..self.limbs.len() { + let new_carry = self.limbs[i].0 >> Limb::HI_BIT; + self.limbs[i].shl_assign(1); + self.limbs[i].0 |= carry; + carry = new_carry + } } } @@ -94,7 +108,7 @@ impl Shl for BoxedUint { type Output = BoxedUint; fn shl(self, shift: u32) -> BoxedUint { - Self::shl(&self, shift) + <&BoxedUint as Shl>::shl(&self, shift) } } @@ -102,7 +116,9 @@ impl Shl for &BoxedUint { type Output = BoxedUint; fn shl(self, shift: u32) -> BoxedUint { - self.shl(shift) + let (result, overflow) = self.shl(shift); + assert!(!bool::from(overflow), "attempt to shift left with overflow"); + result } } @@ -117,15 +133,35 @@ impl ShlAssign for BoxedUint { mod tests { use super::BoxedUint; + #[test] + fn shl1_assign() { + let mut n = BoxedUint::from(0x3c442b21f19185fe433f0a65af902b8fu128); + let n_shl1 = BoxedUint::from(0x78885643e3230bfc867e14cb5f20571eu128); + n.shl1_assign(); + assert_eq!(n, n_shl1); + } + + #[test] + fn shl() { + let one = BoxedUint::one_with_precision(128); + + assert_eq!(BoxedUint::from(2u8), one.shl(1).0); + assert_eq!(BoxedUint::from(4u8), one.shl(2).0); + assert_eq!( + BoxedUint::from(0x80000000000000000u128), + one.shl_vartime(67).unwrap() + ); + } + #[test] fn shl_vartime() { let one = BoxedUint::one_with_precision(128); - assert_eq!(BoxedUint::from(2u8), one.shl_vartime(1)); - assert_eq!(BoxedUint::from(4u8), one.shl_vartime(2)); + assert_eq!(BoxedUint::from(2u8), one.shl_vartime(1).unwrap()); + assert_eq!(BoxedUint::from(4u8), one.shl_vartime(2).unwrap()); assert_eq!( BoxedUint::from(0x80000000000000000u128), - one.shl_vartime(67) + one.shl_vartime(67).unwrap() ); } } diff --git a/src/uint/boxed/shr.rs b/src/uint/boxed/shr.rs index ba5a3487..ac1acee5 100644 --- a/src/uint/boxed/shr.rs +++ b/src/uint/boxed/shr.rs @@ -5,69 +5,86 @@ use core::ops::{Shr, ShrAssign}; use subtle::{Choice, ConstantTimeLess}; impl BoxedUint { - /// Computes `self << shift`. - /// Returns zero if `shift >= Self::BITS`. - pub fn shr(&self, shift: u32) -> Self { + /// Computes `self >> shift`. + /// + /// Returns a zero and a truthy `Choice` if `shift >= self.bits_precision()`, + /// or the result and a falsy `Choice` otherwise. + pub fn shr(&self, shift: u32) -> (Self, Choice) { + // `floor(log2(bits_precision - 1))` is the number of bits in the representation of `shift` + // (which lies in range `0 <= shift < bits_precision`). + let shift_bits = u32::BITS - (self.bits_precision() - 1).leading_zeros(); let overflow = !shift.ct_lt(&self.bits_precision()); let shift = shift % self.bits_precision(); - let log2_bits = u32::BITS - self.bits_precision().leading_zeros(); let mut result = self.clone(); + let mut temp = self.clone(); - for i in 0..log2_bits { + for i in 0..shift_bits { let bit = Choice::from(((shift >> i) & 1) as u8); - result = Self::conditional_select(&result, &result.shr_vartime(1 << i), bit); + temp.set_to_zero(); + // Will not overflow by construction + result + .shr_vartime_into(&mut temp, 1 << i) + .expect("shift within range"); + result.conditional_assign(&temp, bit); } - Self::conditional_select( - &result, - &Self::zero_with_precision(self.bits_precision()), - overflow, - ) + result.conditional_set_to_zero(overflow); + + (result, overflow) } /// Computes `self >> shift`. + /// Returns `None` if `shift >= self.bits_precision()`. + /// + /// WARNING: for performance reasons, `dest` is assumed to be pre-zeroized. /// /// NOTE: this operation is variable time with respect to `shift` *ONLY*. /// /// When used with a fixed `shift`, this function is constant-time with respect to `self`. #[inline(always)] - pub fn shr_vartime(&self, shift: u32) -> Self { + fn shr_vartime_into(&self, dest: &mut Self, shift: u32) -> Option<()> { + if shift >= self.bits_precision() { + return None; + } + let nlimbs = self.nlimbs(); - let full_shifts = (shift / Limb::BITS) as usize; - let small_shift = shift & (Limb::BITS - 1); - let mut limbs = vec![Limb::ZERO; nlimbs].into_boxed_slice(); + let shift_num = (shift / Limb::BITS) as usize; + let rem = shift % Limb::BITS; + + for i in 0..nlimbs - shift_num { + dest.limbs[i] = self.limbs[i + shift_num]; + } - if shift > self.bits_precision() { - return Self { limbs }; + if rem == 0 { + return Some(()); } - let n = nlimbs - full_shifts; - let mut i = 0; - - if small_shift == 0 { - while i < n { - limbs[i] = Limb(self.limbs[i + full_shifts].0); - i += 1; - } - } else { - while i < n { - let mut lo = self.limbs[i + full_shifts].0 >> small_shift; - - if i < (nlimbs - 1) - full_shifts { - lo |= self.limbs[i + full_shifts + 1].0 << (Limb::BITS - small_shift); - } - - limbs[i] = Limb(lo); - i += 1; - } + for i in 0..nlimbs - shift_num - 1 { + let shifted = dest.limbs[i].shr(rem); + let carry = dest.limbs[i + 1].shl(Limb::BITS - rem); + dest.limbs[i] = shifted.bitor(carry); } + dest.limbs[nlimbs - shift_num - 1] = dest.limbs[nlimbs - shift_num - 1].shr(rem); + + Some(()) + } - Self { limbs } + /// Computes `self >> shift`. + /// Returns `None` if `shift >= self.bits_precision()`. + /// + /// NOTE: this operation is variable time with respect to `shift` *ONLY*. + /// + /// When used with a fixed `shift`, this function is constant-time with respect to `self`. + #[inline(always)] + pub fn shr_vartime(&self, shift: u32) -> Option { + let mut result = Self::zero_with_precision(self.bits_precision()); + let success = self.shr_vartime_into(&mut result, shift); + success.map(|_| result) } - /// Computes `self >> 1` in constant-time, returning a true [`Choice`] if the overflowing bit - /// was set, and a false [`Choice::FALSE`] otherwise. - pub(crate) fn shr1_with_overflow(&self) -> (Self, Choice) { + /// Computes `self >> 1` in constant-time, returning a true [`Choice`] + /// if the least significant bit was set, and a false [`Choice::FALSE`] otherwise. + pub(crate) fn shr1_with_carry(&self) -> (Self, Choice) { let carry = self.limbs[0].0 & 1; (self.shr1(), Choice::from(carry as u8)) } @@ -95,7 +112,7 @@ impl Shr for BoxedUint { type Output = BoxedUint; fn shr(self, shift: u32) -> BoxedUint { - Self::shr(&self, shift) + <&BoxedUint as Shr>::shr(&self, shift) } } @@ -103,7 +120,12 @@ impl Shr for &BoxedUint { type Output = BoxedUint; fn shr(self, shift: u32) -> BoxedUint { - self.shr(shift) + let (result, overflow) = self.shr(shift); + assert!( + !bool::from(overflow), + "attempt to shift right with overflow" + ); + result } } @@ -126,12 +148,21 @@ mod tests { assert_eq!(n, n_shr1); } + #[test] + fn shr() { + let n = BoxedUint::from(0x80000000000000000u128); + assert_eq!(BoxedUint::zero(), n.shr(68).0); + assert_eq!(BoxedUint::one(), n.shr(67).0); + assert_eq!(BoxedUint::from(2u8), n.shr(66).0); + assert_eq!(BoxedUint::from(4u8), n.shr(65).0); + } + #[test] fn shr_vartime() { let n = BoxedUint::from(0x80000000000000000u128); - assert_eq!(BoxedUint::zero(), n.shr_vartime(68)); - assert_eq!(BoxedUint::one(), n.shr_vartime(67)); - assert_eq!(BoxedUint::from(2u8), n.shr_vartime(66)); - assert_eq!(BoxedUint::from(4u8), n.shr_vartime(65)); + assert_eq!(BoxedUint::zero(), n.shr_vartime(68).unwrap()); + assert_eq!(BoxedUint::one(), n.shr_vartime(67).unwrap()); + assert_eq!(BoxedUint::from(2u8), n.shr_vartime(66).unwrap()); + assert_eq!(BoxedUint::from(4u8), n.shr_vartime(65).unwrap()); } } diff --git a/src/uint/div.rs b/src/uint/div.rs index ac50a206..b634a9c9 100644 --- a/src/uint/div.rs +++ b/src/uint/div.rs @@ -53,7 +53,9 @@ impl Uint { let mb = rhs.bits(); let mut rem = *self; let mut quo = Self::ZERO; - let mut c = rhs.shl(Self::BITS - mb); + // If there is overflow, it means `mb == 0`, so `rhs == 0`. + let (mut c, overflow) = rhs.shl(Self::BITS - mb); + let is_some = overflow.not(); let mut i = Self::BITS; let mut done = CtChoice::FALSE; @@ -73,7 +75,6 @@ impl Uint { quo = Self::ct_select(&quo.shl1(), &quo, done); } - let is_some = Limb(mb as Word).ct_is_nonzero(); quo = Self::ct_select(&Self::ZERO, &quo, is_some); (quo, rem, is_some) } @@ -93,7 +94,8 @@ impl Uint { let mut bd = Self::BITS - mb; let mut rem = *self; let mut quo = Self::ZERO; - let mut c = rhs.shl_vartime(bd); + let (mut c, overflow) = rhs.shl_vartime(bd); + let is_some = overflow.not(); loop { let (mut r, borrow) = rem.sbb(&c, Limb::ZERO); @@ -108,7 +110,6 @@ impl Uint { quo = quo.shl1(); } - let is_some = CtChoice::from_u32_nonzero(mb); quo = Self::ct_select(&Self::ZERO, &quo, is_some); (quo, rem, is_some) } @@ -125,7 +126,8 @@ impl Uint { let mb = rhs.bits_vartime(); let mut bd = Self::BITS - mb; let mut rem = *self; - let mut c = rhs.shl_vartime(bd); + let (mut c, overflow) = rhs.shl_vartime(bd); + let is_some = overflow.not(); loop { let (r, borrow) = rem.sbb(&c, Limb::ZERO); @@ -137,7 +139,6 @@ impl Uint { c = c.shr1(); } - let is_some = CtChoice::from_u32_nonzero(mb); (rem, is_some) } @@ -158,7 +159,7 @@ impl Uint { let (mut lower, mut upper) = lower_upper; // Factor of the modulus, split into two halves - let mut c = Self::shl_vartime_wide((*rhs, Uint::ZERO), bd); + let (mut c, _overflow) = Self::shl_vartime_wide((*rhs, Uint::ZERO), bd); loop { let (lower_sub, borrow) = lower.sbb(&c.0, Limb::ZERO); @@ -170,7 +171,8 @@ impl Uint { break; } bd -= 1; - c = Self::shr_vartime_wide(c, 1); + let (new_c, _overflow) = Self::shr_vartime_wide(c, 1); + c = new_c; } let is_some = CtChoice::from_u32_nonzero(mb); @@ -696,8 +698,8 @@ mod tests { fn div() { let mut rng = ChaChaRng::from_seed([7u8; 32]); for _ in 0..25 { - let num = U256::random(&mut rng).shr_vartime(128); - let den = U256::random(&mut rng).shr_vartime(128); + let (num, _) = U256::random(&mut rng).shr_vartime(128); + let (den, _) = U256::random(&mut rng).shr_vartime(128); let n = num.checked_mul(&den); if n.is_some().into() { let (q, _, is_some) = n.unwrap().const_div_rem(&den); @@ -808,7 +810,7 @@ mod tests { for _ in 0..25 { let num = U256::random(&mut rng); let k = rng.next_u32() % 256; - let den = U256::ONE.shl_vartime(k); + let (den, _) = U256::ONE.shl_vartime(k); let a = num.rem2k(k); let e = num.wrapping_rem(&den); diff --git a/src/uint/inv_mod.rs b/src/uint/inv_mod.rs index 28694720..236f0ac1 100644 --- a/src/uint/inv_mod.rs +++ b/src/uint/inv_mod.rs @@ -30,7 +30,8 @@ impl Uint { // b_{i+1} = (b_i - a * X_i) / 2 b = Self::ct_select(&b, &b.wrapping_sub(self), x_i_choice).shr1(); // Store the X_i bit in the result (x = x | (1 << X_i)) - x = x.bitor(&Uint::from_word(x_i).shl_vartime(i)); + let (shifted, _overflow) = Uint::from_word(x_i).shl_vartime(i); + x = x.bitor(&shifted); i += 1; } @@ -127,9 +128,9 @@ impl Uint { let (new_u, cyy) = new_u.conditional_wrapping_add(modulus, cy); debug_assert!(cy.is_true_vartime() == cyy.is_true_vartime()); - let (new_a, overflow) = a.shr1_with_overflow(); - debug_assert!(modulus_is_odd.not().or(overflow.not()).is_true_vartime()); - let (new_u, cy) = new_u.shr1_with_overflow(); + let (new_a, carry) = a.shr1_with_carry(); + debug_assert!(modulus_is_odd.not().or(carry.not()).is_true_vartime()); + let (new_u, cy) = new_u.shr1_with_carry(); let (new_u, cy) = new_u.conditional_wrapping_add(&m1hp, cy); debug_assert!(modulus_is_odd.not().or(cy.not()).is_true_vartime()); @@ -161,7 +162,7 @@ impl Uint { pub const fn inv_mod(&self, modulus: &Self) -> (Self, CtChoice) { // Decompose `modulus = s * 2^k` where `s` is odd let k = modulus.trailing_zeros(); - let s = modulus.shr(k); + let (s, _overflow) = modulus.shr(k); // Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses. // Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1` @@ -176,7 +177,9 @@ impl Uint { let (m_odd_inv, _is_some) = s.inv_mod2k(k); // `s` is odd, so this always exists // This part is mod 2^k - let mask = Uint::ONE.shl(k).wrapping_sub(&Uint::ONE); + // Will not overflow since `modulus` is nonzero, and therefore `k < BITS`. + let (shifted, _overflow) = Uint::ONE.shl(k); + let mask = shifted.wrapping_sub(&Uint::ONE); let t = (b.wrapping_sub(&a).wrapping_mul(&m_odd_inv)).bitand(&mask); // Will not overflow since `a <= s - 1`, `t <= 2^k - 1`, diff --git a/src/uint/mul.rs b/src/uint/mul.rs index a668f2ce..41e9e0eb 100644 --- a/src/uint/mul.rs +++ b/src/uint/mul.rs @@ -135,7 +135,7 @@ impl Uint { // Double the current result, this accounts for the other half of the multiplication grid. // TODO: The top word is empty so we can also use a special purpose shl. - (lo, hi) = Self::shl_vartime_wide((lo, hi), 1); + (lo, hi) = Self::shl_vartime_wide((lo, hi), 1).0; // Handle the diagonal of the multiplication grid, which finishes the multiplication grid. let mut carry = Limb::ZERO; diff --git a/src/uint/shl.rs b/src/uint/shl.rs index ce4b9047..2b885d8d 100644 --- a/src/uint/shl.rs +++ b/src/uint/shl.rs @@ -5,66 +5,93 @@ use core::ops::{Shl, ShlAssign}; impl Uint { /// Computes `self << shift`. - /// Returns zero if `shift >= Self::BITS`. - pub const fn shl(&self, shift: u32) -> Self { + /// If `shift >= Self::BITS`, returns zero as the first tuple element, + /// and `CtChoice::TRUE` as the second element. + pub const fn shl(&self, shift: u32) -> (Self, CtChoice) { + // `floor(log2(BITS - 1))` is the number of bits in the representation of `shift` + // (which lies in range `0 <= shift < BITS`). + let shift_bits = u32::BITS - (Self::BITS - 1).leading_zeros(); let overflow = CtChoice::from_u32_lt(shift, Self::BITS).not(); let shift = shift % Self::BITS; let mut result = *self; let mut i = 0; - while i < Self::LOG2_BITS + 1 { + while i < shift_bits { let bit = CtChoice::from_u32_lsb((shift >> i) & 1); - result = Uint::ct_select(&result, &result.shl_vartime(1 << i), bit); + result = Uint::ct_select(&result, &result.shl_vartime(1 << i).0, bit); i += 1; } - Uint::ct_select(&result, &Self::ZERO, overflow) + (Uint::ct_select(&result, &Self::ZERO, overflow), overflow) } /// Computes `self << shift`. + /// If `shift >= Self::BITS`, returns zero as the first tuple element, + /// and `CtChoice::TRUE` as the second element. /// /// NOTE: this operation is variable time with respect to `shift` *ONLY*. /// /// When used with a fixed `shift`, this function is constant-time with respect /// to `self`. #[inline(always)] - pub const fn shl_vartime(&self, shift: u32) -> Self { + pub const fn shl_vartime(&self, shift: u32) -> (Self, CtChoice) { let mut limbs = [Limb::ZERO; LIMBS]; if shift >= Self::BITS { - return Self { limbs }; + return (Self::ZERO, CtChoice::TRUE); } let shift_num = (shift / Limb::BITS) as usize; let rem = shift % Limb::BITS; - let mut i = LIMBS; - while i > shift_num { - i -= 1; + let mut i = shift_num; + while i < LIMBS { limbs[i] = self.limbs[i - shift_num]; + i += 1; + } + + if rem == 0 { + return (Self { limbs }, CtChoice::FALSE); + } + + let mut carry = Limb::ZERO; + + let mut i = shift_num; + while i < LIMBS { + let shifted = limbs[i].shl(rem); + let new_carry = limbs[i].shr(Limb::BITS - rem); + limbs[i] = shifted.bitor(carry); + carry = new_carry; + i += 1; } - let (new_lower, _carry) = (Self { limbs }).shl_limb(rem); - new_lower + (Self { limbs }, CtChoice::FALSE) } /// Computes a left shift on a wide input as `(lo, hi)`. + /// If `shift >= Self::BITS`, returns a tuple of zeros as the first element, + /// and `CtChoice::TRUE` as the second element. /// /// NOTE: this operation is variable time with respect to `shift` *ONLY*. /// /// When used with a fixed `shift`, this function is constant-time with respect /// to `self`. #[inline(always)] - pub const fn shl_vartime_wide(lower_upper: (Self, Self), shift: u32) -> (Self, Self) { - let (lower, mut upper) = lower_upper; - let new_lower = lower.shl_vartime(shift); - upper = upper.shl_vartime(shift); - if shift >= Self::BITS { - upper = upper.bitor(&lower.shl_vartime(shift - Self::BITS)); + pub const fn shl_vartime_wide( + lower_upper: (Self, Self), + shift: u32, + ) -> ((Self, Self), CtChoice) { + let (lower, upper) = lower_upper; + if shift >= 2 * Self::BITS { + ((Self::ZERO, Self::ZERO), CtChoice::TRUE) + } else if shift >= Self::BITS { + let (upper, _) = lower.shl_vartime(shift - Self::BITS); + ((Self::ZERO, upper), CtChoice::FALSE) } else { - upper = upper.bitor(&lower.shr_vartime(Self::BITS - shift)); + let (new_lower, _) = lower.shl_vartime(shift); + let (upper_lo, _) = lower.shr_vartime(Self::BITS - shift); + let (upper_hi, _) = upper.shl_vartime(shift); + ((new_lower, upper_lo.bitor(&upper_hi)), CtChoice::FALSE) } - - (new_lower, upper) } /// Computes `self << shift` where `0 <= shift < Limb::BITS`, @@ -78,23 +105,40 @@ impl Uint { let rshift = nz.if_true_u32(Limb::BITS - shift); let carry = nz.if_true_word(self.limbs[LIMBS - 1].0.wrapping_shr(Word::BITS - shift)); - let mut i = LIMBS - 1; - while i > 0 { + limbs[0] = Limb(self.limbs[0].0 << lshift); + let mut i = 1; + while i < LIMBS { let mut limb = self.limbs[i].0 << lshift; let hi = self.limbs[i - 1].0 >> rshift; limb |= nz.if_true_word(hi); limbs[i] = Limb(limb); - i -= 1 + i += 1 } - limbs[0] = Limb(self.limbs[0].0 << lshift); (Uint::::new(limbs), Limb(carry)) } - /// Computes `self >> 1` in constant-time. + /// Computes `self << 1` in constant-time, returning [`CtChoice::TRUE`] + /// if the most significant bit was set, and [`CtChoice::FALSE`] otherwise. + #[inline(always)] + pub(crate) const fn shl1_with_carry(&self) -> (Self, CtChoice) { + let mut ret = Self::ZERO; + let mut i = 0; + let mut carry = Limb::ZERO; + while i < LIMBS { + let (shifted, new_carry) = self.limbs[i].shl1(); + ret.limbs[i] = shifted.bitor(carry); + carry = new_carry; + i += 1; + } + + (ret, CtChoice::from_word_lsb(carry.0)) + } + + /// Computes `self << 1` in constant-time. pub(crate) const fn shl1(&self) -> Self { // TODO(tarcieri): optimized implementation - self.shl_vartime(1) + self.shl1_with_carry().0 } } @@ -102,7 +146,7 @@ impl Shl for Uint { type Output = Uint; fn shl(self, shift: u32) -> Uint { - Uint::::shl(&self, shift) + <&Uint as Shl>::shl(&self, shift) } } @@ -110,7 +154,12 @@ impl Shl for &Uint { type Output = Uint; fn shl(self, shift: u32) -> Uint { - self.shl(shift) + let (result, overflow) = Uint::::shl(self, shift); + assert!( + !overflow.is_true_vartime(), + "attempt to shift left with overflow" + ); + result } } @@ -122,7 +171,7 @@ impl ShlAssign for Uint { #[cfg(test)] mod tests { - use crate::{Limb, Uint, U128, U256}; + use crate::{CtChoice, Limb, Uint, U128, U256}; const N: U256 = U256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141"); @@ -153,6 +202,7 @@ mod tests { #[test] fn shl1() { assert_eq!(N << 1, TWO_N); + assert_eq!(N.shl1(), TWO_N); } #[test] @@ -171,8 +221,15 @@ mod tests { } #[test] + fn shl256_const() { + assert_eq!(N.shl(256), (U256::ZERO, CtChoice::TRUE)); + assert_eq!(N.shl_vartime(256), (U256::ZERO, CtChoice::TRUE)); + } + + #[test] + #[should_panic(expected = "attempt to shift left with overflow")] fn shl256() { - assert_eq!(N << 256, U256::default()); + let _ = N << 256; } #[test] @@ -184,7 +241,11 @@ mod tests { fn shl_wide_1_1_128() { assert_eq!( Uint::shl_vartime_wide((U128::ONE, U128::ONE), 128), - (U128::ZERO, U128::ONE) + ((U128::ZERO, U128::ONE), CtChoice::FALSE) + ); + assert_eq!( + Uint::shl_vartime_wide((U128::ONE, U128::ONE), 128), + ((U128::ZERO, U128::ONE), CtChoice::FALSE) ); } @@ -192,7 +253,10 @@ mod tests { fn shl_wide_max_0_1() { assert_eq!( Uint::shl_vartime_wide((U128::MAX, U128::ZERO), 1), - (U128::MAX.sbb(&U128::ONE, Limb::ZERO).0, U128::ONE) + ( + (U128::MAX.sbb(&U128::ONE, Limb::ZERO).0, U128::ONE), + CtChoice::FALSE + ) ); } @@ -200,7 +264,7 @@ mod tests { fn shl_wide_max_max_256() { assert_eq!( Uint::shl_vartime_wide((U128::MAX, U128::MAX), 256), - (U128::ZERO, U128::ZERO) + ((U128::ZERO, U128::ZERO), CtChoice::TRUE) ); } } diff --git a/src/uint/shr.rs b/src/uint/shr.rs index 8f2b0b69..5bb8093b 100644 --- a/src/uint/shr.rs +++ b/src/uint/shr.rs @@ -1,107 +1,119 @@ //! [`Uint`] bitwise right shift operations. -use super::Uint; -use crate::{CtChoice, Limb}; +use crate::{CtChoice, Limb, Uint}; use core::ops::{Shr, ShrAssign}; impl Uint { - /// Computes `self << shift`. - /// Returns zero if `shift >= Self::BITS`. - pub const fn shr(&self, shift: u32) -> Self { + /// Computes `self >> shift`. + /// If `shift >= Self::BITS`, returns zero as the first tuple element, + /// and `CtChoice::TRUE` as the second element. + pub const fn shr(&self, shift: u32) -> (Self, CtChoice) { + // `floor(log2(BITS - 1))` is the number of bits in the representation of `shift` + // (which lies in range `0 <= shift < BITS`). + let shift_bits = u32::BITS - (Self::BITS - 1).leading_zeros(); let overflow = CtChoice::from_u32_lt(shift, Self::BITS).not(); let shift = shift % Self::BITS; let mut result = *self; let mut i = 0; - while i < Self::LOG2_BITS + 1 { + while i < shift_bits { let bit = CtChoice::from_u32_lsb((shift >> i) & 1); - result = Uint::ct_select(&result, &result.shr_vartime(1 << i), bit); + result = Uint::ct_select(&result, &result.shr_vartime(1 << i).0, bit); i += 1; } - Uint::ct_select(&result, &Self::ZERO, overflow) + (Uint::ct_select(&result, &Self::ZERO, overflow), overflow) } /// Computes `self >> shift`. + /// If `shift >= Self::BITS`, returns zero as the first tuple element, + /// and `CtChoice::TRUE` as the second element. /// /// NOTE: this operation is variable time with respect to `shift` *ONLY*. /// /// When used with a fixed `shift`, this function is constant-time with respect /// to `self`. #[inline(always)] - pub const fn shr_vartime(&self, shift: u32) -> Self { - let full_shifts = (shift / Limb::BITS) as usize; - let small_shift = shift & (Limb::BITS - 1); + pub const fn shr_vartime(&self, shift: u32) -> (Self, CtChoice) { let mut limbs = [Limb::ZERO; LIMBS]; - if shift > Self::BITS { - return Self { limbs }; + if shift >= Self::BITS { + return (Self::ZERO, CtChoice::TRUE); } - let shift = LIMBS - full_shifts; + let shift_num = (shift / Limb::BITS) as usize; + let rem = shift % Limb::BITS; + let mut i = 0; + while i < LIMBS - shift_num { + limbs[i] = self.limbs[i + shift_num]; + i += 1; + } - if small_shift == 0 { - while i < shift { - limbs[i] = Limb(self.limbs[i + full_shifts].0); - i += 1; - } - } else { - while i < shift { - let mut lo = self.limbs[i + full_shifts].0 >> small_shift; + if rem == 0 { + return (Self { limbs }, CtChoice::FALSE); + } - if i < (LIMBS - 1) - full_shifts { - lo |= self.limbs[i + full_shifts + 1].0 << (Limb::BITS - small_shift); - } + let mut carry = Limb::ZERO; - limbs[i] = Limb(lo); - i += 1; - } + while i > 0 { + i -= 1; + let shifted = limbs[i].shr(rem); + let new_carry = limbs[i].shl(Limb::BITS - rem); + limbs[i] = shifted.bitor(carry); + carry = new_carry; } - Self { limbs } + (Self { limbs }, CtChoice::FALSE) } /// Computes a right shift on a wide input as `(lo, hi)`. + /// If `shift >= Self::BITS`, returns a tuple of zeros as the first element, + /// and `CtChoice::TRUE` as the second element. /// /// NOTE: this operation is variable time with respect to `shift` *ONLY*. /// /// When used with a fixed `shift`, this function is constant-time with respect /// to `self`. #[inline(always)] - pub const fn shr_vartime_wide(lower_upper: (Self, Self), shift: u32) -> (Self, Self) { - let (mut lower, upper) = lower_upper; - let new_upper = upper.shr_vartime(shift); - lower = lower.shr_vartime(shift); - if shift >= Self::BITS { - lower = lower.bitor(&upper.shr_vartime(shift - Self::BITS)); + pub const fn shr_vartime_wide( + lower_upper: (Self, Self), + shift: u32, + ) -> ((Self, Self), CtChoice) { + let (lower, upper) = lower_upper; + if shift >= 2 * Self::BITS { + ((Self::ZERO, Self::ZERO), CtChoice::TRUE) + } else if shift >= Self::BITS { + let (lower, _) = upper.shr_vartime(shift - Self::BITS); + ((lower, Self::ZERO), CtChoice::FALSE) } else { - lower = lower.bitor(&upper.shl_vartime(Self::BITS - shift)); + let (new_upper, _) = upper.shr_vartime(shift); + let (lower_hi, _) = upper.shl_vartime(Self::BITS - shift); + let (lower_lo, _) = lower.shr_vartime(shift); + ((lower_lo.bitor(&lower_hi), new_upper), CtChoice::FALSE) } - - (lower, new_upper) } - /// Computes `self >> 1` in constant-time, returning [`CtChoice::TRUE`] if the overflowing bit - /// was set, and [`CtChoice::FALSE`] otherwise. - pub(crate) const fn shr1_with_overflow(&self) -> (Self, CtChoice) { - let carry = CtChoice::from_word_lsb(self.limbs[0].0 & 1); + /// Computes `self >> 1` in constant-time, returning [`CtChoice::TRUE`] + /// if the least significant bit was set, and [`CtChoice::FALSE`] otherwise. + #[inline(always)] + pub(crate) const fn shr1_with_carry(&self) -> (Self, CtChoice) { let mut ret = Self::ZERO; - ret.limbs[0] = self.limbs[0].shr(1); - - let mut i = 1; - while i < LIMBS { - // set carry bit - ret.limbs[i - 1].0 |= (self.limbs[i].0 & 1) << Limb::HI_BIT; - ret.limbs[i] = self.limbs[i].shr(1); - i += 1; + let mut i = LIMBS; + let mut carry = Limb::ZERO; + while i > 0 { + i -= 1; + let (shifted, new_carry) = self.limbs[i].shr1(); + ret.limbs[i] = shifted.bitor(carry); + carry = new_carry; } - (ret, carry) + (ret, CtChoice::from_word_lsb(carry.0 >> Limb::HI_BIT)) } /// Computes `self >> 1` in constant-time. pub(crate) const fn shr1(&self) -> Self { - self.shr1_with_overflow().0 + // TODO(tarcieri): optimized implementation + self.shr1_with_carry().0 } } @@ -109,7 +121,7 @@ impl Shr for Uint { type Output = Uint; fn shr(self, shift: u32) -> Uint { - Uint::::shr(&self, shift) + <&Uint as Shr>::shr(&self, shift) } } @@ -117,7 +129,12 @@ impl Shr for &Uint { type Output = Uint; fn shr(self, shift: u32) -> Uint { - self.shr(shift) + let (result, overflow) = Uint::::shr(self, shift); + assert!( + !overflow.is_true_vartime(), + "attempt to shift right with overflow" + ); + result } } @@ -129,7 +146,7 @@ impl ShrAssign for Uint { #[cfg(test)] mod tests { - use crate::{Uint, U128, U256}; + use crate::{CtChoice, Uint, U128, U256}; const N: U256 = U256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141"); @@ -139,14 +156,27 @@ mod tests { #[test] fn shr1() { + assert_eq!(N.shr1(), N_2); assert_eq!(N >> 1, N_2); } + #[test] + fn shr256_const() { + assert_eq!(N.shr(256), (U256::ZERO, CtChoice::TRUE)); + assert_eq!(N.shr_vartime(256), (U256::ZERO, CtChoice::TRUE)); + } + + #[test] + #[should_panic(expected = "attempt to shift right with overflow")] + fn shr256() { + let _ = N >> 256; + } + #[test] fn shr_wide_1_1_128() { assert_eq!( Uint::shr_vartime_wide((U128::ONE, U128::ONE), 128), - (U128::ONE, U128::ZERO) + ((U128::ONE, U128::ZERO), CtChoice::FALSE) ); } @@ -154,7 +184,7 @@ mod tests { fn shr_wide_0_max_1() { assert_eq!( Uint::shr_vartime_wide((U128::ZERO, U128::MAX), 1), - (U128::ONE << 127, U128::MAX >> 1) + ((U128::ONE << 127, U128::MAX >> 1), CtChoice::FALSE) ); } @@ -162,7 +192,7 @@ mod tests { fn shr_wide_max_max_256() { assert_eq!( Uint::shr_vartime_wide((U128::MAX, U128::MAX), 256), - (U128::ZERO, U128::ZERO) + ((U128::ZERO, U128::ZERO), CtChoice::TRUE) ); } } diff --git a/src/uint/sqrt.rs b/src/uint/sqrt.rs index eed95826..17394c0f 100644 --- a/src/uint/sqrt.rs +++ b/src/uint/sqrt.rs @@ -15,7 +15,8 @@ impl Uint { // https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`. - let mut x = Self::ONE.shl((self.bits() + 1) >> 1); // ≥ √(`self`) + // Will not overflow since `b <= BITS`. + let (mut x, _overflow) = Self::ONE.shl((self.bits() + 1) >> 1); // ≥ √(`self`) // Repeat enough times to guarantee result has stabilized. let mut i = 0; @@ -49,7 +50,8 @@ impl Uint { // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13 // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`. - let mut x = Self::ONE.shl((self.bits() + 1) >> 1); // ≥ √(`self`) + // Will not overflow since `b <= BITS`. + let (mut x, _overflow) = Self::ONE.shl((self.bits() + 1) >> 1); // ≥ √(`self`) // Stop right away if `x` is zero to avoid divizion by zero. while !x.cmp_vartime(&Self::ZERO).is_eq() { diff --git a/tests/boxed_uint_proptests.rs b/tests/boxed_uint_proptests.rs index 4fcb99d6..424d4b68 100644 --- a/tests/boxed_uint_proptests.rs +++ b/tests/boxed_uint_proptests.rs @@ -5,6 +5,7 @@ use core::cmp::Ordering; use crypto_bigint::{BoxedUint, CheckedAdd, Limb, NonZero}; use num_bigint::{BigUint, ModInverse}; +use num_traits::identities::One; use proptest::prelude::*; fn to_biguint(uint: &BoxedUint) -> BigUint { @@ -212,4 +213,75 @@ proptest! { prop_assert_eq!(expected, to_biguint(&actual)); } } + + #[test] + fn shl(a in uint(), shift in any::()) { + let a_bi = to_biguint(&a); + + // Add a 50% probability of overflow. + let shift = u32::from(shift) % (a.bits_precision() * 2); + + let expected = to_uint((a_bi << shift as usize) & ((BigUint::one() << a.bits_precision() as usize) - BigUint::one())); + let (actual, overflow) = a.shl(shift); + + assert_eq!(expected, actual); + if shift >= a.bits_precision() { + assert_eq!(actual, BoxedUint::zero()); + assert!(bool::from(overflow)); + } + } + + #[test] + fn shl_vartime(a in uint(), shift in any::()) { + let a_bi = to_biguint(&a); + + // Add a 50% probability of overflow. + let shift = u32::from(shift) % (a.bits_precision() * 2); + + let expected = to_uint((a_bi << shift as usize) & ((BigUint::one() << a.bits_precision() as usize) - BigUint::one())); + let actual = a.shl_vartime(shift); + + if shift >= a.bits_precision() { + assert!(actual.is_none()); + } + else { + assert_eq!(expected, actual.unwrap()); + } + } + + #[test] + fn shr(a in uint(), shift in any::()) { + let a_bi = to_biguint(&a); + + // Add a 50% probability of overflow. + let shift = u32::from(shift) % (a.bits_precision() * 2); + + let expected = to_uint(a_bi >> shift as usize); + let (actual, overflow) = a.shr(shift); + + assert_eq!(expected, actual); + if shift >= a.bits_precision() { + assert_eq!(actual, BoxedUint::zero()); + assert!(bool::from(overflow)); + } + } + + + #[test] + fn shr_vartime(a in uint(), shift in any::()) { + let a_bi = to_biguint(&a); + + // Add a 50% probability of overflow. + let shift = u32::from(shift) % (a.bits_precision() * 2); + + let expected = to_uint(a_bi >> shift as usize); + let actual = a.shr_vartime(shift); + + if shift >= a.bits_precision() { + assert!(actual.is_none()); + } + else { + assert_eq!(expected, actual.unwrap()); + } + } } diff --git a/tests/uint_proptests.rs b/tests/uint_proptests.rs index 9c884e25..e1b52f3a 100644 --- a/tests/uint_proptests.rs +++ b/tests/uint_proptests.rs @@ -60,10 +60,17 @@ proptest! { fn shl_vartime(a in uint(), shift in any::()) { let a_bi = to_biguint(&a); - let expected = to_uint(a_bi << shift.into()); - let actual = a.shl_vartime(shift.into()); + // Add a 50% probability of overflow. + let shift = u32::from(shift) % (U256::BITS * 2); + + let expected = to_uint((a_bi << shift as usize) & ((BigUint::one() << U256::BITS as usize) - BigUint::one())); + let (actual, overflow) = a.shl_vartime(shift.into()); assert_eq!(expected, actual); + if shift >= U256::BITS { + assert_eq!(actual, U256::ZERO); + assert_eq!(overflow, CtChoice::TRUE); + } } #[test] @@ -74,9 +81,30 @@ proptest! { let shift = u32::from(shift) % (U256::BITS * 2); let expected = to_uint((a_bi << shift as usize) & ((BigUint::one() << U256::BITS as usize) - BigUint::one())); - let actual = a.shl(shift); + let (actual, overflow) = a.shl(shift); assert_eq!(expected, actual); + if shift >= U256::BITS { + assert_eq!(actual, U256::ZERO); + assert_eq!(overflow, CtChoice::TRUE); + } + } + + #[test] + fn shr_vartime(a in uint(), shift in any::()) { + let a_bi = to_biguint(&a); + + // Add a 50% probability of overflow. + let shift = u32::from(shift) % (U256::BITS * 2); + + let expected = to_uint(a_bi >> shift as usize); + let (actual, overflow) = a.shr_vartime(shift); + + assert_eq!(expected, actual); + if shift >= U256::BITS { + assert_eq!(actual, U256::ZERO); + assert_eq!(overflow, CtChoice::TRUE); + } } #[test] @@ -87,9 +115,13 @@ proptest! { let shift = u32::from(shift) % (U256::BITS * 2); let expected = to_uint(a_bi >> shift as usize); - let actual = a.shr(shift); + let (actual, overflow) = a.shr(shift); assert_eq!(expected, actual); + if shift >= U256::BITS { + assert_eq!(actual, U256::ZERO); + assert_eq!(overflow, CtChoice::TRUE); + } } #[test]