From 1947a25a804d8c828bd3911981336904b414a4a4 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 29 May 2026 13:06:22 -0700 Subject: [PATCH 1/3] Fix DeepCompile ZeRO-1 grad target lifetime Signed-off-by: Masahiro Tanaka --- csrc/compile/init.cpp | 3 + csrc/compile/z1.cpp | 5 + csrc/compile/z1.h | 1 + csrc/includes/deepcompile.h | 15 ++ deepspeed/compile/init_z1.py | 187 +++++++++++++++------ deepspeed/compile/init_z3.py | 2 +- deepspeed/compile/util.py | 2 +- deepspeed/runtime/zero/stage_1_and_2.py | 31 +++- tests/unit/compile/__init__.py | 4 + tests/unit/v1/compile/test_compile_zero.py | 49 ++++++ 10 files changed, 243 insertions(+), 56 deletions(-) create mode 100644 tests/unit/compile/__init__.py diff --git a/csrc/compile/init.cpp b/csrc/compile/init.cpp index cbb03907d327..ae5d52775a02 100644 --- a/csrc/compile/init.cpp +++ b/csrc/compile/init.cpp @@ -92,6 +92,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("init", &dc::init, "Set the process group"); m.def("cleanup", &dc::cleanup, "Cleanup the process group"); m.def("register_param", &dc::register_param, "Register a parameter"); + m.def("update_param_grad_buffer", + &dc::update_param_grad_buffer, + "Update a registered parameter grad buffer"); m.def("register_graph_z1", &dc::register_graph_z1, "Register graph with a list of ds parameter ids"); diff --git a/csrc/compile/z1.cpp b/csrc/compile/z1.cpp index d0c804b7f1ad..5ebe1e3b21f4 100644 --- a/csrc/compile/z1.cpp +++ b/csrc/compile/z1.cpp @@ -127,4 +127,9 @@ void register_param(long ds_id, param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, false, offset, false); } +void update_param_grad_buffer(long ds_id, at::Tensor grad_buffer, int64_t offset) +{ + param_registry->updateGradBuffer(ds_id, grad_buffer, offset); +} + } // namespace dc diff --git a/csrc/compile/z1.h b/csrc/compile/z1.h index 1d3607a59b06..493bddfb3ab1 100644 --- a/csrc/compile/z1.h +++ b/csrc/compile/z1.h @@ -15,4 +15,5 @@ void register_param(long ds_id, at::Tensor ds_tensor, at::Tensor grad_buffer, int64_t offset); +void update_param_grad_buffer(long ds_id, at::Tensor grad_buffer, int64_t offset); } // namespace dc diff --git a/csrc/includes/deepcompile.h b/csrc/includes/deepcompile.h index 7016d4a99310..df9ee0bb01c4 100644 --- a/csrc/includes/deepcompile.h +++ b/csrc/includes/deepcompile.h @@ -291,6 +291,11 @@ class DSParam { return ds_tensor_; } at::Tensor getGradBuffer() const { return grad_buffer_; } + void setGradBuffer(at::Tensor grad_buffer, int64_t offset) + { + grad_buffer_ = grad_buffer; + offset_ = offset; + } bool isPartitioned() const { return partitioned_; } int64_t getOffset() const { return offset_; } void setPersistent(bool persistent) { persistent_ = persistent; } @@ -395,6 +400,12 @@ class DSParamRegistry { valid_[ds_id] = false; } + void updateGradBuffer(long ds_id, at::Tensor grad_buffer, int64_t offset) + { + if (grad_buffer.numel() > 0) { grad_buffer.zero_(); } + params_.at(ds_id).setGradBuffer(grad_buffer, offset); + } + void registerGatheredParam(long ds_id, at::Tensor ds_tensor) { gathered_params_.emplace(ds_id, ds_tensor); @@ -477,6 +488,10 @@ class CustomOpExecutor { // This synchronization ensures all of reduce calls are done before optimizer's step. at::cuda::stream_synchronize(rs_stream_); + + // Match ZeRO's IPG buffer lifecycle: reduction buckets are backward work buffers and + // should not stay allocated while Adam creates optimizer temporaries. + reduce_buckets_->clear(); } virtual at::Tensor reduceGrad(at::Tensor grad_tensor, long ds_id) diff --git a/deepspeed/compile/init_z1.py b/deepspeed/compile/init_z1.py index f73a1953f7e4..3e60b66d7cc6 100644 --- a/deepspeed/compile/init_z1.py +++ b/deepspeed/compile/init_z1.py @@ -15,6 +15,67 @@ WARMUP = 5 +def _empty_grad_buffer(param): + return torch.empty([0], dtype=param.dtype, device=param.device) + + +def _build_partition_grad_views(optimizer, group_idx): + missing = object() + original_all_grad_tensors = optimizer.all_grad_tensors.get(group_idx, missing) + optimizer.all_grad_tensors[group_idx] = optimizer.get_all_grad_tensors(optimizer.params_in_partition[group_idx], + optimizer.gradient_accumulation_dtype) + try: + return optimizer.get_flat_partition(optimizer.params_in_partition[group_idx], + optimizer.first_offset[group_idx], + optimizer.partition_size[group_idx], + dtype=optimizer.gradient_accumulation_dtype, + device=get_accelerator().current_device_name(), + param_group_idx=group_idx, + return_tensor_list=True) + finally: + if original_all_grad_tensors is missing: + optimizer.all_grad_tensors.pop(group_idx, None) + else: + optimizer.all_grad_tensors[group_idx] = original_all_grad_tensors + + +def _build_flat_partition_grad_views(optimizer, group_idx): + partition_size = int(optimizer.partition_size[group_idx]) + dtype = optimizer.gradient_accumulation_dtype + device = get_accelerator().current_device_name() + flat_buffer = torch.zeros(partition_size, dtype=dtype, device=device) + + views = [] + current_size = 0 + for i, tensor in enumerate(optimizer.params_in_partition[group_idx]): + num_elements = tensor.numel() + tensor_offset = 0 + + if i == 0 and optimizer.first_offset[group_idx] > 0: + tensor_offset = int(optimizer.first_offset[group_idx]) + num_elements -= tensor_offset + + if num_elements > partition_size - current_size: + num_elements = partition_size - current_size + + if num_elements <= 0: + continue + + view = flat_buffer.narrow(0, current_size, int(num_elements)) + if tensor_offset == 0 and num_elements == tensor.numel(): + view = view.view(tensor.shape) + views.append(view) + current_size += int(num_elements) + + if current_size >= partition_size: + break + + if current_size < partition_size: + views.append(flat_buffer.narrow(0, current_size, partition_size - current_size)) + + return flat_buffer, views + + def init_z1(engine, backend, compile_config, compile_kwargs, schedule=None, use_z2=False): optimizer = engine.optimizer @@ -26,53 +87,85 @@ def init_z1(engine, backend, compile_config, compile_kwargs, schedule=None, use_ dc = get_deepcompile_handle() dc.init(engine.data_parallel_group, compile_config, engine.zero_reduce_bucket_size()) - grad_buffer = {} - - # Save original all_grad_tensors state as we temporarily modify it - original_all_grad_tensors = optimizer.all_grad_tensors.copy() if hasattr(optimizer, 'all_grad_tensors') else {} - - for i, group in enumerate(optimizer.bit16_groups): - # Temporarily populate all_grad_tensors for get_flat_partition call - # This is needed because get_flat_partition accesses all_grad_tensors[param_group_idx][i] - # but it's empty during initialization - if i not in optimizer.all_grad_tensors or optimizer.all_grad_tensors[i] is None: - optimizer.all_grad_tensors[i] = optimizer.get_all_grad_tensors(optimizer.params_in_partition[i], - optimizer.gradient_accumulation_dtype) - - grad_buffer[i] = optimizer.get_flat_partition(optimizer.params_in_partition[i], - optimizer.first_offset[i], - optimizer.partition_size[i], - dtype=optimizer.gradient_accumulation_dtype, - device=get_accelerator().current_device_name(), - param_group_idx=i, - return_tensor_list=True) - grad_buffer[i] = [p.clone().detach() for p in grad_buffer[i]] # Maybe not necessary - - index_in_partition = 0 - first_in_partition = True - for p in group: - param_id = optimizer.get_param_id(p) - p.param_id = param_id - in_partition = optimizer.is_param_in_current_partition[param_id] - - if in_partition: - buf = grad_buffer[i][index_in_partition] - offset = optimizer.first_offset[i] if first_in_partition else 0 - # print(f"[r{dist.get_rank()}] Registering group {i} param {param_id} in_partition={in_partition} p={p.shape} buf={buf.shape} partition_offset={offset}") - dc.register_param(p.param_id, p.shape, p, buf, int(offset)) - index_in_partition += 1 - first_in_partition = False - else: - # print(f"[r{dist.get_rank()}] Registering group {i} param {param_id} in_partition={in_partition} p={p.shape} buf=None") - dc.register_param(p.param_id, p.shape, p, torch.empty([0], dtype=p.dtype, device=p.device), 0) - - # Restore original all_grad_tensors state - optimizer.all_grad_tensors = original_all_grad_tensors - - def set_grad_buffer(): - optimizer.averaged_gradients = copy.copy(grad_buffer) - - add_pre_backward_hook(set_grad_buffer) + if use_z2: + grad_buffer = {} + for i, group in enumerate(optimizer.bit16_groups): + grad_buffer[i] = [p.clone().detach() for p in _build_partition_grad_views(optimizer, i)] + + index_in_partition = 0 + first_in_partition = True + for p in group: + param_id = optimizer.get_param_id(p) + p.param_id = param_id + in_partition = optimizer.is_param_in_current_partition[param_id] + + if in_partition: + buf = grad_buffer[i][index_in_partition] + offset = optimizer.first_offset[i] if first_in_partition else 0 + dc.register_param(p.param_id, p.shape, p, buf, int(offset)) + index_in_partition += 1 + first_in_partition = False + else: + dc.register_param(p.param_id, p.shape, p, _empty_grad_buffer(p), 0) + + def set_z2_grad_buffer(_is_gradient_accumulation_boundary): + optimizer.averaged_gradients = copy.copy(grad_buffer) + + add_pre_backward_hook(set_z2_grad_buffer) + else: + grad_buffer_metadata = {} + + for i, group in enumerate(optimizer.bit16_groups): + grad_buffer_metadata[i] = [] + first_in_partition = True + for p in group: + param_id = optimizer.get_param_id(p) + p.param_id = param_id + in_partition = optimizer.is_param_in_current_partition[param_id] + + if in_partition: + offset = optimizer.first_offset[i] if first_in_partition else 0 + grad_buffer_metadata[i].append((p.param_id, p, int(offset))) + dc.register_param(p.param_id, p.shape, p, _empty_grad_buffer(p), 0) + first_in_partition = False + else: + dc.register_param(p.param_id, p.shape, p, _empty_grad_buffer(p), 0) + + optimizer._deepcompile_z1_grad_buffer_metadata = grad_buffer_metadata + optimizer._deepcompile_z1_current_grad_buffers = {} + optimizer._deepcompile_z1_current_flat_grad_buffers = {} + + def set_z1_grad_buffer(is_gradient_accumulation_boundary): + if not is_gradient_accumulation_boundary: + release_grad_buffer() + optimizer.averaged_gradients = {} + return + + current_grad_buffers = {} + current_flat_grad_buffers = {} + for group_idx in range(len(optimizer.bit16_groups)): + flat_grad_buffer, group_grad_buffers = _build_flat_partition_grad_views(optimizer, group_idx) + current_flat_grad_buffers[group_idx] = flat_grad_buffer + current_grad_buffers[group_idx] = group_grad_buffers + for (param_id, _, offset), grad_buffer in zip(grad_buffer_metadata[group_idx], group_grad_buffers): + dc.update_param_grad_buffer(param_id, grad_buffer, offset) + optimizer._deepcompile_z1_current_flat_grad_buffers = current_flat_grad_buffers + optimizer._deepcompile_z1_current_grad_buffers = current_grad_buffers + optimizer.averaged_gradients = current_grad_buffers + + def release_grad_buffer(group_idx=None): + group_indices = grad_buffer_metadata.keys() if group_idx is None else [group_idx] + for idx in group_indices: + for param_id, param, _ in grad_buffer_metadata[idx]: + dc.update_param_grad_buffer(param_id, _empty_grad_buffer(param), 0) + if idx in optimizer._deepcompile_z1_current_grad_buffers: + optimizer._deepcompile_z1_current_grad_buffers[idx] = None + if idx in optimizer._deepcompile_z1_current_flat_grad_buffers: + optimizer._deepcompile_z1_current_flat_grad_buffers[idx] = None + + optimizer._deepcompile_z1_release_grad_buffers = release_grad_buffer + + add_pre_backward_hook(set_z1_grad_buffer) if schedule is None: schedule = [] diff --git a/deepspeed/compile/init_z3.py b/deepspeed/compile/init_z3.py index 2b9d404b3781..2feb78a3944a 100644 --- a/deepspeed/compile/init_z3.py +++ b/deepspeed/compile/init_z3.py @@ -72,7 +72,7 @@ def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None): if use_opt: - def set_grad_buffer(): + def set_grad_buffer(_is_gradient_accumulation_boundary): for i, sub_group in enumerate(optimizer.fp16_groups): optimizer.averaged_gradients[i] = [ optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[param.ds_id] diff --git a/deepspeed/compile/util.py b/deepspeed/compile/util.py index 0fd3b6b389db..c9aed53c5a62 100644 --- a/deepspeed/compile/util.py +++ b/deepspeed/compile/util.py @@ -69,7 +69,7 @@ def add_pre_backward_hook(hook): def deepcompile_backward_prologue(is_gradient_accumulation_boundary): for hook in pre_backward_hooks: - hook() + hook(is_gradient_accumulation_boundary) dc = get_deepcompile_handle() dc.start_backward(is_gradient_accumulation_boundary) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 4b55c56d0929..6c3ae0afee63 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -2133,6 +2133,9 @@ def step(self, closure=None): if self.overflow: see_memory_usage('After overflow before clearing gradients') self.zero_grad(set_to_none=True) + release_deepcompile_grad_buffers = getattr(self, "_deepcompile_z1_release_grad_buffers", None) + if release_deepcompile_grad_buffers is not None: + release_deepcompile_grad_buffers() if self.cpu_offload: self.reset_cpu_buffers() else: @@ -2184,13 +2187,24 @@ def step(self, closure=None): # create a flat gradients for parameters updated by this process # If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors - if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1: - single_grad_partition = self.flatten_dense_tensors_aligned( - self.averaged_gradients[i], - int(self.partition_size[i])).to(self.single_partition_of_fp32_groups[i].dtype) - else: - single_grad_partition = self.flatten(self.averaged_gradients[i]).to( - self.single_partition_of_fp32_groups[i].dtype) + deepcompile_flat_grad_buffers = getattr(self, "_deepcompile_z1_current_flat_grad_buffers", None) + flat_grad_partition = None + if deepcompile_flat_grad_buffers is not None: + flat_grad_partition = deepcompile_flat_grad_buffers.get(i) + if flat_grad_partition is not None: + flat_grad_partition = flat_grad_partition.view(-1) + assert flat_grad_partition.numel() == self.partition_size[i], \ + "DeepCompile flat gradient partition has different number of elements than partition size {} {} {} {}".format( + flat_grad_partition.numel(), self.partition_size[i], i, partition_id) + + if flat_grad_partition is None: + if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1: + flat_grad_partition = self.flatten_dense_tensors_aligned(self.averaged_gradients[i], + int(self.partition_size[i])) + else: + flat_grad_partition = self.flatten(self.averaged_gradients[i]) + single_grad_partition = flat_grad_partition.to(self.single_partition_of_fp32_groups[i].dtype) + del flat_grad_partition assert single_grad_partition.numel() == self.partition_size[i], \ "averaged gradients have different number of elements that partition size {} {} {} {}".format( single_grad_partition.numel(), self.partition_size[i], i, partition_id) @@ -2201,6 +2215,9 @@ def step(self, closure=None): self.averaged_gradients[i] = None self.all_grad_tensors[i] = None + release_deepcompile_grad_buffers = getattr(self, "_deepcompile_z1_release_grad_buffers", None) + if release_deepcompile_grad_buffers is not None: + release_deepcompile_grad_buffers(i) self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) self.timers(OPTIMIZER_GRADIENTS_TIMER).stop() diff --git a/tests/unit/compile/__init__.py b/tests/unit/compile/__init__.py new file mode 100644 index 000000000000..6f5f5619004b --- /dev/null +++ b/tests/unit/compile/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/v1/compile/test_compile_zero.py b/tests/unit/v1/compile/test_compile_zero.py index 16ad12d30f13..f91c33b8b011 100644 --- a/tests/unit/v1/compile/test_compile_zero.py +++ b/tests/unit/v1/compile/test_compile_zero.py @@ -12,6 +12,7 @@ from unit.v1.compile.util import compare_loss from unit.common import DistributedTest +from unit.simple_model import SimpleModel from unit.util import bf16_required_version_check, skip_on_arch import deepspeed from deepspeed.ops.aio import AsyncIOBuilder @@ -117,6 +118,54 @@ def test(self, zero_stage, dtype, deepcompile): # Need warmup steps compare_loss(self, config_dict, dtype, iteration=10) + def test_zero1_releases_grad_buffers_after_optimizer_step(self): + if not required_torch_version(min_version=2.6): + pytest.skip("DeepCompile requires PyTorch >= v2.6") + + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU does not support this test yet") + + dtype = torch.float32 + hidden_dim = 10 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "zero_optimization": { + "stage": 1, + }, + "compile": { + "deepcompile": True + } + } + + model = SimpleModel(hidden_dim) + engine, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + engine.compile() + + device = torch.device(get_accelerator().current_device_name()) + x = torch.randn(config_dict["train_micro_batch_size_per_gpu"], hidden_dim, device=device, dtype=dtype) + y = torch.randn_like(x) + + loss = engine(x, y) + engine.backward(loss) + + optimizer = engine.optimizer + current_grad_buffers = optimizer._deepcompile_z1_current_grad_buffers + assert current_grad_buffers + assert all(group_buffers is not None for group_buffers in current_grad_buffers.values()) + assert any(buffer.numel() > 0 for group_buffers in current_grad_buffers.values() for buffer in group_buffers) + + engine.step() + + assert all(group_buffers is None for group_buffers in optimizer._deepcompile_z1_current_grad_buffers.values()) + engine.destroy() + @pytest.mark.parametrize('dtype', [torch.float32]) @pytest.mark.parametrize('zero_stage', [3]) def test_padded_shard_handling(self, zero_stage, dtype): From 644caa2e7d281f85f18301dd91581debb8b83016 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 29 May 2026 23:28:08 -0700 Subject: [PATCH 2/3] Release DeepCompile reduce buckets after final graph Signed-off-by: Masahiro Tanaka --- csrc/compile/deepcompile.cpp | 9 +++++++-- csrc/compile/init.cpp | 2 +- csrc/compile/z3.h | 4 ++-- csrc/includes/deepcompile.h | 4 ---- deepspeed/compile/fx.py | 10 ++++++++-- deepspeed/compile/passes/zero1_compile.py | 12 +++++++----- deepspeed/compile/passes/zero3_compile.py | 5 +++-- tests/unit/v1/compile/test_compile_fx.py | 13 +++++++++++-- 8 files changed, 39 insertions(+), 20 deletions(-) diff --git a/csrc/compile/deepcompile.cpp b/csrc/compile/deepcompile.cpp index b09223b214c5..5274ab887c34 100644 --- a/csrc/compile/deepcompile.cpp +++ b/csrc/compile/deepcompile.cpp @@ -179,12 +179,17 @@ void start_backward(bool update) for (auto& it : executors) { it.second->startBackward(update); } } -void end_backward(const c10::IValue& deps, long graph_id) +void end_backward(const c10::IValue& deps, long graph_id, bool release_reduce_buckets) { auto executor = getExecutor(graph_id, executors); executor->endBackward(); + if (release_reduce_buckets) { + // reduce_buckets is shared across graph executors, so release it once + // after the final backward graph has flushed its pending reductions. + reduce_buckets->clear(); + } } -void end_backward_meta(const c10::IValue& deps, long graph_id) {} +void end_backward_meta(const c10::IValue& deps, long graph_id, bool release_reduce_buckets) {} } // namespace dc diff --git a/csrc/compile/init.cpp b/csrc/compile/init.cpp index ae5d52775a02..7ed378fd3cdd 100644 --- a/csrc/compile/init.cpp +++ b/csrc/compile/init.cpp @@ -24,7 +24,7 @@ TORCH_LIBRARY(dc, m) m.def("wait_reload(Tensor a, int id, int id) -> Tensor"); m.def("offload_parameter(Tensor a, int id, int id) -> ()"); m.def("reload_parameter(Tensor a, int id, int id) -> ()"); - m.def("end_backward(Any deps, int graph_id) -> ()"); + m.def("end_backward(Any deps, int graph_id, bool release_reduce_buckets) -> ()"); m.def("test_call(Tensor a) -> Tensor"); } diff --git a/csrc/compile/z3.h b/csrc/compile/z3.h index bc095c86cfb6..1cbd3166a4db 100644 --- a/csrc/compile/z3.h +++ b/csrc/compile/z3.h @@ -53,6 +53,6 @@ void reload_parameter(at::Tensor tensor, long graph_id, long id); void offload_parameter(at::Tensor tensor, long graph_id, long id); void reload_parameter_meta(at::Tensor tensor, long graph_id, long id); void offload_parameter_meta(at::Tensor tensor, long graph_id, long id); -void end_backward(const c10::IValue& deps, long graph_id); -void end_backward_meta(const c10::IValue& deps, long graph_id); +void end_backward(const c10::IValue& deps, long graph_id, bool release_reduce_buckets); +void end_backward_meta(const c10::IValue& deps, long graph_id, bool release_reduce_buckets); } // namespace dc diff --git a/csrc/includes/deepcompile.h b/csrc/includes/deepcompile.h index df9ee0bb01c4..7d920cad523c 100644 --- a/csrc/includes/deepcompile.h +++ b/csrc/includes/deepcompile.h @@ -488,10 +488,6 @@ class CustomOpExecutor { // This synchronization ensures all of reduce calls are done before optimizer's step. at::cuda::stream_synchronize(rs_stream_); - - // Match ZeRO's IPG buffer lifecycle: reduction buckets are backward work buffers and - // should not stay allocated while Adam creates optimizer temporaries. - reduce_buckets_->clear(); } virtual at::Tensor reduceGrad(at::Tensor grad_tensor, long ds_id) diff --git a/deepspeed/compile/fx.py b/deepspeed/compile/fx.py index 51a2147ab7c2..232f967fa328 100644 --- a/deepspeed/compile/fx.py +++ b/deepspeed/compile/fx.py @@ -20,13 +20,19 @@ def get_output_node(graph: Graph): raise ValueError("No output node found") -def add_end_backward(graph: Graph, graph_id: int): +def should_release_reduce_buckets(graph_order, graph_id: int) -> bool: + backward_graph_ids = [g_id for g_id, needs_backward in graph_order if needs_backward] + return not backward_graph_ids or graph_id == backward_graph_ids[0] + + +def add_end_backward(graph: Graph, graph_id: int, release_reduce_buckets: bool = True): reduce_nodes = [n for n in graph.nodes if n.target == torch.ops.dc.reduce_grad.default] if len(reduce_nodes) == 0: return with graph.inserting_before(get_output_node(graph)): - graph.create_node("call_function", torch.ops.dc.end_backward.default, (reduce_nodes, graph_id)) + graph.create_node("call_function", torch.ops.dc.end_backward.default, + (reduce_nodes, graph_id, release_reduce_buckets)) def replace_reduce_outputs_with_none(graph: Graph): diff --git a/deepspeed/compile/passes/zero1_compile.py b/deepspeed/compile/passes/zero1_compile.py index c4da6ad82fa3..9e0023fab1a8 100644 --- a/deepspeed/compile/passes/zero1_compile.py +++ b/deepspeed/compile/passes/zero1_compile.py @@ -9,7 +9,8 @@ from torch.fx import GraphModule from ..util import get_deepcompile_handle -from ..fx import add_postprocess, move_primals_to_head, _make_node_meta, add_end_backward, replace_reduce_outputs_with_none +from ..fx import (add_postprocess, move_primals_to_head, _make_node_meta, add_end_backward, + replace_reduce_outputs_with_none, should_release_reduce_buckets) NAME = "zero1_compile" @@ -27,7 +28,8 @@ def add_z1_reduce_fw(gm: GraphModule, graph_id: int, profiling_results, param_ma return gm -def add_z1_reduce_bw(gm: GraphModule, graph_id: int, param_manager) -> GraphModule: +def add_z1_reduce_bw(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], + param_manager) -> GraphModule: graph = gm.graph pm = param_manager[graph_id] @@ -50,7 +52,7 @@ def add_z1_reduce_bw(gm: GraphModule, graph_id: int, param_manager) -> GraphModu gm.graph = move_primals_to_head(graph) - add_end_backward(gm.graph, graph_id) + add_end_backward(gm.graph, graph_id, should_release_reduce_buckets(graph_order, graph_id)) replace_reduce_outputs_with_none(gm.graph) return gm @@ -59,12 +61,12 @@ def add_z1_reduce_bw(gm: GraphModule, graph_id: int, param_manager) -> GraphModu def add_z1_reduce(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results, create_inputs_fn, mem_budget: float, param_manager, bwd: bool) -> GraphModule: if bwd: - return add_z1_reduce_bw(gm, graph_id, param_manager) + return add_z1_reduce_bw(gm, graph_id, graph_order, param_manager) return add_z1_reduce_fw(gm, graph_id, profiling_results, param_manager, use_z2=False) def add_z2_reduce(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results, create_inputs_fn, mem_budget: float, param_manager, bwd: bool) -> GraphModule: if bwd: - return add_z1_reduce_bw(gm, graph_id, param_manager) + return add_z1_reduce_bw(gm, graph_id, graph_order, param_manager) return add_z1_reduce_fw(gm, graph_id, profiling_results, param_manager, use_z2=True) diff --git a/deepspeed/compile/passes/zero3_compile.py b/deepspeed/compile/passes/zero3_compile.py index f09a4dee2adf..c46396a5ad45 100644 --- a/deepspeed/compile/passes/zero3_compile.py +++ b/deepspeed/compile/passes/zero3_compile.py @@ -11,7 +11,8 @@ from torch.fx import Graph, Node, GraphModule from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses, is_cast_op -from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head, add_end_backward, replace_reduce_outputs_with_none +from ..fx import (add_postprocess, _make_node_meta, get_output_node, move_primals_to_head, add_end_backward, + replace_reduce_outputs_with_none, should_release_reduce_buckets) from ..profilers.graph_profile import ProfilingInterpreter from ..list_schedule import fast_free_schedule @@ -209,7 +210,7 @@ def add_z3_gather_release_bw(gm: GraphModule, 0, # unused debug_log=debug_log) - add_end_backward(gm.graph, graph_id) + add_end_backward(gm.graph, graph_id, should_release_reduce_buckets(graph_order, graph_id)) replace_reduce_outputs_with_none(gm.graph) return gm diff --git a/tests/unit/v1/compile/test_compile_fx.py b/tests/unit/v1/compile/test_compile_fx.py index e7435efc2572..b0cbb7269179 100644 --- a/tests/unit/v1/compile/test_compile_fx.py +++ b/tests/unit/v1/compile/test_compile_fx.py @@ -7,10 +7,18 @@ import torch from torch.fx import Graph -from deepspeed.compile.fx import add_end_backward, replace_reduce_outputs_with_none, get_output_node +from deepspeed.compile.fx import (add_end_backward, replace_reduce_outputs_with_none, get_output_node, + should_release_reduce_buckets) from deepspeed.compile.util import get_deepcompile_handle, is_deepcompile_supported +def test_should_release_reduce_buckets_only_on_last_backward_graph(): + graph_order = [(11, True), (22, False), (33, True)] + + assert should_release_reduce_buckets(graph_order, 11) + assert not should_release_reduce_buckets(graph_order, 33) + + @pytest.mark.skipif(not is_deepcompile_supported(), reason="DeepCompile requires CUDA and supported PyTorch") @pytest.mark.sequential def test_end_backward_depends_on_all_reduce_nodes(): @@ -27,10 +35,11 @@ def test_end_backward_depends_on_all_reduce_nodes(): graph.lint() end_backward = next(n for n in graph.nodes if n.target == torch.ops.dc.end_backward.default) - deps, graph_id = end_backward.args + deps, graph_id, release_reduce_buckets = end_backward.args output_node = get_output_node(graph) assert graph_id == 7 + assert release_reduce_buckets is True assert list(deps) == [reduce_a, reduce_b] assert end_backward in reduce_a.users assert end_backward in reduce_b.users From eb018f3d8e038b76c531dbf6855347c75c14f6ef Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Sat, 30 May 2026 00:02:39 -0700 Subject: [PATCH 3/3] Remove added DeepCompile test Signed-off-by: Masahiro Tanaka --- tests/unit/v1/compile/test_compile_fx.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/tests/unit/v1/compile/test_compile_fx.py b/tests/unit/v1/compile/test_compile_fx.py index b0cbb7269179..c23411817e5f 100644 --- a/tests/unit/v1/compile/test_compile_fx.py +++ b/tests/unit/v1/compile/test_compile_fx.py @@ -7,18 +7,10 @@ import torch from torch.fx import Graph -from deepspeed.compile.fx import (add_end_backward, replace_reduce_outputs_with_none, get_output_node, - should_release_reduce_buckets) +from deepspeed.compile.fx import add_end_backward, replace_reduce_outputs_with_none, get_output_node from deepspeed.compile.util import get_deepcompile_handle, is_deepcompile_supported -def test_should_release_reduce_buckets_only_on_last_backward_graph(): - graph_order = [(11, True), (22, False), (33, True)] - - assert should_release_reduce_buckets(graph_order, 11) - assert not should_release_reduce_buckets(graph_order, 33) - - @pytest.mark.skipif(not is_deepcompile_supported(), reason="DeepCompile requires CUDA and supported PyTorch") @pytest.mark.sequential def test_end_backward_depends_on_all_reduce_nodes(): @@ -35,11 +27,10 @@ def test_end_backward_depends_on_all_reduce_nodes(): graph.lint() end_backward = next(n for n in graph.nodes if n.target == torch.ops.dc.end_backward.default) - deps, graph_id, release_reduce_buckets = end_backward.args + deps, graph_id, _ = end_backward.args output_node = get_output_node(graph) assert graph_id == 7 - assert release_reduce_buckets is True assert list(deps) == [reduce_a, reduce_b] assert end_backward in reduce_a.users assert end_backward in reduce_b.users