@@ -64,6 +64,9 @@ std::vector<std::shared_ptr<Tensor>> Function::Apply(const std::vector<std::shar
6464 next_functions_.emplace_back (std::make_shared<AccumulateGrad>(input_tensor->grad ()), 0 );
6565 } else {
6666 next_functions_.emplace_back (input_tensor->grad_fn (), input_tensor->output_idx ());
67+ if (input_tensor->grad_fn ()) {
68+ input_tensor->grad_fn ()->IncreaseDependenciesNumber ();
69+ }
6770 }
6871 output_requires_grad |= input_tensor->requires_grad ();
6972 }
@@ -84,10 +87,16 @@ std::vector<std::shared_ptr<Tensor>> Function::Apply(const std::vector<std::shar
8487
8588void Function::BackwardPartial (const std::shared_ptr<Tensor> &grad_output, int grad_output_idx) {
8689 LOG (INFO) << " start backward_partial of function: " << type_;
87- CHECK (!grad_outputs_[grad_output_idx]);
88- grad_outputs_[grad_output_idx] = grad_output;
89- ++grad_outputs_reached_;
90- if (grad_outputs_reached_ == grad_outputs_.size ()) {
90+ if (!grad_outputs_[grad_output_idx]) {
91+ grad_outputs_[grad_output_idx] = grad_output;
92+ ++grad_outputs_reached_;
93+ } else {
94+ auto accumulate_function = std::make_shared<AccumulateGrad>(grad_outputs_[grad_output_idx]);
95+ accumulate_function->BackwardPartial (grad_output, 0 );
96+ }
97+ ++dependencies_reached_;
98+ if (grad_outputs_reached_ == grad_outputs_.size ()
99+ && (dependencies_reached_ == dependencies_number_ || dependencies_number_ == 0 )) {
91100 auto grad_inputs = Backward (grad_outputs_);
92101 CHECK_EQ (grad_inputs.size (), next_functions_.size ());
93102 for (int idx = 0 ; idx < grad_inputs.size (); ++idx) {
@@ -97,15 +106,7 @@ void Function::BackwardPartial(const std::shared_ptr<Tensor> &grad_output, int g
97106 next_function->BackwardPartial (grad_input, output_idx);
98107 }
99108 }
100- // When a tensor is consumed by multiple operations, there may be multiple backward paths. As a result, the same
101- // function might be reused during backpropagation. Therefore, it's necessary to clear the state specific to
102- // this backward path.
103- ResetState ();
104109 }
105110}
106-
107- void Function::ResetState () {
108- grad_outputs_reached_ = 0 ;
109- for (auto &tensor : grad_outputs_) { tensor.reset (); }
110- }
111+ void Function::IncreaseDependenciesNumber () { ++dependencies_number_; }
111112} // namespace infini_train::autograd
0 commit comments