Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions iris/drivers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved.

"""
Shared driver package types for fabric backends.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

from iris.drivers.base import BaseFabricDriver

__all__ = ["DriverStack"]


@dataclass
class DriverStack:
"""Fabric drivers available for a rank."""

vendor: str
fabric: Optional[BaseFabricDriver]
78 changes: 78 additions & 0 deletions iris/drivers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved.

"""
Abstract base classes, shared dataclasses, and exceptions for fabric drivers.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any

from iris.topology import InterconnectLevel

__all__ = [
"PeerMapping",
"LocalAllocation",
"BaseFabricDriver",
"DriverError",
"DriverNotSupported",
]


@dataclass
class PeerMapping:
"""A remote rank's memory mapped into this rank's address space."""

peer_rank: int
transport: InterconnectLevel
remote_va: int
size: int
_driver_handle: Any = None


@dataclass
class LocalAllocation:
"""This rank's exportable allocation."""

va: int
size: int
handle: Any


class DriverError(RuntimeError):
"""Base exception for driver operations."""


class DriverNotSupported(DriverError):
"""The current hardware or software stack does not support this driver."""


class BaseFabricDriver(ABC):
"""Cross-node fabric memory sharing (for example NVSwitch or xGMI)."""

@abstractmethod
def initialize(self, device_ordinal: int) -> None:
"""Prepare the driver for a specific local GPU."""

@abstractmethod
def allocate_exportable(self, size: int) -> LocalAllocation:
"""Allocate memory that can be shared through the fabric transport."""

@abstractmethod
def export_handle(self, allocation: LocalAllocation) -> bytes:
"""Export a transport-specific handle for a local allocation."""

@abstractmethod
def import_and_map(self, peer_rank: int, handle_bytes: bytes, size: int) -> PeerMapping:
"""Import a peer handle and map it into the local virtual address space."""

@abstractmethod
def cleanup_import(self, mapping: PeerMapping) -> None:
"""Release a mapped peer allocation."""

@abstractmethod
def cleanup_local(self, allocation: LocalAllocation) -> None:
"""Release a locally-exported allocation."""
44 changes: 44 additions & 0 deletions iris/drivers/fabric/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved.

"""
Shared handle serialization utilities for fabric drivers.

These helpers convert between raw fabric handles and uint8 tensors that higher
layers can exchange with torch.distributed collectives. Handle size validation,
when desired, is backend-specific and can be passed in explicitly.
"""

from __future__ import annotations

from typing import Optional, Union

import torch

__all__ = [
"fabric_handle_bytes_to_tensor",
"fabric_tensor_to_handle_bytes",
]


def fabric_handle_bytes_to_tensor(
handle_bytes: bytes,
device: Union[torch.device, str],
expected_num_bytes: Optional[int] = None,
) -> torch.Tensor:
"""Serialize a raw fabric handle into a uint8 tensor."""

if expected_num_bytes is not None and len(handle_bytes) != expected_num_bytes:
raise ValueError(f"Fabric handle must be {expected_num_bytes} bytes, got {len(handle_bytes)}")
return torch.tensor(list(handle_bytes), dtype=torch.uint8, device=device)


def fabric_tensor_to_handle_bytes(handle_tensor: torch.Tensor, expected_num_bytes: Optional[int] = None) -> bytes:
"""Deserialize a uint8 tensor back into raw handle bytes."""

flattened = handle_tensor.detach().flatten()
if flattened.dtype != torch.uint8:
raise ValueError("Fabric handle tensor must have dtype torch.uint8")
if expected_num_bytes is not None and flattened.numel() != expected_num_bytes:
raise ValueError(f"Fabric handle tensor must have {expected_num_bytes} elements, got {flattened.numel()}")
return bytes(flattened.to("cpu", copy=True).tolist())
Comment on lines +24 to +44
Copy link

Copilot AI Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both helpers create Python lists (list(handle_bytes) / tolist()), which is relatively expensive and allocates per-element Python integers. If these run in hot paths (e.g., per-collective/per-peer), consider using a zero/low-copy conversion path (e.g., build a CPU uint8 tensor from a buffer and then .to(device), and convert back via a contiguous CPU view to raw bytes) to reduce overhead.

Copilot uses AI. Check for mistakes.
36 changes: 36 additions & 0 deletions iris/drivers/fabric/amd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved.

"""
AMD fabric driver stub.
"""

from __future__ import annotations

from iris.drivers.base import BaseFabricDriver, DriverNotSupported, LocalAllocation, PeerMapping

__all__ = ["AmdFabricDriver"]

_NOT_IMPLEMENTED_MESSAGE = "AMD fabric driver not yet implemented"


class AmdFabricDriver(BaseFabricDriver):
"""AMD fabric driver placeholder."""

def initialize(self, device_ordinal: int) -> None:
raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE)

def allocate_exportable(self, size: int) -> LocalAllocation:
raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE)

def export_handle(self, allocation: LocalAllocation) -> bytes:
raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE)

def import_and_map(self, peer_rank: int, handle_bytes: bytes, size: int) -> PeerMapping:
raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE)

def cleanup_import(self, mapping: PeerMapping) -> None:
raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE)

def cleanup_local(self, allocation: LocalAllocation) -> None:
raise DriverNotSupported(_NOT_IMPLEMENTED_MESSAGE)
Loading
Loading