Skip to content

Commit 9d75910

Browse files
committed
Deduplicate local and shared training logic
1 parent 19c906b commit 9d75910

3 files changed

Lines changed: 72 additions & 641 deletions

File tree

src/art/megatron/train.py

Lines changed: 5 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66

77
import gc
88
import importlib
9-
import json
109
import math
1110
import os
12-
import shutil
1311
import time
1412
from typing import Any, Callable, cast
1513

@@ -30,6 +28,7 @@
3028
DEFAULT_JOBS_DIR,
3129
DEFAULT_TRAINING_LOG_PATH,
3230
DEFAULT_VLLM_WAKE_LOCK_PATH,
31+
MegatronTrainingJob,
3332
)
3433
from art.megatron.lora import apply_lora_adapters
3534
from art.megatron.offload import (
@@ -44,34 +43,14 @@
4443
MoeRoutingReplayController,
4544
)
4645
from art.preprocessing.pack import (
47-
DiskPackedTensors,
4846
PackedTensors,
49-
packed_tensors_from_dir,
5047
)
5148

5249
safetensors_torch = importlib.import_module("safetensors.torch")
53-
load_file = safetensors_torch.load_file
54-
save_file = safetensors_torch.save_file
5550

5651
DEFAULT_MODEL_IDENTIFIER = "Qwen/Qwen3-30B-A3B-Instruct-2507"
5752

5853

59-
class TrainingJob(BaseModel):
60-
lora_path: str
61-
optimizer_state_path: str
62-
disk_packed_tensors: DiskPackedTensors
63-
config: types.TrainConfig
64-
experimental_config: dev.TrainConfig
65-
moe_routing_replay_path: str | None = None
66-
moe_routing_replay_strict: bool = True
67-
68-
69-
TrainingJob.model_rebuild(
70-
force=True,
71-
_types_namespace={"MoeRoutingReplayBundle": MoeRoutingReplayBundle},
72-
)
73-
74-
7554
class TrainingRuntime(BaseModel):
7655
model_config = ConfigDict(arbitrary_types_allowed=True)
7756

@@ -586,6 +565,8 @@ def _run_service_loop(runtime: TrainingRuntime) -> None:
586565
offload_to_cpu(runtime.model, runtime.optimizer, runtime.rank, offload_state)
587566

588567
while True:
568+
from .shared import run_megatron_rl_job
569+
589570
torch.distributed.barrier() # ty: ignore[possibly-missing-attribute]
590571
os.makedirs(DEFAULT_JOBS_DIR, exist_ok=True)
591572
job_names = sorted(
@@ -605,135 +586,13 @@ def _run_service_loop(runtime: TrainingRuntime) -> None:
605586
job_name = job_names[0]
606587
job_path = os.path.join(DEFAULT_JOBS_DIR, job_name)
607588
with open(job_path, "rb") as handle:
608-
job = TrainingJob.model_validate_json(handle.read())
609-
config = job.config
610-
experimental_config = job.experimental_config
611-
612-
configure_moe_routing_replay(
613-
runtime,
614-
replay_bundle_path=job.moe_routing_replay_path,
615-
strict=job.moe_routing_replay_strict,
616-
)
589+
job = MegatronTrainingJob.model_validate_json(handle.read())
617590

618591
print0(runtime.rank, "Loaded job from", job_path)
619592
print0(runtime.rank, "Job:", job)
620-
621-
adapter_model_path = f"{job.lora_path}/adapter_model.safetensors"
622-
if not os.path.exists(adapter_model_path):
623-
raise FileNotFoundError(f"No adapter model found at {adapter_model_path}")
624-
print0(runtime.rank, "Loading adapter model from", adapter_model_path)
625-
adapter_model = load_file(adapter_model_path)
626-
load_adapter_into_model(runtime.model, adapter_model, runtime.optimizer)
627-
628-
optimizer_shard_path = os.path.join(
629-
job.optimizer_state_path,
630-
f"{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.pt",
631-
)
632-
if os.path.exists(optimizer_shard_path):
633-
print("Loading optimizer state from", optimizer_shard_path)
634-
runtime.optimizer.load_state_dict(torch.load(optimizer_shard_path))
635-
else:
636-
print(
637-
"No optimizer state found at",
638-
optimizer_shard_path,
639-
"- resetting optimizer for new run",
640-
)
641-
clear_optimizer_state(runtime.optimizer)
642-
runtime.optimizer.reload_model_params()
643-
644-
print0(
645-
runtime.rank, "Loading packed tensors from", job.disk_packed_tensors["dir"]
646-
)
647-
packed_tensors = packed_tensors_from_dir(**job.disk_packed_tensors)
648-
template = _clone_packed_tensors(select_indexed_inputs(packed_tensors, 0))
649-
zero_template = _zero_contribution_inputs(template)
650-
num_sequences = job.disk_packed_tensors["num_sequences"]
651-
global_grad_accumulation_sequences = config.grad_accumulation_sequences
652-
num_steps = math.ceil(num_sequences / global_grad_accumulation_sequences)
653-
for step_index in range(num_steps):
654-
micro_indices = build_micro_sample_indices(
655-
step_index=step_index,
656-
num_sequences=num_sequences,
657-
global_grad_accumulation_sequences=global_grad_accumulation_sequences,
658-
)
659-
micro_inputs = select_micro_inputs(
660-
packed_tensors, micro_indices, zero_template
661-
)
662-
try:
663-
step_result = run_training_step(
664-
model_chunks=runtime.model,
665-
optimizer=runtime.optimizer,
666-
learning_rate=config.learning_rate,
667-
inputs=micro_inputs,
668-
config=config,
669-
experimental_config=experimental_config,
670-
ref_logprobs=None,
671-
step_index=step_index,
672-
sample_index=micro_indices,
673-
moe_routing_replay_controller=runtime.moe_routing_replay_controller,
674-
)
675-
except Exception:
676-
raise
677-
print0(
678-
runtime.rank,
679-
"Correlation between old and new probabilities:",
680-
step_result.probs_corr,
681-
)
682-
683-
if runtime.rank == 0:
684-
with open(
685-
DEFAULT_TRAINING_LOG_PATH, "a+", encoding="utf-8"
686-
) as log_file:
687-
log_msg = json.dumps(
688-
{
689-
"loss": step_result.reduced_loss.item(),
690-
"grad_norm": step_result.grad_norm,
691-
"probs_corr": step_result.probs_corr,
692-
}
693-
)
694-
print("Logging", log_msg)
695-
log_file.write(log_msg + "\n")
696-
697-
sharded_state_dict, sharded_state_manifest = collect_sharded_lora_state(
698-
runtime.model,
699-
adapter_model,
700-
)
701-
shard_path = os.path.join(
702-
job.lora_path,
703-
f"adapter_model-{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.safetensors",
704-
)
705-
manifest_path = os.path.join(
706-
job.lora_path,
707-
f"adapter_manifest-{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.json",
708-
)
709-
print("Saving adapter shard to", shard_path)
710-
save_file(sharded_state_dict, shard_path)
711-
print("Saving adapter shard manifest to", manifest_path)
712-
with open(manifest_path, "w", encoding="utf-8") as manifest_file:
713-
json.dump(sharded_state_manifest, manifest_file, sort_keys=True)
714-
715-
print("Saving optimizer shard to", optimizer_shard_path)
716-
os.makedirs(job.optimizer_state_path, exist_ok=True)
717-
torch.save(runtime.optimizer.state_dict(), optimizer_shard_path)
718-
593+
run_megatron_rl_job(runtime, job, job_path=job_path)
719594
offload_to_cpu(runtime.model, runtime.optimizer, runtime.rank, offload_state)
720595

721-
del packed_tensors
722-
del template
723-
del zero_template
724-
del adapter_model
725-
if "micro_inputs" in locals():
726-
del micro_inputs
727-
gc.collect()
728-
torch.cuda.empty_cache()
729-
730-
torch.distributed.barrier() # ty: ignore[possibly-missing-attribute]
731-
if runtime.rank == 0:
732-
os.remove(job_path)
733-
with open(DEFAULT_TRAINING_LOG_PATH, "a+", encoding="utf-8") as log_file:
734-
log_file.write("all done\n")
735-
shutil.rmtree(job.disk_packed_tensors["dir"])
736-
737596

738597
def main() -> None:
739598
runtime = build_training_runtime(

0 commit comments

Comments
 (0)