diff --git a/src/squidpy/experimental/im/_stain/_decomposition.py b/src/squidpy/experimental/im/_stain/_decomposition.py index 939efb615..538de332e 100644 --- a/src/squidpy/experimental/im/_stain/_decomposition.py +++ b/src/squidpy/experimental/im/_stain/_decomposition.py @@ -223,6 +223,48 @@ def fit_decomposition( ) -> StainReference: """Fit a decomposition :class:`StainReference` (stain matrix + max concentrations).""" od = _tissue_od(image_rgb, white_point, params.beta, tissue_mask=tissue_mask, image_key=image_key) + return _reference_from_od( + od, method, params, white_point, image_key=image_key, reference=reference, max_angle_deg=max_angle_deg + ) + + +def fit_decomposition_pooled( + das: list[xr.DataArray], + masks: list[np.ndarray], + image_keys: list[str], + method: StainMethod, + params: Any, + white_point: np.ndarray, + *, + reference: dict[str, np.ndarray] = RUIFROK_HE, + max_angle_deg: float = 45.0, +) -> StainReference: + """Fit a decomposition reference from the pooled tissue OD of several slides. + + Each slide's tissue OD is gathered (naming the slide on empty tissue) and + stacked into one ``(SUM_N, 3)`` array; the single-image fit tail then runs on + the pooled OD. Pooling one slide is identical to :func:`fit_decomposition`. + """ + ods = [ + _tissue_od(da, white_point, params.beta, tissue_mask=m, image_key=k) + for da, m, k in zip(das, masks, image_keys, strict=True) + ] + return _reference_from_od( + np.vstack(ods), method, params, white_point, image_key=None, reference=reference, max_angle_deg=max_angle_deg + ) + + +def _reference_from_od( + od: np.ndarray, + method: StainMethod, + params: Any, + white_point: np.ndarray, + *, + image_key: str | None, + reference: dict[str, np.ndarray], + max_angle_deg: float, +) -> StainReference: + """Shared fit tail: stain matrix (gated) + max concentrations -> StainReference.""" matrix = _stain_matrix(od, method, params, image_key=image_key, reference=reference, max_angle_deg=max_angle_deg) return StainReference( method=method, diff --git a/src/squidpy/experimental/im/_stain/_normalize.py b/src/squidpy/experimental/im/_stain/_normalize.py index 2731c9bc4..58350d9ff 100644 --- a/src/squidpy/experimental/im/_stain/_normalize.py +++ b/src/squidpy/experimental/im/_stain/_normalize.py @@ -33,6 +33,7 @@ apply_decomposition, decompose_to_concentrations, fit_decomposition, + fit_decomposition_pooled, ) from squidpy.experimental.im._stain._reference import StainMethod, StainReference from squidpy.experimental.im._stain._reinhard import ( @@ -40,6 +41,7 @@ _resolve_reinhard_params, apply_reinhard, fit_reinhard, + fit_reinhard_pooled, ) from squidpy.experimental.im._stain._white_point import ( default_white_point, @@ -211,13 +213,13 @@ def estimate_white_point( def fit_stain_reference( sdata: sd.SpatialData, - image_key: str, + image_key: str | list[str], *, method: StainMethod = "macenko", scale: str | Literal["auto"] = "auto", method_params: MethodParams = None, white_point: np.ndarray | None = None, - tissue_mask_key: str | None = None, + tissue_mask_key: str | list[str] | None = None, max_angle_deg: float = 45.0, canonical_reference: Mapping[str, np.ndarray] | None = None, ) -> StainReference: @@ -228,7 +230,9 @@ def fit_stain_reference( sdata SpatialData object containing the image. image_key - Key of the RGB image in ``sdata.images`` to fit on. + Key of the RGB image in ``sdata.images`` to fit on, or a **list of keys** + to fit one reference from the pooled tissue pixels of several images + (e.g. a representative cohort). Pooled images must share a dtype. method Fitting method: ``"macenko"`` (default) or ``"vahadane"`` (physical stain-matrix decomposition, usable by both :func:`normalize_stains` and @@ -254,7 +258,9 @@ def fit_stain_reference( :func:`!detect_tissue`) restricting the fit to tissue pixels. If ``None``, ``f"{image_key}_tissue"`` is used. A tissue mask is **required**: if neither exists, a :class:`KeyError` asks you to - run :func:`!detect_tissue` first. + run :func:`!detect_tissue` first. When ``image_key`` is a list, pass a + list of mask keys **order-matched** to it (or ``None`` for the + ``{key}_tissue`` convention per image). max_angle_deg Tolerance of the H/E sanity gate for the decomposition methods: the fit raises :class:`!StainFittingError` if either recovered stain vector @@ -272,6 +278,20 @@ def fit_stain_reference( """ if method not in _VALID_METHODS: raise ValueError(f"Unknown method {method!r}; expected one of {list(_VALID_METHODS)}.") + if not isinstance(image_key, str): + return _fit_pooled( + sdata, + list(image_key), + tissue_mask_key, + method=method, + scale=scale, + method_params=method_params, + white_point=white_point, + max_angle_deg=max_angle_deg, + canonical_reference=canonical_reference, + ) + if tissue_mask_key is not None and not isinstance(tissue_mask_key, str): + raise ValueError("a single `image_key` takes a single `tissue_mask_key` (str) or None.") da = _resolve_image(sdata, image_key, scale, prefer="coarsest") validate_rgb_range(da) params = _resolve_method_params(method, method_params) @@ -292,6 +312,63 @@ def fit_stain_reference( ) +def _fit_pooled( + sdata: sd.SpatialData, + image_keys: list[str], + tissue_mask_key: str | list[str] | None, + *, + method: StainMethod, + scale: str | Literal["auto"], + method_params: MethodParams, + white_point: np.ndarray | None, + max_angle_deg: float, + canonical_reference: Mapping[str, np.ndarray] | None, +) -> StainReference: + """Fit one reference by pooling the tissue pixels of several same-dtype images. + + Each image is resolved + validated and masked by its own `tissue_mask_key` + (order-matched) or the `{key}_tissue` convention; decomposition pools tissue + OD, Reinhard pools tissue Lab pixels. Images must share a dtype so the white + point is well-defined. A blank slide raises (named); no silent skip. + """ + if not image_keys: + raise ValueError("`image_key` list is empty; pass at least one image key.") + if len(set(image_keys)) != len(image_keys): + raise ValueError("`image_key` list has duplicate keys.") + if tissue_mask_key is None: + mask_keys: list[str | None] = [None] * len(image_keys) + elif isinstance(tissue_mask_key, str): + raise ValueError( + "for multiple images, pass a list of `tissue_mask_key` (order-matched to `image_key`), " + "or None to use the `{image_key}_tissue` convention." + ) + elif len(tissue_mask_key) != len(image_keys): + raise ValueError( + f"`tissue_mask_key` length ({len(tissue_mask_key)}) must match `image_key` ({len(image_keys)})." + ) + else: + mask_keys = list(tissue_mask_key) + + params = _resolve_method_params(method, method_params) + das = [] + for k in image_keys: + da = _resolve_image(sdata, k, scale, prefer="coarsest") + validate_rgb_range(da) + das.append(da) + dtypes = {str(da.dtype) for da in das} + if len(dtypes) != 1: + raise ValueError(f"pooled images must share a dtype; got {sorted(dtypes)}.") + masks = [_resolve_tissue_bool_mask(sdata, k, da, mk) for k, da, mk in zip(image_keys, das, mask_keys, strict=True)] + + if method == "reinhard": + return fit_reinhard_pooled(das, masks, params, image_keys) + bg = default_white_point(das[0]) if white_point is None else np.asarray(white_point, np.float64) + reference = RUIFROK_HE if canonical_reference is None else dict(canonical_reference) + return fit_decomposition_pooled( + das, masks, image_keys, method, params, bg, reference=reference, max_angle_deg=max_angle_deg + ) + + def normalize_stains( sdata: sd.SpatialData, image_key: str, diff --git a/src/squidpy/experimental/im/_stain/_reinhard.py b/src/squidpy/experimental/im/_stain/_reinhard.py index 5c0021c5d..404a4b110 100644 --- a/src/squidpy/experimental/im/_stain/_reinhard.py +++ b/src/squidpy/experimental/im/_stain/_reinhard.py @@ -139,6 +139,35 @@ def fit_reinhard( return StainReference(method="reinhard", mu=mu, sigma=sigma) +def fit_reinhard_pooled( + das: list[xr.DataArray], + masks: list[np.ndarray], + params: ReinhardParams, + image_keys: list[str], +) -> StainReference: + """Fit Reinhard stats from the pooled tissue Lab pixels of several slides. + + Gathers each slide's tissue Lab pixels (naming the slide on empty tissue) and + concatenates them into one set, then takes ``mu``/``sigma`` over the pool - + matching :func:`_masked_channel_stats` (population std, ddof=0, tissue only). + """ + cols: list[np.ndarray] = [] + for da, m, k in zip(das, masks, image_keys, strict=True): + lab = rgb_to_lab_ruderman(da) + masked = lab.where(mask) if (mask := _reinhard_mask(lab, params, m)) is not None else lab + pix = np.asarray(masked.transpose("c", "y", "x").data).reshape(3, -1) + pix = pix[:, np.all(np.isfinite(pix), axis=0)] + if pix.shape[1] == 0: + raise ValueError(f"Foreground mask leaves zero tissue pixels for image {k!r}.") + cols.append(pix) + pooled = np.concatenate(cols, axis=1) + return StainReference( + method="reinhard", + mu=np.asarray(pooled.mean(axis=1), dtype=np.float64), + sigma=np.asarray(pooled.std(axis=1, ddof=0), dtype=np.float64), + ) + + def apply_reinhard( image_rgb: xr.DataArray, reference: StainReference, diff --git a/tests/experimental/test_stain_normalize.py b/tests/experimental/test_stain_normalize.py index e09c910ba..446b80bda 100644 --- a/tests/experimental/test_stain_normalize.py +++ b/tests/experimental/test_stain_normalize.py @@ -251,3 +251,103 @@ def test_plot_reinhard_before_after(self, sdata_hne) -> None: _, axes = plt.subplots(1, 2, figsize=(8, 4)) sdata_hne.pl.render_images("hne_shifted").pl.show(ax=axes[0], title="before") sdata_hne.pl.render_images("hne_normalized").pl.show(ax=axes[1], title="after") + + +# --------------------------------------------------------------------------- +# Multi-slide pooled fit (one reference from several images in one sdata) +# --------------------------------------------------------------------------- + +from squidpy.experimental.im._stain._constants import RUIFROK_HE # noqa: E402 + + +class TestPooledFit: + @staticmethod + def _he(seed: int, shape: tuple[int, int] = (48, 48), dtype=np.uint8) -> np.ndarray: + """Synthetic H&E from the Ruifrok H/E vectors so macenko/vahadane can fit.""" + rng = np.random.default_rng(seed) + h, w = shape + wmat = np.stack([RUIFROK_HE["hematoxylin"], RUIFROK_HE["eosin"]], axis=1) # (3, 2) + conc = rng.uniform(0.05, 1.3, (h * w, 2)) + rgb = np.clip(255.0 * np.exp(-(conc @ wmat.T)), 0, 255).reshape(h, w, 3).transpose(2, 0, 1) + arr = rgb.astype(np.uint8) + return (arr.astype(np.uint16) * 257) if dtype == np.uint16 else arr + + def _cohort(self, n: int = 3, dtype=np.uint8) -> tuple[sd.SpatialData, list[str]]: + sdata = sd.SpatialData() + keys = [] + for i in range(n): + k = f"img{i}" + arr = self._he(seed=i + 1, dtype=dtype) + sdata.images[k] = Image2DModel.parse(arr, dims=("c", "y", "x")) + h, w = arr.shape[-2], arr.shape[-1] + sdata.labels[f"{k}_tissue"] = Labels2DModel.parse(np.ones((h, w), dtype=np.uint32), dims=("y", "x")) + keys.append(k) + return sdata, keys + + @pytest.mark.parametrize("method", ["reinhard", "macenko", "vahadane"]) + def test_pooled_fit_runs(self, method: str) -> None: + sdata, keys = self._cohort() + ref = fit_stain_reference(sdata, keys, method=method) + assert ref.method == method + assert (ref.mu.shape == (3,)) if method == "reinhard" else (ref.stain_matrix.shape == (3, 3)) + + @pytest.mark.parametrize("method", ["reinhard", "macenko", "vahadane"]) + def test_pooled_of_one_matches_single(self, method: str) -> None: + sdata, keys = self._cohort(n=1) + single = fit_stain_reference(sdata, keys[0], method=method) + pooled = fit_stain_reference(sdata, keys, method=method) + if method == "reinhard": + # pooled path materialises + np-reduces; single uses xarray lazy reduction + np.testing.assert_allclose(pooled.mu, single.mu) + np.testing.assert_allclose(pooled.sigma, single.sigma) + else: + np.testing.assert_array_equal(pooled.stain_matrix, single.stain_matrix) + np.testing.assert_array_equal(pooled.max_concentrations, single.max_concentrations) + + def test_order_matched_non_convention_masks(self) -> None: + # non-convention mask names selecting different halves; swapping the order + # must change the fit (proves order is honoured, not name-matched). + sdata, keys = self._cohort(n=2) + h, w = 48, 48 + top = np.zeros((h, w), np.uint32) + top[: h // 2] = 1 + bot = np.zeros((h, w), np.uint32) + bot[h // 2 :] = 1 + sdata.labels["m_a"] = Labels2DModel.parse(top, dims=("y", "x")) + sdata.labels["m_b"] = Labels2DModel.parse(bot, dims=("y", "x")) + ref = fit_stain_reference(sdata, keys, method="reinhard", tissue_mask_key=["m_a", "m_b"]) + swapped = fit_stain_reference(sdata, keys, method="reinhard", tissue_mask_key=["m_b", "m_a"]) + assert not np.allclose(ref.mu, swapped.mu) + + @pytest.mark.parametrize("method", ["reinhard", "macenko"]) + def test_empty_slide_is_named(self, method: str) -> None: + sdata, keys = self._cohort(n=2) + sdata.labels[f"{keys[1]}_tissue"] = Labels2DModel.parse(np.zeros((48, 48), np.uint32), dims=("y", "x")) + with pytest.raises((ValueError, RuntimeError), match=keys[1]): + fit_stain_reference(sdata, keys, method=method) + + def test_mixed_dtype_raises(self) -> None: + sdata, _ = self._cohort(n=1) + sdata.images["img16"] = Image2DModel.parse(self._he(seed=9, dtype=np.uint16), dims=("c", "y", "x")) + sdata.labels["img16_tissue"] = Labels2DModel.parse(np.ones((48, 48), np.uint32), dims=("y", "x")) + with pytest.raises(ValueError, match="share a dtype"): + fit_stain_reference(sdata, ["img0", "img16"], method="macenko") + + @pytest.mark.parametrize( + ("image_key", "tissue_mask_key", "match"), + [ + ([], None, "empty"), + (["img0", "img0"], None, "duplicate"), + (["img0", "img1"], "img0_tissue", "list of `tissue_mask_key`"), + (["img0", "img1"], ["img0_tissue"], "length"), + ], + ) + def test_validation(self, image_key, tissue_mask_key, match: str) -> None: + sdata, _ = self._cohort(n=2) + with pytest.raises(ValueError, match=match): + fit_stain_reference(sdata, image_key, method="macenko", tissue_mask_key=tissue_mask_key) + + def test_list_mask_with_str_image_raises(self) -> None: + sdata, _ = self._cohort(n=1) + with pytest.raises(ValueError, match="single `tissue_mask_key`"): + fit_stain_reference(sdata, "img0", method="macenko", tissue_mask_key=["img0_tissue"])