|
1 | 1 | """I3Extractor class(es) for extracting specific, reconstructed features.""" |
2 | 2 |
|
3 | | -from typing import TYPE_CHECKING, Any, Dict, List |
| 3 | +from typing import TYPE_CHECKING, Any, Dict, List, Optional |
4 | 4 | from .i3extractor import I3Extractor |
5 | 5 | from graphnet.data.extractors.icecube.utilities.frames import ( |
6 | 6 | get_om_keys_and_pulseseries, |
7 | 7 | ) |
8 | 8 | from graphnet.utilities.imports import has_icecube_package |
9 | 9 |
|
10 | 10 | if has_icecube_package() or TYPE_CHECKING: |
11 | | - from icecube import icetray # pyright: reportMissingImports=false |
| 11 | + from icecube import ( |
| 12 | + icetray, |
| 13 | + dataclasses, |
| 14 | + ) # pyright: reportMissingImports=false |
12 | 15 |
|
| 16 | +import numpy as np |
13 | 17 |
|
14 | | -class I3FeatureExtractor(I3Extractor): |
| 18 | + |
| 19 | +class I3PulseLevelExtractor(I3Extractor): |
15 | 20 | """Base class for extracting specific, reconstructed features.""" |
16 | 21 |
|
17 | | - def __init__(self, pulsemap: str, exclude: list = [None]): |
18 | | - """Construct I3FeatureExtractor. |
| 22 | + def __init__( |
| 23 | + self, |
| 24 | + pulsemap: str, |
| 25 | + exclude: list = [None], |
| 26 | + extractor_name: Optional[str] = None, |
| 27 | + ): |
| 28 | + """Construct I3PulseLevelExtractor. |
19 | 29 |
|
20 | 30 | Args: |
21 | 31 | pulsemap: Name of the pulse (series) map for which to extract |
22 | 32 | reconstructed features. |
23 | 33 | exclude: List of keys to exclude from the extracted data. |
| 34 | + extractor_name: Name of the extractor. |
24 | 35 | """ |
25 | 36 | # Member variable(s) |
26 | 37 | self._pulsemap = pulsemap |
| 38 | + if extractor_name is None: |
| 39 | + extractor_name = pulsemap |
27 | 40 |
|
28 | 41 | # Base class constructor |
| 42 | + super().__init__(extractor_name, exclude=exclude) |
| 43 | + |
| 44 | + |
| 45 | +class I3FeatureExtractor(I3PulseLevelExtractor): |
| 46 | + """Old class now contained in I3PulseLevelExtractor.""" |
| 47 | + |
| 48 | + def __init__(self, pulsemap: str, exclude: list = [None]): |
| 49 | + """Construct I3FeatureExtractor. |
| 50 | +
|
| 51 | + Args: |
| 52 | + pulsemap: Name of the pulse (series) map for which to extract |
| 53 | + reconstructed features. |
| 54 | + exclude: List of keys to exclude from the extracted data. |
| 55 | + """ |
| 56 | + self.warning_once( |
| 57 | + "I3FeatureExtractor is deprecated and will be removed in a future release. Please use I3PulseLevelExtractor instead." |
| 58 | + ) |
29 | 59 | super().__init__(pulsemap, exclude=exclude) |
30 | 60 |
|
31 | 61 |
|
32 | | -class I3FeatureExtractorIceCube86(I3FeatureExtractor): |
| 62 | +class I3FeatureExtractorIceCube86(I3PulseLevelExtractor): |
33 | 63 | """Class for extracting reconstructed features for IceCube-86.""" |
34 | 64 |
|
35 | 65 | def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]: |
@@ -206,6 +236,217 @@ def _parse_awtd_flag( |
206 | 236 | return pulse.width < (fadc_min_width_ns * icetray.I3Units.ns) |
207 | 237 |
|
208 | 238 |
|
| 239 | +class I3PulseOriginLabels(I3PulseLevelExtractor): |
| 240 | + """Class for extracting MCPE labels for IceCube-86.""" |
| 241 | + |
| 242 | + def __init__( |
| 243 | + self, |
| 244 | + pulsemap: str, |
| 245 | + exclude: list = [None], |
| 246 | + extractor_name: str = "PulseOrigin", |
| 247 | + time_window: float = 10.0, |
| 248 | + mctree: str = "I3MCTree", |
| 249 | + mcpe_map: str = "I3MCPESeriesMapWithoutNoise", |
| 250 | + mcpe_map_id: str = "I3MCPESeriesMapParticleIDMap", |
| 251 | + ): |
| 252 | + """Construct I3PulseOriginLabels. |
| 253 | +
|
| 254 | + Args: |
| 255 | + pulsemap: Name of the pulse (series) map for which to extract |
| 256 | + reconstructed features. |
| 257 | + exclude: List of keys to exclude from the extracted data. |
| 258 | + extractor_name: Name of the extractor. |
| 259 | + time_window: Time window (in ns) around each pulse to consider |
| 260 | + MCPEs for label assignment. |
| 261 | + mctree: Name of the MCTree in the I3 frame. |
| 262 | + mcpe_map: Name of the MCPE series map in the I3 frame. |
| 263 | + mcpe_map_id: Name of the MCPE series map particle ID map in the I3 frame. |
| 264 | + """ |
| 265 | + super().__init__(pulsemap, exclude, extractor_name) |
| 266 | + self._time_window = time_window |
| 267 | + self._mctree = mctree |
| 268 | + self._mcpe_map = mcpe_map |
| 269 | + self._mcpe_map_id = mcpe_map_id |
| 270 | + |
| 271 | + def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]: |
| 272 | + """Extract MCPE labels from `frame`. |
| 273 | +
|
| 274 | + Args: |
| 275 | + frame: Physics (P) I3-frame from which to extract MCPE labels. |
| 276 | +
|
| 277 | + Returns: |
| 278 | + Dictionary of MCPE labels for all pulses in `pulsemap`, |
| 279 | + in pure-python format. |
| 280 | + """ |
| 281 | + output: Dict[str, List[Any]] = { |
| 282 | + "charge": [], |
| 283 | + "dom_time": [], |
| 284 | + "dom_x": [], |
| 285 | + "dom_y": [], |
| 286 | + "dom_z": [], |
| 287 | + "neutrino_fraction": [], |
| 288 | + "neutrino_npe_fraction": [], |
| 289 | + "npe": [], |
| 290 | + "pulse_count": [], |
| 291 | + "noise_hit": [], |
| 292 | + "trackness": [], |
| 293 | + "overlap_count": [], |
| 294 | + "min_time_delta": [], |
| 295 | + } |
| 296 | + |
| 297 | + # Get OM data |
| 298 | + if self._pulsemap in frame: |
| 299 | + om_keys, data = get_om_keys_and_pulseseries( |
| 300 | + frame, |
| 301 | + self._pulsemap, |
| 302 | + self._calibration, |
| 303 | + ) |
| 304 | + else: |
| 305 | + self.warning_once(f"Pulsemap {self._pulsemap} not found in frame.") |
| 306 | + return output |
| 307 | + |
| 308 | + for om_key in om_keys: |
| 309 | + # Loop over pulses for each OM |
| 310 | + pulses = data[om_key] |
| 311 | + pulse_times, pulse_charges = self._get_pulse_info(pulses) |
| 312 | + npe_list, times, nu_bool, track_like_list = self._get_mcpe_info( |
| 313 | + frame, om_key |
| 314 | + ) |
| 315 | + time_distance_matrix = pulse_times[:, None] - times[None, :] |
| 316 | + weight_matrix = self._get_gaussian_weight(time_distance_matrix) * ( |
| 317 | + np.abs(time_distance_matrix) <= self._time_window |
| 318 | + ) |
| 319 | + |
| 320 | + with np.errstate(invalid="ignore"): |
| 321 | + weight_matrix /= np.sum(weight_matrix, axis=0, keepdims=True) |
| 322 | + weight_matrix = np.nan_to_num(weight_matrix, nan=0.0) |
| 323 | + pulse_counts = np.sum(weight_matrix, axis=1) |
| 324 | + neutrino_fractions = (weight_matrix @ nu_bool) / pulse_counts |
| 325 | + neutrino_npe_fractions = ( |
| 326 | + weight_matrix @ (npe_list * nu_bool) |
| 327 | + ) / (weight_matrix @ npe_list) |
| 328 | + total_npe = weight_matrix @ npe_list |
| 329 | + trackness = weight_matrix @ track_like_list / pulse_counts |
| 330 | + min_time_deltas = ( |
| 331 | + np.min(np.abs(time_distance_matrix), axis=1) |
| 332 | + if len(times) > 0 |
| 333 | + else np.array([np.nan] * len(pulse_times)) |
| 334 | + ) |
| 335 | + |
| 336 | + # Determine how many other mcpe's overlap with any other pulse for a given pulse |
| 337 | + weight_matrix_binary = (weight_matrix > 0).astype(float) |
| 338 | + overlap_counts = weight_matrix_binary @ weight_matrix_binary.T |
| 339 | + np.fill_diagonal(overlap_counts, 0) |
| 340 | + overlap_counts = np.sum(overlap_counts, axis=1) |
| 341 | + |
| 342 | + output["neutrino_fraction"].extend(neutrino_fractions.tolist()) |
| 343 | + output["neutrino_npe_fraction"].extend( |
| 344 | + neutrino_npe_fractions.tolist() |
| 345 | + ) |
| 346 | + output["npe"].extend(total_npe.tolist()) |
| 347 | + output["pulse_count"].extend(pulse_counts.tolist()) |
| 348 | + output["charge"].extend(pulse_charges.tolist()) |
| 349 | + output["dom_time"].extend(pulse_times.tolist()) |
| 350 | + output["dom_x"].extend( |
| 351 | + [self._gcd_dict[om_key].position.x] * len(pulse_times) |
| 352 | + ) |
| 353 | + output["dom_y"].extend( |
| 354 | + [self._gcd_dict[om_key].position.y] * len(pulse_times) |
| 355 | + ) |
| 356 | + output["dom_z"].extend( |
| 357 | + [self._gcd_dict[om_key].position.z] * len(pulse_times) |
| 358 | + ) |
| 359 | + output["noise_hit"].extend((pulse_counts == 0).tolist()) |
| 360 | + output["min_time_delta"].extend(min_time_deltas.tolist()) |
| 361 | + output["trackness"].extend(trackness.tolist()) |
| 362 | + output["overlap_count"].extend(overlap_counts.tolist()) |
| 363 | + |
| 364 | + return output |
| 365 | + |
| 366 | + def _get_mcpe_info( |
| 367 | + self, frame: "icetray.I3Frame", om_key: "icetray.OMKey" |
| 368 | + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: |
| 369 | + """Determine the neutrino fraction of a pulse. |
| 370 | +
|
| 371 | + Args: |
| 372 | + pulse: The pulse for which to determine the neutrino fraction. |
| 373 | + mcpe_map: Name of the MCPE series map in the I3 frame. |
| 374 | +
|
| 375 | + Returns: |
| 376 | + Neutrino fraction of the pulse. |
| 377 | + """ |
| 378 | + times: list[float] = [] |
| 379 | + nu_bool: list[bool] = [] |
| 380 | + npe_list: list[float] = [] |
| 381 | + track_like_list: list[float] = [] |
| 382 | + try: |
| 383 | + mcpe_info = frame[self._mcpe_map][om_key] |
| 384 | + except KeyError: |
| 385 | + return ( |
| 386 | + np.array(npe_list), |
| 387 | + np.array(times), |
| 388 | + np.array(nu_bool), |
| 389 | + np.array(track_like_list), |
| 390 | + ) |
| 391 | + for i, mcpe in enumerate(mcpe_info): |
| 392 | + try: |
| 393 | + nu_primary = ( |
| 394 | + frame[self._mctree].get_primary(mcpe.ID).is_neutrino |
| 395 | + ) |
| 396 | + track_like = frame[self._mctree].get_particle(mcpe.ID).is_track |
| 397 | + nu_bool.append(nu_primary) |
| 398 | + except RuntimeError as e: |
| 399 | + # backup to using the mcpe id map to figure out the parent type, if any part of the mcpe has a neutrino as the primary, we count it as a neutrino mcpe this is a choice, but the information about which part of the mcpe corresponds to which primary is lost. |
| 400 | + if "particle not found" in str(e): |
| 401 | + ids = [ |
| 402 | + id_p |
| 403 | + for id_p, indexval in frame[self._mcpe_map_id][ |
| 404 | + om_key |
| 405 | + ].items() |
| 406 | + if i in indexval |
| 407 | + ] |
| 408 | + bool_val = any( |
| 409 | + [ |
| 410 | + frame[self._mctree].get_primary(id_p).is_neutrino |
| 411 | + for id_p in ids |
| 412 | + ] |
| 413 | + ) |
| 414 | + track_like = [ |
| 415 | + frame[self._mctree].get_particle(id_p).is_track |
| 416 | + for id_p in ids |
| 417 | + ] |
| 418 | + track_like = sum(track_like) / len(track_like) |
| 419 | + nu_bool.append(bool_val) |
| 420 | + else: |
| 421 | + raise e |
| 422 | + |
| 423 | + times.append(mcpe.time) |
| 424 | + npe_list.append(mcpe.npe) |
| 425 | + track_like_list.append(track_like) |
| 426 | + |
| 427 | + return ( |
| 428 | + np.array(npe_list), |
| 429 | + np.array(times), |
| 430 | + np.array(nu_bool), |
| 431 | + np.array(track_like_list), |
| 432 | + ) |
| 433 | + |
| 434 | + def _get_pulse_info( |
| 435 | + self, pulses: List["icetray.I3RecoPulse"] |
| 436 | + ) -> tuple[np.ndarray, np.ndarray]: |
| 437 | + """Create an nd array of pulse times, charge.""" |
| 438 | + times, charges = np.array([[p.time, p.charge] for p in pulses]).T |
| 439 | + return times, charges |
| 440 | + |
| 441 | + def _get_gaussian_weight( |
| 442 | + self, time_distance_matrix: np.ndarray |
| 443 | + ) -> np.ndarray: |
| 444 | + """Create gaussian weight matrix based on time distance matrix.""" |
| 445 | + return np.exp( |
| 446 | + -0.5 * (time_distance_matrix / (self._time_window / 2)) ** 2 |
| 447 | + ) |
| 448 | + |
| 449 | + |
209 | 450 | class I3FeatureExtractorIceCubeDeepCore(I3FeatureExtractorIceCube86): |
210 | 451 | """Class for extracting reconstructed features for IceCube-DeepCore.""" |
211 | 452 |
|
|
0 commit comments