@@ -165,10 +165,21 @@ class Tensor;
165165
166166enum class DeviceType { cpu, cuda, hip };
167167
168+ enum class AllocHint {
169+ normal,
170+ output,
171+ local,
172+ temporary,
173+ input_grad,
174+ local_grad,
175+ global_grad,
176+ };
177+
168178class Device {
169179public:
170180 virtual ~Device () = default ;
171- virtual std::pair<void *, Tensor> allocate (std::size_t size) const = 0;
181+ virtual std::pair<void *, Tensor>
182+ allocate (std::size_t size, AllocHint hint) const = 0 ;
172183 virtual void free (void * ptr) const = 0;
173184 virtual void memcpy (void * to, void * from, std::size_t size) const = 0;
174185 virtual void tensor_copy (const Tensor& source, Tensor& target) const = 0;
@@ -199,19 +210,30 @@ class Tensor {
199210
200211 Tensor (Tensor&& other) noexcept : impl(other.impl) { other.impl = nullptr ; }
201212
202- Tensor (DataType dtype, const Sizes& shape) : Tensor(dtype, shape, cpu_device()) {}
213+ Tensor (DataType dtype, const Sizes& shape, AllocHint hint = AllocHint::normal) :
214+ Tensor (dtype, shape, cpu_device(), hint) {}
203215
204- Tensor (DataType dtype, const Sizes& shape, DevicePtr device) :
216+ Tensor (
217+ DataType dtype,
218+ const Sizes& shape,
219+ DevicePtr device,
220+ AllocHint hint = AllocHint::normal
221+ ) :
205222 impl (new TensorImpl{dtype, shape, device}) {
206223 auto size = init_stride ();
207- allocate (size, *device);
224+ allocate (size, *device, hint );
208225 }
209226
210227 template <typename D>
211- Tensor (DataType dtype, const Sizes& shape, const D& device) :
228+ Tensor (
229+ DataType dtype,
230+ const Sizes& shape,
231+ const D& device,
232+ AllocHint hint = AllocHint::normal
233+ ) :
212234 impl (new TensorImpl{dtype, shape, device.device_ptr ()}) {
213235 auto size = init_stride ();
214- allocate (size, device);
236+ allocate (size, device, hint );
215237 }
216238
217239 Tensor (
@@ -276,21 +298,21 @@ class Tensor {
276298 }) {}
277299
278300 template <ScalarType T>
279- Tensor (T value, DevicePtr device) :
301+ Tensor (T value, DevicePtr device, AllocHint hint = AllocHint::normal ) :
280302 impl (new TensorImpl{
281303 std::is_same_v<T, me_int_t > ? DataType::dt_int : DataType::dt_float,
282304 {1 },
283305 device
284306 }) {
285307 auto size = init_stride ();
286- allocate (size, *device);
308+ allocate (size, *device, hint );
287309 device->memcpy (impl->data , &value, sizeof (value));
288310 if (std::is_same_v<T, me_int_t > && value >= 0 ) {
289311 impl->batch_sizes .push_back (value);
290312 }
291313 }
292314
293- Tensor (TensorValue value, DevicePtr device) :
315+ Tensor (TensorValue value, DevicePtr device, AllocHint hint = AllocHint::normal ) :
294316 impl (new TensorImpl{
295317 std::visit (
296318 Overloaded{
@@ -309,7 +331,7 @@ class Tensor {
309331 device
310332 }) {
311333 auto size = init_stride ();
312- allocate (size, *device);
334+ allocate (size, *device, hint );
313335 std::visit (
314336 [&](auto & vec) { device->memcpy (impl->data , vec.data (), size); },
315337 std::get<1 >(value)
@@ -510,35 +532,41 @@ class Tensor {
510532 void add (const Tensor& source) { add (source, *impl->device ); }
511533
512534 template <typename D>
513- Tensor copy (const D& device) const {
535+ Tensor copy (const D& device, AllocHint hint = AllocHint::normal ) const {
514536 check_impl ();
515- Tensor tensor (impl->dtype , impl->shape , impl->device );
537+ Tensor tensor (impl->dtype , impl->shape , impl->device , hint );
516538 device.tensor_copy (*this , tensor);
517539 return tensor;
518540 }
519- Tensor copy () const { return copy (*impl->device ); }
541+ Tensor copy (AllocHint hint = AllocHint::normal) const {
542+ return copy (*impl->device , hint);
543+ }
520544
521545 bool is_contiguous () const { return impl->contiguous_dims == impl->shape .size (); }
522546
523547 std::size_t contiguous_dims () const { return impl->contiguous_dims ; }
524548
525549 template <typename D>
526- Tensor contiguous (const D& device) const {
550+ Tensor contiguous (const D& device, AllocHint hint = AllocHint::normal ) const {
527551 check_impl ();
528- return is_contiguous () ? *this : copy (device);
552+ return is_contiguous () ? *this : copy (device, hint );
529553 }
530554
531- Tensor contiguous () const { return contiguous (*impl->device ); }
555+ Tensor contiguous (AllocHint hint = AllocHint::normal) const {
556+ return contiguous (*impl->device , hint);
557+ }
532558
533559 template <typename D>
534- Tensor contiguous (std::size_t batch_size, const D& device) const {
560+ Tensor contiguous (
561+ std::size_t batch_size, const D& device, AllocHint hint = AllocHint::normal
562+ ) const {
535563 check_impl ();
536564 if (size (0 ) == batch_size) {
537- return contiguous (device);
565+ return contiguous (device, hint );
538566 } else if (size (0 ) == 1 ) {
539567 auto shape = impl->shape ;
540568 shape[0 ] = batch_size;
541- Tensor tensor (impl->dtype , shape, impl->device );
569+ Tensor tensor (impl->dtype , shape, impl->device , hint );
542570 device.tensor_copy (*this , tensor);
543571 return tensor;
544572 } else {
@@ -597,8 +625,8 @@ class Tensor {
597625 }
598626
599627 template <typename D>
600- void allocate (std::size_t size, const D& device) {
601- auto [data, parent] = device.allocate (size);
628+ void allocate (std::size_t size, const D& device, AllocHint hint ) {
629+ auto [data, parent] = device.allocate (size, hint );
602630 impl->data = data;
603631 if (parent) {
604632 impl->owns_data = false ;
0 commit comments