Skip to content

Commit 963de8c

Browse files
committed
Adding ability to save model output predictions to csv
1 parent 0467914 commit 963de8c

11 files changed

Lines changed: 187 additions & 46 deletions

File tree

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,14 @@ dmypy.json
163163
.DS_Store
164164

165165
# data
166-
./data/*
166+
data/**
167167

168168
# config
169169
./config/*
170170

171+
# results
172+
results/**
173+
171174
server/
172175

173176
main.py

.pre-commit-config.yaml

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,26 @@ repos:
66
- id: end-of-file-fixer
77
- id: trailing-whitespace
88
# isort
9-
- repo: https://github.com/asottile/seed-isort-config
10-
rev: v2.2.0
11-
hooks:
12-
- id: seed-isort-config
13-
- repo: https://github.com/pre-commit/mirrors-isort
14-
rev: v5.10.1
9+
#- repo: https://github.com/asottile/seed-isort-config
10+
# rev: v2.2.0
11+
# hooks:
12+
# - id: seed-isort-config
13+
- repo: https://github.com/pycqa/isort
14+
rev: 5.10.1
1515
hooks:
1616
- id: isort
1717
args: ["--profile", "black"]
18+
#- repo: https://github.com/pycqa/isort
19+
# rev: 5.8.0
20+
# hooks:
21+
# - id: isort
22+
# name: isort (python)
23+
# - id: isort
24+
# name: isort (cython)
25+
# types: [cython]
26+
# - id: isort
27+
# name: isort (pyi)
28+
# types: [pyi]
1829
# flake8
1930
- repo: https://github.com/pycqa/flake8
2031
rev: 5.0.4

configs/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ trainer: property
33

44
task:
55
# run_mode: train
6-
name: "my_train_job"
6+
identifier: "my_train_job"
77

88
reprocess: False
99

configs/examples/DOS_STO.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
trainer: property
22

33
task:
4-
name: "my_train_job"
4+
identifier: "my_train_job"
55
reprocess: False
66
parallel: True
77
seed: 0
@@ -38,7 +38,7 @@ optim:
3838
scheduler_args: {"mode":"min", "factor":0.8, "patience":40, "min_lr":0.00001, "threshold":0.0002}
3939

4040
dataset:
41-
processed: False
41+
processed: True
4242
src: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/STO_DOS_data/raw/"
4343
target_path: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/STO_DOS_data/targets.csv"
4444
pt_path: "/global/cfs/projectdirs/m3641/Sarah/datasets/processed/STO_DOS_data/"

matdeeplearn/common/registry.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class Registry:
5050
"model_name_mapping": {},
5151
"logger_name_mapping": {},
5252
"trainer_name_mapping": {},
53+
"loss_name_mapping": {},
5354
"state": {},
5455
}
5556

@@ -165,6 +166,28 @@ def wrap(func):
165166

166167
return wrap
167168

169+
@classmethod
170+
def register_loss(cls, name):
171+
r"""Register a loss class to registry with key 'name'
172+
173+
Args:
174+
name: Key with which the trainer will be registered.
175+
176+
Usage::
177+
178+
from matdeeplearn.common.registry import registry
179+
180+
@registry.register_loss("dos_loss")
181+
class DOSLoss():
182+
...
183+
"""
184+
185+
def wrap(func):
186+
cls.mapping["loss_name_mapping"][name] = func
187+
return func
188+
189+
return wrap
190+
168191
@classmethod
169192
def register(cls, name, obj):
170193
r"""Register an item to registry with key 'name'
@@ -248,6 +271,10 @@ def get_logger_class(cls, name):
248271
def get_trainer_class(cls, name):
249272
return cls.get_class(name, "trainer_name_mapping")
250273

274+
@classmethod
275+
def get_loss_class(cls, name):
276+
return cls.get_class(name, "loss_name_mapping")
277+
251278
@classmethod
252279
def get(cls, name, default=None, no_warning=False):
253280
r"""Get an item from registry with key 'name'

matdeeplearn/models/base_model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import warnings
2-
from abc import abstractmethod
2+
from abc import ABCMeta, abstractmethod
33

44
import torch
55
import torch.nn as nn
6-
from torch_geometric.nn import radius_graph
76
from torch_geometric.utils import dense_to_sparse
87

98
from matdeeplearn.preprocessor.helpers import (
@@ -14,12 +13,17 @@
1413
)
1514

1615

17-
class BaseModel(nn.Module):
16+
class BaseModel(nn.Module, metaclass=ABCMeta):
1817
def __init__(self, edge_steps: int = 50, self_loop: bool = True) -> None:
1918
super(BaseModel, self).__init__()
2019
self.edge_steps = edge_steps
2120
self.self_loop = self_loop
2221

22+
@property
23+
@abstractmethod
24+
def target_attr(self):
25+
"""Specifies the target attribute property for writing output to file"""
26+
2327
def __str__(self):
2428
# Prints model summary
2529
str_representation = "\n"

matdeeplearn/models/cgcnn.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ def __init__(
6060
self.gc_dim, self.post_fc_dim = dim1, dim1
6161

6262
# Determine output dimension length
63-
self.output_dim = 1 if data[0].y.ndim == 0 else len(data[0].y[0])
63+
if data[0][self.target_attr].ndim == 0:
64+
self.output_dim = 1
65+
else:
66+
self.output_dim = len(data[0][self.target_attr][0])
6467

6568
# setup layers
6669
self.pre_lin_list = self._setup_pre_gnn_layers()
@@ -75,6 +78,10 @@ def __init__(
7578
# workaround for doubled dimension by set2set; if late pooling not recommended to use set2set
7679
self.lin_out_2 = torch.nn.Linear(self.output_dim * 2, self.output_dim)
7780

81+
@property
82+
def target_attr(self):
83+
return "y"
84+
7885
def _setup_pre_gnn_layers(self):
7986
"""Sets up pre-GNN dense layers (NOTE: in v0.1 this is always set to 1 layer)."""
8087
pre_lin_list = torch.nn.ModuleList()

matdeeplearn/models/dos_predict.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ def __init__(
4646
self.gc_dim, self.post_fc_dim = dim1, dim1
4747

4848
# Determine output dimension length
49-
self.output_dim = 1 if data[0].scaled.ndim == 0 else len(data[0].scaled[0])
49+
if data[0][self.target_attr].ndim == 0:
50+
self.output_dim = 1
51+
else:
52+
self.output_dim = len(data[0][self.target_attr][0])
5053

5154
# setup layers
5255
self.pre_lin_list = self._setup_pre_gnn_layers()
@@ -65,6 +68,10 @@ def __init__(
6568
Linear(self.dim2, 1),
6669
)
6770

71+
@property
72+
def target_attr(self):
73+
return "scaled"
74+
6875
def _setup_pre_gnn_layers(self):
6976
"""Sets up pre-GNN dense layers (NOTE: in v0.1 this is always set to 1 layer)."""
7077
pre_lin_list = torch.nn.ModuleList()

matdeeplearn/modules/loss.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from torch import nn
55
from torch_geometric.data import Batch
66

7+
from matdeeplearn.common.registry import registry
78

9+
10+
@registry.register_loss("DOSLoss")
811
class DOSLoss(nn.Module):
912
def __init__(
1013
self,
@@ -47,34 +50,28 @@ def forward(self, predictions: tuple[torch.Tensor, torch.Tensor], target: Batch)
4750
def get_dos_features(self, x, dos):
4851
"""get dos features"""
4952
dos = torch.abs(dos)
53+
dos_sum = torch.sum(dos, axis=1)
5054

51-
center = torch.sum(x * dos, axis=1) / torch.sum(dos, axis=1)
55+
center = torch.sum(x * dos, axis=1) / dos_sum
5256
x_offset = (
5357
torch.repeat_interleave(x[np.newaxis, :], dos.shape[0], axis=0)
5458
- center[:, None]
5559
)
56-
width = torch.diagonal(torch.mm((x_offset**2), dos.T)) / torch.sum(
57-
dos, axis=1
58-
)
59-
skew = (
60-
torch.diagonal(torch.mm((x_offset**3), dos.T))
61-
/ torch.sum(dos, axis=1)
62-
/ width**1.5
63-
)
60+
width = torch.diagonal(torch.mm((x_offset**2), dos.T)) / dos_sum
61+
skew = torch.diagonal(torch.mm((x_offset**3), dos.T)) / dos_sum / width**1.5
6462
kurtosis = (
65-
torch.diagonal(torch.mm((x_offset**4), dos.T))
66-
/ torch.sum(dos, axis=1)
67-
/ width**2
63+
torch.diagonal(torch.mm((x_offset**4), dos.T)) / dos_sum / width**2
6864
)
6965

70-
# find zero index (fermi leve)
66+
# find zero index (fermi level)
7167
zero_index = torch.abs(x - 0).argmin().long()
7268
ef_states = torch.sum(dos[:, zero_index - 20 : zero_index + 20], axis=1) * abs(
7369
x[0] - x[1]
7470
)
7571
return torch.stack((center, width, skew, kurtosis, ef_states), axis=1)
7672

7773

74+
@registry.register_loss("TorchLossWrapper")
7875
class TorchLossWrapper(nn.Module):
7976
def __init__(self, loss_fn="l1_loss"):
8077
super().__init__()

matdeeplearn/trainers/base_trainer.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import csv
12
import logging
3+
import os
24
from abc import ABC, abstractmethod
5+
from datetime import datetime
36

47
import torch
58
import torch.optim as optim
@@ -17,7 +20,6 @@
1720
from matdeeplearn.common.registry import registry
1821
from matdeeplearn.models.base_model import BaseModel
1922
from matdeeplearn.modules.evaluator import Evaluator
20-
from matdeeplearn.modules.loss import *
2123
from matdeeplearn.modules.scheduler import LRScheduler
2224

2325

@@ -35,6 +37,7 @@ def __init__(
3537
test_loader: DataLoader,
3638
loss: nn.Module,
3739
max_epochs: int,
40+
identifier: str = None,
3841
verbosity: int = None,
3942
):
4043
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -56,9 +59,19 @@ def __init__(
5659
self.step = 0
5760
self.metrics = {}
5861
self.epoch_time = None
62+
self.best_val_metric = 1e10
5963

6064
self.evaluator = Evaluator()
6165

66+
self.run_dir = os.getcwd()
67+
68+
timestamp = torch.tensor(datetime.now().timestamp()).to(self.device)
69+
self.timestamp_id = datetime.fromtimestamp(timestamp.int()).strftime(
70+
"%Y-%m-%d-%H-%M-%S"
71+
)
72+
if identifier:
73+
self.timestamp_id = f"{self.timestamp_id}-{identifier}"
74+
6275
if self.train_verbosity:
6376
logging.info(
6477
f"GPU is available: {torch.cuda.is_available()}, Quantity: {torch.cuda.device_count()}"
@@ -94,6 +107,7 @@ def from_config(cls, config):
94107
loss = cls._load_loss(config["optim"]["loss"])
95108

96109
max_epochs = config["optim"]["max_epochs"]
110+
identifier = config["task"].get("identifier", None)
97111
verbosity = config["task"].get("verbosity", None)
98112

99113
return cls(
@@ -107,6 +121,7 @@ def from_config(cls, config):
107121
test_loader=test_loader,
108122
loss=loss,
109123
max_epochs=max_epochs,
124+
identifier=identifier,
110125
verbosity=verbosity,
111126
)
112127

@@ -180,15 +195,12 @@ def _load_scheduler(scheduler_config, optimizer):
180195
@staticmethod
181196
def _load_loss(loss_config):
182197
"""Loads the loss from either the TorchLossWrapper or custom loss functions in matdeeplearn"""
183-
try:
184-
loss_type = loss_config["loss_type"]
185-
# if there are other params for loss type, include in call
186-
if loss_config.get("loss_args"):
187-
return eval(loss_type)(**loss_config["loss_args"])
188-
else:
189-
return eval(loss_type)()
190-
except (AttributeError, NameError):
191-
raise NotImplementedError(f"Unknown loss class name: {loss_type}")
198+
loss_cls = registry.get_loss_class(loss_config["loss_type"])
199+
# if there are other params for loss type, include in call
200+
if loss_config.get("loss_args"):
201+
return loss_cls(**loss_config["loss_args"])
202+
else:
203+
return loss_cls()
192204

193205
@abstractmethod
194206
def _load_task(self):
@@ -205,3 +217,23 @@ def validate(self):
205217
@abstractmethod
206218
def predict(self):
207219
"""Implemented by derived classes."""
220+
221+
def save_results(self, output, filename, node_level_predictions=False):
222+
results_path = os.path.join(self.run_dir, "results", self.timestamp_id)
223+
os.makedirs(results_path, exist_ok=True)
224+
filename = os.path.join(results_path, filename)
225+
shape = output.shape
226+
227+
id_headers = ["structure_id"]
228+
if node_level_predictions:
229+
id_headers += ["node_id"]
230+
num_cols = (shape[1] - len(id_headers)) // 2
231+
headers = id_headers + ["target"] * num_cols + ["prediction"] * num_cols
232+
233+
with open(filename, "w") as f:
234+
csvwriter = csv.writer(f)
235+
for i in range(0, len(output)):
236+
if i == 0:
237+
csvwriter.writerow(headers)
238+
elif i > 0:
239+
csvwriter.writerow(output[i - 1, :])

0 commit comments

Comments
 (0)