Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions cebra/models/criterions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,36 @@
"""

import math
import warnings
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn


@torch.jit.script
def _compile(fn):
"""Apply ``torch.compile`` when available, falling back to uncompiled.

``torch.compile`` is the recommended replacement for ``torch.jit.script``
starting from PyTorch 2.0. In environments where the compiler backend is
not available (e.g. certain CI configurations or incomplete installations),
the function is returned unchanged so that correctness is preserved.
A :class:`UserWarning` is emitted when the fallback path is taken.
"""
try:
return torch.compile(fn)
except (ImportError, RuntimeError, TypeError) as exc:
warnings.warn(
f"torch.compile is unavailable; falling back to uncompiled "
f"{fn.__name__!r}. Reason: {exc}",
UserWarning,
stacklevel=2,
)
return fn


@_compile
def dot_similarity(ref: torch.Tensor, pos: torch.Tensor,
neg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Cosine similarity the ref, pos and negative pairs
Expand All @@ -59,7 +81,7 @@ def dot_similarity(ref: torch.Tensor, pos: torch.Tensor,
return pos_dist, neg_dist


@torch.jit.script
@_compile
def euclidean_similarity(
ref: torch.Tensor, pos: torch.Tensor,
neg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -85,7 +107,7 @@ def euclidean_similarity(
return pos_dist, neg_dist


@torch.jit.script
@_compile
def infonce(
pos_dist: torch.Tensor, neg_dist: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down
11 changes: 7 additions & 4 deletions tests/test_criterions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,19 @@

import cebra.models.criterions as cebra_criterions

# Use the same _compile helper from criterions for consistency
_compile = cebra_criterions._compile

@torch.jit.script

@_compile
def ref_dot_similarity(ref: torch.Tensor, pos: torch.Tensor, neg: torch.Tensor,
temperature: float):
pos_dist = torch.einsum("ni,ni->n", ref, pos) / temperature
neg_dist = torch.einsum("ni,mi->nm", ref, neg) / temperature
return pos_dist, neg_dist


@torch.jit.script
@_compile
def ref_euclidean_similarity(ref: torch.Tensor, pos: torch.Tensor,
neg: torch.Tensor, temperature: float):
ref_sq = torch.einsum("ni->n", ref**2) / temperature
Expand All @@ -48,7 +51,7 @@ def ref_euclidean_similarity(ref: torch.Tensor, pos: torch.Tensor,
return pos_dist, neg_dist


@torch.jit.script
@_compile
def ref_infonce(pos_dist: torch.Tensor, neg_dist: torch.Tensor):
with torch.no_grad():
c, _ = neg_dist.max(dim=1, keepdim=True)
Expand All @@ -61,7 +64,7 @@ def ref_infonce(pos_dist: torch.Tensor, neg_dist: torch.Tensor):
return align + uniform, align, uniform


@torch.jit.script
@_compile
def ref_infonce_not_stable(pos_dist: torch.Tensor, neg_dist: torch.Tensor):
pos_dist = pos_dist
neg_dist = neg_dist
Expand Down
Loading