@@ -50,13 +50,13 @@ size_t TensorBuffer::Size() const { return size_; }
5050// Tensor implementation
5151Tensor::Tensor (const std::vector<int64_t > &dims, DataType dtype, Device device) : dims_(dims), dtype_(dtype) {
5252 num_elements_ = std::accumulate (dims.begin (), dims.end (), 1 , std::multiplies<int64_t >());
53- buffer_ = std::make_shared<TensorBuffer>(device, DTypeSize (dtype) * num_elements_);
53+ buffer_ = std::make_shared<TensorBuffer>(device, kDataTypeToSize . at (dtype) * num_elements_);
5454}
5555
5656Tensor::Tensor (const Tensor &tensor, size_t offset, const std::vector<int64_t > &dims)
5757 : buffer_(tensor.buffer_), offset_(tensor.offset_ + offset), dims_(dims),
5858 num_elements_ (std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<int64_t>())), dtype_(tensor.dtype_) {
59- CHECK_LE (offset_ + DTypeSize (dtype_) * num_elements_, buffer_->Size ());
59+ CHECK_LE (offset_ + kDataTypeToSize . at (dtype_) * num_elements_, buffer_->Size ());
6060}
6161
6262Tensor::Tensor (const float *data, const std::vector<int64_t > &dims, DataType dtype, Device device)
@@ -65,7 +65,7 @@ Tensor::Tensor(const float *data, const std::vector<int64_t> &dims, DataType dty
6565 // TODO(dcj): support more datatype
6666 CHECK (dtype == DataType::kFLOAT32 );
6767
68- buffer_ = std::make_shared<TensorBuffer>(device, DTypeSize (dtype) * num_elements_);
68+ buffer_ = std::make_shared<TensorBuffer>(device, kDataTypeToSize . at (dtype) * num_elements_);
6969
7070 core::DeviceGuard guard (device);
7171 auto *impl = core::GetDeviceGuardImpl (device.type ());
@@ -96,7 +96,7 @@ void *Tensor::DataPtr() { return reinterpret_cast<uint8_t *>(buffer_->DataPtr())
9696
9797const void *Tensor::DataPtr () const { return reinterpret_cast <const uint8_t *>(buffer_->DataPtr ()) + offset_; }
9898
99- size_t Tensor::SizeInBytes () const { return DTypeSize (dtype_) * num_elements_; }
99+ size_t Tensor::SizeInBytes () const { return kDataTypeToSize . at (dtype_) * num_elements_; }
100100
101101const std::vector<int64_t > &Tensor::Dims () const { return dims_; }
102102
0 commit comments