44"""
55
66from typing import List , Union , Optional , TYPE_CHECKING , Dict , Any
7-
87import numpy as np
98from torch_geometric .data import Data , Batch
10-
119from torch .cuda import OutOfMemoryError
1210from time import sleep
1311import torch
@@ -51,7 +49,6 @@ def __init__(
5149 prediction_columns : Optional [Union [List [str ], None ]] = None ,
5250 pulsemap : Optional [str ] = None ,
5351 multiple_models : bool = False ,
54- key_name : Optional [str ] = None ,
5552 requirements : Optional [callable ] = None ,
5653 device : Optional [str ] = "cpu" ,
5754 batch_size : Optional [int ] = 1 ,
@@ -75,8 +72,6 @@ def __init__(
7572 E.g. ['energy_reco']. Optional.
7673 pulsemap: the pulsmap that the model is expecting as input.
7774 multiple_models: process multiple models with the same feature set at once.
78- key_name: The name used for the key in the I3Frame. Will help define the
79- named entry in the I3Frame. E.g. "dynedge_predictions".
8075 """
8176 super ().__init__ (
8277 model_config = model_config ,
@@ -130,7 +125,6 @@ def __init__(
130125 self ._num_threads = num_threads
131126 self ._inference_speed_check = inference_speed_check
132127 self ._multiple_models = multiple_models
133- self ._key_name = key_name
134128 # Set GCD file for pulsemap extractor
135129 if gcd_file is not None :
136130 for i3_extractor in self ._i3_extractors :
@@ -266,26 +260,31 @@ def _create_dictionary(
266260 ) -> Dict [str , Any ]:
267261 """Transform predictions into a dictionary."""
268262 data = {}
269- for i in range (dim ):
270- data [self .model_name + "_" + self .prediction_columns [i ]] = (
271- I3Double (float (predictions [i ]))
272- )
263+ if self ._multiple_models == True :
264+ for i , key in enumerate (self .model_name ):
265+ data [key ] = float (predictions [i ])
266+ else :
267+ for i in range (dim ):
268+ data [self .model_name + "_" + self .prediction_columns [i ]] = (
269+ I3Double (float (predictions [i ]))
270+ )
273271
274- # try:
275- # assert len(predictions[:, i]) == 1
276- # data[
277- # self.model_name + "_" + self.prediction_columns[i]
278- # ] = I3Double(float(predictions[:, i][0]))
279- # except IndexError:
280- # data[
281- # self.model_name + "_" + self.prediction_columns[i]
282- # ] = I3Double(predictions[0])
272+ # try:
273+ # assert len(predictions[:, i]) == 1
274+ # data[
275+ # self.model_name + "_" + self.prediction_columns[i]
276+ # ] = I3Double(float(predictions[:, i][0]))
277+ # except IndexError:
278+ # data[
279+ # self.model_name + "_" + self.prediction_columns[i]
280+ # ] = I3Double(predictions[0])
283281 return data
284282
285283 def _apply_model (self , data : Data ) -> np .ndarray :
286284 """Apply model to `Data` and case-handling."""
287285 if data is not None :
288286 predictions = self ._inference (data )
287+ #print(predictions, type(predictions), type(predictions[0]))
289288 if isinstance (predictions , list ):
290289 predictions = np .concatenate (
291290 [pred .flatten () for pred in predictions ]
@@ -479,3 +478,25 @@ def _add_to_frame(self, frame, data):
479478 # for all the other values in data, add them to an I3Dictionary
480479 super ()._add_to_frame (frame = frame , data = data )
481480 return
481+
482+ class I3MultipleModelInferenceModule (I3InferenceModule ):
483+ """I3InferenceModule for I3Particle data."""
484+
485+ def __init__ (
486+ self ,
487+ key_name : List [str ],
488+ ** kwargs ,
489+ ):
490+ """
491+ key_name: The name used for the key in the I3Frame. Will help define the
492+ named entry in the I3Frame.
493+ """
494+ super ().__init__ (** kwargs )
495+
496+ self ._key_name = key_name
497+
498+ def _add_to_frame (self , frame , data ):
499+
500+ i3_score_container = dataclasses .I3MapStringDouble (data )
501+ frame .Put (self ._key_name , i3_score_container )
502+
0 commit comments