Skip to content

Commit 1caa044

Browse files
jitter baseline patch (#838)
* jitter baseline patch * celltype_random_embedding_w/o_jitter * Split jitter and no jitter * import --------- Co-authored-by: Scott Gigante <84813314+scottgigante-immunai@users.noreply.github.com> Former-commit-id: d4a00be
1 parent 84278f5 commit 1caa044

3 files changed

Lines changed: 14 additions & 3 deletions

File tree

openproblems/tasks/_batch_integration/_common/methods/baseline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,15 @@ def _randomize_graph(adata, partition=None):
3535
return adata
3636

3737

38-
def _random_embedding(partition):
38+
def _random_embedding(partition, jitter=0.01):
3939
from sklearn.preprocessing import LabelEncoder
4040
from sklearn.preprocessing import OneHotEncoder
4141

4242
embedding = OneHotEncoder().fit_transform(
4343
LabelEncoder().fit_transform(partition)[:, None]
4444
)
45-
embedding = embedding + np.random.uniform(-0.01, 0.01, embedding.shape)
45+
if jitter is not None:
46+
embedding = embedding + np.random.uniform(-1 * jitter, jitter, embedding.shape)
4647
return embedding
4748

4849

openproblems/tasks/_batch_integration/batch_integration_embed/methods/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from ...batch_integration_graph.methods.scvi import scvi_full_unscaled
3535
from ...batch_integration_graph.methods.scvi import scvi_hvg_unscaled
3636
from .baseline import celltype_random_embedding
37+
from .baseline import celltype_random_embedding_jitter
3738
from .baseline import no_integration_batch
3839
from .scalex import scalex_full
3940
from .scalex import scalex_hvg

openproblems/tasks/_batch_integration/batch_integration_embed/methods/baseline.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,20 @@
1616
)
1717

1818

19+
@_baseline_method(
20+
method_name="Random Embedding by Celltype (with jitter)",
21+
)
22+
def celltype_random_embedding_jitter(adata, test=False):
23+
adata.obsm["X_emb"] = _random_embedding(partition=adata.obs["labels"], jitter=0.01)
24+
adata.uns["method_code_version"] = check_version("openproblems")
25+
return adata
26+
27+
1928
@_baseline_method(
2029
method_name="Random Embedding by Celltype",
2130
)
2231
def celltype_random_embedding(adata, test=False):
23-
adata.obsm["X_emb"] = _random_embedding(partition=adata.obs["labels"])
32+
adata.obsm["X_emb"] = _random_embedding(partition=adata.obs["labels"], jitter=None)
2433
adata.uns["method_code_version"] = check_version("openproblems")
2534
return adata
2635

0 commit comments

Comments
 (0)