Skip to content

Histogram error handling #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
bfb0db7
Use Option as return type where things might fail
LukeMathWalker Jan 29, 2019
4057a72
Test suite aligned with docs
LukeMathWalker Jan 29, 2019
fa1eb34
Equispaced does not panic anymore
Jan 30, 2019
669a33f
Fixed some tests
Jan 30, 2019
0e5eb6b
Fixed FD tests
Jan 30, 2019
04cf0a7
Fixed wrong condition in IF
Jan 30, 2019
4f429bc
Fixed wrong test
Jan 30, 2019
4e74c48
Added new test for EquiSpaced and fixed old one
Jan 30, 2019
56b7e45
Fixed doc tests
Jan 30, 2019
12906fd
Fix docs.
LukeMathWalker Feb 2, 2019
9d1862f
Fix docs.
LukeMathWalker Feb 2, 2019
64789d6
Fix docs.
LukeMathWalker Feb 2, 2019
fe150d1
Fmt
LukeMathWalker Mar 26, 2019
facd4c4
Merge master
LukeMathWalker Mar 26, 2019
b28c35a
Create StrategyError
LukeMathWalker Mar 26, 2019
c06f382
Fmt
LukeMathWalker Mar 26, 2019
4a24f5a
Return Result. Fix Equispaced, Sqrt and Rice
LukeMathWalker Mar 26, 2019
f708a17
Fix Rice
LukeMathWalker Mar 26, 2019
58788db
Fixed Sturges
LukeMathWalker Mar 26, 2019
3014f77
Fix strategies
LukeMathWalker Mar 26, 2019
17e5efc
Fix match
LukeMathWalker Mar 26, 2019
63abed5
Tests compile
LukeMathWalker Mar 26, 2019
4a4b489
Fix assertion
LukeMathWalker Mar 26, 2019
f692887
Fmt
LukeMathWalker Mar 26, 2019
a8ad4b1
Add more error types
jturner314 Mar 31, 2019
29f56f3
Rename StrategyError to BinsBuildError
jturner314 Apr 1, 2019
bca2dc9
Make GridBuilder::from_array return Result
jturner314 Apr 1, 2019
f41b317
Make BinsBuildError enum non-exhaustive
jturner314 Apr 1, 2019
308e0e7
Merge pull request #4 from jturner314/histogram-error-handling
LukeMathWalker Apr 1, 2019
c280c6b
Use lazy OR operator.
LukeMathWalker Apr 1, 2019
6481509
Merge remote-tracking branch 'origin/histogram-error-handling' into h…
LukeMathWalker Apr 1, 2019
701842d
Use lazy OR operator.
LukeMathWalker Apr 1, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 48 additions & 49 deletions src/entropy.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Information theory (e.g. entropy, KL divergence, etc.).
use crate::errors::ShapeMismatch;
use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch};
use ndarray::{Array, ArrayBase, Data, Dimension, Zip};
use num_traits::Float;

@@ -19,7 +19,7 @@ where
/// i=1
/// ```
///
/// If the array is empty, `None` is returned.
/// If the array is empty, `Err(EmptyInput)` is returned.
///
/// **Panics** if `ln` of any element in the array panics (which can occur for negative values for some `A`).
///
@@ -38,7 +38,7 @@ where
///
/// [entropy]: https://en.wikipedia.org/wiki/Entropy_(information_theory)
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
fn entropy(&self) -> Option<A>
fn entropy(&self) -> Result<A, EmptyInput>
where
A: Float;

@@ -53,8 +53,9 @@ where
/// i=1
/// ```
///
/// If the arrays are empty, Ok(`None`) is returned.
/// If the array shapes are not identical, `Err(ShapeMismatch)` is returned.
/// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned.
/// If the array shapes are not identical,
/// `Err(MultiInputError::ShapeMismatch)` is returned.
///
/// **Panics** if, for a pair of elements *(pᵢ, qᵢ)* from *p* and *q*, computing
/// *ln(qᵢ/pᵢ)* is a panic cause for `A`.
@@ -73,7 +74,7 @@ where
///
/// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
where
S2: Data<Elem = A>,
A: Float;
@@ -89,8 +90,9 @@ where
/// i=1
/// ```
///
/// If the arrays are empty, Ok(`None`) is returned.
/// If the array shapes are not identical, `Err(ShapeMismatch)` is returned.
/// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned.
/// If the array shapes are not identical,
/// `Err(MultiInputError::ShapeMismatch)` is returned.
///
/// **Panics** if any element in *q* is negative and taking the logarithm of a negative number
/// is a panic cause for `A`.
@@ -114,7 +116,7 @@ where
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
/// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method
/// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
where
S2: Data<Elem = A>,
A: Float;
@@ -125,14 +127,14 @@ where
S: Data<Elem = A>,
D: Dimension,
{
fn entropy(&self) -> Option<A>
fn entropy(&self) -> Result<A, EmptyInput>
where
A: Float,
{
if self.len() == 0 {
None
Err(EmptyInput)
} else {
let entropy = self
let entropy = -self
.mapv(|x| {
if x == A::zero() {
A::zero()
@@ -141,23 +143,24 @@ where
}
})
.sum();
Some(-entropy)
Ok(entropy)
}
}

fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
where
A: Float,
S2: Data<Elem = A>,
{
if self.len() == 0 {
return Ok(None);
return Err(MultiInputError::EmptyInput);
}
if self.shape() != q.shape() {
return Err(ShapeMismatch {
first_shape: self.shape().to_vec(),
second_shape: q.shape().to_vec(),
});
}
.into());
}

let mut temp = Array::zeros(self.raw_dim());
@@ -174,22 +177,23 @@ where
}
});
let kl_divergence = -temp.sum();
Ok(Some(kl_divergence))
Ok(kl_divergence)
}

fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
where
S2: Data<Elem = A>,
A: Float,
{
if self.len() == 0 {
return Ok(None);
return Err(MultiInputError::EmptyInput);
}
if self.shape() != q.shape() {
return Err(ShapeMismatch {
first_shape: self.shape().to_vec(),
second_shape: q.shape().to_vec(),
});
}
.into());
}

let mut temp = Array::zeros(self.raw_dim());
@@ -206,15 +210,15 @@ where
}
});
let cross_entropy = -temp.sum();
Ok(Some(cross_entropy))
Ok(cross_entropy)
}
}

#[cfg(test)]
mod tests {
use super::EntropyExt;
use approx::assert_abs_diff_eq;
use errors::ShapeMismatch;
use errors::{EmptyInput, MultiInputError};
use ndarray::{array, Array1};
use noisy_float::types::n64;
use std::f64;
@@ -228,7 +232,7 @@ mod tests {
#[test]
fn test_entropy_with_empty_array_of_floats() {
let a: Array1<f64> = array![];
assert!(a.entropy().is_none());
assert_eq!(a.entropy(), Err(EmptyInput));
}

#[test]
@@ -251,13 +255,13 @@ mod tests {
}

#[test]
fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), ShapeMismatch> {
fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), MultiInputError> {
let a = array![f64::NAN, 1.];
let b = array![2., 1.];
assert!(a.cross_entropy(&b)?.unwrap().is_nan());
assert!(b.cross_entropy(&a)?.unwrap().is_nan());
assert!(a.kl_divergence(&b)?.unwrap().is_nan());
assert!(b.kl_divergence(&a)?.unwrap().is_nan());
assert!(a.cross_entropy(&b)?.is_nan());
assert!(b.cross_entropy(&a)?.is_nan());
assert!(a.kl_divergence(&b)?.is_nan());
assert!(b.kl_divergence(&a)?.is_nan());
Ok(())
}

@@ -284,20 +288,19 @@ mod tests {
}

#[test]
fn test_cross_entropy_and_kl_with_empty_array_of_floats() -> Result<(), ShapeMismatch> {
fn test_cross_entropy_and_kl_with_empty_array_of_floats() {
let p: Array1<f64> = array![];
let q: Array1<f64> = array![];
assert!(p.cross_entropy(&q)?.is_none());
assert!(p.kl_divergence(&q)?.is_none());
Ok(())
assert!(p.cross_entropy(&q).unwrap_err().is_empty_input());
assert!(p.kl_divergence(&q).unwrap_err().is_empty_input());
}

#[test]
fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), ShapeMismatch> {
fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), MultiInputError> {
let p = array![1.];
let q = array![-1.];
let cross_entropy: f64 = p.cross_entropy(&q)?.unwrap();
let kl_divergence: f64 = p.kl_divergence(&q)?.unwrap();
let cross_entropy: f64 = p.cross_entropy(&q)?;
let kl_divergence: f64 = p.kl_divergence(&q)?;
assert!(cross_entropy.is_nan());
assert!(kl_divergence.is_nan());
Ok(())
@@ -320,26 +323,26 @@ mod tests {
}

#[test]
fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), ShapeMismatch> {
fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), MultiInputError> {
let p = array![0., 0.];
let q = array![0., 0.5];
assert_eq!(p.cross_entropy(&q)?.unwrap(), 0.);
assert_eq!(p.kl_divergence(&q)?.unwrap(), 0.);
assert_eq!(p.cross_entropy(&q)?, 0.);
assert_eq!(p.kl_divergence(&q)?, 0.);
Ok(())
}

#[test]
fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership(
) -> Result<(), ShapeMismatch> {
) -> Result<(), MultiInputError> {
let p = array![0.5, 0.5];
let mut q = array![0.5, 0.];
assert_eq!(p.cross_entropy(&q.view_mut())?.unwrap(), f64::INFINITY);
assert_eq!(p.kl_divergence(&q.view_mut())?.unwrap(), f64::INFINITY);
assert_eq!(p.cross_entropy(&q.view_mut())?, f64::INFINITY);
assert_eq!(p.kl_divergence(&q.view_mut())?, f64::INFINITY);
Ok(())
}

#[test]
fn test_cross_entropy() -> Result<(), ShapeMismatch> {
fn test_cross_entropy() -> Result<(), MultiInputError> {
// Arrays of probability values - normalized and positive.
let p: Array1<f64> = array![
0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, 0.05782189,
@@ -356,16 +359,12 @@ mod tests {
// Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q)
let expected_cross_entropy = 3.385347705020779;

assert_abs_diff_eq!(
p.cross_entropy(&q)?.unwrap(),
expected_cross_entropy,
epsilon = 1e-6
);
assert_abs_diff_eq!(p.cross_entropy(&q)?, expected_cross_entropy, epsilon = 1e-6);
Ok(())
}

#[test]
fn test_kl() -> Result<(), ShapeMismatch> {
fn test_kl() -> Result<(), MultiInputError> {
// Arrays of probability values - normalized and positive.
let p: Array1<f64> = array![
0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, 0.02183501, 0.00137516,
@@ -390,7 +389,7 @@ mod tests {
// Computed using scipy.stats.entropy(p, q)
let expected_kl = 0.3555862567800096;

assert_abs_diff_eq!(p.kl_divergence(&q)?.unwrap(), expected_kl, epsilon = 1e-6);
assert_abs_diff_eq!(p.kl_divergence(&q)?, expected_kl, epsilon = 1e-6);
Ok(())
}
}
92 changes: 91 additions & 1 deletion src/errors.rs
Original file line number Diff line number Diff line change
@@ -2,10 +2,50 @@
use std::error::Error;
use std::fmt;

#[derive(Debug)]
/// An error that indicates that the input array was empty.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct EmptyInput;

impl fmt::Display for EmptyInput {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Empty input.")
}
}

impl Error for EmptyInput {}

/// An error computing a minimum/maximum value.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum MinMaxError {
/// The input was empty.
EmptyInput,
/// The ordering between a tested pair of values was undefined.
UndefinedOrder,
}

impl fmt::Display for MinMaxError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
MinMaxError::EmptyInput => write!(f, "Empty input."),
MinMaxError::UndefinedOrder => {
write!(f, "Undefined ordering between a tested pair of values.")
}
}
}
}

impl Error for MinMaxError {}

impl From<EmptyInput> for MinMaxError {
fn from(_: EmptyInput) -> MinMaxError {
MinMaxError::EmptyInput
}
}

/// An error used by methods and functions that take two arrays as argument and
/// expect them to have exactly the same shape
/// (e.g. `ShapeMismatch` is raised when `a.shape() == b.shape()` evaluates to `False`).
#[derive(Clone, Debug)]
pub struct ShapeMismatch {
pub first_shape: Vec<usize>,
pub second_shape: Vec<usize>,
@@ -22,3 +62,53 @@ impl fmt::Display for ShapeMismatch {
}

impl Error for ShapeMismatch {}

/// An error for methods that take multiple non-empty array inputs.
#[derive(Clone, Debug)]
pub enum MultiInputError {
/// One or more of the arrays were empty.
EmptyInput,
/// The arrays did not have the same shape.
ShapeMismatch(ShapeMismatch),
}

impl MultiInputError {
/// Returns whether `self` is the `EmptyInput` variant.
pub fn is_empty_input(&self) -> bool {
match self {
MultiInputError::EmptyInput => true,
_ => false,
}
}

/// Returns whether `self` is the `ShapeMismatch` variant.
pub fn is_shape_mismatch(&self) -> bool {
match self {
MultiInputError::ShapeMismatch(_) => true,
_ => false,
}
}
}

impl fmt::Display for MultiInputError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
MultiInputError::EmptyInput => write!(f, "Empty input."),
MultiInputError::ShapeMismatch(e) => write!(f, "Shape mismatch: {}", e),
}
}
}

impl Error for MultiInputError {}

impl From<EmptyInput> for MultiInputError {
fn from(_: EmptyInput) -> Self {
MultiInputError::EmptyInput
}
}

impl From<ShapeMismatch> for MultiInputError {
fn from(err: ShapeMismatch) -> Self {
MultiInputError::ShapeMismatch(err)
}
}
58 changes: 55 additions & 3 deletions src/histogram/errors.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::errors::{EmptyInput, MinMaxError};
use std::error;
use std::fmt;

@@ -15,9 +16,60 @@ impl error::Error for BinNotFound {
fn description(&self) -> &str {
"No bin has been found."
}
}

/// Error computing the set of histogram bins.
#[derive(Debug, Clone)]
pub enum BinsBuildError {
/// The input array was empty.
EmptyInput,
/// The strategy for computing appropriate bins failed.
Strategy,
#[doc(hidden)]
__NonExhaustive,
}

impl BinsBuildError {
/// Returns whether `self` is the `EmptyInput` variant.
pub fn is_empty_input(&self) -> bool {
match self {
BinsBuildError::EmptyInput => true,
_ => false,
}
}

/// Returns whether `self` is the `Strategy` variant.
pub fn is_strategy(&self) -> bool {
match self {
BinsBuildError::Strategy => true,
_ => false,
}
}
}

impl fmt::Display for BinsBuildError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "The strategy failed to determine a non-zero bin width.")
}
}

impl error::Error for BinsBuildError {
fn description(&self) -> &str {
"The strategy failed to determine a non-zero bin width."
}
}

impl From<EmptyInput> for BinsBuildError {
fn from(_: EmptyInput) -> Self {
BinsBuildError::EmptyInput
}
}

fn cause(&self) -> Option<&error::Error> {
// Generic error, underlying cause isn't tracked.
None
impl From<MinMaxError> for BinsBuildError {
fn from(err: MinMaxError) -> BinsBuildError {
match err {
MinMaxError::EmptyInput => BinsBuildError::EmptyInput,
MinMaxError::UndefinedOrder => BinsBuildError::Strategy,
}
}
}
12 changes: 8 additions & 4 deletions src/histogram/grid.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::bins::Bins;
use super::errors::BinsBuildError;
use super::strategies::BinsBuildingStrategy;
use itertools::izip;
use ndarray::{ArrayBase, Axis, Data, Ix1, Ix2};
@@ -54,7 +55,7 @@ use std::ops::Range;
///
/// // The optimal grid layout is inferred from the data,
/// // specifying a strategy (Auto in this case)
/// let grid = GridBuilder::<Auto<usize>>::from_array(&observations).build();
/// let grid = GridBuilder::<Auto<usize>>::from_array(&observations).unwrap().build();
/// let expected_grid = Grid::from(vec![Bins::new(Edges::from(vec![1, 20, 39, 58, 77, 96, 115]))]);
/// assert_eq!(grid, expected_grid);
///
@@ -169,17 +170,20 @@ where
/// it returns a `GridBuilder` instance that has learned the required parameter
/// to build a [`Grid`] according to the specified [`strategy`].
///
/// It returns `Err` if it is not possible to build a [`Grid`] given
/// the observed data according to the chosen [`strategy`].
///
/// [`Grid`]: struct.Grid.html
/// [`strategy`]: strategies/index.html
pub fn from_array<S>(array: &ArrayBase<S, Ix2>) -> Self
pub fn from_array<S>(array: &ArrayBase<S, Ix2>) -> Result<Self, BinsBuildError>
where
S: Data<Elem = A>,
{
let bin_builders = array
.axis_iter(Axis(1))
.map(|data| B::from_array(&data))
.collect();
Self { bin_builders }
.collect::<Result<Vec<B>, BinsBuildError>>()?;
Ok(Self { bin_builders })
}

/// Returns a [`Grid`] instance, built accordingly to the specified [`strategy`]
2 changes: 1 addition & 1 deletion src/histogram/histograms.rs
Original file line number Diff line number Diff line change
@@ -123,7 +123,7 @@ where
/// [n64(-1.), n64(-0.5)],
/// [n64(0.5), n64(-1.)]
/// ];
/// let grid = GridBuilder::<Sqrt<N64>>::from_array(&observations).build();
/// let grid = GridBuilder::<Sqrt<N64>>::from_array(&observations).unwrap().build();
/// let expected_grid = Grid::from(
/// vec![
/// Bins::new(Edges::from(vec![n64(-1.), n64(0.), n64(1.), n64(2.)])),
206 changes: 131 additions & 75 deletions src/histogram/strategies.rs

Large diffs are not rendered by default.

97 changes: 52 additions & 45 deletions src/quantile.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::errors::{EmptyInput, MinMaxError, MinMaxError::UndefinedOrder};
use interpolate::Interpolate;
use ndarray::prelude::*;
use ndarray::{s, Data, DataMut, RemoveAxis};
@@ -184,11 +185,11 @@ where
{
/// Finds the index of the minimum value of the array.
///
/// Returns `None` if any of the pairwise orderings tested by the function
/// are undefined. (For example, this occurs if there are any
/// floating-point NaN values in the array.)
/// Returns `Err(MinMaxError::UndefinedOrder)` if any of the pairwise
/// orderings tested by the function are undefined. (For example, this
/// occurs if there are any floating-point NaN values in the array.)
///
/// Returns `None` if the array is empty.
/// Returns `Err(MinMaxError::EmptyInput)` if the array is empty.
///
/// Even if there are multiple (equal) elements that are minima, only one
/// index is returned. (Which one is returned is unspecified and may depend
@@ -205,9 +206,9 @@ where
///
/// let a = array![[1., 3., 5.],
/// [2., 0., 6.]];
/// assert_eq!(a.argmin(), Some((1, 1)));
/// assert_eq!(a.argmin(), Ok((1, 1)));
/// ```
fn argmin(&self) -> Option<D::Pattern>
fn argmin(&self) -> Result<D::Pattern, MinMaxError>
where
A: PartialOrd;

@@ -240,16 +241,16 @@ where

/// Finds the elementwise minimum of the array.
///
/// Returns `None` if any of the pairwise orderings tested by the function
/// are undefined. (For example, this occurs if there are any
/// floating-point NaN values in the array.)
/// Returns `Err(MinMaxError::UndefinedOrder)` if any of the pairwise
/// orderings tested by the function are undefined. (For example, this
/// occurs if there are any floating-point NaN values in the array.)
///
/// Additionally, returns `None` if the array is empty.
/// Returns `Err(MinMaxError::EmptyInput)` if the array is empty.
///
/// Even if there are multiple (equal) elements that are minima, only one
/// is returned. (Which one is returned is unspecified and may depend on
/// the memory layout of the array.)
fn min(&self) -> Option<&A>
fn min(&self) -> Result<&A, MinMaxError>
where
A: PartialOrd;

@@ -269,11 +270,11 @@ where

/// Finds the index of the maximum value of the array.
///
/// Returns `None` if any of the pairwise orderings tested by the function
/// are undefined. (For example, this occurs if there are any
/// floating-point NaN values in the array.)
/// Returns `Err(MinMaxError::UndefinedOrder)` if any of the pairwise
/// orderings tested by the function are undefined. (For example, this
/// occurs if there are any floating-point NaN values in the array.)
///
/// Returns `None` if the array is empty.
/// Returns `Err(MinMaxError::EmptyInput)` if the array is empty.
///
/// Even if there are multiple (equal) elements that are maxima, only one
/// index is returned. (Which one is returned is unspecified and may depend
@@ -290,9 +291,9 @@ where
///
/// let a = array![[1., 3., 7.],
/// [2., 5., 6.]];
/// assert_eq!(a.argmax(), Some((0, 2)));
/// assert_eq!(a.argmax(), Ok((0, 2)));
/// ```
fn argmax(&self) -> Option<D::Pattern>
fn argmax(&self) -> Result<D::Pattern, MinMaxError>
where
A: PartialOrd;

@@ -325,16 +326,16 @@ where

/// Finds the elementwise maximum of the array.
///
/// Returns `None` if any of the pairwise orderings tested by the function
/// are undefined. (For example, this occurs if there are any
/// floating-point NaN values in the array.)
/// Returns `Err(MinMaxError::UndefinedOrder)` if any of the pairwise
/// orderings tested by the function are undefined. (For example, this
/// occurs if there are any floating-point NaN values in the array.)
///
/// Additionally, returns `None` if the array is empty.
/// Returns `Err(EmptyInput)` if the array is empty.
///
/// Even if there are multiple (equal) elements that are maxima, only one
/// is returned. (Which one is returned is unspecified and may depend on
/// the memory layout of the array.)
fn max(&self) -> Option<&A>
fn max(&self) -> Result<&A, MinMaxError>
where
A: PartialOrd;

@@ -406,21 +407,21 @@ where
S: Data<Elem = A>,
D: Dimension,
{
fn argmin(&self) -> Option<D::Pattern>
fn argmin(&self) -> Result<D::Pattern, MinMaxError>
where
A: PartialOrd,
{
let mut current_min = self.first()?;
let mut current_min = self.first().ok_or(EmptyInput)?;
let mut current_pattern_min = D::zeros(self.ndim()).into_pattern();

for (pattern, elem) in self.indexed_iter() {
if elem.partial_cmp(current_min)? == cmp::Ordering::Less {
if elem.partial_cmp(current_min).ok_or(UndefinedOrder)? == cmp::Ordering::Less {
current_pattern_min = pattern;
current_min = elem
}
}

Some(current_pattern_min)
Ok(current_pattern_min)
}

fn argmin_skipnan(&self) -> Option<D::Pattern>
@@ -445,14 +446,17 @@ where
}
}

fn min(&self) -> Option<&A>
fn min(&self) -> Result<&A, MinMaxError>
where
A: PartialOrd,
{
let first = self.first()?;
self.fold(Some(first), |acc, elem| match elem.partial_cmp(acc?)? {
cmp::Ordering::Less => Some(elem),
_ => acc,
let first = self.first().ok_or(EmptyInput)?;
self.fold(Ok(first), |acc, elem| {
let acc = acc?;
match elem.partial_cmp(acc).ok_or(UndefinedOrder)? {
cmp::Ordering::Less => Ok(elem),
_ => Ok(acc),
}
})
}

@@ -470,21 +474,21 @@ where
}))
}

fn argmax(&self) -> Option<D::Pattern>
fn argmax(&self) -> Result<D::Pattern, MinMaxError>
where
A: PartialOrd,
{
let mut current_max = self.first()?;
let mut current_max = self.first().ok_or(EmptyInput)?;
let mut current_pattern_max = D::zeros(self.ndim()).into_pattern();

for (pattern, elem) in self.indexed_iter() {
if elem.partial_cmp(current_max)? == cmp::Ordering::Greater {
if elem.partial_cmp(current_max).ok_or(UndefinedOrder)? == cmp::Ordering::Greater {
current_pattern_max = pattern;
current_max = elem
}
}

Some(current_pattern_max)
Ok(current_pattern_max)
}

fn argmax_skipnan(&self) -> Option<D::Pattern>
@@ -509,14 +513,17 @@ where
}
}

fn max(&self) -> Option<&A>
fn max(&self) -> Result<&A, MinMaxError>
where
A: PartialOrd,
{
let first = self.first()?;
self.fold(Some(first), |acc, elem| match elem.partial_cmp(acc?)? {
cmp::Ordering::Greater => Some(elem),
_ => acc,
let first = self.first().ok_or(EmptyInput)?;
self.fold(Ok(first), |acc, elem| {
let acc = acc?;
match elem.partial_cmp(acc).ok_or(UndefinedOrder)? {
cmp::Ordering::Greater => Ok(elem),
_ => Ok(acc),
}
})
}

@@ -619,10 +626,10 @@ where
/// - worst case: O(`m`^2);
/// where `m` is the number of elements in the array.
///
/// Returns `None` if the array is empty.
/// Returns `Err(EmptyInput)` if the array is empty.
///
/// **Panics** if `q` is not between `0.` and `1.` (inclusive).
fn quantile_mut<I>(&mut self, q: f64) -> Option<A>
fn quantile_mut<I>(&mut self, q: f64) -> Result<A, EmptyInput>
where
A: Ord + Clone,
S: DataMut,
@@ -633,16 +640,16 @@ impl<A, S> Quantile1dExt<A, S> for ArrayBase<S, Ix1>
where
S: Data<Elem = A>,
{
fn quantile_mut<I>(&mut self, q: f64) -> Option<A>
fn quantile_mut<I>(&mut self, q: f64) -> Result<A, EmptyInput>
where
A: Ord + Clone,
S: DataMut,
I: Interpolate<A>,
{
if self.is_empty() {
None
Err(EmptyInput)
} else {
Some(self.quantile_axis_mut::<I>(Axis(0), q).into_scalar())
Ok(self.quantile_axis_mut::<I>(Axis(0), q).into_scalar())
}
}
}
24 changes: 13 additions & 11 deletions src/summary_statistics/means.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::SummaryStatisticsExt;
use crate::errors::EmptyInput;
use ndarray::{ArrayBase, Data, Dimension};
use num_traits::{Float, FromPrimitive, Zero};
use std::ops::{Add, Div};
@@ -8,28 +9,28 @@ where
S: Data<Elem = A>,
D: Dimension,
{
fn mean(&self) -> Option<A>
fn mean(&self) -> Result<A, EmptyInput>
where
A: Clone + FromPrimitive + Add<Output = A> + Div<Output = A> + Zero,
{
let n_elements = self.len();
if n_elements == 0 {
None
Err(EmptyInput)
} else {
let n_elements = A::from_usize(n_elements)
.expect("Converting number of elements to `A` must not fail.");
Some(self.sum() / n_elements)
Ok(self.sum() / n_elements)
}
}

fn harmonic_mean(&self) -> Option<A>
fn harmonic_mean(&self) -> Result<A, EmptyInput>
where
A: Float + FromPrimitive,
{
self.map(|x| x.recip()).mean().map(|x| x.recip())
}

fn geometric_mean(&self) -> Option<A>
fn geometric_mean(&self) -> Result<A, EmptyInput>
where
A: Float + FromPrimitive,
{
@@ -40,6 +41,7 @@ where
#[cfg(test)]
mod tests {
use super::SummaryStatisticsExt;
use crate::errors::EmptyInput;
use approx::abs_diff_eq;
use ndarray::{array, Array1};
use noisy_float::types::N64;
@@ -56,17 +58,17 @@ mod tests {
#[test]
fn test_means_with_empty_array_of_floats() {
let a: Array1<f64> = array![];
assert!(a.mean().is_none());
assert!(a.harmonic_mean().is_none());
assert!(a.geometric_mean().is_none());
assert_eq!(a.mean(), Err(EmptyInput));
assert_eq!(a.harmonic_mean(), Err(EmptyInput));
assert_eq!(a.geometric_mean(), Err(EmptyInput));
}

#[test]
fn test_means_with_empty_array_of_noisy_floats() {
let a: Array1<N64> = array![];
assert!(a.mean().is_none());
assert!(a.harmonic_mean().is_none());
assert!(a.geometric_mean().is_none());
assert_eq!(a.mean(), Err(EmptyInput));
assert_eq!(a.harmonic_mean(), Err(EmptyInput));
assert_eq!(a.geometric_mean(), Err(EmptyInput));
}

#[test]
13 changes: 7 additions & 6 deletions src/summary_statistics/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! Summary statistics (e.g. mean, variance, etc.).
use crate::errors::EmptyInput;
use ndarray::{Data, Dimension};
use num_traits::{Float, FromPrimitive, Zero};
use std::ops::{Add, Div};
@@ -18,12 +19,12 @@ where
/// n i=1
/// ```
///
/// If the array is empty, `None` is returned.
/// If the array is empty, `Err(EmptyInput)` is returned.
///
/// **Panics** if `A::from_usize()` fails to convert the number of elements in the array.
///
/// [`arithmetic mean`]: https://en.wikipedia.org/wiki/Arithmetic_mean
fn mean(&self) -> Option<A>
fn mean(&self) -> Result<A, EmptyInput>
where
A: Clone + FromPrimitive + Add<Output = A> + Div<Output = A> + Zero;

@@ -35,12 +36,12 @@ where
/// ⎝i=1 ⎠
/// ```
///
/// If the array is empty, `None` is returned.
/// If the array is empty, `Err(EmptyInput)` is returned.
///
/// **Panics** if `A::from_usize()` fails to convert the number of elements in the array.
///
/// [`harmonic mean`]: https://en.wikipedia.org/wiki/Harmonic_mean
fn harmonic_mean(&self) -> Option<A>
fn harmonic_mean(&self) -> Result<A, EmptyInput>
where
A: Float + FromPrimitive;

@@ -52,12 +53,12 @@ where
/// ⎝i=1 ⎠
/// ```
///
/// If the array is empty, `None` is returned.
/// If the array is empty, `Err(EmptyInput)` is returned.
///
/// **Panics** if `A::from_usize()` fails to convert the number of elements in the array.
///
/// [`geometric mean`]: https://en.wikipedia.org/wiki/Geometric_mean
fn geometric_mean(&self) -> Option<A>
fn geometric_mean(&self) -> Result<A, EmptyInput>
where
A: Float + FromPrimitive;
}
33 changes: 17 additions & 16 deletions tests/quantile.rs
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@ extern crate quickcheck;

use ndarray::prelude::*;
use ndarray_stats::{
errors::MinMaxError,
interpolate::{Higher, Linear, Lower, Midpoint, Nearest},
Quantile1dExt, QuantileExt,
};
@@ -13,22 +14,22 @@ use quickcheck::quickcheck;
#[test]
fn test_argmin() {
let a = array![[1, 5, 3], [2, 0, 6]];
assert_eq!(a.argmin(), Some((1, 1)));
assert_eq!(a.argmin(), Ok((1, 1)));

let a = array![[1., 5., 3.], [2., 0., 6.]];
assert_eq!(a.argmin(), Some((1, 1)));
assert_eq!(a.argmin(), Ok((1, 1)));

let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]];
assert_eq!(a.argmin(), None);
assert_eq!(a.argmin(), Err(MinMaxError::UndefinedOrder));

let a: Array2<i32> = array![[], []];
assert_eq!(a.argmin(), None);
assert_eq!(a.argmin(), Err(MinMaxError::EmptyInput));
}

quickcheck! {
fn argmin_matches_min(data: Vec<f32>) -> bool {
let a = Array1::from(data);
a.argmin().map(|i| a[i]) == a.min().cloned()
a.argmin().map(|i| &a[i]) == a.min()
}
}

@@ -66,13 +67,13 @@ quickcheck! {
#[test]
fn test_min() {
let a = array![[1, 5, 3], [2, 0, 6]];
assert_eq!(a.min(), Some(&0));
assert_eq!(a.min(), Ok(&0));

let a = array![[1., 5., 3.], [2., 0., 6.]];
assert_eq!(a.min(), Some(&0.));
assert_eq!(a.min(), Ok(&0.));

let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]];
assert_eq!(a.min(), None);
assert_eq!(a.min(), Err(MinMaxError::UndefinedOrder));
}

#[test]
@@ -93,22 +94,22 @@ fn test_min_skipnan_all_nan() {
#[test]
fn test_argmax() {
let a = array![[1, 5, 3], [2, 0, 6]];
assert_eq!(a.argmax(), Some((1, 2)));
assert_eq!(a.argmax(), Ok((1, 2)));

let a = array![[1., 5., 3.], [2., 0., 6.]];
assert_eq!(a.argmax(), Some((1, 2)));
assert_eq!(a.argmax(), Ok((1, 2)));

let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]];
assert_eq!(a.argmax(), None);
assert_eq!(a.argmax(), Err(MinMaxError::UndefinedOrder));

let a: Array2<i32> = array![[], []];
assert_eq!(a.argmax(), None);
assert_eq!(a.argmax(), Err(MinMaxError::EmptyInput));
}

quickcheck! {
fn argmax_matches_max(data: Vec<f32>) -> bool {
let a = Array1::from(data);
a.argmax().map(|i| a[i]) == a.max().cloned()
a.argmax().map(|i| &a[i]) == a.max()
}
}

@@ -149,13 +150,13 @@ quickcheck! {
#[test]
fn test_max() {
let a = array![[1, 5, 7], [2, 0, 6]];
assert_eq!(a.max(), Some(&7));
assert_eq!(a.max(), Ok(&7));

let a = array![[1., 5., 7.], [2., 0., 6.]];
assert_eq!(a.max(), Some(&7.));
assert_eq!(a.max(), Ok(&7.));

let a = array![[1., 5., 7.], [2., ::std::f64::NAN, 6.]];
assert_eq!(a.max(), None);
assert_eq!(a.max(), Err(MinMaxError::UndefinedOrder));
}

#[test]