Skip to content

Commit 7e0f582

Browse files
authored
Merge pull request #218 from rust-ndarray/lapack-svd
SVD using LAPACK
2 parents 34c7f5a + 2373b43 commit 7e0f582

File tree

3 files changed

+223
-64
lines changed

3 files changed

+223
-64
lines changed

lax/src/svd.rs

Lines changed: 166 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,27 @@
22
33
use crate::{error::*, layout::MatrixLayout};
44
use cauchy::*;
5-
use num_traits::Zero;
5+
use num_traits::{ToPrimitive, Zero};
66

77
#[repr(u8)]
8+
#[derive(Debug, Copy, Clone)]
89
enum FlagSVD {
910
All = b'A',
1011
// OverWrite = b'O',
1112
// Separately = b'S',
1213
No = b'N',
1314
}
1415

16+
impl FlagSVD {
17+
fn from_bool(calc_uv: bool) -> Self {
18+
if calc_uv {
19+
FlagSVD::All
20+
} else {
21+
FlagSVD::No
22+
}
23+
}
24+
}
25+
1526
/// Result of SVD
1627
pub struct SVDOutput<A: Scalar> {
1728
/// diagonal values
@@ -24,6 +35,7 @@ pub struct SVDOutput<A: Scalar> {
2435

2536
/// Wraps `*gesvd`
2637
pub trait SVD_: Scalar {
38+
/// Calculate singular value decomposition $ A = U \Sigma V^T $
2739
unsafe fn svd(
2840
l: MatrixLayout,
2941
calc_u: bool,
@@ -32,7 +44,7 @@ pub trait SVD_: Scalar {
3244
) -> Result<SVDOutput<Self>>;
3345
}
3446

35-
macro_rules! impl_svd {
47+
macro_rules! impl_svd_real {
3648
($scalar:ty, $gesvd:path) => {
3749
impl SVD_ for $scalar {
3850
unsafe fn svd(
@@ -41,48 +53,169 @@ macro_rules! impl_svd {
4153
calc_vt: bool,
4254
mut a: &mut [Self],
4355
) -> Result<SVDOutput<Self>> {
44-
let (m, n) = l.size();
45-
let k = ::std::cmp::min(n, m);
46-
let lda = l.lda();
47-
let (ju, ldu, mut u) = if calc_u {
48-
(FlagSVD::All, m, vec![Self::zero(); (m * m) as usize])
49-
} else {
50-
(FlagSVD::No, 1, Vec::new())
56+
let ju = match l {
57+
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u),
58+
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt),
5159
};
52-
let (jvt, ldvt, mut vt) = if calc_vt {
53-
(FlagSVD::All, n, vec![Self::zero(); (n * n) as usize])
54-
} else {
55-
(FlagSVD::No, n, Vec::new())
60+
let jvt = match l {
61+
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt),
62+
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u),
63+
};
64+
65+
let m = l.lda();
66+
let mut u = match ju {
67+
FlagSVD::All => Some(vec![Self::zero(); (m * m) as usize]),
68+
FlagSVD::No => None,
5669
};
70+
71+
let n = l.len();
72+
let mut vt = match jvt {
73+
FlagSVD::All => Some(vec![Self::zero(); (n * n) as usize]),
74+
FlagSVD::No => None,
75+
};
76+
77+
let k = std::cmp::min(m, n);
5778
let mut s = vec![Self::Real::zero(); k as usize];
58-
let mut superb = vec![Self::Real::zero(); (k - 1) as usize];
79+
80+
// eval work size
81+
let mut info = 0;
82+
let mut work_size = [Self::zero()];
83+
$gesvd(
84+
ju as u8,
85+
jvt as u8,
86+
m,
87+
n,
88+
&mut a,
89+
m,
90+
&mut s,
91+
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
92+
m,
93+
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
94+
n,
95+
&mut work_size,
96+
-1,
97+
&mut info,
98+
);
99+
info.as_lapack_result()?;
100+
101+
// calc
102+
let lwork = work_size[0].to_usize().unwrap();
103+
let mut work = vec![Self::zero(); lwork];
59104
$gesvd(
60-
l.lapacke_layout(),
61105
ju as u8,
62106
jvt as u8,
63107
m,
64108
n,
65109
&mut a,
66-
lda,
110+
m,
67111
&mut s,
68-
&mut u,
69-
ldu,
70-
&mut vt,
71-
ldvt,
72-
&mut superb,
73-
)
74-
.as_lapack_result()?;
75-
Ok(SVDOutput {
76-
s,
77-
u: if calc_u { Some(u) } else { None },
78-
vt: if calc_vt { Some(vt) } else { None },
79-
})
112+
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
113+
m,
114+
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
115+
n,
116+
&mut work,
117+
lwork as i32,
118+
&mut info,
119+
);
120+
info.as_lapack_result()?;
121+
match l {
122+
MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }),
123+
MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }),
124+
}
125+
}
126+
}
127+
};
128+
} // impl_svd_real!
129+
130+
impl_svd_real!(f64, lapack::dgesvd);
131+
impl_svd_real!(f32, lapack::sgesvd);
132+
133+
macro_rules! impl_svd_complex {
134+
($scalar:ty, $gesvd:path) => {
135+
impl SVD_ for $scalar {
136+
unsafe fn svd(
137+
l: MatrixLayout,
138+
calc_u: bool,
139+
calc_vt: bool,
140+
mut a: &mut [Self],
141+
) -> Result<SVDOutput<Self>> {
142+
let ju = match l {
143+
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u),
144+
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt),
145+
};
146+
let jvt = match l {
147+
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt),
148+
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u),
149+
};
150+
151+
let m = l.lda();
152+
let mut u = match ju {
153+
FlagSVD::All => Some(vec![Self::zero(); (m * m) as usize]),
154+
FlagSVD::No => None,
155+
};
156+
157+
let n = l.len();
158+
let mut vt = match jvt {
159+
FlagSVD::All => Some(vec![Self::zero(); (n * n) as usize]),
160+
FlagSVD::No => None,
161+
};
162+
163+
let k = std::cmp::min(m, n);
164+
let mut s = vec![Self::Real::zero(); k as usize];
165+
166+
let mut rwork = vec![Self::Real::zero(); 5 * k as usize];
167+
168+
// eval work size
169+
let mut info = 0;
170+
let mut work_size = [Self::zero()];
171+
$gesvd(
172+
ju as u8,
173+
jvt as u8,
174+
m,
175+
n,
176+
&mut a,
177+
m,
178+
&mut s,
179+
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
180+
m,
181+
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
182+
n,
183+
&mut work_size,
184+
-1,
185+
&mut rwork,
186+
&mut info,
187+
);
188+
info.as_lapack_result()?;
189+
190+
// calc
191+
let lwork = work_size[0].to_usize().unwrap();
192+
let mut work = vec![Self::zero(); lwork];
193+
$gesvd(
194+
ju as u8,
195+
jvt as u8,
196+
m,
197+
n,
198+
&mut a,
199+
m,
200+
&mut s,
201+
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
202+
m,
203+
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
204+
n,
205+
&mut work,
206+
lwork as i32,
207+
&mut rwork,
208+
&mut info,
209+
);
210+
info.as_lapack_result()?;
211+
match l {
212+
MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }),
213+
MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }),
214+
}
80215
}
81216
}
82217
};
83-
} // impl_svd!
218+
} // impl_svd_real!
84219

85-
impl_svd!(f64, lapacke::dgesvd);
86-
impl_svd!(f32, lapacke::sgesvd);
87-
impl_svd!(c64, lapacke::zgesvd);
88-
impl_svd!(c32, lapacke::cgesvd);
220+
impl_svd_complex!(c64, lapack::zgesvd);
221+
impl_svd_complex!(c32, lapack::cgesvd);

ndarray-linalg/src/svd.rs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
55
use ndarray::*;
66

7-
use super::convert::*;
87
use super::error::*;
98
use super::layout::*;
109
use super::types::*;
@@ -99,12 +98,27 @@ where
9998
let l = self.layout()?;
10099
let svd_res = unsafe { A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)? };
101100
let (n, m) = l.size();
102-
let u = svd_res
103-
.u
104-
.map(|u| into_matrix(l.resized(n, n), u).expect("Size of U mismatches"));
105-
let vt = svd_res
106-
.vt
107-
.map(|vt| into_matrix(l.resized(m, m), vt).expect("Size of VT mismatches"));
101+
let n = n as usize;
102+
let m = m as usize;
103+
104+
let u = svd_res.u.map(|u| {
105+
assert_eq!(u.len(), n * n);
106+
match l {
107+
MatrixLayout::F { .. } => Array::from_shape_vec((n, n).f(), u),
108+
MatrixLayout::C { .. } => Array::from_shape_vec((n, n), u),
109+
}
110+
.unwrap()
111+
});
112+
113+
let vt = svd_res.vt.map(|vt| {
114+
assert_eq!(vt.len(), m * m);
115+
match l {
116+
MatrixLayout::F { .. } => Array::from_shape_vec((m, m).f(), vt),
117+
MatrixLayout::C { .. } => Array::from_shape_vec((m, m), vt),
118+
}
119+
.unwrap()
120+
});
121+
108122
let s = ArrayBase::from(svd_res.s);
109123
Ok((u, s, vt))
110124
}

ndarray-linalg/tests/svd.rs

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use ndarray::*;
22
use ndarray_linalg::*;
33
use std::cmp::min;
44

5-
fn test(a: &Array2<f64>) {
5+
fn test<T: Scalar + Lapack>(a: &Array2<T>) {
66
let (n, m) = a.dim();
77
let answer = a.clone();
88
println!("a = \n{:?}", a);
@@ -12,14 +12,14 @@ fn test(a: &Array2<f64>) {
1212
println!("u = \n{:?}", &u);
1313
println!("s = \n{:?}", &s);
1414
println!("v = \n{:?}", &vt);
15-
let mut sm = Array::zeros((n, m));
15+
let mut sm = Array::<T, _>::zeros((n, m));
1616
for i in 0..min(n, m) {
17-
sm[(i, i)] = s[i];
17+
sm[(i, i)] = T::from(s[i]).unwrap();
1818
}
19-
assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7);
19+
assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, T::real(1e-7));
2020
}
2121

22-
fn test_no_vt(a: &Array2<f64>) {
22+
fn test_no_vt<T: Scalar + Lapack>(a: &Array2<T>) {
2323
let (n, _m) = a.dim();
2424
println!("a = \n{:?}", a);
2525
let (u, _s, vt): (_, Array1<_>, _) = a.svd(true, false).unwrap();
@@ -30,7 +30,7 @@ fn test_no_vt(a: &Array2<f64>) {
3030
assert_eq!(u.dim().1, n);
3131
}
3232

33-
fn test_no_u(a: &Array2<f64>) {
33+
fn test_no_u<T: Scalar + Lapack>(a: &Array2<T>) {
3434
let (_n, m) = a.dim();
3535
println!("a = \n{:?}", a);
3636
let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, true).unwrap();
@@ -41,40 +41,52 @@ fn test_no_u(a: &Array2<f64>) {
4141
assert_eq!(vt.dim().1, m);
4242
}
4343

44-
fn test_diag_only(a: &Array2<f64>) {
44+
fn test_diag_only<T: Scalar + Lapack>(a: &Array2<T>) {
4545
println!("a = \n{:?}", a);
4646
let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, false).unwrap();
4747
assert!(u.is_none());
4848
assert!(vt.is_none());
4949
}
5050

5151
macro_rules! test_svd_impl {
52-
($test:ident, $n:expr, $m:expr) => {
52+
($type:ty, $test:ident, $n:expr, $m:expr) => {
5353
paste::item! {
5454
#[test]
55-
fn [<svd_ $test _ $n x $m>]() {
55+
fn [<svd_ $type _ $test _ $n x $m>]() {
5656
let a = random(($n, $m));
57-
$test(&a);
57+
$test::<$type>(&a);
5858
}
5959

6060
#[test]
61-
fn [<svd_ $test _ $n x $m _t>]() {
61+
fn [<svd_ $type _ $test _ $n x $m _t>]() {
6262
let a = random(($n, $m).f());
63-
$test(&a);
63+
$test::<$type>(&a);
6464
}
6565
}
6666
};
6767
}
6868

69-
test_svd_impl!(test, 3, 3);
70-
test_svd_impl!(test_no_vt, 3, 3);
71-
test_svd_impl!(test_no_u, 3, 3);
72-
test_svd_impl!(test_diag_only, 3, 3);
73-
test_svd_impl!(test, 4, 3);
74-
test_svd_impl!(test_no_vt, 4, 3);
75-
test_svd_impl!(test_no_u, 4, 3);
76-
test_svd_impl!(test_diag_only, 4, 3);
77-
test_svd_impl!(test, 3, 4);
78-
test_svd_impl!(test_no_vt, 3, 4);
79-
test_svd_impl!(test_no_u, 3, 4);
80-
test_svd_impl!(test_diag_only, 3, 4);
69+
test_svd_impl!(f64, test, 3, 3);
70+
test_svd_impl!(f64, test_no_vt, 3, 3);
71+
test_svd_impl!(f64, test_no_u, 3, 3);
72+
test_svd_impl!(f64, test_diag_only, 3, 3);
73+
test_svd_impl!(f64, test, 4, 3);
74+
test_svd_impl!(f64, test_no_vt, 4, 3);
75+
test_svd_impl!(f64, test_no_u, 4, 3);
76+
test_svd_impl!(f64, test_diag_only, 4, 3);
77+
test_svd_impl!(f64, test, 3, 4);
78+
test_svd_impl!(f64, test_no_vt, 3, 4);
79+
test_svd_impl!(f64, test_no_u, 3, 4);
80+
test_svd_impl!(f64, test_diag_only, 3, 4);
81+
test_svd_impl!(c64, test, 3, 3);
82+
test_svd_impl!(c64, test_no_vt, 3, 3);
83+
test_svd_impl!(c64, test_no_u, 3, 3);
84+
test_svd_impl!(c64, test_diag_only, 3, 3);
85+
test_svd_impl!(c64, test, 4, 3);
86+
test_svd_impl!(c64, test_no_vt, 4, 3);
87+
test_svd_impl!(c64, test_no_u, 4, 3);
88+
test_svd_impl!(c64, test_diag_only, 4, 3);
89+
test_svd_impl!(c64, test, 3, 4);
90+
test_svd_impl!(c64, test_no_vt, 3, 4);
91+
test_svd_impl!(c64, test_no_u, 3, 4);
92+
test_svd_impl!(c64, test_diag_only, 3, 4);

0 commit comments

Comments
 (0)