forked from openproblems-bio/openproblems
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathxgboost.py
More file actions
91 lines (74 loc) · 2.69 KB
/
xgboost.py
File metadata and controls
91 lines (74 loc) · 2.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from ....tools.decorators import method
from ....tools.normalize import log_cp10k
from ....tools.normalize import log_scran_pooling
from ....tools.utils import check_version
from typing import Optional
import functools
import numpy as np
_xgboost_method = functools.partial(
method,
method_summary=(
"XGBoost is a gradient boosting decision tree model that learns multiple tree"
" structures in the form of a series of input features and their values,"
" leading to a prediction decision, and averages predictions from all its"
" trees. Here, input features are normalised gene expression values."
),
paper_name="XGBoost: A Scalable Tree Boosting System",
paper_reference="chen2016xgboost",
paper_year=2016,
code_url="https://xgboost.readthedocs.io/en/stable/index.html",
)
def _xgboost(
adata,
test: bool = False,
obsm: Optional[str] = None,
num_round: Optional[int] = None,
**kwargs,
):
import xgboost as xgb
if test:
num_round = num_round or 2
else: # pragma: nocover
num_round = num_round or 5
adata.obs["labels_int"] = adata.obs["labels"].cat.codes
categories = adata.obs["labels"].cat.categories
adata_train = adata[adata.obs["is_train"]]
adata_test = adata[~adata.obs["is_train"]].copy()
xg_train = xgb.DMatrix(
adata_train.obsm[obsm] if obsm else adata_train.X,
label=adata_train.obs["labels_int"],
)
xg_test = xgb.DMatrix(
adata_test.obsm[obsm] if obsm else adata_test.X,
label=adata_test.obs["labels_int"],
)
param = dict(
objective="multi:softmax",
num_class=len(categories),
**kwargs,
)
watchlist = [(xg_train, "train")]
xgb_op = xgb.train(param, xg_train, num_boost_round=num_round, evals=watchlist)
# Predict on test data
pred = xgb_op.predict(xg_test).astype(int)
adata_test.obs["labels_pred"] = categories[pred]
adata.obs["labels_pred"] = [
adata_test.obs["labels_pred"][idx] if idx in adata_test.obs_names else np.nan
for idx in adata.obs_names
]
adata.uns["method_code_version"] = check_version("xgboost")
return adata
@_xgboost_method(
method_name="XGBoost (log CP10k)",
image="openproblems-python-extras",
)
def xgboost_log_cp10k(adata, test: bool = False, num_round: Optional[int] = None):
adata = log_cp10k(adata)
return _xgboost(adata, test=test, num_round=num_round)
@_xgboost_method(
method_name="XGBoost (log scran)",
image="openproblems-r-extras",
)
def xgboost_scran(adata, test: bool = False, num_round: Optional[int] = None):
adata = log_scran_pooling(adata)
return _xgboost(adata, test=test, num_round=num_round)