Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions src/spatialdata/dataloader/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ class ImageTilesDataset(Dataset):
system; this back-transforms the target tile into the pixel coordinates. If the back-transformed tile is not
aligned with the pixel grid, the returned tile will correspond to the bounding box of the back-transformed tile
(so that the returned tile is axis-aligned to the pixel grid).
return_genes:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Two comments:

  1. I would specify that the layers are AnnData layers and the default layer is X.
  2. I would also allow to pass just a list instead of a dict, that would be interpreted as {'X': genes_list}

If not `None`, return the gene expression values from the table. The dictionary should have the following
structure: `{"layer_name": None}` or `{"layer": ["gene_name1", "gene_name2"]}`.
If the value is `None`, all the genes are returned.
return_annotations
If not `None`, one or more values from the table are returned together with the image tile in a tuple.
Only columns in :attr:`anndata.AnnData.obs` and :attr:`anndata.AnnData.X` can be returned.
Expand Down Expand Up @@ -122,6 +126,7 @@ def __init__(
tile_scale: float = 1.0,
tile_dim_in_units: float | None = None,
rasterize: bool = False,
return_genes: Mapping[str, list[str] | None] | None = None,
return_annotations: str | list[str] | None = None,
table_name: str | None = None,
transform: Callable[[Any], Any] | None = None,
Expand Down Expand Up @@ -158,13 +163,14 @@ def _validate(
sdata: SpatialData,
regions_to_images: dict[str, str],
regions_to_coordinate_systems: dict[str, str],
return_genes: Mapping[str, list[str] | None] | None,
return_annotations: str | list[str] | None,
table_name: str | None,
) -> None:
"""Validate input parameters."""
self.sdata = sdata
if return_annotations is not None and table_name is None:
raise ValueError("`table_name` must be provided if `return_annotations` is not `None`.")
if (return_annotations is not None) or (return_genes is None) and table_name is None:
raise ValueError("`table_name` must be provided if `return_annotations` or `return_genes` is not `None`.")

# check that the regions specified in the two dicts are the same
assert set(regions_to_images.keys()) == set(
Expand Down Expand Up @@ -264,6 +270,7 @@ def _preprocess(

if table_name is not None:
table_subset = filtered_table[filtered_table.obs[region_key] == region_name]
table_subset.uns["spatialdata_attrs"]["region"] = region_name
circles_sdata = SpatialData.init_from_elements({region_name: circles}, tables=table_subset.copy())
_, table = join_spatialelement_table(
sdata=circles_sdata,
Expand Down Expand Up @@ -302,6 +309,7 @@ def _return_function(
dataset_table: AnnData,
dataset_index: pd.DataFrame,
table_name: str | None,
return_genes: Mapping[str, list[str] | None],
return_annot: str | list[str] | None,
return_array: bool = False,
) -> tuple[Any, Any] | SpatialData:
Expand All @@ -312,7 +320,7 @@ def _return_function(
# where return_table can be a single column or a list of columns
return_annot = [return_annot] if isinstance(return_annot, str) else return_annot
# return tuple of (tile, table)
if np.all([i in dataset_table.obs for i in return_annot]):
if np.all(dataset_table.obs.columns.isin(return_annot)):
return tile, dataset_table.obs[return_annot].iloc[idx].values.reshape(1, -1)
if np.all([i in dataset_table.var_names for i in return_annot]):
if issparse(dataset_table.X):
Expand All @@ -336,6 +344,43 @@ def _return_function(
)
return SpatialData(images={dataset_index.iloc[idx][ImageTilesDataset.IMAGE_KEY]: tile})

@staticmethod
def _return_annotations(
idx: int,
dataset_table: AnnData,
dataset_index: pd.DataFrame,
table_name: str | None,
return_annot: str | list[str],
) -> pd.DataFrame:
# table is always returned as array shape (1, len(return_annot))
# where return_table can be a single column or a list of columns
return_annot = [return_annot] if isinstance(return_annot, str) else return_annot
# return tuple of (tile, table)
if np.all(dataset_table.obs.columns.isin(return_annot)):
return dataset_table.obs[return_annot].iloc[idx].values.reshape(1, -1)
else:
raise KeyError("Missing some valid annotations in the table.")

@staticmethod
def _return_genes(
idx: int,
dataset_table: AnnData,
dataset_index: pd.DataFrame,
table_name: str | None,
return_genes: Mapping[str, list[str] | None],
) -> pd.DataFrame:
k, v = next(iter(return_genes.items()))
layer = dataset_table.X if k == "X" else dataset_table.layers[k].X
if v is None:
if issparse(layer):
return layer[idx].X.A
return layer[idx].X
if isinstance(v, list) and np.all(dataset_table.var_names.isin(v)):
if issparse(layer):
return layer[idx, v].X.A
return layer[idx, v].X
raise KeyError("Missing some valid genes in the table.")

def _get_return(
self,
return_annot: str | list[str] | None,
Expand Down