Skip to content

Commit 680209f

Browse files
Update for ndarray to 0.13 (#52)
* Return a result for correlation functions * Fix geometric and harmonic mean * Fix all_close everywhere * Green tests * Update to released versions * Enable approx feature for the test suite * Fix deprecation warnings * Formatting * Bump minimum Rust version
1 parent ba56083 commit 680209f

File tree

8 files changed

+95
-59
lines changed

8 files changed

+95
-59
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ addons:
77
- libssl-dev
88
cache: cargo
99
rust:
10-
- 1.34.0
10+
- 1.37.0
1111
- stable
1212
- beta
1313
- nightly

Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ keywords = ["array", "multidimensional", "statistics", "matrix", "ndarray"]
1616
categories = ["data-structures", "science"]
1717

1818
[dependencies]
19-
ndarray = "0.12.1"
19+
ndarray = "0.13"
2020
noisy_float = "0.1.8"
2121
num-integer = "0.1"
2222
num-traits = "0.2"
@@ -25,9 +25,10 @@ itertools = { version = "0.8.0", default-features = false }
2525
indexmap = "1.0"
2626

2727
[dev-dependencies]
28+
ndarray = { version = "0.13", features = ["approx"] }
2829
criterion = "0.2"
2930
quickcheck = { version = "0.8.1", default-features = false }
30-
ndarray-rand = "0.10"
31+
ndarray-rand = "0.11"
3132
approx = "0.3"
3233
quickcheck_macros = "0.8"
3334
num-bigint = "0.2.2"

src/correlation.rs

Lines changed: 73 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::errors::EmptyInput;
12
use ndarray::prelude::*;
23
use ndarray::Data;
34
use num_traits::{Float, FromPrimitive};
@@ -41,10 +42,10 @@ where
4142
/// ```
4243
/// and similarly for ̅y.
4344
///
44-
/// **Panics** if `ddof` is greater than or equal to the number of
45-
/// observations, if the number of observations is zero and division by
46-
/// zero panics for type `A`, or if the type cast of `n_observations` from
47-
/// `usize` to `A` fails.
45+
/// If `M` is empty (either zero observations or zero random variables), it returns `Err(EmptyInput)`.
46+
///
47+
/// **Panics** if `ddof` is negative or greater than or equal to the number of
48+
/// observations, or if the type cast of `n_observations` from `usize` to `A` fails.
4849
///
4950
/// # Example
5051
///
@@ -54,13 +55,13 @@ where
5455
///
5556
/// let a = arr2(&[[1., 3., 5.],
5657
/// [2., 4., 6.]]);
57-
/// let covariance = a.cov(1.);
58+
/// let covariance = a.cov(1.).unwrap();
5859
/// assert_eq!(
5960
/// covariance,
6061
/// aview2(&[[4., 4.], [4., 4.]])
6162
/// );
6263
/// ```
63-
fn cov(&self, ddof: A) -> Array2<A>
64+
fn cov(&self, ddof: A) -> Result<Array2<A>, EmptyInput>
6465
where
6566
A: Float + FromPrimitive;
6667

@@ -89,30 +90,35 @@ where
8990
/// R_ij = rho(X_i, X_j)
9091
/// ```
9192
///
92-
/// **Panics** if `M` is empty, if the type cast of `n_observations`
93-
/// from `usize` to `A` fails or if the standard deviation of one of the random
93+
/// If `M` is empty (either zero observations or zero random variables), it returns `Err(EmptyInput)`.
94+
///
95+
/// **Panics** if the type cast of `n_observations` from `usize` to `A` fails or
96+
/// if the standard deviation of one of the random variables is zero and
97+
/// division by zero panics for type A.
9498
///
9599
/// # Example
96100
///
97-
/// variables is zero and division by zero panics for type A.
98101
/// ```
102+
/// use approx;
99103
/// use ndarray::arr2;
100104
/// use ndarray_stats::CorrelationExt;
105+
/// use approx::AbsDiffEq;
101106
///
102107
/// let a = arr2(&[[1., 3., 5.],
103108
/// [2., 4., 6.]]);
104-
/// let corr = a.pearson_correlation();
109+
/// let corr = a.pearson_correlation().unwrap();
110+
/// let epsilon = 1e-7;
105111
/// assert!(
106-
/// corr.all_close(
112+
/// corr.abs_diff_eq(
107113
/// &arr2(&[
108114
/// [1., 1.],
109115
/// [1., 1.],
110116
/// ]),
111-
/// 1e-7
117+
/// epsilon
112118
/// )
113119
/// );
114120
/// ```
115-
fn pearson_correlation(&self) -> Array2<A>
121+
fn pearson_correlation(&self) -> Result<Array2<A>, EmptyInput>
116122
where
117123
A: Float + FromPrimitive;
118124

@@ -123,7 +129,7 @@ impl<A: 'static, S> CorrelationExt<A, S> for ArrayBase<S, Ix2>
123129
where
124130
S: Data<Elem = A>,
125131
{
126-
fn cov(&self, ddof: A) -> Array2<A>
132+
fn cov(&self, ddof: A) -> Result<Array2<A>, EmptyInput>
127133
where
128134
A: Float + FromPrimitive,
129135
{
@@ -139,28 +145,37 @@ where
139145
n_observations - ddof
140146
};
141147
let mean = self.mean_axis(observation_axis);
142-
let denoised = self - &mean.insert_axis(observation_axis);
143-
let covariance = denoised.dot(&denoised.t());
144-
covariance.mapv_into(|x| x / dof)
148+
match mean {
149+
Some(mean) => {
150+
let denoised = self - &mean.insert_axis(observation_axis);
151+
let covariance = denoised.dot(&denoised.t());
152+
Ok(covariance.mapv_into(|x| x / dof))
153+
}
154+
None => Err(EmptyInput),
155+
}
145156
}
146157

147-
fn pearson_correlation(&self) -> Array2<A>
158+
fn pearson_correlation(&self) -> Result<Array2<A>, EmptyInput>
148159
where
149160
A: Float + FromPrimitive,
150161
{
151-
let observation_axis = Axis(1);
152-
// The ddof value doesn't matter, as long as we use the same one
153-
// for computing covariance and standard deviation
154-
// We choose -1 to avoid panicking when we only have one
155-
// observation per random variable (or no observations at all)
156-
let ddof = -A::one();
157-
let cov = self.cov(ddof);
158-
let std = self
159-
.std_axis(observation_axis, ddof)
160-
.insert_axis(observation_axis);
161-
let std_matrix = std.dot(&std.t());
162-
// element-wise division
163-
cov / std_matrix
162+
match self.dim() {
163+
(n, m) if n > 0 && m > 0 => {
164+
let observation_axis = Axis(1);
165+
// The ddof value doesn't matter, as long as we use the same one
166+
// for computing covariance and standard deviation
167+
// We choose 0 as it is the smallest number admitted by std_axis
168+
let ddof = A::zero();
169+
let cov = self.cov(ddof).unwrap();
170+
let std = self
171+
.std_axis(observation_axis, ddof)
172+
.insert_axis(observation_axis);
173+
let std_matrix = std.dot(&std.t());
174+
// element-wise division
175+
Ok(cov / std_matrix)
176+
}
177+
_ => Err(EmptyInput),
178+
}
164179
}
165180

166181
private_impl! {}
@@ -180,9 +195,10 @@ mod cov_tests {
180195
let n_random_variables = 3;
181196
let n_observations = 4;
182197
let a = Array::from_elem((n_random_variables, n_observations), value);
183-
a.cov(1.).all_close(
198+
abs_diff_eq!(
199+
a.cov(1.).unwrap(),
184200
&Array::zeros((n_random_variables, n_random_variables)),
185-
1e-8,
201+
epsilon = 1e-8,
186202
)
187203
}
188204

@@ -194,8 +210,8 @@ mod cov_tests {
194210
(n_random_variables, n_observations),
195211
Uniform::new(-bound.abs(), bound.abs()),
196212
);
197-
let covariance = a.cov(1.);
198-
covariance.all_close(&covariance.t(), 1e-8)
213+
let covariance = a.cov(1.).unwrap();
214+
abs_diff_eq!(covariance, &covariance.t(), epsilon = 1e-8)
199215
}
200216

201217
#[test]
@@ -205,31 +221,31 @@ mod cov_tests {
205221
let n_observations = 4;
206222
let a = Array::random((n_random_variables, n_observations), Uniform::new(0., 10.));
207223
let invalid_ddof = (n_observations as f64) + rand::random::<f64>().abs();
208-
a.cov(invalid_ddof);
224+
let _ = a.cov(invalid_ddof);
209225
}
210226

211227
#[test]
212228
fn test_covariance_zero_variables() {
213229
let a = Array2::<f32>::zeros((0, 2));
214230
let cov = a.cov(1.);
215-
assert_eq!(cov.shape(), &[0, 0]);
231+
assert!(cov.is_ok());
232+
assert_eq!(cov.unwrap().shape(), &[0, 0]);
216233
}
217234

218235
#[test]
219236
fn test_covariance_zero_observations() {
220237
let a = Array2::<f32>::zeros((2, 0));
221238
// Negative ddof (-1 < 0) to avoid invalid-ddof panic
222239
let cov = a.cov(-1.);
223-
assert_eq!(cov.shape(), &[2, 2]);
224-
cov.mapv(|x| assert_eq!(x, 0.));
240+
assert_eq!(cov, Err(EmptyInput));
225241
}
226242

227243
#[test]
228244
fn test_covariance_zero_variables_zero_observations() {
229245
let a = Array2::<f32>::zeros((0, 0));
230246
// Negative ddof (-1 < 0) to avoid invalid-ddof panic
231247
let cov = a.cov(-1.);
232-
assert_eq!(cov.shape(), &[0, 0]);
248+
assert_eq!(cov, Err(EmptyInput));
233249
}
234250

235251
#[test]
@@ -255,7 +271,7 @@ mod cov_tests {
255271
]
256272
];
257273
assert_eq!(a.ndim(), 2);
258-
assert!(a.cov(1.).all_close(&numpy_covariance, 1e-8));
274+
assert_abs_diff_eq!(a.cov(1.).unwrap(), &numpy_covariance, epsilon = 1e-8);
259275
}
260276

261277
#[test]
@@ -264,7 +280,7 @@ mod cov_tests {
264280
fn test_covariance_for_badly_conditioned_array() {
265281
let a: Array2<f64> = array![[1e12 + 1., 1e12 - 1.], [1e-6 + 1e-12, 1e-6 - 1e-12],];
266282
let expected_covariance = array![[2., 2e-12], [2e-12, 2e-24]];
267-
assert!(a.cov(1.).all_close(&expected_covariance, 1e-24));
283+
assert_abs_diff_eq!(a.cov(1.).unwrap(), &expected_covariance, epsilon = 1e-24);
268284
}
269285
}
270286

@@ -284,8 +300,12 @@ mod pearson_correlation_tests {
284300
(n_random_variables, n_observations),
285301
Uniform::new(-bound.abs(), bound.abs()),
286302
);
287-
let pearson_correlation = a.pearson_correlation();
288-
pearson_correlation.all_close(&pearson_correlation.t(), 1e-8)
303+
let pearson_correlation = a.pearson_correlation().unwrap();
304+
abs_diff_eq!(
305+
pearson_correlation.view(),
306+
pearson_correlation.t(),
307+
epsilon = 1e-8
308+
)
289309
}
290310

291311
#[quickcheck]
@@ -295,6 +315,7 @@ mod pearson_correlation_tests {
295315
let a = Array::from_elem((n_random_variables, n_observations), value);
296316
let pearson_correlation = a.pearson_correlation();
297317
pearson_correlation
318+
.unwrap()
298319
.iter()
299320
.map(|x| x.is_nan())
300321
.fold(true, |acc, flag| acc & flag)
@@ -304,21 +325,21 @@ mod pearson_correlation_tests {
304325
fn test_zero_variables() {
305326
let a = Array2::<f32>::zeros((0, 2));
306327
let pearson_correlation = a.pearson_correlation();
307-
assert_eq!(pearson_correlation.shape(), &[0, 0]);
328+
assert_eq!(pearson_correlation, Err(EmptyInput))
308329
}
309330

310331
#[test]
311332
fn test_zero_observations() {
312333
let a = Array2::<f32>::zeros((2, 0));
313334
let pearson = a.pearson_correlation();
314-
pearson.mapv(|x| x.is_nan());
335+
assert_eq!(pearson, Err(EmptyInput));
315336
}
316337

317338
#[test]
318339
fn test_zero_variables_zero_observations() {
319340
let a = Array2::<f32>::zeros((0, 0));
320341
let pearson = a.pearson_correlation();
321-
assert_eq!(pearson.shape(), &[0, 0]);
342+
assert_eq!(pearson, Err(EmptyInput));
322343
}
323344

324345
#[test]
@@ -338,6 +359,10 @@ mod pearson_correlation_tests {
338359
[0.1365648, 0.38954398, -0.17324776, -0.8743213, 1.]
339360
];
340361
assert_eq!(a.ndim(), 2);
341-
assert!(a.pearson_correlation().all_close(&numpy_corrcoeff, 1e-7));
362+
assert_abs_diff_eq!(
363+
a.pearson_correlation().unwrap(),
364+
numpy_corrcoeff,
365+
epsilon = 1e-7
366+
);
342367
}
343368
}

src/histogram/bins.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ mod edges_tests {
349349

350350
#[quickcheck]
351351
fn check_sorted_from_array(v: Vec<i32>) -> bool {
352-
let a = Array1::from_vec(v);
352+
let a = Array1::from(v);
353353
let edges = Edges::from(a);
354354
let n = edges.len();
355355
for i in 1..n {

src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ pub use crate::quantile::{interpolate, Quantile1dExt, QuantileExt};
3838
pub use crate::sort::Sort1dExt;
3939
pub use crate::summary_statistics::SummaryStatisticsExt;
4040

41+
#[cfg(test)]
42+
#[macro_use]
43+
extern crate approx;
44+
4145
#[macro_use]
4246
mod private {
4347
/// This is a public type in a private module, so it can be included in

src/summary_statistics/means.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,20 @@ where
2828
where
2929
A: Float + FromPrimitive,
3030
{
31-
self.map(|x| x.recip()).mean().map(|x| x.recip())
31+
self.map(|x| x.recip())
32+
.mean()
33+
.map(|x| x.recip())
34+
.ok_or(EmptyInput)
3235
}
3336

3437
fn geometric_mean(&self) -> Result<A, EmptyInput>
3538
where
3639
A: Float + FromPrimitive,
3740
{
38-
self.map(|x| x.ln()).mean().map(|x| x.exp())
41+
self.map(|x| x.ln())
42+
.mean()
43+
.map(|x| x.exp())
44+
.ok_or(EmptyInput)
3945
}
4046

4147
fn kurtosis(&self) -> Result<A, EmptyInput>
@@ -207,15 +213,15 @@ mod tests {
207213
#[test]
208214
fn test_means_with_empty_array_of_floats() {
209215
let a: Array1<f64> = array![];
210-
assert_eq!(a.mean(), Err(EmptyInput));
216+
assert_eq!(a.mean(), None);
211217
assert_eq!(a.harmonic_mean(), Err(EmptyInput));
212218
assert_eq!(a.geometric_mean(), Err(EmptyInput));
213219
}
214220

215221
#[test]
216222
fn test_means_with_empty_array_of_noisy_floats() {
217223
let a: Array1<N64> = array![];
218-
assert_eq!(a.mean(), Err(EmptyInput));
224+
assert_eq!(a.mean(), None);
219225
assert_eq!(a.harmonic_mean(), Err(EmptyInput));
220226
assert_eq!(a.geometric_mean(), Err(EmptyInput));
221227
}

tests/quantile.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ fn test_midpoint_overflow() {
279279

280280
#[quickcheck]
281281
fn test_quantiles_mut(xs: Vec<i64>) -> bool {
282-
let v = Array::from_vec(xs.clone());
282+
let v = Array::from(xs.clone());
283283

284284
// Unordered list of quantile indexes to look up, with a duplicate
285285
let quantile_indexes = Array::from(vec![

tests/sort.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ fn test_sorted_get_many_mut(mut xs: Vec<i64>) -> bool {
4949
if n == 0 {
5050
true
5151
} else {
52-
let mut v = Array::from_vec(xs.clone());
52+
let mut v = Array::from(xs.clone());
5353

5454
// Insert each index twice, to get a set of indexes with duplicates, not sorted
5555
let mut indexes: Vec<usize> = (0..n).into_iter().collect();
@@ -78,7 +78,7 @@ fn test_sorted_get_mut_as_sorting_algorithm(mut xs: Vec<i64>) -> bool {
7878
if n == 0 {
7979
true
8080
} else {
81-
let mut v = Array::from_vec(xs.clone());
81+
let mut v = Array::from(xs.clone());
8282
let sorted_v: Vec<_> = (0..n).map(|i| v.get_from_sorted_mut(i)).collect();
8383
xs.sort();
8484
xs == sorted_v

0 commit comments

Comments
 (0)