Skip to content

Commit 20dd6a5

Browse files
committed
Centralize Megatron dependency install
1 parent 3a679cb commit 20dd6a5

2 files changed

Lines changed: 245 additions & 6 deletions

File tree

src/art/megatron/setup.sh

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@ export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-9.0}"
77
apt-get update
88
apt-get install -y libcudnn9-headers-cuda-12 ninja-build
99

10-
# Python dependencies are declared in pyproject.toml extras.
11-
# Keep backend + megatron together so setup does not prune runtime deps (e.g. vllm).
12-
script_dir="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)"
13-
repo_root="$(cd -- "${script_dir}/../../.." && pwd)"
14-
cd "${repo_root}"
15-
uv sync --extra backend --extra megatron --frozen --active
10+
# Python dependencies are installed through art_megatron_install so
11+
# downstream repos can reuse the same source of truth for versions and VCS pins.
12+
uv run python -m art_megatron_install

src/art_megatron_install.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
from __future__ import annotations
2+
3+
import argparse
4+
import importlib.metadata as metadata
5+
import json
6+
import os
7+
import shutil
8+
import subprocess
9+
from dataclasses import dataclass
10+
from urllib.parse import parse_qs
11+
12+
from packaging.requirements import Requirement
13+
from packaging.version import Version
14+
15+
_APEX_ENV = {
16+
"APEX_CPP_EXT": "1",
17+
"APEX_CUDA_EXT": "1",
18+
"APEX_FAST_LAYER_NORM": "1",
19+
"APEX_PARALLEL_BUILD": "16",
20+
"NVCC_APPEND_FLAGS": "--threads 4",
21+
}
22+
_TRANSFORMER_ENGINE_ENV = {"NVTE_NO_LOCAL_VERSION": "1"}
23+
_MEGATRON_REQUIREMENTS = {
24+
"torch": Requirement("torch>=2.8.0"),
25+
"quack-kernels": Requirement("quack-kernels==0.2.5"),
26+
"apex": Requirement("apex @ git+https://github.com/NVIDIA/apex.git@25.09"),
27+
"transformer-engine": Requirement("transformer-engine==2.11.0"),
28+
"transformer-engine-cu12": Requirement("transformer-engine-cu12==2.11.0"),
29+
"transformer-engine-torch": Requirement(
30+
"transformer-engine-torch @ git+https://github.com/NVIDIA/TransformerEngine.git@v2.11#subdirectory=transformer_engine/pytorch"
31+
),
32+
"megatron-core": Requirement("megatron-core==0.16.1"),
33+
"pybind11": Requirement("pybind11>=2.13.6"),
34+
"megatron-bridge": Requirement(
35+
"megatron-bridge @ git+https://github.com/NVIDIA-NeMo/Megatron-Bridge.git@75f2c5ad4afb702b57b4781a00f5291a66bcf183"
36+
),
37+
"nvidia-ml-py": Requirement("nvidia-ml-py==13.580.82"),
38+
"ml-dtypes": Requirement("ml-dtypes>=0.5.0"),
39+
}
40+
_CORE_INSTALL_ORDER = (
41+
"torch",
42+
"pybind11",
43+
"ml-dtypes",
44+
"quack-kernels",
45+
"transformer-engine",
46+
"transformer-engine-cu12",
47+
"transformer-engine-torch",
48+
"megatron-core",
49+
)
50+
51+
52+
@dataclass(frozen=True)
53+
class _DirectUrlParts:
54+
vcs: str | None
55+
url: str
56+
requested_revision: str | None
57+
subdirectory: str | None
58+
59+
60+
def _megatron_requirements() -> dict[str, Requirement]:
61+
return dict(_MEGATRON_REQUIREMENTS)
62+
63+
64+
def _format_requirement(requirement: Requirement) -> str:
65+
extras = (
66+
f"[{','.join(sorted(requirement.extras))}]"
67+
if requirement.extras
68+
else ""
69+
)
70+
if requirement.url is not None:
71+
return f"{requirement.name}{extras} @ {requirement.url}"
72+
return f"{requirement.name}{extras}{requirement.specifier}"
73+
74+
75+
def _run(command: list[str], *, env: dict[str, str] | None = None, dry_run: bool) -> None:
76+
print("+", " ".join(command))
77+
if dry_run:
78+
return
79+
subprocess.run(command, check=True, env=env)
80+
81+
82+
def _uv_executable() -> str:
83+
uv = shutil.which("uv")
84+
if uv is None:
85+
raise RuntimeError("uv executable not found on PATH")
86+
return uv
87+
88+
89+
def _parse_direct_url(url: str) -> _DirectUrlParts:
90+
raw_url = url
91+
vcs = None
92+
if raw_url.startswith("git+"):
93+
raw_url = raw_url.removeprefix("git+")
94+
vcs = "git"
95+
raw_url, _, fragment = raw_url.partition("#")
96+
url_without_revision = raw_url
97+
requested_revision = None
98+
if "@" in raw_url:
99+
url_without_revision, requested_revision = raw_url.rsplit("@", 1)
100+
fragment_parts = parse_qs(fragment)
101+
subdirectory = fragment_parts.get("subdirectory", [None])[0]
102+
return _DirectUrlParts(
103+
vcs=vcs,
104+
url=url_without_revision,
105+
requested_revision=requested_revision,
106+
subdirectory=subdirectory,
107+
)
108+
109+
110+
def _direct_url_matches(requirement: Requirement) -> bool:
111+
try:
112+
dist = metadata.distribution(requirement.name)
113+
except metadata.PackageNotFoundError:
114+
return False
115+
direct_url_text = dist.read_text("direct_url.json")
116+
if direct_url_text is None:
117+
return False
118+
installed = json.loads(direct_url_text)
119+
expected = _parse_direct_url(requirement.url or "")
120+
if installed.get("url") != expected.url:
121+
return False
122+
if installed.get("subdirectory") != expected.subdirectory:
123+
return False
124+
if expected.vcs is None:
125+
return True
126+
vcs_info = installed.get("vcs_info") or {}
127+
if vcs_info.get("vcs") != expected.vcs:
128+
return False
129+
if expected.requested_revision is None:
130+
return True
131+
return vcs_info.get("requested_revision") == expected.requested_revision
132+
133+
134+
def _specifier_matches(requirement: Requirement) -> bool:
135+
try:
136+
installed_version = metadata.version(requirement.name)
137+
except metadata.PackageNotFoundError:
138+
return False
139+
return requirement.specifier.contains(Version(installed_version), prereleases=True)
140+
141+
142+
def dependencies_in_sync() -> bool:
143+
for requirement in _megatron_requirements().values():
144+
if requirement.url is not None:
145+
if not _direct_url_matches(requirement):
146+
return False
147+
continue
148+
if not _specifier_matches(requirement):
149+
return False
150+
return True
151+
152+
153+
def install_megatron_dependencies(*, dry_run: bool = False) -> None:
154+
requirements = _megatron_requirements()
155+
156+
apex_requirement = requirements.pop("apex", None)
157+
if apex_requirement is not None:
158+
_run(
159+
[
160+
_uv_executable(),
161+
"pip",
162+
"install",
163+
"--no-build-isolation",
164+
_format_requirement(apex_requirement),
165+
],
166+
env={**os.environ, **_APEX_ENV},
167+
dry_run=dry_run,
168+
)
169+
170+
core_requirements = [
171+
_format_requirement(requirements.pop(name))
172+
for name in _CORE_INSTALL_ORDER
173+
if name in requirements
174+
]
175+
if core_requirements:
176+
override_path = "/tmp/art-megatron-transformer-engine-override.txt"
177+
if not dry_run:
178+
with open(override_path, "w", encoding="utf-8") as override_file:
179+
override_file.write("transformer-engine==2.11.0\n")
180+
_run(
181+
[
182+
_uv_executable(),
183+
"pip",
184+
"install",
185+
"--no-build-isolation",
186+
"--override",
187+
override_path,
188+
*core_requirements,
189+
],
190+
env={**os.environ, **_TRANSFORMER_ENGINE_ENV},
191+
dry_run=dry_run,
192+
)
193+
if not dry_run and os.path.exists(override_path):
194+
os.remove(override_path)
195+
196+
bridge_requirement = requirements.pop("megatron-bridge", None)
197+
if bridge_requirement is not None:
198+
_run(
199+
[
200+
_uv_executable(),
201+
"pip",
202+
"install",
203+
"--no-deps",
204+
_format_requirement(bridge_requirement),
205+
],
206+
dry_run=dry_run,
207+
)
208+
209+
_run([_uv_executable(), "pip", "uninstall", "-y", "pynvml"], dry_run=dry_run)
210+
211+
remaining_requirements = [
212+
_format_requirement(requirement) for requirement in requirements.values()
213+
]
214+
if remaining_requirements:
215+
_run(
216+
[_uv_executable(), "pip", "install", *remaining_requirements],
217+
dry_run=dry_run,
218+
)
219+
220+
221+
def main() -> None:
222+
parser = argparse.ArgumentParser(description="Install ART Megatron dependencies")
223+
parser.add_argument(
224+
"--check",
225+
action="store_true",
226+
help="Exit 0 if installed Megatron dependencies already match ART's install plan.",
227+
)
228+
parser.add_argument(
229+
"--dry-run",
230+
action="store_true",
231+
help="Print uv pip commands without executing them.",
232+
)
233+
args = parser.parse_args()
234+
235+
if args.check:
236+
raise SystemExit(0 if dependencies_in_sync() else 1)
237+
238+
install_megatron_dependencies(dry_run=args.dry_run)
239+
240+
241+
if __name__ == "__main__":
242+
main()

0 commit comments

Comments
 (0)