Skip to content

Commit e9c3481

Browse files
committed
Impl triangular inv/solve using LAPACK
1 parent 24888ec commit e9c3481

File tree

2 files changed

+58
-42
lines changed

2 files changed

+58
-42
lines changed

lax/src/triangular.rs

Lines changed: 57 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
//! Implement linear solver and inverse matrix
22
33
use super::*;
4-
use crate::{error::*, layout::MatrixLayout};
4+
use crate::{error::*, layout::*};
55
use cauchy::*;
6+
use num_traits::Zero;
67

78
#[derive(Debug, Clone, Copy)]
89
#[repr(u8)]
@@ -12,9 +13,8 @@ pub enum Diag {
1213
}
1314

1415
/// Wraps `*trtri` and `*trtrs`
15-
pub trait Triangular_: Sized {
16-
unsafe fn inv_triangular(l: MatrixLayout, uplo: UPLO, d: Diag, a: &mut [Self]) -> Result<()>;
17-
unsafe fn solve_triangular(
16+
pub trait Triangular_: Scalar {
17+
fn solve_triangular(
1818
al: MatrixLayout,
1919
bl: MatrixLayout,
2020
uplo: UPLO,
@@ -27,50 +27,66 @@ pub trait Triangular_: Sized {
2727
macro_rules! impl_triangular {
2828
($scalar:ty, $trtri:path, $trtrs:path) => {
2929
impl Triangular_ for $scalar {
30-
unsafe fn inv_triangular(
31-
l: MatrixLayout,
32-
uplo: UPLO,
33-
diag: Diag,
34-
a: &mut [Self],
35-
) -> Result<()> {
36-
let (n, _) = l.size();
37-
let lda = l.lda();
38-
$trtri(l.lapacke_layout(), uplo as u8, diag as u8, n, a, lda).as_lapack_result()?;
39-
Ok(())
40-
}
41-
42-
unsafe fn solve_triangular(
43-
al: MatrixLayout,
44-
bl: MatrixLayout,
30+
fn solve_triangular(
31+
a_layout: MatrixLayout,
32+
b_layout: MatrixLayout,
4533
uplo: UPLO,
4634
diag: Diag,
4735
a: &[Self],
48-
mut b: &mut [Self],
36+
b: &mut [Self],
4937
) -> Result<()> {
50-
let (n, _) = al.size();
51-
let lda = al.lda();
52-
let (_, nrhs) = bl.size();
53-
let ldb = bl.lda();
54-
$trtrs(
55-
al.lapacke_layout(),
56-
uplo as u8,
57-
Transpose::No as u8,
58-
diag as u8,
59-
n,
60-
nrhs,
61-
a,
62-
lda,
63-
&mut b,
64-
ldb,
65-
)
66-
.as_lapack_result()?;
38+
// Transpose if a is C-continuous
39+
let mut a_t = None;
40+
let a_layout = match a_layout {
41+
MatrixLayout::C { .. } => {
42+
a_t = Some(vec![Self::zero(); a.len()]);
43+
transpose(a_layout, a, a_t.as_mut().unwrap())
44+
}
45+
MatrixLayout::F { .. } => a_layout,
46+
};
47+
48+
// Transpose if b is C-continuous
49+
let mut b_t = None;
50+
let b_layout = match b_layout {
51+
MatrixLayout::C { .. } => {
52+
b_t = Some(vec![Self::zero(); b.len()]);
53+
transpose(b_layout, b, b_t.as_mut().unwrap())
54+
}
55+
MatrixLayout::F { .. } => b_layout,
56+
};
57+
58+
let (m, n) = a_layout.size();
59+
let (n_, nrhs) = b_layout.size();
60+
assert_eq!(n, n_);
61+
62+
let mut info = 0;
63+
unsafe {
64+
$trtrs(
65+
uplo as u8,
66+
Transpose::No as u8,
67+
diag as u8,
68+
m,
69+
nrhs,
70+
a_t.as_ref().map(|v| v.as_slice()).unwrap_or(a),
71+
a_layout.lda(),
72+
b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b),
73+
b_layout.lda(),
74+
&mut info,
75+
);
76+
}
77+
info.as_lapack_result()?;
78+
79+
// Re-transpose b
80+
if let Some(b_t) = b_t {
81+
transpose(b_layout, &b_t, b);
82+
}
6783
Ok(())
6884
}
6985
}
7086
};
7187
} // impl_triangular!
7288

73-
impl_triangular!(f64, lapacke::dtrtri, lapacke::dtrtrs);
74-
impl_triangular!(f32, lapacke::strtri, lapacke::strtrs);
75-
impl_triangular!(c64, lapacke::ztrtri, lapacke::ztrtrs);
76-
impl_triangular!(c32, lapacke::ctrtri, lapacke::ctrtrs);
89+
impl_triangular!(f64, lapack::dtrtri, lapack::dtrtrs);
90+
impl_triangular!(f32, lapack::strtri, lapack::strtrs);
91+
impl_triangular!(c64, lapack::ztrtri, lapack::ztrtrs);
92+
impl_triangular!(c32, lapack::ctrtri, lapack::ctrtrs);

ndarray-linalg/src/triangular.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ where
8585
transpose_data(b)?;
8686
}
8787
let lb = b.layout()?;
88-
unsafe { A::solve_triangular(la, lb, uplo, diag, a_, b.as_allocated_mut()?)? };
88+
A::solve_triangular(la, lb, uplo, diag, a_, b.as_allocated_mut()?)?;
8989
Ok(b)
9090
}
9191
}

0 commit comments

Comments
 (0)