Skip to content

Commit f2404e0

Browse files
timtreisclaude
andcommitted
Fix IndexError in _build_datashader_color_key and restore _hex_no_alpha
- Convert color_vector to numpy array before indexing to avoid IndexError when color_vector is shorter than cat_series (fixes CI failure in test_color_key_unseen_category_gets_na_color) - Guard against out-of-bounds index when unique code appears beyond the color_vector length - Revert inline c[:7] slice back to _hex_no_alpha for correctness (the slice silently mangles non-hex 9-char strings) - Use explicit `is not None` check for col_for_color guard Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 54dd3db commit f2404e0

2 files changed

Lines changed: 6 additions & 5 deletions

File tree

src/spatialdata_plot/pl/_datashader.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,11 @@ def _build_datashader_color_key(
6161
) -> dict[str, str]:
6262
"""Build a datashader ``color_key`` dict from a categorical series and its color vector."""
6363
na_hex = _hex_no_alpha(na_color_hex) if na_color_hex.startswith("#") else na_color_hex
64+
colors_arr = np.asarray(color_vector, dtype=object)
6465
categories = np.asarray(cat_series.categories, dtype=str)
6566
codes = np.asarray(cat_series.codes)
6667

67-
if len(color_vector) != len(codes):
68+
if len(colors_arr) != len(codes):
6869
logger.warning(
6970
f"color_vector length ({len(color_vector)}) does not match categorical series length "
7071
f"({len(codes)}); some categories may receive the na_color fallback."
@@ -76,9 +77,9 @@ def _build_datashader_color_key(
7677

7778
first_color: dict[str, str] = {}
7879
for code, idx in zip(unique_codes, first_indices, strict=True):
79-
if code < 0:
80+
if code < 0 or idx >= len(colors_arr):
8081
continue
81-
c = color_vector[idx]
82+
c = colors_arr[idx]
8283
first_color[categories[code]] = _hex_no_alpha(c) if isinstance(c, str) and c.startswith("#") else c
8384

8485
return {cat: first_color.get(cat, na_hex) for cat in categories}

src/spatialdata_plot/pl/render.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ def _render_points(
786786

787787
# When color was already loaded from a table (line 690), pass it directly
788788
# to avoid a redundant get_values() call inside _set_color_source_vec.
789-
_preloaded = points_pd_with_color[col_for_color] if added_color_from_table and col_for_color else None
789+
_preloaded = points_pd_with_color[col_for_color] if added_color_from_table and col_for_color is not None else None
790790

791791
color_source_vector, color_vector, _ = _set_color_source_vec(
792792
sdata=sdata_filt,
@@ -926,7 +926,7 @@ def _render_points(
926926
and isinstance(color_vector[0], str)
927927
and color_vector[0].startswith("#")
928928
):
929-
color_vector = np.asarray([c[:7] if len(c) == 9 else c for c in color_vector])
929+
color_vector = np.asarray([_hex_no_alpha(c) for c in color_vector])
930930

931931
nan_shaded = None
932932
if color_by_categorical or col_for_color is None:

0 commit comments

Comments
 (0)