├── lr_scheduler
├── __init__.py
├── lr_scheduler.py
├── warmup_lr_scheduler.py
├── reduce_lr_on_plateau_lr_scheduler.py
├── warmup_reduce_lr_on_plateau_scheduler.py
├── transformer_lr_scheduler.py
└── tri_stage_lr_scheduler.py
├── images
├── WarmupLRScheduler.png
├── TriStageLRScheduler.png
├── TransformerLRScheduler.png
├── ReduceLROnPlateauScheduler.png
└── WarmupReduceLROnPlateauScheduler.png
├── setup.py
├── LICENSE
├── .gitignore
└── README.md
/lr_scheduler/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/images/WarmupLRScheduler.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sooftware/pytorch-lr-scheduler/HEAD/images/WarmupLRScheduler.png
--------------------------------------------------------------------------------
/images/TriStageLRScheduler.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sooftware/pytorch-lr-scheduler/HEAD/images/TriStageLRScheduler.png
--------------------------------------------------------------------------------
/images/TransformerLRScheduler.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sooftware/pytorch-lr-scheduler/HEAD/images/TransformerLRScheduler.png
--------------------------------------------------------------------------------
/images/ReduceLROnPlateauScheduler.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sooftware/pytorch-lr-scheduler/HEAD/images/ReduceLROnPlateauScheduler.png
--------------------------------------------------------------------------------
/images/WarmupReduceLROnPlateauScheduler.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sooftware/pytorch-lr-scheduler/HEAD/images/WarmupReduceLROnPlateauScheduler.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(name='lr_scheduler',
4 | version='0.0.1',
5 | description='PyTorch implementation of some learning rate schedulers for deep learning researchers.',
6 | url='https://github.com/sooftware/pytorch-lr-scheduler',
7 | author='Soohwan Kim',
8 | author_email='sh951011@gmail.com',
9 | packages=['lr_scheduler'],
10 | )
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Soohwan Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/lr_scheduler/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from torch.optim.lr_scheduler import _LRScheduler
24 |
25 | class LearningRateScheduler(_LRScheduler):
26 | r"""
27 | Provides inteface of learning rate scheduler.
28 |
29 | Note:
30 | Do not use this class directly, use one of the sub classes.
31 | """
32 | def __init__(self, optimizer, lr):
33 | self.optimizer = optimizer
34 | self.lr = lr
35 |
36 | def step(self, *args, **kwargs):
37 | raise NotImplementedError
38 |
39 | @staticmethod
40 | def set_lr(optimizer, lr):
41 | for g in optimizer.param_groups:
42 | g['lr'] = lr
43 |
44 | def get_lr(self):
45 | for g in self.optimizer.param_groups:
46 | return g['lr']
47 |
--------------------------------------------------------------------------------
/lr_scheduler/warmup_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import torch
24 | from typing import Optional
25 | from torch.optim import Optimizer
26 |
27 | from lr_scheduler.lr_scheduler import LearningRateScheduler
28 |
29 |
30 | class WarmupLRScheduler(LearningRateScheduler):
31 | """
32 | Warmup learning rate until `total_steps`
33 |
34 | Args:
35 | optimizer (Optimizer): wrapped optimizer.
36 |
37 | """
38 | def __init__(
39 | self,
40 | optimizer: Optimizer,
41 | init_lr: float,
42 | peak_lr: float,
43 | warmup_steps: int,
44 | ) -> None:
45 | super(WarmupLRScheduler, self).__init__(optimizer, init_lr)
46 | self.init_lr = init_lr
47 | if warmup_steps != 0:
48 | warmup_rate = peak_lr - init_lr
49 | self.warmup_rate = warmup_rate / warmup_steps
50 | else:
51 | self.warmup_rate = 0
52 | self.update_steps = 1
53 | self.lr = init_lr
54 | self.warmup_steps = warmup_steps
55 |
56 | def step(self, val_loss: Optional[torch.FloatTensor] = None):
57 | if self.update_steps < self.warmup_steps:
58 | lr = self.init_lr + self.warmup_rate * self.update_steps
59 | self.set_lr(self.optimizer, lr)
60 | self.lr = lr
61 | self.update_steps += 1
62 | return self.lr
63 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | .DS_Store
10 | *.bin
11 | *.zip
12 | *.idea
13 | venv/*
14 | *.pyc
15 | .idea
16 | .ipynb_checkpoints
17 | .DS_Store/*
18 |
19 | # Distribution / packaging
20 | .Python
21 | build/
22 | develop-eggs/
23 | dist/
24 | downloads/
25 | eggs/
26 | .eggs/
27 | lib/
28 | lib64/
29 | parts/
30 | sdist/
31 | var/
32 | wheels/
33 | pip-wheel-metadata/
34 | share/python-wheels/
35 | *.egg-info/
36 | .installed.cfg
37 | *.egg
38 | MANIFEST
39 |
40 | # PyInstaller
41 | # Usually these files are written by a python script from a template
42 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 | *.manifest
44 | *.spec
45 |
46 | # Installer logs
47 | pip-log.txt
48 | pip-delete-this-directory.txt
49 |
50 | # Unit test / coverage reports
51 | htmlcov/
52 | .tox/
53 | .nox/
54 | .coverage
55 | .coverage.*
56 | .cache
57 | nosetests.xml
58 | coverage.xml
59 | *.cover
60 | *.py,cover
61 | .hypothesis/
62 | .pytest_cache/
63 |
64 | # Translations
65 | *.mo
66 | *.pot
67 |
68 | # Django stuff:
69 | *.log
70 | local_settings.py
71 | db.sqlite3
72 | db.sqlite3-journal
73 |
74 | # Flask stuff:
75 | instance/
76 | .webassets-cache
77 |
78 | # Scrapy stuff:
79 | .scrapy
80 |
81 | # Sphinx documentation
82 | docs/_build/
83 |
84 | # PyBuilder
85 | target/
86 |
87 | # Jupyter Notebook
88 | .ipynb_checkpoints
89 |
90 | # IPython
91 | profile_default/
92 | ipython_config.py
93 |
94 | # pyenv
95 | .python-version
96 |
97 | # pipenv
98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
101 | # install all needed dependencies.
102 | #Pipfile.lock
103 |
104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
105 | __pypackages__/
106 |
107 | # Celery stuff
108 | celerybeat-schedule
109 | celerybeat.pid
110 |
111 | # SageMath parsed files
112 | *.sage.py
113 |
114 | # Environments
115 | .env
116 | .venv
117 | env/
118 | venv/
119 | ENV/
120 | env.bak/
121 | venv.bak/
122 |
123 | # Spyder project settings
124 | .spyderproject
125 | .spyproject
126 |
127 | # Rope project settings
128 | .ropeproject
129 |
130 | # mkdocs documentation
131 | /site
132 |
133 | # mypy
134 | .mypy_cache/
135 | .dmypy.json
136 | dmypy.json
137 |
138 | # Pyre type checker
139 | .pyre/
140 |
--------------------------------------------------------------------------------
/lr_scheduler/reduce_lr_on_plateau_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from omegaconf import DictConfig
24 | from torch.optim import Optimizer
25 |
26 | from lr_scheduler.lr_scheduler import LearningRateScheduler
27 |
28 |
29 | class ReduceLROnPlateauScheduler(LearningRateScheduler):
30 | r"""
31 | Reduce learning rate when a metric has stopped improving. Models often benefit from reducing the learning rate by
32 | a factor of 2-10 once learning stagnates. This scheduler reads a metrics quantity and if no improvement is seen
33 | for a ‘patience’ number of epochs, the learning rate is reduced.
34 |
35 | Args:
36 | optimizer (Optimizer): Optimizer.
37 | lr (float): Initial learning rate.
38 | patience (int): Number of epochs with no improvement after which learning rate will be reduced.
39 | factor (float): Factor by which the learning rate will be reduced. new_lr = lr * factor.
40 | """
41 | def __init__(
42 | self,
43 | optimizer: Optimizer,
44 | lr: float,
45 | patience: int = 1,
46 | factor: float = 0.3,
47 | ) -> None:
48 | super(ReduceLROnPlateauScheduler, self).__init__(optimizer, lr)
49 | self.lr = lr
50 | self.patience = patience
51 | self.factor = factor
52 | self.val_loss = 100.0
53 | self.count = 0
54 |
55 | def step(self, val_loss: float):
56 | if val_loss is not None:
57 | if self.val_loss < val_loss:
58 | self.count += 1
59 | self.val_loss = val_loss
60 | else:
61 | self.count = 0
62 | self.val_loss = val_loss
63 |
64 | if self.patience == self.count:
65 | self.count = 0
66 | self.lr *= self.factor
67 | self.set_lr(self.optimizer, self.lr)
68 |
69 | return self.lr
70 |
--------------------------------------------------------------------------------
/lr_scheduler/warmup_reduce_lr_on_plateau_scheduler.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from torch.optim import Optimizer
24 | from typing import Optional
25 |
26 | from lr_scheduler.lr_scheduler import LearningRateScheduler
27 | from lr_scheduler.reduce_lr_on_plateau_lr_scheduler import ReduceLROnPlateauScheduler
28 | from lr_scheduler.warmup_lr_scheduler import WarmupLRScheduler
29 |
30 |
31 | class WarmupReduceLROnPlateauScheduler(LearningRateScheduler):
32 | r"""
33 | Warmup learning rate until `warmup_steps` and reduce learning rate on plateau after.
34 |
35 | Args:
36 | optimizer (Optimizer): wrapped optimizer.
37 | init_lr (float): Initial learning rate.
38 | peak_lr (float): Maximum learning rate.
39 | warmup_steps (int): Warmup the learning rate linearly for the first N updates.
40 | patience (int): Number of epochs with no improvement after which learning rate will be reduced.
41 | factor (float): Factor by which the learning rate will be reduced. new_lr = lr * factor.
42 | """
43 | def __init__(
44 | self,
45 | optimizer: Optimizer,
46 | init_lr: float,
47 | peak_lr: float,
48 | warmup_steps: int,
49 | patience: int = 1,
50 | factor: float = 0.3,
51 | ) -> None:
52 | super(WarmupReduceLROnPlateauScheduler, self).__init__(optimizer, init_lr)
53 | self.warmup_steps = warmup_steps
54 | self.update_steps = 0
55 | self.warmup_rate = (peak_lr - init_lr) / self.warmup_steps \
56 | if self.warmup_steps != 0 else 0
57 | self.schedulers = [
58 | WarmupLRScheduler(
59 | optimizer=optimizer,
60 | init_lr=init_lr,
61 | peak_lr=peak_lr,
62 | warmup_steps=warmup_steps,
63 | ),
64 | ReduceLROnPlateauScheduler(
65 | optimizer=optimizer,
66 | lr=peak_lr,
67 | patience=patience,
68 | factor=factor,
69 | ),
70 | ]
71 |
72 | def _decide_stage(self):
73 | if self.update_steps < self.warmup_steps:
74 | return 0, self.update_steps
75 | else:
76 | return 1, None
77 |
78 | def step(self, val_loss: Optional[float] = None):
79 | stage, steps_in_stage = self._decide_stage()
80 |
81 | if stage == 0:
82 | self.schedulers[0].step()
83 | elif stage == 1:
84 | self.schedulers[1].step(val_loss)
85 |
86 | self.update_steps += 1
87 |
88 | return self.get_lr()
89 |
--------------------------------------------------------------------------------
/lr_scheduler/transformer_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import math
24 | import torch
25 | from typing import Optional
26 | from torch.optim import Optimizer
27 |
28 | from lr_scheduler.lr_scheduler import LearningRateScheduler
29 |
30 |
31 | class TransformerLRScheduler(LearningRateScheduler):
32 | r"""
33 | Transformer Learning Rate Scheduler proposed in "Attention Is All You Need"
34 |
35 | Args:
36 | optimizer (Optimizer): Optimizer.
37 | init_lr (float): Initial learning rate.
38 | peak_lr (float): Maximum learning rate.
39 | final_lr (float): Final learning rate.
40 | final_lr_scale (float): Final learning rate scale
41 | warmup_steps (int): Warmup the learning rate linearly for the first N updates
42 | decay_steps (int): Steps in decay stages
43 | """
44 | def __init__(
45 | self,
46 | optimizer: Optimizer,
47 | init_lr: float,
48 | peak_lr: float,
49 | final_lr: float,
50 | final_lr_scale: float,
51 | warmup_steps: int,
52 | decay_steps: int,
53 | ) -> None:
54 | assert isinstance(warmup_steps, int), "warmup_steps should be inteager type"
55 | assert isinstance(decay_steps, int), "total_steps should be inteager type"
56 |
57 | super(TransformerLRScheduler, self).__init__(optimizer, init_lr)
58 | self.final_lr = final_lr
59 | self.peak_lr = peak_lr
60 | self.warmup_steps = warmup_steps
61 | self.decay_steps = decay_steps
62 |
63 | self.warmup_rate = self.peak_lr / self.warmup_steps
64 | self.decay_factor = -math.log(final_lr_scale) / self.decay_steps
65 |
66 | self.init_lr = init_lr
67 | self.update_steps = 0
68 |
69 | def _decide_stage(self):
70 | if self.update_steps < self.warmup_steps:
71 | return 0, self.update_steps
72 |
73 | if self.warmup_steps <= self.update_steps < self.warmup_steps + self.decay_steps:
74 | return 1, self.update_steps - self.warmup_steps
75 |
76 | return 2, None
77 |
78 | def step(self, val_loss: Optional[torch.FloatTensor] = None):
79 | self.update_steps += 1
80 | stage, steps_in_stage = self._decide_stage()
81 |
82 | if stage == 0:
83 | self.lr = self.update_steps * self.warmup_rate
84 | elif stage == 1:
85 | self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
86 | elif stage == 2:
87 | self.lr = self.final_lr
88 | else:
89 | raise ValueError("Undefined stage")
90 |
91 | self.set_lr(self.optimizer, self.lr)
92 |
93 | return self.lr
94 |
--------------------------------------------------------------------------------
/lr_scheduler/tri_stage_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import math
24 | import torch
25 | from typing import Optional
26 | from torch.optim import Optimizer
27 |
28 | from lr_scheduler.lr_scheduler import LearningRateScheduler
29 |
30 |
31 | class TriStageLRScheduler(LearningRateScheduler):
32 | r"""
33 | Tri-Stage Learning Rate Scheduler. Implement the learning rate scheduler in "SpecAugment"
34 |
35 | Args:
36 | optimizer (Optimizer): Optimizer.
37 | init_lr (float): Initial learning rate.
38 | peak_lr (float): Maximum learning rate.
39 | final_lr (float): Final learning rate.
40 | init_lr_scale (float): Initial learning rate scale.
41 | final_lr_scale (float): Final learning rate scale.
42 | warmup_steps (int): Warmup the learning rate linearly for the first N updates.
43 | hold_steps (int): Hold the learning rate for the N updates.
44 | decay_steps (int): Decay the learning rate linearly for the first N updates.
45 | total_steps (int): Total steps in training.
46 | """
47 | def __init__(
48 | self,
49 | optimizer: Optimizer,
50 | init_lr: float,
51 | peak_lr: float,
52 | final_lr: float,
53 | init_lr_scale: float,
54 | final_lr_scale: float,
55 | warmup_steps: int,
56 | hold_steps: int,
57 | decay_steps: int,
58 | total_steps: int,
59 | ):
60 | assert isinstance(warmup_steps, int), "warmup_steps should be inteager type"
61 | assert isinstance(total_steps, int), "total_steps should be inteager type"
62 |
63 | super(TriStageLRScheduler, self).__init__(optimizer, init_lr)
64 | self.init_lr = init_lr
65 | self.init_lr *= init_lr_scale
66 | self.final_lr = final_lr
67 | self.peak_lr = peak_lr
68 | self.warmup_steps = warmup_steps
69 | self.hold_steps = hold_steps
70 | self.decay_steps = decay_steps
71 |
72 | self.warmup_rate = (self.peak_lr - self.init_lr) / self.warmup_steps if self.warmup_steps != 0 else 0
73 | self.decay_factor = -math.log(final_lr_scale) / self.decay_steps
74 |
75 | self.lr = self.init_lr
76 | self.update_steps = 0
77 |
78 | def _decide_stage(self):
79 | if self.update_steps < self.warmup_steps:
80 | return 0, self.update_steps
81 |
82 | offset = self.warmup_steps
83 |
84 | if self.update_steps < offset + self.hold_steps:
85 | return 1, self.update_steps - offset
86 |
87 | offset += self.hold_steps
88 |
89 | if self.update_steps <= offset + self.decay_steps:
90 | # decay stage
91 | return 2, self.update_steps - offset
92 |
93 | offset += self.decay_steps
94 |
95 | return 3, self.update_steps - offset
96 |
97 | def step(self, val_loss: Optional[torch.FloatTensor] = None):
98 | stage, steps_in_stage = self._decide_stage()
99 |
100 | if stage == 0:
101 | self.lr = self.init_lr + self.warmup_rate * steps_in_stage
102 | elif stage == 1:
103 | self.lr = self.peak_lr
104 | elif stage == 2:
105 | self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
106 | elif stage == 3:
107 | self.lr = self.final_lr
108 | else:
109 | raise ValueError("Undefined stage")
110 |
111 | self.set_lr(self.optimizer, self.lr)
112 | self.update_steps += 1
113 |
114 | return self.lr
115 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-lr-scheduler
2 | PyTorch implementation of some learning rate schedulers for deep learning researcher.
3 |
4 | ## Usage
5 |
6 | ### [`WarmupReduceLROnPlateauScheduler`](https://github.com/sooftware/pytorch-lr-scheduler/blob/main/lr_scheduler/warmup_reduce_lr_on_plateau_scheduler.py)
7 |
8 | - Visualize
9 |
10 |
11 |
12 | - Example code
13 | ```python
14 | import torch
15 |
16 | from lr_scheduler.warmup_reduce_lr_on_plateau_scheduler import WarmupReduceLROnPlateauScheduler
17 |
18 | if __name__ == '__main__':
19 | max_epochs, steps_in_epoch = 10, 10000
20 |
21 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
22 | optimizer = torch.optim.Adam(model, 1e-10)
23 |
24 | scheduler = WarmupReduceLROnPlateauScheduler(
25 | optimizer,
26 | init_lr=1e-10,
27 | peak_lr=1e-4,
28 | warmup_steps=30000,
29 | patience=1,
30 | factor=0.3,
31 | )
32 |
33 | for epoch in range(max_epochs):
34 | for timestep in range(steps_in_epoch):
35 | ...
36 | ...
37 | if timestep < warmup_steps:
38 | scheduler.step()
39 |
40 | val_loss = validate()
41 | scheduler.step(val_loss)
42 | ```
43 |
44 | ### [`TransformerLRScheduler`](https://github.com/sooftware/pytorch-lr-scheduler/blob/main/lr_scheduler/transformer_lr_scheduler.py)
45 |
46 | - Visualize
47 |
48 |
49 |
50 | - Example code
51 |
52 | ```python
53 | import torch
54 |
55 | from lr_scheduler.transformer_lr_scheduler import TransformerLRScheduler
56 |
57 | if __name__ == '__main__':
58 | max_epochs, steps_in_epoch = 10, 10000
59 |
60 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
61 | optimizer = torch.optim.Adam(model, 1e-10)
62 |
63 | scheduler = TransformerLRScheduler(
64 | optimizer=optimizer,
65 | init_lr=1e-10,
66 | peak_lr=0.1,
67 | final_lr=1e-4,
68 | final_lr_scale=0.05,
69 | warmup_steps=3000,
70 | decay_steps=17000,
71 | )
72 |
73 | for epoch in range(max_epochs):
74 | for timestep in range(steps_in_epoch):
75 | ...
76 | ...
77 | scheduler.step()
78 | ```
79 |
80 | ### [`TriStageLRScheduler`](https://github.com/sooftware/pytorch-lr-scheduler/blob/main/lr_scheduler/tri_stage_lr_scheduler.py)
81 |
82 | - Visualize
83 |
84 |
85 |
86 | - Example code
87 |
88 | ```python
89 | import torch
90 |
91 | from lr_scheduler.tri_stage_lr_scheduler import TriStageLRScheduler
92 |
93 | if __name__ == '__main__':
94 | max_epochs, steps_in_epoch = 10, 10000
95 |
96 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
97 | optimizer = torch.optim.Adam(model, 1e-10)
98 |
99 | scheduler = TriStageLRScheduler(
100 | optimizer,
101 | init_lr=1e-10,
102 | peak_lr=1e-4,
103 | final_lr=1e-7,
104 | init_lr_scale=0.01,
105 | final_lr_scale=0.05,
106 | warmup_steps=30000,
107 | hold_steps=70000,
108 | decay_steps=100000,
109 | total_steps=200000,
110 | )
111 |
112 | for epoch in range(max_epochs):
113 | for timestep in range(steps_in_epoch):
114 | ...
115 | ...
116 | scheduler.step()
117 | ```
118 |
119 | ### [`ReduceLROnPlateauScheduler`](https://github.com/sooftware/pytorch-lr-scheduler/blob/main/lr_scheduler/reduce_lr_on_plateau_lr_scheduler.py)
120 |
121 | - Visualize
122 |
123 |
124 |
125 | - Example code
126 |
127 | ```python
128 | import torch
129 |
130 | from lr_scheduler.reduce_lr_on_plateau_lr_scheduler import ReduceLROnPlateauScheduler
131 |
132 | if __name__ == '__main__':
133 | max_epochs, steps_in_epoch = 10, 10000
134 |
135 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
136 | optimizer = torch.optim.Adam(model, 1e-4)
137 |
138 | scheduler = ReduceLROnPlateauScheduler(optimizer, patience=1, factor=0.3)
139 |
140 | for epoch in range(max_epochs):
141 | for timestep in range(steps_in_epoch):
142 | ...
143 | ...
144 |
145 | val_loss = validate()
146 | scheduler.step(val_loss)
147 | ```
148 |
149 |
150 |
151 | ### [`WarmupLRScheduler`](https://github.com/sooftware/pytorch-lr-scheduler/blob/main/lr_scheduler/warmup_lr_scheduler.py)
152 |
153 | - Visualize
154 |
155 |
156 |
157 | - Example code
158 |
159 | ```python
160 | import torch
161 |
162 | from lr_scheduler.warmup_lr_scheduler import WarmupLRScheduler
163 |
164 | if __name__ == '__main__':
165 | max_epochs, steps_in_epoch = 10, 10000
166 |
167 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
168 | optimizer = torch.optim.Adam(model, 1e-10)
169 |
170 | scheduler = WarmupLRScheduler(
171 | optimizer,
172 | init_lr=1e-10,
173 | peak_lr=1e-4,
174 | warmup_steps=4000,
175 | )
176 |
177 | for epoch in range(max_epochs):
178 | for timestep in range(steps_in_epoch):
179 | ...
180 | ...
181 | scheduler.step()
182 | ```
183 |
184 |
185 | ## Installation
186 | ```bash
187 | git clone git@github.com:sooftware/pytorch-lr-scheduler.git
188 | cd pytorch-lr-scheduler
189 | pip install .
190 | ```
191 |
192 | ## Troubleshoots and Contributing
193 | If you have any questions, bug reports, and feature requests, please [open an issue](https://github.com/sooftware/pytorch-lr-scheduler/issues) on Github.
194 |
195 | I appreciate any kind of feedback or contribution. Feel free to proceed with small issues like bug fixes, documentation improvement. For major contributions and new features, please discuss with the collaborators in corresponding issues.
196 |
197 | ## Code Style
198 | I follow [PEP-8](https://www.python.org/dev/peps/pep-0008/) for code style. Especially the style of docstrings is important to generate documentation.
199 |
200 | ## License
201 | This project is licensed under the MIT LICENSE - see the [LICENSE.md](https://github.com/sooftware/pytorch-lr-scheduler/blob/master/LICENSE) file for details
202 |
--------------------------------------------------------------------------------