-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbatch_jobs.py
More file actions
210 lines (167 loc) · 6.77 KB
/
batch_jobs.py
File metadata and controls
210 lines (167 loc) · 6.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
"""Tools for running batch jobs on a k8s cluster."""
import dataclasses
import shlex
import subprocess
from copy import deepcopy
from pathlib import Path
from typing import Any
from collections.abc import Sequence
from farconf import from_dict, to_dict
from farconf.cli import assign_from_dotlist
from names_generator import generate_name
from {{ cookiecutter.project_name }}.configs import RunConfig
from {{ cookiecutter.project_name }}.constants import PROJECT_SHORT, WANDB_ENTITY, WANDB_PROJECT
from {{ cookiecutter.project_name }}.utils.git import get_repo_root, git_latest_commit, validate_git_repo
JOB_TEMPLATE_PATH = get_repo_root() / "k8s" / "batch_job.yaml"
with JOB_TEMPLATE_PATH.open() as f:
JOB_TEMPLATE = f.read()
@dataclasses.dataclass
class ClusterRunConfig:
"""Configuration for a single run on the cluster.
Args:
config: The run configuration.
node_frac_needed: Fraction of a node needed for the run.
"""
config: RunConfig
node_frac_needed: float = 1.0
@dataclasses.dataclass(frozen=True)
class ClusterOptions:
"""Options for running jobs on the cluster.
These are typically shared between multiple runs.
"""
CONTAINER_TAG: str = "latest"
COMMIT_HASH: str = dataclasses.field(default_factory=git_latest_commit)
CPU: int = 4
MEMORY: str = "20G"
GPU: int = 1
PRIORITY: str = "normal-batch"
WANDB_ENTITY: str = WANDB_ENTITY
WANDB_PROJECT: str = WANDB_PROJECT
WANDB_MODE: str = "online"
def create_cluster_run_configs(
base_experiment_config: RunConfig,
override_args_and_node_frac_needed: Sequence[tuple[dict[str, Any], float]],
) -> list[ClusterRunConfig]:
"""Creates a list of configs from a base config and sequence of modification specs.
Args:
base_experiment_config: base config which will be used by all runs with some
specified modifications.
override_args_and_node_frac_needed: a sequence of tuples (override_args,
node_frac_needed), where override_args are merged with the base config to
produce a run config, and node_frac_needed is the fraction of a node
needed for a run. override_args is a dictionary with dot-separated keys
that correspond to the config structure.
Returns:
list of `ClusterRunConfig`s describing configurations to be run on the cluster.
"""
run_configs = []
for override_args, node_frac_needed in override_args_and_node_frac_needed:
config = to_dict(deepcopy(base_experiment_config))
assert isinstance(config, dict)
for key, value in override_args.items():
assign_from_dotlist(config, key, value)
config = from_dict(config, RunConfig)
run_configs.append(
ClusterRunConfig(config=config, node_frac_needed=node_frac_needed)
)
return run_configs
def run_multiple(
run_configs: Sequence[ClusterRunConfig],
cluster_options: ClusterOptions | None = None,
dry_run: bool = False,
) -> None:
"""Runs batch jobs on the cluster with the runs specified in `run_configs`.
Args:
run_configs: sequence of `ClusterRunConfig`s with configurations to be run.
cluster_options: options for running jobs on the cluster.
dry_run: if True, only print the job configs without running them. Otherwise,
submit the jobs to the cluster.
"""
if cluster_options is None:
cluster_options = ClusterOptions()
validate_git_repo()
# All configs should have the same experiment name.
assert all(
[
run_config.config.experiment_name == run_configs[0].config.experiment_name
for run_config in run_configs
]
)
jobs, launch_id = _create_jobs(run_configs, cluster_options)
yamls_for_all_jobs = "\n\n---\n\n".join(jobs)
print(yamls_for_all_jobs)
if not dry_run:
print(f"Launching jobs with {launch_id=}...")
subprocess.run(
["kubectl", "create", "-f", "-"],
check=True,
input=yamls_for_all_jobs.encode(),
)
print(
"Jobs launched. To delete them run:\n"
f"kubectl delete jobs -l launch-id={launch_id}"
)
def _create_jobs(
run_configs: Sequence[ClusterRunConfig], cluster_options: ClusterOptions
) -> tuple[Sequence[str], str]:
"""Creates k8s job configs for the runs, together with a launch id."""
launch_id = generate_name(style="hyphen")
runs_by_containers = _organize_by_containers(run_configs)
jobs = []
index = 0
for runs in runs_by_containers:
jobs.append(
_create_job_for_multiple_runs(runs, cluster_options, launch_id, index)
)
index += len(runs)
return jobs, launch_id
def _create_job_for_multiple_runs(
runs: Sequence[RunConfig],
cluster_options: ClusterOptions,
launch_id: str,
start_index: int,
) -> str:
"""Creates a single k8s job config for runs to be run in a single container."""
name = runs[0].experiment_name
# K8s job/pod names should be short for readability (hence cutting the name).
end_index = start_index + len(runs) - 1
k8s_job_name = (
f"{PROJECT_SHORT}-{name[:16]}-{start_index:03}"
if len(runs) == 1
else f"{PROJECT_SHORT}-{name[:16]}-{start_index:03}-{end_index:03}"
).replace("_", "-")
single_commands = []
for i, run in enumerate(runs):
run.run_name = f"{name}-{start_index + i:03}"
split_command = ["python", run.script_path, *run.to_cli()]
single_commands.append(shlex.join(split_command))
single_commands.append("wait")
command = "(" + " & ".join(single_commands) + ")"
cluster_options_dict = to_dict(cluster_options)
assert isinstance(cluster_options_dict, dict)
job = JOB_TEMPLATE.format(
NAME=k8s_job_name,
COMMAND=command,
LAUNCH_ID=launch_id,
**cluster_options_dict,
)
return job
def _organize_by_containers(
cluster_runs: Sequence[ClusterRunConfig],
) -> list[list[RunConfig]]:
"""Splits runs into groups that will be run together in a single k8s container."""
# Sort from "less demanding" to "more demanding" jobs.
cluster_runs = sorted(cluster_runs, key=lambda run: run.node_frac_needed)
current_container: list[RunConfig] = []
current_weight = 0.0
runs_by_containers = [current_container]
for cluster_run in cluster_runs:
# Run can fit into the current container.
if current_weight + cluster_run.node_frac_needed <= 1:
current_container.append(cluster_run.config)
current_weight += cluster_run.node_frac_needed
else:
current_container = [cluster_run.config]
current_weight = cluster_run.node_frac_needed
runs_by_containers.append(current_container)
return runs_by_containers