Skip to content

Commit 1492423

Browse files
committed
Fixing output_feature_names issue
1 parent d4de9ba commit 1492423

3 files changed

Lines changed: 13 additions & 9 deletions

File tree

src/graphnet/models/data_representation/data_representation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,7 @@ def forward( # type: ignore
173173
input_features = self._detector(input_features, input_feature_names)
174174

175175
# Create data & get new final data feature names
176-
data, data_feature_names = self._create_data(
177-
input_features=input_features
178-
)
176+
data = self._create_data(input_features=input_features)
179177

180178
# Attach number of pulses as static attribute.
181179
data.n_pulses = torch.tensor(len(input_features), dtype=torch.int32)

src/graphnet/models/data_representation/graphs/graph_definition.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def _create_data(
124124
self, input_features: torch.Tensor
125125
) -> Tuple[Data, List[str]]:
126126
# Create graph & get new node feature names
127-
data, data_feature_names = self._node_definition(input_features)
127+
data = self._node_definition(input_features)
128128
if self._sort_by is not None:
129129
data.x = data.x[data.x[:, self._sort_by].sort()[1]]
130130

@@ -135,7 +135,7 @@ def _create_data(
135135
if self._edge_definition is not None:
136136
data = self._edge_definition(data)
137137

138-
return data, data_feature_names
138+
return data
139139

140140
def forward( # type: ignore
141141
self,
@@ -181,7 +181,7 @@ def forward( # type: ignore
181181
if self._add_static_features:
182182
data = self._add_features_individually(
183183
data,
184-
self._node_definition._output_feature_names,
184+
self.output_feature_names,
185185
)
186186

187187
return data

src/graphnet/models/data_representation/graphs/nodes/nodes.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,14 @@ def forward(self, x: torch.tensor) -> Tuple[Data, List[str]]:
4646
graph: a graph without edges
4747
"""
4848
graph = self._construct_nodes(x=x)
49+
50+
return graph
51+
52+
@property
53+
def _output_feature_names(self) -> List[str]:
54+
"""Return output feature names."""
4955
try:
50-
self._output_feature_names
56+
self._hidden_output_feature_names
5157
except AttributeError as e:
5258
self.error(
5359
f"""{self.__class__.__name__} was instantiated without
@@ -57,7 +63,7 @@ def forward(self, x: torch.tensor) -> Tuple[Data, List[str]]:
5763
with `input_feature_names`."""
5864
) # noqa
5965
raise e
60-
return graph, self._output_feature_names
66+
return self._hidden_output_feature_names
6167

6268
@property
6369
def nb_outputs(self) -> int:
@@ -85,7 +91,7 @@ def set_output_feature_names(self, input_feature_names: List[str]) -> None:
8591
input_feature_names: List of column names of the input to the
8692
node definition.
8793
"""
88-
self._output_feature_names = self._define_output_feature_names(
94+
self._hidden_output_feature_names = self._define_output_feature_names(
8995
input_feature_names
9096
)
9197

0 commit comments

Comments
 (0)