Skip to content

Commit bea975c

Browse files
committed
adding standardization and PR comments
1 parent b5ecc2b commit bea975c

4 files changed

Lines changed: 61 additions & 27 deletions

File tree

src/graphnet/models/data_representation/graphs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
NodeAsDOMTimeSeries,
1919
PercentileClusters,
2020
IceMixNodes,
21-
DOMSummaryFeatures,
21+
ClusterSummaryFeatures,
2222
)
2323

2424
from .edges import (

src/graphnet/models/data_representation/graphs/nodes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@
1111
PercentileClusters,
1212
NodeAsDOMTimeSeries,
1313
IceMixNodes,
14-
DOMSummaryFeatures,
14+
ClusterSummaryFeatures,
1515
)

src/graphnet/models/data_representation/graphs/nodes/nodes.py

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Class(es) for building/connecting graphs."""
22

3-
from typing import List, Tuple, Optional, Dict
3+
from typing import List, Tuple, Optional, Dict, Union, Callable
44
from abc import abstractmethod
55

66
import torch
@@ -471,21 +471,29 @@ def _construct_nodes(self, x: torch.Tensor) -> torch.Tensor:
471471
return graph
472472

473473

474-
class DOMSummaryFeatures(NodeDefinition):
475-
"""Represent DOMs as clusters with summary features.
474+
class ClusterSummaryFeatures(NodeDefinition):
475+
"""Represent pulse maps as clusters with summary features.
476476
477-
If `cluster_on` is set to the xyz coordinates of DOMs
477+
If `cluster_on` is set to the xyz coordinates of optical modules
478478
e.g. `cluster_on = ['dom_x', 'dom_y', 'dom_z']`, each node will be a
479-
unique DOM and the pulse information (charge, time) is summarized by
479+
unique optical module and the pulse information (charge, time) is summarized.
480480
NOTE: Intended to be used with features [dom_x, dom_y, dom_z, charge, time]
481481
482+
Possible features per cluster:
482483
- total charge
483-
- charge accumulated after times `charge_after_t`
484-
- time of first hit per DOM
485-
- time spread per DOM
486-
- time std per DOM
487-
- time took to collect the charge percentiles `time_after_charge_pct`
488-
- number of pulses per DOM
484+
feature name: `total_charge`
485+
- charge accumulated after <X> time units
486+
feature name: `charge_after_<X>ns`
487+
- time of first hit in the optical module
488+
feature name: `time_of_first_hit`
489+
- time spread per optical module
490+
feature name: `time_spread`
491+
- time std per optical module
492+
feature name: `time_std`
493+
- time took to collect <X> percent of total charge per cluster
494+
feature name: `time_after_charge_pct<X>`
495+
- number of pulses per clusters
496+
feature name: `counts`
489497
490498
Taken from Theo Glauchs PhD thesis:
491499
https://mediatum.ub.tum.de/node?id=1584755
@@ -503,11 +511,12 @@ def __init__(
503511
time_spread: bool = True,
504512
time_std: bool = True,
505513
time_after_charge_pct: List[int] = [1, 3, 5, 11, 15, 20, 50, 80],
506-
charge_standardization: float = 1e-2,
514+
charge_standardization: Union[float, str] = 1e-2,
507515
time_standardization: float = 1e-3,
516+
order_in_time: bool = True,
508517
add_counts: bool = False,
509518
) -> None:
510-
"""Construct `PercentileClusters`.
519+
"""Construct `ClusterSummaryFeatures`.
511520
512521
Args:
513522
cluster_on: Names of features to create clusters from.
@@ -527,15 +536,34 @@ def __init__(
527536
with a charge value.
528537
time_standardization: Standardization factor for features
529538
with a time
530-
add_counts: If True, number of log10(counts per DOM) is added as
539+
order_in_time: If True, clusters are sorted by time.
540+
If your data is already sorted by time, you can set this
541+
to False to avoid a potential overhead.
542+
NOTE: Should only be set to False if you are sure that
543+
the input data is already sorted by time. Will lead to
544+
incorrect results if the data is not sorted by time.
545+
add_counts: If True, number of log10(event counts per clusters) is added as
531546
a feature.
532547
"""
533548
# Set member variables
534549
self._cluster_on = cluster_on
535550
self._charge_label = charge_label
536551
self._time_label = time_label
537-
self._charge_standardization = charge_standardization
552+
if isinstance(charge_standardization, float):
553+
charge_std_fn = lambda x: x * charge_standardization
554+
elif isinstance(
555+
charge_standardization, str
556+
):
557+
if charge_standardization != "log":
558+
raise ValueError(
559+
f"charge_standardization must be either a float or 'log', "
560+
f"but got {charge_standardization}"
561+
)
562+
charge_std_fn = lambda x: np.log10(x)
563+
self._charge_std_fn = charge_std_fn
538564
self._time_standardization = time_standardization
565+
self._order_in_time = order_in_time
566+
539567

540568
# feature member variables
541569
self._total_charge = total_charge
@@ -548,13 +576,18 @@ def __init__(
548576

549577
# Base class constructor
550578
super().__init__(input_feature_names=input_feature_names)
579+
if self._order_in_time is False:
580+
self.info(
581+
"Setting `order_by_time` to False. "
582+
"Make sure that the input data is already sorted by time."
583+
)
551584

552585
def _define_output_feature_names(
553586
self,
554587
input_feature_names: List[str],
555588
) -> List[str]:
556589
"""Set the output feature names."""
557-
self.set_indeces(input_feature_names)
590+
self.set_indices(input_feature_names)
558591
new_feature_names = deepcopy(self._cluster_on)
559592
if self._total_charge:
560593
new_feature_names.append("total_charge")
@@ -580,7 +613,7 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
580613
cluster_class = cluster_and_pad(
581614
x=x,
582615
cluster_columns=self._cluster_idx,
583-
sort_by=[self._time_idx],
616+
sort_by=[self._time_idx] if self._order_in_time else [],
584617
)
585618
# calculate charge weighted median time as reference
586619
ref_time = cluster_class.reference_time(
@@ -591,7 +624,9 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
591624
# add total charge
592625
if self._total_charge:
593626
cluster_class.add_sum_charge(charge_index=self._charge_idx)
594-
cluster_class.clustered_x[:, -1] *= self._charge_standardization
627+
cluster_class.clustered_x[:, -1] = self._charge_std_fn(
628+
cluster_class.clustered_x[:, -1]
629+
)
595630

596631
# add charge after t
597632
if len(self._charge_after_t) > 0:
@@ -602,7 +637,11 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
602637
)
603638
cluster_class.clustered_x[
604639
:, -len(self._charge_after_t) :
605-
] *= self._charge_standardization
640+
] = self._charge_std_fn(
641+
cluster_class.clustered_x[
642+
:, -len(self._charge_after_t) :
643+
]
644+
)
606645

607646
# add time of first hit
608647
if self._time_of_first_hit:
@@ -643,7 +682,7 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
643682
cluster_class.add_counts()
644683
return torch.tensor(cluster_class.clustered_x)
645684

646-
def set_indeces(self, feature_names: List[str]) -> None:
685+
def set_indices(self, feature_names: List[str]) -> None:
647686
"""Set the indices for the input features."""
648687
self._cluster_idx = [
649688
feature_names.index(column) for column in self._cluster_on

src/graphnet/models/data_representation/graphs/utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -582,11 +582,6 @@ def add_accumulated_value_after_t(
582582
_cluster_names: The names are added at the end of the tensor
583583
or inserted at the specified location
584584
"""
585-
assert (
586-
self._sort_by[-1] == time_index
587-
), """Data is not sorted by time index.
588-
Make sure that the last element of
589-
sort_by is set to the time index"""
590585
# Summarize the values at different times
591586
time_first_pulse = np.nanmin(
592587
self._padded_x[:, :, time_index],

0 commit comments

Comments
 (0)