Skip to content

Commit 33202ab

Browse files
committed
Add optional atom-resolved weighting of training gradients
1 parent f8b62d7 commit 33202ab

30 files changed

Lines changed: 243 additions & 3 deletions

README.rst

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,43 @@ results obtained by the neural network implementation
1616
Installation
1717
============
1818

19-
|build status|
19+
Please note, that this package has been tested for Python 3.X support. Its usage
20+
additionally requires
21+
22+
- `numerical Python <https://numpy.org/doc/stable/reference/>`_ (`numpy`)
23+
- `pythonic HDF5 <http://www.h5py.org/>`_ (`h5py`)
24+
- `Atomic Simulation Environment <https://wiki.fysik.dtu.dk/ase/>`_ (`ase`)
25+
26+
as well as the `pytest` framework in order to run the regression tests.
27+
28+
Via the Python Package Index
29+
----------------------------
2030

2131
The package can be downloaded and installed via pip into the active Python
22-
interpreter (preferably using a virtual python environment) by::
32+
interpreter (preferably using a virtual python environment) by ::
2333

2434
pip install fortnet-python
2535

2636
or into the user space issueing::
2737

2838
pip install --user fortnet-python
2939

40+
Locally from Source
41+
-------------------
42+
43+
Alternatively, you can install it locally from source, i.e. from the root folder
44+
of the project::
45+
46+
python -m pip install .
47+
48+
Testing
49+
=======
50+
51+
The regression testsuite utilizes the `pytest` framework and may be executed by
52+
::
53+
54+
python -m pytest --basetemp=Testing
55+
3056
Documentation
3157
=============
3258

src/fortformat/fnetdata.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,18 @@ def __init__(self, atoms=None, features=None, targets=None, atomic=False):
7575

7676
self._weights = np.ones((self._nsystems,), dtype=int)
7777

78+
if self._withatoms:
79+
self._atomicweights = []
80+
for entry in self._atoms:
81+
self._atomicweights.append(np.ones((len(entry),), dtype=float))
82+
elif self._withfeatures:
83+
self._atomicweights = []
84+
for entry in self._features:
85+
self._atomicweights.append(
86+
np.ones((entry.shape[0],), dtype=float))
87+
else:
88+
self._atomicweights = None
89+
7890

7991
def _process_data(self):
8092
'''Based on the stored data, a list of dictionaries,
@@ -216,6 +228,7 @@ def _create_contiguous_hdf(self, fname, data, zz):
216228
subroot = datagrp.create_group('datapoint{}'.format(isys + 1))
217229

218230
hdf_append_weight(subroot, self._weights[isys])
231+
hdf_append_atomicweights(subroot, self._atomicweights[isys])
219232

220233
if self._withatoms:
221234
hdf_append_geometry(subroot, data[isys], True)
@@ -279,6 +292,55 @@ def weights(self, weights):
279292
self._weights = weights
280293

281294

295+
@property
296+
def atomicweights(self):
297+
'''Defines property, providing the gradient weight of each atom.
298+
299+
Returns:
300+
301+
atomicweights (list): float-valued list of atomic gradient weights
302+
303+
'''
304+
305+
return self._atomicweights
306+
307+
308+
@atomicweights.setter
309+
def atomicweights(self, atomicweights):
310+
'''Sets user-specified gradient weighting of each atom.'''
311+
312+
# enable providing arrays of several dtypes
313+
for ii, entry in enumerate(atomicweights):
314+
atomicweights[ii] = np.array(entry, dtype=float)
315+
316+
if not self._withatoms:
317+
msg = 'Trying to set atomic gradient weighting but the object ' + \
318+
'was initialized without geometry information.'
319+
raise FnetdataError(msg)
320+
321+
for weights in atomicweights:
322+
weights = np.array(weights)
323+
324+
if not len(atomicweights) == len(self._atoms):
325+
msg = 'Mismatch in list length of atomic gradient weighting ' + \
326+
'and geometries.'
327+
raise FnetdataError(msg)
328+
329+
# check consistency with geometries and whether (weights >= 0.0)
330+
for isys, weights in enumerate(atomicweights):
331+
if not len(weights) == len(self._atoms[isys]):
332+
msg = 'Mismatch in number of atomic gradient weights and ' + \
333+
'number of atoms of corresponding geometry (index: {}).' \
334+
.format(isys + 1)
335+
raise FnetdataError(msg)
336+
if any(weights < 0.0):
337+
msg = 'Negative atomic gradient weight(s) obtained ' + \
338+
'(index: {}).'.format(isys + 1)
339+
raise FnetdataError(msg)
340+
341+
self._atomicweights = atomicweights
342+
343+
282344
@property
283345
def ndatapoints(self):
284346
'''Defines property, providing the number of datapoints.
@@ -461,6 +523,20 @@ def hdf_append_weight(root, weight):
461523
root.attrs['weight'] = weight
462524

463525

526+
def hdf_append_atomicweights(root, data):
527+
'''Appends atomic gradient weights to a given in-memory hdf file.
528+
529+
Args:
530+
531+
root (hdf group): hdf group
532+
data (1darray): atomic weights of current datapoint
533+
534+
'''
535+
536+
weights = root.create_dataset('atomicweights', data.shape, dtype='float')
537+
weights[...] = data
538+
539+
464540
def hdf_append_geometry(root, data, frac):
465541
'''Appends geometry information to a given in-memory hdf file.
466542

test/common.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,3 +406,28 @@ def get_atomicweights_byatoms(atoms):
406406
weights.append(np.asfarray(np.random.randint(1, 100, natom, dtype=int)))
407407

408408
return weights
409+
410+
411+
def get_batomicweights_byatoms(atoms):
412+
'''Generates dummy properties for regression testing.
413+
414+
Args:
415+
416+
atoms (ASE atoms list): list of ASE Atoms objects
417+
418+
Returns:
419+
420+
weights (list): atomic gradient weighting
421+
422+
'''
423+
424+
# fix random seed for reproduction purposes
425+
np.random.seed(42)
426+
sample = [True, False]
427+
428+
weights = []
429+
for atom in atoms:
430+
natom = len(atom)
431+
weights.append(np.random.choice(sample, size=natom))
432+
433+
return weights
816 Bytes
Binary file not shown.
816 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
816 Bytes
Binary file not shown.
816 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)