Skip to content

Commit b64ec50

Browse files
committed
forward() in GraphDefinition
1 parent df2288f commit b64ec50

2 files changed

Lines changed: 54 additions & 21 deletions

File tree

src/graphnet/models/data_representation/data_representation.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from numpy.random import default_rng, Generator
88

99
from graphnet.models.detector import Detector
10-
from graphnet.utilities.decorators import final
1110
from graphnet.models import Model
1211
from abc import abstractmethod
1312

@@ -109,7 +108,6 @@ def __init__(
109108
else:
110109
self.rng = default_rng()
111110

112-
@final
113111
def forward( # type: ignore
114112
self,
115113
input_features: np.ndarray,
@@ -121,10 +119,10 @@ def forward( # type: ignore
121119
loss_weight_default_value: Optional[float] = None,
122120
data_path: Optional[str] = None,
123121
) -> Data:
124-
"""Construct graph as ´Data´ object.
122+
"""Construct data as ´Data´ object.
125123
126124
Args:
127-
input_features: Input features for graph construction.
125+
input_features: Input features for data construction.
128126
Shape ´[num_rows, d]´
129127
input_feature_names: name of each column. Shape ´[,d]´.
130128
truth_dicts: Dictionary containing truth labels.
@@ -138,7 +136,7 @@ def forward( # type: ignore
138136
data_path: Path to dataset data files. Defaults to None.
139137
140138
Returns:
141-
graph
139+
data
142140
"""
143141
# Checks
144142
self._validate_input(
@@ -197,9 +195,6 @@ def forward( # type: ignore
197195
data=data, custom_label_functions=custom_label_functions
198196
)
199197

200-
# Do final processing steps
201-
data = self._forward_end(data, data_feature_names)
202-
203198
# DEPRECATION STAMP GRAPH_DEFINITION: REMOVE AT 2.0 LAUNCH
204199
# See https://github.com/graphnet-team/graphnet/issues/647
205200
data["graph_definition"] = self.__class__.__name__
@@ -439,18 +434,6 @@ def _create_data(self, input_features: torch.Tensor) -> Data:
439434
"""
440435
raise NotImplementedError
441436

442-
def _forward_end(
443-
self,
444-
data: Data,
445-
data_feature_names: List[str],
446-
) -> Data:
447-
"""Place to add any final data processing steps.
448-
449-
Override this method to add any final processing steps in the end of
450-
the forward pass.
451-
"""
452-
return data
453-
454437
def _label_repeater(self, label: torch.Tensor, data: Data) -> torch.Tensor:
455438
"""Handle the label repetition.
456439

src/graphnet/models/data_representation/graphs/graph_definition.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
and can be passed to dataloaders during training and deployment.
66
"""
77

8-
from typing import List, Optional, Dict, Union, Tuple
8+
from typing import List, Optional, Dict, Union, Tuple, Any, Callable
99
import torch
1010
from numpy.random import Generator
11+
import numpy as np
1112

1213
from graphnet.models.detector import Detector
1314
from .edges import EdgeDefinition
@@ -144,6 +145,55 @@ def _create_data(
144145

145146
return data, data_feature_names
146147

148+
def forward( # type: ignore
149+
self,
150+
input_features: np.ndarray,
151+
input_feature_names: List[str],
152+
truth_dicts: Optional[List[Dict[str, Any]]] = None,
153+
custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None,
154+
loss_weight_column: Optional[str] = None,
155+
loss_weight: Optional[float] = None,
156+
loss_weight_default_value: Optional[float] = None,
157+
data_path: Optional[str] = None,
158+
) -> Data:
159+
"""Construct graph as ´Data´ object.
160+
161+
Args:
162+
input_features: Input features for graph construction.
163+
Shape ´[num_rows, d]´
164+
input_feature_names: name of each column. Shape ´[,d]´.
165+
truth_dicts: Dictionary containing truth labels.
166+
custom_label_functions: Custom label functions.
167+
loss_weight_column: Name of column that holds loss weight.
168+
Defaults to None.
169+
loss_weight: Loss weight associated with event. Defaults to None.
170+
loss_weight_default_value: default value for loss weight.
171+
Used in instances where some events have
172+
no pre-defined loss weight. Defaults to None.
173+
data_path: Path to dataset data files. Defaults to None.
174+
175+
Returns:
176+
graph
177+
"""
178+
data = super().forward(
179+
input_features=input_features,
180+
input_feature_names=input_feature_names,
181+
truth_dicts=truth_dicts,
182+
custom_label_functions=custom_label_functions,
183+
loss_weight_column=loss_weight_column,
184+
loss_weight=loss_weight,
185+
loss_weight_default_value=loss_weight_default_value,
186+
data_path=data_path,
187+
)
188+
189+
if self._add_static_features:
190+
data = self._add_features_individually(
191+
data,
192+
self._node_definition._output_feature_names,
193+
)
194+
195+
return data
196+
147197
def _forward_end(
148198
self,
149199
data: Data,

0 commit comments

Comments
 (0)