Skip to content

Commit fc796d6

Browse files
committed
Raise on color name collision with column in render_* (#619)
When `color=<str>` was both a valid matplotlib color and a column name in the element or an annotating table, the literal color silently won and the column was ignored (only a default-hidden `logger.info` was emitted). Now raise `ValueError` with disambiguation guidance: pass hex (e.g. `"#ffa500"`), an RGB(A) tuple, or rename the column.
1 parent 3749098 commit fc796d6

5 files changed

Lines changed: 106 additions & 2 deletions

File tree

src/spatialdata_plot/pl/basic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,9 @@ def render_shapes(
293293
-----
294294
- Empty geometries will be removed at the time of plotting.
295295
- An `outline_width` of 0.0 leads to no border being plotted.
296-
- When passing a color-like to 'color', this has precedence over the potential existence as a column name.
296+
- If ``color`` is a string that is both a matplotlib color name and a column name in the
297+
element or an annotating table, a ``ValueError`` is raised. Disambiguate by passing
298+
a hex string (e.g. ``"#ffa500"``) or an RGB(A) tuple, or by renaming the column.
297299
298300
Returns
299301
-------

src/spatialdata_plot/pl/utils.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2280,6 +2280,41 @@ def _validate_show_parameters(
22802280
)
22812281

22822282

2283+
def _check_color_column_collision(
2284+
sdata: SpatialData,
2285+
elements: list[str],
2286+
color: str,
2287+
element_type: str,
2288+
) -> None:
2289+
"""Raise if ``color`` is a color-like string that also names a column in the element or its tables."""
2290+
matches: list[str] = []
2291+
for el in elements:
2292+
if element_type in {"shapes", "points"}:
2293+
try:
2294+
el_cols = sdata[el].columns
2295+
except (KeyError, AttributeError):
2296+
el_cols = ()
2297+
if color in el_cols:
2298+
matches.append(f"element '{el}'")
2299+
continue
2300+
try:
2301+
tables = get_element_annotators(sdata, el)
2302+
except (KeyError, ValueError):
2303+
tables = set()
2304+
for t in tables:
2305+
adata = sdata[t]
2306+
if color in adata.obs.columns or color in adata.var_names:
2307+
matches.append(f"table '{t}' (annotating '{el}')")
2308+
break
2309+
if matches:
2310+
locations = ", ".join(sorted(set(matches)))
2311+
raise ValueError(
2312+
f"`color={color!r}` is ambiguous: it is a valid matplotlib color name AND a column "
2313+
f"name in {locations}. Disambiguate by either passing an unambiguous color form "
2314+
f"(hex string like '#ffa500' or an RGB(A) tuple), or by renaming the column."
2315+
)
2316+
2317+
22832318
def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[str, Any]:
22842319
colorbar = param_dict.get("colorbar", "auto")
22852320
if colorbar not in {True, False, None, "auto"}:
@@ -2330,7 +2365,8 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
23302365
if not isinstance(color, str | tuple | list):
23312366
raise TypeError("Parameter 'color' must be a string or a tuple/list of floats.")
23322367
if _is_color_like(color):
2333-
logger.info("Value for parameter 'color' appears to be a color, using it as such.")
2368+
if isinstance(color, str):
2369+
_check_color_column_collision(param_dict["sdata"], param_dict["element"], color, element_type)
23342370
param_dict["col_for_color"] = None
23352371
param_dict["color"] = Color(color)
23362372
if param_dict["color"].alpha_is_user_defined():

tests/pl/test_render_labels.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,3 +550,29 @@ def test_render_labels_disjoint_instance_ids_clear_error():
550550
sdata.pl.render_labels("lbl", color="cat", table_name="t").pl.show(ax=ax)
551551
finally:
552552
plt.close(fig)
553+
554+
555+
def test_render_labels_color_column_name_collision_raises():
556+
# regression test for #619: color="orange" + obs column "orange" must raise.
557+
arr = np.zeros((10, 10), dtype=np.int32)
558+
arr[2:5, 2:5] = 1
559+
arr[6:9, 6:9] = 2
560+
obs = pd.DataFrame(
561+
{
562+
"region": pd.Categorical(["lbl"] * 2),
563+
"instance_id": [1, 2],
564+
"orange": pd.Categorical(["A", "B"]),
565+
}
566+
)
567+
table = TableModel.parse(
568+
AnnData(X=np.zeros((2, 1)), obs=obs),
569+
region="lbl",
570+
region_key="region",
571+
instance_key="instance_id",
572+
)
573+
sdata = SpatialData(labels={"lbl": Labels2DModel.parse(arr, dims=["y", "x"])}, tables={"t": table})
574+
575+
with pytest.raises(ValueError, match=r"color='orange'.*ambiguous.*column"):
576+
sdata.pl.render_labels("lbl", color="orange", table_name="t")
577+
578+
sdata.pl.render_labels("lbl", color="#ffa500", table_name="t")

tests/pl/test_render_points.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,3 +1031,16 @@ def test_render_points_disjoint_instance_ids_clear_error():
10311031
sdata.pl.render_points("pts", color="cat", table_name="t").pl.show(ax=ax)
10321032
finally:
10331033
plt.close(fig)
1034+
1035+
1036+
def test_render_points_color_column_name_collision_raises():
1037+
# regression test for #619: color="orange" + element column "orange" must raise.
1038+
points = PointsModel.parse(
1039+
pd.DataFrame({"x": [1.0, 2.0, 3.0, 4.0], "y": [1.0, 2.0, 3.0, 4.0], "orange": [0.1, 0.2, 0.3, 0.4]})
1040+
)
1041+
sdata = SpatialData(points={"pts": points})
1042+
1043+
with pytest.raises(ValueError, match=r"color='orange'.*ambiguous.*column"):
1044+
sdata.pl.render_points("pts", color="orange")
1045+
1046+
sdata.pl.render_points("pts", color="#ffa500")

tests/pl/test_render_shapes.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,33 @@ def test_render_shapes_disjoint_instance_ids_clear_error():
13781378
plt.close(fig)
13791379

13801380

1381+
def test_render_shapes_color_column_name_collision_raises():
1382+
# regression test for #619: color="orange" + column "orange" must raise rather than silently
1383+
# treat the value as a literal color and shadow the column.
1384+
n = 4
1385+
shapes = ShapesModel.parse(gpd.GeoDataFrame({"geometry": [Point(i, 0) for i in range(n)], "radius": [0.5] * n}))
1386+
obs = pd.DataFrame(
1387+
{
1388+
"region": pd.Categorical(["s"] * n),
1389+
"instance_id": list(range(n)),
1390+
"orange": pd.Categorical(["A", "B", "A", "B"]),
1391+
}
1392+
)
1393+
table = TableModel.parse(
1394+
AnnData(X=np.zeros((n, 1)), obs=obs),
1395+
region="s",
1396+
region_key="region",
1397+
instance_key="instance_id",
1398+
)
1399+
sdata = SpatialData(shapes={"s": shapes}, tables={"t": table})
1400+
1401+
with pytest.raises(ValueError, match=r"color='orange'.*ambiguous.*column"):
1402+
sdata.pl.render_shapes("s", color="orange")
1403+
1404+
sdata.pl.render_shapes("s", color="#ffa500")
1405+
sdata.pl.render_shapes("s", color=(1.0, 0.65, 0.0))
1406+
1407+
13811408
def test_datashader_colorbar_range_matches_data(sdata_blobs: SpatialData):
13821409
"""Datashader colorbar range must not exceed the actual data range for shapes.
13831410

0 commit comments

Comments
 (0)