66from pathlib import Path
77import shutil
88import time
9- from typing import Any
9+ from typing import Any , Callable
1010
1111from megatron .core import parallel_state as ps
1212import torch
1515from ..preprocessing .pack import PackedTensors , packed_tensors_from_dir
1616from .finalize_grads import finalize_model_grads_extended
1717from .flex_attention import create_shared_prefix_attention_state
18- from .jobs import MegatronSFTTrainingJob , MegatronTrainingJob
18+ from .jobs import DEFAULT_JOBS_DIR , MegatronSFTTrainingJob , MegatronTrainingJob
1919from .offload import clear_optimizer_state
2020from .train import (
2121 DEFAULT_MODEL_IDENTIFIER ,
4040save_file = safetensors_torch .save_file
4141
4242MegatronTrainContext = TrainingRuntime
43+ MegatronJob = MegatronTrainingJob | MegatronSFTTrainingJob
4344
4445
4546def create_megatron_train_context (
@@ -48,6 +49,50 @@ def create_megatron_train_context(
4849 return build_training_runtime (model_identifier = model_identifier )
4950
5051
52+ def run_megatron_worker_loop (
53+ ctx : MegatronTrainContext ,
54+ * ,
55+ supports_sft : bool ,
56+ wait_until_ready : Callable [[], None ] | None = None ,
57+ before_job : Callable [[], None ] | None = None ,
58+ after_job : Callable [[], None ] | None = None ,
59+ ) -> None :
60+ while True :
61+ torch .distributed .barrier () # type: ignore[possibly-missing-attribute]
62+ os .makedirs (DEFAULT_JOBS_DIR , exist_ok = True )
63+ job_names = sorted (
64+ job_name
65+ for job_name in os .listdir (DEFAULT_JOBS_DIR )
66+ if job_name .endswith (".json" )
67+ )
68+ if not job_names :
69+ time .sleep (1 )
70+ continue
71+
72+ if wait_until_ready is not None :
73+ wait_until_ready ()
74+ if before_job is not None :
75+ before_job ()
76+
77+ job_path = os .path .join (DEFAULT_JOBS_DIR , job_names [0 ])
78+ job = _load_megatron_job (job_path , supports_sft = supports_sft )
79+ print0 (ctx .rank , "Loaded job from" , job_path )
80+ print0 (ctx .rank , "Job:" , job )
81+
82+ try :
83+ _run_megatron_job (ctx , job )
84+ finally :
85+ if after_job is not None :
86+ after_job ()
87+
88+ finalize_megatron_job (
89+ ctx ,
90+ job_path = job_path ,
91+ log_path = job .log_path ,
92+ cleanup_path = _job_cleanup_path (job ),
93+ )
94+
95+
5196def run_megatron_rl_job (
5297 ctx : MegatronTrainContext ,
5398 job : MegatronTrainingJob ,
@@ -254,6 +299,29 @@ def run_megatron_sft_job(
254299 torch .cuda .empty_cache ()
255300
256301
302+ def _load_megatron_job (job_path : str , * , supports_sft : bool ) -> MegatronJob :
303+ with open (job_path , "rb" ) as handle :
304+ job_data = json .loads (handle .read ())
305+ if job_data .get ("job_type" ) == "sft" :
306+ if not supports_sft :
307+ raise NotImplementedError ("SFT jobs are not supported in this worker loop" )
308+ return MegatronSFTTrainingJob .model_validate (job_data )
309+ return MegatronTrainingJob .model_validate (job_data )
310+
311+
312+ def _run_megatron_job (ctx : MegatronTrainContext , job : MegatronJob ) -> None :
313+ if isinstance (job , MegatronSFTTrainingJob ):
314+ run_megatron_sft_job (ctx , job )
315+ return
316+ run_megatron_rl_job (ctx , job )
317+
318+
319+ def _job_cleanup_path (job : MegatronJob ) -> str :
320+ if isinstance (job , MegatronSFTTrainingJob ):
321+ return job .sft_data_dir
322+ return job .disk_packed_tensors ["dir" ]
323+
324+
257325def merge_lora_adapter (lora_path : str ) -> None :
258326 base_dir = Path (lora_path )
259327 shard_filenames = sorted (base_dir .glob ("adapter_model-*-of-*.safetensors" ))
0 commit comments