Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ dynamic = [
dependencies = [
"aiohttp>=3.8.1",
"anndata>=0.9",
"centrosome>=1.2.3",
"cp-measure>=0.1.19,<0.2",
"cycler>=0.11",
"dask[array]>=2021.2",
"dask-image>=0.5",
Expand Down
682 changes: 503 additions & 179 deletions src/squidpy/experimental/im/_calculate_image_features.py

Large diffs are not rendered by default.

229 changes: 82 additions & 147 deletions src/squidpy/experimental/im/_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
cell. Non-owned cells are zeroed out in each tile's mask so that
downstream processing never double-counts.

All functions accept pre-computed centroid dicts and image shapes they
All functions accept pre-computed centroid dicts and image shapes; they
never materialize the full image or label array.
"""

Expand All @@ -20,6 +20,16 @@
from skimage.measure import regionprops


def yx_size(da: xr.DataArray) -> tuple[int, int]:
"""``(height, width)`` of a DataArray, falling back to its last two axes."""
return int(da.sizes.get("y", da.shape[-2])), int(da.sizes.get("x", da.shape[-1]))


def _as_2d(arr: np.ndarray) -> np.ndarray:
"""Drop singleton leading dims so a labels array is 2-D."""
return arr.squeeze() if arr.ndim > 2 else arr


@dataclass(frozen=True)
class CellInfo:
"""Centroid and bounding box for a single label."""
Expand All @@ -29,6 +39,8 @@ class CellInfo:
centroid_x: float
bbox_h: int # height of bounding box
bbox_w: int # width of bounding box
bbox_y0: int = 0 # top edge (row) of bounding box
bbox_x0: int = 0 # left edge (col) of bounding box


@dataclass(frozen=True)
Expand Down Expand Up @@ -78,6 +90,8 @@ def compute_cell_info(labels: np.ndarray) -> dict[int, CellInfo]:
centroid_x=p.centroid[1],
bbox_h=max_row - min_row,
bbox_w=max_col - min_col,
bbox_y0=min_row,
bbox_x0=min_col,
)
return info

Expand All @@ -95,20 +109,16 @@ def compute_cell_info_multiscale(
return {}

def _spatial_size(k: str) -> int:
da = labels_node[k].ds["image"]
h = int(da.sizes.get("y", da.shape[-2]))
w = int(da.sizes.get("x", da.shape[-1]))
h, w = yx_size(labels_node[k].ds["image"])
return h * w

coarsest = min(available, key=_spatial_size)
coarse_da = labels_node[coarsest].ds["image"]
coarse_labels = np.asarray(coarse_da.values).squeeze()
coarse_labels = np.asarray(labels_node[coarsest].ds["image"].values).squeeze()

if coarse_labels.ndim != 2:
raise ValueError(f"Expected 2-D labels at scale {coarsest}, got shape {coarse_labels.shape}")

target_da = labels_node[target_scale].ds["image"]
target_h, target_w = target_da.sizes.get("y", target_da.shape[-2]), target_da.sizes.get("x", target_da.shape[-1])
target_h, target_w = yx_size(labels_node[target_scale].ds["image"])
coarse_h, coarse_w = coarse_labels.shape
scale_y = target_h / coarse_h
scale_x = target_w / coarse_w
Expand All @@ -121,75 +131,73 @@ def _spatial_size(k: str) -> int:
centroid_x=p.centroid[1] * scale_x,
bbox_h=int(np.ceil((p.bbox[2] - p.bbox[0]) * scale_y)),
bbox_w=int(np.ceil((p.bbox[3] - p.bbox[1]) * scale_x)),
bbox_y0=int(np.floor(p.bbox[0] * scale_y)),
bbox_x0=int(np.floor(p.bbox[1] * scale_x)),
)
for p in props
}


@dataclass
class _Accum:
"""Per-label running totals while streaming chunks (a cell may span chunks)."""

sum_y: float = 0.0 # centroid_y * area, summed across chunks (for area-weighted centroid)
sum_x: float = 0.0
area: float = 0.0
min_y: float = np.inf
max_y: float = -np.inf
min_x: float = np.inf
max_x: float = -np.inf


def compute_cell_info_tiled(
labels_da: xr.DataArray,
chunk_size: int = 4096,
) -> dict[int, CellInfo]:
"""Compute centroids by reading label tiles — never materializes the full array.
"""Compute per-label centroids and bounding boxes without materializing the full array.

For cells spanning multiple chunks, centroids are computed as
area-weighted means of per-chunk centroids.
The labels are read in ``chunk_size`` blocks. This chunking is internal to the
scan and is independent of the featurization tiles from :func:`build_tile_specs`.
A label spanning a block boundary is partitioned across blocks; its centroid is
recovered as the area-weighted mean of the per-block centroids and its bounding
box as the union of the per-block boxes.

Parameters
----------
labels_da
2-D (y, x) dask-backed xarray DataArray.
chunk_size
Size of chunks to read at a time.
Side length in pixels of each read block.
"""
H = int(labels_da.sizes.get("y", labels_da.shape[-2]))
W = int(labels_da.sizes.get("x", labels_da.shape[-1]))

# Per-label accumulators: [sum_y*area, sum_x*area, total_area, min_y, max_y, min_x, max_x]
stats: dict[int, list[float]] = {}

for y0 in range(0, H, chunk_size):
y1 = min(y0 + chunk_size, H)
for x0 in range(0, W, chunk_size):
x1 = min(x0 + chunk_size, W)
chunk = labels_da.isel(y=slice(y0, y1), x=slice(x0, x1)).values
if chunk.ndim > 2:
chunk = chunk.squeeze()

for p in regionprops(chunk):
lid = p.label
cy_global = float(p.centroid[0] + y0)
cx_global = float(p.centroid[1] + x0)
area = float(p.area)
min_row = float(p.bbox[0] + y0)
max_row = float(p.bbox[2] + y0)
min_col = float(p.bbox[1] + x0)
max_col = float(p.bbox[3] + x0)

if lid not in stats:
stats[lid] = [cy_global * area, cx_global * area, area, min_row, max_row, min_col, max_col]
else:
s = stats[lid]
s[0] += cy_global * area
s[1] += cx_global * area
s[2] += area
s[3] = min(s[3], min_row)
s[4] = max(s[4], max_row)
s[5] = min(s[5], min_col)
s[6] = max(s[6], max_col)

result: dict[int, CellInfo] = {}
for lid, s in stats.items():
if lid == 0:
continue
result[lid] = CellInfo(
height, width = yx_size(labels_da)
accums: dict[int, _Accum] = {}

for y0 in range(0, height, chunk_size):
for x0 in range(0, width, chunk_size):
chunk = _as_2d(labels_da.isel(y=slice(y0, y0 + chunk_size), x=slice(x0, x0 + chunk_size)).values)
for prop in regionprops(chunk):
a = accums.setdefault(prop.label, _Accum())
area = float(prop.area)
a.area += area
a.sum_y += (prop.centroid[0] + y0) * area
a.sum_x += (prop.centroid[1] + x0) * area
a.min_y, a.max_y = min(a.min_y, prop.bbox[0] + y0), max(a.max_y, prop.bbox[2] + y0)
a.min_x, a.max_x = min(a.min_x, prop.bbox[1] + x0), max(a.max_x, prop.bbox[3] + x0)

return {
lid: CellInfo(
label=lid,
centroid_y=s[0] / s[2],
centroid_x=s[1] / s[2],
bbox_h=int(s[4] - s[3]),
bbox_w=int(s[6] - s[5]),
centroid_y=a.sum_y / a.area,
centroid_x=a.sum_x / a.area,
bbox_h=int(a.max_y - a.min_y),
bbox_w=int(a.max_x - a.min_x),
bbox_y0=int(a.min_y),
bbox_x0=int(a.min_x),
)
return result
for lid, a in accums.items()
if lid != 0
}


# Tile spec building
Expand All @@ -206,19 +214,19 @@ def _auto_margin(cell_info: dict[int, CellInfo]) -> int:


def build_tile_specs(
image_shape: tuple[int, int],
grid_shape: tuple[int, int],
cell_info: dict[int, CellInfo],
tile_size: int = 2048,
overlap_margin: int | Literal["auto"] = "auto",
) -> list[TileSpec]:
"""Build tile specifications from pre-computed centroids.

No pixel data is neededonly the image dimensions and centroid dict.
No pixel data is needed, only the grid dimensions and centroid dict.

Parameters
----------
image_shape
``(H, W)`` of the full-resolution image/labels.
grid_shape
``(height, width)`` of the full-resolution labels grid.
cell_info
Pre-computed centroids from :func:`compute_cell_info`,
:func:`compute_cell_info_multiscale`, or :func:`compute_cell_info_tiled`.
Expand All @@ -233,21 +241,18 @@ def build_tile_specs(
List of :class:`TileSpec`, one per grid cell that owns at least one
label. Empty tiles (no cells) are omitted.
"""
H, W = image_shape
height, width = grid_shape
if tile_size <= 0:
raise ValueError(f"tile_size must be positive, got {tile_size}")

if isinstance(overlap_margin, str) and overlap_margin == "auto":
margin = _auto_margin(cell_info)
else:
margin = int(overlap_margin)
margin = _auto_margin(cell_info) if overlap_margin == "auto" else int(overlap_margin)
if margin < 0:
raise ValueError(f"overlap_margin must be non-negative, got {margin}")

cell_to_tile: dict[int, tuple[int, int]] = {}
for lid, ci in cell_info.items():
tile_row = min(int(ci.centroid_y) // tile_size, (H - 1) // tile_size)
tile_col = min(int(ci.centroid_x) // tile_size, (W - 1) // tile_size)
for lid, cell in cell_info.items():
tile_row = min(int(cell.centroid_y) // tile_size, (height - 1) // tile_size)
tile_col = min(int(cell.centroid_x) // tile_size, (width - 1) // tile_size)
cell_to_tile[lid] = (tile_row, tile_col)

tile_to_cells: dict[tuple[int, int], set[int]] = {}
Expand All @@ -258,13 +263,13 @@ def build_tile_specs(
for (row, col), owned in sorted(tile_to_cells.items()):
by0 = row * tile_size
bx0 = col * tile_size
by1 = min(by0 + tile_size, H)
bx1 = min(bx0 + tile_size, W)
by1 = min(by0 + tile_size, height)
bx1 = min(bx0 + tile_size, width)

cy0 = max(by0 - margin, 0)
cx0 = max(bx0 - margin, 0)
cy1 = min(by1 + margin, H)
cx1 = min(bx1 + margin, W)
cy1 = min(by1 + margin, height)
cx1 = min(bx1 + margin, width)

specs.append(
TileSpec(
Expand All @@ -280,41 +285,14 @@ def build_tile_specs(
# Tile extraction


def extract_tile(
image: np.ndarray,
labels: np.ndarray,
spec: TileSpec,
) -> tuple[np.ndarray, np.ndarray]:
"""Extract a tile from numpy arrays, zeroing out non-owned cells.

Parameters
----------
image
``(C, H, W)`` numpy array.
labels
``(H, W)`` numpy label array.
spec
Tile specification.

Returns
-------
tile_image, tile_labels
"""
cy0, cx0, cy1, cx1 = spec.crop
tile_image = image[:, cy0:cy1, cx0:cx1]
tile_labels = labels[cy0:cy1, cx0:cx1].copy()
_zero_non_owned(tile_labels, spec.owned_ids)
return tile_image, tile_labels


def extract_tile_lazy(
image_da: xr.DataArray,
labels_da: xr.DataArray,
spec: TileSpec,
) -> tuple[np.ndarray, np.ndarray]:
"""Extract a tile from dask-backed xarray arrays.

Materializes only the tile's crop region (~2k×2k), not the full image.
Materializes only the tile's crop region (~2k x 2k), not the full image.

Parameters
----------
Expand All @@ -334,11 +312,7 @@ def extract_tile_lazy(
"""
cy0, cx0, cy1, cx1 = spec.crop
tile_image = image_da.isel(y=slice(cy0, cy1), x=slice(cx0, cx1)).values
tile_labels = labels_da.isel(y=slice(cy0, cy1), x=slice(cx0, cx1)).values.copy()
if tile_labels.ndim > 2:
tile_labels = tile_labels.squeeze()
_zero_non_owned(tile_labels, spec.owned_ids)
return tile_image, tile_labels
return tile_image, extract_labels_tile_lazy(labels_da, spec)


def extract_labels_tile_lazy(
Expand All @@ -362,9 +336,7 @@ def extract_labels_tile_lazy(
``(crop_h, crop_w)`` numpy array with non-owned cells zeroed.
"""
cy0, cx0, cy1, cx1 = spec.crop
tile_labels = labels_da.isel(y=slice(cy0, cy1), x=slice(cx0, cx1)).values.copy()
if tile_labels.ndim > 2:
tile_labels = tile_labels.squeeze()
tile_labels = _as_2d(labels_da.isel(y=slice(cy0, cy1), x=slice(cx0, cx1)).values.copy())
_zero_non_owned(tile_labels, spec.owned_ids)
return tile_labels

Expand Down Expand Up @@ -397,40 +369,3 @@ def _zero_non_owned(tile_labels: np.ndarray, owned_ids: frozenset[int]) -> None:
else:
owned_arr = np.fromiter(owned_ids, dtype=tile_labels.dtype, count=len(owned_ids))
tile_labels[~np.isin(tile_labels, owned_arr)] = 0


# Coverage verification


def verify_coverage(
all_label_ids: set[int],
specs: list[TileSpec],
) -> None:
"""Assert that tile specs provide full, non-overlapping cell coverage.

Parameters
----------
all_label_ids
Set of all nonzero label IDs expected in the image.
specs
Tile specifications to verify.

Raises
------
ValueError
If any cell is missing or assigned to more than one tile.
"""
owned_union: set[int] = set()
for spec in specs:
overlap = owned_union & spec.owned_ids
if overlap:
raise ValueError(f"Cells {overlap} assigned to multiple tiles")
owned_union |= spec.owned_ids

missing = all_label_ids - owned_union
if missing:
raise ValueError(f"Cells {missing} not assigned to any tile")

extra = owned_union - all_label_ids
if extra:
raise ValueError(f"Tile specs reference non-existent labels {extra}")
Loading
Loading