Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions examples/weather/forecast_solar_runs_comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""Compare recent solar-radiation forecast runs at a point, valid-time aligned.

Plots 1h surface downwelling shortwave flux (SSRD) at Zurich for the most recent
runs of Helios and ICON-EU. Every run is drawn against valid time, colored by
model with a light->dark gradient over init time (older runs lighter).

A black "ground truth" line is built from each Helios run's T+1h value (valid at
init + 1h), stitched across runs to approximate the analysis.

This uses cheap single-point queries, so it runs against the live latest
forecasts without needing a raised credit limit.
"""

import logging
from datetime import timedelta

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D

from jua import JuaClient
from jua.types.geo import LatLon
from jua.weather import Models, Variables

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

ZURICH = LatLon(lat=47.3769, lon=8.5417, label="Zurich")
VARIABLE = Variables.SURFACE_DOWNWELLING_SHORTWAVE_FLUX_SUM_1H

# Only keep runs initialized within this many hours of the latest available run.
WINDOW_HOURS = 6

# (label, matplotlib colormap) per model.
MODEL_CONFIG = {
Models.EPT2_HELIOS: ("Helios", "Oranges"),
Models.ICON_EU: ("ICON-EU", "Purples"),
}


def _naive_utc(ts) -> pd.Timestamp:
"""Normalize a timestamp to tz-naive UTC."""
t = pd.Timestamp(ts)
if t.tzinfo is not None:
t = t.tz_convert("UTC").tz_localize(None)
return t


def available_init_times(model_obj) -> list[pd.Timestamp]:
"""Sorted (ascending) tz-naive init times available for a model."""
available = model_obj.get_available_forecasts(limit=50)
return sorted(_naive_utc(f.init_time) for f in available.forecasts)


def ground_truth(model_obj, init_times, day_start, day_end):
"""Stitch each run's T+1h value into a pseudo ground-truth series."""
inits = [t for t in init_times if day_start <= t + timedelta(hours=1) <= day_end]
if not inits:
return None, None

forecast = model_obj.get_forecasts(
init_time=[t.to_pydatetime() for t in inits],
points=ZURICH,
variables=[VARIABLE],
prediction_timedelta=[np.timedelta64(60, "m")],
stream=False,
)
da = forecast[VARIABLE].squeeze()
valid_times = pd.to_datetime(da["init_time"].values) + pd.Timedelta(minutes=60)
values = np.asarray(da.values).ravel()
order = np.argsort(valid_times)
return valid_times[order], values[order]


def main():
client = JuaClient()

helios = client.weather.get_model(Models.EPT2_HELIOS)
helios_inits = available_init_times(helios)
if not helios_inits:
logger.warning("No Helios runs available; nothing to plot.")
return

# Frame the chart on the day of the latest Helios run.
latest_init = helios_inits[-1]
day_start = latest_init.normalize()
day_end = day_start + timedelta(days=1)

fig, ax = plt.subplots(figsize=(14, 7))
legend_handles = []

for model_enum, (label, cmap_name) in MODEL_CONFIG.items():
model_obj = client.weather.get_model(model_enum)
init_times = available_init_times(model_obj)
runs = [
t for t in init_times if t >= init_times[-1] - timedelta(hours=WINDOW_HOURS)
]
if not runs:
logger.warning(f"No runs found for {label}")
continue

cmap = plt.get_cmap(cmap_name)
n = len(runs)
logger.info(f"{label}: {n} runs from {runs[0]} to {runs[-1]}")

for i, init_time in enumerate(runs):
max_lead = int((day_end - init_time).total_seconds() / 3600) + 1
if max_lead <= 0:
continue

forecast = model_obj.get_forecasts(
init_time=init_time.to_pydatetime(),
points=ZURICH,
variables=[VARIABLE],
max_lead_time=min(48, max_lead),
stream=False,
)
da = forecast[VARIABLE].to_absolute_time().squeeze()
times = pd.to_datetime(da["time"].values)
values = np.asarray(da.values).ravel()

mask = times <= day_end
times, values = times[mask], values[mask]
if len(times) == 0:
continue

# Older runs lighter, newest darkest.
shade = 0.35 + 0.6 * (i / max(n - 1, 1))
ax.plot(times, values, color=cmap(shade), linewidth=1.6, alpha=0.9)

legend_handles.append(
Line2D([0], [0], color=cmap(0.85), linewidth=2.5, label=label)
)

# Pseudo ground truth: Helios T+1h stitched across runs.
gt_times, gt_values = ground_truth(helios, helios_inits, day_start, day_end)
if gt_times is not None:
ax.plot(gt_times, gt_values, color="black", linewidth=2.8, zorder=10)
legend_handles.append(
Line2D(
[0],
[0],
color="black",
linewidth=2.8,
label="Ground truth (Helios T+1h)",
)
)

ax.set_xlim(day_start, day_end)
ax.set_xlabel("Valid time (UTC)")
ax.set_ylabel(VARIABLE.display_name_with_unit)
ax.set_title(
f"SSRD 1h at {ZURICH.label} — recent runs (last {WINDOW_HOURS}h), "
"gradient = init time (light=older)"
)
ax.legend(handles=legend_handles)
ax.grid(True, alpha=0.3)
fig.autofmt_xdate()
plt.tight_layout()
plt.show()


if __name__ == "__main__":
main()
33 changes: 26 additions & 7 deletions src/jua/weather/_model_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@ class TemporalResolution:
"""Internal class to store model temporal resolution

Used for models with variable temporal resolution, such as EPT2.
Resolutions are expressed in hours and may be fractional (e.g. ``0.5``
for a 30-minute cadence such as EPT2 Helios).

Attributes:
default: The default temporal resolution for the model.
segments: The resolution of the model for prediction_timedelta ranges.
base: The default temporal resolution for the model, in hours.
special: The resolution of the model for prediction_timedelta ranges.
Defined as `(resolution, from_hour, to_hour)`, where the model has a
prediction every `resolution` hours when the prediction_timedelta is
in the interval [`from_hour`, `to_hour`].
"""

base: int
special: tuple[tuple[int, int, int], ...] = tuple()
base: float
special: tuple[tuple[float, int, int], ...] = tuple()

def __post_init__(self) -> None:
"""Checks that the special cases make sense"""
Expand All @@ -45,6 +47,9 @@ def __post_init__(self) -> None:
def num_prediction_timedeltas(self, from_hour: int, to_hour: int) -> int:
"""Determines the number of `prediction_timedeltas` in an interval.

Iterates internally in minutes so that sub-hourly resolutions
(e.g. a 30-minute cadence) are counted correctly.

Attributes:
from_hour: The start hour for the interval
to_hour: The end hour for the interval
Expand All @@ -56,13 +61,15 @@ def num_prediction_timedeltas(self, from_hour: int, to_hour: int) -> int:
)

num_timedeltas = 0
for h in range(from_hour, to_hour + 1):
for minute in range(from_hour * 60, to_hour * 60 + 1):
hour = minute / 60
resolution = self.base
for s_res, s_start, s_end in self.special:
if s_start <= h <= s_end:
if s_start <= hour <= s_end:
resolution = s_res
break
if h % resolution == 0:
resolution_minutes = round(resolution * 60)
if minute % resolution_minutes == 0:
num_timedeltas += 1

return num_timedeltas
Expand Down Expand Up @@ -153,6 +160,18 @@ class ModelMetaInfo:
full_forecasted_hours=480,
temporal_resolution=TemporalResolution(base=6, special=((1, 0, 10 * 24),)),
)
_MODEL_META_INFO[Models.EPT2_HELIOS] = ModelMetaInfo(
has_grid_access=True,
full_forecasted_hours=48,
forecasts_per_day=48,
temporal_resolution=TemporalResolution(base=0.5),
)
_MODEL_META_INFO[Models.EPT2_EUROPA] = ModelMetaInfo(
has_grid_access=True,
full_forecasted_hours=48,
forecasts_per_day=24,
temporal_resolution=TemporalResolution(base=1),
)
_MODEL_META_INFO[Models.AIFS] = ModelMetaInfo(
has_grid_access=True,
forecast_name_mapping="aifs",
Expand Down
2 changes: 2 additions & 0 deletions src/jua/weather/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class Models(str, Enum):
EPT2_HRRR = "ept2_hrrr"
EPT2_RR = "ept2_rr"
EPT2_REASONING = "ept2_reasoning"
EPT2_HELIOS = "ept2_1_helios"
EPT2_EUROPA = "ept2_1_europa"
AIFS = "aifs"
AURORA = "aurora"
ECMWF_IFS_SINGLE = "ecmwf_ifs_single"
Expand Down
8 changes: 8 additions & 0 deletions src/jua/weather/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,18 @@ class Variables(Enum):
"surface_downwelling_shortwave_flux_sum_1h", "J / m^2", "ssrd", None
)

SURFACE_DOWNWELLING_SHORTWAVE_FLUX_SUM_30MIN = Variable(
"surface_downwelling_shortwave_flux_sum_30min", "J / m^2", None, None
)

SURFACE_DIRECT_DOWNWELLING_SHORTWAVE_FLUX_SUM_1H = Variable(
"surface_direct_downwelling_shortwave_flux_sum_1h", "J / m^2", "fdir", None
)

SURFACE_DIRECT_DOWNWELLING_SHORTWAVE_FLUX_SUM_30MIN = Variable(
"surface_direct_downwelling_shortwave_flux_sum_30min", "J / m^2", None, None
)

SURFACE_NET_DOWNWARD_SHORTWAVE_FLUX_SUM_1H = Variable(
"surface_net_downward_shortwave_flux_sum_1h", "J / m^2", "ssr", None
)
Expand Down
10 changes: 8 additions & 2 deletions tests/functional/test_forecasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,14 @@
Models.ICON_EU: datetime(2026, 2, 9, 0, 0, 0),
}

ALL_MODELS = list(Models)
INTERNAL_MODELS = [m for m in Models if get_model_meta_info(m).has_grid_access]
SOLAR_ONLY_MODELS = {Models.EPT2_HELIOS}

ALL_MODELS = [m for m in Models if m not in SOLAR_ONLY_MODELS]
INTERNAL_MODELS = [
m
for m in Models
if get_model_meta_info(m).has_grid_access and m not in SOLAR_ONLY_MODELS
]


def get_forecast_date(model: Models) -> datetime:
Expand Down
6 changes: 4 additions & 2 deletions tests/weather/test_temporal_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
((1, 0, 24), (3, 24, 48)),
[((0, 48), 33), ((0, 72), 37)],
),
# Sub-hourly (30min) resolution: 2 steps per hour
(0.5, tuple(), [((0, 1), 3), ((0, 2), 5), ((0, 48), 97)]),
],
)
def test_temporal_resolution(
base: int,
special: tuple[tuple[int, int, int]],
base: float,
special: tuple[tuple[float, int, int]],
test_cases: list[tuple[int, int], int],
) -> None:
tr = TemporalResolution(base=base, special=special)
Expand Down
Loading