Skip to content

Commit 299535a

Browse files
committed
Rewrite lapack::eigh using lapack crate
1 parent 04e33ae commit 299535a

File tree

2 files changed

+148
-46
lines changed

2 files changed

+148
-46
lines changed

lax/src/eigh.rs

Lines changed: 139 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
1-
//! Eigenvalue decomposition for Hermite matrices
1+
//! Eigenvalue decomposition for Symmetric/Hermite matrices
22
33
use super::*;
44
use crate::{error::*, layout::MatrixLayout};
55
use cauchy::*;
66
use num_traits::Zero;
77

8-
/// Wraps `*syev` for real and `*heev` for complex
98
pub trait Eigh_: Scalar {
10-
unsafe fn eigh(
9+
/// Wraps `*syev` for real and `*heev` for complex
10+
fn eigh(
1111
calc_eigenvec: bool,
12-
l: MatrixLayout,
12+
layout: MatrixLayout,
1313
uplo: UPLO,
1414
a: &mut [Self],
1515
) -> Result<Vec<Self::Real>>;
16-
unsafe fn eigh_generalized(
16+
17+
/// Wraps `*syegv` for real and `*heegv` for complex
18+
fn eigh_generalized(
1719
calc_eigenvec: bool,
18-
l: MatrixLayout,
20+
layout: MatrixLayout,
1921
uplo: UPLO,
2022
a: &mut [Self],
2123
b: &mut [Self],
@@ -25,50 +27,152 @@ pub trait Eigh_: Scalar {
2527
macro_rules! impl_eigh {
2628
($scalar:ty, $ev:path, $evg:path) => {
2729
impl Eigh_ for $scalar {
28-
unsafe fn eigh(
30+
fn eigh(
31+
calc_v: bool,
32+
layout: MatrixLayout,
33+
uplo: UPLO,
34+
mut a: &mut [Self],
35+
) -> Result<Vec<Self::Real>> {
36+
assert_eq!(layout.len(), layout.lda());
37+
let n = layout.len();
38+
let jobz = if calc_v { b'V' } else { b'N' };
39+
let mut eigs = vec![Self::Real::zero(); n as usize];
40+
let mut work = vec![Self::zero(); 3 * n as usize];
41+
let n = n as i32;
42+
let mut info = 0;
43+
unsafe {
44+
$ev(
45+
jobz,
46+
uplo as u8,
47+
n,
48+
&mut a,
49+
n,
50+
&mut eigs,
51+
&mut work,
52+
3 * n - 1,
53+
&mut info,
54+
);
55+
}
56+
info.as_lapack_result()?;
57+
Ok(eigs)
58+
}
59+
60+
fn eigh_generalized(
61+
calc_v: bool,
62+
layout: MatrixLayout,
63+
uplo: UPLO,
64+
mut a: &mut [Self],
65+
mut b: &mut [Self],
66+
) -> Result<Vec<Self::Real>> {
67+
assert_eq!(layout.len(), layout.lda());
68+
let n = layout.len();
69+
let jobz = if calc_v { b'V' } else { b'N' };
70+
let mut eigs = vec![Self::Real::zero(); n as usize];
71+
let mut work = vec![Self::zero(); 3 * n as usize - 1];
72+
let n = n as i32;
73+
let mut info = 0;
74+
unsafe {
75+
$evg(
76+
&[1],
77+
jobz,
78+
uplo as u8,
79+
n,
80+
&mut a,
81+
n,
82+
&mut b,
83+
n,
84+
&mut eigs,
85+
&mut work,
86+
3 * n - 1,
87+
&mut info,
88+
);
89+
}
90+
info.as_lapack_result()?;
91+
Ok(eigs)
92+
}
93+
}
94+
};
95+
} // impl_eigh!
96+
97+
impl_eigh!(f64, lapack::dsyev, lapack::dsygv);
98+
impl_eigh!(f32, lapack::ssyev, lapack::ssygv);
99+
100+
// splitted for RWORK
101+
macro_rules! impl_eighc {
102+
($scalar:ty, $ev:path, $evg:path) => {
103+
impl Eigh_ for $scalar {
104+
fn eigh(
29105
calc_v: bool,
30-
l: MatrixLayout,
106+
layout: MatrixLayout,
31107
uplo: UPLO,
32108
mut a: &mut [Self],
33109
) -> Result<Vec<Self::Real>> {
34-
let (n, _) = l.size();
110+
assert_eq!(layout.len(), layout.lda());
111+
let n = layout.len();
35112
let jobz = if calc_v { b'V' } else { b'N' };
36-
let mut w = vec![Self::Real::zero(); n as usize];
37-
$ev(l.lapacke_layout(), jobz, uplo as u8, n, &mut a, n, &mut w)
38-
.as_lapack_result()?;
39-
Ok(w)
113+
let mut eigs = vec![Self::Real::zero(); n as usize];
114+
let mut work = vec![Self::zero(); 2 * n as usize - 1];
115+
let mut rwork = vec![Self::Real::zero(); 3 * n as usize - 2];
116+
let mut info = 0;
117+
let n = n as i32;
118+
119+
unsafe {
120+
$ev(
121+
jobz,
122+
uplo as u8,
123+
n,
124+
&mut a,
125+
n,
126+
&mut eigs,
127+
&mut work,
128+
2 * n - 1,
129+
&mut rwork,
130+
&mut info,
131+
)
132+
};
133+
info.as_lapack_result()?;
134+
Ok(eigs)
40135
}
41136

42-
unsafe fn eigh_generalized(
137+
fn eigh_generalized(
43138
calc_v: bool,
44-
l: MatrixLayout,
139+
layout: MatrixLayout,
45140
uplo: UPLO,
46141
mut a: &mut [Self],
47142
mut b: &mut [Self],
48143
) -> Result<Vec<Self::Real>> {
49-
let (n, _) = l.size();
144+
assert_eq!(layout.len(), layout.lda());
145+
let n = layout.len();
50146
let jobz = if calc_v { b'V' } else { b'N' };
51-
let mut w = vec![Self::Real::zero(); n as usize];
52-
$evg(
53-
l.lapacke_layout(),
54-
1,
55-
jobz,
56-
uplo as u8,
57-
n,
58-
&mut a,
59-
n,
60-
&mut b,
61-
n,
62-
&mut w,
63-
)
64-
.as_lapack_result()?;
65-
Ok(w)
147+
let mut eigs = vec![Self::Real::zero(); n as usize];
148+
let mut work = vec![Self::zero(); 2 * n as usize - 1];
149+
let mut rwork = vec![Self::Real::zero(); 3 * n as usize - 2];
150+
let n = n as i32;
151+
let mut info = 0;
152+
153+
unsafe {
154+
$evg(
155+
&[1],
156+
jobz,
157+
uplo as u8,
158+
n,
159+
&mut a,
160+
n,
161+
&mut b,
162+
n,
163+
&mut eigs,
164+
&mut work,
165+
2 * n - 1,
166+
&mut rwork,
167+
&mut info,
168+
)
169+
};
170+
info.as_lapack_result()?;
171+
Ok(eigs)
66172
}
67173
}
68174
};
69175
} // impl_eigh!
70176

71-
impl_eigh!(f64, lapacke::dsyev, lapacke::dsygv);
72-
impl_eigh!(f32, lapacke::ssyev, lapacke::ssygv);
73-
impl_eigh!(c64, lapacke::zheev, lapacke::zhegv);
74-
impl_eigh!(c32, lapacke::cheev, lapacke::chegv);
177+
impl_eighc!(c64, lapack::zheev, lapack::zhegv);
178+
impl_eighc!(c32, lapack::cheev, lapack::chegv);

ndarray-linalg/src/eigh.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ where
9999
MatrixLayout::C { .. } => self.swap_axes(0, 1),
100100
MatrixLayout::F { .. } => {}
101101
}
102-
let s = unsafe { A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)? };
102+
let s = A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)?;
103103
Ok((ArrayBase::from(s), self))
104104
}
105105
}
@@ -126,15 +126,13 @@ where
126126
MatrixLayout::F { .. } => {}
127127
}
128128

129-
let s = unsafe {
130-
A::eigh_generalized(
131-
true,
132-
self.0.square_layout()?,
133-
uplo,
134-
self.0.as_allocated_mut()?,
135-
self.1.as_allocated_mut()?,
136-
)?
137-
};
129+
let s = A::eigh_generalized(
130+
true,
131+
self.0.square_layout()?,
132+
uplo,
133+
self.0.as_allocated_mut()?,
134+
self.1.as_allocated_mut()?,
135+
)?;
138136

139137
Ok((ArrayBase::from(s), self))
140138
}
@@ -191,7 +189,7 @@ where
191189
type EigVal = Array1<A::Real>;
192190

193191
fn eigvalsh_inplace(&mut self, uplo: UPLO) -> Result<Self::EigVal> {
194-
let s = unsafe { A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)? };
192+
let s = A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)?;
195193
Ok(ArrayBase::from(s))
196194
}
197195
}

0 commit comments

Comments
 (0)