Skip to content

Commit 77b6269

Browse files
committed
Rewrite lapack::eigh using lapack crate
1 parent e652fcc commit 77b6269

File tree

2 files changed

+192
-47
lines changed

2 files changed

+192
-47
lines changed

lax/src/eigh.rs

Lines changed: 183 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
1-
//! Eigenvalue decomposition for Hermite matrices
1+
//! Eigenvalue decomposition for Symmetric/Hermite matrices
22
33
use super::*;
44
use crate::{error::*, layout::MatrixLayout};
55
use cauchy::*;
6-
use num_traits::Zero;
6+
use num_traits::{ToPrimitive, Zero};
77

8-
/// Wraps `*syev` for real and `*heev` for complex
98
pub trait Eigh_: Scalar {
10-
unsafe fn eigh(
9+
/// Wraps `*syev` for real and `*heev` for complex
10+
fn eigh(
1111
calc_eigenvec: bool,
12-
l: MatrixLayout,
12+
layout: MatrixLayout,
1313
uplo: UPLO,
1414
a: &mut [Self],
1515
) -> Result<Vec<Self::Real>>;
16-
unsafe fn eigh_generalized(
16+
17+
/// Wraps `*syegv` for real and `*heegv` for complex
18+
fn eigh_generalized(
1719
calc_eigenvec: bool,
18-
l: MatrixLayout,
20+
layout: MatrixLayout,
1921
uplo: UPLO,
2022
a: &mut [Self],
2123
b: &mut [Self],
@@ -25,50 +27,195 @@ pub trait Eigh_: Scalar {
2527
macro_rules! impl_eigh {
2628
($scalar:ty, $ev:path, $evg:path) => {
2729
impl Eigh_ for $scalar {
28-
unsafe fn eigh(
30+
fn eigh(
31+
calc_v: bool,
32+
layout: MatrixLayout,
33+
uplo: UPLO,
34+
mut a: &mut [Self],
35+
) -> Result<Vec<Self::Real>> {
36+
assert_eq!(layout.len(), layout.lda());
37+
let n = layout.len();
38+
let jobz = if calc_v { b'V' } else { b'N' };
39+
let mut eigs = vec![Self::Real::zero(); n as usize];
40+
let n = n as i32;
41+
42+
// calc work size
43+
let mut info = 0;
44+
let mut work_size = [0.0];
45+
unsafe {
46+
$ev(
47+
jobz,
48+
uplo as u8,
49+
n,
50+
&mut a,
51+
n,
52+
&mut eigs,
53+
&mut work_size,
54+
-1,
55+
&mut info,
56+
);
57+
}
58+
info.as_lapack_result()?;
59+
60+
// actual ev
61+
let lwork = work_size[0].to_usize().unwrap();
62+
let mut work = vec![Self::zero(); lwork];
63+
unsafe {
64+
$ev(
65+
jobz,
66+
uplo as u8,
67+
n,
68+
&mut a,
69+
n,
70+
&mut eigs,
71+
&mut work,
72+
lwork as i32,
73+
&mut info,
74+
);
75+
}
76+
info.as_lapack_result()?;
77+
Ok(eigs)
78+
}
79+
80+
fn eigh_generalized(
81+
calc_v: bool,
82+
layout: MatrixLayout,
83+
uplo: UPLO,
84+
mut a: &mut [Self],
85+
mut b: &mut [Self],
86+
) -> Result<Vec<Self::Real>> {
87+
assert_eq!(layout.len(), layout.lda());
88+
let n = layout.len();
89+
let jobz = if calc_v { b'V' } else { b'N' };
90+
let mut eigs = vec![Self::Real::zero(); n as usize];
91+
let n = n as i32;
92+
93+
// calc work size
94+
let mut info = 0;
95+
let mut work_size = [0.0];
96+
unsafe {
97+
$evg(
98+
&[1],
99+
jobz,
100+
uplo as u8,
101+
n,
102+
&mut a,
103+
n,
104+
&mut b,
105+
n,
106+
&mut eigs,
107+
&mut work_size,
108+
-1,
109+
&mut info,
110+
);
111+
}
112+
info.as_lapack_result()?;
113+
114+
// actual evg
115+
let lwork = work_size[0].to_usize().unwrap();
116+
let mut work = vec![Self::zero(); lwork];
117+
unsafe {
118+
$evg(
119+
&[1],
120+
jobz,
121+
uplo as u8,
122+
n,
123+
&mut a,
124+
n,
125+
&mut b,
126+
n,
127+
&mut eigs,
128+
&mut work,
129+
lwork as i32,
130+
&mut info,
131+
);
132+
}
133+
info.as_lapack_result()?;
134+
Ok(eigs)
135+
}
136+
}
137+
};
138+
} // impl_eigh!
139+
140+
impl_eigh!(f64, lapack::dsyev, lapack::dsygv);
141+
impl_eigh!(f32, lapack::ssyev, lapack::ssygv);
142+
143+
// splitted for RWORK
144+
macro_rules! impl_eighc {
145+
($scalar:ty, $ev:path, $evg:path) => {
146+
impl Eigh_ for $scalar {
147+
fn eigh(
29148
calc_v: bool,
30-
l: MatrixLayout,
149+
layout: MatrixLayout,
31150
uplo: UPLO,
32151
mut a: &mut [Self],
33152
) -> Result<Vec<Self::Real>> {
34-
let (n, _) = l.size();
153+
assert_eq!(layout.len(), layout.lda());
154+
let n = layout.len();
35155
let jobz = if calc_v { b'V' } else { b'N' };
36-
let mut w = vec![Self::Real::zero(); n as usize];
37-
$ev(l.lapacke_layout(), jobz, uplo as u8, n, &mut a, n, &mut w)
38-
.as_lapack_result()?;
39-
Ok(w)
156+
let mut eigs = vec![Self::Real::zero(); n as usize];
157+
let mut work = vec![Self::zero(); 2 * n as usize - 1];
158+
let mut rwork = vec![Self::Real::zero(); 3 * n as usize - 2];
159+
let mut info = 0;
160+
let n = n as i32;
161+
162+
unsafe {
163+
$ev(
164+
jobz,
165+
uplo as u8,
166+
n,
167+
&mut a,
168+
n,
169+
&mut eigs,
170+
&mut work,
171+
2 * n - 1,
172+
&mut rwork,
173+
&mut info,
174+
)
175+
};
176+
info.as_lapack_result()?;
177+
Ok(eigs)
40178
}
41179

42-
unsafe fn eigh_generalized(
180+
fn eigh_generalized(
43181
calc_v: bool,
44-
l: MatrixLayout,
182+
layout: MatrixLayout,
45183
uplo: UPLO,
46184
mut a: &mut [Self],
47185
mut b: &mut [Self],
48186
) -> Result<Vec<Self::Real>> {
49-
let (n, _) = l.size();
187+
assert_eq!(layout.len(), layout.lda());
188+
let n = layout.len();
50189
let jobz = if calc_v { b'V' } else { b'N' };
51-
let mut w = vec![Self::Real::zero(); n as usize];
52-
$evg(
53-
l.lapacke_layout(),
54-
1,
55-
jobz,
56-
uplo as u8,
57-
n,
58-
&mut a,
59-
n,
60-
&mut b,
61-
n,
62-
&mut w,
63-
)
64-
.as_lapack_result()?;
65-
Ok(w)
190+
let mut eigs = vec![Self::Real::zero(); n as usize];
191+
let mut work = vec![Self::zero(); 2 * n as usize - 1];
192+
let mut rwork = vec![Self::Real::zero(); 3 * n as usize - 2];
193+
let n = n as i32;
194+
let mut info = 0;
195+
196+
unsafe {
197+
$evg(
198+
&[1],
199+
jobz,
200+
uplo as u8,
201+
n,
202+
&mut a,
203+
n,
204+
&mut b,
205+
n,
206+
&mut eigs,
207+
&mut work,
208+
2 * n - 1,
209+
&mut rwork,
210+
&mut info,
211+
)
212+
};
213+
info.as_lapack_result()?;
214+
Ok(eigs)
66215
}
67216
}
68217
};
69218
} // impl_eigh!
70219

71-
impl_eigh!(f64, lapacke::dsyev, lapacke::dsygv);
72-
impl_eigh!(f32, lapacke::ssyev, lapacke::ssygv);
73-
impl_eigh!(c64, lapacke::zheev, lapacke::zhegv);
74-
impl_eigh!(c32, lapacke::cheev, lapacke::chegv);
220+
impl_eighc!(c64, lapack::zheev, lapack::zhegv);
221+
impl_eighc!(c32, lapack::cheev, lapack::chegv);

ndarray-linalg/src/eigh.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ where
9999
MatrixLayout::C { .. } => self.swap_axes(0, 1),
100100
MatrixLayout::F { .. } => {}
101101
}
102-
let s = unsafe { A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)? };
102+
let s = A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)?;
103103
Ok((ArrayBase::from(s), self))
104104
}
105105
}
@@ -126,15 +126,13 @@ where
126126
MatrixLayout::F { .. } => {}
127127
}
128128

129-
let s = unsafe {
130-
A::eigh_generalized(
131-
true,
132-
self.0.square_layout()?,
133-
uplo,
134-
self.0.as_allocated_mut()?,
135-
self.1.as_allocated_mut()?,
136-
)?
137-
};
129+
let s = A::eigh_generalized(
130+
true,
131+
self.0.square_layout()?,
132+
uplo,
133+
self.0.as_allocated_mut()?,
134+
self.1.as_allocated_mut()?,
135+
)?;
138136

139137
Ok((ArrayBase::from(s), self))
140138
}
@@ -191,7 +189,7 @@ where
191189
type EigVal = Array1<A::Real>;
192190

193191
fn eigvalsh_inplace(&mut self, uplo: UPLO) -> Result<Self::EigVal> {
194-
let s = unsafe { A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)? };
192+
let s = A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)?;
195193
Ok(ArrayBase::from(s))
196194
}
197195
}

0 commit comments

Comments
 (0)