You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[DCP] Add DefaultStager example to distributed async checkpoint recipe (#3711)
Fixes#3710
## Description
This PR updates the `distributed_async_checkpoint_recipe` to include the
`DefaultStager` functionality introduced in PyTorch 2.9.
**Motivation:**
In large-scale training, even with standard `async_save`, the initial
memory copy (Staging phase, GPU -> CPU) occurs on the main thread. This
blocks the training loop. This PR introduces `DefaultStager`, which
offloads this copy to a background thread, enabling **full
computation-communication overlap**.
**Key Changes:**
1. **New Section**: Added "Fully Asynchronous Staging with
DefaultStager" to the recipe.
2. **Version Note**: Added `.. versionadded:: 2.9` to indicate version
requirements.
3. **Advanced Example**: Provided a code example demonstrating how to
overlap the D2H copy with the **entire Forward and Backward pass**:
* Check `staging_completion` *after* backward but *before*
`optimizer.step()` to ensure data consistency while maximizing parallel
execution.
* Check `upload_completion` before the next save to manage memory
backpressure.
Co-authored-by: Alanna Burke <burkealanna@meta.com>
Co-authored-by: Alanna Burke <aburke626@gmail.com>
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
Checkpointing is often a bottle-neck in the critical path for distributed training workloads, incurring larger and larger costs as both model and world sizes grow.
6
+
Checkpointing is often a bottleneck in the critical path for distributed training workloads, incurring larger and larger costs as both model and world sizes grow.
7
7
One excellent strategy for offsetting this cost is to checkpoint in parallel, asynchronously. Below, we expand the save example
8
8
from the `Getting Started with Distributed Checkpoint Tutorial <https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_checkpoint_recipe.rst>`__
9
9
to show how this can be integrated quite easily with ``torch.distributed.checkpoint.async_save``.
@@ -26,7 +26,7 @@ to show how this can be integrated quite easily with ``torch.distributed.checkpo
26
26
27
27
Asynchronous Checkpointing Overview
28
28
------------------------------------
29
-
Before getting started with Asynchronous Checkpointing, it's important to understand it's differences and limitations as compared to synchronous checkpointing.
29
+
Before getting started with Asynchronous Checkpointing, it's important to understand its differences and limitations as compared to synchronous checkpointing.
30
30
Specifically:
31
31
32
32
* Memory requirements - Asynchronous checkpointing works by first copying models into internal CPU-buffers.
@@ -35,8 +35,8 @@ Specifically:
35
35
the memory constraints of their systems. Specifically, pinned memory implies the usage of ``page-lock`` memory, which can be scarce as compared to
36
36
``pageable`` memory.
37
37
38
-
* Checkpoint Management - Since checkpointing is asynchronous, it is up to the user to manage concurrently run checkpoints. In general, users can
39
-
employ their own management strategies by handling the future object returned form ``async_save``. For most users, we recommend limiting
38
+
* Checkpoint Management - Since checkpointing is asynchronous, it is up to the user to manage concurrently run checkpoints.
39
+
In general, users can employ their own management strategies by handling the future object returned form ``async_save``. For most users, we recommend limiting
40
40
checkpoints to one asynchronous request at a time, avoiding additional memory pressure per request.
41
41
42
42
@@ -72,7 +72,7 @@ Specifically:
72
72
self.optimizer = optimizer
73
73
74
74
defstate_dict(self):
75
-
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
75
+
# this line automatically manages FSDP FQNs, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
@@ -279,6 +279,154 @@ checkpoint requests users can take advantage of direct memory access to speed up
279
279
)
280
280
281
281
282
+
Fully Asynchronous Staging with DefaultStager
283
+
--------------------------------------------
284
+
285
+
.. versionadded:: 2.9
286
+
The ``async_stager`` argument and ``DefaultStager`` class were introduced in PyTorch 2.9.
287
+
288
+
While ``async_save`` handles the disk write asynchronously, the process of copying data from GPU to CPU (known as "staging") typically happens on the main thread. Even with Pinned Memory, this Device-to-Host (D2H) copy can block the training loop for large models.
289
+
290
+
To achieve maximum overlap between computation and checkpointing, we can use the ``DefaultStager``. This component offloads the state dictionary creation and the D2H copy to a background thread.
Using ``AsyncStager`` introduces a background thread that consumes CPU resources. Ensure your environment has sufficient CPU cores to handle this without impacting the main training process.
299
+
300
+
.. code-block:: python
301
+
302
+
import os
303
+
304
+
import torch
305
+
import torch.distributed as dist
306
+
import torch.distributed.checkpoint as dcp
307
+
import torch.multiprocessing as mp
308
+
import torch.nn as nn
309
+
310
+
from torch.distributed.fsdp import fully_shard
311
+
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
312
+
from torch.distributed.checkpoint.stateful import Stateful
313
+
from torch.distributed.checkpoint.staging import DefaultStager
314
+
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
315
+
316
+
CHECKPOINT_DIR="checkpoint"
317
+
318
+
319
+
classAppState(Stateful):
320
+
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
321
+
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
322
+
dcp.save/load APIs.
323
+
324
+
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
325
+
and optimizer.
326
+
"""
327
+
328
+
def__init__(self, model, optimizer=None):
329
+
self.model = model
330
+
self.optimizer = optimizer
331
+
332
+
defstate_dict(self):
333
+
# this line automatically manages FSDP FQNs, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
# Pass the DefaultStager to enable fully asynchronous staging.
405
+
# This offloads the state_dict creation and GPU-to-CPU copy to a background thread.
406
+
# The return object (AsyncSaveResponse) exposes distinct futures for staging and upload.
407
+
checkpoint_future = dcp.async_save(
408
+
state_dict,
409
+
checkpoint_id=f"{CHECKPOINT_DIR}_step{step}",
410
+
async_stager=DefaultStager(),
411
+
)
412
+
413
+
# Ensure the last checkpoint completes
414
+
if checkpoint_future:
415
+
checkpoint_future.upload_completion.result()
416
+
417
+
cleanup()
418
+
419
+
420
+
if__name__=="__main__":
421
+
world_size = torch.cuda.device_count()
422
+
print(f"Running async checkpoint example on {world_size} devices.")
423
+
mp.spawn(
424
+
run_fsdp_checkpoint_save_example,
425
+
args=(world_size,),
426
+
nprocs=world_size,
427
+
join=True,
428
+
)
429
+
282
430
Conclusion
283
431
----------
284
432
In conclusion, we have learned how to use DCP's :func:`async_save` API to generate checkpoints off the critical training path. We've also learned about the
0 commit comments