Skip to content

Commit 9e90c7d

Browse files
committed
Fix Megatron job finalization ordering
1 parent 3a679cb commit 9e90c7d

3 files changed

Lines changed: 28 additions & 38 deletions

File tree

src/art/megatron/service.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from ..vllm import get_llm, openai_server_task, run_on_workers
3030
from .jobs import (
3131
DEFAULT_JOBS_DIR,
32-
DEFAULT_TRAINING_LOG_PATH,
3332
DEFAULT_VLLM_WAKE_LOCK_PATH,
3433
MegatronTrainingJob,
3534
)
@@ -277,6 +276,8 @@ async def train(
277276
"moe_routing_replay_bundle is only supported for in-process/runtime APIs; "
278277
"MegatronService subprocess jobs must use moe_routing_replay_path."
279278
)
279+
log_dir = "/tmp/megatron_training_logs"
280+
os.makedirs(log_dir, exist_ok=True)
280281
job = MegatronTrainingJob(
281282
lora_path=lora_path,
282283
optimizer_state_path=self._optimizer_state_path,
@@ -285,6 +286,9 @@ async def train(
285286
experimental_config=_config,
286287
moe_routing_replay_path=_config.get("moe_routing_replay_path"),
287288
moe_routing_replay_strict=_config.get("moe_routing_replay_strict", True),
289+
log_path=os.path.join(
290+
log_dir, f"{datetime.datetime.now().isoformat()}.jsonl"
291+
),
288292
)
289293
job_path = os.path.join(
290294
DEFAULT_JOBS_DIR,
@@ -297,14 +301,14 @@ async def train(
297301
while True:
298302
await asyncio.sleep(0.1)
299303
try:
300-
with open(DEFAULT_TRAINING_LOG_PATH, "a+") as log_file:
304+
with open(job.log_path, "a+") as log_file:
301305
log_file.seek(0)
302306
lines = log_file.readlines()[num_lines:]
303307
for line in lines:
304308
if line := line.strip():
305309
if line == "all done":
306310
merge_lora_adapter(lora_path)
307-
os.remove(DEFAULT_TRAINING_LOG_PATH)
311+
os.remove(job.log_path)
308312
break
309313
num_lines += 1
310314
yield json.loads(line)

src/art/megatron/shared.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ def create_megatron_train_context(
5151
def run_megatron_rl_job(
5252
ctx: MegatronTrainContext,
5353
job: MegatronTrainingJob,
54-
*,
55-
job_path: str | None = None,
5654
) -> None:
5755
packed_tensors = None
5856
adapter_model = None
@@ -125,12 +123,6 @@ def run_megatron_rl_job(
125123
lora_path=job.lora_path,
126124
optimizer_state_path=job.optimizer_state_path,
127125
)
128-
_complete_job(
129-
ctx,
130-
job_path=job_path,
131-
log_path=job.log_path,
132-
cleanup_path=job.disk_packed_tensors["dir"],
133-
)
134126
finally:
135127
if packed_tensors is not None:
136128
del packed_tensors
@@ -149,8 +141,6 @@ def run_megatron_rl_job(
149141
def run_megatron_sft_job(
150142
ctx: MegatronTrainContext,
151143
job: MegatronSFTTrainingJob,
152-
*,
153-
job_path: str | None = None,
154144
) -> None:
155145
adapter_model = None
156146

@@ -257,12 +247,6 @@ def run_megatron_sft_job(
257247
lora_path=job.lora_path,
258248
optimizer_state_path=job.optimizer_state_path,
259249
)
260-
_complete_job(
261-
ctx,
262-
job_path=job_path,
263-
log_path=job.log_path,
264-
cleanup_path=job.sft_data_dir,
265-
)
266250
finally:
267251
if adapter_model is not None:
268252
del adapter_model
@@ -381,19 +365,11 @@ def _load_lora_and_optimizer(
381365
optimizer_state_path: str,
382366
) -> dict[str, torch.Tensor]:
383367
adapter_model_path = os.path.join(lora_path, "adapter_model.safetensors")
384-
if os.path.exists(adapter_model_path):
385-
print0(ctx.rank, "Loading adapter model from", adapter_model_path)
386-
adapter_model = load_file(adapter_model_path)
387-
load_adapter_into_model(ctx.model, adapter_model, ctx.optimizer)
388-
else:
389-
print0(ctx.rank, "No adapter model found at", adapter_model_path)
390-
adapter_model = {}
391-
with torch.no_grad():
392-
for chunk in ctx.model:
393-
for module in chunk.modules():
394-
if hasattr(module, "reset_lora_parameters"):
395-
module.reset_lora_parameters() # type: ignore[attr-defined]
396-
ctx.optimizer.reload_model_params()
368+
if not os.path.exists(adapter_model_path):
369+
raise FileNotFoundError(f"No adapter model found at {adapter_model_path}")
370+
print0(ctx.rank, "Loading adapter model from", adapter_model_path)
371+
adapter_model = load_file(adapter_model_path)
372+
load_adapter_into_model(ctx.model, adapter_model, ctx.optimizer)
397373

398374
optimizer_shard_path = os.path.join(
399375
optimizer_state_path,
@@ -449,7 +425,7 @@ def _save_lora_and_optimizer(
449425
torch.save(ctx.optimizer.state_dict(), optimizer_shard_path)
450426

451427

452-
def _complete_job(
428+
def finalize_megatron_job(
453429
ctx: MegatronTrainContext,
454430
*,
455431
job_path: str | None,
@@ -462,9 +438,10 @@ def _complete_job(
462438

463439
if job_path is not None and os.path.exists(job_path):
464440
os.remove(job_path)
441+
if os.path.exists(cleanup_path):
442+
shutil.rmtree(cleanup_path)
465443
with open(log_path, "a+", encoding="utf-8") as log_file:
466444
log_file.write("all done\n")
467-
shutil.rmtree(cleanup_path)
468445

469446

470447
def _placeholder_attention_mask(device: torch.device) -> torch.Tensor:

src/art/megatron/train.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from art.megatron.flex_attention import create_shared_prefix_attention_state
2727
from art.megatron.jobs import (
2828
DEFAULT_JOBS_DIR,
29-
DEFAULT_TRAINING_LOG_PATH,
3029
DEFAULT_VLLM_WAKE_LOCK_PATH,
3130
MegatronTrainingJob,
3231
)
@@ -565,7 +564,7 @@ def _run_service_loop(runtime: TrainingRuntime) -> None:
565564
offload_to_cpu(runtime.model, runtime.optimizer, runtime.rank, offload_state)
566565

567566
while True:
568-
from .shared import run_megatron_rl_job
567+
from .shared import finalize_megatron_job, run_megatron_rl_job
569568

570569
torch.distributed.barrier() # ty: ignore[possibly-missing-attribute]
571570
os.makedirs(DEFAULT_JOBS_DIR, exist_ok=True)
@@ -590,8 +589,18 @@ def _run_service_loop(runtime: TrainingRuntime) -> None:
590589

591590
print0(runtime.rank, "Loaded job from", job_path)
592591
print0(runtime.rank, "Job:", job)
593-
run_megatron_rl_job(runtime, job, job_path=job_path)
594-
offload_to_cpu(runtime.model, runtime.optimizer, runtime.rank, offload_state)
592+
try:
593+
run_megatron_rl_job(runtime, job)
594+
finally:
595+
offload_to_cpu(
596+
runtime.model, runtime.optimizer, runtime.rank, offload_state
597+
)
598+
finalize_megatron_job(
599+
runtime,
600+
job_path=job_path,
601+
log_path=job.log_path,
602+
cleanup_path=job.disk_packed_tensors["dir"],
603+
)
595604

596605

597606
def main() -> None:

0 commit comments

Comments
 (0)