Skip to content

Commit 0467914

Browse files
feature: custom loss function and DOS predict model (#6)
* Adding in custom DOS loss function * Adding DOS predict model
1 parent a70526b commit 0467914

13 files changed

Lines changed: 384 additions & 49 deletions

File tree

configs/config.yml

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,25 @@ task:
55
# run_mode: train
66
name: "my_train_job"
77

8-
reprocess: "False"
8+
reprocess: False
99

1010

11-
parallel: "True"
11+
parallel: True
1212
seed: 0
1313
#seed=0 means random initalization
1414

1515

16-
write_output: "True"
17-
parallel: "True"
16+
write_output: True
17+
parallel: True
1818
#Training print out frequency (print per n number of epochs)
1919
verbosity: 5
2020

21-
#Ratios for train/val/test split out of a total of 1
22-
train_ratio: 0.8
23-
val_ratio: 0.05
24-
test_ratio: 0.15
25-
2621

2722

2823
model:
2924
name: CGCNN
30-
load_model: "False"
31-
save_model: "True"
25+
load_model: False
26+
save_model: True
3227
model_path: "my_model.pth"
3328
edge_steps: 50
3429
self_loop: True
@@ -40,16 +35,19 @@ model:
4035
post_fc_count: 3
4136
pool: "global_mean_pool"
4237
pool_order: "early"
43-
batch_norm: "True"
44-
batch_track_stats: "True"
38+
batch_norm: True
39+
batch_track_stats: True
4540
act: "relu"
4641
dropout_rate: 0.0
4742

4843
optim:
4944
max_epochs: 250
5045
lr: 0.002
51-
#Loss functions (from pytorch) examples: l1_loss, mse_loss, binary_cross_entropy
52-
loss_fn: "l1_loss"
46+
#Either custom or from torch.nn.functional library. If from torch, loss_type is TorchLossWrapper
47+
loss:
48+
loss_type: "TorchLossWrapper"
49+
loss_args: {"loss_fn": "l1_loss"}
50+
5351
batch_size: 100
5452
optimizer:
5553
optimizer_type: "AdamW"
@@ -72,8 +70,9 @@ dataset:
7270
data_format: "json"
7371
#Method of obtaining atom idctionary: available:(onehot)
7472
node_representation: "onehot"
73+
additional_attributes: []
7574
#Print out processing info
76-
verbose: "True"
75+
verbose: True
7776

7877
#Loading dataset params
7978
#Index of target column in targets.csv
@@ -83,4 +82,8 @@ dataset:
8382
cutoff_radius : 8.0
8483
n_neighbors : 12
8584
edge_steps : 50
86-
85+
86+
#Ratios for train/val/test split out of a total of 1
87+
train_ratio: 0.8
88+
val_ratio: 0.05
89+
test_ratio: 0.15

configs/examples/DOS_STO.yml

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
trainer: property
2+
3+
task:
4+
name: "my_train_job"
5+
reprocess: False
6+
parallel: True
7+
seed: 0
8+
write_output: True
9+
verbosity: 5
10+
11+
model:
12+
name: DOSPredict
13+
load_model: False
14+
save_model: True
15+
model_path: "my_model.pth"
16+
edge_steps: 50
17+
self_loop: True
18+
dim1: 370
19+
dim2: 370
20+
pre_fc_count: 1
21+
gc_count: 9
22+
batch_norm: True
23+
batch_track_stats: False
24+
dropout_rate: 0.05
25+
26+
optim:
27+
max_epochs: 2000
28+
lr: 0.00047
29+
loss:
30+
loss_type: "DOSLoss"
31+
loss_args: {"loss_fn": "l1_loss", "scaling_weight": 0.05, "cumsum_weight": 0.005, "features_weight": 0.15}
32+
batch_size: 180
33+
optimizer:
34+
optimizer_type: "AdamW"
35+
optimizer_args: {"weight_decay":0.1}
36+
scheduler:
37+
scheduler_type: "ReduceLROnPlateau"
38+
scheduler_args: {"mode":"min", "factor":0.8, "patience":40, "min_lr":0.00001, "threshold":0.0002}
39+
40+
dataset:
41+
processed: False
42+
src: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/STO_DOS_data/raw/"
43+
target_path: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/STO_DOS_data/targets.csv"
44+
pt_path: "/global/cfs/projectdirs/m3641/Sarah/datasets/processed/STO_DOS_data/"
45+
data_format: "vasp"
46+
node_representation: "onehot"
47+
additional_attributes: ["features", "scaled", "scaling_factor"]
48+
verbose: True
49+
target_index: 0
50+
cutoff_radius : 8.0
51+
n_neighbors : 12
52+
edge_steps : 50
53+
train_ratio: 0.8
54+
val_ratio: 0.05
55+
test_ratio: 0.15

matdeeplearn/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +0,0 @@
1-
from matdeeplearn.common.data import *
2-
3-
from .models import *
4-
from .preprocessor import *

matdeeplearn/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1+
__all__ = ["BaseModel", "CGCNN", "DOSPredict"]
2+
13
from .base_model import BaseModel
4+
from .cgcnn import CGCNN
5+
from .dos_predict import DOSPredict

matdeeplearn/models/base_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from abc import abstractmethod
23

34
import torch
45
import torch.nn as nn
@@ -52,6 +53,10 @@ def __str__(self):
5253

5354
return str_representation
5455

56+
@abstractmethod
57+
def forward(self):
58+
"""The forward method for the model."""
59+
5560
def generate_graph(self, data, r, n_neighbors, otf: bool = False):
5661
"""
5762
generates the graph on-the-fly.

matdeeplearn/models/cgcnn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def __init__(
5959
else:
6060
self.gc_dim, self.post_fc_dim = dim1, dim1
6161

62-
# Determine output dimension length
63-
self.output_dim = 1 if data[0].y.ndim == 0 else len(data[0].y[0])
62+
# Determine output dimension length
63+
self.output_dim = 1 if data[0].y.ndim == 0 else len(data[0].y[0])
6464

6565
# setup layers
6666
self.pre_lin_list = self._setup_pre_gnn_layers()
@@ -99,7 +99,7 @@ def _setup_gnn_layers(self):
9999
)
100100
conv_list.append(conv)
101101
# Track running stats set to false can prevent some instabilities; this causes other issues with different val/test performance from loader size?
102-
if self.batch_norm == "True":
102+
if self.batch_norm:
103103
bn = BatchNorm1d(
104104
self.gc_dim, track_running_stats=self.batch_track_stats
105105
)
@@ -147,7 +147,7 @@ def forward(self, data):
147147
# GNN layers
148148
for i in range(0, len(self.conv_list)):
149149
if len(self.pre_lin_list) == 0 and i == 0:
150-
if self.batch_norm == "True":
150+
if self.batch_norm:
151151
out = self.conv_list[i](
152152
data.x, data.edge_index, data.edge_attr.float()
153153
)
@@ -157,7 +157,7 @@ def forward(self, data):
157157
data.x, data.edge_index, data.edge_attr.float()
158158
)
159159
else:
160-
if self.batch_norm == "True":
160+
if self.batch_norm:
161161
out = self.conv_list[i](
162162
out, data.edge_index, data.edge_attr.float()
163163
)

matdeeplearn/models/dos_predict.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
import torch.nn.functional as F
5+
from torch import Tensor
6+
from torch.nn import BatchNorm1d, Linear, Sequential
7+
from torch_geometric.nn.conv import MessagePassing
8+
from torch_geometric.typing import Adj, OptTensor, PairTensor, Size
9+
10+
from matdeeplearn.common.registry import registry
11+
from matdeeplearn.models.base_model import BaseModel
12+
13+
14+
@registry.register_model("DOSPredict")
15+
class DOSPredict(BaseModel):
16+
def __init__(
17+
self,
18+
edge_steps,
19+
self_loop,
20+
data,
21+
dim1=64,
22+
dim2=64,
23+
pre_fc_count=1,
24+
gc_count=3,
25+
batch_norm=True,
26+
batch_track_stats=True,
27+
dropout_rate=0.0,
28+
**kwargs,
29+
):
30+
super(DOSPredict, self).__init__(edge_steps, self_loop)
31+
self.dim1 = dim1
32+
self.dim2 = dim2
33+
self.pre_fc_count = pre_fc_count
34+
self.gc_count = gc_count
35+
self.num_features = data.num_features
36+
self.num_edge_features = data.num_edge_features
37+
self.batch_norm = batch_norm
38+
self.batch_track_stats = batch_track_stats
39+
self.dropout_rate = dropout_rate
40+
41+
# Determine gc dimension and post_fc dimension
42+
assert gc_count > 0, "Need at least 1 GC layer"
43+
if pre_fc_count == 0:
44+
self.gc_dim, self.post_fc_dim = data.num_features, data.num_features
45+
else:
46+
self.gc_dim, self.post_fc_dim = dim1, dim1
47+
48+
# Determine output dimension length
49+
self.output_dim = 1 if data[0].scaled.ndim == 0 else len(data[0].scaled[0])
50+
51+
# setup layers
52+
self.pre_lin_list = self._setup_pre_gnn_layers()
53+
self.conv_list, self.bn_list = self._setup_gnn_layers()
54+
55+
self.dos_mlp = Sequential(
56+
Linear(self.post_fc_dim, self.dim2),
57+
torch.nn.PReLU(),
58+
Linear(self.dim2, self.output_dim),
59+
torch.nn.PReLU(),
60+
)
61+
62+
self.scaling_mlp = Sequential(
63+
Linear(self.post_fc_dim, self.dim2),
64+
torch.nn.PReLU(),
65+
Linear(self.dim2, 1),
66+
)
67+
68+
def _setup_pre_gnn_layers(self):
69+
"""Sets up pre-GNN dense layers (NOTE: in v0.1 this is always set to 1 layer)."""
70+
pre_lin_list = torch.nn.ModuleList()
71+
if self.pre_fc_count > 0:
72+
pre_lin_list = torch.nn.ModuleList()
73+
for i in range(self.pre_fc_count):
74+
if i == 0:
75+
lin = torch.nn.Linear(self.num_features, self.dim1)
76+
else:
77+
lin = torch.nn.Linear(self.dim1, self.dim1)
78+
79+
pre_lin_list.append(Sequential(lin, torch.nn.PReLU()))
80+
81+
return pre_lin_list
82+
83+
def _setup_gnn_layers(self):
84+
"""Sets up GNN layers."""
85+
conv_list = torch.nn.ModuleList()
86+
bn_list = torch.nn.ModuleList()
87+
for i in range(self.gc_count):
88+
conv = GCBlock(self.gc_dim, self.num_edge_features, aggr="mean")
89+
conv_list.append(conv)
90+
# Track running stats set to false can prevent some instabilities; this causes other issues with different val/test performance from loader size?
91+
if self.batch_norm:
92+
bn = BatchNorm1d(
93+
self.gc_dim, track_running_stats=self.batch_track_stats, affine=True
94+
)
95+
bn_list.append(bn)
96+
97+
return conv_list, bn_list
98+
99+
def forward(self, data):
100+
101+
# Pre-GNN dense layers
102+
for i in range(0, len(self.pre_lin_list)):
103+
if i == 0:
104+
out = self.pre_lin_list[i](data.x.float())
105+
else:
106+
out = self.pre_lin_list[i](out)
107+
108+
# GNN layers
109+
for i in range(0, len(self.conv_list)):
110+
if len(self.pre_lin_list) == 0 and i == 0:
111+
out = self.conv_list[i](data.x, data.edge_index, data.edge_attr.float())
112+
else:
113+
out = self.conv_list[i](out, data.edge_index, data.edge_attr.float())
114+
if self.batch_norm:
115+
out = self.bn_list[i](out)
116+
117+
out = F.dropout(out, p=self.dropout_rate, training=self.training)
118+
# Post-GNN dense layers
119+
dos_out = self.dos_mlp(out)
120+
scaling = self.scaling_mlp(out)
121+
122+
if dos_out.shape[1] == 1:
123+
return dos_out.view(-1), scaling.view(-1)
124+
else:
125+
return dos_out, scaling.view(-1)
126+
127+
128+
class GCBlock(MessagePassing):
129+
def __init__(
130+
self,
131+
channels: int | tuple[int, int],
132+
dim: int = 0,
133+
aggr: str = "mean",
134+
**kwargs,
135+
):
136+
super(GCBlock, self).__init__(aggr=aggr, **kwargs)
137+
self.channels = channels
138+
self.dim = dim
139+
140+
if isinstance(channels, int):
141+
channels = (channels, channels)
142+
143+
self.mlp = Sequential(
144+
Linear(sum(channels) + dim, channels[1]),
145+
torch.nn.PReLU(),
146+
)
147+
self.mlp2 = Sequential(
148+
Linear(dim, dim),
149+
torch.nn.PReLU(),
150+
)
151+
152+
def forward(
153+
self,
154+
x: Tensor | PairTensor,
155+
edge_index: Adj,
156+
edge_attr: OptTensor = None,
157+
size: Size = None,
158+
) -> Tensor:
159+
160+
if isinstance(x, Tensor):
161+
x: PairTensor = (x, x)
162+
163+
# propagate_type: (x: PairTensor, edge_attr: OptTensor)
164+
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
165+
out += x[1]
166+
return out
167+
168+
def message(self, x_i, x_j, edge_attr: OptTensor) -> Tensor:
169+
z = torch.cat([x_i, x_j, self.mlp2(edge_attr)], dim=-1)
170+
z = self.mlp(z)
171+
return z

matdeeplearn/modules/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
__all__ = ["Evaluator", "DOSLoss", "TorchLossWrapper", "LRScheduler"]
2+
3+
from .evaluator import Evaluator
4+
from .loss import DOSLoss, TorchLossWrapper
5+
from .scheduler import LRScheduler

matdeeplearn/modules/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self, task=None):
1111
def eval(self, prediction, target, loss_method, prev_metrics={}):
1212
metrics = prev_metrics
1313
res = loss_method(prediction, target)
14-
metrics = self.update(loss_method.__name__, res.item(), metrics)
14+
metrics = self.update(type(loss_method).__name__, res.item(), metrics)
1515

1616
return metrics
1717

0 commit comments

Comments
 (0)