Skip to content

Commit e25a1ea

Browse files
committed
add tests
1 parent 0ee9396 commit e25a1ea

4 files changed

Lines changed: 58 additions & 6 deletions

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1030,7 +1030,7 @@ def _render_labels(
10301030
mask = np.isin(instance_id, labels_in_rasterized_image)
10311031
instance_id = instance_id[mask]
10321032
color_vector = color_vector[mask]
1033-
if pd.api.types.is_categorical_dtype(color_vector.dtype):
1033+
if isinstance(color_vector.dtype, pd.CategoricalDtype):
10341034
color_vector = color_vector.remove_unused_categories()
10351035
if color_source_vector is not None:
10361036
color_source_vector = color_source_vector[mask]
59.9 KB
Loading
65.9 KB
Loading

tests/pl/test_render_labels.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from matplotlib.colors import Normalize
1010
from spatial_image import to_spatial_image
1111
from spatialdata import SpatialData, deepcopy, get_element_instances
12-
from spatialdata.models import TableModel
12+
from spatialdata.models import Labels2DModel, TableModel
1313

1414
import spatialdata_plot # noqa: F401
1515
from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over
@@ -76,7 +76,13 @@ def test_plot_can_stack_render_labels(self, sdata_blobs: SpatialData):
7676
fill_alpha=1,
7777
outline_alpha=0,
7878
)
79-
.pl.render_labels(element="blobs_labels", na_color="blue", fill_alpha=0, outline_alpha=1, contour_px=15)
79+
.pl.render_labels(
80+
element="blobs_labels",
81+
na_color="blue",
82+
fill_alpha=0,
83+
outline_alpha=1,
84+
contour_px=15,
85+
)
8086
.pl.show()
8187
)
8288

@@ -146,7 +152,11 @@ def test_plot_two_calls_with_coloring_result_in_two_colorbars(self, sdata_blobs:
146152

147153
def test_plot_can_control_label_outline(self, sdata_blobs: SpatialData):
148154
sdata_blobs.pl.render_labels(
149-
"blobs_labels", color="channel_0_sum", outline_alpha=0.4, fill_alpha=0.0, contour_px=15
155+
"blobs_labels",
156+
color="channel_0_sum",
157+
outline_alpha=0.4,
158+
fill_alpha=0.0,
159+
contour_px=15,
150160
).pl.show()
151161

152162
def test_plot_can_control_label_infill(self, sdata_blobs: SpatialData):
@@ -162,7 +172,11 @@ def test_plot_label_colorbar_uses_alpha_of_less_transparent_infill(
162172
sdata_blobs: SpatialData,
163173
):
164174
sdata_blobs.pl.render_labels(
165-
"blobs_labels", color="channel_0_sum", fill_alpha=0.1, outline_alpha=0.7, contour_px=15
175+
"blobs_labels",
176+
color="channel_0_sum",
177+
fill_alpha=0.1,
178+
outline_alpha=0.7,
179+
contour_px=15,
166180
).pl.show()
167181

168182
def test_plot_label_colorbar_uses_alpha_of_less_transparent_outline(
@@ -233,7 +247,10 @@ def _make_tablemodel_with_categorical_labels(self, sdata_blobs, labels_name: str
233247

234248
def test_plot_can_color_with_norm_and_clipping(self, sdata_blobs: SpatialData):
235249
sdata_blobs.pl.render_labels(
236-
"blobs_labels", color="channel_0_sum", norm=Normalize(400, 1000, clip=True), cmap=_viridis_with_under_over()
250+
"blobs_labels",
251+
color="channel_0_sum",
252+
norm=Normalize(400, 1000, clip=True),
253+
cmap=_viridis_with_under_over(),
237254
).pl.show()
238255

239256
def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs: SpatialData):
@@ -247,3 +264,38 @@ def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs: SpatialData):
247264
def test_plot_can_annotate_labels_with_table_layer(self, sdata_blobs: SpatialData):
248265
sdata_blobs["table"].layers["normalized"] = RNG.random(sdata_blobs["table"].X.shape)
249266
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", table_layer="normalized").pl.show()
267+
268+
def _prepare_small_labels(self, sdata_blobs: SpatialData) -> SpatialData:
269+
# add a categorical column
270+
adata = sdata_blobs["table"]
271+
sdata_blobs["table"].obs["category"] = ["a"] * 10 + ["b"] * 10 + ["c"] * 6
272+
273+
sdata_blobs["table"].obs["category"] = sdata_blobs["table"].obs["category"].astype("category")
274+
275+
labels = sdata_blobs["blobs_labels"].data.compute()
276+
277+
# make label 1 small
278+
mask = labels == 1
279+
labels[mask] = 0
280+
labels[200, 200] = 1
281+
282+
sdata_blobs["blobs_labels"] = Labels2DModel.parse(labels)
283+
284+
# tile the labels object
285+
arr = da.tile(sdata_blobs["blobs_labels"], (4, 4))
286+
sdata_blobs["blobs_labels_large"] = Labels2DModel.parse(arr)
287+
288+
adata.obs["region"] = "blobs_labels_large"
289+
sdata_blobs.set_table_annotates_spatialelement("table", region="blobs_labels_large")
290+
return sdata_blobs
291+
292+
def test_plot_can_handle_dropping_small_labels_after_rasterize_continuous(self, sdata_blobs: SpatialData):
293+
# reported here https://github.com/scverse/spatialdata-plot/issues/443
294+
sdata_blobs = self._prepare_small_labels(sdata_blobs)
295+
296+
sdata_blobs.pl.render_labels("blobs_labels_large", color="channel_0_sum", table_name="table").pl.show()
297+
298+
def test_plot_can_handle_dropping_small_labels_after_rasterize_categorical(self, sdata_blobs: SpatialData):
299+
sdata_blobs = self._prepare_small_labels(sdata_blobs)
300+
301+
sdata_blobs.pl.render_labels("blobs_labels_large", color="category", table_name="table").pl.show()

0 commit comments

Comments
 (0)