Skip to content

Commit 165fedb

Browse files
committed
automatic_name_generation
1 parent eaa12e5 commit 165fedb

1 file changed

Lines changed: 82 additions & 12 deletions

File tree

src/graphnet/models/graphs/utils.py

Lines changed: 82 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,20 @@ class cluster_and_pad:
202202
# Gets the clustered matrix with all the aggregate statistics.
203203
"""
204204

205-
def __init__(self, x: np.ndarray, cluster_columns: List[int]) -> None:
205+
def __init__(
206+
self,
207+
x: np.ndarray,
208+
cluster_columns: List[int],
209+
input_names: Optional[List[str]] = None,
210+
) -> None:
206211
"""Initialize the class with the data and cluster columns.
207212
208213
Args:
209214
x: Array to be clustered
210215
cluster_columns: List of column indices on which the clusters
211216
are constructed.
217+
input_names: Names of the columns in the input data for automatic
218+
generation of names.
212219
Adds:
213220
clustered_x: Added to the class
214221
_counts: Added to the class
@@ -244,6 +251,14 @@ def __init__(self, x: np.ndarray, cluster_columns: List[int]) -> None:
244251
self._padded_x[i, : self._counts[i]] = x[: self._counts[i]]
245252
x = x[self._counts[i] :]
246253

254+
self._input_names = input_names
255+
if self._input_names is not None:
256+
assert (
257+
len(self._input_names) == x.shape[1]
258+
), "The input names must have the same length as the input data"
259+
260+
self._cluster_names = np.array(input_names)[cluster_columns]
261+
247262
def _add_column(
248263
self, column: np.ndarray, location: Optional[int] = None
249264
) -> None:
@@ -263,6 +278,25 @@ def _add_column(
263278
self.clustered_x, location, column, axis=1
264279
)
265280

281+
def _add_column_names(
282+
self, names: List[str], location: Optional[int] = None
283+
) -> None:
284+
"""Add names to the columns of the clustered tensor.
285+
286+
Args:
287+
names: Names to be added to the columns of the tensor
288+
location: Location to insert the names in the clustered tensor
289+
Altered:
290+
_cluster_names: The names are added at the end of the tensor
291+
or inserted at the specified location
292+
"""
293+
if location is None:
294+
self._cluster_names = np.append(self._cluster_names, names)
295+
else:
296+
self._cluster_names = np.insert(
297+
self._cluster_names, location, names
298+
)
299+
266300
def _calculate_charge_sum(self, charge_index: int) -> np.ndarray:
267301
"""Calculate the sum of the charge."""
268302
assert not hasattr(
@@ -310,6 +344,8 @@ def add_charge_threshold_summary(
310344
of the charge divided by the total charge
311345
clustered_x: The summarization indices are added at the end
312346
of the tensor or inserted at the specified location.
347+
_cluster_names: The names are added at the end of the tensor
348+
or inserted at the specified location
313349
"""
314350
# convert the charge to the cumulative sum of the charge divided
315351
# by the total charge
@@ -340,6 +376,15 @@ def add_charge_threshold_summary(
340376
)
341377
self._add_column(selections, location)
342378

379+
# update the cluster names
380+
if self._input_names is not None:
381+
new_names = [
382+
self._input_names[i] + "_charge_threshold_" + str(p)
383+
for i in summarization_indices
384+
for p in percentiles
385+
]
386+
self._add_column_names(new_names, location)
387+
343388
def add_percentile_summary(
344389
self,
345390
summarization_indices: List[int],
@@ -359,6 +404,8 @@ def add_percentile_summary(
359404
Altered:
360405
clustered_x: The summarization indices are added at the end of
361406
the tensor or inserted at the specified location
407+
_cluster_names: The names are added at the end of the tensor
408+
or inserted at the specified location
362409
"""
363410
percentiles_x = np.nanpercentile(
364411
self._padded_x[:, :, summarization_indices],
@@ -372,48 +419,71 @@ def add_percentile_summary(
372419
)
373420
self._add_column(percentiles_x, location)
374421

422+
# update the cluster names
423+
if self._input_names is not None:
424+
new_names = [
425+
self._input_names[i] + "_percentile_" + str(p)
426+
for i in summarization_indices
427+
for p in percentiles
428+
]
429+
self._add_column_names(new_names, location)
430+
375431
def add_counts(self, location: Optional[int] = None) -> np.ndarray:
376432
"""Add the counts of the sensor to the summarization features."""
377433
self._add_column(np.log10(self._counts), location)
434+
new_name = ["counts"]
435+
self._add_column_names(new_name, location)
378436

379-
def add_sum_charge(self, location: Optional[int] = None) -> np.ndarray:
437+
def add_sum_charge(
438+
self, charge_index: int, location: Optional[int] = None
439+
) -> np.ndarray:
380440
"""Add the sum of the charge to the summarization features."""
381-
assert hasattr(
382-
self, "_charge_sum"
383-
), "Charge sum has not been calculated, \
384-
please run calculate_charge_sum"
441+
if not hasattr(self, "_charge_sum"):
442+
self._calculate_charge_sum(charge_index)
385443
self._add_column(self._charge_sum, location)
444+
# update the cluster names
445+
if self._input_names is not None:
446+
new_name = [self._input_names[charge_index] + "_sum"]
447+
self._add_column_names(new_name, location)
386448

387449
def add_std(
388450
self,
389-
column: int,
451+
columns: List[int],
390452
location: Optional[int] = None,
391453
weights: Union[np.ndarray, int] = 1,
392454
) -> np.ndarray:
393455
"""Add the standard deviation of the column.
394456
395457
Args:
396-
column: Index of the column in the padded tensor to
397-
calculate the standard deviation
458+
columns: Index of the columns from which to calculate the standard
459+
deviation.
398460
location: Location to insert the standard deviation in the
399461
clustered tensor defaults to adding at the end
400462
weights: Optional weights to be applied to the standard deviation
401463
"""
402464
self._add_column(
403-
np.nanstd(self._padded_x[:, :, column] * weights, axis=1), location
465+
np.nanstd(self._padded_x[:, :, columns] * weights, axis=1),
466+
location,
404467
)
468+
if self._input_names is not None:
469+
new_names = [self._input_names[i] + "_std" for i in columns]
470+
self._add_column_names(new_names, location)
405471

406472
def add_mean(
407473
self,
408-
column: int,
474+
columns: List[int],
409475
location: Optional[int] = None,
410476
weights: Union[np.ndarray, int] = 1,
411477
) -> np.ndarray:
412478
"""Add the mean of the column."""
413479
self._add_column(
414-
np.nanmean(self._padded_x[:, :, column] * weights, axis=1),
480+
np.nanmean(self._padded_x[:, :, columns] * weights, axis=1),
415481
location,
416482
)
483+
# update the cluster names
484+
if self._input_names is not None:
485+
new_names = [self._input_names[i] + "_mean" for i in columns]
486+
self._add_column_names(new_names, location)
417487

418488

419489
def ice_transparency(

0 commit comments

Comments
 (0)