Skip to content

Commit 858c180

Browse files
committed
alternative fix
1 parent b9b565c commit 858c180

1 file changed

Lines changed: 33 additions & 3 deletions

File tree

src/graphnet/data/extractors/icecube/i3calorimetry.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99

1010
from graphnet.utilities.imports import has_icecube_package
11+
from copy import deepcopy
1112

1213
if has_icecube_package() or TYPE_CHECKING:
1314
from icecube import (
@@ -32,6 +33,7 @@ def __init__(
3233
mmctracklist: str = "MMCTrackList",
3334
extractor_name: str = "I3Calorimetry",
3435
daughters: bool = False,
36+
highest_energy_primary: bool = False,
3537
**kwargs: Any,
3638
) -> None:
3739
"""Create a ConvexHull object from the GCD file.
@@ -49,12 +51,15 @@ def __init__(
4951
self.mctree = mctree
5052
self.mmctracklist = mmctracklist
5153
self.daughters = daughters
54+
self.highest_energy_primary = highest_energy_primary
5255
# Base class constructor
5356
super().__init__(extractor_name=extractor_name, **kwargs)
5457

5558
def __call__(self, frame: "icetray.I3Frame") -> Dict[str, Any]:
5659
"""Extract all the visible particles entering the volume."""
5760
output = {}
61+
# copy the original mctree because we will be modifying it
62+
tree_copy = deepcopy(frame[self.mctree])
5863
if self.frame_contains_info(frame):
5964

6065
e_entrance_track, e_deposited_track = self.total_track_energy(
@@ -67,7 +72,12 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, Any]:
6772
[
6873
p.energy
6974
for p in self.check_primary_energy(
70-
frame, self.get_primaries(frame, self.daughters)
75+
frame,
76+
self.get_primaries(
77+
frame,
78+
self.daughters,
79+
highest_energy_primary=self.highest_energy_primary,
80+
),
7181
)
7282
]
7383
)
@@ -113,6 +123,9 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, Any]:
113123
)
114124

115125
output = {k: v for k, v in output.items() if k not in self._exclude}
126+
# restore original mctree
127+
frame.Delete(self.mctree)
128+
frame[self.mctree] = tree_copy
116129
return output
117130

118131
def frame_contains_info(self, frame: "icetray.I3Frame") -> bool:
@@ -138,8 +151,12 @@ def total_track_energy(
138151
]
139152
MMCTrackList = simclasses.I3MMCTrackList(MMCTrackList)
140153

141-
for track in MuonGun.Track.harvest(frame[self.mctree], MMCTrackList):
142-
assert track.is_track, "Track is not a track"
154+
track_list = MuonGun.Track.harvest(frame[self.mctree], MMCTrackList)
155+
track_ids = np.array([track.id for track in track_list])
156+
157+
total_tracks = len(track_list)
158+
while len(track_list) > 0:
159+
track = track_list[0]
143160

144161
# Find distance to entrance and exit from sampling volume
145162
intersections = self.hull.surface.intersection(
@@ -160,6 +177,19 @@ def total_track_energy(
160177
assert e_deposited <= sum(
161178
[p.energy for p in primaries]
162179
), "Energy deposited is greater than primary energy"
180+
# erase particle and children
181+
exclude_ids = set(
182+
[c.id for c in frame[self.mctree].children(track.id)]
183+
)
184+
exclude_ids.add(track.id)
185+
frame[self.mctree].erase_children(track.id)
186+
frame[self.mctree].erase(track.id)
187+
track_mask = [tid not in exclude_ids for tid in track_ids]
188+
track_list = list(np.array(track_list)[track_mask])
189+
track_ids = track_ids[track_mask]
190+
self._logger.debug(
191+
f"Remaining tracks: {len(track_list)}/{total_tracks}"
192+
)
163193
return e_entrance, e_deposited
164194

165195
def total_cascade_energy(

0 commit comments

Comments
 (0)