Skip to content

Commit 0a31e0d

Browse files
authored
2249 hi lo cg correction bin pset data 1 (IMAP-Science-Operations-Center#2308)
* Change how HiPointingSet implements ram/anti-ram filtering * Revert changes in HiPointingSet that modify esa_energy_step coordinate * Start integrating CG correction into Hi L2 * Fix test failing due to unspecified coordinate values * Remove spin_phase argument to HiPointingSet * Pre PR cleanup * Pre PR cleanup * Give a more descriptive name to TypeVar * Fix test broken by reordering * PR feedback * Fix doc build
1 parent 81572fe commit 0a31e0d

6 files changed

Lines changed: 334 additions & 146 deletions

File tree

imap_processing/ena_maps/ena_maps.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -647,28 +647,21 @@ class HiPointingSet(LoHiBasePointingSet):
647647
----------
648648
dataset : xarray.Dataset | str | Path
649649
Hi L1C pointing set data loaded in a xarray.DataArray.
650-
spin_phase : str
651-
Include ENAs from "full", "ram" or "anti-ram" phases of the spin.
652650
"""
653651

654-
def __init__(self, dataset: xr.Dataset | str | Path, spin_phase: str):
655-
super().__init__(dataset, spice_reference_frame=geometry.SpiceFrame.ECLIPJ2000)
652+
def __init__(self, dataset: xr.Dataset | str | Path):
653+
super().__init__(dataset, spice_reference_frame=geometry.SpiceFrame.IMAP_HAE)
654+
655+
self.spatial_coords = ("spin_angle_bin",)
656656

657-
# Filter out ENAs from non-selected portions of the spin.
658-
if spin_phase not in ["full", "ram", "anti"]:
659-
raise ValueError(f"Unrecognized spin_phase value: {spin_phase}.")
657+
# Naively generate the ram_mask variable assuming spacecraft frame
658+
# binning. The ram_mask variable gets updated in the CG correction
659+
# code if the CG correction is applied.
660+
ram_mask = xr.zeros_like(self.data["spin_angle_bin"], dtype=bool)
660661
# ram only includes spin-phase interval [0, 0.5)
661662
# which is the first half of the spin_angle_bins
662-
elif spin_phase == "ram":
663-
self.data = self.data.isel(
664-
spin_angle_bin=slice(0, self.data["spin_angle_bin"].data.size // 2)
665-
)
666-
# anti-ram includes spin-phase interval [0.5, 1)
667-
# which is the second half of the spin_angle_bins
668-
elif spin_phase == "anti":
669-
self.data = self.data.isel(
670-
spin_angle_bin=slice(self.data["spin_angle_bin"].data.size // 2, None)
671-
)
663+
ram_mask[slice(0, self.data["spin_angle_bin"].data.size // 2)] = True
664+
self.data["ram_mask"] = ram_mask
672665

673666
# Rename some PSET vars to match L2 variables
674667
self.data = self.data.rename(
@@ -684,8 +677,6 @@ def __init__(self, dataset: xr.Dataset | str | Path, spin_phase: str):
684677
self.data["exposure_factor"], self.data["epoch"].values[0]
685678
)
686679

687-
self.spatial_coords = ("spin_angle_bin",)
688-
689680
# Update az_el_points using the base class method
690681
self.update_az_el_points()
691682

@@ -810,12 +801,12 @@ def num_points(self) -> int:
810801
"""
811802
return self.az_el_points.shape[0]
812803

813-
def project_pset_values_to_map(
804+
def project_pset_values_to_map( # noqa: PLR0912
814805
self,
815806
pointing_set: PointingSet,
816807
value_keys: list[str] | None = None,
817808
index_match_method: IndexMatchMethod = IndexMatchMethod.PUSH,
818-
pset_valid_mask: NDArray | None = None,
809+
pset_valid_mask: NDArray | xr.DataArray | None = None,
819810
) -> None:
820811
"""
821812
Project a pointing set's values to the map grid.
@@ -837,7 +828,7 @@ def project_pset_values_to_map(
837828
index_match_method : IndexMatchMethod, optional
838829
The method of index matching to use for all values.
839830
Default is IndexMatchMethod.PUSH.
840-
pset_valid_mask : NDArray, optional
831+
pset_valid_mask : xarray.DataArray or NDArray, optional
841832
A boolean mask of shape (number of pointing set pixels,) indicating
842833
which pixels in the pointing set should be considered valid for projection.
843834
If None, all pixels are considered valid. Default is None.
@@ -849,9 +840,9 @@ def project_pset_values_to_map(
849840
"""
850841
if value_keys is None:
851842
value_keys = list(pointing_set.data.data_vars.keys())
852-
for value_key in value_keys:
853-
if value_key not in pointing_set.data.data_vars:
854-
raise ValueError(f"Value key {value_key} not found in pointing set.")
843+
844+
if missing_keys := set(value_keys) - set(pointing_set.data.data_vars):
845+
raise KeyError(f"Value keys not found in pointing set: {missing_keys}")
855846

856847
if pset_valid_mask is None:
857848
pset_valid_mask = np.ones(pointing_set.num_points, dtype=bool)
@@ -876,9 +867,12 @@ def project_pset_values_to_map(
876867
)
877868

878869
for value_key in value_keys:
870+
if value_key not in pointing_set.data.data_vars:
871+
raise ValueError(f"Value key {value_key} not found in pointing set.")
872+
879873
# If multiple spatial axes present
880874
# (i.e (az, el) for rectangular coordinate PSET),
881-
# flatten them in the values array to match the raveled indices
875+
# stack them into a single coordinate to match the raveled indices
882876
raveled_pset_data = pointing_set.data[value_key].stack(
883877
{CoordNames.GENERIC_PIXEL.value: pointing_set.spatial_coords}
884878
)
@@ -907,13 +901,22 @@ def project_pset_values_to_map(
907901
data_bc, indices_bc = xr.broadcast(
908902
raveled_pset_data, matched_indices_push
909903
)
904+
# If the valid mask is a xr.DataArray, broadcast it to the same shape
905+
if isinstance(pset_valid_mask, xr.DataArray):
906+
stacked_valid_mask = pset_valid_mask.stack(
907+
{CoordNames.GENERIC_PIXEL.value: pointing_set.spatial_coords}
908+
)
909+
pset_valid_mask_bc, _ = xr.broadcast(data_bc, stacked_valid_mask)
910+
pset_valid_mask_values = pset_valid_mask_bc.values
911+
else:
912+
pset_valid_mask_values = pset_valid_mask
910913

911914
# Extract numpy arrays for bincount operation
912915
pointing_projected_values = map_utils.bin_single_array_at_indices(
913916
value_array=data_bc.values,
914917
projection_grid_shape=self.binning_grid_shape,
915918
projection_indices=indices_bc.values,
916-
input_valid_mask=pset_valid_mask,
919+
input_valid_mask=pset_valid_mask_values,
917920
)
918921
# TODO: we may need to allow for unweighted/weighted means here by
919922
# dividing pointing_projected_values by some binned weights.
@@ -934,10 +937,6 @@ def project_pset_values_to_map(
934937
self.data_1d[value_key].values[..., valid_map_mask] += (
935938
pointing_projected_values
936939
)
937-
else:
938-
raise NotImplementedError(
939-
"Only PUSH and PULL index matching methods are supported."
940-
)
941940

942941
# TODO: The max epoch needs to include the pset duration. Right now it
943942
# is just capturing the start epoch. See issue #1747

imap_processing/ena_maps/utils/corrections.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
11
"""L2 corrections common to multiple IMAP ENA instruments."""
22

33
from pathlib import Path
4+
from typing import TypeVar
45

56
import numpy as np
67
import pandas as pd
78
import xarray as xr
89
from numpy.polynomial import Polynomial
910
from scipy.constants import electron_volt, erg, proton_mass
1011

11-
from imap_processing.ena_maps.ena_maps import LoHiBasePointingSet
12+
from imap_processing.ena_maps.ena_maps import (
13+
LoHiBasePointingSet,
14+
)
1215
from imap_processing.ena_maps.utils.coordinates import CoordNames
1316
from imap_processing.spice import geometry
1417
from imap_processing.spice.time import ttj2000ns_to_et
1518

19+
# Create a TypeVar to represent the specific class being passed in
20+
# Bound to LoHiBasePointingSet, meaning it must be LoHiBasePointingSet
21+
# or a subclass of it
22+
LoHiBasePsetSubclass = TypeVar("LoHiBasePsetSubclass", bound=LoHiBasePointingSet)
23+
1624
# Physical constants for Compton-Getting correction
1725
# Units: electron_volt = [J / eV]
1826
# erg = [J / erg]
@@ -307,7 +315,9 @@ class or child classes.
307315
return corrected_flux, corrected_flux_stat_unc
308316

309317

310-
def _add_spacecraft_velocity_to_pset(pset: LoHiBasePointingSet) -> None:
318+
def _add_spacecraft_velocity_to_pset(
319+
pset: LoHiBasePsetSubclass,
320+
) -> LoHiBasePsetSubclass:
311321
"""
312322
Calculate and add spacecraft velocity data to pointing set.
313323
@@ -316,6 +326,11 @@ def _add_spacecraft_velocity_to_pset(pset: LoHiBasePointingSet) -> None:
316326
pset : LoHiBasePointingSet
317327
Pointing set object to be updated.
318328
329+
Returns
330+
-------
331+
pset : LoHiBasePointingSet
332+
Pointing set object with spacecraft velocity data added.
333+
319334
Notes
320335
-----
321336
Adds the following DataArrays to pset.data:
@@ -342,8 +357,10 @@ def _add_spacecraft_velocity_to_pset(pset: LoHiBasePointingSet) -> None:
342357
)
343358
pset.data["sc_direction_vector"] = pset.data["sc_velocity"] / sc_velocity_km_per_sec
344359

360+
return pset
361+
345362

346-
def _add_cartesian_look_direction(pset: LoHiBasePointingSet) -> None:
363+
def _add_cartesian_look_direction(pset: LoHiBasePsetSubclass) -> LoHiBasePsetSubclass:
347364
"""
348365
Calculate and add look direction vectors to pointing set.
349366
@@ -352,6 +369,11 @@ def _add_cartesian_look_direction(pset: LoHiBasePointingSet) -> None:
352369
pset : LoHiBasePointingSet
353370
Pointing set object to be updated.
354371
372+
Returns
373+
-------
374+
pset : LoHiBasePointingSet
375+
Pointing set object with look direction vectors added.
376+
355377
Notes
356378
-----
357379
Adds the following DataArray to pset.data:
@@ -376,11 +398,13 @@ def _add_cartesian_look_direction(pset: LoHiBasePointingSet) -> None:
376398
dims=[*longitudes.dims, CoordNames.CARTESIAN_VECTOR.value],
377399
)
378400

401+
return pset
402+
379403

380404
def _calculate_compton_getting_transform(
381-
pset: LoHiBasePointingSet,
405+
pset: LoHiBasePsetSubclass,
382406
energy_hf: xr.DataArray,
383-
) -> None:
407+
) -> LoHiBasePsetSubclass:
384408
"""
385409
Apply Compton-Getting transformation to compute ENA source directions.
386410
@@ -400,14 +424,24 @@ def _calculate_compton_getting_transform(
400424
energy_hf : xr.DataArray
401425
ENA energies in the heliosphere frame in eV.
402426
427+
Returns
428+
-------
429+
pset : LoHiBasePointingSet
430+
Pointing set object with Compton-Getting related variables added and
431+
updated az_el_points.
432+
403433
Notes
404434
-----
405435
The algorithm is based on the "Appendix A. The IMAP-Lo Mapping Algorithms"
406436
document.
407437
Adds the following DataArrays to pset.data:
408438
- "energy_sc": ENA energies in spacecraft frame (eV)
409-
- "ena_source_hae_longitude": ENA source longitudes in heliosphere frame (degrees)
410-
- "ena_source_hae_latitude": ENA source latitudes in heliosphere frame (degrees)
439+
- "energy_hf": ENA energies in the heliosphere frame (eV)
440+
- "ram_mask": Mask indicating whether ENA source direction is from the ram
441+
direction.
442+
Updates the following DataArrays in pset.data:
443+
- "hae_longitude": ENA source longitudes in heliosphere frame (degrees)
444+
- "hae_latitude": ENA source latitudes in heliosphere frame (degrees)
411445
"""
412446
# Store heliosphere frame energies
413447
pset.data["energy_hf"] = energy_hf
@@ -455,6 +489,8 @@ def _calculate_compton_getting_transform(
455489
# Velocity magnitude factor calculation (Equation 62)
456490
# x_k = (êₛ · û_sc) + sqrt(y² + (êₛ · û_sc)² - 1)
457491
x = dot_product + np.sqrt(y**2 + dot_product**2 - 1)
492+
# Get the dimensions in the right order so that spatial is last
493+
x = x.transpose(dot_product.dims[0], y.dims[0], dot_product.dims[1])
458494

459495
# Calculate ENA speed in the spacecraft frame
460496
# |v⃗_sc| = x_k * U_sc
@@ -504,11 +540,13 @@ def _calculate_compton_getting_transform(
504540
dims=velocity_vector_helio.dims[:-1],
505541
)
506542

543+
return pset
544+
507545

508546
def apply_compton_getting_correction(
509-
pset: LoHiBasePointingSet,
547+
pset: LoHiBasePsetSubclass,
510548
energy_hf: xr.DataArray,
511-
) -> None:
549+
) -> LoHiBasePsetSubclass:
512550
"""
513551
Apply Compton-Getting correction to a pointing set and update coordinates.
514552
@@ -532,6 +570,11 @@ def apply_compton_getting_correction(
532570
ENA energies in the heliosphere frame in eV. Must be 1D with an
533571
energy dimension.
534572
573+
Returns
574+
-------
575+
pset : LoHiBasePointingSet
576+
Updated pointing set object with Compton-Getting related variables added.
577+
535578
Notes
536579
-----
537580
This function adds the following variables to the pointing set dataset:
@@ -540,20 +583,23 @@ def apply_compton_getting_correction(
540583
- "look_direction": Cartesian unit vectors of observation directions
541584
- "energy_hf": ENA energies in heliosphere frame (eV)
542585
- "energy_sc": ENA energies in spacecraft frame (eV)
543-
- "ena_source_hae_longitude": ENA source longitudes in heliosphere frame (degrees)
544-
- "ena_source_hae_latitude": ENA source latitudes in heliosphere frame (degrees)
586+
This function modifies the following variables in the pointing set dataset:
587+
- "hae_longitude": ENA source longitudes in heliosphere frame (degrees)
588+
- "hae_latitude": ENA source latitudes in heliosphere frame (degrees)
545589
546590
The az_el_points attribute is updated to use the corrected coordinates,
547591
which will be used for subsequent binning operations.
548592
"""
549593
# Step 1: Add spacecraft velocity and direction to pset
550-
_add_spacecraft_velocity_to_pset(pset)
594+
pset = _add_spacecraft_velocity_to_pset(pset)
551595

552596
# Step 2: Calculate and add look direction vectors to pset
553-
_add_cartesian_look_direction(pset)
597+
pset = _add_cartesian_look_direction(pset)
554598

555599
# Step 3: Apply Compton-Getting transformation
556-
_calculate_compton_getting_transform(pset, energy_hf)
600+
pset = _calculate_compton_getting_transform(pset, energy_hf)
557601

558602
# Step 4: Update az_el_points to use the corrected coordinates
559603
pset.update_az_el_points()
604+
605+
return pset

0 commit comments

Comments
 (0)