diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index a675e00e..8f874fd0 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -2479,9 +2479,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st palette_group = param_dict.get("palette") if element_type in ["shapes", "points", "labels"] and palette_group is not None and not isinstance(palette, dict): groups = param_dict.get("groups") - if groups is None: - raise ValueError("When specifying 'palette', 'groups' must also be specified.") - if len(groups) != len(palette_group): + if groups is not None and len(groups) != len(palette_group): raise ValueError( f"The length of 'palette' and 'groups' must be the same, length is {len(palette_group)} and" f"{len(groups)} respectively." diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 815071d9..fcf520b6 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -275,6 +275,23 @@ def test_plot_coloring_with_palette(self, sdata_blobs: SpatialData): "blobs_polygons", color="cluster", groups=["c2", "c1"], palette=["green", "yellow"] ).pl.show() + def test_render_shapes_list_palette_without_groups(self, sdata_blobs: SpatialData): + # Regression test for #605: a list palette should map to categories in their natural order + # without requiring groups= to enumerate every category. + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" + sdata_blobs.shapes["blobs_polygons"]["cluster"] = "c1" + sdata_blobs.shapes["blobs_polygons"].iloc[3:5, 1] = "c2" + sdata_blobs.shapes["blobs_polygons"]["cluster"] = sdata_blobs.shapes["blobs_polygons"]["cluster"].astype( + "category" + ) + + _, ax = plt.subplots() + sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster", palette=["green", "yellow"]).pl.show(ax=ax) + legend = ax.get_legend() + assert legend is not None + assert {t.get_text() for t in legend.get_texts()} == {"c1", "c2"} + def test_plot_colorbar_respects_input_limits(self, sdata_blobs: SpatialData): sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"