diff --git a/.scripts/ci/download_data.py b/.scripts/ci/download_data.py index c22566ff3..45e2837ad 100644 --- a/.scripts/ci/download_data.py +++ b/.scripts/ci/download_data.py @@ -19,10 +19,7 @@ def main(args: argparse.Namespace) -> None: from anndata import AnnData import squidpy as sq - from squidpy.datasets._downloader import get_downloader - - downloader = get_downloader() - registry = downloader.registry + from squidpy.datasets._registry import dataset_names # Visium samples tested in CI visium_samples_to_cache = [ @@ -35,23 +32,23 @@ def main(args: argparse.Namespace) -> None: logger.info("Cache: %s", settings.datasetdir) logger.info( "Would download: %d AnnData, %d images, %d SpatialData, %d Visium", - len(registry.anndata_datasets), - len(registry.image_datasets), - len(registry.spatialdata_datasets), + len(dataset_names("anndata")), + len(dataset_names("image")), + len(dataset_names("spatialdata")), len(visium_samples_to_cache), ) return # Download all datasets - the downloader handles caching - for name in registry.anndata_datasets: + for name in dataset_names("anndata"): obj = getattr(sq.datasets, name)() assert isinstance(obj, AnnData) - for name in registry.image_datasets: + for name in dataset_names("image"): obj = getattr(sq.datasets, name)() assert isinstance(obj, sq.im.ImageContainer) - for name in registry.spatialdata_datasets: + for name in dataset_names("spatialdata"): getattr(sq.datasets, name)() for sample in visium_samples_to_cache: diff --git a/pyproject.toml b/pyproject.toml index 9a67f5d83..43309f414 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,12 +60,13 @@ dependencies = [ "omnipath>=1.0.7", "pandas>=2.1", "pillow>=8", - "pooch>=1.6", "pyyaml>=6", "scanpy>=1.9.3", "scikit-image>=0.25", # due to https://github.com/scikit-image/scikit-image/issues/6850 breaks rescale ufunc "scikit-learn>=0.24", + # dataset registry + downloader (pooch) now live in scverse-misc + "scverse-misc[datasets]", "spatialdata>=0.7.2", # 0.7.2 dropped xarray-schema (pkg_resources break, #1115) "spatialdata-plot>=0.3.3", "statsmodels>=0.12", diff --git a/src/squidpy/datasets/_datasets.py b/src/squidpy/datasets/_datasets.py index 482441f58..0890349c5 100644 --- a/src/squidpy/datasets/_datasets.py +++ b/src/squidpy/datasets/_datasets.py @@ -10,10 +10,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Literal -from scanpy import settings - from squidpy.datasets._downloader import get_downloader -from squidpy.datasets._registry import DatasetType, get_registry +from squidpy.datasets._registry import dataset_names, get_registry from squidpy.read._utils import PathLike if TYPE_CHECKING: @@ -125,15 +123,12 @@ def visium( # Validate sample_id against known names downloader = get_downloader() - if sample_id not in downloader.registry: + if sample_id not in downloader.datasets: msg = f"Unknown Visium sample: {sample_id}. " - msg += f"Available samples: {downloader.registry.visium_datasets}" + msg += f"Available samples: {dataset_names('visium_10x')}" raise ValueError(msg) - # Use scanpy.settings.datasetdir/visium if base_dir not specified - if base_dir is None: - base_dir = Path(settings.datasetdir) / "visium" - + # downloads land in /visium_10x// return downloader.download(sample_id, base_dir, include_hires_tiff=include_hires_tiff) @@ -204,9 +199,9 @@ class _DocParts: return_type=":class:`squidpy.im.ImageContainer`\n The image data.", ) -_DOC_PARTS_BY_TYPE: dict[DatasetType, _DocParts] = { - DatasetType.ANNDATA: _ANNDATA_DOC, - DatasetType.IMAGE: _IMAGE_DOC, +_DOC_PARTS_BY_TYPE: dict[str, _DocParts] = { + "anndata": _ANNDATA_DOC, + "image": _IMAGE_DOC, } @@ -228,9 +223,9 @@ def loader(path: PathLike | None = None, **kwargs: Any): return get_downloader().download(dataset_name, path, **kwargs) loader.__doc__ = f""" - {entry.doc_header} + {entry.metadata.get("doc_header")} - {doc_parts.shape_prefix} ``{entry.shape}``. + {doc_parts.shape_prefix} ``{entry.metadata.get("shape")}``. Parameters ---------- diff --git a/src/squidpy/datasets/_downloader.py b/src/squidpy/datasets/_downloader.py index 8492f7905..966a1b1b5 100644 --- a/src/squidpy/datasets/_downloader.py +++ b/src/squidpy/datasets/_downloader.py @@ -1,290 +1,116 @@ -"""Unified dataset downloader using pooch.""" +"""squidpy's dataset loaders, registered against :mod:`scverse_misc.datasets`. + +The generic download/extract/verify machinery and the ``anndata``/``spatialdata`` loaders +live in ``scverse-misc``. squidpy registers loaders for its domain-specific types +(``image`` -> :class:`~squidpy.im.ImageContainer`, ``visium_10x`` -> :func:`squidpy.read.visium`) +and overrides the built-in ``anndata`` loader to emit squidpy's shape warning. +""" from __future__ import annotations -import shutil -import tarfile from functools import lru_cache from pathlib import Path from typing import TYPE_CHECKING, Any -import pooch from scanpy import settings +from scverse_misc.datasets import fetch, register_loader from spatialdata._logging import logger as logg -from squidpy.datasets._registry import ( - DatasetEntry, - DatasetRegistry, - DatasetType, - FileEntry, - get_registry, -) +from squidpy.datasets._registry import get_base_url, get_registry if TYPE_CHECKING: + from collections.abc import Callable + from anndata import AnnData - from spatialdata import SpatialData + from scverse_misc.datasets import DatasetEntry - from squidpy.im import ImageContainer +__all__ = ["DatasetDownloader", "download", "get_downloader"] -__all__ = [ - "DatasetDownloader", - "download", - "get_downloader", -] +@register_loader("anndata") +def _load_anndata(entry: DatasetEntry, target: Path, download: Callable[..., Any], **kwargs: Any) -> AnnData: + import anndata -class DatasetDownloader: - """Unified downloader for all squidpy datasets. + adata = anndata.read_h5ad(download(entry.file(suffix=".h5ad")), **kwargs) + shape = entry.metadata.get("shape") + if shape is not None and tuple(adata.shape) != tuple(shape): + logg.warning(f"Expected shape {tuple(shape)}, got {adata.shape}") + return adata - Parameters - ---------- - cache_dir - Directory to cache downloaded files. Defaults to :attr:`scanpy.settings.datasetdir`. - s3_base_url - Base URL for S3 bucket. If None, uses the value from datasets.yaml. - """ - - def __init__( - self, - registry: DatasetRegistry, - cache_dir: Path | str | None = None, - s3_base_url: str | None = None, - ): - self.cache_dir = Path(cache_dir or settings.datasetdir) - self.cache_dir.mkdir(parents=True, exist_ok=True) +@register_loader("image") +def _load_image(entry: DatasetEntry, target: Path, download: Callable[..., Any], **kwargs: Any) -> Any: + from squidpy.im import ImageContainer - self.registry = registry - self._s3_base_url = s3_base_url or self.registry.s3_base_url + img = ImageContainer() + img.add_img( + download(entry.file(suffix=".tiff")), layer="image", library_id=entry.metadata.get("library_id"), **kwargs + ) + return img - def _resolve_path( - self, - path: Path | str | None, - file_entry: FileEntry, - default_subdir: str, - ) -> tuple[Path, str]: - """Resolve target directory and filename from path argument.""" - if path is not None: - path = Path(path) - target_dir = path.parent - suffix = Path(file_entry.name).suffix - target_name = path.name if path.suffix else f"{path.name}{suffix}" - else: - target_dir = self.cache_dir / default_subdir - target_name = file_entry.name - return target_dir, target_name - def _download_file( - self, - file_entry: FileEntry, - target_dir: Path, - target_name: str | None = None, - ) -> Path: - """Download a single file.""" - target_dir.mkdir(parents=True, exist_ok=True) - filename = target_name or file_entry.name - local_path = target_dir / filename +@register_loader("visium_10x") +def _load_visium_10x( + entry: DatasetEntry, target: Path, download: Callable[..., Any], *, include_hires_tiff: bool = False, **kwargs: Any +) -> AnnData: + import pooch - if local_path.exists(): - logg.debug(f"Using cached file: {local_path}") - return local_path + from squidpy.read._read import visium as read_visium - urls = file_entry.get_urls(self._s3_base_url) - errors: list[Exception] = [] + sample_dir = target / entry.name + download(entry.file(name="filtered_feature_bc_matrix.h5"), dest=sample_dir) + download(entry.file(name="spatial.tar.gz"), dest=sample_dir, processor=pooch.Untar(extract_dir=".")) - for url in urls: + source_image_path = None + if include_hires_tiff: + image_file = next((f for f in entry.files if f.name.startswith("image.")), None) + if image_file is None: + logg.warning(f"High-res image not available for {entry.name}") + else: try: - logg.info(f"Downloading {filename} from {url}") - downloaded = pooch.retrieve( - url=url, - known_hash=(f"sha256:{file_entry.sha256}" if file_entry.sha256 else None), - fname=filename, - path=str(target_dir), - progressbar=True, - ) - return Path(downloaded) + source_image_path = download(image_file, dest=sample_dir) except (OSError, ValueError, RuntimeError) as e: - errors.append(e) - logg.warning(f"Failed to download from {url}: {e}") - - msg = f"Failed to download {filename}" - raise ExceptionGroup(msg, errors) - - def download(self, name: str, path: Path | str | None = None, **kwargs: Any) -> Any: - """Download a dataset by name and return the appropriate object. - - Parameters - ---------- - name - Dataset name from the registry. - path - Optional custom path for download. - **kwargs - Additional arguments passed to the loader. - - Returns - ------- - Loaded dataset. - """ - if name not in self.registry: - raise ValueError(f"Unknown dataset: {name}. Available: {self.registry.all_names}") - - entry = self.registry[name] - loaders = { - DatasetType.ANNDATA: lambda: self._load_anndata(entry, path, **kwargs), - DatasetType.IMAGE: lambda: self._load_image(entry, path, **kwargs), - DatasetType.SPATIALDATA: lambda: self._load_spatialdata(entry, path), - DatasetType.VISIUM_10X: lambda: self._load_visium_10x( - entry, - path, - include_hires_tiff=kwargs.pop("include_hires_tiff", False), - ), - } - - loader = loaders.get(entry.type) - if loader is None: - raise ValueError(f"Unknown dataset type: {entry.type}") - return loader() - - def _load_anndata( - self, - entry: DatasetEntry, - path: Path | str | None = None, - **kwargs: Any, - ) -> AnnData: - """Download and load an AnnData dataset.""" - import anndata - - file_entry = entry.get_file_by_suffix(".h5ad") - if file_entry is None: - raise ValueError(f"Dataset {entry.name} has no .h5ad file") - target_dir, target_name = self._resolve_path(path, file_entry, "anndata") - - local_path = self._download_file(file_entry, target_dir, target_name) - adata = anndata.read_h5ad(local_path, **kwargs) - - if entry.shape is not None and adata.shape != entry.shape: - logg.warning(f"Expected shape {entry.shape}, got {adata.shape}") - - return adata + logg.warning(f"Failed to download high-res image: {e}") - def _load_image( - self, - entry: DatasetEntry, - path: Path | str | None = None, - **kwargs: Any, - ) -> ImageContainer: - """Download and load an image dataset.""" - from squidpy.im import ImageContainer + if source_image_path is not None: + return read_visium(sample_dir, source_image_path=source_image_path) + return read_visium(sample_dir) - file_entry = entry.get_file_by_suffix(".tiff") - if file_entry is None: - raise ValueError(f"Dataset {entry.name} has no .tiff file") - target_dir, target_name = self._resolve_path(path, file_entry, "images") - local_path = self._download_file(file_entry, target_dir, target_name) - - img = ImageContainer() - img.add_img(local_path, layer="image", library_id=entry.library_id, **kwargs) - return img - - def _load_spatialdata( - self, - entry: DatasetEntry, - path: Path | str | None = None, - ) -> SpatialData: - """Download and load a SpatialData dataset.""" - import spatialdata as sd - - file_entry = entry.get_file_by_suffix(".zip") - if file_entry is None: - raise ValueError(f"Dataset {entry.name} has no .zip file") - folder = Path(path or self.cache_dir / "spatialdata") - folder.mkdir(parents=True, exist_ok=True) - - zarr_path = folder / f"{entry.name}.zarr" - - if zarr_path.exists(): - logg.info(f"Loading existing dataset from {zarr_path}") - return sd.read_zarr(zarr_path) - - zip_path = self._download_file(file_entry, folder) - logg.info(f"Extracting {zip_path} to {folder}") - shutil.unpack_archive(str(zip_path), folder) - - if not zarr_path.exists(): - raise RuntimeError(f"Expected extracted data at {zarr_path}, but not found") - - return sd.read_zarr(zarr_path) - - def _load_visium_10x( - self, - entry: DatasetEntry, - path: Path | str | None = None, - include_hires_tiff: bool = False, - ) -> AnnData: - """Download and load a 10x Genomics Visium dataset.""" - from squidpy.read._read import visium as read_visium - - base_dir = Path(path or self.cache_dir / "visium") - sample_dir = base_dir / entry.name - sample_dir.mkdir(parents=True, exist_ok=True) - - # Download feature matrix - matrix_file = entry.get_file("filtered_feature_bc_matrix.h5") - if matrix_file is None: - raise ValueError(f"Dataset {entry.name} missing filtered_feature_bc_matrix.h5") - self._download_file(matrix_file, sample_dir) - - # Download and extract spatial data - spatial_file = entry.get_file("spatial.tar.gz") - if spatial_file is None: - raise ValueError(f"Dataset {entry.name} missing spatial.tar.gz") +class DatasetDownloader: + """Thin squidpy wrapper over :func:`scverse_misc.datasets.fetch`. - spatial_path = self._download_file(spatial_file, sample_dir) - with tarfile.open(spatial_path) as f: - for member in f: - if not (sample_dir / member.name).exists(): - f.extract(member, sample_dir) + Parameters + ---------- + datasets + Mapping of dataset name to :class:`~scverse_misc.datasets.DatasetEntry`. + base_url + Base URL used to resolve ``s3_key`` files. + cache_dir + Base download directory. Defaults to :attr:`scanpy.settings.datasetdir`. + """ - # Optionally download high-res image - source_image_path = None - if include_hires_tiff: - image_file = entry.get_file_by_name_prefix("image.") - if image_file is None: - logg.warning(f"High-res image not available for {entry.name}") - else: - try: - self._download_file(image_file, sample_dir) - source_image_path = sample_dir / image_file.name - except (OSError, ValueError, RuntimeError) as e: - logg.warning(f"Failed to download high-res image: {e}") + def __init__( + self, datasets: dict[str, DatasetEntry], base_url: str | None, cache_dir: Path | str | None = None + ) -> None: + self.datasets = datasets + self.base_url = base_url + self.cache_dir = Path(cache_dir or settings.datasetdir) - if source_image_path and source_image_path.exists(): - return read_visium(sample_dir, source_image_path=source_image_path) - return read_visium(sample_dir) + def download(self, name: str, path: Path | str | None = None, **kwargs: Any) -> Any: + """Download and load a dataset by name, optionally into a custom ``path``.""" + if name not in self.datasets: + raise ValueError(f"Unknown dataset: {name}. Available: {sorted(self.datasets)}") + cache_dir = Path(path) if path is not None else self.cache_dir + return fetch(self.datasets[name], cache_dir, base_url=self.base_url, **kwargs) @lru_cache(maxsize=1) def get_downloader() -> DatasetDownloader: """Get the singleton downloader instance.""" - return DatasetDownloader(registry=get_registry()) + return DatasetDownloader(get_registry(), get_base_url()) def download(name: str, path: Path | str | None = None, **kwargs: Any) -> Any: - """Download a dataset by name. - - Parameters - ---------- - name - Dataset name. - path - Optional custom path. - **kwargs - Additional arguments passed to the loader. - - Returns - ------- - Loaded dataset. - """ + """Download a dataset by name (convenience wrapper around :func:`get_downloader`).""" return get_downloader().download(name, path, **kwargs) diff --git a/src/squidpy/datasets/_registry.py b/src/squidpy/datasets/_registry.py index da5cdb673..9bd217160 100644 --- a/src/squidpy/datasets/_registry.py +++ b/src/squidpy/datasets/_registry.py @@ -1,198 +1,39 @@ -"""Unified dataset registry loaded from YAML configuration.""" +"""squidpy's dataset registry: ``datasets.yaml`` parsed via :func:`scverse_misc.datasets.parse_registry`. + +squidpy-specific fields (``shape``, ``library_id``, ``doc_header``) land in each entry's +``metadata`` mapping automatically. +""" from __future__ import annotations import importlib.resources -from dataclasses import dataclass, field -from enum import Enum from functools import lru_cache from typing import TYPE_CHECKING -import yaml +from scverse_misc.datasets import parse_registry if TYPE_CHECKING: - from collections.abc import Iterator - from importlib.resources.abc import Traversable - - from squidpy.read._utils import PathLike - -__all__ = ["DatasetType", "FileEntry", "DatasetEntry", "DatasetRegistry", "get_registry"] - - -def _get_config_traversable() -> Traversable: - """Get the file-like object to datasets.yaml using importlib.resources for robustness.""" - # Using importlib.resources for robust path resolution across different installation methods - # (editable installs, zip imports, etc.) - return importlib.resources.files("squidpy.datasets").joinpath("datasets.yaml") - - -class DatasetType(Enum): - """Types of datasets.""" - - ANNDATA = "anndata" - IMAGE = "image" - SPATIALDATA = "spatialdata" - VISIUM_10X = "visium_10x" - - -@dataclass(frozen=True) -class FileEntry: - """Metadata for a single file within a dataset.""" - - name: str - s3_key: str - sha256: str | None = None - - def get_urls(self, s3_base_url: str) -> list[str]: - """Return list of URLs to try, primary (S3) first, then fallback.""" - urls = [] - if s3_base_url and self.s3_key: - urls.append(f"{s3_base_url.rstrip('/')}/{self.s3_key}") - return urls - - -@dataclass -class DatasetEntry: - """Metadata for a dataset (can have one or multiple files).""" - - name: str - type: DatasetType - files: list[FileEntry] - shape: tuple[int, ...] | None = None - doc_header: str | None = None - library_id: str | None = None - - def get_file(self, name: str) -> FileEntry | None: - """Get a specific file by name.""" - for f in self.files: - if f.name == name: - return f - return None - - def get_file_by_suffix(self, suffix: str) -> FileEntry | None: - """Get a file by suffix (e.g., 'filtered_feature_bc_matrix.h5').""" - for f in self.files: - if f.name.endswith(suffix): - return f - return None + from scverse_misc.datasets import DatasetEntry - def get_file_by_name_prefix(self, prefix: str) -> FileEntry | None: - """Get a file by prefix of its name (e.g., 'image.' to find image.tif or image.jpg).""" - for f in self.files: - if f.name.startswith(prefix): - return f - return None +__all__ = ["get_registry", "get_base_url", "dataset_names"] -@dataclass -class DatasetRegistry: - """Central registry for all squidpy datasets.""" - - s3_base_url: str = "" - datasets: dict[str, DatasetEntry] = field(default_factory=dict) - - @classmethod - def from_yaml(cls, config_path: PathLike | None = None) -> DatasetRegistry: - """Load registry from YAML configuration file.""" - # This case should be always true - # only for testing and tinkering config_path should be provided - if config_path is None: - with _get_config_traversable().open() as f: - config = yaml.safe_load(f) - else: - with open(config_path) as f: - config = yaml.safe_load(f) - - registry = cls(s3_base_url=config.get("s3_base_url", "")) - - # Load all datasets - for name, data in config.get("datasets", {}).items(): - # Parse files - files = [] - for file_data in data.get("files", []): - files.append( - FileEntry( - name=file_data["name"], - s3_key=file_data["s3_key"], - sha256=file_data.get("sha256"), - ) - ) - - # Parse shape - shape = None - if "shape" in data: - shape_data = data["shape"] - if isinstance(shape_data, list): - shape = tuple(shape_data) - else: - shape = shape_data - - registry.datasets[name] = DatasetEntry( - name=name, - type=DatasetType(data["type"]), - files=files, - shape=shape, - doc_header=data.get("doc_header"), - library_id=data.get("library_id"), - ) - - return registry - - def get(self, name: str) -> DatasetEntry | None: - """Get a dataset by name.""" - return self.datasets.get(name) - - def __getitem__(self, name: str) -> DatasetEntry: - """Get a dataset by name, raises KeyError if not found.""" - if name not in self.datasets: - raise KeyError(f"Unknown dataset: {name}. Available: {list(self.datasets.keys())}") - return self.datasets[name] - - def __contains__(self, name: str) -> bool: - """Check if dataset exists.""" - return name in self.datasets - - def iter_by_type(self, dataset_type: DatasetType) -> Iterator[DatasetEntry]: - """Iterate over datasets of a specific type.""" - for entry in self.datasets.values(): - if entry.type == dataset_type: - yield entry - - @property - def anndata_datasets(self) -> list[str]: - """Return names of all AnnData datasets.""" - return [name for name, entry in self.datasets.items() if entry.type == DatasetType.ANNDATA] - - @property - def image_datasets(self) -> list[str]: - """Return names of all image datasets.""" - return [name for name, entry in self.datasets.items() if entry.type == DatasetType.IMAGE] - - @property - def spatialdata_datasets(self) -> list[str]: - """Return names of all SpatialData datasets.""" - return [name for name, entry in self.datasets.items() if entry.type == DatasetType.SPATIALDATA] +@lru_cache(maxsize=1) +def _parsed() -> tuple[str | None, dict[str, DatasetEntry]]: + path = importlib.resources.files("squidpy.datasets").joinpath("datasets.yaml") + return parse_registry(str(path)) - @property - def visium_10x_datasets(self) -> list[str]: - """Return names of all 10x Genomics Visium datasets.""" - return [name for name, entry in self.datasets.items() if entry.type == DatasetType.VISIUM_10X] - @property - def visium_datasets(self) -> list[str]: - """Return names of all Visium datasets (alias for visium_10x_datasets).""" - return self.visium_10x_datasets +def get_registry() -> dict[str, DatasetEntry]: + """Return squidpy's datasets as ``{name: DatasetEntry}`` (cached).""" + return _parsed()[1] - @property - def all_names(self) -> list[str]: - """Return all dataset names.""" - return list(self.datasets.keys()) +def get_base_url() -> str | None: + """Return the registry's base URL.""" + return _parsed()[0] -@lru_cache(maxsize=1) -def get_registry() -> DatasetRegistry: - """Get the singleton dataset registry instance. - Uses lru_cache to ensure a single instance without mutable global state. - """ - return DatasetRegistry.from_yaml() +def dataset_names(dataset_type: str | None = None) -> list[str]: + """Return dataset names, optionally filtered by ``type`` (e.g. ``"visium_10x"``).""" + return [name for name, entry in get_registry().items() if dataset_type is None or entry.type == dataset_type] diff --git a/tests/datasets/test_download_visium_dataset.py b/tests/datasets/test_download_visium_dataset.py index 5f64bbd64..aa4977dde 100644 --- a/tests/datasets/test_download_visium_dataset.py +++ b/tests/datasets/test_download_visium_dataset.py @@ -39,7 +39,7 @@ def test_visium_datasets(sample): # Test that downloading tissue image works sample_dataset = visium(sample, include_hires_tiff=True) - expected_image_path = (Path(settings.datasetdir) / "visium" / sample / "image.tif").resolve() + expected_image_path = (Path(settings.datasetdir) / "visium_10x" / sample / "image.tif").resolve() spatial_metadata = sample_dataset.uns["spatial"][sample]["metadata"] image_path = Path(spatial_metadata["source_image_path"]).resolve() assert image_path == expected_image_path diff --git a/tests/datasets/test_downloader.py b/tests/datasets/test_downloader.py index 7eae27140..4590fbce6 100644 --- a/tests/datasets/test_downloader.py +++ b/tests/datasets/test_downloader.py @@ -1,4 +1,4 @@ -"""Tests for the unified dataset downloader.""" +"""Tests for squidpy's dataset loaders + downloader (built on scverse_misc.datasets).""" from __future__ import annotations @@ -6,60 +6,47 @@ import pytest from scanpy import settings +from scverse_misc.datasets import available_loaders from squidpy.datasets._downloader import ( DatasetDownloader, download, get_downloader, ) -from squidpy.datasets._registry import get_registry +from squidpy.datasets._registry import get_base_url, get_registry -class TestDatasetDownloader: - """Tests for DatasetDownloader class.""" - - def test_init_default_cache_dir(self): - downloader = DatasetDownloader(registry=get_registry()) - assert downloader.cache_dir == Path(settings.datasetdir) +class TestLoaderRegistration: + def test_squidpy_loaders_registered(self): + # importing the downloader module registers squidpy's domain loaders; + # anndata + spatialdata are shipped by scverse-misc + assert {"anndata", "image", "spatialdata", "visium_10x"} <= set(available_loaders()) - def test_init_custom_cache_dir(self, tmp_path: Path): - downloader = DatasetDownloader(registry=get_registry(), cache_dir=tmp_path / "custom_cache") - assert downloader.cache_dir == tmp_path / "custom_cache" - assert downloader.cache_dir.exists() - def test_init_custom_s3_url(self): - s3_url = "https://my-bucket.s3.amazonaws.com" - downloader = DatasetDownloader(registry=get_registry(), s3_base_url=s3_url) - assert downloader._s3_base_url == s3_url +class TestDatasetDownloader: + def test_default_cache_dir(self): + dl = DatasetDownloader(get_registry(), get_base_url()) + assert dl.cache_dir == Path(settings.datasetdir) - def test_registry_loaded(self): - downloader = DatasetDownloader(registry=get_registry()) - assert downloader.registry is not None - assert len(downloader.registry.datasets) > 0 + def test_custom_cache_dir(self, tmp_path: Path): + dl = DatasetDownloader(get_registry(), get_base_url(), cache_dir=tmp_path) + assert dl.cache_dir == tmp_path - def test_download_unknown_dataset(self, tmp_path: Path): - downloader = DatasetDownloader(registry=get_registry(), cache_dir=tmp_path) + def test_unknown_dataset_raises(self, tmp_path: Path): + dl = DatasetDownloader(get_registry(), get_base_url(), cache_dir=tmp_path) with pytest.raises(ValueError, match="Unknown dataset"): - downloader.download("nonexistent_dataset") + dl.download("nonexistent_dataset") class TestGetDownloader: - """Tests for get_downloader singleton function.""" - def test_returns_downloader(self): - downloader = get_downloader() - assert isinstance(downloader, DatasetDownloader) + assert isinstance(get_downloader(), DatasetDownloader) - def test_returns_same_instance(self): - # lru_cache ensures singleton behavior - downloader1 = get_downloader() - downloader2 = get_downloader() - assert downloader1 is downloader2 + def test_singleton(self): + assert get_downloader() is get_downloader() class TestDownloadFunction: - """Tests for download convenience function.""" - def test_unknown_dataset_raises(self): with pytest.raises(ValueError, match="Unknown dataset"): download("nonexistent_dataset") @@ -70,75 +57,20 @@ class TestDownloaderIntegration: @pytest.mark.timeout(120) @pytest.mark.internet() - def test_download_imc_dataset(self): - """Test downloading a small AnnData dataset.""" + def test_download_anndata(self): from anndata import AnnData - # Use scanpy.settings.datasetdir to match what download_data.py uses - downloader = DatasetDownloader(registry=get_registry(), cache_dir=settings.datasetdir) - adata = downloader.download("imc") - + adata = DatasetDownloader(get_registry(), get_base_url(), cache_dir=settings.datasetdir).download("imc") assert isinstance(adata, AnnData) assert adata.shape == (4668, 34) - @pytest.mark.timeout(120) - @pytest.mark.internet() - def test_download_caches_file(self): - """Test that downloaded files are cached.""" - cache_dir = Path(settings.datasetdir) - downloader = DatasetDownloader(registry=get_registry(), cache_dir=cache_dir) - - # First download - adata1 = downloader.download("imc") - - # Check file exists in cache - cache_files = list((cache_dir / "anndata").glob("*.h5ad")) - # At least one file (may have more from other tests) - assert len(cache_files) >= 1 - - # Second download should use cache (no network) - adata2 = downloader.download("imc") - assert adata1.shape == adata2.shape - @pytest.mark.timeout(180) @pytest.mark.internet() def test_download_visium_sample(self): - """Test downloading a Visium sample.""" from anndata import AnnData - downloader = DatasetDownloader(registry=get_registry(), cache_dir=settings.datasetdir) - adata = downloader.download("V1_Mouse_Kidney", include_hires_tiff=False) - + adata = DatasetDownloader(get_registry(), get_base_url(), cache_dir=settings.datasetdir).download( + "V1_Mouse_Kidney", include_hires_tiff=False + ) assert isinstance(adata, AnnData) assert "spatial" in adata.uns - - @pytest.mark.timeout(300) - @pytest.mark.internet() - def test_include_hires_tiff_caching_behavior(self): - """Test include_hires_tiff: cached files persist, return varies. - - On CI, V1_Mouse_Kidney is pre-cached via .scripts/ci/download_data.py - with include_hires_tiff=True, so this tests return behavior. - """ - sample_id = "V1_Mouse_Kidney" - cache_dir = Path(settings.datasetdir) - hires_image_path = cache_dir / "visium" / sample_id / "image.tif" - downloader = DatasetDownloader(registry=get_registry(), cache_dir=cache_dir) - - # include_hires_tiff=False: no source_image_path in metadata - adata = downloader.download(sample_id, include_hires_tiff=False) - metadata = adata.uns["spatial"][sample_id].get("metadata", {}) - assert "source_image_path" not in metadata - - # include_hires_tiff=True: source_image_path in metadata, file cached - adata = downloader.download(sample_id, include_hires_tiff=True) - metadata = adata.uns["spatial"][sample_id].get("metadata", {}) - assert "source_image_path" in metadata - assert Path(metadata["source_image_path"]).exists() - assert hires_image_path.exists() - - # include_hires_tiff=False again: cached file persists, not in metadata - adata = downloader.download(sample_id, include_hires_tiff=False) - metadata = adata.uns["spatial"][sample_id].get("metadata", {}) - assert "source_image_path" not in metadata - assert hires_image_path.exists() # file still cached diff --git a/tests/datasets/test_registry.py b/tests/datasets/test_registry.py index fdfb754e3..103a3b579 100644 --- a/tests/datasets/test_registry.py +++ b/tests/datasets/test_registry.py @@ -1,235 +1,61 @@ -"""Tests for the unified dataset registry.""" +"""Tests for squidpy's dataset registry (built on scverse_misc.datasets).""" from __future__ import annotations -import pytest - -from squidpy.datasets._registry import ( - DatasetEntry, - DatasetRegistry, - DatasetType, - FileEntry, - get_registry, -) - - -class TestFileEntry: - """Tests for FileEntry dataclass.""" - - def test_entry_creation(self): - entry = FileEntry( - name="test.h5ad", - s3_key="figshare/test.h5ad", - sha256="abc123", - ) - assert entry.name == "test.h5ad" - assert entry.sha256 == "abc123" - - def test_get_urls_with_s3(self): - entry = FileEntry( - name="test.h5ad", - s3_key="test.h5ad", - ) - urls = entry.get_urls("https://s3.example.com") - assert len(urls) == 1 - assert urls[0] == "https://s3.example.com/test.h5ad" - - entry_cells = FileEntry( - name="cells.zip", - s3_key="cells.zip", - ) - urls_cells = entry_cells.get_urls("https://s3.example.com") - assert urls_cells[0] == "https://s3.example.com/cells.zip" - - entry_10x = FileEntry( - name="filtered_feature_bc_matrix.h5", - s3_key="10x_genomics/sample/matrix.h5", - ) - urls_10x = entry_10x.get_urls("https://s3.example.com") - assert urls_10x[0] == "https://s3.example.com/10x_genomics/sample/matrix.h5" - - -class TestDatasetEntry: - """Tests for DatasetEntry dataclass.""" - - def test_single_file_dataset(self): - entry = DatasetEntry( - name="test", - type=DatasetType.ANNDATA, - files=[ - FileEntry( - name="test.h5ad", - s3_key="test.h5ad", - ) - ], - shape=(100, 50), - ) - assert len(entry.files) == 1 - assert entry.shape == (100, 50) - - def test_visium_10x_dataset(self): - entry = DatasetEntry( - name="V1_Test", - type=DatasetType.VISIUM_10X, - files=[ - FileEntry( - name="filtered_feature_bc_matrix.h5", - s3_key="test.h5", - ), - FileEntry(name="spatial.tar.gz", s3_key="test.tar.gz"), - FileEntry(name="image.tif", s3_key="test.tif"), - ], - ) - assert len(entry.files) == 3 - assert entry.type == DatasetType.VISIUM_10X - assert entry.get_file_by_name_prefix("image.") is not None - - def test_get_file(self): - entry = DatasetEntry( - name="test", - type=DatasetType.VISIUM_10X, - files=[ - FileEntry( - name="filtered_feature_bc_matrix.h5", - s3_key="test.h5", - ), - FileEntry(name="spatial.tar.gz", s3_key="test.tar.gz"), - ], - ) - f = entry.get_file("spatial.tar.gz") - assert f is not None - assert f.name == "spatial.tar.gz" - - assert entry.get_file("nonexistent") is None - - -class TestDatasetRegistry: - """Tests for DatasetRegistry class.""" - - def test_from_yaml_loads_config(self): - registry = DatasetRegistry.from_yaml() - assert registry is not None - assert len(registry.datasets) > 0 - - def test_anndata_datasets_loaded(self): - registry = DatasetRegistry.from_yaml() - assert "four_i" in registry - assert "imc" in registry - assert "seqfish" in registry - assert "visium_hne_adata" in registry - - def test_anndata_dataset_fields(self): - registry = DatasetRegistry.from_yaml() - four_i = registry["four_i"] - assert four_i.type == DatasetType.ANNDATA - assert four_i.shape == (270876, 43) - assert len(four_i.files) == 1 - - def test_image_datasets_loaded(self): - registry = DatasetRegistry.from_yaml() - assert "visium_hne_image" in registry - assert "visium_hne_image_crop" in registry - assert "visium_fluo_image_crop" in registry - - def test_image_has_library_id(self): - registry = DatasetRegistry.from_yaml() - img = registry["visium_hne_image"] - assert img.library_id == "V1_Adult_Mouse_Brain" - - def test_spatialdata_loaded(self): - registry = DatasetRegistry.from_yaml() - assert "visium_hne_sdata" in registry - sdata = registry["visium_hne_sdata"] - assert sdata.type == DatasetType.SPATIALDATA - - def test_visium_10x_datasets_loaded(self): - registry = DatasetRegistry.from_yaml() - # Check samples from different versions - assert "V1_Adult_Mouse_Brain" in registry - assert "Parent_Visium_Human_Cerebellum" in registry - assert "Visium_FFPE_Mouse_Brain" in registry - - def test_visium_10x_dataset_structure(self): - registry = DatasetRegistry.from_yaml() - v1_sample = registry["V1_Adult_Mouse_Brain"] - assert v1_sample.type == DatasetType.VISIUM_10X - assert len(v1_sample.files) == 3 # matrix, spatial, image - assert v1_sample.get_file_by_name_prefix("image.") is not None - - def test_visium_10x_has_jpg(self): - """Test that Visium_FFPE_Human_Normal_Prostate has jpg image.""" - registry = DatasetRegistry.from_yaml() - sample = registry["Visium_FFPE_Human_Normal_Prostate"] - assert sample.type == DatasetType.VISIUM_10X - # Check it's a jpg - img_file = sample.get_file_by_name_prefix("image.") - assert img_file is not None - assert img_file.name == "image.jpg" - - def test_get_dataset(self): - registry = DatasetRegistry.from_yaml() - entry = registry.get("four_i") - assert entry is not None - assert entry.name == "four_i" - - assert registry.get("nonexistent") is None - - def test_getitem(self): - registry = DatasetRegistry.from_yaml() - entry = registry["four_i"] - assert entry.name == "four_i" - - with pytest.raises(KeyError): - _ = registry["nonexistent"] - - def test_contains(self): - registry = DatasetRegistry.from_yaml() - assert "four_i" in registry - assert "nonexistent" not in registry - - def test_iter_by_type(self): - registry = DatasetRegistry.from_yaml() - anndata_entries = list(registry.iter_by_type(DatasetType.ANNDATA)) - assert len(anndata_entries) == 11 # 11 h5ad datasets - - visium_10x_entries = list(registry.iter_by_type(DatasetType.VISIUM_10X)) - assert len(visium_10x_entries) == 35 # 35 Visium samples - - def test_cells_dataset_loaded(self): - registry = DatasetRegistry.from_yaml() - assert "cells" in registry - cells = registry["cells"] - assert cells.type == DatasetType.SPATIALDATA - assert len(cells.files) == 1 - assert cells.files[0].name == "cells.zip" - assert cells.files[0].s3_key == "cells.zip" - - def test_property_lists(self): - registry = DatasetRegistry.from_yaml() - assert len(registry.anndata_datasets) == 11 - assert len(registry.image_datasets) == 3 - assert len(registry.spatialdata_datasets) == 2 - assert len(registry.visium_datasets) == 35 +from scverse_misc.datasets import DatasetEntry - def test_all_names(self): - registry = DatasetRegistry.from_yaml() - names = registry.all_names - assert "four_i" in names - assert "visium_hne_image" in names - assert "V1_Adult_Mouse_Brain" in names - assert "cells" in names - # Total: 11 + 3 + 2 + 35 = 51 - assert len(names) == 51 +from squidpy.datasets._registry import dataset_names, get_base_url, get_registry class TestGetRegistry: - """Tests for get_registry singleton function.""" + def test_returns_mapping_of_entries(self): + reg = get_registry() + assert isinstance(reg, dict) + assert all(isinstance(e, DatasetEntry) for e in reg.values()) + + def test_cached(self): + assert get_registry() is get_registry() + + def test_base_url(self): + assert get_base_url() == "https://exampledata.scverse.org/squidpy/" + + +class TestRegistryContents: + def test_anndata_entry(self): + four_i = get_registry()["four_i"] + assert four_i.type == "anndata" + assert four_i.metadata["shape"] == [270876, 43] + assert four_i.file(suffix=".h5ad").s3_key == "four_i.h5ad" + + def test_image_entry_has_library_id(self): + img = get_registry()["visium_hne_image"] + assert img.type == "image" + assert img.metadata["library_id"] == "V1_Adult_Mouse_Brain" - def test_returns_registry(self): - registry = get_registry() - assert isinstance(registry, DatasetRegistry) + def test_spatialdata_entries(self): + reg = get_registry() + assert reg["visium_hne_sdata"].type == "spatialdata" + assert reg["cells"].type == "spatialdata" + assert reg["cells"].file(suffix=".zip").name == "cells.zip" - def test_returns_same_instance(self): - registry1 = get_registry() - registry2 = get_registry() - assert registry1 is registry2 + def test_visium_10x_entry(self): + sample = get_registry()["V1_Adult_Mouse_Brain"] + assert sample.type == "visium_10x" + assert len(sample.files) == 3 + assert any(f.name.startswith("image.") for f in sample.files) + + def test_unknown_dataset(self): + assert "nonexistent" not in get_registry() + + +class TestDatasetNames: + def test_counts_by_type(self): + assert len(dataset_names("anndata")) == 11 + assert len(dataset_names("image")) == 3 + assert len(dataset_names("spatialdata")) == 2 + assert len(dataset_names("visium_10x")) == 35 + + def test_all_names(self): + names = dataset_names() + assert {"four_i", "visium_hne_image", "V1_Adult_Mouse_Brain", "cells"} <= set(names) + assert len(names) == 51