├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── adan-pseudocode.png ├── adan_pytorch ├── __init__.py └── adan.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Adan - Pytorch 4 | 5 | Implementation of the Adan (ADAptive Nesterov momentum algorithm) Optimizer in Pytorch. 6 | 7 | Explanation from Davis Blalock 8 | 9 | Official Adan code 10 | 11 | ## Install 12 | 13 | ```bash 14 | $ pip install adan-pytorch 15 | ``` 16 | 17 | ## Usage 18 | 19 | ```python 20 | from adan_pytorch import Adan 21 | 22 | # mock model 23 | 24 | import torch 25 | from torch import nn 26 | 27 | model = torch.nn.Sequential( 28 | nn.Linear(16, 16), 29 | nn.GELU() 30 | ) 31 | 32 | # instantiate Adan with model parameters 33 | 34 | optim = Adan( 35 | model.parameters(), 36 | lr = 1e-3, # learning rate (can be much higher than Adam, up to 5-10x) 37 | betas = (0.02, 0.08, 0.01), # beta 1-2-3 as described in paper - author says most sensitive to beta3 tuning 38 | weight_decay = 0.02 # weight decay 0.02 is optimal per author 39 | ) 40 | 41 | # train 42 | 43 | for _ in range(10): 44 | loss = model(torch.randn(16)).sum() 45 | loss.backward() 46 | optim.step() 47 | optim.zero_grad() 48 | 49 | ``` 50 | 51 | ## Citations 52 | 53 | ```bibtex 54 | @article{Xie2022AdanAN, 55 | title = {Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models}, 56 | author = {Xingyu Xie and Pan Zhou and Huan Li and Zhouchen Lin and Shuicheng Yan}, 57 | journal = {ArXiv}, 58 | year = {2022}, 59 | volume = {abs/2208.06677} 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /adan-pseudocode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/Adan-pytorch/77328b32395cebb42d8d1d0bd7c3dd9bb8f61b8a/adan-pseudocode.png -------------------------------------------------------------------------------- /adan_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from adan_pytorch.adan import Adan 2 | -------------------------------------------------------------------------------- /adan_pytorch/adan.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim import Optimizer 4 | 5 | def exists(val): 6 | return val is not None 7 | 8 | class Adan(Optimizer): 9 | def __init__( 10 | self, 11 | params, 12 | lr = 1e-3, 13 | betas = (0.02, 0.08, 0.01), 14 | eps = 1e-8, 15 | weight_decay = 0, 16 | restart_cond: callable = None 17 | ): 18 | assert len(betas) == 3 19 | 20 | defaults = dict( 21 | lr = lr, 22 | betas = betas, 23 | eps = eps, 24 | weight_decay = weight_decay, 25 | restart_cond = restart_cond 26 | ) 27 | 28 | super().__init__(params, defaults) 29 | 30 | def step(self, closure = None): 31 | loss = None 32 | 33 | if exists(closure): 34 | loss = closure() 35 | 36 | for group in self.param_groups: 37 | 38 | lr = group['lr'] 39 | beta1, beta2, beta3 = group['betas'] 40 | weight_decay = group['weight_decay'] 41 | eps = group['eps'] 42 | restart_cond = group['restart_cond'] 43 | 44 | for p in group['params']: 45 | if not exists(p.grad): 46 | continue 47 | 48 | data, grad = p.data, p.grad.data 49 | assert not grad.is_sparse 50 | 51 | state = self.state[p] 52 | 53 | if len(state) == 0: 54 | state['step'] = 0 55 | state['prev_grad'] = torch.zeros_like(grad) 56 | state['m'] = torch.zeros_like(grad) 57 | state['v'] = torch.zeros_like(grad) 58 | state['n'] = torch.zeros_like(grad) 59 | 60 | step, m, v, n, prev_grad = state['step'], state['m'], state['v'], state['n'], state['prev_grad'] 61 | 62 | if step > 0: 63 | prev_grad = state['prev_grad'] 64 | 65 | # main algorithm 66 | 67 | m.mul_(1 - beta1).add_(grad, alpha = beta1) 68 | 69 | grad_diff = grad - prev_grad 70 | 71 | v.mul_(1 - beta2).add_(grad_diff, alpha = beta2) 72 | 73 | next_n = (grad + (1 - beta2) * grad_diff) ** 2 74 | 75 | n.mul_(1 - beta3).add_(next_n, alpha = beta3) 76 | 77 | # bias correction terms 78 | 79 | step += 1 80 | 81 | correct_m, correct_v, correct_n = map(lambda n: 1 / (1 - (1 - n) ** step), (beta1, beta2, beta3)) 82 | 83 | # gradient step 84 | 85 | def grad_step_(data, m, v, n): 86 | weighted_step_size = lr / (n * correct_n).sqrt().add_(eps) 87 | 88 | denom = 1 + weight_decay * lr 89 | 90 | data.addcmul_(weighted_step_size, (m * correct_m + (1 - beta2) * v * correct_v), value = -1.).div_(denom) 91 | 92 | grad_step_(data, m, v, n) 93 | 94 | # restart condition 95 | 96 | if exists(restart_cond) and restart_cond(state): 97 | m.data.copy_(grad) 98 | v.zero_() 99 | n.data.copy_(grad ** 2) 100 | 101 | grad_step_(data, m, v, n) 102 | 103 | # set new incremented step 104 | 105 | prev_grad.copy_(grad) 106 | state['step'] = step 107 | 108 | return loss 109 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'adan-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.1.0', 7 | license='MIT', 8 | description = 'Adan - (ADAptive Nesterov momentum algorithm) Optimizer in Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/Adan-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'optimizer', 17 | ], 18 | install_requires=[ 19 | 'torch>=1.6', 20 | ], 21 | classifiers=[ 22 | 'Development Status :: 4 - Beta', 23 | 'Intended Audience :: Developers', 24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 25 | 'License :: OSI Approved :: MIT License', 26 | 'Programming Language :: Python :: 3.6', 27 | ], 28 | ) 29 | --------------------------------------------------------------------------------