├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── gradnorm.png ├── gradnorm_pytorch ├── __init__.py ├── gradnorm_pytorch.py └── mocks.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## GradNorm - Pytorch 4 | 5 | A practical implementation of GradNorm, Gradient Normalization for Adaptive Loss Balancing, in Pytorch 6 | 7 | Increasingly starting to come across neural network architectures that require more than 3 auxiliary losses, so will build out an installable package that easily handles loss balancing in distributed setting, gradient accumulation, etc. Also open to incorporating any follow up research; just let me know in the issues. 8 | 9 | Will be dog-fooded for SoundStream, MagViT2 as well as MetNet3 10 | 11 | ## Appreciation 12 | 13 | - StabilityAI, A16Z Open Source AI Grant Program, and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research 14 | 15 | ## Install 16 | 17 | ```bash 18 | $ pip install gradnorm-pytorch 19 | ``` 20 | 21 | ## Usage 22 | 23 | ```python 24 | import torch 25 | 26 | from gradnorm_pytorch import ( 27 | GradNormLossWeighter, 28 | MockNetworkWithMultipleLosses 29 | ) 30 | 31 | # a mock network with multiple discriminator losses 32 | 33 | network = MockNetworkWithMultipleLosses( 34 | dim = 512, 35 | num_losses = 4 36 | ) 37 | 38 | # backbone shared parameter 39 | 40 | backbone_parameter = network.backbone[-1].weight 41 | 42 | # grad norm based loss weighter 43 | 44 | loss_weighter = GradNormLossWeighter( 45 | num_losses = 4, 46 | learning_rate = 1e-4, 47 | restoring_force_alpha = 0., # 0. is perfectly balanced losses, while anything greater than 1 would account for the relative training rates of each loss. in the paper, they go as high as 3. 48 | grad_norm_parameters = backbone_parameter 49 | ) 50 | 51 | # mock input 52 | 53 | mock_input = torch.randn(2, 512) 54 | losses, backbone_output_activations = network(mock_input) 55 | 56 | # backwards with the loss weights 57 | # will update on each backward based on gradnorm algorithm 58 | 59 | loss_weighter.backward(losses, retain_graph = True) 60 | 61 | # if you would like to update the loss weights wrt activations just do the following instead 62 | 63 | loss_weighter.backward(losses, backbone_output_activations) 64 | ``` 65 | 66 | You can also switch it to basic static loss weighting, in case you want to run experiments against fixed weighting. 67 | 68 | ```python 69 | loss_weighter = GradNormLossWeighter( 70 | loss_weights = [1., 10., 5., 2.], 71 | ..., 72 | frozen = True 73 | ) 74 | 75 | # or you can also freeze it on invoking the instance 76 | 77 | loss_weighter.backward(..., freeze = True) 78 | ``` 79 | 80 | For use with 🤗 Huggingface Accelerate, just pass in the `Accelerator` instance into the keyword `accelerator` on initialization 81 | 82 | ex. 83 | 84 | ```python 85 | accelerator = Accelerator() 86 | 87 | network = accelerator.prepare(network) 88 | 89 | loss_weighter = GradNormLossWeighter( 90 | ..., 91 | accelerator = accelerator 92 | ) 93 | 94 | # backwards will now use accelerator 95 | ``` 96 | 97 | ## Todo 98 | 99 | - [x] take care of gradient accumulation 100 | - [ ] handle sets of loss weights 101 | - [ ] handle freezing of some loss weights, but not others 102 | - [ ] allow for a prior weighting, accounted for when calculating gradient targets 103 | 104 | ## Citations 105 | 106 | ```bibtex 107 | @article{Chen2017GradNormGN, 108 | title = {GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks}, 109 | author = {Zhao Chen and Vijay Badrinarayanan and Chen-Yu Lee and Andrew Rabinovich}, 110 | journal = {ArXiv}, 111 | year = {2017}, 112 | volume = {abs/1711.02257}, 113 | url = {https://api.semanticscholar.org/CorpusID:4703661} 114 | } 115 | ``` 116 | -------------------------------------------------------------------------------- /gradnorm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/gradnorm-pytorch/4f793905161b471b0d63c3e4629d265c908e1dfd/gradnorm.png -------------------------------------------------------------------------------- /gradnorm_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from gradnorm_pytorch.gradnorm_pytorch import GradNormLossWeighter 2 | from gradnorm_pytorch.mocks import MockNetworkWithMultipleLosses 3 | -------------------------------------------------------------------------------- /gradnorm_pytorch/gradnorm_pytorch.py: -------------------------------------------------------------------------------- 1 | from functools import cache, partial 2 | 3 | import torch 4 | import torch.distributed as dist 5 | from torch.autograd import grad 6 | import torch.nn.functional as F 7 | from torch import nn, einsum, Tensor 8 | from torch.nn import Module, ModuleList, Parameter 9 | 10 | from einops import rearrange, repeat 11 | 12 | from accelerate import Accelerator 13 | 14 | from beartype import beartype 15 | from beartype.door import is_bearable 16 | from beartype.typing import Optional, Union, List, Dict, Tuple, NamedTuple 17 | 18 | # helper functions 19 | 20 | def exists(v): 21 | return v is not None 22 | 23 | def default(v, d): 24 | return v if exists(v) else d 25 | 26 | # tensor helpers 27 | 28 | def l1norm(t, dim = -1): 29 | return F.normalize(t, p = 1, dim = dim) 30 | 31 | # distributed helpers 32 | 33 | @cache 34 | def is_distributed(): 35 | return dist.is_initialized() and dist.get_world_size() > 1 36 | 37 | def maybe_distributed_mean(t): 38 | if not is_distributed(): 39 | return t 40 | 41 | dist.all_reduce(t) 42 | t = t / dist.get_world_size() 43 | return t 44 | 45 | # main class 46 | 47 | class GradNormLossWeighter(Module): 48 | @beartype 49 | def __init__( 50 | self, 51 | *, 52 | num_losses: Optional[int] = None, 53 | loss_weights: Optional[Union[ 54 | List[float], 55 | Tensor 56 | ]] = None, 57 | loss_names: Optional[Tuple[str, ...]] = None, 58 | learning_rate = 1e-4, 59 | restoring_force_alpha = 0., 60 | grad_norm_parameters: Optional[Parameter] = None, 61 | accelerator: Optional[Accelerator] = None, 62 | frozen = False, 63 | initial_losses_decay = 1., 64 | update_after_step = 0., 65 | update_every = 1. 66 | ): 67 | super().__init__() 68 | assert exists(num_losses) or exists(loss_weights) 69 | 70 | if exists(loss_weights): 71 | if isinstance(loss_weights, list): 72 | loss_weights = torch.tensor(loss_weights) 73 | 74 | num_losses = default(num_losses, loss_weights.numel()) 75 | else: 76 | loss_weights = torch.ones((num_losses,), dtype = torch.float32) 77 | 78 | assert len(loss_weights) == num_losses 79 | assert num_losses > 1, 'only makes sense if you have multiple losses' 80 | assert loss_weights.ndim == 1, 'loss weights must be 1 dimensional' 81 | 82 | self.accelerator = accelerator 83 | self.num_losses = num_losses 84 | self.frozen = frozen 85 | 86 | self.loss_names = loss_names 87 | assert not exists(loss_names) or len(loss_names) == num_losses 88 | 89 | assert restoring_force_alpha >= 0. 90 | 91 | self.alpha = restoring_force_alpha 92 | self.has_restoring_force = self.alpha > 0 93 | 94 | self._grad_norm_parameters = [grad_norm_parameters] # hack 95 | 96 | # loss weights, either learned or static 97 | 98 | self.register_buffer('loss_weights', loss_weights) 99 | 100 | self.learning_rate = learning_rate 101 | 102 | # initial loss 103 | # if initial loss decay set to less than 1, will EMA smooth the initial loss 104 | 105 | assert 0 <= initial_losses_decay <= 1. 106 | self.initial_losses_decay = initial_losses_decay 107 | 108 | self.register_buffer('initial_losses', torch.zeros(num_losses)) 109 | 110 | # for renormalizing loss weights at end 111 | 112 | self.register_buffer('loss_weights_sum', self.loss_weights.sum()) 113 | 114 | # for gradient accumulation 115 | 116 | self.register_buffer('loss_weights_grad', torch.zeros_like(loss_weights), persistent = False) 117 | 118 | # step, for maybe having schedules etc 119 | 120 | self.register_buffer('step', torch.tensor(0.)) 121 | 122 | # can update less frequently, to save on compute 123 | 124 | self.update_after_step = update_after_step 125 | self.update_every = update_every 126 | 127 | self.register_buffer('initted', torch.tensor(False)) 128 | 129 | @property 130 | def grad_norm_parameters(self): 131 | return self._grad_norm_parameters[0] 132 | 133 | def backward(self, *args, **kwargs): 134 | return self.forward(*args, **kwargs) 135 | 136 | @beartype 137 | def forward( 138 | self, 139 | losses: Union[ 140 | Dict[str, Tensor], 141 | List[Tensor], 142 | Tuple[Tensor], 143 | Tensor 144 | ], 145 | activations: Optional[Tensor] = None, # in the paper, they used the grad norm of penultimate parameters from a backbone layer. but this could also be activations (say shared image being fed to multiple discriminators) 146 | freeze = False, # can optionally freeze the learnable loss weights on forward 147 | scale = 1., 148 | grad_step = True, 149 | **backward_kwargs 150 | ): 151 | # backward functions dependent on whether using hf accelerate or not 152 | 153 | backward = self.accelerator.backward if exists(self.accelerator) else lambda l, **kwargs: l.backward(**kwargs) 154 | backward = partial(backward, **backward_kwargs) 155 | 156 | # increment step 157 | 158 | step = self.step.item() 159 | 160 | self.step.add_(int(self.training and grad_step)) 161 | 162 | # loss can be passed in as a dictionary of Dict[str, Tensor], will be ordered by the `loss_names` passed in on init 163 | 164 | if isinstance(losses, tuple) and hasattr(losses, '_asdict'): 165 | losses = losses._asdict() 166 | 167 | if isinstance(losses, dict): 168 | assert exists(self.loss_names) 169 | input_loss_names = set(losses.keys()) 170 | assert input_loss_names == set(self.loss_names), f'expect losses named {self.loss_names} but received {input_loss_names}' 171 | 172 | losses = [losses[name] for name in self.loss_names] 173 | 174 | # validate that all the losses are a single scalar 175 | 176 | assert all([loss.numel() == 1 for loss in losses]) 177 | 178 | # cast losses to tensor form 179 | 180 | if isinstance(losses, (list, tuple)): 181 | losses = torch.stack(losses) 182 | 183 | # auto move gradnorm module to the device of the losses 184 | 185 | if self.initted.device != losses.device: 186 | self.to(losses.device) 187 | 188 | assert losses.ndim == 1, 'losses must be 1 dimensional' 189 | assert losses.numel() == self.num_losses, f'you instantiated with {self.num_losses} losses but passed in {losses.numel()} losses' 190 | 191 | total_weighted_loss = (losses * self.loss_weights.detach()).sum() 192 | 193 | backward(total_weighted_loss * scale, **{**backward_kwargs, 'retain_graph': not freeze}) 194 | 195 | # handle base frozen case, so one can freeze the weights after a certain number of steps, or just to a/b test against learned gradnorm loss weights 196 | 197 | if ( 198 | self.frozen or \ 199 | freeze or \ 200 | not self.training or \ 201 | step < self.update_after_step or \ 202 | (step % self.update_every) != 0 203 | ): 204 | return total_weighted_loss 205 | 206 | # store initial loss 207 | 208 | if self.has_restoring_force: 209 | if not self.initted.item(): 210 | initial_losses = maybe_distributed_mean(losses) 211 | self.initial_losses.copy_(initial_losses) 212 | self.initted.copy_(True) 213 | 214 | elif self.initial_losses_decay < 1.: 215 | meaned_losses = maybe_distributed_mean(losses) 216 | self.initial_losses.lerp_(meaned_losses, 1. - self.initial_losses_decay) 217 | 218 | # determine which tensor to get grad norm from 219 | 220 | grad_norm_tensor = default(activations, self.grad_norm_parameters) 221 | 222 | assert exists(grad_norm_tensor), 'you need to either set `grad_norm_parameters` on init or `activations` on backwards' 223 | 224 | grad_norm_tensor.requires_grad_() 225 | 226 | # get grad norm with respect to each loss 227 | 228 | grad_norms = [] 229 | loss_weights = self.loss_weights.clone() 230 | loss_weights = Parameter(loss_weights) 231 | 232 | for weight, loss in zip(loss_weights, losses): 233 | gradients, = grad(weight * loss, grad_norm_tensor, create_graph = True, retain_graph = True) 234 | 235 | grad_norm = gradients.norm(p = 2) 236 | grad_norms.append(grad_norm) 237 | 238 | grad_norms = torch.stack(grad_norms) 239 | 240 | # main algorithm for loss balancing 241 | 242 | grad_norm_average = maybe_distributed_mean(grad_norms.mean()) 243 | 244 | if self.has_restoring_force: 245 | loss_ratio = losses.detach() / self.initial_losses 246 | 247 | relative_training_rate = l1norm(loss_ratio) * self.num_losses 248 | 249 | gradient_target = (grad_norm_average * (relative_training_rate ** self.alpha)).detach() 250 | else: 251 | gradient_target = repeat(grad_norm_average, ' -> l', l = self.num_losses).detach() 252 | 253 | grad_norm_loss = F.l1_loss(grad_norms, gradient_target) 254 | 255 | backward(grad_norm_loss * scale) 256 | 257 | # accumulate gradients 258 | 259 | self.loss_weights_grad.add_(loss_weights.grad) 260 | 261 | if not grad_step: 262 | return 263 | 264 | # manually take a single gradient step 265 | 266 | updated_loss_weights = loss_weights - self.loss_weights_grad * self.learning_rate 267 | 268 | renormalized_loss_weights = l1norm(updated_loss_weights) * self.loss_weights_sum 269 | 270 | self.loss_weights.copy_(renormalized_loss_weights) 271 | 272 | self.loss_weights_grad.zero_() 273 | -------------------------------------------------------------------------------- /gradnorm_pytorch/mocks.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class MockNetworkWithMultipleLosses(nn.Module): 4 | def __init__( 5 | self, 6 | dim, 7 | num_losses = 2 8 | ): 9 | super().__init__() 10 | self.backbone = nn.Sequential( 11 | nn.Linear(dim, dim), 12 | nn.SiLU(), 13 | nn.Linear(dim, dim) 14 | ) 15 | 16 | self.discriminators = nn.ModuleList([ 17 | nn.Linear(dim, 1) for _ in range(num_losses) 18 | ]) 19 | 20 | def forward(self, x): 21 | backbone_output = self.backbone(x) 22 | 23 | losses = [] 24 | 25 | for discr in self.discriminators: 26 | loss = discr(backbone_output) 27 | losses.append(loss.mean()) 28 | 29 | return losses, backbone_output 30 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'gradnorm-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.0.26', 7 | license='MIT', 8 | description = 'GradNorm - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/gradnorm-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'loss balancing', 17 | 'gradient normalization' 18 | ], 19 | install_requires=[ 20 | 'accelerate', 21 | 'beartype', 22 | 'einops>=0.7.0', 23 | 'torch>=2.0' 24 | ], 25 | classifiers=[ 26 | 'Development Status :: 4 - Beta', 27 | 'Intended Audience :: Developers', 28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 29 | 'License :: OSI Approved :: MIT License', 30 | 'Programming Language :: Python :: 3.6', 31 | ], 32 | ) 33 | --------------------------------------------------------------------------------