Skip to content

Commit a1bea05

Browse files
Chamberlain0w0kilinchange
authored andcommitted
fix: add EventFlag enum, fix mutex usage in ProcessGroupFactory::Instance()
1 parent f33fc2b commit a1bea05

11 files changed

Lines changed: 47 additions & 15 deletions

File tree

infini_train/include/core/runtime/device_guard.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class DeviceGuardImpl {
9494

9595
virtual void EventCreate(Event **event) const;
9696

97-
virtual void EventCreateWithFlags(Event **event, uint32_t flags) const;
97+
virtual void EventCreateWithFlags(Event **event, EventFlag flags) const;
9898

9999
virtual void EventDestroy(Event *event) const;
100100

infini_train/include/core/runtime/runtime_common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ class Stream {
2222
virtual ~Stream() = default;
2323
};
2424

25+
enum class EventFlag : uint32_t {
26+
kDefault = 0x0,
27+
kBlockingSync = 0x1,
28+
kDisableTiming = 0x2,
29+
kInterprocess = 0x4,
30+
};
31+
2532
// Generic runtime status for backend-agnostic control flow.
2633
#define INFINI_TRAIN_RUNTIME_STATUS_LIST(X) \
2734
X(kSuccess, 0) \

infini_train/src/core/runtime/cpu/cpu_guard_impl.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ void CpuGuardImpl::EventCreate(Event **event) const {
6565
*event = nullptr;
6666
}
6767

68-
void CpuGuardImpl::EventCreateWithFlags(Event **event, uint32_t flags) const {
68+
void CpuGuardImpl::EventCreateWithFlags(Event **event, EventFlag flags) const {
6969
CHECK_NOTNULL(event);
7070
LOG(WARNING) << "CpuGuardImpl::EventCreateWithFlags is not supported. Returning nullptr event.";
7171
*event = nullptr;

infini_train/src/core/runtime/cpu/cpu_guard_impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class CpuGuardImpl final : public DeviceGuardImpl {
3333
// Event management (explicitly unsupported for now)
3434
void EventCreate(Event **event) const override;
3535

36-
void EventCreateWithFlags(Event **event, uint32_t flags) const override;
36+
void EventCreateWithFlags(Event **event, EventFlag flags) const override;
3737

3838
void EventDestroy(Event *event) const override;
3939

infini_train/src/core/runtime/cuda/cuda_guard_impl.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ void CudaGuardImpl::GetStreamPriorityRange(int *low, int *high) const {
139139
// event
140140
void CudaGuardImpl::EventCreate(Event **event) const { *event = new CudaEvent(); }
141141

142-
void CudaGuardImpl::EventCreateWithFlags(Event **event, uint32_t flags) const { *event = new CudaEvent(flags); }
142+
void CudaGuardImpl::EventCreateWithFlags(Event **event, EventFlag flags) const { *event = new CudaEvent(flags); }
143143

144144
void CudaGuardImpl::EventDestroy(Event *event) const {
145145
if (event == nullptr) {

infini_train/src/core/runtime/cuda/cuda_guard_impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class CudaGuardImpl final : public DeviceGuardImpl {
4141
// event
4242
void EventCreate(Event **event) const override;
4343

44-
void EventCreateWithFlags(Event **event, uint32_t flags) const override;
44+
void EventCreateWithFlags(Event **event, EventFlag flags) const override;
4545

4646
void EventDestroy(Event *event) const override;
4747

infini_train/src/core/runtime/cuda/cuda_runtime_common.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,26 @@
33
#include "infini_train/include/common/cuda/common_cuda.h"
44

55
namespace infini_train::core::cuda {
6+
namespace {
7+
uint32_t ToCudaEventFlags(EventFlag flags) {
8+
switch (flags) {
9+
case EventFlag::kDefault:
10+
return cudaEventDefault;
11+
case EventFlag::kBlockingSync:
12+
return cudaEventBlockingSync;
13+
case EventFlag::kDisableTiming:
14+
return cudaEventDisableTiming;
15+
case EventFlag::kInterprocess:
16+
// CUDA requires cudaEventDisableTiming with interprocess events.
17+
return cudaEventInterprocess | cudaEventDisableTiming;
18+
default:
19+
LOG(FATAL) << "Unsupported EventFlag value: " << static_cast<uint32_t>(flags);
20+
}
21+
return cudaEventDefault;
22+
}
23+
} // namespace
624

7-
CudaEvent::CudaEvent(uint32_t flags) { CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags)); }
25+
CudaEvent::CudaEvent(EventFlag flags) { CUDA_CHECK(cudaEventCreateWithFlags(&event_, ToCudaEventFlags(flags))); }
826

927
CudaEvent::~CudaEvent() {
1028
if (event_ != nullptr) {

infini_train/src/core/runtime/cuda/cuda_runtime_common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace infini_train::core::cuda {
1515

1616
class CudaEvent final : public Event {
1717
public:
18-
explicit CudaEvent(uint32_t flags = cudaEventDefault);
18+
explicit CudaEvent(EventFlag flags = EventFlag::kDefault);
1919
~CudaEvent() override;
2020

2121
cudaEvent_t cuda_event() const;

infini_train/src/core/runtime/device_guard.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ void DeviceGuardImpl::GetStreamPriorityRange(int *, int *) const {
4343

4444
void DeviceGuardImpl::EventCreate(Event **) const { LOG(FATAL) << "DeviceGuardImpl::EventCreate is not implemented."; }
4545

46-
void DeviceGuardImpl::EventCreateWithFlags(Event **, uint32_t) const {
46+
void DeviceGuardImpl::EventCreateWithFlags(Event **, EventFlag) const {
4747
LOG(FATAL) << "DeviceGuardImpl::EventCreateWithFlags is not implemented.";
4848
}
4949

infini_train/src/nn/parallel/process_group.cc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -425,20 +425,27 @@ std::shared_ptr<Tensor> ProcessGroup::Gather(const std::vector<std::shared_ptr<T
425425

426426
ProcessGroupFactory *ProcessGroupFactory::Instance() {
427427
// NOTE(zbl): Instance() with no arguments only gets initialized instance with a certain backend
428-
std::lock_guard<std::mutex> lock(g_process_group_factory_mutex);
429428
auto &instance = g_process_group_factory_instance;
430429
if (instance == nullptr) {
431-
LOG(FATAL) << "ProcessGroupFactory is not initialized with backend. "
432-
<< "Call ProcessGroupFactory::Instance(backend) first.";
430+
std::lock_guard<std::mutex> lock(g_process_group_factory_mutex);
431+
if (instance == nullptr) {
432+
LOG(FATAL) << "ProcessGroupFactory is not initialized with backend. "
433+
<< "Call ProcessGroupFactory::Instance(backend) first.";
434+
}
433435
}
434436
return instance.get();
435437
}
436438

437439
ProcessGroupFactory *ProcessGroupFactory::Instance(Device::DeviceType backend) {
438-
std::lock_guard<std::mutex> lock(g_process_group_factory_mutex);
439440
auto &instance = g_process_group_factory_instance;
440441
if (instance == nullptr) {
441-
instance.reset(new ProcessGroupFactory(backend));
442+
std::lock_guard<std::mutex> lock(g_process_group_factory_mutex);
443+
if (instance == nullptr) {
444+
instance.reset(new ProcessGroupFactory(backend));
445+
} else if (instance->backend_ != backend) {
446+
LOG(FATAL) << "ProcessGroupFactory backend mismatch. initialized=" << static_cast<int>(instance->backend_)
447+
<< ", requested=" << static_cast<int>(backend);
448+
}
442449
} else if (instance->backend_ != backend) {
443450
LOG(FATAL) << "ProcessGroupFactory backend mismatch. initialized=" << static_cast<int>(instance->backend_)
444451
<< ", requested=" << static_cast<int>(backend);

0 commit comments

Comments
 (0)