├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── adan-pseudocode.png
├── adan_pytorch
├── __init__.py
└── adan.py
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Adan - Pytorch
4 |
5 | Implementation of the Adan (ADAptive Nesterov momentum algorithm) Optimizer in Pytorch.
6 |
7 | Explanation from Davis Blalock
8 |
9 | Official Adan code
10 |
11 | ## Install
12 |
13 | ```bash
14 | $ pip install adan-pytorch
15 | ```
16 |
17 | ## Usage
18 |
19 | ```python
20 | from adan_pytorch import Adan
21 |
22 | # mock model
23 |
24 | import torch
25 | from torch import nn
26 |
27 | model = torch.nn.Sequential(
28 | nn.Linear(16, 16),
29 | nn.GELU()
30 | )
31 |
32 | # instantiate Adan with model parameters
33 |
34 | optim = Adan(
35 | model.parameters(),
36 | lr = 1e-3, # learning rate (can be much higher than Adam, up to 5-10x)
37 | betas = (0.02, 0.08, 0.01), # beta 1-2-3 as described in paper - author says most sensitive to beta3 tuning
38 | weight_decay = 0.02 # weight decay 0.02 is optimal per author
39 | )
40 |
41 | # train
42 |
43 | for _ in range(10):
44 | loss = model(torch.randn(16)).sum()
45 | loss.backward()
46 | optim.step()
47 | optim.zero_grad()
48 |
49 | ```
50 |
51 | ## Citations
52 |
53 | ```bibtex
54 | @article{Xie2022AdanAN,
55 | title = {Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models},
56 | author = {Xingyu Xie and Pan Zhou and Huan Li and Zhouchen Lin and Shuicheng Yan},
57 | journal = {ArXiv},
58 | year = {2022},
59 | volume = {abs/2208.06677}
60 | }
61 | ```
62 |
--------------------------------------------------------------------------------
/adan-pseudocode.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/Adan-pytorch/77328b32395cebb42d8d1d0bd7c3dd9bb8f61b8a/adan-pseudocode.png
--------------------------------------------------------------------------------
/adan_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from adan_pytorch.adan import Adan
2 |
--------------------------------------------------------------------------------
/adan_pytorch/adan.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim import Optimizer
4 |
5 | def exists(val):
6 | return val is not None
7 |
8 | class Adan(Optimizer):
9 | def __init__(
10 | self,
11 | params,
12 | lr = 1e-3,
13 | betas = (0.02, 0.08, 0.01),
14 | eps = 1e-8,
15 | weight_decay = 0,
16 | restart_cond: callable = None
17 | ):
18 | assert len(betas) == 3
19 |
20 | defaults = dict(
21 | lr = lr,
22 | betas = betas,
23 | eps = eps,
24 | weight_decay = weight_decay,
25 | restart_cond = restart_cond
26 | )
27 |
28 | super().__init__(params, defaults)
29 |
30 | def step(self, closure = None):
31 | loss = None
32 |
33 | if exists(closure):
34 | loss = closure()
35 |
36 | for group in self.param_groups:
37 |
38 | lr = group['lr']
39 | beta1, beta2, beta3 = group['betas']
40 | weight_decay = group['weight_decay']
41 | eps = group['eps']
42 | restart_cond = group['restart_cond']
43 |
44 | for p in group['params']:
45 | if not exists(p.grad):
46 | continue
47 |
48 | data, grad = p.data, p.grad.data
49 | assert not grad.is_sparse
50 |
51 | state = self.state[p]
52 |
53 | if len(state) == 0:
54 | state['step'] = 0
55 | state['prev_grad'] = torch.zeros_like(grad)
56 | state['m'] = torch.zeros_like(grad)
57 | state['v'] = torch.zeros_like(grad)
58 | state['n'] = torch.zeros_like(grad)
59 |
60 | step, m, v, n, prev_grad = state['step'], state['m'], state['v'], state['n'], state['prev_grad']
61 |
62 | if step > 0:
63 | prev_grad = state['prev_grad']
64 |
65 | # main algorithm
66 |
67 | m.mul_(1 - beta1).add_(grad, alpha = beta1)
68 |
69 | grad_diff = grad - prev_grad
70 |
71 | v.mul_(1 - beta2).add_(grad_diff, alpha = beta2)
72 |
73 | next_n = (grad + (1 - beta2) * grad_diff) ** 2
74 |
75 | n.mul_(1 - beta3).add_(next_n, alpha = beta3)
76 |
77 | # bias correction terms
78 |
79 | step += 1
80 |
81 | correct_m, correct_v, correct_n = map(lambda n: 1 / (1 - (1 - n) ** step), (beta1, beta2, beta3))
82 |
83 | # gradient step
84 |
85 | def grad_step_(data, m, v, n):
86 | weighted_step_size = lr / (n * correct_n).sqrt().add_(eps)
87 |
88 | denom = 1 + weight_decay * lr
89 |
90 | data.addcmul_(weighted_step_size, (m * correct_m + (1 - beta2) * v * correct_v), value = -1.).div_(denom)
91 |
92 | grad_step_(data, m, v, n)
93 |
94 | # restart condition
95 |
96 | if exists(restart_cond) and restart_cond(state):
97 | m.data.copy_(grad)
98 | v.zero_()
99 | n.data.copy_(grad ** 2)
100 |
101 | grad_step_(data, m, v, n)
102 |
103 | # set new incremented step
104 |
105 | prev_grad.copy_(grad)
106 | state['step'] = step
107 |
108 | return loss
109 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'adan-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.1.0',
7 | license='MIT',
8 | description = 'Adan - (ADAptive Nesterov momentum algorithm) Optimizer in Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/Adan-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'optimizer',
17 | ],
18 | install_requires=[
19 | 'torch>=1.6',
20 | ],
21 | classifiers=[
22 | 'Development Status :: 4 - Beta',
23 | 'Intended Audience :: Developers',
24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
25 | 'License :: OSI Approved :: MIT License',
26 | 'Programming Language :: Python :: 3.6',
27 | ],
28 | )
29 |
--------------------------------------------------------------------------------