@@ -106,11 +106,11 @@ MatmulBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
106106std::shared_ptr<Tensor> LinearForward (const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight,
107107 bool transpose, const std::shared_ptr<Tensor> &bias) {
108108 /*
109- !transpose: output = input * weight + bias
110- output[*, out_features] = input[*, in_features] * weight[in_features, out_features] + bias[out_features]
111-
112109 transpose: output = input * weight^T + bias
113110 output[*, out_features] = input[*, in_features] * weight[out_features, in_features]^T + bias[out_features]
111+
112+ !transpose: output = input * weight + bias
113+ output[*, out_features] = input[*, in_features] * weight[in_features, out_features] + bias[out_features]
114114 */
115115
116116 const auto &input_dims = input->Dims ();
@@ -130,24 +130,32 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
130130 auto output_dims = input_dims;
131131 *output_dims.rbegin () = out_features;
132132 auto output = std::make_shared<Tensor>(output_dims, DataType::kFLOAT32 );
133- for (int64_t i = 0 ; i < bs; ++i) {
134- for (int64_t j = 0 ; j < out_features; ++j) {
135- auto *data_ptr = static_cast <float *>(output->DataPtr ()) + i * out_features + j;
136- *data_ptr = 0 .0f ;
137- for (int64_t k = 0 ; k < in_features; ++k) {
138- *data_ptr += reinterpret_cast <const float *>(input->DataPtr ())[i * in_features + k]
139- * reinterpret_cast <const float *>(
140- weight->DataPtr ())[transpose ? j * in_features + k : k * out_features + j];
141- }
142- *data_ptr += reinterpret_cast <const float *>(bias->DataPtr ())[j];
143- }
133+
134+ if (transpose) {
135+ output->EigenMatrix () = input->EigenMatrix () * weight->EigenMatrix ().transpose ();
136+ } else {
137+ output->EigenMatrix () = input->EigenMatrix () * weight->EigenMatrix ();
144138 }
139+ output->EigenMatrix ().rowwise () += bias->EigenVector ();
140+
145141 return output;
146142}
147143
148144std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
149145LinearBackward (const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight, bool transpose,
150146 int64_t out_features, const std::shared_ptr<Tensor> &grad_output) {
147+ /*
148+ transpose: grad_input = grad_output * weight
149+ grad_input[*, in_features] = grad_output[*, out_features] * weight[out_features, in_features]
150+ grad_weight[out_features, in_features] = grad_output[*, out_features]^T * input[*, in_features]
151+ grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
152+
153+ !transpose: grad_input = grad_output * weight^T
154+ grad_input[*, in_features] = grad_output[_, out_features] * weight[in_features, out_features]^T
155+ grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features]
156+ grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
157+ */
158+
151159 const auto &input_dims = input->Dims ();
152160 CHECK_GE (input_dims.size (), 2 );
153161 const int64_t bs = std::accumulate (input_dims.rbegin () + 1 , input_dims.rend (), 1 , std::multiplies<int64_t >{});
@@ -160,28 +168,17 @@ LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
160168
161169 auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32 );
162170 auto grad_weight = std::make_shared<Tensor>(weight_dims, DataType::kFLOAT32 );
163- grad_weight->Fill <float >(0 .0f );
164171 auto grad_bias = std::make_shared<Tensor>(std::vector<int64_t >{out_features}, DataType::kFLOAT32 );
165- grad_bias->Fill <float >(0 .0f );
166-
167- for (int64_t i = 0 ; i < bs; ++i) {
168- for (int64_t j = 0 ; j < in_features; ++j) {
169- const auto input_idx = i * in_features + j;
170- auto *data_ptr = static_cast <float *>(grad_input->DataPtr ()) + input_idx;
171- *data_ptr = 0 .0f ;
172- for (int64_t k = 0 ; k < out_features; ++k) {
173- const auto weight_idx = transpose ? k * in_features + j : j * out_features + k;
174- const auto grad = reinterpret_cast <const float *>(grad_output->DataPtr ())[i * out_features + k];
175- *data_ptr += grad * reinterpret_cast <const float *>(weight->DataPtr ())[weight_idx];
176- static_cast <float *>(grad_weight->DataPtr ())[weight_idx]
177- += grad * reinterpret_cast <const float *>(input->DataPtr ())[input_idx];
178- }
179- }
180- for (int64_t k = 0 ; k < out_features; ++k) {
181- static_cast <float *>(grad_bias->DataPtr ())[k]
182- += reinterpret_cast <const float *>(grad_output->DataPtr ())[i * out_features + k];
183- }
172+
173+ if (transpose) {
174+ grad_input->EigenMatrix () = grad_output->EigenMatrix () * weight->EigenMatrix ();
175+ grad_weight->EigenMatrix () = grad_output->EigenMatrix ().transpose () * input->EigenMatrix ();
176+ } else {
177+ grad_input->EigenMatrix () = grad_output->EigenMatrix () * weight->EigenMatrix ().transpose ();
178+ grad_weight->EigenMatrix () = input->EigenMatrix ().transpose () * grad_output->EigenMatrix ();
184179 }
180+ grad_bias->EigenVector () = grad_output->EigenMatrix ().colwise ().sum ();
181+
185182 return {grad_input, grad_weight, grad_bias};
186183}
187184} // namespace infini_train::kernels::cpu
0 commit comments