Skip to content

Commit cb1c680

Browse files
authored
feat(aggregation): Add FairGrad (#688)
* Add `FairGrad` * Add `FairGradWeighting` * Add a `fairgrad` optional dependency group (`pip install "torchjd[fairgrad]"`) backed by `scipy` * Add changelog entry
1 parent ef59f7c commit cb1c680

11 files changed

Lines changed: 253 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ changelog does not include internal changes that do not affect the user.
88

99
## [Unreleased]
1010

11+
### Added
12+
13+
- Added `FairGrad` and `FairGradWeighting` from [Fair Resource Allocation in Multi-Task
14+
Learning](https://arxiv.org/pdf/2402.15638).
15+
1116
### Changed
1217

1318
- **BREAKING**: Removed `numpy`, `quadprog` and `qpsolvers` from the main dependencies of `torchjd`,

NOTICES

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,34 @@ SOFTWARE.
5959

6060
-------------------------------------------------------------------------------
6161

62+
Project: fairgrad
63+
Source: https://github.com/OptMN-Lab/fairgrad
64+
Used in: src/torchjd/aggregation/_fairgrad.py
65+
66+
MIT License
67+
68+
Copyright (c) 2024 OptMN-Lab
69+
70+
Permission is hereby granted, free of charge, to any person obtaining a copy
71+
of this software and associated documentation files (the "Software"), to deal
72+
in the Software without restriction, including without limitation the rights
73+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
74+
copies of the Software, and to permit persons to whom the Software is
75+
furnished to do so, subject to the following conditions:
76+
77+
The above copyright notice and this permission notice shall be included in all
78+
copies or substantial portions of the Software.
79+
80+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
81+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
82+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
83+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
84+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
85+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
86+
SOFTWARE.
87+
88+
-------------------------------------------------------------------------------
89+
6290
Project: ConFIG
6391
Source: https://github.com/tum-pbs/ConFIG/tree/main/conflictfree
6492
Used in: src/torchjd/aggregation/_config.py
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
:hide-toc:
2+
3+
FairGrad
4+
========
5+
6+
.. autoclass:: torchjd.aggregation.FairGrad
7+
:members: __call__
8+
9+
.. autoclass:: torchjd.aggregation.FairGradWeighting
10+
:members: __call__

docs/source/docs/aggregation/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Abstract base classes
3737
constant.rst
3838
cr_mogm.rst
3939
dualproj.rst
40+
fairgrad.rst
4041
flattening.rst
4142
graddrop.rst
4243
gradvac.rst

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,17 @@ cagrad = [
119119
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
120120
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
121121
]
122+
fairgrad = [
123+
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
124+
"scipy",
125+
]
122126
full = [
123127
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
124128
"quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked
125129
"qpsolvers>=1.0.1", # Does not work before 1.0.1
126130
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
127131
"ecos>=2.0.14", # Does not work before 2.0.14
132+
"scipy",
128133
]
129134

130135
[tool.pytest.ini_options]

src/torchjd/aggregation/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from ._constant import Constant, ConstantWeighting
6868
from ._cr_mogm import CRMOGMWeighting
6969
from ._dualproj import DualProj, DualProjWeighting
70+
from ._fairgrad import FairGrad, FairGradWeighting
7071
from ._flattening import Flattening
7172
from ._graddrop import GradDrop
7273
from ._gradvac import GradVac, GradVacWeighting
@@ -95,6 +96,8 @@
9596
"CRMOGMWeighting",
9697
"DualProj",
9798
"DualProjWeighting",
99+
"FairGrad",
100+
"FairGradWeighting",
98101
"Flattening",
99102
"GeneralizedWeighting",
100103
"GradDrop",
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Partly adapted from https://github.com/OptMN-Lab/fairgrad/blob/main/methods/weight_methods.py#L811-L825 — MIT License, Copyright (c) 2024 OptMN-Lab.
2+
# See NOTICES for the full license text.
3+
4+
from __future__ import annotations
5+
6+
import contextlib
7+
8+
import torch
9+
from torch import Tensor
10+
11+
from torchjd._mixins import _WithOptionalDeps
12+
from torchjd.linalg import PSDMatrix
13+
14+
from ._aggregator_bases import GramianWeightedAggregator
15+
from ._mixins import _NonDifferentiable
16+
from ._weighting_bases import _GramianWeighting
17+
18+
with contextlib.suppress(ImportError):
19+
import numpy as np
20+
from scipy.optimize import least_squares
21+
22+
23+
# Non-differentiable: the scipy solver operates on numpy arrays, breaking the autograd graph.
24+
class FairGradWeighting(_WithOptionalDeps, _NonDifferentiable, _GramianWeighting):
25+
r"""
26+
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the
27+
weights of :class:`~torchjd.aggregation.FairGrad`, as defined in Equation 4 of `Fair Resource
28+
Allocation in Multi-Task Learning <https://arxiv.org/pdf/2402.15638>`_.
29+
30+
:param alpha: The parameter controlling the type of fairness in the alpha-fairness
31+
formulation.
32+
:param max_iters: The maximum number of iterations of the optimization loop. If set to None,
33+
the default value of ``scipy.optimize.least_squares`` (``100 * m``) will be used.
34+
35+
.. note::
36+
This implementation was adapted from the `official implementation
37+
<https://github.com/OptMN-Lab/fairgrad/blob/main/methods/weight_methods.py#L811-L825>`_.
38+
39+
.. note::
40+
This aggregator requires optional dependencies. When they are not installed, instantiating
41+
it raises an :class:`ImportError` with installation instructions.
42+
To install them, use ``pip install "torchjd[fairgrad]"``.
43+
"""
44+
45+
_REQUIRED_DEPS = ["numpy", "scipy"]
46+
_INSTALL_HINT = 'Install it with: pip install "torchjd[fairgrad]"'
47+
48+
def __init__(self, alpha: float, max_iters: int | None = None) -> None:
49+
super().__init__()
50+
self.alpha = alpha
51+
self.max_iters = max_iters
52+
53+
def forward(self, gramian: PSDMatrix, /) -> Tensor:
54+
m = gramian.shape[0]
55+
uniform = np.ones(m) / m
56+
57+
if self.alpha == 0:
58+
# When alpha=0, the alpha-fairness formulation reduces to linear scalarization with
59+
# uniform weights (see Section 3 of https://arxiv.org/pdf/2402.15638).
60+
weight_array = uniform
61+
else:
62+
gramian_array = gramian.detach().cpu().numpy()
63+
64+
def objective(x: np.ndarray) -> np.ndarray:
65+
return np.dot(gramian_array, x) - np.power(x, -1.0 / self.alpha)
66+
67+
res = least_squares(objective, uniform, bounds=(0, np.inf), max_nfev=self.max_iters)
68+
weight_array = res.x
69+
70+
return torch.tensor(weight_array).to(device=gramian.device, dtype=gramian.dtype)
71+
72+
@property
73+
def alpha(self) -> float:
74+
return self._alpha
75+
76+
@alpha.setter
77+
def alpha(self, value: float) -> None:
78+
self._alpha = value
79+
80+
81+
class FairGrad(_NonDifferentiable, GramianWeightedAggregator):
82+
r"""
83+
:class:`~torchjd.aggregation.GramianWeightedAggregator` using the step decision of Algorithm 1
84+
of `Fair Resource Allocation in Multi-Task Learning
85+
<https://arxiv.org/pdf/2402.15638.pdf>`_.
86+
87+
:param alpha: The parameter controlling the type of fairness in the alpha-fairness
88+
formulation.
89+
:param max_iters: The maximum number of iterations of the optimization loop. If set to None,
90+
the default value of ``scipy.optimize.least_squares`` (``100 * m``) will be used.
91+
92+
.. note::
93+
This aggregator requires optional dependencies. When they are not installed, instantiating
94+
it raises an :class:`ImportError` with installation instructions.
95+
To install them, use ``pip install "torchjd[fairgrad]"``.
96+
"""
97+
98+
gramian_weighting: FairGradWeighting
99+
100+
def __init__(self, alpha: float, max_iters: int | None = None) -> None:
101+
super().__init__(FairGradWeighting(alpha=alpha, max_iters=max_iters))
102+
103+
@property
104+
def alpha(self) -> float:
105+
return self.gramian_weighting.alpha
106+
107+
@alpha.setter
108+
def alpha(self, value: float) -> None:
109+
self.gramian_weighting.alpha = value
110+
111+
@property
112+
def max_iters(self) -> int | None:
113+
return self.gramian_weighting.max_iters
114+
115+
@max_iters.setter
116+
def max_iters(self, value: int | None) -> None:
117+
self.gramian_weighting.max_iters = value
118+
119+
def __repr__(self) -> str:
120+
return f"{self.__class__.__name__}(alpha={self.alpha}, max_iters={self.max_iters})"
121+
122+
def __str__(self) -> str:
123+
return f"{self.alpha}-FairGrad"

tests/plots/interactive_plotter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
CAGrad,
2121
ConFIG,
2222
DualProj,
23+
FairGrad,
2324
GradDrop,
2425
GradVac,
2526
Mean,
@@ -63,6 +64,7 @@ def main() -> None:
6364
str(CAGrad(c=0.5)): lambda: CAGrad(c=0.5),
6465
str(ConFIG()): lambda: ConFIG(),
6566
str(DualProj()): lambda: DualProj(projector=QuadprogProjector(reg_eps=1e-7)),
67+
str(FairGrad(alpha=1.0)): lambda: FairGrad(alpha=1.0),
6668
str(GradDrop()): lambda: GradDrop(),
6769
str(GradVac()): lambda: GradVac(),
6870
str(IMTLG()): lambda: IMTLG(),
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from utils.optional_deps import skip_if_deps_not_installed
2+
3+
from torchjd.aggregation import FairGrad, FairGradWeighting, Mean
4+
5+
skip_if_deps_not_installed(FairGradWeighting)
6+
7+
from pytest import mark
8+
from torch import Tensor
9+
from utils.tensors import ones_
10+
11+
from ._asserts import assert_expected_structure, assert_non_conflicting, assert_non_differentiable
12+
from ._inputs import scaled_matrices, typical_matrices
13+
14+
# max_iters=1 is enough to produce a finite output for structure tests.
15+
scaled_pairs = [(FairGrad(alpha=1.0, max_iters=1), matrix) for matrix in scaled_matrices]
16+
typical_pairs = [(FairGrad(alpha=1.0, max_iters=1), matrix) for matrix in typical_matrices]
17+
requires_grad_pairs = [(FairGrad(alpha=1.0, max_iters=1), ones_(3, 5, requires_grad=True))]
18+
# max_iters=100 is sufficient for convergence on the base matrices.
19+
non_conflicting_pairs = [
20+
(FairGrad(alpha=0.1, max_iters=100), matrix) for matrix in typical_matrices
21+
]
22+
23+
24+
@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
25+
def test_expected_structure(aggregator: FairGrad, matrix: Tensor) -> None:
26+
assert_expected_structure(aggregator, matrix)
27+
28+
29+
@mark.parametrize(["aggregator", "matrix"], requires_grad_pairs)
30+
def test_non_differentiable(aggregator: FairGrad, matrix: Tensor) -> None:
31+
assert_non_differentiable(aggregator, matrix)
32+
33+
34+
@mark.parametrize(["aggregator", "matrix"], non_conflicting_pairs)
35+
def test_non_conflicting(aggregator: FairGrad, matrix: Tensor) -> None:
36+
assert_non_conflicting(aggregator, matrix)
37+
38+
39+
def test_representations() -> None:
40+
A = FairGrad(alpha=0.1, max_iters=None)
41+
assert repr(A) == "FairGrad(alpha=0.1, max_iters=None)"
42+
assert str(A) == "0.1-FairGrad"
43+
44+
45+
def test_alpha_setter_updates_value() -> None:
46+
A = FairGrad(alpha=1.0)
47+
A.alpha = 2.0
48+
assert A.alpha == 2.0
49+
assert A.gramian_weighting.alpha == 2.0
50+
51+
52+
def test_max_iters_setter_updates_value() -> None:
53+
A = FairGrad(alpha=1.0)
54+
A.max_iters = 50
55+
assert A.max_iters == 50
56+
assert A.gramian_weighting.max_iters == 50
57+
58+
59+
def test_alpha_zero_gives_uniform_weights() -> None:
60+
aggregator = FairGrad(alpha=0.0)
61+
mean = Mean()
62+
for matrix in typical_matrices:
63+
assert aggregator(matrix).allclose(mean(matrix))

tests/unit/aggregation/test_values.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from torch.testing import assert_close
44
from utils.optional_deps import (
55
IS_CAGRAD_AVAILABLE,
6+
IS_FAIRGRAD_AVAILABLE,
67
IS_NASH_MTL_AVAILABLE,
78
IS_QUADPROG_PROJ_AVAILABLE,
89
)
@@ -20,6 +21,8 @@
2021
ConstantWeighting,
2122
DualProj,
2223
DualProjWeighting,
24+
FairGrad,
25+
FairGradWeighting,
2326
GradDrop,
2427
GradVac,
2528
GradVacWeighting,
@@ -93,13 +96,20 @@
9396
(SumWeighting(), G_base, tensor([1.0, 1.0])),
9497
]
9598

96-
9799
if IS_QUADPROG_PROJ_AVAILABLE:
98100
AGGREGATOR_PARAMETRIZATIONS.append((DualProj(), J_base, tensor([0.5563, 1.1109, 1.1109])))
99101
AGGREGATOR_PARAMETRIZATIONS.append((UPGrad(), J_base, tensor([0.2929, 1.9004, 1.9004])))
100102
WEIGHTING_PARAMETRIZATIONS.append((DualProjWeighting(), G_base, tensor([0.6109, 0.5000])))
101103
WEIGHTING_PARAMETRIZATIONS.append((UPGradWeighting(), G_base, tensor([1.1109, 0.7894])))
102104

105+
if IS_FAIRGRAD_AVAILABLE:
106+
AGGREGATOR_PARAMETRIZATIONS.append(
107+
(FairGrad(alpha=1.0), J_base, tensor([0.0766, 0.9985, 0.9985]))
108+
)
109+
WEIGHTING_PARAMETRIZATIONS.append(
110+
(FairGradWeighting(alpha=1.0), G_base, tensor([0.5915, 0.4071]))
111+
)
112+
103113
if IS_CAGRAD_AVAILABLE:
104114
AGGREGATOR_PARAMETRIZATIONS.append((CAGrad(c=0.5), J_base, tensor([0.1835, 1.2041, 1.2041])))
105115
WEIGHTING_PARAMETRIZATIONS.append((CAGradWeighting(c=0.5), G_base, tensor([0.7041, 0.5000])))

0 commit comments

Comments
 (0)