Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,23 @@ def _add_legend_and_colorbar(
)


def _check_instance_ids_overlap(
sdata: sd.SpatialData,
table_name: str,
element_name: str,
element_index: abc.Iterable[Any],
) -> None:
"""Raise a clear error when a table annotates an element but no instance IDs overlap (#603)."""
_, region_key, instance_key = get_table_keys(sdata[table_name])
annotating = sdata[table_name].obs[sdata[table_name].obs[region_key].isin([element_name])]
if len(annotating) > 0 and set(annotating[instance_key]).isdisjoint(set(element_index)):
raise ValueError(
f"No instance IDs overlap between table '{table_name}' (instance_key='{instance_key}') "
f"and element '{element_name}'. Check that the table's '{instance_key}' column matches the "
f"element's index."
)


def _render_shapes(
sdata: sd.SpatialData,
render_params: ShapesRenderParams,
Expand All @@ -336,6 +353,9 @@ def _render_shapes(
table = None
shapes = sdata_filt[element]
else:
# check before mutating obs.index.name below so a failure leaves no half-restored state
_check_instance_ids_overlap(sdata, table_name, element, sdata_filt[element].index)

# Workaround for upstream spatialdata bug (scverse/spatialdata#1099):
# join_spatialelement_table calls table.obs.reset_index() which fails when
# the obs index name matches an existing column (e.g. "EntityID" in Merfish data).
Expand Down Expand Up @@ -742,6 +762,9 @@ def _render_points(

added_color_from_table = False
if col_for_color is not None and col_for_color not in points.columns:
if table_name is not None:
# guard against disjoint instance IDs (#603) for a clearer error than KeyError: None
_check_instance_ids_overlap(sdata, table_name, element, points.index)
color_values = get_values(
value_key=col_for_color,
sdata=sdata_filt,
Expand Down Expand Up @@ -1651,6 +1674,7 @@ def _render_labels(
instance_id = np.unique(label)
table = None
else:
_check_instance_ids_overlap(sdata, table_name, element, np.unique(label.values))
_, region_key, instance_key = get_table_keys(sdata[table_name])
table = sdata[table_name][sdata[table_name].obs[region_key].isin([element])]

Expand Down
29 changes: 29 additions & 0 deletions tests/pl/test_render_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,32 @@ def test_render_labels_rejects_background_instance_id_in_table():
sdata.pl.render_labels("lbl", color="score", table_name="t").pl.show(ax=ax)
finally:
plt.close(fig)


def test_render_labels_disjoint_instance_ids_clear_error():
# regression test for #603: disjoint instance_id values must raise a clear ValueError
arr = np.zeros((20, 20), dtype=np.int32)
arr[3:8, 3:8] = 1
arr[12:17, 12:17] = 2
obs = pd.DataFrame(
{
"instance_id": [99, 100], # label has IDs 1, 2 (no overlap)
"region": pd.Categorical(["lbl"] * 2),
"cat": pd.Categorical(["A", "B"]),
}
)
obs.index = obs.index.astype(str)
table = TableModel.parse(
AnnData(X=np.zeros((2, 1)), obs=obs),
region=["lbl"],
region_key="region",
instance_key="instance_id",
)
sdata = SpatialData(labels={"lbl": Labels2DModel.parse(arr, dims=["y", "x"])}, tables={"t": table})

fig, ax = plt.subplots()
try:
with pytest.raises(ValueError, match=r"No instance IDs overlap.*table 't'.*element 'lbl'"):
sdata.pl.render_labels("lbl", color="cat", table_name="t").pl.show(ax=ax)
finally:
plt.close(fig)
27 changes: 27 additions & 0 deletions tests/pl/test_render_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,3 +1004,30 @@ def test_no_table_fallback_warning_for_element_column(caplog):
with logger_no_warns(caplog, logger, match="fallback for color mapping"):
sdata.pl.render_points("pts", color="cell_type").pl.show()
plt.close("all")


def test_render_points_disjoint_instance_ids_clear_error():
# regression test for #603: disjoint instance_id values must raise a clear ValueError
points = PointsModel.parse(pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1.0, 2.0, 3.0]}))
obs = pd.DataFrame(
{
"instance_id": [99, 100, 101], # points index is 0, 1, 2 (no overlap)
"region": pd.Categorical(["pts"] * 3),
"cat": pd.Categorical(["A", "B", "C"]),
}
)
obs.index = obs.index.astype(str)
table = TableModel.parse(
AnnData(X=np.zeros((3, 1)), obs=obs),
region=["pts"],
region_key="region",
instance_key="instance_id",
)
sdata = SpatialData(points={"pts": points}, tables={"t": table})

fig, ax = plt.subplots()
try:
with pytest.raises(ValueError, match=r"No instance IDs overlap.*table 't'.*element 'pts'"):
sdata.pl.render_points("pts", color="cat", table_name="t").pl.show(ax=ax)
finally:
plt.close(fig)
33 changes: 30 additions & 3 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,9 +1076,7 @@ def test_gene_symbols_missing_column_raises_auto_detect(sdata_blobs: SpatialData
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles"
sdata_blobs["table"].var["gene_symbol"] = ["GeneA", "GeneB", "GeneC"]
with pytest.raises(KeyError, match="`gene_symbols=`"):
sdata_blobs.pl.render_shapes(
"blobs_circles", color="GeneA", gene_symbols="WRONGCOL"
).pl.show()
sdata_blobs.pl.render_shapes("blobs_circles", color="GeneA", gene_symbols="WRONGCOL").pl.show()


def test_groups_na_color_none_no_match_shapes(sdata_blobs: SpatialData):
Expand Down Expand Up @@ -1351,6 +1349,35 @@ def test_render_shapes_color_with_conflicting_index_name():
sdata.pl.render_shapes("shapes", color="cell_type", table_name="table").pl.show()


def test_render_shapes_disjoint_instance_ids_clear_error():
# regression test for #603: disjoint instance_id values must raise a clear ValueError
shapes = ShapesModel.parse(
gpd.GeoDataFrame({"geometry": [Point(5, 5), Point(15, 5), Point(25, 5)], "radius": [2.0] * 3})
)
obs = pd.DataFrame(
{
"instance_id": [99, 100, 101], # element has IDs 0, 1, 2 (no overlap)
"region": pd.Categorical(["s"] * 3),
"cat": pd.Categorical(["A", "B", "C"]),
}
)
obs.index = obs.index.astype(str)
table = TableModel.parse(
AnnData(X=np.zeros((3, 1)), obs=obs),
region=["s"],
region_key="region",
instance_key="instance_id",
)
sdata = SpatialData(shapes={"s": shapes}, tables={"t": table})

fig, ax = plt.subplots()
try:
with pytest.raises(ValueError, match=r"No instance IDs overlap.*table 't'.*element 's'"):
sdata.pl.render_shapes("s", color="cat", table_name="t").pl.show(ax=ax)
finally:
plt.close(fig)


def test_datashader_colorbar_range_matches_data(sdata_blobs: SpatialData):
"""Datashader colorbar range must not exceed the actual data range for shapes.

Expand Down
Loading