Skip to content

Commit cef4e5c

Browse files
scottgigante-immunaiLuckyMDgithub-actions[bot]
authored
[label projection] scVI+scArches+XGBoost (#853)
* added scvi-scarches-xgb to label projection * pre-commit * added methods to init and renamed * updated method naming and added imports * pre-commit * made labels categorical * test integer label encoding * pre-commit * share xgboost code with scarches_xgb --------- Co-authored-by: LuckyMD <m.d.luecken@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent dc45472 commit cef4e5c

3 files changed

Lines changed: 82 additions & 7 deletions

File tree

openproblems/tasks/label_projection/methods/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from .scvi_tools import scanvi_hvg
1212
from .scvi_tools import scarches_scanvi_all_genes
1313
from .scvi_tools import scarches_scanvi_hvg
14+
from .scvi_tools import scarches_scanvi_xgb_all_genes
15+
from .scvi_tools import scarches_scanvi_xgb_hvg
1416
from .seurat import seurat
1517
from .xgboost import xgboost_log_cp10k
1618
from .xgboost import xgboost_scran

openproblems/tasks/label_projection/methods/scvi_tools.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from ....tools.decorators import method
22
from ....tools.utils import check_version
3+
from .xgboost import _xgboost
4+
from typing import Optional
35

46
import functools
57

@@ -104,7 +106,15 @@ def _scanvi(adata, test=False, n_hidden=None, n_latent=None, n_layers=None):
104106
return preds
105107

106108

107-
def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=None):
109+
def _scanvi_scarches(
110+
adata,
111+
test=False,
112+
n_hidden=None,
113+
n_latent=None,
114+
n_layers=None,
115+
prediction_method="scanvi",
116+
):
117+
import numpy as np
108118
import scvi
109119

110120
if test:
@@ -116,11 +126,14 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N
116126
n_layers = n_layers or 2
117127
n_hidden = n_hidden or 128
118128

129+
unlabeled_category = "Unknown"
130+
119131
# new obs labels to mask test set
132+
adata.obs["scanvi_labels"] = np.where(
133+
adata.obs["is_train"], adata.obs["labels"], unlabeled_category
134+
)
120135
adata_train = adata[adata.obs["is_train"]].copy()
121-
adata_train.obs["scanvi_labels"] = adata_train.obs["labels"].copy()
122136
adata_test = adata[~adata.obs["is_train"]].copy()
123-
adata_test.obs["scanvi_labels"] = "Unknown"
124137
scvi.model.SCVI.setup_anndata(
125138
adata_train, batch_key="batch", labels_key="scanvi_labels"
126139
)
@@ -145,7 +158,9 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N
145158
train_kwargs["limit_train_batches"] = 10
146159
train_kwargs["limit_val_batches"] = 10
147160
scvi_model.train(**train_kwargs)
148-
model = scvi.model.SCANVI.from_scvi_model(scvi_model, unlabeled_category="Unknown")
161+
model = scvi.model.SCANVI.from_scvi_model(
162+
scvi_model, unlabeled_category=unlabeled_category
163+
)
149164
model.train(**train_kwargs)
150165

151166
query_model = scvi.model.SCANVI.load_query_data(adata_test, model)
@@ -156,6 +171,15 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N
156171
train_kwargs["limit_val_batches"] = 10
157172
query_model.train(plan_kwargs=dict(weight_decay=0.0), **train_kwargs)
158173

174+
if prediction_method == "scanvi":
175+
preds = _pred_scanvi(adata, query_model)
176+
elif prediction_method == "xgboost":
177+
preds = _pred_xgb(adata, query_model, test=test)
178+
179+
return preds
180+
181+
182+
def _pred_scanvi(adata, query_model):
159183
# this is temporary and won't be used
160184
adata.obs["scanvi_labels"] = "Unknown"
161185
preds = query_model.predict(adata)
@@ -164,6 +188,20 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N
164188
return preds
165189

166190

191+
# note: could extend test option
192+
def _pred_xgb(
193+
adata,
194+
query_model,
195+
test=False,
196+
num_round: Optional[int] = None,
197+
):
198+
adata.obsm["X_emb"] = query_model.get_latent_representation(adata)
199+
adata = _xgboost(
200+
adata, test=test, obsm="X_emb", num_round=num_round, tree_method="hist"
201+
)
202+
return adata.obs["labels_pred"]
203+
204+
167205
@_scanvi_method(method_name="scANVI (All genes)")
168206
def scanvi_all_genes(adata, test=False):
169207
adata.obs["labels_pred"] = _scanvi(adata, test=test)
@@ -194,3 +232,25 @@ def scarches_scanvi_hvg(adata, test=False):
194232
adata.obs["labels_pred"] = _scanvi_scarches(bdata, test=test)
195233
adata.uns["method_code_version"] = check_version("scvi-tools")
196234
return adata
235+
236+
237+
@_scanvi_scarches_method(method_name="scArches+scANVI+xgboost (All genes)")
238+
def scarches_scanvi_xgb_all_genes(adata, test=False):
239+
adata.obs["labels_pred"] = _scanvi_scarches(
240+
adata, test=test, prediction_method="xgboost"
241+
)
242+
243+
adata.uns["method_code_version"] = check_version("scvi-tools")
244+
return adata
245+
246+
247+
@_scanvi_scarches_method(method_name="scArches+scANVI+xgboost (Seurat v3 2000 HVG)")
248+
def scarches_scanvi_xgb_hvg(adata, test=False):
249+
hvg_df = _hvg(adata, test)
250+
bdata = adata[:, hvg_df.highly_variable].copy()
251+
adata.obs["labels_pred"] = _scanvi_scarches(
252+
bdata, test=test, prediction_method="xgboost"
253+
)
254+
255+
adata.uns["method_code_version"] = check_version("scvi-tools")
256+
return adata

openproblems/tasks/label_projection/methods/xgboost.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@
2222
)
2323

2424

25-
def _xgboost(adata, test: bool = False, num_round: Optional[int] = None):
25+
def _xgboost(
26+
adata,
27+
test: bool = False,
28+
obsm: Optional[str] = None,
29+
num_round: Optional[int] = None,
30+
**kwargs,
31+
):
2632
import xgboost as xgb
2733

2834
if test:
@@ -36,12 +42,19 @@ def _xgboost(adata, test: bool = False, num_round: Optional[int] = None):
3642
adata_train = adata[adata.obs["is_train"]]
3743
adata_test = adata[~adata.obs["is_train"]].copy()
3844

39-
xg_train = xgb.DMatrix(adata_train.X, label=adata_train.obs["labels_int"])
40-
xg_test = xgb.DMatrix(adata_test.X, label=adata_test.obs["labels_int"])
45+
xg_train = xgb.DMatrix(
46+
adata_train.obsm[obsm] if obsm else adata_train.X,
47+
label=adata_train.obs["labels_int"],
48+
)
49+
xg_test = xgb.DMatrix(
50+
adata_test.obsm[obsm] if obsm else adata_test.X,
51+
label=adata_test.obs["labels_int"],
52+
)
4153

4254
param = dict(
4355
objective="multi:softmax",
4456
num_class=len(categories),
57+
**kwargs,
4558
)
4659

4760
watchlist = [(xg_train, "train")]

0 commit comments

Comments
 (0)