Skip to content

Commit 84278f5

Browse files
new common baselines and cross import (#825)
* new common baselines and cross import * pre-commit * addressing comments * pre-commit * fix wrong import * wrong import * pre-commit * wrong import 2 * pre-commit --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Former-commit-id: db800b7
1 parent 456424d commit 84278f5

8 files changed

Lines changed: 142 additions & 228 deletions

File tree

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .baseline import batch_random_integration
2+
from .baseline import celltype_random_integration
3+
from .baseline import no_integration
4+
from .baseline import random_integration
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from .....tools.decorators import method
2+
from .....tools.utils import check_version
3+
4+
import functools
5+
import numpy as np
6+
7+
8+
def _set_uns(adata):
9+
adata.uns["neighbors"] = adata.uns["uni"]
10+
adata.uns["neighbors"]["connectivities_key"] = "connectivities"
11+
adata.uns["neighbors"]["distances_key"] = "distances"
12+
13+
14+
def _randomize_features(X, partition=None):
15+
X_out = X.copy()
16+
if partition is None:
17+
partition = np.full(X.shape[0], 0)
18+
else:
19+
partition = np.asarray(partition)
20+
for partition_name in np.unique(partition):
21+
partition_idx = np.argwhere(partition == partition_name).flatten()
22+
X_out[partition_idx] = X[np.random.permutation(partition_idx)]
23+
return X_out
24+
25+
26+
def _randomize_graph(adata, partition=None):
27+
distances, connectivities = (
28+
adata.obsp["uni_distances"],
29+
adata.obsp["uni_connectivities"],
30+
)
31+
new_idx = _randomize_features(np.arange(distances.shape[0]), partition=partition)
32+
adata.obsp["distances"] = distances[new_idx][:, new_idx]
33+
adata.obsp["connectivities"] = connectivities[new_idx][:, new_idx]
34+
_set_uns(adata)
35+
return adata
36+
37+
38+
def _random_embedding(partition):
39+
from sklearn.preprocessing import LabelEncoder
40+
from sklearn.preprocessing import OneHotEncoder
41+
42+
embedding = OneHotEncoder().fit_transform(
43+
LabelEncoder().fit_transform(partition)[:, None]
44+
)
45+
embedding = embedding + np.random.uniform(-0.01, 0.01, embedding.shape)
46+
return embedding
47+
48+
49+
_baseline_method = functools.partial(
50+
method,
51+
paper_name="Open Problems for Single Cell Analysis",
52+
paper_reference="openproblems",
53+
paper_year=2022,
54+
code_url="https://github.com/openproblems-bio/openproblems",
55+
is_baseline=True,
56+
)
57+
58+
59+
@_baseline_method(
60+
method_name="No Integration",
61+
)
62+
def no_integration(adata, test=False):
63+
adata.obsp["connectivities"] = adata.obsp["uni_connectivities"]
64+
adata.obsp["distances"] = adata.obsp["uni_distances"]
65+
_set_uns(adata)
66+
adata.obsm["X_emb"] = adata.obsm["X_uni_pca"]
67+
adata.uns["method_code_version"] = check_version("openproblems")
68+
return adata
69+
70+
71+
@_baseline_method(
72+
method_name="Random Integration",
73+
)
74+
def random_integration(adata, test=False):
75+
adata.X = _randomize_features(adata.X)
76+
adata.obsm["X_emb"] = _randomize_features(adata.obsm["X_uni_pca"])
77+
adata = _randomize_graph(adata)
78+
adata.uns["method_code_version"] = check_version("openproblems")
79+
return adata
80+
81+
82+
@_baseline_method(
83+
method_name="Random Integration by Celltype",
84+
paper_name="Random Integration by Celltype (baseline)",
85+
paper_reference="openproblems",
86+
paper_year=2022,
87+
code_url="https://github.com/openproblems-bio/openproblems",
88+
is_baseline=True,
89+
)
90+
def celltype_random_integration(adata, test=False):
91+
adata.obsm["X_emb"] = _randomize_features(
92+
adata.obsm["X_uni_pca"], partition=adata.obs["labels"]
93+
)
94+
adata.X = _randomize_features(adata.X, partition=adata.obs["labels"])
95+
adata = _randomize_graph(
96+
adata,
97+
partition=adata.obs["labels"].to_numpy(),
98+
)
99+
adata.uns["method_code_version"] = check_version("openproblems")
100+
return adata
101+
102+
103+
@_baseline_method(
104+
method_name="Random Integration by Batch",
105+
)
106+
def batch_random_integration(adata, test=False):
107+
adata.obsm["X_emb"] = _randomize_features(
108+
adata.obsm["X_uni_pca"], partition=adata.obs["batch"]
109+
)
110+
adata.X = _randomize_features(adata.X, partition=adata.obs["batch"])
111+
adata = _randomize_graph(
112+
adata,
113+
partition=adata.obs["batch"].to_numpy(),
114+
)
115+
adata.uns["method_code_version"] = check_version("openproblems")
116+
return adata

openproblems/tasks/_batch_integration/batch_integration_embed/methods/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
from ..._common.methods.baseline import batch_random_integration
2+
from ..._common.methods.baseline import celltype_random_integration
3+
from ..._common.methods.baseline import no_integration
4+
from ..._common.methods.baseline import random_integration
5+
from ...batch_integration_graph.methods.baseline import celltype_random_graph
16
from ...batch_integration_graph.methods.combat import combat_full_scaled
27
from ...batch_integration_graph.methods.combat import combat_full_unscaled
38
from ...batch_integration_graph.methods.combat import combat_hvg_scaled
@@ -28,11 +33,7 @@
2833
from ...batch_integration_graph.methods.scanvi import scanvi_hvg_unscaled
2934
from ...batch_integration_graph.methods.scvi import scvi_full_unscaled
3035
from ...batch_integration_graph.methods.scvi import scvi_hvg_unscaled
31-
from .baseline import batch_random_integration
3236
from .baseline import celltype_random_embedding
33-
from .baseline import celltype_random_integration
34-
from .baseline import no_integration
3537
from .baseline import no_integration_batch
36-
from .baseline import random_integration
3738
from .scalex import scalex_full
3839
from .scalex import scalex_hvg

openproblems/tasks/_batch_integration/batch_integration_embed/methods/baseline.py

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from .....tools.decorators import method
22
from .....tools.utils import check_version
3-
from ...batch_integration_graph.methods.baseline import _random_embedding
4-
from ...batch_integration_graph.methods.baseline import _randomize_features
3+
from ..._common.methods.baseline import _random_embedding
54

65
import functools
76
import numpy as np
@@ -17,40 +16,6 @@
1716
)
1817

1918

20-
@_baseline_method(
21-
method_name="No Integration",
22-
)
23-
def no_integration(adata, test=False):
24-
adata.obsm["X_emb"] = adata.obsm["X_uni_pca"]
25-
adata.uns["method_code_version"] = check_version("openproblems")
26-
return adata
27-
28-
29-
@_baseline_method(
30-
method_name="Random Integration",
31-
)
32-
def random_integration(adata, test=False):
33-
adata.obsm["X_emb"] = _randomize_features(adata.obsm["X_uni_pca"])
34-
adata.uns["method_code_version"] = check_version("openproblems")
35-
return adata
36-
37-
38-
@method(
39-
method_name="Random Integration by Celltype",
40-
paper_name="Random Integration by Celltype (baseline)",
41-
paper_reference="openproblems",
42-
paper_year=2022,
43-
code_url="https://github.com/openproblems-bio/openproblems",
44-
is_baseline=True,
45-
)
46-
def celltype_random_integration(adata, test=False):
47-
adata.obsm["X_emb"] = _randomize_features(
48-
adata.obsm["X_uni_pca"], partition=adata.obs["labels"]
49-
)
50-
adata.uns["method_code_version"] = check_version("openproblems")
51-
return adata
52-
53-
5419
@_baseline_method(
5520
method_name="Random Embedding by Celltype",
5621
)
@@ -60,17 +25,6 @@ def celltype_random_embedding(adata, test=False):
6025
return adata
6126

6227

63-
@_baseline_method(
64-
method_name="Random Integration by Batch",
65-
)
66-
def batch_random_integration(adata, test=False):
67-
adata.obsm["X_emb"] = _randomize_features(
68-
adata.obsm["X_uni_pca"], partition=adata.obs["batch"]
69-
)
70-
adata.uns["method_code_version"] = check_version("openproblems")
71-
return adata
72-
73-
7428
@_baseline_method(
7529
method_name="No Integration by Batch",
7630
)

openproblems/tasks/_batch_integration/batch_integration_feature/methods/__init__.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
# from ...batch_integration_graph.methods.seuratrpca import seuratrpca_full_scaled
2+
# from ...batch_integration_graph.methods.seuratrpca import seuratrpca_full_unscaled
3+
# from ...batch_integration_graph.methods.seuratrpca import seuratrpca_hvg_scaled
4+
# from ...batch_integration_graph.methods.seuratrpca import seuratrpca_hvg_unscaled
5+
from ..._common.methods.baseline import batch_random_integration
6+
from ..._common.methods.baseline import celltype_random_integration
7+
from ..._common.methods.baseline import no_integration
8+
from ..._common.methods.baseline import random_integration
9+
from ...batch_integration_embed.methods.baseline import celltype_random_embedding
10+
from ...batch_integration_embed.methods.baseline import no_integration_batch
11+
from ...batch_integration_graph.methods.baseline import celltype_random_graph
112
from ...batch_integration_graph.methods.combat import combat_full_scaled
213
from ...batch_integration_graph.methods.combat import combat_full_unscaled
314
from ...batch_integration_graph.methods.combat import combat_hvg_scaled
@@ -28,10 +39,6 @@
2839
from ...batch_integration_graph.methods.scanorama import scanorama_feature_full_unscaled
2940
from ...batch_integration_graph.methods.scanorama import scanorama_feature_hvg_scaled
3041
from ...batch_integration_graph.methods.scanorama import scanorama_feature_hvg_unscaled
31-
from .baseline import batch_random_integration
32-
from .baseline import celltype_random_integration
33-
from .baseline import no_integration
34-
from .baseline import random_integration
3542
from .scalex import scalex_full
3643
from .scalex import scalex_hvg
3744

@@ -44,8 +51,3 @@
4451
# from ...batch_integration_graph.methods.seurat_full import seurat_full_unscaled
4552
# from ...batch_integration_graph.methods.seurat_full import seurat_hvg_scaled
4653
# from ...batch_integration_graph.methods.seurat_full import seurat_hvg_unscaled
47-
48-
# from ...batch_integration_graph.methods.seuratrpca import seuratrpca_full_scaled
49-
# from ...batch_integration_graph.methods.seuratrpca import seuratrpca_full_unscaled
50-
# from ...batch_integration_graph.methods.seuratrpca import seuratrpca_hvg_scaled
51-
# from ...batch_integration_graph.methods.seuratrpca import seuratrpca_hvg_unscaled

openproblems/tasks/_batch_integration/batch_integration_feature/methods/baseline.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

openproblems/tasks/_batch_integration/batch_integration_graph/methods/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from .baseline import batch_random_integration
1+
from ..._common.methods.baseline import batch_random_integration
2+
from ..._common.methods.baseline import celltype_random_integration
3+
from ..._common.methods.baseline import no_integration
4+
from ..._common.methods.baseline import random_integration
25
from .baseline import celltype_random_graph
3-
from .baseline import celltype_random_integration
4-
from .baseline import no_integration
5-
from .baseline import random_integration
66
from .bbknn import bbknn_full_scaled
77
from .bbknn import bbknn_full_unscaled
88
from .bbknn import bbknn_hvg_scaled

0 commit comments

Comments
 (0)