Skip to content

Commit d75844e

Browse files
Covariance (#2)
* Added new file. * Added correlation module to lib.rs * Added stub for covariance method. * Implement signature for covariance, alongside first failing test. * Added one test for panic if working on a 1d array. * Check number of dimensions before proceeding. Panic if invalid. * First implementation of covariance for 2-dimensional arrays. * Improved test using all_close. * Added one test with a random array. * Added a new test to check covariance_matrix is symmetric. * added another test to check for panic when passing an invalid ddof. * Using quickcheck to test the symmetry property for covariance matrices. * Moved constant matrix test under quickcheck to generalize on the rconstant value. * Added docs to cov. Published the CorrelationExt crate. * Added another reason for panic to the docs. * Added one more test and one more reason to panic to the docs. * Added one more test for a badly conditioned array.
1 parent fcc60cf commit d75844e

File tree

3 files changed

+192
-0
lines changed

3 files changed

+192
-0
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ rand = "0.5"
1111

1212
[dev-dependencies]
1313
quickcheck = "0.7"
14+
ndarray-rand = "0.8"
1415

1516
[patch.crates-io]
1617
noisy_float = { git = "https://github.com/SergiusIW/noisy_float-rs.git", rev = "c33a94803987475bbd205c9ff5a697af533f9a17" }

src/correlation.rs

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
use ndarray::prelude::*;
2+
use ndarray::Data;
3+
use num_traits::{Float, FromPrimitive};
4+
5+
pub trait CorrelationExt<A, S>
6+
where
7+
S: Data<Elem = A>,
8+
{
9+
/// Return the covariance matrix `C` for a 2-dimensional
10+
/// array of observations `M`.
11+
///
12+
/// Let `(r, o)` be the shape of `M`:
13+
/// - `r` is the number of random variables;
14+
/// - `o` is the number of observations we have collected
15+
/// for each random variable.
16+
///
17+
/// Every column in `M` is an experiment: a single observation for each
18+
/// random variable.
19+
/// Each row in `M` contains all the observations for a certain random variable.
20+
///
21+
/// The parameter `ddof` specifies the "delta degrees of freedom". For
22+
/// example, to calculate the population covariance, use `ddof = 0`, or to
23+
/// calculate the sample covariance (unbiased estimate), use `ddof = 1`.
24+
///
25+
/// The covariance of two random variables is defined as:
26+
///
27+
/// ```text
28+
/// 1 n
29+
/// cov(X, Y) = ―――――――― ∑ (xᵢ - x̅)(yᵢ - y̅)
30+
/// n - ddof i=1
31+
/// ```
32+
///
33+
/// where
34+
///
35+
/// ```text
36+
/// 1 n
37+
/// x̅ = ― ∑ xᵢ
38+
/// n i=1
39+
/// ```
40+
/// and similarly for ̅y.
41+
///
42+
/// **Panics** if `ddof` is greater than or equal to the number of
43+
/// observations, if `M` is emtpy or if the type cast of `n_observations`
44+
/// from `usize` to `A` fails.
45+
///
46+
/// # Example
47+
///
48+
/// ```
49+
/// extern crate ndarray;
50+
/// extern crate ndarray_stats;
51+
/// use ndarray::{aview2, arr2};
52+
/// use ndarray_stats::CorrelationExt;
53+
///
54+
/// let a = arr2(&[[1., 3., 5.],
55+
/// [2., 4., 6.]]);
56+
/// let covariance = a.cov(1.);
57+
/// assert_eq!(
58+
/// covariance,
59+
/// aview2(&[[4., 4.], [4., 4.]])
60+
/// );
61+
/// ```
62+
fn cov(&self, ddof: A) -> Array2<A>
63+
where
64+
A: Float + FromPrimitive;
65+
}
66+
67+
impl<A: 'static, S> CorrelationExt<A, S> for ArrayBase<S, Ix2>
68+
where
69+
S: Data<Elem = A>,
70+
{
71+
fn cov(&self, ddof: A) -> Array2<A>
72+
where
73+
A: Float + FromPrimitive,
74+
{
75+
let observation_axis = Axis(1);
76+
let n_observations = A::from_usize(self.len_of(observation_axis)).unwrap();
77+
let dof =
78+
if ddof >= n_observations {
79+
panic!("`ddof` needs to be strictly smaller than the \
80+
number of observations provided for each \
81+
random variable!")
82+
} else {
83+
n_observations - ddof
84+
};
85+
let mean = self.mean_axis(observation_axis);
86+
let denoised = self - &mean.insert_axis(observation_axis);
87+
let covariance = denoised.dot(&denoised.t());
88+
covariance.mapv_into(|x| x / dof)
89+
}
90+
}
91+
92+
#[cfg(test)]
93+
mod tests {
94+
use super::*;
95+
use rand;
96+
use rand::distributions::Range;
97+
use ndarray_rand::RandomExt;
98+
99+
quickcheck! {
100+
fn constant_random_variables_have_zero_covariance_matrix(value: f64) -> bool {
101+
let n_random_variables = 3;
102+
let n_observations = 4;
103+
let a = Array::from_elem((n_random_variables, n_observations), value);
104+
a.cov(1.).all_close(
105+
&Array::zeros((n_random_variables, n_random_variables)),
106+
1e-8
107+
)
108+
}
109+
110+
fn covariance_matrix_is_symmetric(bound: f64) -> bool {
111+
let n_random_variables = 3;
112+
let n_observations = 4;
113+
let a = Array::random(
114+
(n_random_variables, n_observations),
115+
Range::new(-bound.abs(), bound.abs())
116+
);
117+
let covariance = a.cov(1.);
118+
covariance.all_close(&covariance.t(), 1e-8)
119+
}
120+
}
121+
122+
#[test]
123+
#[should_panic]
124+
fn test_invalid_ddof() {
125+
let n_random_variables = 3;
126+
let n_observations = 4;
127+
let a = Array::random(
128+
(n_random_variables, n_observations),
129+
Range::new(0., 10.)
130+
);
131+
let invalid_ddof = (n_observations as f64) + rand::random::<f64>().abs();
132+
a.cov(invalid_ddof);
133+
}
134+
135+
#[test]
136+
#[should_panic]
137+
fn test_empty_matrix() {
138+
let a: Array2<f32> = array![[], []];
139+
// Negative ddof (-1 < 0) to avoid invalid-ddof panic
140+
a.cov(-1.);
141+
}
142+
143+
#[test]
144+
fn test_covariance_for_random_array() {
145+
let a = array![
146+
[ 0.72009497, 0.12568055, 0.55705966, 0.5959984 , 0.69471457],
147+
[ 0.56717131, 0.47619486, 0.21526298, 0.88915366, 0.91971245],
148+
[ 0.59044195, 0.10720363, 0.76573717, 0.54693675, 0.95923036],
149+
[ 0.24102952, 0.131347, 0.11118028, 0.21451351, 0.30515539],
150+
[ 0.26952473, 0.93079841, 0.8080893 , 0.42814155, 0.24642258]
151+
];
152+
let numpy_covariance = array![
153+
[ 0.05786248, 0.02614063, 0.06446215, 0.01285105, -0.06443992],
154+
[ 0.02614063, 0.08733569, 0.02436933, 0.01977437, -0.06715555],
155+
[ 0.06446215, 0.02436933, 0.10052129, 0.01393589, -0.06129912],
156+
[ 0.01285105, 0.01977437, 0.01393589, 0.00638795, -0.02355557],
157+
[-0.06443992, -0.06715555, -0.06129912, -0.02355557, 0.09909855]
158+
];
159+
assert_eq!(a.ndim(), 2);
160+
assert!(
161+
a.cov(1.).all_close(
162+
&numpy_covariance,
163+
1e-8
164+
)
165+
);
166+
}
167+
168+
#[test]
169+
#[should_panic]
170+
// We lose precision, hence the failing assert
171+
fn test_covariance_for_badly_conditioned_array() {
172+
let a: Array2<f64> = array![
173+
[ 1e12 + 1., 1e12 - 1.],
174+
[ 1e-6 + 1e-12, 1e-6 - 1e-12],
175+
];
176+
let expected_covariance = array![
177+
[2., 2e-12], [2e-12, 2e-24]
178+
];
179+
assert!(
180+
a.cov(1.).all_close(
181+
&expected_covariance,
182+
1e-24
183+
)
184+
);
185+
}
186+
}

src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
#[macro_use(azip, s)]
2+
#[cfg_attr(test, macro_use(array))]
23
extern crate ndarray;
34
extern crate noisy_float;
45
extern crate num_traits;
56
extern crate rand;
67

8+
#[cfg(test)]
9+
extern crate ndarray_rand;
710
#[cfg(test)]
811
#[macro_use(quickcheck)]
912
extern crate quickcheck;
1013

1114
pub use maybe_nan::{MaybeNan, MaybeNanExt};
1215
pub use quantile::{interpolate, QuantileExt};
1316
pub use sort::Sort1dExt;
17+
pub use correlation::CorrelationExt;
1418

1519
mod maybe_nan;
1620
mod quantile;
1721
mod sort;
22+
mod correlation;

0 commit comments

Comments
 (0)