-
Notifications
You must be signed in to change notification settings - Fork 28
Add deviation functions #41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+635
−2
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
58627a3
Port deviation functions from StatsBase.jl
munckymagik 5b54c65
SQUASH return type doc fixes
munckymagik f306f34
SQUASH try parenthesis to highlight the square root
munckymagik 55cb103
SQUASH fix copy and paste error in docs
munckymagik 9b4a26f
SQUASH add link from package summary to DeviationExt
munckymagik File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,376 @@ | ||
use ndarray::{ArrayBase, Data, Dimension, Zip}; | ||
use num_traits::{Signed, ToPrimitive}; | ||
use std::convert::Into; | ||
use std::ops::AddAssign; | ||
|
||
use crate::errors::{MultiInputError, ShapeMismatch}; | ||
|
||
/// An extension trait for `ArrayBase` providing functions | ||
/// to compute different deviation measures. | ||
pub trait DeviationExt<A, S, D> | ||
where | ||
S: Data<Elem = A>, | ||
D: Dimension, | ||
{ | ||
/// Counts the number of indices at which the elements of the arrays `self` | ||
/// and `other` are equal. | ||
/// | ||
/// The following **errors** may be returned: | ||
/// | ||
/// * `MultiInputError::EmptyInput` if `self` is empty | ||
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape | ||
fn count_eq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError> | ||
where | ||
A: PartialEq; | ||
|
||
/// Counts the number of indices at which the elements of the arrays `self` | ||
/// and `other` are not equal. | ||
/// | ||
/// The following **errors** may be returned: | ||
/// | ||
/// * `MultiInputError::EmptyInput` if `self` is empty | ||
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape | ||
fn count_neq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError> | ||
where | ||
A: PartialEq; | ||
|
||
/// Computes the [squared L2 distance] between `self` and `other`. | ||
/// | ||
/// ```text | ||
/// n | ||
/// ∑ |aᵢ - bᵢ|² | ||
/// i=1 | ||
/// ``` | ||
/// | ||
/// where `self` is `a` and `other` is `b`. | ||
/// | ||
/// The following **errors** may be returned: | ||
/// | ||
/// * `MultiInputError::EmptyInput` if `self` is empty | ||
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape | ||
/// | ||
/// [squared L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance | ||
fn sq_l2_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError> | ||
where | ||
A: AddAssign + Clone + Signed; | ||
|
||
/// Computes the [L2 distance] between `self` and `other`. | ||
/// | ||
/// ```text | ||
/// n | ||
/// √ ( ∑ |aᵢ - bᵢ|² ) | ||
/// i=1 | ||
/// ``` | ||
/// | ||
/// where `self` is `a` and `other` is `b`. | ||
/// | ||
/// The following **errors** may be returned: | ||
/// | ||
/// * `MultiInputError::EmptyInput` if `self` is empty | ||
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape | ||
/// | ||
/// **Panics** if the type cast from `A` to `f64` fails. | ||
/// | ||
/// [L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance | ||
fn l2_dist(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError> | ||
where | ||
A: AddAssign + Clone + Signed + ToPrimitive; | ||
|
||
/// Computes the [L1 distance] between `self` and `other`. | ||
/// | ||
/// ```text | ||
/// n | ||
/// ∑ |aᵢ - bᵢ| | ||
/// i=1 | ||
/// ``` | ||
/// | ||
/// where `self` is `a` and `other` is `b`. | ||
/// | ||
/// The following **errors** may be returned: | ||
/// | ||
/// * `MultiInputError::EmptyInput` if `self` is empty | ||
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape | ||
/// | ||
/// [L1 distance]: https://en.wikipedia.org/wiki/Taxicab_geometry | ||
fn l1_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError> | ||
where | ||
A: AddAssign + Clone + Signed; | ||
|
||
/// Computes the [L∞ distance] between `self` and `other`. | ||
/// | ||
/// ```text | ||
/// max(|aᵢ - bᵢ|) | ||
/// ᵢ | ||
/// ``` | ||
/// | ||
/// where `self` is `a` and `other` is `b`. | ||
/// | ||
/// The following **errors** may be returned: | ||
/// | ||
/// * `MultiInputError::EmptyInput` if `self` is empty | ||
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape | ||
/// | ||
/// [L∞ distance]: https://en.wikipedia.org/wiki/Chebyshev_distance | ||
fn linf_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError> | ||
where | ||
A: Clone + PartialOrd + Signed; | ||
|
||
/// Computes the [mean absolute error] between `self` and `other`. | ||
/// | ||
/// ```text | ||
/// n | ||
/// 1/n * ∑ |aᵢ - bᵢ| | ||
/// i=1 | ||
/// ``` | ||
/// | ||
/// where `self` is `a` and `other` is `b`. | ||
/// | ||
/// The following **errors** may be returned: | ||
/// | ||
/// * `MultiInputError::EmptyInput` if `self` is empty | ||
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape | ||
/// | ||
/// **Panics** if the type cast from `A` to `f64` fails. | ||
/// | ||
/// [mean absolute error]: https://en.wikipedia.org/wiki/Mean_absolute_error | ||
fn mean_abs_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError> | ||
where | ||
A: AddAssign + Clone + Signed + ToPrimitive; | ||
|
||
/// Computes the [mean squared error] between `self` and `other`. | ||
/// | ||
/// ```text | ||
/// n | ||
/// 1/n * ∑ |aᵢ - bᵢ|² | ||
/// i=1 | ||
/// ``` | ||
/// | ||
/// where `self` is `a` and `other` is `b`. | ||
/// | ||
/// The following **errors** may be returned: | ||
/// | ||
/// * `MultiInputError::EmptyInput` if `self` is empty | ||
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape | ||
/// | ||
/// **Panics** if the type cast from `A` to `f64` fails. | ||
/// | ||
/// [mean squared error]: https://en.wikipedia.org/wiki/Mean_squared_error | ||
fn mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError> | ||
where | ||
A: AddAssign + Clone + Signed + ToPrimitive; | ||
|
||
/// Computes the unnormalized [root-mean-square error] between `self` and `other`. | ||
/// | ||
/// ```text | ||
/// √ mse(a, b) | ||
/// ``` | ||
/// | ||
/// where `self` is `a`, `other` is `b` and `mse` is the mean-squared-error. | ||
/// | ||
/// The following **errors** may be returned: | ||
/// | ||
/// * `MultiInputError::EmptyInput` if `self` is empty | ||
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape | ||
/// | ||
/// **Panics** if the type cast from `A` to `f64` fails. | ||
/// | ||
/// [root-mean-square error]: https://en.wikipedia.org/wiki/Root-mean-square_deviation | ||
fn root_mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError> | ||
where | ||
A: AddAssign + Clone + Signed + ToPrimitive; | ||
|
||
/// Computes the [peak signal-to-noise ratio] between `self` and `other`. | ||
/// | ||
/// ```text | ||
/// 10 * log10(maxv^2 / mse(a, b)) | ||
/// ``` | ||
/// | ||
/// where `self` is `a`, `other` is `b`, `mse` is the mean-squared-error | ||
/// and `maxv` is the maximum possible value either array can take. | ||
/// | ||
/// The following **errors** may be returned: | ||
/// | ||
/// * `MultiInputError::EmptyInput` if `self` is empty | ||
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape | ||
/// | ||
/// **Panics** if the type cast from `A` to `f64` fails. | ||
/// | ||
/// [peak signal-to-noise ratio]: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio | ||
fn peak_signal_to_noise_ratio( | ||
&self, | ||
other: &ArrayBase<S, D>, | ||
maxv: A, | ||
) -> Result<f64, MultiInputError> | ||
where | ||
A: AddAssign + Clone + Signed + ToPrimitive; | ||
|
||
private_decl! {} | ||
} | ||
|
||
macro_rules! return_err_if_empty { | ||
($arr:expr) => { | ||
if $arr.len() == 0 { | ||
return Err(MultiInputError::EmptyInput); | ||
} | ||
}; | ||
} | ||
macro_rules! return_err_unless_same_shape { | ||
($arr_a:expr, $arr_b:expr) => { | ||
if $arr_a.shape() != $arr_b.shape() { | ||
return Err(ShapeMismatch { | ||
first_shape: $arr_a.shape().to_vec(), | ||
second_shape: $arr_b.shape().to_vec(), | ||
} | ||
.into()); | ||
} | ||
}; | ||
} | ||
|
||
impl<A, S, D> DeviationExt<A, S, D> for ArrayBase<S, D> | ||
where | ||
S: Data<Elem = A>, | ||
D: Dimension, | ||
{ | ||
fn count_eq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError> | ||
where | ||
A: PartialEq, | ||
{ | ||
return_err_if_empty!(self); | ||
return_err_unless_same_shape!(self, other); | ||
|
||
let mut count = 0; | ||
|
||
Zip::from(self).and(other).apply(|a, b| { | ||
if a == b { | ||
count += 1; | ||
} | ||
}); | ||
|
||
Ok(count) | ||
} | ||
|
||
fn count_neq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError> | ||
where | ||
A: PartialEq, | ||
{ | ||
self.count_eq(other).map(|n_eq| self.len() - n_eq) | ||
} | ||
|
||
fn sq_l2_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError> | ||
where | ||
A: AddAssign + Clone + Signed, | ||
{ | ||
return_err_if_empty!(self); | ||
return_err_unless_same_shape!(self, other); | ||
|
||
let mut result = A::zero(); | ||
|
||
Zip::from(self).and(other).apply(|self_i, other_i| { | ||
let (a, b) = (self_i.clone(), other_i.clone()); | ||
let abs_diff = (a - b).abs(); | ||
result += abs_diff.clone() * abs_diff; | ||
LukeMathWalker marked this conversation as resolved.
Show resolved
Hide resolved
|
||
}); | ||
|
||
Ok(result) | ||
} | ||
|
||
fn l2_dist(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError> | ||
where | ||
A: AddAssign + Clone + Signed + ToPrimitive, | ||
{ | ||
let sq_l2_dist = self | ||
.sq_l2_dist(other)? | ||
.to_f64() | ||
.expect("failed cast from type A to f64"); | ||
|
||
Ok(sq_l2_dist.sqrt()) | ||
} | ||
|
||
fn l1_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError> | ||
where | ||
A: AddAssign + Clone + Signed, | ||
{ | ||
return_err_if_empty!(self); | ||
return_err_unless_same_shape!(self, other); | ||
|
||
let mut result = A::zero(); | ||
|
||
Zip::from(self).and(other).apply(|self_i, other_i| { | ||
let (a, b) = (self_i.clone(), other_i.clone()); | ||
result += (a - b).abs(); | ||
}); | ||
|
||
Ok(result) | ||
} | ||
|
||
fn linf_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError> | ||
where | ||
A: Clone + PartialOrd + Signed, | ||
{ | ||
return_err_if_empty!(self); | ||
return_err_unless_same_shape!(self, other); | ||
|
||
let mut max = A::zero(); | ||
|
||
Zip::from(self).and(other).apply(|self_i, other_i| { | ||
let (a, b) = (self_i.clone(), other_i.clone()); | ||
let diff = (a - b).abs(); | ||
if diff > max { | ||
max = diff; | ||
} | ||
}); | ||
|
||
Ok(max) | ||
} | ||
|
||
fn mean_abs_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError> | ||
where | ||
A: AddAssign + Clone + Signed + ToPrimitive, | ||
{ | ||
let l1_dist = self | ||
.l1_dist(other)? | ||
.to_f64() | ||
.expect("failed cast from type A to f64"); | ||
let n = self.len() as f64; | ||
|
||
Ok(l1_dist / n) | ||
} | ||
|
||
fn mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError> | ||
where | ||
A: AddAssign + Clone + Signed + ToPrimitive, | ||
{ | ||
let sq_l2_dist = self | ||
.sq_l2_dist(other)? | ||
.to_f64() | ||
.expect("failed cast from type A to f64"); | ||
let n = self.len() as f64; | ||
|
||
Ok(sq_l2_dist / n) | ||
} | ||
|
||
fn root_mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError> | ||
where | ||
A: AddAssign + Clone + Signed + ToPrimitive, | ||
{ | ||
let msd = self.mean_sq_err(other)?; | ||
Ok(msd.sqrt()) | ||
} | ||
|
||
fn peak_signal_to_noise_ratio( | ||
&self, | ||
other: &ArrayBase<S, D>, | ||
maxv: A, | ||
) -> Result<f64, MultiInputError> | ||
where | ||
A: AddAssign + Clone + Signed + ToPrimitive, | ||
{ | ||
let maxv_f = maxv.to_f64().expect("failed cast from type A to f64"); | ||
let msd = self.mean_sq_err(&other)?; | ||
let psnr = 10. * f64::log10(maxv_f * maxv_f / msd); | ||
|
||
Ok(psnr) | ||
} | ||
|
||
private_impl! {} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,252 @@ | ||
use ndarray_stats::errors::{MultiInputError, ShapeMismatch}; | ||
use ndarray_stats::DeviationExt; | ||
|
||
use approx::assert_abs_diff_eq; | ||
use ndarray::{array, Array1}; | ||
use num_bigint::BigInt; | ||
use num_traits::Float; | ||
|
||
use std::f64; | ||
|
||
#[test] | ||
fn test_count_eq() -> Result<(), MultiInputError> { | ||
let a = array![0., 0.]; | ||
let b = array![1., 0.]; | ||
let c = array![0., 1.]; | ||
let d = array![1., 1.]; | ||
|
||
assert_eq!(a.count_eq(&a)?, 2); | ||
assert_eq!(a.count_eq(&b)?, 1); | ||
assert_eq!(a.count_eq(&c)?, 1); | ||
assert_eq!(a.count_eq(&d)?, 0); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_count_neq() -> Result<(), MultiInputError> { | ||
let a = array![0., 0.]; | ||
let b = array![1., 0.]; | ||
let c = array![0., 1.]; | ||
let d = array![1., 1.]; | ||
|
||
assert_eq!(a.count_neq(&a)?, 0); | ||
assert_eq!(a.count_neq(&b)?, 1); | ||
assert_eq!(a.count_neq(&c)?, 1); | ||
assert_eq!(a.count_neq(&d)?, 2); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_sq_l2_dist() -> Result<(), MultiInputError> { | ||
let a = array![0., 1., 4., 2.]; | ||
let b = array![1., 1., 2., 4.]; | ||
|
||
assert_eq!(a.sq_l2_dist(&b)?, 9.); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_l2_dist() -> Result<(), MultiInputError> { | ||
let a = array![0., 1., 4., 2.]; | ||
let b = array![1., 1., 2., 4.]; | ||
|
||
assert_eq!(a.l2_dist(&b)?, 3.); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_l1_dist() -> Result<(), MultiInputError> { | ||
let a = array![0., 1., 4., 2.]; | ||
let b = array![1., 1., 2., 4.]; | ||
|
||
assert_eq!(a.l1_dist(&b)?, 5.); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_linf_dist() -> Result<(), MultiInputError> { | ||
let a = array![0., 0.]; | ||
let b = array![1., 0.]; | ||
let c = array![1., 2.]; | ||
|
||
assert_eq!(a.linf_dist(&a)?, 0.); | ||
|
||
assert_eq!(a.linf_dist(&b)?, 1.); | ||
assert_eq!(b.linf_dist(&a)?, 1.); | ||
|
||
assert_eq!(a.linf_dist(&c)?, 2.); | ||
assert_eq!(c.linf_dist(&a)?, 2.); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_mean_abs_err() -> Result<(), MultiInputError> { | ||
let a = array![1., 1.]; | ||
let b = array![3., 5.]; | ||
|
||
assert_eq!(a.mean_abs_err(&a)?, 0.); | ||
assert_eq!(a.mean_abs_err(&b)?, 3.); | ||
assert_eq!(b.mean_abs_err(&a)?, 3.); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_mean_sq_err() -> Result<(), MultiInputError> { | ||
let a = array![1., 1.]; | ||
let b = array![3., 5.]; | ||
|
||
assert_eq!(a.mean_sq_err(&a)?, 0.); | ||
assert_eq!(a.mean_sq_err(&b)?, 10.); | ||
assert_eq!(b.mean_sq_err(&a)?, 10.); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_root_mean_sq_err() -> Result<(), MultiInputError> { | ||
let a = array![1., 1.]; | ||
let b = array![3., 5.]; | ||
|
||
assert_eq!(a.root_mean_sq_err(&a)?, 0.); | ||
assert_abs_diff_eq!(a.root_mean_sq_err(&b)?, 10.0.sqrt()); | ||
assert_abs_diff_eq!(b.root_mean_sq_err(&a)?, 10.0.sqrt()); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_peak_signal_to_noise_ratio() -> Result<(), MultiInputError> { | ||
let a = array![1., 1.]; | ||
assert!(a.peak_signal_to_noise_ratio(&a, 1.)?.is_infinite()); | ||
|
||
let a = array![1., 2., 3., 4., 5., 6., 7.]; | ||
let b = array![1., 3., 3., 4., 6., 7., 8.]; | ||
let maxv = 8.; | ||
let expected = 20. * Float::log10(maxv) - 10. * Float::log10(a.mean_sq_err(&b)?); | ||
let actual = a.peak_signal_to_noise_ratio(&b, maxv)?; | ||
|
||
assert_abs_diff_eq!(actual, expected); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_deviations_with_n_by_m_ints() -> Result<(), MultiInputError> { | ||
let a = array![[0, 1], [4, 2]]; | ||
let b = array![[1, 1], [2, 4]]; | ||
|
||
assert_eq!(a.count_eq(&a)?, 4); | ||
assert_eq!(a.count_neq(&a)?, 0); | ||
|
||
assert_eq!(a.sq_l2_dist(&b)?, 9); | ||
assert_eq!(a.l2_dist(&b)?, 3.); | ||
assert_eq!(a.l1_dist(&b)?, 5); | ||
assert_eq!(a.linf_dist(&b)?, 2); | ||
|
||
assert_abs_diff_eq!(a.mean_abs_err(&b)?, 1.25); | ||
assert_abs_diff_eq!(a.mean_sq_err(&b)?, 2.25); | ||
assert_abs_diff_eq!(a.root_mean_sq_err(&b)?, 1.5); | ||
assert_abs_diff_eq!(a.peak_signal_to_noise_ratio(&b, 4)?, 8.519374645445623); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_deviations_with_empty_receiver() { | ||
let a: Array1<f64> = array![]; | ||
let b: Array1<f64> = array![1.]; | ||
|
||
assert_eq!(a.count_eq(&b), Err(MultiInputError::EmptyInput)); | ||
assert_eq!(a.count_neq(&b), Err(MultiInputError::EmptyInput)); | ||
|
||
assert_eq!(a.sq_l2_dist(&b), Err(MultiInputError::EmptyInput)); | ||
assert_eq!(a.l2_dist(&b), Err(MultiInputError::EmptyInput)); | ||
assert_eq!(a.l1_dist(&b), Err(MultiInputError::EmptyInput)); | ||
assert_eq!(a.linf_dist(&b), Err(MultiInputError::EmptyInput)); | ||
|
||
assert_eq!(a.mean_abs_err(&b), Err(MultiInputError::EmptyInput)); | ||
assert_eq!(a.mean_sq_err(&b), Err(MultiInputError::EmptyInput)); | ||
assert_eq!(a.root_mean_sq_err(&b), Err(MultiInputError::EmptyInput)); | ||
assert_eq!( | ||
a.peak_signal_to_noise_ratio(&b, 0.), | ||
Err(MultiInputError::EmptyInput) | ||
); | ||
} | ||
|
||
#[test] | ||
fn test_deviations_do_not_panic_if_nans() -> Result<(), MultiInputError> { | ||
let a: Array1<f64> = array![1., f64::NAN, 3., f64::NAN]; | ||
let b: Array1<f64> = array![1., f64::NAN, 3., 4.]; | ||
|
||
assert_eq!(a.count_eq(&b)?, 2); | ||
assert_eq!(a.count_neq(&b)?, 2); | ||
|
||
assert!(a.sq_l2_dist(&b)?.is_nan()); | ||
assert!(a.l2_dist(&b)?.is_nan()); | ||
assert!(a.l1_dist(&b)?.is_nan()); | ||
assert_eq!(a.linf_dist(&b)?, 0.); | ||
|
||
assert!(a.mean_abs_err(&b)?.is_nan()); | ||
assert!(a.mean_sq_err(&b)?.is_nan()); | ||
assert!(a.root_mean_sq_err(&b)?.is_nan()); | ||
assert!(a.peak_signal_to_noise_ratio(&b, 0.)?.is_nan()); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_deviations_with_empty_argument() { | ||
let a: Array1<f64> = array![1.]; | ||
let b: Array1<f64> = array![]; | ||
|
||
let shape_mismatch_err = MultiInputError::ShapeMismatch(ShapeMismatch { | ||
first_shape: a.shape().to_vec(), | ||
second_shape: b.shape().to_vec(), | ||
}); | ||
let expected_err_usize = Err(shape_mismatch_err.clone()); | ||
let expected_err_f64 = Err(shape_mismatch_err); | ||
|
||
assert_eq!(a.count_eq(&b), expected_err_usize); | ||
assert_eq!(a.count_neq(&b), expected_err_usize); | ||
|
||
assert_eq!(a.sq_l2_dist(&b), expected_err_f64); | ||
assert_eq!(a.l2_dist(&b), expected_err_f64); | ||
assert_eq!(a.l1_dist(&b), expected_err_f64); | ||
assert_eq!(a.linf_dist(&b), expected_err_f64); | ||
|
||
assert_eq!(a.mean_abs_err(&b), expected_err_f64); | ||
assert_eq!(a.mean_sq_err(&b), expected_err_f64); | ||
assert_eq!(a.root_mean_sq_err(&b), expected_err_f64); | ||
assert_eq!(a.peak_signal_to_noise_ratio(&b, 0.), expected_err_f64); | ||
} | ||
|
||
#[test] | ||
fn test_deviations_with_non_copyable() -> Result<(), MultiInputError> { | ||
let a: Array1<BigInt> = array![0.into(), 1.into(), 4.into(), 2.into()]; | ||
let b: Array1<BigInt> = array![1.into(), 1.into(), 2.into(), 4.into()]; | ||
|
||
assert_eq!(a.count_eq(&a)?, 4); | ||
assert_eq!(a.count_neq(&a)?, 0); | ||
|
||
assert_eq!(a.sq_l2_dist(&b)?, 9.into()); | ||
assert_eq!(a.l2_dist(&b)?, 3.); | ||
assert_eq!(a.l1_dist(&b)?, 5.into()); | ||
assert_eq!(a.linf_dist(&b)?, 2.into()); | ||
|
||
assert_abs_diff_eq!(a.mean_abs_err(&b)?, 1.25); | ||
assert_abs_diff_eq!(a.mean_sq_err(&b)?, 2.25); | ||
assert_abs_diff_eq!(a.root_mean_sq_err(&b)?, 1.5); | ||
assert_abs_diff_eq!( | ||
a.peak_signal_to_noise_ratio(&b, 4.into())?, | ||
8.519374645445623 | ||
); | ||
|
||
Ok(()) | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.