77 ErrCorrForms ,
88 RandomCorrelation ,
99 SystematicCorrelation ,
10+ ErrCorrMatrixCorrelation ,
1011)
1112from obsarray .test .test_unc_accessor import create_ds
1213
@@ -65,18 +66,26 @@ class BasicErrCorrForm(BaseErrCorrForm):
6566 form = "basic"
6667
6768 def build_matrix (self , sli ):
68- return None
69+ full_matrix = np .arange (144 ).reshape ((12 , 12 ))
70+ return self .slice_errcorr_matrix (full_matrix , (2 , 2 , 3 ), sli )
6971
7072 self .BasicErrCorrForm = BasicErrCorrForm
7173
74+ def test_get_varshape_errcorr (self ):
75+ basicerrcorr = self .BasicErrCorrForm (
76+ self .ds , "u_ran_temperature" , ["x" , "y" , "time" ], [], []
77+ )
78+ shape = basicerrcorr .get_varshape_errcorr ()
79+ np .testing .assert_equal (shape , (2 , 2 , 3 ))
80+
7281 def test_slice_full_cov_full (self ):
7382 basicerrcorr = self .BasicErrCorrForm (
7483 self .ds , "u_ran_temperature" , ["x" ], [], []
7584 )
7685
7786 full_matrix = np .arange (144 ).reshape ((12 , 12 ))
78- slice_matrix = basicerrcorr .slice_full_cov (
79- full_matrix , (slice (None ), slice (None ), slice (None ))
87+ slice_matrix = basicerrcorr .build_matrix (
88+ (slice (None ), slice (None ), slice (None ))
8089 )
8190
8291 np .testing .assert_equal (full_matrix , slice_matrix )
@@ -118,13 +127,13 @@ def test_get_sliced_shape_errcorr(self):
118127 shape = basicerrcorr .get_sliced_shape_errcorr ((slice (None ), 0 , slice (0 , 2 , 1 )))
119128 assert shape == (2 , 2 )
120129
121- def test_slice_flattened_matrix (self ):
130+ def test_slice_errcorr_matrix (self ):
122131 basicerrcorr = self .BasicErrCorrForm (
123- self .ds , "u_ran_temperature" , ["x" ], [], []
132+ self .ds , "u_ran_temperature" , ["x" , "y" , "z" ], [], []
124133 )
125134
126135 full_matrix = np .arange (144 ).reshape ((12 , 12 ))
127- slice_matrix = basicerrcorr .slice_flattened_matrix (
136+ slice_matrix = basicerrcorr .slice_errcorr_matrix (
128137 full_matrix , (2 , 2 , 3 ), (slice (None ), slice (None ), 0 )
129138 )
130139
@@ -136,13 +145,10 @@ def test_slice_flattened_matrix(self):
136145
137146 def test_slice_full_cov_slice (self ):
138147 basicerrcorr = self .BasicErrCorrForm (
139- self .ds , "u_ran_temperature" , ["x" ], [], []
148+ self .ds , "u_ran_temperature" , ["x" , "y" , "z" ], [], []
140149 )
141150
142- full_matrix = np .arange (144 ).reshape ((12 , 12 ))
143- slice_matrix = basicerrcorr .slice_full_cov (
144- full_matrix , (slice (None ), slice (None ), 0 )
145- )
151+ slice_matrix = basicerrcorr .build_dot_matrix ((slice (None ), slice (None ), 0 ))
146152
147153 exp_slice_matrix = np .array (
148154 [[0 , 3 , 6 , 9 ], [36 , 39 , 42 , 45 ], [72 , 75 , 78 , 81 ], [108 , 111 , 114 , 117 ]]
@@ -206,5 +212,40 @@ def test_build_dot_matrix(self):
206212 np .testing .assert_equal ((x .dot (y )).dot (time ), np .ones ((12 , 12 )))
207213
208214
215+ class TestErrCorrMatrixCorrelation (unittest .TestCase ):
216+ def setUp (self ) -> None :
217+ self .ds = create_ds ()
218+
219+ def test_build_matrix_full (self ):
220+ ec = ErrCorrMatrixCorrelation (
221+ self .ds ,
222+ "u_str_temperature" ,
223+ ["x" , "time" ],
224+ ["err_corr_str_temperature" ],
225+ [],
226+ )
227+
228+ ecrm = ec .build_matrix ((slice (None ), slice (None ), slice (None )))
229+ np .testing .assert_equal (ecrm , np .ones ((6 , 6 )))
230+
231+ def test_build_matrix_sliced (self ):
232+ ec = ErrCorrMatrixCorrelation (
233+ self .ds ,
234+ "u_str_temperature" ,
235+ ["x" , "time" ],
236+ ["err_corr_str_temperature" ],
237+ [],
238+ )
239+
240+ ecrm = ec .build_matrix ((0 , slice (None ), slice (None )))
241+ np .testing .assert_equal (ecrm , np .ones ((3 , 3 )))
242+
243+ ecrm = ec .build_matrix ((slice (None ), 0 , slice (None )))
244+ np .testing .assert_equal (ecrm , np .ones ((6 , 6 )))
245+
246+ ecrm = ec .build_matrix ((slice (None ), slice (None ), 0 ))
247+ np .testing .assert_equal (ecrm , np .ones ((2 , 2 )))
248+
249+
209250if __name__ == "main" :
210251 unittest .main ()
0 commit comments