Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/gigl/env/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
str
] = "COMPUTE_CLUSTER_LOCAL_WORLD_SIZE"

# Environment variable to indicate the type of job.
# Values: "train", "inference"
JOB_TYPE_ENV_KEY: Final[str] = "GIGL_JOB_TYPE"
Comment thread
kmontemayor2-sc marked this conversation as resolved.
Outdated


@dataclass(frozen=True)
class DistributedContext:
Expand Down
15 changes: 15 additions & 0 deletions python/gigl/src/common/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os

from gigl.env.distributed import JOB_TYPE_ENV_KEY
from gigl.src.common.constants.components import GiGLComponents


def get_component() -> GiGLComponents:
"""Get the component of the current job.

Returns:
GiGLComponents: The component of the current job.
Raises:
ValueError: If the component is not valid.
"""
return GiGLComponents(os.environ[JOB_TYPE_ENV_KEY])
14 changes: 9 additions & 5 deletions python/gigl/src/common/vertex_ai_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
)
from gigl.common.logger import Logger
from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService
from gigl.env.distributed import COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY
from gigl.env.distributed import (
COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY,
JOB_TYPE_ENV_KEY,
)
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.types.pb_wrappers.gigl_resource_config import (
GiglResourceConfigWrapper,
Expand Down Expand Up @@ -67,7 +70,6 @@ def launch_single_pool_job(
cpu_docker_uri = cpu_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU
cuda_docker_uri = cuda_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA
container_uri = cpu_docker_uri if is_cpu_execution else cuda_docker_uri

job_config = _build_job_config(
job_name=job_name,
task_config_uri=task_config_uri,
Expand All @@ -77,7 +79,10 @@ def launch_single_pool_job(
use_cuda=is_cpu_execution,
container_uri=container_uri,
vertex_ai_resource_config=vertex_ai_resource_config,
env_vars=[env_var.EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="3")],
env_vars=[
env_var.EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="3"),
env_var.EnvVar(name=JOB_TYPE_ENV_KEY, value=component.value),
],
labels=resource_config_wrapper.get_resource_labels(component=component),
)
logger.info(f"Launching {component.value} job with config: {job_config}")
Expand Down Expand Up @@ -150,10 +155,10 @@ def launch_graph_store_enabled_job(
name=COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY,
value=str(num_compute_processes),
),
env_var.EnvVar(name=JOB_TYPE_ENV_KEY, value=component.value),
]

labels = resource_config_wrapper.get_resource_labels(component=component)

# Create compute pool job config
compute_job_config = _build_job_config(
job_name=job_name,
Expand Down Expand Up @@ -221,7 +226,6 @@ def _build_job_config(

Args:
job_name (str): The base name for the job. Will be prefixed with "gigl_train_" or "gigl_infer_".
is_inference (bool): Whether this is an inference job (True) or training job (False).
task_config_uri (Uri): URI to the task configuration file.
resource_config_uri (Uri): URI to the resource configuration file.
command_str (str): The command to run in the container (will be split on spaces).
Expand Down
33 changes: 33 additions & 0 deletions python/tests/unit/src/common/env_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import unittest
from unittest import mock

from gigl.env.distributed import JOB_TYPE_ENV_KEY
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.env import get_component


class TestGetComponent(unittest.TestCase):
"""Test suite for get_component function."""

@mock.patch.dict(os.environ, {JOB_TYPE_ENV_KEY: GiGLComponents.Trainer.value})
def test_get_component_valid_value(self):
"""Test get_component returns correct component when env var is valid."""
result = get_component()
self.assertEqual(result, GiGLComponents.Trainer)

@mock.patch.dict(os.environ, {JOB_TYPE_ENV_KEY: "invalid_component"})
def test_get_component_invalid_value(self):
"""Test get_component raises ValueError when env var is invalid."""
with self.assertRaises(ValueError):
get_component()

@mock.patch.dict(os.environ, {}, clear=True)
def test_get_component_not_set(self):
"""Test get_component raises KeyError when env var is not set."""
with self.assertRaises(KeyError):
get_component()


if __name__ == "__main__":
unittest.main()