11"""Class(es) for building/connecting graphs."""
22
3- from typing import List , Tuple , Optional , Dict , Union , Callable
3+ from typing import List , Tuple , Optional , Dict , Union
44from abc import abstractmethod
55
66import torch
@@ -475,9 +475,11 @@ class ClusterSummaryFeatures(NodeDefinition):
475475 """Represent pulse maps as clusters with summary features.
476476
477477 If `cluster_on` is set to the xyz coordinates of optical modules
478- e.g. `cluster_on = ['dom_x', 'dom_y', 'dom_z']`, each node will be a
479- unique optical module and the pulse information (charge, time) is summarized.
480- NOTE: Intended to be used with features [dom_x, dom_y, dom_z, charge, time]
478+ e.g. `cluster_on = ['dom_x', 'dom_y', 'dom_z']`, each node will be
479+ a unique optical module and the pulse information (e.g. charge, time)
480+ is summarized.
481+ NOTE: Developed to be used with features
482+ [dom_x, dom_y, dom_z, charge, time]
481483
482484 Possible features per cluster:
483485 - total charge
@@ -495,7 +497,8 @@ class ClusterSummaryFeatures(NodeDefinition):
495497 - number of pulses per clusters
496498 feature name: `counts`
497499
498- Taken from Theo Glauchs PhD thesis:
500+ For more details on some of the features see
501+ Theo Glauchs thesis (chapter 5.3):
499502 https://mediatum.ub.tum.de/node?id=1584755
500503 """
501504
@@ -539,31 +542,22 @@ def __init__(
539542 order_in_time: If True, clusters are sorted by time.
540543 If your data is already sorted by time, you can set this
541544 to False to avoid a potential overhead.
542- NOTE: Should only be set to False if you are sure that
545+ NOTE: Should only be set to False if you are sure that
543546 the input data is already sorted by time. Will lead to
544547 incorrect results if the data is not sorted by time.
545- add_counts: If True, number of log10(event counts per clusters) is added as
546- a feature.
548+ add_counts: If True, number of log10(event counts per clusters)
549+ is added as a feature.
547550 """
548551 # Set member variables
549552 self ._cluster_on = cluster_on
550553 self ._charge_label = charge_label
551554 self ._time_label = time_label
552- if isinstance (charge_standardization , float ):
553- charge_std_fn = lambda x : x * charge_standardization
554- elif isinstance (
555- charge_standardization , str
556- ):
557- if charge_standardization != "log" :
558- raise ValueError (
559- f"charge_standardization must be either a float or 'log', "
560- f"but got { charge_standardization } "
561- )
562- charge_std_fn = lambda x : np .log10 (x )
563- self ._charge_std_fn = charge_std_fn
564- self ._time_standardization = time_standardization
565555 self ._order_in_time = order_in_time
566556
557+ # Check if charge_standardization is a float or 'log'
558+ self ._charge_standardization = charge_standardization
559+ self ._time_standardization = time_standardization
560+ self ._verify_standardization ()
567561
568562 # feature member variables
569563 self ._total_charge = total_charge
@@ -624,8 +618,9 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
624618 # add total charge
625619 if self ._total_charge :
626620 cluster_class .add_sum_charge (charge_index = self ._charge_idx )
627- cluster_class .clustered_x [:, - 1 ] = self ._charge_std_fn (
628- cluster_class .clustered_x [:, - 1 ]
621+ cluster_class .clustered_x [:, - 1 ] = self ._standardize_features (
622+ cluster_class .clustered_x [:, - 1 ],
623+ self ._charge_standardization ,
629624 )
630625
631626 # add charge after t
@@ -635,12 +630,11 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
635630 summarization_indices = [self ._charge_idx ],
636631 times = self ._charge_after_t ,
637632 )
638- cluster_class .clustered_x [
639- :, - len (self ._charge_after_t ) :
640- ] = self ._charge_std_fn (
641- cluster_class .clustered_x [
642- :, - len (self ._charge_after_t ) :
643- ]
633+ cluster_class .clustered_x [:, - len (self ._charge_after_t ) :] = (
634+ self ._standardize_features (
635+ cluster_class .clustered_x [:, - len (self ._charge_after_t ) :],
636+ self ._charge_standardization ,
637+ )
644638 )
645639
646640 # add time of first hit
@@ -649,21 +643,31 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
649643 time_index = self ._time_idx ,
650644 )
651645 cluster_class .clustered_x [:, - 1 ] -= ref_time
652- cluster_class .clustered_x [:, - 1 ] *= self ._time_standardization
646+
647+ cluster_class .clustered_x [:, - 1 ] = self ._standardize_features (
648+ cluster_class .clustered_x [:, - 1 ],
649+ self ._time_standardization ,
650+ )
653651
654652 # add time spread
655653 if self ._time_spread :
656654 cluster_class .add_spread (
657655 columns = [self ._time_idx ],
658656 )
659- cluster_class .clustered_x [:, - 1 ] *= self ._time_standardization
657+ cluster_class .clustered_x [:, - 1 ] = self ._standardize_features (
658+ cluster_class .clustered_x [:, - 1 ],
659+ self ._time_standardization ,
660+ )
660661
661662 # add time std
662663 if self ._time_std :
663664 cluster_class .add_std (
664665 columns = [self ._time_idx ],
665666 )
666- cluster_class .clustered_x [:, - 1 ] *= self ._time_standardization
667+ cluster_class .clustered_x [:, - 1 ] = self ._standardize_features (
668+ cluster_class .clustered_x [:, - 1 ],
669+ self ._time_standardization ,
670+ )
667671
668672 # add time after charge percentiles
669673 if len (self ._time_after_charge_pct ) > 0 :
@@ -677,7 +681,13 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
677681 ] -= ref_time
678682 cluster_class .clustered_x [
679683 :, - len (self ._time_after_charge_pct ) :
680- ] *= self ._time_standardization
684+ ] = self ._standardize_features (
685+ cluster_class .clustered_x [
686+ :, - len (self ._time_after_charge_pct ) :
687+ ],
688+ self ._time_standardization ,
689+ )
690+
681691 if self ._add_counts :
682692 cluster_class .add_counts ()
683693 return torch .tensor (cluster_class .clustered_x )
@@ -689,3 +699,47 @@ def set_indices(self, feature_names: List[str]) -> None:
689699 ]
690700 self ._charge_idx = feature_names .index (self ._charge_label )
691701 self ._time_idx = feature_names .index (self ._time_label )
702+
703+ def _standardize_features (
704+ self ,
705+ x : np .ndarray ,
706+ standardization : Union [float , str ],
707+ ) -> np .ndarray :
708+ """Standardize the features in the input tensor."""
709+ if isinstance (standardization , float ):
710+ return x * standardization
711+ elif isinstance (standardization , str ):
712+ if standardization != "log" :
713+ raise ValueError (
714+ f"standardization must be either a float or 'log', "
715+ f"but got { standardization } "
716+ )
717+ return np .log10 (x )
718+ else :
719+ raise ValueError (
720+ f"standardization must be either a float or 'log', "
721+ f"but got { standardization } "
722+ )
723+
724+ def _verify_standardization (
725+ self ,
726+ ) -> torch .Tensor :
727+ """Verify the standardization of the features."""
728+ if not isinstance (self ._charge_standardization , float ):
729+ if isinstance (self ._charge_standardization , str ):
730+ if self ._charge_standardization != "log" :
731+ raise ValueError (
732+ f"charge_standardization must be either a float or"
733+ f" 'log', but got { self ._charge_standardization } "
734+ )
735+ else :
736+ raise ValueError (
737+ f"charge_standardization must be either a float or 'log', "
738+ f"but got { self ._charge_standardization } "
739+ )
740+
741+ if not isinstance (self ._time_standardization , float ):
742+ raise ValueError (
743+ f"time_standardization must be a float, "
744+ f"but got { self ._time_standardization } "
745+ )
0 commit comments