Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 36 additions & 17 deletions csrc/compile/z3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,15 @@ class Z3CustomOpExecutor : public CustomOpExecutor {

blockCopyEvents(scalar_type);

// Calculate temporary buffer size for accumulated gradients
// Calculate temporary buffer size for accumulated gradients or
// communication/storage dtype mismatches.
int64_t tmp_recv_numel = 0;
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
if (has_acc_grad_.at(t.getDSId())) {
tmp_recv_numel += param_registry_->getParam(t.getDSId()).getGradBuffer().numel();
}
auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer();
int64_t recv_numel = recv_buf.numel();
bool use_tmp_recv = recv_numel > 0 && (has_acc_grad_.at(t.getDSId()) ||
recv_buf.scalar_type() != scalar_type);
if (use_tmp_recv) { tmp_recv_numel += recv_numel; }
}

// Allocate temporary buffer if needed
Expand All @@ -268,37 +271,51 @@ class Z3CustomOpExecutor : public CustomOpExecutor {
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer();
bool acc_grad = has_acc_grad_.at(t.getDSId());
int64_t recv_numel = recv_buf.numel();
bool use_tmp_recv =
recv_numel > 0 && (acc_grad || recv_buf.scalar_type() != scalar_type);

if (acc_grad) {
if (use_tmp_recv) {
recv_buf =
tmp_recv_buf.index({torch::indexing::Slice(offset, offset + recv_buf.numel())});
tmp_recv_buf.index({torch::indexing::Slice(offset, offset + recv_numel)});
}

ncclResult_t result = ncclReduceScatter(t.getSendBuf().data_ptr(),
recv_buf.data_ptr(),
recv_buf.numel(),
recv_numel,
get_nccl_data_type(scalar_type),
getReductionOp(),
nccl_comm_,
rs_stream_);
if (result != ncclSuccess) { throw std::runtime_error("NCCL ReduceScatter failed"); }

if (acc_grad) { offset += recv_buf.numel(); }
if (use_tmp_recv) { offset += recv_numel; }
}
ncclGroupEnd();

// Handle gradient accumulation with temporary buffer
// Move temporary receive results into the ZeRO grad buffer.
{
at::cuda::CUDAStreamGuard guard(rs_stream_);
int64_t offset = 0;
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer();
bool acc_grad = has_acc_grad_.at(t.getDSId());

if (acc_grad) {
auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer();
recv_buf.add_(tmp_recv_buf.index(
{torch::indexing::Slice(offset, offset + recv_buf.numel())}));
offset += recv_buf.numel();
int64_t recv_numel = recv_buf.numel();
bool use_tmp_recv =
recv_numel > 0 && (acc_grad || recv_buf.scalar_type() != scalar_type);

if (use_tmp_recv) {
auto reduced_slice =
tmp_recv_buf.index({torch::indexing::Slice(offset, offset + recv_numel)});
if (reduced_slice.scalar_type() != recv_buf.scalar_type()) {
reduced_slice = reduced_slice.to(recv_buf.scalar_type());
}
if (acc_grad) {
recv_buf.add_(reduced_slice);
} else {
recv_buf.copy_(reduced_slice, true);
}
offset += recv_numel;
}
has_acc_grad_[t.getDSId()] = true;
}
Expand Down Expand Up @@ -462,9 +479,11 @@ void register_z3_param(long ds_id,
const std::vector<int64_t>& ds_shape,
at::Tensor ds_tensor,
at::Tensor grad_buffer,
bool persistent)
bool persistent,
std::optional<at::ScalarType> expected_grad_dtype)
{
param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, true, 0, persistent);
param_registry->registerParam(
ds_id, ds_shape, ds_tensor, grad_buffer, true, 0, persistent, expected_grad_dtype);
if (persistent) { param_registry->registerGatheredParam(ds_id, ds_tensor); }

// Validate that padded shard sizes are uniform across ranks at registration time
Expand Down
3 changes: 2 additions & 1 deletion csrc/compile/z3.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ void register_z3_param(long ds_id,
const std::vector<int64_t>& ds_shape,
at::Tensor ds_tensor,
at::Tensor grad_buffer,
bool persistent);
bool persistent,
std::optional<at::ScalarType> expected_grad_dtype);
at::Tensor allgather_param(at::Tensor param_tensor,
long graph_id,
long ds_id,
Expand Down
36 changes: 26 additions & 10 deletions csrc/includes/deepcompile.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,17 +256,18 @@ class DSParam {
at::Tensor ds_tensor,
at::Tensor grad_buffer,
bool partitioned,
int64_t offset, // for Z1
bool persistent // for Z3
)
int64_t offset, // for Z1
bool persistent, // for Z3
std::optional<at::ScalarType> expected_grad_dtype = std::nullopt)
: id_(id),
shape_(std::move(ds_shape)),
ds_tensor_(ds_tensor),
ds_dtype_(ds_tensor.scalar_type()),
grad_buffer_(grad_buffer),
partitioned_(partitioned),
offset_(offset),
persistent_(persistent)
persistent_(persistent),
expected_grad_dtype_(expected_grad_dtype)
{
}

Expand Down Expand Up @@ -295,6 +296,7 @@ class DSParam {
int64_t getOffset() const { return offset_; }
void setPersistent(bool persistent) { persistent_ = persistent; }
bool isPersistent() const { return persistent_; }
std::optional<at::ScalarType> getExpectedGradDtype() const { return expected_grad_dtype_; }

void offload()
{
Expand Down Expand Up @@ -365,6 +367,7 @@ class DSParam {
bool partitioned_;
int64_t offset_; // for Z1
bool persistent_; // for Z3
std::optional<at::ScalarType> expected_grad_dtype_;
mutable bool is_reloaded = false;

std::optional<at::cuda::CUDAStream> offload_stream_;
Expand All @@ -384,14 +387,20 @@ class DSParamRegistry {
at::Tensor ds_tensor,
at::Tensor grad_buffer,
bool partitioned,
int64_t offset, // for Z1
bool persistent // for Z3
)
int64_t offset, // for Z1
bool persistent, // for Z3
std::optional<at::ScalarType> expected_grad_dtype = std::nullopt)
{
grad_buffer.zero_();
params_.emplace(
ds_id,
DSParam(ds_id, ds_shape, ds_tensor, grad_buffer, partitioned, offset, persistent));
params_.emplace(ds_id,
DSParam(ds_id,
ds_shape,
ds_tensor,
grad_buffer,
partitioned,
offset,
persistent,
expected_grad_dtype));
valid_[ds_id] = false;
}

Expand Down Expand Up @@ -483,6 +492,13 @@ class CustomOpExecutor {
{
int world_size = process_group_->getSize();
const DSParam& param = param_registry_->getParam(ds_id);
const auto expected_grad_dtype = param.getExpectedGradDtype();
// Match PyTorch's leaf grad accumulation dtype before bucket selection:
// https://docs.pytorch.org/docs/main/generated/torch.sparse.semi_structured.SparseSemiStructuredTensorCUSPARSELT.html#torch.sparse.semi_structured.SparseSemiStructuredTensorCUSPARSELT.grad_dtype
if (expected_grad_dtype.has_value() &&
grad_tensor.scalar_type() != expected_grad_dtype.value()) {
grad_tensor = grad_tensor.to(expected_grad_dtype.value());
}
const auto scalar_type = grad_tensor.scalar_type();
std::shared_ptr<ReduceBucket> reduce_bucket = reduce_buckets_->getBuffer(scalar_type);

Expand Down
17 changes: 16 additions & 1 deletion deepspeed/compile/init_z3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@

WARMUP = 5

_MISSING = object()


def _resolve_expected_grad_dtype(param):
# Match PyTorch's leaf grad accumulation contract. grad_dtype can be a
# dtype, or None to allow any incoming gradient dtype:
# https://docs.pytorch.org/docs/main/generated/torch.sparse.semi_structured.SparseSemiStructuredTensorCUSPARSELT.html#torch.sparse.semi_structured.SparseSemiStructuredTensorCUSPARSELT.grad_dtype
grad_dtype = getattr(param, "grad_dtype", _MISSING)
if grad_dtype is None:
return None
if grad_dtype is not _MISSING:
return grad_dtype
return param.dtype


def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None):

Expand Down Expand Up @@ -56,7 +70,8 @@ def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None):

# Disable persistent param
p.ds_persist = False
dc.register_z3_param(p.ds_id, p.ds_shape, p.ds_tensor, grad_buffer, p.ds_persist)
dc.register_z3_param(p.ds_id, p.ds_shape, p.ds_tensor, grad_buffer, p.ds_persist,
_resolve_expected_grad_dtype(p))

if schedule is None:
schedule = []
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/compile/test_zero3_grad_dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) DeepSpeed Team.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch

from deepspeed.compile.init_z3 import _resolve_expected_grad_dtype


def test_missing_grad_dtype_attribute_falls_back_to_param_dtype():

class FakeParam:
dtype = torch.bfloat16

assert _resolve_expected_grad_dtype(FakeParam()) is torch.bfloat16


def test_explicit_none_grad_dtype_allows_raw_grad_dtype():
param = torch.empty((2, 3), dtype=torch.bfloat16)
param.grad_dtype = None

assert _resolve_expected_grad_dtype(param) is None


def test_explicit_grad_dtype_is_preserved():
param = torch.empty((2, 3), dtype=torch.bfloat16)
param.grad_dtype = torch.float32

assert _resolve_expected_grad_dtype(param) is torch.float32
Loading