diff --git a/src/lapack_traits/mod.rs b/src/lapack_traits/mod.rs index 89ce1a46..3ae4fd11 100644 --- a/src/lapack_traits/mod.rs +++ b/src/lapack_traits/mod.rs @@ -56,3 +56,21 @@ pub enum Transpose { Transpose = b'T', Hermite = b'C', } + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum NormType { + One = b'O', + Infinity = b'I', + Frobenius = b'F', +} + +impl NormType { + pub(crate) fn transpose(self) -> Self { + match self { + NormType::One => NormType::Infinity, + NormType::Infinity => NormType::One, + NormType::Frobenius => NormType::Frobenius, + } + } +} diff --git a/src/lapack_traits/opnorm.rs b/src/lapack_traits/opnorm.rs index 81163cff..d3823c00 100644 --- a/src/lapack_traits/opnorm.rs +++ b/src/lapack_traits/opnorm.rs @@ -6,22 +6,7 @@ use lapack::c::Layout::ColumnMajor as cm; use layout::MatrixLayout; use types::*; -#[repr(u8)] -pub enum NormType { - One = b'o', - Infinity = b'i', - Frobenius = b'f', -} - -impl NormType { - fn transpose(self) -> Self { - match self { - NormType::One => NormType::Infinity, - NormType::Infinity => NormType::One, - NormType::Frobenius => NormType::Frobenius, - } - } -} +use super::NormType; pub trait OperatorNorm_: AssociatedReal { unsafe fn opnorm(NormType, MatrixLayout, &[Self]) -> Self::Real; diff --git a/src/lapack_traits/solve.rs b/src/lapack_traits/solve.rs index 6936826c..b3fa5d93 100644 --- a/src/lapack_traits/solve.rs +++ b/src/lapack_traits/solve.rs @@ -4,12 +4,14 @@ use lapack::c; use error::*; use layout::MatrixLayout; +use num_traits::Zero; use types::*; use super::{Pivot, Transpose, into_result}; +use super::NormType; /// Wraps `*getrf`, `*getri`, and `*getrs` -pub trait Solve_: Sized { +pub trait Solve_: AssociatedReal + Sized { /// Computes the LU factorization of a general `m x n` matrix `a` using /// partial pivoting with row interchanges. /// @@ -20,11 +22,15 @@ pub trait Solve_: Sized { /// if it is used to solve a system of equations. unsafe fn lu(MatrixLayout, a: &mut [Self]) -> Result; unsafe fn inv(MatrixLayout, a: &mut [Self], &Pivot) -> Result<()>; + /// Estimates the the reciprocal of the condition number of the matrix in 1-norm. + /// + /// `anorm` should be the 1-norm of the matrix `a`. + unsafe fn rcond(MatrixLayout, a: &[Self], anorm: Self::Real) -> Result; unsafe fn solve(MatrixLayout, Transpose, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>; } macro_rules! impl_solve { - ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => { + ($scalar:ty, $getrf:path, $getri:path, $gecon:path, $getrs:path) => { impl Solve_ for $scalar { unsafe fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { @@ -41,6 +47,13 @@ impl Solve_ for $scalar { into_result(info, ()) } + unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result { + let (n, _) = l.size(); + let mut rcond = Self::Real::zero(); + let info = $gecon(l.lapacke_layout(), NormType::One as u8, n, a, l.lda(), anorm, &mut rcond); + into_result(info, rcond) + } + unsafe fn solve(l: MatrixLayout, t: Transpose, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> { let (n, _) = l.size(); let nrhs = 1; @@ -52,7 +65,7 @@ impl Solve_ for $scalar { }} // impl_solve! -impl_solve!(f64, c::dgetrf, c::dgetri, c::dgetrs); -impl_solve!(f32, c::sgetrf, c::sgetri, c::sgetrs); -impl_solve!(c64, c::zgetrf, c::zgetri, c::zgetrs); -impl_solve!(c32, c::cgetrf, c::cgetri, c::cgetrs); +impl_solve!(f64, c::dgetrf, c::dgetri, c::dgecon, c::dgetrs); +impl_solve!(f32, c::sgetrf, c::sgetri, c::sgecon, c::sgetrs); +impl_solve!(c64, c::zgetrf, c::zgetri, c::zgecon, c::zgetrs); +impl_solve!(c32, c::cgetrf, c::cgetri, c::cgecon, c::cgetrs); diff --git a/src/solve.rs b/src/solve.rs index 2b9f2f24..cdb2afbb 100644 --- a/src/solve.rs +++ b/src/solve.rs @@ -50,6 +50,7 @@ use ndarray::*; use super::convert::*; use super::error::*; use super::layout::*; +use super::opnorm::OperationNorm; use super::types::*; pub use lapack_traits::{Pivot, Transpose}; @@ -419,3 +420,77 @@ where } } } + +/// An interface for *estimating* the reciprocal condition number of matrix refs. +pub trait ReciprocalConditionNum { + /// *Estimates* the reciprocal of the condition number of the matrix in + /// 1-norm. + /// + /// This method uses the LAPACK `*gecon` routines, which *estimate* + /// `self.inv().opnorm_one()` and then compute `rcond = 1. / + /// (self.opnorm_one() * self.inv().opnorm_one())`. + /// + /// * If `rcond` is near `0.`, the matrix is badly conditioned. + /// * If `rcond` is near `1.`, the matrix is well conditioned. + fn rcond(&self) -> Result; +} + +/// An interface for *estimating* the reciprocal condition number of matrices. +pub trait ReciprocalConditionNumInto { + /// *Estimates* the reciprocal of the condition number of the matrix in + /// 1-norm. + /// + /// This method uses the LAPACK `*gecon` routines, which *estimate* + /// `self.inv().opnorm_one()` and then compute `rcond = 1. / + /// (self.opnorm_one() * self.inv().opnorm_one())`. + /// + /// * If `rcond` is near `0.`, the matrix is badly conditioned. + /// * If `rcond` is near `1.`, the matrix is well conditioned. + fn rcond_into(self) -> Result; +} + +impl ReciprocalConditionNum for LUFactorized +where + A: Scalar, + S: Data, +{ + fn rcond(&self) -> Result { + unsafe { + A::rcond( + self.a.layout()?, + self.a.as_allocated()?, + self.a.opnorm_one()?, + ) + } + } +} + +impl ReciprocalConditionNumInto for LUFactorized +where + A: Scalar, + S: Data, +{ + fn rcond_into(self) -> Result { + self.rcond() + } +} + +impl ReciprocalConditionNum for ArrayBase +where + A: Scalar, + S: Data, +{ + fn rcond(&self) -> Result { + self.factorize()?.rcond_into() + } +} + +impl ReciprocalConditionNumInto for ArrayBase +where + A: Scalar, + S: DataMut, +{ + fn rcond_into(self) -> Result { + self.factorize_into()?.rcond_into() + } +} diff --git a/tests/solve.rs b/tests/solve.rs index 48b66782..5bc6357c 100644 --- a/tests/solve.rs +++ b/tests/solve.rs @@ -23,3 +23,51 @@ fn solve_random_t() { let y = a.solve_into(b).unwrap(); assert_close_l2!(&x, &y, 1e-7); } + +#[test] +fn rcond() { + macro_rules! rcond { + ($elem:ty, $rows:expr, $atol:expr) => { + let a: Array2<$elem> = random(($rows, $rows)); + let rcond = 1. / (a.opnorm_one().unwrap() * a.inv().unwrap().opnorm_one().unwrap()); + assert_aclose!(a.rcond().unwrap(), rcond, $atol); + assert_aclose!(a.rcond_into().unwrap(), rcond, $atol); + } + } + for rows in 1..6 { + rcond!(f64, rows, 0.2); + rcond!(f32, rows, 0.5); + rcond!(c64, rows, 0.2); + rcond!(c32, rows, 0.5); + } +} + +#[test] +fn rcond_hilbert() { + macro_rules! rcond_hilbert { + ($elem:ty, $rows:expr, $atol:expr) => { + let a = Array2::<$elem>::from_shape_fn(($rows, $rows), |(i, j)| 1. / (i as $elem + j as $elem - 1.)); + assert_aclose!(a.rcond().unwrap(), 0., $atol); + assert_aclose!(a.rcond_into().unwrap(), 0., $atol); + } + } + rcond_hilbert!(f64, 10, 1e-9); + rcond_hilbert!(f32, 10, 1e-3); +} + +#[test] +fn rcond_identity() { + macro_rules! rcond_identity { + ($elem:ty, $rows:expr, $atol:expr) => { + let a = Array2::<$elem>::eye($rows); + assert_aclose!(a.rcond().unwrap(), 1., $atol); + assert_aclose!(a.rcond_into().unwrap(), 1., $atol); + } + } + for rows in 1..6 { + rcond_identity!(f64, rows, 1e-9); + rcond_identity!(f32, rows, 1e-3); + rcond_identity!(c64, rows, 1e-9); + rcond_identity!(c32, rows, 1e-3); + } +}