Skip to content

Commit b9d8b45

Browse files
committed
deepcopy of mctree to avoid changes between extractors
1 parent ad36dc6 commit b9d8b45

3 files changed

Lines changed: 70 additions & 23 deletions

File tree

src/graphnet/data/extractors/combine_extractors.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
from typing import List, Dict
88

99
if has_icecube_package() or TYPE_CHECKING:
10-
from icecube import icetray # pyright: reportMissingImports=false
10+
from icecube import (
11+
icetray,
12+
dataclasses,
13+
) # pyright: reportMissingImports=false
14+
15+
from copy import deepcopy
1116

1217

1318
class CombinedExtractor(I3Extractor):
@@ -44,5 +49,10 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]:
4449
"""
4550
output = {}
4651
for extractor in self._extractors:
52+
if hasattr(extractor, "mctree") and frame.Has(extractor.mctree):
53+
mctree_backup = deepcopy(frame[extractor.mctree])
4754
output.update(extractor(frame))
55+
if hasattr(extractor, "mctree") and frame.Has(extractor.mctree):
56+
frame.Delete(extractor.mctree)
57+
frame[extractor.mctree] = mctree_backup
4858
return output

src/graphnet/data/extractors/icecube/i3truthextractor.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(
8888
self._borders = borders
8989

9090
self._extend_boundary = extend_boundary
91-
self._mctree = mctree
91+
self.mctree = mctree
9292

9393
def set_gcd(self, i3_file: str, gcd_file: Optional[str] = None) -> None:
9494
"""Extract GFrame and CFrame from i3/gcd-file pair.
@@ -130,8 +130,8 @@ def __call__(
130130
self, frame: "icetray.I3Frame", padding_value: Any = -1
131131
) -> Dict[str, Any]:
132132
"""Extract truth-level information."""
133-
is_mc = frame_is_montecarlo(frame, self._mctree)
134-
is_noise = frame_is_noise(frame, self._mctree)
133+
is_mc = frame_is_montecarlo(frame, self.mctree)
134+
is_noise = frame_is_noise(frame, self.mctree)
135135
sim_type = self._find_data_type(is_mc, self._i3_file, frame)
136136

137137
output = {
@@ -289,7 +289,7 @@ def __call__(
289289
def _extract_dbang_decay_length(
290290
self, frame: "icetray.I3Frame", padding_value: float = -1
291291
) -> float:
292-
mctree = frame[self._mctree]
292+
mctree = frame[self.mctree]
293293
try:
294294
p_true = mctree.primaries[0]
295295
p_daughters = mctree.get_daughters(p_true)
@@ -418,12 +418,12 @@ def _get_primary_particle_interaction_type_and_elasticity(
418418
try:
419419
MCInIcePrimary = frame["MCInIcePrimary"]
420420
except KeyError:
421-
MCInIcePrimary = frame[self._mctree][0]
421+
MCInIcePrimary = frame[self.mctree][0]
422422
if (
423423
MCInIcePrimary.energy != MCInIcePrimary.energy
424424
): # This is a nan check. Only happens for some muons
425425
# where second item in MCTree is primary. Weird!
426-
MCInIcePrimary = frame[self._mctree][1]
426+
MCInIcePrimary = frame[self.mctree][1]
427427
# For some strange reason the second entry is identical in
428428
# all variables and has no nans (always muon)
429429
else:
@@ -472,7 +472,7 @@ def _get_primary_track_energy_and_inelasticity(
472472
Tuple containing the energy of tracks from primary, and the
473473
corresponding inelasticity.
474474
"""
475-
mc_tree = frame[self._mctree]
475+
mc_tree = frame[self.mctree]
476476
primary = mc_tree.primaries[0]
477477
daughters = mc_tree.get_daughters(primary)
478478
tracks = []

src/graphnet/data/readers/i3reader.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@
1212
from graphnet.utilities.filesys import find_i3_files
1313
from .graphnet_file_reader import GraphNeTFileReader
1414

15+
from copy import deepcopy
16+
1517

1618
if has_icecube_package():
17-
from icecube import icetray, dataio # pyright: reportMissingImports=false
19+
from icecube import (
20+
icetray,
21+
dataio,
22+
) # pyright: reportMissingImports=false
1823

1924

2025
class I3Reader(GraphNeTFileReader):
@@ -87,22 +92,54 @@ def __call__(
8792
# Open I3 file
8893
i3_file_io = dataio.I3File(file_path.i3_file, "r")
8994
data = list()
90-
while i3_file_io.more():
91-
try:
92-
frame = i3_file_io.pop_physics()
93-
except Exception as e:
94-
if "I3" in str(e):
95+
try:
96+
while i3_file_io.more():
97+
try:
98+
frame = i3_file_io.pop_physics()
99+
except Exception as e:
100+
if "I3" in str(e):
101+
continue
102+
# check if frame should be skipped
103+
if self._skip_frame(frame):
95104
continue
96-
# check if frame should be skipped
97-
if self._skip_frame(frame):
98-
continue
99-
100-
# Try to extract data from I3Frame
101-
results = [extractor(frame) for extractor in self._extractors]
102-
103-
data_dict = OrderedDict(zip(self.extracor_names, results))
104105

105-
data.append(data_dict)
106+
# Try to extract data from I3Frame
107+
results = []
108+
try:
109+
for extractor in self._extractors:
110+
if hasattr(extractor, "mctree") and frame.Has(
111+
extractor.mctree
112+
):
113+
mctree_backup = deepcopy(frame[extractor.mctree])
114+
115+
results.append(extractor(frame))
116+
if hasattr(extractor, "mctree") and frame.Has(
117+
extractor.mctree
118+
):
119+
frame.Delete(extractor.mctree)
120+
frame[extractor.mctree] = mctree_backup
121+
122+
data_dict = OrderedDict(zip(self.extracor_names, results))
123+
124+
except KeyError as e:
125+
if "Deserialization failed for object" in str(e):
126+
self.warning(
127+
f"KeyError {e} in file {file_path.i3_file}"
128+
" - skipping frame."
129+
)
130+
continue
131+
else:
132+
raise e
133+
134+
data.append(data_dict)
135+
except KeyError as e:
136+
if "Deserialization failed for object" in str(e):
137+
self.warning(
138+
f"KeyError {e} in file {file_path.i3_file}"
139+
" - skipping remaining frames."
140+
)
141+
else:
142+
raise e
106143
return data
107144

108145
def find_files(self, path: Union[str, List[str]]) -> List[I3FileSet]:

0 commit comments

Comments
 (0)