Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/squidpy/experimental/im/_stain/_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,48 @@ def fit_decomposition(
) -> StainReference:
"""Fit a decomposition :class:`StainReference` (stain matrix + max concentrations)."""
od = _tissue_od(image_rgb, white_point, params.beta, tissue_mask=tissue_mask, image_key=image_key)
return _reference_from_od(
od, method, params, white_point, image_key=image_key, reference=reference, max_angle_deg=max_angle_deg
)


def fit_decomposition_pooled(
das: list[xr.DataArray],
masks: list[np.ndarray],
image_keys: list[str],
method: StainMethod,
params: Any,
white_point: np.ndarray,
*,
reference: dict[str, np.ndarray] = RUIFROK_HE,
max_angle_deg: float = 45.0,
) -> StainReference:
"""Fit a decomposition reference from the pooled tissue OD of several slides.

Each slide's tissue OD is gathered (naming the slide on empty tissue) and
stacked into one ``(SUM_N, 3)`` array; the single-image fit tail then runs on
the pooled OD. Pooling one slide is identical to :func:`fit_decomposition`.
"""
ods = [
_tissue_od(da, white_point, params.beta, tissue_mask=m, image_key=k)
for da, m, k in zip(das, masks, image_keys, strict=True)
]
return _reference_from_od(
np.vstack(ods), method, params, white_point, image_key=None, reference=reference, max_angle_deg=max_angle_deg
)


def _reference_from_od(
od: np.ndarray,
method: StainMethod,
params: Any,
white_point: np.ndarray,
*,
image_key: str | None,
reference: dict[str, np.ndarray],
max_angle_deg: float,
) -> StainReference:
"""Shared fit tail: stain matrix (gated) + max concentrations -> StainReference."""
matrix = _stain_matrix(od, method, params, image_key=image_key, reference=reference, max_angle_deg=max_angle_deg)
return StainReference(
method=method,
Expand Down
85 changes: 81 additions & 4 deletions src/squidpy/experimental/im/_stain/_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
apply_decomposition,
decompose_to_concentrations,
fit_decomposition,
fit_decomposition_pooled,
)
from squidpy.experimental.im._stain._reference import StainMethod, StainReference
from squidpy.experimental.im._stain._reinhard import (
ReinhardParams,
_resolve_reinhard_params,
apply_reinhard,
fit_reinhard,
fit_reinhard_pooled,
)
from squidpy.experimental.im._stain._white_point import (
default_white_point,
Expand Down Expand Up @@ -211,13 +213,13 @@ def estimate_white_point(

def fit_stain_reference(
sdata: sd.SpatialData,
image_key: str,
image_key: str | list[str],
*,
method: StainMethod = "macenko",
scale: str | Literal["auto"] = "auto",
method_params: MethodParams = None,
white_point: np.ndarray | None = None,
tissue_mask_key: str | None = None,
tissue_mask_key: str | list[str] | None = None,
max_angle_deg: float = 45.0,
canonical_reference: Mapping[str, np.ndarray] | None = None,
) -> StainReference:
Expand All @@ -228,7 +230,9 @@ def fit_stain_reference(
sdata
SpatialData object containing the image.
image_key
Key of the RGB image in ``sdata.images`` to fit on.
Key of the RGB image in ``sdata.images`` to fit on, or a **list of keys**
to fit one reference from the pooled tissue pixels of several images
(e.g. a representative cohort). Pooled images must share a dtype.
method
Fitting method: ``"macenko"`` (default) or ``"vahadane"`` (physical
stain-matrix decomposition, usable by both :func:`normalize_stains` and
Expand All @@ -254,7 +258,9 @@ def fit_stain_reference(
:func:`!detect_tissue`) restricting the fit to
tissue pixels. If ``None``, ``f"{image_key}_tissue"`` is used. A tissue
mask is **required**: if neither exists, a :class:`KeyError` asks you to
run :func:`!detect_tissue` first.
run :func:`!detect_tissue` first. When ``image_key`` is a list, pass a
list of mask keys **order-matched** to it (or ``None`` for the
``{key}_tissue`` convention per image).
max_angle_deg
Tolerance of the H/E sanity gate for the decomposition methods: the fit
raises :class:`!StainFittingError` if either recovered stain vector
Expand All @@ -272,6 +278,20 @@ def fit_stain_reference(
"""
if method not in _VALID_METHODS:
raise ValueError(f"Unknown method {method!r}; expected one of {list(_VALID_METHODS)}.")
if not isinstance(image_key, str):
return _fit_pooled(
sdata,
list(image_key),
tissue_mask_key,
method=method,
scale=scale,
method_params=method_params,
white_point=white_point,
max_angle_deg=max_angle_deg,
canonical_reference=canonical_reference,
)
if tissue_mask_key is not None and not isinstance(tissue_mask_key, str):
raise ValueError("a single `image_key` takes a single `tissue_mask_key` (str) or None.")
da = _resolve_image(sdata, image_key, scale, prefer="coarsest")
validate_rgb_range(da)
params = _resolve_method_params(method, method_params)
Expand All @@ -292,6 +312,63 @@ def fit_stain_reference(
)


def _fit_pooled(
sdata: sd.SpatialData,
image_keys: list[str],
tissue_mask_key: str | list[str] | None,
*,
method: StainMethod,
scale: str | Literal["auto"],
method_params: MethodParams,
white_point: np.ndarray | None,
max_angle_deg: float,
canonical_reference: Mapping[str, np.ndarray] | None,
) -> StainReference:
"""Fit one reference by pooling the tissue pixels of several same-dtype images.

Each image is resolved + validated and masked by its own `tissue_mask_key`
(order-matched) or the `{key}_tissue` convention; decomposition pools tissue
OD, Reinhard pools tissue Lab pixels. Images must share a dtype so the white
point is well-defined. A blank slide raises (named); no silent skip.
"""
if not image_keys:
raise ValueError("`image_key` list is empty; pass at least one image key.")
if len(set(image_keys)) != len(image_keys):
raise ValueError("`image_key` list has duplicate keys.")
if tissue_mask_key is None:
mask_keys: list[str | None] = [None] * len(image_keys)
elif isinstance(tissue_mask_key, str):
raise ValueError(
"for multiple images, pass a list of `tissue_mask_key` (order-matched to `image_key`), "
"or None to use the `{image_key}_tissue` convention."
)
elif len(tissue_mask_key) != len(image_keys):
raise ValueError(
f"`tissue_mask_key` length ({len(tissue_mask_key)}) must match `image_key` ({len(image_keys)})."
)
else:
mask_keys = list(tissue_mask_key)

params = _resolve_method_params(method, method_params)
das = []
for k in image_keys:
da = _resolve_image(sdata, k, scale, prefer="coarsest")
validate_rgb_range(da)
das.append(da)
dtypes = {str(da.dtype) for da in das}
if len(dtypes) != 1:
raise ValueError(f"pooled images must share a dtype; got {sorted(dtypes)}.")
masks = [_resolve_tissue_bool_mask(sdata, k, da, mk) for k, da, mk in zip(image_keys, das, mask_keys, strict=True)]

if method == "reinhard":
return fit_reinhard_pooled(das, masks, params, image_keys)
bg = default_white_point(das[0]) if white_point is None else np.asarray(white_point, np.float64)
reference = RUIFROK_HE if canonical_reference is None else dict(canonical_reference)
return fit_decomposition_pooled(
das, masks, image_keys, method, params, bg, reference=reference, max_angle_deg=max_angle_deg
)


def normalize_stains(
sdata: sd.SpatialData,
image_key: str,
Expand Down
29 changes: 29 additions & 0 deletions src/squidpy/experimental/im/_stain/_reinhard.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,35 @@ def fit_reinhard(
return StainReference(method="reinhard", mu=mu, sigma=sigma)


def fit_reinhard_pooled(
das: list[xr.DataArray],
masks: list[np.ndarray],
params: ReinhardParams,
image_keys: list[str],
) -> StainReference:
"""Fit Reinhard stats from the pooled tissue Lab pixels of several slides.

Gathers each slide's tissue Lab pixels (naming the slide on empty tissue) and
concatenates them into one set, then takes ``mu``/``sigma`` over the pool -
matching :func:`_masked_channel_stats` (population std, ddof=0, tissue only).
"""
cols: list[np.ndarray] = []
for da, m, k in zip(das, masks, image_keys, strict=True):
lab = rgb_to_lab_ruderman(da)
masked = lab.where(mask) if (mask := _reinhard_mask(lab, params, m)) is not None else lab
pix = np.asarray(masked.transpose("c", "y", "x").data).reshape(3, -1)
pix = pix[:, np.all(np.isfinite(pix), axis=0)]
if pix.shape[1] == 0:
raise ValueError(f"Foreground mask leaves zero tissue pixels for image {k!r}.")
cols.append(pix)
pooled = np.concatenate(cols, axis=1)
return StainReference(
method="reinhard",
mu=np.asarray(pooled.mean(axis=1), dtype=np.float64),
sigma=np.asarray(pooled.std(axis=1, ddof=0), dtype=np.float64),
)


def apply_reinhard(
image_rgb: xr.DataArray,
reference: StainReference,
Expand Down
100 changes: 100 additions & 0 deletions tests/experimental/test_stain_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,103 @@ def test_plot_reinhard_before_after(self, sdata_hne) -> None:
_, axes = plt.subplots(1, 2, figsize=(8, 4))
sdata_hne.pl.render_images("hne_shifted").pl.show(ax=axes[0], title="before")
sdata_hne.pl.render_images("hne_normalized").pl.show(ax=axes[1], title="after")


# ---------------------------------------------------------------------------
# Multi-slide pooled fit (one reference from several images in one sdata)
# ---------------------------------------------------------------------------

from squidpy.experimental.im._stain._constants import RUIFROK_HE # noqa: E402


class TestPooledFit:
@staticmethod
def _he(seed: int, shape: tuple[int, int] = (48, 48), dtype=np.uint8) -> np.ndarray:
"""Synthetic H&E from the Ruifrok H/E vectors so macenko/vahadane can fit."""
rng = np.random.default_rng(seed)
h, w = shape
wmat = np.stack([RUIFROK_HE["hematoxylin"], RUIFROK_HE["eosin"]], axis=1) # (3, 2)
conc = rng.uniform(0.05, 1.3, (h * w, 2))
rgb = np.clip(255.0 * np.exp(-(conc @ wmat.T)), 0, 255).reshape(h, w, 3).transpose(2, 0, 1)
arr = rgb.astype(np.uint8)
return (arr.astype(np.uint16) * 257) if dtype == np.uint16 else arr

def _cohort(self, n: int = 3, dtype=np.uint8) -> tuple[sd.SpatialData, list[str]]:
sdata = sd.SpatialData()
keys = []
for i in range(n):
k = f"img{i}"
arr = self._he(seed=i + 1, dtype=dtype)
sdata.images[k] = Image2DModel.parse(arr, dims=("c", "y", "x"))
h, w = arr.shape[-2], arr.shape[-1]
sdata.labels[f"{k}_tissue"] = Labels2DModel.parse(np.ones((h, w), dtype=np.uint32), dims=("y", "x"))
keys.append(k)
return sdata, keys

@pytest.mark.parametrize("method", ["reinhard", "macenko", "vahadane"])
def test_pooled_fit_runs(self, method: str) -> None:
sdata, keys = self._cohort()
ref = fit_stain_reference(sdata, keys, method=method)
assert ref.method == method
assert (ref.mu.shape == (3,)) if method == "reinhard" else (ref.stain_matrix.shape == (3, 3))

@pytest.mark.parametrize("method", ["reinhard", "macenko", "vahadane"])
def test_pooled_of_one_matches_single(self, method: str) -> None:
sdata, keys = self._cohort(n=1)
single = fit_stain_reference(sdata, keys[0], method=method)
pooled = fit_stain_reference(sdata, keys, method=method)
if method == "reinhard":
# pooled path materialises + np-reduces; single uses xarray lazy reduction
np.testing.assert_allclose(pooled.mu, single.mu)
np.testing.assert_allclose(pooled.sigma, single.sigma)
else:
np.testing.assert_array_equal(pooled.stain_matrix, single.stain_matrix)
np.testing.assert_array_equal(pooled.max_concentrations, single.max_concentrations)

def test_order_matched_non_convention_masks(self) -> None:
# non-convention mask names selecting different halves; swapping the order
# must change the fit (proves order is honoured, not name-matched).
sdata, keys = self._cohort(n=2)
h, w = 48, 48
top = np.zeros((h, w), np.uint32)
top[: h // 2] = 1
bot = np.zeros((h, w), np.uint32)
bot[h // 2 :] = 1
sdata.labels["m_a"] = Labels2DModel.parse(top, dims=("y", "x"))
sdata.labels["m_b"] = Labels2DModel.parse(bot, dims=("y", "x"))
ref = fit_stain_reference(sdata, keys, method="reinhard", tissue_mask_key=["m_a", "m_b"])
swapped = fit_stain_reference(sdata, keys, method="reinhard", tissue_mask_key=["m_b", "m_a"])
assert not np.allclose(ref.mu, swapped.mu)

@pytest.mark.parametrize("method", ["reinhard", "macenko"])
def test_empty_slide_is_named(self, method: str) -> None:
sdata, keys = self._cohort(n=2)
sdata.labels[f"{keys[1]}_tissue"] = Labels2DModel.parse(np.zeros((48, 48), np.uint32), dims=("y", "x"))
with pytest.raises((ValueError, RuntimeError), match=keys[1]):
fit_stain_reference(sdata, keys, method=method)

def test_mixed_dtype_raises(self) -> None:
sdata, _ = self._cohort(n=1)
sdata.images["img16"] = Image2DModel.parse(self._he(seed=9, dtype=np.uint16), dims=("c", "y", "x"))
sdata.labels["img16_tissue"] = Labels2DModel.parse(np.ones((48, 48), np.uint32), dims=("y", "x"))
with pytest.raises(ValueError, match="share a dtype"):
fit_stain_reference(sdata, ["img0", "img16"], method="macenko")

@pytest.mark.parametrize(
("image_key", "tissue_mask_key", "match"),
[
([], None, "empty"),
(["img0", "img0"], None, "duplicate"),
(["img0", "img1"], "img0_tissue", "list of `tissue_mask_key`"),
(["img0", "img1"], ["img0_tissue"], "length"),
],
)
def test_validation(self, image_key, tissue_mask_key, match: str) -> None:
sdata, _ = self._cohort(n=2)
with pytest.raises(ValueError, match=match):
fit_stain_reference(sdata, image_key, method="macenko", tissue_mask_key=tissue_mask_key)

def test_list_mask_with_str_image_raises(self) -> None:
sdata, _ = self._cohort(n=1)
with pytest.raises(ValueError, match="single `tissue_mask_key`"):
fit_stain_reference(sdata, "img0", method="macenko", tissue_mask_key=["img0_tissue"])
Loading