diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 7ddd957f..84130e82 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -311,6 +311,23 @@ def _add_legend_and_colorbar( ) +def _check_instance_ids_overlap( + sdata: sd.SpatialData, + table_name: str, + element_name: str, + element_index: abc.Iterable[Any], +) -> None: + """Raise a clear error when a table annotates an element but no instance IDs overlap (#603).""" + _, region_key, instance_key = get_table_keys(sdata[table_name]) + annotating = sdata[table_name].obs[sdata[table_name].obs[region_key].isin([element_name])] + if len(annotating) > 0 and set(annotating[instance_key]).isdisjoint(set(element_index)): + raise ValueError( + f"No instance IDs overlap between table '{table_name}' (instance_key='{instance_key}') " + f"and element '{element_name}'. Check that the table's '{instance_key}' column matches the " + f"element's index." + ) + + def _render_shapes( sdata: sd.SpatialData, render_params: ShapesRenderParams, @@ -336,6 +353,8 @@ def _render_shapes( table = None shapes = sdata_filt[element] else: + _check_instance_ids_overlap(sdata_filt, table_name, element, sdata_filt[element].index) + # Workaround for upstream spatialdata bug (scverse/spatialdata#1099): # join_spatialelement_table calls table.obs.reset_index() which fails when # the obs index name matches an existing column (e.g. "EntityID" in Merfish data). @@ -742,6 +761,9 @@ def _render_points( added_color_from_table = False if col_for_color is not None and col_for_color not in points.columns: + if table_name is not None: + # guard against disjoint instance IDs (#603) for a clearer error than KeyError: None + _check_instance_ids_overlap(sdata_filt, table_name, element, points.index) color_values = get_values( value_key=col_for_color, sdata=sdata_filt, @@ -1651,6 +1673,7 @@ def _render_labels( instance_id = np.unique(label) table = None else: + _check_instance_ids_overlap(sdata_filt, table_name, element, np.unique(label.values)) _, region_key, instance_key = get_table_keys(sdata[table_name]) table = sdata[table_name][sdata[table_name].obs[region_key].isin([element])] diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index 431a0dc1..931ac24c 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -521,3 +521,32 @@ def test_render_labels_rejects_background_instance_id_in_table(): sdata.pl.render_labels("lbl", color="score", table_name="t").pl.show(ax=ax) finally: plt.close(fig) + + +def test_render_labels_disjoint_instance_ids_clear_error(): + # regression test for #603: disjoint instance_id values must raise a clear ValueError + arr = np.zeros((20, 20), dtype=np.int32) + arr[3:8, 3:8] = 1 + arr[12:17, 12:17] = 2 + obs = pd.DataFrame( + { + "instance_id": [99, 100], # label has IDs 1, 2 (no overlap) + "region": pd.Categorical(["lbl"] * 2), + "cat": pd.Categorical(["A", "B"]), + } + ) + obs.index = obs.index.astype(str) + table = TableModel.parse( + AnnData(X=np.zeros((2, 1)), obs=obs), + region=["lbl"], + region_key="region", + instance_key="instance_id", + ) + sdata = SpatialData(labels={"lbl": Labels2DModel.parse(arr, dims=["y", "x"])}, tables={"t": table}) + + fig, ax = plt.subplots() + try: + with pytest.raises(ValueError, match=r"No instance IDs overlap.*table 't'.*element 'lbl'"): + sdata.pl.render_labels("lbl", color="cat", table_name="t").pl.show(ax=ax) + finally: + plt.close(fig) diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index bdb3218a..dcc4267c 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -1004,3 +1004,30 @@ def test_no_table_fallback_warning_for_element_column(caplog): with logger_no_warns(caplog, logger, match="fallback for color mapping"): sdata.pl.render_points("pts", color="cell_type").pl.show() plt.close("all") + + +def test_render_points_disjoint_instance_ids_clear_error(): + # regression test for #603: disjoint instance_id values must raise a clear ValueError + points = PointsModel.parse(pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1.0, 2.0, 3.0]})) + obs = pd.DataFrame( + { + "instance_id": [99, 100, 101], # points index is 0, 1, 2 (no overlap) + "region": pd.Categorical(["pts"] * 3), + "cat": pd.Categorical(["A", "B", "C"]), + } + ) + obs.index = obs.index.astype(str) + table = TableModel.parse( + AnnData(X=np.zeros((3, 1)), obs=obs), + region=["pts"], + region_key="region", + instance_key="instance_id", + ) + sdata = SpatialData(points={"pts": points}, tables={"t": table}) + + fig, ax = plt.subplots() + try: + with pytest.raises(ValueError, match=r"No instance IDs overlap.*table 't'.*element 'pts'"): + sdata.pl.render_points("pts", color="cat", table_name="t").pl.show(ax=ax) + finally: + plt.close(fig) diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index e2c436e4..a401fef8 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -1076,9 +1076,7 @@ def test_gene_symbols_missing_column_raises_auto_detect(sdata_blobs: SpatialData sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles" sdata_blobs["table"].var["gene_symbol"] = ["GeneA", "GeneB", "GeneC"] with pytest.raises(KeyError, match="`gene_symbols=`"): - sdata_blobs.pl.render_shapes( - "blobs_circles", color="GeneA", gene_symbols="WRONGCOL" - ).pl.show() + sdata_blobs.pl.render_shapes("blobs_circles", color="GeneA", gene_symbols="WRONGCOL").pl.show() def test_groups_na_color_none_no_match_shapes(sdata_blobs: SpatialData): @@ -1351,6 +1349,35 @@ def test_render_shapes_color_with_conflicting_index_name(): sdata.pl.render_shapes("shapes", color="cell_type", table_name="table").pl.show() +def test_render_shapes_disjoint_instance_ids_clear_error(): + # regression test for #603: disjoint instance_id values must raise a clear ValueError + shapes = ShapesModel.parse( + gpd.GeoDataFrame({"geometry": [Point(5, 5), Point(15, 5), Point(25, 5)], "radius": [2.0] * 3}) + ) + obs = pd.DataFrame( + { + "instance_id": [99, 100, 101], # element has IDs 0, 1, 2 (no overlap) + "region": pd.Categorical(["s"] * 3), + "cat": pd.Categorical(["A", "B", "C"]), + } + ) + obs.index = obs.index.astype(str) + table = TableModel.parse( + AnnData(X=np.zeros((3, 1)), obs=obs), + region=["s"], + region_key="region", + instance_key="instance_id", + ) + sdata = SpatialData(shapes={"s": shapes}, tables={"t": table}) + + fig, ax = plt.subplots() + try: + with pytest.raises(ValueError, match=r"No instance IDs overlap.*table 't'.*element 's'"): + sdata.pl.render_shapes("s", color="cat", table_name="t").pl.show(ax=ax) + finally: + plt.close(fig) + + def test_datashader_colorbar_range_matches_data(sdata_blobs: SpatialData): """Datashader colorbar range must not exceed the actual data range for shapes.