From 9cf83316f7610448c865e685e4deda21407b6efb Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 19 Jul 2020 02:15:45 +0900 Subject: [PATCH 1/7] Move test of least_squares into tests/, and minor cleanup --- ndarray-linalg/src/least_squares.rs | 205 ++------------------------ ndarray-linalg/tests/least_squares.rs | 161 ++++++++++++++++++++ 2 files changed, 177 insertions(+), 189 deletions(-) create mode 100644 ndarray-linalg/tests/least_squares.rs diff --git a/ndarray-linalg/src/least_squares.rs b/ndarray-linalg/src/least_squares.rs index df99f0e8..c8f26355 100644 --- a/ndarray-linalg/src/least_squares.rs +++ b/ndarray-linalg/src/least_squares.rs @@ -414,117 +414,9 @@ fn compute_residual_array1>( #[cfg(test)] mod tests { - use super::*; + use crate::{error::LinalgError, *}; use approx::AbsDiffEq; - use ndarray::{ArcArray1, ArcArray2, Array1, Array2, CowArray}; - use num_complex::Complex; - - // - // Test cases taken from the scipy test suite for the scipy lstsq function - // https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/test_basic.py - // - #[test] - fn scipy_test_simple_exact() { - let a = array![[1., 20.], [-30., 4.]]; - let bs = vec![ - array![[1., 0.], [0., 1.]], - array![[1.], [0.]], - array![[2., 1.], [-30., 4.]], - ]; - for b in &bs { - let res = a.least_squares(b).unwrap(); - assert_eq!(res.rank, 2); - let b_hat = a.dot(&res.solution); - let rssq = (b - &b_hat).mapv(|x| x.powi(2)).sum_axis(Axis(0)); - assert!(res - .residual_sum_of_squares - .unwrap() - .abs_diff_eq(&rssq, 1e-12)); - assert!(b_hat.abs_diff_eq(&b, 1e-12)); - } - } - - #[test] - fn scipy_test_simple_overdetermined() { - let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; - let b: Array1 = array![1., 2., 3.]; - let res = a.least_squares(&b).unwrap(); - assert_eq!(res.rank, 2); - let b_hat = a.dot(&res.solution); - let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum(); - assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12)); - assert!(res - .solution - .abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-12)); - } - - #[test] - fn scipy_test_simple_overdetermined_f32() { - let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; - let b: Array1 = array![1., 2., 3.]; - let res = a.least_squares(&b).unwrap(); - assert_eq!(res.rank, 2); - let b_hat = a.dot(&res.solution); - let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum(); - assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-6)); - assert!(res - .solution - .abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-6)); - } - - fn c(re: f64, im: f64) -> Complex { - Complex::new(re, im) - } - - #[test] - fn scipy_test_simple_overdetermined_complex() { - let a: Array2 = array![ - [c(1., 2.), c(2., 0.)], - [c(4., 0.), c(5., 0.)], - [c(3., 0.), c(4., 0.)] - ]; - let b: Array1 = array![c(1., 0.), c(2., 4.), c(3., 0.)]; - let res = a.least_squares(&b).unwrap(); - assert_eq!(res.rank, 2); - let b_hat = a.dot(&res.solution); - let rssq = (&b_hat - &b).mapv(|x| x.powi(2).abs()).sum(); - assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12)); - assert!(res.solution.abs_diff_eq( - &array![ - c(-0.4831460674157303, 0.258426966292135), - c(0.921348314606741, 0.292134831460674) - ], - 1e-12 - )); - } - - #[test] - fn scipy_test_simple_underdetermined() { - let a: Array2 = array![[1., 2., 3.], [4., 5., 6.]]; - let b: Array1 = array![1., 2.]; - let res = a.least_squares(&b).unwrap(); - assert_eq!(res.rank, 2); - assert!(res.residual_sum_of_squares.is_none()); - let expected = array![-0.055555555555555, 0.111111111111111, 0.277777777777777]; - assert!(res.solution.abs_diff_eq(&expected, 1e-12)); - } - - /// This test case tests the underdetermined case for multiple right hand - /// sides. Adapted from scipy lstsq tests. - #[test] - fn scipy_test_simple_underdetermined_nrhs() { - let a: Array2 = array![[1., 2., 3.], [4., 5., 6.]]; - let b: Array2 = array![[1., 1.], [2., 2.]]; - let res = a.least_squares(&b).unwrap(); - assert_eq!(res.rank, 2); - assert!(res.residual_sum_of_squares.is_none()); - let expected = array![ - [-0.055555555555555, -0.055555555555555], - [0.111111111111111, 0.111111111111111], - [0.277777777777777, 0.277777777777777] - ]; - assert!(res.solution.abs_diff_eq(&expected, 1e-12)); - } + use ndarray::*; // // Test that the different lest squares traits work as intended on the @@ -554,7 +446,7 @@ mod tests { } #[test] - fn test_least_squares_on_arc() { + fn on_arc() { let a: ArcArray2 = array![[1., 2.], [4., 5.], [3., 4.]].into_shared(); let b: ArcArray1 = array![1., 2., 3.].into_shared(); let res = a.least_squares(&b).unwrap(); @@ -562,7 +454,7 @@ mod tests { } #[test] - fn test_least_squares_on_cow() { + fn on_cow() { let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]); let b = CowArray::from(array![1., 2., 3.]); let res = a.least_squares(&b).unwrap(); @@ -570,7 +462,7 @@ mod tests { } #[test] - fn test_least_squares_on_view() { + fn on_view() { let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; let b: Array1 = array![1., 2., 3.]; let av = a.view(); @@ -580,7 +472,7 @@ mod tests { } #[test] - fn test_least_squares_on_view_mut() { + fn on_view_mut() { let mut a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; let mut b: Array1 = array![1., 2., 3.]; let av = a.view_mut(); @@ -590,7 +482,7 @@ mod tests { } #[test] - fn test_least_squares_into_on_owned() { + fn into_on_owned() { let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; let b: Array1 = array![1., 2., 3.]; let ac = a.clone(); @@ -600,7 +492,7 @@ mod tests { } #[test] - fn test_least_squares_into_on_arc() { + fn into_on_arc() { let a: ArcArray2 = array![[1., 2.], [4., 5.], [3., 4.]].into_shared(); let b: ArcArray1 = array![1., 2., 3.].into_shared(); let a2 = a.clone(); @@ -610,7 +502,7 @@ mod tests { } #[test] - fn test_least_squares_into_on_cow() { + fn into_on_cow() { let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]); let b = CowArray::from(array![1., 2., 3.]); let a2 = a.clone(); @@ -620,7 +512,7 @@ mod tests { } #[test] - fn test_least_squares_in_place_on_owned() { + fn in_place_on_owned() { let a = array![[1., 2.], [4., 5.], [3., 4.]]; let b = array![1., 2., 3.]; let mut a2 = a.clone(); @@ -630,7 +522,7 @@ mod tests { } #[test] - fn test_least_squares_in_place_on_cow() { + fn in_place_on_cow() { let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]); let b = CowArray::from(array![1., 2., 3.]); let mut a2 = a.clone(); @@ -640,7 +532,7 @@ mod tests { } #[test] - fn test_least_squares_in_place_on_mut_view() { + fn in_place_on_mut_view() { let a = array![[1., 2.], [4., 5.], [3., 4.]]; let b = array![1., 2., 3.]; let mut a2 = a.clone(); @@ -651,95 +543,30 @@ mod tests { assert_result(&a, &b, &res); } - // - // Test cases taken from the netlib documentation at - // https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code - // - #[test] - fn netlib_lapack_example_for_dgels_1() { - let a: Array2 = array![ - [1., 1., 1.], - [2., 3., 4.], - [3., 5., 2.], - [4., 2., 5.], - [5., 4., 3.] - ]; - let b: Array1 = array![-10., 12., 14., 16., 18.]; - let expected: Array1 = array![2., 1., 1.]; - let result = a.least_squares(&b).unwrap(); - assert!(result.solution.abs_diff_eq(&expected, 1e-12)); - - let residual = b - a.dot(&result.solution); - let resid_ssq = result.residual_sum_of_squares.unwrap(); - assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12); - } - - #[test] - fn netlib_lapack_example_for_dgels_2() { - let a: Array2 = array![ - [1., 1., 1.], - [2., 3., 4.], - [3., 5., 2.], - [4., 2., 5.], - [5., 4., 3.] - ]; - let b: Array1 = array![-3., 14., 12., 16., 16.]; - let expected: Array1 = array![1., 1., 2.]; - let result = a.least_squares(&b).unwrap(); - assert!(result.solution.abs_diff_eq(&expected, 1e-12)); - - let residual = b - a.dot(&result.solution); - let resid_ssq = result.residual_sum_of_squares.unwrap(); - assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12); - } - - #[test] - fn netlib_lapack_example_for_dgels_nrhs() { - let a: Array2 = array![ - [1., 1., 1.], - [2., 3., 4.], - [3., 5., 2.], - [4., 2., 5.], - [5., 4., 3.] - ]; - let b: Array2 = array![[-10., -3.], [12., 14.], [14., 12.], [16., 16.], [18., 16.]]; - let expected: Array2 = array![[2., 1.], [1., 1.], [1., 2.]]; - let result = a.least_squares(&b).unwrap(); - assert!(result.solution.abs_diff_eq(&expected, 1e-12)); - - let residual = &b - &a.dot(&result.solution); - let residual_ssq = residual.mapv(|x| x.powi(2)).sum_axis(Axis(0)); - assert!(result - .residual_sum_of_squares - .unwrap() - .abs_diff_eq(&residual_ssq, 1e-12)); - } - // // Testing error cases // - use crate::layout::MatrixLayout; #[test] - fn test_incompatible_shape_error_on_mismatching_num_rows() { + fn incompatible_shape_error_on_mismatching_num_rows() { let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; let b: Array1 = array![1., 2.]; let res = a.least_squares(&b); match res { - Err(LinalgError::Lapack(err)) if matches!(err, lapack::error::Error::InvalidShape) => {} + Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {} _ => panic!("Expected Err()"), } } #[test] - fn test_incompatible_shape_error_on_mismatching_layout() { + fn incompatible_shape_error_on_mismatching_layout() { let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; let b = array![[1.], [2.]].t().to_owned(); assert_eq!(b.layout().unwrap(), MatrixLayout::F { col: 2, lda: 1 }); let res = a.least_squares(&b); match res { - Err(LinalgError::Lapack(err)) if matches!(err, lapack::error::Error::InvalidShape) => {} + Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {} _ => panic!("Expected Err()"), } } diff --git a/ndarray-linalg/tests/least_squares.rs b/ndarray-linalg/tests/least_squares.rs new file mode 100644 index 00000000..e59233ce --- /dev/null +++ b/ndarray-linalg/tests/least_squares.rs @@ -0,0 +1,161 @@ +use approx::AbsDiffEq; +use ndarray::*; +use ndarray_linalg::*; +use num_complex::Complex; + +fn c(re: f64, im: f64) -> Complex { + Complex::new(re, im) +} + +// +// Test cases taken from the scipy test suite for the scipy lstsq function +// https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/basic.py +// +#[test] +fn least_squares_exact() { + let a = array![[1., 20.], [-30., 4.]]; + let bs = vec![ + array![[1., 0.], [0., 1.]], + array![[1.], [0.]], + array![[2., 1.], [-30., 4.]], + ]; + for b in &bs { + let res = a.least_squares(b).unwrap(); + assert_eq!(res.rank, 2); + let b_hat = a.dot(&res.solution); + let rssq = (b - &b_hat).mapv(|x| x.powi(2)).sum_axis(Axis(0)); + assert!(res + .residual_sum_of_squares + .unwrap() + .abs_diff_eq(&rssq, 1e-12)); + assert!(b_hat.abs_diff_eq(&b, 1e-12)); + } +} + +#[test] +fn least_squares_overdetermined() { + let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; + let b: Array1 = array![1., 2., 3.]; + let res = a.least_squares(&b).unwrap(); + assert_eq!(res.rank, 2); + let b_hat = a.dot(&res.solution); + let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum(); + assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12)); + assert!(res + .solution + .abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-12)); +} + +#[test] +fn least_squares_overdetermined_complex() { + let a: Array2 = array![ + [c(1., 2.), c(2., 0.)], + [c(4., 0.), c(5., 0.)], + [c(3., 0.), c(4., 0.)] + ]; + let b: Array1 = array![c(1., 0.), c(2., 4.), c(3., 0.)]; + let res = a.least_squares(&b).unwrap(); + assert_eq!(res.rank, 2); + let b_hat = a.dot(&res.solution); + let rssq = (&b_hat - &b).mapv(|x| x.powi(2).abs()).sum(); + assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12)); + assert!(res.solution.abs_diff_eq( + &array![ + c(-0.4831460674157303, 0.258426966292135), + c(0.921348314606741, 0.292134831460674) + ], + 1e-12 + )); +} + +#[test] +fn least_squares_underdetermined() { + let a: Array2 = array![[1., 2., 3.], [4., 5., 6.]]; + let b: Array1 = array![1., 2.]; + let res = a.least_squares(&b).unwrap(); + assert_eq!(res.rank, 2); + assert!(res.residual_sum_of_squares.is_none()); + let expected = array![-0.055555555555555, 0.111111111111111, 0.277777777777777]; + assert!(res.solution.abs_diff_eq(&expected, 1e-12)); +} + +/// This test case tests the underdetermined case for multiple right hand +/// sides. Adapted from scipy lstsq tests. +#[test] +fn least_squares_underdetermined_nrhs() { + let a: Array2 = array![[1., 2., 3.], [4., 5., 6.]]; + let b: Array2 = array![[1., 1.], [2., 2.]]; + let res = a.least_squares(&b).unwrap(); + assert_eq!(res.rank, 2); + assert!(res.residual_sum_of_squares.is_none()); + let expected = array![ + [-0.055555555555555, -0.055555555555555], + [0.111111111111111, 0.111111111111111], + [0.277777777777777, 0.277777777777777] + ]; + assert!(res.solution.abs_diff_eq(&expected, 1e-12)); +} + +// +// Test cases taken from the netlib documentation at +// https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code +// +#[test] +fn netlib_lapack_example_for_dgels_1() { + let a: Array2 = array![ + [1., 1., 1.], + [2., 3., 4.], + [3., 5., 2.], + [4., 2., 5.], + [5., 4., 3.] + ]; + let b: Array1 = array![-10., 12., 14., 16., 18.]; + let expected: Array1 = array![2., 1., 1.]; + let result = a.least_squares(&b).unwrap(); + assert!(result.solution.abs_diff_eq(&expected, 1e-12)); + + let residual = b - a.dot(&result.solution); + let resid_ssq = result.residual_sum_of_squares.unwrap(); + assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12); +} + +#[test] +fn netlib_lapack_example_for_dgels_2() { + let a: Array2 = array![ + [1., 1., 1.], + [2., 3., 4.], + [3., 5., 2.], + [4., 2., 5.], + [5., 4., 3.] + ]; + let b: Array1 = array![-3., 14., 12., 16., 16.]; + let expected: Array1 = array![1., 1., 2.]; + let result = a.least_squares(&b).unwrap(); + assert!(result.solution.abs_diff_eq(&expected, 1e-12)); + + let residual = b - a.dot(&result.solution); + let resid_ssq = result.residual_sum_of_squares.unwrap(); + assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12); +} + +#[test] +fn netlib_lapack_example_for_dgels_nrhs() { + let a: Array2 = array![ + [1., 1., 1.], + [2., 3., 4.], + [3., 5., 2.], + [4., 2., 5.], + [5., 4., 3.] + ]; + let b: Array2 = array![[-10., -3.], [12., 14.], [14., 12.], [16., 16.], [18., 16.]]; + let expected: Array2 = array![[2., 1.], [1., 1.], [1., 2.]]; + let result = a.least_squares(&b).unwrap(); + assert!(result.solution.abs_diff_eq(&expected, 1e-12)); + + let residual = &b - &a.dot(&result.solution); + let residual_ssq = residual.mapv(|x| x.powi(2)).sum_axis(Axis(0)); + assert!(result + .residual_sum_of_squares + .unwrap() + .abs_diff_eq(&residual_ssq, 1e-12)); +} From e51f966710b85d0e10651dfba7ce9b8953559f14 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 22 Jul 2020 00:03:33 +0900 Subject: [PATCH 2/7] Rewrite test cases --- ndarray-linalg/tests/least_squares.rs | 180 ++++++-------------------- 1 file changed, 41 insertions(+), 139 deletions(-) diff --git a/ndarray-linalg/tests/least_squares.rs b/ndarray-linalg/tests/least_squares.rs index e59233ce..51903251 100644 --- a/ndarray-linalg/tests/least_squares.rs +++ b/ndarray-linalg/tests/least_squares.rs @@ -1,161 +1,63 @@ +/// Solve least square problem `|b - Ax|` use approx::AbsDiffEq; use ndarray::*; use ndarray_linalg::*; -use num_complex::Complex; -fn c(re: f64, im: f64) -> Complex { - Complex::new(re, im) -} - -// -// Test cases taken from the scipy test suite for the scipy lstsq function -// https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/basic.py -// +/// A is square. `x = A^{-1} b`, `|b - Ax| = 0` #[test] fn least_squares_exact() { - let a = array![[1., 20.], [-30., 4.]]; - let bs = vec![ - array![[1., 0.], [0., 1.]], - array![[1.], [0.]], - array![[2., 1.], [-30., 4.]], - ]; - for b in &bs { - let res = a.least_squares(b).unwrap(); - assert_eq!(res.rank, 2); - let b_hat = a.dot(&res.solution); - let rssq = (b - &b_hat).mapv(|x| x.powi(2)).sum_axis(Axis(0)); - assert!(res - .residual_sum_of_squares - .unwrap() - .abs_diff_eq(&rssq, 1e-12)); - assert!(b_hat.abs_diff_eq(&b, 1e-12)); - } -} + let a: Array2 = random((3, 3)); + let b: Array1 = random(3); + let result = a.least_squares(&b).unwrap(); + // unpack result + let x = result.solution; + let residual_l2_square = result.residual_sum_of_squares.unwrap()[()]; -#[test] -fn least_squares_overdetermined() { - let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; - let b: Array1 = array![1., 2., 3.]; - let res = a.least_squares(&b).unwrap(); - assert_eq!(res.rank, 2); - let b_hat = a.dot(&res.solution); - let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum(); - assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12)); - assert!(res - .solution - .abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-12)); -} + // must be full-rank + assert_eq!(result.rank, 3); -#[test] -fn least_squares_overdetermined_complex() { - let a: Array2 = array![ - [c(1., 2.), c(2., 0.)], - [c(4., 0.), c(5., 0.)], - [c(3., 0.), c(4., 0.)] - ]; - let b: Array1 = array![c(1., 0.), c(2., 4.), c(3., 0.)]; - let res = a.least_squares(&b).unwrap(); - assert_eq!(res.rank, 2); - let b_hat = a.dot(&res.solution); - let rssq = (&b_hat - &b).mapv(|x| x.powi(2).abs()).sum(); - assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12)); - assert!(res.solution.abs_diff_eq( - &array![ - c(-0.4831460674157303, 0.258426966292135), - c(0.921348314606741, 0.292134831460674) - ], - 1e-12 - )); -} + // |b - Ax| == 0 + assert!(residual_l2_square < 1.0e-7); -#[test] -fn least_squares_underdetermined() { - let a: Array2 = array![[1., 2., 3.], [4., 5., 6.]]; - let b: Array1 = array![1., 2.]; - let res = a.least_squares(&b).unwrap(); - assert_eq!(res.rank, 2); - assert!(res.residual_sum_of_squares.is_none()); - let expected = array![-0.055555555555555, 0.111111111111111, 0.277777777777777]; - assert!(res.solution.abs_diff_eq(&expected, 1e-12)); + // b == Ax + let ax = a.dot(&x); + assert_close_l2!(&b, &ax, 1.0e-7); } -/// This test case tests the underdetermined case for multiple right hand -/// sides. Adapted from scipy lstsq tests. +/// #column < #row case. +/// Linear problem is overdetermined, `|b - Ax| > 0`. #[test] -fn least_squares_underdetermined_nrhs() { - let a: Array2 = array![[1., 2., 3.], [4., 5., 6.]]; - let b: Array2 = array![[1., 1.], [2., 2.]]; - let res = a.least_squares(&b).unwrap(); - assert_eq!(res.rank, 2); - assert!(res.residual_sum_of_squares.is_none()); - let expected = array![ - [-0.055555555555555, -0.055555555555555], - [0.111111111111111, 0.111111111111111], - [0.277777777777777, 0.277777777777777] - ]; - assert!(res.solution.abs_diff_eq(&expected, 1e-12)); -} - -// -// Test cases taken from the netlib documentation at -// https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code -// -#[test] -fn netlib_lapack_example_for_dgels_1() { - let a: Array2 = array![ - [1., 1., 1.], - [2., 3., 4.], - [3., 5., 2.], - [4., 2., 5.], - [5., 4., 3.] - ]; - let b: Array1 = array![-10., 12., 14., 16., 18.]; - let expected: Array1 = array![2., 1., 1.]; +fn least_squares_overdetermined() { + let a: Array2 = random((4, 3)); + let b: Array1 = random(4); let result = a.least_squares(&b).unwrap(); - assert!(result.solution.abs_diff_eq(&expected, 1e-12)); + // unpack result + let x = result.solution; + let residual_l2_square = result.residual_sum_of_squares.unwrap()[()]; - let residual = b - a.dot(&result.solution); - let resid_ssq = result.residual_sum_of_squares.unwrap(); - assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12); -} + // Must be full-rank + assert_eq!(result.rank, 3); -#[test] -fn netlib_lapack_example_for_dgels_2() { - let a: Array2 = array![ - [1., 1., 1.], - [2., 3., 4.], - [3., 5., 2.], - [4., 2., 5.], - [5., 4., 3.] - ]; - let b: Array1 = array![-3., 14., 12., 16., 16.]; - let expected: Array1 = array![1., 1., 2.]; - let result = a.least_squares(&b).unwrap(); - assert!(result.solution.abs_diff_eq(&expected, 1e-12)); + // eval `residual = b - Ax` + let residual = &b - &a.dot(&x); + assert!(residual_l2_square.abs_diff_eq(&residual.norm_l2().powi(2), 1e-12)); - let residual = b - a.dot(&result.solution); - let resid_ssq = result.residual_sum_of_squares.unwrap(); - assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12); + // `|residual| < |b|` + assert!(residual.norm_l2() < b.norm_l2()); } +/// #column > #row case. +/// Linear problem is underdetermined, `|b - Ax| = 0` and `x` is not unique #[test] -fn netlib_lapack_example_for_dgels_nrhs() { - let a: Array2 = array![ - [1., 1., 1.], - [2., 3., 4.], - [3., 5., 2.], - [4., 2., 5.], - [5., 4., 3.] - ]; - let b: Array2 = array![[-10., -3.], [12., 14.], [14., 12.], [16., 16.], [18., 16.]]; - let expected: Array2 = array![[2., 1.], [1., 1.], [1., 2.]]; +fn least_squares_underdetermined() { + let a: Array2 = random((3, 4)); + let b: Array1 = random(3); let result = a.least_squares(&b).unwrap(); - assert!(result.solution.abs_diff_eq(&expected, 1e-12)); + assert_eq!(result.rank, 3); + assert!(result.residual_sum_of_squares.is_none()); - let residual = &b - &a.dot(&result.solution); - let residual_ssq = residual.mapv(|x| x.powi(2)).sum_axis(Axis(0)); - assert!(result - .residual_sum_of_squares - .unwrap() - .abs_diff_eq(&residual_ssq, 1e-12)); + // b == Ax + let x = result.solution; + let ax = a.dot(&x); + assert_close_l2!(&b, &ax, 1.0e-7); } From a8a07fba5f3f46f78aa006c113bd2e23c1cc50f1 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 22 Jul 2020 18:24:31 +0900 Subject: [PATCH 3/7] Test for F-continuous and complex cases --- ndarray-linalg/tests/least_squares.rs | 98 ++++++++++++++++++++++----- 1 file changed, 82 insertions(+), 16 deletions(-) diff --git a/ndarray-linalg/tests/least_squares.rs b/ndarray-linalg/tests/least_squares.rs index 51903251..c388c9d7 100644 --- a/ndarray-linalg/tests/least_squares.rs +++ b/ndarray-linalg/tests/least_squares.rs @@ -4,10 +4,8 @@ use ndarray::*; use ndarray_linalg::*; /// A is square. `x = A^{-1} b`, `|b - Ax| = 0` -#[test] -fn least_squares_exact() { - let a: Array2 = random((3, 3)); - let b: Array1 = random(3); +fn test_exact(a: Array2) { + let b: Array1 = random(3); let result = a.least_squares(&b).unwrap(); // unpack result let x = result.solution; @@ -17,19 +15,43 @@ fn least_squares_exact() { assert_eq!(result.rank, 3); // |b - Ax| == 0 - assert!(residual_l2_square < 1.0e-7); + assert!(residual_l2_square < T::real(1.0e-4)); // b == Ax let ax = a.dot(&x); - assert_close_l2!(&b, &ax, 1.0e-7); + assert_close_l2!(&b, &ax, T::real(1.0e-4)); } +macro_rules! impl_exact { + ($scalar:ty) => { + paste::item! { + #[test] + fn []() { + let a: Array2 = random((3, 3)); + test_exact(a) + } + + #[test] + fn []() { + let a: Array2 = random((3, 3).f()); + test_exact(a) + } + } + }; +} + +impl_exact!(f32); +impl_exact!(f64); +impl_exact!(c32); +impl_exact!(c64); + /// #column < #row case. /// Linear problem is overdetermined, `|b - Ax| > 0`. -#[test] -fn least_squares_overdetermined() { - let a: Array2 = random((4, 3)); - let b: Array1 = random(4); +fn test_overdetermined(a: Array2) +where + T::Real: AbsDiffEq, +{ + let b: Array1 = random(4); let result = a.least_squares(&b).unwrap(); // unpack result let x = result.solution; @@ -40,18 +62,39 @@ fn least_squares_overdetermined() { // eval `residual = b - Ax` let residual = &b - &a.dot(&x); - assert!(residual_l2_square.abs_diff_eq(&residual.norm_l2().powi(2), 1e-12)); + assert!(residual_l2_square.abs_diff_eq(&residual.norm_l2().powi(2), T::real(1.0e-4))); // `|residual| < |b|` assert!(residual.norm_l2() < b.norm_l2()); } +macro_rules! impl_overdetermined { + ($scalar:ty) => { + paste::item! { + #[test] + fn []() { + let a: Array2 = random((4, 3)); + test_overdetermined(a) + } + + #[test] + fn []() { + let a: Array2 = random((4, 3).f()); + test_overdetermined(a) + } + } + }; +} + +impl_overdetermined!(f32); +impl_overdetermined!(f64); +impl_overdetermined!(c32); +impl_overdetermined!(c64); + /// #column > #row case. /// Linear problem is underdetermined, `|b - Ax| = 0` and `x` is not unique -#[test] -fn least_squares_underdetermined() { - let a: Array2 = random((3, 4)); - let b: Array1 = random(3); +fn test_underdetermined(a: Array2) { + let b: Array1 = random(3); let result = a.least_squares(&b).unwrap(); assert_eq!(result.rank, 3); assert!(result.residual_sum_of_squares.is_none()); @@ -59,5 +102,28 @@ fn least_squares_underdetermined() { // b == Ax let x = result.solution; let ax = a.dot(&x); - assert_close_l2!(&b, &ax, 1.0e-7); + assert_close_l2!(&b, &ax, T::real(1.0e-4)); +} + +macro_rules! impl_underdetermined { + ($scalar:ty) => { + paste::item! { + #[test] + fn []() { + let a: Array2 = random((3, 4)); + test_underdetermined(a) + } + + #[test] + fn []() { + let a: Array2 = random((3, 4).f()); + test_underdetermined(a) + } + } + }; } + +impl_underdetermined!(f32); +impl_underdetermined!(f64); +impl_underdetermined!(c32); +impl_underdetermined!(c64); From 05a2bcdbb7e0f8d9471c8ea6e4344abe1b1c012a Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Thu, 23 Jul 2020 18:21:28 +0900 Subject: [PATCH 4/7] Fix F-contiguous case --- lax/src/least_squares.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lax/src/least_squares.rs b/lax/src/least_squares.rs index eaf496f2..69553a44 100644 --- a/lax/src/least_squares.rs +++ b/lax/src/least_squares.rs @@ -42,6 +42,10 @@ macro_rules! impl_least_squares { } let k = ::std::cmp::min(m, n); let nrhs = 1; + let ldb = match a_layout { + MatrixLayout::F { .. } => m.max(n), + MatrixLayout::C { .. } => 1, + }; let rcond: Self::Real = -1.; let mut singular_values: Vec = vec![Self::Real::zero(); k as usize]; let mut rank: i32 = 0; @@ -54,9 +58,7 @@ macro_rules! impl_least_squares { a, a_layout.lda(), b, - // this is the 'leading dimension of b', in the case where - // b is a single vector, this is 1 - nrhs, + ldb, &mut singular_values, rcond, &mut rank, From 6bbe87d74e70b62f70e5d55e5497cd9758cbb30e Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 24 Jul 2020 02:42:05 +0900 Subject: [PATCH 5/7] Add tests for multi-column `b` --- ndarray-linalg/tests/least_squares_nrhs.rs | 189 +++++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 ndarray-linalg/tests/least_squares_nrhs.rs diff --git a/ndarray-linalg/tests/least_squares_nrhs.rs b/ndarray-linalg/tests/least_squares_nrhs.rs new file mode 100644 index 00000000..835091e5 --- /dev/null +++ b/ndarray-linalg/tests/least_squares_nrhs.rs @@ -0,0 +1,189 @@ +/// Solve least square problem `|b - Ax|` with multi-column `b` +use approx::AbsDiffEq; +use ndarray::*; +use ndarray_linalg::*; + +/// A is square. `x = A^{-1} b`, `|b - Ax| = 0` +fn test_exact(a: Array2, b: Array2) { + assert_eq!(a.layout().unwrap().size(), (3, 3)); + assert_eq!(b.layout().unwrap().size(), (3, 2)); + + let result = a.least_squares(&b).unwrap(); + // unpack result + let x: Array2 = result.solution; + let residual_l2_square: Array1 = result.residual_sum_of_squares.unwrap(); + + // must be full-rank + assert_eq!(result.rank, 3); + + // |b - Ax| == 0 + for &residual in &residual_l2_square { + assert!(residual < T::real(1.0e-4)); + } + + // b == Ax + let ax = a.dot(&x); + assert_close_l2!(&b, &ax, T::real(1.0e-4)); +} + +macro_rules! impl_exact { + ($scalar:ty) => { + paste::item! { + #[test] + fn []() { + let a: Array2 = random((3, 3)); + let b: Array2 = random((3, 2)); + test_exact(a, b) + } + + #[test] + fn []() { + let a: Array2 = random((3, 3)); + let b: Array2 = random((3, 2).f()); + test_exact(a, b) + } + + #[test] + fn []() { + let a: Array2 = random((3, 3).f()); + let b: Array2 = random((3, 2)); + test_exact(a, b) + } + + #[test] + fn []() { + let a: Array2 = random((3, 3).f()); + let b: Array2 = random((3, 2).f()); + test_exact(a, b) + } + } + }; +} + +impl_exact!(f32); +impl_exact!(f64); +impl_exact!(c32); +impl_exact!(c64); + +/// #column < #row case. +/// Linear problem is overdetermined, `|b - Ax| > 0`. +fn test_overdetermined(a: Array2, bs: Array2) +where + T::Real: AbsDiffEq, +{ + assert_eq!(a.layout().unwrap().size(), (4, 3)); + assert_eq!(bs.layout().unwrap().size(), (4, 2)); + + let result = a.least_squares(&bs).unwrap(); + // unpack result + let xs = result.solution; + let residual_l2_square = result.residual_sum_of_squares.unwrap(); + + // Must be full-rank + assert_eq!(result.rank, 3); + + for j in 0..2 { + let b = bs.index_axis(Axis(1), j); + let x = xs.index_axis(Axis(1), j); + let residual = &b - &a.dot(&x); + let residual_l2_sq = residual_l2_square[j]; + assert!(residual_l2_sq.abs_diff_eq(&residual.norm_l2().powi(2), T::real(1.0e-4))); + + // `|residual| < |b|` + assert!(residual.norm_l2() < b.norm_l2()); + } +} + +macro_rules! impl_overdetermined { + ($scalar:ty) => { + paste::item! { + #[test] + fn []() { + let a: Array2 = random((4, 3)); + let b: Array2 = random((4, 2)); + test_overdetermined(a, b) + } + + #[test] + fn []() { + let a: Array2 = random((4, 3).f()); + let b: Array2 = random((4, 2)); + test_overdetermined(a, b) + } + + #[test] + fn []() { + let a: Array2 = random((4, 3)); + let b: Array2 = random((4, 2).f()); + test_overdetermined(a, b) + } + + #[test] + fn []() { + let a: Array2 = random((4, 3).f()); + let b: Array2 = random((4, 2).f()); + test_overdetermined(a, b) + } + } + }; +} + +impl_overdetermined!(f32); +impl_overdetermined!(f64); +impl_overdetermined!(c32); +impl_overdetermined!(c64); + +/// #column > #row case. +/// Linear problem is underdetermined, `|b - Ax| = 0` and `x` is not unique +fn test_underdetermined(a: Array2, b: Array2) { + assert_eq!(a.layout().unwrap().size(), (3, 4)); + assert_eq!(b.layout().unwrap().size(), (3, 2)); + + let result = a.least_squares(&b).unwrap(); + assert_eq!(result.rank, 3); + assert!(result.residual_sum_of_squares.is_none()); + + // b == Ax + let x = result.solution; + let ax = a.dot(&x); + assert_close_l2!(&b, &ax, T::real(1.0e-4)); +} + +macro_rules! impl_underdetermined { + ($scalar:ty) => { + paste::item! { + #[test] + fn []() { + let a: Array2 = random((3, 4)); + let b: Array2 = random((3, 2)); + test_underdetermined(a, b) + } + + #[test] + fn []() { + let a: Array2 = random((3, 4).f()); + let b: Array2 = random((3, 2)); + test_underdetermined(a, b) + } + + #[test] + fn []() { + let a: Array2 = random((3, 4)); + let b: Array2 = random((3, 2).f()); + test_underdetermined(a, b) + } + + #[test] + fn []() { + let a: Array2 = random((3, 4).f()); + let b: Array2 = random((3, 2).f()); + test_underdetermined(a, b) + } + } + }; +} + +impl_underdetermined!(f32); +impl_underdetermined!(f64); +impl_underdetermined!(c32); +impl_underdetermined!(c64); From b60cfc851c543c507b551046b79acfda12ae34e3 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 24 Jul 2020 17:07:30 +0900 Subject: [PATCH 6/7] Comment out Unsupported C/F-mixed cases --- ndarray-linalg/tests/least_squares_nrhs.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ndarray-linalg/tests/least_squares_nrhs.rs b/ndarray-linalg/tests/least_squares_nrhs.rs index 835091e5..4c964697 100644 --- a/ndarray-linalg/tests/least_squares_nrhs.rs +++ b/ndarray-linalg/tests/least_squares_nrhs.rs @@ -36,6 +36,8 @@ macro_rules! impl_exact { test_exact(a, b) } + /* Unsupported currently. See https://github.com/rust-ndarray/ndarray-linalg/issues/234 + #[test] fn []() { let a: Array2 = random((3, 3)); @@ -50,6 +52,8 @@ macro_rules! impl_exact { test_exact(a, b) } + */ + #[test] fn []() { let a: Array2 = random((3, 3).f()); @@ -104,6 +108,8 @@ macro_rules! impl_overdetermined { test_overdetermined(a, b) } + /* Unsupported currently. See https://github.com/rust-ndarray/ndarray-linalg/issues/234 + #[test] fn []() { let a: Array2 = random((4, 3).f()); @@ -118,6 +124,8 @@ macro_rules! impl_overdetermined { test_overdetermined(a, b) } + */ + #[test] fn []() { let a: Array2 = random((4, 3).f()); @@ -159,6 +167,8 @@ macro_rules! impl_underdetermined { test_underdetermined(a, b) } + /* Unsupported currently. See https://github.com/rust-ndarray/ndarray-linalg/issues/234 + #[test] fn []() { let a: Array2 = random((3, 4).f()); @@ -173,6 +183,8 @@ macro_rules! impl_underdetermined { test_underdetermined(a, b) } + */ + #[test] fn []() { let a: Array2 = random((3, 4).f()); From f6a9c2ac4ca85b599e543feb2a34579f589bdf84 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 24 Jul 2020 21:23:50 +0900 Subject: [PATCH 7/7] Fix memory layout for overdetermined case --- ndarray-linalg/src/least_squares.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ndarray-linalg/src/least_squares.rs b/ndarray-linalg/src/least_squares.rs index c8f26355..18d2033f 100644 --- a/ndarray-linalg/src/least_squares.rs +++ b/ndarray-linalg/src/least_squares.rs @@ -60,7 +60,7 @@ //! // `a` and `b` have been moved, no longer valid //! ``` -use ndarray::{s, Array, Array1, Array2, ArrayBase, Axis, Data, DataMut, Dimension, Ix0, Ix1, Ix2}; +use ndarray::*; use crate::error::*; use crate::lapack::least_squares::*; @@ -352,7 +352,10 @@ where // we need a new rhs b/c it will be overwritten with the solution // for which we need `n` entries let k = rhs.shape()[1]; - let mut new_rhs = Array2::::zeros((n, k)); + let mut new_rhs = match self.layout()? { + MatrixLayout::C { .. } => Array2::::zeros((n, k)), + MatrixLayout::F { .. } => Array2::::zeros((n, k).f()), + }; new_rhs.slice_mut(s![0..m, ..]).assign(rhs); compute_least_squares_nrhs(self, &mut new_rhs) } else {