Skip to content

Commit 9b1d49c

Browse files
authored
Merge pull request graphnet-team#797 from sevmag/summary_feature_node
add summary feature nodedefinition
2 parents 65c4568 + 27800c4 commit 9b1d49c

5 files changed

Lines changed: 461 additions & 16 deletions

File tree

src/graphnet/models/data_representation/graphs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
NodeAsDOMTimeSeries,
1919
PercentileClusters,
2020
IceMixNodes,
21+
ClusterSummaryFeatures,
2122
)
22-
2323
from .edges import (
2424
EdgeDefinition,
2525
KNNEdges,

src/graphnet/models/data_representation/graphs/nodes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@
1111
PercentileClusters,
1212
NodeAsDOMTimeSeries,
1313
IceMixNodes,
14+
ClusterSummaryFeatures,
1415
)

src/graphnet/models/data_representation/graphs/nodes/nodes.py

Lines changed: 276 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Class(es) for building/connecting graphs."""
22

3-
from typing import List, Tuple, Optional, Dict
3+
from typing import List, Tuple, Optional, Dict, Union
44
from abc import abstractmethod
55

66
import torch
@@ -469,3 +469,278 @@ def _construct_nodes(self, x: torch.Tensor) -> torch.Tensor:
469469
graph[:event_length, idx] = x[ids, self.feature_indexes[feature]]
470470

471471
return graph
472+
473+
474+
class ClusterSummaryFeatures(NodeDefinition):
475+
"""Represent pulse maps as clusters with summary features.
476+
477+
If `cluster_on` is set to the xyz coordinates of optical modules
478+
e.g. `cluster_on = ['dom_x', 'dom_y', 'dom_z']`, each node will be
479+
a unique optical module and the pulse information (e.g. charge, time)
480+
is summarized.
481+
NOTE: Developed to be used with features
482+
[dom_x, dom_y, dom_z, charge, time]
483+
484+
Possible features per cluster:
485+
- total charge
486+
feature name: `total_charge`
487+
- charge accumulated after <X> time units
488+
feature name: `charge_after_<X>ns`
489+
- time of first hit in the optical module
490+
feature name: `time_of_first_hit`
491+
- time spread per optical module
492+
feature name: `time_spread`
493+
- time std per optical module
494+
feature name: `time_std`
495+
- time took to collect <X> percent of total charge per cluster
496+
feature name: `time_after_charge_pct<X>`
497+
- number of pulses per clusters
498+
feature name: `counts`
499+
500+
For more details on some of the features see
501+
Theo Glauchs thesis (chapter 5.3):
502+
https://mediatum.ub.tum.de/node?id=1584755
503+
"""
504+
505+
def __init__(
506+
self,
507+
cluster_on: List[str],
508+
input_feature_names: List[str],
509+
charge_label: str = "charge",
510+
time_label: str = "dom_time",
511+
total_charge: bool = True,
512+
charge_after_t: List[int] = [10, 50, 100],
513+
time_of_first_hit: bool = True,
514+
time_spread: bool = True,
515+
time_std: bool = True,
516+
time_after_charge_pct: List[int] = [1, 3, 5, 11, 15, 20, 50, 80],
517+
charge_standardization: Union[float, str] = "log",
518+
time_standardization: float = 1e-3,
519+
order_in_time: bool = True,
520+
add_counts: bool = False,
521+
) -> None:
522+
"""Construct `ClusterSummaryFeatures`.
523+
524+
Args:
525+
cluster_on: Names of features to create clusters from.
526+
input_feature_names: Column names for input features.
527+
charge_label: Name of the charge column.
528+
time_label: Name of the time column.
529+
total_charge: If True, calculates total charge as feature.
530+
charge_after_t: List of times at which the accumulated charge
531+
is calculated as a feature.
532+
time_of_first_hit: If True, time of first hit is added
533+
as a feature.
534+
time_spread: If True, time spread is added as a feature.
535+
time_std: If True, time std is added as a feature.
536+
time_after_charge_pct: List of percentiles to calculate time after
537+
charge.
538+
charge_standardization: Either a float or 'log'. If a float,
539+
the features are multiplied by this factor. If 'log', the
540+
features are transformed to log10 scale.
541+
time_standardization: Standardization factor for features
542+
with a time
543+
order_in_time: If True, clusters are ordered in time.
544+
If your data is already ordered in time, you can set this
545+
to False to avoid a potential overhead.
546+
NOTE: Should only be set to False if you are sure that
547+
the input data is already ordered in time. Will lead to
548+
incorrect results otherwise.
549+
add_counts: If True, number of log10(event counts per clusters)
550+
is added as a feature.
551+
552+
NOTE: Make sure that either the input data is not already standardized
553+
or that the `charge_standardization` and `time_standardization`
554+
parameters are set to 1 to avoid a double standardization.
555+
"""
556+
# Set member variables
557+
self._cluster_on = cluster_on
558+
self._charge_label = charge_label
559+
self._time_label = time_label
560+
self._order_in_time = order_in_time
561+
562+
# Check if charge_standardization is a float or 'log'
563+
self._charge_standardization = charge_standardization
564+
self._time_standardization = time_standardization
565+
self._verify_standardization()
566+
567+
# feature member variables
568+
self._total_charge = total_charge
569+
self._charge_after_t = charge_after_t
570+
self._time_of_first_hit = time_of_first_hit
571+
self._time_spread = time_spread
572+
self._time_std = time_std
573+
self._time_after_charge_pct = time_after_charge_pct
574+
self._add_counts = add_counts
575+
576+
# Base class constructor
577+
super().__init__(input_feature_names=input_feature_names)
578+
if self._order_in_time is False:
579+
self.info(
580+
"Setting `order_by_time` to False. "
581+
"Make sure that the input data is already ordered in time."
582+
)
583+
584+
def _define_output_feature_names(
585+
self,
586+
input_feature_names: List[str],
587+
) -> List[str]:
588+
"""Set the output feature names."""
589+
self.set_indices(input_feature_names)
590+
new_feature_names = deepcopy(self._cluster_on)
591+
if self._total_charge:
592+
new_feature_names.append("total_charge")
593+
for t in self._charge_after_t:
594+
new_feature_names.append(f"charge_after_{t}ns")
595+
if self._time_of_first_hit:
596+
new_feature_names.append("time_of_first_hit")
597+
if self._time_spread:
598+
new_feature_names.append("time_spread")
599+
if self._time_std:
600+
new_feature_names.append("time_std")
601+
for pct in self._time_after_charge_pct:
602+
new_feature_names.append(f"time_after_charge_pct{pct}")
603+
if self._add_counts:
604+
new_feature_names.append("counts")
605+
return new_feature_names
606+
607+
def _construct_nodes(self, x: torch.Tensor) -> Data:
608+
"""Construct nodes from raw node features ´x´."""
609+
# Cast to Numpy
610+
x = x.numpy()
611+
# Construct clusters with percentile-summarized features
612+
cluster_class = cluster_and_pad(
613+
x=x,
614+
cluster_columns=self._cluster_idx,
615+
sort_by=[self._time_idx] if self._order_in_time else [],
616+
)
617+
# calculate charge weighted median time as reference
618+
ref_time = cluster_class.reference_time(
619+
charge_index=self._charge_idx,
620+
time_index=self._time_idx,
621+
)
622+
623+
# add total charge
624+
if self._total_charge:
625+
cluster_class.add_sum_charge(charge_index=self._charge_idx)
626+
cluster_class.clustered_x[:, -1] = self._standardize_features(
627+
cluster_class.clustered_x[:, -1],
628+
self._charge_standardization,
629+
)
630+
631+
# add charge after t
632+
if len(self._charge_after_t) > 0:
633+
cluster_class.add_accumulated_value_after_t(
634+
time_index=self._time_idx,
635+
summarization_indices=[self._charge_idx],
636+
times=self._charge_after_t,
637+
)
638+
cluster_class.clustered_x[:, -len(self._charge_after_t) :] = (
639+
self._standardize_features(
640+
cluster_class.clustered_x[:, -len(self._charge_after_t) :],
641+
self._charge_standardization,
642+
)
643+
)
644+
645+
# add time of first hit
646+
if self._time_of_first_hit:
647+
cluster_class.add_time_first_pulse(
648+
time_index=self._time_idx,
649+
)
650+
cluster_class.clustered_x[:, -1] -= ref_time
651+
652+
cluster_class.clustered_x[:, -1] = self._standardize_features(
653+
cluster_class.clustered_x[:, -1],
654+
self._time_standardization,
655+
)
656+
657+
# add time spread
658+
if self._time_spread:
659+
cluster_class.add_spread(
660+
columns=[self._time_idx],
661+
)
662+
cluster_class.clustered_x[:, -1] = self._standardize_features(
663+
cluster_class.clustered_x[:, -1],
664+
self._time_standardization,
665+
)
666+
667+
# add time std
668+
if self._time_std:
669+
cluster_class.add_std(
670+
columns=[self._time_idx],
671+
)
672+
cluster_class.clustered_x[:, -1] = self._standardize_features(
673+
cluster_class.clustered_x[:, -1],
674+
self._time_standardization,
675+
)
676+
677+
# add time after charge percentiles
678+
if len(self._time_after_charge_pct) > 0:
679+
cluster_class.add_charge_threshold_summary(
680+
summarization_indices=[self._time_idx],
681+
percentiles=self._time_after_charge_pct,
682+
charge_index=self._charge_idx,
683+
)
684+
cluster_class.clustered_x[
685+
:, -len(self._time_after_charge_pct) :
686+
] -= ref_time
687+
cluster_class.clustered_x[
688+
:, -len(self._time_after_charge_pct) :
689+
] = self._standardize_features(
690+
cluster_class.clustered_x[
691+
:, -len(self._time_after_charge_pct) :
692+
],
693+
self._time_standardization,
694+
)
695+
696+
if self._add_counts:
697+
cluster_class.add_counts()
698+
return torch.tensor(cluster_class.clustered_x)
699+
700+
def set_indices(self, feature_names: List[str]) -> None:
701+
"""Set the indices for the input features."""
702+
self._cluster_idx = [
703+
feature_names.index(column) for column in self._cluster_on
704+
]
705+
self._charge_idx = feature_names.index(self._charge_label)
706+
self._time_idx = feature_names.index(self._time_label)
707+
708+
def _standardize_features(
709+
self,
710+
x: np.ndarray,
711+
standardization: Union[float, str],
712+
) -> np.ndarray:
713+
"""Standardize the features in the input array."""
714+
if isinstance(standardization, float):
715+
return x * standardization
716+
elif standardization == "log":
717+
return np.log10(x)
718+
else:
719+
# should never happen, but just in case
720+
raise ValueError(
721+
f"standardization must be either a float or 'log', "
722+
f"but got {standardization}"
723+
)
724+
725+
def _verify_standardization(
726+
self,
727+
) -> torch.Tensor:
728+
"""Verify settings of standardization of the features."""
729+
if not isinstance(self._charge_standardization, float):
730+
if isinstance(self._charge_standardization, str):
731+
if self._charge_standardization != "log":
732+
raise ValueError(
733+
f"charge_standardization must be either a float or"
734+
f" 'log', but got {self._charge_standardization}"
735+
)
736+
else:
737+
raise ValueError(
738+
f"charge_standardization must be either a float or 'log', "
739+
f"but got {self._charge_standardization}"
740+
)
741+
742+
if not isinstance(self._time_standardization, float):
743+
raise ValueError(
744+
f"time_standardization must be a float, "
745+
f"but got {self._time_standardization}"
746+
)

0 commit comments

Comments
 (0)