Skip to content

Commit f0a39bc

Browse files
committed
adjust docs and clean up
1 parent 89f8c92 commit f0a39bc

1 file changed

Lines changed: 9 additions & 13 deletions

File tree

  • src/graphnet/models/data_representation/graphs/nodes

src/graphnet/models/data_representation/graphs/nodes/nodes.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -539,12 +539,12 @@ def __init__(
539539
with a charge value.
540540
time_standardization: Standardization factor for features
541541
with a time
542-
order_in_time: If True, clusters are sorted by time.
543-
If your data is already sorted by time, you can set this
542+
order_in_time: If True, clusters are ordered in time.
543+
If your data is already ordered in time, you can set this
544544
to False to avoid a potential overhead.
545545
NOTE: Should only be set to False if you are sure that
546-
the input data is already sorted by time. Will lead to
547-
incorrect results if the data is not sorted by time.
546+
the input data is already ordered in time. Will lead to
547+
incorrect results otherwise.
548548
add_counts: If True, number of log10(event counts per clusters)
549549
is added as a feature.
550550
"""
@@ -573,7 +573,7 @@ def __init__(
573573
if self._order_in_time is False:
574574
self.info(
575575
"Setting `order_by_time` to False. "
576-
"Make sure that the input data is already sorted by time."
576+
"Make sure that the input data is already ordered in time."
577577
)
578578

579579
def _define_output_feature_names(
@@ -705,17 +705,13 @@ def _standardize_features(
705705
x: np.ndarray,
706706
standardization: Union[float, str],
707707
) -> np.ndarray:
708-
"""Standardize the features in the input tensor."""
708+
"""Standardize the features in the input array."""
709709
if isinstance(standardization, float):
710710
return x * standardization
711-
elif isinstance(standardization, str):
712-
if standardization != "log":
713-
raise ValueError(
714-
f"standardization must be either a float or 'log', "
715-
f"but got {standardization}"
716-
)
711+
elif standardization == "log":
717712
return np.log10(x)
718713
else:
714+
# should never happen, but just in case
719715
raise ValueError(
720716
f"standardization must be either a float or 'log', "
721717
f"but got {standardization}"
@@ -724,7 +720,7 @@ def _standardize_features(
724720
def _verify_standardization(
725721
self,
726722
) -> torch.Tensor:
727-
"""Verify the standardization of the features."""
723+
"""Verify settings of standardization of the features."""
728724
if not isinstance(self._charge_standardization, float):
729725
if isinstance(self._charge_standardization, str):
730726
if self._charge_standardization != "log":

0 commit comments

Comments
 (0)