Skip to content

Commit 89f8c92

Browse files
committed
more adjustments for imports etc
1 parent bea975c commit 89f8c92

3 files changed

Lines changed: 88 additions & 36 deletions

File tree

src/graphnet/models/data_representation/graphs/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
IceMixNodes,
2121
ClusterSummaryFeatures,
2222
)
23-
2423
from .edges import (
2524
EdgeDefinition,
2625
KNNEdges,

src/graphnet/models/data_representation/graphs/nodes/nodes.py

Lines changed: 87 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Class(es) for building/connecting graphs."""
22

3-
from typing import List, Tuple, Optional, Dict, Union, Callable
3+
from typing import List, Tuple, Optional, Dict, Union
44
from abc import abstractmethod
55

66
import torch
@@ -475,9 +475,11 @@ class ClusterSummaryFeatures(NodeDefinition):
475475
"""Represent pulse maps as clusters with summary features.
476476
477477
If `cluster_on` is set to the xyz coordinates of optical modules
478-
e.g. `cluster_on = ['dom_x', 'dom_y', 'dom_z']`, each node will be a
479-
unique optical module and the pulse information (charge, time) is summarized.
480-
NOTE: Intended to be used with features [dom_x, dom_y, dom_z, charge, time]
478+
e.g. `cluster_on = ['dom_x', 'dom_y', 'dom_z']`, each node will be
479+
a unique optical module and the pulse information (e.g. charge, time)
480+
is summarized.
481+
NOTE: Developed to be used with features
482+
[dom_x, dom_y, dom_z, charge, time]
481483
482484
Possible features per cluster:
483485
- total charge
@@ -495,7 +497,8 @@ class ClusterSummaryFeatures(NodeDefinition):
495497
- number of pulses per clusters
496498
feature name: `counts`
497499
498-
Taken from Theo Glauchs PhD thesis:
500+
For more details on some of the features see
501+
Theo Glauchs thesis (chapter 5.3):
499502
https://mediatum.ub.tum.de/node?id=1584755
500503
"""
501504

@@ -539,31 +542,22 @@ def __init__(
539542
order_in_time: If True, clusters are sorted by time.
540543
If your data is already sorted by time, you can set this
541544
to False to avoid a potential overhead.
542-
NOTE: Should only be set to False if you are sure that
545+
NOTE: Should only be set to False if you are sure that
543546
the input data is already sorted by time. Will lead to
544547
incorrect results if the data is not sorted by time.
545-
add_counts: If True, number of log10(event counts per clusters) is added as
546-
a feature.
548+
add_counts: If True, number of log10(event counts per clusters)
549+
is added as a feature.
547550
"""
548551
# Set member variables
549552
self._cluster_on = cluster_on
550553
self._charge_label = charge_label
551554
self._time_label = time_label
552-
if isinstance(charge_standardization, float):
553-
charge_std_fn = lambda x: x * charge_standardization
554-
elif isinstance(
555-
charge_standardization, str
556-
):
557-
if charge_standardization != "log":
558-
raise ValueError(
559-
f"charge_standardization must be either a float or 'log', "
560-
f"but got {charge_standardization}"
561-
)
562-
charge_std_fn = lambda x: np.log10(x)
563-
self._charge_std_fn = charge_std_fn
564-
self._time_standardization = time_standardization
565555
self._order_in_time = order_in_time
566556

557+
# Check if charge_standardization is a float or 'log'
558+
self._charge_standardization = charge_standardization
559+
self._time_standardization = time_standardization
560+
self._verify_standardization()
567561

568562
# feature member variables
569563
self._total_charge = total_charge
@@ -624,8 +618,9 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
624618
# add total charge
625619
if self._total_charge:
626620
cluster_class.add_sum_charge(charge_index=self._charge_idx)
627-
cluster_class.clustered_x[:, -1] = self._charge_std_fn(
628-
cluster_class.clustered_x[:, -1]
621+
cluster_class.clustered_x[:, -1] = self._standardize_features(
622+
cluster_class.clustered_x[:, -1],
623+
self._charge_standardization,
629624
)
630625

631626
# add charge after t
@@ -635,12 +630,11 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
635630
summarization_indices=[self._charge_idx],
636631
times=self._charge_after_t,
637632
)
638-
cluster_class.clustered_x[
639-
:, -len(self._charge_after_t) :
640-
] = self._charge_std_fn(
641-
cluster_class.clustered_x[
642-
:, -len(self._charge_after_t) :
643-
]
633+
cluster_class.clustered_x[:, -len(self._charge_after_t) :] = (
634+
self._standardize_features(
635+
cluster_class.clustered_x[:, -len(self._charge_after_t) :],
636+
self._charge_standardization,
637+
)
644638
)
645639

646640
# add time of first hit
@@ -649,21 +643,31 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
649643
time_index=self._time_idx,
650644
)
651645
cluster_class.clustered_x[:, -1] -= ref_time
652-
cluster_class.clustered_x[:, -1] *= self._time_standardization
646+
647+
cluster_class.clustered_x[:, -1] = self._standardize_features(
648+
cluster_class.clustered_x[:, -1],
649+
self._time_standardization,
650+
)
653651

654652
# add time spread
655653
if self._time_spread:
656654
cluster_class.add_spread(
657655
columns=[self._time_idx],
658656
)
659-
cluster_class.clustered_x[:, -1] *= self._time_standardization
657+
cluster_class.clustered_x[:, -1] = self._standardize_features(
658+
cluster_class.clustered_x[:, -1],
659+
self._time_standardization,
660+
)
660661

661662
# add time std
662663
if self._time_std:
663664
cluster_class.add_std(
664665
columns=[self._time_idx],
665666
)
666-
cluster_class.clustered_x[:, -1] *= self._time_standardization
667+
cluster_class.clustered_x[:, -1] = self._standardize_features(
668+
cluster_class.clustered_x[:, -1],
669+
self._time_standardization,
670+
)
667671

668672
# add time after charge percentiles
669673
if len(self._time_after_charge_pct) > 0:
@@ -677,7 +681,13 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
677681
] -= ref_time
678682
cluster_class.clustered_x[
679683
:, -len(self._time_after_charge_pct) :
680-
] *= self._time_standardization
684+
] = self._standardize_features(
685+
cluster_class.clustered_x[
686+
:, -len(self._time_after_charge_pct) :
687+
],
688+
self._time_standardization,
689+
)
690+
681691
if self._add_counts:
682692
cluster_class.add_counts()
683693
return torch.tensor(cluster_class.clustered_x)
@@ -689,3 +699,47 @@ def set_indices(self, feature_names: List[str]) -> None:
689699
]
690700
self._charge_idx = feature_names.index(self._charge_label)
691701
self._time_idx = feature_names.index(self._time_label)
702+
703+
def _standardize_features(
704+
self,
705+
x: np.ndarray,
706+
standardization: Union[float, str],
707+
) -> np.ndarray:
708+
"""Standardize the features in the input tensor."""
709+
if isinstance(standardization, float):
710+
return x * standardization
711+
elif isinstance(standardization, str):
712+
if standardization != "log":
713+
raise ValueError(
714+
f"standardization must be either a float or 'log', "
715+
f"but got {standardization}"
716+
)
717+
return np.log10(x)
718+
else:
719+
raise ValueError(
720+
f"standardization must be either a float or 'log', "
721+
f"but got {standardization}"
722+
)
723+
724+
def _verify_standardization(
725+
self,
726+
) -> torch.Tensor:
727+
"""Verify the standardization of the features."""
728+
if not isinstance(self._charge_standardization, float):
729+
if isinstance(self._charge_standardization, str):
730+
if self._charge_standardization != "log":
731+
raise ValueError(
732+
f"charge_standardization must be either a float or"
733+
f" 'log', but got {self._charge_standardization}"
734+
)
735+
else:
736+
raise ValueError(
737+
f"charge_standardization must be either a float or 'log', "
738+
f"but got {self._charge_standardization}"
739+
)
740+
741+
if not isinstance(self._time_standardization, float):
742+
raise ValueError(
743+
f"time_standardization must be a float, "
744+
f"but got {self._time_standardization}"
745+
)

src/graphnet/models/graphs/nodes/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
PercentileClusters,
1313
NodeAsDOMTimeSeries,
1414
IceMixNodes,
15-
DOMSummaryFeatures,
15+
ClusterSummaryFeatures,
1616
)
1717

18-
1918
Logger(log_folder=None).warning_once(
2019
(
2120
"`graphnet.models.graphs` will be depricated soon. "

0 commit comments

Comments
 (0)