├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── lion.png
├── lion_pytorch
├── __init__.py
├── cautious_lion.py
├── foreach.py
├── lion_pytorch.py
└── triton.py
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## 🦁 Lion - Pytorch
4 |
5 | 🦁 Lion, Evo**L**ved S**i**gn M**o**me**n**tum, new optimizer discovered by Google Brain that is purportedly better than Adam(w), in Pytorch. This is nearly a straight copy from here, with few minor modifications.
6 |
7 | It is so simple, we may as well get it accessible and used asap by everyone to train some great models, if it really works 🤞
8 |
9 | ### Instructions
10 | - Learning rate and weight decay: the authors write in Section 5 - `Based on our experience, a suitable learning rate for Lion is typically 3-10x smaller than that for AdamW. Since the effective weight decay is lr * λ, the value of decoupled weight decay λ used for Lion is 3-10x larger than that for AdamW in order to maintain a similar strength.` The initial value, peak value, and end value in the learning rate schedule should be changed ***simultaneously*** with the same ratio compared to AdamW, [evidenced by a researcher](https://github.com/lucidrains/lion-pytorch/discussions/1#discussioncomment-5239900).
11 |
12 | - Learning rate schedule: the authors use the same learning rate schedule for Lion as AdamW in the paper. Nevertheless, they observe a larger gain when using a cosine decay schedule to train ViT, compared to a reciprocal square-root schedule.
13 |
14 | - β1 and β2: the authors write in Section 5 - `The default values for β1 and β2 in AdamW are set as 0.9 and 0.999, respectively, with an ε of 1e−8, while in Lion, the default values for β1 and β2 are discovered through the program search process and set as 0.9 and 0.99, respectively.` Similar to how people reduce β2 to 0.99 or smaller and increase ε to 1e-6 in AdamW to improve stability, using `β1=0.95, β2=0.98` in Lion can also be helpful in mitigating instability during training, suggested by the authors. This was corroborated by a researcher.
15 |
16 | ### Updates
17 | - Update: seems to work for my local enwik8 autoregressive language modeling.
18 |
19 | - Update 2: experiments, seems much worse than Adam if learning rate held constant.
20 |
21 | - Update 3: Dividing the learning rate by 3, seeing better early results than Adam. Maybe Adam has been dethroned, after nearly a decade.
22 |
23 | - Update 4: using the 10x smaller learning rate rule of thumb from the paper resulted in the worst run. So I guess it still takes a bit of tuning.
24 |
25 | A summarization of previous updates: as shown in the experiments, Lion with a 3x smaller learning rate beats Adam. It still takes a bit of tuning as a 10x smaller learning rate leads to a worse result.
26 |
27 | - Update 5: so far hearing all positive results for language modeling, when done right. Also heard positive results for significant text-to-image training, although it takes a bit of tuning. The negative results seem to be with problems and architectures outside of what was evaluated in the paper - RL, feedforward networks, weird hybrid architectures with LSTMs + convolutions etc. Negative anecdata also confirms this technique is sensitive to batch size, amount of data / augmentation. Tbd what optimal learning rate schedule is, and whether cooldown affects results. Also interestingly have a positive result at open-clip, which became negative as the model size was scaled up (but may be resolvable).
28 |
29 | - Update 6: open clip issue [resolved by the author](https://github.com/mlfoundations/open_clip/pull/432#issuecomment-1457323237), by setting a higher initial temperature.
30 |
31 | - Update 7: would only recommend this optimizer in the setting of high batch sizes (64 or above)
32 |
33 | ## Install
34 |
35 | ```bash
36 | $ pip install lion-pytorch
37 | ```
38 | Alternatively, using conda:
39 | ```bash
40 | $ conda install lion-pytorch
41 | ```
42 |
43 | ## Usage
44 |
45 | ```python
46 | # toy model
47 |
48 | import torch
49 | from torch import nn
50 |
51 | model = nn.Linear(10, 1)
52 |
53 | # import Lion and instantiate with parameters
54 |
55 | from lion_pytorch import Lion
56 |
57 | opt = Lion(model.parameters(), lr=1e-4, weight_decay=1e-2)
58 |
59 | # forward and backwards
60 |
61 | loss = model(torch.randn(10))
62 | loss.backward()
63 |
64 | # optimizer step
65 |
66 | opt.step()
67 | opt.zero_grad()
68 | ```
69 |
70 | To use a fused kernel for updating the parameters, first `pip install triton -U --pre`, then
71 |
72 | ```python
73 | opt = Lion(
74 | model.parameters(),
75 | lr=1e-4,
76 | weight_decay=1e-2,
77 | use_triton=True # set this to True to use cuda kernel w/ Triton lang (Tillet et al)
78 | )
79 | ```
80 |
81 | ## Appreciation
82 |
83 | - Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research
84 |
85 | ## Citations
86 |
87 | ```bibtex
88 | @misc{https://doi.org/10.48550/arxiv.2302.06675,
89 | url = {https://arxiv.org/abs/2302.06675},
90 | author = {Chen, Xiangning and Liang, Chen and Huang, Da and Real, Esteban and Wang, Kaiyuan and Liu, Yao and Pham, Hieu and Dong, Xuanyi and Luong, Thang and Hsieh, Cho-Jui and Lu, Yifeng and Le, Quoc V.},
91 | title = {Symbolic Discovery of Optimization Algorithms},
92 | publisher = {arXiv},
93 | year = {2023}
94 | }
95 | ```
96 |
97 | ```bibtex
98 | @article{Tillet2019TritonAI,
99 | title = {Triton: an intermediate language and compiler for tiled neural network computations},
100 | author = {Philippe Tillet and H. Kung and D. Cox},
101 | journal = {Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages},
102 | year = {2019}
103 | }
104 | ```
105 |
106 | ```bibtex
107 | @misc{Schaipp2024,
108 | author = {Fabian Schaipp},
109 | url = {https://fabian-sp.github.io/posts/2024/02/decoupling/}
110 | }
111 | ```
112 |
113 | ```bibtex
114 | @inproceedings{Liang2024CautiousOI,
115 | title = {Cautious Optimizers: Improving Training with One Line of Code},
116 | author = {Kaizhao Liang and Lizhang Chen and Bo Liu and Qiang Liu},
117 | year = {2024},
118 | url = {https://api.semanticscholar.org/CorpusID:274234738}
119 | }
120 | ```
121 |
--------------------------------------------------------------------------------
/lion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/lion-pytorch/6a74fdc0ba572ab5683dc0270c66c20ecbc02d09/lion.png
--------------------------------------------------------------------------------
/lion_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from lion_pytorch.lion_pytorch import Lion
2 |
--------------------------------------------------------------------------------
/lion_pytorch/cautious_lion.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Tuple, Callable
3 |
4 | import torch
5 | from torch.optim.optimizer import Optimizer
6 |
7 | # functions
8 |
9 | def exists(val):
10 | return val is not None
11 |
12 | # class
13 |
14 | class Lion(Optimizer):
15 | def __init__(
16 | self,
17 | params,
18 | lr: float = 1e-4,
19 | betas: Tuple[float, float] = (0.9, 0.99),
20 | weight_decay: float = 0.0,
21 | cautious_factor: float = 0.,
22 | decoupled_weight_decay: bool = False,
23 | ):
24 | assert lr > 0.
25 | assert all([0. <= beta <= 1. for beta in betas])
26 | assert 0. <= cautious_factor <= 1.
27 |
28 | self._init_lr = lr
29 | self.decoupled_wd = decoupled_weight_decay
30 |
31 | defaults = dict(
32 | lr = lr,
33 | betas = betas,
34 | weight_decay = weight_decay,
35 | cautious_factor = cautious_factor
36 | )
37 |
38 | super().__init__(params, defaults)
39 |
40 | @torch.no_grad()
41 | def step(
42 | self,
43 | closure: Callable | None = None
44 | ):
45 |
46 | loss = None
47 | if exists(closure):
48 | with torch.enable_grad():
49 | loss = closure()
50 |
51 | for group in self.param_groups:
52 | for p in filter(lambda p: exists(p.grad), group['params']):
53 |
54 | grad, lr, wd, cautious_factor, beta1, beta2, state, decoupled_wd, init_lr = p.grad, group['lr'], group['weight_decay'], group['cautious_factor'], *group['betas'], self.state[p], self.decoupled_wd, self._init_lr
55 |
56 | # maybe decoupled weight decay
57 |
58 | if decoupled_wd:
59 | wd /= init_lr
60 |
61 | # init state - exponential moving average of gradient values
62 |
63 | if len(state) == 0:
64 | state['exp_avg'] = torch.zeros_like(p)
65 |
66 | exp_avg = state['exp_avg']
67 |
68 | # stepweight decay
69 |
70 | p.data.mul_(1. - lr * wd)
71 |
72 | # weight update
73 |
74 | update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1. - beta1).sign_()
75 |
76 | # maybe cautious update - algorithm 2 in https://arxiv.org/abs/2411.16085
77 |
78 | if cautious_factor < 1.:
79 | align_mask = (update * grad) > 0
80 | scale = torch.where(align_mask, torch.ones_like(grad), cautious_factor)
81 | scale /= scale.mean().clamp(min = 1e-5)
82 | update.mul_(scale)
83 |
84 | # update params
85 |
86 | p.add_(update, alpha = -lr)
87 |
88 | # decay the momentum running average coefficient
89 |
90 | exp_avg.mul_(beta2).add_(grad, alpha = 1. - beta2)
91 |
92 | return loss
93 |
--------------------------------------------------------------------------------
/lion_pytorch/foreach.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Tuple, Callable
3 |
4 | import torch
5 | from torch.optim.optimizer import Optimizer
6 |
7 | # functions
8 |
9 | def exists(val):
10 | return val is not None
11 |
12 | # class
13 |
14 | class Lion(Optimizer):
15 | def __init__(
16 | self,
17 | params,
18 | lr: float = 1e-4,
19 | betas: Tuple[float, float] = (0.9, 0.99),
20 | weight_decay: float = 0.0,
21 | decoupled_weight_decay: bool = False
22 | ):
23 | assert lr > 0.
24 | assert all([0. <= beta <= 1. for beta in betas])
25 | assert all([hasattr(torch, f'_foreach_{attr}_') for attr in ('mul', 'add', 'sign', 'lerp')]), 'this version of torch does not have the prerequisite foreach functions'
26 |
27 | self._init_lr = lr
28 | self.decoupled_wd = decoupled_weight_decay
29 |
30 | defaults = dict(
31 | lr = lr,
32 | betas = betas,
33 | weight_decay = weight_decay
34 | )
35 |
36 | super().__init__(params, defaults)
37 |
38 | @torch.no_grad()
39 | def step(
40 | self,
41 | closure: Callable | None = None
42 | ):
43 |
44 | loss = None
45 | if exists(closure):
46 | with torch.enable_grad():
47 | loss = closure()
48 |
49 | for group in self.param_groups:
50 |
51 | lr, wd, beta1, beta2, decoupled_wd, init_lr = group['lr'], group['weight_decay'], *group['betas'], self.decoupled_wd, self._init_lr
52 |
53 | # maybe decoupled weight decay
54 |
55 | if decoupled_wd:
56 | wd /= init_lr
57 |
58 | # accumulate List[Tensor] for foreach inplace updates
59 |
60 | params = []
61 | grads = []
62 | exp_avgs = []
63 |
64 | for p in filter(lambda p: exists(p.grad), group['params']):
65 |
66 | grad, state = p.grad, self.state[p]
67 |
68 | # init state - exponential moving average of gradient values
69 |
70 | if len(state) == 0:
71 | state['exp_avg'] = torch.zeros_like(p)
72 |
73 | exp_avg = state['exp_avg']
74 |
75 | params.append(p)
76 | grads.append(grad)
77 | exp_avgs.append(exp_avg)
78 |
79 | # stepweight decay
80 |
81 | if wd > 0.:
82 | torch._foreach_mul_(params, 1. - lr * wd)
83 |
84 | # weight update
85 |
86 | updates = [t.clone() for t in exp_avgs]
87 | torch._foreach_lerp_(updates, grads, 1. - beta1)
88 | torch._foreach_sign_(updates)
89 |
90 | torch._foreach_add_(params, updates, alpha = -lr)
91 |
92 | # decay momentum running average
93 |
94 | torch._foreach_lerp_(exp_avgs, grads, 1. - beta2)
95 |
96 | return loss
97 |
--------------------------------------------------------------------------------
/lion_pytorch/lion_pytorch.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Tuple, Callable
3 |
4 | import torch
5 | from torch.optim.optimizer import Optimizer
6 |
7 | # functions
8 |
9 | def exists(val):
10 | return val is not None
11 |
12 | # update functions
13 |
14 | def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
15 | # stepweight decay
16 |
17 | p.data.mul_(1. - lr * wd)
18 |
19 | # weight update
20 |
21 | update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1. - beta1).sign_()
22 | p.add_(update, alpha = -lr)
23 |
24 | # decay the momentum running average coefficient
25 |
26 | exp_avg.mul_(beta2).add_(grad, alpha = 1. - beta2)
27 |
28 | # class
29 |
30 | class Lion(Optimizer):
31 | def __init__(
32 | self,
33 | params,
34 | lr: float = 1e-4,
35 | betas: Tuple[float, float] = (0.9, 0.99),
36 | weight_decay: float = 0.0,
37 | use_triton: bool = False,
38 | decoupled_weight_decay: bool = False,
39 | ):
40 | assert lr > 0.
41 | assert all([0. <= beta <= 1. for beta in betas])
42 |
43 | self._init_lr = lr
44 | self.decoupled_wd = decoupled_weight_decay
45 |
46 | defaults = dict(
47 | lr = lr,
48 | betas = betas,
49 | weight_decay = weight_decay
50 | )
51 |
52 | super().__init__(params, defaults)
53 |
54 | self.update_fn = update_fn
55 |
56 | if use_triton:
57 | from lion_pytorch.triton import update_fn as triton_update_fn
58 | self.update_fn = triton_update_fn
59 |
60 | @torch.no_grad()
61 | def step(
62 | self,
63 | closure: Callable | None = None
64 | ):
65 |
66 | loss = None
67 | if exists(closure):
68 | with torch.enable_grad():
69 | loss = closure()
70 |
71 | for group in self.param_groups:
72 | for p in filter(lambda p: exists(p.grad), group['params']):
73 |
74 | grad, lr, wd, beta1, beta2, state, decoupled_wd, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p], self.decoupled_wd, self._init_lr
75 |
76 | # maybe decoupled weight decay
77 |
78 | if decoupled_wd:
79 | wd /= init_lr
80 |
81 | # init state - exponential moving average of gradient values
82 |
83 | if len(state) == 0:
84 | state['exp_avg'] = torch.zeros_like(p)
85 |
86 | exp_avg = state['exp_avg']
87 |
88 | self.update_fn(
89 | p,
90 | grad,
91 | exp_avg,
92 | lr,
93 | wd,
94 | beta1,
95 | beta2
96 | )
97 |
98 | return loss
99 |
--------------------------------------------------------------------------------
/lion_pytorch/triton.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | try:
4 | import triton
5 | import triton.language as tl
6 | except ImportError as e:
7 | print('triton is not installed, please install by running `pip install triton>=2.2.0`')
8 | exit()
9 |
10 | # triton cuda kernel
11 |
12 | @triton.autotune(configs = [
13 | triton.Config({'BLOCK_SIZE': 128}, num_warps = 4),
14 | triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8),
15 | ], key = ['n_elements'], restore_value=['p_ptr', 'exp_avg_ptr'])
16 | @triton.jit
17 | def update_fn_kernel(
18 | p_ptr,
19 | grad_ptr,
20 | exp_avg_ptr,
21 | lr,
22 | wd,
23 | beta1,
24 | beta2,
25 | n_elements,
26 | BLOCK_SIZE: tl.constexpr,
27 | ):
28 | pid = tl.program_id(axis = 0)
29 |
30 | block_start = pid * BLOCK_SIZE
31 | offsets = block_start + tl.arange(0, BLOCK_SIZE)
32 |
33 | mask = offsets < n_elements
34 |
35 | # offsetted pointers
36 |
37 | offset_p_ptr = p_ptr + offsets
38 | offset_grad_ptr = grad_ptr + offsets
39 | offset_exp_avg_ptr = exp_avg_ptr + offsets
40 |
41 | # load
42 |
43 | p = tl.load(offset_p_ptr, mask = mask)
44 | grad = tl.load(offset_grad_ptr, mask = mask)
45 | exp_avg = tl.load(offset_exp_avg_ptr, mask = mask)
46 |
47 | # stepweight decay
48 |
49 | p = p * (1 - lr * wd)
50 |
51 | # diff between momentum running average and grad
52 |
53 | diff = exp_avg - grad
54 |
55 | # weight update
56 |
57 | update = diff * beta1 + grad
58 |
59 | # torch.sign
60 |
61 | can_update = update != 0
62 | update_sign = tl.where(update > 0, -lr, lr)
63 |
64 | p = p + update_sign * can_update
65 |
66 | # decay the momentum running average coefficient
67 |
68 | exp_avg = diff * beta2 + grad
69 |
70 | # store new params and momentum running average coefficient
71 |
72 | tl.store(offset_p_ptr, p, mask = mask)
73 | tl.store(offset_exp_avg_ptr, exp_avg, mask = mask)
74 |
75 | def update_fn(
76 | p: torch.Tensor,
77 | grad: torch.Tensor,
78 | exp_avg: torch.Tensor,
79 | lr: float,
80 | wd: float,
81 | beta1: float,
82 | beta2: float
83 | ):
84 | assert all([t.is_cuda for t in (p, grad, exp_avg)])
85 | n_elements = p.numel()
86 |
87 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
88 |
89 | update_fn_kernel[grid](
90 | p,
91 | grad,
92 | exp_avg,
93 | lr,
94 | wd,
95 | beta1,
96 | beta2,
97 | n_elements
98 | )
99 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'lion-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.2.3',
7 | license='MIT',
8 | description = 'Lion Optimizer - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/lion-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'optimizers'
17 | ],
18 | install_requires=[
19 | 'torch>=1.6'
20 | ],
21 | classifiers=[
22 | 'Development Status :: 4 - Beta',
23 | 'Intended Audience :: Developers',
24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
25 | 'License :: OSI Approved :: MIT License',
26 | 'Programming Language :: Python :: 3.6',
27 | ],
28 | )
29 |
--------------------------------------------------------------------------------