Skip to content

Commit d4de9ba

Browse files
committed
removing pre_init
1 parent b64ec50 commit d4de9ba

2 files changed

Lines changed: 22 additions & 32 deletions

File tree

src/graphnet/models/data_representation/data_representation.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def __init__(
2525
sensor_mask: Optional[List[int]] = None,
2626
string_mask: Optional[List[int]] = None,
2727
repeat_labels: bool = False,
28-
**kwargs: Any,
2928
):
3029
"""Construct´DataRepresentation´. The ´detector´ holds.
3130
@@ -59,8 +58,6 @@ def __init__(
5958
# Base class constructor
6059
super().__init__(name=__name__, class_name=self.__class__.__name__)
6160

62-
self._pre_init(**kwargs)
63-
6461
# Member Variables
6562
self._detector = detector
6663
self._perturbation_dict = perturbation_dict
@@ -78,18 +75,9 @@ def __init__(
7875
) # noqa: E501 # type: ignore
7976
self._input_feature_names = input_feature_names
8077

81-
# Set final data column names
82-
self.output_feature_names = self._set_output_feature_names(
83-
self._input_feature_names
84-
)
85-
8678
# Set data type
8779
self.to(dtype)
8880

89-
# Set Input / Output dimensions
90-
self.nb_inputs = len(self._input_feature_names)
91-
self.nb_outputs = len(self.output_feature_names)
92-
9381
# Set perturbation_cols if needed
9482
if isinstance(self._perturbation_dict, dict):
9583
self._perturbation_cols = [
@@ -108,6 +96,25 @@ def __init__(
10896
else:
10997
self.rng = default_rng()
11098

99+
@property
100+
def nb_inputs(self) -> int:
101+
"""Return the number of input features."""
102+
return len(self._input_feature_names)
103+
104+
@property
105+
def nb_outputs(self) -> int:
106+
"""Return the number of output features."""
107+
return len(self.output_feature_names)
108+
109+
@property
110+
def output_feature_names(self) -> List[str]:
111+
"""Initialize / return the names of output features."""
112+
if not hasattr(self, "_output_feature_names"):
113+
self._output_feature_names = self._set_output_feature_names(
114+
self._input_feature_names
115+
)
116+
return self._output_feature_names
117+
111118
def forward( # type: ignore
112119
self,
113120
input_features: np.ndarray,
@@ -411,15 +418,6 @@ def _set_output_feature_names(
411418
"""Set the final data output feature names."""
412419
raise NotImplementedError
413420

414-
@abstractmethod
415-
def _pre_init(self, **kwargs: Any) -> None:
416-
"""Assign member varibales within `__init__()`.
417-
418-
E.g. Necessary member varibales for the `_set_output_feature_names()`
419-
call in `__init__()`.
420-
"""
421-
raise NotImplementedError
422-
423421
@abstractmethod
424422
def _create_data(self, input_features: torch.Tensor) -> Data:
425423
"""Create data from input features.

src/graphnet/models/data_representation/graphs/graph_definition.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,6 @@
2020
class GraphDefinition(DataRepresentation):
2121
"""An Abstract class to create graph definitions from."""
2222

23-
def _pre_init( # type: ignore[override]
24-
self, node_definition: NodeDefinition
25-
) -> None:
26-
"""Pre-initialization steps."""
27-
# Pre-initialization steps
28-
self._node_definition = node_definition
29-
3023
def __init__(
3124
self,
3225
detector: Detector,
@@ -81,9 +74,6 @@ def __init__(
8174
add_static_features: If True, the original features will be
8275
added as static attributes to the graph. Defaults to True.
8376
"""
84-
if node_definition is None:
85-
node_definition = NodesAsPulses()
86-
8777
# Base class constructor
8878
super().__init__(
8979
detector=detector,
@@ -95,9 +85,11 @@ def __init__(
9585
sensor_mask=sensor_mask,
9686
string_mask=string_mask,
9787
repeat_labels=repeat_labels,
98-
node_definition=node_definition, # -> kwargs
9988
)
10089

90+
if node_definition is None:
91+
node_definition = NodesAsPulses()
92+
self._node_definition = node_definition
10193
self._edge_definition = edge_definition
10294
if self._edge_definition is None:
10395
self.warning_once(

0 commit comments

Comments
 (0)