Skip to content

Commit 1e92e81

Browse files
timtreisclaude
andcommitted
Fix shapes datashader colorbar exceeding data range (#559)
The default datashader reduction for shapes was "sum", causing overlapping shapes to inflate the colorbar beyond the true data maximum. Changed the default to "max" which preserves the actual data range and closely matches the matplotlib rendering. Also: extract _default_reduction to prevent log/aggregation drift, add logger_no_warns test helper, short-circuit _want_decorations for diverse color vectors. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1b195f0 commit 1e92e81

6 files changed

Lines changed: 101 additions & 46 deletions

File tree

src/spatialdata_plot/_logging.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,40 @@ def logger_warns(
113113
if not any(pattern.search(r.getMessage()) for r in records):
114114
msgs = [r.getMessage() for r in records]
115115
raise AssertionError(f"Did not find log matching {match!r} in records: {msgs!r}")
116+
117+
118+
@contextmanager
119+
def logger_no_warns(
120+
caplog: LogCaptureFixture,
121+
logger: logging.Logger,
122+
match: str | None = None,
123+
level: int = logging.WARNING,
124+
) -> Iterator[None]:
125+
"""Assert that no log record matching *match* is emitted.
126+
127+
Counterpart to :func:`logger_warns`.
128+
"""
129+
initial_record_count = len(caplog.records)
130+
131+
handler = caplog.handler
132+
logger.addHandler(handler)
133+
original_level = logger.level
134+
logger.setLevel(level)
135+
136+
with caplog.at_level(level, logger=logger.name):
137+
try:
138+
yield
139+
finally:
140+
logger.removeHandler(handler)
141+
logger.setLevel(original_level)
142+
143+
records = [r for r in caplog.records[initial_record_count:] if r.levelno >= level]
144+
145+
if match is not None:
146+
pattern = re.compile(match)
147+
matching = [r.getMessage() for r in records if pattern.search(r.getMessage())]
148+
if matching:
149+
raise AssertionError(f"Found unexpected log matching {match!r}: {matching!r}")
150+
elif records:
151+
msgs = [r.getMessage() for r in records]
152+
raise AssertionError(f"Expected no log records at level>={level}, but got: {msgs!r}")

src/spatialdata_plot/pl/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,8 @@ def render_shapes(
272272
273273
datashader_reduction : Literal[
274274
"sum", "mean", "any", "count", "std", "var", "max", "min"
275-
], default: "sum"
276-
Reduction method for datashader when coloring by continuous values. Defaults to 'sum'.
275+
], default: "max"
276+
Reduction method for datashader when coloring by continuous values. Defaults to 'max'.
277277
278278
279279
Notes

src/spatialdata_plot/pl/render.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
_ds_aggregate,
3535
_ds_shade_categorical,
3636
_ds_shade_continuous,
37+
_DsReduction,
3738
_render_ds_image,
3839
_render_ds_outlines,
3940
)
@@ -82,17 +83,26 @@ def _want_decorations(color_vector: Any, na_color: Color) -> bool:
8283
cv = np.asarray(color_vector)
8384
if cv.size == 0:
8485
return False
85-
# Fast check: if any value differs from the first, there is variety → show decorations.
8686
first = cv.flat[0]
8787
if not (cv == first).all():
8888
return True
89-
# All values are the same — suppress decorations when that value is the NA color.
9089
na_hex = na_color.get_hex()
9190
if isinstance(first, str) and first.startswith("#") and na_hex.startswith("#"):
9291
return _hex_no_alpha(first) != _hex_no_alpha(na_hex)
9392
return bool(first != na_hex)
9493

9594

95+
def _log_datashader_method(method: str, ds_reduction: _DsReduction | None, default: _DsReduction) -> None:
96+
"""Log the datashader backend and effective reduction being used."""
97+
effective = ds_reduction if ds_reduction is not None else default
98+
logger.info(
99+
f"Using '{method}' backend with '{effective}' as reduction"
100+
" method to speed up plotting. Depending on the reduction method, the value"
101+
" range of the plot might change. Set method to 'matplotlib' to disable"
102+
" this behaviour."
103+
)
104+
105+
96106
def _reparse_points(
97107
sdata_filt: sd.SpatialData,
98108
element: str,
@@ -437,14 +447,10 @@ def _render_shapes(
437447
if method is None:
438448
method = "datashader" if len(shapes) > 10000 else "matplotlib"
439449

450+
_default_reduction: _DsReduction = "max"
451+
440452
if method != "matplotlib":
441-
# we only notify the user when we switched away from matplotlib
442-
logger.info(
443-
f"Using '{method}' backend with '{render_params.ds_reduction}' as reduction"
444-
" method to speed up plotting. Depending on the reduction method, the value"
445-
" range of the plot might change. Set method to 'matplotlib' to disable"
446-
" this behaviour."
447-
)
453+
_log_datashader_method(method, render_params.ds_reduction, _default_reduction)
448454

449455
if method == "datashader":
450456
_geometry = shapes["geometry"]
@@ -518,7 +524,7 @@ def _render_shapes(
518524
col_for_color,
519525
color_by_categorical,
520526
render_params.ds_reduction,
521-
"mean",
527+
_default_reduction,
522528
"shapes",
523529
)
524530

@@ -851,14 +857,10 @@ def _render_points(
851857
if method is None:
852858
method = "datashader" if n_points > 10000 else "matplotlib"
853859

860+
_default_reduction: _DsReduction = "sum"
861+
854862
if method == "datashader":
855-
# we only notify the user when we switched away from matplotlib
856-
logger.info(
857-
f"Using '{method}' backend with '{render_params.ds_reduction}' as reduction"
858-
" method to speed up plotting. Depending on the reduction method, the value"
859-
" range of the plot might change. Set method to 'matplotlib' do disable"
860-
" this behaviour."
861-
)
863+
_log_datashader_method(method, render_params.ds_reduction, _default_reduction)
862864

863865
# NOTE: s in matplotlib is in units of points**2
864866
# use dpi/100 as a factor for cases where dpi!=100
@@ -917,7 +919,7 @@ def _render_points(
917919
col_for_color,
918920
color_by_categorical,
919921
render_params.ds_reduction,
920-
"sum",
922+
_default_reduction,
921923
"points",
922924
)
923925

src/spatialdata_plot/pl/utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2507,9 +2507,6 @@ def _ensure_table_and_layer_exist_in_sdata(
25072507
if ds_reduction and (ds_reduction not in valid_ds_reduction_methods):
25082508
raise ValueError(f"Parameter 'ds_reduction' must be one of the following: {valid_ds_reduction_methods}.")
25092509

2510-
if method == "datashader" and ds_reduction is None:
2511-
param_dict["ds_reduction"] = "sum"
2512-
25132510
return param_dict
25142511

25152512

tests/pl/test_render_points.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
import math
32

43
import dask.dataframe
@@ -24,7 +23,7 @@
2423
from spatialdata.transformations._utils import _set_transformations
2524

2625
import spatialdata_plot # noqa: F401
27-
from spatialdata_plot._logging import logger, logger_warns
26+
from spatialdata_plot._logging import logger, logger_no_warns, logger_warns
2827
from spatialdata_plot.pl._datashader import (
2928
_build_datashader_color_key,
3029
_ds_aggregate,
@@ -832,13 +831,8 @@ def test_ds_reduction_ignored_for_categorical(caplog):
832831
def test_ds_reduction_no_warning_when_none(caplog):
833832
"""No spurious warning when ds_reduction is None (the default)."""
834833
cvs, df = _make_ds_canvas_and_df()
835-
with caplog.at_level(logging.WARNING, logger=logger.name):
836-
logger.addHandler(caplog.handler)
837-
try:
838-
_ds_aggregate(cvs, df.copy(), "cat", True, None, "sum", "points")
839-
finally:
840-
logger.removeHandler(caplog.handler)
841-
assert not any("ignored" in r.message.lower() for r in caplog.records)
834+
with logger_no_warns(caplog, logger, match="ignored"):
835+
_ds_aggregate(cvs, df.copy(), "cat", True, None, "sum", "points")
842836

843837

844838
@pytest.mark.parametrize("reduction", ["mean", "max", "min", "count", "std", "var"])
@@ -866,13 +860,8 @@ def test_warn_groups_ignored_continuous_emits(caplog):
866860

867861
def test_warn_groups_ignored_continuous_silent_for_categorical(caplog):
868862
"""No warning when color_source_vector is present (categorical)."""
869-
with caplog.at_level(logging.WARNING, logger=logger.name):
870-
logger.addHandler(caplog.handler)
871-
try:
872-
_warn_groups_ignored_continuous(["A"], pd.Categorical(["A", "B"]), "cat_col")
873-
finally:
874-
logger.removeHandler(caplog.handler)
875-
assert not any("ignored" in r.message for r in caplog.records)
863+
with logger_no_warns(caplog, logger, match="ignored"):
864+
_warn_groups_ignored_continuous(["A"], pd.Categorical(["A", "B"]), "cat_col")
876865

877866

878867
def test_color_key_warns_on_short_color_vector(caplog):
@@ -893,13 +882,8 @@ def test_color_key_warns_on_long_color_vector(caplog):
893882
def test_color_key_no_warning_when_lengths_match(caplog):
894883
"""No warning when lengths match."""
895884
cat = pd.Categorical(["A", "B", "C"])
896-
with caplog.at_level(logging.WARNING, logger=logger.name):
897-
logger.addHandler(caplog.handler)
898-
try:
899-
_build_datashader_color_key(cat, ["#ff0000", "#00ff00", "#0000ff"], "#cccccc")
900-
finally:
901-
logger.removeHandler(caplog.handler)
902-
assert not any("color_vector length" in r.message for r in caplog.records)
885+
with logger_no_warns(caplog, logger, match="color_vector length"):
886+
_build_datashader_color_key(cat, ["#ff0000", "#00ff00", "#0000ff"], "#cccccc")
903887

904888

905889
def test_color_key_unseen_category_gets_na_color(caplog):

tests/pl/test_render_shapes.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,3 +1175,38 @@ def test_render_shapes_color_with_conflicting_index_name():
11751175

11761176
# Should not raise ValueError: cannot insert EntityID, already exists
11771177
sdata.pl.render_shapes("shapes", color="cell_type", table_name="table").pl.show()
1178+
1179+
1180+
def test_datashader_colorbar_range_matches_data(sdata_blobs: SpatialData):
1181+
"""Datashader colorbar range must not exceed the actual data range for shapes.
1182+
1183+
Regression test for https://github.com/scverse/spatialdata-plot/issues/559.
1184+
Before the fix, shapes defaulted to 'sum' aggregation, causing overlapping
1185+
shapes to inflate the colorbar beyond the true data maximum.
1186+
"""
1187+
n = len(sdata_blobs.shapes["blobs_circles"])
1188+
rng = np.random.default_rng(0)
1189+
values = rng.uniform(0, 100, size=n)
1190+
sdata_blobs.shapes["blobs_circles"]["continuous_val"] = values
1191+
data_max = float(values.max())
1192+
data_min = float(values.min())
1193+
1194+
fig, ax = plt.subplots()
1195+
sdata_blobs.pl.render_shapes("blobs_circles", color="continuous_val", method="datashader").pl.show(ax=ax)
1196+
1197+
# Find the colorbar axis — it's a child axes with a ScalarMappable
1198+
cbar_vmax = None
1199+
cbar_vmin = None
1200+
for child in fig.get_children():
1201+
if isinstance(child, matplotlib.axes.Axes) and child is not ax:
1202+
ylim = child.get_ylim()
1203+
if ylim != (0.0, 1.0): # colorbar axes have non-default limits
1204+
cbar_vmin, cbar_vmax = ylim
1205+
1206+
assert cbar_vmax is not None, "Could not find colorbar in figure"
1207+
assert cbar_vmax <= data_max * 1.01, (
1208+
f"Colorbar max ({cbar_vmax:.2f}) exceeds data max ({data_max:.2f}); "
1209+
"datashader aggregation is likely using 'sum' instead of 'max'"
1210+
)
1211+
assert cbar_vmin >= data_min * 0.99 - 0.01, f"Colorbar min ({cbar_vmin:.2f}) is below data min ({data_min:.2f})"
1212+
plt.close(fig)

0 commit comments

Comments
 (0)