Skip to content

Commit ea2f11c

Browse files
Merge pull request #177 from aodn/gridded_timeseries_unittest
Add unittest for gridded timeseries
2 parents e032a76 + e713ce5 commit ea2f11c

4 files changed

Lines changed: 84 additions & 34 deletions

File tree

aodntools/timeseries_products/gridded_timeseries.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os.path
55
import json
66
from datetime import datetime, timezone
7+
from collections import defaultdict
78

89
import xarray as xr
910
import pandas as pd
@@ -91,24 +92,28 @@ def write_netCDF_aggfile(agg_dataset, output_path, encoding):
9192
return output_path
9293

9394

94-
def set_variableattr(varlist, variable_attribute_dictionary, add_variable_attribute):
95+
def set_variableattr(varlist, variable_attribute_dictionary):
9596
"""
96-
set variables variables atributes
97+
Set variable atributes, separate attributes that should be passed to xarray separately as encoding
98+
parameters
9799
98-
:param varlist: list of variable names
100+
:param varlist: list of variable names to pick out
99101
:param variable_attribute_dictionary: dictionary of the variable attributes
100-
:param add_variable_attribute: additional attributes to add
101-
:return: dictionary of attributes
102+
:return: tuple (dictionary of attributes, dictionary of encoding attributes)
102103
"""
103104

104-
# with open(templatefile) as json_file:
105-
# variable_metadata = json.load(json_file)['_variables']
106-
variable_attributes = {key: variable_attribute_dictionary[key] for key in varlist}
107-
if len(add_variable_attribute)>0:
108-
for key in add_variable_attribute.keys():
109-
variable_attributes[key].update(add_variable_attribute[key])
105+
encoding_attributes = {'_FillValue'}
106+
time_encoding_attributes = {'units', 'calendar'}
107+
variable_attributes = defaultdict(dict)
108+
variable_encodings = defaultdict(dict)
109+
for var in varlist:
110+
for name, value in variable_attribute_dictionary[var].items():
111+
if name in encoding_attributes or (var == 'TIME' and name in time_encoding_attributes):
112+
variable_encodings[var][name] = value
113+
else:
114+
variable_attributes[var][name] = value
110115

111-
return variable_attributes
116+
return variable_attributes, variable_encodings
112117

113118
def generate_netcdf_output_filename(nc, facility_code, data_code, VoI, site_code, product_type, file_version):
114119
"""
@@ -237,10 +242,7 @@ def grid_variable(input_file, VoI, depth_bins=None, max_separation=50, depth_bin
237242

238243
## set variable attributes
239244
varlist = list(VoI_interpolated.variables)
240-
add_variable_attribute = {}
241-
variable_attributes = set_variableattr(varlist, variable_attribute_dictionary, add_variable_attribute)
242-
time_units = variable_attributes['TIME'].pop('units')
243-
time_calendar = variable_attributes['TIME'].pop('calendar')
245+
variable_attributes, encoding = set_variableattr(varlist, variable_attribute_dictionary)
244246
for variable in varlist:
245247
VoI_interpolated[variable].attrs = variable_attributes[variable]
246248

@@ -293,22 +295,12 @@ def grid_variable(input_file, VoI, depth_bins=None, max_separation=50, depth_bin
293295
file_version=file_version)
294296
ncout_path = os.path.join(output_dir, ncout_filename)
295297

296-
encoding = {'TIME': {'_FillValue': None,
297-
'units': time_units,
298-
'calendar': time_calendar,
299-
'zlib': True,
300-
'complevel': 5},
301-
VoI: {'zlib': True,
302-
'complevel': 5,
303-
'dtype': np.dtype('float32')},
304-
VoI+'_count': {'dtype': np.dtype('int16'),
305-
'zlib': True,
306-
'complevel': 5},
307-
'DEPTH': {'dtype': np.dtype('float32'),
308-
'zlib': True,
309-
'complevel': 5},
310-
'LONGITUDE': {'_FillValue': False},
311-
'LATITUDE': {'_FillValue': False}}
298+
# data types and compression for encoding
299+
for var in {'TIME', VoI, VoI+'_count', 'DEPTH'}:
300+
encoding[var].update({'zlib': True, 'complevel': 5})
301+
encoding[VoI].update({'dtype': np.dtype('float32')})
302+
encoding[VoI+'_count'].update({'dtype': np.dtype('int16')})
303+
encoding['DEPTH'].update({'dtype': np.dtype('float32')})
312304

313305
write_netCDF_aggfile(VoI_interpolated, ncout_path, encoding)
314306

test_aodntools/base_test.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,24 @@ def check_nan_values(self, dataset):
5151
"check that there are no NaN values in any variable (they should be fill values instead)"
5252
nan_vars = [(name, "contains NaN values")
5353
for name, var in dataset.variables.items()
54-
if var.dtype in (np.dtype('float32'), np.dtype('float64')) and any(np.isnan(var[:]))
54+
if var.dtype in (np.dtype('float32'), np.dtype('float64')) and np.isnan(var[:]).any()
5555
]
5656
self.assertEqual([], nan_vars)
5757

5858
def compare_variables(self, dataset, skip_vars=('source_file', 'instrument_id')):
5959
"""Compare dimensions and values of all variables in dataset with those in self.EXPECTED_OUTPUT_FILE,
6060
except for variables listed in skip_vars.
6161
"""
62+
63+
def _arrays_equal(testvar, expected):
64+
"""compare two numpy arrays, handling the case of scalar variables"""
65+
if expected.shape == ():
66+
if np.isclose(testvar, expected):
67+
return True
68+
elif (np.isclose(testvar, expected)).all():
69+
return True
70+
return False
71+
6272
differences = []
6373
with Dataset(self.EXPECTED_OUTPUT_FILE) as expected:
6474
for var in set(expected.variables.keys()) - set(skip_vars):
@@ -68,7 +78,7 @@ def compare_variables(self, dataset, skip_vars=('source_file', 'instrument_id'))
6878
differences.append((var, "shapes differ"))
6979

7080
# compare the raw data arrays (not the masked_array)
71-
if not all(np.isclose(dataset[var][:].data, expected[var][:].data)):
81+
if not _arrays_equal(dataset[var][:].data, expected[var][:].data):
7282
differences.append((var, "variable values differ"))
7383

7484
self.assertEqual([], differences)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#!/usr/bin/env python3
2+
3+
import os
4+
import unittest
5+
6+
from netCDF4 import Dataset
7+
8+
from test_aodntools.base_test import BaseTestCase
9+
from aodntools import __version__
10+
from aodntools.timeseries_products.gridded_timeseries import grid_variable
11+
12+
13+
TEST_ROOT = os.path.dirname(__file__)
14+
INPUT_FILE = 'IMOS_ANMN-NRS_STZ_20181213_NRSROT_FV02_hourly-timeseries_END-20190523_C-20220428.nc'
15+
16+
17+
class TestGriddedTimeseries(BaseTestCase):
18+
EXPECTED_OUTPUT_FILE = os.path.join(
19+
TEST_ROOT, 'IMOS_ANMN-NRS_TZ_20181213_NRSROT_FV02_TEMP-gridded-timeseries_END-20190523_C-20230110.nc'
20+
)
21+
22+
def test_grid_variable(self):
23+
output_file = grid_variable(INPUT_FILE, 'TEMP', input_dir=TEST_ROOT, output_dir='/tmp')
24+
25+
self.assertRegex(output_file,
26+
r'IMOS_ANMN-NRS_TZ_20181213_NRSROT_FV02_TEMP-gridded-timeseries_END-20190523_C-\d{8}\.nc'
27+
)
28+
29+
dataset = Dataset(output_file)
30+
self.assertSetEqual(set(dataset.dimensions), {'TIME', 'DEPTH'})
31+
self.assertSetEqual(set(dataset.variables.keys()),
32+
{'TIME', 'DEPTH', 'LATITUDE', 'LONGITUDE', 'TEMP', 'TEMP_count'})
33+
34+
# check metadata
35+
self.assertEqual(__version__, dataset.generating_code_version)
36+
self.assertIn(__version__, dataset.lineage)
37+
self.assertIn('gridded_timeseries.py', dataset.lineage)
38+
self.assertIn(INPUT_FILE, dataset.source_file)
39+
40+
self.compare_global_attributes(dataset)
41+
42+
self.check_nan_values(dataset)
43+
44+
self.compare_variables(dataset)
45+
46+
47+
if __name__ == '__main__':
48+
unittest.main()

0 commit comments

Comments
 (0)