Skip to content

Commit 43df220

Browse files
authored
Merge pull request #262 from theislab/fix/control_split
Fix control splitting
2 parents c574350 + 0ab4e4f commit 43df220

4 files changed

Lines changed: 56 additions & 8 deletions

File tree

src/cellflow/data/_datamanager.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,12 +250,18 @@ def get_prediction_data(
250250
adata=adata, split_cov_combs=split_cov_combs
251251
)
252252

253+
control_to_perturbation = self._get_control_to_perturbation(
254+
covariate_data=covariate_data,
255+
perturbation_idx_to_covariates=cond_data.perturbation_idx_to_covariates,
256+
split_cov_combs=split_cov_combs,
257+
)
258+
253259
return PredictionData(
254260
cell_data=cell_data,
255261
split_covariates_mask=split_covariates_mask,
256262
split_idx_to_covariates=split_idx_to_covariates,
257263
condition_data=cond_data.condition_data,
258-
control_to_perturbation=cond_data.control_to_perturbation,
264+
control_to_perturbation=control_to_perturbation,
259265
perturbation_idx_to_covariates=cond_data.perturbation_idx_to_covariates,
260266
perturbation_idx_to_id=cond_data.perturbation_idx_to_id,
261267
max_combination_length=cond_data.max_combination_length,
@@ -830,6 +836,35 @@ def _get_split_covariates_mask(
830836
src_counter += 1
831837
return np.asarray(split_covariates_mask), split_idx_to_covariates
832838

839+
def _get_control_to_perturbation(
840+
self,
841+
covariate_data: pd.DataFrame,
842+
perturbation_idx_to_covariates: dict[int, tuple[Any]],
843+
split_cov_combs: np.ndarray | list[list[Any]],
844+
) -> dict[int, np.ndarray]:
845+
control_to_perturbation = {}
846+
847+
if len(self._split_covariates) == 0:
848+
control_to_perturbation[0] = sorted(perturbation_idx_to_covariates.keys())
849+
else:
850+
for control_idx, split_combination in enumerate(split_cov_combs):
851+
filter_dict = dict(zip(self.split_covariates, split_combination, strict=False))
852+
split_cov_mask = (covariate_data[list(filter_dict.keys())] == list(filter_dict.values())).all(axis=1)
853+
# Get subset of covariate_data that matches this split combination
854+
matching_data = covariate_data[split_cov_mask]
855+
# Find perturbation indices that correspond to this split combination
856+
perturbation_indices = []
857+
for pert_idx, pert_covariates in perturbation_idx_to_covariates.items():
858+
for _, row in matching_data.iterrows():
859+
pert_values = tuple(row[self.perturb_covar_keys])
860+
if pert_values == pert_covariates:
861+
perturbation_indices.append(pert_idx)
862+
break
863+
864+
control_to_perturbation[control_idx] = sorted(perturbation_indices)
865+
866+
return control_to_perturbation
867+
833868
@staticmethod
834869
def _verify_perturbation_covariates(data: dict[str, Sequence[str]] | None) -> dict[str, list[str]]:
835870
if data is None:

src/cellflow/model/_cellflow.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,16 @@ def prepare_validation_data(
232232
n_conditions_on_train_end=n_conditions_on_train_end,
233233
)
234234
self._validation_data[name] = val_data
235+
# Batched prediction is not compatible with split covariates
236+
# as all conditions need to be the same size
237+
split_val = len(val_data.control_to_perturbation) > 1
238+
predict_kwargs = predict_kwargs or {}
239+
# Check if predict_kwargs is alreday provided from an earlier call
240+
if "predict_kwargs" in self._validation_data:
241+
predict_kwargs = self._validation_data["predict_kwargs"].update(predict_kwargs)
242+
# Set batched prediction to False if split_val is True
243+
if split_val:
244+
predict_kwargs["batched"] = False
235245
self._validation_data["predict_kwargs"] = predict_kwargs
236246

237247
def prepare_model(
@@ -494,10 +504,8 @@ def prepare_model(
494504
)
495505
else:
496506
raise NotImplementedError(f"Solver must be an instance of OTFlowMatching or GENOT, got {type(self.solver)}")
497-
if "predict_kwargs" in self.validation_data:
498-
self._trainer = CellFlowTrainer(solver=self.solver, predict_kwargs=self.validation_data["predict_kwargs"]) # type: ignore[arg-type]
499-
else:
500-
self._trainer = CellFlowTrainer(solver=self.solver) # type: ignore[arg-type]
507+
508+
self._trainer = CellFlowTrainer(solver=self.solver, predict_kwargs=self.validation_data["predict_kwargs"]) # type: ignore[arg-type]
501509

502510
def train(
503511
self,

src/cellflow/solvers/_otfm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Callable
2+
from functools import partial
23
from typing import Any
34

45
import diffrax
@@ -263,6 +264,12 @@ def predict(
263264

264265
pred_targets = batched_predict(src_inputs, batched_conditions)
265266
return {k: pred_targets[i] for i, k in enumerate(keys)}
267+
elif isinstance(x, dict):
268+
return jax.tree.map(
269+
partial(self._predict_jit, rng=rng, **kwargs),
270+
x,
271+
condition, # type: ignore[attr-defined]
272+
)
266273
else:
267274
x_pred = self._predict_jit(x, condition, rng, **kwargs)
268275
return np.array(x_pred)

src/cellflow/training/_trainer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,7 @@ def _validation_step(
6767
condition = batch.get("condition", None)
6868
true_tgt = batch["target"]
6969
valid_source_data[val_key] = src
70-
valid_pred_data[val_key] = self.solver.predict(
71-
src, condition=condition, batched=True, **self.predict_kwargs
72-
)
70+
valid_pred_data[val_key] = self.solver.predict(src, condition=condition, **self.predict_kwargs)
7371
valid_true_data[val_key] = true_tgt
7472

7573
return valid_source_data, valid_true_data, valid_pred_data

0 commit comments

Comments
 (0)