-
Notifications
You must be signed in to change notification settings - Fork 43
Expand file tree
/
Copy pathdistributed_data_parallel.cc
More file actions
205 lines (178 loc) · 8.01 KB
/
distributed_data_parallel.cc
File metadata and controls
205 lines (178 loc) · 8.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
#include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h"
#include <map>
#include <memory>
#include <utility>
#include <vector>
#include "glog/logging.h"
#include "infini_train/include/autograd/function_hook.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/parallel_functional.h"
#include "infini_train/include/nn/parallel/process_group.h"
#include "infini_train/include/nn/parallel/rank.h"
#include "infini_train/include/nn/parallel/utils.h"
#include "infini_train/include/tensor.h"
namespace infini_train::nn::parallel {
namespace {
constexpr char kModuleName[] = "module";
} // namespace
DistributedDataParallel::DistributedDataParallel(std::shared_ptr<nn::Module> module, const Rank &rank,
const DistributedDataParallelConfig ddp_config)
: ddp_config_(ddp_config),
ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(rank.GlobalRank()))) {
CHECK(ddp_config_.zero_stage >= 1 && ddp_config_.zero_stage <= 3)
<< "DistributedDataParallel: zero_stage must be in 1/2/3.";
if (ddp_config_.zero_stage >= 3) {
LOG(FATAL) << "DistributedDataParallel: ZeRO-3 is not implemented yet.";
}
if (!ddp_config_.use_distributed_optimizer && ddp_config_.zero_stage >= 1) {
LOG(WARNING) << "DistributedDataParallel: zero_stage is ignored because "
"use_distributed_optimizer is false.";
ddp_config_.zero_stage = 1;
}
for (auto ¶m : module->Parameters()) {
auto device = param->GetDevice();
CHECK_EQ(device.index(), rank.thread_rank()) << "All parameters must be on the same device as the module";
if (!ddp_config.gradient_bucketing_enabled && !ddp_config.use_distributed_optimizer) {
auto hook = std::make_unique<infini_train::autograd::AllReducePostAccumulateHook>(
function::ReduceOpType::kAvg, ddp_pg_);
param->RegisterPostAccumulateGradHook(std::move(hook));
}
}
for (auto &buffer : module->Buffers()) {
CHECK_EQ(buffer->GetDevice().index(), rank.thread_rank())
<< "All buffers must be on the same device as the module";
}
modules_[kModuleName] = std::move(module);
if (ddp_config.use_distributed_optimizer) {
BuildParamAndGradBuffers();
RegisterBackwardHooks();
} else if (ddp_config.gradient_bucketing_enabled) {
// Bucket Assignment
auto params = modules_[kModuleName]->Parameters();
const size_t first_cap_bytes = ddp_config.first_bucket_cap_mb * kBytesPerMB;
const size_t normal_cap_bytes = ddp_config.normal_bucket_cap_mb * kBytesPerMB;
std::vector<size_t> bucket_size_limits = {first_cap_bytes, normal_cap_bytes};
auto bucket_indices = ComputeBucketAssignmentBySize(params, bucket_size_limits);
reducer_ = std::make_shared<Reducer>(params, bucket_indices, ddp_config);
reducer_->AttachHooksToParameters();
}
}
void DistributedDataParallel::BuildParamAndGradBuffers() {
// (param_dtype, grad_dtype)
using DTypePair = std::pair<DataType, DataType>;
std::map<DTypePair, std::vector<std::shared_ptr<Tensor>>> dtype_to_params;
for (auto param : modules_[kModuleName]->Parameters()) {
if (!param->requires_grad()) {
continue;
}
auto param_dtype = param->Dtype();
auto grad_dtype = param->grad() ? param->grad()->Dtype() : param_dtype;
dtype_to_params[{param_dtype, grad_dtype}].push_back(param);
}
param_grad_buffers_.clear();
param_grad_buffers_.reserve(dtype_to_params.size());
for (auto &kv : dtype_to_params) {
auto [param_dtype, grad_dtype] = kv.first;
auto param_list = kv.second;
if (param_list.empty()) {
continue;
}
// At the point, zero_stage is already aligned with use_distributed_optimizer.
auto buffer = std::make_shared<ParamAndGradBuffer>(param_list, param_dtype, grad_dtype, ddp_pg_, ddp_config_);
param_grad_buffers_.push_back(buffer);
}
// TODO(zbl): option for disable bucketing
bucket_groups_ = PartitionBuckets(param_grad_buffers_, /*force_single_bucket_group=*/false);
if (ddp_config_.use_distributed_optimizer && ddp_config_.overlap_param_gather) {
auto num_bucket_groups = bucket_groups_.size();
for (auto i = num_bucket_groups - 1; i > 0; --i) {
bucket_groups_[i]->SetNextParamGatherBucketGroup(bucket_groups_[i - 1]);
}
}
param_to_bucket_group_.clear();
for (auto &group : bucket_groups_) {
for (auto &bucket : group->buckets()) {
for (auto ¶m : bucket->params()) {
auto inserted = param_to_bucket_group_.emplace(param.get(), group).second;
if (!inserted) {
LOG(FATAL) << "Parameter appears in more than one bucket group.";
}
}
}
}
LOG(INFO) << "DDP BuildParamAndGradBuffers: "
<< "dtype_groups=" << dtype_to_params.size() << ", param_grad_buffers=" << param_grad_buffers_.size()
<< ", bucket_groups=" << bucket_groups_.size();
}
void DistributedDataParallel::RegisterBackwardHooks() {
if (ddp_config_.zero_stage >= 2) {
auto &module = modules_.at(kModuleName);
for (auto ¶m : module->Parameters()) {
if (!param->requires_grad()) {
continue;
}
auto it = param_to_bucket_group_.find(param.get());
if (it == param_to_bucket_group_.end()) {
continue;
}
std::weak_ptr<ParamAndGradBucketGroup> weak_group = it->second;
param->SetGradAccumulateBypass(
[weak_group, param](const std::shared_ptr<Tensor> &grad_output, bool overwrite, float learning_rate) {
if (auto group = weak_group.lock()) {
group->AccumulateParamGrad(param, grad_output, overwrite, learning_rate);
if (group->config().overlap_grad_reduce) {
group->RegisterGradReady(param);
}
return true;
}
return false;
});
}
return;
}
class DDPPostAccumulateHook final : public autograd::PostAccumulateGradHook {
public:
DDPPostAccumulateHook(DistributedDataParallel *ddp, const std::weak_ptr<Tensor> param)
: ddp_(ddp), param_(param) {}
void operator()(const std::shared_ptr<Tensor> &) override {
if (auto param = param_.lock()) {
ddp_->OnGradReady(param);
}
}
private:
DistributedDataParallel *ddp_;
std::weak_ptr<Tensor> param_;
};
auto &module = modules_.at(kModuleName);
for (auto ¶m : module->Parameters()) {
if (!param->requires_grad()) {
continue;
}
auto hook = std::make_unique<DDPPostAccumulateHook>(this, param);
param->RegisterPostAccumulateGradHook(std::move(hook));
}
}
void DistributedDataParallel::OnGradReady(const std::shared_ptr<Tensor> ¶m) {
auto it = param_to_bucket_group_.find(param.get());
if (it != param_to_bucket_group_.end()) {
CHECK(param->requires_grad());
if (ddp_config_.overlap_grad_reduce && (ddp_config_.zero_stage < 2)) {
CHECK(param->grad()) << "param.grad being None is not safe when overlap_grad_reduce is True";
}
if (ddp_config_.overlap_grad_reduce) {
it->second->RegisterGradReady(param);
}
}
}
std::vector<std::shared_ptr<Tensor>>
DistributedDataParallel::Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
auto outputs = (*modules_[kModuleName])(input_tensors);
if (reducer_) {
reducer_->PrepareForBackward();
}
if (ddp_config_.use_distributed_optimizer) {
for (auto buffer : param_grad_buffers_) { buffer->RebindGradViews(); }
}
return outputs;
}
} // namespace infini_train::nn::parallel