Skip to content

Commit aad4809

Browse files
authored
Merge branch 'master' into master
2 parents b2e17ab + 3bdebc0 commit aad4809

25 files changed

Lines changed: 1474 additions & 80 deletions

.github/workflows/aws-torch-latest-full.yml

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
# DeepSpeed CI - AWS L40S GPU Full Tests (PyTorch Latest)
33
#
44
# Runs the full DeepSpeed unit test suite on AWS self-hosted runners.
5-
# Uses 4x NVIDIA L40S GPUs on g6e.12xlarge instances.
5+
# Prefers 4x NVIDIA L40S GPUs on g6e.12xlarge instances, with AWS-side
6+
# fallback to 8x A100 nodes when L40S capacity is unavailable.
67
#
78
# This workflow runs:
89
# - Parallel tests with pytest-xdist (-n 8)
910
# - Sequential tests marked with @pytest.mark.sequential
10-
#
11-
# Nightly schedule: skips if no new commits since last successful run.
11+
# - Nightly schedule: skips if no new commits since last successful run
1212
################################################################################
1313

1414
name: aws-torch-latest-full
@@ -26,7 +26,6 @@ jobs:
2626
check-changes:
2727
name: Check for new commits
2828
runs-on: ubuntu-latest
29-
# Only check on schedule; workflow_dispatch always runs
3029
if: github.event_name == 'schedule'
3130
outputs:
3231
has_changes: ${{ steps.check.outputs.has_changes }}
@@ -38,28 +37,26 @@ jobs:
3837
run: |
3938
default_branch="${{ github.event.repository.default_branch }}"
4039
41-
# Get the HEAD SHA of the last successful run of this workflow
4240
last_sha=$(gh api \
4341
"repos/${{ github.repository }}/actions/workflows/aws-torch-latest-full.yml/runs?status=success&branch=${default_branch}&per_page=1" \
4442
--jq '.workflow_runs[0].head_sha // empty')
4543
4644
current_sha="${{ github.sha }}"
4745
4846
if [ -z "$last_sha" ]; then
49-
echo "No previous successful run found running tests"
47+
echo "No previous successful run found - running tests"
5048
echo "has_changes=true" >> "$GITHUB_OUTPUT"
5149
elif [ "$last_sha" = "$current_sha" ]; then
52-
echo "No new commits since last successful run ($last_sha) skipping"
50+
echo "No new commits since last successful run ($last_sha) - skipping"
5351
echo "has_changes=false" >> "$GITHUB_OUTPUT"
5452
else
55-
echo "New commits detected: $last_sha -> $current_sha running tests"
53+
echo "New commits detected: $last_sha -> $current_sha - running tests"
5654
echo "has_changes=true" >> "$GITHUB_OUTPUT"
5755
fi
5856
5957
unit-tests:
6058
name: Unit Tests (Full)
6159
needs: [check-changes]
62-
# Run if: (a) workflow_dispatch, or (b) schedule with new commits
6360
if: |
6461
always() &&
6562
(github.event_name == 'workflow_dispatch' || needs.check-changes.outputs.has_changes == 'true')
@@ -134,8 +131,30 @@ jobs:
134131
echo "CUTLASS_PATH: $CUTLASS_PATH"
135132
ls -la $CUTLASS_PATH/include/ | head -5
136133
134+
- name: Detect GPU architecture
135+
run: |
136+
python - <<'PY'
137+
import os
138+
import torch
139+
140+
torch.cuda.init()
141+
major, minor = torch.cuda.get_device_capability(0)
142+
arch = f"{major}.{minor}"
143+
gpu_count = torch.cuda.device_count()
144+
gpu_name = torch.cuda.get_device_name(0)
145+
146+
with open(os.environ["GITHUB_ENV"], "a", encoding="utf-8") as env_file:
147+
env_file.write(f"TORCH_CUDA_ARCH_LIST={arch}\n")
148+
env_file.write(f"GPU_COUNT={gpu_count}\n")
149+
150+
print(f"Detected GPU: {gpu_name}")
151+
print(f"Detected compute capability: {arch}")
152+
print(f"Detected GPU count: {gpu_count}")
153+
PY
154+
137155
- name: Install DeepSpeed
138156
run: |
157+
echo "Using TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST"
139158
# Initialize CUDA before install so setup.py can detect NCCL version
140159
python -c "import torch; torch.cuda.init(); print(f'NCCL version: {torch.cuda.nccl.version()}')"
141160
# Use --no-build-isolation so setup.py can access pre-installed PyTorch
@@ -148,7 +167,7 @@ jobs:
148167
149168
- name: Unit tests (parallel)
150169
run: |
151-
export TORCH_CUDA_ARCH_LIST="8.9"
170+
echo "Running parallel tests with TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST on $GPU_COUNT GPUs"
152171
cd tests
153172
# Skip tests requiring unavailable hardware or known issues:
154173
# - nvme checkpointing: no nvme device
@@ -166,7 +185,7 @@ jobs:
166185
167186
- name: Unit tests (sequential)
168187
run: |
169-
export TORCH_CUDA_ARCH_LIST="8.9"
188+
echo "Running sequential tests with TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST on $GPU_COUNT GPUs"
170189
cd tests
171190
rm -rf /mnt/aio/pytest
172191
pytest --instafail --timeout 600 --forked -m 'sequential' --basetemp=/mnt/aio/pytest unit/ \

AGENTS.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77

88
- All commits MUST have a `Signed-off-by` line (use `--signoff`). Get the name and email from `git config user.name` / `git config user.email`.
99
- Formatting: yapf (column_limit=119, `.style.yapf`) + flake8 (`.flake8`).
10-
- Always verify changed files pass pre-commit checks before committing. Config: `.pre-commit-config.yaml`.
10+
- Always verify changed files pass pre-commit checks before committing: `pre-commit run --files <changed_files>`. Only check modified files, not the entire codebase. Config: `.pre-commit-config.yaml`.
1111
- `check-torchdist` hook: NEVER directly import torch's distributed module. Use `import deepspeed.comm as dist` instead.
1212
- New files require license header:
1313
```
14-
# Copyright (c) Microsoft Corporation.
1514
# SPDX-License-Identifier: Apache-2.0
1615
# DeepSpeed Team
1716
```

CLAUDE.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77

88
- All commits MUST have a `Signed-off-by` line (use `--signoff`). Get the name and email from `git config user.name` / `git config user.email`.
99
- Formatting: yapf (column_limit=119, `.style.yapf`) + flake8 (`.flake8`).
10-
- Always verify changed files pass pre-commit checks before committing. Config: `.pre-commit-config.yaml`.
10+
- Always verify changed files pass pre-commit checks before committing: `pre-commit run --files <changed_files>`. Only check modified files, not the entire codebase. Config: `.pre-commit-config.yaml`.
1111
- `check-torchdist` hook: NEVER directly import torch's distributed module. Use `import deepspeed.comm as dist` instead.
1212
- New files require license header:
1313
```
14-
# Copyright (c) Microsoft Corporation.
1514
# SPDX-License-Identifier: Apache-2.0
1615
# DeepSpeed Team
1716
```

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
## Latest News
1818

19+
* [2026/03] DeepSpeed Team gave a tutorial at ASPLOS 2026 titled ["Building Efficient Large-Scale Model Systems with DeepSpeed: From Open-Source Foundations to Emerging Research" ](https://supercomputing-system-ai-lab.github.io/events/asplos2026-llm-tutorial/index.html)
20+
1921
* [2026/03] [Our SuperOffload work received an Honorable Mention for the ASPLOS 2026 Best Paper Award](https://dl.acm.org/doi/10.1145/3760250.3762217)
2022

2123
* [2025/12] [DeepSpeed Core API updates: PyTorch-style backward and low-precision master states](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/core_api_update/README.md)

deepspeed/compile/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33

44
# DeepSpeed Team
55

6+
from typing import List, Optional, Literal
67
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
78

9+
PassName = Literal["z1", "z3", "autosp"]
10+
811

912
class CompileConfig(DeepSpeedConfigModel):
1013
""" Configure compile settings """
@@ -53,3 +56,6 @@ class CompileConfig(DeepSpeedConfigModel):
5356

5457
keep_all_input_tensors: bool = False
5558
""" Keep real values for all input tensors in InputStorage instead of using dummy values """
59+
60+
passes: Optional[List[PassName]] = None
61+
""" Composes different optimizations. """

deepspeed/compile/constants.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
#########################################
7+
# AUTOSP
8+
#########################################
9+
AUTOSP_INPUT_ID_KEY = "input_id"
10+
AUTOSP_LABEL_ID_KEY = "label_id"
11+
AUTOSP_POSITION_ID_KEY = "position_id"
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
from .all_to_all import all_to_all
7+
from . import sp_dp_registry
8+
9+
__all__ = ["all_to_all", "sp_dp_registry", "sp_compat"]
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import torch
7+
import deepspeed.comm as dist
8+
from torch.utils._sympy.functions import FloorDiv
9+
from .sp_dp_registry import get_group, is_setup, sp_size
10+
11+
12+
@torch.library.custom_op("autosp::all_to_all", mutates_args=())
13+
def all_to_all(
14+
input: torch.Tensor,
15+
scatter_idx: int,
16+
gather_idx: int,
17+
name: str,
18+
) -> torch.Tensor:
19+
"""
20+
All-to-all collective for SDPA tensors [B, N, S, H].
21+
22+
For QKV (scatter_idx=1, gather_idx=2):
23+
[B, N, S/P, H] -> [B, N/P, S, H]
24+
For O (scatter_idx=2, gather_idx=1):
25+
[B, N/P, S, H] -> [B, N, S/P, H]
26+
"""
27+
assert is_setup(), 'Incorrect initialization of SP/DP mesh.'
28+
B, dim1, dim2, H = input.shape
29+
gid = dist.get_rank() // sp_size()
30+
group = get_group(gid)
31+
32+
if scatter_idx == 1:
33+
N, local_S = dim1, dim2
34+
input_t = input.reshape(B, sp_size(), N // sp_size(), local_S, H)
35+
input_t = input_t.permute(1, 0, 2, 3, 4).contiguous()
36+
37+
output = torch.empty_like(input_t)
38+
dist.all_to_all_single(output, input_t, group=group)
39+
40+
output = output.permute(1, 2, 0, 3, 4).contiguous()
41+
output = output.reshape(B, N // sp_size(), sp_size() * local_S, H)
42+
else:
43+
local_N, S = dim1, dim2
44+
input_t = input.reshape(B, local_N, sp_size(), S // sp_size(), H)
45+
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()
46+
47+
output = torch.empty_like(input_t)
48+
dist.all_to_all_single(output, input_t, group=group)
49+
50+
output = output.permute(1, 0, 2, 3, 4).contiguous()
51+
output = output.reshape(B, sp_size() * local_N, S // sp_size(), H)
52+
53+
return output
54+
55+
56+
@torch.library.register_fake("autosp::all_to_all")
57+
def all_to_all_fake(input: torch.Tensor, scatter_idx: int, gather_idx: int, name: str):
58+
59+
def maybe_restore_sharded_dim(dim: torch.SymInt, factor: int):
60+
# Torch 2.9 may keep `P * (s // P)` distinct from the original `s` during
61+
# fake shape propagation. When the local dim is exactly `FloorDiv(s, P)`,
62+
# restore the original symbol so downstream ops see a consistent sequence dim.
63+
node = getattr(dim, "node", None)
64+
if node is None:
65+
return dim * factor
66+
67+
expr = node.expr
68+
if isinstance(expr, FloorDiv) and expr.args[1] == factor:
69+
hint = node.hint * factor if node.has_hint() else None
70+
return node.shape_env.create_symintnode(expr.args[0], hint=hint)
71+
72+
return dim * factor
73+
74+
B, dim1, dim2, H = input.shape
75+
if scatter_idx == 1:
76+
return input.new_empty(B, dim1 // sp_size(), maybe_restore_sharded_dim(dim2, sp_size()), H)
77+
else:
78+
return input.new_empty(B, dim1 * sp_size(), dim2 // sp_size(), H)
79+
80+
81+
def _all_to_all_backward_setup(ctx, inputs, output):
82+
_, scatter_idx, gather_idx, name = inputs
83+
ctx.scatter_idx = gather_idx
84+
ctx.gather_idx = scatter_idx
85+
ctx.name = name + "_grad"
86+
87+
88+
def _all_to_all_backward(ctx, grad):
89+
return (all_to_all(grad, ctx.scatter_idx, ctx.gather_idx, ctx.name), None, None, None)
90+
91+
92+
torch.library.register_autograd("autosp::all_to_all", _all_to_all_backward, setup_context=_all_to_all_backward_setup)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import torch
7+
from packaging.version import Version
8+
9+
10+
def _check_autosp_compatibility():
11+
# Strip the local version segment (e.g. +cu128) so CUDA builds don't sort
12+
# above the max bound when using packaging's local-version ordering rules.
13+
torch_version = Version(torch.__version__.split("+")[0])
14+
if torch_version < Version("2.9"):
15+
raise RuntimeError("AutoSP requires PyTorch >= 2.9, found "
16+
f"{torch.__version__}.")
17+
18+
try:
19+
import transformers
20+
if Version(transformers.__version__) > Version("4.50.3"):
21+
raise RuntimeError("AutoSP requires transformers <= 4.50.3, found "
22+
f"{transformers.__version__}.")
23+
except ImportError:
24+
pass # transformers not installed; skip the check
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import deepspeed.comm as dist
7+
8+
GROUP_REGISTRY = {} # int -> dist.ProcessGroup
9+
10+
11+
def register_groups(groups):
12+
"""groups: List[List[int]], e.g. [[0,1],[2,3]]"""
13+
for gid, ranks in enumerate(groups):
14+
if gid not in GROUP_REGISTRY:
15+
GROUP_REGISTRY[gid] = dist.new_group(ranks)
16+
17+
18+
def get_group(gid: int):
19+
return GROUP_REGISTRY[gid] if gid is not None else dist.get_world_group()
20+
21+
22+
def get_registry():
23+
return GROUP_REGISTRY
24+
25+
26+
def is_setup():
27+
return GROUP_REGISTRY['is_reg'] if 'is_reg' in GROUP_REGISTRY else False
28+
29+
30+
def extract_mesh_size(param_dict):
31+
sp_size = param_dict.get('sequence_parallel_size', 1)
32+
assert dist.get_world_size() % sp_size == 0, 'World mesh-size should be divisible by SP_SIZE'
33+
dp_size = dist.get_world_size() // sp_size
34+
35+
return sp_size, dp_size
36+
37+
38+
def sp_size():
39+
assert 'SP_SIZE' in GROUP_REGISTRY, 'SP_SIZE not init properly.'
40+
41+
return GROUP_REGISTRY['SP_SIZE']
42+
43+
44+
def dp_size():
45+
assert 'DP_SIZE' in GROUP_REGISTRY, 'DP_SIZE not init properly'
46+
47+
return GROUP_REGISTRY['DP_SIZE']
48+
49+
50+
def populate_registry(SP_SIZE, DP_SIZE):
51+
""" Populate rank to SP/DP mesh index. """
52+
53+
if GROUP_REGISTRY.get('is_reg', False):
54+
return
55+
56+
group_listing = []
57+
offset = 0
58+
for _ in range(DP_SIZE):
59+
group_listing.append([i + offset for i in range(SP_SIZE)])
60+
offset += SP_SIZE
61+
62+
register_groups(group_listing)
63+
64+
## Extraneous metadata required for proper instatiation. ##
65+
GROUP_REGISTRY['SP_SIZE'] = SP_SIZE
66+
GROUP_REGISTRY['DP_SIZE'] = DP_SIZE
67+
GROUP_REGISTRY['is_reg'] = True

0 commit comments

Comments
 (0)