|
| 1 | +""" A wrapper class to write netcdf files directly from a dictionary. |
| 2 | +
|
| 3 | +written by: Hugo Oliveira ocehugo@gmail.com |
| 4 | +""" |
| 5 | + |
| 6 | +# TODO write tests |
| 7 | +# TODO how we handle groups!? |
| 8 | +# TODO cleanup the user precedence rules of fill_values |
| 9 | +# TODO is_dim_consistent too complex |
| 10 | +# TODO change_time too complex |
| 11 | +# TODO check_var too complex |
| 12 | +# TODO createVariables too complex |
| 13 | +# TODO implement from_file/from_cdl/from_json kwarg!? |
| 14 | + |
| 15 | +import netCDF4 |
| 16 | +import numpy as np |
| 17 | + |
| 18 | + |
| 19 | +class NetCDFGroupDict(object): |
| 20 | + def __init__(self, |
| 21 | + dimensions={}, |
| 22 | + variables={}, |
| 23 | + global_attributes={}, |
| 24 | + title='NetCDFGroupDict', |
| 25 | + **kwargs): |
| 26 | + """ A dictionary to hold netCDF groups |
| 27 | + It consist of a generic class holding 3 different dictionaries: |
| 28 | + dimensions is a <key:int> dict |
| 29 | + variables is <key:[str,class,list,dict,int]> dict |
| 30 | + global_attributes is a <key:int> dict |
| 31 | +
|
| 32 | + This class has __add__ to growth variables/dims/global attrs |
| 33 | + and __sub__ to remove unwanted variables from |
| 34 | + other :NetCDFGroupDict: instances. |
| 35 | +
|
| 36 | + Example: |
| 37 | + dmn = {'lon':360,'lat':210} |
| 38 | + var = {} |
| 39 | + var['water'] = {'type':'double','dims':['lat','lon']} |
| 40 | + w1 = NetCDFGroupDict(dimensions=dmn,variables=var) |
| 41 | + dmn2 = {'time':300,'lon':720,'lat':330} |
| 42 | + var2 = {} |
| 43 | + var2['temp'] = {'type':'double','dims':['time','lat','lon']} |
| 44 | + w2 = NetCDFGroupDict(dimensions=dmn2,variables=var2) |
| 45 | + w3 = w1+w2 |
| 46 | + #w3.variables.keys() = ['water','temp'] |
| 47 | + #w3.dimensions = {'time':300,'lon':360,'lat':210} |
| 48 | + w4 = w2-w1 |
| 49 | + #w4.variables.keys() = ['temp'] |
| 50 | + #w4.dimensions = {'lon':720,'lat':330,'time':300} |
| 51 | + """ |
| 52 | + self.title = title |
| 53 | + self.dimensions = dimensions |
| 54 | + self.global_attributes = global_attributes |
| 55 | + self.variables = variables |
| 56 | + |
| 57 | + if self.is_dim_consistent: |
| 58 | + self.rdimensions = dict((x, True) if y is -1 else (x, False) |
| 59 | + for x, y in zip(self.dimensions.keys(), |
| 60 | + self.dimensions.values())) |
| 61 | + else: |
| 62 | + raise TypeError("Correct the dimensions.") |
| 63 | + |
| 64 | + notstr = self.title.__class__ is not str |
| 65 | + if notstr: |
| 66 | + raise TypeError("Title is not a str object") |
| 67 | + |
| 68 | + self.check_dims(self.dimensions) |
| 69 | + self.check_var(self.variables) |
| 70 | + self.check_global_attributes(self.global_attributes) |
| 71 | + self.check_consistency(self.dimensions, self.variables) |
| 72 | + |
| 73 | + def __add__(self, other): |
| 74 | + return NetCDFGroupDict( |
| 75 | + dimensions={ |
| 76 | + **self.dimensions, |
| 77 | + **other.dimensions |
| 78 | + }, |
| 79 | + variables={ |
| 80 | + **self.variables, |
| 81 | + **other.variables |
| 82 | + }, |
| 83 | + global_attributes={ |
| 84 | + **self.global_attributes, |
| 85 | + **other.global_attributes |
| 86 | + }, |
| 87 | + title=self.title + '+' + other.title) |
| 88 | + |
| 89 | + def __sub__(self, other): |
| 90 | + def diff(a, b): |
| 91 | + return set(a.keys()).difference(set(b.keys())) |
| 92 | + |
| 93 | + return NetCDFGroupDict( |
| 94 | + dimensions={ |
| 95 | + **other.dimensions, |
| 96 | + **self.dimensions |
| 97 | + }, |
| 98 | + variables=dict((x, self.variables[x]) |
| 99 | + for x in diff(self.variables, other.variables)), |
| 100 | + global_attributes={ |
| 101 | + **other.global_attributes, |
| 102 | + **self.global_attributes |
| 103 | + }, |
| 104 | + title=self.title + '-' + other.title) |
| 105 | + |
| 106 | + def is_dim_consistent(self): |
| 107 | + """Check if the variable dictionary |
| 108 | + is consistent with current dimensions""" |
| 109 | + checkdims = set() |
| 110 | + for k in self.variables.keys(): |
| 111 | + try: |
| 112 | + for d in self.variables[k]['dims']: |
| 113 | + checkdims.add(d) |
| 114 | + except KeyError: |
| 115 | + print("Variable %s missing dimension information `dims`" % k) |
| 116 | + |
| 117 | + except TypeError: |
| 118 | + if self.variables[k]['dims'] is None: |
| 119 | + continue |
| 120 | + |
| 121 | + missing = ['dims'] |
| 122 | + |
| 123 | + try: |
| 124 | + self.variables['k']['vtype'] |
| 125 | + except KeyError: |
| 126 | + missing += ['type'] |
| 127 | + |
| 128 | + try: |
| 129 | + self.variables['k']['attr'] |
| 130 | + except KeyError: |
| 131 | + missing += ['attr'] |
| 132 | + |
| 133 | + errstr = "Variable %s is missing information for: " |
| 134 | + for _ in missing: |
| 135 | + errstr += '%s, ' |
| 136 | + errtuple = tuple([k] + missing) |
| 137 | + print(errstr % errtuple) |
| 138 | + |
| 139 | + if checkdims != set(self.dimensions.keys()): |
| 140 | + print("Consistent dimensions are: %s" % checkdims) |
| 141 | + return False |
| 142 | + else: |
| 143 | + return True |
| 144 | + |
| 145 | + def search_time_in_vars(self): |
| 146 | + """Check all vars for specific time variables associated with them""" |
| 147 | + tvars = set() |
| 148 | + for v in self.variables: |
| 149 | + try: |
| 150 | + tvars.add(self.variables[v]['attr']['time']['value']) |
| 151 | + except KeyError: |
| 152 | + None |
| 153 | + |
| 154 | + isnone = tvars == set() |
| 155 | + if isnone: |
| 156 | + return None |
| 157 | + else: |
| 158 | + return tvars |
| 159 | + |
| 160 | + def change_time(self, var, timevar): |
| 161 | + """Change the time dimension associated with variable :var: |
| 162 | + :var: a list or str |
| 163 | + Ex: 'zeta' |
| 164 | + ['zeta','u'] |
| 165 | + ['u','v'] |
| 166 | + ['Ptracer1','Ptracer2'] |
| 167 | + :timevar: a list or str |
| 168 | + Ex: 'bry_time' |
| 169 | + ['zeta_time','uv_time'] |
| 170 | + ['uv_time'] |
| 171 | + ['ptime1','ptime2'] |
| 172 | + """ |
| 173 | + |
| 174 | + if var.__class__ is str: |
| 175 | + var = [var] |
| 176 | + if timevar.__class__ is str: |
| 177 | + timevar = [timevar] |
| 178 | + |
| 179 | + if len(var) == 1 and len(timevar) > 1: |
| 180 | + raise ValueError('Invalid input') |
| 181 | + elif len(var) > 1 and len(timevar) == 1: |
| 182 | + timevar = [timevar for x in range(len(var))] |
| 183 | + |
| 184 | + for v, t in zip(var, timevar): |
| 185 | + vargroup = set(self.variables.keys()) |
| 186 | + dimgroup = set(self.dimensions.keys()) |
| 187 | + v_included = v in vargroup |
| 188 | + t_included = t in vargroup and t in dimgroup |
| 189 | + |
| 190 | + # varname should match dimname for time info |
| 191 | + if not t_included: |
| 192 | + raise ValueError('Time variable:', t, 'not present!') |
| 193 | + if not v_included: |
| 194 | + for k in self.variables.keys(): |
| 195 | + if v in k: |
| 196 | + self.variables[k]['dims'][0] = t |
| 197 | + self.variables[k]['attr']['time']['value'] = t |
| 198 | + else: |
| 199 | + self.variables[v]['dims'][0] = t |
| 200 | + self.variables[v]['attr']['time']['value'] = t |
| 201 | + |
| 202 | + @classmethod |
| 203 | + def check_dims(self, dimdict): |
| 204 | + """ Check the dictionary """ |
| 205 | + for d in dimdict: |
| 206 | + notint = dimdict[d].__class__ is not int |
| 207 | + if notint: |
| 208 | + ValueError("Dimension %s is not an integer object" % d) |
| 209 | + |
| 210 | + @classmethod |
| 211 | + def check_var(self, vardict, name=None): |
| 212 | + """ Check if the dictionary have all the reuqired fields |
| 213 | + to be defined as variable""" |
| 214 | + if name is None: |
| 215 | + name = 'input' |
| 216 | + |
| 217 | + vkeys = vardict.keys() |
| 218 | + have_dims = 'dims' in vkeys |
| 219 | + have_type = 'type' in vkeys |
| 220 | + have_att = 'attr' in vkeys |
| 221 | + have_one = have_dims | have_type | have_att |
| 222 | + have_none = not have_one |
| 223 | + |
| 224 | + if have_none: |
| 225 | + for k in vkeys: |
| 226 | + self.check_var(vardict[k], name=k) |
| 227 | + |
| 228 | + if have_dims: |
| 229 | + notnone = vardict['dims'] is not None |
| 230 | + notlist = vardict['dims'] is not list |
| 231 | + if notnone and notlist: |
| 232 | + ValueError( |
| 233 | + "Dim for %s should be a None or a list object" % name) |
| 234 | + |
| 235 | + if have_att: |
| 236 | + notdict = vardict['attr'] is not dict |
| 237 | + if notdict: |
| 238 | + ValueError("Attr for %s should be a dictionary object" % name) |
| 239 | + if have_type: |
| 240 | + notstr = vardict['type'].__class__ is not str |
| 241 | + nottype = vardict['type'].__class__ is not type |
| 242 | + notcompound = vardict['type'].__class__ is not netCDF4.CompoundType |
| 243 | + notvl = vardict['type'].__class__ is not netCDF4.VLType |
| 244 | + if notstr and nottype and notcompound and notvl: |
| 245 | + ValueError( |
| 246 | + "Type for %s should be a string or type object" % name) |
| 247 | + |
| 248 | + @classmethod |
| 249 | + def check_global_attributes(self, gadict): |
| 250 | + """ Check the dictionary """ |
| 251 | + for g in gadict: |
| 252 | + notstr = gadict[g].__class__ is not str |
| 253 | + if notstr: |
| 254 | + ValueError("Global Attr %s is not an integer object" % g) |
| 255 | + |
| 256 | + @classmethod |
| 257 | + def check_consistency(self, dimdict, vdict): |
| 258 | + """ Check the dictionary """ |
| 259 | + alldims = dimdict.keys() |
| 260 | + allvars = vdict.keys() |
| 261 | + for k in allvars: |
| 262 | + vardims = vdict[k]['dims'] |
| 263 | + if vardims is None: |
| 264 | + continue |
| 265 | + else: |
| 266 | + missing = [x for x in vardims if x not in alldims] |
| 267 | + if missing: |
| 268 | + raise ValueError("Variable %s has undefined dimensions: %s" |
| 269 | + % (k, missing)) |
| 270 | + |
| 271 | + |
| 272 | +class DictDataset(NetCDFGroupDict): |
| 273 | + def __new__(cls, *args, **kwargs): |
| 274 | + return super().__new__(cls) |
| 275 | + |
| 276 | + def __init__(self, *args, **kwargs): |
| 277 | + super().__init__(*args, **kwargs) |
| 278 | + self.cattrs = set([ |
| 279 | + 'zlib', 'complevel', 'shuffle', 'fletcher32', 'contiguous', |
| 280 | + 'chunksizes', 'endian', 'least_significant_digit' |
| 281 | + ]) |
| 282 | + self.fill_aliases = set( |
| 283 | + ['fill_value', 'missing_value', 'FillValue', '_FillValue']) |
| 284 | + |
| 285 | + def set_output(self, outfile, mode='w', **kwargs): |
| 286 | + """Create the dataset """ |
| 287 | + self.outfile = outfile |
| 288 | + self.ncobj = netCDF4.Dataset(self.outfile, mode=mode, **kwargs) |
| 289 | + |
| 290 | + def _create_var_opts(self, vdict): |
| 291 | + """Return a list with attribute names required for the creation of variable |
| 292 | + defined by :vdict: This include creation/special options like: |
| 293 | + `zlib` |
| 294 | + `least_significant_digit` |
| 295 | + `dimensions` |
| 296 | + etc""" |
| 297 | + vset = set(list(vdict.keys())) |
| 298 | + inside = vset.intersection(self.cattrs) |
| 299 | + aliases = vset.intersection(self.fill_aliases) |
| 300 | + |
| 301 | + if len(aliases) > 1: |
| 302 | + raise ValueError('You can only provide one missing value alias!') |
| 303 | + else: |
| 304 | + inside = inside.union(aliases) |
| 305 | + return list(inside) |
| 306 | + |
| 307 | + def createDimensions(self): |
| 308 | + """Create the dimensions on the netcdf file""" |
| 309 | + for dname, dval in zip(self.dimensions.keys(), |
| 310 | + self.dimensions.values()): |
| 311 | + self.ncobj.createDimension(dname, dval) |
| 312 | + |
| 313 | + def createVariables(self, **kwargs): |
| 314 | + """Create all variables for the current class |
| 315 | + **kwargs are included here to overload all options for all variables |
| 316 | + like `zlib` and friends. |
| 317 | + """ |
| 318 | + for v in self.variables.keys(): |
| 319 | + varname = v #self.variables[v]['name'] |
| 320 | + datatype = self.variables[v]['type'] |
| 321 | + dimensions = self.variables[v]['dims'] |
| 322 | + |
| 323 | + var_c_opts = {} |
| 324 | + cwargs = kwargs.copy() |
| 325 | + if dimensions is None: # no kwargs in createVariable |
| 326 | + self.ncobj.createVariable(varname, datatype) |
| 327 | + else: |
| 328 | + var_c_keys = list(self._create_var_opts(self.variables[v])) |
| 329 | + |
| 330 | + var_c_opts = dict( |
| 331 | + (x, self.variables[v][x]) for x in var_c_keys) |
| 332 | + |
| 333 | + ureq_fillvalue = [ |
| 334 | + x for x in cwargs.keys() if x in self.fill_aliases |
| 335 | + ] |
| 336 | + |
| 337 | + vreq_fillvalue = [ |
| 338 | + x for x in var_c_opts.keys() if x in self.fill_aliases |
| 339 | + ] |
| 340 | + |
| 341 | + var_c_opts = {**var_c_opts, **cwargs} |
| 342 | + |
| 343 | + # user precendence |
| 344 | + if (ureq_fillvalue and vreq_fillvalue): |
| 345 | + [var_c_opts.pop(x) for x in vreq_fillvalue] |
| 346 | + fv_val = [var_c_opts.pop(x) for x in ureq_fillvalue] |
| 347 | + var_c_opts['fill_value'] = fv_val[-1] |
| 348 | + elif (ureq_fillvalue and not vreq_fillvalue): |
| 349 | + fv_val = [var_c_opts.pop(x) for x in ureq_fillvalue] |
| 350 | + var_c_opts['fill_value'] = fv_val[-1] |
| 351 | + else: |
| 352 | + fv_val = [var_c_opts.pop(x) for x in vreq_fillvalue] |
| 353 | + if fv_val: |
| 354 | + var_c_opts['fill_value'] = fv_val[-1] |
| 355 | + |
| 356 | + self.ncobj.createVariable( |
| 357 | + varname, datatype, dimensions=dimensions, **var_c_opts) |
| 358 | + |
| 359 | + if 'attr' in self.variables[v].keys(): |
| 360 | + attrs = self.variables[v]['attr'].copy() |
| 361 | + for not_attr in self._create_var_opts(attrs): |
| 362 | + attrs.pop(not_attr) |
| 363 | + |
| 364 | + for attname in attrs.keys(): |
| 365 | + var = self.ncobj.variables[varname] |
| 366 | + value = np.array(attrs[attname]['value']).astype( |
| 367 | + attrs[attname]['type']) |
| 368 | + var.setncattr(attname, value) |
| 369 | + |
| 370 | + def createGlobalAttrs(self): |
| 371 | + """Add the global attributes for the current class""" |
| 372 | + for att in self.global_attributes.keys(): |
| 373 | + self.ncobj.setncattr(att, self.global_attributes[att]) |
| 374 | + |
| 375 | + def create(self, **kwargs): |
| 376 | + """Create in the dimensions/variable and attributes and fill with |
| 377 | + basic information""" |
| 378 | + self.createDimensions() |
| 379 | + self.createVariables(**kwargs) |
| 380 | + self.createGlobalAttrs() |
| 381 | + self.ncobj.sync() |
| 382 | + self.ncobj.close() |
| 383 | + self.ncobj = netCDF4.Dataset(self.outfile, 'a') |
| 384 | + pass |
0 commit comments