Skip to content

Commit 0d0a8a4

Browse files
author
Henry Wallace
committed
Use differentiable sorting for Spearman: diffsort fallback, log backend at train start
- loss_unified: when method=auto and torchsort unavailable, use diffsort (HAS_DIFFSORT) - get_spearman_backend() for logging; train_all_fronts prints Spearman loss backend at start - README/Justfile/MAKING_IT_GOOD: recommend uv sync --extra sorting, document backend order
1 parent 3e07098 commit 0d0a8a4

5 files changed

Lines changed: 39 additions & 3 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ ICF is normalized to \([0, 1]\): **0.0 = very common**, **1.0 = very rare**.
88

99
```bash
1010
uv sync --extra dev
11-
# Optional: uv sync --extra sorting (torchsort for differentiable Spearman in multi-task training)
11+
# Recommended for multi-task training: uv sync --extra sorting (torchsort or diffsort for differentiable Spearman; backend is logged at train start)
1212

1313
# Train
1414
uv run tiny-icf-train --help

docs/guides/MAKING_IT_GOOD_MINIMAL_HEURISTICS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Research-backed, low-heuristic improvements. No hand-picked anchor words or ad-h
2020

2121
**Fix:** Use **differentiable Spearman** via soft sorting (Blondel et al., "Fast Differentiable Sorting and Ranking", ICML 2020; [arxiv 2002.08871](https://arxiv.org/abs/2002.08871)). Loss = \( \frac{1}{2}\|r - r_\Psi(\theta)\|^2 \) where \( r_\Psi \) are soft ranks. Implementations: **torchsort** (O(n log n), recommended), **diffsort** (O(n²(log n)²)).
2222

23-
**Implemented:** `loss_unified.spearman_loss_tensor` prefers **torchsort** when `spearman_method` is `"auto"` (default) and torchsort is installed. Training uses `--spearman-method auto`; install with `uv sync --extra sorting` for O(n log n) differentiable Spearman. CLI: `--spearman-reg-strength 0.1`, `--spearman-method auto|torchsort|sigmoid`. Fallback is rank_relax or built-in sigmoid.
23+
**Implemented:** `loss_unified.spearman_loss_tensor` with `spearman_method="auto"` (default): use **torchsort** if available, else **diffsort**, else rank_relax or built-in soft_rank. All paths are differentiable. Install `uv sync --extra sorting` for torchsort and/or diffsort. At training start we log `Spearman loss backend: <torchsort|diffsort|rank_relax|built-in>`. CLI: `--spearman-reg-strength 0.1`, `--spearman-method auto|torchsort|diffsort|sigmoid`.
2424

2525
---
2626

justfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ sync-s3:
6464
aws s3 sync models/ s3://arclabs-backups/tiny-icf/models/ --exclude "*" --include "multitask_*.pt" --include "v3_base*.pt" --include "*.pt.cal.json"
6565

6666
# English-only training (better "the"/"and", no lang prefix); uses frequency sampling + spearman-method auto
67+
# For differentiable Spearman: uv sync --extra sorting (torchsort or diffsort; backend logged at start)
6768
# For custom EPOCHS/SAMPLES run: uv run python scripts/train_all_fronts.py ... --epochs N --train-max-samples M
68-
# Background run: nohup uv run python scripts/train_all_fronts.py ... > models/all_fronts_en/train_en_30ep.log 2>&1 &
6969
train-en DATA="data/word_frequency.csv" EPOCHS="30" SAMPLES="200000":
7070
mkdir -p models/all_fronts_en
7171
uv run python scripts/train_all_fronts.py \

scripts/train_all_fronts.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,11 @@ def main() -> int:
278278
},
279279
}
280280

281+
from tiny_icf.loss_unified import get_spearman_backend
282+
283+
spearman_backend = get_spearman_backend(config.get("spearman_method", "auto"))
284+
print(f"Spearman loss backend: {spearman_backend} (method={args.spearman_method})")
285+
281286
module = FlexibleIDFLightningModule(config=config, learning_rate=args.lr, weight_decay=args.weight_decay)
282287

283288
# Optional init-from: load a UniversalICF checkpoint into the base model.

src/tiny_icf/loss_unified.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,26 @@
3838
HAS_TORCHSORT = False
3939
spearman_loss_torchsort = None
4040

41+
# Fallback: diffsort (differentiable sorting networks) when torchsort unavailable
42+
try:
43+
from tiny_icf.loss import _try_import_diffsort, spearman_loss_diffsort
44+
45+
HAS_DIFFSORT = _try_import_diffsort() is not None
46+
except Exception:
47+
HAS_DIFFSORT = False
48+
spearman_loss_diffsort = None
49+
50+
51+
def get_spearman_backend(method: str = "auto") -> str:
52+
"""Return which backend will be used for Spearman loss (for logging)."""
53+
if method in ("torchsort", "auto") and HAS_TORCHSORT and spearman_loss_torchsort is not None:
54+
return "torchsort"
55+
if method in ("diffsort", "auto") and HAS_DIFFSORT and spearman_loss_diffsort is not None:
56+
return "diffsort"
57+
if HAS_RANK_RELAX:
58+
return "rank_relax"
59+
return "built-in (soft_rank)"
60+
4161

4262
def _to_list(tensor: torch.Tensor) -> List[float]:
4363
"""Convert tensor to Python list for rank-relax."""
@@ -177,6 +197,17 @@ def spearman_loss_tensor(
177197
predictions, targets, regularization_strength=regularization_strength
178198
)
179199

200+
# Differentiable sorting fallback: diffsort when method is auto or diffsort
201+
use_diffsort = (
202+
(method in ("diffsort", "auto"))
203+
and HAS_DIFFSORT
204+
and spearman_loss_diffsort is not None
205+
and predictions.numel() >= 2
206+
)
207+
if use_diffsort:
208+
steepness = max(1.0, min(20.0, regularization_strength * 5.0))
209+
return spearman_loss_diffsort(predictions, targets, steepness=steepness)
210+
180211
if not HAS_RANK_RELAX:
181212
# Fallback: use soft ranking and compute correlation manually
182213
pred_ranks = soft_rank_tensor(

0 commit comments

Comments
 (0)