Skip to content

Commit c826a2d

Browse files
authored
Merge pull request graphnet-team#769 from Aske-Rosted/replace_with_identity
Optional replacing of feature mapping with identity
2 parents 4898fb4 + 0966f89 commit c826a2d

1 file changed

Lines changed: 17 additions & 4 deletions

File tree

src/graphnet/models/detector/detector.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Base detector-specific `Model` class(es)."""
22

33
from abc import abstractmethod
4-
from typing import Dict, Callable, List
4+
from typing import Dict, Callable, List, Optional
55

66
from torch_geometric.data import Data
77
import torch
@@ -14,10 +14,19 @@
1414
class Detector(Model):
1515
"""Base class for all detector-specific read-ins in graphnet."""
1616

17-
def __init__(self) -> None:
18-
"""Construct `Detector`."""
17+
def __init__(
18+
self, replace_with_identity: Optional[List[str]] = None
19+
) -> None:
20+
"""Construct `Detector`.
21+
22+
Args:
23+
replace_with_identity: A list of feature names from the
24+
feature_map that should be replaced with the identity
25+
function.
26+
"""
1927
# Base class constructor
2028
super().__init__(name=__name__, class_name=self.__class__.__name__)
29+
self._replace_with_identity = replace_with_identity
2130

2231
@abstractmethod
2332
def feature_map(self) -> Dict[str, Callable]:
@@ -64,9 +73,13 @@ def sensor_index_name(self) -> str:
6473
def _standardize(
6574
self, input_features: torch.tensor, input_feature_names: List[str]
6675
) -> Data:
76+
feature_map = self.feature_map()
77+
if self._replace_with_identity is not None:
78+
for feature in self._replace_with_identity:
79+
feature_map[feature] = self._identity
6780
for idx, feature in enumerate(input_feature_names):
6881
try:
69-
input_features[:, idx] = self.feature_map()[
82+
input_features[:, idx] = feature_map[
7083
feature
7184
]( # noqa: E501 # type: ignore
7285
input_features[:, idx]

0 commit comments

Comments
 (0)