Skip to content

Commit 261129a

Browse files
committed
basic unittest for gridded timeseries
1 parent 2b1f81e commit 261129a

3 files changed

Lines changed: 60 additions & 2 deletions

File tree

test_aodntools/base_test.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,24 @@ def check_nan_values(self, dataset):
5151
"check that there are no NaN values in any variable (they should be fill values instead)"
5252
nan_vars = [(name, "contains NaN values")
5353
for name, var in dataset.variables.items()
54-
if var.dtype in (np.dtype('float32'), np.dtype('float64')) and any(np.isnan(var[:]))
54+
if var.dtype in (np.dtype('float32'), np.dtype('float64')) and np.isnan(var[:]).any()
5555
]
5656
self.assertEqual([], nan_vars)
5757

5858
def compare_variables(self, dataset, skip_vars=('source_file', 'instrument_id')):
5959
"""Compare dimensions and values of all variables in dataset with those in self.EXPECTED_OUTPUT_FILE,
6060
except for variables listed in skip_vars.
6161
"""
62+
63+
def _arrays_equal(testvar, expected):
64+
"""compare two numpy arrays, handling the case of scalar variables"""
65+
if expected.shape == ():
66+
if np.isclose(testvar, expected):
67+
return True
68+
elif (np.isclose(testvar, expected)).all():
69+
return True
70+
return False
71+
6272
differences = []
6373
with Dataset(self.EXPECTED_OUTPUT_FILE) as expected:
6474
for var in set(expected.variables.keys()) - set(skip_vars):
@@ -68,7 +78,7 @@ def compare_variables(self, dataset, skip_vars=('source_file', 'instrument_id'))
6878
differences.append((var, "shapes differ"))
6979

7080
# compare the raw data arrays (not the masked_array)
71-
if not all(np.isclose(dataset[var][:].data, expected[var][:].data)):
81+
if not _arrays_equal(dataset[var][:].data, expected[var][:].data):
7282
differences.append((var, "variable values differ"))
7383

7484
self.assertEqual([], differences)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#!/usr/bin/env python3
2+
3+
import os
4+
import unittest
5+
6+
from netCDF4 import Dataset
7+
8+
from test_aodntools.base_test import BaseTestCase
9+
from aodntools import __version__
10+
from aodntools.timeseries_products.gridded_timeseries import grid_variable
11+
12+
13+
TEST_ROOT = os.path.dirname(__file__)
14+
INPUT_FILE = 'IMOS_ANMN-NRS_STZ_20181213_NRSROT_FV02_hourly-timeseries_END-20190523_C-20220428.nc'
15+
16+
17+
class TestGriddedTimeseries(BaseTestCase):
18+
EXPECTED_OUTPUT_FILE = os.path.join(
19+
TEST_ROOT, 'IMOS_ANMN-NRS_TZ_20181213_NRSROT_FV02_TEMP-gridded-timeseries_END-20190523_C-20230110.nc'
20+
)
21+
22+
def test_grid_variable(self):
23+
output_file = grid_variable(INPUT_FILE, 'TEMP', input_dir=TEST_ROOT, output_dir='/tmp')
24+
25+
self.assertRegex(output_file,
26+
r'IMOS_ANMN-NRS_TZ_20181213_NRSROT_FV02_TEMP-gridded-timeseries_END-20190523_C-\d{8}\.nc'
27+
)
28+
29+
dataset = Dataset(output_file)
30+
self.assertSetEqual(set(dataset.dimensions), {'TIME', 'DEPTH'})
31+
self.assertSetEqual(set(dataset.variables.keys()),
32+
{'TIME', 'DEPTH', 'LATITUDE', 'LONGITUDE', 'TEMP', 'TEMP_count'})
33+
34+
# check metadata
35+
self.assertEqual(__version__, dataset.generating_code_version)
36+
self.assertIn(__version__, dataset.lineage)
37+
self.assertIn('gridded_timeseries.py', dataset.lineage)
38+
self.assertIn(INPUT_FILE, dataset.source_file)
39+
40+
self.compare_global_attributes(dataset)
41+
42+
self.check_nan_values(dataset)
43+
44+
self.compare_variables(dataset)
45+
46+
47+
if __name__ == '__main__':
48+
unittest.main()

0 commit comments

Comments
 (0)