From 45db0b095d06a9348d9868d126ec3cb5fd5ec989 Mon Sep 17 00:00:00 2001 From: Lukas Heumos Date: Tue, 26 May 2026 10:46:35 +0200 Subject: [PATCH] perturbation_space: skip None-keyed layers in _combine sc.get.aggregate can leave a None-keyed layer on the AnnData it returns; PseudobulkSpace.compute strips it, but the _combine refactor in #994 added `key.endswith("_control_diff")` calls that crash on a stray None key (e.g. when callers hand-build the input). Filter to string keys before doing the string ops and the iteration that builds the new AnnData. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../_perturbation_space.py | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/pertpy/tools/_perturbation_space/_perturbation_space.py b/pertpy/tools/_perturbation_space/_perturbation_space.py index e3d1a0cf..ea25dae5 100644 --- a/pertpy/tools/_perturbation_space/_perturbation_space.py +++ b/pertpy/tools/_perturbation_space/_perturbation_space.py @@ -186,11 +186,6 @@ def _combine( if ensure_consistency: adata = self.compute_control_diff(adata, copy=True, all_data=True, target_col=target_col) - rename_back = { - key: key.removesuffix("_control_diff") - for key in [*adata.layers.keys(), *adata.obsm.keys()] - if key.endswith("_control_diff") - } else: warnings.warn( "Combining perturbations without `ensure_consistency=True` is only well-defined " @@ -198,7 +193,21 @@ def _combine( "(otherwise perturbation - perturbation != control).", stacklevel=3, ) - rename_back = {} + + # sc.get.aggregate can leave a `None`-keyed layer behind (the pre-aggregation .X); + # PseudobulkSpace.compute strips it but defensively re-strip here so callers passing + # a hand-built AnnData don't crash inside the string ops below. + layer_keys = [k for k in adata.layers if isinstance(k, str)] + obsm_keys = [k for k in adata.obsm if isinstance(k, str)] + rename_back = ( + { + key: key.removesuffix("_control_diff") + for key in [*layer_keys, *obsm_keys] + if key.endswith("_control_diff") + } + if ensure_consistency + else {} + ) def _running(values: np.ndarray) -> np.ndarray: result = values[adata.obs_names.get_loc(reference_key)].astype(float, copy=True) @@ -207,13 +216,15 @@ def _running(values: np.ndarray) -> np.ndarray: return result new_layers: dict[str, np.ndarray] = {} - for layer_key, mat in adata.layers.items(): + for layer_key in layer_keys: + mat = adata.layers[layer_key] new_layers[rename_back.get(layer_key, layer_key)] = np.concatenate( (np.asarray(mat), _running(np.asarray(mat))[None, :]), axis=0 ) new_obsm: dict[str, np.ndarray] = {} - for embedding_key, mat in adata.obsm.items(): + for embedding_key in obsm_keys: + mat = adata.obsm[embedding_key] new_obsm[rename_back.get(embedding_key, embedding_key)] = np.concatenate( (np.asarray(mat), _running(np.asarray(mat))[None, :]), axis=0 )