From 0a70ccca4880fbcd088b905e29a19a8d181a219f Mon Sep 17 00:00:00 2001
From: Michael Milton <ttmigueltt@gmail.com>
Date: Wed, 4 Aug 2021 00:02:53 +1000
Subject: [PATCH 1/4] Define and implement a cov_to_corr() method in the
 CorrelationExt trait

---
 src/correlation.rs | 87 +++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 86 insertions(+), 1 deletion(-)

diff --git a/src/correlation.rs b/src/correlation.rs
index 5ae194b..9592f55 100644
--- a/src/correlation.rs
+++ b/src/correlation.rs
@@ -1,7 +1,9 @@
 use crate::errors::EmptyInput;
 use ndarray::prelude::*;
-use ndarray::Data;
+use ndarray::{Data, ShapeError};
 use num_traits::{Float, FromPrimitive};
+use itertools::Itertools;
+use std::error::Error;
 
 /// Extension trait for `ArrayBase` providing functions
 /// to compute different correlation measures.
@@ -123,6 +125,22 @@ where
         A: Float + FromPrimitive;
 
     private_decl! {}
+
+    /// Given that self is a covariance matrix, returns the appropriate correlation matrix
+    ///
+    /// # Example
+    /// ```
+    /// use::ndarray::array;
+    /// use ndarray_stats::CorrelationExt;
+    ///
+    /// let a = array![[1., 3., 5.],
+    ///                [2., 4., 6.]];
+    /// let covariance = a.cov(1.).unwrap();
+    /// let corr = a.pearson_correlation().unwrap();
+    /// assert_eq!(covariance.cov_to_corr().unwrap(), corr);
+    /// ```
+    fn cov_to_corr(&self) -> Result<Array2<A>, Box<dyn Error>>
+        where A: Float + FromPrimitive;
 }
 
 impl<A: 'static, S> CorrelationExt<A, S> for ArrayBase<S, Ix2>
@@ -179,6 +197,27 @@ where
     }
 
     private_impl! {}
+
+    fn cov_to_corr<'a>(&self) -> Result<Array2<A>, Box<dyn Error>>
+    where A: Float + FromPrimitive
+    {
+        if !self.is_square() {
+            return Err("A covariance matrix must be square".into());
+        }
+
+        let vals = self
+            .indexed_iter()
+            .map(|((x, y), v)| {
+                // rho_ij = sigma_ij / sqrt(sigma_ii * sigma_jj)
+                *v / (self[[x, x]] * self[[y, y]]).powf(A::from_f64(0.5).unwrap())
+            })
+            .collect_vec();
+
+        match Array2::from_shape_vec(self.raw_dim(), vals){
+            Ok(x) => Ok(x),
+            Err(e) => Err(Box::new(e))
+        }
+    }
 }
 
 #[cfg(test)]
@@ -274,6 +313,7 @@ mod cov_tests {
         assert_abs_diff_eq!(a.cov(1.).unwrap(), &numpy_covariance, epsilon = 1e-8);
     }
 
+
     #[test]
     #[should_panic]
     // We lose precision, hence the failing assert
@@ -367,3 +407,48 @@ mod pearson_correlation_tests {
         );
     }
 }
+
+#[cfg(test)]
+mod cov_to_corr_tests{
+    use ndarray::array;
+    use super::*;
+
+    #[test]
+    fn test_cov_2_corr_known(){
+        // Very basic maths that can be done in your head
+        let cov = array![
+            [ 4., 1. ],
+            [ 3., 4. ],
+        ];
+        assert_eq!(cov.cov_to_corr().unwrap(), array![
+            [1., 0.25],
+            [0.75, 1.]
+        ])
+    }
+
+    #[test]
+    fn test_cov_2_corr_failure(){
+        // A 1D array can't be a covariance matrix
+        let cov = array![
+            [ 4., 1. ],
+        ];
+        cov.cov_to_corr().unwrap_err();
+    }
+
+    #[test]
+    fn test_cov_2_corr_random(){
+        let a = array![
+            [0.72009497, 0.12568055, 0.55705966, 0.5959984, 0.69471457],
+            [0.56717131, 0.47619486, 0.21526298, 0.88915366, 0.91971245],
+            [0.59044195, 0.10720363, 0.76573717, 0.54693675, 0.95923036],
+            [0.24102952, 0.131347, 0.11118028, 0.21451351, 0.30515539],
+            [0.26952473, 0.93079841, 0.8080893, 0.42814155, 0.24642258]
+        ];
+
+        // Calculating cov, and then corr should always be equivalent to calculating corr directly
+        let cov = a.cov(1.).unwrap();
+        let corr = a.pearson_correlation().unwrap();
+        let corr_2 = cov.cov_to_corr().unwrap();
+        assert!(corr.abs_diff_eq(&corr_2, 0.001));
+    }
+}
\ No newline at end of file

From 3a3778e107ed588489605f3e84583b8d660ae9a6 Mon Sep 17 00:00:00 2001
From: Michael Milton <ttmigueltt@gmail.com>
Date: Wed, 4 Aug 2021 00:04:45 +1000
Subject: [PATCH 2/4] Fix cargo check

---
 src/correlation.rs | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/correlation.rs b/src/correlation.rs
index 9592f55..b330dd2 100644
--- a/src/correlation.rs
+++ b/src/correlation.rs
@@ -1,6 +1,6 @@
 use crate::errors::EmptyInput;
 use ndarray::prelude::*;
-use ndarray::{Data, ShapeError};
+use ndarray::{Data};
 use num_traits::{Float, FromPrimitive};
 use itertools::Itertools;
 use std::error::Error;

From b3f11ac9fd19e4b11b01f5b269c419bd96860eba Mon Sep 17 00:00:00 2001
From: Michael Milton <ttmigueltt@gmail.com>
Date: Wed, 4 Aug 2021 00:05:40 +1000
Subject: [PATCH 3/4] rustfmt

---
 src/correlation.rs | 41 +++++++++++++++++------------------------
 1 file changed, 17 insertions(+), 24 deletions(-)

diff --git a/src/correlation.rs b/src/correlation.rs
index b330dd2..5ac1ae8 100644
--- a/src/correlation.rs
+++ b/src/correlation.rs
@@ -1,8 +1,8 @@
 use crate::errors::EmptyInput;
+use itertools::Itertools;
 use ndarray::prelude::*;
-use ndarray::{Data};
+use ndarray::Data;
 use num_traits::{Float, FromPrimitive};
-use itertools::Itertools;
 use std::error::Error;
 
 /// Extension trait for `ArrayBase` providing functions
@@ -140,7 +140,8 @@ where
     /// assert_eq!(covariance.cov_to_corr().unwrap(), corr);
     /// ```
     fn cov_to_corr(&self) -> Result<Array2<A>, Box<dyn Error>>
-        where A: Float + FromPrimitive;
+    where
+        A: Float + FromPrimitive;
 }
 
 impl<A: 'static, S> CorrelationExt<A, S> for ArrayBase<S, Ix2>
@@ -199,7 +200,8 @@ where
     private_impl! {}
 
     fn cov_to_corr<'a>(&self) -> Result<Array2<A>, Box<dyn Error>>
-    where A: Float + FromPrimitive
+    where
+        A: Float + FromPrimitive,
     {
         if !self.is_square() {
             return Err("A covariance matrix must be square".into());
@@ -213,9 +215,9 @@ where
             })
             .collect_vec();
 
-        match Array2::from_shape_vec(self.raw_dim(), vals){
+        match Array2::from_shape_vec(self.raw_dim(), vals) {
             Ok(x) => Ok(x),
-            Err(e) => Err(Box::new(e))
+            Err(e) => Err(Box::new(e)),
         }
     }
 }
@@ -313,7 +315,6 @@ mod cov_tests {
         assert_abs_diff_eq!(a.cov(1.).unwrap(), &numpy_covariance, epsilon = 1e-8);
     }
 
-
     #[test]
     #[should_panic]
     // We lose precision, hence the failing assert
@@ -409,34 +410,26 @@ mod pearson_correlation_tests {
 }
 
 #[cfg(test)]
-mod cov_to_corr_tests{
-    use ndarray::array;
+mod cov_to_corr_tests {
     use super::*;
+    use ndarray::array;
 
     #[test]
-    fn test_cov_2_corr_known(){
+    fn test_cov_2_corr_known() {
         // Very basic maths that can be done in your head
-        let cov = array![
-            [ 4., 1. ],
-            [ 3., 4. ],
-        ];
-        assert_eq!(cov.cov_to_corr().unwrap(), array![
-            [1., 0.25],
-            [0.75, 1.]
-        ])
+        let cov = array![[4., 1.], [3., 4.],];
+        assert_eq!(cov.cov_to_corr().unwrap(), array![[1., 0.25], [0.75, 1.]])
     }
 
     #[test]
-    fn test_cov_2_corr_failure(){
+    fn test_cov_2_corr_failure() {
         // A 1D array can't be a covariance matrix
-        let cov = array![
-            [ 4., 1. ],
-        ];
+        let cov = array![[4., 1.],];
         cov.cov_to_corr().unwrap_err();
     }
 
     #[test]
-    fn test_cov_2_corr_random(){
+    fn test_cov_2_corr_random() {
         let a = array![
             [0.72009497, 0.12568055, 0.55705966, 0.5959984, 0.69471457],
             [0.56717131, 0.47619486, 0.21526298, 0.88915366, 0.91971245],
@@ -451,4 +444,4 @@ mod cov_to_corr_tests{
         let corr_2 = cov.cov_to_corr().unwrap();
         assert!(corr.abs_diff_eq(&corr_2, 0.001));
     }
-}
\ No newline at end of file
+}

From 43cf9806427fecbd03c11ec7b9c5200631468ad7 Mon Sep 17 00:00:00 2001
From: Michael Milton <ttmigueltt@gmail.com>
Date: Wed, 4 Aug 2021 00:07:35 +1000
Subject: [PATCH 4/4] Don't use itertools

---
 src/correlation.rs | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/src/correlation.rs b/src/correlation.rs
index 5ac1ae8..9ab8e5e 100644
--- a/src/correlation.rs
+++ b/src/correlation.rs
@@ -1,5 +1,4 @@
 use crate::errors::EmptyInput;
-use itertools::Itertools;
 use ndarray::prelude::*;
 use ndarray::Data;
 use num_traits::{Float, FromPrimitive};
@@ -213,7 +212,7 @@ where
                 // rho_ij = sigma_ij / sqrt(sigma_ii * sigma_jj)
                 *v / (self[[x, x]] * self[[y, y]]).powf(A::from_f64(0.5).unwrap())
             })
-            .collect_vec();
+            .collect();
 
         match Array2::from_shape_vec(self.raw_dim(), vals) {
             Ok(x) => Ok(x),