Skip to content

Commit

Permalink
feat: add methods for covariance decomposition, mean, covariance, and…
Browse files Browse the repository at this point in the history
… precision in MultivariateNormal
  • Loading branch information
Qazalbash committed Jan 15, 2025
1 parent 7c028b5 commit 768861f
Showing 1 changed file with 60 additions and 0 deletions.
60 changes: 60 additions & 0 deletions src/distribution/multivariate_normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,66 @@ where
.ln(),
)
}

/// Returns the Cholesky decomposition of the covariance matrix
///
/// # Example
///
/// ```
/// use nalgebra::OMatrix;
/// use statrs::distribution::MultivariateNormal;
///
/// let mvn = MultivariateNormal::new(vec![0., 0.], vec![1., 0., 0., 1.]).unwrap();
/// assert_eq!(mvn.cov_chol_decomp().shape(), (2, 2));
/// ```
pub fn cov_chol_decomp(&self) -> &OMatrix<f64, D, D> {
&self.cov_chol_decomp
}

/// Returns the mean of the multivariate normal distribution
///
/// # Example
///
/// ```
/// use nalgebra::OVector;
/// use statrs::distribution::MultivariateNormal;
///
/// let mvn = MultivariateNormal::new(vec![0., 0.], vec![1., 0., 0., 1.]).unwrap();
/// assert_eq!(mvn.mu(), &OVector::from_vec(vec![0., 0.]));
/// ```
pub fn mu(&self) -> &OVector<f64, D> {
&self.mu
}

/// Returns the mean of the multivariate normal distribution
///
/// # Example
///
/// ```
/// use nalgebra::OVector;
/// use statrs::distribution::MultivariateNormal;
///
/// let mvn = MultivariateNormal::new(vec![0., 0.], vec![1., 0., 0., 1.]).unwrap();
/// assert_eq!(mvn.mean(), &OVector::from_vec(vec![0., 0.]));
/// ```
pub fn cov(&self) -> &OMatrix<f64, D, D> {
&self.cov
}

/// Returns the precision matrix of the multivariate normal distribution
///
/// # Example
///
/// ```
/// use nalgebra::OMatrix;
/// use statrs::distribution::MultivariateNormal;
///
/// let mvn = MultivariateNormal::new(vec![0., 0.], vec![1., 0., 0., 1.]).unwrap();
/// assert_eq!(mvn.precision().shape(), (2, 2));
/// ```
pub fn precision(&self) -> &OMatrix<f64, D, D> {
&self.precision
}
}

impl<D> std::fmt::Display for MultivariateNormal<D>
Expand Down

0 comments on commit 768861f

Please sign in to comment.