From 3569c1035f83e1818704eb14d582ee143268a3e4 Mon Sep 17 00:00:00 2001 From: GBathie Date: Fri, 1 Mar 2024 18:05:20 +0100 Subject: [PATCH] Add RNG to random projection structs RNG defaults to Xoshiro256Plus if not provided by user. Also added tests for minimum dimension using values from scikit-learn. --- algorithms/linfa-reduction/Cargo.toml | 1 + .../examples/gaussian_projection.rs | 8 ++--- .../examples/sparse_projection.rs | 5 ++- .../src/random_projection/common.rs | 33 +++++++++++++++++-- .../random_projection/gaussian/algorithms.rs | 27 ++++++++++----- .../random_projection/gaussian/hyperparams.rs | 10 +++--- .../src/random_projection/mod.rs | 11 ++++--- .../src/random_projection/projection.rs | 2 +- .../random_projection/sparse/algorithms.rs | 26 ++++++++++++--- .../random_projection/sparse/hyperparams.rs | 28 ++++++++++++---- 10 files changed, 112 insertions(+), 39 deletions(-) diff --git a/algorithms/linfa-reduction/Cargo.toml b/algorithms/linfa-reduction/Cargo.toml index 28221b18a..967dfebdf 100644 --- a/algorithms/linfa-reduction/Cargo.toml +++ b/algorithms/linfa-reduction/Cargo.toml @@ -45,6 +45,7 @@ rand = { version = "0.8", features = ["small_rng"] } linfa = { version = "0.7.0", path = "../.." } linfa-kernel = { version = "0.7.0", path = "../linfa-kernel" } sprs = "0.11.1" +rand_xoshiro = "0.6.0" [dev-dependencies] ndarray-npy = { version = "0.8", default-features = false } diff --git a/algorithms/linfa-reduction/examples/gaussian_projection.rs b/algorithms/linfa-reduction/examples/gaussian_projection.rs index f6aabbaf6..9548479a3 100644 --- a/algorithms/linfa-reduction/examples/gaussian_projection.rs +++ b/algorithms/linfa-reduction/examples/gaussian_projection.rs @@ -6,7 +6,8 @@ use linfa_trees::{DecisionTree, SplitQuality}; use mnist::{MnistBuilder, NormalizedMnist}; use ndarray::{Array1, Array2}; -use rand::thread_rng; +use rand::SeedableRng; +use rand_xoshiro::Xoshiro256Plus; /// Train a Decision tree on the MNIST data set, with and without dimensionality reduction. fn main() -> Result<(), Box> { @@ -14,6 +15,7 @@ fn main() -> Result<(), Box> { let train_sz = 10_000usize; let test_sz = 1_000usize; let reduced_dim = 100; + let rng = Xoshiro256Plus::seed_from_u64(42); let NormalizedMnist { trn_img, @@ -54,10 +56,8 @@ fn main() -> Result<(), Box> { println!("Training reduced model..."); let start = Instant::now(); // Compute the random projection and train the model on the reduced dataset. - let rng = thread_rng(); - let proj = GaussianRandomProjection::::params() + let proj = GaussianRandomProjection::::params_with_rng(rng) .target_dim(reduced_dim) - .with_rng(rng) .fit(&train_dataset)?; let reduced_train_ds = proj.transform(&train_dataset); let reduced_test_data = proj.transform(&test_data); diff --git a/algorithms/linfa-reduction/examples/sparse_projection.rs b/algorithms/linfa-reduction/examples/sparse_projection.rs index 9116a7276..d775d34e0 100644 --- a/algorithms/linfa-reduction/examples/sparse_projection.rs +++ b/algorithms/linfa-reduction/examples/sparse_projection.rs @@ -6,6 +6,8 @@ use linfa_trees::{DecisionTree, SplitQuality}; use mnist::{MnistBuilder, NormalizedMnist}; use ndarray::{Array1, Array2}; +use rand::SeedableRng; +use rand_xoshiro::Xoshiro256Plus; /// Train a Decision tree on the MNIST data set, with and without dimensionality reduction. fn main() -> Result<(), Box> { @@ -13,6 +15,7 @@ fn main() -> Result<(), Box> { let train_sz = 10_000usize; let test_sz = 1_000usize; let reduced_dim = 100; + let rng = Xoshiro256Plus::seed_from_u64(42); let NormalizedMnist { trn_img, @@ -53,7 +56,7 @@ fn main() -> Result<(), Box> { println!("Training reduced model..."); let start = Instant::now(); // Compute the random projection and train the model on the reduced dataset. - let proj = SparseRandomProjection::::params() + let proj = SparseRandomProjection::::params_with_rng(rng) .target_dim(reduced_dim) .fit(&train_dataset)?; let reduced_train_ds = proj.transform(&train_dataset); diff --git a/algorithms/linfa-reduction/src/random_projection/common.rs b/algorithms/linfa-reduction/src/random_projection/common.rs index fa0f9325e..ef8e36d56 100644 --- a/algorithms/linfa-reduction/src/random_projection/common.rs +++ b/algorithms/linfa-reduction/src/random_projection/common.rs @@ -5,7 +5,34 @@ /// - [D. Achlioptas, JCSS](https://www.sciencedirect.com/science/article/pii/S0022000003000254) /// - [Li et al., SIGKDD'06](https://hastie.su.domains/Papers/Ping/KDD06_rp.pdf) pub(crate) fn johnson_lindenstrauss_min_dim(n_samples: usize, eps: f64) -> usize { - let log_samples = (n_samples as f64).log2(); - let value = 4. * log_samples * (eps.powi(2) / 2. - eps.powi(3) / 3.); - value.ceil() as usize + let log_samples = (n_samples as f64).ln(); + let value = 4. * log_samples / (eps.powi(2) / 2. - eps.powi(3) / 3.); + value as usize +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + /// Test against values computed by the scikit-learn implementation + /// of `johnson_lindenstrauss_min_dim`. + fn test_johnson_lindenstrauss() { + assert_eq!(johnson_lindenstrauss_min_dim(100, 0.05), 15244); + assert_eq!(johnson_lindenstrauss_min_dim(100, 0.1), 3947); + assert_eq!(johnson_lindenstrauss_min_dim(100, 0.2), 1062); + assert_eq!(johnson_lindenstrauss_min_dim(100, 0.5), 221); + assert_eq!(johnson_lindenstrauss_min_dim(1000, 0.05), 22867); + assert_eq!(johnson_lindenstrauss_min_dim(1000, 0.1), 5920); + assert_eq!(johnson_lindenstrauss_min_dim(1000, 0.2), 1594); + assert_eq!(johnson_lindenstrauss_min_dim(1000, 0.5), 331); + assert_eq!(johnson_lindenstrauss_min_dim(5000, 0.05), 28194); + assert_eq!(johnson_lindenstrauss_min_dim(5000, 0.1), 7300); + assert_eq!(johnson_lindenstrauss_min_dim(5000, 0.2), 1965); + assert_eq!(johnson_lindenstrauss_min_dim(5000, 0.5), 408); + assert_eq!(johnson_lindenstrauss_min_dim(10000, 0.05), 30489); + assert_eq!(johnson_lindenstrauss_min_dim(10000, 0.1), 7894); + assert_eq!(johnson_lindenstrauss_min_dim(10000, 0.2), 2125); + assert_eq!(johnson_lindenstrauss_min_dim(10000, 0.5), 442); + } } diff --git a/algorithms/linfa-reduction/src/random_projection/gaussian/algorithms.rs b/algorithms/linfa-reduction/src/random_projection/gaussian/algorithms.rs index 2086e6cab..2cda409a1 100644 --- a/algorithms/linfa-reduction/src/random_projection/gaussian/algorithms.rs +++ b/algorithms/linfa-reduction/src/random_projection/gaussian/algorithms.rs @@ -4,7 +4,8 @@ use ndarray_rand::{ rand_distr::{Normal, StandardNormal}, RandomExt, }; -use rand::{prelude::Distribution, rngs::SmallRng, Rng}; +use rand::{prelude::Distribution, Rng, SeedableRng}; +use rand_xoshiro::Xoshiro256Plus; use super::super::common::johnson_lindenstrauss_min_dim; use super::hyperparams::GaussianRandomProjectionParamsInner; @@ -28,6 +29,7 @@ where fn fit(&self, dataset: &linfa::DatasetBase) -> Result { let n_samples = dataset.nsamples(); let n_features = dataset.nfeatures(); + let mut rng = self.rng.clone(); let n_dims = match &self.params { GaussianRandomProjectionParamsInner::Dimension { target_dim } => *target_dim, @@ -39,22 +41,31 @@ where let std_dev = F::cast(n_features).sqrt().recip(); let gaussian = Normal::new(F::zero(), std_dev)?; - let proj = match self.rng.clone() { - Some(mut rng) => Array::random_using((n_features, n_dims), gaussian, &mut rng), - None => Array::random((n_features, n_dims), gaussian), - }; + let proj = Array::random_using((n_features, n_dims), gaussian, &mut rng); Ok(GaussianRandomProjection { projection: proj }) } } impl GaussianRandomProjection { + /// Create new parameters for a [`GaussianRandomProjection`] with default value + /// `precision = 0.1` and a [`Xoshiro256Plus`] RNG. + pub fn params() -> GaussianRandomProjectionParams { + GaussianRandomProjectionParams(GaussianRandomProjectionValidParams { + params: GaussianRandomProjectionParamsInner::Precision { precision: 0.1 }, + rng: Xoshiro256Plus::seed_from_u64(42), + }) + } + /// Create new parameters for a [`GaussianRandomProjection`] with default values - /// `precision = 0.1` and no custom [`Rng`] provided. - pub fn params() -> GaussianRandomProjectionParams { + /// `precision = 0.1` and the provided [`Rng`]. + pub fn params_with_rng(rng: R) -> GaussianRandomProjectionParams + where + R: Rng + Clone, + { GaussianRandomProjectionParams(GaussianRandomProjectionValidParams { params: GaussianRandomProjectionParamsInner::Precision { precision: 0.1 }, - rng: None, + rng, }) } } diff --git a/algorithms/linfa-reduction/src/random_projection/gaussian/hyperparams.rs b/algorithms/linfa-reduction/src/random_projection/gaussian/hyperparams.rs index 8cc031eab..6b8569a50 100644 --- a/algorithms/linfa-reduction/src/random_projection/gaussian/hyperparams.rs +++ b/algorithms/linfa-reduction/src/random_projection/gaussian/hyperparams.rs @@ -43,12 +43,10 @@ impl GaussianRandomProjectionParams { } /// Specify the random number generator to use to generate the projection matrix. - /// - /// Optional: if no RNG is specified, uses the default RNG in [ndarray_rand::RandomExt]. pub fn with_rng(self, rng: R2) -> GaussianRandomProjectionParams { GaussianRandomProjectionParams(GaussianRandomProjectionValidParams { params: self.0.params, - rng: Some(rng), + rng, }) } } @@ -68,7 +66,7 @@ impl GaussianRandomProjectionParams { #[derive(Debug, Clone, PartialEq)] pub struct GaussianRandomProjectionValidParams { pub(super) params: GaussianRandomProjectionParamsInner, - pub(super) rng: Option, + pub(super) rng: R, } /// Internal data structure that either holds the dimension or the embedding, @@ -107,8 +105,8 @@ impl GaussianRandomProjectionValidParams { self.params.eps() } - pub fn rng(&self) -> Option<&R> { - self.rng.as_ref() + pub fn rng(&self) -> &R { + &self.rng } } diff --git a/algorithms/linfa-reduction/src/random_projection/mod.rs b/algorithms/linfa-reduction/src/random_projection/mod.rs index e68ade1c7..7cb04df22 100644 --- a/algorithms/linfa-reduction/src/random_projection/mod.rs +++ b/algorithms/linfa-reduction/src/random_projection/mod.rs @@ -39,15 +39,16 @@ pub use sparse::{ #[cfg(test)] mod tests { use super::*; - use rand::rngs::SmallRng; + + use rand_xoshiro::Xoshiro256Plus; #[test] fn autotraits_gaussian() { fn has_autotraits() {} has_autotraits::>(); has_autotraits::>(); - has_autotraits::>(); - has_autotraits::>(); + has_autotraits::>(); + has_autotraits::>(); } #[test] @@ -55,7 +56,7 @@ mod tests { fn has_autotraits() {} has_autotraits::>(); has_autotraits::>(); - has_autotraits::(); - has_autotraits::(); + has_autotraits::>(); + has_autotraits::>(); } } diff --git a/algorithms/linfa-reduction/src/random_projection/projection.rs b/algorithms/linfa-reduction/src/random_projection/projection.rs index f5bf69543..c27eb9e5a 100644 --- a/algorithms/linfa-reduction/src/random_projection/projection.rs +++ b/algorithms/linfa-reduction/src/random_projection/projection.rs @@ -1,5 +1,5 @@ /// Macro that implements [`linfa::traits::Transformer`] -/// for [`GaussianRandomProjection`] and [`SparseRandomProjection`], +/// for [`super::GaussianRandomProjection`] and [`super::SparseRandomProjection`], /// to avoid some code duplication. #[macro_export] macro_rules! impl_proj { diff --git a/algorithms/linfa-reduction/src/random_projection/sparse/algorithms.rs b/algorithms/linfa-reduction/src/random_projection/sparse/algorithms.rs index c9da5dabb..db47cfa17 100644 --- a/algorithms/linfa-reduction/src/random_projection/sparse/algorithms.rs +++ b/algorithms/linfa-reduction/src/random_projection/sparse/algorithms.rs @@ -1,7 +1,9 @@ use linfa::{prelude::Records, traits::Fit, Float}; use ndarray::Ix2; use ndarray_rand::rand_distr::StandardNormal; -use rand::{distributions::Bernoulli, prelude::Distribution, thread_rng, Rng}; +use rand::SeedableRng; +use rand::{distributions::Bernoulli, prelude::Distribution, Rng}; +use rand_xoshiro::Xoshiro256Plus; use sprs::{CsMat, TriMat}; use super::super::common::johnson_lindenstrauss_min_dim; @@ -14,17 +16,19 @@ pub struct SparseRandomProjection { projection: CsMat, } -impl Fit for SparseRandomProjectionValidParams +impl Fit for SparseRandomProjectionValidParams where F: Float, Rec: Records, StandardNormal: Distribution, + R: Rng + Clone, { type Object = SparseRandomProjection; fn fit(&self, dataset: &linfa::DatasetBase) -> Result { let n_samples = dataset.nsamples(); let n_features = dataset.nfeatures(); + let mut rng = self.rng.clone(); let n_dims = match &self.params { SparseRandomProjectionParamsInner::Dimension { target_dim } => *target_dim, @@ -36,7 +40,6 @@ where let scale = (n_features as f64).sqrt(); let p = 1f64 / scale; let dist = SparseDistribution::new(F::cast(scale), p); - let mut rng = thread_rng(); let (mut row_inds, mut col_inds, mut values) = (Vec::new(), Vec::new(), Vec::new()); for row in 0..n_features { @@ -90,10 +93,23 @@ impl Distribution> for SparseDistribution { impl SparseRandomProjection { /// Create new parameters for a [`SparseRandomProjection`] with default value - /// `precision = 0.1`. - pub fn params() -> SparseRandomProjectionParams { + /// `precision = 0.1` and a [`Xoshiro256Plus`] RNG. + pub fn params() -> SparseRandomProjectionParams { SparseRandomProjectionParams(SparseRandomProjectionValidParams { params: SparseRandomProjectionParamsInner::Precision { precision: 0.1 }, + rng: Xoshiro256Plus::seed_from_u64(42), + }) + } + + /// Create new parameters for a [`SparseRandomProjection`] with default values + /// `precision = 0.1` and the provided [`Rng`]. + pub fn params_with_rng(rng: R) -> SparseRandomProjectionParams + where + R: Rng + Clone, + { + SparseRandomProjectionParams(SparseRandomProjectionValidParams { + params: SparseRandomProjectionParamsInner::Precision { precision: 0.1 }, + rng, }) } } diff --git a/algorithms/linfa-reduction/src/random_projection/sparse/hyperparams.rs b/algorithms/linfa-reduction/src/random_projection/sparse/hyperparams.rs index 39333b248..556cfd308 100644 --- a/algorithms/linfa-reduction/src/random_projection/sparse/hyperparams.rs +++ b/algorithms/linfa-reduction/src/random_projection/sparse/hyperparams.rs @@ -1,6 +1,7 @@ use std::fmt::Debug; use linfa::ParamGuard; +use rand::Rng; use crate::ReductionError; @@ -13,9 +14,11 @@ use crate::ReductionError; /// However, this lemma makes a very conservative estimate of the required dimension, /// and does not leverage the structure of the data, therefore it is also possible /// to manually specify the dimension of the embedding. -pub struct SparseRandomProjectionParams(pub(crate) SparseRandomProjectionValidParams); +pub struct SparseRandomProjectionParams( + pub(crate) SparseRandomProjectionValidParams, +); -impl SparseRandomProjectionParams { +impl SparseRandomProjectionParams { /// Set the dimension of output of the embedding. /// /// Setting the target dimension with this function @@ -35,6 +38,14 @@ impl SparseRandomProjectionParams { self } + + /// Specify the random number generator to use to generate the projection matrix. + pub fn with_rng(self, rng: R2) -> SparseRandomProjectionParams { + SparseRandomProjectionParams(SparseRandomProjectionValidParams { + params: self.0.params, + rng, + }) + } } /// Sparse random projection hyperparameters @@ -47,8 +58,9 @@ impl SparseRandomProjectionParams { /// and does not leverage the structure of the data, therefore it is also possible /// to manually specify the dimension of the embedding. #[derive(Debug, Clone, PartialEq)] -pub struct SparseRandomProjectionValidParams { +pub struct SparseRandomProjectionValidParams { pub(super) params: SparseRandomProjectionParamsInner, + pub(super) rng: R, } /// Internal data structure that either holds the dimension or the embedding, @@ -78,7 +90,7 @@ impl SparseRandomProjectionParamsInner { } } -impl SparseRandomProjectionValidParams { +impl SparseRandomProjectionValidParams { pub fn target_dim(&self) -> Option { self.params.target_dim() } @@ -86,10 +98,14 @@ impl SparseRandomProjectionValidParams { pub fn precision(&self) -> Option { self.params.eps() } + + pub fn rng(&self) -> &R { + &self.rng + } } -impl ParamGuard for SparseRandomProjectionParams { - type Checked = SparseRandomProjectionValidParams; +impl ParamGuard for SparseRandomProjectionParams { + type Checked = SparseRandomProjectionValidParams; type Error = ReductionError; fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {