Skip to content

Commit 8096569

Browse files
committed
speed improvements
1 parent 93024c9 commit 8096569

1 file changed

Lines changed: 14 additions & 15 deletions

File tree

src/graphnet/data/extractors/icecube/i3calorimetry.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from graphnet.utilities.imports import has_icecube_package
1111
from copy import deepcopy
12+
from collections import deque
1213

1314
if has_icecube_package() or TYPE_CHECKING:
1415
from icecube import (
@@ -18,6 +19,8 @@
1819
simclasses,
1920
) # pyright: reportMissingImports=false
2021

22+
DARK = dataclasses.I3Particle.ParticleShape.Dark
23+
2124

2225
class I3Calorimetry(I3Extractor):
2326
"""Event level energy labeling for IceCube data.
@@ -224,7 +227,7 @@ def total_cascade_energy(
224227
frame: "icetray.I3Frame",
225228
) -> float:
226229
"""Get the total energy of cascade particles on entrance."""
227-
particles = np.array(
230+
particles = deque(
228231
self.get_primaries(
229232
frame, self.daughters, self.highest_energy_primary
230233
)
@@ -240,31 +243,27 @@ def total_cascade_energy(
240243
[],
241244
[],
242245
)
246+
247+
mctree = frame[self.mctree]
243248
while len(particles) > 0:
244-
p = particles[0]
245-
particles = particles[1:]
246-
p_children = dataclasses.I3MCTree.get_daughters(
247-
frame[self.mctree], p
248-
)
249+
p = particles.popleft()
250+
p_children = mctree.get_daughters(p)
249251
if len(p_children) > 0:
250-
particles = np.concatenate((particles, p_children))
252+
particles.extend(p_children)
251253
continue
252-
elif (
253-
p.is_track
254-
or p.shape == dataclasses.I3Particle.ParticleShape.Dark
255-
):
254+
if p.is_track or p.shape == DARK:
256255
continue
257256

258-
pos_list.append(np.array(p.pos))
259-
direc_list.append(np.array([p.dir.x, p.dir.y, p.dir.z]))
257+
pos_list.append([p.pos.x, p.pos.y, p.pos.z])
258+
direc_list.append([p.dir.x, p.dir.y, p.dir.z])
260259
length_list.append(p.length)
261260
cascade_bool.append(p.is_cascade)
262261
energies.append(p.energy)
263262

264263
length = np.array(length_list).astype(float)
265264
length[np.isnan(length)] = 0
266-
pos = np.array(pos_list)
267-
direc = np.array(direc_list)
265+
pos = np.asarray(pos_list)
266+
direc = np.asarray(direc_list)
268267
cascade_bool = np.array(cascade_bool)
269268
energies = np.array(energies)
270269
pos = (pos.T + direc.T * length).T

0 commit comments

Comments
 (0)