Skip to content

Commit 511d72c

Browse files
committed
Share Megatron worker loop
1 parent 9e90c7d commit 511d72c

2 files changed

Lines changed: 83 additions & 39 deletions

File tree

src/art/megatron/shared.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pathlib import Path
77
import shutil
88
import time
9-
from typing import Any
9+
from typing import Any, Callable
1010

1111
from megatron.core import parallel_state as ps
1212
import torch
@@ -15,7 +15,7 @@
1515
from ..preprocessing.pack import PackedTensors, packed_tensors_from_dir
1616
from .finalize_grads import finalize_model_grads_extended
1717
from .flex_attention import create_shared_prefix_attention_state
18-
from .jobs import MegatronSFTTrainingJob, MegatronTrainingJob
18+
from .jobs import DEFAULT_JOBS_DIR, MegatronSFTTrainingJob, MegatronTrainingJob
1919
from .offload import clear_optimizer_state
2020
from .train import (
2121
DEFAULT_MODEL_IDENTIFIER,
@@ -40,6 +40,7 @@
4040
save_file = safetensors_torch.save_file
4141

4242
MegatronTrainContext = TrainingRuntime
43+
MegatronJob = MegatronTrainingJob | MegatronSFTTrainingJob
4344

4445

4546
def create_megatron_train_context(
@@ -48,6 +49,50 @@ def create_megatron_train_context(
4849
return build_training_runtime(model_identifier=model_identifier)
4950

5051

52+
def run_megatron_worker_loop(
53+
ctx: MegatronTrainContext,
54+
*,
55+
supports_sft: bool,
56+
wait_until_ready: Callable[[], None] | None = None,
57+
before_job: Callable[[], None] | None = None,
58+
after_job: Callable[[], None] | None = None,
59+
) -> None:
60+
while True:
61+
torch.distributed.barrier() # type: ignore[possibly-missing-attribute]
62+
os.makedirs(DEFAULT_JOBS_DIR, exist_ok=True)
63+
job_names = sorted(
64+
job_name
65+
for job_name in os.listdir(DEFAULT_JOBS_DIR)
66+
if job_name.endswith(".json")
67+
)
68+
if not job_names:
69+
time.sleep(1)
70+
continue
71+
72+
if wait_until_ready is not None:
73+
wait_until_ready()
74+
if before_job is not None:
75+
before_job()
76+
77+
job_path = os.path.join(DEFAULT_JOBS_DIR, job_names[0])
78+
job = _load_megatron_job(job_path, supports_sft=supports_sft)
79+
print0(ctx.rank, "Loaded job from", job_path)
80+
print0(ctx.rank, "Job:", job)
81+
82+
try:
83+
_run_megatron_job(ctx, job)
84+
finally:
85+
if after_job is not None:
86+
after_job()
87+
88+
finalize_megatron_job(
89+
ctx,
90+
job_path=job_path,
91+
log_path=job.log_path,
92+
cleanup_path=_job_cleanup_path(job),
93+
)
94+
95+
5196
def run_megatron_rl_job(
5297
ctx: MegatronTrainContext,
5398
job: MegatronTrainingJob,
@@ -254,6 +299,29 @@ def run_megatron_sft_job(
254299
torch.cuda.empty_cache()
255300

256301

302+
def _load_megatron_job(job_path: str, *, supports_sft: bool) -> MegatronJob:
303+
with open(job_path, "rb") as handle:
304+
job_data = json.loads(handle.read())
305+
if job_data.get("job_type") == "sft":
306+
if not supports_sft:
307+
raise NotImplementedError("SFT jobs are not supported in this worker loop")
308+
return MegatronSFTTrainingJob.model_validate(job_data)
309+
return MegatronTrainingJob.model_validate(job_data)
310+
311+
312+
def _run_megatron_job(ctx: MegatronTrainContext, job: MegatronJob) -> None:
313+
if isinstance(job, MegatronSFTTrainingJob):
314+
run_megatron_sft_job(ctx, job)
315+
return
316+
run_megatron_rl_job(ctx, job)
317+
318+
319+
def _job_cleanup_path(job: MegatronJob) -> str:
320+
if isinstance(job, MegatronSFTTrainingJob):
321+
return job.sft_data_dir
322+
return job.disk_packed_tensors["dir"]
323+
324+
257325
def merge_lora_adapter(lora_path: str) -> None:
258326
base_dir = Path(lora_path)
259327
shard_filenames = sorted(base_dir.glob("adapter_model-*-of-*.safetensors"))

src/art/megatron/train.py

Lines changed: 13 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@
2525
from art.megatron.finalize_grads import finalize_model_grads_extended
2626
from art.megatron.flex_attention import create_shared_prefix_attention_state
2727
from art.megatron.jobs import (
28-
DEFAULT_JOBS_DIR,
2928
DEFAULT_VLLM_WAKE_LOCK_PATH,
30-
MegatronTrainingJob,
3129
)
3230
from art.megatron.lora import apply_lora_adapters
3331
from art.megatron.offload import (
@@ -562,45 +560,23 @@ def run_training_step(
562560
def _run_service_loop(runtime: TrainingRuntime) -> None:
563561
offload_state = OffloadState()
564562
offload_to_cpu(runtime.model, runtime.optimizer, runtime.rank, offload_state)
563+
from .shared import run_megatron_worker_loop
565564

566-
while True:
567-
from .shared import finalize_megatron_job, run_megatron_rl_job
568-
569-
torch.distributed.barrier() # ty: ignore[possibly-missing-attribute]
570-
os.makedirs(DEFAULT_JOBS_DIR, exist_ok=True)
571-
job_names = sorted(
572-
job_name
573-
for job_name in os.listdir(DEFAULT_JOBS_DIR)
574-
if job_name.endswith(".json")
575-
)
576-
if not job_names:
577-
time.sleep(1)
578-
continue
579-
565+
def wait_until_ready() -> None:
580566
while os.path.exists(DEFAULT_VLLM_WAKE_LOCK_PATH):
581567
time.sleep(0.2)
582568

583-
reload_to_gpu(runtime.model, runtime.optimizer, runtime.rank, offload_state)
584-
585-
job_name = job_names[0]
586-
job_path = os.path.join(DEFAULT_JOBS_DIR, job_name)
587-
with open(job_path, "rb") as handle:
588-
job = MegatronTrainingJob.model_validate_json(handle.read())
589-
590-
print0(runtime.rank, "Loaded job from", job_path)
591-
print0(runtime.rank, "Job:", job)
592-
try:
593-
run_megatron_rl_job(runtime, job)
594-
finally:
595-
offload_to_cpu(
596-
runtime.model, runtime.optimizer, runtime.rank, offload_state
597-
)
598-
finalize_megatron_job(
599-
runtime,
600-
job_path=job_path,
601-
log_path=job.log_path,
602-
cleanup_path=job.disk_packed_tensors["dir"],
603-
)
569+
run_megatron_worker_loop(
570+
runtime,
571+
supports_sft=False,
572+
wait_until_ready=wait_until_ready,
573+
before_job=lambda: reload_to_gpu(
574+
runtime.model, runtime.optimizer, runtime.rank, offload_state
575+
),
576+
after_job=lambda: offload_to_cpu(
577+
runtime.model, runtime.optimizer, runtime.rank, offload_state
578+
),
579+
)
604580

605581

606582
def main() -> None:

0 commit comments

Comments
 (0)