@@ -75,6 +75,18 @@ def __init__(self, atoms=None, features=None, targets=None, atomic=False):
7575
7676 self ._weights = np .ones ((self ._nsystems ,), dtype = int )
7777
78+ if self ._withatoms :
79+ self ._atomicweights = []
80+ for entry in self ._atoms :
81+ self ._atomicweights .append (np .ones ((len (entry ),), dtype = float ))
82+ elif self ._withfeatures :
83+ self ._atomicweights = []
84+ for entry in self ._features :
85+ self ._atomicweights .append (
86+ np .ones ((entry .shape [0 ],), dtype = float ))
87+ else :
88+ self ._atomicweights = None
89+
7890
7991 def _process_data (self ):
8092 '''Based on the stored data, a list of dictionaries,
@@ -216,6 +228,7 @@ def _create_contiguous_hdf(self, fname, data, zz):
216228 subroot = datagrp .create_group ('datapoint{}' .format (isys + 1 ))
217229
218230 hdf_append_weight (subroot , self ._weights [isys ])
231+ hdf_append_atomicweights (subroot , self ._atomicweights [isys ])
219232
220233 if self ._withatoms :
221234 hdf_append_geometry (subroot , data [isys ], True )
@@ -279,6 +292,55 @@ def weights(self, weights):
279292 self ._weights = weights
280293
281294
295+ @property
296+ def atomicweights (self ):
297+ '''Defines property, providing the gradient weight of each atom.
298+
299+ Returns:
300+
301+ atomicweights (list): float-valued list of atomic gradient weights
302+
303+ '''
304+
305+ return self ._atomicweights
306+
307+
308+ @atomicweights .setter
309+ def atomicweights (self , atomicweights ):
310+ '''Sets user-specified gradient weighting of each atom.'''
311+
312+ # enable providing arrays of several dtypes
313+ for ii , entry in enumerate (atomicweights ):
314+ atomicweights [ii ] = np .array (entry , dtype = float )
315+
316+ if not self ._withatoms :
317+ msg = 'Trying to set atomic gradient weighting but the object ' + \
318+ 'was initialized without geometry information.'
319+ raise FnetdataError (msg )
320+
321+ for weights in atomicweights :
322+ weights = np .array (weights )
323+
324+ if not len (atomicweights ) == len (self ._atoms ):
325+ msg = 'Mismatch in list length of atomic gradient weighting ' + \
326+ 'and geometries.'
327+ raise FnetdataError (msg )
328+
329+ # check consistency with geometries and whether (weights >= 0.0)
330+ for isys , weights in enumerate (atomicweights ):
331+ if not len (weights ) == len (self ._atoms [isys ]):
332+ msg = 'Mismatch in number of atomic gradient weights and ' + \
333+ 'number of atoms of corresponding geometry (index: {}).' \
334+ .format (isys + 1 )
335+ raise FnetdataError (msg )
336+ if any (weights < 0.0 ):
337+ msg = 'Negative atomic gradient weight(s) obtained ' + \
338+ '(index: {}).' .format (isys + 1 )
339+ raise FnetdataError (msg )
340+
341+ self ._atomicweights = atomicweights
342+
343+
282344 @property
283345 def ndatapoints (self ):
284346 '''Defines property, providing the number of datapoints.
@@ -461,6 +523,20 @@ def hdf_append_weight(root, weight):
461523 root .attrs ['weight' ] = weight
462524
463525
526+ def hdf_append_atomicweights (root , data ):
527+ '''Appends atomic gradient weights to a given in-memory hdf file.
528+
529+ Args:
530+
531+ root (hdf group): hdf group
532+ data (1darray): atomic weights of current datapoint
533+
534+ '''
535+
536+ weights = root .create_dataset ('atomicweights' , data .shape , dtype = 'float' )
537+ weights [...] = data
538+
539+
464540def hdf_append_geometry (root , data , frac ):
465541 '''Appends geometry information to a given in-memory hdf file.
466542
0 commit comments