1
+ use crate :: errors:: EmptyInput ;
1
2
use ndarray:: prelude:: * ;
2
3
use ndarray:: Data ;
3
4
use num_traits:: { Float , FromPrimitive } ;
@@ -41,10 +42,10 @@ where
41
42
/// ```
42
43
/// and similarly for ̅y.
43
44
///
44
- /// **Panics** if `ddof ` is greater than or equal to the number of
45
- /// observations, if the number of observations is zero and division by
46
- /// zero panics for type `A`, or if the type cast of `n_observations` from
47
- /// `usize` to `A` fails.
45
+ /// If `M ` is empty (either zero observations or zero random variables), it returns `Err(EmptyInput)`.
46
+ ///
47
+ /// **Panics** if `ddof` is negative or greater than or equal to the number of
48
+ /// observations, or if the type cast of `n_observations` from `usize` to `A` fails.
48
49
///
49
50
/// # Example
50
51
///
@@ -54,13 +55,13 @@ where
54
55
///
55
56
/// let a = arr2(&[[1., 3., 5.],
56
57
/// [2., 4., 6.]]);
57
- /// let covariance = a.cov(1.);
58
+ /// let covariance = a.cov(1.).unwrap() ;
58
59
/// assert_eq!(
59
60
/// covariance,
60
61
/// aview2(&[[4., 4.], [4., 4.]])
61
62
/// );
62
63
/// ```
63
- fn cov ( & self , ddof : A ) -> Array2 < A >
64
+ fn cov ( & self , ddof : A ) -> Result < Array2 < A > , EmptyInput >
64
65
where
65
66
A : Float + FromPrimitive ;
66
67
@@ -89,30 +90,35 @@ where
89
90
/// R_ij = rho(X_i, X_j)
90
91
/// ```
91
92
///
92
- /// **Panics** if `M` is empty, if the type cast of `n_observations`
93
- /// from `usize` to `A` fails or if the standard deviation of one of the random
93
+ /// If `M` is empty (either zero observations or zero random variables), it returns `Err(EmptyInput)`.
94
+ ///
95
+ /// **Panics** if the type cast of `n_observations` from `usize` to `A` fails or
96
+ /// if the standard deviation of one of the random variables is zero and
97
+ /// division by zero panics for type A.
94
98
///
95
99
/// # Example
96
100
///
97
- /// variables is zero and division by zero panics for type A.
98
101
/// ```
102
+ /// use approx;
99
103
/// use ndarray::arr2;
100
104
/// use ndarray_stats::CorrelationExt;
105
+ /// use approx::AbsDiffEq;
101
106
///
102
107
/// let a = arr2(&[[1., 3., 5.],
103
108
/// [2., 4., 6.]]);
104
- /// let corr = a.pearson_correlation();
109
+ /// let corr = a.pearson_correlation().unwrap();
110
+ /// let epsilon = 1e-7;
105
111
/// assert!(
106
- /// corr.all_close (
112
+ /// corr.abs_diff_eq (
107
113
/// &arr2(&[
108
114
/// [1., 1.],
109
115
/// [1., 1.],
110
116
/// ]),
111
- /// 1e-7
117
+ /// epsilon
112
118
/// )
113
119
/// );
114
120
/// ```
115
- fn pearson_correlation ( & self ) -> Array2 < A >
121
+ fn pearson_correlation ( & self ) -> Result < Array2 < A > , EmptyInput >
116
122
where
117
123
A : Float + FromPrimitive ;
118
124
@@ -123,7 +129,7 @@ impl<A: 'static, S> CorrelationExt<A, S> for ArrayBase<S, Ix2>
123
129
where
124
130
S : Data < Elem = A > ,
125
131
{
126
- fn cov ( & self , ddof : A ) -> Array2 < A >
132
+ fn cov ( & self , ddof : A ) -> Result < Array2 < A > , EmptyInput >
127
133
where
128
134
A : Float + FromPrimitive ,
129
135
{
@@ -139,28 +145,37 @@ where
139
145
n_observations - ddof
140
146
} ;
141
147
let mean = self . mean_axis ( observation_axis) ;
142
- let denoised = self - & mean. insert_axis ( observation_axis) ;
143
- let covariance = denoised. dot ( & denoised. t ( ) ) ;
144
- covariance. mapv_into ( |x| x / dof)
148
+ match mean {
149
+ Some ( mean) => {
150
+ let denoised = self - & mean. insert_axis ( observation_axis) ;
151
+ let covariance = denoised. dot ( & denoised. t ( ) ) ;
152
+ Ok ( covariance. mapv_into ( |x| x / dof) )
153
+ }
154
+ None => Err ( EmptyInput ) ,
155
+ }
145
156
}
146
157
147
- fn pearson_correlation ( & self ) -> Array2 < A >
158
+ fn pearson_correlation ( & self ) -> Result < Array2 < A > , EmptyInput >
148
159
where
149
160
A : Float + FromPrimitive ,
150
161
{
151
- let observation_axis = Axis ( 1 ) ;
152
- // The ddof value doesn't matter, as long as we use the same one
153
- // for computing covariance and standard deviation
154
- // We choose -1 to avoid panicking when we only have one
155
- // observation per random variable (or no observations at all)
156
- let ddof = -A :: one ( ) ;
157
- let cov = self . cov ( ddof) ;
158
- let std = self
159
- . std_axis ( observation_axis, ddof)
160
- . insert_axis ( observation_axis) ;
161
- let std_matrix = std. dot ( & std. t ( ) ) ;
162
- // element-wise division
163
- cov / std_matrix
162
+ match self . dim ( ) {
163
+ ( n, m) if n > 0 && m > 0 => {
164
+ let observation_axis = Axis ( 1 ) ;
165
+ // The ddof value doesn't matter, as long as we use the same one
166
+ // for computing covariance and standard deviation
167
+ // We choose 0 as it is the smallest number admitted by std_axis
168
+ let ddof = A :: zero ( ) ;
169
+ let cov = self . cov ( ddof) . unwrap ( ) ;
170
+ let std = self
171
+ . std_axis ( observation_axis, ddof)
172
+ . insert_axis ( observation_axis) ;
173
+ let std_matrix = std. dot ( & std. t ( ) ) ;
174
+ // element-wise division
175
+ Ok ( cov / std_matrix)
176
+ }
177
+ _ => Err ( EmptyInput ) ,
178
+ }
164
179
}
165
180
166
181
private_impl ! { }
@@ -180,9 +195,10 @@ mod cov_tests {
180
195
let n_random_variables = 3 ;
181
196
let n_observations = 4 ;
182
197
let a = Array :: from_elem ( ( n_random_variables, n_observations) , value) ;
183
- a. cov ( 1. ) . all_close (
198
+ abs_diff_eq ! (
199
+ a. cov( 1. ) . unwrap( ) ,
184
200
& Array :: zeros( ( n_random_variables, n_random_variables) ) ,
185
- 1e-8 ,
201
+ epsilon = 1e-8 ,
186
202
)
187
203
}
188
204
@@ -194,8 +210,8 @@ mod cov_tests {
194
210
( n_random_variables, n_observations) ,
195
211
Uniform :: new ( -bound. abs ( ) , bound. abs ( ) ) ,
196
212
) ;
197
- let covariance = a. cov ( 1. ) ;
198
- covariance . all_close ( & covariance. t ( ) , 1e-8 )
213
+ let covariance = a. cov ( 1. ) . unwrap ( ) ;
214
+ abs_diff_eq ! ( covariance , & covariance. t( ) , epsilon = 1e-8 )
199
215
}
200
216
201
217
#[ test]
@@ -205,31 +221,31 @@ mod cov_tests {
205
221
let n_observations = 4 ;
206
222
let a = Array :: random ( ( n_random_variables, n_observations) , Uniform :: new ( 0. , 10. ) ) ;
207
223
let invalid_ddof = ( n_observations as f64 ) + rand:: random :: < f64 > ( ) . abs ( ) ;
208
- a. cov ( invalid_ddof) ;
224
+ let _ = a. cov ( invalid_ddof) ;
209
225
}
210
226
211
227
#[ test]
212
228
fn test_covariance_zero_variables ( ) {
213
229
let a = Array2 :: < f32 > :: zeros ( ( 0 , 2 ) ) ;
214
230
let cov = a. cov ( 1. ) ;
215
- assert_eq ! ( cov. shape( ) , & [ 0 , 0 ] ) ;
231
+ assert ! ( cov. is_ok( ) ) ;
232
+ assert_eq ! ( cov. unwrap( ) . shape( ) , & [ 0 , 0 ] ) ;
216
233
}
217
234
218
235
#[ test]
219
236
fn test_covariance_zero_observations ( ) {
220
237
let a = Array2 :: < f32 > :: zeros ( ( 2 , 0 ) ) ;
221
238
// Negative ddof (-1 < 0) to avoid invalid-ddof panic
222
239
let cov = a. cov ( -1. ) ;
223
- assert_eq ! ( cov. shape( ) , & [ 2 , 2 ] ) ;
224
- cov. mapv ( |x| assert_eq ! ( x, 0. ) ) ;
240
+ assert_eq ! ( cov, Err ( EmptyInput ) ) ;
225
241
}
226
242
227
243
#[ test]
228
244
fn test_covariance_zero_variables_zero_observations ( ) {
229
245
let a = Array2 :: < f32 > :: zeros ( ( 0 , 0 ) ) ;
230
246
// Negative ddof (-1 < 0) to avoid invalid-ddof panic
231
247
let cov = a. cov ( -1. ) ;
232
- assert_eq ! ( cov. shape ( ) , & [ 0 , 0 ] ) ;
248
+ assert_eq ! ( cov, Err ( EmptyInput ) ) ;
233
249
}
234
250
235
251
#[ test]
@@ -255,7 +271,7 @@ mod cov_tests {
255
271
]
256
272
] ;
257
273
assert_eq ! ( a. ndim( ) , 2 ) ;
258
- assert ! ( a. cov( 1. ) . all_close ( & numpy_covariance, 1e-8 ) ) ;
274
+ assert_abs_diff_eq ! ( a. cov( 1. ) . unwrap ( ) , & numpy_covariance, epsilon = 1e-8 ) ;
259
275
}
260
276
261
277
#[ test]
@@ -264,7 +280,7 @@ mod cov_tests {
264
280
fn test_covariance_for_badly_conditioned_array ( ) {
265
281
let a: Array2 < f64 > = array ! [ [ 1e12 + 1. , 1e12 - 1. ] , [ 1e-6 + 1e-12 , 1e-6 - 1e-12 ] , ] ;
266
282
let expected_covariance = array ! [ [ 2. , 2e-12 ] , [ 2e-12 , 2e-24 ] ] ;
267
- assert ! ( a. cov( 1. ) . all_close ( & expected_covariance, 1e-24 ) ) ;
283
+ assert_abs_diff_eq ! ( a. cov( 1. ) . unwrap ( ) , & expected_covariance, epsilon = 1e-24 ) ;
268
284
}
269
285
}
270
286
@@ -284,8 +300,12 @@ mod pearson_correlation_tests {
284
300
( n_random_variables, n_observations) ,
285
301
Uniform :: new ( -bound. abs ( ) , bound. abs ( ) ) ,
286
302
) ;
287
- let pearson_correlation = a. pearson_correlation ( ) ;
288
- pearson_correlation. all_close ( & pearson_correlation. t ( ) , 1e-8 )
303
+ let pearson_correlation = a. pearson_correlation ( ) . unwrap ( ) ;
304
+ abs_diff_eq ! (
305
+ pearson_correlation. view( ) ,
306
+ pearson_correlation. t( ) ,
307
+ epsilon = 1e-8
308
+ )
289
309
}
290
310
291
311
#[ quickcheck]
@@ -295,6 +315,7 @@ mod pearson_correlation_tests {
295
315
let a = Array :: from_elem ( ( n_random_variables, n_observations) , value) ;
296
316
let pearson_correlation = a. pearson_correlation ( ) ;
297
317
pearson_correlation
318
+ . unwrap ( )
298
319
. iter ( )
299
320
. map ( |x| x. is_nan ( ) )
300
321
. fold ( true , |acc, flag| acc & flag)
@@ -304,21 +325,21 @@ mod pearson_correlation_tests {
304
325
fn test_zero_variables ( ) {
305
326
let a = Array2 :: < f32 > :: zeros ( ( 0 , 2 ) ) ;
306
327
let pearson_correlation = a. pearson_correlation ( ) ;
307
- assert_eq ! ( pearson_correlation. shape ( ) , & [ 0 , 0 ] ) ;
328
+ assert_eq ! ( pearson_correlation, Err ( EmptyInput ) )
308
329
}
309
330
310
331
#[ test]
311
332
fn test_zero_observations ( ) {
312
333
let a = Array2 :: < f32 > :: zeros ( ( 2 , 0 ) ) ;
313
334
let pearson = a. pearson_correlation ( ) ;
314
- pearson . mapv ( |x| x . is_nan ( ) ) ;
335
+ assert_eq ! ( pearson , Err ( EmptyInput ) ) ;
315
336
}
316
337
317
338
#[ test]
318
339
fn test_zero_variables_zero_observations ( ) {
319
340
let a = Array2 :: < f32 > :: zeros ( ( 0 , 0 ) ) ;
320
341
let pearson = a. pearson_correlation ( ) ;
321
- assert_eq ! ( pearson. shape ( ) , & [ 0 , 0 ] ) ;
342
+ assert_eq ! ( pearson, Err ( EmptyInput ) ) ;
322
343
}
323
344
324
345
#[ test]
@@ -338,6 +359,10 @@ mod pearson_correlation_tests {
338
359
[ 0.1365648 , 0.38954398 , -0.17324776 , -0.8743213 , 1. ]
339
360
] ;
340
361
assert_eq ! ( a. ndim( ) , 2 ) ;
341
- assert ! ( a. pearson_correlation( ) . all_close( & numpy_corrcoeff, 1e-7 ) ) ;
362
+ assert_abs_diff_eq ! (
363
+ a. pearson_correlation( ) . unwrap( ) ,
364
+ numpy_corrcoeff,
365
+ epsilon = 1e-7
366
+ ) ;
342
367
}
343
368
}
0 commit comments