1111from . import tools
1212
1313
14+ def _select_svd_algorithm (
15+ S : Union [sp .spmatrix , np .ndarray ],
16+ svd_algorithm : Optional [int ],
17+ auto_select_algorithm : bool ,
18+ verbose : bool = True
19+ ) -> int :
20+ """
21+ Select the optimal SVD algorithm based on matrix properties.
22+
23+ Parameters
24+ ----------
25+ S : scipy.sparse matrix or numpy.ndarray
26+ Input matrix (features × cells).
27+ svd_algorithm : int or None
28+ User-specified algorithm (if None, automatic selection is performed).
29+ auto_select_algorithm : bool
30+ Whether to enable automatic algorithm selection.
31+ verbose : bool
32+ Whether to print selection rationale.
33+
34+ Returns
35+ -------
36+ int
37+ Selected algorithm code (0=IRLB, 1=Halko, 2=Feng, 3=PRIMME).
38+
39+ Selection Logic
40+ ---------------
41+ 1. If matrix exceeds 32-bit indexing limits (>2^31-1 elements), force PRIMME
42+ 2. For sparse matrices:
43+ - Large & very sparse (>70% sparse, >10M elements): PRIMME
44+ - Otherwise: IRLB
45+ 3. For dense matrices: Halko (fastest)
46+ """
47+ # If algorithm is explicitly specified, use it
48+ if svd_algorithm is not None :
49+ return svd_algorithm
50+
51+ # If auto-select is disabled, use default
52+ if not auto_select_algorithm :
53+ return 0 # IRLB default
54+
55+ # Calculate matrix properties
56+ total_elements = np .prod (S .shape )
57+
58+ # Check for 32-bit overflow (2^31 - 1 = 2,147,483,647)
59+ # Many sparse matrix libraries use 32-bit integers for indexing
60+ MAX_32BIT = 2_147_483_647
61+
62+ if total_elements > MAX_32BIT :
63+ if verbose :
64+ print (f"⚠ Matrix exceeds 32-bit indexing limit ({ total_elements :,} > { MAX_32BIT :,} elements)" )
65+ print (f"→ Auto-selected PRIMME for safe handling of large matrices" )
66+ return 3 # PRIMME
67+
68+ # Determine sparsity and select algorithm
69+ if sp .issparse (S ):
70+ nnz = S .nnz
71+ sparsity = 1.0 - (nnz / total_elements )
72+
73+ # For large, very sparse matrices, PRIMME is most memory-efficient
74+ if sparsity > 0.7 and total_elements > 1_000_000_000 :
75+ if verbose :
76+ print (f"Auto-selected PRIMME for large sparse matrix "
77+ f"({ sparsity :.1%} sparse, { total_elements :,} elements)" )
78+ return 3 # PRIMME
79+ else :
80+ if verbose :
81+ print (f"Auto-selected IRLB for sparse matrix "
82+ f"({ sparsity :.1%} sparse, { total_elements :,} elements)" )
83+ return 0 # IRLB
84+ else :
85+ # Dense matrices: Halko is typically fastest
86+ if verbose :
87+ print (f"Auto-selected Halko for dense matrix ({ total_elements :,} elements)" )
88+ return 1 # Halko
89+
90+
1491def reduce_kernel (
1592 adata : AnnData ,
1693 n_components : int = 30 ,
1794 layer : Optional [str ] = None ,
1895 key_added : str = "action" ,
19- svd_algorithm : int = 0 ,
96+ svd_algorithm : Optional [ int ] = None ,
2097 max_iter : int = 0 ,
2198 seed : int = 0 ,
2299 verbose : bool = True ,
23100 inplace : bool = True ,
101+ auto_select_algorithm : bool = True ,
24102) -> Optional [AnnData ]:
25103 """
26104 Compute a low-rank approximation of the kernel matrix for ACTION decomposition and store the results in AnnData.
@@ -29,14 +107,19 @@ def reduce_kernel(
29107 ----------
30108 adata : AnnData
31109 Annotated data matrix (cells × features).
32- n_components : int, optional (default: 50 )
110+ n_components : int, optional (default: 30 )
33111 Number of singular vectors (components) to compute.
34112 layer : str or None, optional (default: None)
35113 Layer in AnnData to use for computation. If None, uses adata.X.
36114 key_added : str, optional (default: "action")
37115 Key under which to store the results in adata.obsm and related fields.
38- svd_algorithm : int, optional (default: 0)
39- SVD algorithm to use (0=irlb, 1=halko, 2=feng).
116+ svd_algorithm : int or None, optional (default: None)
117+ SVD algorithm to use:
118+ - 0: IRLB (Implicitly Restarted Lanczos Bidiagonalization)
119+ - 1: Halko (Randomized SVD)
120+ - 2: Feng (Feng's randomized algorithm)
121+ - 3: PRIMME (PReconditioned Iterative MultiMethod Eigensolver)
122+ - None: Automatic selection based on matrix properties (recommended)
40123 max_iter : int, optional (default: 0)
41124 Maximum number of iterations for SVD solver (0=auto).
42125 seed : int, optional (default: 0)
@@ -45,6 +128,13 @@ def reduce_kernel(
45128 Whether to print progress messages.
46129 inplace : bool, optional (default: True)
47130 If True, modifies the AnnData object in place. If False, returns a new AnnData object with the results.
131+ auto_select_algorithm : bool, optional (default: True)
132+ If True and svd_algorithm is None, automatically selects the best algorithm based on
133+ matrix properties. The overhead of this check is negligible.
134+ Selection logic:
135+ - For large, sparse matrices (>70% sparse, >10M elements): PRIMME
136+ - For dense matrices: Halko (fastest)
137+ - Otherwise: IRLB (default)
48138
49139 Returns
50140 -------
@@ -62,12 +152,15 @@ def reduce_kernel(
62152 adata.varm[f"{key_added}_A"] : np.ndarray
63153 A matrix from decomposition (features × n_components).
64154 adata.uns[f"{key_added}_params"] : dict
65- Parameters used for reduction (e.g., sigma, n_components).
155+ Parameters used for reduction (e.g., sigma, n_components, svd_algorithm ).
66156 """
67157 if not inplace :
68158 adata = adata .copy ()
69159 S = anndata_to_matrix (adata , layer = layer , transpose = True )
70160
161+ # Select SVD algorithm (automatic selection has negligible overhead: ~1-2 microseconds)
162+ svd_algorithm = _select_svd_algorithm (S , svd_algorithm , auto_select_algorithm , verbose )
163+
71164 if sp .issparse (S ):
72165 result = _core .reduce_kernel_sparse (S , n_components , svd_algorithm , max_iter , seed , verbose )
73166 else :
@@ -79,9 +172,14 @@ def reduce_kernel(
79172 adata .varm [f"{ key_added } _A" ] = result ["A" ]
80173 adata .obsm [f"{ key_added } _B" ] = result ["B" ]
81174
175+ # Map algorithm code to name for better user understanding
176+ algorithm_names = {0 : 'IRLB' , 1 : 'Halko' , 2 : 'Feng' , 3 : 'PRIMME' }
177+
82178 adata .uns [f"{ key_added } _params" ] = {
83179 "sigma" : result ["sigma" ],
84180 "n_components" : n_components ,
181+ "svd_algorithm" : svd_algorithm ,
182+ "svd_algorithm_name" : algorithm_names .get (svd_algorithm , f'Unknown({ svd_algorithm } )' ),
85183 }
86184
87185 if not inplace :
0 commit comments