Skip to content

Commit ca8bc2f

Browse files
committed
Adding ability to save model checkpoint and best model
1 parent 963de8c commit ca8bc2f

3 files changed

Lines changed: 54 additions & 27 deletions

File tree

.pre-commit-config.yaml

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,11 @@ repos:
66
- id: end-of-file-fixer
77
- id: trailing-whitespace
88
# isort
9-
#- repo: https://github.com/asottile/seed-isort-config
10-
# rev: v2.2.0
11-
# hooks:
12-
# - id: seed-isort-config
139
- repo: https://github.com/pycqa/isort
1410
rev: 5.10.1
1511
hooks:
1612
- id: isort
1713
args: ["--profile", "black"]
18-
#- repo: https://github.com/pycqa/isort
19-
# rev: 5.8.0
20-
# hooks:
21-
# - id: isort
22-
# name: isort (python)
23-
# - id: isort
24-
# name: isort (cython)
25-
# types: [cython]
26-
# - id: isort
27-
# name: isort (pyi)
28-
# types: [pyi]
2914
# flake8
3015
- repo: https://github.com/pycqa/flake8
3116
rev: 5.0.4

matdeeplearn/trainers/base_trainer.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import csv
23
import logging
34
import os
@@ -60,6 +61,7 @@ def __init__(
6061
self.metrics = {}
6162
self.epoch_time = None
6263
self.best_val_metric = 1e10
64+
self.best_model_state = None
6365

6466
self.evaluator = Evaluator()
6567

@@ -218,6 +220,44 @@ def validate(self):
218220
def predict(self):
219221
"""Implemented by derived classes."""
220222

223+
def update_best_model(self, val_metrics):
224+
"""Updates the best val metric and model, saves the best model, and saves the best model predictions"""
225+
self.best_val_metric = val_metrics[type(self.loss_fn).__name__]["metric"]
226+
self.best_model_state = copy.deepcopy(self.model.state_dict())
227+
228+
self.save_model("best_checkpoint.pt", val_metrics, False)
229+
230+
logging.debug(
231+
f"Saving prediction results for epoch {self.epoch} to: /results/{self.timestamp_id}/"
232+
)
233+
self.predict(self.train_loader, "train")
234+
self.predict(self.val_loader, "val")
235+
self.predict(self.test_loader, "test")
236+
237+
def save_model(self, checkpoint_file, val_metrics=None, training_state=True):
238+
"""Saves the model state dict"""
239+
240+
if training_state:
241+
state = {
242+
"epoch": self.epoch,
243+
"step": self.step,
244+
"state_dict": self.model.state_dict(),
245+
"optimizer": self.optimizer.state_dict(),
246+
"scheduler": self.scheduler.scheduler.state_dict(),
247+
"best_val_metric": self.best_val_metric,
248+
}
249+
else:
250+
state = {"state_dict": self.model.state_dict(), "val_metrics": val_metrics}
251+
252+
checkpoint_dir = os.path.join(
253+
self.run_dir, "results", self.timestamp_id, "checkpoint"
254+
)
255+
os.makedirs(checkpoint_dir, exist_ok=True)
256+
filename = os.path.join(checkpoint_dir, checkpoint_file)
257+
258+
torch.save(state, filename)
259+
return filename
260+
221261
def save_results(self, output, filename, node_level_predictions=False):
222262
results_path = os.path.join(self.run_dir, "results", self.timestamp_id)
223263
os.makedirs(results_path, exist_ok=True)
@@ -237,3 +277,9 @@ def save_results(self, output, filename, node_level_predictions=False):
237277
csvwriter.writerow(headers)
238278
elif i > 0:
239279
csvwriter.writerow(output[i - 1, :])
280+
return filename
281+
282+
def load_checkpoint(self):
283+
"""Loads the model from a checkpoint.pt file"""
284+
# TODO: implement this method
285+
pass

matdeeplearn/trainers/property_trainer.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,11 @@ def train(self):
8181
_metrics = self._compute_metrics(out, batch, _metrics)
8282
self.metrics = self.evaluator.update("loss", loss.item(), _metrics)
8383

84+
# TODO: could add param to eval and save on increments instead of every time
85+
# Save current model
86+
self.save_model(checkpoint_file="checkpoint.pt", training_state=True)
87+
8488
# Evaluate on validation set if it exists
85-
# TODO: could add param to eval on increments instead of every time
8689
if self.val_loader:
8790
val_metrics = self.validate()
8891

@@ -92,25 +95,18 @@ def train(self):
9295
if epoch % self.train_verbosity == 0:
9396
self._log_metrics(val_metrics)
9497

95-
# update best_val_metric and save predicted outputs for train, test, val
96-
# TODO save checkpoint if metric is best so far
98+
# Update best val metric and model, and save best model and predicted outputs
9799
if (
98100
val_metrics[type(self.loss_fn).__name__]["metric"]
99101
< self.best_val_metric
100102
):
101-
self.best_val_metric = val_metrics[type(self.loss_fn).__name__][
102-
"metric"
103-
]
104-
logging.debug(
105-
f"Saving prediction results for epoch {epoch} to: /results/{self.timestamp_id}/"
106-
)
107-
self.predict(self.train_loader, "train")
108-
self.predict(self.val_loader, "val")
109-
self.predict(self.test_loader, "test")
103+
self.update_best_model(val_metrics)
110104

111105
# step scheduler, using validation error
112106
self._scheduler_step()
113107

108+
return self.best_model_state
109+
114110
def validate(self, split="val"):
115111
self.model.eval()
116112
evaluator, metrics = Evaluator(), {}

0 commit comments

Comments
 (0)