@@ -171,6 +171,16 @@ void op_matmul(
171171 bias.reset (device);
172172}
173173
174+ __global__ void kernel_one (
175+ std::size_t batch_size,
176+ GpuTensorView<double , 1 , true > output
177+ ) {
178+ me_int_t i = blockDim .x * blockIdx .x + threadIdx .x ;
179+ if (i < batch_size) {
180+ output[i] = 1 .;
181+ }
182+ }
183+
174184void backward_op_matmul (
175185 const GpuRuntime::Instruction& instruction,
176186 TensorVec& locals,
@@ -268,11 +278,12 @@ void backward_op_matmul(
268278
269279 // compute bias_grad += sum_i output_grad_ij
270280 Tensor ones (DataType::dt_float, {batch_size}, device, AllocHint::temporary);
271- thrust::fill_n (
272- thrust_par.on (stream),
273- thrust::device_pointer_cast (static_cast <double *>(ones.data ())),
281+ launch_kernel (
282+ kernel_one,
283+ batch_size,
284+ device.stream (),
274285 batch_size,
275- 1.0
286+ ones. view < double , 1 >()
276287 );
277288 check_error (gpublasDgemv (
278289 handle,
@@ -973,7 +984,9 @@ void op_histogram(
973984class SyncTracker {
974985public:
975986 SyncTracker (std::size_t stream_count) :
976- _stream_count (stream_count), _sync_matrix(stream_count * stream_count, true ) {}
987+ _stream_count (stream_count), _sync_matrix(stream_count * stream_count) {
988+ reset ();
989+ }
977990
978991 bool is_in_sync_with (std::size_t this_stream, std::size_t other_stream) const {
979992 return _sync_matrix.at (this_stream * _stream_count + other_stream);
@@ -993,7 +1006,13 @@ public:
9931006 }
9941007 }
9951008 }
996- void reset () { std::fill (_sync_matrix.begin (), _sync_matrix.end (), true ); }
1009+ void reset () {
1010+ for (std::size_t i = 0 ; i < _stream_count; ++i) {
1011+ for (std::size_t j = 0 ; j < _stream_count; ++j) {
1012+ _sync_matrix.at (i * _stream_count + j) = i == j;
1013+ }
1014+ }
1015+ }
9971016
9981017private:
9991018 std::size_t _stream_count;
@@ -1309,7 +1328,8 @@ TensorVec GpuRuntime::run(const TensorVec& inputs) {
13091328 gpu_device.activate ();
13101329 auto locals = _locals_init;
13111330 std::copy (inputs.begin (), inputs.end (), locals.begin ());
1312- MemPool mem_pool (gpu_device, load_pool_size_cache ());
1331+ gpuStream_t main_stream = streams.at (0 );
1332+ MemPool mem_pool (gpu_device, load_pool_size_cache (), main_stream);
13131333
13141334 // println("----");
13151335 for (auto & instr : _instructions) {
@@ -1328,16 +1348,16 @@ TensorVec GpuRuntime::run(const TensorVec& inputs) {
13281348 check_error (gpuEventRecord (events.at (instr.record_event ), stream));
13291349 }
13301350 }
1331- gpuStream_t main_stream = streams.at (0 );
13321351 for (auto event : _wait_events) {
13331352 check_error (gpuStreamWaitEvent (main_stream, events.at (event)));
13341353 }
1354+ update_pool_size_cache (mem_pool.total_sizes ());
1355+ mem_pool.reset (main_stream);
13351356 TensorVec outputs;
13361357 for (auto index : _output_indices) {
13371358 outputs.push_back (locals[index]);
13381359 }
13391360 check_error (gpuStreamSynchronize (main_stream));
1340- update_pool_size_cache (mem_pool.total_sizes ());
13411361 return outputs;
13421362}
13431363
@@ -1356,7 +1376,8 @@ std::tuple<TensorVec, TensorVec, std::vector<bool>> GpuRuntime::run_with_grad(
13561376 std::copy (
13571377 input_requires_grad.begin (), input_requires_grad.end (), requires_grad.begin ()
13581378 );
1359- MemPool mem_pool (gpu_device, load_pool_size_cache ());
1379+ gpuStream_t main_stream = streams.at (0 );
1380+ MemPool mem_pool (gpu_device, load_pool_size_cache (), main_stream);
13601381
13611382 for (auto [instr, instr_eval_grad] : zip (_instructions, eval_grad)) {
13621383 gpuStream_t stream = streams.at (instr.stream );
@@ -1396,16 +1417,16 @@ std::tuple<TensorVec, TensorVec, std::vector<bool>> GpuRuntime::run_with_grad(
13961417 check_error (gpuEventRecord (events.at (instr.record_event ), stream));
13971418 }
13981419 }
1399- gpuStream_t main_stream = streams.at (0 );
14001420 for (auto event : _wait_events) {
14011421 check_error (gpuStreamWaitEvent (main_stream, events.at (event)));
14021422 }
1423+ update_pool_size_cache (mem_pool.total_sizes ());
1424+ mem_pool.reset (main_stream);
14031425 TensorVec outputs;
14041426 for (auto index : _output_indices) {
14051427 outputs.push_back (locals[index]);
14061428 }
14071429 check_error (gpuStreamSynchronize (main_stream));
1408- update_pool_size_cache (mem_pool.total_sizes ());
14091430 return {outputs, locals, eval_grad};
14101431}
14111432
@@ -1424,8 +1445,8 @@ GpuRuntime::run_backward(
14241445 for (auto [index, grad] : zip (_output_indices, output_grads)) {
14251446 local_grads[index] = grad;
14261447 }
1427- MemPool mem_pool (gpu_device, load_pool_size_cache ());
14281448 gpuStream_t main_stream = streams.at (0 );
1449+ MemPool mem_pool (gpu_device, load_pool_size_cache (), main_stream);
14291450 for (auto [instr, instr_eval_grad] :
14301451 zip (std::views::reverse (_instructions), std::views::reverse (eval_grad))) {
14311452 /* gpuStream_t stream = streams.at(instr.stream);
@@ -1457,12 +1478,13 @@ GpuRuntime::run_backward(
14571478 /* for (auto event : _backward_wait_events) {
14581479 check_error(gpuStreamWaitEvent(main_stream, events.at(event)));
14591480 }*/
1481+ update_pool_size_cache (mem_pool.total_sizes ());
1482+ mem_pool.reset (main_stream);
14601483 std::vector<std::tuple<std::string, Tensor>> global_grads;
14611484 for (auto & [name, index] : _grad_global_indices) {
14621485 global_grads.push_back ({name, local_grads[index]});
14631486 }
14641487 check_error (gpuStreamSynchronize (main_stream));
1465- update_pool_size_cache (mem_pool.total_sizes ());
14661488 return {{local_grads.begin (), local_grads.begin () + _input_count}, global_grads};
14671489}
14681490
0 commit comments