├── .gitattributes ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── labels.yml ├── pull_request_template.md └── workflows │ ├── ci.yml │ └── publish.yml ├── .gitignore ├── CITATION.cff ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── SECURITY.md ├── assets ├── explore_exploit_lr_schedule.png ├── gradient_centralization.png ├── linear_lr_warmup.png ├── norm_loss.png ├── positive_negative_momentum.png └── stable_weight_decay.png ├── docs ├── .readthedocs.yaml ├── base.md ├── changelogs │ ├── v2.10.0.md │ ├── v2.10.1.md │ ├── v2.11.0.md │ ├── v2.11.1.md │ ├── v2.11.2.md │ ├── v2.12.0.md │ ├── v2.7.0.md │ ├── v2.8.0.md │ ├── v2.9.0.md │ ├── v2.9.1.md │ ├── v3.0.0.md │ ├── v3.0.1.md │ ├── v3.0.2.md │ ├── v3.1.0.md │ ├── v3.1.1.md │ ├── v3.1.2.md │ ├── v3.2.0.md │ ├── v3.3.0.md │ ├── v3.3.1.md │ ├── v3.3.2.md │ ├── v3.3.3.md │ ├── v3.3.4.md │ ├── v3.4.0.md │ ├── v3.4.1.md │ ├── v3.4.2.md │ ├── v3.5.0.md │ ├── v3.5.1.md │ ├── v3.6.0.md │ └── v3.6.1.md ├── index.md ├── javascripts │ └── tables.js ├── loss.md ├── lr_scheduler.md ├── optimizer.md ├── qa.md ├── util.md ├── visualization.md └── visualizations │ ├── rastrigin_ADOPT.png │ ├── rastrigin_APOLLO.png │ ├── rastrigin_ASGD.png │ ├── rastrigin_AccSGD.png │ ├── rastrigin_AdEMAMix.png │ ├── rastrigin_AdaBelief.png │ ├── rastrigin_AdaBound.png │ ├── rastrigin_AdaDelta.png │ ├── rastrigin_AdaFactor.png │ ├── rastrigin_AdaGC.png │ ├── rastrigin_AdaHessian.png │ ├── rastrigin_AdaMax.png │ ├── rastrigin_AdaMod.png │ ├── rastrigin_AdaNorm.png │ ├── rastrigin_AdaPNM.png │ ├── rastrigin_AdaShift.png │ ├── rastrigin_AdaSmooth.png │ ├── rastrigin_AdaTAM.png │ ├── rastrigin_Adai.png │ ├── rastrigin_Adalite.png │ ├── rastrigin_Adam.png │ ├── rastrigin_AdamG.png │ ├── rastrigin_AdamMini.png │ ├── rastrigin_AdamP.png │ ├── rastrigin_AdamS.png │ ├── rastrigin_AdamW.png │ ├── rastrigin_Adan.png │ ├── rastrigin_AggMo.png │ ├── rastrigin_Aida.png │ ├── rastrigin_AliG.png │ ├── rastrigin_Amos.png │ ├── rastrigin_ApolloDQN.png │ ├── rastrigin_AvaGrad.png │ ├── rastrigin_BSAM.png │ ├── rastrigin_CAME.png │ ├── rastrigin_DAdaptAdaGrad.png │ ├── rastrigin_DAdaptAdam.png │ ├── rastrigin_DAdaptAdan.png │ ├── rastrigin_DAdaptLion.png │ ├── rastrigin_DAdaptSGD.png │ ├── rastrigin_DiffGrad.png │ ├── rastrigin_EXAdam.png │ ├── rastrigin_FAdam.png │ ├── rastrigin_FOCUS.png │ ├── rastrigin_FTRL.png │ ├── rastrigin_Fira.png │ ├── rastrigin_Fromage.png │ ├── rastrigin_GaLore.png │ ├── rastrigin_Grams.png │ ├── rastrigin_Gravity.png │ ├── rastrigin_GrokFastAdamW.png │ ├── rastrigin_Kate.png │ ├── rastrigin_Kron.png │ ├── rastrigin_LARS.png │ ├── rastrigin_LaProp.png │ ├── rastrigin_Lamb.png │ ├── rastrigin_Lion.png │ ├── rastrigin_MADGRAD.png │ ├── rastrigin_MARS.png │ ├── rastrigin_MSVAG.png │ ├── rastrigin_Nero.png │ ├── rastrigin_NovoGrad.png │ ├── rastrigin_PAdam.png │ ├── rastrigin_PID.png │ ├── rastrigin_PNM.png │ ├── rastrigin_Prodigy.png │ ├── rastrigin_QHAdam.png │ ├── rastrigin_QHM.png │ ├── rastrigin_RACS.png │ ├── rastrigin_RAdam.png │ ├── rastrigin_Ranger.png │ ├── rastrigin_Ranger21.png │ ├── rastrigin_Ranger25.png │ ├── rastrigin_SCION.png │ ├── rastrigin_SCIONLight.png │ ├── rastrigin_SGD.png │ ├── rastrigin_SGDP.png │ ├── rastrigin_SGDSaI.png │ ├── rastrigin_SGDW.png │ ├── rastrigin_SM3.png │ ├── rastrigin_SOAP.png │ ├── rastrigin_SPAM.png │ ├── rastrigin_SRMM.png │ ├── rastrigin_SWATS.png │ ├── rastrigin_ScalableShampoo.png │ ├── rastrigin_ScheduleFreeAdamW.png │ ├── rastrigin_ScheduleFreeRAdam.png │ ├── rastrigin_ScheduleFreeSGD.png │ ├── rastrigin_Shampoo.png │ ├── rastrigin_SignSGD.png │ ├── rastrigin_SimplifiedAdEMAMix.png │ ├── rastrigin_SophiaH.png │ ├── rastrigin_StableAdamW.png │ ├── rastrigin_StableSPAM.png │ ├── rastrigin_TAM.png │ ├── rastrigin_Tiger.png │ ├── rastrigin_VSGD.png │ ├── rastrigin_Yogi.png │ ├── rosenbrock_ADOPT.png │ ├── rosenbrock_APOLLO.png │ ├── rosenbrock_ASGD.png │ ├── rosenbrock_AccSGD.png │ ├── rosenbrock_AdEMAMix.png │ ├── rosenbrock_AdaBelief.png │ ├── rosenbrock_AdaBound.png │ ├── rosenbrock_AdaDelta.png │ ├── rosenbrock_AdaFactor.png │ ├── rosenbrock_AdaGC.png │ ├── rosenbrock_AdaHessian.png │ ├── rosenbrock_AdaMax.png │ ├── rosenbrock_AdaMod.png │ ├── rosenbrock_AdaNorm.png │ ├── rosenbrock_AdaPNM.png │ ├── rosenbrock_AdaShift.png │ ├── rosenbrock_AdaSmooth.png │ ├── rosenbrock_AdaTAM.png │ ├── rosenbrock_Adai.png │ ├── rosenbrock_Adalite.png │ ├── rosenbrock_Adam.png │ ├── rosenbrock_AdamG.png │ ├── rosenbrock_AdamMini.png │ ├── rosenbrock_AdamP.png │ ├── rosenbrock_AdamS.png │ ├── rosenbrock_AdamW.png │ ├── rosenbrock_Adan.png │ ├── rosenbrock_AggMo.png │ ├── rosenbrock_Aida.png │ ├── rosenbrock_AliG.png │ ├── rosenbrock_Amos.png │ ├── rosenbrock_ApolloDQN.png │ ├── rosenbrock_AvaGrad.png │ ├── rosenbrock_BSAM.png │ ├── rosenbrock_CAME.png │ ├── rosenbrock_DAdaptAdaGrad.png │ ├── rosenbrock_DAdaptAdam.png │ ├── rosenbrock_DAdaptAdan.png │ ├── rosenbrock_DAdaptLion.png │ ├── rosenbrock_DAdaptSGD.png │ ├── rosenbrock_DiffGrad.png │ ├── rosenbrock_EXAdam.png │ ├── rosenbrock_FAdam.png │ ├── rosenbrock_FOCUS.png │ ├── rosenbrock_FTRL.png │ ├── rosenbrock_Fira.png │ ├── rosenbrock_Fromage.png │ ├── rosenbrock_GaLore.png │ ├── rosenbrock_Grams.png │ ├── rosenbrock_Gravity.png │ ├── rosenbrock_GrokFastAdamW.png │ ├── rosenbrock_Kate.png │ ├── rosenbrock_Kron.png │ ├── rosenbrock_LARS.png │ ├── rosenbrock_LaProp.png │ ├── rosenbrock_Lamb.png │ ├── rosenbrock_Lion.png │ ├── rosenbrock_MADGRAD.png │ ├── rosenbrock_MARS.png │ ├── rosenbrock_MSVAG.png │ ├── rosenbrock_Nero.png │ ├── rosenbrock_NovoGrad.png │ ├── rosenbrock_PAdam.png │ ├── rosenbrock_PID.png │ ├── rosenbrock_PNM.png │ ├── rosenbrock_Prodigy.png │ ├── rosenbrock_QHAdam.png │ ├── rosenbrock_QHM.png │ ├── rosenbrock_RACS.png │ ├── rosenbrock_RAdam.png │ ├── rosenbrock_Ranger.png │ ├── rosenbrock_Ranger21.png │ ├── rosenbrock_Ranger25.png │ ├── rosenbrock_SCION.png │ ├── rosenbrock_SCIONLight.png │ ├── rosenbrock_SGD.png │ ├── rosenbrock_SGDP.png │ ├── rosenbrock_SGDSaI.png │ ├── rosenbrock_SGDW.png │ ├── rosenbrock_SM3.png │ ├── rosenbrock_SOAP.png │ ├── rosenbrock_SPAM.png │ ├── rosenbrock_SRMM.png │ ├── rosenbrock_SWATS.png │ ├── rosenbrock_ScalableShampoo.png │ ├── rosenbrock_ScheduleFreeAdamW.png │ ├── rosenbrock_ScheduleFreeRAdam.png │ ├── rosenbrock_ScheduleFreeSGD.png │ ├── rosenbrock_Shampoo.png │ ├── rosenbrock_SignSGD.png │ ├── rosenbrock_SimplifiedAdEMAMix.png │ ├── rosenbrock_SophiaH.png │ ├── rosenbrock_StableAdamW.png │ ├── rosenbrock_StableSPAM.png │ ├── rosenbrock_TAM.png │ ├── rosenbrock_Tiger.png │ ├── rosenbrock_VSGD.png │ └── rosenbrock_Yogi.png ├── examples ├── __init__.py ├── pytorch_lightning_example.py ├── pytorch_lightning_manual_backward_example.py └── visualize_optimizers.py ├── hubconf.py ├── mkdocs.yml ├── poetry.lock ├── pyproject.toml ├── pytorch_optimizer ├── __init__.py ├── base │ ├── __init__.py │ ├── exception.py │ ├── optimizer.py │ ├── scheduler.py │ └── type.py ├── loss │ ├── __init__.py │ ├── bi_tempered.py │ ├── cross_entropy.py │ ├── dice.py │ ├── f1.py │ ├── focal.py │ ├── jaccard.py │ ├── ldam.py │ ├── lovasz.py │ └── tversky.py ├── lr_scheduler │ ├── __init__.py │ ├── chebyshev.py │ ├── cosine_anealing.py │ ├── experimental │ │ ├── __init__.py │ │ └── deberta_v3_lr_scheduler.py │ ├── linear_warmup.py │ ├── proportion.py │ ├── rex.py │ └── wsd.py └── optimizer │ ├── __init__.py │ ├── a2grad.py │ ├── adabelief.py │ ├── adabound.py │ ├── adadelta.py │ ├── adafactor.py │ ├── adagc.py │ ├── adahessian.py │ ├── adai.py │ ├── adalite.py │ ├── adam_mini.py │ ├── adamax.py │ ├── adamg.py │ ├── adamod.py │ ├── adamp.py │ ├── adams.py │ ├── adamw.py │ ├── adan.py │ ├── adanorm.py │ ├── adapnm.py │ ├── adashift.py │ ├── adasmooth.py │ ├── ademamix.py │ ├── adopt.py │ ├── agc.py │ ├── aggmo.py │ ├── aida.py │ ├── alig.py │ ├── amos.py │ ├── apollo.py │ ├── avagrad.py │ ├── came.py │ ├── dadapt.py │ ├── demo.py │ ├── diffgrad.py │ ├── exadam.py │ ├── experimental │ ├── __init__.py │ └── ranger25.py │ ├── fadam.py │ ├── fira.py │ ├── focus.py │ ├── fp16.py │ ├── fromage.py │ ├── ftrl.py │ ├── galore.py │ ├── galore_utils.py │ ├── gradient_centralization.py │ ├── grams.py │ ├── gravity.py │ ├── grokfast.py │ ├── kate.py │ ├── lamb.py │ ├── laprop.py │ ├── lars.py │ ├── lion.py │ ├── lomo.py │ ├── lookahead.py │ ├── madgrad.py │ ├── mars.py │ ├── msvag.py │ ├── muon.py │ ├── nero.py │ ├── novograd.py │ ├── orthograd.py │ ├── padam.py │ ├── pcgrad.py │ ├── pid.py │ ├── pnm.py │ ├── prodigy.py │ ├── psgd.py │ ├── psgd_utils.py │ ├── qhadam.py │ ├── qhm.py │ ├── racs.py │ ├── radam.py │ ├── ranger.py │ ├── ranger21.py │ ├── rotograd.py │ ├── sam.py │ ├── schedulefree.py │ ├── scion.py │ ├── sgd.py │ ├── shampoo.py │ ├── shampoo_utils.py │ ├── sm3.py │ ├── snsm.py │ ├── soap.py │ ├── sophia.py │ ├── spam.py │ ├── srmm.py │ ├── swats.py │ ├── tam.py │ ├── tiger.py │ ├── trac.py │ ├── utils.py │ └── yogi.py ├── requirements-dev.txt ├── requirements-docs.txt ├── requirements.txt └── tests ├── __init__.py ├── conftest.py ├── constants.py ├── test_base.py ├── test_create_optimizer.py ├── test_general_optimizer_parameters.py ├── test_gradients.py ├── test_load_modules.py ├── test_loss_functions.py ├── test_lr_scheduler_parameters.py ├── test_lr_schedulers.py ├── test_optimizer_parameters.py ├── test_optimizer_variants.py ├── test_optimizers.py ├── test_utils.py ├── test_wrapper_optimizers.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto eol=lf -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @kozistr -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: kozistr 7 | 8 | --- 9 | 10 | ## Describe the bug 11 | 12 | A clear and concise description of what the bug is. 13 | 14 | ## To Reproduce 15 | 16 | * OS : (e.g. Linux, Windows, MacOS) 17 | * PyTorch version : (e.g. 2.0.1, 1.13, >=1.8, <1.10) 18 | * Python version : (e.g. 3.8, 3.11) 19 | * pytorch-optimizer version : (e.g. 3.3.0) 20 | * reproducible codes : please share your reproducible codes, scripts, or links. If sharing the code is complicated, you can manually write minimal code to reproduce bugs! 21 | 22 | Here's an [example](https://github.com/kozistr/pytorch_optimizer/issues/305#issue-2721453417). 23 | 24 | ## Log 25 | 26 | attach the complete log here! (highlighted texts or screenshots are welcome) 27 | 28 | ## Expected behavior 29 | 30 | A clear and concise description of what you expected to happen. 31 | 32 | ## Additional context 33 | 34 | Add any other context about the problem here. 35 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: feature request 6 | assignees: kozistr 7 | 8 | --- 9 | 10 | ## Paper or Code 11 | 12 | here! 13 | -------------------------------------------------------------------------------- /.github/labels.yml: -------------------------------------------------------------------------------- 1 | XS: 2 | name: size/XS 3 | lines: 0 4 | color: 3CBF00 5 | S: 6 | name: size/S 7 | lines: 20 8 | color: 5D9801 9 | M: 10 | name: size/M 11 | lines: 100 12 | color: 7F7203 13 | L: 14 | name: size/L 15 | lines: 250 16 | color: A14C05 17 | XL: 18 | name: size/XL 19 | lines: 500 20 | color: C32607 21 | XXL: 22 | name: size/XXL 23 | lines: 1000 24 | color: E50009 25 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | --- 2 | Remove this part when you open the PR 3 | 4 | Here's a checklist before opening the Pull Request! 5 | 6 | 1. PR title convention : [Type of PR] [Summary] (e.g. [Feature] Implement AdamP optimizer) 7 | 2. Attach `as much information as possible you can`. It helps the reviewers a lot :) 8 | 3. Make sure the code is perfectly `runnable & compatible`. 9 | 4. If your PR is not ready yet, make your `PR` to `Draft PR`. 10 | 5. Make sure `make format & check` before opening the `PR`. 11 | 6. Or you just call the maintainer to help to fix code-style & test cases. 12 | --- 13 | 14 | ## Problem (Why?) 15 | 16 | _What problem are you trying to solve?_ 17 | 18 | ## Solution (What/How?) 19 | 20 | _How did you solve the problem? Please provide a complete description and explanation!_ 21 | 22 | ## Other changes (bug fixes, small refactors) 23 | 24 | _Are there any small changes?_ 25 | 26 | ## Notes 27 | 28 | _Please note any questions, helps or contexts what maintainer(s) should know_ 29 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | env: 10 | OMP_NUM_THREADS: 2 11 | MKL_NUM_THREADS: 2 12 | 13 | jobs: 14 | test: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: ['3.12'] 19 | 20 | steps: 21 | - uses: actions/checkout@v4 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | cache: 'pip' 27 | - name: Install dependencies 28 | run: pip --disable-pip-version-check install --no-compile -r requirements-dev.txt 29 | - name: Check lint 30 | run: make check 31 | - name: Check test 32 | env: 33 | PYTHONDONTWRITEBYTECODE: 1 34 | run: make test 35 | - name: Check codecov 36 | uses: codecov/codecov-action@v4 37 | with: 38 | token: ${{ secrets.CODECOV_TOKEN }} 39 | directory: ./ 40 | files: ./coverage.xml 41 | env_vars: OS,PYTHON 42 | fail_ci_if_error: true 43 | verbose: false 44 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | jobs: 9 | release: 10 | name: Create Release 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v4 15 | - name: Get the version 16 | id: get_version 17 | run: echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//} 18 | - name: Create release 19 | id: create_release 20 | uses: actions/create-release@v1 21 | env: 22 | GITHUB_TOKEN: ${{ secrets.GH_TOKEN }} 23 | with: 24 | tag_name: ${{ github.ref }} 25 | release_name: pytorch-optimizer ${{ github.ref }} 26 | body_path: docs/changelogs/${{ steps.get_version.outputs.VERSION }}.md 27 | draft: false 28 | prerelease: false 29 | deploy: 30 | name: Deploy 31 | needs: release 32 | runs-on: ubuntu-latest 33 | steps: 34 | - uses: actions/checkout@v4 35 | - name: Setup Python 3.12 36 | uses: actions/setup-python@v5 37 | with: 38 | python-version: 3.12 39 | cache: 'pip' 40 | - name: Install dependencies 41 | run: | 42 | pip --disable-pip-version-check install --no-compile poetry 43 | pip --disable-pip-version-check install --no-compile -r requirements.txt 44 | - name: Publish package to PyPI 45 | env: 46 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 47 | run: | 48 | poetry config pypi-token.pypi $PYPI_TOKEN 49 | poetry publish --build 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | \#*\# 3 | .\#* 4 | .idea/ 5 | .DS_Store 6 | ._* 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | build/ 11 | *.egg-info/ 12 | target/ 13 | htmlcov/ 14 | .coverage 15 | .coverage.* 16 | .cache 17 | coverage.xml 18 | *.cover 19 | .pytest_cache/ 20 | .mypy_cache/ 21 | .env 22 | venv/ 23 | env/ 24 | _build/ 25 | [._]*.s[a-v][a-z] 26 | [._]*.sw[a-p] 27 | [._]s[a-rt-v][a-z] 28 | [._]ss[a-gi-z] 29 | [._]sw[a-p] 30 | .netrwhist 31 | .vscode/* 32 | .history 33 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: Kim 5 | given-names: Hyeongchan 6 | orcid: https://orcid.org/0000-0002-1729-0580 7 | title: "pytorch_optimizer: optimizer & lr scheduler & loss function collections in PyTorch" 8 | version: 2.12.0 9 | date-released: 2021-09-21 10 | url: "https://github.com/kozistr/pytorch_optimizer" -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | *This guideline is very much a WIP* 2 | 3 | Contributions to `pytorch-optimizer` for code, documentation, and tests are always welcome! 4 | 5 | # Coding Style 6 | 7 | Currently, `black` and `ruff` are used to format & lint the code. Here are the [lint options](https://github.com/kozistr/pytorch_optimizer/blob/main/pyproject.toml#L69) 8 | Or you just simply run `make format` and `make check` on the project root. 9 | 10 | You can create the environment with `make init` or just install the pip packages to your computer. 11 | 12 | A few differences from the default `black` (or another style guide) are 13 | 14 | 1. line-length is **119** characters. 15 | 2. **single quote** is preferred instead of a double quote. 16 | 17 | But, maybe, if you feel or think that it's too much or takes much time, then feel free to ask the maintainer to fix the lint stuff! 18 | 19 | # Documentation 20 | 21 | Docstring style is `reST` (which is not Google and Numpydoc styles), and documentation will be built & deployed automatically via `readthedocs`. You can find an example from [here](https://github.com/kozistr/pytorch_optimizer/blob/main/pytorch_optimizer/optimizer/adamp.py#L14). 22 | 23 | # Test 24 | 25 | You can run a test by `make test` on the project root! 26 | 27 | # Question 28 | 29 | If you have any questions about contribution, please ask in the Issues, Discussions, or just in PR :) 30 | 31 | Thank you! 32 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: init format test check requirements visualize docs 2 | 3 | init: 4 | python -m pip install -q -U poetry isort black ruff pytest pytest-cov 5 | python -m poetry install --dev 6 | 7 | format: 8 | isort --profile black -l 119 pytorch_optimizer examples tests hubconf.py 9 | black -S -l 119 pytorch_optimizer examples tests hubconf.py 10 | 11 | check: 12 | black -S -l 119 --check pytorch_optimizer examples tests hubconf.py 13 | ruff check pytorch_optimizer examples tests hubconf.py 14 | 15 | test: 16 | python -m pytest -p no:pastebin -p no:nose -p no:doctest -sv -vv --cov=pytorch_optimizer --cov-report=xml ./tests 17 | 18 | requirements: 19 | poetry export -f requirements.txt --output requirements.txt --without-hashes 20 | poetry export -f requirements.txt --output requirements-dev.txt --without-hashes --with dev 21 | 22 | visualize: 23 | python -m examples.visualize_optimizers 24 | 25 | docs: 26 | mkdocs serve 27 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Supported Versions 4 | 5 | all versions maybe. 6 | 7 | ## Reporting a Vulnerability 8 | 9 | this project heavily depends on the `torch` package. 10 | -------------------------------------------------------------------------------- /assets/explore_exploit_lr_schedule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/assets/explore_exploit_lr_schedule.png -------------------------------------------------------------------------------- /assets/gradient_centralization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/assets/gradient_centralization.png -------------------------------------------------------------------------------- /assets/linear_lr_warmup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/assets/linear_lr_warmup.png -------------------------------------------------------------------------------- /assets/norm_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/assets/norm_loss.png -------------------------------------------------------------------------------- /assets/positive_negative_momentum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/assets/positive_negative_momentum.png -------------------------------------------------------------------------------- /assets/stable_weight_decay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/assets/stable_weight_decay.png -------------------------------------------------------------------------------- /docs/.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.11" 7 | 8 | mkdocs: 9 | configuration: mkdocs.yml 10 | fail_on_warning: false 11 | 12 | python: 13 | install: 14 | - requirements: requirements-docs.txt 15 | -------------------------------------------------------------------------------- /docs/base.md: -------------------------------------------------------------------------------- 1 | # Base 2 | 3 | ::: pytorch_optimizer.base.optimizer.BaseOptimizer 4 | :docstring: 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/changelogs/v2.10.0.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement Amos optimizer (#174) 6 | * [An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale](https://arxiv.org/abs/2210.11693) 7 | * Implement SignSGD optimizer (#176) (thanks to @i404788) 8 | * [Compressed Optimisation for Non-Convex Problems](https://arxiv.org/abs/1802.04434) 9 | * Implement AdaHessian optimizer (#176) (thanks to @i404788) 10 | * [An Adaptive Second Order Optimizer for Machine Learning](https://arxiv.org/abs/2006.00719) 11 | * Implement SophiaH optimizer (#173, #176) (thanks to @i404788) 12 | * [A Scalable Stochastic Second-order Optimizer for Language Model Pre-training](https://arxiv.org/abs/2305.14342) 13 | * Implement re-usable functions to compute hessian in `BaseOptimizer` (#176, #177) (thanks to @i404788) 14 | * two types of distribution are supported (`gaussian`, `rademacher`). 15 | * Support `AdamD` variant for AdaHessian optimizer (#177) 16 | 17 | ### Diff 18 | 19 | [2.9.1...2.10.0](https://github.com/kozistr/pytorch_optimizer/compare/v2.9.1...v2.10.0) 20 | -------------------------------------------------------------------------------- /docs/changelogs/v2.10.1.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement Prodigy optimizer (#183) 6 | * [An Expeditiously Adaptive Parameter-Free Learner](https://arxiv.org/abs/2306.06101) 7 | 8 | ### Fix 9 | 10 | * `perturb` isn't multiplied by `-step_size` in SWATS optimizer. (#179) 11 | * `chebyshev step` has size of `T` while the permutation is `2^T`. (#168, #181) 12 | 13 | ### Diff 14 | 15 | [2.10.0...2.10.1](https://github.com/kozistr/pytorch_optimizer/compare/v2.10.0...v2.10.1) 16 | -------------------------------------------------------------------------------- /docs/changelogs/v2.11.0.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement PAdam optimizer (#186) 6 | * [Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks](https://arxiv.org/abs/1806.06763) 7 | * Implement LOMO optimizer (#188) 8 | * [Full Parameter Fine-tuning for Large Language Models with Limited Resources](https://arxiv.org/abs/2306.09782) 9 | * Implement loss functions (#189) 10 | * BCELoss 11 | * BCEFocalLoss 12 | * FocalLoss : [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) 13 | * FocalCosineLoss : [Data-Efficient Deep Learning Method for Image Classification Using Data Augmentation, Focal Cosine Loss, and Ensemble](https://arxiv.org/abs/2007.07805) 14 | * DiceLoss : [Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations](https://arxiv.org/abs/1707.03237v3) 15 | * LDAMLoss : [Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss](https://arxiv.org/abs/1906.07413) 16 | * JaccardLoss 17 | * BiTemperedLogisticLoss : [Robust Bi-Tempered Logistic Loss Based on Bregman Divergences](https://arxiv.org/abs/1906.03361) 18 | 19 | ### Diff 20 | 21 | [2.10.1...2.11.0](https://github.com/kozistr/pytorch_optimizer/compare/v2.10.1...v2.11.0) 22 | -------------------------------------------------------------------------------- /docs/changelogs/v2.11.1.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement Tiger optimizer (#192) 6 | * [A Tight-fisted Optimizer](https://github.com/bojone/tiger/blob/main/README_en.md) 7 | * Implement CAME optimizer (#196) 8 | * [Confidence-guided Adaptive Memory Efficient Optimization](https://aclanthology.org/2023.acl-long.243/) 9 | * Implement loss functions (#198) 10 | * Tversky Loss : [Tversky loss function for image segmentation using 3D fully convolutional deep networks](https://arxiv.org/abs/1706.05721) 11 | * Focal Tversky Loss 12 | * Lovasz Hinge Loss : [The Lovász-Softmax loss: A tractable surrogate for the optimization of the intersection-over-union measure in neural networks](https://arxiv.org/abs/1705.08790) 13 | 14 | ### Diff 15 | 16 | [2.11.0...2.11.1](https://github.com/kozistr/pytorch_optimizer/compare/v2.11.0...v2.11.1) 17 | -------------------------------------------------------------------------------- /docs/changelogs/v2.11.2.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement DAdaptLion optimizer (#203) 6 | * [Lion with D-Adaptation](https://github.com/facebookresearch/dadaptation/blob/main/dadaptation/dadapt_lion.py) 7 | 8 | ### Fix 9 | 10 | * Fix Lookahead optimizer (#200, #201, #202) 11 | * When using PyTorch Lightning which expects your optimiser to be a subclass of `Optimizer`. 12 | * Fix default `rectify` to `False` in `AdaBelief` optimizer (#203) 13 | 14 | ### Test 15 | 16 | * Add `DynamicLossScaler` test case 17 | 18 | ### Docs 19 | 20 | * Highlight the code blocks 21 | * Fix pepy badges 22 | 23 | ### Contributions 24 | 25 | thanks to @georg-wolflein 26 | 27 | ### Diff 28 | 29 | [2.11.1...2.11.2](https://github.com/kozistr/pytorch_optimizer/compare/v2.11.1...v2.11.2) 30 | -------------------------------------------------------------------------------- /docs/changelogs/v2.12.0.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Support `bitsandbytes` optimizer. (#211) 6 | * now, you can install with `pip3 install pytorch-optimizer[bitsandbytes]` 7 | * supports 8 bnb optimizers. 8 | * `bnb_adagrad8bit`, `bnb_adam8bit`, `bnb_adamw8bit`, `bnb_lion8bit`, `bnb_lamb8bit`, `bnb_lars8bit`, `bnb_rmsprop8bit`, `bnb_sgd8bit`. 9 | 10 | ### Docs 11 | 12 | * Introduce `mkdocs` with `material` theme. (#204, #206) 13 | * documentation : https://pytorch-optimizers.readthedocs.io/en/latest/ 14 | 15 | ### Diff 16 | 17 | [2.11.2...2.12.0](https://github.com/kozistr/pytorch_optimizer/compare/v2.11.2...v2.12.0) 18 | -------------------------------------------------------------------------------- /docs/changelogs/v2.7.0.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement `AdaNorm` optimizer (#133) 6 | * [AdaNorm: Adaptive Gradient Norm Correction based Optimizer for CNNs](https://arxiv.org/abs/2210.06364) 7 | * Implement `RotoGrad` optimizer (#124, #134) 8 | * [RotoGrad: Gradient Homogenization in Multitask Learning](https://arxiv.org/abs/2103.02631) 9 | * Implement `D-Adapt Adan` optimizer (#134) 10 | * Support `AdaNorm` variant (#133, #134) 11 | * AdaBelief 12 | * AdamP 13 | * AdamS 14 | * AdaPNM 15 | * diffGrad 16 | * Lamb 17 | * RAdam 18 | * Ranger 19 | * Adan 20 | * Support `AMSGrad` variant (#133, #134) 21 | * diffGrad 22 | * AdaFactor 23 | * Support `degenerated_to_sgd` (#133) 24 | * Ranger 25 | * Lamb 26 | 27 | ### Refactor 28 | 29 | * Rename `adamd_debias_term` to `adam_debias` (#133) 30 | * Merge the rectified version with the original (#133) 31 | * diffRGrad + diffGrad -> diffGrad 32 | * RaLamb + Lamb -> Lamb 33 | * now you can simply use with `rectify=True` 34 | 35 | ### Bug 36 | 37 | * Fix `previous_grad` deepcopy issue in Adan optimizer (#134) 38 | 39 | ### Diff 40 | 41 | [2.6.1...2.7.0](https://github.com/kozistr/pytorch_optimizer/compare/v2.6.1...v2.7.0) 42 | -------------------------------------------------------------------------------- /docs/changelogs/v2.8.0.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement A2Grad optimizer (#136) 6 | * [Optimal Adaptive and Accelerated Stochastic Gradient Descent](https://arxiv.org/abs/1810.00553) 7 | * Implement Accelerated SGD optimizer (#137) 8 | * [Accelerating Stochastic Gradient Descent For Least Squares Regression](https://arxiv.org/abs/1704.08227) 9 | * Implement Adaptive SGD optimizer (#139) 10 | * [Adaptive Gradient Descent without Descent](https://arxiv.org/abs/1910.09529) 11 | * Implement SGDW optimizer (#139) 12 | * [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) 13 | * Implement Yogi optimizer (#140) 14 | * [Adaptive Methods for Nonconvex Optimization](https://papers.nips.cc/paper_files/paper/2018/hash/90365351ccc7437a1309dc64e4db32a3-Abstract.html) 15 | * Implement SWATS optimizer (#141) 16 | * [Improving Generalization Performance by Switching from Adam to SGD](https://arxiv.org/abs/1712.07628) 17 | * Implement Fromage optimizer (#142) 18 | * [On the distance between two neural networks and the stability of learning](https://arxiv.org/abs/2002.03432) 19 | * Implement MSVAG optimizer (#143) 20 | * [Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients](https://arxiv.org/abs/1705.07774) 21 | * Implement AdaMod optimizer (#144) 22 | * [An Adaptive and Momental Bound Method for Stochastic Learning](https://arxiv.org/abs/1910.12249) 23 | * Implement AggMo optimizer (#145) 24 | * [Aggregated Momentum: Stability Through Passive Damping](https://arxiv.org/abs/1804.00325) 25 | * Implement QHAdam, QHM optimizers (#146) 26 | * [Quasi-hyperbolic momentum and Adam for deep learning](https://arxiv.org/abs/1810.06801) 27 | * Implement PID optimizer (#147) 28 | * [A PID Controller Approach for Stochastic Optimization of Deep Networks](http://www4.comp.polyu.edu.hk/~cslzhang/paper/CVPR18_PID.pdf) 29 | 30 | ### Bug 31 | 32 | * Fix `update` in Lion optimizer (#135) 33 | * Fix `momentum_buffer` in SGDP optimizer (#139) 34 | 35 | ### Diff 36 | 37 | [2.7.0...2.8.0](https://github.com/kozistr/pytorch_optimizer/compare/v2.7.0...v2.8.0) 38 | -------------------------------------------------------------------------------- /docs/changelogs/v2.9.0.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement AdaMax optimizer (#148) 6 | * A variant of Adam based on the infinity norm 7 | * Implement Gravity optimizer (#151) 8 | * [a Kinematic Approach on Optimization in Deep Learning](https://arxiv.org/abs/2101.09192) 9 | * Implement AdaSmooth optimizer (#153) 10 | * [An Adaptive Learning Rate Method based on Effective Ratio](https://arxiv.org/abs/2204.00825v1) 11 | * Implement SRMM optimizer (#154) 12 | * [Stochastic regularized majorization-minimization with weakly convex and multi-convex surrogates](https://arxiv.org/abs/2201.01652) 13 | * Implement AvaGrad optimizer (#155) 14 | * [Domain-independent Dominance of Adaptive Methods](https://arxiv.org/abs/1912.01823) 15 | * Implement AdaShift optimizer (#157) 16 | * [Decorrelation and Convergence of Adaptive Learning Rate Methods](https://arxiv.org/abs/1810.00143v4) 17 | * Upgrade to D-Adaptation v3 (#158, #159) 18 | * Implement AdaDelta optimizer (#160) 19 | * [An Adaptive Learning Rate Method](https://arxiv.org/abs/1212.5701v1) 20 | 21 | ### Docs 22 | 23 | * Fix readthedocs build issue (#156) 24 | * Move citations into table (#156) 25 | 26 | ### Refactor 27 | 28 | * Refactor validation logic (#149, #150) 29 | * Rename `amsbound`, `amsgrad` terms into `ams_bound` (#149) 30 | * Return gradient instead of the parameter, AGC. (#149) 31 | * Refactor duplicates (e.g. rectified step size, AMSBound, AdamD, AdaNorm, weight decay) into re-usable functions (#150) 32 | * Move `pytorch_optimizer.experimental` under `pytorch_optimizer.*.experimental` 33 | 34 | ### Diff 35 | 36 | [2.8.0...2.9.0](https://github.com/kozistr/pytorch_optimizer/compare/v2.8.0...v2.9.0) 37 | -------------------------------------------------------------------------------- /docs/changelogs/v2.9.1.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Fix 4 | 5 | * fix weight decay in Ranger21 (#170) 6 | 7 | ## Diff 8 | 9 | [2.9.0...2.9.1](https://github.com/kozistr/pytorch_optimizer/compare/v2.9.0...v2.9.1) 10 | -------------------------------------------------------------------------------- /docs/changelogs/v3.0.0.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | The major version is updated! (`v2.12.0` -> `v3.0.0`) (#164) 4 | 5 | Many optimizers, learning rate schedulers, and objective functions are in `pytorch-optimizer`. 6 | Currently, `pytorch-optimizer` supports **67 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions**, and reached about 4 ~ 50K downloads / month (peak is 75K downloads / month)! 7 | 8 | The reason for updating the major version from `v2` to `v3` is that I think it's a good time to ship the recent implementations (the last update was about 7 months ago) and plan to pivot to new concepts like training utilities while maintaining the original features (e.g. optimizers). 9 | Also, rich test cases, benchmarks, and examples are on the list! 10 | 11 | Finally, thanks for using the `pytorch-optimizer`, and feel free to make any requests :) 12 | 13 | ### Feature 14 | 15 | * Implement `REX` lr scheduler. (#217, #222) 16 | * [Revisiting Budgeted Training with an Improved Schedule](https://arxiv.org/abs/2107.04197) 17 | * Implement `Aida` optimizer. (#220, #221) 18 | * [A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range](https://arxiv.org/abs/2203.13273) 19 | * Implement `WSAM` optimizer. (#213, #216) 20 | * [Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term](https://arxiv.org/abs/2305.15817) 21 | * Implement `GaLore` optimizer. (#224, #228) 22 | * [Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507) 23 | * Implement `Adalite` optimizer. (#225, #229) 24 | * Implement `bSAM` optimizer. (#212, #233) 25 | * [SAM as an Optimal Relaxation of Bayes](https://arxiv.org/abs/2210.01620) 26 | * Implement `Schedule-Free` optimizer. (#230, #233) 27 | * [Schedule-Free optimizers](https://github.com/facebookresearch/schedule_free) 28 | * Implement `EMCMC`. (#231, #233) 29 | * [Entropy-MCMC: Sampling from flat basins with ease](https://www.semanticscholar.org/paper/Entropy-MCMC%3A-Sampling-from-Flat-Basins-with-Ease-Li-Zhang/fd95de3f24fc4f955a6fe5719d38d1d06136e0cd) 30 | 31 | ### Fix 32 | 33 | * Fix SRMM to allow operation beyond memory_length. (#227) 34 | 35 | ### Dependency 36 | 37 | * Drop `Python 3.7` support officially. (#221) 38 | * Please check the [README](https://github.com/kozistr/pytorch_optimizer?tab=readme-ov-file#getting-started). 39 | * Update `bitsandbytes` to `0.43.0`. (#228) 40 | 41 | ### Docs 42 | 43 | * Add missing parameters in `Ranger21 optimizer` document. (#214, #215) 44 | * Fix `WSAM` optimizer paper link. (#219) 45 | 46 | ## Contributions 47 | 48 | thanks to @sdbds, @i404788 49 | 50 | ## Diff 51 | 52 | * from the previous major version : [2.0.0...3.0.0](https://github.com/kozistr/pytorch_optimizer/compare/v2.0.0...v3.0.0) 53 | * from the previous version: [2.12.0...3.0.0](https://github.com/kozistr/pytorch_optimizer/compare/v2.12.0...v3.0.0) 54 | -------------------------------------------------------------------------------- /docs/changelogs/v3.0.1.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement `FAdam` optimizer. (#241, #242) 6 | * [Adam is a natural gradient optimizer using diagonal empirical Fisher information](https://arxiv.org/abs/2405.12807) 7 | * Tweak `AdaFactor` optimizer. (#236, #243) 8 | * support not-using-first-momentum when beta1 is not given 9 | * default dtype for first momentum to `bfloat16` 10 | * clip second momentum to 0.999 11 | * Implement `GrokFast` optimizer. (#244, #245) 12 | * [Accelerated Grokking by Amplifying Slow Gradients](https://arxiv.org/abs/2405.20233) 13 | 14 | ### Bug 15 | 16 | * Wrong typing of reg_noise. (#239, #240) 17 | * Lookahead`s param_groups attribute is not loaded from checkpoint. (#237, #238) 18 | 19 | ## Contributions 20 | 21 | thanks to @michaldyczko 22 | -------------------------------------------------------------------------------- /docs/changelogs/v3.0.2.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement `WSD` LR Scheduler. (#247, #248) 6 | * [Warmup-Stable-Decay LR Scheduler](https://arxiv.org/abs/2404.06395) 7 | * Add more Pytorch built-in lr schedulers. (#248) 8 | * Implement `Kate` optimizer. (#249, #251) 9 | * [Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad](https://arxiv.org/abs/2403.02648) 10 | * Implement `StableAdamW` optimizer. (#250, #252) 11 | * [Stable and low-precision training for large-scale vision-language models](https://arxiv.org/abs/2304.13013) 12 | * Implement `AdamMini` optimizer. (#246, #253) 13 | * [Use Fewer Learning Rates To Gain More](https://arxiv.org/abs/2406.16793) 14 | 15 | ### Refactor 16 | 17 | * Refactor `Chebyschev` lr scheduler modules. (#248) 18 | * Rename `get_chebyshev_lr` to `get_chebyshev_lr_lambda`. 19 | * Rename `get_chebyshev_schedule` to `get_chebyshev_perm_steps`. 20 | * Call `get_chebyshev_schedule` function to get `LamdbaLR` scheduler object. 21 | * Refactor with `ScheduleType`. (#248) 22 | -------------------------------------------------------------------------------- /docs/changelogs/v3.1.0.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement `AdaLomo` optimizer. (#258) 6 | * [Low-memory Optimization with Adaptive Learning Rate](https://arxiv.org/abs/2310.10195) 7 | * Support `Q-GaLore` optimizer. (#258) 8 | * [Q-GaLore: Quantized GaLore with INT4 Projection and Layer-Adaptive Low-Rank Gradients.](https://arxiv.org/abs/2407.08296) 9 | * you can use by `optimizer = load_optimizer('q_galore_adamw8bit')` 10 | * Support more bnb optimizers. (#258) 11 | * `bnb_paged_adam8bit`, `bnb_paged_adamw8bit`, `bnb_*_*32bit`. 12 | * Improve `power_iteration()` speed up to 40%. (#259) 13 | * Improve `reg_noise()` (E-MCMC) speed up to 120%. (#260) 14 | * Support `disable_lr_scheduler` parameter for `Ranger21` optimizer to disable built-in learning rate scheduler. (#261) 15 | 16 | ### Refactor 17 | 18 | * Refactor `AdamMini` optimizer. (#258) 19 | * Deprecate optional dependency, `bitsandbytes`. (#258) 20 | * Move `get_rms`, `approximate_sq_grad` functions to `BaseOptimizer` for reusability. (#258) 21 | * Refactor `shampoo_utils.py`. (#259) 22 | * Add `debias`, `debias_adam` methods in `BaseOptimizer`. (#261) 23 | * Refactor to use `BaseOptimizer` only, not inherit multiple classes. (#261) 24 | 25 | ### Bug 26 | 27 | * Fix several bugs in `AdamMini` optimizer. (#257) 28 | 29 | ## Contributions 30 | 31 | thanks to @sdbds 32 | -------------------------------------------------------------------------------- /docs/changelogs/v3.1.1.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement `TRAC` optimizer. (#263) 6 | * [Fast TRAC: A Parameter-Free Optimizer for Lifelong Reinforcement Learning](https://arxiv.org/abs/2405.16642) 7 | * Support `AdamW` optimizer via `create_optimizer()`. (#263) 8 | * Implement `AdamG` optimizer. (#264, #265) 9 | * [Towards Stability of Parameter-free Optimization](https://arxiv.org/abs/2405.04376) 10 | 11 | ### Bug 12 | 13 | * Handle the optimizers that only take the `model` instead of the parameters in `create_optimizer()`. (#263) 14 | * Move the variable to the same device with the parameter. (#266, #267) 15 | -------------------------------------------------------------------------------- /docs/changelogs/v3.1.2.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement `AdEMAMix` optimizer. (#272) 6 | * [THE ADEMAMIX OPTIMIZER: BETTER, FASTER, OLDER](https://arxiv.org/pdf/2409.03137) 7 | 8 | ### Bug 9 | 10 | * Add `**kwargs` to the parameters for dummy placeholder. (#270, #271) 11 | -------------------------------------------------------------------------------- /docs/changelogs/v3.2.0.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement `SOAP` optimizer. (#275) 6 | * [SOAP: Improving and Stabilizing Shampoo using Adam](https://arxiv.org/abs/2409.11321) 7 | * Support `AdEMAMix` variants. (#276) 8 | * `bnb_ademamix8bit`, `bnb_ademamix32bit`, `bnb_paged_ademamix8bit`, `bnb_paged_ademamix32bit` 9 | * Support 8/4bit, fp8 optimizers. (#208, #281) 10 | * `torchao_adamw8bit`, `torchao_adamw4bit`, `torchao_adamwfp8`. 11 | * Support a module-name-level (e.g. `LayerNorm`) weight decay exclusion for `get_optimizer_parameters`. (#282, #283) 12 | * Implement `CPUOffloadOptimizer`, which offloads optimizer to CPU for single-GPU training. (#284) 13 | * Support a regex-based filter for searching names of optimizers, lr schedulers, and loss functions. 14 | 15 | ### Bug 16 | 17 | * Fix `should_grokfast` condition when initialization. (#279, #280) 18 | 19 | ### Contributions 20 | 21 | thanks to @Vectorrent 22 | -------------------------------------------------------------------------------- /docs/changelogs/v3.3.0.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Support `PaLM` variant for `ScheduleFreeAdamW` optimizer. (#286, #288) 6 | * you can use this feature by setting `use_palm` to `True`. 7 | * Implement `ADOPT` optimizer. (#289, #290) 8 | * [Modified Adam Can Converge with Any β2 with the Optimal Rate](https://arxiv.org/abs/2411.02853) 9 | * Implement `FTRL` optimizer. (#291) 10 | * [Follow The Regularized Leader](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf) 11 | * Implement `Cautious optimizer` feature. (#294) 12 | * [Improving Training with One Line of Code](https://arxiv.org/pdf/2411.16085v1) 13 | * you can use it by setting `cautious=True` for `Lion`, `AdaFactor` and `AdEMAMix` optimizers. 14 | * Improve the stability of `ADOPT` optimizer. (#294) 15 | * [Note](https://github.com/iShohei220/adopt?tab=readme-ov-file#update-on-nov-22-2024) 16 | * Support a new projection type `random` for `GaLoreProjector`. (#294) 17 | * Implement `DeMo` optimizer. (#300, #301) 18 | * [Decoupled Momentum Optimization](https://arxiv.org/abs/2411.19870) 19 | * Implement `Muon` optimizer. (#302) 20 | * [MomentUm Orthogonalized by Newton-schulz](https://github.com/KellerJordan/Muon) 21 | * Implement `ScheduleFreeRAdam` optimizer. (#304) 22 | * Implement `LaProp` optimizer. (#304) 23 | * [Separating Momentum and Adaptivity in Adam](https://arxiv.org/abs/2002.04839) 24 | * Support `Cautious` variant to `LaProp`, `AdamP`, `Adopt` optimizers. (#304). 25 | 26 | ### Refactor 27 | 28 | * Big refactoring, removing direct import from `pytorch_optimizer.*`. 29 | * I removed some methods not to directly import from it from `pytorch_optimzier.*` because they're probably not used frequently and actually not an optimizer rather utils only used for specific optimizers. 30 | * `pytorch_optimizer.[Shampoo stuff]` -> `pytorch_optimizer.optimizers.shampoo_utils.[Shampoo stuff]`. 31 | * `shampoo_utils` like `Graft`, `BlockPartitioner`, `PreConditioner`, etc. You can check the details [here](https://github.com/kozistr/pytorch_optimizer/blob/main/pytorch_optimizer/optimizer/shampoo_utils.py). 32 | * `pytorch_optimizer.GaLoreProjector` -> `pytorch_optimizer.optimizers.galore.GaLoreProjector`. 33 | * `pytorch_optimizer.gradfilter_ema` -> `pytorch_optimizer.optimizers.grokfast.gradfilter_ema`. 34 | * `pytorch_optimizer.gradfilter_ma` -> `pytorch_optimizer.optimizers.grokfast.gradfilter_ma`. 35 | * `pytorch_optimizer.l2_projection` -> `pytorch_optimizer.optimizers.alig.l2_projection`. 36 | * `pytorch_optimizer.flatten_grad` -> `pytorch_optimizer.optimizers.pcgrad.flatten_grad`. 37 | * `pytorch_optimizer.un_flatten_grad` -> `pytorch_optimizer.optimizers.pcgrad.un_flatten_grad`. 38 | * `pytorch_optimizer.reduce_max_except_dim` -> `pytorch_optimizer.optimizers.sm3.reduce_max_except_dim`. 39 | * `pytorch_optimizer.neuron_norm` -> `pytorch_optimizer.optimizers.nero.neuron_norm`. 40 | * `pytorch_optimizer.neuron_mean` -> `pytorch_optimizer.optimizers.nero.neuron_mean`. 41 | 42 | ### Docs 43 | 44 | * Add more visualizations. (#297) 45 | 46 | ### Bug 47 | 48 | * Add optimizer parameter to `PolyScheduler` constructor. (#295) 49 | 50 | ### Contributions 51 | 52 | thanks to @tanganke 53 | -------------------------------------------------------------------------------- /docs/changelogs/v3.3.1.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Support `Cautious` variant to `AdaShift` optimizer. (#310) 6 | * Save the state of the `Lookahead` optimizer too. (#310) 7 | * Implement `APOLLO` optimizer. (#311, #312) 8 | * [SGD-like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270) 9 | * Rename the `Apollo` (`An Adaptive Parameter-wise Diagonal Quasi-Newton Method for Nonconvex Stochastic Optimization`) optimizer name to `ApolloDQN` not to overlap with the new optimizer name `APOLLO`. (#312) 10 | * Implement `MARS` optimizer. (#313, #314) 11 | * [Unleashing the Power of Variance Reduction for Training Large Models](https://arxiv.org/abs/2411.10438) 12 | * Support `Cautious` variant to `MARS` optimizer. (#314) 13 | 14 | ### Bug 15 | 16 | * Fix `bias_correction` in `AdamG` optimizer. (#305, #308) 17 | * Fix a potential bug when loading the state for `Lookahead` optimizer. (#306, #310) 18 | 19 | ### Docs 20 | 21 | * Add more visualizations. (#310, #314) 22 | 23 | ### Contributions 24 | 25 | thanks to @Vectorrent 26 | -------------------------------------------------------------------------------- /docs/changelogs/v3.3.2.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement `SGDSaI` optimizer. (#315, #316) 6 | * [No More Adam: Learning Rate Scaling at Initialization is All You Need](https://arxiv.org/abs/2412.11768) 7 | 8 | ### Bug 9 | 10 | * Clone `exp_avg` before calling `apply_cautious` not to mask `exp_avg`. (#316) 11 | -------------------------------------------------------------------------------- /docs/changelogs/v3.3.3.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement `Grams` optimizer. (#317, #318) 6 | * [Grams: Gradient Descent with Adaptive Momentum Scaling](https://arxiv.org/abs/2412.17107) 7 | * Support `stable_adamw` variant for `ADOPT` and `AdEMAMix` optimizer. (#321) 8 | * `optimizer = ADOPT(model.parameters(), ..., stable_adamw=True)` 9 | * Implement an experimental optimizer `Ranger25` (not tested). (#321) 10 | * mixing `ADOPT + AdEMAMix + StableAdamW + Cautious + RAdam` optimizers. 11 | * Implement `OrthoGrad` optimizer. (#321) 12 | * [Grokking at the Edge of Numerical Stability](https://arxiv.org/abs/2501.04697) 13 | * Support `Adam-Atan2` feature for `Prodigy` optimizer when `eps` is None. (#321) 14 | * [Scaling Exponents Across Parameterizations and Optimizers](https://arxiv.org/abs/2407.05872) 15 | -------------------------------------------------------------------------------- /docs/changelogs/v3.3.4.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Support `OrthoGrad` feature for `create_optimizer()`. (#324) 6 | * Enhanced flexibility for the `optimizer` parameter in `Lookahead`, `TRAC`, and `OrthoGrad` optimizers. (#324) 7 | * Now supports both torch.optim.Optimizer instances and classes 8 | * You can now use `Lookahead` optimizer in two ways. 9 | * `Lookahead(AdamW(model.parameters(), lr=1e-3), k=5, alpha=0.5)` 10 | * `Lookahead(AdamW, k=5, alpha=0.5, params=model.parameters())` 11 | * Implement `SPAM` optimizer. (#324) 12 | * [Spike-Aware Adam with Momentum Reset for Stable LLM Training](https://arxiv.org/abs/2501.06842) 13 | * Implement `TAM`, and `AdaTAM` optimizers. (#325) 14 | * [Torque-Aware Momentum](https://arxiv.org/abs/2412.18790) 15 | -------------------------------------------------------------------------------- /docs/changelogs/v3.4.0.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement `FOCUS` optimizer. (#330, #331) 6 | * [First Order Concentrated Updating Scheme](https://arxiv.org/abs/2501.12243) 7 | * Implement `PSGD Kron` optimizer. (#336, #337) 8 | * [preconditioned stochastic gradient descent w/ Kron pre-conditioner](https://arxiv.org/abs/1512.04202) 9 | * Implement `EXAdam` optimizer. (#338, #339) 10 | * [The Power of Adaptive Cross-Moments](https://arxiv.org/abs/2412.20302) 11 | 12 | ### Update 13 | 14 | * Support `OrthoGrad` variant to `Ranger25`. (#332) 15 | * `Ranger25` optimizer is my experimental-crafted optimizer, which mixes lots of optimizer variants such as `ADOPT` + `AdEMAMix` + `Cautious` + `StableAdamW` + `Adam-Atan2` + `OrthoGrad`. 16 | 17 | ### Fix 18 | 19 | * Add the missing `state` property in `OrthoGrad` optimizer. (#326, #327) 20 | * Add the missing `state_dict`, and `load_state_dict` methods to `TRAC` and `OrthoGrad` optimizers. (#332) 21 | * Skip when the gradient is sparse in `OrthoGrad` optimizer. (#332) 22 | * Support alternative precision training in `SOAP` optimizer. (#333) 23 | * Store SOAP condition matrices as the dtype of their parameters. (#335) 24 | 25 | ### Contributions 26 | 27 | thanks to @Vectorrent, @kylevedder 28 | -------------------------------------------------------------------------------- /docs/changelogs/v3.4.1.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Support `GCSAM` optimizer. (#343, #344) 6 | * [Gradient Centralized Sharpness Aware Minimization](https://arxiv.org/abs/2501.11584) 7 | * you can use it from `SAM` optimizer by setting `use_gc=True`. 8 | * Support `LookSAM` optimizer. (#343, #344) 9 | * [Towards Efficient and Scalable Sharpness-Aware Minimization](https://arxiv.org/abs/2203.02714) 10 | 11 | ### Update 12 | 13 | * Support alternative precision training for `Shampoo` optimizer. (#339) 14 | * Add more features to and tune `Ranger25` optimizer. (#340) 15 | * `AGC` + `Lookahead` variants 16 | * change default beta1, beta2 to 0.95 and 0.98 respectively 17 | * Skip adding `Lookahead` wrapper in case of `Ranger*` optimizers, which already have it in `create_optimizer()`. (#340) 18 | * Improved optimizer visualization. (#345) 19 | * Rename `pytorch_optimizer.optimizer.gc` to `pytorch_optimizer.optimizer.gradient_centralization` to avoid possible conflict with Python built-in function `gc`. (#349) 20 | 21 | ### Bug 22 | 23 | * Fix to update exp_avg_sq after calculating the denominator in `ADOPT` optimizer. (#346, #347) 24 | 25 | ### Docs 26 | 27 | * Update the visualizations. (#340) 28 | 29 | ### Contributions 30 | 31 | thanks to @AidinHamedi 32 | -------------------------------------------------------------------------------- /docs/changelogs/v3.4.2.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement `SCION` optimizer. (#348, #352) 6 | * [Training Deep Learning Models with Norm-Constrained LMOs](https://arxiv.org/abs/2502.07529) 7 | 8 | ### Update 9 | 10 | * Update ScheduleFreeSGD, AdamW, RAdam optimizers with the latest. (#351, #353) 11 | * Remove `use_palm` variant in ScheduleFree optimizer due to instability. (#353) 12 | * Ranger25 optimizer. (#353) 13 | 14 | ### Fix 15 | 16 | * Remove `weight decouple` parameter in ScheduleFree optimizers. (#351, #353) 17 | 18 | ### Docs 19 | 20 | * Fix `AliG` optimizer visualization. (#350) 21 | 22 | ### Contributions 23 | 24 | thanks to @AidinHamedi, @hatonosuke 25 | -------------------------------------------------------------------------------- /docs/changelogs/v3.5.0.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Support `StableSPAM` optimizer. (#358, #359) 6 | * [How to Train in 4-Bit More Stably than 16-Bit Adam](https://arxiv.org/abs/2502.17055?) 7 | * Support `ScheduleFreeWrapper`. (#334, #360) 8 | * Implement `AdaGC` optimizer. (#364, #366) 9 | * [Improving Training Stability for Large Language Model Pretraining](https://arxiv.org/abs/2502.11034) 10 | * Implement `Simplified-Ademamix` optimizer. (#364, #366) 11 | * [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants](https://arxiv.org/abs/2502.02431) 12 | * Support `Ackley` function for testing optimization algorithms. 13 | 14 | ### Update 15 | 16 | * Update Muon optimizer. (#355, #356) 17 | * support decoupled weight decay. 18 | * adjust default hyperparameters the same as the original implementation. 19 | * support adjusted lr from the Moonlight. you can use it by setting `use_adjusted_lr=True`. 20 | * Tune the performance of the coupled Newton iteration method by 5% increase. (#360) 21 | * Update `SCION` optimizer. (#361) 22 | * add `scale` parameter. 23 | * update `get_lmo_direction`. 24 | 25 | ### Fix 26 | 27 | * bias_correction2 in ScheduleFreeRAdam optimizer. (#354) 28 | * potential bug in SPAM optimizer. (#365) 29 | * initialize the `z` state within the `step()` of the ScheduleFreeWrapper. (#363, #366) 30 | -------------------------------------------------------------------------------- /docs/changelogs/v3.5.1.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement `ScionLight` optimizer. (#369) 6 | 7 | ### Update 8 | 9 | * Update `SCION` optimizer based on the official implementation. (#369) 10 | 11 | ### Fix 12 | 13 | * Correct the learning rate ratio in `Muon` optimizer properly. (#371, #372, #373) 14 | -------------------------------------------------------------------------------- /docs/changelogs/v3.6.0.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ### Feature 4 | 5 | * Implement `Fira` optimizer. (#376) 6 | * [Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623) 7 | * Implement `RACS` and `Alice` optimizers. (#376) 8 | * [Towards Efficient Optimizer Design for LLM via Structured Fisher Approximation with a Low-Rank Extension](https://arxiv.org/abs/2502.07752) 9 | * Implement `VSGD` optimizer. (#377, #378) 10 | * [Variational Stochastic Gradient Descent for Deep Neural Networks](https://openreview.net/forum?id=xu4ATNjcdy) 11 | * Enable training with complex parameters. (#370, #380) 12 | * will raise `NoComplexParameterError` for unsupported optimizers, due to its design or not-yet-implemented. 13 | * Support `maximize` parameter. (#370, #380) 14 | * `maximize`: maximize the objective with respect to the params, instead of minimizing. 15 | * Implement `copy_stochastic()` method. (#381) 16 | 17 | ### Update 18 | 19 | * Support 2D< Tensor for `RACS` and `Alice` optimizers. (#380) 20 | * Remove the auxiliary variants from the default parameters of the optimizers and change the name of the state and parameter. (#380) 21 | * `use_gc`, `adanorm`, `cautious`, `stable_adamw`, and `adam_debias` will be affected. 22 | * You can still use these variants by passing the parameters to `**kwargs`. 23 | * Notably, in case of `adanorm` variant, you need to pass `adanorm` (and `adanorm_r` for `r` option) parameter(s) to use this variant, and the name of the state will be changed from `exp_avg_norm` to `exp_avg_adanorm`. 24 | * Refactor `reset()` to `init_group()` method in the `BaseOptimizer` class. (#380) 25 | * Refactor `SAM` optimizer family. (#380) 26 | * Gather `AdamP`, `SGDP` things into `pytorch_optimizer.optimizer.adamp.*`. (#381) 27 | * `pytorch_optimizer.optimizer.sgdp.SGDP` to `pytorch_optimizer.optimizer.adamp.SGDP` 28 | * `pytorch_optimizer.optimizer.util.projection` to `pytorch_optimizer.optimizer.adamp.projection` 29 | * `pytorch_optimizer.optimizer.util.cosine_similarity_by_view` to `pytorch_optimizer.optimizer.adamp.cosine_similarity_by_view` 30 | * Remove `channel_view()` and `layer_view()` from `pytorch_optimizer.optimizer.util`. (#381) 31 | 32 | ### Fix 33 | 34 | * Fix shape mismatch issues in the Galore projection for `reverse_std`, `right`, and `full` projection types. (#376) 35 | -------------------------------------------------------------------------------- /docs/changelogs/v3.6.1.md: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | ## Feature 4 | 5 | * Implement more cooldown types for WSD learning rate scheduler. (#382, #386) 6 | * Implement `AdamWSN` optimizer. (#387, #389) 7 | * [Lean and Mean Adaptive Optimization via Subset-Norm and Subspace-Momentum with Convergence Guarantees](https://arxiv.org/abs/2411.07120) 8 | 9 | ### Fix 10 | 11 | * Fix to use `momentum buffer` instead of the gradient to calculate LMO. (#385) 12 | -------------------------------------------------------------------------------- /docs/javascripts/tables.js: -------------------------------------------------------------------------------- 1 | app.location$.subscribe(function() { 2 | var tables = document.querySelectorAll("article table") 3 | tables.forEach(function(table) { 4 | new Tablesort(table) 5 | }) 6 | }) 7 | -------------------------------------------------------------------------------- /docs/loss.md: -------------------------------------------------------------------------------- 1 | # Loss Function 2 | 3 | ::: pytorch_optimizer.bi_tempered_logistic_loss 4 | :docstring: 5 | 6 | ::: pytorch_optimizer.BiTemperedLogisticLoss 7 | :docstring: 8 | :members: 9 | 10 | ::: pytorch_optimizer.BinaryBiTemperedLogisticLoss 11 | :docstring: 12 | :members: 13 | 14 | ::: pytorch_optimizer.BCELoss 15 | :docstring: 16 | :members: 17 | 18 | ::: pytorch_optimizer.SoftF1Loss 19 | :docstring: 20 | :members: 21 | 22 | ::: pytorch_optimizer.FocalLoss 23 | :docstring: 24 | :members: 25 | 26 | ::: pytorch_optimizer.FocalCosineLoss 27 | :docstring: 28 | :members: 29 | 30 | ::: pytorch_optimizer.BCEFocalLoss 31 | :docstring: 32 | :members: 33 | 34 | ::: pytorch_optimizer.FocalTverskyLoss 35 | :docstring: 36 | :members: 37 | 38 | ::: pytorch_optimizer.soft_jaccard_score 39 | :docstring: 40 | 41 | ::: pytorch_optimizer.JaccardLoss 42 | :docstring: 43 | :members: 44 | 45 | ::: pytorch_optimizer.LDAMLoss 46 | :docstring: 47 | :members: 48 | 49 | ::: pytorch_optimizer.LovaszHingeLoss 50 | :docstring: 51 | :members: 52 | 53 | ::: pytorch_optimizer.TverskyLoss 54 | :docstring: 55 | :members: 56 | -------------------------------------------------------------------------------- /docs/lr_scheduler.md: -------------------------------------------------------------------------------- 1 | # Learning Rate Scheduler 2 | 3 | ::: pytorch_optimizer.deberta_v3_large_lr_scheduler 4 | :docstring: 5 | :members: 6 | 7 | ::: pytorch_optimizer.get_chebyshev_schedule 8 | :docstring: 9 | :members: 10 | 11 | ::: pytorch_optimizer.get_wsd_schedule 12 | :docstring: 13 | :members: 14 | 15 | ::: pytorch_optimizer.CosineAnnealingWarmupRestarts 16 | :docstring: 17 | :members: 18 | 19 | ::: pytorch_optimizer.LinearScheduler 20 | :docstring: 21 | :members: 22 | 23 | ::: pytorch_optimizer.CosineScheduler 24 | :docstring: 25 | :members: 26 | 27 | ::: pytorch_optimizer.PolyScheduler 28 | :docstring: 29 | :members: 30 | 31 | ::: pytorch_optimizer.ProportionScheduler 32 | :docstring: 33 | :members: 34 | 35 | ::: pytorch_optimizer.REXScheduler 36 | :docstring: 37 | :members: 38 | -------------------------------------------------------------------------------- /docs/qa.md: -------------------------------------------------------------------------------- 1 | # Frequently asked questions 2 | 3 | ## Q1) SophiaH, AdaHessian optimizers give ```RuntimeError: ~ tensors does not require grad and does not have a grad_fn``` in `compute_hutchinson_hessian()`. 4 | 5 | `create_graph` must be set `True` when calling `backward()`. here's [an example](https://github.com/kozistr/pytorch_optimizer/issues/194#issuecomment-1723167466). 6 | 7 | ## Q2) Memory leak happens when using SophiaH, AdaHessian optimizers. 8 | 9 | `torch.autograd.grad` with complex gradient flows sometimes leads memory leak issues, and you might encounter OOM issue. [related issue](https://github.com/kozistr/pytorch_optimizer/issues/278) 10 | 11 | ## Q3) How to run visualizations? 12 | 13 | Run `make visualize` or `python3 -m examples.visualize_optimizers` on the project root. 14 | -------------------------------------------------------------------------------- /docs/util.md: -------------------------------------------------------------------------------- 1 | # Utilization 2 | 3 | ::: pytorch_optimizer.get_supported_optimizers 4 | :docstring: 5 | :members: 6 | 7 | ::: pytorch_optimizer.get_supported_lr_schedulers 8 | :docstring: 9 | :members: 10 | 11 | ::: pytorch_optimizer.get_supported_loss_functions 12 | :docstring: 13 | :members: 14 | 15 | ::: pytorch_optimizer.optimizer.utils.CPUOffloadOptimizer 16 | :docstring: 17 | :members: 18 | 19 | ::: pytorch_optimizer.optimizer.utils.is_valid_parameters 20 | :docstring: 21 | :members: 22 | 23 | ::: pytorch_optimizer.optimizer.utils.has_overflow 24 | :docstring: 25 | :members: 26 | 27 | ::: pytorch_optimizer.optimizer.utils.to_real 28 | :docstring: 29 | :members: 30 | 31 | ::: pytorch_optimizer.optimizer.utils.normalize_gradient 32 | :docstring: 33 | :members: 34 | 35 | ::: pytorch_optimizer.optimizer.utils.clip_grad_norm 36 | :docstring: 37 | :members: 38 | 39 | ::: pytorch_optimizer.optimizer.utils.unit_norm 40 | :docstring: 41 | :members: 42 | 43 | ::: pytorch_optimizer.optimizer.utils.disable_running_stats 44 | :docstring: 45 | :members: 46 | 47 | ::: pytorch_optimizer.optimizer.utils.enable_running_stats 48 | :docstring: 49 | :members: 50 | 51 | ::: pytorch_optimizer.optimizer.utils.get_global_gradient_norm 52 | :docstring: 53 | :members: 54 | 55 | ::: pytorch_optimizer.optimizer.utils.reg_noise 56 | :docstring: 57 | :members: 58 | 59 | ::: pytorch_optimizer.optimizer.utils.copy_stochastic 60 | :docstring: 61 | :members: 62 | -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_ADOPT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_ADOPT.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_APOLLO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_APOLLO.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_ASGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_ASGD.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AccSGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AccSGD.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdEMAMix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdEMAMix.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdaBelief.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdaBelief.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdaBound.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdaBound.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdaDelta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdaDelta.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdaFactor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdaFactor.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdaGC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdaGC.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdaHessian.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdaHessian.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdaMax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdaMax.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdaMod.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdaMod.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdaNorm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdaNorm.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdaPNM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdaPNM.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdaShift.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdaShift.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdaSmooth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdaSmooth.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdaTAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdaTAM.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Adai.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Adai.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Adalite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Adalite.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Adam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Adam.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdamG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdamG.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdamMini.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdamMini.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdamP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdamP.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdamS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdamS.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AdamW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AdamW.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Adan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Adan.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AggMo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AggMo.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Aida.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Aida.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AliG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AliG.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Amos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Amos.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_ApolloDQN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_ApolloDQN.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_AvaGrad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_AvaGrad.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_BSAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_BSAM.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_CAME.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_CAME.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_DAdaptAdaGrad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_DAdaptAdaGrad.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_DAdaptAdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_DAdaptAdam.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_DAdaptAdan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_DAdaptAdan.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_DAdaptLion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_DAdaptLion.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_DAdaptSGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_DAdaptSGD.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_DiffGrad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_DiffGrad.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_EXAdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_EXAdam.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_FAdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_FAdam.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_FOCUS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_FOCUS.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_FTRL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_FTRL.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Fira.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Fira.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Fromage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Fromage.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_GaLore.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_GaLore.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Grams.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Grams.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Gravity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Gravity.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_GrokFastAdamW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_GrokFastAdamW.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Kate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Kate.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Kron.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Kron.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_LARS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_LARS.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_LaProp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_LaProp.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Lamb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Lamb.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Lion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Lion.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_MADGRAD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_MADGRAD.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_MARS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_MARS.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_MSVAG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_MSVAG.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Nero.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Nero.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_NovoGrad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_NovoGrad.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_PAdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_PAdam.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_PID.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_PID.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_PNM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_PNM.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Prodigy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Prodigy.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_QHAdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_QHAdam.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_QHM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_QHM.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_RACS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_RACS.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_RAdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_RAdam.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Ranger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Ranger.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Ranger21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Ranger21.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Ranger25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Ranger25.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_SCION.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_SCION.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_SCIONLight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_SCIONLight.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_SGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_SGD.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_SGDP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_SGDP.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_SGDSaI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_SGDSaI.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_SGDW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_SGDW.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_SM3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_SM3.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_SOAP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_SOAP.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_SPAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_SPAM.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_SRMM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_SRMM.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_SWATS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_SWATS.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_ScalableShampoo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_ScalableShampoo.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_ScheduleFreeAdamW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_ScheduleFreeAdamW.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_ScheduleFreeRAdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_ScheduleFreeRAdam.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_ScheduleFreeSGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_ScheduleFreeSGD.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Shampoo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Shampoo.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_SignSGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_SignSGD.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_SimplifiedAdEMAMix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_SimplifiedAdEMAMix.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_SophiaH.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_SophiaH.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_StableAdamW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_StableAdamW.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_StableSPAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_StableSPAM.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_TAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_TAM.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Tiger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Tiger.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_VSGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_VSGD.png -------------------------------------------------------------------------------- /docs/visualizations/rastrigin_Yogi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rastrigin_Yogi.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_ADOPT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_ADOPT.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_APOLLO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_APOLLO.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_ASGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_ASGD.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AccSGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AccSGD.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdEMAMix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdEMAMix.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdaBelief.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdaBelief.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdaBound.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdaBound.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdaDelta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdaDelta.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdaFactor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdaFactor.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdaGC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdaGC.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdaHessian.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdaHessian.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdaMax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdaMax.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdaMod.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdaMod.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdaNorm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdaNorm.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdaPNM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdaPNM.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdaShift.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdaShift.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdaSmooth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdaSmooth.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdaTAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdaTAM.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Adai.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Adai.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Adalite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Adalite.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Adam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Adam.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdamG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdamG.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdamMini.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdamMini.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdamP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdamP.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdamS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdamS.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AdamW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AdamW.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Adan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Adan.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AggMo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AggMo.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Aida.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Aida.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AliG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AliG.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Amos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Amos.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_ApolloDQN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_ApolloDQN.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_AvaGrad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_AvaGrad.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_BSAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_BSAM.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_CAME.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_CAME.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_DAdaptAdaGrad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_DAdaptAdaGrad.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_DAdaptAdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_DAdaptAdam.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_DAdaptAdan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_DAdaptAdan.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_DAdaptLion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_DAdaptLion.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_DAdaptSGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_DAdaptSGD.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_DiffGrad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_DiffGrad.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_EXAdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_EXAdam.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_FAdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_FAdam.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_FOCUS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_FOCUS.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_FTRL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_FTRL.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Fira.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Fira.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Fromage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Fromage.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_GaLore.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_GaLore.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Grams.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Grams.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Gravity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Gravity.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_GrokFastAdamW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_GrokFastAdamW.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Kate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Kate.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Kron.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Kron.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_LARS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_LARS.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_LaProp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_LaProp.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Lamb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Lamb.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Lion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Lion.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_MADGRAD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_MADGRAD.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_MARS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_MARS.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_MSVAG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_MSVAG.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Nero.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Nero.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_NovoGrad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_NovoGrad.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_PAdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_PAdam.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_PID.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_PID.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_PNM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_PNM.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Prodigy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Prodigy.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_QHAdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_QHAdam.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_QHM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_QHM.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_RACS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_RACS.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_RAdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_RAdam.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Ranger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Ranger.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Ranger21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Ranger21.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Ranger25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Ranger25.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_SCION.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_SCION.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_SCIONLight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_SCIONLight.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_SGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_SGD.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_SGDP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_SGDP.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_SGDSaI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_SGDSaI.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_SGDW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_SGDW.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_SM3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_SM3.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_SOAP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_SOAP.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_SPAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_SPAM.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_SRMM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_SRMM.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_SWATS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_SWATS.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_ScalableShampoo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_ScalableShampoo.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_ScheduleFreeAdamW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_ScheduleFreeAdamW.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_ScheduleFreeRAdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_ScheduleFreeRAdam.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_ScheduleFreeSGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_ScheduleFreeSGD.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Shampoo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Shampoo.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_SignSGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_SignSGD.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_SimplifiedAdEMAMix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_SimplifiedAdEMAMix.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_SophiaH.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_SophiaH.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_StableAdamW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_StableAdamW.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_StableSPAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_StableSPAM.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_TAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_TAM.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Tiger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Tiger.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_VSGD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_VSGD.png -------------------------------------------------------------------------------- /docs/visualizations/rosenbrock_Yogi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/docs/visualizations/rosenbrock_Yogi.png -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/examples/__init__.py -------------------------------------------------------------------------------- /examples/pytorch_lightning_example.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | from torch import nn 6 | from torch.optim import AdamW 7 | from torch.utils.data import DataLoader 8 | from torchvision.datasets import MNIST 9 | from torchvision.transforms import ToTensor 10 | 11 | from pytorch_optimizer import Lookahead 12 | 13 | 14 | class LitAutoEncoder(pl.LightningModule): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3)) 19 | self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28)) 20 | 21 | def training_step(self, batch, batch_idx): 22 | x, y = batch 23 | x = x.view(x.size(0), -1) 24 | 25 | z = self.encoder(x) 26 | x_hat = self.decoder(z) 27 | 28 | loss = nn.functional.mse_loss(x_hat, x) 29 | 30 | self.log('train_loss', loss) 31 | 32 | return loss 33 | 34 | def configure_optimizers(self): 35 | return Lookahead(AdamW(self.parameters(), lr=1e-3), k=5, alpha=0.5) 36 | 37 | 38 | def main(): 39 | train_dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()) 40 | train_loader = DataLoader(train_dataset) 41 | 42 | autoencoder = LitAutoEncoder() 43 | autoencoder.train() 44 | 45 | if torch.cuda.is_available(): 46 | autoencoder.cuda() 47 | 48 | trainer = pl.Trainer(limit_train_batches=100, max_epochs=1) 49 | trainer.fit(model=autoencoder, train_dataloaders=train_loader) 50 | 51 | 52 | if __name__ == '__main__': 53 | main() 54 | -------------------------------------------------------------------------------- /examples/pytorch_lightning_manual_backward_example.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | from torch import nn 6 | from torch.utils.data import DataLoader 7 | from torchvision.datasets import MNIST 8 | from torchvision.transforms import ToTensor 9 | 10 | from pytorch_optimizer import SophiaH 11 | 12 | 13 | class LitAutoEncoder(pl.LightningModule): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3)) 18 | self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28)) 19 | 20 | self.automatic_optimization = False 21 | 22 | def training_step(self, batch, batch_idx): 23 | opt = self.optimizers() 24 | opt.zero_grad() 25 | 26 | x, y = batch 27 | x = x.view(x.size(0), -1) 28 | 29 | z = self.encoder(x) 30 | x_hat = self.decoder(z) 31 | 32 | loss = nn.functional.mse_loss(x_hat, x) 33 | 34 | self.manual_backward(loss, create_graph=True) 35 | opt.step() 36 | 37 | self.log('train_loss', loss) 38 | 39 | return loss 40 | 41 | def configure_optimizers(self): 42 | return SophiaH(self.parameters()) 43 | 44 | 45 | def main(): 46 | train_dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()) 47 | train_loader = DataLoader(train_dataset) 48 | 49 | autoencoder = LitAutoEncoder() 50 | autoencoder.train() 51 | 52 | if torch.cuda.is_available(): 53 | autoencoder.cuda() 54 | 55 | trainer = pl.Trainer(limit_train_batches=100, max_epochs=1) 56 | trainer.fit(model=autoencoder, train_dataloaders=train_loader) 57 | 58 | 59 | if __name__ == '__main__': 60 | main() 61 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | from functools import partial as _partial 2 | from functools import update_wrapper as _update_wrapper 3 | 4 | from pytorch_optimizer import get_supported_lr_schedulers as _get_supported_lr_schedulers 5 | from pytorch_optimizer import get_supported_optimizers as _get_supported_optimizers 6 | from pytorch_optimizer import load_lr_scheduler as _load_lr_scheduler 7 | from pytorch_optimizer import load_optimizer as _load_optimizer 8 | 9 | dependencies = ['torch'] 10 | 11 | for _optimizer in _get_supported_optimizers(): 12 | name: str = _optimizer.__name__ 13 | _func = _partial(_load_optimizer, optimizer=name) 14 | _update_wrapper(_func, _optimizer) 15 | for n in (name, name.lower(), name.upper()): 16 | globals()[n] = _func 17 | 18 | for _scheduler in _get_supported_lr_schedulers(): 19 | name: str = _scheduler.__name__ 20 | _func = _partial(_load_lr_scheduler, lr_scheduler=name) 21 | _update_wrapper(_func, _scheduler) 22 | for n in (name, name.lower(), name.upper()): 23 | globals()[n] = _func 24 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: pytorch-optimizer 2 | site_description: 'optimizer & lr scheduler & loss function collections in PyTorch' 3 | repo_name: 'kozistr/pytorch-optimizer' 4 | repo_url: 'https://github.com/kozistr/pytorch_optimizer' 5 | nav: 6 | - index.md 7 | - base.md 8 | - optimizer.md 9 | - lr_scheduler.md 10 | - loss.md 11 | - util.md 12 | - visualization.md 13 | - ... | changelogs/*.md 14 | - qa.md 15 | theme: 16 | name: material 17 | highlightjs: true 18 | extra_javascript: 19 | - 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-MML-AM_CHTML' 20 | - https://cdnjs.cloudflare.com/ajax/libs/tablesort/5.2.1/tablesort.min.js 21 | - javascripts/tables.js 22 | plugins: 23 | - search 24 | - awesome-pages 25 | - mkdocstrings: 26 | handlers: 27 | python: 28 | options: 29 | # https://mkdocstrings.github.io/python/usage/configuration/general/ 30 | show_root_heading: true 31 | show_root_full_path: false 32 | show_root_members_full_path: false 33 | # show_symbol_type_toc: true 34 | allow_inspection: true 35 | show_bases: true 36 | show_source: true 37 | docstring_style: sphinx 38 | markdown_extensions: 39 | - admonition 40 | - pymdownx.arithmatex 41 | - pymdownx.betterem: 42 | smart_enable: all 43 | - pymdownx.caret 44 | - pymdownx.critic 45 | - pymdownx.details 46 | - pymdownx.emoji: 47 | emoji_generator: !!python/name:pymdownx.emoji.to_svg 48 | - pymdownx.inlinehilite 49 | - pymdownx.magiclink 50 | - pymdownx.mark 51 | - pymdownx.smartsymbols 52 | - pymdownx.superfences 53 | - pymdownx.tasklist: 54 | custom_checkbox: true 55 | - pymdownx.tilde 56 | - mdx_truly_sane_lists 57 | - markdown_include.include: 58 | base_path: . 59 | -------------------------------------------------------------------------------- /pytorch_optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa 2 | from pytorch_optimizer.loss import ( 3 | BCEFocalLoss, 4 | BCELoss, 5 | BinaryBiTemperedLogisticLoss, 6 | BiTemperedLogisticLoss, 7 | DiceLoss, 8 | FocalCosineLoss, 9 | FocalLoss, 10 | FocalTverskyLoss, 11 | JaccardLoss, 12 | LDAMLoss, 13 | LovaszHingeLoss, 14 | SoftF1Loss, 15 | TverskyLoss, 16 | bi_tempered_logistic_loss, 17 | get_supported_loss_functions, 18 | soft_dice_score, 19 | soft_jaccard_score, 20 | ) 21 | from pytorch_optimizer.lr_scheduler import ( 22 | ConstantLR, 23 | CosineAnnealingLR, 24 | CosineAnnealingWarmRestarts, 25 | CosineAnnealingWarmupRestarts, 26 | CosineScheduler, 27 | CyclicLR, 28 | LinearScheduler, 29 | MultiplicativeLR, 30 | MultiStepLR, 31 | OneCycleLR, 32 | PolyScheduler, 33 | ProportionScheduler, 34 | REXScheduler, 35 | StepLR, 36 | deberta_v3_large_lr_scheduler, 37 | get_chebyshev_perm_steps, 38 | get_chebyshev_schedule, 39 | get_supported_lr_schedulers, 40 | get_wsd_schedule, 41 | load_lr_scheduler, 42 | ) 43 | from pytorch_optimizer.optimizer import ( 44 | ADOPT, 45 | APOLLO, 46 | ASGD, 47 | BSAM, 48 | CAME, 49 | FOCUS, 50 | FTRL, 51 | GSAM, 52 | LARS, 53 | LOMO, 54 | MADGRAD, 55 | MARS, 56 | MSVAG, 57 | PID, 58 | PNM, 59 | QHM, 60 | RACS, 61 | SAM, 62 | SCION, 63 | SGDP, 64 | SGDW, 65 | SM3, 66 | SOAP, 67 | SPAM, 68 | SRMM, 69 | SWATS, 70 | TAM, 71 | TRAC, 72 | VSGD, 73 | WSAM, 74 | A2Grad, 75 | AccSGD, 76 | AdaBelief, 77 | AdaBound, 78 | AdaDelta, 79 | AdaFactor, 80 | AdaGC, 81 | AdaHessian, 82 | Adai, 83 | Adalite, 84 | AdaLOMO, 85 | AdaMax, 86 | AdamG, 87 | AdamMini, 88 | AdaMod, 89 | AdamP, 90 | AdamS, 91 | AdamW, 92 | AdamWSN, 93 | Adan, 94 | AdaNorm, 95 | AdaPNM, 96 | AdaShift, 97 | AdaSmooth, 98 | AdaTAM, 99 | AdEMAMix, 100 | AggMo, 101 | Aida, 102 | Alice, 103 | AliG, 104 | Amos, 105 | ApolloDQN, 106 | AvaGrad, 107 | DAdaptAdaGrad, 108 | DAdaptAdam, 109 | DAdaptAdan, 110 | DAdaptLion, 111 | DAdaptSGD, 112 | DeMo, 113 | DiffGrad, 114 | DynamicLossScaler, 115 | EXAdam, 116 | FAdam, 117 | Fira, 118 | Fromage, 119 | GaLore, 120 | Grams, 121 | Gravity, 122 | GrokFastAdamW, 123 | Kate, 124 | Kron, 125 | Lamb, 126 | LaProp, 127 | Lion, 128 | Lookahead, 129 | LookSAM, 130 | Muon, 131 | Nero, 132 | NovoGrad, 133 | OrthoGrad, 134 | PAdam, 135 | PCGrad, 136 | Prodigy, 137 | QHAdam, 138 | RAdam, 139 | Ranger, 140 | Ranger21, 141 | Ranger25, 142 | RotoGrad, 143 | SafeFP16Optimizer, 144 | ScalableShampoo, 145 | ScheduleFreeAdamW, 146 | ScheduleFreeRAdam, 147 | ScheduleFreeSGD, 148 | ScheduleFreeWrapper, 149 | SCIONLight, 150 | SGDSaI, 151 | Shampoo, 152 | SignSGD, 153 | SimplifiedAdEMAMix, 154 | SophiaH, 155 | StableAdamW, 156 | StableSPAM, 157 | Tiger, 158 | Yogi, 159 | agc, 160 | centralize_gradient, 161 | create_optimizer, 162 | get_optimizer_parameters, 163 | get_supported_optimizers, 164 | load_ao_optimizer, 165 | load_bnb_optimizer, 166 | load_optimizer, 167 | load_q_galore_optimizer, 168 | ) 169 | from pytorch_optimizer.optimizer.utils import ( 170 | CPUOffloadOptimizer, 171 | clip_grad_norm, 172 | copy_stochastic, 173 | disable_running_stats, 174 | enable_running_stats, 175 | get_global_gradient_norm, 176 | normalize_gradient, 177 | unit_norm, 178 | ) 179 | -------------------------------------------------------------------------------- /pytorch_optimizer/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/pytorch_optimizer/base/__init__.py -------------------------------------------------------------------------------- /pytorch_optimizer/base/exception.py: -------------------------------------------------------------------------------- 1 | class NoSparseGradientError(Exception): 2 | """Raised when the gradient is sparse gradient. 3 | 4 | :param optimizer_name: str. optimizer name. 5 | :param note: str. special conditions to note (default ''). 6 | """ 7 | 8 | def __init__(self, optimizer_name: str, note: str = ''): 9 | self.note: str = ' ' if not note else f' w/ {note} ' 10 | self.message: str = f'{optimizer_name}{self.note}does not support sparse gradient.' 11 | super().__init__(self.message) 12 | 13 | 14 | class ZeroParameterSizeError(Exception): 15 | """Raised when the parameter size is 0.""" 16 | 17 | def __init__(self): 18 | self.message: str = 'parameter size is 0' 19 | super().__init__(self.message) 20 | 21 | 22 | class NoClosureError(Exception): 23 | """Raised when there's no closure function.""" 24 | 25 | def __init__(self, optimizer_name: str, note: str = ''): 26 | self.message: str = f'{optimizer_name} requires closure.{note}' 27 | super().__init__(self.message) 28 | 29 | 30 | class NegativeLRError(Exception): 31 | """Raised when learning rate is negative.""" 32 | 33 | def __init__(self, lr: float, lr_type: str = ''): 34 | self.note: str = lr_type if lr_type else 'learning rate' 35 | self.message: str = f'{self.note} must be positive. ({lr} > 0)' 36 | super().__init__(self.message) 37 | 38 | 39 | class NegativeStepError(Exception): 40 | """Raised when step is negative.""" 41 | 42 | def __init__(self, num_steps: int, step_type: str = ''): 43 | self.note: str = step_type if step_type else 'step' 44 | self.message: str = f'{self.note} must be positive. ({num_steps} > 0)' 45 | super().__init__(self.message) 46 | 47 | 48 | class NoComplexParameterError(Exception): 49 | """Raised when the dtype of the parameter is complex. 50 | 51 | :param optimizer_name: str. optimizer name. 52 | :param note: str. special conditions to note (default ''). 53 | """ 54 | 55 | def __init__(self, optimizer_name: str, note: str = ''): 56 | self.note: str = ' ' if not note else f' w/ {note} ' 57 | self.message: str = f'{optimizer_name}{self.note}does not support complex parameter.' 58 | super().__init__(self.message) 59 | -------------------------------------------------------------------------------- /pytorch_optimizer/base/scheduler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List 3 | 4 | from torch.optim import Optimizer 5 | 6 | from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError 7 | 8 | 9 | class BaseLinearWarmupScheduler(ABC): 10 | r"""BaseLinearWarmupScheduler class. 11 | 12 | The LR Scheduler class based on this class has linear warmup strategy. 13 | 14 | :param optimizer: Optimizer. It will set learning rate to all trainable parameters in optimizer. 15 | :param t_max: int. total steps to train. 16 | :param max_lr: float. maximum lr. 17 | :param min_lr: float. minimum lr. 18 | :param init_lr: float. initial lr. 19 | :param warmup_steps: int. steps to warm-up. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | optimizer: Optimizer, 25 | t_max: int, 26 | max_lr: float, 27 | min_lr: float = 0.0, 28 | init_lr: float = 0.0, 29 | warmup_steps: int = 0, 30 | ): 31 | self.optimizer = optimizer 32 | self.total_steps = t_max 33 | self.max_lr = max_lr 34 | self.min_lr = min_lr 35 | self.init_lr = init_lr 36 | self.warmup_steps = warmup_steps 37 | 38 | self.step_t: int = 0 39 | self.base_lrs: List[float] = [] 40 | 41 | # record current value in self._last_lr to match API from torch.optim.lr_scheduler 42 | self.last_lr: List[float] = [init_lr] 43 | 44 | self.validate_parameters() 45 | 46 | self._init_lr() 47 | 48 | def validate_parameters(self): 49 | if self.min_lr < 0: 50 | raise NegativeLRError(self.min_lr, 'min_lr') 51 | 52 | if self.max_lr < 0: 53 | raise NegativeLRError(self.max_lr, 'max_lr') 54 | 55 | if self.init_lr < 0: 56 | raise NegativeLRError(self.init_lr, 'init_lr') 57 | 58 | if self.total_steps < 0: 59 | raise NegativeStepError(self.total_steps, 't_max') 60 | 61 | if self.warmup_steps < 0: 62 | raise NegativeStepError(self.warmup_steps, 'warmup_steps') 63 | 64 | def _init_lr(self): 65 | self.base_lrs = [] 66 | for param_group in self.optimizer.param_groups: 67 | param_group['lr'] = self.min_lr 68 | self.base_lrs.append(self.min_lr) 69 | 70 | def step(self): 71 | if self.step_t < self.warmup_steps: 72 | value = self.init_lr + (self.max_lr - self.init_lr) * self.step_t / self.warmup_steps 73 | elif self.step_t == self.warmup_steps: 74 | value = self.max_lr 75 | else: 76 | value = self._step() 77 | 78 | self.step_t += 1 79 | 80 | if self.optimizer is not None: 81 | for param_group in self.optimizer.param_groups: 82 | param_group['lr'] = value 83 | 84 | self.last_lr = [value] 85 | 86 | return value 87 | 88 | @abstractmethod 89 | def _step(self) -> float: # pragma: no cover 90 | raise NotImplementedError 91 | 92 | def get_lr(self) -> float: 93 | return self.last_lr[0] 94 | -------------------------------------------------------------------------------- /pytorch_optimizer/base/type.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Iterable, Literal, Optional, Tuple, Type, Union 2 | 3 | import torch 4 | from torch.optim import Optimizer 5 | from torch.optim.lr_scheduler import LRScheduler 6 | 7 | CLOSURE = Optional[Callable[[], float]] 8 | LOSS = Optional[float] 9 | BETAS = Union[Tuple[float, float], Tuple[float, float, float], Tuple[None, float]] 10 | DEFAULTS = Dict 11 | GROUP = Dict 12 | PARAMETERS = Optional[Union[Iterable[GROUP], Iterable[torch.Tensor]]] 13 | STATE = Dict 14 | OPTIMIZER = Type[Optimizer] 15 | OPTIMIZER_INSTANCE_OR_CLASS = Union[OPTIMIZER, Optimizer] 16 | SCHEDULER = Type[LRScheduler] 17 | 18 | HUTCHINSON_G = Literal['gaussian', 'rademacher'] 19 | CLASS_MODE = Literal['binary', 'multiclass', 'multilabel'] 20 | 21 | DATA_FORMAT = Literal['channels_first', 'channels_last'] 22 | -------------------------------------------------------------------------------- /pytorch_optimizer/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | from typing import Dict, List, Optional, Sequence, Set, Union 3 | 4 | from torch import nn 5 | 6 | from pytorch_optimizer.loss.bi_tempered import ( 7 | BinaryBiTemperedLogisticLoss, 8 | BiTemperedLogisticLoss, 9 | bi_tempered_logistic_loss, 10 | ) 11 | from pytorch_optimizer.loss.cross_entropy import BCELoss 12 | from pytorch_optimizer.loss.dice import DiceLoss, soft_dice_score 13 | from pytorch_optimizer.loss.f1 import SoftF1Loss 14 | from pytorch_optimizer.loss.focal import BCEFocalLoss, FocalCosineLoss, FocalLoss, FocalTverskyLoss 15 | from pytorch_optimizer.loss.jaccard import JaccardLoss, soft_jaccard_score 16 | from pytorch_optimizer.loss.ldam import LDAMLoss 17 | from pytorch_optimizer.loss.lovasz import LovaszHingeLoss 18 | from pytorch_optimizer.loss.tversky import TverskyLoss 19 | 20 | LOSS_FUNCTION_LIST: List = [ 21 | BCELoss, 22 | BCEFocalLoss, 23 | FocalLoss, 24 | SoftF1Loss, 25 | DiceLoss, 26 | LDAMLoss, 27 | FocalCosineLoss, 28 | JaccardLoss, 29 | BiTemperedLogisticLoss, 30 | BinaryBiTemperedLogisticLoss, 31 | TverskyLoss, 32 | FocalTverskyLoss, 33 | LovaszHingeLoss, 34 | ] 35 | LOSS_FUNCTIONS: Dict[str, nn.Module] = { 36 | str(loss_function.__name__).lower(): loss_function for loss_function in LOSS_FUNCTION_LIST 37 | } 38 | 39 | 40 | def get_supported_loss_functions(filters: Optional[Union[str, List[str]]] = None) -> List[str]: 41 | r"""Return list of available loss function names, sorted alphabetically. 42 | 43 | :param filters: Optional[Union[str, List[str]]]. wildcard filter string that works with fmatch. if None, it will 44 | return the whole list. 45 | """ 46 | if filters is None: 47 | return sorted(LOSS_FUNCTIONS.keys()) 48 | 49 | include_filters: Sequence[str] = filters if isinstance(filters, (tuple, list)) else [filters] 50 | 51 | filtered_list: Set[str] = set() 52 | for include_filter in include_filters: 53 | filtered_list.update(fnmatch.filter(LOSS_FUNCTIONS.keys(), include_filter)) 54 | 55 | return sorted(filtered_list) 56 | -------------------------------------------------------------------------------- /pytorch_optimizer/loss/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.functional import binary_cross_entropy 4 | 5 | 6 | class BCELoss(nn.Module): 7 | r"""binary cross entropy with label smoothing + probability input. 8 | 9 | :param label_smooth: float. Smoothness constant for dice coefficient (a). 10 | :param eps: float. epsilon. 11 | :param reduction: str. type of reduction. 12 | """ 13 | 14 | def __init__(self, label_smooth: float = 0.0, eps: float = 1e-6, reduction: str = 'mean'): 15 | super().__init__() 16 | self.label_smooth = label_smooth 17 | self.eps = eps 18 | self.reduction = reduction 19 | 20 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 21 | if self.training and self.label_smooth > 0.0: 22 | y_true = (1.0 - self.label_smooth) * y_true + self.label_smooth / y_pred.size(-1) 23 | y_pred = torch.clamp(y_pred, self.eps, 1.0 - self.eps) 24 | return binary_cross_entropy(y_pred, y_true, reduction=self.reduction) 25 | -------------------------------------------------------------------------------- /pytorch_optimizer/loss/f1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class SoftF1Loss(nn.Module): 6 | r"""Soft-F1 loss. 7 | 8 | :param beta: float. f-beta. 9 | :param eps: float. epsilon. 10 | """ 11 | 12 | def __init__(self, beta: float = 1.0, eps: float = 1e-6): 13 | super().__init__() 14 | self.beta = beta 15 | self.eps = eps 16 | 17 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 18 | tp = (y_true * y_pred).sum().float() 19 | fn = ((1 - y_true) * y_pred).sum().float() 20 | fp = (y_true * (1 - y_pred)).sum().float() 21 | 22 | p = tp / (tp + fp + self.eps) 23 | r = tp / (tp + fn + self.eps) 24 | 25 | f1 = (1 + self.beta ** 2) * (p * r) / ((self.beta ** 2) * p + r + self.eps) # fmt: skip 26 | f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1), f1) 27 | 28 | return 1.0 - f1.mean() 29 | -------------------------------------------------------------------------------- /pytorch_optimizer/loss/focal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.functional import ( 4 | binary_cross_entropy_with_logits, 5 | cosine_embedding_loss, 6 | cross_entropy, 7 | normalize, 8 | one_hot, 9 | ) 10 | 11 | from pytorch_optimizer.loss.cross_entropy import BCELoss 12 | from pytorch_optimizer.loss.tversky import TverskyLoss 13 | 14 | 15 | class FocalLoss(nn.Module): 16 | r"""Focal loss function w/ logit input. 17 | 18 | :param alpha: float. alpha. 19 | :param gamma: float. gamma. 20 | """ 21 | 22 | def __init__(self, alpha: float = 1.0, gamma: float = 2.0): 23 | super().__init__() 24 | self.alpha = alpha 25 | self.gamma = gamma 26 | 27 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 28 | bce_loss = binary_cross_entropy_with_logits(y_pred, y_true, reduction='none') 29 | pt = torch.exp(-bce_loss) 30 | focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss 31 | return focal_loss.mean() 32 | 33 | 34 | class FocalCosineLoss(nn.Module): 35 | r"""Focal Cosine loss function w/ logit input. 36 | 37 | :param alpha: float. alpha. 38 | :param gamma: float. gamma. 39 | :param focal_weight: float. weight of focal loss. 40 | :param reduction: str. type of reduction. 41 | """ 42 | 43 | def __init__(self, alpha: float = 1.0, gamma: float = 2.0, focal_weight: float = 0.1, reduction: str = 'mean'): 44 | super().__init__() 45 | self.alpha = alpha 46 | self.gamma = gamma 47 | self.focal_weight = focal_weight 48 | self.reduction = reduction 49 | 50 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 51 | cosine_loss = cosine_embedding_loss( 52 | y_pred, 53 | one_hot(y_true, num_classes=y_pred.size(-1)), 54 | torch.tensor([1], device=y_true.device), 55 | reduction=self.reduction, 56 | ) 57 | 58 | ce_loss = cross_entropy(normalize(y_pred), y_true, reduction='none') 59 | pt = torch.exp(-ce_loss) 60 | focal_loss = (self.alpha * (1 - pt) ** self.gamma * ce_loss).mean() 61 | 62 | return cosine_loss + self.focal_weight * focal_loss 63 | 64 | 65 | class BCEFocalLoss(nn.Module): 66 | r"""BCEFocal loss function w/ probability input. 67 | 68 | :param alpha: float. alpha. 69 | :param gamma: float. gamma. 70 | :param label_smooth: float. Smoothness constant for dice coefficient (a). 71 | :param eps: float. epsilon. 72 | :param reduction: str. type of reduction. 73 | """ 74 | 75 | def __init__( 76 | self, 77 | alpha: float = 0.25, 78 | gamma: float = 2.0, 79 | label_smooth: float = 0.0, 80 | eps: float = 1e-6, 81 | reduction: str = 'mean', 82 | ): 83 | super().__init__() 84 | self.alpha = alpha 85 | self.gamma = gamma 86 | self.reduction = reduction 87 | 88 | self.bce = BCELoss(label_smooth=label_smooth, eps=eps, reduction='none') 89 | 90 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 91 | bce_loss = self.bce(y_pred, y_true) 92 | focal_loss = ( 93 | y_true * self.alpha * (1.0 - y_pred) ** self.gamma * bce_loss 94 | + (1.0 - y_true) ** self.gamma * bce_loss 95 | ) # fmt: skip 96 | 97 | return focal_loss.mean() if self.reduction == 'mean' else focal_loss.sum() 98 | 99 | 100 | class FocalTverskyLoss(nn.Module): 101 | r"""Focal Tversky Loss w/ logits input. 102 | 103 | :param alpha: float. alpha. 104 | :param beta: float. beta. 105 | :param gamma: float. gamma. 106 | :param smooth: float. smooth factor. 107 | """ 108 | 109 | def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, smooth: float = 1e-6): 110 | super().__init__() 111 | self.gamma = gamma 112 | 113 | self.tversky = TverskyLoss(alpha, beta, smooth) 114 | 115 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 116 | return self.tversky(y_pred, y_true) ** self.gamma 117 | -------------------------------------------------------------------------------- /pytorch_optimizer/loss/ldam.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn.functional import cross_entropy 7 | 8 | 9 | class LDAMLoss(nn.Module): 10 | r"""LDAM Loss. 11 | 12 | :param num_class_list: List[int]. list of number of class. 13 | :param max_m: float. max margin (`C` term in the paper). 14 | :param weight: Optional[torch.Tensor]. class weight. 15 | :param s: float. scaler. 16 | """ 17 | 18 | def __init__( 19 | self, num_class_list: List[int], max_m: float = 0.5, weight: Optional[torch.Tensor] = None, s: float = 30.0 20 | ): 21 | super().__init__() 22 | 23 | cls_num_list = np.asarray(num_class_list) 24 | m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list)) 25 | m_list *= max_m / np.max(m_list) 26 | 27 | self.m_list = torch.FloatTensor(m_list).unsqueeze(0) 28 | self.weight = weight 29 | self.s = s 30 | 31 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 32 | index = torch.zeros_like(y_pred, dtype=torch.bool) 33 | index.scatter_(1, y_true.view(-1, 1), 1) 34 | 35 | batch_m = torch.matmul(self.m_list.to(index.device), index.float().transpose(0, 1)) 36 | batch_m = batch_m.view((-1, 1)) 37 | x_m = y_pred - batch_m 38 | 39 | output = torch.where(index, x_m, y_pred) 40 | return cross_entropy(self.s * output, y_true, weight=self.weight) 41 | -------------------------------------------------------------------------------- /pytorch_optimizer/loss/lovasz.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.functional import relu 4 | 5 | 6 | def lovasz_grad(gt_sorted: torch.Tensor) -> torch.Tensor: 7 | r"""Compute gradient of the Lovasz extension w.r.t sorted errors.""" 8 | p = len(gt_sorted) 9 | gts = gt_sorted.sum() 10 | intersection = gts - gt_sorted.float().cumsum(0) 11 | union = gts + (1 - gt_sorted).float().cumsum(0) 12 | jaccard = 1.0 - intersection / union 13 | if p > 1: # cover 1-pixel case 14 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 15 | return jaccard 16 | 17 | 18 | def lovasz_hinge_flat(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 19 | r"""Binary Lovasz hinge loss. 20 | 21 | :param y_pred: torch.Tensor. 22 | :param y_true: torch.Tensor. 23 | """ 24 | y_pred = y_pred.view(-1) 25 | y_true = y_true.view(-1) 26 | 27 | signs = 2.0 * y_true.float() - 1.0 28 | 29 | errors = 1.0 - y_pred * signs 30 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 31 | 32 | grad = lovasz_grad(y_true[perm]) 33 | 34 | return torch.dot(relu(errors_sorted), grad) 35 | 36 | 37 | class LovaszHingeLoss(nn.Module): 38 | r"""Binary Lovasz hinge loss. 39 | 40 | :param per_image: bool. compute the loss per image instead of per batch. 41 | """ 42 | 43 | def __init__(self, per_image: bool = True): 44 | super().__init__() 45 | self.per_image = per_image 46 | 47 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 48 | if not self.per_image: 49 | return lovasz_hinge_flat(y_pred, y_true) 50 | return sum(lovasz_hinge_flat(y_p, y_t) for y_p, y_t in zip(y_pred, y_true)) / y_pred.size()[0] 51 | -------------------------------------------------------------------------------- /pytorch_optimizer/loss/tversky.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class TverskyLoss(nn.Module): 6 | r"""Tversky Loss w/ logits input. 7 | 8 | :param alpha: float. alpha. 9 | :param beta: float. beta. 10 | :param smooth: float. smooth factor. 11 | """ 12 | 13 | def __init__(self, alpha: float = 0.5, beta: float = 0.5, smooth: float = 1e-6): 14 | super().__init__() 15 | self.alpha = alpha 16 | self.beta = beta 17 | self.smooth = smooth 18 | 19 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 20 | y_pred = torch.sigmoid(y_pred) 21 | 22 | y_pred = y_pred.view(-1) 23 | y_true = y_true.view(-1) 24 | 25 | tp = (y_pred * y_true).sum() 26 | fp = ((1.0 - y_true) * y_pred).sum() 27 | fn = (y_true * (1.0 - y_pred)).sum() 28 | 29 | loss = (tp + self.smooth) / (tp + self.alpha * fp + self.beta * fn + self.smooth) 30 | 31 | return 1.0 - loss 32 | -------------------------------------------------------------------------------- /pytorch_optimizer/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa 2 | import fnmatch 3 | from enum import Enum 4 | from typing import Dict, List, Optional, Sequence, Set, Union 5 | 6 | from torch.optim.lr_scheduler import ( 7 | ConstantLR, 8 | CosineAnnealingLR, 9 | CosineAnnealingWarmRestarts, 10 | CyclicLR, 11 | MultiplicativeLR, 12 | MultiStepLR, 13 | OneCycleLR, 14 | StepLR, 15 | ) 16 | 17 | from pytorch_optimizer.base.type import SCHEDULER 18 | from pytorch_optimizer.lr_scheduler.chebyshev import get_chebyshev_perm_steps, get_chebyshev_schedule 19 | from pytorch_optimizer.lr_scheduler.cosine_anealing import CosineAnnealingWarmupRestarts 20 | from pytorch_optimizer.lr_scheduler.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler 21 | from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler 22 | from pytorch_optimizer.lr_scheduler.proportion import ProportionScheduler 23 | from pytorch_optimizer.lr_scheduler.rex import REXScheduler 24 | from pytorch_optimizer.lr_scheduler.wsd import get_wsd_schedule 25 | 26 | 27 | class SchedulerType(Enum): 28 | CONSTANT = 'constant' 29 | LINEAR = 'linear' 30 | PROPORTION = 'proportion' 31 | STEP = 'step' 32 | MULTI_STEP = 'multi_step' 33 | MULTIPLICATIVE = 'multiplicative' 34 | CYCLIC = 'cyclic' 35 | ONE_CYCLE = 'one_cycle' 36 | COSINE = 'cosine' 37 | POLY = 'poly' 38 | COSINE_ANNEALING = 'cosine_annealing' 39 | COSINE_ANNEALING_WITH_WARM_RESTART = 'cosine_annealing_with_warm_restart' 40 | COSINE_ANNEALING_WITH_WARMUP = 'cosine_annealing_with_warmup' 41 | CHEBYSHEV = 'chebyshev' 42 | REX = 'rex' 43 | WARMUP_STABLE_DECAY = 'warmup_stable_decay' 44 | 45 | def __str__(self) -> str: 46 | return self.value 47 | 48 | 49 | LR_SCHEDULER_LIST: Dict = { 50 | SchedulerType.CONSTANT: ConstantLR, 51 | SchedulerType.STEP: StepLR, 52 | SchedulerType.MULTI_STEP: MultiStepLR, 53 | SchedulerType.CYCLIC: CyclicLR, 54 | SchedulerType.MULTIPLICATIVE: MultiplicativeLR, 55 | SchedulerType.ONE_CYCLE: OneCycleLR, 56 | SchedulerType.COSINE: CosineScheduler, 57 | SchedulerType.POLY: PolyScheduler, 58 | SchedulerType.LINEAR: LinearScheduler, 59 | SchedulerType.PROPORTION: ProportionScheduler, 60 | SchedulerType.COSINE_ANNEALING: CosineAnnealingLR, 61 | SchedulerType.COSINE_ANNEALING_WITH_WARMUP: CosineAnnealingWarmupRestarts, 62 | SchedulerType.COSINE_ANNEALING_WITH_WARM_RESTART: CosineAnnealingWarmRestarts, 63 | SchedulerType.CHEBYSHEV: get_chebyshev_schedule, 64 | SchedulerType.REX: REXScheduler, 65 | SchedulerType.WARMUP_STABLE_DECAY: get_wsd_schedule, 66 | } 67 | LR_SCHEDULERS: Dict[str, SCHEDULER] = { 68 | str(lr_scheduler_name).lower(): lr_scheduler for lr_scheduler_name, lr_scheduler in LR_SCHEDULER_LIST.items() 69 | } 70 | 71 | 72 | def load_lr_scheduler(lr_scheduler: str) -> SCHEDULER: 73 | lr_scheduler: str = lr_scheduler.lower() 74 | 75 | if lr_scheduler not in LR_SCHEDULERS: 76 | raise NotImplementedError(f'[-] not implemented lr_scheduler : {lr_scheduler}') 77 | 78 | return LR_SCHEDULERS[lr_scheduler] 79 | 80 | 81 | def get_supported_lr_schedulers(filters: Optional[Union[str, List[str]]] = None) -> List[str]: 82 | r"""Return list of available lr scheduler names, sorted alphabetically. 83 | 84 | :param filters: Optional[Union[str, List[str]]]. wildcard filter string that works with fmatch. if None, it will 85 | return the whole list. 86 | """ 87 | if filters is None: 88 | return sorted(LR_SCHEDULERS.keys()) 89 | 90 | include_filters: Sequence[str] = filters if isinstance(filters, (tuple, list)) else [filters] 91 | 92 | filtered_list: Set[str] = set() 93 | for include_filter in include_filters: 94 | filtered_list.update(fnmatch.filter(LR_SCHEDULERS.keys(), include_filter)) 95 | 96 | return sorted(filtered_list) 97 | -------------------------------------------------------------------------------- /pytorch_optimizer/lr_scheduler/chebyshev.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | from torch.optim import Optimizer 5 | from torch.optim.lr_scheduler import LambdaLR, LRScheduler 6 | 7 | 8 | def get_chebyshev_steps(num_epochs: int, small_m: float = 0.05, big_m: float = 1.0) -> np.ndarray: 9 | r"""Chebyshev steps. 10 | 11 | gamma_{t} = (M + m) / 2.0 - (M - m) * cos ((t - 0.5) * pi / T) / 2, where t = 1, ..., T 12 | 13 | :param num_epochs: int. stands for 'T' notation. 14 | :param small_m: float. stands for 'm' notation. 15 | :param big_m: float. stands for 'M' notation. 16 | :return: np.array. chebyshev_steps. 17 | """ 18 | c, r = (big_m + small_m) / 2.0, (big_m - small_m) / 2.0 19 | thetas = (np.arange(num_epochs) + 0.5) * np.pi / num_epochs # epoch starts from 0, so +0.5 instead of -0.5 20 | 21 | return 1.0 / (c - r * np.cos(thetas)) 22 | 23 | 24 | def get_chebyshev_permutation(num_epochs: int) -> np.ndarray: 25 | r"""Fractal chebyshev permutation. 26 | 27 | sigma_{2T} := interlace(sigma_{T}, 2T + 1 - sigma_{T}), where 28 | interlace([a_{1}, ..., a_{n}], [b_{1}, ..., b_{n}]) := [a_{1}, b_{1}, ..., n_{1}, b_{n}] 29 | 30 | :param num_epochs: int. number of epochs. 31 | """ 32 | perm = np.array([0]) 33 | while len(perm) < num_epochs: 34 | perm = np.vstack([perm, 2 * len(perm) - 1 - perm]).T.flatten() 35 | return perm 36 | 37 | 38 | def get_chebyshev_perm_steps(num_epochs: int) -> np.ndarray: 39 | r"""Get Chebyshev schedules. 40 | 41 | :param num_epochs: int. number of total epochs. 42 | """ 43 | steps: np.ndarray = get_chebyshev_steps(num_epochs) 44 | perm: np.ndarray = get_chebyshev_permutation(num_epochs - 2) 45 | return steps[perm] 46 | 47 | 48 | def get_chebyshev_lr_lambda(epoch: int, num_epochs: int, is_warmup: bool = False) -> float: 49 | r"""Get chebyshev learning rate ratio. 50 | 51 | :param epoch: int. current epochs. 52 | :param num_epochs: int. number of total epochs. 53 | :param is_warmup: bool. whether warm-up stage or not. 54 | """ 55 | if is_warmup: 56 | return 1.0 57 | 58 | epoch_power: int = np.power(2, int(np.log2(num_epochs - 1)) + 1) if num_epochs > 1 else 1 59 | scheduler = get_chebyshev_perm_steps(epoch_power) 60 | 61 | idx: int = epoch - 2 62 | if idx < 0: 63 | idx = 0 64 | elif idx > len(scheduler) - 1: 65 | idx = len(scheduler) - 1 66 | 67 | chebyshev_value: float = scheduler[idx] 68 | 69 | return chebyshev_value 70 | 71 | 72 | def get_chebyshev_schedule( 73 | optimizer: Optimizer, num_epochs: int, is_warmup: bool = False, last_epoch: int = -1 74 | ) -> LRScheduler: 75 | r"""Get chebyshev learning rate scheduler. 76 | 77 | :param optimizer: Optimizer. the optimizer for which to schedule the learning rate. 78 | :param num_epochs: int. number of total epochs. 79 | :param is_warmup: bool. whether warm-up stage or not. 80 | :param last_epoch: int. the index of the last epoch when resuming training. 81 | """ 82 | lr_scheduler = partial(get_chebyshev_lr_lambda, num_epochs=num_epochs, is_warmup=is_warmup) 83 | 84 | return LambdaLR(optimizer, lr_scheduler, last_epoch) 85 | -------------------------------------------------------------------------------- /pytorch_optimizer/lr_scheduler/experimental/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/pytorch_optimizer/lr_scheduler/experimental/__init__.py -------------------------------------------------------------------------------- /pytorch_optimizer/lr_scheduler/experimental/deberta_v3_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from pytorch_optimizer.base.type import PARAMETERS 4 | 5 | 6 | def deberta_v3_large_lr_scheduler( 7 | model: nn.Module, 8 | layer_low_threshold: int = 195, 9 | layer_middle_threshold: int = 323, 10 | head_param_start: int = 390, 11 | base_lr: float = 2e-5, 12 | head_lr: float = 1e-4, 13 | wd: float = 1e-2, 14 | ) -> PARAMETERS: 15 | """DeBERTa-v3 large layer-wise lr scheduler. 16 | 17 | Reference : https://github.com/gilfernandes/commonlit. 18 | 19 | :param model: nn.Module. model. based on Huggingface Transformers. 20 | :param layer_low_threshold: int. start of the 12 layers. 21 | :param layer_middle_threshold: int. end of the 24 layers. 22 | :param head_param_start: int. where the backbone ends (head starts). 23 | :param base_lr: float. base lr. 24 | :param head_lr: float. head_lr. 25 | :param wd: float. weight decay. 26 | """ 27 | named_parameters = list(model.named_parameters()) 28 | 29 | backbone_parameters = named_parameters[:head_param_start] 30 | head_parameters = named_parameters[head_param_start:] 31 | 32 | head_group = [params for (_, params) in head_parameters] 33 | 34 | parameters = [{'params': head_group, 'lr': head_lr}] 35 | 36 | for layer_num, (name, params) in enumerate(backbone_parameters): 37 | weight_decay: float = 0.0 if ('bias' in name) or ('LayerNorm.weight' in name) else wd 38 | 39 | lr = base_lr / 2.5 # 2e-5 40 | if layer_num >= layer_middle_threshold: 41 | lr = base_lr / 0.5 # 1e-4 42 | elif layer_num >= layer_low_threshold: 43 | lr = base_lr 44 | 45 | parameters.append({'params': params, 'weight_decay': weight_decay, 'lr': lr}) 46 | 47 | return parameters 48 | -------------------------------------------------------------------------------- /pytorch_optimizer/lr_scheduler/linear_warmup.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | 5 | from pytorch_optimizer.base.scheduler import BaseLinearWarmupScheduler 6 | 7 | 8 | class LinearScheduler(BaseLinearWarmupScheduler): 9 | r"""Linear LR Scheduler w/ linear warmup.""" 10 | 11 | def _step(self) -> float: 12 | return self.max_lr + (self.min_lr - self.max_lr) * (self.step_t - self.warmup_steps) / ( 13 | self.total_steps - self.warmup_steps 14 | ) 15 | 16 | 17 | class CosineScheduler(BaseLinearWarmupScheduler): 18 | r"""Cosine LR Scheduler w/ linear warmup.""" 19 | 20 | def _step(self) -> float: 21 | phase: float = (self.step_t - self.warmup_steps) / (self.total_steps - self.warmup_steps) * math.pi 22 | return self.min_lr + (self.max_lr - self.min_lr) * (np.cos(phase) + 1.0) / 2.0 23 | 24 | 25 | class PolyScheduler(BaseLinearWarmupScheduler): 26 | r"""Poly LR Scheduler. 27 | 28 | :param poly_order: float. lr scheduler decreases with steps. 29 | """ 30 | 31 | def __init__(self, optimizer, poly_order: float = 0.5, **kwargs): 32 | self.poly_order = poly_order 33 | 34 | if poly_order <= 0: 35 | raise ValueError(f'[-] poly_order must be positive. {poly_order}') 36 | 37 | super().__init__(optimizer, **kwargs) 38 | 39 | def _step(self) -> float: 40 | return self.min_lr + (self.max_lr - self.min_lr) * (self.step_t - self.warmup_steps) ** self.poly_order 41 | -------------------------------------------------------------------------------- /pytorch_optimizer/lr_scheduler/proportion.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | class ProportionScheduler: 5 | r"""ProportionScheduler (Rho Scheduler of GSAM). 6 | 7 | This scheduler outputs a value that evolves proportional to lr_scheduler. 8 | 9 | :param lr_scheduler: learning rate scheduler. 10 | :param max_lr: float. maximum lr. 11 | :param min_lr: float. minimum lr. 12 | :param max_value: float. maximum of rho. 13 | :param min_value: float. minimum of rho. 14 | """ 15 | 16 | def __init__( 17 | self, lr_scheduler, max_lr: float, min_lr: float = 0.0, max_value: float = 2.0, min_value: float = 2.0 18 | ): 19 | self.lr_scheduler = lr_scheduler 20 | self.max_lr = max_lr 21 | self.min_lr = min_lr 22 | self.max_value = max_value 23 | self.min_value = min_value 24 | 25 | self.step_t: int = 0 26 | self.last_lr: List[float] = [] 27 | 28 | self.step() 29 | 30 | def get_lr(self) -> float: 31 | return self.last_lr[0] 32 | 33 | def step(self) -> float: 34 | self.step_t += 1 35 | 36 | if hasattr(self.lr_scheduler, 'last_lr'): 37 | lr = self.lr_scheduler.last_lr[0] 38 | else: 39 | lr = self.lr_scheduler.optimizer.param_groups[0]['lr'] 40 | 41 | if self.max_lr > self.min_lr: 42 | value = self.min_value + (self.max_value - self.min_value) * (lr - self.min_lr) / ( 43 | self.max_lr - self.min_lr 44 | ) 45 | else: 46 | value = self.max_value 47 | 48 | self.last_lr = [value] 49 | 50 | return value 51 | -------------------------------------------------------------------------------- /pytorch_optimizer/lr_scheduler/rex.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from torch.optim import Optimizer 4 | from torch.optim.lr_scheduler import LRScheduler 5 | 6 | 7 | class REXScheduler(LRScheduler): 8 | r"""Revisiting Budgeted Training with an Improved Schedule. 9 | 10 | :param optimizer: Optimizer. wrapped optimizer instance. 11 | :param total_steps: int. number of steps to optimize. 12 | :param max_lr: float. max lr. 13 | :param min_lr: float. min lr. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | optimizer: Optimizer, 19 | total_steps: int, 20 | max_lr: float = 1.0, 21 | min_lr: float = 0.0, 22 | ): 23 | self.total_steps = total_steps 24 | self.max_lr = max_lr 25 | self.min_lr = min_lr 26 | 27 | self.step_t: int = 0 28 | self.base_lrs: List[float] = [] 29 | 30 | # record current value in self._last_lr to match API from torch.optim.lr_scheduler 31 | self.last_lr: List[float] = [self.max_lr] 32 | 33 | super().__init__(optimizer) 34 | 35 | self.init_lr() 36 | 37 | def init_lr(self) -> None: 38 | self.base_lrs = [] 39 | for param_group in self.optimizer.param_groups: 40 | param_group['lr'] = self.min_lr 41 | self.base_lrs.append(self.min_lr) 42 | 43 | def get_lr(self) -> float: 44 | return self.last_lr[0] 45 | 46 | def get_linear_lr(self) -> float: 47 | if self.step_t >= self.total_steps: 48 | return self.min_lr 49 | 50 | progress: float = self.step_t / self.total_steps 51 | 52 | return self.min_lr + (self.max_lr - self.min_lr) * ((1.0 - progress) / (1.0 - progress / 2.0)) 53 | 54 | def step(self, epoch: Optional[int] = None) -> float: 55 | value: float = self.get_linear_lr() 56 | 57 | self.step_t += 1 58 | 59 | if self.optimizer is not None: 60 | for param_group in self.optimizer.param_groups: 61 | param_group['lr'] = value 62 | 63 | self.last_lr = [value] 64 | 65 | return value 66 | -------------------------------------------------------------------------------- /pytorch_optimizer/optimizer/agc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_optimizer.optimizer.utils import unit_norm 4 | 5 | 6 | def agc( 7 | p: torch.Tensor, grad: torch.Tensor, agc_eps: float = 1e-3, agc_clip_val: float = 1e-2, eps: float = 1e-6 8 | ) -> torch.Tensor: 9 | r"""Clip gradient values in excess of the unit wise norm. 10 | 11 | :param p: torch.Tensor. parameter. 12 | :param grad: torch.Tensor, gradient. 13 | :param agc_eps: float. agc epsilon to clip the norm of parameter. 14 | :param agc_clip_val: float. norm clip. 15 | :param eps: float. simple stop from div by zero and no relation to standard optimizer eps. 16 | """ 17 | max_norm = unit_norm(p).clamp_min_(agc_eps).mul_(agc_clip_val) 18 | g_norm = unit_norm(grad).clamp_min_(eps) 19 | 20 | clipped_grad = grad * (max_norm / g_norm) 21 | 22 | return torch.where(g_norm > max_norm, clipped_grad, grad) 23 | -------------------------------------------------------------------------------- /pytorch_optimizer/optimizer/aggmo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_optimizer.base.exception import NoSparseGradientError 4 | from pytorch_optimizer.base.optimizer import BaseOptimizer 5 | from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, GROUP, LOSS, PARAMETERS 6 | 7 | 8 | class AggMo(BaseOptimizer): 9 | r"""Aggregated Momentum: Stability Through Passive Damping. 10 | 11 | :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. 12 | :param lr: float. learning rate. 13 | :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. 14 | :param weight_decay: float. weight decay (L2 penalty). 15 | :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. 16 | :param fixed_decay: bool. fix weight decay. 17 | :param maximize: bool. maximize the objective with respect to the params, instead of minimizing. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | params: PARAMETERS, 23 | lr: float = 1e-3, 24 | betas: BETAS = (0.0, 0.9, 0.99), 25 | weight_decay: float = 0.0, 26 | weight_decouple: bool = False, 27 | fixed_decay: bool = False, 28 | maximize: bool = False, 29 | **kwargs, 30 | ): 31 | self.validate_learning_rate(lr) 32 | self.validate_betas(betas) 33 | self.validate_non_negative(weight_decay, 'weight_decay') 34 | 35 | self.maximize = maximize 36 | 37 | defaults: DEFAULTS = { 38 | 'lr': lr, 39 | 'betas': betas, 40 | 'weight_decay': weight_decay, 41 | 'weight_decouple': weight_decouple, 42 | 'fixed_decay': fixed_decay, 43 | } 44 | 45 | super().__init__(params, defaults) 46 | 47 | def __str__(self) -> str: 48 | return 'AggMo' 49 | 50 | def init_group(self, group: GROUP, **kwargs) -> None: 51 | for p in group['params']: 52 | if p.grad is None: 53 | continue 54 | 55 | grad = p.grad 56 | if grad.is_sparse: 57 | raise NoSparseGradientError(str(self)) 58 | 59 | state = self.state[p] 60 | 61 | if len(state) == 0: 62 | state['momentum_buffer'] = {beta: torch.zeros_like(p) for beta in group['betas']} 63 | 64 | @torch.no_grad() 65 | def step(self, closure: CLOSURE = None) -> LOSS: 66 | loss: LOSS = None 67 | if closure is not None: 68 | with torch.enable_grad(): 69 | loss = closure() 70 | 71 | for group in self.param_groups: 72 | if 'step' not in group: 73 | self.init_group(group) 74 | group['step'] = 1 75 | else: 76 | group['step'] += 1 77 | 78 | betas = group['betas'] 79 | 80 | for p in group['params']: 81 | if p.grad is None: 82 | continue 83 | 84 | grad = p.grad 85 | 86 | self.maximize_gradient(grad, maximize=self.maximize) 87 | 88 | state = self.state[p] 89 | 90 | self.apply_weight_decay( 91 | p=p, 92 | grad=grad, 93 | lr=group['lr'], 94 | weight_decay=group['weight_decay'], 95 | weight_decouple=group['weight_decouple'], 96 | fixed_decay=group['fixed_decay'], 97 | ) 98 | 99 | for beta in betas: 100 | buf = state['momentum_buffer'][beta] 101 | buf.mul_(beta).add_(grad) 102 | 103 | p.add_(buf, alpha=-group['lr'] / len(betas)) 104 | 105 | return loss 106 | -------------------------------------------------------------------------------- /pytorch_optimizer/optimizer/experimental/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/pytorch_optimizer/optimizer/experimental/__init__.py -------------------------------------------------------------------------------- /pytorch_optimizer/optimizer/focus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_optimizer.base.exception import NoSparseGradientError 4 | from pytorch_optimizer.base.optimizer import BaseOptimizer 5 | from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, GROUP, LOSS, PARAMETERS 6 | 7 | 8 | class FOCUS(BaseOptimizer): 9 | r"""First Order Concentrated Updating Scheme. 10 | 11 | :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. 12 | :param lr: float. learning rate. 13 | :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. 14 | :param gamma: float. control the strength of the attraction. 15 | :param weight_decay: float. weight decay (L2 penalty). 16 | :param maximize: bool. maximize the objective with respect to the params, instead of minimizing. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | params: PARAMETERS, 22 | lr: float = 1e-2, 23 | betas: BETAS = (0.9, 0.999), 24 | gamma: float = 0.1, 25 | weight_decay: float = 0.0, 26 | maximize: bool = False, 27 | **kwargs, 28 | ): 29 | self.validate_learning_rate(lr) 30 | self.validate_betas(betas) 31 | self.validate_range(gamma, 'gamma', 0.0, 1.0, '[)') 32 | self.validate_non_negative(weight_decay, 'weight_decay') 33 | 34 | self.maximize = maximize 35 | 36 | defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'gamma': gamma, 'weight_decay': weight_decay} 37 | 38 | super().__init__(params, defaults) 39 | 40 | def __str__(self) -> str: 41 | return 'FOCUS' 42 | 43 | def init_group(self, group: GROUP, **kwargs) -> None: 44 | for p in group['params']: 45 | if p.grad is None: 46 | continue 47 | 48 | grad = p.grad 49 | if grad.is_sparse: 50 | raise NoSparseGradientError(str(self)) 51 | 52 | state = self.state[p] 53 | 54 | if len(state) == 0: 55 | state['exp_avg'] = torch.zeros_like(p) 56 | state['pbar'] = torch.zeros_like(p) 57 | 58 | @torch.no_grad() 59 | def step(self, closure: CLOSURE = None) -> LOSS: 60 | loss: LOSS = None 61 | if closure is not None: 62 | with torch.enable_grad(): 63 | loss = closure() 64 | 65 | for group in self.param_groups: 66 | if 'step' not in group: 67 | self.init_group(group) 68 | group['step'] = 1 69 | else: 70 | group['step'] += 1 71 | 72 | beta1, beta2 = group['betas'] 73 | 74 | bias_correction2: float = self.debias(beta2, group['step']) 75 | 76 | weight_decay: float = group['weight_decay'] 77 | 78 | for p in group['params']: 79 | if p.grad is None: 80 | continue 81 | 82 | grad = p.grad 83 | 84 | self.maximize_gradient(grad, maximize=self.maximize) 85 | 86 | state = self.state[p] 87 | 88 | exp_avg, pbar = state['exp_avg'], state['pbar'] 89 | 90 | p, grad, exp_avg, pbar = self.view_as_real(p, grad, exp_avg, pbar) 91 | 92 | exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) 93 | pbar.mul_(beta2).add_(p, alpha=1.0 - beta2) 94 | 95 | pbar_hat = pbar / bias_correction2 96 | 97 | if weight_decay > 0.0: 98 | p.add_(pbar_hat, alpha=-group['lr'] * weight_decay) 99 | 100 | update = (p - pbar_hat).sign_().mul_(group['gamma']).add_(torch.sign(exp_avg)) 101 | 102 | p.add_(update, alpha=-group['lr']) 103 | 104 | return loss 105 | -------------------------------------------------------------------------------- /pytorch_optimizer/optimizer/fromage.py: -------------------------------------------------------------------------------- 1 | """Copyright (C) 2020 Jeremy Bernstein, Arash Vahdat, Yisong Yue & Ming-Yu Liu. All rights reserved. 2 | 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/). 4 | """ 5 | 6 | import math 7 | from typing import Optional 8 | 9 | import torch 10 | 11 | from pytorch_optimizer.base.exception import NoSparseGradientError 12 | from pytorch_optimizer.base.optimizer import BaseOptimizer 13 | from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, GROUP, LOSS, PARAMETERS 14 | 15 | 16 | class Fromage(BaseOptimizer): 17 | r"""On the distance between two neural networks and the stability of learning. 18 | 19 | :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. 20 | :param lr: float. learning rate. 21 | :param p_bound: Optional[float]. Restricts the optimisation to a bounded set. A value of 2.0 restricts parameter 22 | norms to lie within 2x their initial norms. This regularises the model class. 23 | :param maximize: bool. maximize the objective with respect to the params, instead of minimizing. 24 | """ 25 | 26 | def __init__( 27 | self, params: PARAMETERS, lr: float = 1e-2, p_bound: Optional[float] = None, maximize: bool = False, **kwargs 28 | ): 29 | self.validate_learning_rate(lr) 30 | 31 | self.p_bound = p_bound 32 | self.maximize = maximize 33 | 34 | defaults: DEFAULTS = {'lr': lr} 35 | 36 | super().__init__(params, defaults) 37 | 38 | def __str__(self) -> str: 39 | return 'Fromage' 40 | 41 | def init_group(self, group: GROUP, **kwargs) -> None: 42 | for p in group['params']: 43 | if p.grad is None: 44 | continue 45 | 46 | grad = p.grad 47 | if grad.is_sparse: 48 | raise NoSparseGradientError(str(self)) 49 | 50 | state = self.state[p] 51 | 52 | if len(state) == 0 and self.p_bound is not None: 53 | state['max'] = p.norm().mul_(self.p_bound) 54 | 55 | @torch.no_grad() 56 | def step(self, closure: CLOSURE = None) -> LOSS: 57 | loss: LOSS = None 58 | if closure is not None: 59 | with torch.enable_grad(): 60 | loss = closure() 61 | 62 | for group in self.param_groups: 63 | if 'step' not in group: 64 | self.init_group(group) 65 | group['step'] = 1 66 | else: 67 | group['step'] += 1 68 | 69 | pre_factor: float = math.sqrt(1 + group['lr'] ** 2) 70 | 71 | for p in group['params']: 72 | if p.grad is None: 73 | continue 74 | 75 | grad = p.grad 76 | 77 | self.maximize_gradient(grad, maximize=self.maximize) 78 | 79 | state = self.state[p] 80 | 81 | p, grad = self.view_as_real(p, grad) 82 | 83 | p_norm, g_norm = p.norm(), grad.norm() 84 | 85 | if p_norm > 0.0 and g_norm > 0.0: 86 | p.add_(grad * (p_norm / g_norm), alpha=-group['lr']) 87 | else: 88 | p.add_(grad, alpha=-group['lr']) 89 | 90 | p.div_(pre_factor) 91 | 92 | if self.p_bound is not None: 93 | p_norm = p.norm() 94 | if p_norm > state['max']: 95 | p.mul_(state['max']).div_(p_norm) 96 | 97 | return loss 98 | -------------------------------------------------------------------------------- /pytorch_optimizer/optimizer/ftrl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_optimizer.base.exception import NoSparseGradientError 4 | from pytorch_optimizer.base.optimizer import BaseOptimizer 5 | from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, GROUP, LOSS, PARAMETERS 6 | 7 | 8 | class FTRL(BaseOptimizer): 9 | r"""Follow The Regularized Leader. 10 | 11 | :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. 12 | :param lr: float. learning rate. 13 | :param lr_power: float. controls how the learning rate decreases during training. use zero for a fixed learning 14 | rate. 15 | :param beta: float. beta value in the paper. 16 | :param lambda_1: float. L1 regularization parameter. 17 | :param lambda_2: float. L2 regularization parameter. 18 | :param maximize: bool. maximize the objective with respect to the params, instead of minimizing. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | params: PARAMETERS, 24 | lr: float = 1e-3, 25 | lr_power: float = -0.5, 26 | beta: float = 0.0, 27 | lambda_1: float = 0.0, 28 | lambda_2: float = 0.0, 29 | maximize: bool = False, 30 | **kwargs, 31 | ): 32 | self.validate_learning_rate(lr) 33 | self.validate_non_negative(beta, 'beta') 34 | self.validate_non_positive(lr_power, 'lr_power') 35 | self.validate_non_negative(lambda_1, 'lambda_1') 36 | self.validate_non_negative(lambda_2, 'lambda_2') 37 | 38 | self.maximize = maximize 39 | 40 | defaults: DEFAULTS = {'lr': lr, 'lr_power': lr_power, 'beta': beta, 'lambda_1': lambda_1, 'lambda_2': lambda_2} 41 | 42 | super().__init__(params, defaults) 43 | 44 | def __str__(self) -> str: 45 | return 'FTRL' 46 | 47 | def init_group(self, group: GROUP, **kwargs) -> None: 48 | for p in group['params']: 49 | if p.grad is None: 50 | continue 51 | 52 | grad = p.grad 53 | if grad.is_sparse: 54 | raise NoSparseGradientError(str(self)) 55 | 56 | state = self.state[p] 57 | 58 | if len(state) == 0: 59 | state['z'] = torch.zeros_like(p) 60 | state['n'] = torch.zeros_like(p) 61 | 62 | @torch.no_grad() 63 | def step(self, closure: CLOSURE = None) -> LOSS: 64 | loss: LOSS = None 65 | if closure is not None: 66 | with torch.enable_grad(): 67 | loss = closure() 68 | 69 | for group in self.param_groups: 70 | if 'step' not in group: 71 | self.init_group(group) 72 | group['step'] = 1 73 | else: 74 | group['step'] += 1 75 | 76 | for p in group['params']: 77 | if p.grad is None: 78 | continue 79 | 80 | grad = p.grad 81 | 82 | self.maximize_gradient(grad, maximize=self.maximize) 83 | 84 | state = self.state[p] 85 | 86 | z, n = state['z'], state['n'] 87 | 88 | p, grad, z, n = self.view_as_real(p, grad, z, n) 89 | 90 | grad_p2 = grad.pow(2) 91 | 92 | sigma = (n + grad_p2).pow_(-group['lr_power']).sub_(n.pow(-group['lr_power'])).div_(group['lr']) 93 | 94 | z.add_(grad).sub_(sigma.mul(p)) 95 | n.add_(grad_p2) 96 | 97 | update = z.sign().mul_(group['lambda_1']).sub_(z) 98 | update.div_((group['beta'] + n.sqrt()).div_(group['lr']).add_(group['lambda_2'])) 99 | 100 | p.copy_(update) 101 | p.masked_fill_(z.abs() < group['lambda_1'], 0.0) 102 | 103 | return loss 104 | -------------------------------------------------------------------------------- /pytorch_optimizer/optimizer/gradient_centralization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def centralize_gradient(grad: torch.Tensor, gc_conv_only: bool = False) -> None: 5 | r"""Gradient Centralization (GC). 6 | 7 | :param grad: torch.Tensor. gradient. 8 | :param gc_conv_only: bool. 'False' for both conv & fc layers. 9 | """ 10 | size: int = grad.dim() 11 | if (gc_conv_only and size > 3) or (not gc_conv_only and size > 1): 12 | grad.add_(-grad.mean(dim=tuple(range(1, size)), keepdim=True)) 13 | -------------------------------------------------------------------------------- /pytorch_optimizer/optimizer/gravity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_optimizer.base.exception import NoSparseGradientError 4 | from pytorch_optimizer.base.optimizer import BaseOptimizer 5 | from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, GROUP, LOSS, PARAMETERS 6 | 7 | 8 | class Gravity(BaseOptimizer): 9 | r"""a Kinematic Approach on Optimization in Deep Learning. 10 | 11 | :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. 12 | :param lr: float. learning rate. 13 | :param alpha: float. alpha controls the V initialization. 14 | :param beta: float. beta will be used to compute running average of V. 15 | :param maximize: bool. maximize the objective with respect to the params, instead of minimizing. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | params: PARAMETERS, 21 | lr: float = 1e-2, 22 | alpha: float = 0.01, 23 | beta: float = 0.9, 24 | maximize: bool = False, 25 | **kwargs, 26 | ): 27 | self.validate_learning_rate(lr) 28 | self.validate_range(alpha, 'alpha', 0.0, 1.0) 29 | self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]') 30 | 31 | self.maximize = maximize 32 | 33 | defaults: DEFAULTS = {'lr': lr, 'alpha': alpha, 'beta': beta} 34 | 35 | super().__init__(params, defaults) 36 | 37 | def __str__(self) -> str: 38 | return 'Gravity' 39 | 40 | def init_group(self, group: GROUP, **kwargs) -> None: 41 | for p in group['params']: 42 | if p.grad is None: 43 | continue 44 | 45 | grad = p.grad 46 | if grad.is_sparse: 47 | raise NoSparseGradientError(str(self)) 48 | 49 | state = self.state[p] 50 | 51 | if len(state) == 0: 52 | state['v'] = torch.empty_like(p).normal_(mean=0.0, std=group['alpha'] / group['lr']) 53 | 54 | @torch.no_grad() 55 | def step(self, closure: CLOSURE = None) -> LOSS: 56 | loss: LOSS = None 57 | if closure is not None: 58 | with torch.enable_grad(): 59 | loss = closure() 60 | 61 | for group in self.param_groups: 62 | if 'step' not in group: 63 | self.init_group(group) 64 | group['step'] = 1 65 | else: 66 | group['step'] += 1 67 | 68 | beta_t: float = (group['beta'] * group['step'] + 1) / (group['step'] + 2) 69 | 70 | for p in group['params']: 71 | if p.grad is None: 72 | continue 73 | 74 | grad = p.grad 75 | 76 | self.maximize_gradient(grad, maximize=self.maximize) 77 | 78 | state = self.state[p] 79 | 80 | v = state['v'] 81 | 82 | p, grad, v = self.view_as_real(p, grad, v) 83 | 84 | m = 1.0 / grad.abs().max() 85 | zeta = grad / (1.0 + (grad / m) ** 2) 86 | 87 | v.mul_(beta_t).add_(zeta, alpha=1.0 - beta_t) 88 | 89 | p.add_(v, alpha=-group['lr']) 90 | 91 | return loss 92 | -------------------------------------------------------------------------------- /pytorch_optimizer/optimizer/kate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_optimizer.base.exception import NoSparseGradientError 4 | from pytorch_optimizer.base.optimizer import BaseOptimizer 5 | from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, GROUP, LOSS, PARAMETERS 6 | 7 | 8 | class Kate(BaseOptimizer): 9 | r"""Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad. 10 | 11 | :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. 12 | :param lr: float. learning rate. 13 | :param delta: float. delta. 0.0 or 1e-8. 14 | :param weight_decay: float. weight decay (L2 penalty). 15 | :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. 16 | :param fixed_decay: bool. fix weight decay. 17 | :param eps: float. epsilon value. 18 | :param maximize: bool. maximize the objective with respect to the params, instead of minimizing. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | params: PARAMETERS, 24 | lr: float = 1e-3, 25 | delta: float = 0.0, 26 | weight_decay: float = 0.0, 27 | weight_decouple: bool = True, 28 | fixed_decay: bool = False, 29 | eps: float = 1e-8, 30 | maximize: bool = False, 31 | **kwargs, 32 | ): 33 | self.validate_learning_rate(lr) 34 | self.validate_range(delta, 'delta', 0.0, 1.0, '[)') 35 | self.validate_non_negative(weight_decay, 'weight_decay') 36 | self.validate_non_negative(eps, 'eps') 37 | 38 | self.maximize = maximize 39 | 40 | defaults: DEFAULTS = { 41 | 'lr': lr, 42 | 'delta': delta, 43 | 'weight_decay': weight_decay, 44 | 'weight_decouple': weight_decouple, 45 | 'fixed_decay': fixed_decay, 46 | 'eps': eps, 47 | } 48 | 49 | super().__init__(params, defaults) 50 | 51 | def __str__(self) -> str: 52 | return 'Kate' 53 | 54 | def init_group(self, group: GROUP, **kwargs) -> None: 55 | for p in group['params']: 56 | if p.grad is None: 57 | continue 58 | 59 | grad = p.grad 60 | if grad.is_sparse: 61 | raise NoSparseGradientError(str(self)) 62 | 63 | state = self.state[p] 64 | 65 | if len(state) == 0: 66 | state['m'] = torch.zeros_like(p) 67 | state['b'] = torch.zeros_like(p) 68 | 69 | @torch.no_grad() 70 | def step(self, closure: CLOSURE = None) -> LOSS: 71 | loss: LOSS = None 72 | if closure is not None: 73 | with torch.enable_grad(): 74 | loss = closure() 75 | 76 | for group in self.param_groups: 77 | if 'step' not in group: 78 | self.init_group(group) 79 | group['step'] = 1 80 | else: 81 | group['step'] += 1 82 | 83 | for p in group['params']: 84 | if p.grad is None: 85 | continue 86 | 87 | grad = p.grad 88 | 89 | self.maximize_gradient(grad, maximize=self.maximize) 90 | 91 | state = self.state[p] 92 | 93 | m, b = state['m'], state['b'] 94 | 95 | p, grad, m, b = self.view_as_real(p, grad, m, b) 96 | 97 | self.apply_weight_decay( 98 | p=p, 99 | grad=p.grad, 100 | lr=group['lr'], 101 | weight_decay=group['weight_decay'], 102 | weight_decouple=group['weight_decouple'], 103 | fixed_decay=group['fixed_decay'], 104 | ) 105 | 106 | grad_p2 = grad.pow(2) 107 | 108 | b.mul_(b).add_(grad_p2).add_(group['eps']) 109 | m.mul_(m).add_(grad_p2, alpha=group['delta']).add_(grad_p2 / b).sqrt_() 110 | 111 | update = m.mul(grad).div_(b) 112 | 113 | p.add_(update, alpha=-group['lr']) 114 | 115 | b.sqrt_() 116 | 117 | return loss 118 | -------------------------------------------------------------------------------- /pytorch_optimizer/optimizer/msvag.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_optimizer.base.exception import NoSparseGradientError 4 | from pytorch_optimizer.base.optimizer import BaseOptimizer 5 | from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, GROUP, LOSS, PARAMETERS 6 | 7 | 8 | class MSVAG(BaseOptimizer): 9 | r"""Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients. 10 | 11 | :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. 12 | :param lr: float. learning rate. 13 | :param beta: float. Moving average (momentum) constant (scalar tensor or float value). 14 | :param maximize: bool. maximize the objective with respect to the params, instead of minimizing. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | params: PARAMETERS, 20 | lr: float = 1e-2, 21 | beta: float = 0.9, 22 | maximize: bool = False, 23 | **kwargs, 24 | ): 25 | self.validate_learning_rate(lr) 26 | self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]') 27 | 28 | self.maximize = maximize 29 | 30 | defaults: DEFAULTS = {'lr': lr, 'beta': beta} 31 | 32 | super().__init__(params, defaults) 33 | 34 | def __str__(self) -> str: 35 | return 'MSVAG' 36 | 37 | def init_group(self, group: GROUP, **kwargs) -> None: 38 | for p in group['params']: 39 | if p.grad is None: 40 | continue 41 | 42 | grad = p.grad 43 | if grad.is_sparse: 44 | raise NoSparseGradientError(str(self)) 45 | 46 | state = self.state[p] 47 | 48 | if len(state) == 0: 49 | state['exp_avg'] = torch.zeros_like(p) 50 | state['exp_avg_sq'] = torch.zeros_like(p) 51 | state['s'] = torch.zeros_like(p) 52 | 53 | @staticmethod 54 | def get_rho(beta_power: float, beta: float) -> float: 55 | r"""Get rho.""" 56 | rho: float = (1.0 - beta_power ** 2) * (1.0 - beta) ** 2 # fmt: skip 57 | rho /= (1.0 - beta) * (1.0 - beta_power) ** 2 58 | return min(rho, 0.9999) 59 | 60 | @torch.no_grad() 61 | def step(self, closure: CLOSURE = None) -> LOSS: 62 | loss: LOSS = None 63 | if closure is not None: 64 | with torch.enable_grad(): 65 | loss = closure() 66 | 67 | for group in self.param_groups: 68 | if 'step' not in group: 69 | self.init_group(group) 70 | group['step'] = 1 71 | else: 72 | group['step'] += 1 73 | 74 | beta: float = group['beta'] 75 | beta_power: float = beta ** group['step'] 76 | 77 | for p in group['params']: 78 | if p.grad is None: 79 | continue 80 | 81 | grad = p.grad 82 | 83 | self.maximize_gradient(grad, maximize=self.maximize) 84 | 85 | state = self.state[p] 86 | 87 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 88 | 89 | p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq) 90 | 91 | exp_avg.mul_(beta).add_(grad, alpha=1.0 - beta) 92 | exp_avg_sq.mul_(beta).addcmul_(grad, grad, value=1.0 - beta) 93 | 94 | m = exp_avg.div(beta_power) 95 | v = exp_avg_sq.div(beta_power) 96 | 97 | rho: float = self.get_rho(beta_power, beta) 98 | 99 | m_p2 = m.pow(2) 100 | s = (v - m_p2).div_(1.0 - rho) 101 | 102 | factor = m_p2.div(m_p2 + rho * s) 103 | torch.nan_to_num(factor, nan=0.0, out=factor) 104 | factor.clamp_(0.0, 1.0) 105 | 106 | p.add_(m * factor, alpha=-group['lr']) 107 | 108 | return loss 109 | -------------------------------------------------------------------------------- /pytorch_optimizer/optimizer/orthograd.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict 2 | 3 | import torch 4 | from torch.optim import Optimizer 5 | 6 | from pytorch_optimizer.base.optimizer import BaseOptimizer 7 | from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, GROUP, LOSS, OPTIMIZER_INSTANCE_OR_CLASS, STATE 8 | 9 | 10 | class OrthoGrad(BaseOptimizer): 11 | r"""Grokking at the Edge of Numerical Stability. 12 | 13 | A wrapper optimizer that projects gradients to be orthogonal to the current parameters before performing an update. 14 | 15 | :param optimizer: OPTIMIZER_INSTANCE_OR_CLASS. base optimizer. 16 | """ 17 | 18 | def __init__(self, optimizer: OPTIMIZER_INSTANCE_OR_CLASS, **kwargs) -> None: 19 | self._optimizer_step_pre_hooks: Dict[int, Callable] = {} 20 | self._optimizer_step_post_hooks: Dict[int, Callable] = {} 21 | self.eps: float = 1e-30 22 | 23 | self.optimizer: Optimizer = self.load_optimizer(optimizer, **kwargs) 24 | 25 | self.defaults: DEFAULTS = self.optimizer.defaults 26 | 27 | def __str__(self) -> str: 28 | return 'OrthoGrad' 29 | 30 | @property 31 | def param_groups(self): 32 | return self.optimizer.param_groups 33 | 34 | @property 35 | def state(self) -> STATE: 36 | return self.optimizer.state 37 | 38 | def state_dict(self) -> STATE: 39 | return self.optimizer.state_dict() 40 | 41 | def load_state_dict(self, state_dict: STATE) -> None: 42 | self.optimizer.load_state_dict(state_dict) 43 | 44 | @torch.no_grad() 45 | def zero_grad(self, set_to_none: bool = True) -> None: 46 | self.optimizer.zero_grad(set_to_none=set_to_none) 47 | 48 | def init_group(self, group: GROUP, **kwargs) -> None: 49 | pass 50 | 51 | @torch.no_grad() 52 | def apply_orthogonal_gradients(self, params) -> None: 53 | for p in params: 54 | if p.grad is None or p.grad.is_sparse or torch.is_complex(p): 55 | continue 56 | 57 | w = p.view(-1) 58 | g = p.grad.view(-1) 59 | 60 | proj = torch.dot(w, g).div_(torch.dot(w, w).add_(self.eps)) 61 | g_ortho = g.to(dtype=torch.float32, copy=True).sub_(w, alpha=proj) 62 | g_ortho_scaled = g_ortho.mul_(g.norm(2).div_(g_ortho.norm(2).add_(self.eps))) 63 | 64 | p.grad.copy_(g_ortho_scaled.view_as(p.grad)) 65 | 66 | @torch.no_grad() 67 | def step(self, closure: CLOSURE = None) -> LOSS: 68 | for group in self.param_groups: 69 | self.apply_orthogonal_gradients(group['params']) 70 | return self.optimizer.step(closure) 71 | -------------------------------------------------------------------------------- /pytorch_optimizer/optimizer/psgd_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | from torch.linalg import vector_norm 5 | 6 | 7 | def damped_pair_vg(g: torch.Tensor, damp: float = 2 ** -13) -> Tuple[torch.Tensor, torch.Tensor]: # fmt: skip 8 | r"""Get damped pair v and g. 9 | 10 | Instead of return (v, g), it returns pair (v, g + sqrt(eps)*mean(abs(g))*v) 11 | such that the covariance matrix of the modified g is lower bound by eps * (mean(abs(g)))**2 * I 12 | This should damp the pre-conditioner to encourage numerical stability. 13 | The default amount of damping is 2**(-13), slightly smaller than sqrt(eps('single')). 14 | 15 | If v is integrated out, let's just use the modified g; 16 | If hvp is used, recommend to use L2 regularization to lower bound the Hessian, although this method also works. 17 | 18 | Please check example 19 | https://github.com/lixilinx/psgd_torch/blob/master/misc/psgd_with_finite_precision_arithmetic.py 20 | for the rationale to set default damping level to 2**(-13). 21 | """ 22 | v = torch.randn_like(g) 23 | return v, g + damp * torch.mean(torch.abs(g)) * v 24 | 25 | 26 | def norm_lower_bound(a: torch.Tensor) -> torch.Tensor: 27 | r"""Get a cheap lower bound for the spectral norm of A. 28 | 29 | Numerical results on random matrices with a wide range of distributions and sizes suggest, 30 | norm(A) <= sqrt(2) * norm_lower_bound(A) 31 | Looks to be a very tight lower bound. 32 | """ 33 | max_abs = torch.max(torch.abs(a)) 34 | if max_abs <= 0: 35 | return max_abs 36 | 37 | a.div_(max_abs) 38 | 39 | aa = torch.real(a * a.conj()) 40 | value0, i = torch.max(torch.sum(aa, dim=0), 0) 41 | value1, j = torch.max(torch.sum(aa, dim=1), 0) 42 | 43 | if value0 > value1: 44 | x = a[:, i].conj() @ a 45 | return max_abs * vector_norm((x / vector_norm(x)) @ a.H) 46 | 47 | x = a @ a[j].conj() 48 | return max_abs * vector_norm(a.H @ (x / vector_norm(x))) 49 | 50 | 51 | def woodbury_identity(inv_a: torch.Tensor, u: torch.Tensor, v: torch.Tensor) -> None: 52 | r"""Get the Woodbury identity. 53 | 54 | inv(A + U * V) = inv(A) - inv(A) * U * inv(I + V * inv(A) * U) * V * inv(A) 55 | 56 | with inplace update of inv_a. 57 | 58 | Note that using the Woodbury identity multiple times could accumulate numerical errors. 59 | """ 60 | inv_au = inv_a @ u 61 | v_inv_au = v @ inv_au 62 | 63 | ident = torch.eye(v_inv_au.shape[0], dtype=v_inv_au.dtype, device=v_inv_au.device) 64 | inv_a.sub_(inv_au @ torch.linalg.solve(ident + v_inv_au, v @ inv_a)) 65 | 66 | 67 | def triu_with_diagonal_and_above(a: torch.Tensor) -> torch.Tensor: 68 | r"""Get triu with diagonal and above. 69 | 70 | It is useful as for a small A, the R of QR decomposition qr(I + A) is about I + triu(A, 0) + triu(A, 1) 71 | """ 72 | return torch.triu(a, diagonal=0) + torch.triu(a, diagonal=1) 73 | 74 | 75 | def update_precondition_dense( 76 | q: torch.Tensor, dxs: List[torch.Tensor], dgs: List[torch.Tensor], step: float = 0.01, eps: float = 1.2e-38 77 | ) -> torch.Tensor: 78 | r"""Update dense pre-conditioner P = Q^T * Q. 79 | 80 | :param q: torch.Tensor. Cholesky factor of pre-conditioner with positive diagonal entries. 81 | :param dxs: List[torch.Tensor]. list of perturbations of parameters. 82 | :param dgs: List[torch.Tensor]. list of perturbations of gradients. 83 | :param step: float. update step size normalized to range [0, 1]. 84 | :param eps: float. an offset to avoid division by zero. 85 | """ 86 | dx = torch.cat([torch.reshape(x, [-1, 1]) for x in dxs]) 87 | dg = torch.cat([torch.reshape(g, [-1, 1]) for g in dgs]) 88 | 89 | a = q.mm(dg) 90 | b = torch.linalg.solve_triangular(q.t(), dx, upper=False) 91 | 92 | grad = torch.triu(a.mm(a.t()) - b.mm(b.t())) 93 | 94 | return q - (step / norm_lower_bound(grad).add_(eps)) * grad.mm(q) 95 | -------------------------------------------------------------------------------- /pytorch_optimizer/optimizer/qhm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_optimizer.base.exception import NoSparseGradientError 4 | from pytorch_optimizer.base.optimizer import BaseOptimizer 5 | from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, GROUP, LOSS, PARAMETERS 6 | 7 | 8 | class QHM(BaseOptimizer): 9 | r"""Quasi-hyperbolic momentum (QHM) optimization algorithm. 10 | 11 | :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. 12 | :param lr: float. learning rate. 13 | :param momentum: float. momentum factor. 14 | :param nu: float. immediate discount factor used to estimate the gradient and its square. 15 | :param weight_decay: float. weight decay (L2 penalty). 16 | :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. 17 | :param fixed_decay: bool. fix weight decay. 18 | :param maximize: bool. maximize the objective with respect to the params, instead of minimizing. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | params: PARAMETERS, 24 | lr: float = 1e-3, 25 | momentum: float = 0.0, 26 | nu: float = 1.0, 27 | weight_decay: float = 0.0, 28 | weight_decouple: bool = False, 29 | fixed_decay: bool = False, 30 | maximize: bool = False, 31 | **kwargs, 32 | ): 33 | self.validate_learning_rate(lr) 34 | self.validate_range(momentum, 'momentum', 0.0, 1.0) 35 | self.validate_non_negative(weight_decay, 'weight_decay') 36 | self.validate_nus(nu) 37 | 38 | self.maximize = maximize 39 | 40 | defaults: DEFAULTS = { 41 | 'lr': lr, 42 | 'momentum': momentum, 43 | 'nu': nu, 44 | 'weight_decay': weight_decay, 45 | 'weight_decouple': weight_decouple, 46 | 'fixed_decay': fixed_decay, 47 | } 48 | 49 | super().__init__(params, defaults) 50 | 51 | def __str__(self) -> str: 52 | return 'QHM' 53 | 54 | def init_group(self, group: GROUP, **kwargs) -> None: 55 | for p in group['params']: 56 | if p.grad is None: 57 | continue 58 | 59 | grad = p.grad 60 | if grad.is_sparse: 61 | raise NoSparseGradientError(str(self)) 62 | 63 | state = self.state[p] 64 | 65 | if len(state) == 0: 66 | state['momentum_buffer'] = torch.zeros_like(p) 67 | 68 | @torch.no_grad() 69 | def step(self, closure: CLOSURE = None) -> LOSS: 70 | loss: LOSS = None 71 | if closure is not None: 72 | with torch.enable_grad(): 73 | loss = closure() 74 | 75 | for group in self.param_groups: 76 | if 'step' not in group: 77 | self.init_group(group) 78 | group['step'] = 1 79 | else: 80 | group['step'] += 1 81 | 82 | for p in group['params']: 83 | if p.grad is None: 84 | continue 85 | 86 | grad = p.grad 87 | 88 | self.maximize_gradient(grad, maximize=self.maximize) 89 | 90 | state = self.state[p] 91 | 92 | buf = state['momentum_buffer'] 93 | 94 | p, grad, buf = self.view_as_real(p, grad, buf) 95 | 96 | self.apply_weight_decay( 97 | p=p, 98 | grad=grad, 99 | lr=group['lr'], 100 | weight_decay=group['weight_decay'], 101 | weight_decouple=group['weight_decouple'], 102 | fixed_decay=group['fixed_decay'], 103 | ) 104 | 105 | buf.mul_(group['momentum']).add_(grad, alpha=1.0 - group['momentum']) 106 | 107 | p.add_(buf, alpha=-group['lr'] * group['nu']) 108 | p.add_(grad, alpha=-group['lr'] * (1.0 - group['nu'])) 109 | 110 | return loss 111 | -------------------------------------------------------------------------------- /pytorch_optimizer/optimizer/srmm.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | 5 | from pytorch_optimizer.base.exception import NoComplexParameterError, NoSparseGradientError 6 | from pytorch_optimizer.base.optimizer import BaseOptimizer 7 | from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, GROUP, LOSS, PARAMETERS 8 | 9 | 10 | class SRMM(BaseOptimizer): 11 | """Stochastic regularized majorization-minimization with weakly convex and multi-convex surrogates. 12 | 13 | :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. 14 | :param lr: float. learning rate. 15 | :param beta: float. adaptivity weight. 16 | :param memory_length: Optional[int]. internal memory length for moving average. None for no refreshing. 17 | :param maximize: bool. maximize the objective with respect to the params, instead of minimizing. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | params: PARAMETERS, 23 | lr: float = 0.01, 24 | beta: float = 0.5, 25 | memory_length: Optional[int] = 100, 26 | maximize: bool = False, 27 | **kwargs, 28 | ): 29 | self.validate_learning_rate(lr) 30 | self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]') 31 | 32 | self.maximize = maximize 33 | 34 | defaults: DEFAULTS = {'lr': lr, 'beta': beta, 'memory_length': memory_length} 35 | 36 | super().__init__(params, defaults) 37 | 38 | self.base_lrs: List[float] = [group['lr'] for group in self.param_groups] 39 | 40 | def __str__(self) -> str: 41 | return 'SRMM' 42 | 43 | def init_group(self, group: GROUP, **kwargs) -> None: 44 | for p in group['params']: 45 | if p.grad is None: 46 | continue 47 | 48 | grad = p.grad 49 | if grad.is_sparse: 50 | raise NoSparseGradientError(str(self)) 51 | 52 | if torch.is_complex(p): 53 | raise NoComplexParameterError(str(self)) 54 | 55 | state = self.state[p] 56 | 57 | if len(state) == 0: 58 | state['mov_avg_grad'] = torch.zeros_like(grad) 59 | state['mov_avg_param'] = torch.zeros_like(grad) 60 | 61 | @torch.no_grad() 62 | def step(self, closure: CLOSURE = None) -> LOSS: 63 | loss: LOSS = None 64 | if closure is not None: 65 | with torch.enable_grad(): 66 | loss = closure() 67 | 68 | for group in self.param_groups: 69 | if 'step' not in group: 70 | self.init_group(group) 71 | group['step'] = 1 72 | else: 73 | group['step'] += 1 74 | 75 | w_t: float = ( 76 | (group['step'] % (group['memory_length'] if group['memory_length'] is not None else 1)) + 1 77 | ) ** -group['beta'] 78 | 79 | for p in group['params']: 80 | if p.grad is None: 81 | continue 82 | 83 | grad = p.grad 84 | 85 | self.maximize_gradient(grad, maximize=self.maximize) 86 | 87 | state = self.state[p] 88 | 89 | mov_avg_grad, mov_avg_param = state['mov_avg_grad'], state['mov_avg_param'] 90 | 91 | mov_avg_grad.mul_(1.0 - w_t).add_(grad, alpha=w_t) 92 | mov_avg_param.mul_(1.0 - w_t).add_(p, alpha=w_t) 93 | 94 | mov_avg_param.add_(mov_avg_grad, alpha=-group['lr']) 95 | 96 | p.copy_(mov_avg_param) 97 | 98 | return loss 99 | -------------------------------------------------------------------------------- /pytorch_optimizer/optimizer/tiger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_optimizer.base.exception import NoSparseGradientError 4 | from pytorch_optimizer.base.optimizer import BaseOptimizer 5 | from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, GROUP, LOSS, PARAMETERS 6 | 7 | 8 | class Tiger(BaseOptimizer): 9 | r"""A Tight-fisted Optimizer, an optimizer that is extremely budget-conscious. 10 | 11 | :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. 12 | :param lr: float. learning rate. 13 | :param beta: float. coefficients used for computing running averages of gradient and the squared hessian trace. 14 | :param weight_decay: float. weight decay (L2 penalty). 15 | :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. 16 | :param fixed_decay: bool. fix weight decay. 17 | :param maximize: bool. maximize the objective with respect to the params, instead of minimizing. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | params: PARAMETERS, 23 | lr: float = 1e-3, 24 | beta: float = 0.965, 25 | weight_decay: float = 0.01, 26 | weight_decouple: bool = True, 27 | fixed_decay: bool = False, 28 | maximize: bool = False, 29 | **kwargs, 30 | ): 31 | self.validate_learning_rate(lr) 32 | self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[)') 33 | self.validate_non_negative(weight_decay, 'weight_decay') 34 | 35 | self.maximize = maximize 36 | 37 | defaults: DEFAULTS = { 38 | 'lr': lr, 39 | 'beta': beta, 40 | 'weight_decay': weight_decay, 41 | 'weight_decouple': weight_decouple, 42 | 'fixed_decay': fixed_decay, 43 | } 44 | 45 | super().__init__(params, defaults) 46 | 47 | def __str__(self) -> str: 48 | return 'Tiger' 49 | 50 | def init_group(self, group: GROUP, **kwargs) -> None: 51 | for p in group['params']: 52 | if p.grad is None: 53 | continue 54 | 55 | grad = p.grad 56 | if grad.is_sparse: 57 | raise NoSparseGradientError(str(self)) 58 | 59 | state = self.state[p] 60 | 61 | if len(state) == 0: 62 | state['exp_avg'] = torch.zeros_like(grad) 63 | 64 | @torch.no_grad() 65 | def step(self, closure: CLOSURE = None) -> LOSS: 66 | loss: LOSS = None 67 | if closure is not None: 68 | with torch.enable_grad(): 69 | loss = closure() 70 | 71 | for group in self.param_groups: 72 | if 'step' not in group: 73 | self.init_group(group) 74 | group['step'] = 1 75 | else: 76 | group['step'] += 1 77 | 78 | beta = group['beta'] 79 | 80 | for p in group['params']: 81 | if p.grad is None: 82 | continue 83 | 84 | grad = p.grad 85 | 86 | self.maximize_gradient(grad, maximize=self.maximize) 87 | 88 | state = self.state[p] 89 | 90 | self.apply_weight_decay( 91 | p=p, 92 | grad=grad, 93 | lr=group['lr'], 94 | weight_decay=group['weight_decay'], 95 | weight_decouple=group['weight_decouple'], 96 | fixed_decay=group['fixed_decay'], 97 | ) 98 | 99 | exp_avg = state['exp_avg'] 100 | exp_avg.mul_(beta).add_(grad, alpha=1.0 - beta) 101 | 102 | p.add_( 103 | torch.sign(exp_avg) if not torch.is_complex(exp_avg) else torch.sgn(exp_avg), alpha=-group['lr'] 104 | ) 105 | 106 | return loss 107 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cpu 2 | 3 | black==24.8.0 ; python_version == "3.8" 4 | black==25.1.0 ; python_version >= "3.9" 5 | click==8.1.8 ; python_version >= "3.8" 6 | colorama==0.4.6 ; python_version >= "3.8" and (sys_platform == "win32" or platform_system == "Windows") 7 | coverage[toml]==7.6.1 ; python_version == "3.8" 8 | coverage[toml]==7.8.2 ; python_version >= "3.9" 9 | exceptiongroup==1.3.0 ; python_version < "3.11" and python_version >= "3.8" 10 | filelock==3.16.1 ; python_version == "3.8" 11 | filelock==3.18.0 ; python_version >= "3.9" 12 | fsspec==2025.3.0 ; python_version == "3.8" 13 | fsspec==2025.5.1 ; python_version >= "3.9" 14 | iniconfig==2.1.0 ; python_version >= "3.8" 15 | isort==5.13.2 ; python_version == "3.8" 16 | isort==6.0.1 ; python_version >= "3.9" 17 | jinja2==3.1.6 ; python_version >= "3.8" 18 | markupsafe==2.1.5 ; python_version == "3.8" 19 | markupsafe==3.0.2 ; python_version >= "3.9" 20 | mpmath==1.3.0 ; python_version >= "3.8" 21 | mypy-extensions==1.1.0 ; python_version >= "3.8" 22 | networkx==3.1 ; python_version == "3.8" 23 | networkx==3.2.1 ; python_version >= "3.9" 24 | numpy==1.24.4 ; python_version == "3.8" 25 | numpy==2.0.2 ; python_version >= "3.9" 26 | packaging==25.0 ; python_version >= "3.8" 27 | pathspec==0.12.1 ; python_version >= "3.8" 28 | platformdirs==4.3.6 ; python_version == "3.8" 29 | platformdirs==4.3.8 ; python_version >= "3.9" 30 | pluggy==1.5.0 ; python_version == "3.8" 31 | pluggy==1.6.0 ; python_version >= "3.9" 32 | pytest-cov==5.0.0 ; python_version >= "3.8" 33 | pytest==8.3.5 ; python_version >= "3.8" 34 | ruff==0.11.12 ; python_version >= "3.8" 35 | setuptools==80.9.0 ; python_version >= "3.12" 36 | sympy==1.13.3 ; python_version == "3.8" 37 | sympy==1.14.0 ; python_version >= "3.9" 38 | tomli==2.2.1 ; python_full_version <= "3.11.0a6" and python_version >= "3.8" 39 | torch==2.4.1+cpu ; python_version == "3.8" 40 | torch==2.7.0+cpu ; python_version >= "3.9" 41 | typing-extensions==4.13.2 ; python_version >= "3.8" 42 | -------------------------------------------------------------------------------- /requirements-docs.txt: -------------------------------------------------------------------------------- 1 | --index-url https://pypi.org/simple 2 | --extra-index-url https://download.pytorch.org/whl/cpu 3 | numpy<2.0 4 | torch==2.6.0 5 | mkdocs==1.6.1 6 | mkdocs-material==9.5.45 7 | pymdown-extensions==10.12 8 | mkdocstrings-python==1.12.2 9 | markdown-include==0.8.1 10 | mdx_truly_sane_lists==1.3 11 | mkdocs-awesome-pages-plugin==2.9.3 12 | griffe==1.5.1 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cpu 2 | 3 | filelock==3.16.1 ; python_version == "3.8" 4 | filelock==3.18.0 ; python_version >= "3.9" 5 | fsspec==2025.3.0 ; python_version == "3.8" 6 | fsspec==2025.5.1 ; python_version >= "3.9" 7 | jinja2==3.1.6 ; python_version >= "3.8" 8 | markupsafe==2.1.5 ; python_version == "3.8" 9 | markupsafe==3.0.2 ; python_version >= "3.9" 10 | mpmath==1.3.0 ; python_version >= "3.8" 11 | networkx==3.1 ; python_version == "3.8" 12 | networkx==3.2.1 ; python_version >= "3.9" 13 | numpy==1.24.4 ; python_version == "3.8" 14 | numpy==2.0.2 ; python_version >= "3.9" 15 | setuptools==80.9.0 ; python_version >= "3.12" 16 | sympy==1.13.3 ; python_version == "3.8" 17 | sympy==1.14.0 ; python_version >= "3.9" 18 | torch==2.4.1+cpu ; python_version == "3.8" 19 | torch==2.7.0+cpu ; python_version >= "3.9" 20 | typing-extensions==4.13.2 ; python_version >= "3.8" 21 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/pytorch_optimizer/030d303b8f3ef506b4aa96d975d2234d13afa5f5/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | 7 | 8 | @pytest.fixture(scope='session') 9 | def environment(num_samples: int = 100, dims: int = 2, seed: int = 42) -> Tuple[torch.Tensor, torch.Tensor]: 10 | torch.manual_seed(42) 11 | rng = np.random.RandomState(seed) 12 | 13 | x = rng.randn(num_samples, dims) * 2 14 | 15 | # center the first N/2 points at (-2, -2) 16 | mid: int = num_samples // 2 17 | x[:mid, :] = x[:mid, :] - 2 * np.ones((mid, dims)) 18 | 19 | # center the last N/2 points at (2, 2) 20 | x[mid:, :] = x[mid:, :] + 2 * np.ones((mid, dims)) 21 | 22 | # labels: first N/2 are 0, last N/2 are 1 23 | y = np.array([0] * mid + [1] * mid).reshape(100, 1) 24 | 25 | return torch.Tensor(x), torch.Tensor(y) 26 | -------------------------------------------------------------------------------- /tests/test_base.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from pytorch_optimizer.base.optimizer import BaseOptimizer 5 | from tests.utils import simple_parameter 6 | 7 | 8 | def test_set_hessian(): 9 | param = simple_parameter() 10 | 11 | param_groups = [{'params': param}] 12 | hessian = [torch.zeros(2, 1)] 13 | 14 | with pytest.raises(ValueError): 15 | BaseOptimizer.set_hessian(param_groups, {'dummy': param}, hessian) 16 | 17 | 18 | def test_compute_hutchinson_hessian(): 19 | with pytest.raises(NotImplementedError): 20 | BaseOptimizer.compute_hutchinson_hessian({}, {}, distribution='dummy') 21 | 22 | 23 | def test_validate_boundary(): 24 | x: float = -1.0 25 | 26 | with pytest.raises(ValueError): 27 | BaseOptimizer.validate_boundary(x, -2.0, bound_type='upper') 28 | 29 | with pytest.raises(ValueError): 30 | BaseOptimizer.validate_boundary(x, 1.0, bound_type='lower') 31 | 32 | 33 | @pytest.mark.parametrize('range_type', ['[]', '[)', '(]', '()']) 34 | def test_validate_range(range_type): 35 | with pytest.raises(ValueError): 36 | BaseOptimizer.validate_range(-1.0, 'x', 0.0, 1.0, range_type=range_type) 37 | 38 | 39 | def test_non_positive(): 40 | with pytest.raises(ValueError): 41 | BaseOptimizer.validate_non_positive(1.0, 'asdf') 42 | 43 | 44 | def test_mod(): 45 | with pytest.raises(ValueError): 46 | BaseOptimizer.validate_mod(10, 3) 47 | 48 | 49 | def test_maximize_gradient(): 50 | grad = torch.ones((1,)) 51 | expected = -grad 52 | 53 | BaseOptimizer.maximize_gradient(grad, True) 54 | 55 | torch.testing.assert_close(grad, expected) 56 | -------------------------------------------------------------------------------- /tests/test_create_optimizer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pytorch_optimizer.optimizer import create_optimizer, load_optimizer 4 | from tests.constants import VALID_OPTIMIZER_NAMES 5 | from tests.utils import Example 6 | 7 | 8 | @pytest.mark.parametrize('use_lookahead', [True, False]) 9 | @pytest.mark.parametrize('use_orthograd', [True, False]) 10 | @pytest.mark.parametrize('optimizer_name', VALID_OPTIMIZER_NAMES) 11 | def test_create_optimizer(use_lookahead, use_orthograd, optimizer_name): 12 | if optimizer_name in ('adamw', 'adam', 'sgd', 'demo'): 13 | pytest.skip(f'skip {optimizer_name}') 14 | 15 | if use_lookahead and use_orthograd: 16 | pytest.skip() 17 | 18 | kwargs = {'eps': 1e-8, 'k': 7} 19 | if optimizer_name == 'ranger21': 20 | kwargs.update({'num_iterations': 1}) 21 | elif optimizer_name == 'bsam': 22 | kwargs.update({'num_data': 1}) 23 | elif optimizer_name == 'demo': 24 | kwargs = {} 25 | 26 | create_optimizer( 27 | Example(), 28 | optimizer_name=optimizer_name, 29 | use_lookahead=use_lookahead, 30 | use_orthograd=use_orthograd, 31 | **kwargs, 32 | ) 33 | 34 | 35 | def test_bnb_optimizer(): 36 | with pytest.raises(ImportError): 37 | load_optimizer('bnb_adamw8bit') 38 | 39 | 40 | def test_q_galore_optimizer(): 41 | with pytest.raises(ImportError): 42 | load_optimizer('q_galore_adamw8bit') 43 | 44 | 45 | def test_torchao_optimizer(): 46 | with pytest.raises(ImportError): 47 | load_optimizer('torchao_adamw4bit') 48 | -------------------------------------------------------------------------------- /tests/test_load_modules.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pytorch_optimizer.loss import get_supported_loss_functions 4 | from pytorch_optimizer.lr_scheduler import get_supported_lr_schedulers, load_lr_scheduler 5 | from pytorch_optimizer.optimizer import get_supported_optimizers, load_optimizer 6 | from tests.constants import ( 7 | INVALID_LR_SCHEDULER_NAMES, 8 | INVALID_OPTIMIZER_NAMES, 9 | VALID_LR_SCHEDULER_NAMES, 10 | VALID_OPTIMIZER_NAMES, 11 | ) 12 | 13 | 14 | @pytest.mark.parametrize('valid_optimizer_names', VALID_OPTIMIZER_NAMES) 15 | def test_load_optimizer_valid(valid_optimizer_names): 16 | load_optimizer(valid_optimizer_names) 17 | 18 | 19 | @pytest.mark.parametrize('invalid_optimizer_names', INVALID_OPTIMIZER_NAMES) 20 | def test_load_optimizer_invalid(invalid_optimizer_names): 21 | with pytest.raises(NotImplementedError): 22 | load_optimizer(invalid_optimizer_names) 23 | 24 | 25 | @pytest.mark.parametrize('valid_lr_scheduler_names', VALID_LR_SCHEDULER_NAMES) 26 | def test_load_lr_scheduler_valid(valid_lr_scheduler_names): 27 | load_lr_scheduler(valid_lr_scheduler_names) 28 | 29 | 30 | @pytest.mark.parametrize('invalid_lr_scheduler_names', INVALID_LR_SCHEDULER_NAMES) 31 | def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names): 32 | with pytest.raises(NotImplementedError): 33 | load_lr_scheduler(invalid_lr_scheduler_names) 34 | 35 | 36 | def test_get_supported_optimizers(): 37 | assert len(get_supported_optimizers()) == 105 38 | assert len(get_supported_optimizers('adam*')) == 9 39 | assert len(get_supported_optimizers(['adam*', 'ranger*'])) == 12 40 | 41 | 42 | def test_get_supported_lr_schedulers(): 43 | assert len(get_supported_lr_schedulers()) == 16 44 | assert len(get_supported_lr_schedulers('cosine*')) == 4 45 | assert len(get_supported_lr_schedulers(['cosine*', '*warm*'])) == 5 46 | 47 | 48 | def test_get_supported_loss_functions(): 49 | assert len(get_supported_loss_functions()) == 13 50 | assert len(get_supported_loss_functions('*focal*')) == 4 51 | assert len(get_supported_loss_functions(['*focal*', 'bce*'])) == 5 52 | -------------------------------------------------------------------------------- /tests/test_lr_scheduler_parameters.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError 5 | from pytorch_optimizer.lr_scheduler import get_chebyshev_perm_steps 6 | from pytorch_optimizer.lr_scheduler.cosine_anealing import CosineAnnealingWarmupRestarts 7 | from pytorch_optimizer.lr_scheduler.linear_warmup import PolyScheduler 8 | from pytorch_optimizer.optimizer import AdamW 9 | from tests.utils import Example 10 | 11 | 12 | def test_cosine_annealing_warmup_restarts_params(): 13 | optimizer = AdamW(Example().parameters()) 14 | 15 | with pytest.raises(ValueError) as error_info: 16 | CosineAnnealingWarmupRestarts( 17 | optimizer=optimizer, 18 | first_cycle_steps=10, 19 | warmup_steps=20, 20 | ) 21 | 22 | assert str(error_info.value) == '[-] warmup_steps must be smaller than first_cycle_steps. 20 < 10' 23 | 24 | min_lr: float = 1e-6 25 | first_cycle_steps: int = 5 26 | lr_scheduler = CosineAnnealingWarmupRestarts( 27 | optimizer=optimizer, 28 | min_lr=min_lr, 29 | first_cycle_steps=first_cycle_steps, 30 | warmup_steps=0, 31 | ) 32 | lr_scheduler.step_in_cycle = -1 33 | expected_max_lr: float = round(lr_scheduler.get_lr()[0], 6) 34 | np.testing.assert_almost_equal(min_lr, expected_max_lr) 35 | 36 | for _ in range(first_cycle_steps + 1): 37 | lr_scheduler.step(epoch=None) 38 | 39 | 40 | def test_linear_warmup_lr_scheduler_params(): 41 | optimizer = AdamW(Example().parameters()) 42 | 43 | with pytest.raises(ValueError) as error_info: 44 | PolyScheduler(poly_order=-1, optimizer=optimizer, t_max=1, max_lr=1) 45 | 46 | assert str(error_info.value) == '[-] poly_order must be positive. -1' 47 | 48 | with pytest.raises(NegativeLRError): 49 | PolyScheduler(optimizer=optimizer, t_max=1, max_lr=-1) 50 | 51 | with pytest.raises(NegativeLRError): 52 | PolyScheduler(optimizer=optimizer, t_max=1, max_lr=1, min_lr=-1) 53 | 54 | with pytest.raises(NegativeLRError): 55 | PolyScheduler(optimizer=optimizer, t_max=1, max_lr=1, min_lr=1, init_lr=-1) 56 | 57 | with pytest.raises(NegativeStepError): 58 | PolyScheduler(optimizer=optimizer, t_max=-1, max_lr=1, min_lr=1, init_lr=1) 59 | 60 | with pytest.raises(NegativeStepError): 61 | PolyScheduler(optimizer=optimizer, t_max=1, max_lr=1, min_lr=1, init_lr=1, warmup_steps=-1) 62 | 63 | 64 | def test_chebyshev_params(): 65 | with pytest.raises(IndexError): 66 | get_chebyshev_perm_steps(0) 67 | -------------------------------------------------------------------------------- /tests/test_optimizer_variants.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from tests.constants import ( 6 | ADAMD_SUPPORTED_OPTIMIZERS, 7 | ADANORM_SUPPORTED_OPTIMIZERS, 8 | COPT_SUPPORTED_OPTIMIZERS, 9 | STABLE_ADAMW_SUPPORTED_OPTIMIZERS, 10 | ) 11 | from tests.utils import build_model, ids, simple_parameter, tensor_to_numpy 12 | 13 | 14 | @pytest.mark.parametrize('optimizer_config', ADANORM_SUPPORTED_OPTIMIZERS, ids=ids) 15 | def test_adanorm_optimizer(optimizer_config, environment): 16 | x_data, y_data = environment 17 | model, loss_fn = build_model() 18 | 19 | optimizer_class, config, num_iterations = optimizer_config 20 | 21 | optimizer = optimizer_class(model.parameters(), **config, adanorm=True) 22 | 23 | init_loss, loss = np.inf, np.inf 24 | for _ in range(num_iterations): 25 | optimizer.zero_grad() 26 | 27 | y_pred = model(x_data) 28 | loss = loss_fn(y_pred, y_data) 29 | 30 | if init_loss == np.inf: 31 | init_loss = loss 32 | 33 | loss.backward() 34 | 35 | optimizer.step() 36 | 37 | assert tensor_to_numpy(init_loss) > 1.75 * tensor_to_numpy(loss) 38 | 39 | 40 | @pytest.mark.parametrize('optimizer_config', ADANORM_SUPPORTED_OPTIMIZERS, ids=ids) 41 | def test_adanorm_variant(optimizer_config): 42 | param = simple_parameter(True) 43 | param.grad = torch.ones(1, 1) 44 | 45 | optimizer_class, config = optimizer_config[:2] 46 | 47 | optimizer = optimizer_class([param], adanorm=True) 48 | optimizer.step() 49 | 50 | param.grad = torch.zeros(1, 1) 51 | optimizer.step() 52 | 53 | 54 | @pytest.mark.parametrize('optimizer_config', ADAMD_SUPPORTED_OPTIMIZERS, ids=ids) 55 | def test_adamd_variant(optimizer_config, environment): 56 | x_data, y_data = environment 57 | model, loss_fn = build_model() 58 | 59 | optimizer_class, config, num_iterations = optimizer_config 60 | 61 | optimizer = optimizer_class(model.parameters(), **config, adam_debias=True) 62 | 63 | init_loss, loss = np.inf, np.inf 64 | for _ in range(num_iterations): 65 | optimizer.zero_grad() 66 | 67 | y_pred = model(x_data) 68 | loss = loss_fn(y_pred, y_data) 69 | 70 | if init_loss == np.inf: 71 | init_loss = loss 72 | 73 | loss.backward(create_graph=optimizer_class.__name__ in ('AdaHessian',)) 74 | 75 | optimizer.step() 76 | 77 | assert tensor_to_numpy(init_loss) > 2.0 * tensor_to_numpy(loss) 78 | 79 | 80 | @pytest.mark.parametrize('optimizer_config', COPT_SUPPORTED_OPTIMIZERS, ids=ids) 81 | def test_cautious_variant(optimizer_config, environment): 82 | x_data, y_data = environment 83 | model, loss_fn = build_model() 84 | 85 | optimizer_class, config, num_iterations = optimizer_config 86 | 87 | optimizer = optimizer_class(model.parameters(), **config, cautious=True) 88 | 89 | init_loss, loss = np.inf, np.inf 90 | for _ in range(num_iterations): 91 | optimizer.zero_grad() 92 | 93 | y_pred = model(x_data) 94 | loss = loss_fn(y_pred, y_data) 95 | 96 | if init_loss == np.inf: 97 | init_loss = loss 98 | 99 | loss.backward() 100 | 101 | optimizer.step() 102 | 103 | assert tensor_to_numpy(init_loss) > 1.5 * tensor_to_numpy(loss) 104 | 105 | 106 | @pytest.mark.parametrize('optimizer_config', STABLE_ADAMW_SUPPORTED_OPTIMIZERS, ids=ids) 107 | def test_stable_adamw_variant(optimizer_config, environment): 108 | x_data, y_data = environment 109 | model, loss_fn = build_model() 110 | 111 | optimizer_class, config, num_iterations = optimizer_config 112 | 113 | optimizer = optimizer_class(model.parameters(), **config) 114 | 115 | init_loss, loss = np.inf, np.inf 116 | for _ in range(num_iterations): 117 | optimizer.zero_grad() 118 | 119 | y_pred = model(x_data) 120 | loss = loss_fn(y_pred, y_data) 121 | 122 | if init_loss == np.inf: 123 | init_loss = loss 124 | 125 | loss.backward() 126 | 127 | optimizer.step() 128 | 129 | assert tensor_to_numpy(init_loss) > 1.5 * tensor_to_numpy(loss) 130 | --------------------------------------------------------------------------------