├── .coveragerc ├── .deepsource.toml ├── .github ├── dependabot.yml └── workflows │ ├── codeql-analysis.yml │ └── python-package.yml ├── .gitignore ├── CHANGES.rst ├── CITATION.cff ├── CONTRIBUTING.rst ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.rst ├── docs ├── Makefile ├── api.rst ├── conf.py ├── contributing.rst ├── examples.rst ├── index.rst ├── make.bat ├── rastrigin_A2GradExp.png ├── rastrigin_A2GradInc.png ├── rastrigin_A2GradUni.png ├── rastrigin_AccSGD.png ├── rastrigin_AdaBelief.png ├── rastrigin_AdaBound.png ├── rastrigin_AdaMod.png ├── rastrigin_Adafactor.png ├── rastrigin_Adahessian.png ├── rastrigin_Adam.png ├── rastrigin_AdamP.png ├── rastrigin_AggMo.png ├── rastrigin_Apollo.png ├── rastrigin_DiffGrad.png ├── rastrigin_Lamb.png ├── rastrigin_LookaheadYogi.png ├── rastrigin_MADGRAD.png ├── rastrigin_NovoGrad.png ├── rastrigin_PID.png ├── rastrigin_QHAdam.png ├── rastrigin_QHM.png ├── rastrigin_RAdam.png ├── rastrigin_Ranger.png ├── rastrigin_RangerQH.png ├── rastrigin_RangerVA.png ├── rastrigin_SGD.png ├── rastrigin_SGDP.png ├── rastrigin_SGDW.png ├── rastrigin_SWATS.png ├── rastrigin_Shampoo.png ├── rastrigin_Yogi.png ├── rosenbrock_A2GradExp.png ├── rosenbrock_A2GradInc.png ├── rosenbrock_A2GradUni.png ├── rosenbrock_AccSGD.png ├── rosenbrock_AdaBelief.png ├── rosenbrock_AdaBound.png ├── rosenbrock_AdaMod.png ├── rosenbrock_Adafactor.png ├── rosenbrock_Adahessian.png ├── rosenbrock_Adam.png ├── rosenbrock_AdamP.png ├── rosenbrock_AggMo.png ├── rosenbrock_Apollo.png ├── rosenbrock_DiffGrad.png ├── rosenbrock_Lamb.png ├── rosenbrock_LookaheadYogi.png ├── rosenbrock_MADGRAD.png ├── rosenbrock_NovoGrad.png ├── rosenbrock_PID.png ├── rosenbrock_QHAdam.png ├── rosenbrock_QHM.png ├── rosenbrock_RAdam.png ├── rosenbrock_Ranger.png ├── rosenbrock_RangerQH.png ├── rosenbrock_RangerVA.png ├── rosenbrock_SGD.png ├── rosenbrock_SGDP.png ├── rosenbrock_SGDW.png ├── rosenbrock_SWATS.png ├── rosenbrock_Shampoo.png └── rosenbrock_Yogi.png ├── examples ├── mnist.py ├── requirements-examples.txt └── viz_optimizers.py ├── requirements-dev.txt ├── setup.py ├── tests ├── conftest.py ├── test_basic.py ├── test_optimizer.py ├── test_optimizer_with_nn.py └── test_param_validation.py └── torch_optimizer ├── __init__.py ├── a2grad.py ├── accsgd.py ├── adabelief.py ├── adabound.py ├── adafactor.py ├── adahessian.py ├── adamod.py ├── adamp.py ├── aggmo.py ├── apollo.py ├── diffgrad.py ├── lamb.py ├── lars.py ├── lion.py ├── lookahead.py ├── madgrad.py ├── novograd.py ├── pid.py ├── py.typed ├── qhadam.py ├── qhm.py ├── radam.py ├── sgdp.py ├── sgdw.py ├── shampoo.py ├── swats.py ├── types.py └── yogi.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | source = torch_optimizer 4 | omit = site-packages, .tox 5 | -------------------------------------------------------------------------------- /.deepsource.toml: -------------------------------------------------------------------------------- 1 | version = 1 2 | 3 | test_patterns = ["tests/test_*.py"] 4 | 5 | exclude_patterns = [ 6 | "tests/test_optimizer.py", 7 | "docs/conf.py", 8 | "tests/utils.py", 9 | "examples/*" 10 | ] 11 | 12 | [[analyzers]] 13 | name = "python" 14 | enabled = true 15 | 16 | [analyzers.meta] 17 | runtime_version = "3.x.x" 18 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: pip 4 | directory: "/" 5 | schedule: 6 | interval: weekly 7 | day: "monday" 8 | time: "10:00" 9 | open-pull-requests-limit: 10 10 | ignore: 11 | - dependency-name: numpy 12 | versions: 13 | - 1.20.0 14 | - 1.20.1 15 | - dependency-name: isort 16 | versions: 17 | - 5.7.0 18 | - dependency-name: ipython 19 | versions: 20 | - 7.19.0 21 | - 7.20.0 22 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ master ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ master ] 20 | schedule: 21 | - cron: '39 10 * * 5' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | 28 | strategy: 29 | fail-fast: false 30 | matrix: 31 | language: [ 'python' ] 32 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ] 33 | # Learn more: 34 | # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed 35 | 36 | steps: 37 | - name: Checkout repository 38 | uses: actions/checkout@v2 39 | 40 | # Initializes the CodeQL tools for scanning. 41 | - name: Initialize CodeQL 42 | uses: github/codeql-action/init@v1 43 | with: 44 | languages: ${{ matrix.language }} 45 | # If you wish to specify custom queries, you can do so here or in a config file. 46 | # By default, queries listed here will override any specified in a config file. 47 | # Prefix the list here with "+" to use these queries and those in the config file. 48 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 49 | 50 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 51 | # If this step fails, then you should remove it and run the build manually (see below) 52 | - name: Autobuild 53 | uses: github/codeql-action/autobuild@v1 54 | 55 | # ℹ️ Command-line programs to run using the OS shell. 56 | # 📚 https://git.io/JvXDl 57 | 58 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 59 | # and modify them (or add more) to build your code if your project 60 | # uses a compiled language 61 | 62 | #- run: | 63 | # make bootstrap 64 | # make release 65 | 66 | - name: Perform CodeQL Analysis 67 | uses: github/codeql-action/analyze@v1 68 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: CI 5 | 6 | on: 7 | push: 8 | branches: 9 | - master 10 | tags: [ 'v*' ] 11 | pull_request: 12 | branches: 13 | - master 14 | schedule: 15 | - cron: '0 6 * * 1' # Weekly Mon 6AM UTC build 16 | 17 | 18 | jobs: 19 | test: 20 | 21 | runs-on: ubuntu-latest 22 | strategy: 23 | matrix: 24 | python-version: ['3.7', '3.8', '3.9'] 25 | 26 | steps: 27 | - uses: actions/checkout@v2 28 | - name: Set up Python ${{ matrix.python-version }} 29 | uses: actions/setup-python@v2 30 | with: 31 | python-version: ${{ matrix.python-version }} 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | python -m pip install codecov 36 | pip install -r requirements-dev.txt 37 | - name: Lint 38 | run: | 39 | make lint 40 | make checkbuild 41 | 42 | - name: Test 43 | run: | 44 | make cov 45 | codecov 46 | 47 | deploy: 48 | name: Deploy 49 | runs-on: ubuntu-latest 50 | needs: test 51 | # Run only on pushing a tag 52 | if: github.event_name == 'push' && contains(github.ref, 'refs/tags/') 53 | steps: 54 | - name: Checkout 55 | uses: actions/checkout@v2 56 | - name: Setup Python 3.8 57 | uses: actions/setup-python@v2 58 | with: 59 | python-version: 3.8 60 | - name: Install dependencies 61 | run: 62 | python -m pip install -U pip wheel twine 63 | - name: Make dists 64 | run: 65 | python setup.py sdist bdist_wheel 66 | - name: PyPI upload 67 | env: 68 | TWINE_USERNAME: __token__ 69 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 70 | run: | 71 | twine upload dist/* 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ># Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | lib/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | 25 | # PyInstaller 26 | # Usually these files are written by a python script from a template 27 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 28 | *.manifest 29 | *.spec 30 | 31 | # Installer logs 32 | pip-log.txt 33 | pip-delete-this-directory.txt 34 | 35 | # Unit test / coverage reports 36 | htmlcov/ 37 | .tox/ 38 | .coverage 39 | .cache 40 | nosetests.xml 41 | coverage.xml 42 | cover 43 | 44 | # Translations 45 | *.mo 46 | *.pot 47 | 48 | # Django stuff: 49 | *.log 50 | 51 | # Sphinx documentation 52 | docs/_build/ 53 | 54 | # PyBuilder 55 | target/ 56 | 57 | # PyCharm 58 | .idea 59 | 60 | .coverage.* 61 | coverage 62 | .mypy_cache/ 63 | .DS_Store 64 | tags 65 | cscope.* 66 | TODO 67 | -------------------------------------------------------------------------------- /CHANGES.rst: -------------------------------------------------------------------------------- 1 | Changes 2 | ------- 3 | 4 | 0.3.1 (YYYY-MM-DD) 5 | ------------------ 6 | * Deprecate RAdam optimizer. 7 | 8 | 0.3.0 (2021-10-30) 9 | ------------------ 10 | * Revert for Drop RAdam. 11 | 12 | 0.2.0 (2021-10-25) 13 | ------------------ 14 | * Drop RAdam optimizer since it is included in pytorch. 15 | * Do not include tests as installable package. 16 | * Preserver memory layout where possible. 17 | * Add MADGRAD optimizer. 18 | 19 | 0.1.0 (2021-01-01) 20 | ------------------ 21 | * Initial release. 22 | * Added support for A2GradExp, A2GradInc, A2GradUni, AccSGD, AdaBelief, 23 | AdaBound, AdaMod, Adafactor, Adahessian, AdamP, AggMo, Apollo, 24 | DiffGrad, Lamb, Lookahead, NovoGrad, PID, QHAdam, QHM, RAdam, Ranger, 25 | RangerQH, RangerVA, SGDP, SGDW, SWATS, Shampoo, Yogi. 26 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: Novik 5 | given-names: Mykola 6 | orcid: https://orcid.org/0000-0002-0890-1159 7 | title: "torch-optimizer -- collection of optimization algorithms for PyTorch." 8 | version: 1.0.1 9 | date-released: 2020-01-11 10 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | Contributing 2 | ============ 3 | 4 | Running Tests 5 | ------------- 6 | 7 | .. _GitHub: https://github.com/jettify/pytorch-optimizer 8 | .. _PyTorch: https://github.com/pytorch/pytorch 9 | 10 | Thanks for your interest in contributing to ``pytorch-optimizer``, there are multiple 11 | ways and places you can contribute. 12 | 13 | First of all just clone repository:: 14 | 15 | $ git clone git@github.com:jettify/pytorch-optimizer.git 16 | 17 | Create virtualenv with python3.5 (older version are not supported). For example 18 | using *virtualenvwrapper* commands could look like:: 19 | 20 | $ cd pytorch-optimizer 21 | $ mkvirtualenv --python=`which python3.7` pytorch-optimizer 22 | 23 | 24 | After that please install libraries required for development:: 25 | 26 | $ pip install -r requirements-dev.txt 27 | $ pip install -e . 28 | 29 | Congratulations, you are ready to run the test suite:: 30 | 31 | $ make cov 32 | 33 | To run individual use following command:: 34 | 35 | $ py.test -sv tests/test_basic.py -k test_name 36 | 37 | 38 | Reporting an Issue 39 | ------------------ 40 | If you have found issue with `pytorch-optimizer` please do 41 | not hesitate to file an issue on the GitHub_ project. When filing your 42 | issue please make sure you can express the issue with a reproducible test 43 | case. 44 | 45 | When reporting an issue we also need as much information about your environment 46 | that you can include. We never know what information will be pertinent when 47 | trying narrow down the issue. Please include at least the following 48 | information: 49 | 50 | * Version of `pytorch-optimizer`, `python`. 51 | * Version PyTorch_ if installed. 52 | * Version or CUDA if installed. 53 | * Platform you're running on (OS X, Linux). 54 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include CHANGES.rst 3 | include README.rst 4 | include Makefile 5 | graft torch_optimizer 6 | graft tests 7 | prune docs/_build 8 | include examples/mnist.py 9 | global-exclude *.py[cod] 10 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Some simple testing tasks (sorry, UNIX only). 2 | 3 | FILES := torch_optimizer tests examples setup.py 4 | 5 | 6 | 7 | flake: 8 | flake8 ${FILES} 9 | 10 | test: flake 11 | pytest -sv 12 | 13 | vtest: 14 | pytest -sv -vv 15 | 16 | checkrst: 17 | python setup.py check --restructuredtext 18 | 19 | pyroma: 20 | pyroma -d . 21 | 22 | bandit: 23 | bandit -r ./torch_optimizer 24 | 25 | mypy: 26 | mypy torch_optimizer --ignore-missing-imports 27 | 28 | checkbuild: 29 | python setup.py sdist bdist_wheel 30 | twine check dist/* 31 | 32 | cov cover coverage: 33 | pytest -sv -vv --cov=torch_optimizer --cov-report=term --cov-report=html ./tests 34 | @echo "open file://`pwd`/htmlcov/index.html" 35 | 36 | checkfmt: 37 | isort --profile black --check-only --diff $(FILES) 38 | black -l 79 --check $(FILES) 39 | 40 | lint: flake checkrst pyroma bandit checkfmt 41 | 42 | clean: 43 | rm -rf `find . -name __pycache__` 44 | rm -f `find . -type f -name '*.py[co]' ` 45 | rm -f `find . -type f -name '*~' ` 46 | rm -f `find . -type f -name '.*~' ` 47 | rm -f `find . -type f -name '@*' ` 48 | rm -f `find . -type f -name '#*#' ` 49 | rm -f `find . -type f -name '*.orig' ` 50 | rm -f `find . -type f -name '*.rej' ` 51 | rm -f .coverage 52 | rm -rf coverage 53 | rm -rf build 54 | rm -rf cover 55 | rm -rf dist 56 | rm -rf docs/_build 57 | 58 | doc: 59 | make -C docs html 60 | @echo "open file://`pwd`/docs/_build/html/index.html" 61 | 62 | black: 63 | black -l 79 setup.py torch_optimizer/ tests/ examples/ 64 | 65 | fmt: 66 | isort --profile black ${FILES} 67 | black -l 79 ${FILES} 68 | 69 | 70 | .PHONY: all flake test vtest cov clean doc 71 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | Available Optimizers 2 | ==================== 3 | 4 | .. _AccSGD: 5 | 6 | AccSGD 7 | ------ 8 | 9 | .. autoclass:: torch_optimizer.AccSGD 10 | :members: 11 | 12 | .. _AdaBound: 13 | 14 | AdaBound 15 | -------- 16 | 17 | .. autoclass:: torch_optimizer.AdaBound 18 | :members: 19 | 20 | .. _AdaMod: 21 | 22 | AdaMod 23 | ------ 24 | 25 | .. autoclass:: torch_optimizer.AdaMod 26 | :members: 27 | 28 | .. _Adafactor: 29 | 30 | Adafactor 31 | --------- 32 | 33 | .. autoclass:: torch_optimizer.Adafactor 34 | :members: 35 | 36 | .. _AdamP: 37 | 38 | AdamP 39 | ------ 40 | 41 | .. autoclass:: torch_optimizer.AdamP 42 | :members: 43 | 44 | .. _AggMo: 45 | 46 | AggMo 47 | ----- 48 | 49 | .. autoclass:: torch_optimizer.AggMo 50 | :members: 51 | 52 | .. _DiffGrad: 53 | 54 | DiffGrad 55 | -------- 56 | 57 | .. autoclass:: torch_optimizer.DiffGrad 58 | :members: 59 | 60 | .. _Lamb: 61 | 62 | Lamb 63 | ---- 64 | 65 | .. autoclass:: torch_optimizer.Lamb 66 | :members: 67 | 68 | .. _NovoGrad: 69 | 70 | NovoGrad 71 | -------- 72 | 73 | .. autoclass:: torch_optimizer.NovoGrad 74 | :members: 75 | 76 | .. _PID: 77 | 78 | PID 79 | --- 80 | 81 | .. autoclass:: torch_optimizer.PID 82 | :members: 83 | 84 | .. _QHAdam: 85 | 86 | QHAdam 87 | ------ 88 | 89 | .. autoclass:: torch_optimizer.QHAdam 90 | :members: 91 | 92 | .. _QHM: 93 | 94 | QHM 95 | --- 96 | 97 | .. autoclass:: torch_optimizer.QHM 98 | :members: 99 | 100 | .. _RAdam: 101 | 102 | RAdam 103 | ----- 104 | 105 | .. autoclass:: torch_optimizer.RAdam 106 | :members: 107 | 108 | .. _SGDP: 109 | 110 | SGDP 111 | ---- 112 | 113 | .. autoclass:: torch_optimizer.SGDP 114 | :members: 115 | 116 | .. _SGDW: 117 | 118 | SGDW 119 | ---- 120 | 121 | .. autoclass:: torch_optimizer.SGDW 122 | :members: 123 | 124 | .. _Shampoo: 125 | 126 | Shampoo 127 | ------- 128 | 129 | .. autoclass:: torch_optimizer.Shampoo 130 | :members: 131 | 132 | .. _SWATS: 133 | 134 | SWATS 135 | ----- 136 | 137 | .. autoclass:: torch_optimizer.SWATS 138 | :members: 139 | 140 | .. _Yogi: 141 | 142 | Yogi 143 | ---- 144 | 145 | .. autoclass:: torch_optimizer.Yogi 146 | :members: 147 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'pytorch-optimizer' 21 | copyright = '2020, Nikolai Novik' 22 | author = 'Nikolai Novik' 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | 31 | # Sphinx extension modules 32 | extensions = [ 33 | "sphinx.ext.autodoc", 34 | "sphinx.ext.napoleon", 35 | "sphinx_autodoc_typehints", 36 | "sphinx.ext.doctest", 37 | "sphinx.ext.todo", 38 | "sphinx.ext.coverage", 39 | "sphinx.ext.mathjax", 40 | "sphinx.ext.ifconfig", 41 | "sphinx.ext.viewcode", 42 | "sphinx.ext.intersphinx", 43 | ] 44 | 45 | # Add any paths that contain templates here, relative to this directory. 46 | templates_path = ['_templates'] 47 | 48 | # List of patterns, relative to source directory, that match files and 49 | # directories to ignore when looking for source files. 50 | # This pattern also affects html_static_path and html_extra_path. 51 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 52 | 53 | # Configuration for intersphinx: refer to the Python standard library and PyTorch 54 | intersphinx_mapping = { 55 | "python": ("https://docs.python.org/3", None), 56 | "pytorch": ("https://pytorch.org/docs/stable", None), 57 | } 58 | 59 | 60 | # -- Options for HTML output ------------------------------------------------- 61 | 62 | # The theme to use for HTML and HTML Help pages. See the documentation for 63 | # a list of builtin themes. 64 | # 65 | html_theme = 'alabaster' 66 | 67 | # Add any paths that contain custom static files (such as style sheets) here, 68 | # relative to this directory. They are copied after the builtin static files, 69 | # so a file named "default.css" will overwrite the builtin "default.css". 70 | html_static_path = ['_static'] 71 | 72 | desc = 'collection of optimizers for PyTorch' 73 | html_theme_options = { 74 | 'description': desc, 75 | 'github_user': 'jettify', 76 | 'github_repo': 'pytorch-optimizer', 77 | 'github_button': True, 78 | 'github_type': 'star', 79 | 'github_banner': True, 80 | } 81 | -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CONTRIBUTING.rst 2 | -------------------------------------------------------------------------------- /docs/examples.rst: -------------------------------------------------------------------------------- 1 | Examples of pytorch-optimizer usage 2 | =================================== 3 | 4 | Below is a list of examples from `pytorch-optimizer/examples 5 | `_ 6 | 7 | Every example is a correct tiny python program. 8 | 9 | .. _pytorch-optimizer-examples-simple: 10 | 11 | 12 | Basic Usage 13 | ----------- 14 | 15 | Simple example that shows how to use library with MNIST dataset. 16 | 17 | .. literalinclude:: ../examples/mnist.py 18 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. pytorch-optimizer documentation master file, created by 2 | sphinx-quickstart on Thu Feb 13 21:14:16 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to pytorch-optimizer's documentation! 7 | ============================================= 8 | 9 | **torch-optimizer** -- collection of optimizers for PyTorch_. 10 | 11 | Simple example 12 | -------------- 13 | 14 | .. code:: python 15 | 16 | import torch_optimizer as optim 17 | 18 | # model = ... 19 | optimizer = optim.DiffGrad(model.parameters(), lr=0.001) 20 | optimizer.step() 21 | 22 | 23 | Installation 24 | ------------ 25 | Installation process is simple, just:: 26 | 27 | $ pip install torch_optimizer 28 | 29 | 30 | Citation 31 | -------- 32 | Please cite original authors of optimization algorithms. If you like this 33 | package:: 34 | 35 | @software{Novik_torchoptimizers, 36 | title = {{torch-optimizer -- collection of optimization algorithms for PyTorch.}}, 37 | author = {Novik, Mykola}, 38 | year = 2020, 39 | month = 1, 40 | version = {1.0.1} 41 | } 42 | 43 | Or use github feature: "cite this repository" button. 44 | 45 | 46 | Supported Optimizers 47 | ==================== 48 | 49 | +-----------------+-------------------------------------------------------------------------------+ 50 | | | | 51 | | :ref:`AccSGD` | https://arxiv.org/abs/1803.05591 | 52 | +-----------------+-------------------------------------------------------------------------------+ 53 | | | | 54 | | :ref:`AdaBound` | https://arxiv.org/abs/1902.09843 | 55 | +-----------------+-------------------------------------------------------------------------------+ 56 | | | | 57 | | :ref:`AdaMod` | https://arxiv.org/abs/1910.12249 | 58 | +-----------------+-------------------------------------------------------------------------------+ 59 | | | | 60 | | :ref:`Adafactor`| https://arxiv.org/abs/1804.04235 | 61 | +-----------------+-------------------------------------------------------------------------------+ 62 | | | | 63 | | :ref:`AdamP` | https://arxiv.org/abs/1804.00325 | 64 | +-----------------+-------------------------------------------------------------------------------+ 65 | | | | 66 | | :ref:`AggMo` | https://arxiv.org/abs/2006.08217 | 67 | +-----------------+-------------------------------------------------------------------------------+ 68 | | | | 69 | | :ref:`DiffGrad` | https://arxiv.org/abs/1909.11015 | 70 | +-----------------+-------------------------------------------------------------------------------+ 71 | | | | 72 | | :ref:`Lamb` | https://arxiv.org/abs/1904.00962 | 73 | +-----------------+-------------------------------------------------------------------------------+ 74 | | | | 75 | | :ref:`NovoGrad` | https://arxiv.org/abs/1905.11286 | 76 | +-----------------+-------------------------------------------------------------------------------+ 77 | | | | 78 | | :ref:`PID` | https://www4.comp.polyu.edu.hk/~cslzhang/paper/CVPR18_PID.pdf | 79 | +-----------------+-------------------------------------------------------------------------------+ 80 | | | | 81 | | :ref:`QHAdam` | https://arxiv.org/abs/1810.06801 | 82 | +-----------------+-------------------------------------------------------------------------------+ 83 | | | | 84 | | :ref:`QHM` | https://arxiv.org/abs/1810.06801 | 85 | +-----------------+-------------------------------------------------------------------------------+ 86 | | | | 87 | | :ref:`RAdam` | https://arxiv.org/abs/1908.03265 | 88 | +-----------------+-------------------------------------------------------------------------------+ 89 | | | | 90 | | :ref:`Ranger` | https://arxiv.org/abs/1908.00700v2 | 91 | +-----------------+-------------------------------------------------------------------------------+ 92 | | | | 93 | | :ref:`RangerQH` | https://arxiv.org/abs/1908.00700v2 | 94 | +-----------------+-------------------------------------------------------------------------------+ 95 | | | | 96 | | :ref:`RangerVA` | https://arxiv.org/abs/1908.00700v2 | 97 | +-----------------+-------------------------------------------------------------------------------+ 98 | | | | 99 | | :ref:`SGDP` | https://arxiv.org/abs/2006.08217 | 100 | +-----------------+-------------------------------------------------------------------------------+ 101 | | | | 102 | | :ref:`SGDW` | https://arxiv.org/abs/1608.03983 | 103 | +-----------------+-------------------------------------------------------------------------------+ 104 | | | | 105 | | :ref:`Shampoo` | https://arxiv.org/abs/1802.09568 | 106 | +-----------------+-------------------------------------------------------------------------------+ 107 | | | | 108 | | :ref:`SWATS` | https://arxiv.org/abs/1712.07628 | 109 | +-----------------+-------------------------------------------------------------------------------+ 110 | | | | 111 | | :ref:`Yogi` | https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization | 112 | +-----------------+-------------------------------------------------------------------------------+ 113 | 114 | .. toctree:: 115 | :maxdepth: 2 116 | :caption: Contents: 117 | 118 | Contents 119 | -------- 120 | 121 | .. toctree:: 122 | :maxdepth: 2 123 | 124 | api 125 | examples 126 | contributing 127 | 128 | 129 | Indices and tables 130 | ================== 131 | 132 | * :ref:`genindex` 133 | * :ref:`modindex` 134 | * :ref:`search` 135 | 136 | .. _Python: https://www.python.org 137 | .. _PyTorch: https://github.com/pytorch/pytorch 138 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/rastrigin_A2GradExp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_A2GradExp.png -------------------------------------------------------------------------------- /docs/rastrigin_A2GradInc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_A2GradInc.png -------------------------------------------------------------------------------- /docs/rastrigin_A2GradUni.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_A2GradUni.png -------------------------------------------------------------------------------- /docs/rastrigin_AccSGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_AccSGD.png -------------------------------------------------------------------------------- /docs/rastrigin_AdaBelief.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_AdaBelief.png -------------------------------------------------------------------------------- /docs/rastrigin_AdaBound.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_AdaBound.png -------------------------------------------------------------------------------- /docs/rastrigin_AdaMod.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_AdaMod.png -------------------------------------------------------------------------------- /docs/rastrigin_Adafactor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_Adafactor.png -------------------------------------------------------------------------------- /docs/rastrigin_Adahessian.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_Adahessian.png -------------------------------------------------------------------------------- /docs/rastrigin_Adam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_Adam.png -------------------------------------------------------------------------------- /docs/rastrigin_AdamP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_AdamP.png -------------------------------------------------------------------------------- /docs/rastrigin_AggMo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_AggMo.png -------------------------------------------------------------------------------- /docs/rastrigin_Apollo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_Apollo.png -------------------------------------------------------------------------------- /docs/rastrigin_DiffGrad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_DiffGrad.png -------------------------------------------------------------------------------- /docs/rastrigin_Lamb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_Lamb.png -------------------------------------------------------------------------------- /docs/rastrigin_LookaheadYogi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_LookaheadYogi.png -------------------------------------------------------------------------------- /docs/rastrigin_MADGRAD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_MADGRAD.png -------------------------------------------------------------------------------- /docs/rastrigin_NovoGrad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_NovoGrad.png -------------------------------------------------------------------------------- /docs/rastrigin_PID.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_PID.png -------------------------------------------------------------------------------- /docs/rastrigin_QHAdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_QHAdam.png -------------------------------------------------------------------------------- /docs/rastrigin_QHM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_QHM.png -------------------------------------------------------------------------------- /docs/rastrigin_RAdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_RAdam.png -------------------------------------------------------------------------------- /docs/rastrigin_Ranger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_Ranger.png -------------------------------------------------------------------------------- /docs/rastrigin_RangerQH.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_RangerQH.png -------------------------------------------------------------------------------- /docs/rastrigin_RangerVA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_RangerVA.png -------------------------------------------------------------------------------- /docs/rastrigin_SGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_SGD.png -------------------------------------------------------------------------------- /docs/rastrigin_SGDP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_SGDP.png -------------------------------------------------------------------------------- /docs/rastrigin_SGDW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_SGDW.png -------------------------------------------------------------------------------- /docs/rastrigin_SWATS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_SWATS.png -------------------------------------------------------------------------------- /docs/rastrigin_Shampoo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_Shampoo.png -------------------------------------------------------------------------------- /docs/rastrigin_Yogi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rastrigin_Yogi.png -------------------------------------------------------------------------------- /docs/rosenbrock_A2GradExp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_A2GradExp.png -------------------------------------------------------------------------------- /docs/rosenbrock_A2GradInc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_A2GradInc.png -------------------------------------------------------------------------------- /docs/rosenbrock_A2GradUni.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_A2GradUni.png -------------------------------------------------------------------------------- /docs/rosenbrock_AccSGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_AccSGD.png -------------------------------------------------------------------------------- /docs/rosenbrock_AdaBelief.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_AdaBelief.png -------------------------------------------------------------------------------- /docs/rosenbrock_AdaBound.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_AdaBound.png -------------------------------------------------------------------------------- /docs/rosenbrock_AdaMod.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_AdaMod.png -------------------------------------------------------------------------------- /docs/rosenbrock_Adafactor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_Adafactor.png -------------------------------------------------------------------------------- /docs/rosenbrock_Adahessian.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_Adahessian.png -------------------------------------------------------------------------------- /docs/rosenbrock_Adam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_Adam.png -------------------------------------------------------------------------------- /docs/rosenbrock_AdamP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_AdamP.png -------------------------------------------------------------------------------- /docs/rosenbrock_AggMo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_AggMo.png -------------------------------------------------------------------------------- /docs/rosenbrock_Apollo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_Apollo.png -------------------------------------------------------------------------------- /docs/rosenbrock_DiffGrad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_DiffGrad.png -------------------------------------------------------------------------------- /docs/rosenbrock_Lamb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_Lamb.png -------------------------------------------------------------------------------- /docs/rosenbrock_LookaheadYogi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_LookaheadYogi.png -------------------------------------------------------------------------------- /docs/rosenbrock_MADGRAD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_MADGRAD.png -------------------------------------------------------------------------------- /docs/rosenbrock_NovoGrad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_NovoGrad.png -------------------------------------------------------------------------------- /docs/rosenbrock_PID.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_PID.png -------------------------------------------------------------------------------- /docs/rosenbrock_QHAdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_QHAdam.png -------------------------------------------------------------------------------- /docs/rosenbrock_QHM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_QHM.png -------------------------------------------------------------------------------- /docs/rosenbrock_RAdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_RAdam.png -------------------------------------------------------------------------------- /docs/rosenbrock_Ranger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_Ranger.png -------------------------------------------------------------------------------- /docs/rosenbrock_RangerQH.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_RangerQH.png -------------------------------------------------------------------------------- /docs/rosenbrock_RangerVA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_RangerVA.png -------------------------------------------------------------------------------- /docs/rosenbrock_SGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_SGD.png -------------------------------------------------------------------------------- /docs/rosenbrock_SGDP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_SGDP.png -------------------------------------------------------------------------------- /docs/rosenbrock_SGDW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_SGDW.png -------------------------------------------------------------------------------- /docs/rosenbrock_SWATS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_SWATS.png -------------------------------------------------------------------------------- /docs/rosenbrock_Shampoo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_Shampoo.png -------------------------------------------------------------------------------- /docs/rosenbrock_Yogi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/docs/rosenbrock_Yogi.png -------------------------------------------------------------------------------- /examples/mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.optim.lr_scheduler import StepLR 5 | from torch.utils.tensorboard import SummaryWriter 6 | from torchvision import datasets, transforms, utils 7 | 8 | import torch_optimizer as optim 9 | 10 | 11 | class Net(nn.Module): 12 | def __init__(self): 13 | super(Net, self).__init__() 14 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 15 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 16 | self.dropout1 = nn.Dropout2d(0.25) 17 | self.dropout2 = nn.Dropout2d(0.5) 18 | self.fc1 = nn.Linear(9216, 128) 19 | self.fc2 = nn.Linear(128, 10) 20 | 21 | def forward(self, x): 22 | x = self.conv1(x) 23 | x = F.relu(x) 24 | x = self.conv2(x) 25 | x = F.max_pool2d(x, 2) 26 | x = self.dropout1(x) 27 | x = torch.flatten(x, 1) 28 | x = self.fc1(x) 29 | x = F.relu(x) 30 | x = self.dropout2(x) 31 | x = self.fc2(x) 32 | output = F.log_softmax(x, dim=1) 33 | return output 34 | 35 | 36 | def train(conf, model, device, train_loader, optimizer, epoch, writer): 37 | model.train() 38 | for batch_idx, (data, target) in enumerate(train_loader): 39 | data, target = data.to(device), target.to(device) 40 | optimizer.zero_grad() 41 | output = model(data) 42 | loss = F.nll_loss(output, target) 43 | loss.backward() 44 | optimizer.step() 45 | if batch_idx % conf.log_interval == 0: 46 | loss = loss.item() 47 | idx = batch_idx + epoch * (len(train_loader)) 48 | writer.add_scalar("Loss/train", loss, idx) 49 | print( 50 | "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( 51 | epoch, 52 | batch_idx * len(data), 53 | len(train_loader.dataset), 54 | 100.0 * batch_idx / len(train_loader), 55 | loss, 56 | ) 57 | ) 58 | 59 | 60 | def test(conf, model, device, test_loader, epoch, writer): 61 | model.eval() 62 | test_loss = 0 63 | correct = 0 64 | with torch.no_grad(): 65 | for data, target in test_loader: 66 | data, target = data.to(device), target.to(device) 67 | output = model(data) 68 | # sum up batch loss 69 | test_loss += F.nll_loss(output, target, reduction="sum").item() 70 | # get the index of the max log-probability 71 | pred = output.argmax(dim=1, keepdim=True) 72 | correct += pred.eq(target.view_as(pred)).sum().item() 73 | 74 | test_loss /= len(test_loader.dataset) 75 | fmt = "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n" 76 | print( 77 | fmt.format( 78 | test_loss, 79 | correct, 80 | len(test_loader.dataset), 81 | 100.0 * correct / len(test_loader.dataset), 82 | ) 83 | ) 84 | 85 | writer.add_scalar("Accuracy", correct, epoch) 86 | writer.add_scalar("Loss/test", test_loss, epoch) 87 | 88 | 89 | def prepare_loaders(conf, use_cuda=False): 90 | kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} 91 | train_loader = torch.utils.data.DataLoader( 92 | datasets.MNIST( 93 | "../data", 94 | train=True, 95 | download=True, 96 | transform=transforms.Compose( 97 | [ 98 | transforms.ToTensor(), 99 | transforms.Normalize((0.1307,), (0.3081,)), 100 | ] 101 | ), 102 | ), 103 | batch_size=conf.batch_size, 104 | shuffle=True, 105 | **kwargs, 106 | ) 107 | 108 | test_loader = torch.utils.data.DataLoader( 109 | datasets.MNIST( 110 | "../data", 111 | train=False, 112 | transform=transforms.Compose( 113 | [ 114 | transforms.ToTensor(), 115 | transforms.Normalize((0.1307,), (0.3081,)), 116 | ] 117 | ), 118 | ), 119 | batch_size=conf.test_batch_size, 120 | shuffle=True, 121 | **kwargs, 122 | ) 123 | return train_loader, test_loader 124 | 125 | 126 | class Config: 127 | def __init__( 128 | self, 129 | batch_size: int = 64, 130 | test_batch_size: int = 1000, 131 | epochs: int = 15, 132 | lr: float = 0.01, 133 | gamma: float = 0.7, 134 | no_cuda: bool = True, 135 | seed: int = 42, 136 | log_interval: int = 10, 137 | ): 138 | self.batch_size = batch_size 139 | self.test_batch_size = test_batch_size 140 | self.epochs = epochs 141 | self.lr = lr 142 | self.gamma = gamma 143 | self.no_cuda = no_cuda 144 | self.seed = seed 145 | self.log_interval = log_interval 146 | 147 | 148 | def main(): 149 | conf = Config() 150 | log_dir = "runs/mnist_custom_optim" 151 | print("Tensorboard: tensorboard --logdir={}".format(log_dir)) 152 | 153 | with SummaryWriter(log_dir) as writer: 154 | use_cuda = not conf.no_cuda and torch.cuda.is_available() 155 | torch.manual_seed(conf.seed) 156 | device = torch.device("cuda" if use_cuda else "cpu") 157 | train_loader, test_loader = prepare_loaders(conf, use_cuda) 158 | 159 | model = Net().to(device) 160 | 161 | # create grid of images and write to tensorboard 162 | images, labels = next(iter(train_loader)) 163 | img_grid = utils.make_grid(images) 164 | writer.add_image("mnist_images", img_grid) 165 | 166 | # custom optimizer from torch_optimizer package 167 | optimizer = optim.DiffGrad(model.parameters(), lr=conf.lr) 168 | 169 | scheduler = StepLR(optimizer, step_size=1, gamma=conf.gamma) 170 | for epoch in range(1, conf.epochs + 1): 171 | train(conf, model, device, train_loader, optimizer, epoch, writer) 172 | test(conf, model, device, test_loader, epoch, writer) 173 | scheduler.step() 174 | for name, param in model.named_parameters(): 175 | writer.add_histogram(name, param, epoch) 176 | writer.add_histogram("{}.grad".format(name), param.grad, epoch) 177 | 178 | 179 | if __name__ == "__main__": 180 | main() 181 | -------------------------------------------------------------------------------- /examples/requirements-examples.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | hyperopt==0.2.5 3 | torchvision==0.11.1 4 | matplotlib==3.4.3 5 | -------------------------------------------------------------------------------- /examples/viz_optimizers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import torch 6 | from hyperopt import fmin, hp, tpe 7 | 8 | import torch_optimizer as optim 9 | 10 | plt.style.use("seaborn-white") 11 | 12 | 13 | def rosenbrock(tensor): 14 | # https://en.wikipedia.org/wiki/Test_functions_for_optimization 15 | x, y = tensor 16 | return (1 - x) ** 2 + 100 * (y - x**2) ** 2 17 | 18 | 19 | def rastrigin(tensor, lib=torch): 20 | # https://en.wikipedia.org/wiki/Test_functions_for_optimization 21 | x, y = tensor 22 | A = 10 23 | f = ( 24 | A * 2 25 | + (x**2 - A * lib.cos(x * math.pi * 2)) 26 | + (y**2 - A * lib.cos(y * math.pi * 2)) 27 | ) 28 | return f 29 | 30 | 31 | def execute_steps( 32 | func, initial_state, optimizer_class, optimizer_config, num_iter=500 33 | ): 34 | x = torch.Tensor(initial_state).requires_grad_(True) 35 | optimizer = optimizer_class([x], **optimizer_config) 36 | steps = [] 37 | steps = np.zeros((2, num_iter + 1)) 38 | steps[:, 0] = np.array(initial_state) 39 | for i in range(1, num_iter + 1): 40 | optimizer.zero_grad() 41 | f = func(x) 42 | f.backward(create_graph=True, retain_graph=True) 43 | torch.nn.utils.clip_grad_norm_(x, 1.0) 44 | optimizer.step() 45 | steps[:, i] = x.detach().numpy() 46 | return steps 47 | 48 | 49 | def objective_rastrigin(params): 50 | lr = params["lr"] 51 | optimizer_class = params["optimizer_class"] 52 | initial_state = (-2.0, 3.5) 53 | minimum = (0, 0) 54 | optimizer_config = dict(lr=lr) 55 | num_iter = 100 56 | steps = execute_steps( 57 | rastrigin, initial_state, optimizer_class, optimizer_config, num_iter 58 | ) 59 | return (steps[0][-1] - minimum[0]) ** 2 + (steps[1][-1] - minimum[1]) ** 2 60 | 61 | 62 | def objective_rosenbrok(params): 63 | lr = params["lr"] 64 | optimizer_class = params["optimizer_class"] 65 | minimum = (1.0, 1.0) 66 | initial_state = (-2.0, 2.0) 67 | optimizer_config = dict(lr=lr) 68 | num_iter = 100 69 | steps = execute_steps( 70 | rosenbrock, initial_state, optimizer_class, optimizer_config, num_iter 71 | ) 72 | return (steps[0][-1] - minimum[0]) ** 2 + (steps[1][-1] - minimum[1]) ** 2 73 | 74 | 75 | def plot_rastrigin(grad_iter, optimizer_name, lr): 76 | x = np.linspace(-4.5, 4.5, 250) 77 | y = np.linspace(-4.5, 4.5, 250) 78 | minimum = (0, 0) 79 | 80 | X, Y = np.meshgrid(x, y) 81 | Z = rastrigin([X, Y], lib=np) 82 | 83 | iter_x, iter_y = grad_iter[0, :], grad_iter[1, :] 84 | 85 | fig = plt.figure(figsize=(8, 8)) 86 | 87 | ax = fig.add_subplot(1, 1, 1) 88 | ax.contour(X, Y, Z, 20, cmap="jet") 89 | ax.plot(iter_x, iter_y, color="r", marker="x") 90 | ax.set_title( 91 | "Rastrigin func: {} with " 92 | "{} iterations, lr={:.6}".format(optimizer_name, len(iter_x), lr) 93 | ) 94 | plt.plot(*minimum, "gD") 95 | plt.plot(iter_x[-1], iter_y[-1], "rD") 96 | plt.savefig("docs/rastrigin_{}.png".format(optimizer_name)) 97 | 98 | 99 | def plot_rosenbrok(grad_iter, optimizer_name, lr): 100 | x = np.linspace(-2, 2, 250) 101 | y = np.linspace(-1, 3, 250) 102 | minimum = (1.0, 1.0) 103 | 104 | X, Y = np.meshgrid(x, y) 105 | Z = rosenbrock([X, Y]) 106 | 107 | iter_x, iter_y = grad_iter[0, :], grad_iter[1, :] 108 | 109 | fig = plt.figure(figsize=(8, 8)) 110 | 111 | ax = fig.add_subplot(1, 1, 1) 112 | ax.contour(X, Y, Z, 90, cmap="jet") 113 | ax.plot(iter_x, iter_y, color="r", marker="x") 114 | 115 | ax.set_title( 116 | "Rosenbrock func: {} with {} " 117 | "iterations, lr={:.6}".format(optimizer_name, len(iter_x), lr) 118 | ) 119 | plt.plot(*minimum, "gD") 120 | plt.plot(iter_x[-1], iter_y[-1], "rD") 121 | plt.savefig("docs/rosenbrock_{}.png".format(optimizer_name)) 122 | 123 | 124 | def execute_experiments( 125 | optimizers, objective, func, plot_func, initial_state, seed=1 126 | ): 127 | seed = seed 128 | for item in optimizers: 129 | optimizer_class, lr_low, lr_hi = item 130 | space = { 131 | "optimizer_class": hp.choice("optimizer_class", [optimizer_class]), 132 | "lr": hp.loguniform("lr", lr_low, lr_hi), 133 | } 134 | best = fmin( 135 | fn=objective, 136 | space=space, 137 | algo=tpe.suggest, 138 | max_evals=200, 139 | rstate=np.random.RandomState(seed), 140 | ) 141 | print(best["lr"], optimizer_class) 142 | 143 | steps = execute_steps( 144 | func, 145 | initial_state, 146 | optimizer_class, 147 | {"lr": best["lr"]}, 148 | num_iter=500, 149 | ) 150 | plot_func(steps, optimizer_class.__name__, best["lr"]) 151 | 152 | 153 | def LookaheadYogi(*a, **kw): 154 | base = optim.Yogi(*a, **kw) 155 | return optim.Lookahead(base) 156 | 157 | 158 | if __name__ == "__main__": 159 | # python examples/viz_optimizers.py 160 | 161 | # Each optimizer has tweaked search space to produce better plots and 162 | # help to converge on better lr faster. 163 | optimizers = [ 164 | # baselines 165 | (torch.optim.Adam, -8, 0.5), 166 | (torch.optim.SGD, -8, -1.0), 167 | # Adam based 168 | (optim.AdaBound, -8, 0.3), 169 | (optim.Adahessian, -1, 8), 170 | (optim.AdaMod, -8, 0.2), 171 | (optim.AdamP, -8, 0.2), 172 | (optim.DiffGrad, -8, 0.4), 173 | (optim.Lamb, -8, -2.9), 174 | (optim.MADGRAD, -8, 0.5), 175 | (optim.NovoGrad, -8, -1.7), 176 | (optim.RAdam, -8, 0.5), 177 | (optim.Yogi, -8, 0.1), 178 | # SGD/Momentum based 179 | (optim.AccSGD, -8, -1.4), 180 | (optim.SGDW, -8, -1.5), 181 | (optim.SGDP, -8, -1.5), 182 | (optim.PID, -8, -1.0), 183 | (optim.QHM, -6, -0.2), 184 | (optim.QHAdam, -8, 0.1), 185 | (optim.Ranger, -8, 0.1), 186 | (optim.RangerQH, -8, 0.1), 187 | (optim.RangerVA, -8, 0.1), 188 | (optim.Shampoo, -8, 0.1), 189 | (LookaheadYogi, -8, 0.1), 190 | (optim.AggMo, -8, -1.5), 191 | (optim.SWATS, -8, -1.5), 192 | (optim.Adafactor, -8, 0.5), 193 | (optim.A2GradUni, -8, 0.1), 194 | (optim.A2GradInc, -8, 0.1), 195 | (optim.A2GradExp, -8, 0.1), 196 | (optim.AdaBelief, -8, 0.1), 197 | (optim.Apollo, -8, 0.1), 198 | ] 199 | execute_experiments( 200 | optimizers, 201 | objective_rastrigin, 202 | rastrigin, 203 | plot_rastrigin, 204 | (-2.0, 3.5), 205 | ) 206 | 207 | execute_experiments( 208 | optimizers, 209 | objective_rosenbrok, 210 | rosenbrock, 211 | plot_rosenbrok, 212 | (-2.0, 2.0), 213 | ) 214 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | bandit==1.7.0 3 | black==23.3.0 4 | flake8-bugbear==21.9.2 5 | flake8==4.0.1 6 | ipdb==0.13.9 7 | isort==5.9.3 8 | mypy==0.910 9 | numpy==1.23.2 10 | pyroma==3.2 11 | pytest-cov==3.0.0 12 | pytest==6.2.5 13 | pytorch_ranger==0.1.1 14 | sphinx-autodoc-typehints==1.12.0 15 | sphinx==4.2.0 16 | torch==1.13.1 17 | twine==3.4.2 18 | wheel==0.38.1 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | from setuptools import find_packages, setup 5 | 6 | install_requires = [ 7 | "torch>=1.5.0", 8 | "pytorch_ranger>=0.1.1", 9 | ] 10 | 11 | 12 | def _read(f): 13 | with open(os.path.join(os.path.dirname(__file__), f)) as f_: 14 | return f_.read().strip() 15 | 16 | 17 | def _read_version(): 18 | regexp = re.compile(r'^__version__\W*=\W*"([\d.abrc]+)"') 19 | init_py = os.path.join( 20 | os.path.dirname(__file__), "torch_optimizer", "__init__.py" 21 | ) 22 | with open(init_py) as f: 23 | for line in f: 24 | match = regexp.match(line) 25 | if match is not None: 26 | return match.group(1) 27 | raise RuntimeError( 28 | "Cannot find version in torch_optimizer/__init__.py" 29 | ) 30 | 31 | 32 | classifiers = [ 33 | "License :: OSI Approved :: Apache Software License", 34 | "Intended Audience :: Developers", 35 | "Intended Audience :: Science/Research", 36 | "Programming Language :: Python :: 3", 37 | "Programming Language :: Python :: 3.6", 38 | "Programming Language :: Python :: 3.7", 39 | "Programming Language :: Python :: 3.8", 40 | "Operating System :: OS Independent", 41 | "Development Status :: 3 - Alpha", 42 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 43 | ] 44 | 45 | keywords = [ 46 | "torch-optimizer", 47 | "pytorch", 48 | # optimizers 49 | "accsgd", 50 | "adabound", 51 | "adamod", 52 | "diffgrad", 53 | "lamb", 54 | "lookahead", 55 | "madgrad", 56 | "novograd", 57 | "pid", 58 | "qhadam", 59 | "qhm", 60 | "radam", 61 | "sgdw", 62 | "yogi", 63 | "ranger", 64 | ] 65 | 66 | project_urls = { 67 | "Website": "https://github.com/jettify/pytorch-optimizer", 68 | "Documentation": "https://pytorch-optimizer.readthedocs.io", 69 | "Issues": "https://github.com/jettify/pytorch-optimizer/issues", 70 | } 71 | 72 | 73 | setup( 74 | name="torch-optimizer", 75 | version=_read_version(), 76 | description=("pytorch-optimizer"), 77 | long_description="\n\n".join((_read("README.rst"), _read("CHANGES.rst"))), 78 | long_description_content_type="text/x-rst", 79 | classifiers=classifiers, 80 | platforms=["POSIX"], 81 | author="Nikolay Novik", 82 | author_email="nickolainovik@gmail.com", 83 | url="https://github.com/jettify/pytorch-optimizer", 84 | download_url="https://pypi.org/project/torch-optimizer/", 85 | license="Apache 2", 86 | packages=find_packages(exclude=("tests",)), 87 | install_requires=install_requires, 88 | keywords=keywords, 89 | zip_safe=True, 90 | include_package_data=True, 91 | project_urls=project_urls, 92 | python_requires=">=3.6.0", 93 | ) 94 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jettify/pytorch-optimizer/19c3e41952b94f2d60db06e559ee9a1433b25e53/tests/conftest.py -------------------------------------------------------------------------------- /tests/test_basic.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | import torch_optimizer as optim 5 | 6 | 7 | def rosenbrock(tensor): 8 | x, y = tensor 9 | return (1 - x) ** 2 + 1 * (y - x**2) ** 2 10 | 11 | 12 | def quadratic(tensor): 13 | x, y = tensor 14 | a = 1.0 15 | b = 1.0 16 | return (x**2) / a + (y**2) / b 17 | 18 | 19 | def beale(tensor): 20 | x, y = tensor 21 | f = ( 22 | (1.5 - x + x * y) ** 2 23 | + (2.25 - x + x * y**2) ** 2 24 | + (2.625 - x + x * y**3) ** 2 25 | ) 26 | return f 27 | 28 | 29 | cases = [ 30 | (rosenbrock, (1.5, 1.5), (1, 1)), 31 | (quadratic, (1.5, 1.5), (0, 0)), 32 | (beale, (1.5, 1.5), (3, 0.5)), 33 | ] 34 | 35 | 36 | def ids(v): 37 | n = "{} {}".format(v[0].__name__, v[1:]) 38 | return n 39 | 40 | 41 | def build_lookahead(*a, **kw): 42 | base = optim.Yogi(*a, **kw) 43 | return optim.Lookahead(base) 44 | 45 | 46 | optimizers = [ 47 | (optim.A2GradUni, {"lips": 40, "beta": 0.0001}, 800), 48 | (optim.PID, {"lr": 0.002, "momentum": 0.8, "weight_decay": 0.0001}, 900), 49 | (optim.QHM, {"lr": 0.02, "momentum": 0.95, "nu": 1}, 900), 50 | ( 51 | optim.NovoGrad, 52 | {"lr": 2.9, "betas": (0.9, 0.999), "grad_averaging": True}, 53 | 900, 54 | ), 55 | (optim.RAdam, {"lr": 0.01, "betas": (0.9, 0.95), "eps": 1e-3}, 800), 56 | (optim.SGDW, {"lr": 0.002, "momentum": 0.91}, 900), 57 | (optim.DiffGrad, {"lr": 0.5}, 500), 58 | (optim.AdaMod, {"lr": 1.0}, 800), 59 | (optim.AdaBound, {"lr": 1.0}, 800), 60 | (optim.Yogi, {"lr": 1.0}, 500), 61 | (optim.AccSGD, {"lr": 0.015}, 800), 62 | (build_lookahead, {"lr": 1.0}, 500), 63 | (optim.QHAdam, {"lr": 1.0}, 500), 64 | (optim.AdamP, {"lr": 0.01, "betas": (0.9, 0.95), "eps": 1e-3}, 800), 65 | (optim.SGDP, {"lr": 0.002, "momentum": 0.91}, 900), 66 | (optim.AggMo, {"lr": 0.003}, 1800), 67 | (optim.SWATS, {"lr": 0.1, "amsgrad": True, "nesterov": True}, 900), 68 | (optim.Adafactor, {"lr": None, "decay_rate": -0.3, "beta1": 0.9}, 800), 69 | (optim.AdaBelief, {"lr": 1.0}, 500), 70 | (optim.Adahessian, {"lr": 0.15, "hessian_power": 0.6, "seed": 0}, 900), 71 | (optim.MADGRAD, {"lr": 0.02}, 500), 72 | (optim.LARS, {"lr": 0.002, "momentum": 0.91}, 900), 73 | (optim.Lion, {"lr": 0.025}, 3600), 74 | ] 75 | 76 | 77 | @pytest.mark.parametrize("case", cases, ids=ids) 78 | @pytest.mark.parametrize("optimizer_config", optimizers, ids=ids) 79 | def test_benchmark_function(case, optimizer_config): 80 | func, initial_state, min_loc = case 81 | optimizer_class, config, iterations = optimizer_config 82 | 83 | x = torch.Tensor(initial_state).requires_grad_(True) 84 | x_min = torch.Tensor(min_loc) 85 | optimizer = optimizer_class([x], **config) 86 | for _ in range(iterations): 87 | optimizer.zero_grad() 88 | f = func(x) 89 | f.backward(retain_graph=True, create_graph=True) 90 | optimizer.step() 91 | assert torch.allclose(x, x_min, atol=0.001) 92 | 93 | name = optimizer.__class__.__name__ 94 | assert name in optimizer.__repr__() 95 | -------------------------------------------------------------------------------- /tests/test_optimizer_with_nn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from torch import nn 5 | 6 | import torch_optimizer as optim 7 | 8 | 9 | def make_dataset(seed=42): 10 | rng = np.random.RandomState(seed) 11 | N = 100 12 | D = 2 13 | 14 | X = rng.randn(N, D) * 2 15 | 16 | # center the first N/2 points at (-2,-2) 17 | mid = N // 2 18 | X[:mid, :] = X[:mid, :] - 2 * np.ones((mid, D)) 19 | 20 | # center the last N/2 points at (2, 2) 21 | X[mid:, :] = X[mid:, :] + 2 * np.ones((mid, D)) 22 | 23 | # labels: first N/2 are 0, last N/2 are 1 24 | Y = np.array([0] * mid + [1] * mid).reshape(100, 1) 25 | 26 | x = torch.Tensor(X) 27 | y = torch.Tensor(Y) 28 | return x, y 29 | 30 | 31 | class LogisticRegression(nn.Module): 32 | def __init__(self): 33 | super(LogisticRegression, self).__init__() 34 | self.linear1 = nn.Linear(2, 4) 35 | self.linear2 = nn.Linear(4, 1) 36 | 37 | def forward(self, x): 38 | output = torch.relu(self.linear1(x)) 39 | output = self.linear2(output) 40 | y_pred = torch.sigmoid(output) 41 | return y_pred 42 | 43 | 44 | def ids(v): 45 | return "{} {}".format(v[0].__name__, v[1:]) 46 | 47 | 48 | def build_lookahead(*a, **kw): 49 | base = optim.Yogi(*a, **kw) 50 | return optim.Lookahead(base) 51 | 52 | 53 | optimizers = [ 54 | (build_lookahead, {"lr": 0.1, "weight_decay": 1e-3}, 200), 55 | (optim.A2GradExp, {"lips": 2.0, "beta": 1e-3}, 500), 56 | (optim.A2GradInc, {"lips": 5.0, "beta": 1e-3}, 200), 57 | (optim.A2GradUni, {"lips": 5.0, "beta": 1e-3}, 500), 58 | (optim.AccSGD, {"lr": 1.0, "weight_decay": 1e-3}, 200), 59 | (optim.AdaBelief, {"lr": 0.1, "weight_decay": 1e-3}, 200), 60 | (optim.AdaBound, {"lr": 1.5, "gamma": 0.1, "weight_decay": 1e-3}, 200), 61 | (optim.AdaMod, {"lr": 2.0, "weight_decay": 1e-3}, 200), 62 | (optim.Adafactor, {"lr": 0.004466, "weight_decay": 1e-3}, 1500), 63 | (optim.AdamP, {"lr": 0.045, "weight_decay": 1e-3}, 800), 64 | (optim.AggMo, {"lr": 0.17059, "weight_decay": 1e-3}, 1000), 65 | (optim.Apollo, {"lr": 0.1, "weight_decay": 1e-3}, 200), 66 | (optim.DiffGrad, {"lr": 0.5, "weight_decay": 1e-3}, 200), 67 | ( 68 | optim.LARS, 69 | {"lr": 1.0, "weight_decay": 1e-3, "trust_coefficient": 0.01}, 70 | 200, 71 | ), 72 | (optim.Lamb, {"lr": 0.0151, "weight_decay": 1e-3}, 1000), 73 | (optim.MADGRAD, {"lr": 1.0, "weight_decay": 1e-3}, 200), 74 | (optim.NovoGrad, {"lr": 0.01, "weight_decay": 1e-3}, 200), 75 | (optim.PID, {"lr": 0.01, "weight_decay": 1e-3, "momentum": 0.1}, 200), 76 | (optim.QHAdam, {"lr": 0.1, "weight_decay": 1e-3}, 200), 77 | (optim.QHM, {"lr": 0.1, "weight_decay": 1e-5, "momentum": 0.2}, 200), 78 | (optim.RAdam, {"lr": 1.0, "weight_decay": 1e-3}, 200), 79 | (optim.Ranger, {"lr": 0.1, "weight_decay": 1e-3}, 200), 80 | (optim.RangerQH, {"lr": 0.0124, "weight_decay": 1e-3}, 1100), 81 | (optim.RangerVA, {"lr": 0.2214, "weight_decay": 1e-3}, 500), 82 | (optim.SGDP, {"lr": 1.0, "weight_decay": 1e-3}, 200), 83 | (optim.SGDW, {"lr": 1.0, "weight_decay": 1e-3}, 200), 84 | (optim.SWATS, {"lr": 0.703, "weight_decay": 1e-3}, 600), 85 | ( 86 | optim.Shampoo, 87 | {"lr": 0.279, "weight_decay": 1e-3, "momentum": 0.05}, 88 | 1600, 89 | ), 90 | (optim.Yogi, {"lr": 0.1, "weight_decay": 1e-3}, 200), 91 | (optim.Adahessian, {"lr": 0.1, "weight_decay": 1e-3}, 200), 92 | (optim.Lion, {"lr": 0.1, "weight_decay": 1e-3}, 200), 93 | ] 94 | 95 | 96 | @pytest.mark.parametrize("optimizer_config", optimizers, ids=ids) 97 | def test_basic_nn_modeloptimizer_config(optimizer_config): 98 | torch.manual_seed(42) 99 | x_data, y_data = make_dataset() 100 | model = LogisticRegression() 101 | 102 | loss_fn = nn.BCELoss() 103 | optimizer_class, config, iterations = optimizer_config 104 | optimizer = optimizer_class(model.parameters(), **config) 105 | init_loss = None 106 | for _ in range(iterations): 107 | y_pred = model(x_data) 108 | loss = loss_fn(y_pred, y_data) 109 | if init_loss is None: 110 | init_loss = loss 111 | optimizer.zero_grad() 112 | loss.backward(create_graph=True) 113 | optimizer.step() 114 | assert init_loss.item() > 2.0 * loss.item() 115 | -------------------------------------------------------------------------------- /tests/test_param_validation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | import torch_optimizer as optim 5 | 6 | 7 | def assert_sparse_not_supported(optimizer_class, err_msg=None): 8 | param = torch.randn(1, 1).to_sparse().requires_grad_(True) 9 | grad = torch.randn(1, 1).to_sparse() 10 | param.grad = grad 11 | optimizer = optimizer_class([param]) 12 | optimizer.zero_grad() 13 | with pytest.raises(RuntimeError) as ctx: 14 | optimizer.step() 15 | 16 | msg = err_msg or "does not support sparse gradients" 17 | assert msg in str(ctx.value) 18 | 19 | 20 | no_sparse_optimizers = [ 21 | optim.AdaBound, 22 | optim.AdaMod, 23 | optim.DiffGrad, 24 | optim.Lamb, 25 | optim.NovoGrad, 26 | optim.RAdam, 27 | optim.Yogi, 28 | ] 29 | 30 | 31 | @pytest.mark.parametrize("optimizer_class", no_sparse_optimizers) 32 | def test_sparse_not_supported(optimizer_class): 33 | assert_sparse_not_supported(optimizer_class) 34 | 35 | 36 | optimizers = [ 37 | optim.AccSGD, 38 | optim.AdaBelief, 39 | optim.AdaBound, 40 | optim.AdaMod, 41 | optim.AdamP, 42 | optim.AggMo, 43 | optim.Apollo, 44 | optim.DiffGrad, 45 | optim.LARS, 46 | optim.Lamb, 47 | optim.MADGRAD, 48 | optim.NovoGrad, 49 | optim.PID, 50 | optim.QHAdam, 51 | optim.QHM, 52 | optim.RAdam, 53 | optim.SGDP, 54 | optim.SGDW, 55 | optim.SWATS, 56 | optim.Shampoo, 57 | optim.Yogi, 58 | optim.Lion, 59 | ] 60 | 61 | 62 | @pytest.mark.parametrize("optimizer_class", optimizers) 63 | def test_learning_rate(optimizer_class): 64 | lr = -0.01 65 | with pytest.raises(ValueError) as ctx: 66 | optimizer_class(None, lr=-0.01) 67 | msg = "Invalid learning rate: {}".format(lr) 68 | assert msg in str(ctx.value) 69 | 70 | 71 | eps_optimizers = [ 72 | optim.AdaBelief, 73 | optim.AdaBound, 74 | optim.AdaMod, 75 | optim.AdamP, 76 | optim.Apollo, 77 | optim.DiffGrad, 78 | optim.LARS, 79 | optim.Lamb, 80 | optim.MADGRAD, 81 | optim.NovoGrad, 82 | optim.QHAdam, 83 | optim.RAdam, 84 | optim.SGDP, 85 | optim.SWATS, 86 | optim.Yogi, 87 | ] 88 | 89 | 90 | @pytest.mark.parametrize("optimizer_class", eps_optimizers) 91 | def test_eps_validation(optimizer_class): 92 | eps = -0.1 93 | with pytest.raises(ValueError) as ctx: 94 | optimizer_class(None, lr=0.1, eps=eps) 95 | msg = "Invalid epsilon value: {}".format(eps) 96 | assert msg in str(ctx.value) 97 | 98 | 99 | weight_decay_optimizers = [ 100 | optim.AccSGD, 101 | optim.AdaBelief, 102 | optim.AdaBound, 103 | optim.AdaMod, 104 | optim.Adafactor, 105 | optim.AdamP, 106 | optim.AggMo, 107 | optim.Apollo, 108 | optim.DiffGrad, 109 | optim.LARS, 110 | optim.Lamb, 111 | optim.MADGRAD, 112 | optim.NovoGrad, 113 | optim.PID, 114 | optim.QHAdam, 115 | optim.QHM, 116 | optim.RAdam, 117 | optim.SGDP, 118 | optim.SGDW, 119 | optim.SWATS, 120 | optim.Shampoo, 121 | optim.Yogi, 122 | optim.Lion, 123 | ] 124 | 125 | 126 | @pytest.mark.parametrize("optimizer_class", weight_decay_optimizers) 127 | def test_weight_decay_validation(optimizer_class): 128 | weight_decay = -0.1 129 | with pytest.raises(ValueError) as ctx: 130 | optimizer_class(None, lr=0.1, weight_decay=weight_decay) 131 | msg = "Invalid weight_decay value: {}".format(weight_decay) 132 | assert msg in str(ctx.value) 133 | 134 | 135 | betas_optimizers = [ 136 | optim.AdaBelief, 137 | optim.AdaBound, 138 | optim.AdaMod, 139 | optim.AdamP, 140 | optim.DiffGrad, 141 | optim.Lamb, 142 | optim.NovoGrad, 143 | optim.QHAdam, 144 | optim.RAdam, 145 | optim.Yogi, 146 | optim.Lion, 147 | ] 148 | 149 | 150 | @pytest.mark.parametrize("optimizer_class", betas_optimizers) 151 | def test_betas_validation(optimizer_class): 152 | betas = (-1, 0.999) 153 | with pytest.raises(ValueError) as ctx: 154 | optimizer_class(None, lr=0.1, betas=(-1, 0.999)) 155 | msg = "Invalid beta parameter at index 0: {}".format(betas[0]) 156 | assert msg in str(ctx.value) 157 | 158 | betas = (0.9, -0.999) 159 | with pytest.raises(ValueError) as ctx: 160 | optimizer_class(None, lr=0.1, betas=betas) 161 | msg = "Invalid beta parameter at index 1: {}".format(betas[1]) 162 | assert msg in str(ctx.value) 163 | -------------------------------------------------------------------------------- /torch_optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | """torch-optimizer -- collection of of optimization algorithms for PyTorch. 2 | 3 | API and usage patterns are the same as `torch.optim`__ 4 | 5 | Example 6 | ------- 7 | 8 | >>> import torch_optimizer as optim 9 | # model = ... 10 | >>> optimizer = optim.DiffGrad(model.parameters(), lr=0.001) 11 | >>> optimizer.step() 12 | 13 | See documentation for full list of supported optimizers. 14 | 15 | __ https://pytorch.org/docs/stable/optim.html#module-torch.optim 16 | """ 17 | from typing import Dict, List, Type 18 | 19 | from pytorch_ranger import Ranger, RangerQH, RangerVA 20 | from torch.optim.optimizer import Optimizer 21 | 22 | from .a2grad import A2GradExp, A2GradInc, A2GradUni 23 | from .accsgd import AccSGD 24 | from .adabelief import AdaBelief 25 | from .adabound import AdaBound 26 | from .adafactor import Adafactor 27 | from .adahessian import Adahessian 28 | from .adamod import AdaMod 29 | from .adamp import AdamP 30 | from .aggmo import AggMo 31 | from .apollo import Apollo 32 | from .diffgrad import DiffGrad 33 | from .lamb import Lamb 34 | from .lars import LARS 35 | from .lion import Lion 36 | from .lookahead import Lookahead 37 | from .madgrad import MADGRAD 38 | from .novograd import NovoGrad 39 | from .pid import PID 40 | from .qhadam import QHAdam 41 | from .qhm import QHM 42 | from .radam import RAdam 43 | from .sgdp import SGDP 44 | from .sgdw import SGDW 45 | from .shampoo import Shampoo 46 | from .swats import SWATS 47 | from .yogi import Yogi 48 | 49 | __all__ = ( 50 | "A2GradExp", 51 | "A2GradInc", 52 | "A2GradUni", 53 | "AccSGD", 54 | "AdaBelief", 55 | "AdaBound", 56 | "AdaMod", 57 | "Adafactor", 58 | "Adahessian", 59 | "AdamP", 60 | "AggMo", 61 | "Apollo", 62 | "DiffGrad", 63 | "LARS", 64 | "Lamb", 65 | "Lookahead", 66 | "MADGRAD", 67 | "NovoGrad", 68 | "PID", 69 | "QHAdam", 70 | "QHM", 71 | "RAdam", 72 | "Ranger", 73 | "RangerQH", 74 | "RangerVA", 75 | "SGDP", 76 | "SGDW", 77 | "SWATS", 78 | "Shampoo", 79 | "Yogi", 80 | "Lion", 81 | # utils 82 | "get", 83 | ) 84 | __version__ = "0.3.1a0" 85 | 86 | 87 | _package_opts = [ 88 | AdaBelief, 89 | AccSGD, 90 | AdaBound, 91 | AdaMod, 92 | AdamP, 93 | AggMo, 94 | DiffGrad, 95 | LARS, 96 | Lamb, 97 | Lookahead, 98 | MADGRAD, 99 | NovoGrad, 100 | PID, 101 | QHAdam, 102 | QHM, 103 | RAdam, 104 | Ranger, 105 | RangerQH, 106 | RangerVA, 107 | SGDP, 108 | SGDW, 109 | SWATS, 110 | Shampoo, 111 | Yogi, 112 | Lion, 113 | ] # type: List[Type[Optimizer]] 114 | 115 | 116 | _NAME_OPTIM_MAP = { 117 | opt.__name__.lower(): opt for opt in _package_opts 118 | } # type: Dict[str, Type[Optimizer]] 119 | 120 | 121 | def get(name: str) -> Type[Optimizer]: 122 | r"""Returns an optimizer class from its name. Case insensitive. 123 | 124 | Args: 125 | name: the optimizer name. 126 | """ 127 | optimizer_class = _NAME_OPTIM_MAP.get(name.lower()) 128 | if optimizer_class is None: 129 | raise ValueError("Optimizer {} not found".format(name)) 130 | return optimizer_class 131 | -------------------------------------------------------------------------------- /torch_optimizer/accsgd.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from torch.optim.optimizer import Optimizer 4 | 5 | from .types import OptFloat, OptLossClosure, Params 6 | 7 | __all__ = ("AccSGD",) 8 | 9 | 10 | class AccSGD(Optimizer): 11 | r"""Implements AccSGD algorithm. 12 | 13 | It has been proposed in `On the insufficiency of existing momentum 14 | schemes for Stochastic Optimization`__ and `Accelerating Stochastic 15 | Gradient Descent For Least Squares Regression`__ 16 | 17 | Arguments: 18 | params: iterable of parameters to optimize or dicts defining 19 | parameter groups 20 | lr: learning rate (default: 1e-3) 21 | kappa: ratio of long to short step (default: 1000) 22 | xi: statistical advantage parameter (default: 10) 23 | small_const: any value <=1 (default: 0.7) 24 | weight_decay: weight decay (L2 penalty) (default: 0) 25 | 26 | Example: 27 | >>> import torch_optimizer as optim 28 | >>> optimizer = optim.AccSGD(model.parameters(), lr=0.1) 29 | >>> optimizer.zero_grad() 30 | >>> loss_fn(model(input), target).backward() 31 | >>> optimizer.step() 32 | 33 | __ https://arxiv.org/abs/1704.08227 34 | __ https://arxiv.org/abs/1803.05591 35 | 36 | Note: 37 | Reference code: https://github.com/rahulkidambi/AccSGD 38 | """ 39 | 40 | def __init__( 41 | self, 42 | params: Params, 43 | lr: float = 1e-3, 44 | kappa: float = 1000.0, 45 | xi: float = 10.0, 46 | small_const: float = 0.7, 47 | weight_decay: float = 0, 48 | ) -> None: 49 | if lr <= 0.0: 50 | raise ValueError("Invalid learning rate: {}".format(lr)) 51 | if weight_decay < 0: 52 | raise ValueError( 53 | "Invalid weight_decay value: {}".format(weight_decay) 54 | ) 55 | defaults = dict( 56 | lr=lr, 57 | kappa=kappa, 58 | xi=xi, 59 | small_const=small_const, 60 | weight_decay=weight_decay, 61 | ) 62 | super(AccSGD, self).__init__(params, defaults) 63 | 64 | def step(self, closure: OptLossClosure = None) -> OptFloat: 65 | r"""Performs a single optimization step. 66 | 67 | Arguments: 68 | closure: A closure that reevaluates the model and returns the loss. 69 | """ 70 | loss = None 71 | if closure is not None: 72 | loss = closure() 73 | 74 | for group in self.param_groups: 75 | weight_decay = group["weight_decay"] 76 | large_lr = (group["lr"] * group["kappa"]) / (group["small_const"]) 77 | alpha = 1.0 - ( 78 | (group["small_const"] * group["small_const"] * group["xi"]) 79 | / group["kappa"] 80 | ) 81 | beta = 1.0 - alpha 82 | zeta = group["small_const"] / (group["small_const"] + beta) 83 | for p in group["params"]: 84 | if p.grad is None: 85 | continue 86 | d_p = p.grad.data 87 | if weight_decay != 0: 88 | d_p.add_(p.data, alpha=weight_decay) 89 | param_state = self.state[p] 90 | if "momentum_buffer" not in param_state: 91 | param_state["momentum_buffer"] = copy.deepcopy(p.data) 92 | buf = param_state["momentum_buffer"] 93 | buf.mul_((1.0 / beta) - 1.0) 94 | buf.add_(d_p, alpha=-large_lr) 95 | buf.add_(p.data) 96 | buf.mul_(beta) 97 | 98 | p.data.add_(d_p, alpha=-group["lr"]) 99 | p.data.mul_(zeta) 100 | p.data.add_(buf, alpha=1.0 - zeta) 101 | 102 | return loss 103 | -------------------------------------------------------------------------------- /torch_optimizer/adabound.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.optim.optimizer import Optimizer 5 | 6 | from .types import Betas2, OptFloat, OptLossClosure, Params, State 7 | 8 | __all__ = ("AdaBound",) 9 | 10 | 11 | class AdaBound(Optimizer): 12 | r"""Implements AdaBound algorithm. 13 | 14 | It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of 15 | Learning Rate`__. 16 | 17 | Arguments: 18 | params: iterable of parameters to optimize or dicts defining 19 | parameter groups 20 | lr: learning rate (default: 1e-3) 21 | betas: coefficients used for computing running averages of gradient 22 | and its square (default: (0.9, 0.999)) 23 | final_lr: final (SGD) learning rate (default: 0.1) 24 | gamma: convergence speed of the bound functions 25 | (default: 1e-3) 26 | eps: term added to the denominator to improve numerical stability 27 | (default: 1e-8) 28 | weight_decay: weight decay (L2 penalty) (default: 0) 29 | amsbound: whether to use the AMSBound variant of this algorithm 30 | 31 | Example: 32 | >>> import torch_optimizer as optim 33 | >>> optimizer = optim.AdaBound(model.parameters(), lr=0.1) 34 | >>> optimizer.zero_grad() 35 | >>> loss_fn(model(input), target).backward() 36 | >>> optimizer.step() 37 | 38 | __ https://arxiv.org/abs/1902.09843 39 | 40 | Note: 41 | Reference code: https://github.com/Luolc/AdaBound 42 | """ 43 | 44 | def __init__( 45 | self, 46 | params: Params, 47 | lr: float = 1e-3, 48 | betas: Betas2 = (0.9, 0.999), 49 | final_lr: float = 0.1, 50 | gamma: float = 1e-3, 51 | eps: float = 1e-8, 52 | weight_decay: float = 0, 53 | amsbound: bool = False, 54 | ) -> None: 55 | if lr <= 0.0: 56 | raise ValueError("Invalid learning rate: {}".format(lr)) 57 | if eps < 0.0: 58 | raise ValueError("Invalid epsilon value: {}".format(eps)) 59 | if not 0.0 <= betas[0] < 1.0: 60 | raise ValueError( 61 | "Invalid beta parameter at index 0: {}".format(betas[0]) 62 | ) 63 | if not 0.0 <= betas[1] < 1.0: 64 | raise ValueError( 65 | "Invalid beta parameter at index 1: {}".format(betas[1]) 66 | ) 67 | if final_lr < 0.0: 68 | raise ValueError( 69 | "Invalid final learning rate: {}".format(final_lr) 70 | ) 71 | if not 0.0 <= gamma < 1.0: 72 | raise ValueError("Invalid gamma parameter: {}".format(gamma)) 73 | if weight_decay < 0: 74 | raise ValueError( 75 | "Invalid weight_decay value: {}".format(weight_decay) 76 | ) 77 | defaults = dict( 78 | lr=lr, 79 | betas=betas, 80 | final_lr=final_lr, 81 | gamma=gamma, 82 | eps=eps, 83 | weight_decay=weight_decay, 84 | amsbound=amsbound, 85 | ) 86 | super(AdaBound, self).__init__(params, defaults) 87 | self.base_lrs = [group["lr"] for group in self.param_groups] 88 | 89 | def __setstate__(self, state: State) -> None: 90 | super(AdaBound, self).__setstate__(state) 91 | for group in self.param_groups: 92 | group.setdefault("amsbound", False) 93 | 94 | def step(self, closure: OptLossClosure = None) -> OptFloat: 95 | r"""Performs a single optimization step. 96 | 97 | Arguments: 98 | closure: A closure that reevaluates the model and returns the loss. 99 | """ 100 | loss = None 101 | if closure is not None: 102 | loss = closure() 103 | 104 | for group, base_lr in zip(self.param_groups, self.base_lrs): 105 | for p in group["params"]: 106 | if p.grad is None: 107 | continue 108 | grad = p.grad.data 109 | if grad.is_sparse: 110 | msg = ( 111 | "AdaBound does not support sparse gradients, " 112 | "please consider SparseAdam instead" 113 | ) 114 | raise RuntimeError(msg) 115 | amsbound = group["amsbound"] 116 | 117 | state = self.state[p] 118 | 119 | # State initialization 120 | if len(state) == 0: 121 | state["step"] = 0 122 | # Exponential moving average of gradient values 123 | state["exp_avg"] = torch.zeros_like( 124 | p, memory_format=torch.preserve_format 125 | ) 126 | # Exponential moving average of squared gradient values 127 | state["exp_avg_sq"] = torch.zeros_like( 128 | p, memory_format=torch.preserve_format 129 | ) 130 | if amsbound: 131 | # Maintains max of all exp. moving avg. of 132 | # sq. grad. values 133 | state["max_exp_avg_sq"] = torch.zeros_like( 134 | p, memory_format=torch.preserve_format 135 | ) 136 | 137 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 138 | if amsbound: 139 | max_exp_avg_sq = state["max_exp_avg_sq"] 140 | beta1, beta2 = group["betas"] 141 | 142 | state["step"] += 1 143 | 144 | if group["weight_decay"] != 0: 145 | grad = grad.add(p.data, alpha=group["weight_decay"]) 146 | 147 | # Decay the first and second moment running average coefficient 148 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 149 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 150 | if amsbound: 151 | # Maintains the maximum of all 2nd moment running 152 | # avg. till now 153 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 154 | # Use the max. for normalizing running avg. of gradient 155 | denom = max_exp_avg_sq.sqrt().add_(group["eps"]) 156 | else: 157 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 158 | 159 | bias_correction1 = 1 - beta1 ** state["step"] 160 | bias_correction2 = 1 - beta2 ** state["step"] 161 | step_size = ( 162 | group["lr"] 163 | * math.sqrt(bias_correction2) 164 | / bias_correction1 165 | ) 166 | 167 | # Applies bounds on actual learning rate 168 | # lr_scheduler cannot affect final_lr, this is a workaround 169 | # to apply lr decay 170 | final_lr = group["final_lr"] * group["lr"] / base_lr 171 | lower_bound = final_lr * ( 172 | 1 - 1 / (group["gamma"] * state["step"] + 1) 173 | ) 174 | upper_bound = final_lr * ( 175 | 1 + 1 / (group["gamma"] * state["step"]) 176 | ) 177 | step_size = torch.full_like(denom, step_size) 178 | step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_( 179 | exp_avg 180 | ) 181 | 182 | p.data.add_(-step_size) 183 | return loss 184 | -------------------------------------------------------------------------------- /torch_optimizer/adafactor.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any, Dict, Tuple 3 | 4 | import torch 5 | from torch.optim.optimizer import Optimizer 6 | 7 | from .types import OptFloat, OptLossClosure, Params, State 8 | 9 | Eps2 = Tuple[float, float] 10 | ParamGroup = Dict[str, Any] 11 | 12 | 13 | class Adafactor(Optimizer): 14 | """Implements Adafactor algorithm. 15 | 16 | It has been proposed in: `Adafactor: Adaptive Learning Rates with 17 | Sublinear Memory Cost`__. 18 | 19 | Arguments: 20 | params: iterable of parameters to optimize or dicts defining 21 | parameter groups 22 | lr: external learning rate (default: None) 23 | eps2: regularization constans for square gradient 24 | and parameter scale respectively (default: (1e-30, 1e-3)) 25 | clip_threshold: threshold of root mean square of 26 | final gradient update (default: 1.0) 27 | decay_rate: coefficient used to compute running averages of square 28 | gradient (default: -0.8) 29 | beta1: coefficient used for computing running averages of gradient 30 | (default: None) 31 | weight_decay: weight decay (L2 penalty) (default: 0) 32 | scale_parameter: if true, learning rate is scaled by root mean square 33 | of parameter (default: True) 34 | relative_step: if true, time-dependent learning rate is computed 35 | instead of external learning rate (default: True) 36 | warmup_init: time-dependent learning rate computation depends on 37 | whether warm-up initialization is being used (default: False) 38 | 39 | Example: 40 | >>> import torch_optimizer as optim 41 | >>> optimizer = optim.Adafactor(model.parameters()) 42 | >>> optimizer.zero_grad() 43 | >>> loss_fn(model(input), target).backward() 44 | >>> optimizer.step() 45 | 46 | __ https://arxiv.org/abs/1804.04235 47 | 48 | Note: 49 | Reference code: https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py # noqa 50 | """ 51 | 52 | def __init__( 53 | self, 54 | params: Params, 55 | lr: OptFloat = None, 56 | eps2: Eps2 = (1e-30, 1e-3), 57 | clip_threshold: float = 1.0, 58 | decay_rate: float = -0.8, 59 | beta1: OptFloat = None, 60 | weight_decay: float = 0.0, 61 | scale_parameter: bool = True, 62 | relative_step: bool = True, 63 | warmup_init: bool = False, 64 | ): 65 | if lr is not None and lr <= 0.0: 66 | raise ValueError("Invalid learning rate: {}".format(lr)) 67 | if weight_decay < 0.0: 68 | raise ValueError( 69 | "Invalid weight_decay value: {}".format(weight_decay) 70 | ) 71 | 72 | defaults = dict( 73 | lr=lr, 74 | eps2=eps2, 75 | clip_threshold=clip_threshold, 76 | decay_rate=decay_rate, 77 | beta1=beta1, 78 | weight_decay=weight_decay, 79 | scale_parameter=scale_parameter, 80 | relative_step=relative_step, 81 | warmup_init=warmup_init, 82 | ) 83 | super(Adafactor, self).__init__(params, defaults) 84 | 85 | def _get_lr(self, param_group: ParamGroup, param_state: State) -> float: 86 | rel_step_sz = param_group["lr"] 87 | if param_group["relative_step"]: 88 | min_step = ( 89 | 1e-6 * param_state["step"] 90 | if param_group["warmup_init"] 91 | else 1e-2 92 | ) 93 | rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) 94 | param_scale = 1.0 95 | if param_group["scale_parameter"]: 96 | param_scale = max(param_group["eps2"][1], param_state["RMS"]) 97 | return param_scale * rel_step_sz 98 | 99 | def _get_options( 100 | self, param_group: ParamGroup, param_shape: Tuple[int, ...] 101 | ) -> Tuple[bool, bool]: 102 | factored = len(param_shape) >= 2 103 | use_first_moment = param_group["beta1"] is not None 104 | return factored, use_first_moment 105 | 106 | def _rms(self, tensor: torch.Tensor) -> float: 107 | return tensor.norm(2) / (tensor.numel() ** 0.5) 108 | 109 | def _approx_sq_grad( 110 | self, 111 | exp_avg_sq_row: torch.Tensor, 112 | exp_avg_sq_col: torch.Tensor, 113 | output: torch.Tensor, 114 | ) -> None: 115 | r_factor = ( 116 | (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1)) 117 | .rsqrt_() 118 | .unsqueeze(-1) 119 | ) 120 | c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() 121 | torch.mul(r_factor, c_factor, out=output) 122 | 123 | def step(self, closure: OptLossClosure = None) -> OptFloat: 124 | r"""Performs a single optimization step. 125 | 126 | Arguments: 127 | closure: A closure that reevaluates the model and returns the loss. 128 | """ 129 | loss = None 130 | if closure is not None: 131 | loss = closure() 132 | 133 | for group in self.param_groups: 134 | for p in group["params"]: 135 | if p.grad is None: 136 | continue 137 | grad = p.grad.data 138 | if grad.is_sparse: 139 | raise RuntimeError( 140 | "Adafactor does not support sparse gradients." 141 | ) 142 | 143 | state = self.state[p] 144 | grad_shape = grad.shape 145 | 146 | factored, use_first_moment = self._get_options( 147 | group, grad_shape 148 | ) 149 | # State Initialization 150 | if len(state) == 0: 151 | state["step"] = 0 152 | 153 | if use_first_moment: 154 | # Exponential moving average of gradient values 155 | state["exp_avg"] = torch.zeros_like( 156 | grad, memory_format=torch.preserve_format 157 | ) 158 | if factored: 159 | state["exp_avg_sq_row"] = torch.zeros( 160 | grad_shape[:-1] 161 | ).type_as(grad) 162 | state["exp_avg_sq_col"] = torch.zeros( 163 | grad_shape[:-2] + grad_shape[-1:] 164 | ).type_as(grad) 165 | else: 166 | state["exp_avg_sq"] = torch.zeros_like( 167 | grad, memory_format=torch.preserve_format 168 | ) 169 | 170 | state["RMS"] = 0 171 | 172 | state["step"] += 1 173 | state["RMS"] = self._rms(p.data) 174 | lr = self._get_lr(group, state) 175 | 176 | beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) 177 | update = (grad**2) + group["eps2"][0] 178 | if factored: 179 | exp_avg_sq_row = state["exp_avg_sq_row"] 180 | exp_avg_sq_col = state["exp_avg_sq_col"] 181 | 182 | exp_avg_sq_row.mul_(beta2t).add_( 183 | update.mean(dim=-1), alpha=1.0 - beta2t 184 | ) 185 | exp_avg_sq_col.mul_(beta2t).add_( 186 | update.mean(dim=-2), alpha=1.0 - beta2t 187 | ) 188 | 189 | # Approximation of exponential moving average of square 190 | # of gradient 191 | self._approx_sq_grad( 192 | exp_avg_sq_row, exp_avg_sq_col, update 193 | ) 194 | update.mul_(grad) 195 | else: 196 | exp_avg_sq = state["exp_avg_sq"] 197 | 198 | exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) 199 | torch.rsqrt(exp_avg_sq, out=update).mul_(grad) 200 | 201 | update.div_( 202 | max(1.0, self._rms(update) / group["clip_threshold"]) 203 | ) 204 | update.mul_(lr) 205 | 206 | if use_first_moment: 207 | exp_avg = state["exp_avg"] 208 | exp_avg.mul_(group["beta1"]).add_( 209 | update, alpha=1 - group["beta1"] 210 | ) 211 | update = exp_avg 212 | 213 | if group["weight_decay"] != 0: 214 | p.data.add_(p.data, alpha=-group["weight_decay"] * lr) 215 | 216 | p.data.add_(-update) 217 | 218 | return loss 219 | -------------------------------------------------------------------------------- /torch_optimizer/adahessian.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional 3 | 4 | import torch 5 | from torch.optim.optimizer import Optimizer 6 | 7 | from .types import Betas2, OptFloat, OptLossClosure, Params 8 | 9 | Grads = Params 10 | 11 | __all__ = ("Adahessian",) 12 | 13 | 14 | class Adahessian(Optimizer): 15 | r"""Implements Adahessian Algorithm. 16 | It has been proposed in `ADAHESSIAN: An Adaptive Second Order Optimizer 17 | for Machine Learning`. 18 | 19 | Arguments: 20 | params (iterable): iterable of parameters to optimize or dicts defining 21 | parameter groups 22 | lr (float, optional): learning rate (default: 0.15) 23 | betas (Tuple[float, float], optional): coefficients used for computing 24 | running averages of gradient and its square (default: (0.9, 0.999)) 25 | eps (float, optional): term added to the denominator to improve 26 | numerical stability (default: 1e-4) 27 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 28 | hessian_power (float, optional): Hessian power (default: 0.5) 29 | seed (int, optional): Random number generator seed (default: None) 30 | 31 | Example: 32 | >>> import torch_optimizer as optim 33 | >>> optimizer = optim.Adahessian(model.parameters(), lr = 1.0) 34 | >>> optimizer.zero_grad() 35 | >>> loss_fn(model(input), target).backward(create_graph=True) 36 | >>> optimizer.step() 37 | 38 | __ https://arxiv.org/abs/2006.00719 39 | 40 | Note: 41 | Reference code: https://github.com/amirgholami/adahessian 42 | """ 43 | 44 | def __init__( 45 | self, 46 | params: Params, 47 | lr: float = 0.15, 48 | betas: Betas2 = (0.9, 0.999), 49 | eps: float = 1e-4, 50 | weight_decay: float = 0, 51 | hessian_power: float = 0.5, 52 | seed: Optional[int] = None, 53 | ) -> None: 54 | if lr <= 0.0: 55 | raise ValueError("Invalid learning rate: {}".format(lr)) 56 | if eps <= 0.0: 57 | raise ValueError("Invalid epsilon value: {}".format(eps)) 58 | if not 0.0 <= betas[0] < 1.0: 59 | raise ValueError( 60 | "Invalid beta parameter at index 0: {}".format(betas[0]) 61 | ) 62 | if not 0.0 <= betas[1] < 1.0: 63 | raise ValueError( 64 | "Invalid beta parameter at index 1: {}".format(betas[1]) 65 | ) 66 | if not 0.0 <= hessian_power <= 1.0: 67 | raise ValueError( 68 | "Invalid Hessian power value: {}".format(hessian_power) 69 | ) 70 | if seed is not None: 71 | torch.manual_seed(seed) 72 | defaults = dict( 73 | lr=lr, 74 | betas=betas, 75 | eps=eps, 76 | weight_decay=weight_decay, 77 | hessian_power=hessian_power, 78 | ) 79 | super(Adahessian, self).__init__(params, defaults) 80 | 81 | def get_trace(self, params: Params, grads: Grads) -> List[torch.Tensor]: 82 | """Get an estimate of Hessian Trace. 83 | This is done by computing the Hessian vector product with a random 84 | vector v at the current gradient point, to estimate Hessian trace by 85 | computing the gradient of . 86 | :param gradsH: a list of torch variables 87 | :return: a list of torch tensors 88 | """ 89 | 90 | # Check backward was called with create_graph set to True 91 | for i, grad in enumerate(grads): 92 | if grad.grad_fn is None: 93 | msg = ( 94 | "Gradient tensor {:} does not have grad_fn. When " 95 | "calling loss.backward(), make sure the option " 96 | "create_graph is set to True." 97 | ) 98 | raise RuntimeError(msg.format(i)) 99 | 100 | v = [ 101 | 2 102 | * torch.randint_like( 103 | p, high=2, memory_format=torch.preserve_format 104 | ) 105 | - 1 106 | for p in params 107 | ] 108 | 109 | # this is for distributed setting with single node and multi-gpus, 110 | # for multi nodes setting, we have not support it yet. 111 | hvs = torch.autograd.grad( 112 | grads, params, grad_outputs=v, only_inputs=True, retain_graph=True 113 | ) 114 | 115 | hutchinson_trace = [] 116 | for hv in hvs: 117 | param_size = hv.size() 118 | if len(param_size) <= 2: # for 0/1/2D tensor 119 | # Hessian diagonal block size is 1 here. 120 | # We use that torch.abs(hv * vi) = hv.abs() 121 | tmp_output = hv.abs() 122 | 123 | elif len(param_size) == 4: # Conv kernel 124 | # Hessian diagonal block size is 9 here: torch.sum() reduces 125 | # the dim 2/3. 126 | # We use that torch.abs(hv * vi) = hv.abs() 127 | tmp_output = torch.mean(hv.abs(), dim=[2, 3], keepdim=True) 128 | hutchinson_trace.append(tmp_output) 129 | 130 | return hutchinson_trace 131 | 132 | def step(self, closure: OptLossClosure = None) -> OptFloat: 133 | """Perform a single optimization step. 134 | 135 | Arguments: 136 | closure: A closure that reevaluates the model and returns the loss. 137 | """ 138 | loss = None 139 | if closure is not None: 140 | loss = closure() 141 | 142 | params = [] 143 | groups = [] 144 | grads = [] 145 | 146 | # Flatten groups into lists, so that 147 | # hut_traces can be called with lists of parameters 148 | # and grads 149 | for group in self.param_groups: 150 | for p in group["params"]: 151 | if p.grad is not None: 152 | params.append(p) 153 | groups.append(group) 154 | grads.append(p.grad) 155 | 156 | # get the Hessian diagonal 157 | 158 | hut_traces = self.get_trace(params, grads) 159 | 160 | for p, group, grad, hut_trace in zip( 161 | params, groups, grads, hut_traces 162 | ): 163 | state = self.state[p] 164 | 165 | # State initialization 166 | if len(state) == 0: 167 | state["step"] = 0 168 | # Exponential moving average of gradient values 169 | state["exp_avg"] = torch.zeros_like(p.data) 170 | # Exponential moving average of Hessian diagonal square values 171 | state["exp_hessian_diag_sq"] = torch.zeros_like(p.data) 172 | 173 | exp_avg, exp_hessian_diag_sq = ( 174 | state["exp_avg"], 175 | state["exp_hessian_diag_sq"], 176 | ) 177 | 178 | beta1, beta2 = group["betas"] 179 | 180 | state["step"] += 1 181 | 182 | # Decay the first and second moment running average coefficient 183 | exp_avg.mul_(beta1).add_(grad.detach_(), alpha=1 - beta1) 184 | exp_hessian_diag_sq.mul_(beta2).addcmul_( 185 | hut_trace, hut_trace, value=1 - beta2 186 | ) 187 | 188 | bias_correction1 = 1 - beta1 ** state["step"] 189 | bias_correction2 = 1 - beta2 ** state["step"] 190 | 191 | # make the square root, and the Hessian power 192 | k = group["hessian_power"] 193 | denom = ( 194 | (exp_hessian_diag_sq.sqrt() ** k) 195 | / math.sqrt(bias_correction2) ** k 196 | ).add_(group["eps"]) 197 | 198 | # make update 199 | p.data = p.data - group["lr"] * ( 200 | exp_avg / bias_correction1 / denom 201 | + group["weight_decay"] * p.data 202 | ) 203 | 204 | return loss 205 | -------------------------------------------------------------------------------- /torch_optimizer/adamod.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.optim.optimizer import Optimizer 5 | 6 | from .types import Betas2, OptFloat, OptLossClosure, Params 7 | 8 | __all__ = ("AdaMod",) 9 | 10 | 11 | class AdaMod(Optimizer): 12 | r"""Implements AdaMod algorithm. 13 | 14 | It has been proposed in `Adaptive and Momental Bounds for Adaptive 15 | Learning Rate Methods`__. 16 | 17 | Arguments: 18 | params: iterable of parameters to optimize or dicts defining 19 | parameter groups 20 | lr: learning rate (default: 1e-3) 21 | betas: coefficients used for computing running averages of gradient 22 | and its square (default: (0.9, 0.999)) 23 | beta3: smoothing coefficient for adaptive learning rates 24 | (default: 0.9999) 25 | eps: term added to the denominator to improve numerical stability 26 | (default: 1e-8) 27 | weight_decay: weight decay (L2 penalty) (default: 0) 28 | 29 | Example: 30 | >>> import torch_optimizer as optim 31 | >>> optimizer = optim.AdaMod(model.parameters(), lr=0.1) 32 | >>> optimizer.zero_grad() 33 | >>> loss_fn(model(input), target).backward() 34 | >>> optimizer.step() 35 | 36 | __ https://arxiv.org/abs/1910.12249 37 | 38 | Note: 39 | Reference code: https://github.com/lancopku/AdaMod 40 | """ 41 | 42 | def __init__( 43 | self, 44 | params: Params, 45 | lr: float = 1e-3, 46 | betas: Betas2 = (0.9, 0.999), 47 | beta3: float = 0.999, 48 | eps: float = 1e-8, 49 | weight_decay: float = 0, 50 | ) -> None: 51 | if lr <= 0.0: 52 | raise ValueError("Invalid learning rate: {}".format(lr)) 53 | if eps < 0.0: 54 | raise ValueError("Invalid epsilon value: {}".format(eps)) 55 | if not 0.0 <= betas[0] < 1.0: 56 | raise ValueError( 57 | "Invalid beta parameter at index 0: {}".format(betas[0]) 58 | ) 59 | if not 0.0 <= betas[1] < 1.0: 60 | raise ValueError( 61 | "Invalid beta parameter at index 1: {}".format(betas[1]) 62 | ) 63 | if not 0.0 <= beta3 < 1.0: 64 | raise ValueError("Invalid beta3 parameter: {}".format(beta3)) 65 | if weight_decay < 0.0: 66 | raise ValueError( 67 | "Invalid weight_decay value: {}".format(weight_decay) 68 | ) 69 | defaults = dict( 70 | lr=lr, betas=betas, beta3=beta3, eps=eps, weight_decay=weight_decay 71 | ) 72 | super(AdaMod, self).__init__(params, defaults) 73 | 74 | def step(self, closure: OptLossClosure = None) -> OptFloat: 75 | """Performs a single optimization step. 76 | 77 | Arguments: 78 | closure: A closure that reevaluates the model and returns the loss. 79 | """ 80 | loss = None 81 | if closure is not None: 82 | loss = closure() 83 | 84 | for group in self.param_groups: 85 | for p in group["params"]: 86 | if p.grad is None: 87 | continue 88 | grad = p.grad.data 89 | if grad.is_sparse: 90 | msg = "AdaMod does not support sparse gradients" 91 | raise RuntimeError(msg) 92 | 93 | state = self.state[p] 94 | 95 | # State initialization 96 | if len(state) == 0: 97 | state["step"] = 0 98 | # Exponential moving average of gradient values 99 | state["exp_avg"] = torch.zeros_like( 100 | p, memory_format=torch.preserve_format 101 | ) 102 | # Exponential moving average of squared gradient values 103 | state["exp_avg_sq"] = torch.zeros_like( 104 | p, memory_format=torch.preserve_format 105 | ) 106 | # Exponential moving average of actual learning rates 107 | state["exp_avg_lr"] = torch.zeros_like( 108 | p, memory_format=torch.preserve_format 109 | ) 110 | 111 | exp_avg, exp_avg_sq, exp_avg_lr = ( 112 | state["exp_avg"], 113 | state["exp_avg_sq"], 114 | state["exp_avg_lr"], 115 | ) 116 | beta1, beta2 = group["betas"] 117 | 118 | state["step"] += 1 119 | 120 | # Decay the first and second moment running average coefficient 121 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 122 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 123 | 124 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 125 | 126 | bias_correction1 = 1 - beta1 ** state["step"] 127 | bias_correction2 = 1 - beta2 ** state["step"] 128 | step_size = ( 129 | group["lr"] 130 | * math.sqrt(bias_correction2) 131 | / bias_correction1 132 | ) 133 | 134 | if group["weight_decay"] != 0: 135 | p.data.add_( 136 | p.data, alpha=-group["weight_decay"] * group["lr"] 137 | ) 138 | 139 | # Applies momental bounds on actual learning rates 140 | step_size = torch.full_like( 141 | denom, step_size, memory_format=torch.preserve_format 142 | ) 143 | step_size.div_(denom) 144 | exp_avg_lr.mul_(group["beta3"]).add_( 145 | step_size, alpha=1 - group["beta3"] 146 | ) 147 | step_size = torch.min(step_size, exp_avg_lr) 148 | step_size.mul_(exp_avg) 149 | 150 | p.data.add_(-step_size) 151 | 152 | return loss 153 | -------------------------------------------------------------------------------- /torch_optimizer/adamp.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.optim.optimizer import Optimizer 5 | 6 | from .types import Betas2, OptFloat, OptLossClosure, Params 7 | 8 | __all__ = ("AdamP",) 9 | 10 | 11 | class AdamP(Optimizer): 12 | r"""Implements AdamP algorithm. 13 | 14 | It has been proposed in `Slowing Down the Weight Norm Increase in 15 | Momentum-based Optimizers`__ 16 | 17 | Arguments: 18 | params: iterable of parameters to optimize or dicts defining 19 | parameter groups 20 | lr: learning rate (default: 1e-3) 21 | betas: coefficients used for computing 22 | running averages of gradient and its square (default: (0.9, 0.999)) 23 | eps: term added to the denominator to improve 24 | numerical stability (default: 1e-8) 25 | weight_decay: weight decay (L2 penalty) (default: 0) 26 | delta: threhold that determines whether a set of parameters is scale 27 | invariant or not (default: 0.1) 28 | wd_ratio: relative weight decay applied on scale-invariant parameters 29 | compared to that applied on scale-variant parameters (default: 0.1) 30 | nesterov: enables Nesterov momentum (default: False) 31 | 32 | 33 | Example: 34 | >>> import torch_optimizer as optim 35 | >>> optimizer = optim.AdamP(model.parameters(), lr=0.1) 36 | >>> optimizer.zero_grad() 37 | >>> loss_fn(model(input), target).backward() 38 | >>> optimizer.step() 39 | 40 | __ https://arxiv.org/abs/2006.08217 41 | 42 | Note: 43 | Reference code: https://github.com/clovaai/AdamP 44 | """ 45 | 46 | def __init__( 47 | self, 48 | params: Params, 49 | lr: float = 1e-3, 50 | betas: Betas2 = (0.9, 0.999), 51 | eps: float = 1e-8, 52 | weight_decay: float = 0, 53 | delta: float = 0.1, 54 | wd_ratio: float = 0.1, 55 | nesterov: bool = False, 56 | ) -> None: 57 | if lr <= 0.0: 58 | raise ValueError("Invalid learning rate: {}".format(lr)) 59 | if eps < 0.0: 60 | raise ValueError("Invalid epsilon value: {}".format(eps)) 61 | if not 0.0 <= betas[0] < 1.0: 62 | raise ValueError( 63 | "Invalid beta parameter at index 0: {}".format(betas[0]) 64 | ) 65 | if not 0.0 <= betas[1] < 1.0: 66 | raise ValueError( 67 | "Invalid beta parameter at index 1: {}".format(betas[1]) 68 | ) 69 | if weight_decay < 0: 70 | raise ValueError( 71 | "Invalid weight_decay value: {}".format(weight_decay) 72 | ) 73 | if delta < 0: 74 | raise ValueError("Invalid delta value: {}".format(delta)) 75 | if wd_ratio < 0: 76 | raise ValueError("Invalid wd_ratio value: {}".format(wd_ratio)) 77 | 78 | defaults = dict( 79 | lr=lr, 80 | betas=betas, 81 | eps=eps, 82 | weight_decay=weight_decay, 83 | delta=delta, 84 | wd_ratio=wd_ratio, 85 | nesterov=nesterov, 86 | ) 87 | super(AdamP, self).__init__(params, defaults) 88 | 89 | @staticmethod 90 | def _channel_view(x): 91 | return x.view(x.size(0), -1) 92 | 93 | @staticmethod 94 | def _layer_view(x): 95 | return x.view(1, -1) 96 | 97 | @staticmethod 98 | def _cosine_similarity(x, y, eps, view_func): 99 | x = view_func(x) 100 | y = view_func(y) 101 | 102 | x_norm = x.norm(dim=1).add_(eps) 103 | y_norm = y.norm(dim=1).add_(eps) 104 | dot = (x * y).sum(dim=1) 105 | 106 | return dot.abs() / x_norm / y_norm 107 | 108 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 109 | wd = 1 110 | expand_size = [-1] + [1] * (len(p.shape) - 1) 111 | for view_func in [self._channel_view, self._layer_view]: 112 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 113 | 114 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 115 | p_n = p.data / view_func(p.data).norm(dim=1).view( 116 | expand_size 117 | ).add_(eps) 118 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view( 119 | expand_size 120 | ) 121 | wd = wd_ratio 122 | 123 | return perturb, wd 124 | 125 | return perturb, wd 126 | 127 | def step(self, closure: OptLossClosure = None) -> OptFloat: 128 | r"""Performs a single optimization step. 129 | 130 | Arguments: 131 | closure: A closure that reevaluates the model and returns the loss. 132 | """ 133 | loss = None 134 | if closure is not None: 135 | loss = closure() 136 | 137 | for group in self.param_groups: 138 | for p in group["params"]: 139 | if p.grad is None: 140 | continue 141 | 142 | grad = p.grad.data 143 | beta1, beta2 = group["betas"] 144 | nesterov = group["nesterov"] 145 | 146 | state = self.state[p] 147 | 148 | # State initialization 149 | if len(state) == 0: 150 | state["step"] = 0 151 | state["exp_avg"] = torch.zeros_like( 152 | p.data, memory_format=torch.preserve_format 153 | ) 154 | state["exp_avg_sq"] = torch.zeros_like( 155 | p.data, memory_format=torch.preserve_format 156 | ) 157 | 158 | # Adam 159 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 160 | 161 | state["step"] += 1 162 | bias_correction1 = 1 - beta1 ** state["step"] 163 | bias_correction2 = 1 - beta2 ** state["step"] 164 | 165 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 166 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 167 | 168 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( 169 | group["eps"] 170 | ) 171 | step_size = group["lr"] / bias_correction1 172 | 173 | if nesterov: 174 | perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom 175 | else: 176 | perturb = exp_avg / denom 177 | 178 | # Projection 179 | wd_ratio = 1 180 | if len(p.shape) > 1: 181 | perturb, wd_ratio = self._projection( 182 | p, 183 | grad, 184 | perturb, 185 | group["delta"], 186 | group["wd_ratio"], 187 | group["eps"], 188 | ) 189 | 190 | # Weight decay 191 | if group["weight_decay"] > 0: 192 | p.data.mul_( 193 | 1 - group["lr"] * group["weight_decay"] * wd_ratio 194 | ) 195 | 196 | # Step 197 | p.data.add_(perturb, alpha=-step_size) 198 | 199 | return loss 200 | -------------------------------------------------------------------------------- /torch_optimizer/aggmo.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Type, TypeVar, Union 2 | 3 | import torch 4 | from torch.optim.optimizer import Optimizer 5 | 6 | from .types import OptFloat, OptLossClosure, Params 7 | 8 | __all__ = ("AggMo",) 9 | 10 | 11 | T = TypeVar("T", bound="AggMo") 12 | 13 | 14 | class AggMo(Optimizer): 15 | r"""Implements Aggregated Momentum Gradient Descent. 16 | 17 | It has been proposed in `Aggregated Momentum: Stability Through Passive 18 | Damping`__ 19 | 20 | Example: 21 | >>> import torch_optimizer as optim 22 | >>> optimizer = optim.AggMo(model.parameters(), lr=0.1) 23 | >>> optimizer.zero_grad() 24 | >>> loss_fn(model(input), target).backward() 25 | >>> optimizer.step() 26 | 27 | __ https://arxiv.org/abs/1804.00325 28 | 29 | Note: 30 | Reference code: https://github.com/AtheMathmo/AggMo/blob/master/aggmo.py # noqa 31 | """ 32 | 33 | def __init__( 34 | self, 35 | params: Params, 36 | lr: float = 1e-3, 37 | betas: Union[List[float], Tuple[float, ...]] = (0.0, 0.9, 0.99), 38 | weight_decay: float = 0, 39 | ) -> None: 40 | if lr <= 0.0: 41 | raise ValueError("Invalid learning rate: {}".format(lr)) 42 | 43 | for i, beta in enumerate(betas): 44 | if not 0.0 <= beta < 1.0: 45 | msg = "Invalid beta parameter at index 1: {}".format(betas[i]) 46 | raise ValueError(msg) 47 | 48 | if weight_decay < 0.0: 49 | raise ValueError( 50 | "Invalid weight_decay value: {}".format(weight_decay) 51 | ) 52 | 53 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) 54 | super(AggMo, self).__init__(params, defaults) 55 | 56 | @classmethod 57 | def from_exp_form( 58 | cls: Type[T], 59 | params: Params, 60 | lr: float = 1e-3, 61 | a: float = 0.1, 62 | k: int = 3, 63 | weight_decay: float = 0, 64 | ) -> T: 65 | if lr <= 0.0: 66 | raise ValueError("Invalid parameter k: {}".format(k)) 67 | 68 | betas = [1 - a**i for i in range(k)] # type: List[float] 69 | return cls(params, lr, betas, weight_decay) 70 | 71 | def step(self, closure: OptLossClosure = None) -> OptFloat: 72 | r"""Performs a single optimization step. 73 | 74 | Arguments: 75 | closure: A closure that reevaluates the model and returns the loss. 76 | """ 77 | loss = None 78 | if closure is not None: 79 | loss = closure() 80 | 81 | for group in self.param_groups: 82 | weight_decay = group["weight_decay"] 83 | betas = group["betas"] 84 | total_mom = float(len(betas)) 85 | 86 | for p in group["params"]: 87 | if p.grad is None: 88 | continue 89 | d_p = p.grad.data 90 | if weight_decay != 0: 91 | d_p.add_(p.data, alpha=weight_decay) 92 | param_state = self.state[p] 93 | if "momentum_buffer" not in param_state: 94 | param_state["momentum_buffer"] = {} 95 | for beta in betas: 96 | param_state["momentum_buffer"][ 97 | beta 98 | ] = torch.zeros_like( 99 | p.data, memory_format=torch.preserve_format 100 | ) 101 | for beta in betas: 102 | buf = param_state["momentum_buffer"][beta] 103 | buf.mul_(beta).add_(d_p) 104 | p.data.sub_(buf, alpha=group["lr"] / total_mom) 105 | return loss 106 | -------------------------------------------------------------------------------- /torch_optimizer/apollo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | from .types import OptFloat, OptLossClosure, Params 5 | 6 | 7 | class Apollo(Optimizer): 8 | r"""Implements Apollo Optimizer Algorithm. 9 | 10 | It has been proposed in `Apollo: An Adaptive Parameter-wise Diagonal 11 | Quasi-Newton Method for Nonconvex Stochastic Optimization`__. 12 | 13 | Arguments: 14 | params: iterable of parameters to optimize or dicts defining 15 | parameter groups 16 | lr: learning rate (default: 1e-2) 17 | beta: coefficient used for computing 18 | running averages of gradient (default: 0.9) 19 | eps: term added to the denominator to improve 20 | numerical stability (default: 1e-4) 21 | warmup: number of warmup steps (default: 0) 22 | init_lr: initial learning rate for warmup (default: 0.01) 23 | weight_decay: weight decay (L2 penalty) (default: 0) 24 | 25 | Example: 26 | >>> import torch_optimizer as optim 27 | >>> optimizer = optim.Apollo(model.parameters(), lr=0.01) 28 | >>> optimizer.zero_grad() 29 | >>> loss_fn(model(input), target).backward() 30 | >>> optimizer.step() 31 | 32 | __ https://arxiv.org/abs/2009.13586 33 | 34 | Note: 35 | Reference code: https://github.com/XuezheMax/apollo 36 | """ 37 | 38 | def __init__( 39 | self, 40 | params: Params, 41 | lr: float = 1e-2, 42 | beta: float = 0.9, 43 | eps: float = 1e-4, 44 | warmup: int = 0, 45 | init_lr: float = 0.01, 46 | weight_decay: float = 0, 47 | ): 48 | if lr <= 0.0: 49 | raise ValueError("Invalid learning rate: {}".format(lr)) 50 | if eps < 0.0: 51 | raise ValueError("Invalid epsilon value: {}".format(eps)) 52 | if not 0.0 <= beta < 1.0: 53 | raise ValueError("Invalid beta parameter: {}".format(beta)) 54 | if not 0.0 <= weight_decay: 55 | raise ValueError( 56 | "Invalid weight_decay value: {}".format(weight_decay) 57 | ) 58 | if not 0.0 <= warmup: 59 | raise ValueError("Invalid warmup updates: {}".format(warmup)) 60 | if not 0.0 <= init_lr <= 1.0: 61 | raise ValueError( 62 | "Invalid initial learning rate: {}".format(init_lr) 63 | ) 64 | 65 | defaults = dict( 66 | lr=lr, 67 | beta=beta, 68 | eps=eps, 69 | warmup=warmup, 70 | init_lr=init_lr, 71 | base_lr=lr, 72 | weight_decay=weight_decay, 73 | ) 74 | super(Apollo, self).__init__(params, defaults) 75 | 76 | def step(self, closure: OptLossClosure = None) -> OptFloat: 77 | r"""Performs a single optimization step. 78 | 79 | Arguments: 80 | closure: A closure that reevaluates the model and returns the loss. 81 | """ 82 | loss = None 83 | if closure is not None: 84 | loss = closure() 85 | 86 | for group in self.param_groups: 87 | for p in group["params"]: 88 | if p.grad is None: 89 | continue 90 | 91 | state = self.state[p] 92 | 93 | # State initialization 94 | if len(state) == 0: 95 | state["step"] = 0 96 | # Exponential moving average of gradient values 97 | state["exp_avg_grad"] = torch.zeros_like( 98 | p, memory_format=torch.preserve_format 99 | ) 100 | # Exponential moving average of squared gradient values 101 | state["approx_hessian"] = torch.zeros_like( 102 | p, memory_format=torch.preserve_format 103 | ) 104 | # Previous update direction 105 | state["update"] = torch.zeros_like( 106 | p, memory_format=torch.preserve_format 107 | ) 108 | 109 | # Calculate current lr 110 | if state["step"] < group["warmup"]: 111 | curr_lr = (group["base_lr"] - group["init_lr"]) * state[ 112 | "step" 113 | ] / group["warmup"] + group["init_lr"] 114 | else: 115 | curr_lr = group["lr"] 116 | 117 | # Perform optimization step 118 | grad = p.grad.data 119 | if grad.is_sparse: 120 | raise RuntimeError( 121 | "Atom does not support sparse gradients." 122 | ) 123 | 124 | # Perform step weight decay 125 | if group["weight_decay"] != 0: 126 | grad = grad.add(p, alpha=group["weight_decay"]) 127 | 128 | beta = group["beta"] 129 | exp_avg_grad = state["exp_avg_grad"] 130 | B = state["approx_hessian"] 131 | d_p = state["update"] 132 | 133 | state["step"] += 1 134 | bias_correction = 1 - beta ** state["step"] 135 | alpha = (1 - beta) / bias_correction 136 | 137 | # Update the running average grad 138 | delta_grad = grad - exp_avg_grad 139 | exp_avg_grad.add_(delta_grad, alpha=alpha) 140 | 141 | denom = d_p.norm(p=4).add(group["eps"]) 142 | d_p.div_(denom) 143 | v_sq = d_p.mul(d_p) 144 | delta = ( 145 | delta_grad.div_(denom).mul_(d_p).sum().mul(-alpha) 146 | - B.mul(v_sq).sum() 147 | ) 148 | 149 | # Update B 150 | B.addcmul_(v_sq, delta) 151 | 152 | # calc direction of parameter updates 153 | denom = B.abs().clamp_(min=1) 154 | d_p.copy_(exp_avg_grad.div(denom)) 155 | 156 | p.data.add_(d_p, alpha=-curr_lr) 157 | 158 | return loss 159 | -------------------------------------------------------------------------------- /torch_optimizer/diffgrad.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.optim.optimizer import Optimizer 5 | 6 | from .types import Betas2, OptFloat, OptLossClosure, Params 7 | 8 | __all__ = ("DiffGrad",) 9 | 10 | 11 | class DiffGrad(Optimizer): 12 | r"""Implements DiffGrad algorithm. 13 | 14 | It has been proposed in `DiffGrad: An Optimization Method for 15 | Convolutional Neural Networks`__. 16 | 17 | Arguments: 18 | params: iterable of parameters to optimize or dicts defining 19 | parameter groups 20 | lr: learning rate (default: 1e-3) 21 | betas: coefficients used for computing 22 | running averages of gradient and its square (default: (0.9, 0.999)) 23 | eps: term added to the denominator to improve 24 | numerical stability (default: 1e-8) 25 | weight_decay: weight decay (L2 penalty) (default: 0) 26 | 27 | Example: 28 | >>> import torch_optimizer as optim 29 | >>> optimizer = optim.DiffGrad(model.parameters(), lr=0.1) 30 | >>> optimizer.zero_grad() 31 | >>> loss_fn(model(input), target).backward() 32 | >>> optimizer.step() 33 | 34 | __ https://arxiv.org/abs/1909.11015 35 | 36 | Note: 37 | Reference code: https://github.com/shivram1987/diffGrad 38 | """ 39 | 40 | def __init__( 41 | self, 42 | params: Params, 43 | lr: float = 1e-3, 44 | betas: Betas2 = (0.9, 0.999), 45 | eps: float = 1e-8, 46 | weight_decay: float = 0.0, 47 | ) -> None: 48 | if lr <= 0.0: 49 | raise ValueError("Invalid learning rate: {}".format(lr)) 50 | if eps < 0.0: 51 | raise ValueError("Invalid epsilon value: {}".format(eps)) 52 | if not 0.0 <= betas[0] < 1.0: 53 | raise ValueError( 54 | "Invalid beta parameter at index 0: {}".format(betas[0]) 55 | ) 56 | if not 0.0 <= betas[1] < 1.0: 57 | raise ValueError( 58 | "Invalid beta parameter at index 1: {}".format(betas[1]) 59 | ) 60 | if weight_decay < 0.0: 61 | raise ValueError( 62 | "Invalid weight_decay value: {}".format(weight_decay) 63 | ) 64 | 65 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 66 | super(DiffGrad, self).__init__(params, defaults) 67 | 68 | def step(self, closure: OptLossClosure = None) -> OptFloat: 69 | r"""Performs a single optimization step. 70 | 71 | Arguments: 72 | closure: A closure that reevaluates the model and returns the loss. 73 | """ 74 | loss = None 75 | if closure is not None: 76 | loss = closure() 77 | 78 | for group in self.param_groups: 79 | beta1, beta2 = group["betas"] 80 | 81 | for p in group["params"]: 82 | if p.grad is None: 83 | continue 84 | grad = p.grad.data 85 | if grad.is_sparse: 86 | msg = ( 87 | "DiffGrad does not support sparse gradients, " 88 | "please consider SparseAdam instead" 89 | ) 90 | raise RuntimeError(msg) 91 | 92 | state = self.state[p] 93 | 94 | # State initialization 95 | if len(state) == 0: 96 | state["step"] = 0 97 | # Exponential moving average of gradient values 98 | state["exp_avg"] = torch.zeros_like( 99 | p, memory_format=torch.preserve_format 100 | ) 101 | # Exponential moving average of squared gradient values 102 | state["exp_avg_sq"] = torch.zeros_like( 103 | p, memory_format=torch.preserve_format 104 | ) 105 | # Previous gradient 106 | state["previous_grad"] = torch.zeros_like( 107 | p, memory_format=torch.preserve_format 108 | ) 109 | 110 | exp_avg, exp_avg_sq, previous_grad = ( 111 | state["exp_avg"], 112 | state["exp_avg_sq"], 113 | state["previous_grad"], 114 | ) 115 | 116 | state["step"] += 1 117 | 118 | if group["weight_decay"] != 0: 119 | grad.add_(p.data, alpha=group["weight_decay"]) 120 | 121 | # Decay the first and second moment running average coefficient 122 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 123 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 124 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 125 | 126 | bias_correction1 = 1 - beta1 ** state["step"] 127 | bias_correction2 = 1 - beta2 ** state["step"] 128 | 129 | # compute diffgrad coefficient (dfc) 130 | diff = torch.abs(previous_grad - grad) 131 | dfc = torch.div(1.0, (1.0 + torch.exp(-diff))) 132 | state["previous_grad"] = grad.clone() 133 | 134 | # update momentum with dfc 135 | exp_avg1 = exp_avg * dfc 136 | 137 | step_size = ( 138 | group["lr"] 139 | * math.sqrt(bias_correction2) 140 | / bias_correction1 141 | ) 142 | 143 | p.data.addcdiv_(exp_avg1, denom, value=-step_size) 144 | 145 | return loss 146 | -------------------------------------------------------------------------------- /torch_optimizer/lamb.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.optim.optimizer import Optimizer 5 | 6 | from .types import Betas2, OptFloat, OptLossClosure, Params 7 | 8 | __all__ = ("Lamb",) 9 | 10 | 11 | class Lamb(Optimizer): 12 | r"""Implements Lamb algorithm. 13 | 14 | It has been proposed in `Large Batch Optimization for Deep Learning: 15 | Training BERT in 76 minutes`__. 16 | 17 | Arguments: 18 | params: iterable of parameters to optimize or dicts defining 19 | parameter groups 20 | lr: learning rate (default: 1e-3) 21 | betas: coefficients used for computing 22 | running averages of gradient and its square (default: (0.9, 0.999)) 23 | eps: term added to the denominator to improve 24 | numerical stability (default: 1e-8) 25 | weight_decay: weight decay (L2 penalty) (default: 0) 26 | clamp_value: clamp weight_norm in (0,clamp_value) (default: 10) 27 | set to a high value to avoid it (e.g 10e3) 28 | adam: always use trust ratio = 1, which turns this 29 | into Adam. Useful for comparison purposes. (default: False) 30 | debias: debias adam by (1 - beta**step) (default: False) 31 | 32 | Example: 33 | >>> import torch_optimizer as optim 34 | >>> optimizer = optim.Lamb(model.parameters(), lr=0.1) 35 | >>> optimizer.zero_grad() 36 | >>> loss_fn(model(input), target).backward() 37 | >>> optimizer.step() 38 | 39 | __ https://arxiv.org/abs/1904.00962 40 | 41 | Note: 42 | Reference code: https://github.com/cybertronai/pytorch-lamb 43 | """ 44 | 45 | def __init__( 46 | self, 47 | params: Params, 48 | lr: float = 1e-3, 49 | betas: Betas2 = (0.9, 0.999), 50 | eps: float = 1e-6, 51 | weight_decay: float = 0, 52 | clamp_value: float = 10, 53 | adam: bool = False, 54 | debias: bool = False, 55 | ) -> None: 56 | if lr <= 0.0: 57 | raise ValueError("Invalid learning rate: {}".format(lr)) 58 | if eps < 0.0: 59 | raise ValueError("Invalid epsilon value: {}".format(eps)) 60 | if not 0.0 <= betas[0] < 1.0: 61 | raise ValueError( 62 | "Invalid beta parameter at index 0: {}".format(betas[0]) 63 | ) 64 | if not 0.0 <= betas[1] < 1.0: 65 | raise ValueError( 66 | "Invalid beta parameter at index 1: {}".format(betas[1]) 67 | ) 68 | if weight_decay < 0: 69 | raise ValueError( 70 | "Invalid weight_decay value: {}".format(weight_decay) 71 | ) 72 | if clamp_value < 0.0: 73 | raise ValueError("Invalid clamp value: {}".format(clamp_value)) 74 | 75 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 76 | self.clamp_value = clamp_value 77 | self.adam = adam 78 | self.debias = debias 79 | 80 | super(Lamb, self).__init__(params, defaults) 81 | 82 | def step(self, closure: OptLossClosure = None) -> OptFloat: 83 | r"""Performs a single optimization step. 84 | 85 | Arguments: 86 | closure: A closure that reevaluates the model and returns the loss. 87 | """ 88 | loss = None 89 | if closure is not None: 90 | loss = closure() 91 | 92 | for group in self.param_groups: 93 | for p in group["params"]: 94 | if p.grad is None: 95 | continue 96 | grad = p.grad.data 97 | if grad.is_sparse: 98 | msg = ( 99 | "Lamb does not support sparse gradients, " 100 | "please consider SparseAdam instead" 101 | ) 102 | raise RuntimeError(msg) 103 | 104 | state = self.state[p] 105 | 106 | # State initialization 107 | if len(state) == 0: 108 | state["step"] = 0 109 | # Exponential moving average of gradient values 110 | state["exp_avg"] = torch.zeros_like( 111 | p, memory_format=torch.preserve_format 112 | ) 113 | # Exponential moving average of squared gradient values 114 | state["exp_avg_sq"] = torch.zeros_like( 115 | p, memory_format=torch.preserve_format 116 | ) 117 | 118 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 119 | beta1, beta2 = group["betas"] 120 | 121 | state["step"] += 1 122 | 123 | # Decay the first and second moment running average coefficient 124 | # m_t 125 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 126 | # v_t 127 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 128 | 129 | # Paper v3 does not use debiasing. 130 | if self.debias: 131 | bias_correction = math.sqrt(1 - beta2 ** state["step"]) 132 | bias_correction /= 1 - beta1 ** state["step"] 133 | else: 134 | bias_correction = 1 135 | 136 | # Apply bias to lr to avoid broadcast. 137 | step_size = group["lr"] * bias_correction 138 | 139 | weight_norm = torch.norm(p.data).clamp(0, self.clamp_value) 140 | 141 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group["eps"]) 142 | if group["weight_decay"] != 0: 143 | adam_step.add_(p.data, alpha=group["weight_decay"]) 144 | 145 | adam_norm = torch.norm(adam_step) 146 | if weight_norm == 0 or adam_norm == 0: 147 | trust_ratio = 1 148 | else: 149 | trust_ratio = weight_norm / adam_norm 150 | state["weight_norm"] = weight_norm 151 | state["adam_norm"] = adam_norm 152 | state["trust_ratio"] = trust_ratio 153 | if self.adam: 154 | trust_ratio = 1 155 | 156 | p.data.add_(adam_step, alpha=-step_size * trust_ratio) 157 | 158 | return loss 159 | -------------------------------------------------------------------------------- /torch_optimizer/lars.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | from .types import OptFloat, OptLossClosure, Params, State 5 | 6 | __all__ = ("LARS",) 7 | 8 | 9 | class LARS(Optimizer): 10 | r"""Extends SGD in PyTorch with LARS scaling from the paper 11 | `Large batch training of Convolutional Networks`__. 12 | 13 | Arguments: 14 | params (iterable): iterable of parameters to optimize or dicts defining 15 | parameter groups 16 | lr: learning rate (default: 1e-3) 17 | momentum: momentum factor (default: 0) 18 | dampening: dampening for momentum (default: 0) 19 | eps: term added to the denominator to improve 20 | numerical stability (default: 1e-8) 21 | weight_decay: weight decay (L2 penalty) (default: 0) 22 | nesterov: enables Nesterov momentum (default: False) 23 | trust_coefficient: trust coefficient for computing LR (default: 0.001) 24 | eps: eps for division denominator (default: 1e-8) 25 | 26 | Example: 27 | >>> import torch_optimizer as optim 28 | >>> optimizer = optim.LARS(model.parameters(), lr=0.001) 29 | >>> optimizer.zero_grad() 30 | >>> loss_fn(model(input), target).backward() 31 | >>> optimizer.step() 32 | 33 | .. note:: 34 | The application of momentum in the SGD part is modified according to 35 | the PyTorch standards. LARS scaling fits into the equation in the 36 | following fashion. 37 | 38 | .. math:: 39 | \begin{aligned} 40 | g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\ 41 | v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\ 42 | p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, 43 | \\end{aligned} 44 | 45 | where :math:`p`, :math:`g`, :math:`v`, :math:`\\mu` and :math:`\beta` 46 | denote the parameters, gradient, velocity, momentum, and weight decay 47 | respectively. The :math:`lars_lr` is defined by Eq. 6 in the paper. 48 | The Nesterov version is analogously modified. 49 | 50 | .. warning:: 51 | Parameters with weight decay set to 0 will automatically be excluded 52 | from layer-wise LR scaling. This is to ensure consistency with papers 53 | like SimCLR and BYOL. 54 | 55 | 56 | __ https://arxiv.org/pdf/1708.03888.pdf 57 | 58 | Note: 59 | Reference code: https://github.com/PyTorchLightning/lightning-bolts/ 60 | """ 61 | 62 | def __init__( 63 | self, 64 | params: Params, 65 | lr: float = 1e-2, 66 | momentum: float = 0.0, 67 | dampening: float = 0.0, 68 | weight_decay: float = 0.0, 69 | nesterov: bool = False, 70 | trust_coefficient: float = 0.01, 71 | eps: float = 1e-8, 72 | ): 73 | if lr <= 0.0: 74 | raise ValueError("Invalid learning rate: {}".format(lr)) 75 | if eps < 0.0: 76 | raise ValueError("Invalid epsilon value: {}".format(eps)) 77 | if momentum < 0.0: 78 | raise ValueError("Invalid momentum value: {}".format(momentum)) 79 | if dampening < 0.0: 80 | raise ValueError("Invalid dampening value: {}".format(dampening)) 81 | if weight_decay < 0.0: 82 | raise ValueError( 83 | "Invalid weight_decay value: {}".format(weight_decay) 84 | ) 85 | if trust_coefficient < 0.0: 86 | raise ValueError( 87 | "Invalid trust_coefficient value: {}".format(trust_coefficient) 88 | ) 89 | 90 | defaults = dict( 91 | lr=lr, 92 | momentum=momentum, 93 | dampening=dampening, 94 | weight_decay=weight_decay, 95 | nesterov=nesterov, 96 | trust_coefficient=trust_coefficient, 97 | eps=eps, 98 | ) 99 | if nesterov and (momentum <= 0 or dampening != 0): 100 | raise ValueError( 101 | "Nesterov momentum requires a momentum and zero dampening" 102 | ) 103 | 104 | super().__init__(params, defaults) 105 | 106 | def __setstate__(self, state: State) -> None: 107 | super().__setstate__(state) 108 | 109 | for group in self.param_groups: 110 | group.setdefault("nesterov", False) 111 | 112 | @torch.no_grad() 113 | def step(self, closure: OptLossClosure = None) -> OptFloat: 114 | r"""Performs a single optimization step. 115 | 116 | Arguments: 117 | closure: A closure that reevaluates the model and returns the loss. 118 | """ 119 | loss = None 120 | if closure is not None: 121 | with torch.enable_grad(): 122 | loss = closure() 123 | 124 | # exclude scaling for params with 0 weight decay 125 | for group in self.param_groups: 126 | weight_decay = group["weight_decay"] 127 | momentum = group["momentum"] 128 | dampening = group["dampening"] 129 | nesterov = group["nesterov"] 130 | 131 | for p in group["params"]: 132 | if p.grad is None: 133 | continue 134 | 135 | d_p = p.grad 136 | p_norm = torch.norm(p.data) 137 | g_norm = torch.norm(p.grad.data) 138 | 139 | # lars scaling + weight decay part 140 | if weight_decay != 0: 141 | if p_norm != 0 and g_norm != 0: 142 | lars_lr = p_norm / ( 143 | g_norm + p_norm * weight_decay + group["eps"] 144 | ) 145 | lars_lr *= group["trust_coefficient"] 146 | 147 | d_p = d_p.add(p, alpha=weight_decay) 148 | d_p *= lars_lr 149 | 150 | if momentum != 0: 151 | param_state = self.state[p] 152 | if "momentum_buffer" not in param_state: 153 | buf = param_state["momentum_buffer"] = torch.clone( 154 | d_p 155 | ).detach() 156 | else: 157 | buf = param_state["momentum_buffer"] 158 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 159 | if nesterov: 160 | d_p = d_p.add(buf, alpha=momentum) 161 | else: 162 | d_p = buf 163 | 164 | p.add_(d_p, alpha=-group["lr"]) 165 | 166 | return loss 167 | -------------------------------------------------------------------------------- /torch_optimizer/lion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | from .types import Betas2, OptFloat, OptLossClosure, Params 5 | 6 | __all__ = ("Lion",) 7 | 8 | 9 | class Lion(Optimizer): 10 | r"""Implements Lion algorithm. 11 | 12 | Addapted from https://github.com/google/automl/tree/master/lion 13 | 14 | The Lion - EvoLved SIgn MOmeNtum - algorithm was proposed in 15 | https://arxiv.org/pdf/2302.06675.pdf. 16 | Lion aims to be more memory efficient than Adam by only tracking momentum. 17 | 18 | Caveats: As detailed in the paper, Lion requires a smaller learning rate 19 | lr, and larger decoupled weight decay to maintain effective weight decay 20 | strength. Also, the gain of Lion increases with the batch size. 21 | Furthermore, Lion was not found to outperform AdamW on some large language 22 | and text/image datasets. 23 | 24 | Arguments: 25 | params: iterable of parameters to optimize or dicts defining 26 | parameter groups 27 | lr: learning rate (default: 1e-3) 28 | betas: coefficients used for computing 29 | running averages of gradient and its square (default: (0.95, 0)) 30 | weight_decay: weight decay (L2 penalty) (default: 0) 31 | 32 | Example: 33 | >>> import torch_optimizer as optim 34 | >>> optimizer = optim.Lion(model.parameters(), lr=0.001) 35 | >>> optimizer.zero_grad() 36 | >>> loss_fn(model(input), target).backward() 37 | >>> optimizer.step() 38 | """ 39 | 40 | def __init__( 41 | self, 42 | params: Params, 43 | lr: float = 1e-4, 44 | betas: Betas2 = (0.9, 0.99), 45 | weight_decay: float = 0.0, 46 | ): 47 | if lr <= 0.0: 48 | raise ValueError("Invalid learning rate: {}".format(lr)) 49 | if not 0.0 <= betas[0] < 1.0: 50 | raise ValueError( 51 | "Invalid beta parameter at index 0: {}".format(betas[0]) 52 | ) 53 | if not 0.0 <= betas[1] < 1.0: 54 | raise ValueError( 55 | "Invalid beta parameter at index 1: {}".format(betas[1]) 56 | ) 57 | if weight_decay < 0: 58 | raise ValueError( 59 | "Invalid weight_decay value: {}".format(weight_decay) 60 | ) 61 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) 62 | super().__init__(params, defaults) 63 | 64 | @torch.no_grad() 65 | def step(self, closure: OptLossClosure = None) -> OptFloat: 66 | r"""Performs a single optimization step. 67 | 68 | Arguments: 69 | closure: A closure that reevaluates the model and returns the loss. 70 | """ 71 | loss = None 72 | if closure is not None: 73 | with torch.enable_grad(): 74 | loss = closure() 75 | 76 | for group in self.param_groups: 77 | for p in group["params"]: 78 | if p.grad is None: 79 | continue 80 | 81 | # Perform stepweight decay 82 | p.data.mul_(1 - group["lr"] * group["weight_decay"]) 83 | 84 | grad = p.grad 85 | state = self.state[p] 86 | # State initialization 87 | if len(state) == 0: 88 | # Exponential moving average of gradient values 89 | state["exp_avg"] = torch.zeros_like(p) 90 | 91 | exp_avg = state["exp_avg"] 92 | beta1, beta2 = group["betas"] 93 | 94 | # Weight update 95 | update = exp_avg * beta1 + grad * (1 - beta1) 96 | p.add_(torch.sign(update), alpha=-group["lr"]) 97 | # Decay the momentum running average coefficient 98 | exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) 99 | 100 | return loss 101 | -------------------------------------------------------------------------------- /torch_optimizer/lookahead.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Any, Dict 3 | 4 | import torch 5 | from torch.optim.optimizer import Optimizer 6 | 7 | from .types import OptFloat, OptLossClosure, State 8 | 9 | __all__ = ("Lookahead",) 10 | 11 | 12 | class Lookahead(Optimizer): 13 | r"""Implements Lookahead optimization algorithm. 14 | 15 | It has been proposed in `Lookahead Optimizer: k steps forward, 1 16 | step back`__ 17 | 18 | Arguments: 19 | optimizer: base inner optimizer optimize, like Yogi, DiffGrad or Adam. 20 | k: number of lookahead steps (default: 5) 21 | alpha: linear interpolation factor. 1.0 recovers the inner optimizer. 22 | (default: 5) 23 | 24 | Example: 25 | >>> import torch_optimizer as optim 26 | >>> yogi = optim.Yogi(model.parameters(), lr=0.1) 27 | >>> optimizer = optim.Lookahead(yogi, k=5, alpha=0.5) 28 | >>> optimizer.zero_grad() 29 | >>> loss_fn(model(input), target).backward() 30 | >>> optimizer.step() 31 | 32 | __ https://arxiv.org/abs/1907.08610 33 | 34 | Note: 35 | Reference code: https://github.com/alphadl/lookahead.pytorch 36 | """ 37 | 38 | def __init__( 39 | self, optimizer: Optimizer, k: int = 5, alpha: float = 0.5 40 | ) -> None: 41 | if k < 0.0: 42 | raise ValueError("Invalid number of lookahead steps: {}".format(k)) 43 | if alpha < 0: 44 | raise ValueError( 45 | "Invalid linear interpolation factor: {}".format(alpha) 46 | ) 47 | 48 | self.optimizer = optimizer 49 | self.k = k 50 | self.alpha = alpha 51 | self.param_groups = self.optimizer.param_groups 52 | self.state = defaultdict(dict) 53 | self.fast_state = self.optimizer.state 54 | for group in self.param_groups: 55 | group["counter"] = 0 56 | self.defaults = {"k": k, "alpha": alpha, **optimizer.defaults} 57 | 58 | def _update(self, group: Dict[str, Any]) -> None: 59 | for fast in group["params"]: 60 | param_state = self.state[fast] 61 | if "slow_param" not in param_state: 62 | param_state["slow_param"] = torch.clone(fast.data).detach() 63 | 64 | slow = param_state["slow_param"] 65 | fast.data.mul_(self.alpha).add_(slow, alpha=1.0 - self.alpha) 66 | slow.data.copy_(fast) 67 | 68 | def step(self, closure: OptLossClosure = None) -> OptFloat: 69 | r"""Performs a single optimization step. 70 | 71 | Arguments: 72 | closure: A closure that reevaluates the model and returns the loss. 73 | """ 74 | loss = self.optimizer.step(closure=closure) 75 | for group in self.param_groups: 76 | if group["counter"] == 0: 77 | self._update(group) 78 | group["counter"] += 1 79 | group["counter"] %= self.k 80 | return loss 81 | 82 | def state_dict(self) -> State: 83 | r"""Returns the state of the optimizer as a :class:`dict`. 84 | 85 | It contains two entries: 86 | * state - a dict holding current optimization state. Its content 87 | differs between optimizer classes. 88 | * param_groups - a dict containing all parameter groups 89 | """ 90 | slow_state_dict = super(Lookahead, self).state_dict() 91 | fast_state_dict = self.optimizer.state_dict() 92 | fast_state = fast_state_dict["state"] 93 | param_groups = fast_state_dict["param_groups"] 94 | return { 95 | "fast_state": fast_state, 96 | "slow_state": slow_state_dict["state"], 97 | "param_groups": param_groups, 98 | } 99 | 100 | def load_state_dict(self, state_dict: State) -> None: 101 | r"""Loads the optimizer state. 102 | 103 | Arguments: 104 | state_dict: optimizer state. Should be an object returned 105 | from a call to :meth:`state_dict`. 106 | """ 107 | slow_state_dict = { 108 | "state": state_dict["slow_state"], 109 | "param_groups": state_dict["param_groups"], 110 | } 111 | fast_state_dict = { 112 | "state": state_dict["fast_state"], 113 | "param_groups": state_dict["param_groups"], 114 | } 115 | super(Lookahead, self).load_state_dict(slow_state_dict) 116 | self.optimizer.load_state_dict(fast_state_dict) 117 | self.fast_state = self.optimizer.state 118 | 119 | def zero_grad(self, set_to_none: bool = False) -> None: 120 | r"""Clears the gradients of all optimized :class:`torch.Tensor` s.""" 121 | self.optimizer.zero_grad(set_to_none) 122 | 123 | def __repr__(self) -> str: 124 | base_str = self.optimizer.__repr__() 125 | format_string = self.__class__.__name__ + " (" 126 | format_string += "\n" 127 | format_string += "k: {}\n".format(self.k) 128 | format_string += "alpha: {}\n".format(self.alpha) 129 | format_string += base_str 130 | format_string += "\n" 131 | format_string += ")" 132 | return format_string 133 | -------------------------------------------------------------------------------- /torch_optimizer/madgrad.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable, Optional 3 | 4 | import torch 5 | import torch.optim 6 | 7 | from .types import Params 8 | 9 | __all__ = ("MADGRAD",) 10 | 11 | 12 | class MADGRAD(torch.optim.Optimizer): 13 | r"""Implements MADGRAD algorithm. 14 | 15 | It has been proposed in `Adaptivity without Compromise: A Momentumized, 16 | Adaptive, Dual Averaged Gradient Method for Stochastic Optimization`__ 17 | 18 | Arguments: 19 | params (iterable): 20 | Iterable of parameters to optimize 21 | or dicts defining parameter groups. 22 | lr (float): 23 | Learning rate (default: 1e-2). 24 | momentum (float): 25 | Momentum value in the range [0,1) (default: 0.9). 26 | weight_decay (float): 27 | Weight decay, i.e. a L2 penalty (default: 0). 28 | eps (float): 29 | Term added to the denominator outside of the root operation 30 | to improve numerical stability. (default: 1e-6). 31 | 32 | Example: 33 | >>> import torch_optimizer as optim 34 | >>> optimizer = optim.MAGRAD(model.parameters(), lr=0.1) 35 | >>> optimizer.zero_grad() 36 | >>> loss_fn(model(input), target).backward() 37 | >>> optimizer.step() 38 | 39 | __ https://arxiv.org/abs/2101.11075 40 | 41 | Note: 42 | Reference code: https://github.com/facebookresearch/madgrad 43 | """ 44 | 45 | def __init__( 46 | self, 47 | params: Params, 48 | lr: float = 1e-2, 49 | momentum: float = 0.9, 50 | weight_decay: float = 0.0, 51 | eps: float = 1e-6, 52 | ): 53 | if momentum < 0 or momentum >= 1: 54 | raise ValueError("Invalid momentum value: {}".format(momentum)) 55 | if lr <= 0.0: 56 | raise ValueError("Invalid learning rate: {}".format(lr)) 57 | if weight_decay < 0: 58 | raise ValueError( 59 | "Invalid weight_decay value: {}".format(weight_decay) 60 | ) 61 | if eps < 0.0: 62 | raise ValueError("Invalid epsilon value: {}".format(eps)) 63 | 64 | defaults = dict( 65 | lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, k=0 66 | ) 67 | super().__init__(params, defaults) 68 | 69 | for group in self.param_groups: 70 | for p in group["params"]: 71 | state = self.state[p] 72 | 73 | state["grad_sum_sq"] = torch.zeros_like(p.data).detach() 74 | state["s"] = torch.zeros_like(p.data).detach() 75 | if momentum != 0: 76 | state["x0"] = torch.clone(p.data).detach() 77 | 78 | def step( 79 | self, closure: Optional[Callable[[], float]] = None 80 | ) -> Optional[float]: 81 | r"""Performs a single optimization step. 82 | 83 | Arguments: 84 | closure: A closure that reevaluates the model and returns the loss. 85 | """ 86 | loss = None 87 | if closure is not None: 88 | loss = closure() 89 | 90 | for group in self.param_groups: 91 | eps = group["eps"] 92 | k = group["k"] 93 | lr = group["lr"] + eps 94 | decay = group["weight_decay"] 95 | momentum = group["momentum"] 96 | 97 | ck = 1 - momentum 98 | lamb = lr * math.pow(k + 1, 0.5) 99 | 100 | for p in group["params"]: 101 | if p.grad is None: 102 | continue 103 | grad = p.grad.data 104 | state = self.state[p] 105 | 106 | if momentum != 0.0 and grad.is_sparse: 107 | raise RuntimeError( 108 | "momentum != 0 is not compatible with " 109 | "sparse gradients" 110 | ) 111 | 112 | grad_sum_sq = state["grad_sum_sq"] 113 | s = state["s"] 114 | 115 | # Apply weight decay 116 | if decay != 0: 117 | if grad.is_sparse: 118 | raise RuntimeError( 119 | "weight_decay option is not " 120 | "compatible with sparse gradients" 121 | ) 122 | 123 | grad.add_(p.data, alpha=decay) 124 | 125 | if grad.is_sparse: 126 | grad = grad.coalesce() 127 | grad_val = grad._values() 128 | 129 | p_masked = p.sparse_mask(grad) 130 | grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad) 131 | s_masked = s.sparse_mask(grad) 132 | 133 | # Compute x_0 from other known quantities 134 | rms_masked_vals = ( 135 | grad_sum_sq_masked._values().pow(1 / 3).add_(eps) 136 | ) 137 | x0_masked_vals = p_masked._values().addcdiv( 138 | s_masked._values(), rms_masked_vals, value=1 139 | ) 140 | 141 | # Dense + sparse op 142 | grad_sq = grad * grad 143 | grad_sum_sq.add_(grad_sq, alpha=lamb) 144 | grad_sum_sq_masked.add_(grad_sq, alpha=lamb) 145 | 146 | rms_masked_vals = ( 147 | grad_sum_sq_masked._values().pow_(1 / 3).add_(eps) 148 | ) 149 | 150 | s.add_(grad, alpha=lamb) 151 | s_masked._values().add_(grad_val, alpha=lamb) 152 | 153 | # update masked copy of p 154 | p_kp1_masked_vals = x0_masked_vals.addcdiv( 155 | s_masked._values(), rms_masked_vals, value=-1 156 | ) 157 | # Copy updated masked p to dense p using an add operation 158 | p_masked._values().add_(p_kp1_masked_vals, alpha=-1) 159 | p.data.add_(p_masked, alpha=-1) 160 | else: 161 | if momentum == 0: 162 | # Compute x_0 from other known quantities 163 | rms = grad_sum_sq.pow(1 / 3).add_(eps) 164 | x0 = p.data.addcdiv(s, rms, value=1) 165 | else: 166 | x0 = state["x0"] 167 | 168 | # Accumulate second moments 169 | grad_sum_sq.addcmul_(grad, grad, value=lamb) 170 | rms = grad_sum_sq.pow(1 / 3).add_(eps) 171 | 172 | # Update s 173 | s.data.add_(grad, alpha=lamb) 174 | 175 | # Step 176 | if momentum == 0: 177 | p.data.copy_(x0.addcdiv(s, rms, value=-1)) 178 | else: 179 | z = x0.addcdiv(s, rms, value=-1) 180 | 181 | # p is a moving average of z 182 | p.data.mul_(1 - ck).add_(z, alpha=ck) 183 | 184 | group["k"] = group["k"] + 1 185 | return loss 186 | -------------------------------------------------------------------------------- /torch_optimizer/novograd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | from .types import Betas2, OptFloat, OptLossClosure, Params 5 | 6 | __all__ = ("NovoGrad",) 7 | 8 | 9 | class NovoGrad(Optimizer): 10 | r"""Implements Novograd optimization algorithm. 11 | 12 | It has been proposed in `Stochastic Gradient Methods with Layer-wise 13 | Adaptive Moments for Training of Deep Networks`__. 14 | 15 | Arguments: 16 | params: iterable of parameters to optimize or dicts defining 17 | parameter groups 18 | lr: learning rate (default: 1e-3) 19 | betas: coefficients used for computing 20 | running averages of gradient and its square (default: (0.95, 0)) 21 | eps: term added to the denominator to improve 22 | numerical stability (default: 1e-8) 23 | weight_decay: weight decay (L2 penalty) (default: 0) 24 | grad_averaging: gradient averaging (default: False) 25 | amsgrad: whether to use the AMSGrad variant of this 26 | algorithm from the paper `On the Convergence of Adam and Beyond` 27 | (default: False) 28 | 29 | Example: 30 | >>> import torch_optimizer as optim 31 | >>> optimizer = optim.Yogi(model.parameters(), lr=0.1) 32 | >>> optimizer.zero_grad() 33 | >>> loss_fn(model(input), target).backward() 34 | >>> scheduler = StepLR(optimizer, step_size=1, gamma=0.7) 35 | >>> optimizer.step() 36 | >>> scheduler.step() 37 | 38 | __ https://arxiv.org/abs/1905.11286 39 | 40 | Note: 41 | Reference code: https://github.com/NVIDIA/DeepLearningExamples 42 | """ 43 | 44 | def __init__( 45 | self, 46 | params: Params, 47 | lr: float = 1e-3, 48 | betas: Betas2 = (0.95, 0), 49 | eps: float = 1e-8, 50 | weight_decay: float = 0, 51 | grad_averaging: bool = False, 52 | amsgrad: bool = False, 53 | ): 54 | if lr <= 0.0: 55 | raise ValueError("Invalid learning rate: {}".format(lr)) 56 | if eps < 0.0: 57 | raise ValueError("Invalid epsilon value: {}".format(eps)) 58 | if not 0.0 <= betas[0] < 1.0: 59 | raise ValueError( 60 | "Invalid beta parameter at index 0: {}".format(betas[0]) 61 | ) 62 | if not 0.0 <= betas[1] < 1.0: 63 | raise ValueError( 64 | "Invalid beta parameter at index 1: {}".format(betas[1]) 65 | ) 66 | if weight_decay < 0: 67 | raise ValueError( 68 | "Invalid weight_decay value: {}".format(weight_decay) 69 | ) 70 | defaults = dict( 71 | lr=lr, 72 | betas=betas, 73 | eps=eps, 74 | weight_decay=weight_decay, 75 | grad_averaging=grad_averaging, 76 | amsgrad=amsgrad, 77 | ) 78 | 79 | super(NovoGrad, self).__init__(params, defaults) 80 | 81 | def __setstate__(self, state: dict) -> None: 82 | super(NovoGrad, self).__setstate__(state) 83 | for group in self.param_groups: 84 | group.setdefault("amsgrad", False) 85 | 86 | def step(self, closure: OptLossClosure = None) -> OptFloat: 87 | r"""Performs a single optimization step. 88 | 89 | Arguments: 90 | closure: A closure that reevaluates the model and returns the loss. 91 | """ 92 | loss = None 93 | if closure is not None: 94 | loss = closure() 95 | 96 | for group in self.param_groups: 97 | for p in group["params"]: 98 | if p.grad is None: 99 | continue 100 | grad = p.grad.data 101 | if grad.is_sparse: 102 | msg = ( 103 | "NovoGrad does not support sparse gradients, " 104 | "please consider SparseAdam instead" 105 | ) 106 | raise RuntimeError(msg) 107 | amsgrad = group["amsgrad"] 108 | 109 | state = self.state[p] 110 | 111 | # State initialization 112 | if len(state) == 0: 113 | state["step"] = 0 114 | # Exponential moving average of gradient values 115 | state["exp_avg"] = torch.zeros_like( 116 | p.data, memory_format=torch.preserve_format 117 | ) 118 | # Exponential moving average of squared gradient values 119 | state["exp_avg_sq"] = torch.zeros([]).to( 120 | state["exp_avg"].device 121 | ) 122 | if amsgrad: 123 | # Maintains max of all exp. moving avg. of sq. 124 | # grad. values 125 | state["max_exp_avg_sq"] = torch.zeros([]).to( 126 | state["exp_avg"].device 127 | ) 128 | 129 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 130 | if amsgrad: 131 | max_exp_avg_sq = state["max_exp_avg_sq"] 132 | beta1, beta2 = group["betas"] 133 | 134 | state["step"] += 1 135 | 136 | norm = torch.sum(torch.pow(grad, 2)) 137 | 138 | if exp_avg_sq == 0: 139 | exp_avg_sq.copy_(norm) 140 | else: 141 | exp_avg_sq.mul_(beta2).add_(norm, alpha=1 - beta2) 142 | 143 | if amsgrad: 144 | # Maintains the maximum of all 2nd moment running avg. 145 | # till now 146 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 147 | # Use the max. for normalizing running avg. of gradient 148 | denom = max_exp_avg_sq.sqrt().add_(group["eps"]) 149 | else: 150 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 151 | 152 | grad.div_(denom) 153 | if group["weight_decay"] != 0: 154 | grad.add_(p.data, alpha=group["weight_decay"]) 155 | if group["grad_averaging"]: 156 | grad.mul_(1 - beta1) 157 | exp_avg.mul_(beta1).add_(grad) 158 | 159 | p.data.add_(exp_avg, alpha=-group["lr"]) 160 | 161 | return loss 162 | -------------------------------------------------------------------------------- /torch_optimizer/pid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | from .types import OptFloat, OptLossClosure, Params 5 | 6 | 7 | class PID(Optimizer): 8 | r"""Implements PID optimization algorithm. 9 | 10 | It has been proposed in `A PID Controller Approach for Stochastic 11 | Optimization of Deep Networks`__. 12 | 13 | Arguments: 14 | params: iterable of parameters to optimize or dicts defining 15 | parameter groups 16 | lr: learning rate (default: 1e-3) 17 | momentum: momentum factor (default: 0.0) 18 | weight_decay: weight decay (L2 penalty) (default: 0.0) 19 | dampening: dampening for momentum (default: 0.0) 20 | derivative: D part of the PID (default: 10.0) 21 | integral: I part of the PID (default: 5.0) 22 | 23 | Example: 24 | >>> import torch_optimizer as optim 25 | >>> optimizer = optim.PID(model.parameters(), lr=0.001, momentum=0.1) 26 | >>> optimizer.zero_grad() 27 | >>> loss_fn(model(input), target).backward() 28 | >>> optimizer.step() 29 | 30 | __ http://www4.comp.polyu.edu.hk/~cslzhang/paper/CVPR18_PID.pdf 31 | 32 | Note: 33 | Reference code: https://github.com/tensorboy/PIDOptimizer 34 | """ 35 | 36 | def __init__( 37 | self, 38 | params: Params, 39 | lr: float = 1e-3, 40 | momentum: float = 0.0, 41 | dampening: float = 0, 42 | weight_decay: float = 0.0, 43 | integral: float = 5.0, 44 | derivative: float = 10.0, 45 | ) -> None: 46 | defaults = dict( 47 | lr=lr, 48 | momentum=momentum, 49 | dampening=dampening, 50 | weight_decay=weight_decay, 51 | integral=integral, 52 | derivative=derivative, 53 | ) 54 | if lr <= 0.0: 55 | raise ValueError("Invalid learning rate: {}".format(lr)) 56 | if momentum < 0.0: 57 | raise ValueError("Invalid momentum value: {}".format(momentum)) 58 | if weight_decay < 0.0: 59 | raise ValueError( 60 | "Invalid weight_decay value: {}".format(weight_decay) 61 | ) 62 | if integral < 0.0: 63 | raise ValueError("Invalid PID integral value: {}".format(integral)) 64 | if derivative < 0.0: 65 | raise ValueError( 66 | "Invalid PID derivative value: {}".format(derivative) 67 | ) 68 | 69 | super(PID, self).__init__(params, defaults) 70 | 71 | def step(self, closure: OptLossClosure = None) -> OptFloat: 72 | r"""Performs a single optimization step. 73 | 74 | Arguments: 75 | closure: A closure that reevaluates the model and returns the loss. 76 | """ 77 | loss = None 78 | if closure is not None: 79 | loss = closure() 80 | 81 | for group in self.param_groups: 82 | weight_decay = group["weight_decay"] 83 | momentum = group["momentum"] 84 | dampening = group["dampening"] 85 | integral = group["integral"] 86 | derivative = group["derivative"] 87 | for p in group["params"]: 88 | if p.grad is None: 89 | continue 90 | d_p = p.grad.data 91 | if weight_decay != 0: 92 | d_p.add_(p.data, alpha=weight_decay) 93 | if momentum != 0: 94 | param_state = self.state[p] 95 | if "i_buffer" not in param_state: 96 | i_buf = param_state["i_buffer"] = torch.zeros_like( 97 | p, memory_format=torch.preserve_format 98 | ) 99 | i_buf.mul_(momentum).add_(d_p) 100 | else: 101 | i_buf = param_state["i_buffer"] 102 | i_buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 103 | if "grad_buffer" not in param_state: 104 | g_buf = param_state["grad_buffer"] = torch.zeros_like( 105 | p, memory_format=torch.preserve_format 106 | ) 107 | g_buf = d_p 108 | 109 | d_buf = param_state["d_buffer"] = torch.zeros_like( 110 | p, memory_format=torch.preserve_format 111 | ) 112 | d_buf.mul_(momentum).add_(d_p - g_buf) 113 | else: 114 | d_buf = param_state["d_buffer"] 115 | g_buf = param_state["grad_buffer"] 116 | d_buf.mul_(momentum).add_( 117 | d_p - g_buf, alpha=1 - momentum 118 | ) 119 | self.state[p]["grad_buffer"] = d_p.clone() 120 | 121 | d_p = d_p.add_(i_buf, alpha=integral).add_( 122 | d_buf, alpha=derivative 123 | ) 124 | p.data.add_(d_p, alpha=-group["lr"]) 125 | return loss 126 | -------------------------------------------------------------------------------- /torch_optimizer/py.typed: -------------------------------------------------------------------------------- 1 | # placeholder 2 | -------------------------------------------------------------------------------- /torch_optimizer/qhadam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | from .types import Betas2, Nus2, OptFloat, OptLossClosure, Params 5 | 6 | __all__ = ("QHAdam",) 7 | 8 | 9 | class QHAdam(Optimizer): 10 | r"""Implements the QHAdam optimization algorithm. 11 | 12 | It has been proposed in `Adaptive methods for Nonconvex Optimization`__. 13 | 14 | Arguments: 15 | params: iterable of parameters to optimize or dicts defining 16 | parameter groups 17 | lr: learning rate (default: 1e-3) 18 | betas: coefficients used for computing 19 | running averages of gradient and its square (default: (0.9, 0.999)) 20 | nus: immediate discount factors used to estimate the gradient and its 21 | square (default: (1.0, 1.0)) 22 | eps: term added to the denominator to improve 23 | numerical stability (default: 1e-8) 24 | weight_decay: weight decay (L2 penalty) (default: 0) 25 | decouple_weight_decay: whether to decouple the weight 26 | decay from the gradient-based optimization step (default: False) 27 | 28 | Example: 29 | >>> import torch_optimizer as optim 30 | >>> optimizer = optim.QHAdam(model.parameters(), lr=0.1) 31 | >>> optimizer.zero_grad() 32 | >>> loss_fn(model(input), target).backward() 33 | >>> optimizer.step() 34 | 35 | __ https://arxiv.org/abs/1810.06801 36 | 37 | Note: 38 | Reference code: https://github.com/facebookresearch/qhoptim 39 | """ 40 | 41 | def __init__( 42 | self, 43 | params: Params, 44 | lr: float = 1e-3, 45 | betas: Betas2 = (0.9, 0.999), 46 | nus: Nus2 = (1.0, 1.0), 47 | weight_decay: float = 0.0, 48 | decouple_weight_decay: bool = False, 49 | eps: float = 1e-8, 50 | ): 51 | if lr <= 0.0: 52 | raise ValueError("Invalid learning rate: {}".format(lr)) 53 | if eps < 0.0: 54 | raise ValueError("Invalid epsilon value: {}".format(eps)) 55 | if not 0.0 <= betas[0] < 1.0: 56 | raise ValueError( 57 | "Invalid beta parameter at index 0: {}".format(betas[0]) 58 | ) 59 | if not 0.0 <= betas[1] < 1.0: 60 | raise ValueError( 61 | "Invalid beta parameter at index 1: {}".format(betas[1]) 62 | ) 63 | if weight_decay < 0: 64 | raise ValueError( 65 | "Invalid weight_decay value: {}".format(weight_decay) 66 | ) 67 | 68 | defaults = { 69 | "lr": lr, 70 | "betas": betas, 71 | "nus": nus, 72 | "weight_decay": weight_decay, 73 | "decouple_weight_decay": decouple_weight_decay, 74 | "eps": eps, 75 | } 76 | super(QHAdam, self).__init__(params, defaults) 77 | 78 | def step(self, closure: OptLossClosure = None) -> OptFloat: 79 | """Performs a single optimization step. 80 | 81 | Arguments: 82 | closure: A closure that reevaluates the model and returns the loss. 83 | """ 84 | loss = None 85 | if closure is not None: 86 | loss = closure() 87 | 88 | for group in self.param_groups: 89 | lr = group["lr"] 90 | beta1, beta2 = group["betas"] 91 | nu1, nu2 = group["nus"] 92 | weight_decay = group["weight_decay"] 93 | decouple_weight_decay = group["decouple_weight_decay"] 94 | eps = group["eps"] 95 | 96 | for p in group["params"]: 97 | if p.grad is None: 98 | continue 99 | 100 | d_p = p.grad.data 101 | if d_p.is_sparse: 102 | raise RuntimeError( 103 | "QHAdam does not support sparse gradients, " 104 | "please consider SparseAdam instead" 105 | ) 106 | 107 | state = self.state[p] 108 | 109 | if weight_decay != 0: 110 | if decouple_weight_decay: 111 | p.data.mul_(1 - lr * weight_decay) 112 | else: 113 | d_p.add_(p.data, alpha=weight_decay) 114 | 115 | d_p_sq = d_p.mul(d_p) 116 | 117 | if len(state) == 0: 118 | state["beta1_weight"] = 0.0 119 | state["beta2_weight"] = 0.0 120 | state["exp_avg"] = torch.zeros_like( 121 | p.data, memory_format=torch.preserve_format 122 | ) 123 | state["exp_avg_sq"] = torch.zeros_like( 124 | p.data, memory_format=torch.preserve_format 125 | ) 126 | 127 | state["beta1_weight"] = 1.0 + beta1 * state["beta1_weight"] 128 | state["beta2_weight"] = 1.0 + beta2 * state["beta2_weight"] 129 | 130 | beta1_weight = state["beta1_weight"] 131 | beta2_weight = state["beta2_weight"] 132 | exp_avg = state["exp_avg"] 133 | exp_avg_sq = state["exp_avg_sq"] 134 | 135 | beta1_adj = 1.0 - (1.0 / beta1_weight) 136 | beta2_adj = 1.0 - (1.0 / beta2_weight) 137 | exp_avg.mul_(beta1_adj).add_(d_p, alpha=1.0 - beta1_adj) 138 | exp_avg_sq.mul_(beta2_adj).add_(d_p_sq, alpha=1.0 - beta2_adj) 139 | 140 | avg_grad = exp_avg.mul(nu1) 141 | if nu1 != 1.0: 142 | avg_grad.add_(d_p, alpha=1.0 - nu1) 143 | 144 | avg_grad_rms = exp_avg_sq.mul(nu2) 145 | if nu2 != 1.0: 146 | avg_grad_rms.add_(d_p_sq, alpha=1.0 - nu2) 147 | avg_grad_rms.sqrt_() 148 | if eps != 0.0: 149 | avg_grad_rms.add_(eps) 150 | 151 | p.data.addcdiv_(avg_grad, avg_grad_rms, value=-lr) 152 | 153 | return loss 154 | -------------------------------------------------------------------------------- /torch_optimizer/qhm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | from .types import OptFloat, OptLossClosure, Params 5 | 6 | __all__ = ("QHM",) 7 | 8 | 9 | class QHM(Optimizer): 10 | GRAD = "grad" 11 | DIRECT = "direct" 12 | 13 | r"""Implements quasi-hyperbolic momentum (QHM) optimization algorithm. 14 | 15 | It has been proposed in `Quasi-hyperbolic momentum and Adam for deep 16 | learning`__. 17 | 18 | Arguments: 19 | params: iterable of parameters to optimize or dicts defining 20 | parameter groups 21 | lr: learning rate (default: 1e-3) 22 | momentum: momentum factor (:math:`\beta` from the paper) 23 | nu: immediate discount factor (:math:`\nu` from the paper) 24 | weight_decay: weight decay (L2 regularization coefficient, times two) 25 | (default: 0.0) 26 | weight_decay_type: method of applying the weight decay: 27 | ``"grad"`` for accumulation in the gradient 28 | (same as :class:`torch.optim.SGD`) or 29 | ``"direct"`` for direct application to the parameters 30 | (default: ``"grad"``) 31 | 32 | Example: 33 | >>> import torch_optimizer as optim 34 | >>> optimizer = optim.QHM(model.parameters(), lr=0.1, momentum=0.9) 35 | >>> optimizer.zero_grad() 36 | >>> loss_fn(model(input), target).backward() 37 | >>> optimizer.step() 38 | 39 | 40 | __ https://arxiv.org/abs/1810.06801 41 | 42 | Note: 43 | Reference code: https://github.com/facebookresearch/qhoptim 44 | """ 45 | 46 | def __init__( 47 | self, 48 | params: Params, 49 | lr: float = 1e-3, 50 | momentum: float = 0.0, 51 | nu: float = 0.7, 52 | weight_decay: float = 0.0, 53 | weight_decay_type: str = "grad", 54 | ) -> None: 55 | if lr <= 0.0: 56 | raise ValueError("Invalid learning rate: {}".format(lr)) 57 | if momentum < 0.0: 58 | raise ValueError("Invalid momentum value: {}".format(momentum)) 59 | if weight_decay < 0.0: 60 | raise ValueError( 61 | "Invalid weight_decay value: {}".format(weight_decay) 62 | ) 63 | if weight_decay_type not in (self.GRAD, self.DIRECT): 64 | _type = weight_decay_type 65 | msg = "Invalid weight_decay_type value: {}".format(_type) 66 | raise ValueError(msg) 67 | 68 | defaults = { 69 | "lr": lr, 70 | "momentum": momentum, 71 | "nu": nu, 72 | "weight_decay": weight_decay, 73 | "weight_decay_type": weight_decay_type, 74 | } 75 | super(QHM, self).__init__(params, defaults) 76 | 77 | def step(self, closure: OptLossClosure = None) -> OptFloat: 78 | """Performs a single optimization step. 79 | 80 | Arguments: 81 | closure: A closure that reevaluates the model and returns the loss. 82 | """ 83 | loss = None 84 | if closure is not None: 85 | loss = closure() 86 | 87 | for group in self.param_groups: 88 | lr, nu, momentum = group["lr"], group["nu"], group["momentum"] 89 | weight_decay, weight_decay_type = ( 90 | group["weight_decay"], 91 | group["weight_decay_type"], 92 | ) 93 | 94 | for p in group["params"]: 95 | if p.grad is None: 96 | continue 97 | d_p = p.grad.data 98 | param_state = self.state[p] 99 | 100 | if weight_decay != 0: 101 | if weight_decay_type == self.GRAD: 102 | d_p.add_(p.data, alpha=weight_decay) 103 | else: 104 | p.data.mul_(1.0 - lr * weight_decay) 105 | 106 | if len(param_state) == 0: 107 | param_state["momentum_buffer"] = torch.zeros_like( 108 | p.data, memory_format=torch.preserve_format 109 | ) 110 | 111 | momentum_buffer = param_state["momentum_buffer"] 112 | momentum_buffer.mul_(momentum).add_(d_p, alpha=1.0 - momentum) 113 | 114 | p.data.add_(momentum_buffer, alpha=-lr * nu) 115 | p.data.add_(d_p, alpha=-lr * (1.0 - nu)) 116 | 117 | return loss 118 | -------------------------------------------------------------------------------- /torch_optimizer/radam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | 4 | import torch 5 | from torch.optim.optimizer import Optimizer 6 | 7 | from .types import Betas2, OptFloat, OptLossClosure, Params 8 | 9 | __all__ = ("RAdam",) 10 | 11 | 12 | class RAdam(Optimizer): 13 | r"""Implements RAdam optimization algorithm. 14 | 15 | Note: 16 | Deprecated, please use version provided by PyTorch_. 17 | 18 | It has been proposed in `On the Variance of the Adaptive Learning 19 | Rate and Beyond`__. 20 | 21 | Arguments: 22 | params: iterable of parameters to optimize or dicts defining 23 | parameter groups 24 | lr: learning rate (default: 1e-3) 25 | betas: coefficients used for computing 26 | running averages of gradient and its square (default: (0.9, 0.999)) 27 | eps: term added to the denominator to improve 28 | numerical stability (default: 1e-8) 29 | weight_decay: weight decay (L2 penalty) (default: 0) 30 | 31 | Example: 32 | >>> import torch_optimizer as optim 33 | >>> optimizer = optim.RAdam(model.parameters(), lr=0.1) 34 | >>> optimizer.zero_grad() 35 | >>> loss_fn(model(input), target).backward() 36 | >>> optimizer.step() 37 | 38 | __ https://arxiv.org/abs/1908.03265 39 | 40 | Note: 41 | Reference code: https://github.com/LiyuanLucasLiu/RAdam 42 | """ 43 | 44 | def __init__( 45 | self, 46 | params: Params, 47 | lr: float = 1e-3, 48 | betas: Betas2 = (0.9, 0.999), 49 | eps: float = 1e-8, 50 | weight_decay: float = 0, 51 | ) -> None: 52 | warnings.warn( 53 | "RAdam optimizer is deprecated, since it is included " 54 | "in pytorch natively.", 55 | DeprecationWarning, 56 | stacklevel=2, 57 | ) 58 | if lr <= 0.0: 59 | raise ValueError("Invalid learning rate: {}".format(lr)) 60 | if eps < 0.0: 61 | raise ValueError("Invalid epsilon value: {}".format(eps)) 62 | if not 0.0 <= betas[0] < 1.0: 63 | raise ValueError( 64 | "Invalid beta parameter at index 0: {}".format(betas[0]) 65 | ) 66 | if not 0.0 <= betas[1] < 1.0: 67 | raise ValueError( 68 | "Invalid beta parameter at index 1: {}".format(betas[1]) 69 | ) 70 | if weight_decay < 0: 71 | raise ValueError( 72 | "Invalid weight_decay value: {}".format(weight_decay) 73 | ) 74 | 75 | if ( 76 | isinstance(params, (list, tuple)) 77 | and len(params) > 0 78 | and isinstance(params[0], dict) 79 | ): 80 | for param in params: 81 | if "betas" in param and ( 82 | param["betas"][0] != betas[0] 83 | or param["betas"][1] != betas[1] 84 | ): 85 | param["buffer"] = [[None, None, None] for _ in range(10)] 86 | 87 | defaults = dict( 88 | lr=lr, 89 | betas=betas, 90 | eps=eps, 91 | weight_decay=weight_decay, 92 | buffer=[[None, None, None] for _ in range(10)], 93 | ) 94 | super(RAdam, self).__init__(params, defaults) 95 | 96 | def __setstate__(self, state): 97 | super(RAdam, self).__setstate__(state) 98 | 99 | def step(self, closure: OptLossClosure = None) -> OptFloat: 100 | r"""Performs a single optimization step. 101 | 102 | Arguments: 103 | closure: A closure that reevaluates the model and returns the loss. 104 | """ 105 | 106 | loss = None 107 | if closure is not None: 108 | loss = closure() 109 | 110 | for group in self.param_groups: 111 | lr = group["lr"] 112 | weight_decay = group["weight_decay"] 113 | beta1, beta2 = group["betas"] 114 | eps = group["eps"] 115 | 116 | for p in group["params"]: 117 | if p.grad is None: 118 | continue 119 | grad = p.grad.data.float() 120 | if grad.is_sparse: 121 | msg = ( 122 | "RAdam does not support sparse gradients, " 123 | "please consider SparseAdam instead" 124 | ) 125 | raise RuntimeError(msg) 126 | 127 | p_data_fp32 = p.data.float() 128 | 129 | state = self.state[p] 130 | 131 | if len(state) == 0: 132 | state["step"] = 0 133 | state["exp_avg"] = torch.zeros_like( 134 | p_data_fp32, memory_format=torch.preserve_format 135 | ) 136 | state["exp_avg_sq"] = torch.zeros_like( 137 | p_data_fp32, memory_format=torch.preserve_format 138 | ) 139 | else: 140 | state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) 141 | state["exp_avg_sq"] = state["exp_avg_sq"].type_as( 142 | p_data_fp32 143 | ) 144 | 145 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 146 | 147 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 148 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 149 | 150 | state["step"] += 1 151 | buffered = group["buffer"][int(state["step"] % 10)] 152 | if state["step"] == buffered[0]: 153 | N_sma, step_size = buffered[1], buffered[2] 154 | else: 155 | buffered[0] = state["step"] 156 | beta2_t = beta2 ** state["step"] 157 | N_sma_max = 2 / (1 - beta2) - 1 158 | N_sma = N_sma_max - 2 * state["step"] * beta2_t / ( 159 | 1 - beta2_t 160 | ) 161 | buffered[1] = N_sma 162 | 163 | # more conservative since it's an approximated value 164 | if N_sma >= 5: 165 | step_size = ( 166 | lr 167 | * math.sqrt( 168 | (1 - beta2_t) 169 | * (N_sma - 4) 170 | / (N_sma_max - 4) 171 | * (N_sma - 2) 172 | / N_sma 173 | * N_sma_max 174 | / (N_sma_max - 2) 175 | ) 176 | / (1 - beta1 ** state["step"]) 177 | ) 178 | else: 179 | step_size = lr / (1 - beta1 ** state["step"]) 180 | buffered[2] = step_size 181 | 182 | if weight_decay != 0: 183 | p_data_fp32.add_(p_data_fp32, alpha=-weight_decay * lr) 184 | 185 | # more conservative since it's an approximated value 186 | if N_sma >= 5: 187 | denom = exp_avg_sq.sqrt().add_(eps) 188 | p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size) 189 | else: 190 | p_data_fp32.add_(exp_avg, alpha=-step_size) 191 | 192 | p.data.copy_(p_data_fp32) 193 | 194 | return loss 195 | -------------------------------------------------------------------------------- /torch_optimizer/sgdp.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.optim.optimizer import Optimizer 5 | 6 | from .types import OptFloat, OptLossClosure, Params 7 | 8 | __all__ = ("SGDP",) 9 | 10 | 11 | class SGDP(Optimizer): 12 | r"""Implements SGDP algorithm. 13 | 14 | It has been proposed in `Slowing Down the Weight Norm Increase in 15 | Momentum-based Optimizers`__ 16 | 17 | Arguments: 18 | params: iterable of parameters to optimize or dicts defining 19 | parameter groups 20 | lr: learning rate (default: 1e-3) 21 | momentum: momentum factor (default: 0) 22 | dampening: dampening for momentum (default: 0) 23 | eps: term added to the denominator to improve 24 | numerical stability (default: 1e-8) 25 | weight_decay: weight decay (L2 penalty) (default: 0) 26 | delta: threhold that determines whether a set of parameters is scale 27 | invariant or not (default: 0.1) 28 | wd_ratio: relative weight decay applied on scale-invariant parameters 29 | compared to that applied on scale-variant parameters (default: 0.1) 30 | nesterov: enables Nesterov momentum (default: False) 31 | 32 | 33 | Example: 34 | >>> import torch_optimizer as optim 35 | >>> optimizer = optim.SGDP(model.parameters(), lr=0.1) 36 | >>> optimizer.zero_grad() 37 | >>> loss_fn(model(input), target).backward() 38 | >>> optimizer.step() 39 | 40 | __ https://arxiv.org/abs/2006.08217 41 | 42 | Note: 43 | Reference code: https://github.com/clovaai/AdamP 44 | """ 45 | 46 | def __init__( 47 | self, 48 | params: Params, 49 | lr: float = 1e-3, 50 | momentum: float = 0, 51 | dampening: float = 0, 52 | eps: float = 1e-8, 53 | weight_decay: float = 0, 54 | delta: float = 0.1, 55 | wd_ratio: float = 0.1, 56 | nesterov: bool = False, 57 | ) -> None: 58 | if lr <= 0.0: 59 | raise ValueError("Invalid learning rate: {}".format(lr)) 60 | if eps < 0.0: 61 | raise ValueError("Invalid epsilon value: {}".format(eps)) 62 | if momentum < 0.0: 63 | raise ValueError("Invalid momentum value: {}".format(momentum)) 64 | if dampening < 0.0: 65 | raise ValueError("Invalid dampening value: {}".format(dampening)) 66 | if weight_decay < 0: 67 | raise ValueError( 68 | "Invalid weight_decay value: {}".format(weight_decay) 69 | ) 70 | if delta < 0: 71 | raise ValueError("Invalid delta value: {}".format(delta)) 72 | if wd_ratio < 0: 73 | raise ValueError("Invalid wd_ratio value: {}".format(wd_ratio)) 74 | 75 | defaults = dict( 76 | lr=lr, 77 | momentum=momentum, 78 | dampening=dampening, 79 | eps=eps, 80 | weight_decay=weight_decay, 81 | delta=delta, 82 | wd_ratio=wd_ratio, 83 | nesterov=nesterov, 84 | ) 85 | super(SGDP, self).__init__(params, defaults) 86 | 87 | @staticmethod 88 | def _channel_view(x): 89 | return x.view(x.size(0), -1) 90 | 91 | @staticmethod 92 | def _layer_view(x): 93 | return x.view(1, -1) 94 | 95 | @staticmethod 96 | def _cosine_similarity(x, y, eps, view_func): 97 | x = view_func(x) 98 | y = view_func(y) 99 | 100 | x_norm = x.norm(dim=1).add_(eps) 101 | y_norm = y.norm(dim=1).add_(eps) 102 | dot = (x * y).sum(dim=1) 103 | 104 | return dot.abs() / x_norm / y_norm 105 | 106 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 107 | wd = 1 108 | expand_size = [-1] + [1] * (len(p.shape) - 1) 109 | for view_func in [self._channel_view, self._layer_view]: 110 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 111 | 112 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 113 | p_n = p.data / view_func(p.data).norm(dim=1).view( 114 | expand_size 115 | ).add_(eps) 116 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view( 117 | expand_size 118 | ) 119 | wd = wd_ratio 120 | 121 | return perturb, wd 122 | 123 | return perturb, wd 124 | 125 | def step(self, closure: OptLossClosure = None) -> OptFloat: 126 | r"""Performs a single optimization step. 127 | 128 | Arguments: 129 | closure: A closure that reevaluates the model and returns the loss. 130 | """ 131 | loss = None 132 | if closure is not None: 133 | loss = closure() 134 | 135 | for group in self.param_groups: 136 | weight_decay = group["weight_decay"] 137 | momentum = group["momentum"] 138 | dampening = group["dampening"] 139 | nesterov = group["nesterov"] 140 | 141 | for p in group["params"]: 142 | if p.grad is None: 143 | continue 144 | 145 | grad = p.grad.data 146 | state = self.state[p] 147 | 148 | # State initialization 149 | if len(state) == 0: 150 | state["momentum"] = torch.zeros_like( 151 | p.data, memory_format=torch.preserve_format 152 | ) 153 | 154 | # SGD 155 | buf = state["momentum"] 156 | buf.mul_(momentum).add_(grad, alpha=1 - dampening) 157 | if nesterov: 158 | d_p = grad + momentum * buf 159 | else: 160 | d_p = buf 161 | 162 | # Projection 163 | wd_ratio = 1 164 | if len(p.shape) > 1: 165 | d_p, wd_ratio = self._projection( 166 | p, 167 | grad, 168 | d_p, 169 | group["delta"], 170 | group["wd_ratio"], 171 | group["eps"], 172 | ) 173 | 174 | # Weight decay 175 | if weight_decay != 0: 176 | p.data.mul_( 177 | 1 178 | - group["lr"] 179 | * group["weight_decay"] 180 | * wd_ratio 181 | / (1 - momentum) 182 | ) 183 | 184 | # Step 185 | p.data.add_(d_p, alpha=-group["lr"]) 186 | 187 | return loss 188 | -------------------------------------------------------------------------------- /torch_optimizer/sgdw.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | from .types import OptFloat, OptLossClosure, Params, State 5 | 6 | __all__ = ("SGDW",) 7 | 8 | 9 | class SGDW(Optimizer): 10 | r"""Implements SGDW algorithm. 11 | 12 | It has been proposed in `Decoupled Weight Decay Regularization`__. 13 | 14 | Arguments: 15 | params: iterable of parameters to optimize or dicts defining 16 | parameter groups 17 | lr: learning rate (default: 1e-3) 18 | momentum: momentum factor (default: 0) 19 | weight_decay: weight decay (L2 penalty) (default: 0) 20 | dampening: dampening for momentum (default: 0) 21 | nesterov: enables Nesterov momentum (default: False) 22 | 23 | Example: 24 | >>> import torch_optimizer as optim 25 | >>> optimizer = optim.SGDW(model.parameters(), lr=0.1, momentum=0.9) 26 | >>> optimizer.zero_grad() 27 | >>> loss_fn(model(input), target).backward() 28 | >>> optimizer.step() 29 | 30 | __ https://arxiv.org/abs/1711.05101 31 | 32 | Note: 33 | Reference code: https://github.com/pytorch/pytorch/pull/22466 34 | """ 35 | 36 | def __init__( 37 | self, 38 | params: Params, 39 | lr: float = 1e-3, 40 | momentum: float = 0.0, 41 | dampening: float = 0.0, 42 | weight_decay: float = 0.0, 43 | nesterov: bool = False, 44 | ) -> None: 45 | if lr <= 0.0: 46 | raise ValueError("Invalid learning rate: {}".format(lr)) 47 | if momentum < 0.0: 48 | raise ValueError("Invalid momentum value: {}".format(momentum)) 49 | if dampening < 0.0: 50 | raise ValueError("Invalid dampening value: {}".format(dampening)) 51 | if weight_decay < 0.0: 52 | raise ValueError( 53 | "Invalid weight_decay value: {}".format(weight_decay) 54 | ) 55 | 56 | defaults = dict( 57 | lr=lr, 58 | momentum=momentum, 59 | dampening=dampening, 60 | weight_decay=weight_decay, 61 | nesterov=nesterov, 62 | ) 63 | if nesterov and (momentum <= 0 or dampening != 0): 64 | raise ValueError( 65 | "Nesterov momentum requires a momentum and zero dampening" 66 | ) 67 | super(SGDW, self).__init__(params, defaults) 68 | 69 | def __setstate__(self, state: State) -> None: 70 | super(SGDW, self).__setstate__(state) 71 | for group in self.param_groups: 72 | group.setdefault("nesterov", False) 73 | 74 | def step(self, closure: OptLossClosure = None) -> OptFloat: 75 | """Performs a single optimization step. 76 | 77 | Arguments: 78 | closure: A closure that reevaluates the model and returns the loss. 79 | """ 80 | loss = None 81 | if closure is not None: 82 | loss = closure() 83 | 84 | for group in self.param_groups: 85 | weight_decay = group["weight_decay"] 86 | momentum = group["momentum"] 87 | dampening = group["dampening"] 88 | nesterov = group["nesterov"] 89 | 90 | for p in group["params"]: 91 | if p.grad is None: 92 | continue 93 | d_p = p.grad.data 94 | 95 | if p.grad.is_sparse: 96 | msg = ( 97 | "SGDW does not support sparse gradients, " 98 | "please consider SparseAdam instead" 99 | ) 100 | raise RuntimeError(msg) 101 | 102 | if momentum != 0: 103 | param_state = self.state[p] 104 | if "momentum_buffer" not in param_state: 105 | buf = param_state["momentum_buffer"] = torch.clone( 106 | d_p 107 | ).detach() 108 | else: 109 | buf = param_state["momentum_buffer"] 110 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 111 | if nesterov: 112 | d_p = d_p.add(momentum, buf) 113 | else: 114 | d_p = buf 115 | 116 | # Apply momentum 117 | p.data.add_(d_p, alpha=-group["lr"]) 118 | 119 | # Apply weight decay 120 | if weight_decay != 0: 121 | p.data.add_(weight_decay, alpha=-group["lr"]) 122 | return loss 123 | -------------------------------------------------------------------------------- /torch_optimizer/shampoo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | from .types import OptFloat, OptLossClosure, Params 5 | 6 | 7 | def _matrix_power(matrix: torch.Tensor, power: float) -> torch.Tensor: 8 | # use CPU for svd for speed up 9 | device = matrix.device 10 | matrix = matrix.cpu() 11 | u, s, v = torch.svd(matrix) 12 | return (u @ s.pow_(power).diag() @ v.t()).to(device) 13 | 14 | 15 | class Shampoo(Optimizer): 16 | r"""Implements Shampoo Optimizer Algorithm. 17 | 18 | It has been proposed in `Shampoo: Preconditioned Stochastic Tensor 19 | Optimization`__. 20 | 21 | Arguments: 22 | params: iterable of parameters to optimize or dicts defining 23 | parameter groups 24 | lr: learning rate (default: 1e-3) 25 | momentum: momentum factor (default: 0) 26 | weight_decay: weight decay (L2 penalty) (default: 0) 27 | epsilon: epsilon added to each mat_gbar_j for numerical stability 28 | (default: 1e-4) 29 | update_freq: update frequency to compute inverse (default: 1) 30 | 31 | Example: 32 | >>> import torch_optimizer as optim 33 | >>> optimizer = optim.Shampoo(model.parameters(), lr=0.01) 34 | >>> optimizer.zero_grad() 35 | >>> loss_fn(model(input), target).backward() 36 | >>> optimizer.step() 37 | 38 | __ https://arxiv.org/abs/1802.09568 39 | 40 | Note: 41 | Reference code: https://github.com/moskomule/shampoo.pytorch 42 | """ 43 | 44 | def __init__( 45 | self, 46 | params: Params, 47 | lr: float = 1e-1, 48 | momentum: float = 0.0, 49 | weight_decay: float = 0.0, 50 | epsilon: float = 1e-4, 51 | update_freq: int = 1, 52 | ): 53 | if lr <= 0.0: 54 | raise ValueError("Invalid learning rate: {}".format(lr)) 55 | if momentum < 0.0: 56 | raise ValueError("Invalid momentum value: {}".format(momentum)) 57 | if weight_decay < 0.0: 58 | raise ValueError( 59 | "Invalid weight_decay value: {}".format(weight_decay) 60 | ) 61 | if epsilon < 0.0: 62 | raise ValueError("Invalid momentum value: {}".format(momentum)) 63 | if update_freq < 1: 64 | raise ValueError("Invalid momentum value: {}".format(momentum)) 65 | 66 | defaults = dict( 67 | lr=lr, 68 | momentum=momentum, 69 | weight_decay=weight_decay, 70 | epsilon=epsilon, 71 | update_freq=update_freq, 72 | ) 73 | super(Shampoo, self).__init__(params, defaults) 74 | 75 | def step(self, closure: OptLossClosure = None) -> OptFloat: 76 | """Performs a single optimization step. 77 | 78 | Arguments: 79 | closure: A closure that reevaluates the model and returns the loss. 80 | """ 81 | loss = None 82 | if closure is not None: 83 | loss = closure() 84 | 85 | for group in self.param_groups: 86 | for p in group["params"]: 87 | if p.grad is None: 88 | continue 89 | grad = p.grad.data 90 | order = grad.ndimension() 91 | original_size = grad.size() 92 | state = self.state[p] 93 | momentum = group["momentum"] 94 | weight_decay = group["weight_decay"] 95 | if len(state) == 0: 96 | state["step"] = 0 97 | if momentum > 0: 98 | state["momentum_buffer"] = grad.clone() 99 | for dim_id, dim in enumerate(grad.size()): 100 | # precondition matrices 101 | state["precond_{}".format(dim_id)] = group[ 102 | "epsilon" 103 | ] * torch.eye(dim, out=grad.new(dim, dim)) 104 | state[ 105 | "inv_precond_{dim_id}".format(dim_id=dim_id) 106 | ] = grad.new(dim, dim).zero_() 107 | 108 | if momentum > 0: 109 | grad.mul_(1 - momentum).add_( 110 | state["momentum_buffer"], alpha=momentum 111 | ) 112 | 113 | if weight_decay > 0: 114 | grad.add_(p.data, alpha=group["weight_decay"]) 115 | 116 | # See Algorithm 2 for detail 117 | for dim_id, dim in enumerate(grad.size()): 118 | precond = state["precond_{}".format(dim_id)] 119 | inv_precond = state["inv_precond_{}".format(dim_id)] 120 | 121 | # mat_{dim_id}(grad) 122 | grad = grad.transpose_(0, dim_id).contiguous() 123 | transposed_size = grad.size() 124 | grad = grad.view(dim, -1) 125 | 126 | grad_t = grad.t() 127 | precond.add_(grad @ grad_t) 128 | if state["step"] % group["update_freq"] == 0: 129 | inv_precond.copy_(_matrix_power(precond, -1 / order)) 130 | 131 | if dim_id == order - 1: 132 | # finally 133 | grad = grad_t @ inv_precond 134 | # grad: (-1, last_dim) 135 | grad = grad.view(original_size) 136 | else: 137 | # if not final 138 | grad = inv_precond @ grad 139 | # grad (dim, -1) 140 | grad = grad.view(transposed_size) 141 | 142 | state["step"] += 1 143 | state["momentum_buffer"] = grad 144 | p.data.add_(grad, alpha=-group["lr"]) 145 | 146 | return loss 147 | -------------------------------------------------------------------------------- /torch_optimizer/swats.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | from .types import Betas2, OptFloat, OptLossClosure, Params, State 5 | 6 | __all__ = ("SWATS",) 7 | 8 | 9 | class SWATS(Optimizer): 10 | r"""Implements SWATS Optimizer Algorithm. 11 | It has been proposed in `Improving Generalization Performance by 12 | Switching from Adam to SGD`__. 13 | 14 | Arguments: 15 | params: iterable of parameters to optimize or dicts defining 16 | parameter groups 17 | lr: learning rate (default: 1e-2) 18 | betas: coefficients used for computing 19 | running averages of gradient and its square (default: (0.9, 0.999)) 20 | eps: term added to the denominator to improve 21 | numerical stability (default: 1e-3) 22 | weight_decay: weight decay (L2 penalty) (default: 0) 23 | amsgrad: whether to use the AMSGrad variant of this 24 | algorithm from the paper `On the Convergence of Adam and Beyond` 25 | (default: False) 26 | nesterov: enables Nesterov momentum (default: False) 27 | 28 | 29 | Example: 30 | >>> import torch_optimizer as optim 31 | >>> optimizer = optim.SWATS(model.parameters(), lr=0.01) 32 | >>> optimizer.zero_grad() 33 | >>> loss_fn(model(input), target).backward() 34 | >>> optimizer.step() 35 | 36 | __ https://arxiv.org/pdf/1712.07628.pdf 37 | 38 | Note: 39 | Reference code: https://github.com/Mrpatekful/swats 40 | """ 41 | 42 | def __init__( 43 | self, 44 | params: Params, 45 | lr: float = 1e-3, 46 | betas: Betas2 = (0.9, 0.999), 47 | eps: float = 1e-3, 48 | weight_decay: float = 0, 49 | amsgrad: bool = False, 50 | nesterov: bool = False, 51 | ): 52 | if not 0.0 <= lr: 53 | raise ValueError("Invalid learning rate: {}".format(lr)) 54 | if not 0.0 <= eps: 55 | raise ValueError("Invalid epsilon value: {}".format(eps)) 56 | if not 0.0 <= betas[0] < 1.0: 57 | raise ValueError( 58 | "Invalid beta parameter at index 0: {}".format(betas[0]) 59 | ) 60 | if not 0.0 <= betas[1] < 1.0: 61 | raise ValueError( 62 | "Invalid beta parameter at index 1: {}".format(betas[1]) 63 | ) 64 | if weight_decay < 0: 65 | raise ValueError( 66 | "Invalid weight_decay value: {}".format(weight_decay) 67 | ) 68 | defaults = dict( 69 | lr=lr, 70 | betas=betas, 71 | eps=eps, 72 | phase="ADAM", 73 | weight_decay=weight_decay, 74 | amsgrad=amsgrad, 75 | nesterov=nesterov, 76 | ) 77 | 78 | super().__init__(params, defaults) 79 | 80 | def __setstate__(self, state: State) -> None: 81 | super().__setstate__(state) 82 | for group in self.param_groups: 83 | group.setdefault("amsgrad", False) 84 | group.setdefault("nesterov", False) 85 | 86 | def step(self, closure: OptLossClosure = None) -> OptFloat: 87 | r"""Performs a single optimization step. 88 | 89 | Arguments: 90 | closure: A closure that reevaluates the model and returns the loss. 91 | """ 92 | loss = None 93 | if closure is not None: 94 | loss = closure() 95 | 96 | for group in self.param_groups: 97 | for w in group["params"]: 98 | if w.grad is None: 99 | continue 100 | grad = w.grad.data 101 | 102 | if grad.is_sparse: 103 | raise RuntimeError( 104 | "Adam does not support sparse gradients, " 105 | "please consider SparseAdam instead" 106 | ) 107 | 108 | amsgrad = group["amsgrad"] 109 | 110 | state = self.state[w] 111 | 112 | # state initialization 113 | if len(state) == 0: 114 | state["step"] = 0 115 | # exponential moving average of gradient values 116 | state["exp_avg"] = torch.zeros_like( 117 | w.data, memory_format=torch.preserve_format 118 | ) 119 | # exponential moving average of squared gradient values 120 | state["exp_avg_sq"] = torch.zeros_like( 121 | w.data, memory_format=torch.preserve_format 122 | ) 123 | # moving average for the non-orthogonal projection scaling 124 | state["exp_avg2"] = w.new(1).fill_(0) 125 | if amsgrad: 126 | # maintains max of all exp. moving avg. 127 | # of sq. grad. values 128 | state["max_exp_avg_sq"] = torch.zeros_like( 129 | w.data, memory_format=torch.preserve_format 130 | ) 131 | 132 | exp_avg, exp_avg2, exp_avg_sq = ( 133 | state["exp_avg"], 134 | state["exp_avg2"], 135 | state["exp_avg_sq"], 136 | ) 137 | 138 | if amsgrad: 139 | max_exp_avg_sq = state["max_exp_avg_sq"] 140 | beta1, beta2 = group["betas"] 141 | 142 | state["step"] += 1 143 | 144 | if group["weight_decay"] != 0: 145 | grad.add_(w.data, alpha=group["weight_decay"]) 146 | 147 | # if its SGD phase, take an SGD update and continue 148 | if group["phase"] == "SGD": 149 | if "momentum_buffer" not in state: 150 | buf = state["momentum_buffer"] = torch.clone( 151 | grad 152 | ).detach() 153 | else: 154 | buf = state["momentum_buffer"] 155 | buf.mul_(beta1).add_(grad) 156 | grad = buf 157 | 158 | grad.mul_(1 - beta1) 159 | if group["nesterov"]: 160 | grad.add_(buf, alpha=beta1) 161 | 162 | w.data.add_(grad, alpha=-group["lr"]) 163 | continue 164 | 165 | # decay the first and second moment running average coefficient 166 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 167 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 168 | if amsgrad: 169 | # maintains the maximum of all 2nd 170 | # moment running avg. till now 171 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 172 | # use the max. for normalizing running avg. of gradient 173 | denom = max_exp_avg_sq.sqrt().add_(group["eps"]) 174 | else: 175 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 176 | 177 | bias_correction1 = 1 - beta1 ** state["step"] 178 | bias_correction2 = 1 - beta2 ** state["step"] 179 | step_size = ( 180 | group["lr"] * (bias_correction2**0.5) / bias_correction1 181 | ) 182 | 183 | p = -step_size * (exp_avg / denom) 184 | w.data.add_(p) 185 | 186 | p_view = p.view(-1) 187 | pg = p_view.dot(grad.view(-1)) 188 | 189 | if pg != 0: 190 | # the non-orthognal scaling estimate 191 | scaling = p_view.dot(p_view) / -pg 192 | exp_avg2.mul_(beta2).add_(scaling, alpha=1 - beta2) 193 | 194 | # bias corrected exponential average 195 | corrected_exp_avg = exp_avg2 / bias_correction2 196 | 197 | # checking criteria of switching to SGD training 198 | if ( 199 | state["step"] > 1 200 | and corrected_exp_avg.allclose(scaling, rtol=1e-6) 201 | and corrected_exp_avg > 0 202 | ): 203 | group["phase"] = "SGD" 204 | group["lr"] = corrected_exp_avg.item() 205 | return loss 206 | -------------------------------------------------------------------------------- /torch_optimizer/types.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union 2 | 3 | from torch import Tensor 4 | 5 | Params = Union[Iterable[Tensor], Iterable[Dict[str, Any]]] 6 | 7 | LossClosure = Callable[[], float] 8 | OptLossClosure = Optional[LossClosure] 9 | Betas2 = Tuple[float, float] 10 | State = Dict[str, Any] 11 | OptFloat = Optional[float] 12 | Nus2 = Tuple[float, float] 13 | -------------------------------------------------------------------------------- /torch_optimizer/yogi.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.optim.optimizer import Optimizer 6 | 7 | from .types import Betas2, OptFloat, OptLossClosure, Params 8 | 9 | __all__ = ("Yogi",) 10 | 11 | 12 | class Yogi(Optimizer): 13 | r"""Implements Yogi Optimizer Algorithm. 14 | It has been proposed in `Adaptive methods for Nonconvex Optimization`__. 15 | 16 | Arguments: 17 | params: iterable of parameters to optimize or dicts defining 18 | parameter groups 19 | lr: learning rate (default: 1e-2) 20 | betas: coefficients used for computing 21 | running averages of gradient and its square (default: (0.9, 0.999)) 22 | eps: term added to the denominator to improve 23 | numerical stability (default: 0.001) 24 | initial_accumulator: initial values for first and 25 | second moments (default: 1e-6) 26 | weight_decay: weight decay (L2 penalty) (default: 0) 27 | 28 | Example: 29 | >>> import torch_optimizer as optim 30 | >>> optimizer = optim.Yogi(model.parameters(), lr=0.01) 31 | >>> optimizer.zero_grad() 32 | >>> loss_fn(model(input), target).backward() 33 | >>> optimizer.step() 34 | 35 | __ https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization # noqa 36 | 37 | Note: 38 | Reference code: https://github.com/4rtemi5/Yogi-Optimizer_Keras 39 | """ 40 | 41 | def __init__( 42 | self, 43 | params: Params, 44 | lr: float = 1e-2, 45 | betas: Betas2 = (0.9, 0.999), 46 | eps: float = 1e-3, 47 | initial_accumulator: float = 1e-6, 48 | weight_decay: float = 0, 49 | ) -> None: 50 | if lr <= 0.0: 51 | raise ValueError("Invalid learning rate: {}".format(lr)) 52 | if eps < 0.0: 53 | raise ValueError("Invalid epsilon value: {}".format(eps)) 54 | if not 0.0 <= betas[0] < 1.0: 55 | raise ValueError( 56 | "Invalid beta parameter at index 0: {}".format(betas[0]) 57 | ) 58 | if not 0.0 <= betas[1] < 1.0: 59 | raise ValueError( 60 | "Invalid beta parameter at index 1: {}".format(betas[1]) 61 | ) 62 | if weight_decay < 0: 63 | raise ValueError( 64 | "Invalid weight_decay value: {}".format(weight_decay) 65 | ) 66 | 67 | defaults = dict( 68 | lr=lr, 69 | betas=betas, 70 | eps=eps, 71 | initial_accumulator=initial_accumulator, 72 | weight_decay=weight_decay, 73 | ) 74 | super(Yogi, self).__init__(params, defaults) 75 | 76 | def step(self, closure: OptLossClosure = None) -> OptFloat: 77 | r"""Performs a single optimization step. 78 | 79 | Arguments: 80 | closure: A closure that reevaluates the model and returns the loss. 81 | """ 82 | loss = None 83 | if closure is not None: 84 | loss = closure() 85 | 86 | for group in self.param_groups: 87 | for p in group["params"]: 88 | if p.grad is None: 89 | continue 90 | grad = p.grad.data 91 | if grad.is_sparse: 92 | raise RuntimeError( 93 | "Yogi does not support sparse gradients, " 94 | "please consider SparseAdam instead" 95 | ) 96 | 97 | state = self.state[p] 98 | 99 | # State initialization 100 | # Followed from official implementation in tensorflow addons: 101 | # https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/yogi.py#L118 # noqa 102 | # For more details refer to the discussion: 103 | # https://github.com/jettify/pytorch-optimizer/issues/77 104 | if len(state) == 0: 105 | state["step"] = 0 106 | # Exponential moving average of gradient values 107 | state["exp_avg"] = nn.init.constant_( 108 | torch.empty_like( 109 | p.data, memory_format=torch.preserve_format 110 | ), 111 | group["initial_accumulator"], 112 | ) 113 | # Exponential moving average of squared gradient values 114 | state["exp_avg_sq"] = nn.init.constant_( 115 | torch.empty_like( 116 | p.data, memory_format=torch.preserve_format 117 | ), 118 | group["initial_accumulator"], 119 | ) 120 | 121 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 122 | beta1, beta2 = group["betas"] 123 | 124 | state["step"] += 1 125 | bias_correction1 = 1 - beta1 ** state["step"] 126 | bias_correction2 = 1 - beta2 ** state["step"] 127 | 128 | if group["weight_decay"] != 0: 129 | grad = grad.add(p.data, alpha=group["weight_decay"]) 130 | 131 | # Decay the first and second moment running average coefficient 132 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 133 | 134 | grad_squared = grad.mul(grad) 135 | 136 | exp_avg_sq.addcmul_( 137 | torch.sign(exp_avg_sq - grad_squared), 138 | grad_squared, 139 | value=-(1 - beta2), 140 | ) 141 | 142 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( 143 | group["eps"] 144 | ) 145 | step_size = group["lr"] / bias_correction1 146 | p.data.addcdiv_(exp_avg, denom, value=-step_size) 147 | 148 | return loss 149 | --------------------------------------------------------------------------------