@@ -425,20 +425,27 @@ std::shared_ptr<Tensor> ProcessGroup::Gather(const std::vector<std::shared_ptr<T
425425
426426ProcessGroupFactory *ProcessGroupFactory::Instance () {
427427 // NOTE(zbl): Instance() with no arguments only gets initialized instance with a certain backend
428- std::lock_guard<std::mutex> lock (g_process_group_factory_mutex);
429428 auto &instance = g_process_group_factory_instance;
430429 if (instance == nullptr ) {
431- LOG (FATAL) << " ProcessGroupFactory is not initialized with backend. "
432- << " Call ProcessGroupFactory::Instance(backend) first." ;
430+ std::lock_guard<std::mutex> lock (g_process_group_factory_mutex);
431+ if (instance == nullptr ) {
432+ LOG (FATAL) << " ProcessGroupFactory is not initialized with backend. "
433+ << " Call ProcessGroupFactory::Instance(backend) first." ;
434+ }
433435 }
434436 return instance.get ();
435437}
436438
437439ProcessGroupFactory *ProcessGroupFactory::Instance (Device::DeviceType backend) {
438- std::lock_guard<std::mutex> lock (g_process_group_factory_mutex);
439440 auto &instance = g_process_group_factory_instance;
440441 if (instance == nullptr ) {
441- instance.reset (new ProcessGroupFactory (backend));
442+ std::lock_guard<std::mutex> lock (g_process_group_factory_mutex);
443+ if (instance == nullptr ) {
444+ instance.reset (new ProcessGroupFactory (backend));
445+ } else if (instance->backend_ != backend) {
446+ LOG (FATAL) << " ProcessGroupFactory backend mismatch. initialized=" << static_cast <int >(instance->backend_ )
447+ << " , requested=" << static_cast <int >(backend);
448+ }
442449 } else if (instance->backend_ != backend) {
443450 LOG (FATAL) << " ProcessGroupFactory backend mismatch. initialized=" << static_cast <int >(instance->backend_ )
444451 << " , requested=" << static_cast <int >(backend);
0 commit comments