@@ -168,7 +168,7 @@ enum class DeviceType { cpu, cuda, hip };
168168class Device {
169169public:
170170 virtual ~Device () = default ;
171- virtual void * allocate (std::size_t size) const = 0;
171+ virtual std::pair< void *, Tensor> allocate (std::size_t size) const = 0;
172172 virtual void free (void * ptr) const = 0;
173173 virtual void memcpy (void * to, void * from, std::size_t size) const = 0;
174174 virtual void tensor_copy (const Tensor& source, Tensor& target) const = 0;
@@ -204,14 +204,14 @@ class Tensor {
204204 Tensor (DataType dtype, const Sizes& shape, DevicePtr device) :
205205 impl (new TensorImpl{dtype, shape, device}) {
206206 auto size = init_stride ();
207- impl-> data = device-> allocate (size);
207+ allocate (size, *device );
208208 }
209209
210210 template <typename D>
211211 Tensor (DataType dtype, const Sizes& shape, const D& device) :
212212 impl (new TensorImpl{dtype, shape, device.device_ptr ()}) {
213213 auto size = init_stride ();
214- impl-> data = device. allocate (size);
214+ allocate (size, device );
215215 }
216216
217217 Tensor (
@@ -283,7 +283,7 @@ class Tensor {
283283 device
284284 }) {
285285 auto size = init_stride ();
286- impl-> data = device-> allocate (size);
286+ allocate (size, *device );
287287 device->memcpy (impl->data , &value, sizeof (value));
288288 if (std::is_same_v<T, me_int_t > && value >= 0 ) {
289289 impl->batch_sizes .push_back (value);
@@ -309,7 +309,7 @@ class Tensor {
309309 device
310310 }) {
311311 auto size = init_stride ();
312- impl-> data = device-> allocate (size);
312+ allocate (size, *device );
313313 std::visit (
314314 [&](auto & vec) { device->memcpy (impl->data , vec.data (), size); },
315315 std::get<1 >(value)
@@ -596,6 +596,16 @@ class Tensor {
596596 }
597597 }
598598
599+ template <typename D>
600+ void allocate (std::size_t size, const D& device) {
601+ auto [data, parent] = device.allocate (size);
602+ impl->data = data;
603+ if (parent) {
604+ impl->owns_data = false ;
605+ impl->data_owner = parent.impl ;
606+ }
607+ }
608+
599609 TensorImpl* impl;
600610};
601611
0 commit comments