Skip to content

Commit 8477757

Browse files
authored
BERT-style mask prediction pretraining (#851)
the added (experimental) features are: * a new task type that takes an augmentation-like model and a corresponding loss function to perform the unsupervised training * an analogue to the StandardModel that takes an augmentationn task like the one introduced above and a mostly arbitrary model to pretrain the latter * default components to the new task, which perform masking and comparison against masked values * a minor new loss function that does the negative cosine loss * an example file to illustrate the added functionalities
1 parent 6b5608a commit 8477757

4 files changed

Lines changed: 491 additions & 0 deletions

File tree

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""Minimal example for use of maskpred pretraining."""
2+
3+
from typing import Tuple
4+
5+
from graphnet.models.pretraining_maskpred import mask_pred_frame
6+
from graphnet.models.pretraining_maskpred import default_mask_augment
7+
from graphnet.models.pretraining_maskpred import default_loss_calc
8+
from graphnet.models import Model
9+
from torch_geometric.data import Data
10+
from graphnet.models.data_representation.graphs import KNNGraph
11+
from graphnet.data.dataset.sqlite.sqlite_dataset import SQLiteDataset
12+
from graphnet.data.dataloader import DataLoader
13+
from graphnet.constants import EXAMPLE_DATA_DIR
14+
15+
from torch_scatter import scatter
16+
17+
import torch
18+
from torch import Tensor
19+
20+
from graphnet.models.detector.prometheus import Prometheus
21+
from graphnet.models.graphs.nodes import NodesAsPulses
22+
23+
from graphnet.models.task.task import UnsupervisedTask
24+
25+
26+
class simple_model(Model):
27+
"""Just for a dummy model."""
28+
29+
def __init__(
30+
self,
31+
) -> None:
32+
"""Construct."""
33+
super().__init__()
34+
self.net = torch.nn.Sequential(
35+
torch.nn.Linear(4, 10), torch.nn.SELU(), torch.nn.Linear(10, 5)
36+
)
37+
38+
def forward(self, data: Data) -> Tuple[Tensor, Tensor]:
39+
"""Forward pass."""
40+
x = self.net(data.x)
41+
x_rep = scatter(src=x, index=data.batch, dim=0, reduce="max")
42+
return x, x_rep
43+
44+
45+
def test() -> None:
46+
"""Short test with saving at the end."""
47+
graph_definition = KNNGraph(
48+
detector=Prometheus(),
49+
node_definition=NodesAsPulses(),
50+
nb_nearest_neighbours=8,
51+
)
52+
53+
dataset = SQLiteDataset(
54+
path=f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db",
55+
pulsemaps="total",
56+
truth_table="mc_truth",
57+
features=["sensor_pos_x", "sensor_pos_y", "sensor_pos_z", "t", "q"],
58+
truth=["injection_energy", "injection_zenith"],
59+
data_representation=graph_definition,
60+
)
61+
62+
dataloader = DataLoader(
63+
dataset,
64+
batch_size=3,
65+
num_workers=10,
66+
)
67+
68+
for batch in dataloader:
69+
data = batch
70+
break
71+
72+
dummy_model = simple_model()
73+
default_task = UnsupervisedTask(
74+
default_mask_augment(), default_loss_calc()
75+
)
76+
77+
model = mask_pred_frame(
78+
encoder=dummy_model,
79+
bert_task=default_task,
80+
encoder_out_dim=5,
81+
need_charge_rep=False,
82+
)
83+
84+
out = model(data)
85+
print(out)
86+
87+
# for training
88+
# model.fit(train_dataloader=dataloader, max_epochs=10, gpus=1)
89+
90+
# for saving
91+
# model.save_pretrained_model('some/path')
92+
93+
94+
if __name__ == "__main__":
95+
test()

0 commit comments

Comments
 (0)