Skip to content

Commit d54ba6a

Browse files
Alpaca233claude
andcommitted
fix: Address PR Cephla-Lab#19 review comments
- Add docstring entries for registration_z and registration_t parameters - Fix type hints: z_level and time_idx now Optional[int] instead of int - Update read_zarr_tile and read_zarr_region to accept z_level and time_idx - Zarr reads now respect z-level and timepoint selection instead of hardcoding Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent c9527b2 commit d54ba6a

2 files changed

Lines changed: 51 additions & 13 deletions

File tree

src/tilefusion/core.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ class TileFusion:
8484
Channel index for registration.
8585
multiscale_downsample : str
8686
Either "stride" (default) or "block_mean" to control multiscale reduction.
87+
registration_z : int, optional
88+
Z-level to use for registration. If None, uses middle z-level.
89+
registration_t : int
90+
Timepoint to use for registration. Defaults to 0.
8791
"""
8892

8993
def __init__(
@@ -461,7 +465,9 @@ def _update_profiles(self) -> None:
461465
# I/O methods (delegate to format-specific loaders)
462466
# -------------------------------------------------------------------------
463467

464-
def _read_tile(self, tile_idx: int, z_level: int = None, time_idx: int = None) -> np.ndarray:
468+
def _read_tile(
469+
self, tile_idx: int, z_level: Optional[int] = None, time_idx: Optional[int] = None
470+
) -> np.ndarray:
465471
"""Read a single tile from the input data (all channels)."""
466472
if z_level is None:
467473
z_level = self._registration_z # Default to registration z-level
@@ -471,7 +477,7 @@ def _read_tile(self, tile_idx: int, z_level: int = None, time_idx: int = None) -
471477
if self._is_zarr_format:
472478
zarr_ts = self._metadata["tensorstore"]
473479
is_3d = self._metadata.get("is_3d", False)
474-
tile = read_zarr_tile(zarr_ts, tile_idx, is_3d)
480+
tile = read_zarr_tile(zarr_ts, tile_idx, is_3d, z_level=z_level, time_idx=time_idx)
475481
elif self._is_individual_tiffs_format:
476482
tile = read_individual_tiffs_tile(
477483
self._metadata["image_folder"],
@@ -508,8 +514,8 @@ def _read_tile_region(
508514
tile_idx: int,
509515
y_slice: slice,
510516
x_slice: slice,
511-
z_level: int = None,
512-
time_idx: int = None,
517+
z_level: Optional[int] = None,
518+
time_idx: Optional[int] = None,
513519
) -> np.ndarray:
514520
"""Read a region of a tile from the input data."""
515521
if z_level is None:
@@ -521,7 +527,14 @@ def _read_tile_region(
521527
zarr_ts = self._metadata["tensorstore"]
522528
is_3d = self._metadata.get("is_3d", False)
523529
region = read_zarr_region(
524-
zarr_ts, tile_idx, y_slice, x_slice, self.channel_to_use, is_3d
530+
zarr_ts,
531+
tile_idx,
532+
y_slice,
533+
x_slice,
534+
self.channel_to_use,
535+
is_3d,
536+
z_level=z_level,
537+
time_idx=time_idx,
525538
)
526539
elif self._is_individual_tiffs_format:
527540
region = read_individual_tiffs_region(

src/tilefusion/io/zarr.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ def read_zarr_tile(
100100
zarr_ts: ts.TensorStore,
101101
tile_idx: int,
102102
is_3d: bool = False,
103+
z_level: int = None,
104+
time_idx: int = 0,
103105
) -> np.ndarray:
104106
"""
105107
Read all channels of a tile from Zarr format.
@@ -111,18 +113,27 @@ def read_zarr_tile(
111113
tile_idx : int
112114
Index of the tile.
113115
is_3d : bool
114-
If True, data is 3D and max projection is applied.
116+
If True, data is 3D.
117+
z_level : int, optional
118+
Z-level to read. If None and is_3d, uses max projection.
119+
time_idx : int
120+
Timepoint to read. Defaults to 0.
115121
116122
Returns
117123
-------
118124
arr : ndarray of shape (C, Y, X)
119125
Tile data as float32.
120126
"""
121127
if is_3d:
122-
arr = zarr_ts[0, tile_idx, :, :, :, :].read().result()
123-
arr = np.max(arr, axis=1) # Max projection along Z
128+
if z_level is not None:
129+
# Read specific z-level
130+
arr = zarr_ts[time_idx, tile_idx, :, z_level, :, :].read().result()
131+
else:
132+
# Max projection along Z (legacy behavior)
133+
arr = zarr_ts[time_idx, tile_idx, :, :, :, :].read().result()
134+
arr = np.max(arr, axis=1)
124135
else:
125-
arr = zarr_ts[0, tile_idx, :, :, :].read().result()
136+
arr = zarr_ts[time_idx, tile_idx, :, :, :].read().result()
126137
return arr.astype(np.float32)
127138

128139

@@ -133,6 +144,8 @@ def read_zarr_region(
133144
x_slice: slice,
134145
channel_idx: int = 0,
135146
is_3d: bool = False,
147+
z_level: int = None,
148+
time_idx: int = 0,
136149
) -> np.ndarray:
137150
"""
138151
Read a region of a single channel from Zarr format.
@@ -149,18 +162,30 @@ def read_zarr_region(
149162
Channel index.
150163
is_3d : bool
151164
If True, data is 3D.
165+
z_level : int, optional
166+
Z-level to read. If None and is_3d, uses max projection.
167+
time_idx : int
168+
Timepoint to read. Defaults to 0.
152169
153170
Returns
154171
-------
155172
arr : ndarray of shape (1, h, w)
156173
Tile region as float32.
157174
"""
158175
if is_3d:
159-
arr = zarr_ts[0, tile_idx, channel_idx, :, y_slice, x_slice].read().result()
160-
arr = np.max(arr, axis=0)
161-
arr = arr[np.newaxis, :, :]
176+
if z_level is not None:
177+
# Read specific z-level
178+
arr = (
179+
zarr_ts[time_idx, tile_idx, channel_idx, z_level, y_slice, x_slice].read().result()
180+
)
181+
arr = arr[np.newaxis, :, :]
182+
else:
183+
# Max projection along Z (legacy behavior)
184+
arr = zarr_ts[time_idx, tile_idx, channel_idx, :, y_slice, x_slice].read().result()
185+
arr = np.max(arr, axis=0)
186+
arr = arr[np.newaxis, :, :]
162187
else:
163-
arr = zarr_ts[0, tile_idx, channel_idx, y_slice, x_slice].read().result()
188+
arr = zarr_ts[time_idx, tile_idx, channel_idx, y_slice, x_slice].read().result()
164189
arr = arr[np.newaxis, :, :]
165190
return arr.astype(np.float32)
166191

0 commit comments

Comments
 (0)