Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 274 additions & 1 deletion gigl/distributed/base_dist_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
"""

import math
import queue
import sys
import threading
import time
from collections import Counter, defaultdict
from dataclasses import dataclass
from typing import Callable, Optional, Union
from typing import Callable, Final, Optional, Union

import torch
from graphlearn_torch.channel import RemoteReceivingChannel, ShmChannel
Expand Down Expand Up @@ -51,9 +53,55 @@

logger = Logger()

_COLLATION_SENTINEL: Final = object() # Signals end-of-epoch to consumer

DEFAULT_NUM_CPU_THREADS = 2


class TimingStats:
"""Accumulates timing measurements for profiling and outputs a summary."""

def __init__(self, name: str):
self._name = name
self._totals: dict[str, float] = defaultdict(float)
self._counts: dict[str, int] = defaultdict(int)
self._mins: dict[str, float] = {}
self._maxs: dict[str, float] = {}
self._order: list[str] = []

def record(self, key: str, elapsed: float) -> None:
if key not in self._totals:
self._order.append(key)
self._totals[key] += elapsed
self._counts[key] += 1
if key not in self._mins or elapsed < self._mins[key]:
self._mins[key] = elapsed
if key not in self._maxs or elapsed > self._maxs[key]:
self._maxs[key] = elapsed

def summary(self) -> str:
lines = [
f"\n{'=' * 80}",
f" {self._name} — Timing Summary",
f"{'=' * 80}",
]
for key in self._order:
total = self._totals[key]
count = self._counts[key]
avg = total / count if count > 0 else 0
min_v = self._mins.get(key, 0)
max_v = self._maxs.get(key, 0)
if count == 1:
lines.append(f" {key:<45s} {total:>10.4f}s")
continue
lines.append(
f" {key:<45s} total={total:>10.4f}s n={count:>6d} "
f"avg={avg:>8.4f}s min={min_v:>8.4f}s max={max_v:>8.4f}s"
)
lines.append(f"{'=' * 80}\n")
return "\n".join(lines)


# We don't see logs for graph store mode for whatever reason.
# TOOD(#442): Revert this once the GCP issues are resolved.
def _flush() -> None:
Expand Down Expand Up @@ -114,6 +162,12 @@ class BaseDistLoader(DistLoader):
``batch_index * process_start_gap_seconds`` before dispatching.
Only applies to graph store mode. Defaults to ``None``
(no staggering).
background_collation_queue_size: If set to a positive integer, enables
background collation in a daemon thread. The collation of sampled
messages (via ``_collate_fn``) is performed in a background thread,
overlapping with GPU training. The value controls the maximum number
of pre-collated batches buffered in memory. ``None`` disables
background collation (default behavior).
"""

@staticmethod
Expand Down Expand Up @@ -220,6 +274,7 @@ def __init__(
producer: Union[DistSamplingProducer, Callable[..., int]],
sampler_options: SamplerOptions,
process_start_gap_seconds: float = 60.0,
background_collation_queue_size: Optional[int] = None,
max_concurrent_producer_inits: Optional[int] = None,
):
if max_concurrent_producer_inits is None:
Expand All @@ -229,6 +284,30 @@ def __init__(
# Will be set to False once connections are initialized.
self._shutdowned = True

# --- Background collation setup (validate early, before heavy init) ---
if (
background_collation_queue_size is not None
and background_collation_queue_size < 1
):
raise ValueError(
f"background_collation_queue_size must be >= 1 if provided, "
f"got {background_collation_queue_size}"
)
self._background_collation_queue_size = background_collation_queue_size
self._collation_thread: Optional[threading.Thread] = None
self._collated_queue: Optional[queue.Queue] = None
self._collation_stop_event: Optional[threading.Event] = None

# --- Timing instrumentation ---
mode_label = (
"background_collation"
if background_collation_queue_size is not None
else "synchronous"
)
self._timing = TimingStats(f"BaseDistLoader ({mode_label})")
self._epoch_start_time: Optional[float] = None
self._log_timing_every_n_batches: Final[int] = 10

# Store dataset metadata for subclass _collate_fn usage
self._is_homogeneous_with_labeled_edge_type = (
dataset_schema.is_homogeneous_with_labeled_edge_type
Expand Down Expand Up @@ -799,10 +878,181 @@ def _init_graph_store_connections(
)
_flush()

# --- Background collation methods ---

def _maybe_log_timing(self) -> None:
"""Log timing summary periodically and at end of epoch."""
if self._num_recv % self._log_timing_every_n_batches == 0:
logger.info(self._timing.summary())
_flush()

def __next__(self): # type: ignore[override]
"""Returns the next collated batch.

When background collation is enabled, retrieves pre-collated results
from the bounded queue. Otherwise, falls back to the synchronous
path (replicated from GLT ``DistLoader``).

Returns:
A ``Data`` or ``HeteroData`` batch.

Raises:
StopIteration: When the epoch is exhausted.
"""
if self._background_collation_queue_size is not None:
return self._next_from_background_collation()
# Original synchronous path (replicated from GLT DistLoader)
if self._num_recv == self._num_expected:
logger.info(
f"[sync] Epoch done. Total batches: {self._num_recv}, "
f"epoch wall time: {time.time() - (self._epoch_start_time or 0.0):.2f}s"
)
logger.info(self._timing.summary())
_flush()
raise StopIteration
t0 = time.time()
if self._with_channel:
msg = self._channel.recv()
else:
msg = self._collocated_producer.sample()
t_recv = time.time()
self._timing.record("sync/recv", t_recv - t0)

result = self._collate_fn(msg)
t_collate = time.time()
self._timing.record("sync/collate_fn", t_collate - t_recv)
self._timing.record("sync/total_next", t_collate - t0)

self._num_recv += 1
self._maybe_log_timing()
return result

def _next_from_background_collation(self):
"""Retrieves the next pre-collated batch from the background queue.

Returns:
A ``Data`` or ``HeteroData`` batch.

Raises:
StopIteration: On sentinel or when expected count is reached.
"""
assert self._collated_queue is not None
t0 = time.time()
qsize_before = self._collated_queue.qsize()
item = self._collated_queue.get()
t_get = time.time()
self._timing.record("bg_consumer/queue_get", t_get - t0)
self._timing.record("bg_consumer/queue_size_at_get", qsize_before)

if item is _COLLATION_SENTINEL:
logger.info(
f"[bg] Epoch done. Total batches: {self._num_recv}, "
f"epoch wall time: {time.time() - (self._epoch_start_time or 0.0):.2f}s"
)
logger.info(self._timing.summary())
_flush()
raise StopIteration
if isinstance(item, BaseException):
raise item
self._num_recv += 1
self._maybe_log_timing()
return item

def _collation_worker(self) -> None:
"""Target function for the background collation daemon thread.

Continuously receives messages from the channel (or collocated
producer) and runs ``_collate_fn``, placing collated results into
``_collated_queue``. Exits when the epoch batch count is reached,
a ``StopIteration`` is received from the channel, or the stop event
is set.
"""
assert self._collated_queue is not None
assert self._collation_stop_event is not None
num_produced = 0
try:
while True:
# For finite epochs, exit after producing all expected batches
if (
self._num_expected != float("inf")
and num_produced >= self._num_expected
):
self._collated_queue.put(_COLLATION_SENTINEL)
return

# Receive next sampled message
t0 = time.time()
try:
if self._with_channel:
msg = self._channel.recv()
else:
msg = self._collocated_producer.sample()
except StopIteration:
self._collated_queue.put(_COLLATION_SENTINEL)
return
t_recv = time.time()
self._timing.record("bg_producer/recv", t_recv - t0)

# Check stop event between recv and collate
if self._collation_stop_event.is_set():
return

result = self._collate_fn(msg)
t_collate = time.time()
self._timing.record("bg_producer/collate_fn", t_collate - t_recv)

self._collated_queue.put(result)
t_put = time.time()
self._timing.record("bg_producer/queue_put", t_put - t_collate)
self._timing.record("bg_producer/total_iteration", t_put - t0)

num_produced += 1
except Exception as e:
self._collated_queue.put(e)

def _start_collation_thread(self) -> None:
"""Creates and starts a fresh background collation thread."""
assert self._background_collation_queue_size is not None
self._collation_stop_event = threading.Event()
self._collated_queue = queue.Queue(
maxsize=self._background_collation_queue_size
)
self._collation_thread = threading.Thread(
target=self._collation_worker, daemon=True
)
self._collation_thread.start()

def _stop_collation_thread(self) -> None:
"""Stops the background collation thread if it is running.

Sets the stop event and drains the queue to unblock the worker
if it is blocked on ``queue.put()``. Joins the thread with a
10-second timeout.
"""
if self._collation_thread is None or not self._collation_thread.is_alive():
return
assert self._collation_stop_event is not None
assert self._collated_queue is not None

self._collation_stop_event.set()
# Drain the queue to unblock the worker if it's blocked on put()
while True:
try:
self._collated_queue.get_nowait()
except queue.Empty:
break
self._collation_thread.join(timeout=10.0)
if self._collation_thread.is_alive():
logger.warning(
"Background collation thread did not terminate within 10 seconds."
)

# Overwrite DistLoader.shutdown to so we can use our own shutdown and rpc calls
def shutdown(self) -> None:
if self._shutdowned:
return
if self._background_collation_queue_size is not None:
self._stop_collation_thread()
if self._is_collocated_worker:
self._collocated_producer.shutdown()
elif self._is_mp_worker:
Expand All @@ -821,6 +1071,25 @@ def shutdown(self) -> None:

# Overwrite DistLoader.__iter__ to so we can use our own __iter__ and rpc calls
def __iter__(self) -> Self:
if self._background_collation_queue_size is not None:
self._stop_collation_thread()

# Log previous epoch timing (if any) and reset for new epoch
if self._epoch > 0:
logger.info(
f"[iter] Resetting for epoch {self._epoch}. " f"Previous epoch timing:"
)
logger.info(self._timing.summary())
_flush()

mode_label = (
"background_collation"
if self._background_collation_queue_size is not None
else "synchronous"
)
self._timing = TimingStats(f"BaseDistLoader ({mode_label}) epoch={self._epoch}")
self._epoch_start_time = time.time()

self._num_recv = 0
if self._is_collocated_worker:
self._collocated_producer.reset()
Expand All @@ -841,4 +1110,8 @@ def __iter__(self) -> Self:
torch.futures.wait_all(rpc_futures)
self._channel.reset()
self._epoch += 1

if self._background_collation_queue_size is not None:
self._start_collation_thread()

return self
8 changes: 8 additions & 0 deletions gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
context: Optional[DistributedContext] = None, # TODO: (svij) Deprecate this
local_process_rank: Optional[int] = None, # TODO: (svij) Deprecate this
local_process_world_size: Optional[int] = None, # TODO: (svij) Deprecate this
background_collation_queue_size: Optional[int] = None,
):
"""
Neighbor loader for Anchor Based Link Prediction (ABLP) tasks.
Expand Down Expand Up @@ -215,6 +216,12 @@ def __init__(
context (deprecated - will be removed soon) (Optional[DistributedContext]): Distributed context information of the current process.
local_process_rank (deprecated - will be removed soon) (int): The local rank of the current process within a node.
local_process_world_size (deprecated - will be removed soon) (int): The total number of processes within a node.
background_collation_queue_size (Optional[int]): If set to a positive
integer, enables background collation in a daemon thread. The
collation of sampled messages is performed in a background thread,
overlapping with GPU training. The value controls the maximum
number of pre-collated batches buffered in memory. ``None``
disables background collation (default behavior).
"""

# Set self._shutdowned right away, that way if we throw here, and __del__ is called,
Expand Down Expand Up @@ -385,6 +392,7 @@ def __init__(
producer=producer,
sampler_options=sampler_options,
process_start_gap_seconds=process_start_gap_seconds,
background_collation_queue_size=background_collation_queue_size,
max_concurrent_producer_inits=max_concurrent_producer_inits,
)

Expand Down
Loading