Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,10 @@ def _render_shapes(

shapes = sdata_filt[element]

# Capture the transformation *before* any groups filtering that may strip
# coordinate-system metadata from the element (see #420, #447).
trans, trans_data = _prepare_transformation(sdata_filt.shapes[element], coordinate_system)

# get color vector (categorical or continuous)
color_source_vector, color_vector, _ = _set_color_source_vec(
sdata=sdata_filt,
Expand Down Expand Up @@ -425,9 +429,6 @@ def _render_shapes(
# necessary in case different shapes elements are annotated with one table
color_source_vector = color_source_vector.remove_unused_categories()

# Apply the transformation to the PatchCollection's paths
trans, trans_data = _prepare_transformation(sdata_filt.shapes[element], coordinate_system)

shapes = gpd.GeoDataFrame(shapes, geometry="geometry")
# convert shapes if necessary
if render_params.shape is not None:
Expand Down
43 changes: 42 additions & 1 deletion tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from shapely.geometry import MultiPolygon, Point, Polygon
from spatialdata import SpatialData, deepcopy
from spatialdata.models import ShapesModel, TableModel
from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Sequence, Translation
from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Sequence, Translation, set_transformation
from spatialdata.transformations._utils import _set_transformations

import spatialdata_plot # noqa: F401
Expand Down Expand Up @@ -1067,6 +1067,47 @@ def test_plot_can_handle_non_numeric_radius_values(sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_circles", color="red").pl.show()


def test_groups_filtering_preserves_transformation(sdata_blobs: SpatialData):
"""Regression test for #420: groups filtering must not strip coordinate-system metadata.

Simulates the exact sequence that ``_render_shapes`` performs —
filter_by_coordinate_system -> groups boolean-index -> reset_index ->
re-assign to sdata_filt -> GeoDataFrame re-wrap — then asserts that
``_prepare_transformation`` can still retrieve the correct transformation.
"""
from spatialdata_plot.pl.utils import _prepare_transformation

scale_factor = 2.5
cs = "not_global"
set_transformation(
sdata_blobs["blobs_polygons"],
transformation={cs: Scale([scale_factor, scale_factor], axes=("x", "y"))},
set_all=True,
)
sdata_blobs.shapes["blobs_polygons"]["cluster"] = pd.Categorical(["c1", "c2", "c1", "c2", "c1"])

sdata_filt = sdata_blobs.filter_by_coordinate_system(coordinate_system=cs, filter_tables=False)

# Replicate groups filtering: boolean-index -> reset_index -> re-assign
shapes = sdata_filt.shapes["blobs_polygons"]
keep = shapes["cluster"] == "c1"
shapes = shapes[keep].reset_index(drop=True)
sdata_filt["blobs_polygons"] = shapes
# GeoDataFrame re-wrap strips .attrs (this is what _render_shapes does next)
shapes = gpd.GeoDataFrame(shapes, geometry="geometry")

# sdata_filt's element must still carry the correct transformation
trans, _ = _prepare_transformation(sdata_filt.shapes["blobs_polygons"], cs)
matrix = trans.get_matrix()
np.testing.assert_allclose(matrix[0, 0], scale_factor, err_msg="x-scale lost after groups filtering")
np.testing.assert_allclose(matrix[1, 1], scale_factor, err_msg="y-scale lost after groups filtering")

# The GeoDataFrame re-wrap strips attrs — reading the transform from
# the re-wrapped object must fail, proving why early capture matters.
with pytest.raises(AssertionError):
_prepare_transformation(shapes, cs)


def test_plot_can_handle_mixed_numeric_and_color_data(sdata_blobs: SpatialData):
"""Test that mixed numeric and color-like data raises a clear error."""
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)
Expand Down
Loading