Skip to content

Commit ccc4954

Browse files
committed
bugfix
1 parent 4af92f5 commit ccc4954

2 files changed

Lines changed: 8 additions & 14 deletions

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -846,8 +846,7 @@ def _render_images(
846846
else:
847847
layers = {}
848848
for ch_idx, ch in enumerate(channels):
849-
print(channels, ch)
850-
layers[ch_idx] = img.sel(c=ch).copy(deep=True).squeeze()
849+
layers[ch] = img.sel(c=ch).copy(deep=True).squeeze()
851850
if isinstance(render_params.cmap_params, list):
852851
ch_norm = render_params.cmap_params[ch_idx].norm
853852
ch_cmap_is_default = render_params.cmap_params[ch_idx].cmap_is_default
@@ -861,7 +860,7 @@ def _render_images(
861860
# 2A) Image has 3 channels, no palette info, and no/only one cmap was given
862861
if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list):
863862
if render_params.cmap_params.cmap_is_default: # -> use RGB
864-
stacked = np.stack([layers[ch_idx] for ch_idx in layers], axis=-1)
863+
stacked = np.stack([layers[ch] for ch in layers], axis=-1)
865864
else: # -> use given cmap for each channel
866865
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
867866
stacked = (
@@ -896,7 +895,7 @@ def _render_images(
896895
seed_colors = ["#ff0000ff", "#00ff00ff"]
897896
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
898897
colored = np.stack(
899-
[channel_cmaps[ch_ind](layers[ch_ind]) for ch_ind, ch in enumerate(channels)],
898+
[channel_cmaps[ch_ind](layers[ch]) for ch_ind, ch in enumerate(channels)],
900899
0,
901900
).sum(0)
902901
colored = colored[:, :, :3]
@@ -936,9 +935,9 @@ def _render_images(
936935
comp_rgb = np.zeros((H, W, 3), dtype=float)
937936

938937
# For each channel: map to RGBA, apply constant alpha, then add
939-
for idx, ch in enumerate(channels):
940-
layer_arr = layers[idx]
941-
rgba = channel_cmaps[idx](layer_arr)
938+
for ch_idx, ch in enumerate(channels):
939+
layer_arr = layers[ch]
940+
rgba = channel_cmaps[ch_idx](layer_arr)
942941
rgba[..., 3] = render_params.alpha
943942
comp_rgb += rgba[..., :3] * rgba[..., 3][..., None]
944943

@@ -955,8 +954,8 @@ def _render_images(
955954
H, W = next(iter(layers.values())).shape
956955
pixel_matrix = np.stack(
957956
[
958-
(layers[ch_idx].data.ravel() if hasattr(layers[ch_idx], "data") else layers[ch_idx].ravel())
959-
for ch_idx, _ in enumerate(channels)
957+
(layers[ch].data.ravel() if hasattr(layers[ch], "data") else layers[ch].ravel())
958+
for ch in channels
960959
],
961960
axis=1,
962961
)

src/spatialdata_plot/pl/utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2072,12 +2072,7 @@ def _validate_image_render_params(
20722072
# If channel is a list, ensure all elements are the same type
20732073
if not (isinstance(channel, list) and channel and all(isinstance(c, type(channel[0])) for c in channel)):
20742074
raise TypeError("Each item in 'channel' list must be of the same type, either string or integer.")
2075-
# At this point, channel is either list[str] or list[int]
20762075

2077-
channels_are_strings = isinstance(channel[0], str)
2078-
channels_are_ints = isinstance(channel[0], int)
2079-
2080-
print(channel, spatial_element_ch)
20812076
invalid = [c for c in channel if c not in spatial_element_ch]
20822077
if invalid:
20832078
raise ValueError(

0 commit comments

Comments
 (0)