@@ -34,6 +34,20 @@ def _create_neuralee_dataset(
3434 return dataset
3535
3636
37+ def _neuralee (dataset , d : int = 2 , test : bool = False ):
38+ from neuralee .embedding import NeuralEE
39+
40+ import torch
41+
42+ NEE = NeuralEE (dataset , d = d , device = torch .device ("cpu" ))
43+ fine_tune_kwargs = dict (verbose = False )
44+ if test :
45+ fine_tune_kwargs ["maxit" ] = 10
46+ res = NEE .fine_tune (** fine_tune_kwargs )
47+
48+ return res ["X" ].detach ().cpu ().numpy ()
49+
50+
3751@method (
3852 method_name = "NeuralEE (CPU) (Default)" ,
3953 paper_name = " NeuralEE: A GPU-Accelerated Elastic Embedding "
@@ -46,13 +60,6 @@ def _create_neuralee_dataset(
4660 image = "openproblems-python-extras" ,
4761)
4862def neuralee_default (adata : AnnData , test : bool = False ) -> AnnData :
49- from neuralee .embedding import NeuralEE
50-
51- import torch
52-
53- # Store raw counts for use by metrics
54- adata .layers ["counts" ] = adata .X .copy ()
55-
5663 # this can fail due to sparseness of data; if so, retry with more genes
5764 # note that this is a deviation from the true default behavior, which fails
5865 # see https://github.com/openproblems-bio/openproblems/issues/375
@@ -72,11 +79,7 @@ def neuralee_default(adata: AnnData, test: bool = False) -> AnnData:
7279 else :
7380 break
7481
75- NEE = NeuralEE (dataset , d = 2 , device = torch .device ("cpu" ))
76- res = NEE .fine_tune (verbose = False )
77-
78- adata .obsm ["X_emb" ] = res ["X" ].detach ().cpu ().numpy ()
79-
82+ adata .obsm ["X_emb" ] = _neuralee (dataset , test = test )
8083 return adata
8184
8285
@@ -92,17 +95,7 @@ def neuralee_default(adata: AnnData, test: bool = False) -> AnnData:
9295 image = "openproblems-python-extras" ,
9396)
9497def neuralee_logCPM_1kHVG (adata : AnnData , test : bool = False ) -> AnnData :
95- from neuralee .embedding import NeuralEE
96-
97- import torch
98-
9998 adata = log_cpm_hvg (adata )
100-
10199 dataset = _create_neuralee_dataset (adata , normalize = False , subsample_genes = None )
102-
103- NEE = NeuralEE (dataset , d = 2 , device = torch .device ("cpu" ))
104- res = NEE .fine_tune (verbose = False )
105-
106- adata .obsm ["X_emb" ] = res ["X" ].detach ().cpu ().numpy ()
107-
100+ adata .obsm ["X_emb" ] = _neuralee (dataset , test = test )
108101 return adata
0 commit comments