├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── adam-atan2.png ├── adam_atan2_pytorch ├── __init__.py ├── adam_atan2.py ├── adam_atan2_with_wasserstein_reg.py ├── adopt.py ├── adopt_atan2.py └── foreach.py └── pyproject.toml /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Adam-atan2 - Pytorch 4 | 5 | Implementation of the proposed Adam-atan2 optimizer in Pytorch 6 | 7 | A multi-million dollar paper out of google deepmind proposes a small change to Adam update rule (using `atan2`) to remove the epsilon altogether for numerical stability and scale invariance 8 | 9 | It also contains some features for improving plasticity (continual learning field) 10 | 11 | ## Install 12 | 13 | ```bash 14 | $ pip install adam-atan2-pytorch 15 | ``` 16 | 17 | ## Usage 18 | 19 | ```python 20 | import torch 21 | from torch import nn 22 | 23 | # toy model 24 | 25 | model = nn.Linear(10, 1) 26 | 27 | # import AdamAtan2 and instantiate with parameters 28 | 29 | from adam_atan2_pytorch import AdamAtan2 30 | 31 | opt = AdamAtan2(model.parameters(), lr = 1e-4) 32 | 33 | # forward and backwards 34 | 35 | for _ in range(100): 36 | loss = model(torch.randn(10)) 37 | loss.backward() 38 | 39 | # optimizer step 40 | 41 | opt.step() 42 | opt.zero_grad() 43 | ``` 44 | 45 | ## Citations 46 | 47 | ```bibtex 48 | @inproceedings{Everett2024ScalingEA, 49 | title = {Scaling Exponents Across Parameterizations and Optimizers}, 50 | author = {Katie Everett and Lechao Xiao and Mitchell Wortsman and Alex Alemi and Roman Novak and Peter J. Liu and Izzeddin Gur and Jascha Narain Sohl-Dickstein and Leslie Pack Kaelbling and Jaehoon Lee and Jeffrey Pennington}, 51 | year = {2024}, 52 | url = {https://api.semanticscholar.org/CorpusID:271051056} 53 | } 54 | ``` 55 | 56 | ```bibtex 57 | @inproceedings{Kumar2023MaintainingPI, 58 | title = {Maintaining Plasticity in Continual Learning via Regenerative Regularization}, 59 | author = {Saurabh Kumar and Henrik Marklund and Benjamin Van Roy}, 60 | year = {2023}, 61 | url = {https://api.semanticscholar.org/CorpusID:261076021} 62 | } 63 | ``` 64 | 65 | ```bibtex 66 | @article{Lewandowski2024LearningCB, 67 | title = {Learning Continually by Spectral Regularization}, 68 | author = {Alex Lewandowski and Saurabh Kumar and Dale Schuurmans and Andr'as Gyorgy and Marlos C. Machado}, 69 | journal = {ArXiv}, 70 | year = {2024}, 71 | volume = {abs/2406.06811}, 72 | url = {https://api.semanticscholar.org/CorpusID:270380086} 73 | } 74 | ``` 75 | 76 | ```bibtex 77 | @inproceedings{Taniguchi2024ADOPTMA, 78 | title = {ADOPT: Modified Adam Can Converge with Any \$\beta\_2\$ with the Optimal Rate}, 79 | author = {Shohei Taniguchi and Keno Harada and Gouki Minegishi and Yuta Oshima and Seong Cheol Jeong and Go Nagahara and Tomoshi Iiyama and Masahiro Suzuki and Yusuke Iwasawa and Yutaka Matsuo}, 80 | year = {2024}, 81 | url = {https://api.semanticscholar.org/CorpusID:273822148} 82 | } 83 | ``` 84 | 85 | ```bibtex 86 | @inproceedings{Liang2024CautiousOI, 87 | title = {Cautious Optimizers: Improving Training with One Line of Code}, 88 | author = {Kaizhao Liang and Lizhang Chen and Bo Liu and Qiang Liu}, 89 | year = {2024}, 90 | url = {https://api.semanticscholar.org/CorpusID:274234738} 91 | } 92 | ``` 93 | -------------------------------------------------------------------------------- /adam-atan2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/adam-atan2-pytorch/8f14cf50a89030edbdea5214d22ac268511b400e/adam-atan2.png -------------------------------------------------------------------------------- /adam_atan2_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from adam_atan2_pytorch.adam_atan2 import AdamAtan2 2 | from adam_atan2_pytorch.adopt_atan2 import AdoptAtan2 3 | 4 | Adam = AdamAtan2 5 | Adopt = AdoptAtan2 6 | -------------------------------------------------------------------------------- /adam_atan2_pytorch/adam_atan2.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Callable 3 | 4 | import torch 5 | from torch import atan2, sqrt 6 | from torch.optim.optimizer import Optimizer 7 | 8 | # functions 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | # class 14 | 15 | class AdamAtan2(Optimizer): 16 | def __init__( 17 | self, 18 | params, 19 | lr = 1e-4, 20 | betas: tuple[float, float] = (0.9, 0.99), 21 | weight_decay = 0., 22 | regen_reg_rate = 0., 23 | decoupled_wd = False, 24 | cautious_factor = 1., # set to 0. for zeroing out any updates not in same direction as gradient as in https://arxiv.org/abs/2411.16085 25 | a = 1.27, 26 | b = 1. 27 | ): 28 | assert lr > 0. 29 | assert all([0. <= beta <= 1. for beta in betas]) 30 | assert weight_decay >= 0. 31 | assert regen_reg_rate >= 0. 32 | assert not (weight_decay > 0. and regen_reg_rate > 0.) 33 | assert 0. <= cautious_factor <= 1. 34 | 35 | self._init_lr = lr 36 | self.decoupled_wd = decoupled_wd 37 | 38 | defaults = dict( 39 | lr = lr, 40 | betas = betas, 41 | a = a, 42 | b = b, 43 | weight_decay = weight_decay, 44 | regen_reg_rate = regen_reg_rate, 45 | cautious_factor = cautious_factor 46 | ) 47 | 48 | super().__init__(params, defaults) 49 | 50 | @torch.no_grad() 51 | def step( 52 | self, 53 | closure: Callable | None = None 54 | ): 55 | 56 | loss = None 57 | if exists(closure): 58 | with torch.enable_grad(): 59 | loss = closure() 60 | 61 | for group in self.param_groups: 62 | for p in filter(lambda p: exists(p.grad), group['params']): 63 | 64 | grad, lr, wd, regen_rate, cautious_factor, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], group['cautious_factor'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr 65 | 66 | # maybe decoupled weight decay 67 | 68 | if self.decoupled_wd: 69 | wd /= init_lr 70 | 71 | # weight decay 72 | 73 | if wd > 0.: 74 | p.mul_(1. - lr * wd) 75 | 76 | # regenerative regularization from Kumar et al. https://arxiv.org/abs/2308.11958 77 | 78 | if regen_rate > 0. and 'param_init' in state: 79 | param_init = state['param_init'] 80 | p.lerp_(param_init, lr / init_lr * regen_rate) 81 | 82 | # init state if needed 83 | 84 | if len(state) == 0: 85 | state['steps'] = 0 86 | state['exp_avg'] = torch.zeros_like(grad) 87 | state['exp_avg_sq'] = torch.zeros_like(grad) 88 | 89 | if regen_rate > 0.: 90 | state['param_init'] = p.clone() 91 | 92 | # get some of the states 93 | 94 | exp_avg, exp_avg_sq, steps = state['exp_avg'], state['exp_avg_sq'], state['steps'] 95 | 96 | steps += 1 97 | 98 | # bias corrections 99 | 100 | bias_correct1 = 1. - beta1 ** steps 101 | bias_correct2 = 1. - beta2 ** steps 102 | 103 | # decay running averages 104 | 105 | exp_avg.lerp_(grad, 1. - beta1) 106 | exp_avg_sq.lerp_(grad * grad, 1. - beta2) 107 | 108 | # the following line is the proposed change to the update rule 109 | # using atan2 instead of a division with epsilon in denominator 110 | # a * atan2(exp_avg / bias_correct1, b * sqrt(exp_avg_sq / bias_correct2)) 111 | 112 | den = exp_avg_sq.mul(b * b / bias_correct2).sqrt_() 113 | update = exp_avg.mul(1. / bias_correct1).atan2_(den) 114 | 115 | # maybe cautious update - algorithm 2 in https://arxiv.org/abs/2411.16085 116 | 117 | if cautious_factor < 1.: 118 | align_mask = (update * grad) > 0 119 | scale = torch.where(align_mask, torch.ones_like(grad), cautious_factor) 120 | update *= (scale / scale.mean().clamp(min = 1e-5)) 121 | 122 | # update parameters 123 | 124 | p.add_(update, alpha = -lr * a) 125 | 126 | # increment steps 127 | 128 | state['steps'] = steps 129 | 130 | return loss 131 | -------------------------------------------------------------------------------- /adam_atan2_pytorch/adam_atan2_with_wasserstein_reg.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Callable 3 | 4 | import torch 5 | from torch import atan2, sqrt 6 | from torch.optim.optimizer import Optimizer 7 | 8 | # functions 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | # class 14 | 15 | class AdamAtan2(Optimizer): 16 | def __init__( 17 | self, 18 | params, 19 | lr = 1e-4, 20 | betas: tuple[float, float] = (0.9, 0.99), 21 | weight_decay = 0., 22 | regen_reg_rate = 0., 23 | decoupled_wd = False, 24 | a = 1.27, 25 | b = 1. 26 | ): 27 | assert lr > 0. 28 | assert all([0. <= beta <= 1. for beta in betas]) 29 | assert weight_decay >= 0. 30 | assert regen_reg_rate >= 0. 31 | assert not (weight_decay > 0. and regen_reg_rate > 0.) 32 | 33 | self._init_lr = lr 34 | self.decoupled_wd = decoupled_wd 35 | 36 | defaults = dict( 37 | lr = lr, 38 | betas = betas, 39 | a = a, 40 | b = b, 41 | weight_decay = weight_decay, 42 | regen_reg_rate = regen_reg_rate, 43 | ) 44 | 45 | super().__init__(params, defaults) 46 | 47 | @torch.no_grad() 48 | def step( 49 | self, 50 | closure: Callable | None = None 51 | ): 52 | 53 | loss = None 54 | if exists(closure): 55 | with torch.enable_grad(): 56 | loss = closure() 57 | 58 | for group in self.param_groups: 59 | for p in filter(lambda p: exists(p.grad), group['params']): 60 | 61 | grad, lr, wd, regen_rate, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr 62 | 63 | # maybe decoupled weight decay 64 | 65 | if self.decoupled_wd: 66 | wd /= init_lr 67 | 68 | # weight decay 69 | 70 | if wd > 0.: 71 | p.mul_(1. - lr * wd) 72 | 73 | # regenerative regularization from Kumar et al. https://arxiv.org/abs/2308.11958 74 | 75 | if regen_rate > 0. and 'param_init' in state: 76 | param_init = state['param_init'] 77 | 78 | shape = param_init.shape 79 | 80 | # wasserstein compares using ordered statistics, iiuc 81 | 82 | indices = p.flatten().sort(dim = -1).indices 83 | indices = indices.argsort(dim = -1) 84 | 85 | target = param_init.flatten()[indices] 86 | target = target.reshape(shape) 87 | 88 | p.lerp_(target, lr / init_lr * regen_rate) 89 | 90 | # init state if needed 91 | 92 | if len(state) == 0: 93 | state['steps'] = 0 94 | state['exp_avg'] = torch.zeros_like(grad) 95 | state['exp_avg_sq'] = torch.zeros_like(grad) 96 | 97 | if regen_rate > 0.: 98 | 99 | # wasserstein reg - https://arxiv.org/abs/2406.06811v1 100 | # initial parameters sorted for efficiency 101 | 102 | shape = p.shape 103 | p = p.flatten().sort(dim = -1).values 104 | p = p.reshape(shape) 105 | 106 | state['param_init'] = p.clone() 107 | 108 | # get some of the states 109 | 110 | exp_avg, exp_avg_sq, steps = state['exp_avg'], state['exp_avg_sq'], state['steps'] 111 | 112 | steps += 1 113 | 114 | # bias corrections 115 | 116 | bias_correct1 = 1. - beta1 ** steps 117 | bias_correct2 = 1. - beta2 ** steps 118 | 119 | # decay running averages 120 | 121 | exp_avg.lerp_(grad, 1. - beta1) 122 | exp_avg_sq.lerp_(grad * grad, 1. - beta2) 123 | 124 | # the following line is the proposed change to the update rule 125 | # using atan2 instead of a division with epsilon in denominator 126 | # a * atan2(exp_avg / bias_correct1, b * sqrt(exp_avg_sq / bias_correct2)) 127 | 128 | den = exp_avg_sq.mul(b * b / bias_correct2).sqrt_() 129 | update = exp_avg.mul(1. / bias_correct1).atan2_(den) 130 | 131 | # update parameters 132 | 133 | p.add_(update, alpha = -lr * a) 134 | 135 | # increment steps 136 | 137 | state['steps'] = steps 138 | 139 | return loss 140 | 141 | Adam = AdamAtan2 -------------------------------------------------------------------------------- /adam_atan2_pytorch/adopt.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Callable 3 | 4 | import torch 5 | from torch import atan2, sqrt 6 | from torch.optim.optimizer import Optimizer 7 | 8 | # functions 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | # class 14 | 15 | class Adopt(Optimizer): 16 | """ 17 | the proposed Adam substitute from University of Tokyo 18 | 19 | Algorithm 3 in https://arxiv.org/abs/2411.02853 20 | """ 21 | 22 | def __init__( 23 | self, 24 | params, 25 | lr = 1e-4, 26 | betas: tuple[float, float] = (0.9, 0.99), 27 | eps = 1e-6, 28 | weight_decay = 0., 29 | decoupled_wd = True 30 | ): 31 | assert lr > 0. 32 | assert all([0. <= beta <= 1. for beta in betas]) 33 | assert weight_decay >= 0. 34 | 35 | self._init_lr = lr 36 | self.decoupled_wd = decoupled_wd 37 | 38 | defaults = dict( 39 | lr = lr, 40 | betas = betas, 41 | eps = eps, 42 | weight_decay = weight_decay, 43 | ) 44 | 45 | super().__init__(params, defaults) 46 | 47 | @torch.no_grad() 48 | def step( 49 | self, 50 | closure: Callable | None = None 51 | ): 52 | 53 | loss = None 54 | if exists(closure): 55 | with torch.enable_grad(): 56 | loss = closure() 57 | 58 | for group in self.param_groups: 59 | for p in filter(lambda p: exists(p.grad), group['params']): 60 | 61 | grad, lr, wd, beta1, beta2, eps, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['eps'], self.state[p], self._init_lr 62 | 63 | # maybe decoupled weight decay 64 | 65 | if self.decoupled_wd: 66 | wd /= init_lr 67 | 68 | # weight decay 69 | 70 | if wd > 0.: 71 | p.mul_(1. - lr * wd) 72 | 73 | # init state if needed 74 | 75 | if len(state) == 0: 76 | state['steps'] = 0 77 | state['m'] = torch.zeros_like(grad) 78 | state['v'] = grad * grad 79 | 80 | # get some of the states 81 | 82 | m, v, steps = state['m'], state['v'], state['steps'] 83 | 84 | # for the first step do nothing 85 | 86 | if steps == 0: 87 | state['steps'] += 1 88 | continue 89 | 90 | # calculate m 91 | 92 | grad_sq = grad * grad 93 | 94 | update = grad.div(v.sqrt().clamp(min = eps)) # they claim that a max(value, eps) performs better than adding the epsilon 95 | 96 | # clip with t ^ 0.25 as in Algorithm 3 97 | 98 | clip_value = steps ** 0.25 99 | update.clamp_(min = -clip_value, max = clip_value) 100 | 101 | # update m 102 | 103 | m.lerp_(update, 1. - beta1) 104 | 105 | # then update parameters 106 | 107 | p.add_(m, alpha = -lr) 108 | 109 | # update exp grad sq (v) 110 | 111 | v.lerp_(grad_sq, 1. - beta2) 112 | 113 | # increment steps 114 | 115 | state['steps'] += 1 116 | 117 | return loss 118 | -------------------------------------------------------------------------------- /adam_atan2_pytorch/adopt_atan2.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Callable 3 | 4 | import torch 5 | from torch import atan2, sqrt 6 | from torch.optim.optimizer import Optimizer 7 | 8 | # functions 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | # class 14 | 15 | class AdoptAtan2(Optimizer): 16 | """ 17 | the proposed Adam substitute from University of Tokyo 18 | combined with the proposed atan2 method for ridding of the eps from Google 19 | 20 | Algorithm 3 in https://arxiv.org/abs/2411.02853 21 | """ 22 | 23 | def __init__( 24 | self, 25 | params, 26 | lr = 1e-4, 27 | betas: tuple[float, float] = (0.9, 0.99), 28 | weight_decay = 0., 29 | regen_reg_rate = 0., 30 | decoupled_wd = True, 31 | cautious_factor = 1., # set to 0. for zeroing out any updates not in same direction as gradient as in https://arxiv.org/abs/2411.16085 32 | a = 1.27, 33 | b = 1. 34 | ): 35 | assert lr > 0. 36 | assert all([0. <= beta <= 1. for beta in betas]) 37 | assert weight_decay >= 0. 38 | assert not (weight_decay > 0. and regen_reg_rate > 0.) 39 | 40 | self._init_lr = lr 41 | self.decoupled_wd = decoupled_wd 42 | 43 | defaults = dict( 44 | lr = lr, 45 | betas = betas, 46 | a = a, 47 | b = b, 48 | weight_decay = weight_decay, 49 | regen_reg_rate = regen_reg_rate, 50 | cautious_factor = cautious_factor 51 | ) 52 | 53 | super().__init__(params, defaults) 54 | 55 | @torch.no_grad() 56 | def step( 57 | self, 58 | closure: Callable | None = None 59 | ): 60 | 61 | loss = None 62 | if exists(closure): 63 | with torch.enable_grad(): 64 | loss = closure() 65 | 66 | for group in self.param_groups: 67 | for p in filter(lambda p: exists(p.grad), group['params']): 68 | 69 | grad, lr, wd, regen_rate, cautious_factor, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], group['cautious_factor'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr 70 | 71 | # maybe decoupled weight decay 72 | 73 | if self.decoupled_wd: 74 | wd /= init_lr 75 | 76 | # regenerative regularization from Kumar et al. https://arxiv.org/abs/2308.11958 77 | 78 | if regen_rate > 0. and 'param_init' in state: 79 | param_init = state['param_init'] 80 | p.lerp_(param_init, lr / init_lr * regen_rate) 81 | 82 | # weight decay 83 | 84 | if wd > 0.: 85 | p.mul_(1. - lr * wd) 86 | 87 | # init state if needed 88 | 89 | if len(state) == 0: 90 | state['steps'] = 0 91 | state['m'] = torch.zeros_like(grad) 92 | state['v'] = grad * grad 93 | 94 | if regen_rate > 0.: 95 | state['param_init'] = p.clone() 96 | 97 | # get some of the states 98 | 99 | m, v, steps = state['m'], state['v'], state['steps'] 100 | 101 | # for the first step do nothing 102 | 103 | if steps == 0: 104 | state['steps'] += 1 105 | continue 106 | 107 | # calculate m 108 | 109 | grad_sq = grad * grad 110 | 111 | update = grad.atan2(b * v.sqrt()) 112 | 113 | m.lerp_(update, 1. - beta1) 114 | 115 | # maybe cautious update - algorithm 2 in https://arxiv.org/abs/2411.16085 116 | 117 | scale = 1. 118 | 119 | if cautious_factor < 1.: 120 | align_mask = (update * grad) > 0 121 | scale = torch.where(align_mask, torch.ones_like(grad), cautious_factor) 122 | scale /= scale.mean().clamp(min = 1e-5) 123 | 124 | # then update parameters 125 | 126 | p.add_(m * scale, alpha = -lr * a) 127 | 128 | # update exp grad sq (v) 129 | 130 | v.lerp_(grad_sq, 1. - beta2) 131 | 132 | # increment steps 133 | 134 | state['steps'] += 1 135 | 136 | return loss 137 | -------------------------------------------------------------------------------- /adam_atan2_pytorch/foreach.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Callable 3 | 4 | import torch 5 | from torch import atan2, sqrt, Tensor 6 | from torch.optim.optimizer import Optimizer 7 | 8 | # functions 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | def default(*args): 14 | for arg in args: 15 | if exists(arg): 16 | return arg 17 | return None 18 | 19 | # slow foreach atan2 20 | 21 | def slow_foreach_atan2_(nums: list[Tensor], dens: list[Tensor]): 22 | for num, den, in zip(nums, dens): 23 | num.atan2_(den) 24 | 25 | # class 26 | 27 | class AdamAtan2(Optimizer): 28 | def __init__( 29 | self, 30 | params, 31 | lr = 1e-4, 32 | betas: tuple[float, float] = (0.9, 0.99), 33 | weight_decay = 0., 34 | regen_reg_rate = 0., 35 | decoupled_wd = False, 36 | a = 1.27, 37 | b = 1., 38 | foreach_atan2_fn: Callable | None = None 39 | ): 40 | assert lr > 0. 41 | assert all([0. <= beta <= 1. for beta in betas]) 42 | assert weight_decay >= 0. 43 | assert regen_reg_rate >= 0. 44 | assert not (weight_decay > 0. and regen_reg_rate > 0.) 45 | assert all([hasattr(torch, f'_foreach_{attr}_') for attr in ('mul', 'add', 'lerp', 'sqrt')]), 'this version of torch does not have the prerequisite foreach functions' 46 | 47 | self._init_lr = lr 48 | self.decoupled_wd = decoupled_wd 49 | 50 | self._foreach_atan2_ = default( 51 | foreach_atan2_fn, 52 | getattr(torch, '_foreach_atan2_', None), 53 | slow_foreach_atan2_ 54 | ) 55 | 56 | defaults = dict( 57 | lr = lr, 58 | betas = betas, 59 | a = a, 60 | b = b, 61 | weight_decay = weight_decay, 62 | regen_reg_rate = regen_reg_rate 63 | ) 64 | 65 | super().__init__(params, defaults) 66 | 67 | @torch.no_grad() 68 | def step( 69 | self, 70 | closure: Callable | None = None 71 | ): 72 | init_lr = self._init_lr 73 | 74 | loss = None 75 | if exists(closure): 76 | with torch.enable_grad(): 77 | loss = closure() 78 | 79 | for group in self.param_groups: 80 | 81 | wd, regen_rate, lr, beta1, beta2, a, b = group['weight_decay'], group['regen_reg_rate'], group['lr'], *group['betas'], group['a'], group['b'] 82 | 83 | has_weight_decay = wd > 0 84 | 85 | has_regenerative_reg = regen_rate > 0 86 | 87 | # accumulate List[Tensor] for foreach inplace updates 88 | 89 | params = [] 90 | params_init = [] 91 | grads = [] 92 | grad_squared = [] 93 | exp_avgs = [] 94 | exp_avg_sqs = [] 95 | 96 | for p in filter(lambda p: exists(p.grad), group['params']): 97 | 98 | grad, state = p.grad, self.state[p] 99 | 100 | # maybe decoupled weight decay 101 | 102 | if self.decoupled_wd and has_weight_decay: 103 | wd /= init_lr 104 | 105 | # init state if needed 106 | 107 | if len(state) == 0: 108 | state['steps'] = 0 109 | state['exp_avg'] = torch.zeros_like(grad) 110 | state['exp_avg_sq'] = torch.zeros_like(grad) 111 | state['param_init'] = p.clone() 112 | 113 | # get some of the states 114 | 115 | exp_avg, exp_avg_sq, param_init, steps = state['exp_avg'], state['exp_avg_sq'], state['param_init'], state['steps'] 116 | 117 | steps += 1 118 | 119 | # bias corrections 120 | 121 | bias_correct1 = 1. - beta1 ** steps 122 | bias_correct2 = 1. - beta2 ** steps 123 | 124 | # append to list 125 | 126 | params.append(p) 127 | params_init.append(param_init) 128 | grads.append(grad) 129 | grad_squared.append(grad * grad) 130 | exp_avgs.append(exp_avg) 131 | exp_avg_sqs.append(exp_avg_sq) 132 | 133 | # update steps 134 | 135 | state['steps'] = steps 136 | 137 | # weight decay 138 | 139 | if has_weight_decay: 140 | torch._foreach_mul_(params, 1. - lr * wd) 141 | 142 | # regenerative regularization 143 | 144 | if has_regenerative_reg: 145 | torch._foreach_lerp_(params, params_init, lr / init_lr * regen_rate) 146 | 147 | # decay running averages 148 | 149 | torch._foreach_lerp_(exp_avgs, grads, 1. - beta1) 150 | torch._foreach_lerp_(exp_avg_sqs, grad_squared, 1. - beta2) 151 | 152 | # clone for update 153 | 154 | updates = [t.clone() for t in exp_avgs] 155 | den = [t.clone() for t in exp_avg_sqs] 156 | 157 | # calculate update atan2(exp_avg / bias_correct1, b * sqrt(exp_avg_sq / bias_correct2)) 158 | 159 | torch._foreach_mul_(updates, 1. / bias_correct1) 160 | 161 | torch._foreach_mul_(den, b * b / bias_correct2) 162 | torch._foreach_sqrt_(den) 163 | 164 | self._foreach_atan2_(updates, den) 165 | 166 | # update params 167 | 168 | torch._foreach_add_(params, updates, alpha = -lr * a) 169 | 170 | return loss 171 | 172 | Adam = AdamAtan2 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "adam-atan2-pytorch" 3 | version = "0.1.18" 4 | description = "Adam-atan2 for Pytorch" 5 | authors = [ 6 | { name = "Phil Wang", email = "lucidrains@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | requires-python = ">= 3.9" 10 | license = { file = "LICENSE" } 11 | keywords = [ 12 | 'artificial intelligence', 13 | 'deep learning', 14 | 'adam', 15 | 'optimizers' 16 | ] 17 | 18 | classifiers=[ 19 | 'Development Status :: 4 - Beta', 20 | 'Intended Audience :: Developers', 21 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 22 | 'License :: OSI Approved :: MIT License', 23 | 'Programming Language :: Python :: 3.9', 24 | ] 25 | 26 | dependencies = [ 27 | "torch>=2.0", 28 | ] 29 | 30 | [project.urls] 31 | Homepage = "https://pypi.org/project/adam_atan2_pytorch/" 32 | Repository = "https://github.com/lucidrains/adam_atan2_pytorch" 33 | 34 | [project.optional-dependencies] 35 | examples = [] 36 | test = [ 37 | "pytest" 38 | ] 39 | 40 | [tool.pytest.ini_options] 41 | pythonpath = [ 42 | "." 43 | ] 44 | 45 | [build-system] 46 | requires = ["hatchling"] 47 | build-backend = "hatchling.build" 48 | 49 | [tool.rye] 50 | managed = true 51 | dev-dependencies = [] 52 | 53 | [tool.hatch.metadata] 54 | allow-direct-references = true 55 | 56 | [tool.hatch.build.targets.wheel] 57 | packages = ["adam_atan2_pytorch"] 58 | --------------------------------------------------------------------------------