Skip to content

Commit e281dea

Browse files
Chamberlain0w0kilinchange
authored andcommitted
fix: fix nccl error in process_group and seg fault in profiler
1 parent 73a2bad commit e281dea

5 files changed

Lines changed: 76 additions & 108 deletions

File tree

infini_train/include/profiler.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
#include <map>
77
#include <mutex>
88
#include <string>
9+
#include <vector>
910

1011
#include "glog/logging.h"
1112

1213
#include "infini_train/include/device.h"
1314

1415
namespace infini_train {
16+
namespace core {
17+
class Event;
18+
}
1519

1620
inline thread_local int g_profiling_depth = 0;
1721

@@ -80,20 +84,18 @@ class Profiler {
8084
void ReportGroupedByRank(std::function<std::ostream &(int64_t)> get_os, SortBy sort_by) const;
8185
void PrintRecordsGroupedByRank(std::function<std::ostream &(int64_t)> get_os) const;
8286

83-
std::mutex mtx_;
87+
mutable std::mutex mtx_;
8488
std::vector<KernelCallRecord> call_records_;
8589
std::string current_tag_ = "Untagged";
8690

8791
// thread-local tracking
8892
thread_local static inline std::map<std::string, std::chrono::high_resolution_clock::time_point> cpu_timing_map_;
8993

90-
#ifdef USE_CUDA
9194
struct EventPair {
92-
void *start;
93-
void *stop;
95+
core::Event *start = nullptr;
96+
core::Event *stop = nullptr;
9497
};
9598

96-
thread_local static inline std::map<std::string, EventPair> cuda_timing_map_;
97-
#endif
99+
thread_local static inline std::map<std::string, EventPair> device_timing_map_;
98100
};
99101
} // namespace infini_train

infini_train/src/core/cuda/cuda_guard_impl.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,11 @@ void CudaGuardImpl::SynchronizeDevice(Device device) const {
202202
SetDevice(original_device);
203203
}
204204

205+
void CudaGuardImpl::SynchronizeStream(Stream *stream) const {
206+
auto cuda_stream = GetCudaStream(stream);
207+
CUDA_CHECK(cudaStreamSynchronize(cuda_stream));
208+
}
209+
205210
// blas
206211
BlasHandle *CudaGuardImpl::GetBlasHandle(Device device) const {
207212
CheckCudaDevice(device);

infini_train/src/core/cuda/cuda_guard_impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class CudaGuardImpl final : public DeviceGuardImpl {
5555

5656
// sync
5757
void SynchronizeDevice(Device device) const override;
58+
void SynchronizeStream(Stream *stream) const override;
5859

5960
// blas
6061
BlasHandle *GetBlasHandle(Device device) const override;

infini_train/src/nn/parallel/process_group.cc

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,6 @@ ProcessGroup::~ProcessGroup() {
3838
if (is_main_process_) {
3939
core::GetCclImpl(backend_)->CleanupUniqueIdFile(name_);
4040
}
41-
42-
auto *impl = core::GetDeviceGuardImpl(backend_);
43-
for (auto &s : comm_streams_) {
44-
if (s) {
45-
impl->DestroyStream(s.get());
46-
}
47-
}
48-
49-
auto *ccl_impl = core::GetCclImpl(backend_);
50-
for (auto &c : comms_) {
51-
if (c) {
52-
ccl_impl->CommDestroy(c.get());
53-
}
54-
}
5541
}
5642

5743
void ProcessGroup::InitSingleProcess(const std::vector<int> &ranks) {

infini_train/src/profiler.cc

Lines changed: 62 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,11 @@
55
#include <iostream>
66
#include <map>
77

8-
#ifdef USE_CUDA
9-
#include <cuda_runtime.h>
10-
#endif
11-
128
#include "glog/logging.h"
139

14-
#ifdef USE_CUDA
15-
#include "infini_train/include/common/cuda/common_cuda.h"
16-
#endif
1710
#include "infini_train/include/core/device_guard.h"
1811
#include "infini_train/include/device.h"
1912

20-
#include "infini_train/src/core/cuda/cuda_stream.h"
21-
2213
namespace infini_train {
2314
namespace {
2415
inline std::string GetCurrentTimestamp() {
@@ -46,54 +37,39 @@ int GetRank(Device::DeviceType device) {
4637
return impl->GetDevice().index();
4738
}
4839

49-
#ifdef USE_CUDA
50-
cudaStream_t GetCudaStream() {
51-
int device_id = GetRank(Device::DeviceType::kCUDA);
52-
// TODO(zbl): support multi-stream on single device
53-
auto device = Device(Device::DeviceType::kCUDA, static_cast<int8_t>(device_id));
54-
return dynamic_cast<infini_train::core::cuda::CudaStream *>(
55-
core::GetDeviceGuardImpl(device.type())->GetStream(device))
56-
->cuda_stream();
57-
}
58-
#endif
59-
6040
void Profiler::StartRecord(const std::string &name, Device::DeviceType device) {
6141
if (g_profiling_depth++ > 0) {
6242
return;
6343
}
6444
cpu_timing_map_[name] = std::chrono::high_resolution_clock::now();
6545

66-
switch (device) {
67-
case Device::DeviceType::kCPU:
68-
break;
69-
#ifdef USE_CUDA
70-
case Device::DeviceType::kCUDA: {
71-
auto it = cuda_timing_map_.find(name);
72-
if (it != cuda_timing_map_.end()) {
73-
// Make sure there are no conflicts
74-
CUDA_CHECK(cudaEventDestroy(reinterpret_cast<cudaEvent_t>(it->second.start)));
75-
CUDA_CHECK(cudaEventDestroy(reinterpret_cast<cudaEvent_t>(it->second.stop)));
76-
cuda_timing_map_.erase(it);
77-
}
78-
79-
cudaEvent_t start, stop;
80-
cudaStream_t stream = GetCudaStream();
81-
CUDA_CHECK(cudaEventCreate(&start));
82-
CUDA_CHECK(cudaEventCreate(&stop));
83-
84-
// Make sure the compute stream has done waiting, and ready for the execution of next op
85-
CUDA_CHECK(cudaStreamSynchronize(stream));
86-
// Start record after waiting
87-
cpu_timing_map_[name] = std::chrono::high_resolution_clock::now();
88-
CUDA_CHECK(cudaEventRecord(start, stream));
89-
cuda_timing_map_[name] = {reinterpret_cast<void *>(start), reinterpret_cast<void *>(stop)};
90-
break;
46+
if (device == Device::DeviceType::kCPU) {
47+
return;
9148
}
92-
#endif
93-
default:
94-
LOG(FATAL) << "Unsupported device type.";
95-
break;
49+
50+
auto *impl = core::GetDeviceGuardImpl(device);
51+
const int device_id = impl->GetDevice().index();
52+
auto current_device = Device(device, static_cast<int8_t>(device_id));
53+
auto *stream = impl->GetStream(current_device);
54+
55+
auto it = device_timing_map_.find(name);
56+
if (it != device_timing_map_.end()) {
57+
impl->EventDestroy(it->second.start);
58+
impl->EventDestroy(it->second.stop);
59+
device_timing_map_.erase(it);
9660
}
61+
62+
core::Event *start = nullptr;
63+
core::Event *stop = nullptr;
64+
impl->EventCreate(&start);
65+
impl->EventCreate(&stop);
66+
67+
// Make sure the compute stream has done waiting, and ready for the execution of next op
68+
impl->SynchronizeStream(stream);
69+
// Start record after waiting
70+
cpu_timing_map_[name] = std::chrono::high_resolution_clock::now();
71+
impl->EventRecord(start, stream);
72+
device_timing_map_[name] = {start, stop};
9773
}
9874

9975
void Profiler::EndRecord(const std::string &name, Device::DeviceType device) {
@@ -105,44 +81,28 @@ void Profiler::EndRecord(const std::string &name, Device::DeviceType device) {
10581
std::string device_str = "cpu";
10682
int rank = GetRank(device);
10783

108-
switch (device) {
109-
case Device::DeviceType::kCPU:
110-
break;
111-
#ifdef USE_CUDA
112-
case Device::DeviceType::kCUDA: {
113-
auto it = cuda_timing_map_.find(name);
114-
if (it != cuda_timing_map_.end()) {
115-
auto event_pair = it->second;
116-
cudaEvent_t start = reinterpret_cast<cudaEvent_t>(event_pair.start);
117-
cudaEvent_t stop = reinterpret_cast<cudaEvent_t>(event_pair.stop);
118-
cudaStream_t stream = GetCudaStream();
119-
CUDA_CHECK(cudaEventRecord(stop, stream));
120-
CUDA_CHECK(cudaEventSynchronize(stop));
121-
float elapsed_ms = 0.f;
122-
CUDA_CHECK(cudaEventElapsedTime(&elapsed_ms, start, stop));
123-
device_us = static_cast<int64_t>(elapsed_ms * 1000);
124-
CUDA_CHECK(cudaEventDestroy(start));
125-
CUDA_CHECK(cudaEventDestroy(stop));
126-
cuda_timing_map_.erase(it);
127-
128-
cudaMemPool_t pool;
129-
size_t peak_bytes = 0;
130-
if (cudaDeviceGetDefaultMemPool(&pool, rank) == cudaSuccess
131-
&& cudaMemPoolGetAttribute(pool, cudaMemPoolAttrReservedMemHigh, &peak_bytes) == cudaSuccess) {
132-
peak_mem_mb = static_cast<int64_t>(peak_bytes) / (1024 * 1024);
133-
} else {
134-
LOG(FATAL) << "cudaMemPool not supported.";
135-
}
136-
device_str = "cuda:" + std::to_string(rank);
137-
} else {
84+
if (device != Device::DeviceType::kCPU) {
85+
auto *impl = core::GetDeviceGuardImpl(device);
86+
auto current_device = Device(device, static_cast<int8_t>(rank));
87+
auto *stream = impl->GetStream(current_device);
88+
89+
auto it = device_timing_map_.find(name);
90+
if (it == device_timing_map_.end()) {
13891
LOG(FATAL) << "Start time of " + name + " is not recorded.";
13992
}
140-
break;
141-
}
142-
#endif
143-
default:
144-
LOG(FATAL) << "Unsupported device type.";
145-
break;
93+
94+
auto event_pair = it->second;
95+
impl->EventRecord(event_pair.stop, stream);
96+
impl->EventSynchronize(event_pair.stop);
97+
device_us = static_cast<int64_t>(impl->EventElapsedTime(event_pair.start, event_pair.stop) * 1000.0f);
98+
impl->EventDestroy(event_pair.start);
99+
impl->EventDestroy(event_pair.stop);
100+
device_timing_map_.erase(it);
101+
102+
auto [peak_used_mb, peak_reserved_mb] = impl->GetMemPoolPeakMB(current_device);
103+
(void)peak_used_mb;
104+
peak_mem_mb = static_cast<int64_t>(peak_reserved_mb);
105+
device_str = current_device.ToString();
146106
}
147107

148108
auto cpu_start = cpu_timing_map_[name];
@@ -171,9 +131,16 @@ void Profiler::Reset() {
171131
void Profiler::SetTag(const std::string &tag) { current_tag_ = tag; }
172132

173133
void Profiler::ReportGroupedByRank(std::function<std::ostream &(int64_t)> get_os, SortBy sort_by) const {
134+
std::vector<KernelCallRecord> records_snapshot;
135+
{
136+
// Prevent call_records_ from being modified by other threads
137+
std::lock_guard<std::mutex> lock(mtx_);
138+
records_snapshot = call_records_;
139+
}
140+
174141
std::map<int64_t, std::map<std::string, std::map<std::string, KernelProfileInfo>>> grouped_stats;
175142

176-
for (const auto &rec : call_records_) {
143+
for (const auto &rec : records_snapshot) {
177144
auto &entry = grouped_stats[rec.rank][rec.tag][rec.name];
178145
entry.host_total_us += rec.host_us;
179146
entry.device_total_us += rec.device_us;
@@ -193,7 +160,7 @@ void Profiler::ReportGroupedByRank(std::function<std::ostream &(int64_t)> get_os
193160

194161
// Peak memory usage by tag
195162
int64_t tag_peak_mb = 0;
196-
for (const auto &rec : call_records_) {
163+
for (const auto &rec : records_snapshot) {
197164
if (rec.rank == rank && rec.tag == tag) {
198165
tag_peak_mb = std::max(tag_peak_mb, rec.max_device_mem_usage_mb);
199166
}
@@ -283,9 +250,16 @@ void Profiler::Report(const std::string &file_prefix, SortBy sort_by) const {
283250
}
284251

285252
void Profiler::PrintRecordsGroupedByRank(std::function<std::ostream &(int64_t)> get_os) const {
253+
std::vector<KernelCallRecord> records_snapshot;
254+
{
255+
// Prevent call_records_ from being modified by other threads
256+
std::lock_guard<std::mutex> lock(mtx_);
257+
records_snapshot = call_records_;
258+
}
259+
286260
std::map<int64_t, std::map<std::string, std::vector<const KernelCallRecord *>>> grouped;
287261

288-
for (const auto &rec : call_records_) { grouped[rec.rank][rec.tag].push_back(&rec); }
262+
for (const auto &rec : records_snapshot) { grouped[rec.rank][rec.tag].push_back(&rec); }
289263

290264
for (const auto &[rank, tag_map] : grouped) {
291265
std::ostream &os = get_os(rank);

0 commit comments

Comments
 (0)