diff --git a/cebra/models/criterions.py b/cebra/models/criterions.py index f78e298b..eeeb7f73 100644 --- a/cebra/models/criterions.py +++ b/cebra/models/criterions.py @@ -33,6 +33,7 @@ """ import math +import warnings from typing import Optional, Tuple import torch @@ -40,7 +41,28 @@ from torch import nn -@torch.jit.script +def _compile(fn): + """Apply ``torch.compile`` when available, falling back to uncompiled. + + ``torch.compile`` is the recommended replacement for ``torch.jit.script`` + starting from PyTorch 2.0. In environments where the compiler backend is + not available (e.g. certain CI configurations or incomplete installations), + the function is returned unchanged so that correctness is preserved. + A :class:`UserWarning` is emitted when the fallback path is taken. + """ + try: + return torch.compile(fn) + except (ImportError, RuntimeError, TypeError) as exc: + warnings.warn( + f"torch.compile is unavailable; falling back to uncompiled " + f"{fn.__name__!r}. Reason: {exc}", + UserWarning, + stacklevel=2, + ) + return fn + + +@_compile def dot_similarity(ref: torch.Tensor, pos: torch.Tensor, neg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Cosine similarity the ref, pos and negative pairs @@ -59,7 +81,7 @@ def dot_similarity(ref: torch.Tensor, pos: torch.Tensor, return pos_dist, neg_dist -@torch.jit.script +@_compile def euclidean_similarity( ref: torch.Tensor, pos: torch.Tensor, neg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -85,7 +107,7 @@ def euclidean_similarity( return pos_dist, neg_dist -@torch.jit.script +@_compile def infonce( pos_dist: torch.Tensor, neg_dist: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/tests/test_criterions.py b/tests/test_criterions.py index 0d6f8ff2..572ea667 100644 --- a/tests/test_criterions.py +++ b/tests/test_criterions.py @@ -25,8 +25,11 @@ import cebra.models.criterions as cebra_criterions +# Use the same _compile helper from criterions for consistency +_compile = cebra_criterions._compile -@torch.jit.script + +@_compile def ref_dot_similarity(ref: torch.Tensor, pos: torch.Tensor, neg: torch.Tensor, temperature: float): pos_dist = torch.einsum("ni,ni->n", ref, pos) / temperature @@ -34,7 +37,7 @@ def ref_dot_similarity(ref: torch.Tensor, pos: torch.Tensor, neg: torch.Tensor, return pos_dist, neg_dist -@torch.jit.script +@_compile def ref_euclidean_similarity(ref: torch.Tensor, pos: torch.Tensor, neg: torch.Tensor, temperature: float): ref_sq = torch.einsum("ni->n", ref**2) / temperature @@ -48,7 +51,7 @@ def ref_euclidean_similarity(ref: torch.Tensor, pos: torch.Tensor, return pos_dist, neg_dist -@torch.jit.script +@_compile def ref_infonce(pos_dist: torch.Tensor, neg_dist: torch.Tensor): with torch.no_grad(): c, _ = neg_dist.max(dim=1, keepdim=True) @@ -61,7 +64,7 @@ def ref_infonce(pos_dist: torch.Tensor, neg_dist: torch.Tensor): return align + uniform, align, uniform -@torch.jit.script +@_compile def ref_infonce_not_stable(pos_dist: torch.Tensor, neg_dist: torch.Tensor): pos_dist = pos_dist neg_dist = neg_dist