diff --git a/pertpy/tools/_perturbation_space/_perturbation_space.py b/pertpy/tools/_perturbation_space/_perturbation_space.py index e3d1a0cf..ea25dae5 100644 --- a/pertpy/tools/_perturbation_space/_perturbation_space.py +++ b/pertpy/tools/_perturbation_space/_perturbation_space.py @@ -186,11 +186,6 @@ def _combine( if ensure_consistency: adata = self.compute_control_diff(adata, copy=True, all_data=True, target_col=target_col) - rename_back = { - key: key.removesuffix("_control_diff") - for key in [*adata.layers.keys(), *adata.obsm.keys()] - if key.endswith("_control_diff") - } else: warnings.warn( "Combining perturbations without `ensure_consistency=True` is only well-defined " @@ -198,7 +193,21 @@ def _combine( "(otherwise perturbation - perturbation != control).", stacklevel=3, ) - rename_back = {} + + # sc.get.aggregate can leave a `None`-keyed layer behind (the pre-aggregation .X); + # PseudobulkSpace.compute strips it but defensively re-strip here so callers passing + # a hand-built AnnData don't crash inside the string ops below. + layer_keys = [k for k in adata.layers if isinstance(k, str)] + obsm_keys = [k for k in adata.obsm if isinstance(k, str)] + rename_back = ( + { + key: key.removesuffix("_control_diff") + for key in [*layer_keys, *obsm_keys] + if key.endswith("_control_diff") + } + if ensure_consistency + else {} + ) def _running(values: np.ndarray) -> np.ndarray: result = values[adata.obs_names.get_loc(reference_key)].astype(float, copy=True) @@ -207,13 +216,15 @@ def _running(values: np.ndarray) -> np.ndarray: return result new_layers: dict[str, np.ndarray] = {} - for layer_key, mat in adata.layers.items(): + for layer_key in layer_keys: + mat = adata.layers[layer_key] new_layers[rename_back.get(layer_key, layer_key)] = np.concatenate( (np.asarray(mat), _running(np.asarray(mat))[None, :]), axis=0 ) new_obsm: dict[str, np.ndarray] = {} - for embedding_key, mat in adata.obsm.items(): + for embedding_key in obsm_keys: + mat = adata.obsm[embedding_key] new_obsm[rename_back.get(embedding_key, embedding_key)] = np.concatenate( (np.asarray(mat), _running(np.asarray(mat))[None, :]), axis=0 )