Skip to content

Commit bdc2b49

Browse files
committed
name changes and source truth info
1 parent 3439b5f commit bdc2b49

1 file changed

Lines changed: 127 additions & 54 deletions

File tree

src/graphnet/data/extractors/icecube/i3featureextractor.py

Lines changed: 127 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def __init__(
248248
mctree: str = "I3MCTree",
249249
mcpe_map: str = "I3MCPESeriesMapWithoutNoise",
250250
mcpe_map_id: str = "I3MCPESeriesMapParticleIDMap",
251+
pulse_source_map: Optional[str] = None,
251252
):
252253
"""Construct I3PulseOriginLabels.
253254
@@ -261,12 +262,14 @@ def __init__(
261262
mctree: Name of the MCTree in the I3 frame.
262263
mcpe_map: Name of the MCPE series map in the I3 frame.
263264
mcpe_map_id: Name of the MCPE series map particle ID map in the I3 frame.
265+
pulse_source_map: Name of the pulse source map in the I3 frame.
264266
"""
265267
super().__init__(pulsemap, exclude, extractor_name)
266268
self._time_window = time_window
267269
self._mctree = mctree
268270
self._mcpe_map = mcpe_map
269271
self._mcpe_map_id = mcpe_map_id
272+
self._pulse_source_map = pulse_source_map
270273

271274
def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]:
272275
"""Extract MCPE labels from `frame`.
@@ -285,14 +288,16 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]:
285288
"dom_y": [],
286289
"dom_z": [],
287290
"neutrino_fraction": [],
288-
"neutrino_npe_fraction": [],
291+
"noise_fraction": [],
289292
"npe": [],
290-
"pulse_count": [],
291-
"noise_hit": [],
293+
"hit_count": [],
292294
"trackness": [],
293295
"overlap_count": [],
294296
"min_time_delta": [],
297+
"total_npe_fraction": [],
295298
}
299+
if self._pulse_source_map is not None:
300+
output["source_truth"] = []
296301

297302
# Get OM data
298303
if self._pulsemap in frame:
@@ -309,26 +314,52 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]:
309314
# Loop over pulses for each OM
310315
pulses = data[om_key]
311316
pulse_times, pulse_charges = self._get_pulse_info(pulses)
312-
npe_list, times, nu_bool, track_like_list = self._get_mcpe_info(
313-
frame, om_key
317+
npe_list, times, nu_bool, track_like_list, noise_bool = (
318+
self._get_mcpe_info(frame, om_key)
314319
)
320+
315321
time_distance_matrix = pulse_times[:, None] - times[None, :]
316322
weight_matrix = self._get_gaussian_weight(time_distance_matrix) * (
317323
np.abs(time_distance_matrix) <= self._time_window
318324
)
319325

326+
source_truth = []
327+
if (self._pulse_source_map is not None) and frame.Has(
328+
self._pulse_source_map
329+
):
330+
source_truth_series = frame[self._pulse_source_map][om_key]
331+
source_t, source_q, source_truth = np.array(
332+
[[p.time, p.charge, p.source] for p in source_truth_series]
333+
).T
334+
# for each pulse, find the closest source truth
335+
time_diffs = np.abs(pulse_times[:, None] - source_t[None, :])
336+
# Record where there where no source truth within 50ns
337+
no_source_within_50ns = np.min(time_diffs, axis=1) > 50.0
338+
closest_indices = np.argmin(time_diffs, axis=1)
339+
# assert that no pulses have the same source truth assigned
340+
# assert len(closest_indices) == len(set(closest_indices)), "Multiple pulses assigned to same source truth!"
341+
source_truth = source_truth[closest_indices]
342+
source_truth[no_source_within_50ns] = -1
343+
else:
344+
source_truth = [-1] * len(pulse_times) # no source found
345+
346+
total_mcpe_npe = np.sum(npe_list)
320347
with np.errstate(invalid="ignore"):
321348
weight_matrix /= np.sum(weight_matrix, axis=0, keepdims=True)
322349
weight_matrix = np.nan_to_num(weight_matrix, nan=0.0)
323350
pulse_counts = np.sum(weight_matrix, axis=1)
324-
neutrino_fractions = (weight_matrix @ nu_bool) / pulse_counts
325-
neutrino_npe_fractions = (
326-
weight_matrix @ (npe_list * nu_bool)
327-
) / (weight_matrix @ npe_list)
351+
noise_fractions = (weight_matrix @ noise_bool) / pulse_counts
352+
noise_fractions = np.nan_to_num(noise_fractions, nan=1.0)
353+
neutrino_fractions = (weight_matrix @ (npe_list * nu_bool)) / (
354+
weight_matrix @ npe_list
355+
)
328356
total_npe = weight_matrix @ npe_list
329357
trackness = weight_matrix @ track_like_list / pulse_counts
330358
min_time_deltas = (
331-
np.min(np.abs(time_distance_matrix), axis=1)
359+
time_distance_matrix[
360+
np.arange(time_distance_matrix.shape[0]),
361+
np.argmin(np.abs(time_distance_matrix), axis=1),
362+
]
332363
if len(times) > 0
333364
else np.array([np.nan] * len(pulse_times))
334365
)
@@ -340,11 +371,8 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]:
340371
overlap_counts = np.sum(overlap_counts, axis=1)
341372

342373
output["neutrino_fraction"].extend(neutrino_fractions.tolist())
343-
output["neutrino_npe_fraction"].extend(
344-
neutrino_npe_fractions.tolist()
345-
)
346374
output["npe"].extend(total_npe.tolist())
347-
output["pulse_count"].extend(pulse_counts.tolist())
375+
output["hit_count"].extend(pulse_counts.tolist())
348376
output["charge"].extend(pulse_charges.tolist())
349377
output["dom_time"].extend(pulse_times.tolist())
350378
output["dom_x"].extend(
@@ -356,16 +384,26 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]:
356384
output["dom_z"].extend(
357385
[self._gcd_dict[om_key].position.z] * len(pulse_times)
358386
)
359-
output["noise_hit"].extend((pulse_counts == 0).tolist())
387+
output["noise_fraction"].extend(
388+
noise_fractions.tolist()
389+
) # if noise is not present in the mcpe map, this will be a binary.
360390
output["min_time_delta"].extend(min_time_deltas.tolist())
361391
output["trackness"].extend(trackness.tolist())
362392
output["overlap_count"].extend(overlap_counts.tolist())
363-
393+
output["total_npe_fraction"].extend(
394+
(total_npe / total_mcpe_npe).tolist()
395+
)
396+
if self._pulse_source_map is not None:
397+
output["source_truth"].extend(
398+
source_truth
399+
if len(source_truth) == len(pulse_times)
400+
else [-1] * len(pulse_times)
401+
)
364402
return output
365403

366404
def _get_mcpe_info(
367405
self, frame: "icetray.I3Frame", om_key: "icetray.OMKey"
368-
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
406+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
369407
"""Determine the neutrino fraction of a pulse.
370408
371409
Args:
@@ -379,56 +417,91 @@ def _get_mcpe_info(
379417
nu_bool: list[bool] = []
380418
npe_list: list[float] = []
381419
track_like_list: list[float] = []
420+
noise_list: list[bool] = []
382421
try:
383-
mcpe_info = frame[self._mcpe_map][om_key]
422+
mcpe_info = np.array(frame[self._mcpe_map][om_key])
423+
except KeyError as e:
424+
if self._mcpe_map in str(e):
425+
self.warning_once(
426+
f"MCPE map {self._mcpe_map} not found in frame."
427+
)
428+
return (
429+
np.array(npe_list),
430+
np.array(times),
431+
np.array(nu_bool),
432+
np.array(track_like_list),
433+
np.array(noise_list),
434+
)
435+
elif "Invalid key" in str(e):
436+
return (
437+
np.array(npe_list),
438+
np.array(times),
439+
np.array(nu_bool),
440+
np.array(track_like_list),
441+
np.array(noise_list),
442+
)
443+
else:
444+
raise e
445+
446+
mcpe_id_map_keys = []
447+
mcpe_index_map = []
448+
try:
449+
for id_key, id_vals in frame[self._mcpe_map_id][om_key].items():
450+
mcpe_id_map_keys.extend([id_key] * len(id_vals))
451+
# mcpe_id_map_vals.extend(mcpe_info[id_vals])
452+
mcpe_index_map.extend(id_vals)
453+
384454
except KeyError:
385-
return (
386-
np.array(npe_list),
387-
np.array(times),
388-
np.array(nu_bool),
389-
np.array(track_like_list),
390-
)
455+
# This just means all the pulses are noise hits (do nothing)
456+
pass
457+
458+
mcpe_id_map_keys = np.array(mcpe_id_map_keys)
459+
mcpe_index_map = np.array(mcpe_index_map)
460+
391461
for i, mcpe in enumerate(mcpe_info):
392-
try:
393-
nu_primary = (
394-
frame[self._mctree].get_primary(mcpe.ID).is_neutrino
395-
)
396-
track_like = frame[self._mctree].get_particle(mcpe.ID).is_track
397-
nu_bool.append(nu_primary)
398-
except RuntimeError as e:
399-
# backup to using the mcpe id map to figure out the parent type, if any part of the mcpe has a neutrino as the primary, we count it as a neutrino mcpe this is a choice, but the information about which part of the mcpe corresponds to which primary is lost.
400-
if "particle not found" in str(e):
401-
ids = [
402-
id_p
403-
for id_p, indexval in frame[self._mcpe_map_id][
404-
om_key
405-
].items()
406-
if i in indexval
407-
]
408-
bool_val = any(
409-
[
410-
frame[self._mctree].get_primary(id_p).is_neutrino
411-
for id_p in ids
412-
]
413-
)
414-
track_like = [
415-
frame[self._mctree].get_particle(id_p).is_track
416-
for id_p in ids
417-
]
418-
track_like = sum(track_like) / len(track_like)
419-
nu_bool.append(bool_val)
462+
is_noise = False
463+
if mcpe.ID == dataclasses.I3ParticleID(0, 0):
464+
# Noise hit
465+
track_like = 0.0
466+
neutrino_bool = False
467+
is_noise = True
468+
469+
elif mcpe.ID != dataclasses.I3ParticleID(0, -1):
470+
# Not a multiple parent hit - use direct method
471+
particle = frame[self._mctree].get_particle(mcpe.ID)
472+
track_like = float(particle.is_track)
473+
neutrino_bool = particle.is_neutrino
474+
475+
else:
476+
# Hit with multiple parent particles - need to loop over all parent particles
477+
particle_list = [
478+
frame[self._mctree].get_particle(p)
479+
for p in mcpe_id_map_keys[i == mcpe_index_map]
480+
]
481+
if len(particle_list) == 0:
482+
warning_string = f"No parent particles found for MCPE with ID {mcpe.ID} in OMKey {om_key}.\n"
483+
warning_string += f"{frame['I3EventHeader']}\n"
484+
self.warning(warning_string)
485+
track_like = 0.0
486+
neutrino_bool = False
420487
else:
421-
raise e
488+
track_like = sum(
489+
[p.is_track for p in particle_list]
490+
) / len(particle_list)
491+
neutrino_bool = any([p.is_neutrino for p in particle_list])
422492

493+
track_like_list.append(track_like)
494+
nu_bool.append(neutrino_bool)
423495
times.append(mcpe.time)
424496
npe_list.append(mcpe.npe)
425-
track_like_list.append(track_like)
497+
noise_list.append(is_noise)
426498

427499
return (
428500
np.array(npe_list),
429501
np.array(times),
430502
np.array(nu_bool),
431503
np.array(track_like_list),
504+
np.array(noise_list),
432505
)
433506

434507
def _get_pulse_info(

0 commit comments

Comments
 (0)