Skip to content

Commit 2168d2d

Browse files
committed
Fix transform bugs
1 parent fbf36fc commit 2168d2d

16 files changed

Lines changed: 2884 additions & 47 deletions

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,5 @@ main.py
178178

179179
test*.py
180180
test*.ipynb
181+
182+
!testing/*

configs/config_alignn.yml

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
2+
trainer: property
3+
4+
task:
5+
# run_mode: train
6+
name: "alignn_first_training"
7+
8+
reprocess: "False"
9+
10+
parallel: "True"
11+
seed: 0
12+
#seed=0 means random initalization
13+
14+
write_output: "True"
15+
parallel: "True"
16+
#Training print out frequency (print per n number of epochs)
17+
verbosity: 1
18+
19+
#Ratios for train/val/test split out of a total of 1
20+
train_ratio: 0.85
21+
val_ratio: 0.05
22+
test_ratio: 0.10
23+
24+
model:
25+
name: "ALIGNN_GRAPHITE"
26+
load_model: "False"
27+
save_model: "True"
28+
model_path: "/global/cfs/projectdirs/m3641/Sidharth/MatDeepLearn_dev/testing/models/alignn_model_t1.pth"
29+
30+
#model attributes
31+
alignn_layers: 4
32+
gcn_layers: 4
33+
atom_input_features: 114
34+
edge_input_features: 50
35+
triplet_input_features: 40
36+
embedding_features: 32
37+
hidden_features: 64
38+
output_features: 1
39+
# min_edge_distance: 0.0,
40+
# max_edge_distance: 8.0,
41+
# min_angle: 0.0,
42+
# max_angle: torch.acos(torch.zeros(1)).item() * 2,
43+
link: "identity"
44+
45+
optim:
46+
max_epochs: 300
47+
lr: 0.001
48+
#Loss functions (from pytorch) examples: l1_loss, mse_loss, binary_cross_entropy
49+
loss_fn: "mse_loss"
50+
batch_size: 64
51+
52+
optimizer:
53+
optimizer_type: "AdamW"
54+
optimizer_args: {"weight_decay": 0.00001}
55+
scheduler:
56+
scheduler_type: "OneCycleLR"
57+
# Look further into steps per epoch, for now hardcoded calculation from paper
58+
scheduler_args: {"max_lr": 0.001, "epochs": 300, "steps_per_epoch": 1}
59+
60+
dataset:
61+
processed: True # if False, need to preprocessor data and generate .pt file
62+
# Whether to use "inmemory" or "large" format for pytorch-geometric dataset. Reccomend inmemory unless the dataset is too large
63+
# dataset_type: "inmemory"
64+
#Path to data files
65+
src: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/MP_data_69K/raw/"
66+
#Path to target file within data_path
67+
target_path: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/MP_data_69K/targets.csv"
68+
#Path to save processed data.pt file (a directory path not filepath)
69+
pt_path: "/global/cfs/projectdirs/m3641/Sidharth/datasets/MP_data_69K/"
70+
transforms:
71+
- NumNodeTransform
72+
- LineGraphMod
73+
- ToFloat
74+
#Format of data files (limit to those supported by ASE)
75+
data_format: "json"
76+
#Method of obtaining atom dictionary: available:(one-hot)
77+
node_representation: "onehot"
78+
#Print out processing info
79+
verbose: "True"
80+
81+
#Loading dataset params
82+
#Index of target column in targets.csv
83+
target_index: 0
84+
85+
#graph specific settings
86+
cutoff_radius : 8.0
87+
n_neighbors : 12
88+
edge_steps : 50

matdeeplearn/common/data.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import torch
44
from torch.utils.data import random_split
55
from torch_geometric.loader import DataLoader
6+
from torch_geometric.transforms import Compose
67

78
from matdeeplearn.preprocessor.datasets import LargeStructureDataset, StructureDataset
8-
from matdeeplearn.preprocessor.transforms import GetY
9+
from matdeeplearn.preprocessor.transforms import TRANSFORM_REGISTRY, GetY
910

1011

1112
# train test split
@@ -58,7 +59,7 @@ def dataset_split(
5859

5960

6061
def get_dataset(
61-
data_path, target_index: int = 0, transform_type="GetY", large_dataset=False
62+
data_path, target_index: int = 0, transform_list=[], large_dataset=False
6263
):
6364
"""
6465
get dataset according to data_path
@@ -78,22 +79,25 @@ def get_dataset(
7879
particular dataset, thus we need to index one column for
7980
the current run/experiment
8081
81-
transform_type: transformation function/class to be applied
82+
transform_list: transformation function/classes to be applied
8283
"""
84+
85+
transforms = [GetY(index=target_index)]
8386

8487
# set transform method
85-
if transform_type == "GetY":
86-
T = GetY
87-
else:
88-
raise ValueError("No such transform found for {transform}")
88+
for transform in transform_list:
89+
if transform in TRANSFORM_REGISTRY:
90+
transforms.append(TRANSFORM_REGISTRY[transform]())
91+
else:
92+
raise ValueError("No such transform found for {transform}")
8993

9094
# check if large dataset is needed
9195
if large_dataset:
9296
Dataset = LargeStructureDataset
9397
else:
9498
Dataset = StructureDataset
9599

96-
transform = T(index=target_index)
100+
transform = Compose(transforms)
97101

98102
return Dataset(data_path, processed_data_path="", transform=transform)
99103

matdeeplearn/models/alignn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,10 @@ def __init__(
356356
)
357357
elif link == "logit":
358358
self.link = torch.sigmoid
359+
360+
@property
361+
def target_attr(self):
362+
return "y"
359363

360364
def forward(self, g: Data):
361365
# Compute OTF transform to generate attributes for L(g)

matdeeplearn/preprocessor/helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pathlib import Path
77

88
import torch
9+
from torch_sparse import SparseTensor
910
import torch.nn.functional as F
1011
from torch_geometric.utils import dense_to_sparse, degree, add_self_loops
1112
from torch_geometric.data.data import Data

matdeeplearn/preprocessor/transforms.py

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch_geometric.utils import remove_self_loops
77
from matdeeplearn.preprocessor.helpers import compute_bond_angles, triplets
88
from scipy.spatial.distance import cdist
9+
from contextlib import contextmanager
910

1011
'''
1112
here resides the transform classes needed for data processing
@@ -16,6 +17,16 @@
1617
The data object will be transformed before every access.
1718
'''
1819

20+
TRANSFORM_REGISTRY = {}
21+
22+
23+
def register_transform(transform_name):
24+
'''Registers a transform function for bookkeeping.'''
25+
def registered_transform(transform):
26+
TRANSFORM_REGISTRY[transform_name] = transform
27+
return transform
28+
return registered_transform
29+
1930

2031
class GetY(object):
2132
def __init__(self, index=0):
@@ -28,6 +39,7 @@ def __call__(self, data):
2839
return data
2940

3041

42+
@register_transform("NumNodeTransform")
3143
class NumNodeTransform(object):
3244
'''
3345
Adds the number of nodes to the data object
@@ -38,6 +50,7 @@ def __call__(self, data):
3850
return data
3951

4052

53+
@register_transform("LineGraphMod")
4154
class LineGraphMod(object):
4255
'''
4356
Adds line graph attributes to the data object
@@ -47,59 +60,37 @@ def __call__(self, data):
4760
# CODE FROM PYG LINEGRAPH TRANSFORM (DIRECTED)
4861
N = data.num_nodes
4962
edge_index, edge_attr = data.edge_index, data.edge_attr
50-
(row, col), edge_attr = coalesce(edge_index, edge_attr, N, N)
51-
52-
i = torch.arange(row.size(0), dtype=torch.long, device=row.device)
53-
count = scatter_add(torch.ones_like(row), row, dim=0,
54-
dim_size=data.num_nodes)
55-
cumsum = torch.cat([count.new_zeros(1), count.cumsum(0)], dim=0)
56-
57-
cols = [
58-
i[cumsum[col[j]]:cumsum[col[j] + 1]]
59-
for j in range(col.size(0))
60-
]
61-
rows = [row.new_full((c.numel(), ), j) for j, c in enumerate(cols)]
62-
63-
row, col = torch.cat(rows, dim=0), torch.cat(cols, dim=0)
64-
65-
data.edge_index_lg = torch.stack([row, col], dim=0)
66-
data.x_lg = data.edge_attr
67-
data.num_nodes_lg = edge_index.size(1)
68-
69-
# CUSTOM CODE FOR CALCULATING EDGE ATTRIBUTES
70-
edge_attr_lg = torch.zeros(
71-
(data.edge_index_lg.shape[1], 1), device='cuda')
63+
_, edge_attr = coalesce(edge_index, edge_attr, N, N)
7264

7365
# compute bond angles
7466
angles, idx_kj, idx_ji = compute_bond_angles(
7567
data.pos, data.cell_offsets, data.edge_index, data.num_nodes)
7668
triplet_pairs = torch.stack([idx_kj, idx_ji], dim=0)
7769

78-
# move triplets and edges to CPU for sklearn based calculation
79-
match_indices = torch.Tensor(
80-
np.where(cdist(data.edge_index_lg.T.cpu(), triplet_pairs.T.cpu()) == 0)[
81-
0].reshape(-1, 1)
82-
).type(torch.long)
70+
data.edge_index_lg = triplet_pairs
71+
data.x_lg = data.edge_attr
72+
data.num_nodes_lg = edge_index.size(1)
8373

8474
# assign bond angles to edge attributes
85-
edge_attr_lg[match_indices.squeeze(-1)] = angles.reshape(-1, 1)
75+
data.edge_attr_lg = angles.reshape(-1, 1)
8676

87-
data.edge_attr_lg = edge_attr_lg
88-
8977
return data
9078

79+
80+
@register_transform("ToFloat")
9181
class ToFloat(object):
9282
'''
9383
Convert non-int attributes to float
9484
'''
85+
9586
def __call__(self, data):
9687
data.x = data.x.float()
9788
data.x_lg = data.x_lg.float()
98-
89+
9990
data.distances = data.distances.float()
10091
data.pos = data.pos.float()
10192

10293
data.edge_attr = data.edge_attr.float()
10394
data.edge_attr_lg = data.edge_attr_lg.float()
10495

105-
return data
96+
return data

matdeeplearn/trainers/base_trainer.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import csv
33
import logging
4+
import re
45
import os
56
from abc import ABC, abstractmethod
67
from datetime import datetime
@@ -41,7 +42,8 @@ def __init__(
4142
identifier: str = None,
4243
verbosity: int = None,
4344
):
44-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45+
self.device = torch.device(
46+
"cuda" if torch.cuda.is_available() else "cpu")
4547
self.model = model.to(self.device)
4648
self.dataset = dataset
4749
self.optimizer = optimizer
@@ -105,7 +107,8 @@ def from_config(cls, config):
105107
train_loader, val_loader, test_loader = cls._load_dataloader(
106108
config["optim"], config["dataset"], dataset, sampler
107109
)
108-
scheduler = cls._load_scheduler(config["optim"]["scheduler"], optimizer)
110+
scheduler = cls._load_scheduler(
111+
config["optim"]["scheduler"], optimizer)
109112
loss = cls._load_loss(config["optim"]["loss"])
110113

111114
max_epochs = config["optim"]["max_epochs"]
@@ -133,7 +136,7 @@ def _load_dataset(dataset_config):
133136
dataset_path = dataset_config["pt_path"]
134137
target_index = dataset_config.get("target_index", 0)
135138

136-
dataset = get_dataset(dataset_path, target_index)
139+
dataset = get_dataset(dataset_path, target_index, transforms_list=dataset_config["transforms"])
137140

138141
return dataset
139142

@@ -180,7 +183,8 @@ def _load_dataloader(optim_config, dataset_config, dataset, sampler):
180183
train_loader = get_dataloader(
181184
train_dataset, batch_size=batch_size, sampler=sampler
182185
)
183-
val_loader = get_dataloader(val_dataset, batch_size=batch_size, sampler=sampler)
186+
val_loader = get_dataloader(
187+
val_dataset, batch_size=batch_size, sampler=sampler)
184188
test_loader = get_dataloader(
185189
test_dataset, batch_size=batch_size, sampler=sampler
186190
)
@@ -222,7 +226,8 @@ def predict(self):
222226

223227
def update_best_model(self, val_metrics):
224228
"""Updates the best val metric and model, saves the best model, and saves the best model predictions"""
225-
self.best_val_metric = val_metrics[type(self.loss_fn).__name__]["metric"]
229+
self.best_val_metric = val_metrics[type(
230+
self.loss_fn).__name__]["metric"]
226231
self.best_model_state = copy.deepcopy(self.model.state_dict())
227232

228233
self.save_model("best_checkpoint.pt", val_metrics, False)
@@ -247,7 +252,8 @@ def save_model(self, checkpoint_file, val_metrics=None, training_state=True):
247252
"best_val_metric": self.best_val_metric,
248253
}
249254
else:
250-
state = {"state_dict": self.model.state_dict(), "val_metrics": val_metrics}
255+
state = {"state_dict": self.model.state_dict(),
256+
"val_metrics": val_metrics}
251257

252258
checkpoint_dir = os.path.join(
253259
self.run_dir, "results", self.timestamp_id, "checkpoint"
@@ -268,7 +274,8 @@ def save_results(self, output, filename, node_level_predictions=False):
268274
if node_level_predictions:
269275
id_headers += ["node_id"]
270276
num_cols = (shape[1] - len(id_headers)) // 2
271-
headers = id_headers + ["target"] * num_cols + ["prediction"] * num_cols
277+
headers = id_headers + ["target"] * \
278+
num_cols + ["prediction"] * num_cols
272279

273280
with open(filename, "w") as f:
274281
csvwriter = csv.writer(f)

0 commit comments

Comments
 (0)