@@ -51,8 +51,6 @@ def create_megatron_train_context(
5151def run_megatron_rl_job (
5252 ctx : MegatronTrainContext ,
5353 job : MegatronTrainingJob ,
54- * ,
55- job_path : str | None = None ,
5654) -> None :
5755 packed_tensors = None
5856 adapter_model = None
@@ -125,12 +123,6 @@ def run_megatron_rl_job(
125123 lora_path = job .lora_path ,
126124 optimizer_state_path = job .optimizer_state_path ,
127125 )
128- _complete_job (
129- ctx ,
130- job_path = job_path ,
131- log_path = job .log_path ,
132- cleanup_path = job .disk_packed_tensors ["dir" ],
133- )
134126 finally :
135127 if packed_tensors is not None :
136128 del packed_tensors
@@ -149,8 +141,6 @@ def run_megatron_rl_job(
149141def run_megatron_sft_job (
150142 ctx : MegatronTrainContext ,
151143 job : MegatronSFTTrainingJob ,
152- * ,
153- job_path : str | None = None ,
154144) -> None :
155145 adapter_model = None
156146
@@ -257,12 +247,6 @@ def run_megatron_sft_job(
257247 lora_path = job .lora_path ,
258248 optimizer_state_path = job .optimizer_state_path ,
259249 )
260- _complete_job (
261- ctx ,
262- job_path = job_path ,
263- log_path = job .log_path ,
264- cleanup_path = job .sft_data_dir ,
265- )
266250 finally :
267251 if adapter_model is not None :
268252 del adapter_model
@@ -381,19 +365,11 @@ def _load_lora_and_optimizer(
381365 optimizer_state_path : str ,
382366) -> dict [str , torch .Tensor ]:
383367 adapter_model_path = os .path .join (lora_path , "adapter_model.safetensors" )
384- if os .path .exists (adapter_model_path ):
385- print0 (ctx .rank , "Loading adapter model from" , adapter_model_path )
386- adapter_model = load_file (adapter_model_path )
387- load_adapter_into_model (ctx .model , adapter_model , ctx .optimizer )
388- else :
389- print0 (ctx .rank , "No adapter model found at" , adapter_model_path )
390- adapter_model = {}
391- with torch .no_grad ():
392- for chunk in ctx .model :
393- for module in chunk .modules ():
394- if hasattr (module , "reset_lora_parameters" ):
395- module .reset_lora_parameters () # type: ignore[attr-defined]
396- ctx .optimizer .reload_model_params ()
368+ if not os .path .exists (adapter_model_path ):
369+ raise FileNotFoundError (f"No adapter model found at { adapter_model_path } " )
370+ print0 (ctx .rank , "Loading adapter model from" , adapter_model_path )
371+ adapter_model = load_file (adapter_model_path )
372+ load_adapter_into_model (ctx .model , adapter_model , ctx .optimizer )
397373
398374 optimizer_shard_path = os .path .join (
399375 optimizer_state_path ,
@@ -449,7 +425,7 @@ def _save_lora_and_optimizer(
449425 torch .save (ctx .optimizer .state_dict (), optimizer_shard_path )
450426
451427
452- def _complete_job (
428+ def finalize_megatron_job (
453429 ctx : MegatronTrainContext ,
454430 * ,
455431 job_path : str | None ,
@@ -462,9 +438,10 @@ def _complete_job(
462438
463439 if job_path is not None and os .path .exists (job_path ):
464440 os .remove (job_path )
441+ if os .path .exists (cleanup_path ):
442+ shutil .rmtree (cleanup_path )
465443 with open (log_path , "a+" , encoding = "utf-8" ) as log_file :
466444 log_file .write ("all done\n " )
467- shutil .rmtree (cleanup_path )
468445
469446
470447def _placeholder_attention_mask (device : torch .device ) -> torch .Tensor :
0 commit comments