Skip to content

Commit 9b0cf72

Browse files
committed
Preserve points data columns dropped by PointsModel.parse (#615)
PointsModel.parse silently strips columns whose names collide with reserved spatial-axis names (currently "z"). When _reparse_points re-registered a points element, a data column named "z" requested via color="z" was dropped before color lookup ran, producing a misleading KeyError. _reparse_points now takes the color column name and re-attaches it from the source DataFrame when parsing dropped it, so coloring by a data column that shadows a reserved axis name works.
1 parent db17a4f commit 9b0cf72

2 files changed

Lines changed: 54 additions & 4 deletions

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,20 @@ def _reparse_points(
117117
df: pd.DataFrame,
118118
transformation: Any,
119119
coordinate_system: str,
120+
color_column: str | None = None,
120121
) -> None:
121-
"""Re-register a points DataFrame in *sdata_filt* with its transformation."""
122+
"""Re-register a points DataFrame in *sdata_filt* with its transformation.
123+
124+
When ``color_column`` is provided and refers to a column that
125+
``PointsModel.parse`` would silently drop because it collides with a
126+
reserved coordinate axis (e.g. a data column literally named ``"z"``),
127+
re-attach it so downstream color lookup can find it as a ``"df"`` origin.
128+
"""
122129
dd_frame = dask.dataframe.from_pandas(df, npartitions=1)
123-
sdata_filt.points[element] = PointsModel.parse(dd_frame, coordinates={"x": "x", "y": "y"})
130+
parsed = PointsModel.parse(dd_frame, coordinates={"x": "x", "y": "y"})
131+
if color_column is not None and color_column in df.columns and color_column not in parsed.columns:
132+
parsed[color_column] = dd_frame[color_column]
133+
sdata_filt.points[element] = parsed
124134
set_transformation(
125135
element=sdata_filt.points[element],
126136
transformation=transformation,
@@ -820,7 +830,7 @@ def _render_points(
820830

821831
# Convert back to dask dataframe to modify sdata
822832
transformation_in_cs = sdata_filt.points[element].attrs["transform"][coordinate_system]
823-
_reparse_points(sdata_filt, element, points_for_model, transformation_in_cs, coordinate_system)
833+
_reparse_points(sdata_filt, element, points_for_model, transformation_in_cs, coordinate_system, col_for_color)
824834

825835
if col_for_color is not None:
826836
assert isinstance(col_for_color, str)
@@ -877,6 +887,7 @@ def _render_points(
877887
points_pd_with_color,
878888
transformation_in_cs,
879889
coordinate_system,
890+
col_for_color,
880891
)
881892

882893
_warn_groups_ignored_continuous(groups, color_source_vector, col_for_color)
@@ -897,7 +908,7 @@ def _render_points(
897908
# filter the materialized points, adata, and re-register in sdata_filt
898909
points = points[keep].reset_index(drop=True)
899910
adata = adata[keep]
900-
_reparse_points(sdata_filt, element, points, transformation_in_cs, coordinate_system)
911+
_reparse_points(sdata_filt, element, points, transformation_in_cs, coordinate_system, col_for_color)
901912

902913
# color_source_vector is None when the values aren't categorical
903914
if color_source_vector is None and render_params.transfunc is not None:

tests/pl/test_render_points.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,45 @@ def test_no_table_fallback_warning_for_element_column(caplog):
10061006
plt.close("all")
10071007

10081008

1009+
def test_render_points_color_by_z_data_column():
1010+
# regression test for #615: a data column named "z" must remain usable
1011+
# for coloring. PointsModel.parse drops columns colliding with reserved
1012+
# coordinate names; _reparse_points must re-attach the requested color
1013+
# column when needed so color lookup does not crash.
1014+
pts = PointsModel.parse(
1015+
pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1.0, 2.0, 3.0], "z": [0.1, 0.5, 0.9]}),
1016+
)
1017+
assert "z" in pts.columns
1018+
sdata = SpatialData(points={"p": pts})
1019+
fig, ax = plt.subplots()
1020+
try:
1021+
sdata.pl.render_points("p", color="z").pl.show(ax=ax)
1022+
finally:
1023+
plt.close(fig)
1024+
1025+
1026+
def test_render_points_color_by_z_with_extra_columns():
1027+
# #615 follow-up: re-attaching the color column must not disturb other
1028+
# data columns. Color by a non-conflicting column on a frame that also
1029+
# carries a (dropped-by-parse) "z" column.
1030+
pts = PointsModel.parse(
1031+
pd.DataFrame(
1032+
{
1033+
"x": [1.0, 2.0, 3.0],
1034+
"y": [1.0, 2.0, 3.0],
1035+
"z": [0.1, 0.5, 0.9],
1036+
"score": [0.0, 0.5, 1.0],
1037+
}
1038+
),
1039+
)
1040+
sdata = SpatialData(points={"p": pts})
1041+
fig, ax = plt.subplots()
1042+
try:
1043+
sdata.pl.render_points("p", color="score").pl.show(ax=ax)
1044+
finally:
1045+
plt.close(fig)
1046+
1047+
10091048
def test_render_points_disjoint_instance_ids_clear_error():
10101049
# regression test for #603: disjoint instance_id values must raise a clear ValueError
10111050
points = PointsModel.parse(pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1.0, 2.0, 3.0]}))

0 commit comments

Comments
 (0)