From 6797d7bf6973e8ba9bfda2ba3dca3c7628f4e625 Mon Sep 17 00:00:00 2001 From: Nil Goyette Date: Thu, 26 Sep 2019 11:26:32 -0400 Subject: [PATCH 1/8] Move summary statistic tests outside --- src/summary_statistics/means.rs | 313 -------------------------------- tests/summary_statistics.rs | 310 +++++++++++++++++++++++++++++++ 2 files changed, 310 insertions(+), 313 deletions(-) create mode 100644 tests/summary_statistics.rs diff --git a/src/summary_statistics/means.rs b/src/summary_statistics/means.rs index e5fd566e..66aa7a81 100644 --- a/src/summary_statistics/means.rs +++ b/src/summary_statistics/means.rs @@ -251,316 +251,3 @@ where } result } - -#[cfg(test)] -mod tests { - use super::SummaryStatisticsExt; - use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch}; - use approx::{abs_diff_eq, assert_abs_diff_eq}; - use ndarray::{arr0, array, Array, Array1, Array2, Axis}; - use ndarray_rand::RandomExt; - use noisy_float::types::N64; - use quickcheck::{quickcheck, TestResult}; - use rand::distributions::Uniform; - use std::f64; - - #[test] - fn test_means_with_nan_values() { - let a = array![f64::NAN, 1.]; - assert!(a.mean().unwrap().is_nan()); - assert!(a.weighted_mean(&array![1.0, f64::NAN]).unwrap().is_nan()); - assert!(a.weighted_sum(&array![1.0, f64::NAN]).unwrap().is_nan()); - assert!(a - .weighted_mean_axis(Axis(0), &array![1.0, f64::NAN]) - .unwrap() - .into_scalar() - .is_nan()); - assert!(a - .weighted_sum_axis(Axis(0), &array![1.0, f64::NAN]) - .unwrap() - .into_scalar() - .is_nan()); - assert!(a.harmonic_mean().unwrap().is_nan()); - assert!(a.geometric_mean().unwrap().is_nan()); - } - - #[test] - fn test_means_with_empty_array_of_floats() { - let a: Array1 = array![]; - assert_eq!(a.mean(), None); - assert_eq!( - a.weighted_mean(&array![1.0]), - Err(MultiInputError::EmptyInput) - ); - assert_eq!( - a.weighted_mean_axis(Axis(0), &array![1.0]), - Err(MultiInputError::EmptyInput) - ); - assert_eq!(a.harmonic_mean(), Err(EmptyInput)); - assert_eq!(a.geometric_mean(), Err(EmptyInput)); - - // The sum methods accept empty arrays - assert_eq!(a.weighted_sum(&array![]), Ok(0.0)); - assert_eq!(a.weighted_sum_axis(Axis(0), &array![]), Ok(arr0(0.0))); - } - - #[test] - fn test_means_with_empty_array_of_noisy_floats() { - let a: Array1 = array![]; - assert_eq!(a.mean(), None); - assert_eq!(a.weighted_mean(&array![]), Err(MultiInputError::EmptyInput)); - assert_eq!( - a.weighted_mean_axis(Axis(0), &array![]), - Err(MultiInputError::EmptyInput) - ); - assert_eq!(a.harmonic_mean(), Err(EmptyInput)); - assert_eq!(a.geometric_mean(), Err(EmptyInput)); - - // The sum methods accept empty arrays - assert_eq!(a.weighted_sum(&array![]), Ok(N64::new(0.0))); - assert_eq!( - a.weighted_sum_axis(Axis(0), &array![]), - Ok(arr0(N64::new(0.0))) - ); - } - - #[test] - fn test_means_with_array_of_floats() { - let a: Array1 = array![ - 0.99889651, 0.0150731, 0.28492482, 0.83819218, 0.48413156, 0.80710412, 0.41762936, - 0.22879429, 0.43997224, 0.23831807, 0.02416466, 0.6269962, 0.47420614, 0.56275487, - 0.78995021, 0.16060581, 0.64635041, 0.34876609, 0.78543249, 0.19938356, 0.34429457, - 0.88072369, 0.17638164, 0.60819363, 0.250392, 0.69912532, 0.78855523, 0.79140914, - 0.85084218, 0.31839879, 0.63381769, 0.22421048, 0.70760302, 0.99216018, 0.80199153, - 0.19239188, 0.61356023, 0.31505352, 0.06120481, 0.66417377, 0.63608897, 0.84959691, - 0.43599069, 0.77867775, 0.88267754, 0.83003623, 0.67016118, 0.67547638, 0.65220036, - 0.68043427 - ]; - // Computed using NumPy - let expected_mean = 0.5475494059146699; - let expected_weighted_mean = 0.6782420496397121; - // Computed using SciPy - let expected_harmonic_mean = 0.21790094950226022; - let expected_geometric_mean = 0.4345897639796527; - - assert_abs_diff_eq!(a.mean().unwrap(), expected_mean, epsilon = 1e-9); - assert_abs_diff_eq!( - a.harmonic_mean().unwrap(), - expected_harmonic_mean, - epsilon = 1e-7 - ); - assert_abs_diff_eq!( - a.geometric_mean().unwrap(), - expected_geometric_mean, - epsilon = 1e-12 - ); - - // weighted_mean with itself, normalized - let weights = &a / a.sum(); - assert_abs_diff_eq!( - a.weighted_sum(&weights).unwrap(), - expected_weighted_mean, - epsilon = 1e-12 - ); - - let data = a.into_shape((2, 5, 5)).unwrap(); - let weights = array![0.1, 0.5, 0.25, 0.15, 0.2]; - assert_abs_diff_eq!( - data.weighted_mean_axis(Axis(1), &weights).unwrap(), - array![ - [0.50202721, 0.53347361, 0.29086033, 0.56995637, 0.37087139], - [0.58028328, 0.50485216, 0.59349973, 0.70308937, 0.72280630] - ], - epsilon = 1e-8 - ); - assert_abs_diff_eq!( - data.weighted_mean_axis(Axis(2), &weights).unwrap(), - array![ - [0.33434378, 0.38365259, 0.56405781, 0.48676574, 0.55016179], - [0.71112376, 0.55134174, 0.45566513, 0.74228516, 0.68405851] - ], - epsilon = 1e-8 - ); - assert_abs_diff_eq!( - data.weighted_sum_axis(Axis(1), &weights).unwrap(), - array![ - [0.60243266, 0.64016833, 0.34903240, 0.68394765, 0.44504567], - [0.69633993, 0.60582259, 0.71219968, 0.84370724, 0.86736757] - ], - epsilon = 1e-8 - ); - assert_abs_diff_eq!( - data.weighted_sum_axis(Axis(2), &weights).unwrap(), - array![ - [0.40121254, 0.46038311, 0.67686937, 0.58411889, 0.66019415], - [0.85334851, 0.66161009, 0.54679815, 0.89074219, 0.82087021] - ], - epsilon = 1e-8 - ); - } - - #[test] - fn weighted_sum_dimension_zero() { - let a = Array2::::zeros((0, 20)); - assert_eq!( - a.weighted_sum_axis(Axis(0), &Array1::zeros(0)).unwrap(), - Array1::from_elem(20, 0) - ); - assert_eq!( - a.weighted_sum_axis(Axis(1), &Array1::zeros(20)).unwrap(), - Array1::from_elem(0, 0) - ); - assert_eq!( - a.weighted_sum_axis(Axis(0), &Array1::zeros(1)), - Err(MultiInputError::ShapeMismatch(ShapeMismatch { - first_shape: vec![0, 20], - second_shape: vec![1] - })) - ); - assert_eq!( - a.weighted_sum(&Array2::zeros((10, 20))), - Err(MultiInputError::ShapeMismatch(ShapeMismatch { - first_shape: vec![0, 20], - second_shape: vec![10, 20] - })) - ); - } - - #[test] - fn mean_eq_if_uniform_weights() { - fn prop(a: Vec) -> TestResult { - if a.len() < 1 { - return TestResult::discard(); - } - let a = Array1::from(a); - let weights = Array1::from_elem(a.len(), 1.0 / a.len() as f64); - let m = a.mean().unwrap(); - let wm = a.weighted_mean(&weights).unwrap(); - let ws = a.weighted_sum(&weights).unwrap(); - TestResult::from_bool( - abs_diff_eq!(m, wm, epsilon = 1e-9) && abs_diff_eq!(wm, ws, epsilon = 1e-9), - ) - } - quickcheck(prop as fn(Vec) -> TestResult); - } - - #[test] - fn mean_axis_eq_if_uniform_weights() { - fn prop(mut a: Vec) -> TestResult { - if a.len() < 24 { - return TestResult::discard(); - } - let depth = a.len() / 12; - a.truncate(depth * 3 * 4); - let weights = Array1::from_elem(depth, 1.0 / depth as f64); - let a = Array1::from(a).into_shape((depth, 3, 4)).unwrap(); - let ma = a.mean_axis(Axis(0)).unwrap(); - let wm = a.weighted_mean_axis(Axis(0), &weights).unwrap(); - let ws = a.weighted_sum_axis(Axis(0), &weights).unwrap(); - TestResult::from_bool( - abs_diff_eq!(ma, wm, epsilon = 1e-12) && abs_diff_eq!(wm, ws, epsilon = 1e12), - ) - } - quickcheck(prop as fn(Vec) -> TestResult); - } - - #[test] - fn test_central_moment_with_empty_array_of_floats() { - let a: Array1 = array![]; - for order in 0..=3 { - assert_eq!(a.central_moment(order), Err(EmptyInput)); - assert_eq!(a.central_moments(order), Err(EmptyInput)); - } - } - - #[test] - fn test_zeroth_central_moment_is_one() { - let n = 50; - let bound: f64 = 200.; - let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); - assert_eq!(a.central_moment(0).unwrap(), 1.); - } - - #[test] - fn test_first_central_moment_is_zero() { - let n = 50; - let bound: f64 = 200.; - let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); - assert_eq!(a.central_moment(1).unwrap(), 0.); - } - - #[test] - fn test_central_moments() { - let a: Array1 = array![ - 0.07820559, 0.5026185, 0.80935324, 0.39384033, 0.9483038, 0.62516215, 0.90772261, - 0.87329831, 0.60267392, 0.2960298, 0.02810356, 0.31911966, 0.86705506, 0.96884832, - 0.2222465, 0.42162446, 0.99909868, 0.47619762, 0.91696979, 0.9972741, 0.09891734, - 0.76934818, 0.77566862, 0.7692585, 0.2235759, 0.44821286, 0.79732186, 0.04804275, - 0.87863238, 0.1111003, 0.6653943, 0.44386445, 0.2133176, 0.39397086, 0.4374617, - 0.95896624, 0.57850146, 0.29301706, 0.02329879, 0.2123203, 0.62005503, 0.996492, - 0.5342986, 0.97822099, 0.5028445, 0.6693834, 0.14256682, 0.52724704, 0.73482372, - 0.1809703, - ]; - // Computed using scipy.stats.moment - let expected_moments = vec![ - 1., - 0., - 0.09339920262960291, - -0.0026849636727735186, - 0.015403769257729755, - -0.001204176487006564, - 0.002976822584939186, - ]; - for (order, expected_moment) in expected_moments.iter().enumerate() { - assert_abs_diff_eq!( - a.central_moment(order as u16).unwrap(), - expected_moment, - epsilon = 1e-8 - ); - } - } - - #[test] - fn test_bulk_central_moments() { - // Test that the bulk method is coherent with the non-bulk method - let n = 50; - let bound: f64 = 200.; - let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); - let order = 10; - let central_moments = a.central_moments(order).unwrap(); - for i in 0..=order { - assert_eq!(a.central_moment(i).unwrap(), central_moments[i as usize]); - } - } - - #[test] - fn test_kurtosis_and_skewness_is_none_with_empty_array_of_floats() { - let a: Array1 = array![]; - assert_eq!(a.skewness(), Err(EmptyInput)); - assert_eq!(a.kurtosis(), Err(EmptyInput)); - } - - #[test] - fn test_kurtosis_and_skewness() { - let a: Array1 = array![ - 0.33310096, 0.98757449, 0.9789796, 0.96738114, 0.43545674, 0.06746873, 0.23706562, - 0.04241815, 0.38961714, 0.52421271, 0.93430327, 0.33911604, 0.05112372, 0.5013455, - 0.05291507, 0.62511183, 0.20749633, 0.22132433, 0.14734804, 0.51960608, 0.00449208, - 0.4093339, 0.2237519, 0.28070469, 0.7887231, 0.92224523, 0.43454188, 0.18335111, - 0.08646856, 0.87979847, 0.25483457, 0.99975627, 0.52712442, 0.41163279, 0.85162594, - 0.52618733, 0.75815023, 0.30640695, 0.14205781, 0.59695813, 0.851331, 0.39524328, - 0.73965373, 0.4007615, 0.02133069, 0.92899207, 0.79878191, 0.38947334, 0.22042183, - 0.77768353, - ]; - // Computed using scipy.stats.kurtosis(a, fisher=False) - let expected_kurtosis = 1.821933711687523; - // Computed using scipy.stats.skew - let expected_skewness = 0.2604785422878771; - - let kurtosis = a.kurtosis().unwrap(); - let skewness = a.skewness().unwrap(); - - assert_abs_diff_eq!(kurtosis, expected_kurtosis, epsilon = 1e-12); - assert_abs_diff_eq!(skewness, expected_skewness, epsilon = 1e-8); - } -} diff --git a/tests/summary_statistics.rs b/tests/summary_statistics.rs new file mode 100644 index 00000000..6f045575 --- /dev/null +++ b/tests/summary_statistics.rs @@ -0,0 +1,310 @@ +use approx::{abs_diff_eq, assert_abs_diff_eq}; +use ndarray::{arr0, array, Array, Array1, Array2, Axis}; +use ndarray_rand::RandomExt; +use ndarray_stats::{ + errors::{EmptyInput, MultiInputError, ShapeMismatch}, + SummaryStatisticsExt, +}; +use noisy_float::types::N64; +use quickcheck::{quickcheck, TestResult}; +use rand::distributions::Uniform; +use std::f64; + +#[test] +fn test_means_with_nan_values() { + let a = array![f64::NAN, 1.]; + assert!(a.mean().unwrap().is_nan()); + assert!(a.weighted_mean(&array![1.0, f64::NAN]).unwrap().is_nan()); + assert!(a.weighted_sum(&array![1.0, f64::NAN]).unwrap().is_nan()); + assert!(a + .weighted_mean_axis(Axis(0), &array![1.0, f64::NAN]) + .unwrap() + .into_scalar() + .is_nan()); + assert!(a + .weighted_sum_axis(Axis(0), &array![1.0, f64::NAN]) + .unwrap() + .into_scalar() + .is_nan()); + assert!(a.harmonic_mean().unwrap().is_nan()); + assert!(a.geometric_mean().unwrap().is_nan()); +} + +#[test] +fn test_means_with_empty_array_of_floats() { + let a: Array1 = array![]; + assert_eq!(a.mean(), None); + assert_eq!( + a.weighted_mean(&array![1.0]), + Err(MultiInputError::EmptyInput) + ); + assert_eq!( + a.weighted_mean_axis(Axis(0), &array![1.0]), + Err(MultiInputError::EmptyInput) + ); + assert_eq!(a.harmonic_mean(), Err(EmptyInput)); + assert_eq!(a.geometric_mean(), Err(EmptyInput)); + + // The sum methods accept empty arrays + assert_eq!(a.weighted_sum(&array![]), Ok(0.0)); + assert_eq!(a.weighted_sum_axis(Axis(0), &array![]), Ok(arr0(0.0))); +} + +#[test] +fn test_means_with_empty_array_of_noisy_floats() { + let a: Array1 = array![]; + assert_eq!(a.mean(), None); + assert_eq!(a.weighted_mean(&array![]), Err(MultiInputError::EmptyInput)); + assert_eq!( + a.weighted_mean_axis(Axis(0), &array![]), + Err(MultiInputError::EmptyInput) + ); + assert_eq!(a.harmonic_mean(), Err(EmptyInput)); + assert_eq!(a.geometric_mean(), Err(EmptyInput)); + + // The sum methods accept empty arrays + assert_eq!(a.weighted_sum(&array![]), Ok(N64::new(0.0))); + assert_eq!( + a.weighted_sum_axis(Axis(0), &array![]), + Ok(arr0(N64::new(0.0))) + ); +} + +#[test] +fn test_means_with_array_of_floats() { + let a: Array1 = array![ + 0.99889651, 0.0150731, 0.28492482, 0.83819218, 0.48413156, 0.80710412, 0.41762936, + 0.22879429, 0.43997224, 0.23831807, 0.02416466, 0.6269962, 0.47420614, 0.56275487, + 0.78995021, 0.16060581, 0.64635041, 0.34876609, 0.78543249, 0.19938356, 0.34429457, + 0.88072369, 0.17638164, 0.60819363, 0.250392, 0.69912532, 0.78855523, 0.79140914, + 0.85084218, 0.31839879, 0.63381769, 0.22421048, 0.70760302, 0.99216018, 0.80199153, + 0.19239188, 0.61356023, 0.31505352, 0.06120481, 0.66417377, 0.63608897, 0.84959691, + 0.43599069, 0.77867775, 0.88267754, 0.83003623, 0.67016118, 0.67547638, 0.65220036, + 0.68043427 + ]; + // Computed using NumPy + let expected_mean = 0.5475494059146699; + let expected_weighted_mean = 0.6782420496397121; + // Computed using SciPy + let expected_harmonic_mean = 0.21790094950226022; + let expected_geometric_mean = 0.4345897639796527; + + assert_abs_diff_eq!(a.mean().unwrap(), expected_mean, epsilon = 1e-9); + assert_abs_diff_eq!( + a.harmonic_mean().unwrap(), + expected_harmonic_mean, + epsilon = 1e-7 + ); + assert_abs_diff_eq!( + a.geometric_mean().unwrap(), + expected_geometric_mean, + epsilon = 1e-12 + ); + + // weighted_mean with itself, normalized + let weights = &a / a.sum(); + assert_abs_diff_eq!( + a.weighted_sum(&weights).unwrap(), + expected_weighted_mean, + epsilon = 1e-12 + ); + + let data = a.into_shape((2, 5, 5)).unwrap(); + let weights = array![0.1, 0.5, 0.25, 0.15, 0.2]; + assert_abs_diff_eq!( + data.weighted_mean_axis(Axis(1), &weights).unwrap(), + array![ + [0.50202721, 0.53347361, 0.29086033, 0.56995637, 0.37087139], + [0.58028328, 0.50485216, 0.59349973, 0.70308937, 0.72280630] + ], + epsilon = 1e-8 + ); + assert_abs_diff_eq!( + data.weighted_mean_axis(Axis(2), &weights).unwrap(), + array![ + [0.33434378, 0.38365259, 0.56405781, 0.48676574, 0.55016179], + [0.71112376, 0.55134174, 0.45566513, 0.74228516, 0.68405851] + ], + epsilon = 1e-8 + ); + assert_abs_diff_eq!( + data.weighted_sum_axis(Axis(1), &weights).unwrap(), + array![ + [0.60243266, 0.64016833, 0.34903240, 0.68394765, 0.44504567], + [0.69633993, 0.60582259, 0.71219968, 0.84370724, 0.86736757] + ], + epsilon = 1e-8 + ); + assert_abs_diff_eq!( + data.weighted_sum_axis(Axis(2), &weights).unwrap(), + array![ + [0.40121254, 0.46038311, 0.67686937, 0.58411889, 0.66019415], + [0.85334851, 0.66161009, 0.54679815, 0.89074219, 0.82087021] + ], + epsilon = 1e-8 + ); +} + +#[test] +fn weighted_sum_dimension_zero() { + let a = Array2::::zeros((0, 20)); + assert_eq!( + a.weighted_sum_axis(Axis(0), &Array1::zeros(0)).unwrap(), + Array1::from_elem(20, 0) + ); + assert_eq!( + a.weighted_sum_axis(Axis(1), &Array1::zeros(20)).unwrap(), + Array1::from_elem(0, 0) + ); + assert_eq!( + a.weighted_sum_axis(Axis(0), &Array1::zeros(1)), + Err(MultiInputError::ShapeMismatch(ShapeMismatch { + first_shape: vec![0, 20], + second_shape: vec![1] + })) + ); + assert_eq!( + a.weighted_sum(&Array2::zeros((10, 20))), + Err(MultiInputError::ShapeMismatch(ShapeMismatch { + first_shape: vec![0, 20], + second_shape: vec![10, 20] + })) + ); +} + +#[test] +fn mean_eq_if_uniform_weights() { + fn prop(a: Vec) -> TestResult { + if a.len() < 1 { + return TestResult::discard(); + } + let a = Array1::from(a); + let weights = Array1::from_elem(a.len(), 1.0 / a.len() as f64); + let m = a.mean().unwrap(); + let wm = a.weighted_mean(&weights).unwrap(); + let ws = a.weighted_sum(&weights).unwrap(); + TestResult::from_bool( + abs_diff_eq!(m, wm, epsilon = 1e-9) && abs_diff_eq!(wm, ws, epsilon = 1e-9), + ) + } + quickcheck(prop as fn(Vec) -> TestResult); +} + +#[test] +fn mean_axis_eq_if_uniform_weights() { + fn prop(mut a: Vec) -> TestResult { + if a.len() < 24 { + return TestResult::discard(); + } + let depth = a.len() / 12; + a.truncate(depth * 3 * 4); + let weights = Array1::from_elem(depth, 1.0 / depth as f64); + let a = Array1::from(a).into_shape((depth, 3, 4)).unwrap(); + let ma = a.mean_axis(Axis(0)).unwrap(); + let wm = a.weighted_mean_axis(Axis(0), &weights).unwrap(); + let ws = a.weighted_sum_axis(Axis(0), &weights).unwrap(); + TestResult::from_bool( + abs_diff_eq!(ma, wm, epsilon = 1e-12) && abs_diff_eq!(wm, ws, epsilon = 1e12), + ) + } + quickcheck(prop as fn(Vec) -> TestResult); +} + +#[test] +fn test_central_moment_with_empty_array_of_floats() { + let a: Array1 = array![]; + for order in 0..=3 { + assert_eq!(a.central_moment(order), Err(EmptyInput)); + assert_eq!(a.central_moments(order), Err(EmptyInput)); + } +} + +#[test] +fn test_zeroth_central_moment_is_one() { + let n = 50; + let bound: f64 = 200.; + let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); + assert_eq!(a.central_moment(0).unwrap(), 1.); +} + +#[test] +fn test_first_central_moment_is_zero() { + let n = 50; + let bound: f64 = 200.; + let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); + assert_eq!(a.central_moment(1).unwrap(), 0.); +} + +#[test] +fn test_central_moments() { + let a: Array1 = array![ + 0.07820559, 0.5026185, 0.80935324, 0.39384033, 0.9483038, 0.62516215, 0.90772261, + 0.87329831, 0.60267392, 0.2960298, 0.02810356, 0.31911966, 0.86705506, 0.96884832, + 0.2222465, 0.42162446, 0.99909868, 0.47619762, 0.91696979, 0.9972741, 0.09891734, + 0.76934818, 0.77566862, 0.7692585, 0.2235759, 0.44821286, 0.79732186, 0.04804275, + 0.87863238, 0.1111003, 0.6653943, 0.44386445, 0.2133176, 0.39397086, 0.4374617, 0.95896624, + 0.57850146, 0.29301706, 0.02329879, 0.2123203, 0.62005503, 0.996492, 0.5342986, 0.97822099, + 0.5028445, 0.6693834, 0.14256682, 0.52724704, 0.73482372, 0.1809703, + ]; + // Computed using scipy.stats.moment + let expected_moments = vec![ + 1., + 0., + 0.09339920262960291, + -0.0026849636727735186, + 0.015403769257729755, + -0.001204176487006564, + 0.002976822584939186, + ]; + for (order, expected_moment) in expected_moments.iter().enumerate() { + assert_abs_diff_eq!( + a.central_moment(order as u16).unwrap(), + expected_moment, + epsilon = 1e-8 + ); + } +} + +#[test] +fn test_bulk_central_moments() { + // Test that the bulk method is coherent with the non-bulk method + let n = 50; + let bound: f64 = 200.; + let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); + let order = 10; + let central_moments = a.central_moments(order).unwrap(); + for i in 0..=order { + assert_eq!(a.central_moment(i).unwrap(), central_moments[i as usize]); + } +} + +#[test] +fn test_kurtosis_and_skewness_is_none_with_empty_array_of_floats() { + let a: Array1 = array![]; + assert_eq!(a.skewness(), Err(EmptyInput)); + assert_eq!(a.kurtosis(), Err(EmptyInput)); +} + +#[test] +fn test_kurtosis_and_skewness() { + let a: Array1 = array![ + 0.33310096, 0.98757449, 0.9789796, 0.96738114, 0.43545674, 0.06746873, 0.23706562, + 0.04241815, 0.38961714, 0.52421271, 0.93430327, 0.33911604, 0.05112372, 0.5013455, + 0.05291507, 0.62511183, 0.20749633, 0.22132433, 0.14734804, 0.51960608, 0.00449208, + 0.4093339, 0.2237519, 0.28070469, 0.7887231, 0.92224523, 0.43454188, 0.18335111, + 0.08646856, 0.87979847, 0.25483457, 0.99975627, 0.52712442, 0.41163279, 0.85162594, + 0.52618733, 0.75815023, 0.30640695, 0.14205781, 0.59695813, 0.851331, 0.39524328, + 0.73965373, 0.4007615, 0.02133069, 0.92899207, 0.79878191, 0.38947334, 0.22042183, + 0.77768353, + ]; + // Computed using scipy.stats.kurtosis(a, fisher=False) + let expected_kurtosis = 1.821933711687523; + // Computed using scipy.stats.skew + let expected_skewness = 0.2604785422878771; + + let kurtosis = a.kurtosis().unwrap(); + let skewness = a.skewness().unwrap(); + + assert_abs_diff_eq!(kurtosis, expected_kurtosis, epsilon = 1e-12); + assert_abs_diff_eq!(skewness, expected_skewness, epsilon = 1e-8); +} From 8660232a567c82c4aa512a4768973087bb1ad57e Mon Sep 17 00:00:00 2001 From: Nil Goyette Date: Thu, 26 Sep 2019 14:15:58 -0400 Subject: [PATCH 2/8] Add weighted variance and standard deviation --- src/summary_statistics/means.rs | 34 ++++++++++++++++++++++++++++++++- src/summary_statistics/mod.rs | 34 ++++++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/src/summary_statistics/means.rs b/src/summary_statistics/means.rs index 66aa7a81..03026e1d 100644 --- a/src/summary_statistics/means.rs +++ b/src/summary_statistics/means.rs @@ -3,7 +3,7 @@ use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch}; use ndarray::{Array, ArrayBase, Axis, Data, Dimension, Ix1, RemoveAxis}; use num_integer::IterBinomial; use num_traits::{Float, FromPrimitive, Zero}; -use std::ops::{Add, Div, Mul}; +use std::ops::{Add, AddAssign, Div, Mul}; impl SummaryStatisticsExt for ArrayBase where @@ -105,6 +105,38 @@ where .ok_or(EmptyInput) } + fn weighted_var(&self, weights: &Self, ddof: A) -> Result + where + A: AddAssign + Float + FromPrimitive, + { + return_err_if_empty!(self); + return_err_unless_same_shape!(self, weights); + let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail."); + let one = A::from_usize(1).expect("Converting 1 to `A` must not fail."); + assert!( + !(ddof < zero || ddof > one), + "`ddof` must not be less than zero or greater than the length of the axis", + ); + + let mut weight_sum = zero; + let mut mean = zero; + let mut s = zero; + for (&x, &w) in self.iter().zip(weights.iter()) { + weight_sum += w; + let x_m_m = x - mean; + mean += (w / weight_sum) * x_m_m; + s += w * x_m_m * (x - mean); + } + Ok(s / (weight_sum - ddof)) + } + + fn weighted_std(&self, weights: &Self, ddof: A) -> Result + where + A: AddAssign + Float + FromPrimitive, + { + Ok(self.weighted_var(weights, ddof)?.sqrt()) + } + fn kurtosis(&self) -> Result where A: Float + FromPrimitive, diff --git a/src/summary_statistics/mod.rs b/src/summary_statistics/mod.rs index 3f00ca98..238aeb06 100644 --- a/src/summary_statistics/mod.rs +++ b/src/summary_statistics/mod.rs @@ -2,7 +2,7 @@ use crate::errors::{EmptyInput, MultiInputError}; use ndarray::{Array, ArrayBase, Axis, Data, Dimension, Ix1, RemoveAxis}; use num_traits::{Float, FromPrimitive, Zero}; -use std::ops::{Add, Div, Mul}; +use std::ops::{Add, AddAssign, Div, Mul}; /// Extension trait for `ArrayBase` providing methods /// to compute several summary statistics (e.g. mean, variance, etc.). @@ -156,6 +156,38 @@ where where A: Float + FromPrimitive; + /// Return weighted variance of all elements in the array. + /// + /// The weighted variance is computed using the [`West, D. H. D.`] incremental algorithm. + /// Equivalent to `var_axis` if the `weights` are normalized. + /// + /// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the + /// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`. + /// + /// **Panics** if `ddof` is less than zero or greater than one, if `axis` is out of bounds, or + /// if `A::from_usize()` fails for 0 or 1. + /// + /// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm + fn weighted_var(&self, weights: &Self, ddof: A) -> Result + where + A: AddAssign + Float + FromPrimitive; + + /// Return weighted standard deviation of all elements in the array. + /// + /// The weighted weighted standard deviation is computed using the [`West, D. H. D.`] + /// incremental algorithm. Equivalent to `var_axis` if the `weights` are normalized. + /// + /// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the + /// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`. + /// + /// **Panics** if `ddof` is less than zero or greater than one, if `axis` is out of bounds, or + /// if `A::from_usize()` fails for 0 or 1. + /// + /// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm + fn weighted_std(&self, weights: &Self, ddof: A) -> Result + where + A: AddAssign + Float + FromPrimitive; + /// Returns the [kurtosis] `Kurt[X]` of all elements in the array: /// /// ```text From 8e1e7b987d5ec1ffccdced9f9f40005a782d63e4 Mon Sep 17 00:00:00 2001 From: Nil Goyette Date: Thu, 26 Sep 2019 14:50:07 -0400 Subject: [PATCH 3/8] Add tests --- tests/summary_statistics.rs | 97 ++++++++++++++++++++++++++++++------- 1 file changed, 80 insertions(+), 17 deletions(-) diff --git a/tests/summary_statistics.rs b/tests/summary_statistics.rs index 6f045575..37f2fb9b 100644 --- a/tests/summary_statistics.rs +++ b/tests/summary_statistics.rs @@ -11,39 +11,48 @@ use rand::distributions::Uniform; use std::f64; #[test] -fn test_means_with_nan_values() { +fn test_with_nan_values() { let a = array![f64::NAN, 1.]; + let weights = array![1.0, f64::NAN]; assert!(a.mean().unwrap().is_nan()); - assert!(a.weighted_mean(&array![1.0, f64::NAN]).unwrap().is_nan()); - assert!(a.weighted_sum(&array![1.0, f64::NAN]).unwrap().is_nan()); + assert!(a.weighted_mean(&weights).unwrap().is_nan()); + assert!(a.weighted_sum(&weights).unwrap().is_nan()); assert!(a - .weighted_mean_axis(Axis(0), &array![1.0, f64::NAN]) + .weighted_mean_axis(Axis(0), &weights) .unwrap() .into_scalar() .is_nan()); assert!(a - .weighted_sum_axis(Axis(0), &array![1.0, f64::NAN]) + .weighted_sum_axis(Axis(0), &weights) .unwrap() .into_scalar() .is_nan()); assert!(a.harmonic_mean().unwrap().is_nan()); assert!(a.geometric_mean().unwrap().is_nan()); + assert!(a.weighted_var(&weights, 0.0).unwrap().is_nan()); + assert!(a.weighted_std(&weights, 0.0).unwrap().is_nan()); } #[test] -fn test_means_with_empty_array_of_floats() { +fn test_with_empty_array_of_floats() { let a: Array1 = array![]; + let weights = array![1.0]; assert_eq!(a.mean(), None); + assert_eq!(a.weighted_mean(&weights), Err(MultiInputError::EmptyInput)); assert_eq!( - a.weighted_mean(&array![1.0]), + a.weighted_mean_axis(Axis(0), &weights), Err(MultiInputError::EmptyInput) ); + assert_eq!(a.harmonic_mean(), Err(EmptyInput)); + assert_eq!(a.geometric_mean(), Err(EmptyInput)); assert_eq!( - a.weighted_mean_axis(Axis(0), &array![1.0]), + a.weighted_var(&weights, 0.0), + Err(MultiInputError::EmptyInput) + ); + assert_eq!( + a.weighted_std(&weights, 0.0), Err(MultiInputError::EmptyInput) ); - assert_eq!(a.harmonic_mean(), Err(EmptyInput)); - assert_eq!(a.geometric_mean(), Err(EmptyInput)); // The sum methods accept empty arrays assert_eq!(a.weighted_sum(&array![]), Ok(0.0)); @@ -51,27 +60,36 @@ fn test_means_with_empty_array_of_floats() { } #[test] -fn test_means_with_empty_array_of_noisy_floats() { +fn test_with_empty_array_of_noisy_floats() { let a: Array1 = array![]; + let weights = array![]; assert_eq!(a.mean(), None); - assert_eq!(a.weighted_mean(&array![]), Err(MultiInputError::EmptyInput)); + assert_eq!(a.weighted_mean(&weights), Err(MultiInputError::EmptyInput)); assert_eq!( - a.weighted_mean_axis(Axis(0), &array![]), + a.weighted_mean_axis(Axis(0), &weights), Err(MultiInputError::EmptyInput) ); assert_eq!(a.harmonic_mean(), Err(EmptyInput)); assert_eq!(a.geometric_mean(), Err(EmptyInput)); + assert_eq!( + a.weighted_var(&weights, N64::new(0.0)), + Err(MultiInputError::EmptyInput) + ); + assert_eq!( + a.weighted_std(&weights, N64::new(0.0)), + Err(MultiInputError::EmptyInput) + ); // The sum methods accept empty arrays - assert_eq!(a.weighted_sum(&array![]), Ok(N64::new(0.0))); + assert_eq!(a.weighted_sum(&weights), Ok(N64::new(0.0))); assert_eq!( - a.weighted_sum_axis(Axis(0), &array![]), + a.weighted_sum_axis(Axis(0), &weights), Ok(arr0(N64::new(0.0))) ); } #[test] -fn test_means_with_array_of_floats() { +fn test_with_array_of_floats() { let a: Array1 = array![ 0.99889651, 0.0150731, 0.28492482, 0.83819218, 0.48413156, 0.80710412, 0.41762936, 0.22879429, 0.43997224, 0.23831807, 0.02416466, 0.6269962, 0.47420614, 0.56275487, @@ -85,6 +103,7 @@ fn test_means_with_array_of_floats() { // Computed using NumPy let expected_mean = 0.5475494059146699; let expected_weighted_mean = 0.6782420496397121; + let expected_weighted_var = 0.04306695637838332; // Computed using SciPy let expected_harmonic_mean = 0.21790094950226022; let expected_geometric_mean = 0.4345897639796527; @@ -101,13 +120,23 @@ fn test_means_with_array_of_floats() { epsilon = 1e-12 ); - // weighted_mean with itself, normalized + // Input array used as weights, normalized let weights = &a / a.sum(); assert_abs_diff_eq!( a.weighted_sum(&weights).unwrap(), expected_weighted_mean, epsilon = 1e-12 ); + assert_abs_diff_eq!( + a.weighted_var(&weights, 0.0).unwrap(), + expected_weighted_var, + epsilon = 1e-12 + ); + assert_abs_diff_eq!( + a.weighted_std(&weights, 0.0).unwrap(), + expected_weighted_var.sqrt(), + epsilon = 1e-12 + ); let data = a.into_shape((2, 5, 5)).unwrap(); let weights = array![0.1, 0.5, 0.25, 0.15, 0.2]; @@ -210,6 +239,40 @@ fn mean_axis_eq_if_uniform_weights() { quickcheck(prop as fn(Vec) -> TestResult); } +#[test] +fn weighted_var_eq_var_if_uniform_weight() { + fn prop(a: Vec) -> TestResult { + if a.len() < 1 { + return TestResult::discard(); + } + let a = Array1::from(a); + let weights = Array1::from_elem(a.len(), 1.0 / a.len() as f64); + let weighted_var = a.weighted_var(&weights, 0.0).unwrap(); + let var = a.var_axis(Axis(0), 0.0).into_scalar(); + TestResult::from_bool(abs_diff_eq!(weighted_var, var, epsilon = 1e-10)) + } + quickcheck(prop as fn(Vec) -> TestResult); +} + +#[test] +fn weighted_var_algo_eq_simple_algo() { + fn prop(a: Vec) -> TestResult { + if a.len() < 1 { + return TestResult::discard(); + } + let a = Array1::from(a); + let weights = Array::random(a.len(), Uniform::new(0.0, 1.0)); + let mean = a.weighted_mean(&weights).unwrap(); + let res_1_pass = a.weighted_var(&weights, 0.0).unwrap(); + let res_2_pass = (a - mean) + .mapv_into(|v| v.powi(2)) + .weighted_mean(&weights) + .unwrap(); + TestResult::from_bool(abs_diff_eq!(res_1_pass, res_2_pass, epsilon = 1e-10)) + } + quickcheck(prop as fn(Vec) -> TestResult); +} + #[test] fn test_central_moment_with_empty_array_of_floats() { let a: Array1 = array![]; From a0d4212564fd621862a1c7cc8f7e711719fc647a Mon Sep 17 00:00:00 2001 From: Nil Goyette Date: Thu, 26 Sep 2019 15:37:16 -0400 Subject: [PATCH 4/8] Add axis versions --- src/summary_statistics/means.rs | 86 +++++++++++++++++++++++++++------ src/summary_statistics/mod.rs | 52 +++++++++++++++++--- 2 files changed, 118 insertions(+), 20 deletions(-) diff --git a/src/summary_statistics/means.rs b/src/summary_statistics/means.rs index 03026e1d..6a6fc617 100644 --- a/src/summary_statistics/means.rs +++ b/src/summary_statistics/means.rs @@ -112,22 +112,11 @@ where return_err_if_empty!(self); return_err_unless_same_shape!(self, weights); let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail."); - let one = A::from_usize(1).expect("Converting 1 to `A` must not fail."); assert!( - !(ddof < zero || ddof > one), - "`ddof` must not be less than zero or greater than the length of the axis", + !(ddof < zero || ddof > A::from_usize(1).unwrap()), + "`ddof` must not be less than zero or greater than one", ); - - let mut weight_sum = zero; - let mut mean = zero; - let mut s = zero; - for (&x, &w) in self.iter().zip(weights.iter()) { - weight_sum += w; - let x_m_m = x - mean; - mean += (w / weight_sum) * x_m_m; - s += w * x_m_m * (x - mean); - } - Ok(s / (weight_sum - ddof)) + inner_weighted_var(self, weights, ddof, zero) } fn weighted_std(&self, weights: &Self, ddof: A) -> Result @@ -137,6 +126,51 @@ where Ok(self.weighted_var(weights, ddof)?.sqrt()) } + fn weighted_var_axis( + &self, + axis: Axis, + weights: &ArrayBase, + ddof: A, + ) -> Result, MultiInputError> + where + A: AddAssign + Float + FromPrimitive, + D: RemoveAxis, + { + return_err_if_empty!(self); + if self.shape()[axis.index()] != weights.len() { + return Err(MultiInputError::ShapeMismatch(ShapeMismatch { + first_shape: self.shape().to_vec(), + second_shape: weights.shape().to_vec(), + })); + } + let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail."); + assert!( + !(ddof < zero || ddof > A::from_usize(1).unwrap()), + "`ddof` must not be less than zero or greater than one", + ); + + // `weights` must be a view because `lane` is a view in this context. + let weights = weights.view(); + Ok(self.map_axis(axis, |lane| { + inner_weighted_var(&lane, &weights, ddof, zero).unwrap() + })) + } + + fn weighted_std_axis( + &self, + axis: Axis, + weights: &ArrayBase, + ddof: A, + ) -> Result, MultiInputError> + where + A: AddAssign + Float + FromPrimitive, + D: RemoveAxis, + { + Ok(self + .weighted_var_axis(axis, weights, ddof)? + .mapv_into(|x| x.sqrt())) + } + fn kurtosis(&self) -> Result where A: Float + FromPrimitive, @@ -208,6 +242,30 @@ where private_impl! {} } +/// Private function for `weighted_var` without conditions and asserts. +fn inner_weighted_var( + arr: &ArrayBase, + weights: &ArrayBase, + ddof: A, + zero: A, +) -> Result +where + S: Data, + A: AddAssign + Float + FromPrimitive, + D: Dimension, +{ + let mut weight_sum = zero; + let mut mean = zero; + let mut s = zero; + for (&x, &w) in arr.iter().zip(weights.iter()) { + weight_sum += w; + let x_m_m = x - mean; + mean += (w / weight_sum) * x_m_m; + s += w * x_m_m * (x - mean); + } + Ok(s / (weight_sum - ddof)) +} + /// Returns a vector containing all moments of the array elements up to /// *order*, where the *p*-th moment is defined as: /// diff --git a/src/summary_statistics/mod.rs b/src/summary_statistics/mod.rs index 238aeb06..425a2f9d 100644 --- a/src/summary_statistics/mod.rs +++ b/src/summary_statistics/mod.rs @@ -164,8 +164,7 @@ where /// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the /// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`. /// - /// **Panics** if `ddof` is less than zero or greater than one, if `axis` is out of bounds, or - /// if `A::from_usize()` fails for 0 or 1. + /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds. /// /// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm fn weighted_var(&self, weights: &Self, ddof: A) -> Result @@ -174,20 +173,61 @@ where /// Return weighted standard deviation of all elements in the array. /// - /// The weighted weighted standard deviation is computed using the [`West, D. H. D.`] - /// incremental algorithm. Equivalent to `var_axis` if the `weights` are normalized. + /// The weighted standard deviation is computed using the [`West, D. H. D.`] incremental + /// algorithm. Equivalent to `var_axis` if the `weights` are normalized. /// /// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the /// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`. /// - /// **Panics** if `ddof` is less than zero or greater than one, if `axis` is out of bounds, or - /// if `A::from_usize()` fails for 0 or 1. + /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds. /// /// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm fn weighted_std(&self, weights: &Self, ddof: A) -> Result where A: AddAssign + Float + FromPrimitive; + /// Return weighted variance along `axis`. + /// + /// The weighted variance is computed using the [`West, D. H. D.`] incremental algorithm. + /// Equivalent to `var_axis` if the `weights` are normalized. + /// + /// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the + /// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`. + /// + /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds. + /// + /// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm + fn weighted_var_axis( + &self, + axis: Axis, + weights: &ArrayBase, + ddof: A, + ) -> Result, MultiInputError> + where + A: AddAssign + Float + FromPrimitive, + D: RemoveAxis; + + /// Return weighted standard deviation along `axis`. + /// + /// The weighted standard deviation is computed using the [`West, D. H. D.`] incremental + /// algorithm. Equivalent to `var_axis` if the `weights` are normalized. + /// + /// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the + /// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`. + /// + /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds. + /// + /// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm + fn weighted_std_axis( + &self, + axis: Axis, + weights: &ArrayBase, + ddof: A, + ) -> Result, MultiInputError> + where + A: AddAssign + Float + FromPrimitive, + D: RemoveAxis; + /// Returns the [kurtosis] `Kurt[X]` of all elements in the array: /// /// ```text From 4ca755855d317e6adc6f97abbe91f00ac198958c Mon Sep 17 00:00:00 2001 From: Nil Goyette Date: Thu, 26 Sep 2019 15:54:03 -0400 Subject: [PATCH 5/8] Add tests for axis versions --- tests/summary_statistics.rs | 56 +++++++++++++++++++++++++++++-------- 1 file changed, 45 insertions(+), 11 deletions(-) diff --git a/tests/summary_statistics.rs b/tests/summary_statistics.rs index 37f2fb9b..a754d780 100644 --- a/tests/summary_statistics.rs +++ b/tests/summary_statistics.rs @@ -31,6 +31,16 @@ fn test_with_nan_values() { assert!(a.geometric_mean().unwrap().is_nan()); assert!(a.weighted_var(&weights, 0.0).unwrap().is_nan()); assert!(a.weighted_std(&weights, 0.0).unwrap().is_nan()); + assert!(a + .weighted_var_axis(Axis(0), &weights, 0.0) + .unwrap() + .into_scalar() + .is_nan()); + assert!(a + .weighted_std_axis(Axis(0), &weights, 0.0) + .unwrap() + .into_scalar() + .is_nan()); } #[test] @@ -53,6 +63,14 @@ fn test_with_empty_array_of_floats() { a.weighted_std(&weights, 0.0), Err(MultiInputError::EmptyInput) ); + assert_eq!( + a.weighted_var_axis(Axis(0), &weights, 0.0), + Err(MultiInputError::EmptyInput) + ); + assert_eq!( + a.weighted_std_axis(Axis(0), &weights, 0.0), + Err(MultiInputError::EmptyInput) + ); // The sum methods accept empty arrays assert_eq!(a.weighted_sum(&array![]), Ok(0.0)); @@ -79,6 +97,14 @@ fn test_with_empty_array_of_noisy_floats() { a.weighted_std(&weights, N64::new(0.0)), Err(MultiInputError::EmptyInput) ); + assert_eq!( + a.weighted_var_axis(Axis(0), &weights, N64::new(0.0)), + Err(MultiInputError::EmptyInput) + ); + assert_eq!( + a.weighted_std_axis(Axis(0), &weights, N64::new(0.0)), + Err(MultiInputError::EmptyInput) + ); // The sum methods accept empty arrays assert_eq!(a.weighted_sum(&weights), Ok(N64::new(0.0))); @@ -256,19 +282,27 @@ fn weighted_var_eq_var_if_uniform_weight() { #[test] fn weighted_var_algo_eq_simple_algo() { - fn prop(a: Vec) -> TestResult { - if a.len() < 1 { + fn prop(mut a: Vec) -> TestResult { + if a.len() < 24 { return TestResult::discard(); } - let a = Array1::from(a); - let weights = Array::random(a.len(), Uniform::new(0.0, 1.0)); - let mean = a.weighted_mean(&weights).unwrap(); - let res_1_pass = a.weighted_var(&weights, 0.0).unwrap(); - let res_2_pass = (a - mean) - .mapv_into(|v| v.powi(2)) - .weighted_mean(&weights) - .unwrap(); - TestResult::from_bool(abs_diff_eq!(res_1_pass, res_2_pass, epsilon = 1e-10)) + let depth = a.len() / 12; + a.truncate(depth * 3 * 4); + let a = Array1::from(a).into_shape((depth, 3, 4)).unwrap(); + let mut success = true; + for axis in 0..3 { + let axis = Axis(axis); + + let weights = Array::random(a.len_of(axis), Uniform::new(0.0, 1.0)); + let mean = a.weighted_mean_axis(axis, &weights).unwrap().insert_axis(axis); + let res_1_pass = a.weighted_var_axis(axis, &weights, 0.0).unwrap(); + let res_2_pass = (&a - &mean) + .mapv_into(|v| v.powi(2)) + .weighted_mean_axis(axis, &weights) + .unwrap(); + success &= abs_diff_eq!(res_1_pass, res_2_pass, epsilon = 1e-10); + } + TestResult::from_bool(success) } quickcheck(prop as fn(Vec) -> TestResult); } From 5629d11578cd14f64d652c6d0404b68e0018869d Mon Sep 17 00:00:00 2001 From: Nil Goyette Date: Mon, 7 Oct 2019 11:24:17 -0400 Subject: [PATCH 6/8] ddof expect and doc --- src/summary_statistics/means.rs | 12 +++++++----- src/summary_statistics/mod.rs | 12 ++++++++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/summary_statistics/means.rs b/src/summary_statistics/means.rs index 6a6fc617..1ed3bcad 100644 --- a/src/summary_statistics/means.rs +++ b/src/summary_statistics/means.rs @@ -112,8 +112,9 @@ where return_err_if_empty!(self); return_err_unless_same_shape!(self, weights); let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail."); + let one = A::from_usize(1).expect("Converting 1 to `A` must not fail."); assert!( - !(ddof < zero || ddof > A::from_usize(1).unwrap()), + !(ddof < zero || ddof > one), "`ddof` must not be less than zero or greater than one", ); inner_weighted_var(self, weights, ddof, zero) @@ -144,8 +145,9 @@ where })); } let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail."); + let one = A::from_usize(1).expect("Converting 1 to `A` must not fail."); assert!( - !(ddof < zero || ddof > A::from_usize(1).unwrap()), + !(ddof < zero || ddof > one), "`ddof` must not be less than zero or greater than one", ); @@ -259,9 +261,9 @@ where let mut s = zero; for (&x, &w) in arr.iter().zip(weights.iter()) { weight_sum += w; - let x_m_m = x - mean; - mean += (w / weight_sum) * x_m_m; - s += w * x_m_m * (x - mean); + let x_minus_mean = x - mean; + mean += (w / weight_sum) * x_minus_mean; + s += w * x_minus_mean * (x - mean); } Ok(s / (weight_sum - ddof)) } diff --git a/src/summary_statistics/mod.rs b/src/summary_statistics/mod.rs index 425a2f9d..1f8fe000 100644 --- a/src/summary_statistics/mod.rs +++ b/src/summary_statistics/mod.rs @@ -164,7 +164,8 @@ where /// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the /// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`. /// - /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds. + /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds, + /// or if `A::from_usize()` fails for zero or one. /// /// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm fn weighted_var(&self, weights: &Self, ddof: A) -> Result @@ -179,7 +180,8 @@ where /// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the /// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`. /// - /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds. + /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds, + /// or if `A::from_usize()` fails for zero or one. /// /// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm fn weighted_std(&self, weights: &Self, ddof: A) -> Result @@ -194,7 +196,8 @@ where /// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the /// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`. /// - /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds. + /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds, + /// or if `A::from_usize()` fails for zero or one. /// /// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm fn weighted_var_axis( @@ -215,7 +218,8 @@ where /// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the /// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`. /// - /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds. + /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds, + /// or if `A::from_usize()` fails for zero or one. /// /// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm fn weighted_std_axis( From 6254b28e2d3522891bd581eef9ee0a6de63f9f60 Mon Sep 17 00:00:00 2001 From: Nil Goyette Date: Mon, 7 Oct 2019 11:30:25 -0400 Subject: [PATCH 7/8] Fmt --- tests/summary_statistics.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/summary_statistics.rs b/tests/summary_statistics.rs index a754d780..ed22bde6 100644 --- a/tests/summary_statistics.rs +++ b/tests/summary_statistics.rs @@ -294,7 +294,10 @@ fn weighted_var_algo_eq_simple_algo() { let axis = Axis(axis); let weights = Array::random(a.len_of(axis), Uniform::new(0.0, 1.0)); - let mean = a.weighted_mean_axis(axis, &weights).unwrap().insert_axis(axis); + let mean = a + .weighted_mean_axis(axis, &weights) + .unwrap() + .insert_axis(axis); let res_1_pass = a.weighted_var_axis(axis, &weights, 0.0).unwrap(); let res_2_pass = (&a - &mean) .mapv_into(|v| v.powi(2)) From 092cfa61c17dba281f748d04489b5006bda6b8f9 Mon Sep 17 00:00:00 2001 From: Nil Goyette Date: Mon, 7 Oct 2019 13:05:28 -0400 Subject: [PATCH 8/8] Add benches --- Cargo.toml | 4 ++++ benches/summary_statistics.rs | 37 +++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) create mode 100644 benches/summary_statistics.rs diff --git a/Cargo.toml b/Cargo.toml index 0fecc1f3..0dbd1a28 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,3 +36,7 @@ num-bigint = "0.2.2" [[bench]] name = "sort" harness = false + +[[bench]] +name = "summary_statistics" +harness = false \ No newline at end of file diff --git a/benches/summary_statistics.rs b/benches/summary_statistics.rs new file mode 100644 index 00000000..74694e6b --- /dev/null +++ b/benches/summary_statistics.rs @@ -0,0 +1,37 @@ +use criterion::{ + black_box, criterion_group, criterion_main, AxisScale, BatchSize, Criterion, + ParameterizedBenchmark, PlotConfiguration, +}; +use ndarray::prelude::*; +use ndarray_rand::RandomExt; +use ndarray_stats::SummaryStatisticsExt; +use rand::distributions::Uniform; + +fn weighted_std(c: &mut Criterion) { + let lens = vec![10, 100, 1000, 10000]; + let benchmark = ParameterizedBenchmark::new( + "weighted_std", + |bencher, &len| { + let data = Array::random(len, Uniform::new(0.0, 1.0)); + let mut weights = Array::random(len, Uniform::new(0.0, 1.0)); + weights /= weights.sum(); + bencher.iter_batched( + || data.clone(), + |arr| { + black_box(arr.weighted_std(&weights, 0.0).unwrap()); + }, + BatchSize::SmallInput, + ) + }, + lens, + ) + .plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); + c.bench("weighted_std", benchmark); +} + +criterion_group! { + name = benches; + config = Criterion::default(); + targets = weighted_std +} +criterion_main!(benches);