@@ -157,7 +157,7 @@ def _run_experiment(
157157
158158 # This says where the edges are located
159159 edge_placement = rank_mappings [0 ]
160-
160+
161161 cache_prefix = f"cache/{ dset_name } "
162162 scatter_cache_file = f"{ cache_prefix } _scatter_cache_{ world_size } _{ rank } .pt"
163163 gather_cache_file = f"{ cache_prefix } _gather_cache_{ world_size } _{ rank } .pt"
@@ -187,8 +187,8 @@ def _run_experiment(
187187 world_size ,
188188 )
189189 with open (f"{ log_prefix } _gather_cache_{ world_size } _{ rank } .pt" , "wb" ) as f :
190- torch .save (gather_cache , f )
191-
190+ torch .save (gather_cache , f )
191+
192192 if scatter_cache is None :
193193 nodes_per_rank = dataset .graph_obj .get_nodes_per_rank ()
194194
@@ -230,12 +230,11 @@ def _run_experiment(
230230 end_time = perf_counter ()
231231 print (f"Rank: { rank } Cache Generation Time: { end_time - start_time :.4f} s" )
232232
233-
234- #with open(f"{log_prefix}_gather_cache_{world_size}_{rank}.pt", "wb") as f:
233+ # with open(f"{log_prefix}_gather_cache_{world_size}_{rank}.pt", "wb") as f:
235234 # torch.save(gather_cache, f)
236- #with open(f"{log_prefix}_scatter_cache_{world_size}_{rank}.pt", "wb") as f:
235+ # with open(f"{log_prefix}_scatter_cache_{world_size}_{rank}.pt", "wb") as f:
237236 # torch.save(scatter_cache, f)
238- #print(f"Rank: {rank} Cache Generated")
237+ # print(f"Rank: {rank} Cache Generated")
239238
240239 training_times = []
241240 for i in range (epochs ):
@@ -391,7 +390,7 @@ def main(
391390 use_cache = use_cache ,
392391 num_classes = num_classes ,
393392 dset_name = dset_name ,
394- in_dim = in_dims [dset_name ]
393+ in_dim = in_dims [dset_name ],
395394 )
396395 training_trajectores [i ] = training_traj
397396 validation_trajectores [i ] = val_traj
0 commit comments