From e35234744e411b1087950e546a5486f0c09f1e20 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 17 Jul 2020 20:48:22 +0900 Subject: [PATCH 1/5] Drop unsafe of cholesky --- lax/src/cholesky.rs | 29 ++++++++++++++++++----------- ndarray-linalg/src/cholesky.rs | 18 ++++++++---------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/lax/src/cholesky.rs b/lax/src/cholesky.rs index ef9473b4..673aaa78 100644 --- a/lax/src/cholesky.rs +++ b/lax/src/cholesky.rs @@ -8,32 +8,37 @@ pub trait Cholesky_: Sized { /// Cholesky: wrapper of `*potrf` /// /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.** - unsafe fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + /// Wrapper of `*potri` /// /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.** - unsafe fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + /// Wrapper of `*potrs` - unsafe fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) - -> Result<()>; + fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; } macro_rules! impl_cholesky { ($scalar:ty, $trf:path, $tri:path, $trs:path) => { impl Cholesky_ for $scalar { - unsafe fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); - $trf(l.lapacke_layout(), uplo as u8, n, a, n).as_lapack_result()?; + unsafe { + $trf(l.lapacke_layout(), uplo as u8, n, a, n).as_lapack_result()?; + } Ok(()) } - unsafe fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); - $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda()).as_lapack_result()?; + unsafe { + $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda()).as_lapack_result()?; + } Ok(()) } - unsafe fn solve_cholesky( + fn solve_cholesky( l: MatrixLayout, uplo: UPLO, a: &[Self], @@ -42,8 +47,10 @@ macro_rules! impl_cholesky { let (n, _) = l.size(); let nrhs = 1; let ldb = 1; - $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), b, ldb) - .as_lapack_result()?; + unsafe { + $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), b, ldb) + .as_lapack_result()?; + } Ok(()) } } diff --git a/ndarray-linalg/src/cholesky.rs b/ndarray-linalg/src/cholesky.rs index 79e240ab..3f445305 100644 --- a/ndarray-linalg/src/cholesky.rs +++ b/ndarray-linalg/src/cholesky.rs @@ -155,7 +155,7 @@ where fn invc_into(self) -> Result { let mut a = self.factor; - unsafe { A::inv_cholesky(a.square_layout()?, self.uplo, a.as_allocated_mut()?)? }; + A::inv_cholesky(a.square_layout()?, self.uplo, a.as_allocated_mut()?)?; triangular_fill_hermitian(&mut a, self.uplo); Ok(a) } @@ -173,14 +173,12 @@ where where Sb: DataMut, { - unsafe { - A::solve_cholesky( - self.factor.square_layout()?, - self.uplo, - self.factor.as_allocated()?, - b.as_slice_mut().unwrap(), - )? - }; + A::solve_cholesky( + self.factor.square_layout()?, + self.uplo, + self.factor.as_allocated()?, + b.as_slice_mut().unwrap(), + )?; Ok(b) } } @@ -259,7 +257,7 @@ where S: DataMut, { fn cholesky_inplace(&mut self, uplo: UPLO) -> Result<&mut Self> { - unsafe { A::cholesky(self.square_layout()?, uplo, self.as_allocated_mut()?)? }; + A::cholesky(self.square_layout()?, uplo, self.as_allocated_mut()?)?; Ok(self.into_triangular(uplo)) } } From df396b516dc2fee4ec12178df7a90bc94ff91bfe Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 17 Jul 2020 21:34:08 +0900 Subject: [PATCH 2/5] WIP --- lax/src/cholesky.rs | 34 +++++++++++++++++++++++++--------- lax/src/lib.rs | 9 +++++++++ 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/lax/src/cholesky.rs b/lax/src/cholesky.rs index 673aaa78..93c5147b 100644 --- a/lax/src/cholesky.rs +++ b/lax/src/cholesky.rs @@ -24,17 +24,29 @@ macro_rules! impl_cholesky { impl Cholesky_ for $scalar { fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); + let mut info = 0; + let uplo = match l { + MatrixLayout::F { .. } => uplo, + MatrixLayout::C { .. } => uplo.t(), + }; unsafe { - $trf(l.lapacke_layout(), uplo as u8, n, a, n).as_lapack_result()?; + $trf(uplo as u8, n, a, n, &mut info); } + info.as_lapack_result()?; Ok(()) } fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); + let mut info = 0; + let uplo = match l { + MatrixLayout::F { .. } => uplo, + MatrixLayout::C { .. } => uplo.t(), + }; unsafe { - $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda()).as_lapack_result()?; + $tri(uplo as u8, n, a, l.lda(), &mut info); } + info.as_lapack_result()?; Ok(()) } @@ -46,18 +58,22 @@ macro_rules! impl_cholesky { ) -> Result<()> { let (n, _) = l.size(); let nrhs = 1; - let ldb = 1; + let uplo = match l { + MatrixLayout::F { .. } => uplo, + MatrixLayout::C { .. } => uplo.t(), + }; + let mut info = 0; unsafe { - $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), b, ldb) - .as_lapack_result()?; + $trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info); } + info.as_lapack_result()?; Ok(()) } } }; } // end macro_rules -impl_cholesky!(f64, lapacke::dpotrf, lapacke::dpotri, lapacke::dpotrs); -impl_cholesky!(f32, lapacke::spotrf, lapacke::spotri, lapacke::spotrs); -impl_cholesky!(c64, lapacke::zpotrf, lapacke::zpotri, lapacke::zpotrs); -impl_cholesky!(c32, lapacke::cpotrf, lapacke::cpotri, lapacke::cpotrs); +impl_cholesky!(f64, lapack::dpotrf, lapack::dpotri, lapack::dpotrs); +impl_cholesky!(f32, lapack::spotrf, lapack::spotri, lapack::spotrs); +impl_cholesky!(c64, lapack::zpotrf, lapack::zpotri, lapack::zpotrs); +impl_cholesky!(c32, lapack::cpotrf, lapack::cpotri, lapack::cpotrs); diff --git a/lax/src/lib.rs b/lax/src/lib.rs index be88410e..bbbdd85b 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -126,6 +126,15 @@ pub enum UPLO { Lower = b'L', } +impl UPLO { + pub fn t(self) -> Self { + match self { + UPLO::Upper => UPLO::Lower, + UPLO::Lower => UPLO::Upper, + } + } +} + #[derive(Debug, Clone, Copy)] #[repr(u8)] pub enum Transpose { From 06a6c658c7ccb17d779798d0a625ced47b20f00d Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 18 Jul 2020 16:20:23 +0900 Subject: [PATCH 3/5] square_transpose --- lax/src/cholesky.rs | 31 +++++++++++++++++-------------- lax/src/layout.rs | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 14 deletions(-) diff --git a/lax/src/cholesky.rs b/lax/src/cholesky.rs index 93c5147b..618218f1 100644 --- a/lax/src/cholesky.rs +++ b/lax/src/cholesky.rs @@ -1,7 +1,7 @@ //! Cholesky decomposition use super::*; -use crate::{error::*, layout::MatrixLayout}; +use crate::{error::*, layout::*}; use cauchy::*; pub trait Cholesky_: Sized { @@ -24,45 +24,48 @@ macro_rules! impl_cholesky { impl Cholesky_ for $scalar { fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); + if matches!(l, MatrixLayout::C { .. }) { + square_transpose(l, a); + } let mut info = 0; - let uplo = match l { - MatrixLayout::F { .. } => uplo, - MatrixLayout::C { .. } => uplo.t(), - }; unsafe { $trf(uplo as u8, n, a, n, &mut info); } info.as_lapack_result()?; + if matches!(l, MatrixLayout::C { .. }) { + square_transpose(l, a); + } Ok(()) } fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); + if matches!(l, MatrixLayout::C { .. }) { + square_transpose(l, a); + } let mut info = 0; - let uplo = match l { - MatrixLayout::F { .. } => uplo, - MatrixLayout::C { .. } => uplo.t(), - }; unsafe { $tri(uplo as u8, n, a, l.lda(), &mut info); } info.as_lapack_result()?; + if matches!(l, MatrixLayout::C { .. }) { + square_transpose(l, a); + } Ok(()) } fn solve_cholesky( l: MatrixLayout, - uplo: UPLO, + mut uplo: UPLO, a: &[Self], b: &mut [Self], ) -> Result<()> { let (n, _) = l.size(); let nrhs = 1; - let uplo = match l { - MatrixLayout::F { .. } => uplo, - MatrixLayout::C { .. } => uplo.t(), - }; let mut info = 0; + if matches!(l, MatrixLayout::C { .. }) { + uplo = uplo.t(); + } unsafe { $trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info); } diff --git a/lax/src/layout.rs b/lax/src/layout.rs index aa9fe110..9dad70e6 100644 --- a/lax/src/layout.rs +++ b/lax/src/layout.rs @@ -37,6 +37,8 @@ //! This `S` for a matrix `A` is called "leading dimension of the array A" in LAPACK document, and denoted by `lda`. //! +use cauchy::Scalar; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum MatrixLayout { C { row: i32, lda: i32 }, @@ -96,3 +98,44 @@ impl MatrixLayout { } } } + +/// In-place transpose of a square matrix by keeping F/C layout +/// +/// Transpose for C-continuous array +/// +/// ```rust +/// # use lax::layout::*; +/// let layout = MatrixLayout::C { row: 2, lda: 2 }; +/// let mut a = vec![1., 2., 3., 4.]; +/// square_transpose(layout, &mut a); +/// assert_eq!(a, &[1., 3., 2., 4.]); +/// ``` +/// +/// Transpose for F-continuous array +/// +/// ```rust +/// # use lax::layout::*; +/// let layout = MatrixLayout::F { col: 2, lda: 2 }; +/// let mut a = vec![1., 3., 2., 4.]; +/// square_transpose(layout, &mut a); +/// assert_eq!(a, &[1., 2., 3., 4.]); +/// ``` +/// +/// Panics +/// ------ +/// - If size of `a` and `layout` size mismatch +/// +pub fn square_transpose(layout: MatrixLayout, a: &mut [T]) { + let (m, n) = layout.size(); + let n = n as usize; + let m = m as usize; + assert_eq!(a.len(), n * m); + for i in 0..m { + for j in (i + 1)..n { + let a_ij = a[i * n + j]; + let a_ji = a[j * m + i]; + a[i * n + j] = a_ji.conj(); + a[j * m + i] = a_ij.conj(); + } + } +} From 76c06d4491ddfe60dbd18d5af24d9df1dde5c1ee Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 18 Jul 2020 16:55:57 +0900 Subject: [PATCH 4/5] Split tests of cholesky --- ndarray-linalg/tests/cholesky.rs | 390 ++++++++++++++++--------------- 1 file changed, 201 insertions(+), 189 deletions(-) diff --git a/ndarray-linalg/tests/cholesky.rs b/ndarray-linalg/tests/cholesky.rs index accdbdf8..b45afb5c 100644 --- a/ndarray-linalg/tests/cholesky.rs +++ b/ndarray-linalg/tests/cholesky.rs @@ -1,216 +1,228 @@ use ndarray::*; use ndarray_linalg::*; -#[test] -fn cholesky() { - macro_rules! cholesky { - ($elem:ty, $rtol:expr) => { - let a_orig: Array2<$elem> = random_hpd(3); - println!("a = \n{:?}", a_orig); +macro_rules! cholesky { + ($elem:ty, $rtol:expr) => { + paste::item! { + #[test] + fn []() { + let a_orig: Array2<$elem> = random_hpd(3); + println!("a = \n{:?}", a_orig); - let upper = a_orig.cholesky(UPLO::Upper).unwrap(); - assert_close_l2!( - &upper.t().mapv(|elem| elem.conj()).dot(&upper.view()), - &a_orig, - $rtol - ); - - let lower = a_orig.cholesky(UPLO::Lower).unwrap(); - assert_close_l2!( - &lower.dot(&lower.t().mapv(|elem| elem.conj())), - &a_orig, - $rtol - ); - - let a: Array2<$elem> = replicate(&a_orig); - let upper = a.cholesky_into(UPLO::Upper).unwrap(); - assert_close_l2!( - &upper.t().mapv(|elem| elem.conj()).dot(&upper.view()), - &a_orig, - $rtol - ); + let upper = a_orig.cholesky(UPLO::Upper).unwrap(); + assert_close_l2!( + &upper.t().mapv(|elem| elem.conj()).dot(&upper.view()), + &a_orig, + $rtol + ); - let a: Array2<$elem> = replicate(&a_orig); - let lower = a.cholesky_into(UPLO::Lower).unwrap(); - assert_close_l2!( - &lower.dot(&lower.t().mapv(|elem| elem.conj())), - &a_orig, - $rtol - ); + let lower = a_orig.cholesky(UPLO::Lower).unwrap(); + assert_close_l2!( + &lower.dot(&lower.t().mapv(|elem| elem.conj())), + &a_orig, + $rtol + ); - let mut a: Array2<$elem> = replicate(&a_orig); - { - let upper = a.cholesky_inplace(UPLO::Upper).unwrap(); + let a: Array2<$elem> = replicate(&a_orig); + let upper = a.cholesky_into(UPLO::Upper).unwrap(); assert_close_l2!( &upper.t().mapv(|elem| elem.conj()).dot(&upper.view()), &a_orig, $rtol ); - } - assert_close_l2!( - &a.t().mapv(|elem| elem.conj()).dot(&upper.view()), - &a_orig, - $rtol - ); - let mut a: Array2<$elem> = replicate(&a_orig); - { - let lower = a.cholesky_inplace(UPLO::Lower).unwrap(); + let a: Array2<$elem> = replicate(&a_orig); + let lower = a.cholesky_into(UPLO::Lower).unwrap(); assert_close_l2!( &lower.dot(&lower.t().mapv(|elem| elem.conj())), &a_orig, $rtol ); + + let mut a: Array2<$elem> = replicate(&a_orig); + { + let upper = a.cholesky_inplace(UPLO::Upper).unwrap(); + assert_close_l2!( + &upper.t().mapv(|elem| elem.conj()).dot(&upper.view()), + &a_orig, + $rtol + ); + } + assert_close_l2!( + &a.t().mapv(|elem| elem.conj()).dot(&upper.view()), + &a_orig, + $rtol + ); + + let mut a: Array2<$elem> = replicate(&a_orig); + { + let lower = a.cholesky_inplace(UPLO::Lower).unwrap(); + assert_close_l2!( + &lower.dot(&lower.t().mapv(|elem| elem.conj())), + &a_orig, + $rtol + ); + } + assert_close_l2!(&a.dot(&lower.t().mapv(|elem| elem.conj())), &a_orig, $rtol); } - assert_close_l2!(&a.dot(&lower.t().mapv(|elem| elem.conj())), &a_orig, $rtol); - }; - } - cholesky!(f64, 1e-9); - cholesky!(f32, 1e-5); - cholesky!(c64, 1e-9); - cholesky!(c32, 1e-5); + } // paste::item + }; } -#[test] -fn cholesky_into_lower_upper() { - macro_rules! cholesky_into_lower_upper { - ($elem:ty, $rtol:expr) => { - let a: Array2<$elem> = random_hpd(3); - println!("a = \n{:?}", a); - let upper = a.cholesky(UPLO::Upper).unwrap(); - let fac_upper = a.factorizec(UPLO::Upper).unwrap(); - let fac_lower = a.factorizec(UPLO::Lower).unwrap(); - assert_close_l2!(&upper, &fac_lower.into_upper(), $rtol); - assert_close_l2!(&upper, &fac_upper.into_upper(), $rtol); - let lower = a.cholesky(UPLO::Lower).unwrap(); - let fac_upper = a.factorizec(UPLO::Upper).unwrap(); - let fac_lower = a.factorizec(UPLO::Lower).unwrap(); - assert_close_l2!(&lower, &fac_lower.into_lower(), $rtol); - assert_close_l2!(&lower, &fac_upper.into_lower(), $rtol); - }; - } - cholesky_into_lower_upper!(f64, 1e-9); - cholesky_into_lower_upper!(f32, 1e-5); - cholesky_into_lower_upper!(c64, 1e-9); - cholesky_into_lower_upper!(c32, 1e-5); -} +cholesky!(f64, 1e-9); +cholesky!(f32, 1e-5); +cholesky!(c64, 1e-9); +cholesky!(c32, 1e-5); -#[test] -fn cholesky_inverse() { - macro_rules! cholesky_into_inverse { - ($elem:ty, $rtol:expr) => { - let a: Array2<$elem> = random_hpd(3); - println!("a = \n{:?}", a); - let inv = a.invc().unwrap(); - assert_close_l2!(&a.dot(&inv), &Array2::eye(3), $rtol); - let inv_into: Array2<$elem> = replicate(&a).invc_into().unwrap(); - assert_close_l2!(&a.dot(&inv_into), &Array2::eye(3), $rtol); - let inv_upper = a.factorizec(UPLO::Upper).unwrap().invc().unwrap(); - assert_close_l2!(&a.dot(&inv_upper), &Array2::eye(3), $rtol); - let inv_upper_into = a.factorizec(UPLO::Upper).unwrap().invc_into().unwrap(); - assert_close_l2!(&a.dot(&inv_upper_into), &Array2::eye(3), $rtol); - let inv_lower = a.factorizec(UPLO::Lower).unwrap().invc().unwrap(); - assert_close_l2!(&a.dot(&inv_lower), &Array2::eye(3), $rtol); - let inv_lower_into = a.factorizec(UPLO::Lower).unwrap().invc_into().unwrap(); - assert_close_l2!(&a.dot(&inv_lower_into), &Array2::eye(3), $rtol); - }; - } - cholesky_into_inverse!(f64, 1e-9); - cholesky_into_inverse!(f32, 1e-3); - cholesky_into_inverse!(c64, 1e-9); - cholesky_into_inverse!(c32, 1e-3); +macro_rules! cholesky_into_lower_upper { + ($elem:ty, $rtol:expr) => { + paste::item! { + #[test] + fn []() { + let a: Array2<$elem> = random_hpd(3); + println!("a = \n{:?}", a); + let upper = a.cholesky(UPLO::Upper).unwrap(); + let fac_upper = a.factorizec(UPLO::Upper).unwrap(); + let fac_lower = a.factorizec(UPLO::Lower).unwrap(); + assert_close_l2!(&upper, &fac_lower.into_upper(), $rtol); + assert_close_l2!(&upper, &fac_upper.into_upper(), $rtol); + let lower = a.cholesky(UPLO::Lower).unwrap(); + let fac_upper = a.factorizec(UPLO::Upper).unwrap(); + let fac_lower = a.factorizec(UPLO::Lower).unwrap(); + assert_close_l2!(&lower, &fac_lower.into_lower(), $rtol); + assert_close_l2!(&lower, &fac_upper.into_lower(), $rtol); + } + } + }; } -#[test] -fn cholesky_det() { - macro_rules! cholesky_det { - ($elem:ty, $atol:expr) => { - let a: Array2<$elem> = random_hpd(3); - println!("a = \n{:?}", a); - let ln_det = a - .eigvalsh(UPLO::Upper) - .unwrap() - .mapv(|elem| elem.ln()) - .scalar_sum(); - let det = ln_det.exp(); - assert_aclose!(a.factorizec(UPLO::Upper).unwrap().detc(), det, $atol); - assert_aclose!(a.factorizec(UPLO::Upper).unwrap().ln_detc(), ln_det, $atol); - assert_aclose!(a.factorizec(UPLO::Lower).unwrap().detc_into(), det, $atol); - assert_aclose!( - a.factorizec(UPLO::Lower).unwrap().ln_detc_into(), - ln_det, - $atol - ); - assert_aclose!(a.detc().unwrap(), det, $atol); - assert_aclose!(a.ln_detc().unwrap(), ln_det, $atol); - assert_aclose!(a.clone().detc_into().unwrap(), det, $atol); - assert_aclose!(a.ln_detc_into().unwrap(), ln_det, $atol); - }; - } - cholesky_det!(f64, 1e-9); - cholesky_det!(f32, 1e-3); - cholesky_det!(c64, 1e-9); - cholesky_det!(c32, 1e-3); +cholesky_into_lower_upper!(f64, 1e-9); +cholesky_into_lower_upper!(f32, 1e-5); +cholesky_into_lower_upper!(c64, 1e-9); +cholesky_into_lower_upper!(c32, 1e-5); + +macro_rules! cholesky_into_inverse { + ($elem:ty, $rtol:expr) => { + paste::item! { + #[test] + fn []() { + let a: Array2<$elem> = random_hpd(3); + println!("a = \n{:?}", a); + let inv = a.invc().unwrap(); + assert_close_l2!(&a.dot(&inv), &Array2::eye(3), $rtol); + let inv_into: Array2<$elem> = replicate(&a).invc_into().unwrap(); + assert_close_l2!(&a.dot(&inv_into), &Array2::eye(3), $rtol); + let inv_upper = a.factorizec(UPLO::Upper).unwrap().invc().unwrap(); + assert_close_l2!(&a.dot(&inv_upper), &Array2::eye(3), $rtol); + let inv_upper_into = a.factorizec(UPLO::Upper).unwrap().invc_into().unwrap(); + assert_close_l2!(&a.dot(&inv_upper_into), &Array2::eye(3), $rtol); + let inv_lower = a.factorizec(UPLO::Lower).unwrap().invc().unwrap(); + assert_close_l2!(&a.dot(&inv_lower), &Array2::eye(3), $rtol); + let inv_lower_into = a.factorizec(UPLO::Lower).unwrap().invc_into().unwrap(); + assert_close_l2!(&a.dot(&inv_lower_into), &Array2::eye(3), $rtol); + } + } + }; } +cholesky_into_inverse!(f64, 1e-9); +cholesky_into_inverse!(f32, 1e-3); +cholesky_into_inverse!(c64, 1e-9); +cholesky_into_inverse!(c32, 1e-3); -#[test] -fn cholesky_solve() { - macro_rules! cholesky_solve { - ($elem:ty, $rtol:expr) => { - let a: Array2<$elem> = random_hpd(3); - let x: Array1<$elem> = random(3); - let b = a.dot(&x); - println!("a = \n{:?}", a); - println!("x = \n{:?}", x); - assert_close_l2!(&a.solvec(&b).unwrap(), &x, $rtol); - assert_close_l2!(&a.solvec_into(b.clone()).unwrap(), &x, $rtol); - assert_close_l2!(&a.solvec_inplace(&mut b.clone()).unwrap(), &x, $rtol); - assert_close_l2!( - &a.factorizec(UPLO::Upper).unwrap().solvec(&b).unwrap(), - &x, - $rtol - ); - assert_close_l2!( - &a.factorizec(UPLO::Lower).unwrap().solvec(&b).unwrap(), - &x, - $rtol - ); - assert_close_l2!( - &a.factorizec(UPLO::Upper) - .unwrap() - .solvec_into(b.clone()) - .unwrap(), - &x, - $rtol - ); - assert_close_l2!( - &a.factorizec(UPLO::Lower) - .unwrap() - .solvec_into(b.clone()) - .unwrap(), - &x, - $rtol - ); - assert_close_l2!( - &a.factorizec(UPLO::Upper) +macro_rules! cholesky_det { + ($elem:ty, $atol:expr) => { + paste::item! { + #[test] + fn []() { + let a: Array2<$elem> = random_hpd(3); + println!("a = \n{:?}", a); + let ln_det = a + .eigvalsh(UPLO::Upper) .unwrap() - .solvec_inplace(&mut b.clone()) - .unwrap(), - &x, - $rtol - ); - assert_close_l2!( - &a.factorizec(UPLO::Lower) - .unwrap() - .solvec_inplace(&mut b.clone()) - .unwrap(), - &x, - $rtol - ); - }; - } - cholesky_solve!(f64, 1e-9); - cholesky_solve!(f32, 1e-3); - cholesky_solve!(c64, 1e-9); - cholesky_solve!(c32, 1e-3); + .mapv(|elem| elem.ln()) + .scalar_sum(); + let det = ln_det.exp(); + assert_aclose!(a.factorizec(UPLO::Upper).unwrap().detc(), det, $atol); + assert_aclose!(a.factorizec(UPLO::Upper).unwrap().ln_detc(), ln_det, $atol); + assert_aclose!(a.factorizec(UPLO::Lower).unwrap().detc_into(), det, $atol); + assert_aclose!( + a.factorizec(UPLO::Lower).unwrap().ln_detc_into(), + ln_det, + $atol + ); + assert_aclose!(a.detc().unwrap(), det, $atol); + assert_aclose!(a.ln_detc().unwrap(), ln_det, $atol); + assert_aclose!(a.clone().detc_into().unwrap(), det, $atol); + assert_aclose!(a.ln_detc_into().unwrap(), ln_det, $atol); + } + } + }; +} +cholesky_det!(f64, 1e-9); +cholesky_det!(f32, 1e-3); +cholesky_det!(c64, 1e-9); +cholesky_det!(c32, 1e-3); + +macro_rules! cholesky_solve { + ($elem:ty, $rtol:expr) => { + paste::item! { + #[test] + fn []() { + let a: Array2<$elem> = random_hpd(3); + let x: Array1<$elem> = random(3); + let b = a.dot(&x); + println!("a = \n{:?}", a); + println!("x = \n{:?}", x); + assert_close_l2!(&a.solvec(&b).unwrap(), &x, $rtol); + assert_close_l2!(&a.solvec_into(b.clone()).unwrap(), &x, $rtol); + assert_close_l2!(&a.solvec_inplace(&mut b.clone()).unwrap(), &x, $rtol); + assert_close_l2!( + &a.factorizec(UPLO::Upper).unwrap().solvec(&b).unwrap(), + &x, + $rtol + ); + assert_close_l2!( + &a.factorizec(UPLO::Lower).unwrap().solvec(&b).unwrap(), + &x, + $rtol + ); + assert_close_l2!( + &a.factorizec(UPLO::Upper) + .unwrap() + .solvec_into(b.clone()) + .unwrap(), + &x, + $rtol + ); + assert_close_l2!( + &a.factorizec(UPLO::Lower) + .unwrap() + .solvec_into(b.clone()) + .unwrap(), + &x, + $rtol + ); + assert_close_l2!( + &a.factorizec(UPLO::Upper) + .unwrap() + .solvec_inplace(&mut b.clone()) + .unwrap(), + &x, + $rtol + ); + assert_close_l2!( + &a.factorizec(UPLO::Lower) + .unwrap() + .solvec_inplace(&mut b.clone()) + .unwrap(), + &x, + $rtol + ); + } + } + }; } +cholesky_solve!(f64, 1e-9); +cholesky_solve!(f32, 1e-3); +cholesky_solve!(c64, 1e-9); +cholesky_solve!(c32, 1e-3); From 44b2adb89b3f51de01f2ea270ba057669289dde2 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 18 Jul 2020 17:44:25 +0900 Subject: [PATCH 5/5] Take complex conjugate --- lax/src/cholesky.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lax/src/cholesky.rs b/lax/src/cholesky.rs index 618218f1..8305efe5 100644 --- a/lax/src/cholesky.rs +++ b/lax/src/cholesky.rs @@ -65,11 +65,19 @@ macro_rules! impl_cholesky { let mut info = 0; if matches!(l, MatrixLayout::C { .. }) { uplo = uplo.t(); + for val in b.iter_mut() { + *val = val.conj(); + } } unsafe { $trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info); } info.as_lapack_result()?; + if matches!(l, MatrixLayout::C { .. }) { + for val in b.iter_mut() { + *val = val.conj(); + } + } Ok(()) } }