66
77import gc
88import importlib
9- import json
109import math
1110import os
12- import shutil
1311import time
1412from typing import Any , Callable , cast
1513
3028 DEFAULT_JOBS_DIR ,
3129 DEFAULT_TRAINING_LOG_PATH ,
3230 DEFAULT_VLLM_WAKE_LOCK_PATH ,
31+ MegatronTrainingJob ,
3332)
3433from art .megatron .lora import apply_lora_adapters
3534from art .megatron .offload import (
4443 MoeRoutingReplayController ,
4544)
4645from art .preprocessing .pack import (
47- DiskPackedTensors ,
4846 PackedTensors ,
49- packed_tensors_from_dir ,
5047)
5148
5249safetensors_torch = importlib .import_module ("safetensors.torch" )
53- load_file = safetensors_torch .load_file
54- save_file = safetensors_torch .save_file
5550
5651DEFAULT_MODEL_IDENTIFIER = "Qwen/Qwen3-30B-A3B-Instruct-2507"
5752
5853
59- class TrainingJob (BaseModel ):
60- lora_path : str
61- optimizer_state_path : str
62- disk_packed_tensors : DiskPackedTensors
63- config : types .TrainConfig
64- experimental_config : dev .TrainConfig
65- moe_routing_replay_path : str | None = None
66- moe_routing_replay_strict : bool = True
67-
68-
69- TrainingJob .model_rebuild (
70- force = True ,
71- _types_namespace = {"MoeRoutingReplayBundle" : MoeRoutingReplayBundle },
72- )
73-
74-
7554class TrainingRuntime (BaseModel ):
7655 model_config = ConfigDict (arbitrary_types_allowed = True )
7756
@@ -586,6 +565,8 @@ def _run_service_loop(runtime: TrainingRuntime) -> None:
586565 offload_to_cpu (runtime .model , runtime .optimizer , runtime .rank , offload_state )
587566
588567 while True :
568+ from .shared import run_megatron_rl_job
569+
589570 torch .distributed .barrier () # ty: ignore[possibly-missing-attribute]
590571 os .makedirs (DEFAULT_JOBS_DIR , exist_ok = True )
591572 job_names = sorted (
@@ -605,135 +586,13 @@ def _run_service_loop(runtime: TrainingRuntime) -> None:
605586 job_name = job_names [0 ]
606587 job_path = os .path .join (DEFAULT_JOBS_DIR , job_name )
607588 with open (job_path , "rb" ) as handle :
608- job = TrainingJob .model_validate_json (handle .read ())
609- config = job .config
610- experimental_config = job .experimental_config
611-
612- configure_moe_routing_replay (
613- runtime ,
614- replay_bundle_path = job .moe_routing_replay_path ,
615- strict = job .moe_routing_replay_strict ,
616- )
589+ job = MegatronTrainingJob .model_validate_json (handle .read ())
617590
618591 print0 (runtime .rank , "Loaded job from" , job_path )
619592 print0 (runtime .rank , "Job:" , job )
620-
621- adapter_model_path = f"{ job .lora_path } /adapter_model.safetensors"
622- if not os .path .exists (adapter_model_path ):
623- raise FileNotFoundError (f"No adapter model found at { adapter_model_path } " )
624- print0 (runtime .rank , "Loading adapter model from" , adapter_model_path )
625- adapter_model = load_file (adapter_model_path )
626- load_adapter_into_model (runtime .model , adapter_model , runtime .optimizer )
627-
628- optimizer_shard_path = os .path .join (
629- job .optimizer_state_path ,
630- f"{ runtime .rank + 1 :02d} -of-{ runtime .world_size :02d} .pt" ,
631- )
632- if os .path .exists (optimizer_shard_path ):
633- print ("Loading optimizer state from" , optimizer_shard_path )
634- runtime .optimizer .load_state_dict (torch .load (optimizer_shard_path ))
635- else :
636- print (
637- "No optimizer state found at" ,
638- optimizer_shard_path ,
639- "- resetting optimizer for new run" ,
640- )
641- clear_optimizer_state (runtime .optimizer )
642- runtime .optimizer .reload_model_params ()
643-
644- print0 (
645- runtime .rank , "Loading packed tensors from" , job .disk_packed_tensors ["dir" ]
646- )
647- packed_tensors = packed_tensors_from_dir (** job .disk_packed_tensors )
648- template = _clone_packed_tensors (select_indexed_inputs (packed_tensors , 0 ))
649- zero_template = _zero_contribution_inputs (template )
650- num_sequences = job .disk_packed_tensors ["num_sequences" ]
651- global_grad_accumulation_sequences = config .grad_accumulation_sequences
652- num_steps = math .ceil (num_sequences / global_grad_accumulation_sequences )
653- for step_index in range (num_steps ):
654- micro_indices = build_micro_sample_indices (
655- step_index = step_index ,
656- num_sequences = num_sequences ,
657- global_grad_accumulation_sequences = global_grad_accumulation_sequences ,
658- )
659- micro_inputs = select_micro_inputs (
660- packed_tensors , micro_indices , zero_template
661- )
662- try :
663- step_result = run_training_step (
664- model_chunks = runtime .model ,
665- optimizer = runtime .optimizer ,
666- learning_rate = config .learning_rate ,
667- inputs = micro_inputs ,
668- config = config ,
669- experimental_config = experimental_config ,
670- ref_logprobs = None ,
671- step_index = step_index ,
672- sample_index = micro_indices ,
673- moe_routing_replay_controller = runtime .moe_routing_replay_controller ,
674- )
675- except Exception :
676- raise
677- print0 (
678- runtime .rank ,
679- "Correlation between old and new probabilities:" ,
680- step_result .probs_corr ,
681- )
682-
683- if runtime .rank == 0 :
684- with open (
685- DEFAULT_TRAINING_LOG_PATH , "a+" , encoding = "utf-8"
686- ) as log_file :
687- log_msg = json .dumps (
688- {
689- "loss" : step_result .reduced_loss .item (),
690- "grad_norm" : step_result .grad_norm ,
691- "probs_corr" : step_result .probs_corr ,
692- }
693- )
694- print ("Logging" , log_msg )
695- log_file .write (log_msg + "\n " )
696-
697- sharded_state_dict , sharded_state_manifest = collect_sharded_lora_state (
698- runtime .model ,
699- adapter_model ,
700- )
701- shard_path = os .path .join (
702- job .lora_path ,
703- f"adapter_model-{ runtime .rank + 1 :02d} -of-{ runtime .world_size :02d} .safetensors" ,
704- )
705- manifest_path = os .path .join (
706- job .lora_path ,
707- f"adapter_manifest-{ runtime .rank + 1 :02d} -of-{ runtime .world_size :02d} .json" ,
708- )
709- print ("Saving adapter shard to" , shard_path )
710- save_file (sharded_state_dict , shard_path )
711- print ("Saving adapter shard manifest to" , manifest_path )
712- with open (manifest_path , "w" , encoding = "utf-8" ) as manifest_file :
713- json .dump (sharded_state_manifest , manifest_file , sort_keys = True )
714-
715- print ("Saving optimizer shard to" , optimizer_shard_path )
716- os .makedirs (job .optimizer_state_path , exist_ok = True )
717- torch .save (runtime .optimizer .state_dict (), optimizer_shard_path )
718-
593+ run_megatron_rl_job (runtime , job , job_path = job_path )
719594 offload_to_cpu (runtime .model , runtime .optimizer , runtime .rank , offload_state )
720595
721- del packed_tensors
722- del template
723- del zero_template
724- del adapter_model
725- if "micro_inputs" in locals ():
726- del micro_inputs
727- gc .collect ()
728- torch .cuda .empty_cache ()
729-
730- torch .distributed .barrier () # ty: ignore[possibly-missing-attribute]
731- if runtime .rank == 0 :
732- os .remove (job_path )
733- with open (DEFAULT_TRAINING_LOG_PATH , "a+" , encoding = "utf-8" ) as log_file :
734- log_file .write ("all done\n " )
735- shutil .rmtree (job .disk_packed_tensors ["dir" ])
736-
737596
738597def main () -> None :
739598 runtime = build_training_runtime (
0 commit comments