Skip to content

Commit 73a832f

Browse files
committed
fix: remove unnecessary changes
1 parent 4d970c4 commit 73a832f

10 files changed

Lines changed: 23 additions & 49 deletions

File tree

infini_train/include/datatype.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,11 @@ enum class DataType : int8_t {
9898
kFLOAT64,
9999
};
100100

101-
size_t DTypeSize(DataType data_type);
101+
inline const std::unordered_map<DataType, size_t> kDataTypeToSize = {
102+
{DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2}, {DataType::kINT16, 2},
103+
{DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8}, {DataType::kINT64, 8},
104+
{DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4}, {DataType::kFLOAT64, 8},
105+
};
102106

103107
extern const std::unordered_map<DataType, std::string> kDataTypeToDesc;
104108

infini_train/include/dispatcher.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
#include "infini_train/include/autocast.h"
1010
#include "infini_train/include/device.h"
11-
// FIXEM(dcj): should not include this
12-
#include "infini_train/include/dtype_dispatch.h"
1311
#ifdef PROFILE_MODE
1412
#include "infini_train/include/profiler.h"
1513
#endif

infini_train/src/core/runtime/cuda/cuda_dispatch.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#pragma once
22

3-
#include <cuda_bf16.h>
4-
#include <cuda_fp16.h>
5-
63
#include <utility>
74
#include <vector>
85

6+
#include <cuda_bf16.h>
7+
#include <cuda_fp16.h>
8+
99
#include "infini_train/include/core/backend_type_map.h"
1010
#include "infini_train/include/dtype_dispatch.h"
1111

infini_train/src/datatype.cc

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -166,36 +166,6 @@ BF16 &BF16::operator++() {
166166
// -----------------------------------------------------------------------------
167167
// DataType metadata
168168
// -----------------------------------------------------------------------------
169-
size_t DTypeSize(DataType data_type) {
170-
switch (data_type) {
171-
case DataType::kUINT8:
172-
return 1;
173-
case DataType::kINT8:
174-
return 1;
175-
case DataType::kUINT16:
176-
return 2;
177-
case DataType::kINT16:
178-
return 2;
179-
case DataType::kUINT32:
180-
return 4;
181-
case DataType::kINT32:
182-
return 4;
183-
case DataType::kUINT64:
184-
return 8;
185-
case DataType::kINT64:
186-
return 8;
187-
case DataType::kBFLOAT16:
188-
return 2;
189-
case DataType::kFLOAT16:
190-
return 2;
191-
case DataType::kFLOAT32:
192-
return 4;
193-
case DataType::kFLOAT64:
194-
return 8;
195-
}
196-
return 0; // unreachable
197-
}
198-
199169
const std::unordered_map<DataType, std::string> kDataTypeToDesc = {
200170
{DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"}, {DataType::kUINT16, "uint16"},
201171
{DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"}, {DataType::kINT32, "int32"},
@@ -234,7 +204,7 @@ DataType PromoteDataTypes(DataType a, DataType b) {
234204
}
235205

236206
// Rule 3: same category — wider wins
237-
return DTypeSize(a) >= DTypeSize(b) ? a : b;
207+
return kDataTypeToSize.at(a) >= kDataTypeToSize.at(b) ? a : b;
238208
}
239209

240210
} // namespace infini_train

infini_train/src/kernels/cuda/concat.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ std::shared_ptr<Tensor> ConcatForward(const std::vector<std::shared_ptr<Tensor>>
9090
const int64_t num_inputs = static_cast<int64_t>(inputs.size());
9191
const int64_t K_total = out_dims[dim];
9292

93+
// offsets records the sum of Ks
9394
// offsets[i] = sum_{j < i} K_j
9495
std::vector<int64_t> host_offsets(num_inputs + 1, 0);
9596
for (int64_t i = 0; i < num_inputs; ++i) { host_offsets[i + 1] = host_offsets[i] + Ks[i]; }

infini_train/src/kernels/cuda/elementwise.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "infini_train/include/common/cuda/kernel_helper.cuh"
77
#include "infini_train/include/core/runtime/device_guard.h"
88
#include "infini_train/include/dispatcher.h"
9+
#include "infini_train/include/dtype_dispatch.h"
910
#include "infini_train/include/tensor.h"
1011

1112
#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h"

infini_train/src/nn/parallel/ddp/distributed_optimizer.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() {
6464
const size_t piece_numel = local_end - local_start;
6565
CHECK_GT(piece_numel, 0);
6666

67-
const size_t param_piece_offset_bytes = local_start * DTypeSize(bucket_param->Dtype());
68-
const size_t grad_piece_offset_bytes = local_start * DTypeSize(bucket_grad->Dtype());
67+
const size_t param_piece_offset_bytes = local_start * kDataTypeToSize.at(bucket_param->Dtype());
68+
const size_t grad_piece_offset_bytes = local_start * kDataTypeToSize.at(bucket_grad->Dtype());
6969

7070
auto param_piece = std::make_shared<Tensor>(*bucket_param, param_piece_offset_bytes,
7171
std::vector<int64_t>{static_cast<int64_t>(piece_numel)});

infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ std::shared_ptr<Tensor> AllocateFlatBuffer(size_t num_elements, DataType data_ty
3636

3737
std::shared_ptr<Tensor> GetBufferView(const std::shared_ptr<Tensor> buffer, size_t start_in_elements,
3838
const std::vector<int64_t> &dims) {
39-
return std::make_shared<Tensor>(*buffer, start_in_elements * DTypeSize(buffer->Dtype()), dims);
39+
return std::make_shared<Tensor>(*buffer, start_in_elements * kDataTypeToSize.at(buffer->Dtype()), dims);
4040
};
4141

4242
std::vector<std::shared_ptr<Tensor>> ShardBuffer(const std::shared_ptr<Tensor> buffer, size_t ddp_world_size) {
@@ -451,7 +451,7 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype)
451451
// Remap param/grad pointers
452452
if (param_buffer_) {
453453
// FIXME(zbl): change tensor buffer
454-
param->SetData(*param_buffer_, param_start_index * DTypeSize(param_buffer_->Dtype()), true);
454+
param->SetData(*param_buffer_, param_start_index * kDataTypeToSize.at(param_buffer_->Dtype()), true);
455455
}
456456

457457
auto grad_view = GetBufferView(grad_buffer_, param_start_index, param->Dims());

infini_train/src/nn/parallel/ddp/reducer.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace {
1818
void CopyGradToBucket(const std::shared_ptr<Tensor> &grad, const std::shared_ptr<Tensor> &flat,
1919
size_t dst_elem_offset) {
2020
CHECK(grad && flat);
21-
const size_t element_size_in_bytes = DTypeSize(grad->Dtype());
21+
const size_t element_size_in_bytes = kDataTypeToSize.at(grad->Dtype());
2222
const size_t bytes = grad->NumElements() * element_size_in_bytes;
2323
char *dst = static_cast<char *>(flat->DataPtr()) + dst_elem_offset * element_size_in_bytes;
2424
const void *src = grad->DataPtr();
@@ -33,7 +33,7 @@ void CopyGradToBucket(const std::shared_ptr<Tensor> &grad, const std::shared_ptr
3333
void CopyBucketToGrad(const std::shared_ptr<Tensor> &flat, const std::shared_ptr<Tensor> &grad,
3434
size_t src_elem_offset) {
3535
CHECK(grad && flat);
36-
const size_t element_size_in_bytes = DTypeSize(grad->Dtype());
36+
const size_t element_size_in_bytes = kDataTypeToSize.at(grad->Dtype());
3737
const size_t bytes = grad->NumElements() * element_size_in_bytes;
3838
const char *src = static_cast<const char *>(flat->DataPtr()) + src_elem_offset * element_size_in_bytes;
3939
void *dst = grad->DataPtr();
@@ -48,7 +48,7 @@ void CopyBucketToGrad(const std::shared_ptr<Tensor> &flat, const std::shared_ptr
4848
std::shared_ptr<Tensor> MakeGradView(const std::shared_ptr<Tensor> &contents, size_t offset_elems,
4949
const std::vector<int64_t> &dims) {
5050
// Return a view of contents (same chunk of memory)
51-
auto view = std::make_shared<Tensor>(*contents, offset_elems * DTypeSize(contents->Dtype()), dims);
51+
auto view = std::make_shared<Tensor>(*contents, offset_elems * kDataTypeToSize.at(contents->Dtype()), dims);
5252
return view;
5353
}
5454
} // namespace
@@ -118,7 +118,7 @@ std::vector<std::vector<size_t>> ComputeBucketAssignmentBySize(const std::vector
118118
}
119119
auto &state = it->second;
120120

121-
const size_t element_size_in_bytes = DTypeSize(tensor->Dtype());
121+
const size_t element_size_in_bytes = kDataTypeToSize.at(tensor->Dtype());
122122
const size_t bytes = tensor->NumElements() * element_size_in_bytes;
123123
const size_t cap = bucket_size_limits[state.limit_idx];
124124

infini_train/src/tensor.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@ size_t TensorBuffer::Size() const { return size_; }
5050
// Tensor implementation
5151
Tensor::Tensor(const std::vector<int64_t> &dims, DataType dtype, Device device) : dims_(dims), dtype_(dtype) {
5252
num_elements_ = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<int64_t>());
53-
buffer_ = std::make_shared<TensorBuffer>(device, DTypeSize(dtype) * num_elements_);
53+
buffer_ = std::make_shared<TensorBuffer>(device, kDataTypeToSize.at(dtype) * num_elements_);
5454
}
5555

5656
Tensor::Tensor(const Tensor &tensor, size_t offset, const std::vector<int64_t> &dims)
5757
: buffer_(tensor.buffer_), offset_(tensor.offset_ + offset), dims_(dims),
5858
num_elements_(std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<int64_t>())), dtype_(tensor.dtype_) {
59-
CHECK_LE(offset_ + DTypeSize(dtype_) * num_elements_, buffer_->Size());
59+
CHECK_LE(offset_ + kDataTypeToSize.at(dtype_) * num_elements_, buffer_->Size());
6060
}
6161

6262
Tensor::Tensor(const float *data, const std::vector<int64_t> &dims, DataType dtype, Device device)
@@ -65,7 +65,7 @@ Tensor::Tensor(const float *data, const std::vector<int64_t> &dims, DataType dty
6565
// TODO(dcj): support more datatype
6666
CHECK(dtype == DataType::kFLOAT32);
6767

68-
buffer_ = std::make_shared<TensorBuffer>(device, DTypeSize(dtype) * num_elements_);
68+
buffer_ = std::make_shared<TensorBuffer>(device, kDataTypeToSize.at(dtype) * num_elements_);
6969

7070
core::DeviceGuard guard(device);
7171
auto *impl = core::GetDeviceGuardImpl(device.type());
@@ -96,7 +96,7 @@ void *Tensor::DataPtr() { return reinterpret_cast<uint8_t *>(buffer_->DataPtr())
9696

9797
const void *Tensor::DataPtr() const { return reinterpret_cast<const uint8_t *>(buffer_->DataPtr()) + offset_; }
9898

99-
size_t Tensor::SizeInBytes() const { return DTypeSize(dtype_) * num_elements_; }
99+
size_t Tensor::SizeInBytes() const { return kDataTypeToSize.at(dtype_) * num_elements_; }
100100

101101
const std::vector<int64_t> &Tensor::Dims() const { return dims_; }
102102

0 commit comments

Comments
 (0)