From 3fba6622f0cb62486dea071344ba4fb346b8d647 Mon Sep 17 00:00:00 2001 From: doraneko94 Date: Fri, 12 Jun 2020 18:14:17 +0900 Subject: [PATCH 1/8] add calculation for tridiagonal matrices (solve, factorize, det, rcond) --- examples/tridiagonal.rs | 30 ++ src/lapack/mod.rs | 4 +- src/lapack/tridiagonal.rs | 103 +++++++ src/lib.rs | 3 + src/tridiagonal.rs | 567 ++++++++++++++++++++++++++++++++++++++ tests/tridiagonal.rs | 143 ++++++++++ 6 files changed, 849 insertions(+), 1 deletion(-) create mode 100644 examples/tridiagonal.rs create mode 100644 src/lapack/tridiagonal.rs create mode 100644 src/tridiagonal.rs create mode 100644 tests/tridiagonal.rs diff --git a/examples/tridiagonal.rs b/examples/tridiagonal.rs new file mode 100644 index 00000000..d7509e4d --- /dev/null +++ b/examples/tridiagonal.rs @@ -0,0 +1,30 @@ +use ndarray::*; +use ndarray_linalg::*; + +// Solve `Ax=b` for tridiagonal matrix +fn solve() -> Result<(), error::LinalgError> { + let mut a: Array2 = random((3, 3)); + let b: Array1 = random(3); + a[[0, 2]] = 0.0; + a[[2, 0]] = 0.0; + let _x = a.solve_tridiagonal(&b)?; + Ok(()) +} + +// Solve `Ax=b` for many b with fixed A +fn factorize() -> Result<(), error::LinalgError> { + let mut a: Array2 = random((3, 3)); + a[[0, 2]] = 0.0; + a[[2, 0]] = 0.0; + let f = a.factorize_tridiagonal()?; // LU factorize A (A is *not* consumed) + for _ in 0..10 { + let b: Array1 = random(3); + let _x = f.solve_tridiagonal_into(b)?; // solve Ax=b using factorized L, U + } + Ok(()) +} + +fn main() { + solve().unwrap(); + factorize().unwrap(); +} \ No newline at end of file diff --git a/src/lapack/mod.rs b/src/lapack/mod.rs index 6a6903fe..bd9f029f 100644 --- a/src/lapack/mod.rs +++ b/src/lapack/mod.rs @@ -10,6 +10,7 @@ pub mod solveh; pub mod svd; pub mod svddc; pub mod triangular; +pub mod tridiagonal; pub use self::cholesky::*; pub use self::eig::*; @@ -21,6 +22,7 @@ pub use self::solveh::*; pub use self::svd::*; pub use self::svddc::*; pub use self::triangular::*; +pub use self::tridiagonal::*; use super::error::*; use super::types::*; @@ -29,7 +31,7 @@ pub type Pivot = Vec; /// Trait for primitive types which implements LAPACK subroutines pub trait Lapack: - OperatorNorm_ + QR_ + SVD_ + SVDDC_ + Solve_ + Solveh_ + Cholesky_ + Eig_ + Eigh_ + Triangular_ + OperatorNorm_ + QR_ + SVD_ + SVDDC_ + Solve_ + Solveh_ + Cholesky_ + Eig_ + Eigh_ + Triangular_ + TriDiagonal_ { } diff --git a/src/lapack/tridiagonal.rs b/src/lapack/tridiagonal.rs new file mode 100644 index 00000000..1361880f --- /dev/null +++ b/src/lapack/tridiagonal.rs @@ -0,0 +1,103 @@ +//! Implement linear solver using LU decomposition +//! for tridiagonal matrix + +use lapacke; +use ndarray::*; +use num_traits::Zero; + +use super::NormType; +use super::{into_result, Pivot, Transpose}; + +use crate::error::*; +use crate::layout::MatrixLayout; +use crate::tridiagonal::{TriDiagonal, LUFactorizedTriDiagonal}; +use crate::types::*; + +/// Wraps `*gttrf`, `*gtcon` and `*gttrs` +pub trait TriDiagonal_: Scalar + Sized { + /// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using + /// partial pivoting with row interchanges. + unsafe fn lu_tridiagonal(a: &mut TriDiagonal) -> Result<(Array1, Pivot)>; + /// Estimates the the reciprocal of the condition number of the tridiagonal matrix in 1-norm. + unsafe fn rcond_tridiagonal(lu: &LUFactorizedTriDiagonal) -> Result; + unsafe fn solve_tridiagonal( + lu: &LUFactorizedTriDiagonal, + bl: MatrixLayout, + t: Transpose, + b: &mut [Self]) -> Result<()>; +} + +macro_rules! impl_tridiagonal { + ($scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { + impl TriDiagonal_ for $scalar { + unsafe fn lu_tridiagonal(a: &mut TriDiagonal) -> Result<(Array1, Pivot)> { + let (n, _) = a.l.size(); + let dl = a.dl.as_slice_mut().unwrap(); + let d = a.d.as_slice_mut().unwrap(); + let du = a.du.as_slice_mut().unwrap(); + let mut du2 = vec![Zero::zero(); (n-2) as usize]; + let mut ipiv = vec![0; n as usize]; + let info = $gttrf(n, dl, d, du, &mut du2, &mut ipiv); + into_result(info, (arr1(&du2), ipiv)) + } + + unsafe fn rcond_tridiagonal(lu: &LUFactorizedTriDiagonal) -> Result { + let (n, _) = lu.a.l.size(); + let dl = lu.a.dl.as_slice().unwrap(); + let d = lu.a.d.as_slice().unwrap(); + let du = lu.a.du.as_slice().unwrap(); + let du2 = lu.du2.as_slice().unwrap(); + let ipiv = &lu.ipiv; + let anorm = lu.a.n1; + let mut rcond = Self::Real::zero(); + let info = $gtcon( + NormType::One as u8, + n, + dl, + d, + du, + du2, + ipiv, + anorm, + &mut rcond, + ); + into_result(info, rcond) + } + + unsafe fn solve_tridiagonal( + lu: &LUFactorizedTriDiagonal, + bl: MatrixLayout, + t: Transpose, + b: &mut [Self] + ) -> Result<()> { + let (n, _) = lu.a.l.size(); + let (_, nrhs) = bl.size(); + let dl = lu.a.dl.as_slice().unwrap(); + let d = lu.a.d.as_slice().unwrap(); + let du = lu.a.du.as_slice().unwrap(); + let du2 = lu.du2.as_slice().unwrap(); + let ipiv = &lu.ipiv; + let ldb = bl.lda(); + let info = $gttrs( + lu.a.l.lapacke_layout(), + t as u8, + n, + nrhs, + dl, + d, + du, + du2, + ipiv, + b, + ldb, + ); + into_result(info, ()) + } + } + }; +} // impl_tridiagonal! + +impl_tridiagonal!(f64, lapacke::dgttrf, lapacke::dgtcon, lapacke::dgttrs); +impl_tridiagonal!(f32, lapacke::sgttrf, lapacke::sgtcon, lapacke::sgttrs); +impl_tridiagonal!(c64, lapacke::zgttrf, lapacke::zgtcon, lapacke::zgttrs); +impl_tridiagonal!(c32, lapacke::cgttrf, lapacke::cgtcon, lapacke::cgttrs); \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index e3c90efe..06d9c89a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ //! - [General matrices](solve/index.html) //! - [Triangular matrices](triangular/index.html) //! - [Hermitian/real symmetric matrices](solveh/index.html) +//! - [Tridiagonal matrices](tridiagonal/index.html) //! - [Inverse matrix computation](solve/trait.Inverse.html) //! //! Naming Convention @@ -66,6 +67,7 @@ pub mod svd; pub mod svddc; pub mod trace; pub mod triangular; +pub mod tridiagonal; pub mod types; pub use assert::*; @@ -88,4 +90,5 @@ pub use svd::*; pub use svddc::*; pub use trace::*; pub use triangular::*; +pub use tridiagonal::*; pub use types::*; diff --git a/src/tridiagonal.rs b/src/tridiagonal.rs new file mode 100644 index 00000000..4ac99756 --- /dev/null +++ b/src/tridiagonal.rs @@ -0,0 +1,567 @@ +//! Vectors as a TriDiagonal matrix +//! & +//! Methods for tridiagonal matrices + +use ndarray::*; +use cauchy::Scalar; +use num_traits::One; + +use crate::opnorm::OperationNorm; + +use super::convert::*; +use super::error::*; +use super::lapack::*; +use super::layout::*; + +/// Represents a tridiagonal matrix as 3 one-dimensional vectors. +/// This struct also holds the layout and 1-norm of the raw matrix +/// for some methods (eg. rcond_tridiagonal()). +#[derive(Clone)] +pub struct TriDiagonal { + /// layout of raw matrix + pub l: MatrixLayout, + /// the one norm of raw matrix + pub n1: ::Real, + /// (n-1) sub-diagonal elements of matrix. + pub dl: Array1, + /// (n) diagonal elements of matrix. + pub d: Array1, + /// (n-1) super-diagonal elements of matrix. + pub du: Array1, +} + +/// An interface for making a TriDiagonal struct. +pub trait ToTriDiagonal { + /// Extract tridiagonal elements and layout of the raw matrix. + /// And also calculate 1-norm. + /// + /// If the raw matrix has some non-tridiagonal elements, + /// they will be ignored. + /// + /// The shape of raw matrix should be equal to or larger than (2, 2). + fn to_tridiagonal(&self) -> Result>; +} + +impl ToTriDiagonal for ArrayBase +where + A: Scalar + Lapack, + S: Data +{ + fn to_tridiagonal(&self) -> Result> { + let l = self.square_layout()?; + let (n, _) = l.size(); + if n < 2 { panic!("Cannot make a tridiagonal matrix of shape=(1, 1)!"); } + let n1 = self.opnorm_one()?; + + let dl = self.slice(s![1..n, 0..n-1]).diag().to_owned(); + let d = self.diag().to_owned(); + let du = self.slice(s![0..n-1, 1..n]).diag().to_owned(); + Ok(TriDiagonal { l, n1, dl, d, du }) + } +} + +pub trait SolveTriDiagonal { + /// Solves a system of linear equations `A * x = b` with tridiagonal + /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// `x` is the successful result. + fn solve_tridiagonal>(&self, b: &ArrayBase) -> Result>; + /// Solves a system of linear equations `A * x = b` with tridiagonal + /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// `x` is the successful result. + fn solve_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result>; + /// Solves a system of linear equations `A^T * x = b` with tridiagonal + /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// `x` is the successful result. + fn solve_t_tridiagonal>(&self, b: &ArrayBase) -> Result>; + /// Solves a system of linear equations `A^T * x = b` with tridiagonal + /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// `x` is the successful result. + fn solve_t_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result>; + /// Solves a system of linear equations `A^H * x = b` with tridiagonal + /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// `x` is the successful result. + fn solve_h_tridiagonal>(&self, b: &ArrayBase) -> Result>; + /// Solves a system of linear equations `A^H * x = b` with tridiagonal + /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// `x` is the successful result. + fn solve_h_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result>; +} + +pub trait SolveTriDiagonalInplace { + /// Solves a system of linear equations `A * x = b` tridiagonal + /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// `x` is the successful result. The value of `x` is also assigned to the + /// argument. + fn solve_tridiagonal_inplace<'a, S: DataMut>( + &self, + b: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase>; + /// Solves a system of linear equations `A^T * x = b` tridiagonal + /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// `x` is the successful result. The value of `x` is also assigned to the + /// argument. + fn solve_t_tridiagonal_inplace<'a, S: DataMut>( + &self, + b: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase>; + /// Solves a system of linear equations `A^H * x = b` tridiagonal + /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// `x` is the successful result. The value of `x` is also assigned to the + /// argument. + fn solve_h_tridiagonal_inplace<'a, S: DataMut>( + &self, + b: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase>; +} + +/// Represents the LU factorization of a tridiagonal matrix `A` as `A = P*L*U`. +#[derive(Clone)] +pub struct LUFactorizedTriDiagonal { + /// A tridiagonal matrix which consists of + /// - l : layout of raw matrix + /// - n1: the one norm of raw matrix + /// - dl: (n-1) multipliers that define the matrix L. + /// - d : (n) diagonal elements of the upper triangular matrix U. + /// - du: (n-1) elements of the first super-diagonal of U. + pub a: TriDiagonal, + /// (n-2) elements of the second super-diagonal of U. + pub du2: Array1, + /// The pivot indices that define the permutation matrix `P`. + pub ipiv: Pivot, +} + +impl SolveTriDiagonal for LUFactorizedTriDiagonal +where + A: Scalar + Lapack, +{ + fn solve_tridiagonal>(&self, b: &ArrayBase) -> Result> { + let mut b = replicate(b); + self.solve_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_t_tridiagonal>(&self, b: &ArrayBase) -> Result> { + let mut b = replicate(b); + self.solve_t_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_t_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_t_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_h_tridiagonal>(&self, b: &ArrayBase) -> Result> { + let mut b = replicate(b); + self.solve_h_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_h_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_h_tridiagonal_inplace(&mut b)?; + Ok(b) + } +} + +impl SolveTriDiagonal for ArrayBase +where + A: Scalar + Lapack, + S: Data, +{ + fn solve_tridiagonal>(&self, b: &ArrayBase) -> Result> { + let mut b = replicate(b); + self.solve_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_t_tridiagonal>(&self, b: &ArrayBase) -> Result> { + let mut b = replicate(b); + self.solve_t_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_t_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_t_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_h_tridiagonal>(&self, b: &ArrayBase) -> Result> { + let mut b = replicate(b); + self.solve_h_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_h_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_h_tridiagonal_inplace(&mut b)?; + Ok(b) + } +} + +impl SolveTriDiagonalInplace for LUFactorizedTriDiagonal +where + A: Scalar + Lapack, +{ + fn solve_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + unsafe { + A::solve_tridiagonal( + &self, + rhs.layout()?, + Transpose::No, + rhs.as_slice_mut().unwrap(), + )? + }; + Ok(rhs) + } + fn solve_t_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + unsafe { + A::solve_tridiagonal( + &self, + rhs.layout()?, + Transpose::Transpose, + rhs.as_slice_mut().unwrap(), + )? + }; + Ok(rhs) + } + fn solve_h_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + unsafe { + A::solve_tridiagonal( + &self, + rhs.layout()?, + Transpose::Hermite, + rhs.as_slice_mut().unwrap(), + )? + }; + Ok(rhs) + } +} + +impl SolveTriDiagonalInplace for ArrayBase +where + A: Scalar + Lapack, + S: Data, +{ + fn solve_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize_tridiagonal()?; + f.solve_tridiagonal_inplace(rhs) + } + fn solve_t_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize_tridiagonal()?; + f.solve_t_tridiagonal_inplace(rhs) + } + fn solve_h_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize_tridiagonal()?; + f.solve_h_tridiagonal_inplace(rhs) + } +} + +impl SolveTriDiagonal for LUFactorizedTriDiagonal +where + A: Scalar + Lapack, +{ + fn solve_tridiagonal>(&self, b: &ArrayBase) -> Result> { + let b = b.to_owned(); + self.solve_tridiagonal_into(b) + } + fn solve_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result> { + let b = into_col(b); + let b = self.solve_tridiagonal_into(b)?; + Ok(flatten(b)) + } + fn solve_t_tridiagonal>(&self, b: &ArrayBase) -> Result> { + let b = b.to_owned(); + self.solve_t_tridiagonal_into(b) + } + fn solve_t_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result> { + let b = into_col(b); + let b = self.solve_t_tridiagonal_into(b)?; + Ok(flatten(b)) + } + fn solve_h_tridiagonal>(&self, b: &ArrayBase) -> Result> { + let b = b.to_owned(); + self.solve_h_tridiagonal_into(b) + } + fn solve_h_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result> { + let b = into_col(b); + let b = self.solve_h_tridiagonal_into(b)?; + Ok(flatten(b)) + } +} + +impl SolveTriDiagonal for ArrayBase +where + A: Scalar + Lapack, + S: Data, +{ + fn solve_tridiagonal>(&self, b: &ArrayBase) -> Result> { + let b = b.to_owned(); + self.solve_tridiagonal_into(b) + } + fn solve_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result> { + let b = into_col(b); + let f = self.factorize_tridiagonal()?; + let b = f.solve_tridiagonal_into(b)?; + Ok(flatten(b)) + } + fn solve_t_tridiagonal>(&self, b: &ArrayBase) -> Result> { + let b = b.to_owned(); + self.solve_t_tridiagonal_into(b) + } + fn solve_t_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result> { + let b = into_col(b); + let f = self.factorize_tridiagonal()?; + let b = f.solve_t_tridiagonal_into(b)?; + Ok(flatten(b)) + } + fn solve_h_tridiagonal>(&self, b: &ArrayBase) -> Result> { + let b = b.to_owned(); + self.solve_h_tridiagonal_into(b) + } + fn solve_h_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result> { + let b = into_col(b); + let f = self.factorize_tridiagonal()?; + let b = f.solve_h_tridiagonal_into(b)?; + Ok(flatten(b)) + } +} + +/// An interface for computing LU factorizations of tridiagonal matrix refs. +pub trait FactorizeTriDiagonal { + /// Computes the LU factorization `A = P*L*U`, where `P` is a permutation + /// matrix. + fn factorize_tridiagonal(&self) -> Result>; +} + +/// An interface for computing LU factorizations of tridiagonal matrices. +pub trait FactorizeTriDiagonalInto { + /// Computes the LU factorization `A = P*L*U`, where `P` is a permutation + /// matrix. + fn factorize_tridiagonal_into(self) -> Result>; +} + +impl FactorizeTriDiagonalInto for TriDiagonal +where + A: Scalar + Lapack, +{ + fn factorize_tridiagonal_into(mut self) -> Result> { + let (du2, ipiv) = unsafe { A::lu_tridiagonal(&mut self)? }; + Ok(LUFactorizedTriDiagonal { + a: self, + du2: du2, + ipiv: ipiv + }) + } +} + +impl FactorizeTriDiagonal for TriDiagonal +where + A: Scalar + Lapack, +{ + fn factorize_tridiagonal(&self) -> Result> { + let mut a = self.clone(); + let (du2, ipiv) = unsafe { A::lu_tridiagonal(&mut a)? }; + Ok(LUFactorizedTriDiagonal { a, du2, ipiv }) + } +} + +impl FactorizeTriDiagonal for ArrayBase +where + A: Scalar + Lapack, + S: Data +{ + fn factorize_tridiagonal(&self) -> Result> { + let mut a = self.to_tridiagonal()?; + let (du2, ipiv) = unsafe { A::lu_tridiagonal(&mut a)? }; + Ok(LUFactorizedTriDiagonal { a, du2, ipiv }) + } +} + +/// Calculates the recurrent relation, +/// f_k = a_k * f_{k-1} - c_{k-1} * b_{k-1} * f_{k-2} +/// where {a_1, a_2, ..., a_n} are diagonal elements, +/// {b_1, b_2, ..., b_{n-1}} are super-diagonal elements, and +/// {c_1, c_2, ..., c_{n-1}} are sub-diagonal elements of matrix. +/// +/// f[n] is used to calculate the determinant. +/// (https://en.wikipedia.org/wiki/Tridiagonal_matrix#Determinant) +/// +/// In the future, the vector `f` can be used to calculate the inverce matrix. +/// (https://en.wikipedia.org/wiki/Tridiagonal_matrix#Inversion) +fn rec_rel(tridiag: &TriDiagonal) -> Vec { + let n = tridiag.d.shape()[0]; + let mut f = Vec::with_capacity(n+1); + f.push(One::one()); + f.push(tridiag.d[0]); + for i in 1..n { + f.push(tridiag.d[i] * f[i] - tridiag.dl[i-1] * tridiag.du[i-1] * f[i-1]); + } + f +} + +/// An interface for calculating determinants of tridiagonal matrix refs. +pub trait DeterminantTriDiagonal { + /// Computes the determinant of the matrix. + /// Unlike `.det()` of Determinant trait, this method + /// doesn't returns the natural logarithm of the determinant + /// but the determinant itself. + fn det_tridiagonal(&self) -> Result; +} + +impl DeterminantTriDiagonal for TriDiagonal +where + A: Scalar, +{ + fn det_tridiagonal(&self) -> Result { + let n = self.d.shape()[0]; + Ok(rec_rel(&self)[n]) + } +} + +impl DeterminantTriDiagonal for ArrayBase +where + A: Scalar + Lapack, + S: Data, +{ + fn det_tridiagonal(&self) -> Result { + let tridiag = self.to_tridiagonal()?; + let n = tridiag.d.shape()[0]; + Ok(rec_rel(&tridiag)[n]) + } +} + +/// An interface for *estimating* the reciprocal condition number of tridiagonal matrix refs. +pub trait ReciprocalConditionNumTriDiagonal { + /// *Estimates* the reciprocal of the condition number of the tridiagonal matrix in + /// 1-norm. + /// + /// This method uses the LAPACK `*gtcon` routines, which *estimate* + /// `self.inv_tridiagonal().opnorm_one()` and then compute `rcond = 1. / + /// (self.opnorm_one() * self.inv_tridiagonal().opnorm_one())`. + /// + /// * If `rcond` is near `0.`, the matrix is badly conditioned. + /// * If `rcond` is near `1.`, the matrix is well conditioned. + fn rcond_tridiagonal(&self) -> Result; +} + +/// An interface for *estimating* the reciprocal condition number of tridiagonal matrices. +pub trait ReciprocalConditionNumTriDiagonalInto { + /// *Estimates* the reciprocal of the condition number of the tridiagonal matrix in + /// 1-norm. + /// + /// This method uses the LAPACK `*gtcon` routines, which *estimate* + /// `self.inv_tridiagonal().opnorm_one()` and then compute `rcond = 1. / + /// (self.opnorm_one() * self.inv_tridiagonal().opnorm_one())`. + /// + /// * If `rcond` is near `0.`, the matrix is badly conditioned. + /// * If `rcond` is near `1.`, the matrix is well conditioned. + fn rcond_tridiagonal_into(self) -> Result; +} + +impl ReciprocalConditionNumTriDiagonal for LUFactorizedTriDiagonal +where + A: Scalar + Lapack, +{ + fn rcond_tridiagonal(&self) -> Result { + unsafe { A::rcond_tridiagonal(&self) } + } +} + +impl ReciprocalConditionNumTriDiagonalInto for LUFactorizedTriDiagonal +where + A: Scalar + Lapack, +{ + fn rcond_tridiagonal_into(self) -> Result { + self.rcond_tridiagonal() + } +} + +impl ReciprocalConditionNumTriDiagonal for ArrayBase +where + A: Scalar + Lapack, + S: Data, +{ + fn rcond_tridiagonal(&self) -> Result { + self.factorize_tridiagonal()?.rcond_tridiagonal_into() + } +} diff --git a/tests/tridiagonal.rs b/tests/tridiagonal.rs new file mode 100644 index 00000000..7fda3f5c --- /dev/null +++ b/tests/tridiagonal.rs @@ -0,0 +1,143 @@ +use ndarray::*; +use ndarray_linalg::*; + +#[test] +fn to_tridiagonal() { + let a: Array2 = arr2(&[[1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0]]); + let t = a.to_tridiagonal().unwrap(); + assert_close_l2!(&t.dl, &arr1(&[4.0, 8.0]), 1e-7); + assert_close_l2!(&t.d , &arr1(&[1.0, 5.0, 9.0]), 1e-7); + assert_close_l2!(&t.du, &arr1(&[2.0, 6.0]), 1e-7); +} + +#[test] +fn solve_tridiagonal_f64() { + // https://www.nag-j.co.jp/lapack/dgttrs.htm + let a: Array2 = arr2(&[[ 3.0, 2.1, 0.0, 0.0, 0.0], + [ 3.4, 2.3, -1.0, 0.0, 0.0], + [ 0.0, 3.6, -5.0, 1.9, 0.0], + [ 0.0, 0.0, 7.0, -0.9, 8.0], + [ 0.0, 0.0, 0.0, -6.0, 7.1]]); + let b: Array2 = arr2(&[[ 2.7, 6.6], + [ -0.5, 10.8], + [ 2.6, -3.2], + [ 0.6, -11.2], + [ 2.7, 19.1]]); + let x: Array2 = arr2(&[[ -4.0, 5.0], + [ 7.0, -4.0], + [ 3.0, -3.0], + [ -4.0, -2.0], + [ -3.0, 1.0]]); + let y = a.solve_tridiagonal_into(b).unwrap(); + assert_close_l2!(&x, &y, 1e-7); +} +//`*gttrf`, `*gtcon` and `*gttrs` +#[test] +fn solve_tridiagonal_c64() { + // https://www.nag-j.co.jp/lapack/zgttrs.htm + let a: Array2 = arr2(&[[ c64::new( -1.3, 1.3), c64::new( 2.0, -1.0), c64::new( 0.0, 0.0), c64::new( 0.0, 0.0), c64::new( 0.0, 0.0)], + [ c64::new( 1.0, -2.0), c64::new( -1.3, 1.3), c64::new( 2.0, 1.0), c64::new( 0.0, 0.0), c64::new( 0.0, 0.0)], + [ c64::new( 0.0, 0.0), c64::new( 1.0, 1.0), c64::new( -1.3, 3.3), c64::new( -1.0, 1.0), c64::new( 0.0, 0.0)], + [ c64::new( 0.0, 0.0), c64::new( 0.0, 0.0), c64::new( 2.0, -3.0), c64::new( -0.3, 4.3), c64::new( 1.0, -1.0)], + [ c64::new( 0.0, 0.0), c64::new( 0.0, 0.0), c64::new( 0.0, 0.0), c64::new( 1.0, 1.0), c64::new( -3.3, 1.3)]]); + let b: Array2 = arr2(&[[ c64::new( 2.4, -5.0), c64::new( 2.7, 6.9)], + [ c64::new( 3.4, 18.2), c64::new( -6.9, -5.3)], + [ c64::new(-14.7, 9.7), c64::new( -6.0, -0.6)], + [ c64::new( 31.9, -7.7), c64::new( -3.9, 9.3)], + [ c64::new( -1.0, 1.6), c64::new( -3.0, 12.2)]]); + let x: Array2 = arr2(&[[ c64::new( 1.0, 1.0), c64::new( 2.0, -1.0)], + [ c64::new( 3.0, -1.0), c64::new( 1.0, 2.0)], + [ c64::new( 4.0, 5.0), c64::new( -1.0, 1.0)], + [ c64::new( -1.0, -2.0), c64::new( 2.0, 1.0)], + [ c64::new( 1.0, -1.0), c64::new( 2.0, -2.0)]]); + let y = a.solve_tridiagonal_into(b).unwrap(); + assert_close_l2!(&x, &y, 1e-7); +} + +#[test] +fn solve_tridiagonal_random() { + let mut a: Array2 = random((3, 3)); + a[[0, 2]] = 0.0; + a[[2, 0]] = 0.0; + let x: Array1 = random(3); + let b1 = a.dot(&x); + let b2 = b1.clone(); + let y1 = a.solve_tridiagonal_into(b1).unwrap(); + let y2 = a.solve_into(b2).unwrap(); + assert_close_l2!(&x, &y1, 1e-7); + assert_close_l2!(&y1, &y2, 1e-7); +} + +#[test] +fn solve_tridiagonal_random_t() { + let mut a: Array2 = random((3, 3)); + a[[0, 2]] = 0.0; + a[[2, 0]] = 0.0; + let x: Array1 = random(3); + let at = a.t(); + let b1 = at.dot(&x); + let b2 = b1.clone(); + let y1 = a.solve_t_tridiagonal_into(b1).unwrap(); + let y2 = a.solve_t_into(b2).unwrap(); + assert_close_l2!(&x, &y1, 1e-7); + assert_close_l2!(&y1, &y2, 1e-7); +} + +#[test] +fn det_tridiagonal_f64() { + let a: Array2 = arr2(&[[ 10.0, -9.0, 0.0], + [ 7.0, -12.0, 11.0], + [ 0.0, 10.0, 3.0]]); + assert_aclose!(a.det_tridiagonal().unwrap(), -1271.0, 1e-7); + assert_aclose!(a.det_tridiagonal().unwrap(), a.det().unwrap(), 1e-7); +} + +#[test] +fn det_tridiagonal_random() { + let mut a: Array2 = random((3, 3)); + a[[0, 2]] = 0.0; + a[[2, 0]] = 0.0; + assert_aclose!(a.det_tridiagonal().unwrap(), a.det().unwrap(), 1e-7); +} + +#[test] +fn rcond_tridiagonal_f64() { + // https://www.nag-j.co.jp/lapack/dgtcon.htm + let a: Array2 = arr2(&[[ 3.0, 2.1, 0.0, 0.0, 0.0], + [ 3.4, 2.3, -1.0, 0.0, 0.0], + [ 0.0, 3.6, -5.0, 1.9, 0.0], + [ 0.0, 0.0, 7.0, -0.9, 8.0], + [ 0.0, 0.0, 0.0, -6.0, 7.1]]); + assert_aclose!(1.0 / a.rcond_tridiagonal().unwrap(), 9.27e1, 0.1); + assert_aclose!(a.rcond_tridiagonal().unwrap(), a.rcond().unwrap(), 1e-3); +} + +#[test] +fn rcond_tridiagonal_c64() { + // https://www.nag-j.co.jp/lapack/dgtcon.htm + let a: Array2 = arr2(&[[ c64::new( -1.3, 1.3), c64::new( 2.0, -1.0), c64::new( 0.0, 0.0), c64::new( 0.0, 0.0), c64::new( 0.0, 0.0)], + [ c64::new( 1.0, -2.0), c64::new( -1.3, 1.3), c64::new( 2.0, 1.0), c64::new( 0.0, 0.0), c64::new( 0.0, 0.0)], + [ c64::new( 0.0, 0.0), c64::new( 1.0, 1.0), c64::new( -1.3, 3.3), c64::new( -1.0, 1.0), c64::new( 0.0, 0.0)], + [ c64::new( 0.0, 0.0), c64::new( 0.0, 0.0), c64::new( 2.0, -3.0), c64::new( -0.3, 4.3), c64::new( 1.0, -1.0)], + [ c64::new( 0.0, 0.0), c64::new( 0.0, 0.0), c64::new( 0.0, 0.0), c64::new( 1.0, 1.0), c64::new( -3.3, 1.3)]]); + assert_aclose!(1.0 / a.rcond_tridiagonal().unwrap(), 1.84e2, 1.0); + assert_aclose!(a.rcond_tridiagonal().unwrap(), a.rcond().unwrap(), 1e-3); +} + +#[test] +fn rcond_tridiagonal_identity() { + macro_rules! rcond_identity { + ($elem:ty, $rows:expr, $atol:expr) => { + let a = Array2::<$elem>::eye($rows); + assert_aclose!(a.rcond_tridiagonal().unwrap(), 1., $atol); + }; + } + for rows in 2..6 { // cannot make 1x1 tridiagonal matrices. + rcond_identity!(f64, rows, 1e-9); + rcond_identity!(f32, rows, 1e-3); + rcond_identity!(c64, rows, 1e-9); + rcond_identity!(c32, rows, 1e-3); + } +} \ No newline at end of file From 819c664506f98bd7db4417e7f453a264aa089e81 Mon Sep 17 00:00:00 2001 From: doraneko94 Date: Fri, 12 Jun 2020 19:37:33 +0900 Subject: [PATCH 2/8] cargo fmt --- examples/tridiagonal.rs | 2 +- src/lapack/mod.rs | 12 ++- src/lapack/tridiagonal.rs | 31 +++---- src/tridiagonal.rs | 90 +++++++++++++------- tests/tridiagonal.rs | 171 +++++++++++++++++++++++++++----------- 5 files changed, 211 insertions(+), 95 deletions(-) diff --git a/examples/tridiagonal.rs b/examples/tridiagonal.rs index d7509e4d..676bfb28 100644 --- a/examples/tridiagonal.rs +++ b/examples/tridiagonal.rs @@ -27,4 +27,4 @@ fn factorize() -> Result<(), error::LinalgError> { fn main() { solve().unwrap(); factorize().unwrap(); -} \ No newline at end of file +} diff --git a/src/lapack/mod.rs b/src/lapack/mod.rs index bd9f029f..540fef64 100644 --- a/src/lapack/mod.rs +++ b/src/lapack/mod.rs @@ -31,7 +31,17 @@ pub type Pivot = Vec; /// Trait for primitive types which implements LAPACK subroutines pub trait Lapack: - OperatorNorm_ + QR_ + SVD_ + SVDDC_ + Solve_ + Solveh_ + Cholesky_ + Eig_ + Eigh_ + Triangular_ + TriDiagonal_ + OperatorNorm_ + + QR_ + + SVD_ + + SVDDC_ + + Solve_ + + Solveh_ + + Cholesky_ + + Eig_ + + Eigh_ + + Triangular_ + + TriDiagonal_ { } diff --git a/src/lapack/tridiagonal.rs b/src/lapack/tridiagonal.rs index 1361880f..71ccd308 100644 --- a/src/lapack/tridiagonal.rs +++ b/src/lapack/tridiagonal.rs @@ -10,7 +10,7 @@ use super::{into_result, Pivot, Transpose}; use crate::error::*; use crate::layout::MatrixLayout; -use crate::tridiagonal::{TriDiagonal, LUFactorizedTriDiagonal}; +use crate::tridiagonal::{LUFactorizedTriDiagonal, TriDiagonal}; use crate::types::*; /// Wraps `*gttrf`, `*gtcon` and `*gttrs` @@ -24,7 +24,8 @@ pub trait TriDiagonal_: Scalar + Sized { lu: &LUFactorizedTriDiagonal, bl: MatrixLayout, t: Transpose, - b: &mut [Self]) -> Result<()>; + b: &mut [Self], + ) -> Result<()>; } macro_rules! impl_tridiagonal { @@ -32,20 +33,20 @@ macro_rules! impl_tridiagonal { impl TriDiagonal_ for $scalar { unsafe fn lu_tridiagonal(a: &mut TriDiagonal) -> Result<(Array1, Pivot)> { let (n, _) = a.l.size(); - let dl = a.dl.as_slice_mut().unwrap(); - let d = a.d.as_slice_mut().unwrap(); - let du = a.du.as_slice_mut().unwrap(); - let mut du2 = vec![Zero::zero(); (n-2) as usize]; + let dl = a.dl.as_slice_mut().unwrap(); + let d = a.d.as_slice_mut().unwrap(); + let du = a.du.as_slice_mut().unwrap(); + let mut du2 = vec![Zero::zero(); (n - 2) as usize]; let mut ipiv = vec![0; n as usize]; let info = $gttrf(n, dl, d, du, &mut du2, &mut ipiv); into_result(info, (arr1(&du2), ipiv)) } - + unsafe fn rcond_tridiagonal(lu: &LUFactorizedTriDiagonal) -> Result { let (n, _) = lu.a.l.size(); - let dl = lu.a.dl.as_slice().unwrap(); - let d = lu.a.d.as_slice().unwrap(); - let du = lu.a.du.as_slice().unwrap(); + let dl = lu.a.dl.as_slice().unwrap(); + let d = lu.a.d.as_slice().unwrap(); + let du = lu.a.du.as_slice().unwrap(); let du2 = lu.du2.as_slice().unwrap(); let ipiv = &lu.ipiv; let anorm = lu.a.n1; @@ -68,13 +69,13 @@ macro_rules! impl_tridiagonal { lu: &LUFactorizedTriDiagonal, bl: MatrixLayout, t: Transpose, - b: &mut [Self] + b: &mut [Self], ) -> Result<()> { let (n, _) = lu.a.l.size(); let (_, nrhs) = bl.size(); - let dl = lu.a.dl.as_slice().unwrap(); - let d = lu.a.d.as_slice().unwrap(); - let du = lu.a.du.as_slice().unwrap(); + let dl = lu.a.dl.as_slice().unwrap(); + let d = lu.a.d.as_slice().unwrap(); + let du = lu.a.du.as_slice().unwrap(); let du2 = lu.du2.as_slice().unwrap(); let ipiv = &lu.ipiv; let ldb = bl.lda(); @@ -100,4 +101,4 @@ macro_rules! impl_tridiagonal { impl_tridiagonal!(f64, lapacke::dgttrf, lapacke::dgtcon, lapacke::dgttrs); impl_tridiagonal!(f32, lapacke::sgttrf, lapacke::sgtcon, lapacke::sgttrs); impl_tridiagonal!(c64, lapacke::zgttrf, lapacke::zgtcon, lapacke::zgttrs); -impl_tridiagonal!(c32, lapacke::cgttrf, lapacke::cgtcon, lapacke::cgttrs); \ No newline at end of file +impl_tridiagonal!(c32, lapacke::cgttrf, lapacke::cgtcon, lapacke::cgttrs); diff --git a/src/tridiagonal.rs b/src/tridiagonal.rs index 4ac99756..337e9c3c 100644 --- a/src/tridiagonal.rs +++ b/src/tridiagonal.rs @@ -2,8 +2,8 @@ //! & //! Methods for tridiagonal matrices -use ndarray::*; use cauchy::Scalar; +use ndarray::*; use num_traits::One; use crate::opnorm::OperationNorm; @@ -34,10 +34,10 @@ pub struct TriDiagonal { pub trait ToTriDiagonal { /// Extract tridiagonal elements and layout of the raw matrix. /// And also calculate 1-norm. - /// + /// /// If the raw matrix has some non-tridiagonal elements, /// they will be ignored. - /// + /// /// The shape of raw matrix should be equal to or larger than (2, 2). fn to_tridiagonal(&self) -> Result>; } @@ -45,17 +45,19 @@ pub trait ToTriDiagonal { impl ToTriDiagonal for ArrayBase where A: Scalar + Lapack, - S: Data + S: Data, { fn to_tridiagonal(&self) -> Result> { let l = self.square_layout()?; let (n, _) = l.size(); - if n < 2 { panic!("Cannot make a tridiagonal matrix of shape=(1, 1)!"); } + if n < 2 { + panic!("Cannot make a tridiagonal matrix of shape=(1, 1)!"); + } let n1 = self.opnorm_one()?; - let dl = self.slice(s![1..n, 0..n-1]).diag().to_owned(); - let d = self.diag().to_owned(); - let du = self.slice(s![0..n-1, 1..n]).diag().to_owned(); + let dl = self.slice(s![1..n, 0..n - 1]).diag().to_owned(); + let d = self.diag().to_owned(); + let du = self.slice(s![0..n - 1, 1..n]).diag().to_owned(); Ok(TriDiagonal { l, n1, dl, d, du }) } } @@ -73,22 +75,22 @@ pub trait SolveTriDiagonal { b: ArrayBase, ) -> Result>; /// Solves a system of linear equations `A^T * x = b` with tridiagonal - /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// matrix `A`, where `A` is `self`, `b` is the argument, and /// `x` is the successful result. fn solve_t_tridiagonal>(&self, b: &ArrayBase) -> Result>; /// Solves a system of linear equations `A^T * x = b` with tridiagonal - /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// matrix `A`, where `A` is `self`, `b` is the argument, and /// `x` is the successful result. fn solve_t_tridiagonal_into>( &self, b: ArrayBase, ) -> Result>; /// Solves a system of linear equations `A^H * x = b` with tridiagonal - /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// matrix `A`, where `A` is `self`, `b` is the argument, and /// `x` is the successful result. fn solve_h_tridiagonal>(&self, b: &ArrayBase) -> Result>; /// Solves a system of linear equations `A^H * x = b` with tridiagonal - /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// matrix `A`, where `A` is `self`, `b` is the argument, and /// `x` is the successful result. fn solve_h_tridiagonal_into>( &self, @@ -155,7 +157,10 @@ where self.solve_tridiagonal_inplace(&mut b)?; Ok(b) } - fn solve_t_tridiagonal>(&self, b: &ArrayBase) -> Result> { + fn solve_t_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { let mut b = replicate(b); self.solve_t_tridiagonal_inplace(&mut b)?; Ok(b) @@ -167,7 +172,10 @@ where self.solve_t_tridiagonal_inplace(&mut b)?; Ok(b) } - fn solve_h_tridiagonal>(&self, b: &ArrayBase) -> Result> { + fn solve_h_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { let mut b = replicate(b); self.solve_h_tridiagonal_inplace(&mut b)?; Ok(b) @@ -186,7 +194,10 @@ where A: Scalar + Lapack, S: Data, { - fn solve_tridiagonal>(&self, b: &ArrayBase) -> Result> { + fn solve_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { let mut b = replicate(b); self.solve_tridiagonal_inplace(&mut b)?; Ok(b) @@ -198,7 +209,10 @@ where self.solve_tridiagonal_inplace(&mut b)?; Ok(b) } - fn solve_t_tridiagonal>(&self, b: &ArrayBase) -> Result> { + fn solve_t_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { let mut b = replicate(b); self.solve_t_tridiagonal_inplace(&mut b)?; Ok(b) @@ -210,7 +224,10 @@ where self.solve_t_tridiagonal_inplace(&mut b)?; Ok(b) } - fn solve_h_tridiagonal>(&self, b: &ArrayBase) -> Result> { + fn solve_h_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { let mut b = replicate(b); self.solve_h_tridiagonal_inplace(&mut b)?; Ok(b) @@ -334,7 +351,10 @@ where let b = self.solve_tridiagonal_into(b)?; Ok(flatten(b)) } - fn solve_t_tridiagonal>(&self, b: &ArrayBase) -> Result> { + fn solve_t_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { let b = b.to_owned(); self.solve_t_tridiagonal_into(b) } @@ -346,7 +366,10 @@ where let b = self.solve_t_tridiagonal_into(b)?; Ok(flatten(b)) } - fn solve_h_tridiagonal>(&self, b: &ArrayBase) -> Result> { + fn solve_h_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { let b = b.to_owned(); self.solve_h_tridiagonal_into(b) } @@ -365,7 +388,10 @@ where A: Scalar + Lapack, S: Data, { - fn solve_tridiagonal>(&self, b: &ArrayBase) -> Result> { + fn solve_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { let b = b.to_owned(); self.solve_tridiagonal_into(b) } @@ -378,7 +404,10 @@ where let b = f.solve_tridiagonal_into(b)?; Ok(flatten(b)) } - fn solve_t_tridiagonal>(&self, b: &ArrayBase) -> Result> { + fn solve_t_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { let b = b.to_owned(); self.solve_t_tridiagonal_into(b) } @@ -391,7 +420,10 @@ where let b = f.solve_t_tridiagonal_into(b)?; Ok(flatten(b)) } - fn solve_h_tridiagonal>(&self, b: &ArrayBase) -> Result> { + fn solve_h_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { let b = b.to_owned(); self.solve_h_tridiagonal_into(b) } @@ -429,7 +461,7 @@ where Ok(LUFactorizedTriDiagonal { a: self, du2: du2, - ipiv: ipiv + ipiv: ipiv, }) } } @@ -448,7 +480,7 @@ where impl FactorizeTriDiagonal for ArrayBase where A: Scalar + Lapack, - S: Data + S: Data, { fn factorize_tridiagonal(&self) -> Result> { let mut a = self.to_tridiagonal()?; @@ -462,19 +494,19 @@ where /// where {a_1, a_2, ..., a_n} are diagonal elements, /// {b_1, b_2, ..., b_{n-1}} are super-diagonal elements, and /// {c_1, c_2, ..., c_{n-1}} are sub-diagonal elements of matrix. -/// +/// /// f[n] is used to calculate the determinant. /// (https://en.wikipedia.org/wiki/Tridiagonal_matrix#Determinant) -/// +/// /// In the future, the vector `f` can be used to calculate the inverce matrix. /// (https://en.wikipedia.org/wiki/Tridiagonal_matrix#Inversion) fn rec_rel(tridiag: &TriDiagonal) -> Vec { let n = tridiag.d.shape()[0]; - let mut f = Vec::with_capacity(n+1); + let mut f = Vec::with_capacity(n + 1); f.push(One::one()); f.push(tridiag.d[0]); for i in 1..n { - f.push(tridiag.d[i] * f[i] - tridiag.dl[i-1] * tridiag.du[i-1] * f[i-1]); + f.push(tridiag.d[i] * f[i] - tridiag.dl[i - 1] * tridiag.du[i - 1] * f[i - 1]); } f } @@ -483,7 +515,7 @@ fn rec_rel(tridiag: &TriDiagonal) -> Vec { pub trait DeterminantTriDiagonal { /// Computes the determinant of the matrix. /// Unlike `.det()` of Determinant trait, this method - /// doesn't returns the natural logarithm of the determinant + /// doesn't returns the natural logarithm of the determinant /// but the determinant itself. fn det_tridiagonal(&self) -> Result; } diff --git a/tests/tridiagonal.rs b/tests/tridiagonal.rs index 7fda3f5c..b04e592d 100644 --- a/tests/tridiagonal.rs +++ b/tests/tridiagonal.rs @@ -3,33 +3,37 @@ use ndarray_linalg::*; #[test] fn to_tridiagonal() { - let a: Array2 = arr2(&[[1.0, 2.0, 3.0], - [4.0, 5.0, 6.0], - [7.0, 8.0, 9.0]]); + let a: Array2 = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]); let t = a.to_tridiagonal().unwrap(); assert_close_l2!(&t.dl, &arr1(&[4.0, 8.0]), 1e-7); - assert_close_l2!(&t.d , &arr1(&[1.0, 5.0, 9.0]), 1e-7); + assert_close_l2!(&t.d, &arr1(&[1.0, 5.0, 9.0]), 1e-7); assert_close_l2!(&t.du, &arr1(&[2.0, 6.0]), 1e-7); } #[test] fn solve_tridiagonal_f64() { // https://www.nag-j.co.jp/lapack/dgttrs.htm - let a: Array2 = arr2(&[[ 3.0, 2.1, 0.0, 0.0, 0.0], - [ 3.4, 2.3, -1.0, 0.0, 0.0], - [ 0.0, 3.6, -5.0, 1.9, 0.0], - [ 0.0, 0.0, 7.0, -0.9, 8.0], - [ 0.0, 0.0, 0.0, -6.0, 7.1]]); - let b: Array2 = arr2(&[[ 2.7, 6.6], - [ -0.5, 10.8], - [ 2.6, -3.2], - [ 0.6, -11.2], - [ 2.7, 19.1]]); - let x: Array2 = arr2(&[[ -4.0, 5.0], - [ 7.0, -4.0], - [ 3.0, -3.0], - [ -4.0, -2.0], - [ -3.0, 1.0]]); + let a: Array2 = arr2(&[ + [3.0, 2.1, 0.0, 0.0, 0.0], + [3.4, 2.3, -1.0, 0.0, 0.0], + [0.0, 3.6, -5.0, 1.9, 0.0], + [0.0, 0.0, 7.0, -0.9, 8.0], + [0.0, 0.0, 0.0, -6.0, 7.1], + ]); + let b: Array2 = arr2(&[ + [2.7, 6.6], + [-0.5, 10.8], + [2.6, -3.2], + [0.6, -11.2], + [2.7, 19.1], + ]); + let x: Array2 = arr2(&[ + [-4.0, 5.0], + [7.0, -4.0], + [3.0, -3.0], + [-4.0, -2.0], + [-3.0, 1.0], + ]); let y = a.solve_tridiagonal_into(b).unwrap(); assert_close_l2!(&x, &y, 1e-7); } @@ -37,21 +41,57 @@ fn solve_tridiagonal_f64() { #[test] fn solve_tridiagonal_c64() { // https://www.nag-j.co.jp/lapack/zgttrs.htm - let a: Array2 = arr2(&[[ c64::new( -1.3, 1.3), c64::new( 2.0, -1.0), c64::new( 0.0, 0.0), c64::new( 0.0, 0.0), c64::new( 0.0, 0.0)], - [ c64::new( 1.0, -2.0), c64::new( -1.3, 1.3), c64::new( 2.0, 1.0), c64::new( 0.0, 0.0), c64::new( 0.0, 0.0)], - [ c64::new( 0.0, 0.0), c64::new( 1.0, 1.0), c64::new( -1.3, 3.3), c64::new( -1.0, 1.0), c64::new( 0.0, 0.0)], - [ c64::new( 0.0, 0.0), c64::new( 0.0, 0.0), c64::new( 2.0, -3.0), c64::new( -0.3, 4.3), c64::new( 1.0, -1.0)], - [ c64::new( 0.0, 0.0), c64::new( 0.0, 0.0), c64::new( 0.0, 0.0), c64::new( 1.0, 1.0), c64::new( -3.3, 1.3)]]); - let b: Array2 = arr2(&[[ c64::new( 2.4, -5.0), c64::new( 2.7, 6.9)], - [ c64::new( 3.4, 18.2), c64::new( -6.9, -5.3)], - [ c64::new(-14.7, 9.7), c64::new( -6.0, -0.6)], - [ c64::new( 31.9, -7.7), c64::new( -3.9, 9.3)], - [ c64::new( -1.0, 1.6), c64::new( -3.0, 12.2)]]); - let x: Array2 = arr2(&[[ c64::new( 1.0, 1.0), c64::new( 2.0, -1.0)], - [ c64::new( 3.0, -1.0), c64::new( 1.0, 2.0)], - [ c64::new( 4.0, 5.0), c64::new( -1.0, 1.0)], - [ c64::new( -1.0, -2.0), c64::new( 2.0, 1.0)], - [ c64::new( 1.0, -1.0), c64::new( 2.0, -2.0)]]); + let a: Array2 = arr2(&[ + [ + c64::new(-1.3, 1.3), + c64::new(2.0, -1.0), + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + ], + [ + c64::new(1.0, -2.0), + c64::new(-1.3, 1.3), + c64::new(2.0, 1.0), + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + ], + [ + c64::new(0.0, 0.0), + c64::new(1.0, 1.0), + c64::new(-1.3, 3.3), + c64::new(-1.0, 1.0), + c64::new(0.0, 0.0), + ], + [ + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + c64::new(2.0, -3.0), + c64::new(-0.3, 4.3), + c64::new(1.0, -1.0), + ], + [ + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + c64::new(1.0, 1.0), + c64::new(-3.3, 1.3), + ], + ]); + let b: Array2 = arr2(&[ + [c64::new(2.4, -5.0), c64::new(2.7, 6.9)], + [c64::new(3.4, 18.2), c64::new(-6.9, -5.3)], + [c64::new(-14.7, 9.7), c64::new(-6.0, -0.6)], + [c64::new(31.9, -7.7), c64::new(-3.9, 9.3)], + [c64::new(-1.0, 1.6), c64::new(-3.0, 12.2)], + ]); + let x: Array2 = arr2(&[ + [c64::new(1.0, 1.0), c64::new(2.0, -1.0)], + [c64::new(3.0, -1.0), c64::new(1.0, 2.0)], + [c64::new(4.0, 5.0), c64::new(-1.0, 1.0)], + [c64::new(-1.0, -2.0), c64::new(2.0, 1.0)], + [c64::new(1.0, -1.0), c64::new(2.0, -2.0)], + ]); let y = a.solve_tridiagonal_into(b).unwrap(); assert_close_l2!(&x, &y, 1e-7); } @@ -87,9 +127,7 @@ fn solve_tridiagonal_random_t() { #[test] fn det_tridiagonal_f64() { - let a: Array2 = arr2(&[[ 10.0, -9.0, 0.0], - [ 7.0, -12.0, 11.0], - [ 0.0, 10.0, 3.0]]); + let a: Array2 = arr2(&[[10.0, -9.0, 0.0], [7.0, -12.0, 11.0], [0.0, 10.0, 3.0]]); assert_aclose!(a.det_tridiagonal().unwrap(), -1271.0, 1e-7); assert_aclose!(a.det_tridiagonal().unwrap(), a.det().unwrap(), 1e-7); } @@ -105,11 +143,13 @@ fn det_tridiagonal_random() { #[test] fn rcond_tridiagonal_f64() { // https://www.nag-j.co.jp/lapack/dgtcon.htm - let a: Array2 = arr2(&[[ 3.0, 2.1, 0.0, 0.0, 0.0], - [ 3.4, 2.3, -1.0, 0.0, 0.0], - [ 0.0, 3.6, -5.0, 1.9, 0.0], - [ 0.0, 0.0, 7.0, -0.9, 8.0], - [ 0.0, 0.0, 0.0, -6.0, 7.1]]); + let a: Array2 = arr2(&[ + [3.0, 2.1, 0.0, 0.0, 0.0], + [3.4, 2.3, -1.0, 0.0, 0.0], + [0.0, 3.6, -5.0, 1.9, 0.0], + [0.0, 0.0, 7.0, -0.9, 8.0], + [0.0, 0.0, 0.0, -6.0, 7.1], + ]); assert_aclose!(1.0 / a.rcond_tridiagonal().unwrap(), 9.27e1, 0.1); assert_aclose!(a.rcond_tridiagonal().unwrap(), a.rcond().unwrap(), 1e-3); } @@ -117,11 +157,43 @@ fn rcond_tridiagonal_f64() { #[test] fn rcond_tridiagonal_c64() { // https://www.nag-j.co.jp/lapack/dgtcon.htm - let a: Array2 = arr2(&[[ c64::new( -1.3, 1.3), c64::new( 2.0, -1.0), c64::new( 0.0, 0.0), c64::new( 0.0, 0.0), c64::new( 0.0, 0.0)], - [ c64::new( 1.0, -2.0), c64::new( -1.3, 1.3), c64::new( 2.0, 1.0), c64::new( 0.0, 0.0), c64::new( 0.0, 0.0)], - [ c64::new( 0.0, 0.0), c64::new( 1.0, 1.0), c64::new( -1.3, 3.3), c64::new( -1.0, 1.0), c64::new( 0.0, 0.0)], - [ c64::new( 0.0, 0.0), c64::new( 0.0, 0.0), c64::new( 2.0, -3.0), c64::new( -0.3, 4.3), c64::new( 1.0, -1.0)], - [ c64::new( 0.0, 0.0), c64::new( 0.0, 0.0), c64::new( 0.0, 0.0), c64::new( 1.0, 1.0), c64::new( -3.3, 1.3)]]); + let a: Array2 = arr2(&[ + [ + c64::new(-1.3, 1.3), + c64::new(2.0, -1.0), + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + ], + [ + c64::new(1.0, -2.0), + c64::new(-1.3, 1.3), + c64::new(2.0, 1.0), + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + ], + [ + c64::new(0.0, 0.0), + c64::new(1.0, 1.0), + c64::new(-1.3, 3.3), + c64::new(-1.0, 1.0), + c64::new(0.0, 0.0), + ], + [ + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + c64::new(2.0, -3.0), + c64::new(-0.3, 4.3), + c64::new(1.0, -1.0), + ], + [ + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + c64::new(1.0, 1.0), + c64::new(-3.3, 1.3), + ], + ]); assert_aclose!(1.0 / a.rcond_tridiagonal().unwrap(), 1.84e2, 1.0); assert_aclose!(a.rcond_tridiagonal().unwrap(), a.rcond().unwrap(), 1e-3); } @@ -134,10 +206,11 @@ fn rcond_tridiagonal_identity() { assert_aclose!(a.rcond_tridiagonal().unwrap(), 1., $atol); }; } - for rows in 2..6 { // cannot make 1x1 tridiagonal matrices. + for rows in 2..6 { + // cannot make 1x1 tridiagonal matrices. rcond_identity!(f64, rows, 1e-9); rcond_identity!(f32, rows, 1e-3); rcond_identity!(c64, rows, 1e-9); rcond_identity!(c32, rows, 1e-3); } -} \ No newline at end of file +} From 5196445c35ccbfa3468a0705bdecff8d4d949432 Mon Sep 17 00:00:00 2001 From: doraneko94 Date: Sun, 14 Jun 2020 15:35:58 +0900 Subject: [PATCH 3/8] impl Solve_Tridiagonal for Tridiagonal --- src/tridiagonal.rs | 141 +++++++++++++++++++++++++++++++++++++++++++ tests/tridiagonal.rs | 15 +++++ 2 files changed, 156 insertions(+) diff --git a/src/tridiagonal.rs b/src/tridiagonal.rs index 337e9c3c..2ffd55cb 100644 --- a/src/tridiagonal.rs +++ b/src/tridiagonal.rs @@ -189,6 +189,57 @@ where } } +impl SolveTriDiagonal for TriDiagonal +where + A: Scalar + Lapack, +{ + fn solve_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_t_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_t_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_t_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_t_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_h_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_h_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_h_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_h_tridiagonal_inplace(&mut b)?; + Ok(b) + } +} + impl SolveTriDiagonal for ArrayBase where A: Scalar + Lapack, @@ -298,6 +349,42 @@ where } } +impl SolveTriDiagonalInplace for TriDiagonal +where + A: Scalar + Lapack, +{ + fn solve_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize_tridiagonal()?; + f.solve_tridiagonal_inplace(rhs) + } + fn solve_t_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize_tridiagonal()?; + f.solve_t_tridiagonal_inplace(rhs) + } + fn solve_h_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize_tridiagonal()?; + f.solve_h_tridiagonal_inplace(rhs) + } +} + impl SolveTriDiagonalInplace for ArrayBase where A: Scalar + Lapack, @@ -383,6 +470,60 @@ where } } +impl SolveTriDiagonal for TriDiagonal +where + A: Scalar + Lapack, +{ + fn solve_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let b = b.to_owned(); + self.solve_tridiagonal_into(b) + } + fn solve_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result> { + let b = into_col(b); + let f = self.factorize_tridiagonal()?; + let b = f.solve_tridiagonal_into(b)?; + Ok(flatten(b)) + } + fn solve_t_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let b = b.to_owned(); + self.solve_t_tridiagonal_into(b) + } + fn solve_t_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result> { + let b = into_col(b); + let f = self.factorize_tridiagonal()?; + let b = f.solve_t_tridiagonal_into(b)?; + Ok(flatten(b)) + } + fn solve_h_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let b = b.to_owned(); + self.solve_h_tridiagonal_into(b) + } + fn solve_h_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result> { + let b = into_col(b); + let f = self.factorize_tridiagonal()?; + let b = f.solve_h_tridiagonal_into(b)?; + Ok(flatten(b)) + } +} + impl SolveTriDiagonal for ArrayBase where A: Scalar + Lapack, diff --git a/tests/tridiagonal.rs b/tests/tridiagonal.rs index b04e592d..290ce510 100644 --- a/tests/tridiagonal.rs +++ b/tests/tridiagonal.rs @@ -125,6 +125,21 @@ fn solve_tridiagonal_random_t() { assert_close_l2!(&y1, &y2, 1e-7); } +#[test] +fn to_tridiagonal_solve_random() { + let mut a: Array2 = random((3, 3)); + a[[0, 2]] = 0.0; + a[[2, 0]] = 0.0; + let tridiag = a.to_tridiagonal().unwrap(); + let x: Array1 = random(3); + let b1 = a.dot(&x); + let b2 = b1.clone(); + let y1 = tridiag.solve_tridiagonal_into(b1).unwrap(); + let y2 = a.solve_into(b2).unwrap(); + assert_close_l2!(&x, &y1, 1e-7); + assert_close_l2!(&y1, &y2, 1e-7); +} + #[test] fn det_tridiagonal_f64() { let a: Array2 = arr2(&[[10.0, -9.0, 0.0], [7.0, -12.0, 11.0], [0.0, 10.0, 3.0]]); From 47714387ca85ab7063487a32497adcf5e47cb82c Mon Sep 17 00:00:00 2001 From: doraneko94 Date: Fri, 19 Jun 2020 12:19:51 +0900 Subject: [PATCH 4/8] impl Index/IndexMut & impl opnorm --- src/lapack/tridiagonal.rs | 14 ++++-- src/opnorm.rs | 53 ++++++++++++++++++++++ src/tridiagonal.rs | 93 ++++++++++++++++++++++++++++++++------- tests/tridiagonal.rs | 31 +++++++++++++ 4 files changed, 171 insertions(+), 20 deletions(-) diff --git a/src/lapack/tridiagonal.rs b/src/lapack/tridiagonal.rs index 71ccd308..26aeceb1 100644 --- a/src/lapack/tridiagonal.rs +++ b/src/lapack/tridiagonal.rs @@ -10,6 +10,7 @@ use super::{into_result, Pivot, Transpose}; use crate::error::*; use crate::layout::MatrixLayout; +use crate::opnorm::*; use crate::tridiagonal::{LUFactorizedTriDiagonal, TriDiagonal}; use crate::types::*; @@ -17,7 +18,9 @@ use crate::types::*; pub trait TriDiagonal_: Scalar + Sized { /// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using /// partial pivoting with row interchanges. - unsafe fn lu_tridiagonal(a: &mut TriDiagonal) -> Result<(Array1, Pivot)>; + unsafe fn lu_tridiagonal( + a: &mut TriDiagonal, + ) -> Result<(Array1, Self::Real, Pivot)>; /// Estimates the the reciprocal of the condition number of the tridiagonal matrix in 1-norm. unsafe fn rcond_tridiagonal(lu: &LUFactorizedTriDiagonal) -> Result; unsafe fn solve_tridiagonal( @@ -31,15 +34,18 @@ pub trait TriDiagonal_: Scalar + Sized { macro_rules! impl_tridiagonal { ($scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { impl TriDiagonal_ for $scalar { - unsafe fn lu_tridiagonal(a: &mut TriDiagonal) -> Result<(Array1, Pivot)> { + unsafe fn lu_tridiagonal( + a: &mut TriDiagonal, + ) -> Result<(Array1, Self::Real, Pivot)> { let (n, _) = a.l.size(); + let anom = a.opnorm_one()?; let dl = a.dl.as_slice_mut().unwrap(); let d = a.d.as_slice_mut().unwrap(); let du = a.du.as_slice_mut().unwrap(); let mut du2 = vec![Zero::zero(); (n - 2) as usize]; let mut ipiv = vec![0; n as usize]; let info = $gttrf(n, dl, d, du, &mut du2, &mut ipiv); - into_result(info, (arr1(&du2), ipiv)) + into_result(info, (arr1(&du2), anom, ipiv)) } unsafe fn rcond_tridiagonal(lu: &LUFactorizedTriDiagonal) -> Result { @@ -49,7 +55,7 @@ macro_rules! impl_tridiagonal { let du = lu.a.du.as_slice().unwrap(); let du2 = lu.du2.as_slice().unwrap(); let ipiv = &lu.ipiv; - let anorm = lu.a.n1; + let anorm = lu.anom; let mut rcond = Self::Real::zero(); let info = $gtcon( NormType::One as u8, diff --git a/src/opnorm.rs b/src/opnorm.rs index 37d38115..13080fc5 100644 --- a/src/opnorm.rs +++ b/src/opnorm.rs @@ -2,8 +2,10 @@ use ndarray::*; +use crate::convert::*; use crate::error::*; use crate::layout::*; +use crate::tridiagonal::TriDiagonal; use crate::types::*; pub use crate::lapack::NormType; @@ -46,3 +48,54 @@ where Ok(unsafe { A::opnorm(t, l, a) }) } } + +impl OperationNorm for TriDiagonal +where + A: Scalar + Lapack, +{ + type Output = A::Real; + + fn opnorm(&self, t: NormType) -> Result { + let arr = match t { + NormType::One => { + let zl: Array1 = Array::zeros(1); + let zu: Array1 = Array::zeros(1); + let dl = stack![Axis(0), self.dl.to_owned(), zl]; + let du = stack![Axis(0), zu, self.du.to_owned()]; + let arr = stack![ + Axis(0), + into_row(du), + into_row(self.d.to_owned()), + into_row(dl) + ]; + arr + } + NormType::Infinity => { + let zl: Array1 = Array::zeros(1); + let zu: Array1 = Array::zeros(1); + let dl = stack![Axis(0), zl, self.dl.to_owned()]; + let du = stack![Axis(0), self.du.to_owned(), zu]; + let arr = stack![ + Axis(1), + into_col(dl), + into_col(self.d.to_owned()), + into_col(du) + ]; + arr + } + NormType::Frobenius => { + let arr = stack![ + Axis(1), + into_row(self.dl.to_owned()), + into_row(self.d.to_owned()), + into_row(self.du.to_owned()) + ]; + arr + } + }; + + let l = arr.layout()?; + let a = arr.as_allocated()?; + Ok(unsafe { A::opnorm(t, l, a) }) + } +} diff --git a/src/tridiagonal.rs b/src/tridiagonal.rs index 2ffd55cb..96aa318b 100644 --- a/src/tridiagonal.rs +++ b/src/tridiagonal.rs @@ -2,26 +2,23 @@ //! & //! Methods for tridiagonal matrices +use std::ops::{Index, IndexMut}; + use cauchy::Scalar; use ndarray::*; use num_traits::One; -use crate::opnorm::OperationNorm; - use super::convert::*; use super::error::*; use super::lapack::*; use super::layout::*; /// Represents a tridiagonal matrix as 3 one-dimensional vectors. -/// This struct also holds the layout and 1-norm of the raw matrix -/// for some methods (eg. rcond_tridiagonal()). -#[derive(Clone)] +/// This struct also holds the layout of the raw matrix. +#[derive(Clone, PartialEq)] pub struct TriDiagonal { /// layout of raw matrix pub l: MatrixLayout, - /// the one norm of raw matrix - pub n1: ::Real, /// (n-1) sub-diagonal elements of matrix. pub dl: Array1, /// (n) diagonal elements of matrix. @@ -30,10 +27,73 @@ pub struct TriDiagonal { pub du: Array1, } +pub trait TridiagIndex { + fn to_tuple(&self) -> (i32, i32); +} +impl TridiagIndex for [Ix; 2] { + fn to_tuple(&self) -> (i32, i32) { + (self[0] as i32, self[1] as i32) + } +} + +fn debug_bounds_check_tridiag(n: i32, row: i32, col: i32) { + if std::cmp::max(row, col) >= n { + panic!( + "ndarray: index {:?} is out of bounds for array of shape {}", + [row, col], + n + ); + } +} + +impl Index for TriDiagonal +where + A: Scalar, + I: TridiagIndex, +{ + type Output = A; + #[inline] + fn index(&self, index: I) -> &A { + let (n, _) = self.l.size(); + let (row, col) = index.to_tuple(); + debug_bounds_check_tridiag(n, row, col); + match row - col { + 0 => &self.d[row as usize], + 1 => &self.dl[col as usize], + -1 => &self.du[row as usize], + _ => panic!( + "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element", + [row, col] + ), + } + } +} + +impl IndexMut for TriDiagonal +where + A: Scalar, + I: TridiagIndex, +{ + #[inline] + fn index_mut(&mut self, index: I) -> &mut A { + let (n, _) = self.l.size(); + let (row, col) = index.to_tuple(); + debug_bounds_check_tridiag(n, row, col); + match row - col { + 0 => &mut self.d[row as usize], + 1 => &mut self.dl[col as usize], + -1 => &mut self.du[row as usize], + _ => panic!( + "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element", + [row, col] + ), + } + } +} + /// An interface for making a TriDiagonal struct. pub trait ToTriDiagonal { /// Extract tridiagonal elements and layout of the raw matrix. - /// And also calculate 1-norm. /// /// If the raw matrix has some non-tridiagonal elements, /// they will be ignored. @@ -53,12 +113,11 @@ where if n < 2 { panic!("Cannot make a tridiagonal matrix of shape=(1, 1)!"); } - let n1 = self.opnorm_one()?; let dl = self.slice(s![1..n, 0..n - 1]).diag().to_owned(); let d = self.diag().to_owned(); let du = self.slice(s![0..n - 1, 1..n]).diag().to_owned(); - Ok(TriDiagonal { l, n1, dl, d, du }) + Ok(TriDiagonal { l, dl, d, du }) } } @@ -130,13 +189,14 @@ pub trait SolveTriDiagonalInplace { pub struct LUFactorizedTriDiagonal { /// A tridiagonal matrix which consists of /// - l : layout of raw matrix - /// - n1: the one norm of raw matrix /// - dl: (n-1) multipliers that define the matrix L. /// - d : (n) diagonal elements of the upper triangular matrix U. /// - du: (n-1) elements of the first super-diagonal of U. pub a: TriDiagonal, /// (n-2) elements of the second super-diagonal of U. pub du2: Array1, + /// 1-norm of raw matrix (used in .rcond_tridiagonal()). + pub anom: A::Real, /// The pivot indices that define the permutation matrix `P`. pub ipiv: Pivot, } @@ -598,10 +658,11 @@ where A: Scalar + Lapack, { fn factorize_tridiagonal_into(mut self) -> Result> { - let (du2, ipiv) = unsafe { A::lu_tridiagonal(&mut self)? }; + let (du2, anom, ipiv) = unsafe { A::lu_tridiagonal(&mut self)? }; Ok(LUFactorizedTriDiagonal { a: self, du2: du2, + anom: anom, ipiv: ipiv, }) } @@ -613,8 +674,8 @@ where { fn factorize_tridiagonal(&self) -> Result> { let mut a = self.clone(); - let (du2, ipiv) = unsafe { A::lu_tridiagonal(&mut a)? }; - Ok(LUFactorizedTriDiagonal { a, du2, ipiv }) + let (du2, anom, ipiv) = unsafe { A::lu_tridiagonal(&mut a)? }; + Ok(LUFactorizedTriDiagonal { a, du2, anom, ipiv }) } } @@ -625,8 +686,8 @@ where { fn factorize_tridiagonal(&self) -> Result> { let mut a = self.to_tridiagonal()?; - let (du2, ipiv) = unsafe { A::lu_tridiagonal(&mut a)? }; - Ok(LUFactorizedTriDiagonal { a, du2, ipiv }) + let (du2, anom, ipiv) = unsafe { A::lu_tridiagonal(&mut a)? }; + Ok(LUFactorizedTriDiagonal { a, du2, anom, ipiv }) } } diff --git a/tests/tridiagonal.rs b/tests/tridiagonal.rs index 290ce510..19e66f60 100644 --- a/tests/tridiagonal.rs +++ b/tests/tridiagonal.rs @@ -10,6 +10,37 @@ fn to_tridiagonal() { assert_close_l2!(&t.du, &arr1(&[2.0, 6.0]), 1e-7); } +#[test] +fn tridiagonal_index() { + let a: Array2 = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]); + let t1 = a.to_tridiagonal().unwrap(); + let mut t2 = Array2::::eye(3).to_tridiagonal().unwrap(); + t2[[0, 1]] = 2.0; + t2[[1, 0]] = 4.0; + t2[[1, 1]] += 4.0; + t2[[1, 2]] = 6.0; + t2[[2, 1]] = 8.0; + t2[[2, 2]] += 8.0; + assert_eq!(t1.dl, t2.dl); + assert_eq!(t1.d, t2.d); + assert_eq!(t1.du, t2.du); +} + +#[test] +fn opnorm_tridiagonal() { + let mut a: Array2 = random((4, 4)); + a[[0, 2]] = 0.0; + a[[0, 3]] = 0.0; + a[[1, 3]] = 0.0; + a[[2, 0]] = 0.0; + a[[3, 0]] = 0.0; + a[[3, 1]] = 0.0; + let t = a.to_tridiagonal().unwrap(); + assert_aclose!(a.opnorm_one().unwrap(), t.opnorm_one().unwrap(), 1e-7); + assert_aclose!(a.opnorm_inf().unwrap(), t.opnorm_inf().unwrap(), 1e-7); + assert_aclose!(a.opnorm_fro().unwrap(), t.opnorm_fro().unwrap(), 1e-7); +} + #[test] fn solve_tridiagonal_f64() { // https://www.nag-j.co.jp/lapack/dgttrs.htm From 4e438176796d9074ce386bc12b5b9f01d45f0a52 Mon Sep 17 00:00:00 2001 From: doraneko94 Date: Fri, 19 Jun 2020 15:46:50 +0900 Subject: [PATCH 5/8] add comments on opnorm --- src/opnorm.rs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/opnorm.rs b/src/opnorm.rs index 13080fc5..01a0d5c3 100644 --- a/src/opnorm.rs +++ b/src/opnorm.rs @@ -56,7 +56,18 @@ where type Output = A::Real; fn opnorm(&self, t: NormType) -> Result { + // `self` is a tridiagonal matrix like, + // [d0, u1, 0, ..., 0, + // l1, d1, u2, ..., + // 0, l2, d2, + // ... ..., u{n-1}, + // 0, ..., l{n-1}, d{n-1},] let arr = match t { + // opnorm_one() calculates muximum column sum. + // Therefore, This part align the columns and make a (3 x n) matrix like, + // [ 0, u1, u2, ..., u{n-1}, + // d0, d1, d2, ..., d{n-1}, + // l1, l2, l3, ..., 0,] NormType::One => { let zl: Array1 = Array::zeros(1); let zu: Array1 = Array::zeros(1); @@ -70,6 +81,13 @@ where ]; arr } + // opnorm_inf() calculates muximum row sum. + // Therefore, This part align the rows and make a (n x 3) matrix like, + // [ 0, d0, u1, + // l1, d1, u2, + // l2, d2, u3, + // ..., ..., ..., + // l{n-1}, d{n-1}, 0,] NormType::Infinity => { let zl: Array1 = Array::zeros(1); let zu: Array1 = Array::zeros(1); @@ -83,6 +101,10 @@ where ]; arr } + // opnorm_fro() calculates square root of sum of squares. + // Because it is independent of the shape of matrix, + // this part make a (1 x (3n-2)) matrix like, + // [l1, ..., l{n-1}, d0, ..., d{n-1}, u1, ..., u{n-1}] NormType::Frobenius => { let arr = stack![ Axis(1), From 0ff92e47edc578e7f0d4be3a1766ab08f9e05264 Mon Sep 17 00:00:00 2001 From: doraneko94 Date: Fri, 19 Jun 2020 15:47:57 +0900 Subject: [PATCH 6/8] cargo fmt --- src/opnorm.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/opnorm.rs b/src/opnorm.rs index 01a0d5c3..487110f2 100644 --- a/src/opnorm.rs +++ b/src/opnorm.rs @@ -59,7 +59,7 @@ where // `self` is a tridiagonal matrix like, // [d0, u1, 0, ..., 0, // l1, d1, u2, ..., - // 0, l2, d2, + // 0, l2, d2, // ... ..., u{n-1}, // 0, ..., l{n-1}, d{n-1},] let arr = match t { From 4624ff2a22c879c696e158126764c9644f990066 Mon Sep 17 00:00:00 2001 From: doraneko94 Date: Sat, 27 Jun 2020 16:26:46 +0900 Subject: [PATCH 7/8] fix based on code review --- examples/tridiagonal.rs | 4 ++ src/error.rs | 11 +++ src/lapack/mod.rs | 2 +- src/lapack/tridiagonal.rs | 54 ++++++-------- src/opnorm.rs | 24 ++----- src/tridiagonal.rs | 144 ++++++++++++++++++++------------------ tests/tridiagonal.rs | 20 +++--- 7 files changed, 127 insertions(+), 132 deletions(-) diff --git a/examples/tridiagonal.rs b/examples/tridiagonal.rs index 676bfb28..cffa3173 100644 --- a/examples/tridiagonal.rs +++ b/examples/tridiagonal.rs @@ -27,4 +27,8 @@ fn factorize() -> Result<(), error::LinalgError> { fn main() { solve().unwrap(); factorize().unwrap(); + match arr2(&[[0.0]]).extract_tridiagonal() { + Ok(_) => {} + Err(err) => println!("{}", err), + } } diff --git a/src/error.rs b/src/error.rs index 4e487acf..c6524897 100644 --- a/src/error.rs +++ b/src/error.rs @@ -17,6 +17,12 @@ pub enum LinalgError { InvalidStride { s0: Ixs, s1: Ixs }, /// Memory is not aligned continously MemoryNotCont, + /// Obj cannot be made from a (rows, cols) matrix + NotStandardShape { + obj: &'static str, + rows: i32, + cols: i32, + }, /// Strides of the array is not supported Shape(ShapeError), } @@ -34,6 +40,11 @@ impl fmt::Display for LinalgError { write!(f, "invalid stride: s0={}, s1={}", s0, s1) } LinalgError::MemoryNotCont => write!(f, "Memory is not contiguous"), + LinalgError::NotStandardShape { obj, rows, cols } => write!( + f, + "{} cannot be made from a ({}, {}) matrix", + obj, rows, cols + ), LinalgError::Shape(err) => write!(f, "Shape Error: {}", err), } } diff --git a/src/lapack/mod.rs b/src/lapack/mod.rs index 540fef64..a3482505 100644 --- a/src/lapack/mod.rs +++ b/src/lapack/mod.rs @@ -41,7 +41,7 @@ pub trait Lapack: + Eig_ + Eigh_ + Triangular_ - + TriDiagonal_ + + Tridiagonal_ { } diff --git a/src/lapack/tridiagonal.rs b/src/lapack/tridiagonal.rs index 26aeceb1..468cc8ea 100644 --- a/src/lapack/tridiagonal.rs +++ b/src/lapack/tridiagonal.rs @@ -2,7 +2,6 @@ //! for tridiagonal matrix use lapacke; -use ndarray::*; use num_traits::Zero; use super::NormType; @@ -11,20 +10,18 @@ use super::{into_result, Pivot, Transpose}; use crate::error::*; use crate::layout::MatrixLayout; use crate::opnorm::*; -use crate::tridiagonal::{LUFactorizedTriDiagonal, TriDiagonal}; +use crate::tridiagonal::{LUFactorizedTridiagonal, Tridiagonal}; use crate::types::*; /// Wraps `*gttrf`, `*gtcon` and `*gttrs` -pub trait TriDiagonal_: Scalar + Sized { +pub trait Tridiagonal_: Scalar + Sized { /// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using /// partial pivoting with row interchanges. - unsafe fn lu_tridiagonal( - a: &mut TriDiagonal, - ) -> Result<(Array1, Self::Real, Pivot)>; + unsafe fn lu_tridiagonal(a: &mut Tridiagonal) -> Result<(Vec, Self::Real, Pivot)>; /// Estimates the the reciprocal of the condition number of the tridiagonal matrix in 1-norm. - unsafe fn rcond_tridiagonal(lu: &LUFactorizedTriDiagonal) -> Result; + unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result; unsafe fn solve_tridiagonal( - lu: &LUFactorizedTriDiagonal, + lu: &LUFactorizedTridiagonal, bl: MatrixLayout, t: Transpose, b: &mut [Self], @@ -33,37 +30,30 @@ pub trait TriDiagonal_: Scalar + Sized { macro_rules! impl_tridiagonal { ($scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { - impl TriDiagonal_ for $scalar { + impl Tridiagonal_ for $scalar { unsafe fn lu_tridiagonal( - a: &mut TriDiagonal, - ) -> Result<(Array1, Self::Real, Pivot)> { + a: &mut Tridiagonal, + ) -> Result<(Vec, Self::Real, Pivot)> { let (n, _) = a.l.size(); let anom = a.opnorm_one()?; - let dl = a.dl.as_slice_mut().unwrap(); - let d = a.d.as_slice_mut().unwrap(); - let du = a.du.as_slice_mut().unwrap(); let mut du2 = vec![Zero::zero(); (n - 2) as usize]; let mut ipiv = vec![0; n as usize]; - let info = $gttrf(n, dl, d, du, &mut du2, &mut ipiv); - into_result(info, (arr1(&du2), anom, ipiv)) + let info = $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv); + into_result(info, (du2, anom, ipiv)) } - unsafe fn rcond_tridiagonal(lu: &LUFactorizedTriDiagonal) -> Result { + unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { let (n, _) = lu.a.l.size(); - let dl = lu.a.dl.as_slice().unwrap(); - let d = lu.a.d.as_slice().unwrap(); - let du = lu.a.du.as_slice().unwrap(); - let du2 = lu.du2.as_slice().unwrap(); let ipiv = &lu.ipiv; let anorm = lu.anom; let mut rcond = Self::Real::zero(); let info = $gtcon( NormType::One as u8, n, - dl, - d, - du, - du2, + &lu.a.dl, + &lu.a.d, + &lu.a.du, + &lu.du2, ipiv, anorm, &mut rcond, @@ -72,17 +62,13 @@ macro_rules! impl_tridiagonal { } unsafe fn solve_tridiagonal( - lu: &LUFactorizedTriDiagonal, + lu: &LUFactorizedTridiagonal, bl: MatrixLayout, t: Transpose, b: &mut [Self], ) -> Result<()> { let (n, _) = lu.a.l.size(); let (_, nrhs) = bl.size(); - let dl = lu.a.dl.as_slice().unwrap(); - let d = lu.a.d.as_slice().unwrap(); - let du = lu.a.du.as_slice().unwrap(); - let du2 = lu.du2.as_slice().unwrap(); let ipiv = &lu.ipiv; let ldb = bl.lda(); let info = $gttrs( @@ -90,10 +76,10 @@ macro_rules! impl_tridiagonal { t as u8, n, nrhs, - dl, - d, - du, - du2, + &lu.a.dl, + &lu.a.d, + &lu.a.du, + &lu.du2, ipiv, b, ldb, diff --git a/src/opnorm.rs b/src/opnorm.rs index 487110f2..4358fd7f 100644 --- a/src/opnorm.rs +++ b/src/opnorm.rs @@ -5,7 +5,7 @@ use ndarray::*; use crate::convert::*; use crate::error::*; use crate::layout::*; -use crate::tridiagonal::TriDiagonal; +use crate::tridiagonal::Tridiagonal; use crate::types::*; pub use crate::lapack::NormType; @@ -49,7 +49,7 @@ where } } -impl OperationNorm for TriDiagonal +impl OperationNorm for Tridiagonal where A: Scalar + Lapack, { @@ -73,12 +73,7 @@ where let zu: Array1 = Array::zeros(1); let dl = stack![Axis(0), self.dl.to_owned(), zl]; let du = stack![Axis(0), zu, self.du.to_owned()]; - let arr = stack![ - Axis(0), - into_row(du), - into_row(self.d.to_owned()), - into_row(dl) - ]; + let arr = stack![Axis(0), into_row(du), into_row(arr1(&self.d)), into_row(dl)]; arr } // opnorm_inf() calculates muximum row sum. @@ -93,12 +88,7 @@ where let zu: Array1 = Array::zeros(1); let dl = stack![Axis(0), zl, self.dl.to_owned()]; let du = stack![Axis(0), self.du.to_owned(), zu]; - let arr = stack![ - Axis(1), - into_col(dl), - into_col(self.d.to_owned()), - into_col(du) - ]; + let arr = stack![Axis(1), into_col(dl), into_col(arr1(&self.d)), into_col(du)]; arr } // opnorm_fro() calculates square root of sum of squares. @@ -108,9 +98,9 @@ where NormType::Frobenius => { let arr = stack![ Axis(1), - into_row(self.dl.to_owned()), - into_row(self.d.to_owned()), - into_row(self.du.to_owned()) + into_row(arr1(&self.dl)), + into_row(arr1(&self.d)), + into_row(arr1(&self.du)) ]; arr } diff --git a/src/tridiagonal.rs b/src/tridiagonal.rs index 96aa318b..155aa463 100644 --- a/src/tridiagonal.rs +++ b/src/tridiagonal.rs @@ -1,4 +1,4 @@ -//! Vectors as a TriDiagonal matrix +//! Vectors as a Tridiagonal matrix //! & //! Methods for tridiagonal matrices @@ -16,15 +16,15 @@ use super::layout::*; /// Represents a tridiagonal matrix as 3 one-dimensional vectors. /// This struct also holds the layout of the raw matrix. #[derive(Clone, PartialEq)] -pub struct TriDiagonal { +pub struct Tridiagonal { /// layout of raw matrix pub l: MatrixLayout, /// (n-1) sub-diagonal elements of matrix. - pub dl: Array1, + pub dl: Vec, /// (n) diagonal elements of matrix. - pub d: Array1, + pub d: Vec, /// (n-1) super-diagonal elements of matrix. - pub du: Array1, + pub du: Vec, } pub trait TridiagIndex { @@ -36,17 +36,7 @@ impl TridiagIndex for [Ix; 2] { } } -fn debug_bounds_check_tridiag(n: i32, row: i32, col: i32) { - if std::cmp::max(row, col) >= n { - panic!( - "ndarray: index {:?} is out of bounds for array of shape {}", - [row, col], - n - ); - } -} - -impl Index for TriDiagonal +impl Index for Tridiagonal where A: Scalar, I: TridiagIndex, @@ -56,7 +46,12 @@ where fn index(&self, index: I) -> &A { let (n, _) = self.l.size(); let (row, col) = index.to_tuple(); - debug_bounds_check_tridiag(n, row, col); + assert!( + std::cmp::max(row, col) < n, + "ndarray: index {:?} is out of bounds for array of shape {}", + [row, col], + n + ); match row - col { 0 => &self.d[row as usize], 1 => &self.dl[col as usize], @@ -69,7 +64,7 @@ where } } -impl IndexMut for TriDiagonal +impl IndexMut for Tridiagonal where A: Scalar, I: TridiagIndex, @@ -78,7 +73,12 @@ where fn index_mut(&mut self, index: I) -> &mut A { let (n, _) = self.l.size(); let (row, col) = index.to_tuple(); - debug_bounds_check_tridiag(n, row, col); + assert!( + std::cmp::max(row, col) < n, + "ndarray: index {:?} is out of bounds for array of shape {}", + [row, col], + n + ); match row - col { 0 => &mut self.d[row as usize], 1 => &mut self.dl[col as usize], @@ -91,37 +91,41 @@ where } } -/// An interface for making a TriDiagonal struct. -pub trait ToTriDiagonal { +/// An interface for making a Tridiagonal struct. +pub trait ToTridiagonal { /// Extract tridiagonal elements and layout of the raw matrix. /// /// If the raw matrix has some non-tridiagonal elements, /// they will be ignored. /// /// The shape of raw matrix should be equal to or larger than (2, 2). - fn to_tridiagonal(&self) -> Result>; + fn extract_tridiagonal(&self) -> Result>; } -impl ToTriDiagonal for ArrayBase +impl ToTridiagonal for ArrayBase where A: Scalar + Lapack, S: Data, { - fn to_tridiagonal(&self) -> Result> { + fn extract_tridiagonal(&self) -> Result> { let l = self.square_layout()?; let (n, _) = l.size(); if n < 2 { - panic!("Cannot make a tridiagonal matrix of shape=(1, 1)!"); + return Err(LinalgError::NotStandardShape { + obj: "Tridiagonal", + rows: 1, + cols: 1, + }); } - let dl = self.slice(s![1..n, 0..n - 1]).diag().to_owned(); - let d = self.diag().to_owned(); - let du = self.slice(s![0..n - 1, 1..n]).diag().to_owned(); - Ok(TriDiagonal { l, dl, d, du }) + let dl = self.slice(s![1..n, 0..n - 1]).diag().to_vec(); + let d = self.diag().to_vec(); + let du = self.slice(s![0..n - 1, 1..n]).diag().to_vec(); + Ok(Tridiagonal { l, dl, d, du }) } } -pub trait SolveTriDiagonal { +pub trait SolveTridiagonal { /// Solves a system of linear equations `A * x = b` with tridiagonal /// matrix `A`, where `A` is `self`, `b` is the argument, and /// `x` is the successful result. @@ -157,7 +161,7 @@ pub trait SolveTriDiagonal { ) -> Result>; } -pub trait SolveTriDiagonalInplace { +pub trait SolveTridiagonalInplace { /// Solves a system of linear equations `A * x = b` tridiagonal /// matrix `A`, where `A` is `self`, `b` is the argument, and /// `x` is the successful result. The value of `x` is also assigned to the @@ -186,22 +190,22 @@ pub trait SolveTriDiagonalInplace { /// Represents the LU factorization of a tridiagonal matrix `A` as `A = P*L*U`. #[derive(Clone)] -pub struct LUFactorizedTriDiagonal { +pub struct LUFactorizedTridiagonal { /// A tridiagonal matrix which consists of /// - l : layout of raw matrix /// - dl: (n-1) multipliers that define the matrix L. /// - d : (n) diagonal elements of the upper triangular matrix U. /// - du: (n-1) elements of the first super-diagonal of U. - pub a: TriDiagonal, + pub a: Tridiagonal, /// (n-2) elements of the second super-diagonal of U. - pub du2: Array1, + pub du2: Vec, /// 1-norm of raw matrix (used in .rcond_tridiagonal()). pub anom: A::Real, /// The pivot indices that define the permutation matrix `P`. pub ipiv: Pivot, } -impl SolveTriDiagonal for LUFactorizedTriDiagonal +impl SolveTridiagonal for LUFactorizedTridiagonal where A: Scalar + Lapack, { @@ -249,7 +253,7 @@ where } } -impl SolveTriDiagonal for TriDiagonal +impl SolveTridiagonal for Tridiagonal where A: Scalar + Lapack, { @@ -300,7 +304,7 @@ where } } -impl SolveTriDiagonal for ArrayBase +impl SolveTridiagonal for ArrayBase where A: Scalar + Lapack, S: Data, @@ -352,7 +356,7 @@ where } } -impl SolveTriDiagonalInplace for LUFactorizedTriDiagonal +impl SolveTridiagonalInplace for LUFactorizedTridiagonal where A: Scalar + Lapack, { @@ -409,7 +413,7 @@ where } } -impl SolveTriDiagonalInplace for TriDiagonal +impl SolveTridiagonalInplace for Tridiagonal where A: Scalar + Lapack, { @@ -445,7 +449,7 @@ where } } -impl SolveTriDiagonalInplace for ArrayBase +impl SolveTridiagonalInplace for ArrayBase where A: Scalar + Lapack, S: Data, @@ -482,7 +486,7 @@ where } } -impl SolveTriDiagonal for LUFactorizedTriDiagonal +impl SolveTridiagonal for LUFactorizedTridiagonal where A: Scalar + Lapack, { @@ -530,7 +534,7 @@ where } } -impl SolveTriDiagonal for TriDiagonal +impl SolveTridiagonal for Tridiagonal where A: Scalar + Lapack, { @@ -584,7 +588,7 @@ where } } -impl SolveTriDiagonal for ArrayBase +impl SolveTridiagonal for ArrayBase where A: Scalar + Lapack, S: Data, @@ -640,26 +644,26 @@ where } /// An interface for computing LU factorizations of tridiagonal matrix refs. -pub trait FactorizeTriDiagonal { +pub trait FactorizeTridiagonal { /// Computes the LU factorization `A = P*L*U`, where `P` is a permutation /// matrix. - fn factorize_tridiagonal(&self) -> Result>; + fn factorize_tridiagonal(&self) -> Result>; } /// An interface for computing LU factorizations of tridiagonal matrices. -pub trait FactorizeTriDiagonalInto { +pub trait FactorizeTridiagonalInto { /// Computes the LU factorization `A = P*L*U`, where `P` is a permutation /// matrix. - fn factorize_tridiagonal_into(self) -> Result>; + fn factorize_tridiagonal_into(self) -> Result>; } -impl FactorizeTriDiagonalInto for TriDiagonal +impl FactorizeTridiagonalInto for Tridiagonal where A: Scalar + Lapack, { - fn factorize_tridiagonal_into(mut self) -> Result> { + fn factorize_tridiagonal_into(mut self) -> Result> { let (du2, anom, ipiv) = unsafe { A::lu_tridiagonal(&mut self)? }; - Ok(LUFactorizedTriDiagonal { + Ok(LUFactorizedTridiagonal { a: self, du2: du2, anom: anom, @@ -668,26 +672,26 @@ where } } -impl FactorizeTriDiagonal for TriDiagonal +impl FactorizeTridiagonal for Tridiagonal where A: Scalar + Lapack, { - fn factorize_tridiagonal(&self) -> Result> { + fn factorize_tridiagonal(&self) -> Result> { let mut a = self.clone(); let (du2, anom, ipiv) = unsafe { A::lu_tridiagonal(&mut a)? }; - Ok(LUFactorizedTriDiagonal { a, du2, anom, ipiv }) + Ok(LUFactorizedTridiagonal { a, du2, anom, ipiv }) } } -impl FactorizeTriDiagonal for ArrayBase +impl FactorizeTridiagonal for ArrayBase where A: Scalar + Lapack, S: Data, { - fn factorize_tridiagonal(&self) -> Result> { - let mut a = self.to_tridiagonal()?; + fn factorize_tridiagonal(&self) -> Result> { + let mut a = self.extract_tridiagonal()?; let (du2, anom, ipiv) = unsafe { A::lu_tridiagonal(&mut a)? }; - Ok(LUFactorizedTriDiagonal { a, du2, anom, ipiv }) + Ok(LUFactorizedTridiagonal { a, du2, anom, ipiv }) } } @@ -702,8 +706,8 @@ where /// /// In the future, the vector `f` can be used to calculate the inverce matrix. /// (https://en.wikipedia.org/wiki/Tridiagonal_matrix#Inversion) -fn rec_rel(tridiag: &TriDiagonal) -> Vec { - let n = tridiag.d.shape()[0]; +fn rec_rel(tridiag: &Tridiagonal) -> Vec { + let n = tridiag.d.len(); let mut f = Vec::with_capacity(n + 1); f.push(One::one()); f.push(tridiag.d[0]); @@ -714,7 +718,7 @@ fn rec_rel(tridiag: &TriDiagonal) -> Vec { } /// An interface for calculating determinants of tridiagonal matrix refs. -pub trait DeterminantTriDiagonal { +pub trait DeterminantTridiagonal { /// Computes the determinant of the matrix. /// Unlike `.det()` of Determinant trait, this method /// doesn't returns the natural logarithm of the determinant @@ -722,30 +726,30 @@ pub trait DeterminantTriDiagonal { fn det_tridiagonal(&self) -> Result; } -impl DeterminantTriDiagonal for TriDiagonal +impl DeterminantTridiagonal for Tridiagonal where A: Scalar, { fn det_tridiagonal(&self) -> Result { - let n = self.d.shape()[0]; + let n = self.d.len(); Ok(rec_rel(&self)[n]) } } -impl DeterminantTriDiagonal for ArrayBase +impl DeterminantTridiagonal for ArrayBase where A: Scalar + Lapack, S: Data, { fn det_tridiagonal(&self) -> Result { - let tridiag = self.to_tridiagonal()?; - let n = tridiag.d.shape()[0]; + let tridiag = self.extract_tridiagonal()?; + let n = tridiag.d.len(); Ok(rec_rel(&tridiag)[n]) } } /// An interface for *estimating* the reciprocal condition number of tridiagonal matrix refs. -pub trait ReciprocalConditionNumTriDiagonal { +pub trait ReciprocalConditionNumTridiagonal { /// *Estimates* the reciprocal of the condition number of the tridiagonal matrix in /// 1-norm. /// @@ -759,7 +763,7 @@ pub trait ReciprocalConditionNumTriDiagonal { } /// An interface for *estimating* the reciprocal condition number of tridiagonal matrices. -pub trait ReciprocalConditionNumTriDiagonalInto { +pub trait ReciprocalConditionNumTridiagonalInto { /// *Estimates* the reciprocal of the condition number of the tridiagonal matrix in /// 1-norm. /// @@ -772,7 +776,7 @@ pub trait ReciprocalConditionNumTriDiagonalInto { fn rcond_tridiagonal_into(self) -> Result; } -impl ReciprocalConditionNumTriDiagonal for LUFactorizedTriDiagonal +impl ReciprocalConditionNumTridiagonal for LUFactorizedTridiagonal where A: Scalar + Lapack, { @@ -781,7 +785,7 @@ where } } -impl ReciprocalConditionNumTriDiagonalInto for LUFactorizedTriDiagonal +impl ReciprocalConditionNumTridiagonalInto for LUFactorizedTridiagonal where A: Scalar + Lapack, { @@ -790,7 +794,7 @@ where } } -impl ReciprocalConditionNumTriDiagonal for ArrayBase +impl ReciprocalConditionNumTridiagonal for ArrayBase where A: Scalar + Lapack, S: Data, diff --git a/tests/tridiagonal.rs b/tests/tridiagonal.rs index 19e66f60..38278951 100644 --- a/tests/tridiagonal.rs +++ b/tests/tridiagonal.rs @@ -2,19 +2,19 @@ use ndarray::*; use ndarray_linalg::*; #[test] -fn to_tridiagonal() { +fn extract_tridiagonal() { let a: Array2 = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]); - let t = a.to_tridiagonal().unwrap(); - assert_close_l2!(&t.dl, &arr1(&[4.0, 8.0]), 1e-7); - assert_close_l2!(&t.d, &arr1(&[1.0, 5.0, 9.0]), 1e-7); - assert_close_l2!(&t.du, &arr1(&[2.0, 6.0]), 1e-7); + let t = a.extract_tridiagonal().unwrap(); + assert_close_l2!(&arr1(&t.dl), &arr1(&[4.0, 8.0]), 1e-7); + assert_close_l2!(&arr1(&t.d), &arr1(&[1.0, 5.0, 9.0]), 1e-7); + assert_close_l2!(&arr1(&t.du), &arr1(&[2.0, 6.0]), 1e-7); } #[test] fn tridiagonal_index() { let a: Array2 = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]); - let t1 = a.to_tridiagonal().unwrap(); - let mut t2 = Array2::::eye(3).to_tridiagonal().unwrap(); + let t1 = a.extract_tridiagonal().unwrap(); + let mut t2 = Array2::::eye(3).extract_tridiagonal().unwrap(); t2[[0, 1]] = 2.0; t2[[1, 0]] = 4.0; t2[[1, 1]] += 4.0; @@ -35,7 +35,7 @@ fn opnorm_tridiagonal() { a[[2, 0]] = 0.0; a[[3, 0]] = 0.0; a[[3, 1]] = 0.0; - let t = a.to_tridiagonal().unwrap(); + let t = a.extract_tridiagonal().unwrap(); assert_aclose!(a.opnorm_one().unwrap(), t.opnorm_one().unwrap(), 1e-7); assert_aclose!(a.opnorm_inf().unwrap(), t.opnorm_inf().unwrap(), 1e-7); assert_aclose!(a.opnorm_fro().unwrap(), t.opnorm_fro().unwrap(), 1e-7); @@ -157,11 +157,11 @@ fn solve_tridiagonal_random_t() { } #[test] -fn to_tridiagonal_solve_random() { +fn extract_tridiagonal_solve_random() { let mut a: Array2 = random((3, 3)); a[[0, 2]] = 0.0; a[[2, 0]] = 0.0; - let tridiag = a.to_tridiagonal().unwrap(); + let tridiag = a.extract_tridiagonal().unwrap(); let x: Array1 = random(3); let b1 = a.dot(&x); let b2 = b1.clone(); From e005ae1b2d482c68c8b9bed06eaea1ed9d92cdc1 Mon Sep 17 00:00:00 2001 From: doraneko94 Date: Sat, 27 Jun 2020 16:35:48 +0900 Subject: [PATCH 8/8] remove private example code --- examples/tridiagonal.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/examples/tridiagonal.rs b/examples/tridiagonal.rs index cffa3173..676bfb28 100644 --- a/examples/tridiagonal.rs +++ b/examples/tridiagonal.rs @@ -27,8 +27,4 @@ fn factorize() -> Result<(), error::LinalgError> { fn main() { solve().unwrap(); factorize().unwrap(); - match arr2(&[[0.0]]).extract_tridiagonal() { - Ok(_) => {} - Err(err) => println!("{}", err), - } }