From 2c15ad9f7b2a74bc01102a168ed61c5fdc7c495e Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 29 May 2026 18:54:24 -0700 Subject: [PATCH] compile: normalize ZeRO-3 DeepCompile grad dtype Signed-off-by: Masahiro Tanaka --- csrc/compile/z3.cpp | 53 ++++++++++++++------- csrc/compile/z3.h | 3 +- csrc/includes/deepcompile.h | 36 ++++++++++---- deepspeed/compile/init_z3.py | 17 ++++++- tests/unit/compile/test_zero3_grad_dtype.py | 30 ++++++++++++ 5 files changed, 110 insertions(+), 29 deletions(-) create mode 100644 tests/unit/compile/test_zero3_grad_dtype.py diff --git a/csrc/compile/z3.cpp b/csrc/compile/z3.cpp index fdc146b4ec02..153c6156adad 100644 --- a/csrc/compile/z3.cpp +++ b/csrc/compile/z3.cpp @@ -244,12 +244,15 @@ class Z3CustomOpExecutor : public CustomOpExecutor { blockCopyEvents(scalar_type); - // Calculate temporary buffer size for accumulated gradients + // Calculate temporary buffer size for accumulated gradients or + // communication/storage dtype mismatches. int64_t tmp_recv_numel = 0; for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { - if (has_acc_grad_.at(t.getDSId())) { - tmp_recv_numel += param_registry_->getParam(t.getDSId()).getGradBuffer().numel(); - } + auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer(); + int64_t recv_numel = recv_buf.numel(); + bool use_tmp_recv = recv_numel > 0 && (has_acc_grad_.at(t.getDSId()) || + recv_buf.scalar_type() != scalar_type); + if (use_tmp_recv) { tmp_recv_numel += recv_numel; } } // Allocate temporary buffer if needed @@ -268,37 +271,51 @@ class Z3CustomOpExecutor : public CustomOpExecutor { for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer(); bool acc_grad = has_acc_grad_.at(t.getDSId()); + int64_t recv_numel = recv_buf.numel(); + bool use_tmp_recv = + recv_numel > 0 && (acc_grad || recv_buf.scalar_type() != scalar_type); - if (acc_grad) { + if (use_tmp_recv) { recv_buf = - tmp_recv_buf.index({torch::indexing::Slice(offset, offset + recv_buf.numel())}); + tmp_recv_buf.index({torch::indexing::Slice(offset, offset + recv_numel)}); } ncclResult_t result = ncclReduceScatter(t.getSendBuf().data_ptr(), recv_buf.data_ptr(), - recv_buf.numel(), + recv_numel, get_nccl_data_type(scalar_type), getReductionOp(), nccl_comm_, rs_stream_); if (result != ncclSuccess) { throw std::runtime_error("NCCL ReduceScatter failed"); } - if (acc_grad) { offset += recv_buf.numel(); } + if (use_tmp_recv) { offset += recv_numel; } } ncclGroupEnd(); - // Handle gradient accumulation with temporary buffer + // Move temporary receive results into the ZeRO grad buffer. { at::cuda::CUDAStreamGuard guard(rs_stream_); int64_t offset = 0; for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer(); bool acc_grad = has_acc_grad_.at(t.getDSId()); - - if (acc_grad) { - auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer(); - recv_buf.add_(tmp_recv_buf.index( - {torch::indexing::Slice(offset, offset + recv_buf.numel())})); - offset += recv_buf.numel(); + int64_t recv_numel = recv_buf.numel(); + bool use_tmp_recv = + recv_numel > 0 && (acc_grad || recv_buf.scalar_type() != scalar_type); + + if (use_tmp_recv) { + auto reduced_slice = + tmp_recv_buf.index({torch::indexing::Slice(offset, offset + recv_numel)}); + if (reduced_slice.scalar_type() != recv_buf.scalar_type()) { + reduced_slice = reduced_slice.to(recv_buf.scalar_type()); + } + if (acc_grad) { + recv_buf.add_(reduced_slice); + } else { + recv_buf.copy_(reduced_slice, true); + } + offset += recv_numel; } has_acc_grad_[t.getDSId()] = true; } @@ -462,9 +479,11 @@ void register_z3_param(long ds_id, const std::vector& ds_shape, at::Tensor ds_tensor, at::Tensor grad_buffer, - bool persistent) + bool persistent, + std::optional expected_grad_dtype) { - param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, true, 0, persistent); + param_registry->registerParam( + ds_id, ds_shape, ds_tensor, grad_buffer, true, 0, persistent, expected_grad_dtype); if (persistent) { param_registry->registerGatheredParam(ds_id, ds_tensor); } // Validate that padded shard sizes are uniform across ranks at registration time diff --git a/csrc/compile/z3.h b/csrc/compile/z3.h index bc095c86cfb6..b0192df95fc3 100644 --- a/csrc/compile/z3.h +++ b/csrc/compile/z3.h @@ -20,7 +20,8 @@ void register_z3_param(long ds_id, const std::vector& ds_shape, at::Tensor ds_tensor, at::Tensor grad_buffer, - bool persistent); + bool persistent, + std::optional expected_grad_dtype); at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id, diff --git a/csrc/includes/deepcompile.h b/csrc/includes/deepcompile.h index 7016d4a99310..142e98603123 100644 --- a/csrc/includes/deepcompile.h +++ b/csrc/includes/deepcompile.h @@ -256,9 +256,9 @@ class DSParam { at::Tensor ds_tensor, at::Tensor grad_buffer, bool partitioned, - int64_t offset, // for Z1 - bool persistent // for Z3 - ) + int64_t offset, // for Z1 + bool persistent, // for Z3 + std::optional expected_grad_dtype = std::nullopt) : id_(id), shape_(std::move(ds_shape)), ds_tensor_(ds_tensor), @@ -266,7 +266,8 @@ class DSParam { grad_buffer_(grad_buffer), partitioned_(partitioned), offset_(offset), - persistent_(persistent) + persistent_(persistent), + expected_grad_dtype_(expected_grad_dtype) { } @@ -295,6 +296,7 @@ class DSParam { int64_t getOffset() const { return offset_; } void setPersistent(bool persistent) { persistent_ = persistent; } bool isPersistent() const { return persistent_; } + std::optional getExpectedGradDtype() const { return expected_grad_dtype_; } void offload() { @@ -365,6 +367,7 @@ class DSParam { bool partitioned_; int64_t offset_; // for Z1 bool persistent_; // for Z3 + std::optional expected_grad_dtype_; mutable bool is_reloaded = false; std::optional offload_stream_; @@ -384,14 +387,20 @@ class DSParamRegistry { at::Tensor ds_tensor, at::Tensor grad_buffer, bool partitioned, - int64_t offset, // for Z1 - bool persistent // for Z3 - ) + int64_t offset, // for Z1 + bool persistent, // for Z3 + std::optional expected_grad_dtype = std::nullopt) { grad_buffer.zero_(); - params_.emplace( - ds_id, - DSParam(ds_id, ds_shape, ds_tensor, grad_buffer, partitioned, offset, persistent)); + params_.emplace(ds_id, + DSParam(ds_id, + ds_shape, + ds_tensor, + grad_buffer, + partitioned, + offset, + persistent, + expected_grad_dtype)); valid_[ds_id] = false; } @@ -483,6 +492,13 @@ class CustomOpExecutor { { int world_size = process_group_->getSize(); const DSParam& param = param_registry_->getParam(ds_id); + const auto expected_grad_dtype = param.getExpectedGradDtype(); + // Match PyTorch's leaf grad accumulation dtype before bucket selection: + // https://docs.pytorch.org/docs/main/generated/torch.sparse.semi_structured.SparseSemiStructuredTensorCUSPARSELT.html#torch.sparse.semi_structured.SparseSemiStructuredTensorCUSPARSELT.grad_dtype + if (expected_grad_dtype.has_value() && + grad_tensor.scalar_type() != expected_grad_dtype.value()) { + grad_tensor = grad_tensor.to(expected_grad_dtype.value()); + } const auto scalar_type = grad_tensor.scalar_type(); std::shared_ptr reduce_bucket = reduce_buckets_->getBuffer(scalar_type); diff --git a/deepspeed/compile/init_z3.py b/deepspeed/compile/init_z3.py index 2b9d404b3781..f3e8570cf9f9 100644 --- a/deepspeed/compile/init_z3.py +++ b/deepspeed/compile/init_z3.py @@ -17,6 +17,20 @@ WARMUP = 5 +_MISSING = object() + + +def _resolve_expected_grad_dtype(param): + # Match PyTorch's leaf grad accumulation contract. grad_dtype can be a + # dtype, or None to allow any incoming gradient dtype: + # https://docs.pytorch.org/docs/main/generated/torch.sparse.semi_structured.SparseSemiStructuredTensorCUSPARSELT.html#torch.sparse.semi_structured.SparseSemiStructuredTensorCUSPARSELT.grad_dtype + grad_dtype = getattr(param, "grad_dtype", _MISSING) + if grad_dtype is None: + return None + if grad_dtype is not _MISSING: + return grad_dtype + return param.dtype + def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None): @@ -56,7 +70,8 @@ def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None): # Disable persistent param p.ds_persist = False - dc.register_z3_param(p.ds_id, p.ds_shape, p.ds_tensor, grad_buffer, p.ds_persist) + dc.register_z3_param(p.ds_id, p.ds_shape, p.ds_tensor, grad_buffer, p.ds_persist, + _resolve_expected_grad_dtype(p)) if schedule is None: schedule = [] diff --git a/tests/unit/compile/test_zero3_grad_dtype.py b/tests/unit/compile/test_zero3_grad_dtype.py new file mode 100644 index 000000000000..a2fdaae89d9d --- /dev/null +++ b/tests/unit/compile/test_zero3_grad_dtype.py @@ -0,0 +1,30 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from deepspeed.compile.init_z3 import _resolve_expected_grad_dtype + + +def test_missing_grad_dtype_attribute_falls_back_to_param_dtype(): + + class FakeParam: + dtype = torch.bfloat16 + + assert _resolve_expected_grad_dtype(FakeParam()) is torch.bfloat16 + + +def test_explicit_none_grad_dtype_allows_raw_grad_dtype(): + param = torch.empty((2, 3), dtype=torch.bfloat16) + param.grad_dtype = None + + assert _resolve_expected_grad_dtype(param) is None + + +def test_explicit_grad_dtype_is_preserved(): + param = torch.empty((2, 3), dtype=torch.bfloat16) + param.grad_dtype = torch.float32 + + assert _resolve_expected_grad_dtype(param) is torch.float32