From 8730e26343b03e4330ad5ffa0fa39c3df8821170 Mon Sep 17 00:00:00 2001 From: Haihan Jiang Date: Sun, 14 Jun 2026 01:57:50 -0700 Subject: [PATCH] docs: explain gradient descent loop options --- docs/implicit_diff.rst | 8 ++++++++ jaxopt/_src/gradient_descent.py | 8 +++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/docs/implicit_diff.rst b/docs/implicit_diff.rst index 9e1279b6..0c9efff2 100644 --- a/docs/implicit_diff.rst +++ b/docs/implicit_diff.rst @@ -25,6 +25,14 @@ All solvers in JAXopt support implicit differentiation **out-of-the-box**. Most solvers have an ``implicit_diff=True|False`` option. When set to ``False``, autodiff of unrolled iterates is used instead of implicit differentiation. +The ``jit`` and ``unroll`` options control how iterative solvers stage their +optimization loops. With the default ``jit=True``, the solver loop is compiled +with JAX. With ``unroll=True``, loop iterations are expanded so that autodiff can +differentiate through each iteration directly; this may increase compilation time +and memory use for large ``maxiter`` values. With ``unroll=False``, solvers keep +the loop as a JAX control-flow loop. The default ``unroll="auto"`` unrolls when +``implicit_diff=False`` or ``jit=False`` and otherwise keeps the loop rolled. + Using the ridge regression example from the :ref:`unconstrained optimization ` section, we can write:: diff --git a/jaxopt/_src/gradient_descent.py b/jaxopt/_src/gradient_descent.py index eabfdc72..8d4e5fa5 100644 --- a/jaxopt/_src/gradient_descent.py +++ b/jaxopt/_src/gradient_descent.py @@ -59,7 +59,13 @@ class GradientDescent(ProximalGradient): implicit_diff_solve: the linear system solver to use. jit: whether to JIT-compile the optimization loop (default: True). - unroll: whether to unroll the optimization loop (default: "auto"). + See the :ref:`implicit differentiation guide ` for how + this interacts with loop unrolling and differentiation. + unroll: whether to unroll the optimization loop (default: "auto"). When + set to "auto", the loop is unrolled when ``implicit_diff=False`` or + ``jit=False``. Unrolling enables autodiff through the solver iterations, + but can increase compilation time and memory usage for large + ``maxiter`` values. """ def init_state(self,