@@ -250,12 +250,18 @@ def get_prediction_data(
250250 adata = adata , split_cov_combs = split_cov_combs
251251 )
252252
253+ control_to_perturbation = self ._get_control_to_perturbation (
254+ covariate_data = covariate_data ,
255+ perturbation_idx_to_covariates = cond_data .perturbation_idx_to_covariates ,
256+ split_cov_combs = split_cov_combs ,
257+ )
258+
253259 return PredictionData (
254260 cell_data = cell_data ,
255261 split_covariates_mask = split_covariates_mask ,
256262 split_idx_to_covariates = split_idx_to_covariates ,
257263 condition_data = cond_data .condition_data ,
258- control_to_perturbation = cond_data . control_to_perturbation ,
264+ control_to_perturbation = control_to_perturbation ,
259265 perturbation_idx_to_covariates = cond_data .perturbation_idx_to_covariates ,
260266 perturbation_idx_to_id = cond_data .perturbation_idx_to_id ,
261267 max_combination_length = cond_data .max_combination_length ,
@@ -830,6 +836,35 @@ def _get_split_covariates_mask(
830836 src_counter += 1
831837 return np .asarray (split_covariates_mask ), split_idx_to_covariates
832838
839+ def _get_control_to_perturbation (
840+ self ,
841+ covariate_data : pd .DataFrame ,
842+ perturbation_idx_to_covariates : dict [int , tuple [Any ]],
843+ split_cov_combs : np .ndarray | list [list [Any ]],
844+ ) -> dict [int , np .ndarray ]:
845+ control_to_perturbation = {}
846+
847+ if len (self ._split_covariates ) == 0 :
848+ control_to_perturbation [0 ] = sorted (perturbation_idx_to_covariates .keys ())
849+ else :
850+ for control_idx , split_combination in enumerate (split_cov_combs ):
851+ filter_dict = dict (zip (self .split_covariates , split_combination , strict = False ))
852+ split_cov_mask = (covariate_data [list (filter_dict .keys ())] == list (filter_dict .values ())).all (axis = 1 )
853+ # Get subset of covariate_data that matches this split combination
854+ matching_data = covariate_data [split_cov_mask ]
855+ # Find perturbation indices that correspond to this split combination
856+ perturbation_indices = []
857+ for pert_idx , pert_covariates in perturbation_idx_to_covariates .items ():
858+ for _ , row in matching_data .iterrows ():
859+ pert_values = tuple (row [self .perturb_covar_keys ])
860+ if pert_values == pert_covariates :
861+ perturbation_indices .append (pert_idx )
862+ break
863+
864+ control_to_perturbation [control_idx ] = sorted (perturbation_indices )
865+
866+ return control_to_perturbation
867+
833868 @staticmethod
834869 def _verify_perturbation_covariates (data : dict [str , Sequence [str ]] | None ) -> dict [str , list [str ]]:
835870 if data is None :
0 commit comments