Skip to content

Commit 8076088

Browse files
Donglai Weiclaude
andcommitted
Zarr I/O, val skeleton precompute, data-check visualization, lazy zarr label_aux
- read_volume/save_volume: add zarr format support (split store/subkey paths) - LazyZarrVolumeDataset: add label_aux_paths, load non-zarr files eagerly - data_factory: precompute val skeleton alongside train, val uses same dataset class - VisualizationCallback: log image+label on first batch (data_check, no prediction) - VisualizationCallback: head="all" visualizes all heads separately - config_io: allow head="all" in validation Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ecd3520 commit 8076088

5 files changed

Lines changed: 116 additions & 33 deletions

File tree

connectomics/data/io/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
read_images,
1515
read_volume,
1616
save_volume,
17+
volume_exists,
1718
write_hdf5,
1819
)
1920
from .transforms import (
@@ -32,6 +33,7 @@
3233
"read_volume",
3334
"save_volume",
3435
"get_vol_shape",
36+
"volume_exists",
3537
"LoadVolumed",
3638
"SaveVolumed",
3739
"TileLoaderd",

connectomics/data/io/io.py

Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import glob
1414
import logging
1515
import os
16+
from pathlib import Path
1617
from typing import List, Optional, Union
1718

1819
import h5py
@@ -35,11 +36,9 @@ def _detect_format(filename: str) -> str:
3536
Returns canonical format string:
3637
'h5', 'tiff', 'png', 'nifti', 'zarr'.
3738
"""
38-
if ".zarr" in filename:
39-
return "zarr"
4039
if filename.endswith(".nii.gz"):
4140
return "nifti"
42-
suffix = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
41+
suffix = Path(filename).suffix.lower().lstrip(".")
4342
_SUFFIX_MAP = {
4443
"h5": "h5",
4544
"hdf5": "h5",
@@ -49,12 +48,22 @@ def _detect_format(filename: str) -> str:
4948
"nii": "nifti",
5049
}
5150
fmt = _SUFFIX_MAP.get(suffix)
52-
if fmt is None:
53-
raise ValueError(
54-
f"Unrecognizable file format for {filename}. "
55-
f"Expected: h5, hdf5, tif, tiff, png, nii, nii.gz"
56-
)
57-
return fmt
51+
if fmt is not None:
52+
return fmt
53+
if ".zarr" in filename:
54+
return "zarr"
55+
raise ValueError(
56+
f"Unrecognizable file format for {filename}. "
57+
f"Expected: h5, hdf5, tif, tiff, png, nii, nii.gz, zarr"
58+
)
59+
60+
61+
def _split_zarr_path(filename: str) -> tuple[str, Optional[str]]:
62+
"""Split a zarr path into store path and optional subkey."""
63+
zarr_idx = filename.index(".zarr")
64+
zarr_path = filename[: zarr_idx + 5]
65+
sub_key = filename[zarr_idx + 5 :].strip("/") or None
66+
return zarr_path, sub_key
5867

5968

6069
# =============================================================================
@@ -346,10 +355,7 @@ def read_volume(
346355
elif fmt == "zarr":
347356
import zarr
348357

349-
# Path may be "dir.zarr/subkey" — split at .zarr boundary.
350-
zarr_idx = filename.index(".zarr")
351-
zarr_path = filename[: zarr_idx + 5]
352-
sub_key = filename[zarr_idx + 5 :].strip("/") or None
358+
zarr_path, sub_key = _split_zarr_path(filename)
353359
store = zarr.open(zarr_path, mode="r")
354360
arr = store[sub_key] if sub_key else store
355361
data = np.asarray(arr)
@@ -374,7 +380,7 @@ def save_volume(
374380
filename: str,
375381
volume: np.ndarray,
376382
dataset: str = "main",
377-
file_format: str = "h5",
383+
file_format: Optional[str] = None,
378384
) -> None:
379385
"""Save volumetric data in specified format.
380386
@@ -384,9 +390,27 @@ def save_volume(
384390
dataset: Dataset name for HDF5 format.
385391
file_format: 'h5', 'tiff', 'png', 'nii', 'nii.gz'.
386392
"""
393+
file_format = file_format or _detect_format(filename)
394+
387395
if file_format == "h5":
388396
write_hdf5(filename, volume, dataset=dataset)
389397

398+
elif file_format == "zarr":
399+
import zarr
400+
401+
zarr_path, sub_key = _split_zarr_path(filename)
402+
if sub_key:
403+
group = zarr.open_group(zarr_path, mode="a")
404+
group.create_dataset(sub_key, data=volume, overwrite=True)
405+
else:
406+
array = zarr.open(
407+
zarr_path,
408+
mode="w",
409+
shape=volume.shape,
410+
dtype=volume.dtype,
411+
)
412+
array[...] = volume
413+
390414
elif file_format in ("tif", "tiff"):
391415
import tifffile
392416

@@ -410,7 +434,7 @@ def save_volume(
410434

411435
else:
412436
raise ValueError(
413-
f"Unsupported format: {file_format}. " f"Expected: h5, tiff, png, nii, nii.gz"
437+
f"Unsupported format: {file_format}. " f"Expected: h5, zarr, tiff, png, nii, nii.gz"
414438
)
415439

416440

@@ -436,17 +460,19 @@ def get_vol_shape(
436460
Returns shape consistent with what read_volume would
437461
produce: (D, H, W) or (C, D, H, W).
438462
"""
439-
if not os.path.exists(filename):
440-
raise FileNotFoundError(f"File not found: {filename}")
441-
442463
fmt = _detect_format(filename)
443464

444465
if fmt == "zarr":
445466
try:
446467
import zarr
447468
except ModuleNotFoundError as exc:
448469
raise ModuleNotFoundError("zarr required. pip install zarr") from exc
449-
obj = zarr.open(filename, mode="r")
470+
zarr_path, sub_key = _split_zarr_path(filename)
471+
if not os.path.exists(zarr_path):
472+
raise FileNotFoundError(f"File not found: {zarr_path}")
473+
obj = zarr.open(zarr_path, mode="r")
474+
if sub_key:
475+
return tuple(obj[sub_key].shape)
450476
if hasattr(obj, "shape"):
451477
return tuple(obj.shape)
452478
if dataset is not None:
@@ -456,6 +482,9 @@ def get_vol_shape(
456482
raise ValueError(f"No arrays in zarr group: {filename}")
457483
return tuple(obj[keys[0]].shape)
458484

485+
if not os.path.exists(filename):
486+
raise FileNotFoundError(f"File not found: {filename}")
487+
459488
if fmt == "h5":
460489
with h5py.File(filename, "r") as f:
461490
if dataset is None:
@@ -483,3 +512,15 @@ def get_vol_shape(
483512
return _get_nifti_shape(filename)
484513

485514
raise ValueError(f"Unsupported format: {fmt}")
515+
516+
517+
def volume_exists(
518+
filename: str,
519+
dataset: Optional[str] = None,
520+
) -> bool:
521+
"""Return True when a volume path can be opened by this IO layer."""
522+
try:
523+
get_vol_shape(filename, dataset=dataset)
524+
except (FileNotFoundError, KeyError, ValueError, OSError):
525+
return False
526+
return True

connectomics/data/processing/distance.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import Dict, Optional, Tuple
44

5+
import cc3d
56
import kimimaro
67
import numpy as np
78
from scipy.ndimage import binary_fill_holes, distance_transform_edt
@@ -13,8 +14,6 @@
1314
remove_small_holes,
1415
)
1516

16-
import cc3d
17-
1817
from .bbox_processor import BBoxInstanceProcessor, BBoxProcessorConfig
1918
from .quantize import energy_quantize
2019

@@ -459,7 +458,7 @@ def kimimaro_config(label: np.ndarray, resolution: Tuple[float, ...]) -> dict:
459458

460459
# --- dust threshold ---
461460
# Skip instances smaller than a 5³-voxel cube.
462-
dust_threshold = max(5 ** label.ndim, 5)
461+
dust_threshold = max(5**label.ndim, 5)
463462

464463
# --- flags ---
465464
# fix_branching: improves branch-point accuracy but ~1.3x slower.
@@ -605,7 +604,10 @@ def precompute_sdt_volume(
605604

606605
t0 = time.time()
607606
sdt = skeleton_aware_distance_transform(
608-
label, resolution=resolution, alpha=alpha, bg_value=bg_value,
607+
label,
608+
resolution=resolution,
609+
alpha=alpha,
610+
bg_value=bg_value,
609611
max_parallel=parallel,
610612
)
611613
elapsed = time.time() - t0
@@ -665,7 +667,9 @@ def precompute_skeleton_volume(
665667
skel_vol[verts[:, 0], verts[:, 1]] = inst_id
666668

667669
n_skel_voxels = int((skel_vol > 0).sum())
668-
print(f" Skeleton volume: {n_skel_voxels} voxels ({n_skel_voxels / max(skel_vol.size, 1) * 100:.2f}%)")
670+
print(
671+
f" Skeleton volume: {n_skel_voxels} voxels ({n_skel_voxels / max(skel_vol.size, 1) * 100:.2f}%)"
672+
)
669673

670674
save_volume(output_path, skel_vol)
671675
print(f" Saved to {output_path}")
@@ -751,14 +755,24 @@ def compute_edt_with_skeleton(
751755
def sdt_path_for_label(label_path: str, mode: str = "sdt") -> str:
752756
"""Derive the precomputed cache path from a label file path.
753757
758+
HDF5 labels produce sibling ``*.h5`` cache files. Zarr dataset paths such
759+
as ``data.zarr/seg`` produce sibling arrays inside the same store, for
760+
example ``data.zarr/seg_skeleton``. A bare ``data.zarr`` label path falls
761+
back to a sibling store such as ``data_skeleton.zarr``.
762+
754763
Args:
755764
mode: ``"sdt"`` for full SDT, ``"skeleton"`` for skeleton volume.
756-
757-
Examples:
758-
``train-labels.tif`` → ``train-labels_sdt.h5``
759-
``train-labels.tif`` → ``train-labels_skeleton.h5``
760765
"""
761766
import os
762767

768+
if ".zarr" in label_path:
769+
zarr_idx = label_path.index(".zarr")
770+
zarr_path = label_path[: zarr_idx + 5]
771+
sub_key = label_path[zarr_idx + 5 :].strip("/")
772+
if sub_key:
773+
return f"{zarr_path}/{sub_key}_{mode}"
774+
base = zarr_path[: -len(".zarr")]
775+
return f"{base}_{mode}.zarr"
776+
763777
base, _ = os.path.splitext(label_path)
764778
return base + f"_{mode}.h5"

connectomics/training/lightning/callbacks.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ def on_train_batch_end(
159159
"""Store first batch for epoch-end visualization."""
160160
if batch_idx == 0:
161161
self._last_train_batch = self._build_cached_batch(batch)
162+
# Log image+label on the very first batch (no prediction) for data sanity check.
163+
if trainer.current_epoch == 0 and trainer.logger is not None:
164+
self._log_data_check(trainer, batch)
162165

163166
def on_validation_batch_end(
164167
self,
@@ -324,6 +327,31 @@ def _log_visualization(
324327
selected_channels=self.selected_channels,
325328
)
326329

330+
def _log_data_check(self, trainer, batch: Dict[str, torch.Tensor]) -> None:
331+
"""Log image + label from the first training batch (no prediction).
332+
333+
Runs once at the start of training so the user can visually verify
334+
data loading, augmentation, and label transforms before waiting for
335+
the first epoch to finish.
336+
"""
337+
try:
338+
writer = trainer.logger.experiment
339+
image = batch["image"].cpu()
340+
label = batch["label"].cpu()
341+
342+
self._log_visualization(
343+
image=image,
344+
label=label,
345+
mask=None,
346+
pred=label, # show label in the pred slot too (no model output yet)
347+
writer=writer,
348+
iteration=0,
349+
prefix="data_check",
350+
)
351+
logger.info("Logged data check visualization (image + label, no prediction)")
352+
except Exception as e:
353+
logger.warning("Data check visualization failed: %s", e)
354+
327355
@staticmethod
328356
def _to_tensor(pred):
329357
"""Extract a tensor from possible deep-supervision dict outputs."""

connectomics/training/lightning/data_factory.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
build_val_transforms,
1717
)
1818
from ...data.datasets import create_data_dicts_from_paths
19-
from ...data.io import get_vol_shape
19+
from ...data.io import get_vol_shape, volume_exists
2020
from .data import ConnectomicsDataModule, SimpleDataModule
2121
from .path_utils import expand_file_paths
2222

@@ -76,12 +76,10 @@ def _maybe_precompute_label_aux(
7676

7777
print(f"label_aux_type={mode} ({split_name}): " f"resolution={list(resolution)}, alpha={alpha}")
7878

79-
import os
80-
8179
paths = []
8280
for lp in label_paths:
8381
sp = sdt_path_for_label(lp, mode=mode)
84-
if not os.path.exists(sp):
82+
if not volume_exists(sp):
8583
if mode == "sdt":
8684
precompute_sdt_volume(lp, sp, resolution=resolution, alpha=alpha, bg_value=bg_value)
8785
else:
@@ -616,7 +614,7 @@ def create_datamodule(
616614
logger.info("Auto-computing iter_num from volume size...")
617615

618616
from ...data.datasets.sampling import compute_total_samples
619-
from ...data.io import get_vol_shape
617+
from ...data.io import get_vol_shape, volume_exists
620618

621619
# Get volume sizes
622620
volume_sizes = []

0 commit comments

Comments
 (0)