Skip to content

Commit 282cd5c

Browse files
committed
fix word wrapping
1 parent d2204bb commit 282cd5c

1 file changed

Lines changed: 96 additions & 96 deletions

File tree

Lines changed: 96 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,96 @@
1-
"""Base detector-specific `Model` class(es)."""
2-
3-
from abc import abstractmethod
4-
from typing import Dict, Callable, List, Optional
5-
6-
from torch_geometric.data import Data
7-
import torch
8-
import pandas as pd
9-
10-
from graphnet.models import Model
11-
from graphnet.utilities.decorators import final
12-
13-
14-
class Detector(Model):
15-
"""Base class for all detector-specific read-ins in graphnet."""
16-
17-
def __init__(
18-
self, replace_with_identity: Optional[List[str]] = None
19-
) -> None:
20-
"""Construct `Detector`.
21-
22-
Args:
23-
replace_with_identity: A list of feature names from the
24-
feature_map that should be replaced with the identity
25-
function.
26-
"""
27-
# Base class constructor
28-
super().__init__(name=__name__, class_name=self.__class__.__name__)
29-
self._replace_with_identity = replace_with_identity
30-
31-
@abstractmethod
32-
def feature_map(self) -> Dict[str, Callable]:
33-
"""List of features used/assumed by inheriting `Detector` objects."""
34-
35-
@final
36-
def forward( # type: ignore
37-
self, input_features: torch.tensor, input_feature_names: List[str]
38-
) -> Data:
39-
"""Pre-process graph `Data` features and build graph adjacency."""
40-
return self._standardize(input_features, input_feature_names)
41-
42-
@property
43-
def geometry_table(self) -> pd.DataFrame:
44-
"""Public get method for retrieving a `Detector`s geometry table."""
45-
if ~hasattr(self, "_geometry_table"):
46-
try:
47-
assert hasattr(self, "geometry_table_path")
48-
except AssertionError as e:
49-
self.error(
50-
f"""{self.__class__.__name__} does not have class
51-
variable `geometry_table_path` set."""
52-
)
53-
raise e
54-
self._geometry_table = pd.read_parquet(self.geometry_table_path)
55-
return self._geometry_table
56-
57-
@property
58-
def string_index_name(self) -> str:
59-
"""Public get method for retrieving the string index column name."""
60-
return self.string_id_column
61-
62-
@property
63-
def sensor_position_names(self) -> List[str]:
64-
"""Public get method for retrieving the xyz coordinate column names."""
65-
return self.xyz
66-
67-
@property
68-
def sensor_index_name(self) -> str:
69-
"""Public get method for retrieving the sensor id column name."""
70-
return self.sensor_id_column
71-
72-
@final
73-
def _standardize(
74-
self, input_features: torch.tensor, input_feature_names: List[str]
75-
) -> Data:
76-
feature_map = self.feature_map()
77-
if self._replace_with_identity is not None:
78-
for feature in self._replace_with_identity:
79-
feature_map[feature] = self._identity
80-
for idx, feature in enumerate(input_feature_names):
81-
try:
82-
input_features[:, idx] = feature_map[
83-
feature
84-
]( # noqa: E501 # type: ignore
85-
input_features[:, idx]
86-
)
87-
except KeyError as e:
88-
self.warning(
89-
f"""No Standardization function found for '{feature}'"""
90-
)
91-
raise e
92-
return input_features
93-
94-
def _identity(self, x: torch.tensor) -> torch.tensor:
95-
"""Apply no standardization to input."""
96-
return x
1+
"""Base detector-specific `Model` class(es)."""
2+
3+
from abc import abstractmethod
4+
from typing import Dict, Callable, List, Optional
5+
6+
from torch_geometric.data import Data
7+
import torch
8+
import pandas as pd
9+
10+
from graphnet.models import Model
11+
from graphnet.utilities.decorators import final
12+
13+
14+
class Detector(Model):
15+
"""Base class for all detector-specific read-ins in graphnet."""
16+
17+
def __init__(
18+
self, replace_with_identity: Optional[List[str]] = None
19+
) -> None:
20+
"""Construct `Detector`.
21+
22+
Args:
23+
replace_with_identity: A list of feature names from the
24+
feature_map that should be replaced with the identity
25+
function.
26+
"""
27+
# Base class constructor
28+
super().__init__(name=__name__, class_name=self.__class__.__name__)
29+
self._replace_with_identity = replace_with_identity
30+
31+
@abstractmethod
32+
def feature_map(self) -> Dict[str, Callable]:
33+
"""List of features used/assumed by inheriting `Detector` objects."""
34+
35+
@final
36+
def forward( # type: ignore
37+
self, input_features: torch.tensor, input_feature_names: List[str]
38+
) -> Data:
39+
"""Pre-process graph `Data` features and build graph adjacency."""
40+
return self._standardize(input_features, input_feature_names)
41+
42+
@property
43+
def geometry_table(self) -> pd.DataFrame:
44+
"""Public get method for retrieving a `Detector`s geometry table."""
45+
if ~hasattr(self, "_geometry_table"):
46+
try:
47+
assert hasattr(self, "geometry_table_path")
48+
except AssertionError as e:
49+
self.error(
50+
f"""{self.__class__.__name__} does not have class
51+
variable `geometry_table_path` set."""
52+
)
53+
raise e
54+
self._geometry_table = pd.read_parquet(self.geometry_table_path)
55+
return self._geometry_table
56+
57+
@property
58+
def string_index_name(self) -> str:
59+
"""Public get method for retrieving the string index column name."""
60+
return self.string_id_column
61+
62+
@property
63+
def sensor_position_names(self) -> List[str]:
64+
"""Public get method for retrieving the xyz coordinate column names."""
65+
return self.xyz
66+
67+
@property
68+
def sensor_index_name(self) -> str:
69+
"""Public get method for retrieving the sensor id column name."""
70+
return self.sensor_id_column
71+
72+
@final
73+
def _standardize(
74+
self, input_features: torch.tensor, input_feature_names: List[str]
75+
) -> Data:
76+
feature_map = self.feature_map()
77+
if self._replace_with_identity is not None:
78+
for feature in self._replace_with_identity:
79+
feature_map[feature] = self._identity
80+
for idx, feature in enumerate(input_feature_names):
81+
try:
82+
input_features[:, idx] = feature_map[
83+
feature
84+
]( # noqa: E501 # type: ignore
85+
input_features[:, idx]
86+
)
87+
except KeyError as e:
88+
self.warning(
89+
f"""No Standardization function found for '{feature}'"""
90+
)
91+
raise e
92+
return input_features
93+
94+
def _identity(self, x: torch.tensor) -> torch.tensor:
95+
"""Apply no standardization to input."""
96+
return x

0 commit comments

Comments
 (0)