Skip to content

Commit 34f7b99

Browse files
authored
Merge pull request #206 from rust-ndarray/lapack
Rewrite LAPACKE by Rust, call LAPACK directly
2 parents 4d0d8c3 + 6571d86 commit 34f7b99

34 files changed

+1795
-1023
lines changed

lax/Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ netlib = ["lapack-src/netlib", "blas-src/netlib"]
1111
openblas = ["lapack-src/openblas", "blas-src/openblas"]
1212

1313
[dependencies]
14-
thiserror = "1"
15-
cauchy = "0.2"
16-
lapacke = "0.2.0"
14+
thiserror = "1.0"
15+
cauchy = "0.2.0"
1716
num-traits = "0.2"
17+
lapack = "0.16.0"
1818

1919
[dependencies.blas-src]
2020
version = "0.6.1"

lax/src/cholesky.rs

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,90 @@
11
//! Cholesky decomposition
22
33
use super::*;
4-
use crate::{error::*, layout::MatrixLayout};
4+
use crate::{error::*, layout::*};
55
use cauchy::*;
66

77
pub trait Cholesky_: Sized {
88
/// Cholesky: wrapper of `*potrf`
99
///
1010
/// **Warning: Only the portion of `a` corresponding to `UPLO` is written.**
11-
unsafe fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
11+
fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
12+
1213
/// Wrapper of `*potri`
1314
///
1415
/// **Warning: Only the portion of `a` corresponding to `UPLO` is written.**
15-
unsafe fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
16+
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
17+
1618
/// Wrapper of `*potrs`
17-
unsafe fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self])
18-
-> Result<()>;
19+
fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>;
1920
}
2021

2122
macro_rules! impl_cholesky {
2223
($scalar:ty, $trf:path, $tri:path, $trs:path) => {
2324
impl Cholesky_ for $scalar {
24-
unsafe fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
25+
fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
2526
let (n, _) = l.size();
26-
$trf(l.lapacke_layout(), uplo as u8, n, a, n).as_lapack_result()?;
27+
if matches!(l, MatrixLayout::C { .. }) {
28+
square_transpose(l, a);
29+
}
30+
let mut info = 0;
31+
unsafe {
32+
$trf(uplo as u8, n, a, n, &mut info);
33+
}
34+
info.as_lapack_result()?;
35+
if matches!(l, MatrixLayout::C { .. }) {
36+
square_transpose(l, a);
37+
}
2738
Ok(())
2839
}
2940

30-
unsafe fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
41+
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
3142
let (n, _) = l.size();
32-
$tri(l.lapacke_layout(), uplo as u8, n, a, l.lda()).as_lapack_result()?;
43+
if matches!(l, MatrixLayout::C { .. }) {
44+
square_transpose(l, a);
45+
}
46+
let mut info = 0;
47+
unsafe {
48+
$tri(uplo as u8, n, a, l.lda(), &mut info);
49+
}
50+
info.as_lapack_result()?;
51+
if matches!(l, MatrixLayout::C { .. }) {
52+
square_transpose(l, a);
53+
}
3354
Ok(())
3455
}
3556

36-
unsafe fn solve_cholesky(
57+
fn solve_cholesky(
3758
l: MatrixLayout,
38-
uplo: UPLO,
59+
mut uplo: UPLO,
3960
a: &[Self],
4061
b: &mut [Self],
4162
) -> Result<()> {
4263
let (n, _) = l.size();
4364
let nrhs = 1;
44-
let ldb = 1;
45-
$trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), b, ldb)
46-
.as_lapack_result()?;
65+
let mut info = 0;
66+
if matches!(l, MatrixLayout::C { .. }) {
67+
uplo = uplo.t();
68+
for val in b.iter_mut() {
69+
*val = val.conj();
70+
}
71+
}
72+
unsafe {
73+
$trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info);
74+
}
75+
info.as_lapack_result()?;
76+
if matches!(l, MatrixLayout::C { .. }) {
77+
for val in b.iter_mut() {
78+
*val = val.conj();
79+
}
80+
}
4781
Ok(())
4882
}
4983
}
5084
};
5185
} // end macro_rules
5286

53-
impl_cholesky!(f64, lapacke::dpotrf, lapacke::dpotri, lapacke::dpotrs);
54-
impl_cholesky!(f32, lapacke::spotrf, lapacke::spotri, lapacke::spotrs);
55-
impl_cholesky!(c64, lapacke::zpotrf, lapacke::zpotri, lapacke::zpotrs);
56-
impl_cholesky!(c32, lapacke::cpotrf, lapacke::cpotri, lapacke::cpotrs);
87+
impl_cholesky!(f64, lapack::dpotrf, lapack::dpotri, lapack::dpotrs);
88+
impl_cholesky!(f32, lapack::spotrf, lapack::spotri, lapack::spotrs);
89+
impl_cholesky!(c64, lapack::zpotrf, lapack::zpotri, lapack::zpotrs);
90+
impl_cholesky!(c32, lapack::cpotrf, lapack::cpotri, lapack::cpotrs);

0 commit comments

Comments
 (0)