Skip to content

Commit 86ae933

Browse files
committed
Fix lint and typing regressions in shared training refactor
1 parent a1b8efc commit 86ae933

6 files changed

Lines changed: 51 additions & 22 deletions

File tree

src/art/_backend_training.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,7 @@ def build_rl_train_configs(
5151
}
5252

5353
if allow_training_without_logprobs is not None:
54-
dev_config["allow_training_without_logprobs"] = (
55-
allow_training_without_logprobs
56-
)
54+
dev_config["allow_training_without_logprobs"] = allow_training_without_logprobs
5755
if plot_tensors is not None:
5856
dev_config["plot_tensors"] = plot_tensors
5957
if truncated_importance_sampling is not None:
@@ -63,9 +61,7 @@ def build_rl_train_configs(
6361
scale_learning_rate_by_reward_std_dev
6462
)
6563
if logprob_calculation_chunk_size is not None:
66-
dev_config["logprob_calculation_chunk_size"] = (
67-
logprob_calculation_chunk_size
68-
)
64+
dev_config["logprob_calculation_chunk_size"] = logprob_calculation_chunk_size
6965
if num_trajectories_learning_rate_multiplier_power is not None:
7066
dev_config["num_trajectories_learning_rate_multiplier_power"] = (
7167
num_trajectories_learning_rate_multiplier_power

src/art/megatron/service.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import asdict, dataclass
33
import datetime
44
from functools import cached_property
5+
import importlib
56
import json
67
import os
78
from pathlib import Path
@@ -10,8 +11,6 @@
1011
from typing import Any, AsyncIterator
1112

1213
from peft.tuners.lora.config import LoraConfig
13-
from safetensors import safe_open
14-
from safetensors.torch import load_file, save_file
1514
import torch
1615
from vllm import AsyncEngineArgs
1716
from vllm.lora.request import LoRARequest
@@ -31,6 +30,12 @@
3130
MegatronTrainingJob,
3231
)
3332

33+
safetensors = importlib.import_module("safetensors")
34+
safetensors_torch = importlib.import_module("safetensors.torch")
35+
safe_open = safetensors.safe_open
36+
load_file = safetensors_torch.load_file
37+
save_file = safetensors_torch.save_file
38+
3439

3540
@dataclass
3641
class MegatronService:

src/art/megatron/shared.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
1+
from dataclasses import dataclass
12
import gc
3+
import importlib
24
import json
35
import math
46
import os
57
import shutil
68
import time
7-
from dataclasses import dataclass
89
from typing import Any
910

1011
from megatron.core import parallel_state as ps
11-
from safetensors.torch import load_file, save_file
1212
import torch
1313

1414
from ..loss import loss_fn, shift_tensor
1515
from ..preprocessing.pack import PackedTensors, packed_tensors_from_dir
1616
from .flex_attention import create_shared_prefix_attention_state
1717
from .jobs import MegatronSFTTrainingJob, MegatronTrainingJob
1818

19+
safetensors_torch = importlib.import_module("safetensors.torch")
20+
load_file = safetensors_torch.load_file
21+
save_file = safetensors_torch.save_file
22+
1923

2024
@dataclass
2125
class MegatronTrainContext:
@@ -274,7 +278,9 @@ def run_megatron_sft_job(
274278
update_successful, grad_norm, num_zeros_in_grad = ctx.optimizer.step()
275279
ctx.optimizer.zero_grad()
276280

277-
torch.distributed.reduce(batch_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
281+
torch.distributed.reduce(
282+
batch_loss, dst=0, op=torch.distributed.ReduceOp.SUM
283+
)
278284
avg_loss = batch_loss / num_trainable_tokens
279285

280286
batch_time = time.perf_counter() - batch_start_time
@@ -289,7 +295,9 @@ def run_megatron_sft_job(
289295
"loss": avg_loss.item(),
290296
"learning_rate": job.learning_rates[batch_idx],
291297
"grad_norm": float(grad_norm),
292-
"num_trajectories": float(batch_metadata["num_trajectories"]),
298+
"num_trajectories": float(
299+
batch_metadata["num_trajectories"]
300+
),
293301
"num_trainable_tokens": float(num_trainable_tokens),
294302
"tokens_per_second": tokens_per_second,
295303
}

src/art/megatron/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
torch.distributed.barrier()
2929
os.makedirs(DEFAULT_JOBS_DIR, exist_ok=True)
3030
job_names = sorted(
31-
job_name for job_name in os.listdir(DEFAULT_JOBS_DIR) if job_name.endswith(".json")
31+
job_name
32+
for job_name in os.listdir(DEFAULT_JOBS_DIR)
33+
if job_name.endswith(".json")
3234
)
3335
if not job_names:
3436
time.sleep(1)

src/art/unsloth/service.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import os
99
import subprocess
1010
import sys
11-
from typing import Any, AsyncIterator
11+
from typing import Any, AsyncIterator, cast
1212

1313
from trl import GRPOTrainer
1414
from vllm import AsyncEngineArgs
@@ -18,6 +18,7 @@
1818
from .. import dev, types
1919
from ..dev.validate import is_dedicated_mode
2020
from ..local.checkpoints import get_last_checkpoint_dir
21+
from ..preprocessing.inputs import TrainInputs
2122
from ..preprocessing.pack import DiskPackedTensors
2223
from ..preprocessing.tokenize import SFTBatch
2324
from ..utils.convert_moe_lora import convert_checkpoint_if_needed
@@ -34,6 +35,7 @@
3435

3536
logger = logging.getLogger(__name__)
3637

38+
3739
def save_checkpoint(
3840
trainer: GRPOTrainer,
3941
output_dir: str,
@@ -558,7 +560,7 @@ async def train_sft(
558560

559561
@cached_property
560562
def _state(self) -> UnslothTrainContext:
561-
init_args = dict(self.config.get("init_args", {}))
563+
init_args = dict(cast(dict[str, Any], self.config.get("init_args") or {}))
562564
checkpoint_dir = get_last_checkpoint_dir(self.output_dir)
563565
if checkpoint_dir:
564566
init_args["model_name"] = checkpoint_dir
@@ -567,8 +569,11 @@ def _state(self) -> UnslothTrainContext:
567569

568570
return create_unsloth_train_context(
569571
init_args=init_args,
570-
peft_args=dict(self.config.get("peft_args", {})),
571-
trainer_args=dict(self.config.get("trainer_args", {})),
572+
peft_args=cast(dict[str, Any], self.config.get("peft_args") or {}),
573+
trainer_args=cast(
574+
dict[str, Any],
575+
self.config.get("trainer_args") or {},
576+
),
572577
)
573578

574579
@cached_property

src/art/unsloth/shared.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ def offload_to_cpu(self) -> None:
7373
if optimizer is not None and hasattr(optimizer, "state"):
7474
for param_id, state in optimizer.state.items():
7575
for key, value in state.items():
76-
if not isinstance(value, torch.Tensor) or value.device.type != "cuda":
76+
if (
77+
not isinstance(value, torch.Tensor)
78+
or value.device.type != "cuda"
79+
):
7780
continue
7881
buffer_key = f"opt_{id(param_id)}_{key}"
7982
if (
@@ -108,9 +111,14 @@ def reload_to_gpu(self, device: str = "cuda:0") -> None:
108111
if optimizer is not None and hasattr(optimizer, "state"):
109112
for state in optimizer.state.values():
110113
for key, value in state.items():
111-
if not isinstance(value, torch.Tensor) or value.device.type != "cpu":
114+
if (
115+
not isinstance(value, torch.Tensor)
116+
or value.device.type != "cpu"
117+
):
112118
continue
113-
gpu_tensor = torch.empty(value.shape, dtype=value.dtype, device=device)
119+
gpu_tensor = torch.empty(
120+
value.shape, dtype=value.dtype, device=device
121+
)
114122
gpu_tensor.copy_(value, non_blocking=True)
115123
state[key] = gpu_tensor
116124

@@ -224,7 +232,10 @@ def create_unsloth_train_context(
224232
loader_cls.from_pretrained(**init_args),
225233
)
226234

227-
if hasattr(model, "peft_config") and getattr(model, "peft_config", None) is not None:
235+
if (
236+
hasattr(model, "peft_config")
237+
and getattr(model, "peft_config", None) is not None
238+
):
228239
peft_model = cast(peft.peft_model.PeftModelForCausalLM, model)
229240
else:
230241
peft_model = cast(
@@ -301,7 +312,9 @@ def _precalculate_new_logprobs(
301312
if isinstance(value, torch.Tensor)
302313
},
303314
pixel_values=packed_tensors["pixel_values"][offset : offset + 1],
304-
image_grid_thw=packed_tensors["image_grid_thw"][offset : offset + 1],
315+
image_grid_thw=packed_tensors["image_grid_thw"][
316+
offset : offset + 1
317+
],
305318
config=config,
306319
_config=_config,
307320
return_new_logprobs=True,

0 commit comments

Comments
 (0)