├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── lion.png ├── lion_pytorch ├── __init__.py ├── cautious_lion.py ├── foreach.py ├── lion_pytorch.py └── triton.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## 🦁 Lion - Pytorch 4 | 5 | 🦁 Lion, Evo**L**ved S**i**gn M**o**me**n**tum, new optimizer discovered by Google Brain that is purportedly better than Adam(w), in Pytorch. This is nearly a straight copy from here, with few minor modifications. 6 | 7 | It is so simple, we may as well get it accessible and used asap by everyone to train some great models, if it really works 🤞 8 | 9 | ### Instructions 10 | - Learning rate and weight decay: the authors write in Section 5 - `Based on our experience, a suitable learning rate for Lion is typically 3-10x smaller than that for AdamW. Since the effective weight decay is lr * λ, the value of decoupled weight decay λ used for Lion is 3-10x larger than that for AdamW in order to maintain a similar strength.` The initial value, peak value, and end value in the learning rate schedule should be changed ***simultaneously*** with the same ratio compared to AdamW, [evidenced by a researcher](https://github.com/lucidrains/lion-pytorch/discussions/1#discussioncomment-5239900). 11 | 12 | - Learning rate schedule: the authors use the same learning rate schedule for Lion as AdamW in the paper. Nevertheless, they observe a larger gain when using a cosine decay schedule to train ViT, compared to a reciprocal square-root schedule. 13 | 14 | - β1 and β2: the authors write in Section 5 - `The default values for β1 and β2 in AdamW are set as 0.9 and 0.999, respectively, with an ε of 1e−8, while in Lion, the default values for β1 and β2 are discovered through the program search process and set as 0.9 and 0.99, respectively.` Similar to how people reduce β2 to 0.99 or smaller and increase ε to 1e-6 in AdamW to improve stability, using `β1=0.95, β2=0.98` in Lion can also be helpful in mitigating instability during training, suggested by the authors. This was corroborated by a researcher. 15 | 16 | ### Updates 17 | - Update: seems to work for my local enwik8 autoregressive language modeling. 18 | 19 | - Update 2: experiments, seems much worse than Adam if learning rate held constant. 20 | 21 | - Update 3: Dividing the learning rate by 3, seeing better early results than Adam. Maybe Adam has been dethroned, after nearly a decade. 22 | 23 | - Update 4: using the 10x smaller learning rate rule of thumb from the paper resulted in the worst run. So I guess it still takes a bit of tuning. 24 | 25 | A summarization of previous updates: as shown in the experiments, Lion with a 3x smaller learning rate beats Adam. It still takes a bit of tuning as a 10x smaller learning rate leads to a worse result. 26 | 27 | - Update 5: so far hearing all positive results for language modeling, when done right. Also heard positive results for significant text-to-image training, although it takes a bit of tuning. The negative results seem to be with problems and architectures outside of what was evaluated in the paper - RL, feedforward networks, weird hybrid architectures with LSTMs + convolutions etc. Negative anecdata also confirms this technique is sensitive to batch size, amount of data / augmentation. Tbd what optimal learning rate schedule is, and whether cooldown affects results. Also interestingly have a positive result at open-clip, which became negative as the model size was scaled up (but may be resolvable). 28 | 29 | - Update 6: open clip issue [resolved by the author](https://github.com/mlfoundations/open_clip/pull/432#issuecomment-1457323237), by setting a higher initial temperature. 30 | 31 | - Update 7: would only recommend this optimizer in the setting of high batch sizes (64 or above) 32 | 33 | ## Install 34 | 35 | ```bash 36 | $ pip install lion-pytorch 37 | ``` 38 | Alternatively, using conda: 39 | ```bash 40 | $ conda install lion-pytorch 41 | ``` 42 | 43 | ## Usage 44 | 45 | ```python 46 | # toy model 47 | 48 | import torch 49 | from torch import nn 50 | 51 | model = nn.Linear(10, 1) 52 | 53 | # import Lion and instantiate with parameters 54 | 55 | from lion_pytorch import Lion 56 | 57 | opt = Lion(model.parameters(), lr=1e-4, weight_decay=1e-2) 58 | 59 | # forward and backwards 60 | 61 | loss = model(torch.randn(10)) 62 | loss.backward() 63 | 64 | # optimizer step 65 | 66 | opt.step() 67 | opt.zero_grad() 68 | ``` 69 | 70 | To use a fused kernel for updating the parameters, first `pip install triton -U --pre`, then 71 | 72 | ```python 73 | opt = Lion( 74 | model.parameters(), 75 | lr=1e-4, 76 | weight_decay=1e-2, 77 | use_triton=True # set this to True to use cuda kernel w/ Triton lang (Tillet et al) 78 | ) 79 | ``` 80 | 81 | ## Appreciation 82 | 83 | - Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research 84 | 85 | ## Citations 86 | 87 | ```bibtex 88 | @misc{https://doi.org/10.48550/arxiv.2302.06675, 89 | url = {https://arxiv.org/abs/2302.06675}, 90 | author = {Chen, Xiangning and Liang, Chen and Huang, Da and Real, Esteban and Wang, Kaiyuan and Liu, Yao and Pham, Hieu and Dong, Xuanyi and Luong, Thang and Hsieh, Cho-Jui and Lu, Yifeng and Le, Quoc V.}, 91 | title = {Symbolic Discovery of Optimization Algorithms}, 92 | publisher = {arXiv}, 93 | year = {2023} 94 | } 95 | ``` 96 | 97 | ```bibtex 98 | @article{Tillet2019TritonAI, 99 | title = {Triton: an intermediate language and compiler for tiled neural network computations}, 100 | author = {Philippe Tillet and H. Kung and D. Cox}, 101 | journal = {Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages}, 102 | year = {2019} 103 | } 104 | ``` 105 | 106 | ```bibtex 107 | @misc{Schaipp2024, 108 | author = {Fabian Schaipp}, 109 | url = {https://fabian-sp.github.io/posts/2024/02/decoupling/} 110 | } 111 | ``` 112 | 113 | ```bibtex 114 | @inproceedings{Liang2024CautiousOI, 115 | title = {Cautious Optimizers: Improving Training with One Line of Code}, 116 | author = {Kaizhao Liang and Lizhang Chen and Bo Liu and Qiang Liu}, 117 | year = {2024}, 118 | url = {https://api.semanticscholar.org/CorpusID:274234738} 119 | } 120 | ``` 121 | -------------------------------------------------------------------------------- /lion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/lion-pytorch/6a74fdc0ba572ab5683dc0270c66c20ecbc02d09/lion.png -------------------------------------------------------------------------------- /lion_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from lion_pytorch.lion_pytorch import Lion 2 | -------------------------------------------------------------------------------- /lion_pytorch/cautious_lion.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Tuple, Callable 3 | 4 | import torch 5 | from torch.optim.optimizer import Optimizer 6 | 7 | # functions 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | # class 13 | 14 | class Lion(Optimizer): 15 | def __init__( 16 | self, 17 | params, 18 | lr: float = 1e-4, 19 | betas: Tuple[float, float] = (0.9, 0.99), 20 | weight_decay: float = 0.0, 21 | cautious_factor: float = 0., 22 | decoupled_weight_decay: bool = False, 23 | ): 24 | assert lr > 0. 25 | assert all([0. <= beta <= 1. for beta in betas]) 26 | assert 0. <= cautious_factor <= 1. 27 | 28 | self._init_lr = lr 29 | self.decoupled_wd = decoupled_weight_decay 30 | 31 | defaults = dict( 32 | lr = lr, 33 | betas = betas, 34 | weight_decay = weight_decay, 35 | cautious_factor = cautious_factor 36 | ) 37 | 38 | super().__init__(params, defaults) 39 | 40 | @torch.no_grad() 41 | def step( 42 | self, 43 | closure: Callable | None = None 44 | ): 45 | 46 | loss = None 47 | if exists(closure): 48 | with torch.enable_grad(): 49 | loss = closure() 50 | 51 | for group in self.param_groups: 52 | for p in filter(lambda p: exists(p.grad), group['params']): 53 | 54 | grad, lr, wd, cautious_factor, beta1, beta2, state, decoupled_wd, init_lr = p.grad, group['lr'], group['weight_decay'], group['cautious_factor'], *group['betas'], self.state[p], self.decoupled_wd, self._init_lr 55 | 56 | # maybe decoupled weight decay 57 | 58 | if decoupled_wd: 59 | wd /= init_lr 60 | 61 | # init state - exponential moving average of gradient values 62 | 63 | if len(state) == 0: 64 | state['exp_avg'] = torch.zeros_like(p) 65 | 66 | exp_avg = state['exp_avg'] 67 | 68 | # stepweight decay 69 | 70 | p.data.mul_(1. - lr * wd) 71 | 72 | # weight update 73 | 74 | update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1. - beta1).sign_() 75 | 76 | # maybe cautious update - algorithm 2 in https://arxiv.org/abs/2411.16085 77 | 78 | if cautious_factor < 1.: 79 | align_mask = (update * grad) > 0 80 | scale = torch.where(align_mask, torch.ones_like(grad), cautious_factor) 81 | scale /= scale.mean().clamp(min = 1e-5) 82 | update.mul_(scale) 83 | 84 | # update params 85 | 86 | p.add_(update, alpha = -lr) 87 | 88 | # decay the momentum running average coefficient 89 | 90 | exp_avg.mul_(beta2).add_(grad, alpha = 1. - beta2) 91 | 92 | return loss 93 | -------------------------------------------------------------------------------- /lion_pytorch/foreach.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Tuple, Callable 3 | 4 | import torch 5 | from torch.optim.optimizer import Optimizer 6 | 7 | # functions 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | # class 13 | 14 | class Lion(Optimizer): 15 | def __init__( 16 | self, 17 | params, 18 | lr: float = 1e-4, 19 | betas: Tuple[float, float] = (0.9, 0.99), 20 | weight_decay: float = 0.0, 21 | decoupled_weight_decay: bool = False 22 | ): 23 | assert lr > 0. 24 | assert all([0. <= beta <= 1. for beta in betas]) 25 | assert all([hasattr(torch, f'_foreach_{attr}_') for attr in ('mul', 'add', 'sign', 'lerp')]), 'this version of torch does not have the prerequisite foreach functions' 26 | 27 | self._init_lr = lr 28 | self.decoupled_wd = decoupled_weight_decay 29 | 30 | defaults = dict( 31 | lr = lr, 32 | betas = betas, 33 | weight_decay = weight_decay 34 | ) 35 | 36 | super().__init__(params, defaults) 37 | 38 | @torch.no_grad() 39 | def step( 40 | self, 41 | closure: Callable | None = None 42 | ): 43 | 44 | loss = None 45 | if exists(closure): 46 | with torch.enable_grad(): 47 | loss = closure() 48 | 49 | for group in self.param_groups: 50 | 51 | lr, wd, beta1, beta2, decoupled_wd, init_lr = group['lr'], group['weight_decay'], *group['betas'], self.decoupled_wd, self._init_lr 52 | 53 | # maybe decoupled weight decay 54 | 55 | if decoupled_wd: 56 | wd /= init_lr 57 | 58 | # accumulate List[Tensor] for foreach inplace updates 59 | 60 | params = [] 61 | grads = [] 62 | exp_avgs = [] 63 | 64 | for p in filter(lambda p: exists(p.grad), group['params']): 65 | 66 | grad, state = p.grad, self.state[p] 67 | 68 | # init state - exponential moving average of gradient values 69 | 70 | if len(state) == 0: 71 | state['exp_avg'] = torch.zeros_like(p) 72 | 73 | exp_avg = state['exp_avg'] 74 | 75 | params.append(p) 76 | grads.append(grad) 77 | exp_avgs.append(exp_avg) 78 | 79 | # stepweight decay 80 | 81 | if wd > 0.: 82 | torch._foreach_mul_(params, 1. - lr * wd) 83 | 84 | # weight update 85 | 86 | updates = [t.clone() for t in exp_avgs] 87 | torch._foreach_lerp_(updates, grads, 1. - beta1) 88 | torch._foreach_sign_(updates) 89 | 90 | torch._foreach_add_(params, updates, alpha = -lr) 91 | 92 | # decay momentum running average 93 | 94 | torch._foreach_lerp_(exp_avgs, grads, 1. - beta2) 95 | 96 | return loss 97 | -------------------------------------------------------------------------------- /lion_pytorch/lion_pytorch.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Tuple, Callable 3 | 4 | import torch 5 | from torch.optim.optimizer import Optimizer 6 | 7 | # functions 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | # update functions 13 | 14 | def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): 15 | # stepweight decay 16 | 17 | p.data.mul_(1. - lr * wd) 18 | 19 | # weight update 20 | 21 | update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1. - beta1).sign_() 22 | p.add_(update, alpha = -lr) 23 | 24 | # decay the momentum running average coefficient 25 | 26 | exp_avg.mul_(beta2).add_(grad, alpha = 1. - beta2) 27 | 28 | # class 29 | 30 | class Lion(Optimizer): 31 | def __init__( 32 | self, 33 | params, 34 | lr: float = 1e-4, 35 | betas: Tuple[float, float] = (0.9, 0.99), 36 | weight_decay: float = 0.0, 37 | use_triton: bool = False, 38 | decoupled_weight_decay: bool = False, 39 | ): 40 | assert lr > 0. 41 | assert all([0. <= beta <= 1. for beta in betas]) 42 | 43 | self._init_lr = lr 44 | self.decoupled_wd = decoupled_weight_decay 45 | 46 | defaults = dict( 47 | lr = lr, 48 | betas = betas, 49 | weight_decay = weight_decay 50 | ) 51 | 52 | super().__init__(params, defaults) 53 | 54 | self.update_fn = update_fn 55 | 56 | if use_triton: 57 | from lion_pytorch.triton import update_fn as triton_update_fn 58 | self.update_fn = triton_update_fn 59 | 60 | @torch.no_grad() 61 | def step( 62 | self, 63 | closure: Callable | None = None 64 | ): 65 | 66 | loss = None 67 | if exists(closure): 68 | with torch.enable_grad(): 69 | loss = closure() 70 | 71 | for group in self.param_groups: 72 | for p in filter(lambda p: exists(p.grad), group['params']): 73 | 74 | grad, lr, wd, beta1, beta2, state, decoupled_wd, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p], self.decoupled_wd, self._init_lr 75 | 76 | # maybe decoupled weight decay 77 | 78 | if decoupled_wd: 79 | wd /= init_lr 80 | 81 | # init state - exponential moving average of gradient values 82 | 83 | if len(state) == 0: 84 | state['exp_avg'] = torch.zeros_like(p) 85 | 86 | exp_avg = state['exp_avg'] 87 | 88 | self.update_fn( 89 | p, 90 | grad, 91 | exp_avg, 92 | lr, 93 | wd, 94 | beta1, 95 | beta2 96 | ) 97 | 98 | return loss 99 | -------------------------------------------------------------------------------- /lion_pytorch/triton.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | try: 4 | import triton 5 | import triton.language as tl 6 | except ImportError as e: 7 | print('triton is not installed, please install by running `pip install triton>=2.2.0`') 8 | exit() 9 | 10 | # triton cuda kernel 11 | 12 | @triton.autotune(configs = [ 13 | triton.Config({'BLOCK_SIZE': 128}, num_warps = 4), 14 | triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8), 15 | ], key = ['n_elements'], restore_value=['p_ptr', 'exp_avg_ptr']) 16 | @triton.jit 17 | def update_fn_kernel( 18 | p_ptr, 19 | grad_ptr, 20 | exp_avg_ptr, 21 | lr, 22 | wd, 23 | beta1, 24 | beta2, 25 | n_elements, 26 | BLOCK_SIZE: tl.constexpr, 27 | ): 28 | pid = tl.program_id(axis = 0) 29 | 30 | block_start = pid * BLOCK_SIZE 31 | offsets = block_start + tl.arange(0, BLOCK_SIZE) 32 | 33 | mask = offsets < n_elements 34 | 35 | # offsetted pointers 36 | 37 | offset_p_ptr = p_ptr + offsets 38 | offset_grad_ptr = grad_ptr + offsets 39 | offset_exp_avg_ptr = exp_avg_ptr + offsets 40 | 41 | # load 42 | 43 | p = tl.load(offset_p_ptr, mask = mask) 44 | grad = tl.load(offset_grad_ptr, mask = mask) 45 | exp_avg = tl.load(offset_exp_avg_ptr, mask = mask) 46 | 47 | # stepweight decay 48 | 49 | p = p * (1 - lr * wd) 50 | 51 | # diff between momentum running average and grad 52 | 53 | diff = exp_avg - grad 54 | 55 | # weight update 56 | 57 | update = diff * beta1 + grad 58 | 59 | # torch.sign 60 | 61 | can_update = update != 0 62 | update_sign = tl.where(update > 0, -lr, lr) 63 | 64 | p = p + update_sign * can_update 65 | 66 | # decay the momentum running average coefficient 67 | 68 | exp_avg = diff * beta2 + grad 69 | 70 | # store new params and momentum running average coefficient 71 | 72 | tl.store(offset_p_ptr, p, mask = mask) 73 | tl.store(offset_exp_avg_ptr, exp_avg, mask = mask) 74 | 75 | def update_fn( 76 | p: torch.Tensor, 77 | grad: torch.Tensor, 78 | exp_avg: torch.Tensor, 79 | lr: float, 80 | wd: float, 81 | beta1: float, 82 | beta2: float 83 | ): 84 | assert all([t.is_cuda for t in (p, grad, exp_avg)]) 85 | n_elements = p.numel() 86 | 87 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 88 | 89 | update_fn_kernel[grid]( 90 | p, 91 | grad, 92 | exp_avg, 93 | lr, 94 | wd, 95 | beta1, 96 | beta2, 97 | n_elements 98 | ) 99 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'lion-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.2.3', 7 | license='MIT', 8 | description = 'Lion Optimizer - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/lion-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'optimizers' 17 | ], 18 | install_requires=[ 19 | 'torch>=1.6' 20 | ], 21 | classifiers=[ 22 | 'Development Status :: 4 - Beta', 23 | 'Intended Audience :: Developers', 24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 25 | 'License :: OSI Approved :: MIT License', 26 | 'Programming Language :: Python :: 3.6', 27 | ], 28 | ) 29 | --------------------------------------------------------------------------------