11"""Class(es) for building/connecting graphs."""
22
3- from typing import List , Tuple , Optional , Dict
3+ from typing import List , Tuple , Optional , Dict , Union , Callable
44from abc import abstractmethod
55
66import torch
@@ -471,21 +471,29 @@ def _construct_nodes(self, x: torch.Tensor) -> torch.Tensor:
471471 return graph
472472
473473
474- class DOMSummaryFeatures (NodeDefinition ):
475- """Represent DOMs as clusters with summary features.
474+ class ClusterSummaryFeatures (NodeDefinition ):
475+ """Represent pulse maps as clusters with summary features.
476476
477- If `cluster_on` is set to the xyz coordinates of DOMs
477+ If `cluster_on` is set to the xyz coordinates of optical modules
478478 e.g. `cluster_on = ['dom_x', 'dom_y', 'dom_z']`, each node will be a
479- unique DOM and the pulse information (charge, time) is summarized by
479+ unique optical module and the pulse information (charge, time) is summarized.
480480 NOTE: Intended to be used with features [dom_x, dom_y, dom_z, charge, time]
481481
482+ Possible features per cluster:
482483 - total charge
483- - charge accumulated after times `charge_after_t`
484- - time of first hit per DOM
485- - time spread per DOM
486- - time std per DOM
487- - time took to collect the charge percentiles `time_after_charge_pct`
488- - number of pulses per DOM
484+ feature name: `total_charge`
485+ - charge accumulated after <X> time units
486+ feature name: `charge_after_<X>ns`
487+ - time of first hit in the optical module
488+ feature name: `time_of_first_hit`
489+ - time spread per optical module
490+ feature name: `time_spread`
491+ - time std per optical module
492+ feature name: `time_std`
493+ - time took to collect <X> percent of total charge per cluster
494+ feature name: `time_after_charge_pct<X>`
495+ - number of pulses per clusters
496+ feature name: `counts`
489497
490498 Taken from Theo Glauchs PhD thesis:
491499 https://mediatum.ub.tum.de/node?id=1584755
@@ -503,11 +511,12 @@ def __init__(
503511 time_spread : bool = True ,
504512 time_std : bool = True ,
505513 time_after_charge_pct : List [int ] = [1 , 3 , 5 , 11 , 15 , 20 , 50 , 80 ],
506- charge_standardization : float = 1e-2 ,
514+ charge_standardization : Union [ float , str ] = 1e-2 ,
507515 time_standardization : float = 1e-3 ,
516+ order_in_time : bool = True ,
508517 add_counts : bool = False ,
509518 ) -> None :
510- """Construct `PercentileClusters `.
519+ """Construct `ClusterSummaryFeatures `.
511520
512521 Args:
513522 cluster_on: Names of features to create clusters from.
@@ -527,15 +536,34 @@ def __init__(
527536 with a charge value.
528537 time_standardization: Standardization factor for features
529538 with a time
530- add_counts: If True, number of log10(counts per DOM) is added as
539+ order_in_time: If True, clusters are sorted by time.
540+ If your data is already sorted by time, you can set this
541+ to False to avoid a potential overhead.
542+ NOTE: Should only be set to False if you are sure that
543+ the input data is already sorted by time. Will lead to
544+ incorrect results if the data is not sorted by time.
545+ add_counts: If True, number of log10(event counts per clusters) is added as
531546 a feature.
532547 """
533548 # Set member variables
534549 self ._cluster_on = cluster_on
535550 self ._charge_label = charge_label
536551 self ._time_label = time_label
537- self ._charge_standardization = charge_standardization
552+ if isinstance (charge_standardization , float ):
553+ charge_std_fn = lambda x : x * charge_standardization
554+ elif isinstance (
555+ charge_standardization , str
556+ ):
557+ if charge_standardization != "log" :
558+ raise ValueError (
559+ f"charge_standardization must be either a float or 'log', "
560+ f"but got { charge_standardization } "
561+ )
562+ charge_std_fn = lambda x : np .log10 (x )
563+ self ._charge_std_fn = charge_std_fn
538564 self ._time_standardization = time_standardization
565+ self ._order_in_time = order_in_time
566+
539567
540568 # feature member variables
541569 self ._total_charge = total_charge
@@ -548,13 +576,18 @@ def __init__(
548576
549577 # Base class constructor
550578 super ().__init__ (input_feature_names = input_feature_names )
579+ if self ._order_in_time is False :
580+ self .info (
581+ "Setting `order_by_time` to False. "
582+ "Make sure that the input data is already sorted by time."
583+ )
551584
552585 def _define_output_feature_names (
553586 self ,
554587 input_feature_names : List [str ],
555588 ) -> List [str ]:
556589 """Set the output feature names."""
557- self .set_indeces (input_feature_names )
590+ self .set_indices (input_feature_names )
558591 new_feature_names = deepcopy (self ._cluster_on )
559592 if self ._total_charge :
560593 new_feature_names .append ("total_charge" )
@@ -580,7 +613,7 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
580613 cluster_class = cluster_and_pad (
581614 x = x ,
582615 cluster_columns = self ._cluster_idx ,
583- sort_by = [self ._time_idx ],
616+ sort_by = [self ._time_idx ] if self . _order_in_time else [] ,
584617 )
585618 # calculate charge weighted median time as reference
586619 ref_time = cluster_class .reference_time (
@@ -591,7 +624,9 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
591624 # add total charge
592625 if self ._total_charge :
593626 cluster_class .add_sum_charge (charge_index = self ._charge_idx )
594- cluster_class .clustered_x [:, - 1 ] *= self ._charge_standardization
627+ cluster_class .clustered_x [:, - 1 ] = self ._charge_std_fn (
628+ cluster_class .clustered_x [:, - 1 ]
629+ )
595630
596631 # add charge after t
597632 if len (self ._charge_after_t ) > 0 :
@@ -602,7 +637,11 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
602637 )
603638 cluster_class .clustered_x [
604639 :, - len (self ._charge_after_t ) :
605- ] *= self ._charge_standardization
640+ ] = self ._charge_std_fn (
641+ cluster_class .clustered_x [
642+ :, - len (self ._charge_after_t ) :
643+ ]
644+ )
606645
607646 # add time of first hit
608647 if self ._time_of_first_hit :
@@ -643,7 +682,7 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
643682 cluster_class .add_counts ()
644683 return torch .tensor (cluster_class .clustered_x )
645684
646- def set_indeces (self , feature_names : List [str ]) -> None :
685+ def set_indices (self , feature_names : List [str ]) -> None :
647686 """Set the indices for the input features."""
648687 self ._cluster_idx = [
649688 feature_names .index (column ) for column in self ._cluster_on
0 commit comments