├── src
└── torchjd
│ ├── py.typed
│ ├── aggregation
│ ├── _utils
│ │ ├── __init__.py
│ │ ├── str.py
│ │ ├── non_differentiable.py
│ │ ├── check_dependencies.py
│ │ ├── pref_vector.py
│ │ ├── gramian.py
│ │ └── dual_cone.py
│ ├── _sum.py
│ ├── _mean.py
│ ├── _random.py
│ ├── _imtl_g.py
│ ├── _flattening.py
│ ├── _constant.py
│ ├── _pcgrad.py
│ ├── _trimmed_mean.py
│ ├── _aggregator_bases.py
│ ├── _weighting_bases.py
│ ├── _mgda.py
│ ├── _graddrop.py
│ ├── _config.py
│ └── _krum.py
│ ├── autojac
│ ├── _transform
│ │ ├── __init__.py
│ │ ├── _materialize.py
│ │ ├── _init.py
│ │ ├── _select.py
│ │ ├── _ordered_set.py
│ │ ├── _accumulate.py
│ │ ├── _stack.py
│ │ ├── _grad.py
│ │ ├── _diagonalize.py
│ │ └── _differentiate.py
│ ├── __init__.py
│ └── _utils.py
│ ├── __init__.py
│ └── autogram
│ ├── _gramian_accumulator.py
│ ├── __init__.py
│ ├── _edge_registry.py
│ ├── _gramian_computer.py
│ └── _gramian_utils.py
├── tests
├── doc
│ ├── __init__.py
│ ├── test_backward.py
│ ├── test_aggregation.py
│ └── test_autogram.py
├── plots
│ └── __init__.py
├── speed
│ ├── __init__.py
│ ├── autogram
│ │ ├── __init__.py
│ │ └── grad_vs_jac_vs_gram.py
│ └── utils.py
├── unit
│ ├── __init__.py
│ ├── autogram
│ │ ├── __init__.py
│ │ ├── test_gramian_utils.py
│ │ └── test_edge_registry.py
│ ├── autojac
│ │ ├── __init__.py
│ │ └── _transform
│ │ │ ├── __init__.py
│ │ │ ├── test_init.py
│ │ │ ├── test_select.py
│ │ │ ├── test_accumulate.py
│ │ │ ├── test_stack.py
│ │ │ └── test_diagonalize.py
│ ├── aggregation
│ │ ├── __init__.py
│ │ ├── _utils
│ │ │ ├── __init__.py
│ │ │ ├── test_pref_vector.py
│ │ │ └── test_dual_cone.py
│ │ ├── test_aggregator_bases.py
│ │ ├── test_random.py
│ │ ├── test_aligned_mtl.py
│ │ ├── test_sum.py
│ │ ├── test_mean.py
│ │ ├── test_imtl_g.py
│ │ ├── test_config.py
│ │ ├── _inputs.py
│ │ ├── test_cagrad.py
│ │ ├── test_pcgrad.py
│ │ ├── test_trimmed_mean.py
│ │ ├── test_mgda.py
│ │ ├── test_nash_mtl.py
│ │ ├── test_dualproj.py
│ │ ├── test_krum.py
│ │ ├── test_upgrad.py
│ │ ├── test_graddrop.py
│ │ └── test_constant.py
│ └── test_deprecations.py
├── utils
│ ├── __init__.py
│ ├── contexts.py
│ ├── dict_assertions.py
│ └── tensors.py
├── device.py
└── conftest.py
├── docs
├── source
│ ├── icons
│ │ ├── favicon.ico
│ │ ├── favicon-16x16.png
│ │ ├── favicon-32x32.png
│ │ ├── apple-touch-icon.png
│ │ ├── android-chrome-192x192.png
│ │ ├── android-chrome-512x512.png
│ │ ├── site.webmanifest
│ │ ├── TorchJD_logo.svg
│ │ ├── TorchJD_text.svg
│ │ └── TorchJD_text_dark.svg
│ ├── _static
│ │ ├── logo-dark-mode.png
│ │ └── logo-light-mode.png
│ ├── docs
│ │ ├── autojac
│ │ │ ├── backward.rst
│ │ │ ├── mtl_backward.rst
│ │ │ └── index.rst
│ │ ├── autogram
│ │ │ ├── engine.rst
│ │ │ └── index.rst
│ │ └── aggregation
│ │ │ ├── config.rst
│ │ │ ├── graddrop.rst
│ │ │ ├── nash_mtl.rst
│ │ │ ├── flattening.rst
│ │ │ ├── trimmed_mean.rst
│ │ │ ├── sum.rst
│ │ │ ├── krum.rst
│ │ │ ├── mean.rst
│ │ │ ├── mgda.rst
│ │ │ ├── imtl_g.rst
│ │ │ ├── cagrad.rst
│ │ │ ├── pcgrad.rst
│ │ │ ├── random.rst
│ │ │ ├── upgrad.rst
│ │ │ ├── constant.rst
│ │ │ ├── dualproj.rst
│ │ │ ├── aligned_mtl.rst
│ │ │ └── index.rst
│ ├── installation.md
│ ├── _templates
│ │ └── page.html
│ ├── examples
│ │ ├── rnn.rst
│ │ ├── partial_jd.rst
│ │ ├── index.rst
│ │ ├── basic_usage.rst
│ │ ├── mtl.rst
│ │ ├── iwmtl.rst
│ │ ├── amp.rst
│ │ ├── monitoring.rst
│ │ └── lightning_integration.rst
│ └── index.rst
├── Makefile
└── make.bat
├── .github
└── workflows
│ ├── release.yml
│ ├── build-deploy-docs.yml
│ └── tests.yml
├── .gitignore
└── .pre-commit-config.yaml
/src/torchjd/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/doc/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/plots/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/speed/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/unit/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/speed/autogram/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/unit/autogram/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/unit/autojac/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/_utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/unit/autojac/_transform/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/source/icons/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TorchJD/torchjd/HEAD/docs/source/icons/favicon.ico
--------------------------------------------------------------------------------
/docs/source/icons/favicon-16x16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TorchJD/torchjd/HEAD/docs/source/icons/favicon-16x16.png
--------------------------------------------------------------------------------
/docs/source/icons/favicon-32x32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TorchJD/torchjd/HEAD/docs/source/icons/favicon-32x32.png
--------------------------------------------------------------------------------
/docs/source/_static/logo-dark-mode.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TorchJD/torchjd/HEAD/docs/source/_static/logo-dark-mode.png
--------------------------------------------------------------------------------
/docs/source/icons/apple-touch-icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TorchJD/torchjd/HEAD/docs/source/icons/apple-touch-icon.png
--------------------------------------------------------------------------------
/docs/source/_static/logo-light-mode.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TorchJD/torchjd/HEAD/docs/source/_static/logo-light-mode.png
--------------------------------------------------------------------------------
/docs/source/docs/autojac/backward.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | backward
4 | ========
5 |
6 | .. autofunction:: torchjd.autojac.backward
7 |
--------------------------------------------------------------------------------
/docs/source/icons/android-chrome-192x192.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TorchJD/torchjd/HEAD/docs/source/icons/android-chrome-192x192.png
--------------------------------------------------------------------------------
/docs/source/icons/android-chrome-512x512.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TorchJD/torchjd/HEAD/docs/source/icons/android-chrome-512x512.png
--------------------------------------------------------------------------------
/docs/source/docs/autojac/mtl_backward.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | mtl_backward
4 | ============
5 |
6 | .. autofunction:: torchjd.autojac.mtl_backward
7 |
--------------------------------------------------------------------------------
/docs/source/docs/autogram/engine.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | Engine
4 | ======
5 |
6 | .. autoclass:: torchjd.autogram.Engine
7 | :members:
8 | :undoc-members:
9 |
--------------------------------------------------------------------------------
/docs/source/docs/aggregation/config.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | ConFIG
4 | ======
5 |
6 | .. autoclass:: torchjd.aggregation.ConFIG
7 | :members:
8 | :undoc-members:
9 | :exclude-members: forward
10 |
--------------------------------------------------------------------------------
/docs/source/docs/autogram/index.rst:
--------------------------------------------------------------------------------
1 | autogram
2 | ========
3 |
4 |
5 | .. automodule:: torchjd.autogram
6 | :members:
7 |
8 | .. toctree::
9 | :hidden:
10 | :maxdepth: 1
11 |
12 | engine.rst
13 |
--------------------------------------------------------------------------------
/docs/source/docs/aggregation/graddrop.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | GradDrop
4 | ========
5 |
6 | .. autoclass:: torchjd.aggregation.GradDrop
7 | :members:
8 | :undoc-members:
9 | :exclude-members: forward
10 |
--------------------------------------------------------------------------------
/docs/source/docs/aggregation/nash_mtl.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | Nash-MTL
4 | ========
5 |
6 | .. autoclass:: torchjd.aggregation.NashMTL
7 | :members:
8 | :undoc-members:
9 | :exclude-members: forward
10 |
--------------------------------------------------------------------------------
/docs/source/docs/aggregation/flattening.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | Flattening
4 | ==========
5 |
6 | .. autoclass:: torchjd.aggregation.Flattening
7 | :members:
8 | :undoc-members:
9 | :exclude-members: forward
10 |
--------------------------------------------------------------------------------
/docs/source/docs/aggregation/trimmed_mean.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | Trimmed Mean
4 | ============
5 |
6 | .. autoclass:: torchjd.aggregation.TrimmedMean
7 | :members:
8 | :undoc-members:
9 | :exclude-members: forward
10 |
--------------------------------------------------------------------------------
/docs/source/docs/autojac/index.rst:
--------------------------------------------------------------------------------
1 | autojac
2 | =======
3 |
4 | .. automodule:: torchjd.autojac
5 | :members:
6 |
7 |
8 | .. toctree::
9 | :hidden:
10 | :maxdepth: 1
11 |
12 | backward.rst
13 | mtl_backward.rst
14 |
--------------------------------------------------------------------------------
/docs/source/icons/site.webmanifest:
--------------------------------------------------------------------------------
1 | {"name":"","short_name":"","icons":[{"src":"/android-chrome-192x192.png","sizes":"192x192","type":"image/png"},{"src":"/android-chrome-512x512.png","sizes":"512x512","type":"image/png"}],"theme_color":"#ffffff","background_color":"#ffffff","display":"standalone"}
2 |
--------------------------------------------------------------------------------
/docs/source/docs/aggregation/sum.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | Sum
4 | ===
5 |
6 | .. autoclass:: torchjd.aggregation.Sum
7 | :members:
8 | :undoc-members:
9 | :exclude-members: forward
10 |
11 | .. autoclass:: torchjd.aggregation.SumWeighting
12 | :members:
13 | :undoc-members:
14 | :exclude-members: forward
15 |
--------------------------------------------------------------------------------
/docs/source/docs/aggregation/krum.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | Krum
4 | ====
5 |
6 | .. autoclass:: torchjd.aggregation.Krum
7 | :members:
8 | :undoc-members:
9 | :exclude-members: forward
10 |
11 | .. autoclass:: torchjd.aggregation.KrumWeighting
12 | :members:
13 | :undoc-members:
14 | :exclude-members: forward
15 |
--------------------------------------------------------------------------------
/docs/source/docs/aggregation/mean.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | Mean
4 | ====
5 |
6 | .. autoclass:: torchjd.aggregation.Mean
7 | :members:
8 | :undoc-members:
9 | :exclude-members: forward
10 |
11 | .. autoclass:: torchjd.aggregation.MeanWeighting
12 | :members:
13 | :undoc-members:
14 | :exclude-members: forward
15 |
--------------------------------------------------------------------------------
/docs/source/docs/aggregation/mgda.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | MGDA
4 | ====
5 |
6 | .. autoclass:: torchjd.aggregation.MGDA
7 | :members:
8 | :undoc-members:
9 | :exclude-members: forward
10 |
11 | .. autoclass:: torchjd.aggregation.MGDAWeighting
12 | :members:
13 | :undoc-members:
14 | :exclude-members: forward
15 |
--------------------------------------------------------------------------------
/docs/source/docs/aggregation/imtl_g.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | IMTL-G
4 | ======
5 |
6 | .. autoclass:: torchjd.aggregation.IMTLG
7 | :members:
8 | :undoc-members:
9 | :exclude-members: forward
10 |
11 | .. autoclass:: torchjd.aggregation.IMTLGWeighting
12 | :members:
13 | :undoc-members:
14 | :exclude-members: forward
15 |
--------------------------------------------------------------------------------
/docs/source/docs/aggregation/cagrad.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | CAGrad
4 | ======
5 |
6 | .. autoclass:: torchjd.aggregation.CAGrad
7 | :members:
8 | :undoc-members:
9 | :exclude-members: forward
10 |
11 | .. autoclass:: torchjd.aggregation.CAGradWeighting
12 | :members:
13 | :undoc-members:
14 | :exclude-members: forward
15 |
--------------------------------------------------------------------------------
/docs/source/docs/aggregation/pcgrad.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | PCGrad
4 | ======
5 |
6 | .. autoclass:: torchjd.aggregation.PCGrad
7 | :members:
8 | :undoc-members:
9 | :exclude-members: forward
10 |
11 | .. autoclass:: torchjd.aggregation.PCGradWeighting
12 | :members:
13 | :undoc-members:
14 | :exclude-members: forward
15 |
--------------------------------------------------------------------------------
/docs/source/docs/aggregation/random.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | Random
4 | ======
5 |
6 | .. autoclass:: torchjd.aggregation.Random
7 | :members:
8 | :undoc-members:
9 | :exclude-members: forward
10 |
11 | .. autoclass:: torchjd.aggregation.RandomWeighting
12 | :members:
13 | :undoc-members:
14 | :exclude-members: forward
15 |
--------------------------------------------------------------------------------
/docs/source/docs/aggregation/upgrad.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | UPGrad
4 | ======
5 |
6 | .. autoclass:: torchjd.aggregation.UPGrad
7 | :members:
8 | :undoc-members:
9 | :exclude-members: forward
10 |
11 | .. autoclass:: torchjd.aggregation.UPGradWeighting
12 | :members:
13 | :undoc-members:
14 | :exclude-members: forward
15 |
--------------------------------------------------------------------------------
/docs/source/docs/aggregation/constant.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | Constant
4 | ========
5 |
6 | .. autoclass:: torchjd.aggregation.Constant
7 | :members:
8 | :undoc-members:
9 | :exclude-members: forward
10 |
11 | .. autoclass:: torchjd.aggregation.ConstantWeighting
12 | :members:
13 | :undoc-members:
14 | :exclude-members: forward
15 |
--------------------------------------------------------------------------------
/docs/source/docs/aggregation/dualproj.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | DualProj
4 | ========
5 |
6 | .. autoclass:: torchjd.aggregation.DualProj
7 | :members:
8 | :undoc-members:
9 | :exclude-members: forward
10 |
11 | .. autoclass:: torchjd.aggregation.DualProjWeighting
12 | :members:
13 | :undoc-members:
14 | :exclude-members: forward
15 |
--------------------------------------------------------------------------------
/docs/source/docs/aggregation/aligned_mtl.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | Aligned-MTL
4 | ===========
5 |
6 | .. autoclass:: torchjd.aggregation.AlignedMTL
7 | :members:
8 | :undoc-members:
9 | :exclude-members: forward
10 |
11 | .. autoclass:: torchjd.aggregation.AlignedMTLWeighting
12 | :members:
13 | :undoc-members:
14 | :exclude-members: forward
15 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_utils/str.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 |
3 |
4 | def vector_to_str(vector: Tensor) -> str:
5 | """
6 | Transforms a Tensor of the form `tensor([1.23456, 1.0, ...])` into a string of the form
7 | `1.23, 1., ...`.
8 | """
9 |
10 | weights_str = ", ".join(["{:.2f}".format(value).rstrip("0") for value in vector])
11 | return weights_str
12 |
--------------------------------------------------------------------------------
/src/torchjd/autojac/_transform/__init__.py:
--------------------------------------------------------------------------------
1 | from ._accumulate import Accumulate
2 | from ._aggregate import Aggregate
3 | from ._base import Composition, Conjunction, RequirementError, Transform
4 | from ._diagonalize import Diagonalize
5 | from ._grad import Grad
6 | from ._init import Init
7 | from ._jac import Jac
8 | from ._ordered_set import OrderedSet
9 | from ._select import Select
10 | from ._stack import Stack
11 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_utils/non_differentiable.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor, nn
2 |
3 |
4 | class NonDifferentiableError(RuntimeError):
5 | def __init__(self, module: nn.Module):
6 | super().__init__(f"Trying to differentiate through {module}, which is not differentiable.")
7 |
8 |
9 | def raise_non_differentiable_error(module: nn.Module, _: tuple[Tensor, ...] | Tensor) -> None:
10 | raise NonDifferentiableError(module)
11 |
--------------------------------------------------------------------------------
/src/torchjd/autojac/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This package enables Jacobian descent, through the ``backward`` and ``mtl_backward`` functions, which
3 | are meant to replace the call to ``torch.backward`` or ``loss.backward`` in gradient descent. To
4 | combine the information of the Jacobian, an aggregator from the ``aggregation`` package has to be
5 | used.
6 | """
7 |
8 | from ._backward import backward
9 | from ._mtl_backward import mtl_backward
10 |
--------------------------------------------------------------------------------
/tests/unit/test_deprecations.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | # deprecated since 2025-08-18
5 | def test_deprecate_imports_from_torchjd():
6 | with pytest.deprecated_call():
7 | from torchjd import backward # noqa: F401
8 |
9 | with pytest.deprecated_call():
10 | from torchjd import mtl_backward # noqa: F401
11 |
12 | with pytest.raises(ImportError):
13 | from torchjd import something_that_does_not_exist # noqa: F401
14 |
--------------------------------------------------------------------------------
/tests/device.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 | try:
6 | _device_str = os.environ["PYTEST_TORCH_DEVICE"]
7 | except KeyError:
8 | _device_str = "cpu" # Default to cpu if environment variable not set
9 |
10 | if _device_str != "cuda:0" and _device_str != "cpu":
11 | raise ValueError(f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}")
12 |
13 | if _device_str == "cuda:0" and not torch.cuda.is_available():
14 | raise ValueError('Requested device "cuda:0" but cuda is not available.')
15 |
16 | DEVICE = torch.device(_device_str)
17 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_utils/check_dependencies.py:
--------------------------------------------------------------------------------
1 | from importlib.util import find_spec
2 |
3 |
4 | class OptionalDepsNotInstalledError(ModuleNotFoundError):
5 | pass
6 |
7 |
8 | def check_dependencies_are_installed(dependency_names: list[str]) -> None:
9 | """
10 | Check that the required list of dependencies are installed.
11 |
12 | This can be useful for Aggregators whose dependencies are optional when installing torchjd.
13 | """
14 |
15 | if any(find_spec(name) is None for name in dependency_names):
16 | raise OptionalDepsNotInstalledError()
17 |
--------------------------------------------------------------------------------
/tests/utils/contexts.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Generator
2 | from contextlib import AbstractContextManager, contextmanager
3 | from typing import Any, TypeAlias
4 |
5 | import torch
6 | from device import DEVICE
7 |
8 | ExceptionContext: TypeAlias = AbstractContextManager[Exception | None]
9 |
10 |
11 | @contextmanager
12 | def fork_rng(seed: int = 0) -> Generator[Any, None, None]:
13 | devices = [DEVICE] if DEVICE.type == "cuda" else []
14 | with torch.random.fork_rng(devices=devices, device_type=DEVICE.type) as ctx:
15 | torch.manual_seed(seed)
16 | yield ctx
17 |
--------------------------------------------------------------------------------
/tests/utils/dict_assertions.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Hashable
2 | from typing import TypeVar
3 |
4 | from torch import Tensor
5 | from torch.testing import assert_close
6 |
7 | _KeyType = TypeVar("_KeyType", bound=Hashable)
8 |
9 |
10 | def assert_tensor_dicts_are_close(d1: dict[_KeyType, Tensor], d2: dict[_KeyType, Tensor]) -> None:
11 | """
12 | Check that two dictionaries of tensors are close enough. Note that this does not require the
13 | keys to have the same ordering.
14 | """
15 |
16 | assert d1.keys() == d2.keys()
17 |
18 | for key in d1:
19 | assert_close(d1[key], d2[key])
20 |
--------------------------------------------------------------------------------
/docs/source/installation.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | ```{include} ../../README.md
4 | :start-after:
5 | :end-before:
6 | ```
7 |
8 | Note that `torchjd` requires Python 3.10, 3.11, 3.12, 3.13 or 3.14 and `torch>=2.0`.
9 |
10 | Some aggregators (CAGrad and Nash-MTL) have additional dependencies that are not included by default
11 | when installing `torchjd`. To install them, you can use:
12 | ```
13 | pip install torchjd[cagrad]
14 | ```
15 | ```
16 | pip install torchjd[nash_mtl]
17 | ```
18 |
19 | To install `torchjd` with all of its optional dependencies, you can also use:
20 | ```
21 | pip install torchjd[full]
22 | ```
23 |
--------------------------------------------------------------------------------
/tests/doc/test_backward.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains the test of the backward usage example, with a verification of the value of the
3 | obtained `.grad` field.
4 | """
5 |
6 | from torch.testing import assert_close
7 |
8 |
9 | def test_backward():
10 | import torch
11 |
12 | from torchjd.aggregation import UPGrad
13 | from torchjd.autojac import backward
14 |
15 | param = torch.tensor([1.0, 2.0], requires_grad=True)
16 | # Compute arbitrary quantities that are function of param
17 | y1 = torch.tensor([-1.0, 1.0]) @ param
18 | y2 = (param**2).sum()
19 |
20 | backward([y1, y2], UPGrad())
21 |
22 | assert_close(param.grad, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04)
23 |
--------------------------------------------------------------------------------
/tests/speed/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 |
7 | def noop():
8 | pass
9 |
10 |
11 | def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> Tensor:
12 | init_fn()
13 |
14 | times = []
15 | for _ in range(n_runs):
16 | pre_fn()
17 | start = time.perf_counter()
18 | fn()
19 | post_fn()
20 | elapsed_time = time.perf_counter() - start
21 | times.append(elapsed_time)
22 |
23 | return torch.tensor(times)
24 |
25 |
26 | def print_times(name: str, times: Tensor) -> None:
27 | print(f"{name} times (avg = {times.mean():.5f}, std = {times.std():.5f}")
28 | print(times)
29 | print()
30 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?= -W --keep-going -n
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | env:
8 | PYTHON_VERSION: 3.14
9 |
10 | jobs:
11 | pypi-publish:
12 | name: Publish to PyPI
13 | environment: release
14 | runs-on: ubuntu-latest
15 | permissions:
16 | # IMPORTANT: this permission is mandatory for trusted publishing
17 | id-token: write
18 | steps:
19 | - name: Checkout repository
20 | uses: actions/checkout@v4
21 |
22 | - name: Set up uv
23 | uses: astral-sh/setup-uv@v5
24 | with:
25 | python-version: ${{ env.PYTHON_VERSION }}
26 |
27 | - name: Build
28 | run: uv build
29 |
30 | - name: Publish package distributions to PyPI
31 | run: uv publish -v
32 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/test_aggregator_bases.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Sequence
2 | from contextlib import nullcontext as does_not_raise
3 |
4 | from pytest import mark, raises
5 | from utils.contexts import ExceptionContext
6 | from utils.tensors import randn_
7 |
8 | from torchjd.aggregation import Aggregator
9 |
10 |
11 | @mark.parametrize(
12 | ["shape", "expectation"],
13 | [
14 | ([], raises(ValueError)),
15 | ([1], raises(ValueError)),
16 | ([1, 2], does_not_raise()),
17 | ([1, 2, 3], raises(ValueError)),
18 | ([1, 2, 3, 4], raises(ValueError)),
19 | ],
20 | )
21 | def test_check_is_matrix(shape: Sequence[int], expectation: ExceptionContext):
22 | with expectation:
23 | Aggregator._check_is_matrix(randn_(shape))
24 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_sum.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 | from ._aggregator_bases import WeightedAggregator
5 | from ._weighting_bases import Matrix, Weighting
6 |
7 |
8 | class Sum(WeightedAggregator):
9 | """
10 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that sums of the rows of the input
11 | matrices.
12 | """
13 |
14 | def __init__(self):
15 | super().__init__(weighting=SumWeighting())
16 |
17 |
18 | class SumWeighting(Weighting[Matrix]):
19 | r"""
20 | :class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights
21 | :math:`\begin{bmatrix} 1 & \dots & 1 \end{bmatrix}^T \in \mathbb{R}^m`.
22 | """
23 |
24 | def forward(self, matrix: Tensor) -> Tensor:
25 | device = matrix.device
26 | dtype = matrix.dtype
27 | weights = torch.ones(matrix.shape[0], device=device, dtype=dtype)
28 | return weights
29 |
--------------------------------------------------------------------------------
/src/torchjd/autojac/_transform/_materialize.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Sequence
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 |
7 | def materialize(
8 | optional_tensors: Sequence[Tensor | None], inputs: Sequence[Tensor]
9 | ) -> tuple[Tensor, ...]:
10 | """
11 | Transforms a sequence of optional tensors by changing each None by a tensor of zeros of the same
12 | shape as the corresponding input. Returns the obtained sequence as a tuple.
13 |
14 | Note that the name "materialize" comes from the flag `materialize_grads` from
15 | `torch.autograd.grad`, which will be available in future torch releases.
16 | """
17 |
18 | tensors = []
19 | for optional_tensor, input in zip(optional_tensors, inputs):
20 | if optional_tensor is None:
21 | tensors.append(torch.zeros_like(input))
22 | else:
23 | tensors.append(optional_tensor)
24 | return tuple(tensors)
25 |
--------------------------------------------------------------------------------
/src/torchjd/autojac/_transform/_init.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Set
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 | from ._base import RequirementError, TensorDict, Transform
7 |
8 |
9 | class Init(Transform):
10 | """
11 | Transform from {} returning Gradients filled with ones for each of the provided values.
12 |
13 | :param values: Tensors for which Gradients must be returned.
14 | """
15 |
16 | def __init__(self, values: Set[Tensor]):
17 | self.values = values
18 |
19 | def __call__(self, input: TensorDict) -> TensorDict:
20 | return {value: torch.ones_like(value) for value in self.values}
21 |
22 | def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
23 | if not input_keys == set():
24 | raise RequirementError(
25 | f"The input_keys should be the empty set. Found input_keys {input_keys}."
26 | )
27 | return set(self.values)
28 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/src/torchjd/__init__.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 | from warnings import warn as _warn
3 |
4 | from .autojac import backward as _backward
5 | from .autojac import mtl_backward as _mtl_backward
6 |
7 | _deprecated_items: dict[str, tuple[str, Callable]] = {
8 | "backward": ("autojac", _backward),
9 | "mtl_backward": ("autojac", _mtl_backward),
10 | }
11 |
12 |
13 | def __getattr__(name: str) -> Callable:
14 | """
15 | If an attribute is not found in the module's dictionary and its name is in _deprecated_items,
16 | then import it with a warning.
17 | """
18 | if name in _deprecated_items:
19 | _warn(
20 | f"Importing `{name}` from `torchjd` is deprecated. Please import it from "
21 | f"`{_deprecated_items[name][0]}` instead.",
22 | DeprecationWarning,
23 | stacklevel=2,
24 | )
25 | return _deprecated_items[name][1]
26 | raise AttributeError(f"module {__name__} has no attribute {name}")
27 |
--------------------------------------------------------------------------------
/docs/source/docs/aggregation/index.rst:
--------------------------------------------------------------------------------
1 | aggregation
2 | ===========
3 |
4 | .. automodule:: torchjd.aggregation
5 | :no-members:
6 |
7 | Abstract base classes
8 | ---------------------
9 |
10 | .. autoclass:: torchjd.aggregation.Aggregator
11 | :members:
12 | :undoc-members:
13 | :exclude-members: forward
14 |
15 | .. autoclass:: torchjd.aggregation.Weighting
16 | :members:
17 | :undoc-members:
18 | :exclude-members: forward
19 |
20 | .. autoclass:: torchjd.aggregation.GeneralizedWeighting
21 | :members:
22 | :undoc-members:
23 | :exclude-members: forward
24 |
25 |
26 | .. toctree::
27 | :hidden:
28 | :maxdepth: 1
29 |
30 | upgrad.rst
31 | aligned_mtl.rst
32 | cagrad.rst
33 | config.rst
34 | constant.rst
35 | dualproj.rst
36 | flattening.rst
37 | graddrop.rst
38 | imtl_g.rst
39 | krum.rst
40 | mean.rst
41 | mgda.rst
42 | nash_mtl.rst
43 | pcgrad.rst
44 | random.rst
45 | sum.rst
46 | trimmed_mean.rst
47 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_mean.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 | from ._aggregator_bases import WeightedAggregator
5 | from ._weighting_bases import Matrix, Weighting
6 |
7 |
8 | class Mean(WeightedAggregator):
9 | """
10 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input
11 | matrices.
12 | """
13 |
14 | def __init__(self):
15 | super().__init__(weighting=MeanWeighting())
16 |
17 |
18 | class MeanWeighting(Weighting[Matrix]):
19 | r"""
20 | :class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights
21 | :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in
22 | \mathbb{R}^m`.
23 | """
24 |
25 | def forward(self, matrix: Tensor) -> Tensor:
26 | device = matrix.device
27 | dtype = matrix.dtype
28 | m = matrix.shape[0]
29 | weights = torch.full(size=[m], fill_value=1 / m, device=device, dtype=dtype)
30 | return weights
31 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/_utils/test_pref_vector.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext as does_not_raise
2 |
3 | from pytest import mark, raises
4 | from torch import Tensor
5 | from utils.contexts import ExceptionContext
6 | from utils.tensors import ones_
7 |
8 | from torchjd.aggregation._mean import MeanWeighting
9 | from torchjd.aggregation._utils.pref_vector import pref_vector_to_weighting
10 |
11 |
12 | @mark.parametrize(
13 | ["pref_vector", "expectation"],
14 | [
15 | (None, does_not_raise()),
16 | (ones_([]), raises(ValueError)),
17 | (ones_([0]), does_not_raise()),
18 | (ones_([1]), does_not_raise()),
19 | (ones_([5]), does_not_raise()),
20 | (ones_([1, 1]), raises(ValueError)),
21 | (ones_([1, 1, 1]), raises(ValueError)),
22 | ],
23 | )
24 | def test_pref_vector_to_weighting_check(pref_vector: Tensor | None, expectation: ExceptionContext):
25 | with expectation:
26 | _ = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
27 |
--------------------------------------------------------------------------------
/src/torchjd/autojac/_transform/_select.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Set
2 |
3 | from torch import Tensor
4 |
5 | from ._base import RequirementError, TensorDict, Transform
6 |
7 |
8 | class Select(Transform):
9 | """
10 | Transform returning a subset of the provided TensorDict.
11 |
12 | :param keys: The keys that should be included in the returned subset.
13 | """
14 |
15 | def __init__(self, keys: Set[Tensor]):
16 | self.keys = keys
17 |
18 | def __call__(self, tensor_dict: TensorDict) -> TensorDict:
19 | output = {key: tensor_dict[key] for key in self.keys}
20 | return type(tensor_dict)(output)
21 |
22 | def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
23 | keys = set(self.keys)
24 | if not keys.issubset(input_keys):
25 | raise RequirementError(
26 | f"The input_keys should be a super set of the keys to select. Found input_keys "
27 | f"{input_keys} and keys to select {keys}."
28 | )
29 | return keys
30 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/test_random.py:
--------------------------------------------------------------------------------
1 | from pytest import mark
2 | from torch import Tensor
3 |
4 | from torchjd.aggregation import Random
5 |
6 | from ._asserts import assert_expected_structure, assert_strongly_stationary
7 | from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices
8 |
9 | scaled_pairs = [(Random(), matrix) for matrix in scaled_matrices]
10 | typical_pairs = [(Random(), matrix) for matrix in typical_matrices]
11 | non_strong_pairs = [(Random(), matrix) for matrix in non_strong_matrices]
12 |
13 |
14 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
15 | def test_expected_structure(aggregator: Random, matrix: Tensor):
16 | assert_expected_structure(aggregator, matrix)
17 |
18 |
19 | @mark.parametrize(["aggregator", "matrix"], non_strong_pairs)
20 | def test_strongly_stationary(aggregator: Random, matrix: Tensor):
21 | assert_strongly_stationary(aggregator, matrix)
22 |
23 |
24 | def test_representations():
25 | A = Random()
26 | assert repr(A) == "Random()"
27 | assert str(A) == "Random"
28 |
--------------------------------------------------------------------------------
/docs/source/_templates/page.html:
--------------------------------------------------------------------------------
1 | {#
2 | Adds a canonical link to the stable version of the documentation to each built hmtl page.
3 | The reason to do that is that search engines require it:
4 | https://developers.google.com/search/docs/crawling-indexing/consolidate-duplicate-urls
5 |
6 | This template overrides the page.html furo template
7 | (https://github.com/pradyunsg/furo/blob/main/src/furo/theme/furo/page.html) by adding an extra
8 | line to its
section, with the appropriate canonical link. Note that there is no guarantee
9 | that the furo theme keeps using a file named page.html, so this could silently break with future
10 | updates of furo. See https://pradyunsg.me/furo/customisation/injecting/ and
11 | https://github.com/pradyunsg/furo/discussions/248 for more information.
12 | #}
13 | {% extends "!page.html" %}
14 | {% block extrahead %}
15 |
16 |
17 |
18 | {% endblock %}
19 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_random.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.nn import functional as F
4 |
5 | from ._aggregator_bases import WeightedAggregator
6 | from ._weighting_bases import Matrix, Weighting
7 |
8 |
9 | class Random(WeightedAggregator):
10 | """
11 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that computes a random combination of
12 | the rows of the provided matrices, as defined in algorithm 2 of `Reasonable Effectiveness of
13 | Random Weighting: A Litmus Test for Multi-Task Learning
14 | `_.
15 | """
16 |
17 | def __init__(self):
18 | super().__init__(RandomWeighting())
19 |
20 |
21 | class RandomWeighting(Weighting[Matrix]):
22 | """
23 | :class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights
24 | at each call.
25 | """
26 |
27 | def forward(self, matrix: Tensor) -> Tensor:
28 | random_vector = torch.randn(matrix.shape[0], device=matrix.device, dtype=matrix.dtype)
29 | weights = F.softmax(random_vector, dim=-1)
30 | return weights
31 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_utils/pref_vector.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 |
3 | from torchjd.aggregation._constant import ConstantWeighting
4 | from torchjd.aggregation._weighting_bases import Matrix, Weighting
5 |
6 | from .str import vector_to_str
7 |
8 |
9 | def pref_vector_to_weighting(
10 | pref_vector: Tensor | None, default: Weighting[Matrix]
11 | ) -> Weighting[Matrix]:
12 | """
13 | Returns the weighting associated to a given preference vector, with a fallback to a default
14 | weighting if the preference vector is None.
15 | """
16 |
17 | if pref_vector is None:
18 | return default
19 | else:
20 | if pref_vector.ndim != 1:
21 | raise ValueError(
22 | "Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = "
23 | f"{pref_vector.ndim}`."
24 | )
25 | return ConstantWeighting(pref_vector)
26 |
27 |
28 | def pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str:
29 | """Returns a suffix string containing the representation of the optional preference vector."""
30 |
31 | if pref_vector is None:
32 | return ""
33 | else:
34 | return f"([{vector_to_str(pref_vector)}])"
35 |
--------------------------------------------------------------------------------
/tests/doc/test_aggregation.py:
--------------------------------------------------------------------------------
1 | """This file contains the test corresponding to the usage example of Aggregator and Weighting."""
2 |
3 | import torch
4 | from torch.testing import assert_close
5 |
6 |
7 | def test_aggregation_and_weighting():
8 | from torch import tensor
9 |
10 | from torchjd.aggregation import UPGrad, UPGradWeighting
11 |
12 | aggregator = UPGrad()
13 | jacobian = tensor([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]])
14 | aggregation = aggregator(jacobian)
15 |
16 | assert_close(aggregation, tensor([0.2929, 1.9004, 1.9004]), rtol=0, atol=1e-4)
17 |
18 | weighting = UPGradWeighting()
19 | gramian = jacobian @ jacobian.T
20 | weights = weighting(gramian)
21 |
22 | assert_close(weights, tensor([1.1109, 0.7894]), rtol=0, atol=1e-4)
23 |
24 |
25 | def test_generalized_weighting():
26 | from torch import ones
27 |
28 | from torchjd.aggregation import Flattening, UPGradWeighting
29 |
30 | weighting = Flattening(UPGradWeighting())
31 | # Generate a generalized Gramian filled with ones, for the sake of the example
32 | generalized_gramian = ones((2, 3, 3, 2))
33 | weights = weighting(generalized_gramian)
34 |
35 | assert_close(weights, torch.full((2, 3), 0.1667), rtol=0, atol=1e-4)
36 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/test_aligned_mtl.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytest import mark
3 | from torch import Tensor
4 |
5 | from torchjd.aggregation import AlignedMTL
6 |
7 | from ._asserts import assert_expected_structure, assert_permutation_invariant
8 | from ._inputs import scaled_matrices, typical_matrices
9 |
10 | scaled_pairs = [(AlignedMTL(), matrix) for matrix in scaled_matrices]
11 | typical_pairs = [(AlignedMTL(), matrix) for matrix in typical_matrices]
12 |
13 |
14 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
15 | def test_expected_structure(aggregator: AlignedMTL, matrix: Tensor):
16 | assert_expected_structure(aggregator, matrix)
17 |
18 |
19 | @mark.parametrize(["aggregator", "matrix"], typical_pairs)
20 | def test_permutation_invariant(aggregator: AlignedMTL, matrix: Tensor):
21 | assert_permutation_invariant(aggregator, matrix)
22 |
23 |
24 | def test_representations():
25 | A = AlignedMTL(pref_vector=None)
26 | assert repr(A) == "AlignedMTL(pref_vector=None)"
27 | assert str(A) == "AlignedMTL"
28 |
29 | A = AlignedMTL(pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"))
30 | assert repr(A) == "AlignedMTL(pref_vector=tensor([1., 2., 3.]))"
31 | assert str(A) == "AlignedMTL([1., 2., 3.])"
32 |
--------------------------------------------------------------------------------
/src/torchjd/autogram/_gramian_accumulator.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from torch import Tensor
4 |
5 |
6 | class GramianAccumulator:
7 | """
8 | Efficiently accumulates the Gramian of the Jacobian during reverse-mode differentiation.
9 |
10 | Jacobians from multiple graph paths to the same parameter are first summed to obtain the full
11 | Jacobian w.r.t. a parameter, then its Gramian is computed and accumulated, over parameters, into
12 | the total Gramian matrix. Intermediate matrices are discarded immediately to save memory.
13 | """
14 |
15 | def __init__(self) -> None:
16 | self._gramian: Optional[Tensor] = None
17 |
18 | def reset(self) -> None:
19 | self._gramian = None
20 |
21 | def accumulate_gramian(self, gramian: Tensor) -> None:
22 | if self._gramian is not None:
23 | self._gramian.add_(gramian)
24 | else:
25 | self._gramian = gramian
26 |
27 | @property
28 | def gramian(self) -> Optional[Tensor]:
29 | """
30 | Get the Gramian matrix accumulated so far.
31 |
32 | :returns: Accumulated Gramian matrix of shape (batch_size, batch_size) or None if nothing
33 | was accumulated yet.
34 | """
35 |
36 | return self._gramian
37 |
--------------------------------------------------------------------------------
/tests/doc/test_autogram.py:
--------------------------------------------------------------------------------
1 | """This file contains tests for the usage examples related to autogram."""
2 |
3 |
4 | def test_engine():
5 | import torch
6 | from torch.nn import Linear, MSELoss, ReLU, Sequential
7 | from torch.optim import SGD
8 |
9 | from torchjd.aggregation import UPGradWeighting
10 | from torchjd.autogram import Engine
11 |
12 | # Generate data (8 batches of 16 examples of dim 5) for the sake of the example
13 | inputs = torch.randn(8, 16, 5)
14 | targets = torch.randn(8, 16)
15 |
16 | model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1))
17 | optimizer = SGD(model.parameters())
18 |
19 | criterion = MSELoss(reduction="none") # Important to use reduction="none"
20 | weighting = UPGradWeighting()
21 |
22 | # Create the engine before the backward pass, and only once.
23 | engine = Engine(model, batch_dim=0)
24 |
25 | for input, target in zip(inputs, targets):
26 | output = model(input).squeeze(dim=1) # shape: [16]
27 | losses = criterion(output, target) # shape: [16]
28 |
29 | optimizer.zero_grad()
30 | gramian = engine.compute_gramian(losses) # shape: [16, 16]
31 | weights = weighting(gramian) # shape: [16]
32 | losses.backward(weights)
33 | optimizer.step()
34 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # uv
2 | uv.lock
3 |
4 | # Jupyter Notebooks
5 | *.ipynb
6 |
7 | # PyCharm
8 | .idea/
9 |
10 | # Generated images
11 | images/
12 |
13 | # Byte-compiled / optimized / DLL files
14 | __pycache__/
15 | *.py[cod]
16 | *$py.class
17 |
18 | # C extensions
19 | *.so
20 |
21 | # Distribution / packaging
22 | .Python
23 | build/
24 | develop-eggs/
25 | dist/
26 | downloads/
27 | eggs/
28 | .eggs/
29 | lib/
30 | lib64/
31 | parts/
32 | sdist/
33 | var/
34 | wheels/
35 | pip-wheel-metadata/
36 | share/python-wheels/
37 | *.egg-info/
38 | .installed.cfg
39 | *.egg
40 | MANIFEST
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # Jupyter Notebook
68 | .ipynb_checkpoints
69 |
70 | # IPython
71 | profile_default/
72 | ipython_config.py
73 |
74 | # pyenv
75 | .python-version
76 |
77 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
78 | __pypackages__/
79 |
80 | # Environments
81 | .env
82 | .venv
83 | env/
84 | venv/
85 | ENV/
86 | env.bak/
87 | venv.bak/
88 |
89 | # mypy
90 | .mypy_cache/
91 | .dmypy.json
92 | dmypy.json
93 |
94 | # macOS file
95 | .DS_Store
96 |
97 | # VS Code folder
98 | .vscode/
99 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/test_sum.py:
--------------------------------------------------------------------------------
1 | from pytest import mark
2 | from torch import Tensor
3 |
4 | from torchjd.aggregation import Sum
5 |
6 | from ._asserts import (
7 | assert_expected_structure,
8 | assert_linear_under_scaling,
9 | assert_permutation_invariant,
10 | assert_strongly_stationary,
11 | )
12 | from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices
13 |
14 | scaled_pairs = [(Sum(), matrix) for matrix in scaled_matrices]
15 | typical_pairs = [(Sum(), matrix) for matrix in typical_matrices]
16 | non_strong_pairs = [(Sum(), matrix) for matrix in non_strong_matrices]
17 |
18 |
19 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
20 | def test_expected_structure(aggregator: Sum, matrix: Tensor):
21 | assert_expected_structure(aggregator, matrix)
22 |
23 |
24 | @mark.parametrize(["aggregator", "matrix"], typical_pairs)
25 | def test_permutation_invariant(aggregator: Sum, matrix: Tensor):
26 | assert_permutation_invariant(aggregator, matrix)
27 |
28 |
29 | @mark.parametrize(["aggregator", "matrix"], typical_pairs)
30 | def test_linear_under_scaling(aggregator: Sum, matrix: Tensor):
31 | assert_linear_under_scaling(aggregator, matrix)
32 |
33 |
34 | @mark.parametrize(["aggregator", "matrix"], non_strong_pairs)
35 | def test_strongly_stationary(aggregator: Sum, matrix: Tensor):
36 | assert_strongly_stationary(aggregator, matrix)
37 |
38 |
39 | def test_representations():
40 | A = Sum()
41 | assert repr(A) == "Sum()"
42 | assert str(A) == "Sum"
43 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/test_mean.py:
--------------------------------------------------------------------------------
1 | from pytest import mark
2 | from torch import Tensor
3 |
4 | from torchjd.aggregation import Mean
5 |
6 | from ._asserts import (
7 | assert_expected_structure,
8 | assert_linear_under_scaling,
9 | assert_permutation_invariant,
10 | assert_strongly_stationary,
11 | )
12 | from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices
13 |
14 | scaled_pairs = [(Mean(), matrix) for matrix in scaled_matrices]
15 | typical_pairs = [(Mean(), matrix) for matrix in typical_matrices]
16 | non_strong_pairs = [(Mean(), matrix) for matrix in non_strong_matrices]
17 |
18 |
19 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
20 | def test_expected_structure(aggregator: Mean, matrix: Tensor):
21 | assert_expected_structure(aggregator, matrix)
22 |
23 |
24 | @mark.parametrize(["aggregator", "matrix"], typical_pairs)
25 | def test_permutation_invariant(aggregator: Mean, matrix: Tensor):
26 | assert_permutation_invariant(aggregator, matrix)
27 |
28 |
29 | @mark.parametrize(["aggregator", "matrix"], typical_pairs)
30 | def test_linear_under_scaling(aggregator: Mean, matrix: Tensor):
31 | assert_linear_under_scaling(aggregator, matrix)
32 |
33 |
34 | @mark.parametrize(["aggregator", "matrix"], non_strong_pairs)
35 | def test_strongly_stationary(aggregator: Mean, matrix: Tensor):
36 | assert_strongly_stationary(aggregator, matrix)
37 |
38 |
39 | def test_representations():
40 | A = Mean()
41 | assert repr(A) == "Mean()"
42 | assert str(A) == "Mean"
43 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_imtl_g.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 | from ._aggregator_bases import GramianWeightedAggregator
5 | from ._utils.non_differentiable import raise_non_differentiable_error
6 | from ._weighting_bases import PSDMatrix, Weighting
7 |
8 |
9 | class IMTLG(GramianWeightedAggregator):
10 | """
11 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` generalizing the method described in
12 | `Towards Impartial Multi-task Learning `_.
13 | This generalization, defined formally in `Jacobian Descent For Multi-Objective Optimization
14 | `_, supports matrices with some linearly dependant rows.
15 | """
16 |
17 | def __init__(self):
18 | super().__init__(IMTLGWeighting())
19 |
20 | # This prevents computing gradients that can be very wrong.
21 | self.register_full_backward_pre_hook(raise_non_differentiable_error)
22 |
23 |
24 | class IMTLGWeighting(Weighting[PSDMatrix]):
25 | """
26 | :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
27 | :class:`~torchjd.aggregation.IMTLG`.
28 | """
29 |
30 | def forward(self, gramian: Tensor) -> Tensor:
31 | d = torch.sqrt(torch.diagonal(gramian))
32 | v = torch.linalg.pinv(gramian) @ d
33 | v_sum = v.sum()
34 |
35 | if v_sum.abs() < 1e-12:
36 | weights = torch.zeros_like(v)
37 | else:
38 | weights = v / v_sum
39 |
40 | return weights
41 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_utils/gramian.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 |
5 | def compute_gramian(matrix: Tensor) -> Tensor:
6 | """
7 | Computes the `Gramian matrix `_ of a given matrix.
8 | """
9 |
10 | return matrix @ matrix.T
11 |
12 |
13 | def normalize(gramian: Tensor, eps: float) -> Tensor:
14 | """
15 | Normalizes the gramian `G=AA^T` with respect to the Frobenius norm of `A`.
16 |
17 | If `G=A A^T`, then the Frobenius norm of `A` is the square root of the trace of `G`, i.e., the
18 | sqrt of the sum of the diagonal elements. The gramian of the (Frobenius) normalization of `A` is
19 | therefore `G` divided by the sum of its diagonal elements.
20 | """
21 | squared_frobenius_norm = gramian.diagonal().sum()
22 | if squared_frobenius_norm < eps:
23 | return torch.zeros_like(gramian)
24 | else:
25 | return gramian / squared_frobenius_norm
26 |
27 |
28 | def regularize(gramian: Tensor, eps: float) -> Tensor:
29 | """
30 | Adds a regularization term to the gramian to enforce positive definiteness.
31 |
32 | Because of numerical errors, `gramian` might have slightly negative eigenvalue(s). Adding a
33 | regularization term which is a small proportion of the identity matrix ensures that the gramian
34 | is positive definite.
35 | """
36 |
37 | regularization_matrix = eps * torch.eye(
38 | gramian.shape[0], dtype=gramian.dtype, device=gramian.device
39 | )
40 | return gramian + regularization_matrix
41 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_flattening.py:
--------------------------------------------------------------------------------
1 | from math import prod
2 |
3 | from torch import Tensor
4 |
5 | from torchjd.aggregation._weighting_bases import GeneralizedWeighting, PSDMatrix, Weighting
6 | from torchjd.autogram._gramian_utils import reshape_gramian
7 |
8 |
9 | class Flattening(GeneralizedWeighting):
10 | """
11 | :class:`~torchjd.aggregation._weighting_bases.GeneralizedWeighting` flattening the generalized
12 | Gramian into a square matrix, extracting a vector of weights from it using a
13 | :class:`~torchjd.aggregation._weighting_bases.Weighting`, and returning the reshaped tensor of
14 | weights.
15 |
16 | For instance, when applied to a generalized Gramian of shape ``[2, 3, 3, 2]``, it would flatten
17 | it into a square Gramian matrix of shape ``[6, 6]``, apply the weighting on it to get a vector
18 | of weights of shape ``[6]``, and then return this vector reshaped into a matrix of shape
19 | ``[2, 3]``.
20 |
21 | :param weighting: The weighting to apply to the Gramian matrix.
22 | """
23 |
24 | def __init__(self, weighting: Weighting[PSDMatrix]):
25 | super().__init__()
26 | self.weighting = weighting
27 |
28 | def forward(self, generalized_gramian: Tensor) -> Tensor:
29 | k = generalized_gramian.ndim // 2
30 | shape = generalized_gramian.shape[:k]
31 | m = prod(shape)
32 | square_gramian = reshape_gramian(generalized_gramian, [m])
33 | weights_vector = self.weighting(square_gramian)
34 | weights = weights_vector.reshape(shape)
35 | return weights
36 |
--------------------------------------------------------------------------------
/src/torchjd/autojac/_transform/_ordered_set.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections import OrderedDict
4 | from collections.abc import Hashable, Iterable, Iterator, MutableSet
5 | from typing import TypeVar
6 |
7 | _T = TypeVar("_T", bound=Hashable)
8 |
9 |
10 | class OrderedSet(MutableSet[_T]):
11 | """Ordered collection of distinct elements."""
12 |
13 | def __init__(self, elements: Iterable[_T]):
14 | super().__init__()
15 | self.ordered_dict = OrderedDict[_T, None]([(element, None) for element in elements])
16 |
17 | def difference_update(self, elements: set[_T]) -> None:
18 | """Removes all specified elements from the OrderedSet."""
19 |
20 | for element in elements:
21 | self.discard(element)
22 |
23 | def add(self, element: _T) -> None:
24 | """Adds the specified element to the OrderedSet."""
25 |
26 | self.ordered_dict[element] = None
27 |
28 | def __add__(self, other: OrderedSet[_T]) -> OrderedSet[_T]:
29 | """Creates a new OrderedSet with the elements of self followed by the elements of other."""
30 |
31 | return OrderedSet([*self, *other])
32 |
33 | def discard(self, value: _T) -> None:
34 | if value in self:
35 | del self.ordered_dict[value]
36 |
37 | def __iter__(self) -> Iterator[_T]:
38 | return self.ordered_dict.__iter__()
39 |
40 | def __len__(self) -> int:
41 | return len(self.ordered_dict)
42 |
43 | def __contains__(self, element: object) -> bool:
44 | return element in self.ordered_dict
45 |
--------------------------------------------------------------------------------
/src/torchjd/autogram/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The autogram package provides an engine to efficiently compute the Gramian of the Jacobian of a
3 | tensor of outputs (generally losses) with respect to some modules' parameters. This Gramian contains
4 | all the inner products between pairs of gradients, and is thus a sufficient statistic for most
5 | weighting methods. The algorithm is formally defined in Section 6 of `Jacobian Descent For
6 | Multi-Objective Optimization `_).
7 |
8 | Due to computing the Gramian iteratively over the layers, without ever having to store the full
9 | Jacobian in memory, this method is much more memory-efficient than
10 | :doc:`autojac <../autojac/index>`, which makes it often much faster. Note that we're still working
11 | on making autogram faster and more memory-efficient, and it's interface may change in future
12 | releases.
13 |
14 | The list of Weightings compatible with ``autogram`` is:
15 |
16 | * :class:`~torchjd.aggregation.UPGradWeighting`
17 | * :class:`~torchjd.aggregation.AlignedMTLWeighting`
18 | * :class:`~torchjd.aggregation.CAGradWeighting`
19 | * :class:`~torchjd.aggregation.ConstantWeighting`
20 | * :class:`~torchjd.aggregation.DualProjWeighting`
21 | * :class:`~torchjd.aggregation.IMTLGWeighting`
22 | * :class:`~torchjd.aggregation.KrumWeighting`
23 | * :class:`~torchjd.aggregation.MeanWeighting`
24 | * :class:`~torchjd.aggregation.MGDAWeighting`
25 | * :class:`~torchjd.aggregation.PCGradWeighting`
26 | * :class:`~torchjd.aggregation.RandomWeighting`
27 | * :class:`~torchjd.aggregation.SumWeighting`
28 | """
29 |
30 | from ._engine import Engine
31 |
--------------------------------------------------------------------------------
/tests/utils/tensors.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 | from device import DEVICE
5 | from torch import nn
6 | from torch.utils._pytree import PyTree, tree_map
7 | from utils.architectures import get_in_out_shapes
8 | from utils.contexts import fork_rng
9 |
10 | # Curried calls to torch functions that require a device so that we automatically fix the device
11 | # for code written in the tests, while not affecting code written in src (what
12 | # torch.set_default_device or what a too large `with torch.device(DEVICE)` context would have done).
13 |
14 | empty_ = partial(torch.empty, device=DEVICE)
15 | eye_ = partial(torch.eye, device=DEVICE)
16 | ones_ = partial(torch.ones, device=DEVICE)
17 | rand_ = partial(torch.rand, device=DEVICE)
18 | randint_ = partial(torch.randint, device=DEVICE)
19 | randn_ = partial(torch.randn, device=DEVICE)
20 | randperm_ = partial(torch.randperm, device=DEVICE)
21 | tensor_ = partial(torch.tensor, device=DEVICE)
22 | zeros_ = partial(torch.zeros, device=DEVICE)
23 |
24 |
25 | def make_inputs_and_targets(model: nn.Module, batch_size: int) -> tuple[PyTree, PyTree]:
26 | input_shapes, output_shapes = get_in_out_shapes(model)
27 | with fork_rng(seed=0):
28 | inputs = _make_tensors(batch_size, input_shapes)
29 | targets = _make_tensors(batch_size, output_shapes)
30 |
31 | return inputs, targets
32 |
33 |
34 | def _make_tensors(batch_size: int, tensor_shapes: PyTree) -> PyTree:
35 | def is_leaf(s):
36 | return isinstance(s, tuple) and all([isinstance(e, int) for e in s])
37 |
38 | return tree_map(lambda s: randn_((batch_size,) + s), tensor_shapes, is_leaf=is_leaf)
39 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/test_imtl_g.py:
--------------------------------------------------------------------------------
1 | from pytest import mark
2 | from torch import Tensor
3 | from torch.testing import assert_close
4 | from utils.tensors import ones_, zeros_
5 |
6 | from torchjd.aggregation import IMTLG
7 |
8 | from ._asserts import (
9 | assert_expected_structure,
10 | assert_non_differentiable,
11 | assert_permutation_invariant,
12 | )
13 | from ._inputs import scaled_matrices, typical_matrices
14 |
15 | scaled_pairs = [(IMTLG(), matrix) for matrix in scaled_matrices]
16 | typical_pairs = [(IMTLG(), matrix) for matrix in typical_matrices]
17 | requires_grad_pairs = [(IMTLG(), ones_(3, 5, requires_grad=True))]
18 |
19 |
20 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
21 | def test_expected_structure(aggregator: IMTLG, matrix: Tensor):
22 | assert_expected_structure(aggregator, matrix)
23 |
24 |
25 | @mark.parametrize(["aggregator", "matrix"], typical_pairs)
26 | def test_permutation_invariant(aggregator: IMTLG, matrix: Tensor):
27 | assert_permutation_invariant(aggregator, matrix)
28 |
29 |
30 | @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs)
31 | def test_non_differentiable(aggregator: IMTLG, matrix: Tensor):
32 | assert_non_differentiable(aggregator, matrix)
33 |
34 |
35 | def test_imtlg_zero():
36 | """
37 | Tests that IMTLG correctly returns the 0 vector in the special case where input matrix only
38 | consists of zeros.
39 | """
40 |
41 | A = IMTLG()
42 | J = zeros_(2, 3)
43 | assert_close(A(J), zeros_(3))
44 |
45 |
46 | def test_representations():
47 | A = IMTLG()
48 | assert repr(A) == "IMTLG()"
49 | assert str(A) == "IMTLG"
50 |
--------------------------------------------------------------------------------
/docs/source/examples/rnn.rst:
--------------------------------------------------------------------------------
1 | Recurrent Neural Network (RNN)
2 | ==============================
3 |
4 | When training recurrent neural networks for sequence modelling, we can easily obtain one loss per
5 | element of the output sequences. If the gradients of these losses are likely to conflict, Jacobian
6 | descent can be leveraged to enhance optimization.
7 |
8 | .. code-block:: python
9 | :emphasize-lines: 5-6, 10, 17, 20
10 |
11 | import torch
12 | from torch.nn import RNN
13 | from torch.optim import SGD
14 |
15 | from torchjd.aggregation import UPGrad
16 | from torchjd.autojac import backward
17 |
18 | rnn = RNN(input_size=10, hidden_size=20, num_layers=2)
19 | optimizer = SGD(rnn.parameters(), lr=0.1)
20 | aggregator = UPGrad()
21 |
22 | inputs = torch.randn(8, 5, 3, 10) # 8 batches of 3 sequences of length 5 and of dim 10.
23 | targets = torch.randn(8, 5, 3, 20) # 8 batches of 3 sequences of length 5 and of dim 20.
24 |
25 | for input, target in zip(inputs, targets):
26 | output, _ = rnn(input) # output is of shape [5, 3, 20].
27 | losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.
28 |
29 | optimizer.zero_grad()
30 | backward(losses, aggregator, parallel_chunk_size=1)
31 | optimizer.step()
32 |
33 | .. note::
34 | At the time of writing, there seems to be an incompatibility between ``torch.vmap`` and
35 | ``torch.nn.RNN`` when running on CUDA (see `this issue
36 | `_ for more info), so we advise to set the
37 | ``parallel_chunk_size`` to ``1`` to avoid using ``torch.vmap``. To improve performance, you can
38 | check whether ``parallel_chunk_size=None`` (maximal parallelization) works on your side.
39 |
--------------------------------------------------------------------------------
/src/torchjd/autojac/_transform/_accumulate.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 |
3 | from ._base import TensorDict, Transform
4 |
5 |
6 | class Accumulate(Transform):
7 | """
8 | Transform from Gradients to {} that accumulates gradients with respect to keys into their
9 | ``grad`` field.
10 | """
11 |
12 | def __call__(self, gradients: TensorDict) -> TensorDict:
13 | for key in gradients.keys():
14 | _check_expects_grad(key)
15 | if hasattr(key, "grad") and key.grad is not None:
16 | key.grad += gradients[key]
17 | else:
18 | # We clone the value because we do not want subsequent accumulations to also affect
19 | # this value (in case it is still used outside). We do not detach from the
20 | # computation graph because the value can have grad_fn that we want to keep track of
21 | # (in case it was obtained via create_graph=True and a differentiable aggregator).
22 | key.grad = gradients[key].clone()
23 |
24 | return {}
25 |
26 | def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
27 | return set()
28 |
29 |
30 | def _check_expects_grad(tensor: Tensor) -> None:
31 | if not _expects_grad(tensor):
32 | raise ValueError(
33 | "Cannot populate the .grad field of a Tensor that does not satisfy:"
34 | "`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`."
35 | )
36 |
37 |
38 | def _expects_grad(tensor: Tensor) -> bool:
39 | """
40 | Determines whether a Tensor expects its .grad attribute to be populated.
41 | See https://pytorch.org/docs/stable/generated/torch.Tensor.is_leaf for more information.
42 | """
43 |
44 | return tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)
45 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v6.0.0
4 | hooks:
5 | - id: trailing-whitespace # Trim trailing whitespace at the end of lines.
6 | - id: end-of-file-fixer # Make sure files end in a newline and only a newline.
7 | - id: check-added-large-files # Prevent giant files from being committed.
8 | - id: check-case-conflict # Check for files that would conflict in case-insensitive filesystems.
9 | - id: check-docstring-first # Check a common error of defining a docstring after code.
10 | - id: check-merge-conflict # Check for files that contain merge conflict strings.
11 |
12 | - repo: https://github.com/PyCQA/flake8
13 | rev: 7.3.0
14 | hooks:
15 | - id: flake8 # Check style and syntax. Does not modify code, issues have to be solved manually.
16 | args: [
17 | '--ignore=E501,E203,W503,E402', # Ignore line length problems, space after colon problems, line break occurring before a binary operator problems, module level import not at top of file problems.
18 | '--per-file-ignores=*/__init__.py:F401', # Ignore module imported but unused problems in __init__.py files.
19 | ]
20 |
21 | - repo: https://github.com/pycqa/isort
22 | rev: 6.1.0
23 | hooks:
24 | - id: isort # Sort imports.
25 | args: [
26 | --multi-line=3,
27 | --line-length=100,
28 | --trailing-comma,
29 | --force-grid-wrap=0,
30 | --use-parentheses,
31 | --ensure-newline-before-comments,
32 | ]
33 |
34 | - repo: https://github.com/psf/black-pre-commit-mirror
35 | rev: 25.9.0
36 | hooks:
37 | - id: black # Format code.
38 | args: [--line-length=100]
39 |
40 | ci:
41 | autoupdate_commit_msg: 'chore: Update pre-commit hooks'
42 | autoupdate_schedule: quarterly
43 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/test_config.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytest import mark
3 | from torch import Tensor
4 | from utils.tensors import ones_
5 |
6 | from torchjd.aggregation import ConFIG
7 |
8 | from ._asserts import (
9 | assert_expected_structure,
10 | assert_linear_under_scaling,
11 | assert_non_differentiable,
12 | assert_permutation_invariant,
13 | )
14 | from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices
15 |
16 | scaled_pairs = [(ConFIG(), matrix) for matrix in scaled_matrices]
17 | typical_pairs = [(ConFIG(), matrix) for matrix in typical_matrices]
18 | non_strong_pairs = [(ConFIG(), matrix) for matrix in non_strong_matrices]
19 | requires_grad_pairs = [(ConFIG(), ones_(3, 5, requires_grad=True))]
20 |
21 |
22 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
23 | def test_expected_structure(aggregator: ConFIG, matrix: Tensor):
24 | assert_expected_structure(aggregator, matrix)
25 |
26 |
27 | @mark.parametrize(["aggregator", "matrix"], typical_pairs)
28 | def test_permutation_invariant(aggregator: ConFIG, matrix: Tensor):
29 | assert_permutation_invariant(aggregator, matrix)
30 |
31 |
32 | @mark.parametrize(["aggregator", "matrix"], typical_pairs)
33 | def test_linear_under_scaling(aggregator: ConFIG, matrix: Tensor):
34 | assert_linear_under_scaling(aggregator, matrix)
35 |
36 |
37 | @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs)
38 | def test_non_differentiable(aggregator: ConFIG, matrix: Tensor):
39 | assert_non_differentiable(aggregator, matrix)
40 |
41 |
42 | def test_representations():
43 | A = ConFIG()
44 | assert repr(A) == "ConFIG(pref_vector=None)"
45 | assert str(A) == "ConFIG"
46 |
47 | A = ConFIG(pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"))
48 | assert repr(A) == "ConFIG(pref_vector=tensor([1., 2., 3.]))"
49 | assert str(A) == "ConFIG([1., 2., 3.])"
50 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/_inputs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from device import DEVICE
3 | from utils.tensors import zeros_
4 |
5 | from ._matrix_samplers import NonWeakSampler, NormalSampler, StrictlyWeakSampler, StrongSampler
6 |
7 | _normal_dims = [
8 | (1, 1, 1),
9 | (4, 3, 1),
10 | (4, 3, 2),
11 | (4, 3, 3),
12 | (9, 11, 5),
13 | (9, 11, 9),
14 | ]
15 |
16 | _zero_dims = [
17 | (1, 1, 0),
18 | (4, 3, 0),
19 | (9, 11, 0),
20 | ]
21 |
22 | _stationarity_dims = [
23 | (20, 10, 10),
24 | (20, 10, 5),
25 | (20, 10, 1),
26 | (20, 100, 1),
27 | (20, 100, 19),
28 | ]
29 |
30 | _scales = [0.0, 1e-10, 1e3, 1e5, 1e10, 1e15]
31 |
32 | _rng = torch.Generator(device=DEVICE).manual_seed(0)
33 |
34 | matrices = [NormalSampler(m, n, r)(_rng) for m, n, r in _normal_dims]
35 | zero_matrices = [zeros_([m, n]) for m, n, _ in _zero_dims]
36 | strong_matrices = [StrongSampler(m, n, r)(_rng) for m, n, r in _stationarity_dims]
37 | strictly_weak_matrices = [StrictlyWeakSampler(m, n, r)(_rng) for m, n, r in _stationarity_dims]
38 | non_weak_matrices = [NonWeakSampler(m, n, r)(_rng) for m, n, r in _stationarity_dims]
39 |
40 | scaled_matrices = [scale * matrix for scale in _scales for matrix in matrices]
41 |
42 | non_strong_matrices = strictly_weak_matrices + non_weak_matrices
43 | typical_matrices = zero_matrices + matrices + strong_matrices + non_strong_matrices
44 |
45 | scaled_matrices_2_plus_rows = [matrix for matrix in scaled_matrices if matrix.shape[0] >= 2]
46 | typical_matrices_2_plus_rows = [matrix for matrix in typical_matrices if matrix.shape[0] >= 2]
47 |
48 | # It seems that NashMTL does not work for matrices with 1 row, so we make different matrices for it.
49 | _nashmtl_dims = [
50 | (3, 1, 1),
51 | (4, 3, 1),
52 | (4, 3, 2),
53 | (4, 3, 3),
54 | (9, 11, 5),
55 | (9, 11, 9),
56 | ]
57 | nash_mtl_matrices = [NormalSampler(m, n, r)(_rng) for m, n, r in _nashmtl_dims]
58 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_constant.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 |
3 | from ._aggregator_bases import WeightedAggregator
4 | from ._utils.str import vector_to_str
5 | from ._weighting_bases import Matrix, Weighting
6 |
7 |
8 | class Constant(WeightedAggregator):
9 | """
10 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that makes a linear combination of
11 | the rows of the provided matrix, with constant, pre-determined weights.
12 |
13 | :param weights: The weights associated to the rows of the input matrices.
14 | """
15 |
16 | def __init__(self, weights: Tensor):
17 | super().__init__(weighting=ConstantWeighting(weights=weights))
18 | self._weights = weights
19 |
20 | def __repr__(self) -> str:
21 | return f"{self.__class__.__name__}(weights={repr(self._weights)})"
22 |
23 | def __str__(self) -> str:
24 | weights_str = vector_to_str(self._weights)
25 | return f"{self.__class__.__name__}([{weights_str}])"
26 |
27 |
28 | class ConstantWeighting(Weighting[Matrix]):
29 | """
30 | :class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined
31 | weights.
32 |
33 | :param weights: The weights to return at each call.
34 | """
35 |
36 | def __init__(self, weights: Tensor):
37 | if weights.dim() != 1:
38 | raise ValueError(
39 | "Parameter `weights` should be a 1-dimensional tensor. Found `weights.shape = "
40 | f"{weights.shape}`."
41 | )
42 |
43 | super().__init__()
44 | self.weights = weights
45 |
46 | def forward(self, matrix: Tensor) -> Tensor:
47 | self._check_matrix_shape(matrix)
48 | return self.weights
49 |
50 | def _check_matrix_shape(self, matrix: Tensor) -> None:
51 | if matrix.shape[0] != len(self.weights):
52 | raise ValueError(
53 | f"Parameter `matrix` should have {len(self.weights)} rows (the number of specified "
54 | f"weights). Found `matrix` with {matrix.shape[0]} rows."
55 | )
56 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_pcgrad.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 | from ._aggregator_bases import GramianWeightedAggregator
5 | from ._utils.non_differentiable import raise_non_differentiable_error
6 | from ._weighting_bases import PSDMatrix, Weighting
7 |
8 |
9 | class PCGrad(GramianWeightedAggregator):
10 | """
11 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in algorithm 1 of
12 | `Gradient Surgery for Multi-Task Learning `_.
13 | """
14 |
15 | def __init__(self):
16 | super().__init__(PCGradWeighting())
17 |
18 | # This prevents running into a RuntimeError due to modifying stored tensors in place.
19 | self.register_full_backward_pre_hook(raise_non_differentiable_error)
20 |
21 |
22 | class PCGradWeighting(Weighting[PSDMatrix]):
23 | """
24 | :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
25 | :class:`~torchjd.aggregation.PCGrad`.
26 | """
27 |
28 | def forward(self, gramian: Tensor) -> Tensor:
29 | # Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration
30 | device = gramian.device
31 | dtype = gramian.dtype
32 | cpu = torch.device("cpu")
33 | gramian = gramian.to(device=cpu)
34 |
35 | dimension = gramian.shape[0]
36 | weights = torch.zeros(dimension, device=cpu, dtype=dtype)
37 |
38 | for i in range(dimension):
39 | permutation = torch.randperm(dimension)
40 | current_weights = torch.zeros(dimension, device=cpu, dtype=dtype)
41 | current_weights[i] = 1.0
42 |
43 | for j in permutation:
44 | if j == i:
45 | continue
46 |
47 | # Compute the inner product between g_i^{PC} and g_j
48 | inner_product = gramian[j] @ current_weights
49 |
50 | if inner_product < 0.0:
51 | current_weights[j] -= inner_product / (gramian[j, j])
52 |
53 | weights = weights + current_weights
54 |
55 | return weights.to(device)
56 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/test_cagrad.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext as does_not_raise
2 |
3 | from pytest import mark, raises
4 | from torch import Tensor
5 | from utils.contexts import ExceptionContext
6 | from utils.tensors import ones_
7 |
8 | from torchjd.aggregation import CAGrad
9 |
10 | from ._asserts import assert_expected_structure, assert_non_conflicting, assert_non_differentiable
11 | from ._inputs import scaled_matrices, typical_matrices
12 |
13 | scaled_pairs = [(CAGrad(c=0.5), matrix) for matrix in scaled_matrices]
14 | typical_pairs = [(CAGrad(c=0.5), matrix) for matrix in typical_matrices]
15 | requires_grad_pairs = [(CAGrad(c=0.5), ones_(3, 5, requires_grad=True))]
16 | non_conflicting_pairs_1 = [(CAGrad(c=1.0), matrix) for matrix in typical_matrices]
17 | non_conflicting_pairs_2 = [(CAGrad(c=2.0), matrix) for matrix in typical_matrices]
18 |
19 |
20 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
21 | def test_expected_structure(aggregator: CAGrad, matrix: Tensor):
22 | assert_expected_structure(aggregator, matrix)
23 |
24 |
25 | @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs)
26 | def test_non_differentiable(aggregator: CAGrad, matrix: Tensor):
27 | assert_non_differentiable(aggregator, matrix)
28 |
29 |
30 | @mark.parametrize(["aggregator", "matrix"], non_conflicting_pairs_1 + non_conflicting_pairs_2)
31 | def test_non_conflicting(aggregator: CAGrad, matrix: Tensor):
32 | """Tests that CAGrad is non-conflicting when c >= 1 (it should not hold when c < 1)."""
33 | assert_non_conflicting(aggregator, matrix)
34 |
35 |
36 | @mark.parametrize(
37 | ["c", "expectation"],
38 | [
39 | (-5.0, raises(ValueError)),
40 | (-1.0, raises(ValueError)),
41 | (0.0, does_not_raise()),
42 | (1.0, does_not_raise()),
43 | (50.0, does_not_raise()),
44 | ],
45 | )
46 | def test_c_check(c: float, expectation: ExceptionContext):
47 | with expectation:
48 | _ = CAGrad(c=c)
49 |
50 |
51 | def test_representations():
52 | A = CAGrad(c=0.5, norm_eps=0.0001)
53 | assert repr(A) == "CAGrad(c=0.5, norm_eps=0.0001)"
54 | assert str(A) == "CAGrad0.5"
55 |
--------------------------------------------------------------------------------
/docs/source/examples/partial_jd.rst:
--------------------------------------------------------------------------------
1 | Partial Jacobian Descent for IWRM
2 | =================================
3 |
4 | This example demonstrates how to perform Partial Jacobian Descent using TorchJD. This technique
5 | minimizes a vector of per-instance losses by resolving conflict only based on a submatrix of the
6 | Jacobian — specifically, the portion corresponding to a selected subset of the model's parameters.
7 | This approach offers a trade-off between the precision of the aggregation decision and the
8 | computational cost associated with computing the Gramian of the full Jacobian. For a complete,
9 | non-partial version, see the :doc:`IWRM ` example.
10 |
11 | In this example, our model consists of three ``Linear`` layers separated by ``ReLU`` layers. We
12 | perform the partial descent by considering only the parameters of the last two ``Linear`` layers. By
13 | doing this, we avoid computing the Jacobian and its Gramian with respect to the parameters of the
14 | first ``Linear`` layer, thereby reducing memory usage and computation time.
15 |
16 | .. code-block:: python
17 | :emphasize-lines: 16-18
18 |
19 | import torch
20 | from torch.nn import Linear, MSELoss, ReLU, Sequential
21 | from torch.optim import SGD
22 |
23 | from torchjd.aggregation import UPGradWeighting
24 | from torchjd.autogram import Engine
25 |
26 | X = torch.randn(8, 16, 10)
27 | Y = torch.randn(8, 16)
28 |
29 | model = Sequential(Linear(10, 8), ReLU(), Linear(8, 5), ReLU(), Linear(5, 1))
30 | loss_fn = MSELoss(reduction="none")
31 |
32 | weighting = UPGradWeighting()
33 |
34 | # Create the autogram engine that will compute the Gramian of the
35 | # Jacobian with respect to the two last Linear layers' parameters.
36 | engine = Engine(model[2:], batch_dim=0)
37 |
38 | params = model.parameters()
39 | optimizer = SGD(params, lr=0.1)
40 |
41 | for x, y in zip(X, Y):
42 | y_hat = model(x).squeeze(dim=1) # shape: [16]
43 | losses = loss_fn(y_hat, y) # shape: [16]
44 | optimizer.zero_grad()
45 | gramian = engine.compute_gramian(losses)
46 | weights = weighting(gramian)
47 | losses.backward(weights)
48 | optimizer.step()
49 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_trimmed_mean.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 | from ._aggregator_bases import Aggregator
5 |
6 |
7 | class TrimmedMean(Aggregator):
8 | """
9 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` for adversarial federated learning,
10 | that trims the most extreme values of the input matrix, before averaging its rows, as defined in
11 | `Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates
12 | `_.
13 |
14 | :param trim_number: The number of maximum and minimum values to remove from each column of the
15 | input matrix (note that ``2 * trim_number`` values are removed from each column).
16 | """
17 |
18 | def __init__(self, trim_number: int):
19 | super().__init__()
20 | if trim_number < 0:
21 | raise ValueError(
22 | "Parameter `trim_number` should be a non-negative integer. Found `trim_number` = "
23 | f"{trim_number}`."
24 | )
25 | self.trim_number = trim_number
26 |
27 | def forward(self, matrix: Tensor) -> Tensor:
28 | self._check_is_matrix(matrix)
29 | self._check_matrix_has_enough_rows(matrix)
30 |
31 | n_rows = matrix.shape[0]
32 | n_remaining = n_rows - 2 * self.trim_number
33 |
34 | sorted_matrix, _ = torch.sort(matrix, dim=0)
35 | trimmed = torch.narrow(sorted_matrix, dim=0, start=self.trim_number, length=n_remaining)
36 | vector = trimmed.mean(dim=0)
37 | return vector
38 |
39 | def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None:
40 | min_rows = 1 + 2 * self.trim_number
41 | n_rows = matrix.shape[0]
42 | if n_rows < min_rows:
43 | raise ValueError(
44 | f"Parameter `matrix` should be a matrix of at least {min_rows} rows "
45 | f"(i.e. `2 * trim_number + 1`). Found `matrix` of shape `{matrix.shape}`."
46 | )
47 |
48 | def __repr__(self) -> str:
49 | return f"{self.__class__.__name__}(trim_number={self.trim_number})"
50 |
51 | def __str__(self) -> str:
52 | return f"TM{self.trim_number}"
53 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/test_pcgrad.py:
--------------------------------------------------------------------------------
1 | from pytest import mark
2 | from torch import Tensor
3 | from torch.testing import assert_close
4 | from utils.tensors import ones_, randn_
5 |
6 | from torchjd.aggregation import PCGrad
7 | from torchjd.aggregation._pcgrad import PCGradWeighting
8 | from torchjd.aggregation._upgrad import UPGradWeighting
9 | from torchjd.aggregation._utils.gramian import compute_gramian
10 |
11 | from ._asserts import assert_expected_structure, assert_non_differentiable
12 | from ._inputs import scaled_matrices, typical_matrices
13 |
14 | scaled_pairs = [(PCGrad(), matrix) for matrix in scaled_matrices]
15 | typical_pairs = [(PCGrad(), matrix) for matrix in typical_matrices]
16 | requires_grad_pairs = [(PCGrad(), ones_(3, 5, requires_grad=True))]
17 |
18 |
19 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
20 | def test_expected_structure(aggregator: PCGrad, matrix: Tensor):
21 | assert_expected_structure(aggregator, matrix)
22 |
23 |
24 | @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs)
25 | def test_non_differentiable(aggregator: PCGrad, matrix: Tensor):
26 | assert_non_differentiable(aggregator, matrix)
27 |
28 |
29 | @mark.parametrize(
30 | "shape",
31 | [
32 | (2, 5),
33 | (2, 7),
34 | (2, 9),
35 | (2, 15),
36 | (2, 27),
37 | (2, 68),
38 | (2, 102),
39 | (2, 57),
40 | (2, 1200),
41 | (2, 11100),
42 | ],
43 | )
44 | def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]):
45 | """
46 | Tests that UPGradWeighting of a SumWeighting is equivalent to PCGradWeighting for matrices of 2
47 | rows.
48 | """
49 |
50 | matrix = randn_(shape)
51 | gramian = compute_gramian(matrix)
52 |
53 | pc_grad_weighting = PCGradWeighting()
54 | upgrad_sum_weighting = UPGradWeighting(
55 | ones_((2,)), norm_eps=0.0, reg_eps=0.0, solver="quadprog"
56 | )
57 |
58 | result = pc_grad_weighting(gramian)
59 | expected = upgrad_sum_weighting(gramian)
60 |
61 | assert_close(result, expected, atol=4e-04, rtol=0.0)
62 |
63 |
64 | def test_representations():
65 | A = PCGrad()
66 | assert repr(A) == "PCGrad()"
67 | assert str(A) == "PCGrad"
68 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/test_trimmed_mean.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext as does_not_raise
2 |
3 | from pytest import mark, raises
4 | from torch import Tensor
5 | from utils.contexts import ExceptionContext
6 | from utils.tensors import ones_
7 |
8 | from torchjd.aggregation import TrimmedMean
9 |
10 | from ._asserts import assert_expected_structure, assert_permutation_invariant
11 | from ._inputs import scaled_matrices_2_plus_rows, typical_matrices_2_plus_rows
12 |
13 | scaled_pairs = [(TrimmedMean(trim_number=1), matrix) for matrix in scaled_matrices_2_plus_rows]
14 | typical_pairs = [(TrimmedMean(trim_number=1), matrix) for matrix in typical_matrices_2_plus_rows]
15 |
16 |
17 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
18 | def test_expected_structure(aggregator: TrimmedMean, matrix: Tensor):
19 | assert_expected_structure(aggregator, matrix)
20 |
21 |
22 | @mark.parametrize(["aggregator", "matrix"], typical_pairs)
23 | def test_permutation_invariant(aggregator: TrimmedMean, matrix: Tensor):
24 | assert_permutation_invariant(aggregator, matrix)
25 |
26 |
27 | @mark.parametrize(
28 | ["trim_number", "expectation"],
29 | [
30 | (-5, raises(ValueError)),
31 | (-1, raises(ValueError)),
32 | (0, does_not_raise()),
33 | (1, does_not_raise()),
34 | (5, does_not_raise()),
35 | ],
36 | )
37 | def test_trim_number_check(trim_number: int, expectation: ExceptionContext):
38 | with expectation:
39 | _ = TrimmedMean(trim_number=trim_number)
40 |
41 |
42 | @mark.parametrize(
43 | ["n_rows", "trim_number", "expectation"],
44 | [
45 | (1, 0, does_not_raise()),
46 | (1, 1, raises(ValueError)),
47 | (10, 0, does_not_raise()),
48 | (10, 4, does_not_raise()),
49 | (10, 5, raises(ValueError)),
50 | ],
51 | )
52 | def test_matrix_shape_check(n_rows: int, trim_number: int, expectation: ExceptionContext):
53 | matrix = ones_([n_rows, 5])
54 | aggregator = TrimmedMean(trim_number=trim_number)
55 |
56 | with expectation:
57 | _ = aggregator(matrix)
58 |
59 |
60 | def test_representations():
61 | aggregator = TrimmedMean(trim_number=2)
62 | assert repr(aggregator) == "TrimmedMean(trim_number=2)"
63 | assert str(aggregator) == "TM2"
64 |
--------------------------------------------------------------------------------
/src/torchjd/autogram/_edge_registry.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 |
3 | from torch.autograd.graph import GradientEdge
4 |
5 |
6 | class EdgeRegistry:
7 | """
8 | Tracks `GradientEdge`s and provides a way to efficiently compute a minimally sufficient subset
9 | of leaf edges that are reachable from some given `GradientEdge`s.
10 | """
11 |
12 | def __init__(self) -> None:
13 | self._edges: set[GradientEdge] = set()
14 |
15 | def reset(self) -> None:
16 | self._edges = set()
17 |
18 | def register(self, edge: GradientEdge) -> None:
19 | """
20 | Track the provided edge.
21 |
22 | :param edge: Edge to track.
23 | """
24 | self._edges.add(edge)
25 |
26 | def get_leaf_edges(self, roots: set[GradientEdge]) -> set[GradientEdge]:
27 | """
28 | Compute a minimal subset of edges that yields the same differentiation graph traversal: the
29 | leaf edges. Specifically, this removes edges that are reachable from other edges in the
30 | differentiation graph, avoiding the need to keep gradients in memory for all edges
31 | simultaneously.
32 |
33 | :param roots: Roots of the graph traversal. Modified in-place.
34 | :returns: Minimal subset of leaf edges.
35 | """
36 |
37 | nodes_to_traverse = deque((child, root) for root in roots for child in _next_edges(root))
38 | result = {root for root in roots if root in self._edges}
39 |
40 | excluded = roots
41 | while nodes_to_traverse:
42 | node, origin = nodes_to_traverse.popleft()
43 | if node in self._edges:
44 | result.add(node)
45 | result.discard(origin)
46 | origin = node
47 | for child in _next_edges(node):
48 | if child not in excluded:
49 | nodes_to_traverse.append((child, origin))
50 | excluded.add(child)
51 | return result
52 |
53 |
54 | def _next_edges(edge: GradientEdge) -> list[GradientEdge]:
55 | """
56 | Get the next edges in the autograd graph from the given edge.
57 |
58 | :param edge: The current edge.
59 | """
60 | return [GradientEdge(child, nr) for child, nr in edge.node.next_functions if child is not None]
61 |
--------------------------------------------------------------------------------
/tests/unit/autojac/_transform/test_init.py:
--------------------------------------------------------------------------------
1 | from pytest import raises
2 | from utils.dict_assertions import assert_tensor_dicts_are_close
3 | from utils.tensors import tensor_
4 |
5 | from torchjd.autojac._transform import Init, RequirementError
6 |
7 |
8 | def test_single_input():
9 | """
10 | Tests that when there is a single key to initialize, the Init transform creates a TensorDict
11 | whose value is a tensor full of ones, of the same shape as its key.
12 | """
13 |
14 | key = tensor_([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
15 | input = {}
16 |
17 | init = Init({key})
18 |
19 | output = init(input)
20 | expected_output = {key: tensor_([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])}
21 |
22 | assert_tensor_dicts_are_close(output, expected_output)
23 |
24 |
25 | def test_multiple_inputs():
26 | """
27 | Tests that when there are several keys to initialize, the Init transform creates a TensorDict
28 | whose values are tensors full of ones, of the same shape as their corresponding keys.
29 | """
30 |
31 | key1 = tensor_([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
32 | key2 = tensor_([1.0, 3.0, 5.0])
33 | input = {}
34 |
35 | init = Init({key1, key2})
36 |
37 | output = init(input)
38 | expected = {
39 | key1: tensor_([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]),
40 | key2: tensor_([1.0, 1.0, 1.0]),
41 | }
42 | assert_tensor_dicts_are_close(output, expected)
43 |
44 |
45 | def test_conjunction_of_inits_is_init():
46 | """
47 | Tests that the conjunction of 2 Init transforms is equivalent to a single Init transform with
48 | multiple keys.
49 | """
50 |
51 | x1 = tensor_(5.0)
52 | x2 = tensor_(6.0)
53 | input = {}
54 |
55 | init1 = Init({x1})
56 | init2 = Init({x2})
57 | conjunction_of_inits = init1 | init2
58 | init = Init({x1, x2})
59 |
60 | output = conjunction_of_inits(input)
61 | expected_output = init(input)
62 |
63 | assert_tensor_dicts_are_close(output, expected_output)
64 |
65 |
66 | def test_check_keys():
67 | """Tests that the `check_keys` method works correctly: the input_keys should be empty."""
68 |
69 | key = tensor_([1.0])
70 | init = Init({key})
71 |
72 | output_keys = init.check_keys(set())
73 | assert output_keys == {key}
74 |
75 | with raises(RequirementError):
76 | init.check_keys({key})
77 |
--------------------------------------------------------------------------------
/src/torchjd/autojac/_transform/_stack.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Sequence
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 | from ._base import TensorDict, Transform
7 | from ._materialize import materialize
8 |
9 |
10 | class Stack(Transform):
11 | """
12 | Transform from something to Jacobians, applying several transforms to the same input, and
13 | combining the results (by stacking) into a single TensorDict.
14 |
15 | The set of keys of the resulting dict is the union of the sets of keys of the input dicts.
16 |
17 | :param transforms: The transforms to apply. They should all be from the same thing and map to
18 | Gradients. Their outputs may have different sets of keys. If a key is absent in some output
19 | dicts, the corresponding stacked tensor is filled with zeroes at the positions corresponding
20 | to those dicts.
21 | """
22 |
23 | def __init__(self, transforms: Sequence[Transform]):
24 | self.transforms = transforms
25 |
26 | def __call__(self, input: TensorDict) -> TensorDict:
27 | results = [transform(input) for transform in self.transforms]
28 | result = _stack(results)
29 | return result
30 |
31 | def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
32 | return {key for transform in self.transforms for key in transform.check_keys(input_keys)}
33 |
34 |
35 | def _stack(gradient_dicts: list[TensorDict]) -> TensorDict:
36 | # It is important to first remove duplicate keys before computing their associated
37 | # stacked tensor. Otherwise, some computations would be duplicated. Therefore, we first compute
38 | # unique_keys, and only then, we compute the stacked tensors.
39 | union: TensorDict = {}
40 | for d in gradient_dicts:
41 | union |= d
42 | unique_keys = union.keys()
43 | result = {key: _stack_one_key(gradient_dicts, key) for key in unique_keys}
44 | return result
45 |
46 |
47 | def _stack_one_key(gradient_dicts: list[TensorDict], input: Tensor) -> Tensor:
48 | """Makes the stacked tensor corresponding to a given key, from a list of tensor dicts."""
49 |
50 | optional_gradients = [gradients.get(input, None) for gradients in gradient_dicts]
51 | gradients = materialize(optional_gradients, [input] * len(optional_gradients))
52 | jacobian = torch.stack(gradients, dim=0)
53 | return jacobian
54 |
--------------------------------------------------------------------------------
/tests/unit/autojac/_transform/test_select.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytest import raises
3 | from utils.dict_assertions import assert_tensor_dicts_are_close
4 | from utils.tensors import tensor_
5 |
6 | from torchjd.autojac._transform import RequirementError, Select
7 |
8 |
9 | def test_partition():
10 | """
11 | Tests that the Select transform works correctly by applying 2 different Selects to a TensorDict,
12 | whose keys form a partition of the keys of the TensorDict.
13 | """
14 |
15 | key1 = tensor_([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
16 | key2 = tensor_([1.0, 3.0, 5.0])
17 | key3 = tensor_(2.0)
18 | value1 = torch.ones_like(key1)
19 | value2 = torch.ones_like(key2)
20 | value3 = torch.ones_like(key3)
21 | input = {key1: value1, key2: value2, key3: value3}
22 |
23 | select1 = Select({key1, key2})
24 | select2 = Select({key3})
25 |
26 | output1 = select1(input)
27 | expected_output1 = {key1: value1, key2: value2}
28 |
29 | assert_tensor_dicts_are_close(output1, expected_output1)
30 |
31 | output2 = select2(input)
32 | expected_output2 = {key3: value3}
33 |
34 | assert_tensor_dicts_are_close(output2, expected_output2)
35 |
36 |
37 | def test_conjunction_of_selects_is_select():
38 | """
39 | Tests that the conjunction of 2 Select transforms is equivalent to directly using a Select with
40 | the union of the keys of the 2 Selects.
41 | """
42 |
43 | x1 = tensor_(5.0)
44 | x2 = tensor_(6.0)
45 | x3 = tensor_(7.0)
46 | input = {x1: torch.ones_like(x1), x2: torch.ones_like(x2), x3: torch.ones_like(x3)}
47 |
48 | select1 = Select({x1})
49 | select2 = Select({x2})
50 | conjunction_of_selects = select1 | select2
51 | select = Select({x1, x2})
52 |
53 | output = conjunction_of_selects(input)
54 | expected_output = select(input)
55 |
56 | assert_tensor_dicts_are_close(output, expected_output)
57 |
58 |
59 | def test_check_keys():
60 | """
61 | Tests that the `check_keys` method works correctly: the set of keys to select should be a subset
62 | of the set of required_keys.
63 | """
64 |
65 | key1 = tensor_([1.0])
66 | key2 = tensor_([2.0])
67 | key3 = tensor_([3.0])
68 |
69 | output_keys = Select({key1, key2}).check_keys({key1, key2, key3})
70 | assert output_keys == {key1, key2}
71 |
72 | with raises(RequirementError):
73 | Select({key1, key2}).check_keys({key1})
74 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/test_mgda.py:
--------------------------------------------------------------------------------
1 | from pytest import mark
2 | from torch import Tensor
3 | from torch.testing import assert_close
4 | from utils.tensors import ones_, randn_
5 |
6 | from torchjd.aggregation import MGDA
7 | from torchjd.aggregation._mgda import MGDAWeighting
8 | from torchjd.aggregation._utils.gramian import compute_gramian
9 |
10 | from ._asserts import (
11 | assert_expected_structure,
12 | assert_non_conflicting,
13 | assert_permutation_invariant,
14 | )
15 | from ._inputs import scaled_matrices, typical_matrices
16 |
17 | scaled_pairs = [(MGDA(), matrix) for matrix in scaled_matrices]
18 | typical_pairs = [(MGDA(), matrix) for matrix in typical_matrices]
19 |
20 |
21 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
22 | def test_expected_structure(aggregator: MGDA, matrix: Tensor):
23 | assert_expected_structure(aggregator, matrix)
24 |
25 |
26 | @mark.parametrize(["aggregator", "matrix"], typical_pairs)
27 | def test_non_conflicting(aggregator: MGDA, matrix: Tensor):
28 | assert_non_conflicting(aggregator, matrix)
29 |
30 |
31 | @mark.parametrize(["aggregator", "matrix"], typical_pairs)
32 | def test_permutation_invariant(aggregator: MGDA, matrix: Tensor):
33 | assert_permutation_invariant(aggregator, matrix)
34 |
35 |
36 | @mark.parametrize(
37 | "shape",
38 | [
39 | (5, 7),
40 | (9, 37),
41 | (2, 14),
42 | (32, 114),
43 | (50, 100),
44 | ],
45 | )
46 | def test_mgda_satisfies_kkt_conditions(shape: tuple[int, int]):
47 | matrix = randn_(shape)
48 | gramian = compute_gramian(matrix)
49 |
50 | weighting = MGDAWeighting(epsilon=1e-05, max_iters=1000)
51 | weights = weighting(gramian)
52 |
53 | output_direction = gramian @ weights # Stationarity
54 | lamb = -weights @ output_direction # Complementary slackness
55 | mu = output_direction + lamb
56 |
57 | # Primal feasibility
58 | positive_weights = weights[weights >= 0]
59 | assert_close(positive_weights.norm(), weights.norm())
60 |
61 | weights_sum = weights.sum()
62 | assert_close(weights_sum, ones_([]))
63 |
64 | # Dual feasibility
65 | positive_mu = mu[mu >= 0]
66 | assert_close(positive_mu.norm(), mu.norm(), atol=1e-02, rtol=0.0)
67 |
68 |
69 | def test_representations():
70 | A = MGDA(epsilon=0.001, max_iters=100)
71 | assert repr(A) == "MGDA(epsilon=0.001, max_iters=100)"
72 | assert str(A) == "MGDA"
73 |
--------------------------------------------------------------------------------
/src/torchjd/autojac/_transform/_grad.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Sequence
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 | from ._differentiate import Differentiate
7 | from ._ordered_set import OrderedSet
8 |
9 |
10 | class Grad(Differentiate):
11 | """
12 | Transform from Gradients to Gradients, computing the gradient of each output element with
13 | respect to each input tensor, and applying the linear transformations represented by provided
14 | the grad_outputs to the results.
15 |
16 | :param outputs: Tensors to differentiate.
17 | :param inputs: Tensors with respect to which we differentiate.
18 | :param retain_graph: If False, the graph used to compute the grads will be freed. Defaults to
19 | False.
20 | :param create_graph: If True, graph of the derivative will be constructed, allowing to compute
21 | higher order derivative products. Defaults to False.
22 |
23 | .. note:: The order of outputs and inputs only matters because we have no guarantee that
24 | torch.autograd.grad is *exactly* equivariant to input permutations and invariant to output
25 | (with their corresponding grad_output) permutations.
26 | """
27 |
28 | def __init__(
29 | self,
30 | outputs: OrderedSet[Tensor],
31 | inputs: OrderedSet[Tensor],
32 | retain_graph: bool = False,
33 | create_graph: bool = False,
34 | ):
35 | super().__init__(outputs, inputs, retain_graph, create_graph)
36 |
37 | def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
38 | """
39 | Computes the gradient of each output element with respect to each input tensor, and applies
40 | the linear transformations represented by the grad_outputs to the results.
41 |
42 | Returns one gradient per input, corresponding to the sum of the scaled gradients with
43 | respect to this input.
44 |
45 | :param grad_outputs: The sequence of tensors to scale the obtained gradients with. Its
46 | length should be equal to the length of ``outputs``. Each grad_output should have the
47 | same shape as the corresponding output.
48 | """
49 |
50 | if len(self.inputs) == 0:
51 | return tuple()
52 |
53 | if len(self.outputs) == 0:
54 | return tuple(torch.zeros_like(input) for input in self.inputs)
55 |
56 | grads = self._get_vjp(grad_outputs, self.retain_graph)
57 | return grads
58 |
--------------------------------------------------------------------------------
/tests/unit/autogram/test_gramian_utils.py:
--------------------------------------------------------------------------------
1 | from pytest import mark
2 | from torch.testing import assert_close
3 | from utils.forward_backwards import compute_gramian
4 | from utils.tensors import randn_
5 |
6 | from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian
7 |
8 |
9 | @mark.parametrize(
10 | ["original_shape", "target_shape"],
11 | [
12 | ([], []),
13 | ([], [1, 1]),
14 | ([1], []),
15 | ([12], [2, 3, 2]),
16 | ([12], [4, 3]),
17 | ([12], [12]),
18 | ([4, 3], [12]),
19 | ([4, 3], [2, 3, 2]),
20 | ([4, 3], [3, 4]),
21 | ([4, 3], [4, 3]),
22 | ([6, 7, 9], [378]),
23 | ([6, 7, 9], [9, 42]),
24 | ([6, 7, 9], [2, 7, 27]),
25 | ([6, 7, 9], [6, 7, 9]),
26 | ],
27 | )
28 | def test_reshape_gramian(original_shape: list[int], target_shape: list[int]):
29 | """Tests that reshape_gramian is such that compute_gramian is equivariant to a reshape."""
30 |
31 | original_matrix = randn_(original_shape + [2])
32 | target_matrix = original_matrix.reshape(target_shape + [2])
33 |
34 | original_gramian = compute_gramian(original_matrix)
35 | target_gramian = compute_gramian(target_matrix)
36 |
37 | reshaped_gramian = reshape_gramian(original_gramian, target_shape)
38 |
39 | assert_close(reshaped_gramian, target_gramian)
40 |
41 |
42 | @mark.parametrize(
43 | ["shape", "source", "destination"],
44 | [
45 | ([], [], []),
46 | ([1], [0], [0]),
47 | ([1], [], []),
48 | ([1, 1], [], []),
49 | ([1, 1], [1], [0]),
50 | ([6, 7], [1], [0]),
51 | ([3, 1], [0, 1], [1, 0]),
52 | ([1, 1, 1], [], []),
53 | ([3, 2, 5], [], []),
54 | ([1, 1, 1], [2], [0]),
55 | ([3, 2, 5], [1], [2]),
56 | ([2, 2, 3], [0, 2], [1, 0]),
57 | ([2, 2, 3], [0, 2, 1], [1, 0, 2]),
58 | ],
59 | )
60 | def test_movedim_gramian(shape: list[int], source: list[int], destination: list[int]):
61 | """Tests that movedim_gramian is such that compute_gramian is equivariant to a movedim."""
62 |
63 | original_matrix = randn_(shape + [2])
64 | target_matrix = original_matrix.movedim(source, destination)
65 |
66 | original_gramian = compute_gramian(original_matrix)
67 | target_gramian = compute_gramian(target_matrix)
68 |
69 | moveddim_gramian = movedim_gramian(original_gramian, source, destination)
70 |
71 | assert_close(moveddim_gramian, target_gramian)
72 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_utils/dual_cone.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | import numpy as np
4 | import torch
5 | from qpsolvers import solve_qp
6 | from torch import Tensor
7 |
8 |
9 | def project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor:
10 | """
11 | Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the
12 | rows of a matrix whose Gramian is provided.
13 |
14 | :param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`.
15 | :param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite.
16 | :param solver: The quadratic programming solver to use.
17 | :return: A tensor of projection weights with the same shape as `U`.
18 | """
19 |
20 | G_ = _to_array(G)
21 | U_ = _to_array(U)
22 |
23 | W = np.apply_along_axis(lambda u: _project_weight_vector(u, G_, solver), axis=-1, arr=U_)
24 |
25 | return torch.as_tensor(W, device=G.device, dtype=G.dtype)
26 |
27 |
28 | def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: Literal["quadprog"]) -> np.ndarray:
29 | r"""
30 | Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`,
31 | given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies
32 | `\pi_J(J^T u) = J^T w`, with `\pi_J` defined in Equation 3 of [1].
33 |
34 | By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic program:
35 | minimize v^T G v
36 | subject to u \preceq v
37 |
38 | Reference:
39 | [1] `Jacobian Descent For Multi-Objective Optimization `_.
40 |
41 | :param u: The vector of weights `u` of shape `[m]` corresponding to the vector `J^T u` to
42 | project.
43 | :param G: The Gramian matrix of `J`, equal to `J J^T`, and of shape `[m, m]`. It must be
44 | symmetric and positive definite.
45 | :param solver: The quadratic programming solver to use.
46 | """
47 |
48 | m = G.shape[0]
49 | w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=solver)
50 |
51 | if w is None: # This may happen when G has large values.
52 | raise ValueError("Failed to solve the quadratic programming problem.")
53 |
54 | return w
55 |
56 |
57 | def _to_array(tensor: Tensor) -> np.ndarray:
58 | """Transforms a tensor into a numpy array with float64 dtype."""
59 |
60 | return tensor.cpu().detach().numpy().astype(np.float64)
61 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/test_nash_mtl.py:
--------------------------------------------------------------------------------
1 | from pytest import mark
2 | from torch import Tensor
3 | from torch.testing import assert_close
4 | from utils.tensors import ones_, randn_
5 |
6 | from torchjd.aggregation import NashMTL
7 |
8 | from ._asserts import assert_expected_structure, assert_non_differentiable
9 | from ._inputs import nash_mtl_matrices
10 |
11 |
12 | def _make_aggregator(matrix: Tensor) -> NashMTL:
13 | return NashMTL(n_tasks=matrix.shape[0])
14 |
15 |
16 | standard_pairs = [(_make_aggregator(matrix), matrix) for matrix in nash_mtl_matrices]
17 | requires_grad_pairs = [(NashMTL(n_tasks=3), ones_(3, 5, requires_grad=True))]
18 |
19 |
20 | # Note that as opposed to most aggregators, the expected structure is only tested with non-scaled
21 | # matrices, and with matrices of > 1 row. Otherwise, NashMTL fails.
22 | @mark.filterwarnings(
23 | "ignore:Solution may be inaccurate.",
24 | "ignore:You are solving a parameterized problem that is not DPP.",
25 | )
26 | @mark.parametrize(["aggregator", "matrix"], standard_pairs)
27 | def test_expected_structure(aggregator: NashMTL, matrix: Tensor):
28 | assert_expected_structure(aggregator, matrix)
29 |
30 |
31 | @mark.filterwarnings("ignore:You are solving a parameterized problem that is not DPP.")
32 | @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs)
33 | def test_non_differentiable(aggregator: NashMTL, matrix: Tensor):
34 | assert_non_differentiable(aggregator, matrix)
35 |
36 |
37 | @mark.filterwarnings("ignore: You are solving a parameterized problem that is not DPP.")
38 | def test_nash_mtl_reset():
39 | """
40 | Tests that the reset method of NashMTL correctly resets its internal state, by verifying that
41 | the result is the same after reset as it is right after instantiation.
42 |
43 | To ensure that the aggregations are not all the same, we create different matrices to aggregate.
44 | """
45 |
46 | matrices = [randn_(3, 5) for _ in range(4)]
47 | aggregator = NashMTL(n_tasks=3, update_weights_every=3)
48 | expecteds = [aggregator(matrix) for matrix in matrices]
49 |
50 | aggregator.reset()
51 | results = [aggregator(matrix) for matrix in matrices]
52 |
53 | for result, expected in zip(results, expecteds):
54 | assert_close(result, expected)
55 |
56 |
57 | def test_representations():
58 | A = NashMTL(n_tasks=2, max_norm=1.5, update_weights_every=2, optim_niter=5)
59 | assert repr(A) == "NashMTL(n_tasks=2, max_norm=1.5, update_weights_every=2, optim_niter=5)"
60 | assert str(A) == "NashMTL"
61 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/test_dualproj.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytest import mark
3 | from torch import Tensor
4 | from utils.tensors import ones_
5 |
6 | from torchjd.aggregation import DualProj
7 |
8 | from ._asserts import (
9 | assert_expected_structure,
10 | assert_non_conflicting,
11 | assert_non_differentiable,
12 | assert_permutation_invariant,
13 | assert_strongly_stationary,
14 | )
15 | from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices
16 |
17 | scaled_pairs = [(DualProj(), matrix) for matrix in scaled_matrices]
18 | typical_pairs = [(DualProj(), matrix) for matrix in typical_matrices]
19 | non_strong_pairs = [(DualProj(), matrix) for matrix in non_strong_matrices]
20 | requires_grad_pairs = [(DualProj(), ones_(3, 5, requires_grad=True))]
21 |
22 |
23 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
24 | def test_expected_structure(aggregator: DualProj, matrix: Tensor):
25 | assert_expected_structure(aggregator, matrix)
26 |
27 |
28 | @mark.parametrize(["aggregator", "matrix"], typical_pairs)
29 | def test_non_conflicting(aggregator: DualProj, matrix: Tensor):
30 | assert_non_conflicting(aggregator, matrix, atol=5e-05, rtol=5e-05)
31 |
32 |
33 | @mark.parametrize(["aggregator", "matrix"], typical_pairs)
34 | def test_permutation_invariant(aggregator: DualProj, matrix: Tensor):
35 | assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=2e-07, rtol=2e-07)
36 |
37 |
38 | @mark.parametrize(["aggregator", "matrix"], non_strong_pairs)
39 | def test_strongly_stationary(aggregator: DualProj, matrix: Tensor):
40 | assert_strongly_stationary(aggregator, matrix, threshold=3e-03)
41 |
42 |
43 | @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs)
44 | def test_non_differentiable(aggregator: DualProj, matrix: Tensor):
45 | assert_non_differentiable(aggregator, matrix)
46 |
47 |
48 | def test_representations():
49 | A = DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog")
50 | assert (
51 | repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')"
52 | )
53 | assert str(A) == "DualProj"
54 |
55 | A = DualProj(
56 | pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"),
57 | norm_eps=0.0001,
58 | reg_eps=0.0001,
59 | solver="quadprog",
60 | )
61 | assert (
62 | repr(A) == "DualProj(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, "
63 | "solver='quadprog')"
64 | )
65 | assert str(A) == "DualProj([1., 2., 3.])"
66 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/test_krum.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext as does_not_raise
2 |
3 | from pytest import mark, raises
4 | from torch import Tensor
5 | from utils.contexts import ExceptionContext
6 | from utils.tensors import ones_
7 |
8 | from torchjd.aggregation import Krum
9 |
10 | from ._asserts import assert_expected_structure
11 | from ._inputs import scaled_matrices_2_plus_rows, typical_matrices_2_plus_rows
12 |
13 | scaled_pairs = [(Krum(n_byzantine=1), matrix) for matrix in scaled_matrices_2_plus_rows]
14 | typical_pairs = [(Krum(n_byzantine=1), matrix) for matrix in typical_matrices_2_plus_rows]
15 |
16 |
17 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
18 | def test_expected_structure(aggregator: Krum, matrix: Tensor):
19 | assert_expected_structure(aggregator, matrix)
20 |
21 |
22 | @mark.parametrize(
23 | ["n_byzantine", "expectation"],
24 | [
25 | (-5, raises(ValueError)),
26 | (-1, raises(ValueError)),
27 | (0, does_not_raise()),
28 | (1, does_not_raise()),
29 | (5, does_not_raise()),
30 | ],
31 | )
32 | def test_n_byzantine_check(n_byzantine: int, expectation: ExceptionContext):
33 | with expectation:
34 | _ = Krum(n_byzantine=n_byzantine, n_selected=1)
35 |
36 |
37 | @mark.parametrize(
38 | ["n_selected", "expectation"],
39 | [
40 | (-5, raises(ValueError)),
41 | (-1, raises(ValueError)),
42 | (0, raises(ValueError)),
43 | (1, does_not_raise()),
44 | (5, does_not_raise()),
45 | ],
46 | )
47 | def test_n_selected_check(n_selected: int, expectation: ExceptionContext):
48 | with expectation:
49 | _ = Krum(n_byzantine=1, n_selected=n_selected)
50 |
51 |
52 | @mark.parametrize(
53 | ["n_byzantine", "n_selected", "n_rows", "expectation"],
54 | [
55 | (1, 1, 3, raises(ValueError)),
56 | (1, 1, 4, does_not_raise()),
57 | (1, 4, 4, does_not_raise()),
58 | (12, 4, 14, raises(ValueError)),
59 | (12, 4, 15, does_not_raise()),
60 | (12, 15, 15, does_not_raise()),
61 | (12, 16, 15, raises(ValueError)),
62 | ],
63 | )
64 | def test_matrix_shape_check(
65 | n_byzantine: int, n_selected: int, n_rows: int, expectation: ExceptionContext
66 | ):
67 | aggregator = Krum(n_byzantine=n_byzantine, n_selected=n_selected)
68 | matrix = ones_([n_rows, 5])
69 |
70 | with expectation:
71 | _ = aggregator(matrix)
72 |
73 |
74 | def test_representations():
75 | A = Krum(n_byzantine=1, n_selected=2)
76 | assert repr(A) == "Krum(n_byzantine=1, n_selected=2)"
77 | assert str(A) == "Krum1-2"
78 |
--------------------------------------------------------------------------------
/docs/source/examples/index.rst:
--------------------------------------------------------------------------------
1 | Examples
2 | ========
3 |
4 | This section contains some usage examples for TorchJD.
5 |
6 | - :doc:`Basic Usage ` provides a toy example using :doc:`torchjd.backward
7 | <../docs/autojac/backward>` to make a step of Jacobian descent with the :doc:`UPGrad
8 | <../docs/aggregation/upgrad>` aggregator.
9 | - :doc:`Instance-Wise Risk Minimization (IWRM) ` provides an example in which we minimize the
10 | vector of per-instance losses, using stochastic sub-Jacobian descent (SSJD). It is compared to the
11 | usual minimization of the average loss, called empirical risk minimization (ERM), using stochastic
12 | gradient descent (SGD).
13 | - :doc:`Partial Jacobian Descent for IWRM ` provides an example in which we minimize the
14 | vector of per-instance losses using stochastic sub-Jacobian descent, similar to our :doc:`IWRM `
15 | example. However, this method bases the aggregation decision on the Jacobian of the losses with respect
16 | to **only a subset** of the model's parameters, offering a trade-off between computational cost and
17 | aggregation precision.
18 | - :doc:`Multi-Task Learning (MTL) ` provides an example of multi-task learning where Jacobian
19 | descent is used to optimize the vector of per-task losses of a multi-task model, using the
20 | dedicated backpropagation function :doc:`mtl_backward <../docs/autojac/mtl_backward>`.
21 | - :doc:`Instance-Wise Multi-Task Learning (IWMTL) ` shows how to combine multi-task learning
22 | with instance-wise risk minimization: one loss per task and per element of the batch, using the
23 | :doc:`autogram.Engine <../docs/autogram/engine>` and a :doc:`GeneralizedWeighting
24 | <../docs/aggregation/index>`.
25 | - :doc:`Recurrent Neural Network (RNN) ` shows how to apply Jacobian descent to RNN training,
26 | with one loss per output sequence element.
27 | - :doc:`Monitoring Aggregations ` shows how to monitor the aggregation performed by the
28 | aggregator, to check if Jacobian descent is prescribed for your use-case.
29 | - :doc:`PyTorch Lightning Integration ` showcases how to combine
30 | TorchJD with PyTorch Lightning, by providing an example implementation of a multi-task
31 | ``LightningModule`` optimized by Jacobian descent.
32 | - :doc:`Automatic Mixed Precision ` shows how to combine mixed precision training with TorchJD.
33 |
34 | .. toctree::
35 | :hidden:
36 |
37 | basic_usage.rst
38 | iwrm.rst
39 | partial_jd.rst
40 | mtl.rst
41 | iwmtl.rst
42 | rnn.rst
43 | monitoring.rst
44 | lightning_integration.rst
45 | amp.rst
46 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import random as rand
2 | from contextlib import nullcontext
3 |
4 | import torch
5 | from device import DEVICE
6 | from pytest import RaisesExc, fixture, mark
7 | from torch import Tensor
8 | from utils.architectures import ModuleFactory
9 |
10 | from torchjd.aggregation import Aggregator, Weighting
11 |
12 |
13 | @fixture(autouse=True)
14 | def fix_randomness() -> None:
15 | rand.seed(0)
16 | torch.manual_seed(0)
17 |
18 | # Only force to use deterministic algorithms on CPU.
19 | # This is because the CI currently runs only on CPU, so we don't really need perfect
20 | # reproducibility on GPU. We also use GPU to benchmark algorithms, and we would rather have them
21 | # use non-deterministic but faster algorithms.
22 | if DEVICE.type == "cpu":
23 | torch.use_deterministic_algorithms(True)
24 |
25 |
26 | def pytest_addoption(parser):
27 | parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")
28 |
29 |
30 | def pytest_configure(config):
31 | config.addinivalue_line("markers", "slow: mark test as slow to run")
32 | config.addinivalue_line("markers", "xfail_if_cuda: mark test as xfail if running on cuda")
33 |
34 |
35 | def pytest_collection_modifyitems(config, items):
36 | skip_slow = mark.skip(reason="Slow test. Use --runslow to run it.")
37 | xfail_cuda = mark.xfail(reason=f"Test expected to fail on {DEVICE}")
38 | for item in items:
39 | if "slow" in item.keywords and not config.getoption("--runslow"):
40 | item.add_marker(skip_slow)
41 | if "xfail_if_cuda" in item.keywords and str(DEVICE).startswith("cuda"):
42 | item.add_marker(xfail_cuda)
43 |
44 |
45 | def pytest_make_parametrize_id(config, val, argname):
46 | MAX_SIZE = 40
47 | optional_string = None # Returning None means using pytest's way of making the string
48 |
49 | if isinstance(val, (Aggregator, ModuleFactory, Weighting)):
50 | optional_string = str(val)
51 | elif isinstance(val, Tensor):
52 | optional_string = "T" + str(list(val.shape)) # T to indicate that it's a tensor
53 | elif isinstance(val, (tuple, list, set)) and len(val) < 20:
54 | optional_string = str(val)
55 | elif isinstance(val, RaisesExc):
56 | optional_string = " or ".join([f"{exc.__name__}" for exc in val.expected_exceptions])
57 | elif isinstance(val, nullcontext):
58 | optional_string = "does_not_raise()"
59 |
60 | if isinstance(optional_string, str) and len(optional_string) > MAX_SIZE:
61 | optional_string = optional_string[: MAX_SIZE - 3] + "+++" # Can't use dots with pytest
62 |
63 | return optional_string
64 |
--------------------------------------------------------------------------------
/docs/source/icons/TorchJD_logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
66 |
--------------------------------------------------------------------------------
/src/torchjd/autogram/_gramian_computer.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Optional
3 |
4 | from torch import Tensor
5 | from torch.utils._pytree import PyTree
6 |
7 | from torchjd.autogram._jacobian_computer import JacobianComputer
8 |
9 |
10 | class GramianComputer(ABC):
11 | @abstractmethod
12 | def __call__(
13 | self,
14 | rg_outputs: tuple[Tensor, ...],
15 | grad_outputs: tuple[Tensor, ...],
16 | args: tuple[PyTree, ...],
17 | kwargs: dict[str, PyTree],
18 | ) -> Optional[Tensor]:
19 | """Compute what we can for a module and optionally return the gramian if it's ready."""
20 |
21 | def track_forward_call(self) -> None:
22 | """Track that the module's forward was called. Necessary in some implementations."""
23 |
24 | def reset(self):
25 | """Reset state if any. Necessary in some implementations."""
26 |
27 |
28 | class JacobianBasedGramianComputer(GramianComputer, ABC):
29 | def __init__(self, jacobian_computer):
30 | self.jacobian_computer = jacobian_computer
31 |
32 | @staticmethod
33 | def _to_gramian(jacobian: Tensor) -> Tensor:
34 | return jacobian @ jacobian.T
35 |
36 |
37 | class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer):
38 | """
39 | Stateful JacobianBasedGramianComputer that waits for all usages to be counted before returning
40 | the gramian.
41 | """
42 |
43 | def __init__(self, jacobian_computer: JacobianComputer):
44 | super().__init__(jacobian_computer)
45 | self.remaining_counter = 0
46 | self.summed_jacobian: Optional[Tensor] = None
47 |
48 | def reset(self) -> None:
49 | self.remaining_counter = 0
50 | self.summed_jacobian = None
51 |
52 | def track_forward_call(self) -> None:
53 | self.remaining_counter += 1
54 |
55 | def __call__(
56 | self,
57 | rg_outputs: tuple[Tensor, ...],
58 | grad_outputs: tuple[Tensor, ...],
59 | args: tuple[PyTree, ...],
60 | kwargs: dict[str, PyTree],
61 | ) -> Optional[Tensor]:
62 | """Compute what we can for a module and optionally return the gramian if it's ready."""
63 |
64 | jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs)
65 |
66 | if self.summed_jacobian is None:
67 | self.summed_jacobian = jacobian_matrix
68 | else:
69 | self.summed_jacobian += jacobian_matrix
70 |
71 | self.remaining_counter -= 1
72 |
73 | if self.remaining_counter == 0:
74 | gramian = self._to_gramian(self.summed_jacobian)
75 | del self.summed_jacobian
76 | return gramian
77 | else:
78 | return None
79 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/test_upgrad.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytest import mark
3 | from torch import Tensor
4 | from utils.tensors import ones_
5 |
6 | from torchjd.aggregation import UPGrad
7 |
8 | from ._asserts import (
9 | assert_expected_structure,
10 | assert_linear_under_scaling,
11 | assert_non_conflicting,
12 | assert_non_differentiable,
13 | assert_permutation_invariant,
14 | assert_strongly_stationary,
15 | )
16 | from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices
17 |
18 | scaled_pairs = [(UPGrad(), matrix) for matrix in scaled_matrices]
19 | typical_pairs = [(UPGrad(), matrix) for matrix in typical_matrices]
20 | non_strong_pairs = [(UPGrad(), matrix) for matrix in non_strong_matrices]
21 | requires_grad_pairs = [(UPGrad(), ones_(3, 5, requires_grad=True))]
22 |
23 |
24 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
25 | def test_expected_structure(aggregator: UPGrad, matrix: Tensor):
26 | assert_expected_structure(aggregator, matrix)
27 |
28 |
29 | @mark.parametrize(["aggregator", "matrix"], typical_pairs)
30 | def test_non_conflicting(aggregator: UPGrad, matrix: Tensor):
31 | assert_non_conflicting(aggregator, matrix, atol=3e-04, rtol=3e-04)
32 |
33 |
34 | @mark.parametrize(["aggregator", "matrix"], typical_pairs)
35 | def test_permutation_invariant(aggregator: UPGrad, matrix: Tensor):
36 | assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=4e-07, rtol=4e-07)
37 |
38 |
39 | @mark.parametrize(["aggregator", "matrix"], typical_pairs)
40 | def test_linear_under_scaling(aggregator: UPGrad, matrix: Tensor):
41 | assert_linear_under_scaling(aggregator, matrix, n_runs=5, atol=3e-02, rtol=3e-02)
42 |
43 |
44 | @mark.parametrize(["aggregator", "matrix"], non_strong_pairs)
45 | def test_strongly_stationary(aggregator: UPGrad, matrix: Tensor):
46 | assert_strongly_stationary(aggregator, matrix, threshold=5e-03)
47 |
48 |
49 | @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs)
50 | def test_non_differentiable(aggregator: UPGrad, matrix: Tensor):
51 | assert_non_differentiable(aggregator, matrix)
52 |
53 |
54 | def test_representations():
55 | A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog")
56 | assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')"
57 | assert str(A) == "UPGrad"
58 |
59 | A = UPGrad(
60 | pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"),
61 | norm_eps=0.0001,
62 | reg_eps=0.0001,
63 | solver="quadprog",
64 | )
65 | assert (
66 | repr(A) == "UPGrad(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, "
67 | "solver='quadprog')"
68 | )
69 | assert str(A) == "UPGrad([1., 2., 3.])"
70 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_aggregator_bases.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from torch import Tensor, nn
4 |
5 | from ._utils.gramian import compute_gramian
6 | from ._weighting_bases import Matrix, PSDMatrix, Weighting
7 |
8 |
9 | class Aggregator(nn.Module, ABC):
10 | r"""
11 | Abstract base class for all aggregators. It has the role of aggregating matrices of dimension
12 | :math:`m \times n` into row vectors of dimension :math:`n`.
13 | """
14 |
15 | def __init__(self):
16 | super().__init__()
17 |
18 | @staticmethod
19 | def _check_is_matrix(matrix: Tensor) -> None:
20 | if len(matrix.shape) != 2:
21 | raise ValueError(
22 | "Parameter `matrix` should be a tensor of dimension 2. Found `matrix.shape = "
23 | f"{matrix.shape}`."
24 | )
25 |
26 | @abstractmethod
27 | def forward(self, matrix: Tensor) -> Tensor:
28 | """Computes the aggregation from the input matrix."""
29 |
30 | # Override to make type hints and documentation more specific
31 | def __call__(self, matrix: Tensor) -> Tensor:
32 | """Computes the aggregation from the input matrix and applies all registered hooks."""
33 |
34 | return super().__call__(matrix)
35 |
36 | def __repr__(self) -> str:
37 | return f"{self.__class__.__name__}()"
38 |
39 | def __str__(self) -> str:
40 | return f"{self.__class__.__name__}"
41 |
42 |
43 | class WeightedAggregator(Aggregator):
44 | """
45 | Aggregator that combines the rows of the input jacobian matrix with weights given by applying a
46 | Weighting to it.
47 |
48 | :param weighting: The object responsible for extracting the vector of weights from the matrix.
49 | """
50 |
51 | def __init__(self, weighting: Weighting[Matrix]):
52 | super().__init__()
53 | self.weighting = weighting
54 |
55 | @staticmethod
56 | def combine(matrix: Tensor, weights: Tensor) -> Tensor:
57 | """
58 | Aggregates a matrix by making a linear combination of its rows, using the provided vector of
59 | weights.
60 | """
61 |
62 | vector = weights @ matrix
63 | return vector
64 |
65 | def forward(self, matrix: Tensor) -> Tensor:
66 | self._check_is_matrix(matrix)
67 | weights = self.weighting(matrix)
68 | vector = self.combine(matrix, weights)
69 | return vector
70 |
71 |
72 | class GramianWeightedAggregator(WeightedAggregator):
73 | """
74 | WeightedAggregator that computes the gramian of the input jacobian matrix before applying a
75 | Weighting to it.
76 |
77 | :param weighting: The object responsible for extracting the vector of weights from the gramian.
78 | """
79 |
80 | def __init__(self, weighting: Weighting[PSDMatrix]):
81 | super().__init__(weighting << compute_gramian)
82 |
--------------------------------------------------------------------------------
/.github/workflows/build-deploy-docs.yml:
--------------------------------------------------------------------------------
1 | name: Build and Deploy Documentation
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | tags:
7 | - 'v[0-9]*.[0-9]*.[0-9]*'
8 |
9 | env:
10 | UV_NO_SYNC: 1
11 | PYTHON_VERSION: '3.14'
12 |
13 | jobs:
14 | build-deploy-doc:
15 | name: Build & deploy doc
16 | environment: prod-documentation
17 | runs-on: ubuntu-latest
18 | permissions:
19 | contents: write
20 | steps:
21 | - name: Checkout repository
22 | uses: actions/checkout@v4
23 |
24 | - name: Set up uv
25 | uses: astral-sh/setup-uv@v5
26 | with:
27 | python-version: ${{ env.PYTHON_VERSION }}
28 |
29 | - name: Install dependencies (default with full options & doc)
30 | run: uv pip install --python-version=${{ env.PYTHON_VERSION }} -e '.[full]' --group doc
31 |
32 | - name: Determine deployment folder
33 | id: deploy_folder
34 | run: |
35 | echo "Determining deployment folder..."
36 | if [[ "${{ github.ref }}" == refs/tags/* ]]; then
37 | echo "Deploying to target ${{ github.ref_name }}"
38 | echo "DEPLOY_DIR=${{ github.ref_name }}" >> $GITHUB_OUTPUT
39 | echo "TORCHJD_VERSION=${{ github.ref_name }}" >> $GITHUB_OUTPUT
40 | else
41 | echo "Deploying to target latest"
42 | echo "DEPLOY_DIR=latest" >> $GITHUB_OUTPUT
43 | echo "TORCHJD_VERSION=main" >> $GITHUB_OUTPUT
44 | fi
45 |
46 | - name: Build Documentation
47 | working-directory: docs
48 | run: uv run make dirhtml
49 | env:
50 | TORCHJD_VERSION: ${{ steps.deploy_folder.outputs.TORCHJD_VERSION }}
51 |
52 | - name: Deploy to DEPLOY_DIR of TorchJD/documentation
53 | uses: peaceiris/actions-gh-pages@v4
54 | with:
55 | deploy_key: ${{ secrets.PROD_DOCUMENTATION_DEPLOY_KEY }}
56 | publish_dir: docs/build/dirhtml
57 | destination_dir: ${{ steps.deploy_folder.outputs.DEPLOY_DIR }}
58 | external_repository: TorchJD/documentation
59 | publish_branch: main
60 |
61 | - name: Kill ssh-agent
62 | # See: https://github.com/peaceiris/actions-gh-pages/issues/909
63 | run: killall ssh-agent
64 |
65 | - name: Deploy to stable of TorchJD/documentation
66 | if: startsWith(github.ref, 'refs/tags/')
67 | uses: peaceiris/actions-gh-pages@v4
68 | with:
69 | deploy_key: ${{ secrets.PROD_DOCUMENTATION_DEPLOY_KEY }}
70 | publish_dir: docs/build/dirhtml
71 | destination_dir: stable
72 | external_repository: TorchJD/documentation
73 | publish_branch: main
74 |
75 | - name: Add documentation link to summary
76 | run: |
77 | echo "### 📄 [View Deployed Documentation](https://torchjd.github.io/documentation/${{ steps.deploy_folder.outputs.DEPLOY_DIR }})" >> $GITHUB_STEP_SUMMARY
78 |
--------------------------------------------------------------------------------
/src/torchjd/autojac/_transform/_diagonalize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 | from ._base import RequirementError, TensorDict, Transform
5 | from ._ordered_set import OrderedSet
6 |
7 |
8 | class Diagonalize(Transform):
9 | """
10 | Transform diagonalizing Gradients into Jacobians.
11 |
12 | The first dimension of the returned Jacobians will be equal to the total number of elements in
13 | the tensors of the input tensor dict. The exact behavior of the diagonalization is best
14 | explained by some examples.
15 |
16 | Example 1:
17 | The input is one tensor of shape [3] and of value [1 2 3].
18 | The output Jacobian will be:
19 | [[1 0 0]
20 | [0 2 0]
21 | [0 0 3]]
22 |
23 | Example 2:
24 | The input is one tensor of shape [2, 2] and of value [[4 5] [6 7]].
25 | The output Jacobian will be:
26 | [[[4 0] [0 0]]
27 | [[0 5] [0 0]]
28 | [[0 0] [6 0]]
29 | [[0 0] [0 7]]]
30 |
31 | Example 3:
32 | The input is two tensors, of shapes [3] and [2, 2] and of values [1 2 3] and [[4 5] [6 7]].
33 | If the key_order has the tensor of shape [3] appear first and the one of shape [2, 2] appear
34 | second, the output Jacobians will be:
35 | [[1 0 0]
36 | [0 2 0]
37 | [0 0 3]
38 | [0 0 0]
39 | [0 0 0]
40 | [0 0 0]
41 | [0 0 0]] and
42 | [[[0 0] [0 0]]
43 | [[0 0] [0 0]]
44 | [[0 0] [0 0]]
45 | [[4 0] [0 0]]
46 | [[0 5] [0 0]]
47 | [[0 0] [6 0]]
48 | [[0 0] [0 7]]]
49 |
50 | :param key_order: The order in which the keys are represented in the rows of the output
51 | Jacobians.
52 | """
53 |
54 | def __init__(self, key_order: OrderedSet[Tensor]):
55 | self.key_order = key_order
56 | self.indices: list[tuple[int, int]] = []
57 | begin = 0
58 | for tensor in self.key_order:
59 | end = begin + tensor.numel()
60 | self.indices.append((begin, end))
61 | begin = end
62 |
63 | def __call__(self, tensors: TensorDict) -> TensorDict:
64 | flattened_considered_values = [tensors[key].reshape([-1]) for key in self.key_order]
65 | diagonal_matrix = torch.cat(flattened_considered_values).diag()
66 | diagonalized_tensors = {
67 | key: diagonal_matrix[:, begin:end].reshape((-1,) + key.shape)
68 | for (begin, end), key in zip(self.indices, self.key_order)
69 | }
70 | return diagonalized_tensors
71 |
72 | def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
73 | if not set(self.key_order) == input_keys:
74 | raise RequirementError(
75 | f"The input_keys must match the key_order. Found input_keys {input_keys} and"
76 | f"key_order {self.key_order}."
77 | )
78 | return input_keys
79 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/test_graddrop.py:
--------------------------------------------------------------------------------
1 | import re
2 | from contextlib import nullcontext as does_not_raise
3 |
4 | import torch
5 | from pytest import mark, raises
6 | from torch import Tensor
7 | from utils.contexts import ExceptionContext
8 | from utils.tensors import ones_
9 |
10 | from torchjd.aggregation import GradDrop
11 |
12 | from ._asserts import assert_expected_structure, assert_non_differentiable
13 | from ._inputs import scaled_matrices, typical_matrices
14 |
15 | scaled_pairs = [(GradDrop(), matrix) for matrix in scaled_matrices]
16 | typical_pairs = [(GradDrop(), matrix) for matrix in typical_matrices]
17 | requires_grad_pairs = [(GradDrop(), ones_(3, 5, requires_grad=True))]
18 |
19 |
20 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
21 | def test_expected_structure(aggregator: GradDrop, matrix: Tensor):
22 | assert_expected_structure(aggregator, matrix)
23 |
24 |
25 | @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs)
26 | def test_non_differentiable(aggregator: GradDrop, matrix: Tensor):
27 | assert_non_differentiable(aggregator, matrix)
28 |
29 |
30 | @mark.parametrize(
31 | ["leak_shape", "expectation"],
32 | [
33 | ([], raises(ValueError)),
34 | ([0], does_not_raise()),
35 | ([1], does_not_raise()),
36 | ([10], does_not_raise()),
37 | ([0, 0], raises(ValueError)),
38 | ([0, 1], raises(ValueError)),
39 | ([1, 1], raises(ValueError)),
40 | ([1, 1, 1], raises(ValueError)),
41 | ([1, 1, 1, 1], raises(ValueError)),
42 | ([1, 1, 1, 1, 1], raises(ValueError)),
43 | ],
44 | )
45 | def test_leak_shape_check(leak_shape: list[int], expectation: ExceptionContext):
46 | leak = ones_(leak_shape)
47 | with expectation:
48 | _ = GradDrop(leak=leak)
49 |
50 |
51 | @mark.parametrize(
52 | ["leak_shape", "n_rows", "expectation"],
53 | [
54 | ([0], 0, does_not_raise()),
55 | ([1], 1, does_not_raise()),
56 | ([5], 5, does_not_raise()),
57 | ([0], 1, raises(ValueError)),
58 | ([1], 0, raises(ValueError)),
59 | ([4], 5, raises(ValueError)),
60 | ([5], 4, raises(ValueError)),
61 | ],
62 | )
63 | def test_matrix_shape_check(leak_shape: list[int], n_rows: int, expectation: ExceptionContext):
64 | matrix = ones_([n_rows, 5])
65 | leak = ones_(leak_shape)
66 | aggregator = GradDrop(leak=leak)
67 |
68 | with expectation:
69 | _ = aggregator(matrix)
70 |
71 |
72 | def test_representations():
73 | A = GradDrop(leak=torch.tensor([0.0, 1.0], device="cpu"))
74 | assert re.match(
75 | r"GradDrop\(f=, leak=tensor\(\[0\., 1\.\]\)\)",
76 | repr(A),
77 | )
78 |
79 | assert str(A) == "GradDrop([0., 1.])"
80 |
81 | A = GradDrop()
82 | assert re.match(r"GradDrop\(f=, leak=None\)", repr(A))
83 | assert str(A) == "GradDrop"
84 |
--------------------------------------------------------------------------------
/docs/source/examples/basic_usage.rst:
--------------------------------------------------------------------------------
1 | Basic Usage
2 | ===========
3 |
4 | This example shows how to use TorchJD to perform an iteration of Jacobian descent on a regression
5 | model with two objectives. In this example, a batch of inputs is forwarded through the model and two
6 | corresponding batches of labels are used to compute two losses. These losses are then backwarded
7 | through the model. The obtained Jacobian matrix, consisting of the gradients of the two losses with
8 | respect to the parameters, is then aggregated using :doc:`UPGrad <../docs/aggregation/upgrad>`, and
9 | the parameters are updated using the resulting aggregation.
10 |
11 |
12 |
13 | Import several classes from ``torch`` and ``torchjd``:
14 |
15 | .. code-block:: python
16 |
17 | import torch
18 | from torch.nn import Linear, MSELoss, ReLU, Sequential
19 | from torch.optim import SGD
20 |
21 | from torchjd import autojac
22 | from torchjd.aggregation import UPGrad
23 |
24 | Define the model and the optimizer, as usual:
25 |
26 | .. code-block:: python
27 |
28 | model = Sequential(Linear(10, 5), ReLU(), Linear(5, 2))
29 | optimizer = SGD(model.parameters(), lr=0.1)
30 |
31 | Define the aggregator that will be used to combine the Jacobian matrix:
32 |
33 | .. code-block:: python
34 |
35 | aggregator = UPGrad()
36 |
37 | In essence, :doc:`UPGrad <../docs/aggregation/upgrad>` projects each gradient onto the dual cone of
38 | the rows of the Jacobian and averages the results. This ensures that locally, no loss will be
39 | negatively affected by the update.
40 |
41 | Now that everything is defined, we can train the model. Define the input and the associated target:
42 |
43 | .. code-block:: python
44 |
45 | input = torch.randn(16, 10) # Batch of 16 random input vectors of length 10
46 | target1 = torch.randn(16) # First batch of 16 targets
47 | target2 = torch.randn(16) # Second batch of 16 targets
48 |
49 | Here, we generate fake inputs and labels for the sake of the example.
50 |
51 | We can now compute the losses associated to each element of the batch.
52 |
53 | .. code-block:: python
54 |
55 | loss_fn = MSELoss()
56 | output = model(input)
57 | loss1 = loss_fn(output[:, 0], target1)
58 | loss2 = loss_fn(output[:, 1], target2)
59 |
60 | The last steps are similar to gradient descent-based optimization, but using the two losses.
61 |
62 | Reset the ``.grad`` field of each model parameter:
63 |
64 | .. code-block:: python
65 |
66 | optimizer.zero_grad()
67 |
68 | Perform the Jacobian descent backward pass:
69 |
70 | .. code-block:: python
71 |
72 | autojac.backward([loss1, loss2], aggregator)
73 |
74 | This will populate the ``.grad`` field of each model parameter with the corresponding aggregated
75 | Jacobian matrix.
76 |
77 | Update each parameter based on its ``.grad`` field, using the ``optimizer``:
78 |
79 | .. code-block:: python
80 |
81 | optimizer.step()
82 |
83 | The model's parameters have been updated!
84 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_weighting_bases.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from abc import ABC, abstractmethod
4 | from collections.abc import Callable
5 | from typing import Annotated, Generic, TypeVar
6 |
7 | from torch import Tensor, nn
8 |
9 | _T = TypeVar("_T", contravariant=True)
10 | _FnInputT = TypeVar("_FnInputT")
11 | _FnOutputT = TypeVar("_FnOutputT")
12 | Matrix = Annotated[Tensor, "ndim=2"]
13 | PSDMatrix = Annotated[Matrix, "Positive semi-definite"]
14 |
15 |
16 | class Weighting(Generic[_T], nn.Module, ABC):
17 | r"""
18 | Abstract base class for all weighting methods. It has the role of extracting a vector of weights
19 | of dimension :math:`m` from some statistic of a matrix of dimension :math:`m \times n`,
20 | generally its Gramian, of dimension :math:`m \times m`.
21 | """
22 |
23 | def __init__(self):
24 | super().__init__()
25 |
26 | @abstractmethod
27 | def forward(self, stat: _T) -> Tensor:
28 | """Computes the vector of weights from the input stat."""
29 |
30 | # Override to make type hints and documentation more specific
31 | def __call__(self, stat: _T) -> Tensor:
32 | """Computes the vector of weights from the input stat and applies all registered hooks."""
33 |
34 | return super().__call__(stat)
35 |
36 | def _compose(self, fn: Callable[[_FnInputT], _T]) -> Weighting[_FnInputT]:
37 | return _Composition(self, fn)
38 |
39 | __lshift__ = _compose
40 |
41 |
42 | class _Composition(Weighting[_T]):
43 | """
44 | Weighting that composes a Weighting with a function, so that the Weighting is applied to the
45 | output of the function.
46 | """
47 |
48 | def __init__(self, weighting: Weighting[_FnOutputT], fn: Callable[[_T], _FnOutputT]):
49 | super().__init__()
50 | self.fn = fn
51 | self.weighting = weighting
52 |
53 | def forward(self, stat: _T) -> Tensor:
54 | return self.weighting(self.fn(stat))
55 |
56 |
57 | class GeneralizedWeighting(nn.Module, ABC):
58 | r"""
59 | Abstract base class for all weightings that operate on generalized Gramians. It has the role of
60 | extracting a tensor of weights of dimension :math:`m_1 \times \dots \times m_k` from a
61 | generalized Gramian of dimension
62 | :math:`m_1 \times \dots \times m_k \times m_k \times \dots \times m_1`.
63 | """
64 |
65 | def __init__(self):
66 | super().__init__()
67 |
68 | @abstractmethod
69 | def forward(self, generalized_gramian: Tensor) -> Tensor:
70 | """Computes the vector of weights from the input generalized Gramian."""
71 |
72 | # Override to make type hints and documentation more specific
73 | def __call__(self, generalized_gramian: Tensor) -> Tensor:
74 | """
75 | Computes the tensor of weights from the input generalized Gramian and applies all registered
76 | hooks.
77 | """
78 |
79 | return super().__call__(generalized_gramian)
80 |
--------------------------------------------------------------------------------
/docs/source/examples/mtl.rst:
--------------------------------------------------------------------------------
1 | Multi-Task Learning (MTL)
2 | =========================
3 |
4 | In the context of multi-task learning, multiple tasks are performed simultaneously on a common
5 | input. Typically, a feature extractor is applied to the input to obtain a shared representation,
6 | useful for all tasks. Then, task-specific heads are applied to these features to obtain each task's
7 | result. A loss can then be computed for each task. Fundamentally, multi-task learning is a
8 | multi-objective optimization problem in which we minimize the vector of task losses.
9 |
10 | A common trick to train multi-task models is to cast the problem as single-objective, by minimizing
11 | a weighted sum of the losses. This works well in some cases, but sometimes conflict among tasks can
12 | make the optimization of the shared parameters very hard. Besides, the weight associated to each
13 | loss can be considered as a hyper-parameter. Finding their optimal value is generally expensive.
14 |
15 | Alternatively, the vector of losses can be directly minimized using Jacobian descent. The following
16 | example shows how to use TorchJD to train a very simple multi-task model with two regression tasks.
17 | For the sake of the example, we generate a fake dataset consisting of 8 batches of 16 random input
18 | vectors of dimension 10, and their corresponding scalar labels for both tasks.
19 |
20 |
21 | .. code-block:: python
22 | :emphasize-lines: 5-6, 19, 33
23 |
24 | import torch
25 | from torch.nn import Linear, MSELoss, ReLU, Sequential
26 | from torch.optim import SGD
27 |
28 | from torchjd.aggregation import UPGrad
29 | from torchjd.autojac import mtl_backward
30 |
31 | shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
32 | task1_module = Linear(3, 1)
33 | task2_module = Linear(3, 1)
34 | params = [
35 | *shared_module.parameters(),
36 | *task1_module.parameters(),
37 | *task2_module.parameters(),
38 | ]
39 |
40 | loss_fn = MSELoss()
41 | optimizer = SGD(params, lr=0.1)
42 | aggregator = UPGrad()
43 |
44 | inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
45 | task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
46 | task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
47 |
48 | for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
49 | features = shared_module(input)
50 | output1 = task1_module(features)
51 | output2 = task2_module(features)
52 | loss1 = loss_fn(output1, target1)
53 | loss2 = loss_fn(output2, target2)
54 |
55 | optimizer.zero_grad()
56 | mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
57 | optimizer.step()
58 |
59 | .. note::
60 | In this example, the Jacobian is only with respect to the shared parameters. The task-specific
61 | parameters are simply updated via the gradient of their task's loss with respect to them.
62 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_mgda.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 | from ._aggregator_bases import GramianWeightedAggregator
5 | from ._weighting_bases import PSDMatrix, Weighting
6 |
7 |
8 | class MGDA(GramianWeightedAggregator):
9 | r"""
10 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` performing the gradient aggregation
11 | step of `Multiple-gradient descent algorithm (MGDA) for multiobjective optimization
12 | `_. The implementation is
13 | based on Algorithm 2 of `Multi-Task Learning as Multi-Objective Optimization
14 | `_.
15 |
16 | :param epsilon: The value of :math:`\hat{\gamma}` below which we stop the optimization.
17 | :param max_iters: The maximum number of iterations of the optimization loop.
18 | """
19 |
20 | def __init__(self, epsilon: float = 0.001, max_iters: int = 100):
21 | super().__init__(MGDAWeighting(epsilon=epsilon, max_iters=max_iters))
22 | self._epsilon = epsilon
23 | self._max_iters = max_iters
24 |
25 | def __repr__(self) -> str:
26 | return f"{self.__class__.__name__}(epsilon={self._epsilon}, max_iters={self._max_iters})"
27 |
28 |
29 | class MGDAWeighting(Weighting[PSDMatrix]):
30 | r"""
31 | :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
32 | :class:`~torchjd.aggregation.MGDA`.
33 |
34 | :param epsilon: The value of :math:`\hat{\gamma}` below which we stop the optimization.
35 | :param max_iters: The maximum number of iterations of the optimization loop.
36 | """
37 |
38 | def __init__(self, epsilon: float = 0.001, max_iters: int = 100):
39 | super().__init__()
40 | self.epsilon = epsilon
41 | self.max_iters = max_iters
42 |
43 | def forward(self, gramian: Tensor) -> Tensor:
44 | """
45 | This is the Frank-Wolfe solver in Algorithm 2 of `Multi-Task Learning as Multi-Objective
46 | Optimization
47 | `_.
48 | """
49 | device = gramian.device
50 | dtype = gramian.dtype
51 |
52 | alpha = torch.ones(gramian.shape[0], device=device, dtype=dtype) / gramian.shape[0]
53 | for i in range(self.max_iters):
54 | t = torch.argmin(gramian @ alpha)
55 | e_t = torch.zeros(gramian.shape[0], device=device, dtype=dtype)
56 | e_t[t] = 1.0
57 | a = alpha @ (gramian @ e_t)
58 | b = alpha @ (gramian @ alpha)
59 | c = e_t @ (gramian @ e_t)
60 | if c <= a:
61 | gamma = 1.0
62 | elif b <= a:
63 | gamma = 0.0
64 | else:
65 | gamma = (b - a) / (b + c - 2 * a) # type: ignore[assignment]
66 | alpha = (1 - gamma) * alpha + gamma * e_t
67 | if gamma < self.epsilon:
68 | break
69 | return alpha
70 |
--------------------------------------------------------------------------------
/docs/source/examples/iwmtl.rst:
--------------------------------------------------------------------------------
1 | Instance-Wise Multi-Task Learning (IWMTL)
2 | =========================================
3 |
4 | When training a model with multiple tasks, the gradients of the individual tasks are likely to
5 | conflict. This is particularly true when looking at the individual (per-sample) gradients.
6 | The :doc:`autogram engine <../docs/autogram/engine>` can be used to efficiently compute the Gramian
7 | of the Jacobian of the matrix of per-sample and per-task losses. Weights can then be extracted from
8 | this Gramian to reweight the gradients and resolve conflict entirely.
9 |
10 | The following example shows how to do that.
11 |
12 | .. code-block:: python
13 | :emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 41-42
14 |
15 | import torch
16 | from torch.nn import Linear, MSELoss, ReLU, Sequential
17 | from torch.optim import SGD
18 |
19 | from torchjd.aggregation import Flattening, UPGradWeighting
20 | from torchjd.autogram import Engine
21 |
22 | shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
23 | task1_module = Linear(3, 1)
24 | task2_module = Linear(3, 1)
25 | params = [
26 | *shared_module.parameters(),
27 | *task1_module.parameters(),
28 | *task2_module.parameters(),
29 | ]
30 |
31 | optimizer = SGD(params, lr=0.1)
32 | mse = MSELoss(reduction="none")
33 | weighting = Flattening(UPGradWeighting())
34 | engine = Engine(shared_module, batch_dim=0)
35 |
36 | inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
37 | task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task
38 | task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task
39 |
40 | for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
41 | features = shared_module(input) # shape: [16, 3]
42 | out1 = task1_module(features).squeeze(1) # shape: [16]
43 | out2 = task2_module(features).squeeze(1) # shape: [16]
44 |
45 | # Compute the matrix of losses: one loss per element of the batch and per task
46 | losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2]
47 |
48 | # Compute the gramian (inner products between pairs of gradients of the losses)
49 | gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16]
50 |
51 | # Obtain the weights that lead to no conflict between reweighted gradients
52 | weights = weighting(gramian) # shape: [16, 2]
53 |
54 | optimizer.zero_grad()
55 | # Do the standard backward pass, but weighted using the obtained weights
56 | losses.backward(weights)
57 | optimizer.step()
58 |
59 | .. note::
60 | In this example, the tensor of losses is a matrix rather than a vector. The gramian is thus a
61 | 4D tensor rather than a matrix, and a
62 | :class:`~torchjd.aggregation._weighting_bases.GeneralizedWeighting`, such as
63 | :class:`~torchjd.aggregation._flattening.Flattening`, has to be used to extract a matrix of
64 | weights from it. More information about ``GeneralizedWeighting`` can be found in the
65 | :doc:`../../docs/aggregation/index` page.
66 |
--------------------------------------------------------------------------------
/docs/source/examples/amp.rst:
--------------------------------------------------------------------------------
1 | Automatic Mixed Precision (AMP)
2 | ===============================
3 |
4 | In some cases, to save memory and reduce computation time, you may want to use `automatic mixed
5 | precision `_. Since the
6 | `torch.amp.GradScaler `_ class already
7 | works on multiple losses, it's pretty straightforward to combine TorchJD and AMP. As usual, the
8 | forward pass should be wrapped within a `torch.autocast
9 | `_ context, and as usual, the loss (in our
10 | case, the losses) should preferably be scaled with a `GradScaler
11 | `_ to avoid gradient underflow. The
12 | following example shows the resulting code for a multi-task learning use-case.
13 |
14 | .. code-block:: python
15 | :emphasize-lines: 2, 17, 27, 34, 36-38
16 |
17 | import torch
18 | from torch.amp import GradScaler
19 | from torch.nn import Linear, MSELoss, ReLU, Sequential
20 | from torch.optim import SGD
21 |
22 | from torchjd.aggregation import UPGrad
23 | from torchjd.autojac import mtl_backward
24 |
25 | shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
26 | task1_module = Linear(3, 1)
27 | task2_module = Linear(3, 1)
28 | params = [
29 | *shared_module.parameters(),
30 | *task1_module.parameters(),
31 | *task2_module.parameters(),
32 | ]
33 | scaler = GradScaler(device="cpu")
34 | loss_fn = MSELoss()
35 | optimizer = SGD(params, lr=0.1)
36 | aggregator = UPGrad()
37 |
38 | inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
39 | task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
40 | task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
41 |
42 | for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
43 | with torch.autocast(device_type="cpu", dtype=torch.float16):
44 | features = shared_module(input)
45 | output1 = task1_module(features)
46 | output2 = task2_module(features)
47 | loss1 = loss_fn(output1, target1)
48 | loss2 = loss_fn(output2, target2)
49 |
50 | scaled_losses = scaler.scale([loss1, loss2])
51 | optimizer.zero_grad()
52 | mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator)
53 | scaler.step(optimizer)
54 | scaler.update()
55 |
56 | .. hint::
57 | Within the ``torch.autocast`` context, some operations may be done in ``float16`` type. For
58 | those operations, the tensors saved for the backward pass will also be of ``float16`` type.
59 | However, the Jacobian computed by ``mtl_backward`` will be of type ``float32``, so the ``.grad``
60 | fields of the model parameters will also be of type ``float32``. This is in line with the
61 | behavior of PyTorch, that would also compute all gradients in ``float32`` type.
62 |
63 | .. note::
64 | :doc:`torchjd.backward <../docs/autojac/backward>` can be similarly combined with AMP.
65 |
--------------------------------------------------------------------------------
/src/torchjd/autogram/_gramian_utils.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 |
3 |
4 | def reshape_gramian(gramian: Tensor, half_shape: list[int]) -> Tensor:
5 | """
6 | Reshapes a Gramian to a provided shape. The reshape of the first half of the target dimensions
7 | must be done from the left, while the reshape of the second half must be done from the right.
8 |
9 | :param gramian: Gramian to reshape. Can be a generalized Gramian.
10 | :param half_shape: First half of the target shape, the shape of the output is therefore
11 | `shape + shape[::-1]`.
12 | """
13 |
14 | # Example 1: `gramian` of shape [4, 3, 2, 2, 3, 4] and `shape` of [8, 3]:
15 | # [4, 3, 2, 2, 3, 4] -(movedim)-> [4, 3, 2, 4, 3, 2] -(reshape)-> [8, 3, 8, 3] -(movedim)->
16 | # [8, 3, 3, 8]
17 | #
18 | # Example 2: `gramian` of shape [24, 24] and `shape` of [4, 3, 2]:
19 | # [24, 24] -(movedim)-> [24, 24] -(reshape)-> [4, 3, 2, 4, 3, 2] -(movedim)-> [4, 3, 2, 2, 3, 4]
20 |
21 | return _revert_last_dims(_revert_last_dims(gramian).reshape(half_shape + half_shape))
22 |
23 |
24 | def _revert_last_dims(gramian: Tensor) -> Tensor:
25 | """Inverts the order of the last half of the dimensions of the input generalized Gramian."""
26 |
27 | half_ndim = gramian.ndim // 2
28 | last_dims = [half_ndim + i for i in range(half_ndim)]
29 | return gramian.movedim(last_dims, last_dims[::-1])
30 |
31 |
32 | def movedim_gramian(gramian: Tensor, half_source: list[int], half_destination: list[int]) -> Tensor:
33 | """
34 | Moves the dimensions of a Gramian from some source dimensions to destination dimensions. This
35 | must be done simultaneously on the first half of the dimensions and on the second half of the
36 | dimensions reversed.
37 |
38 | :param gramian: Gramian to reshape. Can be a generalized Gramian.
39 | :param half_source: Source dimensions, that should be in the range [-gramian.ndim//2,
40 | gramian.ndim//2[. Its elements should be unique.
41 | :param half_destination: Destination dimensions, that should be in the range
42 | [-gramian.ndim//2, gramian.ndim//2[. It should have the same size as `half_source`, and its
43 | elements should be unique.
44 | """
45 |
46 | # Example: `gramian` of shape [4, 3, 2, 2, 3, 4], `half_source` of [-2, 2] and
47 | # `half_destination` of [0, 1]:
48 | # - `half_source_` will be [1, 2] and `half_destination_` will be [0, 1]
49 | # - `source` will be [1, 2, 4, 3] and `destination` will be [0, 1, 5, 4]
50 | # - The `moved_gramian` will be of shape [3, 2, 4, 4, 2, 3]
51 |
52 | # Map everything to the range [0, gramian.ndim//2[
53 | half_ndim = gramian.ndim // 2
54 | half_source_ = [i if 0 <= i else i + half_ndim for i in half_source]
55 | half_destination_ = [i if 0 <= i else i + half_ndim for i in half_destination]
56 |
57 | # Mirror the half source and the half destination and use the result to move the dimensions of
58 | # the gramian
59 | last_dim = gramian.ndim - 1
60 | source = half_source_ + [last_dim - i for i in half_source_]
61 | destination = half_destination_ + [last_dim - i for i in half_destination_]
62 | moved_gramian = gramian.movedim(source, destination)
63 | return moved_gramian
64 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/test_constant.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext as does_not_raise
2 |
3 | import torch
4 | from pytest import mark, raises
5 | from torch import Tensor
6 | from utils.contexts import ExceptionContext
7 | from utils.tensors import ones_, tensor_
8 |
9 | from torchjd.aggregation import Constant
10 |
11 | from ._asserts import (
12 | assert_expected_structure,
13 | assert_linear_under_scaling,
14 | assert_strongly_stationary,
15 | )
16 | from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices
17 |
18 |
19 | def _make_aggregator(matrix: Tensor) -> Constant:
20 | n_rows = matrix.shape[0]
21 | weights = tensor_([1.0 / n_rows] * n_rows, dtype=matrix.dtype)
22 | return Constant(weights)
23 |
24 |
25 | scaled_pairs = [(_make_aggregator(matrix), matrix) for matrix in scaled_matrices]
26 | typical_pairs = [(_make_aggregator(matrix), matrix) for matrix in typical_matrices]
27 | non_strong_pairs = [(_make_aggregator(matrix), matrix) for matrix in non_strong_matrices]
28 |
29 |
30 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
31 | def test_expected_structure(aggregator: Constant, matrix: Tensor):
32 | assert_expected_structure(aggregator, matrix)
33 |
34 |
35 | @mark.parametrize(["aggregator", "matrix"], typical_pairs)
36 | def test_linear_under_scaling(aggregator: Constant, matrix: Tensor):
37 | assert_linear_under_scaling(aggregator, matrix)
38 |
39 |
40 | @mark.parametrize(["aggregator", "matrix"], non_strong_pairs)
41 | def test_strongly_stationary(aggregator: Constant, matrix: Tensor):
42 | assert_strongly_stationary(aggregator, matrix)
43 |
44 |
45 | @mark.parametrize(
46 | ["weights_shape", "expectation"],
47 | [
48 | ([], raises(ValueError)),
49 | ([0], does_not_raise()),
50 | ([1], does_not_raise()),
51 | ([10], does_not_raise()),
52 | ([0, 0], raises(ValueError)),
53 | ([0, 1], raises(ValueError)),
54 | ([1, 1], raises(ValueError)),
55 | ([1, 1, 1], raises(ValueError)),
56 | ([1, 1, 1, 1], raises(ValueError)),
57 | ([1, 1, 1, 1, 1], raises(ValueError)),
58 | ],
59 | )
60 | def test_weights_shape_check(weights_shape: list[int], expectation: ExceptionContext):
61 | weights = ones_(weights_shape)
62 | with expectation:
63 | _ = Constant(weights=weights)
64 |
65 |
66 | @mark.parametrize(
67 | ["weights_shape", "n_rows", "expectation"],
68 | [
69 | ([0], 0, does_not_raise()),
70 | ([1], 1, does_not_raise()),
71 | ([5], 5, does_not_raise()),
72 | ([0], 1, raises(ValueError)),
73 | ([1], 0, raises(ValueError)),
74 | ([4], 5, raises(ValueError)),
75 | ([5], 4, raises(ValueError)),
76 | ],
77 | )
78 | def test_matrix_shape_check(weights_shape: list[int], n_rows: int, expectation: ExceptionContext):
79 | matrix = ones_([n_rows, 5])
80 | weights = ones_(weights_shape)
81 | aggregator = Constant(weights)
82 |
83 | with expectation:
84 | _ = aggregator(matrix)
85 |
86 |
87 | def test_representations():
88 | A = Constant(weights=torch.tensor([1.0, 2.0], device="cpu"))
89 | assert repr(A) == "Constant(weights=tensor([1., 2.]))"
90 | assert str(A) == "Constant([1., 2.])"
91 |
--------------------------------------------------------------------------------
/docs/source/examples/monitoring.rst:
--------------------------------------------------------------------------------
1 | Monitoring aggregations
2 | =======================
3 |
4 | The :doc:`Aggregator <../docs/aggregation/index>` class is a subclass of :class:`torch.nn.Module`.
5 | This allows registering hooks, which can be used to monitor some information about aggregations.
6 | The following code example demonstrates registering a hook to compute and print the cosine
7 | similarity between the aggregation performed by :doc:`UPGrad <../docs/aggregation/upgrad>` and the
8 | average of the gradients, and another hook to compute and print the weights of the weighting of
9 | :doc:`UPGrad <../docs/aggregation/upgrad>`.
10 |
11 | Updating the parameters of the model with the average gradient is equivalent to using gradient
12 | descent on the average of the losses. Observing a cosine similarity smaller than 1 means that
13 | Jacobian descent is doing something different than gradient descent. With
14 | :doc:`UPGrad <../docs/aggregation/upgrad>`, this happens when the original gradients conflict (i.e.
15 | they have a negative inner product).
16 |
17 | .. code-block:: python
18 | :emphasize-lines: 9-11, 13-18, 33-34
19 |
20 | import torch
21 | from torch.nn import Linear, MSELoss, ReLU, Sequential
22 | from torch.nn.functional import cosine_similarity
23 | from torch.optim import SGD
24 |
25 | from torchjd.aggregation import UPGrad
26 | from torchjd.autojac import mtl_backward
27 |
28 | def print_weights(_, __, weights: torch.Tensor) -> None:
29 | """Prints the extracted weights."""
30 | print(f"Weights: {weights}")
31 |
32 | def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.Tensor) -> None:
33 | """Prints the cosine similarity between the aggregation and the average gradient."""
34 | matrix = inputs[0]
35 | gd_output = matrix.mean(dim=0)
36 | similarity = cosine_similarity(aggregation, gd_output, dim=0)
37 | print(f"Cosine similarity: {similarity.item():.4f}")
38 |
39 | shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
40 | task1_module = Linear(3, 1)
41 | task2_module = Linear(3, 1)
42 | params = [
43 | *shared_module.parameters(),
44 | *task1_module.parameters(),
45 | *task2_module.parameters(),
46 | ]
47 |
48 | loss_fn = MSELoss()
49 | optimizer = SGD(params, lr=0.1)
50 | aggregator = UPGrad()
51 |
52 | aggregator.weighting.register_forward_hook(print_weights)
53 | aggregator.register_forward_hook(print_gd_similarity)
54 |
55 | inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
56 | task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
57 | task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
58 |
59 | for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
60 | features = shared_module(input)
61 | output1 = task1_module(features)
62 | output2 = task2_module(features)
63 | loss1 = loss_fn(output1, target1)
64 | loss2 = loss_fn(output2, target2)
65 |
66 | optimizer.zero_grad()
67 | mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
68 | optimizer.step()
69 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_graddrop.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 | from ._aggregator_bases import Aggregator
7 | from ._utils.non_differentiable import raise_non_differentiable_error
8 |
9 |
10 | def _identity(P: Tensor) -> Tensor:
11 | return P
12 |
13 |
14 | class GradDrop(Aggregator):
15 | """
16 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that applies the gradient combination
17 | steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of `Just Pick a Sign:
18 | Optimizing Deep Multitask Models with Gradient Sign Dropout
19 | `_.
20 |
21 | :param f: The function to apply to the Gradient Positive Sign Purity. It should be monotically
22 | increasing. Defaults to identity.
23 | :param leak: The tensor of leak values, determining how much each row is allowed to leak
24 | through. Defaults to None, which means no leak.
25 | """
26 |
27 | def __init__(self, f: Callable = _identity, leak: Tensor | None = None):
28 | if leak is not None and leak.dim() != 1:
29 | raise ValueError(
30 | "Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = "
31 | f"{leak.shape}`."
32 | )
33 |
34 | super().__init__()
35 | self.f = f
36 | self.leak = leak
37 |
38 | # This prevents computing gradients that can be very wrong.
39 | self.register_full_backward_pre_hook(raise_non_differentiable_error)
40 |
41 | def forward(self, matrix: Tensor) -> Tensor:
42 | self._check_is_matrix(matrix)
43 | self._check_matrix_has_enough_rows(matrix)
44 |
45 | if matrix.shape[0] == 0 or matrix.shape[1] == 0:
46 | return torch.zeros(matrix.shape[1], dtype=matrix.dtype, device=matrix.device)
47 |
48 | leak = self.leak if self.leak is not None else torch.zeros_like(matrix[:, 0])
49 |
50 | P = 0.5 * (torch.ones_like(matrix[0]) + matrix.sum(dim=0) / matrix.abs().sum(dim=0))
51 | fP = self.f(P)
52 | U = torch.rand(P.shape, dtype=matrix.dtype, device=matrix.device)
53 |
54 | vector = torch.zeros_like(matrix[0])
55 | for i in range(len(matrix)):
56 | M_i = (fP > U) * (matrix[i] > 0) + (fP < U) * (matrix[i] < 0)
57 | vector += (leak[i] + (1 - leak[i]) * M_i) * matrix[i]
58 |
59 | return vector
60 |
61 | def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None:
62 | n_rows = matrix.shape[0]
63 | if self.leak is not None and n_rows != len(self.leak):
64 | raise ValueError(
65 | f"Parameter `matrix` should be a matrix of exactly {len(self.leak)} rows (i.e. the "
66 | f"number of leak scalars). Found `matrix` of shape `{matrix.shape}`."
67 | )
68 |
69 | def __repr__(self) -> str:
70 | return f"{self.__class__.__name__}(f={repr(self.f)}, leak={repr(self.leak)})"
71 |
72 | def __str__(self) -> str:
73 | if self.leak is None:
74 | leak_str = ""
75 | else:
76 | leak_str = f"([{', '.join(['{:.2f}'.format(l_).rstrip('0') for l_ in self.leak])}])"
77 | return f"GradDrop{leak_str}"
78 |
--------------------------------------------------------------------------------
/tests/unit/autogram/test_edge_registry.py:
--------------------------------------------------------------------------------
1 | from torch.autograd.graph import get_gradient_edge
2 | from utils.tensors import randn_
3 |
4 | from torchjd.autogram._edge_registry import EdgeRegistry
5 |
6 |
7 | def test_all_edges_are_leaves1():
8 | """Tests that get_leaf_edges works correctly when all edges are already leaves."""
9 |
10 | a = randn_([3, 4], requires_grad=True)
11 | b = randn_([4], requires_grad=True)
12 | c = randn_([3], requires_grad=True)
13 |
14 | d = (a @ b) + c
15 |
16 | edge_registry = EdgeRegistry()
17 | for tensor in [a, b, c]:
18 | edge_registry.register(get_gradient_edge(tensor))
19 |
20 | expected_leaves = {get_gradient_edge(tensor) for tensor in [a, b, c]}
21 | leaves = edge_registry.get_leaf_edges({get_gradient_edge(d)})
22 | assert leaves == expected_leaves
23 |
24 |
25 | def test_all_edges_are_leaves2():
26 | """
27 | Tests that get_leaf_edges works correctly when all edges are already leaves of the graph of
28 | edges leading to them, but are not leaves of the autograd graph.
29 | """
30 |
31 | a = randn_([3, 4], requires_grad=True)
32 | b = randn_([4], requires_grad=True)
33 | c = randn_([4], requires_grad=True)
34 | d = randn_([4], requires_grad=True)
35 |
36 | e = a * b
37 | f = e + c
38 | g = f + d
39 |
40 | edge_registry = EdgeRegistry()
41 | for tensor in [e, g]:
42 | edge_registry.register(get_gradient_edge(tensor))
43 |
44 | expected_leaves = {get_gradient_edge(tensor) for tensor in [e, g]}
45 | leaves = edge_registry.get_leaf_edges({get_gradient_edge(e), get_gradient_edge(g)})
46 | assert leaves == expected_leaves
47 |
48 |
49 | def test_some_edges_are_not_leaves1():
50 | """Tests that get_leaf_edges works correctly when some edges are leaves and some are not."""
51 |
52 | a = randn_([3, 4], requires_grad=True)
53 | b = randn_([4], requires_grad=True)
54 | c = randn_([4], requires_grad=True)
55 | d = randn_([4], requires_grad=True)
56 |
57 | e = a * b
58 | f = e + c
59 | g = f + d
60 |
61 | edge_registry = EdgeRegistry()
62 | for tensor in [a, b, c, d, e, f, g]:
63 | edge_registry.register(get_gradient_edge(tensor))
64 |
65 | expected_leaves = {get_gradient_edge(tensor) for tensor in [a, b, c, d]}
66 | leaves = edge_registry.get_leaf_edges({get_gradient_edge(g)})
67 | assert leaves == expected_leaves
68 |
69 |
70 | def test_some_edges_are_not_leaves2():
71 | """
72 | Tests that get_leaf_edges works correctly when some edges are leaves and some are not. This
73 | time, not all tensors in the graph are registered so not all leavese in the graph have to be
74 | returned.
75 | """
76 |
77 | a = randn_([3, 4], requires_grad=True)
78 | b = randn_([4], requires_grad=True)
79 | c = randn_([4], requires_grad=True)
80 | d = randn_([4], requires_grad=True)
81 |
82 | e = a * b
83 | f = e + c
84 | g = f + d
85 |
86 | edge_registry = EdgeRegistry()
87 | for tensor in [a, c, d, e, g]:
88 | edge_registry.register(get_gradient_edge(tensor))
89 |
90 | expected_leaves = {get_gradient_edge(tensor) for tensor in [a, c, d]}
91 | leaves = edge_registry.get_leaf_edges({get_gradient_edge(g)})
92 | assert leaves == expected_leaves
93 |
--------------------------------------------------------------------------------
/src/torchjd/autojac/_transform/_differentiate.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from collections.abc import Sequence
3 |
4 | import torch
5 | from torch import Tensor
6 |
7 | from ._base import RequirementError, TensorDict, Transform
8 | from ._materialize import materialize
9 | from ._ordered_set import OrderedSet
10 |
11 |
12 | class Differentiate(Transform, ABC):
13 | """
14 | Abstract base class for transforms responsible for differentiating some outputs with respect to
15 | some inputs.
16 |
17 | :param outputs: Tensors to differentiate.
18 | :param inputs: Tensors with respect to which we differentiate.
19 | :param retain_graph: If False, the graph used to compute the grads will be freed.
20 | :param create_graph: If True, graph of the derivative will be constructed, allowing to compute
21 | higher order derivative products.
22 |
23 | .. note:: The order of outputs and inputs only matters because we have no guarantee that
24 | torch.autograd.grad is *exactly* equivariant to input permutations and invariant to output
25 | (with their corresponding grad_output) permutations.
26 | """
27 |
28 | def __init__(
29 | self,
30 | outputs: OrderedSet[Tensor],
31 | inputs: OrderedSet[Tensor],
32 | retain_graph: bool,
33 | create_graph: bool,
34 | ):
35 | self.outputs = list(outputs)
36 | self.inputs = list(inputs)
37 | self.retain_graph = retain_graph
38 | self.create_graph = create_graph
39 |
40 | def __call__(self, tensors: TensorDict) -> TensorDict:
41 | tensor_outputs = [tensors[output] for output in self.outputs]
42 |
43 | differentiated_tuple = self._differentiate(tensor_outputs)
44 | new_differentiations = dict(zip(self.inputs, differentiated_tuple))
45 | return type(tensors)(new_differentiations)
46 |
47 | @abstractmethod
48 | def _differentiate(self, tensor_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
49 | """
50 | Abstract method for differentiating the outputs with respect to the inputs, and applying the
51 | linear transformations represented by the tensor_outputs to the results.
52 |
53 | The implementation of this method should define what kind of differentiation is performed:
54 | whether gradients, Jacobians, etc. are computed, and what the dimension of the
55 | tensor_outputs should be.
56 | """
57 |
58 | def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
59 | outputs = set(self.outputs)
60 | if not outputs == input_keys:
61 | raise RequirementError(
62 | f"The input_keys must match the expected outputs. Found input_keys {input_keys} and"
63 | f"outputs {outputs}."
64 | )
65 | return set(self.inputs)
66 |
67 | def _get_vjp(self, grad_outputs: Sequence[Tensor], retain_graph: bool) -> tuple[Tensor, ...]:
68 | optional_grads = torch.autograd.grad(
69 | self.outputs,
70 | self.inputs,
71 | grad_outputs=grad_outputs,
72 | retain_graph=retain_graph,
73 | create_graph=self.create_graph,
74 | allow_unused=True,
75 | )
76 | grads = materialize(optional_grads, inputs=self.inputs)
77 | return grads
78 |
--------------------------------------------------------------------------------
/docs/source/examples/lightning_integration.rst:
--------------------------------------------------------------------------------
1 | PyTorch Lightning Integration
2 | =============================
3 |
4 | To use Jacobian descent with TorchJD in a :class:`~lightning.pytorch.core.LightningModule`, you need
5 | to turn off automatic optimization by setting ``automatic_optimization`` to ``False`` and to
6 | customize the ``training_step`` method to make it call the appropriate TorchJD method
7 | (:doc:`backward <../docs/autojac/backward>` or :doc:`mtl_backward <../docs/autojac/mtl_backward>`).
8 |
9 | The following code example demonstrates a basic multi-task learning setup using a
10 | :class:`~lightning.pytorch.core.LightningModule` that will call :doc:`mtl_backward
11 | <../docs/autojac/mtl_backward>` at each training iteration.
12 |
13 | .. code-block:: python
14 | :emphasize-lines: 9-10, 18, 32
15 |
16 | import torch
17 | from lightning import LightningModule, Trainer
18 | from lightning.pytorch.utilities.types import OptimizerLRScheduler
19 | from torch.nn import Linear, ReLU, Sequential
20 | from torch.nn.functional import mse_loss
21 | from torch.optim import Adam
22 | from torch.utils.data import DataLoader, TensorDataset
23 |
24 | from torchjd.aggregation import UPGrad
25 | from torchjd.autojac import mtl_backward
26 |
27 | class Model(LightningModule):
28 | def __init__(self):
29 | super().__init__()
30 | self.feature_extractor = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
31 | self.task1_head = Linear(3, 1)
32 | self.task2_head = Linear(3, 1)
33 | self.automatic_optimization = False
34 |
35 | def training_step(self, batch, batch_idx) -> None:
36 | input, target1, target2 = batch
37 |
38 | features = self.feature_extractor(input)
39 | output1 = self.task1_head(features)
40 | output2 = self.task2_head(features)
41 |
42 | loss1 = mse_loss(output1, target1)
43 | loss2 = mse_loss(output2, target2)
44 |
45 | opt = self.optimizers()
46 | opt.zero_grad()
47 | mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad())
48 | opt.step()
49 |
50 | def configure_optimizers(self) -> OptimizerLRScheduler:
51 | optimizer = Adam(self.parameters(), lr=1e-3)
52 | return optimizer
53 |
54 | model = Model()
55 |
56 | inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
57 | task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
58 | task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
59 |
60 | dataset = TensorDataset(inputs, task1_targets, task2_targets)
61 | train_loader = DataLoader(dataset)
62 | trainer = Trainer(
63 | accelerator="cpu",
64 | max_epochs=1,
65 | enable_checkpointing=False,
66 | logger=False,
67 | enable_progress_bar=False,
68 | )
69 |
70 | trainer.fit(model=model, train_dataloaders=train_loader)
71 |
72 | .. warning::
73 | This will not handle automatic scaling in low-precision settings. There is currently no easy
74 | fix.
75 |
76 | .. warning::
77 | TorchJD is incompatible with compiled models, so you must ensure that your model is not
78 | compiled.
79 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | |
4 |
5 | .. image:: _static/logo-dark-mode.png
6 | :width: 400
7 | :alt: torchjd
8 | :align: center
9 | :class: only-dark, no-scaled-link
10 |
11 | .. image:: _static/logo-light-mode.png
12 | :width: 400
13 | :alt: torchjd
14 | :align: center
15 | :class: only-light, no-scaled-link
16 |
17 | |
18 |
19 | TorchJD is a library enabling Jacobian descent with PyTorch, to train neural networks with multiple
20 | objectives. It is based on the theory from `Jacobian Descent For Multi-Objective Optimization
21 | `_ and several other related publications.
22 |
23 | The main purpose is to jointly optimize multiple objectives without combining them into a single
24 | scalar loss. When the objectives are conflicting, this can be the key to a successful and stable
25 | optimization. To get started, check out our :doc:`basic usage example
26 | `.
27 |
28 | Gradient descent relies on gradients to optimize a single objective. Jacobian descent takes this
29 | idea a step further, using the Jacobian to optimize multiple objectives. An important component of
30 | Jacobian descent is the aggregator, which maps the Jacobian to an optimization step. In the page
31 | :doc:`Aggregation `, we provide an overview of the various aggregators
32 | available in TorchJD, and their corresponding weightings.
33 |
34 | A straightforward application of Jacobian descent is multi-task learning, in which the vector of
35 | per-task losses has to be minimized. To start using TorchJD for multi-task learning, follow our
36 | :doc:`MTL example `.
37 |
38 | Another more interesting application is to consider separately the loss of each element in the
39 | batch. This is what we define as :doc:`Instance-Wise Risk Minimization ` (IWRM).
40 |
41 | The Gramian-based Jacobian descent algorithm provides a very efficient alternative way of
42 | performing Jacobian descent. It consists in computing
43 | the Gramian of the Jacobian iteratively during the backward pass (without ever storing the full
44 | Jacobian in memory), weighting the losses using the information of the Gramian, and then computing
45 | the gradient of the obtained weighted loss. The iterative computation of the Gramian corresponds to
46 | Algorithm 3 of
47 | `Jacobian Descent For Multi-Objective Optimization `_. The
48 | documentation and usage example of this algorithm is provided in
49 | :doc:`autogram.Engine `.
50 |
51 | The original usage of the autogram engine is to compute the Gramian of the Jacobian very efficiently
52 | for :doc:`IWRM `. Another direct application is when considering one loss per element
53 | of the batch and per task, in the context of multi-task learning. We call this
54 | :doc:`Instance-Wise Risk Multi-Task Learning ` (IWMTL).
55 |
56 | TorchJD is open-source, under MIT License. The source code is available on
57 | `GitHub `_.
58 |
59 | .. toctree::
60 | :caption: Getting Started
61 | :hidden:
62 |
63 | installation.md
64 | examples/index.rst
65 |
66 | .. toctree::
67 | :caption: API Reference
68 | :hidden:
69 |
70 | docs/autogram/index.rst
71 | docs/autojac/index.rst
72 | docs/aggregation/index.rst
73 |
--------------------------------------------------------------------------------
/tests/unit/autojac/_transform/test_accumulate.py:
--------------------------------------------------------------------------------
1 | from pytest import mark, raises
2 | from utils.dict_assertions import assert_tensor_dicts_are_close
3 | from utils.tensors import ones_, tensor_, zeros_
4 |
5 | from torchjd.autojac._transform import Accumulate
6 |
7 |
8 | def test_single_accumulation():
9 | """
10 | Tests that the Accumulate transform correctly accumulates gradients in .grad fields when run
11 | once.
12 | """
13 |
14 | key1 = zeros_([], requires_grad=True)
15 | key2 = zeros_([1], requires_grad=True)
16 | key3 = zeros_([2, 3], requires_grad=True)
17 | value1 = ones_([])
18 | value2 = ones_([1])
19 | value3 = ones_([2, 3])
20 | input = {key1: value1, key2: value2, key3: value3}
21 |
22 | accumulate = Accumulate()
23 |
24 | output = accumulate(input)
25 | expected_output = {}
26 |
27 | assert_tensor_dicts_are_close(output, expected_output)
28 |
29 | grads = {key1: key1.grad, key2: key2.grad, key3: key3.grad}
30 | expected_grads = {key1: value1, key2: value2, key3: value3}
31 |
32 | assert_tensor_dicts_are_close(grads, expected_grads)
33 |
34 |
35 | @mark.parametrize("iterations", [1, 2, 4, 10, 13])
36 | def test_multiple_accumulation(iterations: int):
37 | """
38 | Tests that the Accumulate transform correctly accumulates gradients in .grad fields when run
39 | `iterations` times.
40 | """
41 |
42 | key1 = zeros_([], requires_grad=True)
43 | key2 = zeros_([1], requires_grad=True)
44 | key3 = zeros_([2, 3], requires_grad=True)
45 | value1 = ones_([])
46 | value2 = ones_([1])
47 | value3 = ones_([2, 3])
48 | input = {key1: value1, key2: value2, key3: value3}
49 |
50 | accumulate = Accumulate()
51 |
52 | for i in range(iterations):
53 | accumulate(input)
54 |
55 | grads = {key1: key1.grad, key2: key2.grad, key3: key3.grad}
56 | expected_grads = {
57 | key1: iterations * value1,
58 | key2: iterations * value2,
59 | key3: iterations * value3,
60 | }
61 |
62 | assert_tensor_dicts_are_close(grads, expected_grads)
63 |
64 |
65 | def test_no_requires_grad_fails():
66 | """
67 | Tests that the Accumulate transform raises an error when it tries to populate a .grad of a
68 | tensor that does not require grad.
69 | """
70 |
71 | key = zeros_([1], requires_grad=False)
72 | value = ones_([1])
73 | input = {key: value}
74 |
75 | accumulate = Accumulate()
76 |
77 | with raises(ValueError):
78 | accumulate(input)
79 |
80 |
81 | def test_no_leaf_and_no_retains_grad_fails():
82 | """
83 | Tests that the Accumulate transform raises an error when it tries to populate a .grad of a
84 | tensor that is not a leaf and that does not retain grad.
85 | """
86 |
87 | key = tensor_([1.0], requires_grad=True) * 2
88 | value = ones_([1])
89 | input = {key: value}
90 |
91 | accumulate = Accumulate()
92 |
93 | with raises(ValueError):
94 | accumulate(input)
95 |
96 |
97 | def test_check_keys():
98 | """Tests that the `check_keys` method works correctly."""
99 |
100 | key = tensor_([1.0], requires_grad=True)
101 | accumulate = Accumulate()
102 |
103 | output_keys = accumulate.check_keys({key})
104 | assert output_keys == set()
105 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_config.py:
--------------------------------------------------------------------------------
1 | # The code of this file was partly adapted from
2 | # https://github.com/tum-pbs/ConFIG/tree/main/conflictfree.
3 | # It is therefore also subject to the following license.
4 | #
5 | # MIT License
6 | #
7 | # Copyright (c) 2024 TUM Physics-based Simulation
8 | #
9 | # Permission is hereby granted, free of charge, to any person obtaining a copy
10 | # of this software and associated documentation files (the "Software"), to deal
11 | # in the Software without restriction, including without limitation the rights
12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 | # copies of the Software, and to permit persons to whom the Software is
14 | # furnished to do so, subject to the following conditions:
15 | #
16 | # The above copyright notice and this permission notice shall be included in all
17 | # copies or substantial portions of the Software.
18 | #
19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 | # SOFTWARE.
26 |
27 |
28 | import torch
29 | from torch import Tensor
30 |
31 | from ._aggregator_bases import Aggregator
32 | from ._sum import SumWeighting
33 | from ._utils.non_differentiable import raise_non_differentiable_error
34 | from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
35 |
36 |
37 | class ConFIG(Aggregator):
38 | """
39 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Equation 2 of `ConFIG:
40 | Towards Conflict-free Training of Physics Informed Neural Networks
41 | `_.
42 |
43 | :param pref_vector: The preference vector used to weight the rows. If not provided, defaults to
44 | equal weights of 1.
45 |
46 | .. note::
47 | This implementation was adapted from the `official implementation
48 | `_.
49 | """
50 |
51 | def __init__(self, pref_vector: Tensor | None = None):
52 | super().__init__()
53 | self.weighting = pref_vector_to_weighting(pref_vector, default=SumWeighting())
54 | self._pref_vector = pref_vector
55 |
56 | # This prevents computing gradients that can be very wrong.
57 | self.register_full_backward_pre_hook(raise_non_differentiable_error)
58 |
59 | def forward(self, matrix: Tensor) -> Tensor:
60 | weights = self.weighting(matrix)
61 | units = torch.nan_to_num((matrix / (matrix.norm(dim=1)).unsqueeze(1)), 0.0)
62 | best_direction = torch.linalg.pinv(units) @ weights
63 |
64 | unit_target_vector = torch.nn.functional.normalize(best_direction, dim=0)
65 |
66 | length = torch.sum(matrix @ unit_target_vector)
67 |
68 | return length * unit_target_vector
69 |
70 | def __repr__(self) -> str:
71 | return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"
72 |
73 | def __str__(self) -> str:
74 | return f"ConFIG{pref_vector_to_str_suffix(self._pref_vector)}"
75 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | pull_request:
5 | workflow_dispatch:
6 | schedule:
7 | - cron: '41 16 * * *' # Every day at 16:41 UTC (to avoid high load at exact hour values).
8 |
9 | env:
10 | UV_NO_SYNC: 1
11 | PYTHON_VERSION: 3.14
12 |
13 | jobs:
14 | tests-full-install:
15 | name: Run tests with full install
16 | runs-on: ${{ matrix.os }}
17 | strategy:
18 | fail-fast: false # Ensure matrix jobs keep running even if one fails
19 | matrix:
20 | python-version: ['3.10', '3.11', '3.12', '3.13', '3.14']
21 | os: [ubuntu-latest, macOS-latest, windows-latest]
22 |
23 | steps:
24 | - uses: actions/checkout@v4
25 | - name: Set up uv
26 | uses: astral-sh/setup-uv@v5
27 | with:
28 | python-version: ${{ matrix.python-version }}
29 | - name: Install default (with full options) and test dependencies
30 | run: uv pip install --python-version=${{ matrix.python-version }} -e '.[full]' --group test
31 | - name: Run unit and doc tests with coverage report
32 | run: uv run pytest -W error tests/unit tests/doc --cov=src --cov-report=xml
33 | - name: Upload results to Codecov
34 | uses: codecov/codecov-action@v4
35 | with:
36 | token: ${{ secrets.CODECOV_TOKEN }}
37 |
38 | tests-default-install:
39 | name: Run (most) tests with default install
40 | runs-on: ubuntu-latest
41 | steps:
42 | - uses: actions/checkout@v4
43 | - name: Set up uv
44 | uses: astral-sh/setup-uv@v5
45 | with:
46 | python-version: ${{ env.PYTHON_VERSION }}
47 | - name: Install default (without any option) and test dependencies
48 | run: uv pip install --python-version=${{ env.PYTHON_VERSION }} -e . --group test
49 | - name: Run unit and doc tests with coverage report
50 | run: |
51 | uv run pytest -W error tests/unit tests/doc \
52 | --ignore tests/unit/aggregation/test_cagrad.py \
53 | --ignore tests/unit/aggregation/test_nash_mtl.py \
54 | --ignore tests/doc/test_aggregation.py \
55 | --cov=src --cov-report=xml
56 | - name: Upload results to Codecov
57 | uses: codecov/codecov-action@v4
58 | with:
59 | token: ${{ secrets.CODECOV_TOKEN }}
60 |
61 | build-doc:
62 | name: Build doc
63 | runs-on: ubuntu-latest
64 | steps:
65 | - name: Checkout repository
66 | uses: actions/checkout@v4
67 |
68 | - name: Set up uv
69 | uses: astral-sh/setup-uv@v5
70 | with:
71 | python-version: ${{ env.PYTHON_VERSION }}
72 |
73 | - name: Install dependencies (default with full options & doc)
74 | run: uv pip install --python-version=${{ env.PYTHON_VERSION }} -e '.[full]' --group doc
75 |
76 | - name: Build Documentation
77 | working-directory: docs
78 | run: uv run make dirhtml
79 |
80 | mypy:
81 | name: Run mypy
82 | runs-on: ubuntu-latest
83 | steps:
84 | - name: Checkout repository
85 | uses: actions/checkout@v4
86 |
87 | - name: Set up uv
88 | uses: astral-sh/setup-uv@v5
89 | with:
90 | python-version: ${{ env.PYTHON_VERSION }}
91 |
92 | - name: Install dependencies (default with full options & check)
93 | run: uv pip install --python-version=${{ env.PYTHON_VERSION }} -e '.[full]' --group check
94 |
95 | - name: Run mypy
96 | run: uv run mypy src/torchjd
97 |
--------------------------------------------------------------------------------
/tests/unit/autojac/_transform/test_stack.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Iterable
2 |
3 | import torch
4 | from torch import Tensor
5 | from utils.dict_assertions import assert_tensor_dicts_are_close
6 | from utils.tensors import ones_, tensor_, zeros_
7 |
8 | from torchjd.autojac._transform import Stack, Transform
9 | from torchjd.autojac._transform._base import TensorDict
10 |
11 |
12 | class FakeGradientsTransform(Transform):
13 | """Transform that produces gradients filled with ones, for testing purposes."""
14 |
15 | def __init__(self, keys: Iterable[Tensor]):
16 | self.keys = set(keys)
17 |
18 | def __call__(self, input: TensorDict) -> TensorDict:
19 | return {key: torch.ones_like(key) for key in self.keys}
20 |
21 | def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
22 | return self.keys
23 |
24 |
25 | def test_single_key():
26 | """
27 | Tests that the Stack transform correctly stacks gradients into a jacobian, in a very simple
28 | example with 2 transforms sharing the same key.
29 | """
30 |
31 | key = zeros_([3, 4])
32 | input = {}
33 |
34 | transform = FakeGradientsTransform([key])
35 | stack = Stack([transform, transform])
36 |
37 | output = stack(input)
38 | expected_output = {key: ones_([2, 3, 4])}
39 |
40 | assert_tensor_dicts_are_close(output, expected_output)
41 |
42 |
43 | def test_disjoint_key_sets():
44 | """
45 | Tests that the Stack transform correctly stacks gradients into a jacobian, in an example where
46 | the output key sets of all of its transforms are disjoint. The missing values should be replaced
47 | by zeros.
48 | """
49 |
50 | key1 = zeros_([1, 2])
51 | key2 = zeros_([3])
52 | input = {}
53 |
54 | transform1 = FakeGradientsTransform([key1])
55 | transform2 = FakeGradientsTransform([key2])
56 | stack = Stack([transform1, transform2])
57 |
58 | output = stack(input)
59 | expected_output = {
60 | key1: tensor_([[[1.0, 1.0]], [[0.0, 0.0]]]),
61 | key2: tensor_([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]),
62 | }
63 |
64 | assert_tensor_dicts_are_close(output, expected_output)
65 |
66 |
67 | def test_overlapping_key_sets():
68 | """
69 | Tests that the Stack transform correctly stacks gradients into a jacobian, in an example where
70 | the output key sets all of its transforms are overlapping (non-empty intersection, but not
71 | equal). The missing values should be replaced by zeros.
72 | """
73 |
74 | key1 = zeros_([1, 2])
75 | key2 = zeros_([3])
76 | key3 = zeros_([4])
77 | input = {}
78 |
79 | transform12 = FakeGradientsTransform([key1, key2])
80 | transform23 = FakeGradientsTransform([key2, key3])
81 | stack = Stack([transform12, transform23])
82 |
83 | output = stack(input)
84 | expected_output = {
85 | key1: tensor_([[[1.0, 1.0]], [[0.0, 0.0]]]),
86 | key2: tensor_([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]),
87 | key3: tensor_([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]]),
88 | }
89 |
90 | assert_tensor_dicts_are_close(output, expected_output)
91 |
92 |
93 | def test_empty():
94 | """Tests that the Stack transform correctly handles an empty list of transforms."""
95 |
96 | stack = Stack([])
97 | input = {}
98 | output = stack(input)
99 | expected_output = {}
100 |
101 | assert_tensor_dicts_are_close(output, expected_output)
102 |
--------------------------------------------------------------------------------
/src/torchjd/autojac/_utils.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | from collections.abc import Iterable, Sequence
3 | from typing import cast
4 |
5 | from torch import Tensor
6 | from torch.autograd.graph import Node
7 |
8 | from ._transform import OrderedSet
9 |
10 |
11 | def check_optional_positive_chunk_size(parallel_chunk_size: int | None) -> None:
12 | if not (parallel_chunk_size is None or parallel_chunk_size > 0):
13 | raise ValueError(
14 | "`parallel_chunk_size` should be `None` or greater than `0`. (got "
15 | f"{parallel_chunk_size})"
16 | )
17 |
18 |
19 | def as_checked_ordered_set(
20 | tensors: Sequence[Tensor] | Tensor, variable_name: str
21 | ) -> OrderedSet[Tensor]:
22 | if isinstance(tensors, Tensor):
23 | tensors = [tensors]
24 |
25 | original_length = len(tensors)
26 | output = OrderedSet(tensors)
27 |
28 | if len(output) != original_length:
29 | raise ValueError(f"`{variable_name}` should contain unique elements.")
30 |
31 | return OrderedSet(tensors)
32 |
33 |
34 | def get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> OrderedSet[Tensor]:
35 | """
36 | Gets the leaves of the autograd graph of all specified ``tensors``.
37 |
38 | :param tensors: Tensors from which the graph traversal should start. They should all require
39 | grad and not be leaves.
40 | :param excluded: Tensors whose grad_fn should be excluded from the graph traversal. They should
41 | all require grad and not be leaves.
42 |
43 | """
44 |
45 | if any([tensor.grad_fn is None for tensor in tensors]):
46 | raise ValueError("All `tensors` should have a `grad_fn`.")
47 |
48 | if any([tensor.grad_fn is None for tensor in excluded]):
49 | raise ValueError("All `excluded` tensors should have a `grad_fn`.")
50 |
51 | accumulate_grads = _get_descendant_accumulate_grads(
52 | roots=cast(OrderedSet[Node], OrderedSet([tensor.grad_fn for tensor in tensors])),
53 | excluded_nodes=cast(set[Node], {tensor.grad_fn for tensor in excluded}),
54 | )
55 |
56 | # accumulate_grads contains instances of AccumulateGrad, which contain a `variable` field.
57 | # They cannot be typed as such because AccumulateGrad is not public.
58 | leaves = OrderedSet([g.variable for g in accumulate_grads]) # type: ignore[attr-defined]
59 |
60 | return leaves
61 |
62 |
63 | def _get_descendant_accumulate_grads(
64 | roots: OrderedSet[Node], excluded_nodes: set[Node]
65 | ) -> OrderedSet[Node]:
66 | """
67 | Gets the AccumulateGrad descendants of the specified nodes.
68 |
69 | :param roots: Root nodes from which the graph traversal should start.
70 | :param excluded_nodes: Nodes excluded from the graph traversal.
71 | """
72 |
73 | excluded_nodes = set(excluded_nodes) # Re-instantiate set to avoid modifying input
74 | result: OrderedSet[Node] = OrderedSet([])
75 | roots.difference_update(excluded_nodes)
76 | nodes_to_traverse = deque(roots)
77 |
78 | # This implementation more or less follows what is advised in
79 | # https://discuss.pytorch.org/t/autograd-graph-traversal/213658 and what was suggested in
80 | # https://github.com/TorchJD/torchjd/issues/216.
81 | while nodes_to_traverse:
82 | node = nodes_to_traverse.popleft() # Breadth-first
83 |
84 | if node.__class__.__name__ == "AccumulateGrad":
85 | result.add(node)
86 |
87 | for child, _ in node.next_functions:
88 | if child is not None and child not in excluded_nodes:
89 | nodes_to_traverse.append(child) # Append to the right
90 | excluded_nodes.add(child)
91 |
92 | return result
93 |
--------------------------------------------------------------------------------
/docs/source/icons/TorchJD_text.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
76 |
--------------------------------------------------------------------------------
/tests/unit/aggregation/_utils/test_dual_cone.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from pytest import mark, raises
4 | from torch.testing import assert_close
5 | from utils.tensors import rand_, randn_
6 |
7 | from torchjd.aggregation._utils.dual_cone import _project_weight_vector, project_weights
8 |
9 |
10 | @mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)])
11 | def test_solution_weights(shape: tuple[int, int]):
12 | r"""
13 | Tests that `_project_weights` returns valid weights corresponding to the projection onto the
14 | dual cone of a matrix with the specified shape.
15 |
16 | Validation is performed by verifying that the solution satisfies the `KKT conditions
17 | `_ for the
18 | quadratic program that projects vectors onto the dual cone of a matrix. Specifically, the
19 | solution should satisfy the equivalent set of conditions described in Lemma 4 of [1].
20 |
21 | Let `u` be a vector of weights and `G` a positive semi-definite matrix. Consider the quadratic
22 | problem of minimizing `v^T G v` subject to `u \preceq v`.
23 |
24 | Then `w` is a solution if and only if it satisfies the following three conditions:
25 | 1. **Dual feasibility:** `u \preceq w`
26 | 2. **Primal feasibility:** `0 \preceq G w`
27 | 3. **Complementary slackness:** `u^T G w = w^T G w`
28 |
29 | Reference:
30 | [1] `Jacobian Descent For Multi-Objective Optimization `_.
31 | """
32 |
33 | J = randn_(shape)
34 | G = J @ J.T
35 | u = rand_(shape[0])
36 |
37 | w = project_weights(u, G, "quadprog")
38 | dual_gap = w - u
39 |
40 | # Dual feasibility
41 | dual_gap_positive_part = dual_gap[dual_gap >= 0.0]
42 | assert_close(dual_gap_positive_part.norm(), dual_gap.norm(), atol=1e-05, rtol=0)
43 |
44 | primal_gap = G @ w
45 |
46 | # Primal feasibility
47 | primal_gap_positive_part = primal_gap[primal_gap >= 0]
48 | assert_close(primal_gap_positive_part.norm(), primal_gap.norm(), atol=1e-04, rtol=0)
49 |
50 | # Complementary slackness
51 | slackness = dual_gap @ primal_gap
52 | assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0)
53 |
54 |
55 | @mark.parametrize("shape", [(5, 7), (9, 37), (32, 114)])
56 | @mark.parametrize("scaling", [2 ** (-4), 2 ** (-2), 2**2, 2**4])
57 | def test_scale_invariant(shape: tuple[int, int], scaling: float):
58 | """
59 | Tests that `_project_weights` is invariant under scaling.
60 | """
61 |
62 | J = randn_(shape)
63 | G = J @ J.T
64 | u = rand_(shape[0])
65 |
66 | w = project_weights(u, G, "quadprog")
67 | w_scaled = project_weights(u, scaling * G, "quadprog")
68 |
69 | assert_close(w_scaled, w)
70 |
71 |
72 | @mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)])
73 | def test_tensorization_shape(shape: tuple[int, ...]):
74 | """
75 | Tests that applying `_project_weights` on a tensor is equivalent to applying it on the tensor
76 | reshaped as matrix and to reshape the result back to the original tensor's shape.
77 | """
78 |
79 | matrix = randn_([shape[-1], shape[-1]])
80 | U_tensor = randn_(shape)
81 | U_matrix = U_tensor.reshape([-1, shape[-1]])
82 |
83 | G = matrix @ matrix.T
84 |
85 | W_tensor = project_weights(U_tensor, G, "quadprog")
86 | W_matrix = project_weights(U_matrix, G, "quadprog")
87 |
88 | assert_close(W_matrix.reshape(shape), W_tensor)
89 |
90 |
91 | def test_project_weight_vector_failure():
92 | """Tests that `_project_weight_vector` raises an error when the input G has too large values."""
93 |
94 | large_J = np.random.randn(10, 100) * 1e5
95 | large_G = large_J @ large_J.T
96 | with raises(ValueError):
97 | _project_weight_vector(np.ones(10), large_G, "quadprog")
98 |
--------------------------------------------------------------------------------
/tests/unit/autojac/_transform/test_diagonalize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytest import raises
3 | from utils.dict_assertions import assert_tensor_dicts_are_close
4 | from utils.tensors import tensor_
5 |
6 | from torchjd.autojac._transform import Diagonalize, OrderedSet, RequirementError
7 |
8 |
9 | def test_single_input():
10 | """Tests that the Diagonalize transform works when given a single input."""
11 |
12 | key = tensor_([1.0, 2.0, 3.0])
13 | value = torch.ones_like(key)
14 | input = {key: value}
15 |
16 | diag = Diagonalize(OrderedSet([key]))
17 |
18 | output = diag(input)
19 | expected_output = {key: tensor_([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])}
20 |
21 | assert_tensor_dicts_are_close(output, expected_output)
22 |
23 |
24 | def test_multiple_inputs():
25 | """Tests that the Diagonalize transform works when given multiple inputs."""
26 |
27 | key1 = tensor_([[1.0, 2.0], [4.0, 5.0]])
28 | key2 = tensor_([1.0, 3.0, 5.0])
29 | key3 = tensor_(1.0)
30 | value1 = torch.ones_like(key1)
31 | value2 = torch.ones_like(key2)
32 | value3 = torch.ones_like(key3)
33 | input = {key1: value1, key2: value2, key3: value3}
34 |
35 | diag = Diagonalize(OrderedSet([key1, key2, key3]))
36 |
37 | output = diag(input)
38 | expected_output = {
39 | key1: tensor_(
40 | [
41 | [[1.0, 0.0], [0.0, 0.0]],
42 | [[0.0, 1.0], [0.0, 0.0]],
43 | [[0.0, 0.0], [1.0, 0.0]],
44 | [[0.0, 0.0], [0.0, 1.0]],
45 | [[0.0, 0.0], [0.0, 0.0]],
46 | [[0.0, 0.0], [0.0, 0.0]],
47 | [[0.0, 0.0], [0.0, 0.0]],
48 | [[0.0, 0.0], [0.0, 0.0]],
49 | ],
50 | ),
51 | key2: tensor_(
52 | [
53 | [0.0, 0.0, 0.0],
54 | [0.0, 0.0, 0.0],
55 | [0.0, 0.0, 0.0],
56 | [0.0, 0.0, 0.0],
57 | [1.0, 0.0, 0.0],
58 | [0.0, 1.0, 0.0],
59 | [0.0, 0.0, 1.0],
60 | [0.0, 0.0, 0.0],
61 | ],
62 | ),
63 | key3: tensor_(
64 | [
65 | 0.0,
66 | 0.0,
67 | 0.0,
68 | 0.0,
69 | 0.0,
70 | 0.0,
71 | 0.0,
72 | 1.0,
73 | ],
74 | ),
75 | }
76 |
77 | assert_tensor_dicts_are_close(output, expected_output)
78 |
79 |
80 | def test_permute_order():
81 | """
82 | Tests that the Diagonalize transform outputs a permuted mapping when its keys are permuted.
83 | """
84 |
85 | key1 = tensor_(2.0)
86 | key2 = tensor_(1.0)
87 | value1 = torch.ones_like(key1)
88 | value2 = torch.ones_like(key2)
89 | input = {key1: value1, key2: value2}
90 |
91 | permuted_diag = Diagonalize(OrderedSet([key2, key1]))
92 | diag = Diagonalize(OrderedSet([key1, key2]))
93 |
94 | permuted_output = permuted_diag(input)
95 | output = {key1: permuted_output[key2], key2: permuted_output[key1]} # un-permute
96 | expected_output = diag(input)
97 |
98 | assert_tensor_dicts_are_close(output, expected_output)
99 |
100 |
101 | def test_check_keys():
102 | """
103 | Tests that the `check_keys` method works correctly. The input_keys must match the stored
104 | considered keys.
105 | """
106 |
107 | key1 = tensor_([1.0])
108 | key2 = tensor_([1.0])
109 | diag = Diagonalize(OrderedSet([key1]))
110 |
111 | output_keys = diag.check_keys({key1})
112 | assert output_keys == {key1}
113 |
114 | with raises(RequirementError):
115 | diag.check_keys(set())
116 |
117 | with raises(RequirementError):
118 | diag.check_keys({key1, key2})
119 |
--------------------------------------------------------------------------------
/docs/source/icons/TorchJD_text_dark.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
79 |
--------------------------------------------------------------------------------
/src/torchjd/aggregation/_krum.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.nn import functional as F
4 |
5 | from ._aggregator_bases import GramianWeightedAggregator
6 | from ._weighting_bases import PSDMatrix, Weighting
7 |
8 |
9 | class Krum(GramianWeightedAggregator):
10 | """
11 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` for adversarial federated learning,
12 | as defined in `Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent
13 | `_.
14 |
15 | :param n_byzantine: The number of rows of the input matrix that can come from an adversarial
16 | source.
17 | :param n_selected: The number of selected rows in the context of Multi-Krum. Defaults to 1.
18 | """
19 |
20 | def __init__(self, n_byzantine: int, n_selected: int = 1):
21 | self._n_byzantine = n_byzantine
22 | self._n_selected = n_selected
23 | super().__init__(KrumWeighting(n_byzantine=n_byzantine, n_selected=n_selected))
24 |
25 | def __repr__(self) -> str:
26 | return (
27 | f"{self.__class__.__name__}(n_byzantine={self._n_byzantine}, n_selected="
28 | f"{self._n_selected})"
29 | )
30 |
31 | def __str__(self) -> str:
32 | return f"Krum{self._n_byzantine}-{self._n_selected}"
33 |
34 |
35 | class KrumWeighting(Weighting[PSDMatrix]):
36 | """
37 | :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
38 | :class:`~torchjd.aggregation.Krum`.
39 |
40 | :param n_byzantine: The number of rows of the input matrix that can come from an adversarial
41 | source.
42 | :param n_selected: The number of selected rows in the context of Multi-Krum. Defaults to 1.
43 | """
44 |
45 | def __init__(self, n_byzantine: int, n_selected: int = 1):
46 | super().__init__()
47 | if n_byzantine < 0:
48 | raise ValueError(
49 | "Parameter `n_byzantine` should be a non-negative integer. Found `n_byzantine = "
50 | f"{n_byzantine}`."
51 | )
52 |
53 | if n_selected < 1:
54 | raise ValueError(
55 | "Parameter `n_selected` should be a positive integer. Found `n_selected = "
56 | f"{n_selected}`."
57 | )
58 |
59 | self.n_byzantine = n_byzantine
60 | self.n_selected = n_selected
61 |
62 | def forward(self, gramian: Tensor) -> Tensor:
63 | self._check_matrix_shape(gramian)
64 | gradient_norms_squared = torch.diagonal(gramian)
65 | distances_squared = (
66 | gradient_norms_squared.unsqueeze(0) + gradient_norms_squared.unsqueeze(1) - 2 * gramian
67 | )
68 | distances = torch.sqrt(distances_squared)
69 |
70 | n_closest = gramian.shape[0] - self.n_byzantine - 2
71 | smallest_distances, _ = torch.topk(distances, k=n_closest + 1, largest=False)
72 | smallest_distances_excluding_self = smallest_distances[:, 1:]
73 | scores = smallest_distances_excluding_self.sum(dim=1)
74 |
75 | _, selected_indices = torch.topk(scores, k=self.n_selected, largest=False)
76 | one_hot_selected_indices = F.one_hot(selected_indices, num_classes=gramian.shape[0])
77 | weights = one_hot_selected_indices.sum(dim=0).to(dtype=gramian.dtype) / self.n_selected
78 |
79 | return weights
80 |
81 | def _check_matrix_shape(self, gramian: Tensor) -> None:
82 | min_rows = self.n_byzantine + 3
83 | if gramian.shape[0] < min_rows:
84 | raise ValueError(
85 | f"Parameter `gramian` should have at least {min_rows} rows (n_byzantine + 3). Found"
86 | f" `gramian` with {gramian.shape[0]} rows."
87 | )
88 |
89 | if gramian.shape[0] < self.n_selected:
90 | raise ValueError(
91 | f"Parameter `gramian` should have at least {self.n_selected} rows (n_selected). "
92 | f"Found `gramian` with {gramian.shape[0]} rows."
93 | )
94 |
--------------------------------------------------------------------------------
/tests/speed/autogram/grad_vs_jac_vs_gram.py:
--------------------------------------------------------------------------------
1 | import gc
2 |
3 | import torch
4 | from device import DEVICE
5 | from utils.architectures import (
6 | AlexNet,
7 | Cifar10Model,
8 | FreeParam,
9 | GroupNormMobileNetV3Small,
10 | InstanceNormMobileNetV2,
11 | InstanceNormResNet18,
12 | ModuleFactory,
13 | NoFreeParam,
14 | SqueezeNet,
15 | WithTransformerLarge,
16 | )
17 | from utils.forward_backwards import (
18 | autograd_forward_backward,
19 | autograd_gramian_forward_backward,
20 | autogram_forward_backward,
21 | autojac_forward_backward,
22 | make_mse_loss_fn,
23 | )
24 | from utils.tensors import make_inputs_and_targets
25 |
26 | from tests.speed.utils import print_times, time_call
27 | from torchjd.aggregation import Mean
28 | from torchjd.autogram import Engine
29 |
30 | PARAMETRIZATIONS = [
31 | (ModuleFactory(WithTransformerLarge), 8),
32 | (ModuleFactory(FreeParam), 64),
33 | (ModuleFactory(NoFreeParam), 64),
34 | (ModuleFactory(Cifar10Model), 64),
35 | (ModuleFactory(AlexNet), 8),
36 | (ModuleFactory(InstanceNormResNet18), 16),
37 | (ModuleFactory(GroupNormMobileNetV3Small), 16),
38 | (ModuleFactory(SqueezeNet), 4),
39 | (ModuleFactory(InstanceNormMobileNetV2), 2),
40 | ]
41 |
42 |
43 | def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_size: int):
44 | model = factory()
45 | inputs, targets = make_inputs_and_targets(model, batch_size)
46 | loss_fn = make_mse_loss_fn(targets)
47 |
48 | A = Mean()
49 | W = A.weighting
50 |
51 | print(
52 | f"\nTimes for forward + backward on {factory} with BS={batch_size}, A={A}" f" on {DEVICE}."
53 | )
54 |
55 | def fn_autograd():
56 | autograd_forward_backward(model, inputs, loss_fn)
57 |
58 | def init_fn_autograd():
59 | torch.cuda.empty_cache()
60 | gc.collect()
61 | fn_autograd()
62 |
63 | def fn_autograd_gramian():
64 | autograd_gramian_forward_backward(model, inputs, loss_fn, W)
65 |
66 | def init_fn_autograd_gramian():
67 | torch.cuda.empty_cache()
68 | gc.collect()
69 | fn_autograd_gramian()
70 |
71 | def fn_autojac():
72 | autojac_forward_backward(model, inputs, loss_fn, A)
73 |
74 | def init_fn_autojac():
75 | torch.cuda.empty_cache()
76 | gc.collect()
77 | fn_autojac()
78 |
79 | def fn_autogram():
80 | autogram_forward_backward(model, inputs, loss_fn, engine, W)
81 |
82 | def init_fn_autogram():
83 | torch.cuda.empty_cache()
84 | gc.collect()
85 | fn_autogram()
86 |
87 | def optionally_cuda_sync():
88 | if str(DEVICE).startswith("cuda"):
89 | torch.cuda.synchronize()
90 |
91 | def pre_fn():
92 | model.zero_grad()
93 | optionally_cuda_sync()
94 |
95 | def post_fn():
96 | optionally_cuda_sync()
97 |
98 | n_runs = 10
99 | autograd_times = time_call(fn_autograd, init_fn_autograd, pre_fn, post_fn, n_runs)
100 | print_times("autograd", autograd_times)
101 |
102 | autograd_gramian_times = time_call(
103 | fn_autograd_gramian, init_fn_autograd_gramian, pre_fn, post_fn, n_runs
104 | )
105 | print_times("autograd gramian", autograd_gramian_times)
106 |
107 | autojac_times = time_call(fn_autojac, init_fn_autojac, pre_fn, post_fn, n_runs)
108 | print_times("autojac", autojac_times)
109 |
110 | engine = Engine(model, batch_dim=0)
111 | autogram_times = time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs)
112 | print_times("autogram", autogram_times)
113 |
114 |
115 | def main():
116 | for factory, batch_size in PARAMETRIZATIONS:
117 | compare_autograd_autojac_and_autogram_speed(factory, batch_size)
118 | print("\n")
119 |
120 |
121 | if __name__ == "__main__":
122 | # To test this on cuda, add the following environment variables when running this:
123 | # CUBLAS_WORKSPACE_CONFIG=:4096:8;PYTEST_TORCH_DEVICE=cuda:0
124 | main()
125 |
--------------------------------------------------------------------------------