@@ -1043,18 +1043,18 @@ GpuRuntime::GpuRuntime(const Function& function_arg, ContextPtr context) :
10431043 return handle;
10441044 },
10451045 [](gpurandGenerator_t handle) { check_error (gpurandDestroyGenerator (handle)); }
1046- ) {
1046+ ),
1047+ _prev_caches (context->thread_pool (), []() { return TensorVec{}; }) {
10471048 if (context->device ()->device_type () != GpuDevice::gpu_device_type) {
10481049 throw std::runtime_error (" Context has incompatible device" );
10491050 }
10501051 auto & gpu_device = *static_cast <const GpuDevice*>(_context->device ());
10511052 gpu_device.activate ();
10521053
1053- cudaMemPool_t pool;
1054- check_error (cudaDeviceGetMemPool (&pool, 0 ));
1055- uint64_t thresh = UINT64_MAX;
1056- check_error (cudaMemPoolSetAttribute (pool, cudaMemPoolAttrReleaseThreshold, &thresh));
1057-
1054+ // cudaMemPool_t pool;
1055+ // check_error(cudaDeviceGetMemPool(&pool, 0));
1056+ // uint64_t thresh = UINT64_MAX;
1057+ // check_error(cudaMemPoolSetAttribute(pool, cudaMemPoolAttrReleaseThreshold, &thresh));
10581058
10591059 Function function = sort_breadth_first (function_arg);
10601060
@@ -1352,7 +1352,7 @@ TensorVec GpuRuntime::run(const TensorVec& inputs) {
13521352 check_error (gpuStreamWaitEvent (main_stream, events.at (event)));
13531353 }
13541354 update_pool_size_cache (mem_pool.total_sizes ());
1355- mem_pool.reset (main_stream);
1355+ update_cached_tensors ( mem_pool.reset (main_stream) );
13561356 TensorVec outputs;
13571357 for (auto index : _output_indices) {
13581358 outputs.push_back (locals[index]);
@@ -1421,7 +1421,7 @@ std::tuple<TensorVec, TensorVec, std::vector<bool>> GpuRuntime::run_with_grad(
14211421 check_error (gpuStreamWaitEvent (main_stream, events.at (event)));
14221422 }
14231423 update_pool_size_cache (mem_pool.total_sizes ());
1424- mem_pool.reset (main_stream);
1424+ update_cached_tensors ( mem_pool.reset (main_stream) );
14251425 TensorVec outputs;
14261426 for (auto index : _output_indices) {
14271427 outputs.push_back (locals[index]);
@@ -1479,7 +1479,7 @@ GpuRuntime::run_backward(
14791479 check_error(gpuStreamWaitEvent(main_stream, events.at(event)));
14801480 }*/
14811481 update_pool_size_cache (mem_pool.total_sizes ());
1482- mem_pool.reset (main_stream);
1482+ update_cached_tensors ( mem_pool.reset (main_stream) );
14831483 std::vector<std::tuple<std::string, Tensor>> global_grads;
14841484 for (auto & [name, index] : _grad_global_indices) {
14851485 global_grads.push_back ({name, local_grads[index]});
@@ -1488,11 +1488,21 @@ GpuRuntime::run_backward(
14881488 return {{local_grads.begin (), local_grads.begin () + _input_count}, global_grads};
14891489}
14901490
1491- std::vector<std::pair <std::size_t , std::size_t >> GpuRuntime::load_pool_size_cache () {
1491+ std::vector<std::tuple <std::size_t , std::size_t , Tensor >> GpuRuntime::load_pool_size_cache () {
14921492 auto cache = _pool_size_cache.load ();
1493- std::vector<std::pair <std::size_t , std::size_t >> ret;
1493+ std::vector<std::tuple <std::size_t , std::size_t , Tensor >> ret;
14941494 if (cache) {
1495- ret = {cache->begin (), cache->end ()};
1495+ auto & thread_prev_caches = _prev_caches.get ();
1496+ for (auto [pool_index, size] : *cache) {
1497+ Tensor new_cache;
1498+ if (pool_index < thread_prev_caches.size ()) {
1499+ Tensor& prev_cache = thread_prev_caches.at (pool_index);
1500+ if (prev_cache && prev_cache.is_only_reference ()) {
1501+ new_cache = prev_cache;
1502+ }
1503+ }
1504+ ret.push_back ({pool_index, size, new_cache});
1505+ }
14961506 }
14971507 return ret;
14981508}
@@ -1503,11 +1513,25 @@ void GpuRuntime::update_pool_size_cache(const std::vector<std::pair<std::size_t,
15031513 std::make_shared<std::unordered_map<std::size_t , std::size_t >>(*cache) :
15041514 std::make_shared<std::unordered_map<std::size_t , std::size_t >>();
15051515 for (auto [pool_index, size] : total_sizes) {
1506- (*new_cache)[pool_index] = std::max ((*new_cache)[pool_index], size);
1516+ auto & cache_size = (*new_cache)[pool_index];
1517+ if (size > cache_size) {
1518+ // if the cache needs to be resized, add some padding to prevent frequent resizing
1519+ cache_size = size * 4 / 3 ;
1520+ }
15071521 }
15081522 _pool_size_cache.store (new_cache);
15091523}
15101524
1525+ void GpuRuntime::update_cached_tensors (const std::vector<std::pair<std::size_t , Tensor>>& tensors) {
1526+ auto & thread_prev_caches = _prev_caches.get ();
1527+ for (auto & [pool_index, tensor] : tensors) {
1528+ if (pool_index >= thread_prev_caches.size ()) {
1529+ thread_prev_caches.resize (pool_index + 1 );
1530+ }
1531+ thread_prev_caches.at (pool_index) = tensor;
1532+ }
1533+ }
1534+
15111535extern " C" Runtime*
15121536build_runtime (const Function& function, ContextPtr context, bool concurrent) {
15131537 return new GpuRuntime (function, context);
0 commit comments