Skip to content

Commit 954bd99

Browse files
szaman19Copilot
andauthored
Fixes for OGB experiments (#18)
* Fix for single process error on OGB data * Fix for multi-gpu run and updated README * Update experiments/OGB/Readme.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Add locks to account for race condition when setting up the data for the first time * Add keita's suggestion for safe makedir * Patch ranklocal operation to fix device issues * Updated main with new cache retrieval * Add graph cache generation code * Add dataset name to experiment func to load and store the cache * Quick fix for different hidden dims * Modify model to allow for different graph vertex input dimensions The last fix was too quick * Add guard on accuracy calculations for ranks with 0 masks --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 3d19f5e commit 954bd99

8 files changed

Lines changed: 253 additions & 30 deletions

File tree

DGraph/CommunicatorBase.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,6 @@ def get_rank(self) -> int:
2626

2727
def get_world_size(self) -> int:
2828
raise NotImplementedError
29+
30+
def barrier(self):
31+
raise NotImplementedError

DGraph/data/ogbn_datasets.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,25 @@ def __init__(
174174
self._rank = self.comm_object.get_rank()
175175
self._world_size = self.comm_object.get_world_size()
176176

177-
self.dataset = NodePropPredDataset(
178-
name=dname,
179-
)
177+
comm_object.barrier()
178+
# Load the dataset on rank 0
179+
if comm_object.get_rank() == 0:
180+
self.dataset = NodePropPredDataset(
181+
name=dname,
182+
)
183+
# Block until rank 0 loads and processe the data
184+
# For the first time, the code downloads and processes the data
185+
# doing that on all ranks causes a race condition
186+
comm_object.barrier()
187+
# Load the dataset on all other ranks
188+
# This is to use the processed data that was generated by rank 0
189+
# This should account for a race condition
190+
191+
if comm_object.get_rank() != 0:
192+
self.dataset = NodePropPredDataset(
193+
name=dname,
194+
)
195+
comm_object.barrier()
180196
graph_data, labels = self.dataset[0]
181197

182198
self.split_idx = self.dataset.get_idx_split()
@@ -185,7 +201,7 @@ def __init__(
185201
dir_name = dir_name if dir_name is not None else os.getcwd() + "/data"
186202

187203
if not os.path.exists(dir_name):
188-
os.makedirs(dir_name)
204+
os.makedirs(dir_name, exist_ok=True)
189205

190206
cached_graph_file = f"{dir_name}/{dname}_graph_data_{self._world_size}.pt"
191207

DGraph/distributed/RankLocalOps.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def OptimizedRankLocalMaskedGather(
6969
num_features = src.shape[-1]
7070
local_masked_gather(
7171
src,
72-
indices,
73-
rank_mapping,
72+
indices.cuda(),
73+
rank_mapping.cuda(),
7474
output,
7575
bs,
7676
num_src_rows,
@@ -137,13 +137,11 @@ def RankLocalRenumberingWithMapping(_indices, rank_mapping):
137137
"""
138138
This function removes duplicates from the indices tensor.
139139
"""
140-
unique_indices = torch.unique(_indices).to(_indices.device)
140+
unique_indices, inverse_indices = torch.unique(_indices, return_inverse=True)
141141
rank_mapping = rank_mapping.to(_indices.device)
142-
renumbered_indices = torch.zeros_like(_indices)
143-
unique_rank_mapping = torch.zeros_like(unique_indices)
144-
for i, idx in enumerate(unique_indices):
145-
renumbered_indices[_indices == idx] = i
146-
unique_rank_mapping[i] = rank_mapping[_indices == idx][0]
142+
renumbered_indices = inverse_indices
143+
unique_rank_mapping = torch.zeros_like(unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device)
144+
unique_rank_mapping.scatter_(0, inverse_indices, rank_mapping)
147145

148146
return renumbered_indices, unique_indices, unique_rank_mapping
149147

DGraph/distributed/nccl/NCCLBackendEngine.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -689,10 +689,18 @@ def gather(
689689
return output_tensor # type: ignore
690690

691691
def destroy(self) -> None:
692-
if self._initialized:
692+
if NCCLBackendEngine._is_initialized:
693693
# dist.destroy_process_group()
694-
self._initialized = False
694+
NCCLBackendEngine._is_initialized = False
695695

696696
def finalize(self) -> None:
697-
if self._initialized:
697+
if NCCLBackendEngine._is_initialized:
698698
dist.barrier()
699+
700+
def barrier(self) -> None:
701+
if NCCLBackendEngine._is_initialized:
702+
dist.barrier()
703+
else:
704+
raise RuntimeError(
705+
"NCCLBackendEngine is not initialized, cannot call barrier"
706+
)

experiments/OGB/GenerateCache.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC.
2+
# Produced at the Lawrence Livermore National Laboratory.
3+
# Written by the LBANN Research Team (B. Van Essen, et al.) listed in
4+
# the CONTRIBUTORS file. See the top-level LICENSE file for details.
5+
#
6+
# LLNL-CODE-697807.
7+
# All rights reserved.
8+
#
9+
# This file is part of LBANN: Livermore Big Artificial Neural Network
10+
# Toolkit. For details, see http://software.llnl.gov/LBANN or
11+
# https://github.com/LBANN and https://github.com/LLNL/LBANN.
12+
#
13+
# SPDX-License-Identifier: (Apache-2.0)
14+
15+
from DGraph.data.ogbn_datasets import process_homogenous_data
16+
from ogb.nodeproppred import NodePropPredDataset
17+
from fire import Fire
18+
import os
19+
import torch
20+
from DGraph.distributed.nccl._nccl_cache import (
21+
NCCLGatherCacheGenerator,
22+
NCCLScatterCacheGenerator,
23+
)
24+
from time import perf_counter
25+
from tqdm import tqdm
26+
from multiprocessing import get_context
27+
28+
29+
cache_prefix = {
30+
"ogbn-arxiv": "arxiv",
31+
"ogbn-products": "products",
32+
"ogbn-papers100M": "papers100M",
33+
}
34+
35+
36+
def generate_cache_file(
37+
dist_graph,
38+
src_indices,
39+
dst_indices,
40+
edge_placement,
41+
edge_src_placement,
42+
edge_dest_placement,
43+
cache_prefix_str: str,
44+
rank: int,
45+
world_size: int,
46+
):
47+
print(f"Generating cache for rank {rank}...")
48+
local_node_features = dist_graph.get_local_node_features(rank).unsqueeze(0)
49+
num_input_rows = local_node_features.size(1)
50+
51+
print(
52+
f"Rank {rank} has {num_input_rows} input rows with shape {local_node_features.shape}"
53+
)
54+
gather_cache = NCCLGatherCacheGenerator(
55+
dst_indices,
56+
edge_placement,
57+
edge_dest_placement,
58+
num_input_rows,
59+
rank,
60+
world_size,
61+
)
62+
63+
nodes_per_rank = dist_graph.get_nodes_per_rank()
64+
nodes_per_rank = int(nodes_per_rank[rank].item())
65+
66+
scatter_cache = NCCLScatterCacheGenerator(
67+
src_indices,
68+
edge_placement,
69+
edge_src_placement,
70+
nodes_per_rank,
71+
rank,
72+
world_size,
73+
)
74+
print(f"Rank {rank} completed cache generation")
75+
with open(
76+
f"{cache_prefix_str}_gather_cache_rank_{world_size}_{rank}.pt", "wb"
77+
) as f:
78+
torch.save(gather_cache, f)
79+
80+
with open(
81+
f"{cache_prefix_str}_scatter_cache_rank_{world_size}_{rank}.pt", "wb"
82+
) as f:
83+
torch.save(scatter_cache, f)
84+
return 0
85+
86+
87+
def main(dset: str, world_size: int, node_rank_placement_file: str):
88+
assert dset in ["ogbn-arxiv", "ogbn-products", "ogbn-papers100M"]
89+
90+
assert world_size > 0
91+
assert os.path.exists(
92+
node_rank_placement_file
93+
), "Node rank placement file does not exist."
94+
95+
node_rank_placement = torch.load(node_rank_placement_file)
96+
97+
dataset = NodePropPredDataset(
98+
dset,
99+
)
100+
101+
split_index = dataset.get_idx_split()
102+
assert split_index is not None, "Split index is None."
103+
104+
graph, labels = dataset[0]
105+
106+
num_edges = graph["edge_index"].shape
107+
print(num_edges)
108+
109+
dist_graph = process_homogenous_data(
110+
graph_data=graph,
111+
labels=labels,
112+
world_Size=world_size,
113+
split_idx=split_index,
114+
node_rank_placement=node_rank_placement,
115+
rank=0,
116+
)
117+
118+
edge_indices = dist_graph.get_global_edge_indices()
119+
rank_mappings = dist_graph.get_global_rank_mappings()
120+
121+
print("Edge indices shape:", edge_indices.shape)
122+
print("Rank mappings shape:", rank_mappings.shape)
123+
124+
edge_indices = edge_indices.unsqueeze(0)
125+
src_indices = edge_indices[:, 0, :]
126+
dst_indices = edge_indices[:, 1, :]
127+
128+
edge_placement = rank_mappings[0]
129+
edge_src_placement = rank_mappings[0]
130+
edge_dest_placement = rank_mappings[1]
131+
132+
start_time = perf_counter()
133+
cache_prefix_str = f"cache/{cache_prefix[dset]}"
134+
with get_context("spawn").Pool(min(world_size, 8)) as pool:
135+
args = [
136+
(
137+
dist_graph,
138+
src_indices,
139+
dst_indices,
140+
edge_placement,
141+
edge_src_placement,
142+
edge_dest_placement,
143+
cache_prefix_str,
144+
rank,
145+
world_size,
146+
)
147+
for rank in range(world_size)
148+
]
149+
150+
out = pool.starmap(generate_cache_file, args)
151+
152+
end_time = perf_counter()
153+
print(f"Cache generation time: {end_time - start_time:.4f} seconds")
154+
print("Cache files generated successfully.")
155+
print(
156+
f"Gather cache file: {cache_prefix_str}_gather_cache_rank_{world_size}_<rank>.pt"
157+
)
158+
print(
159+
f"Scatter cache file: {cache_prefix_str}_scatter_cache_rank_{world_size}_<rank>.pt"
160+
)
161+
162+
163+
if __name__ == "__main__":
164+
Fire(main)

experiments/OGB/Readme.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,13 @@ DGraph supports distributed training using the `nccl`, `nvshmem`, and `mpi` back
2323
In order to run the experiments with the `nccl` backend, run the following command:
2424

2525
```bash
26-
torchrun --nnodes <nodes> --nproc-per-node <gpus> main.py --backend nccl --lr lr --epochs epochs --runs runs --log_dir log-dir
26+
torchrun-hpc -N <nodes> -n <gpus> main.py --backend nccl --lr lr --epochs epochs --runs runs --node_rank_placement_file <file_dir> --log_dir log-dir
2727
```
28+
You may have to turn ``--xargs=--mpibind=off`` and ``--xargs=--gpu-bind=none`` in your Slurm script to avoid binding issues.
29+
30+
**Note that we use `torchrun-hpc` instead of `torchrun` **, the run command may vary based on your environment.
31+
32+
2833

2934
### Additional Notes
3035
The experiments use some additional libraries. Use the [ogb] option

0 commit comments

Comments
 (0)