diff --git a/iris/drivers/__init__.py b/iris/drivers/__init__.py new file mode 100644 index 000000000..0520e18ab --- /dev/null +++ b/iris/drivers/__init__.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Shared driver package types for fabric backends. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +from iris.drivers.base import BaseFabricDriver + +__all__ = ["DriverStack"] + + +@dataclass +class DriverStack: + """Fabric drivers available for a rank.""" + + vendor: str + fabric: Optional[BaseFabricDriver] diff --git a/iris/drivers/base.py b/iris/drivers/base.py new file mode 100644 index 000000000..70b08a70d --- /dev/null +++ b/iris/drivers/base.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Abstract base classes, shared dataclasses, and exceptions for fabric drivers. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + +from iris.topology import InterconnectLevel + +__all__ = [ + "PeerMapping", + "LocalAllocation", + "BaseFabricDriver", + "DriverError", + "DriverNotSupported", +] + + +@dataclass +class PeerMapping: + """A remote rank's memory mapped into this rank's address space.""" + + peer_rank: int + transport: InterconnectLevel + remote_va: int + size: int + _driver_handle: Any = None + + +@dataclass +class LocalAllocation: + """This rank's exportable allocation.""" + + va: int + size: int + handle: Any + + +class DriverError(RuntimeError): + """Base exception for driver operations.""" + + +class DriverNotSupported(DriverError): + """The current hardware or software stack does not support this driver.""" + + +class BaseFabricDriver(ABC): + """Cross-node fabric memory sharing (for example NVSwitch or xGMI).""" + + @abstractmethod + def initialize(self, device_ordinal: int) -> None: + """Prepare the driver for a specific local GPU.""" + + @abstractmethod + def allocate_exportable(self, size: int) -> LocalAllocation: + """Allocate memory that can be shared through the fabric transport.""" + + @abstractmethod + def export_handle(self, allocation: LocalAllocation) -> bytes: + """Export a transport-specific handle for a local allocation.""" + + @abstractmethod + def import_and_map(self, peer_rank: int, handle_bytes: bytes, size: int) -> PeerMapping: + """Import a peer handle and map it into the local virtual address space.""" + + @abstractmethod + def cleanup_import(self, mapping: PeerMapping) -> None: + """Release a mapped peer allocation.""" + + @abstractmethod + def cleanup_local(self, allocation: LocalAllocation) -> None: + """Release a locally-exported allocation.""" diff --git a/iris/drivers/fabric/__init__.py b/iris/drivers/fabric/__init__.py new file mode 100644 index 000000000..42dd830a5 --- /dev/null +++ b/iris/drivers/fabric/__init__.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Shared handle serialization utilities for fabric drivers. + +These helpers convert between raw fabric handles and uint8 tensors that higher +layers can exchange with torch.distributed collectives. Handle size validation, +when desired, is backend-specific and can be passed in explicitly. +""" + +from __future__ import annotations + +from typing import Optional, Union + +import torch + +__all__ = [ + "fabric_handle_bytes_to_tensor", + "fabric_tensor_to_handle_bytes", +] + + +def fabric_handle_bytes_to_tensor( + handle_bytes: bytes, + device: Union[torch.device, str], + expected_num_bytes: Optional[int] = None, +) -> torch.Tensor: + """Serialize a raw fabric handle into a uint8 tensor.""" + + if expected_num_bytes is not None and len(handle_bytes) != expected_num_bytes: + raise ValueError(f"Fabric handle must be {expected_num_bytes} bytes, got {len(handle_bytes)}") + return torch.tensor(list(handle_bytes), dtype=torch.uint8, device=device) + + +def fabric_tensor_to_handle_bytes(handle_tensor: torch.Tensor, expected_num_bytes: Optional[int] = None) -> bytes: + """Deserialize a uint8 tensor back into raw handle bytes.""" + + flattened = handle_tensor.detach().flatten() + if flattened.dtype != torch.uint8: + raise ValueError("Fabric handle tensor must have dtype torch.uint8") + if expected_num_bytes is not None and flattened.numel() != expected_num_bytes: + raise ValueError(f"Fabric handle tensor must have {expected_num_bytes} elements, got {flattened.numel()}") + return bytes(flattened.to("cpu", copy=True).tolist()) diff --git a/iris/drivers/fabric/amd.py b/iris/drivers/fabric/amd.py new file mode 100644 index 000000000..80fd275b9 --- /dev/null +++ b/iris/drivers/fabric/amd.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +AMD fabric driver stub. +""" + +from __future__ import annotations + +from iris.drivers.base import BaseFabricDriver, DriverNotSupported, LocalAllocation, PeerMapping + +__all__ = ["AmdFabricDriver"] + +_NOT_IMPLEMENTED_MESSAGE = "AMD fabric driver not yet implemented" + + +class AmdFabricDriver(BaseFabricDriver): + """AMD fabric driver placeholder.""" + + def initialize(self, device_ordinal: int) -> None: + raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE) + + def allocate_exportable(self, size: int) -> LocalAllocation: + raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE) + + def export_handle(self, allocation: LocalAllocation) -> bytes: + raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE) + + def import_and_map(self, peer_rank: int, handle_bytes: bytes, size: int) -> PeerMapping: + raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE) + + def cleanup_import(self, mapping: PeerMapping) -> None: + raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE) + + def cleanup_local(self, allocation: LocalAllocation) -> None: + raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE) diff --git a/iris/drivers/fabric/nvidia.py b/iris/drivers/fabric/nvidia.py new file mode 100644 index 000000000..29c3b8ba3 --- /dev/null +++ b/iris/drivers/fabric/nvidia.py @@ -0,0 +1,454 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +NVIDIA CUDA VMM fabric driver. +""" + +from __future__ import annotations + +import ctypes +import logging +from typing import Any, Optional + +import torch + +from iris.drivers.base import ( + BaseFabricDriver, + DriverError, + DriverNotSupported, + LocalAllocation, + PeerMapping, +) +from iris.topology import InterconnectLevel + +logger = logging.getLogger("iris.drivers.fabric") + +__all__ = [ + "CudaFabricError", + "CudaFabricNotSupported", + "FABRIC_HANDLE_BYTES", + "NvidiaFabricDriver", +] + +# Load CUDA driver library at module level +_cuda_driver = None +try: + _cuda_driver = ctypes.CDLL("libcuda.so.1") +except OSError: + try: + _cuda_driver = ctypes.CDLL("libcuda.so") + except OSError: + pass + +CUDA_SUCCESS = 0 +CUDA_ERROR_NOT_SUPPORTED = 801 +FABRIC_HANDLE_BYTES = 64 + +# CUDA VMM constants +_CU_MEM_ALLOCATION_TYPE_PINNED = 1 +_CU_MEM_LOCATION_TYPE_DEVICE = 1 +_CU_MEM_HANDLE_TYPE_FABRIC = 0x8 +_CU_MEM_ALLOC_GRANULARITY_MINIMUM = 0 +_CU_MEM_ACCESS_FLAGS_PROT_READWRITE = 0x3 + + +class CudaFabricError(DriverError): + """CUDA fabric/VMM operation failed.""" + + +class CudaFabricNotSupported(DriverNotSupported): + """The local CUDA stack does not support fabric handles.""" + + +def _cuda_try(err: int, op_name: str = "CUDA operation") -> None: + """Check CUDA driver return code and raise on error.""" + if err == CUDA_SUCCESS: + return + error_name = str(err) + if _cuda_driver is not None and hasattr(_cuda_driver, "cuGetErrorName"): + ptr = ctypes.c_char_p() + if _cuda_driver.cuGetErrorName(err, ctypes.byref(ptr)) == CUDA_SUCCESS and ptr.value: + error_name = ptr.value.decode("utf-8") + message = f"{op_name} failed with {error_name} ({err})" + if err == CUDA_ERROR_NOT_SUPPORTED: + raise CudaFabricNotSupported(message) + raise CudaFabricError(message) + + +def _round_up(value: int, granularity: int) -> int: + if granularity <= 0: + raise ValueError(f"granularity must be > 0, got {granularity}") + return ((value + granularity - 1) // granularity) * granularity + + +def _normalize_fabric_handle_bytes(raw_handle: Any) -> bytes: + if isinstance(raw_handle, memoryview): + data = raw_handle.tobytes() + elif isinstance(raw_handle, (bytes, bytearray)): + data = bytes(raw_handle) + elif isinstance(raw_handle, torch.Tensor): + data = bytes(raw_handle.detach().to("cpu", copy=True).flatten().tolist()) + else: + try: + data = bytes(raw_handle) + except Exception: + try: + data = ctypes.string_at(ctypes.addressof(raw_handle), FABRIC_HANDLE_BYTES) + except Exception as exc: + raise CudaFabricError("Unable to convert fabric handle object to bytes") from exc + + if len(data) != FABRIC_HANDLE_BYTES: + raise CudaFabricError(f"Fabric handle serialization expected {FABRIC_HANDLE_BYTES} bytes, got {len(data)}") + return data + + +def _get_required_cuda_symbol(name: str) -> Any: + if _cuda_driver is None: + raise CudaFabricNotSupported("CUDA driver library (libcuda.so) not found") + + symbol = getattr(_cuda_driver, name, None) + if symbol is None: + raise CudaFabricNotSupported(f"CUDA driver missing required VMM symbol: {name}") + return symbol + + +def _run_cleanup_steps(*steps) -> None: + first_error = None + for step in steps: + try: + step() + except Exception as exc: + if first_error is None: + first_error = exc + if first_error is not None: + raise first_error + + +# ctypes structure definitions for CUDA VMM API +class _MemLocation(ctypes.Structure): + _fields_ = [("type", ctypes.c_int), ("id", ctypes.c_int)] + + +class _MemAllocationFlags(ctypes.Structure): + _fields_ = [ + ("compressionType", ctypes.c_ubyte), + ("gpuDirectRDMACapable", ctypes.c_ubyte), + ("usage", ctypes.c_ushort), + ("reserved", ctypes.c_ubyte * 4), + ] + + +class _MemAllocationProp(ctypes.Structure): + _fields_ = [ + ("type", ctypes.c_int), + ("requestedHandleTypes", ctypes.c_int), + ("location", _MemLocation), + ("win32HandleMetaData", ctypes.c_void_p), + ("allocFlags", _MemAllocationFlags), + ] + + +class _MemAccessDesc(ctypes.Structure): + _fields_ = [("location", _MemLocation), ("flags", ctypes.c_ulonglong)] + + +def _configure_cuda_signatures() -> None: + """Configure ctypes signatures for the required CUDA VMM driver API.""" + if _cuda_driver is None: + return + + cu_init = _get_required_cuda_symbol("cuInit") + cu_device_get = _get_required_cuda_symbol("cuDeviceGet") + cu_device_primary_ctx_retain = _get_required_cuda_symbol("cuDevicePrimaryCtxRetain") + cu_ctx_set_current = _get_required_cuda_symbol("cuCtxSetCurrent") + cu_mem_get_allocation_granularity = _get_required_cuda_symbol("cuMemGetAllocationGranularity") + cu_mem_address_reserve = _get_required_cuda_symbol("cuMemAddressReserve") + cu_mem_address_free = _get_required_cuda_symbol("cuMemAddressFree") + cu_mem_create = _get_required_cuda_symbol("cuMemCreate") + cu_mem_release = _get_required_cuda_symbol("cuMemRelease") + cu_mem_map = _get_required_cuda_symbol("cuMemMap") + cu_mem_unmap = _get_required_cuda_symbol("cuMemUnmap") + cu_mem_set_access = _get_required_cuda_symbol("cuMemSetAccess") + cu_mem_export_to_shareable_handle = _get_required_cuda_symbol("cuMemExportToShareableHandle") + cu_mem_import_from_shareable_handle = _get_required_cuda_symbol("cuMemImportFromShareableHandle") + + cu_init.argtypes = [ctypes.c_uint] + cu_init.restype = ctypes.c_int + + cu_device_get.argtypes = [ctypes.POINTER(ctypes.c_int), ctypes.c_int] + cu_device_get.restype = ctypes.c_int + + cu_device_primary_ctx_retain.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), + ctypes.c_int, + ] + cu_device_primary_ctx_retain.restype = ctypes.c_int + + cu_ctx_set_current.argtypes = [ctypes.c_void_p] + cu_ctx_set_current.restype = ctypes.c_int + + cu_mem_get_allocation_granularity.argtypes = [ + ctypes.POINTER(ctypes.c_size_t), + ctypes.POINTER(_MemAllocationProp), + ctypes.c_int, + ] + cu_mem_get_allocation_granularity.restype = ctypes.c_int + + cu_mem_address_reserve.argtypes = [ + ctypes.POINTER(ctypes.c_uint64), + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_uint64, + ctypes.c_ulonglong, + ] + cu_mem_address_reserve.restype = ctypes.c_int + + cu_mem_address_free.argtypes = [ctypes.c_uint64, ctypes.c_size_t] + cu_mem_address_free.restype = ctypes.c_int + + cu_mem_create.argtypes = [ + ctypes.POINTER(ctypes.c_uint64), + ctypes.c_size_t, + ctypes.POINTER(_MemAllocationProp), + ctypes.c_ulonglong, + ] + cu_mem_create.restype = ctypes.c_int + + cu_mem_release.argtypes = [ctypes.c_uint64] + cu_mem_release.restype = ctypes.c_int + + cu_mem_map.argtypes = [ + ctypes.c_uint64, + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_uint64, + ctypes.c_ulonglong, + ] + cu_mem_map.restype = ctypes.c_int + + cu_mem_unmap.argtypes = [ctypes.c_uint64, ctypes.c_size_t] + cu_mem_unmap.restype = ctypes.c_int + + cu_mem_set_access.argtypes = [ + ctypes.c_uint64, + ctypes.c_size_t, + ctypes.POINTER(_MemAccessDesc), + ctypes.c_size_t, + ] + cu_mem_set_access.restype = ctypes.c_int + + cu_mem_export_to_shareable_handle.argtypes = [ + ctypes.c_void_p, + ctypes.c_uint64, + ctypes.c_int, + ctypes.c_ulonglong, + ] + cu_mem_export_to_shareable_handle.restype = ctypes.c_int + + cu_mem_import_from_shareable_handle.argtypes = [ + ctypes.POINTER(ctypes.c_uint64), + ctypes.c_void_p, + ctypes.c_int, + ] + cu_mem_import_from_shareable_handle.restype = ctypes.c_int + + cu_get_error_name = getattr(_cuda_driver, "cuGetErrorName", None) + if cu_get_error_name is not None: + cu_get_error_name.argtypes = [ + ctypes.c_int, + ctypes.POINTER(ctypes.c_char_p), + ] + cu_get_error_name.restype = ctypes.c_int + + +class NvidiaFabricDriver(BaseFabricDriver): + """ + NVIDIA CUDA VMM fabric driver. + + Uses ctypes to interface with libcuda.so for CUDA Virtual Memory Management + operations required for fabric handle export/import. + """ + + def __init__(self) -> None: + self._device_ordinal: int = 0 + self._granularity: Optional[int] = None + self._initialized: bool = False + + def _make_alloc_props(self) -> _MemAllocationProp: + props = _MemAllocationProp() + props.type = _CU_MEM_ALLOCATION_TYPE_PINNED + props.requestedHandleTypes = _CU_MEM_HANDLE_TYPE_FABRIC + props.location.type = _CU_MEM_LOCATION_TYPE_DEVICE + props.location.id = self._device_ordinal + props.win32HandleMetaData = None + return props + + def _get_granularity(self) -> int: + if self._granularity is not None: + return self._granularity + props = self._make_alloc_props() + granularity = ctypes.c_size_t() + _cuda_try( + _cuda_driver.cuMemGetAllocationGranularity( + ctypes.byref(granularity), + ctypes.byref(props), + _CU_MEM_ALLOC_GRANULARITY_MINIMUM, + ), + "cuMemGetAllocationGranularity", + ) + self._granularity = int(granularity.value) + return self._granularity + + def _mem_set_access(self, va: int, size: int) -> None: + desc = _MemAccessDesc() + desc.location.type = _CU_MEM_LOCATION_TYPE_DEVICE + desc.location.id = self._device_ordinal + desc.flags = _CU_MEM_ACCESS_FLAGS_PROT_READWRITE + _cuda_try(_cuda_driver.cuMemSetAccess(va, size, ctypes.byref(desc), 1), "cuMemSetAccess") + + def initialize(self, device_ordinal: int) -> None: + if _cuda_driver is None: + raise CudaFabricNotSupported("CUDA driver library (libcuda.so) not found") + + _configure_cuda_signatures() + _cuda_try(_cuda_driver.cuInit(0), "cuInit") + dev = ctypes.c_int() + _cuda_try(_cuda_driver.cuDeviceGet(ctypes.byref(dev), device_ordinal), "cuDeviceGet") + ctx = ctypes.c_void_p() + _cuda_try(_cuda_driver.cuDevicePrimaryCtxRetain(ctypes.byref(ctx), dev.value), "cuDevicePrimaryCtxRetain") + _cuda_try(_cuda_driver.cuCtxSetCurrent(ctx), "cuCtxSetCurrent") + self._device_ordinal = device_ordinal + self._granularity = None + self._initialized = True + logger.info("NvidiaFabricDriver initialized (device %d)", device_ordinal) + + def _check_initialized(self) -> None: + if not self._initialized: + raise CudaFabricError("NvidiaFabricDriver not initialized — call initialize() first") + + def allocate_exportable(self, size: int) -> LocalAllocation: + self._check_initialized() + props = self._make_alloc_props() + granularity = self._get_granularity() + alloc_size = _round_up(size, granularity) + + va = ctypes.c_uint64() + handle = ctypes.c_uint64() + mapped = False + + try: + _cuda_try( + _cuda_driver.cuMemAddressReserve(ctypes.byref(va), alloc_size, granularity, 0, 0), + "cuMemAddressReserve", + ) + _cuda_try( + _cuda_driver.cuMemCreate(ctypes.byref(handle), alloc_size, ctypes.byref(props), 0), + "cuMemCreate", + ) + _cuda_try(_cuda_driver.cuMemMap(va.value, alloc_size, 0, handle.value, 0), "cuMemMap") + mapped = True + self._mem_set_access(int(va.value), alloc_size) + return LocalAllocation(va=int(va.value), size=alloc_size, handle=int(handle.value)) + except Exception: + if mapped: + try: + _cuda_try(_cuda_driver.cuMemUnmap(va.value, alloc_size), "cuMemUnmap") + except Exception: + pass + if handle.value: + try: + _cuda_try(_cuda_driver.cuMemRelease(handle.value), "cuMemRelease") + except Exception: + pass + if va.value: + try: + _cuda_try(_cuda_driver.cuMemAddressFree(va.value, alloc_size), "cuMemAddressFree") + except Exception: + pass + raise + + def export_handle(self, allocation: LocalAllocation) -> bytes: + self._check_initialized() + raw = (ctypes.c_ubyte * FABRIC_HANDLE_BYTES)() + _cuda_try( + _cuda_driver.cuMemExportToShareableHandle( + ctypes.byref(raw), + int(allocation.handle), + _CU_MEM_HANDLE_TYPE_FABRIC, + 0, + ), + "cuMemExportToShareableHandle", + ) + return bytes(raw) + + def _import_handle(self, handle_bytes: bytes) -> int: + handle_bytes = _normalize_fabric_handle_bytes(handle_bytes) + imported = ctypes.c_uint64() + raw = (ctypes.c_ubyte * FABRIC_HANDLE_BYTES).from_buffer_copy(handle_bytes) + _cuda_try( + _cuda_driver.cuMemImportFromShareableHandle( + ctypes.byref(imported), + ctypes.byref(raw), + _CU_MEM_HANDLE_TYPE_FABRIC, + ), + "cuMemImportFromShareableHandle", + ) + return int(imported.value) + + def import_and_map(self, peer_rank: int, handle_bytes: bytes, size: int) -> PeerMapping: + self._check_initialized() + imported_handle = self._import_handle(handle_bytes) + + granularity = self._get_granularity() + va = ctypes.c_uint64() + + mapped = False + try: + _cuda_try( + _cuda_driver.cuMemAddressReserve(ctypes.byref(va), size, granularity, 0, 0), + "cuMemAddressReserve", + ) + _cuda_try(_cuda_driver.cuMemMap(va.value, size, 0, imported_handle, 0), "cuMemMap") + mapped = True + self._mem_set_access(int(va.value), size) + except Exception: + if mapped: + try: + _cuda_try(_cuda_driver.cuMemUnmap(va.value, size), "cuMemUnmap") + except Exception: + pass + try: + _cuda_try(_cuda_driver.cuMemRelease(imported_handle), "cuMemRelease") + except Exception: + pass + if va.value: + try: + _cuda_try(_cuda_driver.cuMemAddressFree(va.value, size), "cuMemAddressFree") + except Exception: + pass + raise + + return PeerMapping( + peer_rank=peer_rank, + transport=InterconnectLevel.INTRA_RACK_FABRIC, + remote_va=int(va.value), + size=size, + _driver_handle=imported_handle, + ) + + def cleanup_import(self, mapping: PeerMapping) -> None: + self._check_initialized() + _run_cleanup_steps( + lambda: _cuda_try(_cuda_driver.cuMemUnmap(mapping.remote_va, mapping.size), "cuMemUnmap"), + lambda: _cuda_try(_cuda_driver.cuMemRelease(mapping._driver_handle), "cuMemRelease"), + lambda: _cuda_try(_cuda_driver.cuMemAddressFree(mapping.remote_va, mapping.size), "cuMemAddressFree"), + ) + + def cleanup_local(self, allocation: LocalAllocation) -> None: + self._check_initialized() + _run_cleanup_steps( + lambda: _cuda_try(_cuda_driver.cuMemUnmap(allocation.va, allocation.size), "cuMemUnmap"), + lambda: _cuda_try(_cuda_driver.cuMemRelease(allocation.handle), "cuMemRelease"), + lambda: _cuda_try(_cuda_driver.cuMemAddressFree(allocation.va, allocation.size), "cuMemAddressFree"), + ) diff --git a/tests/unittests/test_drivers.py b/tests/unittests/test_drivers.py new file mode 100644 index 000000000..9e70a1869 --- /dev/null +++ b/tests/unittests/test_drivers.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Tests for the driver layer. + +These are pure unit tests and should run without GPUs or a distributed setup. +""" + +from __future__ import annotations + +import pytest +import torch + +from iris.drivers.base import ( + DriverError, + DriverNotSupported, + LocalAllocation, + PeerMapping, +) +from iris.drivers.fabric import fabric_handle_bytes_to_tensor, fabric_tensor_to_handle_bytes +from iris.drivers.fabric import nvidia as nvidia_driver_module +from iris.drivers.fabric.nvidia import ( + CudaFabricError, + CudaFabricNotSupported, + FABRIC_HANDLE_BYTES, + NvidiaFabricDriver, + _normalize_fabric_handle_bytes, + _round_up, +) +from iris.topology import InterconnectLevel + + +class TestExceptions: + def test_cuda_fabric_error_uses_driver_error_hierarchy(self): + assert issubclass(CudaFabricError, DriverError) + assert issubclass(CudaFabricNotSupported, DriverNotSupported) + + +class TestFabricHandleSerialization: + def test_round_trip(self): + original = bytes(range(FABRIC_HANDLE_BYTES)) + tensor = fabric_handle_bytes_to_tensor(original, "cpu", expected_num_bytes=FABRIC_HANDLE_BYTES) + assert tensor.shape == (FABRIC_HANDLE_BYTES,) + assert tensor.dtype == torch.uint8 + assert fabric_tensor_to_handle_bytes(tensor, expected_num_bytes=FABRIC_HANDLE_BYTES) == original + + def test_wrong_size_bytes(self): + with pytest.raises(ValueError, match="64 bytes"): + fabric_handle_bytes_to_tensor(b"\x00" * 32, "cpu", expected_num_bytes=FABRIC_HANDLE_BYTES) + + def test_wrong_size_tensor(self): + with pytest.raises(ValueError, match="64 elements"): + fabric_tensor_to_handle_bytes(torch.zeros(32, dtype=torch.uint8), expected_num_bytes=FABRIC_HANDLE_BYTES) + + def test_wrong_tensor_dtype(self): + with pytest.raises(ValueError, match="dtype torch.uint8"): + fabric_tensor_to_handle_bytes(torch.zeros(FABRIC_HANDLE_BYTES, dtype=torch.int32)) + + +class TestNvidiaFabricHelpers: + def test_round_up(self): + assert _round_up(1, 64) == 64 + assert _round_up(64, 64) == 64 + assert _round_up(65, 64) == 128 + + def test_round_up_rejects_nonpositive_granularity(self): + with pytest.raises(ValueError, match="granularity must be > 0"): + _round_up(1, 0) + + def test_normalize_fabric_handle_bytes(self): + original = bytes(range(FABRIC_HANDLE_BYTES)) + assert _normalize_fabric_handle_bytes(original) == original + assert _normalize_fabric_handle_bytes(bytearray(original)) == original + assert _normalize_fabric_handle_bytes(memoryview(original)) == original + tensor = torch.tensor(list(original), dtype=torch.uint8) + assert _normalize_fabric_handle_bytes(tensor) == original + + def test_normalize_fabric_handle_bytes_rejects_wrong_size(self): + with pytest.raises(CudaFabricError, match="expected 64 bytes"): + _normalize_fabric_handle_bytes(b"\x00" * 8) + + def test_normalize_fabric_handle_bytes_rejects_unconvertible_objects(self): + class Unconvertible: + def __bytes__(self): + raise TypeError("boom") + + with pytest.raises(CudaFabricError, match="Unable to convert"): + _normalize_fabric_handle_bytes(Unconvertible()) + + +class TestNvidiaFabricDriver: + @pytest.mark.parametrize( + ("method_name", "args"), + [ + ("allocate_exportable", (4096,)), + ("export_handle", (LocalAllocation(va=0, size=0, handle=0),)), + ("import_and_map", (0, b"\x00" * FABRIC_HANDLE_BYTES, 4096)), + ( + "cleanup_import", + ( + PeerMapping( + peer_rank=0, + transport=InterconnectLevel.INTRA_RACK_FABRIC, + remote_va=0, + size=0, + ), + ), + ), + ("cleanup_local", (LocalAllocation(va=0, size=0, handle=0),)), + ], + ) + def test_public_methods_require_initialize(self, method_name, args): + driver = NvidiaFabricDriver() + method = getattr(driver, method_name) + with pytest.raises(CudaFabricError, match="not initialized"): + method(*args) + + def test_initialize_raises_when_no_cuda_driver(self, monkeypatch): + monkeypatch.setattr(nvidia_driver_module, "_cuda_driver", None) + driver = NvidiaFabricDriver() + with pytest.raises(CudaFabricNotSupported, match="libcuda.so.*not found"): + driver.initialize(0) + + def test_initialize_raises_not_supported_for_missing_required_symbol(self, monkeypatch): + class IncompleteCudaDriver: + pass + + monkeypatch.setattr(nvidia_driver_module, "_cuda_driver", IncompleteCudaDriver()) + driver = NvidiaFabricDriver() + with pytest.raises(CudaFabricNotSupported, match="missing required VMM symbol: cuInit"): + driver.initialize(0) + + def test_cleanup_import_attempts_all_cleanup_steps(self, monkeypatch): + calls = [] + + class FakeCudaDriver: + def cuMemUnmap(self, remote_va, size): + calls.append(("unmap", remote_va, size)) + return 1 + + def cuMemRelease(self, handle): + calls.append(("release", handle)) + return 0 + + def cuMemAddressFree(self, remote_va, size): + calls.append(("free", remote_va, size)) + return 0 + + monkeypatch.setattr(nvidia_driver_module, "_cuda_driver", FakeCudaDriver()) + driver = NvidiaFabricDriver() + driver._initialized = True + mapping = PeerMapping( + peer_rank=2, + transport=InterconnectLevel.INTRA_RACK_FABRIC, + remote_va=0x2000, + size=4096, + _driver_handle=99, + ) + + with pytest.raises(CudaFabricError, match="cuMemUnmap"): + driver.cleanup_import(mapping) + + assert calls == [ + ("unmap", 0x2000, 4096), + ("release", 99), + ("free", 0x2000, 4096), + ] + + def test_cleanup_local_attempts_all_cleanup_steps(self, monkeypatch): + calls = [] + + class FakeCudaDriver: + def cuMemUnmap(self, va, size): + calls.append(("unmap", va, size)) + return 1 + + def cuMemRelease(self, handle): + calls.append(("release", handle)) + return 0 + + def cuMemAddressFree(self, va, size): + calls.append(("free", va, size)) + return 0 + + monkeypatch.setattr(nvidia_driver_module, "_cuda_driver", FakeCudaDriver()) + driver = NvidiaFabricDriver() + driver._initialized = True + allocation = LocalAllocation(va=0x1000, size=8192, handle=77) + + with pytest.raises(CudaFabricError, match="cuMemUnmap"): + driver.cleanup_local(allocation) + + assert calls == [ + ("unmap", 0x1000, 8192), + ("release", 77), + ("free", 0x1000, 8192), + ]