Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions csrc/compile/deepcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,17 @@ void start_backward(bool update)
for (auto& it : executors) { it.second->startBackward(update); }
}

void end_backward(const c10::IValue& deps, long graph_id)
void end_backward(const c10::IValue& deps, long graph_id, bool release_reduce_buckets)
{
auto executor = getExecutor<CustomOpExecutor>(graph_id, executors);
executor->endBackward();
if (release_reduce_buckets) {
// reduce_buckets is shared across graph executors, so release it once
// after the final backward graph has flushed its pending reductions.
reduce_buckets->clear();
}
}

void end_backward_meta(const c10::IValue& deps, long graph_id) {}
void end_backward_meta(const c10::IValue& deps, long graph_id, bool release_reduce_buckets) {}

} // namespace dc
5 changes: 4 additions & 1 deletion csrc/compile/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ TORCH_LIBRARY(dc, m)
m.def("wait_reload(Tensor a, int id, int id) -> Tensor");
m.def("offload_parameter(Tensor a, int id, int id) -> ()");
m.def("reload_parameter(Tensor a, int id, int id) -> ()");
m.def("end_backward(Any deps, int graph_id) -> ()");
m.def("end_backward(Any deps, int graph_id, bool release_reduce_buckets) -> ()");

m.def("test_call(Tensor a) -> Tensor");
}
Expand Down Expand Up @@ -92,6 +92,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("init", &dc::init, "Set the process group");
m.def("cleanup", &dc::cleanup, "Cleanup the process group");
m.def("register_param", &dc::register_param, "Register a parameter");
m.def("update_param_grad_buffer",
&dc::update_param_grad_buffer,
"Update a registered parameter grad buffer");
m.def("register_graph_z1",
&dc::register_graph_z1,
"Register graph with a list of ds parameter ids");
Expand Down
5 changes: 5 additions & 0 deletions csrc/compile/z1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,9 @@ void register_param(long ds_id,
param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, false, offset, false);
}

void update_param_grad_buffer(long ds_id, at::Tensor grad_buffer, int64_t offset)
{
param_registry->updateGradBuffer(ds_id, grad_buffer, offset);
}

} // namespace dc
1 change: 1 addition & 0 deletions csrc/compile/z1.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ void register_param(long ds_id,
at::Tensor ds_tensor,
at::Tensor grad_buffer,
int64_t offset);
void update_param_grad_buffer(long ds_id, at::Tensor grad_buffer, int64_t offset);
} // namespace dc
4 changes: 2 additions & 2 deletions csrc/compile/z3.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,6 @@ void reload_parameter(at::Tensor tensor, long graph_id, long id);
void offload_parameter(at::Tensor tensor, long graph_id, long id);
void reload_parameter_meta(at::Tensor tensor, long graph_id, long id);
void offload_parameter_meta(at::Tensor tensor, long graph_id, long id);
void end_backward(const c10::IValue& deps, long graph_id);
void end_backward_meta(const c10::IValue& deps, long graph_id);
void end_backward(const c10::IValue& deps, long graph_id, bool release_reduce_buckets);
void end_backward_meta(const c10::IValue& deps, long graph_id, bool release_reduce_buckets);
} // namespace dc
11 changes: 11 additions & 0 deletions csrc/includes/deepcompile.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,11 @@ class DSParam {
return ds_tensor_;
}
at::Tensor getGradBuffer() const { return grad_buffer_; }
void setGradBuffer(at::Tensor grad_buffer, int64_t offset)
{
grad_buffer_ = grad_buffer;
offset_ = offset;
}
bool isPartitioned() const { return partitioned_; }
int64_t getOffset() const { return offset_; }
void setPersistent(bool persistent) { persistent_ = persistent; }
Expand Down Expand Up @@ -395,6 +400,12 @@ class DSParamRegistry {
valid_[ds_id] = false;
}

void updateGradBuffer(long ds_id, at::Tensor grad_buffer, int64_t offset)
{
if (grad_buffer.numel() > 0) { grad_buffer.zero_(); }
params_.at(ds_id).setGradBuffer(grad_buffer, offset);
}

void registerGatheredParam(long ds_id, at::Tensor ds_tensor)
{
gathered_params_.emplace(ds_id, ds_tensor);
Expand Down
10 changes: 8 additions & 2 deletions deepspeed/compile/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,19 @@ def get_output_node(graph: Graph):
raise ValueError("No output node found")


def add_end_backward(graph: Graph, graph_id: int):
def should_release_reduce_buckets(graph_order, graph_id: int) -> bool:
backward_graph_ids = [g_id for g_id, needs_backward in graph_order if needs_backward]
return not backward_graph_ids or graph_id == backward_graph_ids[0]


def add_end_backward(graph: Graph, graph_id: int, release_reduce_buckets: bool = True):
reduce_nodes = [n for n in graph.nodes if n.target == torch.ops.dc.reduce_grad.default]
if len(reduce_nodes) == 0:
return

with graph.inserting_before(get_output_node(graph)):
graph.create_node("call_function", torch.ops.dc.end_backward.default, (reduce_nodes, graph_id))
graph.create_node("call_function", torch.ops.dc.end_backward.default,
(reduce_nodes, graph_id, release_reduce_buckets))


def replace_reduce_outputs_with_none(graph: Graph):
Expand Down
187 changes: 140 additions & 47 deletions deepspeed/compile/init_z1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,67 @@
WARMUP = 5


def _empty_grad_buffer(param):
return torch.empty([0], dtype=param.dtype, device=param.device)


def _build_partition_grad_views(optimizer, group_idx):
missing = object()
original_all_grad_tensors = optimizer.all_grad_tensors.get(group_idx, missing)
optimizer.all_grad_tensors[group_idx] = optimizer.get_all_grad_tensors(optimizer.params_in_partition[group_idx],
optimizer.gradient_accumulation_dtype)
try:
return optimizer.get_flat_partition(optimizer.params_in_partition[group_idx],
optimizer.first_offset[group_idx],
optimizer.partition_size[group_idx],
dtype=optimizer.gradient_accumulation_dtype,
device=get_accelerator().current_device_name(),
param_group_idx=group_idx,
return_tensor_list=True)
finally:
if original_all_grad_tensors is missing:
optimizer.all_grad_tensors.pop(group_idx, None)
else:
optimizer.all_grad_tensors[group_idx] = original_all_grad_tensors


def _build_flat_partition_grad_views(optimizer, group_idx):
partition_size = int(optimizer.partition_size[group_idx])
dtype = optimizer.gradient_accumulation_dtype
device = get_accelerator().current_device_name()
flat_buffer = torch.zeros(partition_size, dtype=dtype, device=device)

views = []
current_size = 0
for i, tensor in enumerate(optimizer.params_in_partition[group_idx]):
num_elements = tensor.numel()
tensor_offset = 0

if i == 0 and optimizer.first_offset[group_idx] > 0:
tensor_offset = int(optimizer.first_offset[group_idx])
num_elements -= tensor_offset

if num_elements > partition_size - current_size:
num_elements = partition_size - current_size

if num_elements <= 0:
continue

view = flat_buffer.narrow(0, current_size, int(num_elements))
if tensor_offset == 0 and num_elements == tensor.numel():
view = view.view(tensor.shape)
views.append(view)
current_size += int(num_elements)

if current_size >= partition_size:
break

if current_size < partition_size:
views.append(flat_buffer.narrow(0, current_size, partition_size - current_size))

return flat_buffer, views


def init_z1(engine, backend, compile_config, compile_kwargs, schedule=None, use_z2=False):

optimizer = engine.optimizer
Expand All @@ -26,53 +87,85 @@ def init_z1(engine, backend, compile_config, compile_kwargs, schedule=None, use_
dc = get_deepcompile_handle()
dc.init(engine.data_parallel_group, compile_config, engine.zero_reduce_bucket_size())

grad_buffer = {}

# Save original all_grad_tensors state as we temporarily modify it
original_all_grad_tensors = optimizer.all_grad_tensors.copy() if hasattr(optimizer, 'all_grad_tensors') else {}

for i, group in enumerate(optimizer.bit16_groups):
# Temporarily populate all_grad_tensors for get_flat_partition call
# This is needed because get_flat_partition accesses all_grad_tensors[param_group_idx][i]
# but it's empty during initialization
if i not in optimizer.all_grad_tensors or optimizer.all_grad_tensors[i] is None:
optimizer.all_grad_tensors[i] = optimizer.get_all_grad_tensors(optimizer.params_in_partition[i],
optimizer.gradient_accumulation_dtype)

grad_buffer[i] = optimizer.get_flat_partition(optimizer.params_in_partition[i],
optimizer.first_offset[i],
optimizer.partition_size[i],
dtype=optimizer.gradient_accumulation_dtype,
device=get_accelerator().current_device_name(),
param_group_idx=i,
return_tensor_list=True)
grad_buffer[i] = [p.clone().detach() for p in grad_buffer[i]] # Maybe not necessary

index_in_partition = 0
first_in_partition = True
for p in group:
param_id = optimizer.get_param_id(p)
p.param_id = param_id
in_partition = optimizer.is_param_in_current_partition[param_id]

if in_partition:
buf = grad_buffer[i][index_in_partition]
offset = optimizer.first_offset[i] if first_in_partition else 0
# print(f"[r{dist.get_rank()}] Registering group {i} param {param_id} in_partition={in_partition} p={p.shape} buf={buf.shape} partition_offset={offset}")
dc.register_param(p.param_id, p.shape, p, buf, int(offset))
index_in_partition += 1
first_in_partition = False
else:
# print(f"[r{dist.get_rank()}] Registering group {i} param {param_id} in_partition={in_partition} p={p.shape} buf=None")
dc.register_param(p.param_id, p.shape, p, torch.empty([0], dtype=p.dtype, device=p.device), 0)

# Restore original all_grad_tensors state
optimizer.all_grad_tensors = original_all_grad_tensors

def set_grad_buffer():
optimizer.averaged_gradients = copy.copy(grad_buffer)

add_pre_backward_hook(set_grad_buffer)
if use_z2:
grad_buffer = {}
for i, group in enumerate(optimizer.bit16_groups):
grad_buffer[i] = [p.clone().detach() for p in _build_partition_grad_views(optimizer, i)]

index_in_partition = 0
first_in_partition = True
for p in group:
param_id = optimizer.get_param_id(p)
p.param_id = param_id
in_partition = optimizer.is_param_in_current_partition[param_id]

if in_partition:
buf = grad_buffer[i][index_in_partition]
offset = optimizer.first_offset[i] if first_in_partition else 0
dc.register_param(p.param_id, p.shape, p, buf, int(offset))
index_in_partition += 1
first_in_partition = False
else:
dc.register_param(p.param_id, p.shape, p, _empty_grad_buffer(p), 0)

def set_z2_grad_buffer(_is_gradient_accumulation_boundary):
optimizer.averaged_gradients = copy.copy(grad_buffer)

add_pre_backward_hook(set_z2_grad_buffer)
else:
grad_buffer_metadata = {}

for i, group in enumerate(optimizer.bit16_groups):
grad_buffer_metadata[i] = []
first_in_partition = True
for p in group:
param_id = optimizer.get_param_id(p)
p.param_id = param_id
in_partition = optimizer.is_param_in_current_partition[param_id]

if in_partition:
offset = optimizer.first_offset[i] if first_in_partition else 0
grad_buffer_metadata[i].append((p.param_id, p, int(offset)))
dc.register_param(p.param_id, p.shape, p, _empty_grad_buffer(p), 0)
first_in_partition = False
else:
dc.register_param(p.param_id, p.shape, p, _empty_grad_buffer(p), 0)

optimizer._deepcompile_z1_grad_buffer_metadata = grad_buffer_metadata
optimizer._deepcompile_z1_current_grad_buffers = {}
optimizer._deepcompile_z1_current_flat_grad_buffers = {}

def set_z1_grad_buffer(is_gradient_accumulation_boundary):
if not is_gradient_accumulation_boundary:
release_grad_buffer()
optimizer.averaged_gradients = {}
return

current_grad_buffers = {}
current_flat_grad_buffers = {}
for group_idx in range(len(optimizer.bit16_groups)):
flat_grad_buffer, group_grad_buffers = _build_flat_partition_grad_views(optimizer, group_idx)
current_flat_grad_buffers[group_idx] = flat_grad_buffer
current_grad_buffers[group_idx] = group_grad_buffers
for (param_id, _, offset), grad_buffer in zip(grad_buffer_metadata[group_idx], group_grad_buffers):
dc.update_param_grad_buffer(param_id, grad_buffer, offset)
optimizer._deepcompile_z1_current_flat_grad_buffers = current_flat_grad_buffers
optimizer._deepcompile_z1_current_grad_buffers = current_grad_buffers
optimizer.averaged_gradients = current_grad_buffers

def release_grad_buffer(group_idx=None):
group_indices = grad_buffer_metadata.keys() if group_idx is None else [group_idx]
for idx in group_indices:
for param_id, param, _ in grad_buffer_metadata[idx]:
dc.update_param_grad_buffer(param_id, _empty_grad_buffer(param), 0)
if idx in optimizer._deepcompile_z1_current_grad_buffers:
optimizer._deepcompile_z1_current_grad_buffers[idx] = None
if idx in optimizer._deepcompile_z1_current_flat_grad_buffers:
optimizer._deepcompile_z1_current_flat_grad_buffers[idx] = None

optimizer._deepcompile_z1_release_grad_buffers = release_grad_buffer

add_pre_backward_hook(set_z1_grad_buffer)

if schedule is None:
schedule = []
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/compile/init_z3.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None):

if use_opt:

def set_grad_buffer():
def set_grad_buffer(_is_gradient_accumulation_boundary):
for i, sub_group in enumerate(optimizer.fp16_groups):
optimizer.averaged_gradients[i] = [
optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[param.ds_id]
Expand Down
12 changes: 7 additions & 5 deletions deepspeed/compile/passes/zero1_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from torch.fx import GraphModule

from ..util import get_deepcompile_handle
from ..fx import add_postprocess, move_primals_to_head, _make_node_meta, add_end_backward, replace_reduce_outputs_with_none
from ..fx import (add_postprocess, move_primals_to_head, _make_node_meta, add_end_backward,
replace_reduce_outputs_with_none, should_release_reduce_buckets)

NAME = "zero1_compile"

Expand All @@ -27,7 +28,8 @@ def add_z1_reduce_fw(gm: GraphModule, graph_id: int, profiling_results, param_ma
return gm


def add_z1_reduce_bw(gm: GraphModule, graph_id: int, param_manager) -> GraphModule:
def add_z1_reduce_bw(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]],
param_manager) -> GraphModule:

graph = gm.graph
pm = param_manager[graph_id]
Expand All @@ -50,7 +52,7 @@ def add_z1_reduce_bw(gm: GraphModule, graph_id: int, param_manager) -> GraphModu

gm.graph = move_primals_to_head(graph)

add_end_backward(gm.graph, graph_id)
add_end_backward(gm.graph, graph_id, should_release_reduce_buckets(graph_order, graph_id))
replace_reduce_outputs_with_none(gm.graph)

return gm
Expand All @@ -59,12 +61,12 @@ def add_z1_reduce_bw(gm: GraphModule, graph_id: int, param_manager) -> GraphModu
def add_z1_reduce(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
create_inputs_fn, mem_budget: float, param_manager, bwd: bool) -> GraphModule:
if bwd:
return add_z1_reduce_bw(gm, graph_id, param_manager)
return add_z1_reduce_bw(gm, graph_id, graph_order, param_manager)
return add_z1_reduce_fw(gm, graph_id, profiling_results, param_manager, use_z2=False)


def add_z2_reduce(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
create_inputs_fn, mem_budget: float, param_manager, bwd: bool) -> GraphModule:
if bwd:
return add_z1_reduce_bw(gm, graph_id, param_manager)
return add_z1_reduce_bw(gm, graph_id, graph_order, param_manager)
return add_z1_reduce_fw(gm, graph_id, profiling_results, param_manager, use_z2=True)
5 changes: 3 additions & 2 deletions deepspeed/compile/passes/zero3_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from torch.fx import Graph, Node, GraphModule

from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses, is_cast_op
from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head, add_end_backward, replace_reduce_outputs_with_none
from ..fx import (add_postprocess, _make_node_meta, get_output_node, move_primals_to_head, add_end_backward,
replace_reduce_outputs_with_none, should_release_reduce_buckets)
from ..profilers.graph_profile import ProfilingInterpreter
from ..list_schedule import fast_free_schedule

Expand Down Expand Up @@ -209,7 +210,7 @@ def add_z3_gather_release_bw(gm: GraphModule,
0, # unused
debug_log=debug_log)

add_end_backward(gm.graph, graph_id)
add_end_backward(gm.graph, graph_id, should_release_reduce_buckets(graph_order, graph_id))
replace_reduce_outputs_with_none(gm.graph)

return gm
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/compile/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def add_pre_backward_hook(hook):
def deepcompile_backward_prologue(is_gradient_accumulation_boundary):

for hook in pre_backward_hooks:
hook()
hook(is_gradient_accumulation_boundary)

dc = get_deepcompile_handle()
dc.start_backward(is_gradient_accumulation_boundary)
Expand Down
Loading
Loading