Skip to content

Commit 1692164

Browse files
authored
Re-enables NVSHMEM Scatter tests on tester (#10)
* Re-enables NVSHMEM Scatter tests on tester - Adds checks for new environment variables when installing * Fixes issues with nvshmem scatter data setup * Fixes issue with non-zeroed out tensor init on shared memory
1 parent 67b1ba3 commit 1692164

4 files changed

Lines changed: 46 additions & 14 deletions

File tree

DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,12 @@ def _nvshmem_scatter(input_tensor, indices, rank_mappings, num_output_rows):
5555

5656
num_elem = num_output_rows * num_features
5757

58+
# TODO: Look into using calloc here to avoid zeroing out the tensor
5859
scattered_tensor = nvshmem.NVSHMEMP2P.allocate_symmetric_memory(
5960
num_elem, device.index
6061
).reshape((bs, num_output_rows, num_features))
62+
scattered_tensor.zero_()
63+
6164
cur_rank = nvshmem.NVSHMEMP2P.get_rank()
6265
indices = indices % num_output_rows
6366
local_send_tensor = input_tensor[rank_mappings == cur_rank].unsqueeze(0)

experiments/NVSHMEM-Enabled-DGRAPH/PreInstallCheck.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,14 @@ def _check_mpi():
3333
print("Checking if MPI_HOME is set")
3434

3535
error_signal = False
36-
if "MPI_HOME" not in os.environ:
37-
print("ERROR: MPI_HOME is not set\n")
36+
37+
usual_MPI_envs_names = ["MPI_HOME", "MPI_ROOT", "MPICH_HOME"]
38+
if not any(env_name in os.environ for env_name in usual_MPI_envs_names):
39+
print("ERROR: One of MPI_HOME, MPI_ROOT, or MPICH_HOME is not set\n")
3840
error_signal = True
3941
else:
40-
print(f"MPI_HOME: {os.environ['MPI_HOME']}\n")
42+
mpi_env_name = [x for x in usual_MPI_envs_names if x in os.environ][0]
43+
print(f"{mpi_env_name}: {os.environ[mpi_env_name]}\n")
4144

4245
return error_signal
4346

setup.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,20 @@
4444
raise EnvironmentError("NVSHMEM_HOME must be set to build DGraph")
4545

4646
# TODO: Try to add the ability to input this path as an argument
47-
if "MPI_HOME" not in os.environ:
47+
usual_MPI_envs_names = ["MPI_HOME", "MPI_ROOT", "MPICH_HOME"]
48+
49+
if not any(env_name in os.environ for env_name in usual_MPI_envs_names):
4850
raise EnvironmentError("MPI_HOME must be set to build DGraph")
4951

52+
mpi_env_name = [x for x in usual_MPI_envs_names if x in os.environ][0]
53+
5054
nvshmem_home = os.environ["NVSHMEM_HOME"]
5155
# print(f"Found NVSHMEM_HOME: {nvshmem_home}")
5256

5357
nvshmem_include = os.path.join(nvshmem_home, "include")
5458
nvshmem_lib = os.path.join(nvshmem_home, "lib")
5559

56-
mpi_home = os.environ["MPI_HOME"]
60+
mpi_home = os.environ[mpi_env_name]
5761
# print(f"Found MPI_HOME: {mpi_home}")
5862
mpi_include = os.path.join(mpi_home, "include")
5963
mpi_lib = os.path.join(mpi_home, "lib")

tests/test_nvshmem_backend.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,28 @@ def setup_scatter_data(init_nvshmem_backend):
6161
torch.manual_seed(0)
6262

6363
num_features = 8
64+
6465
all_rank_input_data = torch.randn(1, 8, num_features)
6566

6667
all_edge_coo = torch.tensor([[0, 0, 0, 1, 1, 2, 2, 3], [1, 2, 3, 0, 3, 0, 3, 0]])
6768
rank_mappings = torch.tensor([[0, 0, 0, 0, 0, 1, 1, 1], [0, 1, 1, 0, 1, 0, 1, 0]])
6869

69-
all_rank_output = torch.zeros(2, 4, num_features)
70+
num_global_output_rows = 4
71+
all_rank_output = torch.zeros(2, num_global_output_rows, num_features)
7072

7173
for k in range(2):
7274
_indices = all_edge_coo[k].view(1, -1, 1).expand(1, -1, num_features)
7375
output_data = torch.zeros_like(all_rank_output[[k]])
7476
output_data.scatter_add_(1, _indices, all_rank_input_data)
7577
all_rank_output[k] = output_data
7678

77-
return all_rank_input_data, all_edge_coo, rank_mappings, all_rank_output
79+
return (
80+
all_rank_input_data,
81+
all_edge_coo,
82+
rank_mappings,
83+
all_rank_output,
84+
num_global_output_rows,
85+
)
7886

7987

8088
def test_nvshmem_backend_init(init_nvshmem_backend):
@@ -131,12 +139,13 @@ def test_nvshmem_backend_scatter(init_nvshmem_backend, setup_scatter_data):
131139
comm = init_nvshmem_backend
132140
rank = comm.get_rank()
133141
world_size = comm.get_world_size()
134-
all_rank_input_data, all_edge_coo, rank_mappings, all_rank_output = (
135-
setup_scatter_data
136-
)
137-
138-
all_edge_coo = all_edge_coo.T
139-
rank_mappings = rank_mappings.T
142+
(
143+
all_rank_input_data,
144+
all_edge_coo,
145+
rank_mappings,
146+
all_rank_output,
147+
num_global_output_rows,
148+
) = setup_scatter_data
140149

141150
input_slice_start = (all_rank_input_data.shape[1] // world_size) * rank
142151
input_slice_end = (all_rank_input_data.shape[1] // world_size) * (rank + 1)
@@ -147,7 +156,11 @@ def test_nvshmem_backend_scatter(init_nvshmem_backend, setup_scatter_data):
147156
local_input_data_gt = all_rank_input_data[:, input_slice_start:input_slice_end, :]
148157
local_edge_coo = all_edge_coo[:, edge_slice_start:edge_slice_end]
149158
local_rank_mappings_gt = rank_mappings[:, edge_slice_start:edge_slice_end]
150-
local_output_data_gt = all_rank_output[:, edge_slice_start:edge_slice_end, :]
159+
160+
output_slice_start = (num_global_output_rows // world_size) * rank
161+
output_slice_end = (num_global_output_rows // world_size) * (rank + 1)
162+
local_output_data_gt = all_rank_output[:, output_slice_start:output_slice_end, :]
163+
num_output_rows = local_output_data_gt.shape[1]
151164

152165
for i in range(2):
153166
local_indices_gt = local_edge_coo[[i], :]
@@ -161,3 +174,12 @@ def test_nvshmem_backend_scatter(init_nvshmem_backend, setup_scatter_data):
161174

162175
local_input_data = comm.get_local_rank_slice(all_rank_input_data, dim=1)
163176
assert torch.allclose(local_input_data, local_input_data_gt)
177+
178+
scattered_tensor = comm.scatter(
179+
local_input_data.cuda(),
180+
local_indices.cuda(),
181+
local_rank_mapping.cuda(),
182+
num_output_rows,
183+
)
184+
185+
assert torch.allclose(scattered_tensor, local_output_data_gt[[i]].cuda())

0 commit comments

Comments
 (0)