@@ -169,7 +169,7 @@ def run_scatter_benchmark(
169169
170170def main ():
171171 parser = argparse .ArgumentParser ()
172- parser .add_argument ("--message_size" , type = int , default = 128 )
172+ parser .add_argument ("--message_size" , type = int , default = 2 )
173173 parser .add_argument ("--benchmark_cache" , action = "store_true" )
174174 parser .add_argument ("--num_iters" , type = int , default = 1000 )
175175 parser .add_argument ("--log_dir" , type = str , default = "logs" )
@@ -196,92 +196,114 @@ def main():
196196 benchmark .print (f"Running NCCL Benchmark on { world_size } ranks" )
197197
198198 # Built in small message benchmarks, in future we can add more
199- gather_graph_data = get_nccl_gather_benchmark_data (message_size , world_size , device )
200199
201- benchmark . print ( "*" * 50 )
202- benchmark . print ( "Running Gather Benchmark" )
203- times = run_gather_benchmark ( benchmark , num_iters , gather_graph_data , cache = None )
204-
205- benchmark . print ( "Saving Gather Benchmark Times" )
206-
207- for i in range ( world_size ):
208- benchmark . save_np ( times , f" { log_dir } /NCCL_gather_times_ { i } .npy" , rank_to_save = i )
200+ for i in range ( 1 , 20 ):
201+ message_size *= 2
202+ benchmark . print ( "*" * 50 )
203+ benchmark . print ( f"Running NCCL Benchmark for message size { message_size } " )
204+ gather_graph_data = get_nccl_gather_benchmark_data (
205+ message_size , world_size , device
206+ )
207+ dist . barrier ( )
209208
210- benchmark .print ("Gather Benchmark Complete" )
211- benchmark .print ("*" * 50 )
209+ benchmark .print ("Running Gather Benchmark" )
210+ times = run_gather_benchmark (
211+ benchmark , num_iters , gather_graph_data , cache = None
212+ )
212213
213- if benchmark_cache :
214- edge_placement = gather_graph_data .edge_rank_placement
215- edge_src_rank = gather_graph_data .edge_src_rank
216- indices = gather_graph_data .edge_indices
214+ benchmark .print ("Saving Gather Benchmark Times" )
217215
218- gather_cache = NCCLGatherCacheGenerator (
219- indices ,
220- edge_placement .view (- 1 ),
221- edge_src_rank .view (- 1 ),
222- 1 ,
223- rank ,
224- world_size ,
216+ benchmark .save_np (
217+ times ,
218+ f"{ log_dir } /NCCL_gather_times_message_size_{ message_size } "
219+ + f"_world_size_{ world_size } .npy" ,
220+ rank_to_save = 0 ,
225221 )
222+
223+ benchmark .print ("Gather Benchmark Complete" )
226224 benchmark .print ("*" * 50 )
227- benchmark .print ("Running Gather Benchmark with Cache" )
228- times = run_gather_benchmark (
229- benchmark , num_iters , gather_graph_data , cache = gather_cache
230- )
231225
232- benchmark .print ("Saving Gather Benchmark with Cache Times" )
233- for i in range (world_size ):
226+ if benchmark_cache :
227+ edge_placement = gather_graph_data .edge_rank_placement
228+ edge_src_rank = gather_graph_data .edge_src_rank
229+ indices = gather_graph_data .edge_indices
230+
231+ gather_cache = NCCLGatherCacheGenerator (
232+ indices ,
233+ edge_placement .view (- 1 ),
234+ edge_src_rank .view (- 1 ),
235+ 1 ,
236+ rank ,
237+ world_size ,
238+ )
239+ benchmark .print ("*" * 50 )
240+ benchmark .print ("Running Gather Benchmark with Cache" )
241+ times = run_gather_benchmark (
242+ benchmark , num_iters , gather_graph_data , cache = gather_cache
243+ )
244+
245+ benchmark .print ("Saving Gather Benchmark with Cache Times" )
234246 benchmark .save_np (
235- times , f"{ log_dir } /NCCL_gather_with_cache_times_{ i } .npy" , rank_to_save = i
247+ times ,
248+ f"{ log_dir } /NCCL_gather_with_cache_message_size_{ message_size } "
249+ + f"_world_size_{ world_size } .npy" ,
250+ rank_to_save = 0 ,
236251 )
237252
238- benchmark .print ("Gather Benchmark with Cache Complete" )
239- benchmark .print ("*" * 50 )
253+ benchmark .print ("Gather Benchmark with Cache Complete" )
254+ benchmark .print ("*" * 50 )
240255
241- scatter_graph_data = get_nccl_scatter_benchmark_data (
242- message_size , world_size , device
243- )
244- benchmark .print ("*" * 50 )
245- benchmark .print ("Running Scatter Benchmark" )
246- times = run_scatter_benchmark (benchmark , num_iters , scatter_graph_data , cache = None )
256+ scatter_graph_data = get_nccl_scatter_benchmark_data (
257+ message_size , world_size , device
258+ )
247259
248- benchmark .print ("Saving Scatter Benchmark Times" )
249- for i in range ( world_size ):
250- benchmark . save_np (
251- times , f" { log_dir } /NCCL_scatter_times_ { i } .npy" , rank_to_save = i
252- )
260+ benchmark .print ("*" * 50 )
261+ benchmark . print ( "Running Scatter Benchmark" )
262+ times = run_scatter_benchmark (
263+ benchmark , num_iters , scatter_graph_data , cache = None
264+ )
253265
254- benchmark .print ("Scatter Benchmark Complete" )
255- benchmark .print ("*" * 50 )
256- if benchmark_cache :
257- edge_placement = scatter_graph_data .edge_rank_placement
258- edge_dest_rank = scatter_graph_data .edge_dest_rank
259- indices = scatter_graph_data .edge_indices
260-
261- scatter_cache = NCCLScatterCacheGenerator (
262- indices ,
263- edge_placement .view (- 1 ),
264- edge_dest_rank .view (- 1 ),
265- 1 ,
266- rank ,
267- world_size ,
268- )
269- benchmark .print ("*" * 50 )
270- benchmark .print ("Running Scatter Benchmark with Cache" )
271- times = run_scatter_benchmark (
272- benchmark , num_iters , scatter_graph_data , cache = scatter_cache
273- )
266+ benchmark .print ("Saving Scatter Benchmark Times" )
274267
275- benchmark .print ("Saving Scatter Benchmark with Cache Times" )
276- for i in range (world_size ):
277268 benchmark .save_np (
278269 times ,
279- f"{ log_dir } /NCCL_scatter_with_cache_times_{ i } .npy" ,
280- rank_to_save = i ,
270+ f"{ log_dir } /NCCL_scatter_times_message_size_{ message_size } "
271+ + f"_world_size_{ world_size } .npy" ,
272+ rank_to_save = 0 ,
281273 )
282274
283- benchmark .print ("Scatter Benchmark with Cache Complete" )
284- benchmark .print ("*" * 50 )
275+ benchmark .print ("Scatter Benchmark Complete" )
276+ benchmark .print ("*" * 50 )
277+ if benchmark_cache :
278+ edge_placement = scatter_graph_data .edge_rank_placement
279+ edge_dest_rank = scatter_graph_data .edge_dest_rank
280+ indices = scatter_graph_data .edge_indices
281+
282+ scatter_cache = NCCLScatterCacheGenerator (
283+ indices ,
284+ edge_placement .view (- 1 ),
285+ edge_dest_rank .view (- 1 ),
286+ 1 ,
287+ rank ,
288+ world_size ,
289+ )
290+ benchmark .print ("*" * 50 )
291+ benchmark .print ("Running Scatter Benchmark with Cache" )
292+ times = run_scatter_benchmark (
293+ benchmark , num_iters , scatter_graph_data , cache = scatter_cache
294+ )
295+
296+ benchmark .print ("Saving Scatter Benchmark with Cache Times" )
297+
298+ benchmark .save_np (
299+ times ,
300+ f"{ log_dir } /NCCL_scatter_with_cache_message_size_{ message_size } "
301+ + f"_world_size_{ world_size } .npy" ,
302+ rank_to_save = 0 ,
303+ )
304+
305+ benchmark .print ("Scatter Benchmark with Cache Complete" )
306+ benchmark .print ("*" * 50 )
285307
286308 dist .destroy_process_group ()
287309
0 commit comments