Skip to content

Commit 0dda886

Browse files
authored
Refactor pathways workload scheduling to use Jinja template (#1116)
* Refactor pathways workload scheduling to use Jinja template * Refactor jinja injection and cleanup workload.py
1 parent 65ffde6 commit 0dda886

5 files changed

Lines changed: 372 additions & 420 deletions

File tree

recipes/Workload_create_pathways.md

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ docker buildx build --platform=linux/amd64 -f 4b6736a12db8ea0f78ce793fd0d4ee0c94
4242
docker tag dry-run-runner gcr.io/golden-project/dry-run-runner:prefix-current
4343
[XPK] Task: `Upload Docker Image` is implemented by the following command not running since it is a dry run.
4444
docker push gcr.io/golden-project/dry-run-runner:prefix-current
45-
[XPK] Temp file (5394ec8d9ca40eb8e048844d5622be8da9ac27cc9565535366c11a379ea35f58) content:
45+
[XPK] Temp file (8e311bca9f9f54ee09e88dbb12e7d20d536478aa69d990a59142ee6fb70da079) content:
4646
apiVersion: jobset.x-k8s.io/v1alpha2
4747
kind: JobSet
4848
metadata:
@@ -76,7 +76,6 @@ spec:
7676
dnsPolicy: ClusterFirstWithHostNet
7777
nodeSelector:
7878
cloud.google.com/gke-nodepool: cpu-np
79-
8079
initContainers:
8180
- name: pathways-proxy
8281
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest
@@ -132,9 +131,7 @@ spec:
132131
cpu: "8"
133132
memory: 32G
134133
restartPolicy: Always
135-
136134
containers:
137-
138135
- name: jax-tpu
139136
image: gcr.io/golden-project/dry-run-runner:prefix-current
140137
imagePullPolicy: Always
@@ -209,8 +206,6 @@ spec:
209206
nodeSelector:
210207
cloud.google.com/gke-tpu-accelerator: tpu-v5p-slice
211208
cloud.google.com/gke-tpu-topology: 2x2x1
212-
213-
214209
containers:
215210
- name: pathways-worker
216211
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
@@ -282,7 +277,7 @@ spec:
282277
suspend: false
283278
284279
[XPK] Task: `Creating Workload` is implemented by the following command not running since it is a dry run.
285-
kubectl apply -f 5394ec8d9ca40eb8e048844d5622be8da9ac27cc9565535366c11a379ea35f58
280+
kubectl apply -f 8e311bca9f9f54ee09e88dbb12e7d20d536478aa69d990a59142ee6fb70da079
286281
[XPK] Task: `GKE Dashboard List` is implemented by the following command not running since it is a dry run.
287282
gcloud monitoring dashboards list --project=golden-project --filter="displayName:'GKE - TPU Monitoring Dashboard'" --format="value(name)" --verbosity=error
288283
[XPK] Check statistics and outlier mode of GKE metrics here: https://console.cloud.google.com/monitoring/dashboards/builder/0?project=golden-project&f.rlabel.cluster_name.ClusterName=golden-cluster. To view the metric data for your workload, select golden-workload from the JobName filter on the dashboard.

src/xpk/commands/workload.py

Lines changed: 88 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"""
1616

1717
import urllib
18+
import argparse
19+
from ..core.system_characteristics import SystemCharacteristics
1820
from ..core.blueprint.blueprint_generator import (
1921
a3high_device_type,
2022
a4x_device_types,
@@ -41,14 +43,9 @@
4143
)
4244
from ..core.network import get_cluster_subnetworks
4345
from ..core.pathways import (
44-
append_custom_colocated_python_sidecar,
45-
append_custom_pathways_proxy_server,
46-
append_custom_pathways_server,
47-
append_custom_pathways_worker,
4846
check_if_pathways_job_is_installed,
4947
ensure_pathways_workload_prerequisites,
5048
get_pathways_unified_query_link,
51-
get_user_workload_for_pathways,
5249
try_to_delete_pathwaysjob_first,
5350
)
5451
from ..core.resources import get_cluster_capacity_type, get_cluster_system_characteristics_from_config_map
@@ -58,9 +55,7 @@
5855
ONE_TO_ONE_REPLICA_NODE_POOL_ASSIGNMENT_ANNOTATION,
5956
WorkloadScheduling,
6057
check_if_workload_can_schedule,
61-
create_tpu_machine_type,
6258
create_tpu_slice_topology_annotation,
63-
create_tpu_topology,
6459
get_cpu_affinity,
6560
get_gpu_scheduler,
6661
create_sub_slicing_annotations,
@@ -106,6 +101,8 @@
106101
from jinja2 import Environment, FileSystemLoader
107102
from ..utils.templates import get_templates_absolute_path
108103

104+
_PATHWAYS_WORKLOAD_TEMPLATE = 'pathways_workload_create.yaml.j2'
105+
109106
_SUPER_SLICING_WORKLOAD_NAME_LIMIT = 28
110107
"""Maximum safe workload name length to avoid exceeding GCE's 63-character limit.
111108
@@ -263,89 +260,90 @@
263260
containers:
264261
{container}
265262
"""
266-
# The indentation of PW_WORKLOAD_CREATE_YAML is intentional to allow reusing the user workload container YAML.
267-
PW_WORKLOAD_CREATE_YAML = """apiVersion: jobset.x-k8s.io/v1alpha2
268-
kind: JobSet
269-
metadata:
270-
name: {args.workload}
271-
labels:
272-
kueue.x-k8s.io/queue-name: {local_queue_name} # Name of the LocalQueue
273-
xpk.google.com/workload: {args.workload}
274-
spec:
275-
coordinator:
276-
replicatedJob: pathways-head
277-
network:
278-
enableDNSHostnames: true
279-
publishNotReadyAddresses: true
280-
failurePolicy:
281-
restartStrategy: Recreate
282-
replicatedJobs:
283-
- name: pathways-head
284-
replicas: 1
285-
template:
286-
spec:
287-
backoffLimit: 0
288-
completionMode: Indexed
289-
completions: 1
290-
parallelism: 1
291-
template:
292-
metadata:
293-
annotations:
294-
alpha.jobset.sigs.k8s.io/exclusive-topology: kubernetes.io/hostname
295-
spec:
296-
hostNetwork: true
297-
dnsPolicy: ClusterFirstWithHostNet
298-
nodeSelector:
299-
cloud.google.com/gke-nodepool: cpu-np
300-
{autoprovisioning_args}
301-
{pathways_head_containers}
302-
restartPolicy: Never
303-
volumes:
304-
- hostPath:
305-
path: /tmp
306-
type: DirectoryOrCreate
307-
name: shared-tmp
308-
- name: worker
309-
replicas: {args.num_slices}
310-
template:
311-
spec:
312-
backoffLimit: {worker_backoff_limit}
313-
completionMode: Indexed
314-
completions: {vms_per_slice}
315-
parallelism: {vms_per_slice}
316-
template:
317-
metadata:
318-
labels:
319-
xpk.google.com/workload: {args.workload}
320-
annotations:
321-
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
322-
spec:
323-
hostNetwork: true
324-
dnsPolicy: ClusterFirstWithHostNet
325-
terminationGracePeriodSeconds: {args.termination_grace_period_seconds}
326-
priorityClassName: {args.priority}
327-
nodeSelector:
328-
{accelerator_label}
329-
{node_selector_machine_label}
330-
{placement_policy_label}
331-
{autoprovisioning_args}
332-
containers:
333-
{custom_pathways_worker}
334-
restartPolicy: OnFailure
335-
volumes:
336-
- hostPath:
337-
path: /tmp
338-
type: DirectoryOrCreate
339-
name: shared-tmp
340-
startupPolicy:
341-
startupPolicyOrder: InOrder
342-
{success_policy}
343-
suspend: false
344-
"""
345263

346264
ARM_GPU_WORKLOAD_CREATE_JINJA_FILE = 'arm_gpu_workload_crate.yaml.j2'
347265

348266

267+
def _generate_pathways_workload_yaml(
268+
args: argparse.Namespace,
269+
workload_system: SystemCharacteristics,
270+
parallel_containers: int,
271+
placement_policy_label: str,
272+
autoprovisioning_args: str | None,
273+
) -> str:
274+
worker_backoff_limit = (
275+
(args.max_slice_restarts * workload_system.vms_per_slice)
276+
if getattr(args, 'elastic_slices', 0) > 0
277+
else (workload_system.vms_per_slice * 4)
278+
)
279+
280+
proxy_server_image = (
281+
getattr(args, 'proxy_server_image', None)
282+
or 'us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest'
283+
)
284+
server_image = (
285+
getattr(args, 'server_image', None)
286+
or 'us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest'
287+
)
288+
worker_image = getattr(args, 'worker_image', None) or server_image
289+
instance_type = (
290+
f'{workload_system.pathways_tpu_version}:{workload_system.topology}'
291+
if workload_system.pathways_tpu_version
292+
else workload_system.gce_machine_type
293+
)
294+
if args.headless:
295+
user_workload_container = ''
296+
user_workload_env_vars = []
297+
else:
298+
user_workload_container, _ = get_user_workload_container(
299+
args, workload_system, parallel_containers
300+
)
301+
302+
user_workload_env_vars = [
303+
{
304+
'name': 'PATHWAYS_HEAD',
305+
'valueFrom': "metadata.labels['jobset.sigs.k8s.io/coordinator']",
306+
},
307+
{
308+
'name': 'JAX_PLATFORMS',
309+
'value': 'proxy',
310+
},
311+
{
312+
'name': 'XCLOUD_ENVIRONMENT',
313+
'value': 'GCP',
314+
},
315+
{
316+
'name': 'JAX_BACKEND_TARGET',
317+
'value': 'grpc://$(PATHWAYS_HEAD):29000',
318+
},
319+
]
320+
321+
template_env = Environment(
322+
loader=FileSystemLoader(searchpath=get_templates_absolute_path()),
323+
trim_blocks=True,
324+
lstrip_blocks=True,
325+
keep_trailing_newline=True,
326+
)
327+
workload_create_yaml = template_env.get_template(_PATHWAYS_WORKLOAD_TEMPLATE)
328+
return workload_create_yaml.render(
329+
args=args,
330+
local_queue_name=LOCAL_QUEUE_NAME,
331+
proxy_server_image=proxy_server_image,
332+
server_image=server_image,
333+
instance_type=instance_type,
334+
user_workload_container=user_workload_container,
335+
user_workload_env_vars=user_workload_env_vars,
336+
worker_backoff_limit=worker_backoff_limit,
337+
vms_per_slice=workload_system.vms_per_slice,
338+
workload_system=workload_system,
339+
accelerator_label=create_accelerator_label(workload_system),
340+
node_selector_machine_label=create_machine_label(workload_system),
341+
placement_policy_label=placement_policy_label,
342+
autoprovisioning_args=autoprovisioning_args,
343+
worker_image=worker_image,
344+
)
345+
346+
349347
def workload_create_pathways(args) -> None:
350348
"""Run jobset apply command for a file, specifically for Pathways.
351349
@@ -695,46 +693,12 @@ def workload_create(args) -> None:
695693
elif args.use_pathways and ensure_pathways_workload_prerequisites(
696694
args, workload_system
697695
):
698-
if args.headless:
699-
pathways_head_containers = f""" containers:
700-
{append_custom_pathways_proxy_server(args)}
701-
{append_custom_pathways_server(args, workload_system)}
702-
{append_custom_colocated_python_sidecar(args)}"""
703-
success_policy = ''
704-
else:
705-
pathways_head_containers = f""" initContainers:
706-
{append_custom_pathways_proxy_server(args)}
707-
{append_custom_pathways_server(args, workload_system)}
708-
{append_custom_colocated_python_sidecar(args)}
709-
containers:
710-
{get_user_workload_for_pathways(args, workload_system, parallel_containers)}"""
711-
success_policy = """successPolicy:
712-
operator: All
713-
targetReplicatedJobs:
714-
- pathways-head"""
715-
716-
worker_backoff_limit = (
717-
(args.max_slice_restarts * workload_system.vms_per_slice)
718-
if getattr(args, 'elastic_slices', 0) > 0
719-
else (workload_system.vms_per_slice * 4)
720-
)
721-
722-
yml_string = PW_WORKLOAD_CREATE_YAML.format(
696+
yml_string = _generate_pathways_workload_yaml(
723697
args=args,
724-
topology=create_tpu_topology(workload_system),
725-
machine_type=create_tpu_machine_type(workload_system),
726-
pathways_head_containers=pathways_head_containers,
727-
custom_pathways_worker=append_custom_pathways_worker(
728-
args, workload_system
729-
),
730-
worker_backoff_limit=worker_backoff_limit,
731-
success_policy=success_policy,
732-
local_queue_name=LOCAL_QUEUE_NAME,
733-
autoprovisioning_args=autoprovisioning_args,
698+
workload_system=workload_system,
699+
parallel_containers=parallel_containers,
734700
placement_policy_label=placement_policy_label,
735-
vms_per_slice=workload_system.vms_per_slice,
736-
accelerator_label=create_accelerator_label(workload_system),
737-
node_selector_machine_label=create_machine_label(workload_system),
701+
autoprovisioning_args=autoprovisioning_args,
738702
)
739703
else:
740704
if use_sub_slicing:

src/xpk/commands/workload_test.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def test_workload_create_pathways_jobset_yaml(mocker):
329329
args.use_vertex_tensorboard = False
330330
args.headless = False
331331
args.num_slices = 2
332-
args.elastic_slices = 0
332+
args.elastic_slices = 2
333333
args.max_restarts = 1
334334
args.max_slice_restarts = 1
335335
args.termination_grace_period_seconds = 30
@@ -382,16 +382,25 @@ def test_workload_create_pathways_jobset_yaml(mocker):
382382
return_value=True,
383383
)
384384
mocker.patch(
385-
'xpk.core.pathways.get_user_workload_container',
386-
return_value=('- name: test-docker\n image: test-image', '123'),
387-
)
388-
mocker.patch('xpk.commands.workload.create_tpu_topology', return_value='4x4')
389-
mocker.patch(
390-
'xpk.commands.workload.create_tpu_machine_type',
391-
return_value='ct4p-hightpu-4t',
385+
'xpk.commands.workload.get_user_workload_container',
386+
return_value=(
387+
(
388+
'- name: test-docker\n image: test-image\n env:\n - name:'
389+
' FOO\n value: BAR'
390+
),
391+
'123',
392+
),
392393
)
393394

394-
mock_write_file = mocker.patch('builtins.open', mocker.mock_open())
395+
real_open = open
396+
m_open = mocker.mock_open()
397+
398+
def custom_open(file, *args, **kwargs):
399+
if str(file) == 'pw_manifest.yaml':
400+
return m_open(file, *args, **kwargs)
401+
return real_open(file, *args, **kwargs)
402+
403+
mocker.patch('builtins.open', side_effect=custom_open)
395404

396405
mocker.patch(
397406
'xpk.commands.workload.write_tmp_file', return_value='/tmp/test.yaml'
@@ -405,10 +414,8 @@ def test_workload_create_pathways_jobset_yaml(mocker):
405414

406415
workload_create(args)
407416

408-
mock_write_file.assert_called_once_with(
409-
'pw_manifest.yaml', 'w', encoding='utf-8'
410-
)
411-
written_content = mock_write_file.return_value.write.call_args[0][0]
417+
m_open.assert_called_once_with('pw_manifest.yaml', 'w', encoding='utf-8')
418+
written_content = m_open.return_value.write.call_args[0][0]
412419

413420
assert 'apiVersion: jobset.x-k8s.io/v1alpha2' in written_content
414421
assert 'kind: JobSet' in written_content
@@ -422,6 +429,14 @@ def test_workload_create_pathways_jobset_yaml(mocker):
422429
assert '- name: pathways-worker' in written_content
423430
assert f'replicas: {args.num_slices}' in written_content # worker replicas
424431

432+
# Assert custom arguments are correctly injected
433+
assert '- --custom_proxy_arg' in written_content
434+
assert '- --custom_server_arg' in written_content
435+
assert '- --custom_worker_arg' in written_content
436+
437+
# Assert elastic_slices is rendered
438+
assert '- --num_elastic_slices=2' in written_content
439+
425440
# Assert newly migrated JobSet specifics
426441
assert 'coordinator:' in written_content
427442
assert 'replicatedJob: pathways-head' in written_content
@@ -431,4 +446,14 @@ def test_workload_create_pathways_jobset_yaml(mocker):
431446
assert 'completionMode: Indexed' in written_content
432447
assert 'startupPolicyOrder: InOrder' in written_content
433448
assert 'operator: All' in written_content
434-
assert f'backoffLimit: {workload_system.vms_per_slice * 4}' in written_content
449+
assert (
450+
f'backoffLimit: {args.max_slice_restarts * workload_system.vms_per_slice}'
451+
in written_content
452+
)
453+
assert f'image: {args.proxy_server_image}' in written_content
454+
assert f'image: {args.server_image}' in written_content
455+
assert f'image: {args.colocated_python_sidecar_image}' in written_content
456+
assert f'image: {args.worker_image}' in written_content
457+
assert (
458+
f'--gcs_scratch_location={args.pathways_gcs_location}' in written_content
459+
)

0 commit comments

Comments
 (0)