Skip to content

Commit 76e22ff

Browse files
committed
getting rid of _create_data
1 parent 03dec8a commit 76e22ff

2 files changed

Lines changed: 14 additions & 44 deletions

File tree

src/graphnet/models/data_representation/data_representation.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def forward( # type: ignore
170170
input_features = self._detector(input_features, input_feature_names)
171171

172172
# Create data & get new final data feature names
173-
data = self._create_data(input_features=input_features)
173+
data = Data(x=input_features)
174174

175175
# Attach number of pulses as static attribute.
176176
data.n_pulses = torch.tensor(len(input_features), dtype=torch.int32)
@@ -413,20 +413,6 @@ def _set_output_feature_names(
413413
"""Set the final data output feature names."""
414414
raise NotImplementedError
415415

416-
@abstractmethod
417-
def _create_data(self, input_features: torch.Tensor) -> Data:
418-
"""Create data from input features.
419-
420-
Enforce the dtype of the feature tensor.
421-
E.g.: `data.x = data.x.type(self.dtype)`
422-
if the training data is stored in `data.x`.
423-
424-
Should return:
425-
- data: torch_geometric.data.Data object representing the event.
426-
- data_feature_names: List of feature names in the data object.
427-
"""
428-
raise NotImplementedError
429-
430416
def _label_repeater(self, label: torch.Tensor, data: Data) -> torch.Tensor:
431417
"""Handle the label repetition.
432418

src/graphnet/models/data_representation/graphs/graph_definition.py

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
and can be passed to dataloaders during training and deployment.
66
"""
77

8-
from typing import List, Optional, Dict, Union, Tuple, Any, Callable
8+
from typing import List, Optional, Dict, Union, Any, Callable
99
import torch
1010
from numpy.random import Generator
1111
import numpy as np
@@ -120,23 +120,6 @@ def _set_output_feature_names(
120120
self._node_definition.set_output_feature_names(input_feature_names)
121121
return self._node_definition._output_feature_names
122122

123-
def _create_data(
124-
self, input_features: torch.Tensor
125-
) -> Tuple[Data, List[str]]:
126-
# Create graph & get new node feature names
127-
data = self._node_definition(input_features)
128-
if self._sort_by is not None:
129-
data.x = data.x[data.x[:, self._sort_by].sort()[1]]
130-
131-
# Enforce dtype
132-
data.x = data.x.type(self.dtype)
133-
134-
# Assign edges
135-
if self._edge_definition is not None:
136-
data = self._edge_definition(data)
137-
138-
return data
139-
140123
def forward( # type: ignore
141124
self,
142125
input_features: np.ndarray,
@@ -178,6 +161,18 @@ def forward( # type: ignore
178161
data_path=data_path,
179162
)
180163

164+
# Create graph & get new node feature names
165+
data = self._node_definition(data.x)
166+
if self._sort_by is not None:
167+
data.x = data.x[data.x[:, self._sort_by].sort()[1]]
168+
169+
# Enforce dtype
170+
data.x = data.x.type(self.dtype)
171+
172+
# Assign edges
173+
if self._edge_definition is not None:
174+
data = self._edge_definition(data)
175+
181176
if self._add_static_features:
182177
data = self._add_features_individually(
183178
data,
@@ -186,17 +181,6 @@ def forward( # type: ignore
186181

187182
return data
188183

189-
def _forward_end(
190-
self,
191-
data: Data,
192-
data_feature_names: List[str],
193-
) -> Data:
194-
"""Add processing steps at the end of the forward pass."""
195-
# Add original features as attributes
196-
if self._add_static_features:
197-
data = self._add_features_individually(data, data_feature_names)
198-
return data
199-
200184
def _add_features_individually(
201185
self,
202186
data: Data,

0 commit comments

Comments
 (0)