Skip to content

Commit 7703af6

Browse files
authored
Fix rechunking during canonicalisation (#30)
* Fix rechunking during canonicalisation * Fix ruff version
1 parent e605ed7 commit 7703af6

10 files changed

Lines changed: 33 additions & 37 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ dev = [
3838
"mypy~=1.14",
3939
"pandas-stubs~=2.2.0",
4040
"pytest~=8.3",
41-
"ruff~=0.8",
41+
"ruff~=0.15",
4242
"types-requests~=2.32.0.20241016",
4343
"types-tqdm~=4.67.0.20250301",
4444
"universal-pathlib~=0.2.0",

src/climatebenchpress/data_loader/__init__.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def open_downloaded_canonicalized_dataset(
5151
ds = canon.canonicalize_dataset(ds)
5252

5353
with monitor.progress_bar(progress):
54-
ds.to_zarr(standardized, encoding=dict(), compute=False).compute()
54+
ds.to_zarr(standardized, compute=False).compute()
5555

5656
return xr.open_dataset(standardized, chunks=dict(), engine="zarr")
5757

@@ -96,20 +96,9 @@ def open_downloaded_tiny_canonicalized_dataset(
9696
ds = canon.canonical_tiny_dataset(ds, slices=slices)
9797
# Rechunk the data because "tiny-fication" can lead to inconsistent or
9898
# suboptimal chunking.
99-
ds = _rechunk_dataset(ds)
99+
ds = ds.chunk(-1)
100100

101101
with monitor.progress_bar(progress):
102-
ds.to_zarr(
103-
standardized, encoding=dict(), compute=False, consolidated=True
104-
).compute()
102+
ds.to_zarr(standardized, compute=False, consolidated=True).compute()
105103

106104
return xr.open_dataset(standardized, chunks=dict(), engine="zarr")
107-
108-
109-
def _rechunk_dataset(ds: xr.Dataset) -> xr.Dataset:
110-
rechunked = ds.copy()
111-
for var_name in ds.data_vars:
112-
if hasattr(ds[var_name].data, "chunks"):
113-
rechunked[var_name] = ds[var_name].chunk("auto")
114-
115-
return rechunked

src/climatebenchpress/data_loader/datasets/cams.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,11 @@ def download(download_path: Path, progress: bool = True):
4343

4444
@staticmethod
4545
def open(download_path: Path) -> xr.Dataset:
46-
ds = xr.open_dataset(download_path / Path(NO2_FILE).name).chunk(-1)
46+
ds = (
47+
xr.open_dataset(download_path / Path(NO2_FILE).name)
48+
.drop_encoding()
49+
.chunk(-1)
50+
)
4751

4852
# valid_time contains actual dates, whereas step is the seconds (in simulated time)
4953
# since the model as been initialised.

src/climatebenchpress/data_loader/datasets/cmip6/abc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,13 @@ def download_with(
6262
ds = ds[variable_selector]
6363

6464
with monitor.progress_bar(progress):
65-
ds.to_zarr(downloadfile, mode="w", encoding=dict(), compute=False).compute()
65+
ds.to_zarr(downloadfile, mode="w", compute=False).compute()
6666

6767
donefile.touch()
6868

6969
@staticmethod
7070
def open(download_path: Path) -> xr.Dataset:
71-
return xr.open_zarr(download_path / "download.zarr")
71+
return xr.open_zarr(download_path / "download.zarr").drop_encoding().chunk(-1)
7272

7373
@lru_cache
7474
@staticmethod

src/climatebenchpress/data_loader/datasets/era5.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,18 @@ def download(download_path: Path, progress: bool = True):
4444
"10m_u_component_of_wind",
4545
"10m_v_component_of_wind",
4646
]
47-
].chunk(-1)
47+
]
4848
# Needed to make the dataset CF-compliant.
4949
ds.time.attrs["standard_name"] = "time"
5050
ds.longitude.attrs["axis"] = "X"
5151
ds.latitude.attrs["axis"] = "Y"
5252
with monitor.progress_bar(progress):
53-
ds.to_zarr(downloadfile, mode="w", encoding=dict(), compute=False).compute()
53+
ds.to_zarr(downloadfile, mode="w", compute=False).compute()
5454
donefile.touch()
5555

5656
@staticmethod
5757
def open(download_path: Path) -> xr.Dataset:
58-
return xr.open_zarr(download_path / "download.zarr")
58+
return xr.open_zarr(download_path / "download.zarr").drop_encoding().chunk(-1)
5959

6060

6161
if __name__ == "__main__":

src/climatebenchpress/data_loader/datasets/esa_biomass_cci.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def download(download_path: Path, progress: bool = True):
4747
@staticmethod
4848
def open(download_path: Path) -> xr.Dataset:
4949
# Need string conversion for argument to be interpreted as a glob pattern.
50-
ds = xr.open_mfdataset(str(download_path / "*.nc"))
50+
ds = xr.open_mfdataset(str(download_path / "*.nc")).drop_encoding()
5151
# Needed to make the dataset CF-compliant.
5252
ds.lon.attrs["axis"] = "X"
5353
ds.lat.attrs["axis"] = "Y"

src/climatebenchpress/data_loader/datasets/ifs_humidity.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,11 @@ def download(download_path: Path, progress: bool = True):
3939
)
4040
downloadfile = download_path / "ifs_humidity.zarr"
4141
with monitor.progress_bar(progress):
42-
ds_regridded.to_zarr(
43-
downloadfile, mode="w", encoding=dict(), compute=False
44-
).compute()
42+
ds_regridded.to_zarr(downloadfile, mode="w", compute=False).compute()
4543

4644
@staticmethod
4745
def open(download_path: Path) -> xr.Dataset:
48-
ds = xr.open_dataset(download_path / "ifs_humidity.zarr")
46+
ds = xr.open_zarr(download_path / "ifs_humidity.zarr").drop_encoding()
4947
num_levels = ds["level"].size
5048
ds = ds.isel(time=slice(0, 1)).chunk(
5149
{

src/climatebenchpress/data_loader/datasets/ifs_uncompressed.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@ def download(download_path: Path, progress: bool = True):
3939
)
4040
downloadfile = download_path / "ifs_uncompressed.zarr"
4141
with monitor.progress_bar(progress):
42-
ds_regridded.to_zarr(
43-
downloadfile, mode="w", encoding=dict(), compute=False
44-
).compute()
42+
ds_regridded.to_zarr(downloadfile, mode="w", compute=False).compute()
4543

4644
@staticmethod
4745
def open(download_path: Path) -> xr.Dataset:
48-
ds = xr.open_dataset(download_path / "ifs_uncompressed.zarr")
46+
ds = (
47+
xr.open_dataset(download_path / "ifs_uncompressed.zarr")
48+
.drop_encoding()
49+
.chunk(-1)
50+
)
4951

5052
# Needed to make the dataset CF-compliant.
5153
ds.longitude.attrs["axis"] = "X"
@@ -106,7 +108,11 @@ def load_hplp_data(leveltype=None, gridtype=None, step=None, remap=False):
106108
return xr.open_dataset(
107109
"reference://",
108110
engine="zarr",
109-
backend_kwargs=dict(storage_options=dict(fo=ref, asynchronous=False)),
111+
backend_kwargs=dict(
112+
storage_options=dict(
113+
fo=ref, asynchronous=False, remote_options=dict(ssl=False)
114+
)
115+
),
110116
consolidated=False,
111117
)
112118

@@ -149,9 +155,9 @@ def regrid_to_regular(ds, in_grid, out_grid):
149155
out_data[var].append(r)
150156

151157
dx = out_grid["grid"][0]
152-
assert (
153-
out_grid["grid"][0] == out_grid["grid"][1]
154-
), "Only grids with equal latitude and longitude spacing are supported."
158+
assert out_grid["grid"][0] == out_grid["grid"][1], (
159+
"Only grids with equal latitude and longitude spacing are supported."
160+
)
155161
lats = np.linspace(90, -90, int(180 / dx) + 1)
156162
lons = np.linspace(0, 360 - dx, int(360 / dx))
157163
coords = {

src/climatebenchpress/data_loader/datasets/nextgems.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ def download(download_path: Path, progress: bool = True):
7171
ds.lat.attrs["axis"] = "Y"
7272

7373
with monitor.progress_bar(progress):
74-
ds.to_zarr(downloadfile, mode="w", encoding=dict(), compute=False).compute()
74+
ds.to_zarr(downloadfile, mode="w", compute=False).compute()
7575
donefile.touch()
7676

7777
@staticmethod
7878
def open(download_path: Path) -> xr.Dataset:
79-
return xr.open_zarr(download_path / "download.zarr")
79+
return xr.open_zarr(download_path / "download.zarr").drop_encoding().chunk(-1)
8080

8181

8282
def _get_nn_lon_lat_index(nside, lons, lats):

tests/test_virtual.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def download(download_path: Path, progress: bool = True):
4040
ds.to_zarr(
4141
download_path / "download.zarr",
4242
mode="w",
43-
encoding=dict(),
4443
compute=False,
4544
).compute()
4645

0 commit comments

Comments
 (0)