Skip to content

Commit 38ee37b

Browse files
szaman19bvanessen
andauthored
Edge placement update (#11)
Updated API for NCCL and NVSHMEM backends for DGraph with running OGB and Benchmark code. Incorporates caching to reduce overhead. In detail: * Adding Multi-partition function * Adding preprocessing code * Add padding functionality to NVSHMEM operations, enabling arbitrary shaped inputs to the function * Revamped the Distributed Graph implementation to simplify the code and remove the old slicing based code - Also adds documentation * Updating Graphcast implementation with DGraph distributed * Incorporate NCCL cache into the backend engine * Update local tensor getter to use a placement tensor mask * Add Graphcast distributed trainer * Fix distributed graph object * Save intermediate preprocessed file so it can be reused * Remove unnecessary complexity of file passing and use a single torch tensor * Graphcast update static graph generator with preprocessing code * Updated mesh graph placement algorithm * Fixed scatter test with new API * Update distributed GCN with edge placement tensor * Add GatherCacheGenerator and ScatterCacheGenerator * Add Graphcast preprocessing * Add static method for mesh partitioning * Add grid vertex placement logic to MeshGraph * Add OGBN-products update * Overly complicated but correct graph data * Add separated out benchmarking code for small tests * Fix missing batch dim * More general fix * Disabled some incomplete NVSHMEM caching optimizations. Added code to set the default PyTorch device. * Add NVSHMEM benchmark code - Append backend type to output files - Add sample plot generation code * Adding torch distributed init with NVSHMEM communicator * Apply suggestions from code review * Fixed plotting script to grab the right log files. * Fix the cached benchmarks * Apply review suggestions, remove dead code --------- Co-authored-by: Brian C. Van Essen <vanessen1@llnl.gov>
1 parent 1692164 commit 38ee37b

31 files changed

Lines changed: 2730 additions & 517 deletions

DGraph/Communicator.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
#
1313
# SPDX-License-Identifier: (Apache-2.0)
1414
import torch
15-
from DGraph.distributed.mpi import MPIBackendEngine
15+
1616
from DGraph.distributed.nccl import NCCLBackendEngine
17-
from DGraph.distributed.nvshmem import NVSHMEMBackendEngine
17+
1818
from DGraph.CommunicatorBase import CommunicatorBase
1919

2020
SUPPORTED_BACKENDS = ["nccl", "mpi", "nvshmem"]
@@ -38,8 +38,12 @@ def __init__(self, backend: str, **kwargs) -> None:
3838
if backend == "nccl":
3939
self.__backend_engine = NCCLBackendEngine()
4040
elif backend == "mpi":
41+
from DGraph.distributed.mpi import MPIBackendEngine
42+
4143
self.__backend_engine = MPIBackendEngine(**kwargs)
4244
elif backend == "nvshmem":
45+
from DGraph.distributed.nvshmem import NVSHMEMBackendEngine
46+
4347
self.__backend_engine = NVSHMEMBackendEngine()
4448
else:
4549
raise NotImplementedError(f"Backend {backend} not implemented")
@@ -65,6 +69,32 @@ def get_local_rank_slice(self, tensor: torch.Tensor, dim: int = -1) -> torch.Ten
6569
self.__check_init()
6670
return self.__backend_engine.get_local_rank_slice(tensor, dim)
6771

72+
def get_local_tensor(
73+
self, tensor: torch.Tensor, placement_tensor: torch.Tensor, dim: int = -1
74+
) -> torch.Tensor:
75+
"""Returns the tensor corresponding to the current process based on the placement tensor.
76+
77+
Args:
78+
tensor: The tensor to be sliced.
79+
placement_tensor: A boolean tensor of the same shape as the tensor, where True indicates the process
80+
that should receive the corresponding element.
81+
dim: The dimension along which the tensor should be sliced.
82+
83+
Returns:
84+
(torch.Tensor): The local tensor corresponding to the current process.
85+
"""
86+
self.__check_init()
87+
mask = (placement_tensor == self.get_rank()).bool()
88+
mask_shape = [1] * tensor.ndim
89+
mask_shape[dim] = mask.size(0)
90+
mask_expanded = mask.view(mask_shape).expand_as(tensor)
91+
masked_tensor = tensor[mask_expanded]
92+
new_shape = list(tensor.shape)
93+
new_shape[dim] = int(mask.sum().item())
94+
masked_tensor = masked_tensor.view(new_shape)
95+
96+
return masked_tensor
97+
6898
def scatter(self, *args, **kwargs) -> torch.Tensor:
6999
self.__check_init()
70100
return self.__backend_engine.scatter(*args, **kwargs)

DGraph/CommunicatorBase.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class CommunicatorBase:
1515
_is_initialized = False
1616

1717
def __init__(self):
18+
self.backend = ""
1819
pass
1920

2021
def init_process_group(self, backend: str, **kwargs):

0 commit comments

Comments
 (0)