Skip to content

Commit c40c57e

Browse files
author
nerkulec
committed
Rename validation and logging related parameters (refs mala-project#636)
1 parent 5f51c1d commit c40c57e

6 files changed

Lines changed: 49 additions & 49 deletions

File tree

docs/source/advanced_usage/hyperparameters.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ a physical validation metric such as
114114

115115
.. code-block:: python
116116
117-
parameters.running.after_training_metric = "band_energy"
117+
parameters.running.final_validation_metric = "band_energy"
118118
119119
Advanced optimization algorithms
120120
********************************

docs/source/advanced_usage/trainingmodel.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ is directly outputted by MALA. By default, this validation loss gives the
7171
mean squared error between LDOS prediction and actual value. From a purely
7272
ML point of view, this is fine; however, the correctness of the LDOS itself
7373
does not hold much physical virtue. Thus, MALA implements physical validation
74-
metrics to be accessed before and after the training routine.
74+
metrics which can be evaluated for example after the training.
7575

7676
Specifically, when setting
7777

7878
.. code-block:: python
7979
80-
parameters.running.after_training_metric = "band_energy"
80+
parameters.running.final_validation_metric = "band_energy"
8181
8282
the error in the band energy between actual and predicted LDOS will be
8383
calculated and printed before and after network training (in meV/atom).
@@ -231,7 +231,7 @@ MALA logging data. You can also select which metrics to record via
231231

232232
.. code-block:: python
233233
234-
parameters.validation_metrics = ["ldos", "dos", "density", "total_energy"]
234+
parameters.logging_metrics = ["ldos", "dos", "density", "total_energy"]
235235
236236
Full list of available metrics:
237237
- "ldos": MSE of the LDOS.
@@ -249,14 +249,14 @@ To save time and resources you can specify the logging interval via
249249

250250
.. code-block:: python
251251
252-
parameters.running.validate_every_n_epochs = 10
252+
parameters.running.logging_metrics_freq = 10
253253
254254
If you want to monitor the degree to which the model overfits to the training data,
255255
you can use the option
256256

257257
.. code-block:: python
258258
259-
parameters.running.validate_on_training_data = True
259+
parameters.running.log_metrics_on_train_set = True
260260
261261
MALA will evaluate the validation metrics on the training set as well as the validation set.
262262

examples/advanced/ex03_tensor_board.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
# files into.
2828
parameters.running.logger = "tensorboard"
2929
parameters.running.logging_dir = "mala_vis"
30-
parameters.running.validation_metrics = ["ldos", "band_energy"]
31-
parameters.running.validate_every_n_epochs = 5
30+
parameters.running.logging_metrics = ["ldos", "band_energy"]
31+
parameters.running.logging_metrics_freq = 5
3232

3333
data_handler = mala.DataHandler(parameters)
3434
data_handler.add_snapshot(

examples/advanced/ex06_distributed_hyperparameter_optimization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@
4242
parameters.hyperparameters.number_training_per_trial = 3
4343

4444
# Hyperparameter optimization can be further refined by using ensemble training
45-
# at each step and by using a different metric then the validation loss
45+
# at each step and by using a different metric then the test metric
4646
# (e.g. the band energy). It is recommended not to use the ensemble training
4747
# method in Single-GPU use, as it naturally drastically decreases performance.
48-
# For this small example setting, using the band energy as the after training
48+
# For this small example setting, using the band energy as the test
4949
# metric is not recommended, since the small data size makes
5050
# an accurate hyperparameter search difficult. For larger systems, enabling
5151
# this option is recommended.
52-
# parameters.running.after_training_metric = "band_energy"
52+
# parameters.running.final_validation_metric = "band_energy"
5353

5454
data_handler = mala.DataHandler(parameters)
5555

mala/common/parameters.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,7 +1038,7 @@ class ParametersRunning(ParametersBase):
10381038
- "tensorboard": Tensorboard logger.
10391039
- "wandb": Weights and Biases logger.
10401040
1041-
validation_metrics : list
1041+
logging_metrics : list
10421042
List of metrics to be used for validation. Default is ["ldos"].
10431043
Possible options are:
10441044
@@ -1053,10 +1053,10 @@ class ParametersRunning(ParametersBase):
10531053
- "dos": Density of states.
10541054
- "dos_relative": Density of states (MAPE).
10551055
1056-
validate_on_training_data : bool
1056+
log_metrics_on_train_set : bool
10571057
Whether to validate on the training data as well. Default is False.
10581058
1059-
validate_every_n_epochs : int
1059+
logging_metrics_freq : int
10601060
Determines how often validation is performed. Default is 1.
10611061
10621062
training_log_interval : int
@@ -1103,8 +1103,8 @@ def __init__(self):
11031103
self.learning_rate_scheduler = None
11041104
self.learning_rate_decay = 0.1
11051105
self.learning_rate_patience = 0
1106-
self._during_training_metric = "ldos"
1107-
self._after_training_metric = "ldos"
1106+
self._validation_metric = "ldos"
1107+
self._final_validation_metric = "ldos"
11081108
# self.use_compression = False
11091109
self.num_workers = 0
11101110
self.use_shuffling_for_samplers = True
@@ -1116,9 +1116,9 @@ def __init__(self):
11161116
self.logging_dir = "./mala_logging"
11171117
self.logging_dir_append_date = True
11181118
self.logger = None
1119-
self.validation_metrics = ["ldos"]
1120-
self.validate_on_training_data = False
1121-
self.validate_every_n_epochs = 1
1119+
self.logging_metrics = ["ldos"]
1120+
self.log_metrics_on_train_set = False
1121+
self.logging_metrics_freq = 1
11221122
self.inference_data_grid = [0, 0, 0]
11231123
self.use_mixed_precision = False
11241124
self.use_graphs = False
@@ -1137,11 +1137,11 @@ def _update_ddp(self, new_ddp):
11371137
New DDP setting.
11381138
"""
11391139
super(ParametersRunning, self)._update_ddp(new_ddp)
1140-
self.during_training_metric = self.during_training_metric
1141-
self.after_training_metric = self.after_training_metric
1140+
self.validation_metric = self.validation_metric
1141+
self.final_validation_metric = self.final_validation_metric
11421142

11431143
@property
1144-
def during_training_metric(self):
1144+
def validation_metric(self):
11451145
"""
11461146
Control the metric used during training.
11471147
@@ -1153,22 +1153,22 @@ def during_training_metric(self):
11531153
DFT results. Of these, the mean average error in eV/atom will be
11541154
calculated.
11551155
"""
1156-
return self._during_training_metric
1156+
return self._validation_metric
11571157

1158-
@during_training_metric.setter
1159-
def during_training_metric(self, value):
1158+
@validation_metric.setter
1159+
def validation_metric(self, value):
11601160
if value != "ldos":
11611161
if self._configuration["ddp"]:
11621162
raise Exception(
11631163
"Currently, MALA can only operate with the "
11641164
'"ldos" metric for ddp runs.'
11651165
)
1166-
if value not in self.validation_metrics:
1167-
self.validation_metrics.append(value)
1168-
self._during_training_metric = value
1166+
if value not in self.logging_metrics:
1167+
self.logging_metrics.append(value)
1168+
self._validation_metric = value
11691169

11701170
@property
1171-
def after_training_metric(self):
1171+
def final_validation_metric(self):
11721172
"""
11731173
Get the metric used during training.
11741174
@@ -1180,17 +1180,17 @@ def after_training_metric(self):
11801180
DFT results. Of these, the mean average error in eV/atom will be
11811181
calculated.
11821182
"""
1183-
return self._after_training_metric
1183+
return self._final_validation_metric
11841184

1185-
@after_training_metric.setter
1186-
def after_training_metric(self, value):
1185+
@final_validation_metric.setter
1186+
def final_validation_metric(self, value):
11871187
if value != "ldos":
11881188
if self._configuration["ddp"]:
11891189
raise Exception(
11901190
"Currently, MALA can only operate with the "
11911191
'"ldos" metric for ddp runs.'
11921192
)
1193-
self._after_training_metric = value
1193+
self._final_validation_metric = value
11941194

11951195
@property
11961196
def use_graphs(self):

mala/network/trainer.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -460,24 +460,24 @@ def train_network(self):
460460
total_batch_id += 1
461461

462462
dataset_fractions = ["validation"]
463-
if self.parameters.validate_on_training_data:
463+
if self.parameters.log_metrics_on_train_set:
464464
dataset_fractions.append("train")
465-
validation_metrics = ["ldos"]
465+
logging_metrics = ["ldos"]
466466
if (
467467
epoch != 0
468-
and (epoch - 1) % self.parameters.validate_every_n_epochs == 0
468+
and (epoch - 1) % self.parameters.logging_metrics_freq == 0
469469
):
470-
validation_metrics = self.parameters.validation_metrics
471-
errors = self._validate_network(
472-
dataset_fractions, validation_metrics
470+
logging_metrics = self.parameters.logging_metrics
471+
errors = self._evaluate_metrics(
472+
dataset_fractions, logging_metrics
473473
)
474474
for dataset_fraction in dataset_fractions:
475475
for metric in errors[dataset_fraction]:
476476
errors[dataset_fraction][metric] = np.mean(
477477
np.abs(errors[dataset_fraction][metric])
478478
)
479479
vloss = errors["validation"][
480-
self.parameters.during_training_metric
480+
self.parameters.validation_metric
481481
]
482482
if self.parameters_full.use_ddp:
483483
vloss = self.__average_validation(
@@ -589,17 +589,17 @@ def train_network(self):
589589
############################
590590
# CALCULATE FINAL METRICS
591591
############################
592-
if self.parameters.after_training_metric in errors["validation"]:
592+
if self.parameters.final_validation_metric in errors["validation"]:
593593
self.final_validation_loss = errors["validation"][
594-
self.parameters.after_training_metric
594+
self.parameters.final_validation_metric
595595
]
596596
else:
597-
final_errors = self._validate_network(
598-
["validation"], [self.parameters.after_training_metric]
597+
final_errors = self._evaluate_metrics(
598+
["validation"], [self.parameters.final_validation_metric]
599599
)
600600
vloss = np.mean(
601601
final_errors["validation"][
602-
self.parameters.after_training_metric
602+
self.parameters.final_validation_metric
603603
]
604604
)
605605

@@ -619,7 +619,7 @@ def train_network(self):
619619
self._training_data_loaders.cleanup()
620620
self._validation_data_loaders.cleanup()
621621

622-
def _validate_network(self, data_set_fractions, metrics):
622+
def _evaluate_metrics(self, data_set_fractions, metrics):
623623
"""
624624
Validate a network, using train or validation data.
625625
@@ -770,11 +770,11 @@ def __calculate_validation_error_ldos_only(self, data_loaders):
770770
"""
771771
Calculate the validation error for the LDOS only.
772772
773-
This is a specialization of _validate_network that ONLY computes
773+
This is a specialization of _evaluate_metrics that ONLY computes
774774
one error, namely the LDOS error. It is more efficient, especially
775775
in the distributed case, than the implementation called from
776-
_validate_network. As such it is mostly "legacy" code for now, until
777-
we adapt _validate_network.
776+
_evaluate_metrics. As such it is mostly "legacy" code for now, until
777+
we adapt _evaluate_metrics.
778778
779779
Parameters
780780
----------

0 commit comments

Comments
 (0)