Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@ packages = [{include = "stationbench"}]

[tool.poetry.dependencies]
python = ">=3.11,<3.13"
xarray = {extras = ["io"], version = "^2025.03.0"}
xarray = {extras = ["io"], version = "^2026.02.0"}
pandas = "^2.2.0"
dask = {extras = ["distributed"], version = "^2024.1.1"}
dask = {extras = ["distributed"], version = "^2026.1.2"}
plotly = "^5.18.0"
zarr = "^3.0.4"
gcsfs = "^2024.2.0"
gcsfs = "^2026.1.0"
nbformat = "^5.10.4"
kaleido = "0.2.1"
wandb = "^0.19.7"
scoringrules = {version = "^0.8", optional = true}

[tool.poetry.extras]
ensemble = ["scoringrules"]

[tool.poetry.group.dev.dependencies]
ruff = "0.4.2"
Expand Down
3 changes: 3 additions & 0 deletions stationbench/calculate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,11 @@ def generate_benchmarks(
logger.info("Calculating metrics")
metrics_list = []

has_ensemble = "member" in forecast.dims
# Calculate each metric
for metric in AVAILABLE_METRICS.values():
if metric.is_ensemble and not has_ensemble:
continue
metrics_list.append(metric.compute(forecast, stations))
# Merge all metrics into one dataset
return xr.merge(metrics_list)
Expand Down
15 changes: 12 additions & 3 deletions stationbench/compare_forecasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,22 @@ def process_temporal_and_spatial_metrics(
temporal_regions.append(temp_dataset)

# Concatenate datasets for each region
combined_temporal_regions = xr.concat(temporal_regions, dim="region")
combined_temporal_regions = xr.concat(
temporal_regions,
dim="region",
)
temporal_metrics.append(combined_temporal_regions)

# Concatenate datasets for each metric
combined_temporal_metric = xr.concat(temporal_metrics, dim="metric")
combined_temporal_metric = xr.concat(
temporal_metrics,
dim="metric",
)
temporal_metric_datasets.append(combined_temporal_metric)
combined_spatial_metric = xr.concat(spatial_metrics, dim="metric")
combined_spatial_metric = xr.concat(
spatial_metrics,
dim="metric",
)
spatial_metric_datasets.append(combined_spatial_metric)
return temporal_metric_datasets, spatial_metric_datasets

Expand Down
36 changes: 36 additions & 0 deletions stationbench/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
class Metric(ABC):
"""Base class for all metrics."""

is_ensemble: bool = False

@abstractmethod
def compute(self, forecast: xr.Dataset, ground_truth: xr.Dataset) -> xr.Dataset:
"""Compute metric between forecast and ground truth."""
Expand Down Expand Up @@ -56,7 +58,41 @@ def compute(self, forecast: xr.Dataset, ground_truth: xr.Dataset) -> xr.Dataset:
return xr.Dataset(mbe).expand_dims(metric=["mbe"])


class CRPSEnsemble(Metric):
is_ensemble = True

def compute(self, forecast: xr.Dataset, ground_truth: xr.Dataset) -> xr.Dataset:
"""Compute Continuous Ranked Probability Score for ensemble forecasts.

For each station (s) and lead time (l):
CRPS(s,l) = 1/T * sum_t[crps_ensemble(o_{s,t}, f_{s,t,l,:})]

where:
- t: time index
- T: total number of time steps
- f: ensemble forecast with member dimension
- o: observation (ground truth)

Uses scoringrules.crps_ensemble with the ensemble member dimension.
"""
import scoringrules as sr

crps = {}
for var in forecast.data_vars:
crps[var] = xr.apply_ufunc(
sr.crps_ensemble,
ground_truth[var],
forecast[var],
input_core_dims=[[], ["member"]],
dask="parallelized",
output_dtypes=[float],
).mean("init_time", skipna=True)

return xr.Dataset(crps).expand_dims(metric=["crps"])


AVAILABLE_METRICS = {
"rmse": RMSE(),
"mbe": MBE(),
"crps": CRPSEnsemble(),
}
94 changes: 94 additions & 0 deletions tests/test_calculate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,100 @@ def test_mbe_calculation(sample_forecast, sample_stations):
)


@pytest.fixture
def sample_ensemble_forecast():
"""Create a sample ensemble forecast dataset."""
times = pd.date_range("2022-01-01", "2022-01-02", freq="24h")
lead_times = pd.timedelta_range("0h", "24h", freq="24h")
stations = ["ST1"]
members = np.arange(5)
lats = [50.0]
lons = [5.0]

ds = xr.Dataset(
data_vars={
"2m_temperature": (
("time", "prediction_timedelta", "station_id", "member"),
np.random.randn(
len(times), len(lead_times), len(stations), len(members)
),
),
"10m_wind_speed": (
("time", "prediction_timedelta", "station_id", "member"),
np.random.randn(
len(times), len(lead_times), len(stations), len(members)
),
),
},
coords={
"time": times,
"prediction_timedelta": lead_times,
"station_id": stations,
"member": members,
"latitude": ("station_id", lats),
"longitude": ("station_id", lons),
},
)
return ds


def test_ensemble_pipeline(sample_ensemble_forecast, sample_stations):
"""Test the full pipeline with ensemble forecast includes CRPS."""
sr = pytest.importorskip("scoringrules")

# Rename dims to match post-prepare_forecast format
forecast = sample_ensemble_forecast.rename(
{"time": "init_time", "prediction_timedelta": "lead_time"}
)
forecast.coords["valid_time"] = forecast.init_time + forecast.lead_time

stations = sample_stations.copy()

benchmarks = generate_benchmarks(forecast=forecast, stations=stations)

assert "crps" in benchmarks.metric.values
assert "rmse" in benchmarks.metric.values
assert "mbe" in benchmarks.metric.values


def test_crps_calculation(sample_ensemble_forecast, sample_stations):
"""Test CRPS calculation with a known case: perfect ensemble should yield CRPS ~0."""
sr = pytest.importorskip("scoringrules")

forecast = sample_ensemble_forecast.rename(
{"time": "init_time", "prediction_timedelta": "lead_time"}
)
forecast.coords["valid_time"] = forecast.init_time + forecast.lead_time

# Set all ensemble members to the same value as observations
forecast["10m_wind_speed"][:] = 3.0
stations = sample_stations.copy()
stations["10m_wind_speed"][:] = 3.0

benchmarks = generate_benchmarks(forecast=forecast, stations=stations)

np.testing.assert_allclose(
benchmarks.sel(metric="crps")["10m_wind_speed"].values,
0.0,
atol=1e-6,
err_msg="CRPS should be ~0 for a perfect ensemble",
)


def test_no_ensemble_skips_crps(sample_forecast, sample_stations):
"""Test that CRPS is skipped when forecast has no member dimension."""
forecast = sample_forecast.copy()
forecast = forecast.rename({"time": "init_time"})
forecast = forecast.rename({"prediction_timedelta": "lead_time"})
forecast.coords["valid_time"] = forecast.init_time + forecast.lead_time

benchmarks = generate_benchmarks(forecast=forecast, stations=sample_stations)

assert "crps" not in benchmarks.metric.values
assert "rmse" in benchmarks.metric.values
assert "mbe" in benchmarks.metric.values


def test_invalid_path():
"""Test handling of invalid file paths."""
with pytest.raises(Exception): # Should raise some kind of file not found error
Expand Down