Skip to content

Commit 68abcb7

Browse files
timtreisclaude
andcommitted
Address review: normalize non-float dtypes, fix edge cases
- Normalize uint8/uint16 RGB(A) images to [0, 1] before passing to imshow - Add length check to _is_rgb_image to prevent false positives with duplicate channel names - Remove .squeeze() calls that could drop spatial dimensions - Add tests for uint8, uint16, and duplicate channel name edge cases Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e1cb399 commit 68abcb7

2 files changed

Lines changed: 46 additions & 4 deletions

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,9 +1034,9 @@ def _is_rgb_image(channel_coords: list[Any]) -> tuple[bool, bool]:
10341034
(is_rgb, has_alpha) — whether the image is RGB and whether it includes an alpha channel.
10351035
"""
10361036
names = {str(c).lower() for c in channel_coords}
1037-
if names == {"r", "g", "b", "a"}:
1037+
if names == {"r", "g", "b", "a"} and len(channel_coords) == 4:
10381038
return True, True
1039-
if names == {"r", "g", "b"}:
1039+
if names == {"r", "g", "b"} and len(channel_coords) == 3:
10401040
return True, False
10411041
return False, False
10421042

@@ -1114,12 +1114,26 @@ def _render_images(
11141114
if is_rgb and palette is None and not got_multiple_cmaps and not has_explicit_cmap:
11151115
coord_map = {str(c).lower(): c for c in channels}
11161116
ordered = [coord_map[ch] for ch in ("r", "g", "b")]
1117-
stacked = np.moveaxis(img.sel(c=ordered).squeeze().values, 0, -1)
1117+
stacked = np.moveaxis(img.sel(c=ordered).values, 0, -1)
1118+
1119+
# Normalize to [0, 1] for matplotlib: uint8 → /255, other int dtypes → /max, float → clip
1120+
if stacked.dtype == np.uint8:
1121+
stacked = stacked.astype(np.float64) / 255.0
1122+
elif stacked.dtype.kind in ("u", "i"):
1123+
stacked = stacked.astype(np.float64) / np.iinfo(stacked.dtype).max
1124+
else:
1125+
stacked = np.clip(stacked, 0, 1)
11181126

11191127
show_kwargs: dict[str, Any] = {"zorder": render_params.zorder}
11201128

11211129
if has_alpha and render_params.alpha == 1.0:
1122-
alpha_layer = np.clip(img.sel(c=coord_map["a"]).squeeze().values, 0, 1)
1130+
alpha_raw = img.sel(c=coord_map["a"]).values
1131+
if alpha_raw.dtype == np.uint8:
1132+
alpha_layer = alpha_raw.astype(np.float64) / 255.0
1133+
elif alpha_raw.dtype.kind in ("u", "i"):
1134+
alpha_layer = alpha_raw.astype(np.float64) / np.iinfo(alpha_raw.dtype).max
1135+
else:
1136+
alpha_layer = np.clip(alpha_raw.astype(np.float64), 0, 1)
11231137
stacked = np.concatenate([stacked, alpha_layer[..., np.newaxis]], axis=-1)
11241138
else:
11251139
show_kwargs["alpha"] = render_params.alpha

tests/pl/test_render_images.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ def test_partial_rgb_not_detected(self):
181181
def test_rgb_with_extra_channel_not_detected(self):
182182
assert _is_rgb_image(["r", "g", "b", "x"]) == (False, False)
183183

184+
def test_duplicate_channel_names_not_detected(self):
185+
assert _is_rgb_image(["r", "g", "b", "b"]) == (False, False)
186+
184187

185188
class TestRGBARendering:
186189
"""Regression tests for #406: RGBA images rendered correctly."""
@@ -242,3 +245,28 @@ def test_explicit_alpha_overrides_per_pixel(self):
242245
fig, ax = plt.subplots()
243246
sdata.pl.render_images("img", alpha=0.3).pl.show(ax=ax)
244247
plt.close("all")
248+
249+
def test_uint8_rgb_renders(self):
250+
"""uint8 RGB image should be normalized to [0, 1] and render correctly."""
251+
data = np.zeros((3, 50, 50), dtype=np.uint8)
252+
data[0] = 200
253+
data[1] = 100
254+
data[2] = 50
255+
img = Image2DModel.parse(data, dims=("c", "y", "x"), c_coords=["r", "g", "b"])
256+
sdata = SpatialData(images={"img": img})
257+
fig, ax = plt.subplots()
258+
sdata.pl.render_images("img").pl.show(ax=ax)
259+
plt.close("all")
260+
261+
def test_uint16_rgba_renders(self):
262+
"""uint16 RGBA image should be normalized and render correctly."""
263+
data = np.zeros((4, 50, 50), dtype=np.uint16)
264+
data[0] = 50000
265+
data[1] = 30000
266+
data[2] = 10000
267+
data[3] = 65535 # fully opaque
268+
img = Image2DModel.parse(data, dims=("c", "y", "x"), c_coords=["r", "g", "b", "a"])
269+
sdata = SpatialData(images={"img": img})
270+
fig, ax = plt.subplots()
271+
sdata.pl.render_images("img").pl.show(ax=ax)
272+
plt.close("all")

0 commit comments

Comments
 (0)