Skip to content

Commit 733ad19

Browse files
Chamberlain0w0kilinchange
authored andcommitted
fix: save device/ccl impl in ProcessGroup
1 parent bf2eae5 commit 733ad19

2 files changed

Lines changed: 54 additions & 65 deletions

File tree

infini_train/include/nn/parallel/process_group.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ namespace infini_train {
1515
class Tensor;
1616
namespace core {
1717
class CclComm;
18+
class CclImpl;
19+
class DeviceGuardImpl;
1820
class Stream;
1921
} // namespace core
2022
namespace nn {
@@ -89,6 +91,10 @@ class ProcessGroup {
8991
bool is_main_process_ = false;
9092
Device::DeviceType backend_ = Device::DeviceType::kInvalid;
9193

94+
// Save impl for convenience
95+
core::DeviceGuardImpl *runtime_impl_ = nullptr;
96+
core::CclImpl *ccl_impl_ = nullptr;
97+
9298
std::vector<std::unique_ptr<core::CclComm>> comms_;
9399
std::vector<std::unique_ptr<core::Stream>> comm_streams_;
94100
std::unordered_map<int, core::CclComm *> device_comm_map_; // device_index : comm

infini_train/src/nn/parallel/process_group.cc

Lines changed: 48 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ ProcessGroup::ProcessGroup(int world_size, const std::string &name) : world_size
3636

3737
ProcessGroup::ProcessGroup(Device::DeviceType backend, const std::string &process_group_name,
3838
const std::vector<int> &ranks)
39-
: backend_(backend), world_size_(ranks.size()), name_(process_group_name) {
39+
: backend_(backend), runtime_impl_(core::GetDeviceGuardImpl(backend)), ccl_impl_(core::GetCclImpl(backend)),
40+
world_size_(ranks.size()), name_(process_group_name) {
4041
CHECK_GT(world_size_, 0);
4142
if (global::GetNnodes() == 1 && global::GetNprocPerNode() == 1) {
4243
InitSingleProcess(ranks);
@@ -59,9 +60,8 @@ void ProcessGroup::InitSingleProcess(const std::vector<int> &ranks) {
5960
comms_.clear();
6061
comms_.reserve(world_size_);
6162

62-
auto *ccl_impl = core::GetCclImpl(backend_);
6363
std::vector<core::CclComm *> comm_ptrs(static_cast<size_t>(world_size_), nullptr);
64-
ccl_impl->CommInitAll(comm_ptrs.data(), world_size_, ranks.data());
64+
ccl_impl_->CommInitAll(comm_ptrs.data(), world_size_, ranks.data());
6565

6666
for (int i = 0; i < ranks.size(); ++i) {
6767
auto *comm_raw = comm_ptrs[static_cast<size_t>(i)];
@@ -77,15 +77,13 @@ void ProcessGroup::InitSingleProcess(const std::vector<int> &ranks) {
7777
}
7878

7979
void ProcessGroup::InitMultiProcess(const std::vector<int> &ranks) {
80-
auto *ccl_impl = core::GetCclImpl(backend_);
81-
8280
int n_threads = global::GetNthreadPerProc();
8381
int global_proc_rank = global::GetGlobalProcRank();
8482
int lower_rank = global_proc_rank * n_threads;
8583
int upper_rank = (global_proc_rank + 1) * n_threads;
8684

8785
core::CclUniqueId *unique_id_raw = nullptr;
88-
ccl_impl->GetUniqueId(&unique_id_raw);
86+
ccl_impl_->GetUniqueId(&unique_id_raw);
8987
std::unique_ptr<core::CclUniqueId> unique_id(unique_id_raw);
9088

9189
int min_rank = std::ranges::min(ranks);
@@ -106,7 +104,7 @@ void ProcessGroup::InitMultiProcess(const std::vector<int> &ranks) {
106104

107105
core::CclComm *comm_raw = nullptr;
108106
int group_rank = std::distance(ranks.begin(), it);
109-
ccl_impl->CommInitRank(&comm_raw, world_size_, *unique_id, group_rank);
107+
ccl_impl_->CommInitRank(&comm_raw, world_size_, *unique_id, group_rank);
110108
CHECK_NOTNULL(comm_raw);
111109
comms_.emplace_back(comm_raw);
112110

@@ -120,10 +118,9 @@ void ProcessGroup::InitMultiProcess(const std::vector<int> &ranks) {
120118
void ProcessGroup::InitStreams() {
121119
for (const auto &device : devices_) {
122120
core::DeviceGuard guard(device);
123-
auto *impl = core::GetDeviceGuardImpl(device.type());
124121
int low, high;
125-
impl->GetStreamPriorityRange(&low, &high);
126-
auto *stream = CreateOwnedStream(impl, device, high, comm_streams_);
122+
runtime_impl_->GetStreamPriorityRange(&low, &high);
123+
auto *stream = CreateOwnedStream(runtime_impl_, device, high, comm_streams_);
127124
device_stream_map_[device.index()] = stream;
128125
}
129126
}
@@ -132,18 +129,16 @@ std::shared_ptr<Work> ProcessGroup::AllReduce(const std::shared_ptr<Tensor> &ten
132129
bool async_op) const {
133130
auto device = tensor->GetDevice();
134131
core::DeviceGuard guard(device);
135-
auto *runtime_impl = core::GetDeviceGuardImpl(device.type());
136-
auto *ccl_impl = core::GetCclImpl(device.type());
137-
auto *compute_stream = runtime_impl->GetStream(device);
132+
auto *compute_stream = runtime_impl_->GetStream(device);
138133
auto *comm_stream = device_stream_map_.at(device.index());
139134
auto comm = device_comm_map_.at(device.index());
140135

141136
auto work = std::make_shared<Work>(device, comm);
142-
runtime_impl->EventRecord(work->ready_event(), compute_stream);
143-
runtime_impl->StreamWaitEvent(comm_stream, work->ready_event(), 0);
144-
ccl_impl->AllReduce(tensor->DataPtr(), tensor->DataPtr(), tensor->NumElements(), tensor->Dtype(), reduce_op, comm,
145-
comm_stream);
146-
runtime_impl->EventRecord(work->done_event(), comm_stream);
137+
runtime_impl_->EventRecord(work->ready_event(), compute_stream);
138+
runtime_impl_->StreamWaitEvent(comm_stream, work->ready_event(), 0);
139+
ccl_impl_->AllReduce(tensor->DataPtr(), tensor->DataPtr(), tensor->NumElements(), tensor->Dtype(), reduce_op, comm,
140+
comm_stream);
141+
runtime_impl_->EventRecord(work->done_event(), comm_stream);
147142

148143
if (async_op) {
149144
return work;
@@ -157,17 +152,15 @@ std::shared_ptr<Work> ProcessGroup::AllGather(const std::shared_ptr<Tensor> &out
157152
const std::shared_ptr<Tensor> &input, bool async_op) const {
158153
auto device = input->GetDevice();
159154
core::DeviceGuard guard(device);
160-
auto *runtime_impl = core::GetDeviceGuardImpl(device.type());
161-
auto *ccl_impl = core::GetCclImpl(device.type());
162-
auto *compute_stream = runtime_impl->GetStream(device);
155+
auto *compute_stream = runtime_impl_->GetStream(device);
163156
auto *comm_stream = device_stream_map_.at(device.index());
164157
auto comm = device_comm_map_.at(device.index());
165158

166159
auto work = std::make_shared<Work>(device, comm);
167-
runtime_impl->EventRecord(work->ready_event(), compute_stream);
168-
runtime_impl->StreamWaitEvent(comm_stream, work->ready_event(), 0);
169-
ccl_impl->AllGather(input->DataPtr(), output->DataPtr(), input->NumElements(), input->Dtype(), comm, comm_stream);
170-
runtime_impl->EventRecord(work->done_event(), comm_stream);
160+
runtime_impl_->EventRecord(work->ready_event(), compute_stream);
161+
runtime_impl_->StreamWaitEvent(comm_stream, work->ready_event(), 0);
162+
ccl_impl_->AllGather(input->DataPtr(), output->DataPtr(), input->NumElements(), input->Dtype(), comm, comm_stream);
163+
runtime_impl_->EventRecord(work->done_event(), comm_stream);
171164

172165
if (async_op) {
173166
return work;
@@ -182,18 +175,16 @@ std::shared_ptr<Work> ProcessGroup::ReduceScatter(const std::shared_ptr<Tensor>
182175
function::ReduceOpType reduce_op, bool async_op) const {
183176
auto device = input->GetDevice();
184177
core::DeviceGuard guard(device);
185-
auto *runtime_impl = core::GetDeviceGuardImpl(device.type());
186-
auto *ccl_impl = core::GetCclImpl(device.type());
187-
auto *compute_stream = runtime_impl->GetStream(device);
178+
auto *compute_stream = runtime_impl_->GetStream(device);
188179
auto *comm_stream = device_stream_map_.at(device.index());
189180
auto comm = device_comm_map_.at(device.index());
190181

191182
auto work = std::make_shared<Work>(device, comm);
192-
runtime_impl->EventRecord(work->ready_event(), compute_stream);
193-
runtime_impl->StreamWaitEvent(comm_stream, work->ready_event(), 0);
194-
ccl_impl->ReduceScatter(input->DataPtr(), output->DataPtr(), output->NumElements(), input->Dtype(), reduce_op, comm,
195-
comm_stream);
196-
runtime_impl->EventRecord(work->done_event(), comm_stream);
183+
runtime_impl_->EventRecord(work->ready_event(), compute_stream);
184+
runtime_impl_->StreamWaitEvent(comm_stream, work->ready_event(), 0);
185+
ccl_impl_->ReduceScatter(input->DataPtr(), output->DataPtr(), output->NumElements(), input->Dtype(), reduce_op,
186+
comm, comm_stream);
187+
runtime_impl_->EventRecord(work->done_event(), comm_stream);
197188

198189
if (async_op) {
199190
return work;
@@ -208,21 +199,19 @@ std::shared_ptr<Work> ProcessGroup::Send(std::vector<std::shared_ptr<Tensor>> te
208199
CHECK_GT(tensors.size(), 0);
209200
auto device = tensors[0]->GetDevice();
210201
core::DeviceGuard guard(device);
211-
auto *runtime_impl = core::GetDeviceGuardImpl(device.type());
212-
auto *ccl_impl = core::GetCclImpl(device.type());
213-
auto *compute_stream = runtime_impl->GetStream(device);
202+
auto *compute_stream = runtime_impl_->GetStream(device);
214203
auto *comm_stream = device_stream_map_.at(device.index());
215204
auto comm = device_comm_map_.at(device.index());
216205

217206
auto work = std::make_shared<Work>(device, comm);
218-
runtime_impl->EventRecord(work->ready_event(), compute_stream);
219-
runtime_impl->StreamWaitEvent(comm_stream, work->ready_event(), 0);
207+
runtime_impl_->EventRecord(work->ready_event(), compute_stream);
208+
runtime_impl_->StreamWaitEvent(comm_stream, work->ready_event(), 0);
220209
for (const auto &tensor : tensors) {
221210
CHECK_NOTNULL(tensor);
222211
CHECK_EQ(device, tensor->GetDevice());
223-
ccl_impl->Send(tensor->DataPtr(), tensor->NumElements(), tensor->Dtype(), dest_rank, comm, comm_stream);
212+
ccl_impl_->Send(tensor->DataPtr(), tensor->NumElements(), tensor->Dtype(), dest_rank, comm, comm_stream);
224213
}
225-
runtime_impl->EventRecord(work->done_event(), comm_stream);
214+
runtime_impl_->EventRecord(work->done_event(), comm_stream);
226215

227216
if (async_op) {
228217
return work;
@@ -237,21 +226,19 @@ std::shared_ptr<Work> ProcessGroup::Recv(std::vector<std::shared_ptr<Tensor>> te
237226
CHECK_GT(tensors.size(), 0);
238227
auto device = tensors[0]->GetDevice();
239228
core::DeviceGuard guard(device);
240-
auto *runtime_impl = core::GetDeviceGuardImpl(device.type());
241-
auto *ccl_impl = core::GetCclImpl(device.type());
242-
auto *compute_stream = runtime_impl->GetStream(device);
229+
auto *compute_stream = runtime_impl_->GetStream(device);
243230
auto *comm_stream = device_stream_map_.at(device.index());
244231
auto comm = device_comm_map_.at(device.index());
245232

246233
auto work = std::make_shared<Work>(device, comm);
247-
runtime_impl->EventRecord(work->ready_event(), compute_stream);
248-
runtime_impl->StreamWaitEvent(comm_stream, work->ready_event(), 0);
234+
runtime_impl_->EventRecord(work->ready_event(), compute_stream);
235+
runtime_impl_->StreamWaitEvent(comm_stream, work->ready_event(), 0);
249236
for (const auto &tensor : tensors) {
250237
CHECK_NOTNULL(tensor);
251238
CHECK_EQ(device, tensor->GetDevice());
252-
ccl_impl->Recv(tensor->DataPtr(), tensor->NumElements(), tensor->Dtype(), src_rank, comm, comm_stream);
239+
ccl_impl_->Recv(tensor->DataPtr(), tensor->NumElements(), tensor->Dtype(), src_rank, comm, comm_stream);
253240
}
254-
runtime_impl->EventRecord(work->done_event(), comm_stream);
241+
runtime_impl_->EventRecord(work->done_event(), comm_stream);
255242

256243
if (async_op) {
257244
return work;
@@ -275,7 +262,7 @@ ProcessGroup::BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensor
275262
outputs.push_back(std::make_shared<Tensor>(input_tensor->Dims(), input_tensor->Dtype(), device));
276263
}
277264
devices.push_back(device);
278-
streams.push_back(core::GetDeviceGuardImpl(device.type())->GetStream(device));
265+
streams.push_back(runtime_impl_->GetStream(device));
279266
comms.push_back(device_comm_map_.at(device.index()));
280267
}
281268

@@ -288,15 +275,14 @@ ProcessGroup::BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensor
288275
}
289276
CHECK_NE(root, -1) << "Root not found in input devices";
290277

291-
auto *ccl_impl = core::GetCclImpl(devices[0].type());
292278
core::CclGroupGuard ccl_group_guard(devices[0].type());
293279
for (size_t i = 0; i < devices.size(); ++i) {
294280
core::DeviceGuard guard(devices[i]);
295281
for (size_t j = 0; j < input_tensors.size(); ++j) {
296282
const auto &input_tensor = input_tensors[j];
297283
const void *send_buffer = (static_cast<int>(i) == root ? input_tensor->DataPtr() : nullptr);
298-
ccl_impl->Broadcast(send_buffer, outputs[i * input_tensors.size() + j]->DataPtr(),
299-
input_tensor->NumElements(), input_tensor->Dtype(), root, comms[i], streams[i]);
284+
ccl_impl_->Broadcast(send_buffer, outputs[i * input_tensors.size() + j]->DataPtr(),
285+
input_tensor->NumElements(), input_tensor->Dtype(), root, comms[i], streams[i]);
300286
}
301287
}
302288

@@ -317,7 +303,7 @@ ProcessGroup::ReduceAddCoalesced(const std::vector<std::vector<std::shared_ptr<T
317303
}
318304
for (size_t i = 0; i < grads.size(); ++i) {
319305
devices.push_back(grads[i][0]->GetDevice());
320-
streams.push_back(core::GetDeviceGuardImpl(devices[i].type())->GetStream(devices[i]));
306+
streams.push_back(runtime_impl_->GetStream(devices[i]));
321307
comms.push_back(device_comm_map_.at(devices[i].index()));
322308
}
323309

@@ -330,14 +316,13 @@ ProcessGroup::ReduceAddCoalesced(const std::vector<std::vector<std::shared_ptr<T
330316
}
331317
CHECK_NE(root, -1) << "Destination device not found in grads group";
332318

333-
auto *ccl_impl = core::GetCclImpl(devices[0].type());
334319
core::CclGroupGuard ccl_group_guard(devices[0].type());
335320
for (size_t i = 0; i < grads.size(); ++i) {
336321
core::DeviceGuard guard(devices[i]);
337322
for (size_t j = 0; j < grads[i].size(); ++j) {
338323
const auto &grad = grads[i][j];
339-
ccl_impl->Reduce(grad->DataPtr(), outputs[j]->DataPtr(), grad->NumElements(), grad->Dtype(),
340-
function::ReduceOpType::kSum, root, comms[i], streams[i]);
324+
ccl_impl_->Reduce(grad->DataPtr(), outputs[j]->DataPtr(), grad->NumElements(), grad->Dtype(),
325+
function::ReduceOpType::kSum, root, comms[i], streams[i]);
341326
}
342327
}
343328

@@ -357,19 +342,18 @@ std::vector<std::shared_ptr<Tensor>> ProcessGroup::Scatter(const std::shared_ptr
357342
src_rank = static_cast<int>(i);
358343
}
359344
outputs.push_back(std::make_shared<Tensor>(split_tensors[i]->Dims(), split_tensors[i]->Dtype(), devices[i]));
360-
streams.push_back(core::GetDeviceGuardImpl(devices[i].type())->GetStream(devices[i]));
345+
streams.push_back(runtime_impl_->GetStream(devices[i]));
361346
comms.push_back(device_comm_map_.at(devices[i].index()));
362347
}
363348
CHECK_NE(src_rank, -1) << "Source device not found in input devices";
364349

365-
auto *ccl_impl = core::GetCclImpl(devices[0].type());
366350
core::CclGroupGuard ccl_group_guard(devices[0].type());
367351
for (size_t i = 0; i < devices.size(); ++i) {
368352
core::DeviceGuard guard(devices[i]);
369-
ccl_impl->Send(split_tensors[i]->DataPtr(), split_tensors[i]->NumElements(), tensor->Dtype(), i,
370-
comms[src_rank], streams[src_rank]);
371-
ccl_impl->Recv(outputs[i]->DataPtr(), outputs[i]->NumElements(), tensor->Dtype(), src_rank, comms[i],
372-
streams[i]);
353+
ccl_impl_->Send(split_tensors[i]->DataPtr(), split_tensors[i]->NumElements(), tensor->Dtype(), i,
354+
comms[src_rank], streams[src_rank]);
355+
ccl_impl_->Recv(outputs[i]->DataPtr(), outputs[i]->NumElements(), tensor->Dtype(), src_rank, comms[i],
356+
streams[i]);
373357
}
374358
return outputs;
375359
}
@@ -390,7 +374,7 @@ std::shared_ptr<Tensor> ProcessGroup::Gather(const std::vector<std::shared_ptr<T
390374
if (device == destination) {
391375
dest_rank = static_cast<int>(i);
392376
}
393-
streams.push_back(core::GetDeviceGuardImpl(device.type())->GetStream(device));
377+
streams.push_back(runtime_impl_->GetStream(device));
394378
comms.push_back(device_comm_map_.at(device.index()));
395379
devices.push_back(device);
396380
total_dim += tensors[i]->Dims()[dim];
@@ -401,7 +385,6 @@ std::shared_ptr<Tensor> ProcessGroup::Gather(const std::vector<std::shared_ptr<T
401385
auto output = std::make_shared<Tensor>(out_dims, dtype, destination);
402386
CHECK_NE(dest_rank, -1) << "Destination device not found in input tensors's devices";
403387

404-
auto *ccl_impl = core::GetCclImpl(devices[0].type());
405388
core::CclGroupGuard ccl_group_guard(devices[0].type());
406389
int64_t offset = 0;
407390
for (size_t i = 0; i < num_devices; ++i) {
@@ -410,8 +393,8 @@ std::shared_ptr<Tensor> ProcessGroup::Gather(const std::vector<std::shared_ptr<T
410393
size_t num_elements = tensor->NumElements();
411394
void *send_ptr = tensor->DataPtr();
412395
auto *recv_ptr = static_cast<int8_t *>(output->DataPtr()) + offset;
413-
ccl_impl->Send(send_ptr, num_elements, dtype, dest_rank, comms[i], streams[i]);
414-
ccl_impl->Recv(recv_ptr, num_elements, dtype, i, comms[dest_rank], streams[dest_rank]);
396+
ccl_impl_->Send(send_ptr, num_elements, dtype, dest_rank, comms[i], streams[i]);
397+
ccl_impl_->Recv(recv_ptr, num_elements, dtype, i, comms[dest_rank], streams[dest_rank]);
415398
offset += tensor->SizeInBytes();
416399
}
417400
return output;

0 commit comments

Comments
 (0)