From 77235f4f3fbb705d177fc36cadf2b247e913bef9 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Thu, 29 Sep 2022 20:54:44 +0900 Subject: [PATCH 01/15] InvWork --- lax/src/solve.rs | 63 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/lax/src/solve.rs b/lax/src/solve.rs index d0f764fd..89dc1113 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -64,6 +64,69 @@ pub trait Solve_: Scalar + Sized { fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; } +pub struct InvWork { + pub layout: MatrixLayout, + pub work: Vec>, +} + +pub trait InvWorkImpl: Sized { + type Elem: Scalar; + fn new(layout: MatrixLayout) -> Result; + fn calc(&mut self, a: &mut [Self::Elem], p: &Pivot) -> Result<()>; +} + +macro_rules! impl_inv_work { + ($s:ty, $tri:path) => { + impl InvWorkImpl for InvWork<$s> { + type Elem = $s; + + fn new(layout: MatrixLayout) -> Result { + let (n, _) = layout.size(); + let mut info = 0; + let mut work_size = [Self::Elem::zero()]; + unsafe { + $tri( + &n, + std::ptr::null_mut(), + &layout.lda(), + std::ptr::null(), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ) + }; + info.as_lapack_result()?; + let lwork = work_size[0].to_usize().unwrap(); + let work = vec_uninit(lwork); + Ok(InvWork { layout, work }) + } + + fn calc(&mut self, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> { + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; + unsafe { + $tri( + &self.layout.len(), + AsPtr::as_mut_ptr(a), + &self.layout.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + &mut info, + ) + }; + info.as_lapack_result()?; + Ok(()) + } + } + }; +} + +impl_inv_work!(c64, lapack_sys::zgetri_); +impl_inv_work!(c32, lapack_sys::cgetri_); +impl_inv_work!(f64, lapack_sys::dgetri_); +impl_inv_work!(f32, lapack_sys::sgetri_); + macro_rules! impl_solve { ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => { impl Solve_ for $scalar { From 842ad79f4d445931f070ef3633b11fbd25abcef0 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 30 Sep 2022 01:01:10 +0900 Subject: [PATCH 02/15] LuImpl, SolveImpl --- lax/src/lib.rs | 73 +++++++++++++- lax/src/solve.rs | 258 +++++++++++++++-------------------------------- 2 files changed, 154 insertions(+), 177 deletions(-) diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 199f2dc2..6ed8498c 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -90,6 +90,7 @@ pub mod eigh; pub mod eigh_generalized; pub mod least_squares; pub mod qr; +pub mod solve; pub mod svd; pub mod svddc; @@ -97,7 +98,6 @@ mod alloc; mod cholesky; mod opnorm; mod rcond; -mod solve; mod solveh; mod triangular; mod tridiagonal; @@ -107,7 +107,6 @@ pub use self::flags::*; pub use self::least_squares::LeastSquaresOwned; pub use self::opnorm::*; pub use self::rcond::*; -pub use self::solve::*; pub use self::solveh::*; pub use self::svd::{SvdOwned, SvdRef}; pub use self::triangular::*; @@ -122,7 +121,7 @@ pub type Pivot = Vec; #[cfg_attr(doc, katexit::katexit)] /// Trait for primitive types which implements LAPACK subroutines pub trait Lapack: - OperatorNorm_ + Solve_ + Solveh_ + Cholesky_ + Triangular_ + Tridiagonal_ + Rcond_ + OperatorNorm_ + Solveh_ + Cholesky_ + Triangular_ + Tridiagonal_ + Rcond_ { /// Compute right eigenvalue and eigenvectors for a general matrix fn eig( @@ -181,6 +180,51 @@ pub trait Lapack: b_layout: MatrixLayout, b: &mut [Self], ) -> Result>; + + /// Computes the LU decomposition of a general $m \times n$ matrix + /// with partial pivoting with row interchanges. + /// + /// Output + /// ------- + /// - $U$ and $L$ are stored in `a` after LU decomposition has succeeded. + /// - $P$ is returned as [Pivot] + /// + /// Error + /// ------ + /// - if the matrix is singular + /// - On this case, `return_code` in [Error::LapackComputationalFailure] means + /// `return_code`-th diagonal element of $U$ becomes zero. + /// + /// LAPACK correspondance + /// ---------------------- + /// + /// | f32 | f64 | c32 | c64 | + /// |:-------|:-------|:-------|:-------| + /// | sgetrf | dgetrf | cgetrf | zgetrf | + /// + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; + + /// Compute inverse matrix $A^{-1}$ from the output of LU-decomposition + /// + /// LAPACK correspondance + /// ---------------------- + /// + /// | f32 | f64 | c32 | c64 | + /// |:-------|:-------|:-------|:-------| + /// | sgetri | dgetri | cgetri | zgetri | + /// + fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; + + /// Solve linear equations $Ax = b$ using the output of LU-decomposition + /// + /// LAPACK correspondance + /// ---------------------- + /// + /// | f32 | f64 | c32 | c64 | + /// |:-------|:-------|:-------|:-------| + /// | sgetrs | dgetrs | cgetrs | zgetrs | + /// + fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; } macro_rules! impl_lapack { @@ -276,6 +320,29 @@ macro_rules! impl_lapack { let work = LeastSquaresWork::<$s>::new(a_layout, b_layout)?; work.eval(a, b) } + + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { + use solve::*; + LuImpl::lu(l, a) + } + + fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()> { + use solve::*; + let mut work = InvWork::<$s>::new(l)?; + work.calc(a, p)?; + Ok(()) + } + + fn solve( + l: MatrixLayout, + t: Transpose, + a: &[Self], + p: &Pivot, + b: &mut [Self], + ) -> Result<()> { + use solve::*; + SolveImpl::solve(l, t, a, p, b) + } } }; } diff --git a/lax/src/solve.rs b/lax/src/solve.rs index 89dc1113..ba67c847 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -17,119 +17,13 @@ use num_traits::{ToPrimitive, Zero}; /// 2. Solve linear equation $Ax = b$ or compute inverse matrix $A^{-1}$ /// using the output of LU decomposition. /// -pub trait Solve_: Scalar + Sized { - /// Computes the LU decomposition of a general $m \times n$ matrix - /// with partial pivoting with row interchanges. - /// - /// Output - /// ------- - /// - $U$ and $L$ are stored in `a` after LU decomposition has succeeded. - /// - $P$ is returned as [Pivot] - /// - /// Error - /// ------ - /// - if the matrix is singular - /// - On this case, `return_code` in [Error::LapackComputationalFailure] means - /// `return_code`-th diagonal element of $U$ becomes zero. - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | sgetrf | dgetrf | cgetrf | zgetrf | - /// +pub trait LuImpl: Scalar { fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; - - /// Compute inverse matrix $A^{-1}$ from the output of LU-decomposition - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | sgetri | dgetri | cgetri | zgetri | - /// - fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; - - /// Solve linear equations $Ax = b$ using the output of LU-decomposition - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | sgetrs | dgetrs | cgetrs | zgetrs | - /// - fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; -} - -pub struct InvWork { - pub layout: MatrixLayout, - pub work: Vec>, -} - -pub trait InvWorkImpl: Sized { - type Elem: Scalar; - fn new(layout: MatrixLayout) -> Result; - fn calc(&mut self, a: &mut [Self::Elem], p: &Pivot) -> Result<()>; -} - -macro_rules! impl_inv_work { - ($s:ty, $tri:path) => { - impl InvWorkImpl for InvWork<$s> { - type Elem = $s; - - fn new(layout: MatrixLayout) -> Result { - let (n, _) = layout.size(); - let mut info = 0; - let mut work_size = [Self::Elem::zero()]; - unsafe { - $tri( - &n, - std::ptr::null_mut(), - &layout.lda(), - std::ptr::null(), - AsPtr::as_mut_ptr(&mut work_size), - &(-1), - &mut info, - ) - }; - info.as_lapack_result()?; - let lwork = work_size[0].to_usize().unwrap(); - let work = vec_uninit(lwork); - Ok(InvWork { layout, work }) - } - - fn calc(&mut self, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> { - let lwork = self.work.len().to_i32().unwrap(); - let mut info = 0; - unsafe { - $tri( - &self.layout.len(), - AsPtr::as_mut_ptr(a), - &self.layout.lda(), - ipiv.as_ptr(), - AsPtr::as_mut_ptr(&mut self.work), - &lwork, - &mut info, - ) - }; - info.as_lapack_result()?; - Ok(()) - } - } - }; } -impl_inv_work!(c64, lapack_sys::zgetri_); -impl_inv_work!(c32, lapack_sys::cgetri_); -impl_inv_work!(f64, lapack_sys::dgetri_); -impl_inv_work!(f32, lapack_sys::sgetri_); - -macro_rules! impl_solve { - ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => { - impl Solve_ for $scalar { +macro_rules! impl_lu { + ($scalar:ty, $getrf:path) => { + impl LuImpl for $scalar { fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { let (row, col) = l.size(); assert_eq!(a.len() as i32, row * col); @@ -154,49 +48,22 @@ macro_rules! impl_solve { let ipiv = unsafe { ipiv.assume_init() }; Ok(ipiv) } + } + }; +} - fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { - let (n, _) = l.size(); - if n == 0 { - // Do nothing for empty matrices. - return Ok(()); - } - - // calc work size - let mut info = 0; - let mut work_size = [Self::zero()]; - unsafe { - $getri( - &n, - AsPtr::as_mut_ptr(a), - &l.lda(), - ipiv.as_ptr(), - AsPtr::as_mut_ptr(&mut work_size), - &(-1), - &mut info, - ) - }; - info.as_lapack_result()?; - - // actual - let lwork = work_size[0].to_usize().unwrap(); - let mut work: Vec> = vec_uninit(lwork); - unsafe { - $getri( - &l.len(), - AsPtr::as_mut_ptr(a), - &l.lda(), - ipiv.as_ptr(), - AsPtr::as_mut_ptr(&mut work), - &(lwork as i32), - &mut info, - ) - }; - info.as_lapack_result()?; +impl_lu!(c64, lapack_sys::zgetrf_); +impl_lu!(c32, lapack_sys::cgetrf_); +impl_lu!(f64, lapack_sys::dgetrf_); +impl_lu!(f32, lapack_sys::sgetrf_); - Ok(()) - } +pub trait SolveImpl: Scalar { + fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; +} +macro_rules! impl_solve { + ($scalar:ty, $getrs:path) => { + impl SolveImpl for $scalar { fn solve( l: MatrixLayout, t: Transpose, @@ -266,27 +133,70 @@ macro_rules! impl_solve { }; } // impl_solve! -impl_solve!( - f64, - lapack_sys::dgetrf_, - lapack_sys::dgetri_, - lapack_sys::dgetrs_ -); -impl_solve!( - f32, - lapack_sys::sgetrf_, - lapack_sys::sgetri_, - lapack_sys::sgetrs_ -); -impl_solve!( - c64, - lapack_sys::zgetrf_, - lapack_sys::zgetri_, - lapack_sys::zgetrs_ -); -impl_solve!( - c32, - lapack_sys::cgetrf_, - lapack_sys::cgetri_, - lapack_sys::cgetrs_ -); +impl_solve!(f64, lapack_sys::dgetrs_); +impl_solve!(f32, lapack_sys::sgetrs_); +impl_solve!(c64, lapack_sys::zgetrs_); +impl_solve!(c32, lapack_sys::cgetrs_); + +pub struct InvWork { + pub layout: MatrixLayout, + pub work: Vec>, +} + +pub trait InvWorkImpl: Sized { + type Elem: Scalar; + fn new(layout: MatrixLayout) -> Result; + fn calc(&mut self, a: &mut [Self::Elem], p: &Pivot) -> Result<()>; +} + +macro_rules! impl_inv_work { + ($s:ty, $tri:path) => { + impl InvWorkImpl for InvWork<$s> { + type Elem = $s; + + fn new(layout: MatrixLayout) -> Result { + let (n, _) = layout.size(); + let mut info = 0; + let mut work_size = [Self::Elem::zero()]; + unsafe { + $tri( + &n, + std::ptr::null_mut(), + &layout.lda(), + std::ptr::null(), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ) + }; + info.as_lapack_result()?; + let lwork = work_size[0].to_usize().unwrap(); + let work = vec_uninit(lwork); + Ok(InvWork { layout, work }) + } + + fn calc(&mut self, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> { + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; + unsafe { + $tri( + &self.layout.len(), + AsPtr::as_mut_ptr(a), + &self.layout.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + &mut info, + ) + }; + info.as_lapack_result()?; + Ok(()) + } + } + }; +} + +impl_inv_work!(c64, lapack_sys::zgetri_); +impl_inv_work!(c32, lapack_sys::cgetri_); +impl_inv_work!(f64, lapack_sys::dgetri_); +impl_inv_work!(f32, lapack_sys::sgetri_); From db395057690061d5a36f162f709649e47ddfd235 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 30 Sep 2022 01:31:14 +0900 Subject: [PATCH 03/15] Fix 0-sized case --- lax/src/solve.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lax/src/solve.rs b/lax/src/solve.rs index ba67c847..6dfa5433 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -176,6 +176,9 @@ macro_rules! impl_inv_work { } fn calc(&mut self, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> { + if self.layout.len() == 0 { + return Ok(()); + } let lwork = self.work.len().to_i32().unwrap(); let mut info = 0; unsafe { From 23baa44c29295872d26eef35d4ca426afcb41459 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 30 Sep 2022 01:31:28 +0900 Subject: [PATCH 04/15] Update documents --- lax/src/lib.rs | 64 +++++++++++++++++--------------------- lax/src/solve.rs | 80 +++++++++++++++++++++++++++++------------------- 2 files changed, 76 insertions(+), 68 deletions(-) diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 6ed8498c..49ee1702 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -1,21 +1,24 @@ -//! ndarray-free safe Rust wrapper for LAPACK FFI +//! Safe Rust wrapper for LAPACK without external dependency. //! -//! `Lapack` trait and sub-traits -//! ------------------------------- +//! [Lapack] trait +//! ---------------- //! -//! This crates provides LAPACK wrapper as `impl` of traits to base scalar types. -//! For example, LU decomposition to double-precision matrix is provided like: +//! This crates provides LAPACK wrapper as a traits. +//! For example, LU decomposition of general matrices is provided like: //! -//! ```ignore -//! impl Solve_ for f64 { -//! fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { ... } +//! ``` +//! pub trait Lapack{ +//! fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; //! } //! ``` //! -//! see [Solve_] for detail. You can use it like `f64::lu`: +//! see [Lapack] for detail. +//! This trait is implemented for [f32], [f64], [c32] which is an alias to `num::Complex`, +//! and [c64] which is an alias to `num::Complex`. +//! You can use it like `f64::lu`: //! //! ``` -//! use lax::{Solve_, layout::MatrixLayout, Transpose}; +//! use lax::{Lapack, layout::MatrixLayout, Transpose}; //! //! let mut a = vec![ //! 1.0, 2.0, @@ -31,9 +34,9 @@ //! this trait can be used as a trait bound: //! //! ``` -//! use lax::{Solve_, layout::MatrixLayout, Transpose}; +//! use lax::{Lapack, layout::MatrixLayout, Transpose}; //! -//! fn solve_at_once(layout: MatrixLayout, a: &mut [T], b: &mut [T]) -> Result<(), lax::error::Error> { +//! fn solve_at_once(layout: MatrixLayout, a: &mut [T], b: &mut [T]) -> Result<(), lax::error::Error> { //! let pivot = T::lu(layout, a)?; //! T::solve(layout, Transpose::No, a, &pivot, b)?; //! Ok(()) @@ -48,7 +51,7 @@ //! //! According to the property input metrix, several types of triangular decomposition are used: //! -//! - [Solve_] trait provides methods for LU-decomposition for general matrix. +//! - [solve] module provides methods for LU-decomposition for general matrix. //! - [Solveh_] triat provides methods for Bunch-Kaufman diagonal pivoting method for symmetric/hermite indefinite matrix. //! - [Cholesky_] triat provides methods for Cholesky decomposition for symmetric/hermite positive dinite matrix. //! @@ -184,6 +187,18 @@ pub trait Lapack: /// Computes the LU decomposition of a general $m \times n$ matrix /// with partial pivoting with row interchanges. /// + /// For a given matrix $A$, LU decomposition is described as $A = PLU$ where: + /// + /// - $L$ is lower matrix + /// - $U$ is upper matrix + /// - $P$ is permutation matrix represented by [Pivot] + /// + /// This is designed as two step computation according to LAPACK API: + /// + /// 1. Factorize input matrix $A$ into $L$, $U$, and $P$. + /// 2. Solve linear equation $Ax = b$ by [Lapack::solve] + /// or compute inverse matrix $A^{-1}$ by [Lapack::inv] using the output of LU decomposition. + /// /// Output /// ------- /// - $U$ and $L$ are stored in `a` after LU decomposition has succeeded. @@ -195,35 +210,12 @@ pub trait Lapack: /// - On this case, `return_code` in [Error::LapackComputationalFailure] means /// `return_code`-th diagonal element of $U$ becomes zero. /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | sgetrf | dgetrf | cgetrf | zgetrf | - /// fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; /// Compute inverse matrix $A^{-1}$ from the output of LU-decomposition - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | sgetri | dgetri | cgetri | zgetri | - /// fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; /// Solve linear equations $Ax = b$ using the output of LU-decomposition - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | sgetrs | dgetrs | cgetrs | zgetrs | - /// fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; } diff --git a/lax/src/solve.rs b/lax/src/solve.rs index 6dfa5433..1b3239f5 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -1,21 +1,17 @@ +//! Solve linear equations using LU-decomposition + use crate::{error::*, layout::MatrixLayout, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -#[cfg_attr(doc, katexit::katexit)] -/// Solve linear equations using LU-decomposition -/// -/// For a given matrix $A$, LU decomposition is described as $A = PLU$ where: -/// -/// - $L$ is lower matrix -/// - $U$ is upper matrix -/// - $P$ is permutation matrix represented by [Pivot] +/// Helper trait to abstract `*getrf` LAPACK routines for implementing [Lapack::lu] /// -/// This is designed as two step computation according to LAPACK API: +/// LAPACK correspondance +/// ---------------------- /// -/// 1. Factorize input matrix $A$ into $L$, $U$, and $P$. -/// 2. Solve linear equation $Ax = b$ or compute inverse matrix $A^{-1}$ -/// using the output of LU decomposition. +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | sgetrf | dgetrf | cgetrf | zgetrf | /// pub trait LuImpl: Scalar { fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; @@ -57,6 +53,36 @@ impl_lu!(c32, lapack_sys::cgetrf_); impl_lu!(f64, lapack_sys::dgetrf_); impl_lu!(f32, lapack_sys::sgetrf_); +/// Helper trait to abstract `*getrs` LAPACK routines for implementing [Lapack::solve] +/// +/// If the array has C layout, then it needs to be handled +/// specially, since LAPACK expects a Fortran-layout array. +/// Reinterpreting a C layout array as Fortran layout is +/// equivalent to transposing it. So, we can handle the "no +/// transpose" and "transpose" cases by swapping to "transpose" +/// or "no transpose", respectively. For the "Hermite" case, we +/// can take advantage of the following: +/// +/// ```text +/// A^H x = b +/// ⟺ conj(A^T) x = b +/// ⟺ conj(conj(A^T) x) = conj(b) +/// ⟺ conj(conj(A^T)) conj(x) = conj(b) +/// ⟺ A^T conj(x) = conj(b) +/// ``` +/// +/// So, we can handle this case by switching to "no transpose" +/// (which is equivalent to transposing the array since it will +/// be reinterpreted as Fortran layout) and applying the +/// elementwise conjugate to `x` and `b`. +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | sgetrs | dgetrs | cgetrs | zgetrs | +/// pub trait SolveImpl: Scalar { fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; } @@ -71,26 +97,6 @@ macro_rules! impl_solve { ipiv: &Pivot, b: &mut [Self], ) -> Result<()> { - // If the array has C layout, then it needs to be handled - // specially, since LAPACK expects a Fortran-layout array. - // Reinterpreting a C layout array as Fortran layout is - // equivalent to transposing it. So, we can handle the "no - // transpose" and "transpose" cases by swapping to "transpose" - // or "no transpose", respectively. For the "Hermite" case, we - // can take advantage of the following: - // - // ```text - // A^H x = b - // ⟺ conj(A^T) x = b - // ⟺ conj(conj(A^T) x) = conj(b) - // ⟺ conj(conj(A^T)) conj(x) = conj(b) - // ⟺ A^T conj(x) = conj(b) - // ``` - // - // So, we can handle this case by switching to "no transpose" - // (which is equivalent to transposing the array since it will - // be reinterpreted as Fortran layout) and applying the - // elementwise conjugate to `x` and `b`. let (t, conj) = match l { MatrixLayout::C { .. } => match t { Transpose::No => (Transpose::Transpose, false), @@ -138,11 +144,21 @@ impl_solve!(f32, lapack_sys::sgetrs_); impl_solve!(c64, lapack_sys::zgetrs_); impl_solve!(c32, lapack_sys::cgetrs_); +/// Working memory for computing inverse matrix pub struct InvWork { pub layout: MatrixLayout, pub work: Vec>, } +/// Helper trait to abstract `*getri` LAPACK rotuines for implementing [Lapack::inv] +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | sgetri | dgetri | cgetri | zgetri | +/// pub trait InvWorkImpl: Sized { type Elem: Scalar; fn new(layout: MatrixLayout) -> Result; From 3dcf19b78a8ff97e411b970d99d70f81f363f0d8 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 30 Sep 2022 21:24:29 +0900 Subject: [PATCH 05/15] Add BkWork, InvhWork, SolvehImpl --- lax/src/solveh.rs | 161 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) diff --git a/lax/src/solveh.rs b/lax/src/solveh.rs index bbc6f363..e357a195 100644 --- a/lax/src/solveh.rs +++ b/lax/src/solveh.rs @@ -2,6 +2,167 @@ use crate::{error::*, layout::MatrixLayout, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; +pub struct BkWork { + pub layout: MatrixLayout, + pub work: Vec>, + pub ipiv: Vec>, +} + +pub trait BkWorkImpl: Sized { + type Elem: Scalar; + fn new(l: MatrixLayout) -> Result; + fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<&[i32]>; + fn eval(self, uplo: UPLO, a: &mut [Self::Elem]) -> Result; +} + +macro_rules! impl_bk_work { + ($s:ty, $trf:path) => { + impl BkWorkImpl for BkWork<$s> { + type Elem = $s; + + fn new(layout: MatrixLayout) -> Result { + let (n, _) = layout.size(); + let ipiv = vec_uninit(n as usize); + let mut info = 0; + let mut work_size = [Self::Elem::zero()]; + unsafe { + $trf( + UPLO::Upper.as_ptr(), + &n, + std::ptr::null_mut(), + &layout.lda(), + std::ptr::null_mut(), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ) + }; + info.as_lapack_result()?; + let lwork = work_size[0].to_usize().unwrap(); + let work = vec_uninit(lwork); + Ok(BkWork { layout, work, ipiv }) + } + + fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<&[i32]> { + let (n, _) = self.layout.size(); + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; + unsafe { + $trf( + uplo.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &self.layout.lda(), + AsPtr::as_mut_ptr(&mut self.ipiv), + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + &mut info, + ) + }; + info.as_lapack_result()?; + Ok(unsafe { self.ipiv.slice_assume_init_ref() }) + } + + fn eval(mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result { + let _ref = self.calc(uplo, a)?; + Ok(unsafe { self.ipiv.assume_init() }) + } + } + }; +} +impl_bk_work!(c64, lapack_sys::zhetrf_); +impl_bk_work!(c32, lapack_sys::chetrf_); +impl_bk_work!(f64, lapack_sys::dsytrf_); +impl_bk_work!(f32, lapack_sys::ssytrf_); + +pub struct InvhWork { + pub layout: MatrixLayout, + pub work: Vec>, +} + +pub trait InvhWorkImpl: Sized { + type Elem; + fn new(layout: MatrixLayout) -> Result; + fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()>; +} + +macro_rules! impl_invh_work { + ($s:ty, $tri:path) => { + impl InvhWorkImpl for InvhWork<$s> { + type Elem = $s; + + fn new(layout: MatrixLayout) -> Result { + let (n, _) = layout.size(); + let work = vec_uninit(n as usize); + Ok(InvhWork { layout, work }) + } + + fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> { + let (n, _) = self.layout.size(); + let mut info = 0; + unsafe { + $tri( + uplo.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &self.layout.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(&mut self.work), + &mut info, + ) + }; + info.as_lapack_result()?; + Ok(()) + } + } + }; +} +impl_invh_work!(c64, lapack_sys::zhetri_); +impl_invh_work!(c32, lapack_sys::chetri_); +impl_invh_work!(f64, lapack_sys::dsytri_); +impl_invh_work!(f32, lapack_sys::ssytri_); + +pub trait SolvehImpl: Scalar { + fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>; +} + +macro_rules! impl_solveh_ { + ($s:ty, $trs:path) => { + impl SolvehImpl for $s { + fn solveh( + l: MatrixLayout, + uplo: UPLO, + a: &[Self], + ipiv: &Pivot, + b: &mut [Self], + ) -> Result<()> { + let (n, _) = l.size(); + let mut info = 0; + unsafe { + $trs( + uplo.as_ptr(), + &n, + &1, + AsPtr::as_ptr(a), + &l.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(b), + &n, + &mut info, + ) + }; + info.as_lapack_result()?; + Ok(()) + } + } + }; +} + +impl_solveh_!(c64, lapack_sys::zhetrs_); +impl_solveh_!(c32, lapack_sys::chetrs_); +impl_solveh_!(f64, lapack_sys::dsytrs_); +impl_solveh_!(f32, lapack_sys::ssytrs_); + #[cfg_attr(doc, katexit::katexit)] /// Solve symmetric/hermite indefinite linear problem using the [Bunch-Kaufman diagonal pivoting method][BK]. /// From 4f7404d85c5e93c519c4104b70afe5be489104e8 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 30 Sep 2022 21:41:58 +0900 Subject: [PATCH 06/15] Merge Solveh_ into Lapack trait --- lax/src/lib.rs | 51 +++++++++++- lax/src/solveh.rs | 198 +++++++--------------------------------------- 2 files changed, 75 insertions(+), 174 deletions(-) diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 49ee1702..30421d74 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -123,9 +123,7 @@ pub type Pivot = Vec; #[cfg_attr(doc, katexit::katexit)] /// Trait for primitive types which implements LAPACK subroutines -pub trait Lapack: - OperatorNorm_ + Solveh_ + Cholesky_ + Triangular_ + Tridiagonal_ + Rcond_ -{ +pub trait Lapack: OperatorNorm_ + Cholesky_ + Triangular_ + Tridiagonal_ + Rcond_ { /// Compute right eigenvalue and eigenvectors for a general matrix fn eig( calc_v: bool, @@ -217,6 +215,30 @@ pub trait Lapack: /// Solve linear equations $Ax = b$ using the output of LU-decomposition fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; + + /// Factorize symmetric/Hermitian matrix using Bunch-Kaufman diagonal pivoting method + /// + /// + /// For a given symmetric matrix $A$, + /// this method factorizes $A = U^T D U$ or $A = L D L^T$ where + /// + /// - $U$ (or $L$) are is a product of permutation and unit upper (lower) triangular matrices + /// - $D$ is symmetric and block diagonal with 1-by-1 and 2-by-2 diagonal blocks. + /// + /// This takes two-step approach based in LAPACK: + /// + /// 1. Factorize given matrix $A$ into upper ($U$) or lower ($L$) form with diagonal matrix $D$ + /// 2. Then solve linear equation $Ax = b$, and/or calculate inverse matrix $A^{-1}$ + /// + /// [BK]: https://doi.org/10.2307/2005787 + /// + fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result; + + /// Compute inverse matrix $A^{-1}$ of symmetric/Hermitian matrix using factroized result + fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()>; + + /// Solve symmetric/Hermitian linear equation $Ax = b$ using factroized result + fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>; } macro_rules! impl_lapack { @@ -335,6 +357,29 @@ macro_rules! impl_lapack { use solve::*; SolveImpl::solve(l, t, a, p, b) } + + fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result { + use solveh::*; + let work = BkWork::<$s>::new(l)?; + work.eval(uplo, a) + } + + fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> { + use solveh::*; + let mut work = InvhWork::<$s>::new(l)?; + work.calc(uplo, a, ipiv) + } + + fn solveh( + l: MatrixLayout, + uplo: UPLO, + a: &[Self], + ipiv: &Pivot, + b: &mut [Self], + ) -> Result<()> { + use solveh::*; + SolvehImpl::solveh(l, uplo, a, ipiv, b) + } } }; } diff --git a/lax/src/solveh.rs b/lax/src/solveh.rs index e357a195..e9af59b6 100644 --- a/lax/src/solveh.rs +++ b/lax/src/solveh.rs @@ -8,6 +8,15 @@ pub struct BkWork { pub ipiv: Vec>, } +/// Factorize symmetric/Hermitian matrix using Bunch-Kaufman diagonal pivoting method +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | ssytrf | dsytrf | chetrf | zhetrf | +/// pub trait BkWorkImpl: Sized { type Elem: Scalar; fn new(l: MatrixLayout) -> Result; @@ -80,6 +89,15 @@ pub struct InvhWork { pub work: Vec>, } +/// Compute inverse matrix of symmetric/Hermitian matrix +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | ssytri | dsytri | chetri | zhetri | +/// pub trait InvhWorkImpl: Sized { type Elem; fn new(layout: MatrixLayout) -> Result; @@ -122,6 +140,15 @@ impl_invh_work!(c32, lapack_sys::chetri_); impl_invh_work!(f64, lapack_sys::dsytri_); impl_invh_work!(f32, lapack_sys::ssytri_); +/// Solve symmetric/Hermitian linear equation +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | ssytrs | dsytrs | chetrs | zhetrs | +/// pub trait SolvehImpl: Scalar { fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>; } @@ -162,174 +189,3 @@ impl_solveh_!(c64, lapack_sys::zhetrs_); impl_solveh_!(c32, lapack_sys::chetrs_); impl_solveh_!(f64, lapack_sys::dsytrs_); impl_solveh_!(f32, lapack_sys::ssytrs_); - -#[cfg_attr(doc, katexit::katexit)] -/// Solve symmetric/hermite indefinite linear problem using the [Bunch-Kaufman diagonal pivoting method][BK]. -/// -/// For a given symmetric matrix $A$, -/// this method factorizes $A = U^T D U$ or $A = L D L^T$ where -/// -/// - $U$ (or $L$) are is a product of permutation and unit upper (lower) triangular matrices -/// - $D$ is symmetric and block diagonal with 1-by-1 and 2-by-2 diagonal blocks. -/// -/// This takes two-step approach based in LAPACK: -/// -/// 1. Factorize given matrix $A$ into upper ($U$) or lower ($L$) form with diagonal matrix $D$ -/// 2. Then solve linear equation $Ax = b$, and/or calculate inverse matrix $A^{-1}$ -/// -/// [BK]: https://doi.org/10.2307/2005787 -/// -pub trait Solveh_: Sized { - /// Factorize input matrix using Bunch-Kaufman diagonal pivoting method - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | ssytrf | dsytrf | chetrf | zhetrf | - /// - fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result; - - /// Compute inverse matrix $A^{-1}$ from factroized result - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | ssytri | dsytri | chetri | zhetri | - /// - fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()>; - - /// Solve linear equation $Ax = b$ using factroized result - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | ssytrs | dsytrs | chetrs | zhetrs | - /// - fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>; -} - -macro_rules! impl_solveh { - ($scalar:ty, $trf:path, $tri:path, $trs:path) => { - impl Solveh_ for $scalar { - fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result { - let (n, _) = l.size(); - let mut ipiv = vec_uninit(n as usize); - if n == 0 { - return Ok(Vec::new()); - } - - // calc work size - let mut info = 0; - let mut work_size = [Self::zero()]; - unsafe { - $trf( - uplo.as_ptr(), - &n, - AsPtr::as_mut_ptr(a), - &l.lda(), - AsPtr::as_mut_ptr(&mut ipiv), - AsPtr::as_mut_ptr(&mut work_size), - &(-1), - &mut info, - ) - }; - info.as_lapack_result()?; - - // actual - let lwork = work_size[0].to_usize().unwrap(); - let mut work: Vec> = vec_uninit(lwork); - unsafe { - $trf( - uplo.as_ptr(), - &n, - AsPtr::as_mut_ptr(a), - &l.lda(), - AsPtr::as_mut_ptr(&mut ipiv), - AsPtr::as_mut_ptr(&mut work), - &(lwork as i32), - &mut info, - ) - }; - info.as_lapack_result()?; - let ipiv = unsafe { ipiv.assume_init() }; - Ok(ipiv) - } - - fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> { - let (n, _) = l.size(); - let mut info = 0; - let mut work: Vec> = vec_uninit(n as usize); - unsafe { - $tri( - uplo.as_ptr(), - &n, - AsPtr::as_mut_ptr(a), - &l.lda(), - ipiv.as_ptr(), - AsPtr::as_mut_ptr(&mut work), - &mut info, - ) - }; - info.as_lapack_result()?; - Ok(()) - } - - fn solveh( - l: MatrixLayout, - uplo: UPLO, - a: &[Self], - ipiv: &Pivot, - b: &mut [Self], - ) -> Result<()> { - let (n, _) = l.size(); - let mut info = 0; - unsafe { - $trs( - uplo.as_ptr(), - &n, - &1, - AsPtr::as_ptr(a), - &l.lda(), - ipiv.as_ptr(), - AsPtr::as_mut_ptr(b), - &n, - &mut info, - ) - }; - info.as_lapack_result()?; - Ok(()) - } - } - }; -} // impl_solveh! - -impl_solveh!( - f64, - lapack_sys::dsytrf_, - lapack_sys::dsytri_, - lapack_sys::dsytrs_ -); -impl_solveh!( - f32, - lapack_sys::ssytrf_, - lapack_sys::ssytri_, - lapack_sys::ssytrs_ -); -impl_solveh!( - c64, - lapack_sys::zhetrf_, - lapack_sys::zhetri_, - lapack_sys::zhetrs_ -); -impl_solveh!( - c32, - lapack_sys::chetrf_, - lapack_sys::chetri_, - lapack_sys::chetrs_ -); From 608010c55a3749e0042ded6d07c136f0170c9f6e Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 30 Sep 2022 21:44:29 +0900 Subject: [PATCH 07/15] Make solveh submodule public --- lax/src/lib.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 30421d74..8b0be740 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -52,7 +52,7 @@ //! According to the property input metrix, several types of triangular decomposition are used: //! //! - [solve] module provides methods for LU-decomposition for general matrix. -//! - [Solveh_] triat provides methods for Bunch-Kaufman diagonal pivoting method for symmetric/hermite indefinite matrix. +//! - [solveh] module provides methods for Bunch-Kaufman diagonal pivoting method for symmetric/hermite indefinite matrix. //! - [Cholesky_] triat provides methods for Cholesky decomposition for symmetric/hermite positive dinite matrix. //! //! Eigenvalue Problem @@ -94,6 +94,7 @@ pub mod eigh_generalized; pub mod least_squares; pub mod qr; pub mod solve; +pub mod solveh; pub mod svd; pub mod svddc; @@ -101,7 +102,6 @@ mod alloc; mod cholesky; mod opnorm; mod rcond; -mod solveh; mod triangular; mod tridiagonal; @@ -110,7 +110,6 @@ pub use self::flags::*; pub use self::least_squares::LeastSquaresOwned; pub use self::opnorm::*; pub use self::rcond::*; -pub use self::solveh::*; pub use self::svd::{SvdOwned, SvdRef}; pub use self::triangular::*; pub use self::tridiagonal::*; From d42fc3ec231ad73ffdad0f4b809dda18cf600951 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 1 Oct 2022 15:14:13 +0900 Subject: [PATCH 08/15] Add CholeskyImpl, InvCholeskyImpl, SolveCholeskyImpl --- lax/src/cholesky.rs | 110 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/lax/src/cholesky.rs b/lax/src/cholesky.rs index 9b213246..1c0d3612 100644 --- a/lax/src/cholesky.rs +++ b/lax/src/cholesky.rs @@ -2,6 +2,116 @@ use super::*; use crate::{error::*, layout::*}; use cauchy::*; +pub trait CholeskyImpl: Scalar { + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; +} + +macro_rules! impl_cholesky_ { + ($s:ty, $trf:path) => { + impl CholeskyImpl for $s { + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { + let (n, _) = l.size(); + if matches!(l, MatrixLayout::C { .. }) { + square_transpose(l, a); + } + let mut info = 0; + unsafe { + $trf(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &n, &mut info); + } + info.as_lapack_result()?; + if matches!(l, MatrixLayout::C { .. }) { + square_transpose(l, a); + } + Ok(()) + } + } + }; +} +impl_cholesky_!(c64, lapack_sys::zpotrf_); +impl_cholesky_!(c32, lapack_sys::cpotrf_); +impl_cholesky_!(f64, lapack_sys::dpotrf_); +impl_cholesky_!(f32, lapack_sys::spotrf_); + +pub trait InvCholeskyImpl: Scalar { + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; +} + +macro_rules! impl_inv_cholesky { + ($s:ty, $tri:path) => { + impl InvCholeskyImpl for $s { + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { + let (n, _) = l.size(); + if matches!(l, MatrixLayout::C { .. }) { + square_transpose(l, a); + } + let mut info = 0; + unsafe { + $tri(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &l.lda(), &mut info); + } + info.as_lapack_result()?; + if matches!(l, MatrixLayout::C { .. }) { + square_transpose(l, a); + } + Ok(()) + } + } + }; +} +impl_inv_cholesky!(c64, lapack_sys::zpotri_); +impl_inv_cholesky!(c32, lapack_sys::cpotri_); +impl_inv_cholesky!(f64, lapack_sys::dpotri_); +impl_inv_cholesky!(f32, lapack_sys::spotri_); + +pub trait SolveCholeskyImpl: Scalar { + fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; +} + +macro_rules! impl_solve_cholesky { + ($s:ty, $trs:path) => { + impl SolveCholeskyImpl for $s { + fn solve_cholesky( + l: MatrixLayout, + mut uplo: UPLO, + a: &[Self], + b: &mut [Self], + ) -> Result<()> { + let (n, _) = l.size(); + let nrhs = 1; + let mut info = 0; + if matches!(l, MatrixLayout::C { .. }) { + uplo = uplo.t(); + for val in b.iter_mut() { + *val = val.conj(); + } + } + unsafe { + $trs( + uplo.as_ptr(), + &n, + &nrhs, + AsPtr::as_ptr(a), + &l.lda(), + AsPtr::as_mut_ptr(b), + &n, + &mut info, + ); + } + info.as_lapack_result()?; + if matches!(l, MatrixLayout::C { .. }) { + for val in b.iter_mut() { + *val = val.conj(); + } + } + Ok(()) + } + } + }; +} +impl_solve_cholesky!(c64, lapack_sys::zpotrs_); +impl_solve_cholesky!(c32, lapack_sys::cpotrs_); +impl_solve_cholesky!(f64, lapack_sys::dpotrs_); +impl_solve_cholesky!(f32, lapack_sys::spotrs_); + #[cfg_attr(doc, katexit::katexit)] /// Solve symmetric/hermite positive-definite linear equations using Cholesky decomposition /// From 062a345b383ddc88fb6541bbf94c90860a63dee7 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 1 Oct 2022 15:19:10 +0900 Subject: [PATCH 09/15] Merge Cholesky_ into Lapack --- lax/src/cholesky.rs | 175 +++++++------------------------------------- lax/src/lib.rs | 43 ++++++++++- 2 files changed, 69 insertions(+), 149 deletions(-) diff --git a/lax/src/cholesky.rs b/lax/src/cholesky.rs index 1c0d3612..0f853173 100644 --- a/lax/src/cholesky.rs +++ b/lax/src/cholesky.rs @@ -2,6 +2,15 @@ use super::*; use crate::{error::*, layout::*}; use cauchy::*; +/// Compute Cholesky decomposition according to [UPLO] +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | spotrf | dpotrf | cpotrf | zpotrf | +/// pub trait CholeskyImpl: Scalar { fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; } @@ -32,6 +41,15 @@ impl_cholesky_!(c32, lapack_sys::cpotrf_); impl_cholesky_!(f64, lapack_sys::dpotrf_); impl_cholesky_!(f32, lapack_sys::spotrf_); +/// Compute inverse matrix using Cholesky factroization result +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | spotri | dpotri | cpotri | zpotri | +/// pub trait InvCholeskyImpl: Scalar { fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; } @@ -62,6 +80,15 @@ impl_inv_cholesky!(c32, lapack_sys::cpotri_); impl_inv_cholesky!(f64, lapack_sys::dpotri_); impl_inv_cholesky!(f32, lapack_sys::spotri_); +/// Solve linear equation using Cholesky factroization result +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | spotrs | dpotrs | cpotrs | zpotrs | +/// pub trait SolveCholeskyImpl: Scalar { fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; } @@ -111,151 +138,3 @@ impl_solve_cholesky!(c64, lapack_sys::zpotrs_); impl_solve_cholesky!(c32, lapack_sys::cpotrs_); impl_solve_cholesky!(f64, lapack_sys::dpotrs_); impl_solve_cholesky!(f32, lapack_sys::spotrs_); - -#[cfg_attr(doc, katexit::katexit)] -/// Solve symmetric/hermite positive-definite linear equations using Cholesky decomposition -/// -/// For a given positive definite matrix $A$, -/// Cholesky decomposition is described as $A = U^T U$ or $A = LL^T$ where -/// -/// - $L$ is lower matrix -/// - $U$ is upper matrix -/// -/// This is designed as two step computation according to LAPACK API -/// -/// 1. Factorize input matrix $A$ into $L$ or $U$ -/// 2. Solve linear equation $Ax = b$ or compute inverse matrix $A^{-1}$ -/// using $U$ or $L$. -pub trait Cholesky_: Sized { - /// Compute Cholesky decomposition $A = U^T U$ or $A = L L^T$ according to [UPLO] - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | spotrf | dpotrf | cpotrf | zpotrf | - /// - fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; - - /// Compute inverse matrix $A^{-1}$ using $U$ or $L$ - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | spotri | dpotri | cpotri | zpotri | - /// - fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; - - /// Solve linear equation $Ax = b$ using $U$ or $L$ - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | spotrs | dpotrs | cpotrs | zpotrs | - /// - fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; -} - -macro_rules! impl_cholesky { - ($scalar:ty, $trf:path, $tri:path, $trs:path) => { - impl Cholesky_ for $scalar { - fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { - let (n, _) = l.size(); - if matches!(l, MatrixLayout::C { .. }) { - square_transpose(l, a); - } - let mut info = 0; - unsafe { - $trf(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &n, &mut info); - } - info.as_lapack_result()?; - if matches!(l, MatrixLayout::C { .. }) { - square_transpose(l, a); - } - Ok(()) - } - - fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { - let (n, _) = l.size(); - if matches!(l, MatrixLayout::C { .. }) { - square_transpose(l, a); - } - let mut info = 0; - unsafe { - $tri(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &l.lda(), &mut info); - } - info.as_lapack_result()?; - if matches!(l, MatrixLayout::C { .. }) { - square_transpose(l, a); - } - Ok(()) - } - - fn solve_cholesky( - l: MatrixLayout, - mut uplo: UPLO, - a: &[Self], - b: &mut [Self], - ) -> Result<()> { - let (n, _) = l.size(); - let nrhs = 1; - let mut info = 0; - if matches!(l, MatrixLayout::C { .. }) { - uplo = uplo.t(); - for val in b.iter_mut() { - *val = val.conj(); - } - } - unsafe { - $trs( - uplo.as_ptr(), - &n, - &nrhs, - AsPtr::as_ptr(a), - &l.lda(), - AsPtr::as_mut_ptr(b), - &n, - &mut info, - ); - } - info.as_lapack_result()?; - if matches!(l, MatrixLayout::C { .. }) { - for val in b.iter_mut() { - *val = val.conj(); - } - } - Ok(()) - } - } - }; -} // end macro_rules - -impl_cholesky!( - f64, - lapack_sys::dpotrf_, - lapack_sys::dpotri_, - lapack_sys::dpotrs_ -); -impl_cholesky!( - f32, - lapack_sys::spotrf_, - lapack_sys::spotri_, - lapack_sys::spotrs_ -); -impl_cholesky!( - c64, - lapack_sys::zpotrf_, - lapack_sys::zpotri_, - lapack_sys::zpotrs_ -); -impl_cholesky!( - c32, - lapack_sys::cpotrf_, - lapack_sys::cpotri_, - lapack_sys::cpotrs_ -); diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 8b0be740..21190ab9 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -122,7 +122,7 @@ pub type Pivot = Vec; #[cfg_attr(doc, katexit::katexit)] /// Trait for primitive types which implements LAPACK subroutines -pub trait Lapack: OperatorNorm_ + Cholesky_ + Triangular_ + Tridiagonal_ + Rcond_ { +pub trait Lapack: OperatorNorm_ + Triangular_ + Tridiagonal_ + Rcond_ { /// Compute right eigenvalue and eigenvectors for a general matrix fn eig( calc_v: bool, @@ -238,6 +238,27 @@ pub trait Lapack: OperatorNorm_ + Cholesky_ + Triangular_ + Tridiagonal_ + Rcond /// Solve symmetric/Hermitian linear equation $Ax = b$ using factroized result fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>; + + /// Solve symmetric/hermite positive-definite linear equations using Cholesky decomposition + /// + /// For a given positive definite matrix $A$, + /// Cholesky decomposition is described as $A = U^T U$ or $A = LL^T$ where + /// + /// - $L$ is lower matrix + /// - $U$ is upper matrix + /// + /// This is designed as two step computation according to LAPACK API + /// + /// 1. Factorize input matrix $A$ into $L$ or $U$ + /// 2. Solve linear equation $Ax = b$ or compute inverse matrix $A^{-1}$ + /// using $U$ or $L$. + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + + /// Compute inverse matrix $A^{-1}$ using $U$ or $L$ + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + + /// Solve linear equation $Ax = b$ using $U$ or $L$ + fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; } macro_rules! impl_lapack { @@ -379,6 +400,26 @@ macro_rules! impl_lapack { use solveh::*; SolvehImpl::solveh(l, uplo, a, ipiv, b) } + + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { + use cholesky::*; + CholeskyImpl::cholesky(l, uplo, a) + } + + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { + use cholesky::*; + InvCholeskyImpl::inv_cholesky(l, uplo, a) + } + + fn solve_cholesky( + l: MatrixLayout, + uplo: UPLO, + a: &[Self], + b: &mut [Self], + ) -> Result<()> { + use cholesky::*; + SolveCholeskyImpl::solve_cholesky(l, uplo, a, b) + } } }; } From 29c848c66c3ca4ceb354c788f3a32fe3780c02f1 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 1 Oct 2022 15:29:24 +0900 Subject: [PATCH 10/15] Make cholesky submodule public, update documents --- lax/src/lib.rs | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 21190ab9..d3e713ce 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -52,8 +52,8 @@ //! According to the property input metrix, several types of triangular decomposition are used: //! //! - [solve] module provides methods for LU-decomposition for general matrix. -//! - [solveh] module provides methods for Bunch-Kaufman diagonal pivoting method for symmetric/hermite indefinite matrix. -//! - [Cholesky_] triat provides methods for Cholesky decomposition for symmetric/hermite positive dinite matrix. +//! - [solveh] module provides methods for Bunch-Kaufman diagonal pivoting method for symmetric/Hermitian indefinite matrix. +//! - [cholesky] module provides methods for Cholesky decomposition for symmetric/Hermitian positive dinite matrix. //! //! Eigenvalue Problem //! ------------------- @@ -62,8 +62,8 @@ //! there are several types of eigenvalue problem API //! //! - [eig] module for eigenvalue problem for general matrix. -//! - [eigh] module for eigenvalue problem for symmetric/hermite matrix. -//! - [eigh_generalized] module for generalized eigenvalue problem for symmetric/hermite matrix. +//! - [eigh] module for eigenvalue problem for symmetric/Hermitian matrix. +//! - [eigh_generalized] module for generalized eigenvalue problem for symmetric/Hermitian matrix. //! //! Singular Value Decomposition //! ----------------------------- @@ -88,6 +88,7 @@ pub mod error; pub mod flags; pub mod layout; +pub mod cholesky; pub mod eig; pub mod eigh; pub mod eigh_generalized; @@ -99,7 +100,6 @@ pub mod svd; pub mod svddc; mod alloc; -mod cholesky; mod opnorm; mod rcond; mod triangular; @@ -130,7 +130,7 @@ pub trait Lapack: OperatorNorm_ + Triangular_ + Tridiagonal_ + Rcond_ { a: &mut [Self], ) -> Result<(Vec, Vec)>; - /// Compute right eigenvalue and eigenvectors for a symmetric or hermite matrix + /// Compute right eigenvalue and eigenvectors for a symmetric or Hermitian matrix fn eigh( calc_eigenvec: bool, layout: MatrixLayout, @@ -138,7 +138,7 @@ pub trait Lapack: OperatorNorm_ + Triangular_ + Tridiagonal_ + Rcond_ { a: &mut [Self], ) -> Result>; - /// Compute right eigenvalue and eigenvectors for a symmetric or hermite matrix + /// Compute right eigenvalue and eigenvectors for a symmetric or Hermitian matrix fn eigh_generalized( calc_eigenvec: bool, layout: MatrixLayout, @@ -217,7 +217,6 @@ pub trait Lapack: OperatorNorm_ + Triangular_ + Tridiagonal_ + Rcond_ { /// Factorize symmetric/Hermitian matrix using Bunch-Kaufman diagonal pivoting method /// - /// /// For a given symmetric matrix $A$, /// this method factorizes $A = U^T D U$ or $A = L D L^T$ where /// @@ -233,13 +232,13 @@ pub trait Lapack: OperatorNorm_ + Triangular_ + Tridiagonal_ + Rcond_ { /// fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result; - /// Compute inverse matrix $A^{-1}$ of symmetric/Hermitian matrix using factroized result + /// Compute inverse matrix $A^{-1}$ using the result of [Lapack::bk] fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()>; - /// Solve symmetric/Hermitian linear equation $Ax = b$ using factroized result + /// Solve symmetric/Hermitian linear equation $Ax = b$ using the result of [Lapack::bk] fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>; - /// Solve symmetric/hermite positive-definite linear equations using Cholesky decomposition + /// Solve symmetric/Hermitian positive-definite linear equations using Cholesky decomposition /// /// For a given positive definite matrix $A$, /// Cholesky decomposition is described as $A = U^T U$ or $A = LL^T$ where @@ -250,14 +249,15 @@ pub trait Lapack: OperatorNorm_ + Triangular_ + Tridiagonal_ + Rcond_ { /// This is designed as two step computation according to LAPACK API /// /// 1. Factorize input matrix $A$ into $L$ or $U$ - /// 2. Solve linear equation $Ax = b$ or compute inverse matrix $A^{-1}$ - /// using $U$ or $L$. + /// 2. Solve linear equation $Ax = b$ by [Lapack::solve_cholesky] + /// or compute inverse matrix $A^{-1}$ by [Lapack::inv_cholesky] + /// fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; - /// Compute inverse matrix $A^{-1}$ using $U$ or $L$ + /// Compute inverse matrix $A^{-1}$ using $U$ or $L$ calculated by [Lapack::cholesky] fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; - /// Solve linear equation $Ax = b$ using $U$ or $L$ + /// Solve linear equation $Ax = b$ using $U$ or $L$ calculated by [Lapack::cholesky] fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; } From 5824df154c26c72d1c4d8e7a86cd022019bc07ea Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 1 Oct 2022 15:45:39 +0900 Subject: [PATCH 11/15] Ignore first doctest --- lax/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lax/src/lib.rs b/lax/src/lib.rs index d3e713ce..e844b1da 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -6,8 +6,8 @@ //! This crates provides LAPACK wrapper as a traits. //! For example, LU decomposition of general matrices is provided like: //! -//! ``` -//! pub trait Lapack{ +//! ```ignore +//! pub trait Lapack { //! fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; //! } //! ``` From e381bf2d3e4b9ff118e35a02d8ca422d3111705a Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 1 Oct 2022 15:45:56 +0900 Subject: [PATCH 12/15] Fix for 0-sized case --- lax/src/solveh.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lax/src/solveh.rs b/lax/src/solveh.rs index e9af59b6..4587d8e8 100644 --- a/lax/src/solveh.rs +++ b/lax/src/solveh.rs @@ -55,6 +55,9 @@ macro_rules! impl_bk_work { fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<&[i32]> { let (n, _) = self.layout.size(); let lwork = self.work.len().to_i32().unwrap(); + if lwork == 0 { + return Ok(&[]); + } let mut info = 0; unsafe { $trf( From a621b1b0c0eaada181f3acb28a3a0be8387776f9 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 1 Oct 2022 15:54:50 +0900 Subject: [PATCH 13/15] Update module-level documents --- lax/src/cholesky.rs | 2 ++ lax/src/lib.rs | 2 -- lax/src/solveh.rs | 5 +++++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/lax/src/cholesky.rs b/lax/src/cholesky.rs index 0f853173..785f6e5e 100644 --- a/lax/src/cholesky.rs +++ b/lax/src/cholesky.rs @@ -1,3 +1,5 @@ +//! Factorize positive-definite symmetric/Hermitian matrices using Cholesky algorithm + use super::*; use crate::{error::*, layout::*}; use cauchy::*; diff --git a/lax/src/lib.rs b/lax/src/lib.rs index e844b1da..e673d261 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -228,8 +228,6 @@ pub trait Lapack: OperatorNorm_ + Triangular_ + Tridiagonal_ + Rcond_ { /// 1. Factorize given matrix $A$ into upper ($U$) or lower ($L$) form with diagonal matrix $D$ /// 2. Then solve linear equation $Ax = b$, and/or calculate inverse matrix $A^{-1}$ /// - /// [BK]: https://doi.org/10.2307/2005787 - /// fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result; /// Compute inverse matrix $A^{-1}$ using the result of [Lapack::bk] diff --git a/lax/src/solveh.rs b/lax/src/solveh.rs index 4587d8e8..abb75cb8 100644 --- a/lax/src/solveh.rs +++ b/lax/src/solveh.rs @@ -1,3 +1,8 @@ +//! Factorize symmetric/Hermitian matrix using [Bunch-Kaufman diagonal pivoting method][BK] +//! +//! [BK]: https://doi.org/10.2307/2005787 +//! + use crate::{error::*, layout::MatrixLayout, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; From 1b1bc821d5b3741e2e7433e11a0bff81cb86f7c0 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 1 Oct 2022 16:08:40 +0900 Subject: [PATCH 14/15] Rewrite transpose note into KaTeX --- lax/src/solve.rs | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/lax/src/solve.rs b/lax/src/solve.rs index 1b3239f5..df6372f9 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -53,6 +53,7 @@ impl_lu!(c32, lapack_sys::cgetrf_); impl_lu!(f64, lapack_sys::dgetrf_); impl_lu!(f32, lapack_sys::sgetrf_); +#[cfg_attr(doc, katexit::katexit)] /// Helper trait to abstract `*getrs` LAPACK routines for implementing [Lapack::solve] /// /// If the array has C layout, then it needs to be handled @@ -63,13 +64,15 @@ impl_lu!(f32, lapack_sys::sgetrf_); /// or "no transpose", respectively. For the "Hermite" case, we /// can take advantage of the following: /// -/// ```text -/// A^H x = b -/// ⟺ conj(A^T) x = b -/// ⟺ conj(conj(A^T) x) = conj(b) -/// ⟺ conj(conj(A^T)) conj(x) = conj(b) -/// ⟺ A^T conj(x) = conj(b) -/// ``` +/// $$ +/// \begin{align*} +/// A^H x &= b \\\\ +/// \Leftrightarrow \overline{A^T} x &= b \\\\ +/// \Leftrightarrow \overline{\overline{A^T} x} &= \overline{b} \\\\ +/// \Leftrightarrow \overline{\overline{A^T}} \overline{x} &= \overline{b} \\\\ +/// \Leftrightarrow A^T \overline{x} &= \overline{b} +/// \end{align*} +/// $$ /// /// So, we can handle this case by switching to "no transpose" /// (which is equivalent to transposing the array since it will From 7e61539e4a36ed0de858ca3d81a2c0ccdcf95390 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 1 Oct 2022 16:12:28 +0900 Subject: [PATCH 15/15] Fix markdown table --- lax/src/solve.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lax/src/solve.rs b/lax/src/solve.rs index df6372f9..63f69983 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -79,14 +79,14 @@ impl_lu!(f32, lapack_sys::sgetrf_); /// be reinterpreted as Fortran layout) and applying the /// elementwise conjugate to `x` and `b`. /// -/// LAPACK correspondance -/// ---------------------- -/// -/// | f32 | f64 | c32 | c64 | -/// |:-------|:-------|:-------|:-------| -/// | sgetrs | dgetrs | cgetrs | zgetrs | -/// pub trait SolveImpl: Scalar { + /// LAPACK correspondance + /// ---------------------- + /// + /// | f32 | f64 | c32 | c64 | + /// |:-------|:-------|:-------|:-------| + /// | sgetrs | dgetrs | cgetrs | zgetrs | + /// fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; }