diff --git a/examples/tridiagonal.rs b/examples/tridiagonal.rs new file mode 100644 index 00000000..676bfb28 --- /dev/null +++ b/examples/tridiagonal.rs @@ -0,0 +1,30 @@ +use ndarray::*; +use ndarray_linalg::*; + +// Solve `Ax=b` for tridiagonal matrix +fn solve() -> Result<(), error::LinalgError> { + let mut a: Array2 = random((3, 3)); + let b: Array1 = random(3); + a[[0, 2]] = 0.0; + a[[2, 0]] = 0.0; + let _x = a.solve_tridiagonal(&b)?; + Ok(()) +} + +// Solve `Ax=b` for many b with fixed A +fn factorize() -> Result<(), error::LinalgError> { + let mut a: Array2 = random((3, 3)); + a[[0, 2]] = 0.0; + a[[2, 0]] = 0.0; + let f = a.factorize_tridiagonal()?; // LU factorize A (A is *not* consumed) + for _ in 0..10 { + let b: Array1 = random(3); + let _x = f.solve_tridiagonal_into(b)?; // solve Ax=b using factorized L, U + } + Ok(()) +} + +fn main() { + solve().unwrap(); + factorize().unwrap(); +} diff --git a/src/error.rs b/src/error.rs index 4e487acf..c6524897 100644 --- a/src/error.rs +++ b/src/error.rs @@ -17,6 +17,12 @@ pub enum LinalgError { InvalidStride { s0: Ixs, s1: Ixs }, /// Memory is not aligned continously MemoryNotCont, + /// Obj cannot be made from a (rows, cols) matrix + NotStandardShape { + obj: &'static str, + rows: i32, + cols: i32, + }, /// Strides of the array is not supported Shape(ShapeError), } @@ -34,6 +40,11 @@ impl fmt::Display for LinalgError { write!(f, "invalid stride: s0={}, s1={}", s0, s1) } LinalgError::MemoryNotCont => write!(f, "Memory is not contiguous"), + LinalgError::NotStandardShape { obj, rows, cols } => write!( + f, + "{} cannot be made from a ({}, {}) matrix", + obj, rows, cols + ), LinalgError::Shape(err) => write!(f, "Shape Error: {}", err), } } diff --git a/src/lapack/mod.rs b/src/lapack/mod.rs index 6a6903fe..a3482505 100644 --- a/src/lapack/mod.rs +++ b/src/lapack/mod.rs @@ -10,6 +10,7 @@ pub mod solveh; pub mod svd; pub mod svddc; pub mod triangular; +pub mod tridiagonal; pub use self::cholesky::*; pub use self::eig::*; @@ -21,6 +22,7 @@ pub use self::solveh::*; pub use self::svd::*; pub use self::svddc::*; pub use self::triangular::*; +pub use self::tridiagonal::*; use super::error::*; use super::types::*; @@ -29,7 +31,17 @@ pub type Pivot = Vec; /// Trait for primitive types which implements LAPACK subroutines pub trait Lapack: - OperatorNorm_ + QR_ + SVD_ + SVDDC_ + Solve_ + Solveh_ + Cholesky_ + Eig_ + Eigh_ + Triangular_ + OperatorNorm_ + + QR_ + + SVD_ + + SVDDC_ + + Solve_ + + Solveh_ + + Cholesky_ + + Eig_ + + Eigh_ + + Triangular_ + + Tridiagonal_ { } diff --git a/src/lapack/tridiagonal.rs b/src/lapack/tridiagonal.rs new file mode 100644 index 00000000..468cc8ea --- /dev/null +++ b/src/lapack/tridiagonal.rs @@ -0,0 +1,96 @@ +//! Implement linear solver using LU decomposition +//! for tridiagonal matrix + +use lapacke; +use num_traits::Zero; + +use super::NormType; +use super::{into_result, Pivot, Transpose}; + +use crate::error::*; +use crate::layout::MatrixLayout; +use crate::opnorm::*; +use crate::tridiagonal::{LUFactorizedTridiagonal, Tridiagonal}; +use crate::types::*; + +/// Wraps `*gttrf`, `*gtcon` and `*gttrs` +pub trait Tridiagonal_: Scalar + Sized { + /// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using + /// partial pivoting with row interchanges. + unsafe fn lu_tridiagonal(a: &mut Tridiagonal) -> Result<(Vec, Self::Real, Pivot)>; + /// Estimates the the reciprocal of the condition number of the tridiagonal matrix in 1-norm. + unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result; + unsafe fn solve_tridiagonal( + lu: &LUFactorizedTridiagonal, + bl: MatrixLayout, + t: Transpose, + b: &mut [Self], + ) -> Result<()>; +} + +macro_rules! impl_tridiagonal { + ($scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { + impl Tridiagonal_ for $scalar { + unsafe fn lu_tridiagonal( + a: &mut Tridiagonal, + ) -> Result<(Vec, Self::Real, Pivot)> { + let (n, _) = a.l.size(); + let anom = a.opnorm_one()?; + let mut du2 = vec![Zero::zero(); (n - 2) as usize]; + let mut ipiv = vec![0; n as usize]; + let info = $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv); + into_result(info, (du2, anom, ipiv)) + } + + unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { + let (n, _) = lu.a.l.size(); + let ipiv = &lu.ipiv; + let anorm = lu.anom; + let mut rcond = Self::Real::zero(); + let info = $gtcon( + NormType::One as u8, + n, + &lu.a.dl, + &lu.a.d, + &lu.a.du, + &lu.du2, + ipiv, + anorm, + &mut rcond, + ); + into_result(info, rcond) + } + + unsafe fn solve_tridiagonal( + lu: &LUFactorizedTridiagonal, + bl: MatrixLayout, + t: Transpose, + b: &mut [Self], + ) -> Result<()> { + let (n, _) = lu.a.l.size(); + let (_, nrhs) = bl.size(); + let ipiv = &lu.ipiv; + let ldb = bl.lda(); + let info = $gttrs( + lu.a.l.lapacke_layout(), + t as u8, + n, + nrhs, + &lu.a.dl, + &lu.a.d, + &lu.a.du, + &lu.du2, + ipiv, + b, + ldb, + ); + into_result(info, ()) + } + } + }; +} // impl_tridiagonal! + +impl_tridiagonal!(f64, lapacke::dgttrf, lapacke::dgtcon, lapacke::dgttrs); +impl_tridiagonal!(f32, lapacke::sgttrf, lapacke::sgtcon, lapacke::sgttrs); +impl_tridiagonal!(c64, lapacke::zgttrf, lapacke::zgtcon, lapacke::zgttrs); +impl_tridiagonal!(c32, lapacke::cgttrf, lapacke::cgtcon, lapacke::cgttrs); diff --git a/src/lib.rs b/src/lib.rs index e3c90efe..06d9c89a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ //! - [General matrices](solve/index.html) //! - [Triangular matrices](triangular/index.html) //! - [Hermitian/real symmetric matrices](solveh/index.html) +//! - [Tridiagonal matrices](tridiagonal/index.html) //! - [Inverse matrix computation](solve/trait.Inverse.html) //! //! Naming Convention @@ -66,6 +67,7 @@ pub mod svd; pub mod svddc; pub mod trace; pub mod triangular; +pub mod tridiagonal; pub mod types; pub use assert::*; @@ -88,4 +90,5 @@ pub use svd::*; pub use svddc::*; pub use trace::*; pub use triangular::*; +pub use tridiagonal::*; pub use types::*; diff --git a/src/opnorm.rs b/src/opnorm.rs index 37d38115..4358fd7f 100644 --- a/src/opnorm.rs +++ b/src/opnorm.rs @@ -2,8 +2,10 @@ use ndarray::*; +use crate::convert::*; use crate::error::*; use crate::layout::*; +use crate::tridiagonal::Tridiagonal; use crate::types::*; pub use crate::lapack::NormType; @@ -46,3 +48,66 @@ where Ok(unsafe { A::opnorm(t, l, a) }) } } + +impl OperationNorm for Tridiagonal +where + A: Scalar + Lapack, +{ + type Output = A::Real; + + fn opnorm(&self, t: NormType) -> Result { + // `self` is a tridiagonal matrix like, + // [d0, u1, 0, ..., 0, + // l1, d1, u2, ..., + // 0, l2, d2, + // ... ..., u{n-1}, + // 0, ..., l{n-1}, d{n-1},] + let arr = match t { + // opnorm_one() calculates muximum column sum. + // Therefore, This part align the columns and make a (3 x n) matrix like, + // [ 0, u1, u2, ..., u{n-1}, + // d0, d1, d2, ..., d{n-1}, + // l1, l2, l3, ..., 0,] + NormType::One => { + let zl: Array1 = Array::zeros(1); + let zu: Array1 = Array::zeros(1); + let dl = stack![Axis(0), self.dl.to_owned(), zl]; + let du = stack![Axis(0), zu, self.du.to_owned()]; + let arr = stack![Axis(0), into_row(du), into_row(arr1(&self.d)), into_row(dl)]; + arr + } + // opnorm_inf() calculates muximum row sum. + // Therefore, This part align the rows and make a (n x 3) matrix like, + // [ 0, d0, u1, + // l1, d1, u2, + // l2, d2, u3, + // ..., ..., ..., + // l{n-1}, d{n-1}, 0,] + NormType::Infinity => { + let zl: Array1 = Array::zeros(1); + let zu: Array1 = Array::zeros(1); + let dl = stack![Axis(0), zl, self.dl.to_owned()]; + let du = stack![Axis(0), self.du.to_owned(), zu]; + let arr = stack![Axis(1), into_col(dl), into_col(arr1(&self.d)), into_col(du)]; + arr + } + // opnorm_fro() calculates square root of sum of squares. + // Because it is independent of the shape of matrix, + // this part make a (1 x (3n-2)) matrix like, + // [l1, ..., l{n-1}, d0, ..., d{n-1}, u1, ..., u{n-1}] + NormType::Frobenius => { + let arr = stack![ + Axis(1), + into_row(arr1(&self.dl)), + into_row(arr1(&self.d)), + into_row(arr1(&self.du)) + ]; + arr + } + }; + + let l = arr.layout()?; + let a = arr.as_allocated()?; + Ok(unsafe { A::opnorm(t, l, a) }) + } +} diff --git a/src/tridiagonal.rs b/src/tridiagonal.rs new file mode 100644 index 00000000..155aa463 --- /dev/null +++ b/src/tridiagonal.rs @@ -0,0 +1,805 @@ +//! Vectors as a Tridiagonal matrix +//! & +//! Methods for tridiagonal matrices + +use std::ops::{Index, IndexMut}; + +use cauchy::Scalar; +use ndarray::*; +use num_traits::One; + +use super::convert::*; +use super::error::*; +use super::lapack::*; +use super::layout::*; + +/// Represents a tridiagonal matrix as 3 one-dimensional vectors. +/// This struct also holds the layout of the raw matrix. +#[derive(Clone, PartialEq)] +pub struct Tridiagonal { + /// layout of raw matrix + pub l: MatrixLayout, + /// (n-1) sub-diagonal elements of matrix. + pub dl: Vec, + /// (n) diagonal elements of matrix. + pub d: Vec, + /// (n-1) super-diagonal elements of matrix. + pub du: Vec, +} + +pub trait TridiagIndex { + fn to_tuple(&self) -> (i32, i32); +} +impl TridiagIndex for [Ix; 2] { + fn to_tuple(&self) -> (i32, i32) { + (self[0] as i32, self[1] as i32) + } +} + +impl Index for Tridiagonal +where + A: Scalar, + I: TridiagIndex, +{ + type Output = A; + #[inline] + fn index(&self, index: I) -> &A { + let (n, _) = self.l.size(); + let (row, col) = index.to_tuple(); + assert!( + std::cmp::max(row, col) < n, + "ndarray: index {:?} is out of bounds for array of shape {}", + [row, col], + n + ); + match row - col { + 0 => &self.d[row as usize], + 1 => &self.dl[col as usize], + -1 => &self.du[row as usize], + _ => panic!( + "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element", + [row, col] + ), + } + } +} + +impl IndexMut for Tridiagonal +where + A: Scalar, + I: TridiagIndex, +{ + #[inline] + fn index_mut(&mut self, index: I) -> &mut A { + let (n, _) = self.l.size(); + let (row, col) = index.to_tuple(); + assert!( + std::cmp::max(row, col) < n, + "ndarray: index {:?} is out of bounds for array of shape {}", + [row, col], + n + ); + match row - col { + 0 => &mut self.d[row as usize], + 1 => &mut self.dl[col as usize], + -1 => &mut self.du[row as usize], + _ => panic!( + "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element", + [row, col] + ), + } + } +} + +/// An interface for making a Tridiagonal struct. +pub trait ToTridiagonal { + /// Extract tridiagonal elements and layout of the raw matrix. + /// + /// If the raw matrix has some non-tridiagonal elements, + /// they will be ignored. + /// + /// The shape of raw matrix should be equal to or larger than (2, 2). + fn extract_tridiagonal(&self) -> Result>; +} + +impl ToTridiagonal for ArrayBase +where + A: Scalar + Lapack, + S: Data, +{ + fn extract_tridiagonal(&self) -> Result> { + let l = self.square_layout()?; + let (n, _) = l.size(); + if n < 2 { + return Err(LinalgError::NotStandardShape { + obj: "Tridiagonal", + rows: 1, + cols: 1, + }); + } + + let dl = self.slice(s![1..n, 0..n - 1]).diag().to_vec(); + let d = self.diag().to_vec(); + let du = self.slice(s![0..n - 1, 1..n]).diag().to_vec(); + Ok(Tridiagonal { l, dl, d, du }) + } +} + +pub trait SolveTridiagonal { + /// Solves a system of linear equations `A * x = b` with tridiagonal + /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// `x` is the successful result. + fn solve_tridiagonal>(&self, b: &ArrayBase) -> Result>; + /// Solves a system of linear equations `A * x = b` with tridiagonal + /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// `x` is the successful result. + fn solve_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result>; + /// Solves a system of linear equations `A^T * x = b` with tridiagonal + /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// `x` is the successful result. + fn solve_t_tridiagonal>(&self, b: &ArrayBase) -> Result>; + /// Solves a system of linear equations `A^T * x = b` with tridiagonal + /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// `x` is the successful result. + fn solve_t_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result>; + /// Solves a system of linear equations `A^H * x = b` with tridiagonal + /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// `x` is the successful result. + fn solve_h_tridiagonal>(&self, b: &ArrayBase) -> Result>; + /// Solves a system of linear equations `A^H * x = b` with tridiagonal + /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// `x` is the successful result. + fn solve_h_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result>; +} + +pub trait SolveTridiagonalInplace { + /// Solves a system of linear equations `A * x = b` tridiagonal + /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// `x` is the successful result. The value of `x` is also assigned to the + /// argument. + fn solve_tridiagonal_inplace<'a, S: DataMut>( + &self, + b: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase>; + /// Solves a system of linear equations `A^T * x = b` tridiagonal + /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// `x` is the successful result. The value of `x` is also assigned to the + /// argument. + fn solve_t_tridiagonal_inplace<'a, S: DataMut>( + &self, + b: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase>; + /// Solves a system of linear equations `A^H * x = b` tridiagonal + /// matrix `A`, where `A` is `self`, `b` is the argument, and + /// `x` is the successful result. The value of `x` is also assigned to the + /// argument. + fn solve_h_tridiagonal_inplace<'a, S: DataMut>( + &self, + b: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase>; +} + +/// Represents the LU factorization of a tridiagonal matrix `A` as `A = P*L*U`. +#[derive(Clone)] +pub struct LUFactorizedTridiagonal { + /// A tridiagonal matrix which consists of + /// - l : layout of raw matrix + /// - dl: (n-1) multipliers that define the matrix L. + /// - d : (n) diagonal elements of the upper triangular matrix U. + /// - du: (n-1) elements of the first super-diagonal of U. + pub a: Tridiagonal, + /// (n-2) elements of the second super-diagonal of U. + pub du2: Vec, + /// 1-norm of raw matrix (used in .rcond_tridiagonal()). + pub anom: A::Real, + /// The pivot indices that define the permutation matrix `P`. + pub ipiv: Pivot, +} + +impl SolveTridiagonal for LUFactorizedTridiagonal +where + A: Scalar + Lapack, +{ + fn solve_tridiagonal>(&self, b: &ArrayBase) -> Result> { + let mut b = replicate(b); + self.solve_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_t_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_t_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_t_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_t_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_h_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_h_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_h_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_h_tridiagonal_inplace(&mut b)?; + Ok(b) + } +} + +impl SolveTridiagonal for Tridiagonal +where + A: Scalar + Lapack, +{ + fn solve_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_t_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_t_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_t_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_t_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_h_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_h_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_h_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_h_tridiagonal_inplace(&mut b)?; + Ok(b) + } +} + +impl SolveTridiagonal for ArrayBase +where + A: Scalar + Lapack, + S: Data, +{ + fn solve_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_t_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_t_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_t_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_t_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_h_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_h_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_h_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_h_tridiagonal_inplace(&mut b)?; + Ok(b) + } +} + +impl SolveTridiagonalInplace for LUFactorizedTridiagonal +where + A: Scalar + Lapack, +{ + fn solve_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + unsafe { + A::solve_tridiagonal( + &self, + rhs.layout()?, + Transpose::No, + rhs.as_slice_mut().unwrap(), + )? + }; + Ok(rhs) + } + fn solve_t_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + unsafe { + A::solve_tridiagonal( + &self, + rhs.layout()?, + Transpose::Transpose, + rhs.as_slice_mut().unwrap(), + )? + }; + Ok(rhs) + } + fn solve_h_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + unsafe { + A::solve_tridiagonal( + &self, + rhs.layout()?, + Transpose::Hermite, + rhs.as_slice_mut().unwrap(), + )? + }; + Ok(rhs) + } +} + +impl SolveTridiagonalInplace for Tridiagonal +where + A: Scalar + Lapack, +{ + fn solve_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize_tridiagonal()?; + f.solve_tridiagonal_inplace(rhs) + } + fn solve_t_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize_tridiagonal()?; + f.solve_t_tridiagonal_inplace(rhs) + } + fn solve_h_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize_tridiagonal()?; + f.solve_h_tridiagonal_inplace(rhs) + } +} + +impl SolveTridiagonalInplace for ArrayBase +where + A: Scalar + Lapack, + S: Data, +{ + fn solve_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize_tridiagonal()?; + f.solve_tridiagonal_inplace(rhs) + } + fn solve_t_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize_tridiagonal()?; + f.solve_t_tridiagonal_inplace(rhs) + } + fn solve_h_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize_tridiagonal()?; + f.solve_h_tridiagonal_inplace(rhs) + } +} + +impl SolveTridiagonal for LUFactorizedTridiagonal +where + A: Scalar + Lapack, +{ + fn solve_tridiagonal>(&self, b: &ArrayBase) -> Result> { + let b = b.to_owned(); + self.solve_tridiagonal_into(b) + } + fn solve_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result> { + let b = into_col(b); + let b = self.solve_tridiagonal_into(b)?; + Ok(flatten(b)) + } + fn solve_t_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let b = b.to_owned(); + self.solve_t_tridiagonal_into(b) + } + fn solve_t_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result> { + let b = into_col(b); + let b = self.solve_t_tridiagonal_into(b)?; + Ok(flatten(b)) + } + fn solve_h_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let b = b.to_owned(); + self.solve_h_tridiagonal_into(b) + } + fn solve_h_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result> { + let b = into_col(b); + let b = self.solve_h_tridiagonal_into(b)?; + Ok(flatten(b)) + } +} + +impl SolveTridiagonal for Tridiagonal +where + A: Scalar + Lapack, +{ + fn solve_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let b = b.to_owned(); + self.solve_tridiagonal_into(b) + } + fn solve_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result> { + let b = into_col(b); + let f = self.factorize_tridiagonal()?; + let b = f.solve_tridiagonal_into(b)?; + Ok(flatten(b)) + } + fn solve_t_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let b = b.to_owned(); + self.solve_t_tridiagonal_into(b) + } + fn solve_t_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result> { + let b = into_col(b); + let f = self.factorize_tridiagonal()?; + let b = f.solve_t_tridiagonal_into(b)?; + Ok(flatten(b)) + } + fn solve_h_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let b = b.to_owned(); + self.solve_h_tridiagonal_into(b) + } + fn solve_h_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result> { + let b = into_col(b); + let f = self.factorize_tridiagonal()?; + let b = f.solve_h_tridiagonal_into(b)?; + Ok(flatten(b)) + } +} + +impl SolveTridiagonal for ArrayBase +where + A: Scalar + Lapack, + S: Data, +{ + fn solve_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let b = b.to_owned(); + self.solve_tridiagonal_into(b) + } + fn solve_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result> { + let b = into_col(b); + let f = self.factorize_tridiagonal()?; + let b = f.solve_tridiagonal_into(b)?; + Ok(flatten(b)) + } + fn solve_t_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let b = b.to_owned(); + self.solve_t_tridiagonal_into(b) + } + fn solve_t_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result> { + let b = into_col(b); + let f = self.factorize_tridiagonal()?; + let b = f.solve_t_tridiagonal_into(b)?; + Ok(flatten(b)) + } + fn solve_h_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let b = b.to_owned(); + self.solve_h_tridiagonal_into(b) + } + fn solve_h_tridiagonal_into>( + &self, + b: ArrayBase, + ) -> Result> { + let b = into_col(b); + let f = self.factorize_tridiagonal()?; + let b = f.solve_h_tridiagonal_into(b)?; + Ok(flatten(b)) + } +} + +/// An interface for computing LU factorizations of tridiagonal matrix refs. +pub trait FactorizeTridiagonal { + /// Computes the LU factorization `A = P*L*U`, where `P` is a permutation + /// matrix. + fn factorize_tridiagonal(&self) -> Result>; +} + +/// An interface for computing LU factorizations of tridiagonal matrices. +pub trait FactorizeTridiagonalInto { + /// Computes the LU factorization `A = P*L*U`, where `P` is a permutation + /// matrix. + fn factorize_tridiagonal_into(self) -> Result>; +} + +impl FactorizeTridiagonalInto for Tridiagonal +where + A: Scalar + Lapack, +{ + fn factorize_tridiagonal_into(mut self) -> Result> { + let (du2, anom, ipiv) = unsafe { A::lu_tridiagonal(&mut self)? }; + Ok(LUFactorizedTridiagonal { + a: self, + du2: du2, + anom: anom, + ipiv: ipiv, + }) + } +} + +impl FactorizeTridiagonal for Tridiagonal +where + A: Scalar + Lapack, +{ + fn factorize_tridiagonal(&self) -> Result> { + let mut a = self.clone(); + let (du2, anom, ipiv) = unsafe { A::lu_tridiagonal(&mut a)? }; + Ok(LUFactorizedTridiagonal { a, du2, anom, ipiv }) + } +} + +impl FactorizeTridiagonal for ArrayBase +where + A: Scalar + Lapack, + S: Data, +{ + fn factorize_tridiagonal(&self) -> Result> { + let mut a = self.extract_tridiagonal()?; + let (du2, anom, ipiv) = unsafe { A::lu_tridiagonal(&mut a)? }; + Ok(LUFactorizedTridiagonal { a, du2, anom, ipiv }) + } +} + +/// Calculates the recurrent relation, +/// f_k = a_k * f_{k-1} - c_{k-1} * b_{k-1} * f_{k-2} +/// where {a_1, a_2, ..., a_n} are diagonal elements, +/// {b_1, b_2, ..., b_{n-1}} are super-diagonal elements, and +/// {c_1, c_2, ..., c_{n-1}} are sub-diagonal elements of matrix. +/// +/// f[n] is used to calculate the determinant. +/// (https://en.wikipedia.org/wiki/Tridiagonal_matrix#Determinant) +/// +/// In the future, the vector `f` can be used to calculate the inverce matrix. +/// (https://en.wikipedia.org/wiki/Tridiagonal_matrix#Inversion) +fn rec_rel(tridiag: &Tridiagonal) -> Vec { + let n = tridiag.d.len(); + let mut f = Vec::with_capacity(n + 1); + f.push(One::one()); + f.push(tridiag.d[0]); + for i in 1..n { + f.push(tridiag.d[i] * f[i] - tridiag.dl[i - 1] * tridiag.du[i - 1] * f[i - 1]); + } + f +} + +/// An interface for calculating determinants of tridiagonal matrix refs. +pub trait DeterminantTridiagonal { + /// Computes the determinant of the matrix. + /// Unlike `.det()` of Determinant trait, this method + /// doesn't returns the natural logarithm of the determinant + /// but the determinant itself. + fn det_tridiagonal(&self) -> Result; +} + +impl DeterminantTridiagonal for Tridiagonal +where + A: Scalar, +{ + fn det_tridiagonal(&self) -> Result { + let n = self.d.len(); + Ok(rec_rel(&self)[n]) + } +} + +impl DeterminantTridiagonal for ArrayBase +where + A: Scalar + Lapack, + S: Data, +{ + fn det_tridiagonal(&self) -> Result { + let tridiag = self.extract_tridiagonal()?; + let n = tridiag.d.len(); + Ok(rec_rel(&tridiag)[n]) + } +} + +/// An interface for *estimating* the reciprocal condition number of tridiagonal matrix refs. +pub trait ReciprocalConditionNumTridiagonal { + /// *Estimates* the reciprocal of the condition number of the tridiagonal matrix in + /// 1-norm. + /// + /// This method uses the LAPACK `*gtcon` routines, which *estimate* + /// `self.inv_tridiagonal().opnorm_one()` and then compute `rcond = 1. / + /// (self.opnorm_one() * self.inv_tridiagonal().opnorm_one())`. + /// + /// * If `rcond` is near `0.`, the matrix is badly conditioned. + /// * If `rcond` is near `1.`, the matrix is well conditioned. + fn rcond_tridiagonal(&self) -> Result; +} + +/// An interface for *estimating* the reciprocal condition number of tridiagonal matrices. +pub trait ReciprocalConditionNumTridiagonalInto { + /// *Estimates* the reciprocal of the condition number of the tridiagonal matrix in + /// 1-norm. + /// + /// This method uses the LAPACK `*gtcon` routines, which *estimate* + /// `self.inv_tridiagonal().opnorm_one()` and then compute `rcond = 1. / + /// (self.opnorm_one() * self.inv_tridiagonal().opnorm_one())`. + /// + /// * If `rcond` is near `0.`, the matrix is badly conditioned. + /// * If `rcond` is near `1.`, the matrix is well conditioned. + fn rcond_tridiagonal_into(self) -> Result; +} + +impl ReciprocalConditionNumTridiagonal for LUFactorizedTridiagonal +where + A: Scalar + Lapack, +{ + fn rcond_tridiagonal(&self) -> Result { + unsafe { A::rcond_tridiagonal(&self) } + } +} + +impl ReciprocalConditionNumTridiagonalInto for LUFactorizedTridiagonal +where + A: Scalar + Lapack, +{ + fn rcond_tridiagonal_into(self) -> Result { + self.rcond_tridiagonal() + } +} + +impl ReciprocalConditionNumTridiagonal for ArrayBase +where + A: Scalar + Lapack, + S: Data, +{ + fn rcond_tridiagonal(&self) -> Result { + self.factorize_tridiagonal()?.rcond_tridiagonal_into() + } +} diff --git a/tests/tridiagonal.rs b/tests/tridiagonal.rs new file mode 100644 index 00000000..38278951 --- /dev/null +++ b/tests/tridiagonal.rs @@ -0,0 +1,262 @@ +use ndarray::*; +use ndarray_linalg::*; + +#[test] +fn extract_tridiagonal() { + let a: Array2 = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]); + let t = a.extract_tridiagonal().unwrap(); + assert_close_l2!(&arr1(&t.dl), &arr1(&[4.0, 8.0]), 1e-7); + assert_close_l2!(&arr1(&t.d), &arr1(&[1.0, 5.0, 9.0]), 1e-7); + assert_close_l2!(&arr1(&t.du), &arr1(&[2.0, 6.0]), 1e-7); +} + +#[test] +fn tridiagonal_index() { + let a: Array2 = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]); + let t1 = a.extract_tridiagonal().unwrap(); + let mut t2 = Array2::::eye(3).extract_tridiagonal().unwrap(); + t2[[0, 1]] = 2.0; + t2[[1, 0]] = 4.0; + t2[[1, 1]] += 4.0; + t2[[1, 2]] = 6.0; + t2[[2, 1]] = 8.0; + t2[[2, 2]] += 8.0; + assert_eq!(t1.dl, t2.dl); + assert_eq!(t1.d, t2.d); + assert_eq!(t1.du, t2.du); +} + +#[test] +fn opnorm_tridiagonal() { + let mut a: Array2 = random((4, 4)); + a[[0, 2]] = 0.0; + a[[0, 3]] = 0.0; + a[[1, 3]] = 0.0; + a[[2, 0]] = 0.0; + a[[3, 0]] = 0.0; + a[[3, 1]] = 0.0; + let t = a.extract_tridiagonal().unwrap(); + assert_aclose!(a.opnorm_one().unwrap(), t.opnorm_one().unwrap(), 1e-7); + assert_aclose!(a.opnorm_inf().unwrap(), t.opnorm_inf().unwrap(), 1e-7); + assert_aclose!(a.opnorm_fro().unwrap(), t.opnorm_fro().unwrap(), 1e-7); +} + +#[test] +fn solve_tridiagonal_f64() { + // https://www.nag-j.co.jp/lapack/dgttrs.htm + let a: Array2 = arr2(&[ + [3.0, 2.1, 0.0, 0.0, 0.0], + [3.4, 2.3, -1.0, 0.0, 0.0], + [0.0, 3.6, -5.0, 1.9, 0.0], + [0.0, 0.0, 7.0, -0.9, 8.0], + [0.0, 0.0, 0.0, -6.0, 7.1], + ]); + let b: Array2 = arr2(&[ + [2.7, 6.6], + [-0.5, 10.8], + [2.6, -3.2], + [0.6, -11.2], + [2.7, 19.1], + ]); + let x: Array2 = arr2(&[ + [-4.0, 5.0], + [7.0, -4.0], + [3.0, -3.0], + [-4.0, -2.0], + [-3.0, 1.0], + ]); + let y = a.solve_tridiagonal_into(b).unwrap(); + assert_close_l2!(&x, &y, 1e-7); +} +//`*gttrf`, `*gtcon` and `*gttrs` +#[test] +fn solve_tridiagonal_c64() { + // https://www.nag-j.co.jp/lapack/zgttrs.htm + let a: Array2 = arr2(&[ + [ + c64::new(-1.3, 1.3), + c64::new(2.0, -1.0), + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + ], + [ + c64::new(1.0, -2.0), + c64::new(-1.3, 1.3), + c64::new(2.0, 1.0), + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + ], + [ + c64::new(0.0, 0.0), + c64::new(1.0, 1.0), + c64::new(-1.3, 3.3), + c64::new(-1.0, 1.0), + c64::new(0.0, 0.0), + ], + [ + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + c64::new(2.0, -3.0), + c64::new(-0.3, 4.3), + c64::new(1.0, -1.0), + ], + [ + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + c64::new(1.0, 1.0), + c64::new(-3.3, 1.3), + ], + ]); + let b: Array2 = arr2(&[ + [c64::new(2.4, -5.0), c64::new(2.7, 6.9)], + [c64::new(3.4, 18.2), c64::new(-6.9, -5.3)], + [c64::new(-14.7, 9.7), c64::new(-6.0, -0.6)], + [c64::new(31.9, -7.7), c64::new(-3.9, 9.3)], + [c64::new(-1.0, 1.6), c64::new(-3.0, 12.2)], + ]); + let x: Array2 = arr2(&[ + [c64::new(1.0, 1.0), c64::new(2.0, -1.0)], + [c64::new(3.0, -1.0), c64::new(1.0, 2.0)], + [c64::new(4.0, 5.0), c64::new(-1.0, 1.0)], + [c64::new(-1.0, -2.0), c64::new(2.0, 1.0)], + [c64::new(1.0, -1.0), c64::new(2.0, -2.0)], + ]); + let y = a.solve_tridiagonal_into(b).unwrap(); + assert_close_l2!(&x, &y, 1e-7); +} + +#[test] +fn solve_tridiagonal_random() { + let mut a: Array2 = random((3, 3)); + a[[0, 2]] = 0.0; + a[[2, 0]] = 0.0; + let x: Array1 = random(3); + let b1 = a.dot(&x); + let b2 = b1.clone(); + let y1 = a.solve_tridiagonal_into(b1).unwrap(); + let y2 = a.solve_into(b2).unwrap(); + assert_close_l2!(&x, &y1, 1e-7); + assert_close_l2!(&y1, &y2, 1e-7); +} + +#[test] +fn solve_tridiagonal_random_t() { + let mut a: Array2 = random((3, 3)); + a[[0, 2]] = 0.0; + a[[2, 0]] = 0.0; + let x: Array1 = random(3); + let at = a.t(); + let b1 = at.dot(&x); + let b2 = b1.clone(); + let y1 = a.solve_t_tridiagonal_into(b1).unwrap(); + let y2 = a.solve_t_into(b2).unwrap(); + assert_close_l2!(&x, &y1, 1e-7); + assert_close_l2!(&y1, &y2, 1e-7); +} + +#[test] +fn extract_tridiagonal_solve_random() { + let mut a: Array2 = random((3, 3)); + a[[0, 2]] = 0.0; + a[[2, 0]] = 0.0; + let tridiag = a.extract_tridiagonal().unwrap(); + let x: Array1 = random(3); + let b1 = a.dot(&x); + let b2 = b1.clone(); + let y1 = tridiag.solve_tridiagonal_into(b1).unwrap(); + let y2 = a.solve_into(b2).unwrap(); + assert_close_l2!(&x, &y1, 1e-7); + assert_close_l2!(&y1, &y2, 1e-7); +} + +#[test] +fn det_tridiagonal_f64() { + let a: Array2 = arr2(&[[10.0, -9.0, 0.0], [7.0, -12.0, 11.0], [0.0, 10.0, 3.0]]); + assert_aclose!(a.det_tridiagonal().unwrap(), -1271.0, 1e-7); + assert_aclose!(a.det_tridiagonal().unwrap(), a.det().unwrap(), 1e-7); +} + +#[test] +fn det_tridiagonal_random() { + let mut a: Array2 = random((3, 3)); + a[[0, 2]] = 0.0; + a[[2, 0]] = 0.0; + assert_aclose!(a.det_tridiagonal().unwrap(), a.det().unwrap(), 1e-7); +} + +#[test] +fn rcond_tridiagonal_f64() { + // https://www.nag-j.co.jp/lapack/dgtcon.htm + let a: Array2 = arr2(&[ + [3.0, 2.1, 0.0, 0.0, 0.0], + [3.4, 2.3, -1.0, 0.0, 0.0], + [0.0, 3.6, -5.0, 1.9, 0.0], + [0.0, 0.0, 7.0, -0.9, 8.0], + [0.0, 0.0, 0.0, -6.0, 7.1], + ]); + assert_aclose!(1.0 / a.rcond_tridiagonal().unwrap(), 9.27e1, 0.1); + assert_aclose!(a.rcond_tridiagonal().unwrap(), a.rcond().unwrap(), 1e-3); +} + +#[test] +fn rcond_tridiagonal_c64() { + // https://www.nag-j.co.jp/lapack/dgtcon.htm + let a: Array2 = arr2(&[ + [ + c64::new(-1.3, 1.3), + c64::new(2.0, -1.0), + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + ], + [ + c64::new(1.0, -2.0), + c64::new(-1.3, 1.3), + c64::new(2.0, 1.0), + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + ], + [ + c64::new(0.0, 0.0), + c64::new(1.0, 1.0), + c64::new(-1.3, 3.3), + c64::new(-1.0, 1.0), + c64::new(0.0, 0.0), + ], + [ + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + c64::new(2.0, -3.0), + c64::new(-0.3, 4.3), + c64::new(1.0, -1.0), + ], + [ + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + c64::new(0.0, 0.0), + c64::new(1.0, 1.0), + c64::new(-3.3, 1.3), + ], + ]); + assert_aclose!(1.0 / a.rcond_tridiagonal().unwrap(), 1.84e2, 1.0); + assert_aclose!(a.rcond_tridiagonal().unwrap(), a.rcond().unwrap(), 1e-3); +} + +#[test] +fn rcond_tridiagonal_identity() { + macro_rules! rcond_identity { + ($elem:ty, $rows:expr, $atol:expr) => { + let a = Array2::<$elem>::eye($rows); + assert_aclose!(a.rcond_tridiagonal().unwrap(), 1., $atol); + }; + } + for rows in 2..6 { + // cannot make 1x1 tridiagonal matrices. + rcond_identity!(f64, rows, 1e-9); + rcond_identity!(f32, rows, 1e-3); + rcond_identity!(c64, rows, 1e-9); + rcond_identity!(c32, rows, 1e-3); + } +}