Skip to content

Commit 3439b5f

Browse files
committed
add pulseOriginLabels
1 parent 5fe7f9f commit 3439b5f

2 files changed

Lines changed: 249 additions & 7 deletions

File tree

src/graphnet/data/extractors/icecube/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
from .i3extractor import I3Extractor
44
from .i3featureextractor import (
5-
I3FeatureExtractor,
5+
I3PulseLevelExtractor,
66
I3FeatureExtractorIceCube86,
77
I3FeatureExtractorIceCubeDeepCore,
88
I3FeatureExtractorIceCubeUpgrade,
99
I3PulseNoiseTruthFlagIceCubeUpgrade,
10+
I3PulseOriginLabels,
1011
)
1112
from .i3truthextractor import I3TruthExtractor
1213
from .i3retroextractor import I3RetroExtractor

src/graphnet/data/extractors/icecube/i3featureextractor.py

Lines changed: 247 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,65 @@
11
"""I3Extractor class(es) for extracting specific, reconstructed features."""
22

3-
from typing import TYPE_CHECKING, Any, Dict, List
3+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
44
from .i3extractor import I3Extractor
55
from graphnet.data.extractors.icecube.utilities.frames import (
66
get_om_keys_and_pulseseries,
77
)
88
from graphnet.utilities.imports import has_icecube_package
99

1010
if has_icecube_package() or TYPE_CHECKING:
11-
from icecube import icetray # pyright: reportMissingImports=false
11+
from icecube import (
12+
icetray,
13+
dataclasses,
14+
) # pyright: reportMissingImports=false
1215

16+
import numpy as np
1317

14-
class I3FeatureExtractor(I3Extractor):
18+
19+
class I3PulseLevelExtractor(I3Extractor):
1520
"""Base class for extracting specific, reconstructed features."""
1621

17-
def __init__(self, pulsemap: str, exclude: list = [None]):
18-
"""Construct I3FeatureExtractor.
22+
def __init__(
23+
self,
24+
pulsemap: str,
25+
exclude: list = [None],
26+
extractor_name: Optional[str] = None,
27+
):
28+
"""Construct I3PulseLevelExtractor.
1929
2030
Args:
2131
pulsemap: Name of the pulse (series) map for which to extract
2232
reconstructed features.
2333
exclude: List of keys to exclude from the extracted data.
34+
extractor_name: Name of the extractor.
2435
"""
2536
# Member variable(s)
2637
self._pulsemap = pulsemap
38+
if extractor_name is None:
39+
extractor_name = pulsemap
2740

2841
# Base class constructor
42+
super().__init__(extractor_name, exclude=exclude)
43+
44+
45+
class I3FeatureExtractor(I3PulseLevelExtractor):
46+
"""Old class now contained in I3PulseLevelExtractor."""
47+
48+
def __init__(self, pulsemap: str, exclude: list = [None]):
49+
"""Construct I3FeatureExtractor.
50+
51+
Args:
52+
pulsemap: Name of the pulse (series) map for which to extract
53+
reconstructed features.
54+
exclude: List of keys to exclude from the extracted data.
55+
"""
56+
self.warning_once(
57+
"I3FeatureExtractor is deprecated and will be removed in a future release. Please use I3PulseLevelExtractor instead."
58+
)
2959
super().__init__(pulsemap, exclude=exclude)
3060

3161

32-
class I3FeatureExtractorIceCube86(I3FeatureExtractor):
62+
class I3FeatureExtractorIceCube86(I3PulseLevelExtractor):
3363
"""Class for extracting reconstructed features for IceCube-86."""
3464

3565
def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]:
@@ -206,6 +236,217 @@ def _parse_awtd_flag(
206236
return pulse.width < (fadc_min_width_ns * icetray.I3Units.ns)
207237

208238

239+
class I3PulseOriginLabels(I3PulseLevelExtractor):
240+
"""Class for extracting MCPE labels for IceCube-86."""
241+
242+
def __init__(
243+
self,
244+
pulsemap: str,
245+
exclude: list = [None],
246+
extractor_name: str = "PulseOrigin",
247+
time_window: float = 10.0,
248+
mctree: str = "I3MCTree",
249+
mcpe_map: str = "I3MCPESeriesMapWithoutNoise",
250+
mcpe_map_id: str = "I3MCPESeriesMapParticleIDMap",
251+
):
252+
"""Construct I3PulseOriginLabels.
253+
254+
Args:
255+
pulsemap: Name of the pulse (series) map for which to extract
256+
reconstructed features.
257+
exclude: List of keys to exclude from the extracted data.
258+
extractor_name: Name of the extractor.
259+
time_window: Time window (in ns) around each pulse to consider
260+
MCPEs for label assignment.
261+
mctree: Name of the MCTree in the I3 frame.
262+
mcpe_map: Name of the MCPE series map in the I3 frame.
263+
mcpe_map_id: Name of the MCPE series map particle ID map in the I3 frame.
264+
"""
265+
super().__init__(pulsemap, exclude, extractor_name)
266+
self._time_window = time_window
267+
self._mctree = mctree
268+
self._mcpe_map = mcpe_map
269+
self._mcpe_map_id = mcpe_map_id
270+
271+
def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]:
272+
"""Extract MCPE labels from `frame`.
273+
274+
Args:
275+
frame: Physics (P) I3-frame from which to extract MCPE labels.
276+
277+
Returns:
278+
Dictionary of MCPE labels for all pulses in `pulsemap`,
279+
in pure-python format.
280+
"""
281+
output: Dict[str, List[Any]] = {
282+
"charge": [],
283+
"dom_time": [],
284+
"dom_x": [],
285+
"dom_y": [],
286+
"dom_z": [],
287+
"neutrino_fraction": [],
288+
"neutrino_npe_fraction": [],
289+
"npe": [],
290+
"pulse_count": [],
291+
"noise_hit": [],
292+
"trackness": [],
293+
"overlap_count": [],
294+
"min_time_delta": [],
295+
}
296+
297+
# Get OM data
298+
if self._pulsemap in frame:
299+
om_keys, data = get_om_keys_and_pulseseries(
300+
frame,
301+
self._pulsemap,
302+
self._calibration,
303+
)
304+
else:
305+
self.warning_once(f"Pulsemap {self._pulsemap} not found in frame.")
306+
return output
307+
308+
for om_key in om_keys:
309+
# Loop over pulses for each OM
310+
pulses = data[om_key]
311+
pulse_times, pulse_charges = self._get_pulse_info(pulses)
312+
npe_list, times, nu_bool, track_like_list = self._get_mcpe_info(
313+
frame, om_key
314+
)
315+
time_distance_matrix = pulse_times[:, None] - times[None, :]
316+
weight_matrix = self._get_gaussian_weight(time_distance_matrix) * (
317+
np.abs(time_distance_matrix) <= self._time_window
318+
)
319+
320+
with np.errstate(invalid="ignore"):
321+
weight_matrix /= np.sum(weight_matrix, axis=0, keepdims=True)
322+
weight_matrix = np.nan_to_num(weight_matrix, nan=0.0)
323+
pulse_counts = np.sum(weight_matrix, axis=1)
324+
neutrino_fractions = (weight_matrix @ nu_bool) / pulse_counts
325+
neutrino_npe_fractions = (
326+
weight_matrix @ (npe_list * nu_bool)
327+
) / (weight_matrix @ npe_list)
328+
total_npe = weight_matrix @ npe_list
329+
trackness = weight_matrix @ track_like_list / pulse_counts
330+
min_time_deltas = (
331+
np.min(np.abs(time_distance_matrix), axis=1)
332+
if len(times) > 0
333+
else np.array([np.nan] * len(pulse_times))
334+
)
335+
336+
# Determine how many other mcpe's overlap with any other pulse for a given pulse
337+
weight_matrix_binary = (weight_matrix > 0).astype(float)
338+
overlap_counts = weight_matrix_binary @ weight_matrix_binary.T
339+
np.fill_diagonal(overlap_counts, 0)
340+
overlap_counts = np.sum(overlap_counts, axis=1)
341+
342+
output["neutrino_fraction"].extend(neutrino_fractions.tolist())
343+
output["neutrino_npe_fraction"].extend(
344+
neutrino_npe_fractions.tolist()
345+
)
346+
output["npe"].extend(total_npe.tolist())
347+
output["pulse_count"].extend(pulse_counts.tolist())
348+
output["charge"].extend(pulse_charges.tolist())
349+
output["dom_time"].extend(pulse_times.tolist())
350+
output["dom_x"].extend(
351+
[self._gcd_dict[om_key].position.x] * len(pulse_times)
352+
)
353+
output["dom_y"].extend(
354+
[self._gcd_dict[om_key].position.y] * len(pulse_times)
355+
)
356+
output["dom_z"].extend(
357+
[self._gcd_dict[om_key].position.z] * len(pulse_times)
358+
)
359+
output["noise_hit"].extend((pulse_counts == 0).tolist())
360+
output["min_time_delta"].extend(min_time_deltas.tolist())
361+
output["trackness"].extend(trackness.tolist())
362+
output["overlap_count"].extend(overlap_counts.tolist())
363+
364+
return output
365+
366+
def _get_mcpe_info(
367+
self, frame: "icetray.I3Frame", om_key: "icetray.OMKey"
368+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
369+
"""Determine the neutrino fraction of a pulse.
370+
371+
Args:
372+
pulse: The pulse for which to determine the neutrino fraction.
373+
mcpe_map: Name of the MCPE series map in the I3 frame.
374+
375+
Returns:
376+
Neutrino fraction of the pulse.
377+
"""
378+
times: list[float] = []
379+
nu_bool: list[bool] = []
380+
npe_list: list[float] = []
381+
track_like_list: list[float] = []
382+
try:
383+
mcpe_info = frame[self._mcpe_map][om_key]
384+
except KeyError:
385+
return (
386+
np.array(npe_list),
387+
np.array(times),
388+
np.array(nu_bool),
389+
np.array(track_like_list),
390+
)
391+
for i, mcpe in enumerate(mcpe_info):
392+
try:
393+
nu_primary = (
394+
frame[self._mctree].get_primary(mcpe.ID).is_neutrino
395+
)
396+
track_like = frame[self._mctree].get_particle(mcpe.ID).is_track
397+
nu_bool.append(nu_primary)
398+
except RuntimeError as e:
399+
# backup to using the mcpe id map to figure out the parent type, if any part of the mcpe has a neutrino as the primary, we count it as a neutrino mcpe this is a choice, but the information about which part of the mcpe corresponds to which primary is lost.
400+
if "particle not found" in str(e):
401+
ids = [
402+
id_p
403+
for id_p, indexval in frame[self._mcpe_map_id][
404+
om_key
405+
].items()
406+
if i in indexval
407+
]
408+
bool_val = any(
409+
[
410+
frame[self._mctree].get_primary(id_p).is_neutrino
411+
for id_p in ids
412+
]
413+
)
414+
track_like = [
415+
frame[self._mctree].get_particle(id_p).is_track
416+
for id_p in ids
417+
]
418+
track_like = sum(track_like) / len(track_like)
419+
nu_bool.append(bool_val)
420+
else:
421+
raise e
422+
423+
times.append(mcpe.time)
424+
npe_list.append(mcpe.npe)
425+
track_like_list.append(track_like)
426+
427+
return (
428+
np.array(npe_list),
429+
np.array(times),
430+
np.array(nu_bool),
431+
np.array(track_like_list),
432+
)
433+
434+
def _get_pulse_info(
435+
self, pulses: List["icetray.I3RecoPulse"]
436+
) -> tuple[np.ndarray, np.ndarray]:
437+
"""Create an nd array of pulse times, charge."""
438+
times, charges = np.array([[p.time, p.charge] for p in pulses]).T
439+
return times, charges
440+
441+
def _get_gaussian_weight(
442+
self, time_distance_matrix: np.ndarray
443+
) -> np.ndarray:
444+
"""Create gaussian weight matrix based on time distance matrix."""
445+
return np.exp(
446+
-0.5 * (time_distance_matrix / (self._time_window / 2)) ** 2
447+
)
448+
449+
209450
class I3FeatureExtractorIceCubeDeepCore(I3FeatureExtractorIceCube86):
210451
"""Class for extracting reconstructed features for IceCube-DeepCore."""
211452

0 commit comments

Comments
 (0)