Skip to content

Commit 92cb5eb

Browse files
timtreisclaude
andcommitted
Unify multi-channel compositing to additive blending and add per-channel norm
Replace inconsistent compositing formulas in _render_images with a shared _additive_blend helper that implements standard additive blending with clamping, matching Napari/ImageJ/FIJI behavior. Fixes: - Averaging bug in paths 2A-cmap and 2D (divided by n_channels) - Missing clip in paths 2B (2-3ch) and 2C - Double-alpha bug in path 2B (4+ch) - Redundant isinstance(c, str) filter in path 2C Adds per-channel normalization support: render_images() now accepts norm as a list of Normalize objects (one per channel), enabling proper contrast control for multi-channel fluorescence data. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 2893343 commit 92cb5eb

4 files changed

Lines changed: 54 additions & 60 deletions

File tree

src/spatialdata_plot/pl/basic.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def render_images(
512512
*,
513513
channel: list[str] | list[int] | str | int | None = None,
514514
cmap: list[Colormap | str] | Colormap | str | None = None,
515-
norm: Normalize | None = None,
515+
norm: list[Normalize] | Normalize | None = None,
516516
na_color: ColorLike | None = "default",
517517
palette: list[str] | str | None = None,
518518
alpha: float | int = 1.0,
@@ -541,9 +541,10 @@ def render_images(
541541
cmap : list[Colormap | str] | Colormap | str | None
542542
Colormap or list of colormaps for continuous annotations, see :class:`matplotlib.colors.Colormap`.
543543
Each colormap applies to a corresponding channel.
544-
norm : Normalize | None, optional
544+
norm : list[Normalize] | Normalize | None, optional
545545
Colormap normalization for continuous annotations, see :class:`matplotlib.colors.Normalize`.
546-
Applies to all channels if set.
546+
Can be a single :class:`~matplotlib.colors.Normalize` (applied to all channels) or a list
547+
of :class:`~matplotlib.colors.Normalize` objects (one per channel) for per-channel control.
547548
na_color : ColorLike | None, default "default" (gets set to "lightgray")
548549
Color to be used for NAs values, if present. Can either be a named color ("red"), a hex representation
549550
("#000000ff") or a list of floats that represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0). When None, the values
@@ -596,12 +597,15 @@ def render_images(
596597
n_steps = len(sdata.plotting_tree.keys())
597598

598599
for element, param_values in params_dict.items():
600+
# When norm is a list, per-channel norms are stored separately in ImageRenderParams.norms
601+
# and _prepare_cmap_norm receives None so it creates a default (unused) norm.
602+
scalar_norm = None if isinstance(norm, list) else norm
599603
cmap_params: list[CmapParams] | CmapParams
600604
if isinstance(cmap, list):
601605
cmap_params = [
602606
_prepare_cmap_norm(
603607
cmap=c,
604-
norm=norm,
608+
norm=scalar_norm,
605609
na_color=param_values["na_color"],
606610
)
607611
for c in cmap
@@ -610,7 +614,7 @@ def render_images(
610614
else:
611615
cmap_params = _prepare_cmap_norm(
612616
cmap=cmap,
613-
norm=norm,
617+
norm=scalar_norm,
614618
na_color=param_values["na_color"],
615619
**kwargs,
616620
)
@@ -619,6 +623,7 @@ def render_images(
619623
channel=param_values["channel"],
620624
cmap_params=cmap_params,
621625
palette=param_values["palette"],
626+
norms=norm if isinstance(norm, list) else None,
622627
alpha=param_values["alpha"],
623628
scale=param_values["scale"],
624629
zorder=n_steps,

src/spatialdata_plot/pl/render.py

Lines changed: 37 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import spatialdata as sd
1818
from anndata import AnnData
1919
from matplotlib.cm import ScalarMappable
20-
from matplotlib.colors import ListedColormap, Normalize
20+
from matplotlib.colors import Colormap, ListedColormap, Normalize
2121
from scanpy._settings import settings as sc_settings
2222
from spatialdata import get_extent, get_values, join_spatialelement_table
2323
from spatialdata._core.query.relational_query import match_table_to_element
@@ -1175,6 +1175,24 @@ def _render_points(
11751175
)
11761176

11771177

1178+
def _additive_blend(
1179+
layers: dict[str | int, np.ndarray],
1180+
channels: list[Any],
1181+
channel_cmaps: list[Colormap],
1182+
) -> np.ndarray:
1183+
"""Additive blend of colormapped channels, matching Napari's additive mode.
1184+
1185+
Each channel is mapped through its colormap (which must return RGBA),
1186+
the RGB components are summed, and the result is clamped to [0, 1].
1187+
"""
1188+
height, width = next(iter(layers.values())).shape
1189+
composite = np.zeros((height, width, 3), dtype=float)
1190+
for ch, cmap in zip(channels, channel_cmaps, strict=True):
1191+
rgba = cmap(np.asarray(layers[ch]))
1192+
composite += rgba[..., :3]
1193+
return np.clip(composite, 0, 1, out=composite)
1194+
1195+
11781196
def _render_images(
11791197
sdata: sd.SpatialData,
11801198
render_params: ImageRenderParams,
@@ -1238,6 +1256,11 @@ def _render_images(
12381256
if isinstance(render_params.cmap_params, list) and len(render_params.cmap_params) != n_channels:
12391257
raise ValueError("If 'cmap' is provided, its length must match the number of channels.")
12401258

1259+
if render_params.norms is not None and len(render_params.norms) != n_channels:
1260+
raise ValueError(
1261+
f"Length of 'norm' list ({len(render_params.norms)}) must match the number of channels ({n_channels})."
1262+
)
1263+
12411264
_, trans_data = _prepare_transformation(img, coordinate_system, ax)
12421265

12431266
# 1) Image has only 1 channel
@@ -1255,13 +1278,14 @@ def _render_images(
12551278
cmap._lut[:, -1] = render_params.alpha
12561279

12571280
# norm needs to be passed directly to ax.imshow(). If we normalize before, that method would always clip.
1281+
single_norm = render_params.norms[0] if render_params.norms else render_params.cmap_params.norm
12581282
_ax_show_and_transform(
12591283
layer,
12601284
trans_data,
12611285
ax,
12621286
cmap=cmap,
12631287
zorder=render_params.zorder,
1264-
norm=render_params.cmap_params.norm,
1288+
norm=single_norm,
12651289
)
12661290

12671291
wants_colorbar = _should_request_colorbar(
@@ -1271,7 +1295,7 @@ def _render_images(
12711295
auto_condition=n_channels == 1,
12721296
)
12731297
if wants_colorbar and legend_params.colorbar and colorbar_requests is not None:
1274-
sm = plt.cm.ScalarMappable(cmap=cmap, norm=render_params.cmap_params.norm)
1298+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=single_norm)
12751299
colorbar_requests.append(
12761300
ColorbarSpec(
12771301
ax=ax,
@@ -1290,7 +1314,9 @@ def _render_images(
12901314
layers = {}
12911315
for ch_idx, ch in enumerate(channels):
12921316
layers[ch] = img.sel(c=ch).copy(deep=True).squeeze()
1293-
if isinstance(render_params.cmap_params, list):
1317+
if render_params.norms is not None:
1318+
ch_norm = render_params.norms[ch_idx]
1319+
elif isinstance(render_params.cmap_params, list):
12941320
ch_norm = render_params.cmap_params[ch_idx].norm
12951321
else:
12961322
ch_norm = render_params.cmap_params.norm
@@ -1304,14 +1330,7 @@ def _render_images(
13041330
stacked = np.stack([layers[ch] for ch in layers], axis=-1)
13051331
else: # -> use given cmap for each channel
13061332
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
1307-
stacked = (
1308-
np.stack(
1309-
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
1310-
0,
1311-
).sum(0)
1312-
/ n_channels
1313-
)
1314-
stacked = stacked[:, :, :3]
1333+
stacked = _additive_blend(layers, channels, channel_cmaps)
13151334
logger.warning(
13161335
"One cmap was given for multiple channels and is now used for each channel. "
13171336
"You're blending multiple cmaps. "
@@ -1332,22 +1351,9 @@ def _render_images(
13321351
# 2B) Image has n channels, no palette/cmap info -> sample n categorical colors
13331352
elif palette is None and not got_multiple_cmaps:
13341353
# overwrite if n_channels == 2 for intuitive result
1354+
# Pick seed colors based on channel count
13351355
if n_channels == 2:
13361356
seed_colors = ["#ff0000ff", "#00ff00ff"]
1337-
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
1338-
colored = np.stack(
1339-
[channel_cmaps[ch_ind](layers[ch]) for ch_ind, ch in enumerate(channels)],
1340-
0,
1341-
).sum(0)
1342-
colored = colored[:, :, :3]
1343-
elif n_channels == 3:
1344-
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
1345-
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
1346-
colored = np.stack(
1347-
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
1348-
0,
1349-
).sum(0)
1350-
colored = colored[:, :, :3]
13511357
else:
13521358
if isinstance(render_params.cmap_params, list):
13531359
cmap_is_default = render_params.cmap_params[0].cmap_is_default
@@ -1364,24 +1370,9 @@ def _render_images(
13641370
]
13651371
else:
13661372
seed_colors = [render_params.cmap_params.cmap(i / (n_channels - 1)) for i in range(n_channels)]
1367-
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
1368-
1369-
# Stack (n_channels, height, width) → (height*width, n_channels)
1370-
H, W = next(iter(layers.values())).shape
1371-
comp_rgb = np.zeros((H, W, 3), dtype=float)
13721373

1373-
# For each channel: map to RGBA, apply constant alpha, then add
1374-
for ch_idx, ch in enumerate(channels):
1375-
layer_arr = layers[ch]
1376-
rgba = channel_cmaps[ch_idx](layer_arr)
1377-
rgba[..., 3] = render_params.alpha
1378-
comp_rgb += rgba[..., :3] * rgba[..., 3][..., None]
1379-
1380-
colored = np.clip(comp_rgb, 0, 1)
1381-
logger.info(
1382-
f"Your image has {n_channels} channels. Sampling categorical colors and using "
1383-
f"multichannel strategy 'stack' to render."
1384-
) # TODO: update when pca is added as strategy
1374+
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
1375+
colored = _additive_blend(layers, channels, channel_cmaps)
13851376

13861377
_ax_show_and_transform(
13871378
colored,
@@ -1396,9 +1387,8 @@ def _render_images(
13961387
if len(palette) != n_channels:
13971388
raise ValueError("If 'palette' is provided, its length must match the number of channels.")
13981389

1399-
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in palette if isinstance(c, str)]
1400-
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)
1401-
colored = colored[:, :, :3]
1390+
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in palette]
1391+
colored = _additive_blend(layers, channels, channel_cmaps)
14021392

14031393
_ax_show_and_transform(
14041394
colored,
@@ -1410,14 +1400,7 @@ def _render_images(
14101400

14111401
elif palette is None and got_multiple_cmaps:
14121402
channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr]
1413-
colored = (
1414-
np.stack(
1415-
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
1416-
0,
1417-
).sum(0)
1418-
/ n_channels
1419-
)
1420-
colored = colored[:, :, :3]
1403+
colored = _additive_blend(layers, channels, channel_cmaps)
14211404

14221405
_ax_show_and_transform(
14231406
colored,

src/spatialdata_plot/pl/render_params.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ class ImageRenderParams:
264264
element: str
265265
channel: list[str] | list[int] | int | str | None = None
266266
palette: ListedColormap | list[str] | None = None
267+
norms: list[Normalize] | None = None
267268
alpha: float = 1.0
268269
percentiles_for_norm: tuple[float | None, float | None] = (None, None)
269270
scale: str | None = None

src/spatialdata_plot/pl/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2394,7 +2394,12 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
23942394

23952395
norm = param_dict.get("norm")
23962396
if norm is not None:
2397-
if element_type in {"images", "labels"} and not isinstance(norm, Normalize):
2397+
if element_type == "images" and isinstance(norm, list):
2398+
if len(norm) == 0:
2399+
raise ValueError("'norm' list must not be empty.")
2400+
if not all(isinstance(n, Normalize) for n in norm):
2401+
raise TypeError("All elements of 'norm' list must be of type Normalize.")
2402+
elif element_type in {"images", "labels"} and not isinstance(norm, Normalize):
23982403
raise TypeError("Parameter 'norm' must be of type Normalize.")
23992404
if element_type in {"shapes", "points"} and not isinstance(norm, bool | Normalize):
24002405
raise TypeError("Parameter 'norm' must be a boolean or a mpl.Normalize.")

0 commit comments

Comments
 (0)