diff --git a/const-oid/src/arcs.rs b/const-oid/src/arcs.rs index f245845fa..7e5056cee 100644 --- a/const-oid/src/arcs.rs +++ b/const-oid/src/arcs.rs @@ -26,8 +26,8 @@ pub(crate) const ARC_MAX_SECOND: Arc = 39; /// Maximum number of bytes supported in an arc. /// -/// Note that OIDs are LEB128 encoded (i.e. base 128), so we must consider how many bytes are -/// required when each byte can only represent 7-bits of the input. +/// Note that OIDs are base 128 encoded (with continuation bits), so we must consider how many bytes +/// are required when each byte can only represent 7-bits of the input. const ARC_MAX_BYTES: usize = (Arc::BITS as usize).div_ceil(7); /// Maximum value of the last byte in an arc. diff --git a/const-oid/src/checked.rs b/const-oid/src/checked.rs index 7ff16a2a7..c5941bdc6 100644 --- a/const-oid/src/checked.rs +++ b/const-oid/src/checked.rs @@ -5,7 +5,17 @@ macro_rules! checked_add { ($a:expr, $b:expr) => { match $a.checked_add($b) { Some(n) => n, - None => return Err(Error::Length), + None => return Err(Error::Overflow), + } + }; +} + +/// `const fn`-friendly checked addition helper. +macro_rules! checked_sub { + ($a:expr, $b:expr) => { + match $a.checked_sub($b) { + Some(n) => n, + None => return Err(Error::Overflow), } }; } diff --git a/const-oid/src/encoder.rs b/const-oid/src/encoder.rs index 1081dcd88..75a6dd2d2 100644 --- a/const-oid/src/encoder.rs +++ b/const-oid/src/encoder.rs @@ -24,7 +24,7 @@ enum State { /// Initial state - no arcs yet encoded. Initial, - /// First arc parsed. + /// First arc has been supplied and stored as the wrapped [`Arc`]. FirstArc(Arc), /// Encoding base 128 body of the OID. @@ -83,10 +83,7 @@ impl Encoder { self.cursor = 1; Ok(self) } - State::Body => { - let nbytes = base128_len(arc); - self.encode_base128(arc, nbytes) - } + State::Body => self.encode_base128(arc), } } @@ -104,64 +101,48 @@ impl Encoder { Ok(ObjectIdentifier { ber }) } - /// Encode a single byte of a Base 128 value. - const fn encode_base128(mut self, n: u32, remaining_len: usize) -> Result { - if self.cursor >= MAX_SIZE { + /// Encode base 128. + const fn encode_base128(mut self, arc: Arc) -> Result { + let nbytes = base128_len(arc); + let end_pos = checked_add!(self.cursor, nbytes); + + if end_pos > MAX_SIZE { return Err(Error::Length); } - let mask = if remaining_len > 0 { 0b10000000 } else { 0 }; - let (hi, lo) = split_hi_bits(n); - self.bytes[self.cursor] = hi | mask; - self.cursor = checked_add!(self.cursor, 1); - - match remaining_len.checked_sub(1) { - Some(len) => self.encode_base128(lo, len), - None => Ok(self), + let mut i = 0; + while i < nbytes { + // TODO(tarcieri): use `?` when stable in `const fn` + self.bytes[self.cursor] = match base128_byte(arc, i, nbytes) { + Ok(byte) => byte, + Err(e) => return Err(e), + }; + self.cursor = checked_add!(self.cursor, 1); + i = checked_add!(i, 1); } + + Ok(self) } } -/// Compute the length - 1 of an arc when encoded in base 128. +/// Compute the length of an arc when encoded in base 128. const fn base128_len(arc: Arc) -> usize { match arc { - 0..=0x7f => 0, - 0x80..=0x3fff => 1, - 0x4000..=0x1fffff => 2, - 0x200000..=0x1fffffff => 3, - _ => 4, + 0..=0x7f => 1, + 0x80..=0x3fff => 2, + 0x4000..=0x1fffff => 3, + 0x200000..=0x1fffffff => 4, + _ => 5, } } -/// Split the highest 7-bits of an [`Arc`] from the rest of an arc. -/// -/// Returns: `(hi, lo)` -#[inline] -const fn split_hi_bits(arc: Arc) -> (u8, Arc) { - if arc < 0x80 { - return (arc as u8, 0); - } - - let hi_bit = match 32u32.checked_sub(arc.leading_zeros()) { - Some(bit) => bit, - None => unreachable!(), - }; - - let hi_bit_mod7 = hi_bit % 7; - let upper_bit_offset = if hi_bit > 0 && hi_bit_mod7 == 0 { - 7 - } else { - hi_bit_mod7 - }; - - let upper_bit_pos = match hi_bit.checked_sub(upper_bit_offset) { - Some(bit) => bit, - None => unreachable!(), - }; - - let upper_bits = arc >> upper_bit_pos; - let lower_bits = arc ^ (upper_bits << upper_bit_pos); - (upper_bits as u8, lower_bits) +/// Compute the big endian base 128 encoding of the given [`Arc`] at the given byte. +const fn base128_byte(arc: Arc, pos: usize, total: usize) -> Result { + debug_assert!(pos < total); + let last_byte = checked_add!(pos, 1) == total; + let mask = if last_byte { 0 } else { 0b10000000 }; + let shift = checked_sub!(checked_sub!(total, pos), 1) * 7; + Ok(((arc >> shift) & 0b1111111) as u8 | mask) } #[cfg(test)] @@ -174,9 +155,14 @@ mod tests { const EXAMPLE_OID_BER: &[u8] = &hex!("2A8648CE3D0201"); #[test] - fn split_hi_bits_with_gaps() { - assert_eq!(super::split_hi_bits(0x3a00002), (0x1d, 0x2)); - assert_eq!(super::split_hi_bits(0x3a08000), (0x1d, 0x8000)); + fn base128_byte() { + let example_arc = 0x44332211; + assert_eq!(super::base128_len(example_arc), 5); + assert_eq!(super::base128_byte(example_arc, 0, 5).unwrap(), 0b10000100); + assert_eq!(super::base128_byte(example_arc, 1, 5).unwrap(), 0b10100001); + assert_eq!(super::base128_byte(example_arc, 2, 5).unwrap(), 0b11001100); + assert_eq!(super::base128_byte(example_arc, 3, 5).unwrap(), 0b11000100); + assert_eq!(super::base128_byte(example_arc, 4, 5).unwrap(), 0b10001); } #[test] diff --git a/const-oid/src/error.rs b/const-oid/src/error.rs index 39a70e8a4..9142d5733 100644 --- a/const-oid/src/error.rs +++ b/const-oid/src/error.rs @@ -37,6 +37,9 @@ pub enum Error { /// OID length is invalid (too short or too long). Length, + /// Arithmetic overflow (or underflow) errors. + Overflow, + /// Repeated `..` characters in input data. RepeatedDot, @@ -56,6 +59,7 @@ impl Error { Error::DigitExpected { .. } => panic!("OID expected to start with digit"), Error::Empty => panic!("OID value is empty"), Error::Length => panic!("OID length invalid"), + Error::Overflow => panic!("arithmetic calculation overflowed"), Error::RepeatedDot => panic!("repeated consecutive '..' characters in OID"), Error::TrailingDot => panic!("OID ends with invalid trailing '.'"), } @@ -73,6 +77,7 @@ impl fmt::Display for Error { } Error::Empty => f.write_str("OID value is empty"), Error::Length => f.write_str("OID length invalid"), + Error::Overflow => f.write_str("arithmetic calculation overflowed"), Error::RepeatedDot => f.write_str("repeated consecutive '..' characters in OID"), Error::TrailingDot => f.write_str("OID ends with invalid trailing '.'"), } diff --git a/const-oid/src/parser.rs b/const-oid/src/parser.rs index 4810294d9..5b5155b36 100644 --- a/const-oid/src/parser.rs +++ b/const-oid/src/parser.rs @@ -63,7 +63,7 @@ impl Parser { self.current_arc = match arc.checked_mul(10) { Some(arc) => match arc.checked_add(digit as Arc) { None => return Err(Error::ArcTooBig), - arc => arc, + Some(arc) => Some(arc), }, None => return Err(Error::ArcTooBig), }; diff --git a/const-oid/tests/oid.rs b/const-oid/tests/oid.rs index ad7a0f8e9..92bfc49c4 100644 --- a/const-oid/tests/oid.rs +++ b/const-oid/tests/oid.rs @@ -29,8 +29,8 @@ const EXAMPLE_OID_LARGE_ARC_0: ObjectIdentifier = ObjectIdentifier::new_unwrap(crate::EXAMPLE_OID_LARGE_ARC_0_STR); /// Example OID value with a large arc -const EXAMPLE_OID_LARGE_ARC_1_STR: &str = "0.9.2342.19200300.100.1.1"; -const EXAMPLE_OID_LARGE_ARC_1_BER: &[u8] = &hex!("0992268993F22C640101"); +const EXAMPLE_OID_LARGE_ARC_1_STR: &str = "1.1.1.60817410.1"; +const EXAMPLE_OID_LARGE_ARC_1_BER: &[u8] = &hex!("29019D80800201"); const EXAMPLE_OID_LARGE_ARC_1: ObjectIdentifier = ObjectIdentifier::new_unwrap(EXAMPLE_OID_LARGE_ARC_1_STR); @@ -45,54 +45,69 @@ pub fn oid(s: &str) -> ObjectIdentifier { ObjectIdentifier::new(s).unwrap() } +/// 0.9.2342.19200300.100.1.1 #[test] -fn from_bytes() { - // 0.9.2342.19200300.100.1.1 - let oid0 = ObjectIdentifier::from_bytes(EXAMPLE_OID_0_BER).unwrap(); - assert_eq!(oid0.arc(0).unwrap(), 0); - assert_eq!(oid0.arc(1).unwrap(), 9); - assert_eq!(oid0.arc(2).unwrap(), 2342); - assert_eq!(oid0, EXAMPLE_OID_0); +fn from_bytes_oid_0() { + let oid = ObjectIdentifier::from_bytes(EXAMPLE_OID_0_BER).unwrap(); + assert_eq!(oid, EXAMPLE_OID_0); + assert_eq!(oid.arc(0).unwrap(), 0); + assert_eq!(oid.arc(1).unwrap(), 9); + assert_eq!(oid.arc(2).unwrap(), 2342); +} - // 1.2.840.10045.2.1 - let oid1 = ObjectIdentifier::from_bytes(EXAMPLE_OID_1_BER).unwrap(); - assert_eq!(oid1.arc(0).unwrap(), 1); - assert_eq!(oid1.arc(1).unwrap(), 2); - assert_eq!(oid1.arc(2).unwrap(), 840); - assert_eq!(oid1, EXAMPLE_OID_1); +/// 1.2.840.10045.2.1 +#[test] +fn from_bytes_oid_1() { + let oid = ObjectIdentifier::from_bytes(EXAMPLE_OID_1_BER).unwrap(); + assert_eq!(oid, EXAMPLE_OID_1); + assert_eq!(oid.arc(0).unwrap(), 1); + assert_eq!(oid.arc(1).unwrap(), 2); + assert_eq!(oid.arc(2).unwrap(), 840); +} - // 2.16.840.1.101.3.4.1.42 - let oid2 = ObjectIdentifier::from_bytes(EXAMPLE_OID_2_BER).unwrap(); - assert_eq!(oid2.arc(0).unwrap(), 2); - assert_eq!(oid2.arc(1).unwrap(), 16); - assert_eq!(oid2.arc(2).unwrap(), 840); - assert_eq!(oid2, EXAMPLE_OID_2); +/// 2.16.840.1.101.3.4.1.42 +#[test] +fn from_bytes_oid_2() { + let oid = ObjectIdentifier::from_bytes(EXAMPLE_OID_2_BER).unwrap(); + assert_eq!(oid, EXAMPLE_OID_2); + assert_eq!(oid.arc(0).unwrap(), 2); + assert_eq!(oid.arc(1).unwrap(), 16); + assert_eq!(oid.arc(2).unwrap(), 840); +} - // 1.2.16384 - let oid_largearc0 = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_0_BER).unwrap(); - assert_eq!(oid_largearc0.arc(0).unwrap(), 1); - assert_eq!(oid_largearc0.arc(1).unwrap(), 2); - assert_eq!(oid_largearc0.arc(2).unwrap(), 16384); - assert_eq!(oid_largearc0.arc(3), None); - assert_eq!(oid_largearc0, EXAMPLE_OID_LARGE_ARC_0); +/// 1.2.16384 +#[test] +fn from_bytes_oid_largearc_0() { + let oid = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_0_BER).unwrap(); + assert_eq!(oid, EXAMPLE_OID_LARGE_ARC_0); + assert_eq!(oid.arc(0).unwrap(), 1); + assert_eq!(oid.arc(1).unwrap(), 2); + assert_eq!(oid.arc(2).unwrap(), 16384); + assert_eq!(oid.arc(3), None); +} - // 0.9.2342.19200300.100.1.1 - let oid_largearc1 = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_1_BER).unwrap(); - assert_eq!(oid_largearc1.arc(0).unwrap(), 0); - assert_eq!(oid_largearc1.arc(1).unwrap(), 9); - assert_eq!(oid_largearc1.arc(2).unwrap(), 2342); - assert_eq!(oid_largearc1.arc(3).unwrap(), 19200300); - assert_eq!(oid_largearc1.arc(4).unwrap(), 100); - assert_eq!(oid_largearc1.arc(5).unwrap(), 1); - assert_eq!(oid_largearc1.arc(6).unwrap(), 1); - assert_eq!(oid_largearc1, EXAMPLE_OID_LARGE_ARC_1); +/// 1.1.1.60817410.1 +#[test] +fn from_bytes_oid_largearc_1() { + let oid = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_1_BER).unwrap(); + assert_eq!(oid, EXAMPLE_OID_LARGE_ARC_1); + assert_eq!(oid.arc(0).unwrap(), 1); + assert_eq!(oid.arc(1).unwrap(), 1); + assert_eq!(oid.arc(2).unwrap(), 1); + assert_eq!(oid.arc(3).unwrap(), 60817410); + assert_eq!(oid.arc(4).unwrap(), 1); + assert_eq!(oid.arc(5), None); +} - // 1.2.4294967295 - let oid_largearc2 = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_2_BER).unwrap(); - assert_eq!(oid_largearc2.arc(0).unwrap(), 1); - assert_eq!(oid_largearc2.arc(1).unwrap(), 2); - assert_eq!(oid_largearc2.arc(2).unwrap(), 4294967295); - assert_eq!(oid_largearc2, EXAMPLE_OID_LARGE_ARC_2); +/// 1.2.4294967295 +#[test] +fn from_bytes_oid_largearc_2() { + let oid = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_2_BER).unwrap(); + assert_eq!(oid, EXAMPLE_OID_LARGE_ARC_2); + assert_eq!(oid.arc(0).unwrap(), 1); + assert_eq!(oid.arc(1).unwrap(), 2); + assert_eq!(oid.arc(2).unwrap(), 4294967295); + assert_eq!(oid.arc(3), None); // Empty assert_eq!(ObjectIdentifier::from_bytes(&[]), Err(Error::Empty)); @@ -126,13 +141,11 @@ fn from_str() { let oid_largearc1 = EXAMPLE_OID_LARGE_ARC_1_STR .parse::() .unwrap(); - assert_eq!(oid_largearc1.arc(0).unwrap(), 0); - assert_eq!(oid_largearc1.arc(1).unwrap(), 9); - assert_eq!(oid_largearc1.arc(2).unwrap(), 2342); - assert_eq!(oid_largearc1.arc(3).unwrap(), 19200300); - assert_eq!(oid_largearc1.arc(4).unwrap(), 100); - assert_eq!(oid_largearc1.arc(5).unwrap(), 1); - assert_eq!(oid_largearc1.arc(6).unwrap(), 1); + assert_eq!(oid_largearc1.arc(0).unwrap(), 1); + assert_eq!(oid_largearc1.arc(1).unwrap(), 1); + assert_eq!(oid_largearc1.arc(2).unwrap(), 1); + assert_eq!(oid_largearc1.arc(3).unwrap(), 60817410); + assert_eq!(oid_largearc1.arc(4).unwrap(), 1); assert_eq!(oid_largearc1, EXAMPLE_OID_LARGE_ARC_1); let oid_largearc2 = EXAMPLE_OID_LARGE_ARC_2_STR