From 45165a8fcb0ed93f1a8e6ca492e0a609621e5caa Mon Sep 17 00:00:00 2001 From: anon Date: Sat, 30 May 2026 01:07:15 +0200 Subject: [PATCH 1/4] test: add quantitative recovery floor from validation sweep Lock the validation-sweep outcome: at min_confidence=0.5 the deterministic fixture recovers >=50% of cut pieces with no intact false-merges. min_confidence default stays 0.7 (full attainable recall, zero false merges); gap_proximity kept in the 5-feature score. Co-Authored-By: Claude Opus 4.8 --- tests/experimental/test_tiling_stitch.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/experimental/test_tiling_stitch.py b/tests/experimental/test_tiling_stitch.py index a22c40efc..c55b65548 100644 --- a/tests/experimental/test_tiling_stitch.py +++ b/tests/experimental/test_tiling_stitch.py @@ -74,6 +74,23 @@ def test_no_intact_cells_stitched_at_high_threshold(self, sdata_tile_boundary): n_false = int((intact & adata.obs["is_stitched"].astype(bool)).sum()) assert n_false <= 5 + def test_recovery_meets_quantitative_bounds(self, sdata_tile_boundary): + """Quantitative floor from the validation sweep (deterministic fixture). + + At ``min_confidence=0.5`` the sweep recovers ~64% of cut pieces with zero + intact false-merges; assert a conservative recall floor and a near-zero + false-merge bound (small tolerance for skimage version drift). + """ + sdata, gt = sdata_tile_boundary + adata = _run_qc_and_stitch(sdata, min_confidence=0.5) + lid = adata.obs["label_id"].astype(int) + stitched = adata.obs["is_stitched"].astype(bool) + n_cut_stitched = int((lid.isin(gt.cut_cell_ids) & stitched).sum()) + n_false = int((lid.isin(gt.intact_cell_ids) & stitched).sum()) + recall = n_cut_stitched / max(len(gt.cut_cell_ids), 1) + assert recall >= 0.5, f"recall {recall:.2f} below 0.5 floor" + assert n_false <= 2, f"too many intact false merges: {n_false}" + def test_uns_records_params_and_features(self, sdata_tile_boundary): sdata, _ = sdata_tile_boundary meta = _run_qc_and_stitch(sdata, min_confidence=0.7, max_gap=4.0).uns["tiling_stitch"] From afb4bde70d1140614723abbc1f868a84a48dbd1b Mon Sep 17 00:00:00 2001 From: anon Date: Sat, 30 May 2026 01:47:13 +0200 Subject: [PATCH 2/4] feat(experimental): add make_stitched_labels (materialise stitched labels) PR-C of the #1170 split. Adds sq.experimental.im.make_stitched_labels: reads the stitch_group_id mapping written by assign_stitch_groups and registers a new labels element where each stitched group shares one ID (original labels left untouched), plus an optional collapsed AnnData (one row per group) with configurable merge_strategy. Stacks on PR-B (feature/tiling-stitch-algo) and reuses its full-pipeline tests. Scale rework (must handle 100k x 100k, never materialise the full array): - LUT remap stays lazy; labels present in the image but absent from the QC table (e.g. min_area-filtered cells) pass through instead of indexing OOB. - join_labels is now chunk-aware via dask map_overlap (depth = close_radius+2), never computing the whole image or running a full-frame regionprops. - group aggregation vectorised (argsort+split, pandas groupby) instead of the O(cells x groups) per-group np.where scan. - X aggregation is sparse-safe: sum/mean via sparse matmul, first via sparse gather, other reducers per-group-bounded; the full matrix is never densified. - validate label_id/is_stitched up front; squeeze a leading singleton dim. Co-Authored-By: Claude Opus 4.8 --- docs/api.md | 1 + src/squidpy/experimental/im/__init__.py | 2 + .../experimental/im/_stitched_labels.py | 482 ++++++++++++++++++ tests/experimental/test_stitched_labels.py | 348 +++++++++++++ 4 files changed, 833 insertions(+) create mode 100644 src/squidpy/experimental/im/_stitched_labels.py create mode 100644 tests/experimental/test_stitched_labels.py diff --git a/docs/api.md b/docs/api.md index c701c4058..86a3d80bd 100644 --- a/docs/api.md +++ b/docs/api.md @@ -151,6 +151,7 @@ See the {doc}`extensibility guide ` for how to implement a custo experimental.tl.TilingQCParams experimental.tl.assign_stitch_groups experimental.tl.StitchParams + experimental.im.make_stitched_labels experimental.pl.tiling_qc experimental.im.fit_stain_reference experimental.im.apply_stain_normalization diff --git a/src/squidpy/experimental/im/__init__.py b/src/squidpy/experimental/im/__init__.py index 1a661b53a..0d4661f15 100644 --- a/src/squidpy/experimental/im/__init__.py +++ b/src/squidpy/experimental/im/__init__.py @@ -15,6 +15,7 @@ apply_stain_normalization, fit_stain_reference, ) +from ._stitched_labels import make_stitched_labels __all__ = [ "BackgroundDetectionParams", @@ -26,6 +27,7 @@ "apply_stain_normalization", "detect_tissue", "fit_stain_reference", + "make_stitched_labels", "make_tiles", "make_tiles_from_spots", "qc_image", diff --git a/src/squidpy/experimental/im/_stitched_labels.py b/src/squidpy/experimental/im/_stitched_labels.py new file mode 100644 index 000000000..f1b48ad5c --- /dev/null +++ b/src/squidpy/experimental/im/_stitched_labels.py @@ -0,0 +1,482 @@ +"""Materialise a stitched labels element from an assign_stitch_groups result. + +Companion to :func:`squidpy.experimental.tl.assign_stitch_groups`. Takes the +piece-to-group mapping from ``stitch_group_id`` in the QC table and writes +a new labels element where stitched pieces share a single ID. The original +labels element is untouched. +""" + +from __future__ import annotations + +import copy as _copy +from collections.abc import Callable + +import anndata as ad +import dask.array as da +import numpy as np +import pandas as pd +import scipy.sparse as sp +import spatialdata as sd +import xarray as xr +from scipy.ndimage import binary_closing +from skimage.morphology import disk as morph_disk +from spatialdata._logging import logger as logg +from spatialdata.models import Labels2DModel, TableModel +from spatialdata.transformations import get_transformation + +from squidpy.experimental.utils._labels import resolve_labels_array + +__all__ = ["make_stitched_labels"] + + +_LUT_DENSITY_RATIO = 8 # max_id <= len(label_ids) * 8 -> LUT is reasonable +_LUT_ABSOLUTE_CAP = 100_000_000 # never allocate more than 100M entries + + +def _build_lookup(adata_obs: pd.DataFrame, dtype: np.dtype) -> np.ndarray: + """Build an int->int LUT from ``label_id`` to ``stitch_group_id``. + + LUT covers ``[0, max_label_id]``; unmapped indices keep their own value + (identity), so background (0) and any cells absent from the QC table are + preserved. + + Raises + ------ + ValueError + If ``stitch_group_id`` (or ``label_id``) values exceed the labels' + dtype range -- silent truncation here would alias unrelated cells. + ValueError + If ``max(label_id)`` is so much larger than the number of cells that + the dense LUT would be wasteful (sparse-but-large ID spaces). Users + with this label scheme should remap to contiguous IDs first. + """ + label_ids = adata_obs["label_id"].astype(np.int64).to_numpy() + group_ids = adata_obs["stitch_group_id"].astype(np.int64).to_numpy() + if np.issubdtype(dtype, np.integer): + info = np.iinfo(dtype) + worst = max(int(label_ids.max(initial=0)), int(group_ids.max(initial=0))) + if worst > info.max: + raise ValueError( + f"label_id / stitch_group_id values up to {worst} exceed the labels " + f"dtype range {dtype} (max {info.max}); cannot build a safe LUT." + ) + max_id = int(label_ids.max(initial=0)) + n_cells = int(label_ids.size) + if max_id > _LUT_ABSOLUTE_CAP or (n_cells > 0 and max_id > _LUT_DENSITY_RATIO * n_cells and max_id > 1000): + raise ValueError( + f"Cannot allocate a {max_id + 1}-entry LUT for {n_cells} cells " + f"(sparse label IDs). Remap your labels to contiguous IDs starting " + f"from 1 before calling make_stitched_labels." + ) + lut = np.arange(max_id + 1, dtype=dtype) + lut[label_ids] = group_ids.astype(dtype) + return lut + + +def _apply_lut(labels_da: xr.DataArray, lut: np.ndarray) -> da.Array | np.ndarray: + """Lazily remap a labels DataArray via the LUT over its dask blocks. + + Labels present in the image but absent from the LUT (e.g. small cells the + QC table dropped via ``min_area``, whose pixels still exist) are kept as-is + -- they index past the LUT, so we map only in-range values and leave the + rest at their original identity. Returns a bare array (dask or numpy) so + the caller can re-parse via Labels2DModel without colliding metadata. + """ + src = labels_da.data + max_id = lut.shape[0] - 1 + + def _remap(block: np.ndarray, _lut: np.ndarray = lut, _max: int = max_id) -> np.ndarray: + out = np.asarray(block).copy() + in_range = out <= _max + out[in_range] = _lut[out[in_range]] + return out + + if isinstance(src, da.Array): + return src.map_blocks(_remap, dtype=lut.dtype) + return _remap(np.asarray(src)) + + +def _join_stitched_labels( + labels_arr: da.Array | np.ndarray, + stitched_group_ids: set[int], + close_radius: int = 3, +) -> da.Array | np.ndarray: + """Morphologically close gaps between pieces of each stitched group. + + The basic LUT remap leaves stitched groups as multi-component regions (the + cut stripe between pieces stays at 0). This pass fills only background + pixels inside the closed hull of each stitched group, so each becomes a + single connected component; other cells' pixels are never overwritten. + + Chunk-aware and lazy: a dask array is processed block-by-block via + :func:`dask.array.map_overlap` with ``depth = close_radius + 2`` (so groups + split across a block boundary still close correctly), never materialising + the full image. Each block touches only the stitched labels it contains. + Returns a dask array for dask input, numpy for numpy input. + """ + if not stitched_group_ids: + return labels_arr + stitched = frozenset(int(g) for g in stitched_group_ids) + structure = morph_disk(close_radius) + + def _close_block(block: np.ndarray) -> np.ndarray: + block = np.asarray(block) + while block.ndim > 2: + block = block.squeeze(0) + present = stitched.intersection(np.unique(block).tolist()) + if not present: + return block + out = block.copy() + for gid in present: + mask = block == gid + closed = binary_closing(mask, structure=structure) + # Only fill genuine background pixels -- never overwrite another cell. + fill = closed & ~mask & (block == 0) + if fill.any(): + out[fill] = gid + return out + + if isinstance(labels_arr, da.Array): + depth = close_radius + 2 + return da.map_overlap(_close_block, labels_arr, depth=depth, boundary=0, dtype=labels_arr.dtype) + return _close_block(labels_arr) + + +_BUILTIN_STRATEGIES: dict[str, Callable[[pd.Series], object]] = { + "sum": lambda s: s.sum(), + "min": lambda s: s.min(), + "max": lambda s: s.max(), + "mean": lambda s: s.mean(), + "median": lambda s: s.median(), + "first": lambda s: s.iloc[0], +} + +# Vectorised counterparts: ``f(block) -> 1-D array of length n_cols``. Used +# in :func:`_aggregate_X` to avoid an O(groups*cols) Python loop when the +# user passes a built-in strategy name. Callable strategies fall back to +# the per-column path. +_BUILTIN_X_REDUCERS: dict[str, Callable[[np.ndarray], np.ndarray]] = { + "sum": lambda b: b.sum(axis=0), + "min": lambda b: b.min(axis=0), + "max": lambda b: b.max(axis=0), + "mean": lambda b: b.mean(axis=0), + "median": lambda b: np.median(b, axis=0), + "first": lambda b: b[0], +} + +# Columns whose value is shared across all members of a stitch group; we always +# take the first member's value rather than aggregating. +_GROUP_INVARIANT_COLS = frozenset({"stitch_group_id", "is_stitched", "n_pieces", "stitch_confidence", "region"}) + + +def _resolve_strategy(strategy: str | Callable[[pd.Series], object]) -> Callable[[pd.Series], object]: + if callable(strategy): + return strategy + if strategy not in _BUILTIN_STRATEGIES: + raise ValueError( + f"Unknown merge_strategy {strategy!r}. Use one of {sorted(_BUILTIN_STRATEGIES)} or pass a callable." + ) + return _BUILTIN_STRATEGIES[strategy] + + +_INTEGER_PRESERVING_STRATEGIES = frozenset({"sum", "min", "max", "first"}) + + +def _aggregate_X( + X, + group_indices: list[np.ndarray], + strategy: str | Callable[[pd.Series], object], +): + """Aggregate ``X`` row-blocks into one row per group, column-wise. + + Scale-safe: the full matrix is **never** densified. ``sum`` / ``mean`` use + a sparse group-indicator matmul, ``first`` a sparse row gather -- both keep + a sparse result for sparse input (important when most groups are singletons, + so the output is nearly as tall as the input). Other reducers (``min`` / + ``max`` / ``median`` / callables) pass singleton groups through and densify + only each multi-member group's small block. + + For integer ``X`` the output dtype is preserved only for integer-safe + strategies (``sum``, ``min``, ``max``, ``first``); mean/median and callables + promote to ``float64`` so a uint16 count matrix doesn't truncate. + """ + n_groups = len(group_indices) + n_cols = X.shape[1] + if n_cols == 0: + return np.empty((n_groups, 0), dtype=np.float32) + sparse_in = sp.issparse(X) + if np.issubdtype(X.dtype, np.integer) and ( + not isinstance(strategy, str) or strategy not in _INTEGER_PRESERVING_STRATEGIES + ): + out_dtype = np.float64 + else: + out_dtype = X.dtype + + # Vectorised, non-densifying paths for the common strategies. + if isinstance(strategy, str) and strategy in ("sum", "mean"): + rows = np.concatenate([np.full(len(idx), i, dtype=np.int64) for i, idx in enumerate(group_indices)]) + cols = np.concatenate(group_indices).astype(np.int64) + if strategy == "mean": + sizes = np.array([len(idx) for idx in group_indices], dtype=np.float64) + data = (1.0 / sizes)[rows] + else: + data = np.ones(cols.size, dtype=np.float64) + indicator = sp.csr_matrix((data, (rows, cols)), shape=(n_groups, X.shape[0])) + res = indicator @ X + return res.astype(out_dtype) if sparse_in else np.asarray(res, dtype=out_dtype) + + if isinstance(strategy, str) and strategy == "first": + first_rows = np.array([idx[0] for idx in group_indices], dtype=np.int64) + res = X[first_rows] + return res.astype(out_dtype) if sparse_in else np.asarray(res, dtype=out_dtype) + + # General path: bounded per-group work, sparse-preserving output. + reducer = _BUILTIN_X_REDUCERS[strategy] if isinstance(strategy, str) else None + strategy_fn = None if reducer is not None else _resolve_strategy(strategy) + Xc = X.tocsr() if sparse_in else np.asarray(X) + out = sp.lil_matrix((n_groups, n_cols), dtype=out_dtype) if sparse_in else np.zeros((n_groups, n_cols), out_dtype) + for i, idx in enumerate(group_indices): + if len(idx) == 1: + out[i] = Xc[idx[0]].toarray().ravel() if sparse_in else Xc[idx[0]] + continue + block = Xc[idx].toarray() if sparse_in else Xc[idx] + if reducer is not None: + out[i] = reducer(block) + else: + for c in range(n_cols): + out[i, c] = strategy_fn(pd.Series(block[:, c])) + return out.tocsr() if sparse_in else out + + +def _collapse_groups( + adata: ad.AnnData, + new_labels_key: str, + merge_strategy: str | Callable[[pd.Series], object], +) -> ad.AnnData: + """Collapse each stitch group into a single row. + + Output has one row per unique ``stitch_group_id``: unstitched cells (their + own group) keep their row unchanged, stitched groups (size 2-4) collapse + via ``merge_strategy``. ``.obs`` columns, ``.uns``, ``.var`` and ``.X`` + are preserved/aggregated; ``spatialdata_attrs`` and the ``region`` column + are rewritten to point at the new labels element. + + Aggregation rules: + - ``label_id``: rewritten to the group id (matches new labels element). + - ``stitch_group_id``, ``is_stitched``, ``n_pieces``, ``stitch_confidence``, + ``region``: members agree -> first value. + - Other numeric obs columns and all of ``X``: ``merge_strategy`` (default + ``"sum"``). Built-ins: ``sum``, ``min``, ``max``, ``mean``, ``median``, + ``first``. A callable receives a :class:`pandas.Series` and returns a + scalar; it's applied column-wise to both ``.obs`` and ``.X``. + - Non-numeric obs columns: ``"first"`` regardless of ``merge_strategy`` + (sum/mean don't make sense for strings/categoricals). + + Note: ``merge_strategy="sum"`` is the right default for additive features + (area, intensity, count) but wrong for centroids, scores, fractions. + Override accordingly for those. + + .. warning:: + ``.obsm``, ``.obsp``, ``.layers`` are passed through but not + aggregated. If their row dimensions become inconsistent with the new + ``n_obs``, downstream tools may complain. Drop them if not needed. + """ + obs = adata.obs + if "stitch_group_id" not in obs.columns: + raise ValueError("AnnData missing 'stitch_group_id'; run assign_stitch_groups first.") + if "label_id" not in obs.columns: + raise ValueError("AnnData missing 'label_id'.") + + _resolve_strategy(merge_strategy) # validate strategy name early + group_ids = obs["stitch_group_id"].astype(int).to_numpy() + # Positional indices per group in one linear pass (sorted by group id), + # instead of an O(n_cells * n_groups) per-group np.where scan. + order = np.argsort(group_ids, kind="stable") + unique_groups, first_idx = np.unique(group_ids[order], return_index=True) + indices_by_group = np.split(order, first_idx[1:]) + + # ---- Aggregate obs via vectorised groupby ---- + # Group-invariant + non-numeric columns take the first member's value; + # numeric columns use merge_strategy. label_id is set to the group id. + cols = [c for c in obs.columns if c != "label_id"] + numeric_cols = [c for c in cols if c not in _GROUP_INVARIANT_COLS and pd.api.types.is_numeric_dtype(obs[c])] + first_cols = [c for c in cols if c not in numeric_cols] + gb = obs.groupby(group_ids, sort=True) + pieces = [] + if first_cols: + pieces.append(gb[first_cols].first()) + if numeric_cols: + pieces.append(gb[numeric_cols].agg(merge_strategy)) + new_obs = pd.concat(pieces, axis=1) if pieces else pd.DataFrame(index=unique_groups) + new_obs["label_id"] = unique_groups + new_obs = new_obs[list(obs.columns)] + # Preserve dtypes where possible (agg can promote/lose categorical). + for col in new_obs.columns: + try: + new_obs[col] = new_obs[col].astype(obs[col].dtype) + except (TypeError, ValueError): + pass + # Update the region column to point at the new labels element. + if "region" in new_obs.columns: + new_obs["region"] = pd.Categorical([new_labels_key] * len(new_obs)) + new_obs.index = [f"group_{gid}" for gid in unique_groups] + + # ---- Aggregate X ---- + if adata.X is not None and adata.X.shape[1] > 0: + new_X = _aggregate_X(adata.X, indices_by_group, merge_strategy) + else: + new_X = np.empty((len(unique_groups), 0), dtype=np.float32) + + # ---- Preserve var / uns / pass-through obsm-style fields ---- + new_uns = _copy.deepcopy(dict(adata.uns)) + new_uns["spatialdata_attrs"] = { + "region": new_labels_key, + "region_key": "region", + "instance_key": "label_id", + } + out = ad.AnnData(X=new_X, obs=new_obs, var=adata.var.copy(), uns=new_uns) + + # Warn if there are row-dimensioned fields we didn't aggregate; user can + # decide whether to drop them. + skipped = [name for name in ("obsm", "obsp", "layers") if getattr(adata, name, None)] + if skipped: + logg.warning( + f"AnnData has {skipped}; these were not aggregated and the " + "resulting table omits them. Pass them through manually if needed." + ) + + return out + + +def make_stitched_labels( + sdata: sd.SpatialData, + labels_key: str, + qc_table_key: str | None = None, + labels_key_added: str | None = None, + table_key_added: str | None = None, + write_table: bool = True, + merge_strategy: str | Callable[[pd.Series], object] = "sum", + join_labels: bool = False, + join_close_radius: int = 3, + inplace: bool = True, +) -> dict[str, object] | None: + """Materialise a stitched labels element from an assign_stitch_groups result. + + Reads the ``stitch_group_id`` mapping in the QC table, builds a lazy + int->int LUT, and registers a new labels element where each stitched + group shares a single ID. The original labels element is **not** + modified. + + Optionally also writes a companion AnnData (``write_table=True``) with one + row per unique ``stitch_group_id`` -- unstitched cells keep their row + unchanged, stitched groups (size 2-4) collapse via ``merge_strategy``. + + Parameters + ---------- + sdata + :class:`~spatialdata.SpatialData` with a labels element and a QC + table that has been processed by + :func:`squidpy.experimental.tl.assign_stitch_groups`. + labels_key + Key in ``sdata.labels`` of the original labels element. + qc_table_key + Key of the QC table. Defaults to ``"{labels_key}_qc"``. + labels_key_added + Key for the new labels element. Defaults to + ``"{labels_key}_stitched"``. Existing element at this key is + overwritten with a warning. + table_key_added + Key for the optional collapsed AnnData (one row per unique + ``stitch_group_id``). Defaults to ``"{labels_key_added}_table"`` + (must differ from the labels element key -- SpatialData requires + unique names across element types). + write_table + If ``True``, also write the collapsed AnnData to + ``sdata.tables[table_key_added]``. + merge_strategy + How to aggregate numeric ``.obs`` columns and ``.X`` across the + 2-4 pieces of each stitched cell. String options: + ``"sum"`` (default), ``"min"``, ``"max"``, ``"mean"``, ``"median"``, + ``"first"``. Callable: receives a :class:`pandas.Series` (one + column of one group's members) and returns a scalar; applied + column-wise. + + ``"sum"`` is the right default for additive features (area, + intensity); for centroids, scores, or fractions, override with + ``"mean"`` or pass a callable. + + Two classes of columns are **always** taken from the first member + regardless of ``merge_strategy`` (including callables): + + - Group-invariant columns -- ``stitch_group_id``, ``is_stitched``, + ``n_pieces``, ``stitch_confidence``, ``region`` -- because every + member of a group already shares the same value. + - Non-numeric columns (strings, categoricals, booleans) -- because + ``sum`` / ``mean`` / etc. don't have a meaningful interpretation. + join_labels + If ``True``, morphologically close the gap between pieces of each + stitched group so the resulting labels are single connected + components instead of multi-component regions sharing an ID. Only + background pixels inside each group's closed hull are filled; + other cells are never overwritten. **Forces materialisation of + the labels array** -- cost is O(image_size) plus O(stitched x bbox). + Default ``False`` preserves the original gap pixels. + join_close_radius + Radius (px) of the disk structuring element used when + ``join_labels=True``. Default ``3`` matches the closing radius + used during scoring; raise it if pieces remain disconnected after + joining. + inplace + If ``True`` (default), write the new labels element (and table when + ``write_table=True``) into ``sdata``. If ``False``, return the + materialised objects in a dict ``{"labels": ..., "table": ...}`` + without mutating ``sdata``; ``"table"`` is ``None`` when + ``write_table=False``. + """ + if labels_key not in sdata.labels: + raise ValueError(f"Labels key '{labels_key}' not found in sdata.labels.") + table_key = qc_table_key if qc_table_key is not None else f"{labels_key}_qc" + if table_key not in sdata.tables: + raise ValueError(f"QC table '{table_key}' not found in sdata.tables.") + adata = sdata.tables[table_key] + required = ("label_id", "stitch_group_id", "is_stitched") + missing = [c for c in required if c not in adata.obs.columns] + if missing: + raise ValueError( + f"QC table '{table_key}' is missing {missing}; run squidpy.experimental.tl.assign_stitch_groups first." + ) + + qc_params = adata.uns.get("tiling_qc", {}) + scale = qc_params.get("scale") + labels_da = resolve_labels_array(sdata, labels_key, scale) + + lut = _build_lookup(adata.obs, labels_da.dtype) + new_data = _apply_lut(labels_da, lut) + if join_labels: + stitched_gids = adata.obs.loc[adata.obs["is_stitched"].astype(bool), "stitch_group_id"].astype(int).unique() + new_data = _join_stitched_labels(new_data, {int(g) for g in stitched_gids}, close_radius=join_close_radius) + + out_key = labels_key_added if labels_key_added is not None else f"{labels_key}_stitched" + new_labels = Labels2DModel.parse( + data=new_data, + dims=("y", "x"), + transformations=get_transformation(sdata.labels[labels_key], get_all=True), + ) + new_table = None + if write_table: + collapsed = _collapse_groups(adata, out_key, merge_strategy) + new_table = TableModel.parse(collapsed) + + if not inplace: + return {"labels": new_labels, "table": new_table} + + if out_key in sdata.labels: + logg.warning(f"Overwriting existing labels element '{out_key}'.") + sdata.labels[out_key] = new_labels + + if new_table is not None: + tbl_key = table_key_added if table_key_added is not None else f"{out_key}_table" + if tbl_key in sdata.tables: + logg.warning(f"Overwriting existing table '{tbl_key}'.") + sdata.tables[tbl_key] = new_table + return None diff --git a/tests/experimental/test_stitched_labels.py b/tests/experimental/test_stitched_labels.py new file mode 100644 index 000000000..53013e019 --- /dev/null +++ b/tests/experimental/test_stitched_labels.py @@ -0,0 +1,348 @@ +"""Tests for sq.experimental.im.make_stitched_labels.""" + +from __future__ import annotations + +import numpy as np +import pytest + +import squidpy as sq + + +def _qc_and_stitch(sdata, **stitch_kwargs): + sq.experimental.tl.calculate_tiling_qc( + sdata, + labels_key="labels", + tile_size=200, + nmads_cut=1.0, + nmads_smoothed=1.5, + ) + sq.experimental.tl.assign_stitch_groups(sdata, labels_key="labels", **stitch_kwargs) + + +class TestMakeStitchedLabels: + def test_creates_new_labels_element(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata) + assert "labels_stitched" not in sdata.labels + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels") + assert "labels_stitched" in sdata.labels + + def test_original_labels_unchanged(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + original_arr = np.asarray(sdata.labels["labels"].values).copy() + _qc_and_stitch(sdata) + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels") + after_arr = np.asarray(sdata.labels["labels"].values) + np.testing.assert_array_equal(original_arr, after_arr) + + def test_remap_unifies_stitched_pieces(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + adata = sdata.tables["labels_qc"] + stitched = adata.obs[adata.obs["is_stitched"].astype(bool)] + if len(stitched) == 0: + pytest.skip("no stitched cells in this fixture realisation") + + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels") + new_arr = np.asarray(sdata.labels["labels_stitched"].values) + old_arr = np.asarray(sdata.labels["labels"].values) + + # Pick one stitched group with >= 2 pieces + gid = int(stitched["stitch_group_id"].iloc[0]) + pieces = stitched.loc[stitched["stitch_group_id"] == gid, "label_id"].astype(int).tolist() + assert len(pieces) >= 2 + + # All original pixels of those pieces should now carry the group id + for piece_id in pieces: + mask = old_arr == piece_id + assert mask.any() + assert (new_arr[mask] == gid).all(), f"piece {piece_id} not remapped to {gid}" + + def test_unstitched_pieces_keep_their_id(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata) + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels") + old_arr = np.asarray(sdata.labels["labels"].values) + new_arr = np.asarray(sdata.labels["labels_stitched"].values) + # Pixels with label 0 (background) stay 0 + bg = old_arr == 0 + assert (new_arr[bg] == 0).all() + # Cells whose group_id == label_id are unchanged in the remap + adata = sdata.tables["labels_qc"] + unstitched = adata.obs[adata.obs["stitch_group_id"].astype(int) == adata.obs["label_id"].astype(int)] + # Spot-check the first 5 unstitched + for lid in unstitched["label_id"].astype(int).iloc[:5]: + mask = old_arr == lid + if mask.any(): + assert (new_arr[mask] == lid).all() + + def test_collapsed_table_one_row_per_group(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=True) + assert "labels_stitched_table" in sdata.tables + agg = sdata.tables["labels_stitched_table"] + adata = sdata.tables["labels_qc"] + # Output has one row per unique stitch_group_id (unstitched cells stay + # as singleton groups, stitched groups collapse to one row). + n_groups = adata.obs["stitch_group_id"].nunique() + assert agg.n_obs == n_groups + for col in ("label_id", "stitch_group_id", "n_pieces", "is_stitched", "stitch_confidence"): + assert col in agg.obs.columns + + def test_collapsed_table_includes_unstitched_cells(self, sdata_tile_boundary): + """Both stitched (collapsed) and unstitched (passthrough) rows present.""" + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=True) + agg = sdata.tables["labels_stitched_table"] + # At least some unstitched cells should be in the output. + assert (~agg.obs["is_stitched"].astype(bool)).sum() > 0, "expected unstitched rows" + # The is_stitched column flags which rows are collapsed groups. + if agg.obs["is_stitched"].astype(bool).sum() > 0: + assert (agg.obs.loc[agg.obs["is_stitched"].astype(bool), "n_pieces"] >= 2).all() + + def test_merge_strategy_sum_aggregates_numeric_columns(self, sdata_tile_boundary): + """For a stitched group, a synthetic numeric column should sum across pieces.""" + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + adata = sdata.tables["labels_qc"] + adata.obs["fake_area"] = 100.0 + sdata.tables["labels_qc"] = adata + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=True, merge_strategy="sum") + agg = sdata.tables["labels_stitched_table"] + stitched = agg.obs[agg.obs["is_stitched"].astype(bool)] + if len(stitched) == 0: + pytest.skip("no stitched groups in this realisation") + # Each stitched group has n_pieces members each contributing 100. + np.testing.assert_array_equal( + stitched["fake_area"].to_numpy(), + stitched["n_pieces"].to_numpy() * 100.0, + ) + + def test_merge_strategy_mean_aggregates_numeric_columns(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + adata = sdata.tables["labels_qc"] + adata.obs["fake_intensity"] = 42.0 + sdata.tables["labels_qc"] = adata + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=True, merge_strategy="mean") + agg = sdata.tables["labels_stitched_table"] + stitched = agg.obs[agg.obs["is_stitched"].astype(bool)] + if len(stitched) > 0: + np.testing.assert_allclose(stitched["fake_intensity"].to_numpy(), 42.0) + + def test_merge_strategy_callable(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + adata = sdata.tables["labels_qc"] + adata.obs["fake_count"] = 1 + sdata.tables["labels_qc"] = adata + sq.experimental.im.make_stitched_labels( + sdata, + labels_key="labels", + write_table=True, + merge_strategy=lambda s: len(s), + ) + agg = sdata.tables["labels_stitched_table"] + # Callable returns len of group, so fake_count == n_pieces post-merge. + np.testing.assert_array_equal( + agg.obs["fake_count"].to_numpy(), + agg.obs["n_pieces"].to_numpy(), + ) + + def test_group_invariant_columns_take_first(self, sdata_tile_boundary): + """is_stitched, n_pieces, stitch_confidence are not affected by sum strategy.""" + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + adata_orig = sdata.tables["labels_qc"] + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=True, merge_strategy="sum") + agg = sdata.tables["labels_stitched_table"] + # n_pieces should be in {1, 2, 3, 4} -- if "sum" had been applied to it, + # a 4-piece group would show n_pieces = 16. + assert (agg.obs["n_pieces"].astype(int) <= 4).all() + # Members of a stitch group share is_stitched value; collapsed row should match. + stitched = adata_orig.obs[adata_orig.obs["is_stitched"].astype(bool)] + for gid in stitched["stitch_group_id"].astype(int).unique(): + row = agg.obs[agg.obs["stitch_group_id"].astype(int) == gid] + assert len(row) == 1 + assert bool(row["is_stitched"].iloc[0]) is True + + def test_aggregated_table_preserves_qc_columns_and_uns(self, sdata_tile_boundary): + """The reduced table must keep the QC table's obs columns and uns + instead of constructing a fresh AnnData from scratch.""" + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + adata = sdata.tables["labels_qc"] + # User adds a custom obs column to simulate downstream annotation. + adata.obs["my_custom_flag"] = True + sdata.tables["labels_qc"] = adata + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=True) + agg = sdata.tables["labels_stitched_table"] + # Original QC obs columns survive + for col in ( + "max_straight_edge_ratio", + "cardinal_alignment_score", + "cut_score", + "smoothed_cut_score", + "is_outlier", + "nhood_outlier_fraction", + "centroid_y", + "centroid_x", + "my_custom_flag", + ): + assert col in agg.obs.columns, f"missing preserved column: {col}" + # Uns surfaces survive (tiling_qc params, tiling_stitch params) + assert "tiling_qc" in agg.uns + assert "tiling_stitch" in agg.uns + # spatialdata_attrs now points at the stitched labels element + attrs = agg.uns["spatialdata_attrs"] + assert attrs["region"] == "labels_stitched" + assert attrs["instance_key"] == "label_id" + + def test_aggregated_table_label_id_matches_new_element_ids(self, sdata_tile_boundary): + """label_id values in the table must equal the IDs in the new labels + element (the stitch_group_id values become the new instance keys).""" + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=True) + agg = sdata.tables["labels_stitched_table"] + new_arr = np.asarray(sdata.labels["labels_stitched"].values) + unique_in_image = set(np.unique(new_arr).tolist()) - {0} + unique_in_table = set(agg.obs["label_id"].astype(int).tolist()) + # Every row in the table must reference an existing instance in the labels element. + assert unique_in_table.issubset(unique_in_image), f"orphan rows: {unique_in_table - unique_in_image}" + + @pytest.mark.parametrize( + ("setup", "kwargs", "match"), + [ + ("qc_only", {"labels_key": "labels"}, "stitch_group_id"), + ("qc_and_stitch", {"labels_key": "bogus"}, "not found"), + ("qc_and_stitch", {"labels_key": "labels", "merge_strategy": "bogus"}, "Unknown merge_strategy"), + ], + ids=["stitch_not_run", "missing_labels_key", "invalid_merge_strategy"], + ) + def test_invalid_input_raises(self, sdata_tile_boundary, setup, kwargs, match): + sdata, _ = sdata_tile_boundary + if setup == "qc_only": + sq.experimental.tl.calculate_tiling_qc(sdata, labels_key="labels", tile_size=200) + else: + _qc_and_stitch(sdata) + with pytest.raises(ValueError, match=match): + sq.experimental.im.make_stitched_labels(sdata, **kwargs) + + def test_idempotent(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata) + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels") + first = np.asarray(sdata.labels["labels_stitched"].values).copy() + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels") + second = np.asarray(sdata.labels["labels_stitched"].values) + np.testing.assert_array_equal(first, second) + + def test_join_labels_false_keeps_multi_component(self, sdata_tile_boundary): + from skimage.measure import label as cc_label + + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + adata = sdata.tables["labels_qc"] + stitched = adata.obs[adata.obs["is_stitched"].astype(bool)] + if len(stitched) == 0: + pytest.skip("no stitched cells in this fixture realisation") + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", join_labels=False) + arr = np.asarray(sdata.labels["labels_stitched"].values) + # At least one stitched group should have >1 connected component + # (the unjoined behaviour leaves the cut stripe as background). + any_multi = False + for gid in stitched["stitch_group_id"].astype(int).unique()[:5]: + mask = arr == gid + if mask.any(): + ncc = int(cc_label(mask).max()) + if ncc > 1: + any_multi = True + break + assert any_multi, "expected at least one multi-component stitched group with join_labels=False" + + def test_join_labels_true_unifies_components(self, sdata_tile_boundary): + from skimage.measure import label as cc_label + + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + adata = sdata.tables["labels_qc"] + stitched = adata.obs[adata.obs["is_stitched"].astype(bool)] + if len(stitched) == 0: + pytest.skip("no stitched cells in this fixture realisation") + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", join_labels=True) + arr = np.asarray(sdata.labels["labels_stitched"].values) + for gid in stitched["stitch_group_id"].astype(int).unique(): + mask = arr == gid + if not mask.any(): + continue + ncc = int(cc_label(mask).max()) + assert ncc == 1, f"group {gid} still has {ncc} components after join_labels=True" + + def test_join_labels_does_not_overwrite_other_cells(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + adata = sdata.tables["labels_qc"] + # Snapshot every non-stitched cell's pixel set before joining, then + # confirm none of those pixels changed identity afterwards. + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", join_labels=False) + before_arr = np.asarray(sdata.labels["labels_stitched"].values).copy() + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", join_labels=True) + after_arr = np.asarray(sdata.labels["labels_stitched"].values) + non_stitched_gids = ( + adata.obs.loc[~adata.obs["is_stitched"].astype(bool), "stitch_group_id"].astype(int).unique() + ) + for gid in non_stitched_gids[:20]: + before_mask = before_arr == gid + if not before_mask.any(): + continue + # Non-stitched cells must keep all their original pixels. + assert (after_arr[before_mask] == gid).all(), f"non-stitched cell {gid} was overwritten" + + +class TestScaleRework: + """Lock the scale/correctness rework: out-of-range passthrough, lazy join, + sparse-safe and vectorised aggregation.""" + + def test_unmapped_image_label_passes_through(self): + """A label present in the image but absent from the QC table (e.g. a + min_area-filtered cell) must survive the LUT remap, not crash (C1).""" + import pandas as pd + import xarray as xr + + from squidpy.experimental.im._stitched_labels import _apply_lut, _build_lookup + + obs = pd.DataFrame({"label_id": [1, 2], "stitch_group_id": [1, 1]}) + labels = np.array([[0, 1, 2], [5, 5, 0]], dtype=np.int32) # label 5 not in table + lut = _build_lookup(obs, labels.dtype) + out = np.asarray(_apply_lut(xr.DataArray(labels, dims=("y", "x")), lut)) + assert (out[labels == 5] == 5).all(), "unmapped label 5 was not preserved" + assert (out[labels == 2] == 1).all(), "label 2 should remap to group 1" + + def test_join_labels_stays_lazy_on_dask_input(self, sdata_tile_boundary): + import dask.array as da + + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + res = sq.experimental.im.make_stitched_labels( + sdata, labels_key="labels", join_labels=True, write_table=False, inplace=False + ) + # The fixture labels are dask-backed; the joined output must remain lazy. + assert isinstance(res["labels"].data, da.Array), "join_labels materialised the full array" + + @pytest.mark.parametrize("strategy", ["sum", "mean", "first", "max"]) + def test_aggregate_X_sparse_matches_dense(self, strategy): + from scipy import sparse + + from squidpy.experimental.im._stitched_labels import _aggregate_X + + rng = np.random.default_rng(0) + dense = rng.integers(0, 5, size=(6, 4)).astype(np.float64) + groups = [np.array([0, 1, 2]), np.array([3]), np.array([4, 5])] + out_dense = np.asarray(_aggregate_X(dense, groups, strategy)) + out_sparse = _aggregate_X(sparse.csr_matrix(dense), groups, strategy) + # Sparse input must NOT be densified into a dense result. + assert sparse.issparse(out_sparse), "sparse input should yield sparse output" + np.testing.assert_allclose(out_dense, out_sparse.toarray()) From 8a196c8395136e130b5543d343916e1e6b88dcdc Mon Sep 17 00:00:00 2001 From: anon Date: Fri, 12 Jun 2026 16:41:21 +0200 Subject: [PATCH 3/4] refactor: inline single-use tiling-stitch geometry helpers, no behaviour change - inline equivalent_diameter + largest_contour (single consumer _tiling_stitch) into their call site and delete utils/_geometry.py; keep both guards (1px pad precondition + empty-contour skip). - route _resolve_qc_params through the shared resolve_params helper instead of a duplicated Mapping-validation block; drop the now-dead _QC_FIELDS and unused fields import. utils/_params.py now has two genuine callers. Co-Authored-By: Claude Fable 5 --- src/squidpy/experimental/tl/_tiling_qc.py | 15 ++-------- src/squidpy/experimental/tl/_tiling_stitch.py | 10 ++++--- src/squidpy/experimental/utils/_geometry.py | 30 ------------------- 3 files changed, 9 insertions(+), 46 deletions(-) delete mode 100644 src/squidpy/experimental/utils/_geometry.py diff --git a/src/squidpy/experimental/tl/_tiling_qc.py b/src/squidpy/experimental/tl/_tiling_qc.py index 287610a45..1514d35ac 100644 --- a/src/squidpy/experimental/tl/_tiling_qc.py +++ b/src/squidpy/experimental/tl/_tiling_qc.py @@ -28,7 +28,7 @@ import math from collections.abc import Mapping -from dataclasses import asdict, dataclass, fields +from dataclasses import asdict, dataclass from typing import Any, Literal import anndata as ad @@ -54,6 +54,7 @@ ) from squidpy.experimental.tl._tiling_stitch import _STITCH_COLUMNS, _STITCH_PARAM_KEYS, StitchParams from squidpy.experimental.utils._labels import resolve_labels_array +from squidpy.experimental.utils._params import resolve_params __all__ = ["TilingQCParams", "calculate_tiling_qc"] @@ -92,23 +93,13 @@ def __post_init__(self) -> None: _QC_DEFAULTS = TilingQCParams() -_QC_FIELDS = frozenset(f.name for f in fields(TilingQCParams)) def _resolve_qc_params(qc_params: TilingQCParams | Mapping[str, Any] | None) -> TilingQCParams: """Normalise the ``tiling_qc_params`` argument to a :class:`TilingQCParams` instance.""" if qc_params is None: return _QC_DEFAULTS - if isinstance(qc_params, TilingQCParams): - return qc_params - if isinstance(qc_params, Mapping): - unknown = set(qc_params) - _QC_FIELDS - if unknown: - raise ValueError( - f"Unknown `tiling_qc_params` field(s): {sorted(unknown)}; expected from {sorted(_QC_FIELDS)}." - ) - return TilingQCParams(**qc_params) - raise TypeError(f"`tiling_qc_params` must be TilingQCParams, Mapping, or None; got {type(qc_params).__name__}.") + return resolve_params(qc_params, TilingQCParams, label="`tiling_qc_params`") # Standard consistency factor sd ~ 1.4826 x MAD for normal distributions. diff --git a/src/squidpy/experimental/tl/_tiling_stitch.py b/src/squidpy/experimental/tl/_tiling_stitch.py index 0c8b69cb1..859dbc0b0 100644 --- a/src/squidpy/experimental/tl/_tiling_stitch.py +++ b/src/squidpy/experimental/tl/_tiling_stitch.py @@ -31,12 +31,12 @@ from scipy.ndimage import binary_closing from scipy.sparse import csr_matrix from scipy.sparse.csgraph import connected_components +from skimage.measure import find_contours from skimage.measure import label as cc_label from skimage.measure import regionprops from skimage.morphology import disk as morph_disk from spatialdata._logging import logger as logg -from squidpy.experimental.utils._geometry import equivalent_diameter, largest_contour from squidpy.experimental.utils._labels import iter_chunked_regionprops, resolve_labels_array from squidpy.experimental.utils._params import resolve_params @@ -302,10 +302,12 @@ def _extract_cut_edges( if not cell_mask.any(): continue outlier_crops[lid] = cell_mask + # 1px zero-pad so cells filling their bbox still trace a closed contour. mask = np.pad(cell_mask.astype(np.float32), 1, mode="constant", constant_values=0) - contour = largest_contour(mask) - if contour is None: + contours = find_contours(mask, 0.5) + if not contours: # degenerate mask traces nothing; skip it continue + contour = max(contours, key=len) contour_global = contour.copy() contour_global[:, 0] += min_r - 1 contour_global[:, 1] += min_c - 1 @@ -315,7 +317,7 @@ def _extract_cut_edges( cy = float(ys.mean()) + min_r - 1 cx = float(xs.mean()) + min_c - 1 area = float(mask.sum()) - eq_diameter = equivalent_diameter(area) + eq_diameter = float(np.sqrt(4 * area / np.pi)) # diameter of the equal-area circle min_len = max(min_edge_length, min_edge_length_ratio * eq_diameter) # find_contours places level set 0.5 outside the integer pixel boundary. diff --git a/src/squidpy/experimental/utils/_geometry.py b/src/squidpy/experimental/utils/_geometry.py deleted file mode 100644 index 4f6c152c2..000000000 --- a/src/squidpy/experimental/utils/_geometry.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Shared internal geometry helpers for mask/contour analysis. - -Not part of the public API - symbols here are private and may change -without notice. -""" - -from __future__ import annotations - -import numpy as np -from skimage.measure import find_contours - - -def equivalent_diameter(area: float) -> float: - """Diameter of the circle with the given area: ``sqrt(4 * area / pi)``.""" - return float(np.sqrt(4 * area / np.pi)) - - -def largest_contour(padded_mask: np.ndarray, level: float = 0.5) -> np.ndarray | None: - """Return the longest :func:`skimage.measure.find_contours` contour, or ``None``. - - The mask must be **already 1px zero-padded** by the caller so that cells - touching the crop edge (e.g. filling their bbox) are traced closed. Padding - is left to the caller because its placement relative to other steps (e.g. - downsampling) is order-sensitive and differs between call sites. Returned - coordinates are in the padded mask's frame. - """ - contours = find_contours(padded_mask, level) - if not contours: - return None - return max(contours, key=len) From 9a6f45dbcec96b3aa57dd2b2fe19f029f0f9b086 Mon Sep 17 00:00:00 2001 From: anon Date: Fri, 12 Jun 2026 16:54:13 +0200 Subject: [PATCH 4/4] fix(experimental): correct make_stitched_labels aggregation, multiscale + validation Addresses a max-effort review of the stitched-labels path: - collapsed table no longer sums QC-contract columns: centroid_y/x take the mean of the pieces (a merged cell's position), the per-piece cut-artifact scores take the max (worst piece); merge_strategy now applies only to genuine additive user features. Drops the dtype-restore that re-narrowed/truncated promoted sums and means. - _aggregate_X: sum no longer treated as integer-preserving, so an integer .X promotes to float64 instead of wrapping (uint16 200000 -> 3392); a callable strategy is now applied to singleton groups too (obs/X parity). - multiscale: parse the new element with the resolved scale's transform (not the DataTree base/Identity), so the output overlays correctly; documented as a single-scale element at the QC scale's resolution. - validate label_id (positive, unique, non-NaN) and merge_strategy up front, so bad input fails fast even with write_table=False. - preserve varm (var axis is unchanged); correct the join_labels docstring (it is lazy via map_overlap and bridges seams up to 2*close_radius px, not "forces materialisation"). - add TestReviewFixes regression locks (each fails on the pre-fix code). Co-Authored-By: Claude Fable 5 --- docs/release/notes-dev.md | 1 + .../experimental/im/_stitched_labels.py | 166 ++++++++++++------ src/squidpy/experimental/tl/_tiling_stitch.py | 3 +- tests/experimental/test_stitched_labels.py | 137 +++++++++++++++ 4 files changed, 254 insertions(+), 53 deletions(-) diff --git a/docs/release/notes-dev.md b/docs/release/notes-dev.md index e5fda1ca7..eb0589df4 100644 --- a/docs/release/notes-dev.md +++ b/docs/release/notes-dev.md @@ -2,6 +2,7 @@ ## Features +- Add {func}`squidpy.experimental.im.make_stitched_labels` to materialise a stitched labels element (and an optional collapsed table) from an {func}`squidpy.experimental.tl.assign_stitch_groups` result, completing the tile-cut stitching workflow. - Fix {func}`squidpy.tl.var_by_distance` behaviour when providing {mod}`numpy` arrays of coordinates as anchor point. - Update :attr:`squidpy.pl.var_by_distance` to show multiple variables on same plot. [@LLehner](https://github.com/LLehner) diff --git a/src/squidpy/experimental/im/_stitched_labels.py b/src/squidpy/experimental/im/_stitched_labels.py index f1b48ad5c..3952394df 100644 --- a/src/squidpy/experimental/im/_stitched_labels.py +++ b/src/squidpy/experimental/im/_stitched_labels.py @@ -168,6 +168,23 @@ def _close_block(block: np.ndarray) -> np.ndarray: # take the first member's value rather than aggregating. _GROUP_INVARIANT_COLS = frozenset({"stitch_group_id", "is_stitched", "n_pieces", "stitch_confidence", "region"}) +# QC-contract columns that calculate_tiling_qc always writes and that must NOT be +# summed: a stitched cell's centroid is the (unweighted) mean of its pieces', and +# the per-piece cut-artifact scores keep the group's worst (max) value. Summing +# either produces out-of-range / out-of-bounds garbage (the merge_strategy default +# is "sum", aimed at additive *user* features like area/intensity). +_CENTROID_COLS = frozenset({"centroid_y", "centroid_x"}) +_QC_SCORE_COLS = frozenset( + { + "max_straight_edge_ratio", + "cardinal_alignment_score", + "cut_score", + "smoothed_cut_score", + "nhood_outlier_fraction", + "is_outlier", + } +) + def _resolve_strategy(strategy: str | Callable[[pd.Series], object]) -> Callable[[pd.Series], object]: if callable(strategy): @@ -179,7 +196,10 @@ def _resolve_strategy(strategy: str | Callable[[pd.Series], object]) -> Callable return _BUILTIN_STRATEGIES[strategy] -_INTEGER_PRESERVING_STRATEGIES = frozenset({"sum", "min", "max", "first"}) +# Strategies whose result is always one of the input values, so an integer +# input dtype is safe to keep. ``sum`` is deliberately excluded: it can exceed +# the input range and would wrap/saturate on cast-back (e.g. uint16 200000 -> 3392). +_INTEGER_PRESERVING_STRATEGIES = frozenset({"min", "max", "first"}) def _aggregate_X( @@ -196,9 +216,11 @@ def _aggregate_X( ``max`` / ``median`` / callables) pass singleton groups through and densify only each multi-member group's small block. - For integer ``X`` the output dtype is preserved only for integer-safe - strategies (``sum``, ``min``, ``max``, ``first``); mean/median and callables - promote to ``float64`` so a uint16 count matrix doesn't truncate. + For integer ``X`` the output dtype is preserved only for range-preserving + strategies (``min`` / ``max`` / ``first``, whose result is always an input + value); ``sum`` (which can exceed the input range), ``mean`` / ``median`` + and callables promote to ``float64`` so a uint16 count matrix neither + overflows nor truncates. """ n_groups = len(group_indices) n_cols = X.shape[1] @@ -236,7 +258,10 @@ def _aggregate_X( Xc = X.tocsr() if sparse_in else np.asarray(X) out = sp.lil_matrix((n_groups, n_cols), dtype=out_dtype) if sparse_in else np.zeros((n_groups, n_cols), out_dtype) for i, idx in enumerate(group_indices): - if len(idx) == 1: + # A builtin reducer on a 1-row block returns that row, so the singleton + # short-circuit is exact -- but a *callable* must still be applied (it may + # be non-idempotent, e.g. len), matching the obs aggregation path. + if reducer is not None and len(idx) == 1: out[i] = Xc[idx[0]].toarray().ravel() if sparse_in else Xc[idx[0]] continue block = Xc[idx].toarray() if sparse_in else Xc[idx] @@ -265,16 +290,23 @@ def _collapse_groups( - ``label_id``: rewritten to the group id (matches new labels element). - ``stitch_group_id``, ``is_stitched``, ``n_pieces``, ``stitch_confidence``, ``region``: members agree -> first value. - - Other numeric obs columns and all of ``X``: ``merge_strategy`` (default - ``"sum"``). Built-ins: ``sum``, ``min``, ``max``, ``mean``, ``median``, - ``first``. A callable receives a :class:`pandas.Series` and returns a - scalar; it's applied column-wise to both ``.obs`` and ``.X``. - - Non-numeric obs columns: ``"first"`` regardless of ``merge_strategy`` - (sum/mean don't make sense for strings/categoricals). - - Note: ``merge_strategy="sum"`` is the right default for additive features - (area, intensity, count) but wrong for centroids, scores, fractions. - Override accordingly for those. + - ``centroid_y`` / ``centroid_x``: mean of the pieces' centroids (the merged + cell's position) -- never summed. + - QC cut-artifact scores (``cut_score``, ``smoothed_cut_score``, + ``max_straight_edge_ratio``, ``cardinal_alignment_score``, + ``nhood_outlier_fraction``, ``is_outlier``): max -> keep the worst piece's + value; summing per-piece diagnostics is meaningless. + - Remaining numeric obs columns (genuine user features) and all of ``X``: + ``merge_strategy`` (default ``"sum"``). Built-ins: ``sum``, ``min``, + ``max``, ``mean``, ``median``, ``first``. A callable receives a + :class:`pandas.Series` and returns a scalar; applied column-wise to both + ``.obs`` and ``.X``. + - Non-numeric obs columns: first member's value regardless of + ``merge_strategy`` (sum/mean don't make sense for strings/categoricals). + + ``merge_strategy="sum"`` is the right default for additive user features + (area, intensity, count); the QC contract columns above are handled + automatically and ignore it. .. warning:: ``.obsm``, ``.obsp``, ``.layers`` are passed through but not @@ -296,26 +328,31 @@ def _collapse_groups( indices_by_group = np.split(order, first_idx[1:]) # ---- Aggregate obs via vectorised groupby ---- - # Group-invariant + non-numeric columns take the first member's value; - # numeric columns use merge_strategy. label_id is set to the group id. + # Column policy: invariant + non-numeric -> first member; centroids -> mean + # (the merged cell's position); QC cut-artifact scores -> max (keep the worst + # piece); remaining numeric (genuine user features) -> merge_strategy. + # label_id is set to the group id. We deliberately do NOT cast aggregated + # columns back to the source dtype: a summed/averaged int column must keep its + # promoted (int64/float) dtype or it would overflow / truncate. cols = [c for c in obs.columns if c != "label_id"] numeric_cols = [c for c in cols if c not in _GROUP_INVARIANT_COLS and pd.api.types.is_numeric_dtype(obs[c])] + centroid_cols = [c for c in numeric_cols if c in _CENTROID_COLS] + score_cols = [c for c in numeric_cols if c in _QC_SCORE_COLS] + user_cols = [c for c in numeric_cols if c not in _CENTROID_COLS and c not in _QC_SCORE_COLS] first_cols = [c for c in cols if c not in numeric_cols] gb = obs.groupby(group_ids, sort=True) pieces = [] if first_cols: pieces.append(gb[first_cols].first()) - if numeric_cols: - pieces.append(gb[numeric_cols].agg(merge_strategy)) + if centroid_cols: + pieces.append(gb[centroid_cols].mean()) + if score_cols: + pieces.append(gb[score_cols].max()) + if user_cols: + pieces.append(gb[user_cols].agg(merge_strategy)) new_obs = pd.concat(pieces, axis=1) if pieces else pd.DataFrame(index=unique_groups) new_obs["label_id"] = unique_groups new_obs = new_obs[list(obs.columns)] - # Preserve dtypes where possible (agg can promote/lose categorical). - for col in new_obs.columns: - try: - new_obs[col] = new_obs[col].astype(obs[col].dtype) - except (TypeError, ValueError): - pass # Update the region column to point at the new labels element. if "region" in new_obs.columns: new_obs["region"] = pd.Categorical([new_labels_key] * len(new_obs)) @@ -336,6 +373,11 @@ def _collapse_groups( } out = ad.AnnData(X=new_X, obs=new_obs, var=adata.var.copy(), uns=new_uns) + # The var axis is unchanged by a row collapse, so varm stays consistent and + # is carried over verbatim (unlike the obs-dimensioned fields below). + for key in getattr(adata, "varm", {}): + out.varm[key] = adata.varm[key].copy() + # Warn if there are row-dimensioned fields we didn't aggregate; user can # decide whether to drop them. skipped = [name for name in ("obsm", "obsp", "layers") if getattr(adata, name, None)] @@ -394,33 +436,34 @@ def make_stitched_labels( If ``True``, also write the collapsed AnnData to ``sdata.tables[table_key_added]``. merge_strategy - How to aggregate numeric ``.obs`` columns and ``.X`` across the - 2-4 pieces of each stitched cell. String options: - ``"sum"`` (default), ``"min"``, ``"max"``, ``"mean"``, ``"median"``, - ``"first"``. Callable: receives a :class:`pandas.Series` (one - column of one group's members) and returns a scalar; applied - column-wise. - - ``"sum"`` is the right default for additive features (area, - intensity); for centroids, scores, or fractions, override with - ``"mean"`` or pass a callable. - - Two classes of columns are **always** taken from the first member - regardless of ``merge_strategy`` (including callables): - - - Group-invariant columns -- ``stitch_group_id``, ``is_stitched``, - ``n_pieces``, ``stitch_confidence``, ``region`` -- because every - member of a group already shares the same value. - - Non-numeric columns (strings, categoricals, booleans) -- because - ``sum`` / ``mean`` / etc. don't have a meaningful interpretation. + How to aggregate genuine numeric user feature columns in ``.obs`` and + all of ``.X`` across the 2-4 pieces of each stitched cell. String + options: ``"sum"`` (default, for additive features like area / + intensity), ``"min"``, ``"max"``, ``"mean"``, ``"median"``, ``"first"``. + Callable: receives a :class:`pandas.Series` (one column of one group's + members) and returns a scalar; applied column-wise to ``.obs`` and ``.X``. + + ``merge_strategy`` does **not** apply to the columns below, which are + handled automatically (so a stray ``"sum"`` can't corrupt them): + + - Group-invariant columns (``stitch_group_id``, ``is_stitched``, + ``n_pieces``, ``stitch_confidence``, ``region``) and any non-numeric + column -> first member's value. + - ``centroid_y`` / ``centroid_x`` -> mean of the pieces' centroids. + - QC cut-artifact scores (``cut_score``, ``smoothed_cut_score``, + ``max_straight_edge_ratio``, ``cardinal_alignment_score``, + ``nhood_outlier_fraction``, ``is_outlier``) -> max (worst piece). join_labels If ``True``, morphologically close the gap between pieces of each - stitched group so the resulting labels are single connected - components instead of multi-component regions sharing an ID. Only - background pixels inside each group's closed hull are filled; - other cells are never overwritten. **Forces materialisation of - the labels array** -- cost is O(image_size) plus O(stitched x bbox). - Default ``False`` preserves the original gap pixels. + stitched group, so a group that the basic remap leaves as several + components sharing an ID becomes a single connected component. Only + background pixels inside each group's closed hull are filled; other + cells are never overwritten. This stays lazy on dask input (a + per-block :func:`dask.array.map_overlap` pass, never materialising the + whole image). Closing bridges seams up to ``2 * join_close_radius`` px + wide; pieces separated by a wider gap stay disconnected -- raise + ``join_close_radius`` for those. Default ``False`` preserves the + original gap pixels. join_close_radius Radius (px) of the disk structuring element used when ``join_labels=True``. Default ``3`` matches the closing radius @@ -445,6 +488,23 @@ def make_stitched_labels( raise ValueError( f"QC table '{table_key}' is missing {missing}; run squidpy.experimental.tl.assign_stitch_groups first." ) + # Validate merge_strategy up front so an invalid value fails fast even when + # write_table=False (the aggregation that would otherwise raise is skipped). + _resolve_strategy(merge_strategy) + # label_id is the instance key that drives the LUT remap; a NaN, non-positive, + # or duplicated id would crash cryptically or silently mis-map pixels (0 is the + # background sentinel; duplicates make lut[label_id]=group_id keep only the last). + label_id = adata.obs["label_id"] + if label_id.isna().any(): + raise ValueError(f"QC table '{table_key}' has NaN in 'label_id'; cannot build the relabel lookup.") + label_id = label_id.astype(np.int64) + if (label_id <= 0).any(): + raise ValueError( + f"QC table '{table_key}' has non-positive 'label_id' (0 is the background sentinel); " + "label ids must be positive instance keys." + ) + if label_id.duplicated().any(): + raise ValueError(f"QC table '{table_key}' has duplicate 'label_id' values; each cell must appear once.") qc_params = adata.uns.get("tiling_qc", {}) scale = qc_params.get("scale") @@ -457,10 +517,14 @@ def make_stitched_labels( new_data = _join_stitched_labels(new_data, {int(g) for g in stitched_gids}, close_radius=join_close_radius) out_key = labels_key_added if labels_key_added is not None else f"{labels_key}_stitched" + # Take the transform from the RESOLVED array (the chosen scale's DataArray), + # not the DataTree: a multi-scale element's base transform is Identity, but the + # resolved level carries the Scale that maps it back to global coordinates. The + # output is a single-scale element at the QC scale's resolution (see docstring). new_labels = Labels2DModel.parse( data=new_data, dims=("y", "x"), - transformations=get_transformation(sdata.labels[labels_key], get_all=True), + transformations=get_transformation(labels_da, get_all=True), ) new_table = None if write_table: diff --git a/src/squidpy/experimental/tl/_tiling_stitch.py b/src/squidpy/experimental/tl/_tiling_stitch.py index 859dbc0b0..7df08c629 100644 --- a/src/squidpy/experimental/tl/_tiling_stitch.py +++ b/src/squidpy/experimental/tl/_tiling_stitch.py @@ -31,9 +31,8 @@ from scipy.ndimage import binary_closing from scipy.sparse import csr_matrix from scipy.sparse.csgraph import connected_components -from skimage.measure import find_contours +from skimage.measure import find_contours, regionprops from skimage.measure import label as cc_label -from skimage.measure import regionprops from skimage.morphology import disk as morph_disk from spatialdata._logging import logger as logg diff --git a/tests/experimental/test_stitched_labels.py b/tests/experimental/test_stitched_labels.py index 53013e019..d2b101145 100644 --- a/tests/experimental/test_stitched_labels.py +++ b/tests/experimental/test_stitched_labels.py @@ -346,3 +346,140 @@ def test_aggregate_X_sparse_matches_dense(self, strategy): # Sparse input must NOT be densified into a dense result. assert sparse.issparse(out_sparse), "sparse input should yield sparse output" np.testing.assert_allclose(out_dense, out_sparse.toarray()) + + +class TestReviewFixes: + """Regression locks for the review findings (each fails on the pre-fix code).""" + + def test_centroid_is_mean_not_sum_and_in_bounds(self, sdata_tile_boundary): + """M1: QC contract columns must not be summed by the default merge_strategy.""" + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + qc = sdata.tables["labels_qc"].obs + stitched = qc[qc["is_stitched"].astype(bool)] + if stitched.empty: + pytest.skip("no stitched cells in this realisation") + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=True, merge_strategy="sum") + agg = sdata.tables["labels_stitched_table"].obs + h, w = np.asarray(sdata.labels["labels"].values).shape[-2:] + gid = int(stitched["stitch_group_id"].iloc[0]) + members = qc[qc["stitch_group_id"].astype(int) == gid] + assert len(members) >= 2 + row = agg[agg["stitch_group_id"].astype(int) == gid] + assert len(row) == 1 + # centroid -> mean of pieces (a sum would roughly double it and leave the image) + np.testing.assert_allclose(row["centroid_y"].iloc[0], members["centroid_y"].mean(), rtol=1e-6) + np.testing.assert_allclose(row["centroid_x"].iloc[0], members["centroid_x"].mean(), rtol=1e-6) + assert 0.0 <= row["centroid_y"].iloc[0] <= h + assert 0.0 <= row["centroid_x"].iloc[0] <= w + # cut_score -> max of pieces (a sum would push it past its natural range) + np.testing.assert_allclose(row["cut_score"].iloc[0], members["cut_score"].max(), rtol=1e-6) + + def test_aggregate_X_integer_sum_no_overflow(self): + """M3: summing an integer .X must not wrap on cast-back to the input dtype.""" + from scipy import sparse + + from squidpy.experimental.im._stitched_labels import _aggregate_X + + X = np.array([[40000], [30000], [20000]], dtype=np.uint16) # sum 90000 > uint16 max + groups = [np.array([0, 1, 2])] + dense = np.asarray(_aggregate_X(X, groups, "sum")) + assert dense[0, 0] == 90000, "uint16 sum wrapped (90000 % 65536 == 24464)" + sp_out = _aggregate_X(sparse.csr_matrix(X), groups, "sum") + assert sp_out.toarray()[0, 0] == 90000 + + def test_aggregate_X_callable_applied_to_singletons(self): + """M5: a callable strategy must be applied to singleton groups too (obs/X parity).""" + from squidpy.experimental.im._stitched_labels import _aggregate_X + + X = np.array([[10.0], [20.0], [30.0]]) + groups = [np.array([0]), np.array([1, 2])] # group 0 is a singleton + out = np.asarray(_aggregate_X(X, groups, lambda s: float(len(s)))) + assert out[0, 0] == 1.0, "callable bypassed for singleton (got the raw row, not len)" + assert out[1, 0] == 2.0 + + def test_int_obs_column_mean_not_truncated(self, sdata_tile_boundary): + """M6: aggregating an integer obs column with mean must keep the fractional value.""" + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + qc = sdata.tables["labels_qc"] + qc.obs["int_feature"] = np.arange(1, qc.n_obs + 1, dtype=np.int64) + sdata.tables["labels_qc"] = qc + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=True, merge_strategy="mean") + agg = sdata.tables["labels_stitched_table"] + assert agg.obs["int_feature"].dtype.kind == "f", "mean of an int column was truncated back to int" + stitched = qc.obs[qc.obs["is_stitched"].astype(bool)] + if not stitched.empty: + gid = int(stitched["stitch_group_id"].iloc[0]) + members = qc.obs[qc.obs["stitch_group_id"].astype(int) == gid] + got = agg.obs.loc[agg.obs["stitch_group_id"].astype(int) == gid, "int_feature"].iloc[0] + np.testing.assert_allclose(got, members["int_feature"].mean(), rtol=1e-6) + + @pytest.mark.parametrize( + ("mutate", "match"), + [("duplicate", "duplicate 'label_id'"), ("zero", "non-positive")], + ids=["duplicate", "zero"], + ) + def test_invalid_label_id_raises(self, sdata_tile_boundary, mutate, match): + """M7: a duplicated / non-positive label_id must raise an actionable error, not mis-map. + + (NaN label_id is also guarded in-code but is unreachable through the public + path -- SpatialData's TableModel rejects a null instance key on assignment.) + """ + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata) + qc = sdata.tables["labels_qc"] + lid = qc.obs["label_id"].to_numpy().copy() + if mutate == "duplicate": + lid[1] = lid[0] + else: + lid[0] = 0 + qc.obs["label_id"] = lid + sdata.tables["labels_qc"] = qc + with pytest.raises(ValueError, match=match): + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels") + + def test_invalid_merge_strategy_raises_without_table(self, sdata_tile_boundary): + """M9: merge_strategy is validated eagerly even when write_table=False.""" + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata) + with pytest.raises(ValueError, match="Unknown merge_strategy"): + sq.experimental.im.make_stitched_labels( + sdata, labels_key="labels", write_table=False, merge_strategy="bogus" + ) + + def test_multiscale_output_carries_resolved_scale_transform(self): + """M2/M4: the stitched element must use the resolved scale's transform, not the + DataTree base (Identity), and sit at that scale's resolution.""" + import dask.array as da + import xarray as xr + from spatialdata import SpatialData + from spatialdata.models import Labels2DModel + from spatialdata.transformations import get_transformation + + from squidpy.experimental.utils._labels import resolve_labels_array + from tests.experimental.conftest import make_tile_boundary_sdata + + base, _ = make_tile_boundary_sdata() + arr = np.asarray(base.labels["labels"].values) + ms = Labels2DModel.parse( + xr.DataArray(da.from_array(arr, chunks=(200, 200)), dims=("y", "x")), scale_factors=[2] + ) + sdata = SpatialData(images={"image": base.images["image"]}, labels={"labels": ms}) + sq.experimental.tl.calculate_tiling_qc( + sdata, labels_key="labels", scale="scale1", tile_size=100, nmads_cut=1.0, nmads_smoothed=1.5 + ) + sq.experimental.tl.assign_stitch_groups(sdata, labels_key="labels") + res = sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=False, inplace=False) + out = res["labels"] + resolved = resolve_labels_array(sdata, "labels", "scale1") + + def _affine(elem): + t = next(iter(get_transformation(elem, get_all=True).values())) + return t.to_affine_matrix(("y", "x"), ("y", "x")) + + # output overlays at the resolved scale (scale1 carries a 2x Scale to global) + np.testing.assert_allclose(_affine(out), _affine(resolved)) + # the bug attached the DataTree base (Identity) transform instead + assert not np.allclose(_affine(out), _affine(sdata.labels["labels"])) + assert out.shape[-2:] == resolved.shape[-2:]