├── .gitignore ├── LICENSE ├── README.md ├── cosine_annealing_warmup ├── __init__.py └── scheduler.py ├── requirements.txt ├── setup.py └── src ├── plot001.png └── plot002.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # notebook 141 | *.ipynb 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Naoki Katsura 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cosine Annealing with Warmup for PyTorch 2 | 3 | ## News 4 | - 2020/12/22 : update is comming soon... 5 | - 2020/12/24 : Merry Christmas! Release new version, 2.0. previous version is [here (branch: 1.0)](https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup/tree/1.0). 6 | - 2021/06/04 : this package can be installed with pip. 7 | 8 | ## Installation 9 | ```bash 10 | pip install 'git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup' 11 | ``` 12 | 13 | ## Args 14 | - optimizer (Optimizer): Wrapped optimizer. 15 | - first_cycle_steps (int): First cycle step size. 16 | - cycle_mult(float): Cycle steps magnification. Default: 1. 17 | - max_lr(float): First cycle's max learning rate. Default: 0.1. 18 | - min_lr(float): Min learning rate. Default: 0.001. 19 | - warmup_steps(int): Linear warmup step size. Default: 0. 20 | - gamma(float): Decrease rate of max learning rate by cycle. Default: 1. 21 | - last_epoch (int): The index of last epoch. Default: -1. 22 | 23 | ## Example 24 | ``` 25 | >> from cosine_annealing_warmup import CosineAnnealingWarmupRestarts 26 | >> 27 | >> model = ... 28 | >> optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-5) # lr is min lr 29 | >> scheduler = CosineAnnealingWarmupRestarts(optimizer, 30 | first_cycle_steps=200, 31 | cycle_mult=1.0, 32 | max_lr=0.1, 33 | min_lr=0.001, 34 | warmup_steps=50, 35 | gamma=1.0) 36 | >> for epoch in range(n_epoch): 37 | >> train() 38 | >> valid() 39 | >> scheduler.step() 40 | ``` 41 | 42 | - case1 : `CosineAnnealingWarmupRestarts(optimizer, first_cycle_steps=500, cycle_mult=1.0, max_lr=0.1, min_lr=0.001, warmup_steps=100, gamma=1.0)` 43 | ![example1](./src/plot001.png "example1") 44 | - case2 : `CosineAnnealingWarmupRestarts(optimizer, first_cycle_steps=200, cycle_mult=1.0, max_lr=0.1, min_lr=0.001, warmup_steps=50, gamma=0.5)` 45 | ![example2](./src/plot002.png "example2") 46 | -------------------------------------------------------------------------------- /cosine_annealing_warmup/__init__.py: -------------------------------------------------------------------------------- 1 | from .scheduler import CosineAnnealingWarmupRestarts 2 | 3 | __all__ = [ 4 | 'CosineAnnealingWarmupRestarts', 5 | ] -------------------------------------------------------------------------------- /cosine_annealing_warmup/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | class CosineAnnealingWarmupRestarts(_LRScheduler): 6 | """ 7 | optimizer (Optimizer): Wrapped optimizer. 8 | first_cycle_steps (int): First cycle step size. 9 | cycle_mult(float): Cycle steps magnification. Default: -1. 10 | max_lr(float): First cycle's max learning rate. Default: 0.1. 11 | min_lr(float): Min learning rate. Default: 0.001. 12 | warmup_steps(int): Linear warmup step size. Default: 0. 13 | gamma(float): Decrease rate of max learning rate by cycle. Default: 1. 14 | last_epoch (int): The index of last epoch. Default: -1. 15 | """ 16 | 17 | def __init__(self, 18 | optimizer : torch.optim.Optimizer, 19 | first_cycle_steps : int, 20 | cycle_mult : float = 1., 21 | max_lr : float = 0.1, 22 | min_lr : float = 0.001, 23 | warmup_steps : int = 0, 24 | gamma : float = 1., 25 | last_epoch : int = -1 26 | ): 27 | assert warmup_steps < first_cycle_steps 28 | 29 | self.first_cycle_steps = first_cycle_steps # first cycle step size 30 | self.cycle_mult = cycle_mult # cycle steps magnification 31 | self.base_max_lr = max_lr # first max learning rate 32 | self.max_lr = max_lr # max learning rate in the current cycle 33 | self.min_lr = min_lr # min learning rate 34 | self.warmup_steps = warmup_steps # warmup step size 35 | self.gamma = gamma # decrease rate of max learning rate by cycle 36 | 37 | self.cur_cycle_steps = first_cycle_steps # first cycle step size 38 | self.cycle = 0 # cycle count 39 | self.step_in_cycle = last_epoch # step size of the current cycle 40 | 41 | super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch) 42 | 43 | # set learning rate min_lr 44 | self.init_lr() 45 | 46 | def init_lr(self): 47 | self.base_lrs = [] 48 | for param_group in self.optimizer.param_groups: 49 | param_group['lr'] = self.min_lr 50 | self.base_lrs.append(self.min_lr) 51 | 52 | def get_lr(self): 53 | if self.step_in_cycle == -1: 54 | return self.base_lrs 55 | elif self.step_in_cycle < self.warmup_steps: 56 | return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs] 57 | else: 58 | return [base_lr + (self.max_lr - base_lr) \ 59 | * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \ 60 | / (self.cur_cycle_steps - self.warmup_steps))) / 2 61 | for base_lr in self.base_lrs] 62 | 63 | def step(self, epoch=None): 64 | if epoch is None: 65 | epoch = self.last_epoch + 1 66 | self.step_in_cycle = self.step_in_cycle + 1 67 | if self.step_in_cycle >= self.cur_cycle_steps: 68 | self.cycle += 1 69 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps 70 | self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps 71 | else: 72 | if epoch >= self.first_cycle_steps: 73 | if self.cycle_mult == 1.: 74 | self.step_in_cycle = epoch % self.first_cycle_steps 75 | self.cycle = epoch // self.first_cycle_steps 76 | else: 77 | n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult)) 78 | self.cycle = n 79 | self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)) 80 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n) 81 | else: 82 | self.cur_cycle_steps = self.first_cycle_steps 83 | self.step_in_cycle = epoch 84 | 85 | self.max_lr = self.base_max_lr * (self.gamma**self.cycle) 86 | self.last_epoch = math.floor(epoch) 87 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 88 | param_group['lr'] = lr -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="cosine_annealing_warmup", 5 | version="2.0", 6 | author="Naoki Katsura", 7 | packages=['cosine_annealing_warmup'], 8 | description="Cosine Annealing with Warmup for PyTorch", 9 | long_description=open("README.md").read(), 10 | ) 11 | -------------------------------------------------------------------------------- /src/plot001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/katsura-jp/pytorch-cosine-annealing-with-warmup/12d03c07553aedd3d9e9155e2b3e31ce8c64081a/src/plot001.png -------------------------------------------------------------------------------- /src/plot002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/katsura-jp/pytorch-cosine-annealing-with-warmup/12d03c07553aedd3d9e9155e2b3e31ce8c64081a/src/plot002.png --------------------------------------------------------------------------------