Skip to content

Commit 4171bb2

Browse files
authored
Merge pull request #9 from Fung-Lab/feature/save_model_output
feature: save model checkpoint and model output
2 parents 65b7cbd + ca8bc2f commit 4171bb2

11 files changed

Lines changed: 215 additions & 47 deletions

File tree

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,14 @@ dmypy.json
164164
.DS_Store
165165

166166
# data
167-
./data/*
167+
data/**
168168

169169
# config
170170
./config/*
171171

172+
# results
173+
results/**
174+
172175
server/
173176

174177
main.py

.pre-commit-config.yaml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,8 @@ repos:
66
- id: end-of-file-fixer
77
- id: trailing-whitespace
88
# isort
9-
- repo: https://github.com/asottile/seed-isort-config
10-
rev: v2.2.0
11-
hooks:
12-
- id: seed-isort-config
13-
- repo: https://github.com/pre-commit/mirrors-isort
14-
rev: v5.10.1
9+
- repo: https://github.com/pycqa/isort
10+
rev: 5.10.1
1511
hooks:
1612
- id: isort
1713
args: ["--profile", "black"]

configs/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ trainer: property
33

44
task:
55
# run_mode: train
6-
name: "my_train_job"
6+
identifier: "my_train_job"
77

88
reprocess: False
99

configs/examples/DOS_STO.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
trainer: property
22

33
task:
4-
name: "my_train_job"
4+
identifier: "my_train_job"
55
reprocess: False
66
parallel: True
77
seed: 0
@@ -38,7 +38,7 @@ optim:
3838
scheduler_args: {"mode":"min", "factor":0.8, "patience":40, "min_lr":0.00001, "threshold":0.0002}
3939

4040
dataset:
41-
processed: False
41+
processed: True
4242
src: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/STO_DOS_data/raw/"
4343
target_path: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/STO_DOS_data/targets.csv"
4444
pt_path: "/global/cfs/projectdirs/m3641/Sarah/datasets/processed/STO_DOS_data/"

matdeeplearn/common/registry.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class Registry:
5050
"model_name_mapping": {},
5151
"logger_name_mapping": {},
5252
"trainer_name_mapping": {},
53+
"loss_name_mapping": {},
5354
"state": {},
5455
}
5556

@@ -165,6 +166,28 @@ def wrap(func):
165166

166167
return wrap
167168

169+
@classmethod
170+
def register_loss(cls, name):
171+
r"""Register a loss class to registry with key 'name'
172+
173+
Args:
174+
name: Key with which the trainer will be registered.
175+
176+
Usage::
177+
178+
from matdeeplearn.common.registry import registry
179+
180+
@registry.register_loss("dos_loss")
181+
class DOSLoss():
182+
...
183+
"""
184+
185+
def wrap(func):
186+
cls.mapping["loss_name_mapping"][name] = func
187+
return func
188+
189+
return wrap
190+
168191
@classmethod
169192
def register(cls, name, obj):
170193
r"""Register an item to registry with key 'name'
@@ -248,6 +271,10 @@ def get_logger_class(cls, name):
248271
def get_trainer_class(cls, name):
249272
return cls.get_class(name, "trainer_name_mapping")
250273

274+
@classmethod
275+
def get_loss_class(cls, name):
276+
return cls.get_class(name, "loss_name_mapping")
277+
251278
@classmethod
252279
def get(cls, name, default=None, no_warning=False):
253280
r"""Get an item from registry with key 'name'

matdeeplearn/models/base_model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import warnings
2-
from abc import abstractmethod
2+
from abc import ABCMeta, abstractmethod
33

44
import torch
55
import torch.nn as nn
6-
from torch_geometric.nn import radius_graph
76
from torch_geometric.utils import dense_to_sparse
87

98
from matdeeplearn.preprocessor.helpers import (
@@ -14,12 +13,17 @@
1413
)
1514

1615

17-
class BaseModel(nn.Module):
16+
class BaseModel(nn.Module, metaclass=ABCMeta):
1817
def __init__(self, edge_steps: int = 50, self_loop: bool = True) -> None:
1918
super(BaseModel, self).__init__()
2019
self.edge_steps = edge_steps
2120
self.self_loop = self_loop
2221

22+
@property
23+
@abstractmethod
24+
def target_attr(self):
25+
"""Specifies the target attribute property for writing output to file"""
26+
2327
def __str__(self):
2428
# Prints model summary
2529
str_representation = "\n"

matdeeplearn/models/cgcnn.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ def __init__(
6060
self.gc_dim, self.post_fc_dim = dim1, dim1
6161

6262
# Determine output dimension length
63-
self.output_dim = 1 if data[0].y.ndim == 0 else len(data[0].y[0])
63+
if data[0][self.target_attr].ndim == 0:
64+
self.output_dim = 1
65+
else:
66+
self.output_dim = len(data[0][self.target_attr][0])
6467

6568
# setup layers
6669
self.pre_lin_list = self._setup_pre_gnn_layers()
@@ -75,6 +78,10 @@ def __init__(
7578
# workaround for doubled dimension by set2set; if late pooling not recommended to use set2set
7679
self.lin_out_2 = torch.nn.Linear(self.output_dim * 2, self.output_dim)
7780

81+
@property
82+
def target_attr(self):
83+
return "y"
84+
7885
def _setup_pre_gnn_layers(self):
7986
"""Sets up pre-GNN dense layers (NOTE: in v0.1 this is always set to 1 layer)."""
8087
pre_lin_list = torch.nn.ModuleList()

matdeeplearn/models/dos_predict.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ def __init__(
4646
self.gc_dim, self.post_fc_dim = dim1, dim1
4747

4848
# Determine output dimension length
49-
self.output_dim = 1 if data[0].scaled.ndim == 0 else len(data[0].scaled[0])
49+
if data[0][self.target_attr].ndim == 0:
50+
self.output_dim = 1
51+
else:
52+
self.output_dim = len(data[0][self.target_attr][0])
5053

5154
# setup layers
5255
self.pre_lin_list = self._setup_pre_gnn_layers()
@@ -65,6 +68,10 @@ def __init__(
6568
Linear(self.dim2, 1),
6669
)
6770

71+
@property
72+
def target_attr(self):
73+
return "scaled"
74+
6875
def _setup_pre_gnn_layers(self):
6976
"""Sets up pre-GNN dense layers (NOTE: in v0.1 this is always set to 1 layer)."""
7077
pre_lin_list = torch.nn.ModuleList()

matdeeplearn/modules/loss.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from torch import nn
55
from torch_geometric.data import Batch
66

7+
from matdeeplearn.common.registry import registry
78

9+
10+
@registry.register_loss("DOSLoss")
811
class DOSLoss(nn.Module):
912
def __init__(
1013
self,
@@ -47,34 +50,28 @@ def forward(self, predictions: tuple[torch.Tensor, torch.Tensor], target: Batch)
4750
def get_dos_features(self, x, dos):
4851
"""get dos features"""
4952
dos = torch.abs(dos)
53+
dos_sum = torch.sum(dos, axis=1)
5054

51-
center = torch.sum(x * dos, axis=1) / torch.sum(dos, axis=1)
55+
center = torch.sum(x * dos, axis=1) / dos_sum
5256
x_offset = (
5357
torch.repeat_interleave(x[np.newaxis, :], dos.shape[0], axis=0)
5458
- center[:, None]
5559
)
56-
width = torch.diagonal(torch.mm((x_offset**2), dos.T)) / torch.sum(
57-
dos, axis=1
58-
)
59-
skew = (
60-
torch.diagonal(torch.mm((x_offset**3), dos.T))
61-
/ torch.sum(dos, axis=1)
62-
/ width**1.5
63-
)
60+
width = torch.diagonal(torch.mm((x_offset**2), dos.T)) / dos_sum
61+
skew = torch.diagonal(torch.mm((x_offset**3), dos.T)) / dos_sum / width**1.5
6462
kurtosis = (
65-
torch.diagonal(torch.mm((x_offset**4), dos.T))
66-
/ torch.sum(dos, axis=1)
67-
/ width**2
63+
torch.diagonal(torch.mm((x_offset**4), dos.T)) / dos_sum / width**2
6864
)
6965

70-
# find zero index (fermi leve)
66+
# find zero index (fermi level)
7167
zero_index = torch.abs(x - 0).argmin().long()
7268
ef_states = torch.sum(dos[:, zero_index - 20 : zero_index + 20], axis=1) * abs(
7369
x[0] - x[1]
7470
)
7571
return torch.stack((center, width, skew, kurtosis, ef_states), axis=1)
7672

7773

74+
@registry.register_loss("TorchLossWrapper")
7875
class TorchLossWrapper(nn.Module):
7976
def __init__(self, loss_fn="l1_loss"):
8077
super().__init__()

matdeeplearn/trainers/base_trainer.py

Lines changed: 88 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
import copy
2+
import csv
13
import logging
4+
import os
25
from abc import ABC, abstractmethod
6+
from datetime import datetime
37

48
import torch
59
import torch.optim as optim
@@ -17,7 +21,6 @@
1721
from matdeeplearn.common.registry import registry
1822
from matdeeplearn.models.base_model import BaseModel
1923
from matdeeplearn.modules.evaluator import Evaluator
20-
from matdeeplearn.modules.loss import *
2124
from matdeeplearn.modules.scheduler import LRScheduler
2225

2326

@@ -35,6 +38,7 @@ def __init__(
3538
test_loader: DataLoader,
3639
loss: nn.Module,
3740
max_epochs: int,
41+
identifier: str = None,
3842
verbosity: int = None,
3943
):
4044
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -56,9 +60,20 @@ def __init__(
5660
self.step = 0
5761
self.metrics = {}
5862
self.epoch_time = None
63+
self.best_val_metric = 1e10
64+
self.best_model_state = None
5965

6066
self.evaluator = Evaluator()
6167

68+
self.run_dir = os.getcwd()
69+
70+
timestamp = torch.tensor(datetime.now().timestamp()).to(self.device)
71+
self.timestamp_id = datetime.fromtimestamp(timestamp.int()).strftime(
72+
"%Y-%m-%d-%H-%M-%S"
73+
)
74+
if identifier:
75+
self.timestamp_id = f"{self.timestamp_id}-{identifier}"
76+
6277
if self.train_verbosity:
6378
logging.info(
6479
f"GPU is available: {torch.cuda.is_available()}, Quantity: {torch.cuda.device_count()}"
@@ -94,6 +109,7 @@ def from_config(cls, config):
94109
loss = cls._load_loss(config["optim"]["loss"])
95110

96111
max_epochs = config["optim"]["max_epochs"]
112+
identifier = config["task"].get("identifier", None)
97113
verbosity = config["task"].get("verbosity", None)
98114

99115
return cls(
@@ -107,6 +123,7 @@ def from_config(cls, config):
107123
test_loader=test_loader,
108124
loss=loss,
109125
max_epochs=max_epochs,
126+
identifier=identifier,
110127
verbosity=verbosity,
111128
)
112129

@@ -180,15 +197,12 @@ def _load_scheduler(scheduler_config, optimizer):
180197
@staticmethod
181198
def _load_loss(loss_config):
182199
"""Loads the loss from either the TorchLossWrapper or custom loss functions in matdeeplearn"""
183-
try:
184-
loss_type = loss_config["loss_type"]
185-
# if there are other params for loss type, include in call
186-
if loss_config.get("loss_args"):
187-
return eval(loss_type)(**loss_config["loss_args"])
188-
else:
189-
return eval(loss_type)()
190-
except (AttributeError, NameError):
191-
raise NotImplementedError(f"Unknown loss class name: {loss_type}")
200+
loss_cls = registry.get_loss_class(loss_config["loss_type"])
201+
# if there are other params for loss type, include in call
202+
if loss_config.get("loss_args"):
203+
return loss_cls(**loss_config["loss_args"])
204+
else:
205+
return loss_cls()
192206

193207
@abstractmethod
194208
def _load_task(self):
@@ -205,3 +219,67 @@ def validate(self):
205219
@abstractmethod
206220
def predict(self):
207221
"""Implemented by derived classes."""
222+
223+
def update_best_model(self, val_metrics):
224+
"""Updates the best val metric and model, saves the best model, and saves the best model predictions"""
225+
self.best_val_metric = val_metrics[type(self.loss_fn).__name__]["metric"]
226+
self.best_model_state = copy.deepcopy(self.model.state_dict())
227+
228+
self.save_model("best_checkpoint.pt", val_metrics, False)
229+
230+
logging.debug(
231+
f"Saving prediction results for epoch {self.epoch} to: /results/{self.timestamp_id}/"
232+
)
233+
self.predict(self.train_loader, "train")
234+
self.predict(self.val_loader, "val")
235+
self.predict(self.test_loader, "test")
236+
237+
def save_model(self, checkpoint_file, val_metrics=None, training_state=True):
238+
"""Saves the model state dict"""
239+
240+
if training_state:
241+
state = {
242+
"epoch": self.epoch,
243+
"step": self.step,
244+
"state_dict": self.model.state_dict(),
245+
"optimizer": self.optimizer.state_dict(),
246+
"scheduler": self.scheduler.scheduler.state_dict(),
247+
"best_val_metric": self.best_val_metric,
248+
}
249+
else:
250+
state = {"state_dict": self.model.state_dict(), "val_metrics": val_metrics}
251+
252+
checkpoint_dir = os.path.join(
253+
self.run_dir, "results", self.timestamp_id, "checkpoint"
254+
)
255+
os.makedirs(checkpoint_dir, exist_ok=True)
256+
filename = os.path.join(checkpoint_dir, checkpoint_file)
257+
258+
torch.save(state, filename)
259+
return filename
260+
261+
def save_results(self, output, filename, node_level_predictions=False):
262+
results_path = os.path.join(self.run_dir, "results", self.timestamp_id)
263+
os.makedirs(results_path, exist_ok=True)
264+
filename = os.path.join(results_path, filename)
265+
shape = output.shape
266+
267+
id_headers = ["structure_id"]
268+
if node_level_predictions:
269+
id_headers += ["node_id"]
270+
num_cols = (shape[1] - len(id_headers)) // 2
271+
headers = id_headers + ["target"] * num_cols + ["prediction"] * num_cols
272+
273+
with open(filename, "w") as f:
274+
csvwriter = csv.writer(f)
275+
for i in range(0, len(output)):
276+
if i == 0:
277+
csvwriter.writerow(headers)
278+
elif i > 0:
279+
csvwriter.writerow(output[i - 1, :])
280+
return filename
281+
282+
def load_checkpoint(self):
283+
"""Loads the model from a checkpoint.pt file"""
284+
# TODO: implement this method
285+
pass

0 commit comments

Comments
 (0)