@@ -973,20 +973,20 @@ public:
973973 _stream_count (stream_count), _sync_matrix(stream_count * stream_count, true ) {}
974974
975975 bool is_in_sync_with (std::size_t this_stream, std::size_t other_stream) const {
976- return _sync_matrix.at (this_stream * stream_count + other_stream);
976+ return _sync_matrix.at (this_stream * _stream_count + other_stream);
977977 }
978978 void desynchronize (std::size_t this_stream) {
979979 for (std::size_t other_stream = 0 ; other_stream < _stream_count;
980- ++_other_stream ) {
980+ ++other_stream ) {
981981 if (this_stream != other_stream) {
982- _sync_matrix.at (other_stream * stream_count + this_stream) = false ;
982+ _sync_matrix.at (other_stream * _stream_count + this_stream) = false ;
983983 }
984984 }
985985 }
986986 void synchronize (std::size_t this_stream, std::size_t other_stream) {
987987 for (std::size_t i = 0 ; i < _stream_count; ++i) {
988988 if (is_in_sync_with (other_stream, i)) {
989- _sync_matrix.at (this_stream * stream_count + i) = true ;
989+ _sync_matrix.at (this_stream * _stream_count + i) = true ;
990990 }
991991 }
992992 }
@@ -995,14 +995,15 @@ public:
995995private:
996996 std::size_t _stream_count;
997997 std::vector<bool > _sync_matrix;
998- }
998+ };
999999
10001000} // namespace
10011001
10021002GpuRuntime::GpuRuntime (const Function& function, ContextPtr context) :
10031003 _context(context),
1004- _input_count (function.inputs().size()) _gpublas_handle(
1005- context.thread_pool(),
1004+ _input_count(function.inputs().size()),
1005+ _gpublas_handle(
1006+ context->thread_pool (),
10061007 []() {
10071008 gpublasHandle_t handle;
10081009 check_error (gpublasCreate (&handle));
@@ -1011,7 +1012,7 @@ GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
10111012 [](gpublasHandle_t handle) { check_error (gpublasDestroy (handle)); }
10121013 ),
10131014 _gpurand_generator (
1014- context. thread_pool(),
1015+ context-> thread_pool (),
10151016 []() {
10161017 gpurandGenerator_t handle;
10171018 check_error (gpurandCreateGenerator (&handle, GPURAND_RNG_PSEUDO_DEFAULT));
@@ -1035,52 +1036,38 @@ GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
10351036 std::size_t stream_count = 0 , event_count = 0 , backward_event_count = 0 ;
10361037 for (auto & instr : function.instructions ()) {
10371038 if (instr.stream_index >= stream_count) {
1038- stream_count = instr.stream_index ;
1039+ stream_count = instr.stream_index + 1 ;
10391040 }
10401041 }
10411042 SyncTracker sync_tracker (stream_count);
10421043 std::vector<int > local_source_streams (function.locals ().size (), -1 );
10431044 SizeVec last_stream_instrs (stream_count);
1044- nested_vector2<std::size_t > local_consumer_streams (function.locals ().size ());
10451045 nested_vector2<std::size_t > backward_wait_events (function.instructions ().size ());
10461046 std::vector<int > backward_record_events (function.instructions ().size (), -1 );
10471047
1048- for (auto & instr : function.instructions ()) {
1049- if (instr.stream_index >= stream_count) {
1050- stream_count = instr.stream_index + 1 ;
1051- }
1052- }
1053-
1054- auto update_sync = [&](std::size_t local_index,
1055- std::size_t stream_index,
1056- SizeVec& wait_events,
1057- auto get_event) {
1058- int source_stream = local_source_streams.at (local_index);
1059- auto & consumer_streams = local_consumer_streams.at (local_index);
1060- if (std::find (consumer_streams.begin (), consumer_streams.end (), stream_index) ==
1061- consumer_streams.end ()) {
1062- consumer_streams.push_back (stream_index);
1063- }
1064- if (!sync_tracker.is_in_sync_with (stream_index, source_stream)) {
1065- wait_events.push_back (get_event (source_stream));
1066- sync_tracker.synchronize (stream_index, source_stream);
1067- }
1068- } auto get_event_backward = [&](std::size_t source_stream) -> int {
1069- int & event = backward_record_events.at (last_stream_instrs.at (source_stream));
1070- if (event == -1 ) {
1071- event = backward_event_count;
1072- ++backward_event_count;
1073- }
1074- return event;
1075- };
1048+ auto update_sync_backward =
1049+ [&](std::size_t local_index, std::size_t stream_index, SizeVec& wait_events) {
1050+ int source_stream = local_source_streams.at (local_index);
1051+ if (source_stream == -1 ) {
1052+ return ;
1053+ }
1054+ if (!sync_tracker.is_in_sync_with (stream_index, source_stream)) {
1055+ int & event =
1056+ backward_record_events.at (last_stream_instrs.at (source_stream));
1057+ if (event == -1 ) {
1058+ event = backward_event_count;
1059+ ++backward_event_count;
1060+ }
1061+ wait_events.push_back (event);
1062+ sync_tracker.synchronize (stream_index, source_stream);
1063+ }
1064+ };
10761065
10771066 for (std::size_t instr_index = 0 ; auto [instr, bw_wait_events] :
10781067 zip (std::views::reverse (function.instructions ()),
10791068 std::views::reverse (backward_wait_events))) {
10801069 for (auto & out : instr.outputs ) {
1081- update_sync (
1082- out.local_index , instr.stream_index , bw_wait_events, get_event_backward
1083- );
1070+ update_sync_backward (out.local_index , instr.stream_index , bw_wait_events);
10841071 }
10851072 for (auto & in : instr.inputs ) {
10861073 local_source_streams.at (in.local_index ) = instr.stream_index ;
@@ -1090,22 +1077,35 @@ GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
10901077 ++instr_index;
10911078 }
10921079 for (auto & in : function.inputs ()) {
1093- update_sync (in.local_index , 0 , _wait_events, get_event_forward );
1080+ update_sync_backward (in.local_index , 0 , _backward_wait_events );
10941081 }
10951082
10961083 sync_tracker.reset ();
1097- local_source_streams.clear ();
1098- last_stream_instrs.clear ();
1099- local_consumer_streams.clear ();
1100-
1101- auto get_event_forward = [&](std::size_t source_stream) -> int {
1102- int & event =
1103- _instructions.at (last_stream_instrs.at (source_stream)).record_event ;
1104- if (event == -1 ) {
1105- event = event_count;
1106- ++event_count;
1084+ std::fill (local_source_streams.begin (), local_source_streams.end (), -1 );
1085+ nested_vector2<std::size_t > local_consumer_streams (function.locals ().size ());
1086+
1087+ auto update_sync = [&](std::size_t local_index,
1088+ std::size_t stream_index,
1089+ SizeVec& wait_events) {
1090+ int source_stream = local_source_streams.at (local_index);
1091+ if (source_stream == -1 ) {
1092+ return ;
1093+ }
1094+ auto & consumer_streams = local_consumer_streams.at (local_index);
1095+ if (std::find (consumer_streams.begin (), consumer_streams.end (), stream_index) ==
1096+ consumer_streams.end ()) {
1097+ consumer_streams.push_back (stream_index);
1098+ }
1099+ if (!sync_tracker.is_in_sync_with (stream_index, source_stream)) {
1100+ int & event =
1101+ _instructions.at (last_stream_instrs.at (source_stream)).record_event ;
1102+ if (event == -1 ) {
1103+ event = event_count;
1104+ ++event_count;
1105+ }
1106+ wait_events.push_back (event);
1107+ sync_tracker.synchronize (stream_index, source_stream);
11071108 }
1108- return event;
11091109 };
11101110
11111111 std::vector<bool > is_input (function.locals ().size ());
@@ -1141,9 +1141,7 @@ GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
11411141 if (in.type .batch_size != BatchSize::one) {
11421142 batch_size_index = in.local_index ;
11431143 }
1144- update_sync (
1145- in.local_index , instr.stream_index , bw_wait_events, get_event_forward
1146- );
1144+ update_sync (in.local_index , instr.stream_index , bw_wait_events);
11471145 }
11481146 SizeVec output_indices;
11491147 std::vector<DataType> output_dtypes;
@@ -1161,6 +1159,7 @@ GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
11611159 }
11621160
11631161 sync_tracker.desynchronize (instr.stream_index );
1162+ last_stream_instrs.at (instr.stream_index ) = _instructions.size ();
11641163 _instructions.push_back ({
11651164 instr.instruction ->opcode (),
11661165 input_indices,
@@ -1176,16 +1175,18 @@ GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
11761175 wait_events,
11771176 -1 ,
11781177 bw_wait_events,
1179- bw_record_events ,
1178+ bw_record_event ,
11801179 });
11811180
1182- auto locals_to_free = last_use.local_indices (instr_index);
1181+ SizeVec locals_to_free = last_use.local_indices (instr_index);
11831182 free_queue.insert (
11841183 free_queue.end (), locals_to_free.begin (), locals_to_free.end ()
11851184 );
11861185 free_queue.erase (
11871186 std::remove_if (
1188- free_queue.begin (), free_queue.end (), [&](std::size_t local_index) {
1187+ free_queue.begin (),
1188+ free_queue.end (),
1189+ [&](std::size_t local_index) {
11891190 for (std::size_t consumer_stream :
11901191 local_consumer_streams.at (local_index)) {
11911192 if (!sync_tracker.is_in_sync_with (
@@ -1213,7 +1214,8 @@ GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
12131214 );
12141215 return true ;
12151216 }
1216- )
1217+ ),
1218+ free_queue.end ()
12171219 );
12181220
12191221 ++instr_index;
@@ -1252,11 +1254,11 @@ GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
12521254
12531255 for (auto & out : function.outputs ()) {
12541256 _output_indices.push_back (out.local_index );
1255- update_sync (out.local_index , 0 , _wait_events, get_event_forward );
1257+ update_sync (out.local_index , 0 , _wait_events);
12561258 }
12571259
12581260 _streams = ThreadResource<std::vector<gpuStream_t>>(
1259- context. thread_pool (),
1261+ context-> thread_pool (),
12601262 [stream_count]() {
12611263 std::vector<gpuStream_t> streams (stream_count);
12621264 for (auto & item : streams) {
@@ -1272,7 +1274,7 @@ GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
12721274 );
12731275 std::size_t max_event_count = std::max (event_count, backward_event_count);
12741276 _events = ThreadResource<std::vector<gpuEvent_t>>(
1275- context. thread_pool (),
1277+ context-> thread_pool (),
12761278 [max_event_count]() {
12771279 std::vector<gpuEvent_t> events (max_event_count);
12781280 for (auto & item : events) {
@@ -1388,7 +1390,7 @@ std::tuple<TensorVec, TensorVec, std::vector<bool>> GpuRuntime::run_with_grad(
13881390 for (auto index : _output_indices) {
13891391 outputs.push_back (locals[index]);
13901392 }
1391- check_error (gpuStreamSynchronize (main_stream);
1393+ check_error (gpuStreamSynchronize (main_stream)) ;
13921394 return {outputs, locals, eval_grad};
13931395}
13941396
@@ -1411,7 +1413,7 @@ GpuRuntime::run_backward(
14111413 gpuStream_t main_stream = streams.at (0 );
14121414 for (auto [instr, instr_eval_grad] :
14131415 zip (std::views::reverse (_instructions), std::views::reverse (eval_grad))) {
1414- /* gpuStream_t stream = streams.at(instr.backward_stream );
1416+ /* gpuStream_t stream = streams.at(instr.stream );
14151417 for (auto event : instr.backward_wait_events) {
14161418 check_error(gpuStreamWaitEvent(stream, events.at(event)));
14171419 }*/
@@ -1432,8 +1434,8 @@ GpuRuntime::run_backward(
14321434#include " runtime_backward_mixin.h"
14331435 }
14341436 }
1435- /* if (instr.backward_record_event) {
1436- check_error(gpuEventRecord(instr.backward_record_event, stream));
1437+ /* if (instr.backward_record_event != -1 ) {
1438+ check_error(gpuEventRecord(events.at( instr.backward_record_event) , stream));
14371439 }*/
14381440 }
14391441 /* for (auto event : _backward_wait_events) {
0 commit comments