1717import spatialdata as sd
1818from anndata import AnnData
1919from matplotlib .cm import ScalarMappable
20- from matplotlib .colors import ListedColormap , Normalize
20+ from matplotlib .colors import Colormap , ListedColormap , Normalize
2121from scanpy ._settings import settings as sc_settings
2222from spatialdata import get_extent , get_values , join_spatialelement_table
2323from spatialdata ._core .query .relational_query import match_table_to_element
@@ -1175,6 +1175,24 @@ def _render_points(
11751175 )
11761176
11771177
1178+ def _additive_blend (
1179+ layers : dict [str | int , np .ndarray ],
1180+ channels : list [Any ],
1181+ channel_cmaps : list [Colormap ],
1182+ ) -> np .ndarray :
1183+ """Additive blend of colormapped channels, matching Napari's additive mode.
1184+
1185+ Each channel is mapped through its colormap (which must return RGBA),
1186+ the RGB components are summed, and the result is clamped to [0, 1].
1187+ """
1188+ height , width = next (iter (layers .values ())).shape
1189+ composite = np .zeros ((height , width , 3 ), dtype = float )
1190+ for ch , cmap in zip (channels , channel_cmaps , strict = True ):
1191+ rgba = cmap (np .asarray (layers [ch ]))
1192+ composite += rgba [..., :3 ]
1193+ return np .clip (composite , 0 , 1 , out = composite )
1194+
1195+
11781196def _render_images (
11791197 sdata : sd .SpatialData ,
11801198 render_params : ImageRenderParams ,
@@ -1238,6 +1256,11 @@ def _render_images(
12381256 if isinstance (render_params .cmap_params , list ) and len (render_params .cmap_params ) != n_channels :
12391257 raise ValueError ("If 'cmap' is provided, its length must match the number of channels." )
12401258
1259+ if render_params .norms is not None and len (render_params .norms ) != n_channels :
1260+ raise ValueError (
1261+ f"Length of 'norm' list ({ len (render_params .norms )} ) must match the number of channels ({ n_channels } )."
1262+ )
1263+
12411264 _ , trans_data = _prepare_transformation (img , coordinate_system , ax )
12421265
12431266 # 1) Image has only 1 channel
@@ -1255,13 +1278,14 @@ def _render_images(
12551278 cmap ._lut [:, - 1 ] = render_params .alpha
12561279
12571280 # norm needs to be passed directly to ax.imshow(). If we normalize before, that method would always clip.
1281+ single_norm = render_params .norms [0 ] if render_params .norms else render_params .cmap_params .norm
12581282 _ax_show_and_transform (
12591283 layer ,
12601284 trans_data ,
12611285 ax ,
12621286 cmap = cmap ,
12631287 zorder = render_params .zorder ,
1264- norm = render_params . cmap_params . norm ,
1288+ norm = single_norm ,
12651289 )
12661290
12671291 wants_colorbar = _should_request_colorbar (
@@ -1271,7 +1295,7 @@ def _render_images(
12711295 auto_condition = n_channels == 1 ,
12721296 )
12731297 if wants_colorbar and legend_params .colorbar and colorbar_requests is not None :
1274- sm = plt .cm .ScalarMappable (cmap = cmap , norm = render_params . cmap_params . norm )
1298+ sm = plt .cm .ScalarMappable (cmap = cmap , norm = single_norm )
12751299 colorbar_requests .append (
12761300 ColorbarSpec (
12771301 ax = ax ,
@@ -1290,7 +1314,9 @@ def _render_images(
12901314 layers = {}
12911315 for ch_idx , ch in enumerate (channels ):
12921316 layers [ch ] = img .sel (c = ch ).copy (deep = True ).squeeze ()
1293- if isinstance (render_params .cmap_params , list ):
1317+ if render_params .norms is not None :
1318+ ch_norm = render_params .norms [ch_idx ]
1319+ elif isinstance (render_params .cmap_params , list ):
12941320 ch_norm = render_params .cmap_params [ch_idx ].norm
12951321 else :
12961322 ch_norm = render_params .cmap_params .norm
@@ -1304,14 +1330,7 @@ def _render_images(
13041330 stacked = np .stack ([layers [ch ] for ch in layers ], axis = - 1 )
13051331 else : # -> use given cmap for each channel
13061332 channel_cmaps = [render_params .cmap_params .cmap ] * n_channels
1307- stacked = (
1308- np .stack (
1309- [channel_cmaps [ind ](layers [ch ]) for ind , ch in enumerate (channels )],
1310- 0 ,
1311- ).sum (0 )
1312- / n_channels
1313- )
1314- stacked = stacked [:, :, :3 ]
1333+ stacked = _additive_blend (layers , channels , channel_cmaps )
13151334 logger .warning (
13161335 "One cmap was given for multiple channels and is now used for each channel. "
13171336 "You're blending multiple cmaps. "
@@ -1332,22 +1351,9 @@ def _render_images(
13321351 # 2B) Image has n channels, no palette/cmap info -> sample n categorical colors
13331352 elif palette is None and not got_multiple_cmaps :
13341353 # overwrite if n_channels == 2 for intuitive result
1354+ # Pick seed colors based on channel count
13351355 if n_channels == 2 :
13361356 seed_colors = ["#ff0000ff" , "#00ff00ff" ]
1337- channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
1338- colored = np .stack (
1339- [channel_cmaps [ch_ind ](layers [ch ]) for ch_ind , ch in enumerate (channels )],
1340- 0 ,
1341- ).sum (0 )
1342- colored = colored [:, :, :3 ]
1343- elif n_channels == 3 :
1344- seed_colors = _get_colors_for_categorical_obs (list (range (n_channels )))
1345- channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
1346- colored = np .stack (
1347- [channel_cmaps [ind ](layers [ch ]) for ind , ch in enumerate (channels )],
1348- 0 ,
1349- ).sum (0 )
1350- colored = colored [:, :, :3 ]
13511357 else :
13521358 if isinstance (render_params .cmap_params , list ):
13531359 cmap_is_default = render_params .cmap_params [0 ].cmap_is_default
@@ -1364,24 +1370,9 @@ def _render_images(
13641370 ]
13651371 else :
13661372 seed_colors = [render_params .cmap_params .cmap (i / (n_channels - 1 )) for i in range (n_channels )]
1367- channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
1368-
1369- # Stack (n_channels, height, width) → (height*width, n_channels)
1370- H , W = next (iter (layers .values ())).shape
1371- comp_rgb = np .zeros ((H , W , 3 ), dtype = float )
13721373
1373- # For each channel: map to RGBA, apply constant alpha, then add
1374- for ch_idx , ch in enumerate (channels ):
1375- layer_arr = layers [ch ]
1376- rgba = channel_cmaps [ch_idx ](layer_arr )
1377- rgba [..., 3 ] = render_params .alpha
1378- comp_rgb += rgba [..., :3 ] * rgba [..., 3 ][..., None ]
1379-
1380- colored = np .clip (comp_rgb , 0 , 1 )
1381- logger .info (
1382- f"Your image has { n_channels } channels. Sampling categorical colors and using "
1383- f"multichannel strategy 'stack' to render."
1384- ) # TODO: update when pca is added as strategy
1374+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
1375+ colored = _additive_blend (layers , channels , channel_cmaps )
13851376
13861377 _ax_show_and_transform (
13871378 colored ,
@@ -1396,9 +1387,8 @@ def _render_images(
13961387 if len (palette ) != n_channels :
13971388 raise ValueError ("If 'palette' is provided, its length must match the number of channels." )
13981389
1399- channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in palette if isinstance (c , str )]
1400- colored = np .stack ([channel_cmaps [i ](layers [c ]) for i , c in enumerate (channels )], 0 ).sum (0 )
1401- colored = colored [:, :, :3 ]
1390+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in palette ]
1391+ colored = _additive_blend (layers , channels , channel_cmaps )
14021392
14031393 _ax_show_and_transform (
14041394 colored ,
@@ -1410,14 +1400,7 @@ def _render_images(
14101400
14111401 elif palette is None and got_multiple_cmaps :
14121402 channel_cmaps = [cp .cmap for cp in render_params .cmap_params ] # type: ignore[union-attr]
1413- colored = (
1414- np .stack (
1415- [channel_cmaps [ind ](layers [ch ]) for ind , ch in enumerate (channels )],
1416- 0 ,
1417- ).sum (0 )
1418- / n_channels
1419- )
1420- colored = colored [:, :, :3 ]
1403+ colored = _additive_blend (layers , channels , channel_cmaps )
14211404
14221405 _ax_show_and_transform (
14231406 colored ,
0 commit comments