diff --git a/infini_train/include/autograd/function.h b/infini_train/include/autograd/function.h index 569de48c..c31242a0 100644 --- a/infini_train/include/autograd/function.h +++ b/infini_train/include/autograd/function.h @@ -24,13 +24,15 @@ class Function : public std::enable_shared_from_this { std::vector> Apply(const std::vector> &input_tensors); virtual void BackwardPartial(const std::shared_ptr &grad_output, int idx); - void ResetState(); + void IncreaseDependenciesNumber(); protected: std::vector> saved_tensors_; private: std::vector, int>> next_functions_; + int dependencies_number_ = 0; + int dependencies_reached_ = 0; int grad_outputs_reached_ = 0; std::vector> grad_outputs_; const std::string type_; diff --git a/infini_train/src/autograd/function.cc b/infini_train/src/autograd/function.cc index 99da9b95..51e56582 100644 --- a/infini_train/src/autograd/function.cc +++ b/infini_train/src/autograd/function.cc @@ -64,6 +64,9 @@ std::vector> Function::Apply(const std::vector(input_tensor->grad()), 0); } else { next_functions_.emplace_back(input_tensor->grad_fn(), input_tensor->output_idx()); + if (input_tensor->grad_fn()) { + input_tensor->grad_fn()->IncreaseDependenciesNumber(); + } } output_requires_grad |= input_tensor->requires_grad(); } @@ -84,10 +87,16 @@ std::vector> Function::Apply(const std::vector &grad_output, int grad_output_idx) { LOG(INFO) << "start backward_partial of function: " << type_; - CHECK(!grad_outputs_[grad_output_idx]); - grad_outputs_[grad_output_idx] = grad_output; - ++grad_outputs_reached_; - if (grad_outputs_reached_ == grad_outputs_.size()) { + if (!grad_outputs_[grad_output_idx]) { + grad_outputs_[grad_output_idx] = grad_output; + ++grad_outputs_reached_; + } else { + auto accumulate_function = std::make_shared(grad_outputs_[grad_output_idx]); + accumulate_function->BackwardPartial(grad_output, 0); + } + ++dependencies_reached_; + if (grad_outputs_reached_ == grad_outputs_.size() + && (dependencies_reached_ == dependencies_number_ || dependencies_number_ == 0)) { auto grad_inputs = Backward(grad_outputs_); CHECK_EQ(grad_inputs.size(), next_functions_.size()); for (int idx = 0; idx < grad_inputs.size(); ++idx) { @@ -97,15 +106,7 @@ void Function::BackwardPartial(const std::shared_ptr &grad_output, int g next_function->BackwardPartial(grad_input, output_idx); } } - // When a tensor is consumed by multiple operations, there may be multiple backward paths. As a result, the same - // function might be reused during backpropagation. Therefore, it's necessary to clear the state specific to - // this backward path. - ResetState(); } } - -void Function::ResetState() { - grad_outputs_reached_ = 0; - for (auto &tensor : grad_outputs_) { tensor.reset(); } -} +void Function::IncreaseDependenciesNumber() { ++dependencies_number_; } } // namespace infini_train::autograd