55#include < iostream>
66#include < map>
77
8- #ifdef USE_CUDA
9- #include < cuda_runtime.h>
10- #endif
11-
128#include " glog/logging.h"
139
14- #ifdef USE_CUDA
15- #include " infini_train/include/common/cuda/common_cuda.h"
16- #endif
1710#include " infini_train/include/core/device_guard.h"
1811#include " infini_train/include/device.h"
1912
20- #include " infini_train/src/core/cuda/cuda_stream.h"
21-
2213namespace infini_train {
2314namespace {
2415inline std::string GetCurrentTimestamp () {
@@ -46,54 +37,39 @@ int GetRank(Device::DeviceType device) {
4637 return impl->GetDevice ().index ();
4738}
4839
49- #ifdef USE_CUDA
50- cudaStream_t GetCudaStream () {
51- int device_id = GetRank (Device::DeviceType::kCUDA );
52- // TODO(zbl): support multi-stream on single device
53- auto device = Device (Device::DeviceType::kCUDA , static_cast <int8_t >(device_id));
54- return dynamic_cast <infini_train::core::cuda::CudaStream *>(
55- core::GetDeviceGuardImpl (device.type ())->GetStream (device))
56- ->cuda_stream ();
57- }
58- #endif
59-
6040void Profiler::StartRecord (const std::string &name, Device::DeviceType device) {
6141 if (g_profiling_depth++ > 0 ) {
6242 return ;
6343 }
6444 cpu_timing_map_[name] = std::chrono::high_resolution_clock::now ();
6545
66- switch (device) {
67- case Device::DeviceType::kCPU :
68- break ;
69- #ifdef USE_CUDA
70- case Device::DeviceType::kCUDA : {
71- auto it = cuda_timing_map_.find (name);
72- if (it != cuda_timing_map_.end ()) {
73- // Make sure there are no conflicts
74- CUDA_CHECK (cudaEventDestroy (reinterpret_cast <cudaEvent_t>(it->second .start )));
75- CUDA_CHECK (cudaEventDestroy (reinterpret_cast <cudaEvent_t>(it->second .stop )));
76- cuda_timing_map_.erase (it);
77- }
78-
79- cudaEvent_t start, stop;
80- cudaStream_t stream = GetCudaStream ();
81- CUDA_CHECK (cudaEventCreate (&start));
82- CUDA_CHECK (cudaEventCreate (&stop));
83-
84- // Make sure the compute stream has done waiting, and ready for the execution of next op
85- CUDA_CHECK (cudaStreamSynchronize (stream));
86- // Start record after waiting
87- cpu_timing_map_[name] = std::chrono::high_resolution_clock::now ();
88- CUDA_CHECK (cudaEventRecord (start, stream));
89- cuda_timing_map_[name] = {reinterpret_cast <void *>(start), reinterpret_cast <void *>(stop)};
90- break ;
46+ if (device == Device::DeviceType::kCPU ) {
47+ return ;
9148 }
92- #endif
93- default :
94- LOG (FATAL) << " Unsupported device type." ;
95- break ;
49+
50+ auto *impl = core::GetDeviceGuardImpl (device);
51+ const int device_id = impl->GetDevice ().index ();
52+ auto current_device = Device (device, static_cast <int8_t >(device_id));
53+ auto *stream = impl->GetStream (current_device);
54+
55+ auto it = device_timing_map_.find (name);
56+ if (it != device_timing_map_.end ()) {
57+ impl->EventDestroy (it->second .start );
58+ impl->EventDestroy (it->second .stop );
59+ device_timing_map_.erase (it);
9660 }
61+
62+ core::Event *start = nullptr ;
63+ core::Event *stop = nullptr ;
64+ impl->EventCreate (&start);
65+ impl->EventCreate (&stop);
66+
67+ // Make sure the compute stream has done waiting, and ready for the execution of next op
68+ impl->SynchronizeStream (stream);
69+ // Start record after waiting
70+ cpu_timing_map_[name] = std::chrono::high_resolution_clock::now ();
71+ impl->EventRecord (start, stream);
72+ device_timing_map_[name] = {start, stop};
9773}
9874
9975void Profiler::EndRecord (const std::string &name, Device::DeviceType device) {
@@ -105,44 +81,28 @@ void Profiler::EndRecord(const std::string &name, Device::DeviceType device) {
10581 std::string device_str = " cpu" ;
10682 int rank = GetRank (device);
10783
108- switch (device) {
109- case Device::DeviceType::kCPU :
110- break ;
111- #ifdef USE_CUDA
112- case Device::DeviceType::kCUDA : {
113- auto it = cuda_timing_map_.find (name);
114- if (it != cuda_timing_map_.end ()) {
115- auto event_pair = it->second ;
116- cudaEvent_t start = reinterpret_cast <cudaEvent_t>(event_pair.start );
117- cudaEvent_t stop = reinterpret_cast <cudaEvent_t>(event_pair.stop );
118- cudaStream_t stream = GetCudaStream ();
119- CUDA_CHECK (cudaEventRecord (stop, stream));
120- CUDA_CHECK (cudaEventSynchronize (stop));
121- float elapsed_ms = 0 .f ;
122- CUDA_CHECK (cudaEventElapsedTime (&elapsed_ms, start, stop));
123- device_us = static_cast <int64_t >(elapsed_ms * 1000 );
124- CUDA_CHECK (cudaEventDestroy (start));
125- CUDA_CHECK (cudaEventDestroy (stop));
126- cuda_timing_map_.erase (it);
127-
128- cudaMemPool_t pool;
129- size_t peak_bytes = 0 ;
130- if (cudaDeviceGetDefaultMemPool (&pool, rank) == cudaSuccess
131- && cudaMemPoolGetAttribute (pool, cudaMemPoolAttrReservedMemHigh, &peak_bytes) == cudaSuccess) {
132- peak_mem_mb = static_cast <int64_t >(peak_bytes) / (1024 * 1024 );
133- } else {
134- LOG (FATAL) << " cudaMemPool not supported." ;
135- }
136- device_str = " cuda:" + std::to_string (rank);
137- } else {
84+ if (device != Device::DeviceType::kCPU ) {
85+ auto *impl = core::GetDeviceGuardImpl (device);
86+ auto current_device = Device (device, static_cast <int8_t >(rank));
87+ auto *stream = impl->GetStream (current_device);
88+
89+ auto it = device_timing_map_.find (name);
90+ if (it == device_timing_map_.end ()) {
13891 LOG (FATAL) << " Start time of " + name + " is not recorded." ;
13992 }
140- break ;
141- }
142- #endif
143- default :
144- LOG (FATAL) << " Unsupported device type." ;
145- break ;
93+
94+ auto event_pair = it->second ;
95+ impl->EventRecord (event_pair.stop , stream);
96+ impl->EventSynchronize (event_pair.stop );
97+ device_us = static_cast <int64_t >(impl->EventElapsedTime (event_pair.start , event_pair.stop ) * 1000 .0f );
98+ impl->EventDestroy (event_pair.start );
99+ impl->EventDestroy (event_pair.stop );
100+ device_timing_map_.erase (it);
101+
102+ auto [peak_used_mb, peak_reserved_mb] = impl->GetMemPoolPeakMB (current_device);
103+ (void )peak_used_mb;
104+ peak_mem_mb = static_cast <int64_t >(peak_reserved_mb);
105+ device_str = current_device.ToString ();
146106 }
147107
148108 auto cpu_start = cpu_timing_map_[name];
@@ -171,9 +131,16 @@ void Profiler::Reset() {
171131void Profiler::SetTag (const std::string &tag) { current_tag_ = tag; }
172132
173133void Profiler::ReportGroupedByRank (std::function<std::ostream &(int64_t )> get_os, SortBy sort_by) const {
134+ std::vector<KernelCallRecord> records_snapshot;
135+ {
136+ // Prevent call_records_ from being modified by other threads
137+ std::lock_guard<std::mutex> lock (mtx_);
138+ records_snapshot = call_records_;
139+ }
140+
174141 std::map<int64_t , std::map<std::string, std::map<std::string, KernelProfileInfo>>> grouped_stats;
175142
176- for (const auto &rec : call_records_ ) {
143+ for (const auto &rec : records_snapshot ) {
177144 auto &entry = grouped_stats[rec.rank ][rec.tag ][rec.name ];
178145 entry.host_total_us += rec.host_us ;
179146 entry.device_total_us += rec.device_us ;
@@ -193,7 +160,7 @@ void Profiler::ReportGroupedByRank(std::function<std::ostream &(int64_t)> get_os
193160
194161 // Peak memory usage by tag
195162 int64_t tag_peak_mb = 0 ;
196- for (const auto &rec : call_records_ ) {
163+ for (const auto &rec : records_snapshot ) {
197164 if (rec.rank == rank && rec.tag == tag) {
198165 tag_peak_mb = std::max (tag_peak_mb, rec.max_device_mem_usage_mb );
199166 }
@@ -283,9 +250,16 @@ void Profiler::Report(const std::string &file_prefix, SortBy sort_by) const {
283250}
284251
285252void Profiler::PrintRecordsGroupedByRank (std::function<std::ostream &(int64_t )> get_os) const {
253+ std::vector<KernelCallRecord> records_snapshot;
254+ {
255+ // Prevent call_records_ from being modified by other threads
256+ std::lock_guard<std::mutex> lock (mtx_);
257+ records_snapshot = call_records_;
258+ }
259+
286260 std::map<int64_t , std::map<std::string, std::vector<const KernelCallRecord *>>> grouped;
287261
288- for (const auto &rec : call_records_ ) { grouped[rec.rank ][rec.tag ].push_back (&rec); }
262+ for (const auto &rec : records_snapshot ) { grouped[rec.rank ][rec.tag ].push_back (&rec); }
289263
290264 for (const auto &[rank, tag_map] : grouped) {
291265 std::ostream &os = get_os (rank);
0 commit comments