Skip to content

Commit b33cc8c

Browse files
chen2021673claude
andcommitted
fix(lora): skip gradient computation for frozen parameters to reduce memory
Add needs_input_grad_ tracking in autograd Function to skip unnecessary gradient allocation and computation for frozen (requires_grad=false) parameters. For LoRA fine-tuning, this avoids allocating grad_weight tensors for all frozen base model weights, reducing peak GPU memory from ~10.7GB to ~7.7GB. Also consolidate LinearBackward loose params into LinearMeta and LinearGradFlags structs for clarity. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 898d97d commit b33cc8c

6 files changed

Lines changed: 237 additions & 122 deletions

File tree

infini_train/include/autograd/function.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class Function : public std::enable_shared_from_this<Function> {
4747

4848
protected:
4949
std::vector<std::shared_ptr<Tensor>> saved_tensors_;
50+
std::vector<bool> needs_input_grad_;
5051

5152
private:
5253
std::vector<std::pair<std::shared_ptr<Function>, int>> next_functions_;

infini_train/include/autograd/linear.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <cstdint>
34
#include <memory>
45
#include <vector>
56

@@ -10,6 +11,21 @@ class Tensor;
1011
}
1112

1213
namespace infini_train::autograd {
14+
15+
struct LinearMeta {
16+
bool transpose = false;
17+
bool has_bias = false;
18+
int64_t in_features = 0;
19+
int64_t out_features = 0;
20+
std::vector<int64_t> input_dims;
21+
};
22+
23+
struct LinearGradFlags {
24+
bool input = false;
25+
bool weight = false;
26+
bool bias = false;
27+
};
28+
1329
class Linear : public Function {
1430
public:
1531
static constexpr char kType[] = "LinearFunction";
@@ -22,7 +38,6 @@ class Linear : public Function {
2238
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;
2339

2440
private:
25-
int64_t out_features_ = 0;
26-
bool bias_ = true;
41+
LinearMeta meta_;
2742
};
2843
} // namespace infini_train::autograd

infini_train/src/autograd/function.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,16 @@ std::vector<std::shared_ptr<Tensor>> Function::Apply(const std::vector<std::shar
3636
}
3737
}
3838

39+
// Populate needs_input_grad_ before Forward/SetupContext so that
40+
// SetupContext can use it for saved-tensor pruning.
41+
// Must be done before NoGradGuard since it checks GradMode.
42+
if (autograd::GradMode::IsEnabled()) {
43+
needs_input_grad_.resize(input_tensors.size());
44+
for (size_t idx = 0; idx < input_tensors.size(); ++idx) {
45+
needs_input_grad_[idx] = input_tensors[idx]->requires_grad();
46+
}
47+
}
48+
3949
std::vector<std::shared_ptr<Tensor>> output_tensors;
4050
{
4151
autograd::NoGradGuard no_grad;
@@ -129,6 +139,7 @@ void Function::BackwardPartial(const std::shared_ptr<Tensor> &grad_output, int g
129139

130140
saved_tensors_.clear();
131141
grad_outputs_.clear();
142+
needs_input_grad_.clear();
132143
grad_outputs_reached_ = 0;
133144
dependencies_reached_ = 0;
134145

infini_train/src/autograd/linear.cc

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,18 @@ void Linear::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tens
2020
const std::vector<std::shared_ptr<Tensor>> &) {
2121
const auto &input = input_tensors[0];
2222
const auto &weight = input_tensors[1];
23-
saved_tensors_ = {input, weight};
24-
bias_ = input_tensors.size() == 3;
25-
out_features_ = weight->Dims()[0];
23+
24+
bool need_input = needs_input_grad_.size() > 0 && needs_input_grad_[0];
25+
bool need_weight = needs_input_grad_.size() > 1 && needs_input_grad_[1];
26+
27+
// grad_input needs weight, grad_weight needs input
28+
saved_tensors_ = {need_weight ? input : nullptr, need_input ? weight : nullptr};
29+
30+
meta_ = {.transpose = true,
31+
.has_bias = input_tensors.size() == 3,
32+
.in_features = weight->Dims()[1],
33+
.out_features = weight->Dims()[0],
34+
.input_dims = input->Dims()};
2635
}
2736

2837
std::vector<std::shared_ptr<Tensor>> Linear::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
@@ -32,13 +41,20 @@ std::vector<std::shared_ptr<Tensor>> Linear::Backward(const std::vector<std::sha
3241
CHECK_EQ(grad_outputs.size(), 1);
3342
const auto &grad_output = grad_outputs[0];
3443

35-
auto device = input->GetDevice().type();
44+
CHECK(!needs_input_grad_.empty()) << "needs_input_grad_ not populated in Linear::Backward";
45+
LinearGradFlags grad_flags = {.input = needs_input_grad_[0],
46+
.weight = needs_input_grad_.size() > 1 && needs_input_grad_[1],
47+
.bias = meta_.has_bias && needs_input_grad_.size() > 2 && needs_input_grad_[2]};
48+
49+
auto device = grad_output->GetDevice().type();
3650
auto [grad_input, grad_weight, grad_bias]
3751
= Dispatcher::Instance()
3852
.Call<std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>>(
39-
{device, "LinearBackward"}, input, weight, true, out_features_, grad_output, bias_);
40-
return bias_ ? std::vector<std::shared_ptr<Tensor>>{grad_input, grad_weight, grad_bias}
41-
: std::vector<std::shared_ptr<Tensor>>{grad_input, grad_weight};
42-
;
53+
{device, "LinearBackward"}, input, weight, meta_, grad_output, grad_flags);
54+
if (meta_.has_bias) {
55+
return {grad_input, grad_weight, grad_bias};
56+
} else {
57+
return {grad_input, grad_weight};
58+
}
4359
}
4460
} // namespace infini_train::autograd

infini_train/src/kernels/cpu/linear.cc

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#include <cstdint>
2-
#include <fcntl.h>
32
#include <memory>
43
#include <numeric>
54
#include <tuple>
65

76
#include "glog/logging.h"
87

8+
#include "infini_train/include/autograd/linear.h"
99
#include "infini_train/include/dispatcher.h"
1010
#include "infini_train/include/tensor.h"
1111

@@ -70,6 +70,7 @@ MatmulBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
7070
const int64_t k = input_dims[input_dims.size() - 1];
7171
CHECK_EQ(k, other_dims[other_dims.size() - 2]);
7272
const int64_t n = other_dims[other_dims.size() - 1];
73+
7374
CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]);
7475
CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]);
7576

@@ -147,8 +148,9 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
147148

148149
// TODO(dcj): support linear without bias later
149150
std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
150-
LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight, bool transpose,
151-
int64_t out_features, const std::shared_ptr<Tensor> &grad_output, const bool bias) {
151+
LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight,
152+
infini_train::autograd::LinearMeta meta, const std::shared_ptr<Tensor> &grad_output,
153+
infini_train::autograd::LinearGradFlags grad_flags) {
152154
/*
153155
transpose: grad_input = grad_output * weight
154156
grad_input[*, in_features] = grad_output[*, out_features] * weight[out_features, in_features]
@@ -160,32 +162,46 @@ LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
160162
grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features]
161163
grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
162164
*/
165+
const auto &input_dims = meta.input_dims;
166+
const auto in_features = meta.in_features;
167+
const auto out_features = meta.out_features;
168+
const auto transpose = meta.transpose;
169+
const auto bias = meta.has_bias;
170+
const auto compute_grad_input = grad_flags.input;
171+
const auto compute_grad_weight = grad_flags.weight;
172+
const auto compute_grad_bias = grad_flags.bias;
163173

164-
const auto &input_dims = input->Dims();
165174
CHECK_GE(input_dims.size(), 2);
166-
const int64_t bs = std::accumulate(input_dims.rbegin() + 1, input_dims.rend(), 1, std::multiplies<int64_t>{});
167-
const int64_t in_features = *input_dims.rbegin();
168175

169-
const auto &weight_dims = weight->Dims();
170-
CHECK_EQ(weight_dims.size(), 2);
171-
CHECK_EQ(in_features, weight_dims[transpose ? 1 : 0]);
172-
CHECK_EQ(out_features, weight_dims[transpose ? 0 : 1]);
176+
std::vector<int64_t> weight_dims
177+
= transpose ? std::vector<int64_t>{out_features, in_features} : std::vector<int64_t>{in_features, out_features};
173178

174-
auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
175-
auto grad_weight = std::make_shared<Tensor>(weight_dims, DataType::kFLOAT32);
179+
std::shared_ptr<Tensor> grad_input = nullptr;
180+
std::shared_ptr<Tensor> grad_weight = nullptr;
176181
std::shared_ptr<Tensor> grad_bias = nullptr;
177-
if (bias) {
178-
grad_bias = std::make_shared<Tensor>(std::vector<int64_t>{out_features}, DataType::kFLOAT32);
182+
183+
if (compute_grad_input) {
184+
CHECK(weight != nullptr) << "compute_grad_input=true but weight is nullptr (selective save mismatch)";
185+
grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
186+
if (transpose) {
187+
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix();
188+
} else {
189+
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix().transpose();
190+
}
179191
}
180192

181-
if (transpose) {
182-
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix();
183-
grad_weight->EigenMatrix() = grad_output->EigenMatrix().transpose() * input->EigenMatrix();
184-
} else {
185-
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix().transpose();
186-
grad_weight->EigenMatrix() = input->EigenMatrix().transpose() * grad_output->EigenMatrix();
193+
if (compute_grad_weight) {
194+
CHECK(input != nullptr) << "compute_grad_weight=true but input is nullptr (selective save mismatch)";
195+
grad_weight = std::make_shared<Tensor>(weight_dims, DataType::kFLOAT32);
196+
if (transpose) {
197+
grad_weight->EigenMatrix() = grad_output->EigenMatrix().transpose() * input->EigenMatrix();
198+
} else {
199+
grad_weight->EigenMatrix() = input->EigenMatrix().transpose() * grad_output->EigenMatrix();
200+
}
187201
}
188-
if (bias) {
202+
203+
if (compute_grad_bias && bias) {
204+
grad_bias = std::make_shared<Tensor>(std::vector<int64_t>{out_features}, DataType::kFLOAT32);
189205
grad_bias->EigenVector() = grad_output->EigenMatrix().colwise().sum();
190206
}
191207

0 commit comments

Comments
 (0)