|
1 | 1 | """Base detector-specific `Model` class(es).""" |
2 | 2 |
|
3 | 3 | from abc import abstractmethod |
4 | | -from typing import Dict, Callable, List |
| 4 | +from typing import Dict, Callable, List, Optional |
5 | 5 |
|
6 | 6 | from torch_geometric.data import Data |
7 | 7 | import torch |
|
14 | 14 | class Detector(Model): |
15 | 15 | """Base class for all detector-specific read-ins in graphnet.""" |
16 | 16 |
|
17 | | - def __init__(self) -> None: |
18 | | - """Construct `Detector`.""" |
| 17 | + def __init__( |
| 18 | + self, replace_with_identity: Optional[List[str]] = None |
| 19 | + ) -> None: |
| 20 | + """Construct `Detector`. |
| 21 | +
|
| 22 | + Args: |
| 23 | + replace_with_identity: A list of feature names from the |
| 24 | + feature_map that should be replaced with the identity |
| 25 | + function. |
| 26 | + """ |
19 | 27 | # Base class constructor |
20 | 28 | super().__init__(name=__name__, class_name=self.__class__.__name__) |
| 29 | + self._replace_with_identity = replace_with_identity |
21 | 30 |
|
22 | 31 | @abstractmethod |
23 | 32 | def feature_map(self) -> Dict[str, Callable]: |
@@ -64,9 +73,13 @@ def sensor_index_name(self) -> str: |
64 | 73 | def _standardize( |
65 | 74 | self, input_features: torch.tensor, input_feature_names: List[str] |
66 | 75 | ) -> Data: |
| 76 | + feature_map = self.feature_map() |
| 77 | + if self._replace_with_identity is not None: |
| 78 | + for feature in self._replace_with_identity: |
| 79 | + feature_map[feature] = self._identity |
67 | 80 | for idx, feature in enumerate(input_feature_names): |
68 | 81 | try: |
69 | | - input_features[:, idx] = self.feature_map()[ |
| 82 | + input_features[:, idx] = feature_map[ |
70 | 83 | feature |
71 | 84 | ]( # noqa: E501 # type: ignore |
72 | 85 | input_features[:, idx] |
|
0 commit comments