Skip to content

Commit 79d7baf

Browse files
authored
Merge pull request graphnet-team#777 from pweigel/grit
Add GRIT model
2 parents 83c7330 + acab6e5 commit 79d7baf

10 files changed

Lines changed: 1481 additions & 12 deletions

File tree

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
"""Example of training Model."""
2+
3+
import os
4+
from typing import Any, Dict, List, Optional
5+
6+
from pytorch_lightning.loggers import WandbLogger
7+
import torch
8+
from torch.optim.adam import Adam
9+
10+
from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR
11+
from graphnet.data.constants import FEATURES, TRUTH
12+
from graphnet.models import StandardModel
13+
from graphnet.models.detector.prometheus import Prometheus
14+
from graphnet.models.gnn import GRIT
15+
from graphnet.models.graphs import KNNGraphRRWP
16+
from graphnet.models.task.reconstruction import EnergyReconstruction
17+
from graphnet.training.callbacks import PiecewiseLinearLR
18+
from graphnet.training.loss_functions import LogCoshLoss
19+
from graphnet.utilities.argparse import ArgumentParser
20+
from graphnet.utilities.logging import Logger
21+
from graphnet.data import GraphNeTDataModule
22+
from graphnet.data.dataset import SQLiteDataset
23+
from graphnet.data.dataset import ParquetDataset
24+
25+
# Constants
26+
features = FEATURES.PROMETHEUS
27+
truth = TRUTH.PROMETHEUS
28+
29+
30+
def main(
31+
path: str,
32+
pulsemap: str,
33+
target: str,
34+
truth_table: str,
35+
gpus: Optional[List[int]],
36+
max_epochs: int,
37+
early_stopping_patience: int,
38+
batch_size: int,
39+
num_workers: int,
40+
wandb: bool = False,
41+
) -> None:
42+
"""Run example."""
43+
# Construct Logger
44+
logger = Logger()
45+
46+
# Initialise Weights & Biases (W&B) run
47+
if wandb:
48+
# Make sure W&B output directory exists
49+
wandb_dir = "./wandb/"
50+
os.makedirs(wandb_dir, exist_ok=True)
51+
wandb_logger = WandbLogger(
52+
project="example-script",
53+
entity="graphnet-team",
54+
save_dir=wandb_dir,
55+
log_model=True,
56+
)
57+
58+
logger.info(f"features: {features}")
59+
logger.info(f"truth: {truth}")
60+
61+
# Configuration
62+
config: Dict[str, Any] = {
63+
"path": path,
64+
"pulsemap": pulsemap,
65+
"batch_size": batch_size,
66+
"num_workers": num_workers,
67+
"target": target,
68+
"early_stopping_patience": early_stopping_patience,
69+
"fit": {
70+
"gpus": gpus,
71+
"max_epochs": max_epochs,
72+
"distribution_strategy": "ddp_find_unused_parameters_true",
73+
},
74+
"dataset_reference": (
75+
SQLiteDataset if path.endswith(".db") else ParquetDataset
76+
),
77+
}
78+
79+
archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model_without_configs")
80+
run_name = "grit_{}_example".format(config["target"])
81+
if wandb:
82+
# Log configuration to W&B
83+
wandb_logger.experiment.config.update(config)
84+
85+
walk_length = 6
86+
graph_definition = KNNGraphRRWP(
87+
detector=Prometheus(),
88+
input_feature_names=features,
89+
nb_nearest_neighbours=5,
90+
walk_length=walk_length,
91+
)
92+
dm = GraphNeTDataModule(
93+
dataset_reference=config["dataset_reference"],
94+
dataset_args={
95+
"truth": truth,
96+
"truth_table": truth_table,
97+
"features": features,
98+
"graph_definition": graph_definition,
99+
"pulsemaps": [config["pulsemap"]],
100+
"path": config["path"],
101+
},
102+
train_dataloader_kwargs={
103+
"batch_size": config["batch_size"],
104+
"num_workers": config["num_workers"],
105+
},
106+
test_dataloader_kwargs={
107+
"batch_size": config["batch_size"],
108+
"num_workers": config["num_workers"],
109+
},
110+
)
111+
112+
training_dataloader = dm.train_dataloader
113+
validation_dataloader = dm.val_dataloader
114+
115+
# Building model
116+
backbone = GRIT(
117+
nb_inputs=graph_definition.nb_outputs,
118+
hidden_dim=32,
119+
ksteps=walk_length,
120+
)
121+
122+
task = EnergyReconstruction(
123+
hidden_size=backbone.nb_outputs,
124+
target_labels=config["target"],
125+
loss_function=LogCoshLoss(),
126+
transform_prediction_and_target=lambda x: torch.log10(x),
127+
transform_inference=lambda x: torch.pow(10, x),
128+
)
129+
130+
model = StandardModel(
131+
graph_definition=graph_definition,
132+
backbone=backbone,
133+
tasks=[task],
134+
optimizer_class=Adam,
135+
optimizer_kwargs={"lr": 1e-03, "eps": 1e-03},
136+
scheduler_class=PiecewiseLinearLR,
137+
scheduler_kwargs={
138+
"milestones": [
139+
0,
140+
len(training_dataloader) / 2,
141+
len(training_dataloader) * config["fit"]["max_epochs"],
142+
],
143+
"factors": [1e-2, 1, 1e-02],
144+
},
145+
scheduler_config={
146+
"interval": "step",
147+
},
148+
)
149+
150+
# Training model
151+
model.fit(
152+
training_dataloader,
153+
validation_dataloader,
154+
early_stopping_patience=config["early_stopping_patience"],
155+
logger=wandb_logger if wandb else None,
156+
**config["fit"],
157+
)
158+
159+
# Get predictions
160+
additional_attributes = model.target_labels
161+
assert isinstance(additional_attributes, list) # mypy
162+
163+
results = model.predict_as_dataframe(
164+
validation_dataloader,
165+
additional_attributes=additional_attributes + ["event_no"],
166+
gpus=config["fit"]["gpus"],
167+
)
168+
169+
# Save predictions and model to file
170+
db_name = path.split("/")[-1].split(".")[0]
171+
path = os.path.join(archive, db_name, run_name)
172+
logger.info(f"Writing results to {path}")
173+
os.makedirs(path, exist_ok=True)
174+
175+
results.to_csv(f"{path}/results.csv")
176+
177+
model.save(f"{path}/model.pth")
178+
model.save_state_dict(f"{path}/state_dict.pth")
179+
model.save_config(f"{path}/model_config.yml")
180+
181+
182+
if __name__ == "__main__":
183+
184+
# Parse command-line arguments
185+
parser = ArgumentParser(
186+
description="""
187+
Train GNN model without the use of config files.
188+
"""
189+
)
190+
191+
parser.add_argument(
192+
"--path",
193+
help="Path to dataset file (default: %(default)s)",
194+
default=f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db",
195+
)
196+
197+
parser.add_argument(
198+
"--pulsemap",
199+
help="Name of pulsemap to use (default: %(default)s)",
200+
default="total",
201+
)
202+
203+
parser.add_argument(
204+
"--target",
205+
help=(
206+
"Name of feature to use as regression target (default: "
207+
"%(default)s)"
208+
),
209+
default="total_energy",
210+
)
211+
212+
parser.add_argument(
213+
"--truth-table",
214+
help="Name of truth table to be used (default: %(default)s)",
215+
default="mc_truth",
216+
)
217+
218+
parser.with_standard_arguments(
219+
"gpus",
220+
("max-epochs", 1),
221+
"early-stopping-patience",
222+
("batch-size", 16),
223+
"num-workers",
224+
)
225+
226+
parser.add_argument(
227+
"--wandb",
228+
action="store_true",
229+
help="If True, Weights & Biases are used to track the experiment.",
230+
)
231+
232+
args, unknown = parser.parse_known_args()
233+
234+
main(
235+
args.path,
236+
args.pulsemap,
237+
args.target,
238+
args.truth_table,
239+
args.gpus,
240+
args.max_epochs,
241+
args.early_stopping_patience,
242+
args.batch_size,
243+
args.num_workers,
244+
args.wandb,
245+
)

0 commit comments

Comments
 (0)