Skip to content

Commit 6a00c27

Browse files
author
nerkulec
committed
Rename parameters and update documentation (refs mala-project#636)
1 parent c40c57e commit 6a00c27

4 files changed

Lines changed: 45 additions & 24 deletions

File tree

docs/source/advanced_usage/trainingmodel.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ in the file ``advanced/ex03_tensor_board``. Simply select a logger prior to trai
212212
.. code-block:: python
213213
214214
parameters.running.logger = "tensorboard"
215-
parameters.running.logging_dir = "mala_vis"
215+
parameters.running.logging_dir = "mala_logs"
216216
217217
or
218218

@@ -224,7 +224,7 @@ or
224224
entity="your_wandb_entity"
225225
)
226226
parameters.running.logger = "wandb"
227-
parameters.running.logging_dir = "mala_vis"
227+
parameters.running.logging_dir = "mala_logs"
228228
229229
where ``logging_dir`` specifies some directory in which to save the
230230
MALA logging data. You can also select which metrics to record via
@@ -249,7 +249,7 @@ To save time and resources you can specify the logging interval via
249249

250250
.. code-block:: python
251251
252-
parameters.running.logging_metrics_freq = 10
252+
parameters.running.logging_metrics_interval = 10
253253
254254
If you want to monitor the degree to which the model overfits to the training data,
255255
you can use the option

examples/advanced/ex03_tensor_board.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
parameters.running.logger = "tensorboard"
2929
parameters.running.logging_dir = "mala_vis"
3030
parameters.running.logging_metrics = ["ldos", "band_energy"]
31-
parameters.running.logging_metrics_freq = 5
31+
parameters.running.logging_metrics_interval = 5
3232

3333
data_handler = mala.DataHandler(parameters)
3434
data_handler.add_snapshot(

mala/common/parameters.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,7 +1039,7 @@ class ParametersRunning(ParametersBase):
10391039
- "wandb": Weights and Biases logger.
10401040
10411041
logging_metrics : list
1042-
List of metrics to be used for validation. Default is ["ldos"].
1042+
List of metrics to be used for logging. Default is ["ldos"].
10431043
Possible options are:
10441044
10451045
- "ldos": MSE of the LDOS.
@@ -1049,15 +1049,21 @@ class ParametersRunning(ParametersBase):
10491049
- "total_energy_actual_fe": Total energy computed with ground truth Fermi energy.
10501050
- "fermi_energy": Fermi energy.
10511051
- "density": Electron density.
1052-
- "density_relative": Rlectron density (MAPE).
1052+
- "density_relative": Electron density (MAPE).
10531053
- "dos": Density of states.
10541054
- "dos_relative": Density of states (MAPE).
1055+
1056+
The units for energy metrics are meV/atom.
1057+
Selected metrics are evalauted every `logging_metrics_interval` (see below) epochs.
1058+
To use the energy metrics the validation snapshots need not be shuffled.
1059+
Note that evaluating the energy metrics takes considerably longer than just LDOS
1060+
and therefore it is discouraged.
10551061
10561062
log_metrics_on_train_set : bool
1057-
Whether to validate on the training data as well. Default is False.
1063+
Whether to also log metrics evaluated on the training set. Default is False.
10581064
1059-
logging_metrics_freq : int
1060-
Determines how often validation is performed. Default is 1.
1065+
logging_metrics_interval : int
1066+
Determines how often (in the unit of epochs) metrics are logged. Default is 1.
10611067
10621068
training_log_interval : int
10631069
Determines how often detailed performance info is printed during
@@ -1118,7 +1124,7 @@ def __init__(self):
11181124
self.logger = None
11191125
self.logging_metrics = ["ldos"]
11201126
self.log_metrics_on_train_set = False
1121-
self.logging_metrics_freq = 1
1127+
self.logging_metrics_interval = 1
11221128
self.inference_data_grid = [0, 0, 0]
11231129
self.use_mixed_precision = False
11241130
self.use_graphs = False
@@ -1147,11 +1153,27 @@ def validation_metric(self):
11471153
11481154
Metric for evaluated on the validation set during training.
11491155
Default is "ldos", meaning that the regular loss on the LDOS will be
1150-
used as a metric. Possible options are "band_energy" and
1151-
"total_energy". For these, the band resp. total energy of the
1152-
validation snapshots will be calculated and compared to the provided
1153-
DFT results. Of these, the mean average error in eV/atom will be
1154-
calculated.
1156+
used as a metric.
1157+
1158+
Possible options are:
1159+
1160+
- "ldos": MSE of the LDOS.
1161+
- "band_energy": Band energy.
1162+
- "band_energy_actual_fe": Band energy computed with ground truth Fermi energy.
1163+
- "total_energy": Total energy.
1164+
- "total_energy_actual_fe": Total energy computed with ground truth Fermi energy.
1165+
- "fermi_energy": Fermi energy.
1166+
- "density": Electron density.
1167+
- "density_relative": Electron density (MAPE).
1168+
- "dos": Density of states.
1169+
- "dos_relative": Density of states (MAPE).
1170+
1171+
The units for energy metrics are meV/atom.
1172+
Selected metric is evalauted every epoch.
1173+
The validation metric is used as a criterion for early stopping and also
1174+
for checkpointing the best model.
1175+
Note that evaluating the energy metrics takes considerably longer than LDOS
1176+
and therefore it is discouraged.
11551177
"""
11561178
return self._validation_metric
11571179

@@ -1170,15 +1192,14 @@ def validation_metric(self, value):
11701192
@property
11711193
def final_validation_metric(self):
11721194
"""
1173-
Get the metric used during training.
1195+
Metric for final model evaluation.
11741196
1175-
Metric for evaluated on the validation and test set before and after
1176-
training. Default is "LDOS", meaning that the regular loss on the LDOS
1177-
will be used as a metric. Possible options are "band_energy" and
1178-
"total_energy". For these, the band resp. total energy of the
1179-
validation snapshots will be calculated and compared to the provided
1180-
DFT results. Of these, the mean average error in eV/atom will be
1181-
calculated.
1197+
This metric is evaluated on the validation set after training.
1198+
Available options are the same as for `validation_metric`.
1199+
Default is "LDOS", meaning that MSE of the LDOS
1200+
will be used as a metric.
1201+
The final validation metric is used as a target
1202+
for hyperparameter optimization.
11821203
"""
11831204
return self._final_validation_metric
11841205

mala/network/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def train_network(self):
465465
logging_metrics = ["ldos"]
466466
if (
467467
epoch != 0
468-
and (epoch - 1) % self.parameters.logging_metrics_freq == 0
468+
and (epoch - 1) % self.parameters.logging_metrics_interval == 0
469469
):
470470
logging_metrics = self.parameters.logging_metrics
471471
errors = self._evaluate_metrics(

0 commit comments

Comments
 (0)