-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathdistributed.py
More file actions
100 lines (79 loc) · 3.23 KB
/
distributed.py
File metadata and controls
100 lines (79 loc) · 3.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""Information about distributed environments."""
import os
from dataclasses import dataclass
from typing import Final
COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY: Final[
str
] = "COMPUTE_CLUSTER_LOCAL_WORLD_SIZE"
# Environment variable to indicate the type of job.
# Values: "train", "inference"
JOB_TYPE_ENV_KEY: Final[str] = "GIGL_JOB_TYPE"
@dataclass(frozen=True)
class DistributedContext:
"""
GiGL Distributed Context
"""
# TODO (mkolodner-sc): Investigate adding local rank and local world size
# Main Worker's IP Address for RPC communication
main_worker_ip_address: str
# Rank of machine
global_rank: int
# Total number of machines
global_world_size: int
@dataclass(frozen=True)
class GraphStoreInfo:
"""Information about a graph store cluster."""
# Number of nodes in the storage cluster
num_storage_nodes: int
# Number of nodes in the compute cluster
num_compute_nodes: int
# IP address of the master node for the whole cluster
cluster_master_ip: str
# IP address of the master node for the storage cluster
storage_cluster_master_ip: str
# IP address of the master node for the compute cluster
compute_cluster_master_ip: str
# Port of the master node for the whole cluster
cluster_master_port: int
# Port of the master node for the storage cluster
storage_cluster_master_port: int
# Port of the master node for the compute cluster
compute_cluster_master_port: int
# Number of processes per compute machine
# See documentation on the VertexAiGraphStoreConfig message for more details.
# https://snapchat.github.io/GiGL/docs/api/snapchat/research/gbml/gigl_resource_config_pb2/index.html#snapchat.research.gbml.gigl_resource_config_pb2.VertexAiGraphStoreConfig
num_processes_per_compute: int
@property
def num_cluster_nodes(self) -> int:
return self.num_storage_nodes + self.num_compute_nodes
@property
def compute_cluster_world_size(self) -> int:
return self.num_compute_nodes * self.num_processes_per_compute
@property
def storage_node_rank(self) -> int:
"""Get the rank of the storage node in the storage cluster.
Raises:
ValueError: If the node is not in the storage cluster.
"""
global_rank = int(os.environ["RANK"])
if not (
self.num_compute_nodes
<= global_rank
< self.num_compute_nodes + self.num_storage_nodes
):
raise ValueError(
f"Global rank {global_rank} is not a storage rank. Expected storage rank to be in [{self.num_compute_nodes}, {self.num_compute_nodes + self.num_storage_nodes})"
)
return global_rank - self.num_compute_nodes
@property
def compute_node_rank(self) -> int:
"""Get the rank of the compute node in the compute cluster.
Raises:
ValueError: If the node is not in the compute cluster.
"""
global_rank = int(os.environ["RANK"])
if not 0 <= global_rank < self.num_compute_nodes:
raise ValueError(
f"Global rank {global_rank} is not a compute rank. Expected compute rank to be in [0, {self.num_compute_nodes})"
)
return global_rank