Skip to content

Commit 0f3736b

Browse files
Precompute nn ranking for DR datasets (#841)
* precompute nn ranking on full data * bugfix * bugfix2 * fix sample_dataset * just subsample it * Remove nan check
1 parent 9d16650 commit 0f3736b

8 files changed

Lines changed: 81 additions & 81 deletions

File tree

openproblems/tasks/dimensionality_reduction/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ WARNING: other than most tasks, `adata.X` should contain log CP10k-normalized da
4848
highly on these metrics.
4949

5050
**Datasets** should provide *log CP10k normalized counts* in `adata.X` and store the
51-
original number of genes (i.e., `adata.shape[1]`) in `adata.uns["n_genes"]`.
51+
original number of genes (i.e., `adata.shape[1]`) in `adata.uns["n_genes"]`. Datasets
52+
should also contain the nearest-neighbor ranking matrix, required for the `nn_ranking`
53+
metrics, as computed by `_utils.ranking_matrix(adata.X)` on normalized counts.
5254

5355
**Methods** should assign dimensionally-reduced 2D embedding coordinates to
5456
`adata.obsm['X_emb']`. They *should not* modify the dimensionality of `adata.X` (e.g.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from numba import njit
2+
3+
import numpy as np
4+
5+
6+
@njit(cache=True, fastmath=True)
7+
def _ranking_matrix(D: np.ndarray) -> np.ndarray: # pragma: no cover
8+
assert D.shape[0] == D.shape[1]
9+
R = np.zeros(D.shape)
10+
m = len(R)
11+
ks = np.arange(m)
12+
13+
for i in range(m):
14+
for j in range(m):
15+
R[i, j] = np.sum(
16+
(D[i, :] < D[i, j]) | ((ks < j) & (np.abs(D[i, :] - D[i, j]) <= 1e-12))
17+
)
18+
19+
return R
20+
21+
22+
def ranking_matrix(X):
23+
from sklearn.metrics import pairwise_distances
24+
25+
D = pairwise_distances(X)
26+
R = _ranking_matrix(D)
27+
return R

openproblems/tasks/dimensionality_reduction/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ...data.sample import load_sample_data
22
from ...tools.decorators import dataset
33
from ...tools.normalize import log_cp10k
4+
from . import _utils
45

56
import numpy as np
67

@@ -31,6 +32,7 @@ def sample_dataset():
3132
adata = load_sample_data()
3233
adata = log_cp10k(adata)
3334
adata.uns["n_genes"] = adata.shape[1]
35+
adata.obsm["X_ranking"] = _utils.ranking_matrix(adata.X)
3436
return adata
3537

3638

openproblems/tasks/dimensionality_reduction/datasets/mouse_blood_olsson_labelled.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ....data.mouse_blood_olsson_labelled import load_olsson_2016_mouse_blood
22
from ....tools.decorators import dataset
33
from ....tools.normalize import log_cp10k
4+
from .._utils import ranking_matrix
45

56

67
@dataset(
@@ -15,4 +16,6 @@
1516
def olsson_2016_mouse_blood(test=False):
1617
adata = load_olsson_2016_mouse_blood(test=test)
1718
adata.uns["n_genes"] = adata.shape[1]
18-
return log_cp10k(adata)
19+
adata = log_cp10k(adata)
20+
adata.obsm["X_ranking"] = ranking_matrix(adata.X)
21+
return adata

openproblems/tasks/dimensionality_reduction/datasets/mouse_hspc_nestorowa2016.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ....data.mouse_hspc_nestorowa2016 import load_mouse_hspc_nestorowa2016
22
from ....tools.decorators import dataset
33
from ....tools.normalize import log_cp10k
4+
from .._utils import ranking_matrix
45

56

67
@dataset(
@@ -15,4 +16,6 @@
1516
def mouse_hspc_nestorowa2016(test=False):
1617
adata = load_mouse_hspc_nestorowa2016(test=test)
1718
adata.uns["n_genes"] = adata.shape[1]
18-
return log_cp10k(adata)
19+
adata = log_cp10k(adata)
20+
adata.obsm["X_ranking"] = ranking_matrix(adata.X)
21+
return adata

openproblems/tasks/dimensionality_reduction/datasets/tenx_5k_pbmc.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ....data.tenx import load_tenx_5k_pbmc
22
from ....tools.decorators import dataset
33
from ....tools.normalize import log_cp10k
4+
from .._utils import ranking_matrix
45

56

67
@dataset(
@@ -16,4 +17,6 @@
1617
def tenx_5k_pbmc(test=False):
1718
adata = load_tenx_5k_pbmc(test=test)
1819
adata.uns["n_genes"] = adata.shape[1]
19-
return log_cp10k(adata)
20+
adata = log_cp10k(adata)
21+
adata.obsm["X_ranking"] = ranking_matrix(adata.X)
22+
return adata

openproblems/tasks/dimensionality_reduction/datasets/zebrafish.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ....data.zebrafish import load_zebrafish
22
from ....tools.decorators import dataset
33
from ....tools.normalize import log_cp10k
4+
from .._utils import ranking_matrix
45

56

67
@dataset(
@@ -15,6 +16,13 @@
1516
),
1617
)
1718
def zebrafish_labs(test=False):
19+
import scanpy as sc
20+
1821
adata = load_zebrafish(test=test)
22+
if not test:
23+
# this dataset is too big
24+
sc.pp.subsample(adata, n_obs=25000)
1925
adata.uns["n_genes"] = adata.shape[1]
20-
return log_cp10k(adata)
26+
adata = log_cp10k(adata)
27+
adata.obsm["X_ranking"] = ranking_matrix(adata.X)
28+
return adata

openproblems/tasks/dimensionality_reduction/metrics/nn_ranking.py

Lines changed: 28 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616

1717
from ....tools.decorators import metric
18+
from .._utils import ranking_matrix
1819
from anndata import AnnData
1920
from numba import njit
2021
from typing import Tuple
@@ -33,22 +34,6 @@
3334
_K = 30
3435

3536

36-
@njit(cache=True, fastmath=True)
37-
def _ranking_matrix(D: np.ndarray) -> np.ndarray: # pragma: no cover
38-
assert D.shape[0] == D.shape[1]
39-
R = np.zeros(D.shape)
40-
m = len(R)
41-
ks = np.arange(m)
42-
43-
for i in range(m):
44-
for j in range(m):
45-
R[i, j] = np.sum(
46-
(D[i, :] < D[i, j]) | ((ks < j) & (np.abs(D[i, :] - D[i, j]) <= 1e-12))
47-
)
48-
49-
return R
50-
51-
5237
@njit(cache=True, fastmath=True)
5338
def _coranking_matrix(R1: np.ndarray, R2: np.ndarray) -> np.ndarray: # pragma: no cover
5439
assert R1.shape == R2.shape
@@ -63,22 +48,6 @@ def _coranking_matrix(R1: np.ndarray, R2: np.ndarray) -> np.ndarray: # pragma:
6348
return Q
6449

6550

66-
@njit(cache=True, fastmath=True)
67-
def _trustworthiness(Q: np.ndarray, m: int) -> np.ndarray: # pragma: no cover
68-
69-
T = np.zeros(m - 1) # trustworthiness
70-
71-
for k in range(m - 1):
72-
Qs = Q[k:, :k]
73-
# a column vector of weights. weight = rank error = actual_rank - k
74-
W = np.arange(Qs.shape[0]).reshape(-1, 1)
75-
# 1 - normalized hard-k-intrusions. lower-left region.
76-
# weighted by rank error (rank - k)
77-
T[k] = 1 - np.sum(Qs * W) / ((k + 1) * m * (m - 1 - k))
78-
79-
return T
80-
81-
8251
@njit(cache=True, fastmath=True)
8352
def _continuity(Q: np.ndarray, m: int) -> np.ndarray: # pragma: no cover
8453

@@ -133,65 +102,38 @@ def _qnn_auc(QNN: np.ndarray) -> float:
133102
return AUC # type: ignore
134103

135104

136-
def _metrics(
137-
Q: np.ndarray,
138-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float, np.ndarray, int, float, float]:
139-
Q = Q[1:, 1:]
140-
m = len(Q)
141-
142-
T = _trustworthiness(Q, m)
143-
C = _continuity(Q, m)
144-
QNN = _qnn(Q, m)
145-
LCMC = _lcmc(QNN, m)
146-
kmax = _kmax(LCMC)
147-
Qlocal = _q_local(QNN, kmax)
148-
Qglobal = _q_global(QNN, kmax, m)
149-
AUC = _qnn_auc(QNN)
150-
151-
return T, C, QNN, AUC, LCMC, kmax, Qlocal, Qglobal
152-
153-
154-
def _high_dim(adata: AnnData) -> np.ndarray:
155-
from scipy.sparse import issparse
156-
157-
high_dim = adata.X
158-
return high_dim.A if issparse(high_dim) else high_dim
159-
105+
def _fit(adata: AnnData) -> Tuple[float, float, float, float, float, float, float]:
106+
Rx = adata.obsm["X_ranking"]
107+
E = adata.obsm["X_emb"]
160108

161-
def _fit(
162-
X: np.ndarray, E: np.ndarray
163-
) -> Tuple[float, float, float, float, float, float, float]:
164-
from sklearn.metrics import pairwise_distances
165-
166-
if np.any(np.isnan(E)):
167-
return 0.0, 0.0, 0.0, 0.5, -np.inf, -np.inf, -np.inf
168-
169-
Dx = pairwise_distances(X)
170-
De = pairwise_distances(E)
171-
Rx, Re = _ranking_matrix(Dx), _ranking_matrix(De)
109+
Re = ranking_matrix(E)
172110
Q = _coranking_matrix(Rx, Re)
111+
Q = Q[1:, 1:]
112+
m = len(Q)
173113

174-
T, C, QNN, AUC, LCMC, _kmax, Qlocal, Qglobal = _metrics(Q)
175-
176-
return T[_K], C[_K], QNN[_K], AUC, LCMC[_K], Qlocal, Qglobal
114+
return Q, m
177115

178116

179117
@metric("continuity", paper_reference="zhang2021pydrmetrics", maximize=True)
180118
def continuity(adata: AnnData) -> float:
181-
_, C, _, *_ = _fit(_high_dim(adata), adata.obsm["X_emb"])
119+
Q, m = _fit(adata)
120+
C = _continuity(Q, m)[_K]
182121
return float(np.clip(C, 0.0, 1.0)) # in [0, 1]
183122

184123

185124
@metric("co-KNN size", paper_reference="zhang2021pydrmetrics", maximize=True)
186125
def qnn(adata: AnnData) -> float:
187-
_, _, QNN, *_ = _fit(_high_dim(adata), adata.obsm["X_emb"])
126+
Q, m = _fit(adata)
127+
QNN = _qnn(Q, m)[_K]
188128
# normalized in the code to [0, 1]
189129
return float(np.clip(QNN, 0.0, 1.0))
190130

191131

192132
@metric("co-KNN AUC", paper_reference="zhang2021pydrmetrics", maximize=True)
193133
def qnn_auc(adata: AnnData) -> float:
194-
_, _, _, AUC, *_ = _fit(_high_dim(adata), adata.obsm["X_emb"])
134+
Q, m = _fit(adata)
135+
QNN = _qnn(Q, m)
136+
AUC = _qnn_auc(QNN)
195137
return float(np.clip(AUC, 0.5, 1.0)) # in [0.5, 1]
196138

197139

@@ -201,19 +143,29 @@ def qnn_auc(adata: AnnData) -> float:
201143
maximize=True,
202144
)
203145
def lcmc(adata: AnnData) -> float:
204-
*_, LCMC, _, _ = _fit(_high_dim(adata), adata.obsm["X_emb"])
146+
Q, m = _fit(adata)
147+
QNN = _qnn(Q, m)
148+
LCMC = _lcmc(QNN, m)[_K]
205149
return LCMC
206150

207151

208152
@metric("local property", paper_reference="zhang2021pydrmetrics", maximize=True)
209153
def qlocal(adata: AnnData) -> float:
210154
# according to authors, this is usually preferred to
211155
# qglobal, because human are more sensitive to nearer neighbors
212-
*_, Qlocal, _ = _fit(_high_dim(adata), adata.obsm["X_emb"])
156+
Q, m = _fit(adata)
157+
QNN = _qnn(Q, m)
158+
LCMC = _lcmc(QNN, m)
159+
kmax = _kmax(LCMC)
160+
Qlocal = _q_local(QNN, kmax)
213161
return Qlocal
214162

215163

216164
@metric("global property", paper_reference="zhang2021pydrmetrics", maximize=True)
217165
def qglobal(adata: AnnData) -> float:
218-
*_, Qglobal = _fit(_high_dim(adata), adata.obsm["X_emb"])
166+
Q, m = _fit(adata)
167+
QNN = _qnn(Q, m)
168+
LCMC = _lcmc(QNN, m)
169+
kmax = _kmax(LCMC)
170+
Qglobal = _q_global(QNN, kmax, m)
219171
return Qglobal

0 commit comments

Comments
 (0)