Skip to content

Commit 56888eb

Browse files
Merge pull request #235 from theislab/feature/predict_batch
Feature and Speedup: `predict_batch`
2 parents d169628 + 9068a93 commit 56888eb

5 files changed

Lines changed: 195 additions & 27 deletions

File tree

src/cellflow/model/_cellflow.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import flax.linen as nn
1111
import jax
1212
import jax.numpy as jnp
13+
import numpy as np
1314
import optax
1415
import pandas as pd
1516
from ott.neural.methods.flows import dynamics
@@ -625,6 +626,8 @@ def predict(
625626
batch = pred_loader.sample()
626627
src = batch["source"]
627628
condition = batch.get("condition", None)
629+
# using jax.tree.map to batch the prediction
630+
# because PredictionSampler can return a different number of cells for each condition
628631
out = jax.tree.map(
629632
functools.partial(self.solver.predict, rng=rng, **kwargs),
630633
src,
@@ -637,9 +640,10 @@ def predict(
637640
f"When saving predictions to `adata`, all control cells must be from the same control \
638641
population, but found {len(pred_data.control_to_perturbation)} control populations."
639642
)
643+
out_np = {k: np.array(v) for k, v in out.items()}
640644
_write_predictions(
641645
adata=adata,
642-
predictions=out,
646+
predictions=out_np,
643647
key_added_prefix=key_added_prefix,
644648
)
645649

src/cellflow/solvers/_genot.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def predict(
234234
condition: dict[str, ArrayLike] | None = None,
235235
rng: ArrayLike | None = None,
236236
rng_genot: ArrayLike | None = None,
237+
batched: bool = False,
237238
**kwargs: Any,
238239
) -> ArrayLike | tuple[ArrayLike, diffrax.Solution]:
239240
"""Generate the push-forward of ``x`` under condition ``condition``.
@@ -253,13 +254,48 @@ def predict(
253254
mean embedding is used.
254255
rng_genot
255256
Random generate used to sample from the latent distribution in cell space.
257+
batched
258+
Whether to use batched prediction. This is only supported if the input has
259+
the same number of cells for each condition. For example, this works when using
260+
:class:`~cellflow.data.ValidationSampler` to sample the validation data.
256261
kwargs
257262
Keyword arguments for :func:`diffrax.diffeqsolve`.
258263
259264
Returns
260265
-------
261266
The push-forward distribution of ``x`` under condition ``condition``.
262267
"""
268+
if batched and not x:
269+
return {}
270+
271+
if batched:
272+
keys = sorted(x.keys())
273+
condition_keys = sorted(set().union(*(condition[k].keys() for k in keys)))
274+
_predict_jit = jax.jit(lambda x, condition: self._predict_jit(x, condition, rng, **kwargs))
275+
batched_predict = jax.vmap(_predict_jit, in_axes=(0, dict.fromkeys(condition_keys, 0)))
276+
# assert that the number of cells is the same for each condition
277+
n_cells = x[keys[0]].shape[0]
278+
for k in keys:
279+
assert x[k].shape[0] == n_cells, "The number of cells must be the same for each condition"
280+
src_inputs = jnp.stack([x[k] for k in keys], axis=0)
281+
batched_conditions = {}
282+
for cond_key in condition_keys:
283+
batched_conditions[cond_key] = jnp.stack([condition[k][cond_key] for k in keys])
284+
285+
pred_targets = batched_predict(src_inputs, batched_conditions)
286+
return {k: pred_targets[i] for i, k in enumerate(keys)}
287+
else:
288+
x_pred = self._predict_jit(x, condition, rng, rng_genot, **kwargs)
289+
return np.array(x_pred)
290+
291+
def _predict_jit(
292+
self,
293+
x: ArrayLike,
294+
condition: dict[str, ArrayLike] | None = None,
295+
rng: ArrayLike | None = None,
296+
rng_genot: ArrayLike | None = None,
297+
**kwargs: Any,
298+
) -> ArrayLike | tuple[ArrayLike, diffrax.Solution]:
263299
kwargs.setdefault("dt0", None)
264300
kwargs.setdefault("solver", diffrax.Tsit5())
265301
kwargs.setdefault("stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5))
@@ -291,7 +327,7 @@ def solve_ode(
291327
return sol.ys[0]
292328

293329
x_pred = jax.jit(jax.vmap(solve_ode, in_axes=[0, 0, None, None]))(latent, x, condition, encoder_noise)
294-
return np.array(x_pred)
330+
return x_pred
295331

296332
@property
297333
def is_trained(self) -> bool:

src/cellflow/solvers/_otfm.py

Lines changed: 63 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -174,31 +174,10 @@ def get_condition_embedding(self, condition: dict[str, ArrayLike], return_as_num
174174
return np.asarray(cond_mean), np.asarray(cond_logvar)
175175
return cond_mean, cond_logvar
176176

177-
def predict(
177+
def _predict_jit(
178178
self, x: ArrayLike, condition: dict[str, ArrayLike], rng: jax.Array | None = None, **kwargs: Any
179179
) -> ArrayLike:
180-
"""Predict the translated source ``x`` under condition ``condition``.
181-
182-
This function solves the ODE learnt with
183-
the :class:`~cellflow.networks.ConditionalVelocityField`.
184-
185-
Parameters
186-
----------
187-
x
188-
Input data of shape [batch_size, ...].
189-
condition
190-
Condition of the input data of shape [batch_size, ...].
191-
rng
192-
Random number generator to sample from the latent distribution,
193-
only used if ``condition_mode='stochastic'``. If :obj:`None`, the
194-
mean embedding is used.
195-
kwargs
196-
Keyword arguments for :func:`diffrax.diffeqsolve`.
197-
198-
Returns
199-
-------
200-
The push-forward distribution of ``x`` under condition ``condition``.
201-
"""
180+
"""See :meth:`OTFlowMatching.predict`."""
202181
kwargs.setdefault("dt0", None)
203182
kwargs.setdefault("solver", diffrax.Tsit5())
204183
kwargs.setdefault("stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5))
@@ -226,7 +205,67 @@ def solve_ode(x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise:
226205
return result.ys[0]
227206

228207
x_pred = jax.jit(jax.vmap(solve_ode, in_axes=[0, None, None]))(x, condition, encoder_noise)
229-
return np.array(x_pred)
208+
return x_pred
209+
210+
def predict(
211+
self,
212+
x: ArrayLike | dict[str, ArrayLike],
213+
condition: dict[str, ArrayLike] | dict[str, dict[str, ArrayLike]],
214+
rng: jax.Array | None = None,
215+
batched: bool = False,
216+
**kwargs: Any,
217+
) -> ArrayLike | dict[str, ArrayLike]:
218+
"""Predict the translated source ``x`` under condition ``condition``.
219+
220+
This function solves the ODE learnt with
221+
the :class:`~cellflow.networks.ConditionalVelocityField`.
222+
223+
Parameters
224+
----------
225+
x
226+
A dictionary with keys indicating the name of the condition and values containing
227+
the input data as arrays. If ``batched=False`` provide an array of shape [batch_size, ...].
228+
condition
229+
A dictionary with keys indicating the name of the condition and values containing
230+
the condition of input data as arrays. If ``batched=False`` provide an array of shape
231+
[batch_size, ...].
232+
rng
233+
Random number generator to sample from the latent distribution,
234+
only used if ``condition_mode='stochastic'``. If :obj:`None`, the
235+
mean embedding is used.
236+
batched
237+
Whether to use batched prediction. This is only supported if the input has
238+
the same number of cells for each condition. For example, this works when using
239+
:class:`~cellflow.data.ValidationSampler` to sample the validation data.
240+
kwargs
241+
Keyword arguments for :func:`diffrax.diffeqsolve`.
242+
243+
Returns
244+
-------
245+
The push-forward distribution of ``x`` under condition ``condition``.
246+
"""
247+
if batched and not x:
248+
return {}
249+
250+
if batched:
251+
keys = sorted(x.keys())
252+
condition_keys = sorted(set().union(*(condition[k].keys() for k in keys)))
253+
_predict_jit = jax.jit(lambda x, condition: self._predict_jit(x, condition, rng, **kwargs))
254+
batched_predict = jax.vmap(_predict_jit, in_axes=(0, dict.fromkeys(condition_keys, 0)))
255+
# assert that the number of cells is the same for each condition
256+
n_cells = x[keys[0]].shape[0]
257+
for k in keys:
258+
assert x[k].shape[0] == n_cells, "The number of cells must be the same for each condition"
259+
src_inputs = jnp.stack([x[k] for k in keys], axis=0)
260+
batched_conditions = {}
261+
for cond_key in condition_keys:
262+
batched_conditions[cond_key] = jnp.stack([condition[k][cond_key] for k in keys])
263+
264+
pred_targets = batched_predict(src_inputs, batched_conditions)
265+
return {k: pred_targets[i] for i, k in enumerate(keys)}
266+
else:
267+
x_pred = self._predict_jit(x, condition, rng, **kwargs)
268+
return np.array(x_pred)
230269

231270
@property
232271
def is_trained(self) -> bool:

src/cellflow/training/_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _validation_step(
6161
condition = batch.get("condition", None)
6262
true_tgt = batch["target"]
6363
valid_source_data[val_key] = src
64-
valid_pred_data[val_key] = jax.tree.map(self.solver.predict, src, condition)
64+
valid_pred_data[val_key] = self.solver.predict(src, condition=condition, batched=True)
6565
valid_true_data[val_key] = true_tgt
6666

6767
return valid_source_data, valid_true_data, valid_pred_data

tests/solver/test_solver.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import functools
2+
import time
3+
4+
import jax
5+
import numpy as np
6+
import optax
7+
import pytest
8+
from ott.neural.methods.flows import dynamics
9+
10+
import cellflow
11+
from cellflow.solvers import _genot, _otfm
12+
from cellflow.utils import match_linear
13+
14+
src = {
15+
("drug_1",): np.random.rand(10, 5),
16+
("drug_2",): np.random.rand(10, 5),
17+
}
18+
cond = {
19+
("drug_1",): {"drug": np.random.rand(1, 1, 3)},
20+
("drug_2",): {"drug": np.random.rand(1, 1, 3)},
21+
}
22+
vf_rng = jax.random.PRNGKey(111)
23+
24+
25+
class TestSolver:
26+
@pytest.mark.parametrize("solver_class", ["otfm", "genot"])
27+
def test_predict_batch(self, dataloader, solver_class):
28+
if solver_class == "otfm":
29+
vf_class = cellflow.networks.ConditionalVelocityField
30+
else:
31+
vf_class = cellflow.networks.GENOTConditionalVelocityField
32+
33+
opt = optax.adam(1e-3)
34+
vf = vf_class(
35+
output_dim=5,
36+
max_combination_length=2,
37+
condition_embedding_dim=12,
38+
hidden_dims=(32, 32),
39+
decoder_dims=(32, 32),
40+
)
41+
if solver_class == "otfm":
42+
solver = _otfm.OTFlowMatching(
43+
vf=vf,
44+
match_fn=match_linear,
45+
probability_path=dynamics.ConstantNoiseFlow(0.0),
46+
optimizer=opt,
47+
conditions={"drug": np.random.rand(2, 1, 3)},
48+
rng=vf_rng,
49+
)
50+
else:
51+
solver = _genot.GENOT(
52+
vf=vf,
53+
data_match_fn=match_linear,
54+
probability_path=dynamics.ConstantNoiseFlow(0.0),
55+
optimizer=opt,
56+
source_dim=5,
57+
target_dim=5,
58+
conditions={"drug": np.random.rand(2, 1, 3)},
59+
rng=vf_rng,
60+
)
61+
62+
trainer = cellflow.training.CellFlowTrainer(solver=solver)
63+
trainer.train(
64+
dataloader=dataloader,
65+
num_iterations=2,
66+
valid_freq=1,
67+
)
68+
start_batched = time.time()
69+
x_pred_batched = solver.predict(src, cond, batched=True)
70+
end_batched = time.time()
71+
diff_batched = end_batched - start_batched
72+
73+
start_nonbatched = time.time()
74+
x_pred_nonbatched = jax.tree.map(
75+
functools.partial(solver.predict, batched=False),
76+
src,
77+
cond, # type: ignore[attr-defined]
78+
)
79+
end_nonbatched = time.time()
80+
diff_nonbatched = end_nonbatched - start_nonbatched
81+
82+
assert x_pred_batched[("drug_1",)].shape == x_pred_nonbatched[("drug_1",)].shape
83+
assert np.allclose(
84+
x_pred_batched[("drug_1",)],
85+
x_pred_nonbatched[("drug_1",)],
86+
atol=1e-1,
87+
rtol=1e-2,
88+
)
89+
assert diff_nonbatched - diff_batched > 2

0 commit comments

Comments
 (0)