|
5 | 5 | and can be passed to dataloaders during training and deployment. |
6 | 6 | """ |
7 | 7 |
|
8 | | -from typing import List, Optional, Dict, Union, Tuple |
| 8 | +from typing import List, Optional, Dict, Union, Tuple, Any, Callable |
9 | 9 | import torch |
10 | 10 | from numpy.random import Generator |
| 11 | +import numpy as np |
11 | 12 |
|
12 | 13 | from graphnet.models.detector import Detector |
13 | 14 | from .edges import EdgeDefinition |
@@ -144,6 +145,55 @@ def _create_data( |
144 | 145 |
|
145 | 146 | return data, data_feature_names |
146 | 147 |
|
| 148 | + def forward( # type: ignore |
| 149 | + self, |
| 150 | + input_features: np.ndarray, |
| 151 | + input_feature_names: List[str], |
| 152 | + truth_dicts: Optional[List[Dict[str, Any]]] = None, |
| 153 | + custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None, |
| 154 | + loss_weight_column: Optional[str] = None, |
| 155 | + loss_weight: Optional[float] = None, |
| 156 | + loss_weight_default_value: Optional[float] = None, |
| 157 | + data_path: Optional[str] = None, |
| 158 | + ) -> Data: |
| 159 | + """Construct graph as ´Data´ object. |
| 160 | +
|
| 161 | + Args: |
| 162 | + input_features: Input features for graph construction. |
| 163 | + Shape ´[num_rows, d]´ |
| 164 | + input_feature_names: name of each column. Shape ´[,d]´. |
| 165 | + truth_dicts: Dictionary containing truth labels. |
| 166 | + custom_label_functions: Custom label functions. |
| 167 | + loss_weight_column: Name of column that holds loss weight. |
| 168 | + Defaults to None. |
| 169 | + loss_weight: Loss weight associated with event. Defaults to None. |
| 170 | + loss_weight_default_value: default value for loss weight. |
| 171 | + Used in instances where some events have |
| 172 | + no pre-defined loss weight. Defaults to None. |
| 173 | + data_path: Path to dataset data files. Defaults to None. |
| 174 | +
|
| 175 | + Returns: |
| 176 | + graph |
| 177 | + """ |
| 178 | + data = super().forward( |
| 179 | + input_features=input_features, |
| 180 | + input_feature_names=input_feature_names, |
| 181 | + truth_dicts=truth_dicts, |
| 182 | + custom_label_functions=custom_label_functions, |
| 183 | + loss_weight_column=loss_weight_column, |
| 184 | + loss_weight=loss_weight, |
| 185 | + loss_weight_default_value=loss_weight_default_value, |
| 186 | + data_path=data_path, |
| 187 | + ) |
| 188 | + |
| 189 | + if self._add_static_features: |
| 190 | + data = self._add_features_individually( |
| 191 | + data, |
| 192 | + self._node_definition._output_feature_names, |
| 193 | + ) |
| 194 | + |
| 195 | + return data |
| 196 | + |
147 | 197 | def _forward_end( |
148 | 198 | self, |
149 | 199 | data: Data, |
|
0 commit comments