|
38 | 38 | HAS_TORCHSORT = False |
39 | 39 | spearman_loss_torchsort = None |
40 | 40 |
|
| 41 | +# Fallback: diffsort (differentiable sorting networks) when torchsort unavailable |
| 42 | +try: |
| 43 | + from tiny_icf.loss import _try_import_diffsort, spearman_loss_diffsort |
| 44 | + |
| 45 | + HAS_DIFFSORT = _try_import_diffsort() is not None |
| 46 | +except Exception: |
| 47 | + HAS_DIFFSORT = False |
| 48 | + spearman_loss_diffsort = None |
| 49 | + |
| 50 | + |
| 51 | +def get_spearman_backend(method: str = "auto") -> str: |
| 52 | + """Return which backend will be used for Spearman loss (for logging).""" |
| 53 | + if method in ("torchsort", "auto") and HAS_TORCHSORT and spearman_loss_torchsort is not None: |
| 54 | + return "torchsort" |
| 55 | + if method in ("diffsort", "auto") and HAS_DIFFSORT and spearman_loss_diffsort is not None: |
| 56 | + return "diffsort" |
| 57 | + if HAS_RANK_RELAX: |
| 58 | + return "rank_relax" |
| 59 | + return "built-in (soft_rank)" |
| 60 | + |
41 | 61 |
|
42 | 62 | def _to_list(tensor: torch.Tensor) -> List[float]: |
43 | 63 | """Convert tensor to Python list for rank-relax.""" |
@@ -177,6 +197,17 @@ def spearman_loss_tensor( |
177 | 197 | predictions, targets, regularization_strength=regularization_strength |
178 | 198 | ) |
179 | 199 |
|
| 200 | + # Differentiable sorting fallback: diffsort when method is auto or diffsort |
| 201 | + use_diffsort = ( |
| 202 | + (method in ("diffsort", "auto")) |
| 203 | + and HAS_DIFFSORT |
| 204 | + and spearman_loss_diffsort is not None |
| 205 | + and predictions.numel() >= 2 |
| 206 | + ) |
| 207 | + if use_diffsort: |
| 208 | + steepness = max(1.0, min(20.0, regularization_strength * 5.0)) |
| 209 | + return spearman_loss_diffsort(predictions, targets, steepness=steepness) |
| 210 | + |
180 | 211 | if not HAS_RANK_RELAX: |
181 | 212 | # Fallback: use soft ranking and compute correlation manually |
182 | 213 | pred_ranks = soft_rank_tensor( |
|
0 commit comments