Skip to content

Commit cce5c0e

Browse files
committed
Modified graph cast and benchmark code to log performance benchmarks
1 parent 82b48e9 commit cce5c0e

11 files changed

Lines changed: 181 additions & 92 deletions

File tree

experiments/Benchmarks/TestNCCL.py

Lines changed: 91 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def run_scatter_benchmark(
169169

170170
def main():
171171
parser = argparse.ArgumentParser()
172-
parser.add_argument("--message_size", type=int, default=128)
172+
parser.add_argument("--message_size", type=int, default=2)
173173
parser.add_argument("--benchmark_cache", action="store_true")
174174
parser.add_argument("--num_iters", type=int, default=1000)
175175
parser.add_argument("--log_dir", type=str, default="logs")
@@ -196,92 +196,114 @@ def main():
196196
benchmark.print(f"Running NCCL Benchmark on {world_size} ranks")
197197

198198
# Built in small message benchmarks, in future we can add more
199-
gather_graph_data = get_nccl_gather_benchmark_data(message_size, world_size, device)
200199

201-
benchmark.print("*" * 50)
202-
benchmark.print("Running Gather Benchmark")
203-
times = run_gather_benchmark(benchmark, num_iters, gather_graph_data, cache=None)
204-
205-
benchmark.print("Saving Gather Benchmark Times")
206-
207-
for i in range(world_size):
208-
benchmark.save_np(times, f"{log_dir}/NCCL_gather_times_{i}.npy", rank_to_save=i)
200+
for i in range(1, 20):
201+
message_size *= 2
202+
benchmark.print("*" * 50)
203+
benchmark.print(f"Running NCCL Benchmark for message size {message_size}")
204+
gather_graph_data = get_nccl_gather_benchmark_data(
205+
message_size, world_size, device
206+
)
207+
dist.barrier()
209208

210-
benchmark.print("Gather Benchmark Complete")
211-
benchmark.print("*" * 50)
209+
benchmark.print("Running Gather Benchmark")
210+
times = run_gather_benchmark(
211+
benchmark, num_iters, gather_graph_data, cache=None
212+
)
212213

213-
if benchmark_cache:
214-
edge_placement = gather_graph_data.edge_rank_placement
215-
edge_src_rank = gather_graph_data.edge_src_rank
216-
indices = gather_graph_data.edge_indices
214+
benchmark.print("Saving Gather Benchmark Times")
217215

218-
gather_cache = NCCLGatherCacheGenerator(
219-
indices,
220-
edge_placement.view(-1),
221-
edge_src_rank.view(-1),
222-
1,
223-
rank,
224-
world_size,
216+
benchmark.save_np(
217+
times,
218+
f"{log_dir}/NCCL_gather_times_message_size_{message_size}"
219+
+ f"_world_size_{world_size}.npy",
220+
rank_to_save=0,
225221
)
222+
223+
benchmark.print("Gather Benchmark Complete")
226224
benchmark.print("*" * 50)
227-
benchmark.print("Running Gather Benchmark with Cache")
228-
times = run_gather_benchmark(
229-
benchmark, num_iters, gather_graph_data, cache=gather_cache
230-
)
231225

232-
benchmark.print("Saving Gather Benchmark with Cache Times")
233-
for i in range(world_size):
226+
if benchmark_cache:
227+
edge_placement = gather_graph_data.edge_rank_placement
228+
edge_src_rank = gather_graph_data.edge_src_rank
229+
indices = gather_graph_data.edge_indices
230+
231+
gather_cache = NCCLGatherCacheGenerator(
232+
indices,
233+
edge_placement.view(-1),
234+
edge_src_rank.view(-1),
235+
1,
236+
rank,
237+
world_size,
238+
)
239+
benchmark.print("*" * 50)
240+
benchmark.print("Running Gather Benchmark with Cache")
241+
times = run_gather_benchmark(
242+
benchmark, num_iters, gather_graph_data, cache=gather_cache
243+
)
244+
245+
benchmark.print("Saving Gather Benchmark with Cache Times")
234246
benchmark.save_np(
235-
times, f"{log_dir}/NCCL_gather_with_cache_times_{i}.npy", rank_to_save=i
247+
times,
248+
f"{log_dir}/NCCL_gather_with_cache_message_size_{message_size}"
249+
+ f"_world_size_{world_size}.npy",
250+
rank_to_save=0,
236251
)
237252

238-
benchmark.print("Gather Benchmark with Cache Complete")
239-
benchmark.print("*" * 50)
253+
benchmark.print("Gather Benchmark with Cache Complete")
254+
benchmark.print("*" * 50)
240255

241-
scatter_graph_data = get_nccl_scatter_benchmark_data(
242-
message_size, world_size, device
243-
)
244-
benchmark.print("*" * 50)
245-
benchmark.print("Running Scatter Benchmark")
246-
times = run_scatter_benchmark(benchmark, num_iters, scatter_graph_data, cache=None)
256+
scatter_graph_data = get_nccl_scatter_benchmark_data(
257+
message_size, world_size, device
258+
)
247259

248-
benchmark.print("Saving Scatter Benchmark Times")
249-
for i in range(world_size):
250-
benchmark.save_np(
251-
times, f"{log_dir}/NCCL_scatter_times_{i}.npy", rank_to_save=i
252-
)
260+
benchmark.print("*" * 50)
261+
benchmark.print("Running Scatter Benchmark")
262+
times = run_scatter_benchmark(
263+
benchmark, num_iters, scatter_graph_data, cache=None
264+
)
253265

254-
benchmark.print("Scatter Benchmark Complete")
255-
benchmark.print("*" * 50)
256-
if benchmark_cache:
257-
edge_placement = scatter_graph_data.edge_rank_placement
258-
edge_dest_rank = scatter_graph_data.edge_dest_rank
259-
indices = scatter_graph_data.edge_indices
260-
261-
scatter_cache = NCCLScatterCacheGenerator(
262-
indices,
263-
edge_placement.view(-1),
264-
edge_dest_rank.view(-1),
265-
1,
266-
rank,
267-
world_size,
268-
)
269-
benchmark.print("*" * 50)
270-
benchmark.print("Running Scatter Benchmark with Cache")
271-
times = run_scatter_benchmark(
272-
benchmark, num_iters, scatter_graph_data, cache=scatter_cache
273-
)
266+
benchmark.print("Saving Scatter Benchmark Times")
274267

275-
benchmark.print("Saving Scatter Benchmark with Cache Times")
276-
for i in range(world_size):
277268
benchmark.save_np(
278269
times,
279-
f"{log_dir}/NCCL_scatter_with_cache_times_{i}.npy",
280-
rank_to_save=i,
270+
f"{log_dir}/NCCL_scatter_times_message_size_{message_size}"
271+
+ f"_world_size_{world_size}.npy",
272+
rank_to_save=0,
281273
)
282274

283-
benchmark.print("Scatter Benchmark with Cache Complete")
284-
benchmark.print("*" * 50)
275+
benchmark.print("Scatter Benchmark Complete")
276+
benchmark.print("*" * 50)
277+
if benchmark_cache:
278+
edge_placement = scatter_graph_data.edge_rank_placement
279+
edge_dest_rank = scatter_graph_data.edge_dest_rank
280+
indices = scatter_graph_data.edge_indices
281+
282+
scatter_cache = NCCLScatterCacheGenerator(
283+
indices,
284+
edge_placement.view(-1),
285+
edge_dest_rank.view(-1),
286+
1,
287+
rank,
288+
world_size,
289+
)
290+
benchmark.print("*" * 50)
291+
benchmark.print("Running Scatter Benchmark with Cache")
292+
times = run_scatter_benchmark(
293+
benchmark, num_iters, scatter_graph_data, cache=scatter_cache
294+
)
295+
296+
benchmark.print("Saving Scatter Benchmark with Cache Times")
297+
298+
benchmark.save_np(
299+
times,
300+
f"{log_dir}/NCCL_scatter_with_cache_message_size_{message_size}"
301+
+ f"_world_size_{world_size}.npy",
302+
rank_to_save=0,
303+
)
304+
305+
benchmark.print("Scatter Benchmark with Cache Complete")
306+
benchmark.print("*" * 50)
285307

286308
dist.destroy_process_group()
287309

experiments/Benchmarks/generate_plots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,5 +76,5 @@ def generate_cache_comparison_plot():
7676

7777
if __name__ == "__main__":
7878
generate_plots("nccl")
79-
generate_plots("nvshmem")
79+
# generate_plots("nvshmem")
8080
generate_cache_comparison_plot()

experiments/GraphCast/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,8 @@ Run with benchmarking with the following command:
3838
python main.py --benchmark
3939
```
4040
***Note: *** The graph requires a large amount of memory so better to do run on the CPU and a machine with a large amount of memory.
41+
42+
Run with multiple processes per GPU with the following command:
43+
```bash
44+
torchrun-hpc --xargs=--mpibind=off --xargs=--gpu-bind=none train_graphcast.py --is_distributed True --procs_per_gpu 4
45+
```

experiments/GraphCast/data_utils/graphcast_graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ def get_grid2mesh_graph(self, mesh_graph_dict: dict):
219219
contigous_edge_mapping, renumbered_edges = torch.sort(meshtogrid_edge_placement)
220220

221221
src_grid_indices = src_grid_indices[renumbered_edges]
222-
grid_vertex_rank_placement = torch.zeros_like(lat_lon_grid_flat)
222+
grid_vertex_rank_placement = torch.zeros_like(lat_lon_grid_flat[:, 0])
223+
223224
for i, rank in enumerate(meshtogrid_edge_placement):
224225
loc = src_grid_indices[i]
225226
grid_vertex_rank_placement[loc] = rank
@@ -254,6 +255,7 @@ def get_mesh2grid_graph(
254255
)
255256

256257
edge_features, src_mesh_indices, dst_grid_indices = m2g_graph
258+
breakpoint()
257259
src_mesh_indices = renumbered_vertices[src_mesh_indices]
258260
dst_grid_indices = renumbered_grid[dst_grid_indices]
259261

experiments/GraphCast/dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(
8787
self.lat_lon_grid = torch.stack(
8888
torch.meshgrid(self.latitudes, self.longitudes, indexing="ij"), dim=-1
8989
)
90+
9091
self.graph_cast_graph = DistributedGraphCastGraphGenerator(
9192
self.lat_lon_grid,
9293
mesh_level=self.mesh_level,

experiments/GraphCast/layers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from DGraph.Communicator import Communicator
2121
from dist_utils import SingleProcessDummyCommunicator
2222

23+
# class MLPSiLuWithRecompute(nn.Module):
24+
2325

2426
class MeshGraphMLP(nn.Module):
2527
"""MLP for graph processing"""

experiments/GraphCast/model.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,10 @@ def __init__(self, cfg: Config, comm, *args, **kwargs):
330330
)
331331

332332
def forward(
333-
self, input_grid_features: Tensor, static_graph: DistributedGraphCastGraph
333+
self,
334+
input_grid_features: Tensor,
335+
static_graph: DistributedGraphCastGraph,
336+
device: Optional[torch.device] = None,
334337
) -> Tensor:
335338
"""
336339
Args:
@@ -340,18 +343,19 @@ def forward(
340343
Returns:
341344
(Tensor): The predicted output grid
342345
"""
343-
344-
input_grid_features = input_grid_features.squeeze(0)
345-
input_mesh_features = static_graph.mesh_graph_node_features
346-
mesh2mesh_edge_features = static_graph.mesh_graph_edge_features
347-
grid2mesh_edge_features = static_graph.grid2mesh_graph_edge_features
348-
mesh2grid_edge_features = static_graph.mesh2grid_graph_edge_features
349-
mesh2mesh_edge_indices_src = static_graph.mesh_graph_src_indices
350-
mesh2mesh_edge_indices_dst = static_graph.mesh_graph_dst_indices
351-
mesh2grid_edge_indices_src = static_graph.mesh2grid_graph_src_indices
352-
mesh2grid_edge_indices_dst = static_graph.mesh2grid_graph_dst_indices
353-
grid2mesh_edge_indices_src = static_graph.grid2mesh_graph_src_indices
354-
grid2mesh_edge_indices_dst = static_graph.grid2mesh_graph_dst_indices
346+
if device is None:
347+
device = input_grid_features.device
348+
input_grid_features = input_grid_features.squeeze(0).to(device)
349+
input_mesh_features = static_graph.mesh_graph_node_features.to(device)
350+
mesh2mesh_edge_features = static_graph.mesh_graph_edge_features.to(device)
351+
grid2mesh_edge_features = static_graph.grid2mesh_graph_edge_features.to(device)
352+
mesh2grid_edge_features = static_graph.mesh2grid_graph_edge_features.to(device)
353+
mesh2mesh_edge_indices_src = static_graph.mesh_graph_src_indices.to(device)
354+
mesh2mesh_edge_indices_dst = static_graph.mesh_graph_dst_indices.to(device)
355+
mesh2grid_edge_indices_src = static_graph.mesh2grid_graph_src_indices.to(device)
356+
mesh2grid_edge_indices_dst = static_graph.mesh2grid_graph_dst_indices.to(device)
357+
grid2mesh_edge_indices_src = static_graph.grid2mesh_graph_src_indices.to(device)
358+
grid2mesh_edge_indices_dst = static_graph.grid2mesh_graph_dst_indices.to(device)
355359

356360
out = self.embedder(
357361
input_grid_features,

experiments/GraphCast/train_graphcast.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ def main(
6060
comm = Communicator.init_process_group(
6161
_communicator, ranks_per_graph=procs_per_graph
6262
)
63+
mesh_graph_placement = torch.load("mesh_vertex_rank_placement_4.pt")
6364
else:
6465
comm = SingleProcessDummyCommunicator()
66+
mesh_graph_placement = torch.zeros(40962, dtype=torch.int64)
6567
if not use_synthetic_data:
6668
raise NotImplementedError("Real data is not yet supported yet.")
6769

@@ -106,6 +108,7 @@ def main(
106108
dataset = SyntheticWeatherDataset(
107109
channels=[x for x in range(cfg.data.num_channels_climate)],
108110
num_samples_per_year=cfg.data.num_samples_per_year_train,
111+
mesh_vertex_placement=mesh_graph_placement,
109112
num_steps=cfg.data.num_history,
110113
device=torch.device("cpu"),
111114
)
@@ -127,12 +130,12 @@ def main(
127130
break_training = False
128131

129132
for data in dataloader:
130-
in_data = data["invar"]
131-
ground_truth = data["outvar"]
133+
in_data = data["invar"].to(device)
134+
ground_truth = data["outvar"].to(device)
132135

133136
model.train()
134137
optimizer.zero_grad()
135-
predicted_grid = model(in_data, static_graph)
138+
predicted_grid = model(in_data, static_graph, device=device)
136139
loss = compute_loss(ground_truth, predicted_grid, comm)
137140
loss.backward()
138141
optimizer.step()

experiments/OGB/main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def _run_experiment(
163163
gather_cache_file = f"{cache_prefix}_gather_cache_{world_size}_{rank}.pt"
164164

165165
if os.path.exists(gather_cache_file):
166+
print(f"Rank: {rank} Loading gather cache from {gather_cache_file}")
166167
gather_cache = torch.load(gather_cache_file, weights_only=False)
167168

168169
if os.path.exists(scatter_cache_file):
@@ -379,6 +380,11 @@ def main(
379380
validation_trajectores = np.zeros((runs, epochs))
380381
validation_accuracies = np.zeros((runs, epochs))
381382
world_size = comm.get_world_size()
383+
384+
dist.barrier()
385+
print(f"Running experiment with {world_size} processes on dataset {dataset}")
386+
print(f"Using cache: {use_cache}")
387+
382388
for i in range(runs):
383389
log_prefix = f"{log_dir}/{dataset}_{world_size}_cache={use_cache}_run_{i}"
384390
training_traj, val_traj, val_accuracy = _run_experiment(

0 commit comments

Comments
 (0)