Skip to content
Draft
57 changes: 54 additions & 3 deletions gigl/distributed/graph_store/dist_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,27 @@
R = TypeVar("R")


def _slice_nodes_for_shard(
nodes: torch.Tensor, shard_index: int, num_shards: int
) -> torch.Tensor:
"""Return a contiguous local shard with ``torch.tensor_split`` semantics."""
Comment thread
kmontemayor2-sc marked this conversation as resolved.
Outdated
if num_shards <= 0:
raise ValueError(f"num_shards must be > 0, received {num_shards}")
if shard_index < 0 or shard_index >= num_shards:
raise ValueError(
"shard_index must be in [0, num_shards). "
f"Received shard_index={shard_index}, num_shards={num_shards}"
)

num_nodes = nodes.size(0)
base_shard_size = num_nodes // num_shards
remainder = num_nodes % num_shards
start = shard_index * base_shard_size + min(shard_index, remainder)
Comment thread
kmontemayor2-sc marked this conversation as resolved.
Outdated
length = base_shard_size + (1 if shard_index < remainder else 0)
end = start + length
return nodes[start:end]


class DistServer:
r"""A server that supports launching remote sampling workers for
training clients.
Expand Down Expand Up @@ -305,6 +326,8 @@ def get_node_ids(
node_type=request.node_type,
rank=request.rank,
world_size=request.world_size,
shard_index=request.shard_index,
num_shards=request.num_shards,
)

def _get_node_ids(
Expand All @@ -313,6 +336,8 @@ def _get_node_ids(
node_type: Optional[NodeType],
rank: Optional[int] = None,
world_size: Optional[int] = None,
shard_index: Optional[int] = None,
num_shards: Optional[int] = None,
Comment thread
kmontemayor2-sc marked this conversation as resolved.
Outdated
) -> torch.Tensor:
"""Core implementation for fetching node IDs by split, type, and sharding.

Expand All @@ -325,20 +350,34 @@ def _get_node_ids(
with ``world_size``.
world_size: Total number of processes for sharding. Must be
provided together with ``rank``.
shard_index: Local shard index for storage-rank-local sharding.
Must be provided together with ``num_shards``.
num_shards: Total number of local shards for this storage rank.
Must be provided together with ``shard_index``.

Returns:
The node IDs tensor, optionally sharded by rank.

Raises:
ValueError: If rank/world_size are not provided together, the
split is invalid, or the node type is inconsistent with
the dataset type (homogeneous vs. heterogeneous).
ValueError: If the sharding inputs are invalid, the split is
invalid, or the node type is inconsistent with the dataset
type (homogeneous vs. heterogeneous).
"""
if (rank is None) ^ (world_size is None):
raise ValueError(
"rank and world_size must be provided together. "
f"Received rank={rank}, world_size={world_size}"
)
if (shard_index is None) ^ (num_shards is None):
raise ValueError(
"shard_index and num_shards must be provided together. "
f"Received shard_index={shard_index}, num_shards={num_shards}"
)
if rank is not None and shard_index is not None:
raise ValueError(
"rank/world_size and shard_index/num_shards are mutually exclusive. "
f"Received rank={rank}, world_size={world_size}, shard_index={shard_index}, num_shards={num_shards}"
)

if split == "train":
nodes = self.dataset.train_node_ids
Expand Down Expand Up @@ -367,6 +406,8 @@ def _get_node_ids(

if rank is not None and world_size is not None:
return shard_nodes_by_process(nodes, rank, world_size)
if shard_index is not None and num_shards is not None:
return _slice_nodes_for_shard(nodes, shard_index, num_shards)
return nodes

def get_edge_types(self) -> Optional[list[EdgeType]]:
Expand Down Expand Up @@ -419,10 +460,20 @@ def get_ablp_input(
node_type=request.node_type,
rank=request.rank,
world_size=request.world_size,
shard_index=request.shard_index,
num_shards=request.num_shards,
)
positive_label_edge_type, negative_label_edge_type = select_label_edge_types(
request.supervision_edge_type, self.dataset.get_edge_types()
)
if anchors.numel() == 0:
Comment thread
kmontemayor2-sc marked this conversation as resolved.
Outdated
empty_positive_labels = torch.empty(0, 0, dtype=torch.int64)
empty_negative_labels = (
torch.empty(0, 0, dtype=torch.int64)
if negative_label_edge_type is not None
else None
)
return anchors, empty_positive_labels, empty_negative_labels
positive_labels, negative_labels = get_labels_for_anchor_nodes(
self.dataset, anchors, positive_label_edge_type, negative_label_edge_type
)
Expand Down
74 changes: 60 additions & 14 deletions gigl/distributed/graph_store/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,38 @@
from gigl.src.common.types.graph_data import EdgeType, NodeType


def _validate_sharding_mode(
rank: Optional[int],
world_size: Optional[int],
shard_index: Optional[int],
num_shards: Optional[int],
) -> None:
"""Validate that requests use at most one sharding mode."""
if (rank is None) ^ (world_size is None):
raise ValueError(
"rank and world_size must be provided together. "
f"Received rank={rank}, world_size={world_size}"
)
if (shard_index is None) ^ (num_shards is None):
raise ValueError(
"shard_index and num_shards must be provided together. "
f"Received shard_index={shard_index}, num_shards={num_shards}"
)
if rank is not None and shard_index is not None:
raise ValueError(
"rank/world_size and shard_index/num_shards are mutually exclusive. "
f"Received rank={rank}, world_size={world_size}, shard_index={shard_index}, num_shards={num_shards}"
)
if shard_index is not None and num_shards is not None:
if num_shards <= 0:
raise ValueError(f"num_shards must be > 0, received {num_shards}")
if shard_index < 0 or shard_index >= num_shards:
raise ValueError(
"shard_index must be in [0, num_shards). "
f"Received shard_index={shard_index}, num_shards={num_shards}"
)


@dataclass(frozen=True)
class FetchNodesRequest:
"""Request for fetching node IDs from a storage server.
Expand All @@ -15,6 +47,10 @@ class FetchNodesRequest:
Must be provided together with ``world_size``.
world_size: The total number of processes in the distributed setup.
Must be provided together with ``rank``.
shard_index: The local shard index to fetch from this storage rank.
Must be provided together with ``num_shards``.
num_shards: The total number of local shards for this storage rank.
Must be provided together with ``shard_index``.
split: The split of the dataset to get node ids from.
node_type: The type of nodes to get node ids for.

Expand All @@ -34,20 +70,23 @@ class FetchNodesRequest:

rank: Optional[int] = None
world_size: Optional[int] = None
shard_index: Optional[int] = None
num_shards: Optional[int] = None
split: Optional[Union[Literal["train", "val", "test"], str]] = None
node_type: Optional[NodeType] = None

def validate(self) -> None:
"""Validate that the request has consistent rank/world_size.
"""Validate that the request has a consistent sharding mode.

Raises:
ValueError: If only one of ``rank`` or ``world_size`` is provided.
ValueError: If the request mixes or partially specifies sharding modes.
"""
if (self.rank is None) ^ (self.world_size is None):
raise ValueError(
"rank and world_size must be provided together. "
f"Received rank={self.rank}, world_size={self.world_size}"
)
_validate_sharding_mode(
rank=self.rank,
world_size=self.world_size,
shard_index=self.shard_index,
num_shards=self.num_shards,
)


@dataclass(frozen=True)
Expand All @@ -62,6 +101,10 @@ class FetchABLPInputRequest:
Must be provided together with ``world_size``.
world_size: The total number of processes in the distributed setup.
Must be provided together with ``rank``.
shard_index: The local shard index to fetch from this storage rank.
Must be provided together with ``num_shards``.
num_shards: The total number of local shards for this storage rank.
Must be provided together with ``shard_index``.

Examples:
Fetch training ABLP input without sharding:
Expand All @@ -78,15 +121,18 @@ class FetchABLPInputRequest:
supervision_edge_type: EdgeType
rank: Optional[int] = None
world_size: Optional[int] = None
shard_index: Optional[int] = None
num_shards: Optional[int] = None

def validate(self) -> None:
"""Validate that the request has consistent rank/world_size.
"""Validate that the request has a consistent sharding mode.

Raises:
ValueError: If only one of ``rank`` or ``world_size`` is provided.
ValueError: If the request mixes or partially specifies sharding modes.
"""
if (self.rank is None) ^ (self.world_size is None):
raise ValueError(
"rank and world_size must be provided together. "
f"Received rank={self.rank}, world_size={self.world_size}"
)
_validate_sharding_mode(
rank=self.rank,
world_size=self.world_size,
shard_index=self.shard_index,
num_shards=self.num_shards,
)
Loading