├── lr_scheduler ├── __init__.py ├── lr_scheduler.py ├── warmup_lr_scheduler.py ├── reduce_lr_on_plateau_lr_scheduler.py ├── warmup_reduce_lr_on_plateau_scheduler.py ├── transformer_lr_scheduler.py └── tri_stage_lr_scheduler.py ├── images ├── WarmupLRScheduler.png ├── TriStageLRScheduler.png ├── TransformerLRScheduler.png ├── ReduceLROnPlateauScheduler.png └── WarmupReduceLROnPlateauScheduler.png ├── setup.py ├── LICENSE ├── .gitignore └── README.md /lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/WarmupLRScheduler.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sooftware/pytorch-lr-scheduler/HEAD/images/WarmupLRScheduler.png -------------------------------------------------------------------------------- /images/TriStageLRScheduler.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sooftware/pytorch-lr-scheduler/HEAD/images/TriStageLRScheduler.png -------------------------------------------------------------------------------- /images/TransformerLRScheduler.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sooftware/pytorch-lr-scheduler/HEAD/images/TransformerLRScheduler.png -------------------------------------------------------------------------------- /images/ReduceLROnPlateauScheduler.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sooftware/pytorch-lr-scheduler/HEAD/images/ReduceLROnPlateauScheduler.png -------------------------------------------------------------------------------- /images/WarmupReduceLROnPlateauScheduler.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sooftware/pytorch-lr-scheduler/HEAD/images/WarmupReduceLROnPlateauScheduler.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='lr_scheduler', 4 | version='0.0.1', 5 | description='PyTorch implementation of some learning rate schedulers for deep learning researchers.', 6 | url='https://github.com/sooftware/pytorch-lr-scheduler', 7 | author='Soohwan Kim', 8 | author_email='sh951011@gmail.com', 9 | packages=['lr_scheduler'], 10 | ) 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Soohwan Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /lr_scheduler/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 Soohwan Kim 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from torch.optim.lr_scheduler import _LRScheduler 24 | 25 | class LearningRateScheduler(_LRScheduler): 26 | r""" 27 | Provides inteface of learning rate scheduler. 28 | 29 | Note: 30 | Do not use this class directly, use one of the sub classes. 31 | """ 32 | def __init__(self, optimizer, lr): 33 | self.optimizer = optimizer 34 | self.lr = lr 35 | 36 | def step(self, *args, **kwargs): 37 | raise NotImplementedError 38 | 39 | @staticmethod 40 | def set_lr(optimizer, lr): 41 | for g in optimizer.param_groups: 42 | g['lr'] = lr 43 | 44 | def get_lr(self): 45 | for g in self.optimizer.param_groups: 46 | return g['lr'] 47 | -------------------------------------------------------------------------------- /lr_scheduler/warmup_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 Soohwan Kim 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch 24 | from typing import Optional 25 | from torch.optim import Optimizer 26 | 27 | from lr_scheduler.lr_scheduler import LearningRateScheduler 28 | 29 | 30 | class WarmupLRScheduler(LearningRateScheduler): 31 | """ 32 | Warmup learning rate until `total_steps` 33 | 34 | Args: 35 | optimizer (Optimizer): wrapped optimizer. 36 | 37 | """ 38 | def __init__( 39 | self, 40 | optimizer: Optimizer, 41 | init_lr: float, 42 | peak_lr: float, 43 | warmup_steps: int, 44 | ) -> None: 45 | super(WarmupLRScheduler, self).__init__(optimizer, init_lr) 46 | self.init_lr = init_lr 47 | if warmup_steps != 0: 48 | warmup_rate = peak_lr - init_lr 49 | self.warmup_rate = warmup_rate / warmup_steps 50 | else: 51 | self.warmup_rate = 0 52 | self.update_steps = 1 53 | self.lr = init_lr 54 | self.warmup_steps = warmup_steps 55 | 56 | def step(self, val_loss: Optional[torch.FloatTensor] = None): 57 | if self.update_steps < self.warmup_steps: 58 | lr = self.init_lr + self.warmup_rate * self.update_steps 59 | self.set_lr(self.optimizer, lr) 60 | self.lr = lr 61 | self.update_steps += 1 62 | return self.lr 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | .DS_Store 10 | *.bin 11 | *.zip 12 | *.idea 13 | venv/* 14 | *.pyc 15 | .idea 16 | .ipynb_checkpoints 17 | .DS_Store/* 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | -------------------------------------------------------------------------------- /lr_scheduler/reduce_lr_on_plateau_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 Soohwan Kim 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from omegaconf import DictConfig 24 | from torch.optim import Optimizer 25 | 26 | from lr_scheduler.lr_scheduler import LearningRateScheduler 27 | 28 | 29 | class ReduceLROnPlateauScheduler(LearningRateScheduler): 30 | r""" 31 | Reduce learning rate when a metric has stopped improving. Models often benefit from reducing the learning rate by 32 | a factor of 2-10 once learning stagnates. This scheduler reads a metrics quantity and if no improvement is seen 33 | for a ‘patience’ number of epochs, the learning rate is reduced. 34 | 35 | Args: 36 | optimizer (Optimizer): Optimizer. 37 | lr (float): Initial learning rate. 38 | patience (int): Number of epochs with no improvement after which learning rate will be reduced. 39 | factor (float): Factor by which the learning rate will be reduced. new_lr = lr * factor. 40 | """ 41 | def __init__( 42 | self, 43 | optimizer: Optimizer, 44 | lr: float, 45 | patience: int = 1, 46 | factor: float = 0.3, 47 | ) -> None: 48 | super(ReduceLROnPlateauScheduler, self).__init__(optimizer, lr) 49 | self.lr = lr 50 | self.patience = patience 51 | self.factor = factor 52 | self.val_loss = 100.0 53 | self.count = 0 54 | 55 | def step(self, val_loss: float): 56 | if val_loss is not None: 57 | if self.val_loss < val_loss: 58 | self.count += 1 59 | self.val_loss = val_loss 60 | else: 61 | self.count = 0 62 | self.val_loss = val_loss 63 | 64 | if self.patience == self.count: 65 | self.count = 0 66 | self.lr *= self.factor 67 | self.set_lr(self.optimizer, self.lr) 68 | 69 | return self.lr 70 | -------------------------------------------------------------------------------- /lr_scheduler/warmup_reduce_lr_on_plateau_scheduler.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 Soohwan Kim 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from torch.optim import Optimizer 24 | from typing import Optional 25 | 26 | from lr_scheduler.lr_scheduler import LearningRateScheduler 27 | from lr_scheduler.reduce_lr_on_plateau_lr_scheduler import ReduceLROnPlateauScheduler 28 | from lr_scheduler.warmup_lr_scheduler import WarmupLRScheduler 29 | 30 | 31 | class WarmupReduceLROnPlateauScheduler(LearningRateScheduler): 32 | r""" 33 | Warmup learning rate until `warmup_steps` and reduce learning rate on plateau after. 34 | 35 | Args: 36 | optimizer (Optimizer): wrapped optimizer. 37 | init_lr (float): Initial learning rate. 38 | peak_lr (float): Maximum learning rate. 39 | warmup_steps (int): Warmup the learning rate linearly for the first N updates. 40 | patience (int): Number of epochs with no improvement after which learning rate will be reduced. 41 | factor (float): Factor by which the learning rate will be reduced. new_lr = lr * factor. 42 | """ 43 | def __init__( 44 | self, 45 | optimizer: Optimizer, 46 | init_lr: float, 47 | peak_lr: float, 48 | warmup_steps: int, 49 | patience: int = 1, 50 | factor: float = 0.3, 51 | ) -> None: 52 | super(WarmupReduceLROnPlateauScheduler, self).__init__(optimizer, init_lr) 53 | self.warmup_steps = warmup_steps 54 | self.update_steps = 0 55 | self.warmup_rate = (peak_lr - init_lr) / self.warmup_steps \ 56 | if self.warmup_steps != 0 else 0 57 | self.schedulers = [ 58 | WarmupLRScheduler( 59 | optimizer=optimizer, 60 | init_lr=init_lr, 61 | peak_lr=peak_lr, 62 | warmup_steps=warmup_steps, 63 | ), 64 | ReduceLROnPlateauScheduler( 65 | optimizer=optimizer, 66 | lr=peak_lr, 67 | patience=patience, 68 | factor=factor, 69 | ), 70 | ] 71 | 72 | def _decide_stage(self): 73 | if self.update_steps < self.warmup_steps: 74 | return 0, self.update_steps 75 | else: 76 | return 1, None 77 | 78 | def step(self, val_loss: Optional[float] = None): 79 | stage, steps_in_stage = self._decide_stage() 80 | 81 | if stage == 0: 82 | self.schedulers[0].step() 83 | elif stage == 1: 84 | self.schedulers[1].step(val_loss) 85 | 86 | self.update_steps += 1 87 | 88 | return self.get_lr() 89 | -------------------------------------------------------------------------------- /lr_scheduler/transformer_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 Soohwan Kim 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import math 24 | import torch 25 | from typing import Optional 26 | from torch.optim import Optimizer 27 | 28 | from lr_scheduler.lr_scheduler import LearningRateScheduler 29 | 30 | 31 | class TransformerLRScheduler(LearningRateScheduler): 32 | r""" 33 | Transformer Learning Rate Scheduler proposed in "Attention Is All You Need" 34 | 35 | Args: 36 | optimizer (Optimizer): Optimizer. 37 | init_lr (float): Initial learning rate. 38 | peak_lr (float): Maximum learning rate. 39 | final_lr (float): Final learning rate. 40 | final_lr_scale (float): Final learning rate scale 41 | warmup_steps (int): Warmup the learning rate linearly for the first N updates 42 | decay_steps (int): Steps in decay stages 43 | """ 44 | def __init__( 45 | self, 46 | optimizer: Optimizer, 47 | init_lr: float, 48 | peak_lr: float, 49 | final_lr: float, 50 | final_lr_scale: float, 51 | warmup_steps: int, 52 | decay_steps: int, 53 | ) -> None: 54 | assert isinstance(warmup_steps, int), "warmup_steps should be inteager type" 55 | assert isinstance(decay_steps, int), "total_steps should be inteager type" 56 | 57 | super(TransformerLRScheduler, self).__init__(optimizer, init_lr) 58 | self.final_lr = final_lr 59 | self.peak_lr = peak_lr 60 | self.warmup_steps = warmup_steps 61 | self.decay_steps = decay_steps 62 | 63 | self.warmup_rate = self.peak_lr / self.warmup_steps 64 | self.decay_factor = -math.log(final_lr_scale) / self.decay_steps 65 | 66 | self.init_lr = init_lr 67 | self.update_steps = 0 68 | 69 | def _decide_stage(self): 70 | if self.update_steps < self.warmup_steps: 71 | return 0, self.update_steps 72 | 73 | if self.warmup_steps <= self.update_steps < self.warmup_steps + self.decay_steps: 74 | return 1, self.update_steps - self.warmup_steps 75 | 76 | return 2, None 77 | 78 | def step(self, val_loss: Optional[torch.FloatTensor] = None): 79 | self.update_steps += 1 80 | stage, steps_in_stage = self._decide_stage() 81 | 82 | if stage == 0: 83 | self.lr = self.update_steps * self.warmup_rate 84 | elif stage == 1: 85 | self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage) 86 | elif stage == 2: 87 | self.lr = self.final_lr 88 | else: 89 | raise ValueError("Undefined stage") 90 | 91 | self.set_lr(self.optimizer, self.lr) 92 | 93 | return self.lr 94 | -------------------------------------------------------------------------------- /lr_scheduler/tri_stage_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 Soohwan Kim 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import math 24 | import torch 25 | from typing import Optional 26 | from torch.optim import Optimizer 27 | 28 | from lr_scheduler.lr_scheduler import LearningRateScheduler 29 | 30 | 31 | class TriStageLRScheduler(LearningRateScheduler): 32 | r""" 33 | Tri-Stage Learning Rate Scheduler. Implement the learning rate scheduler in "SpecAugment" 34 | 35 | Args: 36 | optimizer (Optimizer): Optimizer. 37 | init_lr (float): Initial learning rate. 38 | peak_lr (float): Maximum learning rate. 39 | final_lr (float): Final learning rate. 40 | init_lr_scale (float): Initial learning rate scale. 41 | final_lr_scale (float): Final learning rate scale. 42 | warmup_steps (int): Warmup the learning rate linearly for the first N updates. 43 | hold_steps (int): Hold the learning rate for the N updates. 44 | decay_steps (int): Decay the learning rate linearly for the first N updates. 45 | total_steps (int): Total steps in training. 46 | """ 47 | def __init__( 48 | self, 49 | optimizer: Optimizer, 50 | init_lr: float, 51 | peak_lr: float, 52 | final_lr: float, 53 | init_lr_scale: float, 54 | final_lr_scale: float, 55 | warmup_steps: int, 56 | hold_steps: int, 57 | decay_steps: int, 58 | total_steps: int, 59 | ): 60 | assert isinstance(warmup_steps, int), "warmup_steps should be inteager type" 61 | assert isinstance(total_steps, int), "total_steps should be inteager type" 62 | 63 | super(TriStageLRScheduler, self).__init__(optimizer, init_lr) 64 | self.init_lr = init_lr 65 | self.init_lr *= init_lr_scale 66 | self.final_lr = final_lr 67 | self.peak_lr = peak_lr 68 | self.warmup_steps = warmup_steps 69 | self.hold_steps = hold_steps 70 | self.decay_steps = decay_steps 71 | 72 | self.warmup_rate = (self.peak_lr - self.init_lr) / self.warmup_steps if self.warmup_steps != 0 else 0 73 | self.decay_factor = -math.log(final_lr_scale) / self.decay_steps 74 | 75 | self.lr = self.init_lr 76 | self.update_steps = 0 77 | 78 | def _decide_stage(self): 79 | if self.update_steps < self.warmup_steps: 80 | return 0, self.update_steps 81 | 82 | offset = self.warmup_steps 83 | 84 | if self.update_steps < offset + self.hold_steps: 85 | return 1, self.update_steps - offset 86 | 87 | offset += self.hold_steps 88 | 89 | if self.update_steps <= offset + self.decay_steps: 90 | # decay stage 91 | return 2, self.update_steps - offset 92 | 93 | offset += self.decay_steps 94 | 95 | return 3, self.update_steps - offset 96 | 97 | def step(self, val_loss: Optional[torch.FloatTensor] = None): 98 | stage, steps_in_stage = self._decide_stage() 99 | 100 | if stage == 0: 101 | self.lr = self.init_lr + self.warmup_rate * steps_in_stage 102 | elif stage == 1: 103 | self.lr = self.peak_lr 104 | elif stage == 2: 105 | self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage) 106 | elif stage == 3: 107 | self.lr = self.final_lr 108 | else: 109 | raise ValueError("Undefined stage") 110 | 111 | self.set_lr(self.optimizer, self.lr) 112 | self.update_steps += 1 113 | 114 | return self.lr 115 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-lr-scheduler 2 | PyTorch implementation of some learning rate schedulers for deep learning researcher. 3 | 4 | ## Usage 5 | 6 | ### [`WarmupReduceLROnPlateauScheduler`](https://github.com/sooftware/pytorch-lr-scheduler/blob/main/lr_scheduler/warmup_reduce_lr_on_plateau_scheduler.py) 7 | 8 | - Visualize 9 | 10 | 11 | 12 | - Example code 13 | ```python 14 | import torch 15 | 16 | from lr_scheduler.warmup_reduce_lr_on_plateau_scheduler import WarmupReduceLROnPlateauScheduler 17 | 18 | if __name__ == '__main__': 19 | max_epochs, steps_in_epoch = 10, 10000 20 | 21 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] 22 | optimizer = torch.optim.Adam(model, 1e-10) 23 | 24 | scheduler = WarmupReduceLROnPlateauScheduler( 25 | optimizer, 26 | init_lr=1e-10, 27 | peak_lr=1e-4, 28 | warmup_steps=30000, 29 | patience=1, 30 | factor=0.3, 31 | ) 32 | 33 | for epoch in range(max_epochs): 34 | for timestep in range(steps_in_epoch): 35 | ... 36 | ... 37 | if timestep < warmup_steps: 38 | scheduler.step() 39 | 40 | val_loss = validate() 41 | scheduler.step(val_loss) 42 | ``` 43 | 44 | ### [`TransformerLRScheduler`](https://github.com/sooftware/pytorch-lr-scheduler/blob/main/lr_scheduler/transformer_lr_scheduler.py) 45 | 46 | - Visualize 47 | 48 | 49 | 50 | - Example code 51 | 52 | ```python 53 | import torch 54 | 55 | from lr_scheduler.transformer_lr_scheduler import TransformerLRScheduler 56 | 57 | if __name__ == '__main__': 58 | max_epochs, steps_in_epoch = 10, 10000 59 | 60 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] 61 | optimizer = torch.optim.Adam(model, 1e-10) 62 | 63 | scheduler = TransformerLRScheduler( 64 | optimizer=optimizer, 65 | init_lr=1e-10, 66 | peak_lr=0.1, 67 | final_lr=1e-4, 68 | final_lr_scale=0.05, 69 | warmup_steps=3000, 70 | decay_steps=17000, 71 | ) 72 | 73 | for epoch in range(max_epochs): 74 | for timestep in range(steps_in_epoch): 75 | ... 76 | ... 77 | scheduler.step() 78 | ``` 79 | 80 | ### [`TriStageLRScheduler`](https://github.com/sooftware/pytorch-lr-scheduler/blob/main/lr_scheduler/tri_stage_lr_scheduler.py) 81 | 82 | - Visualize 83 | 84 | 85 | 86 | - Example code 87 | 88 | ```python 89 | import torch 90 | 91 | from lr_scheduler.tri_stage_lr_scheduler import TriStageLRScheduler 92 | 93 | if __name__ == '__main__': 94 | max_epochs, steps_in_epoch = 10, 10000 95 | 96 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] 97 | optimizer = torch.optim.Adam(model, 1e-10) 98 | 99 | scheduler = TriStageLRScheduler( 100 | optimizer, 101 | init_lr=1e-10, 102 | peak_lr=1e-4, 103 | final_lr=1e-7, 104 | init_lr_scale=0.01, 105 | final_lr_scale=0.05, 106 | warmup_steps=30000, 107 | hold_steps=70000, 108 | decay_steps=100000, 109 | total_steps=200000, 110 | ) 111 | 112 | for epoch in range(max_epochs): 113 | for timestep in range(steps_in_epoch): 114 | ... 115 | ... 116 | scheduler.step() 117 | ``` 118 | 119 | ### [`ReduceLROnPlateauScheduler`](https://github.com/sooftware/pytorch-lr-scheduler/blob/main/lr_scheduler/reduce_lr_on_plateau_lr_scheduler.py) 120 | 121 | - Visualize 122 | 123 | 124 | 125 | - Example code 126 | 127 | ```python 128 | import torch 129 | 130 | from lr_scheduler.reduce_lr_on_plateau_lr_scheduler import ReduceLROnPlateauScheduler 131 | 132 | if __name__ == '__main__': 133 | max_epochs, steps_in_epoch = 10, 10000 134 | 135 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] 136 | optimizer = torch.optim.Adam(model, 1e-4) 137 | 138 | scheduler = ReduceLROnPlateauScheduler(optimizer, patience=1, factor=0.3) 139 | 140 | for epoch in range(max_epochs): 141 | for timestep in range(steps_in_epoch): 142 | ... 143 | ... 144 | 145 | val_loss = validate() 146 | scheduler.step(val_loss) 147 | ``` 148 | 149 | 150 | 151 | ### [`WarmupLRScheduler`](https://github.com/sooftware/pytorch-lr-scheduler/blob/main/lr_scheduler/warmup_lr_scheduler.py) 152 | 153 | - Visualize 154 | 155 | 156 | 157 | - Example code 158 | 159 | ```python 160 | import torch 161 | 162 | from lr_scheduler.warmup_lr_scheduler import WarmupLRScheduler 163 | 164 | if __name__ == '__main__': 165 | max_epochs, steps_in_epoch = 10, 10000 166 | 167 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] 168 | optimizer = torch.optim.Adam(model, 1e-10) 169 | 170 | scheduler = WarmupLRScheduler( 171 | optimizer, 172 | init_lr=1e-10, 173 | peak_lr=1e-4, 174 | warmup_steps=4000, 175 | ) 176 | 177 | for epoch in range(max_epochs): 178 | for timestep in range(steps_in_epoch): 179 | ... 180 | ... 181 | scheduler.step() 182 | ``` 183 | 184 | 185 | ## Installation 186 | ```bash 187 | git clone git@github.com:sooftware/pytorch-lr-scheduler.git 188 | cd pytorch-lr-scheduler 189 | pip install . 190 | ``` 191 | 192 | ## Troubleshoots and Contributing 193 | If you have any questions, bug reports, and feature requests, please [open an issue](https://github.com/sooftware/pytorch-lr-scheduler/issues) on Github. 194 | 195 | I appreciate any kind of feedback or contribution. Feel free to proceed with small issues like bug fixes, documentation improvement. For major contributions and new features, please discuss with the collaborators in corresponding issues. 196 | 197 | ## Code Style 198 | I follow [PEP-8](https://www.python.org/dev/peps/pep-0008/) for code style. Especially the style of docstrings is important to generate documentation. 199 | 200 | ## License 201 | This project is licensed under the MIT LICENSE - see the [LICENSE.md](https://github.com/sooftware/pytorch-lr-scheduler/blob/master/LICENSE) file for details 202 | --------------------------------------------------------------------------------