Skip to content

Commit 4898fb4

Browse files
authored
Merge pull request graphnet-team#763 from pweigel/visible_inelasticity
Improvements for inelasticity reconstructions and other utilities
2 parents 65ec11c + be64336 commit 4898fb4

6 files changed

Lines changed: 169 additions & 15 deletions

File tree

src/graphnet/data/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class TRUTH:
5050
"interaction_type",
5151
"interaction_time", # Added for vertex reconstruction
5252
"inelasticity",
53+
"visible_inelasticity",
54+
"visible_energy",
5355
"stopped_muon",
5456
]
5557
DEEPCORE = ICECUBE86

src/graphnet/data/extractors/icecube/i3truthextractor.py

Lines changed: 106 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import matplotlib.path as mpath
5+
from scipy.spatial import ConvexHull, Delaunay
56
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
67

78
from .i3extractor import I3Extractor
@@ -12,10 +13,12 @@
1213
from graphnet.utilities.imports import has_icecube_package
1314

1415
if has_icecube_package() or TYPE_CHECKING:
15-
from icecube import (
16+
from icecube import ( # noqa: F401
1617
dataclasses,
1718
icetray,
1819
phys_services,
20+
dataio,
21+
LeptonInjector,
1922
) # pyright: reportMissingImports=false
2023

2124

@@ -27,6 +30,7 @@ def __init__(
2730
name: str = "truth",
2831
borders: Optional[List[np.ndarray]] = None,
2932
mctree: Optional[str] = "I3MCTree",
33+
extend_boundary: Optional[float] = 0.0,
3034
):
3135
"""Construct I3TruthExtractor.
3236
@@ -37,6 +41,8 @@ def __init__(
3741
stopping within the detector. Defaults to hard-coded boundary
3842
coordinates.
3943
mctree: Str of which MCTree to use for truth values.
44+
extend_boundary: Distance to extend the convex hull of the detector
45+
for defining starting events.
4046
"""
4147
# Base class constructor
4248
super().__init__(name)
@@ -78,15 +84,53 @@ def __init__(
7884
self._borders = [border_xy, border_z]
7985
else:
8086
self._borders = borders
87+
88+
self._extend_boundary = extend_boundary
8189
self._mctree = mctree
8290

91+
def set_gcd(self, i3_file: str, gcd_file: Optional[str] = None) -> None:
92+
"""Extract GFrame and CFrame from i3/gcd-file pair.
93+
94+
Information from these frames will be set as member variables of
95+
`I3Extractor.`
96+
97+
Args:
98+
i3_file: Path to i3 file that is being converted.
99+
gcd_file: Path to GCD file. Defaults to None. If no GCD file is
100+
given, the method will attempt to find C and G frames in
101+
the i3 file instead. If either one of those are not
102+
present, `RuntimeErrors` will be raised.
103+
"""
104+
super().set_gcd(i3_file=i3_file, gcd_file=gcd_file)
105+
106+
# Modifications specific to I3TruthExtractor
107+
# These modifications are needed to identify starting events
108+
coordinates = []
109+
for _, g in self._gcd_dict.items():
110+
if g.position.z > 1200:
111+
continue # We want to exclude icetop
112+
coordinates.append([g.position.x, g.position.y, g.position.z])
113+
coordinates = np.array(coordinates)
114+
115+
if self._extend_boundary != 0.0:
116+
center = np.mean(coordinates, axis=0)
117+
d = coordinates - center
118+
norms = np.linalg.norm(d, axis=1, keepdims=True)
119+
dn = d / norms
120+
coordinates = coordinates + dn * self._extend_boundary
121+
122+
hull = ConvexHull(coordinates)
123+
124+
self.hull = hull
125+
self.delaunay = Delaunay(coordinates[self.hull.vertices])
126+
83127
def __call__(
84128
self, frame: "icetray.I3Frame", padding_value: Any = -1
85129
) -> Dict[str, Any]:
86130
"""Extract truth-level information."""
87131
is_mc = frame_is_montecarlo(frame, self._mctree)
88132
is_noise = frame_is_noise(frame, self._mctree)
89-
sim_type = self._find_data_type(is_mc, self._i3_file)
133+
sim_type = self._find_data_type(is_mc, self._i3_file, frame)
90134

91135
output = {
92136
"energy": padding_value,
@@ -119,6 +163,7 @@ def __call__(
119163
"L5_oscNext_bool": padding_value,
120164
"L6_oscNext_bool": padding_value,
121165
"L7_oscNext_bool": padding_value,
166+
"is_starting": padding_value,
122167
}
123168

124169
# Only InIceSplit P frames contain ML appropriate
@@ -230,6 +275,13 @@ def __call__(
230275
}
231276
)
232277

278+
is_starting = self._contained_vertex(output)
279+
output.update(
280+
{
281+
"is_starting": is_starting,
282+
}
283+
)
284+
233285
return output
234286

235287
def _extract_dbang_decay_length(
@@ -374,15 +426,34 @@ def _get_primary_particle_interaction_type_and_elasticity(
374426
# all variables and has no nans (always muon)
375427
else:
376428
MCInIcePrimary = None
377-
try:
378-
interaction_type = frame["I3MCWeightDict"]["InteractionType"]
379-
except KeyError:
380-
interaction_type = padding_value
381429

382-
try:
383-
elasticity = frame["I3GENIEResultDict"]["y"]
384-
except KeyError:
385-
elasticity = padding_value
430+
if sim_type == "LeptonInjector":
431+
event_properties = frame["EventProperties"]
432+
final_state_1 = event_properties.finalType1
433+
if final_state_1 in [
434+
dataclasses.I3Particle.NuE,
435+
dataclasses.I3Particle.NuMu,
436+
dataclasses.I3Particle.NuTau,
437+
dataclasses.I3Particle.NuEBar,
438+
dataclasses.I3Particle.NuMuBar,
439+
dataclasses.I3Particle.NuTauBar,
440+
]:
441+
interaction_type = 2 # NC
442+
else:
443+
interaction_type = 1 # CC
444+
445+
elasticity = 1 - event_properties.finalStateY
446+
447+
else:
448+
try:
449+
interaction_type = frame["I3MCWeightDict"]["InteractionType"]
450+
except KeyError:
451+
interaction_type = int(padding_value)
452+
453+
try:
454+
elasticity = 1 - frame["I3MCWeightDict"]["BjorkenY"]
455+
except KeyError:
456+
elasticity = padding_value
386457

387458
return MCInIcePrimary, interaction_type, elasticity
388459

@@ -418,12 +489,15 @@ def _get_primary_track_energy_and_inelasticity(
418489
return energy_track, energy_cascade, inelasticity
419490

420491
# Utility methods
421-
def _find_data_type(self, mc: bool, input_file: str) -> str:
492+
def _find_data_type(
493+
self, mc: bool, input_file: str, frame: "icetray.I3Frame"
494+
) -> str:
422495
"""Determine the data type.
423496
424497
Args:
425498
mc: Whether `input_file` is Monte Carlo simulation.
426499
input_file: Path to I3-file.
500+
frame: Physics frame containing MC record
427501
428502
Returns:
429503
The simulation/data type.
@@ -439,8 +513,26 @@ def _find_data_type(self, mc: bool, input_file: str) -> str:
439513
sim_type = "genie"
440514
elif "noise" in input_file:
441515
sim_type = "noise"
442-
elif "L2" in input_file: # not robust
443-
sim_type = "dbang"
444-
else:
516+
elif frame.Has("EventProprties") or frame.Has(
517+
"LeptonInjectorProperties"
518+
):
519+
sim_type = "LeptonInjector"
520+
elif frame.Has("I3MCWeightDict"):
445521
sim_type = "NuGen"
522+
else:
523+
raise NotImplementedError("Could not determine data type.")
446524
return sim_type
525+
526+
def _contained_vertex(self, truth: Dict[str, Any]) -> bool:
527+
"""Determine if an event is starting based on vertex position.
528+
529+
Args:
530+
truth: Dictionary of already extracted truth-level information.
531+
532+
Returns:
533+
True/False if vertex is inside detector.
534+
"""
535+
vertex = np.array(
536+
[truth["position_x"], truth["position_y"], truth["position_z"]]
537+
)
538+
return self.delaunay.find_simplex(vertex) >= 0

src/graphnet/data/extractors/icecube/utilities/i3_filters.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,30 @@ def _keep_frame(self, frame: "icetray.I3Frame") -> bool:
6464
return True
6565

6666

67+
class SubEventStreamI3Filter(I3Filter):
68+
"""A filter that only keeps frames from select splits."""
69+
70+
def __init__(self, selection: List[str]):
71+
"""Initialize SubEventStreamI3Filter.
72+
73+
Args:
74+
selection: List of subevent streams to keep.
75+
"""
76+
self._selection = selection
77+
78+
def _keep_frame(self, frame: "icetray.I3Frame") -> bool:
79+
"""Check if current frame should be kept.
80+
81+
Args:
82+
frame: I3-frame
83+
The I3-frame to check.
84+
"""
85+
if frame.Has("I3EventHeader"):
86+
if frame["I3EventHeader"].sub_event_stream not in self._selection:
87+
return False
88+
return True
89+
90+
6791
class I3FilterMask(I3Filter):
6892
"""Checks list of filters from the FilterMask in I3 frames."""
6993

src/graphnet/models/graphs/nodes/nodes.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ def __init__(
326326
"z_offset": None,
327327
"z_scaling": None,
328328
},
329+
sample_pulses: bool = True,
329330
) -> None:
330331
"""Construct `IceMixNodes`.
331332
@@ -339,6 +340,9 @@ def __init__(
339340
ice in IceCube are added to the feature set based on z coordinate.
340341
ice_args: Offset and scaling of the z coordinate in the Detector,
341342
to be able to make similar conversion in the ice data.
343+
sample_pulses: Enable sampling random pulses. If True and the
344+
event is longer than the max_length, they will be sampled. If
345+
False, then only the first max_length pulses will be selected.
342346
"""
343347
if input_feature_names is None:
344348
input_feature_names = [
@@ -384,6 +388,7 @@ def __init__(
384388
self.z_name = z_name
385389
self.hlc_name = hlc_name
386390
self.add_ice_properties = add_ice_properties
391+
self.sampling_enabled = sample_pulses
387392

388393
def _define_output_feature_names(
389394
self, input_feature_names: List[str]
@@ -437,7 +442,14 @@ def _construct_nodes(self, x: torch.Tensor) -> Tuple[Data, List[str]]:
437442
x[:, self.feature_indexes[self.hlc_name]] = torch.logical_not(
438443
x[:, self.feature_indexes[self.hlc_name]]
439444
) # hlc in kaggle was flipped
440-
ids = self._pulse_sampler(x, event_length)
445+
if self.sampling_enabled:
446+
ids = self._pulse_sampler(x, event_length)
447+
else:
448+
if event_length < self.max_length:
449+
ids = torch.arange(event_length)
450+
else:
451+
ids = torch.arange(self.max_length)
452+
441453
event_length = min(self.max_length, event_length)
442454

443455
graph = torch.zeros([event_length, self.n_features])

src/graphnet/models/task/reconstruction.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,19 @@ class InelasticityReconstruction(StandardLearnedTask):
232232
def _forward(self, x: Tensor) -> Tensor:
233233
# Transform output to unit range
234234
return torch.sigmoid(x)
235+
236+
237+
class VisibleInelasticityReconstruction(StandardLearnedTask):
238+
"""Reconstructs interaction visible inelasticity.
239+
240+
That is, 1-(visible track energy / visible hadronic energy).
241+
"""
242+
243+
# Requires one features: inelasticity itself
244+
default_target_labels = ["visible_inelasticity"]
245+
default_prediction_labels = ["visible_inelasticity_pred"]
246+
nb_inputs = 1
247+
248+
def _forward(self, x: Tensor) -> Tensor:
249+
# Transform output to unit range
250+
return 0.5 * (torch.tanh(2.0 * x) + 1.0)

src/graphnet/training/loss_functions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
6363
"""Syntax like `.forward`, for implentation in inheriting classes."""
6464

6565

66+
class MAELoss(LossFunction):
67+
"""Mean absolute error loss."""
68+
69+
def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
70+
"""Implement loss calculation."""
71+
return torch.mean(torch.abs(prediction - target), dim=-1)
72+
73+
6674
class MSELoss(LossFunction):
6775
"""Mean squared error loss."""
6876

0 commit comments

Comments
 (0)