Skip to content

Commit acd677d

Browse files
Ottjax>0.5 support, optional dependency handling fixes and dataset download fixes (#288)
* ottjax compat initial fix attempt * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * comment on copied file * compat errors and dataset error * scvi is also properly optional * scvi fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * formatting * add new dataset tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix draft on scvi * relax test bounds * remove unused import * OTFM change to genot * formattting * try out new batching inpredict * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * reformat --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent da7e094 commit acd677d

17 files changed

Lines changed: 385 additions & 164 deletions

File tree

pyproject.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,13 @@ dependencies = [
3737
"dask",
3838
"diffrax",
3939
"flax",
40+
"huggingface-hub",
4041
"orbax",
41-
"ott-jax==0.5",
42-
"pyarrow", # required for dask.dataframe
42+
"ott-jax[neural]>=0.5",
43+
"pyarrow", # required for dask.dataframe
4344
"scanpy",
4445
"scikit-learn==1.5.1",
45-
"scipy<1.16", # see https://github.com/statsmodels/statsmodels/issues/9584
46+
"scipy<1.16", # see https://github.com/statsmodels/statsmodels/issues/9584
4647
"session-info",
4748
]
4849

@@ -168,6 +169,7 @@ addopts = [
168169
]
169170
markers = [
170171
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
172+
"internet: marks tests that require internet access (deselect with '-m \"not internet\"')",
171173
]
172174

173175
[tool.coverage.run]

src/cellflow/_compat.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""Compatibility helpers for optional and version-gated dependencies.
2+
3+
``ott-jax>=0.6`` removed ``ott.neural.methods.flows.dynamics`` and the
4+
``ott.neural.networks.velocity_field.VelocityField`` (flax linen) class.
5+
This module re-exports the symbols needed by CellFlow so that both
6+
``ott-jax>=0.5,<0.6`` and ``ott-jax>=0.6`` are supported.
7+
8+
The embedding helpers (``torch``, ``transformers``) are optional and only
9+
required when using gene-embedding functionality.
10+
"""
11+
12+
# ---------------------------------------------------------------------------
13+
# Probability-path dynamics (BaseFlow, ConstantNoiseFlow, BrownianBridge)
14+
#
15+
# For ott-jax <0.6 we import directly from ott. For ott-jax >=0.6 the
16+
# module was removed, so we provide a vendored copy below.
17+
#
18+
# The fallback classes are a verbatim copy of
19+
# ott.neural.methods.flows.dynamics
20+
# from ott-jax 0.5.0 (commit 690b1ae, 2024-12-03).
21+
# ott-jax is licensed under the Apache License 2.0, which permits
22+
# reproduction and distribution of derivative works provided the license
23+
# and copyright notice are retained. See:
24+
# https://github.com/ott-jax/ott/blob/0.5.0/LICENSE
25+
# ---------------------------------------------------------------------------
26+
try:
27+
from ott.neural.methods.flows.dynamics import ( # ott-jax <0.6
28+
BaseFlow,
29+
BrownianBridge,
30+
ConstantNoiseFlow,
31+
)
32+
except ImportError:
33+
# -- Vendored from ott-jax 0.5.0 (Apache-2.0) --------------------------
34+
# Source: src/ott/neural/methods/flows/dynamics.py
35+
# Copyright OTT-JAX contributors
36+
# -----------------------------------------------------------------------
37+
import abc
38+
39+
import jax
40+
import jax.numpy as jnp
41+
42+
class BaseFlow(abc.ABC):
43+
"""Base class for all flows."""
44+
45+
def __init__(self, sigma: float):
46+
self.sigma = sigma
47+
48+
@abc.abstractmethod
49+
def compute_mu_t(self, t: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray) -> jnp.ndarray: ...
50+
51+
@abc.abstractmethod
52+
def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: ...
53+
54+
@abc.abstractmethod
55+
def compute_ut(self, t: jnp.ndarray, x: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray) -> jnp.ndarray: ...
56+
57+
def compute_xt(self, rng: jax.Array, t: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray) -> jnp.ndarray:
58+
"""Sample from the probability path."""
59+
noise = jax.random.normal(rng, shape=x0.shape)
60+
mu_t = self.compute_mu_t(t, x0, x1)
61+
sigma_t = self.compute_sigma_t(t)
62+
return mu_t + sigma_t * noise
63+
64+
class _StraightFlow(BaseFlow, abc.ABC):
65+
def compute_mu_t(self, t, x0, x1):
66+
return (1.0 - t) * x0 + t * x1
67+
68+
def compute_ut(self, t, x, x0, x1):
69+
del t, x
70+
return x1 - x0
71+
72+
class ConstantNoiseFlow(_StraightFlow):
73+
r"""Flow with straight paths and constant noise :math:`\sigma`."""
74+
75+
def compute_sigma_t(self, t):
76+
return jnp.full_like(t, fill_value=self.sigma)
77+
78+
class BrownianBridge(_StraightFlow):
79+
r"""Brownian Bridge with :math:`\sigma_t = \sigma \sqrt{t(1-t)}`."""
80+
81+
def compute_sigma_t(self, t):
82+
return self.sigma * jnp.sqrt(t * (1.0 - t))
83+
84+
def compute_ut(self, t, x, x0, x1):
85+
drift_term = (1 - 2 * t) / (2 * t * (1 - t)) * (x - (t * x1 + (1 - t) * x0))
86+
control_term = x1 - x0
87+
return drift_term + control_term
88+
89+
90+
# ---------------------------------------------------------------------------
91+
# Optional embedding dependencies (torch, transformers)
92+
# ---------------------------------------------------------------------------
93+
_EMBEDDING_ERR_MSG = (
94+
"To use gene embedding, please install `transformers` and `torch` e.g. via `pip install cellflow['embedding']`."
95+
)
96+
97+
try:
98+
import torch # noqa: F401
99+
import transformers # noqa: F401
100+
101+
HAS_EMBEDDING_DEPS = True
102+
except ImportError:
103+
HAS_EMBEDDING_DEPS = False
104+
105+
106+
def check_embedding_deps() -> None:
107+
"""Raise a helpful error if torch/transformers are not installed."""
108+
if not HAS_EMBEDDING_DEPS:
109+
raise ImportError(_EMBEDDING_ERR_MSG)
110+
111+
112+
__all__ = [
113+
"BaseFlow",
114+
"BrownianBridge",
115+
"ConstantNoiseFlow",
116+
"HAS_EMBEDDING_DEPS",
117+
"check_embedding_deps",
118+
]

src/cellflow/datasets.py

Lines changed: 22 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
1-
import os
21
from typing import Any
32

43
import anndata as ad
5-
from scanpy.readwrite import _check_datafile_present_and_download
6-
7-
from cellflow._types import PathLike
4+
from huggingface_hub import hf_hub_download
85

96
__all__ = [
107
"ineurons",
118
"pbmc_cytokines",
129
]
1310

11+
_HF_REPO = "theislab/cellflow-datasets"
12+
1413

1514
def ineurons(
16-
path: PathLike = "~/.cache/cellflow/ineurons.h5ad",
1715
force_download: bool = False,
1816
**kwargs: Any,
1917
) -> ad.AnnData:
@@ -24,28 +22,23 @@ def ineurons(
2422
2523
Parameters
2624
----------
27-
path
28-
Path where to save the file.
2925
force_download
3026
Whether to force-download the data.
3127
kwargs
32-
Keyword arguments for :func:`scanpy.read`.
28+
Keyword arguments for :func:`anndata.read_h5ad`.
3329
3430
Returns
3531
-------
3632
Annotated data object.
3733
"""
38-
return _load_dataset_from_url(
39-
path,
40-
backup_url="https://figshare.com/ndownloader/files/52852961",
41-
expected_shape=(54134, 2000), # TODO: adapt this, and enable check
34+
return _load_dataset(
35+
filename="ineurons.h5ad",
4236
force_download=force_download,
4337
**kwargs,
4438
)
4539

4640

4741
def pbmc_cytokines(
48-
path: PathLike = "~/.cache/cellflow/pbmc_parse.h5ad",
4942
force_download: bool = False,
5043
**kwargs: Any,
5144
) -> ad.AnnData:
@@ -57,28 +50,23 @@ def pbmc_cytokines(
5750
5851
Parameters
5952
----------
60-
path
61-
Path where to save the file.
6253
force_download
6354
Whether to force-download the data.
6455
kwargs
65-
Keyword arguments for :func:`scanpy.read`.
56+
Keyword arguments for :func:`anndata.read_h5ad`.
6657
6758
Returns
6859
-------
6960
Annotated data object.
7061
"""
71-
return _load_dataset_from_url(
72-
path,
73-
backup_url="https://figshare.com/ndownloader/files/53372768",
74-
expected_shape=(54134, 2000), # TODO: adapt this, and enable check
62+
return _load_dataset(
63+
filename="pbmc_parse.h5ad",
7564
force_download=force_download,
7665
**kwargs,
7766
)
7867

7968

8069
def zesta(
81-
path: PathLike = "~/.cache/cellflow/zesta.h5ad",
8270
force_download: bool = False,
8371
**kwargs: Any,
8472
) -> ad.AnnData:
@@ -90,47 +78,33 @@ def zesta(
9078
9179
Parameters
9280
----------
93-
path
94-
Path where to save the file.
9581
force_download
9682
Whether to force-download the data.
9783
kwargs
98-
Keyword arguments for :func:`scanpy.read`.
84+
Keyword arguments for :func:`anndata.read_h5ad`.
9985
10086
Returns
10187
-------
10288
Annotated data object.
10389
"""
104-
return _load_dataset_from_url(
105-
path,
106-
backup_url="https://figshare.com/ndownloader/files/52966469",
107-
expected_shape=(54134, 2000), # TODO: adapt this, and enable check
90+
return _load_dataset(
91+
filename="zesta.h5ad",
10892
force_download=force_download,
10993
**kwargs,
11094
)
11195

11296

113-
def _load_dataset_from_url(
114-
fpath: PathLike,
97+
def _load_dataset(
98+
filename: str,
11599
*,
116-
backup_url: str,
117-
expected_shape: tuple[int, int],
100+
repo_id: str = _HF_REPO,
118101
force_download: bool = False,
119102
**kwargs: Any,
120103
) -> ad.AnnData:
121-
fpath = os.path.expanduser(fpath)
122-
if not fpath.endswith(".h5ad"):
123-
fpath += ".h5ad"
124-
if force_download and os.path.exists(fpath):
125-
os.remove(fpath)
126-
if not _check_datafile_present_and_download(backup_url=backup_url, path=fpath):
127-
raise FileNotFoundError(f"File `{fpath}` not found or download failed.")
128-
data = ad.read_h5ad(filename=fpath, **kwargs)
129-
130-
# TODO: enable the dataset shape check
131-
# if data.shape != expected_shape:
132-
# raise ValueError(
133-
# f"Expected AnnData object to have shape `{expected_shape}`, found `{data.shape}`."
134-
# )
135-
136-
return data
104+
fpath = hf_hub_download(
105+
repo_id=repo_id,
106+
filename=filename,
107+
repo_type="dataset",
108+
force_download=force_download,
109+
)
110+
return ad.read_h5ad(fpath, **kwargs)

src/cellflow/external/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1 @@
1-
try:
2-
from cellflow.external._scvi import CFJaxSCVI
3-
except ImportError as e:
4-
raise ImportError(
5-
"cellflow.external requires more dependencies. Please install via pip install 'cellflow[external]'"
6-
) from e
1+
from cellflow.external._scvi import CFJaxSCVI

0 commit comments

Comments
 (0)