Skip to content

Commit a18faef

Browse files
committed
Optimized cache generation with scatter call and inverse_indices
Add standalone file to generation cache to asynchronously generate and save caches - Update run code to load pre-saved cache files
1 parent fd78e19 commit a18faef

3 files changed

Lines changed: 227 additions & 49 deletions

File tree

DGraph/distributed/RankLocalOps.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,13 +137,13 @@ def RankLocalRenumberingWithMapping(_indices, rank_mapping):
137137
"""
138138
This function removes duplicates from the indices tensor.
139139
"""
140-
unique_indices = torch.unique(_indices).to(_indices.device)
140+
unique_indices, inverse_indices = torch.unique(_indices, return_inverse=True)
141141
rank_mapping = rank_mapping.to(_indices.device)
142-
renumbered_indices = torch.zeros_like(_indices)
143-
unique_rank_mapping = torch.zeros_like(unique_indices)
144-
for i, idx in enumerate(unique_indices):
145-
renumbered_indices[_indices == idx] = i
146-
unique_rank_mapping[i] = rank_mapping[_indices == idx][0]
142+
renumbered_indices = inverse_indices
143+
unique_rank_mapping = torch.zeros_like(
144+
unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device
145+
)
146+
unique_rank_mapping.scatter_(0, inverse_indices, rank_mapping)
147147

148148
return renumbered_indices, unique_indices, unique_rank_mapping
149149

experiments/OGB/GenerateCache.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC.
2+
# Produced at the Lawrence Livermore National Laboratory.
3+
# Written by the LBANN Research Team (B. Van Essen, et al.) listed in
4+
# the CONTRIBUTORS file. See the top-level LICENSE file for details.
5+
#
6+
# LLNL-CODE-697807.
7+
# All rights reserved.
8+
#
9+
# This file is part of LBANN: Livermore Big Artificial Neural Network
10+
# Toolkit. For details, see http://software.llnl.gov/LBANN or
11+
# https://github.com/LBANN and https://github.com/LLNL/LBANN.
12+
#
13+
# SPDX-License-Identifier: (Apache-2.0)
14+
15+
from DGraph.data.ogbn_datasets import process_homogenous_data
16+
from ogb.nodeproppred import NodePropPredDataset
17+
from fire import Fire
18+
import os
19+
import torch
20+
from DGraph.distributed.nccl._nccl_cache import (
21+
NCCLGatherCacheGenerator,
22+
NCCLScatterCacheGenerator,
23+
)
24+
from time import perf_counter
25+
from tqdm import tqdm
26+
from multiprocessing import get_context
27+
28+
29+
cache_prefix = {
30+
"ogbn-arxiv": "arxiv",
31+
"ogbn-products": "products",
32+
"ogbn-proteins": "proteins",
33+
}
34+
35+
36+
def generate_cache_file(
37+
dist_graph,
38+
src_indices,
39+
dst_indices,
40+
edge_placement,
41+
edge_src_placement,
42+
edge_dest_placement,
43+
cache_prefix_str: str,
44+
rank: int,
45+
world_size: int,
46+
):
47+
print(f"Generating cache for rank {rank}...")
48+
local_node_features = dist_graph.get_local_node_features(rank).unsqueeze(0)
49+
num_input_rows = local_node_features.size(1)
50+
51+
print(
52+
f"Rank {rank} has {num_input_rows} input rows with shape {local_node_features.shape}"
53+
)
54+
gather_cache = NCCLGatherCacheGenerator(
55+
dst_indices,
56+
edge_placement,
57+
edge_dest_placement,
58+
num_input_rows,
59+
rank,
60+
world_size,
61+
)
62+
63+
nodes_per_rank = dist_graph.get_nodes_per_rank()
64+
nodes_per_rank = int(nodes_per_rank[rank].item())
65+
66+
scatter_cache = NCCLScatterCacheGenerator(
67+
src_indices,
68+
edge_placement,
69+
edge_src_placement,
70+
nodes_per_rank,
71+
rank,
72+
world_size,
73+
)
74+
print(f"Rank {rank} completed cache generation")
75+
with open(
76+
f"{cache_prefix_str}_gather_cache_rank_{world_size}_{rank}.pt", "wb"
77+
) as f:
78+
torch.save(gather_cache, f)
79+
80+
with open(
81+
f"{cache_prefix_str}_scatter_cache_rank_{world_size}_{rank}.pt", "wb"
82+
) as f:
83+
torch.save(scatter_cache, f)
84+
return 0
85+
86+
87+
def main(dset: str, world_size: int, node_rank_placement_file: str):
88+
assert dset in ["ogbn-arxiv", "ogbn-products", "ogbn-proteins"]
89+
90+
assert world_size > 0
91+
assert os.path.exists(
92+
node_rank_placement_file
93+
), "Node rank placement file does not exist."
94+
95+
node_rank_placement = torch.load(node_rank_placement_file)
96+
97+
dataset = NodePropPredDataset(
98+
dset,
99+
)
100+
101+
split_index = dataset.get_idx_split()
102+
assert split_index is not None, "Split index is None."
103+
104+
graph, labels = dataset[0]
105+
106+
num_edges = graph["edge_index"].shape
107+
print(num_edges)
108+
109+
dist_graph = process_homogenous_data(
110+
graph_data=graph,
111+
labels=labels,
112+
world_Size=world_size,
113+
split_idx=split_index,
114+
node_rank_placement=node_rank_placement,
115+
rank=0,
116+
)
117+
118+
edge_indices = dist_graph.get_global_edge_indices()
119+
rank_mappings = dist_graph.get_global_rank_mappings()
120+
121+
print("Edge indices shape:", edge_indices.shape)
122+
print("Rank mappings shape:", rank_mappings.shape)
123+
124+
edge_indices = edge_indices.unsqueeze(0)
125+
src_indices = edge_indices[:, 0, :]
126+
dst_indices = edge_indices[:, 1, :]
127+
128+
edge_placement = rank_mappings[0]
129+
edge_src_placement = rank_mappings[0]
130+
edge_dest_placement = rank_mappings[1]
131+
132+
start_time = perf_counter()
133+
cache_prefix_str = f"cache/{cache_prefix[dset]}"
134+
with get_context("spawn").Pool(min(world_size, 8)) as pool:
135+
args = [
136+
(
137+
dist_graph,
138+
src_indices,
139+
dst_indices,
140+
edge_placement,
141+
edge_src_placement,
142+
edge_dest_placement,
143+
cache_prefix_str,
144+
rank,
145+
world_size,
146+
)
147+
for rank in range(world_size)
148+
]
149+
150+
out = pool.starmap(generate_cache_file, args)
151+
152+
end_time = perf_counter()
153+
print(f"Cache generation time: {end_time - start_time:.4f} seconds")
154+
print("Cache files generated successfully.")
155+
print(
156+
f"Gather cache file: {cache_prefix_str}_gather_cache_rank_{world_size}_<rank>.pt"
157+
)
158+
print(
159+
f"Scatter cache file: {cache_prefix_str}_scatter_cache_rank_{world_size}_<rank>.pt"
160+
)
161+
162+
163+
if __name__ == "__main__":
164+
Fire(main)

experiments/OGB/main.py

Lines changed: 57 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def _run_experiment(
9191
hidden_dims: int = 128,
9292
num_classes: int = 40,
9393
use_cache: bool = False,
94+
dset_name: str = "arxiv",
9495
):
9596
local_rank = comm.get_rank() % torch.cuda.device_count()
9697
print(f"Rank: {local_rank} Local Rank: {local_rank}")
@@ -114,9 +115,9 @@ def _run_experiment(
114115
node_features, edge_indices, rank_mappings, labels = dataset[0]
115116

116117
node_features = node_features.to(device).unsqueeze(0)
117-
edge_indices = edge_indices.to(device)[:, :-1].unsqueeze(0)
118+
edge_indices = edge_indices.to(device).unsqueeze(0)
118119
labels = labels.to(device).unsqueeze(0)
119-
rank_mappings = rank_mappings[:, :-1]
120+
rank_mappings = rank_mappings
120121

121122
if rank == 0:
122123
print("*" * 80)
@@ -144,42 +145,55 @@ def _run_experiment(
144145

145146
if use_cache:
146147
print(f"Rank: {rank} Using Cache. Generating Cache")
147-
start_time = perf_counter()
148-
src_indices = edge_indices[:, 0, :]
149-
dst_indices = edge_indices[:, 1, :]
150-
151-
# This says where the edges are located
152-
edge_placement = rank_mappings[0]
153-
154-
# These say where the source and destination nodes are located
155-
edge_src_placement = rank_mappings[
156-
0
157-
] # Redundant but making explicit for clarity
158-
edge_dest_placement = rank_mappings[1]
159-
160-
num_input_rows = node_features.size(1)
161-
local_num_edges = (edge_placement == rank).sum().item()
162-
163-
if gather_cache is None:
164-
gather_cache = NCCLGatherCacheGenerator(
165-
dst_indices,
166-
edge_placement,
167-
edge_dest_placement,
168-
num_input_rows,
169-
rank,
170-
world_size,
171-
)
172-
if scatter_cache is None:
173-
nodes_per_rank = dataset.graph_obj.get_nodes_per_rank()
174-
175-
scatter_cache = NCCLScatterCacheGenerator(
176-
src_indices,
177-
edge_placement,
178-
edge_src_placement,
179-
nodes_per_rank[rank],
180-
rank,
181-
world_size,
182-
)
148+
149+
# Check if the cache files already exist
150+
cache_prefix = f"cache/{dset_name}"
151+
152+
scatter_cache_file = f"{cache_prefix}_scatter_cache_{world_size}_{rank}.pt"
153+
gather_cache_file = f"{cache_prefix}_gather_cache_{world_size}_{rank}.pt"
154+
155+
if os.path.exists(scatter_cache_file):
156+
scatter_cache = torch.load(scatter_cache_file, weights_only=False)
157+
if os.path.exists(gather_cache_file):
158+
gather_cache = torch.load(gather_cache_file, weights_only=False)
159+
160+
if gather_cache is None or scatter_cache is None:
161+
start_time = perf_counter()
162+
src_indices = edge_indices[:, 0, :]
163+
dst_indices = edge_indices[:, 1, :]
164+
165+
# This says where the edges are located
166+
edge_placement = rank_mappings[0]
167+
168+
# These say where the source and destination nodes are located
169+
edge_src_placement = rank_mappings[
170+
0
171+
] # Redundant but making explicit for clarity
172+
edge_dest_placement = rank_mappings[1]
173+
174+
num_input_rows = node_features.size(1)
175+
local_num_edges = (edge_placement == rank).sum().item()
176+
177+
if gather_cache is None:
178+
gather_cache = NCCLGatherCacheGenerator(
179+
dst_indices,
180+
edge_placement,
181+
edge_dest_placement,
182+
num_input_rows,
183+
rank,
184+
world_size,
185+
)
186+
if scatter_cache is None:
187+
nodes_per_rank = dataset.graph_obj.get_nodes_per_rank()
188+
189+
scatter_cache = NCCLScatterCacheGenerator(
190+
src_indices,
191+
edge_placement,
192+
edge_src_placement,
193+
nodes_per_rank[rank],
194+
rank,
195+
world_size,
196+
)
183197

184198
# Sanity checks for the cache
185199
for key, value in gather_cache.gather_send_local_placement.items():
@@ -208,11 +222,10 @@ def _run_experiment(
208222
end_time = perf_counter()
209223
print(f"Rank: {rank} Cache Generation Time: {end_time - start_time:.4f} s")
210224

211-
if rank == 0:
212-
with open(f"{log_prefix}_gather_cache_{world_size}.pt", "wb") as f:
213-
torch.save(gather_cache, f)
214-
with open(f"{log_prefix}_scatter_cache_{world_size}.pt", "wb") as f:
215-
torch.save(scatter_cache, f)
225+
with open(f"{log_prefix}_gather_cache_{world_size}_{rank}.pt", "wb") as f:
226+
torch.save(gather_cache, f)
227+
with open(f"{log_prefix}_scatter_cache_{world_size}_{rank}.pt", "wb") as f:
228+
torch.save(scatter_cache, f)
216229
print(f"Rank: {rank} Cache Generated")
217230

218231
training_times = []
@@ -366,6 +379,7 @@ def main(
366379
log_prefix,
367380
use_cache=use_cache,
368381
num_classes=num_classes,
382+
dset_name=dataset,
369383
)
370384
training_trajectores[i] = training_traj
371385
validation_trajectores[i] = val_traj

0 commit comments

Comments
 (0)