Skip to content

Commit 6e486af

Browse files
committed
Fix activation checkpointing crash by using use_reentrant=False
Switches gradient_checkpointing_enable() to use non-reentrant checkpointing, which properly preserves dropout RNG state during recomputation and resolves the SystemError during loss.backward(). Issue: #3774
1 parent 3c8fc59 commit 6e486af

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

beginner_source/mosaic_memory_profiling_tutorial.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,9 @@ def run_training_ac(
309309
model = GPT2LMHeadModel.from_pretrained("gpt2")
310310

311311
if activation_checkpointing:
312-
model.gradient_checkpointing_enable()
312+
model.gradient_checkpointing_enable(
313+
gradient_checkpointing_kwargs={"use_reentrant": False}
314+
)
313315
print("Activation checkpointing is ENABLED")
314316
else:
315317
print("Activation checkpointing is DISABLED")

0 commit comments

Comments
 (0)