Skip to content

Commit

Permalink
LapackStrict trait for strict memory management API
Browse files Browse the repository at this point in the history
  • Loading branch information
termoshtt committed Aug 3, 2020
1 parent d7eea76 commit ba45af3
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 107 deletions.
83 changes: 40 additions & 43 deletions lax/src/eigh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@ use crate::{error::*, layout::MatrixLayout};
use cauchy::*;
use num_traits::{ToPrimitive, Zero};

pub trait Eigh: Sized {
type Elem: Scalar;

pub(crate) trait Eigh: Scalar {
/// Allocate working memory for eigenvalue problem
fn eigh_work(calc_eigenvec: bool, layout: MatrixLayout, uplo: UPLO) -> Result<Self>;
fn eigh_work(calc_eigenvec: bool, layout: MatrixLayout, uplo: UPLO) -> Result<EighWork<Self>>;

/// Solve eigenvalue problem
fn eigh_calc(&mut self, a: &mut [Self::Elem]) -> Result<&[<Self::Elem as Scalar>::Real]>;
fn eigh_calc<'work>(
work: &'work mut EighWork<Self>,
a: &mut [Self],
) -> Result<&'work [Self::Real]>;
}

/// Working memory for symmetric/Hermitian eigenvalue problem. See [Eigh trait](trait.Eigh.html)
/// Working memory for symmetric/Hermitian eigenvalue problem. See [LapackStrict trait](trait.LapackStrict.html)
pub struct EighWork<T: Scalar> {
jobz: u8,
uplo: UPLO,
Expand All @@ -29,17 +30,15 @@ pub struct EighWork<T: Scalar> {

macro_rules! impl_eigh_work_real {
($scalar:ty, $ev:path) => {
impl Eigh for EighWork<$scalar> {
type Elem = $scalar;

fn eigh_work(calc_v: bool, layout: MatrixLayout, uplo: UPLO) -> Result<Self> {
impl Eigh for $scalar {
fn eigh_work(calc_v: bool, layout: MatrixLayout, uplo: UPLO) -> Result<EighWork<Self>> {
assert_eq!(layout.len(), layout.lda());
let n = layout.len();
let jobz = if calc_v { b'V' } else { b'N' };
let mut eigs = unsafe { vec_uninit(n as usize) };

let mut info = 0;
let mut work_size = [Self::Elem::zero()];
let mut work_size = [Self::zero()];
unsafe {
$ev(
jobz,
Expand All @@ -66,28 +65,28 @@ macro_rules! impl_eigh_work_real {
})
}

fn eigh_calc(
&mut self,
a: &mut [Self::Elem],
) -> Result<&[<Self::Elem as Scalar>::Real]> {
assert_eq!(a.len(), (self.n * self.n) as usize);
fn eigh_calc<'work>(
work: &'work mut EighWork<Self>,
a: &mut [Self],
) -> Result<&'work [Self::Real]> {
assert_eq!(a.len(), (work.n * work.n) as usize);
let mut info = 0;
let lwork = self.work.len() as i32;
let lwork = work.work.len() as i32;
unsafe {
$ev(
self.jobz,
self.uplo as u8,
self.n,
work.jobz,
work.uplo as u8,
work.n,
a,
self.n,
&mut self.eigs,
&mut self.work,
work.n,
&mut work.eigs,
&mut work.work,
lwork,
&mut info,
);
}
info.as_lapack_result()?;
Ok(&self.eigs)
Ok(&work.eigs)
}
}
};
Expand All @@ -98,17 +97,15 @@ impl_eigh_work_real!(f64, lapack::dsyev);

macro_rules! impl_eigh_work_complex {
($scalar:ty, $ev:path) => {
impl Eigh for EighWork<$scalar> {
type Elem = $scalar;

fn eigh_work(calc_v: bool, layout: MatrixLayout, uplo: UPLO) -> Result<Self> {
impl Eigh for $scalar {
fn eigh_work(calc_v: bool, layout: MatrixLayout, uplo: UPLO) -> Result<EighWork<Self>> {
assert_eq!(layout.len(), layout.lda());
let n = layout.len();
let jobz = if calc_v { b'V' } else { b'N' };
let mut eigs = unsafe { vec_uninit(n as usize) };

let mut info = 0;
let mut work_size = [Self::Elem::zero()];
let mut work_size = [Self::zero()];
let mut rwork = unsafe { vec_uninit(3 * n as usize - 2) };
unsafe {
$ev(
Expand Down Expand Up @@ -137,29 +134,29 @@ macro_rules! impl_eigh_work_complex {
})
}

fn eigh_calc(
&mut self,
a: &mut [Self::Elem],
) -> Result<&[<Self::Elem as Scalar>::Real]> {
assert_eq!(a.len(), (self.n * self.n) as usize);
fn eigh_calc<'work>(
work: &'work mut EighWork<Self>,
a: &mut [Self],
) -> Result<&'work [Self::Real]> {
assert_eq!(a.len(), (work.n * work.n) as usize);
let mut info = 0;
let lwork = self.work.len() as i32;
let lwork = work.work.len() as i32;
unsafe {
$ev(
self.jobz,
self.uplo as u8,
self.n,
work.jobz,
work.uplo as u8,
work.n,
a,
self.n,
&mut self.eigs,
&mut self.work,
work.n,
&mut work.eigs,
&mut work.work,
lwork,
self.rwork.as_mut().unwrap(),
work.rwork.as_mut().unwrap(),
&mut info,
);
}
info.as_lapack_result()?;
Ok(&self.eigs)
Ok(&work.eigs)
}
}
};
Expand Down
113 changes: 49 additions & 64 deletions lax/src/eigh_generalized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,25 @@ use crate::{error::*, layout::MatrixLayout};
use cauchy::*;
use num_traits::{ToPrimitive, Zero};

/// Types of generalized eigenvalue problem
#[allow(dead_code)] // FIXME create interface to use ABxlx and BAxlx
#[repr(i32)]
pub enum ITYPE {
/// Solve $ A x = \lambda B x $
AxlBx = 1,
/// Solve $ A B x = \lambda x $
ABxlx = 2,
/// Solve $ B A x = \lambda x $
BAxlx = 3,
}

/// Generalized eigenvalue problem for Symmetric/Hermite matrices
pub trait EighGeneralized: Sized {
type Elem: Scalar;

pub(crate) trait EighGeneralized: Scalar {
/// Allocate working memory
fn eigh_generalized_work(calc_eigenvec: bool, layout: MatrixLayout, uplo: UPLO)
-> Result<Self>;
fn eigh_generalized_work(
calc_eigenvec: bool,
layout: MatrixLayout,
uplo: UPLO,
) -> Result<EighGeneralizedWork<Self>>;

/// Solve generalized eigenvalue problem
fn eigh_generalized_calc(
&mut self,
a: &mut [Self::Elem],
b: &mut [Self::Elem],
) -> Result<&[<Self::Elem as Scalar>::Real]>;
fn eigh_generalized_calc<'work>(
work: &'work mut EighGeneralizedWork<Self>,
a: &mut [Self],
b: &mut [Self],
) -> Result<&'work [Self::Real]>;
}

/// Working memory for symmetric/Hermitian generalized eigenvalue problem.
/// See [EighGeneralized trait](trait.EighGeneralized.html)
/// See [LapackStrict trait](trait.LapackStrict.html)
pub struct EighGeneralizedWork<T: Scalar> {
jobz: u8,
uplo: UPLO,
Expand All @@ -46,21 +35,19 @@ pub struct EighGeneralizedWork<T: Scalar> {

macro_rules! impl_eigh_work_real {
($scalar:ty, $ev:path) => {
impl EighGeneralized for EighGeneralizedWork<$scalar> {
type Elem = $scalar;

impl EighGeneralized for $scalar {
fn eigh_generalized_work(
calc_v: bool,
layout: MatrixLayout,
uplo: UPLO,
) -> Result<Self> {
) -> Result<EighGeneralizedWork<Self>> {
assert_eq!(layout.len(), layout.lda());
let n = layout.len();
let jobz = if calc_v { b'V' } else { b'N' };
let mut eigs = unsafe { vec_uninit(n as usize) };

let mut info = 0;
let mut work_size = [Self::Elem::zero()];
let mut work_size = [Self::zero()];
unsafe {
$ev(
&[ITYPE::AxlBx as i32],
Expand Down Expand Up @@ -90,32 +77,32 @@ macro_rules! impl_eigh_work_real {
})
}

fn eigh_generalized_calc(
&mut self,
a: &mut [Self::Elem],
b: &mut [Self::Elem],
) -> Result<&[<Self::Elem as Scalar>::Real]> {
assert_eq!(a.len(), (self.n * self.n) as usize);
fn eigh_generalized_calc<'work>(
work: &'work mut EighGeneralizedWork<Self>,
a: &mut [Self],
b: &mut [Self],
) -> Result<&'work [Self::Real]> {
assert_eq!(a.len(), (work.n * work.n) as usize);
let mut info = 0;
let lwork = self.work.len() as i32;
let lwork = work.work.len() as i32;
unsafe {
$ev(
&[ITYPE::AxlBx as i32],
self.jobz,
self.uplo as u8,
self.n,
work.jobz,
work.uplo as u8,
work.n,
a,
self.n,
work.n,
b,
self.n,
&mut self.eigs,
&mut self.work,
work.n,
&mut work.eigs,
&mut work.work,
lwork,
&mut info,
);
}
info.as_lapack_result()?;
Ok(&self.eigs)
Ok(&work.eigs)
}
}
};
Expand All @@ -126,14 +113,12 @@ impl_eigh_work_real!(f64, lapack::dsygv);

macro_rules! impl_eigh_work_complex {
($scalar:ty, $ev:path) => {
impl EighGeneralized for EighGeneralizedWork<$scalar> {
type Elem = $scalar;

impl EighGeneralized for $scalar {
fn eigh_generalized_work(
calc_v: bool,
layout: MatrixLayout,
uplo: UPLO,
) -> Result<Self> {
) -> Result<EighGeneralizedWork<Self>> {
assert_eq!(layout.len(), layout.lda());
let n = layout.len();
let jobz = if calc_v { b'V' } else { b'N' };
Expand All @@ -142,7 +127,7 @@ macro_rules! impl_eigh_work_complex {
let mut eigs = unsafe { vec_uninit(n as usize) };

let mut info = 0;
let mut work_size = [Self::Elem::zero()];
let mut work_size = [Self::zero()];
let mut rwork = unsafe { vec_uninit(3 * n as usize - 2) };
unsafe {
$ev(
Expand Down Expand Up @@ -174,33 +159,33 @@ macro_rules! impl_eigh_work_complex {
})
}

fn eigh_generalized_calc(
&mut self,
a: &mut [Self::Elem],
b: &mut [Self::Elem],
) -> Result<&[<Self::Elem as Scalar>::Real]> {
assert_eq!(a.len(), (self.n * self.n) as usize);
fn eigh_generalized_calc<'work>(
work: &'work mut EighGeneralizedWork<Self>,
a: &mut [Self],
b: &mut [Self],
) -> Result<&'work [Self::Real]> {
assert_eq!(a.len(), (work.n * work.n) as usize);
let mut info = 0;
let lwork = self.work.len() as i32;
let lwork = work.work.len() as i32;
unsafe {
$ev(
&[ITYPE::AxlBx as i32],
self.jobz,
self.uplo as u8,
self.n,
work.jobz,
work.uplo as u8,
work.n,
a,
self.n,
work.n,
b,
self.n,
&mut self.eigs,
&mut self.work,
work.n,
&mut work.eigs,
&mut work.work,
lwork,
self.rwork.as_mut().unwrap(),
work.rwork.as_mut().unwrap(),
&mut info,
);
}
info.as_lapack_result()?;
Ok(&self.eigs)
Ok(&work.eigs)
}
}
};
Expand Down
14 changes: 14 additions & 0 deletions lax/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ mod qr;
mod rcond;
mod solve;
mod solveh;
mod strict;
mod svd;
mod svddc;
mod traits;
Expand All @@ -96,6 +97,7 @@ pub use self::qr::*;
pub use self::rcond::*;
pub use self::solve::*;
pub use self::solveh::*;
pub use self::strict::*;
pub use self::svd::*;
pub use self::svddc::*;
pub use self::traits::*;
Expand Down Expand Up @@ -147,6 +149,18 @@ impl NormType {
}
}

/// Types of generalized eigenvalue problem
#[allow(dead_code)] // FIXME create interface to use ABxlx and BAxlx
#[repr(i32)]
pub enum ITYPE {
/// Solve $ A x = \lambda B x $
AxlBx = 1,
/// Solve $ A B x = \lambda x $
ABxlx = 2,
/// Solve $ B A x = \lambda x $
BAxlx = 3,
}

/// Create a vector without initialization
///
/// Safety
Expand Down
Loading

0 comments on commit ba45af3

Please sign in to comment.