diff --git a/lax/Cargo.toml b/lax/Cargo.toml index e42c8d74..84604074 100644 --- a/lax/Cargo.toml +++ b/lax/Cargo.toml @@ -11,10 +11,10 @@ netlib = ["lapack-src/netlib", "blas-src/netlib"] openblas = ["lapack-src/openblas", "blas-src/openblas"] [dependencies] -thiserror = "1" -cauchy = "0.2" -lapacke = "0.2.0" +thiserror = "1.0" +cauchy = "0.2.0" num-traits = "0.2" +lapack = "0.16.0" [dependencies.blas-src] version = "0.6.1" diff --git a/lax/src/cholesky.rs b/lax/src/cholesky.rs index ef9473b4..8305efe5 100644 --- a/lax/src/cholesky.rs +++ b/lax/src/cholesky.rs @@ -1,56 +1,90 @@ //! Cholesky decomposition use super::*; -use crate::{error::*, layout::MatrixLayout}; +use crate::{error::*, layout::*}; use cauchy::*; pub trait Cholesky_: Sized { /// Cholesky: wrapper of `*potrf` /// /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.** - unsafe fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + /// Wrapper of `*potri` /// /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.** - unsafe fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + /// Wrapper of `*potrs` - unsafe fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) - -> Result<()>; + fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; } macro_rules! impl_cholesky { ($scalar:ty, $trf:path, $tri:path, $trs:path) => { impl Cholesky_ for $scalar { - unsafe fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); - $trf(l.lapacke_layout(), uplo as u8, n, a, n).as_lapack_result()?; + if matches!(l, MatrixLayout::C { .. }) { + square_transpose(l, a); + } + let mut info = 0; + unsafe { + $trf(uplo as u8, n, a, n, &mut info); + } + info.as_lapack_result()?; + if matches!(l, MatrixLayout::C { .. }) { + square_transpose(l, a); + } Ok(()) } - unsafe fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); - $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda()).as_lapack_result()?; + if matches!(l, MatrixLayout::C { .. }) { + square_transpose(l, a); + } + let mut info = 0; + unsafe { + $tri(uplo as u8, n, a, l.lda(), &mut info); + } + info.as_lapack_result()?; + if matches!(l, MatrixLayout::C { .. }) { + square_transpose(l, a); + } Ok(()) } - unsafe fn solve_cholesky( + fn solve_cholesky( l: MatrixLayout, - uplo: UPLO, + mut uplo: UPLO, a: &[Self], b: &mut [Self], ) -> Result<()> { let (n, _) = l.size(); let nrhs = 1; - let ldb = 1; - $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), b, ldb) - .as_lapack_result()?; + let mut info = 0; + if matches!(l, MatrixLayout::C { .. }) { + uplo = uplo.t(); + for val in b.iter_mut() { + *val = val.conj(); + } + } + unsafe { + $trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info); + } + info.as_lapack_result()?; + if matches!(l, MatrixLayout::C { .. }) { + for val in b.iter_mut() { + *val = val.conj(); + } + } Ok(()) } } }; } // end macro_rules -impl_cholesky!(f64, lapacke::dpotrf, lapacke::dpotri, lapacke::dpotrs); -impl_cholesky!(f32, lapacke::spotrf, lapacke::spotri, lapacke::spotrs); -impl_cholesky!(c64, lapacke::zpotrf, lapacke::zpotri, lapacke::zpotrs); -impl_cholesky!(c32, lapacke::cpotrf, lapacke::cpotri, lapacke::cpotrs); +impl_cholesky!(f64, lapack::dpotrf, lapack::dpotri, lapack::dpotrs); +impl_cholesky!(f32, lapack::spotrf, lapack::spotri, lapack::spotrs); +impl_cholesky!(c64, lapack::zpotrf, lapack::zpotri, lapack::zpotrs); +impl_cholesky!(c32, lapack::cpotrf, lapack::cpotri, lapack::cpotrs); diff --git a/lax/src/eig.rs b/lax/src/eig.rs index e8f9a8f8..0a34977d 100644 --- a/lax/src/eig.rs +++ b/lax/src/eig.rs @@ -2,11 +2,12 @@ use crate::{error::*, layout::MatrixLayout}; use cauchy::*; -use num_traits::Zero; +use num_traits::{ToPrimitive, Zero}; -/// Wraps `*geev` for real/complex +/// Wraps `*geev` for general matrices pub trait Eig_: Scalar { - unsafe fn eig( + /// Calculate Right eigenvalue + fn eig( calc_v: bool, l: MatrixLayout, a: &mut [Self], @@ -16,117 +17,242 @@ pub trait Eig_: Scalar { macro_rules! impl_eig_complex { ($scalar:ty, $ev:path) => { impl Eig_ for $scalar { - unsafe fn eig( + fn eig( calc_v: bool, l: MatrixLayout, mut a: &mut [Self], ) -> Result<(Vec, Vec)> { let (n, _) = l.size(); - let jobvr = if calc_v { b'V' } else { b'N' }; - let mut w = vec![Self::Complex::zero(); n as usize]; - let mut vl = Vec::new(); - let mut vr = vec![Self::Complex::zero(); (n * n) as usize]; - $ev( - l.lapacke_layout(), - b'N', - jobvr, - n, - &mut a, - n, - &mut w, - &mut vl, - n, - &mut vr, - n, - ) - .as_lapack_result()?; - Ok((w, vr)) + // Because LAPACK assumes F-continious array, C-continious array should be taken Hermitian conjugate. + // However, we utilize a fact that left eigenvector of A^H corresponds to the right eigenvector of A + let (jobvl, jobvr) = if calc_v { + match l { + MatrixLayout::C { .. } => (b'V', b'N'), + MatrixLayout::F { .. } => (b'N', b'V'), + } + } else { + (b'N', b'N') + }; + let mut eigs = vec![Self::Complex::zero(); n as usize]; + let mut rwork = vec![Self::Real::zero(); 2 * n as usize]; + + let mut vl = if jobvl == b'V' { + Some(vec![Self::Complex::zero(); (n * n) as usize]) + } else { + None + }; + let mut vr = if jobvr == b'V' { + Some(vec![Self::Complex::zero(); (n * n) as usize]) + } else { + None + }; + + // calc work size + let mut info = 0; + let mut work_size = [Self::zero()]; + unsafe { + $ev( + jobvl, + jobvr, + n, + &mut a, + n, + &mut eigs, + &mut vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), + n, + &mut vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work_size, + -1, + &mut rwork, + &mut info, + ) + }; + info.as_lapack_result()?; + + // actal ev + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + $ev( + jobvl, + jobvr, + n, + &mut a, + n, + &mut eigs, + &mut vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), + n, + &mut vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work, + lwork as i32, + &mut rwork, + &mut info, + ) + }; + info.as_lapack_result()?; + + // Hermite conjugate + if jobvl == b'V' { + for c in vl.as_mut().unwrap().iter_mut() { + c.im = -c.im + } + } + + Ok((eigs, vr.or(vl).unwrap_or(Vec::new()))) } } }; } +impl_eig_complex!(c64, lapack::zgeev); +impl_eig_complex!(c32, lapack::cgeev); + macro_rules! impl_eig_real { ($scalar:ty, $ev:path) => { impl Eig_ for $scalar { - unsafe fn eig( + fn eig( calc_v: bool, l: MatrixLayout, mut a: &mut [Self], ) -> Result<(Vec, Vec)> { let (n, _) = l.size(); - let jobvr = if calc_v { b'V' } else { b'N' }; - let mut wr = vec![Self::Real::zero(); n as usize]; - let mut wi = vec![Self::Real::zero(); n as usize]; - let mut vl = Vec::new(); - let mut vr = vec![Self::Real::zero(); (n * n) as usize]; - let info = $ev( - l.lapacke_layout(), - b'N', - jobvr, - n, - &mut a, - n, - &mut wr, - &mut wi, - &mut vl, - n, - &mut vr, - n, - ); - let w: Vec = wr + // Because LAPACK assumes F-continious array, C-continious array should be taken Hermitian conjugate. + // However, we utilize a fact that left eigenvector of A^H corresponds to the right eigenvector of A + let (jobvl, jobvr) = if calc_v { + match l { + MatrixLayout::C { .. } => (b'V', b'N'), + MatrixLayout::F { .. } => (b'N', b'V'), + } + } else { + (b'N', b'N') + }; + let mut eig_re = vec![Self::zero(); n as usize]; + let mut eig_im = vec![Self::zero(); n as usize]; + + let mut vl = if jobvl == b'V' { + Some(vec![Self::zero(); (n * n) as usize]) + } else { + None + }; + let mut vr = if jobvr == b'V' { + Some(vec![Self::zero(); (n * n) as usize]) + } else { + None + }; + + // calc work size + let mut info = 0; + let mut work_size = [0.0]; + unsafe { + $ev( + jobvl, + jobvr, + n, + &mut a, + n, + &mut eig_re, + &mut eig_im, + vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), + n, + vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work_size, + -1, + &mut info, + ) + }; + info.as_lapack_result()?; + + // actual ev + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + $ev( + jobvl, + jobvr, + n, + &mut a, + n, + &mut eig_re, + &mut eig_im, + vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), + n, + vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work, + lwork as i32, + &mut info, + ) + }; + info.as_lapack_result()?; + + // reconstruct eigenvalues + let eigs: Vec = eig_re .iter() - .zip(wi.iter()) - .map(|(&r, &i)| Self::Complex::new(r, i)) + .zip(eig_im.iter()) + .map(|(&re, &im)| Self::complex(re, im)) .collect(); - // If the j-th eigenvalue is real, then - // eigenvector = [ vr[j], vr[j+n], vr[j+2*n], ... ]. + if !calc_v { + return Ok((eigs, Vec::new())); + } + + // Reconstruct eigenvectors into complex-array + // -------------------------------------------- // - // If the j-th and (j+1)-st eigenvalues form a complex conjugate pair, - // eigenvector(j) = [ vr[j] + i*vr[j+1], vr[j+n] + i*vr[j+n+1], vr[j+2*n] + i*vr[j+2*n+1], ... ] and - // eigenvector(j+1) = [ vr[j] - i*vr[j+1], vr[j+n] - i*vr[j+n+1], vr[j+2*n] - i*vr[j+2*n+1], ... ]. + // From LAPACK API https://software.intel.com/en-us/node/469230 // - // Therefore, if eigenvector(j) is written as [ v_{j0}, v_{j1}, v_{j2}, ... ], - // you have to make - // v = vec![ v_{00}, v_{10}, v_{20}, ..., v_{jk}, v_{(j+1)k}, v_{(j+2)k}, ... ] (v.len() = n*n) - // based on wi and vr. - // After that, v is converted to Array2 (see ../eig.rs). + // - If the j-th eigenvalue is real, + // - v(j) = VR(:,j), the j-th column of VR. + // + // - If the j-th and (j+1)-st eigenvalues form a complex conjugate pair, + // - v(j) = VR(:,j) + i*VR(:,j+1) + // - v(j+1) = VR(:,j) - i*VR(:,j+1). + // + // ``` + // j -> <----pair----> <----pair----> + // [ ... (real), (imag), (imag), (imag), (imag), ... ] : eigs + // ^ ^ ^ ^ ^ + // false false true false true : is_conjugate_pair + // ``` let n = n as usize; - let mut flg = false; - let conj: Vec = wi - .iter() - .map(|&i| { - if flg { - flg = false; - -1 - } else if i != 0.0 { - flg = true; - 1 - } else { - 0 + let v = vr.or(vl).unwrap(); + let mut eigvecs = vec![Self::Complex::zero(); n * n]; + let mut is_conjugate_pair = false; // flag for check `j` is complex conjugate + for j in 0..n { + if eig_im[j] == 0.0 { + // j-th eigenvalue is real + for i in 0..n { + eigvecs[i + j * n] = Self::complex(v[i + j * n], 0.0); } - }) - .collect(); - let v: Vec = (0..n * n) - .map(|i| { - let j = i % n; - match conj[j] { - 1 => Self::Complex::new(vr[i], vr[i + 1]), - -1 => Self::Complex::new(vr[i - 1], -vr[i]), - _ => Self::Complex::new(vr[i], 0.0), + } else { + // j-th eigenvalue is complex + // complex conjugated pair can be `j-1` or `j+1` + if is_conjugate_pair { + let j_pair = j - 1; + assert!(j_pair < n); + for i in 0..n { + eigvecs[i + j * n] = Self::complex(v[i + j_pair * n], v[i + j * n]); + } + } else { + let j_pair = j + 1; + assert!(j_pair < n); + for i in 0..n { + eigvecs[i + j * n] = + Self::complex(v[i + j * n], -v[i + j_pair * n]); + } } - }) - .collect(); + is_conjugate_pair = !is_conjugate_pair; + } + } - info.as_lapack_result()?; - Ok((w, v)) + Ok((eigs, eigvecs)) } } }; } -impl_eig_real!(f64, lapacke::dgeev); -impl_eig_real!(f32, lapacke::sgeev); -impl_eig_complex!(c64, lapacke::zgeev); -impl_eig_complex!(c32, lapacke::cgeev); +impl_eig_real!(f64, lapack::dgeev); +impl_eig_real!(f32, lapack::sgeev); diff --git a/lax/src/eigh.rs b/lax/src/eigh.rs index a992a67e..0920dfa1 100644 --- a/lax/src/eigh.rs +++ b/lax/src/eigh.rs @@ -1,21 +1,23 @@ -//! Eigenvalue decomposition for Hermite matrices +//! Eigenvalue decomposition for Symmetric/Hermite matrices use super::*; use crate::{error::*, layout::MatrixLayout}; use cauchy::*; -use num_traits::Zero; +use num_traits::{ToPrimitive, Zero}; -/// Wraps `*syev` for real and `*heev` for complex pub trait Eigh_: Scalar { - unsafe fn eigh( + /// Wraps `*syev` for real and `*heev` for complex + fn eigh( calc_eigenvec: bool, - l: MatrixLayout, + layout: MatrixLayout, uplo: UPLO, a: &mut [Self], ) -> Result>; - unsafe fn eigh_generalized( + + /// Wraps `*syegv` for real and `*heegv` for complex + fn eigh_generalized( calc_eigenvec: bool, - l: MatrixLayout, + layout: MatrixLayout, uplo: UPLO, a: &mut [Self], b: &mut [Self], @@ -23,52 +25,135 @@ pub trait Eigh_: Scalar { } macro_rules! impl_eigh { - ($scalar:ty, $ev:path, $evg:path) => { + (@real, $scalar:ty, $ev:path, $evg:path) => { + impl_eigh!(@body, $scalar, $ev, $evg, ); + }; + (@complex, $scalar:ty, $ev:path, $evg:path) => { + impl_eigh!(@body, $scalar, $ev, $evg, rwork); + }; + (@body, $scalar:ty, $ev:path, $evg:path, $($rwork_ident:ident),*) => { impl Eigh_ for $scalar { - unsafe fn eigh( + fn eigh( calc_v: bool, - l: MatrixLayout, + layout: MatrixLayout, uplo: UPLO, mut a: &mut [Self], ) -> Result> { - let (n, _) = l.size(); + assert_eq!(layout.len(), layout.lda()); + let n = layout.len(); let jobz = if calc_v { b'V' } else { b'N' }; - let mut w = vec![Self::Real::zero(); n as usize]; - $ev(l.lapacke_layout(), jobz, uplo as u8, n, &mut a, n, &mut w) - .as_lapack_result()?; - Ok(w) + let mut eigs = vec![Self::Real::zero(); n as usize]; + + $( + let mut $rwork_ident = vec![Self::Real::zero(); 3 * n as usize - 2]; + )* + + // calc work size + let mut info = 0; + let mut work_size = [Self::zero()]; + unsafe { + $ev( + jobz, + uplo as u8, + n, + &mut a, + n, + &mut eigs, + &mut work_size, + -1, + $(&mut $rwork_ident,)* + &mut info, + ); + } + info.as_lapack_result()?; + + // actual ev + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + $ev( + jobz, + uplo as u8, + n, + &mut a, + n, + &mut eigs, + &mut work, + lwork as i32, + $(&mut $rwork_ident,)* + &mut info, + ); + } + info.as_lapack_result()?; + Ok(eigs) } - unsafe fn eigh_generalized( + fn eigh_generalized( calc_v: bool, - l: MatrixLayout, + layout: MatrixLayout, uplo: UPLO, mut a: &mut [Self], mut b: &mut [Self], ) -> Result> { - let (n, _) = l.size(); + assert_eq!(layout.len(), layout.lda()); + let n = layout.len(); let jobz = if calc_v { b'V' } else { b'N' }; - let mut w = vec![Self::Real::zero(); n as usize]; - $evg( - l.lapacke_layout(), - 1, - jobz, - uplo as u8, - n, - &mut a, - n, - &mut b, - n, - &mut w, - ) - .as_lapack_result()?; - Ok(w) + let mut eigs = vec![Self::Real::zero(); n as usize]; + + $( + let mut $rwork_ident = vec![Self::Real::zero(); 3 * n as usize - 2]; + )* + + // calc work size + let mut info = 0; + let mut work_size = [Self::zero()]; + unsafe { + $evg( + &[1], + jobz, + uplo as u8, + n, + &mut a, + n, + &mut b, + n, + &mut eigs, + &mut work_size, + -1, + $(&mut $rwork_ident,)* + &mut info, + ); + } + info.as_lapack_result()?; + + // actual evg + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + $evg( + &[1], + jobz, + uplo as u8, + n, + &mut a, + n, + &mut b, + n, + &mut eigs, + &mut work, + lwork as i32, + $(&mut $rwork_ident,)* + &mut info, + ); + } + info.as_lapack_result()?; + Ok(eigs) } } }; } // impl_eigh! -impl_eigh!(f64, lapacke::dsyev, lapacke::dsygv); -impl_eigh!(f32, lapacke::ssyev, lapacke::ssygv); -impl_eigh!(c64, lapacke::zheev, lapacke::zhegv); -impl_eigh!(c32, lapacke::cheev, lapacke::chegv); +impl_eigh!(@real, f64, lapack::dsyev, lapack::dsygv); +impl_eigh!(@real, f32, lapack::ssyev, lapack::ssygv); +impl_eigh!(@complex, c64, lapack::zheev, lapack::zhegv); +impl_eigh!(@complex, c32, lapack::cheev, lapack::chegv); diff --git a/lax/src/layout.rs b/lax/src/layout.rs index aa9fe110..e7ab1da4 100644 --- a/lax/src/layout.rs +++ b/lax/src/layout.rs @@ -37,6 +37,8 @@ //! This `S` for a matrix `A` is called "leading dimension of the array A" in LAPACK document, and denoted by `lda`. //! +use cauchy::Scalar; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum MatrixLayout { C { row: i32, lda: i32 }, @@ -78,15 +80,12 @@ impl MatrixLayout { self.len() == 0 } - pub fn lapacke_layout(&self) -> lapacke::Layout { - match *self { - MatrixLayout::C { .. } => lapacke::Layout::RowMajor, - MatrixLayout::F { .. } => lapacke::Layout::ColumnMajor, - } - } - pub fn same_order(&self, other: &MatrixLayout) -> bool { - self.lapacke_layout() == other.lapacke_layout() + match (self, other) { + (MatrixLayout::C { .. }, MatrixLayout::C { .. }) => true, + (MatrixLayout::F { .. }, MatrixLayout::F { .. }) => true, + _ => false, + } } pub fn toggle_order(&self) -> Self { @@ -95,4 +94,132 @@ impl MatrixLayout { MatrixLayout::F { col, lda } => MatrixLayout::C { row: lda, lda: col }, } } + + /// Transpose without changing memory representation + /// + /// C-contigious row=2, lda=3 + /// + /// ```text + /// [[1, 2, 3] + /// [4, 5, 6]] + /// ``` + /// + /// and F-contigious col=2, lda=3 + /// + /// ```text + /// [[1, 4] + /// [2, 5] + /// [3, 6]] + /// ``` + /// + /// have same memory representation `[1, 2, 3, 4, 5, 6]`, and this toggles them. + /// + /// ``` + /// # use lax::layout::*; + /// let layout = MatrixLayout::C { row: 2, lda: 3 }; + /// assert_eq!(layout.t(), MatrixLayout::F { col: 2, lda: 3 }); + /// ``` + pub fn t(&self) -> Self { + match *self { + MatrixLayout::C { row, lda } => MatrixLayout::F { col: row, lda }, + MatrixLayout::F { col, lda } => MatrixLayout::C { row: col, lda }, + } + } +} + +/// In-place transpose of a square matrix by keeping F/C layout +/// +/// Transpose for C-continuous array +/// +/// ```rust +/// # use lax::layout::*; +/// let layout = MatrixLayout::C { row: 2, lda: 2 }; +/// let mut a = vec![1., 2., 3., 4.]; +/// square_transpose(layout, &mut a); +/// assert_eq!(a, &[1., 3., 2., 4.]); +/// ``` +/// +/// Transpose for F-continuous array +/// +/// ```rust +/// # use lax::layout::*; +/// let layout = MatrixLayout::F { col: 2, lda: 2 }; +/// let mut a = vec![1., 3., 2., 4.]; +/// square_transpose(layout, &mut a); +/// assert_eq!(a, &[1., 2., 3., 4.]); +/// ``` +/// +/// Panics +/// ------ +/// - If size of `a` and `layout` size mismatch +/// +pub fn square_transpose(layout: MatrixLayout, a: &mut [T]) { + let (m, n) = layout.size(); + let n = n as usize; + let m = m as usize; + assert_eq!(a.len(), n * m); + for i in 0..m { + for j in (i + 1)..n { + let a_ij = a[i * n + j]; + let a_ji = a[j * m + i]; + a[i * n + j] = a_ji.conj(); + a[j * m + i] = a_ij.conj(); + } + } +} + +/// Out-place transpose for general matrix +/// +/// Inplace transpose of non-square matrices is hard. +/// See also: https://en.wikipedia.org/wiki/In-place_matrix_transposition +/// +/// ```rust +/// # use lax::layout::*; +/// let layout = MatrixLayout::C { row: 2, lda: 3 }; +/// let a = vec![1., 2., 3., 4., 5., 6.]; +/// let mut b = vec![0.0; a.len()]; +/// let l = transpose(layout, &a, &mut b); +/// assert_eq!(l, MatrixLayout::F { col: 3, lda: 2 }); +/// assert_eq!(b, &[1., 4., 2., 5., 3., 6.]); +/// ``` +/// +/// ```rust +/// # use lax::layout::*; +/// let layout = MatrixLayout::F { col: 2, lda: 3 }; +/// let a = vec![1., 2., 3., 4., 5., 6.]; +/// let mut b = vec![0.0; a.len()]; +/// let l = transpose(layout, &a, &mut b); +/// assert_eq!(l, MatrixLayout::C { row: 3, lda: 2 }); +/// assert_eq!(b, &[1., 4., 2., 5., 3., 6.]); +/// ``` +/// +/// Panics +/// ------ +/// - If size of `a` and `layout` size mismatch +/// +pub fn transpose(layout: MatrixLayout, from: &[T], to: &mut [T]) -> MatrixLayout { + let (m, n) = layout.size(); + let transposed = layout.resized(n, m).t(); + let m = m as usize; + let n = n as usize; + assert_eq!(from.len(), m * n); + assert_eq!(to.len(), m * n); + + match layout { + MatrixLayout::C { .. } => { + for i in 0..m { + for j in 0..n { + to[j * m + i] = from[i * n + j]; + } + } + } + MatrixLayout::F { .. } => { + for i in 0..m { + for j in 0..n { + to[i * n + j] = from[j * m + i]; + } + } + } + } + transposed } diff --git a/lax/src/least_squares.rs b/lax/src/least_squares.rs index 69553a44..d684c9b8 100644 --- a/lax/src/least_squares.rs +++ b/lax/src/least_squares.rs @@ -1,8 +1,8 @@ //! Least squares -use crate::{error::*, layout::MatrixLayout}; +use crate::{error::*, layout::*}; use cauchy::*; -use num_traits::Zero; +use num_traits::{ToPrimitive, Zero}; /// Result of LeastSquares pub struct LeastSquaresOutput { @@ -14,13 +14,13 @@ pub struct LeastSquaresOutput { /// Wraps `*gelsd` pub trait LeastSquaresSvdDivideConquer_: Scalar { - unsafe fn least_squares( + fn least_squares( a_layout: MatrixLayout, a: &mut [Self], b: &mut [Self], ) -> Result>; - unsafe fn least_squares_nrhs( + fn least_squares_nrhs( a_layout: MatrixLayout, a: &mut [Self], b_layout: MatrixLayout, @@ -29,81 +29,129 @@ pub trait LeastSquaresSvdDivideConquer_: Scalar { } macro_rules! impl_least_squares { - ($scalar:ty, $gelsd:path) => { + (@real, $scalar:ty, $gelsd:path) => { + impl_least_squares!(@body, $scalar, $gelsd, ); + }; + (@complex, $scalar:ty, $gelsd:path) => { + impl_least_squares!(@body, $scalar, $gelsd, rwork); + }; + + (@body, $scalar:ty, $gelsd:path, $($rwork:ident),*) => { impl LeastSquaresSvdDivideConquer_ for $scalar { - unsafe fn least_squares( - a_layout: MatrixLayout, + fn least_squares( + l: MatrixLayout, a: &mut [Self], b: &mut [Self], ) -> Result> { - let (m, n) = a_layout.size(); - if (m as usize) > b.len() || (n as usize) > b.len() { - return Err(Error::InvalidShape); - } - let k = ::std::cmp::min(m, n); - let nrhs = 1; - let ldb = match a_layout { - MatrixLayout::F { .. } => m.max(n), - MatrixLayout::C { .. } => 1, - }; - let rcond: Self::Real = -1.; - let mut singular_values: Vec = vec![Self::Real::zero(); k as usize]; - let mut rank: i32 = 0; - - $gelsd( - a_layout.lapacke_layout(), - m, - n, - nrhs, - a, - a_layout.lda(), - b, - ldb, - &mut singular_values, - rcond, - &mut rank, - ) - .as_lapack_result()?; - - Ok(LeastSquaresOutput { - singular_values, - rank, - }) + let b_layout = l.resized(b.len() as i32, 1); + Self::least_squares_nrhs(l, a, b_layout, b) } - unsafe fn least_squares_nrhs( + fn least_squares_nrhs( a_layout: MatrixLayout, a: &mut [Self], b_layout: MatrixLayout, b: &mut [Self], ) -> Result> { + // Minimize |b - Ax|_2 + // + // where + // A : (m, n) + // b : (max(m, n), nrhs) // `b` has to store `x` on exit + // x : (n, nrhs) let (m, n) = a_layout.size(); - if (m as usize) > b.len() - || (n as usize) > b.len() - || a_layout.lapacke_layout() != b_layout.lapacke_layout() - { - return Err(Error::InvalidShape); - } - let k = ::std::cmp::min(m, n); - let nrhs = b_layout.size().1; + let (m_, nrhs) = b_layout.size(); + let k = m.min(n); + assert!(m_ >= m); + + // Transpose if a is C-continuous + let mut a_t = None; + let a_layout = match a_layout { + MatrixLayout::C { .. } => { + a_t = Some(vec![Self::zero(); a.len()]); + transpose(a_layout, a, a_t.as_mut().unwrap()) + } + MatrixLayout::F { .. } => a_layout, + }; + + // Transpose if b is C-continuous + let mut b_t = None; + let b_layout = match b_layout { + MatrixLayout::C { .. } => { + b_t = Some(vec![Self::zero(); b.len()]); + transpose(b_layout, b, b_t.as_mut().unwrap()) + } + MatrixLayout::F { .. } => b_layout, + }; + let rcond: Self::Real = -1.; let mut singular_values: Vec = vec![Self::Real::zero(); k as usize]; let mut rank: i32 = 0; - $gelsd( - a_layout.lapacke_layout(), - m, - n, - nrhs, - a, - a_layout.lda(), - b, - b_layout.lda(), - &mut singular_values, - rcond, - &mut rank, - ) - .as_lapack_result()?; + // eval work size + let mut info = 0; + let mut work_size = [Self::zero()]; + let mut iwork_size = [0]; + $( + let mut $rwork = [Self::Real::zero()]; + )* + unsafe { + $gelsd( + m, + n, + nrhs, + a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a), + a_layout.lda(), + b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), + b_layout.lda(), + &mut singular_values, + rcond, + &mut rank, + &mut work_size, + -1, + $(&mut $rwork,)* + &mut iwork_size, + &mut info, + ) + }; + info.as_lapack_result()?; + + // calc + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + let liwork = iwork_size[0].to_usize().unwrap(); + let mut iwork = vec![0; liwork]; + $( + let lrwork = $rwork[0].to_usize().unwrap(); + let mut $rwork = vec![Self::Real::zero(); lrwork]; + )* + unsafe { + $gelsd( + m, + n, + nrhs, + a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a), + a_layout.lda(), + b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), + b_layout.lda(), + &mut singular_values, + rcond, + &mut rank, + &mut work, + lwork as i32, + $(&mut $rwork,)* + &mut iwork, + &mut info, + ); + } + info.as_lapack_result()?; + + // Skip a_t -> a transpose because A has been destroyed + // Re-transpose b + if let Some(b_t) = b_t { + transpose(b_layout, &b_t, b); + } + Ok(LeastSquaresOutput { singular_values, rank, @@ -113,7 +161,7 @@ macro_rules! impl_least_squares { }; } -impl_least_squares!(f64, lapacke::dgelsd); -impl_least_squares!(f32, lapacke::sgelsd); -impl_least_squares!(c64, lapacke::zgelsd); -impl_least_squares!(c32, lapacke::cgelsd); +impl_least_squares!(@real, f64, lapack::dgelsd); +impl_least_squares!(@real, f32, lapack::sgelsd); +impl_least_squares!(@complex, c64, lapack::zgelsd); +impl_least_squares!(@complex, c32, lapack::cgelsd); diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 9565dcab..5d1fb0eb 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -70,6 +70,7 @@ pub mod layout; pub mod least_squares; pub mod opnorm; pub mod qr; +pub mod rcond; pub mod solve; pub mod solveh; pub mod svd; @@ -83,6 +84,7 @@ pub use self::eigh::*; pub use self::least_squares::*; pub use self::opnorm::*; pub use self::qr::*; +pub use self::rcond::*; pub use self::solve::*; pub use self::solveh::*; pub use self::svd::*; @@ -107,6 +109,7 @@ pub trait Lapack: + Eigh_ + Triangular_ + Tridiagonal_ + + Rcond_ + LeastSquaresSvdDivideConquer_ { } @@ -124,6 +127,15 @@ pub enum UPLO { Lower = b'L', } +impl UPLO { + pub fn t(self) -> Self { + match self { + UPLO::Upper => UPLO::Lower, + UPLO::Lower => UPLO::Upper, + } + } +} + #[derive(Debug, Clone, Copy)] #[repr(u8)] pub enum Transpose { diff --git a/lax/src/opnorm.rs b/lax/src/opnorm.rs index 4786fd6e..0c594d92 100644 --- a/lax/src/opnorm.rs +++ b/lax/src/opnorm.rs @@ -2,30 +2,36 @@ use crate::layout::MatrixLayout; use cauchy::*; -use lapacke::Layout::ColumnMajor as cm; +use num_traits::Zero; pub use super::NormType; pub trait OperatorNorm_: Scalar { - unsafe fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real; + fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real; } macro_rules! impl_opnorm { ($scalar:ty, $lange:path) => { impl OperatorNorm_ for $scalar { - unsafe fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real { - match l { - MatrixLayout::F { col, lda } => $lange(cm, t as u8, lda, col, a, lda), - MatrixLayout::C { row, lda } => { - $lange(cm, t.transpose() as u8, lda, row, a, lda) - } - } + fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real { + let m = l.lda(); + let n = l.len(); + let t = match l { + MatrixLayout::F { .. } => t, + MatrixLayout::C { .. } => t.transpose(), + }; + let mut work = if matches!(t, NormType::Infinity) { + vec![Self::Real::zero(); m as usize] + } else { + Vec::new() + }; + unsafe { $lange(t as u8, m, n, a, m, &mut work) } } } }; } // impl_opnorm! -impl_opnorm!(f64, lapacke::dlange); -impl_opnorm!(f32, lapacke::slange); -impl_opnorm!(c64, lapacke::zlange); -impl_opnorm!(c32, lapacke::clange); +impl_opnorm!(f64, lapack::dlange); +impl_opnorm!(f32, lapack::slange); +impl_opnorm!(c64, lapack::zlange); +impl_opnorm!(c32, lapack::clange); diff --git a/lax/src/qr.rs b/lax/src/qr.rs index 6c26273d..0bb00c2a 100644 --- a/lax/src/qr.rs +++ b/lax/src/qr.rs @@ -2,35 +2,120 @@ use crate::{error::*, layout::MatrixLayout}; use cauchy::*; -use num_traits::Zero; -use std::cmp::min; +use num_traits::{ToPrimitive, Zero}; -/// Wraps `*geqrf` and `*orgqr` (`*ungqr` for complex numbers) pub trait QR_: Sized { - unsafe fn householder(l: MatrixLayout, a: &mut [Self]) -> Result>; - unsafe fn q(l: MatrixLayout, a: &mut [Self], tau: &[Self]) -> Result<()>; - unsafe fn qr(l: MatrixLayout, a: &mut [Self]) -> Result>; + /// Execute Householder reflection as the first step of QR-decomposition + /// + /// For C-continuous array, + /// this will call LQ-decomposition of the transposed matrix $ A^T = LQ^T $ + fn householder(l: MatrixLayout, a: &mut [Self]) -> Result>; + + /// Reconstruct Q-matrix from Householder-reflectors + fn q(l: MatrixLayout, a: &mut [Self], tau: &[Self]) -> Result<()>; + + /// Execute QR-decomposition at once + fn qr(l: MatrixLayout, a: &mut [Self]) -> Result>; } macro_rules! impl_qr { - ($scalar:ty, $qrf:path, $gqr:path) => { + ($scalar:ty, $qrf:path, $lqf:path, $gqr:path, $glq:path) => { impl QR_ for $scalar { - unsafe fn householder(l: MatrixLayout, mut a: &mut [Self]) -> Result> { - let (row, col) = l.size(); - let k = min(row, col); + fn householder(l: MatrixLayout, mut a: &mut [Self]) -> Result> { + let m = l.lda(); + let n = l.len(); + let k = m.min(n); let mut tau = vec![Self::zero(); k as usize]; - $qrf(l.lapacke_layout(), row, col, &mut a, l.lda(), &mut tau).as_lapack_result()?; + + // eval work size + let mut info = 0; + let mut work_size = [Self::zero()]; + unsafe { + match l { + MatrixLayout::F { .. } => { + $qrf(m, n, &mut a, m, &mut tau, &mut work_size, -1, &mut info); + } + MatrixLayout::C { .. } => { + $lqf(m, n, &mut a, m, &mut tau, &mut work_size, -1, &mut info); + } + } + } + info.as_lapack_result()?; + + // calc + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + match l { + MatrixLayout::F { .. } => { + $qrf( + m, + n, + &mut a, + m, + &mut tau, + &mut work, + lwork as i32, + &mut info, + ); + } + MatrixLayout::C { .. } => { + $lqf( + m, + n, + &mut a, + m, + &mut tau, + &mut work, + lwork as i32, + &mut info, + ); + } + } + } + info.as_lapack_result()?; + Ok(tau) } - unsafe fn q(l: MatrixLayout, mut a: &mut [Self], tau: &[Self]) -> Result<()> { - let (row, col) = l.size(); - let k = min(row, col); - $gqr(l.lapacke_layout(), row, k, k, &mut a, l.lda(), &tau).as_lapack_result()?; + fn q(l: MatrixLayout, mut a: &mut [Self], tau: &[Self]) -> Result<()> { + let m = l.lda(); + let n = l.len(); + let k = m.min(n); + assert_eq!(tau.len(), k as usize); + + // eval work size + let mut info = 0; + let mut work_size = [Self::zero()]; + unsafe { + match l { + MatrixLayout::F { .. } => { + $gqr(m, k, k, &mut a, m, &tau, &mut work_size, -1, &mut info) + } + MatrixLayout::C { .. } => { + $glq(k, n, k, &mut a, m, &tau, &mut work_size, -1, &mut info) + } + } + }; + + // calc + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + match l { + MatrixLayout::F { .. } => { + $gqr(m, k, k, &mut a, m, &tau, &mut work, lwork as i32, &mut info) + } + MatrixLayout::C { .. } => { + $glq(k, n, k, &mut a, m, &tau, &mut work, lwork as i32, &mut info) + } + } + } + info.as_lapack_result()?; Ok(()) } - unsafe fn qr(l: MatrixLayout, a: &mut [Self]) -> Result> { + fn qr(l: MatrixLayout, a: &mut [Self]) -> Result> { let tau = Self::householder(l, a)?; let r = Vec::from(&*a); Self::q(l, a, &tau)?; @@ -40,7 +125,31 @@ macro_rules! impl_qr { }; } // endmacro -impl_qr!(f64, lapacke::dgeqrf, lapacke::dorgqr); -impl_qr!(f32, lapacke::sgeqrf, lapacke::sorgqr); -impl_qr!(c64, lapacke::zgeqrf, lapacke::zungqr); -impl_qr!(c32, lapacke::cgeqrf, lapacke::cungqr); +impl_qr!( + f64, + lapack::dgeqrf, + lapack::dgelqf, + lapack::dorgqr, + lapack::dorglq +); +impl_qr!( + f32, + lapack::sgeqrf, + lapack::sgelqf, + lapack::sorgqr, + lapack::sorglq +); +impl_qr!( + c64, + lapack::zgeqrf, + lapack::zgelqf, + lapack::zungqr, + lapack::zunglq +); +impl_qr!( + c32, + lapack::cgeqrf, + lapack::cgelqf, + lapack::cungqr, + lapack::cunglq +); diff --git a/lax/src/rcond.rs b/lax/src/rcond.rs new file mode 100644 index 00000000..135c4a12 --- /dev/null +++ b/lax/src/rcond.rs @@ -0,0 +1,86 @@ +use super::*; +use crate::{error::*, layout::MatrixLayout}; +use cauchy::*; +use num_traits::Zero; + +pub trait Rcond_: Scalar + Sized { + /// Estimates the the reciprocal of the condition number of the matrix in 1-norm. + /// + /// `anorm` should be the 1-norm of the matrix `a`. + fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result; +} + +macro_rules! impl_rcond_real { + ($scalar:ty, $gecon:path) => { + impl Rcond_ for $scalar { + fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result { + let (n, _) = l.size(); + let mut rcond = Self::Real::zero(); + let mut info = 0; + + let mut work = vec![Self::zero(); 4 * n as usize]; + let mut iwork = vec![0; n as usize]; + let norm_type = match l { + MatrixLayout::C { .. } => NormType::Infinity, + MatrixLayout::F { .. } => NormType::One, + } as u8; + unsafe { + $gecon( + norm_type, + n, + a, + l.lda(), + anorm, + &mut rcond, + &mut work, + &mut iwork, + &mut info, + ) + }; + info.as_lapack_result()?; + + Ok(rcond) + } + } + }; +} + +impl_rcond_real!(f32, lapack::sgecon); +impl_rcond_real!(f64, lapack::dgecon); + +macro_rules! impl_rcond_complex { + ($scalar:ty, $gecon:path) => { + impl Rcond_ for $scalar { + fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result { + let (n, _) = l.size(); + let mut rcond = Self::Real::zero(); + let mut info = 0; + let mut work = vec![Self::zero(); 2 * n as usize]; + let mut rwork = vec![Self::Real::zero(); 2 * n as usize]; + let norm_type = match l { + MatrixLayout::C { .. } => NormType::Infinity, + MatrixLayout::F { .. } => NormType::One, + } as u8; + unsafe { + $gecon( + norm_type, + n, + a, + l.lda(), + anorm, + &mut rcond, + &mut work, + &mut rwork, + &mut info, + ) + }; + info.as_lapack_result()?; + + Ok(rcond) + } + } + }; +} + +impl_rcond_complex!(c32, lapack::cgecon); +impl_rcond_complex!(c64, lapack::zgecon); diff --git a/lax/src/solve.rs b/lax/src/solve.rs index 67af6409..93aa4722 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -3,119 +3,99 @@ use super::*; use crate::{error::*, layout::MatrixLayout}; use cauchy::*; -use num_traits::Zero; +use num_traits::{ToPrimitive, Zero}; -/// Wraps `*getrf`, `*getri`, and `*getrs` pub trait Solve_: Scalar + Sized { /// Computes the LU factorization of a general `m x n` matrix `a` using /// partial pivoting with row interchanges. /// - /// If the result matches `Err(LinalgError::Lapack(LapackError { - /// return_code )) if return_code > 0`, then `U[(return_code-1, - /// return_code-1)]` is exactly zero. The factorization has been completed, - /// but the factor `U` is exactly singular, and division by zero will occur - /// if it is used to solve a system of equations. - unsafe fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; - unsafe fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; - /// Estimates the the reciprocal of the condition number of the matrix in 1-norm. + /// $ PA = LU $ /// - /// `anorm` should be the 1-norm of the matrix `a`. - unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result; - unsafe fn solve( - l: MatrixLayout, - t: Transpose, - a: &[Self], - p: &Pivot, - b: &mut [Self], - ) -> Result<()>; + /// Error + /// ------ + /// - `LapackComputationalFailure { return_code }` when the matrix is singular + /// - Division by zero will occur if it is used to solve a system of equations + /// because `U[(return_code-1, return_code-1)]` is exactly zero. + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; + + fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; + + fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; } macro_rules! impl_solve { - ($scalar:ty, $getrf:path, $getri:path, $gecon:path, $getrs:path) => { + ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => { impl Solve_ for $scalar { - unsafe fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { let (row, col) = l.size(); + assert_eq!(a.len() as i32, row * col); + if row == 0 || col == 0 { + // Do nothing for empty matrix + return Ok(Vec::new()); + } let k = ::std::cmp::min(row, col); let mut ipiv = vec![0; k as usize]; - $getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv).as_lapack_result()?; + let mut info = 0; + unsafe { $getrf(l.lda(), l.len(), a, l.lda(), &mut ipiv, &mut info) }; + info.as_lapack_result()?; Ok(ipiv) } - unsafe fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { + fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { let (n, _) = l.size(); - $getri(l.lapacke_layout(), n, a, l.lda(), ipiv).as_lapack_result()?; - Ok(()) - } - unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result { - let (n, _) = l.size(); - let mut rcond = Self::Real::zero(); - $gecon( - l.lapacke_layout(), - NormType::One as u8, - n, - a, - l.lda(), - anorm, - &mut rcond, - ) - .as_lapack_result()?; - Ok(rcond) + // calc work size + let mut info = 0; + let mut work_size = [Self::zero()]; + unsafe { $getri(n, a, l.lda(), ipiv, &mut work_size, -1, &mut info) }; + info.as_lapack_result()?; + + // actual + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + $getri( + l.len(), + a, + l.lda(), + ipiv, + &mut work, + lwork as i32, + &mut info, + ) + }; + info.as_lapack_result()?; + + Ok(()) } - unsafe fn solve( + fn solve( l: MatrixLayout, t: Transpose, a: &[Self], ipiv: &Pivot, b: &mut [Self], ) -> Result<()> { + let t = match l { + MatrixLayout::C { .. } => match t { + Transpose::No => Transpose::Transpose, + Transpose::Transpose | Transpose::Hermite => Transpose::No, + }, + _ => t, + }; let (n, _) = l.size(); let nrhs = 1; - let ldb = 1; - $getrs( - l.lapacke_layout(), - t as u8, - n, - nrhs, - a, - l.lda(), - ipiv, - b, - ldb, - ) - .as_lapack_result()?; + let ldb = l.lda(); + let mut info = 0; + unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) }; + info.as_lapack_result()?; Ok(()) } } }; } // impl_solve! -impl_solve!( - f64, - lapacke::dgetrf, - lapacke::dgetri, - lapacke::dgecon, - lapacke::dgetrs -); -impl_solve!( - f32, - lapacke::sgetrf, - lapacke::sgetri, - lapacke::sgecon, - lapacke::sgetrs -); -impl_solve!( - c64, - lapacke::zgetrf, - lapacke::zgetri, - lapacke::zgecon, - lapacke::zgetrs -); -impl_solve!( - c32, - lapacke::cgetrf, - lapacke::cgetri, - lapacke::cgecon, - lapacke::cgetrs -); +impl_solve!(f64, lapack::dgetrf, lapack::dgetri, lapack::dgetrs); +impl_solve!(f32, lapack::sgetrf, lapack::sgetri, lapack::sgetrs); +impl_solve!(c64, lapack::zgetrf, lapack::zgetri, lapack::zgetrs); +impl_solve!(c32, lapack::cgetrf, lapack::cgetri, lapack::cgetrs); diff --git a/lax/src/solveh.rs b/lax/src/solveh.rs index 01e90f13..da2ecdf5 100644 --- a/lax/src/solveh.rs +++ b/lax/src/solveh.rs @@ -5,50 +5,73 @@ use super::*; use crate::{error::*, layout::MatrixLayout}; use cauchy::*; +use num_traits::{ToPrimitive, Zero}; pub trait Solveh_: Sized { /// Bunch-Kaufman: wrapper of `*sytrf` and `*hetrf` - unsafe fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result; + fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result; /// Wrapper of `*sytri` and `*hetri` - unsafe fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()>; + fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()>; /// Wrapper of `*sytrs` and `*hetrs` - unsafe fn solveh( - l: MatrixLayout, - uplo: UPLO, - a: &[Self], - ipiv: &Pivot, - b: &mut [Self], - ) -> Result<()>; + fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>; } macro_rules! impl_solveh { ($scalar:ty, $trf:path, $tri:path, $trs:path) => { impl Solveh_ for $scalar { - unsafe fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result { + fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result { let (n, _) = l.size(); let mut ipiv = vec![0; n as usize]; if n == 0 { - // Work around bug in LAPACKE functions. - Ok(ipiv) - } else { - $trf(l.lapacke_layout(), uplo as u8, n, a, l.lda(), &mut ipiv) - .as_lapack_result()?; - Ok(ipiv) + return Ok(Vec::new()); } + + // calc work size + let mut info = 0; + let mut work_size = [Self::zero()]; + unsafe { + $trf( + uplo as u8, + n, + a, + l.lda(), + &mut ipiv, + &mut work_size, + -1, + &mut info, + ) + }; + info.as_lapack_result()?; + + // actual + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + $trf( + uplo as u8, + n, + a, + l.lda(), + &mut ipiv, + &mut work, + lwork as i32, + &mut info, + ) + }; + info.as_lapack_result()?; + Ok(ipiv) } - unsafe fn invh( - l: MatrixLayout, - uplo: UPLO, - a: &mut [Self], - ipiv: &Pivot, - ) -> Result<()> { + fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> { let (n, _) = l.size(); - $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda(), ipiv).as_lapack_result()?; + let mut info = 0; + let mut work = vec![Self::zero(); n as usize]; + unsafe { $tri(uplo as u8, n, a, l.lda(), ipiv, &mut work, &mut info) }; + info.as_lapack_result()?; Ok(()) } - unsafe fn solveh( + fn solveh( l: MatrixLayout, uplo: UPLO, a: &[Self], @@ -56,30 +79,16 @@ macro_rules! impl_solveh { b: &mut [Self], ) -> Result<()> { let (n, _) = l.size(); - let nrhs = 1; - let ldb = match l { - MatrixLayout::C { .. } => 1, - MatrixLayout::F { .. } => n, - }; - $trs( - l.lapacke_layout(), - uplo as u8, - n, - nrhs, - a, - l.lda(), - ipiv, - b, - ldb, - ) - .as_lapack_result()?; + let mut info = 0; + unsafe { $trs(uplo as u8, n, 1, a, l.lda(), ipiv, b, n, &mut info) }; + info.as_lapack_result()?; Ok(()) } } }; } // impl_solveh! -impl_solveh!(f64, lapacke::dsytrf, lapacke::dsytri, lapacke::dsytrs); -impl_solveh!(f32, lapacke::ssytrf, lapacke::ssytri, lapacke::ssytrs); -impl_solveh!(c64, lapacke::zhetrf, lapacke::zhetri, lapacke::zhetrs); -impl_solveh!(c32, lapacke::chetrf, lapacke::chetri, lapacke::chetrs); +impl_solveh!(f64, lapack::dsytrf, lapack::dsytri, lapack::dsytrs); +impl_solveh!(f32, lapack::ssytrf, lapack::ssytri, lapack::ssytrs); +impl_solveh!(c64, lapack::zhetrf, lapack::zhetri, lapack::zhetrs); +impl_solveh!(c32, lapack::chetrf, lapack::chetri, lapack::chetrs); diff --git a/lax/src/svd.rs b/lax/src/svd.rs index 47a9a1be..d5e48ff5 100644 --- a/lax/src/svd.rs +++ b/lax/src/svd.rs @@ -2,9 +2,10 @@ use crate::{error::*, layout::MatrixLayout}; use cauchy::*; -use num_traits::Zero; +use num_traits::{ToPrimitive, Zero}; #[repr(u8)] +#[derive(Debug, Copy, Clone)] enum FlagSVD { All = b'A', // OverWrite = b'O', @@ -12,6 +13,16 @@ enum FlagSVD { No = b'N', } +impl FlagSVD { + fn from_bool(calc_uv: bool) -> Self { + if calc_uv { + FlagSVD::All + } else { + FlagSVD::No + } + } +} + /// Result of SVD pub struct SVDOutput { /// diagonal values @@ -24,65 +35,106 @@ pub struct SVDOutput { /// Wraps `*gesvd` pub trait SVD_: Scalar { - unsafe fn svd( - l: MatrixLayout, - calc_u: bool, - calc_vt: bool, - a: &mut [Self], - ) -> Result>; + /// Calculate singular value decomposition $ A = U \Sigma V^T $ + fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self]) + -> Result>; } macro_rules! impl_svd { - ($scalar:ty, $gesvd:path) => { + (@real, $scalar:ty, $gesvd:path) => { + impl_svd!(@body, $scalar, $gesvd, ); + }; + (@complex, $scalar:ty, $gesvd:path) => { + impl_svd!(@body, $scalar, $gesvd, rwork); + }; + (@body, $scalar:ty, $gesvd:path, $($rwork_ident:ident),*) => { impl SVD_ for $scalar { - unsafe fn svd( - l: MatrixLayout, - calc_u: bool, - calc_vt: bool, - mut a: &mut [Self], - ) -> Result> { - let (m, n) = l.size(); - let k = ::std::cmp::min(n, m); - let lda = l.lda(); - let (ju, ldu, mut u) = if calc_u { - (FlagSVD::All, m, vec![Self::zero(); (m * m) as usize]) - } else { - (FlagSVD::No, 1, Vec::new()) + fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, mut a: &mut [Self],) -> Result> { + let ju = match l { + MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u), + MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt), + }; + let jvt = match l { + MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt), + MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u), }; - let (jvt, ldvt, mut vt) = if calc_vt { - (FlagSVD::All, n, vec![Self::zero(); (n * n) as usize]) - } else { - (FlagSVD::No, n, Vec::new()) + + let m = l.lda(); + let mut u = match ju { + FlagSVD::All => Some(vec![Self::zero(); (m * m) as usize]), + FlagSVD::No => None, }; + + let n = l.len(); + let mut vt = match jvt { + FlagSVD::All => Some(vec![Self::zero(); (n * n) as usize]), + FlagSVD::No => None, + }; + + let k = std::cmp::min(m, n); let mut s = vec![Self::Real::zero(); k as usize]; - let mut superb = vec![Self::Real::zero(); (k - 1) as usize]; - $gesvd( - l.lapacke_layout(), - ju as u8, - jvt as u8, - m, - n, - &mut a, - lda, - &mut s, - &mut u, - ldu, - &mut vt, - ldvt, - &mut superb, - ) - .as_lapack_result()?; - Ok(SVDOutput { - s, - u: if calc_u { Some(u) } else { None }, - vt: if calc_vt { Some(vt) } else { None }, - }) + + $( + let mut $rwork_ident = vec![Self::Real::zero(); 5 * k as usize]; + )* + + // eval work size + let mut info = 0; + let mut work_size = [Self::zero()]; + unsafe { + $gesvd( + ju as u8, + jvt as u8, + m, + n, + &mut a, + m, + &mut s, + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work_size, + -1, + $(&mut $rwork_ident,)* + &mut info, + ); + } + info.as_lapack_result()?; + + // calc + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + $gesvd( + ju as u8, + jvt as u8, + m, + n, + &mut a, + m, + &mut s, + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + n, + &mut work, + lwork as i32, + $(&mut $rwork_ident,)* + &mut info, + ); + } + info.as_lapack_result()?; + match l { + MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }), + MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }), + } } } }; } // impl_svd! -impl_svd!(f64, lapacke::dgesvd); -impl_svd!(f32, lapacke::sgesvd); -impl_svd!(c64, lapacke::zgesvd); -impl_svd!(c32, lapacke::cgesvd); +impl_svd!(@real, f64, lapack::dgesvd); +impl_svd!(@real, f32, lapack::sgesvd); +impl_svd!(@complex, c64, lapack::zgesvd); +impl_svd!(@complex, c32, lapack::cgesvd); diff --git a/lax/src/svddc.rs b/lax/src/svddc.rs index 84f8394b..12ed129d 100644 --- a/lax/src/svddc.rs +++ b/lax/src/svddc.rs @@ -1,7 +1,7 @@ use super::*; use crate::{error::*, layout::MatrixLayout}; use cauchy::*; -use num_traits::Zero; +use num_traits::{ToPrimitive, Zero}; /// Specifies how many of the columns of *U* and rows of *V*áµ€ are computed and returned. /// @@ -18,59 +18,109 @@ pub enum UVTFlag { } pub trait SVDDC_: Scalar { - unsafe fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result>; + fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result>; } -macro_rules! impl_svdd { - ($scalar:ty, $gesdd:path) => { +macro_rules! impl_svddc { + (@real, $scalar:ty, $gesdd:path) => { + impl_svddc!(@body, $scalar, $gesdd, ); + }; + (@complex, $scalar:ty, $gesdd:path) => { + impl_svddc!(@body, $scalar, $gesdd, rwork); + }; + (@body, $scalar:ty, $gesdd:path, $($rwork_ident:ident),*) => { impl SVDDC_ for $scalar { - unsafe fn svddc( - l: MatrixLayout, - jobz: UVTFlag, - mut a: &mut [Self], - ) -> Result> { - let (m, n) = l.size(); + fn svddc(l: MatrixLayout, jobz: UVTFlag, mut a: &mut [Self],) -> Result> { + let m = l.lda(); + let n = l.len(); let k = m.min(n); - let lda = l.lda(); - let (ucol, vtrow) = match jobz { - UVTFlag::Full => (m, n), + let mut s = vec![Self::Real::zero(); k as usize]; + + let (u_col, vt_row) = match jobz { + UVTFlag::Full | UVTFlag::None => (m, n), UVTFlag::Some => (k, k), - UVTFlag::None => (1, 1), }; - let mut s = vec![Self::Real::zero(); k.max(1) as usize]; - let mut u = vec![Self::zero(); (m * ucol).max(1) as usize]; - let ldu = l.resized(m, ucol).lda(); - let mut vt = vec![Self::zero(); (vtrow * n).max(1) as usize]; - let ldvt = l.resized(vtrow, n).lda(); - $gesdd( - l.lapacke_layout(), - jobz as u8, - m, - n, - &mut a, - lda, - &mut s, - &mut u, - ldu, - &mut vt, - ldvt, - ) - .as_lapack_result()?; - Ok(SVDOutput { - s, - u: if jobz == UVTFlag::None { None } else { Some(u) }, - vt: if jobz == UVTFlag::None { - None - } else { - Some(vt) - }, - }) + let (mut u, mut vt) = match jobz { + UVTFlag::Full => ( + Some(vec![Self::zero(); (m * m) as usize]), + Some(vec![Self::zero(); (n * n) as usize]), + ), + UVTFlag::Some => ( + Some(vec![Self::zero(); (m * u_col) as usize]), + Some(vec![Self::zero(); (n * vt_row) as usize]), + ), + UVTFlag::None => (None, None), + }; + + $( // for complex only + let mx = n.max(m) as usize; + let mn = n.min(m) as usize; + let lrwork = match jobz { + UVTFlag::None => 7 * mn, + _ => std::cmp::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn), + }; + let mut $rwork_ident = vec![Self::Real::zero(); lrwork]; + )* + + // eval work size + let mut info = 0; + let mut iwork = vec![0; 8 * k as usize]; + let mut work_size = [Self::zero()]; + unsafe { + $gesdd( + jobz as u8, + m, + n, + &mut a, + m, + &mut s, + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + vt_row, + &mut work_size, + -1, + $(&mut $rwork_ident,)* + &mut iwork, + &mut info, + ); + } + info.as_lapack_result()?; + + // do svd + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + $gesdd( + jobz as u8, + m, + n, + &mut a, + m, + &mut s, + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + vt_row, + &mut work, + lwork as i32, + $(&mut $rwork_ident,)* + &mut iwork, + &mut info, + ); + } + info.as_lapack_result()?; + + match l { + MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }), + MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }), + } } } }; } -impl_svdd!(f32, lapacke::sgesdd); -impl_svdd!(f64, lapacke::dgesdd); -impl_svdd!(c32, lapacke::cgesdd); -impl_svdd!(c64, lapacke::zgesdd); +impl_svddc!(@real, f32, lapack::sgesdd); +impl_svddc!(@real, f64, lapack::dgesdd); +impl_svddc!(@complex, c32, lapack::cgesdd); +impl_svddc!(@complex, c64, lapack::zgesdd); diff --git a/lax/src/triangular.rs b/lax/src/triangular.rs index e725fa9e..5da50dde 100644 --- a/lax/src/triangular.rs +++ b/lax/src/triangular.rs @@ -1,8 +1,9 @@ //! Implement linear solver and inverse matrix use super::*; -use crate::{error::*, layout::MatrixLayout}; +use crate::{error::*, layout::*}; use cauchy::*; +use num_traits::Zero; #[derive(Debug, Clone, Copy)] #[repr(u8)] @@ -12,9 +13,8 @@ pub enum Diag { } /// Wraps `*trtri` and `*trtrs` -pub trait Triangular_: Sized { - unsafe fn inv_triangular(l: MatrixLayout, uplo: UPLO, d: Diag, a: &mut [Self]) -> Result<()>; - unsafe fn solve_triangular( +pub trait Triangular_: Scalar { + fn solve_triangular( al: MatrixLayout, bl: MatrixLayout, uplo: UPLO, @@ -27,50 +27,66 @@ pub trait Triangular_: Sized { macro_rules! impl_triangular { ($scalar:ty, $trtri:path, $trtrs:path) => { impl Triangular_ for $scalar { - unsafe fn inv_triangular( - l: MatrixLayout, - uplo: UPLO, - diag: Diag, - a: &mut [Self], - ) -> Result<()> { - let (n, _) = l.size(); - let lda = l.lda(); - $trtri(l.lapacke_layout(), uplo as u8, diag as u8, n, a, lda).as_lapack_result()?; - Ok(()) - } - - unsafe fn solve_triangular( - al: MatrixLayout, - bl: MatrixLayout, + fn solve_triangular( + a_layout: MatrixLayout, + b_layout: MatrixLayout, uplo: UPLO, diag: Diag, a: &[Self], - mut b: &mut [Self], + b: &mut [Self], ) -> Result<()> { - let (n, _) = al.size(); - let lda = al.lda(); - let (_, nrhs) = bl.size(); - let ldb = bl.lda(); - $trtrs( - al.lapacke_layout(), - uplo as u8, - Transpose::No as u8, - diag as u8, - n, - nrhs, - a, - lda, - &mut b, - ldb, - ) - .as_lapack_result()?; + // Transpose if a is C-continuous + let mut a_t = None; + let a_layout = match a_layout { + MatrixLayout::C { .. } => { + a_t = Some(vec![Self::zero(); a.len()]); + transpose(a_layout, a, a_t.as_mut().unwrap()) + } + MatrixLayout::F { .. } => a_layout, + }; + + // Transpose if b is C-continuous + let mut b_t = None; + let b_layout = match b_layout { + MatrixLayout::C { .. } => { + b_t = Some(vec![Self::zero(); b.len()]); + transpose(b_layout, b, b_t.as_mut().unwrap()) + } + MatrixLayout::F { .. } => b_layout, + }; + + let (m, n) = a_layout.size(); + let (n_, nrhs) = b_layout.size(); + assert_eq!(n, n_); + + let mut info = 0; + unsafe { + $trtrs( + uplo as u8, + Transpose::No as u8, + diag as u8, + m, + nrhs, + a_t.as_ref().map(|v| v.as_slice()).unwrap_or(a), + a_layout.lda(), + b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), + b_layout.lda(), + &mut info, + ); + } + info.as_lapack_result()?; + + // Re-transpose b + if let Some(b_t) = b_t { + transpose(b_layout, &b_t, b); + } Ok(()) } } }; } // impl_triangular! -impl_triangular!(f64, lapacke::dtrtri, lapacke::dtrtrs); -impl_triangular!(f32, lapacke::strtri, lapacke::strtrs); -impl_triangular!(c64, lapacke::ztrtri, lapacke::ztrtrs); -impl_triangular!(c32, lapacke::ctrtri, lapacke::ctrtrs); +impl_triangular!(f64, lapack::dtrtri, lapack::dtrtrs); +impl_triangular!(f32, lapack::strtri, lapack::strtrs); +impl_triangular!(c64, lapack::ztrtri, lapack::ztrtrs); +impl_triangular!(c32, lapack::ctrtri, lapack::ctrtrs); diff --git a/lax/src/tridiagonal.rs b/lax/src/tridiagonal.rs index 4eb8ff13..ea5bb119 100644 --- a/lax/src/tridiagonal.rs +++ b/lax/src/tridiagonal.rs @@ -2,7 +2,7 @@ //! for tridiagonal matrix use super::*; -use crate::{error::*, layout::MatrixLayout}; +use crate::{error::*, layout::*}; use cauchy::*; use num_traits::Zero; use std::ops::{Index, IndexMut}; @@ -130,11 +130,11 @@ impl IndexMut<[i32; 2]> for Tridiagonal { pub trait Tridiagonal_: Scalar + Sized { /// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using /// partial pivoting with row interchanges. - unsafe fn lu_tridiagonal(a: Tridiagonal) -> Result>; + fn lu_tridiagonal(a: Tridiagonal) -> Result>; - unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result; + fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result; - unsafe fn solve_tridiagonal( + fn solve_tridiagonal( lu: &LUFactorizedTridiagonal, bl: MatrixLayout, t: Transpose, @@ -143,18 +143,23 @@ pub trait Tridiagonal_: Scalar + Sized { } macro_rules! impl_tridiagonal { - ($scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { + (@real, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { + impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, iwork); + }; + (@complex, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { + impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, ); + }; + (@body, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path, $($iwork:ident)*) => { impl Tridiagonal_ for $scalar { - unsafe fn lu_tridiagonal( - mut a: Tridiagonal, - ) -> Result> { + fn lu_tridiagonal(mut a: Tridiagonal) -> Result> { let (n, _) = a.l.size(); let mut du2 = vec![Zero::zero(); (n - 2) as usize]; let mut ipiv = vec![0; n as usize]; // We have to calc one-norm before LU factorization let a_opnorm_one = a.opnorm_one(); - $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv) - .as_lapack_result()?; + let mut info = 0; + unsafe { $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv, &mut info,) }; + info.as_lapack_result()?; Ok(LUFactorizedTridiagonal { a, du2, @@ -163,56 +168,80 @@ macro_rules! impl_tridiagonal { }) } - unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { + fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { let (n, _) = lu.a.l.size(); let ipiv = &lu.ipiv; + let mut work = vec![Self::zero(); 2 * n as usize]; + $( + let mut $iwork = vec![0; n as usize]; + )* let mut rcond = Self::Real::zero(); - $gtcon( - NormType::One as u8, - n, - &lu.a.dl, - &lu.a.d, - &lu.a.du, - &lu.du2, - ipiv, - lu.a_opnorm_one, - &mut rcond, - ) - .as_lapack_result()?; + let mut info = 0; + unsafe { + $gtcon( + NormType::One as u8, + n, + &lu.a.dl, + &lu.a.d, + &lu.a.du, + &lu.du2, + ipiv, + lu.a_opnorm_one, + &mut rcond, + &mut work, + $(&mut $iwork,)* + &mut info, + ); + } + info.as_lapack_result()?; Ok(rcond) } - unsafe fn solve_tridiagonal( + fn solve_tridiagonal( lu: &LUFactorizedTridiagonal, - bl: MatrixLayout, + b_layout: MatrixLayout, t: Transpose, b: &mut [Self], ) -> Result<()> { let (n, _) = lu.a.l.size(); - let (_, nrhs) = bl.size(); let ipiv = &lu.ipiv; - let ldb = bl.lda(); - $gttrs( - lu.a.l.lapacke_layout(), - t as u8, - n, - nrhs, - &lu.a.dl, - &lu.a.d, - &lu.a.du, - &lu.du2, - ipiv, - b, - ldb, - ) - .as_lapack_result()?; + // Transpose if b is C-continuous + let mut b_t = None; + let b_layout = match b_layout { + MatrixLayout::C { .. } => { + b_t = Some(vec![Self::zero(); b.len()]); + transpose(b_layout, b, b_t.as_mut().unwrap()) + } + MatrixLayout::F { .. } => b_layout, + }; + let (ldb, nrhs) = b_layout.size(); + let mut info = 0; + unsafe { + $gttrs( + t as u8, + n, + nrhs, + &lu.a.dl, + &lu.a.d, + &lu.a.du, + &lu.du2, + ipiv, + b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), + ldb, + &mut info, + ); + } + info.as_lapack_result()?; + if let Some(b_t) = b_t { + transpose(b_layout, &b_t, b); + } Ok(()) } } }; } // impl_tridiagonal! -impl_tridiagonal!(f64, lapacke::dgttrf, lapacke::dgtcon, lapacke::dgttrs); -impl_tridiagonal!(f32, lapacke::sgttrf, lapacke::sgtcon, lapacke::sgttrs); -impl_tridiagonal!(c64, lapacke::zgttrf, lapacke::zgtcon, lapacke::zgttrs); -impl_tridiagonal!(c32, lapacke::cgttrf, lapacke::cgtcon, lapacke::cgttrs); +impl_tridiagonal!(@real, f64, lapack::dgttrf, lapack::dgtcon, lapack::dgttrs); +impl_tridiagonal!(@real, f32, lapack::sgttrf, lapack::sgtcon, lapack::sgttrs); +impl_tridiagonal!(@complex, c64, lapack::zgttrf, lapack::zgtcon, lapack::zgttrs); +impl_tridiagonal!(@complex, c32, lapack::cgttrf, lapack::cgtcon, lapack::cgttrs); diff --git a/ndarray-linalg/src/cholesky.rs b/ndarray-linalg/src/cholesky.rs index 79e240ab..3f445305 100644 --- a/ndarray-linalg/src/cholesky.rs +++ b/ndarray-linalg/src/cholesky.rs @@ -155,7 +155,7 @@ where fn invc_into(self) -> Result { let mut a = self.factor; - unsafe { A::inv_cholesky(a.square_layout()?, self.uplo, a.as_allocated_mut()?)? }; + A::inv_cholesky(a.square_layout()?, self.uplo, a.as_allocated_mut()?)?; triangular_fill_hermitian(&mut a, self.uplo); Ok(a) } @@ -173,14 +173,12 @@ where where Sb: DataMut, { - unsafe { - A::solve_cholesky( - self.factor.square_layout()?, - self.uplo, - self.factor.as_allocated()?, - b.as_slice_mut().unwrap(), - )? - }; + A::solve_cholesky( + self.factor.square_layout()?, + self.uplo, + self.factor.as_allocated()?, + b.as_slice_mut().unwrap(), + )?; Ok(b) } } @@ -259,7 +257,7 @@ where S: DataMut, { fn cholesky_inplace(&mut self, uplo: UPLO) -> Result<&mut Self> { - unsafe { A::cholesky(self.square_layout()?, uplo, self.as_allocated_mut()?)? }; + A::cholesky(self.square_layout()?, uplo, self.as_allocated_mut()?)?; Ok(self.into_triangular(uplo)) } } diff --git a/ndarray-linalg/src/eig.rs b/ndarray-linalg/src/eig.rs index e9f09080..17f5a1e8 100644 --- a/ndarray-linalg/src/eig.rs +++ b/ndarray-linalg/src/eig.rs @@ -11,6 +11,29 @@ pub trait Eig { type EigVal; type EigVec; /// Calculate eigenvalues with the right eigenvector + /// + /// $$ A u_i = \lambda_i u_i $$ + /// + /// ``` + /// use ndarray::*; + /// use ndarray_linalg::*; + /// + /// let a: Array2 = array![ + /// [-1.01, 0.86, -4.60, 3.31, -4.81], + /// [ 3.98, 0.53, -7.04, 5.29, 3.55], + /// [ 3.30, 8.26, -3.89, 8.20, -1.51], + /// [ 4.43, 4.96, -7.66, -7.33, 6.18], + /// [ 7.31, -6.43, -6.16, 2.47, 5.58], + /// ]; + /// let (eigs, vecs) = a.eig().unwrap(); + /// + /// let a = a.map(|v| v.as_c()); + /// for (&e, vec) in eigs.iter().zip(vecs.axis_iter(Axis(1))) { + /// let ev = vec.map(|v| v * e); + /// let av = a.dot(&vec); + /// assert_close_l2!(&av, &ev, 1e-5); + /// } + /// ``` fn eig(&self) -> Result<(Self::EigVal, Self::EigVec)>; } @@ -25,13 +48,11 @@ where fn eig(&self) -> Result<(Self::EigVal, Self::EigVec)> { let mut a = self.to_owned(); let layout = a.square_layout()?; - let (s, t) = unsafe { A::eig(true, layout, a.as_allocated_mut()?)? }; - let (n, _) = layout.size(); + let (s, t) = A::eig(true, layout, a.as_allocated_mut()?)?; + let n = layout.len() as usize; Ok(( ArrayBase::from(s), - ArrayBase::from(t) - .into_shape((n as usize, n as usize)) - .unwrap(), + Array2::from_shape_vec((n, n).f(), t).unwrap(), )) } } @@ -51,7 +72,7 @@ where fn eigvals(&self) -> Result { let mut a = self.to_owned(); - let (s, _) = unsafe { A::eig(true, a.square_layout()?, a.as_allocated_mut()?)? }; + let (s, _) = A::eig(true, a.square_layout()?, a.as_allocated_mut()?)?; Ok(ArrayBase::from(s)) } } diff --git a/ndarray-linalg/src/eigh.rs b/ndarray-linalg/src/eigh.rs index b0438a99..86f1fb46 100644 --- a/ndarray-linalg/src/eigh.rs +++ b/ndarray-linalg/src/eigh.rs @@ -99,7 +99,7 @@ where MatrixLayout::C { .. } => self.swap_axes(0, 1), MatrixLayout::F { .. } => {} } - let s = unsafe { A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)? }; + let s = A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)?; Ok((ArrayBase::from(s), self)) } } @@ -126,15 +126,13 @@ where MatrixLayout::F { .. } => {} } - let s = unsafe { - A::eigh_generalized( - true, - self.0.square_layout()?, - uplo, - self.0.as_allocated_mut()?, - self.1.as_allocated_mut()?, - )? - }; + let s = A::eigh_generalized( + true, + self.0.square_layout()?, + uplo, + self.0.as_allocated_mut()?, + self.1.as_allocated_mut()?, + )?; Ok((ArrayBase::from(s), self)) } @@ -191,7 +189,7 @@ where type EigVal = Array1; fn eigvalsh_inplace(&mut self, uplo: UPLO) -> Result { - let s = unsafe { A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)? }; + let s = A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)?; Ok(ArrayBase::from(s)) } } diff --git a/ndarray-linalg/src/least_squares.rs b/ndarray-linalg/src/least_squares.rs index 18d2033f..0ff518ad 100644 --- a/ndarray-linalg/src/least_squares.rs +++ b/ndarray-linalg/src/least_squares.rs @@ -76,6 +76,7 @@ use crate::types::*; /// is a `m x 1` column vector. If `I` is `Ix2`, the RHS is a `n x k` matrix /// (which can be seen as solving `Ax = b` k times for different b) and /// the solution is a `m x k` matrix. +#[derive(Debug, Clone)] pub struct LeastSquaresResult { /// The singular values of the matrix A in `Ax = b` pub singular_values: Array1, @@ -266,6 +267,9 @@ where &mut self, rhs: &mut ArrayBase, ) -> Result> { + if self.shape()[0] != rhs.shape()[0] { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into()); + } let (m, n) = (self.shape()[0], self.shape()[1]); if n > m { // we need a new rhs b/c it will be overwritten with the solution @@ -284,21 +288,19 @@ fn compute_least_squares_srhs( rhs: &mut ArrayBase, ) -> Result> where - E: Scalar + Lapack + LeastSquaresSvdDivideConquer_, + E: Scalar + Lapack, D1: DataMut, D2: DataMut, { let LeastSquaresOutput:: { singular_values, rank, - } = unsafe { - ::least_squares( - a.layout()?, - a.as_allocated_mut()?, - rhs.as_slice_memory_order_mut() - .ok_or_else(|| LinalgError::MemoryNotCont)?, - )? - }; + } = E::least_squares( + a.layout()?, + a.as_allocated_mut()?, + rhs.as_slice_memory_order_mut() + .ok_or_else(|| LinalgError::MemoryNotCont)?, + )?; let (m, n) = (a.shape()[0], a.shape()[1]); let solution = rhs.slice(s![0..n]).to_owned(); @@ -347,6 +349,9 @@ where &mut self, rhs: &mut ArrayBase, ) -> Result> { + if self.shape()[0] != rhs.shape()[0] { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into()); + } let (m, n) = (self.shape()[0], self.shape()[1]); if n > m { // we need a new rhs b/c it will be overwritten with the solution @@ -378,14 +383,12 @@ where let LeastSquaresOutput:: { singular_values, rank, - } = unsafe { - E::least_squares_nrhs( - a_layout, - a.as_allocated_mut()?, - rhs_layout, - rhs.as_allocated_mut()?, - )? - }; + } = E::least_squares_nrhs( + a_layout, + a.as_allocated_mut()?, + rhs_layout, + rhs.as_allocated_mut()?, + )?; let solution: Array2 = rhs.slice(s![..a.shape()[1], ..]).to_owned(); let singular_values = Array::from_shape_vec((singular_values.len(),), singular_values)?; @@ -549,28 +552,13 @@ mod tests { // // Testing error cases // - #[test] fn incompatible_shape_error_on_mismatching_num_rows() { let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; let b: Array1 = array![1., 2.]; - let res = a.least_squares(&b); - match res { - Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {} - _ => panic!("Expected Err()"), - } - } - - #[test] - fn incompatible_shape_error_on_mismatching_layout() { - let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; - let b = array![[1.], [2.]].t().to_owned(); - assert_eq!(b.layout().unwrap(), MatrixLayout::F { col: 2, lda: 1 }); - - let res = a.least_squares(&b); - match res { - Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {} - _ => panic!("Expected Err()"), + match a.least_squares(&b) { + Err(LinalgError::Shape(e)) if e.kind() == ErrorKind::IncompatibleShape => {} + _ => panic!("Should be raise IncompatibleShape"), } } } diff --git a/ndarray-linalg/src/opnorm.rs b/ndarray-linalg/src/opnorm.rs index 4358fd7f..b4956867 100644 --- a/ndarray-linalg/src/opnorm.rs +++ b/ndarray-linalg/src/opnorm.rs @@ -45,7 +45,7 @@ where fn opnorm(&self, t: NormType) -> Result { let l = self.layout()?; let a = self.as_allocated()?; - Ok(unsafe { A::opnorm(t, l, a) }) + Ok(A::opnorm(t, l, a)) } } @@ -108,6 +108,6 @@ where let l = arr.layout()?; let a = arr.as_allocated()?; - Ok(unsafe { A::opnorm(t, l, a) }) + Ok(A::opnorm(t, l, a)) } } diff --git a/ndarray-linalg/src/qr.rs b/ndarray-linalg/src/qr.rs index be2de0c2..ae7b2c25 100644 --- a/ndarray-linalg/src/qr.rs +++ b/ndarray-linalg/src/qr.rs @@ -61,7 +61,7 @@ where fn qr_square_inplace(&mut self) -> Result<(&mut Self, Self::R)> { let l = self.square_layout()?; - let r = unsafe { A::qr(l, self.as_allocated_mut()?)? }; + let r = A::qr(l, self.as_allocated_mut()?)?; let r: Array2<_> = into_matrix(l, r)?; Ok((self, r.into_triangular(UPLO::Upper))) } @@ -107,7 +107,7 @@ where let m = self.ncols(); let k = ::std::cmp::min(n, m); let l = self.layout()?; - let r = unsafe { A::qr(l, self.as_allocated_mut()?)? }; + let r = A::qr(l, self.as_allocated_mut()?)?; let r: Array2<_> = into_matrix(l, r)?; let q = self; Ok((take_slice(&q, n, k), take_slice_upper(&r, k, m))) diff --git a/ndarray-linalg/src/solve.rs b/ndarray-linalg/src/solve.rs index 566511f3..fd4b3017 100644 --- a/ndarray-linalg/src/solve.rs +++ b/ndarray-linalg/src/solve.rs @@ -167,15 +167,13 @@ where where Sb: DataMut, { - unsafe { - A::solve( - self.a.square_layout()?, - Transpose::No, - self.a.as_allocated()?, - &self.ipiv, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solve( + self.a.square_layout()?, + Transpose::No, + self.a.as_allocated()?, + &self.ipiv, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } fn solve_t_inplace<'a, Sb>( @@ -185,15 +183,13 @@ where where Sb: DataMut, { - unsafe { - A::solve( - self.a.square_layout()?, - Transpose::Transpose, - self.a.as_allocated()?, - &self.ipiv, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solve( + self.a.square_layout()?, + Transpose::Transpose, + self.a.as_allocated()?, + &self.ipiv, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } fn solve_h_inplace<'a, Sb>( @@ -203,15 +199,13 @@ where where Sb: DataMut, { - unsafe { - A::solve( - self.a.square_layout()?, - Transpose::Hermite, - self.a.as_allocated()?, - &self.ipiv, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solve( + self.a.square_layout()?, + Transpose::Hermite, + self.a.as_allocated()?, + &self.ipiv, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } } @@ -273,7 +267,7 @@ where S: DataMut + RawDataClone, { fn factorize_into(mut self) -> Result> { - let ipiv = unsafe { A::lu(self.layout()?, self.as_allocated_mut()?)? }; + let ipiv = A::lu(self.layout()?, self.as_allocated_mut()?)?; Ok(LUFactorized { a: self, ipiv }) } } @@ -285,7 +279,7 @@ where { fn factorize(&self) -> Result>> { let mut a: Array2 = replicate(self); - let ipiv = unsafe { A::lu(a.layout()?, a.as_allocated_mut()?)? }; + let ipiv = A::lu(a.layout()?, a.as_allocated_mut()?)?; Ok(LUFactorized { a, ipiv }) } } @@ -312,13 +306,11 @@ where type Output = ArrayBase; fn inv_into(mut self) -> Result> { - unsafe { - A::inv( - self.a.square_layout()?, - self.a.as_allocated_mut()?, - &self.ipiv, - )? - }; + A::inv( + self.a.square_layout()?, + self.a.as_allocated_mut()?, + &self.ipiv, + )?; Ok(self.a) } } @@ -539,13 +531,11 @@ where S: Data + RawDataClone, { fn rcond(&self) -> Result { - unsafe { - Ok(A::rcond( - self.a.layout()?, - self.a.as_allocated()?, - self.a.opnorm_one()?, - )?) - } + Ok(A::rcond( + self.a.layout()?, + self.a.as_allocated()?, + self.a.opnorm_one()?, + )?) } } diff --git a/ndarray-linalg/src/solveh.rs b/ndarray-linalg/src/solveh.rs index c05e2bde..102158aa 100644 --- a/ndarray-linalg/src/solveh.rs +++ b/ndarray-linalg/src/solveh.rs @@ -113,15 +113,13 @@ where where Sb: DataMut, { - unsafe { - A::solveh( - self.a.square_layout()?, - UPLO::Upper, - self.a.as_allocated()?, - &self.ipiv, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solveh( + self.a.square_layout()?, + UPLO::Upper, + self.a.as_allocated()?, + &self.ipiv, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } } @@ -165,7 +163,7 @@ where S: DataMut, { fn factorizeh_into(mut self) -> Result> { - let ipiv = unsafe { A::bk(self.square_layout()?, UPLO::Upper, self.as_allocated_mut()?)? }; + let ipiv = A::bk(self.square_layout()?, UPLO::Upper, self.as_allocated_mut()?)?; Ok(BKFactorized { a: self, ipiv }) } } @@ -177,7 +175,7 @@ where { fn factorizeh(&self) -> Result>> { let mut a: Array2 = replicate(self); - let ipiv = unsafe { A::bk(a.square_layout()?, UPLO::Upper, a.as_allocated_mut()?)? }; + let ipiv = A::bk(a.square_layout()?, UPLO::Upper, a.as_allocated_mut()?)?; Ok(BKFactorized { a, ipiv }) } } @@ -204,14 +202,12 @@ where type Output = ArrayBase; fn invh_into(mut self) -> Result> { - unsafe { - A::invh( - self.a.square_layout()?, - UPLO::Upper, - self.a.as_allocated_mut()?, - &self.ipiv, - )? - }; + A::invh( + self.a.square_layout()?, + UPLO::Upper, + self.a.as_allocated_mut()?, + &self.ipiv, + )?; triangular_fill_hermitian(&mut self.a, UPLO::Upper); Ok(self.a) } @@ -314,6 +310,7 @@ where S: Data, A: Scalar + Lapack, { + let layout = a.layout().unwrap(); let mut sign = A::Real::one(); let mut ln_det = A::Real::zero(); let mut ipiv_enum = ipiv_iter.enumerate(); @@ -337,9 +334,15 @@ where debug_assert_eq!(lower_diag.im(), Zero::zero()); // Off-diagonal elements, can be complex. - let off_diag = match uplo { - UPLO::Upper => unsafe { a.uget((k, k + 1)) }, - UPLO::Lower => unsafe { a.uget((k + 1, k)) }, + let off_diag = match layout { + MatrixLayout::C { .. } => match uplo { + UPLO::Upper => unsafe { a.uget((k + 1, k)) }, + UPLO::Lower => unsafe { a.uget((k, k + 1)) }, + }, + MatrixLayout::F { .. } => match uplo { + UPLO::Upper => unsafe { a.uget((k, k + 1)) }, + UPLO::Lower => unsafe { a.uget((k + 1, k)) }, + }, }; // Determinant of 2x2 block. diff --git a/ndarray-linalg/src/svd.rs b/ndarray-linalg/src/svd.rs index 9bb90977..5c9d59b1 100644 --- a/ndarray-linalg/src/svd.rs +++ b/ndarray-linalg/src/svd.rs @@ -2,13 +2,9 @@ //! //! [Wikipedia article on SVD](https://en.wikipedia.org/wiki/Singular_value_decomposition) +use crate::{convert::*, error::*, layout::*, types::*}; use ndarray::*; -use super::convert::*; -use super::error::*; -use super::layout::*; -use super::types::*; - /// singular-value decomposition of matrix reference pub trait SVD { type U; @@ -97,14 +93,13 @@ where calc_vt: bool, ) -> Result<(Option, Self::Sigma, Option)> { let l = self.layout()?; - let svd_res = unsafe { A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)? }; + let svd_res = A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)?; let (n, m) = l.size(); - let u = svd_res - .u - .map(|u| into_matrix(l.resized(n, n), u).expect("Size of U mismatches")); + + let u = svd_res.u.map(|u| into_matrix(l.resized(n, n), u).unwrap()); let vt = svd_res .vt - .map(|vt| into_matrix(l.resized(m, m), vt).expect("Size of VT mismatches")); + .map(|vt| into_matrix(l.resized(m, m), vt).unwrap()); let s = ArrayBase::from(svd_res.s); Ok((u, s, vt)) } diff --git a/ndarray-linalg/src/svddc.rs b/ndarray-linalg/src/svddc.rs index 22f3ae0c..c79fa0e2 100644 --- a/ndarray-linalg/src/svddc.rs +++ b/ndarray-linalg/src/svddc.rs @@ -1,12 +1,8 @@ //! Singular-value decomposition (SVD) by divide-and-conquer (?gesdd) +use super::{convert::*, error::*, layout::*, types::*}; use ndarray::*; -use super::convert::*; -use super::error::*; -use super::layout::*; -use super::types::*; - pub use lapack::svddc::UVTFlag; /// Singular-value decomposition of matrix (copying) by divide-and-conquer @@ -84,20 +80,24 @@ where uvt_flag: UVTFlag, ) -> Result<(Option, Self::Sigma, Option)> { let l = self.layout()?; - let svd_res = unsafe { A::svddc(l, uvt_flag, self.as_allocated_mut()?)? }; + let svd_res = A::svddc(l, uvt_flag, self.as_allocated_mut()?)?; let (m, n) = l.size(); let k = m.min(n); - let (ldu, tdu, ldvt, tdvt) = match uvt_flag { - UVTFlag::Full => (m, m, n, n), - UVTFlag::Some => (m, k, k, n), - UVTFlag::None => (1, 1, 1, 1), + + let (u_col, vt_row) = match uvt_flag { + UVTFlag::Full => (m, n), + UVTFlag::Some => (k, k), + UVTFlag::None => (0, 0), }; + let u = svd_res .u - .map(|u| into_matrix(l.resized(ldu, tdu), u).expect("Size of U mismatches")); + .map(|u| into_matrix(l.resized(m, u_col), u).unwrap()); + let vt = svd_res .vt - .map(|vt| into_matrix(l.resized(ldvt, tdvt), vt).expect("Size of VT mismatches")); + .map(|vt| into_matrix(l.resized(vt_row, n), vt).unwrap()); + let s = ArrayBase::from(svd_res.s); Ok((u, s, vt)) } diff --git a/ndarray-linalg/src/triangular.rs b/ndarray-linalg/src/triangular.rs index c54beafd..b378dcd2 100644 --- a/ndarray-linalg/src/triangular.rs +++ b/ndarray-linalg/src/triangular.rs @@ -85,7 +85,7 @@ where transpose_data(b)?; } let lb = b.layout()?; - unsafe { A::solve_triangular(la, lb, uplo, diag, a_, b.as_allocated_mut()?)? }; + A::solve_triangular(la, lb, uplo, diag, a_, b.as_allocated_mut()?)?; Ok(b) } } diff --git a/ndarray-linalg/src/tridiagonal.rs b/ndarray-linalg/src/tridiagonal.rs index 1fb0d558..92eae05b 100644 --- a/ndarray-linalg/src/tridiagonal.rs +++ b/ndarray-linalg/src/tridiagonal.rs @@ -271,14 +271,12 @@ where where Sb: DataMut, { - unsafe { - A::solve_tridiagonal( - &self, - rhs.layout()?, - Transpose::No, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solve_tridiagonal( + &self, + rhs.layout()?, + Transpose::No, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } fn solve_t_tridiagonal_inplace<'a, Sb>( @@ -288,14 +286,12 @@ where where Sb: DataMut, { - unsafe { - A::solve_tridiagonal( - &self, - rhs.layout()?, - Transpose::Transpose, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solve_tridiagonal( + &self, + rhs.layout()?, + Transpose::Transpose, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } fn solve_h_tridiagonal_inplace<'a, Sb>( @@ -305,14 +301,12 @@ where where Sb: DataMut, { - unsafe { - A::solve_tridiagonal( - &self, - rhs.layout()?, - Transpose::Hermite, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solve_tridiagonal( + &self, + rhs.layout()?, + Transpose::Hermite, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } } @@ -566,7 +560,7 @@ where A: Scalar + Lapack, { fn factorize_tridiagonal_into(self) -> Result> { - Ok(unsafe { A::lu_tridiagonal(self)? }) + Ok(A::lu_tridiagonal(self)?) } } @@ -576,7 +570,7 @@ where { fn factorize_tridiagonal(&self) -> Result> { let a = self.clone(); - Ok(unsafe { A::lu_tridiagonal(a)? }) + Ok(A::lu_tridiagonal(a)?) } } @@ -587,7 +581,7 @@ where { fn factorize_tridiagonal(&self) -> Result> { let a = self.extract_tridiagonal()?; - Ok(unsafe { A::lu_tridiagonal(a)? }) + Ok(A::lu_tridiagonal(a)?) } } @@ -677,7 +671,7 @@ where A: Scalar + Lapack, { fn rcond_tridiagonal(&self) -> Result { - unsafe { Ok(A::rcond_tridiagonal(&self)?) } + Ok(A::rcond_tridiagonal(&self)?) } } diff --git a/ndarray-linalg/tests/cholesky.rs b/ndarray-linalg/tests/cholesky.rs index accdbdf8..b45afb5c 100644 --- a/ndarray-linalg/tests/cholesky.rs +++ b/ndarray-linalg/tests/cholesky.rs @@ -1,216 +1,228 @@ use ndarray::*; use ndarray_linalg::*; -#[test] -fn cholesky() { - macro_rules! cholesky { - ($elem:ty, $rtol:expr) => { - let a_orig: Array2<$elem> = random_hpd(3); - println!("a = \n{:?}", a_orig); +macro_rules! cholesky { + ($elem:ty, $rtol:expr) => { + paste::item! { + #[test] + fn []() { + let a_orig: Array2<$elem> = random_hpd(3); + println!("a = \n{:?}", a_orig); - let upper = a_orig.cholesky(UPLO::Upper).unwrap(); - assert_close_l2!( - &upper.t().mapv(|elem| elem.conj()).dot(&upper.view()), - &a_orig, - $rtol - ); - - let lower = a_orig.cholesky(UPLO::Lower).unwrap(); - assert_close_l2!( - &lower.dot(&lower.t().mapv(|elem| elem.conj())), - &a_orig, - $rtol - ); - - let a: Array2<$elem> = replicate(&a_orig); - let upper = a.cholesky_into(UPLO::Upper).unwrap(); - assert_close_l2!( - &upper.t().mapv(|elem| elem.conj()).dot(&upper.view()), - &a_orig, - $rtol - ); + let upper = a_orig.cholesky(UPLO::Upper).unwrap(); + assert_close_l2!( + &upper.t().mapv(|elem| elem.conj()).dot(&upper.view()), + &a_orig, + $rtol + ); - let a: Array2<$elem> = replicate(&a_orig); - let lower = a.cholesky_into(UPLO::Lower).unwrap(); - assert_close_l2!( - &lower.dot(&lower.t().mapv(|elem| elem.conj())), - &a_orig, - $rtol - ); + let lower = a_orig.cholesky(UPLO::Lower).unwrap(); + assert_close_l2!( + &lower.dot(&lower.t().mapv(|elem| elem.conj())), + &a_orig, + $rtol + ); - let mut a: Array2<$elem> = replicate(&a_orig); - { - let upper = a.cholesky_inplace(UPLO::Upper).unwrap(); + let a: Array2<$elem> = replicate(&a_orig); + let upper = a.cholesky_into(UPLO::Upper).unwrap(); assert_close_l2!( &upper.t().mapv(|elem| elem.conj()).dot(&upper.view()), &a_orig, $rtol ); - } - assert_close_l2!( - &a.t().mapv(|elem| elem.conj()).dot(&upper.view()), - &a_orig, - $rtol - ); - let mut a: Array2<$elem> = replicate(&a_orig); - { - let lower = a.cholesky_inplace(UPLO::Lower).unwrap(); + let a: Array2<$elem> = replicate(&a_orig); + let lower = a.cholesky_into(UPLO::Lower).unwrap(); assert_close_l2!( &lower.dot(&lower.t().mapv(|elem| elem.conj())), &a_orig, $rtol ); + + let mut a: Array2<$elem> = replicate(&a_orig); + { + let upper = a.cholesky_inplace(UPLO::Upper).unwrap(); + assert_close_l2!( + &upper.t().mapv(|elem| elem.conj()).dot(&upper.view()), + &a_orig, + $rtol + ); + } + assert_close_l2!( + &a.t().mapv(|elem| elem.conj()).dot(&upper.view()), + &a_orig, + $rtol + ); + + let mut a: Array2<$elem> = replicate(&a_orig); + { + let lower = a.cholesky_inplace(UPLO::Lower).unwrap(); + assert_close_l2!( + &lower.dot(&lower.t().mapv(|elem| elem.conj())), + &a_orig, + $rtol + ); + } + assert_close_l2!(&a.dot(&lower.t().mapv(|elem| elem.conj())), &a_orig, $rtol); } - assert_close_l2!(&a.dot(&lower.t().mapv(|elem| elem.conj())), &a_orig, $rtol); - }; - } - cholesky!(f64, 1e-9); - cholesky!(f32, 1e-5); - cholesky!(c64, 1e-9); - cholesky!(c32, 1e-5); + } // paste::item + }; } -#[test] -fn cholesky_into_lower_upper() { - macro_rules! cholesky_into_lower_upper { - ($elem:ty, $rtol:expr) => { - let a: Array2<$elem> = random_hpd(3); - println!("a = \n{:?}", a); - let upper = a.cholesky(UPLO::Upper).unwrap(); - let fac_upper = a.factorizec(UPLO::Upper).unwrap(); - let fac_lower = a.factorizec(UPLO::Lower).unwrap(); - assert_close_l2!(&upper, &fac_lower.into_upper(), $rtol); - assert_close_l2!(&upper, &fac_upper.into_upper(), $rtol); - let lower = a.cholesky(UPLO::Lower).unwrap(); - let fac_upper = a.factorizec(UPLO::Upper).unwrap(); - let fac_lower = a.factorizec(UPLO::Lower).unwrap(); - assert_close_l2!(&lower, &fac_lower.into_lower(), $rtol); - assert_close_l2!(&lower, &fac_upper.into_lower(), $rtol); - }; - } - cholesky_into_lower_upper!(f64, 1e-9); - cholesky_into_lower_upper!(f32, 1e-5); - cholesky_into_lower_upper!(c64, 1e-9); - cholesky_into_lower_upper!(c32, 1e-5); -} +cholesky!(f64, 1e-9); +cholesky!(f32, 1e-5); +cholesky!(c64, 1e-9); +cholesky!(c32, 1e-5); -#[test] -fn cholesky_inverse() { - macro_rules! cholesky_into_inverse { - ($elem:ty, $rtol:expr) => { - let a: Array2<$elem> = random_hpd(3); - println!("a = \n{:?}", a); - let inv = a.invc().unwrap(); - assert_close_l2!(&a.dot(&inv), &Array2::eye(3), $rtol); - let inv_into: Array2<$elem> = replicate(&a).invc_into().unwrap(); - assert_close_l2!(&a.dot(&inv_into), &Array2::eye(3), $rtol); - let inv_upper = a.factorizec(UPLO::Upper).unwrap().invc().unwrap(); - assert_close_l2!(&a.dot(&inv_upper), &Array2::eye(3), $rtol); - let inv_upper_into = a.factorizec(UPLO::Upper).unwrap().invc_into().unwrap(); - assert_close_l2!(&a.dot(&inv_upper_into), &Array2::eye(3), $rtol); - let inv_lower = a.factorizec(UPLO::Lower).unwrap().invc().unwrap(); - assert_close_l2!(&a.dot(&inv_lower), &Array2::eye(3), $rtol); - let inv_lower_into = a.factorizec(UPLO::Lower).unwrap().invc_into().unwrap(); - assert_close_l2!(&a.dot(&inv_lower_into), &Array2::eye(3), $rtol); - }; - } - cholesky_into_inverse!(f64, 1e-9); - cholesky_into_inverse!(f32, 1e-3); - cholesky_into_inverse!(c64, 1e-9); - cholesky_into_inverse!(c32, 1e-3); +macro_rules! cholesky_into_lower_upper { + ($elem:ty, $rtol:expr) => { + paste::item! { + #[test] + fn []() { + let a: Array2<$elem> = random_hpd(3); + println!("a = \n{:?}", a); + let upper = a.cholesky(UPLO::Upper).unwrap(); + let fac_upper = a.factorizec(UPLO::Upper).unwrap(); + let fac_lower = a.factorizec(UPLO::Lower).unwrap(); + assert_close_l2!(&upper, &fac_lower.into_upper(), $rtol); + assert_close_l2!(&upper, &fac_upper.into_upper(), $rtol); + let lower = a.cholesky(UPLO::Lower).unwrap(); + let fac_upper = a.factorizec(UPLO::Upper).unwrap(); + let fac_lower = a.factorizec(UPLO::Lower).unwrap(); + assert_close_l2!(&lower, &fac_lower.into_lower(), $rtol); + assert_close_l2!(&lower, &fac_upper.into_lower(), $rtol); + } + } + }; } -#[test] -fn cholesky_det() { - macro_rules! cholesky_det { - ($elem:ty, $atol:expr) => { - let a: Array2<$elem> = random_hpd(3); - println!("a = \n{:?}", a); - let ln_det = a - .eigvalsh(UPLO::Upper) - .unwrap() - .mapv(|elem| elem.ln()) - .scalar_sum(); - let det = ln_det.exp(); - assert_aclose!(a.factorizec(UPLO::Upper).unwrap().detc(), det, $atol); - assert_aclose!(a.factorizec(UPLO::Upper).unwrap().ln_detc(), ln_det, $atol); - assert_aclose!(a.factorizec(UPLO::Lower).unwrap().detc_into(), det, $atol); - assert_aclose!( - a.factorizec(UPLO::Lower).unwrap().ln_detc_into(), - ln_det, - $atol - ); - assert_aclose!(a.detc().unwrap(), det, $atol); - assert_aclose!(a.ln_detc().unwrap(), ln_det, $atol); - assert_aclose!(a.clone().detc_into().unwrap(), det, $atol); - assert_aclose!(a.ln_detc_into().unwrap(), ln_det, $atol); - }; - } - cholesky_det!(f64, 1e-9); - cholesky_det!(f32, 1e-3); - cholesky_det!(c64, 1e-9); - cholesky_det!(c32, 1e-3); +cholesky_into_lower_upper!(f64, 1e-9); +cholesky_into_lower_upper!(f32, 1e-5); +cholesky_into_lower_upper!(c64, 1e-9); +cholesky_into_lower_upper!(c32, 1e-5); + +macro_rules! cholesky_into_inverse { + ($elem:ty, $rtol:expr) => { + paste::item! { + #[test] + fn []() { + let a: Array2<$elem> = random_hpd(3); + println!("a = \n{:?}", a); + let inv = a.invc().unwrap(); + assert_close_l2!(&a.dot(&inv), &Array2::eye(3), $rtol); + let inv_into: Array2<$elem> = replicate(&a).invc_into().unwrap(); + assert_close_l2!(&a.dot(&inv_into), &Array2::eye(3), $rtol); + let inv_upper = a.factorizec(UPLO::Upper).unwrap().invc().unwrap(); + assert_close_l2!(&a.dot(&inv_upper), &Array2::eye(3), $rtol); + let inv_upper_into = a.factorizec(UPLO::Upper).unwrap().invc_into().unwrap(); + assert_close_l2!(&a.dot(&inv_upper_into), &Array2::eye(3), $rtol); + let inv_lower = a.factorizec(UPLO::Lower).unwrap().invc().unwrap(); + assert_close_l2!(&a.dot(&inv_lower), &Array2::eye(3), $rtol); + let inv_lower_into = a.factorizec(UPLO::Lower).unwrap().invc_into().unwrap(); + assert_close_l2!(&a.dot(&inv_lower_into), &Array2::eye(3), $rtol); + } + } + }; } +cholesky_into_inverse!(f64, 1e-9); +cholesky_into_inverse!(f32, 1e-3); +cholesky_into_inverse!(c64, 1e-9); +cholesky_into_inverse!(c32, 1e-3); -#[test] -fn cholesky_solve() { - macro_rules! cholesky_solve { - ($elem:ty, $rtol:expr) => { - let a: Array2<$elem> = random_hpd(3); - let x: Array1<$elem> = random(3); - let b = a.dot(&x); - println!("a = \n{:?}", a); - println!("x = \n{:?}", x); - assert_close_l2!(&a.solvec(&b).unwrap(), &x, $rtol); - assert_close_l2!(&a.solvec_into(b.clone()).unwrap(), &x, $rtol); - assert_close_l2!(&a.solvec_inplace(&mut b.clone()).unwrap(), &x, $rtol); - assert_close_l2!( - &a.factorizec(UPLO::Upper).unwrap().solvec(&b).unwrap(), - &x, - $rtol - ); - assert_close_l2!( - &a.factorizec(UPLO::Lower).unwrap().solvec(&b).unwrap(), - &x, - $rtol - ); - assert_close_l2!( - &a.factorizec(UPLO::Upper) - .unwrap() - .solvec_into(b.clone()) - .unwrap(), - &x, - $rtol - ); - assert_close_l2!( - &a.factorizec(UPLO::Lower) - .unwrap() - .solvec_into(b.clone()) - .unwrap(), - &x, - $rtol - ); - assert_close_l2!( - &a.factorizec(UPLO::Upper) +macro_rules! cholesky_det { + ($elem:ty, $atol:expr) => { + paste::item! { + #[test] + fn []() { + let a: Array2<$elem> = random_hpd(3); + println!("a = \n{:?}", a); + let ln_det = a + .eigvalsh(UPLO::Upper) .unwrap() - .solvec_inplace(&mut b.clone()) - .unwrap(), - &x, - $rtol - ); - assert_close_l2!( - &a.factorizec(UPLO::Lower) - .unwrap() - .solvec_inplace(&mut b.clone()) - .unwrap(), - &x, - $rtol - ); - }; - } - cholesky_solve!(f64, 1e-9); - cholesky_solve!(f32, 1e-3); - cholesky_solve!(c64, 1e-9); - cholesky_solve!(c32, 1e-3); + .mapv(|elem| elem.ln()) + .scalar_sum(); + let det = ln_det.exp(); + assert_aclose!(a.factorizec(UPLO::Upper).unwrap().detc(), det, $atol); + assert_aclose!(a.factorizec(UPLO::Upper).unwrap().ln_detc(), ln_det, $atol); + assert_aclose!(a.factorizec(UPLO::Lower).unwrap().detc_into(), det, $atol); + assert_aclose!( + a.factorizec(UPLO::Lower).unwrap().ln_detc_into(), + ln_det, + $atol + ); + assert_aclose!(a.detc().unwrap(), det, $atol); + assert_aclose!(a.ln_detc().unwrap(), ln_det, $atol); + assert_aclose!(a.clone().detc_into().unwrap(), det, $atol); + assert_aclose!(a.ln_detc_into().unwrap(), ln_det, $atol); + } + } + }; +} +cholesky_det!(f64, 1e-9); +cholesky_det!(f32, 1e-3); +cholesky_det!(c64, 1e-9); +cholesky_det!(c32, 1e-3); + +macro_rules! cholesky_solve { + ($elem:ty, $rtol:expr) => { + paste::item! { + #[test] + fn []() { + let a: Array2<$elem> = random_hpd(3); + let x: Array1<$elem> = random(3); + let b = a.dot(&x); + println!("a = \n{:?}", a); + println!("x = \n{:?}", x); + assert_close_l2!(&a.solvec(&b).unwrap(), &x, $rtol); + assert_close_l2!(&a.solvec_into(b.clone()).unwrap(), &x, $rtol); + assert_close_l2!(&a.solvec_inplace(&mut b.clone()).unwrap(), &x, $rtol); + assert_close_l2!( + &a.factorizec(UPLO::Upper).unwrap().solvec(&b).unwrap(), + &x, + $rtol + ); + assert_close_l2!( + &a.factorizec(UPLO::Lower).unwrap().solvec(&b).unwrap(), + &x, + $rtol + ); + assert_close_l2!( + &a.factorizec(UPLO::Upper) + .unwrap() + .solvec_into(b.clone()) + .unwrap(), + &x, + $rtol + ); + assert_close_l2!( + &a.factorizec(UPLO::Lower) + .unwrap() + .solvec_into(b.clone()) + .unwrap(), + &x, + $rtol + ); + assert_close_l2!( + &a.factorizec(UPLO::Upper) + .unwrap() + .solvec_inplace(&mut b.clone()) + .unwrap(), + &x, + $rtol + ); + assert_close_l2!( + &a.factorizec(UPLO::Lower) + .unwrap() + .solvec_inplace(&mut b.clone()) + .unwrap(), + &x, + $rtol + ); + } + } + }; } +cholesky_solve!(f64, 1e-9); +cholesky_solve!(f32, 1e-3); +cholesky_solve!(c64, 1e-9); +cholesky_solve!(c32, 1e-3); diff --git a/ndarray-linalg/tests/eig.rs b/ndarray-linalg/tests/eig.rs index 1eb0e7bb..28314b8a 100644 --- a/ndarray-linalg/tests/eig.rs +++ b/ndarray-linalg/tests/eig.rs @@ -92,11 +92,11 @@ fn test_matrix_real_t() -> Array2 { fn answer_eig_real() -> Array1 { array![ - T::complex(2.86, 10.76), - T::complex(2.86, -10.76), + T::complex(-10.46, 0.00), T::complex(-0.69, 4.70), T::complex(-0.69, -4.70), - T::complex(-10.46, 0.00), + T::complex(2.86, 10.76), + T::complex(2.86, -10.76), ] } diff --git a/ndarray-linalg/tests/least_squares.rs b/ndarray-linalg/tests/least_squares.rs index c388c9d7..33e20ca7 100644 --- a/ndarray-linalg/tests/least_squares.rs +++ b/ndarray-linalg/tests/least_squares.rs @@ -19,7 +19,7 @@ fn test_exact(a: Array2) { // b == Ax let ax = a.dot(&x); - assert_close_l2!(&b, &ax, T::real(1.0e-4)); + assert_close_max!(&b, &ax, T::real(1.0e-4)); } macro_rules! impl_exact { @@ -27,13 +27,13 @@ macro_rules! impl_exact { paste::item! { #[test] fn []() { - let a: Array2 = random((3, 3)); + let a: Array2<$scalar> = random((3, 3)); test_exact(a) } #[test] fn []() { - let a: Array2 = random((3, 3).f()); + let a: Array2<$scalar> = random((3, 3).f()); test_exact(a) } } @@ -73,13 +73,13 @@ macro_rules! impl_overdetermined { paste::item! { #[test] fn []() { - let a: Array2 = random((4, 3)); + let a: Array2<$scalar> = random((4, 3)); test_overdetermined(a) } #[test] fn []() { - let a: Array2 = random((4, 3).f()); + let a: Array2<$scalar> = random((4, 3).f()); test_overdetermined(a) } } @@ -102,7 +102,7 @@ fn test_underdetermined(a: Array2) { // b == Ax let x = result.solution; let ax = a.dot(&x); - assert_close_l2!(&b, &ax, T::real(1.0e-4)); + assert_close_max!(&b, &ax, T::real(1.0e-4)); } macro_rules! impl_underdetermined { @@ -110,13 +110,13 @@ macro_rules! impl_underdetermined { paste::item! { #[test] fn []() { - let a: Array2 = random((3, 4)); + let a: Array2<$scalar> = random((3, 4)); test_underdetermined(a) } #[test] fn []() { - let a: Array2 = random((3, 4).f()); + let a: Array2<$scalar> = random((3, 4).f()); test_underdetermined(a) } } diff --git a/ndarray-linalg/tests/least_squares_nrhs.rs b/ndarray-linalg/tests/least_squares_nrhs.rs index 4c964697..dd7d283c 100644 --- a/ndarray-linalg/tests/least_squares_nrhs.rs +++ b/ndarray-linalg/tests/least_squares_nrhs.rs @@ -9,6 +9,7 @@ fn test_exact(a: Array2, b: Array2) { assert_eq!(b.layout().unwrap().size(), (3, 2)); let result = a.least_squares(&b).unwrap(); + dbg!(&result); // unpack result let x: Array2 = result.solution; let residual_l2_square: Array1 = result.residual_sum_of_squares.unwrap(); @@ -31,33 +32,29 @@ macro_rules! impl_exact { paste::item! { #[test] fn []() { - let a: Array2 = random((3, 3)); - let b: Array2 = random((3, 2)); + let a: Array2<$scalar> = random((3, 3)); + let b: Array2<$scalar> = random((3, 2)); test_exact(a, b) } - /* Unsupported currently. See https://github.com/rust-ndarray/ndarray-linalg/issues/234 - #[test] fn []() { - let a: Array2 = random((3, 3)); - let b: Array2 = random((3, 2).f()); + let a: Array2<$scalar> = random((3, 3)); + let b: Array2<$scalar> = random((3, 2).f()); test_exact(a, b) } #[test] fn []() { - let a: Array2 = random((3, 3).f()); - let b: Array2 = random((3, 2)); + let a: Array2<$scalar> = random((3, 3).f()); + let b: Array2<$scalar> = random((3, 2)); test_exact(a, b) } - */ - #[test] fn []() { - let a: Array2 = random((3, 3).f()); - let b: Array2 = random((3, 2).f()); + let a: Array2<$scalar> = random((3, 3).f()); + let b: Array2<$scalar> = random((3, 2).f()); test_exact(a, b) } } @@ -103,33 +100,29 @@ macro_rules! impl_overdetermined { paste::item! { #[test] fn []() { - let a: Array2 = random((4, 3)); - let b: Array2 = random((4, 2)); + let a: Array2<$scalar> = random((4, 3)); + let b: Array2<$scalar> = random((4, 2)); test_overdetermined(a, b) } - /* Unsupported currently. See https://github.com/rust-ndarray/ndarray-linalg/issues/234 - #[test] fn []() { - let a: Array2 = random((4, 3).f()); - let b: Array2 = random((4, 2)); + let a: Array2<$scalar> = random((4, 3).f()); + let b: Array2<$scalar> = random((4, 2)); test_overdetermined(a, b) } #[test] fn []() { - let a: Array2 = random((4, 3)); - let b: Array2 = random((4, 2).f()); + let a: Array2<$scalar> = random((4, 3)); + let b: Array2<$scalar> = random((4, 2).f()); test_overdetermined(a, b) } - */ - #[test] fn []() { - let a: Array2 = random((4, 3).f()); - let b: Array2 = random((4, 2).f()); + let a: Array2<$scalar> = random((4, 3).f()); + let b: Array2<$scalar> = random((4, 2).f()); test_overdetermined(a, b) } } @@ -162,33 +155,29 @@ macro_rules! impl_underdetermined { paste::item! { #[test] fn []() { - let a: Array2 = random((3, 4)); - let b: Array2 = random((3, 2)); + let a: Array2<$scalar> = random((3, 4)); + let b: Array2<$scalar> = random((3, 2)); test_underdetermined(a, b) } - /* Unsupported currently. See https://github.com/rust-ndarray/ndarray-linalg/issues/234 - #[test] fn []() { - let a: Array2 = random((3, 4).f()); - let b: Array2 = random((3, 2)); + let a: Array2<$scalar> = random((3, 4).f()); + let b: Array2<$scalar> = random((3, 2)); test_underdetermined(a, b) } #[test] fn []() { - let a: Array2 = random((3, 4)); - let b: Array2 = random((3, 2).f()); + let a: Array2<$scalar> = random((3, 4)); + let b: Array2<$scalar> = random((3, 2).f()); test_underdetermined(a, b) } - */ - #[test] fn []() { - let a: Array2 = random((3, 4).f()); - let b: Array2 = random((3, 2).f()); + let a: Array2<$scalar> = random((3, 4).f()); + let b: Array2<$scalar> = random((3, 2).f()); test_underdetermined(a, b) } } diff --git a/ndarray-linalg/tests/svd.rs b/ndarray-linalg/tests/svd.rs index acc6ffca..c83885e1 100644 --- a/ndarray-linalg/tests/svd.rs +++ b/ndarray-linalg/tests/svd.rs @@ -2,7 +2,7 @@ use ndarray::*; use ndarray_linalg::*; use std::cmp::min; -fn test(a: &Array2) { +fn test(a: &Array2) { let (n, m) = a.dim(); let answer = a.clone(); println!("a = \n{:?}", a); @@ -12,14 +12,14 @@ fn test(a: &Array2) { println!("u = \n{:?}", &u); println!("s = \n{:?}", &s); println!("v = \n{:?}", &vt); - let mut sm = Array::zeros((n, m)); + let mut sm = Array::::zeros((n, m)); for i in 0..min(n, m) { - sm[(i, i)] = s[i]; + sm[(i, i)] = T::from(s[i]).unwrap(); } - assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7); + assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, T::real(1e-7)); } -fn test_no_vt(a: &Array2) { +fn test_no_vt(a: &Array2) { let (n, _m) = a.dim(); println!("a = \n{:?}", a); let (u, _s, vt): (_, Array1<_>, _) = a.svd(true, false).unwrap(); @@ -30,7 +30,7 @@ fn test_no_vt(a: &Array2) { assert_eq!(u.dim().1, n); } -fn test_no_u(a: &Array2) { +fn test_no_u(a: &Array2) { let (_n, m) = a.dim(); println!("a = \n{:?}", a); let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, true).unwrap(); @@ -41,7 +41,7 @@ fn test_no_u(a: &Array2) { assert_eq!(vt.dim().1, m); } -fn test_diag_only(a: &Array2) { +fn test_diag_only(a: &Array2) { println!("a = \n{:?}", a); let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, false).unwrap(); assert!(u.is_none()); @@ -49,32 +49,44 @@ fn test_diag_only(a: &Array2) { } macro_rules! test_svd_impl { - ($test:ident, $n:expr, $m:expr) => { + ($type:ty, $test:ident, $n:expr, $m:expr) => { paste::item! { #[test] - fn []() { + fn []() { let a = random(($n, $m)); - $test(&a); + $test::<$type>(&a); } #[test] - fn []() { + fn []() { let a = random(($n, $m).f()); - $test(&a); + $test::<$type>(&a); } } }; } -test_svd_impl!(test, 3, 3); -test_svd_impl!(test_no_vt, 3, 3); -test_svd_impl!(test_no_u, 3, 3); -test_svd_impl!(test_diag_only, 3, 3); -test_svd_impl!(test, 4, 3); -test_svd_impl!(test_no_vt, 4, 3); -test_svd_impl!(test_no_u, 4, 3); -test_svd_impl!(test_diag_only, 4, 3); -test_svd_impl!(test, 3, 4); -test_svd_impl!(test_no_vt, 3, 4); -test_svd_impl!(test_no_u, 3, 4); -test_svd_impl!(test_diag_only, 3, 4); +test_svd_impl!(f64, test, 3, 3); +test_svd_impl!(f64, test_no_vt, 3, 3); +test_svd_impl!(f64, test_no_u, 3, 3); +test_svd_impl!(f64, test_diag_only, 3, 3); +test_svd_impl!(f64, test, 4, 3); +test_svd_impl!(f64, test_no_vt, 4, 3); +test_svd_impl!(f64, test_no_u, 4, 3); +test_svd_impl!(f64, test_diag_only, 4, 3); +test_svd_impl!(f64, test, 3, 4); +test_svd_impl!(f64, test_no_vt, 3, 4); +test_svd_impl!(f64, test_no_u, 3, 4); +test_svd_impl!(f64, test_diag_only, 3, 4); +test_svd_impl!(c64, test, 3, 3); +test_svd_impl!(c64, test_no_vt, 3, 3); +test_svd_impl!(c64, test_no_u, 3, 3); +test_svd_impl!(c64, test_diag_only, 3, 3); +test_svd_impl!(c64, test, 4, 3); +test_svd_impl!(c64, test_no_vt, 4, 3); +test_svd_impl!(c64, test_no_u, 4, 3); +test_svd_impl!(c64, test_diag_only, 4, 3); +test_svd_impl!(c64, test, 3, 4); +test_svd_impl!(c64, test_no_vt, 3, 4); +test_svd_impl!(c64, test_no_u, 3, 4); +test_svd_impl!(c64, test_diag_only, 3, 4); diff --git a/ndarray-linalg/tests/svddc.rs b/ndarray-linalg/tests/svddc.rs index 2c9204c8..fb26c8d5 100644 --- a/ndarray-linalg/tests/svddc.rs +++ b/ndarray-linalg/tests/svddc.rs @@ -1,13 +1,13 @@ use ndarray::*; use ndarray_linalg::*; -fn test(a: &Array2, flag: UVTFlag) { +fn test(a: &Array2, flag: UVTFlag) { let (n, m) = a.dim(); let k = n.min(m); let answer = a.clone(); println!("a = \n{:?}", a); let (u, s, vt): (_, Array1<_>, _) = a.svddc(flag).unwrap(); - let mut sm = match flag { + let mut sm: Array2 = match flag { UVTFlag::Full => Array::zeros((n, m)), UVTFlag::Some => Array::zeros((k, k)), UVTFlag::None => { @@ -22,53 +22,56 @@ fn test(a: &Array2, flag: UVTFlag) { println!("s = \n{:?}", &s); println!("v = \n{:?}", &vt); for i in 0..k { - sm[(i, i)] = s[i]; + sm[(i, i)] = T::from_real(s[i]); } - assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7); + assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, T::real(1e-7)); } macro_rules! test_svd_impl { - ($n:expr, $m:expr) => { + ($scalar:ty, $n:expr, $m:expr) => { paste::item! { #[test] - fn []() { + fn []() { let a = random(($n, $m)); - test(&a, UVTFlag::Full); + test::<$scalar>(&a, UVTFlag::Full); } #[test] - fn []() { + fn []() { let a = random(($n, $m)); - test(&a, UVTFlag::Some); + test::<$scalar>(&a, UVTFlag::Some); } #[test] - fn []() { + fn []() { let a = random(($n, $m)); - test(&a, UVTFlag::None); + test::<$scalar>(&a, UVTFlag::None); } #[test] - fn []() { + fn []() { let a = random(($n, $m).f()); - test(&a, UVTFlag::Full); + test::<$scalar>(&a, UVTFlag::Full); } #[test] - fn []() { + fn []() { let a = random(($n, $m).f()); - test(&a, UVTFlag::Some); + test::<$scalar>(&a, UVTFlag::Some); } #[test] - fn []() { + fn []() { let a = random(($n, $m).f()); - test(&a, UVTFlag::None); + test::<$scalar>(&a, UVTFlag::None); } } }; } -test_svd_impl!(3, 3); -test_svd_impl!(4, 3); -test_svd_impl!(3, 4); +test_svd_impl!(f64, 3, 3); +test_svd_impl!(f64, 4, 3); +test_svd_impl!(f64, 3, 4); +test_svd_impl!(c64, 3, 3); +test_svd_impl!(c64, 4, 3); +test_svd_impl!(c64, 3, 4);