Skip to content

Commit c48884d

Browse files
fix: reset next_functions when backward completes
1 parent b5d6a6b commit c48884d

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

infini_train/src/autograd/function.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ std::vector<std::shared_ptr<Tensor>> Function::Apply(const std::vector<std::shar
5757
auto output_tensors = Forward(input_tensors);
5858
SetupContext(input_tensors, output_tensors);
5959

60+
next_functions_.clear();
6061
bool output_requires_grad = false;
6162
for (int idx = 0; idx < input_tensors.size(); ++idx) {
6263
const auto &input_tensor = input_tensors[idx];
@@ -107,5 +108,6 @@ void Function::BackwardPartial(const std::shared_ptr<Tensor> &grad_output, int g
107108
void Function::ResetState() {
108109
grad_outputs_reached_ = 0;
109110
for (auto &tensor : grad_outputs_) { tensor.reset(); }
111+
for (auto &[next_function, idx] : next_functions_) { next_function.reset(); }
110112
}
111113
} // namespace infini_train::autograd

0 commit comments

Comments
 (0)