From cb3010784aceccd36ee46735d7d8f56c379a309a Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 12 Jul 2020 17:16:51 +0900 Subject: [PATCH 1/5] Split real/complex of svddc --- lax/src/svddc.rs | 60 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 5 deletions(-) diff --git a/lax/src/svddc.rs b/lax/src/svddc.rs index 84f8394b..35c5076b 100644 --- a/lax/src/svddc.rs +++ b/lax/src/svddc.rs @@ -21,7 +21,7 @@ pub trait SVDDC_: Scalar { unsafe fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result>; } -macro_rules! impl_svdd { +macro_rules! impl_svddc_real { ($scalar:ty, $gesdd:path) => { impl SVDDC_ for $scalar { unsafe fn svddc( @@ -70,7 +70,57 @@ macro_rules! impl_svdd { }; } -impl_svdd!(f32, lapacke::sgesdd); -impl_svdd!(f64, lapacke::dgesdd); -impl_svdd!(c32, lapacke::cgesdd); -impl_svdd!(c64, lapacke::zgesdd); +impl_svddc_real!(f32, lapacke::sgesdd); +impl_svddc_real!(f64, lapacke::dgesdd); + +macro_rules! impl_svddc_complex { + ($scalar:ty, $gesdd:path) => { + impl SVDDC_ for $scalar { + unsafe fn svddc( + l: MatrixLayout, + jobz: UVTFlag, + mut a: &mut [Self], + ) -> Result> { + let (m, n) = l.size(); + let k = m.min(n); + let lda = l.lda(); + let (ucol, vtrow) = match jobz { + UVTFlag::Full => (m, n), + UVTFlag::Some => (k, k), + UVTFlag::None => (1, 1), + }; + let mut s = vec![Self::Real::zero(); k.max(1) as usize]; + let mut u = vec![Self::zero(); (m * ucol).max(1) as usize]; + let ldu = l.resized(m, ucol).lda(); + let mut vt = vec![Self::zero(); (vtrow * n).max(1) as usize]; + let ldvt = l.resized(vtrow, n).lda(); + $gesdd( + l.lapacke_layout(), + jobz as u8, + m, + n, + &mut a, + lda, + &mut s, + &mut u, + ldu, + &mut vt, + ldvt, + ) + .as_lapack_result()?; + Ok(SVDOutput { + s, + u: if jobz == UVTFlag::None { None } else { Some(u) }, + vt: if jobz == UVTFlag::None { + None + } else { + Some(vt) + }, + }) + } + } + }; +} + +impl_svddc_complex!(c32, lapacke::cgesdd); +impl_svddc_complex!(c64, lapacke::zgesdd); From 8afd53c65226ae5f249b7073391525a5ed8b4bc7 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 12 Jul 2020 21:21:40 +0900 Subject: [PATCH 2/5] Rewrite impl_svddc_real --- lax/src/svddc.rs | 92 +++++++++++++++++++++++++------------ ndarray-linalg/src/svddc.rs | 42 +++++++++++------ 2 files changed, 91 insertions(+), 43 deletions(-) diff --git a/lax/src/svddc.rs b/lax/src/svddc.rs index 35c5076b..d46ef55a 100644 --- a/lax/src/svddc.rs +++ b/lax/src/svddc.rs @@ -1,7 +1,7 @@ use super::*; use crate::{error::*, layout::MatrixLayout}; use cauchy::*; -use num_traits::Zero; +use num_traits::{ToPrimitive, Zero}; /// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned. /// @@ -29,49 +29,81 @@ macro_rules! impl_svddc_real { jobz: UVTFlag, mut a: &mut [Self], ) -> Result> { - let (m, n) = l.size(); + let m = l.lda(); + let n = l.len(); let k = m.min(n); - let lda = l.lda(); - let (ucol, vtrow) = match jobz { - UVTFlag::Full => (m, n), + let mut s = vec![Self::Real::zero(); k as usize]; + + let (u_col, vt_row) = match jobz { + UVTFlag::Full | UVTFlag::None => (m, n), UVTFlag::Some => (k, k), - UVTFlag::None => (1, 1), }; - let mut s = vec![Self::Real::zero(); k.max(1) as usize]; - let mut u = vec![Self::zero(); (m * ucol).max(1) as usize]; - let ldu = l.resized(m, ucol).lda(); - let mut vt = vec![Self::zero(); (vtrow * n).max(1) as usize]; - let ldvt = l.resized(vtrow, n).lda(); + let (mut u, mut vt) = match jobz { + UVTFlag::Full => ( + Some(vec![Self::zero(); (m * m) as usize]), + Some(vec![Self::zero(); (n * n) as usize]), + ), + UVTFlag::Some => ( + Some(vec![Self::zero(); (m * u_col) as usize]), + Some(vec![Self::zero(); (n * vt_row) as usize]), + ), + UVTFlag::None => (None, None), + }; + + // eval work size + let mut info = 0; + let mut iwork = vec![0; 8 * k as usize]; + let mut work_size = [Self::zero()]; $gesdd( - l.lapacke_layout(), jobz as u8, m, n, &mut a, - lda, + m, &mut s, - &mut u, - ldu, - &mut vt, - ldvt, - ) - .as_lapack_result()?; - Ok(SVDOutput { - s, - u: if jobz == UVTFlag::None { None } else { Some(u) }, - vt: if jobz == UVTFlag::None { - None - } else { - Some(vt) - }, - }) + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + vt_row, + &mut work_size, + -1, + &mut iwork, + &mut info, + ); + info.as_lapack_result()?; + + // do svd + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + $gesdd( + jobz as u8, + m, + n, + &mut a, + m, + &mut s, + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + vt_row, + &mut work, + lwork as i32, + &mut iwork, + &mut info, + ); + info.as_lapack_result()?; + + match l { + MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }), + MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }), + } } } }; } -impl_svddc_real!(f32, lapacke::sgesdd); -impl_svddc_real!(f64, lapacke::dgesdd); +impl_svddc_real!(f32, lapack::sgesdd); +impl_svddc_real!(f64, lapack::dgesdd); macro_rules! impl_svddc_complex { ($scalar:ty, $gesdd:path) => { diff --git a/ndarray-linalg/src/svddc.rs b/ndarray-linalg/src/svddc.rs index 22f3ae0c..7045fa1e 100644 --- a/ndarray-linalg/src/svddc.rs +++ b/ndarray-linalg/src/svddc.rs @@ -2,7 +2,6 @@ use ndarray::*; -use super::convert::*; use super::error::*; use super::layout::*; use super::types::*; @@ -85,19 +84,36 @@ where ) -> Result<(Option, Self::Sigma, Option)> { let l = self.layout()?; let svd_res = unsafe { A::svddc(l, uvt_flag, self.as_allocated_mut()?)? }; - let (m, n) = l.size(); - let k = m.min(n); - let (ldu, tdu, ldvt, tdvt) = match uvt_flag { - UVTFlag::Full => (m, m, n, n), - UVTFlag::Some => (m, k, k, n), - UVTFlag::None => (1, 1, 1, 1), + let (n, m) = l.size(); + let k = std::cmp::min(n, m); + let n = n as usize; + let m = m as usize; + let k = k as usize; + + let (u_col, vt_row) = match uvt_flag { + UVTFlag::Full => (n, m), + UVTFlag::Some => (k, k), + UVTFlag::None => (0, 0), }; - let u = svd_res - .u - .map(|u| into_matrix(l.resized(ldu, tdu), u).expect("Size of U mismatches")); - let vt = svd_res - .vt - .map(|vt| into_matrix(l.resized(ldvt, tdvt), vt).expect("Size of VT mismatches")); + + let u = svd_res.u.map(|u| { + assert_eq!(u.len(), n * u_col); + match l { + MatrixLayout::F { .. } => Array::from_shape_vec((n, u_col).f(), u), + MatrixLayout::C { .. } => Array::from_shape_vec((n, u_col), u), + } + .unwrap() + }); + + let vt = svd_res.vt.map(|vt| { + assert_eq!(vt.len(), m * vt_row); + match l { + MatrixLayout::F { .. } => Array::from_shape_vec((vt_row, m).f(), vt), + MatrixLayout::C { .. } => Array::from_shape_vec((vt_row, m), vt), + } + .unwrap() + }); + let s = ArrayBase::from(svd_res.s); Ok((u, s, vt)) } From 797b71b07f06daf4d7e2ea4a4159876f3464d304 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Mon, 13 Jul 2020 21:53:40 +0900 Subject: [PATCH 3/5] Add svddc test for complex numbers --- ndarray-linalg/tests/svddc.rs | 43 +++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/ndarray-linalg/tests/svddc.rs b/ndarray-linalg/tests/svddc.rs index 2c9204c8..fb26c8d5 100644 --- a/ndarray-linalg/tests/svddc.rs +++ b/ndarray-linalg/tests/svddc.rs @@ -1,13 +1,13 @@ use ndarray::*; use ndarray_linalg::*; -fn test(a: &Array2, flag: UVTFlag) { +fn test(a: &Array2, flag: UVTFlag) { let (n, m) = a.dim(); let k = n.min(m); let answer = a.clone(); println!("a = \n{:?}", a); let (u, s, vt): (_, Array1<_>, _) = a.svddc(flag).unwrap(); - let mut sm = match flag { + let mut sm: Array2 = match flag { UVTFlag::Full => Array::zeros((n, m)), UVTFlag::Some => Array::zeros((k, k)), UVTFlag::None => { @@ -22,53 +22,56 @@ fn test(a: &Array2, flag: UVTFlag) { println!("s = \n{:?}", &s); println!("v = \n{:?}", &vt); for i in 0..k { - sm[(i, i)] = s[i]; + sm[(i, i)] = T::from_real(s[i]); } - assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7); + assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, T::real(1e-7)); } macro_rules! test_svd_impl { - ($n:expr, $m:expr) => { + ($scalar:ty, $n:expr, $m:expr) => { paste::item! { #[test] - fn []() { + fn []() { let a = random(($n, $m)); - test(&a, UVTFlag::Full); + test::<$scalar>(&a, UVTFlag::Full); } #[test] - fn []() { + fn []() { let a = random(($n, $m)); - test(&a, UVTFlag::Some); + test::<$scalar>(&a, UVTFlag::Some); } #[test] - fn []() { + fn []() { let a = random(($n, $m)); - test(&a, UVTFlag::None); + test::<$scalar>(&a, UVTFlag::None); } #[test] - fn []() { + fn []() { let a = random(($n, $m).f()); - test(&a, UVTFlag::Full); + test::<$scalar>(&a, UVTFlag::Full); } #[test] - fn []() { + fn []() { let a = random(($n, $m).f()); - test(&a, UVTFlag::Some); + test::<$scalar>(&a, UVTFlag::Some); } #[test] - fn []() { + fn []() { let a = random(($n, $m).f()); - test(&a, UVTFlag::None); + test::<$scalar>(&a, UVTFlag::None); } } }; } -test_svd_impl!(3, 3); -test_svd_impl!(4, 3); -test_svd_impl!(3, 4); +test_svd_impl!(f64, 3, 3); +test_svd_impl!(f64, 4, 3); +test_svd_impl!(f64, 3, 4); +test_svd_impl!(c64, 3, 3); +test_svd_impl!(c64, 4, 3); +test_svd_impl!(c64, 3, 4); From acb88aa5a105510d39a5286b796312380b044026 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Mon, 13 Jul 2020 23:55:00 +0900 Subject: [PATCH 4/5] Impl SVDDC_ for c32/c64 --- lax/src/svddc.rs | 80 +++++++++++++++--------------------------------- 1 file changed, 24 insertions(+), 56 deletions(-) diff --git a/lax/src/svddc.rs b/lax/src/svddc.rs index d46ef55a..3e50d7bb 100644 --- a/lax/src/svddc.rs +++ b/lax/src/svddc.rs @@ -21,8 +21,14 @@ pub trait SVDDC_: Scalar { unsafe fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result>; } -macro_rules! impl_svddc_real { - ($scalar:ty, $gesdd:path) => { +macro_rules! impl_svddc { + (@real, $scalar:ty, $gesdd:path) => { + impl_svddc!(@body, $scalar, $gesdd, ); + }; + (@complex, $scalar:ty, $gesdd:path) => { + impl_svddc!(@body, $scalar, $gesdd, rwork); + }; + (@body, $scalar:ty, $gesdd:path, $($rwork_ident:ident),*) => { impl SVDDC_ for $scalar { unsafe fn svddc( l: MatrixLayout, @@ -50,6 +56,16 @@ macro_rules! impl_svddc_real { UVTFlag::None => (None, None), }; + $( // for complex only + let mx = n.max(m) as usize; + let mn = n.min(m) as usize; + let lrwork = match jobz { + UVTFlag::None => 7 * mn, + _ => std::cmp::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn), + }; + let mut $rwork_ident = vec![Self::Real::zero(); lrwork]; + )* + // eval work size let mut info = 0; let mut iwork = vec![0; 8 * k as usize]; @@ -67,6 +83,7 @@ macro_rules! impl_svddc_real { vt_row, &mut work_size, -1, + $(&mut $rwork_ident,)* &mut iwork, &mut info, ); @@ -88,6 +105,7 @@ macro_rules! impl_svddc_real { vt_row, &mut work, lwork as i32, + $(&mut $rwork_ident,)* &mut iwork, &mut info, ); @@ -102,57 +120,7 @@ macro_rules! impl_svddc_real { }; } -impl_svddc_real!(f32, lapack::sgesdd); -impl_svddc_real!(f64, lapack::dgesdd); - -macro_rules! impl_svddc_complex { - ($scalar:ty, $gesdd:path) => { - impl SVDDC_ for $scalar { - unsafe fn svddc( - l: MatrixLayout, - jobz: UVTFlag, - mut a: &mut [Self], - ) -> Result> { - let (m, n) = l.size(); - let k = m.min(n); - let lda = l.lda(); - let (ucol, vtrow) = match jobz { - UVTFlag::Full => (m, n), - UVTFlag::Some => (k, k), - UVTFlag::None => (1, 1), - }; - let mut s = vec![Self::Real::zero(); k.max(1) as usize]; - let mut u = vec![Self::zero(); (m * ucol).max(1) as usize]; - let ldu = l.resized(m, ucol).lda(); - let mut vt = vec![Self::zero(); (vtrow * n).max(1) as usize]; - let ldvt = l.resized(vtrow, n).lda(); - $gesdd( - l.lapacke_layout(), - jobz as u8, - m, - n, - &mut a, - lda, - &mut s, - &mut u, - ldu, - &mut vt, - ldvt, - ) - .as_lapack_result()?; - Ok(SVDOutput { - s, - u: if jobz == UVTFlag::None { None } else { Some(u) }, - vt: if jobz == UVTFlag::None { - None - } else { - Some(vt) - }, - }) - } - } - }; -} - -impl_svddc_complex!(c32, lapacke::cgesdd); -impl_svddc_complex!(c64, lapacke::zgesdd); +impl_svddc!(@real, f32, lapack::sgesdd); +impl_svddc!(@real, f64, lapack::dgesdd); +impl_svddc!(@complex, c32, lapack::cgesdd); +impl_svddc!(@complex, c64, lapack::zgesdd); From cf56af5336ee83a8658b7a7a6eef7c94684df66c Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 14 Jul 2020 00:04:31 +0900 Subject: [PATCH 5/5] Use convert::into_matrix --- ndarray-linalg/src/svddc.rs | 36 ++++++++++-------------------------- 1 file changed, 10 insertions(+), 26 deletions(-) diff --git a/ndarray-linalg/src/svddc.rs b/ndarray-linalg/src/svddc.rs index 7045fa1e..b27212ec 100644 --- a/ndarray-linalg/src/svddc.rs +++ b/ndarray-linalg/src/svddc.rs @@ -1,11 +1,8 @@ //! Singular-value decomposition (SVD) by divide-and-conquer (?gesdd) +use super::{convert::*, error::*, layout::*, types::*}; use ndarray::*; -use super::error::*; -use super::layout::*; -use super::types::*; - pub use lapack::svddc::UVTFlag; /// Singular-value decomposition of matrix (copying) by divide-and-conquer @@ -84,35 +81,22 @@ where ) -> Result<(Option, Self::Sigma, Option)> { let l = self.layout()?; let svd_res = unsafe { A::svddc(l, uvt_flag, self.as_allocated_mut()?)? }; - let (n, m) = l.size(); - let k = std::cmp::min(n, m); - let n = n as usize; - let m = m as usize; - let k = k as usize; + let (m, n) = l.size(); + let k = m.min(n); let (u_col, vt_row) = match uvt_flag { - UVTFlag::Full => (n, m), + UVTFlag::Full => (m, n), UVTFlag::Some => (k, k), UVTFlag::None => (0, 0), }; - let u = svd_res.u.map(|u| { - assert_eq!(u.len(), n * u_col); - match l { - MatrixLayout::F { .. } => Array::from_shape_vec((n, u_col).f(), u), - MatrixLayout::C { .. } => Array::from_shape_vec((n, u_col), u), - } - .unwrap() - }); + let u = svd_res + .u + .map(|u| into_matrix(l.resized(m, u_col), u).unwrap()); - let vt = svd_res.vt.map(|vt| { - assert_eq!(vt.len(), m * vt_row); - match l { - MatrixLayout::F { .. } => Array::from_shape_vec((vt_row, m).f(), vt), - MatrixLayout::C { .. } => Array::from_shape_vec((vt_row, m), vt), - } - .unwrap() - }); + let vt = svd_res + .vt + .map(|vt| into_matrix(l.resized(vt_row, n), vt).unwrap()); let s = ArrayBase::from(svd_res.s); Ok((u, s, vt))