Skip to content

Commit 3d19f5e

Browse files
authored
NVSHMEM initialization test code (#17)
* Experiments running after build system change * Adding init test code to run with torchrun hpc * Update memory allocation to allow for non-NVLink transport for inter-node NVSHMEM * Adding unit test to check NVSHMEM init matches with NCCL * Adding readme for NVSHMEM init script
1 parent 1d5c7ef commit 3d19f5e

4 files changed

Lines changed: 112 additions & 5 deletions

File tree

DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,13 @@ def _nvshmmem_gather(send_tensor, indices, rank_mappings):
2626
num_output_rows = indices.shape[1]
2727
num_features = send_tensor.shape[2]
2828

29-
gathered_tensor = torch.zeros((bs, num_output_rows, num_features)).to(
30-
send_tensor.device
29+
num_elem = bs * num_output_rows * num_features
30+
gathered_tensor = nvshmem.NVSHMEMP2P.allocate_symmetric_memory(
31+
num_elem, send_tensor.device.index
3132
)
33+
gathered_tensor.fill_(0).float()
34+
gathered_tensor = gathered_tensor.reshape((bs, num_output_rows, num_features))
35+
3236
# Gather the tensors
3337

3438
# TODO: Add an option to cache the max value
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import DGraph.Communicator as Comm
2+
import torch.distributed as dist
3+
import torch
4+
import DGraph.torch_nvshmem_p2p as nvshmem
5+
6+
7+
def main():
8+
comm = Comm.Communicator.init_process_group("nvshmem")
9+
rank = comm.get_rank()
10+
world_size = comm.get_world_size()
11+
12+
dist.barrier()
13+
assert dist.is_initialized(), "NCCL process group not initialized"
14+
assert dist.get_backend() == "nccl", "NCCL process group not initialized"
15+
assert dist.get_rank() == rank, "NCCL process group rank mismatch"
16+
assert dist.get_world_size() == world_size, "NCCL process group world size mismatch"
17+
18+
dist.barrier()
19+
for i in range(world_size):
20+
if rank == i:
21+
print(
22+
f"Rank {rank} checking in. ",
23+
f"Number of available GPUs: {torch.cuda.device_count()}",
24+
)
25+
dist.barrier()
26+
27+
# Set device for this process
28+
local_rank = rank % torch.cuda.device_count()
29+
torch.cuda.set_device(local_rank)
30+
31+
# Allocate a tensor on the GPU
32+
num_elements = world_size
33+
nvshmem_tensor = nvshmem.NVSHMEMP2P.allocate_symmetric_memory(
34+
num_elements, local_rank
35+
)
36+
nvshmem_tensor.fill_(rank).float()
37+
38+
dist.barrier()
39+
for i in range(world_size):
40+
if rank == i:
41+
print(
42+
f"Rank {rank}: ",
43+
f"Tensor: {nvshmem_tensor}",
44+
)
45+
dist.barrier()
46+
assert torch.allclose(
47+
nvshmem_tensor, torch.full((num_elements,), rank).cuda().float()
48+
), "Tensor values do not match expected values"
49+
50+
indices = torch.arange(num_elements, dtype=torch.int64).cuda()
51+
output_tensor = nvshmem.NVSHMEMP2P.allocate_symmetric_memory(
52+
num_elements, local_rank
53+
)
54+
output_tensor.fill_(0).float()
55+
ranks = torch.arange(world_size).cuda()
56+
nvshmem.NVSHMEMP2P.dist_get(
57+
nvshmem_tensor, output_tensor, indices, ranks, 1, num_elements, 1, num_elements
58+
)
59+
dist.barrier()
60+
for i in range(world_size):
61+
if rank == i:
62+
print(
63+
f"Rank {rank}: ",
64+
f"Output Tensor: {output_tensor}",
65+
)
66+
dist.barrier()
67+
68+
assert torch.allclose(
69+
output_tensor, torch.arange(world_size).cuda().float()
70+
), "Output tensor values do not match expected values"
71+
72+
73+
if __name__ == "__main__":
74+
main()

experiments/NVSHMEM-Enabled-DGRAPH/README.md

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@ Enabling the NVSHMEM backend in DGraph requires MPI and NVSHMEM to be installed
44

55
## Pre-requisites
66

7-
DGraph must be built with NVSHMEM, MPI, and CUDA in order to use the NVSHMEM backend. The
8-
setup script will install the appropriate submodules but the dependencies must be installed
9-
and available on the system.
7+
DGraph must be built with NVSHMEM, MPI, and CUDA in order to use the NVSHMEM backend. The setup script will install the appropriate submodules but the dependencies must be installed and available on the system.
108

119
DGraph searches for NVSHMEM, MPI, and CUDA based on the following environment variables:
1210
- `NVSHMEM_HOME`
@@ -40,3 +38,19 @@ NVSHMEM compilation information can be usually found by running the `nvshmem-inf
4038
$NVSHMEM_HOME/bin/nvshmem-info -b
4139
```
4240

41+
## Building DGraph with NVSHMEM
42+
To build DGraph with NVSHMEM, make sure the environment variables are set and run the following command:
43+
44+
```shell
45+
pip install -e .
46+
```
47+
48+
## Running DGraph with NVSHMEM
49+
50+
DGraph builds on top of the PyTorch distributed package, so it is important to initialize the PyTorch distributed package before using DGraph, but also initialize MPI. Using a distributed launcher simplifies this process. We recommend using [`torchrun-hpc](https://github.com/lbann/HPC-launcher)
51+
52+
```shell
53+
torchrun-hpc -N<NUM_NODES> -n<NUM_PROCESSES_PER_NODE> NVSHMEM_init.py
54+
```
55+
56+
The script assumes that the launcher starts the processes and `torch.dist` is initialized. If not using a launcher, you must initialize the PyTorch distributed package.

tests/test_nvshmem_backend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,21 @@ def test_nvshmem_backend_init(init_nvshmem_backend):
9191
print(f"Rank: {rank}")
9292

9393

94+
def test_nvshmem_backend_dist_init(init_nvshmem_backend):
95+
# Check if the initialization of the NVSHMEM backend is correct
96+
# and matches the NCCL backend in terms of rank and world size
97+
comm = init_nvshmem_backend
98+
rank = comm.get_rank()
99+
world_size = comm.get_world_size()
100+
101+
if not dist.is_initialized():
102+
print("NCCL process group not initialized, skipping test...")
103+
return True
104+
assert dist.is_initialized(), "NCCL process group not initialized"
105+
assert dist.get_rank() == rank, "NCCL process group rank mismatch"
106+
assert dist.get_world_size() == world_size, "NCCL process group world size mismatch"
107+
108+
94109
def test_nvshmem_backend_gather(init_nvshmem_backend, setup_gather_data):
95110
comm: Comm.Communicator = init_nvshmem_backend
96111
all_rank_input_data, all_edge_coo, rank_mappings, all_rank_output = (

0 commit comments

Comments
 (0)