Skip to content

Commit 91a7ecc

Browse files
Merge pull request mala-project#659 from nerkulec/rename_validation_settings
Rename validation settings (ref mala-project#636)
2 parents 5f51c1d + e3b0405 commit 91a7ecc

6 files changed

Lines changed: 92 additions & 71 deletions

File tree

docs/source/advanced_usage/hyperparameters.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ a physical validation metric such as
114114

115115
.. code-block:: python
116116
117-
parameters.running.after_training_metric = "band_energy"
117+
parameters.running.final_validation_metric = "band_energy"
118118
119119
Advanced optimization algorithms
120120
********************************

docs/source/advanced_usage/trainingmodel.rst

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ is directly outputted by MALA. By default, this validation loss gives the
7171
mean squared error between LDOS prediction and actual value. From a purely
7272
ML point of view, this is fine; however, the correctness of the LDOS itself
7373
does not hold much physical virtue. Thus, MALA implements physical validation
74-
metrics to be accessed before and after the training routine.
74+
metrics which can be evaluated for example after the training.
7575

7676
Specifically, when setting
7777

7878
.. code-block:: python
7979
80-
parameters.running.after_training_metric = "band_energy"
80+
parameters.running.final_validation_metric = "band_energy"
8181
8282
the error in the band energy between actual and predicted LDOS will be
8383
calculated and printed before and after network training (in meV/atom).
@@ -212,7 +212,7 @@ in the file ``advanced/ex03_tensor_board``. Simply select a logger prior to trai
212212
.. code-block:: python
213213
214214
parameters.running.logger = "tensorboard"
215-
parameters.running.logging_dir = "mala_vis"
215+
parameters.running.logging_dir = "mala_logs"
216216
217217
or
218218

@@ -224,14 +224,14 @@ or
224224
entity="your_wandb_entity"
225225
)
226226
parameters.running.logger = "wandb"
227-
parameters.running.logging_dir = "mala_vis"
227+
parameters.running.logging_dir = "mala_logs"
228228
229229
where ``logging_dir`` specifies some directory in which to save the
230230
MALA logging data. You can also select which metrics to record via
231231

232232
.. code-block:: python
233233
234-
parameters.validation_metrics = ["ldos", "dos", "density", "total_energy"]
234+
parameters.logging_metrics = ["ldos", "dos", "density", "total_energy"]
235235
236236
Full list of available metrics:
237237
- "ldos": MSE of the LDOS.
@@ -249,14 +249,14 @@ To save time and resources you can specify the logging interval via
249249

250250
.. code-block:: python
251251
252-
parameters.running.validate_every_n_epochs = 10
252+
parameters.running.logging_metrics_interval = 10
253253
254254
If you want to monitor the degree to which the model overfits to the training data,
255255
you can use the option
256256

257257
.. code-block:: python
258258
259-
parameters.running.validate_on_training_data = True
259+
parameters.running.log_metrics_on_train_set = True
260260
261261
MALA will evaluate the validation metrics on the training set as well as the validation set.
262262

examples/advanced/ex03_tensor_board.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
# files into.
2828
parameters.running.logger = "tensorboard"
2929
parameters.running.logging_dir = "mala_vis"
30-
parameters.running.validation_metrics = ["ldos", "band_energy"]
31-
parameters.running.validate_every_n_epochs = 5
30+
parameters.running.logging_metrics = ["ldos", "band_energy"]
31+
parameters.running.logging_metrics_interval = 5
3232

3333
data_handler = mala.DataHandler(parameters)
3434
data_handler.add_snapshot(

examples/advanced/ex06_distributed_hyperparameter_optimization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@
4242
parameters.hyperparameters.number_training_per_trial = 3
4343

4444
# Hyperparameter optimization can be further refined by using ensemble training
45-
# at each step and by using a different metric then the validation loss
45+
# at each step and by using a different metric then the test metric
4646
# (e.g. the band energy). It is recommended not to use the ensemble training
4747
# method in Single-GPU use, as it naturally drastically decreases performance.
48-
# For this small example setting, using the band energy as the after training
48+
# For this small example setting, using the band energy as the test
4949
# metric is not recommended, since the small data size makes
5050
# an accurate hyperparameter search difficult. For larger systems, enabling
5151
# this option is recommended.
52-
# parameters.running.after_training_metric = "band_energy"
52+
# parameters.running.final_validation_metric = "band_energy"
5353

5454
data_handler = mala.DataHandler(parameters)
5555

mala/common/parameters.py

Lines changed: 63 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,8 +1038,8 @@ class ParametersRunning(ParametersBase):
10381038
- "tensorboard": Tensorboard logger.
10391039
- "wandb": Weights and Biases logger.
10401040
1041-
validation_metrics : list
1042-
List of metrics to be used for validation. Default is ["ldos"].
1041+
logging_metrics : list
1042+
List of metrics to be used for logging. Default is ["ldos"].
10431043
Possible options are:
10441044
10451045
- "ldos": MSE of the LDOS.
@@ -1049,15 +1049,21 @@ class ParametersRunning(ParametersBase):
10491049
- "total_energy_actual_fe": Total energy computed with ground truth Fermi energy.
10501050
- "fermi_energy": Fermi energy.
10511051
- "density": Electron density.
1052-
- "density_relative": Rlectron density (MAPE).
1052+
- "density_relative": Electron density (MAPE).
10531053
- "dos": Density of states.
10541054
- "dos_relative": Density of states (MAPE).
1055+
1056+
The units for energy metrics are meV/atom.
1057+
Selected metrics are evalauted every `logging_metrics_interval` (see below) epochs.
1058+
To use the energy metrics the validation snapshots need not be shuffled.
1059+
Note that evaluating the energy metrics takes considerably longer than just LDOS
1060+
and therefore it is discouraged.
10551061
1056-
validate_on_training_data : bool
1057-
Whether to validate on the training data as well. Default is False.
1062+
log_metrics_on_train_set : bool
1063+
Whether to also log metrics evaluated on the training set. Default is False.
10581064
1059-
validate_every_n_epochs : int
1060-
Determines how often validation is performed. Default is 1.
1065+
logging_metrics_interval : int
1066+
Determines how often (in the unit of epochs) metrics are logged. Default is 1.
10611067
10621068
training_log_interval : int
10631069
Determines how often detailed performance info is printed during
@@ -1103,8 +1109,8 @@ def __init__(self):
11031109
self.learning_rate_scheduler = None
11041110
self.learning_rate_decay = 0.1
11051111
self.learning_rate_patience = 0
1106-
self._during_training_metric = "ldos"
1107-
self._after_training_metric = "ldos"
1112+
self._validation_metric = "ldos"
1113+
self._final_validation_metric = "ldos"
11081114
# self.use_compression = False
11091115
self.num_workers = 0
11101116
self.use_shuffling_for_samplers = True
@@ -1116,9 +1122,9 @@ def __init__(self):
11161122
self.logging_dir = "./mala_logging"
11171123
self.logging_dir_append_date = True
11181124
self.logger = None
1119-
self.validation_metrics = ["ldos"]
1120-
self.validate_on_training_data = False
1121-
self.validate_every_n_epochs = 1
1125+
self.logging_metrics = ["ldos"]
1126+
self.log_metrics_on_train_set = False
1127+
self.logging_metrics_interval = 1
11221128
self.inference_data_grid = [0, 0, 0]
11231129
self.use_mixed_precision = False
11241130
self.use_graphs = False
@@ -1137,60 +1143,75 @@ def _update_ddp(self, new_ddp):
11371143
New DDP setting.
11381144
"""
11391145
super(ParametersRunning, self)._update_ddp(new_ddp)
1140-
self.during_training_metric = self.during_training_metric
1141-
self.after_training_metric = self.after_training_metric
1146+
self.validation_metric = self.validation_metric
1147+
self.final_validation_metric = self.final_validation_metric
11421148

11431149
@property
1144-
def during_training_metric(self):
1150+
def validation_metric(self):
11451151
"""
1146-
Control the metric used during training.
1152+
Control the metric used for validation.
11471153
1148-
Metric for evaluated on the validation set during training.
1154+
Metric to be evaluated on the validation set during training.
11491155
Default is "ldos", meaning that the regular loss on the LDOS will be
1150-
used as a metric. Possible options are "band_energy" and
1151-
"total_energy". For these, the band resp. total energy of the
1152-
validation snapshots will be calculated and compared to the provided
1153-
DFT results. Of these, the mean average error in eV/atom will be
1154-
calculated.
1155-
"""
1156-
return self._during_training_metric
1156+
used as a metric.
1157+
1158+
Possible options are:
11571159
1158-
@during_training_metric.setter
1159-
def during_training_metric(self, value):
1160+
- "ldos": MSE of the LDOS.
1161+
- "band_energy": Band energy.
1162+
- "band_energy_actual_fe": Band energy computed with ground truth Fermi energy.
1163+
- "total_energy": Total energy.
1164+
- "total_energy_actual_fe": Total energy computed with ground truth Fermi energy.
1165+
- "fermi_energy": Fermi energy.
1166+
- "density": Electron density.
1167+
- "density_relative": Electron density (MAPE).
1168+
- "dos": Density of states.
1169+
- "dos_relative": Density of states (MAPE).
1170+
1171+
The units for energy metrics are meV/atom.
1172+
Selected metric is evalauted after every epoch on the validation set.
1173+
The validation metric is used as a criterion for early stopping and also
1174+
for checkpointing the best model.
1175+
Note that evaluating the energy metrics takes considerably longer than LDOS
1176+
and therefore it is discouraged.
1177+
"""
1178+
return self._validation_metric
1179+
1180+
@validation_metric.setter
1181+
def validation_metric(self, value):
11601182
if value != "ldos":
11611183
if self._configuration["ddp"]:
11621184
raise Exception(
11631185
"Currently, MALA can only operate with the "
11641186
'"ldos" metric for ddp runs.'
11651187
)
1166-
if value not in self.validation_metrics:
1167-
self.validation_metrics.append(value)
1168-
self._during_training_metric = value
1188+
if value not in self.logging_metrics:
1189+
self.logging_metrics.append(value)
1190+
self._validation_metric = value
11691191

11701192
@property
1171-
def after_training_metric(self):
1193+
def final_validation_metric(self):
11721194
"""
1173-
Get the metric used during training.
1195+
Metric for final model evaluation.
11741196
1175-
Metric for evaluated on the validation and test set before and after
1176-
training. Default is "LDOS", meaning that the regular loss on the LDOS
1177-
will be used as a metric. Possible options are "band_energy" and
1178-
"total_energy". For these, the band resp. total energy of the
1179-
validation snapshots will be calculated and compared to the provided
1180-
DFT results. Of these, the mean average error in eV/atom will be
1181-
calculated.
1197+
This metric is evaluated on the validation set after training.
1198+
Available options are the same as for `validation_metric`.
1199+
Default is "LDOS", meaning that MSE of the LDOS
1200+
will be used as a metric.
1201+
The final validation metric is used as a target
1202+
for hyperparameter optimization.
11821203
"""
1183-
return self._after_training_metric
1204+
return self._final_validation_metric
11841205

1185-
@after_training_metric.setter
1186-
def after_training_metric(self, value):
1206+
@final_validation_metric.setter
1207+
def final_validation_metric(self, value):
11871208
if value != "ldos":
11881209
if self._configuration["ddp"]:
11891210
raise Exception(
11901211
"Currently, MALA can only operate with the "
11911212
'"ldos" metric for ddp runs.'
11921213
)
1193-
self._after_training_metric = value
1214+
self._final_validation_metric = value
11941215

11951216
@property
11961217
def use_graphs(self):

mala/network/trainer.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -460,24 +460,24 @@ def train_network(self):
460460
total_batch_id += 1
461461

462462
dataset_fractions = ["validation"]
463-
if self.parameters.validate_on_training_data:
463+
if self.parameters.log_metrics_on_train_set:
464464
dataset_fractions.append("train")
465-
validation_metrics = ["ldos"]
465+
logging_metrics = ["ldos"]
466466
if (
467467
epoch != 0
468-
and (epoch - 1) % self.parameters.validate_every_n_epochs == 0
468+
and (epoch - 1) % self.parameters.logging_metrics_interval == 0
469469
):
470-
validation_metrics = self.parameters.validation_metrics
471-
errors = self._validate_network(
472-
dataset_fractions, validation_metrics
470+
logging_metrics = self.parameters.logging_metrics
471+
errors = self._evaluate_metrics(
472+
dataset_fractions, logging_metrics
473473
)
474474
for dataset_fraction in dataset_fractions:
475475
for metric in errors[dataset_fraction]:
476476
errors[dataset_fraction][metric] = np.mean(
477477
np.abs(errors[dataset_fraction][metric])
478478
)
479479
vloss = errors["validation"][
480-
self.parameters.during_training_metric
480+
self.parameters.validation_metric
481481
]
482482
if self.parameters_full.use_ddp:
483483
vloss = self.__average_validation(
@@ -589,17 +589,17 @@ def train_network(self):
589589
############################
590590
# CALCULATE FINAL METRICS
591591
############################
592-
if self.parameters.after_training_metric in errors["validation"]:
592+
if self.parameters.final_validation_metric in errors["validation"]:
593593
self.final_validation_loss = errors["validation"][
594-
self.parameters.after_training_metric
594+
self.parameters.final_validation_metric
595595
]
596596
else:
597-
final_errors = self._validate_network(
598-
["validation"], [self.parameters.after_training_metric]
597+
final_errors = self._evaluate_metrics(
598+
["validation"], [self.parameters.final_validation_metric]
599599
)
600600
vloss = np.mean(
601601
final_errors["validation"][
602-
self.parameters.after_training_metric
602+
self.parameters.final_validation_metric
603603
]
604604
)
605605

@@ -619,7 +619,7 @@ def train_network(self):
619619
self._training_data_loaders.cleanup()
620620
self._validation_data_loaders.cleanup()
621621

622-
def _validate_network(self, data_set_fractions, metrics):
622+
def _evaluate_metrics(self, data_set_fractions, metrics):
623623
"""
624624
Validate a network, using train or validation data.
625625
@@ -770,11 +770,11 @@ def __calculate_validation_error_ldos_only(self, data_loaders):
770770
"""
771771
Calculate the validation error for the LDOS only.
772772
773-
This is a specialization of _validate_network that ONLY computes
773+
This is a specialization of _evaluate_metrics that ONLY computes
774774
one error, namely the LDOS error. It is more efficient, especially
775775
in the distributed case, than the implementation called from
776-
_validate_network. As such it is mostly "legacy" code for now, until
777-
we adapt _validate_network.
776+
_evaluate_metrics. As such it is mostly "legacy" code for now, until
777+
we adapt _evaluate_metrics.
778778
779779
Parameters
780780
----------

0 commit comments

Comments
 (0)