88import numpy as np
99
1010from graphnet .utilities .imports import has_icecube_package
11+ from copy import deepcopy
1112
1213if has_icecube_package () or TYPE_CHECKING :
1314 from icecube import (
@@ -32,6 +33,7 @@ def __init__(
3233 mmctracklist : str = "MMCTrackList" ,
3334 extractor_name : str = "I3Calorimetry" ,
3435 daughters : bool = False ,
36+ highest_energy_primary : bool = False ,
3537 ** kwargs : Any ,
3638 ) -> None :
3739 """Create a ConvexHull object from the GCD file.
@@ -49,12 +51,15 @@ def __init__(
4951 self .mctree = mctree
5052 self .mmctracklist = mmctracklist
5153 self .daughters = daughters
54+ self .highest_energy_primary = highest_energy_primary
5255 # Base class constructor
5356 super ().__init__ (extractor_name = extractor_name , ** kwargs )
5457
5558 def __call__ (self , frame : "icetray.I3Frame" ) -> Dict [str , Any ]:
5659 """Extract all the visible particles entering the volume."""
5760 output = {}
61+ # copy the original mctree because we will be modifying it
62+ tree_copy = deepcopy (frame [self .mctree ])
5863 if self .frame_contains_info (frame ):
5964
6065 e_entrance_track , e_deposited_track = self .total_track_energy (
@@ -67,7 +72,12 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, Any]:
6772 [
6873 p .energy
6974 for p in self .check_primary_energy (
70- frame , self .get_primaries (frame , self .daughters )
75+ frame ,
76+ self .get_primaries (
77+ frame ,
78+ self .daughters ,
79+ highest_energy_primary = self .highest_energy_primary ,
80+ ),
7181 )
7282 ]
7383 )
@@ -113,6 +123,9 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, Any]:
113123 )
114124
115125 output = {k : v for k , v in output .items () if k not in self ._exclude }
126+ # restore original mctree
127+ frame .Delete (self .mctree )
128+ frame [self .mctree ] = tree_copy
116129 return output
117130
118131 def frame_contains_info (self , frame : "icetray.I3Frame" ) -> bool :
@@ -138,8 +151,12 @@ def total_track_energy(
138151 ]
139152 MMCTrackList = simclasses .I3MMCTrackList (MMCTrackList )
140153
141- for track in MuonGun .Track .harvest (frame [self .mctree ], MMCTrackList ):
142- assert track .is_track , "Track is not a track"
154+ track_list = MuonGun .Track .harvest (frame [self .mctree ], MMCTrackList )
155+ track_ids = np .array ([track .id for track in track_list ])
156+
157+ total_tracks = len (track_list )
158+ while len (track_list ) > 0 :
159+ track = track_list [0 ]
143160
144161 # Find distance to entrance and exit from sampling volume
145162 intersections = self .hull .surface .intersection (
@@ -160,6 +177,19 @@ def total_track_energy(
160177 assert e_deposited <= sum (
161178 [p .energy for p in primaries ]
162179 ), "Energy deposited is greater than primary energy"
180+ # erase particle and children
181+ exclude_ids = set (
182+ [c .id for c in frame [self .mctree ].children (track .id )]
183+ )
184+ exclude_ids .add (track .id )
185+ frame [self .mctree ].erase_children (track .id )
186+ frame [self .mctree ].erase (track .id )
187+ track_mask = [tid not in exclude_ids for tid in track_ids ]
188+ track_list = list (np .array (track_list )[track_mask ])
189+ track_ids = track_ids [track_mask ]
190+ self ._logger .debug (
191+ f"Remaining tracks: { len (track_list )} /{ total_tracks } "
192+ )
163193 return e_entrance , e_deposited
164194
165195 def total_cascade_energy (
0 commit comments