Skip to content

Commit 3cf60ed

Browse files
committed
Impl solve using LAPACK
1 parent f7d93f4 commit 3cf60ed

File tree

2 files changed

+44
-73
lines changed

2 files changed

+44
-73
lines changed

lax/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ pub mod layout;
7070
pub mod least_squares;
7171
pub mod opnorm;
7272
pub mod qr;
73+
pub mod rcond;
7374
pub mod solve;
7475
pub mod solveh;
7576
pub mod svd;
@@ -83,6 +84,7 @@ pub use self::eigh::*;
8384
pub use self::least_squares::*;
8485
pub use self::opnorm::*;
8586
pub use self::qr::*;
87+
pub use self::rcond::*;
8688
pub use self::solve::*;
8789
pub use self::solveh::*;
8890
pub use self::svd::*;
@@ -107,6 +109,7 @@ pub trait Lapack:
107109
+ Eigh_
108110
+ Triangular_
109111
+ Tridiagonal_
112+
+ Rcond_
110113
{
111114
}
112115

lax/src/solve.rs

Lines changed: 41 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use super::*;
44
use crate::{error::*, layout::MatrixLayout};
55
use cauchy::*;
6-
use num_traits::Zero;
6+
use num_traits::{ToPrimitive, Zero};
77

88
pub trait Solve_: Scalar + Sized {
99
/// Computes the LU factorization of a general `m x n` matrix `a` using
@@ -14,59 +14,55 @@ pub trait Solve_: Scalar + Sized {
1414
/// Error
1515
/// ------
1616
/// - `LapackComputationalFailure { return_code }` when the matrix is singular
17-
/// - `U[(return_code-1, return_code-1)]` is exactly zero.
18-
/// - Division by zero will occur if it is used to solve a system of equations.
17+
/// - Division by zero will occur if it is used to solve a system of equations
18+
/// because `U[(return_code-1, return_code-1)]` is exactly zero.
1919
fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot>;
2020

2121
fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>;
2222

23-
/// Estimates the the reciprocal of the condition number of the matrix in 1-norm.
24-
///
25-
/// `anorm` should be the 1-norm of the matrix `a`.
26-
fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real>;
27-
2823
fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>;
2924
}
3025

3126
macro_rules! impl_solve {
32-
($scalar:ty, $getrf:path, $getri:path, $gecon:path, $getrs:path) => {
27+
($scalar:ty, $getrf:path, $getri:path, $getrs:path) => {
3328
impl Solve_ for $scalar {
3429
fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> {
3530
let (row, col) = l.size();
31+
assert_eq!(a.len() as i32, row * col);
3632
let k = ::std::cmp::min(row, col);
3733
let mut ipiv = vec![0; k as usize];
38-
unsafe {
39-
$getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv)
40-
.as_lapack_result()?;
41-
}
34+
let mut info = 0;
35+
unsafe { $getrf(l.lda(), l.len(), a, l.lda(), &mut ipiv, &mut info) };
36+
info.as_lapack_result()?;
4237
Ok(ipiv)
4338
}
4439

4540
fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
4641
let (n, _) = l.size();
47-
unsafe {
48-
$getri(l.lapacke_layout(), n, a, l.lda(), ipiv).as_lapack_result()?;
49-
}
50-
Ok(())
51-
}
5242

53-
fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real> {
54-
let (n, _) = l.size();
55-
let mut rcond = Self::Real::zero();
43+
// calc work size
44+
let mut info = 0;
45+
let mut work_size = [Self::zero()];
46+
unsafe { $getri(n, a, l.lda(), ipiv, &mut work_size, -1, &mut info) };
47+
info.as_lapack_result()?;
48+
49+
// actual
50+
let lwork = work_size[0].to_usize().unwrap();
51+
let mut work = vec![Self::zero(); lwork];
5652
unsafe {
57-
$gecon(
58-
l.lapacke_layout(),
59-
NormType::One as u8,
60-
n,
53+
$getri(
54+
l.len(),
6155
a,
6256
l.lda(),
63-
anorm,
64-
&mut rcond,
57+
ipiv,
58+
&mut work,
59+
lwork as i32,
60+
&mut info,
6561
)
66-
}
67-
.as_lapack_result()?;
62+
};
63+
info.as_lapack_result()?;
6864

69-
Ok(rcond)
65+
Ok(())
7066
}
7167

7268
fn solve(
@@ -76,54 +72,26 @@ macro_rules! impl_solve {
7672
ipiv: &Pivot,
7773
b: &mut [Self],
7874
) -> Result<()> {
75+
let t = match l {
76+
MatrixLayout::C { .. } => match t {
77+
Transpose::No => Transpose::Transpose,
78+
Transpose::Transpose | Transpose::Hermite => Transpose::No,
79+
},
80+
_ => t,
81+
};
7982
let (n, _) = l.size();
8083
let nrhs = 1;
81-
let ldb = 1;
82-
unsafe {
83-
$getrs(
84-
l.lapacke_layout(),
85-
t as u8,
86-
n,
87-
nrhs,
88-
a,
89-
l.lda(),
90-
ipiv,
91-
b,
92-
ldb,
93-
)
94-
.as_lapack_result()?;
95-
}
84+
let ldb = l.lda();
85+
let mut info = 0;
86+
unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) };
87+
info.as_lapack_result()?;
9688
Ok(())
9789
}
9890
}
9991
};
10092
} // impl_solve!
10193

102-
impl_solve!(
103-
f64,
104-
lapacke::dgetrf,
105-
lapacke::dgetri,
106-
lapacke::dgecon,
107-
lapacke::dgetrs
108-
);
109-
impl_solve!(
110-
f32,
111-
lapacke::sgetrf,
112-
lapacke::sgetri,
113-
lapacke::sgecon,
114-
lapacke::sgetrs
115-
);
116-
impl_solve!(
117-
c64,
118-
lapacke::zgetrf,
119-
lapacke::zgetri,
120-
lapacke::zgecon,
121-
lapacke::zgetrs
122-
);
123-
impl_solve!(
124-
c32,
125-
lapacke::cgetrf,
126-
lapacke::cgetri,
127-
lapacke::cgecon,
128-
lapacke::cgetrs
129-
);
94+
impl_solve!(f64, lapack::dgetrf, lapack::dgetri, lapack::dgetrs);
95+
impl_solve!(f32, lapack::sgetrf, lapack::sgetri, lapack::sgetrs);
96+
impl_solve!(c64, lapack::zgetrf, lapack::zgetri, lapack::zgetrs);
97+
impl_solve!(c32, lapack::cgetrf, lapack::cgetri, lapack::cgetrs);

0 commit comments

Comments
 (0)