Skip to content

SVD using LAPACK #218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 166 additions & 33 deletions lax/src/svd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,27 @@

use crate::{error::*, layout::MatrixLayout};
use cauchy::*;
use num_traits::Zero;
use num_traits::{ToPrimitive, Zero};

#[repr(u8)]
#[derive(Debug, Copy, Clone)]
enum FlagSVD {
All = b'A',
// OverWrite = b'O',
// Separately = b'S',
No = b'N',
}

impl FlagSVD {
fn from_bool(calc_uv: bool) -> Self {
if calc_uv {
FlagSVD::All
} else {
FlagSVD::No
}
}
}

/// Result of SVD
pub struct SVDOutput<A: Scalar> {
/// diagonal values
Expand All @@ -24,6 +35,7 @@ pub struct SVDOutput<A: Scalar> {

/// Wraps `*gesvd`
pub trait SVD_: Scalar {
/// Calculate singular value decomposition $ A = U \Sigma V^T $
unsafe fn svd(
l: MatrixLayout,
calc_u: bool,
Expand All @@ -32,7 +44,7 @@ pub trait SVD_: Scalar {
) -> Result<SVDOutput<Self>>;
}

macro_rules! impl_svd {
macro_rules! impl_svd_real {
($scalar:ty, $gesvd:path) => {
impl SVD_ for $scalar {
unsafe fn svd(
Expand All @@ -41,48 +53,169 @@ macro_rules! impl_svd {
calc_vt: bool,
mut a: &mut [Self],
) -> Result<SVDOutput<Self>> {
let (m, n) = l.size();
let k = ::std::cmp::min(n, m);
let lda = l.lda();
let (ju, ldu, mut u) = if calc_u {
(FlagSVD::All, m, vec![Self::zero(); (m * m) as usize])
} else {
(FlagSVD::No, 1, Vec::new())
let ju = match l {
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u),
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt),
};
let (jvt, ldvt, mut vt) = if calc_vt {
(FlagSVD::All, n, vec![Self::zero(); (n * n) as usize])
} else {
(FlagSVD::No, n, Vec::new())
let jvt = match l {
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt),
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u),
};

let m = l.lda();
let mut u = match ju {
FlagSVD::All => Some(vec![Self::zero(); (m * m) as usize]),
FlagSVD::No => None,
};

let n = l.len();
let mut vt = match jvt {
FlagSVD::All => Some(vec![Self::zero(); (n * n) as usize]),
FlagSVD::No => None,
};

let k = std::cmp::min(m, n);
let mut s = vec![Self::Real::zero(); k as usize];
let mut superb = vec![Self::Real::zero(); (k - 1) as usize];

// eval work size
let mut info = 0;
let mut work_size = [Self::zero()];
$gesvd(
ju as u8,
jvt as u8,
m,
n,
&mut a,
m,
&mut s,
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
m,
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
n,
&mut work_size,
-1,
&mut info,
);
info.as_lapack_result()?;

// calc
let lwork = work_size[0].to_usize().unwrap();
let mut work = vec![Self::zero(); lwork];
$gesvd(
l.lapacke_layout(),
ju as u8,
jvt as u8,
m,
n,
&mut a,
lda,
m,
&mut s,
&mut u,
ldu,
&mut vt,
ldvt,
&mut superb,
)
.as_lapack_result()?;
Ok(SVDOutput {
s,
u: if calc_u { Some(u) } else { None },
vt: if calc_vt { Some(vt) } else { None },
})
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
m,
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
n,
&mut work,
lwork as i32,
&mut info,
);
info.as_lapack_result()?;
match l {
MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }),
MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }),
}
}
}
};
} // impl_svd_real!

impl_svd_real!(f64, lapack::dgesvd);
impl_svd_real!(f32, lapack::sgesvd);

macro_rules! impl_svd_complex {
($scalar:ty, $gesvd:path) => {
impl SVD_ for $scalar {
unsafe fn svd(
l: MatrixLayout,
calc_u: bool,
calc_vt: bool,
mut a: &mut [Self],
) -> Result<SVDOutput<Self>> {
let ju = match l {
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u),
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt),
};
let jvt = match l {
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt),
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u),
};

let m = l.lda();
let mut u = match ju {
FlagSVD::All => Some(vec![Self::zero(); (m * m) as usize]),
FlagSVD::No => None,
};

let n = l.len();
let mut vt = match jvt {
FlagSVD::All => Some(vec![Self::zero(); (n * n) as usize]),
FlagSVD::No => None,
};

let k = std::cmp::min(m, n);
let mut s = vec![Self::Real::zero(); k as usize];

let mut rwork = vec![Self::Real::zero(); 5 * k as usize];

// eval work size
let mut info = 0;
let mut work_size = [Self::zero()];
$gesvd(
ju as u8,
jvt as u8,
m,
n,
&mut a,
m,
&mut s,
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
m,
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
n,
&mut work_size,
-1,
&mut rwork,
&mut info,
);
info.as_lapack_result()?;

// calc
let lwork = work_size[0].to_usize().unwrap();
let mut work = vec![Self::zero(); lwork];
$gesvd(
ju as u8,
jvt as u8,
m,
n,
&mut a,
m,
&mut s,
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
m,
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
n,
&mut work,
lwork as i32,
&mut rwork,
&mut info,
);
info.as_lapack_result()?;
match l {
MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }),
MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }),
}
}
}
};
} // impl_svd!
} // impl_svd_real!

impl_svd!(f64, lapacke::dgesvd);
impl_svd!(f32, lapacke::sgesvd);
impl_svd!(c64, lapacke::zgesvd);
impl_svd!(c32, lapacke::cgesvd);
impl_svd_complex!(c64, lapack::zgesvd);
impl_svd_complex!(c32, lapack::cgesvd);
28 changes: 21 additions & 7 deletions ndarray-linalg/src/svd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

use ndarray::*;

use super::convert::*;
use super::error::*;
use super::layout::*;
use super::types::*;
Expand Down Expand Up @@ -99,12 +98,27 @@ where
let l = self.layout()?;
let svd_res = unsafe { A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)? };
let (n, m) = l.size();
let u = svd_res
.u
.map(|u| into_matrix(l.resized(n, n), u).expect("Size of U mismatches"));
let vt = svd_res
.vt
.map(|vt| into_matrix(l.resized(m, m), vt).expect("Size of VT mismatches"));
let n = n as usize;
let m = m as usize;

let u = svd_res.u.map(|u| {
assert_eq!(u.len(), n * n);
match l {
MatrixLayout::F { .. } => Array::from_shape_vec((n, n).f(), u),
MatrixLayout::C { .. } => Array::from_shape_vec((n, n), u),
}
.unwrap()
});

let vt = svd_res.vt.map(|vt| {
assert_eq!(vt.len(), m * m);
match l {
MatrixLayout::F { .. } => Array::from_shape_vec((m, m).f(), vt),
MatrixLayout::C { .. } => Array::from_shape_vec((m, m), vt),
}
.unwrap()
});

let s = ArrayBase::from(svd_res.s);
Ok((u, s, vt))
}
Expand Down
60 changes: 36 additions & 24 deletions ndarray-linalg/tests/svd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use ndarray::*;
use ndarray_linalg::*;
use std::cmp::min;

fn test(a: &Array2<f64>) {
fn test<T: Scalar + Lapack>(a: &Array2<T>) {
let (n, m) = a.dim();
let answer = a.clone();
println!("a = \n{:?}", a);
Expand All @@ -12,14 +12,14 @@ fn test(a: &Array2<f64>) {
println!("u = \n{:?}", &u);
println!("s = \n{:?}", &s);
println!("v = \n{:?}", &vt);
let mut sm = Array::zeros((n, m));
let mut sm = Array::<T, _>::zeros((n, m));
for i in 0..min(n, m) {
sm[(i, i)] = s[i];
sm[(i, i)] = T::from(s[i]).unwrap();
}
assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7);
assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, T::real(1e-7));
}

fn test_no_vt(a: &Array2<f64>) {
fn test_no_vt<T: Scalar + Lapack>(a: &Array2<T>) {
let (n, _m) = a.dim();
println!("a = \n{:?}", a);
let (u, _s, vt): (_, Array1<_>, _) = a.svd(true, false).unwrap();
Expand All @@ -30,7 +30,7 @@ fn test_no_vt(a: &Array2<f64>) {
assert_eq!(u.dim().1, n);
}

fn test_no_u(a: &Array2<f64>) {
fn test_no_u<T: Scalar + Lapack>(a: &Array2<T>) {
let (_n, m) = a.dim();
println!("a = \n{:?}", a);
let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, true).unwrap();
Expand All @@ -41,40 +41,52 @@ fn test_no_u(a: &Array2<f64>) {
assert_eq!(vt.dim().1, m);
}

fn test_diag_only(a: &Array2<f64>) {
fn test_diag_only<T: Scalar + Lapack>(a: &Array2<T>) {
println!("a = \n{:?}", a);
let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, false).unwrap();
assert!(u.is_none());
assert!(vt.is_none());
}

macro_rules! test_svd_impl {
($test:ident, $n:expr, $m:expr) => {
($type:ty, $test:ident, $n:expr, $m:expr) => {
paste::item! {
#[test]
fn [<svd_ $test _ $n x $m>]() {
fn [<svd_ $type _ $test _ $n x $m>]() {
let a = random(($n, $m));
$test(&a);
$test::<$type>(&a);
}

#[test]
fn [<svd_ $test _ $n x $m _t>]() {
fn [<svd_ $type _ $test _ $n x $m _t>]() {
let a = random(($n, $m).f());
$test(&a);
$test::<$type>(&a);
}
}
};
}

test_svd_impl!(test, 3, 3);
test_svd_impl!(test_no_vt, 3, 3);
test_svd_impl!(test_no_u, 3, 3);
test_svd_impl!(test_diag_only, 3, 3);
test_svd_impl!(test, 4, 3);
test_svd_impl!(test_no_vt, 4, 3);
test_svd_impl!(test_no_u, 4, 3);
test_svd_impl!(test_diag_only, 4, 3);
test_svd_impl!(test, 3, 4);
test_svd_impl!(test_no_vt, 3, 4);
test_svd_impl!(test_no_u, 3, 4);
test_svd_impl!(test_diag_only, 3, 4);
test_svd_impl!(f64, test, 3, 3);
test_svd_impl!(f64, test_no_vt, 3, 3);
test_svd_impl!(f64, test_no_u, 3, 3);
test_svd_impl!(f64, test_diag_only, 3, 3);
test_svd_impl!(f64, test, 4, 3);
test_svd_impl!(f64, test_no_vt, 4, 3);
test_svd_impl!(f64, test_no_u, 4, 3);
test_svd_impl!(f64, test_diag_only, 4, 3);
test_svd_impl!(f64, test, 3, 4);
test_svd_impl!(f64, test_no_vt, 3, 4);
test_svd_impl!(f64, test_no_u, 3, 4);
test_svd_impl!(f64, test_diag_only, 3, 4);
test_svd_impl!(c64, test, 3, 3);
test_svd_impl!(c64, test_no_vt, 3, 3);
test_svd_impl!(c64, test_no_u, 3, 3);
test_svd_impl!(c64, test_diag_only, 3, 3);
test_svd_impl!(c64, test, 4, 3);
test_svd_impl!(c64, test_no_vt, 4, 3);
test_svd_impl!(c64, test_no_u, 4, 3);
test_svd_impl!(c64, test_diag_only, 4, 3);
test_svd_impl!(c64, test, 3, 4);
test_svd_impl!(c64, test_no_vt, 3, 4);
test_svd_impl!(c64, test_no_u, 3, 4);
test_svd_impl!(c64, test_diag_only, 3, 4);