1
- //! Eigenvalue decomposition for Hermite matrices
1
+ //! Eigenvalue decomposition for Symmetric/ Hermite matrices
2
2
3
3
use super :: * ;
4
4
use crate :: { error:: * , layout:: MatrixLayout } ;
5
5
use cauchy:: * ;
6
- use num_traits:: Zero ;
6
+ use num_traits:: { ToPrimitive , Zero } ;
7
7
8
- /// Wraps `*syev` for real and `*heev` for complex
9
8
pub trait Eigh_ : Scalar {
10
- unsafe fn eigh (
9
+ /// Wraps `*syev` for real and `*heev` for complex
10
+ fn eigh (
11
11
calc_eigenvec : bool ,
12
- l : MatrixLayout ,
12
+ layout : MatrixLayout ,
13
13
uplo : UPLO ,
14
14
a : & mut [ Self ] ,
15
15
) -> Result < Vec < Self :: Real > > ;
16
- unsafe fn eigh_generalized (
16
+
17
+ /// Wraps `*syegv` for real and `*heegv` for complex
18
+ fn eigh_generalized (
17
19
calc_eigenvec : bool ,
18
- l : MatrixLayout ,
20
+ layout : MatrixLayout ,
19
21
uplo : UPLO ,
20
22
a : & mut [ Self ] ,
21
23
b : & mut [ Self ] ,
@@ -25,50 +27,195 @@ pub trait Eigh_: Scalar {
25
27
macro_rules! impl_eigh {
26
28
( $scalar: ty, $ev: path, $evg: path) => {
27
29
impl Eigh_ for $scalar {
28
- unsafe fn eigh(
30
+ fn eigh(
31
+ calc_v: bool ,
32
+ layout: MatrixLayout ,
33
+ uplo: UPLO ,
34
+ mut a: & mut [ Self ] ,
35
+ ) -> Result <Vec <Self :: Real >> {
36
+ assert_eq!( layout. len( ) , layout. lda( ) ) ;
37
+ let n = layout. len( ) ;
38
+ let jobz = if calc_v { b'V' } else { b'N' } ;
39
+ let mut eigs = vec![ Self :: Real :: zero( ) ; n as usize ] ;
40
+ let n = n as i32 ;
41
+
42
+ // calc work size
43
+ let mut info = 0 ;
44
+ let mut work_size = [ 0.0 ] ;
45
+ unsafe {
46
+ $ev(
47
+ jobz,
48
+ uplo as u8 ,
49
+ n,
50
+ & mut a,
51
+ n,
52
+ & mut eigs,
53
+ & mut work_size,
54
+ -1 ,
55
+ & mut info,
56
+ ) ;
57
+ }
58
+ info. as_lapack_result( ) ?;
59
+
60
+ // actual ev
61
+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
62
+ let mut work = vec![ Self :: zero( ) ; lwork] ;
63
+ unsafe {
64
+ $ev(
65
+ jobz,
66
+ uplo as u8 ,
67
+ n,
68
+ & mut a,
69
+ n,
70
+ & mut eigs,
71
+ & mut work,
72
+ lwork as i32 ,
73
+ & mut info,
74
+ ) ;
75
+ }
76
+ info. as_lapack_result( ) ?;
77
+ Ok ( eigs)
78
+ }
79
+
80
+ fn eigh_generalized(
81
+ calc_v: bool ,
82
+ layout: MatrixLayout ,
83
+ uplo: UPLO ,
84
+ mut a: & mut [ Self ] ,
85
+ mut b: & mut [ Self ] ,
86
+ ) -> Result <Vec <Self :: Real >> {
87
+ assert_eq!( layout. len( ) , layout. lda( ) ) ;
88
+ let n = layout. len( ) ;
89
+ let jobz = if calc_v { b'V' } else { b'N' } ;
90
+ let mut eigs = vec![ Self :: Real :: zero( ) ; n as usize ] ;
91
+ let n = n as i32 ;
92
+
93
+ // calc work size
94
+ let mut info = 0 ;
95
+ let mut work_size = [ 0.0 ] ;
96
+ unsafe {
97
+ $evg(
98
+ & [ 1 ] ,
99
+ jobz,
100
+ uplo as u8 ,
101
+ n,
102
+ & mut a,
103
+ n,
104
+ & mut b,
105
+ n,
106
+ & mut eigs,
107
+ & mut work_size,
108
+ -1 ,
109
+ & mut info,
110
+ ) ;
111
+ }
112
+ info. as_lapack_result( ) ?;
113
+
114
+ // actual evg
115
+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
116
+ let mut work = vec![ Self :: zero( ) ; lwork] ;
117
+ unsafe {
118
+ $evg(
119
+ & [ 1 ] ,
120
+ jobz,
121
+ uplo as u8 ,
122
+ n,
123
+ & mut a,
124
+ n,
125
+ & mut b,
126
+ n,
127
+ & mut eigs,
128
+ & mut work,
129
+ lwork as i32 ,
130
+ & mut info,
131
+ ) ;
132
+ }
133
+ info. as_lapack_result( ) ?;
134
+ Ok ( eigs)
135
+ }
136
+ }
137
+ } ;
138
+ } // impl_eigh!
139
+
140
+ impl_eigh ! ( f64 , lapack:: dsyev, lapack:: dsygv) ;
141
+ impl_eigh ! ( f32 , lapack:: ssyev, lapack:: ssygv) ;
142
+
143
+ // splitted for RWORK
144
+ macro_rules! impl_eighc {
145
+ ( $scalar: ty, $ev: path, $evg: path) => {
146
+ impl Eigh_ for $scalar {
147
+ fn eigh(
29
148
calc_v: bool ,
30
- l : MatrixLayout ,
149
+ layout : MatrixLayout ,
31
150
uplo: UPLO ,
32
151
mut a: & mut [ Self ] ,
33
152
) -> Result <Vec <Self :: Real >> {
34
- let ( n, _) = l. size( ) ;
153
+ assert_eq!( layout. len( ) , layout. lda( ) ) ;
154
+ let n = layout. len( ) ;
35
155
let jobz = if calc_v { b'V' } else { b'N' } ;
36
- let mut w = vec![ Self :: Real :: zero( ) ; n as usize ] ;
37
- $ev( l. lapacke_layout( ) , jobz, uplo as u8 , n, & mut a, n, & mut w)
38
- . as_lapack_result( ) ?;
39
- Ok ( w)
156
+ let mut eigs = vec![ Self :: Real :: zero( ) ; n as usize ] ;
157
+ let mut work = vec![ Self :: zero( ) ; 2 * n as usize - 1 ] ;
158
+ let mut rwork = vec![ Self :: Real :: zero( ) ; 3 * n as usize - 2 ] ;
159
+ let mut info = 0 ;
160
+ let n = n as i32 ;
161
+
162
+ unsafe {
163
+ $ev(
164
+ jobz,
165
+ uplo as u8 ,
166
+ n,
167
+ & mut a,
168
+ n,
169
+ & mut eigs,
170
+ & mut work,
171
+ 2 * n - 1 ,
172
+ & mut rwork,
173
+ & mut info,
174
+ )
175
+ } ;
176
+ info. as_lapack_result( ) ?;
177
+ Ok ( eigs)
40
178
}
41
179
42
- unsafe fn eigh_generalized(
180
+ fn eigh_generalized(
43
181
calc_v: bool ,
44
- l : MatrixLayout ,
182
+ layout : MatrixLayout ,
45
183
uplo: UPLO ,
46
184
mut a: & mut [ Self ] ,
47
185
mut b: & mut [ Self ] ,
48
186
) -> Result <Vec <Self :: Real >> {
49
- let ( n, _) = l. size( ) ;
187
+ assert_eq!( layout. len( ) , layout. lda( ) ) ;
188
+ let n = layout. len( ) ;
50
189
let jobz = if calc_v { b'V' } else { b'N' } ;
51
- let mut w = vec![ Self :: Real :: zero( ) ; n as usize ] ;
52
- $evg(
53
- l. lapacke_layout( ) ,
54
- 1 ,
55
- jobz,
56
- uplo as u8 ,
57
- n,
58
- & mut a,
59
- n,
60
- & mut b,
61
- n,
62
- & mut w,
63
- )
64
- . as_lapack_result( ) ?;
65
- Ok ( w)
190
+ let mut eigs = vec![ Self :: Real :: zero( ) ; n as usize ] ;
191
+ let mut work = vec![ Self :: zero( ) ; 2 * n as usize - 1 ] ;
192
+ let mut rwork = vec![ Self :: Real :: zero( ) ; 3 * n as usize - 2 ] ;
193
+ let n = n as i32 ;
194
+ let mut info = 0 ;
195
+
196
+ unsafe {
197
+ $evg(
198
+ & [ 1 ] ,
199
+ jobz,
200
+ uplo as u8 ,
201
+ n,
202
+ & mut a,
203
+ n,
204
+ & mut b,
205
+ n,
206
+ & mut eigs,
207
+ & mut work,
208
+ 2 * n - 1 ,
209
+ & mut rwork,
210
+ & mut info,
211
+ )
212
+ } ;
213
+ info. as_lapack_result( ) ?;
214
+ Ok ( eigs)
66
215
}
67
216
}
68
217
} ;
69
218
} // impl_eigh!
70
219
71
- impl_eigh ! ( f64 , lapacke:: dsyev, lapacke:: dsygv) ;
72
- impl_eigh ! ( f32 , lapacke:: ssyev, lapacke:: ssygv) ;
73
- impl_eigh ! ( c64, lapacke:: zheev, lapacke:: zhegv) ;
74
- impl_eigh ! ( c32, lapacke:: cheev, lapacke:: chegv) ;
220
+ impl_eighc ! ( c64, lapack:: zheev, lapack:: zhegv) ;
221
+ impl_eighc ! ( c32, lapack:: cheev, lapack:: chegv) ;
0 commit comments