Skip to content

Commit f74588c

Browse files
set maxit on neuralee (#403)
1 parent ea93d6f commit f74588c

1 file changed

Lines changed: 16 additions & 23 deletions

File tree

  • openproblems/tasks/dimensionality_reduction/methods

openproblems/tasks/dimensionality_reduction/methods/neuralee.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,20 @@ def _create_neuralee_dataset(
3434
return dataset
3535

3636

37+
def _neuralee(dataset, d: int = 2, test: bool = False):
38+
from neuralee.embedding import NeuralEE
39+
40+
import torch
41+
42+
NEE = NeuralEE(dataset, d=d, device=torch.device("cpu"))
43+
fine_tune_kwargs = dict(verbose=False)
44+
if test:
45+
fine_tune_kwargs["maxit"] = 10
46+
res = NEE.fine_tune(**fine_tune_kwargs)
47+
48+
return res["X"].detach().cpu().numpy()
49+
50+
3751
@method(
3852
method_name="NeuralEE (CPU) (Default)",
3953
paper_name=" NeuralEE: A GPU-Accelerated Elastic Embedding "
@@ -46,13 +60,6 @@ def _create_neuralee_dataset(
4660
image="openproblems-python-extras",
4761
)
4862
def neuralee_default(adata: AnnData, test: bool = False) -> AnnData:
49-
from neuralee.embedding import NeuralEE
50-
51-
import torch
52-
53-
# Store raw counts for use by metrics
54-
adata.layers["counts"] = adata.X.copy()
55-
5663
# this can fail due to sparseness of data; if so, retry with more genes
5764
# note that this is a deviation from the true default behavior, which fails
5865
# see https://github.com/openproblems-bio/openproblems/issues/375
@@ -72,11 +79,7 @@ def neuralee_default(adata: AnnData, test: bool = False) -> AnnData:
7279
else:
7380
break
7481

75-
NEE = NeuralEE(dataset, d=2, device=torch.device("cpu"))
76-
res = NEE.fine_tune(verbose=False)
77-
78-
adata.obsm["X_emb"] = res["X"].detach().cpu().numpy()
79-
82+
adata.obsm["X_emb"] = _neuralee(dataset, test=test)
8083
return adata
8184

8285

@@ -92,17 +95,7 @@ def neuralee_default(adata: AnnData, test: bool = False) -> AnnData:
9295
image="openproblems-python-extras",
9396
)
9497
def neuralee_logCPM_1kHVG(adata: AnnData, test: bool = False) -> AnnData:
95-
from neuralee.embedding import NeuralEE
96-
97-
import torch
98-
9998
adata = log_cpm_hvg(adata)
100-
10199
dataset = _create_neuralee_dataset(adata, normalize=False, subsample_genes=None)
102-
103-
NEE = NeuralEE(dataset, d=2, device=torch.device("cpu"))
104-
res = NEE.fine_tune(verbose=False)
105-
106-
adata.obsm["X_emb"] = res["X"].detach().cpu().numpy()
107-
100+
adata.obsm["X_emb"] = _neuralee(dataset, test=test)
108101
return adata

0 commit comments

Comments
 (0)