diff --git a/lax/src/qr.rs b/lax/src/qr.rs index 6c26273d..0bb00c2a 100644 --- a/lax/src/qr.rs +++ b/lax/src/qr.rs @@ -2,35 +2,120 @@ use crate::{error::*, layout::MatrixLayout}; use cauchy::*; -use num_traits::Zero; -use std::cmp::min; +use num_traits::{ToPrimitive, Zero}; -/// Wraps `*geqrf` and `*orgqr` (`*ungqr` for complex numbers) pub trait QR_: Sized { - unsafe fn householder(l: MatrixLayout, a: &mut [Self]) -> Result>; - unsafe fn q(l: MatrixLayout, a: &mut [Self], tau: &[Self]) -> Result<()>; - unsafe fn qr(l: MatrixLayout, a: &mut [Self]) -> Result>; + /// Execute Householder reflection as the first step of QR-decomposition + /// + /// For C-continuous array, + /// this will call LQ-decomposition of the transposed matrix $ A^T = LQ^T $ + fn householder(l: MatrixLayout, a: &mut [Self]) -> Result>; + + /// Reconstruct Q-matrix from Householder-reflectors + fn q(l: MatrixLayout, a: &mut [Self], tau: &[Self]) -> Result<()>; + + /// Execute QR-decomposition at once + fn qr(l: MatrixLayout, a: &mut [Self]) -> Result>; } macro_rules! impl_qr { - ($scalar:ty, $qrf:path, $gqr:path) => { + ($scalar:ty, $qrf:path, $lqf:path, $gqr:path, $glq:path) => { impl QR_ for $scalar { - unsafe fn householder(l: MatrixLayout, mut a: &mut [Self]) -> Result> { - let (row, col) = l.size(); - let k = min(row, col); + fn householder(l: MatrixLayout, mut a: &mut [Self]) -> Result> { + let m = l.lda(); + let n = l.len(); + let k = m.min(n); let mut tau = vec![Self::zero(); k as usize]; - $qrf(l.lapacke_layout(), row, col, &mut a, l.lda(), &mut tau).as_lapack_result()?; + + // eval work size + let mut info = 0; + let mut work_size = [Self::zero()]; + unsafe { + match l { + MatrixLayout::F { .. } => { + $qrf(m, n, &mut a, m, &mut tau, &mut work_size, -1, &mut info); + } + MatrixLayout::C { .. } => { + $lqf(m, n, &mut a, m, &mut tau, &mut work_size, -1, &mut info); + } + } + } + info.as_lapack_result()?; + + // calc + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + match l { + MatrixLayout::F { .. } => { + $qrf( + m, + n, + &mut a, + m, + &mut tau, + &mut work, + lwork as i32, + &mut info, + ); + } + MatrixLayout::C { .. } => { + $lqf( + m, + n, + &mut a, + m, + &mut tau, + &mut work, + lwork as i32, + &mut info, + ); + } + } + } + info.as_lapack_result()?; + Ok(tau) } - unsafe fn q(l: MatrixLayout, mut a: &mut [Self], tau: &[Self]) -> Result<()> { - let (row, col) = l.size(); - let k = min(row, col); - $gqr(l.lapacke_layout(), row, k, k, &mut a, l.lda(), &tau).as_lapack_result()?; + fn q(l: MatrixLayout, mut a: &mut [Self], tau: &[Self]) -> Result<()> { + let m = l.lda(); + let n = l.len(); + let k = m.min(n); + assert_eq!(tau.len(), k as usize); + + // eval work size + let mut info = 0; + let mut work_size = [Self::zero()]; + unsafe { + match l { + MatrixLayout::F { .. } => { + $gqr(m, k, k, &mut a, m, &tau, &mut work_size, -1, &mut info) + } + MatrixLayout::C { .. } => { + $glq(k, n, k, &mut a, m, &tau, &mut work_size, -1, &mut info) + } + } + }; + + // calc + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + match l { + MatrixLayout::F { .. } => { + $gqr(m, k, k, &mut a, m, &tau, &mut work, lwork as i32, &mut info) + } + MatrixLayout::C { .. } => { + $glq(k, n, k, &mut a, m, &tau, &mut work, lwork as i32, &mut info) + } + } + } + info.as_lapack_result()?; Ok(()) } - unsafe fn qr(l: MatrixLayout, a: &mut [Self]) -> Result> { + fn qr(l: MatrixLayout, a: &mut [Self]) -> Result> { let tau = Self::householder(l, a)?; let r = Vec::from(&*a); Self::q(l, a, &tau)?; @@ -40,7 +125,31 @@ macro_rules! impl_qr { }; } // endmacro -impl_qr!(f64, lapacke::dgeqrf, lapacke::dorgqr); -impl_qr!(f32, lapacke::sgeqrf, lapacke::sorgqr); -impl_qr!(c64, lapacke::zgeqrf, lapacke::zungqr); -impl_qr!(c32, lapacke::cgeqrf, lapacke::cungqr); +impl_qr!( + f64, + lapack::dgeqrf, + lapack::dgelqf, + lapack::dorgqr, + lapack::dorglq +); +impl_qr!( + f32, + lapack::sgeqrf, + lapack::sgelqf, + lapack::sorgqr, + lapack::sorglq +); +impl_qr!( + c64, + lapack::zgeqrf, + lapack::zgelqf, + lapack::zungqr, + lapack::zunglq +); +impl_qr!( + c32, + lapack::cgeqrf, + lapack::cgelqf, + lapack::cungqr, + lapack::cunglq +); diff --git a/ndarray-linalg/src/qr.rs b/ndarray-linalg/src/qr.rs index be2de0c2..ae7b2c25 100644 --- a/ndarray-linalg/src/qr.rs +++ b/ndarray-linalg/src/qr.rs @@ -61,7 +61,7 @@ where fn qr_square_inplace(&mut self) -> Result<(&mut Self, Self::R)> { let l = self.square_layout()?; - let r = unsafe { A::qr(l, self.as_allocated_mut()?)? }; + let r = A::qr(l, self.as_allocated_mut()?)?; let r: Array2<_> = into_matrix(l, r)?; Ok((self, r.into_triangular(UPLO::Upper))) } @@ -107,7 +107,7 @@ where let m = self.ncols(); let k = ::std::cmp::min(n, m); let l = self.layout()?; - let r = unsafe { A::qr(l, self.as_allocated_mut()?)? }; + let r = A::qr(l, self.as_allocated_mut()?)?; let r: Array2<_> = into_matrix(l, r)?; let q = self; Ok((take_slice(&q, n, k), take_slice_upper(&r, k, m)))