Skip to content

Commit 125c82c

Browse files
niyunshengAlannaBurkesvekars
authored
[DCP] Add DefaultStager example to distributed async checkpoint recipe (#3711)
Fixes #3710 ## Description This PR updates the `distributed_async_checkpoint_recipe` to include the `DefaultStager` functionality introduced in PyTorch 2.9. **Motivation:** In large-scale training, even with standard `async_save`, the initial memory copy (Staging phase, GPU -> CPU) occurs on the main thread. This blocks the training loop. This PR introduces `DefaultStager`, which offloads this copy to a background thread, enabling **full computation-communication overlap**. **Key Changes:** 1. **New Section**: Added "Fully Asynchronous Staging with DefaultStager" to the recipe. 2. **Version Note**: Added `.. versionadded:: 2.9` to indicate version requirements. 3. **Advanced Example**: Provided a code example demonstrating how to overlap the D2H copy with the **entire Forward and Backward pass**: * Check `staging_completion` *after* backward but *before* `optimizer.step()` to ensure data consistency while maximizing parallel execution. * Check `upload_completion` before the next save to manage memory backpressure. Co-authored-by: Alanna Burke <burkealanna@meta.com> Co-authored-by: Alanna Burke <aburke626@gmail.com> Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent a5d2a6d commit 125c82c

1 file changed

Lines changed: 155 additions & 7 deletions

File tree

recipes_source/distributed_async_checkpoint_recipe.rst

Lines changed: 155 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
Asynchronous Saving with Distributed Checkpoint (DCP)
22
=====================================================
33

4-
**Author:** `Lucas Pasqualin <https://github.com/lucasllc>`__, `Iris Zhang <https://github.com/wz337>`__, `Rodrigo Kumpera <https://github.com/kumpera>`__, `Chien-Chin Huang <https://github.com/fegin>`__
4+
**Author:** `Lucas Pasqualin <https://github.com/lucasllc>`__, `Iris Zhang <https://github.com/wz337>`__, `Rodrigo Kumpera <https://github.com/kumpera>`__, `Chien-Chin Huang <https://github.com/fegin>`__, `Yunsheng Ni <https://github.com/niyunsheng>`__
55

6-
Checkpointing is often a bottle-neck in the critical path for distributed training workloads, incurring larger and larger costs as both model and world sizes grow.
6+
Checkpointing is often a bottleneck in the critical path for distributed training workloads, incurring larger and larger costs as both model and world sizes grow.
77
One excellent strategy for offsetting this cost is to checkpoint in parallel, asynchronously. Below, we expand the save example
88
from the `Getting Started with Distributed Checkpoint Tutorial <https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_checkpoint_recipe.rst>`__
99
to show how this can be integrated quite easily with ``torch.distributed.checkpoint.async_save``.
@@ -26,7 +26,7 @@ to show how this can be integrated quite easily with ``torch.distributed.checkpo
2626

2727
Asynchronous Checkpointing Overview
2828
------------------------------------
29-
Before getting started with Asynchronous Checkpointing, it's important to understand it's differences and limitations as compared to synchronous checkpointing.
29+
Before getting started with Asynchronous Checkpointing, it's important to understand its differences and limitations as compared to synchronous checkpointing.
3030
Specifically:
3131

3232
* Memory requirements - Asynchronous checkpointing works by first copying models into internal CPU-buffers.
@@ -35,8 +35,8 @@ Specifically:
3535
the memory constraints of their systems. Specifically, pinned memory implies the usage of ``page-lock`` memory, which can be scarce as compared to
3636
``pageable`` memory.
3737

38-
* Checkpoint Management - Since checkpointing is asynchronous, it is up to the user to manage concurrently run checkpoints. In general, users can
39-
employ their own management strategies by handling the future object returned form ``async_save``. For most users, we recommend limiting
38+
* Checkpoint Management - Since checkpointing is asynchronous, it is up to the user to manage concurrently run checkpoints.
39+
In general, users can employ their own management strategies by handling the future object returned form ``async_save``. For most users, we recommend limiting
4040
checkpoints to one asynchronous request at a time, avoiding additional memory pressure per request.
4141

4242

@@ -72,7 +72,7 @@ Specifically:
7272
self.optimizer = optimizer
7373
7474
def state_dict(self):
75-
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
75+
# this line automatically manages FSDP FQNs, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
7676
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
7777
return {
7878
"model": model_state_dict,
@@ -196,7 +196,7 @@ checkpoint requests users can take advantage of direct memory access to speed up
196196
self.optimizer = optimizer
197197
198198
def state_dict(self):
199-
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
199+
# this line automatically manages FSDP FQNs, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
200200
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
201201
return {
202202
"model": model_state_dict,
@@ -279,6 +279,154 @@ checkpoint requests users can take advantage of direct memory access to speed up
279279
)
280280
281281
282+
Fully Asynchronous Staging with DefaultStager
283+
--------------------------------------------
284+
285+
.. versionadded:: 2.9
286+
The ``async_stager`` argument and ``DefaultStager`` class were introduced in PyTorch 2.9.
287+
288+
While ``async_save`` handles the disk write asynchronously, the process of copying data from GPU to CPU (known as "staging") typically happens on the main thread. Even with Pinned Memory, this Device-to-Host (D2H) copy can block the training loop for large models.
289+
290+
To achieve maximum overlap between computation and checkpointing, we can use the ``DefaultStager``. This component offloads the state dictionary creation and the D2H copy to a background thread.
291+
292+
**Timeline Comparison:**
293+
294+
* **Standard async_save:** ``[GPU Compute] -> [CPU Copy (Blocking)] -> [Disk Write (Async)]``
295+
* **With AsyncStager:** ``[GPU Compute] || [CPU Copy (Async)] -> [Disk Write (Async)]``
296+
297+
.. note::
298+
Using ``AsyncStager`` introduces a background thread that consumes CPU resources. Ensure your environment has sufficient CPU cores to handle this without impacting the main training process.
299+
300+
.. code-block:: python
301+
302+
import os
303+
304+
import torch
305+
import torch.distributed as dist
306+
import torch.distributed.checkpoint as dcp
307+
import torch.multiprocessing as mp
308+
import torch.nn as nn
309+
310+
from torch.distributed.fsdp import fully_shard
311+
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
312+
from torch.distributed.checkpoint.stateful import Stateful
313+
from torch.distributed.checkpoint.staging import DefaultStager
314+
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
315+
316+
CHECKPOINT_DIR = "checkpoint"
317+
318+
319+
class AppState(Stateful):
320+
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
321+
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
322+
dcp.save/load APIs.
323+
324+
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
325+
and optimizer.
326+
"""
327+
328+
def __init__(self, model, optimizer=None):
329+
self.model = model
330+
self.optimizer = optimizer
331+
332+
def state_dict(self):
333+
# this line automatically manages FSDP FQNs, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
334+
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
335+
return {
336+
"model": model_state_dict,
337+
"optim": optimizer_state_dict
338+
}
339+
340+
def load_state_dict(self, state_dict):
341+
# sets our state dicts on the model and optimizer, now that we've loaded
342+
set_state_dict(
343+
self.model,
344+
self.optimizer,
345+
model_state_dict=state_dict["model"],
346+
optim_state_dict=state_dict["optim"]
347+
)
348+
349+
class ToyModel(nn.Module):
350+
def __init__(self):
351+
super(ToyModel, self).__init__()
352+
self.net1 = nn.Linear(16, 16)
353+
self.relu = nn.ReLU()
354+
self.net2 = nn.Linear(16, 8)
355+
356+
def forward(self, x):
357+
return self.net2(self.relu(self.net1(x)))
358+
359+
360+
def setup(rank, world_size):
361+
os.environ["MASTER_ADDR"] = "localhost"
362+
os.environ["MASTER_PORT"] = "12355 "
363+
364+
# initialize the process group
365+
dist.init_process_group("gloo", rank=rank, world_size=world_size)
366+
torch.cuda.set_device(rank)
367+
368+
369+
def cleanup():
370+
dist.destroy_process_group()
371+
372+
373+
def run_fsdp_checkpoint_save_example(rank, world_size):
374+
print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
375+
setup(rank, world_size)
376+
377+
# create a model and move it to GPU with id rank
378+
model = ToyModel().to(rank)
379+
model = fully_shard(model)
380+
381+
loss_fn = nn.MSELoss()
382+
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
383+
384+
checkpoint_future = None
385+
for step in range(10):
386+
print(f"Step {step} starting...")
387+
optimizer.zero_grad()
388+
model(torch.rand(8, 16, device="cuda")).sum().backward()
389+
390+
# Critical: We must ensure the previous checkpoint's D2H copy (staging)
391+
# is complete before the optimizer modifies the model parameters.
392+
# Placing this await AFTER the backward pass allows us to overlap
393+
# the D2H copy with the current step's Forward and Backward computation.
394+
if checkpoint_future is not None:
395+
checkpoint_future.staging_completion.result()
396+
optimizer.step()
397+
398+
# waits for checkpointing to finish if one exists, avoiding queuing more then one checkpoint request at a time
399+
if checkpoint_future is not None:
400+
checkpoint_future.upload_completion.result()
401+
402+
state_dict = { "app": AppState(model, optimizer) }
403+
404+
# Pass the DefaultStager to enable fully asynchronous staging.
405+
# This offloads the state_dict creation and GPU-to-CPU copy to a background thread.
406+
# The return object (AsyncSaveResponse) exposes distinct futures for staging and upload.
407+
checkpoint_future = dcp.async_save(
408+
state_dict,
409+
checkpoint_id=f"{CHECKPOINT_DIR}_step{step}",
410+
async_stager=DefaultStager(),
411+
)
412+
413+
# Ensure the last checkpoint completes
414+
if checkpoint_future:
415+
checkpoint_future.upload_completion.result()
416+
417+
cleanup()
418+
419+
420+
if __name__ == "__main__":
421+
world_size = torch.cuda.device_count()
422+
print(f"Running async checkpoint example on {world_size} devices.")
423+
mp.spawn(
424+
run_fsdp_checkpoint_save_example,
425+
args=(world_size,),
426+
nprocs=world_size,
427+
join=True,
428+
)
429+
282430
Conclusion
283431
----------
284432
In conclusion, we have learned how to use DCP's :func:`async_save` API to generate checkpoints off the critical training path. We've also learned about the

0 commit comments

Comments
 (0)