@@ -647,28 +647,21 @@ class HiPointingSet(LoHiBasePointingSet):
647647 ----------
648648 dataset : xarray.Dataset | str | Path
649649 Hi L1C pointing set data loaded in a xarray.DataArray.
650- spin_phase : str
651- Include ENAs from "full", "ram" or "anti-ram" phases of the spin.
652650 """
653651
654- def __init__ (self , dataset : xr .Dataset | str | Path , spin_phase : str ):
655- super ().__init__ (dataset , spice_reference_frame = geometry .SpiceFrame .ECLIPJ2000 )
652+ def __init__ (self , dataset : xr .Dataset | str | Path ):
653+ super ().__init__ (dataset , spice_reference_frame = geometry .SpiceFrame .IMAP_HAE )
654+
655+ self .spatial_coords = ("spin_angle_bin" ,)
656656
657- # Filter out ENAs from non-selected portions of the spin.
658- if spin_phase not in ["full" , "ram" , "anti" ]:
659- raise ValueError (f"Unrecognized spin_phase value: { spin_phase } ." )
657+ # Naively generate the ram_mask variable assuming spacecraft frame
658+ # binning. The ram_mask variable gets updated in the CG correction
659+ # code if the CG correction is applied.
660+ ram_mask = xr .zeros_like (self .data ["spin_angle_bin" ], dtype = bool )
660661 # ram only includes spin-phase interval [0, 0.5)
661662 # which is the first half of the spin_angle_bins
662- elif spin_phase == "ram" :
663- self .data = self .data .isel (
664- spin_angle_bin = slice (0 , self .data ["spin_angle_bin" ].data .size // 2 )
665- )
666- # anti-ram includes spin-phase interval [0.5, 1)
667- # which is the second half of the spin_angle_bins
668- elif spin_phase == "anti" :
669- self .data = self .data .isel (
670- spin_angle_bin = slice (self .data ["spin_angle_bin" ].data .size // 2 , None )
671- )
663+ ram_mask [slice (0 , self .data ["spin_angle_bin" ].data .size // 2 )] = True
664+ self .data ["ram_mask" ] = ram_mask
672665
673666 # Rename some PSET vars to match L2 variables
674667 self .data = self .data .rename (
@@ -684,8 +677,6 @@ def __init__(self, dataset: xr.Dataset | str | Path, spin_phase: str):
684677 self .data ["exposure_factor" ], self .data ["epoch" ].values [0 ]
685678 )
686679
687- self .spatial_coords = ("spin_angle_bin" ,)
688-
689680 # Update az_el_points using the base class method
690681 self .update_az_el_points ()
691682
@@ -810,12 +801,12 @@ def num_points(self) -> int:
810801 """
811802 return self .az_el_points .shape [0 ]
812803
813- def project_pset_values_to_map (
804+ def project_pset_values_to_map ( # noqa: PLR0912
814805 self ,
815806 pointing_set : PointingSet ,
816807 value_keys : list [str ] | None = None ,
817808 index_match_method : IndexMatchMethod = IndexMatchMethod .PUSH ,
818- pset_valid_mask : NDArray | None = None ,
809+ pset_valid_mask : NDArray | xr . DataArray | None = None ,
819810 ) -> None :
820811 """
821812 Project a pointing set's values to the map grid.
@@ -837,7 +828,7 @@ def project_pset_values_to_map(
837828 index_match_method : IndexMatchMethod, optional
838829 The method of index matching to use for all values.
839830 Default is IndexMatchMethod.PUSH.
840- pset_valid_mask : NDArray, optional
831+ pset_valid_mask : xarray.DataArray or NDArray, optional
841832 A boolean mask of shape (number of pointing set pixels,) indicating
842833 which pixels in the pointing set should be considered valid for projection.
843834 If None, all pixels are considered valid. Default is None.
@@ -849,9 +840,9 @@ def project_pset_values_to_map(
849840 """
850841 if value_keys is None :
851842 value_keys = list (pointing_set .data .data_vars .keys ())
852- for value_key in value_keys :
853- if value_key not in pointing_set .data .data_vars :
854- raise ValueError (f"Value key { value_key } not found in pointing set. " )
843+
844+ if missing_keys := set ( value_keys ) - set ( pointing_set .data .data_vars ) :
845+ raise KeyError (f"Value keys not found in pointing set: { missing_keys } " )
855846
856847 if pset_valid_mask is None :
857848 pset_valid_mask = np .ones (pointing_set .num_points , dtype = bool )
@@ -876,9 +867,12 @@ def project_pset_values_to_map(
876867 )
877868
878869 for value_key in value_keys :
870+ if value_key not in pointing_set .data .data_vars :
871+ raise ValueError (f"Value key { value_key } not found in pointing set." )
872+
879873 # If multiple spatial axes present
880874 # (i.e (az, el) for rectangular coordinate PSET),
881- # flatten them in the values array to match the raveled indices
875+ # stack them into a single coordinate to match the raveled indices
882876 raveled_pset_data = pointing_set .data [value_key ].stack (
883877 {CoordNames .GENERIC_PIXEL .value : pointing_set .spatial_coords }
884878 )
@@ -907,13 +901,22 @@ def project_pset_values_to_map(
907901 data_bc , indices_bc = xr .broadcast (
908902 raveled_pset_data , matched_indices_push
909903 )
904+ # If the valid mask is a xr.DataArray, broadcast it to the same shape
905+ if isinstance (pset_valid_mask , xr .DataArray ):
906+ stacked_valid_mask = pset_valid_mask .stack (
907+ {CoordNames .GENERIC_PIXEL .value : pointing_set .spatial_coords }
908+ )
909+ pset_valid_mask_bc , _ = xr .broadcast (data_bc , stacked_valid_mask )
910+ pset_valid_mask_values = pset_valid_mask_bc .values
911+ else :
912+ pset_valid_mask_values = pset_valid_mask
910913
911914 # Extract numpy arrays for bincount operation
912915 pointing_projected_values = map_utils .bin_single_array_at_indices (
913916 value_array = data_bc .values ,
914917 projection_grid_shape = self .binning_grid_shape ,
915918 projection_indices = indices_bc .values ,
916- input_valid_mask = pset_valid_mask ,
919+ input_valid_mask = pset_valid_mask_values ,
917920 )
918921 # TODO: we may need to allow for unweighted/weighted means here by
919922 # dividing pointing_projected_values by some binned weights.
@@ -934,10 +937,6 @@ def project_pset_values_to_map(
934937 self .data_1d [value_key ].values [..., valid_map_mask ] += (
935938 pointing_projected_values
936939 )
937- else :
938- raise NotImplementedError (
939- "Only PUSH and PULL index matching methods are supported."
940- )
941940
942941 # TODO: The max epoch needs to include the pset duration. Right now it
943942 # is just capturing the start epoch. See issue #1747
0 commit comments