├── .gitignore ├── LICENSE ├── README.md ├── setup.py └── torch_poly_lr_decay ├── __init__.py ├── run.py └── torch_poly_lr_decay.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Park Chun Myong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-polynomial-lr-decay 2 | Polynomial Learning Rate Decay Scheduler for PyTorch 3 | 4 | This scheduler is frequently used in many DL paper. But there is no official implementation in PyTorch. So I propose this code. 5 | 6 | ## Install 7 | 8 | ``` 9 | $ pip install git+https://github.com/cmpark0126/pytorch-polynomial-lr-decay.git 10 | ``` 11 | 12 | ## Usage 13 | 14 | ```python 15 | from torch_poly_lr_decay import PolynomialLRDecay 16 | 17 | scheduler_poly_lr_decay = PolynomialLRDecay(optim, max_decay_steps=100, end_learning_rate=0.0001, power=2.0) 18 | 19 | for epoch in range(train_epoch): 20 | scheduler_poly_lr_decay.step() # you can handle step as epoch number 21 | ... 22 | ``` 23 | 24 | or 25 | 26 | ```python 27 | from torch_poly_lr_decay import PolynomialLRDecay 28 | 29 | scheduler_poly_lr_decay = PolynomialLRDecay(optim, max_decay_steps=100, end_learning_rate=0.0001, power=2.0) 30 | 31 | ... 32 | 33 | for batch_idx, (inputs, targets) in enumerate(trainloader): 34 | scheduler_poly_lr_decay.step() # also, you can handle step as each iter number 35 | ``` 36 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open('README.md') as f: 4 | long_description=f.read() 5 | 6 | setup( 7 | name='torch_poly_lr_decay', 8 | version='0.0.1', 9 | author='Chunmyong Park', 10 | description='Polynomial Learning Rate Decay Scheduler for PyTorch', 11 | long_description=long_description, 12 | long_description_content_type='text/markdown', 13 | maintainer='Chunmyong Park', 14 | zip_safe=False, 15 | packages=['torch_poly_lr_decay'], 16 | install_requires=['torch'], 17 | ) 18 | -------------------------------------------------------------------------------- /torch_poly_lr_decay/__init__.py: -------------------------------------------------------------------------------- 1 | name = 'torchlars' 2 | 3 | from .torch_poly_lr_decay import PolynomialLRDecay 4 | -------------------------------------------------------------------------------- /torch_poly_lr_decay/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_poly_lr_decay import PolynomialLRDecay 4 | 5 | 6 | if __name__ == '__main__': 7 | v = torch.zeros(10) 8 | optim = torch.optim.SGD([v], lr=0.01) 9 | scheduler = PolynomialLRDecay(optim, max_decay_steps=19, end_learning_rate=0.0001, power=2.0) 10 | 11 | for epoch in range(1, 20): 12 | scheduler.step(epoch) 13 | 14 | print(epoch, optim.param_groups[0]['lr']) 15 | -------------------------------------------------------------------------------- /torch_poly_lr_decay/torch_poly_lr_decay.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | class PolynomialLRDecay(_LRScheduler): 4 | """Polynomial learning rate decay until step reach to max_decay_step 5 | 6 | Args: 7 | optimizer (Optimizer): Wrapped optimizer. 8 | max_decay_steps: after this step, we stop decreasing learning rate 9 | end_learning_rate: scheduler stoping learning rate decay, value of learning rate must be this value 10 | power: The power of the polynomial. 11 | """ 12 | 13 | def __init__(self, optimizer, max_decay_steps, end_learning_rate=0.0001, power=1.0): 14 | if max_decay_steps <= 1.: 15 | raise ValueError('max_decay_steps should be greater than 1.') 16 | self.max_decay_steps = max_decay_steps 17 | self.end_learning_rate = end_learning_rate 18 | self.power = power 19 | self.last_step = 0 20 | super().__init__(optimizer) 21 | 22 | def get_lr(self): 23 | if self.last_step > self.max_decay_steps: 24 | return [self.end_learning_rate for _ in self.base_lrs] 25 | 26 | return [(base_lr - self.end_learning_rate) * 27 | ((1 - self.last_step / self.max_decay_steps) ** (self.power)) + 28 | self.end_learning_rate for base_lr in self.base_lrs] 29 | 30 | def step(self, step=None): 31 | if step is None: 32 | step = self.last_step + 1 33 | self.last_step = step if step != 0 else 1 34 | if self.last_step <= self.max_decay_steps: 35 | decay_lrs = [(base_lr - self.end_learning_rate) * 36 | ((1 - self.last_step / self.max_decay_steps) ** (self.power)) + 37 | self.end_learning_rate for base_lr in self.base_lrs] 38 | for param_group, lr in zip(self.optimizer.param_groups, decay_lrs): 39 | param_group['lr'] = lr 40 | --------------------------------------------------------------------------------