Skip to content

Revise tests for least-square problems #227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 24, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions lax/src/least_squares.rs
Original file line number Diff line number Diff line change
@@ -42,6 +42,10 @@ macro_rules! impl_least_squares {
}
let k = ::std::cmp::min(m, n);
let nrhs = 1;
let ldb = match a_layout {
MatrixLayout::F { .. } => m.max(n),
MatrixLayout::C { .. } => 1,
};
let rcond: Self::Real = -1.;
let mut singular_values: Vec<Self::Real> = vec![Self::Real::zero(); k as usize];
let mut rank: i32 = 0;
@@ -54,9 +58,7 @@ macro_rules! impl_least_squares {
a,
a_layout.lda(),
b,
// this is the 'leading dimension of b', in the case where
// b is a single vector, this is 1
nrhs,
ldb,
&mut singular_values,
rcond,
&mut rank,
212 changes: 21 additions & 191 deletions ndarray-linalg/src/least_squares.rs
Original file line number Diff line number Diff line change
@@ -60,7 +60,7 @@
//! // `a` and `b` have been moved, no longer valid
//! ```
use ndarray::{s, Array, Array1, Array2, ArrayBase, Axis, Data, DataMut, Dimension, Ix0, Ix1, Ix2};
use ndarray::*;

use crate::error::*;
use crate::lapack::least_squares::*;
@@ -352,7 +352,10 @@ where
// we need a new rhs b/c it will be overwritten with the solution
// for which we need `n` entries
let k = rhs.shape()[1];
let mut new_rhs = Array2::<E>::zeros((n, k));
let mut new_rhs = match self.layout()? {
MatrixLayout::C { .. } => Array2::<E>::zeros((n, k)),
MatrixLayout::F { .. } => Array2::<E>::zeros((n, k).f()),
};
new_rhs.slice_mut(s![0..m, ..]).assign(rhs);
compute_least_squares_nrhs(self, &mut new_rhs)
} else {
@@ -414,117 +417,9 @@ fn compute_residual_array1<E: Scalar, D: Data<Elem = E>>(

#[cfg(test)]
mod tests {
use super::*;
use crate::{error::LinalgError, *};
use approx::AbsDiffEq;
use ndarray::{ArcArray1, ArcArray2, Array1, Array2, CowArray};
use num_complex::Complex;

//
// Test cases taken from the scipy test suite for the scipy lstsq function
// https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/test_basic.py
//
#[test]
fn scipy_test_simple_exact() {
let a = array![[1., 20.], [-30., 4.]];
let bs = vec![
array![[1., 0.], [0., 1.]],
array![[1.], [0.]],
array![[2., 1.], [-30., 4.]],
];
for b in &bs {
let res = a.least_squares(b).unwrap();
assert_eq!(res.rank, 2);
let b_hat = a.dot(&res.solution);
let rssq = (b - &b_hat).mapv(|x| x.powi(2)).sum_axis(Axis(0));
assert!(res
.residual_sum_of_squares
.unwrap()
.abs_diff_eq(&rssq, 1e-12));
assert!(b_hat.abs_diff_eq(&b, 1e-12));
}
}

#[test]
fn scipy_test_simple_overdetermined() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let b: Array1<f64> = array![1., 2., 3.];
let res = a.least_squares(&b).unwrap();
assert_eq!(res.rank, 2);
let b_hat = a.dot(&res.solution);
let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum();
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12));
assert!(res
.solution
.abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-12));
}

#[test]
fn scipy_test_simple_overdetermined_f32() {
let a: Array2<f32> = array![[1., 2.], [4., 5.], [3., 4.]];
let b: Array1<f32> = array![1., 2., 3.];
let res = a.least_squares(&b).unwrap();
assert_eq!(res.rank, 2);
let b_hat = a.dot(&res.solution);
let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum();
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-6));
assert!(res
.solution
.abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-6));
}

fn c(re: f64, im: f64) -> Complex<f64> {
Complex::new(re, im)
}

#[test]
fn scipy_test_simple_overdetermined_complex() {
let a: Array2<c64> = array![
[c(1., 2.), c(2., 0.)],
[c(4., 0.), c(5., 0.)],
[c(3., 0.), c(4., 0.)]
];
let b: Array1<c64> = array![c(1., 0.), c(2., 4.), c(3., 0.)];
let res = a.least_squares(&b).unwrap();
assert_eq!(res.rank, 2);
let b_hat = a.dot(&res.solution);
let rssq = (&b_hat - &b).mapv(|x| x.powi(2).abs()).sum();
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12));
assert!(res.solution.abs_diff_eq(
&array![
c(-0.4831460674157303, 0.258426966292135),
c(0.921348314606741, 0.292134831460674)
],
1e-12
));
}

#[test]
fn scipy_test_simple_underdetermined() {
let a: Array2<f64> = array![[1., 2., 3.], [4., 5., 6.]];
let b: Array1<f64> = array![1., 2.];
let res = a.least_squares(&b).unwrap();
assert_eq!(res.rank, 2);
assert!(res.residual_sum_of_squares.is_none());
let expected = array![-0.055555555555555, 0.111111111111111, 0.277777777777777];
assert!(res.solution.abs_diff_eq(&expected, 1e-12));
}

/// This test case tests the underdetermined case for multiple right hand
/// sides. Adapted from scipy lstsq tests.
#[test]
fn scipy_test_simple_underdetermined_nrhs() {
let a: Array2<f64> = array![[1., 2., 3.], [4., 5., 6.]];
let b: Array2<f64> = array![[1., 1.], [2., 2.]];
let res = a.least_squares(&b).unwrap();
assert_eq!(res.rank, 2);
assert!(res.residual_sum_of_squares.is_none());
let expected = array![
[-0.055555555555555, -0.055555555555555],
[0.111111111111111, 0.111111111111111],
[0.277777777777777, 0.277777777777777]
];
assert!(res.solution.abs_diff_eq(&expected, 1e-12));
}
use ndarray::*;

//
// Test that the different lest squares traits work as intended on the
@@ -554,23 +449,23 @@ mod tests {
}

#[test]
fn test_least_squares_on_arc() {
fn on_arc() {
let a: ArcArray2<f64> = array![[1., 2.], [4., 5.], [3., 4.]].into_shared();
let b: ArcArray1<f64> = array![1., 2., 3.].into_shared();
let res = a.least_squares(&b).unwrap();
assert_result(&a, &b, &res);
}

#[test]
fn test_least_squares_on_cow() {
fn on_cow() {
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
let b = CowArray::from(array![1., 2., 3.]);
let res = a.least_squares(&b).unwrap();
assert_result(&a, &b, &res);
}

#[test]
fn test_least_squares_on_view() {
fn on_view() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let b: Array1<f64> = array![1., 2., 3.];
let av = a.view();
@@ -580,7 +475,7 @@ mod tests {
}

#[test]
fn test_least_squares_on_view_mut() {
fn on_view_mut() {
let mut a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let mut b: Array1<f64> = array![1., 2., 3.];
let av = a.view_mut();
@@ -590,7 +485,7 @@ mod tests {
}

#[test]
fn test_least_squares_into_on_owned() {
fn into_on_owned() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let b: Array1<f64> = array![1., 2., 3.];
let ac = a.clone();
@@ -600,7 +495,7 @@ mod tests {
}

#[test]
fn test_least_squares_into_on_arc() {
fn into_on_arc() {
let a: ArcArray2<f64> = array![[1., 2.], [4., 5.], [3., 4.]].into_shared();
let b: ArcArray1<f64> = array![1., 2., 3.].into_shared();
let a2 = a.clone();
@@ -610,7 +505,7 @@ mod tests {
}

#[test]
fn test_least_squares_into_on_cow() {
fn into_on_cow() {
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
let b = CowArray::from(array![1., 2., 3.]);
let a2 = a.clone();
@@ -620,7 +515,7 @@ mod tests {
}

#[test]
fn test_least_squares_in_place_on_owned() {
fn in_place_on_owned() {
let a = array![[1., 2.], [4., 5.], [3., 4.]];
let b = array![1., 2., 3.];
let mut a2 = a.clone();
@@ -630,7 +525,7 @@ mod tests {
}

#[test]
fn test_least_squares_in_place_on_cow() {
fn in_place_on_cow() {
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
let b = CowArray::from(array![1., 2., 3.]);
let mut a2 = a.clone();
@@ -640,7 +535,7 @@ mod tests {
}

#[test]
fn test_least_squares_in_place_on_mut_view() {
fn in_place_on_mut_view() {
let a = array![[1., 2.], [4., 5.], [3., 4.]];
let b = array![1., 2., 3.];
let mut a2 = a.clone();
@@ -651,95 +546,30 @@ mod tests {
assert_result(&a, &b, &res);
}

//
// Test cases taken from the netlib documentation at
// https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code
//
#[test]
fn netlib_lapack_example_for_dgels_1() {
let a: Array2<f64> = array![
[1., 1., 1.],
[2., 3., 4.],
[3., 5., 2.],
[4., 2., 5.],
[5., 4., 3.]
];
let b: Array1<f64> = array![-10., 12., 14., 16., 18.];
let expected: Array1<f64> = array![2., 1., 1.];
let result = a.least_squares(&b).unwrap();
assert!(result.solution.abs_diff_eq(&expected, 1e-12));

let residual = b - a.dot(&result.solution);
let resid_ssq = result.residual_sum_of_squares.unwrap();
assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12);
}

#[test]
fn netlib_lapack_example_for_dgels_2() {
let a: Array2<f64> = array![
[1., 1., 1.],
[2., 3., 4.],
[3., 5., 2.],
[4., 2., 5.],
[5., 4., 3.]
];
let b: Array1<f64> = array![-3., 14., 12., 16., 16.];
let expected: Array1<f64> = array![1., 1., 2.];
let result = a.least_squares(&b).unwrap();
assert!(result.solution.abs_diff_eq(&expected, 1e-12));

let residual = b - a.dot(&result.solution);
let resid_ssq = result.residual_sum_of_squares.unwrap();
assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12);
}

#[test]
fn netlib_lapack_example_for_dgels_nrhs() {
let a: Array2<f64> = array![
[1., 1., 1.],
[2., 3., 4.],
[3., 5., 2.],
[4., 2., 5.],
[5., 4., 3.]
];
let b: Array2<f64> = array![[-10., -3.], [12., 14.], [14., 12.], [16., 16.], [18., 16.]];
let expected: Array2<f64> = array![[2., 1.], [1., 1.], [1., 2.]];
let result = a.least_squares(&b).unwrap();
assert!(result.solution.abs_diff_eq(&expected, 1e-12));

let residual = &b - &a.dot(&result.solution);
let residual_ssq = residual.mapv(|x| x.powi(2)).sum_axis(Axis(0));
assert!(result
.residual_sum_of_squares
.unwrap()
.abs_diff_eq(&residual_ssq, 1e-12));
}

//
// Testing error cases
//
use crate::layout::MatrixLayout;

#[test]
fn test_incompatible_shape_error_on_mismatching_num_rows() {
fn incompatible_shape_error_on_mismatching_num_rows() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let b: Array1<f64> = array![1., 2.];
let res = a.least_squares(&b);
match res {
Err(LinalgError::Lapack(err)) if matches!(err, lapack::error::Error::InvalidShape) => {}
Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {}
_ => panic!("Expected Err()"),
}
}

#[test]
fn test_incompatible_shape_error_on_mismatching_layout() {
fn incompatible_shape_error_on_mismatching_layout() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let b = array![[1.], [2.]].t().to_owned();
assert_eq!(b.layout().unwrap(), MatrixLayout::F { col: 2, lda: 1 });

let res = a.least_squares(&b);
match res {
Err(LinalgError::Lapack(err)) if matches!(err, lapack::error::Error::InvalidShape) => {}
Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {}
_ => panic!("Expected Err()"),
}
}
129 changes: 129 additions & 0 deletions ndarray-linalg/tests/least_squares.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/// Solve least square problem `|b - Ax|`
use approx::AbsDiffEq;
use ndarray::*;
use ndarray_linalg::*;

/// A is square. `x = A^{-1} b`, `|b - Ax| = 0`
fn test_exact<T: Scalar + Lapack>(a: Array2<T>) {
let b: Array1<T> = random(3);
let result = a.least_squares(&b).unwrap();
// unpack result
let x = result.solution;
let residual_l2_square = result.residual_sum_of_squares.unwrap()[()];

// must be full-rank
assert_eq!(result.rank, 3);

// |b - Ax| == 0
assert!(residual_l2_square < T::real(1.0e-4));

// b == Ax
let ax = a.dot(&x);
assert_close_l2!(&b, &ax, T::real(1.0e-4));
}

macro_rules! impl_exact {
($scalar:ty) => {
paste::item! {
#[test]
fn [<least_squares_ $scalar _exact>]() {
let a: Array2<f64> = random((3, 3));
test_exact(a)
}
Comment on lines +28 to +32
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/f64/$scalar/g


#[test]
fn [<least_squares_ $scalar _exact_t>]() {
let a: Array2<f64> = random((3, 3).f());
test_exact(a)
}
}
};
}

impl_exact!(f32);
impl_exact!(f64);
impl_exact!(c32);
impl_exact!(c64);

/// #column < #row case.
/// Linear problem is overdetermined, `|b - Ax| > 0`.
fn test_overdetermined<T: Scalar + Lapack>(a: Array2<T>)
where
T::Real: AbsDiffEq<Epsilon = T::Real>,
{
let b: Array1<T> = random(4);
let result = a.least_squares(&b).unwrap();
// unpack result
let x = result.solution;
let residual_l2_square = result.residual_sum_of_squares.unwrap()[()];

// Must be full-rank
assert_eq!(result.rank, 3);

// eval `residual = b - Ax`
let residual = &b - &a.dot(&x);
assert!(residual_l2_square.abs_diff_eq(&residual.norm_l2().powi(2), T::real(1.0e-4)));

// `|residual| < |b|`
assert!(residual.norm_l2() < b.norm_l2());
}

macro_rules! impl_overdetermined {
($scalar:ty) => {
paste::item! {
#[test]
fn [<least_squares_ $scalar _overdetermined>]() {
let a: Array2<f64> = random((4, 3));
test_overdetermined(a)
}

#[test]
fn [<least_squares_ $scalar _overdetermined_t>]() {
let a: Array2<f64> = random((4, 3).f());
test_overdetermined(a)
}
}
};
}

impl_overdetermined!(f32);
impl_overdetermined!(f64);
impl_overdetermined!(c32);
impl_overdetermined!(c64);

/// #column > #row case.
/// Linear problem is underdetermined, `|b - Ax| = 0` and `x` is not unique
fn test_underdetermined<T: Scalar + Lapack>(a: Array2<T>) {
let b: Array1<T> = random(3);
let result = a.least_squares(&b).unwrap();
assert_eq!(result.rank, 3);
assert!(result.residual_sum_of_squares.is_none());

// b == Ax
let x = result.solution;
let ax = a.dot(&x);
assert_close_l2!(&b, &ax, T::real(1.0e-4));
}

macro_rules! impl_underdetermined {
($scalar:ty) => {
paste::item! {
#[test]
fn [<least_squares_ $scalar _underdetermined>]() {
let a: Array2<f64> = random((3, 4));
test_underdetermined(a)
}

#[test]
fn [<least_squares_ $scalar _underdetermined_t>]() {
let a: Array2<f64> = random((3, 4).f());
test_underdetermined(a)
}
}
};
}

impl_underdetermined!(f32);
impl_underdetermined!(f64);
impl_underdetermined!(c32);
impl_underdetermined!(c64);
201 changes: 201 additions & 0 deletions ndarray-linalg/tests/least_squares_nrhs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/// Solve least square problem `|b - Ax|` with multi-column `b`
use approx::AbsDiffEq;
use ndarray::*;
use ndarray_linalg::*;

/// A is square. `x = A^{-1} b`, `|b - Ax| = 0`
fn test_exact<T: Scalar + Lapack>(a: Array2<T>, b: Array2<T>) {
assert_eq!(a.layout().unwrap().size(), (3, 3));
assert_eq!(b.layout().unwrap().size(), (3, 2));

let result = a.least_squares(&b).unwrap();
// unpack result
let x: Array2<T> = result.solution;
let residual_l2_square: Array1<T::Real> = result.residual_sum_of_squares.unwrap();

// must be full-rank
assert_eq!(result.rank, 3);

// |b - Ax| == 0
for &residual in &residual_l2_square {
assert!(residual < T::real(1.0e-4));
}

// b == Ax
let ax = a.dot(&x);
assert_close_l2!(&b, &ax, T::real(1.0e-4));
}

macro_rules! impl_exact {
($scalar:ty) => {
paste::item! {
#[test]
fn [<least_squares_ $scalar _exact_ac_bc>]() {
let a: Array2<f64> = random((3, 3));
let b: Array2<f64> = random((3, 2));
test_exact(a, b)
}

/* Unsupported currently. See https://github.com/rust-ndarray/ndarray-linalg/issues/234
#[test]
fn [<least_squares_ $scalar _exact_ac_bf>]() {
let a: Array2<f64> = random((3, 3));
let b: Array2<f64> = random((3, 2).f());
test_exact(a, b)
}
#[test]
fn [<least_squares_ $scalar _exact_af_bc>]() {
let a: Array2<f64> = random((3, 3).f());
let b: Array2<f64> = random((3, 2));
test_exact(a, b)
}
*/

#[test]
fn [<least_squares_ $scalar _exact_af_bf>]() {
let a: Array2<f64> = random((3, 3).f());
let b: Array2<f64> = random((3, 2).f());
test_exact(a, b)
}
}
};
}

impl_exact!(f32);
impl_exact!(f64);
impl_exact!(c32);
impl_exact!(c64);

/// #column < #row case.
/// Linear problem is overdetermined, `|b - Ax| > 0`.
fn test_overdetermined<T: Scalar + Lapack>(a: Array2<T>, bs: Array2<T>)
where
T::Real: AbsDiffEq<Epsilon = T::Real>,
{
assert_eq!(a.layout().unwrap().size(), (4, 3));
assert_eq!(bs.layout().unwrap().size(), (4, 2));

let result = a.least_squares(&bs).unwrap();
// unpack result
let xs = result.solution;
let residual_l2_square = result.residual_sum_of_squares.unwrap();

// Must be full-rank
assert_eq!(result.rank, 3);

for j in 0..2 {
let b = bs.index_axis(Axis(1), j);
let x = xs.index_axis(Axis(1), j);
let residual = &b - &a.dot(&x);
let residual_l2_sq = residual_l2_square[j];
assert!(residual_l2_sq.abs_diff_eq(&residual.norm_l2().powi(2), T::real(1.0e-4)));

// `|residual| < |b|`
assert!(residual.norm_l2() < b.norm_l2());
}
}

macro_rules! impl_overdetermined {
($scalar:ty) => {
paste::item! {
#[test]
fn [<least_squares_ $scalar _overdetermined_ac_bc>]() {
let a: Array2<f64> = random((4, 3));
let b: Array2<f64> = random((4, 2));
test_overdetermined(a, b)
}

/* Unsupported currently. See https://github.com/rust-ndarray/ndarray-linalg/issues/234
#[test]
fn [<least_squares_ $scalar _overdetermined_af_bc>]() {
let a: Array2<f64> = random((4, 3).f());
let b: Array2<f64> = random((4, 2));
test_overdetermined(a, b)
}
#[test]
fn [<least_squares_ $scalar _overdetermined_ac_bf>]() {
let a: Array2<f64> = random((4, 3));
let b: Array2<f64> = random((4, 2).f());
test_overdetermined(a, b)
}
*/

#[test]
fn [<least_squares_ $scalar _overdetermined_af_bf>]() {
let a: Array2<f64> = random((4, 3).f());
let b: Array2<f64> = random((4, 2).f());
test_overdetermined(a, b)
}
}
};
}

impl_overdetermined!(f32);
impl_overdetermined!(f64);
impl_overdetermined!(c32);
impl_overdetermined!(c64);

/// #column > #row case.
/// Linear problem is underdetermined, `|b - Ax| = 0` and `x` is not unique
fn test_underdetermined<T: Scalar + Lapack>(a: Array2<T>, b: Array2<T>) {
assert_eq!(a.layout().unwrap().size(), (3, 4));
assert_eq!(b.layout().unwrap().size(), (3, 2));

let result = a.least_squares(&b).unwrap();
assert_eq!(result.rank, 3);
assert!(result.residual_sum_of_squares.is_none());

// b == Ax
let x = result.solution;
let ax = a.dot(&x);
assert_close_l2!(&b, &ax, T::real(1.0e-4));
}

macro_rules! impl_underdetermined {
($scalar:ty) => {
paste::item! {
#[test]
fn [<least_squares_ $scalar _underdetermined_ac_bc>]() {
let a: Array2<f64> = random((3, 4));
let b: Array2<f64> = random((3, 2));
test_underdetermined(a, b)
}

/* Unsupported currently. See https://github.com/rust-ndarray/ndarray-linalg/issues/234
#[test]
fn [<least_squares_ $scalar _underdetermined_af_bc>]() {
let a: Array2<f64> = random((3, 4).f());
let b: Array2<f64> = random((3, 2));
test_underdetermined(a, b)
}
#[test]
fn [<least_squares_ $scalar _underdetermined_ac_bf>]() {
let a: Array2<f64> = random((3, 4));
let b: Array2<f64> = random((3, 2).f());
test_underdetermined(a, b)
}
*/

#[test]
fn [<least_squares_ $scalar _underdetermined_af_bf>]() {
let a: Array2<f64> = random((3, 4).f());
let b: Array2<f64> = random((3, 2).f());
test_underdetermined(a, b)
}
}
};
}

impl_underdetermined!(f32);
impl_underdetermined!(f64);
impl_underdetermined!(c32);
impl_underdetermined!(c64);