1515"""
1616
1717from ....tools .decorators import metric
18+ from .._utils import ranking_matrix
1819from anndata import AnnData
1920from numba import njit
2021from typing import Tuple
3334_K = 30
3435
3536
36- @njit (cache = True , fastmath = True )
37- def _ranking_matrix (D : np .ndarray ) -> np .ndarray : # pragma: no cover
38- assert D .shape [0 ] == D .shape [1 ]
39- R = np .zeros (D .shape )
40- m = len (R )
41- ks = np .arange (m )
42-
43- for i in range (m ):
44- for j in range (m ):
45- R [i , j ] = np .sum (
46- (D [i , :] < D [i , j ]) | ((ks < j ) & (np .abs (D [i , :] - D [i , j ]) <= 1e-12 ))
47- )
48-
49- return R
50-
51-
5237@njit (cache = True , fastmath = True )
5338def _coranking_matrix (R1 : np .ndarray , R2 : np .ndarray ) -> np .ndarray : # pragma: no cover
5439 assert R1 .shape == R2 .shape
@@ -63,22 +48,6 @@ def _coranking_matrix(R1: np.ndarray, R2: np.ndarray) -> np.ndarray: # pragma:
6348 return Q
6449
6550
66- @njit (cache = True , fastmath = True )
67- def _trustworthiness (Q : np .ndarray , m : int ) -> np .ndarray : # pragma: no cover
68-
69- T = np .zeros (m - 1 ) # trustworthiness
70-
71- for k in range (m - 1 ):
72- Qs = Q [k :, :k ]
73- # a column vector of weights. weight = rank error = actual_rank - k
74- W = np .arange (Qs .shape [0 ]).reshape (- 1 , 1 )
75- # 1 - normalized hard-k-intrusions. lower-left region.
76- # weighted by rank error (rank - k)
77- T [k ] = 1 - np .sum (Qs * W ) / ((k + 1 ) * m * (m - 1 - k ))
78-
79- return T
80-
81-
8251@njit (cache = True , fastmath = True )
8352def _continuity (Q : np .ndarray , m : int ) -> np .ndarray : # pragma: no cover
8453
@@ -133,65 +102,38 @@ def _qnn_auc(QNN: np.ndarray) -> float:
133102 return AUC # type: ignore
134103
135104
136- def _metrics (
137- Q : np .ndarray ,
138- ) -> Tuple [np .ndarray , np .ndarray , np .ndarray , float , np .ndarray , int , float , float ]:
139- Q = Q [1 :, 1 :]
140- m = len (Q )
141-
142- T = _trustworthiness (Q , m )
143- C = _continuity (Q , m )
144- QNN = _qnn (Q , m )
145- LCMC = _lcmc (QNN , m )
146- kmax = _kmax (LCMC )
147- Qlocal = _q_local (QNN , kmax )
148- Qglobal = _q_global (QNN , kmax , m )
149- AUC = _qnn_auc (QNN )
150-
151- return T , C , QNN , AUC , LCMC , kmax , Qlocal , Qglobal
152-
153-
154- def _high_dim (adata : AnnData ) -> np .ndarray :
155- from scipy .sparse import issparse
156-
157- high_dim = adata .X
158- return high_dim .A if issparse (high_dim ) else high_dim
159-
105+ def _fit (adata : AnnData ) -> Tuple [float , float , float , float , float , float , float ]:
106+ Rx = adata .obsm ["X_ranking" ]
107+ E = adata .obsm ["X_emb" ]
160108
161- def _fit (
162- X : np .ndarray , E : np .ndarray
163- ) -> Tuple [float , float , float , float , float , float , float ]:
164- from sklearn .metrics import pairwise_distances
165-
166- if np .any (np .isnan (E )):
167- return 0.0 , 0.0 , 0.0 , 0.5 , - np .inf , - np .inf , - np .inf
168-
169- Dx = pairwise_distances (X )
170- De = pairwise_distances (E )
171- Rx , Re = _ranking_matrix (Dx ), _ranking_matrix (De )
109+ Re = ranking_matrix (E )
172110 Q = _coranking_matrix (Rx , Re )
111+ Q = Q [1 :, 1 :]
112+ m = len (Q )
173113
174- T , C , QNN , AUC , LCMC , _kmax , Qlocal , Qglobal = _metrics (Q )
175-
176- return T [_K ], C [_K ], QNN [_K ], AUC , LCMC [_K ], Qlocal , Qglobal
114+ return Q , m
177115
178116
179117@metric ("continuity" , paper_reference = "zhang2021pydrmetrics" , maximize = True )
180118def continuity (adata : AnnData ) -> float :
181- _ , C , _ , * _ = _fit (_high_dim (adata ), adata .obsm ["X_emb" ])
119+ Q , m = _fit (adata )
120+ C = _continuity (Q , m )[_K ]
182121 return float (np .clip (C , 0.0 , 1.0 )) # in [0, 1]
183122
184123
185124@metric ("co-KNN size" , paper_reference = "zhang2021pydrmetrics" , maximize = True )
186125def qnn (adata : AnnData ) -> float :
187- _ , _ , QNN , * _ = _fit (_high_dim (adata ), adata .obsm ["X_emb" ])
126+ Q , m = _fit (adata )
127+ QNN = _qnn (Q , m )[_K ]
188128 # normalized in the code to [0, 1]
189129 return float (np .clip (QNN , 0.0 , 1.0 ))
190130
191131
192132@metric ("co-KNN AUC" , paper_reference = "zhang2021pydrmetrics" , maximize = True )
193133def qnn_auc (adata : AnnData ) -> float :
194- _ , _ , _ , AUC , * _ = _fit (_high_dim (adata ), adata .obsm ["X_emb" ])
134+ Q , m = _fit (adata )
135+ QNN = _qnn (Q , m )
136+ AUC = _qnn_auc (QNN )
195137 return float (np .clip (AUC , 0.5 , 1.0 )) # in [0.5, 1]
196138
197139
@@ -201,19 +143,29 @@ def qnn_auc(adata: AnnData) -> float:
201143 maximize = True ,
202144)
203145def lcmc (adata : AnnData ) -> float :
204- * _ , LCMC , _ , _ = _fit (_high_dim (adata ), adata .obsm ["X_emb" ])
146+ Q , m = _fit (adata )
147+ QNN = _qnn (Q , m )
148+ LCMC = _lcmc (QNN , m )[_K ]
205149 return LCMC
206150
207151
208152@metric ("local property" , paper_reference = "zhang2021pydrmetrics" , maximize = True )
209153def qlocal (adata : AnnData ) -> float :
210154 # according to authors, this is usually preferred to
211155 # qglobal, because human are more sensitive to nearer neighbors
212- * _ , Qlocal , _ = _fit (_high_dim (adata ), adata .obsm ["X_emb" ])
156+ Q , m = _fit (adata )
157+ QNN = _qnn (Q , m )
158+ LCMC = _lcmc (QNN , m )
159+ kmax = _kmax (LCMC )
160+ Qlocal = _q_local (QNN , kmax )
213161 return Qlocal
214162
215163
216164@metric ("global property" , paper_reference = "zhang2021pydrmetrics" , maximize = True )
217165def qglobal (adata : AnnData ) -> float :
218- * _ , Qglobal = _fit (_high_dim (adata ), adata .obsm ["X_emb" ])
166+ Q , m = _fit (adata )
167+ QNN = _qnn (Q , m )
168+ LCMC = _lcmc (QNN , m )
169+ kmax = _kmax (LCMC )
170+ Qglobal = _q_global (QNN , kmax , m )
219171 return Qglobal
0 commit comments