Skip to content

Commit 19c906b

Browse files
committed
Rename Megatron merge helper
1 parent c2039fc commit 19c906b

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

src/art/megatron/service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
DEFAULT_VLLM_WAKE_LOCK_PATH,
3434
MegatronTrainingJob,
3535
)
36-
from .shared import merge_sharded_lora_adapter
36+
from .shared import merge_lora_adapter
3737

3838
safetensors = importlib.import_module("safetensors")
3939
safe_open = safetensors.safe_open
@@ -303,7 +303,7 @@ async def train(
303303
for line in lines:
304304
if line := line.strip():
305305
if line == "all done":
306-
merge_sharded_lora_adapter(lora_path)
306+
merge_lora_adapter(lora_path)
307307
os.remove(DEFAULT_TRAINING_LOG_PATH)
308308
break
309309
num_lines += 1

src/art/megatron/shared.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def run_megatron_sft_job(
270270
torch.cuda.empty_cache()
271271

272272

273-
def merge_sharded_lora_adapter(lora_path: str) -> None:
273+
def merge_lora_adapter(lora_path: str) -> None:
274274
base_dir = Path(lora_path)
275275
shard_filenames = sorted(base_dir.glob("adapter_model-*-of-*.safetensors"))
276276
if not shard_filenames:

0 commit comments

Comments
 (0)