├── .github └── workflows │ ├── download-emnist.yml │ ├── example-cifar10.yml │ ├── example-emnist.yml │ ├── example-plots.yml │ ├── pythonpackage.yml │ └── sphinx-gh-pages.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── Makefile ├── _templates │ └── versions.html ├── build.sh ├── conf.py ├── gh-pages-redirect.html ├── index.rst ├── make.bat ├── manual_warmup.rst ├── radam_warmup.rst └── untuned_warmup.rst ├── examples ├── cifar10 │ ├── README.md │ ├── download.py │ ├── figs │ │ ├── fig-history-adamw-w-radam-warmup.png │ │ ├── fig-resnet-sgd-scores-best.png │ │ ├── fig-resnet-sgd-scores-mean-std.png │ │ ├── fig-scores-adamax-w-warmup-lr0-1.png │ │ ├── fig-scores-adamax-w-warmup-vs-sgd.png │ │ ├── fig-scores-adamw-w-warmup-edf0-99-vs-radamw.png │ │ ├── fig-scores-adamw-w-warmup-edf0-99-vs-sgd.png │ │ ├── fig-scores-adamw-w-warmup-vs-sgd.png │ │ ├── fig-scores-amsgradw-w-warmup-vs-sgd.png │ │ ├── fig-scores-nadamw-w-warmup-vs-sgd.png │ │ └── fig-scores-no-warmup.png │ ├── main.py │ └── plot.py ├── emnist │ ├── README.md │ ├── download.py │ ├── figs │ │ ├── accuracy.png │ │ ├── fig-history-adamw-w-radam-warmup.png │ │ └── learning_rate.png │ ├── main.py │ └── plot.py └── plots │ ├── README.md │ ├── effective_warmup_period.py │ ├── figs │ ├── warmup_period.png │ └── warmup_schedule.png │ └── warmup_schedule.py ├── pyproject.toml ├── pytorch_warmup ├── __init__.py ├── base.py ├── radam.py └── untuned.py └── test ├── __init__.py ├── test_base.py ├── test_radam.py └── test_untuned.py /.github/workflows/download-emnist.yml: -------------------------------------------------------------------------------- 1 | name: Download emnist 2 | 3 | on: [push] 4 | 5 | jobs: 6 | download: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | max-parallel: 8 11 | matrix: 12 | torchvision-version: [0.10.1, 0.11.3, 0.12.0, 0.13.1, 0.14.1, 0.15.2, 0.16.2, 0.17.2, 0.18.1, 0.19.1, 0.20.1] 13 | include: 14 | - pytorch-version: 1.9.1 15 | torchvision-version: 0.10.1 16 | - pytorch-version: 1.10.2 17 | torchvision-version: 0.11.3 18 | - pytorch-version: 1.11.0 19 | torchvision-version: 0.12.0 20 | - pytorch-version: 1.12.1 21 | torchvision-version: 0.13.1 22 | - pytorch-version: 1.13.1 23 | torchvision-version: 0.14.1 24 | - pytorch-version: 2.0.1 25 | torchvision-version: 0.15.2 26 | - pytorch-version: 2.1.2 27 | torchvision-version: 0.16.2 28 | - pytorch-version: 2.2.2 29 | torchvision-version: 0.17.2 30 | - pytorch-version: 2.3.1 31 | torchvision-version: 0.18.1 32 | - pytorch-version: 2.4.1 33 | torchvision-version: 0.19.1 34 | - pytorch-version: 2.5.1 35 | torchvision-version: 0.20.1 36 | steps: 37 | - uses: actions/checkout@v4 38 | - name: Set up Python 3.9 39 | uses: actions/setup-python@v5 40 | with: 41 | python-version: 3.9 42 | - name: Install dependencies 43 | run: | 44 | python -m pip install --upgrade pip 45 | pip install 'numpy<2' 46 | pip install torch==${{ matrix.pytorch-version }}+cpu -f https://download.pytorch.org/whl/torch 47 | pip install torchvision==${{ matrix.torchvision-version }}+cpu -f https://download.pytorch.org/whl/torchvision 48 | pip install setuptools 49 | pip install requests 50 | - name: Install package 51 | run: python -m pip install . 52 | - name: Download EMNIST dataset 53 | run: python examples/emnist/download.py 54 | - name: Extract EMNIST dataset 55 | run: python examples/emnist/main.py --epochs 0 56 | -------------------------------------------------------------------------------- /.github/workflows/example-cifar10.yml: -------------------------------------------------------------------------------- 1 | name: Example cifar10 2 | 3 | on: [push] 4 | 5 | jobs: 6 | train: 7 | 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | max-parallel: 8 11 | matrix: 12 | python-version: [3.9, '3.10', 3.11, 3.12] 13 | os: [macos-latest, windows-latest, ubuntu-latest] 14 | include: 15 | - pytorch-version: 2.3.1 16 | torchvision-version: 0.18.1 17 | - pytorch-option: '+cpu' 18 | - pytorch-option: '' 19 | os: macos-latest 20 | 21 | steps: 22 | - uses: actions/checkout@v4 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v5 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | pip install 'numpy<2' 31 | pip install torch==${{ matrix.pytorch-version }}${{ matrix.pytorch-option }} -f https://download.pytorch.org/whl/torch 32 | pip install torchvision==${{ matrix.torchvision-version }}${{ matrix.pytorch-option }} -f https://download.pytorch.org/whl/torchvision 33 | pip install setuptools 34 | pip install requests 35 | pip install tqdm 36 | - name: Install package 37 | run: python -m pip install . 38 | - name: Download a ResNet implementation 39 | run: | 40 | cd examples/cifar10/ 41 | python download.py 42 | - name: Train a ResNet20 model on CIFAR10 dataset 43 | run: python examples/cifar10/main.py --epochs 1 --no-progress --no-gpu 44 | -------------------------------------------------------------------------------- /.github/workflows/example-emnist.yml: -------------------------------------------------------------------------------- 1 | name: Example emnist 2 | 3 | on: [push] 4 | 5 | jobs: 6 | train: 7 | 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | max-parallel: 8 11 | matrix: 12 | python-version: [3.9, '3.10', 3.11, 3.12] 13 | os: [macos-latest, windows-latest, ubuntu-latest] 14 | include: 15 | - pytorch-version: 1.9.1 16 | torchvision-version: 0.10.1 17 | python-version: 3.9 18 | - pytorch-version: 1.9.0 19 | torchvision-version: 0.10.0 20 | python-version: 3.9 21 | os: macos-latest 22 | - pytorch-version: 1.11.0 23 | torchvision-version: 0.12.0 24 | python-version: '3.10' 25 | - pytorch-version: 2.0.1 26 | torchvision-version: 0.15.2 27 | python-version: 3.11 28 | - pytorch-version: 2.2.2 29 | torchvision-version: 0.17.2 30 | python-version: 3.12 31 | - pytorch-option: '+cpu' 32 | - pytorch-option: '' 33 | os: macos-latest 34 | 35 | steps: 36 | - uses: actions/checkout@v4 37 | - name: Set up Python ${{ matrix.python-version }} 38 | uses: actions/setup-python@v5 39 | with: 40 | python-version: ${{ matrix.python-version }} 41 | - name: Install dependencies 42 | run: | 43 | python -m pip install --upgrade pip 44 | pip install 'numpy<2' 45 | pip install torch==${{ matrix.pytorch-version }}${{ matrix.pytorch-option }} -f https://download.pytorch.org/whl/torch 46 | pip install torchvision==${{ matrix.torchvision-version }}${{ matrix.pytorch-option }} -f https://download.pytorch.org/whl/torchvision 47 | pip install setuptools 48 | pip install requests 49 | - name: Install package 50 | run: python -m pip install . 51 | - name: Download EMNIST dataset 52 | run: python examples/emnist/download.py 53 | - name: Train a model on EMNIST dataset 54 | run: python examples/emnist/main.py --epochs 1 --no-gpu 55 | -------------------------------------------------------------------------------- /.github/workflows/example-plots.yml: -------------------------------------------------------------------------------- 1 | name: Example plots 2 | 3 | on: [push] 4 | 5 | jobs: 6 | plot: 7 | 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | max-parallel: 8 11 | matrix: 12 | python-version: [3.9, '3.10', 3.11, 3.12] 13 | os: [macos-latest, windows-latest, ubuntu-latest] 14 | include: 15 | - pytorch-version: 1.9.1 16 | python-version: 3.9 17 | - pytorch-version: 1.11.0 18 | python-version: '3.10' 19 | - pytorch-version: 2.0.1 20 | python-version: 3.11 21 | - pytorch-version: 2.2.2 22 | python-version: 3.12 23 | - pytorch-option: '+cpu' 24 | - pytorch-option: '' 25 | os: macos-latest 26 | 27 | steps: 28 | - uses: actions/checkout@v4 29 | - name: Set up Python ${{ matrix.python-version }} 30 | uses: actions/setup-python@v5 31 | with: 32 | python-version: ${{ matrix.python-version }} 33 | - name: Install dependencies 34 | run: | 35 | python -m pip install --upgrade pip 36 | pip install 'numpy<2' 37 | pip install torch==${{ matrix.pytorch-version }}${{ matrix.pytorch-option }} -f https://download.pytorch.org/whl/torch 38 | pip install matplotlib 39 | pip install setuptools 40 | - name: Install package 41 | run: python -m pip install . 42 | - name: Preparation 43 | run: mkdir artifact 44 | - name: Plot warmup period 45 | run: | 46 | cd artifact 47 | python ../examples/plots/effective_warmup_period.py --output png 48 | - name: Plot warmup schedule 49 | run: | 50 | cd artifact 51 | python ../examples/plots/warmup_schedule.py --output png 52 | - uses: actions/upload-artifact@v4 53 | with: 54 | name: artifact_${{ matrix.os }}_${{ matrix.python-version }} 55 | path: artifact 56 | -------------------------------------------------------------------------------- /.github/workflows/pythonpackage.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | max-parallel: 8 11 | matrix: 12 | python-version: [3.9, '3.10', 3.11, 3.12] 13 | os: [macos-latest, windows-latest, ubuntu-latest] 14 | pytorch-release: [earliest, latest] 15 | include: 16 | - pytorch-version: 1.9.1 17 | python-version: 3.9 18 | pytorch-release: earliest 19 | - pytorch-version: 1.11.0 20 | python-version: '3.10' 21 | pytorch-release: earliest 22 | - pytorch-version: 2.0.1 23 | python-version: 3.11 24 | pytorch-release: earliest 25 | - pytorch-version: 2.2.2 26 | python-version: 3.12 27 | pytorch-release: earliest 28 | - pytorch-version: 2.5.1 29 | pytorch-release: latest 30 | - pytorch-option: '+cpu' 31 | - pytorch-option: '' 32 | os: macos-latest 33 | 34 | steps: 35 | - uses: actions/checkout@v4 36 | - name: Set up Python ${{ matrix.python-version }} 37 | uses: actions/setup-python@v5 38 | with: 39 | python-version: ${{ matrix.python-version }} 40 | - name: Install dependencies 41 | run: | 42 | python -m pip install --upgrade pip 43 | pip install 'numpy<2' 44 | pip install torch==${{ matrix.pytorch-version }}${{ matrix.pytorch-option }} -f https://download.pytorch.org/whl/torch 45 | - name: Lint with flake8 46 | run: | 47 | pip install flake8 48 | # stop the build if there are Python syntax errors or undefined names 49 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 50 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 51 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 52 | - name: Test with pytest (naked) 53 | run: | 54 | pip install pytest 55 | pytest test -s -vv 56 | - name: Test with pytest (wrapped) 57 | run: pytest test -s -vv 58 | env: 59 | WRAPPED_LR: "1" 60 | if: matrix.pytorch-release == 'latest' 61 | - name: Package with build 62 | run: | 63 | pip install setuptools build 64 | python -m build 65 | -------------------------------------------------------------------------------- /.github/workflows/sphinx-gh-pages.yml: -------------------------------------------------------------------------------- 1 | name: Sphinx GitHub Pages 2 | 3 | on: 4 | # Runs on pushes targeting the default branch 5 | push: 6 | branches: ["master"] 7 | 8 | # Allows you to run this workflow manually from the Actions tab 9 | workflow_dispatch: 10 | 11 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 12 | permissions: 13 | contents: read 14 | pages: write 15 | id-token: write 16 | 17 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. 18 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. 19 | concurrency: 20 | group: "pages" 21 | cancel-in-progress: false 22 | 23 | jobs: 24 | # Build job 25 | build: 26 | runs-on: ubuntu-latest 27 | steps: 28 | - name: Checkout 29 | uses: actions/checkout@v4 30 | with: 31 | fetch-depth: 0 32 | - name: Set up Python 33 | uses: actions/setup-python@v5 34 | with: 35 | python-version: '3.10' 36 | - name: Install dependencies 37 | run: | 38 | python -m pip install --upgrade pip 39 | pip install 'numpy<2' 40 | pip install torch --index-url https://download.pytorch.org/whl/cpu 41 | pip install sphinx sphinxcontrib-katex sphinx-copybutton sphinx-multiversion sphinx-rtd-theme 42 | - name: Sphinx Build 43 | run: bash docs/build.sh 44 | - name: Upload artifact 45 | uses: actions/upload-pages-artifact@v3 46 | with: 47 | path: build/html 48 | 49 | # Deployment job 50 | deploy: 51 | environment: 52 | name: github-pages 53 | url: ${{ steps.deployment.outputs.page_url }} 54 | runs-on: ubuntu-latest 55 | needs: build 56 | steps: 57 | - name: Deploy to GitHub Pages 58 | id: deployment 59 | uses: actions/deploy-pages@v4 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | 127 | # Example data 128 | examples/*/data 129 | examples/*/output* 130 | examples/cifar10/resnet.py 131 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019-2024 Takenori Yamamoto 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include test/__init__.py 2 | include examples/plots/README.md 3 | include examples/plots/effective_warmup_period.py 4 | include examples/plots/warmup_schedule.py 5 | include examples/plots/figs/*.png 6 | include examples/emnist/README.md 7 | include examples/emnist/download.py 8 | include examples/emnist/main.py 9 | include examples/emnist/plot.py 10 | include examples/emnist/figs/*.png 11 | include examples/cifar10/README.md 12 | include examples/cifar10/download.py 13 | include examples/cifar10/main.py 14 | include examples/cifar10/plot.py 15 | include examples/cifar10/figs/*.png 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A PyTorch Extension for Learning Rate Warmup 2 | 3 | This library contains PyTorch implementations of the warmup schedules described in [On the adequacy of untuned warmup for adaptive optimization](https://arxiv.org/abs/1910.04209). 4 | 5 |

Warmup schedule

6 | 7 | [![Python package](https://github.com/Tony-Y/pytorch_warmup/workflows/Python%20package/badge.svg)](https://github.com/Tony-Y/pytorch_warmup/) 8 | [![PyPI version shields.io](https://img.shields.io/pypi/v/pytorch-warmup.svg)](https://pypi.python.org/pypi/pytorch-warmup/) 9 | [![PyPI license](https://img.shields.io/pypi/l/pytorch-warmup.svg)](https://github.com/Tony-Y/pytorch_warmup/blob/master/LICENSE) 10 | [![Python versions](https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue)](https://www.python.org) 11 | 12 | ## Installation 13 | 14 | Make sure you have Python 3.9+ and PyTorch 1.9+ or 2.x. Then, run the following command in the project directory: 15 | 16 | ```shell 17 | python -m pip install . 18 | ``` 19 | 20 | or install the latest version from the Python Package Index: 21 | 22 | ```shell 23 | pip install -U pytorch_warmup 24 | ``` 25 | 26 | ## Examples 27 | 28 | * [CIFAR10](https://github.com/Tony-Y/pytorch_warmup/tree/master/examples/cifar10) - 29 | A sample script to train a ResNet model on the CIFAR10 dataset using an optimization algorithm with a warmup schedule. 30 | Its README presents ResNet20 results obtained using each of AdamW, NAdamW, AMSGradW, and AdaMax 31 | together with each of various warmup schedules. 32 | In addition, there is a ResNet performance comparison (up to ResNet110) obtained using the SGD algorithm 33 | with a linear warmup schedule. 34 | * [EMNIST](https://github.com/Tony-Y/pytorch_warmup/tree/master/examples/emnist) - 35 | A sample script to train a CNN model on the EMNIST dataset using the AdamW algorithm with a warmup schedule. 36 | Its README presents a result obtained using the AdamW algorithm with each of the untuned linear and exponential warmup, 37 | and the RAdam warmup. 38 | * [Plots](https://github.com/Tony-Y/pytorch_warmup/tree/master/examples/plots) - 39 | A script to plot effective warmup periods as a function of β₂, and warmup schedules over time. 40 | 41 | ## Usage 42 | 43 | The [documentation](https://tony-y.github.io/pytorch_warmup/master/) provides more detailed information on this library, unseen below. 44 | 45 | ### Sample Codes 46 | 47 | The scheduled learning rate is dampened by the multiplication of the warmup factor: 48 | 49 |

Learning rate

50 | 51 | #### Approach 1 52 | 53 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Tony-Y/colab-notebooks/blob/master/PyTorch_Warmup_Approach1_chaining.ipynb) 54 | 55 | When the learning rate schedule uses the global iteration number, the untuned linear warmup can be used 56 | together with `Adam` or its variant (`AdamW`, `NAdam`, etc.) as follows: 57 | 58 | ```python 59 | import torch 60 | import pytorch_warmup as warmup 61 | 62 | optimizer = torch.optim.AdamW(params, lr=0.001, betas=(0.9, 0.999), weight_decay=0.01) 63 | # This sample code uses the AdamW optimizer. 64 | num_steps = len(dataloader) * num_epochs 65 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps) 66 | # The LR schedule initialization resets the initial LR of the optimizer. 67 | warmup_scheduler = warmup.UntunedLinearWarmup(optimizer) 68 | # The warmup schedule initialization dampens the initial LR of the optimizer. 69 | for epoch in range(1,num_epochs+1): 70 | for batch in dataloader: 71 | optimizer.zero_grad() 72 | loss = ... 73 | loss.backward() 74 | optimizer.step() 75 | with warmup_scheduler.dampening(): 76 | lr_scheduler.step() 77 | ``` 78 | 79 | > [!Warning] 80 | > Note that the warmup schedule must not be initialized before the initialization of the learning rate schedule. 81 | 82 | If you want to use the learning rate schedule *chaining*, which is supported for PyTorch 1.4 or above, you may simply write a code of learning rate schedulers as a suite of the `with` statement: 83 | 84 | ```python 85 | lr_scheduler1 = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) 86 | lr_scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1) 87 | warmup_scheduler = warmup.UntunedLinearWarmup(optimizer) 88 | for epoch in range(1,num_epochs+1): 89 | for batch in dataloader: 90 | ... 91 | optimizer.step() 92 | with warmup_scheduler.dampening(): 93 | lr_scheduler1.step() 94 | lr_scheduler2.step() 95 | ``` 96 | 97 | If you want to start the learning rate schedule after the end of the linear warmup, delay it by the warmup period: 98 | 99 | ```python 100 | warmup_period = 2000 101 | num_steps = len(dataloader) * num_epochs - warmup_period 102 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps) 103 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period) 104 | for epoch in range(1,num_epochs+1): 105 | for batch in dataloader: 106 | ... 107 | optimizer.step() 108 | with warmup_scheduler.dampening(): 109 | if warmup_scheduler.last_step + 1 >= warmup_period: 110 | lr_scheduler.step() 111 | ``` 112 | 113 | #### Approach 2 114 | 115 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Tony-Y/colab-notebooks/blob/master/PyTorch_Warmup_Approach2_chaining.ipynb) 116 | 117 | When the learning rate schedule uses the epoch number, the warmup schedule can be used as follows: 118 | 119 | ```python 120 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[num_epochs//3], gamma=0.1) 121 | warmup_scheduler = warmup.UntunedLinearWarmup(optimizer) 122 | for epoch in range(1,num_epochs+1): 123 | for i, batch in enumerate(dataloader): 124 | optimizer.zero_grad() 125 | loss = ... 126 | loss.backward() 127 | optimizer.step() 128 | if i < len(dataloader)-1: 129 | with warmup_scheduler.dampening(): 130 | pass 131 | with warmup_scheduler.dampening(): 132 | lr_scheduler.step() 133 | ``` 134 | 135 | This code can be rewritten more compactly: 136 | 137 | ```python 138 | for epoch in range(1,num_epochs+1): 139 | for i, batch in enumerate(dataloader): 140 | optimizer.zero_grad() 141 | loss = ... 142 | loss.backward() 143 | optimizer.step() 144 | with warmup_scheduler.dampening(): 145 | if i + 1 == len(dataloader): 146 | lr_scheduler.step() 147 | ``` 148 | 149 | #### Approach 3 150 | 151 | When you use `CosineAnnealingWarmRestarts`, the warmup schedule can be used as follows: 152 | 153 | ```python 154 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2) 155 | warmup_period = 2000 156 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period) 157 | iters = len(dataloader) 158 | warmup_epochs = ... # for example, (warmup_period + iters - 1) // iters 159 | for epoch in range(epochs+warmup_epochs): 160 | for i, batch in enumerate(dataloader): 161 | optimizer.zero_grad() 162 | loss = ... 163 | loss.backward() 164 | optimizer.step() 165 | with warmup_scheduler.dampening(): 166 | if epoch >= warmup_epochs: 167 | lr_scheduler.step(epoch-warmup_epochs + i / iters) 168 | ``` 169 | 170 | ### Warmup Schedules 171 | 172 | #### Manual Warmup 173 | 174 | In `LinearWarmup` and `ExponentialWarmup`, the warmup factor `w(t)` depends on the warmup period that must manually be specified. 175 | 176 | ##### Linear 177 | 178 | `w(t) = min(1, t / warmup_period)` 179 | 180 | ```python 181 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=2000) 182 | ``` 183 | 184 | For details please refer to [LinearWarmup](https://tony-y.github.io/pytorch_warmup/master/manual_warmup.html#pytorch_warmup.base.LinearWarmup) in the documentation. 185 | 186 | ##### Exponential 187 | 188 | `w(t) = 1 - exp(-t / warmup_period)` 189 | 190 | ```python 191 | warmup_scheduler = warmup.ExponentialWarmup(optimizer, warmup_period=1000) 192 | ``` 193 | 194 | For details please refer to [ExponentialWarmup](https://tony-y.github.io/pytorch_warmup/master/manual_warmup.html#pytorch_warmup.base.ExponentialWarmup) in the documentation. 195 | 196 | #### Untuned Warmup 197 | 198 | In `UntunedLinearWarmup` and `UntunedExponentialWarmup`, the warmup period is determined by a function of Adam's `beta2` parameter. 199 | 200 | ##### Linear 201 | 202 | `warmup_period = 2 / (1 - beta2)` 203 | 204 | ```python 205 | warmup_scheduler = warmup.UntunedLinearWarmup(optimizer) 206 | ``` 207 | 208 | For details please refer to [UntunedLinearWarmup](https://tony-y.github.io/pytorch_warmup/master/untuned_warmup.html#pytorch_warmup.untuned.UntunedLinearWarmup) in the documentation. 209 | 210 | ##### Exponential 211 | 212 | `warmup_period = 1 / (1 - beta2)` 213 | 214 | ```python 215 | warmup_scheduler = warmup.UntunedExponentialWarmup(optimizer) 216 | ``` 217 | 218 | For details please refer to [UntunedExponentialWarmup](https://tony-y.github.io/pytorch_warmup/master/untuned_warmup.html#pytorch_warmup.untuned.UntunedExponentialWarmup) in the documentation. 219 | 220 | #### RAdam Warmup 221 | 222 | In `RAdamWarmup`, the warmup factor `w(t)` is a complicated function depending on Adam's `beta2` parameter. 223 | 224 | ```python 225 | warmup_scheduler = warmup.RAdamWarmup(optimizer) 226 | ``` 227 | 228 | For details please refer to [RAdamWarmup](https://tony-y.github.io/pytorch_warmup/master/radam_warmup.html#pytorch_warmup.radam.RAdamWarmup) in the documentation, or 229 | "[On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265)." 230 | 231 | ### Apex's Adam 232 | 233 | The Apex library provides an Adam optimizer tuned for CUDA devices, [FusedAdam](https://nvidia.github.io/apex/optimizers.html#apex.optimizers.FusedAdam). The FusedAdam optimizer can be used together with any one of the warmup schedules above. For example: 234 | 235 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Tony-Y/colab-notebooks/blob/master/PyTorch_Warmup_FusedAdam.ipynb) 236 | 237 | ```python 238 | optimizer = apex.optimizers.FusedAdam(params, lr=0.001, betas=(0.9, 0.999), weight_decay=0.01) 239 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps) 240 | warmup_scheduler = warmup.UntunedLinearWarmup(optimizer) 241 | ``` 242 | 243 | ### Compiled Optimizers 244 | 245 | [Benchmarking results](https://dev-discuss.pytorch.org/t/performance-comparison-between-torch-compile-and-apex-optimizers/2023) 246 | show that the complied Adam outperforms the Apex's Adam. 247 | 248 | > [!Warning] 249 | > PyTorch 2.3 or later is required for using the compiled optimizer with 250 | > a warmup scheduler and/or LR schedulers. 251 | > PyTorch-Warmup 0.2 or earlier is incompatible with the complied optimizer. 252 | 253 | You can compile the Adam optimizer as follows: 254 | 255 | ```python 256 | model = model.to(device) 257 | optimizer = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.001).to(device)) 258 | opt_step = torch.compile(optimizer.step, mode="reduce-overhead") 259 | ``` 260 | 261 | > [!Important] 262 | > Wrap the learning rate in a `Tensor`, or `torch.compile` will recompile 263 | > as the value of the learning rate changes. 264 | 265 | Then, the compiled version `opt_step` have to be invoked instead of `optimizer.step`: 266 | 267 | ```python 268 | for epoch in range(1,num_epochs+1): 269 | for batch in dataloader: 270 | optimizer.zero_grad() 271 | loss = ... 272 | loss.backward() 273 | opt_step() 274 | with warmup_scheduler.dampening(): 275 | lr_scheduler.step() 276 | ``` 277 | 278 | You can also compile other built-in optimizers in the way shown above. 279 | 280 | > [!Note] 281 | > When using the compiled SGD with momentum, its momentum buffer is needed 282 | > to be initialized manually. You can find sample code in the CIFAR10 exmaple. 283 | 284 | In practice, you may compile it together with other PyTorch code as follows: 285 | 286 | ```python 287 | @torch.compile(mode="reduce-overhead") 288 | def train_iter_fn(batch): 289 | optimizer.zero_grad() 290 | loss = ... 291 | loss.backward() 292 | optimizer.step() 293 | 294 | for epoch in range(1,num_epochs+1): 295 | for batch in dataloader: 296 | train_iter_fn(batch) 297 | with warmup_scheduler.dampening(): 298 | lr_scheduler.step() 299 | ``` 300 | 301 | `torch.compile` skips `lr_scheduler.step` even if it were invoked within `train_iter_fn`. 302 | Likewise, you should not compile `warmup_scheduler.dampening`. 303 | You may also use `torch.compiler.disable` to have `torch.compile` skip a function 304 | updating the learning rate as follows: 305 | 306 | ```python 307 | @torch.compiler.disable 308 | def update_lr_fn(): 309 | with warmup_scheduler.dampening(): 310 | lr_scheduler.step() 311 | 312 | @torch.compile(mode="reduce-overhead") 313 | def train_iter_fn(batch): 314 | optimizer.zero_grad() 315 | loss = ... 316 | loss.backward() 317 | optimizer.step() 318 | update_lr_fn() 319 | 320 | for epoch in range(1,num_epochs+1): 321 | for batch in dataloader: 322 | train_iter_fn(batch) 323 | ``` 324 | 325 | ## License 326 | 327 | MIT License 328 | 329 | © 2019-2025 Takenori Yamamoto 330 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_templates/versions.html: -------------------------------------------------------------------------------- 1 | {%- if current_version %} 2 |
3 | 4 | Other Versions 5 | v: {{ current_version.name }} 6 | 7 | 8 |
9 | {%- if versions.tags %} 10 |
11 |
Tags
12 | {%- for item in versions.tags %} 13 |
{{ item.name }}
14 | {%- endfor %} 15 |
16 | {%- endif %} 17 | {%- if versions.branches %} 18 |
19 |
Branches
20 | {%- for item in versions.branches %} 21 |
{{ item.name }}
22 | {%- endfor %} 23 |
24 | {%- endif %} 25 |
26 |
27 | {%- endif %} 28 | -------------------------------------------------------------------------------- /docs/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ex 2 | # 3 | # In the project root directry, run this script to build the HTML document: 4 | # >> ./docs/build.sh 5 | # 6 | 7 | export SMV_BRANCH_WHITELIST="^$(git branch --show-current)$" 8 | 9 | SOURCEDIR=docs 10 | BUILDDIR=build/html 11 | 12 | sphinx-multiversion $SOURCEDIR $BUILDDIR 13 | cp $SOURCEDIR/gh-pages-redirect.html $BUILDDIR/index.html 14 | touch $BUILDDIR/.nojekyll 15 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | import os 17 | import sys 18 | sys.path.insert(0, os.path.abspath('../')) 19 | 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = 'PyTorch Warmup' 24 | copyright = '2019-2024, Takenori Yamamoto' 25 | author = 'Takenori Yamamoto' 26 | 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | 'sphinx.ext.autodoc', 35 | 'sphinx.ext.autosummary', 36 | 'sphinx.ext.doctest', 37 | 'sphinx.ext.intersphinx', 38 | 'sphinx.ext.todo', 39 | 'sphinx.ext.coverage', 40 | 'sphinx.ext.napoleon', 41 | 'sphinx.ext.viewcode', 42 | 'sphinxcontrib.katex', 43 | 'sphinx_copybutton', 44 | 'sphinx_multiversion', 45 | ] 46 | 47 | # Add any paths that contain templates here, relative to this directory. 48 | templates_path = ['_templates'] 49 | 50 | # List of patterns, relative to source directory, that match files and 51 | # directories to ignore when looking for source files. 52 | # This pattern also affects html_static_path and html_extra_path. 53 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 54 | 55 | 56 | # -- Options for HTML output ------------------------------------------------- 57 | 58 | # The theme to use for HTML and HTML Help pages. See the documentation for 59 | # a list of builtin themes. 60 | # 61 | html_theme = 'sphinx_rtd_theme' 62 | 63 | # Add any paths that contain custom static files (such as style sheets) here, 64 | # relative to this directory. They are copied after the builtin static files, 65 | # so a file named "default.css" will overwrite the builtin "default.css". 66 | # html_static_path = ['_static'] 67 | 68 | # Copybutton settings 69 | copybutton_prompt_text = r">>> |\.\.\. |\$ " 70 | copybutton_prompt_is_regexp = True 71 | 72 | # Multiversion settings 73 | smv_tag_whitelist = r'^v(?!0\.1\.0)\d+\.\d+\.\d+$' 74 | if "SMV_BRANCH_WHITELIST" in os.environ: 75 | smv_branch_whitelist = os.environ["SMV_BRANCH_WHITELIST"] 76 | -------------------------------------------------------------------------------- /docs/gh-pages-redirect.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Redirecting to master branch 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. PyTorch Warmup documentation master file, created by 2 | sphinx-quickstart on Thu Oct 31 14:00:43 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to PyTorch Warmup's documentation! 7 | ========================================== 8 | 9 | This library contains PyTorch implementations of the warmup schedules described in 10 | `On the adequacy of untuned warmup for adaptive optimization 11 | `_. 12 | 13 | .. image:: https://github.com/Tony-Y/pytorch_warmup/raw/master/examples/plots/figs/warmup_schedule.png 14 | :alt: Warmup schedule 15 | :width: 400 16 | :align: center 17 | 18 | .. image:: https://github.com/Tony-Y/pytorch_warmup/workflows/Python%20package/badge.svg 19 | :alt: Python package 20 | :target: https://github.com/Tony-Y/pytorch_warmup/ 21 | 22 | .. image:: https://img.shields.io/pypi/v/pytorch-warmup.svg 23 | :alt: PyPI version shields.io 24 | :target: https://pypi.python.org/pypi/pytorch-warmup/ 25 | 26 | .. image:: https://img.shields.io/pypi/l/pytorch-warmup.svg 27 | :alt: PyPI license 28 | :target: https://github.com/Tony-Y/pytorch_warmup/blob/master/LICENSE 29 | 30 | .. image:: https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue 31 | :alt: Python versions 32 | :target: https://www.python.org 33 | 34 | Installation 35 | ------------ 36 | 37 | Make sure you have Python 3.9+ and PyTorch 1.9+ or 2.x. Then, install the latest version from the Python Package Index: 38 | 39 | .. code-block:: shell 40 | 41 | pip install -U pytorch_warmup 42 | 43 | Examples 44 | -------- 45 | 46 | .. image:: https://colab.research.google.com/assets/colab-badge.svg 47 | :alt: Open In Colab 48 | :target: https://colab.research.google.com/github/Tony-Y/colab-notebooks/blob/master/PyTorch_Warmup_Approach1_chaining.ipynb 49 | 50 | * `CIFAR10 `_ - 51 | A sample script to train a ResNet model on the CIFAR10 dataset using an optimization algorithm with a warmup schedule. 52 | Its README presents ResNet20 results obtained using each of AdamW, NAdamW, AMSGradW, and AdaMax 53 | together with each of various warmup schedules. 54 | In addition, there is a ResNet performance comparison (up to ResNet110) obtained using the SGD algorithm 55 | with a linear warmup schedule. 56 | 57 | * `EMNIST `_ - 58 | A sample script to train a CNN model on the EMNIST dataset using the AdamW algorithm with a warmup schedule. 59 | Its README presents a result obtained using the AdamW algorithm with each of the untuned linear and exponential warmup, 60 | and the RAdam warmup. 61 | 62 | * `Plots `_ - 63 | A script to plot effective warmup periods as a function of :math:`\beta_{2}`, and warmup schedules over time. 64 | 65 | Usage 66 | ----- 67 | 68 | When the learning rate schedule uses the global iteration number, the untuned linear warmup can be used 69 | together with :class:`Adam` or its variant (:class:`AdamW`, :class:`NAdam`, etc.) as follows: 70 | 71 | .. code-block:: python 72 | 73 | import torch 74 | import pytorch_warmup as warmup 75 | 76 | optimizer = torch.optim.AdamW(params, lr=0.001, betas=(0.9, 0.999), weight_decay=0.01) 77 | # This sample code uses the AdamW optimizer. 78 | num_steps = len(dataloader) * num_epochs 79 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps) 80 | # The LR schedule initialization resets the initial LR of the optimizer. 81 | warmup_scheduler = warmup.UntunedLinearWarmup(optimizer) 82 | # The warmup schedule initialization dampens the initial LR of the optimizer. 83 | for epoch in range(1,num_epochs+1): 84 | for batch in dataloader: 85 | optimizer.zero_grad() 86 | loss = ... 87 | loss.backward() 88 | optimizer.step() 89 | with warmup_scheduler.dampening(): 90 | lr_scheduler.step() 91 | 92 | .. warning:: 93 | Note that the warmup schedule must not be initialized before the initialization of the learning rate schedule. 94 | 95 | Other approaches can be found in `README `_. 96 | 97 | .. toctree:: 98 | :maxdepth: 2 99 | :caption: Contents: 100 | 101 | manual_warmup 102 | untuned_warmup 103 | radam_warmup 104 | 105 | 106 | Indices and tables 107 | ================== 108 | 109 | * :ref:`genindex` 110 | * :ref:`modindex` 111 | * :ref:`search` 112 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/manual_warmup.rst: -------------------------------------------------------------------------------- 1 | Manual Warmup 2 | ============== 3 | 4 | .. automodule:: pytorch_warmup.base 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/radam_warmup.rst: -------------------------------------------------------------------------------- 1 | RAdam Warmup 2 | ============ 3 | 4 | .. automodule:: pytorch_warmup.radam 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/untuned_warmup.rst: -------------------------------------------------------------------------------- 1 | Untuned Warmup 2 | ============== 3 | 4 | .. automodule:: pytorch_warmup.untuned 5 | :members: 6 | -------------------------------------------------------------------------------- /examples/cifar10/README.md: -------------------------------------------------------------------------------- 1 | # CIFAR10 Example 2 | 3 | Requirements: PyTorch 1.12+ or 2.x, `pytorch_warmup`, `torchvision`, and `tqdm`. 4 | 5 | > [!Warning] 6 | > The NAdamW optimization algorithm requires PyTorch 2.1 or later. 7 | > The RAdamW optimization algorithm requires PyTorch 2.3 or later. 8 | > So, we recommend you use PyTorch 2.3.1 or later. 9 | 10 | ## Results 11 | 12 | * The ResNet20 architecture is employed from 13 | [a ResNet implementation for CIFAR10](https://github.com/akamaster/pytorch_resnet_cifar10). 14 | * The initial learning rate $\alpha$ is $10^{-1}$ or $10^{-2}$. 15 | * Various optimization algorithms are used for comparison: SGD, AdaMax, AdamW, AMSGradW, NAdamW, and RAdamW. 16 | The momentum factor of SGD is set to $0.9$. The exponential decay rates of Adam variants are set as 17 | $\beta_{1} = 0.9$ and $\beta_{2} = 0.999$. 18 | * Various warmup schedules are used for comparison. 19 | The untuned linear and exponential warmup are labeled as Linear and Expo, respectively. 20 | The RAdam warmup is labeled as RAdam. 21 | The linear and exponential warmup with a warmup period and a constant $\tau$ of *n* thousands 22 | are labeled as Linear-*n*k and Expo-*n*k, respectively. 23 | * 5 random seeds are used to sample top-1 scores on the testing set. 24 | The mean scores are shown in figures below. The error bar indicates the standard deviation. 25 | The error band is used for reference purpose. 26 | 27 | ### No Warmup 28 | 29 |

30 | Top-1 Error
31 | Top-1 errors of models trained by each optimization algorithm without warmup. 32 |

33 | 34 | ### AdamW 35 | 36 |

37 | Top-1 Error
38 | Top-1 errors of models trained by AdamW with each warmup schedule. 39 |

40 | 41 | ### AMSGradW 42 | 43 |

44 | Top-1 Error
45 | Top-1 errors of models trained by AMSGradW with each warmup schedule. 46 |

47 | 48 | ### NAdamW 49 | 50 |

51 | Top-1 Error
52 | Top-1 errors of models trained by NAdamW with each warmup schedule. 53 |

54 | 55 | ### AdaMax 56 | 57 |

58 | Top-1 Error
59 | Top-1 errors of models trained by AdaMax with each warmup schedule for α = 0.01. 60 |

61 | 62 |

63 | Top-1 Error
64 | Top-1 errors of models trained by AdaMax with each warmup schedule for α = 0.1. 65 |

66 | 67 | ### AdamW for a smaller β₂ 68 | 69 |

70 | Top-1 Error
71 | Top-1 errors of models trained by AdamW/RAdamW without warmup or AdamW with each warmup schedule for β₂ = 0.99. 72 |

73 | 74 |

75 | Top-1 Error
76 | Top-1 errors of models trained by AdamW with each warmup schedule for β₂ = 0.99. 77 |

78 | 79 | ## Download ResNet for CIFAR10 80 | 81 | Run the Python script `download.py` to download 82 | [a ResNet implementation for CIFAR10](https://github.com/akamaster/pytorch_resnet_cifar10): 83 | 84 | ```shell 85 | python download.py 86 | ``` 87 | 88 | This script shows download progress: 89 | 90 | ``` 91 | Downloading https://.../resnet.py to ./resnet.py 92 | 100.0% 93 | ``` 94 | 95 | ## Train ResNet20 Models 96 | 97 | Run the Python script `main.py` to train a ResNet20 model on the CIFAR10 dataset using the SGD or AdamW algorithm. 98 | 99 | > [!Note] 100 | > Use `--workers` option to set the number of dataloader workers 101 | > for a better performance in a GPU training. 102 | > The optimal number depends on your environment. 103 | > For example, 2 and 4 for an MPS and CUDA device, respectively, 104 | > but it should not be more than the number of available performance cores. 105 | > Note that the initial epoch takes more time than a later epoch if the number of workers is greater than 0. 106 | 107 | The training log and the test evaluations are saved to files in the directory specified by `--output` option: 108 | 109 | * `history.csv` - The training log. 110 | * `evaluation.csv` - The test evaluations during training. 111 | * `cifar10_resnet20.pt` - The best model (saved optionally). 112 | 113 | You can visualize the training result using the Python script `plot.py`: 114 | 115 | ``` 116 | python plot.py [path to the directory] 117 | ``` 118 | 119 | This plot script requires `pandas` and `matplotlib`. 120 | 121 |

122 | Training History
123 | A training result for the AdamW algorithm with the RAdam warmup. 124 |

125 | 126 | ### SGD 127 | 128 | Train a ResNet20 model using the SGD algorithm: 129 | 130 | ``` 131 | python main.py --output output_sgd 132 | ``` 133 | 134 | ### AdamW 135 | 136 | #### No Warmup 137 | 138 | Train a ResNet20 model using AdamW without warmup: 139 | 140 | ``` 141 | python main.py --algorithm adamw --output output_adamw_none 142 | ``` 143 | 144 | #### Untuned Exponential Warmup 145 | 146 | Train a ResNet20 model using AdamW with the *Untuned Exponential Warmup* schedule: 147 | 148 | ``` 149 | python main.py --algorithm adamw --warmup exponential --output output_adamw_expo 150 | ``` 151 | 152 | #### Untuned Linear Warmup 153 | 154 | Train a ResNet20 model using AdamW with the *Untuned Linear Warmup* schedule: 155 | 156 | ``` 157 | python main.py --algorithm adamw --warmup linear --output output_adamw_linear 158 | ``` 159 | 160 | #### RAdam Warmup 161 | 162 | Train a ResNet20 model using AdamW with the *RAdam Warmup* schedule: 163 | 164 | ``` 165 | python main.py --algorithm adamw --warmup radam --output output_adamw_radam 166 | ``` 167 | 168 | #### Expo-5k Warmup 169 | 170 | Train a ResNet20 model using AdamW with the *Expo-5k Warmup* schedule: 171 | 172 | ``` 173 | python main.py --algorithm adamw --warmup exponential --warmup-period 5000 --output output_adamw_expo-5k 174 | ``` 175 | 176 | #### Linear-10k Warmup 177 | 178 | Train a ResNet20 model using AdamW with the *Linear-10k Warmup* schedule: 179 | 180 | ``` 181 | python main.py --algorithm adamw --warmup linear --warmup-period 10000 --output output_adamw_linear-10k 182 | ``` 183 | 184 | ## Usage 185 | 186 | ``` 187 | usage: main.py [-h] [-r ARCH] [-b BS] [-c BS] [-e NE] [-m M [M ...]] 188 | [-a ALGO] [-l LR] [-d WD] [-g B2] [-w WU] [-t TAU] 189 | [-n NW] [-s S] [-i I] [-o PATH] [--save-model] 190 | [--no-progress] [--no-gpu] 191 | 192 | PyTorch CIFAR10 Example 193 | 194 | options: 195 | -h, --help show this help message and exit 196 | -r ARCH, --arch ARCH ResNet architecture for CIFAR10: resnet20 | 197 | resnet32 | resnet44 | resnet56 | resnet110 | 198 | resnet1202 (default: resnet20) 199 | -b BS, --batch-size BS 200 | input batch size for training (default: 128) 201 | -c BS, --test-batch-size BS 202 | input batch size for testing (default: 1000) 203 | -e NE, --epochs NE number of epochs to train (default: 186) 204 | -m M [M ...], --milestones M [M ...] 205 | MultiStepLR's milestones in epoch (default: 206 | [81, 122]) 207 | -a ALGO, --algorithm ALGO 208 | optimization algorithm: sgd | adamw | amsgradw 209 | | nadamw | adamax | radamw (default: sgd) 210 | -l LR, --lr LR base learning rate (default: 0.1) 211 | -d WD, --weight-decay WD 212 | weight decay (default: 0.0001) 213 | -g B2, --beta2 B2 Adam's beta2 parameter (default: 0.999) 214 | -w WU, --warmup WU warmup schedule: linear | exponential | radam 215 | | none (default: none) 216 | -t TAU, --warmup-period TAU 217 | linear warmup period or exponential warmup 218 | constant. Set 0 to use the untuned linear or 219 | exponential warmup. (default: 0) 220 | -n NW, --workers NW number of dataloader workers for GPU training 221 | (default: 0) 222 | -s S, --seed S random seed (default: 1) 223 | -i I, --log-interval I 224 | how many batches to wait before logging 225 | training status 226 | -o PATH, --output PATH 227 | path to output directory (default: output) 228 | --save-model for saving the best model 229 | --no-progress disable progress bar 230 | --no-gpu disable GPU training. As default, an MPS or 231 | CUDA device will be used if available. 232 | --compile optimize PyTorch code using TorchDynamo, 233 | AOTAutograd, and TorchInductor 234 | ``` 235 | 236 | ``` 237 | usage: plot.py [-h] [--output {none,png,pdf}] PATH 238 | 239 | Training History Plot 240 | 241 | positional arguments: 242 | PATH path to the output directory of the training script 243 | 244 | options: 245 | -h, --help show this help message and exit 246 | --output {none,png,pdf} 247 | output file type (default: none) 248 | ``` 249 | 250 | ## Supplemental Information 251 | 252 | The model is trained for 186 epochs and the learning rate decays at the 81-th and the 122-th epochs by 0.1. 253 | The weight decay rate is $10^{-4}$. Batch size is 128. 254 | Random cropping and random horizontal flipping are applied to training data. 255 | PyTorch 2.3.1 is employed only for use of the RAdamW algorithm, otherwise PyTorch 2.1.2. 256 | A single P100 GPU is used to accelerate computations. 257 | 258 | The tables below present the top-1 errors depicted in the figures above. 259 | 260 | ### No Warmup 261 | 262 | | Optimizer | α=0.1 | α=0.01 | 263 | | --------- | --------------:| --------------:| 264 | | SGD | `8.23 ± 0.25` | `10.70 ± 0.14` | 265 | | AdaMax | `15.27 ± 0.28` | `8.54 ± 0.26` | 266 | | AdamW | `13.15 ± 0.75` | `9.06 ± 0.42` | 267 | | AMSGradW | `12.56 ± 0.39` | `9.55 ± 0.41` | 268 | | NAdamW | `10.12 ± 0.32` | `8.92 ± 0.29` | 269 | | RAdamW | `8.91 ± 0.25` | `8.82 ± 0.31` | 270 | 271 | ### AdamW 272 | 273 | | Warmup | α=0.1 | α=0.01 | 274 | | ---------- | -------------:| -------------:| 275 | | RAdam | `8.86 ± 0.10` | `8.60 ± 0.14` | 276 | | Expo | `8.67 ± 0.29` | `8.64 ± 0.15` | 277 | | Linear | `8.73 ± 0.21` | `8.81 ± 0.06` | 278 | | Expo-5k | `8.64 ± 0.04` | `8.58 ± 0.28` | 279 | | Linear-10k | `8.52 ± 0.11` | `8.55 ± 0.24` | 280 | 281 | ### AMSGradW 282 | 283 | | Warmup | α=0.1 | α=0.01 | 284 | | ---------- | -------------:| -------------:| 285 | | RAdam | `9.01 ± 0.12` | `9.25 ± 0.23` | 286 | | Expo | `8.93 ± 0.26` | `9.12 ± 0.26` | 287 | | Linear | `9.00 ± 0.28` | `9.16 ± 0.16` | 288 | | Expo-5k | `8.62 ± 0.09` | `8.90 ± 0.12` | 289 | | Linear-10k | `8.62 ± 0.14` | `8.95 ± 0.20` | 290 | 291 | ### NAdamW 292 | 293 | | Warmup | α=0.1 | α=0.01 | 294 | | ---------- | -------------:| -------------:| 295 | | RAdam | `8.72 ± 0.14` | `8.59 ± 0.16` | 296 | | Expo | `8.60 ± 0.28` | `8.51 ± 0.16` | 297 | | Linear | `8.54 ± 0.15` | `8.69 ± 0.19` | 298 | | Expo-5k | `8.43 ± 0.19` | `8.49 ± 0.07` | 299 | | Linear-10k | `8.31 ± 0.20` | `8.56 ± 0.08` | 300 | 301 | ### AdaMax 302 | 303 | | Warmup | α=0.1 | α=0.01 | 304 | | ---------- | --------------:| -------------:| 305 | | None | `15.27 ± 0.28` | `8.54 ± 0.26` | 306 | | RAdam | `14.04 ± 0.33` | `8.42 ± 0.15` | 307 | | Expo | `13.87 ± 0.31` | `8.23 ± 0.09` | 308 | | Linear | `13.77 ± 0.27` | `8.15 ± 0.15` | 309 | | Expo-5k | `13.31 ± 0.19` | `8.17 ± 0.17` | 310 | | Linear-10k | `13.45 ± 0.34` | `8.18 ± 0.19` | 311 | 312 | ### Smaller β₂ 313 | 314 | The exponential decay rates of Adam variants are set as $\beta_{1} = 0.9$ and $\beta_{2} = 0.99$. 315 | 316 | #### No Warmup 317 | | Optimizer | α=0.1 | α=0.01 | 318 | | --------- | --------------:| -------------:| 319 | | AdamW | `14.68 ± 0.95` | `9.14 ± 0.25` | 320 | | RAdamW | `10.41 ± 0.33` | `9.13 ± 0.21` | 321 | 322 | #### AdamW with Warmup 323 | | Warmup | α=0.1 | α=0.01 | 324 | | ---------- | --------------:| -------------:| 325 | | RAdam | `10.71 ± 0.49` | `9.11 ± 0.13` | 326 | | Expo | `10.30 ± 0.20` | `9.04 ± 0.28` | 327 | | Linear | `10.34 ± 0.33` | `8.97 ± 0.10` | 328 | | Expo-1k | `9.34 ± 0.20` | `8.98 ± 0.29` | 329 | | Linear-2k | `9.28 ± 0.27` | `8.84 ± 0.23` | 330 | | Expo-5k | `8.60 ± 0.19` | `8.63 ± 0.23` | 331 | | Linear-10k | `8.53 ± 0.09` | `8.43 ± 0.14` | 332 | 333 | ## ResNet Performance Comparison 334 | 335 | We employ the ResNet20, ResNet32, ResNet44, ResNet56, and ResNet110 architecture for comparison. 336 | The SGD with momentum is used as the optimization algorithm. 337 | The momentum factor is set to $0.9$. 338 | The learning rate is $10^{-1}$. 339 | We employ a linear warmup schedule to improve the top-1 score. 340 | The warmup period is set to 1,000. 341 | PyTorch 2.4.0 is used for model training of this performance comparison. 342 | 5 random seeds are used for sampling top-1 scores. 343 | The other implementation details are described in Supplemental Information above. 344 | 345 |

346 | Top-1 Error
347 | The best top-1 errors of ResNet models trained by the SGD algorithm without warmup or with the Linear-1k warmup. 348 |

349 | 350 |

351 | Top-1 Error
352 | Top-1 errors of ResNet models trained by the SGD algorithm without warmup or with the Linear-1k warmup. 353 | This bar chart presents the mean values. The error bar indicates the standard deviation. 354 |

355 | 356 | ### SGD 357 | 358 | The top-1 errors are shown as mean ± std. 359 | 360 | | Architecture | No Warmup | Linear-1k Warmup | 361 | | ------------ | -------------:| ----------------:| 362 | | ResNet20 | `8.11 ± 0.16` | `8.06 ± 0.21` | 363 | | ResNet32 | `7.93 ± 0.62` | `7.27 ± 0.21` | 364 | | ResNet44 | `7.47 ± 0.50` | `6.99 ± 0.19` | 365 | | ResNet56 | `7.53 ± 1.01` | `6.80 ± 0.12` | 366 | | ResNet110 | `7.76 ± 0.69` | `6.28 ± 0.10` | 367 | 368 | © 2024-2025 Takenori Yamamoto -------------------------------------------------------------------------------- /examples/cifar10/download.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets.utils import download_url 2 | 3 | # A downloadable URL of resnet.py in a GitHub repo: 4 | # https://github.com/akamaster/pytorch_resnet_cifar10 5 | url = 'https://raw.githubusercontent.com/akamaster/pytorch_resnet_cifar10/refs/heads/master/resnet.py' 6 | md5 = '9dc255cf8dc64c8b47c2b109c5b28d07' 7 | 8 | download_url(url, root='./', filename=None, md5=md5) 9 | -------------------------------------------------------------------------------- /examples/cifar10/figs/fig-history-adamw-w-radam-warmup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-history-adamw-w-radam-warmup.png -------------------------------------------------------------------------------- /examples/cifar10/figs/fig-resnet-sgd-scores-best.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-resnet-sgd-scores-best.png -------------------------------------------------------------------------------- /examples/cifar10/figs/fig-resnet-sgd-scores-mean-std.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-resnet-sgd-scores-mean-std.png -------------------------------------------------------------------------------- /examples/cifar10/figs/fig-scores-adamax-w-warmup-lr0-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-scores-adamax-w-warmup-lr0-1.png -------------------------------------------------------------------------------- /examples/cifar10/figs/fig-scores-adamax-w-warmup-vs-sgd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-scores-adamax-w-warmup-vs-sgd.png -------------------------------------------------------------------------------- /examples/cifar10/figs/fig-scores-adamw-w-warmup-edf0-99-vs-radamw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-scores-adamw-w-warmup-edf0-99-vs-radamw.png -------------------------------------------------------------------------------- /examples/cifar10/figs/fig-scores-adamw-w-warmup-edf0-99-vs-sgd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-scores-adamw-w-warmup-edf0-99-vs-sgd.png -------------------------------------------------------------------------------- /examples/cifar10/figs/fig-scores-adamw-w-warmup-vs-sgd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-scores-adamw-w-warmup-vs-sgd.png -------------------------------------------------------------------------------- /examples/cifar10/figs/fig-scores-amsgradw-w-warmup-vs-sgd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-scores-amsgradw-w-warmup-vs-sgd.png -------------------------------------------------------------------------------- /examples/cifar10/figs/fig-scores-nadamw-w-warmup-vs-sgd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-scores-nadamw-w-warmup-vs-sgd.png -------------------------------------------------------------------------------- /examples/cifar10/figs/fig-scores-no-warmup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/cifar10/figs/fig-scores-no-warmup.png -------------------------------------------------------------------------------- /examples/cifar10/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import argparse 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torchvision import datasets, transforms 9 | from tqdm.auto import tqdm 10 | import pytorch_warmup as warmup 11 | 12 | try: 13 | import resnet 14 | except ImportError: 15 | sys.exit('Download resnet.py from https://github.com/akamaster/pytorch_resnet_cifar10') 16 | 17 | import torch.backends.cudnn as cudnn 18 | cudnn.benchmark = True 19 | 20 | 21 | architecture_names = ['resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 22 | algorithm_names = ['sgd', 'adamw', 'amsgradw', 'nadamw', 'adamax', 'radamw'] 23 | warmup_names = ['linear', 'exponential', 'radam', 'none'] 24 | 25 | 26 | def check_pytorch_version(algorithm): 27 | major, minor, patch = map(int, torch.__version__.split('+')[0].split('.')) 28 | if major == 0 or (major == 1 and minor < 12): 29 | sys.exit('This script requires PyTorch 1.12+ or 2.x.') 30 | 31 | if algorithm == 'nadamw' and (major == 1 or (major == 2 and minor < 1)): 32 | sys.exit('[Error] The NAdamW optimization algorithm requires PyTorch 2.1 or later.') 33 | elif algorithm == 'radamw' and (major == 1 or (major == 2 and minor < 3)): 34 | sys.exit('[Error] The RAdamW optimization algorithm requires PyTorch 2.3 or later.') 35 | 36 | 37 | def get_lr(args, optimizer): 38 | lr = optimizer.param_groups[0]['lr'] 39 | return lr.item() if args.compile else lr 40 | 41 | 42 | def train_iter_loss_fn(optimizer, model, data, target): 43 | optimizer.zero_grad() 44 | output = model(data) 45 | loss = F.cross_entropy(output, target) 46 | loss.backward() 47 | optimizer.step() 48 | return loss 49 | 50 | 51 | def update_lr_fn(lr_scheduler, warmup_scheduler): 52 | with warmup_scheduler.dampening(): 53 | lr_scheduler.step() 54 | 55 | 56 | def train(args, model, device, train_loader, optimizer, lr_scheduler, 57 | warmup_scheduler, epoch, history): 58 | since = time.time() 59 | model.train() 60 | progress = tqdm(total=len(train_loader), disable=args.no_progress) 61 | progress.set_description(f"[train] Epoch {epoch}") 62 | train_loss = 0 63 | for batch_idx, (data, target) in enumerate(train_loader): 64 | lr = get_lr(args, optimizer) 65 | data, target = data.to(device), target.to(device) 66 | loss = train_iter_loss_fn(optimizer, model, data, target) 67 | update_lr_fn(lr_scheduler, warmup_scheduler) 68 | loss = loss.item() 69 | train_loss += loss 70 | batch_step = batch_idx + 1 71 | if batch_step % args.log_interval == 0: 72 | step = warmup_scheduler.last_step 73 | history.write(f'{epoch},{step},{loss:g},{lr:g}\n') 74 | if not progress.disable: 75 | if batch_step % 10 == 0: 76 | progress.set_postfix_str(f'loss={loss:.2f}, lr={lr:5.4f}') 77 | progress.update() 78 | progress.close() 79 | 80 | train_loss /= len(train_loader) 81 | elapsed = time.time() - since 82 | print(f'[train] Epoch {epoch}: Elapsed Time: {elapsed:.3f} sec, ' + 83 | f'Ave. Loss: {train_loss:.4f}') 84 | 85 | 86 | def test_iter_loss_fn(model, data, target): 87 | output = model(data) 88 | loss = F.cross_entropy(output, target, reduction='sum') 89 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max of unnormalized logits 90 | correct = pred.eq(target.view_as(pred)).sum() 91 | return loss, correct 92 | 93 | 94 | @torch.inference_mode() 95 | def test(args, model, device, test_loader, epoch, evaluation): 96 | since = time.time() 97 | model.eval() 98 | progress = tqdm(test_loader, disable=args.no_progress) 99 | progress.set_description(f"[test] Epoch {epoch}") 100 | test_loss = 0 101 | correct = 0 102 | for data, target in progress: 103 | data, target = data.to(device), target.to(device) 104 | batch_loss, batch_correct = test_iter_loss_fn(model, data, target) 105 | test_loss += batch_loss.item() # sum up batch loss 106 | correct += batch_correct.item() # sum up batch correct 107 | 108 | test_loss /= len(test_loader.dataset) 109 | test_acc = 100. * correct / len(test_loader.dataset) 110 | elapsed = time.time() - since 111 | print(f'[test] Epoch {epoch}: Elapsed Time: {elapsed:.3f} sec, ' + 112 | f'Ave. Loss: {test_loss:.4f}, Accuracy: {test_acc:.2f}%') 113 | evaluation.write(f'{epoch},{test_loss:g},{test_acc:.2f}\n') 114 | evaluation.flush() 115 | return test_acc 116 | 117 | 118 | def gpu_device(): 119 | if torch.backends.mps.is_available(): 120 | return torch.device('mps') 121 | elif torch.cuda.is_available(): 122 | return torch.device('cuda') 123 | else: 124 | return torch.device('cpu') 125 | 126 | 127 | def dataloader_options(device, workers): 128 | if device.type == 'cpu': 129 | return {} 130 | 131 | kwargs = dict(num_workers=workers, pin_memory=True) 132 | if workers > 0: 133 | if device.type == 'mps': 134 | kwargs.update(dict(multiprocessing_context="forkserver", persistent_workers=True)) 135 | else: 136 | kwargs.update(dict(persistent_workers=True)) 137 | return kwargs 138 | 139 | 140 | def optimization_algorithm(args, model, device): 141 | name = args.algorithm 142 | lr = torch.tensor(args.lr).to(device) if args.compile else args.lr 143 | kwargs = dict(lr=lr, weight_decay=args.weight_decay) 144 | if name == 'sgd': 145 | kwargs['momentum'] = 0.9 146 | else: 147 | kwargs['betas'] = (0.9, args.beta2) 148 | 149 | if name == 'sgd': 150 | return optim.SGD(model.parameters(), **kwargs) 151 | elif name == 'adamw': 152 | return optim.AdamW(model.parameters(), **kwargs) 153 | elif name == 'amsgradw': 154 | return optim.AdamW(model.parameters(), amsgrad=True, **kwargs) 155 | elif name == 'nadamw': 156 | return optim.NAdam(model.parameters(), decoupled_weight_decay=True, **kwargs) 157 | elif name == 'adamax': 158 | return optim.Adamax(model.parameters(), **kwargs) 159 | elif name == 'radamw': 160 | return optim.RAdam(model.parameters(), decoupled_weight_decay=True, **kwargs) 161 | else: 162 | raise ValueError(f'unknown optimization algorithm: {name}') 163 | 164 | 165 | def warmup_schedule(optimizer, name, period): 166 | if name == 'linear': 167 | if period == 0: 168 | return warmup.UntunedLinearWarmup(optimizer) 169 | else: 170 | return warmup.LinearWarmup(optimizer, period) 171 | elif name == 'exponential': 172 | if period == 0: 173 | return warmup.UntunedExponentialWarmup(optimizer) 174 | else: 175 | return warmup.ExponentialWarmup(optimizer, period) 176 | elif name == 'radam': 177 | return warmup.RAdamWarmup(optimizer) 178 | elif name == 'none': 179 | return warmup.LinearWarmup(optimizer, 1) 180 | else: 181 | raise ValueError(f'unknown warmup schedule: {name}') 182 | 183 | 184 | def compile_functions(): 185 | global train_iter_loss_fn 186 | global test_iter_loss_fn 187 | train_iter_loss_fn = torch.compile(train_iter_loss_fn, mode="reduce-overhead") 188 | test_iter_loss_fn = torch.compile(test_iter_loss_fn, mode="reduce-overhead") 189 | 190 | 191 | def init_momentum_buffer(optimizer): 192 | for group in optimizer.param_groups: 193 | if group["momentum"] != 0: 194 | for p in group["params"]: 195 | state = optimizer.state[p] 196 | if state.get("momentum_buffer") is None: 197 | state["momentum_buffer"] = torch.zeros_like(p.data) 198 | 199 | 200 | def main(args=None): 201 | # Training settings 202 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Example') 203 | parser.add_argument('-r', '--arch', type=str, default='resnet20', metavar='ARCH', 204 | choices=architecture_names, 205 | help='ResNet architecture for CIFAR10: ' + 206 | ' | '.join(architecture_names) + ' (default: resnet20)') 207 | parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='BS', 208 | help='input batch size for training (default: 128)') 209 | parser.add_argument('-c', '--test-batch-size', type=int, default=1000, metavar='BS', 210 | help='input batch size for testing (default: 1000)') 211 | parser.add_argument('-e', '--epochs', type=int, default=186, metavar='NE', 212 | help='number of epochs to train (default: 186)') 213 | parser.add_argument('-m', '--milestones', type=int, nargs='+', default=[81, 122], metavar='M', 214 | help="MultiStepLR's milestones in epoch (default: [81, 122])") 215 | parser.add_argument('-a', '--algorithm', type=str, default='sgd', metavar='ALGO', 216 | choices=algorithm_names, 217 | help='optimization algorithm: ' + 218 | ' | '.join(algorithm_names) + ' (default: sgd)') 219 | parser.add_argument('-l', '--lr', type=float, default=0.1, metavar='LR', 220 | help='base learning rate (default: 0.1)') 221 | parser.add_argument('-d', '--weight-decay', type=float, default=0.0001, metavar='WD', 222 | help='weight decay (default: 0.0001)') 223 | parser.add_argument('-g', '--beta2', type=float, default=0.999, metavar='B2', 224 | help="Adam's beta2 parameter (default: 0.999)") 225 | parser.add_argument('-w', '--warmup', type=str, default='none', metavar='WU', 226 | choices=warmup_names, 227 | help='warmup schedule: ' + 228 | ' | '.join(warmup_names) + ' (default: none)') 229 | parser.add_argument('-t', '--warmup-period', type=int, default=0, metavar='TAU', 230 | help='linear warmup period or exponential warmup constant. ' + 231 | 'Set 0 to use the untuned linear or exponential warmup. (default: 0)') 232 | parser.add_argument('-n', '--workers', type=int, default=0, metavar='NW', 233 | help='number of dataloader workers for GPU training (default: 0)') 234 | parser.add_argument('-s', '--seed', type=int, default=1, metavar='S', 235 | help='random seed (default: 1)') 236 | parser.add_argument('-i', '--log-interval', type=int, default=10, metavar='I', 237 | help='how many batches to wait before logging training status') 238 | parser.add_argument('-o', '--output', default='output', metavar='PATH', 239 | help='path to output directory (default: output)') 240 | parser.add_argument('--save-model', action='store_true', default=False, 241 | help='for saving the best model') 242 | parser.add_argument('--no-progress', action='store_true', default=False, 243 | help='disable progress bar') 244 | parser.add_argument('--no-gpu', action='store_true', default=False, 245 | help='disable GPU training. ' + 246 | 'As default, an MPS or CUDA device will be used if available.') 247 | parser.add_argument('--compile', action='store_true', default=False, 248 | help='optimize PyTorch code using TorchDynamo, AOTAutograd, and TorchInductor') 249 | args = parser.parse_args(args) 250 | 251 | check_pytorch_version(args.algorithm) 252 | 253 | print(args) 254 | device = torch.device('cpu') if args.no_gpu else gpu_device() 255 | print(f'Device: {device.type}') 256 | 257 | torch.manual_seed(args.seed) 258 | 259 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 260 | std=[0.229, 0.224, 0.225]) 261 | kwargs = dataloader_options(device, args.workers) 262 | train_loader = torch.utils.data.DataLoader( 263 | datasets.CIFAR10( 264 | 'data', train=True, download=True, 265 | transform=transforms.Compose([ 266 | transforms.RandomHorizontalFlip(), 267 | transforms.RandomCrop(32, 4), 268 | transforms.ToTensor(), 269 | normalize, 270 | ])), 271 | batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs) 272 | test_loader = torch.utils.data.DataLoader( 273 | datasets.CIFAR10( 274 | 'data', train=False, 275 | transform=transforms.Compose([ 276 | transforms.ToTensor(), 277 | normalize, 278 | ])), 279 | batch_size=args.test_batch_size, shuffle=False, **kwargs) 280 | 281 | output_dir = args.output 282 | try: 283 | os.makedirs(output_dir, exist_ok=False) 284 | except FileExistsError: 285 | sys.exit(f'[Error] File exists: {output_dir}') 286 | 287 | history = open(os.path.join(output_dir, 'history.csv'), 'w') 288 | history.write('epoch,step,loss,lr\n') 289 | 290 | evaluation = open(os.path.join(output_dir, 'evaluation.csv'), 'w') 291 | evaluation.write('epoch,loss,accuracy\n') 292 | 293 | model = resnet.__dict__[args.arch]().to(device) 294 | 295 | optimizer = optimization_algorithm(args, model, device) 296 | 297 | steps_per_epoch = len(train_loader) 298 | lr_scheduler = optim.lr_scheduler.MultiStepLR( 299 | optimizer, 300 | milestones=[i * steps_per_epoch for i in args.milestones]) 301 | 302 | warmup_scheduler = warmup_schedule(optimizer, 303 | name=args.warmup, 304 | period=args.warmup_period) 305 | 306 | if args.compile: 307 | if args.algorithm == 'sgd': 308 | init_momentum_buffer(optimizer) 309 | compile_functions() 310 | 311 | best_acc = 0.0 312 | best_epoch = 0 313 | print() 314 | for epoch in range(1, args.epochs + 1): 315 | train(args, model, device, train_loader, optimizer, lr_scheduler, 316 | warmup_scheduler, epoch, history) 317 | cur_acc = test(args, model, device, test_loader, epoch, evaluation) 318 | 319 | if cur_acc > best_acc: 320 | best_acc = cur_acc 321 | best_epoch = epoch 322 | if args.save_model: 323 | torch.save(model.state_dict(), os.path.join(output_dir, f"cifar10_{args.arch}.pt")) 324 | 325 | print(f"The best accuracy: {best_acc:.2f}% (epoch {best_epoch})") 326 | 327 | history.close() 328 | evaluation.close() 329 | 330 | 331 | if __name__ == '__main__': 332 | main() 333 | -------------------------------------------------------------------------------- /examples/cifar10/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def main(args=None): 8 | parser = argparse.ArgumentParser(description='Training History Plot') 9 | parser.add_argument('path', metavar='PATH', 10 | help='path to the output directory of the training script') 11 | parser.add_argument('--output', type=str, default='none', 12 | choices=['none', 'png', 'pdf'], 13 | help='output file type (default: none)') 14 | args = parser.parse_args(args) 15 | 16 | output_dir = args.path 17 | 18 | df_hist = pd.read_csv(os.path.join(output_dir, "history.csv")) 19 | df_eval = pd.read_csv(os.path.join(output_dir, "evaluation.csv")) 20 | 21 | fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(6, 8), layout="constrained") 22 | 23 | df_hist.plot(x="step", y="lr", ax=ax1, legend=False) 24 | ax1.set_xlabel('Iteration') 25 | ax1.set_ylabel('Learning rate') 26 | 27 | min_loss = df_hist.loss.min() 28 | df_hist.plot(x="step", y="loss", ylim=(0, 1) if min_loss < 0.5 else None, ax=ax2, legend=False) 29 | ax2.set_xlabel('Iteration') 30 | ax2.set_ylabel('Loss') 31 | 32 | df_eval.plot(x="epoch", y="accuracy", ylim=(0, 100), ax=ax3, legend=False) 33 | ax3.set_xlabel('Epoch') 34 | ax3.set_ylabel('Accuracy (%)') 35 | max_idx = df_eval.accuracy.argmax() 36 | max_epoch = df_eval.epoch[max_idx] 37 | max_acc = df_eval.accuracy[max_idx] 38 | ax3.axvline(x=max_epoch, color='red', ls=':') 39 | ax3.text(x=max_epoch, y=5, s=f'Max: {max_acc:.2f}% @ {max_epoch} ', ha='right', va='bottom') 40 | ax3.plot([0], [max_acc]) 41 | 42 | if args.output == 'none': 43 | plt.show() 44 | else: 45 | file_path = os.path.join(output_dir, f'fig_history.{args.output}') 46 | plt.savefig(file_path) 47 | print(file_path) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /examples/emnist/README.md: -------------------------------------------------------------------------------- 1 | # EMNIST Example 2 | 3 | Requirements: `pytorch_warmup` and `torchvision`. 4 | 5 | ## Results 6 | 7 |

8 | Accuracy
9 | Test accuracy over time for each warmup schedule. 10 |

11 | 12 |

13 | Accuracy
14 | Learning rate over time for each warmup schedule. 15 |

16 | 17 | ## Download EMNIST Dataset 18 | 19 | Run the Python script `download.py` to download the EMNIST dataset: 20 | 21 | ```shell 22 | python download.py 23 | ``` 24 | 25 | This script shows download progress: 26 | 27 | ``` 28 | Downloading zip archive 29 | Downloading https://.../EMNIST/gzip.zip to data/EMNIST/raw/gzip.zip 30 | 100.0% 31 | ``` 32 | 33 | ## Train CNN Models 34 | 35 | Run the Python script `main.py` to train a CNN model on the EMNIST dataset using the AdamW algorithm. 36 | 37 | > [!Note] 38 | > Use `--workers` option to set the number of dataloader workers 39 | > for a better performance in a GPU training. 40 | > The optimal number depends on your environment. 41 | > For example, 2 and 4 for an MPS and CUDA device, respectively, 42 | > but it should not be more than the number of available performance cores. 43 | > Note that the initial epoch takes more time than a later epoch if the number of workers is greater than 0. 44 | 45 | The training log and the test evaluations are saved to files in a directory named `output_[warmup schedule name]`: 46 | 47 | * `history.csv` - The training log. 48 | * `evaluation.csv` - The test evaluations during training. 49 | * `emnist_cnn.pt` - The current model (saved optionally). 50 | 51 | You can visualize the training result using the Python script `plot.py`: 52 | 53 | ``` 54 | python plot.py [path to the directory] 55 | ``` 56 | 57 | This plot script requires `pandas` and `matplotlib`. 58 | 59 |

60 | Training History
61 | A training result for the AdamW algorithm with the RAdam warmup. 62 |

63 | 64 | ### Untuned Linear Warmup 65 | 66 | Train a CNN model with the *Untuned Linear Warmup* schedule: 67 | 68 | ``` 69 | python main.py --warmup linear 70 | ``` 71 | 72 | ### Untuned Exponential Warmup 73 | 74 | Train a CNN model with the *Untuned Exponential Warmup* schedule: 75 | 76 | ``` 77 | python main.py --warmup exponential 78 | ``` 79 | 80 | ### RAdam Warmup 81 | 82 | Train a CNN model with the *RAdam Warmup* schedule: 83 | 84 | ``` 85 | python main.py --warmup radam 86 | ``` 87 | 88 | ### No Warmup 89 | 90 | Train a CNN model without warmup: 91 | 92 | ``` 93 | python main.py --warmup none 94 | ``` 95 | 96 | > [!Warning] 97 | > You may have a very different result from one shown in the figure 98 | > because a training without warmup can become significantly unstable at a very early stage. 99 | > The observed accuracies at the last epoch are 2%, 78%, 86%, etc. 100 | > The figure's result was obtained on Apple M1 Pro chip without GPU acceleration. 101 | 102 | ## Usage 103 | 104 | ``` 105 | usage: main.py [-h] [--batch-size N] [--test-batch-size N] [--epochs N] 106 | [--lr LR] [--lr-min LM] [--wd WD] [--beta2 B2] [--seed S] 107 | [--log-interval N] 108 | [--warmup {linear,exponential,radam,none}] [--workers N] 109 | [--save-model] [--no-gpu] 110 | 111 | PyTorch EMNIST Example 112 | 113 | options: 114 | -h, --help show this help message and exit 115 | --batch-size N input batch size for training (default: 64) 116 | --test-batch-size N input batch size for testing (default: 1000) 117 | --epochs N number of epochs to train (default: 10) 118 | --lr LR base learning rate (default: 0.01) 119 | --lr-min LM minimum learning rate (default: 1e-5) 120 | --wd WD weight decay (default: 0.01) 121 | --beta2 B2 Adam's beta2 parameter (default: 0.999) 122 | --seed S random seed (default: 1) 123 | --log-interval N how many batches to wait before logging training 124 | status 125 | --warmup {linear,exponential,radam,none} 126 | warmup schedule 127 | --workers N number of dataloader workers for GPU training 128 | (default: 0) 129 | --save-model for saving the current model 130 | --no-gpu disable GPU training. As default, an MPS or CUDA 131 | device will be used if available. 132 | ``` 133 | 134 | ``` 135 | usage: plot.py [-h] [--output {none,png,pdf}] PATH 136 | 137 | Training History Plot 138 | 139 | positional arguments: 140 | PATH path to the output directory of the training script 141 | 142 | options: 143 | -h, --help show this help message and exit 144 | --output {none,png,pdf} 145 | output file type (default: none) 146 | ``` 147 | 148 | © 2024 Takenori Yamamoto -------------------------------------------------------------------------------- /examples/emnist/download.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torchvision.datasets.utils import download_url 3 | import os 4 | 5 | raw_folder = 'data/EMNIST/raw' 6 | 7 | url = 'https://biometrics.nist.gov/cs_links/EMNIST/gzip.zip' 8 | md5 = "58c8d27c78d21e728a6bc7b3cc06412e" 9 | 10 | version_numbers = list(map(int, torchvision.__version__.split('+')[0].split('.'))) 11 | if version_numbers[0] == 0 and version_numbers[1] < 10: 12 | filename = "emnist.zip" 13 | else: 14 | filename = None 15 | 16 | os.makedirs(raw_folder, exist_ok=True) 17 | 18 | # download files 19 | print('Downloading zip archive') 20 | download_url(url, root=raw_folder, filename=filename, md5=md5) 21 | -------------------------------------------------------------------------------- /examples/emnist/figs/accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/emnist/figs/accuracy.png -------------------------------------------------------------------------------- /examples/emnist/figs/fig-history-adamw-w-radam-warmup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/emnist/figs/fig-history-adamw-w-radam-warmup.png -------------------------------------------------------------------------------- /examples/emnist/figs/learning_rate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/emnist/figs/learning_rate.png -------------------------------------------------------------------------------- /examples/emnist/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torchvision import datasets, transforms 7 | 8 | import pytorch_warmup as warmup 9 | import os 10 | import sys 11 | import time 12 | 13 | 14 | class Net(nn.Module): 15 | def __init__(self): 16 | super(Net, self).__init__() 17 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 18 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 19 | self.fc1 = nn.Linear(4*4*50, 500) 20 | self.fc2 = nn.Linear(500, 47) 21 | 22 | def forward(self, x): 23 | x = F.relu(self.conv1(x)) 24 | x = F.max_pool2d(x, 2, 2) 25 | x = F.relu(self.conv2(x)) 26 | x = F.max_pool2d(x, 2, 2) 27 | x = x.view(-1, 4*4*50) 28 | x = F.relu(self.fc1(x)) 29 | x = self.fc2(x) 30 | return F.log_softmax(x, dim=1) 31 | 32 | 33 | def train(args, model, device, train_loader, optimizer, lr_scheduler, 34 | warmup_scheduler, epoch, history): 35 | since = time.time() 36 | model.train() 37 | for batch_idx, (data, target) in enumerate(train_loader): 38 | lr = optimizer.param_groups[0]['lr'] 39 | data, target = data.to(device), target.to(device) 40 | optimizer.zero_grad() 41 | output = model(data) 42 | loss = F.nll_loss(output, target) 43 | loss.backward() 44 | optimizer.step() 45 | with warmup_scheduler.dampening(): 46 | lr_scheduler.step() 47 | if (batch_idx+1) % args.log_interval == 0: 48 | loss = loss.item() 49 | step = warmup_scheduler.last_step 50 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} LR: {:.6f}'.format( 51 | epoch, (batch_idx+1) * len(data), len(train_loader) * len(data), 52 | 100. * (batch_idx+1) / len(train_loader), loss, lr)) 53 | history.write(f'{epoch},{step},{loss:g},{lr:g}\n') 54 | print('Train Elapsed Time: {:.3f} sec'.format(time.time()-since)) 55 | 56 | 57 | def test(args, model, device, test_loader, epoch, evaluation): 58 | since = time.time() 59 | model.eval() 60 | test_loss = 0 61 | correct = 0 62 | with torch.no_grad(): 63 | for data, target in test_loader: 64 | data, target = data.to(device), target.to(device) 65 | output = model(data) 66 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 67 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 68 | correct += pred.eq(target.view_as(pred)).sum().item() 69 | 70 | test_loss /= len(test_loader.dataset) 71 | test_acc = 100. * correct / len(test_loader.dataset) 72 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format( 73 | test_loss, correct, len(test_loader.dataset), test_acc)) 74 | evaluation.write(f'{epoch},{test_loss:g},{test_acc:.2f}\n') 75 | evaluation.flush() 76 | print('Test Elapsed Time: {:.3f} sec\n'.format(time.time()-since)) 77 | 78 | 79 | def mps_is_available(): 80 | try: 81 | return torch.backends.mps.is_available() 82 | except AttributeError: 83 | return False 84 | 85 | 86 | def gpu_device(): 87 | if torch.cuda.is_available(): 88 | return torch.device('cuda') 89 | elif mps_is_available(): 90 | return torch.device('mps') 91 | else: 92 | return torch.device('cpu') 93 | 94 | 95 | def dataloader_options(device, workers): 96 | if device.type == 'cpu': 97 | return {} 98 | 99 | kwargs = dict(num_workers=workers, pin_memory=True) 100 | if workers > 0: 101 | if device.type == 'mps': 102 | kwargs.update(dict(multiprocessing_context="forkserver", persistent_workers=True)) 103 | else: 104 | kwargs.update(dict(persistent_workers=True)) 105 | return kwargs 106 | 107 | 108 | def warmup_schedule(optimizer, name): 109 | if name == 'linear': 110 | return warmup.UntunedLinearWarmup(optimizer) 111 | elif name == 'exponential': 112 | return warmup.UntunedExponentialWarmup(optimizer) 113 | elif name == 'radam': 114 | return warmup.RAdamWarmup(optimizer) 115 | elif name == 'none': 116 | return warmup.LinearWarmup(optimizer, 1) 117 | 118 | 119 | def main(args=None): 120 | # Training settings 121 | parser = argparse.ArgumentParser(description='PyTorch EMNIST Example') 122 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 123 | help='input batch size for training (default: 64)') 124 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 125 | help='input batch size for testing (default: 1000)') 126 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 127 | help='number of epochs to train (default: 10)') 128 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 129 | help='base learning rate (default: 0.01)') 130 | parser.add_argument('--lr-min', type=float, default=1e-5, metavar='LM', 131 | help='minimum learning rate (default: 1e-5)') 132 | parser.add_argument('--wd', type=float, default=0.01, metavar='WD', 133 | help='weight decay (default: 0.01)') 134 | parser.add_argument('--beta2', type=float, default=0.999, metavar='B2', 135 | help="Adam's beta2 parameter (default: 0.999)") 136 | parser.add_argument('--seed', type=int, default=1, metavar='S', 137 | help='random seed (default: 1)') 138 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 139 | help='how many batches to wait before logging training status') 140 | parser.add_argument('--warmup', type=str, default='linear', 141 | choices=['linear', 'exponential', 'radam', 'none'], 142 | help='warmup schedule') 143 | parser.add_argument('--workers', type=int, default=0, metavar='N', 144 | help='number of dataloader workers for GPU training (default: 0)') 145 | parser.add_argument('--save-model', action='store_true', default=False, 146 | help='for saving the current model') 147 | parser.add_argument('--no-gpu', action='store_true', default=False, 148 | help='disable GPU training. ' + 149 | 'As default, an MPS or CUDA device will be used if available.') 150 | args = parser.parse_args(args) 151 | 152 | print(args) 153 | device = torch.device('cpu') if args.no_gpu else gpu_device() 154 | print(f'Device: {device.type}') 155 | 156 | torch.manual_seed(args.seed) 157 | 158 | kwargs = dataloader_options(device, args.workers) 159 | train_loader = torch.utils.data.DataLoader( 160 | datasets.EMNIST('data', 'balanced', train=True, download=True, 161 | transform=transforms.Compose([ 162 | transforms.ToTensor(), 163 | transforms.Normalize((0.1751,), (0.3332,)) 164 | ])), 165 | batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs) 166 | test_loader = torch.utils.data.DataLoader( 167 | datasets.EMNIST('data', 'balanced', train=False, 168 | transform=transforms.Compose([ 169 | transforms.ToTensor(), 170 | transforms.Normalize((0.1751,), (0.3332,)) 171 | ])), 172 | batch_size=args.test_batch_size, shuffle=False, **kwargs) 173 | 174 | output_dir = f'output_{args.warmup}' 175 | try: 176 | os.makedirs(output_dir, exist_ok=False) 177 | except FileExistsError: 178 | sys.exit(f'[Error] File exists: {output_dir}') 179 | 180 | history = open(os.path.join(output_dir, 'history.csv'), 'w') 181 | history.write('epoch,step,loss,lr\n') 182 | 183 | evaluation = open(os.path.join(output_dir, 'evaluation.csv'), 'w') 184 | evaluation.write('epoch,loss,accuracy\n') 185 | 186 | model = Net().to(device) 187 | 188 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, 189 | betas=(0.9, args.beta2), 190 | weight_decay=args.wd) 191 | num_steps = len(train_loader) * args.epochs 192 | lr_scheduler = optim.lr_scheduler.CosineAnnealingLR( 193 | optimizer, T_max=num_steps, eta_min=args.lr_min) 194 | warmup_scheduler = warmup_schedule(optimizer, args.warmup) 195 | 196 | for epoch in range(1, args.epochs + 1): 197 | train(args, model, device, train_loader, optimizer, lr_scheduler, 198 | warmup_scheduler, epoch, history) 199 | test(args, model, device, test_loader, epoch, evaluation) 200 | 201 | if args.save_model: 202 | torch.save(model.state_dict(), os.path.join(output_dir, "emnist_cnn.pt")) 203 | 204 | history.close() 205 | evaluation.close() 206 | 207 | 208 | if __name__ == '__main__': 209 | main() 210 | -------------------------------------------------------------------------------- /examples/emnist/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def main(args=None): 8 | parser = argparse.ArgumentParser(description='Training History Plot') 9 | parser.add_argument('path', metavar='PATH', 10 | help='path to the output directory of the training script') 11 | parser.add_argument('--output', type=str, default='none', 12 | choices=['none', 'png', 'pdf'], 13 | help='output file type (default: none)') 14 | args = parser.parse_args(args) 15 | 16 | output_dir = args.path 17 | 18 | df_hist = pd.read_csv(os.path.join(output_dir, "history.csv")) 19 | df_eval = pd.read_csv(os.path.join(output_dir, "evaluation.csv")) 20 | 21 | fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(6, 8), layout="constrained") 22 | 23 | df_hist.plot(x="step", y="lr", ax=ax1, legend=False) 24 | ax1.set_xlabel('Iteration') 25 | ax1.set_ylabel('Learning rate') 26 | 27 | min_loss = df_hist.loss.min() 28 | df_hist.plot(x="step", y="loss", ylim=(0, 1) if min_loss < 0.8 else None, ax=ax2, legend=False) 29 | ax2.set_xlabel('Iteration') 30 | ax2.set_ylabel('Loss') 31 | 32 | df_eval.plot(x="epoch", y="accuracy", marker="o", ax=ax3, legend=False) 33 | ax3.set_xlabel('Epoch') 34 | ax3.set_ylabel('Accuracy (%)') 35 | max_idx = df_eval.accuracy.argmax() 36 | max_epoch = df_eval.epoch[max_idx] 37 | max_acc = df_eval.accuracy[max_idx] 38 | min_acc = df_eval.accuracy.min() 39 | ax3.axvline(x=max_epoch, color='red', ls=':') 40 | ax3.text(x=max_epoch, y=min_acc, s=f'Max: {max_acc:.2f}% @ {max_epoch} ', ha='right', va='bottom') 41 | ax3.plot([0], [max_acc]) 42 | 43 | if args.output == 'none': 44 | plt.show() 45 | else: 46 | file_path = os.path.join(output_dir, f'fig_history.{args.output}') 47 | plt.savefig(file_path) 48 | print(file_path) 49 | 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /examples/plots/README.md: -------------------------------------------------------------------------------- 1 | # Plots 2 | 3 | Requirements: `pytorch_warmup` and `matplotlib`. 4 | 5 | ## Effective Warmup Period 6 | 7 |

8 | Warmup period
9 | Effective warmup periods of RAdam and rule-of-thumb warmup schedules, as a function of β₂. 10 |

11 | 12 | Run the Python script `effective_warmup_period.py` to show up the figure above: 13 | 14 | ```shell 15 | python effective_warmup_period.py 16 | ``` 17 | 18 | ### Usage 19 | 20 | ``` 21 | usage: effective_warmup_period.py [-h] [--output {none,png,pdf}] 22 | 23 | Effective warmup period 24 | 25 | options: 26 | -h, --help show this help message and exit 27 | --output {none,png,pdf} 28 | Output file type (default: none) 29 | ``` 30 | 31 | ## Warmup Schedule 32 | 33 |

34 | Warmup schedule
35 | RAdam and rule-of-thumb warmup schedules over time for β₂ = 0.999. 36 |

37 | 38 | Run the Python script `warmup_schedule.py` to show up the figure above: 39 | 40 | ```shell 41 | python warmup_schedule.py 42 | ``` 43 | 44 | ### Usage 45 | 46 | ``` 47 | usage: warmup_schedule.py [-h] [--output {none,png,pdf}] 48 | 49 | Warmup schedule 50 | 51 | options: 52 | -h, --help show this help message and exit 53 | --output {none,png,pdf} 54 | Output file type (default: none) 55 | ``` 56 | 57 | © 2024 Takenori Yamamoto -------------------------------------------------------------------------------- /examples/plots/effective_warmup_period.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | from pytorch_warmup import rho_fn, rho_inf_fn, get_offset 4 | import numpy as np 5 | 6 | 7 | def untuned_exponential_period(beta2): 8 | return 1.0 / (np.exp(1.0 - beta2) - 1.0) 9 | 10 | 11 | def untuned_linear_period(beta2): 12 | return 1.0 / (1.0 - beta2) - 0.5 13 | 14 | 15 | def warmup_factor(step, beta2, rho_inf, offset): 16 | rho = rho_fn(step+offset, beta2, rho_inf) 17 | numerator = (rho - 4) * (rho - 2) * rho_inf 18 | denominator = (rho_inf - 4) * (rho_inf - 2) * rho 19 | return np.sqrt(numerator/denominator) 20 | 21 | 22 | def radam_period_fn(beta2): 23 | rho_inf = rho_inf_fn(beta2) 24 | offset = get_offset(beta2, rho_inf) 25 | steps = np.arange(1, 101) 26 | w = warmup_factor(steps, beta2, rho_inf, offset) 27 | total_sum = np.sum(1-w) 28 | t = 1 29 | while True: 30 | steps = np.arange(100*t+1, 100*(t+1)+1) 31 | w = warmup_factor(steps, beta2, rho_inf, offset) 32 | partial_sum = np.sum(1-w) 33 | if partial_sum < 0.1: 34 | break 35 | total_sum += partial_sum 36 | t += 1 37 | return total_sum 38 | 39 | 40 | def radam_period(beta2): 41 | return [radam_period_fn(x) for x in beta2] 42 | 43 | 44 | parser = argparse.ArgumentParser(description='Effective warmup period') 45 | parser.add_argument('--output', type=str, default='none', 46 | choices=['none', 'png', 'pdf'], 47 | help='Output file type (default: none)') 48 | args = parser.parse_args() 49 | 50 | beta2 = np.arange(0.99, 0.9999, 0.0001) 51 | plt.plot(beta2, untuned_exponential_period(beta2), label='Untuned Exponential') 52 | plt.plot(beta2, untuned_linear_period(beta2), linestyle=':', label='Untuned Linear') 53 | plt.plot(beta2, radam_period(beta2), linestyle='--', label='RAdam') 54 | plt.xlim(0.990, 1.00) 55 | plt.ylim(100, 10000) 56 | plt.yscale('log') 57 | plt.legend() 58 | plt.title('Effective Warmup Period') 59 | plt.xlabel(r'$\beta_{2}$') 60 | plt.ylabel(r'${\cal T}(\omega)$') 61 | if args.output == 'none': 62 | plt.show() 63 | else: 64 | plt.savefig(f'warmup_period.{args.output}') 65 | -------------------------------------------------------------------------------- /examples/plots/figs/warmup_period.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/plots/figs/warmup_period.png -------------------------------------------------------------------------------- /examples/plots/figs/warmup_schedule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/examples/plots/figs/warmup_schedule.png -------------------------------------------------------------------------------- /examples/plots/warmup_schedule.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | import torch 4 | from pytorch_warmup import RAdamWarmup, UntunedExponentialWarmup, UntunedLinearWarmup 5 | 6 | 7 | def get_rates(warmup_cls, beta2, max_step): 8 | rates = [] 9 | p = torch.nn.Parameter(torch.arange(10, dtype=torch.float32)) 10 | optimizer = torch.optim.Adam([{'params': p}], lr=1.0, betas=(0.9, beta2)) 11 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0) 12 | warmup_scheduler = warmup_cls(optimizer) 13 | for step in range(1, max_step+1): 14 | rates.append(optimizer.param_groups[0]['lr']) 15 | optimizer.zero_grad() 16 | optimizer.step() 17 | lr_scheduler.step() 18 | warmup_scheduler.dampen() 19 | return rates 20 | 21 | 22 | parser = argparse.ArgumentParser(description='Warmup schedule') 23 | parser.add_argument('--output', type=str, default='none', 24 | choices=['none', 'png', 'pdf'], 25 | help='Output file type (default: none)') 26 | args = parser.parse_args() 27 | 28 | beta2 = 0.999 29 | max_step = 3000 30 | 31 | plt.plot(range(1, max_step+1), get_rates(RAdamWarmup, beta2, max_step), label='RAdam') 32 | plt.plot(range(1, max_step+1), get_rates(UntunedExponentialWarmup, beta2, max_step), label='Untuned Exponential') 33 | plt.plot(range(1, max_step+1), get_rates(UntunedLinearWarmup, beta2, max_step), label='Untuned Linear') 34 | plt.legend() 35 | plt.title('Warmup Schedule') 36 | plt.xlabel('Iteration') 37 | plt.ylabel(r'Warmup factor $(\omega_t)$') 38 | if args.output == 'none': 39 | plt.show() 40 | else: 41 | plt.savefig(f'warmup_schedule.{args.output}') 42 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | 6 | [project] 7 | name = "pytorch-warmup" 8 | description = "A PyTorch Extension for Learning Rate Warmup" 9 | readme = "README.md" 10 | license = {file = "LICENSE"} 11 | authors = [ 12 | { name = "Takenori Yamamoto", email = "yamamoto.takenory@gmail.com" } 13 | ] 14 | classifiers = [ 15 | "Development Status :: 5 - Production/Stable", 16 | "Intended Audience :: Developers", 17 | "Intended Audience :: Education", 18 | "Intended Audience :: Science/Research", 19 | "Topic :: Scientific/Engineering", 20 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 21 | "Topic :: Software Development", 22 | "Topic :: Software Development :: Libraries", 23 | "Topic :: Software Development :: Libraries :: Python Modules", 24 | "Programming Language :: Python :: 3", 25 | "Programming Language :: Python :: 3.9", 26 | "Programming Language :: Python :: 3.10", 27 | "Programming Language :: Python :: 3.11", 28 | "Programming Language :: Python :: 3.12", 29 | "Programming Language :: Python :: 3 :: Only", 30 | "Operating System :: OS Independent", 31 | "License :: OSI Approved :: MIT License", 32 | ] 33 | requires-python = ">=3.9" 34 | dependencies = ["torch>=1.9"] 35 | dynamic = ["version"] 36 | 37 | 38 | [project.urls] 39 | "Homepage" = "https://github.com/Tony-Y/pytorch_warmup" 40 | "Bug Reports" = "https://github.com/Tony-Y/pytorch_warmup/issues" 41 | 42 | 43 | [tool.setuptools] 44 | packages = ["pytorch_warmup"] 45 | 46 | 47 | [tool.setuptools.dynamic] 48 | version = { attr = "pytorch_warmup.__version__" } 49 | -------------------------------------------------------------------------------- /pytorch_warmup/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseWarmup, LinearWarmup, ExponentialWarmup 2 | from .untuned import UntunedLinearWarmup, UntunedExponentialWarmup 3 | from .radam import RAdamWarmup, rho_fn, rho_inf_fn, get_offset 4 | 5 | __version__ = "0.3.0.dev0" 6 | 7 | __all__ = [ 8 | 'BaseWarmup', 9 | 'LinearWarmup', 10 | 'ExponentialWarmup', 11 | 'UntunedLinearWarmup', 12 | 'UntunedExponentialWarmup', 13 | 'RAdamWarmup', 14 | 'rho_fn', 15 | 'rho_inf_fn', 16 | 'get_offset', 17 | ] 18 | -------------------------------------------------------------------------------- /pytorch_warmup/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | from contextlib import contextmanager 3 | from torch import Tensor 4 | from torch.optim import Optimizer 5 | 6 | 7 | def _check_optimizer(optimizer): 8 | if not isinstance(optimizer, Optimizer): 9 | raise TypeError('{} ({}) is not an Optimizer.'.format( 10 | optimizer, type(optimizer).__name__)) 11 | 12 | 13 | def _get_lr(group): 14 | if isinstance(group['lr'], Tensor): 15 | return group['lr'].clone().detach() 16 | else: 17 | return group['lr'] 18 | 19 | 20 | def _set_lr(group, lr): 21 | if isinstance(group['lr'], Tensor): 22 | group['lr'].copy_(lr) 23 | else: 24 | group['lr'] = lr 25 | 26 | 27 | class BaseWarmup: 28 | """Base class for all warmup schedules. 29 | 30 | The learning rate :math:`\\alpha_{t}` is dampened by multiplying it by 31 | the warmup factor :math:`\\omega_{t} \\in [0, 1]` at each iteration :math:`t`. 32 | Thus, the modified learning rate 33 | 34 | .. math:: 35 | \\hat \\alpha_{t} = \\alpha_{t} \\cdot \\omega_{t} 36 | 37 | is used by the optimizer. 38 | 39 | Args: 40 | optimizer (Optimizer): Wrapped optimizer. 41 | warmup_params (list): Warmup parameters. 42 | last_step (int): The index of last step. Default: -1. 43 | """ 44 | 45 | def __init__(self, optimizer, warmup_params, last_step=-1): 46 | self.optimizer = optimizer 47 | self.warmup_params = warmup_params 48 | self.last_step = last_step 49 | self.lrs = [_get_lr(group) for group in self.optimizer.param_groups] 50 | self.dampen() 51 | 52 | def state_dict(self): 53 | """Returns the state of the warmup scheduler as a :class:`dict`. 54 | 55 | It contains an entry for every variable in :attr:`self.__dict__` which 56 | is not the optimizer. 57 | """ 58 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 59 | 60 | def load_state_dict(self, state_dict): 61 | """Loads the warmup scheduler's state. 62 | 63 | Args: 64 | state_dict (dict): Warmup scheduler state. Should be an object returned 65 | from a call to :meth:`state_dict`. 66 | """ 67 | self.__dict__.update(state_dict) 68 | 69 | def dampen(self, step=None): 70 | """Dampens the learning rate. 71 | 72 | It is not recommended to explicitly call this method for PyTorch 1.4.0 or later. 73 | Please use the :meth:`dampening` context manager that calls this method correctly. 74 | 75 | Args: 76 | step (int): The index of current step. Default: ``None``. 77 | """ 78 | if step is None: 79 | step = self.last_step + 1 80 | self.last_step = step 81 | 82 | for group, params in zip(self.optimizer.param_groups, self.warmup_params): 83 | omega = self.warmup_factor(step, **params) 84 | group['lr'] *= omega 85 | 86 | @contextmanager 87 | def dampening(self): 88 | """Dampens the learning rate after calling the :meth:`step` method of the learning 89 | rate scheduler. 90 | 91 | The :meth:`step` method calls must be placed in a suite of the ``with`` statement having 92 | the :meth:`dampening` context manager. 93 | 94 | Examples: 95 | >>> # For no LR scheduler 96 | >>> with warmup_scheduler.dampening(): 97 | >>> pass 98 | 99 | >>> # For a single LR scheduler 100 | >>> with warmup_scheduler.dampening(): 101 | >>> lr_scheduler.step() 102 | 103 | >>> # To chain two LR schedulers 104 | >>> with warmup_scheduler.dampening(): 105 | >>> lr_scheduler1.step() 106 | >>> lr_scheduler2.step() 107 | 108 | >>> # To delay an LR scheduler 109 | >>> iteration = warmup_scheduler.last_step + 1 110 | >>> with warmup_scheduler.dampening(): 111 | >>> if iteration >= warmup_period: 112 | >>> lr_scheduler.step() 113 | """ 114 | for group, lr in zip(self.optimizer.param_groups, self.lrs): 115 | _set_lr(group, lr) 116 | yield 117 | self.lrs = [_get_lr(group) for group in self.optimizer.param_groups] 118 | self.dampen() 119 | 120 | def warmup_factor(self, step, **params): 121 | """Returns the warmup factor :math:`\\omega_{t}` at an iteration :math:`t`. 122 | 123 | :meth:`dampen` uses this method to get the warmup factor for each parameter group. 124 | It is unnecessary to explicitly call this method. 125 | 126 | Args: 127 | step (int): The index of current step. 128 | params (dict): The warmup parameters. For details, refer to the arguments of 129 | each subclass method. 130 | """ 131 | raise NotImplementedError 132 | 133 | 134 | def get_warmup_params(warmup_period, group_count): 135 | if isinstance(warmup_period, list): 136 | if len(warmup_period) != group_count: 137 | raise ValueError( 138 | 'The size of warmup_period ({}) does not match the size of param_groups ({}).'.format( 139 | len(warmup_period), group_count)) 140 | for x in warmup_period: 141 | if not isinstance(x, int): 142 | raise TypeError( 143 | 'An element in warmup_period, {}, is not an int.'.format( 144 | type(x).__name__)) 145 | if x <= 0: 146 | raise ValueError( 147 | 'An element in warmup_period must be a positive integer, but is {}.'.format(x)) 148 | warmup_params = [dict(warmup_period=x) for x in warmup_period] 149 | elif isinstance(warmup_period, int): 150 | if warmup_period <= 0: 151 | raise ValueError( 152 | 'warmup_period must be a positive integer, but is {}.'.format(warmup_period)) 153 | warmup_params = [dict(warmup_period=warmup_period) 154 | for _ in range(group_count)] 155 | else: 156 | raise TypeError('{} ({}) is not a list nor an int.'.format( 157 | warmup_period, type(warmup_period).__name__)) 158 | return warmup_params 159 | 160 | 161 | class LinearWarmup(BaseWarmup): 162 | """Linear warmup schedule. 163 | 164 | The linear warmup schedule uses the warmup factor 165 | 166 | .. math:: 167 | \\omega_{t}^{\\rm linear, \\tau} = \\min \\left\\{ 1, \\frac{1}{\\tau} \\cdot t \\right\\} 168 | 169 | at each iteration :math:`t`, where :math:`\\tau` is the warmup period. 170 | 171 | Args: 172 | optimizer (Optimizer): Wrapped optimizer. :class:`RAdam` is not suitable because of the 173 | warmup redundancy. 174 | warmup_period (int or list[int]): The warmup period :math:`\\tau`. 175 | last_step (int): The index of last step. Default: -1. 176 | 177 | Example: 178 | >>> lr_scheduler = CosineAnnealingLR(optimizer, ...) 179 | >>> warmup_scheduler = LinearWarmup(optimizer, warmup_period=2000) 180 | >>> for batch in dataloader: 181 | >>> optimizer.zero_grad() 182 | >>> loss = ... 183 | >>> loss.backward() 184 | >>> optimizer.step() 185 | >>> with warmup_scheduler.dampening(): 186 | >>> lr_scheduler.step() 187 | 188 | Warning: 189 | The warmup schedule must not be initialized before the initialization of the learning rate schedule. 190 | """ 191 | 192 | def __init__(self, optimizer, warmup_period, last_step=-1): 193 | _check_optimizer(optimizer) 194 | group_count = len(optimizer.param_groups) 195 | warmup_params = get_warmup_params(warmup_period, group_count) 196 | super().__init__(optimizer, warmup_params, last_step) 197 | 198 | def warmup_factor(self, step, warmup_period): 199 | """Returns the warmup factor :math:`\\omega_{t}^{\\rm linear, \\tau}` at an iteration :math:`t`. 200 | 201 | Args: 202 | step (int): The index of current step. 203 | warmup_period (int): The warmup period :math:`\\tau`. 204 | """ 205 | return min(1.0, (step+1) / warmup_period) 206 | 207 | 208 | class ExponentialWarmup(BaseWarmup): 209 | """Exponential warmup schedule. 210 | 211 | The exponential warmup schedule uses the warmup factor 212 | 213 | .. math:: 214 | \\omega_{t}^{\\rm expo, \\tau} = 1 - \\exp \\left( - \\frac{1}{\\tau} \\cdot t \\right) 215 | 216 | at each iteration :math:`t`, where the constant :math:`\\tau` is analogous to 217 | a linear warmup period. 218 | 219 | Args: 220 | optimizer (Optimizer): Wrapped optimizer. :class:`RAdam` is not suitable because of the 221 | warmup redundancy. 222 | warmup_period (int or list[int]): The constant :math:`\\tau` analogous to a linear warmup period. 223 | last_step (int): The index of last step. Default: -1. 224 | 225 | Example: 226 | >>> lr_scheduler = CosineAnnealingLR(optimizer, ...) 227 | >>> warmup_scheduler = ExponentialWarmup(optimizer, warmup_period=1000) 228 | >>> for batch in dataloader: 229 | >>> optimizer.zero_grad() 230 | >>> loss = ... 231 | >>> loss.backward() 232 | >>> optimizer.step() 233 | >>> with warmup_scheduler.dampening(): 234 | >>> lr_scheduler.step() 235 | 236 | Warning: 237 | The warmup schedule must not be initialized before the initialization of the learning rate schedule. 238 | """ 239 | 240 | def __init__(self, optimizer, warmup_period, last_step=-1): 241 | _check_optimizer(optimizer) 242 | group_count = len(optimizer.param_groups) 243 | warmup_params = get_warmup_params(warmup_period, group_count) 244 | super().__init__(optimizer, warmup_params, last_step) 245 | 246 | def warmup_factor(self, step, warmup_period): 247 | """Returns the warmup factor :math:`\\omega_{t}^{\\rm expo, \\tau}` at an iteration :math:`t`. 248 | 249 | Args: 250 | step (int): The index of current step. 251 | warmup_period (int): The constant :math:`\\tau` analogous to a linear warmup period. 252 | """ 253 | return 1.0 - math.exp(-(step+1) / warmup_period) 254 | -------------------------------------------------------------------------------- /pytorch_warmup/radam.py: -------------------------------------------------------------------------------- 1 | import math 2 | from .base import BaseWarmup, _check_optimizer 3 | 4 | 5 | def rho_inf_fn(beta2): 6 | """Returns the constant of the RAdam algorithm, :math:`\\rho_{\\infty}`. 7 | 8 | Args: 9 | beta2 (float): The second discount factor of Adam, :math:`\\beta_{2}`. 10 | """ 11 | return 2.0 / (1 - beta2) - 1 12 | 13 | 14 | def rho_fn(t, beta2, rho_inf): 15 | """Returns the value of the function of the RAdam algorithm, :math:`\\rho_{t}`, 16 | at an iteration :math:`t`. 17 | 18 | Args: 19 | t (int): The iteration :math:`t`. 20 | beta2 (float): The second discount factor of Adam, :math:`\\beta_{2}`. 21 | rho_inf (float): The constant of the RAdam algorithm, :math:`\\rho_{\\infty}`. 22 | """ 23 | b2t = beta2 ** t 24 | rho_t = rho_inf - 2 * t * b2t / (1 - b2t) 25 | return rho_t 26 | 27 | 28 | def get_offset(beta2, rho_inf): 29 | """Returns the minimal offset :math:`\\delta`. 30 | 31 | Args: 32 | beta2 (float): The second discount factor of Adam, :math:`\\beta_{2}`. 33 | rho_inf (float): The constant of the RAdam algorithm, :math:`\\rho_{\\infty}`. 34 | """ 35 | if not beta2 > 0.6: 36 | raise ValueError('beta2 ({}) must be greater than 0.6'.format(beta2)) 37 | offset = 1 38 | while True: 39 | if rho_fn(offset, beta2, rho_inf) > 4: 40 | return offset 41 | offset += 1 42 | 43 | 44 | class RAdamWarmup(BaseWarmup): 45 | """RAdam warmup schedule. 46 | 47 | This warmup scheme is described in 48 | `On the adequacy of untuned warmup for adaptive optimization 49 | `_. 50 | 51 | The RAdam algorithm uses the warmup factor 52 | 53 | .. math:: 54 | \\omega_{t}^{\\rm RAdam} = \\sqrt{ \\frac{ \\ 55 | ( \\rho_{t} - 4 ) ( \\rho_{t} - 2 ) \\rho_{\\infty} }{ \\ 56 | ( \\rho_{\\infty} - 4) (\\rho_{\\infty} - 2 ) \\rho_{t} } } 57 | 58 | at each iteration :math:`t` for :math:`\\rho_{t} > 4`, where 59 | 60 | .. math:: 61 | \\rho_{\\infty} = \\frac{ 2 }{ 1 - \\beta_{2} } - 1 62 | 63 | and 64 | 65 | .. math:: 66 | \\rho_{t} = \\rho_{\\infty} - \\frac{ 2 t \\cdot \\beta_{2}^{t} }{ 1 - \\beta_{2}^{t} } 67 | 68 | where :math:`\\beta_{2}` is the second discount factor of Adam. In the RAdam warmup schedule, 69 | the minimal offset :math:`\\delta` is chosen such that :math:`\\rho_{\\delta} > 4`, and then 70 | :math:`\\omega_{t+\\delta-1}^{\\rm RAdam}` is employed as the warmup factor at each iteration :math:`t`. 71 | For all practically relevant values of :math:`\\beta_{2}` (:math:`0.8 < \\beta_{2} \\le 1`), 72 | :math:`\\delta \\le 5` as deduced from Fact 3.1 of the paper. 73 | 74 | Args: 75 | optimizer (Optimizer): Adam optimizer or its variant: 76 | :class:`Adam`, :class:`AdamW`, :class:`SparseAdam`, or :class:`NAdam`. 77 | :class:`RAdam` is not suitable because of the warmup redundancy. This warmup 78 | schedule makes no sense for :class:`Adamax` and, in principle, the AMSGrad variant of 79 | :class:`Adam` and :class:`AdamW` as discussed in Note below. In practice, this warmup 80 | schedule improves the performance of the AMSGrad variant like that of the vanilla Adam. 81 | last_step (int): The index of last step. Default: -1. 82 | 83 | Note: 84 | This warmup schedule employs the same warmup factor for all variants of Adam. However, 85 | according to the RAdam theory, 86 | :class:`Adamax` and the AMSGrad variant of :class:`Adam` and :class:`AdamW` should 87 | have a different warmup factor because its :math:`\\psi(\\cdot)` function is different from one of the 88 | vanilla Adam, where :math:`\\psi(\\cdot)` specifies how the adaptive learning rate at :math:`t` is 89 | calculated. The RAdam theory derives the warmup factor :math:`\\omega_{t}` from 90 | :math:`\\psi(\\cdot)`. For gradients :math:`\\left\\{ g_{i} \\right\\}` viewed as i.i.d. normal random 91 | variables, 92 | 93 | .. centered:: 94 | :math:`\\omega_{t} = \\sqrt{ C_{\\rm var} / {\\rm Var}\\left[ \\psi(g_{1}, \\dots, g_{t}) \\right] }` 95 | 96 | where 97 | 98 | .. centered:: 99 | :math:`C_{\\rm var} = \\inf_{t} {\\rm Var}\\left[ \\psi(g_{1}, \\dots, g_{t}) \\right]`. 100 | 101 | (For details please refer to `On the Variance of the Adaptive Learning Rate and Beyond 102 | `_.) 103 | 104 | The variance hypothesis of the RAdam theory has become questionable 105 | since Ma and Yarats' paper pointed out that the adaptive learning rate may not be the best medium 106 | of analysis for understanding the role of warmup in Adam. 107 | 108 | Example: 109 | >>> optimizer = AdamW(...) 110 | >>> lr_scheduler = CosineAnnealingLR(optimizer, ...) 111 | >>> warmup_scheduler = RAdamWarmup(optimizer) 112 | >>> for batch in dataloader: 113 | >>> optimizer.zero_grad() 114 | >>> loss = ... 115 | >>> loss.backward() 116 | >>> optimizer.step() 117 | >>> with warmup_scheduler.dampening(): 118 | >>> lr_scheduler.step() 119 | 120 | Warning: 121 | The warmup schedule must not be initialized before the initialization of the learning rate schedule. 122 | """ 123 | 124 | def __init__(self, optimizer, last_step=-1): 125 | _check_optimizer(optimizer) 126 | warmup_params = [ 127 | dict( 128 | beta2=x['betas'][1], 129 | rho_inf=rho_inf_fn(x['betas'][1]), 130 | ) 131 | for x in optimizer.param_groups 132 | ] 133 | for x in warmup_params: 134 | x['offset'] = get_offset(**x) 135 | super().__init__(optimizer, warmup_params, last_step) 136 | 137 | def warmup_factor(self, step, beta2, rho_inf, offset): 138 | """Returns the warmup factor :math:`\\omega_{t+\\delta-1}^{\\rm RAdam}` at an iteration :math:`t`. 139 | 140 | Args: 141 | step (int): The index of current step. 142 | beta2 (float): The second discount factor of Adam, :math:`\\beta_{2}`. 143 | rho_inf (float): The constant of the RAdam algorithm, :math:`\\rho_{\\infty}`. 144 | offset (int): The minimal offset :math:`\\delta`. 145 | """ 146 | rho = rho_fn(step+offset, beta2, rho_inf) 147 | numerator = (rho - 4) * (rho - 2) * rho_inf 148 | denominator = (rho_inf - 4) * (rho_inf - 2) * rho 149 | return math.sqrt(numerator/denominator) 150 | -------------------------------------------------------------------------------- /pytorch_warmup/untuned.py: -------------------------------------------------------------------------------- 1 | from .base import LinearWarmup, ExponentialWarmup, _check_optimizer 2 | 3 | 4 | class UntunedLinearWarmup(LinearWarmup): 5 | """Untuned linear warmup schedule for Adam. 6 | 7 | This warmup scheme is described in 8 | `On the adequacy of untuned warmup for adaptive optimization 9 | `_. 10 | 11 | The untuned linear warmup schedule uses the warmup factor 12 | 13 | .. math:: 14 | \\omega_{t}^{\\rm linear, untuned} = \\min \\left\\{ 1, \\frac{1 - \\beta_{2}}{2} \\cdot t \\right\\} 15 | 16 | at each iteration :math:`t`, where :math:`\\beta_{2}` is the second discount factor of Adam. 17 | In practice, :math:`\\omega_{t}^{\\rm linear, untuned}` is calculated as 18 | :math:`\\omega_{t}^{\\rm linear, \\tau}` with :math:`\\tau = \\frac{2}{1 - \\beta_{2}}`. 19 | 20 | Note: 21 | The effective warmup period is defined as 22 | 23 | .. centered:: 24 | :math:`{\\cal T}(\\omega) = \\sum_{t = 1}^{\\infty} \\left( 1 - \\omega_{t} \\right)` 25 | 26 | for a warmup schedule :math:`\\omega = \\left\\{ \\omega_{t} \\right\\}_{t=1}^{\\infty}`. 27 | The warmup period :math:`\\tau` is deduced from solving approximately the rough equivalence: 28 | 29 | .. centered:: 30 | :math:`{\\cal T}(\\omega^{\\rm expo, untuned}) \\approx {\\cal T}(\\omega^{{\\rm linear}, 31 | \\tau}) \\approx \\frac{\\tau}{2}`. 32 | 33 | Args: 34 | optimizer (Optimizer): Adam optimizer or its variant: 35 | :class:`Adam`, :class:`AdamW`, :class:`SparseAdam`, or :class:`NAdam`. 36 | :class:`RAdam` is not suitable because of the warmup redundancy. This warmup 37 | schedule makes no sense for :class:`Adamax` as discussed in Note below. 38 | last_step (int): The index of last step. Default: -1. 39 | 40 | Note: 41 | This warmup schedule employs the same warmup period :math:`\\tau` for all variants of Adam. However, 42 | :class:`Adamax` should in principle need no linear warmup because it needs no exponential warmup. 43 | For further details please refer to Note in the documentation of :class:`UntunedExponentialWarmup`. 44 | In practice, a linear warmup may slightly improve AdaMax's performance because the initial update step 45 | is the same as one of the Adam optimizer. 46 | 47 | Example: 48 | >>> optimizer = AdamW(...) 49 | >>> lr_scheduler = CosineAnnealingLR(optimizer, ...) 50 | >>> warmup_scheduler = UntunedLinearWarmup(optimizer) 51 | >>> for batch in dataloader: 52 | >>> optimizer.zero_grad() 53 | >>> loss = ... 54 | >>> loss.backward() 55 | >>> optimizer.step() 56 | >>> with warmup_scheduler.dampening(): 57 | >>> lr_scheduler.step() 58 | 59 | Warning: 60 | The warmup schedule must not be initialized before the initialization of the learning rate schedule. 61 | """ 62 | 63 | def __init__(self, optimizer, last_step=-1): 64 | _check_optimizer(optimizer) 65 | 66 | def warmup_period_fn(beta2): 67 | return int(2.0 / (1.0-beta2)) 68 | warmup_period = [warmup_period_fn(x['betas'][1]) for x in optimizer.param_groups] 69 | super().__init__(optimizer, warmup_period, last_step) 70 | 71 | 72 | class UntunedExponentialWarmup(ExponentialWarmup): 73 | """Untuned exponential warmup schedule for Adam. 74 | 75 | This warmup scheme is described in 76 | `On the adequacy of untuned warmup for adaptive optimization 77 | `_. 78 | 79 | The untuned exponential warmup schedule uses the warmup factor 80 | 81 | .. math:: 82 | \\omega_{t}^{\\rm expo, untuned} = 1 - \\exp \\left( - (1 - \\beta_{2}) \\cdot t \\right) 83 | 84 | at each iteration :math:`t`, where :math:`\\beta_{2}` is the second discount factor of Adam. 85 | In practice, :math:`\\omega_{t}^{\\rm expo, untuned}` is calculated as 86 | :math:`\\omega_{t}^{\\rm expo, \\tau}` with :math:`\\tau = \\frac{1}{1 - \\beta_{2}}`. 87 | 88 | Note: 89 | The constant :math:`\\tau` is derived from the intuition that 90 | the warmup factor should be roughly equivalent to Adam's second moment bias correction term, 91 | :math:`1 - \\beta_{2}^{t}`. 92 | 93 | Note: 94 | The effective warmup period is defined as 95 | 96 | .. centered:: 97 | :math:`{\\cal T}(\\omega) = \\sum_{t = 1}^{\\infty} \\left( 1 - \\omega_{t} \\right)` 98 | 99 | for a warmup schedule :math:`\\omega = \\left\\{ \\omega_{t} \\right\\}_{t=1}^{\\infty}`. 100 | The constant :math:`\\tau` of the untuned exponential warmup schedule is roughly equivalent to 101 | its effective warmup period: 102 | 103 | .. centered:: 104 | :math:`{\\cal T}(\\omega^{\\rm expo, untuned}) = 1 / \\left( \\exp( 1 - \\beta_{2}) - 1 \\right) \\approx \\tau` 105 | 106 | for :math:`\\beta_{2}` near 1. The rough equivalence is also achieved for an exponential warmup schedule 107 | if its :math:`\\tau` is large enough, for example, :math:`\\tau \\ge 1`. 108 | 109 | Args: 110 | optimizer (Optimizer): Adam optimizer or its variant: 111 | :class:`Adam`, :class:`AdamW`, :class:`SparseAdam`, or :class:`NAdam`. 112 | :class:`RAdam` is not suitable because of the warmup redundancy. This warmup 113 | schedule makes no sense for :class:`Adamax` as discussed in Note below. 114 | last_step (int): The index of last step. Default: -1. 115 | 116 | Note: 117 | This warmup schedule employs the same constant :math:`\\tau` for all variants of Adam. However, 118 | :class:`Adamax` should in principle need no warmup because :class:`Adamax` is derived by employing 119 | a :math:`L^{p}` norm update rule and letting :math:`p \\rightarrow \\infty`, and the second moment bias 120 | correction term is :math:`1-\\beta_{2}^{pt}`, to which the warmup factor must be roughly equivalent 121 | in this warmup schedule derivation. In practice, an exponential warmup may slightly improve AdaMax's 122 | performance because the initial update step is the same as one of the Adam optimizer. 123 | 124 | Example: 125 | >>> optimizer = AdamW(...) 126 | >>> lr_scheduler = CosineAnnealingLR(optimizer, ...) 127 | >>> warmup_scheduler = UntunedExponentialWarmup(optimizer) 128 | >>> for batch in dataloader: 129 | >>> optimizer.zero_grad() 130 | >>> loss = ... 131 | >>> loss.backward() 132 | >>> optimizer.step() 133 | >>> with warmup_scheduler.dampening(): 134 | >>> lr_scheduler.step() 135 | 136 | Warning: 137 | The warmup schedule must not be initialized before the initialization of the learning rate schedule. 138 | """ 139 | 140 | def __init__(self, optimizer, last_step=-1): 141 | _check_optimizer(optimizer) 142 | 143 | def warmup_period_fn(beta2): 144 | return int(1.0 / (1.0-beta2)) 145 | warmup_period = [warmup_period_fn(x['betas'][1]) for x in optimizer.param_groups] 146 | super().__init__(optimizer, warmup_period, last_step) 147 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony-Y/pytorch_warmup/ed598d3aefc232aeec7f342e039e06c2e0056251/test/__init__.py -------------------------------------------------------------------------------- /test/test_base.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import math 4 | import torch 5 | from torch import Tensor 6 | import pytorch_warmup as warmup 7 | 8 | 9 | def _test_state_dict(self, warmup_scheduler, constructor): 10 | warmup_scheduler_copy = constructor() 11 | warmup_scheduler_copy.load_state_dict(warmup_scheduler.state_dict()) 12 | for key in warmup_scheduler.__dict__.keys(): 13 | if key not in ['optimizer']: 14 | print(key, warmup_scheduler.__dict__[key]) 15 | self.assertAlmostEqual(warmup_scheduler.__dict__[key], 16 | warmup_scheduler_copy.__dict__[key]) 17 | 18 | 19 | def _test_optimizer(self, warmup_class): 20 | with self.assertRaises(TypeError, msg='optimizer type') as cm: 21 | warmup_class(optimizer=0, warmup_period=5) 22 | self.assertEqual(str(cm.exception), '0 (int) is not an Optimizer.') 23 | 24 | 25 | def _test_get_warmup_params(self, optimizer, warmup_class): 26 | with self.assertRaises(ValueError, msg='warmup_period size') as cm: 27 | warmup_class(optimizer, warmup_period=[5]) 28 | self.assertEqual(str(cm.exception), 'The size of warmup_period (1) does not match the size of param_groups (2).') 29 | 30 | with self.assertRaises(TypeError, msg='warmup_period element type') as cm: 31 | warmup_class(optimizer, warmup_period=[5.0, 10.0]) 32 | self.assertEqual(str(cm.exception), 'An element in warmup_period, float, is not an int.') 33 | 34 | with self.assertRaises(ValueError, msg='warmup_period element range') as cm: 35 | warmup_class(optimizer, warmup_period=[5, 0]) 36 | self.assertEqual(str(cm.exception), 'An element in warmup_period must be a positive integer, but is 0.') 37 | 38 | with self.assertRaises(ValueError, msg='warmup_period range') as cm: 39 | warmup_class(optimizer, warmup_period=0) 40 | self.assertEqual(str(cm.exception), 'warmup_period must be a positive integer, but is 0.') 41 | 42 | with self.assertRaises(TypeError, msg='warmup_period type') as cm: 43 | warmup_class(optimizer, warmup_period=5.0) 44 | self.assertEqual(str(cm.exception), '5.0 (float) is not a list nor an int.') 45 | 46 | 47 | def _wrap(value): 48 | return torch.tensor(value) 49 | 50 | 51 | def _unwrap(value): 52 | assert isinstance(value, Tensor) 53 | assert value.dim() == 0 54 | return value.item() 55 | 56 | 57 | def _identity(value): 58 | return value 59 | 60 | 61 | def _assert_naked(value): 62 | assert not isinstance(value, Tensor) 63 | return value 64 | 65 | 66 | _set_lr, _get_lr = (_wrap, _unwrap) if 'WRAPPED_LR' in os.environ else (_identity, _assert_naked) 67 | 68 | 69 | class TestBase(unittest.TestCase): 70 | 71 | def setUp(self): 72 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 73 | 74 | def test_linear(self): 75 | p1 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device)) 76 | p2 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device)) 77 | optimizer = torch.optim.SGD([ 78 | {'params': [p1]}, 79 | {'params': [p2], 'lr': _set_lr(0.1)} 80 | ], lr=_set_lr(0.5)) 81 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0) 82 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=5) 83 | print() 84 | for step in range(1, 11): 85 | lr = [_get_lr(x['lr']) for x in optimizer.param_groups] 86 | print(f'{step} {lr}') 87 | if step < 5: 88 | self.assertAlmostEqual(lr[0], 0.5 * step / 5) 89 | self.assertAlmostEqual(lr[1], 0.1 * step / 5) 90 | else: 91 | self.assertAlmostEqual(lr[0], 0.5) 92 | self.assertAlmostEqual(lr[1], 0.1) 93 | optimizer.zero_grad() 94 | optimizer.step() 95 | with warmup_scheduler.dampening(): 96 | lr_scheduler.step() 97 | 98 | _test_state_dict(self, warmup_scheduler, 99 | lambda: warmup.LinearWarmup(optimizer, warmup_period=10)) 100 | 101 | _test_optimizer(self, warmup.LinearWarmup) 102 | 103 | _test_get_warmup_params(self, optimizer, warmup.LinearWarmup) 104 | 105 | def test_exponential(self): 106 | p1 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device)) 107 | p2 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device)) 108 | optimizer = torch.optim.SGD([ 109 | {'params': [p1]}, 110 | {'params': [p2], 'lr': _set_lr(0.1)} 111 | ], lr=_set_lr(0.5)) 112 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0) 113 | warmup_scheduler = warmup.ExponentialWarmup(optimizer, warmup_period=5) 114 | print() 115 | for step in range(1, 11): 116 | lr = [_get_lr(x['lr']) for x in optimizer.param_groups] 117 | print(f'{step} {lr}') 118 | self.assertAlmostEqual(lr[0], 0.5 * (1 - math.exp(-step / 5))) 119 | self.assertAlmostEqual(lr[1], 0.1 * (1 - math.exp(-step / 5))) 120 | optimizer.zero_grad() 121 | optimizer.step() 122 | with warmup_scheduler.dampening(): 123 | lr_scheduler.step() 124 | 125 | _test_state_dict(self, warmup_scheduler, 126 | lambda: warmup.ExponentialWarmup(optimizer, warmup_period=10)) 127 | 128 | _test_optimizer(self, warmup.ExponentialWarmup) 129 | 130 | _test_get_warmup_params(self, optimizer, warmup.ExponentialWarmup) 131 | 132 | def test_linear_chaining(self): 133 | def preparation(): 134 | p1 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device)) 135 | p2 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device)) 136 | optimizer = torch.optim.SGD([ 137 | {'params': [p1]}, 138 | {'params': [p2], 'lr': _set_lr(0.1)} 139 | ], lr=_set_lr(0.5)) 140 | lr_scheduler1 = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) 141 | lr_scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1) 142 | return optimizer, lr_scheduler1, lr_scheduler2 143 | 144 | # First record undampened lrs 145 | optimizer, lr_scheduler1, lr_scheduler2 = preparation() 146 | lrs = [] 147 | print() 148 | print('Undampened:') 149 | for step in range(1, 11): 150 | lr = [_get_lr(x['lr']) for x in optimizer.param_groups] 151 | print(f'{step} {lr}') 152 | lrs.append(lr) 153 | optimizer.zero_grad() 154 | optimizer.step() 155 | lr_scheduler1.step() 156 | lr_scheduler2.step() 157 | 158 | optimizer, lr_scheduler1, lr_scheduler2 = preparation() 159 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=5) 160 | print('Dampened:') 161 | for step in range(1, 11): 162 | lr = [_get_lr(x['lr']) for x in optimizer.param_groups] 163 | print(f'{step} {lr}') 164 | if step < 5: 165 | self.assertAlmostEqual(lr[0], lrs[step-1][0] * step / 5) 166 | self.assertAlmostEqual(lr[1], lrs[step-1][1] * step / 5) 167 | else: 168 | self.assertAlmostEqual(lr[0], lrs[step-1][0]) 169 | self.assertAlmostEqual(lr[1], lrs[step-1][1]) 170 | optimizer.zero_grad() 171 | optimizer.step() 172 | with warmup_scheduler.dampening(): 173 | lr_scheduler1.step() 174 | lr_scheduler2.step() 175 | 176 | _test_state_dict(self, warmup_scheduler, 177 | lambda: warmup.LinearWarmup(optimizer, warmup_period=10)) 178 | 179 | def test_linear_wo_lr_scheduler(self): 180 | p1 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device)) 181 | p2 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device)) 182 | optimizer = torch.optim.SGD([ 183 | {'params': [p1]}, 184 | {'params': [p2], 'lr': _set_lr(0.1)} 185 | ], lr=_set_lr(0.5)) 186 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=5) 187 | print() 188 | for step in range(1, 11): 189 | lr = [_get_lr(x['lr']) for x in optimizer.param_groups] 190 | print(f'{step} {lr}') 191 | if step < 5: 192 | self.assertAlmostEqual(lr[0], 0.5 * step / 5) 193 | self.assertAlmostEqual(lr[1], 0.1 * step / 5) 194 | else: 195 | self.assertAlmostEqual(lr[0], 0.5) 196 | self.assertAlmostEqual(lr[1], 0.1) 197 | optimizer.zero_grad() 198 | optimizer.step() 199 | with warmup_scheduler.dampening(): 200 | pass 201 | 202 | _test_state_dict(self, warmup_scheduler, 203 | lambda: warmup.LinearWarmup(optimizer, warmup_period=10)) 204 | -------------------------------------------------------------------------------- /test/test_radam.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import math 3 | import torch 4 | import pytorch_warmup as warmup 5 | 6 | from .test_base import _test_state_dict, _set_lr, _get_lr 7 | from .test_untuned import _test_optimizer 8 | 9 | 10 | # The expression of each warmup factor 11 | # offset: 6 12 | # beta2: 7/10 13 | # rho_inf: 17/3 14 | ewf = { 15 | 1: 3*math.sqrt(55795775461652507765)/126458170465, 16 | 2: 3*math.sqrt(118975550786574877912913153615)/2269508758199815, 17 | 3: 3*math.sqrt(364009685132320107701159663753)/2992977113632385, 18 | 4: 3*math.sqrt(258572826689968392763003617038979)/68225651323259287, 19 | 5: 3*math.sqrt(668289519821298522824847043230807053)/3138599717744915303, 20 | 6: 3*math.sqrt(60431582784117573249154657184784100939048735)/27879860688339331112605, 21 | 7: 3*math.sqrt(3668869686599344602586804992292010752258094185)/207030521845988349697045, 22 | 8: 3*math.sqrt(38610293903545800493859693214989542002076301625518651)/648745826249577848268496415, 23 | 9: 3*math.sqrt(12026080263946093429637752207887183661294840713819813)/353035787321509409011021039, 24 | 10: 3*math.sqrt(456865593113897246792694328842050091932272202605586577311)/67546148486329926220639511801, 25 | } 26 | 27 | 28 | class TestRAdam(unittest.TestCase): 29 | 30 | def setUp(self): 31 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 32 | 33 | def test_radam(self): 34 | p1 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device)) 35 | p2 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device)) 36 | optimizer = torch.optim.Adam([ 37 | {'params': [p1]}, 38 | {'params': [p2], 'lr': _set_lr(0.1)} 39 | ], lr=_set_lr(0.5), betas=(0.9, 0.7)) 40 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0) 41 | warmup_scheduler = warmup.RAdamWarmup(optimizer) 42 | print() 43 | for step in range(1, 11): 44 | lr = [_get_lr(x['lr']) for x in optimizer.param_groups] 45 | print(f'{step} {lr}') 46 | self.assertAlmostEqual(lr[0], 0.5 * ewf[step]) 47 | self.assertAlmostEqual(lr[1], 0.1 * ewf[step]) 48 | optimizer.zero_grad() 49 | optimizer.step() 50 | with warmup_scheduler.dampening(): 51 | lr_scheduler.step() 52 | 53 | _test_state_dict(self, warmup_scheduler, 54 | lambda: warmup.RAdamWarmup(optimizer)) 55 | 56 | _test_optimizer(self, warmup.RAdamWarmup) 57 | -------------------------------------------------------------------------------- /test/test_untuned.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import math 3 | import torch 4 | import pytorch_warmup as warmup 5 | 6 | from .test_base import _test_state_dict, _set_lr, _get_lr 7 | 8 | 9 | def _test_optimizer(self, warmup_class): 10 | with self.assertRaises(TypeError, msg='optimizer type') as cm: 11 | warmup_class(optimizer=0) 12 | self.assertEqual(str(cm.exception), '0 (int) is not an Optimizer.') 13 | 14 | 15 | class TestUntuned(unittest.TestCase): 16 | 17 | def setUp(self): 18 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 19 | 20 | def test_untuned_linear(self): 21 | p1 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device)) 22 | p2 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device)) 23 | optimizer = torch.optim.Adam([ 24 | {'params': [p1]}, 25 | {'params': [p2], 'lr': _set_lr(0.1)} 26 | ], lr=_set_lr(0.5), betas=(0.9, 0.7)) 27 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0) 28 | warmup_scheduler = warmup.UntunedLinearWarmup(optimizer) 29 | print() 30 | for step in range(1, 11): 31 | lr = [_get_lr(x['lr']) for x in optimizer.param_groups] 32 | print(f'{step} {lr}') 33 | if step < 6: 34 | self.assertAlmostEqual(lr[0], 0.5 * step / 6) 35 | self.assertAlmostEqual(lr[1], 0.1 * step / 6) 36 | else: 37 | self.assertAlmostEqual(lr[0], 0.5) 38 | self.assertAlmostEqual(lr[1], 0.1) 39 | optimizer.zero_grad() 40 | optimizer.step() 41 | with warmup_scheduler.dampening(): 42 | lr_scheduler.step() 43 | 44 | _test_state_dict(self, warmup_scheduler, 45 | lambda: warmup.UntunedLinearWarmup(optimizer)) 46 | 47 | _test_optimizer(self, warmup.UntunedLinearWarmup) 48 | 49 | def test_untuned_exponential(self): 50 | p1 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device)) 51 | p2 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device)) 52 | optimizer = torch.optim.Adam([ 53 | {'params': [p1]}, 54 | {'params': [p2], 'lr': _set_lr(0.1)} 55 | ], lr=_set_lr(0.5), betas=(0.9, 0.7)) 56 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0) 57 | warmup_scheduler = warmup.UntunedExponentialWarmup(optimizer) 58 | print() 59 | for step in range(1, 11): 60 | lr = [_get_lr(x['lr']) for x in optimizer.param_groups] 61 | print(f'{step} {lr}') 62 | self.assertAlmostEqual(lr[0], 0.5 * (1 - math.exp(-step / 3))) 63 | self.assertAlmostEqual(lr[1], 0.1 * (1 - math.exp(-step / 3))) 64 | optimizer.zero_grad() 65 | optimizer.step() 66 | with warmup_scheduler.dampening(): 67 | lr_scheduler.step() 68 | 69 | _test_state_dict(self, warmup_scheduler, 70 | lambda: warmup.UntunedExponentialWarmup(optimizer)) 71 | 72 | _test_optimizer(self, warmup.UntunedExponentialWarmup) 73 | --------------------------------------------------------------------------------