11use crate :: errors:: EmptyInput ;
22use ndarray:: prelude:: * ;
3- use ndarray:: Data ;
43use num_traits:: { Float , FromPrimitive } ;
54
6- /// Extension trait for `ArrayBase ` providing functions
5+ /// Extension trait for `ndarray ` providing functions
76/// to compute different correlation measures.
8- pub trait CorrelationExt < A , S >
9- where
10- S : Data < Elem = A > ,
11- {
7+ pub trait CorrelationExt < A > {
128 /// Return the covariance matrix `C` for a 2-dimensional
139 /// array of observations `M`.
1410 ///
@@ -125,10 +121,7 @@ where
125121 private_decl ! { }
126122}
127123
128- impl < A : ' static , S > CorrelationExt < A , S > for ArrayBase < S , Ix2 >
129- where
130- S : Data < Elem = A > ,
131- {
124+ impl < A : ' static > CorrelationExt < A > for ArrayRef2 < A > {
132125 fn cov ( & self , ddof : A ) -> Result < Array2 < A > , EmptyInput >
133126 where
134127 A : Float + FromPrimitive ,
@@ -147,7 +140,7 @@ where
147140 let mean = self . mean_axis ( observation_axis) ;
148141 match mean {
149142 Some ( mean) => {
150- let denoised = self - & mean. insert_axis ( observation_axis) ;
143+ let denoised = self - mean. insert_axis ( observation_axis) ;
151144 let covariance = denoised. dot ( & denoised. t ( ) ) ;
152145 Ok ( covariance. mapv_into ( |x| x / dof) )
153146 }
@@ -208,7 +201,7 @@ mod cov_tests {
208201 let n_observations = 4 ;
209202 let a = Array :: random (
210203 ( n_random_variables, n_observations) ,
211- Uniform :: new ( -bound. abs ( ) , bound. abs ( ) ) ,
204+ Uniform :: new ( -bound. abs ( ) , bound. abs ( ) ) . unwrap ( ) ,
212205 ) ;
213206 let covariance = a. cov ( 1. ) . unwrap ( ) ;
214207 abs_diff_eq ! ( covariance, & covariance. t( ) , epsilon = 1e-8 )
@@ -219,7 +212,10 @@ mod cov_tests {
219212 fn test_invalid_ddof ( ) {
220213 let n_random_variables = 3 ;
221214 let n_observations = 4 ;
222- let a = Array :: random ( ( n_random_variables, n_observations) , Uniform :: new ( 0. , 10. ) ) ;
215+ let a = Array :: random (
216+ ( n_random_variables, n_observations) ,
217+ Uniform :: new ( 0. , 10. ) . unwrap ( ) ,
218+ ) ;
223219 let invalid_ddof = ( n_observations as f64 ) + rand:: random :: < f64 > ( ) . abs ( ) ;
224220 let _ = a. cov ( invalid_ddof) ;
225221 }
@@ -299,7 +295,7 @@ mod pearson_correlation_tests {
299295 let n_observations = 4 ;
300296 let a = Array :: random (
301297 ( n_random_variables, n_observations) ,
302- Uniform :: new ( -bound. abs ( ) , bound. abs ( ) ) ,
298+ Uniform :: new ( -bound. abs ( ) , bound. abs ( ) ) . unwrap ( ) ,
303299 ) ;
304300 let pearson_correlation = a. pearson_correlation ( ) . unwrap ( ) ;
305301 abs_diff_eq ! (
0 commit comments