Skip to content

Commit c669d38

Browse files
committed
refactor comparison with expected file in timeseries_products unittests
1 parent 0b6af28 commit c669d38

5 files changed

Lines changed: 85 additions & 60 deletions

File tree

test_aodntools/base_test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22
import tempfile
33
import unittest
44

5+
import numpy as np
6+
from netCDF4 import Dataset
7+
58

69
class BaseTestCase(unittest.TestCase):
10+
EXPECTED_OUTPUT_FILE = None
711

812
@property
913
def temp_dir(self):
@@ -22,3 +26,49 @@ def temp_nc_file(self):
2226
def tearDown(self):
2327
if hasattr(self, '_temp_dir'):
2428
shutil.rmtree(self._temp_dir)
29+
30+
def compare_global_attributes(self, dataset,
31+
attrs = ('geospatial_lat_max', 'geospatial_lat_min',
32+
'geospatial_lon_max', 'geospatial_lon_min',
33+
'geospatial_vertical_max', 'geospatial_vertical_min',
34+
'time_coverage_start', 'time_coverage_end'
35+
)
36+
):
37+
"Compare global attributes of the given dataset with those in self.EXPECTED_OUTPUT_FILE"
38+
39+
not_matching = []
40+
with Dataset(self.EXPECTED_OUTPUT_FILE) as expected:
41+
for attr in attrs:
42+
if dataset.getncattr(attr) != expected.getncattr(attr):
43+
not_matching.append((attr,
44+
"expected: {exp}; found: {found}".format(exp=dataset.getncattr(attr),
45+
found=dataset.getncattr(attr))
46+
))
47+
48+
self.assertEqual([], not_matching)
49+
50+
def check_nan_values(self, dataset):
51+
"check that there are no NaN values in any variable (they should be fill values instead)"
52+
nan_vars = [(name, "contains NaN values")
53+
for name, var in dataset.variables.items()
54+
if var.dtype in (np.dtype('float32'), np.dtype('float64')) and any(np.isnan(var[:]))
55+
]
56+
self.assertEqual([], nan_vars)
57+
58+
def compare_variables(self, dataset, skip_vars=('source_file', 'instrument_id')):
59+
"""Compare dimensions and values of all variables in dataset with those in self.EXPECTED_OUTPUT_FILE,
60+
except for variables listed in skip_vars.
61+
"""
62+
differences = []
63+
with Dataset(self.EXPECTED_OUTPUT_FILE) as expected:
64+
for var in set(expected.variables.keys()) - set(skip_vars):
65+
if not dataset[var].dimensions == expected[var].dimensions:
66+
differences.append((var, "dimensions differ"))
67+
if not dataset[var].shape == expected[var].shape:
68+
differences.append((var, "shapes differ"))
69+
70+
# compare the raw data arrays (not the masked_array)
71+
if not all(dataset[var][:].data == expected[var][:].data):
72+
differences.append((var, "variable values differ"))
73+
74+
self.assertEqual([], differences)

test_aodntools/timeseries_products/test_aggregated_timeseries.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import os
44
import unittest
55

6-
import numpy as np
76
from netCDF4 import Dataset, chartostring
87

98
from aodntools import __version__
@@ -19,12 +18,13 @@
1918
'IMOS_ANMN-NRS_BCKOSTUZ_20181213T080038Z_NRSROT_FV01_NRSROT-1812-WQM-55_END-20181215T013118Z_C-20190828T000000Z.nc',
2019
BAD_FILE
2120
]
22-
EXPECTED_OUTPUT_FILE = os.path.join(
23-
TEST_ROOT, 'IMOS_ANMN-NRS_TZ_20181213_NRSROT_FV01_TEMP-aggregated-timeseries_END-20190523_C-20220607.nc'
24-
)
2521

2622

2723
class TestAggregatedTimeseries(BaseTestCase):
24+
EXPECTED_OUTPUT_FILE = os.path.join(
25+
TEST_ROOT, 'IMOS_ANMN-NRS_TZ_20181213_NRSROT_FV01_TEMP-aggregated-timeseries_END-20190523_C-20220607.nc'
26+
)
27+
2828
def test_main_aggregator(self):
2929
output_file, bad_files = main_aggregator(INPUT_FILES, 'TEMP', 'NRSROT', input_dir=TEST_ROOT,
3030
output_dir='/tmp')
@@ -70,29 +70,11 @@ def test_main_aggregator(self):
7070
self.assertIn(__version__, dataset.lineage)
7171
self.assertIn(BAD_FILE, dataset.rejected_files)
7272

73-
compare_attrs = ('Conventions', 'feature_type', 'author', 'author_email', 'file_version',
74-
'geospatial_lat_max', 'geospatial_lat_min', 'geospatial_lon_max', 'geospatial_lon_min',
75-
'geospatial_vertical_max', 'geospatial_vertical_min', 'naming_authority', 'project',
76-
'time_coverage_start', 'time_coverage_end'
77-
)
78-
expected = Dataset(EXPECTED_OUTPUT_FILE)
79-
for attr in compare_attrs:
80-
self.assertEqual(dataset.getncattr(attr), expected.getncattr(attr))
81-
82-
# check that there are no NaN values in any variable (they should be fill values instead)
83-
nan_vars = [name
84-
for name, var in dataset.variables.items()
85-
if var.dtype in (np.dtype('float32'), np.dtype('float64')) and any(np.isnan(var[:]))
86-
]
87-
self.assertEqual([], nan_vars)
88-
89-
# check aggregated variable values
90-
non_match_vars = []
91-
for var in set(expected.variables.keys()) - string_vars:
92-
# compare the raw data arrays (not the masked_array)
93-
if not all(dataset[var][:].data == expected[var][:].data):
94-
non_match_vars.append(var)
95-
self.assertEqual([], non_match_vars)
73+
self.compare_global_attributes(dataset)
74+
75+
self.check_nan_values(dataset)
76+
77+
self.compare_variables(dataset)
9678

9779
def test_source_file_attributes(self):
9880
output_file, bad_files = main_aggregator(INPUT_FILES, 'PSAL', 'NRSROT', input_dir=TEST_ROOT,

test_aodntools/timeseries_products/test_hourly_timeseries.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
BAD_FILE
2121
]
2222
INPUT_PATHS = [os.path.join(TEST_ROOT, f) for f in INPUT_FILES]
23-
EXPECTED_OUTPUT_FILE = os.path.join(
24-
TEST_ROOT, 'IMOS_ANMN-NRS_STZ_20181213_NRSROT_FV02_hourly-timeseries_END-20190523_C-20220428.nc'
25-
)
2623

2724
INST_VARIABLES = {'instrument_id', 'source_file', 'LONGITUDE', 'LATITUDE', 'NOMINAL_DEPTH'}
2825
OBS_VARIABLES = {'instrument_index', 'TIME'}
@@ -49,6 +46,10 @@
4946

5047

5148
class TestHourlyTimeseries(BaseTestCase):
49+
EXPECTED_OUTPUT_FILE = os.path.join(
50+
TEST_ROOT, 'IMOS_ANMN-NRS_STZ_20181213_NRSROT_FV02_hourly-timeseries_END-20190523_C-20220428.nc'
51+
)
52+
5253
def test_hourly_aggregator(self):
5354
output_file, bad_files = hourly_aggregator(files_to_aggregate=INPUT_PATHS,
5455
site_code='NRSROT',
@@ -87,15 +88,11 @@ def test_hourly_aggregator(self):
8788
self.assertIn('hourly_timeseries.py', dataset.lineage)
8889
self.assertIn(BAD_FILE, dataset.rejected_files)
8990

90-
# check variable values
91-
expected = Dataset(EXPECTED_OUTPUT_FILE)
92-
self.assertEqual(len(expected['TIME']), len(dataset['TIME']))
93-
compare_vars = ('TIME', 'NOMINAL_DEPTH', 'instrument_index',
94-
'TEMP', 'TEMP_count', 'TEMP_min', 'TEMP_max')
95-
non_match_vars = [var for var in compare_vars
96-
if not all(dataset[var][:] == expected[var][:])
97-
]
98-
self.assertEqual(non_match_vars, [])
91+
self.compare_global_attributes(dataset)
92+
93+
self.check_nan_values(dataset)
94+
95+
self.compare_variables(dataset)
9996

10097
def test_hourly_aggregator_with_nonqc(self):
10198
output_file, bad_files = hourly_aggregator(files_to_aggregate=INPUT_FILES,

test_aodntools/timeseries_products/test_velocity_aggregated_timeseries.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818
'IMOS_ANMN-NRS_AETVZ_20191016T080000Z_NRSROT-ADCP_FV01_NRSROT-ADCP-1910-Sentinel-or-Monitor-Workhorse-ADCP-44_END-20191018T100000Z_C-20200430T000000Z.nc',
1919
BAD_FILE
2020
]
21-
EXPECTED_OUTPUT_FILE = os.path.join(
22-
TEST_ROOT, 'IMOS_ANMN-NRS_VZ_20180816_NRSROT_FV01_velocity-aggregated-timeseries_END-20191018_C-20200623.nc'
23-
)
2421

2522
OBS_VARS = {'TIME', 'DEPTH', 'DEPTH_quality_control', 'UCUR', 'UCUR_quality_control',
2623
'VCUR', 'VCUR_quality_control', 'WCUR', 'WCUR_quality_control', 'instrument_index', 'CELL_INDEX'}
@@ -29,6 +26,10 @@
2926

3027

3128
class TestVelocityAggregatedTimeseries(BaseTestCase):
29+
EXPECTED_OUTPUT_FILE = os.path.join(
30+
TEST_ROOT, 'IMOS_ANMN-NRS_VZ_20180816_NRSROT_FV01_velocity-aggregated-timeseries_END-20191018_C-20200623.nc'
31+
)
32+
3233
def test_velocity_aggregated(self):
3334
output_file, bad_files = velocity_aggregated(INPUT_FILES, 'NRSROT', input_dir=TEST_ROOT, output_dir='/tmp')
3435

@@ -56,13 +57,11 @@ def test_velocity_aggregated(self):
5657
self.assertEqual(__version__, dataset.generating_code_version)
5758
self.assertIn(__version__, dataset.lineage)
5859

59-
# check aggregated variable values
60-
expected = Dataset(EXPECTED_OUTPUT_FILE)
61-
compare_vars = set(expected.variables.keys()) - STR_VARS
62-
non_match_vars = [var for var in compare_vars
63-
if not all(dataset[var][:] == expected[var][:])
64-
]
65-
self.assertEqual(non_match_vars, [])
60+
self.compare_global_attributes(dataset)
61+
62+
self.check_nan_values(dataset)
63+
64+
self.compare_variables(dataset)
6665

6766
def test_all_rejected(self):
6867
self.assertRaises(NoInputFilesError, velocity_aggregated, [BAD_FILE], 'NRSROT',

test_aodntools/timeseries_products/test_velocity_hourly_timeseries.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@
1919
'IMOS_ANMN-NRS_AETVZ_20191016T080000Z_NRSROT-ADCP_FV01_NRSROT-ADCP-1910-Sentinel-or-Monitor-Workhorse-ADCP-44_END-20191018T100000Z_C-20200430T000000Z.nc',
2020
BAD_FILE
2121
]
22-
EXPECTED_OUTPUT_FILE = os.path.join(
23-
TEST_ROOT, 'IMOS_ANMN-NRS_VZ_20180816_NRSROT_FV02_velocity-hourly-timeseries_END-20191018_C-20220502.nc'
24-
)
2522

2623
OBS_VARS = {'TIME', 'instrument_index', 'CELL_INDEX'}
2724
INST_VARS = {'LATITUDE', 'LONGITUDE', 'NOMINAL_DEPTH', 'SECONDS_TO_MIDDLE'}
@@ -33,6 +30,10 @@
3330

3431

3532
class TestVelocityHourlyTimeseries(BaseTestCase):
33+
EXPECTED_OUTPUT_FILE = os.path.join(
34+
TEST_ROOT, 'IMOS_ANMN-NRS_VZ_20180816_NRSROT_FV02_velocity-hourly-timeseries_END-20191018_C-20220502.nc'
35+
)
36+
3637
def test_velocity_hourly(self):
3738
output_file, bad_files = velocity_hourly_aggregated(INPUT_FILES, 'NRSROT',
3839
input_dir=TEST_ROOT, output_dir='/tmp')
@@ -61,15 +62,11 @@ def test_velocity_hourly(self):
6162
self.assertEqual(__version__, dataset.generating_code_version)
6263
self.assertIn(__version__, dataset.lineage)
6364

64-
# check aggregated variable values
65-
expected = Dataset(EXPECTED_OUTPUT_FILE)
66-
self.assertEqual(len(expected['TIME']), len(dataset['TIME']))
65+
self.compare_global_attributes(dataset)
66+
67+
self.check_nan_values(dataset)
6768

68-
non_match_vars = []
69-
for var in set(expected.variables.keys()) - STR_VARS:
70-
if not all(np.isclose(dataset[var], expected[var], equal_nan=True)):
71-
non_match_vars.append(var)
72-
self.assertEqual(non_match_vars, [])
69+
self.compare_variables(dataset)
7370

7471
def test_all_rejected(self):
7572
self.assertRaises(NoInputFilesError, velocity_hourly_aggregated, [BAD_FILE], 'NRSROT',

0 commit comments

Comments
 (0)