Skip to content

Commit e1e7d58

Browse files
authored
Fix Ulysses SP backward with SDPA (#13328)
* add UT for backward * fix SDPA attention backward
1 parent a93f7f1 commit e1e7d58

2 files changed

Lines changed: 119 additions & 16 deletions

File tree

src/diffusers/models/attention_dispatch.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -862,23 +862,23 @@ def _native_attention_backward_op(
862862
key.requires_grad_(True)
863863
value.requires_grad_(True)
864864

865-
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
866-
out = torch.nn.functional.scaled_dot_product_attention(
867-
query=query_t,
868-
key=key_t,
869-
value=value_t,
870-
attn_mask=ctx.attn_mask,
871-
dropout_p=ctx.dropout_p,
872-
is_causal=ctx.is_causal,
873-
scale=ctx.scale,
874-
enable_gqa=ctx.enable_gqa,
875-
)
876-
out = out.permute(0, 2, 1, 3)
865+
with torch.enable_grad():
866+
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
867+
out = torch.nn.functional.scaled_dot_product_attention(
868+
query=query_t,
869+
key=key_t,
870+
value=value_t,
871+
attn_mask=ctx.attn_mask,
872+
dropout_p=ctx.dropout_p,
873+
is_causal=ctx.is_causal,
874+
scale=ctx.scale,
875+
enable_gqa=ctx.enable_gqa,
876+
)
877+
out = out.permute(0, 2, 1, 3)
877878

878-
grad_out_t = grad_out.permute(0, 2, 1, 3)
879-
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
880-
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
881-
)
879+
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
880+
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out, retain_graph=False
881+
)
882882

883883
grad_query = grad_query_t.permute(0, 2, 1, 3)
884884
grad_key = grad_key_t.permute(0, 2, 1, 3)

tests/models/testing_utils/parallelism.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,64 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
9898
dist.destroy_process_group()
9999

100100

101+
def _context_parallel_backward_worker(
102+
rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict
103+
):
104+
"""Worker function for context parallel backward pass testing."""
105+
try:
106+
# Set up distributed environment
107+
os.environ["MASTER_ADDR"] = "localhost"
108+
os.environ["MASTER_PORT"] = str(master_port)
109+
os.environ["RANK"] = str(rank)
110+
os.environ["WORLD_SIZE"] = str(world_size)
111+
112+
# Get device configuration
113+
device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
114+
backend = device_config["backend"]
115+
device_module = device_config["module"]
116+
117+
# Initialize process group
118+
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
119+
120+
# Set device for this process
121+
device_module.set_device(rank)
122+
device = torch.device(f"{torch_device}:{rank}")
123+
124+
# Create model in training mode
125+
model = model_class(**init_dict)
126+
model.to(device)
127+
model.train()
128+
129+
# Move inputs to device
130+
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
131+
132+
# Enable context parallelism
133+
cp_config = ContextParallelConfig(**cp_dict)
134+
model.enable_parallelism(config=cp_config)
135+
136+
# Run forward and backward pass
137+
output = model(**inputs_on_device, return_dict=False)[0]
138+
loss = output.sum()
139+
loss.backward()
140+
141+
# Check that backward actually produced at least one valid gradient
142+
grads = [p.grad for p in model.parameters() if p.requires_grad and p.grad is not None]
143+
has_valid_grads = len(grads) > 0 and all(torch.isfinite(g).all() for g in grads)
144+
145+
# Only rank 0 reports results
146+
if rank == 0:
147+
return_dict["status"] = "success"
148+
return_dict["has_valid_grads"] = bool(has_valid_grads)
149+
150+
except Exception as e:
151+
if rank == 0:
152+
return_dict["status"] = "error"
153+
return_dict["error"] = str(e)
154+
finally:
155+
if dist.is_initialized():
156+
dist.destroy_process_group()
157+
158+
101159
def _custom_mesh_worker(
102160
rank,
103161
world_size,
@@ -204,6 +262,51 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1):
204262
def test_context_parallel_batch_inputs(self, cp_type):
205263
self.test_context_parallel_inference(cp_type, batch_size=2)
206264

265+
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
266+
def test_context_parallel_backward(self, cp_type, batch_size: int = 1):
267+
if not torch.distributed.is_available():
268+
pytest.skip("torch.distributed is not available.")
269+
270+
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
271+
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
272+
273+
if cp_type == "ring_degree":
274+
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
275+
if active_backend == AttentionBackendName.NATIVE:
276+
pytest.skip("Ring attention is not supported with the native attention backend.")
277+
278+
world_size = 2
279+
init_dict = self.get_init_dict()
280+
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)
281+
282+
# Move all tensors to CPU for multiprocessing
283+
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
284+
cp_dict = {cp_type: world_size}
285+
286+
# Find a free port for distributed communication
287+
master_port = _find_free_port()
288+
289+
# Use multiprocessing manager for cross-process communication
290+
manager = mp.Manager()
291+
return_dict = manager.dict()
292+
293+
# Spawn worker processes
294+
mp.spawn(
295+
_context_parallel_backward_worker,
296+
args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict),
297+
nprocs=world_size,
298+
join=True,
299+
)
300+
301+
assert return_dict.get("status") == "success", (
302+
f"Context parallel backward pass failed: {return_dict.get('error', 'Unknown error')}"
303+
)
304+
assert return_dict.get("has_valid_grads"), "Context parallel backward pass did not produce valid gradients."
305+
306+
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
307+
def test_context_parallel_backward_batch_inputs(self, cp_type):
308+
self.test_context_parallel_backward(cp_type, batch_size=2)
309+
207310
@pytest.mark.parametrize(
208311
"cp_type,mesh_shape,mesh_dim_names",
209312
[

0 commit comments

Comments
 (0)