|
1 | 1 | //! Cholesky decomposition
|
2 | 2 |
|
3 | 3 | use super::*;
|
4 |
| -use crate::{error::*, layout::MatrixLayout}; |
| 4 | +use crate::{error::*, layout::*}; |
5 | 5 | use cauchy::*;
|
6 | 6 |
|
7 | 7 | pub trait Cholesky_: Sized {
|
8 | 8 | /// Cholesky: wrapper of `*potrf`
|
9 | 9 | ///
|
10 | 10 | /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.**
|
11 |
| - unsafe fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; |
| 11 | + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; |
| 12 | + |
12 | 13 | /// Wrapper of `*potri`
|
13 | 14 | ///
|
14 | 15 | /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.**
|
15 |
| - unsafe fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; |
| 16 | + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; |
| 17 | + |
16 | 18 | /// Wrapper of `*potrs`
|
17 |
| - unsafe fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) |
18 |
| - -> Result<()>; |
| 19 | + fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; |
19 | 20 | }
|
20 | 21 |
|
21 | 22 | macro_rules! impl_cholesky {
|
22 | 23 | ($scalar:ty, $trf:path, $tri:path, $trs:path) => {
|
23 | 24 | impl Cholesky_ for $scalar {
|
24 |
| - unsafe fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { |
| 25 | + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { |
25 | 26 | let (n, _) = l.size();
|
26 |
| - $trf(l.lapacke_layout(), uplo as u8, n, a, n).as_lapack_result()?; |
| 27 | + if matches!(l, MatrixLayout::C { .. }) { |
| 28 | + square_transpose(l, a); |
| 29 | + } |
| 30 | + let mut info = 0; |
| 31 | + unsafe { |
| 32 | + $trf(uplo as u8, n, a, n, &mut info); |
| 33 | + } |
| 34 | + info.as_lapack_result()?; |
| 35 | + if matches!(l, MatrixLayout::C { .. }) { |
| 36 | + square_transpose(l, a); |
| 37 | + } |
27 | 38 | Ok(())
|
28 | 39 | }
|
29 | 40 |
|
30 |
| - unsafe fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { |
| 41 | + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { |
31 | 42 | let (n, _) = l.size();
|
32 |
| - $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda()).as_lapack_result()?; |
| 43 | + if matches!(l, MatrixLayout::C { .. }) { |
| 44 | + square_transpose(l, a); |
| 45 | + } |
| 46 | + let mut info = 0; |
| 47 | + unsafe { |
| 48 | + $tri(uplo as u8, n, a, l.lda(), &mut info); |
| 49 | + } |
| 50 | + info.as_lapack_result()?; |
| 51 | + if matches!(l, MatrixLayout::C { .. }) { |
| 52 | + square_transpose(l, a); |
| 53 | + } |
33 | 54 | Ok(())
|
34 | 55 | }
|
35 | 56 |
|
36 |
| - unsafe fn solve_cholesky( |
| 57 | + fn solve_cholesky( |
37 | 58 | l: MatrixLayout,
|
38 |
| - uplo: UPLO, |
| 59 | + mut uplo: UPLO, |
39 | 60 | a: &[Self],
|
40 | 61 | b: &mut [Self],
|
41 | 62 | ) -> Result<()> {
|
42 | 63 | let (n, _) = l.size();
|
43 | 64 | let nrhs = 1;
|
44 |
| - let ldb = 1; |
45 |
| - $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), b, ldb) |
46 |
| - .as_lapack_result()?; |
| 65 | + let mut info = 0; |
| 66 | + if matches!(l, MatrixLayout::C { .. }) { |
| 67 | + uplo = uplo.t(); |
| 68 | + for val in b.iter_mut() { |
| 69 | + *val = val.conj(); |
| 70 | + } |
| 71 | + } |
| 72 | + unsafe { |
| 73 | + $trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info); |
| 74 | + } |
| 75 | + info.as_lapack_result()?; |
| 76 | + if matches!(l, MatrixLayout::C { .. }) { |
| 77 | + for val in b.iter_mut() { |
| 78 | + *val = val.conj(); |
| 79 | + } |
| 80 | + } |
47 | 81 | Ok(())
|
48 | 82 | }
|
49 | 83 | }
|
50 | 84 | };
|
51 | 85 | } // end macro_rules
|
52 | 86 |
|
53 |
| -impl_cholesky!(f64, lapacke::dpotrf, lapacke::dpotri, lapacke::dpotrs); |
54 |
| -impl_cholesky!(f32, lapacke::spotrf, lapacke::spotri, lapacke::spotrs); |
55 |
| -impl_cholesky!(c64, lapacke::zpotrf, lapacke::zpotri, lapacke::zpotrs); |
56 |
| -impl_cholesky!(c32, lapacke::cpotrf, lapacke::cpotri, lapacke::cpotrs); |
| 87 | +impl_cholesky!(f64, lapack::dpotrf, lapack::dpotri, lapack::dpotrs); |
| 88 | +impl_cholesky!(f32, lapack::spotrf, lapack::spotri, lapack::spotrs); |
| 89 | +impl_cholesky!(c64, lapack::zpotrf, lapack::zpotri, lapack::zpotrs); |
| 90 | +impl_cholesky!(c32, lapack::cpotrf, lapack::cpotri, lapack::cpotrs); |
0 commit comments