diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index e13e3ce8..a371ec2e 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -336,6 +336,8 @@ def forward( # noqa: C901 if stress is not None: results["stress"] = stress.detach() + # Detach positions to prevent gradients from flowing back + state.positions = state.positions.detach() return results