@@ -51,14 +51,24 @@ def check_nan_values(self, dataset):
5151 "check that there are no NaN values in any variable (they should be fill values instead)"
5252 nan_vars = [(name , "contains NaN values" )
5353 for name , var in dataset .variables .items ()
54- if var .dtype in (np .dtype ('float32' ), np .dtype ('float64' )) and any ( np .isnan (var [:]))
54+ if var .dtype in (np .dtype ('float32' ), np .dtype ('float64' )) and np .isnan (var [:]). any ( )
5555 ]
5656 self .assertEqual ([], nan_vars )
5757
5858 def compare_variables (self , dataset , skip_vars = ('source_file' , 'instrument_id' )):
5959 """Compare dimensions and values of all variables in dataset with those in self.EXPECTED_OUTPUT_FILE,
6060 except for variables listed in skip_vars.
6161 """
62+
63+ def _arrays_equal (testvar , expected ):
64+ """compare two numpy arrays, handling the case of scalar variables"""
65+ if expected .shape == ():
66+ if np .isclose (testvar , expected ):
67+ return True
68+ elif (np .isclose (testvar , expected )).all ():
69+ return True
70+ return False
71+
6272 differences = []
6373 with Dataset (self .EXPECTED_OUTPUT_FILE ) as expected :
6474 for var in set (expected .variables .keys ()) - set (skip_vars ):
@@ -68,7 +78,7 @@ def compare_variables(self, dataset, skip_vars=('source_file', 'instrument_id'))
6878 differences .append ((var , "shapes differ" ))
6979
7080 # compare the raw data arrays (not the masked_array)
71- if not all ( np . isclose ( dataset [var ][:].data , expected [var ][:].data ) ):
81+ if not _arrays_equal ( dataset [var ][:].data , expected [var ][:].data ):
7282 differences .append ((var , "variable values differ" ))
7383
7484 self .assertEqual ([], differences )
0 commit comments