Skip to content

Commit 699a991

Browse files
committed
SkyCoord for flarePosition
Fits de/serialize hook for converting into fits serializeable icrs frame
1 parent 8febb3e commit 699a991

4 files changed

Lines changed: 180 additions & 28 deletions

File tree

stixcore/io/product_processors/fits/processors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1141,7 +1141,8 @@ def write_fits(self, prod, *, version=0):
11411141
elif fitspath_complete.exists():
11421142
logger.warning("Complete Fits file %s exists will be overridden", fitspath.name)
11431143

1144-
data = prod.data
1144+
data = prod.data.copy()
1145+
prod.on_serialize(data)
11451146

11461147
primary_header, header_override = self.generate_primary_header(filename, prod, version=version)
11471148

stixcore/products/level3/flarelist.py

Lines changed: 77 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -97,24 +97,12 @@ def add_flare_position(
9797
# SkyCoord(HeliographicStonyhurst(0 * u.deg, 0 * u.deg))
9898
# SkyCoord(0 * u.deg, 0 * u.deg, frame=helio_frame)
9999

100-
n = len(data)
101-
102-
data["flareposition_obs_hgs_x"] = Column(
103-
np.zeros(n, dtype=float) * u.km, description="HeliographicStonyhurst X of observer"
104-
)
105-
data["flareposition_obs_hgs_y"] = Column(
106-
np.zeros(n, dtype=float) * u.km, description="HeliographicStonyhurst Y of observer"
107-
)
108-
data["flareposition_obs_hgs_z"] = Column(
109-
np.zeros(n, dtype=float) * u.km, description="HeliographicStonyhurst Z of observer"
110-
)
111-
data["flareposition_hp_tx"] = Column(np.zeros(n, dtype=float) * u.arcsec, description="Helioprojective Tx")
112-
data["flareposition_hp_ty"] = Column(np.zeros(n, dtype=float) * u.arcsec, description="Helioprojective Ty")
113-
114100
data["anc_ephemeris_path"] = Column(" " * 500, dtype=str, description="TDB")
115101
data["cpd_path"] = Column(" " * 500, dtype=str, description="TDB")
116102
data["_position_status"] = Column(False, dtype=bool, description="TDB")
117103
data["_position_message"] = Column(" " * 500, dtype=str, description="TDB")
104+
tx_list, ty_list = [], []
105+
solo_x_list, solo_y_list, solo_z_list, peak_time_list = [], [], [], []
118106

119107
to_remove = []
120108
pass_filter = 0
@@ -127,12 +115,11 @@ def add_flare_position(
127115
day_asp_ephemeris_cache = dict()
128116

129117
for i, row in enumerate(data):
130-
if filter_function(row): # and i < 200:
118+
peak_time = row[peak_time_colname]
119+
start_time = row[start_time_colname]
120+
end_time = row[end_time_colname]
121+
if filter_function(row): # and i < 60:
131122
pass_filter += 1
132-
peak_time = row[peak_time_colname]
133-
start_time = row[start_time_colname]
134-
end_time = row[end_time_colname]
135-
136123
day = peak_time.to_datetime().date()
137124

138125
if day in day_asp_ephemeris_cache:
@@ -249,21 +236,54 @@ def add_flare_position(
249236
solo = HeliographicStonyhurst(*solo_xyz, obstime=peak_time, representation_type="cartesian")
250237
with SphericalScreen(solo, only_off_disk=True):
251238
center_hpc = coord.transform_to(Helioprojective(observer=solo))
252-
253-
data[i]["flareposition_obs_hgs_x"] = solo_xyz[0].to(u.km)
254-
data[i]["flareposition_obs_hgs_y"] = solo_xyz[1].to(u.km)
255-
data[i]["flareposition_obs_hgs_z"] = solo_xyz[2].to(u.km)
256-
data[i]["flareposition_hp_tx"] = center_hpc.Tx.to(u.arcsec)
257-
data[i]["flareposition_hp_ty"] = center_hpc.Ty.to(u.arcsec)
239+
tx_list.append(center_hpc.Tx)
240+
ty_list.append(center_hpc.Ty)
241+
solo_x_list.append(solo.cartesian.x)
242+
solo_y_list.append(solo.cartesian.y)
243+
solo_z_list.append(solo.cartesian.z)
244+
peak_time_list.append(peak_time)
258245

259246
data[i]["_position_status"] = True
260247
data[i]["_position_message"] = "OK"
261248
except Exception as e:
262249
data[i]["_position_status"] = False
263250
data[i]["_position_message"] = f"Error: {type(e)}"
264-
logger.warn(f"Error calculating flare position for flare at time {start_time} : {end_time}: {e}")
251+
logger.warning(f"Error calculating flare position for flare at time {start_time} : {end_time}: {e}")
252+
tx_list.append(np.nan * u.arcsec)
253+
ty_list.append(np.nan * u.arcsec)
254+
solo_x_list.append(np.nan * u.km)
255+
solo_y_list.append(np.nan * u.km)
256+
solo_z_list.append(np.nan * u.km)
257+
peak_time_list.append(peak_time)
265258
else:
266259
to_remove.append(i)
260+
tx_list.append(np.nan * u.arcsec)
261+
ty_list.append(np.nan * u.arcsec)
262+
solo_x_list.append(np.nan * u.km)
263+
solo_y_list.append(np.nan * u.km)
264+
solo_z_list.append(np.nan * u.km)
265+
peak_time_list.append(peak_time)
266+
data[i]["_position_status"] = False
267+
data[i]["_position_message"] = "flare did not pass the filter function"
268+
269+
solo_times = Time(peak_time_list)
270+
hgs_coords = SkyCoord(
271+
u.Quantity(solo_x_list),
272+
u.Quantity(solo_y_list),
273+
u.Quantity(solo_z_list),
274+
frame=HeliographicStonyhurst(obstime=solo_times),
275+
representation_type="cartesian",
276+
)
277+
278+
hp_coords = SkyCoord(
279+
u.Quantity(tx_list), u.Quantity(ty_list), frame=Helioprojective(obstime=solo_times, observer=hgs_coords)
280+
)
281+
282+
data["location_hgs"] = hgs_coords
283+
# description="Flare location in Heliographic Stonyhurst coordinates"
284+
285+
data["location_hp"] = hp_coords
286+
# description="Flare location in Helioprojective coordinates"
267287

268288
if not keep_all_flares:
269289
data.remove_rows(to_remove)
@@ -276,6 +296,34 @@ def add_flare_position(
276296
f"finally {len(data)} flares remaining"
277297
)
278298

299+
def on_serialize(self, data):
300+
for col_name in ("location_hgs", "location_hp"):
301+
if col_name in data.colnames:
302+
icrs = data[col_name].icrs
303+
icrs_coord = SkyCoord(icrs.ra, icrs.dec, icrs.distance, frame="icrs")
304+
col_idx = data.colnames.index(col_name)
305+
data.remove_column(col_name)
306+
data.add_column(icrs_coord, name=col_name, index=col_idx)
307+
s = super()
308+
if hasattr(s, "on_serialize"):
309+
s.on_serialize(data)
310+
311+
def on_deserialize(self, data, *, peak_time_colname=None):
312+
peak_col = peak_time_colname or self.peak_time_colname
313+
if peak_col not in data.colnames:
314+
logger.warning(f"on_deserialize: column '{peak_col}' not found, skipping location transform")
315+
else:
316+
obstime = Time(data[peak_col])
317+
if "location_hgs" in data.colnames:
318+
data["location_hgs"] = data["location_hgs"].transform_to(HeliographicStonyhurst(obstime=obstime))
319+
if "location_hp" in data.colnames:
320+
data["location_hp"] = data["location_hp"].transform_to(
321+
Helioprojective(obstime=obstime, observer=data["location_hgs"])
322+
)
323+
s = super()
324+
if hasattr(s, "on_deserialize"):
325+
s.on_deserialize(data)
326+
279327

280328
class FlareSOOPMixin:
281329
"""_summary_"""
@@ -611,6 +659,7 @@ def __init__(self, *, service_type=0, service_subtype=0, ssid=3, data, month, **
611659

612660
self.name = FlarelistSDCLoc.NAME
613661
self.ssid = 3
662+
self.peak_time_colname = "peak_UTC"
614663

615664
def enhance_from_product(self, in_prod: GenericProduct):
616665
pass
@@ -628,7 +677,7 @@ def add_flare_position(cls, data, fido_client: STIXClient, *, month=None):
628677
peak_time_colname="peak_UTC",
629678
start_time_colname="start_UTC",
630679
end_time_colname="end_UTC",
631-
keep_all_flares=False,
680+
keep_all_flares=True,
632681
month=month,
633682
)
634683

@@ -745,6 +794,7 @@ def __init__(self, *, service_type=0, service_subtype=0, ssid=7, data, month, **
745794

746795
self.name = FlarelistSCLoc.NAME
747796
self.ssid = 7
797+
self.peak_time_colname = "peak_UTC"
748798

749799
def enhance_from_product(self, in_prod: GenericProduct):
750800
pass

stixcore/products/product.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,9 @@ def __call__(self, *args, **kwargs):
315315
month=month,
316316
)
317317

318+
if hasattr(p, "on_deserialize") and callable(getattr(p, "on_deserialize")):
319+
p.on_deserialize(p.data)
320+
318321
# store the old fits header for later reuse
319322
if isinstance(p, (L1Mixin, L2Mixin)):
320323
p.fits_header = pri_header
@@ -562,6 +565,18 @@ def max_exposure(self):
562565
# default for FITS HEADER
563566
return 0.0
564567

568+
def on_serialize(self, data):
569+
"""Hook called before writing data to FITS. Mixins override and chain via super()."""
570+
s = super()
571+
if hasattr(s, "on_serialize"):
572+
s.on_serialize(data)
573+
574+
def on_deserialize(self, data):
575+
"""Hook called after reading data from FITS. Mixins override and chain via super()."""
576+
s = super()
577+
if hasattr(s, "on_deserialize"):
578+
s.on_deserialize(data)
579+
565580
def find_parent_products(self, root):
566581
"""
567582
Convenient way to get access to the parent products.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from datetime import date
2+
3+
import numpy as np
4+
import pytest
5+
from sunpy.coordinates import HeliographicStonyhurst, Helioprojective
6+
7+
import astropy.units as u
8+
from astropy.coordinates import SkyCoord
9+
from astropy.io import fits
10+
from astropy.table import QTable
11+
from astropy.time import Time
12+
13+
from stixcore.io.product_processors.fits.processors import FitsL3Processor
14+
from stixcore.products.level3.flarelist import FlarelistSDCLoc
15+
from stixcore.products.product import Product
16+
17+
N = 10
18+
19+
20+
@pytest.fixture
21+
def flare_data():
22+
peak_times = Time("2022-01-01T12:00:00") + np.arange(N) * 600 * u.s
23+
24+
hgs_coords = SkyCoord(
25+
lon=np.linspace(0, 30, N) * u.deg,
26+
lat=np.linspace(-5, 5, N) * u.deg,
27+
radius=np.ones(N) * 1.0 * u.AU,
28+
frame=HeliographicStonyhurst(obstime=peak_times),
29+
)
30+
31+
hp_coords = SkyCoord(
32+
Tx=np.linspace(-300, 300, N) * u.arcsec,
33+
Ty=np.linspace(-200, 200, N) * u.arcsec,
34+
frame=Helioprojective(obstime=peak_times, observer=hgs_coords),
35+
)
36+
37+
data = QTable()
38+
data["peak_UTC"] = peak_times
39+
data["start_UTC"] = peak_times - 60 * u.s
40+
data["end_UTC"] = peak_times + 60 * u.s
41+
data["duration"] = np.ones(N) * 120 * u.s
42+
data["lc_peak"] = np.ones((N, 5)) * u.ct / u.s
43+
data["location_hgs"] = hgs_coords
44+
data["location_hp"] = hp_coords
45+
46+
return data
47+
48+
49+
def test_flarelist_sdcloc_location_roundtrip(flare_data, tmp_path):
50+
prod = FlarelistSDCLoc(
51+
data=flare_data,
52+
month=date(2022, 1, 1),
53+
control=QTable(),
54+
)
55+
56+
# minimal header bypasses the Spice-dependent header generation chain
57+
header = fits.Header()
58+
header["LEVEL"] = "L3"
59+
header["STYPE"] = 0
60+
header["SSTYPE"] = 0
61+
header["SSID"] = 3
62+
header["DATE-BEG"] = "2022-01-01T00:00:00"
63+
prod.fits_header = header
64+
65+
# energy/additional_header_keywords are not set for freshly created products
66+
prod.energy = None
67+
prod._additional_header_keywords = []
68+
69+
orig_hgs_lon = prod.data["location_hgs"].lon.copy()
70+
orig_hgs_lat = prod.data["location_hgs"].lat.copy()
71+
orig_hp_tx = prod.data["location_hp"].Tx.copy()
72+
orig_hp_ty = prod.data["location_hp"].Ty.copy()
73+
74+
# write via FitsL3Processor — calls on_serialize internally, prod.data unchanged
75+
writer = FitsL3Processor(tmp_path)
76+
written_file_name = writer.write_fits(prod)
77+
assert len(written_file_name) == 1
78+
79+
# read back via Product factory — calls on_deserialize internally
80+
recovered = Product(written_file_name[0])
81+
82+
assert isinstance(recovered, FlarelistSDCLoc)
83+
assert u.allclose(recovered.data["location_hgs"].lon, orig_hgs_lon, atol=1e-6 * u.deg)
84+
assert u.allclose(recovered.data["location_hgs"].lat, orig_hgs_lat, atol=1e-6 * u.deg)
85+
assert u.allclose(recovered.data["location_hp"].Tx, orig_hp_tx, atol=1e-3 * u.arcsec)
86+
assert u.allclose(recovered.data["location_hp"].Ty, orig_hp_ty, atol=1e-3 * u.arcsec)

0 commit comments

Comments
 (0)