Skip to content

Add deviation functions #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 4, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@ quickcheck = { version = "0.8.1", default-features = false }
ndarray-rand = "0.9"
approx = "0.3"
quickcheck_macros = "0.8"
num-bigint = "0.2.2"

[[bench]]
name = "sort"
376 changes: 376 additions & 0 deletions src/deviation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,376 @@
use ndarray::{ArrayBase, Data, Dimension, Zip};
use num_traits::{Signed, ToPrimitive};
use std::convert::Into;
use std::ops::AddAssign;

use crate::errors::{MultiInputError, ShapeMismatch};

/// An extension trait for `ArrayBase` providing functions
/// to compute different deviation measures.
pub trait DeviationExt<A, S, D>
where
S: Data<Elem = A>,
D: Dimension,
{
/// Counts the number of indices at which the elements of the arrays `self`
/// and `other` are equal.
///
/// The following **errors** may be returned:
///
/// * `MultiInputError::EmptyInput` if `self` is empty
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
fn count_eq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
where
A: PartialEq;

/// Counts the number of indices at which the elements of the arrays `self`
/// and `other` are not equal.
///
/// The following **errors** may be returned:
///
/// * `MultiInputError::EmptyInput` if `self` is empty
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
fn count_neq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
where
A: PartialEq;

/// Computes the [squared L2 distance] between `self` and `other`.
///
/// ```text
/// n
/// ∑ |aᵢ - bᵢ|²
/// i=1
/// ```
///
/// where `self` is `a` and `other` is `b`.
///
/// The following **errors** may be returned:
///
/// * `MultiInputError::EmptyInput` if `self` is empty
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
///
/// [squared L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance
fn sq_l2_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
where
A: AddAssign + Clone + Signed;

/// Computes the [L2 distance] between `self` and `other`.
///
/// ```text
/// n
/// √ ( ∑ |aᵢ - bᵢ|² )
/// i=1
/// ```
///
/// where `self` is `a` and `other` is `b`.
///
/// The following **errors** may be returned:
///
/// * `MultiInputError::EmptyInput` if `self` is empty
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
///
/// **Panics** if the type cast from `A` to `f64` fails.
///
/// [L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance
fn l2_dist(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
where
A: AddAssign + Clone + Signed + ToPrimitive;

/// Computes the [L1 distance] between `self` and `other`.
///
/// ```text
/// n
/// ∑ |aᵢ - bᵢ|
/// i=1
/// ```
///
/// where `self` is `a` and `other` is `b`.
///
/// The following **errors** may be returned:
///
/// * `MultiInputError::EmptyInput` if `self` is empty
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
///
/// [L1 distance]: https://en.wikipedia.org/wiki/Taxicab_geometry
fn l1_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
where
A: AddAssign + Clone + Signed;

/// Computes the [L∞ distance] between `self` and `other`.
///
/// ```text
/// max(|aᵢ - bᵢ|)
/// ᵢ
/// ```
///
/// where `self` is `a` and `other` is `b`.
///
/// The following **errors** may be returned:
///
/// * `MultiInputError::EmptyInput` if `self` is empty
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
///
/// [L∞ distance]: https://en.wikipedia.org/wiki/Chebyshev_distance
fn linf_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
where
A: Clone + PartialOrd + Signed;

/// Computes the [mean absolute error] between `self` and `other`.
///
/// ```text
/// n
/// 1/n * ∑ |aᵢ - bᵢ|
/// i=1
/// ```
///
/// where `self` is `a` and `other` is `b`.
///
/// The following **errors** may be returned:
///
/// * `MultiInputError::EmptyInput` if `self` is empty
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
///
/// **Panics** if the type cast from `A` to `f64` fails.
///
/// [mean absolute error]: https://en.wikipedia.org/wiki/Mean_absolute_error
fn mean_abs_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
where
A: AddAssign + Clone + Signed + ToPrimitive;

/// Computes the [mean squared error] between `self` and `other`.
///
/// ```text
/// n
/// 1/n * ∑ |aᵢ - bᵢ|²
/// i=1
/// ```
///
/// where `self` is `a` and `other` is `b`.
///
/// The following **errors** may be returned:
///
/// * `MultiInputError::EmptyInput` if `self` is empty
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
///
/// **Panics** if the type cast from `A` to `f64` fails.
///
/// [mean squared error]: https://en.wikipedia.org/wiki/Mean_squared_error
fn mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
where
A: AddAssign + Clone + Signed + ToPrimitive;

/// Computes the unnormalized [root-mean-square error] between `self` and `other`.
///
/// ```text
/// √ mse(a, b)
/// ```
///
/// where `self` is `a`, `other` is `b` and `mse` is the mean-squared-error.
///
/// The following **errors** may be returned:
///
/// * `MultiInputError::EmptyInput` if `self` is empty
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
///
/// **Panics** if the type cast from `A` to `f64` fails.
///
/// [root-mean-square error]: https://en.wikipedia.org/wiki/Root-mean-square_deviation
fn root_mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
where
A: AddAssign + Clone + Signed + ToPrimitive;

/// Computes the [peak signal-to-noise ratio] between `self` and `other`.
///
/// ```text
/// 10 * log10(maxv^2 / mse(a, b))
/// ```
///
/// where `self` is `a`, `other` is `b`, `mse` is the mean-squared-error
/// and `maxv` is the maximum possible value either array can take.
///
/// The following **errors** may be returned:
///
/// * `MultiInputError::EmptyInput` if `self` is empty
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
///
/// **Panics** if the type cast from `A` to `f64` fails.
///
/// [peak signal-to-noise ratio]: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
fn peak_signal_to_noise_ratio(
&self,
other: &ArrayBase<S, D>,
maxv: A,
) -> Result<f64, MultiInputError>
where
A: AddAssign + Clone + Signed + ToPrimitive;

private_decl! {}
}

macro_rules! return_err_if_empty {
($arr:expr) => {
if $arr.len() == 0 {
return Err(MultiInputError::EmptyInput);
}
};
}
macro_rules! return_err_unless_same_shape {
($arr_a:expr, $arr_b:expr) => {
if $arr_a.shape() != $arr_b.shape() {
return Err(ShapeMismatch {
first_shape: $arr_a.shape().to_vec(),
second_shape: $arr_b.shape().to_vec(),
}
.into());
}
};
}

impl<A, S, D> DeviationExt<A, S, D> for ArrayBase<S, D>
where
S: Data<Elem = A>,
D: Dimension,
{
fn count_eq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
where
A: PartialEq,
{
return_err_if_empty!(self);
return_err_unless_same_shape!(self, other);

let mut count = 0;

Zip::from(self).and(other).apply(|a, b| {
if a == b {
count += 1;
}
});

Ok(count)
}

fn count_neq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
where
A: PartialEq,
{
self.count_eq(other).map(|n_eq| self.len() - n_eq)
}

fn sq_l2_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
where
A: AddAssign + Clone + Signed,
{
return_err_if_empty!(self);
return_err_unless_same_shape!(self, other);

let mut result = A::zero();

Zip::from(self).and(other).apply(|self_i, other_i| {
let (a, b) = (self_i.clone(), other_i.clone());
let abs_diff = (a - b).abs();
result += abs_diff.clone() * abs_diff;
});

Ok(result)
}

fn l2_dist(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
where
A: AddAssign + Clone + Signed + ToPrimitive,
{
let sq_l2_dist = self
.sq_l2_dist(other)?
.to_f64()
.expect("failed cast from type A to f64");

Ok(sq_l2_dist.sqrt())
}

fn l1_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
where
A: AddAssign + Clone + Signed,
{
return_err_if_empty!(self);
return_err_unless_same_shape!(self, other);

let mut result = A::zero();

Zip::from(self).and(other).apply(|self_i, other_i| {
let (a, b) = (self_i.clone(), other_i.clone());
result += (a - b).abs();
});

Ok(result)
}

fn linf_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
where
A: Clone + PartialOrd + Signed,
{
return_err_if_empty!(self);
return_err_unless_same_shape!(self, other);

let mut max = A::zero();

Zip::from(self).and(other).apply(|self_i, other_i| {
let (a, b) = (self_i.clone(), other_i.clone());
let diff = (a - b).abs();
if diff > max {
max = diff;
}
});

Ok(max)
}

fn mean_abs_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
where
A: AddAssign + Clone + Signed + ToPrimitive,
{
let l1_dist = self
.l1_dist(other)?
.to_f64()
.expect("failed cast from type A to f64");
let n = self.len() as f64;

Ok(l1_dist / n)
}

fn mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
where
A: AddAssign + Clone + Signed + ToPrimitive,
{
let sq_l2_dist = self
.sq_l2_dist(other)?
.to_f64()
.expect("failed cast from type A to f64");
let n = self.len() as f64;

Ok(sq_l2_dist / n)
}

fn root_mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
where
A: AddAssign + Clone + Signed + ToPrimitive,
{
let msd = self.mean_sq_err(other)?;
Ok(msd.sqrt())
}

fn peak_signal_to_noise_ratio(
&self,
other: &ArrayBase<S, D>,
maxv: A,
) -> Result<f64, MultiInputError>
where
A: AddAssign + Clone + Signed + ToPrimitive,
{
let maxv_f = maxv.to_f64().expect("failed cast from type A to f64");
let msd = self.mean_sq_err(&other)?;
let psnr = 10. * f64::log10(maxv_f * maxv_f / msd);

Ok(psnr)
}

private_impl! {}
}
4 changes: 2 additions & 2 deletions src/errors.rs
Original file line number Diff line number Diff line change
@@ -46,7 +46,7 @@ impl From<EmptyInput> for MinMaxError {
/// An error used by methods and functions that take two arrays as argument and
/// expect them to have exactly the same shape
/// (e.g. `ShapeMismatch` is raised when `a.shape() == b.shape()` evaluates to `False`).
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub struct ShapeMismatch {
pub first_shape: Vec<usize>,
pub second_shape: Vec<usize>,
@@ -65,7 +65,7 @@ impl fmt::Display for ShapeMismatch {
impl Error for ShapeMismatch {}

/// An error for methods that take multiple non-empty array inputs.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub enum MultiInputError {
/// One or more of the arrays were empty.
EmptyInput,
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
//! - [partitioning];
//! - [correlation analysis] (covariance, pearson correlation);
//! - [measures from information theory] (entropy, KL divergence, etc.);
//! - [measures of deviation] (count equal, L1, L2 distances, mean squared err etc.)
//! - [histogram computation].
//!
//! Please feel free to contribute new functionality! A roadmap can be found [here].
@@ -21,13 +22,15 @@
//! [partitioning]: trait.Sort1dExt.html
//! [summary statistics]: trait.SummaryStatisticsExt.html
//! [correlation analysis]: trait.CorrelationExt.html
//! [measures of deviation]: trait.DeviationExt.html
//! [measures from information theory]: trait.EntropyExt.html
//! [histogram computation]: histogram/index.html
//! [here]: https://github.com/rust-ndarray/ndarray-stats/issues/1
//! [`NumPy`]: https://docs.scipy.org/doc/numpy-1.14.1/reference/routines.statistics.html
//! [`StatsBase.jl`]: https://juliastats.github.io/StatsBase.jl/latest/
pub use crate::correlation::CorrelationExt;
pub use crate::deviation::DeviationExt;
pub use crate::entropy::EntropyExt;
pub use crate::histogram::HistogramExt;
pub use crate::maybe_nan::{MaybeNan, MaybeNanExt};
@@ -69,6 +72,7 @@ mod private {
}

mod correlation;
mod deviation;
mod entropy;
pub mod errors;
pub mod histogram;
252 changes: 252 additions & 0 deletions tests/deviation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
use ndarray_stats::errors::{MultiInputError, ShapeMismatch};
use ndarray_stats::DeviationExt;

use approx::assert_abs_diff_eq;
use ndarray::{array, Array1};
use num_bigint::BigInt;
use num_traits::Float;

use std::f64;

#[test]
fn test_count_eq() -> Result<(), MultiInputError> {
let a = array![0., 0.];
let b = array![1., 0.];
let c = array![0., 1.];
let d = array![1., 1.];

assert_eq!(a.count_eq(&a)?, 2);
assert_eq!(a.count_eq(&b)?, 1);
assert_eq!(a.count_eq(&c)?, 1);
assert_eq!(a.count_eq(&d)?, 0);

Ok(())
}

#[test]
fn test_count_neq() -> Result<(), MultiInputError> {
let a = array![0., 0.];
let b = array![1., 0.];
let c = array![0., 1.];
let d = array![1., 1.];

assert_eq!(a.count_neq(&a)?, 0);
assert_eq!(a.count_neq(&b)?, 1);
assert_eq!(a.count_neq(&c)?, 1);
assert_eq!(a.count_neq(&d)?, 2);

Ok(())
}

#[test]
fn test_sq_l2_dist() -> Result<(), MultiInputError> {
let a = array![0., 1., 4., 2.];
let b = array![1., 1., 2., 4.];

assert_eq!(a.sq_l2_dist(&b)?, 9.);

Ok(())
}

#[test]
fn test_l2_dist() -> Result<(), MultiInputError> {
let a = array![0., 1., 4., 2.];
let b = array![1., 1., 2., 4.];

assert_eq!(a.l2_dist(&b)?, 3.);

Ok(())
}

#[test]
fn test_l1_dist() -> Result<(), MultiInputError> {
let a = array![0., 1., 4., 2.];
let b = array![1., 1., 2., 4.];

assert_eq!(a.l1_dist(&b)?, 5.);

Ok(())
}

#[test]
fn test_linf_dist() -> Result<(), MultiInputError> {
let a = array![0., 0.];
let b = array![1., 0.];
let c = array![1., 2.];

assert_eq!(a.linf_dist(&a)?, 0.);

assert_eq!(a.linf_dist(&b)?, 1.);
assert_eq!(b.linf_dist(&a)?, 1.);

assert_eq!(a.linf_dist(&c)?, 2.);
assert_eq!(c.linf_dist(&a)?, 2.);

Ok(())
}

#[test]
fn test_mean_abs_err() -> Result<(), MultiInputError> {
let a = array![1., 1.];
let b = array![3., 5.];

assert_eq!(a.mean_abs_err(&a)?, 0.);
assert_eq!(a.mean_abs_err(&b)?, 3.);
assert_eq!(b.mean_abs_err(&a)?, 3.);

Ok(())
}

#[test]
fn test_mean_sq_err() -> Result<(), MultiInputError> {
let a = array![1., 1.];
let b = array![3., 5.];

assert_eq!(a.mean_sq_err(&a)?, 0.);
assert_eq!(a.mean_sq_err(&b)?, 10.);
assert_eq!(b.mean_sq_err(&a)?, 10.);

Ok(())
}

#[test]
fn test_root_mean_sq_err() -> Result<(), MultiInputError> {
let a = array![1., 1.];
let b = array![3., 5.];

assert_eq!(a.root_mean_sq_err(&a)?, 0.);
assert_abs_diff_eq!(a.root_mean_sq_err(&b)?, 10.0.sqrt());
assert_abs_diff_eq!(b.root_mean_sq_err(&a)?, 10.0.sqrt());

Ok(())
}

#[test]
fn test_peak_signal_to_noise_ratio() -> Result<(), MultiInputError> {
let a = array![1., 1.];
assert!(a.peak_signal_to_noise_ratio(&a, 1.)?.is_infinite());

let a = array![1., 2., 3., 4., 5., 6., 7.];
let b = array![1., 3., 3., 4., 6., 7., 8.];
let maxv = 8.;
let expected = 20. * Float::log10(maxv) - 10. * Float::log10(a.mean_sq_err(&b)?);
let actual = a.peak_signal_to_noise_ratio(&b, maxv)?;

assert_abs_diff_eq!(actual, expected);

Ok(())
}

#[test]
fn test_deviations_with_n_by_m_ints() -> Result<(), MultiInputError> {
let a = array![[0, 1], [4, 2]];
let b = array![[1, 1], [2, 4]];

assert_eq!(a.count_eq(&a)?, 4);
assert_eq!(a.count_neq(&a)?, 0);

assert_eq!(a.sq_l2_dist(&b)?, 9);
assert_eq!(a.l2_dist(&b)?, 3.);
assert_eq!(a.l1_dist(&b)?, 5);
assert_eq!(a.linf_dist(&b)?, 2);

assert_abs_diff_eq!(a.mean_abs_err(&b)?, 1.25);
assert_abs_diff_eq!(a.mean_sq_err(&b)?, 2.25);
assert_abs_diff_eq!(a.root_mean_sq_err(&b)?, 1.5);
assert_abs_diff_eq!(a.peak_signal_to_noise_ratio(&b, 4)?, 8.519374645445623);

Ok(())
}

#[test]
fn test_deviations_with_empty_receiver() {
let a: Array1<f64> = array![];
let b: Array1<f64> = array![1.];

assert_eq!(a.count_eq(&b), Err(MultiInputError::EmptyInput));
assert_eq!(a.count_neq(&b), Err(MultiInputError::EmptyInput));

assert_eq!(a.sq_l2_dist(&b), Err(MultiInputError::EmptyInput));
assert_eq!(a.l2_dist(&b), Err(MultiInputError::EmptyInput));
assert_eq!(a.l1_dist(&b), Err(MultiInputError::EmptyInput));
assert_eq!(a.linf_dist(&b), Err(MultiInputError::EmptyInput));

assert_eq!(a.mean_abs_err(&b), Err(MultiInputError::EmptyInput));
assert_eq!(a.mean_sq_err(&b), Err(MultiInputError::EmptyInput));
assert_eq!(a.root_mean_sq_err(&b), Err(MultiInputError::EmptyInput));
assert_eq!(
a.peak_signal_to_noise_ratio(&b, 0.),
Err(MultiInputError::EmptyInput)
);
}

#[test]
fn test_deviations_do_not_panic_if_nans() -> Result<(), MultiInputError> {
let a: Array1<f64> = array![1., f64::NAN, 3., f64::NAN];
let b: Array1<f64> = array![1., f64::NAN, 3., 4.];

assert_eq!(a.count_eq(&b)?, 2);
assert_eq!(a.count_neq(&b)?, 2);

assert!(a.sq_l2_dist(&b)?.is_nan());
assert!(a.l2_dist(&b)?.is_nan());
assert!(a.l1_dist(&b)?.is_nan());
assert_eq!(a.linf_dist(&b)?, 0.);

assert!(a.mean_abs_err(&b)?.is_nan());
assert!(a.mean_sq_err(&b)?.is_nan());
assert!(a.root_mean_sq_err(&b)?.is_nan());
assert!(a.peak_signal_to_noise_ratio(&b, 0.)?.is_nan());

Ok(())
}

#[test]
fn test_deviations_with_empty_argument() {
let a: Array1<f64> = array![1.];
let b: Array1<f64> = array![];

let shape_mismatch_err = MultiInputError::ShapeMismatch(ShapeMismatch {
first_shape: a.shape().to_vec(),
second_shape: b.shape().to_vec(),
});
let expected_err_usize = Err(shape_mismatch_err.clone());
let expected_err_f64 = Err(shape_mismatch_err);

assert_eq!(a.count_eq(&b), expected_err_usize);
assert_eq!(a.count_neq(&b), expected_err_usize);

assert_eq!(a.sq_l2_dist(&b), expected_err_f64);
assert_eq!(a.l2_dist(&b), expected_err_f64);
assert_eq!(a.l1_dist(&b), expected_err_f64);
assert_eq!(a.linf_dist(&b), expected_err_f64);

assert_eq!(a.mean_abs_err(&b), expected_err_f64);
assert_eq!(a.mean_sq_err(&b), expected_err_f64);
assert_eq!(a.root_mean_sq_err(&b), expected_err_f64);
assert_eq!(a.peak_signal_to_noise_ratio(&b, 0.), expected_err_f64);
}

#[test]
fn test_deviations_with_non_copyable() -> Result<(), MultiInputError> {
let a: Array1<BigInt> = array![0.into(), 1.into(), 4.into(), 2.into()];
let b: Array1<BigInt> = array![1.into(), 1.into(), 2.into(), 4.into()];

assert_eq!(a.count_eq(&a)?, 4);
assert_eq!(a.count_neq(&a)?, 0);

assert_eq!(a.sq_l2_dist(&b)?, 9.into());
assert_eq!(a.l2_dist(&b)?, 3.);
assert_eq!(a.l1_dist(&b)?, 5.into());
assert_eq!(a.linf_dist(&b)?, 2.into());

assert_abs_diff_eq!(a.mean_abs_err(&b)?, 1.25);
assert_abs_diff_eq!(a.mean_sq_err(&b)?, 2.25);
assert_abs_diff_eq!(a.root_mean_sq_err(&b)?, 1.5);
assert_abs_diff_eq!(
a.peak_signal_to_noise_ratio(&b, 4.into())?,
8.519374645445623
);

Ok(())
}