Skip to content

Commit bde41ed

Browse files
committed
WIP: impl eig based on LAPACK
1 parent 37f9e4d commit bde41ed

File tree

2 files changed

+145
-55
lines changed

2 files changed

+145
-55
lines changed

lax/src/eig.rs

Lines changed: 141 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
33
use crate::{error::*, layout::MatrixLayout};
44
use cauchy::*;
5-
use num_traits::Zero;
5+
use num_traits::{ToPrimitive, Zero};
66

7-
/// Wraps `*geev` for real/complex
7+
/// Wraps `*geev` for general matrices
88
pub trait Eig_: Scalar {
9-
unsafe fn eig(
9+
/// Calculate Right eigenvalue
10+
fn eig(
1011
calc_v: bool,
1112
l: MatrixLayout,
1213
a: &mut [Self],
@@ -16,65 +17,159 @@ pub trait Eig_: Scalar {
1617
macro_rules! impl_eig_complex {
1718
($scalar:ty, $ev:path) => {
1819
impl Eig_ for $scalar {
19-
unsafe fn eig(
20+
fn eig(
2021
calc_v: bool,
2122
l: MatrixLayout,
2223
mut a: &mut [Self],
2324
) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)> {
2425
let (n, _) = l.size();
25-
let jobvr = if calc_v { b'V' } else { b'N' };
26-
let mut w = vec![Self::Complex::zero(); n as usize];
27-
let mut vl = Vec::new();
28-
let mut vr = vec![Self::Complex::zero(); (n * n) as usize];
29-
$ev(
30-
l.lapacke_layout(),
31-
b'N',
32-
jobvr,
33-
n,
34-
&mut a,
35-
n,
36-
&mut w,
37-
&mut vl,
38-
n,
39-
&mut vr,
40-
n,
41-
)
42-
.as_lapack_result()?;
43-
Ok((w, vr))
26+
// Because LAPACK assumes F-continious array, C-continious array should be taken Hermitian conjugate.
27+
// However, we utilize a fact that left eigenvector of A^H corresponds to the right eigenvector of A
28+
let (jobvl, jobvr) = if calc_v {
29+
match l {
30+
MatrixLayout::C { .. } => (b'V', b'N'),
31+
MatrixLayout::F { .. } => (b'N', b'V'),
32+
}
33+
} else {
34+
(b'N', b'N')
35+
};
36+
let mut eigs = vec![Self::Complex::zero(); n as usize];
37+
let mut rwork = vec![Self::Real::zero(); 2 * n as usize];
38+
39+
let mut vl = if jobvl == b'V' {
40+
Some(vec![Self::Complex::zero(); (n * n) as usize])
41+
} else {
42+
None
43+
};
44+
let mut vr = if jobvr == b'V' {
45+
Some(vec![Self::Complex::zero(); (n * n) as usize])
46+
} else {
47+
None
48+
};
49+
50+
// calc work size
51+
let mut info = 0;
52+
let mut work_size = [Self::zero()];
53+
unsafe {
54+
$ev(
55+
jobvl,
56+
jobvr,
57+
n,
58+
&mut a,
59+
n,
60+
&mut eigs,
61+
&mut vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
62+
n,
63+
&mut vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
64+
n,
65+
&mut work_size,
66+
-1,
67+
&mut rwork,
68+
&mut info,
69+
)
70+
};
71+
info.as_lapack_result()?;
72+
73+
// actal ev
74+
let lwork = work_size[0].to_usize().unwrap();
75+
let mut work = vec![Self::zero(); lwork];
76+
unsafe {
77+
$ev(
78+
jobvl,
79+
jobvr,
80+
n,
81+
&mut a,
82+
n,
83+
&mut eigs,
84+
&mut vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
85+
n,
86+
&mut vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
87+
n,
88+
&mut work,
89+
lwork as i32,
90+
&mut rwork,
91+
&mut info,
92+
)
93+
};
94+
info.as_lapack_result()?;
95+
96+
// Hermite conjugate
97+
if jobvl == b'V' {
98+
for c in vl.as_mut().unwrap().iter_mut() {
99+
c.im = -c.im
100+
}
101+
}
102+
103+
Ok((eigs, vr.or(vl).unwrap_or(Vec::new())))
44104
}
45105
}
46106
};
47107
}
48108

109+
impl_eig_complex!(c64, lapack::zgeev);
110+
impl_eig_complex!(c32, lapack::cgeev);
111+
49112
macro_rules! impl_eig_real {
50113
($scalar:ty, $ev:path) => {
51114
impl Eig_ for $scalar {
52-
unsafe fn eig(
115+
fn eig(
53116
calc_v: bool,
54117
l: MatrixLayout,
55118
mut a: &mut [Self],
56119
) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)> {
57120
let (n, _) = l.size();
58121
let jobvr = if calc_v { b'V' } else { b'N' };
59-
let mut wr = vec![Self::Real::zero(); n as usize];
60-
let mut wi = vec![Self::Real::zero(); n as usize];
61-
let mut vl = Vec::new();
62-
let mut vr = vec![Self::Real::zero(); (n * n) as usize];
63-
let info = $ev(
64-
l.lapacke_layout(),
65-
b'N',
66-
jobvr,
67-
n,
68-
&mut a,
69-
n,
70-
&mut wr,
71-
&mut wi,
72-
&mut vl,
73-
n,
74-
&mut vr,
75-
n,
76-
);
77-
let w: Vec<Self::Complex> = wr
122+
let mut wr = vec![Self::zero(); n as usize];
123+
let mut wi = vec![Self::zero(); n as usize];
124+
let mut vr = vec![Self::zero(); (n * n) as usize];
125+
126+
// calc work size
127+
let mut info = 0;
128+
let mut work_size = [0.0];
129+
unsafe {
130+
$ev(
131+
b'N',
132+
jobvr,
133+
n,
134+
&mut a,
135+
n,
136+
&mut wr,
137+
&mut wi,
138+
&mut [],
139+
n,
140+
&mut vr,
141+
n,
142+
&mut work_size,
143+
-1,
144+
&mut info,
145+
)
146+
};
147+
info.as_lapack_result()?;
148+
149+
// actual ev
150+
let lwork = work_size[0].to_usize().unwrap();
151+
let mut work = vec![Self::zero(); lwork];
152+
unsafe {
153+
$ev(
154+
b'N',
155+
jobvr,
156+
n,
157+
&mut a,
158+
n,
159+
&mut wr,
160+
&mut wi,
161+
&mut [],
162+
n,
163+
&mut vr,
164+
n,
165+
&mut work,
166+
lwork as i32,
167+
&mut info,
168+
)
169+
};
170+
info.as_lapack_result()?;
171+
172+
let eigs: Vec<Self::Complex> = wr
78173
.iter()
79174
.zip(wi.iter())
80175
.map(|(&r, &i)| Self::Complex::new(r, i))
@@ -119,14 +214,11 @@ macro_rules! impl_eig_real {
119214
})
120215
.collect();
121216

122-
info.as_lapack_result()?;
123-
Ok((w, v))
217+
Ok((eigs, v))
124218
}
125219
}
126220
};
127221
}
128222

129-
impl_eig_real!(f64, lapacke::dgeev);
130-
impl_eig_real!(f32, lapacke::sgeev);
131-
impl_eig_complex!(c64, lapacke::zgeev);
132-
impl_eig_complex!(c32, lapacke::cgeev);
223+
impl_eig_real!(f64, lapack::dgeev);
224+
impl_eig_real!(f32, lapack::sgeev);

ndarray-linalg/src/eig.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,11 @@ where
4848
fn eig(&self) -> Result<(Self::EigVal, Self::EigVec)> {
4949
let mut a = self.to_owned();
5050
let layout = a.square_layout()?;
51-
let (s, t) = unsafe { A::eig(true, layout, a.as_allocated_mut()?)? };
52-
let (n, _) = layout.size();
51+
let (s, t) = A::eig(true, layout, a.as_allocated_mut()?)?;
52+
let n = layout.len() as usize;
5353
Ok((
5454
ArrayBase::from(s),
55-
ArrayBase::from(t)
56-
.into_shape((n as usize, n as usize))
57-
.unwrap(),
55+
Array2::from_shape_vec((n, n).f(), t).unwrap(),
5856
))
5957
}
6058
}
@@ -74,7 +72,7 @@ where
7472

7573
fn eigvals(&self) -> Result<Self::EigVal> {
7674
let mut a = self.to_owned();
77-
let (s, _) = unsafe { A::eig(true, a.square_layout()?, a.as_allocated_mut()?)? };
75+
let (s, _) = A::eig(true, a.square_layout()?, a.as_allocated_mut()?)?;
7876
Ok(ArrayBase::from(s))
7977
}
8078
}

0 commit comments

Comments
 (0)