1
+ use ndarray:: prelude:: * ;
2
+ use ndarray:: Data ;
3
+ use num_traits:: { Float , FromPrimitive } ;
4
+
5
+ pub trait CorrelationExt < A , S >
6
+ where
7
+ S : Data < Elem = A > ,
8
+ {
9
+ /// Return the covariance matrix `C` for a 2-dimensional
10
+ /// array of observations `M`.
11
+ ///
12
+ /// Let `(r, o)` be the shape of `M`:
13
+ /// - `r` is the number of random variables;
14
+ /// - `o` is the number of observations we have collected
15
+ /// for each random variable.
16
+ ///
17
+ /// Every column in `M` is an experiment: a single observation for each
18
+ /// random variable.
19
+ /// Each row in `M` contains all the observations for a certain random variable.
20
+ ///
21
+ /// The parameter `ddof` specifies the "delta degrees of freedom". For
22
+ /// example, to calculate the population covariance, use `ddof = 0`, or to
23
+ /// calculate the sample covariance (unbiased estimate), use `ddof = 1`.
24
+ ///
25
+ /// The covariance of two random variables is defined as:
26
+ ///
27
+ /// ```text
28
+ /// 1 n
29
+ /// cov(X, Y) = ―――――――― ∑ (xᵢ - x̅)(yᵢ - y̅)
30
+ /// n - ddof i=1
31
+ /// ```
32
+ ///
33
+ /// where
34
+ ///
35
+ /// ```text
36
+ /// 1 n
37
+ /// x̅ = ― ∑ xᵢ
38
+ /// n i=1
39
+ /// ```
40
+ /// and similarly for ̅y.
41
+ ///
42
+ /// **Panics** if `ddof` is greater than or equal to the number of
43
+ /// observations, if `M` is emtpy or if the type cast of `n_observations`
44
+ /// from `usize` to `A` fails.
45
+ ///
46
+ /// # Example
47
+ ///
48
+ /// ```
49
+ /// extern crate ndarray;
50
+ /// extern crate ndarray_stats;
51
+ /// use ndarray::{aview2, arr2};
52
+ /// use ndarray_stats::CorrelationExt;
53
+ ///
54
+ /// let a = arr2(&[[1., 3., 5.],
55
+ /// [2., 4., 6.]]);
56
+ /// let covariance = a.cov(1.);
57
+ /// assert_eq!(
58
+ /// covariance,
59
+ /// aview2(&[[4., 4.], [4., 4.]])
60
+ /// );
61
+ /// ```
62
+ fn cov ( & self , ddof : A ) -> Array2 < A >
63
+ where
64
+ A : Float + FromPrimitive ;
65
+ }
66
+
67
+ impl < A : ' static , S > CorrelationExt < A , S > for ArrayBase < S , Ix2 >
68
+ where
69
+ S : Data < Elem = A > ,
70
+ {
71
+ fn cov ( & self , ddof : A ) -> Array2 < A >
72
+ where
73
+ A : Float + FromPrimitive ,
74
+ {
75
+ let observation_axis = Axis ( 1 ) ;
76
+ let n_observations = A :: from_usize ( self . len_of ( observation_axis) ) . unwrap ( ) ;
77
+ let dof =
78
+ if ddof >= n_observations {
79
+ panic ! ( "`ddof` needs to be strictly smaller than the \
80
+ number of observations provided for each \
81
+ random variable!")
82
+ } else {
83
+ n_observations - ddof
84
+ } ;
85
+ let mean = self . mean_axis ( observation_axis) ;
86
+ let denoised = self - & mean. insert_axis ( observation_axis) ;
87
+ let covariance = denoised. dot ( & denoised. t ( ) ) ;
88
+ covariance. mapv_into ( |x| x / dof)
89
+ }
90
+ }
91
+
92
+ #[ cfg( test) ]
93
+ mod tests {
94
+ use super :: * ;
95
+ use rand;
96
+ use rand:: distributions:: Range ;
97
+ use ndarray_rand:: RandomExt ;
98
+
99
+ quickcheck ! {
100
+ fn constant_random_variables_have_zero_covariance_matrix( value: f64 ) -> bool {
101
+ let n_random_variables = 3 ;
102
+ let n_observations = 4 ;
103
+ let a = Array :: from_elem( ( n_random_variables, n_observations) , value) ;
104
+ a. cov( 1. ) . all_close(
105
+ & Array :: zeros( ( n_random_variables, n_random_variables) ) ,
106
+ 1e-8
107
+ )
108
+ }
109
+
110
+ fn covariance_matrix_is_symmetric( bound: f64 ) -> bool {
111
+ let n_random_variables = 3 ;
112
+ let n_observations = 4 ;
113
+ let a = Array :: random(
114
+ ( n_random_variables, n_observations) ,
115
+ Range :: new( -bound. abs( ) , bound. abs( ) )
116
+ ) ;
117
+ let covariance = a. cov( 1. ) ;
118
+ covariance. all_close( & covariance. t( ) , 1e-8 )
119
+ }
120
+ }
121
+
122
+ #[ test]
123
+ #[ should_panic]
124
+ fn test_invalid_ddof ( ) {
125
+ let n_random_variables = 3 ;
126
+ let n_observations = 4 ;
127
+ let a = Array :: random (
128
+ ( n_random_variables, n_observations) ,
129
+ Range :: new ( 0. , 10. )
130
+ ) ;
131
+ let invalid_ddof = ( n_observations as f64 ) + rand:: random :: < f64 > ( ) . abs ( ) ;
132
+ a. cov ( invalid_ddof) ;
133
+ }
134
+
135
+ #[ test]
136
+ #[ should_panic]
137
+ fn test_empty_matrix ( ) {
138
+ let a: Array2 < f32 > = array ! [ [ ] , [ ] ] ;
139
+ // Negative ddof (-1 < 0) to avoid invalid-ddof panic
140
+ a. cov ( -1. ) ;
141
+ }
142
+
143
+ #[ test]
144
+ fn test_covariance_for_random_array ( ) {
145
+ let a = array ! [
146
+ [ 0.72009497 , 0.12568055 , 0.55705966 , 0.5959984 , 0.69471457 ] ,
147
+ [ 0.56717131 , 0.47619486 , 0.21526298 , 0.88915366 , 0.91971245 ] ,
148
+ [ 0.59044195 , 0.10720363 , 0.76573717 , 0.54693675 , 0.95923036 ] ,
149
+ [ 0.24102952 , 0.131347 , 0.11118028 , 0.21451351 , 0.30515539 ] ,
150
+ [ 0.26952473 , 0.93079841 , 0.8080893 , 0.42814155 , 0.24642258 ]
151
+ ] ;
152
+ let numpy_covariance = array ! [
153
+ [ 0.05786248 , 0.02614063 , 0.06446215 , 0.01285105 , -0.06443992 ] ,
154
+ [ 0.02614063 , 0.08733569 , 0.02436933 , 0.01977437 , -0.06715555 ] ,
155
+ [ 0.06446215 , 0.02436933 , 0.10052129 , 0.01393589 , -0.06129912 ] ,
156
+ [ 0.01285105 , 0.01977437 , 0.01393589 , 0.00638795 , -0.02355557 ] ,
157
+ [ -0.06443992 , -0.06715555 , -0.06129912 , -0.02355557 , 0.09909855 ]
158
+ ] ;
159
+ assert_eq ! ( a. ndim( ) , 2 ) ;
160
+ assert ! (
161
+ a. cov( 1. ) . all_close(
162
+ & numpy_covariance,
163
+ 1e-8
164
+ )
165
+ ) ;
166
+ }
167
+
168
+ #[ test]
169
+ #[ should_panic]
170
+ // We lose precision, hence the failing assert
171
+ fn test_covariance_for_badly_conditioned_array ( ) {
172
+ let a: Array2 < f64 > = array ! [
173
+ [ 1e12 + 1. , 1e12 - 1. ] ,
174
+ [ 1e-6 + 1e-12 , 1e-6 - 1e-12 ] ,
175
+ ] ;
176
+ let expected_covariance = array ! [
177
+ [ 2. , 2e-12 ] , [ 2e-12 , 2e-24 ]
178
+ ] ;
179
+ assert ! (
180
+ a. cov( 1. ) . all_close(
181
+ & expected_covariance,
182
+ 1e-24
183
+ )
184
+ ) ;
185
+ }
186
+ }
0 commit comments