From 9c3f39286471b8bf88ed99b3faf4401e41115fd3 Mon Sep 17 00:00:00 2001 From: Sebastien Rousseau Date: Thu, 9 May 2024 17:51:47 +0100 Subject: [PATCH] test(kyberlib): :recycle: minor refactoring, add unit tests for `rng.rs` --- src/error.rs | 6 ++++ src/params.rs | 2 +- src/rng.rs | 4 +++ tests/test_rng.rs | 92 ++++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 102 insertions(+), 2 deletions(-) diff --git a/src/error.rs b/src/error.rs index 83a6fb1..f209ca6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,6 +11,9 @@ pub enum KyberLibError { /// Error when generating keys InvalidKey, + /// The length of the input buffer is invalid. + InvalidLength, + /// The ciphertext was unable to be authenticated. The shared secret was not decapsulated. Decapsulation, @@ -31,6 +34,9 @@ impl core::fmt::Display for KyberLibError { } KyberLibError::InvalidKey => { write!(f, "The secret and public key given does not match.") + }, + KyberLibError::InvalidLength => { + write!(f, "The length of the input buffer is invalid.") } } } diff --git a/src/params.rs b/src/params.rs index 4d4178e..9134e9a 100644 --- a/src/params.rs +++ b/src/params.rs @@ -23,7 +23,7 @@ pub const KYBER_ETA1: usize = /// - It determines the noise distribution's width in the encryption process. pub const KYBER_ETA2: usize = 2; -// Size of the hashes and seeds +/// Size of the hashes and seeds pub const KYBER_SYM_BYTES: usize = 32; /// The parameter N, representing the degree of the polynomial used in Kyber. diff --git a/src/rng.rs b/src/rng.rs index b0008e4..6b3e192 100644 --- a/src/rng.rs +++ b/src/rng.rs @@ -44,6 +44,10 @@ pub fn randombytes( where R: RngCore + CryptoRng, { + if len > x.len() { + return Err(KyberLibError::InvalidLength); + } + rng.try_fill_bytes(&mut x[..len]) .map_err(|_| KyberLibError::RandomBytesGeneration) } diff --git a/tests/test_rng.rs b/tests/test_rng.rs index c55ddbd..cbffbf0 100644 --- a/tests/test_rng.rs +++ b/tests/test_rng.rs @@ -4,7 +4,7 @@ #[cfg(test)] mod tests { - use kyberlib::rng::randombytes; + use kyberlib::{rng::randombytes, KyberLibError}; use rand_core::OsRng; #[test] @@ -113,4 +113,94 @@ mod tests { assert!(result2.is_ok()); assert_ne!(&buffer1[..], &buffer2[..]); } + + #[test] + fn test_randombytes_partial_fill() { + // Test filling a partial buffer + let mut buffer = [0u8; 32]; + let partial_len = 16; + + // Use OsRng as the RNG + let mut rng = OsRng; + + // Call randombytes to partially fill the buffer + let result = randombytes(&mut buffer, partial_len, &mut rng); + + // Check if the result is Ok, indicating successful random byte generation + assert!(result.is_ok()); + // Check that the buffer length is unchanged + assert_eq!(buffer.len(), 32); + // Check that the first 16 bytes are filled with random data + assert_ne!(&buffer[..partial_len], &[0u8; 16]); + } + + #[test] + fn test_randombytes_error_handling() { + // Test with a buffer of size 32 + let mut buffer = [0u8; 32]; + let buffer_len = buffer.len(); + + // Use OsRng as the RNG + let mut rng = OsRng; + + // Call randombytes with a valid length + let result = randombytes(&mut buffer, buffer_len, &mut rng); + + // Check if the result is ok + assert!(result.is_ok()); + + // Call randombytes with an invalid length + let result = randombytes(&mut buffer, buffer_len + 1, &mut rng); + + // Check if the result is an error + assert!(matches!(result, Err(KyberLibError::InvalidLength))); + } + + #[test] + fn test_randombytes_out_of_bounds() { + // Test with a buffer of size 32 + let mut buffer = [0u8; 32]; + let buffer_len = buffer.len(); + + // Use OsRng as the RNG + let mut rng = OsRng; + + // Call randombytes with an out-of-bounds length + let result = randombytes(&mut buffer, buffer_len + 1, &mut rng); + + // Check if the result is an InvalidLength error + assert!(matches!(result, Err(KyberLibError::InvalidLength))); + } + + #[test] + fn test_randombytes_invalid_rng() { + // Test with a buffer of size 32 + let mut buffer = [0u8; 32]; + let buffer_len = buffer.len(); + + // Use OsRng as the RNG + let mut rng = OsRng; + + // Call randombytes with a valid RNG + let result = randombytes(&mut buffer, buffer_len, &mut rng); + + // Check if the result is ok + assert!(result.is_ok()); + } + + #[test] + fn test_randombytes_invalid_length() { + // Test with a buffer of size 32 + let mut buffer = [0u8; 32]; + let invalid_len = 33; + + // Use OsRng as the RNG + let mut rng = OsRng; + + // Call randombytes with an invalid length + let result = randombytes(&mut buffer, invalid_len, &mut rng); + + // Check if the result is an InvalidLength error + assert!(matches!(result, Err(KyberLibError::InvalidLength))); + } }