3
3
use super :: * ;
4
4
use crate :: { error:: * , layout:: MatrixLayout } ;
5
5
use cauchy:: * ;
6
- use num_traits:: Zero ;
6
+ use num_traits:: { ToPrimitive , Zero } ;
7
7
8
8
pub trait Solve_ : Scalar + Sized {
9
9
/// Computes the LU factorization of a general `m x n` matrix `a` using
@@ -14,59 +14,55 @@ pub trait Solve_: Scalar + Sized {
14
14
/// Error
15
15
/// ------
16
16
/// - `LapackComputationalFailure { return_code }` when the matrix is singular
17
- /// - `U[(return_code-1, return_code-1)]` is exactly zero.
18
- /// - Division by zero will occur if it is used to solve a system of equations .
17
+ /// - Division by zero will occur if it is used to solve a system of equations
18
+ /// because `U[(return_code-1, return_code-1)]` is exactly zero .
19
19
fn lu ( l : MatrixLayout , a : & mut [ Self ] ) -> Result < Pivot > ;
20
20
21
21
fn inv ( l : MatrixLayout , a : & mut [ Self ] , p : & Pivot ) -> Result < ( ) > ;
22
22
23
- /// Estimates the the reciprocal of the condition number of the matrix in 1-norm.
24
- ///
25
- /// `anorm` should be the 1-norm of the matrix `a`.
26
- fn rcond ( l : MatrixLayout , a : & [ Self ] , anorm : Self :: Real ) -> Result < Self :: Real > ;
27
-
28
23
fn solve ( l : MatrixLayout , t : Transpose , a : & [ Self ] , p : & Pivot , b : & mut [ Self ] ) -> Result < ( ) > ;
29
24
}
30
25
31
26
macro_rules! impl_solve {
32
- ( $scalar: ty, $getrf: path, $getri: path, $gecon : path , $ getrs: path) => {
27
+ ( $scalar: ty, $getrf: path, $getri: path, $getrs: path) => {
33
28
impl Solve_ for $scalar {
34
29
fn lu( l: MatrixLayout , a: & mut [ Self ] ) -> Result <Pivot > {
35
30
let ( row, col) = l. size( ) ;
31
+ assert_eq!( a. len( ) as i32 , row * col) ;
36
32
let k = :: std:: cmp:: min( row, col) ;
37
33
let mut ipiv = vec![ 0 ; k as usize ] ;
38
- unsafe {
39
- $getrf( l. lapacke_layout( ) , row, col, a, l. lda( ) , & mut ipiv)
40
- . as_lapack_result( ) ?;
41
- }
34
+ let mut info = 0 ;
35
+ unsafe { $getrf( l. lda( ) , l. len( ) , a, l. lda( ) , & mut ipiv, & mut info) } ;
36
+ info. as_lapack_result( ) ?;
42
37
Ok ( ipiv)
43
38
}
44
39
45
40
fn inv( l: MatrixLayout , a: & mut [ Self ] , ipiv: & Pivot ) -> Result <( ) > {
46
41
let ( n, _) = l. size( ) ;
47
- unsafe {
48
- $getri( l. lapacke_layout( ) , n, a, l. lda( ) , ipiv) . as_lapack_result( ) ?;
49
- }
50
- Ok ( ( ) )
51
- }
52
42
53
- fn rcond( l: MatrixLayout , a: & [ Self ] , anorm: Self :: Real ) -> Result <Self :: Real > {
54
- let ( n, _) = l. size( ) ;
55
- let mut rcond = Self :: Real :: zero( ) ;
43
+ // calc work size
44
+ let mut info = 0 ;
45
+ let mut work_size = [ Self :: zero( ) ] ;
46
+ unsafe { $getri( n, a, l. lda( ) , ipiv, & mut work_size, -1 , & mut info) } ;
47
+ info. as_lapack_result( ) ?;
48
+
49
+ // actual
50
+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
51
+ let mut work = vec![ Self :: zero( ) ; lwork] ;
56
52
unsafe {
57
- $gecon(
58
- l. lapacke_layout( ) ,
59
- NormType :: One as u8 ,
60
- n,
53
+ $getri(
54
+ l. len( ) ,
61
55
a,
62
56
l. lda( ) ,
63
- anorm,
64
- & mut rcond,
57
+ ipiv,
58
+ & mut work,
59
+ lwork as i32 ,
60
+ & mut info,
65
61
)
66
- }
67
- . as_lapack_result( ) ?;
62
+ } ;
63
+ info . as_lapack_result( ) ?;
68
64
69
- Ok ( rcond )
65
+ Ok ( ( ) )
70
66
}
71
67
72
68
fn solve(
@@ -76,54 +72,26 @@ macro_rules! impl_solve {
76
72
ipiv: & Pivot ,
77
73
b: & mut [ Self ] ,
78
74
) -> Result <( ) > {
75
+ let t = match l {
76
+ MatrixLayout :: C { .. } => match t {
77
+ Transpose :: No => Transpose :: Transpose ,
78
+ Transpose :: Transpose | Transpose :: Hermite => Transpose :: No ,
79
+ } ,
80
+ _ => t,
81
+ } ;
79
82
let ( n, _) = l. size( ) ;
80
83
let nrhs = 1 ;
81
- let ldb = 1 ;
82
- unsafe {
83
- $getrs(
84
- l. lapacke_layout( ) ,
85
- t as u8 ,
86
- n,
87
- nrhs,
88
- a,
89
- l. lda( ) ,
90
- ipiv,
91
- b,
92
- ldb,
93
- )
94
- . as_lapack_result( ) ?;
95
- }
84
+ let ldb = l. lda( ) ;
85
+ let mut info = 0 ;
86
+ unsafe { $getrs( t as u8 , n, nrhs, a, l. lda( ) , ipiv, b, ldb, & mut info) } ;
87
+ info. as_lapack_result( ) ?;
96
88
Ok ( ( ) )
97
89
}
98
90
}
99
91
} ;
100
92
} // impl_solve!
101
93
102
- impl_solve ! (
103
- f64 ,
104
- lapacke:: dgetrf,
105
- lapacke:: dgetri,
106
- lapacke:: dgecon,
107
- lapacke:: dgetrs
108
- ) ;
109
- impl_solve ! (
110
- f32 ,
111
- lapacke:: sgetrf,
112
- lapacke:: sgetri,
113
- lapacke:: sgecon,
114
- lapacke:: sgetrs
115
- ) ;
116
- impl_solve ! (
117
- c64,
118
- lapacke:: zgetrf,
119
- lapacke:: zgetri,
120
- lapacke:: zgecon,
121
- lapacke:: zgetrs
122
- ) ;
123
- impl_solve ! (
124
- c32,
125
- lapacke:: cgetrf,
126
- lapacke:: cgetri,
127
- lapacke:: cgecon,
128
- lapacke:: cgetrs
129
- ) ;
94
+ impl_solve ! ( f64 , lapack:: dgetrf, lapack:: dgetri, lapack:: dgetrs) ;
95
+ impl_solve ! ( f32 , lapack:: sgetrf, lapack:: sgetri, lapack:: sgetrs) ;
96
+ impl_solve ! ( c64, lapack:: zgetrf, lapack:: zgetri, lapack:: zgetrs) ;
97
+ impl_solve ! ( c32, lapack:: cgetrf, lapack:: cgetri, lapack:: cgetrs) ;
0 commit comments