diff --git a/pyproject.toml b/pyproject.toml index 899bbd2..82ad176 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,15 +9,19 @@ packages = [{include = "stationbench"}] [tool.poetry.dependencies] python = ">=3.11,<3.13" -xarray = {extras = ["io"], version = "^2025.03.0"} +xarray = {extras = ["io"], version = "^2026.02.0"} pandas = "^2.2.0" -dask = {extras = ["distributed"], version = "^2024.1.1"} +dask = {extras = ["distributed"], version = "^2026.1.2"} plotly = "^5.18.0" zarr = "^3.0.4" -gcsfs = "^2024.2.0" +gcsfs = "^2026.1.0" nbformat = "^5.10.4" kaleido = "0.2.1" wandb = "^0.19.7" +scoringrules = {version = "^0.8", optional = true} + +[tool.poetry.extras] +ensemble = ["scoringrules"] [tool.poetry.group.dev.dependencies] ruff = "0.4.2" diff --git a/stationbench/calculate_metrics.py b/stationbench/calculate_metrics.py index 504ec79..3a0a970 100644 --- a/stationbench/calculate_metrics.py +++ b/stationbench/calculate_metrics.py @@ -219,8 +219,11 @@ def generate_benchmarks( logger.info("Calculating metrics") metrics_list = [] + has_ensemble = "member" in forecast.dims # Calculate each metric for metric in AVAILABLE_METRICS.values(): + if metric.is_ensemble and not has_ensemble: + continue metrics_list.append(metric.compute(forecast, stations)) # Merge all metrics into one dataset return xr.merge(metrics_list) diff --git a/stationbench/compare_forecasts.py b/stationbench/compare_forecasts.py index 39669c1..f52efc1 100644 --- a/stationbench/compare_forecasts.py +++ b/stationbench/compare_forecasts.py @@ -115,13 +115,22 @@ def process_temporal_and_spatial_metrics( temporal_regions.append(temp_dataset) # Concatenate datasets for each region - combined_temporal_regions = xr.concat(temporal_regions, dim="region") + combined_temporal_regions = xr.concat( + temporal_regions, + dim="region", + ) temporal_metrics.append(combined_temporal_regions) # Concatenate datasets for each metric - combined_temporal_metric = xr.concat(temporal_metrics, dim="metric") + combined_temporal_metric = xr.concat( + temporal_metrics, + dim="metric", + ) temporal_metric_datasets.append(combined_temporal_metric) - combined_spatial_metric = xr.concat(spatial_metrics, dim="metric") + combined_spatial_metric = xr.concat( + spatial_metrics, + dim="metric", + ) spatial_metric_datasets.append(combined_spatial_metric) return temporal_metric_datasets, spatial_metric_datasets diff --git a/stationbench/utils/metrics.py b/stationbench/utils/metrics.py index ca044f8..7f2d461 100644 --- a/stationbench/utils/metrics.py +++ b/stationbench/utils/metrics.py @@ -5,6 +5,8 @@ class Metric(ABC): """Base class for all metrics.""" + is_ensemble: bool = False + @abstractmethod def compute(self, forecast: xr.Dataset, ground_truth: xr.Dataset) -> xr.Dataset: """Compute metric between forecast and ground truth.""" @@ -56,7 +58,41 @@ def compute(self, forecast: xr.Dataset, ground_truth: xr.Dataset) -> xr.Dataset: return xr.Dataset(mbe).expand_dims(metric=["mbe"]) +class CRPSEnsemble(Metric): + is_ensemble = True + + def compute(self, forecast: xr.Dataset, ground_truth: xr.Dataset) -> xr.Dataset: + """Compute Continuous Ranked Probability Score for ensemble forecasts. + + For each station (s) and lead time (l): + CRPS(s,l) = 1/T * sum_t[crps_ensemble(o_{s,t}, f_{s,t,l,:})] + + where: + - t: time index + - T: total number of time steps + - f: ensemble forecast with member dimension + - o: observation (ground truth) + + Uses scoringrules.crps_ensemble with the ensemble member dimension. + """ + import scoringrules as sr + + crps = {} + for var in forecast.data_vars: + crps[var] = xr.apply_ufunc( + sr.crps_ensemble, + ground_truth[var], + forecast[var], + input_core_dims=[[], ["member"]], + dask="parallelized", + output_dtypes=[float], + ).mean("init_time", skipna=True) + + return xr.Dataset(crps).expand_dims(metric=["crps"]) + + AVAILABLE_METRICS = { "rmse": RMSE(), "mbe": MBE(), + "crps": CRPSEnsemble(), } diff --git a/tests/test_calculate_metrics.py b/tests/test_calculate_metrics.py index 0b905e8..bf2bfee 100644 --- a/tests/test_calculate_metrics.py +++ b/tests/test_calculate_metrics.py @@ -265,6 +265,100 @@ def test_mbe_calculation(sample_forecast, sample_stations): ) +@pytest.fixture +def sample_ensemble_forecast(): + """Create a sample ensemble forecast dataset.""" + times = pd.date_range("2022-01-01", "2022-01-02", freq="24h") + lead_times = pd.timedelta_range("0h", "24h", freq="24h") + stations = ["ST1"] + members = np.arange(5) + lats = [50.0] + lons = [5.0] + + ds = xr.Dataset( + data_vars={ + "2m_temperature": ( + ("time", "prediction_timedelta", "station_id", "member"), + np.random.randn( + len(times), len(lead_times), len(stations), len(members) + ), + ), + "10m_wind_speed": ( + ("time", "prediction_timedelta", "station_id", "member"), + np.random.randn( + len(times), len(lead_times), len(stations), len(members) + ), + ), + }, + coords={ + "time": times, + "prediction_timedelta": lead_times, + "station_id": stations, + "member": members, + "latitude": ("station_id", lats), + "longitude": ("station_id", lons), + }, + ) + return ds + + +def test_ensemble_pipeline(sample_ensemble_forecast, sample_stations): + """Test the full pipeline with ensemble forecast includes CRPS.""" + sr = pytest.importorskip("scoringrules") + + # Rename dims to match post-prepare_forecast format + forecast = sample_ensemble_forecast.rename( + {"time": "init_time", "prediction_timedelta": "lead_time"} + ) + forecast.coords["valid_time"] = forecast.init_time + forecast.lead_time + + stations = sample_stations.copy() + + benchmarks = generate_benchmarks(forecast=forecast, stations=stations) + + assert "crps" in benchmarks.metric.values + assert "rmse" in benchmarks.metric.values + assert "mbe" in benchmarks.metric.values + + +def test_crps_calculation(sample_ensemble_forecast, sample_stations): + """Test CRPS calculation with a known case: perfect ensemble should yield CRPS ~0.""" + sr = pytest.importorskip("scoringrules") + + forecast = sample_ensemble_forecast.rename( + {"time": "init_time", "prediction_timedelta": "lead_time"} + ) + forecast.coords["valid_time"] = forecast.init_time + forecast.lead_time + + # Set all ensemble members to the same value as observations + forecast["10m_wind_speed"][:] = 3.0 + stations = sample_stations.copy() + stations["10m_wind_speed"][:] = 3.0 + + benchmarks = generate_benchmarks(forecast=forecast, stations=stations) + + np.testing.assert_allclose( + benchmarks.sel(metric="crps")["10m_wind_speed"].values, + 0.0, + atol=1e-6, + err_msg="CRPS should be ~0 for a perfect ensemble", + ) + + +def test_no_ensemble_skips_crps(sample_forecast, sample_stations): + """Test that CRPS is skipped when forecast has no member dimension.""" + forecast = sample_forecast.copy() + forecast = forecast.rename({"time": "init_time"}) + forecast = forecast.rename({"prediction_timedelta": "lead_time"}) + forecast.coords["valid_time"] = forecast.init_time + forecast.lead_time + + benchmarks = generate_benchmarks(forecast=forecast, stations=sample_stations) + + assert "crps" not in benchmarks.metric.values + assert "rmse" in benchmarks.metric.values + assert "mbe" in benchmarks.metric.values + + def test_invalid_path(): """Test handling of invalid file paths.""" with pytest.raises(Exception): # Should raise some kind of file not found error