From d9ec2ed890f1c0fcc7b660c628bb06a49fb8f7c0 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Thu, 9 Apr 2026 07:39:33 +0000 Subject: [PATCH] remove gradient in state.positions after mace forward --- torch_sim/models/mace.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index e13e3ce8..a371ec2e 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -336,6 +336,8 @@ def forward( # noqa: C901 if stress is not None: results["stress"] = stress.detach() + # Detach positions to prevent gradients from flowing back + state.positions = state.positions.detach() return results