Skip to content

Commit cb54822

Browse files
committed
Merge remote-tracking branch 'maxwell_fork/both_models' into multitask_bundle_combine
2 parents ac63d3a + 0bdf36b commit cb54822

4 files changed

Lines changed: 61 additions & 33 deletions

File tree

src/graphnet/deployment/deployment_module.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def _resolve_prediction_columns(
111111
prediction_columns[i]
112112
)
113113
else:
114-
resolved_prediction_columns.append(model.prediction_labels)
114+
# Only Take First Label
115+
resolved_prediction_columns.append(model.prediction_labels[0])
115116
return resolved_prediction_columns
116117

117118
def _inference(self, data: Union[Data, Batch]) -> List[np.ndarray]:
@@ -139,5 +140,9 @@ def _inference(self, data: Union[Data, Batch]) -> List[np.ndarray]:
139140
output = model(data=data[_])
140141
for k in range(len(output)):
141142
output[k] = output[k].detach().cpu().numpy()
143+
if output[0].shape[1] > 1:
144+
output = np.delete(output[0], 1, axis=1)
145+
else:
146+
output = output[0]
142147
outputs.append(output)
143148
return outputs
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Deployment modules specific to IceCube."""
22

3-
from .inference_module import I3InferenceModule, I3ParticleInferenceModule
3+
from .inference_module import I3InferenceModule, I3ParticleInferenceModule, I3MultipleModelInferenceModule
44
from .cleaning_module import I3PulseCleanerModule
55
from .i3deployer import I3Deployer

src/graphnet/deployment/icecube/inference_module.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44
"""
55

66
from typing import List, Union, Optional, TYPE_CHECKING, Dict, Any
7-
87
import numpy as np
98
from torch_geometric.data import Data, Batch
10-
119
from torch.cuda import OutOfMemoryError
1210
from time import sleep
1311
import torch
@@ -51,7 +49,6 @@ def __init__(
5149
prediction_columns: Optional[Union[List[str], None]] = None,
5250
pulsemap: Optional[str] = None,
5351
multiple_models: bool = False,
54-
key_name: Optional[str] = None,
5552
requirements: Optional[callable] = None,
5653
device: Optional[str] = "cpu",
5754
batch_size: Optional[int] = 1,
@@ -75,8 +72,6 @@ def __init__(
7572
E.g. ['energy_reco']. Optional.
7673
pulsemap: the pulsmap that the model is expecting as input.
7774
multiple_models: process multiple models with the same feature set at once.
78-
key_name: The name used for the key in the I3Frame. Will help define the
79-
named entry in the I3Frame. E.g. "dynedge_predictions".
8075
"""
8176
super().__init__(
8277
model_config=model_config,
@@ -130,7 +125,6 @@ def __init__(
130125
self._num_threads = num_threads
131126
self._inference_speed_check = inference_speed_check
132127
self._multiple_models = multiple_models
133-
self._key_name = key_name
134128
# Set GCD file for pulsemap extractor
135129
if gcd_file is not None:
136130
for i3_extractor in self._i3_extractors:
@@ -266,26 +260,31 @@ def _create_dictionary(
266260
) -> Dict[str, Any]:
267261
"""Transform predictions into a dictionary."""
268262
data = {}
269-
for i in range(dim):
270-
data[self.model_name + "_" + self.prediction_columns[i]] = (
271-
I3Double(float(predictions[i]))
272-
)
263+
if self._multiple_models == True:
264+
for i, key in enumerate(self.model_name):
265+
data[key] = float(predictions[i])
266+
else:
267+
for i in range(dim):
268+
data[self.model_name + "_" + self.prediction_columns[i]] = (
269+
I3Double(float(predictions[i]))
270+
)
273271

274-
# try:
275-
# assert len(predictions[:, i]) == 1
276-
# data[
277-
# self.model_name + "_" + self.prediction_columns[i]
278-
# ] = I3Double(float(predictions[:, i][0]))
279-
# except IndexError:
280-
# data[
281-
# self.model_name + "_" + self.prediction_columns[i]
282-
# ] = I3Double(predictions[0])
272+
# try:
273+
# assert len(predictions[:, i]) == 1
274+
# data[
275+
# self.model_name + "_" + self.prediction_columns[i]
276+
# ] = I3Double(float(predictions[:, i][0]))
277+
# except IndexError:
278+
# data[
279+
# self.model_name + "_" + self.prediction_columns[i]
280+
# ] = I3Double(predictions[0])
283281
return data
284282

285283
def _apply_model(self, data: Data) -> np.ndarray:
286284
"""Apply model to `Data` and case-handling."""
287285
if data is not None:
288286
predictions = self._inference(data)
287+
#print(predictions, type(predictions), type(predictions[0]))
289288
if isinstance(predictions, list):
290289
predictions = np.concatenate(
291290
[pred.flatten() for pred in predictions]
@@ -479,3 +478,25 @@ def _add_to_frame(self, frame, data):
479478
# for all the other values in data, add them to an I3Dictionary
480479
super()._add_to_frame(frame=frame, data=data)
481480
return
481+
482+
class I3MultipleModelInferenceModule(I3InferenceModule):
483+
"""I3InferenceModule for I3Particle data."""
484+
485+
def __init__(
486+
self,
487+
key_name: List[str],
488+
**kwargs,
489+
):
490+
"""
491+
key_name: The name used for the key in the I3Frame. Will help define the
492+
named entry in the I3Frame.
493+
"""
494+
super().__init__(**kwargs)
495+
496+
self._key_name = key_name
497+
498+
def _add_to_frame(self, frame, data):
499+
500+
i3_score_container = dataclasses.I3MapStringDouble(data)
501+
frame.Put(self._key_name, i3_score_container)
502+

src/graphnet/models/standard_model.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,18 +94,20 @@ def __init__(
9494
self._data_representation = data_representation
9595
self.backbone = backbone
9696
self._split = split
97-
assert (
98-
sum(self._split[0]) == self.backbone.nb_outputs
99-
), "Split dimensions do not match backbone output dimension check your configuration"
100-
101-
if learned_multitask_weights != -1:
102-
assert isinstance(tasks, list)
103-
# init the module for learned task weights
104-
self.loss_weight_balancing = LossWeightBalancing(
105-
tasks, late_activation=learned_multitask_weights
106-
)
107-
else:
108-
self.loss_weight_balancing = None
97+
98+
if self._split is not None:
99+
assert (
100+
sum(self._split[0]) == self.backbone.nb_outputs
101+
), "Split dimensions do not match backbone output dimension check your configuration"
102+
103+
if learned_multitask_weights != -1:
104+
assert isinstance(tasks, list)
105+
# init the module for learned task weights
106+
self.loss_weight_balancing = LossWeightBalancing(
107+
tasks, late_activation=learned_multitask_weights
108+
)
109+
else:
110+
self.loss_weight_balancing = None
109111

110112
def compute_loss(
111113
self, preds: Tensor, data: List[Data], verbose: bool = False

0 commit comments

Comments
 (0)