Skip to content

Commit af5da20

Browse files
committed
Fixing graph processing logic
1 parent 76e22ff commit af5da20

3 files changed

Lines changed: 17 additions & 16 deletions

File tree

src/graphnet/models/data_representation/graphs/graph_definition.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ def __init__(
112112
self._sort_by = sort_by
113113
self._add_static_features = add_static_features
114114

115+
# make sure output feature names are set also in node definition
116+
self._set_output_feature_names(self._input_feature_names)
117+
115118
def _set_output_feature_names(
116119
self, input_feature_names: List[str]
117120
) -> List[str]:
@@ -160,9 +163,8 @@ def forward( # type: ignore
160163
loss_weight_default_value=loss_weight_default_value,
161164
data_path=data_path,
162165
)
163-
164166
# Create graph & get new node feature names
165-
data = self._node_definition(data.x)
167+
data.x = self._node_definition(data.x)
166168
if self._sort_by is not None:
167169
data.x = data.x[data.x[:, self._sort_by].sort()[1]]
168170

@@ -178,7 +180,6 @@ def forward( # type: ignore
178180
data,
179181
self.output_feature_names,
180182
)
181-
182183
return data
183184

184185
def _add_features_individually(

src/graphnet/models/data_representation/graphs/nodes/nodes.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(
3434
)
3535

3636
@final
37-
def forward(self, x: torch.tensor) -> Tuple[Data, List[str]]:
37+
def forward(self, x: torch.tensor) -> torch.tensor:
3838
"""Construct nodes from raw node features.
3939
4040
Args:
@@ -45,9 +45,9 @@ def forward(self, x: torch.tensor) -> Tuple[Data, List[str]]:
4545
Returns:
4646
graph: a graph without edges
4747
"""
48-
graph = self._construct_nodes(x=x)
48+
data = self._construct_nodes(x=x)
4949

50-
return graph
50+
return data
5151

5252
@property
5353
def _output_feature_names(self) -> List[str]:
@@ -110,7 +110,7 @@ def _define_output_feature_names(
110110
"""
111111

112112
@abstractmethod
113-
def _construct_nodes(self, x: torch.tensor) -> Data:
113+
def _construct_nodes(self, x: torch.tensor) -> torch.tensor:
114114
"""Construct nodes from raw node features ´x´.
115115
116116
Args:
@@ -133,8 +133,8 @@ def _define_output_feature_names(
133133
) -> List[str]:
134134
return input_feature_names
135135

136-
def _construct_nodes(self, x: torch.Tensor) -> Data:
137-
return Data(x=x)
136+
def _construct_nodes(self, x: torch.Tensor) -> torch.Tensor:
137+
return x
138138

139139

140140
class PercentileClusters(NodeDefinition):
@@ -195,7 +195,7 @@ def _get_indices_and_feature_names(
195195
new_feature_names.append("counts")
196196
return cluster_idx, summ_idx, new_feature_names
197197

198-
def _construct_nodes(self, x: torch.Tensor) -> Data:
198+
def _construct_nodes(self, x: torch.Tensor) -> torch.Tensor:
199199
# Cast to Numpy
200200
x = x.numpy()
201201
# Construct clusters with percentile-summarized features
@@ -219,7 +219,7 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
219219
) # noqa
220220
raise AttributeError
221221

222-
return Data(x=torch.tensor(array))
222+
return torch.tensor(array)
223223

224224

225225
class NodeAsDOMTimeSeries(NodeDefinition):
@@ -270,7 +270,7 @@ def _define_output_feature_names(
270270
) -> List[str]:
271271
return input_feature_names + ["new_node_col"]
272272

273-
def _construct_nodes(self, x: torch.Tensor) -> Data:
273+
def _construct_nodes(self, x: torch.Tensor) -> torch.Tensor:
274274
"""Construct nodes from raw node features ´x´."""
275275
# Cast to Numpy
276276
x = x.numpy()
@@ -309,7 +309,7 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
309309
new_node_col[0] = 1
310310
x = np.column_stack([x, new_node_col])
311311

312-
return Data(x=torch.tensor(x))
312+
return torch.tensor(x)
313313

314314

315315
class IceMixNodes(NodeDefinition):
@@ -440,7 +440,7 @@ def _pulse_sampler(
440440

441441
return ids
442442

443-
def _construct_nodes(self, x: torch.Tensor) -> Data:
443+
def _construct_nodes(self, x: torch.Tensor) -> torch.Tensor:
444444

445445
event_length = x.shape[0]
446446
if self.hlc_name is not None:
@@ -468,4 +468,4 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
468468
for idx, feature in enumerate(non_ice_features):
469469
graph[:event_length, idx] = x[ids, self.feature_indexes[feature]]
470470

471-
return Data(x=graph)
471+
return graph

tests/models/test_node_definition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_percentile_cluster() -> None:
4040
# Apply node definition to torch tensor with raw pulses
4141
graph = node_definition(tensor)
4242
new_features = node_definition._output_feature_names
43-
x_tilde = graph.x.numpy()
43+
x_tilde = graph.numpy()
4444

4545
# Calculate percentiles "the normal way" and compare that output of
4646
# node definition match.

0 commit comments

Comments
 (0)