Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
475 changes: 253 additions & 222 deletions gigl/distributed/base_dist_loader.py

Large diffs are not rendered by default.

19 changes: 6 additions & 13 deletions gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import abc, defaultdict
from itertools import count
from typing import Callable, Optional, Union
from typing import Optional, Union

import torch
from graphlearn_torch.channel import SampleMessage
Expand All @@ -21,7 +21,6 @@
PPR_WEIGHT_METADATA_KEY,
)
from gigl.distributed.dist_sampling_producer import DistSamplingProducer
from gigl.distributed.graph_store.dist_server import DistServer
from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset
from gigl.distributed.sampler import (
NEGATIVE_LABEL_METADATA_KEY,
Expand Down Expand Up @@ -353,24 +352,19 @@ def __init__(
drop_last=drop_last,
)

# Build the producer: a pre-constructed producer for colocated mode,
# or an RPC callable for graph store mode.
producer: Optional[DistSamplingProducer] = None
if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED:
assert isinstance(dataset, DistDataset)
assert isinstance(worker_options, MpDistSamplingWorkerOptions)
channel = BaseDistLoader.create_colocated_channel(worker_options)
producer: Union[
DistSamplingProducer, Callable[..., int]
] = DistSamplingProducer(
producer = DistSamplingProducer(
data=dataset,
sampler_input=sampler_input,
sampling_config=sampling_config,
worker_options=worker_options,
channel=channel,
sampler_options=sampler_options,
)
else:
producer = DistServer.create_sampling_producer

# Call base class — handles metadata storage and connection initialization
# (including staggered init for colocated mode).
Expand Down Expand Up @@ -616,13 +610,12 @@ def _setup_for_graph_store(
edge_feature_info = dataset.fetch_edge_feature_info()
edge_types = dataset.fetch_edge_types()
compute_rank = torch.distributed.get_rank()
worker_key = (
f"compute_ablp_loader_rank_{compute_rank}_worker_{self._instance_count}"
)
backend_key = f"dist_ablp_loader_{self._instance_count}"
worker_key = f"{backend_key}_compute_rank_{compute_rank}"
logger.info(f"rank: {compute_rank}, worker_key: {worker_key}")
worker_options = BaseDistLoader.create_graph_store_worker_options(
dataset=dataset,
compute_rank=compute_rank,
loader_port_index=self._instance_count * 2 + 1,
worker_key=worker_key,
num_workers=num_workers,
worker_concurrency=worker_concurrency,
Expand Down
Loading