|
1 | 1 | """Class(es) for building/connecting graphs.""" |
2 | 2 |
|
3 | | -from typing import List, Tuple, Optional, Dict |
| 3 | +from typing import List, Tuple, Optional, Dict, Union |
4 | 4 | from abc import abstractmethod |
5 | 5 |
|
6 | 6 | import torch |
@@ -469,3 +469,278 @@ def _construct_nodes(self, x: torch.Tensor) -> torch.Tensor: |
469 | 469 | graph[:event_length, idx] = x[ids, self.feature_indexes[feature]] |
470 | 470 |
|
471 | 471 | return graph |
| 472 | + |
| 473 | + |
| 474 | +class ClusterSummaryFeatures(NodeDefinition): |
| 475 | + """Represent pulse maps as clusters with summary features. |
| 476 | +
|
| 477 | + If `cluster_on` is set to the xyz coordinates of optical modules |
| 478 | + e.g. `cluster_on = ['dom_x', 'dom_y', 'dom_z']`, each node will be |
| 479 | + a unique optical module and the pulse information (e.g. charge, time) |
| 480 | + is summarized. |
| 481 | + NOTE: Developed to be used with features |
| 482 | + [dom_x, dom_y, dom_z, charge, time] |
| 483 | +
|
| 484 | + Possible features per cluster: |
| 485 | + - total charge |
| 486 | + feature name: `total_charge` |
| 487 | + - charge accumulated after <X> time units |
| 488 | + feature name: `charge_after_<X>ns` |
| 489 | + - time of first hit in the optical module |
| 490 | + feature name: `time_of_first_hit` |
| 491 | + - time spread per optical module |
| 492 | + feature name: `time_spread` |
| 493 | + - time std per optical module |
| 494 | + feature name: `time_std` |
| 495 | + - time took to collect <X> percent of total charge per cluster |
| 496 | + feature name: `time_after_charge_pct<X>` |
| 497 | + - number of pulses per clusters |
| 498 | + feature name: `counts` |
| 499 | +
|
| 500 | + For more details on some of the features see |
| 501 | + Theo Glauchs thesis (chapter 5.3): |
| 502 | + https://mediatum.ub.tum.de/node?id=1584755 |
| 503 | + """ |
| 504 | + |
| 505 | + def __init__( |
| 506 | + self, |
| 507 | + cluster_on: List[str], |
| 508 | + input_feature_names: List[str], |
| 509 | + charge_label: str = "charge", |
| 510 | + time_label: str = "dom_time", |
| 511 | + total_charge: bool = True, |
| 512 | + charge_after_t: List[int] = [10, 50, 100], |
| 513 | + time_of_first_hit: bool = True, |
| 514 | + time_spread: bool = True, |
| 515 | + time_std: bool = True, |
| 516 | + time_after_charge_pct: List[int] = [1, 3, 5, 11, 15, 20, 50, 80], |
| 517 | + charge_standardization: Union[float, str] = "log", |
| 518 | + time_standardization: float = 1e-3, |
| 519 | + order_in_time: bool = True, |
| 520 | + add_counts: bool = False, |
| 521 | + ) -> None: |
| 522 | + """Construct `ClusterSummaryFeatures`. |
| 523 | +
|
| 524 | + Args: |
| 525 | + cluster_on: Names of features to create clusters from. |
| 526 | + input_feature_names: Column names for input features. |
| 527 | + charge_label: Name of the charge column. |
| 528 | + time_label: Name of the time column. |
| 529 | + total_charge: If True, calculates total charge as feature. |
| 530 | + charge_after_t: List of times at which the accumulated charge |
| 531 | + is calculated as a feature. |
| 532 | + time_of_first_hit: If True, time of first hit is added |
| 533 | + as a feature. |
| 534 | + time_spread: If True, time spread is added as a feature. |
| 535 | + time_std: If True, time std is added as a feature. |
| 536 | + time_after_charge_pct: List of percentiles to calculate time after |
| 537 | + charge. |
| 538 | + charge_standardization: Either a float or 'log'. If a float, |
| 539 | + the features are multiplied by this factor. If 'log', the |
| 540 | + features are transformed to log10 scale. |
| 541 | + time_standardization: Standardization factor for features |
| 542 | + with a time |
| 543 | + order_in_time: If True, clusters are ordered in time. |
| 544 | + If your data is already ordered in time, you can set this |
| 545 | + to False to avoid a potential overhead. |
| 546 | + NOTE: Should only be set to False if you are sure that |
| 547 | + the input data is already ordered in time. Will lead to |
| 548 | + incorrect results otherwise. |
| 549 | + add_counts: If True, number of log10(event counts per clusters) |
| 550 | + is added as a feature. |
| 551 | +
|
| 552 | + NOTE: Make sure that either the input data is not already standardized |
| 553 | + or that the `charge_standardization` and `time_standardization` |
| 554 | + parameters are set to 1 to avoid a double standardization. |
| 555 | + """ |
| 556 | + # Set member variables |
| 557 | + self._cluster_on = cluster_on |
| 558 | + self._charge_label = charge_label |
| 559 | + self._time_label = time_label |
| 560 | + self._order_in_time = order_in_time |
| 561 | + |
| 562 | + # Check if charge_standardization is a float or 'log' |
| 563 | + self._charge_standardization = charge_standardization |
| 564 | + self._time_standardization = time_standardization |
| 565 | + self._verify_standardization() |
| 566 | + |
| 567 | + # feature member variables |
| 568 | + self._total_charge = total_charge |
| 569 | + self._charge_after_t = charge_after_t |
| 570 | + self._time_of_first_hit = time_of_first_hit |
| 571 | + self._time_spread = time_spread |
| 572 | + self._time_std = time_std |
| 573 | + self._time_after_charge_pct = time_after_charge_pct |
| 574 | + self._add_counts = add_counts |
| 575 | + |
| 576 | + # Base class constructor |
| 577 | + super().__init__(input_feature_names=input_feature_names) |
| 578 | + if self._order_in_time is False: |
| 579 | + self.info( |
| 580 | + "Setting `order_by_time` to False. " |
| 581 | + "Make sure that the input data is already ordered in time." |
| 582 | + ) |
| 583 | + |
| 584 | + def _define_output_feature_names( |
| 585 | + self, |
| 586 | + input_feature_names: List[str], |
| 587 | + ) -> List[str]: |
| 588 | + """Set the output feature names.""" |
| 589 | + self.set_indices(input_feature_names) |
| 590 | + new_feature_names = deepcopy(self._cluster_on) |
| 591 | + if self._total_charge: |
| 592 | + new_feature_names.append("total_charge") |
| 593 | + for t in self._charge_after_t: |
| 594 | + new_feature_names.append(f"charge_after_{t}ns") |
| 595 | + if self._time_of_first_hit: |
| 596 | + new_feature_names.append("time_of_first_hit") |
| 597 | + if self._time_spread: |
| 598 | + new_feature_names.append("time_spread") |
| 599 | + if self._time_std: |
| 600 | + new_feature_names.append("time_std") |
| 601 | + for pct in self._time_after_charge_pct: |
| 602 | + new_feature_names.append(f"time_after_charge_pct{pct}") |
| 603 | + if self._add_counts: |
| 604 | + new_feature_names.append("counts") |
| 605 | + return new_feature_names |
| 606 | + |
| 607 | + def _construct_nodes(self, x: torch.Tensor) -> Data: |
| 608 | + """Construct nodes from raw node features ´x´.""" |
| 609 | + # Cast to Numpy |
| 610 | + x = x.numpy() |
| 611 | + # Construct clusters with percentile-summarized features |
| 612 | + cluster_class = cluster_and_pad( |
| 613 | + x=x, |
| 614 | + cluster_columns=self._cluster_idx, |
| 615 | + sort_by=[self._time_idx] if self._order_in_time else [], |
| 616 | + ) |
| 617 | + # calculate charge weighted median time as reference |
| 618 | + ref_time = cluster_class.reference_time( |
| 619 | + charge_index=self._charge_idx, |
| 620 | + time_index=self._time_idx, |
| 621 | + ) |
| 622 | + |
| 623 | + # add total charge |
| 624 | + if self._total_charge: |
| 625 | + cluster_class.add_sum_charge(charge_index=self._charge_idx) |
| 626 | + cluster_class.clustered_x[:, -1] = self._standardize_features( |
| 627 | + cluster_class.clustered_x[:, -1], |
| 628 | + self._charge_standardization, |
| 629 | + ) |
| 630 | + |
| 631 | + # add charge after t |
| 632 | + if len(self._charge_after_t) > 0: |
| 633 | + cluster_class.add_accumulated_value_after_t( |
| 634 | + time_index=self._time_idx, |
| 635 | + summarization_indices=[self._charge_idx], |
| 636 | + times=self._charge_after_t, |
| 637 | + ) |
| 638 | + cluster_class.clustered_x[:, -len(self._charge_after_t) :] = ( |
| 639 | + self._standardize_features( |
| 640 | + cluster_class.clustered_x[:, -len(self._charge_after_t) :], |
| 641 | + self._charge_standardization, |
| 642 | + ) |
| 643 | + ) |
| 644 | + |
| 645 | + # add time of first hit |
| 646 | + if self._time_of_first_hit: |
| 647 | + cluster_class.add_time_first_pulse( |
| 648 | + time_index=self._time_idx, |
| 649 | + ) |
| 650 | + cluster_class.clustered_x[:, -1] -= ref_time |
| 651 | + |
| 652 | + cluster_class.clustered_x[:, -1] = self._standardize_features( |
| 653 | + cluster_class.clustered_x[:, -1], |
| 654 | + self._time_standardization, |
| 655 | + ) |
| 656 | + |
| 657 | + # add time spread |
| 658 | + if self._time_spread: |
| 659 | + cluster_class.add_spread( |
| 660 | + columns=[self._time_idx], |
| 661 | + ) |
| 662 | + cluster_class.clustered_x[:, -1] = self._standardize_features( |
| 663 | + cluster_class.clustered_x[:, -1], |
| 664 | + self._time_standardization, |
| 665 | + ) |
| 666 | + |
| 667 | + # add time std |
| 668 | + if self._time_std: |
| 669 | + cluster_class.add_std( |
| 670 | + columns=[self._time_idx], |
| 671 | + ) |
| 672 | + cluster_class.clustered_x[:, -1] = self._standardize_features( |
| 673 | + cluster_class.clustered_x[:, -1], |
| 674 | + self._time_standardization, |
| 675 | + ) |
| 676 | + |
| 677 | + # add time after charge percentiles |
| 678 | + if len(self._time_after_charge_pct) > 0: |
| 679 | + cluster_class.add_charge_threshold_summary( |
| 680 | + summarization_indices=[self._time_idx], |
| 681 | + percentiles=self._time_after_charge_pct, |
| 682 | + charge_index=self._charge_idx, |
| 683 | + ) |
| 684 | + cluster_class.clustered_x[ |
| 685 | + :, -len(self._time_after_charge_pct) : |
| 686 | + ] -= ref_time |
| 687 | + cluster_class.clustered_x[ |
| 688 | + :, -len(self._time_after_charge_pct) : |
| 689 | + ] = self._standardize_features( |
| 690 | + cluster_class.clustered_x[ |
| 691 | + :, -len(self._time_after_charge_pct) : |
| 692 | + ], |
| 693 | + self._time_standardization, |
| 694 | + ) |
| 695 | + |
| 696 | + if self._add_counts: |
| 697 | + cluster_class.add_counts() |
| 698 | + return torch.tensor(cluster_class.clustered_x) |
| 699 | + |
| 700 | + def set_indices(self, feature_names: List[str]) -> None: |
| 701 | + """Set the indices for the input features.""" |
| 702 | + self._cluster_idx = [ |
| 703 | + feature_names.index(column) for column in self._cluster_on |
| 704 | + ] |
| 705 | + self._charge_idx = feature_names.index(self._charge_label) |
| 706 | + self._time_idx = feature_names.index(self._time_label) |
| 707 | + |
| 708 | + def _standardize_features( |
| 709 | + self, |
| 710 | + x: np.ndarray, |
| 711 | + standardization: Union[float, str], |
| 712 | + ) -> np.ndarray: |
| 713 | + """Standardize the features in the input array.""" |
| 714 | + if isinstance(standardization, float): |
| 715 | + return x * standardization |
| 716 | + elif standardization == "log": |
| 717 | + return np.log10(x) |
| 718 | + else: |
| 719 | + # should never happen, but just in case |
| 720 | + raise ValueError( |
| 721 | + f"standardization must be either a float or 'log', " |
| 722 | + f"but got {standardization}" |
| 723 | + ) |
| 724 | + |
| 725 | + def _verify_standardization( |
| 726 | + self, |
| 727 | + ) -> torch.Tensor: |
| 728 | + """Verify settings of standardization of the features.""" |
| 729 | + if not isinstance(self._charge_standardization, float): |
| 730 | + if isinstance(self._charge_standardization, str): |
| 731 | + if self._charge_standardization != "log": |
| 732 | + raise ValueError( |
| 733 | + f"charge_standardization must be either a float or" |
| 734 | + f" 'log', but got {self._charge_standardization}" |
| 735 | + ) |
| 736 | + else: |
| 737 | + raise ValueError( |
| 738 | + f"charge_standardization must be either a float or 'log', " |
| 739 | + f"but got {self._charge_standardization}" |
| 740 | + ) |
| 741 | + |
| 742 | + if not isinstance(self._time_standardization, float): |
| 743 | + raise ValueError( |
| 744 | + f"time_standardization must be a float, " |
| 745 | + f"but got {self._time_standardization}" |
| 746 | + ) |
0 commit comments