11from ....tools .decorators import method
22from ....tools .utils import check_version
3+ from .xgboost import _xgboost
4+ from typing import Optional
35
46import functools
57
@@ -104,7 +106,15 @@ def _scanvi(adata, test=False, n_hidden=None, n_latent=None, n_layers=None):
104106 return preds
105107
106108
107- def _scanvi_scarches (adata , test = False , n_hidden = None , n_latent = None , n_layers = None ):
109+ def _scanvi_scarches (
110+ adata ,
111+ test = False ,
112+ n_hidden = None ,
113+ n_latent = None ,
114+ n_layers = None ,
115+ prediction_method = "scanvi" ,
116+ ):
117+ import numpy as np
108118 import scvi
109119
110120 if test :
@@ -116,11 +126,14 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N
116126 n_layers = n_layers or 2
117127 n_hidden = n_hidden or 128
118128
129+ unlabeled_category = "Unknown"
130+
119131 # new obs labels to mask test set
132+ adata .obs ["scanvi_labels" ] = np .where (
133+ adata .obs ["is_train" ], adata .obs ["labels" ], unlabeled_category
134+ )
120135 adata_train = adata [adata .obs ["is_train" ]].copy ()
121- adata_train .obs ["scanvi_labels" ] = adata_train .obs ["labels" ].copy ()
122136 adata_test = adata [~ adata .obs ["is_train" ]].copy ()
123- adata_test .obs ["scanvi_labels" ] = "Unknown"
124137 scvi .model .SCVI .setup_anndata (
125138 adata_train , batch_key = "batch" , labels_key = "scanvi_labels"
126139 )
@@ -145,7 +158,9 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N
145158 train_kwargs ["limit_train_batches" ] = 10
146159 train_kwargs ["limit_val_batches" ] = 10
147160 scvi_model .train (** train_kwargs )
148- model = scvi .model .SCANVI .from_scvi_model (scvi_model , unlabeled_category = "Unknown" )
161+ model = scvi .model .SCANVI .from_scvi_model (
162+ scvi_model , unlabeled_category = unlabeled_category
163+ )
149164 model .train (** train_kwargs )
150165
151166 query_model = scvi .model .SCANVI .load_query_data (adata_test , model )
@@ -156,6 +171,15 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N
156171 train_kwargs ["limit_val_batches" ] = 10
157172 query_model .train (plan_kwargs = dict (weight_decay = 0.0 ), ** train_kwargs )
158173
174+ if prediction_method == "scanvi" :
175+ preds = _pred_scanvi (adata , query_model )
176+ elif prediction_method == "xgboost" :
177+ preds = _pred_xgb (adata , query_model , test = test )
178+
179+ return preds
180+
181+
182+ def _pred_scanvi (adata , query_model ):
159183 # this is temporary and won't be used
160184 adata .obs ["scanvi_labels" ] = "Unknown"
161185 preds = query_model .predict (adata )
@@ -164,6 +188,20 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N
164188 return preds
165189
166190
191+ # note: could extend test option
192+ def _pred_xgb (
193+ adata ,
194+ query_model ,
195+ test = False ,
196+ num_round : Optional [int ] = None ,
197+ ):
198+ adata .obsm ["X_emb" ] = query_model .get_latent_representation (adata )
199+ adata = _xgboost (
200+ adata , test = test , obsm = "X_emb" , num_round = num_round , tree_method = "hist"
201+ )
202+ return adata .obs ["labels_pred" ]
203+
204+
167205@_scanvi_method (method_name = "scANVI (All genes)" )
168206def scanvi_all_genes (adata , test = False ):
169207 adata .obs ["labels_pred" ] = _scanvi (adata , test = test )
@@ -194,3 +232,25 @@ def scarches_scanvi_hvg(adata, test=False):
194232 adata .obs ["labels_pred" ] = _scanvi_scarches (bdata , test = test )
195233 adata .uns ["method_code_version" ] = check_version ("scvi-tools" )
196234 return adata
235+
236+
237+ @_scanvi_scarches_method (method_name = "scArches+scANVI+xgboost (All genes)" )
238+ def scarches_scanvi_xgb_all_genes (adata , test = False ):
239+ adata .obs ["labels_pred" ] = _scanvi_scarches (
240+ adata , test = test , prediction_method = "xgboost"
241+ )
242+
243+ adata .uns ["method_code_version" ] = check_version ("scvi-tools" )
244+ return adata
245+
246+
247+ @_scanvi_scarches_method (method_name = "scArches+scANVI+xgboost (Seurat v3 2000 HVG)" )
248+ def scarches_scanvi_xgb_hvg (adata , test = False ):
249+ hvg_df = _hvg (adata , test )
250+ bdata = adata [:, hvg_df .highly_variable ].copy ()
251+ adata .obs ["labels_pred" ] = _scanvi_scarches (
252+ bdata , test = test , prediction_method = "xgboost"
253+ )
254+
255+ adata .uns ["method_code_version" ] = check_version ("scvi-tools" )
256+ return adata
0 commit comments