Skip to content

Commit 6f73ea2

Browse files
Add AutoSP unit and end-to-end tests
Signed-off-by: Ahan Gupta <ahangupta.96@gmail.com> Co-authored-by: Neel Dani <neeldani98@gmail.com>
1 parent 4a194c6 commit 6f73ea2

6 files changed

Lines changed: 477 additions & 9 deletions

File tree

deepspeed/compile/custom_ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
from .all_to_all import all_to_all
77
from . import sp_dp_registry
88

9-
__all__ = ["all_to_all", "sp_dp_registry"]
9+
__all__ = ["all_to_all", "sp_dp_registry", "sp_compat"]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import torch
7+
from packaging.version import Version
8+
9+
10+
def _check_autosp_compatibility():
11+
# Strip the local version segment (e.g. +cu128) so CUDA builds don't sort
12+
# above the max bound when using packaging's local-version ordering rules.
13+
torch_version = Version(torch.__version__.split("+")[0])
14+
if torch_version < Version("2.6") or torch_version >= Version("2.8"):
15+
raise RuntimeError(
16+
"AutoSP requires PyTorch >= 2.6 and <= 2.7, found "
17+
f"{torch.__version__}."
18+
)
19+
20+
try:
21+
import transformers
22+
if Version(transformers.__version__) > Version("4.50.3"):
23+
raise RuntimeError(
24+
"AutoSP requires transformers <= 4.50.3, found "
25+
f"{transformers.__version__}."
26+
)
27+
except ImportError:
28+
pass # transformers not installed; skip the check

deepspeed/compile/init_sp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
from .passes.sp_compile import apply_autosp
99
from .passes.long_context_checkpointing import register_long_context_checkpointing
1010
from .custom_ops.sp_dp_registry import extract_mesh_size
11+
from .custom_ops.sp_compat import _check_autosp_compatibility
1112

1213

1314
def init_autosp(config):
15+
_check_autosp_compatibility()
1416
sp_size, dp_size = extract_mesh_size(config._param_dict)
1517
register_long_context_checkpointing()
1618

deepspeed/compile/passes/long_context_checkpointing.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,8 @@ def register_long_context_checkpointing():
9393
lines = src.split('\n')
9494

9595
# Locate the original should_ban_recomputation and the function after it.
96-
start = next(
97-
i for i, l in enumerate(lines)
98-
if l.startswith(' def should_ban_recomputation(')
99-
)
100-
end = next(
101-
i for i, l in enumerate(lines)
102-
if i > start and l.startswith(' def ')
103-
)
96+
start = next(i for i, l in enumerate(lines) if l.startswith(' def should_ban_recomputation('))
97+
end = next(i for i, l in enumerate(lines) if i > start and l.startswith(' def '))
10498

10599
# Indent the replacement to the nesting level inside solve_min_cut (4 spaces).
106100
replacement = textwrap.indent(_CUSTOM_SHOULD_BAN, ' ')
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import operator
7+
from unittest.mock import patch
8+
9+
import pytest
10+
import torch
11+
import torch.nn.functional as F
12+
13+
from deepspeed.utils.torch import required_torch_version
14+
from deepspeed.accelerator import get_accelerator
15+
from deepspeed.compile import constants
16+
17+
from unit.v1.compile.util import compare_sp_loss, create_gm_nodes, find_sym_seq_node
18+
from unit.common import DistributedTest
19+
from unit.util import bf16_required_version_check, skip_on_arch
20+
21+
pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.6),
22+
reason="AutoSP tests require PyTorch >= 2.6")
23+
24+
# Fixed sp_size injected into mocks.
25+
_SP_SIZE = 2
26+
27+
28+
class TestAutoSPCompile(DistributedTest):
29+
world_size = 4
30+
non_daemonic_procs = True
31+
32+
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float32])
33+
@pytest.mark.parametrize('zero_stage', [0, 1])
34+
@pytest.mark.parametrize('sp_size', [2, 4])
35+
def test(self, zero_stage, dtype, sp_size):
36+
if dtype == torch.bfloat16:
37+
skip_on_arch(min_arch=8)
38+
if dtype == torch.bfloat16 and not bf16_required_version_check():
39+
pytest.skip(
40+
"DeepSpeed BFloat16 tests need NCCL >= 2.10.3, CUDA >=11.0, and HW support for BFloat16 to run correctly"
41+
)
42+
if get_accelerator().device_name() == "cpu":
43+
pytest.skip("CPU does not support this test yet")
44+
45+
dp_size = self.world_size // sp_size
46+
47+
config_dict = {
48+
"train_micro_batch_size_per_gpu": 1,
49+
"train_batch_size": dp_size,
50+
"steps_per_print": 1,
51+
"optimizer": {
52+
"type": "Adam",
53+
"params": {
54+
"lr": 1e-4
55+
}
56+
},
57+
"zero_optimization": {
58+
"stage": zero_stage,
59+
},
60+
"compile": {
61+
"deepcompile": True,
62+
"passes": ["autosp"]
63+
},
64+
"sequence_parallel_size": sp_size,
65+
"gradient_clipping": 1.0,
66+
}
67+
68+
if dtype == torch.bfloat16:
69+
config_dict["bf16"] = {"enabled": True}
70+
71+
compare_sp_loss(self, config_dict, sp_size)
72+
73+
74+
# Plain pytest classes — no distributed runtime needed because these functions
75+
# perform pure IR-level graph rewrites; sp_size and get_rank are mocked.
76+
77+
78+
class TestSDPANodesCompile:
79+
80+
@pytest.mark.parametrize('seq_len', [64, 128, 256])
81+
def test(self, seq_len):
82+
from deepspeed.compile.util import get_sdpa_nodes
83+
84+
gm, _ = create_gm_nodes(seq_len=seq_len)
85+
sdpa_nodes = get_sdpa_nodes(gm)
86+
87+
assert len(sdpa_nodes) >= 1, f"Expected at least 1 SDPA node, got {len(sdpa_nodes)}"
88+
for node in sdpa_nodes:
89+
assert node.target == F.scaled_dot_product_attention
90+
91+
92+
class TestInputIdCompile:
93+
94+
@pytest.mark.parametrize('seq_len', [64, 128, 256])
95+
def test(self, seq_len):
96+
from deepspeed.compile.util import get_input_id_node
97+
98+
gm, _ = create_gm_nodes(seq_len=seq_len)
99+
node = get_input_id_node(gm)
100+
101+
assert node.op == "placeholder"
102+
tensor_dict = node.meta.get("tensor_dict", {})
103+
assert tensor_dict.get("tag") == constants.AUTOSP_INPUT_ID_KEY
104+
105+
106+
class TestLabelIdCompile:
107+
108+
@pytest.mark.parametrize('seq_len', [64, 128, 256])
109+
def test(self, seq_len):
110+
from deepspeed.compile.util import get_label_id_node
111+
112+
gm, _ = create_gm_nodes(seq_len=seq_len)
113+
node = get_label_id_node(gm)
114+
115+
assert node.op == "placeholder"
116+
tensor_dict = node.meta.get("tensor_dict", {})
117+
assert tensor_dict.get("tag") == constants.AUTOSP_LABEL_ID_KEY
118+
119+
120+
class TestPositionIdCompile:
121+
122+
@pytest.mark.parametrize('seq_len', [64, 128, 256])
123+
def test(self, seq_len):
124+
from deepspeed.compile.util import get_position_id_node
125+
126+
gm, _ = create_gm_nodes(seq_len=seq_len)
127+
node = get_position_id_node(gm)
128+
129+
assert node is not None, "position_id node not found in graph"
130+
assert node.op == "placeholder"
131+
tensor_dict = node.meta.get("tensor_dict", {})
132+
assert tensor_dict.get("tag") == constants.AUTOSP_POSITION_ID_KEY
133+
134+
135+
class TestShardOffsetsCompile:
136+
137+
@pytest.mark.parametrize('seq_len', [64, 128, 256])
138+
def test(self, seq_len):
139+
import deepspeed.comm as _dist
140+
from deepspeed.compile.custom_ops import sp_dp_registry as _registry
141+
from deepspeed.compile.util import create_shard_offsets
142+
143+
gm, _ = create_gm_nodes(seq_len=seq_len)
144+
sym_seq_node = find_sym_seq_node(gm)
145+
assert sym_seq_node is not None, "Symbolic sequence-length node not found in graph"
146+
147+
with patch.object(_registry, 'sp_size', return_value=_SP_SIZE), \
148+
patch.object(_dist, 'get_rank', return_value=0):
149+
start_node, end_node = create_shard_offsets(gm, sym_seq_node)
150+
151+
# create_shard_offsets emits: chunk = seq // sp_size; start = rank * chunk; end = start + chunk.
152+
# Verify the three-node chain has the right operators and wiring.
153+
chunk_size_node = start_node.args[1] # start = rank * chunk → chunk is arg[1]
154+
155+
assert chunk_size_node.target == operator.floordiv
156+
assert chunk_size_node.args[0] is sym_seq_node
157+
assert chunk_size_node.args[1] == _SP_SIZE
158+
159+
assert start_node.target == operator.mul
160+
assert start_node.args[0] == 0 # rank 0 baked in at transform time
161+
assert start_node.args[1] is chunk_size_node
162+
163+
assert end_node.target == operator.add
164+
assert end_node.args[0] is start_node
165+
assert end_node.args[1] is chunk_size_node
166+
167+
168+
class TestSymSliceCompile:
169+
170+
@pytest.mark.parametrize('seq_len', [64, 128, 256])
171+
def test(self, seq_len):
172+
import deepspeed.comm as _dist
173+
from deepspeed.compile.custom_ops import sp_dp_registry as _registry
174+
from deepspeed.compile.util import create_symbolic_slice_indices
175+
176+
gm, _ = create_gm_nodes(seq_len=seq_len)
177+
sym_seq_node = find_sym_seq_node(gm)
178+
assert sym_seq_node is not None, "Symbolic sequence-length node not found in graph"
179+
180+
with patch.object(_registry, 'sp_size', return_value=_SP_SIZE), \
181+
patch.object(_dist, 'get_rank', return_value=0):
182+
slice_all, slice_range = create_symbolic_slice_indices(gm, sym_seq_node)
183+
184+
# slice_all = slice(None, None, None) — selects the batch dimension unchanged
185+
assert slice_all.target == slice
186+
assert slice_all.args == (None, None, None)
187+
188+
# slice_range selects [start, end) along the sequence dim, where start and
189+
# end come from create_shard_offsets (mul and add nodes respectively).
190+
assert slice_range.target == slice
191+
start_arg, end_arg, step_arg = slice_range.args
192+
assert step_arg is None
193+
194+
# start = rank * chunk → verify the full shard-offset wiring
195+
chunk_size_node = start_arg.args[1]
196+
assert start_arg.target == operator.mul
197+
assert start_arg.args[0] == 0 # rank 0 baked in at transform time
198+
assert chunk_size_node.target == operator.floordiv
199+
assert chunk_size_node.args[0] is sym_seq_node
200+
assert chunk_size_node.args[1] == _SP_SIZE
201+
202+
# end = start + chunk
203+
assert end_arg.target == operator.add
204+
assert end_arg.args[0] is start_arg
205+
assert end_arg.args[1] is chunk_size_node
206+
207+
208+
class TestShardTensorCompile:
209+
210+
@pytest.mark.parametrize('seq_len', [64, 128, 256])
211+
def test(self, seq_len):
212+
import deepspeed.comm as _dist
213+
from deepspeed.compile.custom_ops import sp_dp_registry as _registry
214+
from deepspeed.compile.util import shard_tensor_node, get_input_id_node
215+
216+
gm, _ = create_gm_nodes(seq_len=seq_len)
217+
input_ids_node = get_input_id_node(gm)
218+
original_users = set(input_ids_node.users.keys())
219+
assert len(original_users) > 0, "input_ids_node must have users before sharding"
220+
221+
with patch.object(_registry, 'sp_size', return_value=_SP_SIZE), \
222+
patch.object(_dist, 'get_rank', return_value=0):
223+
shard_tensor_node(gm, input_ids_node)
224+
225+
getitem_nodes = [n for n in gm.graph.nodes if n.target == operator.getitem and n.args[0] is input_ids_node]
226+
assert len(getitem_nodes) == 1, f"Expected 1 slice node after sharding, got {len(getitem_nodes)}"
227+
sliced_node = getitem_nodes[0]
228+
229+
# After sharding, the raw node must only feed the slice; all downstream
230+
# consumers are rewired to sliced_node by replace_node_users.
231+
assert set(input_ids_node.users.keys()) == {sliced_node}
232+
233+
for user in original_users:
234+
assert input_ids_node not in user.all_input_nodes, \
235+
f"User '{user.name}' still references the unsharded input_ids_node"
236+
assert sliced_node in user.all_input_nodes, \
237+
f"User '{user.name}' does not reference the sliced node"

0 commit comments

Comments
 (0)