From e652fccf915c739b202410b0f6b8c99b69bccb05 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Mon, 29 Jun 2020 13:17:13 +0900 Subject: [PATCH 01/49] Add lapack and blas --- lax/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/lax/Cargo.toml b/lax/Cargo.toml index e42c8d74..642dcf36 100644 --- a/lax/Cargo.toml +++ b/lax/Cargo.toml @@ -15,6 +15,7 @@ thiserror = "1" cauchy = "0.2" lapacke = "0.2.0" num-traits = "0.2" +lapack = "*" [dependencies.blas-src] version = "0.6.1" From 77b626967e6474f0a380c598614907cec536097c Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 30 Jun 2020 21:45:08 +0900 Subject: [PATCH 02/49] Rewrite lapack::eigh using lapack crate --- lax/src/eigh.rs | 219 +++++++++++++++++++++++++++++++------ ndarray-linalg/src/eigh.rs | 20 ++-- 2 files changed, 192 insertions(+), 47 deletions(-) diff --git a/lax/src/eigh.rs b/lax/src/eigh.rs index a992a67e..f43ece13 100644 --- a/lax/src/eigh.rs +++ b/lax/src/eigh.rs @@ -1,21 +1,23 @@ -//! Eigenvalue decomposition for Hermite matrices +//! Eigenvalue decomposition for Symmetric/Hermite matrices use super::*; use crate::{error::*, layout::MatrixLayout}; use cauchy::*; -use num_traits::Zero; +use num_traits::{ToPrimitive, Zero}; -/// Wraps `*syev` for real and `*heev` for complex pub trait Eigh_: Scalar { - unsafe fn eigh( + /// Wraps `*syev` for real and `*heev` for complex + fn eigh( calc_eigenvec: bool, - l: MatrixLayout, + layout: MatrixLayout, uplo: UPLO, a: &mut [Self], ) -> Result>; - unsafe fn eigh_generalized( + + /// Wraps `*syegv` for real and `*heegv` for complex + fn eigh_generalized( calc_eigenvec: bool, - l: MatrixLayout, + layout: MatrixLayout, uplo: UPLO, a: &mut [Self], b: &mut [Self], @@ -25,50 +27,195 @@ pub trait Eigh_: Scalar { macro_rules! impl_eigh { ($scalar:ty, $ev:path, $evg:path) => { impl Eigh_ for $scalar { - unsafe fn eigh( + fn eigh( + calc_v: bool, + layout: MatrixLayout, + uplo: UPLO, + mut a: &mut [Self], + ) -> Result> { + assert_eq!(layout.len(), layout.lda()); + let n = layout.len(); + let jobz = if calc_v { b'V' } else { b'N' }; + let mut eigs = vec![Self::Real::zero(); n as usize]; + let n = n as i32; + + // calc work size + let mut info = 0; + let mut work_size = [0.0]; + unsafe { + $ev( + jobz, + uplo as u8, + n, + &mut a, + n, + &mut eigs, + &mut work_size, + -1, + &mut info, + ); + } + info.as_lapack_result()?; + + // actual ev + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + $ev( + jobz, + uplo as u8, + n, + &mut a, + n, + &mut eigs, + &mut work, + lwork as i32, + &mut info, + ); + } + info.as_lapack_result()?; + Ok(eigs) + } + + fn eigh_generalized( + calc_v: bool, + layout: MatrixLayout, + uplo: UPLO, + mut a: &mut [Self], + mut b: &mut [Self], + ) -> Result> { + assert_eq!(layout.len(), layout.lda()); + let n = layout.len(); + let jobz = if calc_v { b'V' } else { b'N' }; + let mut eigs = vec![Self::Real::zero(); n as usize]; + let n = n as i32; + + // calc work size + let mut info = 0; + let mut work_size = [0.0]; + unsafe { + $evg( + &[1], + jobz, + uplo as u8, + n, + &mut a, + n, + &mut b, + n, + &mut eigs, + &mut work_size, + -1, + &mut info, + ); + } + info.as_lapack_result()?; + + // actual evg + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + $evg( + &[1], + jobz, + uplo as u8, + n, + &mut a, + n, + &mut b, + n, + &mut eigs, + &mut work, + lwork as i32, + &mut info, + ); + } + info.as_lapack_result()?; + Ok(eigs) + } + } + }; +} // impl_eigh! + +impl_eigh!(f64, lapack::dsyev, lapack::dsygv); +impl_eigh!(f32, lapack::ssyev, lapack::ssygv); + +// splitted for RWORK +macro_rules! impl_eighc { + ($scalar:ty, $ev:path, $evg:path) => { + impl Eigh_ for $scalar { + fn eigh( calc_v: bool, - l: MatrixLayout, + layout: MatrixLayout, uplo: UPLO, mut a: &mut [Self], ) -> Result> { - let (n, _) = l.size(); + assert_eq!(layout.len(), layout.lda()); + let n = layout.len(); let jobz = if calc_v { b'V' } else { b'N' }; - let mut w = vec![Self::Real::zero(); n as usize]; - $ev(l.lapacke_layout(), jobz, uplo as u8, n, &mut a, n, &mut w) - .as_lapack_result()?; - Ok(w) + let mut eigs = vec![Self::Real::zero(); n as usize]; + let mut work = vec![Self::zero(); 2 * n as usize - 1]; + let mut rwork = vec![Self::Real::zero(); 3 * n as usize - 2]; + let mut info = 0; + let n = n as i32; + + unsafe { + $ev( + jobz, + uplo as u8, + n, + &mut a, + n, + &mut eigs, + &mut work, + 2 * n - 1, + &mut rwork, + &mut info, + ) + }; + info.as_lapack_result()?; + Ok(eigs) } - unsafe fn eigh_generalized( + fn eigh_generalized( calc_v: bool, - l: MatrixLayout, + layout: MatrixLayout, uplo: UPLO, mut a: &mut [Self], mut b: &mut [Self], ) -> Result> { - let (n, _) = l.size(); + assert_eq!(layout.len(), layout.lda()); + let n = layout.len(); let jobz = if calc_v { b'V' } else { b'N' }; - let mut w = vec![Self::Real::zero(); n as usize]; - $evg( - l.lapacke_layout(), - 1, - jobz, - uplo as u8, - n, - &mut a, - n, - &mut b, - n, - &mut w, - ) - .as_lapack_result()?; - Ok(w) + let mut eigs = vec![Self::Real::zero(); n as usize]; + let mut work = vec![Self::zero(); 2 * n as usize - 1]; + let mut rwork = vec![Self::Real::zero(); 3 * n as usize - 2]; + let n = n as i32; + let mut info = 0; + + unsafe { + $evg( + &[1], + jobz, + uplo as u8, + n, + &mut a, + n, + &mut b, + n, + &mut eigs, + &mut work, + 2 * n - 1, + &mut rwork, + &mut info, + ) + }; + info.as_lapack_result()?; + Ok(eigs) } } }; } // impl_eigh! -impl_eigh!(f64, lapacke::dsyev, lapacke::dsygv); -impl_eigh!(f32, lapacke::ssyev, lapacke::ssygv); -impl_eigh!(c64, lapacke::zheev, lapacke::zhegv); -impl_eigh!(c32, lapacke::cheev, lapacke::chegv); +impl_eighc!(c64, lapack::zheev, lapack::zhegv); +impl_eighc!(c32, lapack::cheev, lapack::chegv); diff --git a/ndarray-linalg/src/eigh.rs b/ndarray-linalg/src/eigh.rs index b0438a99..86f1fb46 100644 --- a/ndarray-linalg/src/eigh.rs +++ b/ndarray-linalg/src/eigh.rs @@ -99,7 +99,7 @@ where MatrixLayout::C { .. } => self.swap_axes(0, 1), MatrixLayout::F { .. } => {} } - let s = unsafe { A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)? }; + let s = A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)?; Ok((ArrayBase::from(s), self)) } } @@ -126,15 +126,13 @@ where MatrixLayout::F { .. } => {} } - let s = unsafe { - A::eigh_generalized( - true, - self.0.square_layout()?, - uplo, - self.0.as_allocated_mut()?, - self.1.as_allocated_mut()?, - )? - }; + let s = A::eigh_generalized( + true, + self.0.square_layout()?, + uplo, + self.0.as_allocated_mut()?, + self.1.as_allocated_mut()?, + )?; Ok((ArrayBase::from(s), self)) } @@ -191,7 +189,7 @@ where type EigVal = Array1; fn eigvalsh_inplace(&mut self, uplo: UPLO) -> Result { - let s = unsafe { A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)? }; + let s = A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)?; Ok(ArrayBase::from(s)) } } From e4ee6d6e1c4ebf398c6e5b473b5231f0b125871e Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Mon, 6 Jul 2020 00:59:28 +0900 Subject: [PATCH 03/49] Split eigenvalue and eigenvector tests --- ndarray-linalg/tests/eig.rs | 429 ++++++++++++++++++++++-------------- 1 file changed, 264 insertions(+), 165 deletions(-) diff --git a/ndarray-linalg/tests/eig.rs b/ndarray-linalg/tests/eig.rs index ac520152..28314b8a 100644 --- a/ndarray-linalg/tests/eig.rs +++ b/ndarray-linalg/tests/eig.rs @@ -1,190 +1,289 @@ use ndarray::*; use ndarray_linalg::*; -#[test] -fn dgeev() { - // https://software.intel.com/sites/products/documentation/doclib/mkl_sa/11/mkl_lapack_examples/dgeev_ex.f.htm - let a: Array2 = arr2(&[ - [-1.01, 0.86, -4.60, 3.31, -4.81], - [3.98, 0.53, -7.04, 5.29, 3.55], - [3.30, 8.26, -3.89, 8.20, -1.51], - [4.43, 4.96, -7.66, -7.33, 6.18], - [7.31, -6.43, -6.16, 2.47, 5.58], - ]); - let (e, vecs): (Array1<_>, Array2<_>) = (&a).eig().unwrap(); - assert_close_l2!( - &e, - &arr1(&[ - c64::new(2.86, 10.76), - c64::new(2.86, -10.76), - c64::new(-0.69, 4.70), - c64::new(-0.69, -4.70), - c64::new(-10.46, 0.00) - ]), - 1.0e-3 - ); - - /* - let answer = &arr2(&[[c64::new( 0.11, 0.17), c64::new( 0.11, -0.17), c64::new( 0.73, 0.00), c64::new( 0.73, 0.00), c64::new( 0.46, 0.00)], - [c64::new( 0.41, -0.26), c64::new( 0.41, 0.26), c64::new( -0.03, -0.02), c64::new( -0.03, 0.02), c64::new( 0.34, 0.00)], - [c64::new( 0.10, -0.51), c64::new( 0.10, 0.51), c64::new( 0.19, -0.29), c64::new( 0.19, 0.29), c64::new( 0.31, 0.00)], - [c64::new( 0.40, -0.09), c64::new( 0.40, 0.09), c64::new( -0.08, -0.08), c64::new( -0.08, 0.08), c64::new( -0.74, 0.00)], - [c64::new( 0.54, 0.00), c64::new( 0.54, 0.00), c64::new( -0.29, -0.49), c64::new( -0.29, 0.49), c64::new( 0.16, 0.00)]]); - */ - - let a_c: Array2 = a.map(|f| c64::new(*f, 0.0)); - for (i, v) in vecs.axis_iter(Axis(1)).enumerate() { - let av = a_c.dot(&v); - let ev = v.mapv(|f| e[i] * f); - assert_close_l2!(&av, &ev, 1.0e-7); - } -} - -#[test] -fn fgeev() { - // https://software.intel.com/sites/products/documentation/doclib/mkl_sa/11/mkl_lapack_examples/dgeev_ex.f.htm - let a: Array2 = arr2(&[ - [-1.01, 0.86, -4.60, 3.31, -4.81], - [3.98, 0.53, -7.04, 5.29, 3.55], - [3.30, 8.26, -3.89, 8.20, -1.51], - [4.43, 4.96, -7.66, -7.33, 6.18], - [7.31, -6.43, -6.16, 2.47, 5.58], - ]); - let (e, vecs): (Array1<_>, Array2<_>) = (&a).eig().unwrap(); - assert_close_l2!( - &e, - &arr1(&[ - c32::new(2.86, 10.76), - c32::new(2.86, -10.76), - c32::new(-0.69, 4.70), - c32::new(-0.69, -4.70), - c32::new(-10.46, 0.00) - ]), - 1.0e-3 - ); - - /* - let answer = &arr2(&[[c32::new( 0.11, 0.17), c32::new( 0.11, -0.17), c32::new( 0.73, 0.00), c32::new( 0.73, 0.00), c32::new( 0.46, 0.00)], - [c32::new( 0.41, -0.26), c32::new( 0.41, 0.26), c32::new( -0.03, -0.02), c32::new( -0.03, 0.02), c32::new( 0.34, 0.00)], - [c32::new( 0.10, -0.51), c32::new( 0.10, 0.51), c32::new( 0.19, -0.29), c32::new( 0.19, 0.29), c32::new( 0.31, 0.00)], - [c32::new( 0.40, -0.09), c32::new( 0.40, 0.09), c32::new( -0.08, -0.08), c32::new( -0.08, 0.08), c32::new( -0.74, 0.00)], - [c32::new( 0.54, 0.00), c32::new( 0.54, 0.00), c32::new( -0.29, -0.49), c32::new( -0.29, 0.49), c32::new( 0.16, 0.00)]]); - */ - - let a_c: Array2 = a.map(|f| c32::new(*f, 0.0)); - for (i, v) in vecs.axis_iter(Axis(1)).enumerate() { - let av = a_c.dot(&v); - let ev = v.mapv(|f| e[i] * f); - assert_close_l2!(&av, &ev, 1.0e-5); +// Test Av_i = e_i v_i for i = 0..n +fn test_eig(a: Array2, eigs: Array1, vecs: Array2) +where + T::Complex: Lapack, +{ + println!("a\n{:+.4}", &a); + println!("eigs\n{:+.4}", &eigs); + println!("vec\n{:+.4}", &vecs); + let a: Array2 = a.map(|v| v.as_c()); + for (&e, v) in eigs.iter().zip(vecs.axis_iter(Axis(1))) { + let av = a.dot(&v); + let ev = v.mapv(|val| val * e); + println!("av = {:+.4}", &av); + println!("ev = {:+.4}", &ev); + assert_close_l2!(&av, &ev, T::real(1e-3)); } } -#[test] -fn zgeev() { - // https://software.intel.com/sites/products/documentation/doclib/mkl_sa/11/mkl_lapack_examples/zgeev_ex.f.htm - let a: Array2 = arr2(&[ +// Test case for real Eigenvalue problem +// +// -1.01 0.86 -4.60 3.31 -4.81 +// 3.98 0.53 -7.04 5.29 3.55 +// 3.30 8.26 -3.89 8.20 -1.51 +// 4.43 4.96 -7.66 -7.33 6.18 +// 7.31 -6.43 -6.16 2.47 5.58 +// +// Eigenvalues +// ( 2.86, 10.76) ( 2.86,-10.76) ( -0.69, 4.70) ( -0.69, -4.70) -10.46 +// +// Left eigenvectors +// ( 0.04, 0.29) ( 0.04, -0.29) ( -0.13, -0.33) ( -0.13, 0.33) 0.04 +// ( 0.62, 0.00) ( 0.62, 0.00) ( 0.69, 0.00) ( 0.69, 0.00) 0.56 +// ( -0.04, -0.58) ( -0.04, 0.58) ( -0.39, -0.07) ( -0.39, 0.07) -0.13 +// ( 0.28, 0.01) ( 0.28, -0.01) ( -0.02, -0.19) ( -0.02, 0.19) -0.80 +// ( -0.04, 0.34) ( -0.04, -0.34) ( -0.40, 0.22) ( -0.40, -0.22) 0.18 +// +// Right eigenvectors +// ( 0.11, 0.17) ( 0.11, -0.17) ( 0.73, 0.00) ( 0.73, 0.00) 0.46 +// ( 0.41, -0.26) ( 0.41, 0.26) ( -0.03, -0.02) ( -0.03, 0.02) 0.34 +// ( 0.10, -0.51) ( 0.10, 0.51) ( 0.19, -0.29) ( 0.19, 0.29) 0.31 +// ( 0.40, -0.09) ( 0.40, 0.09) ( -0.08, -0.08) ( -0.08, 0.08) -0.74 +// ( 0.54, 0.00) ( 0.54, 0.00) ( -0.29, -0.49) ( -0.29, 0.49) 0.16 +// +// - https://software.intel.com/sites/products/documentation/doclib/mkl_sa/11/mkl_lapack_examples/dgeev_ex.f.htm +// +fn test_matrix_real() -> Array2 { + array![ [ - c64::new(-3.84, 2.25), - c64::new(-8.94, -4.75), - c64::new(8.95, -6.53), - c64::new(-9.87, 4.82), + T::real(-1.01), + T::real(0.86), + T::real(-4.60), + T::real(3.31), + T::real(-4.81) ], [ - c64::new(-0.66, 0.83), - c64::new(-4.40, -3.82), - c64::new(-3.50, -4.26), - c64::new(-3.15, 7.36), + T::real(3.98), + T::real(0.53), + T::real(-7.04), + T::real(5.29), + T::real(3.55) ], [ - c64::new(-3.99, -4.73), - c64::new(-5.88, -6.60), - c64::new(-3.36, -0.40), - c64::new(-0.75, 5.23), + T::real(3.30), + T::real(8.26), + T::real(-3.89), + T::real(8.20), + T::real(-1.51) ], [ - c64::new(7.74, 4.18), - c64::new(3.66, -7.53), - c64::new(2.58, 3.60), - c64::new(4.59, 5.41), + T::real(4.43), + T::real(4.96), + T::real(-7.66), + T::real(-7.33), + T::real(6.18) ], - ]); - let (e, vecs): (Array1<_>, Array2<_>) = (&a).eig().unwrap(); - assert_close_l2!( - &e, - &arr1(&[ - c64::new(-9.43, -12.98), - c64::new(-3.44, 12.69), - c64::new(0.11, -3.40), - c64::new(5.76, 7.13) - ]), - 1.0e-3 - ); - - /* - let answer = &arr2(&[[c64::new( 0.43, 0.33), c64::new( 0.83, 0.00), c64::new( 0.60, 0.00), c64::new( -0.31, 0.03)], - [c64::new( 0.51, -0.03), c64::new( 0.08, -0.25), c64::new( -0.40, -0.20), c64::new( 0.04, 0.34)], - [c64::new( 0.62, 0.00), c64::new( -0.25, 0.28), c64::new( -0.09, -0.48), c64::new( 0.36, 0.06)], - [c64::new( -0.23, 0.11), c64::new( -0.10, -0.32), c64::new( -0.43, 0.13), c64::new( 0.81, 0.00)]]); - */ - - for (i, v) in vecs.axis_iter(Axis(1)).enumerate() { - let av = a.dot(&v); - let ev = v.mapv(|f| e[i] * f); - assert_close_l2!(&av, &ev, 1.0e-7); - } + [ + T::real(7.31), + T::real(-6.43), + T::real(-6.16), + T::real(2.47), + T::real(5.58) + ], + ] } -#[test] -fn cgeev() { - // https://software.intel.com/sites/products/documentation/doclib/mkl_sa/11/mkl_lapack_examples/zgeev_ex.f.htm - let a: Array2 = arr2(&[ +fn test_matrix_real_t() -> Array2 { + test_matrix_real::().t().permuted_axes([1, 0]).to_owned() +} + +fn answer_eig_real() -> Array1 { + array![ + T::complex(-10.46, 0.00), + T::complex(-0.69, 4.70), + T::complex(-0.69, -4.70), + T::complex(2.86, 10.76), + T::complex(2.86, -10.76), + ] +} + +// Test case for {c,z}geev +// +// ( -3.84, 2.25) ( -8.94, -4.75) ( 8.95, -6.53) ( -9.87, 4.82) +// ( -0.66, 0.83) ( -4.40, -3.82) ( -3.50, -4.26) ( -3.15, 7.36) +// ( -3.99, -4.73) ( -5.88, -6.60) ( -3.36, -0.40) ( -0.75, 5.23) +// ( 7.74, 4.18) ( 3.66, -7.53) ( 2.58, 3.60) ( 4.59, 5.41) +// +// Eigenvalues +// ( -9.43,-12.98) ( -3.44, 12.69) ( 0.11, -3.40) ( 5.76, 7.13) +// +// Left eigenvectors +// ( 0.24, -0.18) ( 0.61, 0.00) ( -0.18, -0.33) ( 0.28, 0.09) +// ( 0.79, 0.00) ( -0.05, -0.27) ( 0.82, 0.00) ( -0.55, 0.16) +// ( 0.22, -0.27) ( -0.21, 0.53) ( -0.37, 0.15) ( 0.45, 0.09) +// ( -0.02, 0.41) ( 0.40, -0.24) ( 0.06, 0.12) ( 0.62, 0.00) +// +// Right eigenvectors +// ( 0.43, 0.33) ( 0.83, 0.00) ( 0.60, 0.00) ( -0.31, 0.03) +// ( 0.51, -0.03) ( 0.08, -0.25) ( -0.40, -0.20) ( 0.04, 0.34) +// ( 0.62, 0.00) ( -0.25, 0.28) ( -0.09, -0.48) ( 0.36, 0.06) +// ( -0.23, 0.11) ( -0.10, -0.32) ( -0.43, 0.13) ( 0.81, 0.00) +// +// - https://software.intel.com/sites/products/documentation/doclib/mkl_sa/11/mkl_lapack_examples/zgeev_ex.f.htm +// +fn test_matrix_complex() -> Array2 { + array![ [ - c32::new(-3.84, 2.25), - c32::new(-8.94, -4.75), - c32::new(8.95, -6.53), - c32::new(-9.87, 4.82), + T::complex(-3.84, 2.25), + T::complex(-8.94, -4.75), + T::complex(8.95, -6.53), + T::complex(-9.87, 4.82) ], [ - c32::new(-0.66, 0.83), - c32::new(-4.40, -3.82), - c32::new(-3.50, -4.26), - c32::new(-3.15, 7.36), + T::complex(-0.66, 0.83), + T::complex(-4.40, -3.82), + T::complex(-3.50, -4.26), + T::complex(-3.15, 7.36) ], [ - c32::new(-3.99, -4.73), - c32::new(-5.88, -6.60), - c32::new(-3.36, -0.40), - c32::new(-0.75, 5.23), + T::complex(-3.99, -4.73), + T::complex(-5.88, -6.60), + T::complex(-3.36, -0.40), + T::complex(-0.75, 5.23) ], [ - c32::new(7.74, 4.18), - c32::new(3.66, -7.53), - c32::new(2.58, 3.60), - c32::new(4.59, 5.41), + T::complex(7.74, 4.18), + T::complex(3.66, -7.53), + T::complex(2.58, 3.60), + T::complex(4.59, 5.41) + ] + ] +} + +fn test_matrix_complex_t() -> Array2 { + test_matrix_complex::() + .t() + .permuted_axes([1, 0]) + .to_owned() +} + +fn answer_eig_complex() -> Array1 { + array![ + T::complex(-9.43, -12.98), + T::complex(-3.44, 12.69), + T::complex(0.11, -3.40), + T::complex(5.76, 7.13) + ] +} + +// re-evaluated eigenvalues in f64 accuracy +fn answer_eigvectors_complex() -> Array2 { + array![ + [ + T::complex(0.4308565200776108, 0.32681273781262143), + T::complex(0.8256820507672813, 0.), + T::complex(0.5983959785539453, 0.), + T::complex(-0.30543190348437826, 0.03333164861799901) ], - ]); - let (e, vecs): (Array1<_>, Array2<_>) = (&a).eig().unwrap(); - assert_close_l2!( - &e, - &arr1(&[ - c32::new(-9.43, -12.98), - c32::new(-3.44, 12.69), - c32::new(0.11, -3.40), - c32::new(5.76, 7.13) - ]), - 1.0e-3 - ); - - /* - let answer = &arr2(&[[c32::new( 0.43, 0.33), c32::new( 0.83, 0.00), c32::new( 0.60, 0.00), c32::new( -0.31, 0.03)], - [c32::new( 0.51, -0.03), c32::new( 0.08, -0.25), c32::new( -0.40, -0.20), c32::new( 0.04, 0.34)], - [c32::new( 0.62, 0.00), c32::new( -0.25, 0.28), c32::new( -0.09, -0.48), c32::new( 0.36, 0.06)], - [c32::new( -0.23, 0.11), c32::new( -0.10, -0.32), c32::new( -0.43, 0.13), c32::new( 0.81, 0.00)]]); - */ - - for (i, v) in vecs.axis_iter(Axis(1)).enumerate() { - let av = a.dot(&v); - let ev = v.mapv(|f| e[i] * f); - assert_close_l2!(&av, &ev, 1.0e-5); - } + [ + T::complex(0.5087414602970965, -0.02883342170692809), + T::complex(0.07502916788141115, -0.2487285045091665), + T::complex(-0.40047616275207687, -0.2014492227625603), + T::complex(0.03978282815783273, 0.34450765221546126) + ], + [ + T::complex(0.6198496527657755, 0.), + T::complex(-0.24575578997801528, 0.27887240221169646), + T::complex(-0.09008001907594984, -0.4752646215391732), + T::complex(0.3583254365159844, 0.06064506988524665) + ], + [ + T::complex(-0.22692824331926856, 0.11043927846403584), + T::complex(-0.10343406372814358, -0.3192014653632327), + T::complex(-0.43484029549540404, 0.13372491785816037), + T::complex(0.8082432893178352, 0.) + ] + ] +} + +macro_rules! impl_test_real { + ($real:ty) => { + paste::item! { + #[test] + fn [<$real _eigvals >]() { + let a = test_matrix_real::<$real>(); + let (e, _vecs) = a.eig().unwrap(); + assert_close_l2!(&e, &answer_eig_real::<$real>(), 1.0e-3); + } + + #[test] + fn [<$real _eigvals_t>]() { + let a = test_matrix_real_t::<$real>(); + let (e, _vecs) = a.eig().unwrap(); + assert_close_l2!(&e, &answer_eig_real::<$real>(), 1.0e-3); + } + + #[test] + fn [<$real _eig>]() { + let a = test_matrix_real::<$real>(); + let (e, vecs) = a.eig().unwrap(); + test_eig(a, e, vecs); + } + + #[test] + fn [<$real _eig_t>]() { + let a = test_matrix_real_t::<$real>(); + let (e, vecs) = a.eig().unwrap(); + test_eig(a, e, vecs); + } + + } // paste::item! + }; +} + +impl_test_real!(f32); +impl_test_real!(f64); + +macro_rules! impl_test_complex { + ($complex:ty) => { + paste::item! { + #[test] + fn [<$complex _eigvals >]() { + let a = test_matrix_complex::<$complex>(); + let (e, _vecs) = a.eig().unwrap(); + assert_close_l2!(&e, &answer_eig_complex::<$complex>(), 1.0e-3); + } + + #[test] + fn [<$complex _eigvals_t>]() { + let a = test_matrix_complex_t::<$complex>(); + let (e, _vecs) = a.eig().unwrap(); + assert_close_l2!(&e, &answer_eig_complex::<$complex>(), 1.0e-3); + } + + #[test] + fn [<$complex _eigvector>]() { + let a = test_matrix_complex::<$complex>(); + let (_e, vecs) = a.eig().unwrap(); + assert_close_l2!(&vecs, &answer_eigvectors_complex::<$complex>(), 1.0e-3); + } + + #[test] + fn [<$complex _eigvector_t>]() { + let a = test_matrix_complex_t::<$complex>(); + let (_e, vecs) = a.eig().unwrap(); + assert_close_l2!(&vecs, &answer_eigvectors_complex::<$complex>(), 1.0e-3); + } + + #[test] + fn [<$complex _eig>]() { + let a = test_matrix_complex::<$complex>(); + let (e, vecs) = a.eig().unwrap(); + test_eig(a, e, vecs); + } + + #[test] + fn [<$complex _eig_t>]() { + let a = test_matrix_complex_t::<$complex>(); + let (e, vecs) = a.eig().unwrap(); + test_eig(a, e, vecs); + } + } // paste::item! + }; } + +impl_test_complex!(c32); +impl_test_complex!(c64); From 27f5bfb995ccb8d83d0a9601eb73e165d2388637 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Mon, 6 Jul 2020 00:59:47 +0900 Subject: [PATCH 04/49] Add doctest --- ndarray-linalg/src/eig.rs | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/ndarray-linalg/src/eig.rs b/ndarray-linalg/src/eig.rs index e9f09080..51d69804 100644 --- a/ndarray-linalg/src/eig.rs +++ b/ndarray-linalg/src/eig.rs @@ -11,6 +11,29 @@ pub trait Eig { type EigVal; type EigVec; /// Calculate eigenvalues with the right eigenvector + /// + /// $$ A u_i = \lambda_i u_i $$ + /// + /// ``` + /// use ndarray::*; + /// use ndarray_linalg::*; + /// + /// let a: Array2 = array![ + /// [-1.01, 0.86, -4.60, 3.31, -4.81], + /// [ 3.98, 0.53, -7.04, 5.29, 3.55], + /// [ 3.30, 8.26, -3.89, 8.20, -1.51], + /// [ 4.43, 4.96, -7.66, -7.33, 6.18], + /// [ 7.31, -6.43, -6.16, 2.47, 5.58], + /// ]; + /// let (eigs, vecs) = a.eig().unwrap(); + /// + /// let a = a.map(|v| v.as_c()); + /// for (&e, vec) in eigs.iter().zip(vecs.axis_iter(Axis(1))) { + /// let ev = vec.map(|v| v * e); + /// let av = a.dot(&vec); + /// assert_close_l2!(&av, &ev, 1e-5); + /// } + /// ``` fn eig(&self) -> Result<(Self::EigVal, Self::EigVec)>; } From 261e79add06073d18c8bd16083fb659690889d7f Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Mon, 6 Jul 2020 00:59:55 +0900 Subject: [PATCH 05/49] Impl eig using LAPACK instead of LAPACKE --- lax/src/eig.rs | 296 +++++++++++++++++++++++++++----------- ndarray-linalg/src/eig.rs | 10 +- 2 files changed, 215 insertions(+), 91 deletions(-) diff --git a/lax/src/eig.rs b/lax/src/eig.rs index e8f9a8f8..0a34977d 100644 --- a/lax/src/eig.rs +++ b/lax/src/eig.rs @@ -2,11 +2,12 @@ use crate::{error::*, layout::MatrixLayout}; use cauchy::*; -use num_traits::Zero; +use num_traits::{ToPrimitive, Zero}; -/// Wraps `*geev` for real/complex +/// Wraps `*geev` for general matrices pub trait Eig_: Scalar { - unsafe fn eig( + /// Calculate Right eigenvalue + fn eig( calc_v: bool, l: MatrixLayout, a: &mut [Self], @@ -16,117 +17,242 @@ pub trait Eig_: Scalar { macro_rules! impl_eig_complex { ($scalar:ty, $ev:path) => { impl Eig_ for $scalar { - unsafe fn eig( + fn eig( calc_v: bool, l: MatrixLayout, mut a: &mut [Self], ) -> Result<(Vec, Vec)> { let (n, _) = l.size(); - let jobvr = if calc_v { b'V' } else { b'N' }; - let mut w = vec![Self::Complex::zero(); n as usize]; - let mut vl = Vec::new(); - let mut vr = vec![Self::Complex::zero(); (n * n) as usize]; - $ev( - l.lapacke_layout(), - b'N', - jobvr, - n, - &mut a, - n, - &mut w, - &mut vl, - n, - &mut vr, - n, - ) - .as_lapack_result()?; - Ok((w, vr)) + // Because LAPACK assumes F-continious array, C-continious array should be taken Hermitian conjugate. + // However, we utilize a fact that left eigenvector of A^H corresponds to the right eigenvector of A + let (jobvl, jobvr) = if calc_v { + match l { + MatrixLayout::C { .. } => (b'V', b'N'), + MatrixLayout::F { .. } => (b'N', b'V'), + } + } else { + (b'N', b'N') + }; + let mut eigs = vec![Self::Complex::zero(); n as usize]; + let mut rwork = vec![Self::Real::zero(); 2 * n as usize]; + + let mut vl = if jobvl == b'V' { + Some(vec![Self::Complex::zero(); (n * n) as usize]) + } else { + None + }; + let mut vr = if jobvr == b'V' { + Some(vec![Self::Complex::zero(); (n * n) as usize]) + } else { + None + }; + + // calc work size + let mut info = 0; + let mut work_size = [Self::zero()]; + unsafe { + $ev( + jobvl, + jobvr, + n, + &mut a, + n, + &mut eigs, + &mut vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), + n, + &mut vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work_size, + -1, + &mut rwork, + &mut info, + ) + }; + info.as_lapack_result()?; + + // actal ev + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + $ev( + jobvl, + jobvr, + n, + &mut a, + n, + &mut eigs, + &mut vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), + n, + &mut vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work, + lwork as i32, + &mut rwork, + &mut info, + ) + }; + info.as_lapack_result()?; + + // Hermite conjugate + if jobvl == b'V' { + for c in vl.as_mut().unwrap().iter_mut() { + c.im = -c.im + } + } + + Ok((eigs, vr.or(vl).unwrap_or(Vec::new()))) } } }; } +impl_eig_complex!(c64, lapack::zgeev); +impl_eig_complex!(c32, lapack::cgeev); + macro_rules! impl_eig_real { ($scalar:ty, $ev:path) => { impl Eig_ for $scalar { - unsafe fn eig( + fn eig( calc_v: bool, l: MatrixLayout, mut a: &mut [Self], ) -> Result<(Vec, Vec)> { let (n, _) = l.size(); - let jobvr = if calc_v { b'V' } else { b'N' }; - let mut wr = vec![Self::Real::zero(); n as usize]; - let mut wi = vec![Self::Real::zero(); n as usize]; - let mut vl = Vec::new(); - let mut vr = vec![Self::Real::zero(); (n * n) as usize]; - let info = $ev( - l.lapacke_layout(), - b'N', - jobvr, - n, - &mut a, - n, - &mut wr, - &mut wi, - &mut vl, - n, - &mut vr, - n, - ); - let w: Vec = wr + // Because LAPACK assumes F-continious array, C-continious array should be taken Hermitian conjugate. + // However, we utilize a fact that left eigenvector of A^H corresponds to the right eigenvector of A + let (jobvl, jobvr) = if calc_v { + match l { + MatrixLayout::C { .. } => (b'V', b'N'), + MatrixLayout::F { .. } => (b'N', b'V'), + } + } else { + (b'N', b'N') + }; + let mut eig_re = vec![Self::zero(); n as usize]; + let mut eig_im = vec![Self::zero(); n as usize]; + + let mut vl = if jobvl == b'V' { + Some(vec![Self::zero(); (n * n) as usize]) + } else { + None + }; + let mut vr = if jobvr == b'V' { + Some(vec![Self::zero(); (n * n) as usize]) + } else { + None + }; + + // calc work size + let mut info = 0; + let mut work_size = [0.0]; + unsafe { + $ev( + jobvl, + jobvr, + n, + &mut a, + n, + &mut eig_re, + &mut eig_im, + vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), + n, + vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work_size, + -1, + &mut info, + ) + }; + info.as_lapack_result()?; + + // actual ev + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + $ev( + jobvl, + jobvr, + n, + &mut a, + n, + &mut eig_re, + &mut eig_im, + vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), + n, + vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work, + lwork as i32, + &mut info, + ) + }; + info.as_lapack_result()?; + + // reconstruct eigenvalues + let eigs: Vec = eig_re .iter() - .zip(wi.iter()) - .map(|(&r, &i)| Self::Complex::new(r, i)) + .zip(eig_im.iter()) + .map(|(&re, &im)| Self::complex(re, im)) .collect(); - // If the j-th eigenvalue is real, then - // eigenvector = [ vr[j], vr[j+n], vr[j+2*n], ... ]. + if !calc_v { + return Ok((eigs, Vec::new())); + } + + // Reconstruct eigenvectors into complex-array + // -------------------------------------------- // - // If the j-th and (j+1)-st eigenvalues form a complex conjugate pair, - // eigenvector(j) = [ vr[j] + i*vr[j+1], vr[j+n] + i*vr[j+n+1], vr[j+2*n] + i*vr[j+2*n+1], ... ] and - // eigenvector(j+1) = [ vr[j] - i*vr[j+1], vr[j+n] - i*vr[j+n+1], vr[j+2*n] - i*vr[j+2*n+1], ... ]. + // From LAPACK API https://software.intel.com/en-us/node/469230 // - // Therefore, if eigenvector(j) is written as [ v_{j0}, v_{j1}, v_{j2}, ... ], - // you have to make - // v = vec![ v_{00}, v_{10}, v_{20}, ..., v_{jk}, v_{(j+1)k}, v_{(j+2)k}, ... ] (v.len() = n*n) - // based on wi and vr. - // After that, v is converted to Array2 (see ../eig.rs). + // - If the j-th eigenvalue is real, + // - v(j) = VR(:,j), the j-th column of VR. + // + // - If the j-th and (j+1)-st eigenvalues form a complex conjugate pair, + // - v(j) = VR(:,j) + i*VR(:,j+1) + // - v(j+1) = VR(:,j) - i*VR(:,j+1). + // + // ``` + // j -> <----pair----> <----pair----> + // [ ... (real), (imag), (imag), (imag), (imag), ... ] : eigs + // ^ ^ ^ ^ ^ + // false false true false true : is_conjugate_pair + // ``` let n = n as usize; - let mut flg = false; - let conj: Vec = wi - .iter() - .map(|&i| { - if flg { - flg = false; - -1 - } else if i != 0.0 { - flg = true; - 1 - } else { - 0 + let v = vr.or(vl).unwrap(); + let mut eigvecs = vec![Self::Complex::zero(); n * n]; + let mut is_conjugate_pair = false; // flag for check `j` is complex conjugate + for j in 0..n { + if eig_im[j] == 0.0 { + // j-th eigenvalue is real + for i in 0..n { + eigvecs[i + j * n] = Self::complex(v[i + j * n], 0.0); } - }) - .collect(); - let v: Vec = (0..n * n) - .map(|i| { - let j = i % n; - match conj[j] { - 1 => Self::Complex::new(vr[i], vr[i + 1]), - -1 => Self::Complex::new(vr[i - 1], -vr[i]), - _ => Self::Complex::new(vr[i], 0.0), + } else { + // j-th eigenvalue is complex + // complex conjugated pair can be `j-1` or `j+1` + if is_conjugate_pair { + let j_pair = j - 1; + assert!(j_pair < n); + for i in 0..n { + eigvecs[i + j * n] = Self::complex(v[i + j_pair * n], v[i + j * n]); + } + } else { + let j_pair = j + 1; + assert!(j_pair < n); + for i in 0..n { + eigvecs[i + j * n] = + Self::complex(v[i + j * n], -v[i + j_pair * n]); + } } - }) - .collect(); + is_conjugate_pair = !is_conjugate_pair; + } + } - info.as_lapack_result()?; - Ok((w, v)) + Ok((eigs, eigvecs)) } } }; } -impl_eig_real!(f64, lapacke::dgeev); -impl_eig_real!(f32, lapacke::sgeev); -impl_eig_complex!(c64, lapacke::zgeev); -impl_eig_complex!(c32, lapacke::cgeev); +impl_eig_real!(f64, lapack::dgeev); +impl_eig_real!(f32, lapack::sgeev); diff --git a/ndarray-linalg/src/eig.rs b/ndarray-linalg/src/eig.rs index 51d69804..17f5a1e8 100644 --- a/ndarray-linalg/src/eig.rs +++ b/ndarray-linalg/src/eig.rs @@ -48,13 +48,11 @@ where fn eig(&self) -> Result<(Self::EigVal, Self::EigVec)> { let mut a = self.to_owned(); let layout = a.square_layout()?; - let (s, t) = unsafe { A::eig(true, layout, a.as_allocated_mut()?)? }; - let (n, _) = layout.size(); + let (s, t) = A::eig(true, layout, a.as_allocated_mut()?)?; + let n = layout.len() as usize; Ok(( ArrayBase::from(s), - ArrayBase::from(t) - .into_shape((n as usize, n as usize)) - .unwrap(), + Array2::from_shape_vec((n, n).f(), t).unwrap(), )) } } @@ -74,7 +72,7 @@ where fn eigvals(&self) -> Result { let mut a = self.to_owned(); - let (s, _) = unsafe { A::eig(true, a.square_layout()?, a.as_allocated_mut()?)? }; + let (s, _) = A::eig(true, a.square_layout()?, a.as_allocated_mut()?)?; Ok(ArrayBase::from(s)) } } From f7d93f4578d1277842d8a03dfe315b41a849d0f2 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Mon, 6 Jul 2020 01:30:33 +0900 Subject: [PATCH 06/49] Fix unsafe signature --- lax/src/solve.rs | 94 ++++++++++++++++++++----------------- ndarray-linalg/src/solve.rs | 76 +++++++++++++----------------- 2 files changed, 84 insertions(+), 86 deletions(-) diff --git a/lax/src/solve.rs b/lax/src/solve.rs index 67af6409..7c39cf88 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -5,65 +5,71 @@ use crate::{error::*, layout::MatrixLayout}; use cauchy::*; use num_traits::Zero; -/// Wraps `*getrf`, `*getri`, and `*getrs` pub trait Solve_: Scalar + Sized { /// Computes the LU factorization of a general `m x n` matrix `a` using /// partial pivoting with row interchanges. /// - /// If the result matches `Err(LinalgError::Lapack(LapackError { - /// return_code )) if return_code > 0`, then `U[(return_code-1, - /// return_code-1)]` is exactly zero. The factorization has been completed, - /// but the factor `U` is exactly singular, and division by zero will occur - /// if it is used to solve a system of equations. - unsafe fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; - unsafe fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; + /// $ PA = LU $ + /// + /// Error + /// ------ + /// - `LapackComputationalFailure { return_code }` when the matrix is singular + /// - `U[(return_code-1, return_code-1)]` is exactly zero. + /// - Division by zero will occur if it is used to solve a system of equations. + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; + + fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; + /// Estimates the the reciprocal of the condition number of the matrix in 1-norm. /// /// `anorm` should be the 1-norm of the matrix `a`. - unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result; - unsafe fn solve( - l: MatrixLayout, - t: Transpose, - a: &[Self], - p: &Pivot, - b: &mut [Self], - ) -> Result<()>; + fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result; + + fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; } macro_rules! impl_solve { ($scalar:ty, $getrf:path, $getri:path, $gecon:path, $getrs:path) => { impl Solve_ for $scalar { - unsafe fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { let (row, col) = l.size(); let k = ::std::cmp::min(row, col); let mut ipiv = vec![0; k as usize]; - $getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv).as_lapack_result()?; + unsafe { + $getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv) + .as_lapack_result()?; + } Ok(ipiv) } - unsafe fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { + fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { let (n, _) = l.size(); - $getri(l.lapacke_layout(), n, a, l.lda(), ipiv).as_lapack_result()?; + unsafe { + $getri(l.lapacke_layout(), n, a, l.lda(), ipiv).as_lapack_result()?; + } Ok(()) } - unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result { + fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result { let (n, _) = l.size(); let mut rcond = Self::Real::zero(); - $gecon( - l.lapacke_layout(), - NormType::One as u8, - n, - a, - l.lda(), - anorm, - &mut rcond, - ) + unsafe { + $gecon( + l.lapacke_layout(), + NormType::One as u8, + n, + a, + l.lda(), + anorm, + &mut rcond, + ) + } .as_lapack_result()?; + Ok(rcond) } - unsafe fn solve( + fn solve( l: MatrixLayout, t: Transpose, a: &[Self], @@ -73,18 +79,20 @@ macro_rules! impl_solve { let (n, _) = l.size(); let nrhs = 1; let ldb = 1; - $getrs( - l.lapacke_layout(), - t as u8, - n, - nrhs, - a, - l.lda(), - ipiv, - b, - ldb, - ) - .as_lapack_result()?; + unsafe { + $getrs( + l.lapacke_layout(), + t as u8, + n, + nrhs, + a, + l.lda(), + ipiv, + b, + ldb, + ) + .as_lapack_result()?; + } Ok(()) } } diff --git a/ndarray-linalg/src/solve.rs b/ndarray-linalg/src/solve.rs index 566511f3..fd4b3017 100644 --- a/ndarray-linalg/src/solve.rs +++ b/ndarray-linalg/src/solve.rs @@ -167,15 +167,13 @@ where where Sb: DataMut, { - unsafe { - A::solve( - self.a.square_layout()?, - Transpose::No, - self.a.as_allocated()?, - &self.ipiv, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solve( + self.a.square_layout()?, + Transpose::No, + self.a.as_allocated()?, + &self.ipiv, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } fn solve_t_inplace<'a, Sb>( @@ -185,15 +183,13 @@ where where Sb: DataMut, { - unsafe { - A::solve( - self.a.square_layout()?, - Transpose::Transpose, - self.a.as_allocated()?, - &self.ipiv, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solve( + self.a.square_layout()?, + Transpose::Transpose, + self.a.as_allocated()?, + &self.ipiv, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } fn solve_h_inplace<'a, Sb>( @@ -203,15 +199,13 @@ where where Sb: DataMut, { - unsafe { - A::solve( - self.a.square_layout()?, - Transpose::Hermite, - self.a.as_allocated()?, - &self.ipiv, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solve( + self.a.square_layout()?, + Transpose::Hermite, + self.a.as_allocated()?, + &self.ipiv, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } } @@ -273,7 +267,7 @@ where S: DataMut + RawDataClone, { fn factorize_into(mut self) -> Result> { - let ipiv = unsafe { A::lu(self.layout()?, self.as_allocated_mut()?)? }; + let ipiv = A::lu(self.layout()?, self.as_allocated_mut()?)?; Ok(LUFactorized { a: self, ipiv }) } } @@ -285,7 +279,7 @@ where { fn factorize(&self) -> Result>> { let mut a: Array2 = replicate(self); - let ipiv = unsafe { A::lu(a.layout()?, a.as_allocated_mut()?)? }; + let ipiv = A::lu(a.layout()?, a.as_allocated_mut()?)?; Ok(LUFactorized { a, ipiv }) } } @@ -312,13 +306,11 @@ where type Output = ArrayBase; fn inv_into(mut self) -> Result> { - unsafe { - A::inv( - self.a.square_layout()?, - self.a.as_allocated_mut()?, - &self.ipiv, - )? - }; + A::inv( + self.a.square_layout()?, + self.a.as_allocated_mut()?, + &self.ipiv, + )?; Ok(self.a) } } @@ -539,13 +531,11 @@ where S: Data + RawDataClone, { fn rcond(&self) -> Result { - unsafe { - Ok(A::rcond( - self.a.layout()?, - self.a.as_allocated()?, - self.a.opnorm_one()?, - )?) - } + Ok(A::rcond( + self.a.layout()?, + self.a.as_allocated()?, + self.a.opnorm_one()?, + )?) } } From a7dc42c00086bbd289c1fbf7e7f83f874503bd16 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 7 Jul 2020 03:14:26 +0900 Subject: [PATCH 07/49] Impl solve using LAPACK --- lax/src/lib.rs | 3 ++ lax/src/rcond.rs | 78 ++++++++++++++++++++++++++++++++ lax/src/solve.rs | 114 +++++++++++++++++------------------------------ 3 files changed, 122 insertions(+), 73 deletions(-) create mode 100644 lax/src/rcond.rs diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 769910db..be88410e 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -70,6 +70,7 @@ pub mod layout; pub mod least_squares; pub mod opnorm; pub mod qr; +pub mod rcond; pub mod solve; pub mod solveh; pub mod svd; @@ -83,6 +84,7 @@ pub use self::eigh::*; pub use self::least_squares::*; pub use self::opnorm::*; pub use self::qr::*; +pub use self::rcond::*; pub use self::solve::*; pub use self::solveh::*; pub use self::svd::*; @@ -107,6 +109,7 @@ pub trait Lapack: + Eigh_ + Triangular_ + Tridiagonal_ + + Rcond_ { } diff --git a/lax/src/rcond.rs b/lax/src/rcond.rs new file mode 100644 index 00000000..7ca24262 --- /dev/null +++ b/lax/src/rcond.rs @@ -0,0 +1,78 @@ +use super::*; +use crate::{error::*, layout::MatrixLayout}; +use cauchy::*; +use num_traits::Zero; + +pub trait Rcond_: Scalar + Sized { + /// Estimates the the reciprocal of the condition number of the matrix in 1-norm. + /// + /// `anorm` should be the 1-norm of the matrix `a`. + fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result; +} + +macro_rules! impl_rcond_real { + ($scalar:ty, $gecon:path) => { + impl Rcond_ for $scalar { + fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result { + let (n, _) = l.size(); + let mut rcond = Self::Real::zero(); + let mut info = 0; + + let mut work = vec![Self::zero(); 4 * n as usize]; + let mut iwork = vec![0; n as usize]; + unsafe { + $gecon( + NormType::One as u8, + n, + a, + l.lda(), + anorm, + &mut rcond, + &mut work, + &mut iwork, + &mut info, + ) + }; + info.as_lapack_result()?; + + Ok(rcond) + } + } + }; +} + +impl_rcond_real!(f32, lapack::sgecon); +impl_rcond_real!(f64, lapack::dgecon); + +macro_rules! impl_rcond_complex { + ($scalar:ty, $gecon:path) => { + impl Rcond_ for $scalar { + fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result { + let (n, _) = l.size(); + let mut rcond = Self::Real::zero(); + let mut info = 0; + let mut work = vec![Self::zero(); 2 * n as usize]; + let mut rwork = vec![Self::Real::zero(); 2 * n as usize]; + unsafe { + $gecon( + NormType::One as u8, + n, + a, + l.lda(), + anorm, + &mut rcond, + &mut work, + &mut rwork, + &mut info, + ) + }; + info.as_lapack_result()?; + + Ok(rcond) + } + } + }; +} + +impl_rcond_complex!(c32, lapack::cgecon); +impl_rcond_complex!(c64, lapack::zgecon); diff --git a/lax/src/solve.rs b/lax/src/solve.rs index 7c39cf88..a00ce9c3 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -3,7 +3,7 @@ use super::*; use crate::{error::*, layout::MatrixLayout}; use cauchy::*; -use num_traits::Zero; +use num_traits::{ToPrimitive, Zero}; pub trait Solve_: Scalar + Sized { /// Computes the LU factorization of a general `m x n` matrix `a` using @@ -14,59 +14,55 @@ pub trait Solve_: Scalar + Sized { /// Error /// ------ /// - `LapackComputationalFailure { return_code }` when the matrix is singular - /// - `U[(return_code-1, return_code-1)]` is exactly zero. - /// - Division by zero will occur if it is used to solve a system of equations. + /// - Division by zero will occur if it is used to solve a system of equations + /// because `U[(return_code-1, return_code-1)]` is exactly zero. fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; - /// Estimates the the reciprocal of the condition number of the matrix in 1-norm. - /// - /// `anorm` should be the 1-norm of the matrix `a`. - fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result; - fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; } macro_rules! impl_solve { - ($scalar:ty, $getrf:path, $getri:path, $gecon:path, $getrs:path) => { + ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => { impl Solve_ for $scalar { fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { let (row, col) = l.size(); + assert_eq!(a.len() as i32, row * col); let k = ::std::cmp::min(row, col); let mut ipiv = vec![0; k as usize]; - unsafe { - $getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv) - .as_lapack_result()?; - } + let mut info = 0; + unsafe { $getrf(l.lda(), l.len(), a, l.lda(), &mut ipiv, &mut info) }; + info.as_lapack_result()?; Ok(ipiv) } fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { let (n, _) = l.size(); - unsafe { - $getri(l.lapacke_layout(), n, a, l.lda(), ipiv).as_lapack_result()?; - } - Ok(()) - } - fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result { - let (n, _) = l.size(); - let mut rcond = Self::Real::zero(); + // calc work size + let mut info = 0; + let mut work_size = [Self::zero()]; + unsafe { $getri(n, a, l.lda(), ipiv, &mut work_size, -1, &mut info) }; + info.as_lapack_result()?; + + // actual + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; unsafe { - $gecon( - l.lapacke_layout(), - NormType::One as u8, - n, + $getri( + l.len(), a, l.lda(), - anorm, - &mut rcond, + ipiv, + &mut work, + lwork as i32, + &mut info, ) - } - .as_lapack_result()?; + }; + info.as_lapack_result()?; - Ok(rcond) + Ok(()) } fn solve( @@ -76,54 +72,26 @@ macro_rules! impl_solve { ipiv: &Pivot, b: &mut [Self], ) -> Result<()> { + let t = match l { + MatrixLayout::C { .. } => match t { + Transpose::No => Transpose::Transpose, + Transpose::Transpose | Transpose::Hermite => Transpose::No, + }, + _ => t, + }; let (n, _) = l.size(); let nrhs = 1; - let ldb = 1; - unsafe { - $getrs( - l.lapacke_layout(), - t as u8, - n, - nrhs, - a, - l.lda(), - ipiv, - b, - ldb, - ) - .as_lapack_result()?; - } + let ldb = l.lda(); + let mut info = 0; + unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) }; + info.as_lapack_result()?; Ok(()) } } }; } // impl_solve! -impl_solve!( - f64, - lapacke::dgetrf, - lapacke::dgetri, - lapacke::dgecon, - lapacke::dgetrs -); -impl_solve!( - f32, - lapacke::sgetrf, - lapacke::sgetri, - lapacke::sgecon, - lapacke::sgetrs -); -impl_solve!( - c64, - lapacke::zgetrf, - lapacke::zgetri, - lapacke::zgecon, - lapacke::zgetrs -); -impl_solve!( - c32, - lapacke::cgetrf, - lapacke::cgetri, - lapacke::cgecon, - lapacke::cgetrs -); +impl_solve!(f64, lapack::dgetrf, lapack::dgetri, lapack::dgetrs); +impl_solve!(f32, lapack::sgetrf, lapack::sgetri, lapack::sgetrs); +impl_solve!(c64, lapack::zgetrf, lapack::zgetri, lapack::zgetrs); +impl_solve!(c32, lapack::cgetrf, lapack::cgetri, lapack::cgetrs); From b6e48b9b23f9f18e123d8736dae68e157b75a5aa Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 10 Jul 2020 17:37:20 +0900 Subject: [PATCH 08/49] Handling empty matrix case --- lax/src/solve.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lax/src/solve.rs b/lax/src/solve.rs index a00ce9c3..93aa4722 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -29,6 +29,10 @@ macro_rules! impl_solve { fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { let (row, col) = l.size(); assert_eq!(a.len() as i32, row * col); + if row == 0 || col == 0 { + // Do nothing for empty matrix + return Ok(Vec::new()); + } let k = ::std::cmp::min(row, col); let mut ipiv = vec![0; k as usize]; let mut info = 0; From c9b6d1330a6d965c3de8da0bb02a24544e7d662f Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 10 Jul 2020 21:33:52 +0900 Subject: [PATCH 09/49] Handle norm type based on matrix layout --- lax/src/rcond.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/lax/src/rcond.rs b/lax/src/rcond.rs index 7ca24262..135c4a12 100644 --- a/lax/src/rcond.rs +++ b/lax/src/rcond.rs @@ -20,9 +20,13 @@ macro_rules! impl_rcond_real { let mut work = vec![Self::zero(); 4 * n as usize]; let mut iwork = vec![0; n as usize]; + let norm_type = match l { + MatrixLayout::C { .. } => NormType::Infinity, + MatrixLayout::F { .. } => NormType::One, + } as u8; unsafe { $gecon( - NormType::One as u8, + norm_type, n, a, l.lda(), @@ -53,9 +57,13 @@ macro_rules! impl_rcond_complex { let mut info = 0; let mut work = vec![Self::zero(); 2 * n as usize]; let mut rwork = vec![Self::Real::zero(); 2 * n as usize]; + let norm_type = match l { + MatrixLayout::C { .. } => NormType::Infinity, + MatrixLayout::F { .. } => NormType::One, + } as u8; unsafe { $gecon( - NormType::One as u8, + norm_type, n, a, l.lda(), From ea9123153b712da58a75c064eb75aa5f3806cd90 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 10 Jul 2020 22:54:44 +0900 Subject: [PATCH 10/49] Impl bk, solveh, and invh using LAPACK --- lax/src/solveh.rs | 72 +++++++++++++++++++++++++++++------------------ 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/lax/src/solveh.rs b/lax/src/solveh.rs index 01e90f13..508b2ec6 100644 --- a/lax/src/solveh.rs +++ b/lax/src/solveh.rs @@ -5,6 +5,7 @@ use super::*; use crate::{error::*, layout::MatrixLayout}; use cauchy::*; +use num_traits::{ToPrimitive, Zero}; pub trait Solveh_: Sized { /// Bunch-Kaufman: wrapper of `*sytrf` and `*hetrf` @@ -28,13 +29,39 @@ macro_rules! impl_solveh { let (n, _) = l.size(); let mut ipiv = vec![0; n as usize]; if n == 0 { - // Work around bug in LAPACKE functions. - Ok(ipiv) - } else { - $trf(l.lapacke_layout(), uplo as u8, n, a, l.lda(), &mut ipiv) - .as_lapack_result()?; - Ok(ipiv) + return Ok(Vec::new()); } + + // calc work size + let mut info = 0; + let mut work_size = [Self::zero()]; + $trf( + uplo as u8, + n, + a, + l.lda(), + &mut ipiv, + &mut work_size, + -1, + &mut info, + ); + info.as_lapack_result()?; + + // actual + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + $trf( + uplo as u8, + n, + a, + l.lda(), + &mut ipiv, + &mut work, + lwork as i32, + &mut info, + ); + info.as_lapack_result()?; + Ok(ipiv) } unsafe fn invh( @@ -44,7 +71,10 @@ macro_rules! impl_solveh { ipiv: &Pivot, ) -> Result<()> { let (n, _) = l.size(); - $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda(), ipiv).as_lapack_result()?; + let mut info = 0; + let mut work = vec![Self::zero(); n as usize]; + $tri(uplo as u8, n, a, l.lda(), ipiv, &mut work, &mut info); + info.as_lapack_result()?; Ok(()) } @@ -56,30 +86,16 @@ macro_rules! impl_solveh { b: &mut [Self], ) -> Result<()> { let (n, _) = l.size(); - let nrhs = 1; - let ldb = match l { - MatrixLayout::C { .. } => 1, - MatrixLayout::F { .. } => n, - }; - $trs( - l.lapacke_layout(), - uplo as u8, - n, - nrhs, - a, - l.lda(), - ipiv, - b, - ldb, - ) - .as_lapack_result()?; + let mut info = 0; + $trs(uplo as u8, n, 1, a, l.lda(), ipiv, b, n, &mut info); + info.as_lapack_result()?; Ok(()) } } }; } // impl_solveh! -impl_solveh!(f64, lapacke::dsytrf, lapacke::dsytri, lapacke::dsytrs); -impl_solveh!(f32, lapacke::ssytrf, lapacke::ssytri, lapacke::ssytrs); -impl_solveh!(c64, lapacke::zhetrf, lapacke::zhetri, lapacke::zhetrs); -impl_solveh!(c32, lapacke::chetrf, lapacke::chetri, lapacke::chetrs); +impl_solveh!(f64, lapack::dsytrf, lapack::dsytri, lapack::dsytrs); +impl_solveh!(f32, lapack::ssytrf, lapack::ssytri, lapack::ssytrs); +impl_solveh!(c64, lapack::zhetrf, lapack::zhetri, lapack::zhetrs); +impl_solveh!(c32, lapack::chetrf, lapack::chetri, lapack::chetrs); From 9ccb6af5487f3eae51592a7d49036e2def232fcd Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 10 Jul 2020 22:55:44 +0900 Subject: [PATCH 11/49] Revise deth --- ndarray-linalg/src/solveh.rs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/ndarray-linalg/src/solveh.rs b/ndarray-linalg/src/solveh.rs index c05e2bde..da748f77 100644 --- a/ndarray-linalg/src/solveh.rs +++ b/ndarray-linalg/src/solveh.rs @@ -314,6 +314,7 @@ where S: Data, A: Scalar + Lapack, { + let layout = a.layout().unwrap(); let mut sign = A::Real::one(); let mut ln_det = A::Real::zero(); let mut ipiv_enum = ipiv_iter.enumerate(); @@ -337,9 +338,15 @@ where debug_assert_eq!(lower_diag.im(), Zero::zero()); // Off-diagonal elements, can be complex. - let off_diag = match uplo { - UPLO::Upper => unsafe { a.uget((k, k + 1)) }, - UPLO::Lower => unsafe { a.uget((k + 1, k)) }, + let off_diag = match layout { + MatrixLayout::C { .. } => match uplo { + UPLO::Upper => unsafe { a.uget((k + 1, k)) }, + UPLO::Lower => unsafe { a.uget((k, k + 1)) }, + }, + MatrixLayout::F { .. } => match uplo { + UPLO::Upper => unsafe { a.uget((k, k + 1)) }, + UPLO::Lower => unsafe { a.uget((k + 1, k)) }, + }, }; // Determinant of 2x2 block. From 34c7f5a7d5daa1cf6e8fcb1de9b32f933d19d0bf Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 10 Jul 2020 23:23:04 +0900 Subject: [PATCH 12/49] Drop unsafe of solveh and others in #216 --- lax/src/solveh.rs | 71 ++++++++++++++++-------------------- ndarray-linalg/src/solveh.rs | 34 ++++++++--------- 2 files changed, 47 insertions(+), 58 deletions(-) diff --git a/lax/src/solveh.rs b/lax/src/solveh.rs index 508b2ec6..da2ecdf5 100644 --- a/lax/src/solveh.rs +++ b/lax/src/solveh.rs @@ -9,23 +9,17 @@ use num_traits::{ToPrimitive, Zero}; pub trait Solveh_: Sized { /// Bunch-Kaufman: wrapper of `*sytrf` and `*hetrf` - unsafe fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result; + fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result; /// Wrapper of `*sytri` and `*hetri` - unsafe fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()>; + fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()>; /// Wrapper of `*sytrs` and `*hetrs` - unsafe fn solveh( - l: MatrixLayout, - uplo: UPLO, - a: &[Self], - ipiv: &Pivot, - b: &mut [Self], - ) -> Result<()>; + fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>; } macro_rules! impl_solveh { ($scalar:ty, $trf:path, $tri:path, $trs:path) => { impl Solveh_ for $scalar { - unsafe fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result { + fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result { let (n, _) = l.size(); let mut ipiv = vec![0; n as usize]; if n == 0 { @@ -35,50 +29,49 @@ macro_rules! impl_solveh { // calc work size let mut info = 0; let mut work_size = [Self::zero()]; - $trf( - uplo as u8, - n, - a, - l.lda(), - &mut ipiv, - &mut work_size, - -1, - &mut info, - ); + unsafe { + $trf( + uplo as u8, + n, + a, + l.lda(), + &mut ipiv, + &mut work_size, + -1, + &mut info, + ) + }; info.as_lapack_result()?; // actual let lwork = work_size[0].to_usize().unwrap(); let mut work = vec![Self::zero(); lwork]; - $trf( - uplo as u8, - n, - a, - l.lda(), - &mut ipiv, - &mut work, - lwork as i32, - &mut info, - ); + unsafe { + $trf( + uplo as u8, + n, + a, + l.lda(), + &mut ipiv, + &mut work, + lwork as i32, + &mut info, + ) + }; info.as_lapack_result()?; Ok(ipiv) } - unsafe fn invh( - l: MatrixLayout, - uplo: UPLO, - a: &mut [Self], - ipiv: &Pivot, - ) -> Result<()> { + fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> { let (n, _) = l.size(); let mut info = 0; let mut work = vec![Self::zero(); n as usize]; - $tri(uplo as u8, n, a, l.lda(), ipiv, &mut work, &mut info); + unsafe { $tri(uplo as u8, n, a, l.lda(), ipiv, &mut work, &mut info) }; info.as_lapack_result()?; Ok(()) } - unsafe fn solveh( + fn solveh( l: MatrixLayout, uplo: UPLO, a: &[Self], @@ -87,7 +80,7 @@ macro_rules! impl_solveh { ) -> Result<()> { let (n, _) = l.size(); let mut info = 0; - $trs(uplo as u8, n, 1, a, l.lda(), ipiv, b, n, &mut info); + unsafe { $trs(uplo as u8, n, 1, a, l.lda(), ipiv, b, n, &mut info) }; info.as_lapack_result()?; Ok(()) } diff --git a/ndarray-linalg/src/solveh.rs b/ndarray-linalg/src/solveh.rs index da748f77..102158aa 100644 --- a/ndarray-linalg/src/solveh.rs +++ b/ndarray-linalg/src/solveh.rs @@ -113,15 +113,13 @@ where where Sb: DataMut, { - unsafe { - A::solveh( - self.a.square_layout()?, - UPLO::Upper, - self.a.as_allocated()?, - &self.ipiv, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solveh( + self.a.square_layout()?, + UPLO::Upper, + self.a.as_allocated()?, + &self.ipiv, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } } @@ -165,7 +163,7 @@ where S: DataMut, { fn factorizeh_into(mut self) -> Result> { - let ipiv = unsafe { A::bk(self.square_layout()?, UPLO::Upper, self.as_allocated_mut()?)? }; + let ipiv = A::bk(self.square_layout()?, UPLO::Upper, self.as_allocated_mut()?)?; Ok(BKFactorized { a: self, ipiv }) } } @@ -177,7 +175,7 @@ where { fn factorizeh(&self) -> Result>> { let mut a: Array2 = replicate(self); - let ipiv = unsafe { A::bk(a.square_layout()?, UPLO::Upper, a.as_allocated_mut()?)? }; + let ipiv = A::bk(a.square_layout()?, UPLO::Upper, a.as_allocated_mut()?)?; Ok(BKFactorized { a, ipiv }) } } @@ -204,14 +202,12 @@ where type Output = ArrayBase; fn invh_into(mut self) -> Result> { - unsafe { - A::invh( - self.a.square_layout()?, - UPLO::Upper, - self.a.as_allocated_mut()?, - &self.ipiv, - )? - }; + A::invh( + self.a.square_layout()?, + UPLO::Upper, + self.a.as_allocated_mut()?, + &self.ipiv, + )?; triangular_fill_hermitian(&mut self.a, UPLO::Upper); Ok(self.a) } From ea930d0d4aa61b1b9e1a932f840d4507d2b2e804 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 11 Jul 2020 17:06:29 +0900 Subject: [PATCH 13/49] SVD using LAPACK --- lax/src/svd.rs | 199 +++++++++++++++++++++++++++++++------- ndarray-linalg/src/svd.rs | 28 ++++-- 2 files changed, 187 insertions(+), 40 deletions(-) diff --git a/lax/src/svd.rs b/lax/src/svd.rs index 47a9a1be..0e7bb4bb 100644 --- a/lax/src/svd.rs +++ b/lax/src/svd.rs @@ -2,9 +2,10 @@ use crate::{error::*, layout::MatrixLayout}; use cauchy::*; -use num_traits::Zero; +use num_traits::{ToPrimitive, Zero}; #[repr(u8)] +#[derive(Debug, Copy, Clone)] enum FlagSVD { All = b'A', // OverWrite = b'O', @@ -12,6 +13,16 @@ enum FlagSVD { No = b'N', } +impl FlagSVD { + fn from_bool(calc_uv: bool) -> Self { + if calc_uv { + FlagSVD::All + } else { + FlagSVD::No + } + } +} + /// Result of SVD pub struct SVDOutput { /// diagonal values @@ -24,6 +35,7 @@ pub struct SVDOutput { /// Wraps `*gesvd` pub trait SVD_: Scalar { + /// Calculate singular value decomposition $ A = U \Sigma V^T $ unsafe fn svd( l: MatrixLayout, calc_u: bool, @@ -32,7 +44,7 @@ pub trait SVD_: Scalar { ) -> Result>; } -macro_rules! impl_svd { +macro_rules! impl_svd_real { ($scalar:ty, $gesvd:path) => { impl SVD_ for $scalar { unsafe fn svd( @@ -41,48 +53,169 @@ macro_rules! impl_svd { calc_vt: bool, mut a: &mut [Self], ) -> Result> { - let (m, n) = l.size(); - let k = ::std::cmp::min(n, m); - let lda = l.lda(); - let (ju, ldu, mut u) = if calc_u { - (FlagSVD::All, m, vec![Self::zero(); (m * m) as usize]) - } else { - (FlagSVD::No, 1, Vec::new()) + let ju = match l { + MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u), + MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt), }; - let (jvt, ldvt, mut vt) = if calc_vt { - (FlagSVD::All, n, vec![Self::zero(); (n * n) as usize]) - } else { - (FlagSVD::No, n, Vec::new()) + let jvt = match l { + MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt), + MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u), + }; + + let m = l.lda(); + let mut u = match ju { + FlagSVD::All => Some(vec![Self::zero(); (m * m) as usize]), + FlagSVD::No => None, }; + + let n = l.len(); + let mut vt = match jvt { + FlagSVD::All => Some(vec![Self::zero(); (n * n) as usize]), + FlagSVD::No => None, + }; + + let k = std::cmp::min(m, n); let mut s = vec![Self::Real::zero(); k as usize]; - let mut superb = vec![Self::Real::zero(); (k - 1) as usize]; + + // eval work size + let mut info = 0; + let mut work_size = [Self::zero()]; + $gesvd( + ju as u8, + jvt as u8, + m, + n, + &mut a, + m, + &mut s, + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work_size, + -1, + &mut info, + ); + info.as_lapack_result()?; + + // calc + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; $gesvd( - l.lapacke_layout(), ju as u8, jvt as u8, m, n, &mut a, - lda, + m, &mut s, - &mut u, - ldu, - &mut vt, - ldvt, - &mut superb, - ) - .as_lapack_result()?; - Ok(SVDOutput { - s, - u: if calc_u { Some(u) } else { None }, - vt: if calc_vt { Some(vt) } else { None }, - }) + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work, + lwork as i32, + &mut info, + ); + info.as_lapack_result()?; + match l { + MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }), + MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }), + } + } + } + }; +} // impl_svd_real! + +impl_svd_real!(f64, lapack::dgesvd); +impl_svd_real!(f32, lapack::sgesvd); + +macro_rules! impl_svd_complex { + ($scalar:ty, $gesvd:path) => { + impl SVD_ for $scalar { + unsafe fn svd( + l: MatrixLayout, + calc_u: bool, + calc_vt: bool, + mut a: &mut [Self], + ) -> Result> { + let ju = match l { + MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u), + MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt), + }; + let jvt = match l { + MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt), + MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u), + }; + + let m = l.lda(); + let mut u = match ju { + FlagSVD::All => Some(vec![Self::zero(); (m * m) as usize]), + FlagSVD::No => None, + }; + + let n = l.len(); + let mut vt = match jvt { + FlagSVD::All => Some(vec![Self::zero(); (n * n) as usize]), + FlagSVD::No => None, + }; + + let k = std::cmp::min(m, n); + let mut s = vec![Self::Real::zero(); k as usize]; + + let mut rwork = vec![Self::Real::zero(); 5 * k as usize]; + + // eval work size + let mut info = 0; + let mut work_size = [Self::zero()]; + $gesvd( + ju as u8, + jvt as u8, + m, + n, + &mut a, + m, + &mut s, + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work_size, + -1, + &mut rwork, + &mut info, + ); + info.as_lapack_result()?; + + // calc + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + $gesvd( + ju as u8, + jvt as u8, + m, + n, + &mut a, + m, + &mut s, + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work, + lwork as i32, + &mut rwork, + &mut info, + ); + info.as_lapack_result()?; + match l { + MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }), + MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }), + } } } }; -} // impl_svd! +} // impl_svd_real! -impl_svd!(f64, lapacke::dgesvd); -impl_svd!(f32, lapacke::sgesvd); -impl_svd!(c64, lapacke::zgesvd); -impl_svd!(c32, lapacke::cgesvd); +impl_svd_complex!(c64, lapack::zgesvd); +impl_svd_complex!(c32, lapack::cgesvd); diff --git a/ndarray-linalg/src/svd.rs b/ndarray-linalg/src/svd.rs index 9bb90977..5dce4851 100644 --- a/ndarray-linalg/src/svd.rs +++ b/ndarray-linalg/src/svd.rs @@ -4,7 +4,6 @@ use ndarray::*; -use super::convert::*; use super::error::*; use super::layout::*; use super::types::*; @@ -99,12 +98,27 @@ where let l = self.layout()?; let svd_res = unsafe { A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)? }; let (n, m) = l.size(); - let u = svd_res - .u - .map(|u| into_matrix(l.resized(n, n), u).expect("Size of U mismatches")); - let vt = svd_res - .vt - .map(|vt| into_matrix(l.resized(m, m), vt).expect("Size of VT mismatches")); + let n = n as usize; + let m = m as usize; + + let u = svd_res.u.map(|u| { + assert_eq!(u.len(), n * n); + match l { + MatrixLayout::F { .. } => Array::from_shape_vec((n, n).f(), u), + MatrixLayout::C { .. } => Array::from_shape_vec((n, n), u), + } + .unwrap() + }); + + let vt = svd_res.vt.map(|vt| { + assert_eq!(vt.len(), m * m); + match l { + MatrixLayout::F { .. } => Array::from_shape_vec((m, m).f(), vt), + MatrixLayout::C { .. } => Array::from_shape_vec((m, m), vt), + } + .unwrap() + }); + let s = ArrayBase::from(svd_res.s); Ok((u, s, vt)) } From 2373b434e07152d6d6934add395c57783d44699c Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 11 Jul 2020 17:26:12 +0900 Subject: [PATCH 14/49] Add test for complex in SVD --- ndarray-linalg/tests/svd.rs | 60 ++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/ndarray-linalg/tests/svd.rs b/ndarray-linalg/tests/svd.rs index acc6ffca..c83885e1 100644 --- a/ndarray-linalg/tests/svd.rs +++ b/ndarray-linalg/tests/svd.rs @@ -2,7 +2,7 @@ use ndarray::*; use ndarray_linalg::*; use std::cmp::min; -fn test(a: &Array2) { +fn test(a: &Array2) { let (n, m) = a.dim(); let answer = a.clone(); println!("a = \n{:?}", a); @@ -12,14 +12,14 @@ fn test(a: &Array2) { println!("u = \n{:?}", &u); println!("s = \n{:?}", &s); println!("v = \n{:?}", &vt); - let mut sm = Array::zeros((n, m)); + let mut sm = Array::::zeros((n, m)); for i in 0..min(n, m) { - sm[(i, i)] = s[i]; + sm[(i, i)] = T::from(s[i]).unwrap(); } - assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7); + assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, T::real(1e-7)); } -fn test_no_vt(a: &Array2) { +fn test_no_vt(a: &Array2) { let (n, _m) = a.dim(); println!("a = \n{:?}", a); let (u, _s, vt): (_, Array1<_>, _) = a.svd(true, false).unwrap(); @@ -30,7 +30,7 @@ fn test_no_vt(a: &Array2) { assert_eq!(u.dim().1, n); } -fn test_no_u(a: &Array2) { +fn test_no_u(a: &Array2) { let (_n, m) = a.dim(); println!("a = \n{:?}", a); let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, true).unwrap(); @@ -41,7 +41,7 @@ fn test_no_u(a: &Array2) { assert_eq!(vt.dim().1, m); } -fn test_diag_only(a: &Array2) { +fn test_diag_only(a: &Array2) { println!("a = \n{:?}", a); let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, false).unwrap(); assert!(u.is_none()); @@ -49,32 +49,44 @@ fn test_diag_only(a: &Array2) { } macro_rules! test_svd_impl { - ($test:ident, $n:expr, $m:expr) => { + ($type:ty, $test:ident, $n:expr, $m:expr) => { paste::item! { #[test] - fn []() { + fn []() { let a = random(($n, $m)); - $test(&a); + $test::<$type>(&a); } #[test] - fn []() { + fn []() { let a = random(($n, $m).f()); - $test(&a); + $test::<$type>(&a); } } }; } -test_svd_impl!(test, 3, 3); -test_svd_impl!(test_no_vt, 3, 3); -test_svd_impl!(test_no_u, 3, 3); -test_svd_impl!(test_diag_only, 3, 3); -test_svd_impl!(test, 4, 3); -test_svd_impl!(test_no_vt, 4, 3); -test_svd_impl!(test_no_u, 4, 3); -test_svd_impl!(test_diag_only, 4, 3); -test_svd_impl!(test, 3, 4); -test_svd_impl!(test_no_vt, 3, 4); -test_svd_impl!(test_no_u, 3, 4); -test_svd_impl!(test_diag_only, 3, 4); +test_svd_impl!(f64, test, 3, 3); +test_svd_impl!(f64, test_no_vt, 3, 3); +test_svd_impl!(f64, test_no_u, 3, 3); +test_svd_impl!(f64, test_diag_only, 3, 3); +test_svd_impl!(f64, test, 4, 3); +test_svd_impl!(f64, test_no_vt, 4, 3); +test_svd_impl!(f64, test_no_u, 4, 3); +test_svd_impl!(f64, test_diag_only, 4, 3); +test_svd_impl!(f64, test, 3, 4); +test_svd_impl!(f64, test_no_vt, 3, 4); +test_svd_impl!(f64, test_no_u, 3, 4); +test_svd_impl!(f64, test_diag_only, 3, 4); +test_svd_impl!(c64, test, 3, 3); +test_svd_impl!(c64, test_no_vt, 3, 3); +test_svd_impl!(c64, test_no_u, 3, 3); +test_svd_impl!(c64, test_diag_only, 3, 3); +test_svd_impl!(c64, test, 4, 3); +test_svd_impl!(c64, test_no_vt, 4, 3); +test_svd_impl!(c64, test_no_u, 4, 3); +test_svd_impl!(c64, test_diag_only, 4, 3); +test_svd_impl!(c64, test, 3, 4); +test_svd_impl!(c64, test_no_vt, 3, 4); +test_svd_impl!(c64, test_no_u, 3, 4); +test_svd_impl!(c64, test_diag_only, 3, 4); From cb3010784aceccd36ee46735d7d8f56c379a309a Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 12 Jul 2020 17:16:51 +0900 Subject: [PATCH 15/49] Split real/complex of svddc --- lax/src/svddc.rs | 60 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 5 deletions(-) diff --git a/lax/src/svddc.rs b/lax/src/svddc.rs index 84f8394b..35c5076b 100644 --- a/lax/src/svddc.rs +++ b/lax/src/svddc.rs @@ -21,7 +21,7 @@ pub trait SVDDC_: Scalar { unsafe fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result>; } -macro_rules! impl_svdd { +macro_rules! impl_svddc_real { ($scalar:ty, $gesdd:path) => { impl SVDDC_ for $scalar { unsafe fn svddc( @@ -70,7 +70,57 @@ macro_rules! impl_svdd { }; } -impl_svdd!(f32, lapacke::sgesdd); -impl_svdd!(f64, lapacke::dgesdd); -impl_svdd!(c32, lapacke::cgesdd); -impl_svdd!(c64, lapacke::zgesdd); +impl_svddc_real!(f32, lapacke::sgesdd); +impl_svddc_real!(f64, lapacke::dgesdd); + +macro_rules! impl_svddc_complex { + ($scalar:ty, $gesdd:path) => { + impl SVDDC_ for $scalar { + unsafe fn svddc( + l: MatrixLayout, + jobz: UVTFlag, + mut a: &mut [Self], + ) -> Result> { + let (m, n) = l.size(); + let k = m.min(n); + let lda = l.lda(); + let (ucol, vtrow) = match jobz { + UVTFlag::Full => (m, n), + UVTFlag::Some => (k, k), + UVTFlag::None => (1, 1), + }; + let mut s = vec![Self::Real::zero(); k.max(1) as usize]; + let mut u = vec![Self::zero(); (m * ucol).max(1) as usize]; + let ldu = l.resized(m, ucol).lda(); + let mut vt = vec![Self::zero(); (vtrow * n).max(1) as usize]; + let ldvt = l.resized(vtrow, n).lda(); + $gesdd( + l.lapacke_layout(), + jobz as u8, + m, + n, + &mut a, + lda, + &mut s, + &mut u, + ldu, + &mut vt, + ldvt, + ) + .as_lapack_result()?; + Ok(SVDOutput { + s, + u: if jobz == UVTFlag::None { None } else { Some(u) }, + vt: if jobz == UVTFlag::None { + None + } else { + Some(vt) + }, + }) + } + } + }; +} + +impl_svddc_complex!(c32, lapacke::cgesdd); +impl_svddc_complex!(c64, lapacke::zgesdd); From 8afd53c65226ae5f249b7073391525a5ed8b4bc7 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 12 Jul 2020 21:21:40 +0900 Subject: [PATCH 16/49] Rewrite impl_svddc_real --- lax/src/svddc.rs | 92 +++++++++++++++++++++++++------------ ndarray-linalg/src/svddc.rs | 42 +++++++++++------ 2 files changed, 91 insertions(+), 43 deletions(-) diff --git a/lax/src/svddc.rs b/lax/src/svddc.rs index 35c5076b..d46ef55a 100644 --- a/lax/src/svddc.rs +++ b/lax/src/svddc.rs @@ -1,7 +1,7 @@ use super::*; use crate::{error::*, layout::MatrixLayout}; use cauchy::*; -use num_traits::Zero; +use num_traits::{ToPrimitive, Zero}; /// Specifies how many of the columns of *U* and rows of *V*áµ€ are computed and returned. /// @@ -29,49 +29,81 @@ macro_rules! impl_svddc_real { jobz: UVTFlag, mut a: &mut [Self], ) -> Result> { - let (m, n) = l.size(); + let m = l.lda(); + let n = l.len(); let k = m.min(n); - let lda = l.lda(); - let (ucol, vtrow) = match jobz { - UVTFlag::Full => (m, n), + let mut s = vec![Self::Real::zero(); k as usize]; + + let (u_col, vt_row) = match jobz { + UVTFlag::Full | UVTFlag::None => (m, n), UVTFlag::Some => (k, k), - UVTFlag::None => (1, 1), }; - let mut s = vec![Self::Real::zero(); k.max(1) as usize]; - let mut u = vec![Self::zero(); (m * ucol).max(1) as usize]; - let ldu = l.resized(m, ucol).lda(); - let mut vt = vec![Self::zero(); (vtrow * n).max(1) as usize]; - let ldvt = l.resized(vtrow, n).lda(); + let (mut u, mut vt) = match jobz { + UVTFlag::Full => ( + Some(vec![Self::zero(); (m * m) as usize]), + Some(vec![Self::zero(); (n * n) as usize]), + ), + UVTFlag::Some => ( + Some(vec![Self::zero(); (m * u_col) as usize]), + Some(vec![Self::zero(); (n * vt_row) as usize]), + ), + UVTFlag::None => (None, None), + }; + + // eval work size + let mut info = 0; + let mut iwork = vec![0; 8 * k as usize]; + let mut work_size = [Self::zero()]; $gesdd( - l.lapacke_layout(), jobz as u8, m, n, &mut a, - lda, + m, &mut s, - &mut u, - ldu, - &mut vt, - ldvt, - ) - .as_lapack_result()?; - Ok(SVDOutput { - s, - u: if jobz == UVTFlag::None { None } else { Some(u) }, - vt: if jobz == UVTFlag::None { - None - } else { - Some(vt) - }, - }) + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + vt_row, + &mut work_size, + -1, + &mut iwork, + &mut info, + ); + info.as_lapack_result()?; + + // do svd + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + $gesdd( + jobz as u8, + m, + n, + &mut a, + m, + &mut s, + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + vt_row, + &mut work, + lwork as i32, + &mut iwork, + &mut info, + ); + info.as_lapack_result()?; + + match l { + MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }), + MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }), + } } } }; } -impl_svddc_real!(f32, lapacke::sgesdd); -impl_svddc_real!(f64, lapacke::dgesdd); +impl_svddc_real!(f32, lapack::sgesdd); +impl_svddc_real!(f64, lapack::dgesdd); macro_rules! impl_svddc_complex { ($scalar:ty, $gesdd:path) => { diff --git a/ndarray-linalg/src/svddc.rs b/ndarray-linalg/src/svddc.rs index 22f3ae0c..7045fa1e 100644 --- a/ndarray-linalg/src/svddc.rs +++ b/ndarray-linalg/src/svddc.rs @@ -2,7 +2,6 @@ use ndarray::*; -use super::convert::*; use super::error::*; use super::layout::*; use super::types::*; @@ -85,19 +84,36 @@ where ) -> Result<(Option, Self::Sigma, Option)> { let l = self.layout()?; let svd_res = unsafe { A::svddc(l, uvt_flag, self.as_allocated_mut()?)? }; - let (m, n) = l.size(); - let k = m.min(n); - let (ldu, tdu, ldvt, tdvt) = match uvt_flag { - UVTFlag::Full => (m, m, n, n), - UVTFlag::Some => (m, k, k, n), - UVTFlag::None => (1, 1, 1, 1), + let (n, m) = l.size(); + let k = std::cmp::min(n, m); + let n = n as usize; + let m = m as usize; + let k = k as usize; + + let (u_col, vt_row) = match uvt_flag { + UVTFlag::Full => (n, m), + UVTFlag::Some => (k, k), + UVTFlag::None => (0, 0), }; - let u = svd_res - .u - .map(|u| into_matrix(l.resized(ldu, tdu), u).expect("Size of U mismatches")); - let vt = svd_res - .vt - .map(|vt| into_matrix(l.resized(ldvt, tdvt), vt).expect("Size of VT mismatches")); + + let u = svd_res.u.map(|u| { + assert_eq!(u.len(), n * u_col); + match l { + MatrixLayout::F { .. } => Array::from_shape_vec((n, u_col).f(), u), + MatrixLayout::C { .. } => Array::from_shape_vec((n, u_col), u), + } + .unwrap() + }); + + let vt = svd_res.vt.map(|vt| { + assert_eq!(vt.len(), m * vt_row); + match l { + MatrixLayout::F { .. } => Array::from_shape_vec((vt_row, m).f(), vt), + MatrixLayout::C { .. } => Array::from_shape_vec((vt_row, m), vt), + } + .unwrap() + }); + let s = ArrayBase::from(svd_res.s); Ok((u, s, vt)) } From 797b71b07f06daf4d7e2ea4a4159876f3464d304 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Mon, 13 Jul 2020 21:53:40 +0900 Subject: [PATCH 17/49] Add svddc test for complex numbers --- ndarray-linalg/tests/svddc.rs | 43 +++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/ndarray-linalg/tests/svddc.rs b/ndarray-linalg/tests/svddc.rs index 2c9204c8..fb26c8d5 100644 --- a/ndarray-linalg/tests/svddc.rs +++ b/ndarray-linalg/tests/svddc.rs @@ -1,13 +1,13 @@ use ndarray::*; use ndarray_linalg::*; -fn test(a: &Array2, flag: UVTFlag) { +fn test(a: &Array2, flag: UVTFlag) { let (n, m) = a.dim(); let k = n.min(m); let answer = a.clone(); println!("a = \n{:?}", a); let (u, s, vt): (_, Array1<_>, _) = a.svddc(flag).unwrap(); - let mut sm = match flag { + let mut sm: Array2 = match flag { UVTFlag::Full => Array::zeros((n, m)), UVTFlag::Some => Array::zeros((k, k)), UVTFlag::None => { @@ -22,53 +22,56 @@ fn test(a: &Array2, flag: UVTFlag) { println!("s = \n{:?}", &s); println!("v = \n{:?}", &vt); for i in 0..k { - sm[(i, i)] = s[i]; + sm[(i, i)] = T::from_real(s[i]); } - assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7); + assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, T::real(1e-7)); } macro_rules! test_svd_impl { - ($n:expr, $m:expr) => { + ($scalar:ty, $n:expr, $m:expr) => { paste::item! { #[test] - fn []() { + fn []() { let a = random(($n, $m)); - test(&a, UVTFlag::Full); + test::<$scalar>(&a, UVTFlag::Full); } #[test] - fn []() { + fn []() { let a = random(($n, $m)); - test(&a, UVTFlag::Some); + test::<$scalar>(&a, UVTFlag::Some); } #[test] - fn []() { + fn []() { let a = random(($n, $m)); - test(&a, UVTFlag::None); + test::<$scalar>(&a, UVTFlag::None); } #[test] - fn []() { + fn []() { let a = random(($n, $m).f()); - test(&a, UVTFlag::Full); + test::<$scalar>(&a, UVTFlag::Full); } #[test] - fn []() { + fn []() { let a = random(($n, $m).f()); - test(&a, UVTFlag::Some); + test::<$scalar>(&a, UVTFlag::Some); } #[test] - fn []() { + fn []() { let a = random(($n, $m).f()); - test(&a, UVTFlag::None); + test::<$scalar>(&a, UVTFlag::None); } } }; } -test_svd_impl!(3, 3); -test_svd_impl!(4, 3); -test_svd_impl!(3, 4); +test_svd_impl!(f64, 3, 3); +test_svd_impl!(f64, 4, 3); +test_svd_impl!(f64, 3, 4); +test_svd_impl!(c64, 3, 3); +test_svd_impl!(c64, 4, 3); +test_svd_impl!(c64, 3, 4); From acb88aa5a105510d39a5286b796312380b044026 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Mon, 13 Jul 2020 23:55:00 +0900 Subject: [PATCH 18/49] Impl SVDDC_ for c32/c64 --- lax/src/svddc.rs | 80 +++++++++++++++--------------------------------- 1 file changed, 24 insertions(+), 56 deletions(-) diff --git a/lax/src/svddc.rs b/lax/src/svddc.rs index d46ef55a..3e50d7bb 100644 --- a/lax/src/svddc.rs +++ b/lax/src/svddc.rs @@ -21,8 +21,14 @@ pub trait SVDDC_: Scalar { unsafe fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result>; } -macro_rules! impl_svddc_real { - ($scalar:ty, $gesdd:path) => { +macro_rules! impl_svddc { + (@real, $scalar:ty, $gesdd:path) => { + impl_svddc!(@body, $scalar, $gesdd, ); + }; + (@complex, $scalar:ty, $gesdd:path) => { + impl_svddc!(@body, $scalar, $gesdd, rwork); + }; + (@body, $scalar:ty, $gesdd:path, $($rwork_ident:ident),*) => { impl SVDDC_ for $scalar { unsafe fn svddc( l: MatrixLayout, @@ -50,6 +56,16 @@ macro_rules! impl_svddc_real { UVTFlag::None => (None, None), }; + $( // for complex only + let mx = n.max(m) as usize; + let mn = n.min(m) as usize; + let lrwork = match jobz { + UVTFlag::None => 7 * mn, + _ => std::cmp::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn), + }; + let mut $rwork_ident = vec![Self::Real::zero(); lrwork]; + )* + // eval work size let mut info = 0; let mut iwork = vec![0; 8 * k as usize]; @@ -67,6 +83,7 @@ macro_rules! impl_svddc_real { vt_row, &mut work_size, -1, + $(&mut $rwork_ident,)* &mut iwork, &mut info, ); @@ -88,6 +105,7 @@ macro_rules! impl_svddc_real { vt_row, &mut work, lwork as i32, + $(&mut $rwork_ident,)* &mut iwork, &mut info, ); @@ -102,57 +120,7 @@ macro_rules! impl_svddc_real { }; } -impl_svddc_real!(f32, lapack::sgesdd); -impl_svddc_real!(f64, lapack::dgesdd); - -macro_rules! impl_svddc_complex { - ($scalar:ty, $gesdd:path) => { - impl SVDDC_ for $scalar { - unsafe fn svddc( - l: MatrixLayout, - jobz: UVTFlag, - mut a: &mut [Self], - ) -> Result> { - let (m, n) = l.size(); - let k = m.min(n); - let lda = l.lda(); - let (ucol, vtrow) = match jobz { - UVTFlag::Full => (m, n), - UVTFlag::Some => (k, k), - UVTFlag::None => (1, 1), - }; - let mut s = vec![Self::Real::zero(); k.max(1) as usize]; - let mut u = vec![Self::zero(); (m * ucol).max(1) as usize]; - let ldu = l.resized(m, ucol).lda(); - let mut vt = vec![Self::zero(); (vtrow * n).max(1) as usize]; - let ldvt = l.resized(vtrow, n).lda(); - $gesdd( - l.lapacke_layout(), - jobz as u8, - m, - n, - &mut a, - lda, - &mut s, - &mut u, - ldu, - &mut vt, - ldvt, - ) - .as_lapack_result()?; - Ok(SVDOutput { - s, - u: if jobz == UVTFlag::None { None } else { Some(u) }, - vt: if jobz == UVTFlag::None { - None - } else { - Some(vt) - }, - }) - } - } - }; -} - -impl_svddc_complex!(c32, lapacke::cgesdd); -impl_svddc_complex!(c64, lapacke::zgesdd); +impl_svddc!(@real, f32, lapack::sgesdd); +impl_svddc!(@real, f64, lapack::dgesdd); +impl_svddc!(@complex, c32, lapack::cgesdd); +impl_svddc!(@complex, c64, lapack::zgesdd); From cf56af5336ee83a8658b7a7a6eef7c94684df66c Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 14 Jul 2020 00:04:31 +0900 Subject: [PATCH 19/49] Use convert::into_matrix --- ndarray-linalg/src/svddc.rs | 36 ++++++++++-------------------------- 1 file changed, 10 insertions(+), 26 deletions(-) diff --git a/ndarray-linalg/src/svddc.rs b/ndarray-linalg/src/svddc.rs index 7045fa1e..b27212ec 100644 --- a/ndarray-linalg/src/svddc.rs +++ b/ndarray-linalg/src/svddc.rs @@ -1,11 +1,8 @@ //! Singular-value decomposition (SVD) by divide-and-conquer (?gesdd) +use super::{convert::*, error::*, layout::*, types::*}; use ndarray::*; -use super::error::*; -use super::layout::*; -use super::types::*; - pub use lapack::svddc::UVTFlag; /// Singular-value decomposition of matrix (copying) by divide-and-conquer @@ -84,35 +81,22 @@ where ) -> Result<(Option, Self::Sigma, Option)> { let l = self.layout()?; let svd_res = unsafe { A::svddc(l, uvt_flag, self.as_allocated_mut()?)? }; - let (n, m) = l.size(); - let k = std::cmp::min(n, m); - let n = n as usize; - let m = m as usize; - let k = k as usize; + let (m, n) = l.size(); + let k = m.min(n); let (u_col, vt_row) = match uvt_flag { - UVTFlag::Full => (n, m), + UVTFlag::Full => (m, n), UVTFlag::Some => (k, k), UVTFlag::None => (0, 0), }; - let u = svd_res.u.map(|u| { - assert_eq!(u.len(), n * u_col); - match l { - MatrixLayout::F { .. } => Array::from_shape_vec((n, u_col).f(), u), - MatrixLayout::C { .. } => Array::from_shape_vec((n, u_col), u), - } - .unwrap() - }); + let u = svd_res + .u + .map(|u| into_matrix(l.resized(m, u_col), u).unwrap()); - let vt = svd_res.vt.map(|vt| { - assert_eq!(vt.len(), m * vt_row); - match l { - MatrixLayout::F { .. } => Array::from_shape_vec((vt_row, m).f(), vt), - MatrixLayout::C { .. } => Array::from_shape_vec((vt_row, m), vt), - } - .unwrap() - }); + let vt = svd_res + .vt + .map(|vt| into_matrix(l.resized(vt_row, n), vt).unwrap()); let s = ArrayBase::from(svd_res.s); Ok((u, s, vt)) From 8b23d8cfdf0cd7ef5ea94846af183655fbc14255 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 15 Jul 2020 21:31:52 +0900 Subject: [PATCH 20/49] Merge impl_svd_{real,complex} macros --- lax/src/svd.rs | 110 ++++++++----------------------------------------- 1 file changed, 17 insertions(+), 93 deletions(-) diff --git a/lax/src/svd.rs b/lax/src/svd.rs index 0e7bb4bb..d51ffd27 100644 --- a/lax/src/svd.rs +++ b/lax/src/svd.rs @@ -44,94 +44,14 @@ pub trait SVD_: Scalar { ) -> Result>; } -macro_rules! impl_svd_real { - ($scalar:ty, $gesvd:path) => { - impl SVD_ for $scalar { - unsafe fn svd( - l: MatrixLayout, - calc_u: bool, - calc_vt: bool, - mut a: &mut [Self], - ) -> Result> { - let ju = match l { - MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u), - MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt), - }; - let jvt = match l { - MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt), - MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u), - }; - - let m = l.lda(); - let mut u = match ju { - FlagSVD::All => Some(vec![Self::zero(); (m * m) as usize]), - FlagSVD::No => None, - }; - - let n = l.len(); - let mut vt = match jvt { - FlagSVD::All => Some(vec![Self::zero(); (n * n) as usize]), - FlagSVD::No => None, - }; - - let k = std::cmp::min(m, n); - let mut s = vec![Self::Real::zero(); k as usize]; - - // eval work size - let mut info = 0; - let mut work_size = [Self::zero()]; - $gesvd( - ju as u8, - jvt as u8, - m, - n, - &mut a, - m, - &mut s, - u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - m, - vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work_size, - -1, - &mut info, - ); - info.as_lapack_result()?; - - // calc - let lwork = work_size[0].to_usize().unwrap(); - let mut work = vec![Self::zero(); lwork]; - $gesvd( - ju as u8, - jvt as u8, - m, - n, - &mut a, - m, - &mut s, - u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - m, - vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work, - lwork as i32, - &mut info, - ); - info.as_lapack_result()?; - match l { - MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }), - MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }), - } - } - } +macro_rules! impl_svd { + (@real, $scalar:ty, $gesvd:path) => { + impl_svd!(@body, $scalar, $gesvd, ); }; -} // impl_svd_real! - -impl_svd_real!(f64, lapack::dgesvd); -impl_svd_real!(f32, lapack::sgesvd); - -macro_rules! impl_svd_complex { - ($scalar:ty, $gesvd:path) => { + (@complex, $scalar:ty, $gesvd:path) => { + impl_svd!(@body, $scalar, $gesvd, rwork); + }; + (@body, $scalar:ty, $gesvd:path, $($rwork_ident:ident),*) => { impl SVD_ for $scalar { unsafe fn svd( l: MatrixLayout, @@ -163,7 +83,9 @@ macro_rules! impl_svd_complex { let k = std::cmp::min(m, n); let mut s = vec![Self::Real::zero(); k as usize]; - let mut rwork = vec![Self::Real::zero(); 5 * k as usize]; + $( + let mut $rwork_ident = vec![Self::Real::zero(); 5 * k as usize]; + )* // eval work size let mut info = 0; @@ -182,7 +104,7 @@ macro_rules! impl_svd_complex { n, &mut work_size, -1, - &mut rwork, + $(&mut $rwork_ident,)* &mut info, ); info.as_lapack_result()?; @@ -204,7 +126,7 @@ macro_rules! impl_svd_complex { n, &mut work, lwork as i32, - &mut rwork, + $(&mut $rwork_ident,)* &mut info, ); info.as_lapack_result()?; @@ -215,7 +137,9 @@ macro_rules! impl_svd_complex { } } }; -} // impl_svd_real! +} // impl_svd! -impl_svd_complex!(c64, lapack::zgesvd); -impl_svd_complex!(c32, lapack::cgesvd); +impl_svd!(@real, f64, lapack::dgesvd); +impl_svd!(@real, f32, lapack::sgesvd); +impl_svd!(@complex, c64, lapack::zgesvd); +impl_svd!(@complex, c32, lapack::cgesvd); From e3a77678d789e6938a641433134dfd8a50acf2be Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 15 Jul 2020 21:46:24 +0900 Subject: [PATCH 21/49] Merge impl_eigh! and impl_eighc! --- lax/src/eigh.rs | 112 +++++++++++------------------------------------- 1 file changed, 25 insertions(+), 87 deletions(-) diff --git a/lax/src/eigh.rs b/lax/src/eigh.rs index f43ece13..0920dfa1 100644 --- a/lax/src/eigh.rs +++ b/lax/src/eigh.rs @@ -25,7 +25,13 @@ pub trait Eigh_: Scalar { } macro_rules! impl_eigh { - ($scalar:ty, $ev:path, $evg:path) => { + (@real, $scalar:ty, $ev:path, $evg:path) => { + impl_eigh!(@body, $scalar, $ev, $evg, ); + }; + (@complex, $scalar:ty, $ev:path, $evg:path) => { + impl_eigh!(@body, $scalar, $ev, $evg, rwork); + }; + (@body, $scalar:ty, $ev:path, $evg:path, $($rwork_ident:ident),*) => { impl Eigh_ for $scalar { fn eigh( calc_v: bool, @@ -37,11 +43,14 @@ macro_rules! impl_eigh { let n = layout.len(); let jobz = if calc_v { b'V' } else { b'N' }; let mut eigs = vec![Self::Real::zero(); n as usize]; - let n = n as i32; + + $( + let mut $rwork_ident = vec![Self::Real::zero(); 3 * n as usize - 2]; + )* // calc work size let mut info = 0; - let mut work_size = [0.0]; + let mut work_size = [Self::zero()]; unsafe { $ev( jobz, @@ -52,6 +61,7 @@ macro_rules! impl_eigh { &mut eigs, &mut work_size, -1, + $(&mut $rwork_ident,)* &mut info, ); } @@ -70,6 +80,7 @@ macro_rules! impl_eigh { &mut eigs, &mut work, lwork as i32, + $(&mut $rwork_ident,)* &mut info, ); } @@ -88,11 +99,14 @@ macro_rules! impl_eigh { let n = layout.len(); let jobz = if calc_v { b'V' } else { b'N' }; let mut eigs = vec![Self::Real::zero(); n as usize]; - let n = n as i32; + + $( + let mut $rwork_ident = vec![Self::Real::zero(); 3 * n as usize - 2]; + )* // calc work size let mut info = 0; - let mut work_size = [0.0]; + let mut work_size = [Self::zero()]; unsafe { $evg( &[1], @@ -106,6 +120,7 @@ macro_rules! impl_eigh { &mut eigs, &mut work_size, -1, + $(&mut $rwork_ident,)* &mut info, ); } @@ -127,6 +142,7 @@ macro_rules! impl_eigh { &mut eigs, &mut work, lwork as i32, + $(&mut $rwork_ident,)* &mut info, ); } @@ -137,85 +153,7 @@ macro_rules! impl_eigh { }; } // impl_eigh! -impl_eigh!(f64, lapack::dsyev, lapack::dsygv); -impl_eigh!(f32, lapack::ssyev, lapack::ssygv); - -// splitted for RWORK -macro_rules! impl_eighc { - ($scalar:ty, $ev:path, $evg:path) => { - impl Eigh_ for $scalar { - fn eigh( - calc_v: bool, - layout: MatrixLayout, - uplo: UPLO, - mut a: &mut [Self], - ) -> Result> { - assert_eq!(layout.len(), layout.lda()); - let n = layout.len(); - let jobz = if calc_v { b'V' } else { b'N' }; - let mut eigs = vec![Self::Real::zero(); n as usize]; - let mut work = vec![Self::zero(); 2 * n as usize - 1]; - let mut rwork = vec![Self::Real::zero(); 3 * n as usize - 2]; - let mut info = 0; - let n = n as i32; - - unsafe { - $ev( - jobz, - uplo as u8, - n, - &mut a, - n, - &mut eigs, - &mut work, - 2 * n - 1, - &mut rwork, - &mut info, - ) - }; - info.as_lapack_result()?; - Ok(eigs) - } - - fn eigh_generalized( - calc_v: bool, - layout: MatrixLayout, - uplo: UPLO, - mut a: &mut [Self], - mut b: &mut [Self], - ) -> Result> { - assert_eq!(layout.len(), layout.lda()); - let n = layout.len(); - let jobz = if calc_v { b'V' } else { b'N' }; - let mut eigs = vec![Self::Real::zero(); n as usize]; - let mut work = vec![Self::zero(); 2 * n as usize - 1]; - let mut rwork = vec![Self::Real::zero(); 3 * n as usize - 2]; - let n = n as i32; - let mut info = 0; - - unsafe { - $evg( - &[1], - jobz, - uplo as u8, - n, - &mut a, - n, - &mut b, - n, - &mut eigs, - &mut work, - 2 * n - 1, - &mut rwork, - &mut info, - ) - }; - info.as_lapack_result()?; - Ok(eigs) - } - } - }; -} // impl_eigh! - -impl_eighc!(c64, lapack::zheev, lapack::zhegv); -impl_eighc!(c32, lapack::cheev, lapack::chegv); +impl_eigh!(@real, f64, lapack::dsyev, lapack::dsygv); +impl_eigh!(@real, f32, lapack::ssyev, lapack::ssygv); +impl_eigh!(@complex, c64, lapack::zheev, lapack::zhegv); +impl_eigh!(@complex, c32, lapack::cheev, lapack::chegv); From cbe6df7388bab07a5ac459fb527f2a526ccfc8a1 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Thu, 16 Jul 2020 22:19:11 +0900 Subject: [PATCH 22/49] Rewrite QR using LAPACK --- lax/src/qr.rs | 149 +++++++++++++++++++++++++++++++++------ ndarray-linalg/src/qr.rs | 4 +- 2 files changed, 131 insertions(+), 22 deletions(-) diff --git a/lax/src/qr.rs b/lax/src/qr.rs index 6c26273d..0bb00c2a 100644 --- a/lax/src/qr.rs +++ b/lax/src/qr.rs @@ -2,35 +2,120 @@ use crate::{error::*, layout::MatrixLayout}; use cauchy::*; -use num_traits::Zero; -use std::cmp::min; +use num_traits::{ToPrimitive, Zero}; -/// Wraps `*geqrf` and `*orgqr` (`*ungqr` for complex numbers) pub trait QR_: Sized { - unsafe fn householder(l: MatrixLayout, a: &mut [Self]) -> Result>; - unsafe fn q(l: MatrixLayout, a: &mut [Self], tau: &[Self]) -> Result<()>; - unsafe fn qr(l: MatrixLayout, a: &mut [Self]) -> Result>; + /// Execute Householder reflection as the first step of QR-decomposition + /// + /// For C-continuous array, + /// this will call LQ-decomposition of the transposed matrix $ A^T = LQ^T $ + fn householder(l: MatrixLayout, a: &mut [Self]) -> Result>; + + /// Reconstruct Q-matrix from Householder-reflectors + fn q(l: MatrixLayout, a: &mut [Self], tau: &[Self]) -> Result<()>; + + /// Execute QR-decomposition at once + fn qr(l: MatrixLayout, a: &mut [Self]) -> Result>; } macro_rules! impl_qr { - ($scalar:ty, $qrf:path, $gqr:path) => { + ($scalar:ty, $qrf:path, $lqf:path, $gqr:path, $glq:path) => { impl QR_ for $scalar { - unsafe fn householder(l: MatrixLayout, mut a: &mut [Self]) -> Result> { - let (row, col) = l.size(); - let k = min(row, col); + fn householder(l: MatrixLayout, mut a: &mut [Self]) -> Result> { + let m = l.lda(); + let n = l.len(); + let k = m.min(n); let mut tau = vec![Self::zero(); k as usize]; - $qrf(l.lapacke_layout(), row, col, &mut a, l.lda(), &mut tau).as_lapack_result()?; + + // eval work size + let mut info = 0; + let mut work_size = [Self::zero()]; + unsafe { + match l { + MatrixLayout::F { .. } => { + $qrf(m, n, &mut a, m, &mut tau, &mut work_size, -1, &mut info); + } + MatrixLayout::C { .. } => { + $lqf(m, n, &mut a, m, &mut tau, &mut work_size, -1, &mut info); + } + } + } + info.as_lapack_result()?; + + // calc + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + match l { + MatrixLayout::F { .. } => { + $qrf( + m, + n, + &mut a, + m, + &mut tau, + &mut work, + lwork as i32, + &mut info, + ); + } + MatrixLayout::C { .. } => { + $lqf( + m, + n, + &mut a, + m, + &mut tau, + &mut work, + lwork as i32, + &mut info, + ); + } + } + } + info.as_lapack_result()?; + Ok(tau) } - unsafe fn q(l: MatrixLayout, mut a: &mut [Self], tau: &[Self]) -> Result<()> { - let (row, col) = l.size(); - let k = min(row, col); - $gqr(l.lapacke_layout(), row, k, k, &mut a, l.lda(), &tau).as_lapack_result()?; + fn q(l: MatrixLayout, mut a: &mut [Self], tau: &[Self]) -> Result<()> { + let m = l.lda(); + let n = l.len(); + let k = m.min(n); + assert_eq!(tau.len(), k as usize); + + // eval work size + let mut info = 0; + let mut work_size = [Self::zero()]; + unsafe { + match l { + MatrixLayout::F { .. } => { + $gqr(m, k, k, &mut a, m, &tau, &mut work_size, -1, &mut info) + } + MatrixLayout::C { .. } => { + $glq(k, n, k, &mut a, m, &tau, &mut work_size, -1, &mut info) + } + } + }; + + // calc + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + match l { + MatrixLayout::F { .. } => { + $gqr(m, k, k, &mut a, m, &tau, &mut work, lwork as i32, &mut info) + } + MatrixLayout::C { .. } => { + $glq(k, n, k, &mut a, m, &tau, &mut work, lwork as i32, &mut info) + } + } + } + info.as_lapack_result()?; Ok(()) } - unsafe fn qr(l: MatrixLayout, a: &mut [Self]) -> Result> { + fn qr(l: MatrixLayout, a: &mut [Self]) -> Result> { let tau = Self::householder(l, a)?; let r = Vec::from(&*a); Self::q(l, a, &tau)?; @@ -40,7 +125,31 @@ macro_rules! impl_qr { }; } // endmacro -impl_qr!(f64, lapacke::dgeqrf, lapacke::dorgqr); -impl_qr!(f32, lapacke::sgeqrf, lapacke::sorgqr); -impl_qr!(c64, lapacke::zgeqrf, lapacke::zungqr); -impl_qr!(c32, lapacke::cgeqrf, lapacke::cungqr); +impl_qr!( + f64, + lapack::dgeqrf, + lapack::dgelqf, + lapack::dorgqr, + lapack::dorglq +); +impl_qr!( + f32, + lapack::sgeqrf, + lapack::sgelqf, + lapack::sorgqr, + lapack::sorglq +); +impl_qr!( + c64, + lapack::zgeqrf, + lapack::zgelqf, + lapack::zungqr, + lapack::zunglq +); +impl_qr!( + c32, + lapack::cgeqrf, + lapack::cgelqf, + lapack::cungqr, + lapack::cunglq +); diff --git a/ndarray-linalg/src/qr.rs b/ndarray-linalg/src/qr.rs index be2de0c2..ae7b2c25 100644 --- a/ndarray-linalg/src/qr.rs +++ b/ndarray-linalg/src/qr.rs @@ -61,7 +61,7 @@ where fn qr_square_inplace(&mut self) -> Result<(&mut Self, Self::R)> { let l = self.square_layout()?; - let r = unsafe { A::qr(l, self.as_allocated_mut()?)? }; + let r = A::qr(l, self.as_allocated_mut()?)?; let r: Array2<_> = into_matrix(l, r)?; Ok((self, r.into_triangular(UPLO::Upper))) } @@ -107,7 +107,7 @@ where let m = self.ncols(); let k = ::std::cmp::min(n, m); let l = self.layout()?; - let r = unsafe { A::qr(l, self.as_allocated_mut()?)? }; + let r = A::qr(l, self.as_allocated_mut()?)?; let r: Array2<_> = into_matrix(l, r)?; let q = self; Ok((take_slice(&q, n, k), take_slice_upper(&r, k, m))) From 5f2b59898fd2c3c4ad8e412a3ac65740691c9083 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 17 Jul 2020 20:37:44 +0900 Subject: [PATCH 23/49] clippy mem::replace must_use warning --- ndarray-linalg/src/convert.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndarray-linalg/src/convert.rs b/ndarray-linalg/src/convert.rs index 307eb477..1742ac4d 100644 --- a/ndarray-linalg/src/convert.rs +++ b/ndarray-linalg/src/convert.rs @@ -91,7 +91,7 @@ where { let l = a.layout()?.toggle_order(); let new = clone_with_layout(l, a); - ::std::mem::replace(a, new); + *a = new; Ok(a) } From 4a96e4d59d8d0b9efe653f1cd6965b09bd1c1188 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 17 Jul 2020 20:42:44 +0900 Subject: [PATCH 24/49] Use into_matrix --- ndarray-linalg/src/svd.rs | 29 +++++------------------------ 1 file changed, 5 insertions(+), 24 deletions(-) diff --git a/ndarray-linalg/src/svd.rs b/ndarray-linalg/src/svd.rs index 5dce4851..486982cd 100644 --- a/ndarray-linalg/src/svd.rs +++ b/ndarray-linalg/src/svd.rs @@ -2,12 +2,9 @@ //! //! [Wikipedia article on SVD](https://en.wikipedia.org/wiki/Singular_value_decomposition) +use crate::{convert::*, error::*, layout::*, types::*}; use ndarray::*; -use super::error::*; -use super::layout::*; -use super::types::*; - /// singular-value decomposition of matrix reference pub trait SVD { type U; @@ -98,27 +95,11 @@ where let l = self.layout()?; let svd_res = unsafe { A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)? }; let (n, m) = l.size(); - let n = n as usize; - let m = m as usize; - - let u = svd_res.u.map(|u| { - assert_eq!(u.len(), n * n); - match l { - MatrixLayout::F { .. } => Array::from_shape_vec((n, n).f(), u), - MatrixLayout::C { .. } => Array::from_shape_vec((n, n), u), - } - .unwrap() - }); - - let vt = svd_res.vt.map(|vt| { - assert_eq!(vt.len(), m * m); - match l { - MatrixLayout::F { .. } => Array::from_shape_vec((m, m).f(), vt), - MatrixLayout::C { .. } => Array::from_shape_vec((m, m), vt), - } - .unwrap() - }); + let u = svd_res.u.map(|u| into_matrix(l.resized(n, n), u).unwrap()); + let vt = svd_res + .vt + .map(|vt| into_matrix(l.resized(m, m), vt).unwrap()); let s = ArrayBase::from(svd_res.s); Ok((u, s, vt)) } From e35234744e411b1087950e546a5486f0c09f1e20 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 17 Jul 2020 20:48:22 +0900 Subject: [PATCH 25/49] Drop unsafe of cholesky --- lax/src/cholesky.rs | 29 ++++++++++++++++++----------- ndarray-linalg/src/cholesky.rs | 18 ++++++++---------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/lax/src/cholesky.rs b/lax/src/cholesky.rs index ef9473b4..673aaa78 100644 --- a/lax/src/cholesky.rs +++ b/lax/src/cholesky.rs @@ -8,32 +8,37 @@ pub trait Cholesky_: Sized { /// Cholesky: wrapper of `*potrf` /// /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.** - unsafe fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + /// Wrapper of `*potri` /// /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.** - unsafe fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + /// Wrapper of `*potrs` - unsafe fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) - -> Result<()>; + fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; } macro_rules! impl_cholesky { ($scalar:ty, $trf:path, $tri:path, $trs:path) => { impl Cholesky_ for $scalar { - unsafe fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); - $trf(l.lapacke_layout(), uplo as u8, n, a, n).as_lapack_result()?; + unsafe { + $trf(l.lapacke_layout(), uplo as u8, n, a, n).as_lapack_result()?; + } Ok(()) } - unsafe fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); - $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda()).as_lapack_result()?; + unsafe { + $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda()).as_lapack_result()?; + } Ok(()) } - unsafe fn solve_cholesky( + fn solve_cholesky( l: MatrixLayout, uplo: UPLO, a: &[Self], @@ -42,8 +47,10 @@ macro_rules! impl_cholesky { let (n, _) = l.size(); let nrhs = 1; let ldb = 1; - $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), b, ldb) - .as_lapack_result()?; + unsafe { + $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), b, ldb) + .as_lapack_result()?; + } Ok(()) } } diff --git a/ndarray-linalg/src/cholesky.rs b/ndarray-linalg/src/cholesky.rs index 79e240ab..3f445305 100644 --- a/ndarray-linalg/src/cholesky.rs +++ b/ndarray-linalg/src/cholesky.rs @@ -155,7 +155,7 @@ where fn invc_into(self) -> Result { let mut a = self.factor; - unsafe { A::inv_cholesky(a.square_layout()?, self.uplo, a.as_allocated_mut()?)? }; + A::inv_cholesky(a.square_layout()?, self.uplo, a.as_allocated_mut()?)?; triangular_fill_hermitian(&mut a, self.uplo); Ok(a) } @@ -173,14 +173,12 @@ where where Sb: DataMut, { - unsafe { - A::solve_cholesky( - self.factor.square_layout()?, - self.uplo, - self.factor.as_allocated()?, - b.as_slice_mut().unwrap(), - )? - }; + A::solve_cholesky( + self.factor.square_layout()?, + self.uplo, + self.factor.as_allocated()?, + b.as_slice_mut().unwrap(), + )?; Ok(b) } } @@ -259,7 +257,7 @@ where S: DataMut, { fn cholesky_inplace(&mut self, uplo: UPLO) -> Result<&mut Self> { - unsafe { A::cholesky(self.square_layout()?, uplo, self.as_allocated_mut()?)? }; + A::cholesky(self.square_layout()?, uplo, self.as_allocated_mut()?)?; Ok(self.into_triangular(uplo)) } } From df396b516dc2fee4ec12178df7a90bc94ff91bfe Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 17 Jul 2020 21:34:08 +0900 Subject: [PATCH 26/49] WIP --- lax/src/cholesky.rs | 34 +++++++++++++++++++++++++--------- lax/src/lib.rs | 9 +++++++++ 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/lax/src/cholesky.rs b/lax/src/cholesky.rs index 673aaa78..93c5147b 100644 --- a/lax/src/cholesky.rs +++ b/lax/src/cholesky.rs @@ -24,17 +24,29 @@ macro_rules! impl_cholesky { impl Cholesky_ for $scalar { fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); + let mut info = 0; + let uplo = match l { + MatrixLayout::F { .. } => uplo, + MatrixLayout::C { .. } => uplo.t(), + }; unsafe { - $trf(l.lapacke_layout(), uplo as u8, n, a, n).as_lapack_result()?; + $trf(uplo as u8, n, a, n, &mut info); } + info.as_lapack_result()?; Ok(()) } fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); + let mut info = 0; + let uplo = match l { + MatrixLayout::F { .. } => uplo, + MatrixLayout::C { .. } => uplo.t(), + }; unsafe { - $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda()).as_lapack_result()?; + $tri(uplo as u8, n, a, l.lda(), &mut info); } + info.as_lapack_result()?; Ok(()) } @@ -46,18 +58,22 @@ macro_rules! impl_cholesky { ) -> Result<()> { let (n, _) = l.size(); let nrhs = 1; - let ldb = 1; + let uplo = match l { + MatrixLayout::F { .. } => uplo, + MatrixLayout::C { .. } => uplo.t(), + }; + let mut info = 0; unsafe { - $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), b, ldb) - .as_lapack_result()?; + $trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info); } + info.as_lapack_result()?; Ok(()) } } }; } // end macro_rules -impl_cholesky!(f64, lapacke::dpotrf, lapacke::dpotri, lapacke::dpotrs); -impl_cholesky!(f32, lapacke::spotrf, lapacke::spotri, lapacke::spotrs); -impl_cholesky!(c64, lapacke::zpotrf, lapacke::zpotri, lapacke::zpotrs); -impl_cholesky!(c32, lapacke::cpotrf, lapacke::cpotri, lapacke::cpotrs); +impl_cholesky!(f64, lapack::dpotrf, lapack::dpotri, lapack::dpotrs); +impl_cholesky!(f32, lapack::spotrf, lapack::spotri, lapack::spotrs); +impl_cholesky!(c64, lapack::zpotrf, lapack::zpotri, lapack::zpotrs); +impl_cholesky!(c32, lapack::cpotrf, lapack::cpotri, lapack::cpotrs); diff --git a/lax/src/lib.rs b/lax/src/lib.rs index be88410e..bbbdd85b 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -126,6 +126,15 @@ pub enum UPLO { Lower = b'L', } +impl UPLO { + pub fn t(self) -> Self { + match self { + UPLO::Upper => UPLO::Lower, + UPLO::Lower => UPLO::Upper, + } + } +} + #[derive(Debug, Clone, Copy)] #[repr(u8)] pub enum Transpose { From 5f6a2c39b16cca3295fe12d2025f22066f51f7e6 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 18 Jul 2020 03:03:21 +0900 Subject: [PATCH 27/49] Rewrite opnorm by LAPACK --- lax/src/opnorm.rs | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/lax/src/opnorm.rs b/lax/src/opnorm.rs index 4786fd6e..c76d9f59 100644 --- a/lax/src/opnorm.rs +++ b/lax/src/opnorm.rs @@ -2,7 +2,7 @@ use crate::layout::MatrixLayout; use cauchy::*; -use lapacke::Layout::ColumnMajor as cm; +use num_traits::Zero; pub use super::NormType; @@ -14,18 +14,24 @@ macro_rules! impl_opnorm { ($scalar:ty, $lange:path) => { impl OperatorNorm_ for $scalar { unsafe fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real { - match l { - MatrixLayout::F { col, lda } => $lange(cm, t as u8, lda, col, a, lda), - MatrixLayout::C { row, lda } => { - $lange(cm, t.transpose() as u8, lda, row, a, lda) - } - } + let m = l.lda(); + let n = l.len(); + let t = match l { + MatrixLayout::F { .. } => t, + MatrixLayout::C { .. } => t.transpose(), + }; + let mut work = if matches!(t, NormType::Infinity) { + vec![Self::Real::zero(); m as usize] + } else { + Vec::new() + }; + $lange(t as u8, m, n, a, m, &mut work) } } }; } // impl_opnorm! -impl_opnorm!(f64, lapacke::dlange); -impl_opnorm!(f32, lapacke::slange); -impl_opnorm!(c64, lapacke::zlange); -impl_opnorm!(c32, lapacke::clange); +impl_opnorm!(f64, lapack::dlange); +impl_opnorm!(f32, lapack::slange); +impl_opnorm!(c64, lapack::zlange); +impl_opnorm!(c32, lapack::clange); From 06a6c658c7ccb17d779798d0a625ced47b20f00d Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 18 Jul 2020 16:20:23 +0900 Subject: [PATCH 28/49] square_transpose --- lax/src/cholesky.rs | 31 +++++++++++++++++-------------- lax/src/layout.rs | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 14 deletions(-) diff --git a/lax/src/cholesky.rs b/lax/src/cholesky.rs index 93c5147b..618218f1 100644 --- a/lax/src/cholesky.rs +++ b/lax/src/cholesky.rs @@ -1,7 +1,7 @@ //! Cholesky decomposition use super::*; -use crate::{error::*, layout::MatrixLayout}; +use crate::{error::*, layout::*}; use cauchy::*; pub trait Cholesky_: Sized { @@ -24,45 +24,48 @@ macro_rules! impl_cholesky { impl Cholesky_ for $scalar { fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); + if matches!(l, MatrixLayout::C { .. }) { + square_transpose(l, a); + } let mut info = 0; - let uplo = match l { - MatrixLayout::F { .. } => uplo, - MatrixLayout::C { .. } => uplo.t(), - }; unsafe { $trf(uplo as u8, n, a, n, &mut info); } info.as_lapack_result()?; + if matches!(l, MatrixLayout::C { .. }) { + square_transpose(l, a); + } Ok(()) } fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); + if matches!(l, MatrixLayout::C { .. }) { + square_transpose(l, a); + } let mut info = 0; - let uplo = match l { - MatrixLayout::F { .. } => uplo, - MatrixLayout::C { .. } => uplo.t(), - }; unsafe { $tri(uplo as u8, n, a, l.lda(), &mut info); } info.as_lapack_result()?; + if matches!(l, MatrixLayout::C { .. }) { + square_transpose(l, a); + } Ok(()) } fn solve_cholesky( l: MatrixLayout, - uplo: UPLO, + mut uplo: UPLO, a: &[Self], b: &mut [Self], ) -> Result<()> { let (n, _) = l.size(); let nrhs = 1; - let uplo = match l { - MatrixLayout::F { .. } => uplo, - MatrixLayout::C { .. } => uplo.t(), - }; let mut info = 0; + if matches!(l, MatrixLayout::C { .. }) { + uplo = uplo.t(); + } unsafe { $trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info); } diff --git a/lax/src/layout.rs b/lax/src/layout.rs index aa9fe110..9dad70e6 100644 --- a/lax/src/layout.rs +++ b/lax/src/layout.rs @@ -37,6 +37,8 @@ //! This `S` for a matrix `A` is called "leading dimension of the array A" in LAPACK document, and denoted by `lda`. //! +use cauchy::Scalar; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum MatrixLayout { C { row: i32, lda: i32 }, @@ -96,3 +98,44 @@ impl MatrixLayout { } } } + +/// In-place transpose of a square matrix by keeping F/C layout +/// +/// Transpose for C-continuous array +/// +/// ```rust +/// # use lax::layout::*; +/// let layout = MatrixLayout::C { row: 2, lda: 2 }; +/// let mut a = vec![1., 2., 3., 4.]; +/// square_transpose(layout, &mut a); +/// assert_eq!(a, &[1., 3., 2., 4.]); +/// ``` +/// +/// Transpose for F-continuous array +/// +/// ```rust +/// # use lax::layout::*; +/// let layout = MatrixLayout::F { col: 2, lda: 2 }; +/// let mut a = vec![1., 3., 2., 4.]; +/// square_transpose(layout, &mut a); +/// assert_eq!(a, &[1., 2., 3., 4.]); +/// ``` +/// +/// Panics +/// ------ +/// - If size of `a` and `layout` size mismatch +/// +pub fn square_transpose(layout: MatrixLayout, a: &mut [T]) { + let (m, n) = layout.size(); + let n = n as usize; + let m = m as usize; + assert_eq!(a.len(), n * m); + for i in 0..m { + for j in (i + 1)..n { + let a_ij = a[i * n + j]; + let a_ji = a[j * m + i]; + a[i * n + j] = a_ji.conj(); + a[j * m + i] = a_ij.conj(); + } + } +} From 76c06d4491ddfe60dbd18d5af24d9df1dde5c1ee Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 18 Jul 2020 16:55:57 +0900 Subject: [PATCH 29/49] Split tests of cholesky --- ndarray-linalg/tests/cholesky.rs | 390 ++++++++++++++++--------------- 1 file changed, 201 insertions(+), 189 deletions(-) diff --git a/ndarray-linalg/tests/cholesky.rs b/ndarray-linalg/tests/cholesky.rs index accdbdf8..b45afb5c 100644 --- a/ndarray-linalg/tests/cholesky.rs +++ b/ndarray-linalg/tests/cholesky.rs @@ -1,216 +1,228 @@ use ndarray::*; use ndarray_linalg::*; -#[test] -fn cholesky() { - macro_rules! cholesky { - ($elem:ty, $rtol:expr) => { - let a_orig: Array2<$elem> = random_hpd(3); - println!("a = \n{:?}", a_orig); +macro_rules! cholesky { + ($elem:ty, $rtol:expr) => { + paste::item! { + #[test] + fn []() { + let a_orig: Array2<$elem> = random_hpd(3); + println!("a = \n{:?}", a_orig); - let upper = a_orig.cholesky(UPLO::Upper).unwrap(); - assert_close_l2!( - &upper.t().mapv(|elem| elem.conj()).dot(&upper.view()), - &a_orig, - $rtol - ); - - let lower = a_orig.cholesky(UPLO::Lower).unwrap(); - assert_close_l2!( - &lower.dot(&lower.t().mapv(|elem| elem.conj())), - &a_orig, - $rtol - ); - - let a: Array2<$elem> = replicate(&a_orig); - let upper = a.cholesky_into(UPLO::Upper).unwrap(); - assert_close_l2!( - &upper.t().mapv(|elem| elem.conj()).dot(&upper.view()), - &a_orig, - $rtol - ); + let upper = a_orig.cholesky(UPLO::Upper).unwrap(); + assert_close_l2!( + &upper.t().mapv(|elem| elem.conj()).dot(&upper.view()), + &a_orig, + $rtol + ); - let a: Array2<$elem> = replicate(&a_orig); - let lower = a.cholesky_into(UPLO::Lower).unwrap(); - assert_close_l2!( - &lower.dot(&lower.t().mapv(|elem| elem.conj())), - &a_orig, - $rtol - ); + let lower = a_orig.cholesky(UPLO::Lower).unwrap(); + assert_close_l2!( + &lower.dot(&lower.t().mapv(|elem| elem.conj())), + &a_orig, + $rtol + ); - let mut a: Array2<$elem> = replicate(&a_orig); - { - let upper = a.cholesky_inplace(UPLO::Upper).unwrap(); + let a: Array2<$elem> = replicate(&a_orig); + let upper = a.cholesky_into(UPLO::Upper).unwrap(); assert_close_l2!( &upper.t().mapv(|elem| elem.conj()).dot(&upper.view()), &a_orig, $rtol ); - } - assert_close_l2!( - &a.t().mapv(|elem| elem.conj()).dot(&upper.view()), - &a_orig, - $rtol - ); - let mut a: Array2<$elem> = replicate(&a_orig); - { - let lower = a.cholesky_inplace(UPLO::Lower).unwrap(); + let a: Array2<$elem> = replicate(&a_orig); + let lower = a.cholesky_into(UPLO::Lower).unwrap(); assert_close_l2!( &lower.dot(&lower.t().mapv(|elem| elem.conj())), &a_orig, $rtol ); + + let mut a: Array2<$elem> = replicate(&a_orig); + { + let upper = a.cholesky_inplace(UPLO::Upper).unwrap(); + assert_close_l2!( + &upper.t().mapv(|elem| elem.conj()).dot(&upper.view()), + &a_orig, + $rtol + ); + } + assert_close_l2!( + &a.t().mapv(|elem| elem.conj()).dot(&upper.view()), + &a_orig, + $rtol + ); + + let mut a: Array2<$elem> = replicate(&a_orig); + { + let lower = a.cholesky_inplace(UPLO::Lower).unwrap(); + assert_close_l2!( + &lower.dot(&lower.t().mapv(|elem| elem.conj())), + &a_orig, + $rtol + ); + } + assert_close_l2!(&a.dot(&lower.t().mapv(|elem| elem.conj())), &a_orig, $rtol); } - assert_close_l2!(&a.dot(&lower.t().mapv(|elem| elem.conj())), &a_orig, $rtol); - }; - } - cholesky!(f64, 1e-9); - cholesky!(f32, 1e-5); - cholesky!(c64, 1e-9); - cholesky!(c32, 1e-5); + } // paste::item + }; } -#[test] -fn cholesky_into_lower_upper() { - macro_rules! cholesky_into_lower_upper { - ($elem:ty, $rtol:expr) => { - let a: Array2<$elem> = random_hpd(3); - println!("a = \n{:?}", a); - let upper = a.cholesky(UPLO::Upper).unwrap(); - let fac_upper = a.factorizec(UPLO::Upper).unwrap(); - let fac_lower = a.factorizec(UPLO::Lower).unwrap(); - assert_close_l2!(&upper, &fac_lower.into_upper(), $rtol); - assert_close_l2!(&upper, &fac_upper.into_upper(), $rtol); - let lower = a.cholesky(UPLO::Lower).unwrap(); - let fac_upper = a.factorizec(UPLO::Upper).unwrap(); - let fac_lower = a.factorizec(UPLO::Lower).unwrap(); - assert_close_l2!(&lower, &fac_lower.into_lower(), $rtol); - assert_close_l2!(&lower, &fac_upper.into_lower(), $rtol); - }; - } - cholesky_into_lower_upper!(f64, 1e-9); - cholesky_into_lower_upper!(f32, 1e-5); - cholesky_into_lower_upper!(c64, 1e-9); - cholesky_into_lower_upper!(c32, 1e-5); -} +cholesky!(f64, 1e-9); +cholesky!(f32, 1e-5); +cholesky!(c64, 1e-9); +cholesky!(c32, 1e-5); -#[test] -fn cholesky_inverse() { - macro_rules! cholesky_into_inverse { - ($elem:ty, $rtol:expr) => { - let a: Array2<$elem> = random_hpd(3); - println!("a = \n{:?}", a); - let inv = a.invc().unwrap(); - assert_close_l2!(&a.dot(&inv), &Array2::eye(3), $rtol); - let inv_into: Array2<$elem> = replicate(&a).invc_into().unwrap(); - assert_close_l2!(&a.dot(&inv_into), &Array2::eye(3), $rtol); - let inv_upper = a.factorizec(UPLO::Upper).unwrap().invc().unwrap(); - assert_close_l2!(&a.dot(&inv_upper), &Array2::eye(3), $rtol); - let inv_upper_into = a.factorizec(UPLO::Upper).unwrap().invc_into().unwrap(); - assert_close_l2!(&a.dot(&inv_upper_into), &Array2::eye(3), $rtol); - let inv_lower = a.factorizec(UPLO::Lower).unwrap().invc().unwrap(); - assert_close_l2!(&a.dot(&inv_lower), &Array2::eye(3), $rtol); - let inv_lower_into = a.factorizec(UPLO::Lower).unwrap().invc_into().unwrap(); - assert_close_l2!(&a.dot(&inv_lower_into), &Array2::eye(3), $rtol); - }; - } - cholesky_into_inverse!(f64, 1e-9); - cholesky_into_inverse!(f32, 1e-3); - cholesky_into_inverse!(c64, 1e-9); - cholesky_into_inverse!(c32, 1e-3); +macro_rules! cholesky_into_lower_upper { + ($elem:ty, $rtol:expr) => { + paste::item! { + #[test] + fn []() { + let a: Array2<$elem> = random_hpd(3); + println!("a = \n{:?}", a); + let upper = a.cholesky(UPLO::Upper).unwrap(); + let fac_upper = a.factorizec(UPLO::Upper).unwrap(); + let fac_lower = a.factorizec(UPLO::Lower).unwrap(); + assert_close_l2!(&upper, &fac_lower.into_upper(), $rtol); + assert_close_l2!(&upper, &fac_upper.into_upper(), $rtol); + let lower = a.cholesky(UPLO::Lower).unwrap(); + let fac_upper = a.factorizec(UPLO::Upper).unwrap(); + let fac_lower = a.factorizec(UPLO::Lower).unwrap(); + assert_close_l2!(&lower, &fac_lower.into_lower(), $rtol); + assert_close_l2!(&lower, &fac_upper.into_lower(), $rtol); + } + } + }; } -#[test] -fn cholesky_det() { - macro_rules! cholesky_det { - ($elem:ty, $atol:expr) => { - let a: Array2<$elem> = random_hpd(3); - println!("a = \n{:?}", a); - let ln_det = a - .eigvalsh(UPLO::Upper) - .unwrap() - .mapv(|elem| elem.ln()) - .scalar_sum(); - let det = ln_det.exp(); - assert_aclose!(a.factorizec(UPLO::Upper).unwrap().detc(), det, $atol); - assert_aclose!(a.factorizec(UPLO::Upper).unwrap().ln_detc(), ln_det, $atol); - assert_aclose!(a.factorizec(UPLO::Lower).unwrap().detc_into(), det, $atol); - assert_aclose!( - a.factorizec(UPLO::Lower).unwrap().ln_detc_into(), - ln_det, - $atol - ); - assert_aclose!(a.detc().unwrap(), det, $atol); - assert_aclose!(a.ln_detc().unwrap(), ln_det, $atol); - assert_aclose!(a.clone().detc_into().unwrap(), det, $atol); - assert_aclose!(a.ln_detc_into().unwrap(), ln_det, $atol); - }; - } - cholesky_det!(f64, 1e-9); - cholesky_det!(f32, 1e-3); - cholesky_det!(c64, 1e-9); - cholesky_det!(c32, 1e-3); +cholesky_into_lower_upper!(f64, 1e-9); +cholesky_into_lower_upper!(f32, 1e-5); +cholesky_into_lower_upper!(c64, 1e-9); +cholesky_into_lower_upper!(c32, 1e-5); + +macro_rules! cholesky_into_inverse { + ($elem:ty, $rtol:expr) => { + paste::item! { + #[test] + fn []() { + let a: Array2<$elem> = random_hpd(3); + println!("a = \n{:?}", a); + let inv = a.invc().unwrap(); + assert_close_l2!(&a.dot(&inv), &Array2::eye(3), $rtol); + let inv_into: Array2<$elem> = replicate(&a).invc_into().unwrap(); + assert_close_l2!(&a.dot(&inv_into), &Array2::eye(3), $rtol); + let inv_upper = a.factorizec(UPLO::Upper).unwrap().invc().unwrap(); + assert_close_l2!(&a.dot(&inv_upper), &Array2::eye(3), $rtol); + let inv_upper_into = a.factorizec(UPLO::Upper).unwrap().invc_into().unwrap(); + assert_close_l2!(&a.dot(&inv_upper_into), &Array2::eye(3), $rtol); + let inv_lower = a.factorizec(UPLO::Lower).unwrap().invc().unwrap(); + assert_close_l2!(&a.dot(&inv_lower), &Array2::eye(3), $rtol); + let inv_lower_into = a.factorizec(UPLO::Lower).unwrap().invc_into().unwrap(); + assert_close_l2!(&a.dot(&inv_lower_into), &Array2::eye(3), $rtol); + } + } + }; } +cholesky_into_inverse!(f64, 1e-9); +cholesky_into_inverse!(f32, 1e-3); +cholesky_into_inverse!(c64, 1e-9); +cholesky_into_inverse!(c32, 1e-3); -#[test] -fn cholesky_solve() { - macro_rules! cholesky_solve { - ($elem:ty, $rtol:expr) => { - let a: Array2<$elem> = random_hpd(3); - let x: Array1<$elem> = random(3); - let b = a.dot(&x); - println!("a = \n{:?}", a); - println!("x = \n{:?}", x); - assert_close_l2!(&a.solvec(&b).unwrap(), &x, $rtol); - assert_close_l2!(&a.solvec_into(b.clone()).unwrap(), &x, $rtol); - assert_close_l2!(&a.solvec_inplace(&mut b.clone()).unwrap(), &x, $rtol); - assert_close_l2!( - &a.factorizec(UPLO::Upper).unwrap().solvec(&b).unwrap(), - &x, - $rtol - ); - assert_close_l2!( - &a.factorizec(UPLO::Lower).unwrap().solvec(&b).unwrap(), - &x, - $rtol - ); - assert_close_l2!( - &a.factorizec(UPLO::Upper) - .unwrap() - .solvec_into(b.clone()) - .unwrap(), - &x, - $rtol - ); - assert_close_l2!( - &a.factorizec(UPLO::Lower) - .unwrap() - .solvec_into(b.clone()) - .unwrap(), - &x, - $rtol - ); - assert_close_l2!( - &a.factorizec(UPLO::Upper) +macro_rules! cholesky_det { + ($elem:ty, $atol:expr) => { + paste::item! { + #[test] + fn []() { + let a: Array2<$elem> = random_hpd(3); + println!("a = \n{:?}", a); + let ln_det = a + .eigvalsh(UPLO::Upper) .unwrap() - .solvec_inplace(&mut b.clone()) - .unwrap(), - &x, - $rtol - ); - assert_close_l2!( - &a.factorizec(UPLO::Lower) - .unwrap() - .solvec_inplace(&mut b.clone()) - .unwrap(), - &x, - $rtol - ); - }; - } - cholesky_solve!(f64, 1e-9); - cholesky_solve!(f32, 1e-3); - cholesky_solve!(c64, 1e-9); - cholesky_solve!(c32, 1e-3); + .mapv(|elem| elem.ln()) + .scalar_sum(); + let det = ln_det.exp(); + assert_aclose!(a.factorizec(UPLO::Upper).unwrap().detc(), det, $atol); + assert_aclose!(a.factorizec(UPLO::Upper).unwrap().ln_detc(), ln_det, $atol); + assert_aclose!(a.factorizec(UPLO::Lower).unwrap().detc_into(), det, $atol); + assert_aclose!( + a.factorizec(UPLO::Lower).unwrap().ln_detc_into(), + ln_det, + $atol + ); + assert_aclose!(a.detc().unwrap(), det, $atol); + assert_aclose!(a.ln_detc().unwrap(), ln_det, $atol); + assert_aclose!(a.clone().detc_into().unwrap(), det, $atol); + assert_aclose!(a.ln_detc_into().unwrap(), ln_det, $atol); + } + } + }; +} +cholesky_det!(f64, 1e-9); +cholesky_det!(f32, 1e-3); +cholesky_det!(c64, 1e-9); +cholesky_det!(c32, 1e-3); + +macro_rules! cholesky_solve { + ($elem:ty, $rtol:expr) => { + paste::item! { + #[test] + fn []() { + let a: Array2<$elem> = random_hpd(3); + let x: Array1<$elem> = random(3); + let b = a.dot(&x); + println!("a = \n{:?}", a); + println!("x = \n{:?}", x); + assert_close_l2!(&a.solvec(&b).unwrap(), &x, $rtol); + assert_close_l2!(&a.solvec_into(b.clone()).unwrap(), &x, $rtol); + assert_close_l2!(&a.solvec_inplace(&mut b.clone()).unwrap(), &x, $rtol); + assert_close_l2!( + &a.factorizec(UPLO::Upper).unwrap().solvec(&b).unwrap(), + &x, + $rtol + ); + assert_close_l2!( + &a.factorizec(UPLO::Lower).unwrap().solvec(&b).unwrap(), + &x, + $rtol + ); + assert_close_l2!( + &a.factorizec(UPLO::Upper) + .unwrap() + .solvec_into(b.clone()) + .unwrap(), + &x, + $rtol + ); + assert_close_l2!( + &a.factorizec(UPLO::Lower) + .unwrap() + .solvec_into(b.clone()) + .unwrap(), + &x, + $rtol + ); + assert_close_l2!( + &a.factorizec(UPLO::Upper) + .unwrap() + .solvec_inplace(&mut b.clone()) + .unwrap(), + &x, + $rtol + ); + assert_close_l2!( + &a.factorizec(UPLO::Lower) + .unwrap() + .solvec_inplace(&mut b.clone()) + .unwrap(), + &x, + $rtol + ); + } + } + }; } +cholesky_solve!(f64, 1e-9); +cholesky_solve!(f32, 1e-3); +cholesky_solve!(c64, 1e-9); +cholesky_solve!(c32, 1e-3); From 44b2adb89b3f51de01f2ea270ba057669289dde2 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 18 Jul 2020 17:44:25 +0900 Subject: [PATCH 30/49] Take complex conjugate --- lax/src/cholesky.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lax/src/cholesky.rs b/lax/src/cholesky.rs index 618218f1..8305efe5 100644 --- a/lax/src/cholesky.rs +++ b/lax/src/cholesky.rs @@ -65,11 +65,19 @@ macro_rules! impl_cholesky { let mut info = 0; if matches!(l, MatrixLayout::C { .. }) { uplo = uplo.t(); + for val in b.iter_mut() { + *val = val.conj(); + } } unsafe { $trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info); } info.as_lapack_result()?; + if matches!(l, MatrixLayout::C { .. }) { + for val in b.iter_mut() { + *val = val.conj(); + } + } Ok(()) } } From d7cf5037100c477474e97fc132b9f60fba2c3281 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 15 Jul 2020 17:49:33 +0900 Subject: [PATCH 31/49] Split impl for real based on LAPACK --- lax/src/least_squares.rs | 167 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 164 insertions(+), 3 deletions(-) diff --git a/lax/src/least_squares.rs b/lax/src/least_squares.rs index 69553a44..269fa952 100644 --- a/lax/src/least_squares.rs +++ b/lax/src/least_squares.rs @@ -2,7 +2,7 @@ use crate::{error::*, layout::MatrixLayout}; use cauchy::*; -use num_traits::Zero; +use num_traits::{ToPrimitive, Zero}; /// Result of LeastSquares pub struct LeastSquaresOutput { @@ -28,6 +28,169 @@ pub trait LeastSquaresSvdDivideConquer_: Scalar { ) -> Result>; } +/// Eval iwork size +/// +/// - NLVL = INT( LOG_2( MIN( M, N ) / ( SMLSIZ + 1 ) ) ) + 1 +/// - LIWORK = 3 * MIN( M, N ) * NLVL + 11 * MIN( M, N ) +/// +/// where SMLSIZ is returned by ILAENV and is equal to the maximum size of the subproblems +/// at the bottom of the computation tree (usually about 25). +/// +/// We put SMLSIZ=0 to estimate NLVL as large as possible +/// because its size will usually very small. +fn iwork_size(m: i32, n: i32) -> usize { + let mn = m.min(n) as usize; + let nlvl = (mn.to_f32().unwrap().log2() + 1.0) + .max(0.0) + .to_usize() + .unwrap(); + (3 * mn * nlvl + 11 * mn).max(1) +} + +macro_rules! impl_least_squares_real { + ($scalar:ty, $gelsd:path) => { + impl LeastSquaresSvdDivideConquer_ for $scalar { + unsafe fn least_squares( + l: MatrixLayout, + a: &mut [Self], + b: &mut [Self], + ) -> Result> { + let m = l.lda(); + let n = l.len(); + let k = m.min(n); + if (m as usize) > b.len() || (n as usize) > b.len() { + return Err(Error::InvalidShape); + } + let rcond: Self::Real = -1.; + let mut singular_values: Vec = vec![Self::Real::zero(); k as usize]; + let mut rank: i32 = 0; + + let mut iwork = vec![0; iwork_size(m, n)]; + + // eval work size + let mut info = 0; + let mut work_size = [Self::zero()]; + $gelsd( + m, + n, + 1, // nrhs + a, + m, + b, + b.len() as i32, + &mut singular_values, + rcond, + &mut rank, + &mut work_size, + -1, + &mut iwork, + &mut info, + ); + info.as_lapack_result()?; + + // calc + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + $gelsd( + m, + n, + 1, // nrhs + a, + m, + b, + b.len() as i32, + &mut singular_values, + rcond, + &mut rank, + &mut work, + lwork as i32, + &mut iwork, + &mut info, + ); + info.as_lapack_result()?; + + Ok(LeastSquaresOutput { + singular_values, + rank, + }) + } + + unsafe fn least_squares_nrhs( + a_layout: MatrixLayout, + a: &mut [Self], + b_layout: MatrixLayout, + b: &mut [Self], + ) -> Result> { + let m = a_layout.lda(); + let n = a_layout.len(); + let k = m.min(n); + if (m as usize) > b.len() + || (n as usize) > b.len() + || a_layout.lapacke_layout() != b_layout.lapacke_layout() + { + return Err(Error::InvalidShape); + } + let (b_lda, nrhs) = b_layout.size(); + let rcond: Self::Real = -1.; + let mut singular_values: Vec = vec![Self::Real::zero(); k as usize]; + let mut rank: i32 = 0; + + let mut iwork = vec![0; iwork_size(m, n)]; + + // eval work size + let mut info = 0; + let mut work_size = [Self::zero()]; + $gelsd( + m, + n, + nrhs, + a, + m, + b, + b_lda, + &mut singular_values, + rcond, + &mut rank, + &mut work_size, + -1, + &mut iwork, + &mut info, + ); + info.as_lapack_result()?; + + // calc + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + $gelsd( + m, + n, + nrhs, + a, + m, + b, + b_lda, + &mut singular_values, + rcond, + &mut rank, + &mut work, + lwork as i32, + &mut iwork, + &mut info, + ); + info.as_lapack_result()?; + + Ok(LeastSquaresOutput { + singular_values, + rank, + }) + } + } + }; +} + +impl_least_squares_real!(f64, lapack::dgelsd); +impl_least_squares_real!(f32, lapack::sgelsd); + macro_rules! impl_least_squares { ($scalar:ty, $gelsd:path) => { impl LeastSquaresSvdDivideConquer_ for $scalar { @@ -113,7 +276,5 @@ macro_rules! impl_least_squares { }; } -impl_least_squares!(f64, lapacke::dgelsd); -impl_least_squares!(f32, lapacke::sgelsd); impl_least_squares!(c64, lapacke::zgelsd); impl_least_squares!(c32, lapacke::cgelsd); From 45c717031d18115eb504c22eb2bf78d4dcadb068 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 15 Jul 2020 19:54:25 +0900 Subject: [PATCH 32/49] Query liwork --- lax/src/least_squares.rs | 33 ++++++++------------------------- 1 file changed, 8 insertions(+), 25 deletions(-) diff --git a/lax/src/least_squares.rs b/lax/src/least_squares.rs index 269fa952..102ac261 100644 --- a/lax/src/least_squares.rs +++ b/lax/src/least_squares.rs @@ -28,25 +28,6 @@ pub trait LeastSquaresSvdDivideConquer_: Scalar { ) -> Result>; } -/// Eval iwork size -/// -/// - NLVL = INT( LOG_2( MIN( M, N ) / ( SMLSIZ + 1 ) ) ) + 1 -/// - LIWORK = 3 * MIN( M, N ) * NLVL + 11 * MIN( M, N ) -/// -/// where SMLSIZ is returned by ILAENV and is equal to the maximum size of the subproblems -/// at the bottom of the computation tree (usually about 25). -/// -/// We put SMLSIZ=0 to estimate NLVL as large as possible -/// because its size will usually very small. -fn iwork_size(m: i32, n: i32) -> usize { - let mn = m.min(n) as usize; - let nlvl = (mn.to_f32().unwrap().log2() + 1.0) - .max(0.0) - .to_usize() - .unwrap(); - (3 * mn * nlvl + 11 * mn).max(1) -} - macro_rules! impl_least_squares_real { ($scalar:ty, $gelsd:path) => { impl LeastSquaresSvdDivideConquer_ for $scalar { @@ -65,11 +46,10 @@ macro_rules! impl_least_squares_real { let mut singular_values: Vec = vec![Self::Real::zero(); k as usize]; let mut rank: i32 = 0; - let mut iwork = vec![0; iwork_size(m, n)]; - // eval work size let mut info = 0; let mut work_size = [Self::zero()]; + let mut iwork_size = [0]; $gelsd( m, n, @@ -83,7 +63,7 @@ macro_rules! impl_least_squares_real { &mut rank, &mut work_size, -1, - &mut iwork, + &mut iwork_size, &mut info, ); info.as_lapack_result()?; @@ -91,6 +71,8 @@ macro_rules! impl_least_squares_real { // calc let lwork = work_size[0].to_usize().unwrap(); let mut work = vec![Self::zero(); lwork]; + let liwork = iwork_size[0].to_usize().unwrap(); + let mut iwork = vec![0; liwork]; $gelsd( m, n, @@ -135,11 +117,10 @@ macro_rules! impl_least_squares_real { let mut singular_values: Vec = vec![Self::Real::zero(); k as usize]; let mut rank: i32 = 0; - let mut iwork = vec![0; iwork_size(m, n)]; - // eval work size let mut info = 0; let mut work_size = [Self::zero()]; + let mut iwork_size = [0]; $gelsd( m, n, @@ -153,7 +134,7 @@ macro_rules! impl_least_squares_real { &mut rank, &mut work_size, -1, - &mut iwork, + &mut iwork_size, &mut info, ); info.as_lapack_result()?; @@ -161,6 +142,8 @@ macro_rules! impl_least_squares_real { // calc let lwork = work_size[0].to_usize().unwrap(); let mut work = vec![Self::zero(); lwork]; + let liwork = iwork_size[0].to_usize().unwrap(); + let mut iwork = vec![0; liwork]; $gelsd( m, n, From 6df2779080f70755ff51938e720ba6adb0c3040a Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 15 Jul 2020 21:11:35 +0900 Subject: [PATCH 33/49] Use least_squares_nrhs for 1-dim b case --- lax/src/least_squares.rs | 61 ++-------------------------------------- 1 file changed, 2 insertions(+), 59 deletions(-) diff --git a/lax/src/least_squares.rs b/lax/src/least_squares.rs index 102ac261..468a759b 100644 --- a/lax/src/least_squares.rs +++ b/lax/src/least_squares.rs @@ -36,65 +36,8 @@ macro_rules! impl_least_squares_real { a: &mut [Self], b: &mut [Self], ) -> Result> { - let m = l.lda(); - let n = l.len(); - let k = m.min(n); - if (m as usize) > b.len() || (n as usize) > b.len() { - return Err(Error::InvalidShape); - } - let rcond: Self::Real = -1.; - let mut singular_values: Vec = vec![Self::Real::zero(); k as usize]; - let mut rank: i32 = 0; - - // eval work size - let mut info = 0; - let mut work_size = [Self::zero()]; - let mut iwork_size = [0]; - $gelsd( - m, - n, - 1, // nrhs - a, - m, - b, - b.len() as i32, - &mut singular_values, - rcond, - &mut rank, - &mut work_size, - -1, - &mut iwork_size, - &mut info, - ); - info.as_lapack_result()?; - - // calc - let lwork = work_size[0].to_usize().unwrap(); - let mut work = vec![Self::zero(); lwork]; - let liwork = iwork_size[0].to_usize().unwrap(); - let mut iwork = vec![0; liwork]; - $gelsd( - m, - n, - 1, // nrhs - a, - m, - b, - b.len() as i32, - &mut singular_values, - rcond, - &mut rank, - &mut work, - lwork as i32, - &mut iwork, - &mut info, - ); - info.as_lapack_result()?; - - Ok(LeastSquaresOutput { - singular_values, - rank, - }) + let b_layout = l.resized(b.len() as i32, 1); + Self::least_squares_nrhs(l, a, b_layout, b) } unsafe fn least_squares_nrhs( From f417f6f15e19e8c09abecc9a6ac88ba77ea222ca Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 24 Jul 2020 21:48:12 +0900 Subject: [PATCH 34/49] Fix test scalar types --- ndarray-linalg/tests/least_squares.rs | 12 +++--- ndarray-linalg/tests/least_squares_nrhs.rs | 48 +++++++++++----------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/ndarray-linalg/tests/least_squares.rs b/ndarray-linalg/tests/least_squares.rs index c388c9d7..e2df3370 100644 --- a/ndarray-linalg/tests/least_squares.rs +++ b/ndarray-linalg/tests/least_squares.rs @@ -27,13 +27,13 @@ macro_rules! impl_exact { paste::item! { #[test] fn []() { - let a: Array2 = random((3, 3)); + let a: Array2<$scalar> = random((3, 3)); test_exact(a) } #[test] fn []() { - let a: Array2 = random((3, 3).f()); + let a: Array2<$scalar> = random((3, 3).f()); test_exact(a) } } @@ -73,13 +73,13 @@ macro_rules! impl_overdetermined { paste::item! { #[test] fn []() { - let a: Array2 = random((4, 3)); + let a: Array2<$scalar> = random((4, 3)); test_overdetermined(a) } #[test] fn []() { - let a: Array2 = random((4, 3).f()); + let a: Array2<$scalar> = random((4, 3).f()); test_overdetermined(a) } } @@ -110,13 +110,13 @@ macro_rules! impl_underdetermined { paste::item! { #[test] fn []() { - let a: Array2 = random((3, 4)); + let a: Array2<$scalar> = random((3, 4)); test_underdetermined(a) } #[test] fn []() { - let a: Array2 = random((3, 4).f()); + let a: Array2<$scalar> = random((3, 4).f()); test_underdetermined(a) } } diff --git a/ndarray-linalg/tests/least_squares_nrhs.rs b/ndarray-linalg/tests/least_squares_nrhs.rs index 4c964697..8072c409 100644 --- a/ndarray-linalg/tests/least_squares_nrhs.rs +++ b/ndarray-linalg/tests/least_squares_nrhs.rs @@ -31,8 +31,8 @@ macro_rules! impl_exact { paste::item! { #[test] fn []() { - let a: Array2 = random((3, 3)); - let b: Array2 = random((3, 2)); + let a: Array2<$scalar> = random((3, 3)); + let b: Array2<$scalar> = random((3, 2)); test_exact(a, b) } @@ -40,15 +40,15 @@ macro_rules! impl_exact { #[test] fn []() { - let a: Array2 = random((3, 3)); - let b: Array2 = random((3, 2).f()); + let a: Array2<$scalar> = random((3, 3)); + let b: Array2<$scalar> = random((3, 2).f()); test_exact(a, b) } #[test] fn []() { - let a: Array2 = random((3, 3).f()); - let b: Array2 = random((3, 2)); + let a: Array2<$scalar> = random((3, 3).f()); + let b: Array2<$scalar> = random((3, 2)); test_exact(a, b) } @@ -56,8 +56,8 @@ macro_rules! impl_exact { #[test] fn []() { - let a: Array2 = random((3, 3).f()); - let b: Array2 = random((3, 2).f()); + let a: Array2<$scalar> = random((3, 3).f()); + let b: Array2<$scalar> = random((3, 2).f()); test_exact(a, b) } } @@ -103,8 +103,8 @@ macro_rules! impl_overdetermined { paste::item! { #[test] fn []() { - let a: Array2 = random((4, 3)); - let b: Array2 = random((4, 2)); + let a: Array2<$scalar> = random((4, 3)); + let b: Array2<$scalar> = random((4, 2)); test_overdetermined(a, b) } @@ -112,15 +112,15 @@ macro_rules! impl_overdetermined { #[test] fn []() { - let a: Array2 = random((4, 3).f()); - let b: Array2 = random((4, 2)); + let a: Array2<$scalar> = random((4, 3).f()); + let b: Array2<$scalar> = random((4, 2)); test_overdetermined(a, b) } #[test] fn []() { - let a: Array2 = random((4, 3)); - let b: Array2 = random((4, 2).f()); + let a: Array2<$scalar> = random((4, 3)); + let b: Array2<$scalar> = random((4, 2).f()); test_overdetermined(a, b) } @@ -128,8 +128,8 @@ macro_rules! impl_overdetermined { #[test] fn []() { - let a: Array2 = random((4, 3).f()); - let b: Array2 = random((4, 2).f()); + let a: Array2<$scalar> = random((4, 3).f()); + let b: Array2<$scalar> = random((4, 2).f()); test_overdetermined(a, b) } } @@ -162,8 +162,8 @@ macro_rules! impl_underdetermined { paste::item! { #[test] fn []() { - let a: Array2 = random((3, 4)); - let b: Array2 = random((3, 2)); + let a: Array2<$scalar> = random((3, 4)); + let b: Array2<$scalar> = random((3, 2)); test_underdetermined(a, b) } @@ -171,15 +171,15 @@ macro_rules! impl_underdetermined { #[test] fn []() { - let a: Array2 = random((3, 4).f()); - let b: Array2 = random((3, 2)); + let a: Array2<$scalar> = random((3, 4).f()); + let b: Array2<$scalar> = random((3, 2)); test_underdetermined(a, b) } #[test] fn []() { - let a: Array2 = random((3, 4)); - let b: Array2 = random((3, 2).f()); + let a: Array2<$scalar> = random((3, 4)); + let b: Array2<$scalar> = random((3, 2).f()); test_underdetermined(a, b) } @@ -187,8 +187,8 @@ macro_rules! impl_underdetermined { #[test] fn []() { - let a: Array2 = random((3, 4).f()); - let b: Array2 = random((3, 2).f()); + let a: Array2<$scalar> = random((3, 4).f()); + let b: Array2<$scalar> = random((3, 2).f()); test_underdetermined(a, b) } } From 8ff555d2ca17fcccae3af15e6b9444547996f0dc Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 25 Jul 2020 02:29:21 +0900 Subject: [PATCH 35/49] Out-place transpose --- lax/src/layout.rs | 87 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/lax/src/layout.rs b/lax/src/layout.rs index 9dad70e6..43b7ee87 100644 --- a/lax/src/layout.rs +++ b/lax/src/layout.rs @@ -97,6 +97,37 @@ impl MatrixLayout { MatrixLayout::F { col, lda } => MatrixLayout::C { row: lda, lda: col }, } } + + /// Transpose without changing memory representation + /// + /// C-contigious row=2, lda=3 + /// + /// ```text + /// [[1, 2, 3] + /// [4, 5, 6]] + /// ``` + /// + /// and F-contigious col=2, lda=3 + /// + /// ```text + /// [[1, 4] + /// [2, 5] + /// [3, 6]] + /// ``` + /// + /// have same memory representation `[1, 2, 3, 4, 5, 6]`, and this toggles them. + /// + /// ``` + /// # use lax::layout::*; + /// let layout = MatrixLayout::C { row: 2, lda: 3 }; + /// assert_eq!(layout.t(), MatrixLayout::F { col: 2, lda: 3 }); + /// ``` + pub fn t(&self) -> Self { + match *self { + MatrixLayout::C { row, lda } => MatrixLayout::F { col: row, lda }, + MatrixLayout::F { col, lda } => MatrixLayout::C { row: col, lda }, + } + } } /// In-place transpose of a square matrix by keeping F/C layout @@ -139,3 +170,59 @@ pub fn square_transpose(layout: MatrixLayout, a: &mut [T]) { } } } + +/// Out-place transpose for general matrix +/// +/// Inplace transpose of non-square matrices is hard. +/// See also: https://en.wikipedia.org/wiki/In-place_matrix_transposition +/// +/// ```rust +/// # use lax::layout::*; +/// let layout = MatrixLayout::C { row: 2, lda: 3 }; +/// let a = vec![1., 2., 3., 4., 5., 6.]; +/// let mut b = vec![0.0; a.len()]; +/// let l = transpose(layout, &a, &mut b); +/// assert_eq!(l, MatrixLayout::F { col: 3, lda: 2 }); +/// assert_eq!(b, &[1., 4., 2., 5., 3., 6.]); +/// ``` +/// +/// ```rust +/// # use lax::layout::*; +/// let layout = MatrixLayout::F { col: 2, lda: 3 }; +/// let a = vec![1., 2., 3., 4., 5., 6.]; +/// let mut b = vec![0.0; a.len()]; +/// let l = transpose(layout, &a, &mut b); +/// assert_eq!(l, MatrixLayout::C { row: 3, lda: 2 }); +/// assert_eq!(b, &[1., 4., 2., 5., 3., 6.]); +/// ``` +/// +/// Panics +/// ------ +/// - If size of `a` and `layout` size mismatch +/// +pub fn transpose(layout: MatrixLayout, from: &[T], to: &mut [T]) -> MatrixLayout { + let (m, n) = layout.size(); + let transposed = layout.resized(n, m).t(); + let m = m as usize; + let n = n as usize; + assert_eq!(from.len(), m * n); + assert_eq!(to.len(), m * n); + + match layout { + MatrixLayout::C { .. } => { + for i in 0..m { + for j in 0..n { + to[j * m + i] = from[i * n + j]; + } + } + } + MatrixLayout::F { .. } => { + for i in 0..m { + for j in 0..n { + to[i * n + j] = from[j * m + i]; + } + } + } + } + transposed +} From 200b9d435a3a0d1adb562e7956841cbf7a0b9689 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 25 Jul 2020 17:07:06 +0900 Subject: [PATCH 36/49] Take transpose --- lax/src/least_squares.rs | 67 +++++++++++++++------- ndarray-linalg/src/least_squares.rs | 1 + ndarray-linalg/tests/least_squares_nrhs.rs | 1 + 3 files changed, 49 insertions(+), 20 deletions(-) diff --git a/lax/src/least_squares.rs b/lax/src/least_squares.rs index 468a759b..64349051 100644 --- a/lax/src/least_squares.rs +++ b/lax/src/least_squares.rs @@ -1,6 +1,6 @@ //! Least squares -use crate::{error::*, layout::MatrixLayout}; +use crate::{error::*, layout::*}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; @@ -46,16 +46,37 @@ macro_rules! impl_least_squares_real { b_layout: MatrixLayout, b: &mut [Self], ) -> Result> { - let m = a_layout.lda(); - let n = a_layout.len(); + // Minimize |b - Ax|_2 + // + // where + // A : (m, n) + // b : (m, p) + // x : (n, p) + let (m, n) = a_layout.size(); + let (m_, p) = b_layout.size(); let k = m.min(n); - if (m as usize) > b.len() - || (n as usize) > b.len() - || a_layout.lapacke_layout() != b_layout.lapacke_layout() - { - return Err(Error::InvalidShape); - } - let (b_lda, nrhs) = b_layout.size(); + assert_eq!(m, m_); + + // Transpose if a is C-continuous + let mut a_t = None; + let a_layout = match a_layout { + MatrixLayout::C { .. } => { + a_t = Some(vec![Self::zero(); a.len()]); + transpose(a_layout, a, a_t.as_mut().unwrap()) + } + MatrixLayout::F { .. } => a_layout, + }; + + // Transpose if b is C-continuous + let mut b_t = None; + let b_layout = match b_layout { + MatrixLayout::C { .. } => { + b_t = Some(vec![Self::zero(); b.len()]); + transpose(b_layout, b, b_t.as_mut().unwrap()) + } + MatrixLayout::F { .. } => b_layout, + }; + let rcond: Self::Real = -1.; let mut singular_values: Vec = vec![Self::Real::zero(); k as usize]; let mut rank: i32 = 0; @@ -67,11 +88,11 @@ macro_rules! impl_least_squares_real { $gelsd( m, n, - nrhs, - a, - m, - b, - b_lda, + p, + a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a), + a_layout.lda(), + b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), + b_layout.lda(), &mut singular_values, rcond, &mut rank, @@ -90,11 +111,11 @@ macro_rules! impl_least_squares_real { $gelsd( m, n, - nrhs, - a, - m, - b, - b_lda, + p, + a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a), + a_layout.lda(), + b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), + b_layout.lda(), &mut singular_values, rcond, &mut rank, @@ -105,6 +126,12 @@ macro_rules! impl_least_squares_real { ); info.as_lapack_result()?; + // Skip a_t -> a transpose because A has been destroyed + // Re-transpose b + if let Some(b_t) = b_t { + transpose(b_layout, &b_t, b); + } + Ok(LeastSquaresOutput { singular_values, rank, diff --git a/ndarray-linalg/src/least_squares.rs b/ndarray-linalg/src/least_squares.rs index 18d2033f..1a6e3a47 100644 --- a/ndarray-linalg/src/least_squares.rs +++ b/ndarray-linalg/src/least_squares.rs @@ -76,6 +76,7 @@ use crate::types::*; /// is a `m x 1` column vector. If `I` is `Ix2`, the RHS is a `n x k` matrix /// (which can be seen as solving `Ax = b` k times for different b) and /// the solution is a `m x k` matrix. +#[derive(Debug, Clone)] pub struct LeastSquaresResult { /// The singular values of the matrix A in `Ax = b` pub singular_values: Array1, diff --git a/ndarray-linalg/tests/least_squares_nrhs.rs b/ndarray-linalg/tests/least_squares_nrhs.rs index 8072c409..95737873 100644 --- a/ndarray-linalg/tests/least_squares_nrhs.rs +++ b/ndarray-linalg/tests/least_squares_nrhs.rs @@ -9,6 +9,7 @@ fn test_exact(a: Array2, b: Array2) { assert_eq!(b.layout().unwrap().size(), (3, 2)); let result = a.least_squares(&b).unwrap(); + dbg!(&result); // unpack result let x: Array2 = result.solution; let residual_l2_square: Array1 = result.residual_sum_of_squares.unwrap(); From 87f0cea7ea4e7f0bdbbd8cd20bb25e2ce73f12f4 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 25 Jul 2020 17:15:49 +0900 Subject: [PATCH 37/49] Support complex case --- lax/src/least_squares.rs | 114 ++++++++------------------------------- 1 file changed, 22 insertions(+), 92 deletions(-) diff --git a/lax/src/least_squares.rs b/lax/src/least_squares.rs index 64349051..2c963158 100644 --- a/lax/src/least_squares.rs +++ b/lax/src/least_squares.rs @@ -28,8 +28,15 @@ pub trait LeastSquaresSvdDivideConquer_: Scalar { ) -> Result>; } -macro_rules! impl_least_squares_real { - ($scalar:ty, $gelsd:path) => { +macro_rules! impl_least_squares { + (@real, $scalar:ty, $gelsd:path) => { + impl_least_squares!(@body, $scalar, $gelsd, ); + }; + (@complex, $scalar:ty, $gelsd:path) => { + impl_least_squares!(@body, $scalar, $gelsd, rwork); + }; + + (@body, $scalar:ty, $gelsd:path, $($rwork:ident),*) => { impl LeastSquaresSvdDivideConquer_ for $scalar { unsafe fn least_squares( l: MatrixLayout, @@ -85,6 +92,9 @@ macro_rules! impl_least_squares_real { let mut info = 0; let mut work_size = [Self::zero()]; let mut iwork_size = [0]; + $( + let mut $rwork = [Self::Real::zero()]; + )* $gelsd( m, n, @@ -98,6 +108,7 @@ macro_rules! impl_least_squares_real { &mut rank, &mut work_size, -1, + $(&mut $rwork,)* &mut iwork_size, &mut info, ); @@ -108,6 +119,10 @@ macro_rules! impl_least_squares_real { let mut work = vec![Self::zero(); lwork]; let liwork = iwork_size[0].to_usize().unwrap(); let mut iwork = vec![0; liwork]; + $( + let lrwork = $rwork[0].to_usize().unwrap(); + let mut $rwork = vec![Self::Real::zero(); lrwork]; + )* $gelsd( m, n, @@ -121,6 +136,7 @@ macro_rules! impl_least_squares_real { &mut rank, &mut work, lwork as i32, + $(&mut $rwork,)* &mut iwork, &mut info, ); @@ -141,93 +157,7 @@ macro_rules! impl_least_squares_real { }; } -impl_least_squares_real!(f64, lapack::dgelsd); -impl_least_squares_real!(f32, lapack::sgelsd); - -macro_rules! impl_least_squares { - ($scalar:ty, $gelsd:path) => { - impl LeastSquaresSvdDivideConquer_ for $scalar { - unsafe fn least_squares( - a_layout: MatrixLayout, - a: &mut [Self], - b: &mut [Self], - ) -> Result> { - let (m, n) = a_layout.size(); - if (m as usize) > b.len() || (n as usize) > b.len() { - return Err(Error::InvalidShape); - } - let k = ::std::cmp::min(m, n); - let nrhs = 1; - let ldb = match a_layout { - MatrixLayout::F { .. } => m.max(n), - MatrixLayout::C { .. } => 1, - }; - let rcond: Self::Real = -1.; - let mut singular_values: Vec = vec![Self::Real::zero(); k as usize]; - let mut rank: i32 = 0; - - $gelsd( - a_layout.lapacke_layout(), - m, - n, - nrhs, - a, - a_layout.lda(), - b, - ldb, - &mut singular_values, - rcond, - &mut rank, - ) - .as_lapack_result()?; - - Ok(LeastSquaresOutput { - singular_values, - rank, - }) - } - - unsafe fn least_squares_nrhs( - a_layout: MatrixLayout, - a: &mut [Self], - b_layout: MatrixLayout, - b: &mut [Self], - ) -> Result> { - let (m, n) = a_layout.size(); - if (m as usize) > b.len() - || (n as usize) > b.len() - || a_layout.lapacke_layout() != b_layout.lapacke_layout() - { - return Err(Error::InvalidShape); - } - let k = ::std::cmp::min(m, n); - let nrhs = b_layout.size().1; - let rcond: Self::Real = -1.; - let mut singular_values: Vec = vec![Self::Real::zero(); k as usize]; - let mut rank: i32 = 0; - - $gelsd( - a_layout.lapacke_layout(), - m, - n, - nrhs, - a, - a_layout.lda(), - b, - b_layout.lda(), - &mut singular_values, - rcond, - &mut rank, - ) - .as_lapack_result()?; - Ok(LeastSquaresOutput { - singular_values, - rank, - }) - } - } - }; -} - -impl_least_squares!(c64, lapacke::zgelsd); -impl_least_squares!(c32, lapacke::cgelsd); +impl_least_squares!(@real, f64, lapack::dgelsd); +impl_least_squares!(@real, f32, lapack::sgelsd); +impl_least_squares!(@complex, c64, lapack::zgelsd); +impl_least_squares!(@complex, c32, lapack::cgelsd); From b742f6fe65907cc6af79734457ec3d2b85ce8149 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 25 Jul 2020 17:45:35 +0900 Subject: [PATCH 38/49] Fix assert to support under-determined case --- lax/src/least_squares.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lax/src/least_squares.rs b/lax/src/least_squares.rs index 2c963158..02f682fb 100644 --- a/lax/src/least_squares.rs +++ b/lax/src/least_squares.rs @@ -57,12 +57,12 @@ macro_rules! impl_least_squares { // // where // A : (m, n) - // b : (m, p) - // x : (n, p) + // b : (max(m, n), nrhs) // `b` has to store `x` on exit + // x : (n, nrhs) let (m, n) = a_layout.size(); - let (m_, p) = b_layout.size(); + let (m_, nrhs) = b_layout.size(); let k = m.min(n); - assert_eq!(m, m_); + assert!(m_ >= m); // Transpose if a is C-continuous let mut a_t = None; @@ -98,7 +98,7 @@ macro_rules! impl_least_squares { $gelsd( m, n, - p, + nrhs, a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a), a_layout.lda(), b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), @@ -126,7 +126,7 @@ macro_rules! impl_least_squares { $gelsd( m, n, - p, + nrhs, a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a), a_layout.lda(), b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), From 5ad434371cdf7644b9008e51f4a10fac7113f612 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 25 Jul 2020 18:26:12 +0900 Subject: [PATCH 39/49] Enable tests for C/F mixed cases #234 --- ndarray-linalg/tests/least_squares_nrhs.rs | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/ndarray-linalg/tests/least_squares_nrhs.rs b/ndarray-linalg/tests/least_squares_nrhs.rs index 95737873..dd7d283c 100644 --- a/ndarray-linalg/tests/least_squares_nrhs.rs +++ b/ndarray-linalg/tests/least_squares_nrhs.rs @@ -37,8 +37,6 @@ macro_rules! impl_exact { test_exact(a, b) } - /* Unsupported currently. See https://github.com/rust-ndarray/ndarray-linalg/issues/234 - #[test] fn []() { let a: Array2<$scalar> = random((3, 3)); @@ -53,8 +51,6 @@ macro_rules! impl_exact { test_exact(a, b) } - */ - #[test] fn []() { let a: Array2<$scalar> = random((3, 3).f()); @@ -109,8 +105,6 @@ macro_rules! impl_overdetermined { test_overdetermined(a, b) } - /* Unsupported currently. See https://github.com/rust-ndarray/ndarray-linalg/issues/234 - #[test] fn []() { let a: Array2<$scalar> = random((4, 3).f()); @@ -125,8 +119,6 @@ macro_rules! impl_overdetermined { test_overdetermined(a, b) } - */ - #[test] fn []() { let a: Array2<$scalar> = random((4, 3).f()); @@ -168,8 +160,6 @@ macro_rules! impl_underdetermined { test_underdetermined(a, b) } - /* Unsupported currently. See https://github.com/rust-ndarray/ndarray-linalg/issues/234 - #[test] fn []() { let a: Array2<$scalar> = random((3, 4).f()); @@ -184,8 +174,6 @@ macro_rules! impl_underdetermined { test_underdetermined(a, b) } - */ - #[test] fn []() { let a: Array2<$scalar> = random((3, 4).f()); From 00b044b7cca26968b4671b978288cb26fcd1ff97 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 25 Jul 2020 18:26:42 +0900 Subject: [PATCH 40/49] Fix error handling --- ndarray-linalg/src/least_squares.rs | 31 ++++++++++------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/ndarray-linalg/src/least_squares.rs b/ndarray-linalg/src/least_squares.rs index 1a6e3a47..0431a822 100644 --- a/ndarray-linalg/src/least_squares.rs +++ b/ndarray-linalg/src/least_squares.rs @@ -267,6 +267,9 @@ where &mut self, rhs: &mut ArrayBase, ) -> Result> { + if self.shape()[0] != rhs.shape()[0] { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into()); + } let (m, n) = (self.shape()[0], self.shape()[1]); if n > m { // we need a new rhs b/c it will be overwritten with the solution @@ -285,7 +288,7 @@ fn compute_least_squares_srhs( rhs: &mut ArrayBase, ) -> Result> where - E: Scalar + Lapack + LeastSquaresSvdDivideConquer_, + E: Scalar + Lapack, D1: DataMut, D2: DataMut, { @@ -293,7 +296,7 @@ where singular_values, rank, } = unsafe { - ::least_squares( + E::least_squares( a.layout()?, a.as_allocated_mut()?, rhs.as_slice_memory_order_mut() @@ -348,6 +351,9 @@ where &mut self, rhs: &mut ArrayBase, ) -> Result> { + if self.shape()[0] != rhs.shape()[0] { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into()); + } let (m, n) = (self.shape()[0], self.shape()[1]); if n > m { // we need a new rhs b/c it will be overwritten with the solution @@ -550,28 +556,13 @@ mod tests { // // Testing error cases // - #[test] fn incompatible_shape_error_on_mismatching_num_rows() { let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; let b: Array1 = array![1., 2.]; - let res = a.least_squares(&b); - match res { - Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {} - _ => panic!("Expected Err()"), - } - } - - #[test] - fn incompatible_shape_error_on_mismatching_layout() { - let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; - let b = array![[1.], [2.]].t().to_owned(); - assert_eq!(b.layout().unwrap(), MatrixLayout::F { col: 2, lda: 1 }); - - let res = a.least_squares(&b); - match res { - Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {} - _ => panic!("Expected Err()"), + match a.least_squares(&b) { + Err(LinalgError::Shape(e)) if e.kind() == ErrorKind::IncompatibleShape => {} + _ => panic!("Should be raise IncompatibleShape"), } } } From 962497eaecdeff4a5c0a2c5d614534217b037d79 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 25 Jul 2020 18:32:13 +0900 Subject: [PATCH 41/49] Drop unsafe --- lax/src/least_squares.rs | 80 +++++++++++++++-------------- ndarray-linalg/src/least_squares.rs | 28 +++++----- 2 files changed, 54 insertions(+), 54 deletions(-) diff --git a/lax/src/least_squares.rs b/lax/src/least_squares.rs index 02f682fb..d684c9b8 100644 --- a/lax/src/least_squares.rs +++ b/lax/src/least_squares.rs @@ -14,13 +14,13 @@ pub struct LeastSquaresOutput { /// Wraps `*gelsd` pub trait LeastSquaresSvdDivideConquer_: Scalar { - unsafe fn least_squares( + fn least_squares( a_layout: MatrixLayout, a: &mut [Self], b: &mut [Self], ) -> Result>; - unsafe fn least_squares_nrhs( + fn least_squares_nrhs( a_layout: MatrixLayout, a: &mut [Self], b_layout: MatrixLayout, @@ -38,7 +38,7 @@ macro_rules! impl_least_squares { (@body, $scalar:ty, $gelsd:path, $($rwork:ident),*) => { impl LeastSquaresSvdDivideConquer_ for $scalar { - unsafe fn least_squares( + fn least_squares( l: MatrixLayout, a: &mut [Self], b: &mut [Self], @@ -47,7 +47,7 @@ macro_rules! impl_least_squares { Self::least_squares_nrhs(l, a, b_layout, b) } - unsafe fn least_squares_nrhs( + fn least_squares_nrhs( a_layout: MatrixLayout, a: &mut [Self], b_layout: MatrixLayout, @@ -95,23 +95,25 @@ macro_rules! impl_least_squares { $( let mut $rwork = [Self::Real::zero()]; )* - $gelsd( - m, - n, - nrhs, - a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a), - a_layout.lda(), - b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), - b_layout.lda(), - &mut singular_values, - rcond, - &mut rank, - &mut work_size, - -1, - $(&mut $rwork,)* - &mut iwork_size, - &mut info, - ); + unsafe { + $gelsd( + m, + n, + nrhs, + a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a), + a_layout.lda(), + b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), + b_layout.lda(), + &mut singular_values, + rcond, + &mut rank, + &mut work_size, + -1, + $(&mut $rwork,)* + &mut iwork_size, + &mut info, + ) + }; info.as_lapack_result()?; // calc @@ -123,23 +125,25 @@ macro_rules! impl_least_squares { let lrwork = $rwork[0].to_usize().unwrap(); let mut $rwork = vec![Self::Real::zero(); lrwork]; )* - $gelsd( - m, - n, - nrhs, - a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a), - a_layout.lda(), - b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), - b_layout.lda(), - &mut singular_values, - rcond, - &mut rank, - &mut work, - lwork as i32, - $(&mut $rwork,)* - &mut iwork, - &mut info, - ); + unsafe { + $gelsd( + m, + n, + nrhs, + a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a), + a_layout.lda(), + b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), + b_layout.lda(), + &mut singular_values, + rcond, + &mut rank, + &mut work, + lwork as i32, + $(&mut $rwork,)* + &mut iwork, + &mut info, + ); + } info.as_lapack_result()?; // Skip a_t -> a transpose because A has been destroyed diff --git a/ndarray-linalg/src/least_squares.rs b/ndarray-linalg/src/least_squares.rs index 0431a822..0ff518ad 100644 --- a/ndarray-linalg/src/least_squares.rs +++ b/ndarray-linalg/src/least_squares.rs @@ -295,14 +295,12 @@ where let LeastSquaresOutput:: { singular_values, rank, - } = unsafe { - E::least_squares( - a.layout()?, - a.as_allocated_mut()?, - rhs.as_slice_memory_order_mut() - .ok_or_else(|| LinalgError::MemoryNotCont)?, - )? - }; + } = E::least_squares( + a.layout()?, + a.as_allocated_mut()?, + rhs.as_slice_memory_order_mut() + .ok_or_else(|| LinalgError::MemoryNotCont)?, + )?; let (m, n) = (a.shape()[0], a.shape()[1]); let solution = rhs.slice(s![0..n]).to_owned(); @@ -385,14 +383,12 @@ where let LeastSquaresOutput:: { singular_values, rank, - } = unsafe { - E::least_squares_nrhs( - a_layout, - a.as_allocated_mut()?, - rhs_layout, - rhs.as_allocated_mut()?, - )? - }; + } = E::least_squares_nrhs( + a_layout, + a.as_allocated_mut()?, + rhs_layout, + rhs.as_allocated_mut()?, + )?; let solution: Array2 = rhs.slice(s![..a.shape()[1], ..]).to_owned(); let singular_values = Array::from_shape_vec((singular_values.len(),), singular_values)?; From e9c34817d3cf709bfd84e12fd0c348c1eb58468e Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 26 Jul 2020 02:06:06 +0900 Subject: [PATCH 42/49] Impl triangular inv/solve using LAPACK --- lax/src/triangular.rs | 98 +++++++++++++++++++------------- ndarray-linalg/src/triangular.rs | 2 +- 2 files changed, 58 insertions(+), 42 deletions(-) diff --git a/lax/src/triangular.rs b/lax/src/triangular.rs index e725fa9e..5da50dde 100644 --- a/lax/src/triangular.rs +++ b/lax/src/triangular.rs @@ -1,8 +1,9 @@ //! Implement linear solver and inverse matrix use super::*; -use crate::{error::*, layout::MatrixLayout}; +use crate::{error::*, layout::*}; use cauchy::*; +use num_traits::Zero; #[derive(Debug, Clone, Copy)] #[repr(u8)] @@ -12,9 +13,8 @@ pub enum Diag { } /// Wraps `*trtri` and `*trtrs` -pub trait Triangular_: Sized { - unsafe fn inv_triangular(l: MatrixLayout, uplo: UPLO, d: Diag, a: &mut [Self]) -> Result<()>; - unsafe fn solve_triangular( +pub trait Triangular_: Scalar { + fn solve_triangular( al: MatrixLayout, bl: MatrixLayout, uplo: UPLO, @@ -27,50 +27,66 @@ pub trait Triangular_: Sized { macro_rules! impl_triangular { ($scalar:ty, $trtri:path, $trtrs:path) => { impl Triangular_ for $scalar { - unsafe fn inv_triangular( - l: MatrixLayout, - uplo: UPLO, - diag: Diag, - a: &mut [Self], - ) -> Result<()> { - let (n, _) = l.size(); - let lda = l.lda(); - $trtri(l.lapacke_layout(), uplo as u8, diag as u8, n, a, lda).as_lapack_result()?; - Ok(()) - } - - unsafe fn solve_triangular( - al: MatrixLayout, - bl: MatrixLayout, + fn solve_triangular( + a_layout: MatrixLayout, + b_layout: MatrixLayout, uplo: UPLO, diag: Diag, a: &[Self], - mut b: &mut [Self], + b: &mut [Self], ) -> Result<()> { - let (n, _) = al.size(); - let lda = al.lda(); - let (_, nrhs) = bl.size(); - let ldb = bl.lda(); - $trtrs( - al.lapacke_layout(), - uplo as u8, - Transpose::No as u8, - diag as u8, - n, - nrhs, - a, - lda, - &mut b, - ldb, - ) - .as_lapack_result()?; + // Transpose if a is C-continuous + let mut a_t = None; + let a_layout = match a_layout { + MatrixLayout::C { .. } => { + a_t = Some(vec![Self::zero(); a.len()]); + transpose(a_layout, a, a_t.as_mut().unwrap()) + } + MatrixLayout::F { .. } => a_layout, + }; + + // Transpose if b is C-continuous + let mut b_t = None; + let b_layout = match b_layout { + MatrixLayout::C { .. } => { + b_t = Some(vec![Self::zero(); b.len()]); + transpose(b_layout, b, b_t.as_mut().unwrap()) + } + MatrixLayout::F { .. } => b_layout, + }; + + let (m, n) = a_layout.size(); + let (n_, nrhs) = b_layout.size(); + assert_eq!(n, n_); + + let mut info = 0; + unsafe { + $trtrs( + uplo as u8, + Transpose::No as u8, + diag as u8, + m, + nrhs, + a_t.as_ref().map(|v| v.as_slice()).unwrap_or(a), + a_layout.lda(), + b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), + b_layout.lda(), + &mut info, + ); + } + info.as_lapack_result()?; + + // Re-transpose b + if let Some(b_t) = b_t { + transpose(b_layout, &b_t, b); + } Ok(()) } } }; } // impl_triangular! -impl_triangular!(f64, lapacke::dtrtri, lapacke::dtrtrs); -impl_triangular!(f32, lapacke::strtri, lapacke::strtrs); -impl_triangular!(c64, lapacke::ztrtri, lapacke::ztrtrs); -impl_triangular!(c32, lapacke::ctrtri, lapacke::ctrtrs); +impl_triangular!(f64, lapack::dtrtri, lapack::dtrtrs); +impl_triangular!(f32, lapack::strtri, lapack::strtrs); +impl_triangular!(c64, lapack::ztrtri, lapack::ztrtrs); +impl_triangular!(c32, lapack::ctrtri, lapack::ctrtrs); diff --git a/ndarray-linalg/src/triangular.rs b/ndarray-linalg/src/triangular.rs index c54beafd..b378dcd2 100644 --- a/ndarray-linalg/src/triangular.rs +++ b/ndarray-linalg/src/triangular.rs @@ -85,7 +85,7 @@ where transpose_data(b)?; } let lb = b.layout()?; - unsafe { A::solve_triangular(la, lb, uplo, diag, a_, b.as_allocated_mut()?)? }; + A::solve_triangular(la, lb, uplo, diag, a_, b.as_allocated_mut()?)?; Ok(b) } } From 373a18a977fc8ed65d45390308a1b53779e7be4a Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 26 Jul 2020 02:26:45 +0900 Subject: [PATCH 43/49] Impl tridiagonal by LAPACK --- lax/src/tridiagonal.rs | 52 ++++++++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/lax/src/tridiagonal.rs b/lax/src/tridiagonal.rs index 4eb8ff13..3dabd0c6 100644 --- a/lax/src/tridiagonal.rs +++ b/lax/src/tridiagonal.rs @@ -143,7 +143,13 @@ pub trait Tridiagonal_: Scalar + Sized { } macro_rules! impl_tridiagonal { - ($scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { + (@real, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { + impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, iwork); + }; + (@complex, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { + impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, ); + }; + (@body, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path, $($iwork:ident)*) => { impl Tridiagonal_ for $scalar { unsafe fn lu_tridiagonal( mut a: Tridiagonal, @@ -153,8 +159,11 @@ macro_rules! impl_tridiagonal { let mut ipiv = vec![0; n as usize]; // We have to calc one-norm before LU factorization let a_opnorm_one = a.opnorm_one(); - $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv) - .as_lapack_result()?; + let mut info = 0; + $gttrf( + n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv, &mut info, + ); + info.as_lapack_result()?; Ok(LUFactorizedTridiagonal { a, du2, @@ -166,7 +175,12 @@ macro_rules! impl_tridiagonal { unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { let (n, _) = lu.a.l.size(); let ipiv = &lu.ipiv; + let mut work = vec![Self::zero(); 2 * n as usize]; + $( + let mut $iwork = vec![0; n as usize]; + )* let mut rcond = Self::Real::zero(); + let mut info = 0; $gtcon( NormType::One as u8, n, @@ -177,8 +191,11 @@ macro_rules! impl_tridiagonal { ipiv, lu.a_opnorm_one, &mut rcond, - ) - .as_lapack_result()?; + &mut work, + $(&mut $iwork,)* + &mut info, + ); + info.as_lapack_result()?; Ok(rcond) } @@ -192,27 +209,18 @@ macro_rules! impl_tridiagonal { let (_, nrhs) = bl.size(); let ipiv = &lu.ipiv; let ldb = bl.lda(); + let mut info = 0; $gttrs( - lu.a.l.lapacke_layout(), - t as u8, - n, - nrhs, - &lu.a.dl, - &lu.a.d, - &lu.a.du, - &lu.du2, - ipiv, - b, - ldb, - ) - .as_lapack_result()?; + t as u8, n, nrhs, &lu.a.dl, &lu.a.d, &lu.a.du, &lu.du2, ipiv, b, ldb, &mut info, + ); + info.as_lapack_result()?; Ok(()) } } }; } // impl_tridiagonal! -impl_tridiagonal!(f64, lapacke::dgttrf, lapacke::dgtcon, lapacke::dgttrs); -impl_tridiagonal!(f32, lapacke::sgttrf, lapacke::sgtcon, lapacke::sgttrs); -impl_tridiagonal!(c64, lapacke::zgttrf, lapacke::zgtcon, lapacke::zgttrs); -impl_tridiagonal!(c32, lapacke::cgttrf, lapacke::cgtcon, lapacke::cgttrs); +impl_tridiagonal!(@real, f64, lapack::dgttrf, lapack::dgtcon, lapack::dgttrs); +impl_tridiagonal!(@real, f32, lapack::sgttrf, lapack::sgtcon, lapack::sgttrs); +impl_tridiagonal!(@complex, c64, lapack::zgttrf, lapack::zgtcon, lapack::zgttrs); +impl_tridiagonal!(@complex, c32, lapack::cgttrf, lapack::cgtcon, lapack::cgttrs); From 46b0dcd8097dfd1eaec1793943b39543233b174c Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 26 Jul 2020 17:51:23 +0900 Subject: [PATCH 44/49] Transpose if C-contiguous --- lax/src/tridiagonal.rs | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/lax/src/tridiagonal.rs b/lax/src/tridiagonal.rs index 3dabd0c6..e995e6f5 100644 --- a/lax/src/tridiagonal.rs +++ b/lax/src/tridiagonal.rs @@ -2,7 +2,7 @@ //! for tridiagonal matrix use super::*; -use crate::{error::*, layout::MatrixLayout}; +use crate::{error::*, layout::*}; use cauchy::*; use num_traits::Zero; use std::ops::{Index, IndexMut}; @@ -201,19 +201,40 @@ macro_rules! impl_tridiagonal { unsafe fn solve_tridiagonal( lu: &LUFactorizedTridiagonal, - bl: MatrixLayout, + b_layout: MatrixLayout, t: Transpose, b: &mut [Self], ) -> Result<()> { let (n, _) = lu.a.l.size(); - let (_, nrhs) = bl.size(); let ipiv = &lu.ipiv; - let ldb = bl.lda(); + // Transpose if b is C-continuous + let mut b_t = None; + let b_layout = match b_layout { + MatrixLayout::C { .. } => { + b_t = Some(vec![Self::zero(); b.len()]); + transpose(b_layout, b, b_t.as_mut().unwrap()) + } + MatrixLayout::F { .. } => b_layout, + }; + let (ldb, nrhs) = b_layout.size(); let mut info = 0; $gttrs( - t as u8, n, nrhs, &lu.a.dl, &lu.a.d, &lu.a.du, &lu.du2, ipiv, b, ldb, &mut info, + t as u8, + n, + nrhs, + &lu.a.dl, + &lu.a.d, + &lu.a.du, + &lu.du2, + ipiv, + b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), + ldb, + &mut info, ); info.as_lapack_result()?; + if let Some(b_t) = b_t { + transpose(b_layout, &b_t, b); + } Ok(()) } } From 2a6154f10e90326bd28dfdcdfa3684c9aa5a204f Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 26 Jul 2020 18:07:02 +0900 Subject: [PATCH 45/49] Drop unsafe --- lax/src/tridiagonal.rs | 76 +++++++++++++++++++++--------------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/lax/src/tridiagonal.rs b/lax/src/tridiagonal.rs index e995e6f5..ea5bb119 100644 --- a/lax/src/tridiagonal.rs +++ b/lax/src/tridiagonal.rs @@ -130,11 +130,11 @@ impl IndexMut<[i32; 2]> for Tridiagonal { pub trait Tridiagonal_: Scalar + Sized { /// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using /// partial pivoting with row interchanges. - unsafe fn lu_tridiagonal(a: Tridiagonal) -> Result>; + fn lu_tridiagonal(a: Tridiagonal) -> Result>; - unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result; + fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result; - unsafe fn solve_tridiagonal( + fn solve_tridiagonal( lu: &LUFactorizedTridiagonal, bl: MatrixLayout, t: Transpose, @@ -151,18 +151,14 @@ macro_rules! impl_tridiagonal { }; (@body, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path, $($iwork:ident)*) => { impl Tridiagonal_ for $scalar { - unsafe fn lu_tridiagonal( - mut a: Tridiagonal, - ) -> Result> { + fn lu_tridiagonal(mut a: Tridiagonal) -> Result> { let (n, _) = a.l.size(); let mut du2 = vec![Zero::zero(); (n - 2) as usize]; let mut ipiv = vec![0; n as usize]; // We have to calc one-norm before LU factorization let a_opnorm_one = a.opnorm_one(); let mut info = 0; - $gttrf( - n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv, &mut info, - ); + unsafe { $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv, &mut info,) }; info.as_lapack_result()?; Ok(LUFactorizedTridiagonal { a, @@ -172,7 +168,7 @@ macro_rules! impl_tridiagonal { }) } - unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { + fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { let (n, _) = lu.a.l.size(); let ipiv = &lu.ipiv; let mut work = vec![Self::zero(); 2 * n as usize]; @@ -181,25 +177,27 @@ macro_rules! impl_tridiagonal { )* let mut rcond = Self::Real::zero(); let mut info = 0; - $gtcon( - NormType::One as u8, - n, - &lu.a.dl, - &lu.a.d, - &lu.a.du, - &lu.du2, - ipiv, - lu.a_opnorm_one, - &mut rcond, - &mut work, - $(&mut $iwork,)* - &mut info, - ); + unsafe { + $gtcon( + NormType::One as u8, + n, + &lu.a.dl, + &lu.a.d, + &lu.a.du, + &lu.du2, + ipiv, + lu.a_opnorm_one, + &mut rcond, + &mut work, + $(&mut $iwork,)* + &mut info, + ); + } info.as_lapack_result()?; Ok(rcond) } - unsafe fn solve_tridiagonal( + fn solve_tridiagonal( lu: &LUFactorizedTridiagonal, b_layout: MatrixLayout, t: Transpose, @@ -218,19 +216,21 @@ macro_rules! impl_tridiagonal { }; let (ldb, nrhs) = b_layout.size(); let mut info = 0; - $gttrs( - t as u8, - n, - nrhs, - &lu.a.dl, - &lu.a.d, - &lu.a.du, - &lu.du2, - ipiv, - b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), - ldb, - &mut info, - ); + unsafe { + $gttrs( + t as u8, + n, + nrhs, + &lu.a.dl, + &lu.a.d, + &lu.a.du, + &lu.du2, + ipiv, + b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), + ldb, + &mut info, + ); + } info.as_lapack_result()?; if let Some(b_t) = b_t { transpose(b_layout, &b_t, b); From 29bef1a022fc5526c4eab29442a372d24dde5023 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 26 Jul 2020 18:21:33 +0900 Subject: [PATCH 46/49] Remove lapacke dep --- lax/Cargo.toml | 7 +++---- lax/src/layout.rs | 13 +++++-------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/lax/Cargo.toml b/lax/Cargo.toml index 642dcf36..84604074 100644 --- a/lax/Cargo.toml +++ b/lax/Cargo.toml @@ -11,11 +11,10 @@ netlib = ["lapack-src/netlib", "blas-src/netlib"] openblas = ["lapack-src/openblas", "blas-src/openblas"] [dependencies] -thiserror = "1" -cauchy = "0.2" -lapacke = "0.2.0" +thiserror = "1.0" +cauchy = "0.2.0" num-traits = "0.2" -lapack = "*" +lapack = "0.16.0" [dependencies.blas-src] version = "0.6.1" diff --git a/lax/src/layout.rs b/lax/src/layout.rs index 43b7ee87..e7ab1da4 100644 --- a/lax/src/layout.rs +++ b/lax/src/layout.rs @@ -80,15 +80,12 @@ impl MatrixLayout { self.len() == 0 } - pub fn lapacke_layout(&self) -> lapacke::Layout { - match *self { - MatrixLayout::C { .. } => lapacke::Layout::RowMajor, - MatrixLayout::F { .. } => lapacke::Layout::ColumnMajor, - } - } - pub fn same_order(&self, other: &MatrixLayout) -> bool { - self.lapacke_layout() == other.lapacke_layout() + match (self, other) { + (MatrixLayout::C { .. }, MatrixLayout::C { .. }) => true, + (MatrixLayout::F { .. }, MatrixLayout::F { .. }) => true, + _ => false, + } } pub fn toggle_order(&self) -> Self { From 6f870a5c1ac3379772296a67b12267a61c8c1603 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 26 Jul 2020 18:24:54 +0900 Subject: [PATCH 47/49] Remove unnecessarily unsafe --- ndarray-linalg/src/tridiagonal.rs | 50 ++++++++++++++----------------- 1 file changed, 22 insertions(+), 28 deletions(-) diff --git a/ndarray-linalg/src/tridiagonal.rs b/ndarray-linalg/src/tridiagonal.rs index 1fb0d558..92eae05b 100644 --- a/ndarray-linalg/src/tridiagonal.rs +++ b/ndarray-linalg/src/tridiagonal.rs @@ -271,14 +271,12 @@ where where Sb: DataMut, { - unsafe { - A::solve_tridiagonal( - &self, - rhs.layout()?, - Transpose::No, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solve_tridiagonal( + &self, + rhs.layout()?, + Transpose::No, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } fn solve_t_tridiagonal_inplace<'a, Sb>( @@ -288,14 +286,12 @@ where where Sb: DataMut, { - unsafe { - A::solve_tridiagonal( - &self, - rhs.layout()?, - Transpose::Transpose, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solve_tridiagonal( + &self, + rhs.layout()?, + Transpose::Transpose, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } fn solve_h_tridiagonal_inplace<'a, Sb>( @@ -305,14 +301,12 @@ where where Sb: DataMut, { - unsafe { - A::solve_tridiagonal( - &self, - rhs.layout()?, - Transpose::Hermite, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solve_tridiagonal( + &self, + rhs.layout()?, + Transpose::Hermite, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } } @@ -566,7 +560,7 @@ where A: Scalar + Lapack, { fn factorize_tridiagonal_into(self) -> Result> { - Ok(unsafe { A::lu_tridiagonal(self)? }) + Ok(A::lu_tridiagonal(self)?) } } @@ -576,7 +570,7 @@ where { fn factorize_tridiagonal(&self) -> Result> { let a = self.clone(); - Ok(unsafe { A::lu_tridiagonal(a)? }) + Ok(A::lu_tridiagonal(a)?) } } @@ -587,7 +581,7 @@ where { fn factorize_tridiagonal(&self) -> Result> { let a = self.extract_tridiagonal()?; - Ok(unsafe { A::lu_tridiagonal(a)? }) + Ok(A::lu_tridiagonal(a)?) } } @@ -677,7 +671,7 @@ where A: Scalar + Lapack, { fn rcond_tridiagonal(&self) -> Result { - unsafe { Ok(A::rcond_tridiagonal(&self)?) } + Ok(A::rcond_tridiagonal(&self)?) } } From d470ad26eb49e1c61e4d049c87ae53bd9bf9c78a Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 26 Jul 2020 18:32:25 +0900 Subject: [PATCH 48/49] Remove unnecessary unsafe more --- lax/src/opnorm.rs | 6 +-- lax/src/svd.rs | 87 +++++++++++++++++------------------- lax/src/svddc.rs | 80 ++++++++++++++++----------------- ndarray-linalg/src/opnorm.rs | 4 +- ndarray-linalg/src/svd.rs | 2 +- ndarray-linalg/src/svddc.rs | 2 +- 6 files changed, 88 insertions(+), 93 deletions(-) diff --git a/lax/src/opnorm.rs b/lax/src/opnorm.rs index c76d9f59..0c594d92 100644 --- a/lax/src/opnorm.rs +++ b/lax/src/opnorm.rs @@ -7,13 +7,13 @@ use num_traits::Zero; pub use super::NormType; pub trait OperatorNorm_: Scalar { - unsafe fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real; + fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real; } macro_rules! impl_opnorm { ($scalar:ty, $lange:path) => { impl OperatorNorm_ for $scalar { - unsafe fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real { + fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real { let m = l.lda(); let n = l.len(); let t = match l { @@ -25,7 +25,7 @@ macro_rules! impl_opnorm { } else { Vec::new() }; - $lange(t as u8, m, n, a, m, &mut work) + unsafe { $lange(t as u8, m, n, a, m, &mut work) } } } }; diff --git a/lax/src/svd.rs b/lax/src/svd.rs index d51ffd27..d5e48ff5 100644 --- a/lax/src/svd.rs +++ b/lax/src/svd.rs @@ -36,12 +36,8 @@ pub struct SVDOutput { /// Wraps `*gesvd` pub trait SVD_: Scalar { /// Calculate singular value decomposition $ A = U \Sigma V^T $ - unsafe fn svd( - l: MatrixLayout, - calc_u: bool, - calc_vt: bool, - a: &mut [Self], - ) -> Result>; + fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self]) + -> Result>; } macro_rules! impl_svd { @@ -53,12 +49,7 @@ macro_rules! impl_svd { }; (@body, $scalar:ty, $gesvd:path, $($rwork_ident:ident),*) => { impl SVD_ for $scalar { - unsafe fn svd( - l: MatrixLayout, - calc_u: bool, - calc_vt: bool, - mut a: &mut [Self], - ) -> Result> { + fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, mut a: &mut [Self],) -> Result> { let ju = match l { MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u), MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt), @@ -90,45 +81,49 @@ macro_rules! impl_svd { // eval work size let mut info = 0; let mut work_size = [Self::zero()]; - $gesvd( - ju as u8, - jvt as u8, - m, - n, - &mut a, - m, - &mut s, - u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - m, - vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work_size, - -1, - $(&mut $rwork_ident,)* - &mut info, - ); + unsafe { + $gesvd( + ju as u8, + jvt as u8, + m, + n, + &mut a, + m, + &mut s, + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work_size, + -1, + $(&mut $rwork_ident,)* + &mut info, + ); + } info.as_lapack_result()?; // calc let lwork = work_size[0].to_usize().unwrap(); let mut work = vec![Self::zero(); lwork]; - $gesvd( - ju as u8, - jvt as u8, - m, - n, - &mut a, - m, - &mut s, - u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - m, - vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work, - lwork as i32, - $(&mut $rwork_ident,)* - &mut info, - ); + unsafe { + $gesvd( + ju as u8, + jvt as u8, + m, + n, + &mut a, + m, + &mut s, + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work, + lwork as i32, + $(&mut $rwork_ident,)* + &mut info, + ); + } info.as_lapack_result()?; match l { MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }), diff --git a/lax/src/svddc.rs b/lax/src/svddc.rs index 3e50d7bb..12ed129d 100644 --- a/lax/src/svddc.rs +++ b/lax/src/svddc.rs @@ -18,7 +18,7 @@ pub enum UVTFlag { } pub trait SVDDC_: Scalar { - unsafe fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result>; + fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result>; } macro_rules! impl_svddc { @@ -30,11 +30,7 @@ macro_rules! impl_svddc { }; (@body, $scalar:ty, $gesdd:path, $($rwork_ident:ident),*) => { impl SVDDC_ for $scalar { - unsafe fn svddc( - l: MatrixLayout, - jobz: UVTFlag, - mut a: &mut [Self], - ) -> Result> { + fn svddc(l: MatrixLayout, jobz: UVTFlag, mut a: &mut [Self],) -> Result> { let m = l.lda(); let n = l.len(); let k = m.min(n); @@ -70,45 +66,49 @@ macro_rules! impl_svddc { let mut info = 0; let mut iwork = vec![0; 8 * k as usize]; let mut work_size = [Self::zero()]; - $gesdd( - jobz as u8, - m, - n, - &mut a, - m, - &mut s, - u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - m, - vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - vt_row, - &mut work_size, - -1, - $(&mut $rwork_ident,)* - &mut iwork, - &mut info, - ); + unsafe { + $gesdd( + jobz as u8, + m, + n, + &mut a, + m, + &mut s, + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + vt_row, + &mut work_size, + -1, + $(&mut $rwork_ident,)* + &mut iwork, + &mut info, + ); + } info.as_lapack_result()?; // do svd let lwork = work_size[0].to_usize().unwrap(); let mut work = vec![Self::zero(); lwork]; - $gesdd( - jobz as u8, - m, - n, - &mut a, - m, - &mut s, - u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - m, - vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - vt_row, - &mut work, - lwork as i32, - $(&mut $rwork_ident,)* - &mut iwork, - &mut info, - ); + unsafe { + $gesdd( + jobz as u8, + m, + n, + &mut a, + m, + &mut s, + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + vt_row, + &mut work, + lwork as i32, + $(&mut $rwork_ident,)* + &mut iwork, + &mut info, + ); + } info.as_lapack_result()?; match l { diff --git a/ndarray-linalg/src/opnorm.rs b/ndarray-linalg/src/opnorm.rs index 4358fd7f..b4956867 100644 --- a/ndarray-linalg/src/opnorm.rs +++ b/ndarray-linalg/src/opnorm.rs @@ -45,7 +45,7 @@ where fn opnorm(&self, t: NormType) -> Result { let l = self.layout()?; let a = self.as_allocated()?; - Ok(unsafe { A::opnorm(t, l, a) }) + Ok(A::opnorm(t, l, a)) } } @@ -108,6 +108,6 @@ where let l = arr.layout()?; let a = arr.as_allocated()?; - Ok(unsafe { A::opnorm(t, l, a) }) + Ok(A::opnorm(t, l, a)) } } diff --git a/ndarray-linalg/src/svd.rs b/ndarray-linalg/src/svd.rs index 486982cd..5c9d59b1 100644 --- a/ndarray-linalg/src/svd.rs +++ b/ndarray-linalg/src/svd.rs @@ -93,7 +93,7 @@ where calc_vt: bool, ) -> Result<(Option, Self::Sigma, Option)> { let l = self.layout()?; - let svd_res = unsafe { A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)? }; + let svd_res = A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)?; let (n, m) = l.size(); let u = svd_res.u.map(|u| into_matrix(l.resized(n, n), u).unwrap()); diff --git a/ndarray-linalg/src/svddc.rs b/ndarray-linalg/src/svddc.rs index b27212ec..c79fa0e2 100644 --- a/ndarray-linalg/src/svddc.rs +++ b/ndarray-linalg/src/svddc.rs @@ -80,7 +80,7 @@ where uvt_flag: UVTFlag, ) -> Result<(Option, Self::Sigma, Option)> { let l = self.layout()?; - let svd_res = unsafe { A::svddc(l, uvt_flag, self.as_allocated_mut()?)? }; + let svd_res = A::svddc(l, uvt_flag, self.as_allocated_mut()?)?; let (m, n) = l.size(); let k = m.min(n); From 6571d86b650b4203561c10eb34604d8ebb235d20 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 26 Jul 2020 18:58:25 +0900 Subject: [PATCH 49/49] Use sup norm for testing least_squares --- ndarray-linalg/tests/least_squares.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ndarray-linalg/tests/least_squares.rs b/ndarray-linalg/tests/least_squares.rs index e2df3370..33e20ca7 100644 --- a/ndarray-linalg/tests/least_squares.rs +++ b/ndarray-linalg/tests/least_squares.rs @@ -19,7 +19,7 @@ fn test_exact(a: Array2) { // b == Ax let ax = a.dot(&x); - assert_close_l2!(&b, &ax, T::real(1.0e-4)); + assert_close_max!(&b, &ax, T::real(1.0e-4)); } macro_rules! impl_exact { @@ -102,7 +102,7 @@ fn test_underdetermined(a: Array2) { // b == Ax let x = result.solution; let ax = a.dot(&x); - assert_close_l2!(&b, &ax, T::real(1.0e-4)); + assert_close_max!(&b, &ax, T::real(1.0e-4)); } macro_rules! impl_underdetermined {