Skip to content

Commit 27800c4

Browse files
committed
time_first_hit as class attribute
1 parent c953953 commit 27800c4

1 file changed

Lines changed: 29 additions & 10 deletions

File tree

  • src/graphnet/models/data_representation/graphs

src/graphnet/models/data_representation/graphs/utils.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,17 @@ def _calculate_reference_time(self, time_index: int) -> np.ndarray:
356356
weights=self._charge_weights.flatten(),
357357
)
358358

359+
def _calculate_time_first_pulse(self, time_index: int) -> np.ndarray:
360+
"""Calculate the time of the first pulse."""
361+
assert not hasattr(
362+
self, "_time_first_pulse"
363+
), "Time of first pulse has already been calculated, \
364+
re-calculation is not allowed"
365+
self._time_first_pulse = np.nanmin(
366+
self._padded_x[:, :, time_index],
367+
axis=1,
368+
)
369+
359370
def add_charge_threshold_summary(
360371
self,
361372
summarization_indices: List[int],
@@ -534,8 +545,12 @@ def add_time_first_pulse(
534545
_cluster_names: The names are added at the end of the tensor
535546
or inserted at the specified location
536547
"""
537-
time_first_pulse = np.nanmin(self._padded_x[:, :, time_index], axis=1)
538-
self._add_column(time_first_pulse, location)
548+
if not hasattr(self, "_time_first_pulse"):
549+
self._calculate_time_first_pulse(time_index)
550+
551+
# Add the time of the first pulse to the clustered tensor
552+
self._add_column(self._time_first_pulse, location)
553+
539554
# update the cluster names
540555
if self._input_names is not None:
541556
new_name = [self._input_names[time_index] + "_first_pulse"]
@@ -582,19 +597,20 @@ def add_accumulated_value_after_t(
582597
_cluster_names: The names are added at the end of the tensor
583598
or inserted at the specified location
584599
"""
585-
# Summarize the values at different times
586-
time_first_pulse = np.nanmin(
587-
self._padded_x[:, :, time_index],
588-
axis=1,
589-
)[:, np.newaxis]
600+
# Calculate the time of the first pulse if not already done
601+
if not hasattr(self, "_time_first_pulse"):
602+
self._calculate_time_first_pulse(time_index)
603+
604+
# Create array with threshold times
590605
tmp_times = (
591606
np.tile(
592607
np.array(times),
593-
(len(time_first_pulse), 1),
608+
(len(self._time_first_pulse[:, np.newaxis]), 1),
594609
)
595-
+ time_first_pulse
610+
+ self._time_first_pulse[:, np.newaxis]
596611
)
597-
# print(times)
612+
613+
# Create a mask for the times
598614
mask = (
599615
self._padded_x[:, :, time_index][:, np.newaxis, :]
600616
>= tmp_times[:, :, np.newaxis]
@@ -615,7 +631,10 @@ def add_accumulated_value_after_t(
615631
selections = selections.transpose(0, 2, 1).reshape(
616632
len(self.clustered_x), -1
617633
)
634+
635+
# Add the selections to the clustered tensor
618636
self._add_column(selections, location)
637+
619638
# update the cluster names
620639
if self._input_names is not None:
621640
new_names = [

0 commit comments

Comments
 (0)