@@ -149,42 +149,10 @@ def __call__(self, frame: I3Frame) -> bool:
149149 return True
150150 # inference
151151 memory_watch = False
152- if self ._inference_speed_check is True :
153- # create log file if it does not exist
154- data_repr_start = time ()
155152 try :
156- if not self .multiple_models :
157- data = self ._create_data_representation (frame = frame ).to (
158- self ._device
159- )
160- if self ._inference_speed_check is True :
161- data_repr_end = time ()
162- data_repr_time = data_repr_end - data_repr_start
163- inference_start = time ()
164- predictions = self ._apply_model (data = data )
165- else :
166- features = self ._extract_feature_array_from_frame (frame = frame )
167- if self ._inference_speed_check is True :
168- data_repr_end = time ()
169- data_repr_time = data_repr_end - data_repr_start
170- inference_start = time ()
171- model_input_data = []
172- for _ , graph_definition in enumerate (self ._graph_definitions ):
173- data = graph_definition (
174- input_features = features [_ ],
175- input_feature_names = self .features_list [_ ],
176- )
177- model_input_data .append (Batch .from_data_list ([data ]))
178-
179- predictions = self ._apply_model (data = model_input_data )
180-
181- if self ._inference_speed_check is True :
182- inference_end = time ()
183- inference_time = inference_end - inference_start
184- self ._logger .info (
185- f"Data representation time: { data_repr_time :.4f} s\n "
186- f"Inference time: { inference_time :.4f} s\n "
187- )
153+ predictions , data_repr_time , inference_time = (
154+ self ._create_data_and_apply (frame = frame )
155+ )
188156
189157 except OutOfMemoryError :
190158 self .error (
@@ -195,22 +163,12 @@ def __call__(self, frame: I3Frame) -> bool:
195163 save_device = self ._device
196164 self ._device = "cpu"
197165 self .model .to (self ._device )
198- data = self ._create_data_representation (frame = frame )
199- if self ._inference_speed_check is True :
200- data_repr_end = time ()
201- data_repr_time = data_repr_end - data_repr_start
202- inference_start = time ()
203- predictions = self ._apply_model (data = data )
204- if self ._inference_speed_check is True :
205- inference_end = time ()
206- inference_time = inference_end - inference_start
207- self ._logger .info (
208- f"Data representation time: { data_repr_time :.4f} s\n "
209- f"Inference time: { inference_time :.4f} s\n "
210- )
211- self ._device = save_device
166+
167+ predictions , data_repr_time , inference_time = (
168+ self ._create_data_and_apply (frame = frame )
169+ )
212170 memory_watch = True
213- del data
171+ self . _device = save_device
214172
215173 if self ._inference_speed_check is True :
216174 write_start = time ()
@@ -238,6 +196,49 @@ def __call__(self, frame: I3Frame) -> bool:
238196 self .model .to (self ._device )
239197 return True
240198
199+ def _create_data_and_apply (self , frame : I3Frame ) -> tuple :
200+ data_repr_time = - 1
201+ inference_time = - 1
202+ if self ._inference_speed_check is True :
203+ # create log file if it does not exist
204+ data_repr_start = time ()
205+ if not self .multiple_models :
206+ data = self ._create_data_representation (frame = frame ).to (
207+ self ._device
208+ )
209+ if self ._inference_speed_check is True :
210+ data_repr_end = time ()
211+ data_repr_time = data_repr_end - data_repr_start
212+ inference_start = time ()
213+ predictions = self ._apply_model (data = data )
214+ else :
215+ features = self ._extract_feature_array_from_frame (frame = frame )
216+ if self ._inference_speed_check is True :
217+ data_repr_end = time ()
218+ data_repr_time = data_repr_end - data_repr_start
219+ inference_start = time ()
220+ model_input_data = []
221+ for _ , graph_definition in enumerate (self ._graph_definitions ):
222+ data = graph_definition (
223+ input_features = features [_ ],
224+ input_feature_names = self .features_list [_ ],
225+ )
226+ model_input_data .append (
227+ Batch .from_data_list ([data .to (self ._device )])
228+ )
229+
230+ predictions = self ._apply_model (data = model_input_data )
231+
232+ if self ._inference_speed_check is True :
233+ inference_end = time ()
234+ inference_time = inference_end - inference_start
235+ self ._logger .info (
236+ f"Data representation time: { data_repr_time :.4f} s\n "
237+ f"Inference time: { inference_time :.4f} s\n "
238+ )
239+ del data
240+ return predictions , data_repr_time , inference_time
241+
241242 def _check_dimensions (self , predictions : np .ndarray ) -> int :
242243 if len (predictions .shape ) > 1 :
243244 dim = predictions .shape [1 ]
@@ -284,7 +285,7 @@ def _apply_model(self, data: Data) -> np.ndarray:
284285 """Apply model to `Data` and case-handling."""
285286 if data is not None :
286287 predictions = self ._inference (data )
287- #print(predictions, type(predictions), type(predictions[0]))
288+ # print(predictions, type(predictions), type(predictions[0]))
288289 if isinstance (predictions , list ):
289290 predictions = np .concatenate (
290291 [pred .flatten () for pred in predictions ]
@@ -316,8 +317,8 @@ def _create_data_representation(self, frame: I3Frame) -> Data:
316317 data = self ._graph_definition (
317318 input_features = input_features ,
318319 input_feature_names = self ._features ,
319- )
320- return Batch .from_data_list ([data . to ( self . _device ) ])
320+ ). to ( self . _device )
321+ return Batch .from_data_list ([data ])
321322 else :
322323 return None
323324
@@ -421,10 +422,10 @@ def __init__(
421422 len (self ._positions ) == 3
422423 ), "positions must be a list of 3 elements"
423424
424- def _get_min_time (self , frame : I3Frame ) -> float :
425+ def _get_min_time (self , frame : I3Frame , pulsemap : str ) -> float :
425426 """Get the minimum time of the first pulse in the frame."""
426427 min_time = np .inf
427- doms = frame [self . _pulsemap ].apply (frame ).values ()
428+ doms = frame [pulsemap ].apply (frame ).values ()
428429 # seach for the minimum time
429430 for dom in doms :
430431 if dom [0 ].time < min_time :
@@ -451,10 +452,11 @@ def _add_to_frame(self, frame, data):
451452
452453 if self ._shift_time :
453454 # Shift time to be relative to the first pulse
454- shift_time = (
455- self ._get_min_time (frame )
456- - frame ["CVStatistics" ].min_pulse_time
457- )
455+ shift_time = self ._get_min_time (frame , self ._pulsemap )
456+ if "CVStatistics" in frame :
457+ shift_time -= frame ["CVStatistics" ].min_pulse_time
458+ else :
459+ shift_time -= self ._get_min_time (frame , "InIcePulses" )
458460 particle .time = data [self ._time ].value + shift_time
459461 else :
460462 particle .time = data [self ._time ].value
@@ -479,6 +481,7 @@ def _add_to_frame(self, frame, data):
479481 super ()._add_to_frame (frame = frame , data = data )
480482 return
481483
484+
482485class I3MultipleModelInferenceModule (I3InferenceModule ):
483486 """I3InferenceModule for I3Particle data."""
484487
@@ -499,4 +502,3 @@ def _add_to_frame(self, frame, data):
499502
500503 i3_score_container = dataclasses .I3MapStringDouble (data )
501504 frame .Put (self ._key_name , i3_score_container )
502-
0 commit comments