├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── adam-atan2.png
├── adam_atan2_pytorch
├── __init__.py
├── adam_atan2.py
├── adam_atan2_with_wasserstein_reg.py
├── adopt.py
├── adopt_atan2.py
└── foreach.py
└── pyproject.toml
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Adam-atan2 - Pytorch
4 |
5 | Implementation of the proposed Adam-atan2 optimizer in Pytorch
6 |
7 | A multi-million dollar paper out of google deepmind proposes a small change to Adam update rule (using `atan2`) to remove the epsilon altogether for numerical stability and scale invariance
8 |
9 | It also contains some features for improving plasticity (continual learning field)
10 |
11 | ## Install
12 |
13 | ```bash
14 | $ pip install adam-atan2-pytorch
15 | ```
16 |
17 | ## Usage
18 |
19 | ```python
20 | import torch
21 | from torch import nn
22 |
23 | # toy model
24 |
25 | model = nn.Linear(10, 1)
26 |
27 | # import AdamAtan2 and instantiate with parameters
28 |
29 | from adam_atan2_pytorch import AdamAtan2
30 |
31 | opt = AdamAtan2(model.parameters(), lr = 1e-4)
32 |
33 | # forward and backwards
34 |
35 | for _ in range(100):
36 | loss = model(torch.randn(10))
37 | loss.backward()
38 |
39 | # optimizer step
40 |
41 | opt.step()
42 | opt.zero_grad()
43 | ```
44 |
45 | ## Citations
46 |
47 | ```bibtex
48 | @inproceedings{Everett2024ScalingEA,
49 | title = {Scaling Exponents Across Parameterizations and Optimizers},
50 | author = {Katie Everett and Lechao Xiao and Mitchell Wortsman and Alex Alemi and Roman Novak and Peter J. Liu and Izzeddin Gur and Jascha Narain Sohl-Dickstein and Leslie Pack Kaelbling and Jaehoon Lee and Jeffrey Pennington},
51 | year = {2024},
52 | url = {https://api.semanticscholar.org/CorpusID:271051056}
53 | }
54 | ```
55 |
56 | ```bibtex
57 | @inproceedings{Kumar2023MaintainingPI,
58 | title = {Maintaining Plasticity in Continual Learning via Regenerative Regularization},
59 | author = {Saurabh Kumar and Henrik Marklund and Benjamin Van Roy},
60 | year = {2023},
61 | url = {https://api.semanticscholar.org/CorpusID:261076021}
62 | }
63 | ```
64 |
65 | ```bibtex
66 | @article{Lewandowski2024LearningCB,
67 | title = {Learning Continually by Spectral Regularization},
68 | author = {Alex Lewandowski and Saurabh Kumar and Dale Schuurmans and Andr'as Gyorgy and Marlos C. Machado},
69 | journal = {ArXiv},
70 | year = {2024},
71 | volume = {abs/2406.06811},
72 | url = {https://api.semanticscholar.org/CorpusID:270380086}
73 | }
74 | ```
75 |
76 | ```bibtex
77 | @inproceedings{Taniguchi2024ADOPTMA,
78 | title = {ADOPT: Modified Adam Can Converge with Any \$\beta\_2\$ with the Optimal Rate},
79 | author = {Shohei Taniguchi and Keno Harada and Gouki Minegishi and Yuta Oshima and Seong Cheol Jeong and Go Nagahara and Tomoshi Iiyama and Masahiro Suzuki and Yusuke Iwasawa and Yutaka Matsuo},
80 | year = {2024},
81 | url = {https://api.semanticscholar.org/CorpusID:273822148}
82 | }
83 | ```
84 |
85 | ```bibtex
86 | @inproceedings{Liang2024CautiousOI,
87 | title = {Cautious Optimizers: Improving Training with One Line of Code},
88 | author = {Kaizhao Liang and Lizhang Chen and Bo Liu and Qiang Liu},
89 | year = {2024},
90 | url = {https://api.semanticscholar.org/CorpusID:274234738}
91 | }
92 | ```
93 |
--------------------------------------------------------------------------------
/adam-atan2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/adam-atan2-pytorch/8f14cf50a89030edbdea5214d22ac268511b400e/adam-atan2.png
--------------------------------------------------------------------------------
/adam_atan2_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from adam_atan2_pytorch.adam_atan2 import AdamAtan2
2 | from adam_atan2_pytorch.adopt_atan2 import AdoptAtan2
3 |
4 | Adam = AdamAtan2
5 | Adopt = AdoptAtan2
6 |
--------------------------------------------------------------------------------
/adam_atan2_pytorch/adam_atan2.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Callable
3 |
4 | import torch
5 | from torch import atan2, sqrt
6 | from torch.optim.optimizer import Optimizer
7 |
8 | # functions
9 |
10 | def exists(val):
11 | return val is not None
12 |
13 | # class
14 |
15 | class AdamAtan2(Optimizer):
16 | def __init__(
17 | self,
18 | params,
19 | lr = 1e-4,
20 | betas: tuple[float, float] = (0.9, 0.99),
21 | weight_decay = 0.,
22 | regen_reg_rate = 0.,
23 | decoupled_wd = False,
24 | cautious_factor = 1., # set to 0. for zeroing out any updates not in same direction as gradient as in https://arxiv.org/abs/2411.16085
25 | a = 1.27,
26 | b = 1.
27 | ):
28 | assert lr > 0.
29 | assert all([0. <= beta <= 1. for beta in betas])
30 | assert weight_decay >= 0.
31 | assert regen_reg_rate >= 0.
32 | assert not (weight_decay > 0. and regen_reg_rate > 0.)
33 | assert 0. <= cautious_factor <= 1.
34 |
35 | self._init_lr = lr
36 | self.decoupled_wd = decoupled_wd
37 |
38 | defaults = dict(
39 | lr = lr,
40 | betas = betas,
41 | a = a,
42 | b = b,
43 | weight_decay = weight_decay,
44 | regen_reg_rate = regen_reg_rate,
45 | cautious_factor = cautious_factor
46 | )
47 |
48 | super().__init__(params, defaults)
49 |
50 | @torch.no_grad()
51 | def step(
52 | self,
53 | closure: Callable | None = None
54 | ):
55 |
56 | loss = None
57 | if exists(closure):
58 | with torch.enable_grad():
59 | loss = closure()
60 |
61 | for group in self.param_groups:
62 | for p in filter(lambda p: exists(p.grad), group['params']):
63 |
64 | grad, lr, wd, regen_rate, cautious_factor, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], group['cautious_factor'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr
65 |
66 | # maybe decoupled weight decay
67 |
68 | if self.decoupled_wd:
69 | wd /= init_lr
70 |
71 | # weight decay
72 |
73 | if wd > 0.:
74 | p.mul_(1. - lr * wd)
75 |
76 | # regenerative regularization from Kumar et al. https://arxiv.org/abs/2308.11958
77 |
78 | if regen_rate > 0. and 'param_init' in state:
79 | param_init = state['param_init']
80 | p.lerp_(param_init, lr / init_lr * regen_rate)
81 |
82 | # init state if needed
83 |
84 | if len(state) == 0:
85 | state['steps'] = 0
86 | state['exp_avg'] = torch.zeros_like(grad)
87 | state['exp_avg_sq'] = torch.zeros_like(grad)
88 |
89 | if regen_rate > 0.:
90 | state['param_init'] = p.clone()
91 |
92 | # get some of the states
93 |
94 | exp_avg, exp_avg_sq, steps = state['exp_avg'], state['exp_avg_sq'], state['steps']
95 |
96 | steps += 1
97 |
98 | # bias corrections
99 |
100 | bias_correct1 = 1. - beta1 ** steps
101 | bias_correct2 = 1. - beta2 ** steps
102 |
103 | # decay running averages
104 |
105 | exp_avg.lerp_(grad, 1. - beta1)
106 | exp_avg_sq.lerp_(grad * grad, 1. - beta2)
107 |
108 | # the following line is the proposed change to the update rule
109 | # using atan2 instead of a division with epsilon in denominator
110 | # a * atan2(exp_avg / bias_correct1, b * sqrt(exp_avg_sq / bias_correct2))
111 |
112 | den = exp_avg_sq.mul(b * b / bias_correct2).sqrt_()
113 | update = exp_avg.mul(1. / bias_correct1).atan2_(den)
114 |
115 | # maybe cautious update - algorithm 2 in https://arxiv.org/abs/2411.16085
116 |
117 | if cautious_factor < 1.:
118 | align_mask = (update * grad) > 0
119 | scale = torch.where(align_mask, torch.ones_like(grad), cautious_factor)
120 | update *= (scale / scale.mean().clamp(min = 1e-5))
121 |
122 | # update parameters
123 |
124 | p.add_(update, alpha = -lr * a)
125 |
126 | # increment steps
127 |
128 | state['steps'] = steps
129 |
130 | return loss
131 |
--------------------------------------------------------------------------------
/adam_atan2_pytorch/adam_atan2_with_wasserstein_reg.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Callable
3 |
4 | import torch
5 | from torch import atan2, sqrt
6 | from torch.optim.optimizer import Optimizer
7 |
8 | # functions
9 |
10 | def exists(val):
11 | return val is not None
12 |
13 | # class
14 |
15 | class AdamAtan2(Optimizer):
16 | def __init__(
17 | self,
18 | params,
19 | lr = 1e-4,
20 | betas: tuple[float, float] = (0.9, 0.99),
21 | weight_decay = 0.,
22 | regen_reg_rate = 0.,
23 | decoupled_wd = False,
24 | a = 1.27,
25 | b = 1.
26 | ):
27 | assert lr > 0.
28 | assert all([0. <= beta <= 1. for beta in betas])
29 | assert weight_decay >= 0.
30 | assert regen_reg_rate >= 0.
31 | assert not (weight_decay > 0. and regen_reg_rate > 0.)
32 |
33 | self._init_lr = lr
34 | self.decoupled_wd = decoupled_wd
35 |
36 | defaults = dict(
37 | lr = lr,
38 | betas = betas,
39 | a = a,
40 | b = b,
41 | weight_decay = weight_decay,
42 | regen_reg_rate = regen_reg_rate,
43 | )
44 |
45 | super().__init__(params, defaults)
46 |
47 | @torch.no_grad()
48 | def step(
49 | self,
50 | closure: Callable | None = None
51 | ):
52 |
53 | loss = None
54 | if exists(closure):
55 | with torch.enable_grad():
56 | loss = closure()
57 |
58 | for group in self.param_groups:
59 | for p in filter(lambda p: exists(p.grad), group['params']):
60 |
61 | grad, lr, wd, regen_rate, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr
62 |
63 | # maybe decoupled weight decay
64 |
65 | if self.decoupled_wd:
66 | wd /= init_lr
67 |
68 | # weight decay
69 |
70 | if wd > 0.:
71 | p.mul_(1. - lr * wd)
72 |
73 | # regenerative regularization from Kumar et al. https://arxiv.org/abs/2308.11958
74 |
75 | if regen_rate > 0. and 'param_init' in state:
76 | param_init = state['param_init']
77 |
78 | shape = param_init.shape
79 |
80 | # wasserstein compares using ordered statistics, iiuc
81 |
82 | indices = p.flatten().sort(dim = -1).indices
83 | indices = indices.argsort(dim = -1)
84 |
85 | target = param_init.flatten()[indices]
86 | target = target.reshape(shape)
87 |
88 | p.lerp_(target, lr / init_lr * regen_rate)
89 |
90 | # init state if needed
91 |
92 | if len(state) == 0:
93 | state['steps'] = 0
94 | state['exp_avg'] = torch.zeros_like(grad)
95 | state['exp_avg_sq'] = torch.zeros_like(grad)
96 |
97 | if regen_rate > 0.:
98 |
99 | # wasserstein reg - https://arxiv.org/abs/2406.06811v1
100 | # initial parameters sorted for efficiency
101 |
102 | shape = p.shape
103 | p = p.flatten().sort(dim = -1).values
104 | p = p.reshape(shape)
105 |
106 | state['param_init'] = p.clone()
107 |
108 | # get some of the states
109 |
110 | exp_avg, exp_avg_sq, steps = state['exp_avg'], state['exp_avg_sq'], state['steps']
111 |
112 | steps += 1
113 |
114 | # bias corrections
115 |
116 | bias_correct1 = 1. - beta1 ** steps
117 | bias_correct2 = 1. - beta2 ** steps
118 |
119 | # decay running averages
120 |
121 | exp_avg.lerp_(grad, 1. - beta1)
122 | exp_avg_sq.lerp_(grad * grad, 1. - beta2)
123 |
124 | # the following line is the proposed change to the update rule
125 | # using atan2 instead of a division with epsilon in denominator
126 | # a * atan2(exp_avg / bias_correct1, b * sqrt(exp_avg_sq / bias_correct2))
127 |
128 | den = exp_avg_sq.mul(b * b / bias_correct2).sqrt_()
129 | update = exp_avg.mul(1. / bias_correct1).atan2_(den)
130 |
131 | # update parameters
132 |
133 | p.add_(update, alpha = -lr * a)
134 |
135 | # increment steps
136 |
137 | state['steps'] = steps
138 |
139 | return loss
140 |
141 | Adam = AdamAtan2
--------------------------------------------------------------------------------
/adam_atan2_pytorch/adopt.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Callable
3 |
4 | import torch
5 | from torch import atan2, sqrt
6 | from torch.optim.optimizer import Optimizer
7 |
8 | # functions
9 |
10 | def exists(val):
11 | return val is not None
12 |
13 | # class
14 |
15 | class Adopt(Optimizer):
16 | """
17 | the proposed Adam substitute from University of Tokyo
18 |
19 | Algorithm 3 in https://arxiv.org/abs/2411.02853
20 | """
21 |
22 | def __init__(
23 | self,
24 | params,
25 | lr = 1e-4,
26 | betas: tuple[float, float] = (0.9, 0.99),
27 | eps = 1e-6,
28 | weight_decay = 0.,
29 | decoupled_wd = True
30 | ):
31 | assert lr > 0.
32 | assert all([0. <= beta <= 1. for beta in betas])
33 | assert weight_decay >= 0.
34 |
35 | self._init_lr = lr
36 | self.decoupled_wd = decoupled_wd
37 |
38 | defaults = dict(
39 | lr = lr,
40 | betas = betas,
41 | eps = eps,
42 | weight_decay = weight_decay,
43 | )
44 |
45 | super().__init__(params, defaults)
46 |
47 | @torch.no_grad()
48 | def step(
49 | self,
50 | closure: Callable | None = None
51 | ):
52 |
53 | loss = None
54 | if exists(closure):
55 | with torch.enable_grad():
56 | loss = closure()
57 |
58 | for group in self.param_groups:
59 | for p in filter(lambda p: exists(p.grad), group['params']):
60 |
61 | grad, lr, wd, beta1, beta2, eps, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['eps'], self.state[p], self._init_lr
62 |
63 | # maybe decoupled weight decay
64 |
65 | if self.decoupled_wd:
66 | wd /= init_lr
67 |
68 | # weight decay
69 |
70 | if wd > 0.:
71 | p.mul_(1. - lr * wd)
72 |
73 | # init state if needed
74 |
75 | if len(state) == 0:
76 | state['steps'] = 0
77 | state['m'] = torch.zeros_like(grad)
78 | state['v'] = grad * grad
79 |
80 | # get some of the states
81 |
82 | m, v, steps = state['m'], state['v'], state['steps']
83 |
84 | # for the first step do nothing
85 |
86 | if steps == 0:
87 | state['steps'] += 1
88 | continue
89 |
90 | # calculate m
91 |
92 | grad_sq = grad * grad
93 |
94 | update = grad.div(v.sqrt().clamp(min = eps)) # they claim that a max(value, eps) performs better than adding the epsilon
95 |
96 | # clip with t ^ 0.25 as in Algorithm 3
97 |
98 | clip_value = steps ** 0.25
99 | update.clamp_(min = -clip_value, max = clip_value)
100 |
101 | # update m
102 |
103 | m.lerp_(update, 1. - beta1)
104 |
105 | # then update parameters
106 |
107 | p.add_(m, alpha = -lr)
108 |
109 | # update exp grad sq (v)
110 |
111 | v.lerp_(grad_sq, 1. - beta2)
112 |
113 | # increment steps
114 |
115 | state['steps'] += 1
116 |
117 | return loss
118 |
--------------------------------------------------------------------------------
/adam_atan2_pytorch/adopt_atan2.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Callable
3 |
4 | import torch
5 | from torch import atan2, sqrt
6 | from torch.optim.optimizer import Optimizer
7 |
8 | # functions
9 |
10 | def exists(val):
11 | return val is not None
12 |
13 | # class
14 |
15 | class AdoptAtan2(Optimizer):
16 | """
17 | the proposed Adam substitute from University of Tokyo
18 | combined with the proposed atan2 method for ridding of the eps from Google
19 |
20 | Algorithm 3 in https://arxiv.org/abs/2411.02853
21 | """
22 |
23 | def __init__(
24 | self,
25 | params,
26 | lr = 1e-4,
27 | betas: tuple[float, float] = (0.9, 0.99),
28 | weight_decay = 0.,
29 | regen_reg_rate = 0.,
30 | decoupled_wd = True,
31 | cautious_factor = 1., # set to 0. for zeroing out any updates not in same direction as gradient as in https://arxiv.org/abs/2411.16085
32 | a = 1.27,
33 | b = 1.
34 | ):
35 | assert lr > 0.
36 | assert all([0. <= beta <= 1. for beta in betas])
37 | assert weight_decay >= 0.
38 | assert not (weight_decay > 0. and regen_reg_rate > 0.)
39 |
40 | self._init_lr = lr
41 | self.decoupled_wd = decoupled_wd
42 |
43 | defaults = dict(
44 | lr = lr,
45 | betas = betas,
46 | a = a,
47 | b = b,
48 | weight_decay = weight_decay,
49 | regen_reg_rate = regen_reg_rate,
50 | cautious_factor = cautious_factor
51 | )
52 |
53 | super().__init__(params, defaults)
54 |
55 | @torch.no_grad()
56 | def step(
57 | self,
58 | closure: Callable | None = None
59 | ):
60 |
61 | loss = None
62 | if exists(closure):
63 | with torch.enable_grad():
64 | loss = closure()
65 |
66 | for group in self.param_groups:
67 | for p in filter(lambda p: exists(p.grad), group['params']):
68 |
69 | grad, lr, wd, regen_rate, cautious_factor, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], group['cautious_factor'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr
70 |
71 | # maybe decoupled weight decay
72 |
73 | if self.decoupled_wd:
74 | wd /= init_lr
75 |
76 | # regenerative regularization from Kumar et al. https://arxiv.org/abs/2308.11958
77 |
78 | if regen_rate > 0. and 'param_init' in state:
79 | param_init = state['param_init']
80 | p.lerp_(param_init, lr / init_lr * regen_rate)
81 |
82 | # weight decay
83 |
84 | if wd > 0.:
85 | p.mul_(1. - lr * wd)
86 |
87 | # init state if needed
88 |
89 | if len(state) == 0:
90 | state['steps'] = 0
91 | state['m'] = torch.zeros_like(grad)
92 | state['v'] = grad * grad
93 |
94 | if regen_rate > 0.:
95 | state['param_init'] = p.clone()
96 |
97 | # get some of the states
98 |
99 | m, v, steps = state['m'], state['v'], state['steps']
100 |
101 | # for the first step do nothing
102 |
103 | if steps == 0:
104 | state['steps'] += 1
105 | continue
106 |
107 | # calculate m
108 |
109 | grad_sq = grad * grad
110 |
111 | update = grad.atan2(b * v.sqrt())
112 |
113 | m.lerp_(update, 1. - beta1)
114 |
115 | # maybe cautious update - algorithm 2 in https://arxiv.org/abs/2411.16085
116 |
117 | scale = 1.
118 |
119 | if cautious_factor < 1.:
120 | align_mask = (update * grad) > 0
121 | scale = torch.where(align_mask, torch.ones_like(grad), cautious_factor)
122 | scale /= scale.mean().clamp(min = 1e-5)
123 |
124 | # then update parameters
125 |
126 | p.add_(m * scale, alpha = -lr * a)
127 |
128 | # update exp grad sq (v)
129 |
130 | v.lerp_(grad_sq, 1. - beta2)
131 |
132 | # increment steps
133 |
134 | state['steps'] += 1
135 |
136 | return loss
137 |
--------------------------------------------------------------------------------
/adam_atan2_pytorch/foreach.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Callable
3 |
4 | import torch
5 | from torch import atan2, sqrt, Tensor
6 | from torch.optim.optimizer import Optimizer
7 |
8 | # functions
9 |
10 | def exists(val):
11 | return val is not None
12 |
13 | def default(*args):
14 | for arg in args:
15 | if exists(arg):
16 | return arg
17 | return None
18 |
19 | # slow foreach atan2
20 |
21 | def slow_foreach_atan2_(nums: list[Tensor], dens: list[Tensor]):
22 | for num, den, in zip(nums, dens):
23 | num.atan2_(den)
24 |
25 | # class
26 |
27 | class AdamAtan2(Optimizer):
28 | def __init__(
29 | self,
30 | params,
31 | lr = 1e-4,
32 | betas: tuple[float, float] = (0.9, 0.99),
33 | weight_decay = 0.,
34 | regen_reg_rate = 0.,
35 | decoupled_wd = False,
36 | a = 1.27,
37 | b = 1.,
38 | foreach_atan2_fn: Callable | None = None
39 | ):
40 | assert lr > 0.
41 | assert all([0. <= beta <= 1. for beta in betas])
42 | assert weight_decay >= 0.
43 | assert regen_reg_rate >= 0.
44 | assert not (weight_decay > 0. and regen_reg_rate > 0.)
45 | assert all([hasattr(torch, f'_foreach_{attr}_') for attr in ('mul', 'add', 'lerp', 'sqrt')]), 'this version of torch does not have the prerequisite foreach functions'
46 |
47 | self._init_lr = lr
48 | self.decoupled_wd = decoupled_wd
49 |
50 | self._foreach_atan2_ = default(
51 | foreach_atan2_fn,
52 | getattr(torch, '_foreach_atan2_', None),
53 | slow_foreach_atan2_
54 | )
55 |
56 | defaults = dict(
57 | lr = lr,
58 | betas = betas,
59 | a = a,
60 | b = b,
61 | weight_decay = weight_decay,
62 | regen_reg_rate = regen_reg_rate
63 | )
64 |
65 | super().__init__(params, defaults)
66 |
67 | @torch.no_grad()
68 | def step(
69 | self,
70 | closure: Callable | None = None
71 | ):
72 | init_lr = self._init_lr
73 |
74 | loss = None
75 | if exists(closure):
76 | with torch.enable_grad():
77 | loss = closure()
78 |
79 | for group in self.param_groups:
80 |
81 | wd, regen_rate, lr, beta1, beta2, a, b = group['weight_decay'], group['regen_reg_rate'], group['lr'], *group['betas'], group['a'], group['b']
82 |
83 | has_weight_decay = wd > 0
84 |
85 | has_regenerative_reg = regen_rate > 0
86 |
87 | # accumulate List[Tensor] for foreach inplace updates
88 |
89 | params = []
90 | params_init = []
91 | grads = []
92 | grad_squared = []
93 | exp_avgs = []
94 | exp_avg_sqs = []
95 |
96 | for p in filter(lambda p: exists(p.grad), group['params']):
97 |
98 | grad, state = p.grad, self.state[p]
99 |
100 | # maybe decoupled weight decay
101 |
102 | if self.decoupled_wd and has_weight_decay:
103 | wd /= init_lr
104 |
105 | # init state if needed
106 |
107 | if len(state) == 0:
108 | state['steps'] = 0
109 | state['exp_avg'] = torch.zeros_like(grad)
110 | state['exp_avg_sq'] = torch.zeros_like(grad)
111 | state['param_init'] = p.clone()
112 |
113 | # get some of the states
114 |
115 | exp_avg, exp_avg_sq, param_init, steps = state['exp_avg'], state['exp_avg_sq'], state['param_init'], state['steps']
116 |
117 | steps += 1
118 |
119 | # bias corrections
120 |
121 | bias_correct1 = 1. - beta1 ** steps
122 | bias_correct2 = 1. - beta2 ** steps
123 |
124 | # append to list
125 |
126 | params.append(p)
127 | params_init.append(param_init)
128 | grads.append(grad)
129 | grad_squared.append(grad * grad)
130 | exp_avgs.append(exp_avg)
131 | exp_avg_sqs.append(exp_avg_sq)
132 |
133 | # update steps
134 |
135 | state['steps'] = steps
136 |
137 | # weight decay
138 |
139 | if has_weight_decay:
140 | torch._foreach_mul_(params, 1. - lr * wd)
141 |
142 | # regenerative regularization
143 |
144 | if has_regenerative_reg:
145 | torch._foreach_lerp_(params, params_init, lr / init_lr * regen_rate)
146 |
147 | # decay running averages
148 |
149 | torch._foreach_lerp_(exp_avgs, grads, 1. - beta1)
150 | torch._foreach_lerp_(exp_avg_sqs, grad_squared, 1. - beta2)
151 |
152 | # clone for update
153 |
154 | updates = [t.clone() for t in exp_avgs]
155 | den = [t.clone() for t in exp_avg_sqs]
156 |
157 | # calculate update atan2(exp_avg / bias_correct1, b * sqrt(exp_avg_sq / bias_correct2))
158 |
159 | torch._foreach_mul_(updates, 1. / bias_correct1)
160 |
161 | torch._foreach_mul_(den, b * b / bias_correct2)
162 | torch._foreach_sqrt_(den)
163 |
164 | self._foreach_atan2_(updates, den)
165 |
166 | # update params
167 |
168 | torch._foreach_add_(params, updates, alpha = -lr * a)
169 |
170 | return loss
171 |
172 | Adam = AdamAtan2
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "adam-atan2-pytorch"
3 | version = "0.1.18"
4 | description = "Adam-atan2 for Pytorch"
5 | authors = [
6 | { name = "Phil Wang", email = "lucidrains@gmail.com" }
7 | ]
8 | readme = "README.md"
9 | requires-python = ">= 3.9"
10 | license = { file = "LICENSE" }
11 | keywords = [
12 | 'artificial intelligence',
13 | 'deep learning',
14 | 'adam',
15 | 'optimizers'
16 | ]
17 |
18 | classifiers=[
19 | 'Development Status :: 4 - Beta',
20 | 'Intended Audience :: Developers',
21 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
22 | 'License :: OSI Approved :: MIT License',
23 | 'Programming Language :: Python :: 3.9',
24 | ]
25 |
26 | dependencies = [
27 | "torch>=2.0",
28 | ]
29 |
30 | [project.urls]
31 | Homepage = "https://pypi.org/project/adam_atan2_pytorch/"
32 | Repository = "https://github.com/lucidrains/adam_atan2_pytorch"
33 |
34 | [project.optional-dependencies]
35 | examples = []
36 | test = [
37 | "pytest"
38 | ]
39 |
40 | [tool.pytest.ini_options]
41 | pythonpath = [
42 | "."
43 | ]
44 |
45 | [build-system]
46 | requires = ["hatchling"]
47 | build-backend = "hatchling.build"
48 |
49 | [tool.rye]
50 | managed = true
51 | dev-dependencies = []
52 |
53 | [tool.hatch.metadata]
54 | allow-direct-references = true
55 |
56 | [tool.hatch.build.targets.wheel]
57 | packages = ["adam_atan2_pytorch"]
58 |
--------------------------------------------------------------------------------