Skip to content

Commit

Permalink
[clustering] Derive {Des,S}erialize for all public items (#324)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyqsimon authored Nov 23, 2023
1 parent 083fc9a commit 00e59f6
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 4 deletions.
22 changes: 18 additions & 4 deletions algorithms/linfa-clustering/src/appx_dbscan/cells_grid/cell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@ use linfa::Float;
use linfa_nn::distance::{Distance, L2Dist};
use ndarray::{Array1, ArrayView1, ArrayView2, ArrayViewMut1};
use partitions::PartitionVec;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};

#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// A point in a D dimensional euclidean space that memorizes its
/// status: 'core' or 'non core'
pub struct StatusPoint {
Expand All @@ -16,10 +23,7 @@ pub struct StatusPoint {

impl StatusPoint {
pub fn new(point_index: usize) -> StatusPoint {
StatusPoint {
point_index,
is_core: false,
}
StatusPoint { point_index, is_core: false }
}

pub fn is_core(&self) -> bool {
Expand All @@ -32,6 +36,11 @@ impl StatusPoint {
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// Informations regarding the cell used in various stages of the approximate DBSCAN
/// algorithm if it is a core cell
pub struct CoreCellInfo<F: Float> {
Expand All @@ -42,6 +51,11 @@ pub struct CoreCellInfo<F: Float> {
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// A cell from a grid that partitions the D dimensional euclidean space.
pub struct Cell<F: Float> {
/// The index of the intervals of the D dimensional axes where this cell lies
Expand Down
7 changes: 7 additions & 0 deletions algorithms/linfa-clustering/src/appx_dbscan/cells_grid/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use linfa::Float;
use linfa_nn::{distance::L2Dist, NearestNeighbour};
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use partitions::PartitionVec;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};

use cell::{Cell, StatusPoint};

Expand All @@ -16,6 +18,11 @@ pub type CellVector<F> = PartitionVec<Cell<F>>;
pub type CellTable = HashMap<Array1<i64>, usize>;

#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct CellsGrid<F: Float> {
table: CellTable,
cells: CellVector<F>,
Expand Down
12 changes: 12 additions & 0 deletions algorithms/linfa-clustering/src/appx_dbscan/counting_tree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,28 @@ use crate::appx_dbscan::AppxDbscanValidParams;
use linfa::Float;
use linfa_nn::distance::{Distance, L2Dist};
use ndarray::{Array1, Array2, ArrayView1, Axis};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub enum IntersectionType {
FullyCovered,
Disjoint,
Intersecting,
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// Tree structure that divides the space in nested cells to perform approximate range counting
/// Each member of this structure is a node in the tree
pub struct TreeStructure<F: Float> {
Expand Down
7 changes: 7 additions & 0 deletions algorithms/linfa-clustering/src/dbscan/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,19 @@ use linfa_nn::{
CommonNearestNeighbour, NearestNeighbour, NearestNeighbourIndex,
};
use ndarray::{Array1, ArrayBase, Data, Ix2};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use std::collections::VecDeque;

use linfa::Float;
use linfa::{traits::Transformer, DatasetBase};

#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// DBSCAN (Density-based Spatial Clustering of Applications with Noise)
/// clusters together points which are close together with enough neighbors
/// labelled points which are sparsely neighbored as noise. As points may be
Expand Down
5 changes: 5 additions & 0 deletions algorithms/linfa-clustering/src/k_means/hyperparams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ pub struct KMeansValidParams<F: Float, R: Rng, D: Distance<F>> {
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// An helper struct used to construct a set of [valid hyperparameters](KMeansParams) for
/// the [K-means algorithm](crate::KMeans) (using the builder pattern).
pub struct KMeansParams<F: Float, R: Rng, D: Distance<F>>(KMeansValidParams<F, R, D>);
Expand Down
10 changes: 10 additions & 0 deletions algorithms/linfa-clustering/src/optics/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ pub struct Optics;
/// This struct represents a data point in the dataset with it's associated distances obtained from
/// the OPTICS analysis
#[derive(Debug, Clone)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct Sample<F> {
/// Index of the observation in the dataset
index: usize,
Expand Down Expand Up @@ -103,6 +108,11 @@ impl<F: Float> Ord for Sample<F> {
/// that of the dataset instead ordering based on the clustering structure worked out during
/// analysis.
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct OpticsAnalysis<F: Float> {
/// A list of the samples in the dataset sorted and with their reachability and core distances
/// computed.
Expand Down
5 changes: 5 additions & 0 deletions algorithms/linfa-clustering/src/optics/hyperparams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ impl<F: Float, D, N> OpticsValidParams<F, D, N> {
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct OpticsParams<F, D, N>(OpticsValidParams<F, D, N>);

impl<F: Float, D, N> OpticsParams<F, D, N> {
Expand Down

0 comments on commit 00e59f6

Please sign in to comment.