├── .gitignore ├── README.md └── warmup_scheduler ├── __init__.py └── scheduler.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-warmup-cosine-lr 2 | 3 | paper : Bag of Tricks for Image Classification with Convolutional Neural Networks (https://arxiv.org/abs/1812.01187) 4 | 5 | ![Figure_1](https://user-images.githubusercontent.com/33244972/61711191-6bf9b900-ad8e-11e9-85f0-e6c55fbc5bc6.png) 6 | 7 | 8 | ## Usage 9 | 10 | python scheduler.py 11 | 12 | ## Import 13 | 14 | ~~~ 15 | from warmup_scheduler.scheduler import GradualWarmupScheduler 16 | 17 | v = torch.zeros(10) 18 | optim = torch.optim.SGD([v], lr=0.01) 19 | cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 100, eta_min=0, last_epoch=-1) 20 | scheduler = GradualWarmupScheduler(optim, multiplier=8, total_epoch=5, after_scheduler=cosine_scheduler) 21 | for epoch in range(1, 100): 22 | scheduler.step(epoch) 23 | 24 | ~~~ 25 | ## note!!!! 26 | 27 | **max_epoch = num** 28 | 29 | for epoch in range(1, **max_epoch**): 30 | 31 | cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, **max_epoch**, eta_min=0, last_epoch=-1) 32 | 33 | **To change the epoch, change all of the highlighted text.** 34 | -------------------------------------------------------------------------------- /warmup_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from warmup_scheduler.scheduler import GradualWarmupScheduler 3 | -------------------------------------------------------------------------------- /warmup_scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | import torch 4 | import matplotlib.pyplot as plt 5 | 6 | class GradualWarmupScheduler(_LRScheduler): 7 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 8 | self.multiplier = multiplier 9 | self.total_epoch = total_epoch 10 | self.after_scheduler = after_scheduler 11 | self.finished = False 12 | super().__init__(optimizer) 13 | 14 | def get_lr(self): 15 | if self.last_epoch > self.total_epoch: 16 | if self.after_scheduler: 17 | if not self.finished: 18 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 19 | self.finished = True 20 | return self.after_scheduler.get_lr() 21 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 22 | 23 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 24 | 25 | 26 | def step(self, epoch=None, metrics=None): 27 | if self.finished and self.after_scheduler: 28 | if epoch is None: 29 | self.after_scheduler.step(None) 30 | else: 31 | self.after_scheduler.step(epoch - self.total_epoch) 32 | else: 33 | return super(GradualWarmupScheduler, self).step(epoch) 34 | 35 | if __name__ == '__main__': 36 | v = torch.zeros(10) 37 | optim = torch.optim.SGD([v], lr=0.01) 38 | cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 100, eta_min=0, last_epoch=-1) 39 | scheduler = GradualWarmupScheduler(optim, multiplier=8, total_epoch=5, after_scheduler=cosine_scheduler) 40 | a = [] 41 | b = [] 42 | for epoch in range(1, 100): 43 | scheduler.step(epoch) 44 | a.append(epoch) 45 | b.append(optim.param_groups[0]['lr']) 46 | print(epoch, optim.param_groups[0]['lr']) 47 | 48 | plt.plot(a,b) 49 | plt.show() 50 | 51 | 52 | --------------------------------------------------------------------------------