Skip to content

Commit a8e82a0

Browse files
committed
1 parent e0b0d5c commit a8e82a0

4 files changed

Lines changed: 10 additions & 8 deletions

File tree

latest/docs/aggregation/aligned_mtl/index.html

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,8 @@ <h1>Aligned-MTL<a class="headerlink" href="#aligned-mtl" title="Link to this hea
311311
</dl>
312312
<div class="admonition note">
313313
<p class="admonition-title">Note</p>
314-
<p>This implementation was adapted from the <a class="reference external" href="https://github.com/SamsungLabs/MTL/tree/master/code/optim/aligned">official implementation</a>.</p>
314+
<p>This implementation was adapted from the official implementation of SamsungLabs/MTL,
315+
which is not available anymore at the time of writing.</p>
315316
</div>
316317
</dd></dl>
317318

latest/docs/aggregation/mgda/index.html

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -295,10 +295,11 @@
295295
<h1>MGDA<a class="headerlink" href="#mgda" title="Link to this heading"></a></h1>
296296
<dl class="py class">
297297
<dt class="sig sig-object py" id="torchjd.aggregation.MGDA">
298-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">MGDA</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">epsilon</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">max_iters</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">100</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_mgda.py#L10-L28"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.MGDA" title="Link to this definition"></a></dt>
298+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">MGDA</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">epsilon</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">max_iters</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">100</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_mgda.py#L10-L29"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.MGDA" title="Link to this definition"></a></dt>
299299
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> performing the gradient aggregation
300-
step of <a class="reference external" href="https://www.sciencedirect.com/science/article/pii/S1631073X12000738">Multiple-gradient descent algorithm (MGDA) for multiobjective optimization</a>. The implementation is
301-
based on Algorithm 2 of <a class="reference external" href="https://proceedings.neurips.cc/paper_files/paper/2018/file/432aca3a1e345e339f35a30c8f65edce-Paper.pdf">Multi-Task Learning as Multi-Objective Optimization</a>.</p>
300+
step of <a class="reference external" href="https://comptes-rendus.academie-sciences.fr/mathematique/articles/10.1016/j.crma.2012.03.014/">Multiple-gradient descent algorithm (MGDA) for multiobjective optimization</a>.
301+
The implementation is based on Algorithm 2 of <a class="reference external" href="https://proceedings.neurips.cc/paper_files/paper/2018/file/432aca3a1e345e339f35a30c8f65edce-Paper.pdf">Multi-Task Learning as Multi-Objective
302+
Optimization</a>.</p>
302303
<dl class="field-list simple">
303304
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
304305
<dd class="field-odd"><ul class="simple">
@@ -311,7 +312,7 @@ <h1>MGDA<a class="headerlink" href="#mgda" title="Link to this heading">¶</a></
311312

312313
<dl class="py class">
313314
<dt class="sig sig-object py" id="torchjd.aggregation.MGDAWeighting">
314-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">MGDAWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">epsilon</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">max_iters</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">100</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_mgda.py#L31-L71"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.MGDAWeighting" title="Link to this definition"></a></dt>
315+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">MGDAWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">epsilon</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">max_iters</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">100</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_mgda.py#L32-L72"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.MGDAWeighting" title="Link to this definition"></a></dt>
315316
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a> giving the weights of
316317
<a class="reference internal" href="#torchjd.aggregation.MGDA" title="torchjd.aggregation.MGDA"><code class="xref py py-class docutils literal notranslate"><span class="pre">MGDA</span></code></a>.</p>
317318
<dl class="field-list simple">

latest/examples/lightning_integration/index.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,12 +293,12 @@
293293
<article role="main" id="furo-main-content">
294294
<section id="pytorch-lightning-integration">
295295
<h1>PyTorch Lightning Integration<a class="headerlink" href="#pytorch-lightning-integration" title="Link to this heading"></a></h1>
296-
<p>To use Jacobian descent with TorchJD in a <a class="reference external" href="https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule" title="(in PyTorch Lightning v2.6.0)"><code class="xref py py-class docutils literal notranslate"><span class="pre">LightningModule</span></code></a>, you need
296+
<p>To use Jacobian descent with TorchJD in a <a class="reference external" href="https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule" title="(in PyTorch Lightning v2.6.1)"><code class="xref py py-class docutils literal notranslate"><span class="pre">LightningModule</span></code></a>, you need
297297
to turn off automatic optimization by setting <code class="docutils literal notranslate"><span class="pre">automatic_optimization</span></code> to <code class="docutils literal notranslate"><span class="pre">False</span></code> and to
298298
customize the <code class="docutils literal notranslate"><span class="pre">training_step</span></code> method to make it call the appropriate TorchJD method
299299
(<a class="reference internal" href="../../docs/autojac/backward/"><span class="doc">backward</span></a> or <a class="reference internal" href="../../docs/autojac/mtl_backward/"><span class="doc">mtl_backward</span></a>).</p>
300300
<p>The following code example demonstrates a basic multi-task learning setup using a
301-
<a class="reference external" href="https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule" title="(in PyTorch Lightning v2.6.0)"><code class="xref py py-class docutils literal notranslate"><span class="pre">LightningModule</span></code></a> that will call <a class="reference internal" href="../../docs/autojac/mtl_backward/"><span class="doc">mtl_backward</span></a> at each training iteration.</p>
301+
<a class="reference external" href="https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule" title="(in PyTorch Lightning v2.6.1)"><code class="xref py py-class docutils literal notranslate"><span class="pre">LightningModule</span></code></a> that will call <a class="reference internal" href="../../docs/autojac/mtl_backward/"><span class="doc">mtl_backward</span></a> at each training iteration.</p>
302302
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
303303
<span class="kn">from</span><span class="w"> </span><span class="nn">lightning</span><span class="w"> </span><span class="kn">import</span> <span class="n">LightningModule</span><span class="p">,</span> <span class="n">Trainer</span>
304304
<span class="kn">from</span><span class="w"> </span><span class="nn">lightning.pytorch.utilities.types</span><span class="w"> </span><span class="kn">import</span> <span class="n">OptimizerLRScheduler</span>

0 commit comments

Comments
 (0)