Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "third_party/gflags"]
path = third_party/gflags
url = git@github.com:gflags/gflags.git
[submodule "third_party/eigen"]
path = third_party/eigen
url = https://gitlab.com/libeigen/eigen.git
31 changes: 24 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,20 @@ set(WITH_GTEST OFF CACHE BOOL "Disable glog finding system gtest" FORCE)
add_subdirectory(third_party/glog)
include_directories(${glog_SOURCE_DIR}/src)

# Add eigen
find_package(OpenMP REQUIRED)
# find_package(OpenBLAS REQUIRED)
# include_directories(${OpenBLAS_INCLUDE_DIR})
add_subdirectory(third_party/eigen)
include_directories(${PROJECT_SOURCE_DIR}/third_party/eigen)
# add_definitions(-DEIGEN_USE_BLAS)

if(USE_CUDA)
add_compile_definitions(USE_CUDA=1)
enable_language(CUDA)
include(FindCUDAToolkit)

# enable CUDA-related compilation options
# set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${CMAKE_INCLUDE_PATH} -Xcompiler -fPIC --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")

include_directories(${PROJECT_SOURCE_DIR})
Expand All @@ -37,25 +44,35 @@ if(USE_CUDA)

add_library(infini_train STATIC ${SRC})
set_target_properties(infini_train PROPERTIES CUDA_ARCHITECTURES "70;80")
target_link_libraries(infini_train glog gflags CUDA::cudart CUDA::cublas)
target_link_libraries(infini_train glog gflags CUDA::cudart CUDA::cublas Eigen3::Eigen)

# Examples
add_executable(mnist example/mnist/main.cc example/mnist/dataset.cc example/mnist/net.cc)
target_link_libraries(mnist glog gflags infini_train)
target_link_libraries(mnist glog gflags infini_train Eigen3::Eigen)

add_executable(gpt2 example/gpt2/main.cc example/gpt2/dataset.cc example/gpt2/net.cc)
target_link_libraries(gpt2 glog gflags infini_train)
target_link_libraries(gpt2 glog gflags infini_train Eigen3::Eigen)
else()
include_directories(${PROJECT_SOURCE_DIR})
file(GLOB_RECURSE SRC ${PROJECT_SOURCE_DIR}/infini_train/src/*.cc)

add_library(infini_train STATIC ${SRC})
target_link_libraries(infini_train glog gflags)
target_link_libraries(infini_train glog gflags Eigen3::Eigen)

# Examples
add_executable(mnist example/mnist/main.cc example/mnist/dataset.cc example/mnist/net.cc)
target_link_libraries(mnist glog gflags infini_train)
target_link_libraries(mnist glog gflags infini_train Eigen3::Eigen)

add_executable(gpt2 example/gpt2/main.cc example/gpt2/dataset.cc example/gpt2/net.cc)
target_link_libraries(gpt2 glog gflags infini_train)
target_link_libraries(gpt2 glog gflags infini_train Eigen3::Eigen)

# OpenBLAS
# target_link_libraries(infini_train ${OpenBLAS_LIBRARIES})
# target_link_libraries(mnist ${OpenBLAS_LIBRARIES})
# target_link_libraries(gpt2 ${OpenBLAS_LIBRARIES})

# OpenMP
target_link_libraries(infini_train OpenMP::OpenMP_CXX)
target_link_libraries(mnist OpenMP::OpenMP_CXX)
target_link_libraries(gpt2 OpenMP::OpenMP_CXX)
endif()
4 changes: 4 additions & 0 deletions infini_train/include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <random>
#include <vector>

#include "Eigen/Dense"
#include "glog/logging.h"

#include "infini_train/include/device.h"
Expand Down Expand Up @@ -69,6 +70,9 @@ class Tensor : public std::enable_shared_from_this<Tensor> {

template <typename T> void Fill(T value);

Eigen::Map<Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> EigenMatrix();
Eigen::Map<Eigen::Matrix<float, 1, Eigen::Dynamic, Eigen::RowMajor>> EigenVector();

// TODO(dcj): return shared_ptr<Tensor> instead of Tensor later
Tensor To(Device device);

Expand Down
65 changes: 31 additions & 34 deletions infini_train/src/kernels/cpu/linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ MatmulBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight,
bool transpose, const std::shared_ptr<Tensor> &bias) {
/*
!transpose: output = input * weight + bias
output[*, out_features] = input[*, in_features] * weight[in_features, out_features] + bias[out_features]

transpose: output = input * weight^T + bias
output[*, out_features] = input[*, in_features] * weight[out_features, in_features]^T + bias[out_features]

!transpose: output = input * weight + bias
output[*, out_features] = input[*, in_features] * weight[in_features, out_features] + bias[out_features]
*/

const auto &input_dims = input->Dims();
Expand All @@ -130,24 +130,32 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
auto output_dims = input_dims;
*output_dims.rbegin() = out_features;
auto output = std::make_shared<Tensor>(output_dims, DataType::kFLOAT32);
for (int64_t i = 0; i < bs; ++i) {
for (int64_t j = 0; j < out_features; ++j) {
auto *data_ptr = static_cast<float *>(output->DataPtr()) + i * out_features + j;
*data_ptr = 0.0f;
for (int64_t k = 0; k < in_features; ++k) {
*data_ptr += reinterpret_cast<const float *>(input->DataPtr())[i * in_features + k]
* reinterpret_cast<const float *>(
weight->DataPtr())[transpose ? j * in_features + k : k * out_features + j];
}
*data_ptr += reinterpret_cast<const float *>(bias->DataPtr())[j];
}

if (transpose) {
output->EigenMatrix() = input->EigenMatrix() * weight->EigenMatrix().transpose();
} else {
output->EigenMatrix() = input->EigenMatrix() * weight->EigenMatrix();
}
output->EigenMatrix().rowwise() += bias->EigenVector();

return output;
}

std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight, bool transpose,
int64_t out_features, const std::shared_ptr<Tensor> &grad_output) {
/*
transpose: grad_input = grad_output * weight
grad_input[*, in_features] = grad_output[*, out_features] * weight[out_features, in_features]
grad_weight[out_features, in_features] = grad_output[*, out_features]^T * input[*, in_features]
grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)

!transpose: grad_input = grad_output * weight^T
grad_input[*, in_features] = grad_output[_, out_features] * weight[in_features, out_features]^T
grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features]
grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
*/

const auto &input_dims = input->Dims();
CHECK_GE(input_dims.size(), 2);
const int64_t bs = std::accumulate(input_dims.rbegin() + 1, input_dims.rend(), 1, std::multiplies<int64_t>{});
Expand All @@ -160,28 +168,17 @@ LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso

auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
auto grad_weight = std::make_shared<Tensor>(weight_dims, DataType::kFLOAT32);
grad_weight->Fill<float>(0.0f);
auto grad_bias = std::make_shared<Tensor>(std::vector<int64_t>{out_features}, DataType::kFLOAT32);
grad_bias->Fill<float>(0.0f);

for (int64_t i = 0; i < bs; ++i) {
for (int64_t j = 0; j < in_features; ++j) {
const auto input_idx = i * in_features + j;
auto *data_ptr = static_cast<float *>(grad_input->DataPtr()) + input_idx;
*data_ptr = 0.0f;
for (int64_t k = 0; k < out_features; ++k) {
const auto weight_idx = transpose ? k * in_features + j : j * out_features + k;
const auto grad = reinterpret_cast<const float *>(grad_output->DataPtr())[i * out_features + k];
*data_ptr += grad * reinterpret_cast<const float *>(weight->DataPtr())[weight_idx];
static_cast<float *>(grad_weight->DataPtr())[weight_idx]
+= grad * reinterpret_cast<const float *>(input->DataPtr())[input_idx];
}
}
for (int64_t k = 0; k < out_features; ++k) {
static_cast<float *>(grad_bias->DataPtr())[k]
+= reinterpret_cast<const float *>(grad_output->DataPtr())[i * out_features + k];
}

if (transpose) {
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix();
grad_weight->EigenMatrix() = grad_output->EigenMatrix().transpose() * input->EigenMatrix();
} else {
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix().transpose();
grad_weight->EigenMatrix() = input->EigenMatrix().transpose() * grad_output->EigenMatrix();
}
grad_bias->EigenVector() = grad_output->EigenMatrix().colwise().sum();

return {grad_input, grad_weight, grad_bias};
}
} // namespace infini_train::kernels::cpu
14 changes: 14 additions & 0 deletions infini_train/src/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#ifdef USE_CUDA
#include "cuda_runtime_api.h"
#endif

#include "Eigen/Dense"
#include "glog/logging.h"

#include "infini_train/include/autograd/elementwise.h"
Expand Down Expand Up @@ -124,6 +126,18 @@ template <typename T> void Tensor::Fill(T value) {

template void Tensor::Fill<float>(float);

Eigen::Map<Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> Tensor::EigenMatrix() {
const int64_t bs = std::accumulate(dims_.rbegin() + 1, dims_.rend(), 1, std::multiplies<int64_t>());
return Eigen::Map<Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
reinterpret_cast<float *>(DataPtr()), bs, *dims_.rbegin());
}

Eigen::Map<Eigen::Matrix<float, 1, Eigen::Dynamic, Eigen::RowMajor>> Tensor::EigenVector() {
CHECK_EQ(dims_.size(), 1);
return Eigen::Map<Eigen::Matrix<float, 1, Eigen::Dynamic, Eigen::RowMajor>>(reinterpret_cast<float *>(DataPtr()), 1,
dims_[0]);
}

Tensor Tensor::To(Device device) {
if (device == buffer_->GetDevice()) {
auto new_tensor = Tensor(*this, offset_, dims_);
Expand Down
1 change: 1 addition & 0 deletions third_party/eigen
Submodule eigen added at 68f4e5