├── .github
└── workflows
│ ├── download-emnist.yml
│ ├── example-cifar10.yml
│ ├── example-emnist.yml
│ ├── example-plots.yml
│ ├── pythonpackage.yml
│ └── sphinx-gh-pages.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── docs
├── Makefile
├── _templates
│ └── versions.html
├── build.sh
├── conf.py
├── gh-pages-redirect.html
├── index.rst
├── make.bat
├── manual_warmup.rst
├── radam_warmup.rst
└── untuned_warmup.rst
├── examples
├── cifar10
│ ├── README.md
│ ├── download.py
│ ├── figs
│ │ ├── fig-history-adamw-w-radam-warmup.png
│ │ ├── fig-resnet-sgd-scores-best.png
│ │ ├── fig-resnet-sgd-scores-mean-std.png
│ │ ├── fig-scores-adamax-w-warmup-lr0-1.png
│ │ ├── fig-scores-adamax-w-warmup-vs-sgd.png
│ │ ├── fig-scores-adamw-w-warmup-edf0-99-vs-radamw.png
│ │ ├── fig-scores-adamw-w-warmup-edf0-99-vs-sgd.png
│ │ ├── fig-scores-adamw-w-warmup-vs-sgd.png
│ │ ├── fig-scores-amsgradw-w-warmup-vs-sgd.png
│ │ ├── fig-scores-nadamw-w-warmup-vs-sgd.png
│ │ └── fig-scores-no-warmup.png
│ ├── main.py
│ └── plot.py
├── emnist
│ ├── README.md
│ ├── download.py
│ ├── figs
│ │ ├── accuracy.png
│ │ ├── fig-history-adamw-w-radam-warmup.png
│ │ └── learning_rate.png
│ ├── main.py
│ └── plot.py
└── plots
│ ├── README.md
│ ├── effective_warmup_period.py
│ ├── figs
│ ├── warmup_period.png
│ └── warmup_schedule.png
│ └── warmup_schedule.py
├── pyproject.toml
├── pytorch_warmup
├── __init__.py
├── base.py
├── radam.py
└── untuned.py
└── test
├── __init__.py
├── test_base.py
├── test_radam.py
└── test_untuned.py
/.github/workflows/download-emnist.yml:
--------------------------------------------------------------------------------
1 | name: Download emnist
2 |
3 | on: [push]
4 |
5 | jobs:
6 | download:
7 |
8 | runs-on: ubuntu-latest
9 | strategy:
10 | max-parallel: 8
11 | matrix:
12 | torchvision-version: [0.10.1, 0.11.3, 0.12.0, 0.13.1, 0.14.1, 0.15.2, 0.16.2, 0.17.2, 0.18.1, 0.19.1, 0.20.1]
13 | include:
14 | - pytorch-version: 1.9.1
15 | torchvision-version: 0.10.1
16 | - pytorch-version: 1.10.2
17 | torchvision-version: 0.11.3
18 | - pytorch-version: 1.11.0
19 | torchvision-version: 0.12.0
20 | - pytorch-version: 1.12.1
21 | torchvision-version: 0.13.1
22 | - pytorch-version: 1.13.1
23 | torchvision-version: 0.14.1
24 | - pytorch-version: 2.0.1
25 | torchvision-version: 0.15.2
26 | - pytorch-version: 2.1.2
27 | torchvision-version: 0.16.2
28 | - pytorch-version: 2.2.2
29 | torchvision-version: 0.17.2
30 | - pytorch-version: 2.3.1
31 | torchvision-version: 0.18.1
32 | - pytorch-version: 2.4.1
33 | torchvision-version: 0.19.1
34 | - pytorch-version: 2.5.1
35 | torchvision-version: 0.20.1
36 | steps:
37 | - uses: actions/checkout@v4
38 | - name: Set up Python 3.9
39 | uses: actions/setup-python@v5
40 | with:
41 | python-version: 3.9
42 | - name: Install dependencies
43 | run: |
44 | python -m pip install --upgrade pip
45 | pip install 'numpy<2'
46 | pip install torch==${{ matrix.pytorch-version }}+cpu -f https://download.pytorch.org/whl/torch
47 | pip install torchvision==${{ matrix.torchvision-version }}+cpu -f https://download.pytorch.org/whl/torchvision
48 | pip install setuptools
49 | pip install requests
50 | - name: Install package
51 | run: python -m pip install .
52 | - name: Download EMNIST dataset
53 | run: python examples/emnist/download.py
54 | - name: Extract EMNIST dataset
55 | run: python examples/emnist/main.py --epochs 0
56 |
--------------------------------------------------------------------------------
/.github/workflows/example-cifar10.yml:
--------------------------------------------------------------------------------
1 | name: Example cifar10
2 |
3 | on: [push]
4 |
5 | jobs:
6 | train:
7 |
8 | runs-on: ${{ matrix.os }}
9 | strategy:
10 | max-parallel: 8
11 | matrix:
12 | python-version: [3.9, '3.10', 3.11, 3.12]
13 | os: [macos-latest, windows-latest, ubuntu-latest]
14 | include:
15 | - pytorch-version: 2.3.1
16 | torchvision-version: 0.18.1
17 | - pytorch-option: '+cpu'
18 | - pytorch-option: ''
19 | os: macos-latest
20 |
21 | steps:
22 | - uses: actions/checkout@v4
23 | - name: Set up Python ${{ matrix.python-version }}
24 | uses: actions/setup-python@v5
25 | with:
26 | python-version: ${{ matrix.python-version }}
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip
30 | pip install 'numpy<2'
31 | pip install torch==${{ matrix.pytorch-version }}${{ matrix.pytorch-option }} -f https://download.pytorch.org/whl/torch
32 | pip install torchvision==${{ matrix.torchvision-version }}${{ matrix.pytorch-option }} -f https://download.pytorch.org/whl/torchvision
33 | pip install setuptools
34 | pip install requests
35 | pip install tqdm
36 | - name: Install package
37 | run: python -m pip install .
38 | - name: Download a ResNet implementation
39 | run: |
40 | cd examples/cifar10/
41 | python download.py
42 | - name: Train a ResNet20 model on CIFAR10 dataset
43 | run: python examples/cifar10/main.py --epochs 1 --no-progress --no-gpu
44 |
--------------------------------------------------------------------------------
/.github/workflows/example-emnist.yml:
--------------------------------------------------------------------------------
1 | name: Example emnist
2 |
3 | on: [push]
4 |
5 | jobs:
6 | train:
7 |
8 | runs-on: ${{ matrix.os }}
9 | strategy:
10 | max-parallel: 8
11 | matrix:
12 | python-version: [3.9, '3.10', 3.11, 3.12]
13 | os: [macos-latest, windows-latest, ubuntu-latest]
14 | include:
15 | - pytorch-version: 1.9.1
16 | torchvision-version: 0.10.1
17 | python-version: 3.9
18 | - pytorch-version: 1.9.0
19 | torchvision-version: 0.10.0
20 | python-version: 3.9
21 | os: macos-latest
22 | - pytorch-version: 1.11.0
23 | torchvision-version: 0.12.0
24 | python-version: '3.10'
25 | - pytorch-version: 2.0.1
26 | torchvision-version: 0.15.2
27 | python-version: 3.11
28 | - pytorch-version: 2.2.2
29 | torchvision-version: 0.17.2
30 | python-version: 3.12
31 | - pytorch-option: '+cpu'
32 | - pytorch-option: ''
33 | os: macos-latest
34 |
35 | steps:
36 | - uses: actions/checkout@v4
37 | - name: Set up Python ${{ matrix.python-version }}
38 | uses: actions/setup-python@v5
39 | with:
40 | python-version: ${{ matrix.python-version }}
41 | - name: Install dependencies
42 | run: |
43 | python -m pip install --upgrade pip
44 | pip install 'numpy<2'
45 | pip install torch==${{ matrix.pytorch-version }}${{ matrix.pytorch-option }} -f https://download.pytorch.org/whl/torch
46 | pip install torchvision==${{ matrix.torchvision-version }}${{ matrix.pytorch-option }} -f https://download.pytorch.org/whl/torchvision
47 | pip install setuptools
48 | pip install requests
49 | - name: Install package
50 | run: python -m pip install .
51 | - name: Download EMNIST dataset
52 | run: python examples/emnist/download.py
53 | - name: Train a model on EMNIST dataset
54 | run: python examples/emnist/main.py --epochs 1 --no-gpu
55 |
--------------------------------------------------------------------------------
/.github/workflows/example-plots.yml:
--------------------------------------------------------------------------------
1 | name: Example plots
2 |
3 | on: [push]
4 |
5 | jobs:
6 | plot:
7 |
8 | runs-on: ${{ matrix.os }}
9 | strategy:
10 | max-parallel: 8
11 | matrix:
12 | python-version: [3.9, '3.10', 3.11, 3.12]
13 | os: [macos-latest, windows-latest, ubuntu-latest]
14 | include:
15 | - pytorch-version: 1.9.1
16 | python-version: 3.9
17 | - pytorch-version: 1.11.0
18 | python-version: '3.10'
19 | - pytorch-version: 2.0.1
20 | python-version: 3.11
21 | - pytorch-version: 2.2.2
22 | python-version: 3.12
23 | - pytorch-option: '+cpu'
24 | - pytorch-option: ''
25 | os: macos-latest
26 |
27 | steps:
28 | - uses: actions/checkout@v4
29 | - name: Set up Python ${{ matrix.python-version }}
30 | uses: actions/setup-python@v5
31 | with:
32 | python-version: ${{ matrix.python-version }}
33 | - name: Install dependencies
34 | run: |
35 | python -m pip install --upgrade pip
36 | pip install 'numpy<2'
37 | pip install torch==${{ matrix.pytorch-version }}${{ matrix.pytorch-option }} -f https://download.pytorch.org/whl/torch
38 | pip install matplotlib
39 | pip install setuptools
40 | - name: Install package
41 | run: python -m pip install .
42 | - name: Preparation
43 | run: mkdir artifact
44 | - name: Plot warmup period
45 | run: |
46 | cd artifact
47 | python ../examples/plots/effective_warmup_period.py --output png
48 | - name: Plot warmup schedule
49 | run: |
50 | cd artifact
51 | python ../examples/plots/warmup_schedule.py --output png
52 | - uses: actions/upload-artifact@v4
53 | with:
54 | name: artifact_${{ matrix.os }}_${{ matrix.python-version }}
55 | path: artifact
56 |
--------------------------------------------------------------------------------
/.github/workflows/pythonpackage.yml:
--------------------------------------------------------------------------------
1 | name: Python package
2 |
3 | on: [push]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ${{ matrix.os }}
9 | strategy:
10 | max-parallel: 8
11 | matrix:
12 | python-version: [3.9, '3.10', 3.11, 3.12]
13 | os: [macos-latest, windows-latest, ubuntu-latest]
14 | pytorch-release: [earliest, latest]
15 | include:
16 | - pytorch-version: 1.9.1
17 | python-version: 3.9
18 | pytorch-release: earliest
19 | - pytorch-version: 1.11.0
20 | python-version: '3.10'
21 | pytorch-release: earliest
22 | - pytorch-version: 2.0.1
23 | python-version: 3.11
24 | pytorch-release: earliest
25 | - pytorch-version: 2.2.2
26 | python-version: 3.12
27 | pytorch-release: earliest
28 | - pytorch-version: 2.5.1
29 | pytorch-release: latest
30 | - pytorch-option: '+cpu'
31 | - pytorch-option: ''
32 | os: macos-latest
33 |
34 | steps:
35 | - uses: actions/checkout@v4
36 | - name: Set up Python ${{ matrix.python-version }}
37 | uses: actions/setup-python@v5
38 | with:
39 | python-version: ${{ matrix.python-version }}
40 | - name: Install dependencies
41 | run: |
42 | python -m pip install --upgrade pip
43 | pip install 'numpy<2'
44 | pip install torch==${{ matrix.pytorch-version }}${{ matrix.pytorch-option }} -f https://download.pytorch.org/whl/torch
45 | - name: Lint with flake8
46 | run: |
47 | pip install flake8
48 | # stop the build if there are Python syntax errors or undefined names
49 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
50 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
51 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
52 | - name: Test with pytest (naked)
53 | run: |
54 | pip install pytest
55 | pytest test -s -vv
56 | - name: Test with pytest (wrapped)
57 | run: pytest test -s -vv
58 | env:
59 | WRAPPED_LR: "1"
60 | if: matrix.pytorch-release == 'latest'
61 | - name: Package with build
62 | run: |
63 | pip install setuptools build
64 | python -m build
65 |
--------------------------------------------------------------------------------
/.github/workflows/sphinx-gh-pages.yml:
--------------------------------------------------------------------------------
1 | name: Sphinx GitHub Pages
2 |
3 | on:
4 | # Runs on pushes targeting the default branch
5 | push:
6 | branches: ["master"]
7 |
8 | # Allows you to run this workflow manually from the Actions tab
9 | workflow_dispatch:
10 |
11 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages
12 | permissions:
13 | contents: read
14 | pages: write
15 | id-token: write
16 |
17 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued.
18 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete.
19 | concurrency:
20 | group: "pages"
21 | cancel-in-progress: false
22 |
23 | jobs:
24 | # Build job
25 | build:
26 | runs-on: ubuntu-latest
27 | steps:
28 | - name: Checkout
29 | uses: actions/checkout@v4
30 | with:
31 | fetch-depth: 0
32 | - name: Set up Python
33 | uses: actions/setup-python@v5
34 | with:
35 | python-version: '3.10'
36 | - name: Install dependencies
37 | run: |
38 | python -m pip install --upgrade pip
39 | pip install 'numpy<2'
40 | pip install torch --index-url https://download.pytorch.org/whl/cpu
41 | pip install sphinx sphinxcontrib-katex sphinx-copybutton sphinx-multiversion sphinx-rtd-theme
42 | - name: Sphinx Build
43 | run: bash docs/build.sh
44 | - name: Upload artifact
45 | uses: actions/upload-pages-artifact@v3
46 | with:
47 | path: build/html
48 |
49 | # Deployment job
50 | deploy:
51 | environment:
52 | name: github-pages
53 | url: ${{ steps.deployment.outputs.page_url }}
54 | runs-on: ubuntu-latest
55 | needs: build
56 | steps:
57 | - name: Deploy to GitHub Pages
58 | id: deployment
59 | uses: actions/deploy-pages@v4
60 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # celery beat schedule file
95 | celerybeat-schedule
96 |
97 | # SageMath parsed files
98 | *.sage.py
99 |
100 | # Environments
101 | .env
102 | .venv
103 | env/
104 | venv/
105 | ENV/
106 | env.bak/
107 | venv.bak/
108 |
109 | # Spyder project settings
110 | .spyderproject
111 | .spyproject
112 |
113 | # Rope project settings
114 | .ropeproject
115 |
116 | # mkdocs documentation
117 | /site
118 |
119 | # mypy
120 | .mypy_cache/
121 | .dmypy.json
122 | dmypy.json
123 |
124 | # Pyre type checker
125 | .pyre/
126 |
127 | # Example data
128 | examples/*/data
129 | examples/*/output*
130 | examples/cifar10/resnet.py
131 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019-2024 Takenori Yamamoto
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include test/__init__.py
2 | include examples/plots/README.md
3 | include examples/plots/effective_warmup_period.py
4 | include examples/plots/warmup_schedule.py
5 | include examples/plots/figs/*.png
6 | include examples/emnist/README.md
7 | include examples/emnist/download.py
8 | include examples/emnist/main.py
9 | include examples/emnist/plot.py
10 | include examples/emnist/figs/*.png
11 | include examples/cifar10/README.md
12 | include examples/cifar10/download.py
13 | include examples/cifar10/main.py
14 | include examples/cifar10/plot.py
15 | include examples/cifar10/figs/*.png
16 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A PyTorch Extension for Learning Rate Warmup
2 |
3 | This library contains PyTorch implementations of the warmup schedules described in [On the adequacy of untuned warmup for adaptive optimization](https://arxiv.org/abs/1910.04209).
4 |
5 |
6 |
7 | [](https://github.com/Tony-Y/pytorch_warmup/)
8 | [](https://pypi.python.org/pypi/pytorch-warmup/)
9 | [](https://github.com/Tony-Y/pytorch_warmup/blob/master/LICENSE)
10 | [](https://www.python.org)
11 |
12 | ## Installation
13 |
14 | Make sure you have Python 3.9+ and PyTorch 1.9+ or 2.x. Then, run the following command in the project directory:
15 |
16 | ```shell
17 | python -m pip install .
18 | ```
19 |
20 | or install the latest version from the Python Package Index:
21 |
22 | ```shell
23 | pip install -U pytorch_warmup
24 | ```
25 |
26 | ## Examples
27 |
28 | * [CIFAR10](https://github.com/Tony-Y/pytorch_warmup/tree/master/examples/cifar10) -
29 | A sample script to train a ResNet model on the CIFAR10 dataset using an optimization algorithm with a warmup schedule.
30 | Its README presents ResNet20 results obtained using each of AdamW, NAdamW, AMSGradW, and AdaMax
31 | together with each of various warmup schedules.
32 | In addition, there is a ResNet performance comparison (up to ResNet110) obtained using the SGD algorithm
33 | with a linear warmup schedule.
34 | * [EMNIST](https://github.com/Tony-Y/pytorch_warmup/tree/master/examples/emnist) -
35 | A sample script to train a CNN model on the EMNIST dataset using the AdamW algorithm with a warmup schedule.
36 | Its README presents a result obtained using the AdamW algorithm with each of the untuned linear and exponential warmup,
37 | and the RAdam warmup.
38 | * [Plots](https://github.com/Tony-Y/pytorch_warmup/tree/master/examples/plots) -
39 | A script to plot effective warmup periods as a function of β₂, and warmup schedules over time.
40 |
41 | ## Usage
42 |
43 | The [documentation](https://tony-y.github.io/pytorch_warmup/master/) provides more detailed information on this library, unseen below.
44 |
45 | ### Sample Codes
46 |
47 | The scheduled learning rate is dampened by the multiplication of the warmup factor:
48 |
49 |
50 |
51 | #### Approach 1
52 |
53 | [](https://colab.research.google.com/github/Tony-Y/colab-notebooks/blob/master/PyTorch_Warmup_Approach1_chaining.ipynb)
54 |
55 | When the learning rate schedule uses the global iteration number, the untuned linear warmup can be used
56 | together with `Adam` or its variant (`AdamW`, `NAdam`, etc.) as follows:
57 |
58 | ```python
59 | import torch
60 | import pytorch_warmup as warmup
61 |
62 | optimizer = torch.optim.AdamW(params, lr=0.001, betas=(0.9, 0.999), weight_decay=0.01)
63 | # This sample code uses the AdamW optimizer.
64 | num_steps = len(dataloader) * num_epochs
65 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)
66 | # The LR schedule initialization resets the initial LR of the optimizer.
67 | warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
68 | # The warmup schedule initialization dampens the initial LR of the optimizer.
69 | for epoch in range(1,num_epochs+1):
70 | for batch in dataloader:
71 | optimizer.zero_grad()
72 | loss = ...
73 | loss.backward()
74 | optimizer.step()
75 | with warmup_scheduler.dampening():
76 | lr_scheduler.step()
77 | ```
78 |
79 | > [!Warning]
80 | > Note that the warmup schedule must not be initialized before the initialization of the learning rate schedule.
81 |
82 | If you want to use the learning rate schedule *chaining*, which is supported for PyTorch 1.4 or above, you may simply write a code of learning rate schedulers as a suite of the `with` statement:
83 |
84 | ```python
85 | lr_scheduler1 = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
86 | lr_scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
87 | warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
88 | for epoch in range(1,num_epochs+1):
89 | for batch in dataloader:
90 | ...
91 | optimizer.step()
92 | with warmup_scheduler.dampening():
93 | lr_scheduler1.step()
94 | lr_scheduler2.step()
95 | ```
96 |
97 | If you want to start the learning rate schedule after the end of the linear warmup, delay it by the warmup period:
98 |
99 | ```python
100 | warmup_period = 2000
101 | num_steps = len(dataloader) * num_epochs - warmup_period
102 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)
103 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period)
104 | for epoch in range(1,num_epochs+1):
105 | for batch in dataloader:
106 | ...
107 | optimizer.step()
108 | with warmup_scheduler.dampening():
109 | if warmup_scheduler.last_step + 1 >= warmup_period:
110 | lr_scheduler.step()
111 | ```
112 |
113 | #### Approach 2
114 |
115 | [](https://colab.research.google.com/github/Tony-Y/colab-notebooks/blob/master/PyTorch_Warmup_Approach2_chaining.ipynb)
116 |
117 | When the learning rate schedule uses the epoch number, the warmup schedule can be used as follows:
118 |
119 | ```python
120 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[num_epochs//3], gamma=0.1)
121 | warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
122 | for epoch in range(1,num_epochs+1):
123 | for i, batch in enumerate(dataloader):
124 | optimizer.zero_grad()
125 | loss = ...
126 | loss.backward()
127 | optimizer.step()
128 | if i < len(dataloader)-1:
129 | with warmup_scheduler.dampening():
130 | pass
131 | with warmup_scheduler.dampening():
132 | lr_scheduler.step()
133 | ```
134 |
135 | This code can be rewritten more compactly:
136 |
137 | ```python
138 | for epoch in range(1,num_epochs+1):
139 | for i, batch in enumerate(dataloader):
140 | optimizer.zero_grad()
141 | loss = ...
142 | loss.backward()
143 | optimizer.step()
144 | with warmup_scheduler.dampening():
145 | if i + 1 == len(dataloader):
146 | lr_scheduler.step()
147 | ```
148 |
149 | #### Approach 3
150 |
151 | When you use `CosineAnnealingWarmRestarts`, the warmup schedule can be used as follows:
152 |
153 | ```python
154 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
155 | warmup_period = 2000
156 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period)
157 | iters = len(dataloader)
158 | warmup_epochs = ... # for example, (warmup_period + iters - 1) // iters
159 | for epoch in range(epochs+warmup_epochs):
160 | for i, batch in enumerate(dataloader):
161 | optimizer.zero_grad()
162 | loss = ...
163 | loss.backward()
164 | optimizer.step()
165 | with warmup_scheduler.dampening():
166 | if epoch >= warmup_epochs:
167 | lr_scheduler.step(epoch-warmup_epochs + i / iters)
168 | ```
169 |
170 | ### Warmup Schedules
171 |
172 | #### Manual Warmup
173 |
174 | In `LinearWarmup` and `ExponentialWarmup`, the warmup factor `w(t)` depends on the warmup period that must manually be specified.
175 |
176 | ##### Linear
177 |
178 | `w(t) = min(1, t / warmup_period)`
179 |
180 | ```python
181 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=2000)
182 | ```
183 |
184 | For details please refer to [LinearWarmup](https://tony-y.github.io/pytorch_warmup/master/manual_warmup.html#pytorch_warmup.base.LinearWarmup) in the documentation.
185 |
186 | ##### Exponential
187 |
188 | `w(t) = 1 - exp(-t / warmup_period)`
189 |
190 | ```python
191 | warmup_scheduler = warmup.ExponentialWarmup(optimizer, warmup_period=1000)
192 | ```
193 |
194 | For details please refer to [ExponentialWarmup](https://tony-y.github.io/pytorch_warmup/master/manual_warmup.html#pytorch_warmup.base.ExponentialWarmup) in the documentation.
195 |
196 | #### Untuned Warmup
197 |
198 | In `UntunedLinearWarmup` and `UntunedExponentialWarmup`, the warmup period is determined by a function of Adam's `beta2` parameter.
199 |
200 | ##### Linear
201 |
202 | `warmup_period = 2 / (1 - beta2)`
203 |
204 | ```python
205 | warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
206 | ```
207 |
208 | For details please refer to [UntunedLinearWarmup](https://tony-y.github.io/pytorch_warmup/master/untuned_warmup.html#pytorch_warmup.untuned.UntunedLinearWarmup) in the documentation.
209 |
210 | ##### Exponential
211 |
212 | `warmup_period = 1 / (1 - beta2)`
213 |
214 | ```python
215 | warmup_scheduler = warmup.UntunedExponentialWarmup(optimizer)
216 | ```
217 |
218 | For details please refer to [UntunedExponentialWarmup](https://tony-y.github.io/pytorch_warmup/master/untuned_warmup.html#pytorch_warmup.untuned.UntunedExponentialWarmup) in the documentation.
219 |
220 | #### RAdam Warmup
221 |
222 | In `RAdamWarmup`, the warmup factor `w(t)` is a complicated function depending on Adam's `beta2` parameter.
223 |
224 | ```python
225 | warmup_scheduler = warmup.RAdamWarmup(optimizer)
226 | ```
227 |
228 | For details please refer to [RAdamWarmup](https://tony-y.github.io/pytorch_warmup/master/radam_warmup.html#pytorch_warmup.radam.RAdamWarmup) in the documentation, or
229 | "[On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265)."
230 |
231 | ### Apex's Adam
232 |
233 | The Apex library provides an Adam optimizer tuned for CUDA devices, [FusedAdam](https://nvidia.github.io/apex/optimizers.html#apex.optimizers.FusedAdam). The FusedAdam optimizer can be used together with any one of the warmup schedules above. For example:
234 |
235 | [](https://colab.research.google.com/github/Tony-Y/colab-notebooks/blob/master/PyTorch_Warmup_FusedAdam.ipynb)
236 |
237 | ```python
238 | optimizer = apex.optimizers.FusedAdam(params, lr=0.001, betas=(0.9, 0.999), weight_decay=0.01)
239 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)
240 | warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
241 | ```
242 |
243 | ### Compiled Optimizers
244 |
245 | [Benchmarking results](https://dev-discuss.pytorch.org/t/performance-comparison-between-torch-compile-and-apex-optimizers/2023)
246 | show that the complied Adam outperforms the Apex's Adam.
247 |
248 | > [!Warning]
249 | > PyTorch 2.3 or later is required for using the compiled optimizer with
250 | > a warmup scheduler and/or LR schedulers.
251 | > PyTorch-Warmup 0.2 or earlier is incompatible with the complied optimizer.
252 |
253 | You can compile the Adam optimizer as follows:
254 |
255 | ```python
256 | model = model.to(device)
257 | optimizer = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.001).to(device))
258 | opt_step = torch.compile(optimizer.step, mode="reduce-overhead")
259 | ```
260 |
261 | > [!Important]
262 | > Wrap the learning rate in a `Tensor`, or `torch.compile` will recompile
263 | > as the value of the learning rate changes.
264 |
265 | Then, the compiled version `opt_step` have to be invoked instead of `optimizer.step`:
266 |
267 | ```python
268 | for epoch in range(1,num_epochs+1):
269 | for batch in dataloader:
270 | optimizer.zero_grad()
271 | loss = ...
272 | loss.backward()
273 | opt_step()
274 | with warmup_scheduler.dampening():
275 | lr_scheduler.step()
276 | ```
277 |
278 | You can also compile other built-in optimizers in the way shown above.
279 |
280 | > [!Note]
281 | > When using the compiled SGD with momentum, its momentum buffer is needed
282 | > to be initialized manually. You can find sample code in the CIFAR10 exmaple.
283 |
284 | In practice, you may compile it together with other PyTorch code as follows:
285 |
286 | ```python
287 | @torch.compile(mode="reduce-overhead")
288 | def train_iter_fn(batch):
289 | optimizer.zero_grad()
290 | loss = ...
291 | loss.backward()
292 | optimizer.step()
293 |
294 | for epoch in range(1,num_epochs+1):
295 | for batch in dataloader:
296 | train_iter_fn(batch)
297 | with warmup_scheduler.dampening():
298 | lr_scheduler.step()
299 | ```
300 |
301 | `torch.compile` skips `lr_scheduler.step` even if it were invoked within `train_iter_fn`.
302 | Likewise, you should not compile `warmup_scheduler.dampening`.
303 | You may also use `torch.compiler.disable` to have `torch.compile` skip a function
304 | updating the learning rate as follows:
305 |
306 | ```python
307 | @torch.compiler.disable
308 | def update_lr_fn():
309 | with warmup_scheduler.dampening():
310 | lr_scheduler.step()
311 |
312 | @torch.compile(mode="reduce-overhead")
313 | def train_iter_fn(batch):
314 | optimizer.zero_grad()
315 | loss = ...
316 | loss.backward()
317 | optimizer.step()
318 | update_lr_fn()
319 |
320 | for epoch in range(1,num_epochs+1):
321 | for batch in dataloader:
322 | train_iter_fn(batch)
323 | ```
324 |
325 | ## License
326 |
327 | MIT License
328 |
329 | © 2019-2025 Takenori Yamamoto
330 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_templates/versions.html:
--------------------------------------------------------------------------------
1 | {%- if current_version %}
2 |
3 |
4 | Other Versions
5 | v: {{ current_version.name }}
6 |
7 |
8 |
9 | {%- if versions.tags %}
10 |
11 | Tags
12 | {%- for item in versions.tags %}
13 | {{ item.name }}
14 | {%- endfor %}
15 |
16 | {%- endif %}
17 | {%- if versions.branches %}
18 |
19 | Branches
20 | {%- for item in versions.branches %}
21 | {{ item.name }}
22 | {%- endfor %}
23 |
24 | {%- endif %}
25 |
26 |
27 | {%- endif %}
28 |
--------------------------------------------------------------------------------
/docs/build.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -ex
2 | #
3 | # In the project root directry, run this script to build the HTML document:
4 | # >> ./docs/build.sh
5 | #
6 |
7 | export SMV_BRANCH_WHITELIST="^$(git branch --show-current)$"
8 |
9 | SOURCEDIR=docs
10 | BUILDDIR=build/html
11 |
12 | sphinx-multiversion $SOURCEDIR $BUILDDIR
13 | cp $SOURCEDIR/gh-pages-redirect.html $BUILDDIR/index.html
14 | touch $BUILDDIR/.nojekyll
15 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | # import os
14 | # import sys
15 | # sys.path.insert(0, os.path.abspath('.'))
16 | import os
17 | import sys
18 | sys.path.insert(0, os.path.abspath('../'))
19 |
20 |
21 | # -- Project information -----------------------------------------------------
22 |
23 | project = 'PyTorch Warmup'
24 | copyright = '2019-2024, Takenori Yamamoto'
25 | author = 'Takenori Yamamoto'
26 |
27 |
28 | # -- General configuration ---------------------------------------------------
29 |
30 | # Add any Sphinx extension module names here, as strings. They can be
31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
32 | # ones.
33 | extensions = [
34 | 'sphinx.ext.autodoc',
35 | 'sphinx.ext.autosummary',
36 | 'sphinx.ext.doctest',
37 | 'sphinx.ext.intersphinx',
38 | 'sphinx.ext.todo',
39 | 'sphinx.ext.coverage',
40 | 'sphinx.ext.napoleon',
41 | 'sphinx.ext.viewcode',
42 | 'sphinxcontrib.katex',
43 | 'sphinx_copybutton',
44 | 'sphinx_multiversion',
45 | ]
46 |
47 | # Add any paths that contain templates here, relative to this directory.
48 | templates_path = ['_templates']
49 |
50 | # List of patterns, relative to source directory, that match files and
51 | # directories to ignore when looking for source files.
52 | # This pattern also affects html_static_path and html_extra_path.
53 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
54 |
55 |
56 | # -- Options for HTML output -------------------------------------------------
57 |
58 | # The theme to use for HTML and HTML Help pages. See the documentation for
59 | # a list of builtin themes.
60 | #
61 | html_theme = 'sphinx_rtd_theme'
62 |
63 | # Add any paths that contain custom static files (such as style sheets) here,
64 | # relative to this directory. They are copied after the builtin static files,
65 | # so a file named "default.css" will overwrite the builtin "default.css".
66 | # html_static_path = ['_static']
67 |
68 | # Copybutton settings
69 | copybutton_prompt_text = r">>> |\.\.\. |\$ "
70 | copybutton_prompt_is_regexp = True
71 |
72 | # Multiversion settings
73 | smv_tag_whitelist = r'^v(?!0\.1\.0)\d+\.\d+\.\d+$'
74 | if "SMV_BRANCH_WHITELIST" in os.environ:
75 | smv_branch_whitelist = os.environ["SMV_BRANCH_WHITELIST"]
76 |
--------------------------------------------------------------------------------
/docs/gh-pages-redirect.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Redirecting to master branch
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. PyTorch Warmup documentation master file, created by
2 | sphinx-quickstart on Thu Oct 31 14:00:43 2019.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to PyTorch Warmup's documentation!
7 | ==========================================
8 |
9 | This library contains PyTorch implementations of the warmup schedules described in
10 | `On the adequacy of untuned warmup for adaptive optimization
11 | `_.
12 |
13 | .. image:: https://github.com/Tony-Y/pytorch_warmup/raw/master/examples/plots/figs/warmup_schedule.png
14 | :alt: Warmup schedule
15 | :width: 400
16 | :align: center
17 |
18 | .. image:: https://github.com/Tony-Y/pytorch_warmup/workflows/Python%20package/badge.svg
19 | :alt: Python package
20 | :target: https://github.com/Tony-Y/pytorch_warmup/
21 |
22 | .. image:: https://img.shields.io/pypi/v/pytorch-warmup.svg
23 | :alt: PyPI version shields.io
24 | :target: https://pypi.python.org/pypi/pytorch-warmup/
25 |
26 | .. image:: https://img.shields.io/pypi/l/pytorch-warmup.svg
27 | :alt: PyPI license
28 | :target: https://github.com/Tony-Y/pytorch_warmup/blob/master/LICENSE
29 |
30 | .. image:: https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue
31 | :alt: Python versions
32 | :target: https://www.python.org
33 |
34 | Installation
35 | ------------
36 |
37 | Make sure you have Python 3.9+ and PyTorch 1.9+ or 2.x. Then, install the latest version from the Python Package Index:
38 |
39 | .. code-block:: shell
40 |
41 | pip install -U pytorch_warmup
42 |
43 | Examples
44 | --------
45 |
46 | .. image:: https://colab.research.google.com/assets/colab-badge.svg
47 | :alt: Open In Colab
48 | :target: https://colab.research.google.com/github/Tony-Y/colab-notebooks/blob/master/PyTorch_Warmup_Approach1_chaining.ipynb
49 |
50 | * `CIFAR10 `_ -
51 | A sample script to train a ResNet model on the CIFAR10 dataset using an optimization algorithm with a warmup schedule.
52 | Its README presents ResNet20 results obtained using each of AdamW, NAdamW, AMSGradW, and AdaMax
53 | together with each of various warmup schedules.
54 | In addition, there is a ResNet performance comparison (up to ResNet110) obtained using the SGD algorithm
55 | with a linear warmup schedule.
56 |
57 | * `EMNIST `_ -
58 | A sample script to train a CNN model on the EMNIST dataset using the AdamW algorithm with a warmup schedule.
59 | Its README presents a result obtained using the AdamW algorithm with each of the untuned linear and exponential warmup,
60 | and the RAdam warmup.
61 |
62 | * `Plots `_ -
63 | A script to plot effective warmup periods as a function of :math:`\beta_{2}`, and warmup schedules over time.
64 |
65 | Usage
66 | -----
67 |
68 | When the learning rate schedule uses the global iteration number, the untuned linear warmup can be used
69 | together with :class:`Adam` or its variant (:class:`AdamW`, :class:`NAdam`, etc.) as follows:
70 |
71 | .. code-block:: python
72 |
73 | import torch
74 | import pytorch_warmup as warmup
75 |
76 | optimizer = torch.optim.AdamW(params, lr=0.001, betas=(0.9, 0.999), weight_decay=0.01)
77 | # This sample code uses the AdamW optimizer.
78 | num_steps = len(dataloader) * num_epochs
79 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)
80 | # The LR schedule initialization resets the initial LR of the optimizer.
81 | warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
82 | # The warmup schedule initialization dampens the initial LR of the optimizer.
83 | for epoch in range(1,num_epochs+1):
84 | for batch in dataloader:
85 | optimizer.zero_grad()
86 | loss = ...
87 | loss.backward()
88 | optimizer.step()
89 | with warmup_scheduler.dampening():
90 | lr_scheduler.step()
91 |
92 | .. warning::
93 | Note that the warmup schedule must not be initialized before the initialization of the learning rate schedule.
94 |
95 | Other approaches can be found in `README `_.
96 |
97 | .. toctree::
98 | :maxdepth: 2
99 | :caption: Contents:
100 |
101 | manual_warmup
102 | untuned_warmup
103 | radam_warmup
104 |
105 |
106 | Indices and tables
107 | ==================
108 |
109 | * :ref:`genindex`
110 | * :ref:`modindex`
111 | * :ref:`search`
112 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/manual_warmup.rst:
--------------------------------------------------------------------------------
1 | Manual Warmup
2 | ==============
3 |
4 | .. automodule:: pytorch_warmup.base
5 | :members:
6 |
--------------------------------------------------------------------------------
/docs/radam_warmup.rst:
--------------------------------------------------------------------------------
1 | RAdam Warmup
2 | ============
3 |
4 | .. automodule:: pytorch_warmup.radam
5 | :members:
6 |
--------------------------------------------------------------------------------
/docs/untuned_warmup.rst:
--------------------------------------------------------------------------------
1 | Untuned Warmup
2 | ==============
3 |
4 | .. automodule:: pytorch_warmup.untuned
5 | :members:
6 |
--------------------------------------------------------------------------------
/examples/cifar10/README.md:
--------------------------------------------------------------------------------
1 | # CIFAR10 Example
2 |
3 | Requirements: PyTorch 1.12+ or 2.x, `pytorch_warmup`, `torchvision`, and `tqdm`.
4 |
5 | > [!Warning]
6 | > The NAdamW optimization algorithm requires PyTorch 2.1 or later.
7 | > The RAdamW optimization algorithm requires PyTorch 2.3 or later.
8 | > So, we recommend you use PyTorch 2.3.1 or later.
9 |
10 | ## Results
11 |
12 | * The ResNet20 architecture is employed from
13 | [a ResNet implementation for CIFAR10](https://github.com/akamaster/pytorch_resnet_cifar10).
14 | * The initial learning rate $\alpha$ is $10^{-1}$ or $10^{-2}$.
15 | * Various optimization algorithms are used for comparison: SGD, AdaMax, AdamW, AMSGradW, NAdamW, and RAdamW.
16 | The momentum factor of SGD is set to $0.9$. The exponential decay rates of Adam variants are set as
17 | $\beta_{1} = 0.9$ and $\beta_{2} = 0.999$.
18 | * Various warmup schedules are used for comparison.
19 | The untuned linear and exponential warmup are labeled as Linear and Expo, respectively.
20 | The RAdam warmup is labeled as RAdam.
21 | The linear and exponential warmup with a warmup period and a constant $\tau$ of *n* thousands
22 | are labeled as Linear-*n*k and Expo-*n*k, respectively.
23 | * 5 random seeds are used to sample top-1 scores on the testing set.
24 | The mean scores are shown in figures below. The error bar indicates the standard deviation.
25 | The error band is used for reference purpose.
26 |
27 | ### No Warmup
28 |
29 |
30 |
31 | Top-1 errors of models trained by each optimization algorithm without warmup.
32 |
33 |
34 | ### AdamW
35 |
36 |
37 |
38 | Top-1 errors of models trained by AdamW with each warmup schedule.
39 |
40 |
41 | ### AMSGradW
42 |
43 |
44 |
45 | Top-1 errors of models trained by AMSGradW with each warmup schedule.
46 |
47 |
48 | ### NAdamW
49 |
50 |
51 |
52 | Top-1 errors of models trained by NAdamW with each warmup schedule.
53 |
54 |
55 | ### AdaMax
56 |
57 |
58 |
59 | Top-1 errors of models trained by AdaMax with each warmup schedule for α = 0.01.
60 |
61 |
62 |
63 |
64 | Top-1 errors of models trained by AdaMax with each warmup schedule for α = 0.1.
65 |
66 |
67 | ### AdamW for a smaller β₂
68 |
69 |
70 |
71 | Top-1 errors of models trained by AdamW/RAdamW without warmup or AdamW with each warmup schedule for β₂ = 0.99.
72 |
73 |
74 |
75 |
76 | Top-1 errors of models trained by AdamW with each warmup schedule for β₂ = 0.99.
77 |
78 |
79 | ## Download ResNet for CIFAR10
80 |
81 | Run the Python script `download.py` to download
82 | [a ResNet implementation for CIFAR10](https://github.com/akamaster/pytorch_resnet_cifar10):
83 |
84 | ```shell
85 | python download.py
86 | ```
87 |
88 | This script shows download progress:
89 |
90 | ```
91 | Downloading https://.../resnet.py to ./resnet.py
92 | 100.0%
93 | ```
94 |
95 | ## Train ResNet20 Models
96 |
97 | Run the Python script `main.py` to train a ResNet20 model on the CIFAR10 dataset using the SGD or AdamW algorithm.
98 |
99 | > [!Note]
100 | > Use `--workers` option to set the number of dataloader workers
101 | > for a better performance in a GPU training.
102 | > The optimal number depends on your environment.
103 | > For example, 2 and 4 for an MPS and CUDA device, respectively,
104 | > but it should not be more than the number of available performance cores.
105 | > Note that the initial epoch takes more time than a later epoch if the number of workers is greater than 0.
106 |
107 | The training log and the test evaluations are saved to files in the directory specified by `--output` option:
108 |
109 | * `history.csv` - The training log.
110 | * `evaluation.csv` - The test evaluations during training.
111 | * `cifar10_resnet20.pt` - The best model (saved optionally).
112 |
113 | You can visualize the training result using the Python script `plot.py`:
114 |
115 | ```
116 | python plot.py [path to the directory]
117 | ```
118 |
119 | This plot script requires `pandas` and `matplotlib`.
120 |
121 |
122 |
123 | A training result for the AdamW algorithm with the RAdam warmup.
124 |
125 |
126 | ### SGD
127 |
128 | Train a ResNet20 model using the SGD algorithm:
129 |
130 | ```
131 | python main.py --output output_sgd
132 | ```
133 |
134 | ### AdamW
135 |
136 | #### No Warmup
137 |
138 | Train a ResNet20 model using AdamW without warmup:
139 |
140 | ```
141 | python main.py --algorithm adamw --output output_adamw_none
142 | ```
143 |
144 | #### Untuned Exponential Warmup
145 |
146 | Train a ResNet20 model using AdamW with the *Untuned Exponential Warmup* schedule:
147 |
148 | ```
149 | python main.py --algorithm adamw --warmup exponential --output output_adamw_expo
150 | ```
151 |
152 | #### Untuned Linear Warmup
153 |
154 | Train a ResNet20 model using AdamW with the *Untuned Linear Warmup* schedule:
155 |
156 | ```
157 | python main.py --algorithm adamw --warmup linear --output output_adamw_linear
158 | ```
159 |
160 | #### RAdam Warmup
161 |
162 | Train a ResNet20 model using AdamW with the *RAdam Warmup* schedule:
163 |
164 | ```
165 | python main.py --algorithm adamw --warmup radam --output output_adamw_radam
166 | ```
167 |
168 | #### Expo-5k Warmup
169 |
170 | Train a ResNet20 model using AdamW with the *Expo-5k Warmup* schedule:
171 |
172 | ```
173 | python main.py --algorithm adamw --warmup exponential --warmup-period 5000 --output output_adamw_expo-5k
174 | ```
175 |
176 | #### Linear-10k Warmup
177 |
178 | Train a ResNet20 model using AdamW with the *Linear-10k Warmup* schedule:
179 |
180 | ```
181 | python main.py --algorithm adamw --warmup linear --warmup-period 10000 --output output_adamw_linear-10k
182 | ```
183 |
184 | ## Usage
185 |
186 | ```
187 | usage: main.py [-h] [-r ARCH] [-b BS] [-c BS] [-e NE] [-m M [M ...]]
188 | [-a ALGO] [-l LR] [-d WD] [-g B2] [-w WU] [-t TAU]
189 | [-n NW] [-s S] [-i I] [-o PATH] [--save-model]
190 | [--no-progress] [--no-gpu]
191 |
192 | PyTorch CIFAR10 Example
193 |
194 | options:
195 | -h, --help show this help message and exit
196 | -r ARCH, --arch ARCH ResNet architecture for CIFAR10: resnet20 |
197 | resnet32 | resnet44 | resnet56 | resnet110 |
198 | resnet1202 (default: resnet20)
199 | -b BS, --batch-size BS
200 | input batch size for training (default: 128)
201 | -c BS, --test-batch-size BS
202 | input batch size for testing (default: 1000)
203 | -e NE, --epochs NE number of epochs to train (default: 186)
204 | -m M [M ...], --milestones M [M ...]
205 | MultiStepLR's milestones in epoch (default:
206 | [81, 122])
207 | -a ALGO, --algorithm ALGO
208 | optimization algorithm: sgd | adamw | amsgradw
209 | | nadamw | adamax | radamw (default: sgd)
210 | -l LR, --lr LR base learning rate (default: 0.1)
211 | -d WD, --weight-decay WD
212 | weight decay (default: 0.0001)
213 | -g B2, --beta2 B2 Adam's beta2 parameter (default: 0.999)
214 | -w WU, --warmup WU warmup schedule: linear | exponential | radam
215 | | none (default: none)
216 | -t TAU, --warmup-period TAU
217 | linear warmup period or exponential warmup
218 | constant. Set 0 to use the untuned linear or
219 | exponential warmup. (default: 0)
220 | -n NW, --workers NW number of dataloader workers for GPU training
221 | (default: 0)
222 | -s S, --seed S random seed (default: 1)
223 | -i I, --log-interval I
224 | how many batches to wait before logging
225 | training status
226 | -o PATH, --output PATH
227 | path to output directory (default: output)
228 | --save-model for saving the best model
229 | --no-progress disable progress bar
230 | --no-gpu disable GPU training. As default, an MPS or
231 | CUDA device will be used if available.
232 | --compile optimize PyTorch code using TorchDynamo,
233 | AOTAutograd, and TorchInductor
234 | ```
235 |
236 | ```
237 | usage: plot.py [-h] [--output {none,png,pdf}] PATH
238 |
239 | Training History Plot
240 |
241 | positional arguments:
242 | PATH path to the output directory of the training script
243 |
244 | options:
245 | -h, --help show this help message and exit
246 | --output {none,png,pdf}
247 | output file type (default: none)
248 | ```
249 |
250 | ## Supplemental Information
251 |
252 | The model is trained for 186 epochs and the learning rate decays at the 81-th and the 122-th epochs by 0.1.
253 | The weight decay rate is $10^{-4}$. Batch size is 128.
254 | Random cropping and random horizontal flipping are applied to training data.
255 | PyTorch 2.3.1 is employed only for use of the RAdamW algorithm, otherwise PyTorch 2.1.2.
256 | A single P100 GPU is used to accelerate computations.
257 |
258 | The tables below present the top-1 errors depicted in the figures above.
259 |
260 | ### No Warmup
261 |
262 | | Optimizer | α=0.1 | α=0.01 |
263 | | --------- | --------------:| --------------:|
264 | | SGD | `8.23 ± 0.25` | `10.70 ± 0.14` |
265 | | AdaMax | `15.27 ± 0.28` | `8.54 ± 0.26` |
266 | | AdamW | `13.15 ± 0.75` | `9.06 ± 0.42` |
267 | | AMSGradW | `12.56 ± 0.39` | `9.55 ± 0.41` |
268 | | NAdamW | `10.12 ± 0.32` | `8.92 ± 0.29` |
269 | | RAdamW | `8.91 ± 0.25` | `8.82 ± 0.31` |
270 |
271 | ### AdamW
272 |
273 | | Warmup | α=0.1 | α=0.01 |
274 | | ---------- | -------------:| -------------:|
275 | | RAdam | `8.86 ± 0.10` | `8.60 ± 0.14` |
276 | | Expo | `8.67 ± 0.29` | `8.64 ± 0.15` |
277 | | Linear | `8.73 ± 0.21` | `8.81 ± 0.06` |
278 | | Expo-5k | `8.64 ± 0.04` | `8.58 ± 0.28` |
279 | | Linear-10k | `8.52 ± 0.11` | `8.55 ± 0.24` |
280 |
281 | ### AMSGradW
282 |
283 | | Warmup | α=0.1 | α=0.01 |
284 | | ---------- | -------------:| -------------:|
285 | | RAdam | `9.01 ± 0.12` | `9.25 ± 0.23` |
286 | | Expo | `8.93 ± 0.26` | `9.12 ± 0.26` |
287 | | Linear | `9.00 ± 0.28` | `9.16 ± 0.16` |
288 | | Expo-5k | `8.62 ± 0.09` | `8.90 ± 0.12` |
289 | | Linear-10k | `8.62 ± 0.14` | `8.95 ± 0.20` |
290 |
291 | ### NAdamW
292 |
293 | | Warmup | α=0.1 | α=0.01 |
294 | | ---------- | -------------:| -------------:|
295 | | RAdam | `8.72 ± 0.14` | `8.59 ± 0.16` |
296 | | Expo | `8.60 ± 0.28` | `8.51 ± 0.16` |
297 | | Linear | `8.54 ± 0.15` | `8.69 ± 0.19` |
298 | | Expo-5k | `8.43 ± 0.19` | `8.49 ± 0.07` |
299 | | Linear-10k | `8.31 ± 0.20` | `8.56 ± 0.08` |
300 |
301 | ### AdaMax
302 |
303 | | Warmup | α=0.1 | α=0.01 |
304 | | ---------- | --------------:| -------------:|
305 | | None | `15.27 ± 0.28` | `8.54 ± 0.26` |
306 | | RAdam | `14.04 ± 0.33` | `8.42 ± 0.15` |
307 | | Expo | `13.87 ± 0.31` | `8.23 ± 0.09` |
308 | | Linear | `13.77 ± 0.27` | `8.15 ± 0.15` |
309 | | Expo-5k | `13.31 ± 0.19` | `8.17 ± 0.17` |
310 | | Linear-10k | `13.45 ± 0.34` | `8.18 ± 0.19` |
311 |
312 | ### Smaller β₂
313 |
314 | The exponential decay rates of Adam variants are set as $\beta_{1} = 0.9$ and $\beta_{2} = 0.99$.
315 |
316 | #### No Warmup
317 | | Optimizer | α=0.1 | α=0.01 |
318 | | --------- | --------------:| -------------:|
319 | | AdamW | `14.68 ± 0.95` | `9.14 ± 0.25` |
320 | | RAdamW | `10.41 ± 0.33` | `9.13 ± 0.21` |
321 |
322 | #### AdamW with Warmup
323 | | Warmup | α=0.1 | α=0.01 |
324 | | ---------- | --------------:| -------------:|
325 | | RAdam | `10.71 ± 0.49` | `9.11 ± 0.13` |
326 | | Expo | `10.30 ± 0.20` | `9.04 ± 0.28` |
327 | | Linear | `10.34 ± 0.33` | `8.97 ± 0.10` |
328 | | Expo-1k | `9.34 ± 0.20` | `8.98 ± 0.29` |
329 | | Linear-2k | `9.28 ± 0.27` | `8.84 ± 0.23` |
330 | | Expo-5k | `8.60 ± 0.19` | `8.63 ± 0.23` |
331 | | Linear-10k | `8.53 ± 0.09` | `8.43 ± 0.14` |
332 |
333 | ## ResNet Performance Comparison
334 |
335 | We employ the ResNet20, ResNet32, ResNet44, ResNet56, and ResNet110 architecture for comparison.
336 | The SGD with momentum is used as the optimization algorithm.
337 | The momentum factor is set to $0.9$.
338 | The learning rate is $10^{-1}$.
339 | We employ a linear warmup schedule to improve the top-1 score.
340 | The warmup period is set to 1,000.
341 | PyTorch 2.4.0 is used for model training of this performance comparison.
342 | 5 random seeds are used for sampling top-1 scores.
343 | The other implementation details are described in Supplemental Information above.
344 |
345 |
346 |
347 | The best top-1 errors of ResNet models trained by the SGD algorithm without warmup or with the Linear-1k warmup.
348 |
349 |
350 |
351 |
352 | Top-1 errors of ResNet models trained by the SGD algorithm without warmup or with the Linear-1k warmup.
353 | This bar chart presents the mean values. The error bar indicates the standard deviation.
354 |
355 |
356 | ### SGD
357 |
358 | The top-1 errors are shown as mean ± std.
359 |
360 | | Architecture | No Warmup | Linear-1k Warmup |
361 | | ------------ | -------------:| ----------------:|
362 | | ResNet20 | `8.11 ± 0.16` | `8.06 ± 0.21` |
363 | | ResNet32 | `7.93 ± 0.62` | `7.27 ± 0.21` |
364 | | ResNet44 | `7.47 ± 0.50` | `6.99 ± 0.19` |
365 | | ResNet56 | `7.53 ± 1.01` | `6.80 ± 0.12` |
366 | | ResNet110 | `7.76 ± 0.69` | `6.28 ± 0.10` |
367 |
368 | © 2024-2025 Takenori Yamamoto
--------------------------------------------------------------------------------
/examples/cifar10/download.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets.utils import download_url
2 |
3 | # A downloadable URL of resnet.py in a GitHub repo:
4 | # https://github.com/akamaster/pytorch_resnet_cifar10
5 | url = 'https://raw.githubusercontent.com/akamaster/pytorch_resnet_cifar10/refs/heads/master/resnet.py'
6 | md5 = '9dc255cf8dc64c8b47c2b109c5b28d07'
7 |
8 | download_url(url, root='./', filename=None, md5=md5)
9 |
--------------------------------------------------------------------------------
/examples/cifar10/figs/fig-history-adamw-w-radam-warmup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-history-adamw-w-radam-warmup.png
--------------------------------------------------------------------------------
/examples/cifar10/figs/fig-resnet-sgd-scores-best.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-resnet-sgd-scores-best.png
--------------------------------------------------------------------------------
/examples/cifar10/figs/fig-resnet-sgd-scores-mean-std.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-resnet-sgd-scores-mean-std.png
--------------------------------------------------------------------------------
/examples/cifar10/figs/fig-scores-adamax-w-warmup-lr0-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-scores-adamax-w-warmup-lr0-1.png
--------------------------------------------------------------------------------
/examples/cifar10/figs/fig-scores-adamax-w-warmup-vs-sgd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-scores-adamax-w-warmup-vs-sgd.png
--------------------------------------------------------------------------------
/examples/cifar10/figs/fig-scores-adamw-w-warmup-edf0-99-vs-radamw.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-scores-adamw-w-warmup-edf0-99-vs-radamw.png
--------------------------------------------------------------------------------
/examples/cifar10/figs/fig-scores-adamw-w-warmup-edf0-99-vs-sgd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-scores-adamw-w-warmup-edf0-99-vs-sgd.png
--------------------------------------------------------------------------------
/examples/cifar10/figs/fig-scores-adamw-w-warmup-vs-sgd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-scores-adamw-w-warmup-vs-sgd.png
--------------------------------------------------------------------------------
/examples/cifar10/figs/fig-scores-amsgradw-w-warmup-vs-sgd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-scores-amsgradw-w-warmup-vs-sgd.png
--------------------------------------------------------------------------------
/examples/cifar10/figs/fig-scores-nadamw-w-warmup-vs-sgd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-scores-nadamw-w-warmup-vs-sgd.png
--------------------------------------------------------------------------------
/examples/cifar10/figs/fig-scores-no-warmup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-scores-no-warmup.png
--------------------------------------------------------------------------------
/examples/cifar10/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import time
4 | import argparse
5 | import torch
6 | import torch.nn.functional as F
7 | import torch.optim as optim
8 | from torchvision import datasets, transforms
9 | from tqdm.auto import tqdm
10 | import pytorch_warmup as warmup
11 |
12 | try:
13 | import resnet
14 | except ImportError:
15 | sys.exit('Download resnet.py from https://github.com/akamaster/pytorch_resnet_cifar10')
16 |
17 | import torch.backends.cudnn as cudnn
18 | cudnn.benchmark = True
19 |
20 |
21 | architecture_names = ['resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']
22 | algorithm_names = ['sgd', 'adamw', 'amsgradw', 'nadamw', 'adamax', 'radamw']
23 | warmup_names = ['linear', 'exponential', 'radam', 'none']
24 |
25 |
26 | def check_pytorch_version(algorithm):
27 | major, minor, patch = map(int, torch.__version__.split('+')[0].split('.'))
28 | if major == 0 or (major == 1 and minor < 12):
29 | sys.exit('This script requires PyTorch 1.12+ or 2.x.')
30 |
31 | if algorithm == 'nadamw' and (major == 1 or (major == 2 and minor < 1)):
32 | sys.exit('[Error] The NAdamW optimization algorithm requires PyTorch 2.1 or later.')
33 | elif algorithm == 'radamw' and (major == 1 or (major == 2 and minor < 3)):
34 | sys.exit('[Error] The RAdamW optimization algorithm requires PyTorch 2.3 or later.')
35 |
36 |
37 | def get_lr(args, optimizer):
38 | lr = optimizer.param_groups[0]['lr']
39 | return lr.item() if args.compile else lr
40 |
41 |
42 | def train_iter_loss_fn(optimizer, model, data, target):
43 | optimizer.zero_grad()
44 | output = model(data)
45 | loss = F.cross_entropy(output, target)
46 | loss.backward()
47 | optimizer.step()
48 | return loss
49 |
50 |
51 | def update_lr_fn(lr_scheduler, warmup_scheduler):
52 | with warmup_scheduler.dampening():
53 | lr_scheduler.step()
54 |
55 |
56 | def train(args, model, device, train_loader, optimizer, lr_scheduler,
57 | warmup_scheduler, epoch, history):
58 | since = time.time()
59 | model.train()
60 | progress = tqdm(total=len(train_loader), disable=args.no_progress)
61 | progress.set_description(f"[train] Epoch {epoch}")
62 | train_loss = 0
63 | for batch_idx, (data, target) in enumerate(train_loader):
64 | lr = get_lr(args, optimizer)
65 | data, target = data.to(device), target.to(device)
66 | loss = train_iter_loss_fn(optimizer, model, data, target)
67 | update_lr_fn(lr_scheduler, warmup_scheduler)
68 | loss = loss.item()
69 | train_loss += loss
70 | batch_step = batch_idx + 1
71 | if batch_step % args.log_interval == 0:
72 | step = warmup_scheduler.last_step
73 | history.write(f'{epoch},{step},{loss:g},{lr:g}\n')
74 | if not progress.disable:
75 | if batch_step % 10 == 0:
76 | progress.set_postfix_str(f'loss={loss:.2f}, lr={lr:5.4f}')
77 | progress.update()
78 | progress.close()
79 |
80 | train_loss /= len(train_loader)
81 | elapsed = time.time() - since
82 | print(f'[train] Epoch {epoch}: Elapsed Time: {elapsed:.3f} sec, ' +
83 | f'Ave. Loss: {train_loss:.4f}')
84 |
85 |
86 | def test_iter_loss_fn(model, data, target):
87 | output = model(data)
88 | loss = F.cross_entropy(output, target, reduction='sum')
89 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max of unnormalized logits
90 | correct = pred.eq(target.view_as(pred)).sum()
91 | return loss, correct
92 |
93 |
94 | @torch.inference_mode()
95 | def test(args, model, device, test_loader, epoch, evaluation):
96 | since = time.time()
97 | model.eval()
98 | progress = tqdm(test_loader, disable=args.no_progress)
99 | progress.set_description(f"[test] Epoch {epoch}")
100 | test_loss = 0
101 | correct = 0
102 | for data, target in progress:
103 | data, target = data.to(device), target.to(device)
104 | batch_loss, batch_correct = test_iter_loss_fn(model, data, target)
105 | test_loss += batch_loss.item() # sum up batch loss
106 | correct += batch_correct.item() # sum up batch correct
107 |
108 | test_loss /= len(test_loader.dataset)
109 | test_acc = 100. * correct / len(test_loader.dataset)
110 | elapsed = time.time() - since
111 | print(f'[test] Epoch {epoch}: Elapsed Time: {elapsed:.3f} sec, ' +
112 | f'Ave. Loss: {test_loss:.4f}, Accuracy: {test_acc:.2f}%')
113 | evaluation.write(f'{epoch},{test_loss:g},{test_acc:.2f}\n')
114 | evaluation.flush()
115 | return test_acc
116 |
117 |
118 | def gpu_device():
119 | if torch.backends.mps.is_available():
120 | return torch.device('mps')
121 | elif torch.cuda.is_available():
122 | return torch.device('cuda')
123 | else:
124 | return torch.device('cpu')
125 |
126 |
127 | def dataloader_options(device, workers):
128 | if device.type == 'cpu':
129 | return {}
130 |
131 | kwargs = dict(num_workers=workers, pin_memory=True)
132 | if workers > 0:
133 | if device.type == 'mps':
134 | kwargs.update(dict(multiprocessing_context="forkserver", persistent_workers=True))
135 | else:
136 | kwargs.update(dict(persistent_workers=True))
137 | return kwargs
138 |
139 |
140 | def optimization_algorithm(args, model, device):
141 | name = args.algorithm
142 | lr = torch.tensor(args.lr).to(device) if args.compile else args.lr
143 | kwargs = dict(lr=lr, weight_decay=args.weight_decay)
144 | if name == 'sgd':
145 | kwargs['momentum'] = 0.9
146 | else:
147 | kwargs['betas'] = (0.9, args.beta2)
148 |
149 | if name == 'sgd':
150 | return optim.SGD(model.parameters(), **kwargs)
151 | elif name == 'adamw':
152 | return optim.AdamW(model.parameters(), **kwargs)
153 | elif name == 'amsgradw':
154 | return optim.AdamW(model.parameters(), amsgrad=True, **kwargs)
155 | elif name == 'nadamw':
156 | return optim.NAdam(model.parameters(), decoupled_weight_decay=True, **kwargs)
157 | elif name == 'adamax':
158 | return optim.Adamax(model.parameters(), **kwargs)
159 | elif name == 'radamw':
160 | return optim.RAdam(model.parameters(), decoupled_weight_decay=True, **kwargs)
161 | else:
162 | raise ValueError(f'unknown optimization algorithm: {name}')
163 |
164 |
165 | def warmup_schedule(optimizer, name, period):
166 | if name == 'linear':
167 | if period == 0:
168 | return warmup.UntunedLinearWarmup(optimizer)
169 | else:
170 | return warmup.LinearWarmup(optimizer, period)
171 | elif name == 'exponential':
172 | if period == 0:
173 | return warmup.UntunedExponentialWarmup(optimizer)
174 | else:
175 | return warmup.ExponentialWarmup(optimizer, period)
176 | elif name == 'radam':
177 | return warmup.RAdamWarmup(optimizer)
178 | elif name == 'none':
179 | return warmup.LinearWarmup(optimizer, 1)
180 | else:
181 | raise ValueError(f'unknown warmup schedule: {name}')
182 |
183 |
184 | def compile_functions():
185 | global train_iter_loss_fn
186 | global test_iter_loss_fn
187 | train_iter_loss_fn = torch.compile(train_iter_loss_fn, mode="reduce-overhead")
188 | test_iter_loss_fn = torch.compile(test_iter_loss_fn, mode="reduce-overhead")
189 |
190 |
191 | def init_momentum_buffer(optimizer):
192 | for group in optimizer.param_groups:
193 | if group["momentum"] != 0:
194 | for p in group["params"]:
195 | state = optimizer.state[p]
196 | if state.get("momentum_buffer") is None:
197 | state["momentum_buffer"] = torch.zeros_like(p.data)
198 |
199 |
200 | def main(args=None):
201 | # Training settings
202 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Example')
203 | parser.add_argument('-r', '--arch', type=str, default='resnet20', metavar='ARCH',
204 | choices=architecture_names,
205 | help='ResNet architecture for CIFAR10: ' +
206 | ' | '.join(architecture_names) + ' (default: resnet20)')
207 | parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='BS',
208 | help='input batch size for training (default: 128)')
209 | parser.add_argument('-c', '--test-batch-size', type=int, default=1000, metavar='BS',
210 | help='input batch size for testing (default: 1000)')
211 | parser.add_argument('-e', '--epochs', type=int, default=186, metavar='NE',
212 | help='number of epochs to train (default: 186)')
213 | parser.add_argument('-m', '--milestones', type=int, nargs='+', default=[81, 122], metavar='M',
214 | help="MultiStepLR's milestones in epoch (default: [81, 122])")
215 | parser.add_argument('-a', '--algorithm', type=str, default='sgd', metavar='ALGO',
216 | choices=algorithm_names,
217 | help='optimization algorithm: ' +
218 | ' | '.join(algorithm_names) + ' (default: sgd)')
219 | parser.add_argument('-l', '--lr', type=float, default=0.1, metavar='LR',
220 | help='base learning rate (default: 0.1)')
221 | parser.add_argument('-d', '--weight-decay', type=float, default=0.0001, metavar='WD',
222 | help='weight decay (default: 0.0001)')
223 | parser.add_argument('-g', '--beta2', type=float, default=0.999, metavar='B2',
224 | help="Adam's beta2 parameter (default: 0.999)")
225 | parser.add_argument('-w', '--warmup', type=str, default='none', metavar='WU',
226 | choices=warmup_names,
227 | help='warmup schedule: ' +
228 | ' | '.join(warmup_names) + ' (default: none)')
229 | parser.add_argument('-t', '--warmup-period', type=int, default=0, metavar='TAU',
230 | help='linear warmup period or exponential warmup constant. ' +
231 | 'Set 0 to use the untuned linear or exponential warmup. (default: 0)')
232 | parser.add_argument('-n', '--workers', type=int, default=0, metavar='NW',
233 | help='number of dataloader workers for GPU training (default: 0)')
234 | parser.add_argument('-s', '--seed', type=int, default=1, metavar='S',
235 | help='random seed (default: 1)')
236 | parser.add_argument('-i', '--log-interval', type=int, default=10, metavar='I',
237 | help='how many batches to wait before logging training status')
238 | parser.add_argument('-o', '--output', default='output', metavar='PATH',
239 | help='path to output directory (default: output)')
240 | parser.add_argument('--save-model', action='store_true', default=False,
241 | help='for saving the best model')
242 | parser.add_argument('--no-progress', action='store_true', default=False,
243 | help='disable progress bar')
244 | parser.add_argument('--no-gpu', action='store_true', default=False,
245 | help='disable GPU training. ' +
246 | 'As default, an MPS or CUDA device will be used if available.')
247 | parser.add_argument('--compile', action='store_true', default=False,
248 | help='optimize PyTorch code using TorchDynamo, AOTAutograd, and TorchInductor')
249 | args = parser.parse_args(args)
250 |
251 | check_pytorch_version(args.algorithm)
252 |
253 | print(args)
254 | device = torch.device('cpu') if args.no_gpu else gpu_device()
255 | print(f'Device: {device.type}')
256 |
257 | torch.manual_seed(args.seed)
258 |
259 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
260 | std=[0.229, 0.224, 0.225])
261 | kwargs = dataloader_options(device, args.workers)
262 | train_loader = torch.utils.data.DataLoader(
263 | datasets.CIFAR10(
264 | 'data', train=True, download=True,
265 | transform=transforms.Compose([
266 | transforms.RandomHorizontalFlip(),
267 | transforms.RandomCrop(32, 4),
268 | transforms.ToTensor(),
269 | normalize,
270 | ])),
271 | batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs)
272 | test_loader = torch.utils.data.DataLoader(
273 | datasets.CIFAR10(
274 | 'data', train=False,
275 | transform=transforms.Compose([
276 | transforms.ToTensor(),
277 | normalize,
278 | ])),
279 | batch_size=args.test_batch_size, shuffle=False, **kwargs)
280 |
281 | output_dir = args.output
282 | try:
283 | os.makedirs(output_dir, exist_ok=False)
284 | except FileExistsError:
285 | sys.exit(f'[Error] File exists: {output_dir}')
286 |
287 | history = open(os.path.join(output_dir, 'history.csv'), 'w')
288 | history.write('epoch,step,loss,lr\n')
289 |
290 | evaluation = open(os.path.join(output_dir, 'evaluation.csv'), 'w')
291 | evaluation.write('epoch,loss,accuracy\n')
292 |
293 | model = resnet.__dict__[args.arch]().to(device)
294 |
295 | optimizer = optimization_algorithm(args, model, device)
296 |
297 | steps_per_epoch = len(train_loader)
298 | lr_scheduler = optim.lr_scheduler.MultiStepLR(
299 | optimizer,
300 | milestones=[i * steps_per_epoch for i in args.milestones])
301 |
302 | warmup_scheduler = warmup_schedule(optimizer,
303 | name=args.warmup,
304 | period=args.warmup_period)
305 |
306 | if args.compile:
307 | if args.algorithm == 'sgd':
308 | init_momentum_buffer(optimizer)
309 | compile_functions()
310 |
311 | best_acc = 0.0
312 | best_epoch = 0
313 | print()
314 | for epoch in range(1, args.epochs + 1):
315 | train(args, model, device, train_loader, optimizer, lr_scheduler,
316 | warmup_scheduler, epoch, history)
317 | cur_acc = test(args, model, device, test_loader, epoch, evaluation)
318 |
319 | if cur_acc > best_acc:
320 | best_acc = cur_acc
321 | best_epoch = epoch
322 | if args.save_model:
323 | torch.save(model.state_dict(), os.path.join(output_dir, f"cifar10_{args.arch}.pt"))
324 |
325 | print(f"The best accuracy: {best_acc:.2f}% (epoch {best_epoch})")
326 |
327 | history.close()
328 | evaluation.close()
329 |
330 |
331 | if __name__ == '__main__':
332 | main()
333 |
--------------------------------------------------------------------------------
/examples/cifar10/plot.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import pandas as pd
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | def main(args=None):
8 | parser = argparse.ArgumentParser(description='Training History Plot')
9 | parser.add_argument('path', metavar='PATH',
10 | help='path to the output directory of the training script')
11 | parser.add_argument('--output', type=str, default='none',
12 | choices=['none', 'png', 'pdf'],
13 | help='output file type (default: none)')
14 | args = parser.parse_args(args)
15 |
16 | output_dir = args.path
17 |
18 | df_hist = pd.read_csv(os.path.join(output_dir, "history.csv"))
19 | df_eval = pd.read_csv(os.path.join(output_dir, "evaluation.csv"))
20 |
21 | fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(6, 8), layout="constrained")
22 |
23 | df_hist.plot(x="step", y="lr", ax=ax1, legend=False)
24 | ax1.set_xlabel('Iteration')
25 | ax1.set_ylabel('Learning rate')
26 |
27 | min_loss = df_hist.loss.min()
28 | df_hist.plot(x="step", y="loss", ylim=(0, 1) if min_loss < 0.5 else None, ax=ax2, legend=False)
29 | ax2.set_xlabel('Iteration')
30 | ax2.set_ylabel('Loss')
31 |
32 | df_eval.plot(x="epoch", y="accuracy", ylim=(0, 100), ax=ax3, legend=False)
33 | ax3.set_xlabel('Epoch')
34 | ax3.set_ylabel('Accuracy (%)')
35 | max_idx = df_eval.accuracy.argmax()
36 | max_epoch = df_eval.epoch[max_idx]
37 | max_acc = df_eval.accuracy[max_idx]
38 | ax3.axvline(x=max_epoch, color='red', ls=':')
39 | ax3.text(x=max_epoch, y=5, s=f'Max: {max_acc:.2f}% @ {max_epoch} ', ha='right', va='bottom')
40 | ax3.plot([0], [max_acc])
41 |
42 | if args.output == 'none':
43 | plt.show()
44 | else:
45 | file_path = os.path.join(output_dir, f'fig_history.{args.output}')
46 | plt.savefig(file_path)
47 | print(file_path)
48 |
49 |
50 | if __name__ == '__main__':
51 | main()
52 |
--------------------------------------------------------------------------------
/examples/emnist/README.md:
--------------------------------------------------------------------------------
1 | # EMNIST Example
2 |
3 | Requirements: `pytorch_warmup` and `torchvision`.
4 |
5 | ## Results
6 |
7 |
8 |
9 | Test accuracy over time for each warmup schedule.
10 |
11 |
12 |
13 |
14 | Learning rate over time for each warmup schedule.
15 |
16 |
17 | ## Download EMNIST Dataset
18 |
19 | Run the Python script `download.py` to download the EMNIST dataset:
20 |
21 | ```shell
22 | python download.py
23 | ```
24 |
25 | This script shows download progress:
26 |
27 | ```
28 | Downloading zip archive
29 | Downloading https://.../EMNIST/gzip.zip to data/EMNIST/raw/gzip.zip
30 | 100.0%
31 | ```
32 |
33 | ## Train CNN Models
34 |
35 | Run the Python script `main.py` to train a CNN model on the EMNIST dataset using the AdamW algorithm.
36 |
37 | > [!Note]
38 | > Use `--workers` option to set the number of dataloader workers
39 | > for a better performance in a GPU training.
40 | > The optimal number depends on your environment.
41 | > For example, 2 and 4 for an MPS and CUDA device, respectively,
42 | > but it should not be more than the number of available performance cores.
43 | > Note that the initial epoch takes more time than a later epoch if the number of workers is greater than 0.
44 |
45 | The training log and the test evaluations are saved to files in a directory named `output_[warmup schedule name]`:
46 |
47 | * `history.csv` - The training log.
48 | * `evaluation.csv` - The test evaluations during training.
49 | * `emnist_cnn.pt` - The current model (saved optionally).
50 |
51 | You can visualize the training result using the Python script `plot.py`:
52 |
53 | ```
54 | python plot.py [path to the directory]
55 | ```
56 |
57 | This plot script requires `pandas` and `matplotlib`.
58 |
59 |
60 |
61 | A training result for the AdamW algorithm with the RAdam warmup.
62 |
63 |
64 | ### Untuned Linear Warmup
65 |
66 | Train a CNN model with the *Untuned Linear Warmup* schedule:
67 |
68 | ```
69 | python main.py --warmup linear
70 | ```
71 |
72 | ### Untuned Exponential Warmup
73 |
74 | Train a CNN model with the *Untuned Exponential Warmup* schedule:
75 |
76 | ```
77 | python main.py --warmup exponential
78 | ```
79 |
80 | ### RAdam Warmup
81 |
82 | Train a CNN model with the *RAdam Warmup* schedule:
83 |
84 | ```
85 | python main.py --warmup radam
86 | ```
87 |
88 | ### No Warmup
89 |
90 | Train a CNN model without warmup:
91 |
92 | ```
93 | python main.py --warmup none
94 | ```
95 |
96 | > [!Warning]
97 | > You may have a very different result from one shown in the figure
98 | > because a training without warmup can become significantly unstable at a very early stage.
99 | > The observed accuracies at the last epoch are 2%, 78%, 86%, etc.
100 | > The figure's result was obtained on Apple M1 Pro chip without GPU acceleration.
101 |
102 | ## Usage
103 |
104 | ```
105 | usage: main.py [-h] [--batch-size N] [--test-batch-size N] [--epochs N]
106 | [--lr LR] [--lr-min LM] [--wd WD] [--beta2 B2] [--seed S]
107 | [--log-interval N]
108 | [--warmup {linear,exponential,radam,none}] [--workers N]
109 | [--save-model] [--no-gpu]
110 |
111 | PyTorch EMNIST Example
112 |
113 | options:
114 | -h, --help show this help message and exit
115 | --batch-size N input batch size for training (default: 64)
116 | --test-batch-size N input batch size for testing (default: 1000)
117 | --epochs N number of epochs to train (default: 10)
118 | --lr LR base learning rate (default: 0.01)
119 | --lr-min LM minimum learning rate (default: 1e-5)
120 | --wd WD weight decay (default: 0.01)
121 | --beta2 B2 Adam's beta2 parameter (default: 0.999)
122 | --seed S random seed (default: 1)
123 | --log-interval N how many batches to wait before logging training
124 | status
125 | --warmup {linear,exponential,radam,none}
126 | warmup schedule
127 | --workers N number of dataloader workers for GPU training
128 | (default: 0)
129 | --save-model for saving the current model
130 | --no-gpu disable GPU training. As default, an MPS or CUDA
131 | device will be used if available.
132 | ```
133 |
134 | ```
135 | usage: plot.py [-h] [--output {none,png,pdf}] PATH
136 |
137 | Training History Plot
138 |
139 | positional arguments:
140 | PATH path to the output directory of the training script
141 |
142 | options:
143 | -h, --help show this help message and exit
144 | --output {none,png,pdf}
145 | output file type (default: none)
146 | ```
147 |
148 | © 2024 Takenori Yamamoto
--------------------------------------------------------------------------------
/examples/emnist/download.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | from torchvision.datasets.utils import download_url
3 | import os
4 |
5 | raw_folder = 'data/EMNIST/raw'
6 |
7 | url = 'https://biometrics.nist.gov/cs_links/EMNIST/gzip.zip'
8 | md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
9 |
10 | version_numbers = list(map(int, torchvision.__version__.split('+')[0].split('.')))
11 | if version_numbers[0] == 0 and version_numbers[1] < 10:
12 | filename = "emnist.zip"
13 | else:
14 | filename = None
15 |
16 | os.makedirs(raw_folder, exist_ok=True)
17 |
18 | # download files
19 | print('Downloading zip archive')
20 | download_url(url, root=raw_folder, filename=filename, md5=md5)
21 |
--------------------------------------------------------------------------------
/examples/emnist/figs/accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/emnist/figs/accuracy.png
--------------------------------------------------------------------------------
/examples/emnist/figs/fig-history-adamw-w-radam-warmup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/emnist/figs/fig-history-adamw-w-radam-warmup.png
--------------------------------------------------------------------------------
/examples/emnist/figs/learning_rate.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/emnist/figs/learning_rate.png
--------------------------------------------------------------------------------
/examples/emnist/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | from torchvision import datasets, transforms
7 |
8 | import pytorch_warmup as warmup
9 | import os
10 | import sys
11 | import time
12 |
13 |
14 | class Net(nn.Module):
15 | def __init__(self):
16 | super(Net, self).__init__()
17 | self.conv1 = nn.Conv2d(1, 20, 5, 1)
18 | self.conv2 = nn.Conv2d(20, 50, 5, 1)
19 | self.fc1 = nn.Linear(4*4*50, 500)
20 | self.fc2 = nn.Linear(500, 47)
21 |
22 | def forward(self, x):
23 | x = F.relu(self.conv1(x))
24 | x = F.max_pool2d(x, 2, 2)
25 | x = F.relu(self.conv2(x))
26 | x = F.max_pool2d(x, 2, 2)
27 | x = x.view(-1, 4*4*50)
28 | x = F.relu(self.fc1(x))
29 | x = self.fc2(x)
30 | return F.log_softmax(x, dim=1)
31 |
32 |
33 | def train(args, model, device, train_loader, optimizer, lr_scheduler,
34 | warmup_scheduler, epoch, history):
35 | since = time.time()
36 | model.train()
37 | for batch_idx, (data, target) in enumerate(train_loader):
38 | lr = optimizer.param_groups[0]['lr']
39 | data, target = data.to(device), target.to(device)
40 | optimizer.zero_grad()
41 | output = model(data)
42 | loss = F.nll_loss(output, target)
43 | loss.backward()
44 | optimizer.step()
45 | with warmup_scheduler.dampening():
46 | lr_scheduler.step()
47 | if (batch_idx+1) % args.log_interval == 0:
48 | loss = loss.item()
49 | step = warmup_scheduler.last_step
50 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} LR: {:.6f}'.format(
51 | epoch, (batch_idx+1) * len(data), len(train_loader) * len(data),
52 | 100. * (batch_idx+1) / len(train_loader), loss, lr))
53 | history.write(f'{epoch},{step},{loss:g},{lr:g}\n')
54 | print('Train Elapsed Time: {:.3f} sec'.format(time.time()-since))
55 |
56 |
57 | def test(args, model, device, test_loader, epoch, evaluation):
58 | since = time.time()
59 | model.eval()
60 | test_loss = 0
61 | correct = 0
62 | with torch.no_grad():
63 | for data, target in test_loader:
64 | data, target = data.to(device), target.to(device)
65 | output = model(data)
66 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
67 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
68 | correct += pred.eq(target.view_as(pred)).sum().item()
69 |
70 | test_loss /= len(test_loader.dataset)
71 | test_acc = 100. * correct / len(test_loader.dataset)
72 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
73 | test_loss, correct, len(test_loader.dataset), test_acc))
74 | evaluation.write(f'{epoch},{test_loss:g},{test_acc:.2f}\n')
75 | evaluation.flush()
76 | print('Test Elapsed Time: {:.3f} sec\n'.format(time.time()-since))
77 |
78 |
79 | def mps_is_available():
80 | try:
81 | return torch.backends.mps.is_available()
82 | except AttributeError:
83 | return False
84 |
85 |
86 | def gpu_device():
87 | if torch.cuda.is_available():
88 | return torch.device('cuda')
89 | elif mps_is_available():
90 | return torch.device('mps')
91 | else:
92 | return torch.device('cpu')
93 |
94 |
95 | def dataloader_options(device, workers):
96 | if device.type == 'cpu':
97 | return {}
98 |
99 | kwargs = dict(num_workers=workers, pin_memory=True)
100 | if workers > 0:
101 | if device.type == 'mps':
102 | kwargs.update(dict(multiprocessing_context="forkserver", persistent_workers=True))
103 | else:
104 | kwargs.update(dict(persistent_workers=True))
105 | return kwargs
106 |
107 |
108 | def warmup_schedule(optimizer, name):
109 | if name == 'linear':
110 | return warmup.UntunedLinearWarmup(optimizer)
111 | elif name == 'exponential':
112 | return warmup.UntunedExponentialWarmup(optimizer)
113 | elif name == 'radam':
114 | return warmup.RAdamWarmup(optimizer)
115 | elif name == 'none':
116 | return warmup.LinearWarmup(optimizer, 1)
117 |
118 |
119 | def main(args=None):
120 | # Training settings
121 | parser = argparse.ArgumentParser(description='PyTorch EMNIST Example')
122 | parser.add_argument('--batch-size', type=int, default=64, metavar='N',
123 | help='input batch size for training (default: 64)')
124 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
125 | help='input batch size for testing (default: 1000)')
126 | parser.add_argument('--epochs', type=int, default=10, metavar='N',
127 | help='number of epochs to train (default: 10)')
128 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
129 | help='base learning rate (default: 0.01)')
130 | parser.add_argument('--lr-min', type=float, default=1e-5, metavar='LM',
131 | help='minimum learning rate (default: 1e-5)')
132 | parser.add_argument('--wd', type=float, default=0.01, metavar='WD',
133 | help='weight decay (default: 0.01)')
134 | parser.add_argument('--beta2', type=float, default=0.999, metavar='B2',
135 | help="Adam's beta2 parameter (default: 0.999)")
136 | parser.add_argument('--seed', type=int, default=1, metavar='S',
137 | help='random seed (default: 1)')
138 | parser.add_argument('--log-interval', type=int, default=10, metavar='N',
139 | help='how many batches to wait before logging training status')
140 | parser.add_argument('--warmup', type=str, default='linear',
141 | choices=['linear', 'exponential', 'radam', 'none'],
142 | help='warmup schedule')
143 | parser.add_argument('--workers', type=int, default=0, metavar='N',
144 | help='number of dataloader workers for GPU training (default: 0)')
145 | parser.add_argument('--save-model', action='store_true', default=False,
146 | help='for saving the current model')
147 | parser.add_argument('--no-gpu', action='store_true', default=False,
148 | help='disable GPU training. ' +
149 | 'As default, an MPS or CUDA device will be used if available.')
150 | args = parser.parse_args(args)
151 |
152 | print(args)
153 | device = torch.device('cpu') if args.no_gpu else gpu_device()
154 | print(f'Device: {device.type}')
155 |
156 | torch.manual_seed(args.seed)
157 |
158 | kwargs = dataloader_options(device, args.workers)
159 | train_loader = torch.utils.data.DataLoader(
160 | datasets.EMNIST('data', 'balanced', train=True, download=True,
161 | transform=transforms.Compose([
162 | transforms.ToTensor(),
163 | transforms.Normalize((0.1751,), (0.3332,))
164 | ])),
165 | batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs)
166 | test_loader = torch.utils.data.DataLoader(
167 | datasets.EMNIST('data', 'balanced', train=False,
168 | transform=transforms.Compose([
169 | transforms.ToTensor(),
170 | transforms.Normalize((0.1751,), (0.3332,))
171 | ])),
172 | batch_size=args.test_batch_size, shuffle=False, **kwargs)
173 |
174 | output_dir = f'output_{args.warmup}'
175 | try:
176 | os.makedirs(output_dir, exist_ok=False)
177 | except FileExistsError:
178 | sys.exit(f'[Error] File exists: {output_dir}')
179 |
180 | history = open(os.path.join(output_dir, 'history.csv'), 'w')
181 | history.write('epoch,step,loss,lr\n')
182 |
183 | evaluation = open(os.path.join(output_dir, 'evaluation.csv'), 'w')
184 | evaluation.write('epoch,loss,accuracy\n')
185 |
186 | model = Net().to(device)
187 |
188 | optimizer = optim.AdamW(model.parameters(), lr=args.lr,
189 | betas=(0.9, args.beta2),
190 | weight_decay=args.wd)
191 | num_steps = len(train_loader) * args.epochs
192 | lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
193 | optimizer, T_max=num_steps, eta_min=args.lr_min)
194 | warmup_scheduler = warmup_schedule(optimizer, args.warmup)
195 |
196 | for epoch in range(1, args.epochs + 1):
197 | train(args, model, device, train_loader, optimizer, lr_scheduler,
198 | warmup_scheduler, epoch, history)
199 | test(args, model, device, test_loader, epoch, evaluation)
200 |
201 | if args.save_model:
202 | torch.save(model.state_dict(), os.path.join(output_dir, "emnist_cnn.pt"))
203 |
204 | history.close()
205 | evaluation.close()
206 |
207 |
208 | if __name__ == '__main__':
209 | main()
210 |
--------------------------------------------------------------------------------
/examples/emnist/plot.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import pandas as pd
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | def main(args=None):
8 | parser = argparse.ArgumentParser(description='Training History Plot')
9 | parser.add_argument('path', metavar='PATH',
10 | help='path to the output directory of the training script')
11 | parser.add_argument('--output', type=str, default='none',
12 | choices=['none', 'png', 'pdf'],
13 | help='output file type (default: none)')
14 | args = parser.parse_args(args)
15 |
16 | output_dir = args.path
17 |
18 | df_hist = pd.read_csv(os.path.join(output_dir, "history.csv"))
19 | df_eval = pd.read_csv(os.path.join(output_dir, "evaluation.csv"))
20 |
21 | fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(6, 8), layout="constrained")
22 |
23 | df_hist.plot(x="step", y="lr", ax=ax1, legend=False)
24 | ax1.set_xlabel('Iteration')
25 | ax1.set_ylabel('Learning rate')
26 |
27 | min_loss = df_hist.loss.min()
28 | df_hist.plot(x="step", y="loss", ylim=(0, 1) if min_loss < 0.8 else None, ax=ax2, legend=False)
29 | ax2.set_xlabel('Iteration')
30 | ax2.set_ylabel('Loss')
31 |
32 | df_eval.plot(x="epoch", y="accuracy", marker="o", ax=ax3, legend=False)
33 | ax3.set_xlabel('Epoch')
34 | ax3.set_ylabel('Accuracy (%)')
35 | max_idx = df_eval.accuracy.argmax()
36 | max_epoch = df_eval.epoch[max_idx]
37 | max_acc = df_eval.accuracy[max_idx]
38 | min_acc = df_eval.accuracy.min()
39 | ax3.axvline(x=max_epoch, color='red', ls=':')
40 | ax3.text(x=max_epoch, y=min_acc, s=f'Max: {max_acc:.2f}% @ {max_epoch} ', ha='right', va='bottom')
41 | ax3.plot([0], [max_acc])
42 |
43 | if args.output == 'none':
44 | plt.show()
45 | else:
46 | file_path = os.path.join(output_dir, f'fig_history.{args.output}')
47 | plt.savefig(file_path)
48 | print(file_path)
49 |
50 |
51 | if __name__ == '__main__':
52 | main()
53 |
--------------------------------------------------------------------------------
/examples/plots/README.md:
--------------------------------------------------------------------------------
1 | # Plots
2 |
3 | Requirements: `pytorch_warmup` and `matplotlib`.
4 |
5 | ## Effective Warmup Period
6 |
7 |
8 |
9 | Effective warmup periods of RAdam and rule-of-thumb warmup schedules, as a function of β₂.
10 |
11 |
12 | Run the Python script `effective_warmup_period.py` to show up the figure above:
13 |
14 | ```shell
15 | python effective_warmup_period.py
16 | ```
17 |
18 | ### Usage
19 |
20 | ```
21 | usage: effective_warmup_period.py [-h] [--output {none,png,pdf}]
22 |
23 | Effective warmup period
24 |
25 | options:
26 | -h, --help show this help message and exit
27 | --output {none,png,pdf}
28 | Output file type (default: none)
29 | ```
30 |
31 | ## Warmup Schedule
32 |
33 |
34 |
35 | RAdam and rule-of-thumb warmup schedules over time for β₂ = 0.999.
36 |
37 |
38 | Run the Python script `warmup_schedule.py` to show up the figure above:
39 |
40 | ```shell
41 | python warmup_schedule.py
42 | ```
43 |
44 | ### Usage
45 |
46 | ```
47 | usage: warmup_schedule.py [-h] [--output {none,png,pdf}]
48 |
49 | Warmup schedule
50 |
51 | options:
52 | -h, --help show this help message and exit
53 | --output {none,png,pdf}
54 | Output file type (default: none)
55 | ```
56 |
57 | © 2024 Takenori Yamamoto
--------------------------------------------------------------------------------
/examples/plots/effective_warmup_period.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import matplotlib.pyplot as plt
3 | from pytorch_warmup import rho_fn, rho_inf_fn, get_offset
4 | import numpy as np
5 |
6 |
7 | def untuned_exponential_period(beta2):
8 | return 1.0 / (np.exp(1.0 - beta2) - 1.0)
9 |
10 |
11 | def untuned_linear_period(beta2):
12 | return 1.0 / (1.0 - beta2) - 0.5
13 |
14 |
15 | def warmup_factor(step, beta2, rho_inf, offset):
16 | rho = rho_fn(step+offset, beta2, rho_inf)
17 | numerator = (rho - 4) * (rho - 2) * rho_inf
18 | denominator = (rho_inf - 4) * (rho_inf - 2) * rho
19 | return np.sqrt(numerator/denominator)
20 |
21 |
22 | def radam_period_fn(beta2):
23 | rho_inf = rho_inf_fn(beta2)
24 | offset = get_offset(beta2, rho_inf)
25 | steps = np.arange(1, 101)
26 | w = warmup_factor(steps, beta2, rho_inf, offset)
27 | total_sum = np.sum(1-w)
28 | t = 1
29 | while True:
30 | steps = np.arange(100*t+1, 100*(t+1)+1)
31 | w = warmup_factor(steps, beta2, rho_inf, offset)
32 | partial_sum = np.sum(1-w)
33 | if partial_sum < 0.1:
34 | break
35 | total_sum += partial_sum
36 | t += 1
37 | return total_sum
38 |
39 |
40 | def radam_period(beta2):
41 | return [radam_period_fn(x) for x in beta2]
42 |
43 |
44 | parser = argparse.ArgumentParser(description='Effective warmup period')
45 | parser.add_argument('--output', type=str, default='none',
46 | choices=['none', 'png', 'pdf'],
47 | help='Output file type (default: none)')
48 | args = parser.parse_args()
49 |
50 | beta2 = np.arange(0.99, 0.9999, 0.0001)
51 | plt.plot(beta2, untuned_exponential_period(beta2), label='Untuned Exponential')
52 | plt.plot(beta2, untuned_linear_period(beta2), linestyle=':', label='Untuned Linear')
53 | plt.plot(beta2, radam_period(beta2), linestyle='--', label='RAdam')
54 | plt.xlim(0.990, 1.00)
55 | plt.ylim(100, 10000)
56 | plt.yscale('log')
57 | plt.legend()
58 | plt.title('Effective Warmup Period')
59 | plt.xlabel(r'$\beta_{2}$')
60 | plt.ylabel(r'${\cal T}(\omega)$')
61 | if args.output == 'none':
62 | plt.show()
63 | else:
64 | plt.savefig(f'warmup_period.{args.output}')
65 |
--------------------------------------------------------------------------------
/examples/plots/figs/warmup_period.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/plots/figs/warmup_period.png
--------------------------------------------------------------------------------
/examples/plots/figs/warmup_schedule.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/plots/figs/warmup_schedule.png
--------------------------------------------------------------------------------
/examples/plots/warmup_schedule.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import matplotlib.pyplot as plt
3 | import torch
4 | from pytorch_warmup import RAdamWarmup, UntunedExponentialWarmup, UntunedLinearWarmup
5 |
6 |
7 | def get_rates(warmup_cls, beta2, max_step):
8 | rates = []
9 | p = torch.nn.Parameter(torch.arange(10, dtype=torch.float32))
10 | optimizer = torch.optim.Adam([{'params': p}], lr=1.0, betas=(0.9, beta2))
11 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0)
12 | warmup_scheduler = warmup_cls(optimizer)
13 | for step in range(1, max_step+1):
14 | rates.append(optimizer.param_groups[0]['lr'])
15 | optimizer.zero_grad()
16 | optimizer.step()
17 | lr_scheduler.step()
18 | warmup_scheduler.dampen()
19 | return rates
20 |
21 |
22 | parser = argparse.ArgumentParser(description='Warmup schedule')
23 | parser.add_argument('--output', type=str, default='none',
24 | choices=['none', 'png', 'pdf'],
25 | help='Output file type (default: none)')
26 | args = parser.parse_args()
27 |
28 | beta2 = 0.999
29 | max_step = 3000
30 |
31 | plt.plot(range(1, max_step+1), get_rates(RAdamWarmup, beta2, max_step), label='RAdam')
32 | plt.plot(range(1, max_step+1), get_rates(UntunedExponentialWarmup, beta2, max_step), label='Untuned Exponential')
33 | plt.plot(range(1, max_step+1), get_rates(UntunedLinearWarmup, beta2, max_step), label='Untuned Linear')
34 | plt.legend()
35 | plt.title('Warmup Schedule')
36 | plt.xlabel('Iteration')
37 | plt.ylabel(r'Warmup factor $(\omega_t)$')
38 | if args.output == 'none':
39 | plt.show()
40 | else:
41 | plt.savefig(f'warmup_schedule.{args.output}')
42 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
4 |
5 |
6 | [project]
7 | name = "pytorch-warmup"
8 | description = "A PyTorch Extension for Learning Rate Warmup"
9 | readme = "README.md"
10 | license = {file = "LICENSE"}
11 | authors = [
12 | { name = "Takenori Yamamoto", email = "yamamoto.takenory@gmail.com" }
13 | ]
14 | classifiers = [
15 | "Development Status :: 5 - Production/Stable",
16 | "Intended Audience :: Developers",
17 | "Intended Audience :: Education",
18 | "Intended Audience :: Science/Research",
19 | "Topic :: Scientific/Engineering",
20 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
21 | "Topic :: Software Development",
22 | "Topic :: Software Development :: Libraries",
23 | "Topic :: Software Development :: Libraries :: Python Modules",
24 | "Programming Language :: Python :: 3",
25 | "Programming Language :: Python :: 3.9",
26 | "Programming Language :: Python :: 3.10",
27 | "Programming Language :: Python :: 3.11",
28 | "Programming Language :: Python :: 3.12",
29 | "Programming Language :: Python :: 3 :: Only",
30 | "Operating System :: OS Independent",
31 | "License :: OSI Approved :: MIT License",
32 | ]
33 | requires-python = ">=3.9"
34 | dependencies = ["torch>=1.9"]
35 | dynamic = ["version"]
36 |
37 |
38 | [project.urls]
39 | "Homepage" = "https://github.com/Tony-Y/pytorch_warmup"
40 | "Bug Reports" = "https://github.com/Tony-Y/pytorch_warmup/issues"
41 |
42 |
43 | [tool.setuptools]
44 | packages = ["pytorch_warmup"]
45 |
46 |
47 | [tool.setuptools.dynamic]
48 | version = { attr = "pytorch_warmup.__version__" }
49 |
--------------------------------------------------------------------------------
/pytorch_warmup/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseWarmup, LinearWarmup, ExponentialWarmup
2 | from .untuned import UntunedLinearWarmup, UntunedExponentialWarmup
3 | from .radam import RAdamWarmup, rho_fn, rho_inf_fn, get_offset
4 |
5 | __version__ = "0.3.0.dev0"
6 |
7 | __all__ = [
8 | 'BaseWarmup',
9 | 'LinearWarmup',
10 | 'ExponentialWarmup',
11 | 'UntunedLinearWarmup',
12 | 'UntunedExponentialWarmup',
13 | 'RAdamWarmup',
14 | 'rho_fn',
15 | 'rho_inf_fn',
16 | 'get_offset',
17 | ]
18 |
--------------------------------------------------------------------------------
/pytorch_warmup/base.py:
--------------------------------------------------------------------------------
1 | import math
2 | from contextlib import contextmanager
3 | from torch import Tensor
4 | from torch.optim import Optimizer
5 |
6 |
7 | def _check_optimizer(optimizer):
8 | if not isinstance(optimizer, Optimizer):
9 | raise TypeError('{} ({}) is not an Optimizer.'.format(
10 | optimizer, type(optimizer).__name__))
11 |
12 |
13 | def _get_lr(group):
14 | if isinstance(group['lr'], Tensor):
15 | return group['lr'].clone().detach()
16 | else:
17 | return group['lr']
18 |
19 |
20 | def _set_lr(group, lr):
21 | if isinstance(group['lr'], Tensor):
22 | group['lr'].copy_(lr)
23 | else:
24 | group['lr'] = lr
25 |
26 |
27 | class BaseWarmup:
28 | """Base class for all warmup schedules.
29 |
30 | The learning rate :math:`\\alpha_{t}` is dampened by multiplying it by
31 | the warmup factor :math:`\\omega_{t} \\in [0, 1]` at each iteration :math:`t`.
32 | Thus, the modified learning rate
33 |
34 | .. math::
35 | \\hat \\alpha_{t} = \\alpha_{t} \\cdot \\omega_{t}
36 |
37 | is used by the optimizer.
38 |
39 | Args:
40 | optimizer (Optimizer): Wrapped optimizer.
41 | warmup_params (list): Warmup parameters.
42 | last_step (int): The index of last step. Default: -1.
43 | """
44 |
45 | def __init__(self, optimizer, warmup_params, last_step=-1):
46 | self.optimizer = optimizer
47 | self.warmup_params = warmup_params
48 | self.last_step = last_step
49 | self.lrs = [_get_lr(group) for group in self.optimizer.param_groups]
50 | self.dampen()
51 |
52 | def state_dict(self):
53 | """Returns the state of the warmup scheduler as a :class:`dict`.
54 |
55 | It contains an entry for every variable in :attr:`self.__dict__` which
56 | is not the optimizer.
57 | """
58 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
59 |
60 | def load_state_dict(self, state_dict):
61 | """Loads the warmup scheduler's state.
62 |
63 | Args:
64 | state_dict (dict): Warmup scheduler state. Should be an object returned
65 | from a call to :meth:`state_dict`.
66 | """
67 | self.__dict__.update(state_dict)
68 |
69 | def dampen(self, step=None):
70 | """Dampens the learning rate.
71 |
72 | It is not recommended to explicitly call this method for PyTorch 1.4.0 or later.
73 | Please use the :meth:`dampening` context manager that calls this method correctly.
74 |
75 | Args:
76 | step (int): The index of current step. Default: ``None``.
77 | """
78 | if step is None:
79 | step = self.last_step + 1
80 | self.last_step = step
81 |
82 | for group, params in zip(self.optimizer.param_groups, self.warmup_params):
83 | omega = self.warmup_factor(step, **params)
84 | group['lr'] *= omega
85 |
86 | @contextmanager
87 | def dampening(self):
88 | """Dampens the learning rate after calling the :meth:`step` method of the learning
89 | rate scheduler.
90 |
91 | The :meth:`step` method calls must be placed in a suite of the ``with`` statement having
92 | the :meth:`dampening` context manager.
93 |
94 | Examples:
95 | >>> # For no LR scheduler
96 | >>> with warmup_scheduler.dampening():
97 | >>> pass
98 |
99 | >>> # For a single LR scheduler
100 | >>> with warmup_scheduler.dampening():
101 | >>> lr_scheduler.step()
102 |
103 | >>> # To chain two LR schedulers
104 | >>> with warmup_scheduler.dampening():
105 | >>> lr_scheduler1.step()
106 | >>> lr_scheduler2.step()
107 |
108 | >>> # To delay an LR scheduler
109 | >>> iteration = warmup_scheduler.last_step + 1
110 | >>> with warmup_scheduler.dampening():
111 | >>> if iteration >= warmup_period:
112 | >>> lr_scheduler.step()
113 | """
114 | for group, lr in zip(self.optimizer.param_groups, self.lrs):
115 | _set_lr(group, lr)
116 | yield
117 | self.lrs = [_get_lr(group) for group in self.optimizer.param_groups]
118 | self.dampen()
119 |
120 | def warmup_factor(self, step, **params):
121 | """Returns the warmup factor :math:`\\omega_{t}` at an iteration :math:`t`.
122 |
123 | :meth:`dampen` uses this method to get the warmup factor for each parameter group.
124 | It is unnecessary to explicitly call this method.
125 |
126 | Args:
127 | step (int): The index of current step.
128 | params (dict): The warmup parameters. For details, refer to the arguments of
129 | each subclass method.
130 | """
131 | raise NotImplementedError
132 |
133 |
134 | def get_warmup_params(warmup_period, group_count):
135 | if isinstance(warmup_period, list):
136 | if len(warmup_period) != group_count:
137 | raise ValueError(
138 | 'The size of warmup_period ({}) does not match the size of param_groups ({}).'.format(
139 | len(warmup_period), group_count))
140 | for x in warmup_period:
141 | if not isinstance(x, int):
142 | raise TypeError(
143 | 'An element in warmup_period, {}, is not an int.'.format(
144 | type(x).__name__))
145 | if x <= 0:
146 | raise ValueError(
147 | 'An element in warmup_period must be a positive integer, but is {}.'.format(x))
148 | warmup_params = [dict(warmup_period=x) for x in warmup_period]
149 | elif isinstance(warmup_period, int):
150 | if warmup_period <= 0:
151 | raise ValueError(
152 | 'warmup_period must be a positive integer, but is {}.'.format(warmup_period))
153 | warmup_params = [dict(warmup_period=warmup_period)
154 | for _ in range(group_count)]
155 | else:
156 | raise TypeError('{} ({}) is not a list nor an int.'.format(
157 | warmup_period, type(warmup_period).__name__))
158 | return warmup_params
159 |
160 |
161 | class LinearWarmup(BaseWarmup):
162 | """Linear warmup schedule.
163 |
164 | The linear warmup schedule uses the warmup factor
165 |
166 | .. math::
167 | \\omega_{t}^{\\rm linear, \\tau} = \\min \\left\\{ 1, \\frac{1}{\\tau} \\cdot t \\right\\}
168 |
169 | at each iteration :math:`t`, where :math:`\\tau` is the warmup period.
170 |
171 | Args:
172 | optimizer (Optimizer): Wrapped optimizer. :class:`RAdam` is not suitable because of the
173 | warmup redundancy.
174 | warmup_period (int or list[int]): The warmup period :math:`\\tau`.
175 | last_step (int): The index of last step. Default: -1.
176 |
177 | Example:
178 | >>> lr_scheduler = CosineAnnealingLR(optimizer, ...)
179 | >>> warmup_scheduler = LinearWarmup(optimizer, warmup_period=2000)
180 | >>> for batch in dataloader:
181 | >>> optimizer.zero_grad()
182 | >>> loss = ...
183 | >>> loss.backward()
184 | >>> optimizer.step()
185 | >>> with warmup_scheduler.dampening():
186 | >>> lr_scheduler.step()
187 |
188 | Warning:
189 | The warmup schedule must not be initialized before the initialization of the learning rate schedule.
190 | """
191 |
192 | def __init__(self, optimizer, warmup_period, last_step=-1):
193 | _check_optimizer(optimizer)
194 | group_count = len(optimizer.param_groups)
195 | warmup_params = get_warmup_params(warmup_period, group_count)
196 | super().__init__(optimizer, warmup_params, last_step)
197 |
198 | def warmup_factor(self, step, warmup_period):
199 | """Returns the warmup factor :math:`\\omega_{t}^{\\rm linear, \\tau}` at an iteration :math:`t`.
200 |
201 | Args:
202 | step (int): The index of current step.
203 | warmup_period (int): The warmup period :math:`\\tau`.
204 | """
205 | return min(1.0, (step+1) / warmup_period)
206 |
207 |
208 | class ExponentialWarmup(BaseWarmup):
209 | """Exponential warmup schedule.
210 |
211 | The exponential warmup schedule uses the warmup factor
212 |
213 | .. math::
214 | \\omega_{t}^{\\rm expo, \\tau} = 1 - \\exp \\left( - \\frac{1}{\\tau} \\cdot t \\right)
215 |
216 | at each iteration :math:`t`, where the constant :math:`\\tau` is analogous to
217 | a linear warmup period.
218 |
219 | Args:
220 | optimizer (Optimizer): Wrapped optimizer. :class:`RAdam` is not suitable because of the
221 | warmup redundancy.
222 | warmup_period (int or list[int]): The constant :math:`\\tau` analogous to a linear warmup period.
223 | last_step (int): The index of last step. Default: -1.
224 |
225 | Example:
226 | >>> lr_scheduler = CosineAnnealingLR(optimizer, ...)
227 | >>> warmup_scheduler = ExponentialWarmup(optimizer, warmup_period=1000)
228 | >>> for batch in dataloader:
229 | >>> optimizer.zero_grad()
230 | >>> loss = ...
231 | >>> loss.backward()
232 | >>> optimizer.step()
233 | >>> with warmup_scheduler.dampening():
234 | >>> lr_scheduler.step()
235 |
236 | Warning:
237 | The warmup schedule must not be initialized before the initialization of the learning rate schedule.
238 | """
239 |
240 | def __init__(self, optimizer, warmup_period, last_step=-1):
241 | _check_optimizer(optimizer)
242 | group_count = len(optimizer.param_groups)
243 | warmup_params = get_warmup_params(warmup_period, group_count)
244 | super().__init__(optimizer, warmup_params, last_step)
245 |
246 | def warmup_factor(self, step, warmup_period):
247 | """Returns the warmup factor :math:`\\omega_{t}^{\\rm expo, \\tau}` at an iteration :math:`t`.
248 |
249 | Args:
250 | step (int): The index of current step.
251 | warmup_period (int): The constant :math:`\\tau` analogous to a linear warmup period.
252 | """
253 | return 1.0 - math.exp(-(step+1) / warmup_period)
254 |
--------------------------------------------------------------------------------
/pytorch_warmup/radam.py:
--------------------------------------------------------------------------------
1 | import math
2 | from .base import BaseWarmup, _check_optimizer
3 |
4 |
5 | def rho_inf_fn(beta2):
6 | """Returns the constant of the RAdam algorithm, :math:`\\rho_{\\infty}`.
7 |
8 | Args:
9 | beta2 (float): The second discount factor of Adam, :math:`\\beta_{2}`.
10 | """
11 | return 2.0 / (1 - beta2) - 1
12 |
13 |
14 | def rho_fn(t, beta2, rho_inf):
15 | """Returns the value of the function of the RAdam algorithm, :math:`\\rho_{t}`,
16 | at an iteration :math:`t`.
17 |
18 | Args:
19 | t (int): The iteration :math:`t`.
20 | beta2 (float): The second discount factor of Adam, :math:`\\beta_{2}`.
21 | rho_inf (float): The constant of the RAdam algorithm, :math:`\\rho_{\\infty}`.
22 | """
23 | b2t = beta2 ** t
24 | rho_t = rho_inf - 2 * t * b2t / (1 - b2t)
25 | return rho_t
26 |
27 |
28 | def get_offset(beta2, rho_inf):
29 | """Returns the minimal offset :math:`\\delta`.
30 |
31 | Args:
32 | beta2 (float): The second discount factor of Adam, :math:`\\beta_{2}`.
33 | rho_inf (float): The constant of the RAdam algorithm, :math:`\\rho_{\\infty}`.
34 | """
35 | if not beta2 > 0.6:
36 | raise ValueError('beta2 ({}) must be greater than 0.6'.format(beta2))
37 | offset = 1
38 | while True:
39 | if rho_fn(offset, beta2, rho_inf) > 4:
40 | return offset
41 | offset += 1
42 |
43 |
44 | class RAdamWarmup(BaseWarmup):
45 | """RAdam warmup schedule.
46 |
47 | This warmup scheme is described in
48 | `On the adequacy of untuned warmup for adaptive optimization
49 | `_.
50 |
51 | The RAdam algorithm uses the warmup factor
52 |
53 | .. math::
54 | \\omega_{t}^{\\rm RAdam} = \\sqrt{ \\frac{ \\
55 | ( \\rho_{t} - 4 ) ( \\rho_{t} - 2 ) \\rho_{\\infty} }{ \\
56 | ( \\rho_{\\infty} - 4) (\\rho_{\\infty} - 2 ) \\rho_{t} } }
57 |
58 | at each iteration :math:`t` for :math:`\\rho_{t} > 4`, where
59 |
60 | .. math::
61 | \\rho_{\\infty} = \\frac{ 2 }{ 1 - \\beta_{2} } - 1
62 |
63 | and
64 |
65 | .. math::
66 | \\rho_{t} = \\rho_{\\infty} - \\frac{ 2 t \\cdot \\beta_{2}^{t} }{ 1 - \\beta_{2}^{t} }
67 |
68 | where :math:`\\beta_{2}` is the second discount factor of Adam. In the RAdam warmup schedule,
69 | the minimal offset :math:`\\delta` is chosen such that :math:`\\rho_{\\delta} > 4`, and then
70 | :math:`\\omega_{t+\\delta-1}^{\\rm RAdam}` is employed as the warmup factor at each iteration :math:`t`.
71 | For all practically relevant values of :math:`\\beta_{2}` (:math:`0.8 < \\beta_{2} \\le 1`),
72 | :math:`\\delta \\le 5` as deduced from Fact 3.1 of the paper.
73 |
74 | Args:
75 | optimizer (Optimizer): Adam optimizer or its variant:
76 | :class:`Adam`, :class:`AdamW`, :class:`SparseAdam`, or :class:`NAdam`.
77 | :class:`RAdam` is not suitable because of the warmup redundancy. This warmup
78 | schedule makes no sense for :class:`Adamax` and, in principle, the AMSGrad variant of
79 | :class:`Adam` and :class:`AdamW` as discussed in Note below. In practice, this warmup
80 | schedule improves the performance of the AMSGrad variant like that of the vanilla Adam.
81 | last_step (int): The index of last step. Default: -1.
82 |
83 | Note:
84 | This warmup schedule employs the same warmup factor for all variants of Adam. However,
85 | according to the RAdam theory,
86 | :class:`Adamax` and the AMSGrad variant of :class:`Adam` and :class:`AdamW` should
87 | have a different warmup factor because its :math:`\\psi(\\cdot)` function is different from one of the
88 | vanilla Adam, where :math:`\\psi(\\cdot)` specifies how the adaptive learning rate at :math:`t` is
89 | calculated. The RAdam theory derives the warmup factor :math:`\\omega_{t}` from
90 | :math:`\\psi(\\cdot)`. For gradients :math:`\\left\\{ g_{i} \\right\\}` viewed as i.i.d. normal random
91 | variables,
92 |
93 | .. centered::
94 | :math:`\\omega_{t} = \\sqrt{ C_{\\rm var} / {\\rm Var}\\left[ \\psi(g_{1}, \\dots, g_{t}) \\right] }`
95 |
96 | where
97 |
98 | .. centered::
99 | :math:`C_{\\rm var} = \\inf_{t} {\\rm Var}\\left[ \\psi(g_{1}, \\dots, g_{t}) \\right]`.
100 |
101 | (For details please refer to `On the Variance of the Adaptive Learning Rate and Beyond
102 | `_.)
103 |
104 | The variance hypothesis of the RAdam theory has become questionable
105 | since Ma and Yarats' paper pointed out that the adaptive learning rate may not be the best medium
106 | of analysis for understanding the role of warmup in Adam.
107 |
108 | Example:
109 | >>> optimizer = AdamW(...)
110 | >>> lr_scheduler = CosineAnnealingLR(optimizer, ...)
111 | >>> warmup_scheduler = RAdamWarmup(optimizer)
112 | >>> for batch in dataloader:
113 | >>> optimizer.zero_grad()
114 | >>> loss = ...
115 | >>> loss.backward()
116 | >>> optimizer.step()
117 | >>> with warmup_scheduler.dampening():
118 | >>> lr_scheduler.step()
119 |
120 | Warning:
121 | The warmup schedule must not be initialized before the initialization of the learning rate schedule.
122 | """
123 |
124 | def __init__(self, optimizer, last_step=-1):
125 | _check_optimizer(optimizer)
126 | warmup_params = [
127 | dict(
128 | beta2=x['betas'][1],
129 | rho_inf=rho_inf_fn(x['betas'][1]),
130 | )
131 | for x in optimizer.param_groups
132 | ]
133 | for x in warmup_params:
134 | x['offset'] = get_offset(**x)
135 | super().__init__(optimizer, warmup_params, last_step)
136 |
137 | def warmup_factor(self, step, beta2, rho_inf, offset):
138 | """Returns the warmup factor :math:`\\omega_{t+\\delta-1}^{\\rm RAdam}` at an iteration :math:`t`.
139 |
140 | Args:
141 | step (int): The index of current step.
142 | beta2 (float): The second discount factor of Adam, :math:`\\beta_{2}`.
143 | rho_inf (float): The constant of the RAdam algorithm, :math:`\\rho_{\\infty}`.
144 | offset (int): The minimal offset :math:`\\delta`.
145 | """
146 | rho = rho_fn(step+offset, beta2, rho_inf)
147 | numerator = (rho - 4) * (rho - 2) * rho_inf
148 | denominator = (rho_inf - 4) * (rho_inf - 2) * rho
149 | return math.sqrt(numerator/denominator)
150 |
--------------------------------------------------------------------------------
/pytorch_warmup/untuned.py:
--------------------------------------------------------------------------------
1 | from .base import LinearWarmup, ExponentialWarmup, _check_optimizer
2 |
3 |
4 | class UntunedLinearWarmup(LinearWarmup):
5 | """Untuned linear warmup schedule for Adam.
6 |
7 | This warmup scheme is described in
8 | `On the adequacy of untuned warmup for adaptive optimization
9 | `_.
10 |
11 | The untuned linear warmup schedule uses the warmup factor
12 |
13 | .. math::
14 | \\omega_{t}^{\\rm linear, untuned} = \\min \\left\\{ 1, \\frac{1 - \\beta_{2}}{2} \\cdot t \\right\\}
15 |
16 | at each iteration :math:`t`, where :math:`\\beta_{2}` is the second discount factor of Adam.
17 | In practice, :math:`\\omega_{t}^{\\rm linear, untuned}` is calculated as
18 | :math:`\\omega_{t}^{\\rm linear, \\tau}` with :math:`\\tau = \\frac{2}{1 - \\beta_{2}}`.
19 |
20 | Note:
21 | The effective warmup period is defined as
22 |
23 | .. centered::
24 | :math:`{\\cal T}(\\omega) = \\sum_{t = 1}^{\\infty} \\left( 1 - \\omega_{t} \\right)`
25 |
26 | for a warmup schedule :math:`\\omega = \\left\\{ \\omega_{t} \\right\\}_{t=1}^{\\infty}`.
27 | The warmup period :math:`\\tau` is deduced from solving approximately the rough equivalence:
28 |
29 | .. centered::
30 | :math:`{\\cal T}(\\omega^{\\rm expo, untuned}) \\approx {\\cal T}(\\omega^{{\\rm linear},
31 | \\tau}) \\approx \\frac{\\tau}{2}`.
32 |
33 | Args:
34 | optimizer (Optimizer): Adam optimizer or its variant:
35 | :class:`Adam`, :class:`AdamW`, :class:`SparseAdam`, or :class:`NAdam`.
36 | :class:`RAdam` is not suitable because of the warmup redundancy. This warmup
37 | schedule makes no sense for :class:`Adamax` as discussed in Note below.
38 | last_step (int): The index of last step. Default: -1.
39 |
40 | Note:
41 | This warmup schedule employs the same warmup period :math:`\\tau` for all variants of Adam. However,
42 | :class:`Adamax` should in principle need no linear warmup because it needs no exponential warmup.
43 | For further details please refer to Note in the documentation of :class:`UntunedExponentialWarmup`.
44 | In practice, a linear warmup may slightly improve AdaMax's performance because the initial update step
45 | is the same as one of the Adam optimizer.
46 |
47 | Example:
48 | >>> optimizer = AdamW(...)
49 | >>> lr_scheduler = CosineAnnealingLR(optimizer, ...)
50 | >>> warmup_scheduler = UntunedLinearWarmup(optimizer)
51 | >>> for batch in dataloader:
52 | >>> optimizer.zero_grad()
53 | >>> loss = ...
54 | >>> loss.backward()
55 | >>> optimizer.step()
56 | >>> with warmup_scheduler.dampening():
57 | >>> lr_scheduler.step()
58 |
59 | Warning:
60 | The warmup schedule must not be initialized before the initialization of the learning rate schedule.
61 | """
62 |
63 | def __init__(self, optimizer, last_step=-1):
64 | _check_optimizer(optimizer)
65 |
66 | def warmup_period_fn(beta2):
67 | return int(2.0 / (1.0-beta2))
68 | warmup_period = [warmup_period_fn(x['betas'][1]) for x in optimizer.param_groups]
69 | super().__init__(optimizer, warmup_period, last_step)
70 |
71 |
72 | class UntunedExponentialWarmup(ExponentialWarmup):
73 | """Untuned exponential warmup schedule for Adam.
74 |
75 | This warmup scheme is described in
76 | `On the adequacy of untuned warmup for adaptive optimization
77 | `_.
78 |
79 | The untuned exponential warmup schedule uses the warmup factor
80 |
81 | .. math::
82 | \\omega_{t}^{\\rm expo, untuned} = 1 - \\exp \\left( - (1 - \\beta_{2}) \\cdot t \\right)
83 |
84 | at each iteration :math:`t`, where :math:`\\beta_{2}` is the second discount factor of Adam.
85 | In practice, :math:`\\omega_{t}^{\\rm expo, untuned}` is calculated as
86 | :math:`\\omega_{t}^{\\rm expo, \\tau}` with :math:`\\tau = \\frac{1}{1 - \\beta_{2}}`.
87 |
88 | Note:
89 | The constant :math:`\\tau` is derived from the intuition that
90 | the warmup factor should be roughly equivalent to Adam's second moment bias correction term,
91 | :math:`1 - \\beta_{2}^{t}`.
92 |
93 | Note:
94 | The effective warmup period is defined as
95 |
96 | .. centered::
97 | :math:`{\\cal T}(\\omega) = \\sum_{t = 1}^{\\infty} \\left( 1 - \\omega_{t} \\right)`
98 |
99 | for a warmup schedule :math:`\\omega = \\left\\{ \\omega_{t} \\right\\}_{t=1}^{\\infty}`.
100 | The constant :math:`\\tau` of the untuned exponential warmup schedule is roughly equivalent to
101 | its effective warmup period:
102 |
103 | .. centered::
104 | :math:`{\\cal T}(\\omega^{\\rm expo, untuned}) = 1 / \\left( \\exp( 1 - \\beta_{2}) - 1 \\right) \\approx \\tau`
105 |
106 | for :math:`\\beta_{2}` near 1. The rough equivalence is also achieved for an exponential warmup schedule
107 | if its :math:`\\tau` is large enough, for example, :math:`\\tau \\ge 1`.
108 |
109 | Args:
110 | optimizer (Optimizer): Adam optimizer or its variant:
111 | :class:`Adam`, :class:`AdamW`, :class:`SparseAdam`, or :class:`NAdam`.
112 | :class:`RAdam` is not suitable because of the warmup redundancy. This warmup
113 | schedule makes no sense for :class:`Adamax` as discussed in Note below.
114 | last_step (int): The index of last step. Default: -1.
115 |
116 | Note:
117 | This warmup schedule employs the same constant :math:`\\tau` for all variants of Adam. However,
118 | :class:`Adamax` should in principle need no warmup because :class:`Adamax` is derived by employing
119 | a :math:`L^{p}` norm update rule and letting :math:`p \\rightarrow \\infty`, and the second moment bias
120 | correction term is :math:`1-\\beta_{2}^{pt}`, to which the warmup factor must be roughly equivalent
121 | in this warmup schedule derivation. In practice, an exponential warmup may slightly improve AdaMax's
122 | performance because the initial update step is the same as one of the Adam optimizer.
123 |
124 | Example:
125 | >>> optimizer = AdamW(...)
126 | >>> lr_scheduler = CosineAnnealingLR(optimizer, ...)
127 | >>> warmup_scheduler = UntunedExponentialWarmup(optimizer)
128 | >>> for batch in dataloader:
129 | >>> optimizer.zero_grad()
130 | >>> loss = ...
131 | >>> loss.backward()
132 | >>> optimizer.step()
133 | >>> with warmup_scheduler.dampening():
134 | >>> lr_scheduler.step()
135 |
136 | Warning:
137 | The warmup schedule must not be initialized before the initialization of the learning rate schedule.
138 | """
139 |
140 | def __init__(self, optimizer, last_step=-1):
141 | _check_optimizer(optimizer)
142 |
143 | def warmup_period_fn(beta2):
144 | return int(1.0 / (1.0-beta2))
145 | warmup_period = [warmup_period_fn(x['betas'][1]) for x in optimizer.param_groups]
146 | super().__init__(optimizer, warmup_period, last_step)
147 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/test/__init__.py
--------------------------------------------------------------------------------
/test/test_base.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import os
3 | import math
4 | import torch
5 | from torch import Tensor
6 | import pytorch_warmup as warmup
7 |
8 |
9 | def _test_state_dict(self, warmup_scheduler, constructor):
10 | warmup_scheduler_copy = constructor()
11 | warmup_scheduler_copy.load_state_dict(warmup_scheduler.state_dict())
12 | for key in warmup_scheduler.__dict__.keys():
13 | if key not in ['optimizer']:
14 | print(key, warmup_scheduler.__dict__[key])
15 | self.assertAlmostEqual(warmup_scheduler.__dict__[key],
16 | warmup_scheduler_copy.__dict__[key])
17 |
18 |
19 | def _test_optimizer(self, warmup_class):
20 | with self.assertRaises(TypeError, msg='optimizer type') as cm:
21 | warmup_class(optimizer=0, warmup_period=5)
22 | self.assertEqual(str(cm.exception), '0 (int) is not an Optimizer.')
23 |
24 |
25 | def _test_get_warmup_params(self, optimizer, warmup_class):
26 | with self.assertRaises(ValueError, msg='warmup_period size') as cm:
27 | warmup_class(optimizer, warmup_period=[5])
28 | self.assertEqual(str(cm.exception), 'The size of warmup_period (1) does not match the size of param_groups (2).')
29 |
30 | with self.assertRaises(TypeError, msg='warmup_period element type') as cm:
31 | warmup_class(optimizer, warmup_period=[5.0, 10.0])
32 | self.assertEqual(str(cm.exception), 'An element in warmup_period, float, is not an int.')
33 |
34 | with self.assertRaises(ValueError, msg='warmup_period element range') as cm:
35 | warmup_class(optimizer, warmup_period=[5, 0])
36 | self.assertEqual(str(cm.exception), 'An element in warmup_period must be a positive integer, but is 0.')
37 |
38 | with self.assertRaises(ValueError, msg='warmup_period range') as cm:
39 | warmup_class(optimizer, warmup_period=0)
40 | self.assertEqual(str(cm.exception), 'warmup_period must be a positive integer, but is 0.')
41 |
42 | with self.assertRaises(TypeError, msg='warmup_period type') as cm:
43 | warmup_class(optimizer, warmup_period=5.0)
44 | self.assertEqual(str(cm.exception), '5.0 (float) is not a list nor an int.')
45 |
46 |
47 | def _wrap(value):
48 | return torch.tensor(value)
49 |
50 |
51 | def _unwrap(value):
52 | assert isinstance(value, Tensor)
53 | assert value.dim() == 0
54 | return value.item()
55 |
56 |
57 | def _identity(value):
58 | return value
59 |
60 |
61 | def _assert_naked(value):
62 | assert not isinstance(value, Tensor)
63 | return value
64 |
65 |
66 | _set_lr, _get_lr = (_wrap, _unwrap) if 'WRAPPED_LR' in os.environ else (_identity, _assert_naked)
67 |
68 |
69 | class TestBase(unittest.TestCase):
70 |
71 | def setUp(self):
72 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
73 |
74 | def test_linear(self):
75 | p1 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
76 | p2 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
77 | optimizer = torch.optim.SGD([
78 | {'params': [p1]},
79 | {'params': [p2], 'lr': _set_lr(0.1)}
80 | ], lr=_set_lr(0.5))
81 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0)
82 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=5)
83 | print()
84 | for step in range(1, 11):
85 | lr = [_get_lr(x['lr']) for x in optimizer.param_groups]
86 | print(f'{step} {lr}')
87 | if step < 5:
88 | self.assertAlmostEqual(lr[0], 0.5 * step / 5)
89 | self.assertAlmostEqual(lr[1], 0.1 * step / 5)
90 | else:
91 | self.assertAlmostEqual(lr[0], 0.5)
92 | self.assertAlmostEqual(lr[1], 0.1)
93 | optimizer.zero_grad()
94 | optimizer.step()
95 | with warmup_scheduler.dampening():
96 | lr_scheduler.step()
97 |
98 | _test_state_dict(self, warmup_scheduler,
99 | lambda: warmup.LinearWarmup(optimizer, warmup_period=10))
100 |
101 | _test_optimizer(self, warmup.LinearWarmup)
102 |
103 | _test_get_warmup_params(self, optimizer, warmup.LinearWarmup)
104 |
105 | def test_exponential(self):
106 | p1 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
107 | p2 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
108 | optimizer = torch.optim.SGD([
109 | {'params': [p1]},
110 | {'params': [p2], 'lr': _set_lr(0.1)}
111 | ], lr=_set_lr(0.5))
112 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0)
113 | warmup_scheduler = warmup.ExponentialWarmup(optimizer, warmup_period=5)
114 | print()
115 | for step in range(1, 11):
116 | lr = [_get_lr(x['lr']) for x in optimizer.param_groups]
117 | print(f'{step} {lr}')
118 | self.assertAlmostEqual(lr[0], 0.5 * (1 - math.exp(-step / 5)))
119 | self.assertAlmostEqual(lr[1], 0.1 * (1 - math.exp(-step / 5)))
120 | optimizer.zero_grad()
121 | optimizer.step()
122 | with warmup_scheduler.dampening():
123 | lr_scheduler.step()
124 |
125 | _test_state_dict(self, warmup_scheduler,
126 | lambda: warmup.ExponentialWarmup(optimizer, warmup_period=10))
127 |
128 | _test_optimizer(self, warmup.ExponentialWarmup)
129 |
130 | _test_get_warmup_params(self, optimizer, warmup.ExponentialWarmup)
131 |
132 | def test_linear_chaining(self):
133 | def preparation():
134 | p1 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
135 | p2 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
136 | optimizer = torch.optim.SGD([
137 | {'params': [p1]},
138 | {'params': [p2], 'lr': _set_lr(0.1)}
139 | ], lr=_set_lr(0.5))
140 | lr_scheduler1 = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
141 | lr_scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
142 | return optimizer, lr_scheduler1, lr_scheduler2
143 |
144 | # First record undampened lrs
145 | optimizer, lr_scheduler1, lr_scheduler2 = preparation()
146 | lrs = []
147 | print()
148 | print('Undampened:')
149 | for step in range(1, 11):
150 | lr = [_get_lr(x['lr']) for x in optimizer.param_groups]
151 | print(f'{step} {lr}')
152 | lrs.append(lr)
153 | optimizer.zero_grad()
154 | optimizer.step()
155 | lr_scheduler1.step()
156 | lr_scheduler2.step()
157 |
158 | optimizer, lr_scheduler1, lr_scheduler2 = preparation()
159 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=5)
160 | print('Dampened:')
161 | for step in range(1, 11):
162 | lr = [_get_lr(x['lr']) for x in optimizer.param_groups]
163 | print(f'{step} {lr}')
164 | if step < 5:
165 | self.assertAlmostEqual(lr[0], lrs[step-1][0] * step / 5)
166 | self.assertAlmostEqual(lr[1], lrs[step-1][1] * step / 5)
167 | else:
168 | self.assertAlmostEqual(lr[0], lrs[step-1][0])
169 | self.assertAlmostEqual(lr[1], lrs[step-1][1])
170 | optimizer.zero_grad()
171 | optimizer.step()
172 | with warmup_scheduler.dampening():
173 | lr_scheduler1.step()
174 | lr_scheduler2.step()
175 |
176 | _test_state_dict(self, warmup_scheduler,
177 | lambda: warmup.LinearWarmup(optimizer, warmup_period=10))
178 |
179 | def test_linear_wo_lr_scheduler(self):
180 | p1 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
181 | p2 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
182 | optimizer = torch.optim.SGD([
183 | {'params': [p1]},
184 | {'params': [p2], 'lr': _set_lr(0.1)}
185 | ], lr=_set_lr(0.5))
186 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=5)
187 | print()
188 | for step in range(1, 11):
189 | lr = [_get_lr(x['lr']) for x in optimizer.param_groups]
190 | print(f'{step} {lr}')
191 | if step < 5:
192 | self.assertAlmostEqual(lr[0], 0.5 * step / 5)
193 | self.assertAlmostEqual(lr[1], 0.1 * step / 5)
194 | else:
195 | self.assertAlmostEqual(lr[0], 0.5)
196 | self.assertAlmostEqual(lr[1], 0.1)
197 | optimizer.zero_grad()
198 | optimizer.step()
199 | with warmup_scheduler.dampening():
200 | pass
201 |
202 | _test_state_dict(self, warmup_scheduler,
203 | lambda: warmup.LinearWarmup(optimizer, warmup_period=10))
204 |
--------------------------------------------------------------------------------
/test/test_radam.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import math
3 | import torch
4 | import pytorch_warmup as warmup
5 |
6 | from .test_base import _test_state_dict, _set_lr, _get_lr
7 | from .test_untuned import _test_optimizer
8 |
9 |
10 | # The expression of each warmup factor
11 | # offset: 6
12 | # beta2: 7/10
13 | # rho_inf: 17/3
14 | ewf = {
15 | 1: 3*math.sqrt(55795775461652507765)/126458170465,
16 | 2: 3*math.sqrt(118975550786574877912913153615)/2269508758199815,
17 | 3: 3*math.sqrt(364009685132320107701159663753)/2992977113632385,
18 | 4: 3*math.sqrt(258572826689968392763003617038979)/68225651323259287,
19 | 5: 3*math.sqrt(668289519821298522824847043230807053)/3138599717744915303,
20 | 6: 3*math.sqrt(60431582784117573249154657184784100939048735)/27879860688339331112605,
21 | 7: 3*math.sqrt(3668869686599344602586804992292010752258094185)/207030521845988349697045,
22 | 8: 3*math.sqrt(38610293903545800493859693214989542002076301625518651)/648745826249577848268496415,
23 | 9: 3*math.sqrt(12026080263946093429637752207887183661294840713819813)/353035787321509409011021039,
24 | 10: 3*math.sqrt(456865593113897246792694328842050091932272202605586577311)/67546148486329926220639511801,
25 | }
26 |
27 |
28 | class TestRAdam(unittest.TestCase):
29 |
30 | def setUp(self):
31 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
32 |
33 | def test_radam(self):
34 | p1 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
35 | p2 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
36 | optimizer = torch.optim.Adam([
37 | {'params': [p1]},
38 | {'params': [p2], 'lr': _set_lr(0.1)}
39 | ], lr=_set_lr(0.5), betas=(0.9, 0.7))
40 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0)
41 | warmup_scheduler = warmup.RAdamWarmup(optimizer)
42 | print()
43 | for step in range(1, 11):
44 | lr = [_get_lr(x['lr']) for x in optimizer.param_groups]
45 | print(f'{step} {lr}')
46 | self.assertAlmostEqual(lr[0], 0.5 * ewf[step])
47 | self.assertAlmostEqual(lr[1], 0.1 * ewf[step])
48 | optimizer.zero_grad()
49 | optimizer.step()
50 | with warmup_scheduler.dampening():
51 | lr_scheduler.step()
52 |
53 | _test_state_dict(self, warmup_scheduler,
54 | lambda: warmup.RAdamWarmup(optimizer))
55 |
56 | _test_optimizer(self, warmup.RAdamWarmup)
57 |
--------------------------------------------------------------------------------
/test/test_untuned.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import math
3 | import torch
4 | import pytorch_warmup as warmup
5 |
6 | from .test_base import _test_state_dict, _set_lr, _get_lr
7 |
8 |
9 | def _test_optimizer(self, warmup_class):
10 | with self.assertRaises(TypeError, msg='optimizer type') as cm:
11 | warmup_class(optimizer=0)
12 | self.assertEqual(str(cm.exception), '0 (int) is not an Optimizer.')
13 |
14 |
15 | class TestUntuned(unittest.TestCase):
16 |
17 | def setUp(self):
18 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
19 |
20 | def test_untuned_linear(self):
21 | p1 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
22 | p2 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
23 | optimizer = torch.optim.Adam([
24 | {'params': [p1]},
25 | {'params': [p2], 'lr': _set_lr(0.1)}
26 | ], lr=_set_lr(0.5), betas=(0.9, 0.7))
27 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0)
28 | warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
29 | print()
30 | for step in range(1, 11):
31 | lr = [_get_lr(x['lr']) for x in optimizer.param_groups]
32 | print(f'{step} {lr}')
33 | if step < 6:
34 | self.assertAlmostEqual(lr[0], 0.5 * step / 6)
35 | self.assertAlmostEqual(lr[1], 0.1 * step / 6)
36 | else:
37 | self.assertAlmostEqual(lr[0], 0.5)
38 | self.assertAlmostEqual(lr[1], 0.1)
39 | optimizer.zero_grad()
40 | optimizer.step()
41 | with warmup_scheduler.dampening():
42 | lr_scheduler.step()
43 |
44 | _test_state_dict(self, warmup_scheduler,
45 | lambda: warmup.UntunedLinearWarmup(optimizer))
46 |
47 | _test_optimizer(self, warmup.UntunedLinearWarmup)
48 |
49 | def test_untuned_exponential(self):
50 | p1 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
51 | p2 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
52 | optimizer = torch.optim.Adam([
53 | {'params': [p1]},
54 | {'params': [p2], 'lr': _set_lr(0.1)}
55 | ], lr=_set_lr(0.5), betas=(0.9, 0.7))
56 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0)
57 | warmup_scheduler = warmup.UntunedExponentialWarmup(optimizer)
58 | print()
59 | for step in range(1, 11):
60 | lr = [_get_lr(x['lr']) for x in optimizer.param_groups]
61 | print(f'{step} {lr}')
62 | self.assertAlmostEqual(lr[0], 0.5 * (1 - math.exp(-step / 3)))
63 | self.assertAlmostEqual(lr[1], 0.1 * (1 - math.exp(-step / 3)))
64 | optimizer.zero_grad()
65 | optimizer.step()
66 | with warmup_scheduler.dampening():
67 | lr_scheduler.step()
68 |
69 | _test_state_dict(self, warmup_scheduler,
70 | lambda: warmup.UntunedExponentialWarmup(optimizer))
71 |
72 | _test_optimizer(self, warmup.UntunedExponentialWarmup)
73 |
--------------------------------------------------------------------------------