From 3004354311602dbd6c5fe8c5c8c0f7df62de3055 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Thu, 27 May 2021 20:58:32 -0400 Subject: [PATCH 1/2] Add more tests for Solve --- ndarray-linalg/tests/solve.rs | 202 +++++++++++++++++++++++++++++----- 1 file changed, 174 insertions(+), 28 deletions(-) diff --git a/ndarray-linalg/tests/solve.rs b/ndarray-linalg/tests/solve.rs index d069ec7a..86e95bdc 100644 --- a/ndarray-linalg/tests/solve.rs +++ b/ndarray-linalg/tests/solve.rs @@ -1,42 +1,188 @@ -use ndarray::*; -use ndarray_linalg::*; +use ndarray::prelude::*; +use ndarray_linalg::{ + assert_aclose, assert_close_l2, c32, c64, random, random_hpd, solve::*, OperationNorm, Scalar, +}; + +macro_rules! test_solve { + ( + [$($elem_type:ty => $rtol:expr),*], + $a_ident:ident = $a:expr, + $x_ident:ident = $x:expr, + b = $b:expr, + $solve:ident, + ) => { + $({ + let $a_ident: Array2<$elem_type> = $a; + let $x_ident: Array1<$elem_type> = $x; + let b: Array1<$elem_type> = $b; + let a = $a_ident; + let x = $x_ident; + let rtol = $rtol; + assert_close_l2!(&a.$solve(&b).unwrap(), &x, rtol); + assert_close_l2!(&a.factorize().unwrap().$solve(&b).unwrap(), &x, rtol); + assert_close_l2!(&a.factorize_into().unwrap().$solve(&b).unwrap(), &x, rtol); + })* + }; +} + +macro_rules! test_solve_into { + ( + [$($elem_type:ty => $rtol:expr),*], + $a_ident:ident = $a:expr, + $x_ident:ident = $x:expr, + b = $b:expr, + $solve_into:ident, + ) => { + $({ + let $a_ident: Array2<$elem_type> = $a; + let $x_ident: Array1<$elem_type> = $x; + let b: Array1<$elem_type> = $b; + let a = $a_ident; + let x = $x_ident; + let rtol = $rtol; + assert_close_l2!(&a.$solve_into(b.clone()).unwrap(), &x, rtol); + assert_close_l2!(&a.factorize().unwrap().$solve_into(b.clone()).unwrap(), &x, rtol); + assert_close_l2!(&a.factorize_into().unwrap().$solve_into(b.clone()).unwrap(), &x, rtol); + })* + }; +} + +macro_rules! test_solve_inplace { + ( + [$($elem_type:ty => $rtol:expr),*], + $a_ident:ident = $a:expr, + $x_ident:ident = $x:expr, + b = $b:expr, + $solve_inplace:ident, + ) => { + $({ + let $a_ident: Array2<$elem_type> = $a; + let $x_ident: Array1<$elem_type> = $x; + let b: Array1<$elem_type> = $b; + let a = $a_ident; + let x = $x_ident; + let rtol = $rtol; + { + let mut b = b.clone(); + assert_close_l2!(&a.$solve_inplace(&mut b).unwrap(), &x, rtol); + assert_close_l2!(&b, &x, rtol); + } + { + let mut b = b.clone(); + assert_close_l2!(&a.factorize().unwrap().$solve_inplace(&mut b).unwrap(), &x, rtol); + assert_close_l2!(&b, &x, rtol); + } + { + let mut b = b.clone(); + assert_close_l2!(&a.factorize_into().unwrap().$solve_inplace(&mut b).unwrap(), &x, rtol); + assert_close_l2!(&b, &x, rtol); + } + })* + }; +} + +macro_rules! test_solve_all { + ( + [$($elem_type:ty => $rtol:expr),*], + $a_ident:ident = $a:expr, + $x_ident:ident = $x:expr, + b = $b:expr, + [$solve:ident, $solve_into:ident, $solve_inplace:ident], + ) => { + test_solve!([$($elem_type => $rtol),*], $a_ident = $a, $x_ident = $x, b = $b, $solve,); + test_solve_into!([$($elem_type => $rtol),*], $a_ident = $a, $x_ident = $x, b = $b, $solve_into,); + test_solve_inplace!([$($elem_type => $rtol),*], $a_ident = $a, $x_ident = $x, b = $b, $solve_inplace,); + }; +} + +#[test] +fn solve_random_float() { + for n in 0..=8 { + for &set_f in &[false, true] { + test_solve_all!( + [f32 => 1e-3, f64 => 1e-9], + a = random([n; 2].set_f(set_f)), + x = random(n), + b = a.dot(&x), + [solve, solve_into, solve_inplace], + ); + } + } +} + +#[test] +fn solve_random_complex() { + for n in 0..=8 { + for &set_f in &[false, true] { + test_solve_all!( + [c32 => 1e-3, c64 => 1e-9], + a = random([n; 2].set_f(set_f)), + x = random(n), + b = a.dot(&x), + [solve, solve_into, solve_inplace], + ); + } + } +} #[test] -fn solve_random() { - let a: Array2 = random((3, 3)); - let x: Array1 = random(3); - let b = a.dot(&x); - let y = a.solve_into(b).unwrap(); - assert_close_l2!(&x, &y, 1e-7); +fn solve_t_random_float() { + for n in 0..=8 { + for &set_f in &[false, true] { + test_solve_all!( + [f32 => 1e-3, f64 => 1e-9], + a = random([n; 2].set_f(set_f)), + x = random(n), + b = a.t().dot(&x), + [solve_t, solve_t_into, solve_t_inplace], + ); + } + } } #[test] -fn solve_random_t() { - let a: Array2 = random((3, 3).f()); - let x: Array1 = random(3); - let b = a.dot(&x); - let y = a.solve_into(b).unwrap(); - assert_close_l2!(&x, &y, 1e-7); +fn solve_t_random_complex() { + for n in 0..=8 { + for &set_f in &[false, true] { + test_solve_all!( + [c32 => 1e-3, c64 => 1e-9], + a = random([n; 2].set_f(set_f)), + x = random(n), + b = a.t().dot(&x), + [solve_t, solve_t_into, solve_t_inplace], + ); + } + } } #[test] -fn solve_factorized() { - let a: Array2 = random((3, 3)); - let ans: Array1 = random(3); - let b = a.dot(&ans); - let f = a.factorize_into().unwrap(); - let x = f.solve_into(b).unwrap(); - assert_close_l2!(&x, &ans, 1e-7); +fn solve_h_random_float() { + for n in 0..=8 { + for &set_f in &[false, true] { + test_solve_all!( + [f32 => 1e-3, f64 => 1e-9], + a = random([n; 2].set_f(set_f)), + x = random(n), + b = a.t().mapv(|x| x.conj()).dot(&x), + [solve_h, solve_h_into, solve_h_inplace], + ); + } + } } #[test] -fn solve_factorized_t() { - let a: Array2 = random((3, 3).f()); - let ans: Array1 = random(3); - let b = a.dot(&ans); - let f = a.factorize_into().unwrap(); - let x = f.solve_into(b).unwrap(); - assert_close_l2!(&x, &ans, 1e-7); +fn solve_h_random_complex() { + for n in 0..=8 { + for &set_f in &[false, true] { + test_solve_all!( + [c32 => 1e-3, c64 => 1e-9], + a = random([n; 2].set_f(set_f)), + x = random(n), + b = a.t().mapv(|x| x.conj()).dot(&x), + [solve_h, solve_h_into, solve_h_inplace], + ); + } + } } #[test] From ffe65cb1b9fd4dfe738020b0179909bed5fa5f23 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Thu, 27 May 2021 21:03:00 -0400 Subject: [PATCH 2/2] Fix Solve::solve_h for complex inputs with C layout --- lax/src/solve.rs | 39 +++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/lax/src/solve.rs b/lax/src/solve.rs index 39498a04..3851bde2 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -75,18 +75,49 @@ macro_rules! impl_solve { ipiv: &Pivot, b: &mut [Self], ) -> Result<()> { - let t = match l { + // If the array has C layout, then it needs to be handled + // specially, since LAPACK expects a Fortran-layout array. + // Reinterpreting a C layout array as Fortran layout is + // equivalent to transposing it. So, we can handle the "no + // transpose" and "transpose" cases by swapping to "transpose" + // or "no transpose", respectively. For the "Hermite" case, we + // can take advantage of the following: + // + // ```text + // A^H x = b + // ⟺ conj(A^T) x = b + // ⟺ conj(conj(A^T) x) = conj(b) + // ⟺ conj(conj(A^T)) conj(x) = conj(b) + // ⟺ A^T conj(x) = conj(b) + // ``` + // + // So, we can handle this case by switching to "no transpose" + // (which is equivalent to transposing the array since it will + // be reinterpreted as Fortran layout) and applying the + // elementwise conjugate to `x` and `b`. + let (t, conj) = match l { MatrixLayout::C { .. } => match t { - Transpose::No => Transpose::Transpose, - Transpose::Transpose | Transpose::Hermite => Transpose::No, + Transpose::No => (Transpose::Transpose, false), + Transpose::Transpose => (Transpose::No, false), + Transpose::Hermite => (Transpose::No, true), }, - _ => t, + MatrixLayout::F { .. } => (t, false), }; let (n, _) = l.size(); let nrhs = 1; let ldb = l.lda(); let mut info = 0; + if conj { + for b_elem in &mut *b { + *b_elem = b_elem.conj(); + } + } unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) }; + if conj { + for b_elem in &mut *b { + *b_elem = b_elem.conj(); + } + } info.as_lapack_result()?; Ok(()) }