Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions kernels/src/kernels/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def _sort_variants(
1. AcceptedVariant before RejectedVariant.
2. Torch stable ABI arch kernels, with highest compatible version first,
then highest compatible CUDA version.
2. Torch arch kernels with with the highest compatible CUDA version.
2. Torch arch kernels (tagless before C++ ABI-tagged) with the highest compatible CUDA version.
3. tvm-ffi arch kernels with with the highest compatible CUDA version.
4. Torch noarch kernels.
5. Old Torch universal kernels.
Expand All @@ -510,7 +510,8 @@ def sort_key(vs: Decision) -> tuple[int, ...]:
)
elif isinstance(v.framework, Torch):
framework_order = 1
abi_version_order = (0, 0)
# Prefer tagless (cxx11_abi is None) over ABI-tagged.
abi_version_order = (0, 1 if v.framework.cxx11_abi is not None else 0)
else:
framework_order = 2
abi_version_order = (0, 0)
Expand Down
25 changes: 25 additions & 0 deletions kernels/tests/test_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,31 @@ def test_resolve_stable_abi_newest_version_preferred():
assert {vs.variant for vs in trace} == set(variants)


def test_resolve_tagless_preferred_over_abi_tagged():
# Tagless variant (e.g. torch210-cu128) should be preferred over ABI-tagged
# (e.g. torch210-cxx11-cu128) when both are accepted.
variants = [
parse_variant(s)
for s in [
"torch210-cxx11-cu128-x86_64-linux",
"torch210-cu128-x86_64-linux",
]
]
result, trace = _resolve_variant_for_system(
variants=variants,
selected_backend=CUDA(Version("12.8")),
cpu="x86_64",
os="linux",
torch_version=Version("2.10"),
torch_cxx11_abi=True,
tvm_ffi_version=None,
)
assert result != []
assert result[0].variant_str == "torch210-cu128-x86_64-linux"
assert result == [vs.variant for vs in trace if isinstance(vs, VariantAccepted)]
assert {vs.variant for vs in trace} == set(variants)


def test_resolve_stable_abi_preferred_over_torch():
# TorchStableAbi variant is preferred over a regular Torch variant of the same version.
variants = [
Expand Down
Loading