├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── gradnorm.png
├── gradnorm_pytorch
├── __init__.py
├── gradnorm_pytorch.py
└── mocks.py
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## GradNorm - Pytorch
4 |
5 | A practical implementation of GradNorm, Gradient Normalization for Adaptive Loss Balancing, in Pytorch
6 |
7 | Increasingly starting to come across neural network architectures that require more than 3 auxiliary losses, so will build out an installable package that easily handles loss balancing in distributed setting, gradient accumulation, etc. Also open to incorporating any follow up research; just let me know in the issues.
8 |
9 | Will be dog-fooded for SoundStream, MagViT2 as well as MetNet3
10 |
11 | ## Appreciation
12 |
13 | - StabilityAI, A16Z Open Source AI Grant Program, and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
14 |
15 | ## Install
16 |
17 | ```bash
18 | $ pip install gradnorm-pytorch
19 | ```
20 |
21 | ## Usage
22 |
23 | ```python
24 | import torch
25 |
26 | from gradnorm_pytorch import (
27 | GradNormLossWeighter,
28 | MockNetworkWithMultipleLosses
29 | )
30 |
31 | # a mock network with multiple discriminator losses
32 |
33 | network = MockNetworkWithMultipleLosses(
34 | dim = 512,
35 | num_losses = 4
36 | )
37 |
38 | # backbone shared parameter
39 |
40 | backbone_parameter = network.backbone[-1].weight
41 |
42 | # grad norm based loss weighter
43 |
44 | loss_weighter = GradNormLossWeighter(
45 | num_losses = 4,
46 | learning_rate = 1e-4,
47 | restoring_force_alpha = 0., # 0. is perfectly balanced losses, while anything greater than 1 would account for the relative training rates of each loss. in the paper, they go as high as 3.
48 | grad_norm_parameters = backbone_parameter
49 | )
50 |
51 | # mock input
52 |
53 | mock_input = torch.randn(2, 512)
54 | losses, backbone_output_activations = network(mock_input)
55 |
56 | # backwards with the loss weights
57 | # will update on each backward based on gradnorm algorithm
58 |
59 | loss_weighter.backward(losses, retain_graph = True)
60 |
61 | # if you would like to update the loss weights wrt activations just do the following instead
62 |
63 | loss_weighter.backward(losses, backbone_output_activations)
64 | ```
65 |
66 | You can also switch it to basic static loss weighting, in case you want to run experiments against fixed weighting.
67 |
68 | ```python
69 | loss_weighter = GradNormLossWeighter(
70 | loss_weights = [1., 10., 5., 2.],
71 | ...,
72 | frozen = True
73 | )
74 |
75 | # or you can also freeze it on invoking the instance
76 |
77 | loss_weighter.backward(..., freeze = True)
78 | ```
79 |
80 | For use with 🤗 Huggingface Accelerate, just pass in the `Accelerator` instance into the keyword `accelerator` on initialization
81 |
82 | ex.
83 |
84 | ```python
85 | accelerator = Accelerator()
86 |
87 | network = accelerator.prepare(network)
88 |
89 | loss_weighter = GradNormLossWeighter(
90 | ...,
91 | accelerator = accelerator
92 | )
93 |
94 | # backwards will now use accelerator
95 | ```
96 |
97 | ## Todo
98 |
99 | - [x] take care of gradient accumulation
100 | - [ ] handle sets of loss weights
101 | - [ ] handle freezing of some loss weights, but not others
102 | - [ ] allow for a prior weighting, accounted for when calculating gradient targets
103 |
104 | ## Citations
105 |
106 | ```bibtex
107 | @article{Chen2017GradNormGN,
108 | title = {GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks},
109 | author = {Zhao Chen and Vijay Badrinarayanan and Chen-Yu Lee and Andrew Rabinovich},
110 | journal = {ArXiv},
111 | year = {2017},
112 | volume = {abs/1711.02257},
113 | url = {https://api.semanticscholar.org/CorpusID:4703661}
114 | }
115 | ```
116 |
--------------------------------------------------------------------------------
/gradnorm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/gradnorm-pytorch/4f793905161b471b0d63c3e4629d265c908e1dfd/gradnorm.png
--------------------------------------------------------------------------------
/gradnorm_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from gradnorm_pytorch.gradnorm_pytorch import GradNormLossWeighter
2 | from gradnorm_pytorch.mocks import MockNetworkWithMultipleLosses
3 |
--------------------------------------------------------------------------------
/gradnorm_pytorch/gradnorm_pytorch.py:
--------------------------------------------------------------------------------
1 | from functools import cache, partial
2 |
3 | import torch
4 | import torch.distributed as dist
5 | from torch.autograd import grad
6 | import torch.nn.functional as F
7 | from torch import nn, einsum, Tensor
8 | from torch.nn import Module, ModuleList, Parameter
9 |
10 | from einops import rearrange, repeat
11 |
12 | from accelerate import Accelerator
13 |
14 | from beartype import beartype
15 | from beartype.door import is_bearable
16 | from beartype.typing import Optional, Union, List, Dict, Tuple, NamedTuple
17 |
18 | # helper functions
19 |
20 | def exists(v):
21 | return v is not None
22 |
23 | def default(v, d):
24 | return v if exists(v) else d
25 |
26 | # tensor helpers
27 |
28 | def l1norm(t, dim = -1):
29 | return F.normalize(t, p = 1, dim = dim)
30 |
31 | # distributed helpers
32 |
33 | @cache
34 | def is_distributed():
35 | return dist.is_initialized() and dist.get_world_size() > 1
36 |
37 | def maybe_distributed_mean(t):
38 | if not is_distributed():
39 | return t
40 |
41 | dist.all_reduce(t)
42 | t = t / dist.get_world_size()
43 | return t
44 |
45 | # main class
46 |
47 | class GradNormLossWeighter(Module):
48 | @beartype
49 | def __init__(
50 | self,
51 | *,
52 | num_losses: Optional[int] = None,
53 | loss_weights: Optional[Union[
54 | List[float],
55 | Tensor
56 | ]] = None,
57 | loss_names: Optional[Tuple[str, ...]] = None,
58 | learning_rate = 1e-4,
59 | restoring_force_alpha = 0.,
60 | grad_norm_parameters: Optional[Parameter] = None,
61 | accelerator: Optional[Accelerator] = None,
62 | frozen = False,
63 | initial_losses_decay = 1.,
64 | update_after_step = 0.,
65 | update_every = 1.
66 | ):
67 | super().__init__()
68 | assert exists(num_losses) or exists(loss_weights)
69 |
70 | if exists(loss_weights):
71 | if isinstance(loss_weights, list):
72 | loss_weights = torch.tensor(loss_weights)
73 |
74 | num_losses = default(num_losses, loss_weights.numel())
75 | else:
76 | loss_weights = torch.ones((num_losses,), dtype = torch.float32)
77 |
78 | assert len(loss_weights) == num_losses
79 | assert num_losses > 1, 'only makes sense if you have multiple losses'
80 | assert loss_weights.ndim == 1, 'loss weights must be 1 dimensional'
81 |
82 | self.accelerator = accelerator
83 | self.num_losses = num_losses
84 | self.frozen = frozen
85 |
86 | self.loss_names = loss_names
87 | assert not exists(loss_names) or len(loss_names) == num_losses
88 |
89 | assert restoring_force_alpha >= 0.
90 |
91 | self.alpha = restoring_force_alpha
92 | self.has_restoring_force = self.alpha > 0
93 |
94 | self._grad_norm_parameters = [grad_norm_parameters] # hack
95 |
96 | # loss weights, either learned or static
97 |
98 | self.register_buffer('loss_weights', loss_weights)
99 |
100 | self.learning_rate = learning_rate
101 |
102 | # initial loss
103 | # if initial loss decay set to less than 1, will EMA smooth the initial loss
104 |
105 | assert 0 <= initial_losses_decay <= 1.
106 | self.initial_losses_decay = initial_losses_decay
107 |
108 | self.register_buffer('initial_losses', torch.zeros(num_losses))
109 |
110 | # for renormalizing loss weights at end
111 |
112 | self.register_buffer('loss_weights_sum', self.loss_weights.sum())
113 |
114 | # for gradient accumulation
115 |
116 | self.register_buffer('loss_weights_grad', torch.zeros_like(loss_weights), persistent = False)
117 |
118 | # step, for maybe having schedules etc
119 |
120 | self.register_buffer('step', torch.tensor(0.))
121 |
122 | # can update less frequently, to save on compute
123 |
124 | self.update_after_step = update_after_step
125 | self.update_every = update_every
126 |
127 | self.register_buffer('initted', torch.tensor(False))
128 |
129 | @property
130 | def grad_norm_parameters(self):
131 | return self._grad_norm_parameters[0]
132 |
133 | def backward(self, *args, **kwargs):
134 | return self.forward(*args, **kwargs)
135 |
136 | @beartype
137 | def forward(
138 | self,
139 | losses: Union[
140 | Dict[str, Tensor],
141 | List[Tensor],
142 | Tuple[Tensor],
143 | Tensor
144 | ],
145 | activations: Optional[Tensor] = None, # in the paper, they used the grad norm of penultimate parameters from a backbone layer. but this could also be activations (say shared image being fed to multiple discriminators)
146 | freeze = False, # can optionally freeze the learnable loss weights on forward
147 | scale = 1.,
148 | grad_step = True,
149 | **backward_kwargs
150 | ):
151 | # backward functions dependent on whether using hf accelerate or not
152 |
153 | backward = self.accelerator.backward if exists(self.accelerator) else lambda l, **kwargs: l.backward(**kwargs)
154 | backward = partial(backward, **backward_kwargs)
155 |
156 | # increment step
157 |
158 | step = self.step.item()
159 |
160 | self.step.add_(int(self.training and grad_step))
161 |
162 | # loss can be passed in as a dictionary of Dict[str, Tensor], will be ordered by the `loss_names` passed in on init
163 |
164 | if isinstance(losses, tuple) and hasattr(losses, '_asdict'):
165 | losses = losses._asdict()
166 |
167 | if isinstance(losses, dict):
168 | assert exists(self.loss_names)
169 | input_loss_names = set(losses.keys())
170 | assert input_loss_names == set(self.loss_names), f'expect losses named {self.loss_names} but received {input_loss_names}'
171 |
172 | losses = [losses[name] for name in self.loss_names]
173 |
174 | # validate that all the losses are a single scalar
175 |
176 | assert all([loss.numel() == 1 for loss in losses])
177 |
178 | # cast losses to tensor form
179 |
180 | if isinstance(losses, (list, tuple)):
181 | losses = torch.stack(losses)
182 |
183 | # auto move gradnorm module to the device of the losses
184 |
185 | if self.initted.device != losses.device:
186 | self.to(losses.device)
187 |
188 | assert losses.ndim == 1, 'losses must be 1 dimensional'
189 | assert losses.numel() == self.num_losses, f'you instantiated with {self.num_losses} losses but passed in {losses.numel()} losses'
190 |
191 | total_weighted_loss = (losses * self.loss_weights.detach()).sum()
192 |
193 | backward(total_weighted_loss * scale, **{**backward_kwargs, 'retain_graph': not freeze})
194 |
195 | # handle base frozen case, so one can freeze the weights after a certain number of steps, or just to a/b test against learned gradnorm loss weights
196 |
197 | if (
198 | self.frozen or \
199 | freeze or \
200 | not self.training or \
201 | step < self.update_after_step or \
202 | (step % self.update_every) != 0
203 | ):
204 | return total_weighted_loss
205 |
206 | # store initial loss
207 |
208 | if self.has_restoring_force:
209 | if not self.initted.item():
210 | initial_losses = maybe_distributed_mean(losses)
211 | self.initial_losses.copy_(initial_losses)
212 | self.initted.copy_(True)
213 |
214 | elif self.initial_losses_decay < 1.:
215 | meaned_losses = maybe_distributed_mean(losses)
216 | self.initial_losses.lerp_(meaned_losses, 1. - self.initial_losses_decay)
217 |
218 | # determine which tensor to get grad norm from
219 |
220 | grad_norm_tensor = default(activations, self.grad_norm_parameters)
221 |
222 | assert exists(grad_norm_tensor), 'you need to either set `grad_norm_parameters` on init or `activations` on backwards'
223 |
224 | grad_norm_tensor.requires_grad_()
225 |
226 | # get grad norm with respect to each loss
227 |
228 | grad_norms = []
229 | loss_weights = self.loss_weights.clone()
230 | loss_weights = Parameter(loss_weights)
231 |
232 | for weight, loss in zip(loss_weights, losses):
233 | gradients, = grad(weight * loss, grad_norm_tensor, create_graph = True, retain_graph = True)
234 |
235 | grad_norm = gradients.norm(p = 2)
236 | grad_norms.append(grad_norm)
237 |
238 | grad_norms = torch.stack(grad_norms)
239 |
240 | # main algorithm for loss balancing
241 |
242 | grad_norm_average = maybe_distributed_mean(grad_norms.mean())
243 |
244 | if self.has_restoring_force:
245 | loss_ratio = losses.detach() / self.initial_losses
246 |
247 | relative_training_rate = l1norm(loss_ratio) * self.num_losses
248 |
249 | gradient_target = (grad_norm_average * (relative_training_rate ** self.alpha)).detach()
250 | else:
251 | gradient_target = repeat(grad_norm_average, ' -> l', l = self.num_losses).detach()
252 |
253 | grad_norm_loss = F.l1_loss(grad_norms, gradient_target)
254 |
255 | backward(grad_norm_loss * scale)
256 |
257 | # accumulate gradients
258 |
259 | self.loss_weights_grad.add_(loss_weights.grad)
260 |
261 | if not grad_step:
262 | return
263 |
264 | # manually take a single gradient step
265 |
266 | updated_loss_weights = loss_weights - self.loss_weights_grad * self.learning_rate
267 |
268 | renormalized_loss_weights = l1norm(updated_loss_weights) * self.loss_weights_sum
269 |
270 | self.loss_weights.copy_(renormalized_loss_weights)
271 |
272 | self.loss_weights_grad.zero_()
273 |
--------------------------------------------------------------------------------
/gradnorm_pytorch/mocks.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | class MockNetworkWithMultipleLosses(nn.Module):
4 | def __init__(
5 | self,
6 | dim,
7 | num_losses = 2
8 | ):
9 | super().__init__()
10 | self.backbone = nn.Sequential(
11 | nn.Linear(dim, dim),
12 | nn.SiLU(),
13 | nn.Linear(dim, dim)
14 | )
15 |
16 | self.discriminators = nn.ModuleList([
17 | nn.Linear(dim, 1) for _ in range(num_losses)
18 | ])
19 |
20 | def forward(self, x):
21 | backbone_output = self.backbone(x)
22 |
23 | losses = []
24 |
25 | for discr in self.discriminators:
26 | loss = discr(backbone_output)
27 | losses.append(loss.mean())
28 |
29 | return losses, backbone_output
30 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'gradnorm-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.0.26',
7 | license='MIT',
8 | description = 'GradNorm - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/gradnorm-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'loss balancing',
17 | 'gradient normalization'
18 | ],
19 | install_requires=[
20 | 'accelerate',
21 | 'beartype',
22 | 'einops>=0.7.0',
23 | 'torch>=2.0'
24 | ],
25 | classifiers=[
26 | 'Development Status :: 4 - Beta',
27 | 'Intended Audience :: Developers',
28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
29 | 'License :: OSI Approved :: MIT License',
30 | 'Programming Language :: Python :: 3.6',
31 | ],
32 | )
33 |
--------------------------------------------------------------------------------