Skip to content

Commit b6aa83c

Browse files
committed
feat: implement dependencies analysis, reduce redundant computation in backward propagation
1 parent 6100feb commit b6aa83c

2 files changed

Lines changed: 17 additions & 14 deletions

File tree

infini_train/include/autograd/function.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@ class Function : public std::enable_shared_from_this<Function> {
2424
std::vector<std::shared_ptr<Tensor>> Apply(const std::vector<std::shared_ptr<Tensor>> &input_tensors);
2525
virtual void BackwardPartial(const std::shared_ptr<Tensor> &grad_output, int idx);
2626

27-
void ResetState();
27+
void IncreaseDependenciesNumber();
2828

2929
protected:
3030
std::vector<std::shared_ptr<Tensor>> saved_tensors_;
3131

3232
private:
3333
std::vector<std::pair<std::shared_ptr<Function>, int>> next_functions_;
34+
int dependencies_number_ = 0;
35+
int dependencies_reached_ = 0;
3436
int grad_outputs_reached_ = 0;
3537
std::vector<std::shared_ptr<Tensor>> grad_outputs_;
3638
const std::string type_;

infini_train/src/autograd/function.cc

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ std::vector<std::shared_ptr<Tensor>> Function::Apply(const std::vector<std::shar
6464
next_functions_.emplace_back(std::make_shared<AccumulateGrad>(input_tensor->grad()), 0);
6565
} else {
6666
next_functions_.emplace_back(input_tensor->grad_fn(), input_tensor->output_idx());
67+
if (input_tensor->grad_fn()) {
68+
input_tensor->grad_fn()->IncreaseDependenciesNumber();
69+
}
6770
}
6871
output_requires_grad |= input_tensor->requires_grad();
6972
}
@@ -84,10 +87,16 @@ std::vector<std::shared_ptr<Tensor>> Function::Apply(const std::vector<std::shar
8487

8588
void Function::BackwardPartial(const std::shared_ptr<Tensor> &grad_output, int grad_output_idx) {
8689
LOG(INFO) << "start backward_partial of function: " << type_;
87-
CHECK(!grad_outputs_[grad_output_idx]);
88-
grad_outputs_[grad_output_idx] = grad_output;
89-
++grad_outputs_reached_;
90-
if (grad_outputs_reached_ == grad_outputs_.size()) {
90+
if (!grad_outputs_[grad_output_idx]) {
91+
grad_outputs_[grad_output_idx] = grad_output;
92+
++grad_outputs_reached_;
93+
} else {
94+
auto accumulate_function = std::make_shared<AccumulateGrad>(grad_outputs_[grad_output_idx]);
95+
accumulate_function->BackwardPartial(grad_output, 0);
96+
}
97+
++dependencies_reached_;
98+
if (grad_outputs_reached_ == grad_outputs_.size()
99+
&& (dependencies_reached_ == dependencies_number_ || dependencies_number_ == 0)) {
91100
auto grad_inputs = Backward(grad_outputs_);
92101
CHECK_EQ(grad_inputs.size(), next_functions_.size());
93102
for (int idx = 0; idx < grad_inputs.size(); ++idx) {
@@ -97,15 +106,7 @@ void Function::BackwardPartial(const std::shared_ptr<Tensor> &grad_output, int g
97106
next_function->BackwardPartial(grad_input, output_idx);
98107
}
99108
}
100-
// When a tensor is consumed by multiple operations, there may be multiple backward paths. As a result, the same
101-
// function might be reused during backpropagation. Therefore, it's necessary to clear the state specific to
102-
// this backward path.
103-
ResetState();
104109
}
105110
}
106-
107-
void Function::ResetState() {
108-
grad_outputs_reached_ = 0;
109-
for (auto &tensor : grad_outputs_) { tensor.reset(); }
110-
}
111+
void Function::IncreaseDependenciesNumber() { ++dependencies_number_; }
111112
} // namespace infini_train::autograd

0 commit comments

Comments
 (0)