Skip to content

Commit 82b48e9

Browse files
committed
Optimized cache generation with scatter call and inverse_indices
Add standalone file to generation cache to asynchronously generate and save caches - Update run code to load pre-saved cache files
1 parent 2ea4bca commit 82b48e9

3 files changed

Lines changed: 15 additions & 9 deletions

File tree

DGraph/distributed/RankLocalOps.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,13 @@ def RankLocalRenumberingWithMapping(_indices, rank_mapping):
140140
unique_indices, inverse_indices = torch.unique(_indices, return_inverse=True)
141141
rank_mapping = rank_mapping.to(_indices.device)
142142
renumbered_indices = inverse_indices
143+
<<<<<<< HEAD
143144
unique_rank_mapping = torch.zeros_like(unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device)
145+
=======
146+
unique_rank_mapping = torch.zeros_like(
147+
unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device
148+
)
149+
>>>>>>> a18faef (Optimized cache generation with scatter call and inverse_indices)
144150
unique_rank_mapping.scatter_(0, inverse_indices, rank_mapping)
145151

146152
return renumbered_indices, unique_indices, unique_rank_mapping

experiments/OGB/GenerateCache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"ogbn-arxiv": "arxiv",
3131
"ogbn-products": "products",
3232
"ogbn-papers100M": "papers100M",
33+
"ogbn-proteins": "proteins",
3334
}
3435

3536

@@ -85,7 +86,7 @@ def generate_cache_file(
8586

8687

8788
def main(dset: str, world_size: int, node_rank_placement_file: str):
88-
assert dset in ["ogbn-arxiv", "ogbn-products", "ogbn-papers100M"]
89+
assert dset in ["ogbn-arxiv", "ogbn-products", "ogbn-papers100M", "ogbn-proteins"]
8990

9091
assert world_size > 0
9192
assert os.path.exists(

experiments/OGB/main.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _run_experiment(
157157

158158
# This says where the edges are located
159159
edge_placement = rank_mappings[0]
160-
160+
161161
cache_prefix = f"cache/{dset_name}"
162162
scatter_cache_file = f"{cache_prefix}_scatter_cache_{world_size}_{rank}.pt"
163163
gather_cache_file = f"{cache_prefix}_gather_cache_{world_size}_{rank}.pt"
@@ -187,8 +187,8 @@ def _run_experiment(
187187
world_size,
188188
)
189189
with open(f"{log_prefix}_gather_cache_{world_size}_{rank}.pt", "wb") as f:
190-
torch.save(gather_cache, f)
191-
190+
torch.save(gather_cache, f)
191+
192192
if scatter_cache is None:
193193
nodes_per_rank = dataset.graph_obj.get_nodes_per_rank()
194194

@@ -230,12 +230,11 @@ def _run_experiment(
230230
end_time = perf_counter()
231231
print(f"Rank: {rank} Cache Generation Time: {end_time - start_time:.4f} s")
232232

233-
234-
#with open(f"{log_prefix}_gather_cache_{world_size}_{rank}.pt", "wb") as f:
233+
# with open(f"{log_prefix}_gather_cache_{world_size}_{rank}.pt", "wb") as f:
235234
# torch.save(gather_cache, f)
236-
#with open(f"{log_prefix}_scatter_cache_{world_size}_{rank}.pt", "wb") as f:
235+
# with open(f"{log_prefix}_scatter_cache_{world_size}_{rank}.pt", "wb") as f:
237236
# torch.save(scatter_cache, f)
238-
#print(f"Rank: {rank} Cache Generated")
237+
# print(f"Rank: {rank} Cache Generated")
239238

240239
training_times = []
241240
for i in range(epochs):
@@ -391,7 +390,7 @@ def main(
391390
use_cache=use_cache,
392391
num_classes=num_classes,
393392
dset_name=dset_name,
394-
in_dim=in_dims[dset_name]
393+
in_dim=in_dims[dset_name],
395394
)
396395
training_trajectores[i] = training_traj
397396
validation_trajectores[i] = val_traj

0 commit comments

Comments
 (0)