Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c2039fc
Refresh shared training refactor on top of ART main
Kovbo Mar 28, 2026
19c906b
Rename Megatron merge helper
Kovbo Mar 28, 2026
9d75910
Deduplicate local and shared training logic
Kovbo Mar 28, 2026
6d0d2ae
Fix Megatron rope theta compatibility
Kovbo Mar 28, 2026
9c474c9
Remove Megatron rope theta workaround
Kovbo Mar 28, 2026
2fa8ffb
Align Unsloth SFT weight decay defaults
Kovbo Mar 28, 2026
8cb71cc
remove apex from no-build-isolation-package
Kovbo Mar 28, 2026
3a679cb
update install script
Kovbo Mar 28, 2026
9e90c7d
Fix Megatron job finalization ordering
Kovbo Mar 30, 2026
511d72c
Share Megatron worker loop
Kovbo Mar 30, 2026
2e64da0
Default Megatron grad accumulation by DP size
Kovbo Apr 1, 2026
0cee7cf
Collapse Megatron shared API into train module
Kovbo Apr 1, 2026
911c082
Remove Megatron shared shim
Kovbo Apr 1, 2026
0fa9a2b
Collapse Unsloth shared API into train module
Kovbo Apr 1, 2026
f6cd445
Lighten Megatron orchestration imports
Kovbo Apr 1, 2026
ff28081
Merge branch 'main' of github.com:OpenPipe/ART into feat/shared-train…
Kovbo Apr 2, 2026
3116a1b
Merge branch 'feat/shared-training-code' of github.com:OpenPipe/ART i…
Kovbo Apr 2, 2026
d08f2ad
fix: normalize SFT loss by token count before backward pass
Kovbo Apr 2, 2026
21dd5a3
Revert "fix: normalize SFT loss by token count before backward pass"
Kovbo Apr 3, 2026
d68ae3d
Support Megatron SFT in local backend
Kovbo Apr 3, 2026
f8fee63
refactor: extract create_identity_lora as standalone function
Kovbo Apr 3, 2026
baac098
Fix SFT main_grad fallback in Megatron
Kovbo Apr 6, 2026
aa2fd4b
Fix ART lint and type issues
Kovbo Apr 6, 2026
497ff3c
Simplify ty-safe optimizer access
Kovbo Apr 6, 2026
2be0333
test: drop megatron sft batch unit test
Kovbo Apr 7, 2026
b322072
refactor: revert direct safetensors import in moe conversion
Kovbo Apr 7, 2026
7c5a02b
style: format megatron oracle harness
Kovbo Apr 7, 2026
82fa9d0
refactor: use direct safetensors import in routing replay
Kovbo Apr 7, 2026
40e66aa
fix: isolate megatron optimizer states and step counts
FurtherAI Apr 7, 2026
9bf7001
Add SFT oracle coverage and shared grad scheduling
FurtherAI Apr 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions scripts/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,14 @@ else
echo "Skipping git reset/clean (GIT_RESET_CLEAN is not true). Preserving synced working tree."
fi

# Install astral-uv
if ! command -v uv >/dev/null 2>&1; then
if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then
echo "Failed to install uv." >&2
exit 1
fi
export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH"
# Install astral-uv (standalone version)
# Always prepend standalone install path so it takes precedence over system/conda uv
export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH"
if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then
echo "Failed to install uv." >&2
exit 1
fi

# Update uv
uv self update

# Sync the dependencies
if [ "${INSTALL_EXTRAS:-false}" = "true" ]; then
uv sync --all-extras
Expand Down
105 changes: 105 additions & 0 deletions src/art/_backend_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from collections.abc import Iterable
import time
from typing import Literal

from . import dev
from .metrics_taxonomy import (
average_metric_samples,
build_training_summary_metrics,
summarize_trajectory_groups,
)
from .trajectories import TrajectoryGroup
from .types import TrainConfig


def build_rl_train_configs(
*,
learning_rate: float,
advantage_balance: float = 0.0,
scale_rewards: bool = True,
importance_sampling_level: Literal[
"token", "sequence", "average", "geometric_average"
] = "token",
mask_prob_ratio: bool = False,
ppo: bool = False,
precalculate_logprobs: bool = False,
epsilon: float | None = None,
epsilon_high: float | None = None,
max_negative_advantage_importance_sampling_weight: float | None = None,
kimi_k2_tau: float | None = None,
kl_penalty_coef: float = 0.0,
allow_training_without_logprobs: bool | None = None,
plot_tensors: bool | None = None,
truncated_importance_sampling: float | None = None,
scale_learning_rate_by_reward_std_dev: bool | None = None,
logprob_calculation_chunk_size: int | None = None,
num_trajectories_learning_rate_multiplier_power: float | None = None,
kl_ref_adapter_path: str | None = None,
) -> tuple[TrainConfig, dev.TrainConfig]:
config = TrainConfig(
learning_rate=learning_rate,
kl_penalty_coef=kl_penalty_coef,
)
dev_config: dev.TrainConfig = {
"advantage_balance": advantage_balance,
"importance_sampling_level": importance_sampling_level,
"kl_penalty_coef": kl_penalty_coef,
"mask_prob_ratio": mask_prob_ratio,
"ppo": ppo,
"precalculate_logprobs": precalculate_logprobs,
"scale_rewards": scale_rewards,
}

if allow_training_without_logprobs is not None:
dev_config["allow_training_without_logprobs"] = allow_training_without_logprobs
if plot_tensors is not None:
dev_config["plot_tensors"] = plot_tensors
if truncated_importance_sampling is not None:
dev_config["truncated_importance_sampling"] = truncated_importance_sampling
if scale_learning_rate_by_reward_std_dev is not None:
dev_config["scale_learning_rate_by_reward_std_dev"] = (
scale_learning_rate_by_reward_std_dev
)
if logprob_calculation_chunk_size is not None:
dev_config["logprob_calculation_chunk_size"] = logprob_calculation_chunk_size
if num_trajectories_learning_rate_multiplier_power is not None:
dev_config["num_trajectories_learning_rate_multiplier_power"] = (
num_trajectories_learning_rate_multiplier_power
)
if epsilon is not None:
dev_config["epsilon"] = epsilon
if epsilon_high is not None:
dev_config["epsilon_high"] = epsilon_high
if max_negative_advantage_importance_sampling_weight is not None:
dev_config["max_negative_advantage_importance_sampling_weight"] = (
max_negative_advantage_importance_sampling_weight
)
if kimi_k2_tau is not None:
dev_config["kimi_k2_tau"] = kimi_k2_tau
if kl_ref_adapter_path is not None:
dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path

return config, dev_config


def aggregate_rl_training_metrics(
*,
training_metrics: list[dict[str, float]],
trajectory_groups: Iterable[TrajectoryGroup],
trainer_started: float,
) -> dict[str, float]:
groups_list = list(trajectory_groups)
avg_metrics = average_metric_samples(training_metrics)
summary = summarize_trajectory_groups(groups_list)
avg_metrics.setdefault("time/step_trainer_s", time.monotonic() - trainer_started)
avg_metrics.update(
{
key: value
for key, value in build_training_summary_metrics(
summary,
include_trainable_groups=True,
).items()
if key not in avg_metrics
}
)
return avg_metrics
87 changes: 35 additions & 52 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,14 @@
from mp_actors import close_proxy, move_to_child_process

from .. import dev
from .._backend_training import (
aggregate_rl_training_metrics,
build_rl_train_configs,
)
from ..backend import AnyTrainableModel, Backend
from ..costs import build_cost_calculator, get_model_pricing
from ..metrics_taxonomy import (
TRAIN_GRADIENT_STEPS_KEY,
average_metric_samples,
build_training_summary_metrics,
summarize_trajectory_groups,
)
Expand Down Expand Up @@ -642,45 +645,36 @@ async def train( # type: ignore[override]
if adam_params is not None:
raise ValueError("LocalBackend requires adam_params=None.")

# Build config objects from explicit kwargs
config = TrainConfig(
learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef
)
dev_config: dev.TrainConfig = {
"advantage_balance": advantage_balance,
"allow_training_without_logprobs": allow_training_without_logprobs,
"importance_sampling_level": importance_sampling_level,
"kl_penalty_coef": kl_penalty_coef,
"mask_prob_ratio": mask_prob_ratio,
"plot_tensors": plot_tensors,
"ppo": loss_fn == "ppo",
"precalculate_logprobs": precalculate_logprobs,
"scale_learning_rate_by_reward_std_dev": scale_learning_rate_by_reward_std_dev,
"scale_rewards": scale_rewards,
"logprob_calculation_chunk_size": logprob_calculation_chunk_size,
"num_trajectories_learning_rate_multiplier_power": num_trajectories_learning_rate_multiplier_power,
}
# Only include optional fields if they're set
if epsilon is not None:
dev_config["epsilon"] = epsilon
if epsilon_high is not None:
dev_config["epsilon_high"] = epsilon_high
if max_negative_advantage_importance_sampling_weight is not None:
dev_config["max_negative_advantage_importance_sampling_weight"] = (
max_negative_advantage_importance_sampling_weight
)
if kimi_k2_tau is not None:
dev_config["kimi_k2_tau"] = kimi_k2_tau
if truncated_importance_sampling is not None:
dev_config["truncated_importance_sampling"] = truncated_importance_sampling
if kl_ref_adapter_path is not None:
dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path
elif kl_penalty_reference_step is not None:
ref_checkpoint_dir = get_step_checkpoint_dir(
resolved_kl_ref_adapter_path = kl_ref_adapter_path
if (
resolved_kl_ref_adapter_path is None
and kl_penalty_reference_step is not None
):
resolved_kl_ref_adapter_path = get_step_checkpoint_dir(
get_model_dir(model=model, art_path=self._path),
kl_penalty_reference_step,
)
dev_config["kl_ref_adapter_path"] = ref_checkpoint_dir
config, dev_config = build_rl_train_configs(
learning_rate=learning_rate,
advantage_balance=advantage_balance,
scale_rewards=scale_rewards,
importance_sampling_level=importance_sampling_level,
mask_prob_ratio=mask_prob_ratio,
ppo=loss_fn == "ppo",
precalculate_logprobs=precalculate_logprobs,
epsilon=epsilon,
epsilon_high=epsilon_high,
max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight,
kimi_k2_tau=kimi_k2_tau,
kl_penalty_coef=kl_penalty_coef,
allow_training_without_logprobs=allow_training_without_logprobs,
plot_tensors=plot_tensors,
truncated_importance_sampling=truncated_importance_sampling,
scale_learning_rate_by_reward_std_dev=scale_learning_rate_by_reward_std_dev,
logprob_calculation_chunk_size=logprob_calculation_chunk_size,
num_trajectories_learning_rate_multiplier_power=num_trajectories_learning_rate_multiplier_power,
kl_ref_adapter_path=resolved_kl_ref_adapter_path,
)

# Collect metrics from training
training_metrics: list[dict[str, float]] = []
Expand All @@ -690,21 +684,10 @@ async def train( # type: ignore[override]
):
training_metrics.append(metrics)

# Aggregate metrics
avg_metrics = average_metric_samples(training_metrics)
summary = summarize_trajectory_groups(groups_list)
avg_metrics.setdefault(
"time/step_trainer_s", time.monotonic() - trainer_started
)
avg_metrics.update(
{
key: value
for key, value in build_training_summary_metrics(
summary,
include_trainable_groups=True,
).items()
if key not in avg_metrics
}
avg_metrics = aggregate_rl_training_metrics(
training_metrics=training_metrics,
trajectory_groups=groups_list,
trainer_started=trainer_started,
)

# Get step and checkpoint path
Expand Down
40 changes: 40 additions & 0 deletions src/art/megatron/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Literal

from pydantic import BaseModel

from .. import dev, types
from ..preprocessing.pack import DiskPackedTensors
from .routing_replay import MoeRoutingReplayBundle

DEFAULT_TRAINING_LOG_PATH = "/tmp/megatron_training_log.jsonl"
DEFAULT_JOBS_DIR = "/tmp/megatron_training_jobs"
DEFAULT_VLLM_WAKE_LOCK_PATH = "/tmp/megatron_vllm_waking"


class MegatronTrainingJob(BaseModel):
lora_path: str
optimizer_state_path: str
disk_packed_tensors: DiskPackedTensors
config: types.TrainConfig
experimental_config: dev.TrainConfig
moe_routing_replay_path: str | None = None
moe_routing_replay_strict: bool = True
log_path: str = DEFAULT_TRAINING_LOG_PATH


MegatronTrainingJob.model_rebuild(
force=True,
_types_namespace={"MoeRoutingReplayBundle": MoeRoutingReplayBundle},
)


class MegatronSFTTrainingJob(BaseModel):
job_type: Literal["sft"] = "sft"
lora_path: str
optimizer_state_path: str
sft_data_dir: str
num_batches: int
learning_rates: list[float]
weight_decay: float = 0.0
max_grad_norm: float = 1.0
log_path: str = DEFAULT_TRAINING_LOG_PATH
6 changes: 5 additions & 1 deletion src/art/megatron/routing_replay.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections import defaultdict
import importlib
import json
from pathlib import Path
import re
Expand All @@ -13,9 +14,12 @@
)
from megatron.core.transformer.moe.moe_utils import permute, sort_chunks_by_idxs
from pydantic import BaseModel, ConfigDict, model_validator
from safetensors.torch import load_file, save_file
import torch

safetensors_torch = importlib.import_module("safetensors.torch")
load_file = safetensors_torch.load_file
save_file = safetensors_torch.save_file

ROUTER_NAME_TOKEN = ".mlp.router"
ROUTER_KEY_FORMAT_VERSION = "moe_routing_replay_v1"
GLOBAL_TOKEN_UIDS_KEY = "global_token_uids"
Expand Down
15 changes: 15 additions & 0 deletions src/art/megatron/runtime_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os


def _set_cache_dir(env_var: str, default_path: str) -> None:
if not os.environ.get(env_var):
os.environ[env_var] = os.path.expanduser(default_path)
os.makedirs(os.environ[env_var], exist_ok=True)


def configure_megatron_runtime_env() -> None:
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
_set_cache_dir("TORCHINDUCTOR_CACHE_DIR", "~/.cache/torchinductor")
_set_cache_dir("TRITON_CACHE_DIR", "~/.triton/cache")
Loading
Loading