11#include < cstdint>
2- #include < fcntl.h>
32#include < memory>
43#include < numeric>
54#include < tuple>
65
76#include " glog/logging.h"
87
8+ #include " infini_train/include/autograd/linear.h"
99#include " infini_train/include/dispatcher.h"
1010#include " infini_train/include/tensor.h"
1111
@@ -70,6 +70,7 @@ MatmulBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
7070 const int64_t k = input_dims[input_dims.size () - 1 ];
7171 CHECK_EQ (k, other_dims[other_dims.size () - 2 ]);
7272 const int64_t n = other_dims[other_dims.size () - 1 ];
73+
7374 CHECK_EQ (m, grad_output_dims[grad_output_dims.size () - 2 ]);
7475 CHECK_EQ (n, grad_output_dims[grad_output_dims.size () - 1 ]);
7576
@@ -147,8 +148,9 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
147148
148149// TODO(dcj): support linear without bias later
149150std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
150- LinearBackward (const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight, bool transpose,
151- int64_t out_features, const std::shared_ptr<Tensor> &grad_output, const bool bias) {
151+ LinearBackward (const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight,
152+ infini_train::autograd::LinearMeta meta, const std::shared_ptr<Tensor> &grad_output,
153+ infini_train::autograd::LinearGradFlags grad_flags) {
152154 /*
153155 transpose: grad_input = grad_output * weight
154156 grad_input[*, in_features] = grad_output[*, out_features] * weight[out_features, in_features]
@@ -160,32 +162,46 @@ LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
160162 grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features]
161163 grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
162164 */
165+ const auto &input_dims = meta.input_dims ;
166+ const auto in_features = meta.in_features ;
167+ const auto out_features = meta.out_features ;
168+ const auto transpose = meta.transpose ;
169+ const auto bias = meta.has_bias ;
170+ const auto compute_grad_input = grad_flags.input ;
171+ const auto compute_grad_weight = grad_flags.weight ;
172+ const auto compute_grad_bias = grad_flags.bias ;
163173
164- const auto &input_dims = input->Dims ();
165174 CHECK_GE (input_dims.size (), 2 );
166- const int64_t bs = std::accumulate (input_dims.rbegin () + 1 , input_dims.rend (), 1 , std::multiplies<int64_t >{});
167- const int64_t in_features = *input_dims.rbegin ();
168175
169- const auto &weight_dims = weight->Dims ();
170- CHECK_EQ (weight_dims.size (), 2 );
171- CHECK_EQ (in_features, weight_dims[transpose ? 1 : 0 ]);
172- CHECK_EQ (out_features, weight_dims[transpose ? 0 : 1 ]);
176+ std::vector<int64_t > weight_dims
177+ = transpose ? std::vector<int64_t >{out_features, in_features} : std::vector<int64_t >{in_features, out_features};
173178
174- auto grad_input = std::make_shared <Tensor>(input_dims, DataType:: kFLOAT32 ) ;
175- auto grad_weight = std::make_shared <Tensor>(weight_dims, DataType:: kFLOAT32 ) ;
179+ std::shared_ptr <Tensor> grad_input = nullptr ;
180+ std::shared_ptr <Tensor> grad_weight = nullptr ;
176181 std::shared_ptr<Tensor> grad_bias = nullptr ;
177- if (bias) {
178- grad_bias = std::make_shared<Tensor>(std::vector<int64_t >{out_features}, DataType::kFLOAT32 );
182+
183+ if (compute_grad_input) {
184+ CHECK (weight != nullptr ) << " compute_grad_input=true but weight is nullptr (selective save mismatch)" ;
185+ grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32 );
186+ if (transpose) {
187+ grad_input->EigenMatrix () = grad_output->EigenMatrix () * weight->EigenMatrix ();
188+ } else {
189+ grad_input->EigenMatrix () = grad_output->EigenMatrix () * weight->EigenMatrix ().transpose ();
190+ }
179191 }
180192
181- if (transpose) {
182- grad_input->EigenMatrix () = grad_output->EigenMatrix () * weight->EigenMatrix ();
183- grad_weight->EigenMatrix () = grad_output->EigenMatrix ().transpose () * input->EigenMatrix ();
184- } else {
185- grad_input->EigenMatrix () = grad_output->EigenMatrix () * weight->EigenMatrix ().transpose ();
186- grad_weight->EigenMatrix () = input->EigenMatrix ().transpose () * grad_output->EigenMatrix ();
193+ if (compute_grad_weight) {
194+ CHECK (input != nullptr ) << " compute_grad_weight=true but input is nullptr (selective save mismatch)" ;
195+ grad_weight = std::make_shared<Tensor>(weight_dims, DataType::kFLOAT32 );
196+ if (transpose) {
197+ grad_weight->EigenMatrix () = grad_output->EigenMatrix ().transpose () * input->EigenMatrix ();
198+ } else {
199+ grad_weight->EigenMatrix () = input->EigenMatrix ().transpose () * grad_output->EigenMatrix ();
200+ }
187201 }
188- if (bias) {
202+
203+ if (compute_grad_bias && bias) {
204+ grad_bias = std::make_shared<Tensor>(std::vector<int64_t >{out_features}, DataType::kFLOAT32 );
189205 grad_bias->EigenVector () = grad_output->EigenMatrix ().colwise ().sum ();
190206 }
191207
0 commit comments