From 2c787f577739cbe7d7c828067b5e57338f7abbbe Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 19 Jun 2019 15:18:26 +0900 Subject: [PATCH 1/6] Add test for SVD without u or vt --- tests/svd.rs | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/svd.rs b/tests/svd.rs index 459b83d2..46718d87 100644 --- a/tests/svd.rs +++ b/tests/svd.rs @@ -3,6 +3,12 @@ use ndarray_linalg::*; use std::cmp::min; fn test(a: &Array2, n: usize, m: usize) { + test_both(a, n, m); + test_u(a, n, m); + test_vt(a, n, m); +} + +fn test_both(a: &Array2, n: usize, m: usize) { let answer = a.clone(); println!("a = \n{:?}", a); let (u, s, vt): (_, Array1<_>, _) = a.svd(true, true).unwrap(); @@ -18,6 +24,26 @@ fn test(a: &Array2, n: usize, m: usize) { assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7); } +fn test_u(a: &Array2, n: usize, _m: usize) { + println!("a = \n{:?}", a); + let (u, _s, vt): (_, Array1<_>, _) = a.svd(true, false).unwrap(); + assert!(u.is_some()); + assert!(vt.is_none()); + let u = u.unwrap(); + assert_eq!(u.dim().0, n); + assert_eq!(u.dim().1, n); +} + +fn test_vt(a: &Array2, _n: usize, m: usize) { + println!("a = \n{:?}", a); + let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, true).unwrap(); + assert!(u.is_none()); + assert!(vt.is_some()); + let vt = vt.unwrap(); + assert_eq!(vt.dim().0, m); + assert_eq!(vt.dim().1, m); +} + #[test] fn svd_square() { let a = random((3, 3)); From 48c21066744524a7dadf4e763e98d3e957a0947b Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Thu, 20 Jun 2019 04:13:39 +0900 Subject: [PATCH 2/6] Generate tests using paste crate --- Cargo.toml | 3 +++ tests/svd.rs | 71 ++++++++++++++++++++++------------------------------ 2 files changed, 33 insertions(+), 41 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 486f44f4..214c8ee4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,3 +46,6 @@ version = "0.6" default-features = false features = ["static"] optional = true + +[dev-dependencies] +paste = "*" diff --git a/tests/svd.rs b/tests/svd.rs index 46718d87..9c6ff376 100644 --- a/tests/svd.rs +++ b/tests/svd.rs @@ -2,13 +2,8 @@ use ndarray::*; use ndarray_linalg::*; use std::cmp::min; -fn test(a: &Array2, n: usize, m: usize) { - test_both(a, n, m); - test_u(a, n, m); - test_vt(a, n, m); -} - -fn test_both(a: &Array2, n: usize, m: usize) { +fn test(a: &Array2) { + let (n, m) = a.dim(); let answer = a.clone(); println!("a = \n{:?}", a); let (u, s, vt): (_, Array1<_>, _) = a.svd(true, true).unwrap(); @@ -24,7 +19,8 @@ fn test_both(a: &Array2, n: usize, m: usize) { assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7); } -fn test_u(a: &Array2, n: usize, _m: usize) { +fn test_u(a: &Array2) { + let (n, _m) = a.dim(); println!("a = \n{:?}", a); let (u, _s, vt): (_, Array1<_>, _) = a.svd(true, false).unwrap(); assert!(u.is_some()); @@ -34,7 +30,8 @@ fn test_u(a: &Array2, n: usize, _m: usize) { assert_eq!(u.dim().1, n); } -fn test_vt(a: &Array2, _n: usize, m: usize) { +fn test_vt(a: &Array2) { + let (_n, m) = a.dim(); println!("a = \n{:?}", a); let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, true).unwrap(); assert!(u.is_none()); @@ -44,38 +41,30 @@ fn test_vt(a: &Array2, _n: usize, m: usize) { assert_eq!(vt.dim().1, m); } -#[test] -fn svd_square() { - let a = random((3, 3)); - test(&a, 3, 3); -} - -#[test] -fn svd_square_t() { - let a = random((3, 3).f()); - test(&a, 3, 3); -} - -#[test] -fn svd_3x4() { - let a = random((3, 4)); - test(&a, 3, 4); -} +macro_rules! test_svd_impl { + ($test:ident, $n:expr, $m:expr) => { + paste::item! { + #[test] + fn []() { + let a = random(($n, $m)); + $test(&a); + } -#[test] -fn svd_3x4_t() { - let a = random((3, 4).f()); - test(&a, 3, 4); + #[test] + fn []() { + let a = random(($n, $m).f()); + $test(&a); + } + } + }; } -#[test] -fn svd_4x3() { - let a = random((4, 3)); - test(&a, 4, 3); -} - -#[test] -fn svd_4x3_t() { - let a = random((4, 3).f()); - test(&a, 4, 3); -} +test_svd_impl!(test, 3, 3); +test_svd_impl!(test_u, 3, 3); +test_svd_impl!(test_vt, 3, 3); +test_svd_impl!(test, 4, 3); +test_svd_impl!(test_u, 4, 3); +test_svd_impl!(test_vt, 4, 3); +test_svd_impl!(test, 3, 4); +test_svd_impl!(test_u, 3, 4); +test_svd_impl!(test_vt, 3, 4); From 9ca911e0eb9ca3d7c83ebc6bbbe9f2fd93c5517c Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Thu, 20 Jun 2019 04:20:20 +0900 Subject: [PATCH 3/6] Fix leading dimensions LAPACK document says > The leading dimension of the array U. LDU >= 1; if JOBU = 'S' or 'A' --- src/lapack/svd.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lapack/svd.rs b/src/lapack/svd.rs index 8766ce52..b7543f10 100644 --- a/src/lapack/svd.rs +++ b/src/lapack/svd.rs @@ -42,12 +42,12 @@ macro_rules! impl_svd { let (ju, ldu, mut u) = if calc_u { (FlagSVD::All, m, vec![Self::zero(); (m * m) as usize]) } else { - (FlagSVD::No, 0, Vec::new()) + (FlagSVD::No, 1, Vec::new()) }; let (jvt, ldvt, mut vt) = if calc_vt { (FlagSVD::All, n, vec![Self::zero(); (n * n) as usize]) } else { - (FlagSVD::No, 0, Vec::new()) + (FlagSVD::No, 1, Vec::new()) }; let mut s = vec![Self::Real::zero(); k as usize]; let mut superb = vec![Self::Real::zero(); (k - 1) as usize]; From 44821c9a1e6e9d01f18961b5813bbdf034b47f83 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Thu, 20 Jun 2019 04:29:26 +0900 Subject: [PATCH 4/6] Bug fix of ldvt --- src/lapack/svd.rs | 7 ++++--- src/svd.rs | 8 ++++++-- tests/svd.rs | 16 ++++++++-------- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/lapack/svd.rs b/src/lapack/svd.rs index b7543f10..ca06502a 100644 --- a/src/lapack/svd.rs +++ b/src/lapack/svd.rs @@ -47,10 +47,11 @@ macro_rules! impl_svd { let (jvt, ldvt, mut vt) = if calc_vt { (FlagSVD::All, n, vec![Self::zero(); (n * n) as usize]) } else { - (FlagSVD::No, 1, Vec::new()) + (FlagSVD::No, n, Vec::new()) }; let mut s = vec![Self::Real::zero(); k as usize]; let mut superb = vec![Self::Real::zero(); (k - 1) as usize]; + dbg!(ldvt); let info = $gesvd( l.lapacke_layout(), ju as u8, @@ -70,8 +71,8 @@ macro_rules! impl_svd { info, SVDOutput { s: s, - u: if ldu > 0 { Some(u) } else { None }, - vt: if ldvt > 0 { Some(vt) } else { None }, + u: if calc_u { Some(u) } else { None }, + vt: if calc_vt { Some(vt) } else { None }, }, ) } diff --git a/src/svd.rs b/src/svd.rs index bbe0804d..3acce0c2 100644 --- a/src/svd.rs +++ b/src/svd.rs @@ -75,8 +75,12 @@ where let l = self.layout()?; let svd_res = unsafe { A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)? }; let (n, m) = l.size(); - let u = svd_res.u.map(|u| into_matrix(l.resized(n, n), u).unwrap()); - let vt = svd_res.vt.map(|vt| into_matrix(l.resized(m, m), vt).unwrap()); + let u = svd_res + .u + .map(|u| into_matrix(l.resized(n, n), u).expect("Size of U mismatches")); + let vt = svd_res + .vt + .map(|vt| into_matrix(l.resized(m, m), vt).expect("Size of VT mismatches")); let s = ArrayBase::from_vec(svd_res.s); Ok((u, s, vt)) } diff --git a/tests/svd.rs b/tests/svd.rs index 9c6ff376..aed32072 100644 --- a/tests/svd.rs +++ b/tests/svd.rs @@ -19,7 +19,7 @@ fn test(a: &Array2) { assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7); } -fn test_u(a: &Array2) { +fn test_no_vt(a: &Array2) { let (n, _m) = a.dim(); println!("a = \n{:?}", a); let (u, _s, vt): (_, Array1<_>, _) = a.svd(true, false).unwrap(); @@ -30,7 +30,7 @@ fn test_u(a: &Array2) { assert_eq!(u.dim().1, n); } -fn test_vt(a: &Array2) { +fn test_no_u(a: &Array2) { let (_n, m) = a.dim(); println!("a = \n{:?}", a); let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, true).unwrap(); @@ -60,11 +60,11 @@ macro_rules! test_svd_impl { } test_svd_impl!(test, 3, 3); -test_svd_impl!(test_u, 3, 3); -test_svd_impl!(test_vt, 3, 3); +test_svd_impl!(test_no_vt, 3, 3); +test_svd_impl!(test_no_u, 3, 3); test_svd_impl!(test, 4, 3); -test_svd_impl!(test_u, 4, 3); -test_svd_impl!(test_vt, 4, 3); +test_svd_impl!(test_no_vt, 4, 3); +test_svd_impl!(test_no_u, 4, 3); test_svd_impl!(test, 3, 4); -test_svd_impl!(test_u, 3, 4); -test_svd_impl!(test_vt, 3, 4); +test_svd_impl!(test_no_vt, 3, 4); +test_svd_impl!(test_no_u, 3, 4); From ea8dc6c3316c2867b73e93d83f0c59b2bd0cf44a Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Thu, 20 Jun 2019 04:35:31 +0900 Subject: [PATCH 5/6] Add test for diag-only --- tests/svd.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/svd.rs b/tests/svd.rs index aed32072..acc6ffca 100644 --- a/tests/svd.rs +++ b/tests/svd.rs @@ -41,6 +41,13 @@ fn test_no_u(a: &Array2) { assert_eq!(vt.dim().1, m); } +fn test_diag_only(a: &Array2) { + println!("a = \n{:?}", a); + let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, false).unwrap(); + assert!(u.is_none()); + assert!(vt.is_none()); +} + macro_rules! test_svd_impl { ($test:ident, $n:expr, $m:expr) => { paste::item! { @@ -62,9 +69,12 @@ macro_rules! test_svd_impl { test_svd_impl!(test, 3, 3); test_svd_impl!(test_no_vt, 3, 3); test_svd_impl!(test_no_u, 3, 3); +test_svd_impl!(test_diag_only, 3, 3); test_svd_impl!(test, 4, 3); test_svd_impl!(test_no_vt, 4, 3); test_svd_impl!(test_no_u, 4, 3); +test_svd_impl!(test_diag_only, 4, 3); test_svd_impl!(test, 3, 4); test_svd_impl!(test_no_vt, 3, 4); test_svd_impl!(test_no_u, 3, 4); +test_svd_impl!(test_diag_only, 3, 4); From 4d06adb740c15d5d5231be707920f517860a22e9 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Thu, 20 Jun 2019 04:36:00 +0900 Subject: [PATCH 6/6] paste = "0.1" --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 214c8ee4..9257b5dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,4 +48,4 @@ features = ["static"] optional = true [dev-dependencies] -paste = "*" +paste = "0.1"