Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
001f77c
Initial plan
Copilot Feb 27, 2026
b90aee5
Revert "fix: update 1 file reformatted."
Copilot Feb 27, 2026
b6da9af
Merge pull request #5 from nathon-lee/copilot/git-revert-ff886701
nathon-lee Feb 27, 2026
bb7f64f
Merge branch 'deepspeedai:master' into master
nathon-lee Mar 6, 2026
cbc816c
Initial plan
Copilot Mar 6, 2026
5fcc9a7
Reapply "fix: update 1 file reformatted."
Copilot Mar 6, 2026
f7c5d75
Merge pull request #6 from nathon-lee/copilot/remove-commits-from-master
nathon-lee Mar 6, 2026
18efbcc
Merge branch 'deepspeedai:master' into master
nathon-lee Mar 25, 2026
e2ac74d
Merge branch 'deepspeedai:master' into master
nathon-lee Mar 27, 2026
da07382
Merge branch 'deepspeedai:master' into master
nathon-lee Mar 30, 2026
5d8875c
Merge branch 'deepspeedai:master' into master
nathon-lee Mar 31, 2026
316b6dd
Merge branch 'deepspeedai:master' into master
nathon-lee Apr 1, 2026
2020543
Merge branch 'deepspeedai:master' into master
nathon-lee Apr 2, 2026
1a8694c
Merge branch 'deepspeedai:master' into master
nathon-lee Apr 16, 2026
d6725be
Merge branch 'deepspeedai:master' into master
nathon-lee Apr 23, 2026
a06c548
Merge branch 'deepspeedai:master' into master
nathon-lee May 1, 2026
6959eb4
Merge branch 'deepspeedai:master' into master
nathon-lee May 5, 2026
e88eb3e
Merge branch 'deepspeedai:master' into master
nathon-lee May 7, 2026
196f60c
feat(zero): implement elastic checkpoint support for ZeRO-3
nathon-lee May 7, 2026
773f32e
[ZeRO-3]: Implement elastic checkpoint save and load
nathon-lee May 7, 2026
9af52c2
[ZeRO-3]: Add unit tests for elastic checkpoint
nathon-lee May 8, 2026
3b08484
fix: fix stage-3 elastic checkpoint cross-world-size save/load
nathon-lee May 8, 2026
b502ee1
fix(zero3): move restored elastic optimizer state to correct device
nathon-lee May 9, 2026
c26b837
Merge pull request #17 from nathon-lee/feat/zero3-elastic-checkpoint-…
nathon-lee May 10, 2026
526d5c2
docs(zero): document ZeRO-3 elastic checkpoint support
nathon-lee May 11, 2026
7eb7e03
Merge pull request #18 from nathon-lee/feat/zero3-elastic-checkpoint-…
nathon-lee May 11, 2026
a7d728c
Merge branch 'deepspeedai:master' into feat/zero3-elastic-checkpoint
nathon-lee May 14, 2026
0d95e79
Merge branch 'deepspeedai:master' into feat/zero3-elastic-checkpoint
nathon-lee May 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
WEIGHT_QUANTIZE_ROUNDING, \
WEIGHT_QUANTIZE_VERBOSE, \
WEIGHT_QUANTIZE_KERNEL
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FROZEN_PARAM_FRAGMENTS, UNIVERSAL_CHECKPOINT_INFO
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FROZEN_PARAM_FRAGMENTS, UNIVERSAL_CHECKPOINT_INFO, BASE_OPTIMIZER_STATE
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
from deepspeed.checkpoint.ds_to_universal import dp_index_to_str
from deepspeed.runtime.sparse_tensor import SparseTensor
Expand Down Expand Up @@ -2072,6 +2072,7 @@ def _configure_zero_optimizer(self, optimizer):
enable_sanity_checks=self.is_sanity_checks_enabled(),
cpuadam_cores_perc=self.cpuadam_cores_perc(),
save_muon_momentum_buffer_in_memory=self.zero_save_muon_momentum_buffer_in_memory(),
elastic_checkpoint=self.zero_elastic_checkpoint(),
)

else:
Expand Down Expand Up @@ -3767,11 +3768,16 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
checkpoint_folder = f'{os.path.join(load_dir, tag)}'
else:
if load_optimizer_states and self.seq_dp_world_size != self.loaded_checkpoint_dp_world_size:
raise ZeRORuntimeException("The checkpoint being loaded used a DP " \
f"world size of {self.loaded_checkpoint_dp_world_size} but the " \
f"current world size is {self.seq_dp_world_size}. Automatic adjustment " \
"of ZeRO's optimizer state partitioning with a new world size is not " \
"currently supported.")
# ZeRO Stage 3 elastic checkpoint repartitions optimizer states internally
# across arbitrary DP world sizes; skip the guard for that case.
zero3_elastic = (self.zero_optimization_stage() == ZeroStageEnum.weights
and self.zero_elastic_checkpoint())
if not zero3_elastic:
raise ZeRORuntimeException("The checkpoint being loaded used a DP " \
f"world size of {self.loaded_checkpoint_dp_world_size} but the " \
f"current world size is {self.seq_dp_world_size}. Automatic adjustment " \
"of ZeRO's optimizer state partitioning with a new world size is not " \
"currently supported.")
checkpoint_folder = None
zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
if zero_sd_list is None:
Expand Down Expand Up @@ -3822,13 +3828,26 @@ def _get_all_zero_checkpoint_names(self, load_dir, tag, bf16_mode):
return zero_ckpt_names

def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names):
cur_rank = dist.get_rank(group=self.optimizer.dp_process_group)
# Load the current rank's checkpoint first to detect whether the on-disk format is
# elastic. This allows autodetection even when elastic_checkpoint=False in the config
# (e.g., loading a checkpoint that was saved with elastic_checkpoint=True).
cached = {}
if cur_rank < len(zero_ckpt_names) and zero_ckpt_names[cur_rank] is not None:
cached[cur_rank] = self.checkpoint_engine.load(zero_ckpt_names[cur_rank], map_location='cpu')
cur_optim_sd = cached[cur_rank].get(OPTIMIZER_STATE_DICT)
ckpt_is_elastic = self.zero_elastic_checkpoint() or (cur_optim_sd is not None
and BASE_OPTIMIZER_STATE in cur_optim_sd)
else:
ckpt_is_elastic = self.zero_elastic_checkpoint()

zero_sd_list = []
for i, ckpt_name in enumerate(zero_ckpt_names):
_state = None
if ckpt_name is None:
if i in cached:
_state = cached[i]
elif ckpt_name is None:
_state = {OPTIMIZER_STATE_DICT: None}
# Fully load state for current rank
elif self.zero_elastic_checkpoint() or dist.get_rank(group=self.optimizer.dp_process_group) == i:
elif ckpt_is_elastic or dist.get_rank(group=self.optimizer.dp_process_group) == i:
_state = self.checkpoint_engine.load(
ckpt_name,
map_location='cpu',
Expand Down
187 changes: 178 additions & 9 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper
from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper
from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE, LOSS_SCALER
from deepspeed.checkpoint.constants import (OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE,
LOSS_SCALER, BASE_OPTIMIZER_STATE, BASE_OPTIMIZER_STATE_STEP,
GROUP_PADDINGS)
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.zero.muon.original_muon import muon_update
from deepspeed.runtime.zero.muon.muon_optimizer import MuonWithAuxAdam
Expand Down Expand Up @@ -2954,6 +2956,51 @@ def _clear_fp32_optimizer_param_groups(self):
for param_group in self.optimizer.param_groups:
param_group['params'] = []

def _get_elastic_optimizer_state(self):
# Return per-sub-group lean optimizer states (per-param tensors with padding stripped).
# Lean format allows the checkpoint to be loaded with a different DP world size.
sub_group_states = []
for sub_group_id, _ in enumerate(self.fp16_groups):
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
if fp32_param not in self.optimizer.state or not self.optimizer.state[fp32_param]:
sub_group_states.append({})
continue
lean_state = {}
for key, value in self.optimizer.state[fp32_param].items():
if torch.is_tensor(value):
lean_state[key] = self._get_lean_tensors(value, self.fp16_partitioned_groups[sub_group_id],
self.groups_padding[sub_group_id])
else:
lean_state[key] = value
sub_group_states.append(lean_state)
return sub_group_states

def _elastic_state_dict(self):
# Build a world-size-agnostic checkpoint by saving lean (padding-stripped) fp32 and
# optimizer-state partitions. Each sub-group entry is a list of per-param lean tensors
# so that loading code can merge ranks and re-partition for any DP world size.
state_dict = {}
state_dict[ZERO_STAGE] = ZeroStageEnum.weights
state_dict[LOSS_SCALER] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict[PARTITION_COUNT] = self.partition_count

# Save padded flat tensors (same layout as rigid format) so that zero_to_fp32.py can
# reconstruct model weights without knowing the DP world size used at save time.
state_dict[FP32_FLAT_GROUPS] = list(self.fp32_partitioned_groups_flat)

self._set_fp32_optimizer_param_groups()
state_dict[BASE_OPTIMIZER_STATE] = self._get_elastic_optimizer_state()
if self.optimizer.param_groups and "step" in self.optimizer.param_groups[0]:
assert all(pg["step"] == self.optimizer.param_groups[0]["step"] for pg in self.optimizer.param_groups), \
"All param groups must have the same step value"
state_dict[BASE_OPTIMIZER_STATE_STEP] = self.optimizer.param_groups[0]["step"]
self._clear_fp32_optimizer_param_groups()

state_dict[GROUP_PADDINGS] = self.groups_padding
return state_dict

def _rigid_state_dict(self):
state_dict = {}
state_dict[ZERO_STAGE] = ZeroStageEnum.weights
Expand Down Expand Up @@ -2981,7 +3028,7 @@ def state_dict(self):
torch.save(checkpoint, "saved.pth")
"""
if self.elastic_checkpoint:
raise NotImplementedError("ZeRO-3 does not yet support elastic checkpointing, please disable for now.")
return self._elastic_state_dict()

return self._rigid_state_dict()

Expand Down Expand Up @@ -3055,6 +3102,121 @@ def _restore_base_optimizer_state(self, all_state_dict):
else:
self.optimizer.state[p][key] = saved

def _repartition_for_current_rank(self, lean_parts_per_ckpt_rank, param):
# Merge lean (padding-stripped) partitions from all checkpoint ranks into the full
# parameter, then slice out the current rank's portion under the new world size.
full_tensor = torch.cat(lean_parts_per_ckpt_rank)
rank = dist.get_rank(group=self.dp_process_group)
partition_size = param.partition_numel()
start = rank * partition_size
end = min(start + partition_size, param.ds_numel)
local_part = full_tensor[start:end]
# Last rank may need zero-padding to fill its partition slot.
if local_part.numel() < partition_size:
pad = torch.zeros(partition_size - local_part.numel(), dtype=local_part.dtype, device=local_part.device)
local_part = torch.cat([local_part, pad])
return local_part

def _restore_from_elastic_fp32_partitions(self, all_state_dict):
# Rebuild fp32 master weights from an elastic checkpoint that may have been saved with
# a different DP world size. FP32_FLAT_GROUPS stores padded flat tensors (one per
# sub-group per checkpoint rank), identical in layout to the rigid format, so the same
# per-param partition-size arithmetic applies.
ckpt_world_size = len(all_state_dict)
for sub_group_id, fp32_flat in enumerate(self.fp32_partitioned_groups_flat):
params = self.fp16_groups[sub_group_id]
# Compute each param's partition size under the checkpoint world size.
ckpt_partition_sizes = [(p.ds_numel + ckpt_world_size - 1) // ckpt_world_size for p in params]
new_parts = []
for param_idx, param in enumerate(params):
ckpt_ps = ckpt_partition_sizes[param_idx]
lean_parts = []
for rank, sd in enumerate(all_state_dict):
padded_flat = sd[FP32_FLAT_GROUPS][sub_group_id]
# Split the rank's padded flat group into one chunk per param.
per_param_chunks = padded_flat.split(ckpt_partition_sizes)
padded_chunk = per_param_chunks[param_idx]
# Strip padding: the real element count for this rank/param pair.
lean_len = max(0, min(ckpt_ps, param.ds_numel - rank * ckpt_ps))
lean_parts.append(padded_chunk[:lean_len])
new_parts.append(self._repartition_for_current_rank(lean_parts, param))
fp32_flat.data.copy_(torch.cat(new_parts).to(fp32_flat.dtype))

def _restore_elastic_optimizer_state(self, all_state_dict):
# Rebuild per-sub-group optimizer states from an elastic checkpoint.
step = None
if BASE_OPTIMIZER_STATE_STEP in all_state_dict[0]:
assert all(sd[BASE_OPTIMIZER_STATE_STEP] == all_state_dict[0][BASE_OPTIMIZER_STATE_STEP]
for sd in all_state_dict), "Checkpoint ranks have inconsistent step values"
step = all_state_dict[0][BASE_OPTIMIZER_STATE_STEP]

for sub_group_id, _ in enumerate(self.fp16_groups):
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
params = self.fp16_groups[sub_group_id]
all_sub_group_states = [sd[BASE_OPTIMIZER_STATE][sub_group_id] for sd in all_state_dict]
if not all_sub_group_states[0]:
continue
restored_state = {}
for key in all_sub_group_states[0].keys():
sample = all_sub_group_states[0][key]
if isinstance(sample, list) and len(sample) > 0 and torch.is_tensor(sample[0]):
new_parts = [
self._repartition_for_current_rank(
[rank_state[key][param_idx] for rank_state in all_sub_group_states], param)
for param_idx, param in enumerate(params)
]
# Move to the same device and dtype as fp32_param so the fused optimizer
# does not raise "tensor not on same device" during step().
restored_state[key] = torch.cat(new_parts).to(device=fp32_param.device, dtype=fp32_param.dtype)
else:
restored_state[key] = sample
self.optimizer.state[fp32_param] = restored_state

if step is not None:
for param_group in self.optimizer.param_groups:
param_group['step'] = step

def _elastic_load_state_dict(self, all_state_dict, load_optimizer_states=True, load_from_fp32_weights=False):
# Load a ZeRO-3 elastic checkpoint. The checkpoint may have been saved with a different
# DP world size; elastic merge-and-repartition handles the mismatch transparently.
sd0 = all_state_dict[0]
self.loss_scaler = sd0[LOSS_SCALER]
self.dynamic_loss_scale = sd0['dynamic_loss_scale']
self.overflow = sd0['overflow']

if load_optimizer_states:
self._set_fp32_optimizer_param_groups()
self._restore_elastic_optimizer_state(all_state_dict)
self._clear_fp32_optimizer_param_groups()

if self.swap_optimizer:
self.optimizer_swapper.purge_state()
timer_names = set()
self._partition_all_parameters()
for sub_group_id, group in enumerate(self.fp16_groups):
self._prepare_sub_group(sub_group_id, timer_names)
self._reassign_or_swap_out_partitioned_parameters(sub_group_id)
self._release_sub_group(sub_group_id, timer_names)
self._post_step(timer_names)

if load_from_fp32_weights:
self._restore_from_elastic_fp32_partitions(all_state_dict)
else:
self._restore_from_bit16_weights()

# Sync fp16 partitions from the (now-updated) fp32 master weights.
for sub_group_id in range(len(self.fp32_partitioned_groups_flat)):
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
if sum(fp32_param.size()) > 0:
fp16_param = self.fp16_partitioned_groups_flat[sub_group_id]
fp16_param.data.copy_(fp32_param.data)

for sub_group_id in range(len(self.fp16_partitioned_groups_flat)):
updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id],
self.fp16_partitioned_groups[sub_group_id])
for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params):
partitioned_param.data = q.data

def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True):
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict[LOSS_SCALER]
Expand Down Expand Up @@ -3099,7 +3261,6 @@ def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True):
for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params):
partitioned_param.data = q.data

# TODO: Support different/changing load/save DP degree.
def load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
Expand Down Expand Up @@ -3132,11 +3293,19 @@ def load_state_dict(self,
optimizer.load_state_dict(checkpoint['optimizer'])
"""

if self.elastic_checkpoint:
raise NotImplementedError("ZeRO-3 does not yet support elastic checkpointing, please disable for now.")

if checkpoint_folder:
self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)
return

# Detect elastic vs. rigid format by the presence of BASE_OPTIMIZER_STATE.
# Elastic checkpoints store lean per-param tensors and support arbitrary DP world sizes.
# Use the current rank's own state dict for detection; other slots may be None when the
# engine used lazy-loading (only fetched the current rank's file).
my_rank = dist.get_rank(group=self.dp_process_group)
own_sd = state_dict_list[my_rank] if my_rank < len(state_dict_list) else None
ckpt_is_elastic = own_sd is not None and BASE_OPTIMIZER_STATE in own_sd
if ckpt_is_elastic:
self._elastic_load_state_dict(state_dict_list, load_optimizer_states, load_from_fp32_weights)
else:
self._rigid_load_state_dict(state_dict_list[dist.get_rank(group=self.dp_process_group)],
load_optimizer_states=load_optimizer_states)
Expand All @@ -3153,9 +3322,9 @@ def load_state_dict(self,
if local_rank != rank_end:
dist.send(tensor=load_serial, dst=rank + 1)

if len(self.persistent_parameters) > 0:
self.persistent_parameters[0].partition(self.persistent_parameters)
# self.persistent_parameters[0].all_gather(self.persistent_parameters) # this will be done in checkpoint_event_epilogue() so remove it to prevent double all_gather
if len(self.persistent_parameters) > 0:
self.persistent_parameters[0].partition(self.persistent_parameters)
# self.persistent_parameters[0].all_gather(self.persistent_parameters) # this will be done in checkpoint_event_epilogue() so remove it to prevent double all_gather

def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder)
Expand Down
Loading