@@ -248,6 +248,7 @@ def __init__(
248248 mctree : str = "I3MCTree" ,
249249 mcpe_map : str = "I3MCPESeriesMapWithoutNoise" ,
250250 mcpe_map_id : str = "I3MCPESeriesMapParticleIDMap" ,
251+ pulse_source_map : Optional [str ] = None ,
251252 ):
252253 """Construct I3PulseOriginLabels.
253254
@@ -261,12 +262,14 @@ def __init__(
261262 mctree: Name of the MCTree in the I3 frame.
262263 mcpe_map: Name of the MCPE series map in the I3 frame.
263264 mcpe_map_id: Name of the MCPE series map particle ID map in the I3 frame.
265+ pulse_source_map: Name of the pulse source map in the I3 frame.
264266 """
265267 super ().__init__ (pulsemap , exclude , extractor_name )
266268 self ._time_window = time_window
267269 self ._mctree = mctree
268270 self ._mcpe_map = mcpe_map
269271 self ._mcpe_map_id = mcpe_map_id
272+ self ._pulse_source_map = pulse_source_map
270273
271274 def __call__ (self , frame : "icetray.I3Frame" ) -> Dict [str , List [Any ]]:
272275 """Extract MCPE labels from `frame`.
@@ -285,14 +288,16 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]:
285288 "dom_y" : [],
286289 "dom_z" : [],
287290 "neutrino_fraction" : [],
288- "neutrino_npe_fraction " : [],
291+ "noise_fraction " : [],
289292 "npe" : [],
290- "pulse_count" : [],
291- "noise_hit" : [],
293+ "hit_count" : [],
292294 "trackness" : [],
293295 "overlap_count" : [],
294296 "min_time_delta" : [],
297+ "total_npe_fraction" : [],
295298 }
299+ if self ._pulse_source_map is not None :
300+ output ["source_truth" ] = []
296301
297302 # Get OM data
298303 if self ._pulsemap in frame :
@@ -309,26 +314,52 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]:
309314 # Loop over pulses for each OM
310315 pulses = data [om_key ]
311316 pulse_times , pulse_charges = self ._get_pulse_info (pulses )
312- npe_list , times , nu_bool , track_like_list = self . _get_mcpe_info (
313- frame , om_key
317+ npe_list , times , nu_bool , track_like_list , noise_bool = (
318+ self . _get_mcpe_info ( frame , om_key )
314319 )
320+
315321 time_distance_matrix = pulse_times [:, None ] - times [None , :]
316322 weight_matrix = self ._get_gaussian_weight (time_distance_matrix ) * (
317323 np .abs (time_distance_matrix ) <= self ._time_window
318324 )
319325
326+ source_truth = []
327+ if (self ._pulse_source_map is not None ) and frame .Has (
328+ self ._pulse_source_map
329+ ):
330+ source_truth_series = frame [self ._pulse_source_map ][om_key ]
331+ source_t , source_q , source_truth = np .array (
332+ [[p .time , p .charge , p .source ] for p in source_truth_series ]
333+ ).T
334+ # for each pulse, find the closest source truth
335+ time_diffs = np .abs (pulse_times [:, None ] - source_t [None , :])
336+ # Record where there where no source truth within 50ns
337+ no_source_within_50ns = np .min (time_diffs , axis = 1 ) > 50.0
338+ closest_indices = np .argmin (time_diffs , axis = 1 )
339+ # assert that no pulses have the same source truth assigned
340+ # assert len(closest_indices) == len(set(closest_indices)), "Multiple pulses assigned to same source truth!"
341+ source_truth = source_truth [closest_indices ]
342+ source_truth [no_source_within_50ns ] = - 1
343+ else :
344+ source_truth = [- 1 ] * len (pulse_times ) # no source found
345+
346+ total_mcpe_npe = np .sum (npe_list )
320347 with np .errstate (invalid = "ignore" ):
321348 weight_matrix /= np .sum (weight_matrix , axis = 0 , keepdims = True )
322349 weight_matrix = np .nan_to_num (weight_matrix , nan = 0.0 )
323350 pulse_counts = np .sum (weight_matrix , axis = 1 )
324- neutrino_fractions = (weight_matrix @ nu_bool ) / pulse_counts
325- neutrino_npe_fractions = (
326- weight_matrix @ (npe_list * nu_bool )
327- ) / (weight_matrix @ npe_list )
351+ noise_fractions = (weight_matrix @ noise_bool ) / pulse_counts
352+ noise_fractions = np .nan_to_num (noise_fractions , nan = 1.0 )
353+ neutrino_fractions = (weight_matrix @ (npe_list * nu_bool )) / (
354+ weight_matrix @ npe_list
355+ )
328356 total_npe = weight_matrix @ npe_list
329357 trackness = weight_matrix @ track_like_list / pulse_counts
330358 min_time_deltas = (
331- np .min (np .abs (time_distance_matrix ), axis = 1 )
359+ time_distance_matrix [
360+ np .arange (time_distance_matrix .shape [0 ]),
361+ np .argmin (np .abs (time_distance_matrix ), axis = 1 ),
362+ ]
332363 if len (times ) > 0
333364 else np .array ([np .nan ] * len (pulse_times ))
334365 )
@@ -340,11 +371,8 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]:
340371 overlap_counts = np .sum (overlap_counts , axis = 1 )
341372
342373 output ["neutrino_fraction" ].extend (neutrino_fractions .tolist ())
343- output ["neutrino_npe_fraction" ].extend (
344- neutrino_npe_fractions .tolist ()
345- )
346374 output ["npe" ].extend (total_npe .tolist ())
347- output ["pulse_count " ].extend (pulse_counts .tolist ())
375+ output ["hit_count " ].extend (pulse_counts .tolist ())
348376 output ["charge" ].extend (pulse_charges .tolist ())
349377 output ["dom_time" ].extend (pulse_times .tolist ())
350378 output ["dom_x" ].extend (
@@ -356,16 +384,26 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]:
356384 output ["dom_z" ].extend (
357385 [self ._gcd_dict [om_key ].position .z ] * len (pulse_times )
358386 )
359- output ["noise_hit" ].extend ((pulse_counts == 0 ).tolist ())
387+ output ["noise_fraction" ].extend (
388+ noise_fractions .tolist ()
389+ ) # if noise is not present in the mcpe map, this will be a binary.
360390 output ["min_time_delta" ].extend (min_time_deltas .tolist ())
361391 output ["trackness" ].extend (trackness .tolist ())
362392 output ["overlap_count" ].extend (overlap_counts .tolist ())
363-
393+ output ["total_npe_fraction" ].extend (
394+ (total_npe / total_mcpe_npe ).tolist ()
395+ )
396+ if self ._pulse_source_map is not None :
397+ output ["source_truth" ].extend (
398+ source_truth
399+ if len (source_truth ) == len (pulse_times )
400+ else [- 1 ] * len (pulse_times )
401+ )
364402 return output
365403
366404 def _get_mcpe_info (
367405 self , frame : "icetray.I3Frame" , om_key : "icetray.OMKey"
368- ) -> tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray ]:
406+ ) -> tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray , np . ndarray ]:
369407 """Determine the neutrino fraction of a pulse.
370408
371409 Args:
@@ -379,56 +417,91 @@ def _get_mcpe_info(
379417 nu_bool : list [bool ] = []
380418 npe_list : list [float ] = []
381419 track_like_list : list [float ] = []
420+ noise_list : list [bool ] = []
382421 try :
383- mcpe_info = frame [self ._mcpe_map ][om_key ]
422+ mcpe_info = np .array (frame [self ._mcpe_map ][om_key ])
423+ except KeyError as e :
424+ if self ._mcpe_map in str (e ):
425+ self .warning_once (
426+ f"MCPE map { self ._mcpe_map } not found in frame."
427+ )
428+ return (
429+ np .array (npe_list ),
430+ np .array (times ),
431+ np .array (nu_bool ),
432+ np .array (track_like_list ),
433+ np .array (noise_list ),
434+ )
435+ elif "Invalid key" in str (e ):
436+ return (
437+ np .array (npe_list ),
438+ np .array (times ),
439+ np .array (nu_bool ),
440+ np .array (track_like_list ),
441+ np .array (noise_list ),
442+ )
443+ else :
444+ raise e
445+
446+ mcpe_id_map_keys = []
447+ mcpe_index_map = []
448+ try :
449+ for id_key , id_vals in frame [self ._mcpe_map_id ][om_key ].items ():
450+ mcpe_id_map_keys .extend ([id_key ] * len (id_vals ))
451+ # mcpe_id_map_vals.extend(mcpe_info[id_vals])
452+ mcpe_index_map .extend (id_vals )
453+
384454 except KeyError :
385- return (
386- np . array ( npe_list ),
387- np . array ( times ),
388- np .array (nu_bool ),
389- np .array (track_like_list ),
390- )
455+ # This just means all the pulses are noise hits (do nothing)
456+ pass
457+
458+ mcpe_id_map_keys = np .array (mcpe_id_map_keys )
459+ mcpe_index_map = np .array (mcpe_index_map )
460+
391461 for i , mcpe in enumerate (mcpe_info ):
392- try :
393- nu_primary = (
394- frame [self ._mctree ].get_primary (mcpe .ID ).is_neutrino
395- )
396- track_like = frame [self ._mctree ].get_particle (mcpe .ID ).is_track
397- nu_bool .append (nu_primary )
398- except RuntimeError as e :
399- # backup to using the mcpe id map to figure out the parent type, if any part of the mcpe has a neutrino as the primary, we count it as a neutrino mcpe this is a choice, but the information about which part of the mcpe corresponds to which primary is lost.
400- if "particle not found" in str (e ):
401- ids = [
402- id_p
403- for id_p , indexval in frame [self ._mcpe_map_id ][
404- om_key
405- ].items ()
406- if i in indexval
407- ]
408- bool_val = any (
409- [
410- frame [self ._mctree ].get_primary (id_p ).is_neutrino
411- for id_p in ids
412- ]
413- )
414- track_like = [
415- frame [self ._mctree ].get_particle (id_p ).is_track
416- for id_p in ids
417- ]
418- track_like = sum (track_like ) / len (track_like )
419- nu_bool .append (bool_val )
462+ is_noise = False
463+ if mcpe .ID == dataclasses .I3ParticleID (0 , 0 ):
464+ # Noise hit
465+ track_like = 0.0
466+ neutrino_bool = False
467+ is_noise = True
468+
469+ elif mcpe .ID != dataclasses .I3ParticleID (0 , - 1 ):
470+ # Not a multiple parent hit - use direct method
471+ particle = frame [self ._mctree ].get_particle (mcpe .ID )
472+ track_like = float (particle .is_track )
473+ neutrino_bool = particle .is_neutrino
474+
475+ else :
476+ # Hit with multiple parent particles - need to loop over all parent particles
477+ particle_list = [
478+ frame [self ._mctree ].get_particle (p )
479+ for p in mcpe_id_map_keys [i == mcpe_index_map ]
480+ ]
481+ if len (particle_list ) == 0 :
482+ warning_string = f"No parent particles found for MCPE with ID { mcpe .ID } in OMKey { om_key } .\n "
483+ warning_string += f"{ frame ['I3EventHeader' ]} \n "
484+ self .warning (warning_string )
485+ track_like = 0.0
486+ neutrino_bool = False
420487 else :
421- raise e
488+ track_like = sum (
489+ [p .is_track for p in particle_list ]
490+ ) / len (particle_list )
491+ neutrino_bool = any ([p .is_neutrino for p in particle_list ])
422492
493+ track_like_list .append (track_like )
494+ nu_bool .append (neutrino_bool )
423495 times .append (mcpe .time )
424496 npe_list .append (mcpe .npe )
425- track_like_list .append (track_like )
497+ noise_list .append (is_noise )
426498
427499 return (
428500 np .array (npe_list ),
429501 np .array (times ),
430502 np .array (nu_bool ),
431503 np .array (track_like_list ),
504+ np .array (noise_list ),
432505 )
433506
434507 def _get_pulse_info (
0 commit comments