Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,20 @@ def _reparse_points(
df: pd.DataFrame,
transformation: Any,
coordinate_system: str,
color_column: str | None = None,
) -> None:
"""Re-register a points DataFrame in *sdata_filt* with its transformation."""
"""Re-register a points DataFrame in *sdata_filt* with its transformation.

When ``color_column`` is provided and refers to a column that
``PointsModel.parse`` would silently drop because it collides with a
reserved coordinate axis (e.g. a data column literally named ``"z"``),
re-attach it so downstream color lookup can find it as a ``"df"`` origin.
"""
dd_frame = dask.dataframe.from_pandas(df, npartitions=1)
sdata_filt.points[element] = PointsModel.parse(dd_frame, coordinates={"x": "x", "y": "y"})
parsed = PointsModel.parse(dd_frame, coordinates={"x": "x", "y": "y"})
if color_column is not None and color_column in df.columns and color_column not in parsed.columns:
parsed[color_column] = dd_frame[color_column]
sdata_filt.points[element] = parsed
set_transformation(
element=sdata_filt.points[element],
transformation=transformation,
Expand Down Expand Up @@ -820,7 +830,7 @@ def _render_points(

# Convert back to dask dataframe to modify sdata
transformation_in_cs = sdata_filt.points[element].attrs["transform"][coordinate_system]
_reparse_points(sdata_filt, element, points_for_model, transformation_in_cs, coordinate_system)
_reparse_points(sdata_filt, element, points_for_model, transformation_in_cs, coordinate_system, col_for_color)

if col_for_color is not None:
assert isinstance(col_for_color, str)
Expand Down Expand Up @@ -877,6 +887,7 @@ def _render_points(
points_pd_with_color,
transformation_in_cs,
coordinate_system,
col_for_color,
)

_warn_groups_ignored_continuous(groups, color_source_vector, col_for_color)
Expand All @@ -897,7 +908,7 @@ def _render_points(
# filter the materialized points, adata, and re-register in sdata_filt
points = points[keep].reset_index(drop=True)
adata = adata[keep]
_reparse_points(sdata_filt, element, points, transformation_in_cs, coordinate_system)
_reparse_points(sdata_filt, element, points, transformation_in_cs, coordinate_system, col_for_color)

# color_source_vector is None when the values aren't categorical
if color_source_vector is None and render_params.transfunc is not None:
Expand Down
39 changes: 39 additions & 0 deletions tests/pl/test_render_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,45 @@ def test_no_table_fallback_warning_for_element_column(caplog):
plt.close("all")


def test_render_points_color_by_z_data_column():
# regression test for #615: a data column named "z" must remain usable
# for coloring. PointsModel.parse drops columns colliding with reserved
# coordinate names; _reparse_points must re-attach the requested color
# column when needed so color lookup does not crash.
pts = PointsModel.parse(
pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1.0, 2.0, 3.0], "z": [0.1, 0.5, 0.9]}),
)
assert "z" in pts.columns
sdata = SpatialData(points={"p": pts})
fig, ax = plt.subplots()
try:
sdata.pl.render_points("p", color="z").pl.show(ax=ax)
finally:
plt.close(fig)


def test_render_points_color_by_z_with_extra_columns():
# #615 follow-up: re-attaching the color column must not disturb other
# data columns. Color by a non-conflicting column on a frame that also
# carries a (dropped-by-parse) "z" column.
pts = PointsModel.parse(
pd.DataFrame(
{
"x": [1.0, 2.0, 3.0],
"y": [1.0, 2.0, 3.0],
"z": [0.1, 0.5, 0.9],
"score": [0.0, 0.5, 1.0],
}
),
)
sdata = SpatialData(points={"p": pts})
fig, ax = plt.subplots()
try:
sdata.pl.render_points("p", color="score").pl.show(ax=ax)
finally:
plt.close(fig)


def test_render_points_disjoint_instance_ids_clear_error():
# regression test for #603: disjoint instance_id values must raise a clear ValueError
points = PointsModel.parse(pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1.0, 2.0, 3.0]}))
Expand Down
Loading