@@ -174,31 +174,10 @@ def get_condition_embedding(self, condition: dict[str, ArrayLike], return_as_num
174174 return np .asarray (cond_mean ), np .asarray (cond_logvar )
175175 return cond_mean , cond_logvar
176176
177- def predict (
177+ def _predict_jit (
178178 self , x : ArrayLike , condition : dict [str , ArrayLike ], rng : jax .Array | None = None , ** kwargs : Any
179179 ) -> ArrayLike :
180- """Predict the translated source ``x`` under condition ``condition``.
181-
182- This function solves the ODE learnt with
183- the :class:`~cellflow.networks.ConditionalVelocityField`.
184-
185- Parameters
186- ----------
187- x
188- Input data of shape [batch_size, ...].
189- condition
190- Condition of the input data of shape [batch_size, ...].
191- rng
192- Random number generator to sample from the latent distribution,
193- only used if ``condition_mode='stochastic'``. If :obj:`None`, the
194- mean embedding is used.
195- kwargs
196- Keyword arguments for :func:`diffrax.diffeqsolve`.
197-
198- Returns
199- -------
200- The push-forward distribution of ``x`` under condition ``condition``.
201- """
180+ """See :meth:`OTFlowMatching.predict`."""
202181 kwargs .setdefault ("dt0" , None )
203182 kwargs .setdefault ("solver" , diffrax .Tsit5 ())
204183 kwargs .setdefault ("stepsize_controller" , diffrax .PIDController (rtol = 1e-5 , atol = 1e-5 ))
@@ -226,7 +205,67 @@ def solve_ode(x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise:
226205 return result .ys [0 ]
227206
228207 x_pred = jax .jit (jax .vmap (solve_ode , in_axes = [0 , None , None ]))(x , condition , encoder_noise )
229- return np .array (x_pred )
208+ return x_pred
209+
210+ def predict (
211+ self ,
212+ x : ArrayLike | dict [str , ArrayLike ],
213+ condition : dict [str , ArrayLike ] | dict [str , dict [str , ArrayLike ]],
214+ rng : jax .Array | None = None ,
215+ batched : bool = False ,
216+ ** kwargs : Any ,
217+ ) -> ArrayLike | dict [str , ArrayLike ]:
218+ """Predict the translated source ``x`` under condition ``condition``.
219+
220+ This function solves the ODE learnt with
221+ the :class:`~cellflow.networks.ConditionalVelocityField`.
222+
223+ Parameters
224+ ----------
225+ x
226+ A dictionary with keys indicating the name of the condition and values containing
227+ the input data as arrays. If ``batched=False`` provide an array of shape [batch_size, ...].
228+ condition
229+ A dictionary with keys indicating the name of the condition and values containing
230+ the condition of input data as arrays. If ``batched=False`` provide an array of shape
231+ [batch_size, ...].
232+ rng
233+ Random number generator to sample from the latent distribution,
234+ only used if ``condition_mode='stochastic'``. If :obj:`None`, the
235+ mean embedding is used.
236+ batched
237+ Whether to use batched prediction. This is only supported if the input has
238+ the same number of cells for each condition. For example, this works when using
239+ :class:`~cellflow.data.ValidationSampler` to sample the validation data.
240+ kwargs
241+ Keyword arguments for :func:`diffrax.diffeqsolve`.
242+
243+ Returns
244+ -------
245+ The push-forward distribution of ``x`` under condition ``condition``.
246+ """
247+ if batched and not x :
248+ return {}
249+
250+ if batched :
251+ keys = sorted (x .keys ())
252+ condition_keys = sorted (set ().union (* (condition [k ].keys () for k in keys )))
253+ _predict_jit = jax .jit (lambda x , condition : self ._predict_jit (x , condition , rng , ** kwargs ))
254+ batched_predict = jax .vmap (_predict_jit , in_axes = (0 , dict .fromkeys (condition_keys , 0 )))
255+ # assert that the number of cells is the same for each condition
256+ n_cells = x [keys [0 ]].shape [0 ]
257+ for k in keys :
258+ assert x [k ].shape [0 ] == n_cells , "The number of cells must be the same for each condition"
259+ src_inputs = jnp .stack ([x [k ] for k in keys ], axis = 0 )
260+ batched_conditions = {}
261+ for cond_key in condition_keys :
262+ batched_conditions [cond_key ] = jnp .stack ([condition [k ][cond_key ] for k in keys ])
263+
264+ pred_targets = batched_predict (src_inputs , batched_conditions )
265+ return {k : pred_targets [i ] for i , k in enumerate (keys )}
266+ else :
267+ x_pred = self ._predict_jit (x , condition , rng , ** kwargs )
268+ return np .array (x_pred )
230269
231270 @property
232271 def is_trained (self ) -> bool :
0 commit comments