@@ -6,7 +6,7 @@ using namespace madspace;
66using namespace madspace ::gpu;
77using namespace madspace ::kernels;
88
9- std::pair<void *, Tensor> GpuDevice::allocate (std::size_t size) const {
9+ std::pair<void *, Tensor> GpuDevice::allocate (std::size_t size, AllocHint hint ) const {
1010 activate ();
1111 void * ptr;
1212 check_error (gpuMalloc (&ptr, size));
@@ -25,19 +25,19 @@ void GpuDevice::memcpy(void* to, void* from, std::size_t size) const {
2525
2626void GpuDevice::tensor_copy (const Tensor& source, Tensor& target) const {
2727 activate ();
28- AsyncGpuDevice (*this , gpuStreamPerThread).tensor_copy (source, target);
28+ AsyncGpuDevice (*this , gpuStreamPerThread, 0 ).tensor_copy (source, target);
2929 check_error (gpuStreamSynchronize (gpuStreamPerThread));
3030}
3131
3232void GpuDevice::tensor_zero (Tensor& tensor) const {
3333 activate ();
34- AsyncGpuDevice (*this , gpuStreamPerThread).tensor_zero (tensor);
34+ AsyncGpuDevice (*this , gpuStreamPerThread, 0 ).tensor_zero (tensor);
3535 check_error (gpuStreamSynchronize (gpuStreamPerThread));
3636}
3737
3838void GpuDevice::tensor_add (const Tensor& source, Tensor& target) const {
3939 activate ();
40- AsyncGpuDevice (*this , gpuStreamPerThread).tensor_add (source, target);
40+ AsyncGpuDevice (*this , gpuStreamPerThread, 0 ).tensor_add (source, target);
4141 check_error (gpuStreamSynchronize (gpuStreamPerThread));
4242}
4343
@@ -65,8 +65,9 @@ MemPool::MemPool(
6565 auto & pool = _pools.at (pool_index);
6666 std::size_t word_count = (size + 7 ) / 8 ;
6767 pool.parent_tensor = Tensor (DataType::dt_float, {word_count}, device);
68- pool.size = word_count * 8 ;
68+ pool.capacity = word_count * 8 ;
6969 pool.needed_size = word_count * 8 ;
70+ // println("create pool {} {}", pool_index, pool.size);
7071 }
7172}
7273
@@ -83,22 +84,29 @@ MemPool::~MemPool() {
8384
8485std::pair<void *, Tensor> MemPool::allocate (std::size_t pool_index, std::size_t size) {
8586 if (pool_index >= _pools.size ()) {
86- _pools.resize (pool_index);
87+ _pools.resize (pool_index + 1 );
8788 }
8889 PoolItem& pool = _pools.at (pool_index);
8990 if (auto search = pool.free_pointers .find (size);
9091 search != pool.free_pointers .end ()) {
91- std::pair<void *, Tensor> ret = *search->second ;
92+ std::pair<void *, Tensor> ret = search->second ;
93+ _allocs[ret.first ] = {
94+ .pool_index = pool_index,
95+ .size = size,
96+ .parent_tensor = ret.second ,
97+ };
98+ // println("reuse {} {} {}", ret.first, pool_index, size);
9299 pool.free_pointers .erase (search);
93100 return ret;
94- } else if (pool.capacity - pool.size >= size) {
101+ } else if (pool.parent_tensor && pool. capacity - pool.size >= size) {
95102 void * ptr = &static_cast <uint8_t *>(pool.parent_tensor .data ())[pool.size ];
96103 pool.size = (pool.size + size + 7 ) / 8 * 8 ;
97104 _allocs[ptr] = {
98105 .pool_index = pool_index,
99106 .size = size,
100107 .parent_tensor = pool.parent_tensor ,
101108 };
109+ // println("pooled {} {} {} {} {}", ptr, pool_index, size, pool.size, pool.capacity);
102110 return {ptr, pool.parent_tensor };
103111 } else {
104112 void * ptr;
@@ -108,25 +116,28 @@ std::pair<void*, Tensor> MemPool::allocate(std::size_t pool_index, std::size_t s
108116 .size = size,
109117 .parent_tensor = Tensor (),
110118 };
119+ // println("alloc {} {} {}", ptr, pool_index, size);
111120 pool.needed_size += (size + 7 ) / 8 * 8 ;
112121 return {ptr, Tensor ()};
113122 }
114123}
115124
116125void MemPool::free (void * ptr) {
117- auto search = _allocs.find (ptr) if (search == _allocs.end ()) {
126+ auto search = _allocs.find (ptr);
127+ if (search == _allocs.end ()) {
118128 throw std::runtime_error (" address was not allocated using this pool" );
119129 }
120130 auto & alloc = search->second ;
121131 _pools.at (alloc.pool_index )
122- .free_pointers .emplace (alloc.size , {ptr, alloc.parent_tensor });
132+ .free_pointers .emplace (alloc.size , std::pair<void *, Tensor>{ptr, alloc.parent_tensor });
133+ // println("free {} {} {}", ptr, alloc.pool_index, alloc.size);
123134 _allocs.erase (search);
124135}
125136
126137std::vector<std::pair<std::size_t , std::size_t >> MemPool::total_sizes () const {
127138 std::vector<std::pair<std::size_t , std::size_t >> ret;
128139 ret.reserve (_pools.size ());
129- for (std::size_t index = 0 ; PoolItem & pool : _pools) {
140+ for (std::size_t index = 0 ; auto & pool : _pools) {
130141 if (pool.needed_size > 0 ) {
131142 ret.push_back ({index, pool.needed_size });
132143 }
@@ -137,17 +148,46 @@ std::vector<std::pair<std::size_t, std::size_t>> MemPool::total_sizes() const {
137148
138149std::pair<void *, Tensor>
139150AsyncGpuDevice::allocate (std::size_t size, AllocHint hint) const {
140- if (_mem_pool != nullptr && hint != AllocHint::normal) {
141- return _mem_pool->allocate (static_cast <std::size_t >(hint) - 1 , size);
151+ if (_mem_pool) {
152+ std::size_t pool_index;
153+ switch (hint) {
154+ case AllocHint::normal:
155+ throw std::runtime_error (" allocation without hint" );
156+ case AllocHint::output:
157+ pool_index = 0 ;
158+ break ;
159+ case AllocHint::local:
160+ pool_index = 3 + 3 * _stream_index;
161+ break ;
162+ case AllocHint::temporary:
163+ pool_index = 4 + 3 * _stream_index;
164+ break ;
165+ case AllocHint::input_grad:
166+ pool_index = 1 ;
167+ break ;
168+ case AllocHint::local_grad:
169+ pool_index = 5 + 3 * _stream_index;
170+ break ;
171+ case AllocHint::global_grad:
172+ pool_index = 2 ;
173+ break ;
174+ }
175+ return _mem_pool->allocate (pool_index, size);
142176 } else {
143- _device.allocate (size, hint);
144- // void* ptr;
145- // check_error(gpuMallocAsync(&ptr, size, _stream));
146- // return {ptr, Tensor()};
177+ // _device.allocate(size, hint);
178+ void * ptr;
179+ check_error (gpuMallocAsync (&ptr, size, _stream));
180+ return {ptr, Tensor ()};
147181 }
148182}
149183
150- void AsyncGpuDevice::free (void * ptr) const { check_error (gpuFreeAsync (ptr, _stream)); }
184+ void AsyncGpuDevice::free (void * ptr) const {
185+ if (_mem_pool) {
186+ _mem_pool->free (ptr);
187+ } else {
188+ check_error (gpuFreeAsync (ptr, _stream));
189+ }
190+ }
151191
152192void AsyncGpuDevice::memcpy (void * to, void * from, std::size_t size) const {
153193 check_error (gpuMemcpyAsync (to, from, size, gpuMemcpyDefault, _stream));
@@ -170,13 +210,21 @@ void AsyncGpuDevice::tensor_copy(const Tensor& source, Tensor& target) const {
170210
171211void AsyncGpuDevice::tensor_zero (Tensor& tensor) const {
172212 if (tensor.dtype () == DataType::dt_float) {
173- tensor_foreach_dynamic<kernel_zero<GpuTypes>, 1 , 1 >(
174- {&tensor}, {&tensor}, tensor.size (0 ), *this
175- );
213+ if (tensor.is_contiguous ()) {
214+ gpuMemsetAsync (tensor.data (), 0 , tensor.byte_size (), _stream);
215+ } else {
216+ tensor_foreach_dynamic<kernel_zero<GpuTypes>, 1 , 1 >(
217+ {&tensor}, {&tensor}, tensor.size (0 ), *this
218+ );
219+ }
176220 } else if (tensor.dtype () == DataType::dt_int) {
177- tensor_foreach_dynamic<kernel_zero_int<GpuTypes>, 1 , 1 >(
178- {&tensor}, {&tensor}, tensor.size (0 ), *this
179- );
221+ if (tensor.is_contiguous ()) {
222+ gpuMemsetAsync (tensor.data (), 0 , tensor.byte_size (), _stream);
223+ } else {
224+ tensor_foreach_dynamic<kernel_zero_int<GpuTypes>, 1 , 1 >(
225+ {&tensor}, {&tensor}, tensor.size (0 ), *this
226+ );
227+ }
180228 } else {
181229 throw std::runtime_error (" invalid dtype in zero" );
182230 }
0 commit comments