├── src └── torchjd │ ├── py.typed │ ├── aggregation │ ├── _utils │ │ ├── __init__.py │ │ ├── str.py │ │ ├── non_differentiable.py │ │ ├── check_dependencies.py │ │ ├── pref_vector.py │ │ ├── gramian.py │ │ └── dual_cone.py │ ├── _sum.py │ ├── _mean.py │ ├── _random.py │ ├── _imtl_g.py │ ├── _flattening.py │ ├── _constant.py │ ├── _pcgrad.py │ ├── _trimmed_mean.py │ ├── _aggregator_bases.py │ ├── _weighting_bases.py │ ├── _mgda.py │ ├── _graddrop.py │ ├── _config.py │ └── _krum.py │ ├── autojac │ ├── _transform │ │ ├── __init__.py │ │ ├── _materialize.py │ │ ├── _init.py │ │ ├── _select.py │ │ ├── _ordered_set.py │ │ ├── _accumulate.py │ │ ├── _stack.py │ │ ├── _grad.py │ │ ├── _diagonalize.py │ │ └── _differentiate.py │ ├── __init__.py │ └── _utils.py │ ├── __init__.py │ └── autogram │ ├── _gramian_accumulator.py │ ├── __init__.py │ ├── _edge_registry.py │ ├── _gramian_computer.py │ └── _gramian_utils.py ├── tests ├── doc │ ├── __init__.py │ ├── test_backward.py │ ├── test_aggregation.py │ └── test_autogram.py ├── plots │ └── __init__.py ├── speed │ ├── __init__.py │ ├── autogram │ │ ├── __init__.py │ │ └── grad_vs_jac_vs_gram.py │ └── utils.py ├── unit │ ├── __init__.py │ ├── autogram │ │ ├── __init__.py │ │ ├── test_gramian_utils.py │ │ └── test_edge_registry.py │ ├── autojac │ │ ├── __init__.py │ │ └── _transform │ │ │ ├── __init__.py │ │ │ ├── test_init.py │ │ │ ├── test_select.py │ │ │ ├── test_accumulate.py │ │ │ ├── test_stack.py │ │ │ └── test_diagonalize.py │ ├── aggregation │ │ ├── __init__.py │ │ ├── _utils │ │ │ ├── __init__.py │ │ │ ├── test_pref_vector.py │ │ │ └── test_dual_cone.py │ │ ├── test_aggregator_bases.py │ │ ├── test_random.py │ │ ├── test_aligned_mtl.py │ │ ├── test_sum.py │ │ ├── test_mean.py │ │ ├── test_imtl_g.py │ │ ├── test_config.py │ │ ├── _inputs.py │ │ ├── test_cagrad.py │ │ ├── test_pcgrad.py │ │ ├── test_trimmed_mean.py │ │ ├── test_mgda.py │ │ ├── test_nash_mtl.py │ │ ├── test_dualproj.py │ │ ├── test_krum.py │ │ ├── test_upgrad.py │ │ ├── test_graddrop.py │ │ └── test_constant.py │ └── test_deprecations.py ├── utils │ ├── __init__.py │ ├── contexts.py │ ├── dict_assertions.py │ └── tensors.py ├── device.py └── conftest.py ├── docs ├── source │ ├── icons │ │ ├── favicon.ico │ │ ├── favicon-16x16.png │ │ ├── favicon-32x32.png │ │ ├── apple-touch-icon.png │ │ ├── android-chrome-192x192.png │ │ ├── android-chrome-512x512.png │ │ ├── site.webmanifest │ │ ├── TorchJD_logo.svg │ │ ├── TorchJD_text.svg │ │ └── TorchJD_text_dark.svg │ ├── _static │ │ ├── logo-dark-mode.png │ │ └── logo-light-mode.png │ ├── docs │ │ ├── autojac │ │ │ ├── backward.rst │ │ │ ├── mtl_backward.rst │ │ │ └── index.rst │ │ ├── autogram │ │ │ ├── engine.rst │ │ │ └── index.rst │ │ └── aggregation │ │ │ ├── config.rst │ │ │ ├── graddrop.rst │ │ │ ├── nash_mtl.rst │ │ │ ├── flattening.rst │ │ │ ├── trimmed_mean.rst │ │ │ ├── sum.rst │ │ │ ├── krum.rst │ │ │ ├── mean.rst │ │ │ ├── mgda.rst │ │ │ ├── imtl_g.rst │ │ │ ├── cagrad.rst │ │ │ ├── pcgrad.rst │ │ │ ├── random.rst │ │ │ ├── upgrad.rst │ │ │ ├── constant.rst │ │ │ ├── dualproj.rst │ │ │ ├── aligned_mtl.rst │ │ │ └── index.rst │ ├── installation.md │ ├── _templates │ │ └── page.html │ ├── examples │ │ ├── rnn.rst │ │ ├── partial_jd.rst │ │ ├── index.rst │ │ ├── basic_usage.rst │ │ ├── mtl.rst │ │ ├── iwmtl.rst │ │ ├── amp.rst │ │ ├── monitoring.rst │ │ └── lightning_integration.rst │ └── index.rst ├── Makefile └── make.bat ├── .github └── workflows │ ├── release.yml │ ├── build-deploy-docs.yml │ └── tests.yml ├── .gitignore └── .pre-commit-config.yaml /src/torchjd/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/doc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/plots/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/speed/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/speed/autogram/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/autogram/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/autojac/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/aggregation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/aggregation/_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/autojac/_transform/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/icons/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorchJD/torchjd/HEAD/docs/source/icons/favicon.ico -------------------------------------------------------------------------------- /docs/source/icons/favicon-16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorchJD/torchjd/HEAD/docs/source/icons/favicon-16x16.png -------------------------------------------------------------------------------- /docs/source/icons/favicon-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorchJD/torchjd/HEAD/docs/source/icons/favicon-32x32.png -------------------------------------------------------------------------------- /docs/source/_static/logo-dark-mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorchJD/torchjd/HEAD/docs/source/_static/logo-dark-mode.png -------------------------------------------------------------------------------- /docs/source/icons/apple-touch-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorchJD/torchjd/HEAD/docs/source/icons/apple-touch-icon.png -------------------------------------------------------------------------------- /docs/source/_static/logo-light-mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorchJD/torchjd/HEAD/docs/source/_static/logo-light-mode.png -------------------------------------------------------------------------------- /docs/source/docs/autojac/backward.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | backward 4 | ======== 5 | 6 | .. autofunction:: torchjd.autojac.backward 7 | -------------------------------------------------------------------------------- /docs/source/icons/android-chrome-192x192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorchJD/torchjd/HEAD/docs/source/icons/android-chrome-192x192.png -------------------------------------------------------------------------------- /docs/source/icons/android-chrome-512x512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorchJD/torchjd/HEAD/docs/source/icons/android-chrome-512x512.png -------------------------------------------------------------------------------- /docs/source/docs/autojac/mtl_backward.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | mtl_backward 4 | ============ 5 | 6 | .. autofunction:: torchjd.autojac.mtl_backward 7 | -------------------------------------------------------------------------------- /docs/source/docs/autogram/engine.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | Engine 4 | ====== 5 | 6 | .. autoclass:: torchjd.autogram.Engine 7 | :members: 8 | :undoc-members: 9 | -------------------------------------------------------------------------------- /docs/source/docs/aggregation/config.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | ConFIG 4 | ====== 5 | 6 | .. autoclass:: torchjd.aggregation.ConFIG 7 | :members: 8 | :undoc-members: 9 | :exclude-members: forward 10 | -------------------------------------------------------------------------------- /docs/source/docs/autogram/index.rst: -------------------------------------------------------------------------------- 1 | autogram 2 | ======== 3 | 4 | 5 | .. automodule:: torchjd.autogram 6 | :members: 7 | 8 | .. toctree:: 9 | :hidden: 10 | :maxdepth: 1 11 | 12 | engine.rst 13 | -------------------------------------------------------------------------------- /docs/source/docs/aggregation/graddrop.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | GradDrop 4 | ======== 5 | 6 | .. autoclass:: torchjd.aggregation.GradDrop 7 | :members: 8 | :undoc-members: 9 | :exclude-members: forward 10 | -------------------------------------------------------------------------------- /docs/source/docs/aggregation/nash_mtl.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | Nash-MTL 4 | ======== 5 | 6 | .. autoclass:: torchjd.aggregation.NashMTL 7 | :members: 8 | :undoc-members: 9 | :exclude-members: forward 10 | -------------------------------------------------------------------------------- /docs/source/docs/aggregation/flattening.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | Flattening 4 | ========== 5 | 6 | .. autoclass:: torchjd.aggregation.Flattening 7 | :members: 8 | :undoc-members: 9 | :exclude-members: forward 10 | -------------------------------------------------------------------------------- /docs/source/docs/aggregation/trimmed_mean.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | Trimmed Mean 4 | ============ 5 | 6 | .. autoclass:: torchjd.aggregation.TrimmedMean 7 | :members: 8 | :undoc-members: 9 | :exclude-members: forward 10 | -------------------------------------------------------------------------------- /docs/source/docs/autojac/index.rst: -------------------------------------------------------------------------------- 1 | autojac 2 | ======= 3 | 4 | .. automodule:: torchjd.autojac 5 | :members: 6 | 7 | 8 | .. toctree:: 9 | :hidden: 10 | :maxdepth: 1 11 | 12 | backward.rst 13 | mtl_backward.rst 14 | -------------------------------------------------------------------------------- /docs/source/icons/site.webmanifest: -------------------------------------------------------------------------------- 1 | {"name":"","short_name":"","icons":[{"src":"/android-chrome-192x192.png","sizes":"192x192","type":"image/png"},{"src":"/android-chrome-512x512.png","sizes":"512x512","type":"image/png"}],"theme_color":"#ffffff","background_color":"#ffffff","display":"standalone"} 2 | -------------------------------------------------------------------------------- /docs/source/docs/aggregation/sum.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | Sum 4 | === 5 | 6 | .. autoclass:: torchjd.aggregation.Sum 7 | :members: 8 | :undoc-members: 9 | :exclude-members: forward 10 | 11 | .. autoclass:: torchjd.aggregation.SumWeighting 12 | :members: 13 | :undoc-members: 14 | :exclude-members: forward 15 | -------------------------------------------------------------------------------- /docs/source/docs/aggregation/krum.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | Krum 4 | ==== 5 | 6 | .. autoclass:: torchjd.aggregation.Krum 7 | :members: 8 | :undoc-members: 9 | :exclude-members: forward 10 | 11 | .. autoclass:: torchjd.aggregation.KrumWeighting 12 | :members: 13 | :undoc-members: 14 | :exclude-members: forward 15 | -------------------------------------------------------------------------------- /docs/source/docs/aggregation/mean.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | Mean 4 | ==== 5 | 6 | .. autoclass:: torchjd.aggregation.Mean 7 | :members: 8 | :undoc-members: 9 | :exclude-members: forward 10 | 11 | .. autoclass:: torchjd.aggregation.MeanWeighting 12 | :members: 13 | :undoc-members: 14 | :exclude-members: forward 15 | -------------------------------------------------------------------------------- /docs/source/docs/aggregation/mgda.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | MGDA 4 | ==== 5 | 6 | .. autoclass:: torchjd.aggregation.MGDA 7 | :members: 8 | :undoc-members: 9 | :exclude-members: forward 10 | 11 | .. autoclass:: torchjd.aggregation.MGDAWeighting 12 | :members: 13 | :undoc-members: 14 | :exclude-members: forward 15 | -------------------------------------------------------------------------------- /docs/source/docs/aggregation/imtl_g.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | IMTL-G 4 | ====== 5 | 6 | .. autoclass:: torchjd.aggregation.IMTLG 7 | :members: 8 | :undoc-members: 9 | :exclude-members: forward 10 | 11 | .. autoclass:: torchjd.aggregation.IMTLGWeighting 12 | :members: 13 | :undoc-members: 14 | :exclude-members: forward 15 | -------------------------------------------------------------------------------- /docs/source/docs/aggregation/cagrad.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | CAGrad 4 | ====== 5 | 6 | .. autoclass:: torchjd.aggregation.CAGrad 7 | :members: 8 | :undoc-members: 9 | :exclude-members: forward 10 | 11 | .. autoclass:: torchjd.aggregation.CAGradWeighting 12 | :members: 13 | :undoc-members: 14 | :exclude-members: forward 15 | -------------------------------------------------------------------------------- /docs/source/docs/aggregation/pcgrad.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | PCGrad 4 | ====== 5 | 6 | .. autoclass:: torchjd.aggregation.PCGrad 7 | :members: 8 | :undoc-members: 9 | :exclude-members: forward 10 | 11 | .. autoclass:: torchjd.aggregation.PCGradWeighting 12 | :members: 13 | :undoc-members: 14 | :exclude-members: forward 15 | -------------------------------------------------------------------------------- /docs/source/docs/aggregation/random.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | Random 4 | ====== 5 | 6 | .. autoclass:: torchjd.aggregation.Random 7 | :members: 8 | :undoc-members: 9 | :exclude-members: forward 10 | 11 | .. autoclass:: torchjd.aggregation.RandomWeighting 12 | :members: 13 | :undoc-members: 14 | :exclude-members: forward 15 | -------------------------------------------------------------------------------- /docs/source/docs/aggregation/upgrad.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | UPGrad 4 | ====== 5 | 6 | .. autoclass:: torchjd.aggregation.UPGrad 7 | :members: 8 | :undoc-members: 9 | :exclude-members: forward 10 | 11 | .. autoclass:: torchjd.aggregation.UPGradWeighting 12 | :members: 13 | :undoc-members: 14 | :exclude-members: forward 15 | -------------------------------------------------------------------------------- /docs/source/docs/aggregation/constant.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | Constant 4 | ======== 5 | 6 | .. autoclass:: torchjd.aggregation.Constant 7 | :members: 8 | :undoc-members: 9 | :exclude-members: forward 10 | 11 | .. autoclass:: torchjd.aggregation.ConstantWeighting 12 | :members: 13 | :undoc-members: 14 | :exclude-members: forward 15 | -------------------------------------------------------------------------------- /docs/source/docs/aggregation/dualproj.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | DualProj 4 | ======== 5 | 6 | .. autoclass:: torchjd.aggregation.DualProj 7 | :members: 8 | :undoc-members: 9 | :exclude-members: forward 10 | 11 | .. autoclass:: torchjd.aggregation.DualProjWeighting 12 | :members: 13 | :undoc-members: 14 | :exclude-members: forward 15 | -------------------------------------------------------------------------------- /docs/source/docs/aggregation/aligned_mtl.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | Aligned-MTL 4 | =========== 5 | 6 | .. autoclass:: torchjd.aggregation.AlignedMTL 7 | :members: 8 | :undoc-members: 9 | :exclude-members: forward 10 | 11 | .. autoclass:: torchjd.aggregation.AlignedMTLWeighting 12 | :members: 13 | :undoc-members: 14 | :exclude-members: forward 15 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_utils/str.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | def vector_to_str(vector: Tensor) -> str: 5 | """ 6 | Transforms a Tensor of the form `tensor([1.23456, 1.0, ...])` into a string of the form 7 | `1.23, 1., ...`. 8 | """ 9 | 10 | weights_str = ", ".join(["{:.2f}".format(value).rstrip("0") for value in vector]) 11 | return weights_str 12 | -------------------------------------------------------------------------------- /src/torchjd/autojac/_transform/__init__.py: -------------------------------------------------------------------------------- 1 | from ._accumulate import Accumulate 2 | from ._aggregate import Aggregate 3 | from ._base import Composition, Conjunction, RequirementError, Transform 4 | from ._diagonalize import Diagonalize 5 | from ._grad import Grad 6 | from ._init import Init 7 | from ._jac import Jac 8 | from ._ordered_set import OrderedSet 9 | from ._select import Select 10 | from ._stack import Stack 11 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_utils/non_differentiable.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, nn 2 | 3 | 4 | class NonDifferentiableError(RuntimeError): 5 | def __init__(self, module: nn.Module): 6 | super().__init__(f"Trying to differentiate through {module}, which is not differentiable.") 7 | 8 | 9 | def raise_non_differentiable_error(module: nn.Module, _: tuple[Tensor, ...] | Tensor) -> None: 10 | raise NonDifferentiableError(module) 11 | -------------------------------------------------------------------------------- /src/torchjd/autojac/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package enables Jacobian descent, through the ``backward`` and ``mtl_backward`` functions, which 3 | are meant to replace the call to ``torch.backward`` or ``loss.backward`` in gradient descent. To 4 | combine the information of the Jacobian, an aggregator from the ``aggregation`` package has to be 5 | used. 6 | """ 7 | 8 | from ._backward import backward 9 | from ._mtl_backward import mtl_backward 10 | -------------------------------------------------------------------------------- /tests/unit/test_deprecations.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | # deprecated since 2025-08-18 5 | def test_deprecate_imports_from_torchjd(): 6 | with pytest.deprecated_call(): 7 | from torchjd import backward # noqa: F401 8 | 9 | with pytest.deprecated_call(): 10 | from torchjd import mtl_backward # noqa: F401 11 | 12 | with pytest.raises(ImportError): 13 | from torchjd import something_that_does_not_exist # noqa: F401 14 | -------------------------------------------------------------------------------- /tests/device.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | try: 6 | _device_str = os.environ["PYTEST_TORCH_DEVICE"] 7 | except KeyError: 8 | _device_str = "cpu" # Default to cpu if environment variable not set 9 | 10 | if _device_str != "cuda:0" and _device_str != "cpu": 11 | raise ValueError(f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}") 12 | 13 | if _device_str == "cuda:0" and not torch.cuda.is_available(): 14 | raise ValueError('Requested device "cuda:0" but cuda is not available.') 15 | 16 | DEVICE = torch.device(_device_str) 17 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_utils/check_dependencies.py: -------------------------------------------------------------------------------- 1 | from importlib.util import find_spec 2 | 3 | 4 | class OptionalDepsNotInstalledError(ModuleNotFoundError): 5 | pass 6 | 7 | 8 | def check_dependencies_are_installed(dependency_names: list[str]) -> None: 9 | """ 10 | Check that the required list of dependencies are installed. 11 | 12 | This can be useful for Aggregators whose dependencies are optional when installing torchjd. 13 | """ 14 | 15 | if any(find_spec(name) is None for name in dependency_names): 16 | raise OptionalDepsNotInstalledError() 17 | -------------------------------------------------------------------------------- /tests/utils/contexts.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Generator 2 | from contextlib import AbstractContextManager, contextmanager 3 | from typing import Any, TypeAlias 4 | 5 | import torch 6 | from device import DEVICE 7 | 8 | ExceptionContext: TypeAlias = AbstractContextManager[Exception | None] 9 | 10 | 11 | @contextmanager 12 | def fork_rng(seed: int = 0) -> Generator[Any, None, None]: 13 | devices = [DEVICE] if DEVICE.type == "cuda" else [] 14 | with torch.random.fork_rng(devices=devices, device_type=DEVICE.type) as ctx: 15 | torch.manual_seed(seed) 16 | yield ctx 17 | -------------------------------------------------------------------------------- /tests/utils/dict_assertions.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Hashable 2 | from typing import TypeVar 3 | 4 | from torch import Tensor 5 | from torch.testing import assert_close 6 | 7 | _KeyType = TypeVar("_KeyType", bound=Hashable) 8 | 9 | 10 | def assert_tensor_dicts_are_close(d1: dict[_KeyType, Tensor], d2: dict[_KeyType, Tensor]) -> None: 11 | """ 12 | Check that two dictionaries of tensors are close enough. Note that this does not require the 13 | keys to have the same ordering. 14 | """ 15 | 16 | assert d1.keys() == d2.keys() 17 | 18 | for key in d1: 19 | assert_close(d1[key], d2[key]) 20 | -------------------------------------------------------------------------------- /docs/source/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ```{include} ../../README.md 4 | :start-after: 5 | :end-before: 6 | ``` 7 | 8 | Note that `torchjd` requires Python 3.10, 3.11, 3.12, 3.13 or 3.14 and `torch>=2.0`. 9 | 10 | Some aggregators (CAGrad and Nash-MTL) have additional dependencies that are not included by default 11 | when installing `torchjd`. To install them, you can use: 12 | ``` 13 | pip install torchjd[cagrad] 14 | ``` 15 | ``` 16 | pip install torchjd[nash_mtl] 17 | ``` 18 | 19 | To install `torchjd` with all of its optional dependencies, you can also use: 20 | ``` 21 | pip install torchjd[full] 22 | ``` 23 | -------------------------------------------------------------------------------- /tests/doc/test_backward.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the test of the backward usage example, with a verification of the value of the 3 | obtained `.grad` field. 4 | """ 5 | 6 | from torch.testing import assert_close 7 | 8 | 9 | def test_backward(): 10 | import torch 11 | 12 | from torchjd.aggregation import UPGrad 13 | from torchjd.autojac import backward 14 | 15 | param = torch.tensor([1.0, 2.0], requires_grad=True) 16 | # Compute arbitrary quantities that are function of param 17 | y1 = torch.tensor([-1.0, 1.0]) @ param 18 | y2 = (param**2).sum() 19 | 20 | backward([y1, y2], UPGrad()) 21 | 22 | assert_close(param.grad, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04) 23 | -------------------------------------------------------------------------------- /tests/speed/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | def noop(): 8 | pass 9 | 10 | 11 | def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> Tensor: 12 | init_fn() 13 | 14 | times = [] 15 | for _ in range(n_runs): 16 | pre_fn() 17 | start = time.perf_counter() 18 | fn() 19 | post_fn() 20 | elapsed_time = time.perf_counter() - start 21 | times.append(elapsed_time) 22 | 23 | return torch.tensor(times) 24 | 25 | 26 | def print_times(name: str, times: Tensor) -> None: 27 | print(f"{name} times (avg = {times.mean():.5f}, std = {times.std():.5f}") 28 | print(times) 29 | print() 30 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= -W --keep-going -n 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | env: 8 | PYTHON_VERSION: 3.14 9 | 10 | jobs: 11 | pypi-publish: 12 | name: Publish to PyPI 13 | environment: release 14 | runs-on: ubuntu-latest 15 | permissions: 16 | # IMPORTANT: this permission is mandatory for trusted publishing 17 | id-token: write 18 | steps: 19 | - name: Checkout repository 20 | uses: actions/checkout@v4 21 | 22 | - name: Set up uv 23 | uses: astral-sh/setup-uv@v5 24 | with: 25 | python-version: ${{ env.PYTHON_VERSION }} 26 | 27 | - name: Build 28 | run: uv build 29 | 30 | - name: Publish package distributions to PyPI 31 | run: uv publish -v 32 | -------------------------------------------------------------------------------- /tests/unit/aggregation/test_aggregator_bases.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | from contextlib import nullcontext as does_not_raise 3 | 4 | from pytest import mark, raises 5 | from utils.contexts import ExceptionContext 6 | from utils.tensors import randn_ 7 | 8 | from torchjd.aggregation import Aggregator 9 | 10 | 11 | @mark.parametrize( 12 | ["shape", "expectation"], 13 | [ 14 | ([], raises(ValueError)), 15 | ([1], raises(ValueError)), 16 | ([1, 2], does_not_raise()), 17 | ([1, 2, 3], raises(ValueError)), 18 | ([1, 2, 3, 4], raises(ValueError)), 19 | ], 20 | ) 21 | def test_check_is_matrix(shape: Sequence[int], expectation: ExceptionContext): 22 | with expectation: 23 | Aggregator._check_is_matrix(randn_(shape)) 24 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_sum.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from ._aggregator_bases import WeightedAggregator 5 | from ._weighting_bases import Matrix, Weighting 6 | 7 | 8 | class Sum(WeightedAggregator): 9 | """ 10 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that sums of the rows of the input 11 | matrices. 12 | """ 13 | 14 | def __init__(self): 15 | super().__init__(weighting=SumWeighting()) 16 | 17 | 18 | class SumWeighting(Weighting[Matrix]): 19 | r""" 20 | :class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights 21 | :math:`\begin{bmatrix} 1 & \dots & 1 \end{bmatrix}^T \in \mathbb{R}^m`. 22 | """ 23 | 24 | def forward(self, matrix: Tensor) -> Tensor: 25 | device = matrix.device 26 | dtype = matrix.dtype 27 | weights = torch.ones(matrix.shape[0], device=device, dtype=dtype) 28 | return weights 29 | -------------------------------------------------------------------------------- /src/torchjd/autojac/_transform/_materialize.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | def materialize( 8 | optional_tensors: Sequence[Tensor | None], inputs: Sequence[Tensor] 9 | ) -> tuple[Tensor, ...]: 10 | """ 11 | Transforms a sequence of optional tensors by changing each None by a tensor of zeros of the same 12 | shape as the corresponding input. Returns the obtained sequence as a tuple. 13 | 14 | Note that the name "materialize" comes from the flag `materialize_grads` from 15 | `torch.autograd.grad`, which will be available in future torch releases. 16 | """ 17 | 18 | tensors = [] 19 | for optional_tensor, input in zip(optional_tensors, inputs): 20 | if optional_tensor is None: 21 | tensors.append(torch.zeros_like(input)) 22 | else: 23 | tensors.append(optional_tensor) 24 | return tuple(tensors) 25 | -------------------------------------------------------------------------------- /src/torchjd/autojac/_transform/_init.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Set 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from ._base import RequirementError, TensorDict, Transform 7 | 8 | 9 | class Init(Transform): 10 | """ 11 | Transform from {} returning Gradients filled with ones for each of the provided values. 12 | 13 | :param values: Tensors for which Gradients must be returned. 14 | """ 15 | 16 | def __init__(self, values: Set[Tensor]): 17 | self.values = values 18 | 19 | def __call__(self, input: TensorDict) -> TensorDict: 20 | return {value: torch.ones_like(value) for value in self.values} 21 | 22 | def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: 23 | if not input_keys == set(): 24 | raise RequirementError( 25 | f"The input_keys should be the empty set. Found input_keys {input_keys}." 26 | ) 27 | return set(self.values) 28 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /src/torchjd/__init__.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from warnings import warn as _warn 3 | 4 | from .autojac import backward as _backward 5 | from .autojac import mtl_backward as _mtl_backward 6 | 7 | _deprecated_items: dict[str, tuple[str, Callable]] = { 8 | "backward": ("autojac", _backward), 9 | "mtl_backward": ("autojac", _mtl_backward), 10 | } 11 | 12 | 13 | def __getattr__(name: str) -> Callable: 14 | """ 15 | If an attribute is not found in the module's dictionary and its name is in _deprecated_items, 16 | then import it with a warning. 17 | """ 18 | if name in _deprecated_items: 19 | _warn( 20 | f"Importing `{name}` from `torchjd` is deprecated. Please import it from " 21 | f"`{_deprecated_items[name][0]}` instead.", 22 | DeprecationWarning, 23 | stacklevel=2, 24 | ) 25 | return _deprecated_items[name][1] 26 | raise AttributeError(f"module {__name__} has no attribute {name}") 27 | -------------------------------------------------------------------------------- /docs/source/docs/aggregation/index.rst: -------------------------------------------------------------------------------- 1 | aggregation 2 | =========== 3 | 4 | .. automodule:: torchjd.aggregation 5 | :no-members: 6 | 7 | Abstract base classes 8 | --------------------- 9 | 10 | .. autoclass:: torchjd.aggregation.Aggregator 11 | :members: 12 | :undoc-members: 13 | :exclude-members: forward 14 | 15 | .. autoclass:: torchjd.aggregation.Weighting 16 | :members: 17 | :undoc-members: 18 | :exclude-members: forward 19 | 20 | .. autoclass:: torchjd.aggregation.GeneralizedWeighting 21 | :members: 22 | :undoc-members: 23 | :exclude-members: forward 24 | 25 | 26 | .. toctree:: 27 | :hidden: 28 | :maxdepth: 1 29 | 30 | upgrad.rst 31 | aligned_mtl.rst 32 | cagrad.rst 33 | config.rst 34 | constant.rst 35 | dualproj.rst 36 | flattening.rst 37 | graddrop.rst 38 | imtl_g.rst 39 | krum.rst 40 | mean.rst 41 | mgda.rst 42 | nash_mtl.rst 43 | pcgrad.rst 44 | random.rst 45 | sum.rst 46 | trimmed_mean.rst 47 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_mean.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from ._aggregator_bases import WeightedAggregator 5 | from ._weighting_bases import Matrix, Weighting 6 | 7 | 8 | class Mean(WeightedAggregator): 9 | """ 10 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input 11 | matrices. 12 | """ 13 | 14 | def __init__(self): 15 | super().__init__(weighting=MeanWeighting()) 16 | 17 | 18 | class MeanWeighting(Weighting[Matrix]): 19 | r""" 20 | :class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights 21 | :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in 22 | \mathbb{R}^m`. 23 | """ 24 | 25 | def forward(self, matrix: Tensor) -> Tensor: 26 | device = matrix.device 27 | dtype = matrix.dtype 28 | m = matrix.shape[0] 29 | weights = torch.full(size=[m], fill_value=1 / m, device=device, dtype=dtype) 30 | return weights 31 | -------------------------------------------------------------------------------- /tests/unit/aggregation/_utils/test_pref_vector.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext as does_not_raise 2 | 3 | from pytest import mark, raises 4 | from torch import Tensor 5 | from utils.contexts import ExceptionContext 6 | from utils.tensors import ones_ 7 | 8 | from torchjd.aggregation._mean import MeanWeighting 9 | from torchjd.aggregation._utils.pref_vector import pref_vector_to_weighting 10 | 11 | 12 | @mark.parametrize( 13 | ["pref_vector", "expectation"], 14 | [ 15 | (None, does_not_raise()), 16 | (ones_([]), raises(ValueError)), 17 | (ones_([0]), does_not_raise()), 18 | (ones_([1]), does_not_raise()), 19 | (ones_([5]), does_not_raise()), 20 | (ones_([1, 1]), raises(ValueError)), 21 | (ones_([1, 1, 1]), raises(ValueError)), 22 | ], 23 | ) 24 | def test_pref_vector_to_weighting_check(pref_vector: Tensor | None, expectation: ExceptionContext): 25 | with expectation: 26 | _ = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) 27 | -------------------------------------------------------------------------------- /src/torchjd/autojac/_transform/_select.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Set 2 | 3 | from torch import Tensor 4 | 5 | from ._base import RequirementError, TensorDict, Transform 6 | 7 | 8 | class Select(Transform): 9 | """ 10 | Transform returning a subset of the provided TensorDict. 11 | 12 | :param keys: The keys that should be included in the returned subset. 13 | """ 14 | 15 | def __init__(self, keys: Set[Tensor]): 16 | self.keys = keys 17 | 18 | def __call__(self, tensor_dict: TensorDict) -> TensorDict: 19 | output = {key: tensor_dict[key] for key in self.keys} 20 | return type(tensor_dict)(output) 21 | 22 | def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: 23 | keys = set(self.keys) 24 | if not keys.issubset(input_keys): 25 | raise RequirementError( 26 | f"The input_keys should be a super set of the keys to select. Found input_keys " 27 | f"{input_keys} and keys to select {keys}." 28 | ) 29 | return keys 30 | -------------------------------------------------------------------------------- /tests/unit/aggregation/test_random.py: -------------------------------------------------------------------------------- 1 | from pytest import mark 2 | from torch import Tensor 3 | 4 | from torchjd.aggregation import Random 5 | 6 | from ._asserts import assert_expected_structure, assert_strongly_stationary 7 | from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices 8 | 9 | scaled_pairs = [(Random(), matrix) for matrix in scaled_matrices] 10 | typical_pairs = [(Random(), matrix) for matrix in typical_matrices] 11 | non_strong_pairs = [(Random(), matrix) for matrix in non_strong_matrices] 12 | 13 | 14 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) 15 | def test_expected_structure(aggregator: Random, matrix: Tensor): 16 | assert_expected_structure(aggregator, matrix) 17 | 18 | 19 | @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) 20 | def test_strongly_stationary(aggregator: Random, matrix: Tensor): 21 | assert_strongly_stationary(aggregator, matrix) 22 | 23 | 24 | def test_representations(): 25 | A = Random() 26 | assert repr(A) == "Random()" 27 | assert str(A) == "Random" 28 | -------------------------------------------------------------------------------- /docs/source/_templates/page.html: -------------------------------------------------------------------------------- 1 | {# 2 | Adds a canonical link to the stable version of the documentation to each built hmtl page. 3 | The reason to do that is that search engines require it: 4 | https://developers.google.com/search/docs/crawling-indexing/consolidate-duplicate-urls 5 | 6 | This template overrides the page.html furo template 7 | (https://github.com/pradyunsg/furo/blob/main/src/furo/theme/furo/page.html) by adding an extra 8 | line to its section, with the appropriate canonical link. Note that there is no guarantee 9 | that the furo theme keeps using a file named page.html, so this could silently break with future 10 | updates of furo. See https://pradyunsg.me/furo/customisation/injecting/ and 11 | https://github.com/pradyunsg/furo/discussions/248 for more information. 12 | #} 13 | {% extends "!page.html" %} 14 | {% block extrahead %} 15 | 16 | 17 | 18 | {% endblock %} 19 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_random.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import functional as F 4 | 5 | from ._aggregator_bases import WeightedAggregator 6 | from ._weighting_bases import Matrix, Weighting 7 | 8 | 9 | class Random(WeightedAggregator): 10 | """ 11 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that computes a random combination of 12 | the rows of the provided matrices, as defined in algorithm 2 of `Reasonable Effectiveness of 13 | Random Weighting: A Litmus Test for Multi-Task Learning 14 | `_. 15 | """ 16 | 17 | def __init__(self): 18 | super().__init__(RandomWeighting()) 19 | 20 | 21 | class RandomWeighting(Weighting[Matrix]): 22 | """ 23 | :class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights 24 | at each call. 25 | """ 26 | 27 | def forward(self, matrix: Tensor) -> Tensor: 28 | random_vector = torch.randn(matrix.shape[0], device=matrix.device, dtype=matrix.dtype) 29 | weights = F.softmax(random_vector, dim=-1) 30 | return weights 31 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_utils/pref_vector.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | from torchjd.aggregation._constant import ConstantWeighting 4 | from torchjd.aggregation._weighting_bases import Matrix, Weighting 5 | 6 | from .str import vector_to_str 7 | 8 | 9 | def pref_vector_to_weighting( 10 | pref_vector: Tensor | None, default: Weighting[Matrix] 11 | ) -> Weighting[Matrix]: 12 | """ 13 | Returns the weighting associated to a given preference vector, with a fallback to a default 14 | weighting if the preference vector is None. 15 | """ 16 | 17 | if pref_vector is None: 18 | return default 19 | else: 20 | if pref_vector.ndim != 1: 21 | raise ValueError( 22 | "Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = " 23 | f"{pref_vector.ndim}`." 24 | ) 25 | return ConstantWeighting(pref_vector) 26 | 27 | 28 | def pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str: 29 | """Returns a suffix string containing the representation of the optional preference vector.""" 30 | 31 | if pref_vector is None: 32 | return "" 33 | else: 34 | return f"([{vector_to_str(pref_vector)}])" 35 | -------------------------------------------------------------------------------- /tests/doc/test_aggregation.py: -------------------------------------------------------------------------------- 1 | """This file contains the test corresponding to the usage example of Aggregator and Weighting.""" 2 | 3 | import torch 4 | from torch.testing import assert_close 5 | 6 | 7 | def test_aggregation_and_weighting(): 8 | from torch import tensor 9 | 10 | from torchjd.aggregation import UPGrad, UPGradWeighting 11 | 12 | aggregator = UPGrad() 13 | jacobian = tensor([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]]) 14 | aggregation = aggregator(jacobian) 15 | 16 | assert_close(aggregation, tensor([0.2929, 1.9004, 1.9004]), rtol=0, atol=1e-4) 17 | 18 | weighting = UPGradWeighting() 19 | gramian = jacobian @ jacobian.T 20 | weights = weighting(gramian) 21 | 22 | assert_close(weights, tensor([1.1109, 0.7894]), rtol=0, atol=1e-4) 23 | 24 | 25 | def test_generalized_weighting(): 26 | from torch import ones 27 | 28 | from torchjd.aggregation import Flattening, UPGradWeighting 29 | 30 | weighting = Flattening(UPGradWeighting()) 31 | # Generate a generalized Gramian filled with ones, for the sake of the example 32 | generalized_gramian = ones((2, 3, 3, 2)) 33 | weights = weighting(generalized_gramian) 34 | 35 | assert_close(weights, torch.full((2, 3), 0.1667), rtol=0, atol=1e-4) 36 | -------------------------------------------------------------------------------- /tests/unit/aggregation/test_aligned_mtl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytest import mark 3 | from torch import Tensor 4 | 5 | from torchjd.aggregation import AlignedMTL 6 | 7 | from ._asserts import assert_expected_structure, assert_permutation_invariant 8 | from ._inputs import scaled_matrices, typical_matrices 9 | 10 | scaled_pairs = [(AlignedMTL(), matrix) for matrix in scaled_matrices] 11 | typical_pairs = [(AlignedMTL(), matrix) for matrix in typical_matrices] 12 | 13 | 14 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) 15 | def test_expected_structure(aggregator: AlignedMTL, matrix: Tensor): 16 | assert_expected_structure(aggregator, matrix) 17 | 18 | 19 | @mark.parametrize(["aggregator", "matrix"], typical_pairs) 20 | def test_permutation_invariant(aggregator: AlignedMTL, matrix: Tensor): 21 | assert_permutation_invariant(aggregator, matrix) 22 | 23 | 24 | def test_representations(): 25 | A = AlignedMTL(pref_vector=None) 26 | assert repr(A) == "AlignedMTL(pref_vector=None)" 27 | assert str(A) == "AlignedMTL" 28 | 29 | A = AlignedMTL(pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu")) 30 | assert repr(A) == "AlignedMTL(pref_vector=tensor([1., 2., 3.]))" 31 | assert str(A) == "AlignedMTL([1., 2., 3.])" 32 | -------------------------------------------------------------------------------- /src/torchjd/autogram/_gramian_accumulator.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torch import Tensor 4 | 5 | 6 | class GramianAccumulator: 7 | """ 8 | Efficiently accumulates the Gramian of the Jacobian during reverse-mode differentiation. 9 | 10 | Jacobians from multiple graph paths to the same parameter are first summed to obtain the full 11 | Jacobian w.r.t. a parameter, then its Gramian is computed and accumulated, over parameters, into 12 | the total Gramian matrix. Intermediate matrices are discarded immediately to save memory. 13 | """ 14 | 15 | def __init__(self) -> None: 16 | self._gramian: Optional[Tensor] = None 17 | 18 | def reset(self) -> None: 19 | self._gramian = None 20 | 21 | def accumulate_gramian(self, gramian: Tensor) -> None: 22 | if self._gramian is not None: 23 | self._gramian.add_(gramian) 24 | else: 25 | self._gramian = gramian 26 | 27 | @property 28 | def gramian(self) -> Optional[Tensor]: 29 | """ 30 | Get the Gramian matrix accumulated so far. 31 | 32 | :returns: Accumulated Gramian matrix of shape (batch_size, batch_size) or None if nothing 33 | was accumulated yet. 34 | """ 35 | 36 | return self._gramian 37 | -------------------------------------------------------------------------------- /tests/doc/test_autogram.py: -------------------------------------------------------------------------------- 1 | """This file contains tests for the usage examples related to autogram.""" 2 | 3 | 4 | def test_engine(): 5 | import torch 6 | from torch.nn import Linear, MSELoss, ReLU, Sequential 7 | from torch.optim import SGD 8 | 9 | from torchjd.aggregation import UPGradWeighting 10 | from torchjd.autogram import Engine 11 | 12 | # Generate data (8 batches of 16 examples of dim 5) for the sake of the example 13 | inputs = torch.randn(8, 16, 5) 14 | targets = torch.randn(8, 16) 15 | 16 | model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) 17 | optimizer = SGD(model.parameters()) 18 | 19 | criterion = MSELoss(reduction="none") # Important to use reduction="none" 20 | weighting = UPGradWeighting() 21 | 22 | # Create the engine before the backward pass, and only once. 23 | engine = Engine(model, batch_dim=0) 24 | 25 | for input, target in zip(inputs, targets): 26 | output = model(input).squeeze(dim=1) # shape: [16] 27 | losses = criterion(output, target) # shape: [16] 28 | 29 | optimizer.zero_grad() 30 | gramian = engine.compute_gramian(losses) # shape: [16, 16] 31 | weights = weighting(gramian) # shape: [16] 32 | losses.backward(weights) 33 | optimizer.step() 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # uv 2 | uv.lock 3 | 4 | # Jupyter Notebooks 5 | *.ipynb 6 | 7 | # PyCharm 8 | .idea/ 9 | 10 | # Generated images 11 | images/ 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | pip-wheel-metadata/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # Jupyter Notebook 68 | .ipynb_checkpoints 69 | 70 | # IPython 71 | profile_default/ 72 | ipython_config.py 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 78 | __pypackages__/ 79 | 80 | # Environments 81 | .env 82 | .venv 83 | env/ 84 | venv/ 85 | ENV/ 86 | env.bak/ 87 | venv.bak/ 88 | 89 | # mypy 90 | .mypy_cache/ 91 | .dmypy.json 92 | dmypy.json 93 | 94 | # macOS file 95 | .DS_Store 96 | 97 | # VS Code folder 98 | .vscode/ 99 | -------------------------------------------------------------------------------- /tests/unit/aggregation/test_sum.py: -------------------------------------------------------------------------------- 1 | from pytest import mark 2 | from torch import Tensor 3 | 4 | from torchjd.aggregation import Sum 5 | 6 | from ._asserts import ( 7 | assert_expected_structure, 8 | assert_linear_under_scaling, 9 | assert_permutation_invariant, 10 | assert_strongly_stationary, 11 | ) 12 | from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices 13 | 14 | scaled_pairs = [(Sum(), matrix) for matrix in scaled_matrices] 15 | typical_pairs = [(Sum(), matrix) for matrix in typical_matrices] 16 | non_strong_pairs = [(Sum(), matrix) for matrix in non_strong_matrices] 17 | 18 | 19 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) 20 | def test_expected_structure(aggregator: Sum, matrix: Tensor): 21 | assert_expected_structure(aggregator, matrix) 22 | 23 | 24 | @mark.parametrize(["aggregator", "matrix"], typical_pairs) 25 | def test_permutation_invariant(aggregator: Sum, matrix: Tensor): 26 | assert_permutation_invariant(aggregator, matrix) 27 | 28 | 29 | @mark.parametrize(["aggregator", "matrix"], typical_pairs) 30 | def test_linear_under_scaling(aggregator: Sum, matrix: Tensor): 31 | assert_linear_under_scaling(aggregator, matrix) 32 | 33 | 34 | @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) 35 | def test_strongly_stationary(aggregator: Sum, matrix: Tensor): 36 | assert_strongly_stationary(aggregator, matrix) 37 | 38 | 39 | def test_representations(): 40 | A = Sum() 41 | assert repr(A) == "Sum()" 42 | assert str(A) == "Sum" 43 | -------------------------------------------------------------------------------- /tests/unit/aggregation/test_mean.py: -------------------------------------------------------------------------------- 1 | from pytest import mark 2 | from torch import Tensor 3 | 4 | from torchjd.aggregation import Mean 5 | 6 | from ._asserts import ( 7 | assert_expected_structure, 8 | assert_linear_under_scaling, 9 | assert_permutation_invariant, 10 | assert_strongly_stationary, 11 | ) 12 | from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices 13 | 14 | scaled_pairs = [(Mean(), matrix) for matrix in scaled_matrices] 15 | typical_pairs = [(Mean(), matrix) for matrix in typical_matrices] 16 | non_strong_pairs = [(Mean(), matrix) for matrix in non_strong_matrices] 17 | 18 | 19 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) 20 | def test_expected_structure(aggregator: Mean, matrix: Tensor): 21 | assert_expected_structure(aggregator, matrix) 22 | 23 | 24 | @mark.parametrize(["aggregator", "matrix"], typical_pairs) 25 | def test_permutation_invariant(aggregator: Mean, matrix: Tensor): 26 | assert_permutation_invariant(aggregator, matrix) 27 | 28 | 29 | @mark.parametrize(["aggregator", "matrix"], typical_pairs) 30 | def test_linear_under_scaling(aggregator: Mean, matrix: Tensor): 31 | assert_linear_under_scaling(aggregator, matrix) 32 | 33 | 34 | @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) 35 | def test_strongly_stationary(aggregator: Mean, matrix: Tensor): 36 | assert_strongly_stationary(aggregator, matrix) 37 | 38 | 39 | def test_representations(): 40 | A = Mean() 41 | assert repr(A) == "Mean()" 42 | assert str(A) == "Mean" 43 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_imtl_g.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from ._aggregator_bases import GramianWeightedAggregator 5 | from ._utils.non_differentiable import raise_non_differentiable_error 6 | from ._weighting_bases import PSDMatrix, Weighting 7 | 8 | 9 | class IMTLG(GramianWeightedAggregator): 10 | """ 11 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` generalizing the method described in 12 | `Towards Impartial Multi-task Learning `_. 13 | This generalization, defined formally in `Jacobian Descent For Multi-Objective Optimization 14 | `_, supports matrices with some linearly dependant rows. 15 | """ 16 | 17 | def __init__(self): 18 | super().__init__(IMTLGWeighting()) 19 | 20 | # This prevents computing gradients that can be very wrong. 21 | self.register_full_backward_pre_hook(raise_non_differentiable_error) 22 | 23 | 24 | class IMTLGWeighting(Weighting[PSDMatrix]): 25 | """ 26 | :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of 27 | :class:`~torchjd.aggregation.IMTLG`. 28 | """ 29 | 30 | def forward(self, gramian: Tensor) -> Tensor: 31 | d = torch.sqrt(torch.diagonal(gramian)) 32 | v = torch.linalg.pinv(gramian) @ d 33 | v_sum = v.sum() 34 | 35 | if v_sum.abs() < 1e-12: 36 | weights = torch.zeros_like(v) 37 | else: 38 | weights = v / v_sum 39 | 40 | return weights 41 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_utils/gramian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | def compute_gramian(matrix: Tensor) -> Tensor: 6 | """ 7 | Computes the `Gramian matrix `_ of a given matrix. 8 | """ 9 | 10 | return matrix @ matrix.T 11 | 12 | 13 | def normalize(gramian: Tensor, eps: float) -> Tensor: 14 | """ 15 | Normalizes the gramian `G=AA^T` with respect to the Frobenius norm of `A`. 16 | 17 | If `G=A A^T`, then the Frobenius norm of `A` is the square root of the trace of `G`, i.e., the 18 | sqrt of the sum of the diagonal elements. The gramian of the (Frobenius) normalization of `A` is 19 | therefore `G` divided by the sum of its diagonal elements. 20 | """ 21 | squared_frobenius_norm = gramian.diagonal().sum() 22 | if squared_frobenius_norm < eps: 23 | return torch.zeros_like(gramian) 24 | else: 25 | return gramian / squared_frobenius_norm 26 | 27 | 28 | def regularize(gramian: Tensor, eps: float) -> Tensor: 29 | """ 30 | Adds a regularization term to the gramian to enforce positive definiteness. 31 | 32 | Because of numerical errors, `gramian` might have slightly negative eigenvalue(s). Adding a 33 | regularization term which is a small proportion of the identity matrix ensures that the gramian 34 | is positive definite. 35 | """ 36 | 37 | regularization_matrix = eps * torch.eye( 38 | gramian.shape[0], dtype=gramian.dtype, device=gramian.device 39 | ) 40 | return gramian + regularization_matrix 41 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_flattening.py: -------------------------------------------------------------------------------- 1 | from math import prod 2 | 3 | from torch import Tensor 4 | 5 | from torchjd.aggregation._weighting_bases import GeneralizedWeighting, PSDMatrix, Weighting 6 | from torchjd.autogram._gramian_utils import reshape_gramian 7 | 8 | 9 | class Flattening(GeneralizedWeighting): 10 | """ 11 | :class:`~torchjd.aggregation._weighting_bases.GeneralizedWeighting` flattening the generalized 12 | Gramian into a square matrix, extracting a vector of weights from it using a 13 | :class:`~torchjd.aggregation._weighting_bases.Weighting`, and returning the reshaped tensor of 14 | weights. 15 | 16 | For instance, when applied to a generalized Gramian of shape ``[2, 3, 3, 2]``, it would flatten 17 | it into a square Gramian matrix of shape ``[6, 6]``, apply the weighting on it to get a vector 18 | of weights of shape ``[6]``, and then return this vector reshaped into a matrix of shape 19 | ``[2, 3]``. 20 | 21 | :param weighting: The weighting to apply to the Gramian matrix. 22 | """ 23 | 24 | def __init__(self, weighting: Weighting[PSDMatrix]): 25 | super().__init__() 26 | self.weighting = weighting 27 | 28 | def forward(self, generalized_gramian: Tensor) -> Tensor: 29 | k = generalized_gramian.ndim // 2 30 | shape = generalized_gramian.shape[:k] 31 | m = prod(shape) 32 | square_gramian = reshape_gramian(generalized_gramian, [m]) 33 | weights_vector = self.weighting(square_gramian) 34 | weights = weights_vector.reshape(shape) 35 | return weights 36 | -------------------------------------------------------------------------------- /src/torchjd/autojac/_transform/_ordered_set.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections import OrderedDict 4 | from collections.abc import Hashable, Iterable, Iterator, MutableSet 5 | from typing import TypeVar 6 | 7 | _T = TypeVar("_T", bound=Hashable) 8 | 9 | 10 | class OrderedSet(MutableSet[_T]): 11 | """Ordered collection of distinct elements.""" 12 | 13 | def __init__(self, elements: Iterable[_T]): 14 | super().__init__() 15 | self.ordered_dict = OrderedDict[_T, None]([(element, None) for element in elements]) 16 | 17 | def difference_update(self, elements: set[_T]) -> None: 18 | """Removes all specified elements from the OrderedSet.""" 19 | 20 | for element in elements: 21 | self.discard(element) 22 | 23 | def add(self, element: _T) -> None: 24 | """Adds the specified element to the OrderedSet.""" 25 | 26 | self.ordered_dict[element] = None 27 | 28 | def __add__(self, other: OrderedSet[_T]) -> OrderedSet[_T]: 29 | """Creates a new OrderedSet with the elements of self followed by the elements of other.""" 30 | 31 | return OrderedSet([*self, *other]) 32 | 33 | def discard(self, value: _T) -> None: 34 | if value in self: 35 | del self.ordered_dict[value] 36 | 37 | def __iter__(self) -> Iterator[_T]: 38 | return self.ordered_dict.__iter__() 39 | 40 | def __len__(self) -> int: 41 | return len(self.ordered_dict) 42 | 43 | def __contains__(self, element: object) -> bool: 44 | return element in self.ordered_dict 45 | -------------------------------------------------------------------------------- /src/torchjd/autogram/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The autogram package provides an engine to efficiently compute the Gramian of the Jacobian of a 3 | tensor of outputs (generally losses) with respect to some modules' parameters. This Gramian contains 4 | all the inner products between pairs of gradients, and is thus a sufficient statistic for most 5 | weighting methods. The algorithm is formally defined in Section 6 of `Jacobian Descent For 6 | Multi-Objective Optimization `_). 7 | 8 | Due to computing the Gramian iteratively over the layers, without ever having to store the full 9 | Jacobian in memory, this method is much more memory-efficient than 10 | :doc:`autojac <../autojac/index>`, which makes it often much faster. Note that we're still working 11 | on making autogram faster and more memory-efficient, and it's interface may change in future 12 | releases. 13 | 14 | The list of Weightings compatible with ``autogram`` is: 15 | 16 | * :class:`~torchjd.aggregation.UPGradWeighting` 17 | * :class:`~torchjd.aggregation.AlignedMTLWeighting` 18 | * :class:`~torchjd.aggregation.CAGradWeighting` 19 | * :class:`~torchjd.aggregation.ConstantWeighting` 20 | * :class:`~torchjd.aggregation.DualProjWeighting` 21 | * :class:`~torchjd.aggregation.IMTLGWeighting` 22 | * :class:`~torchjd.aggregation.KrumWeighting` 23 | * :class:`~torchjd.aggregation.MeanWeighting` 24 | * :class:`~torchjd.aggregation.MGDAWeighting` 25 | * :class:`~torchjd.aggregation.PCGradWeighting` 26 | * :class:`~torchjd.aggregation.RandomWeighting` 27 | * :class:`~torchjd.aggregation.SumWeighting` 28 | """ 29 | 30 | from ._engine import Engine 31 | -------------------------------------------------------------------------------- /tests/utils/tensors.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from device import DEVICE 5 | from torch import nn 6 | from torch.utils._pytree import PyTree, tree_map 7 | from utils.architectures import get_in_out_shapes 8 | from utils.contexts import fork_rng 9 | 10 | # Curried calls to torch functions that require a device so that we automatically fix the device 11 | # for code written in the tests, while not affecting code written in src (what 12 | # torch.set_default_device or what a too large `with torch.device(DEVICE)` context would have done). 13 | 14 | empty_ = partial(torch.empty, device=DEVICE) 15 | eye_ = partial(torch.eye, device=DEVICE) 16 | ones_ = partial(torch.ones, device=DEVICE) 17 | rand_ = partial(torch.rand, device=DEVICE) 18 | randint_ = partial(torch.randint, device=DEVICE) 19 | randn_ = partial(torch.randn, device=DEVICE) 20 | randperm_ = partial(torch.randperm, device=DEVICE) 21 | tensor_ = partial(torch.tensor, device=DEVICE) 22 | zeros_ = partial(torch.zeros, device=DEVICE) 23 | 24 | 25 | def make_inputs_and_targets(model: nn.Module, batch_size: int) -> tuple[PyTree, PyTree]: 26 | input_shapes, output_shapes = get_in_out_shapes(model) 27 | with fork_rng(seed=0): 28 | inputs = _make_tensors(batch_size, input_shapes) 29 | targets = _make_tensors(batch_size, output_shapes) 30 | 31 | return inputs, targets 32 | 33 | 34 | def _make_tensors(batch_size: int, tensor_shapes: PyTree) -> PyTree: 35 | def is_leaf(s): 36 | return isinstance(s, tuple) and all([isinstance(e, int) for e in s]) 37 | 38 | return tree_map(lambda s: randn_((batch_size,) + s), tensor_shapes, is_leaf=is_leaf) 39 | -------------------------------------------------------------------------------- /tests/unit/aggregation/test_imtl_g.py: -------------------------------------------------------------------------------- 1 | from pytest import mark 2 | from torch import Tensor 3 | from torch.testing import assert_close 4 | from utils.tensors import ones_, zeros_ 5 | 6 | from torchjd.aggregation import IMTLG 7 | 8 | from ._asserts import ( 9 | assert_expected_structure, 10 | assert_non_differentiable, 11 | assert_permutation_invariant, 12 | ) 13 | from ._inputs import scaled_matrices, typical_matrices 14 | 15 | scaled_pairs = [(IMTLG(), matrix) for matrix in scaled_matrices] 16 | typical_pairs = [(IMTLG(), matrix) for matrix in typical_matrices] 17 | requires_grad_pairs = [(IMTLG(), ones_(3, 5, requires_grad=True))] 18 | 19 | 20 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) 21 | def test_expected_structure(aggregator: IMTLG, matrix: Tensor): 22 | assert_expected_structure(aggregator, matrix) 23 | 24 | 25 | @mark.parametrize(["aggregator", "matrix"], typical_pairs) 26 | def test_permutation_invariant(aggregator: IMTLG, matrix: Tensor): 27 | assert_permutation_invariant(aggregator, matrix) 28 | 29 | 30 | @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) 31 | def test_non_differentiable(aggregator: IMTLG, matrix: Tensor): 32 | assert_non_differentiable(aggregator, matrix) 33 | 34 | 35 | def test_imtlg_zero(): 36 | """ 37 | Tests that IMTLG correctly returns the 0 vector in the special case where input matrix only 38 | consists of zeros. 39 | """ 40 | 41 | A = IMTLG() 42 | J = zeros_(2, 3) 43 | assert_close(A(J), zeros_(3)) 44 | 45 | 46 | def test_representations(): 47 | A = IMTLG() 48 | assert repr(A) == "IMTLG()" 49 | assert str(A) == "IMTLG" 50 | -------------------------------------------------------------------------------- /docs/source/examples/rnn.rst: -------------------------------------------------------------------------------- 1 | Recurrent Neural Network (RNN) 2 | ============================== 3 | 4 | When training recurrent neural networks for sequence modelling, we can easily obtain one loss per 5 | element of the output sequences. If the gradients of these losses are likely to conflict, Jacobian 6 | descent can be leveraged to enhance optimization. 7 | 8 | .. code-block:: python 9 | :emphasize-lines: 5-6, 10, 17, 20 10 | 11 | import torch 12 | from torch.nn import RNN 13 | from torch.optim import SGD 14 | 15 | from torchjd.aggregation import UPGrad 16 | from torchjd.autojac import backward 17 | 18 | rnn = RNN(input_size=10, hidden_size=20, num_layers=2) 19 | optimizer = SGD(rnn.parameters(), lr=0.1) 20 | aggregator = UPGrad() 21 | 22 | inputs = torch.randn(8, 5, 3, 10) # 8 batches of 3 sequences of length 5 and of dim 10. 23 | targets = torch.randn(8, 5, 3, 20) # 8 batches of 3 sequences of length 5 and of dim 20. 24 | 25 | for input, target in zip(inputs, targets): 26 | output, _ = rnn(input) # output is of shape [5, 3, 20]. 27 | losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element. 28 | 29 | optimizer.zero_grad() 30 | backward(losses, aggregator, parallel_chunk_size=1) 31 | optimizer.step() 32 | 33 | .. note:: 34 | At the time of writing, there seems to be an incompatibility between ``torch.vmap`` and 35 | ``torch.nn.RNN`` when running on CUDA (see `this issue 36 | `_ for more info), so we advise to set the 37 | ``parallel_chunk_size`` to ``1`` to avoid using ``torch.vmap``. To improve performance, you can 38 | check whether ``parallel_chunk_size=None`` (maximal parallelization) works on your side. 39 | -------------------------------------------------------------------------------- /src/torchjd/autojac/_transform/_accumulate.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | from ._base import TensorDict, Transform 4 | 5 | 6 | class Accumulate(Transform): 7 | """ 8 | Transform from Gradients to {} that accumulates gradients with respect to keys into their 9 | ``grad`` field. 10 | """ 11 | 12 | def __call__(self, gradients: TensorDict) -> TensorDict: 13 | for key in gradients.keys(): 14 | _check_expects_grad(key) 15 | if hasattr(key, "grad") and key.grad is not None: 16 | key.grad += gradients[key] 17 | else: 18 | # We clone the value because we do not want subsequent accumulations to also affect 19 | # this value (in case it is still used outside). We do not detach from the 20 | # computation graph because the value can have grad_fn that we want to keep track of 21 | # (in case it was obtained via create_graph=True and a differentiable aggregator). 22 | key.grad = gradients[key].clone() 23 | 24 | return {} 25 | 26 | def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: 27 | return set() 28 | 29 | 30 | def _check_expects_grad(tensor: Tensor) -> None: 31 | if not _expects_grad(tensor): 32 | raise ValueError( 33 | "Cannot populate the .grad field of a Tensor that does not satisfy:" 34 | "`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`." 35 | ) 36 | 37 | 38 | def _expects_grad(tensor: Tensor) -> bool: 39 | """ 40 | Determines whether a Tensor expects its .grad attribute to be populated. 41 | See https://pytorch.org/docs/stable/generated/torch.Tensor.is_leaf for more information. 42 | """ 43 | 44 | return tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad) 45 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v6.0.0 4 | hooks: 5 | - id: trailing-whitespace # Trim trailing whitespace at the end of lines. 6 | - id: end-of-file-fixer # Make sure files end in a newline and only a newline. 7 | - id: check-added-large-files # Prevent giant files from being committed. 8 | - id: check-case-conflict # Check for files that would conflict in case-insensitive filesystems. 9 | - id: check-docstring-first # Check a common error of defining a docstring after code. 10 | - id: check-merge-conflict # Check for files that contain merge conflict strings. 11 | 12 | - repo: https://github.com/PyCQA/flake8 13 | rev: 7.3.0 14 | hooks: 15 | - id: flake8 # Check style and syntax. Does not modify code, issues have to be solved manually. 16 | args: [ 17 | '--ignore=E501,E203,W503,E402', # Ignore line length problems, space after colon problems, line break occurring before a binary operator problems, module level import not at top of file problems. 18 | '--per-file-ignores=*/__init__.py:F401', # Ignore module imported but unused problems in __init__.py files. 19 | ] 20 | 21 | - repo: https://github.com/pycqa/isort 22 | rev: 6.1.0 23 | hooks: 24 | - id: isort # Sort imports. 25 | args: [ 26 | --multi-line=3, 27 | --line-length=100, 28 | --trailing-comma, 29 | --force-grid-wrap=0, 30 | --use-parentheses, 31 | --ensure-newline-before-comments, 32 | ] 33 | 34 | - repo: https://github.com/psf/black-pre-commit-mirror 35 | rev: 25.9.0 36 | hooks: 37 | - id: black # Format code. 38 | args: [--line-length=100] 39 | 40 | ci: 41 | autoupdate_commit_msg: 'chore: Update pre-commit hooks' 42 | autoupdate_schedule: quarterly 43 | -------------------------------------------------------------------------------- /tests/unit/aggregation/test_config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytest import mark 3 | from torch import Tensor 4 | from utils.tensors import ones_ 5 | 6 | from torchjd.aggregation import ConFIG 7 | 8 | from ._asserts import ( 9 | assert_expected_structure, 10 | assert_linear_under_scaling, 11 | assert_non_differentiable, 12 | assert_permutation_invariant, 13 | ) 14 | from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices 15 | 16 | scaled_pairs = [(ConFIG(), matrix) for matrix in scaled_matrices] 17 | typical_pairs = [(ConFIG(), matrix) for matrix in typical_matrices] 18 | non_strong_pairs = [(ConFIG(), matrix) for matrix in non_strong_matrices] 19 | requires_grad_pairs = [(ConFIG(), ones_(3, 5, requires_grad=True))] 20 | 21 | 22 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) 23 | def test_expected_structure(aggregator: ConFIG, matrix: Tensor): 24 | assert_expected_structure(aggregator, matrix) 25 | 26 | 27 | @mark.parametrize(["aggregator", "matrix"], typical_pairs) 28 | def test_permutation_invariant(aggregator: ConFIG, matrix: Tensor): 29 | assert_permutation_invariant(aggregator, matrix) 30 | 31 | 32 | @mark.parametrize(["aggregator", "matrix"], typical_pairs) 33 | def test_linear_under_scaling(aggregator: ConFIG, matrix: Tensor): 34 | assert_linear_under_scaling(aggregator, matrix) 35 | 36 | 37 | @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) 38 | def test_non_differentiable(aggregator: ConFIG, matrix: Tensor): 39 | assert_non_differentiable(aggregator, matrix) 40 | 41 | 42 | def test_representations(): 43 | A = ConFIG() 44 | assert repr(A) == "ConFIG(pref_vector=None)" 45 | assert str(A) == "ConFIG" 46 | 47 | A = ConFIG(pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu")) 48 | assert repr(A) == "ConFIG(pref_vector=tensor([1., 2., 3.]))" 49 | assert str(A) == "ConFIG([1., 2., 3.])" 50 | -------------------------------------------------------------------------------- /tests/unit/aggregation/_inputs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from device import DEVICE 3 | from utils.tensors import zeros_ 4 | 5 | from ._matrix_samplers import NonWeakSampler, NormalSampler, StrictlyWeakSampler, StrongSampler 6 | 7 | _normal_dims = [ 8 | (1, 1, 1), 9 | (4, 3, 1), 10 | (4, 3, 2), 11 | (4, 3, 3), 12 | (9, 11, 5), 13 | (9, 11, 9), 14 | ] 15 | 16 | _zero_dims = [ 17 | (1, 1, 0), 18 | (4, 3, 0), 19 | (9, 11, 0), 20 | ] 21 | 22 | _stationarity_dims = [ 23 | (20, 10, 10), 24 | (20, 10, 5), 25 | (20, 10, 1), 26 | (20, 100, 1), 27 | (20, 100, 19), 28 | ] 29 | 30 | _scales = [0.0, 1e-10, 1e3, 1e5, 1e10, 1e15] 31 | 32 | _rng = torch.Generator(device=DEVICE).manual_seed(0) 33 | 34 | matrices = [NormalSampler(m, n, r)(_rng) for m, n, r in _normal_dims] 35 | zero_matrices = [zeros_([m, n]) for m, n, _ in _zero_dims] 36 | strong_matrices = [StrongSampler(m, n, r)(_rng) for m, n, r in _stationarity_dims] 37 | strictly_weak_matrices = [StrictlyWeakSampler(m, n, r)(_rng) for m, n, r in _stationarity_dims] 38 | non_weak_matrices = [NonWeakSampler(m, n, r)(_rng) for m, n, r in _stationarity_dims] 39 | 40 | scaled_matrices = [scale * matrix for scale in _scales for matrix in matrices] 41 | 42 | non_strong_matrices = strictly_weak_matrices + non_weak_matrices 43 | typical_matrices = zero_matrices + matrices + strong_matrices + non_strong_matrices 44 | 45 | scaled_matrices_2_plus_rows = [matrix for matrix in scaled_matrices if matrix.shape[0] >= 2] 46 | typical_matrices_2_plus_rows = [matrix for matrix in typical_matrices if matrix.shape[0] >= 2] 47 | 48 | # It seems that NashMTL does not work for matrices with 1 row, so we make different matrices for it. 49 | _nashmtl_dims = [ 50 | (3, 1, 1), 51 | (4, 3, 1), 52 | (4, 3, 2), 53 | (4, 3, 3), 54 | (9, 11, 5), 55 | (9, 11, 9), 56 | ] 57 | nash_mtl_matrices = [NormalSampler(m, n, r)(_rng) for m, n, r in _nashmtl_dims] 58 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_constant.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | from ._aggregator_bases import WeightedAggregator 4 | from ._utils.str import vector_to_str 5 | from ._weighting_bases import Matrix, Weighting 6 | 7 | 8 | class Constant(WeightedAggregator): 9 | """ 10 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that makes a linear combination of 11 | the rows of the provided matrix, with constant, pre-determined weights. 12 | 13 | :param weights: The weights associated to the rows of the input matrices. 14 | """ 15 | 16 | def __init__(self, weights: Tensor): 17 | super().__init__(weighting=ConstantWeighting(weights=weights)) 18 | self._weights = weights 19 | 20 | def __repr__(self) -> str: 21 | return f"{self.__class__.__name__}(weights={repr(self._weights)})" 22 | 23 | def __str__(self) -> str: 24 | weights_str = vector_to_str(self._weights) 25 | return f"{self.__class__.__name__}([{weights_str}])" 26 | 27 | 28 | class ConstantWeighting(Weighting[Matrix]): 29 | """ 30 | :class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined 31 | weights. 32 | 33 | :param weights: The weights to return at each call. 34 | """ 35 | 36 | def __init__(self, weights: Tensor): 37 | if weights.dim() != 1: 38 | raise ValueError( 39 | "Parameter `weights` should be a 1-dimensional tensor. Found `weights.shape = " 40 | f"{weights.shape}`." 41 | ) 42 | 43 | super().__init__() 44 | self.weights = weights 45 | 46 | def forward(self, matrix: Tensor) -> Tensor: 47 | self._check_matrix_shape(matrix) 48 | return self.weights 49 | 50 | def _check_matrix_shape(self, matrix: Tensor) -> None: 51 | if matrix.shape[0] != len(self.weights): 52 | raise ValueError( 53 | f"Parameter `matrix` should have {len(self.weights)} rows (the number of specified " 54 | f"weights). Found `matrix` with {matrix.shape[0]} rows." 55 | ) 56 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_pcgrad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from ._aggregator_bases import GramianWeightedAggregator 5 | from ._utils.non_differentiable import raise_non_differentiable_error 6 | from ._weighting_bases import PSDMatrix, Weighting 7 | 8 | 9 | class PCGrad(GramianWeightedAggregator): 10 | """ 11 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in algorithm 1 of 12 | `Gradient Surgery for Multi-Task Learning `_. 13 | """ 14 | 15 | def __init__(self): 16 | super().__init__(PCGradWeighting()) 17 | 18 | # This prevents running into a RuntimeError due to modifying stored tensors in place. 19 | self.register_full_backward_pre_hook(raise_non_differentiable_error) 20 | 21 | 22 | class PCGradWeighting(Weighting[PSDMatrix]): 23 | """ 24 | :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of 25 | :class:`~torchjd.aggregation.PCGrad`. 26 | """ 27 | 28 | def forward(self, gramian: Tensor) -> Tensor: 29 | # Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration 30 | device = gramian.device 31 | dtype = gramian.dtype 32 | cpu = torch.device("cpu") 33 | gramian = gramian.to(device=cpu) 34 | 35 | dimension = gramian.shape[0] 36 | weights = torch.zeros(dimension, device=cpu, dtype=dtype) 37 | 38 | for i in range(dimension): 39 | permutation = torch.randperm(dimension) 40 | current_weights = torch.zeros(dimension, device=cpu, dtype=dtype) 41 | current_weights[i] = 1.0 42 | 43 | for j in permutation: 44 | if j == i: 45 | continue 46 | 47 | # Compute the inner product between g_i^{PC} and g_j 48 | inner_product = gramian[j] @ current_weights 49 | 50 | if inner_product < 0.0: 51 | current_weights[j] -= inner_product / (gramian[j, j]) 52 | 53 | weights = weights + current_weights 54 | 55 | return weights.to(device) 56 | -------------------------------------------------------------------------------- /tests/unit/aggregation/test_cagrad.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext as does_not_raise 2 | 3 | from pytest import mark, raises 4 | from torch import Tensor 5 | from utils.contexts import ExceptionContext 6 | from utils.tensors import ones_ 7 | 8 | from torchjd.aggregation import CAGrad 9 | 10 | from ._asserts import assert_expected_structure, assert_non_conflicting, assert_non_differentiable 11 | from ._inputs import scaled_matrices, typical_matrices 12 | 13 | scaled_pairs = [(CAGrad(c=0.5), matrix) for matrix in scaled_matrices] 14 | typical_pairs = [(CAGrad(c=0.5), matrix) for matrix in typical_matrices] 15 | requires_grad_pairs = [(CAGrad(c=0.5), ones_(3, 5, requires_grad=True))] 16 | non_conflicting_pairs_1 = [(CAGrad(c=1.0), matrix) for matrix in typical_matrices] 17 | non_conflicting_pairs_2 = [(CAGrad(c=2.0), matrix) for matrix in typical_matrices] 18 | 19 | 20 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) 21 | def test_expected_structure(aggregator: CAGrad, matrix: Tensor): 22 | assert_expected_structure(aggregator, matrix) 23 | 24 | 25 | @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) 26 | def test_non_differentiable(aggregator: CAGrad, matrix: Tensor): 27 | assert_non_differentiable(aggregator, matrix) 28 | 29 | 30 | @mark.parametrize(["aggregator", "matrix"], non_conflicting_pairs_1 + non_conflicting_pairs_2) 31 | def test_non_conflicting(aggregator: CAGrad, matrix: Tensor): 32 | """Tests that CAGrad is non-conflicting when c >= 1 (it should not hold when c < 1).""" 33 | assert_non_conflicting(aggregator, matrix) 34 | 35 | 36 | @mark.parametrize( 37 | ["c", "expectation"], 38 | [ 39 | (-5.0, raises(ValueError)), 40 | (-1.0, raises(ValueError)), 41 | (0.0, does_not_raise()), 42 | (1.0, does_not_raise()), 43 | (50.0, does_not_raise()), 44 | ], 45 | ) 46 | def test_c_check(c: float, expectation: ExceptionContext): 47 | with expectation: 48 | _ = CAGrad(c=c) 49 | 50 | 51 | def test_representations(): 52 | A = CAGrad(c=0.5, norm_eps=0.0001) 53 | assert repr(A) == "CAGrad(c=0.5, norm_eps=0.0001)" 54 | assert str(A) == "CAGrad0.5" 55 | -------------------------------------------------------------------------------- /docs/source/examples/partial_jd.rst: -------------------------------------------------------------------------------- 1 | Partial Jacobian Descent for IWRM 2 | ================================= 3 | 4 | This example demonstrates how to perform Partial Jacobian Descent using TorchJD. This technique 5 | minimizes a vector of per-instance losses by resolving conflict only based on a submatrix of the 6 | Jacobian — specifically, the portion corresponding to a selected subset of the model's parameters. 7 | This approach offers a trade-off between the precision of the aggregation decision and the 8 | computational cost associated with computing the Gramian of the full Jacobian. For a complete, 9 | non-partial version, see the :doc:`IWRM ` example. 10 | 11 | In this example, our model consists of three ``Linear`` layers separated by ``ReLU`` layers. We 12 | perform the partial descent by considering only the parameters of the last two ``Linear`` layers. By 13 | doing this, we avoid computing the Jacobian and its Gramian with respect to the parameters of the 14 | first ``Linear`` layer, thereby reducing memory usage and computation time. 15 | 16 | .. code-block:: python 17 | :emphasize-lines: 16-18 18 | 19 | import torch 20 | from torch.nn import Linear, MSELoss, ReLU, Sequential 21 | from torch.optim import SGD 22 | 23 | from torchjd.aggregation import UPGradWeighting 24 | from torchjd.autogram import Engine 25 | 26 | X = torch.randn(8, 16, 10) 27 | Y = torch.randn(8, 16) 28 | 29 | model = Sequential(Linear(10, 8), ReLU(), Linear(8, 5), ReLU(), Linear(5, 1)) 30 | loss_fn = MSELoss(reduction="none") 31 | 32 | weighting = UPGradWeighting() 33 | 34 | # Create the autogram engine that will compute the Gramian of the 35 | # Jacobian with respect to the two last Linear layers' parameters. 36 | engine = Engine(model[2:], batch_dim=0) 37 | 38 | params = model.parameters() 39 | optimizer = SGD(params, lr=0.1) 40 | 41 | for x, y in zip(X, Y): 42 | y_hat = model(x).squeeze(dim=1) # shape: [16] 43 | losses = loss_fn(y_hat, y) # shape: [16] 44 | optimizer.zero_grad() 45 | gramian = engine.compute_gramian(losses) 46 | weights = weighting(gramian) 47 | losses.backward(weights) 48 | optimizer.step() 49 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_trimmed_mean.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from ._aggregator_bases import Aggregator 5 | 6 | 7 | class TrimmedMean(Aggregator): 8 | """ 9 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` for adversarial federated learning, 10 | that trims the most extreme values of the input matrix, before averaging its rows, as defined in 11 | `Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates 12 | `_. 13 | 14 | :param trim_number: The number of maximum and minimum values to remove from each column of the 15 | input matrix (note that ``2 * trim_number`` values are removed from each column). 16 | """ 17 | 18 | def __init__(self, trim_number: int): 19 | super().__init__() 20 | if trim_number < 0: 21 | raise ValueError( 22 | "Parameter `trim_number` should be a non-negative integer. Found `trim_number` = " 23 | f"{trim_number}`." 24 | ) 25 | self.trim_number = trim_number 26 | 27 | def forward(self, matrix: Tensor) -> Tensor: 28 | self._check_is_matrix(matrix) 29 | self._check_matrix_has_enough_rows(matrix) 30 | 31 | n_rows = matrix.shape[0] 32 | n_remaining = n_rows - 2 * self.trim_number 33 | 34 | sorted_matrix, _ = torch.sort(matrix, dim=0) 35 | trimmed = torch.narrow(sorted_matrix, dim=0, start=self.trim_number, length=n_remaining) 36 | vector = trimmed.mean(dim=0) 37 | return vector 38 | 39 | def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None: 40 | min_rows = 1 + 2 * self.trim_number 41 | n_rows = matrix.shape[0] 42 | if n_rows < min_rows: 43 | raise ValueError( 44 | f"Parameter `matrix` should be a matrix of at least {min_rows} rows " 45 | f"(i.e. `2 * trim_number + 1`). Found `matrix` of shape `{matrix.shape}`." 46 | ) 47 | 48 | def __repr__(self) -> str: 49 | return f"{self.__class__.__name__}(trim_number={self.trim_number})" 50 | 51 | def __str__(self) -> str: 52 | return f"TM{self.trim_number}" 53 | -------------------------------------------------------------------------------- /tests/unit/aggregation/test_pcgrad.py: -------------------------------------------------------------------------------- 1 | from pytest import mark 2 | from torch import Tensor 3 | from torch.testing import assert_close 4 | from utils.tensors import ones_, randn_ 5 | 6 | from torchjd.aggregation import PCGrad 7 | from torchjd.aggregation._pcgrad import PCGradWeighting 8 | from torchjd.aggregation._upgrad import UPGradWeighting 9 | from torchjd.aggregation._utils.gramian import compute_gramian 10 | 11 | from ._asserts import assert_expected_structure, assert_non_differentiable 12 | from ._inputs import scaled_matrices, typical_matrices 13 | 14 | scaled_pairs = [(PCGrad(), matrix) for matrix in scaled_matrices] 15 | typical_pairs = [(PCGrad(), matrix) for matrix in typical_matrices] 16 | requires_grad_pairs = [(PCGrad(), ones_(3, 5, requires_grad=True))] 17 | 18 | 19 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) 20 | def test_expected_structure(aggregator: PCGrad, matrix: Tensor): 21 | assert_expected_structure(aggregator, matrix) 22 | 23 | 24 | @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) 25 | def test_non_differentiable(aggregator: PCGrad, matrix: Tensor): 26 | assert_non_differentiable(aggregator, matrix) 27 | 28 | 29 | @mark.parametrize( 30 | "shape", 31 | [ 32 | (2, 5), 33 | (2, 7), 34 | (2, 9), 35 | (2, 15), 36 | (2, 27), 37 | (2, 68), 38 | (2, 102), 39 | (2, 57), 40 | (2, 1200), 41 | (2, 11100), 42 | ], 43 | ) 44 | def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]): 45 | """ 46 | Tests that UPGradWeighting of a SumWeighting is equivalent to PCGradWeighting for matrices of 2 47 | rows. 48 | """ 49 | 50 | matrix = randn_(shape) 51 | gramian = compute_gramian(matrix) 52 | 53 | pc_grad_weighting = PCGradWeighting() 54 | upgrad_sum_weighting = UPGradWeighting( 55 | ones_((2,)), norm_eps=0.0, reg_eps=0.0, solver="quadprog" 56 | ) 57 | 58 | result = pc_grad_weighting(gramian) 59 | expected = upgrad_sum_weighting(gramian) 60 | 61 | assert_close(result, expected, atol=4e-04, rtol=0.0) 62 | 63 | 64 | def test_representations(): 65 | A = PCGrad() 66 | assert repr(A) == "PCGrad()" 67 | assert str(A) == "PCGrad" 68 | -------------------------------------------------------------------------------- /tests/unit/aggregation/test_trimmed_mean.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext as does_not_raise 2 | 3 | from pytest import mark, raises 4 | from torch import Tensor 5 | from utils.contexts import ExceptionContext 6 | from utils.tensors import ones_ 7 | 8 | from torchjd.aggregation import TrimmedMean 9 | 10 | from ._asserts import assert_expected_structure, assert_permutation_invariant 11 | from ._inputs import scaled_matrices_2_plus_rows, typical_matrices_2_plus_rows 12 | 13 | scaled_pairs = [(TrimmedMean(trim_number=1), matrix) for matrix in scaled_matrices_2_plus_rows] 14 | typical_pairs = [(TrimmedMean(trim_number=1), matrix) for matrix in typical_matrices_2_plus_rows] 15 | 16 | 17 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) 18 | def test_expected_structure(aggregator: TrimmedMean, matrix: Tensor): 19 | assert_expected_structure(aggregator, matrix) 20 | 21 | 22 | @mark.parametrize(["aggregator", "matrix"], typical_pairs) 23 | def test_permutation_invariant(aggregator: TrimmedMean, matrix: Tensor): 24 | assert_permutation_invariant(aggregator, matrix) 25 | 26 | 27 | @mark.parametrize( 28 | ["trim_number", "expectation"], 29 | [ 30 | (-5, raises(ValueError)), 31 | (-1, raises(ValueError)), 32 | (0, does_not_raise()), 33 | (1, does_not_raise()), 34 | (5, does_not_raise()), 35 | ], 36 | ) 37 | def test_trim_number_check(trim_number: int, expectation: ExceptionContext): 38 | with expectation: 39 | _ = TrimmedMean(trim_number=trim_number) 40 | 41 | 42 | @mark.parametrize( 43 | ["n_rows", "trim_number", "expectation"], 44 | [ 45 | (1, 0, does_not_raise()), 46 | (1, 1, raises(ValueError)), 47 | (10, 0, does_not_raise()), 48 | (10, 4, does_not_raise()), 49 | (10, 5, raises(ValueError)), 50 | ], 51 | ) 52 | def test_matrix_shape_check(n_rows: int, trim_number: int, expectation: ExceptionContext): 53 | matrix = ones_([n_rows, 5]) 54 | aggregator = TrimmedMean(trim_number=trim_number) 55 | 56 | with expectation: 57 | _ = aggregator(matrix) 58 | 59 | 60 | def test_representations(): 61 | aggregator = TrimmedMean(trim_number=2) 62 | assert repr(aggregator) == "TrimmedMean(trim_number=2)" 63 | assert str(aggregator) == "TM2" 64 | -------------------------------------------------------------------------------- /src/torchjd/autogram/_edge_registry.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | from torch.autograd.graph import GradientEdge 4 | 5 | 6 | class EdgeRegistry: 7 | """ 8 | Tracks `GradientEdge`s and provides a way to efficiently compute a minimally sufficient subset 9 | of leaf edges that are reachable from some given `GradientEdge`s. 10 | """ 11 | 12 | def __init__(self) -> None: 13 | self._edges: set[GradientEdge] = set() 14 | 15 | def reset(self) -> None: 16 | self._edges = set() 17 | 18 | def register(self, edge: GradientEdge) -> None: 19 | """ 20 | Track the provided edge. 21 | 22 | :param edge: Edge to track. 23 | """ 24 | self._edges.add(edge) 25 | 26 | def get_leaf_edges(self, roots: set[GradientEdge]) -> set[GradientEdge]: 27 | """ 28 | Compute a minimal subset of edges that yields the same differentiation graph traversal: the 29 | leaf edges. Specifically, this removes edges that are reachable from other edges in the 30 | differentiation graph, avoiding the need to keep gradients in memory for all edges 31 | simultaneously. 32 | 33 | :param roots: Roots of the graph traversal. Modified in-place. 34 | :returns: Minimal subset of leaf edges. 35 | """ 36 | 37 | nodes_to_traverse = deque((child, root) for root in roots for child in _next_edges(root)) 38 | result = {root for root in roots if root in self._edges} 39 | 40 | excluded = roots 41 | while nodes_to_traverse: 42 | node, origin = nodes_to_traverse.popleft() 43 | if node in self._edges: 44 | result.add(node) 45 | result.discard(origin) 46 | origin = node 47 | for child in _next_edges(node): 48 | if child not in excluded: 49 | nodes_to_traverse.append((child, origin)) 50 | excluded.add(child) 51 | return result 52 | 53 | 54 | def _next_edges(edge: GradientEdge) -> list[GradientEdge]: 55 | """ 56 | Get the next edges in the autograd graph from the given edge. 57 | 58 | :param edge: The current edge. 59 | """ 60 | return [GradientEdge(child, nr) for child, nr in edge.node.next_functions if child is not None] 61 | -------------------------------------------------------------------------------- /tests/unit/autojac/_transform/test_init.py: -------------------------------------------------------------------------------- 1 | from pytest import raises 2 | from utils.dict_assertions import assert_tensor_dicts_are_close 3 | from utils.tensors import tensor_ 4 | 5 | from torchjd.autojac._transform import Init, RequirementError 6 | 7 | 8 | def test_single_input(): 9 | """ 10 | Tests that when there is a single key to initialize, the Init transform creates a TensorDict 11 | whose value is a tensor full of ones, of the same shape as its key. 12 | """ 13 | 14 | key = tensor_([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 15 | input = {} 16 | 17 | init = Init({key}) 18 | 19 | output = init(input) 20 | expected_output = {key: tensor_([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])} 21 | 22 | assert_tensor_dicts_are_close(output, expected_output) 23 | 24 | 25 | def test_multiple_inputs(): 26 | """ 27 | Tests that when there are several keys to initialize, the Init transform creates a TensorDict 28 | whose values are tensors full of ones, of the same shape as their corresponding keys. 29 | """ 30 | 31 | key1 = tensor_([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 32 | key2 = tensor_([1.0, 3.0, 5.0]) 33 | input = {} 34 | 35 | init = Init({key1, key2}) 36 | 37 | output = init(input) 38 | expected = { 39 | key1: tensor_([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]), 40 | key2: tensor_([1.0, 1.0, 1.0]), 41 | } 42 | assert_tensor_dicts_are_close(output, expected) 43 | 44 | 45 | def test_conjunction_of_inits_is_init(): 46 | """ 47 | Tests that the conjunction of 2 Init transforms is equivalent to a single Init transform with 48 | multiple keys. 49 | """ 50 | 51 | x1 = tensor_(5.0) 52 | x2 = tensor_(6.0) 53 | input = {} 54 | 55 | init1 = Init({x1}) 56 | init2 = Init({x2}) 57 | conjunction_of_inits = init1 | init2 58 | init = Init({x1, x2}) 59 | 60 | output = conjunction_of_inits(input) 61 | expected_output = init(input) 62 | 63 | assert_tensor_dicts_are_close(output, expected_output) 64 | 65 | 66 | def test_check_keys(): 67 | """Tests that the `check_keys` method works correctly: the input_keys should be empty.""" 68 | 69 | key = tensor_([1.0]) 70 | init = Init({key}) 71 | 72 | output_keys = init.check_keys(set()) 73 | assert output_keys == {key} 74 | 75 | with raises(RequirementError): 76 | init.check_keys({key}) 77 | -------------------------------------------------------------------------------- /src/torchjd/autojac/_transform/_stack.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from ._base import TensorDict, Transform 7 | from ._materialize import materialize 8 | 9 | 10 | class Stack(Transform): 11 | """ 12 | Transform from something to Jacobians, applying several transforms to the same input, and 13 | combining the results (by stacking) into a single TensorDict. 14 | 15 | The set of keys of the resulting dict is the union of the sets of keys of the input dicts. 16 | 17 | :param transforms: The transforms to apply. They should all be from the same thing and map to 18 | Gradients. Their outputs may have different sets of keys. If a key is absent in some output 19 | dicts, the corresponding stacked tensor is filled with zeroes at the positions corresponding 20 | to those dicts. 21 | """ 22 | 23 | def __init__(self, transforms: Sequence[Transform]): 24 | self.transforms = transforms 25 | 26 | def __call__(self, input: TensorDict) -> TensorDict: 27 | results = [transform(input) for transform in self.transforms] 28 | result = _stack(results) 29 | return result 30 | 31 | def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: 32 | return {key for transform in self.transforms for key in transform.check_keys(input_keys)} 33 | 34 | 35 | def _stack(gradient_dicts: list[TensorDict]) -> TensorDict: 36 | # It is important to first remove duplicate keys before computing their associated 37 | # stacked tensor. Otherwise, some computations would be duplicated. Therefore, we first compute 38 | # unique_keys, and only then, we compute the stacked tensors. 39 | union: TensorDict = {} 40 | for d in gradient_dicts: 41 | union |= d 42 | unique_keys = union.keys() 43 | result = {key: _stack_one_key(gradient_dicts, key) for key in unique_keys} 44 | return result 45 | 46 | 47 | def _stack_one_key(gradient_dicts: list[TensorDict], input: Tensor) -> Tensor: 48 | """Makes the stacked tensor corresponding to a given key, from a list of tensor dicts.""" 49 | 50 | optional_gradients = [gradients.get(input, None) for gradients in gradient_dicts] 51 | gradients = materialize(optional_gradients, [input] * len(optional_gradients)) 52 | jacobian = torch.stack(gradients, dim=0) 53 | return jacobian 54 | -------------------------------------------------------------------------------- /tests/unit/autojac/_transform/test_select.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytest import raises 3 | from utils.dict_assertions import assert_tensor_dicts_are_close 4 | from utils.tensors import tensor_ 5 | 6 | from torchjd.autojac._transform import RequirementError, Select 7 | 8 | 9 | def test_partition(): 10 | """ 11 | Tests that the Select transform works correctly by applying 2 different Selects to a TensorDict, 12 | whose keys form a partition of the keys of the TensorDict. 13 | """ 14 | 15 | key1 = tensor_([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 16 | key2 = tensor_([1.0, 3.0, 5.0]) 17 | key3 = tensor_(2.0) 18 | value1 = torch.ones_like(key1) 19 | value2 = torch.ones_like(key2) 20 | value3 = torch.ones_like(key3) 21 | input = {key1: value1, key2: value2, key3: value3} 22 | 23 | select1 = Select({key1, key2}) 24 | select2 = Select({key3}) 25 | 26 | output1 = select1(input) 27 | expected_output1 = {key1: value1, key2: value2} 28 | 29 | assert_tensor_dicts_are_close(output1, expected_output1) 30 | 31 | output2 = select2(input) 32 | expected_output2 = {key3: value3} 33 | 34 | assert_tensor_dicts_are_close(output2, expected_output2) 35 | 36 | 37 | def test_conjunction_of_selects_is_select(): 38 | """ 39 | Tests that the conjunction of 2 Select transforms is equivalent to directly using a Select with 40 | the union of the keys of the 2 Selects. 41 | """ 42 | 43 | x1 = tensor_(5.0) 44 | x2 = tensor_(6.0) 45 | x3 = tensor_(7.0) 46 | input = {x1: torch.ones_like(x1), x2: torch.ones_like(x2), x3: torch.ones_like(x3)} 47 | 48 | select1 = Select({x1}) 49 | select2 = Select({x2}) 50 | conjunction_of_selects = select1 | select2 51 | select = Select({x1, x2}) 52 | 53 | output = conjunction_of_selects(input) 54 | expected_output = select(input) 55 | 56 | assert_tensor_dicts_are_close(output, expected_output) 57 | 58 | 59 | def test_check_keys(): 60 | """ 61 | Tests that the `check_keys` method works correctly: the set of keys to select should be a subset 62 | of the set of required_keys. 63 | """ 64 | 65 | key1 = tensor_([1.0]) 66 | key2 = tensor_([2.0]) 67 | key3 = tensor_([3.0]) 68 | 69 | output_keys = Select({key1, key2}).check_keys({key1, key2, key3}) 70 | assert output_keys == {key1, key2} 71 | 72 | with raises(RequirementError): 73 | Select({key1, key2}).check_keys({key1}) 74 | -------------------------------------------------------------------------------- /tests/unit/aggregation/test_mgda.py: -------------------------------------------------------------------------------- 1 | from pytest import mark 2 | from torch import Tensor 3 | from torch.testing import assert_close 4 | from utils.tensors import ones_, randn_ 5 | 6 | from torchjd.aggregation import MGDA 7 | from torchjd.aggregation._mgda import MGDAWeighting 8 | from torchjd.aggregation._utils.gramian import compute_gramian 9 | 10 | from ._asserts import ( 11 | assert_expected_structure, 12 | assert_non_conflicting, 13 | assert_permutation_invariant, 14 | ) 15 | from ._inputs import scaled_matrices, typical_matrices 16 | 17 | scaled_pairs = [(MGDA(), matrix) for matrix in scaled_matrices] 18 | typical_pairs = [(MGDA(), matrix) for matrix in typical_matrices] 19 | 20 | 21 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) 22 | def test_expected_structure(aggregator: MGDA, matrix: Tensor): 23 | assert_expected_structure(aggregator, matrix) 24 | 25 | 26 | @mark.parametrize(["aggregator", "matrix"], typical_pairs) 27 | def test_non_conflicting(aggregator: MGDA, matrix: Tensor): 28 | assert_non_conflicting(aggregator, matrix) 29 | 30 | 31 | @mark.parametrize(["aggregator", "matrix"], typical_pairs) 32 | def test_permutation_invariant(aggregator: MGDA, matrix: Tensor): 33 | assert_permutation_invariant(aggregator, matrix) 34 | 35 | 36 | @mark.parametrize( 37 | "shape", 38 | [ 39 | (5, 7), 40 | (9, 37), 41 | (2, 14), 42 | (32, 114), 43 | (50, 100), 44 | ], 45 | ) 46 | def test_mgda_satisfies_kkt_conditions(shape: tuple[int, int]): 47 | matrix = randn_(shape) 48 | gramian = compute_gramian(matrix) 49 | 50 | weighting = MGDAWeighting(epsilon=1e-05, max_iters=1000) 51 | weights = weighting(gramian) 52 | 53 | output_direction = gramian @ weights # Stationarity 54 | lamb = -weights @ output_direction # Complementary slackness 55 | mu = output_direction + lamb 56 | 57 | # Primal feasibility 58 | positive_weights = weights[weights >= 0] 59 | assert_close(positive_weights.norm(), weights.norm()) 60 | 61 | weights_sum = weights.sum() 62 | assert_close(weights_sum, ones_([])) 63 | 64 | # Dual feasibility 65 | positive_mu = mu[mu >= 0] 66 | assert_close(positive_mu.norm(), mu.norm(), atol=1e-02, rtol=0.0) 67 | 68 | 69 | def test_representations(): 70 | A = MGDA(epsilon=0.001, max_iters=100) 71 | assert repr(A) == "MGDA(epsilon=0.001, max_iters=100)" 72 | assert str(A) == "MGDA" 73 | -------------------------------------------------------------------------------- /src/torchjd/autojac/_transform/_grad.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from ._differentiate import Differentiate 7 | from ._ordered_set import OrderedSet 8 | 9 | 10 | class Grad(Differentiate): 11 | """ 12 | Transform from Gradients to Gradients, computing the gradient of each output element with 13 | respect to each input tensor, and applying the linear transformations represented by provided 14 | the grad_outputs to the results. 15 | 16 | :param outputs: Tensors to differentiate. 17 | :param inputs: Tensors with respect to which we differentiate. 18 | :param retain_graph: If False, the graph used to compute the grads will be freed. Defaults to 19 | False. 20 | :param create_graph: If True, graph of the derivative will be constructed, allowing to compute 21 | higher order derivative products. Defaults to False. 22 | 23 | .. note:: The order of outputs and inputs only matters because we have no guarantee that 24 | torch.autograd.grad is *exactly* equivariant to input permutations and invariant to output 25 | (with their corresponding grad_output) permutations. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | outputs: OrderedSet[Tensor], 31 | inputs: OrderedSet[Tensor], 32 | retain_graph: bool = False, 33 | create_graph: bool = False, 34 | ): 35 | super().__init__(outputs, inputs, retain_graph, create_graph) 36 | 37 | def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]: 38 | """ 39 | Computes the gradient of each output element with respect to each input tensor, and applies 40 | the linear transformations represented by the grad_outputs to the results. 41 | 42 | Returns one gradient per input, corresponding to the sum of the scaled gradients with 43 | respect to this input. 44 | 45 | :param grad_outputs: The sequence of tensors to scale the obtained gradients with. Its 46 | length should be equal to the length of ``outputs``. Each grad_output should have the 47 | same shape as the corresponding output. 48 | """ 49 | 50 | if len(self.inputs) == 0: 51 | return tuple() 52 | 53 | if len(self.outputs) == 0: 54 | return tuple(torch.zeros_like(input) for input in self.inputs) 55 | 56 | grads = self._get_vjp(grad_outputs, self.retain_graph) 57 | return grads 58 | -------------------------------------------------------------------------------- /tests/unit/autogram/test_gramian_utils.py: -------------------------------------------------------------------------------- 1 | from pytest import mark 2 | from torch.testing import assert_close 3 | from utils.forward_backwards import compute_gramian 4 | from utils.tensors import randn_ 5 | 6 | from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian 7 | 8 | 9 | @mark.parametrize( 10 | ["original_shape", "target_shape"], 11 | [ 12 | ([], []), 13 | ([], [1, 1]), 14 | ([1], []), 15 | ([12], [2, 3, 2]), 16 | ([12], [4, 3]), 17 | ([12], [12]), 18 | ([4, 3], [12]), 19 | ([4, 3], [2, 3, 2]), 20 | ([4, 3], [3, 4]), 21 | ([4, 3], [4, 3]), 22 | ([6, 7, 9], [378]), 23 | ([6, 7, 9], [9, 42]), 24 | ([6, 7, 9], [2, 7, 27]), 25 | ([6, 7, 9], [6, 7, 9]), 26 | ], 27 | ) 28 | def test_reshape_gramian(original_shape: list[int], target_shape: list[int]): 29 | """Tests that reshape_gramian is such that compute_gramian is equivariant to a reshape.""" 30 | 31 | original_matrix = randn_(original_shape + [2]) 32 | target_matrix = original_matrix.reshape(target_shape + [2]) 33 | 34 | original_gramian = compute_gramian(original_matrix) 35 | target_gramian = compute_gramian(target_matrix) 36 | 37 | reshaped_gramian = reshape_gramian(original_gramian, target_shape) 38 | 39 | assert_close(reshaped_gramian, target_gramian) 40 | 41 | 42 | @mark.parametrize( 43 | ["shape", "source", "destination"], 44 | [ 45 | ([], [], []), 46 | ([1], [0], [0]), 47 | ([1], [], []), 48 | ([1, 1], [], []), 49 | ([1, 1], [1], [0]), 50 | ([6, 7], [1], [0]), 51 | ([3, 1], [0, 1], [1, 0]), 52 | ([1, 1, 1], [], []), 53 | ([3, 2, 5], [], []), 54 | ([1, 1, 1], [2], [0]), 55 | ([3, 2, 5], [1], [2]), 56 | ([2, 2, 3], [0, 2], [1, 0]), 57 | ([2, 2, 3], [0, 2, 1], [1, 0, 2]), 58 | ], 59 | ) 60 | def test_movedim_gramian(shape: list[int], source: list[int], destination: list[int]): 61 | """Tests that movedim_gramian is such that compute_gramian is equivariant to a movedim.""" 62 | 63 | original_matrix = randn_(shape + [2]) 64 | target_matrix = original_matrix.movedim(source, destination) 65 | 66 | original_gramian = compute_gramian(original_matrix) 67 | target_gramian = compute_gramian(target_matrix) 68 | 69 | moveddim_gramian = movedim_gramian(original_gramian, source, destination) 70 | 71 | assert_close(moveddim_gramian, target_gramian) 72 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_utils/dual_cone.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import numpy as np 4 | import torch 5 | from qpsolvers import solve_qp 6 | from torch import Tensor 7 | 8 | 9 | def project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor: 10 | """ 11 | Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the 12 | rows of a matrix whose Gramian is provided. 13 | 14 | :param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`. 15 | :param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite. 16 | :param solver: The quadratic programming solver to use. 17 | :return: A tensor of projection weights with the same shape as `U`. 18 | """ 19 | 20 | G_ = _to_array(G) 21 | U_ = _to_array(U) 22 | 23 | W = np.apply_along_axis(lambda u: _project_weight_vector(u, G_, solver), axis=-1, arr=U_) 24 | 25 | return torch.as_tensor(W, device=G.device, dtype=G.dtype) 26 | 27 | 28 | def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: Literal["quadprog"]) -> np.ndarray: 29 | r""" 30 | Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`, 31 | given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies 32 | `\pi_J(J^T u) = J^T w`, with `\pi_J` defined in Equation 3 of [1]. 33 | 34 | By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic program: 35 | minimize v^T G v 36 | subject to u \preceq v 37 | 38 | Reference: 39 | [1] `Jacobian Descent For Multi-Objective Optimization `_. 40 | 41 | :param u: The vector of weights `u` of shape `[m]` corresponding to the vector `J^T u` to 42 | project. 43 | :param G: The Gramian matrix of `J`, equal to `J J^T`, and of shape `[m, m]`. It must be 44 | symmetric and positive definite. 45 | :param solver: The quadratic programming solver to use. 46 | """ 47 | 48 | m = G.shape[0] 49 | w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=solver) 50 | 51 | if w is None: # This may happen when G has large values. 52 | raise ValueError("Failed to solve the quadratic programming problem.") 53 | 54 | return w 55 | 56 | 57 | def _to_array(tensor: Tensor) -> np.ndarray: 58 | """Transforms a tensor into a numpy array with float64 dtype.""" 59 | 60 | return tensor.cpu().detach().numpy().astype(np.float64) 61 | -------------------------------------------------------------------------------- /tests/unit/aggregation/test_nash_mtl.py: -------------------------------------------------------------------------------- 1 | from pytest import mark 2 | from torch import Tensor 3 | from torch.testing import assert_close 4 | from utils.tensors import ones_, randn_ 5 | 6 | from torchjd.aggregation import NashMTL 7 | 8 | from ._asserts import assert_expected_structure, assert_non_differentiable 9 | from ._inputs import nash_mtl_matrices 10 | 11 | 12 | def _make_aggregator(matrix: Tensor) -> NashMTL: 13 | return NashMTL(n_tasks=matrix.shape[0]) 14 | 15 | 16 | standard_pairs = [(_make_aggregator(matrix), matrix) for matrix in nash_mtl_matrices] 17 | requires_grad_pairs = [(NashMTL(n_tasks=3), ones_(3, 5, requires_grad=True))] 18 | 19 | 20 | # Note that as opposed to most aggregators, the expected structure is only tested with non-scaled 21 | # matrices, and with matrices of > 1 row. Otherwise, NashMTL fails. 22 | @mark.filterwarnings( 23 | "ignore:Solution may be inaccurate.", 24 | "ignore:You are solving a parameterized problem that is not DPP.", 25 | ) 26 | @mark.parametrize(["aggregator", "matrix"], standard_pairs) 27 | def test_expected_structure(aggregator: NashMTL, matrix: Tensor): 28 | assert_expected_structure(aggregator, matrix) 29 | 30 | 31 | @mark.filterwarnings("ignore:You are solving a parameterized problem that is not DPP.") 32 | @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) 33 | def test_non_differentiable(aggregator: NashMTL, matrix: Tensor): 34 | assert_non_differentiable(aggregator, matrix) 35 | 36 | 37 | @mark.filterwarnings("ignore: You are solving a parameterized problem that is not DPP.") 38 | def test_nash_mtl_reset(): 39 | """ 40 | Tests that the reset method of NashMTL correctly resets its internal state, by verifying that 41 | the result is the same after reset as it is right after instantiation. 42 | 43 | To ensure that the aggregations are not all the same, we create different matrices to aggregate. 44 | """ 45 | 46 | matrices = [randn_(3, 5) for _ in range(4)] 47 | aggregator = NashMTL(n_tasks=3, update_weights_every=3) 48 | expecteds = [aggregator(matrix) for matrix in matrices] 49 | 50 | aggregator.reset() 51 | results = [aggregator(matrix) for matrix in matrices] 52 | 53 | for result, expected in zip(results, expecteds): 54 | assert_close(result, expected) 55 | 56 | 57 | def test_representations(): 58 | A = NashMTL(n_tasks=2, max_norm=1.5, update_weights_every=2, optim_niter=5) 59 | assert repr(A) == "NashMTL(n_tasks=2, max_norm=1.5, update_weights_every=2, optim_niter=5)" 60 | assert str(A) == "NashMTL" 61 | -------------------------------------------------------------------------------- /tests/unit/aggregation/test_dualproj.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytest import mark 3 | from torch import Tensor 4 | from utils.tensors import ones_ 5 | 6 | from torchjd.aggregation import DualProj 7 | 8 | from ._asserts import ( 9 | assert_expected_structure, 10 | assert_non_conflicting, 11 | assert_non_differentiable, 12 | assert_permutation_invariant, 13 | assert_strongly_stationary, 14 | ) 15 | from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices 16 | 17 | scaled_pairs = [(DualProj(), matrix) for matrix in scaled_matrices] 18 | typical_pairs = [(DualProj(), matrix) for matrix in typical_matrices] 19 | non_strong_pairs = [(DualProj(), matrix) for matrix in non_strong_matrices] 20 | requires_grad_pairs = [(DualProj(), ones_(3, 5, requires_grad=True))] 21 | 22 | 23 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) 24 | def test_expected_structure(aggregator: DualProj, matrix: Tensor): 25 | assert_expected_structure(aggregator, matrix) 26 | 27 | 28 | @mark.parametrize(["aggregator", "matrix"], typical_pairs) 29 | def test_non_conflicting(aggregator: DualProj, matrix: Tensor): 30 | assert_non_conflicting(aggregator, matrix, atol=5e-05, rtol=5e-05) 31 | 32 | 33 | @mark.parametrize(["aggregator", "matrix"], typical_pairs) 34 | def test_permutation_invariant(aggregator: DualProj, matrix: Tensor): 35 | assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=2e-07, rtol=2e-07) 36 | 37 | 38 | @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) 39 | def test_strongly_stationary(aggregator: DualProj, matrix: Tensor): 40 | assert_strongly_stationary(aggregator, matrix, threshold=3e-03) 41 | 42 | 43 | @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) 44 | def test_non_differentiable(aggregator: DualProj, matrix: Tensor): 45 | assert_non_differentiable(aggregator, matrix) 46 | 47 | 48 | def test_representations(): 49 | A = DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") 50 | assert ( 51 | repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')" 52 | ) 53 | assert str(A) == "DualProj" 54 | 55 | A = DualProj( 56 | pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"), 57 | norm_eps=0.0001, 58 | reg_eps=0.0001, 59 | solver="quadprog", 60 | ) 61 | assert ( 62 | repr(A) == "DualProj(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, " 63 | "solver='quadprog')" 64 | ) 65 | assert str(A) == "DualProj([1., 2., 3.])" 66 | -------------------------------------------------------------------------------- /tests/unit/aggregation/test_krum.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext as does_not_raise 2 | 3 | from pytest import mark, raises 4 | from torch import Tensor 5 | from utils.contexts import ExceptionContext 6 | from utils.tensors import ones_ 7 | 8 | from torchjd.aggregation import Krum 9 | 10 | from ._asserts import assert_expected_structure 11 | from ._inputs import scaled_matrices_2_plus_rows, typical_matrices_2_plus_rows 12 | 13 | scaled_pairs = [(Krum(n_byzantine=1), matrix) for matrix in scaled_matrices_2_plus_rows] 14 | typical_pairs = [(Krum(n_byzantine=1), matrix) for matrix in typical_matrices_2_plus_rows] 15 | 16 | 17 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) 18 | def test_expected_structure(aggregator: Krum, matrix: Tensor): 19 | assert_expected_structure(aggregator, matrix) 20 | 21 | 22 | @mark.parametrize( 23 | ["n_byzantine", "expectation"], 24 | [ 25 | (-5, raises(ValueError)), 26 | (-1, raises(ValueError)), 27 | (0, does_not_raise()), 28 | (1, does_not_raise()), 29 | (5, does_not_raise()), 30 | ], 31 | ) 32 | def test_n_byzantine_check(n_byzantine: int, expectation: ExceptionContext): 33 | with expectation: 34 | _ = Krum(n_byzantine=n_byzantine, n_selected=1) 35 | 36 | 37 | @mark.parametrize( 38 | ["n_selected", "expectation"], 39 | [ 40 | (-5, raises(ValueError)), 41 | (-1, raises(ValueError)), 42 | (0, raises(ValueError)), 43 | (1, does_not_raise()), 44 | (5, does_not_raise()), 45 | ], 46 | ) 47 | def test_n_selected_check(n_selected: int, expectation: ExceptionContext): 48 | with expectation: 49 | _ = Krum(n_byzantine=1, n_selected=n_selected) 50 | 51 | 52 | @mark.parametrize( 53 | ["n_byzantine", "n_selected", "n_rows", "expectation"], 54 | [ 55 | (1, 1, 3, raises(ValueError)), 56 | (1, 1, 4, does_not_raise()), 57 | (1, 4, 4, does_not_raise()), 58 | (12, 4, 14, raises(ValueError)), 59 | (12, 4, 15, does_not_raise()), 60 | (12, 15, 15, does_not_raise()), 61 | (12, 16, 15, raises(ValueError)), 62 | ], 63 | ) 64 | def test_matrix_shape_check( 65 | n_byzantine: int, n_selected: int, n_rows: int, expectation: ExceptionContext 66 | ): 67 | aggregator = Krum(n_byzantine=n_byzantine, n_selected=n_selected) 68 | matrix = ones_([n_rows, 5]) 69 | 70 | with expectation: 71 | _ = aggregator(matrix) 72 | 73 | 74 | def test_representations(): 75 | A = Krum(n_byzantine=1, n_selected=2) 76 | assert repr(A) == "Krum(n_byzantine=1, n_selected=2)" 77 | assert str(A) == "Krum1-2" 78 | -------------------------------------------------------------------------------- /docs/source/examples/index.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | This section contains some usage examples for TorchJD. 5 | 6 | - :doc:`Basic Usage ` provides a toy example using :doc:`torchjd.backward 7 | <../docs/autojac/backward>` to make a step of Jacobian descent with the :doc:`UPGrad 8 | <../docs/aggregation/upgrad>` aggregator. 9 | - :doc:`Instance-Wise Risk Minimization (IWRM) ` provides an example in which we minimize the 10 | vector of per-instance losses, using stochastic sub-Jacobian descent (SSJD). It is compared to the 11 | usual minimization of the average loss, called empirical risk minimization (ERM), using stochastic 12 | gradient descent (SGD). 13 | - :doc:`Partial Jacobian Descent for IWRM ` provides an example in which we minimize the 14 | vector of per-instance losses using stochastic sub-Jacobian descent, similar to our :doc:`IWRM ` 15 | example. However, this method bases the aggregation decision on the Jacobian of the losses with respect 16 | to **only a subset** of the model's parameters, offering a trade-off between computational cost and 17 | aggregation precision. 18 | - :doc:`Multi-Task Learning (MTL) ` provides an example of multi-task learning where Jacobian 19 | descent is used to optimize the vector of per-task losses of a multi-task model, using the 20 | dedicated backpropagation function :doc:`mtl_backward <../docs/autojac/mtl_backward>`. 21 | - :doc:`Instance-Wise Multi-Task Learning (IWMTL) ` shows how to combine multi-task learning 22 | with instance-wise risk minimization: one loss per task and per element of the batch, using the 23 | :doc:`autogram.Engine <../docs/autogram/engine>` and a :doc:`GeneralizedWeighting 24 | <../docs/aggregation/index>`. 25 | - :doc:`Recurrent Neural Network (RNN) ` shows how to apply Jacobian descent to RNN training, 26 | with one loss per output sequence element. 27 | - :doc:`Monitoring Aggregations ` shows how to monitor the aggregation performed by the 28 | aggregator, to check if Jacobian descent is prescribed for your use-case. 29 | - :doc:`PyTorch Lightning Integration ` showcases how to combine 30 | TorchJD with PyTorch Lightning, by providing an example implementation of a multi-task 31 | ``LightningModule`` optimized by Jacobian descent. 32 | - :doc:`Automatic Mixed Precision ` shows how to combine mixed precision training with TorchJD. 33 | 34 | .. toctree:: 35 | :hidden: 36 | 37 | basic_usage.rst 38 | iwrm.rst 39 | partial_jd.rst 40 | mtl.rst 41 | iwmtl.rst 42 | rnn.rst 43 | monitoring.rst 44 | lightning_integration.rst 45 | amp.rst 46 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import random as rand 2 | from contextlib import nullcontext 3 | 4 | import torch 5 | from device import DEVICE 6 | from pytest import RaisesExc, fixture, mark 7 | from torch import Tensor 8 | from utils.architectures import ModuleFactory 9 | 10 | from torchjd.aggregation import Aggregator, Weighting 11 | 12 | 13 | @fixture(autouse=True) 14 | def fix_randomness() -> None: 15 | rand.seed(0) 16 | torch.manual_seed(0) 17 | 18 | # Only force to use deterministic algorithms on CPU. 19 | # This is because the CI currently runs only on CPU, so we don't really need perfect 20 | # reproducibility on GPU. We also use GPU to benchmark algorithms, and we would rather have them 21 | # use non-deterministic but faster algorithms. 22 | if DEVICE.type == "cpu": 23 | torch.use_deterministic_algorithms(True) 24 | 25 | 26 | def pytest_addoption(parser): 27 | parser.addoption("--runslow", action="store_true", default=False, help="run slow tests") 28 | 29 | 30 | def pytest_configure(config): 31 | config.addinivalue_line("markers", "slow: mark test as slow to run") 32 | config.addinivalue_line("markers", "xfail_if_cuda: mark test as xfail if running on cuda") 33 | 34 | 35 | def pytest_collection_modifyitems(config, items): 36 | skip_slow = mark.skip(reason="Slow test. Use --runslow to run it.") 37 | xfail_cuda = mark.xfail(reason=f"Test expected to fail on {DEVICE}") 38 | for item in items: 39 | if "slow" in item.keywords and not config.getoption("--runslow"): 40 | item.add_marker(skip_slow) 41 | if "xfail_if_cuda" in item.keywords and str(DEVICE).startswith("cuda"): 42 | item.add_marker(xfail_cuda) 43 | 44 | 45 | def pytest_make_parametrize_id(config, val, argname): 46 | MAX_SIZE = 40 47 | optional_string = None # Returning None means using pytest's way of making the string 48 | 49 | if isinstance(val, (Aggregator, ModuleFactory, Weighting)): 50 | optional_string = str(val) 51 | elif isinstance(val, Tensor): 52 | optional_string = "T" + str(list(val.shape)) # T to indicate that it's a tensor 53 | elif isinstance(val, (tuple, list, set)) and len(val) < 20: 54 | optional_string = str(val) 55 | elif isinstance(val, RaisesExc): 56 | optional_string = " or ".join([f"{exc.__name__}" for exc in val.expected_exceptions]) 57 | elif isinstance(val, nullcontext): 58 | optional_string = "does_not_raise()" 59 | 60 | if isinstance(optional_string, str) and len(optional_string) > MAX_SIZE: 61 | optional_string = optional_string[: MAX_SIZE - 3] + "+++" # Can't use dots with pytest 62 | 63 | return optional_string 64 | -------------------------------------------------------------------------------- /docs/source/icons/TorchJD_logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 35 | 37 | 44 | 51 | 58 | 65 | 66 | -------------------------------------------------------------------------------- /src/torchjd/autogram/_gramian_computer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional 3 | 4 | from torch import Tensor 5 | from torch.utils._pytree import PyTree 6 | 7 | from torchjd.autogram._jacobian_computer import JacobianComputer 8 | 9 | 10 | class GramianComputer(ABC): 11 | @abstractmethod 12 | def __call__( 13 | self, 14 | rg_outputs: tuple[Tensor, ...], 15 | grad_outputs: tuple[Tensor, ...], 16 | args: tuple[PyTree, ...], 17 | kwargs: dict[str, PyTree], 18 | ) -> Optional[Tensor]: 19 | """Compute what we can for a module and optionally return the gramian if it's ready.""" 20 | 21 | def track_forward_call(self) -> None: 22 | """Track that the module's forward was called. Necessary in some implementations.""" 23 | 24 | def reset(self): 25 | """Reset state if any. Necessary in some implementations.""" 26 | 27 | 28 | class JacobianBasedGramianComputer(GramianComputer, ABC): 29 | def __init__(self, jacobian_computer): 30 | self.jacobian_computer = jacobian_computer 31 | 32 | @staticmethod 33 | def _to_gramian(jacobian: Tensor) -> Tensor: 34 | return jacobian @ jacobian.T 35 | 36 | 37 | class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer): 38 | """ 39 | Stateful JacobianBasedGramianComputer that waits for all usages to be counted before returning 40 | the gramian. 41 | """ 42 | 43 | def __init__(self, jacobian_computer: JacobianComputer): 44 | super().__init__(jacobian_computer) 45 | self.remaining_counter = 0 46 | self.summed_jacobian: Optional[Tensor] = None 47 | 48 | def reset(self) -> None: 49 | self.remaining_counter = 0 50 | self.summed_jacobian = None 51 | 52 | def track_forward_call(self) -> None: 53 | self.remaining_counter += 1 54 | 55 | def __call__( 56 | self, 57 | rg_outputs: tuple[Tensor, ...], 58 | grad_outputs: tuple[Tensor, ...], 59 | args: tuple[PyTree, ...], 60 | kwargs: dict[str, PyTree], 61 | ) -> Optional[Tensor]: 62 | """Compute what we can for a module and optionally return the gramian if it's ready.""" 63 | 64 | jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs) 65 | 66 | if self.summed_jacobian is None: 67 | self.summed_jacobian = jacobian_matrix 68 | else: 69 | self.summed_jacobian += jacobian_matrix 70 | 71 | self.remaining_counter -= 1 72 | 73 | if self.remaining_counter == 0: 74 | gramian = self._to_gramian(self.summed_jacobian) 75 | del self.summed_jacobian 76 | return gramian 77 | else: 78 | return None 79 | -------------------------------------------------------------------------------- /tests/unit/aggregation/test_upgrad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytest import mark 3 | from torch import Tensor 4 | from utils.tensors import ones_ 5 | 6 | from torchjd.aggregation import UPGrad 7 | 8 | from ._asserts import ( 9 | assert_expected_structure, 10 | assert_linear_under_scaling, 11 | assert_non_conflicting, 12 | assert_non_differentiable, 13 | assert_permutation_invariant, 14 | assert_strongly_stationary, 15 | ) 16 | from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices 17 | 18 | scaled_pairs = [(UPGrad(), matrix) for matrix in scaled_matrices] 19 | typical_pairs = [(UPGrad(), matrix) for matrix in typical_matrices] 20 | non_strong_pairs = [(UPGrad(), matrix) for matrix in non_strong_matrices] 21 | requires_grad_pairs = [(UPGrad(), ones_(3, 5, requires_grad=True))] 22 | 23 | 24 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) 25 | def test_expected_structure(aggregator: UPGrad, matrix: Tensor): 26 | assert_expected_structure(aggregator, matrix) 27 | 28 | 29 | @mark.parametrize(["aggregator", "matrix"], typical_pairs) 30 | def test_non_conflicting(aggregator: UPGrad, matrix: Tensor): 31 | assert_non_conflicting(aggregator, matrix, atol=3e-04, rtol=3e-04) 32 | 33 | 34 | @mark.parametrize(["aggregator", "matrix"], typical_pairs) 35 | def test_permutation_invariant(aggregator: UPGrad, matrix: Tensor): 36 | assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=4e-07, rtol=4e-07) 37 | 38 | 39 | @mark.parametrize(["aggregator", "matrix"], typical_pairs) 40 | def test_linear_under_scaling(aggregator: UPGrad, matrix: Tensor): 41 | assert_linear_under_scaling(aggregator, matrix, n_runs=5, atol=3e-02, rtol=3e-02) 42 | 43 | 44 | @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) 45 | def test_strongly_stationary(aggregator: UPGrad, matrix: Tensor): 46 | assert_strongly_stationary(aggregator, matrix, threshold=5e-03) 47 | 48 | 49 | @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) 50 | def test_non_differentiable(aggregator: UPGrad, matrix: Tensor): 51 | assert_non_differentiable(aggregator, matrix) 52 | 53 | 54 | def test_representations(): 55 | A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") 56 | assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')" 57 | assert str(A) == "UPGrad" 58 | 59 | A = UPGrad( 60 | pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"), 61 | norm_eps=0.0001, 62 | reg_eps=0.0001, 63 | solver="quadprog", 64 | ) 65 | assert ( 66 | repr(A) == "UPGrad(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, " 67 | "solver='quadprog')" 68 | ) 69 | assert str(A) == "UPGrad([1., 2., 3.])" 70 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_aggregator_bases.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from torch import Tensor, nn 4 | 5 | from ._utils.gramian import compute_gramian 6 | from ._weighting_bases import Matrix, PSDMatrix, Weighting 7 | 8 | 9 | class Aggregator(nn.Module, ABC): 10 | r""" 11 | Abstract base class for all aggregators. It has the role of aggregating matrices of dimension 12 | :math:`m \times n` into row vectors of dimension :math:`n`. 13 | """ 14 | 15 | def __init__(self): 16 | super().__init__() 17 | 18 | @staticmethod 19 | def _check_is_matrix(matrix: Tensor) -> None: 20 | if len(matrix.shape) != 2: 21 | raise ValueError( 22 | "Parameter `matrix` should be a tensor of dimension 2. Found `matrix.shape = " 23 | f"{matrix.shape}`." 24 | ) 25 | 26 | @abstractmethod 27 | def forward(self, matrix: Tensor) -> Tensor: 28 | """Computes the aggregation from the input matrix.""" 29 | 30 | # Override to make type hints and documentation more specific 31 | def __call__(self, matrix: Tensor) -> Tensor: 32 | """Computes the aggregation from the input matrix and applies all registered hooks.""" 33 | 34 | return super().__call__(matrix) 35 | 36 | def __repr__(self) -> str: 37 | return f"{self.__class__.__name__}()" 38 | 39 | def __str__(self) -> str: 40 | return f"{self.__class__.__name__}" 41 | 42 | 43 | class WeightedAggregator(Aggregator): 44 | """ 45 | Aggregator that combines the rows of the input jacobian matrix with weights given by applying a 46 | Weighting to it. 47 | 48 | :param weighting: The object responsible for extracting the vector of weights from the matrix. 49 | """ 50 | 51 | def __init__(self, weighting: Weighting[Matrix]): 52 | super().__init__() 53 | self.weighting = weighting 54 | 55 | @staticmethod 56 | def combine(matrix: Tensor, weights: Tensor) -> Tensor: 57 | """ 58 | Aggregates a matrix by making a linear combination of its rows, using the provided vector of 59 | weights. 60 | """ 61 | 62 | vector = weights @ matrix 63 | return vector 64 | 65 | def forward(self, matrix: Tensor) -> Tensor: 66 | self._check_is_matrix(matrix) 67 | weights = self.weighting(matrix) 68 | vector = self.combine(matrix, weights) 69 | return vector 70 | 71 | 72 | class GramianWeightedAggregator(WeightedAggregator): 73 | """ 74 | WeightedAggregator that computes the gramian of the input jacobian matrix before applying a 75 | Weighting to it. 76 | 77 | :param weighting: The object responsible for extracting the vector of weights from the gramian. 78 | """ 79 | 80 | def __init__(self, weighting: Weighting[PSDMatrix]): 81 | super().__init__(weighting << compute_gramian) 82 | -------------------------------------------------------------------------------- /.github/workflows/build-deploy-docs.yml: -------------------------------------------------------------------------------- 1 | name: Build and Deploy Documentation 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | tags: 7 | - 'v[0-9]*.[0-9]*.[0-9]*' 8 | 9 | env: 10 | UV_NO_SYNC: 1 11 | PYTHON_VERSION: '3.14' 12 | 13 | jobs: 14 | build-deploy-doc: 15 | name: Build & deploy doc 16 | environment: prod-documentation 17 | runs-on: ubuntu-latest 18 | permissions: 19 | contents: write 20 | steps: 21 | - name: Checkout repository 22 | uses: actions/checkout@v4 23 | 24 | - name: Set up uv 25 | uses: astral-sh/setup-uv@v5 26 | with: 27 | python-version: ${{ env.PYTHON_VERSION }} 28 | 29 | - name: Install dependencies (default with full options & doc) 30 | run: uv pip install --python-version=${{ env.PYTHON_VERSION }} -e '.[full]' --group doc 31 | 32 | - name: Determine deployment folder 33 | id: deploy_folder 34 | run: | 35 | echo "Determining deployment folder..." 36 | if [[ "${{ github.ref }}" == refs/tags/* ]]; then 37 | echo "Deploying to target ${{ github.ref_name }}" 38 | echo "DEPLOY_DIR=${{ github.ref_name }}" >> $GITHUB_OUTPUT 39 | echo "TORCHJD_VERSION=${{ github.ref_name }}" >> $GITHUB_OUTPUT 40 | else 41 | echo "Deploying to target latest" 42 | echo "DEPLOY_DIR=latest" >> $GITHUB_OUTPUT 43 | echo "TORCHJD_VERSION=main" >> $GITHUB_OUTPUT 44 | fi 45 | 46 | - name: Build Documentation 47 | working-directory: docs 48 | run: uv run make dirhtml 49 | env: 50 | TORCHJD_VERSION: ${{ steps.deploy_folder.outputs.TORCHJD_VERSION }} 51 | 52 | - name: Deploy to DEPLOY_DIR of TorchJD/documentation 53 | uses: peaceiris/actions-gh-pages@v4 54 | with: 55 | deploy_key: ${{ secrets.PROD_DOCUMENTATION_DEPLOY_KEY }} 56 | publish_dir: docs/build/dirhtml 57 | destination_dir: ${{ steps.deploy_folder.outputs.DEPLOY_DIR }} 58 | external_repository: TorchJD/documentation 59 | publish_branch: main 60 | 61 | - name: Kill ssh-agent 62 | # See: https://github.com/peaceiris/actions-gh-pages/issues/909 63 | run: killall ssh-agent 64 | 65 | - name: Deploy to stable of TorchJD/documentation 66 | if: startsWith(github.ref, 'refs/tags/') 67 | uses: peaceiris/actions-gh-pages@v4 68 | with: 69 | deploy_key: ${{ secrets.PROD_DOCUMENTATION_DEPLOY_KEY }} 70 | publish_dir: docs/build/dirhtml 71 | destination_dir: stable 72 | external_repository: TorchJD/documentation 73 | publish_branch: main 74 | 75 | - name: Add documentation link to summary 76 | run: | 77 | echo "### 📄 [View Deployed Documentation](https://torchjd.github.io/documentation/${{ steps.deploy_folder.outputs.DEPLOY_DIR }})" >> $GITHUB_STEP_SUMMARY 78 | -------------------------------------------------------------------------------- /src/torchjd/autojac/_transform/_diagonalize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from ._base import RequirementError, TensorDict, Transform 5 | from ._ordered_set import OrderedSet 6 | 7 | 8 | class Diagonalize(Transform): 9 | """ 10 | Transform diagonalizing Gradients into Jacobians. 11 | 12 | The first dimension of the returned Jacobians will be equal to the total number of elements in 13 | the tensors of the input tensor dict. The exact behavior of the diagonalization is best 14 | explained by some examples. 15 | 16 | Example 1: 17 | The input is one tensor of shape [3] and of value [1 2 3]. 18 | The output Jacobian will be: 19 | [[1 0 0] 20 | [0 2 0] 21 | [0 0 3]] 22 | 23 | Example 2: 24 | The input is one tensor of shape [2, 2] and of value [[4 5] [6 7]]. 25 | The output Jacobian will be: 26 | [[[4 0] [0 0]] 27 | [[0 5] [0 0]] 28 | [[0 0] [6 0]] 29 | [[0 0] [0 7]]] 30 | 31 | Example 3: 32 | The input is two tensors, of shapes [3] and [2, 2] and of values [1 2 3] and [[4 5] [6 7]]. 33 | If the key_order has the tensor of shape [3] appear first and the one of shape [2, 2] appear 34 | second, the output Jacobians will be: 35 | [[1 0 0] 36 | [0 2 0] 37 | [0 0 3] 38 | [0 0 0] 39 | [0 0 0] 40 | [0 0 0] 41 | [0 0 0]] and 42 | [[[0 0] [0 0]] 43 | [[0 0] [0 0]] 44 | [[0 0] [0 0]] 45 | [[4 0] [0 0]] 46 | [[0 5] [0 0]] 47 | [[0 0] [6 0]] 48 | [[0 0] [0 7]]] 49 | 50 | :param key_order: The order in which the keys are represented in the rows of the output 51 | Jacobians. 52 | """ 53 | 54 | def __init__(self, key_order: OrderedSet[Tensor]): 55 | self.key_order = key_order 56 | self.indices: list[tuple[int, int]] = [] 57 | begin = 0 58 | for tensor in self.key_order: 59 | end = begin + tensor.numel() 60 | self.indices.append((begin, end)) 61 | begin = end 62 | 63 | def __call__(self, tensors: TensorDict) -> TensorDict: 64 | flattened_considered_values = [tensors[key].reshape([-1]) for key in self.key_order] 65 | diagonal_matrix = torch.cat(flattened_considered_values).diag() 66 | diagonalized_tensors = { 67 | key: diagonal_matrix[:, begin:end].reshape((-1,) + key.shape) 68 | for (begin, end), key in zip(self.indices, self.key_order) 69 | } 70 | return diagonalized_tensors 71 | 72 | def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: 73 | if not set(self.key_order) == input_keys: 74 | raise RequirementError( 75 | f"The input_keys must match the key_order. Found input_keys {input_keys} and" 76 | f"key_order {self.key_order}." 77 | ) 78 | return input_keys 79 | -------------------------------------------------------------------------------- /tests/unit/aggregation/test_graddrop.py: -------------------------------------------------------------------------------- 1 | import re 2 | from contextlib import nullcontext as does_not_raise 3 | 4 | import torch 5 | from pytest import mark, raises 6 | from torch import Tensor 7 | from utils.contexts import ExceptionContext 8 | from utils.tensors import ones_ 9 | 10 | from torchjd.aggregation import GradDrop 11 | 12 | from ._asserts import assert_expected_structure, assert_non_differentiable 13 | from ._inputs import scaled_matrices, typical_matrices 14 | 15 | scaled_pairs = [(GradDrop(), matrix) for matrix in scaled_matrices] 16 | typical_pairs = [(GradDrop(), matrix) for matrix in typical_matrices] 17 | requires_grad_pairs = [(GradDrop(), ones_(3, 5, requires_grad=True))] 18 | 19 | 20 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) 21 | def test_expected_structure(aggregator: GradDrop, matrix: Tensor): 22 | assert_expected_structure(aggregator, matrix) 23 | 24 | 25 | @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) 26 | def test_non_differentiable(aggregator: GradDrop, matrix: Tensor): 27 | assert_non_differentiable(aggregator, matrix) 28 | 29 | 30 | @mark.parametrize( 31 | ["leak_shape", "expectation"], 32 | [ 33 | ([], raises(ValueError)), 34 | ([0], does_not_raise()), 35 | ([1], does_not_raise()), 36 | ([10], does_not_raise()), 37 | ([0, 0], raises(ValueError)), 38 | ([0, 1], raises(ValueError)), 39 | ([1, 1], raises(ValueError)), 40 | ([1, 1, 1], raises(ValueError)), 41 | ([1, 1, 1, 1], raises(ValueError)), 42 | ([1, 1, 1, 1, 1], raises(ValueError)), 43 | ], 44 | ) 45 | def test_leak_shape_check(leak_shape: list[int], expectation: ExceptionContext): 46 | leak = ones_(leak_shape) 47 | with expectation: 48 | _ = GradDrop(leak=leak) 49 | 50 | 51 | @mark.parametrize( 52 | ["leak_shape", "n_rows", "expectation"], 53 | [ 54 | ([0], 0, does_not_raise()), 55 | ([1], 1, does_not_raise()), 56 | ([5], 5, does_not_raise()), 57 | ([0], 1, raises(ValueError)), 58 | ([1], 0, raises(ValueError)), 59 | ([4], 5, raises(ValueError)), 60 | ([5], 4, raises(ValueError)), 61 | ], 62 | ) 63 | def test_matrix_shape_check(leak_shape: list[int], n_rows: int, expectation: ExceptionContext): 64 | matrix = ones_([n_rows, 5]) 65 | leak = ones_(leak_shape) 66 | aggregator = GradDrop(leak=leak) 67 | 68 | with expectation: 69 | _ = aggregator(matrix) 70 | 71 | 72 | def test_representations(): 73 | A = GradDrop(leak=torch.tensor([0.0, 1.0], device="cpu")) 74 | assert re.match( 75 | r"GradDrop\(f=, leak=tensor\(\[0\., 1\.\]\)\)", 76 | repr(A), 77 | ) 78 | 79 | assert str(A) == "GradDrop([0., 1.])" 80 | 81 | A = GradDrop() 82 | assert re.match(r"GradDrop\(f=, leak=None\)", repr(A)) 83 | assert str(A) == "GradDrop" 84 | -------------------------------------------------------------------------------- /docs/source/examples/basic_usage.rst: -------------------------------------------------------------------------------- 1 | Basic Usage 2 | =========== 3 | 4 | This example shows how to use TorchJD to perform an iteration of Jacobian descent on a regression 5 | model with two objectives. In this example, a batch of inputs is forwarded through the model and two 6 | corresponding batches of labels are used to compute two losses. These losses are then backwarded 7 | through the model. The obtained Jacobian matrix, consisting of the gradients of the two losses with 8 | respect to the parameters, is then aggregated using :doc:`UPGrad <../docs/aggregation/upgrad>`, and 9 | the parameters are updated using the resulting aggregation. 10 | 11 | 12 | 13 | Import several classes from ``torch`` and ``torchjd``: 14 | 15 | .. code-block:: python 16 | 17 | import torch 18 | from torch.nn import Linear, MSELoss, ReLU, Sequential 19 | from torch.optim import SGD 20 | 21 | from torchjd import autojac 22 | from torchjd.aggregation import UPGrad 23 | 24 | Define the model and the optimizer, as usual: 25 | 26 | .. code-block:: python 27 | 28 | model = Sequential(Linear(10, 5), ReLU(), Linear(5, 2)) 29 | optimizer = SGD(model.parameters(), lr=0.1) 30 | 31 | Define the aggregator that will be used to combine the Jacobian matrix: 32 | 33 | .. code-block:: python 34 | 35 | aggregator = UPGrad() 36 | 37 | In essence, :doc:`UPGrad <../docs/aggregation/upgrad>` projects each gradient onto the dual cone of 38 | the rows of the Jacobian and averages the results. This ensures that locally, no loss will be 39 | negatively affected by the update. 40 | 41 | Now that everything is defined, we can train the model. Define the input and the associated target: 42 | 43 | .. code-block:: python 44 | 45 | input = torch.randn(16, 10) # Batch of 16 random input vectors of length 10 46 | target1 = torch.randn(16) # First batch of 16 targets 47 | target2 = torch.randn(16) # Second batch of 16 targets 48 | 49 | Here, we generate fake inputs and labels for the sake of the example. 50 | 51 | We can now compute the losses associated to each element of the batch. 52 | 53 | .. code-block:: python 54 | 55 | loss_fn = MSELoss() 56 | output = model(input) 57 | loss1 = loss_fn(output[:, 0], target1) 58 | loss2 = loss_fn(output[:, 1], target2) 59 | 60 | The last steps are similar to gradient descent-based optimization, but using the two losses. 61 | 62 | Reset the ``.grad`` field of each model parameter: 63 | 64 | .. code-block:: python 65 | 66 | optimizer.zero_grad() 67 | 68 | Perform the Jacobian descent backward pass: 69 | 70 | .. code-block:: python 71 | 72 | autojac.backward([loss1, loss2], aggregator) 73 | 74 | This will populate the ``.grad`` field of each model parameter with the corresponding aggregated 75 | Jacobian matrix. 76 | 77 | Update each parameter based on its ``.grad`` field, using the ``optimizer``: 78 | 79 | .. code-block:: python 80 | 81 | optimizer.step() 82 | 83 | The model's parameters have been updated! 84 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_weighting_bases.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from collections.abc import Callable 5 | from typing import Annotated, Generic, TypeVar 6 | 7 | from torch import Tensor, nn 8 | 9 | _T = TypeVar("_T", contravariant=True) 10 | _FnInputT = TypeVar("_FnInputT") 11 | _FnOutputT = TypeVar("_FnOutputT") 12 | Matrix = Annotated[Tensor, "ndim=2"] 13 | PSDMatrix = Annotated[Matrix, "Positive semi-definite"] 14 | 15 | 16 | class Weighting(Generic[_T], nn.Module, ABC): 17 | r""" 18 | Abstract base class for all weighting methods. It has the role of extracting a vector of weights 19 | of dimension :math:`m` from some statistic of a matrix of dimension :math:`m \times n`, 20 | generally its Gramian, of dimension :math:`m \times m`. 21 | """ 22 | 23 | def __init__(self): 24 | super().__init__() 25 | 26 | @abstractmethod 27 | def forward(self, stat: _T) -> Tensor: 28 | """Computes the vector of weights from the input stat.""" 29 | 30 | # Override to make type hints and documentation more specific 31 | def __call__(self, stat: _T) -> Tensor: 32 | """Computes the vector of weights from the input stat and applies all registered hooks.""" 33 | 34 | return super().__call__(stat) 35 | 36 | def _compose(self, fn: Callable[[_FnInputT], _T]) -> Weighting[_FnInputT]: 37 | return _Composition(self, fn) 38 | 39 | __lshift__ = _compose 40 | 41 | 42 | class _Composition(Weighting[_T]): 43 | """ 44 | Weighting that composes a Weighting with a function, so that the Weighting is applied to the 45 | output of the function. 46 | """ 47 | 48 | def __init__(self, weighting: Weighting[_FnOutputT], fn: Callable[[_T], _FnOutputT]): 49 | super().__init__() 50 | self.fn = fn 51 | self.weighting = weighting 52 | 53 | def forward(self, stat: _T) -> Tensor: 54 | return self.weighting(self.fn(stat)) 55 | 56 | 57 | class GeneralizedWeighting(nn.Module, ABC): 58 | r""" 59 | Abstract base class for all weightings that operate on generalized Gramians. It has the role of 60 | extracting a tensor of weights of dimension :math:`m_1 \times \dots \times m_k` from a 61 | generalized Gramian of dimension 62 | :math:`m_1 \times \dots \times m_k \times m_k \times \dots \times m_1`. 63 | """ 64 | 65 | def __init__(self): 66 | super().__init__() 67 | 68 | @abstractmethod 69 | def forward(self, generalized_gramian: Tensor) -> Tensor: 70 | """Computes the vector of weights from the input generalized Gramian.""" 71 | 72 | # Override to make type hints and documentation more specific 73 | def __call__(self, generalized_gramian: Tensor) -> Tensor: 74 | """ 75 | Computes the tensor of weights from the input generalized Gramian and applies all registered 76 | hooks. 77 | """ 78 | 79 | return super().__call__(generalized_gramian) 80 | -------------------------------------------------------------------------------- /docs/source/examples/mtl.rst: -------------------------------------------------------------------------------- 1 | Multi-Task Learning (MTL) 2 | ========================= 3 | 4 | In the context of multi-task learning, multiple tasks are performed simultaneously on a common 5 | input. Typically, a feature extractor is applied to the input to obtain a shared representation, 6 | useful for all tasks. Then, task-specific heads are applied to these features to obtain each task's 7 | result. A loss can then be computed for each task. Fundamentally, multi-task learning is a 8 | multi-objective optimization problem in which we minimize the vector of task losses. 9 | 10 | A common trick to train multi-task models is to cast the problem as single-objective, by minimizing 11 | a weighted sum of the losses. This works well in some cases, but sometimes conflict among tasks can 12 | make the optimization of the shared parameters very hard. Besides, the weight associated to each 13 | loss can be considered as a hyper-parameter. Finding their optimal value is generally expensive. 14 | 15 | Alternatively, the vector of losses can be directly minimized using Jacobian descent. The following 16 | example shows how to use TorchJD to train a very simple multi-task model with two regression tasks. 17 | For the sake of the example, we generate a fake dataset consisting of 8 batches of 16 random input 18 | vectors of dimension 10, and their corresponding scalar labels for both tasks. 19 | 20 | 21 | .. code-block:: python 22 | :emphasize-lines: 5-6, 19, 33 23 | 24 | import torch 25 | from torch.nn import Linear, MSELoss, ReLU, Sequential 26 | from torch.optim import SGD 27 | 28 | from torchjd.aggregation import UPGrad 29 | from torchjd.autojac import mtl_backward 30 | 31 | shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) 32 | task1_module = Linear(3, 1) 33 | task2_module = Linear(3, 1) 34 | params = [ 35 | *shared_module.parameters(), 36 | *task1_module.parameters(), 37 | *task2_module.parameters(), 38 | ] 39 | 40 | loss_fn = MSELoss() 41 | optimizer = SGD(params, lr=0.1) 42 | aggregator = UPGrad() 43 | 44 | inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 45 | task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task 46 | task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task 47 | 48 | for input, target1, target2 in zip(inputs, task1_targets, task2_targets): 49 | features = shared_module(input) 50 | output1 = task1_module(features) 51 | output2 = task2_module(features) 52 | loss1 = loss_fn(output1, target1) 53 | loss2 = loss_fn(output2, target2) 54 | 55 | optimizer.zero_grad() 56 | mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) 57 | optimizer.step() 58 | 59 | .. note:: 60 | In this example, the Jacobian is only with respect to the shared parameters. The task-specific 61 | parameters are simply updated via the gradient of their task's loss with respect to them. 62 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_mgda.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from ._aggregator_bases import GramianWeightedAggregator 5 | from ._weighting_bases import PSDMatrix, Weighting 6 | 7 | 8 | class MGDA(GramianWeightedAggregator): 9 | r""" 10 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` performing the gradient aggregation 11 | step of `Multiple-gradient descent algorithm (MGDA) for multiobjective optimization 12 | `_. The implementation is 13 | based on Algorithm 2 of `Multi-Task Learning as Multi-Objective Optimization 14 | `_. 15 | 16 | :param epsilon: The value of :math:`\hat{\gamma}` below which we stop the optimization. 17 | :param max_iters: The maximum number of iterations of the optimization loop. 18 | """ 19 | 20 | def __init__(self, epsilon: float = 0.001, max_iters: int = 100): 21 | super().__init__(MGDAWeighting(epsilon=epsilon, max_iters=max_iters)) 22 | self._epsilon = epsilon 23 | self._max_iters = max_iters 24 | 25 | def __repr__(self) -> str: 26 | return f"{self.__class__.__name__}(epsilon={self._epsilon}, max_iters={self._max_iters})" 27 | 28 | 29 | class MGDAWeighting(Weighting[PSDMatrix]): 30 | r""" 31 | :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of 32 | :class:`~torchjd.aggregation.MGDA`. 33 | 34 | :param epsilon: The value of :math:`\hat{\gamma}` below which we stop the optimization. 35 | :param max_iters: The maximum number of iterations of the optimization loop. 36 | """ 37 | 38 | def __init__(self, epsilon: float = 0.001, max_iters: int = 100): 39 | super().__init__() 40 | self.epsilon = epsilon 41 | self.max_iters = max_iters 42 | 43 | def forward(self, gramian: Tensor) -> Tensor: 44 | """ 45 | This is the Frank-Wolfe solver in Algorithm 2 of `Multi-Task Learning as Multi-Objective 46 | Optimization 47 | `_. 48 | """ 49 | device = gramian.device 50 | dtype = gramian.dtype 51 | 52 | alpha = torch.ones(gramian.shape[0], device=device, dtype=dtype) / gramian.shape[0] 53 | for i in range(self.max_iters): 54 | t = torch.argmin(gramian @ alpha) 55 | e_t = torch.zeros(gramian.shape[0], device=device, dtype=dtype) 56 | e_t[t] = 1.0 57 | a = alpha @ (gramian @ e_t) 58 | b = alpha @ (gramian @ alpha) 59 | c = e_t @ (gramian @ e_t) 60 | if c <= a: 61 | gamma = 1.0 62 | elif b <= a: 63 | gamma = 0.0 64 | else: 65 | gamma = (b - a) / (b + c - 2 * a) # type: ignore[assignment] 66 | alpha = (1 - gamma) * alpha + gamma * e_t 67 | if gamma < self.epsilon: 68 | break 69 | return alpha 70 | -------------------------------------------------------------------------------- /docs/source/examples/iwmtl.rst: -------------------------------------------------------------------------------- 1 | Instance-Wise Multi-Task Learning (IWMTL) 2 | ========================================= 3 | 4 | When training a model with multiple tasks, the gradients of the individual tasks are likely to 5 | conflict. This is particularly true when looking at the individual (per-sample) gradients. 6 | The :doc:`autogram engine <../docs/autogram/engine>` can be used to efficiently compute the Gramian 7 | of the Jacobian of the matrix of per-sample and per-task losses. Weights can then be extracted from 8 | this Gramian to reweight the gradients and resolve conflict entirely. 9 | 10 | The following example shows how to do that. 11 | 12 | .. code-block:: python 13 | :emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 41-42 14 | 15 | import torch 16 | from torch.nn import Linear, MSELoss, ReLU, Sequential 17 | from torch.optim import SGD 18 | 19 | from torchjd.aggregation import Flattening, UPGradWeighting 20 | from torchjd.autogram import Engine 21 | 22 | shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) 23 | task1_module = Linear(3, 1) 24 | task2_module = Linear(3, 1) 25 | params = [ 26 | *shared_module.parameters(), 27 | *task1_module.parameters(), 28 | *task2_module.parameters(), 29 | ] 30 | 31 | optimizer = SGD(params, lr=0.1) 32 | mse = MSELoss(reduction="none") 33 | weighting = Flattening(UPGradWeighting()) 34 | engine = Engine(shared_module, batch_dim=0) 35 | 36 | inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 37 | task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task 38 | task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task 39 | 40 | for input, target1, target2 in zip(inputs, task1_targets, task2_targets): 41 | features = shared_module(input) # shape: [16, 3] 42 | out1 = task1_module(features).squeeze(1) # shape: [16] 43 | out2 = task2_module(features).squeeze(1) # shape: [16] 44 | 45 | # Compute the matrix of losses: one loss per element of the batch and per task 46 | losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2] 47 | 48 | # Compute the gramian (inner products between pairs of gradients of the losses) 49 | gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16] 50 | 51 | # Obtain the weights that lead to no conflict between reweighted gradients 52 | weights = weighting(gramian) # shape: [16, 2] 53 | 54 | optimizer.zero_grad() 55 | # Do the standard backward pass, but weighted using the obtained weights 56 | losses.backward(weights) 57 | optimizer.step() 58 | 59 | .. note:: 60 | In this example, the tensor of losses is a matrix rather than a vector. The gramian is thus a 61 | 4D tensor rather than a matrix, and a 62 | :class:`~torchjd.aggregation._weighting_bases.GeneralizedWeighting`, such as 63 | :class:`~torchjd.aggregation._flattening.Flattening`, has to be used to extract a matrix of 64 | weights from it. More information about ``GeneralizedWeighting`` can be found in the 65 | :doc:`../../docs/aggregation/index` page. 66 | -------------------------------------------------------------------------------- /docs/source/examples/amp.rst: -------------------------------------------------------------------------------- 1 | Automatic Mixed Precision (AMP) 2 | =============================== 3 | 4 | In some cases, to save memory and reduce computation time, you may want to use `automatic mixed 5 | precision `_. Since the 6 | `torch.amp.GradScaler `_ class already 7 | works on multiple losses, it's pretty straightforward to combine TorchJD and AMP. As usual, the 8 | forward pass should be wrapped within a `torch.autocast 9 | `_ context, and as usual, the loss (in our 10 | case, the losses) should preferably be scaled with a `GradScaler 11 | `_ to avoid gradient underflow. The 12 | following example shows the resulting code for a multi-task learning use-case. 13 | 14 | .. code-block:: python 15 | :emphasize-lines: 2, 17, 27, 34, 36-38 16 | 17 | import torch 18 | from torch.amp import GradScaler 19 | from torch.nn import Linear, MSELoss, ReLU, Sequential 20 | from torch.optim import SGD 21 | 22 | from torchjd.aggregation import UPGrad 23 | from torchjd.autojac import mtl_backward 24 | 25 | shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) 26 | task1_module = Linear(3, 1) 27 | task2_module = Linear(3, 1) 28 | params = [ 29 | *shared_module.parameters(), 30 | *task1_module.parameters(), 31 | *task2_module.parameters(), 32 | ] 33 | scaler = GradScaler(device="cpu") 34 | loss_fn = MSELoss() 35 | optimizer = SGD(params, lr=0.1) 36 | aggregator = UPGrad() 37 | 38 | inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 39 | task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task 40 | task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task 41 | 42 | for input, target1, target2 in zip(inputs, task1_targets, task2_targets): 43 | with torch.autocast(device_type="cpu", dtype=torch.float16): 44 | features = shared_module(input) 45 | output1 = task1_module(features) 46 | output2 = task2_module(features) 47 | loss1 = loss_fn(output1, target1) 48 | loss2 = loss_fn(output2, target2) 49 | 50 | scaled_losses = scaler.scale([loss1, loss2]) 51 | optimizer.zero_grad() 52 | mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator) 53 | scaler.step(optimizer) 54 | scaler.update() 55 | 56 | .. hint:: 57 | Within the ``torch.autocast`` context, some operations may be done in ``float16`` type. For 58 | those operations, the tensors saved for the backward pass will also be of ``float16`` type. 59 | However, the Jacobian computed by ``mtl_backward`` will be of type ``float32``, so the ``.grad`` 60 | fields of the model parameters will also be of type ``float32``. This is in line with the 61 | behavior of PyTorch, that would also compute all gradients in ``float32`` type. 62 | 63 | .. note:: 64 | :doc:`torchjd.backward <../docs/autojac/backward>` can be similarly combined with AMP. 65 | -------------------------------------------------------------------------------- /src/torchjd/autogram/_gramian_utils.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | def reshape_gramian(gramian: Tensor, half_shape: list[int]) -> Tensor: 5 | """ 6 | Reshapes a Gramian to a provided shape. The reshape of the first half of the target dimensions 7 | must be done from the left, while the reshape of the second half must be done from the right. 8 | 9 | :param gramian: Gramian to reshape. Can be a generalized Gramian. 10 | :param half_shape: First half of the target shape, the shape of the output is therefore 11 | `shape + shape[::-1]`. 12 | """ 13 | 14 | # Example 1: `gramian` of shape [4, 3, 2, 2, 3, 4] and `shape` of [8, 3]: 15 | # [4, 3, 2, 2, 3, 4] -(movedim)-> [4, 3, 2, 4, 3, 2] -(reshape)-> [8, 3, 8, 3] -(movedim)-> 16 | # [8, 3, 3, 8] 17 | # 18 | # Example 2: `gramian` of shape [24, 24] and `shape` of [4, 3, 2]: 19 | # [24, 24] -(movedim)-> [24, 24] -(reshape)-> [4, 3, 2, 4, 3, 2] -(movedim)-> [4, 3, 2, 2, 3, 4] 20 | 21 | return _revert_last_dims(_revert_last_dims(gramian).reshape(half_shape + half_shape)) 22 | 23 | 24 | def _revert_last_dims(gramian: Tensor) -> Tensor: 25 | """Inverts the order of the last half of the dimensions of the input generalized Gramian.""" 26 | 27 | half_ndim = gramian.ndim // 2 28 | last_dims = [half_ndim + i for i in range(half_ndim)] 29 | return gramian.movedim(last_dims, last_dims[::-1]) 30 | 31 | 32 | def movedim_gramian(gramian: Tensor, half_source: list[int], half_destination: list[int]) -> Tensor: 33 | """ 34 | Moves the dimensions of a Gramian from some source dimensions to destination dimensions. This 35 | must be done simultaneously on the first half of the dimensions and on the second half of the 36 | dimensions reversed. 37 | 38 | :param gramian: Gramian to reshape. Can be a generalized Gramian. 39 | :param half_source: Source dimensions, that should be in the range [-gramian.ndim//2, 40 | gramian.ndim//2[. Its elements should be unique. 41 | :param half_destination: Destination dimensions, that should be in the range 42 | [-gramian.ndim//2, gramian.ndim//2[. It should have the same size as `half_source`, and its 43 | elements should be unique. 44 | """ 45 | 46 | # Example: `gramian` of shape [4, 3, 2, 2, 3, 4], `half_source` of [-2, 2] and 47 | # `half_destination` of [0, 1]: 48 | # - `half_source_` will be [1, 2] and `half_destination_` will be [0, 1] 49 | # - `source` will be [1, 2, 4, 3] and `destination` will be [0, 1, 5, 4] 50 | # - The `moved_gramian` will be of shape [3, 2, 4, 4, 2, 3] 51 | 52 | # Map everything to the range [0, gramian.ndim//2[ 53 | half_ndim = gramian.ndim // 2 54 | half_source_ = [i if 0 <= i else i + half_ndim for i in half_source] 55 | half_destination_ = [i if 0 <= i else i + half_ndim for i in half_destination] 56 | 57 | # Mirror the half source and the half destination and use the result to move the dimensions of 58 | # the gramian 59 | last_dim = gramian.ndim - 1 60 | source = half_source_ + [last_dim - i for i in half_source_] 61 | destination = half_destination_ + [last_dim - i for i in half_destination_] 62 | moved_gramian = gramian.movedim(source, destination) 63 | return moved_gramian 64 | -------------------------------------------------------------------------------- /tests/unit/aggregation/test_constant.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext as does_not_raise 2 | 3 | import torch 4 | from pytest import mark, raises 5 | from torch import Tensor 6 | from utils.contexts import ExceptionContext 7 | from utils.tensors import ones_, tensor_ 8 | 9 | from torchjd.aggregation import Constant 10 | 11 | from ._asserts import ( 12 | assert_expected_structure, 13 | assert_linear_under_scaling, 14 | assert_strongly_stationary, 15 | ) 16 | from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices 17 | 18 | 19 | def _make_aggregator(matrix: Tensor) -> Constant: 20 | n_rows = matrix.shape[0] 21 | weights = tensor_([1.0 / n_rows] * n_rows, dtype=matrix.dtype) 22 | return Constant(weights) 23 | 24 | 25 | scaled_pairs = [(_make_aggregator(matrix), matrix) for matrix in scaled_matrices] 26 | typical_pairs = [(_make_aggregator(matrix), matrix) for matrix in typical_matrices] 27 | non_strong_pairs = [(_make_aggregator(matrix), matrix) for matrix in non_strong_matrices] 28 | 29 | 30 | @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) 31 | def test_expected_structure(aggregator: Constant, matrix: Tensor): 32 | assert_expected_structure(aggregator, matrix) 33 | 34 | 35 | @mark.parametrize(["aggregator", "matrix"], typical_pairs) 36 | def test_linear_under_scaling(aggregator: Constant, matrix: Tensor): 37 | assert_linear_under_scaling(aggregator, matrix) 38 | 39 | 40 | @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) 41 | def test_strongly_stationary(aggregator: Constant, matrix: Tensor): 42 | assert_strongly_stationary(aggregator, matrix) 43 | 44 | 45 | @mark.parametrize( 46 | ["weights_shape", "expectation"], 47 | [ 48 | ([], raises(ValueError)), 49 | ([0], does_not_raise()), 50 | ([1], does_not_raise()), 51 | ([10], does_not_raise()), 52 | ([0, 0], raises(ValueError)), 53 | ([0, 1], raises(ValueError)), 54 | ([1, 1], raises(ValueError)), 55 | ([1, 1, 1], raises(ValueError)), 56 | ([1, 1, 1, 1], raises(ValueError)), 57 | ([1, 1, 1, 1, 1], raises(ValueError)), 58 | ], 59 | ) 60 | def test_weights_shape_check(weights_shape: list[int], expectation: ExceptionContext): 61 | weights = ones_(weights_shape) 62 | with expectation: 63 | _ = Constant(weights=weights) 64 | 65 | 66 | @mark.parametrize( 67 | ["weights_shape", "n_rows", "expectation"], 68 | [ 69 | ([0], 0, does_not_raise()), 70 | ([1], 1, does_not_raise()), 71 | ([5], 5, does_not_raise()), 72 | ([0], 1, raises(ValueError)), 73 | ([1], 0, raises(ValueError)), 74 | ([4], 5, raises(ValueError)), 75 | ([5], 4, raises(ValueError)), 76 | ], 77 | ) 78 | def test_matrix_shape_check(weights_shape: list[int], n_rows: int, expectation: ExceptionContext): 79 | matrix = ones_([n_rows, 5]) 80 | weights = ones_(weights_shape) 81 | aggregator = Constant(weights) 82 | 83 | with expectation: 84 | _ = aggregator(matrix) 85 | 86 | 87 | def test_representations(): 88 | A = Constant(weights=torch.tensor([1.0, 2.0], device="cpu")) 89 | assert repr(A) == "Constant(weights=tensor([1., 2.]))" 90 | assert str(A) == "Constant([1., 2.])" 91 | -------------------------------------------------------------------------------- /docs/source/examples/monitoring.rst: -------------------------------------------------------------------------------- 1 | Monitoring aggregations 2 | ======================= 3 | 4 | The :doc:`Aggregator <../docs/aggregation/index>` class is a subclass of :class:`torch.nn.Module`. 5 | This allows registering hooks, which can be used to monitor some information about aggregations. 6 | The following code example demonstrates registering a hook to compute and print the cosine 7 | similarity between the aggregation performed by :doc:`UPGrad <../docs/aggregation/upgrad>` and the 8 | average of the gradients, and another hook to compute and print the weights of the weighting of 9 | :doc:`UPGrad <../docs/aggregation/upgrad>`. 10 | 11 | Updating the parameters of the model with the average gradient is equivalent to using gradient 12 | descent on the average of the losses. Observing a cosine similarity smaller than 1 means that 13 | Jacobian descent is doing something different than gradient descent. With 14 | :doc:`UPGrad <../docs/aggregation/upgrad>`, this happens when the original gradients conflict (i.e. 15 | they have a negative inner product). 16 | 17 | .. code-block:: python 18 | :emphasize-lines: 9-11, 13-18, 33-34 19 | 20 | import torch 21 | from torch.nn import Linear, MSELoss, ReLU, Sequential 22 | from torch.nn.functional import cosine_similarity 23 | from torch.optim import SGD 24 | 25 | from torchjd.aggregation import UPGrad 26 | from torchjd.autojac import mtl_backward 27 | 28 | def print_weights(_, __, weights: torch.Tensor) -> None: 29 | """Prints the extracted weights.""" 30 | print(f"Weights: {weights}") 31 | 32 | def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.Tensor) -> None: 33 | """Prints the cosine similarity between the aggregation and the average gradient.""" 34 | matrix = inputs[0] 35 | gd_output = matrix.mean(dim=0) 36 | similarity = cosine_similarity(aggregation, gd_output, dim=0) 37 | print(f"Cosine similarity: {similarity.item():.4f}") 38 | 39 | shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) 40 | task1_module = Linear(3, 1) 41 | task2_module = Linear(3, 1) 42 | params = [ 43 | *shared_module.parameters(), 44 | *task1_module.parameters(), 45 | *task2_module.parameters(), 46 | ] 47 | 48 | loss_fn = MSELoss() 49 | optimizer = SGD(params, lr=0.1) 50 | aggregator = UPGrad() 51 | 52 | aggregator.weighting.register_forward_hook(print_weights) 53 | aggregator.register_forward_hook(print_gd_similarity) 54 | 55 | inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 56 | task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task 57 | task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task 58 | 59 | for input, target1, target2 in zip(inputs, task1_targets, task2_targets): 60 | features = shared_module(input) 61 | output1 = task1_module(features) 62 | output2 = task2_module(features) 63 | loss1 = loss_fn(output1, target1) 64 | loss2 = loss_fn(output2, target2) 65 | 66 | optimizer.zero_grad() 67 | mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) 68 | optimizer.step() 69 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_graddrop.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from ._aggregator_bases import Aggregator 7 | from ._utils.non_differentiable import raise_non_differentiable_error 8 | 9 | 10 | def _identity(P: Tensor) -> Tensor: 11 | return P 12 | 13 | 14 | class GradDrop(Aggregator): 15 | """ 16 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that applies the gradient combination 17 | steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of `Just Pick a Sign: 18 | Optimizing Deep Multitask Models with Gradient Sign Dropout 19 | `_. 20 | 21 | :param f: The function to apply to the Gradient Positive Sign Purity. It should be monotically 22 | increasing. Defaults to identity. 23 | :param leak: The tensor of leak values, determining how much each row is allowed to leak 24 | through. Defaults to None, which means no leak. 25 | """ 26 | 27 | def __init__(self, f: Callable = _identity, leak: Tensor | None = None): 28 | if leak is not None and leak.dim() != 1: 29 | raise ValueError( 30 | "Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = " 31 | f"{leak.shape}`." 32 | ) 33 | 34 | super().__init__() 35 | self.f = f 36 | self.leak = leak 37 | 38 | # This prevents computing gradients that can be very wrong. 39 | self.register_full_backward_pre_hook(raise_non_differentiable_error) 40 | 41 | def forward(self, matrix: Tensor) -> Tensor: 42 | self._check_is_matrix(matrix) 43 | self._check_matrix_has_enough_rows(matrix) 44 | 45 | if matrix.shape[0] == 0 or matrix.shape[1] == 0: 46 | return torch.zeros(matrix.shape[1], dtype=matrix.dtype, device=matrix.device) 47 | 48 | leak = self.leak if self.leak is not None else torch.zeros_like(matrix[:, 0]) 49 | 50 | P = 0.5 * (torch.ones_like(matrix[0]) + matrix.sum(dim=0) / matrix.abs().sum(dim=0)) 51 | fP = self.f(P) 52 | U = torch.rand(P.shape, dtype=matrix.dtype, device=matrix.device) 53 | 54 | vector = torch.zeros_like(matrix[0]) 55 | for i in range(len(matrix)): 56 | M_i = (fP > U) * (matrix[i] > 0) + (fP < U) * (matrix[i] < 0) 57 | vector += (leak[i] + (1 - leak[i]) * M_i) * matrix[i] 58 | 59 | return vector 60 | 61 | def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None: 62 | n_rows = matrix.shape[0] 63 | if self.leak is not None and n_rows != len(self.leak): 64 | raise ValueError( 65 | f"Parameter `matrix` should be a matrix of exactly {len(self.leak)} rows (i.e. the " 66 | f"number of leak scalars). Found `matrix` of shape `{matrix.shape}`." 67 | ) 68 | 69 | def __repr__(self) -> str: 70 | return f"{self.__class__.__name__}(f={repr(self.f)}, leak={repr(self.leak)})" 71 | 72 | def __str__(self) -> str: 73 | if self.leak is None: 74 | leak_str = "" 75 | else: 76 | leak_str = f"([{', '.join(['{:.2f}'.format(l_).rstrip('0') for l_ in self.leak])}])" 77 | return f"GradDrop{leak_str}" 78 | -------------------------------------------------------------------------------- /tests/unit/autogram/test_edge_registry.py: -------------------------------------------------------------------------------- 1 | from torch.autograd.graph import get_gradient_edge 2 | from utils.tensors import randn_ 3 | 4 | from torchjd.autogram._edge_registry import EdgeRegistry 5 | 6 | 7 | def test_all_edges_are_leaves1(): 8 | """Tests that get_leaf_edges works correctly when all edges are already leaves.""" 9 | 10 | a = randn_([3, 4], requires_grad=True) 11 | b = randn_([4], requires_grad=True) 12 | c = randn_([3], requires_grad=True) 13 | 14 | d = (a @ b) + c 15 | 16 | edge_registry = EdgeRegistry() 17 | for tensor in [a, b, c]: 18 | edge_registry.register(get_gradient_edge(tensor)) 19 | 20 | expected_leaves = {get_gradient_edge(tensor) for tensor in [a, b, c]} 21 | leaves = edge_registry.get_leaf_edges({get_gradient_edge(d)}) 22 | assert leaves == expected_leaves 23 | 24 | 25 | def test_all_edges_are_leaves2(): 26 | """ 27 | Tests that get_leaf_edges works correctly when all edges are already leaves of the graph of 28 | edges leading to them, but are not leaves of the autograd graph. 29 | """ 30 | 31 | a = randn_([3, 4], requires_grad=True) 32 | b = randn_([4], requires_grad=True) 33 | c = randn_([4], requires_grad=True) 34 | d = randn_([4], requires_grad=True) 35 | 36 | e = a * b 37 | f = e + c 38 | g = f + d 39 | 40 | edge_registry = EdgeRegistry() 41 | for tensor in [e, g]: 42 | edge_registry.register(get_gradient_edge(tensor)) 43 | 44 | expected_leaves = {get_gradient_edge(tensor) for tensor in [e, g]} 45 | leaves = edge_registry.get_leaf_edges({get_gradient_edge(e), get_gradient_edge(g)}) 46 | assert leaves == expected_leaves 47 | 48 | 49 | def test_some_edges_are_not_leaves1(): 50 | """Tests that get_leaf_edges works correctly when some edges are leaves and some are not.""" 51 | 52 | a = randn_([3, 4], requires_grad=True) 53 | b = randn_([4], requires_grad=True) 54 | c = randn_([4], requires_grad=True) 55 | d = randn_([4], requires_grad=True) 56 | 57 | e = a * b 58 | f = e + c 59 | g = f + d 60 | 61 | edge_registry = EdgeRegistry() 62 | for tensor in [a, b, c, d, e, f, g]: 63 | edge_registry.register(get_gradient_edge(tensor)) 64 | 65 | expected_leaves = {get_gradient_edge(tensor) for tensor in [a, b, c, d]} 66 | leaves = edge_registry.get_leaf_edges({get_gradient_edge(g)}) 67 | assert leaves == expected_leaves 68 | 69 | 70 | def test_some_edges_are_not_leaves2(): 71 | """ 72 | Tests that get_leaf_edges works correctly when some edges are leaves and some are not. This 73 | time, not all tensors in the graph are registered so not all leavese in the graph have to be 74 | returned. 75 | """ 76 | 77 | a = randn_([3, 4], requires_grad=True) 78 | b = randn_([4], requires_grad=True) 79 | c = randn_([4], requires_grad=True) 80 | d = randn_([4], requires_grad=True) 81 | 82 | e = a * b 83 | f = e + c 84 | g = f + d 85 | 86 | edge_registry = EdgeRegistry() 87 | for tensor in [a, c, d, e, g]: 88 | edge_registry.register(get_gradient_edge(tensor)) 89 | 90 | expected_leaves = {get_gradient_edge(tensor) for tensor in [a, c, d]} 91 | leaves = edge_registry.get_leaf_edges({get_gradient_edge(g)}) 92 | assert leaves == expected_leaves 93 | -------------------------------------------------------------------------------- /src/torchjd/autojac/_transform/_differentiate.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from collections.abc import Sequence 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | from ._base import RequirementError, TensorDict, Transform 8 | from ._materialize import materialize 9 | from ._ordered_set import OrderedSet 10 | 11 | 12 | class Differentiate(Transform, ABC): 13 | """ 14 | Abstract base class for transforms responsible for differentiating some outputs with respect to 15 | some inputs. 16 | 17 | :param outputs: Tensors to differentiate. 18 | :param inputs: Tensors with respect to which we differentiate. 19 | :param retain_graph: If False, the graph used to compute the grads will be freed. 20 | :param create_graph: If True, graph of the derivative will be constructed, allowing to compute 21 | higher order derivative products. 22 | 23 | .. note:: The order of outputs and inputs only matters because we have no guarantee that 24 | torch.autograd.grad is *exactly* equivariant to input permutations and invariant to output 25 | (with their corresponding grad_output) permutations. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | outputs: OrderedSet[Tensor], 31 | inputs: OrderedSet[Tensor], 32 | retain_graph: bool, 33 | create_graph: bool, 34 | ): 35 | self.outputs = list(outputs) 36 | self.inputs = list(inputs) 37 | self.retain_graph = retain_graph 38 | self.create_graph = create_graph 39 | 40 | def __call__(self, tensors: TensorDict) -> TensorDict: 41 | tensor_outputs = [tensors[output] for output in self.outputs] 42 | 43 | differentiated_tuple = self._differentiate(tensor_outputs) 44 | new_differentiations = dict(zip(self.inputs, differentiated_tuple)) 45 | return type(tensors)(new_differentiations) 46 | 47 | @abstractmethod 48 | def _differentiate(self, tensor_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]: 49 | """ 50 | Abstract method for differentiating the outputs with respect to the inputs, and applying the 51 | linear transformations represented by the tensor_outputs to the results. 52 | 53 | The implementation of this method should define what kind of differentiation is performed: 54 | whether gradients, Jacobians, etc. are computed, and what the dimension of the 55 | tensor_outputs should be. 56 | """ 57 | 58 | def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: 59 | outputs = set(self.outputs) 60 | if not outputs == input_keys: 61 | raise RequirementError( 62 | f"The input_keys must match the expected outputs. Found input_keys {input_keys} and" 63 | f"outputs {outputs}." 64 | ) 65 | return set(self.inputs) 66 | 67 | def _get_vjp(self, grad_outputs: Sequence[Tensor], retain_graph: bool) -> tuple[Tensor, ...]: 68 | optional_grads = torch.autograd.grad( 69 | self.outputs, 70 | self.inputs, 71 | grad_outputs=grad_outputs, 72 | retain_graph=retain_graph, 73 | create_graph=self.create_graph, 74 | allow_unused=True, 75 | ) 76 | grads = materialize(optional_grads, inputs=self.inputs) 77 | return grads 78 | -------------------------------------------------------------------------------- /docs/source/examples/lightning_integration.rst: -------------------------------------------------------------------------------- 1 | PyTorch Lightning Integration 2 | ============================= 3 | 4 | To use Jacobian descent with TorchJD in a :class:`~lightning.pytorch.core.LightningModule`, you need 5 | to turn off automatic optimization by setting ``automatic_optimization`` to ``False`` and to 6 | customize the ``training_step`` method to make it call the appropriate TorchJD method 7 | (:doc:`backward <../docs/autojac/backward>` or :doc:`mtl_backward <../docs/autojac/mtl_backward>`). 8 | 9 | The following code example demonstrates a basic multi-task learning setup using a 10 | :class:`~lightning.pytorch.core.LightningModule` that will call :doc:`mtl_backward 11 | <../docs/autojac/mtl_backward>` at each training iteration. 12 | 13 | .. code-block:: python 14 | :emphasize-lines: 9-10, 18, 32 15 | 16 | import torch 17 | from lightning import LightningModule, Trainer 18 | from lightning.pytorch.utilities.types import OptimizerLRScheduler 19 | from torch.nn import Linear, ReLU, Sequential 20 | from torch.nn.functional import mse_loss 21 | from torch.optim import Adam 22 | from torch.utils.data import DataLoader, TensorDataset 23 | 24 | from torchjd.aggregation import UPGrad 25 | from torchjd.autojac import mtl_backward 26 | 27 | class Model(LightningModule): 28 | def __init__(self): 29 | super().__init__() 30 | self.feature_extractor = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) 31 | self.task1_head = Linear(3, 1) 32 | self.task2_head = Linear(3, 1) 33 | self.automatic_optimization = False 34 | 35 | def training_step(self, batch, batch_idx) -> None: 36 | input, target1, target2 = batch 37 | 38 | features = self.feature_extractor(input) 39 | output1 = self.task1_head(features) 40 | output2 = self.task2_head(features) 41 | 42 | loss1 = mse_loss(output1, target1) 43 | loss2 = mse_loss(output2, target2) 44 | 45 | opt = self.optimizers() 46 | opt.zero_grad() 47 | mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad()) 48 | opt.step() 49 | 50 | def configure_optimizers(self) -> OptimizerLRScheduler: 51 | optimizer = Adam(self.parameters(), lr=1e-3) 52 | return optimizer 53 | 54 | model = Model() 55 | 56 | inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 57 | task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task 58 | task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task 59 | 60 | dataset = TensorDataset(inputs, task1_targets, task2_targets) 61 | train_loader = DataLoader(dataset) 62 | trainer = Trainer( 63 | accelerator="cpu", 64 | max_epochs=1, 65 | enable_checkpointing=False, 66 | logger=False, 67 | enable_progress_bar=False, 68 | ) 69 | 70 | trainer.fit(model=model, train_dataloaders=train_loader) 71 | 72 | .. warning:: 73 | This will not handle automatic scaling in low-precision settings. There is currently no easy 74 | fix. 75 | 76 | .. warning:: 77 | TorchJD is incompatible with compiled models, so you must ensure that your model is not 78 | compiled. 79 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | | 4 | 5 | .. image:: _static/logo-dark-mode.png 6 | :width: 400 7 | :alt: torchjd 8 | :align: center 9 | :class: only-dark, no-scaled-link 10 | 11 | .. image:: _static/logo-light-mode.png 12 | :width: 400 13 | :alt: torchjd 14 | :align: center 15 | :class: only-light, no-scaled-link 16 | 17 | | 18 | 19 | TorchJD is a library enabling Jacobian descent with PyTorch, to train neural networks with multiple 20 | objectives. It is based on the theory from `Jacobian Descent For Multi-Objective Optimization 21 | `_ and several other related publications. 22 | 23 | The main purpose is to jointly optimize multiple objectives without combining them into a single 24 | scalar loss. When the objectives are conflicting, this can be the key to a successful and stable 25 | optimization. To get started, check out our :doc:`basic usage example 26 | `. 27 | 28 | Gradient descent relies on gradients to optimize a single objective. Jacobian descent takes this 29 | idea a step further, using the Jacobian to optimize multiple objectives. An important component of 30 | Jacobian descent is the aggregator, which maps the Jacobian to an optimization step. In the page 31 | :doc:`Aggregation `, we provide an overview of the various aggregators 32 | available in TorchJD, and their corresponding weightings. 33 | 34 | A straightforward application of Jacobian descent is multi-task learning, in which the vector of 35 | per-task losses has to be minimized. To start using TorchJD for multi-task learning, follow our 36 | :doc:`MTL example `. 37 | 38 | Another more interesting application is to consider separately the loss of each element in the 39 | batch. This is what we define as :doc:`Instance-Wise Risk Minimization ` (IWRM). 40 | 41 | The Gramian-based Jacobian descent algorithm provides a very efficient alternative way of 42 | performing Jacobian descent. It consists in computing 43 | the Gramian of the Jacobian iteratively during the backward pass (without ever storing the full 44 | Jacobian in memory), weighting the losses using the information of the Gramian, and then computing 45 | the gradient of the obtained weighted loss. The iterative computation of the Gramian corresponds to 46 | Algorithm 3 of 47 | `Jacobian Descent For Multi-Objective Optimization `_. The 48 | documentation and usage example of this algorithm is provided in 49 | :doc:`autogram.Engine `. 50 | 51 | The original usage of the autogram engine is to compute the Gramian of the Jacobian very efficiently 52 | for :doc:`IWRM `. Another direct application is when considering one loss per element 53 | of the batch and per task, in the context of multi-task learning. We call this 54 | :doc:`Instance-Wise Risk Multi-Task Learning ` (IWMTL). 55 | 56 | TorchJD is open-source, under MIT License. The source code is available on 57 | `GitHub `_. 58 | 59 | .. toctree:: 60 | :caption: Getting Started 61 | :hidden: 62 | 63 | installation.md 64 | examples/index.rst 65 | 66 | .. toctree:: 67 | :caption: API Reference 68 | :hidden: 69 | 70 | docs/autogram/index.rst 71 | docs/autojac/index.rst 72 | docs/aggregation/index.rst 73 | -------------------------------------------------------------------------------- /tests/unit/autojac/_transform/test_accumulate.py: -------------------------------------------------------------------------------- 1 | from pytest import mark, raises 2 | from utils.dict_assertions import assert_tensor_dicts_are_close 3 | from utils.tensors import ones_, tensor_, zeros_ 4 | 5 | from torchjd.autojac._transform import Accumulate 6 | 7 | 8 | def test_single_accumulation(): 9 | """ 10 | Tests that the Accumulate transform correctly accumulates gradients in .grad fields when run 11 | once. 12 | """ 13 | 14 | key1 = zeros_([], requires_grad=True) 15 | key2 = zeros_([1], requires_grad=True) 16 | key3 = zeros_([2, 3], requires_grad=True) 17 | value1 = ones_([]) 18 | value2 = ones_([1]) 19 | value3 = ones_([2, 3]) 20 | input = {key1: value1, key2: value2, key3: value3} 21 | 22 | accumulate = Accumulate() 23 | 24 | output = accumulate(input) 25 | expected_output = {} 26 | 27 | assert_tensor_dicts_are_close(output, expected_output) 28 | 29 | grads = {key1: key1.grad, key2: key2.grad, key3: key3.grad} 30 | expected_grads = {key1: value1, key2: value2, key3: value3} 31 | 32 | assert_tensor_dicts_are_close(grads, expected_grads) 33 | 34 | 35 | @mark.parametrize("iterations", [1, 2, 4, 10, 13]) 36 | def test_multiple_accumulation(iterations: int): 37 | """ 38 | Tests that the Accumulate transform correctly accumulates gradients in .grad fields when run 39 | `iterations` times. 40 | """ 41 | 42 | key1 = zeros_([], requires_grad=True) 43 | key2 = zeros_([1], requires_grad=True) 44 | key3 = zeros_([2, 3], requires_grad=True) 45 | value1 = ones_([]) 46 | value2 = ones_([1]) 47 | value3 = ones_([2, 3]) 48 | input = {key1: value1, key2: value2, key3: value3} 49 | 50 | accumulate = Accumulate() 51 | 52 | for i in range(iterations): 53 | accumulate(input) 54 | 55 | grads = {key1: key1.grad, key2: key2.grad, key3: key3.grad} 56 | expected_grads = { 57 | key1: iterations * value1, 58 | key2: iterations * value2, 59 | key3: iterations * value3, 60 | } 61 | 62 | assert_tensor_dicts_are_close(grads, expected_grads) 63 | 64 | 65 | def test_no_requires_grad_fails(): 66 | """ 67 | Tests that the Accumulate transform raises an error when it tries to populate a .grad of a 68 | tensor that does not require grad. 69 | """ 70 | 71 | key = zeros_([1], requires_grad=False) 72 | value = ones_([1]) 73 | input = {key: value} 74 | 75 | accumulate = Accumulate() 76 | 77 | with raises(ValueError): 78 | accumulate(input) 79 | 80 | 81 | def test_no_leaf_and_no_retains_grad_fails(): 82 | """ 83 | Tests that the Accumulate transform raises an error when it tries to populate a .grad of a 84 | tensor that is not a leaf and that does not retain grad. 85 | """ 86 | 87 | key = tensor_([1.0], requires_grad=True) * 2 88 | value = ones_([1]) 89 | input = {key: value} 90 | 91 | accumulate = Accumulate() 92 | 93 | with raises(ValueError): 94 | accumulate(input) 95 | 96 | 97 | def test_check_keys(): 98 | """Tests that the `check_keys` method works correctly.""" 99 | 100 | key = tensor_([1.0], requires_grad=True) 101 | accumulate = Accumulate() 102 | 103 | output_keys = accumulate.check_keys({key}) 104 | assert output_keys == set() 105 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_config.py: -------------------------------------------------------------------------------- 1 | # The code of this file was partly adapted from 2 | # https://github.com/tum-pbs/ConFIG/tree/main/conflictfree. 3 | # It is therefore also subject to the following license. 4 | # 5 | # MIT License 6 | # 7 | # Copyright (c) 2024 TUM Physics-based Simulation 8 | # 9 | # Permission is hereby granted, free of charge, to any person obtaining a copy 10 | # of this software and associated documentation files (the "Software"), to deal 11 | # in the Software without restriction, including without limitation the rights 12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | # copies of the Software, and to permit persons to whom the Software is 14 | # furnished to do so, subject to the following conditions: 15 | # 16 | # The above copyright notice and this permission notice shall be included in all 17 | # copies or substantial portions of the Software. 18 | # 19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | # SOFTWARE. 26 | 27 | 28 | import torch 29 | from torch import Tensor 30 | 31 | from ._aggregator_bases import Aggregator 32 | from ._sum import SumWeighting 33 | from ._utils.non_differentiable import raise_non_differentiable_error 34 | from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting 35 | 36 | 37 | class ConFIG(Aggregator): 38 | """ 39 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Equation 2 of `ConFIG: 40 | Towards Conflict-free Training of Physics Informed Neural Networks 41 | `_. 42 | 43 | :param pref_vector: The preference vector used to weight the rows. If not provided, defaults to 44 | equal weights of 1. 45 | 46 | .. note:: 47 | This implementation was adapted from the `official implementation 48 | `_. 49 | """ 50 | 51 | def __init__(self, pref_vector: Tensor | None = None): 52 | super().__init__() 53 | self.weighting = pref_vector_to_weighting(pref_vector, default=SumWeighting()) 54 | self._pref_vector = pref_vector 55 | 56 | # This prevents computing gradients that can be very wrong. 57 | self.register_full_backward_pre_hook(raise_non_differentiable_error) 58 | 59 | def forward(self, matrix: Tensor) -> Tensor: 60 | weights = self.weighting(matrix) 61 | units = torch.nan_to_num((matrix / (matrix.norm(dim=1)).unsqueeze(1)), 0.0) 62 | best_direction = torch.linalg.pinv(units) @ weights 63 | 64 | unit_target_vector = torch.nn.functional.normalize(best_direction, dim=0) 65 | 66 | length = torch.sum(matrix @ unit_target_vector) 67 | 68 | return length * unit_target_vector 69 | 70 | def __repr__(self) -> str: 71 | return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})" 72 | 73 | def __str__(self) -> str: 74 | return f"ConFIG{pref_vector_to_str_suffix(self._pref_vector)}" 75 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | pull_request: 5 | workflow_dispatch: 6 | schedule: 7 | - cron: '41 16 * * *' # Every day at 16:41 UTC (to avoid high load at exact hour values). 8 | 9 | env: 10 | UV_NO_SYNC: 1 11 | PYTHON_VERSION: 3.14 12 | 13 | jobs: 14 | tests-full-install: 15 | name: Run tests with full install 16 | runs-on: ${{ matrix.os }} 17 | strategy: 18 | fail-fast: false # Ensure matrix jobs keep running even if one fails 19 | matrix: 20 | python-version: ['3.10', '3.11', '3.12', '3.13', '3.14'] 21 | os: [ubuntu-latest, macOS-latest, windows-latest] 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Set up uv 26 | uses: astral-sh/setup-uv@v5 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | - name: Install default (with full options) and test dependencies 30 | run: uv pip install --python-version=${{ matrix.python-version }} -e '.[full]' --group test 31 | - name: Run unit and doc tests with coverage report 32 | run: uv run pytest -W error tests/unit tests/doc --cov=src --cov-report=xml 33 | - name: Upload results to Codecov 34 | uses: codecov/codecov-action@v4 35 | with: 36 | token: ${{ secrets.CODECOV_TOKEN }} 37 | 38 | tests-default-install: 39 | name: Run (most) tests with default install 40 | runs-on: ubuntu-latest 41 | steps: 42 | - uses: actions/checkout@v4 43 | - name: Set up uv 44 | uses: astral-sh/setup-uv@v5 45 | with: 46 | python-version: ${{ env.PYTHON_VERSION }} 47 | - name: Install default (without any option) and test dependencies 48 | run: uv pip install --python-version=${{ env.PYTHON_VERSION }} -e . --group test 49 | - name: Run unit and doc tests with coverage report 50 | run: | 51 | uv run pytest -W error tests/unit tests/doc \ 52 | --ignore tests/unit/aggregation/test_cagrad.py \ 53 | --ignore tests/unit/aggregation/test_nash_mtl.py \ 54 | --ignore tests/doc/test_aggregation.py \ 55 | --cov=src --cov-report=xml 56 | - name: Upload results to Codecov 57 | uses: codecov/codecov-action@v4 58 | with: 59 | token: ${{ secrets.CODECOV_TOKEN }} 60 | 61 | build-doc: 62 | name: Build doc 63 | runs-on: ubuntu-latest 64 | steps: 65 | - name: Checkout repository 66 | uses: actions/checkout@v4 67 | 68 | - name: Set up uv 69 | uses: astral-sh/setup-uv@v5 70 | with: 71 | python-version: ${{ env.PYTHON_VERSION }} 72 | 73 | - name: Install dependencies (default with full options & doc) 74 | run: uv pip install --python-version=${{ env.PYTHON_VERSION }} -e '.[full]' --group doc 75 | 76 | - name: Build Documentation 77 | working-directory: docs 78 | run: uv run make dirhtml 79 | 80 | mypy: 81 | name: Run mypy 82 | runs-on: ubuntu-latest 83 | steps: 84 | - name: Checkout repository 85 | uses: actions/checkout@v4 86 | 87 | - name: Set up uv 88 | uses: astral-sh/setup-uv@v5 89 | with: 90 | python-version: ${{ env.PYTHON_VERSION }} 91 | 92 | - name: Install dependencies (default with full options & check) 93 | run: uv pip install --python-version=${{ env.PYTHON_VERSION }} -e '.[full]' --group check 94 | 95 | - name: Run mypy 96 | run: uv run mypy src/torchjd 97 | -------------------------------------------------------------------------------- /tests/unit/autojac/_transform/test_stack.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | 3 | import torch 4 | from torch import Tensor 5 | from utils.dict_assertions import assert_tensor_dicts_are_close 6 | from utils.tensors import ones_, tensor_, zeros_ 7 | 8 | from torchjd.autojac._transform import Stack, Transform 9 | from torchjd.autojac._transform._base import TensorDict 10 | 11 | 12 | class FakeGradientsTransform(Transform): 13 | """Transform that produces gradients filled with ones, for testing purposes.""" 14 | 15 | def __init__(self, keys: Iterable[Tensor]): 16 | self.keys = set(keys) 17 | 18 | def __call__(self, input: TensorDict) -> TensorDict: 19 | return {key: torch.ones_like(key) for key in self.keys} 20 | 21 | def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: 22 | return self.keys 23 | 24 | 25 | def test_single_key(): 26 | """ 27 | Tests that the Stack transform correctly stacks gradients into a jacobian, in a very simple 28 | example with 2 transforms sharing the same key. 29 | """ 30 | 31 | key = zeros_([3, 4]) 32 | input = {} 33 | 34 | transform = FakeGradientsTransform([key]) 35 | stack = Stack([transform, transform]) 36 | 37 | output = stack(input) 38 | expected_output = {key: ones_([2, 3, 4])} 39 | 40 | assert_tensor_dicts_are_close(output, expected_output) 41 | 42 | 43 | def test_disjoint_key_sets(): 44 | """ 45 | Tests that the Stack transform correctly stacks gradients into a jacobian, in an example where 46 | the output key sets of all of its transforms are disjoint. The missing values should be replaced 47 | by zeros. 48 | """ 49 | 50 | key1 = zeros_([1, 2]) 51 | key2 = zeros_([3]) 52 | input = {} 53 | 54 | transform1 = FakeGradientsTransform([key1]) 55 | transform2 = FakeGradientsTransform([key2]) 56 | stack = Stack([transform1, transform2]) 57 | 58 | output = stack(input) 59 | expected_output = { 60 | key1: tensor_([[[1.0, 1.0]], [[0.0, 0.0]]]), 61 | key2: tensor_([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]), 62 | } 63 | 64 | assert_tensor_dicts_are_close(output, expected_output) 65 | 66 | 67 | def test_overlapping_key_sets(): 68 | """ 69 | Tests that the Stack transform correctly stacks gradients into a jacobian, in an example where 70 | the output key sets all of its transforms are overlapping (non-empty intersection, but not 71 | equal). The missing values should be replaced by zeros. 72 | """ 73 | 74 | key1 = zeros_([1, 2]) 75 | key2 = zeros_([3]) 76 | key3 = zeros_([4]) 77 | input = {} 78 | 79 | transform12 = FakeGradientsTransform([key1, key2]) 80 | transform23 = FakeGradientsTransform([key2, key3]) 81 | stack = Stack([transform12, transform23]) 82 | 83 | output = stack(input) 84 | expected_output = { 85 | key1: tensor_([[[1.0, 1.0]], [[0.0, 0.0]]]), 86 | key2: tensor_([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]), 87 | key3: tensor_([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]]), 88 | } 89 | 90 | assert_tensor_dicts_are_close(output, expected_output) 91 | 92 | 93 | def test_empty(): 94 | """Tests that the Stack transform correctly handles an empty list of transforms.""" 95 | 96 | stack = Stack([]) 97 | input = {} 98 | output = stack(input) 99 | expected_output = {} 100 | 101 | assert_tensor_dicts_are_close(output, expected_output) 102 | -------------------------------------------------------------------------------- /src/torchjd/autojac/_utils.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from collections.abc import Iterable, Sequence 3 | from typing import cast 4 | 5 | from torch import Tensor 6 | from torch.autograd.graph import Node 7 | 8 | from ._transform import OrderedSet 9 | 10 | 11 | def check_optional_positive_chunk_size(parallel_chunk_size: int | None) -> None: 12 | if not (parallel_chunk_size is None or parallel_chunk_size > 0): 13 | raise ValueError( 14 | "`parallel_chunk_size` should be `None` or greater than `0`. (got " 15 | f"{parallel_chunk_size})" 16 | ) 17 | 18 | 19 | def as_checked_ordered_set( 20 | tensors: Sequence[Tensor] | Tensor, variable_name: str 21 | ) -> OrderedSet[Tensor]: 22 | if isinstance(tensors, Tensor): 23 | tensors = [tensors] 24 | 25 | original_length = len(tensors) 26 | output = OrderedSet(tensors) 27 | 28 | if len(output) != original_length: 29 | raise ValueError(f"`{variable_name}` should contain unique elements.") 30 | 31 | return OrderedSet(tensors) 32 | 33 | 34 | def get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> OrderedSet[Tensor]: 35 | """ 36 | Gets the leaves of the autograd graph of all specified ``tensors``. 37 | 38 | :param tensors: Tensors from which the graph traversal should start. They should all require 39 | grad and not be leaves. 40 | :param excluded: Tensors whose grad_fn should be excluded from the graph traversal. They should 41 | all require grad and not be leaves. 42 | 43 | """ 44 | 45 | if any([tensor.grad_fn is None for tensor in tensors]): 46 | raise ValueError("All `tensors` should have a `grad_fn`.") 47 | 48 | if any([tensor.grad_fn is None for tensor in excluded]): 49 | raise ValueError("All `excluded` tensors should have a `grad_fn`.") 50 | 51 | accumulate_grads = _get_descendant_accumulate_grads( 52 | roots=cast(OrderedSet[Node], OrderedSet([tensor.grad_fn for tensor in tensors])), 53 | excluded_nodes=cast(set[Node], {tensor.grad_fn for tensor in excluded}), 54 | ) 55 | 56 | # accumulate_grads contains instances of AccumulateGrad, which contain a `variable` field. 57 | # They cannot be typed as such because AccumulateGrad is not public. 58 | leaves = OrderedSet([g.variable for g in accumulate_grads]) # type: ignore[attr-defined] 59 | 60 | return leaves 61 | 62 | 63 | def _get_descendant_accumulate_grads( 64 | roots: OrderedSet[Node], excluded_nodes: set[Node] 65 | ) -> OrderedSet[Node]: 66 | """ 67 | Gets the AccumulateGrad descendants of the specified nodes. 68 | 69 | :param roots: Root nodes from which the graph traversal should start. 70 | :param excluded_nodes: Nodes excluded from the graph traversal. 71 | """ 72 | 73 | excluded_nodes = set(excluded_nodes) # Re-instantiate set to avoid modifying input 74 | result: OrderedSet[Node] = OrderedSet([]) 75 | roots.difference_update(excluded_nodes) 76 | nodes_to_traverse = deque(roots) 77 | 78 | # This implementation more or less follows what is advised in 79 | # https://discuss.pytorch.org/t/autograd-graph-traversal/213658 and what was suggested in 80 | # https://github.com/TorchJD/torchjd/issues/216. 81 | while nodes_to_traverse: 82 | node = nodes_to_traverse.popleft() # Breadth-first 83 | 84 | if node.__class__.__name__ == "AccumulateGrad": 85 | result.add(node) 86 | 87 | for child, _ in node.next_functions: 88 | if child is not None and child not in excluded_nodes: 89 | nodes_to_traverse.append(child) # Append to the right 90 | excluded_nodes.add(child) 91 | 92 | return result 93 | -------------------------------------------------------------------------------- /docs/source/icons/TorchJD_text.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | TorchJD 76 | -------------------------------------------------------------------------------- /tests/unit/aggregation/_utils/test_dual_cone.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pytest import mark, raises 4 | from torch.testing import assert_close 5 | from utils.tensors import rand_, randn_ 6 | 7 | from torchjd.aggregation._utils.dual_cone import _project_weight_vector, project_weights 8 | 9 | 10 | @mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)]) 11 | def test_solution_weights(shape: tuple[int, int]): 12 | r""" 13 | Tests that `_project_weights` returns valid weights corresponding to the projection onto the 14 | dual cone of a matrix with the specified shape. 15 | 16 | Validation is performed by verifying that the solution satisfies the `KKT conditions 17 | `_ for the 18 | quadratic program that projects vectors onto the dual cone of a matrix. Specifically, the 19 | solution should satisfy the equivalent set of conditions described in Lemma 4 of [1]. 20 | 21 | Let `u` be a vector of weights and `G` a positive semi-definite matrix. Consider the quadratic 22 | problem of minimizing `v^T G v` subject to `u \preceq v`. 23 | 24 | Then `w` is a solution if and only if it satisfies the following three conditions: 25 | 1. **Dual feasibility:** `u \preceq w` 26 | 2. **Primal feasibility:** `0 \preceq G w` 27 | 3. **Complementary slackness:** `u^T G w = w^T G w` 28 | 29 | Reference: 30 | [1] `Jacobian Descent For Multi-Objective Optimization `_. 31 | """ 32 | 33 | J = randn_(shape) 34 | G = J @ J.T 35 | u = rand_(shape[0]) 36 | 37 | w = project_weights(u, G, "quadprog") 38 | dual_gap = w - u 39 | 40 | # Dual feasibility 41 | dual_gap_positive_part = dual_gap[dual_gap >= 0.0] 42 | assert_close(dual_gap_positive_part.norm(), dual_gap.norm(), atol=1e-05, rtol=0) 43 | 44 | primal_gap = G @ w 45 | 46 | # Primal feasibility 47 | primal_gap_positive_part = primal_gap[primal_gap >= 0] 48 | assert_close(primal_gap_positive_part.norm(), primal_gap.norm(), atol=1e-04, rtol=0) 49 | 50 | # Complementary slackness 51 | slackness = dual_gap @ primal_gap 52 | assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0) 53 | 54 | 55 | @mark.parametrize("shape", [(5, 7), (9, 37), (32, 114)]) 56 | @mark.parametrize("scaling", [2 ** (-4), 2 ** (-2), 2**2, 2**4]) 57 | def test_scale_invariant(shape: tuple[int, int], scaling: float): 58 | """ 59 | Tests that `_project_weights` is invariant under scaling. 60 | """ 61 | 62 | J = randn_(shape) 63 | G = J @ J.T 64 | u = rand_(shape[0]) 65 | 66 | w = project_weights(u, G, "quadprog") 67 | w_scaled = project_weights(u, scaling * G, "quadprog") 68 | 69 | assert_close(w_scaled, w) 70 | 71 | 72 | @mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)]) 73 | def test_tensorization_shape(shape: tuple[int, ...]): 74 | """ 75 | Tests that applying `_project_weights` on a tensor is equivalent to applying it on the tensor 76 | reshaped as matrix and to reshape the result back to the original tensor's shape. 77 | """ 78 | 79 | matrix = randn_([shape[-1], shape[-1]]) 80 | U_tensor = randn_(shape) 81 | U_matrix = U_tensor.reshape([-1, shape[-1]]) 82 | 83 | G = matrix @ matrix.T 84 | 85 | W_tensor = project_weights(U_tensor, G, "quadprog") 86 | W_matrix = project_weights(U_matrix, G, "quadprog") 87 | 88 | assert_close(W_matrix.reshape(shape), W_tensor) 89 | 90 | 91 | def test_project_weight_vector_failure(): 92 | """Tests that `_project_weight_vector` raises an error when the input G has too large values.""" 93 | 94 | large_J = np.random.randn(10, 100) * 1e5 95 | large_G = large_J @ large_J.T 96 | with raises(ValueError): 97 | _project_weight_vector(np.ones(10), large_G, "quadprog") 98 | -------------------------------------------------------------------------------- /tests/unit/autojac/_transform/test_diagonalize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytest import raises 3 | from utils.dict_assertions import assert_tensor_dicts_are_close 4 | from utils.tensors import tensor_ 5 | 6 | from torchjd.autojac._transform import Diagonalize, OrderedSet, RequirementError 7 | 8 | 9 | def test_single_input(): 10 | """Tests that the Diagonalize transform works when given a single input.""" 11 | 12 | key = tensor_([1.0, 2.0, 3.0]) 13 | value = torch.ones_like(key) 14 | input = {key: value} 15 | 16 | diag = Diagonalize(OrderedSet([key])) 17 | 18 | output = diag(input) 19 | expected_output = {key: tensor_([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])} 20 | 21 | assert_tensor_dicts_are_close(output, expected_output) 22 | 23 | 24 | def test_multiple_inputs(): 25 | """Tests that the Diagonalize transform works when given multiple inputs.""" 26 | 27 | key1 = tensor_([[1.0, 2.0], [4.0, 5.0]]) 28 | key2 = tensor_([1.0, 3.0, 5.0]) 29 | key3 = tensor_(1.0) 30 | value1 = torch.ones_like(key1) 31 | value2 = torch.ones_like(key2) 32 | value3 = torch.ones_like(key3) 33 | input = {key1: value1, key2: value2, key3: value3} 34 | 35 | diag = Diagonalize(OrderedSet([key1, key2, key3])) 36 | 37 | output = diag(input) 38 | expected_output = { 39 | key1: tensor_( 40 | [ 41 | [[1.0, 0.0], [0.0, 0.0]], 42 | [[0.0, 1.0], [0.0, 0.0]], 43 | [[0.0, 0.0], [1.0, 0.0]], 44 | [[0.0, 0.0], [0.0, 1.0]], 45 | [[0.0, 0.0], [0.0, 0.0]], 46 | [[0.0, 0.0], [0.0, 0.0]], 47 | [[0.0, 0.0], [0.0, 0.0]], 48 | [[0.0, 0.0], [0.0, 0.0]], 49 | ], 50 | ), 51 | key2: tensor_( 52 | [ 53 | [0.0, 0.0, 0.0], 54 | [0.0, 0.0, 0.0], 55 | [0.0, 0.0, 0.0], 56 | [0.0, 0.0, 0.0], 57 | [1.0, 0.0, 0.0], 58 | [0.0, 1.0, 0.0], 59 | [0.0, 0.0, 1.0], 60 | [0.0, 0.0, 0.0], 61 | ], 62 | ), 63 | key3: tensor_( 64 | [ 65 | 0.0, 66 | 0.0, 67 | 0.0, 68 | 0.0, 69 | 0.0, 70 | 0.0, 71 | 0.0, 72 | 1.0, 73 | ], 74 | ), 75 | } 76 | 77 | assert_tensor_dicts_are_close(output, expected_output) 78 | 79 | 80 | def test_permute_order(): 81 | """ 82 | Tests that the Diagonalize transform outputs a permuted mapping when its keys are permuted. 83 | """ 84 | 85 | key1 = tensor_(2.0) 86 | key2 = tensor_(1.0) 87 | value1 = torch.ones_like(key1) 88 | value2 = torch.ones_like(key2) 89 | input = {key1: value1, key2: value2} 90 | 91 | permuted_diag = Diagonalize(OrderedSet([key2, key1])) 92 | diag = Diagonalize(OrderedSet([key1, key2])) 93 | 94 | permuted_output = permuted_diag(input) 95 | output = {key1: permuted_output[key2], key2: permuted_output[key1]} # un-permute 96 | expected_output = diag(input) 97 | 98 | assert_tensor_dicts_are_close(output, expected_output) 99 | 100 | 101 | def test_check_keys(): 102 | """ 103 | Tests that the `check_keys` method works correctly. The input_keys must match the stored 104 | considered keys. 105 | """ 106 | 107 | key1 = tensor_([1.0]) 108 | key2 = tensor_([1.0]) 109 | diag = Diagonalize(OrderedSet([key1])) 110 | 111 | output_keys = diag.check_keys({key1}) 112 | assert output_keys == {key1} 113 | 114 | with raises(RequirementError): 115 | diag.check_keys(set()) 116 | 117 | with raises(RequirementError): 118 | diag.check_keys({key1, key2}) 119 | -------------------------------------------------------------------------------- /docs/source/icons/TorchJD_text_dark.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | TorchJD 79 | -------------------------------------------------------------------------------- /src/torchjd/aggregation/_krum.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import functional as F 4 | 5 | from ._aggregator_bases import GramianWeightedAggregator 6 | from ._weighting_bases import PSDMatrix, Weighting 7 | 8 | 9 | class Krum(GramianWeightedAggregator): 10 | """ 11 | :class:`~torchjd.aggregation._aggregator_bases.Aggregator` for adversarial federated learning, 12 | as defined in `Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent 13 | `_. 14 | 15 | :param n_byzantine: The number of rows of the input matrix that can come from an adversarial 16 | source. 17 | :param n_selected: The number of selected rows in the context of Multi-Krum. Defaults to 1. 18 | """ 19 | 20 | def __init__(self, n_byzantine: int, n_selected: int = 1): 21 | self._n_byzantine = n_byzantine 22 | self._n_selected = n_selected 23 | super().__init__(KrumWeighting(n_byzantine=n_byzantine, n_selected=n_selected)) 24 | 25 | def __repr__(self) -> str: 26 | return ( 27 | f"{self.__class__.__name__}(n_byzantine={self._n_byzantine}, n_selected=" 28 | f"{self._n_selected})" 29 | ) 30 | 31 | def __str__(self) -> str: 32 | return f"Krum{self._n_byzantine}-{self._n_selected}" 33 | 34 | 35 | class KrumWeighting(Weighting[PSDMatrix]): 36 | """ 37 | :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of 38 | :class:`~torchjd.aggregation.Krum`. 39 | 40 | :param n_byzantine: The number of rows of the input matrix that can come from an adversarial 41 | source. 42 | :param n_selected: The number of selected rows in the context of Multi-Krum. Defaults to 1. 43 | """ 44 | 45 | def __init__(self, n_byzantine: int, n_selected: int = 1): 46 | super().__init__() 47 | if n_byzantine < 0: 48 | raise ValueError( 49 | "Parameter `n_byzantine` should be a non-negative integer. Found `n_byzantine = " 50 | f"{n_byzantine}`." 51 | ) 52 | 53 | if n_selected < 1: 54 | raise ValueError( 55 | "Parameter `n_selected` should be a positive integer. Found `n_selected = " 56 | f"{n_selected}`." 57 | ) 58 | 59 | self.n_byzantine = n_byzantine 60 | self.n_selected = n_selected 61 | 62 | def forward(self, gramian: Tensor) -> Tensor: 63 | self._check_matrix_shape(gramian) 64 | gradient_norms_squared = torch.diagonal(gramian) 65 | distances_squared = ( 66 | gradient_norms_squared.unsqueeze(0) + gradient_norms_squared.unsqueeze(1) - 2 * gramian 67 | ) 68 | distances = torch.sqrt(distances_squared) 69 | 70 | n_closest = gramian.shape[0] - self.n_byzantine - 2 71 | smallest_distances, _ = torch.topk(distances, k=n_closest + 1, largest=False) 72 | smallest_distances_excluding_self = smallest_distances[:, 1:] 73 | scores = smallest_distances_excluding_self.sum(dim=1) 74 | 75 | _, selected_indices = torch.topk(scores, k=self.n_selected, largest=False) 76 | one_hot_selected_indices = F.one_hot(selected_indices, num_classes=gramian.shape[0]) 77 | weights = one_hot_selected_indices.sum(dim=0).to(dtype=gramian.dtype) / self.n_selected 78 | 79 | return weights 80 | 81 | def _check_matrix_shape(self, gramian: Tensor) -> None: 82 | min_rows = self.n_byzantine + 3 83 | if gramian.shape[0] < min_rows: 84 | raise ValueError( 85 | f"Parameter `gramian` should have at least {min_rows} rows (n_byzantine + 3). Found" 86 | f" `gramian` with {gramian.shape[0]} rows." 87 | ) 88 | 89 | if gramian.shape[0] < self.n_selected: 90 | raise ValueError( 91 | f"Parameter `gramian` should have at least {self.n_selected} rows (n_selected). " 92 | f"Found `gramian` with {gramian.shape[0]} rows." 93 | ) 94 | -------------------------------------------------------------------------------- /tests/speed/autogram/grad_vs_jac_vs_gram.py: -------------------------------------------------------------------------------- 1 | import gc 2 | 3 | import torch 4 | from device import DEVICE 5 | from utils.architectures import ( 6 | AlexNet, 7 | Cifar10Model, 8 | FreeParam, 9 | GroupNormMobileNetV3Small, 10 | InstanceNormMobileNetV2, 11 | InstanceNormResNet18, 12 | ModuleFactory, 13 | NoFreeParam, 14 | SqueezeNet, 15 | WithTransformerLarge, 16 | ) 17 | from utils.forward_backwards import ( 18 | autograd_forward_backward, 19 | autograd_gramian_forward_backward, 20 | autogram_forward_backward, 21 | autojac_forward_backward, 22 | make_mse_loss_fn, 23 | ) 24 | from utils.tensors import make_inputs_and_targets 25 | 26 | from tests.speed.utils import print_times, time_call 27 | from torchjd.aggregation import Mean 28 | from torchjd.autogram import Engine 29 | 30 | PARAMETRIZATIONS = [ 31 | (ModuleFactory(WithTransformerLarge), 8), 32 | (ModuleFactory(FreeParam), 64), 33 | (ModuleFactory(NoFreeParam), 64), 34 | (ModuleFactory(Cifar10Model), 64), 35 | (ModuleFactory(AlexNet), 8), 36 | (ModuleFactory(InstanceNormResNet18), 16), 37 | (ModuleFactory(GroupNormMobileNetV3Small), 16), 38 | (ModuleFactory(SqueezeNet), 4), 39 | (ModuleFactory(InstanceNormMobileNetV2), 2), 40 | ] 41 | 42 | 43 | def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_size: int): 44 | model = factory() 45 | inputs, targets = make_inputs_and_targets(model, batch_size) 46 | loss_fn = make_mse_loss_fn(targets) 47 | 48 | A = Mean() 49 | W = A.weighting 50 | 51 | print( 52 | f"\nTimes for forward + backward on {factory} with BS={batch_size}, A={A}" f" on {DEVICE}." 53 | ) 54 | 55 | def fn_autograd(): 56 | autograd_forward_backward(model, inputs, loss_fn) 57 | 58 | def init_fn_autograd(): 59 | torch.cuda.empty_cache() 60 | gc.collect() 61 | fn_autograd() 62 | 63 | def fn_autograd_gramian(): 64 | autograd_gramian_forward_backward(model, inputs, loss_fn, W) 65 | 66 | def init_fn_autograd_gramian(): 67 | torch.cuda.empty_cache() 68 | gc.collect() 69 | fn_autograd_gramian() 70 | 71 | def fn_autojac(): 72 | autojac_forward_backward(model, inputs, loss_fn, A) 73 | 74 | def init_fn_autojac(): 75 | torch.cuda.empty_cache() 76 | gc.collect() 77 | fn_autojac() 78 | 79 | def fn_autogram(): 80 | autogram_forward_backward(model, inputs, loss_fn, engine, W) 81 | 82 | def init_fn_autogram(): 83 | torch.cuda.empty_cache() 84 | gc.collect() 85 | fn_autogram() 86 | 87 | def optionally_cuda_sync(): 88 | if str(DEVICE).startswith("cuda"): 89 | torch.cuda.synchronize() 90 | 91 | def pre_fn(): 92 | model.zero_grad() 93 | optionally_cuda_sync() 94 | 95 | def post_fn(): 96 | optionally_cuda_sync() 97 | 98 | n_runs = 10 99 | autograd_times = time_call(fn_autograd, init_fn_autograd, pre_fn, post_fn, n_runs) 100 | print_times("autograd", autograd_times) 101 | 102 | autograd_gramian_times = time_call( 103 | fn_autograd_gramian, init_fn_autograd_gramian, pre_fn, post_fn, n_runs 104 | ) 105 | print_times("autograd gramian", autograd_gramian_times) 106 | 107 | autojac_times = time_call(fn_autojac, init_fn_autojac, pre_fn, post_fn, n_runs) 108 | print_times("autojac", autojac_times) 109 | 110 | engine = Engine(model, batch_dim=0) 111 | autogram_times = time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs) 112 | print_times("autogram", autogram_times) 113 | 114 | 115 | def main(): 116 | for factory, batch_size in PARAMETRIZATIONS: 117 | compare_autograd_autojac_and_autogram_speed(factory, batch_size) 118 | print("\n") 119 | 120 | 121 | if __name__ == "__main__": 122 | # To test this on cuda, add the following environment variables when running this: 123 | # CUBLAS_WORKSPACE_CONFIG=:4096:8;PYTEST_TORCH_DEVICE=cuda:0 124 | main() 125 | --------------------------------------------------------------------------------