Skip to content

Commit fea408d

Browse files
authored
Merge pull request #225 from rust-ndarray/lapack-cholesky
Cholesky factorization by LAPACK
2 parents 5f6a2c3 + 44b2adb commit fea408d

File tree

5 files changed

+313
-217
lines changed

5 files changed

+313
-217
lines changed

lax/src/cholesky.rs

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,90 @@
11
//! Cholesky decomposition
22
33
use super::*;
4-
use crate::{error::*, layout::MatrixLayout};
4+
use crate::{error::*, layout::*};
55
use cauchy::*;
66

77
pub trait Cholesky_: Sized {
88
/// Cholesky: wrapper of `*potrf`
99
///
1010
/// **Warning: Only the portion of `a` corresponding to `UPLO` is written.**
11-
unsafe fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
11+
fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
12+
1213
/// Wrapper of `*potri`
1314
///
1415
/// **Warning: Only the portion of `a` corresponding to `UPLO` is written.**
15-
unsafe fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
16+
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
17+
1618
/// Wrapper of `*potrs`
17-
unsafe fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self])
18-
-> Result<()>;
19+
fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>;
1920
}
2021

2122
macro_rules! impl_cholesky {
2223
($scalar:ty, $trf:path, $tri:path, $trs:path) => {
2324
impl Cholesky_ for $scalar {
24-
unsafe fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
25+
fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
2526
let (n, _) = l.size();
26-
$trf(l.lapacke_layout(), uplo as u8, n, a, n).as_lapack_result()?;
27+
if matches!(l, MatrixLayout::C { .. }) {
28+
square_transpose(l, a);
29+
}
30+
let mut info = 0;
31+
unsafe {
32+
$trf(uplo as u8, n, a, n, &mut info);
33+
}
34+
info.as_lapack_result()?;
35+
if matches!(l, MatrixLayout::C { .. }) {
36+
square_transpose(l, a);
37+
}
2738
Ok(())
2839
}
2940

30-
unsafe fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
41+
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
3142
let (n, _) = l.size();
32-
$tri(l.lapacke_layout(), uplo as u8, n, a, l.lda()).as_lapack_result()?;
43+
if matches!(l, MatrixLayout::C { .. }) {
44+
square_transpose(l, a);
45+
}
46+
let mut info = 0;
47+
unsafe {
48+
$tri(uplo as u8, n, a, l.lda(), &mut info);
49+
}
50+
info.as_lapack_result()?;
51+
if matches!(l, MatrixLayout::C { .. }) {
52+
square_transpose(l, a);
53+
}
3354
Ok(())
3455
}
3556

36-
unsafe fn solve_cholesky(
57+
fn solve_cholesky(
3758
l: MatrixLayout,
38-
uplo: UPLO,
59+
mut uplo: UPLO,
3960
a: &[Self],
4061
b: &mut [Self],
4162
) -> Result<()> {
4263
let (n, _) = l.size();
4364
let nrhs = 1;
44-
let ldb = 1;
45-
$trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), b, ldb)
46-
.as_lapack_result()?;
65+
let mut info = 0;
66+
if matches!(l, MatrixLayout::C { .. }) {
67+
uplo = uplo.t();
68+
for val in b.iter_mut() {
69+
*val = val.conj();
70+
}
71+
}
72+
unsafe {
73+
$trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info);
74+
}
75+
info.as_lapack_result()?;
76+
if matches!(l, MatrixLayout::C { .. }) {
77+
for val in b.iter_mut() {
78+
*val = val.conj();
79+
}
80+
}
4781
Ok(())
4882
}
4983
}
5084
};
5185
} // end macro_rules
5286

53-
impl_cholesky!(f64, lapacke::dpotrf, lapacke::dpotri, lapacke::dpotrs);
54-
impl_cholesky!(f32, lapacke::spotrf, lapacke::spotri, lapacke::spotrs);
55-
impl_cholesky!(c64, lapacke::zpotrf, lapacke::zpotri, lapacke::zpotrs);
56-
impl_cholesky!(c32, lapacke::cpotrf, lapacke::cpotri, lapacke::cpotrs);
87+
impl_cholesky!(f64, lapack::dpotrf, lapack::dpotri, lapack::dpotrs);
88+
impl_cholesky!(f32, lapack::spotrf, lapack::spotri, lapack::spotrs);
89+
impl_cholesky!(c64, lapack::zpotrf, lapack::zpotri, lapack::zpotrs);
90+
impl_cholesky!(c32, lapack::cpotrf, lapack::cpotri, lapack::cpotrs);

lax/src/layout.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
//! This `S` for a matrix `A` is called "leading dimension of the array A" in LAPACK document, and denoted by `lda`.
3838
//!
3939
40+
use cauchy::Scalar;
41+
4042
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4143
pub enum MatrixLayout {
4244
C { row: i32, lda: i32 },
@@ -96,3 +98,44 @@ impl MatrixLayout {
9698
}
9799
}
98100
}
101+
102+
/// In-place transpose of a square matrix by keeping F/C layout
103+
///
104+
/// Transpose for C-continuous array
105+
///
106+
/// ```rust
107+
/// # use lax::layout::*;
108+
/// let layout = MatrixLayout::C { row: 2, lda: 2 };
109+
/// let mut a = vec![1., 2., 3., 4.];
110+
/// square_transpose(layout, &mut a);
111+
/// assert_eq!(a, &[1., 3., 2., 4.]);
112+
/// ```
113+
///
114+
/// Transpose for F-continuous array
115+
///
116+
/// ```rust
117+
/// # use lax::layout::*;
118+
/// let layout = MatrixLayout::F { col: 2, lda: 2 };
119+
/// let mut a = vec![1., 3., 2., 4.];
120+
/// square_transpose(layout, &mut a);
121+
/// assert_eq!(a, &[1., 2., 3., 4.]);
122+
/// ```
123+
///
124+
/// Panics
125+
/// ------
126+
/// - If size of `a` and `layout` size mismatch
127+
///
128+
pub fn square_transpose<T: Scalar>(layout: MatrixLayout, a: &mut [T]) {
129+
let (m, n) = layout.size();
130+
let n = n as usize;
131+
let m = m as usize;
132+
assert_eq!(a.len(), n * m);
133+
for i in 0..m {
134+
for j in (i + 1)..n {
135+
let a_ij = a[i * n + j];
136+
let a_ji = a[j * m + i];
137+
a[i * n + j] = a_ji.conj();
138+
a[j * m + i] = a_ij.conj();
139+
}
140+
}
141+
}

lax/src/lib.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,15 @@ pub enum UPLO {
126126
Lower = b'L',
127127
}
128128

129+
impl UPLO {
130+
pub fn t(self) -> Self {
131+
match self {
132+
UPLO::Upper => UPLO::Lower,
133+
UPLO::Lower => UPLO::Upper,
134+
}
135+
}
136+
}
137+
129138
#[derive(Debug, Clone, Copy)]
130139
#[repr(u8)]
131140
pub enum Transpose {

ndarray-linalg/src/cholesky.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ where
155155

156156
fn invc_into(self) -> Result<Self::Output> {
157157
let mut a = self.factor;
158-
unsafe { A::inv_cholesky(a.square_layout()?, self.uplo, a.as_allocated_mut()?)? };
158+
A::inv_cholesky(a.square_layout()?, self.uplo, a.as_allocated_mut()?)?;
159159
triangular_fill_hermitian(&mut a, self.uplo);
160160
Ok(a)
161161
}
@@ -173,14 +173,12 @@ where
173173
where
174174
Sb: DataMut<Elem = A>,
175175
{
176-
unsafe {
177-
A::solve_cholesky(
178-
self.factor.square_layout()?,
179-
self.uplo,
180-
self.factor.as_allocated()?,
181-
b.as_slice_mut().unwrap(),
182-
)?
183-
};
176+
A::solve_cholesky(
177+
self.factor.square_layout()?,
178+
self.uplo,
179+
self.factor.as_allocated()?,
180+
b.as_slice_mut().unwrap(),
181+
)?;
184182
Ok(b)
185183
}
186184
}
@@ -259,7 +257,7 @@ where
259257
S: DataMut<Elem = A>,
260258
{
261259
fn cholesky_inplace(&mut self, uplo: UPLO) -> Result<&mut Self> {
262-
unsafe { A::cholesky(self.square_layout()?, uplo, self.as_allocated_mut()?)? };
260+
A::cholesky(self.square_layout()?, uplo, self.as_allocated_mut()?)?;
263261
Ok(self.into_triangular(uplo))
264262
}
265263
}

0 commit comments

Comments
 (0)