@@ -36,7 +36,8 @@ ProcessGroup::ProcessGroup(int world_size, const std::string &name) : world_size
3636
3737ProcessGroup::ProcessGroup (Device::DeviceType backend, const std::string &process_group_name,
3838 const std::vector<int > &ranks)
39- : backend_(backend), world_size_(ranks.size()), name_(process_group_name) {
39+ : backend_(backend), runtime_impl_(core::GetDeviceGuardImpl(backend)), ccl_impl_(core::GetCclImpl(backend)),
40+ world_size_ (ranks.size()), name_(process_group_name) {
4041 CHECK_GT (world_size_, 0 );
4142 if (global::GetNnodes () == 1 && global::GetNprocPerNode () == 1 ) {
4243 InitSingleProcess (ranks);
@@ -59,9 +60,8 @@ void ProcessGroup::InitSingleProcess(const std::vector<int> &ranks) {
5960 comms_.clear ();
6061 comms_.reserve (world_size_);
6162
62- auto *ccl_impl = core::GetCclImpl (backend_);
6363 std::vector<core::CclComm *> comm_ptrs (static_cast <size_t >(world_size_), nullptr );
64- ccl_impl ->CommInitAll (comm_ptrs.data (), world_size_, ranks.data ());
64+ ccl_impl_ ->CommInitAll (comm_ptrs.data (), world_size_, ranks.data ());
6565
6666 for (int i = 0 ; i < ranks.size (); ++i) {
6767 auto *comm_raw = comm_ptrs[static_cast <size_t >(i)];
@@ -77,15 +77,13 @@ void ProcessGroup::InitSingleProcess(const std::vector<int> &ranks) {
7777}
7878
7979void ProcessGroup::InitMultiProcess (const std::vector<int > &ranks) {
80- auto *ccl_impl = core::GetCclImpl (backend_);
81-
8280 int n_threads = global::GetNthreadPerProc ();
8381 int global_proc_rank = global::GetGlobalProcRank ();
8482 int lower_rank = global_proc_rank * n_threads;
8583 int upper_rank = (global_proc_rank + 1 ) * n_threads;
8684
8785 core::CclUniqueId *unique_id_raw = nullptr ;
88- ccl_impl ->GetUniqueId (&unique_id_raw);
86+ ccl_impl_ ->GetUniqueId (&unique_id_raw);
8987 std::unique_ptr<core::CclUniqueId> unique_id (unique_id_raw);
9088
9189 int min_rank = std::ranges::min (ranks);
@@ -106,7 +104,7 @@ void ProcessGroup::InitMultiProcess(const std::vector<int> &ranks) {
106104
107105 core::CclComm *comm_raw = nullptr ;
108106 int group_rank = std::distance (ranks.begin (), it);
109- ccl_impl ->CommInitRank (&comm_raw, world_size_, *unique_id, group_rank);
107+ ccl_impl_ ->CommInitRank (&comm_raw, world_size_, *unique_id, group_rank);
110108 CHECK_NOTNULL (comm_raw);
111109 comms_.emplace_back (comm_raw);
112110
@@ -120,10 +118,9 @@ void ProcessGroup::InitMultiProcess(const std::vector<int> &ranks) {
120118void ProcessGroup::InitStreams () {
121119 for (const auto &device : devices_) {
122120 core::DeviceGuard guard (device);
123- auto *impl = core::GetDeviceGuardImpl (device.type ());
124121 int low, high;
125- impl ->GetStreamPriorityRange (&low, &high);
126- auto *stream = CreateOwnedStream (impl , device, high, comm_streams_);
122+ runtime_impl_ ->GetStreamPriorityRange (&low, &high);
123+ auto *stream = CreateOwnedStream (runtime_impl_ , device, high, comm_streams_);
127124 device_stream_map_[device.index ()] = stream;
128125 }
129126}
@@ -132,18 +129,16 @@ std::shared_ptr<Work> ProcessGroup::AllReduce(const std::shared_ptr<Tensor> &ten
132129 bool async_op) const {
133130 auto device = tensor->GetDevice ();
134131 core::DeviceGuard guard (device);
135- auto *runtime_impl = core::GetDeviceGuardImpl (device.type ());
136- auto *ccl_impl = core::GetCclImpl (device.type ());
137- auto *compute_stream = runtime_impl->GetStream (device);
132+ auto *compute_stream = runtime_impl_->GetStream (device);
138133 auto *comm_stream = device_stream_map_.at (device.index ());
139134 auto comm = device_comm_map_.at (device.index ());
140135
141136 auto work = std::make_shared<Work>(device, comm);
142- runtime_impl ->EventRecord (work->ready_event (), compute_stream);
143- runtime_impl ->StreamWaitEvent (comm_stream, work->ready_event (), 0 );
144- ccl_impl ->AllReduce (tensor->DataPtr (), tensor->DataPtr (), tensor->NumElements (), tensor->Dtype (), reduce_op, comm,
145- comm_stream);
146- runtime_impl ->EventRecord (work->done_event (), comm_stream);
137+ runtime_impl_ ->EventRecord (work->ready_event (), compute_stream);
138+ runtime_impl_ ->StreamWaitEvent (comm_stream, work->ready_event (), 0 );
139+ ccl_impl_ ->AllReduce (tensor->DataPtr (), tensor->DataPtr (), tensor->NumElements (), tensor->Dtype (), reduce_op, comm,
140+ comm_stream);
141+ runtime_impl_ ->EventRecord (work->done_event (), comm_stream);
147142
148143 if (async_op) {
149144 return work;
@@ -157,17 +152,15 @@ std::shared_ptr<Work> ProcessGroup::AllGather(const std::shared_ptr<Tensor> &out
157152 const std::shared_ptr<Tensor> &input, bool async_op) const {
158153 auto device = input->GetDevice ();
159154 core::DeviceGuard guard (device);
160- auto *runtime_impl = core::GetDeviceGuardImpl (device.type ());
161- auto *ccl_impl = core::GetCclImpl (device.type ());
162- auto *compute_stream = runtime_impl->GetStream (device);
155+ auto *compute_stream = runtime_impl_->GetStream (device);
163156 auto *comm_stream = device_stream_map_.at (device.index ());
164157 auto comm = device_comm_map_.at (device.index ());
165158
166159 auto work = std::make_shared<Work>(device, comm);
167- runtime_impl ->EventRecord (work->ready_event (), compute_stream);
168- runtime_impl ->StreamWaitEvent (comm_stream, work->ready_event (), 0 );
169- ccl_impl ->AllGather (input->DataPtr (), output->DataPtr (), input->NumElements (), input->Dtype (), comm, comm_stream);
170- runtime_impl ->EventRecord (work->done_event (), comm_stream);
160+ runtime_impl_ ->EventRecord (work->ready_event (), compute_stream);
161+ runtime_impl_ ->StreamWaitEvent (comm_stream, work->ready_event (), 0 );
162+ ccl_impl_ ->AllGather (input->DataPtr (), output->DataPtr (), input->NumElements (), input->Dtype (), comm, comm_stream);
163+ runtime_impl_ ->EventRecord (work->done_event (), comm_stream);
171164
172165 if (async_op) {
173166 return work;
@@ -182,18 +175,16 @@ std::shared_ptr<Work> ProcessGroup::ReduceScatter(const std::shared_ptr<Tensor>
182175 function::ReduceOpType reduce_op, bool async_op) const {
183176 auto device = input->GetDevice ();
184177 core::DeviceGuard guard (device);
185- auto *runtime_impl = core::GetDeviceGuardImpl (device.type ());
186- auto *ccl_impl = core::GetCclImpl (device.type ());
187- auto *compute_stream = runtime_impl->GetStream (device);
178+ auto *compute_stream = runtime_impl_->GetStream (device);
188179 auto *comm_stream = device_stream_map_.at (device.index ());
189180 auto comm = device_comm_map_.at (device.index ());
190181
191182 auto work = std::make_shared<Work>(device, comm);
192- runtime_impl ->EventRecord (work->ready_event (), compute_stream);
193- runtime_impl ->StreamWaitEvent (comm_stream, work->ready_event (), 0 );
194- ccl_impl ->ReduceScatter (input->DataPtr (), output->DataPtr (), output->NumElements (), input->Dtype (), reduce_op, comm ,
195- comm_stream);
196- runtime_impl ->EventRecord (work->done_event (), comm_stream);
183+ runtime_impl_ ->EventRecord (work->ready_event (), compute_stream);
184+ runtime_impl_ ->StreamWaitEvent (comm_stream, work->ready_event (), 0 );
185+ ccl_impl_ ->ReduceScatter (input->DataPtr (), output->DataPtr (), output->NumElements (), input->Dtype (), reduce_op,
186+ comm, comm_stream);
187+ runtime_impl_ ->EventRecord (work->done_event (), comm_stream);
197188
198189 if (async_op) {
199190 return work;
@@ -208,21 +199,19 @@ std::shared_ptr<Work> ProcessGroup::Send(std::vector<std::shared_ptr<Tensor>> te
208199 CHECK_GT (tensors.size (), 0 );
209200 auto device = tensors[0 ]->GetDevice ();
210201 core::DeviceGuard guard (device);
211- auto *runtime_impl = core::GetDeviceGuardImpl (device.type ());
212- auto *ccl_impl = core::GetCclImpl (device.type ());
213- auto *compute_stream = runtime_impl->GetStream (device);
202+ auto *compute_stream = runtime_impl_->GetStream (device);
214203 auto *comm_stream = device_stream_map_.at (device.index ());
215204 auto comm = device_comm_map_.at (device.index ());
216205
217206 auto work = std::make_shared<Work>(device, comm);
218- runtime_impl ->EventRecord (work->ready_event (), compute_stream);
219- runtime_impl ->StreamWaitEvent (comm_stream, work->ready_event (), 0 );
207+ runtime_impl_ ->EventRecord (work->ready_event (), compute_stream);
208+ runtime_impl_ ->StreamWaitEvent (comm_stream, work->ready_event (), 0 );
220209 for (const auto &tensor : tensors) {
221210 CHECK_NOTNULL (tensor);
222211 CHECK_EQ (device, tensor->GetDevice ());
223- ccl_impl ->Send (tensor->DataPtr (), tensor->NumElements (), tensor->Dtype (), dest_rank, comm, comm_stream);
212+ ccl_impl_ ->Send (tensor->DataPtr (), tensor->NumElements (), tensor->Dtype (), dest_rank, comm, comm_stream);
224213 }
225- runtime_impl ->EventRecord (work->done_event (), comm_stream);
214+ runtime_impl_ ->EventRecord (work->done_event (), comm_stream);
226215
227216 if (async_op) {
228217 return work;
@@ -237,21 +226,19 @@ std::shared_ptr<Work> ProcessGroup::Recv(std::vector<std::shared_ptr<Tensor>> te
237226 CHECK_GT (tensors.size (), 0 );
238227 auto device = tensors[0 ]->GetDevice ();
239228 core::DeviceGuard guard (device);
240- auto *runtime_impl = core::GetDeviceGuardImpl (device.type ());
241- auto *ccl_impl = core::GetCclImpl (device.type ());
242- auto *compute_stream = runtime_impl->GetStream (device);
229+ auto *compute_stream = runtime_impl_->GetStream (device);
243230 auto *comm_stream = device_stream_map_.at (device.index ());
244231 auto comm = device_comm_map_.at (device.index ());
245232
246233 auto work = std::make_shared<Work>(device, comm);
247- runtime_impl ->EventRecord (work->ready_event (), compute_stream);
248- runtime_impl ->StreamWaitEvent (comm_stream, work->ready_event (), 0 );
234+ runtime_impl_ ->EventRecord (work->ready_event (), compute_stream);
235+ runtime_impl_ ->StreamWaitEvent (comm_stream, work->ready_event (), 0 );
249236 for (const auto &tensor : tensors) {
250237 CHECK_NOTNULL (tensor);
251238 CHECK_EQ (device, tensor->GetDevice ());
252- ccl_impl ->Recv (tensor->DataPtr (), tensor->NumElements (), tensor->Dtype (), src_rank, comm, comm_stream);
239+ ccl_impl_ ->Recv (tensor->DataPtr (), tensor->NumElements (), tensor->Dtype (), src_rank, comm, comm_stream);
253240 }
254- runtime_impl ->EventRecord (work->done_event (), comm_stream);
241+ runtime_impl_ ->EventRecord (work->done_event (), comm_stream);
255242
256243 if (async_op) {
257244 return work;
@@ -275,7 +262,7 @@ ProcessGroup::BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensor
275262 outputs.push_back (std::make_shared<Tensor>(input_tensor->Dims (), input_tensor->Dtype (), device));
276263 }
277264 devices.push_back (device);
278- streams.push_back (core::GetDeviceGuardImpl (device. type ()) ->GetStream (device));
265+ streams.push_back (runtime_impl_ ->GetStream (device));
279266 comms.push_back (device_comm_map_.at (device.index ()));
280267 }
281268
@@ -288,15 +275,14 @@ ProcessGroup::BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensor
288275 }
289276 CHECK_NE (root, -1 ) << " Root not found in input devices" ;
290277
291- auto *ccl_impl = core::GetCclImpl (devices[0 ].type ());
292278 core::CclGroupGuard ccl_group_guard (devices[0 ].type ());
293279 for (size_t i = 0 ; i < devices.size (); ++i) {
294280 core::DeviceGuard guard (devices[i]);
295281 for (size_t j = 0 ; j < input_tensors.size (); ++j) {
296282 const auto &input_tensor = input_tensors[j];
297283 const void *send_buffer = (static_cast <int >(i) == root ? input_tensor->DataPtr () : nullptr );
298- ccl_impl ->Broadcast (send_buffer, outputs[i * input_tensors.size () + j]->DataPtr (),
299- input_tensor->NumElements (), input_tensor->Dtype (), root, comms[i], streams[i]);
284+ ccl_impl_ ->Broadcast (send_buffer, outputs[i * input_tensors.size () + j]->DataPtr (),
285+ input_tensor->NumElements (), input_tensor->Dtype (), root, comms[i], streams[i]);
300286 }
301287 }
302288
@@ -317,7 +303,7 @@ ProcessGroup::ReduceAddCoalesced(const std::vector<std::vector<std::shared_ptr<T
317303 }
318304 for (size_t i = 0 ; i < grads.size (); ++i) {
319305 devices.push_back (grads[i][0 ]->GetDevice ());
320- streams.push_back (core::GetDeviceGuardImpl (devices[i]. type ()) ->GetStream (devices[i]));
306+ streams.push_back (runtime_impl_ ->GetStream (devices[i]));
321307 comms.push_back (device_comm_map_.at (devices[i].index ()));
322308 }
323309
@@ -330,14 +316,13 @@ ProcessGroup::ReduceAddCoalesced(const std::vector<std::vector<std::shared_ptr<T
330316 }
331317 CHECK_NE (root, -1 ) << " Destination device not found in grads group" ;
332318
333- auto *ccl_impl = core::GetCclImpl (devices[0 ].type ());
334319 core::CclGroupGuard ccl_group_guard (devices[0 ].type ());
335320 for (size_t i = 0 ; i < grads.size (); ++i) {
336321 core::DeviceGuard guard (devices[i]);
337322 for (size_t j = 0 ; j < grads[i].size (); ++j) {
338323 const auto &grad = grads[i][j];
339- ccl_impl ->Reduce (grad->DataPtr (), outputs[j]->DataPtr (), grad->NumElements (), grad->Dtype (),
340- function::ReduceOpType::kSum , root, comms[i], streams[i]);
324+ ccl_impl_ ->Reduce (grad->DataPtr (), outputs[j]->DataPtr (), grad->NumElements (), grad->Dtype (),
325+ function::ReduceOpType::kSum , root, comms[i], streams[i]);
341326 }
342327 }
343328
@@ -357,19 +342,18 @@ std::vector<std::shared_ptr<Tensor>> ProcessGroup::Scatter(const std::shared_ptr
357342 src_rank = static_cast <int >(i);
358343 }
359344 outputs.push_back (std::make_shared<Tensor>(split_tensors[i]->Dims (), split_tensors[i]->Dtype (), devices[i]));
360- streams.push_back (core::GetDeviceGuardImpl (devices[i]. type ()) ->GetStream (devices[i]));
345+ streams.push_back (runtime_impl_ ->GetStream (devices[i]));
361346 comms.push_back (device_comm_map_.at (devices[i].index ()));
362347 }
363348 CHECK_NE (src_rank, -1 ) << " Source device not found in input devices" ;
364349
365- auto *ccl_impl = core::GetCclImpl (devices[0 ].type ());
366350 core::CclGroupGuard ccl_group_guard (devices[0 ].type ());
367351 for (size_t i = 0 ; i < devices.size (); ++i) {
368352 core::DeviceGuard guard (devices[i]);
369- ccl_impl ->Send (split_tensors[i]->DataPtr (), split_tensors[i]->NumElements (), tensor->Dtype (), i,
370- comms[src_rank], streams[src_rank]);
371- ccl_impl ->Recv (outputs[i]->DataPtr (), outputs[i]->NumElements (), tensor->Dtype (), src_rank, comms[i],
372- streams[i]);
353+ ccl_impl_ ->Send (split_tensors[i]->DataPtr (), split_tensors[i]->NumElements (), tensor->Dtype (), i,
354+ comms[src_rank], streams[src_rank]);
355+ ccl_impl_ ->Recv (outputs[i]->DataPtr (), outputs[i]->NumElements (), tensor->Dtype (), src_rank, comms[i],
356+ streams[i]);
373357 }
374358 return outputs;
375359}
@@ -390,7 +374,7 @@ std::shared_ptr<Tensor> ProcessGroup::Gather(const std::vector<std::shared_ptr<T
390374 if (device == destination) {
391375 dest_rank = static_cast <int >(i);
392376 }
393- streams.push_back (core::GetDeviceGuardImpl (device. type ()) ->GetStream (device));
377+ streams.push_back (runtime_impl_ ->GetStream (device));
394378 comms.push_back (device_comm_map_.at (device.index ()));
395379 devices.push_back (device);
396380 total_dim += tensors[i]->Dims ()[dim];
@@ -401,7 +385,6 @@ std::shared_ptr<Tensor> ProcessGroup::Gather(const std::vector<std::shared_ptr<T
401385 auto output = std::make_shared<Tensor>(out_dims, dtype, destination);
402386 CHECK_NE (dest_rank, -1 ) << " Destination device not found in input tensors's devices" ;
403387
404- auto *ccl_impl = core::GetCclImpl (devices[0 ].type ());
405388 core::CclGroupGuard ccl_group_guard (devices[0 ].type ());
406389 int64_t offset = 0 ;
407390 for (size_t i = 0 ; i < num_devices; ++i) {
@@ -410,8 +393,8 @@ std::shared_ptr<Tensor> ProcessGroup::Gather(const std::vector<std::shared_ptr<T
410393 size_t num_elements = tensor->NumElements ();
411394 void *send_ptr = tensor->DataPtr ();
412395 auto *recv_ptr = static_cast <int8_t *>(output->DataPtr ()) + offset;
413- ccl_impl ->Send (send_ptr, num_elements, dtype, dest_rank, comms[i], streams[i]);
414- ccl_impl ->Recv (recv_ptr, num_elements, dtype, i, comms[dest_rank], streams[dest_rank]);
396+ ccl_impl_ ->Send (send_ptr, num_elements, dtype, dest_rank, comms[i], streams[i]);
397+ ccl_impl_ ->Recv (recv_ptr, num_elements, dtype, i, comms[dest_rank], streams[dest_rank]);
415398 offset += tensor->SizeInBytes ();
416399 }
417400 return output;
0 commit comments