├── docs ├── .gitignore ├── _static │ └── css │ │ └── custom.css ├── images │ └── examples │ │ ├── lbfgs.png │ │ ├── mnist.png │ │ ├── ogda.png │ │ ├── flax_optax.png │ │ ├── lookahead.png │ │ ├── adversarial.png │ │ ├── contrib │ │ ├── sam.png │ │ ├── reduce_on_plateau.png │ │ └── ademamix_rosenbrock.png │ │ ├── perturbations.png │ │ ├── cifar10_resnet.png │ │ ├── tiny_shakespeare.png │ │ └── linear_assignment_problem.png ├── api │ ├── assignment.rst │ ├── combining_optimizers.rst │ ├── perturbations.rst │ ├── apply_updates.rst │ ├── experimental.rst │ ├── optimizer_wrappers.rst │ ├── stochastic_gradient_estimators.rst │ ├── optimizer_schedules.rst │ ├── optimizers.rst │ ├── projections.rst │ ├── contrib.rst │ ├── losses.rst │ └── utilities.rst ├── pyproject.toml ├── Makefile ├── index.rst └── ext │ └── coverage_check.py ├── .github ├── workflows │ ├── mlc_config.json │ └── pypi-publish.yml └── check_license_headers.py ├── examples ├── README.md ├── contrib │ └── README.md └── pyproject.toml ├── .gitignore ├── readthedocs.yml ├── optax ├── experimental │ └── __init__.py ├── schedules │ ├── inject.py │ ├── _join_test.py │ ├── _join.py │ └── __init__.py ├── assignment │ ├── __init__.py │ └── _hungarian_algorithm_test.py ├── perturbations │ └── __init__.py ├── second_order │ ├── __init__.py │ └── _deprecated.py ├── _src │ ├── combine.py │ ├── constrain.py │ ├── wrappers.py │ ├── schedule.py │ ├── deprecations.py │ ├── update_test.py │ ├── factorized_test.py │ ├── float64_test.py │ ├── update.py │ ├── test_utils.py │ ├── sharding_test.py │ └── numerics.py ├── optax_test.py ├── CHANGELOG.md ├── projections │ └── __init__.py ├── losses │ ├── _smoothing.py │ ├── _fenchel_young.py │ ├── _smoothing_test.py │ ├── __init__.py │ ├── _fenchel_young_test.py │ └── _self_supervised_test.py ├── transforms │ ├── _layouts_test.py │ ├── _layouts.py │ ├── _constraining.py │ ├── _monitoring_test.py │ ├── _adding_test.py │ ├── _freezing.py │ ├── __init__.py │ └── _constraining_test.py ├── tree │ └── __init__.py ├── contrib │ ├── _complex_valued_test.py │ ├── __init__.py │ ├── _sam_test.py │ ├── _mechanic_test.py │ ├── _privacy_test.py │ ├── _complex_valued.py │ └── _cocob.py └── tree_utils │ ├── _random.py │ ├── _random_test.py │ ├── __init__.py │ └── _casting_test.py ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── test.sh └── pyproject.toml /docs/.gitignore: -------------------------------------------------------------------------------- 1 | _build 2 | -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | div.math { 2 | flex-direction: row; 3 | } 4 | -------------------------------------------------------------------------------- /docs/images/examples/lbfgs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/lbfgs.png -------------------------------------------------------------------------------- /docs/images/examples/mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/mnist.png -------------------------------------------------------------------------------- /docs/images/examples/ogda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/ogda.png -------------------------------------------------------------------------------- /docs/images/examples/flax_optax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/flax_optax.png -------------------------------------------------------------------------------- /docs/images/examples/lookahead.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/lookahead.png -------------------------------------------------------------------------------- /docs/images/examples/adversarial.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/adversarial.png -------------------------------------------------------------------------------- /docs/images/examples/contrib/sam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/contrib/sam.png -------------------------------------------------------------------------------- /docs/images/examples/perturbations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/perturbations.png -------------------------------------------------------------------------------- /docs/images/examples/cifar10_resnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/cifar10_resnet.png -------------------------------------------------------------------------------- /docs/images/examples/tiny_shakespeare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/tiny_shakespeare.png -------------------------------------------------------------------------------- /.github/workflows/mlc_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "timeout": "20s", 3 | "retryCount": 5, 4 | "aliveStatusCodes": [ 5 | 429 6 | ] 7 | } 8 | -------------------------------------------------------------------------------- /docs/images/examples/contrib/reduce_on_plateau.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/contrib/reduce_on_plateau.png -------------------------------------------------------------------------------- /docs/images/examples/linear_assignment_problem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/linear_assignment_problem.png -------------------------------------------------------------------------------- /docs/images/examples/contrib/ademamix_rosenbrock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/contrib/ademamix_rosenbrock.png -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | This directory contains examples using the optax library. 4 | 5 | ```{toctree} 6 | :glob: 7 | :maxdepth: 1 8 | 9 | * 10 | ``` 11 | -------------------------------------------------------------------------------- /examples/contrib/README.md: -------------------------------------------------------------------------------- 1 | # Contrib Examples 2 | 3 | Examples that make use of the `optax.contrib` module. 4 | 5 | ```{toctree} 6 | :glob: 7 | :maxdepth: 1 8 | 9 | * 10 | ``` 11 | -------------------------------------------------------------------------------- /docs/api/assignment.rst: -------------------------------------------------------------------------------- 1 | Assignment problem 2 | ================== 3 | 4 | .. currentmodule:: optax.assignment 5 | 6 | .. autosummary:: 7 | hungarian_algorithm 8 | 9 | 10 | Hungarian algorithm 11 | ~~~~~~~~~~~~~~~~~~~ 12 | .. autofunction:: hungarian_algorithm 13 | -------------------------------------------------------------------------------- /docs/api/combining_optimizers.rst: -------------------------------------------------------------------------------- 1 | Combining Optimizers 2 | ===================== 3 | 4 | .. currentmodule:: optax 5 | 6 | .. autosummary:: 7 | chain 8 | named_chain 9 | partition 10 | 11 | Chain 12 | ~~~~~ 13 | .. autofunction:: chain 14 | .. autofunction:: named_chain 15 | 16 | Partition 17 | ~~~~~~~~~ 18 | .. autofunction:: partition 19 | .. autoclass:: PartitionState 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Building and releasing library: 2 | *.egg-info 3 | *.pyc 4 | *.so 5 | build/ 6 | dist/ 7 | venv/ 8 | _testing/ 9 | 10 | # Building the documentation 11 | docs/_autosummary 12 | docs/_collections 13 | docs/modules/generated 14 | docs/sg_execution_times.rst 15 | 16 | # Mac OS 17 | .DS_Store 18 | 19 | # Python tools 20 | .mypy_cache/ 21 | .pytype/ 22 | .ipynb_checkpoints 23 | 24 | # Editors 25 | .idea 26 | .vscode 27 | 28 | -------------------------------------------------------------------------------- /docs/api/perturbations.rst: -------------------------------------------------------------------------------- 1 | Perturbations 2 | ============= 3 | 4 | .. currentmodule:: optax.perturbations 5 | 6 | .. autosummary:: 7 | make_perturbed_fun 8 | Gumbel 9 | Normal 10 | 11 | 12 | Gumbel noise 13 | ~~~~~~~~~~~~ 14 | .. autoclass:: Gumbel 15 | 16 | Make perturbed function 17 | ~~~~~~~~~~~~~~~~~~~~~~~ 18 | .. autofunction:: make_perturbed_fun 19 | 20 | Normal noise 21 | ~~~~~~~~~~~~ 22 | .. autoclass:: Normal 23 | 24 | -------------------------------------------------------------------------------- /docs/api/apply_updates.rst: -------------------------------------------------------------------------------- 1 | Apply Updates 2 | ============= 3 | 4 | .. currentmodule:: optax 5 | 6 | .. autosummary:: 7 | apply_updates 8 | incremental_update 9 | periodic_update 10 | 11 | Apply updates 12 | ~~~~~~~~~~~~~~~~~ 13 | .. autofunction:: apply_updates 14 | 15 | Incremental update 16 | ~~~~~~~~~~~~~~~~~~ 17 | .. autofunction:: incremental_update 18 | 19 | Periodic update 20 | ~~~~~~~~~~~~~~~ 21 | .. autofunction:: periodic_update 22 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | version: 2 5 | 6 | build: 7 | os: ubuntu-22.04 8 | tools: 9 | python: "3.11" 10 | 11 | sphinx: 12 | builder: html 13 | configuration: docs/conf.py 14 | fail_on_warning: false 15 | 16 | python: 17 | install: 18 | # Equivalent to 'pip install .' 19 | - method: pip 20 | path: . 21 | # Equivalent to 'pip install .[docs]' 22 | - method: pip 23 | path: . 24 | extra_requirements: 25 | - docs 26 | -------------------------------------------------------------------------------- /docs/api/experimental.rst: -------------------------------------------------------------------------------- 1 | 🧪 Experimental 2 | =============== 3 | 4 | Experimental features subject to changes before being graduated into `optax`. 5 | 6 | .. currentmodule:: optax.experimental 7 | 8 | .. autosummary:: 9 | microbatching.microbatch 10 | microbatching.micro_vmap 11 | microbatching.micro_grad 12 | microbatching.AccumulationType 13 | microbatching.Accumulator 14 | 15 | .. currentmodule:: optax.experimental.microbatching 16 | 17 | Microbatching 18 | ~~~~~~~~~~~~~ 19 | .. autofunction:: microbatch 20 | .. autofunction:: micro_vmap 21 | .. autofunction:: micro_grad 22 | .. autofunction:: reshape_batch_axis 23 | .. autoclass:: AccumulationType 24 | :members: 25 | .. autoclass:: Accumulator 26 | -------------------------------------------------------------------------------- /docs/pyproject.toml: -------------------------------------------------------------------------------- 1 | # The pyproject.toml is here to override ruff linter checks in this directory 2 | # changes from base pyproject.toml: 3 | # - allow E501 line too long 4 | 5 | [tool.ruff] 6 | line-length = 80 7 | 8 | [tool.ruff.lint] 9 | select = [ 10 | "F", 11 | "E", 12 | "W291", # whitespace at the end of the line 13 | "B023", # pylint's cell-var-over-loop, closures capturing variables in loop 14 | ] 15 | ignore = [ 16 | "E731", # lambdas are allowed 17 | "F401", # allow unused imports 18 | "E501", # allow line-too-long in notebooks 19 | "E402", # allow modules not at top of file 20 | "E741", # allow "l" as a variable name 21 | "E703", # allow semicolons (for jupyter notebooks) 22 | ] 23 | -------------------------------------------------------------------------------- /examples/pyproject.toml: -------------------------------------------------------------------------------- 1 | # The pyproject.toml is here to override ruff linter checks in this directory 2 | # changes from base pyproject.toml: 3 | # - allow E501 line too long 4 | 5 | [tool.ruff] 6 | line-length = 80 7 | 8 | [tool.ruff.lint] 9 | select = [ 10 | "F", 11 | "E", 12 | "W291", # whitespace at the end of the line 13 | "B023", # pylint's cell-var-over-loop, closures capturing variables in loop 14 | ] 15 | ignore = [ 16 | "E731", # lambdas are allowed 17 | "F401", # allow unused imports 18 | "E501", # allow line-too-long in notebooks 19 | "E402", # allow modules not at top of file 20 | "E741", # allow "l" as a variable name 21 | "E703", # allow semicolons (for jupyter notebooks) 22 | ] 23 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | 21 | clean: 22 | rm -rf $(BUILDDIR)/* auto_examples _collections modules 23 | 24 | html-noplot: 25 | $(SPHINXBUILD) -D plot_gallery=0 -D nb_execution_mode=off -b html $(ALLSPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html 26 | @echo 27 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 28 | -------------------------------------------------------------------------------- /optax/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Experimental optax modules.""" 17 | 18 | from . import _aggregating as aggregating 19 | from .microbatching import microbatch 20 | 21 | 22 | __all__ = [ 23 | 'aggregating', 24 | 'microbatch', 25 | ] 26 | -------------------------------------------------------------------------------- /optax/schedules/inject.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Import stub.""" 16 | 17 | # TODO(mtthss): delete this file asap. 18 | 19 | from optax.schedules import _inject 20 | 21 | InjectHyperparamsState = _inject.InjectHyperparamsState 22 | inject_hyperparams = _inject.inject_hyperparams 23 | -------------------------------------------------------------------------------- /optax/assignment/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The assignment sub-package.""" 16 | 17 | # pylint:disable=g-importing-member 18 | 19 | from optax.assignment._hungarian_algorithm import base_hungarian_algorithm 20 | from optax.assignment._hungarian_algorithm import hungarian_algorithm 21 | -------------------------------------------------------------------------------- /optax/perturbations/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The perturbations sub-package.""" 16 | 17 | # pylint: disable=g-importing-member 18 | 19 | from optax.perturbations._make_pert import Gumbel 20 | from optax.perturbations._make_pert import make_perturbed_fun 21 | from optax.perturbations._make_pert import Normal 22 | -------------------------------------------------------------------------------- /optax/second_order/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The second order optimization sub-package.""" 16 | 17 | # pylint: disable=g-importing-member 18 | 19 | from optax.second_order._deprecated import fisher_diag 20 | from optax.second_order._deprecated import hessian_diag 21 | from optax.second_order._deprecated import hvp 22 | -------------------------------------------------------------------------------- /optax/_src/combine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Flexibly compose gradient transformations.""" 16 | 17 | from optax.transforms import _combining 18 | 19 | chain = _combining.chain 20 | named_chain = _combining.named_chain 21 | partition = _combining.partition 22 | multi_transform = _combining.partition # for backwards compatibility 23 | MultiTransformState = _combining.PartitionState 24 | -------------------------------------------------------------------------------- /docs/api/optimizer_wrappers.rst: -------------------------------------------------------------------------------- 1 | Optimizer Wrappers 2 | ==================== 3 | 4 | .. currentmodule:: optax 5 | 6 | .. autosummary:: 7 | apply_if_finite 8 | ApplyIfFiniteState 9 | flatten 10 | lookahead 11 | LookaheadParams 12 | LookaheadState 13 | masked 14 | MaskedState 15 | MultiSteps 16 | MultiStepsState 17 | ShouldSkipUpdateFunction 18 | skip_large_updates 19 | skip_not_finite 20 | 21 | 22 | Apply if finite 23 | ~~~~~~~~~~~~~~~ 24 | .. autofunction:: apply_if_finite 25 | .. autoclass:: ApplyIfFiniteState 26 | 27 | Flatten 28 | ~~~~~~~~ 29 | .. autofunction:: flatten 30 | 31 | Lookahead 32 | ~~~~~~~~~ 33 | .. autofunction:: lookahead 34 | .. autoclass:: LookaheadParams 35 | .. autoclass:: LookaheadState 36 | 37 | Masked update 38 | ~~~~~~~~~~~~~ 39 | .. autofunction:: masked 40 | .. autoclass:: MaskedState 41 | 42 | Multi-step update 43 | ~~~~~~~~~~~~~~~~~ 44 | .. autoclass:: MultiSteps 45 | .. autoclass:: MultiStepsState 46 | .. autoclass:: ShouldSkipUpdateFunction 47 | .. autofunction:: skip_large_updates 48 | .. autofunction:: skip_not_finite 49 | -------------------------------------------------------------------------------- /optax/_src/constrain.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Gradient transformations used to enforce specific constraints.""" 16 | 17 | from optax.transforms import _constraining 18 | 19 | keep_params_nonnegative = _constraining.keep_params_nonnegative 20 | NonNegativeParamsState = _constraining.NonNegativeParamsState 21 | zero_nans = _constraining.zero_nans 22 | ZeroNansState = _constraining.ZeroNansState 23 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Install the pre-commit hooks below with 2 | # 'pre-commit install' 3 | 4 | # Auto-update the version of the hooks with 5 | # 'pre-commit autoupdate' 6 | 7 | # Run the hooks on all files with 8 | # 'pre-commit run --all' 9 | 10 | repos: 11 | - repo: https://github.com/pre-commit/pre-commit-hooks 12 | rev: v5.0.0 13 | hooks: 14 | - id: check-ast 15 | - id: check-merge-conflict 16 | - id: check-toml 17 | - id: check-yaml 18 | - id: end-of-file-fixer 19 | files: \.(py|md)$ 20 | - id: trailing-whitespace 21 | files: \.(py|md|ipynb)$ 22 | - id: debug-statements 23 | files: \.py$ # only include python files 24 | - id: mixed-line-ending # remove CR 25 | args: [--fix=lf] 26 | 27 | - repo: local 28 | hooks: 29 | - id: check-license 30 | name: Check license headers 31 | entry: python ./.github/check_license_headers.py 32 | language: python 33 | pass_filenames: false # only run this check once globally per repo 34 | always_run: true 35 | 36 | - repo: https://github.com/astral-sh/ruff-pre-commit 37 | rev: v0.9.10 38 | hooks: 39 | - id: ruff 40 | -------------------------------------------------------------------------------- /.github/workflows/pypi-publish.yml: -------------------------------------------------------------------------------- 1 | name: pypi 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | permissions: 11 | id-token: write 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Set up Python 15 | uses: actions/setup-python@v4 16 | with: 17 | python-version: '3.x' 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install setuptools wheel twine build 22 | - name: Check consistency between the package version and release tag 23 | run: | 24 | pip install . 25 | RELEASE_VER=${GITHUB_REF#refs/*/} 26 | PACKAGE_VER="v`python -c 'import optax; print(optax.__version__)'`" 27 | if [ $RELEASE_VER != $PACKAGE_VER ] 28 | then 29 | echo "package ver. ($PACKAGE_VER) != release ver. ($RELEASE_VER)"; exit 1 30 | fi 31 | - name: Build 32 | run: | 33 | python -m build 34 | - name: Publish package distributions to PyPI 35 | uses: pypa/gh-action-pypi-publish@release/v1 36 | -------------------------------------------------------------------------------- /docs/api/stochastic_gradient_estimators.rst: -------------------------------------------------------------------------------- 1 | Stochastic Gradient Estimators and Control Variates 2 | =================================================== 3 | 4 | .. warning:: 5 | This module has been deprecated and will be removed in optax 0.2.7 6 | 7 | .. currentmodule:: optax.monte_carlo 8 | 9 | .. autosummary:: 10 | control_delta_method 11 | control_variates_jacobians 12 | measure_valued_jacobians 13 | moving_avg_baseline 14 | pathwise_jacobians 15 | score_function_jacobians 16 | 17 | 18 | Control delta method 19 | ~~~~~~~~~~~~~~~~~~~~ 20 | .. autofunction:: control_delta_method 21 | 22 | Control variates Jacobians 23 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 24 | .. autofunction:: control_variates_jacobians 25 | 26 | Moving average baseline 27 | ~~~~~~~~~~~~~~~~~~~~~~~ 28 | .. autofunction:: moving_avg_baseline 29 | 30 | Measure valued Jacobians 31 | ~~~~~~~~~~~~~~~~~~~~~~~~ 32 | .. autofunction:: measure_valued_jacobians 33 | 34 | Pathwise Jacobians 35 | ~~~~~~~~~~~~~~~~~~ 36 | .. autofunction:: pathwise_jacobians 37 | 38 | Score function Jacobians 39 | ~~~~~~~~~~~~~~~~~~~~~~~~ 40 | .. autofunction:: score_function_jacobians 41 | -------------------------------------------------------------------------------- /optax/optax_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for importing optax.""" 16 | 17 | from absl.testing import absltest 18 | import optax 19 | from optax import transforms 20 | 21 | 22 | class OptaxTest(absltest.TestCase): 23 | """Test optax can be imported correctly.""" 24 | 25 | def test_import(self): 26 | self.assertTrue(hasattr(optax, 'GradientTransformation')) 27 | self.assertTrue(hasattr(transforms, 'partition')) 28 | 29 | 30 | if __name__ == '__main__': 31 | absltest.main() 32 | -------------------------------------------------------------------------------- /optax/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [Unreleased] 9 | 10 | ### Added 11 | 12 | - CHANGELOG.md file added to track notable changes. 13 | 14 | ### Changed 15 | 16 | - Test classes now inherit from `absl.TestCase` or `parameterized.TestCase` 17 | instead of `chex.TestCase` as part of our effort to remove the `chex` 18 | dependency. This means that Chex test variants (with/without `jit`, with/without 19 | `device_put`, with `pmap`) are no longer tested. We decided it was sufficient to 20 | use `jit` throughout the tests. There is already test coverage on both CPU and 21 | accelerators, and `pmap` is deprecated. 22 | - Classification losses (`poly_loss_cross_entropy`, 23 | `ctc_loss_with_forward_probs`, `ctc_loss`, `sigmoid_focal_loss`) and regression 24 | losses (`huber_loss`, `cosine_similarity`, `cosine_distance`) no longer support 25 | positional args for hyperparameter-like inputs. 26 | 27 | ### Removed 28 | 29 | - Stochastic gradient estimators à la Reinforce with control variates methods. 30 | See monte_carlo folder in optax 0.1.8 if you are interested. 31 | - Removed optax._src.transform.cast_tree and optax._src.utils.cast_tree. Use 32 | optax.tree.cast from now on. 33 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Testing 26 | 27 | Please make sure that your PR passes all tests by running `bash test.sh` on your 28 | local machine. Also, you can run only tests that are affected by your code 29 | changes, but you will need to select them manually. 30 | 31 | ## Community Guidelines 32 | 33 | This project follows [Google's Open Source Community 34 | Guidelines](https://opensource.google.com/conduct/). 35 | -------------------------------------------------------------------------------- /optax/projections/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """The projections sub-package.""" 17 | 18 | # pylint: disable=g-importing-member 19 | 20 | from optax.projections._projections import projection_box 21 | from optax.projections._projections import projection_halfspace 22 | from optax.projections._projections import projection_hypercube 23 | from optax.projections._projections import projection_hyperplane 24 | from optax.projections._projections import projection_l1_ball 25 | from optax.projections._projections import projection_l1_sphere 26 | from optax.projections._projections import projection_l2_ball 27 | from optax.projections._projections import projection_l2_sphere 28 | from optax.projections._projections import projection_linf_ball 29 | from optax.projections._projections import projection_non_negative 30 | from optax.projections._projections import projection_simplex 31 | from optax.projections._projections import projection_vector 32 | -------------------------------------------------------------------------------- /optax/losses/_smoothing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Smoothing functions.""" 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | from optax._src import utils 20 | 21 | 22 | def smooth_labels( 23 | labels: jax.typing.ArrayLike, 24 | alpha: jax.typing.ArrayLike, 25 | ) -> jax.Array: 26 | """Apply label smoothing. 27 | 28 | Label smoothing is often used in combination with a cross-entropy loss. 29 | Smoothed labels favor small logit gaps, and it has been shown that this can 30 | provide better model calibration by preventing overconfident predictions. 31 | 32 | Args: 33 | labels: One hot labels to be smoothed. 34 | alpha: The smoothing factor. 35 | 36 | Returns: 37 | a smoothed version of the one hot input labels. 38 | 39 | References: 40 | Muller et al, `When does label smoothing help? 41 | `_, 2019 42 | """ 43 | utils.check_subdtype(labels, jnp.floating) 44 | num_categories = labels.shape[-1] 45 | return (1.0 - alpha) * labels + alpha / num_categories 46 | -------------------------------------------------------------------------------- /optax/_src/wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Transformation wrappers.""" 16 | 17 | from optax.transforms import _accumulation 18 | from optax.transforms import _conditionality 19 | from optax.transforms import _layouts 20 | from optax.transforms import _masking 21 | 22 | 23 | apply_if_finite = _conditionality.apply_if_finite 24 | ApplyIfFiniteState = _conditionality.ApplyIfFiniteState 25 | ConditionFn = _conditionality.ConditionFn 26 | conditionally_mask = _conditionality.conditionally_mask 27 | conditionally_transform = _conditionality.conditionally_transform 28 | ConditionallyMaskState = _conditionality.ConditionallyMaskState 29 | ConditionallyTransformState = _conditionality.ConditionallyTransformState 30 | flatten = _layouts.flatten 31 | masked = _masking.masked 32 | MaskedNode = _masking.MaskedNode 33 | MaskedState = _masking.MaskedState 34 | MultiSteps = _accumulation.MultiSteps 35 | MultiStepsState = _accumulation.MultiStepsState 36 | ShouldSkipUpdateFunction = _accumulation.ShouldSkipUpdateFunction 37 | skip_not_finite = _accumulation.skip_not_finite 38 | skip_large_updates = _accumulation.skip_large_updates 39 | -------------------------------------------------------------------------------- /optax/schedules/_join_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for methods in `join.py`.""" 16 | 17 | from absl.testing import absltest 18 | import numpy as np 19 | from optax.schedules import _join 20 | from optax.schedules import _schedule 21 | 22 | 23 | class JoinTest(absltest.TestCase): 24 | 25 | def test_join_schedules(self): 26 | my_schedule = _join.join_schedules( 27 | schedules=[ 28 | _schedule.constant_schedule(1.0), 29 | _schedule.constant_schedule(2.0), 30 | _schedule.constant_schedule(1.0), 31 | ], 32 | boundaries=[3, 6], 33 | ) 34 | np.testing.assert_allclose(1.0, my_schedule(0), atol=0.0) 35 | np.testing.assert_allclose(1.0, my_schedule(1), atol=0.0) 36 | np.testing.assert_allclose(1.0, my_schedule(2), atol=0.0) 37 | np.testing.assert_allclose(2.0, my_schedule(3), atol=0.0) 38 | np.testing.assert_allclose(2.0, my_schedule(4), atol=0.0) 39 | np.testing.assert_allclose(2.0, my_schedule(5), atol=0.0) 40 | np.testing.assert_allclose(1.0, my_schedule(6), atol=0.0) 41 | 42 | 43 | if __name__ == "__main__": 44 | absltest.main() 45 | -------------------------------------------------------------------------------- /optax/schedules/_join.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utilities to join schedules.""" 16 | 17 | from typing import Sequence 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | from optax._src import base 22 | 23 | 24 | def join_schedules( 25 | schedules: Sequence[base.Schedule], boundaries: Sequence[int] 26 | ) -> base.Schedule: 27 | """Sequentially apply multiple schedules. 28 | 29 | Args: 30 | schedules: A list of callables (expected to be optax schedules). Each 31 | schedule will receive a step count indicating the number of steps since 32 | the previous boundary transition. 33 | boundaries: A list of integers (of length one less than schedules) that 34 | indicate when to transition between schedules. 35 | 36 | Returns: 37 | schedule: A function that maps step counts to values. 38 | """ 39 | 40 | def schedule(step: jax.typing.ArrayLike) -> jax.typing.ArrayLike: 41 | output = schedules[0](step) 42 | for boundary, schedule in zip(boundaries, schedules[1:]): 43 | output = jnp.where(step < boundary, output, schedule(step - boundary)) 44 | return output 45 | 46 | return schedule 47 | -------------------------------------------------------------------------------- /docs/api/optimizer_schedules.rst: -------------------------------------------------------------------------------- 1 | Optimizer Schedules 2 | ===================== 3 | 4 | .. currentmodule:: optax.schedules 5 | 6 | .. autosummary:: 7 | constant_schedule 8 | cosine_decay_schedule 9 | cosine_onecycle_schedule 10 | exponential_decay 11 | join_schedules 12 | linear_onecycle_schedule 13 | linear_schedule 14 | piecewise_constant_schedule 15 | piecewise_interpolate_schedule 16 | polynomial_schedule 17 | sgdr_schedule 18 | warmup_constant_schedule 19 | warmup_cosine_decay_schedule 20 | warmup_exponential_decay_schedule 21 | Schedule 22 | InjectHyperparamsState 23 | inject_hyperparams 24 | 25 | 26 | .. autoclass:: Schedule 27 | 28 | Constant schedule 29 | ~~~~~~~~~~~~~~~~~ 30 | .. autofunction:: constant_schedule 31 | .. autofunction:: warmup_constant_schedule 32 | 33 | Cosine decay schedule 34 | ~~~~~~~~~~~~~~~~~~~~~ 35 | .. autofunction:: cosine_decay_schedule 36 | .. autofunction:: cosine_onecycle_schedule 37 | .. autofunction:: warmup_cosine_decay_schedule 38 | 39 | Exponential decay schedule 40 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 41 | .. autofunction:: exponential_decay 42 | .. autofunction:: warmup_exponential_decay_schedule 43 | 44 | Join schedules 45 | ~~~~~~~~~~~~~~ 46 | .. autofunction:: join_schedules 47 | 48 | Inject hyperparameters 49 | ~~~~~~~~~~~~~~~~~~~~~~ 50 | .. autofunction:: inject_hyperparams 51 | .. autoclass:: InjectHyperparamsState 52 | 53 | Linear schedules 54 | ~~~~~~~~~~~~~~~~ 55 | .. autofunction:: linear_onecycle_schedule 56 | .. autofunction:: linear_schedule 57 | 58 | Piecewise schedules 59 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 60 | .. autofunction:: piecewise_constant_schedule 61 | .. autofunction:: piecewise_interpolate_schedule 62 | 63 | Polynomial schedules 64 | ~~~~~~~~~~~~~~~~~~~~ 65 | .. autofunction:: polynomial_schedule 66 | 67 | Reduce on plateau 68 | ~~~~~~~~~~~~~~~~~ 69 | .. autofunction:: optax.contrib.reduce_on_plateau 70 | 71 | Warm restarts 72 | ~~~~~~~~~~~~~ 73 | .. autofunction:: sgdr_schedule 74 | -------------------------------------------------------------------------------- /optax/_src/schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Hyper-parameters Schedules. 16 | 17 | Schedules may be used to anneal the value of a hyper-parameter over time; for 18 | instance, they may be used to anneal the learning rate used to update an agent's 19 | parameters or the exploration factor used to select actions. 20 | """ 21 | 22 | from optax import schedules 23 | 24 | 25 | # TODO(mtthss): remove schedules alises from flat namespaces after user updates. 26 | constant_schedule = schedules.constant_schedule 27 | cosine_decay_schedule = schedules.cosine_decay_schedule 28 | cosine_onecycle_schedule = schedules.cosine_onecycle_schedule 29 | exponential_decay = schedules.exponential_decay 30 | inject_hyperparams = schedules.inject_hyperparams 31 | InjectHyperparamsState = schedules.InjectHyperparamsState 32 | join_schedules = schedules.join_schedules 33 | linear_onecycle_schedule = schedules.linear_onecycle_schedule 34 | linear_schedule = schedules.linear_schedule 35 | piecewise_constant_schedule = schedules.piecewise_constant_schedule 36 | piecewise_interpolate_schedule = schedules.piecewise_interpolate_schedule 37 | polynomial_schedule = schedules.polynomial_schedule 38 | sgdr_schedule = schedules.sgdr_schedule 39 | warmup_constant_schedule = schedules.warmup_constant_schedule 40 | warmup_cosine_decay_schedule = schedules.warmup_cosine_decay_schedule 41 | warmup_exponential_decay_schedule = schedules.warmup_exponential_decay_schedule 42 | -------------------------------------------------------------------------------- /optax/schedules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utilities for creating schedules.""" 16 | 17 | # pylint: disable=g-importing-member 18 | 19 | from optax._src.base import Schedule 20 | from optax._src.base import StatefulSchedule 21 | from optax.schedules._inject import inject_hyperparams 22 | from optax.schedules._inject import inject_stateful_hyperparams 23 | from optax.schedules._inject import InjectHyperparamsState 24 | from optax.schedules._inject import InjectStatefulHyperparamsState 25 | from optax.schedules._inject import WrappedSchedule 26 | from optax.schedules._join import join_schedules 27 | from optax.schedules._schedule import constant_schedule 28 | from optax.schedules._schedule import cosine_decay_schedule 29 | from optax.schedules._schedule import cosine_onecycle_schedule 30 | from optax.schedules._schedule import exponential_decay 31 | from optax.schedules._schedule import linear_onecycle_schedule 32 | from optax.schedules._schedule import linear_schedule 33 | from optax.schedules._schedule import piecewise_constant_schedule 34 | from optax.schedules._schedule import piecewise_interpolate_schedule 35 | from optax.schedules._schedule import polynomial_schedule 36 | from optax.schedules._schedule import sgdr_schedule 37 | from optax.schedules._schedule import warmup_constant_schedule 38 | from optax.schedules._schedule import warmup_cosine_decay_schedule 39 | from optax.schedules._schedule import warmup_exponential_decay_schedule 40 | -------------------------------------------------------------------------------- /optax/transforms/_layouts_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for methods in `optax.transforms._layouts.py`.""" 16 | 17 | from absl.testing import absltest 18 | import jax.numpy as jnp 19 | from optax._src import alias 20 | from optax._src import test_utils 21 | from optax._src import update 22 | from optax.transforms import _layouts 23 | 24 | 25 | class LayoutsTest(absltest.TestCase): 26 | 27 | def test_flatten(self): 28 | def init_params(): 29 | return (jnp.array(2.0), jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])) 30 | 31 | per_step_updates = ( 32 | jnp.array(1.0), 33 | jnp.array([500.0, 5.0]), 34 | jnp.array([300.0, 3.0]), 35 | ) 36 | 37 | # First calculate new params without flattening 38 | optax_sgd_params = init_params() 39 | sgd = alias.sgd(1e-2, 0.0) 40 | state_sgd = sgd.init(optax_sgd_params) 41 | updates_sgd, _ = sgd.update(per_step_updates, state_sgd) 42 | sgd_params_no_flatten = update.apply_updates(optax_sgd_params, updates_sgd) 43 | 44 | # And now calculate new params with flattening 45 | optax_sgd_params = init_params() 46 | sgd = _layouts.flatten(sgd) 47 | 48 | state_sgd = sgd.init(optax_sgd_params) 49 | updates_sgd, _ = sgd.update(per_step_updates, state_sgd) 50 | sgd_params_flatten = update.apply_updates(optax_sgd_params, updates_sgd) 51 | 52 | # Test that both give the same result 53 | test_utils.assert_trees_all_close( 54 | sgd_params_no_flatten, sgd_params_flatten, atol=1e-7, rtol=1e-7 55 | ) 56 | 57 | 58 | if __name__ == "__main__": 59 | absltest.main() 60 | -------------------------------------------------------------------------------- /optax/losses/_fenchel_young.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Fenchel-Young losses.""" 16 | 17 | from typing import Any, Protocol 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | 22 | 23 | class MaxFun(Protocol): 24 | 25 | def __call__(self, scores, *args, **kwargs: Any) -> jax.typing.ArrayLike: 26 | ... 27 | 28 | 29 | def make_fenchel_young_loss(max_fun: MaxFun): 30 | """Creates a Fenchel-Young loss from a max function. 31 | 32 | Args: 33 | max_fun: the max function on which the Fenchel-Young loss is built. 34 | 35 | Returns: 36 | A Fenchel-Young loss function with the same signature. 37 | 38 | Examples: 39 | Given a max function, e.g., the log sum exp, you can construct a 40 | Fenchel-Young loss easily as follows: 41 | 42 | >>> from jax.scipy.special import logsumexp 43 | >>> fy_loss = optax.losses.make_fenchel_young_loss(max_fun=logsumexp) 44 | 45 | Reference: 46 | Blondel et al. `Learning with Fenchel-Young Losses 47 | `_, 2020 48 | 49 | .. warning:: 50 | The resulting loss accepts an arbitrary number of leading dimensions 51 | with the fy_loss operating over the last dimension. The jaxopt version of 52 | this function would instead flatten any vector in a single big 1D vector. 53 | """ 54 | 55 | vdot_last_dim = jnp.vectorize(jnp.vdot, signature="(n),(n)->()") 56 | max_fun_last_dim = jnp.vectorize(max_fun, signature="(n)->()") 57 | 58 | def fenchel_young_loss(scores, targets, *args, **kwargs): 59 | max_value = max_fun_last_dim(scores, *args, **kwargs) 60 | return max_value - vdot_last_dim(targets, scores) 61 | 62 | return fenchel_young_loss 63 | -------------------------------------------------------------------------------- /optax/tree/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The tree_utils sub-package.""" 16 | 17 | # pylint: disable=g-importing-member 18 | 19 | from optax.tree_utils import _casting 20 | from optax.tree_utils import _random 21 | from optax.tree_utils import _state_utils 22 | from optax.tree_utils import _tree_math 23 | 24 | cast = _casting.tree_cast 25 | cast_like = _casting.tree_cast_like 26 | dtype = _casting.tree_dtype 27 | random_like = _random.tree_random_like 28 | split_key_like = _random.tree_split_key_like 29 | unwrap_random_key_data = _random.tree_unwrap_random_key_data 30 | get = _state_utils.tree_get 31 | get_all_with_path = _state_utils.tree_get_all_with_path 32 | map_params = _state_utils.tree_map_params 33 | set = _state_utils.tree_set # pylint: disable=redefined-builtin 34 | add = _tree_math.tree_add 35 | add_scale = _tree_math.tree_add_scale 36 | allclose = _tree_math.tree_allclose 37 | batch_shape = _tree_math.tree_batch_shape 38 | bias_correction = _tree_math.tree_bias_correction 39 | clip = _tree_math.tree_clip 40 | conj = _tree_math.tree_conj 41 | div = _tree_math.tree_div 42 | full_like = _tree_math.tree_full_like 43 | max = _tree_math.tree_max # pylint: disable=redefined-builtin 44 | min = _tree_math.tree_min # pylint: disable=redefined-builtin 45 | mul = _tree_math.tree_mul 46 | norm = _tree_math.tree_norm 47 | ones_like = _tree_math.tree_ones_like 48 | real = _tree_math.tree_real 49 | scale = _tree_math.tree_scale 50 | size = _tree_math.tree_size 51 | sub = _tree_math.tree_sub 52 | sum = _tree_math.tree_sum # pylint: disable=redefined-builtin 53 | update_infinity_moment = _tree_math.tree_update_infinity_moment 54 | update_moment = _tree_math.tree_update_moment 55 | update_moment_per_elem_norm = _tree_math.tree_update_moment_per_elem_norm 56 | vdot = _tree_math.tree_vdot 57 | where = _tree_math.tree_where 58 | zeros_like = _tree_math.tree_zeros_like 59 | -------------------------------------------------------------------------------- /docs/api/optimizers.rst: -------------------------------------------------------------------------------- 1 | Optimizers 2 | ========== 3 | 4 | .. currentmodule:: optax 5 | 6 | .. autosummary:: 7 | adabelief 8 | adadelta 9 | adan 10 | adafactor 11 | adagrad 12 | adam 13 | adamw 14 | adamax 15 | adamaxw 16 | amsgrad 17 | fromage 18 | lamb 19 | lars 20 | lbfgs 21 | lion 22 | nadam 23 | nadamw 24 | noisy_sgd 25 | novograd 26 | optimistic_gradient_descent 27 | optimistic_adam_v2 28 | polyak_sgd 29 | radam 30 | rmsprop 31 | sgd 32 | sign_sgd 33 | signum 34 | sm3 35 | yogi 36 | 37 | 38 | AdaBelief 39 | ~~~~~~~~~ 40 | .. autofunction:: adabelief 41 | 42 | AdaDelta 43 | ~~~~~~~~~ 44 | .. autofunction:: adadelta 45 | 46 | Adan 47 | ~~~~ 48 | .. autofunction:: adan 49 | 50 | AdaGrad 51 | ~~~~~~~ 52 | .. autofunction:: adagrad 53 | 54 | AdaFactor 55 | ~~~~~~~~~ 56 | .. autofunction:: adafactor 57 | 58 | Adam 59 | ~~~~ 60 | .. autofunction:: adam 61 | 62 | Adamax 63 | ~~~~~~ 64 | .. autofunction:: adamax 65 | 66 | AdamaxW 67 | ~~~~~~~ 68 | .. autofunction:: adamaxw 69 | 70 | AdamW 71 | ~~~~~ 72 | .. autofunction:: adamw 73 | 74 | AMSGrad 75 | ~~~~~~~ 76 | .. autofunction:: amsgrad 77 | 78 | Fromage 79 | ~~~~~~~ 80 | .. autofunction:: fromage 81 | 82 | Lamb 83 | ~~~~ 84 | .. autofunction:: lamb 85 | 86 | Lars 87 | ~~~~ 88 | .. autofunction:: lars 89 | 90 | LBFGS 91 | ~~~~~ 92 | .. autofunction:: lbfgs 93 | 94 | Lion 95 | ~~~~ 96 | .. autofunction:: lion 97 | 98 | Nadam 99 | ~~~~~ 100 | .. autofunction:: nadam 101 | 102 | NadamW 103 | ~~~~~~ 104 | .. autofunction:: nadamw 105 | 106 | Noisy SGD 107 | ~~~~~~~~~ 108 | .. autofunction:: noisy_sgd 109 | 110 | Novograd 111 | ~~~~~~~~ 112 | .. autofunction:: novograd 113 | 114 | Optimistic GD 115 | ~~~~~~~~~~~~~ 116 | .. autofunction:: optimistic_gradient_descent 117 | 118 | Optimistic Adam 119 | ~~~~~~~~~~~~~~~ 120 | .. autofunction:: optimistic_adam_v2 121 | 122 | Polyak step-size SGD 123 | ~~~~~~~~~~~~~~~~~~~~ 124 | .. autofunction:: polyak_sgd 125 | 126 | RAdam 127 | ~~~~~ 128 | .. autofunction:: radam 129 | 130 | RMSProp 131 | ~~~~~~~ 132 | .. autofunction:: rmsprop 133 | 134 | RProp 135 | ~~~~~ 136 | .. autofunction:: rprop 137 | 138 | SGD 139 | ~~~ 140 | .. autofunction:: sgd 141 | 142 | SignSGD 143 | ~~~~~~~ 144 | .. autofunction:: sign_sgd 145 | 146 | Signum 147 | ~~~~~~ 148 | .. autofunction:: signum 149 | 150 | SM3 151 | ~~~ 152 | .. autofunction:: sm3 153 | 154 | Yogi 155 | ~~~~ 156 | .. autofunction:: yogi 157 | -------------------------------------------------------------------------------- /optax/losses/_smoothing_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for smoothing functions in `optax.losses._smoothing.py`.""" 16 | 17 | from absl.testing import absltest 18 | import jax 19 | import jax.numpy as jnp 20 | import numpy as np 21 | from optax.losses import _smoothing 22 | 23 | 24 | class SmoothLabelsTest(absltest.TestCase): 25 | 26 | def setUp(self): 27 | super().setUp() 28 | self.ts = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0]], dtype=np.float32) 29 | # compute expected outputs in numpy. 30 | self.exp_alpha_zero = self.ts 31 | self.exp_alpha_zero_point_one = 0.9 * self.ts + 0.1 / self.ts.shape[-1] 32 | self.exp_alpha_one = jnp.ones_like(self.ts) / self.ts.shape[-1] 33 | 34 | def test_scalar(self): 35 | """Tests for a full batch.""" 36 | np.testing.assert_allclose( 37 | jax.jit(_smoothing.smooth_labels)(self.ts[0], 0.0), 38 | self.exp_alpha_zero[0], 39 | atol=1e-4, 40 | ) 41 | np.testing.assert_allclose( 42 | jax.jit(_smoothing.smooth_labels)(self.ts[0], 0.1), 43 | self.exp_alpha_zero_point_one[0], 44 | atol=1e-4, 45 | ) 46 | np.testing.assert_allclose( 47 | jax.jit(_smoothing.smooth_labels)(self.ts[0], 1.0), 48 | self.exp_alpha_one[0], 49 | atol=1e-4, 50 | ) 51 | 52 | def test_batched(self): 53 | """Tests for a full batch.""" 54 | np.testing.assert_allclose( 55 | jax.jit(_smoothing.smooth_labels)(self.ts, 0.0), 56 | self.exp_alpha_zero, 57 | atol=1e-4, 58 | ) 59 | np.testing.assert_allclose( 60 | jax.jit(_smoothing.smooth_labels)(self.ts, 0.1), 61 | self.exp_alpha_zero_point_one, 62 | atol=1e-4, 63 | ) 64 | np.testing.assert_allclose( 65 | jax.jit(_smoothing.smooth_labels)(self.ts, 1.0), 66 | self.exp_alpha_one, 67 | atol=1e-4, 68 | ) 69 | 70 | 71 | if __name__ == '__main__': 72 | absltest.main() 73 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/google-deepmind/optax/tree/main/docs 2 | 3 | ===== 4 | Optax 5 | ===== 6 | 7 | Optax is a gradient processing and optimization library for JAX. It is designed 8 | to facilitate research by providing building blocks that can be recombined in 9 | custom ways in order to optimize parametric models such as, but not limited to, 10 | deep neural networks. 11 | 12 | Our goals are to 13 | 14 | * Provide readable, well-tested, efficient implementations of core components, 15 | * Improve researcher productivity by making it possible to combine low level 16 | ingredients into custom optimizer (or other gradient processing components). 17 | * Accelerate adoption of new ideas by making it easy for anyone to contribute. 18 | 19 | We favor focusing on small composable building blocks that can be effectively 20 | combined into custom solutions. Others may build upon these basic components 21 | more complicated abstractions. Whenever reasonable, implementations prioritize 22 | readability and structuring code to match standard equations, over code reuse. 23 | 24 | Installation 25 | ------------ 26 | 27 | The latest release of Optax can be installed from 28 | `PyPI `_ using:: 29 | 30 | pip install optax 31 | 32 | You may also install directly from GitHub, using the following command. This 33 | can be used to obtain the most recent version of Optax:: 34 | 35 | pip install git+git://github.com/google-deepmind/optax.git 36 | 37 | Note that Optax is built on top of JAX. 38 | See `here `_ 39 | for instructions on installing JAX. 40 | 41 | 42 | .. toctree:: 43 | :hidden: 44 | 45 | getting_started 46 | 47 | gallery 48 | 49 | development 50 | 51 | 52 | .. toctree:: 53 | :hidden: 54 | :caption: 📖 Reference 55 | :maxdepth: 2 56 | 57 | api/assignment 58 | api/optimizers 59 | api/transformations 60 | api/combining_optimizers 61 | api/optimizer_wrappers 62 | api/optimizer_schedules 63 | api/apply_updates 64 | api/perturbations 65 | api/projections 66 | api/losses 67 | api/stochastic_gradient_estimators 68 | api/utilities 69 | api/contrib 70 | api/experimental 71 | 72 | 73 | Support 74 | ------- 75 | 76 | If you encounter issues with this software, please let us know by filing an issue on our `issue tracker `_. We are also happy to receive bug fixes and other contributions. For more information of how to contribute, please see the :doc:`development guide `. 77 | 78 | 79 | License 80 | ------- 81 | 82 | Optax is licensed under the `Apache 2.0 License `_. 83 | 84 | 85 | Indices and Tables 86 | ================== 87 | 88 | * :ref:`genindex` 89 | -------------------------------------------------------------------------------- /optax/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The losses sub-package.""" 16 | 17 | # pylint:disable=g-importing-member 18 | 19 | from optax.losses._classification import convex_kl_divergence 20 | from optax.losses._classification import ctc_loss 21 | from optax.losses._classification import ctc_loss_with_forward_probs 22 | from optax.losses._classification import generalized_kl_divergence 23 | from optax.losses._classification import hinge_loss 24 | from optax.losses._classification import kl_divergence 25 | from optax.losses._classification import kl_divergence_with_log_targets 26 | from optax.losses._classification import multiclass_hinge_loss 27 | from optax.losses._classification import multiclass_perceptron_loss 28 | from optax.losses._classification import multiclass_sparsemax_loss 29 | from optax.losses._classification import perceptron_loss 30 | from optax.losses._classification import poly_loss_cross_entropy 31 | from optax.losses._classification import safe_softmax_cross_entropy 32 | from optax.losses._classification import sigmoid_binary_cross_entropy 33 | from optax.losses._classification import sigmoid_focal_loss 34 | from optax.losses._classification import softmax_cross_entropy 35 | # pylint: disable=line-too-long 36 | from optax.losses._classification import softmax_cross_entropy_with_integer_labels # noqa: E501 37 | # pylint: enable=line-too-long 38 | from optax.losses._classification import sparsemax_loss 39 | from optax.losses._fenchel_young import make_fenchel_young_loss 40 | from optax.losses._ranking import ranking_softmax_loss 41 | from optax.losses._regression import cosine_distance 42 | from optax.losses._regression import cosine_similarity 43 | from optax.losses._regression import huber_loss 44 | from optax.losses._regression import l2_loss 45 | from optax.losses._regression import log_cosh 46 | from optax.losses._regression import squared_error 47 | from optax.losses._segmentation import binary_dice_loss 48 | from optax.losses._segmentation import dice_loss 49 | from optax.losses._segmentation import multiclass_generalized_dice_loss 50 | from optax.losses._self_supervised import ntxent 51 | from optax.losses._self_supervised import triplet_margin_loss 52 | from optax.losses._smoothing import smooth_labels 53 | -------------------------------------------------------------------------------- /optax/transforms/_layouts.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Wrappers changing the layouts of the tensors that transforms operate on.""" 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | import numpy as np 20 | from optax._src import base 21 | 22 | 23 | def flatten( 24 | inner: base.GradientTransformation, 25 | ) -> base.GradientTransformationExtraArgs: 26 | """Flattens parameters and gradients for init and update of inner transform. 27 | 28 | This can reduce the overhead of performing many calculations on lots of small 29 | variables, at the cost of slightly increased memory usage. 30 | 31 | Args: 32 | inner: Inner transformation to flatten inputs for. 33 | 34 | Returns: 35 | New :class:`optax.GradientTransformationExtraArgs` 36 | """ 37 | 38 | inner = base.with_extra_args_support(inner) 39 | 40 | def _flatten(params): 41 | """Flattens and concatenates all tensors in params to a single vector.""" 42 | params, _ = jax.tree.flatten(params) 43 | return jnp.concatenate([jnp.reshape(param, [-1]) for param in params]) 44 | 45 | def _unflatten(updates, flat): 46 | """Extracts tensors from flat, using the structure and shapes of params.""" 47 | updates_flat, treedef = jax.tree.flatten(updates) 48 | offsets = [] 49 | for update in updates_flat: 50 | size = np.size(update) 51 | if offsets: 52 | offsets.append(size + offsets[-1]) 53 | else: 54 | offsets.append(size) 55 | del offsets[-1] 56 | flat_split = jnp.split(flat, offsets) 57 | reshaped = [ 58 | jnp.reshape(flat_update, update.shape) 59 | for flat_update, update in zip(flat_split, updates_flat) 60 | ] 61 | return jax.tree.unflatten(treedef, reshaped) 62 | 63 | def init_fn(params): 64 | flat = _flatten(params) 65 | return inner.init(flat) 66 | 67 | def update_fn(updates, state, params=None, **extra_args): 68 | if params is not None: 69 | params = _flatten(params) 70 | updates_flat, state = inner.update( 71 | _flatten(updates), state, params, **extra_args 72 | ) 73 | updates = _unflatten(updates, updates_flat) 74 | return updates, state 75 | 76 | return base.GradientTransformationExtraArgs(init_fn, update_fn) 77 | -------------------------------------------------------------------------------- /optax/losses/_fenchel_young_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for fenchel young loss in `_fenchel_young.py`.""" 16 | 17 | from absl.testing import absltest 18 | import jax 19 | import jax.numpy as jnp 20 | from jax.scipy.special import logsumexp # pylint: disable=g-importing-member 21 | 22 | from optax._src import test_utils 23 | from optax.losses import _classification 24 | from optax.losses import _fenchel_young 25 | 26 | 27 | def one_hot_argmax(inputs: jnp.ndarray) -> jnp.ndarray: 28 | """An argmax one-hot function for arbitrary shapes.""" 29 | inputs_flat = jnp.reshape(inputs, (-1)) 30 | flat_one_hot = jax.nn.one_hot(jnp.argmax(inputs_flat), inputs_flat.shape[0]) 31 | return jnp.reshape(flat_one_hot, inputs.shape) 32 | 33 | 34 | class FenchelYoungTest(absltest.TestCase): 35 | 36 | def test_fenchel_young_reg(self): 37 | # Checks the behavior of the Fenchel-Young loss. 38 | fy_loss = jax.jit(_fenchel_young.make_fenchel_young_loss(logsumexp)) 39 | rng = jax.random.key(0) 40 | rngs = jax.random.split(rng, 2) 41 | theta_true = jax.random.uniform(rngs[0], (8, 5)) 42 | y_true = jax.vmap(jax.nn.softmax)(theta_true) 43 | theta_random = jax.random.uniform(rngs[1], (8, 5)) 44 | y_random = jax.vmap(jax.nn.softmax)(theta_random) 45 | grad_random = jax.vmap(jax.grad(fy_loss))(theta_random, y_true) 46 | # Checks that the gradient of the loss takes the correct form. 47 | test_utils.assert_trees_all_close(grad_random, y_random - y_true, rtol=1e-4) 48 | y_one_hot = jax.vmap(one_hot_argmax)(theta_true) 49 | int_one_hot = jnp.where(y_one_hot == 1.)[1] 50 | loss_one_hot = jax.vmap(fy_loss)(theta_random, y_one_hot) 51 | log_loss = jax.vmap( 52 | _classification.softmax_cross_entropy_with_integer_labels)( 53 | theta_random, int_one_hot) 54 | # Checks that the FY loss associated to logsumexp is correct. 55 | test_utils.assert_trees_all_close(loss_one_hot, log_loss, rtol=1e-4) 56 | # Checks that vmapping or not is equivalent. 57 | loss_one_hot_no_vmap = fy_loss(theta_random, y_one_hot) 58 | test_utils.assert_trees_all_close( 59 | loss_one_hot, loss_one_hot_no_vmap, rtol=1e-4) 60 | 61 | 62 | if __name__ == "__main__": 63 | absltest.main() 64 | -------------------------------------------------------------------------------- /docs/api/projections.rst: -------------------------------------------------------------------------------- 1 | Projections 2 | =========== 3 | 4 | .. currentmodule:: optax.projections 5 | 6 | Projections can be used to perform constrained optimization. 7 | The Euclidean projection onto a set :math:`\mathcal{C}` is: 8 | 9 | .. math:: 10 | 11 | \text{proj}_{\mathcal{C}}(u) := 12 | \underset{v}{\text{argmin}} ~ \|u - v\|^2_2 \textrm{ subject to } v \in \mathcal{C}. 13 | 14 | For instance, here is an example how we can project parameters to the non-negative orthant:: 15 | 16 | >>> import optax 17 | >>> import jax 18 | >>> import jax.numpy as jnp 19 | >>> num_weights = 2 20 | >>> xs = jnp.array([[-1.8, 2.2], [-2.0, 1.2]]) 21 | >>> ys = jnp.array([0.5, 0.8]) 22 | >>> optimizer = optax.adam(learning_rate=1e-3) 23 | >>> params = {'w': jnp.zeros(num_weights)} 24 | >>> opt_state = optimizer.init(params) 25 | >>> loss = lambda params, x, y: jnp.mean((params['w'].dot(x) - y) ** 2) 26 | >>> grads = jax.grad(loss)(params, xs, ys) 27 | >>> updates, opt_state = optimizer.update(grads, opt_state) 28 | >>> params = optax.apply_updates(params, updates) 29 | >>> params = optax.projections.projection_non_negative(params) 30 | 31 | Available projections 32 | ~~~~~~~~~~~~~~~~~~~~~ 33 | .. autosummary:: 34 | projection_box 35 | projection_hypercube 36 | projection_l1_ball 37 | projection_l1_sphere 38 | projection_l2_ball 39 | projection_l2_sphere 40 | projection_linf_ball 41 | projection_non_negative 42 | projection_simplex 43 | projection_vector 44 | projection_hyperplane 45 | projection_halfspace 46 | 47 | Projection onto a box 48 | ~~~~~~~~~~~~~~~~~~~~~ 49 | .. autofunction:: projection_box 50 | 51 | Projection onto a hypercube 52 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 53 | .. autofunction:: projection_hypercube 54 | 55 | Projection onto the L1 ball 56 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 57 | .. autofunction:: projection_l1_ball 58 | 59 | Projection onto the L1 sphere 60 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 61 | .. autofunction:: projection_l1_sphere 62 | 63 | Projection onto the L2 ball 64 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 65 | .. autofunction:: projection_l2_ball 66 | 67 | Projection onto the L2 sphere 68 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 69 | .. autofunction:: projection_l2_sphere 70 | 71 | Projection onto the L-infinity ball 72 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 73 | .. autofunction:: projection_linf_ball 74 | 75 | Projection onto the non-negative orthant 76 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 77 | .. autofunction:: projection_non_negative 78 | 79 | Projection onto a simplex 80 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 81 | .. autofunction:: projection_simplex 82 | 83 | Projection onto a vector 84 | ~~~~~~~~~~~~~~~~~~~~~~~~ 85 | .. autofunction:: projection_vector 86 | 87 | Projection onto a hyperplane 88 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 89 | .. autofunction:: projection_hyperplane 90 | 91 | Projection onto a halfspace 92 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 93 | .. autofunction:: projection_halfspace 94 | -------------------------------------------------------------------------------- /optax/contrib/_complex_valued_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for methods in `complex_valued.py`.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | import jax.numpy as jnp 21 | import numpy as np 22 | from optax._src import transform 23 | from optax._src import update 24 | from optax.contrib import _complex_valued 25 | 26 | 27 | def _loss_fun_complex_to_real(z): 28 | return (z.conj() * z).real.sum() 29 | 30 | 31 | def _loss_fun_real_to_real(params): 32 | x, y = params 33 | return _loss_fun_complex_to_real(x + y * 1j) 34 | 35 | 36 | class ComplexValuedTest(parameterized.TestCase): 37 | 38 | @parameterized.named_parameters([ 39 | ('adam', transform.scale_by_adam), 40 | ('param_block_norm', transform.scale_by_param_block_norm), 41 | ]) 42 | def test_split_real_and_imaginary(self, scaler_constr): 43 | 44 | def do_update(loss_fun, optimizer, params, opt_state): 45 | loss, grads = jax.value_and_grad(loss_fun)(params) 46 | # Complex gradients need to be conjugated before being added to parameters 47 | grads = jax.tree.map(lambda x: x.conj(), grads) 48 | updates, opt_state = jax.jit(optimizer.update)( 49 | grads, opt_state, params 50 | ) 51 | params = update.apply_updates(params, updates) 52 | return loss, grads, params, opt_state 53 | 54 | x = jnp.array([[0.1, 0.2, 0.3], [-0.1, -0.2, -0.3]]) 55 | y = jnp.array([[0.5, -0.5, 0], [0.1, 0.3, -0.2]]) 56 | z = x + y * 1j 57 | 58 | optimizer = scaler_constr() 59 | optimizer_complex = _complex_valued.split_real_and_imaginary(optimizer) 60 | opt_state = jax.jit(optimizer.init)((x, y)) 61 | opt_state_complex = jax.jit(optimizer_complex.init)(z) 62 | 63 | # Check that the loss, the gradients, and the parameters are the same for 64 | # real-to-real and complex-to-real loss functions in each step 65 | for _ in range(3): 66 | loss, (gx, gy), (x, y), opt_state = do_update( 67 | _loss_fun_real_to_real, optimizer, (x, y), opt_state 68 | ) 69 | loss_complex, gz, z, opt_state_complex = do_update( 70 | _loss_fun_complex_to_real, optimizer_complex, z, opt_state_complex 71 | ) 72 | rtol = 1e-6 73 | np.testing.assert_allclose(loss, loss_complex, rtol=rtol) 74 | np.testing.assert_allclose(gx + gy * 1j, gz, rtol=rtol) 75 | np.testing.assert_allclose(x + y * 1j, z, rtol=rtol) 76 | 77 | 78 | if __name__ == '__main__': 79 | absltest.main() 80 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | function cleanup { 17 | deactivate 18 | rm -r "${TEMP_DIR}" 19 | } 20 | trap cleanup EXIT 21 | 22 | REPO_DIR=$(pwd) 23 | TEMP_DIR=$(mktemp --directory) 24 | 25 | set -o errexit 26 | set -o nounset 27 | set -o pipefail 28 | 29 | # Install deps in a virtual env. 30 | python3 -m venv "${TEMP_DIR}/test_venv" 31 | source "${TEMP_DIR}/test_venv/bin/activate" 32 | 33 | # Run the linter first to check lint errors quickly 34 | python3 -m pip install --quiet --upgrade pip uv 35 | python3 -m uv pip install --quiet pre-commit 36 | pre-commit run -a 37 | 38 | # Install dependencies. 39 | python3 -m uv pip install --quiet --upgrade pip setuptools wheel 40 | python3 -m uv pip install --quiet --upgrade flake8 pytest-xdist pylint pylint-exit 41 | python3 -m uv pip install --quiet --editable ".[test]" 42 | 43 | # Install the requested JAX version 44 | if [ -z "${JAX_VERSION-}" ]; then 45 | : # use version installed in requirements above 46 | elif [ "$JAX_VERSION" = "newest" ]; then 47 | python3 -m uv pip install --quiet --upgrade jax jaxlib 48 | elif [ "$JAX_VERSION" = "nightly" ]; then 49 | python3 -m uv pip install --quiet --upgrade --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ 50 | else 51 | python3 -m uv pip install --quiet "jax==${JAX_VERSION}" "jaxlib==${JAX_VERSION}" 52 | fi 53 | 54 | # Ensure optax was not installed by one of the dependencies above, 55 | # since if it is, the tests below will be run against that version instead of 56 | # the branch build. 57 | python3 -m uv pip uninstall optax 58 | 59 | # Lint with flake8. 60 | python3 -m flake8 --select=E9,F63,F7,F82,E225,E251 --show-source --statistics 61 | 62 | # Lint with pylint. 63 | pylint optax 64 | 65 | # Build the package. 66 | python3 -m uv pip install --quiet build 67 | python3 -m build 68 | python3 -m pip wheel --no-deps dist/optax-*.tar.gz --wheel-dir "${TEMP_DIR}" 69 | python3 -m pip install --quiet "${TEMP_DIR}/optax-"*.whl 70 | 71 | # Check types with pytype. 72 | python3 -m pip install --quiet pytype 73 | pytype "optax" -j auto --keep-going --disable import-error 74 | 75 | # Run tests using pytest. 76 | # Change directory to avoid importing the package from repo root. 77 | cd "${TEMP_DIR}" 78 | python3 -m pytest --numprocesses auto --pyargs optax 79 | #python3 -m pytest --numprocesses 8 --pyargs optax 80 | cd "${REPO_DIR}" 81 | 82 | # Build Sphinx docs. 83 | python3 -m uv pip install --quiet --editable ".[docs]" 84 | cd docs 85 | make html 86 | make doctest # run doctests 87 | cd .. 88 | 89 | echo "All tests passed. Congrats!" 90 | -------------------------------------------------------------------------------- /docs/ext/coverage_check.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Asserts all public symbols are covered in the docs.""" 16 | 17 | from collections.abc import Mapping 18 | import inspect 19 | import types 20 | from typing import Any, Sequence, Tuple 21 | 22 | import optax 23 | from sphinx import application 24 | from sphinx import builders 25 | from sphinx import errors 26 | 27 | 28 | def find_internal_python_modules( 29 | root_module: types.ModuleType, 30 | ) -> Sequence[Tuple[str, types.ModuleType]]: 31 | """Returns `(name, module)` for all Optax submodules under `root_module`.""" 32 | modules = set([(root_module.__name__, root_module)]) 33 | visited = set() 34 | to_visit = [root_module] 35 | 36 | while to_visit: 37 | mod = to_visit.pop() 38 | visited.add(mod) 39 | 40 | for name in dir(mod): 41 | obj = getattr(mod, name) 42 | if inspect.ismodule(obj) and obj not in visited: 43 | if obj.__name__.startswith("optax"): 44 | if "_src" not in obj.__name__: 45 | to_visit.append(obj) 46 | modules.add((obj.__name__, obj)) 47 | 48 | return sorted(modules) 49 | 50 | 51 | def optax_public_symbols(): 52 | """Collect all optax public symbols.""" 53 | names = set() 54 | for module_name, module in find_internal_python_modules(optax): 55 | for name in module.__all__: 56 | names.add(module_name + "." + name) 57 | return names 58 | 59 | 60 | class OptaxCoverageCheck(builders.Builder): 61 | """Builder that checks all public symbols are included.""" 62 | 63 | name = "coverage_check" 64 | 65 | def get_outdated_docs(self) -> str: 66 | return "coverage_check" 67 | 68 | def write(self, *ignored: Any) -> None: # pylint: disable=overridden-final-method 69 | pass 70 | 71 | def finish(self) -> None: 72 | documented_objects = frozenset(self.env.domaindata["py"]["objects"]) # pytype: disable=attribute-error 73 | undocumented_objects = set(optax_public_symbols()) - documented_objects 74 | if undocumented_objects: 75 | undocumented_objects = tuple(sorted(undocumented_objects)) 76 | raise errors.SphinxError( 77 | "All public symbols must be included in our documentation, did you " 78 | "forget to add an entry to `api.rst`?\n" 79 | f"Undocumented symbols: {undocumented_objects}") 80 | 81 | def get_target_uri(self, docname, typ=None): 82 | raise NotImplementedError 83 | 84 | def write_doc(self, docname, doctree): 85 | raise NotImplementedError 86 | 87 | 88 | def setup(app: application.Sphinx) -> Mapping[str, Any]: 89 | app.add_builder(OptaxCoverageCheck) 90 | return {"version": optax.__version__, "parallel_read_safe": True} 91 | -------------------------------------------------------------------------------- /optax/_src/deprecations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Deprecation helpers.""" 16 | 17 | import functools 18 | from typing import Any, Callable, Optional 19 | import warnings 20 | 21 | 22 | # Module __getattr__ factory that warns if deprecated names are used. 23 | # 24 | # Example usage: 25 | # from optax.contrib.dpsgd import dpsgd as _deprecated_dpsgd 26 | # 27 | # _deprecations = { 28 | # # Added Apr 2024: 29 | # "dpsgd": ( 30 | # "optax.dpsgd is deprecated. Use optax.contrib.dpsgd instead.", 31 | # _deprecated_dpsgd, 32 | # ), 33 | # } 34 | # 35 | # from optax._src.deprecations import deprecation_getattr as _deprecation_getattr # pylint: disable=line-too-long # noqa: E501 36 | # __getattr__ = _deprecation_getattr(__name__, _deprecations) 37 | # del _deprecation_getattr 38 | 39 | 40 | # Note that type checkers such as Pytype will not know about the deprecated 41 | # names. If it is desirable that a deprecated name is known to the type checker, 42 | # add: 43 | # import typing 44 | # if typing.TYPE_CHECKING: 45 | # from optax.contrib import dpsgd 46 | # del typing 47 | 48 | 49 | def deprecation_getattr(module, deprecations): 50 | def _getattr(name): 51 | if name in deprecations: 52 | message, fn = deprecations[name] 53 | if fn is None: # Is the deprecation accelerated? 54 | raise AttributeError(message) 55 | warnings.warn(message, DeprecationWarning, stacklevel=2) 56 | return fn 57 | raise AttributeError(f'module {module!r} has no attribute {name!r}') 58 | 59 | return _getattr 60 | 61 | 62 | def warn_deprecated_function( 63 | fun: Callable[..., Any], 64 | replacement: Optional[str] = None, 65 | version_removed: Optional[str] = None, 66 | ) -> Callable[..., Any]: 67 | """A decorator to mark a function definition as deprecated. 68 | 69 | Args: 70 | fun: the deprecated function. 71 | replacement: name of the function to be used instead. 72 | version_removed: version of optax in which the function was/will be removed. 73 | 74 | Returns: 75 | The wrapped function. 76 | 77 | Example usage: 78 | >>> @functools.partial(warn_deprecated_function, replacement='g') 79 | ... def f(a, b): 80 | ... return a + b 81 | """ 82 | if hasattr(fun, '__name__'): 83 | warning_message = f'The function {fun.__name__} is deprecated.' 84 | else: 85 | warning_message = 'The function is deprecated.' 86 | if replacement: 87 | warning_message += f' Please use {replacement} instead.' 88 | if version_removed: 89 | warning_message += ( 90 | f' This function will be/was removed in optax {version_removed}.' 91 | ) 92 | 93 | @functools.wraps(fun) 94 | def new_fun(*args, **kwargs): 95 | warnings.warn(warning_message, category=DeprecationWarning, stacklevel=2) 96 | return fun(*args, **kwargs) 97 | 98 | return new_fun 99 | -------------------------------------------------------------------------------- /docs/api/contrib.rst: -------------------------------------------------------------------------------- 1 | 🔧 Contrib 2 | =============== 3 | 4 | Algorithms or wrappers that don't meet (yet) the :ref:`inclusion_criteria` or 5 | are not supported by the main library. 6 | 7 | .. currentmodule:: optax.contrib 8 | 9 | .. autosummary:: 10 | acprop 11 | ademamix 12 | adopt 13 | simplified_ademamix 14 | cocob 15 | COCOBState 16 | dadapt_adamw 17 | DAdaptAdamWState 18 | differentially_private_aggregate 19 | DifferentiallyPrivateAggregateState 20 | dog 21 | DoGState 22 | dowg 23 | DoWGState 24 | dpsgd 25 | mechanize 26 | MechanicState 27 | momo 28 | MomoState 29 | momo_adam 30 | MomoAdamState 31 | muon 32 | MuonState 33 | prodigy 34 | ProdigyState 35 | sam 36 | SAMState 37 | schedule_free 38 | schedule_free_adamw 39 | schedule_free_eval_params 40 | schedule_free_sgd 41 | ScheduleFreeState 42 | sophia 43 | SophiaState 44 | split_real_and_imaginary 45 | SplitRealAndImaginaryState 46 | 47 | AdEMAMix 48 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 49 | .. autofunction:: ademamix 50 | .. autofunction:: scale_by_ademamix 51 | .. autoclass:: ScaleByAdemamixState 52 | 53 | Simplified AdEMAMix 54 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 55 | .. autofunction:: simplified_ademamix 56 | .. autofunction:: scale_by_simplified_ademamix 57 | .. autoclass:: ScaleBySimplifiedAdEMAMixState 58 | 59 | ADOPT 60 | ~~~~~ 61 | .. autofunction:: adopt 62 | .. autofunction:: scale_by_adopt 63 | 64 | Asynchronous-centering-Prop 65 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 66 | .. autofunction:: acprop 67 | .. autofunction:: scale_by_acprop 68 | 69 | Complex-valued Optimization 70 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 71 | .. autofunction:: split_real_and_imaginary 72 | .. autoclass:: SplitRealAndImaginaryState 73 | 74 | Continuous coin betting 75 | ~~~~~~~~~~~~~~~~~~~~~~~ 76 | .. autofunction:: cocob 77 | .. autoclass:: COCOBState 78 | 79 | D-adaptation 80 | ~~~~~~~~~~~~ 81 | .. autofunction:: dadapt_adamw 82 | .. autoclass:: DAdaptAdamWState 83 | 84 | Differentially Private Aggregate 85 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 86 | .. autofunction:: differentially_private_aggregate 87 | .. autoclass:: DifferentiallyPrivateAggregateState 88 | .. autofunction:: dpsgd 89 | 90 | Distance over Gradients 91 | ~~~~~~~~~~~~~~~~~~~~~~~ 92 | .. autofunction:: dog 93 | .. autoclass:: DoGState 94 | .. autofunction:: dowg 95 | .. autoclass:: DoWGState 96 | 97 | Mechanize 98 | ~~~~~~~~~ 99 | .. autofunction:: mechanize 100 | .. autoclass:: MechanicState 101 | 102 | Momo 103 | ~~~~ 104 | .. autofunction:: momo 105 | .. autoclass:: MomoState 106 | .. autofunction:: momo_adam 107 | .. autoclass:: MomoAdamState 108 | 109 | Muon 110 | ~~~~ 111 | .. autofunction:: muon 112 | .. autofunction:: scale_by_muon 113 | .. autoclass:: MuonState 114 | 115 | Prodigy 116 | ~~~~~~~ 117 | .. autofunction:: prodigy 118 | .. autoclass:: ProdigyState 119 | 120 | Schedule-Free 121 | ~~~~~~~~~~~~~ 122 | .. autofunction:: schedule_free 123 | .. autofunction:: schedule_free_adamw 124 | .. autofunction:: schedule_free_eval_params 125 | .. autofunction:: schedule_free_sgd 126 | .. autoclass:: ScheduleFreeState 127 | 128 | Sharpness aware minimization 129 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 130 | .. autofunction:: sam 131 | .. autoclass:: SAMState 132 | 133 | Sophia 134 | ~~~~~~ 135 | .. autofunction:: hutchinson_estimator_diag_hessian 136 | .. autoclass:: HutchinsonState 137 | .. autofunction:: sophia 138 | .. autoclass:: SophiaState 139 | -------------------------------------------------------------------------------- /optax/_src/update_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for methods in `update.py`.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | import jax.numpy as jnp 21 | from optax._src import test_utils 22 | from optax._src import update 23 | 24 | 25 | class UpdateTest(parameterized.TestCase): 26 | 27 | def test_apply_updates(self): 28 | params = ({'a': jnp.ones((3, 2))}, jnp.ones((1,))) 29 | grads = jax.tree.map(lambda t: 2 * t, params) 30 | exp_params = jax.tree.map(lambda t: 3 * t, params) 31 | new_params = jax.jit(update.apply_updates)(params, grads) 32 | 33 | test_utils.assert_trees_all_close( 34 | exp_params, new_params, atol=1e-10, rtol=1e-5) 35 | 36 | def test_apply_updates_mixed_precision(self): 37 | params = ( 38 | {'a': jnp.ones((3, 2), dtype=jnp.bfloat16)}, 39 | jnp.ones((1,), dtype=jnp.bfloat16), 40 | ) 41 | grads = jax.tree.map(lambda t: (2 * t).astype(jnp.float32), params) 42 | new_params = jax.jit(update.apply_updates)(params, grads) 43 | 44 | for leaf in jax.tree.leaves(new_params): 45 | assert leaf.dtype == jnp.bfloat16 46 | 47 | def test_incremental_update(self): 48 | params_1 = ({'a': jnp.ones((3, 2))}, jnp.ones((1,))) 49 | params_2 = jax.tree.map(lambda t: 2 * t, params_1) 50 | exp_params = jax.tree.map(lambda t: 1.5 * t, params_1) 51 | new_params = jax.jit(update.incremental_update)( 52 | params_2, params_1, 0.5 53 | ) 54 | 55 | test_utils.assert_trees_all_close( 56 | exp_params, new_params, atol=1e-10, rtol=1e-5) 57 | 58 | def test_periodic_update(self): 59 | params_1 = ({'a': jnp.ones((3, 2))}, jnp.ones((1,))) 60 | params_2 = jax.tree.map(lambda t: 2 * t, params_1) 61 | 62 | update_period = 5 63 | update_fn = jax.jit(update.periodic_update) 64 | 65 | for j in range(3): 66 | for i in range(1, update_period): 67 | new_params = update_fn( 68 | params_2, params_1, j * update_period + i, update_period 69 | ) 70 | test_utils.assert_trees_all_close( 71 | params_1, new_params, atol=1e-10, rtol=1e-5) 72 | 73 | new_params = update_fn( 74 | params_2, params_1, (j + 1) * update_period, update_period 75 | ) 76 | test_utils.assert_trees_all_close( 77 | params_2, new_params, atol=1e-10, rtol=1e-5) 78 | 79 | @parameterized.named_parameters( 80 | {'testcase_name': 'apply_updates', 'operation': update.apply_updates}, 81 | { 82 | 'testcase_name': 'incremental_update', 83 | 'operation': lambda x, y: update.incremental_update(x, y, 1), 84 | }, 85 | ) 86 | def test_none_argument(self, operation): 87 | x = jnp.array([1.0, 2.0, 3.0]) 88 | operation(None, x) 89 | 90 | 91 | if __name__ == '__main__': 92 | absltest.main() 93 | -------------------------------------------------------------------------------- /optax/contrib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contributed optimizers in Optax.""" 16 | 17 | # pylint: disable=g-importing-member 18 | 19 | from optax.contrib._acprop import acprop 20 | from optax.contrib._acprop import scale_by_acprop 21 | from optax.contrib._ademamix import ademamix 22 | from optax.contrib._ademamix import scale_by_ademamix 23 | from optax.contrib._ademamix import scale_by_simplified_ademamix 24 | from optax.contrib._ademamix import ScaleByAdemamixState 25 | from optax.contrib._ademamix import ScaleBySimplifiedAdEMAMixState 26 | from optax.contrib._ademamix import simplified_ademamix 27 | from optax.contrib._adopt import adopt 28 | from optax.contrib._adopt import scale_by_adopt 29 | from optax.contrib._cocob import cocob 30 | from optax.contrib._cocob import COCOBState 31 | from optax.contrib._cocob import scale_by_cocob 32 | from optax.contrib._complex_valued import split_real_and_imaginary 33 | from optax.contrib._complex_valued import SplitRealAndImaginaryState 34 | from optax.contrib._dadapt_adamw import dadapt_adamw 35 | from optax.contrib._dadapt_adamw import DAdaptAdamWState 36 | from optax.contrib._dog import dog 37 | from optax.contrib._dog import DoGState 38 | from optax.contrib._dog import dowg 39 | from optax.contrib._dog import DoWGState 40 | from optax.contrib._mechanic import MechanicState 41 | from optax.contrib._mechanic import mechanize 42 | from optax.contrib._momo import momo 43 | from optax.contrib._momo import momo_adam 44 | from optax.contrib._momo import MomoAdamState 45 | from optax.contrib._momo import MomoState 46 | from optax.contrib._muon import muon 47 | from optax.contrib._muon import MuonDimensionNumbers 48 | from optax.contrib._muon import MuonState 49 | from optax.contrib._muon import scale_by_muon 50 | from optax.contrib._privacy import differentially_private_aggregate 51 | from optax.contrib._privacy import DifferentiallyPrivateAggregateState 52 | from optax.contrib._privacy import dpsgd 53 | from optax.contrib._prodigy import prodigy 54 | from optax.contrib._prodigy import ProdigyState 55 | from optax.contrib._reduce_on_plateau import reduce_on_plateau 56 | from optax.contrib._reduce_on_plateau import ReduceLROnPlateauState 57 | from optax.contrib._sam import normalize 58 | from optax.contrib._sam import NormalizeState 59 | from optax.contrib._sam import sam 60 | from optax.contrib._sam import SAMState 61 | from optax.contrib._schedule_free import schedule_free 62 | from optax.contrib._schedule_free import schedule_free_adamw 63 | from optax.contrib._schedule_free import schedule_free_eval_params 64 | from optax.contrib._schedule_free import schedule_free_sgd 65 | from optax.contrib._schedule_free import ScheduleFreeState 66 | from optax.contrib._sophia import hutchinson_estimator_diag_hessian 67 | from optax.contrib._sophia import HutchinsonState 68 | from optax.contrib._sophia import sophia 69 | from optax.contrib._sophia import SophiaState 70 | -------------------------------------------------------------------------------- /optax/transforms/_constraining.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Gradient transformations used to enforce specific constraints.""" 16 | 17 | from typing import Any, NamedTuple 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | from optax._src import base 22 | 23 | 24 | NonNegativeParamsState = base.EmptyState 25 | 26 | 27 | def keep_params_nonnegative() -> base.GradientTransformation: 28 | """Modifies the updates to keep parameters non-negative, i.e. >= 0. 29 | 30 | This transformation ensures that parameters after the update will be 31 | larger than or equal to zero. 32 | In a chain of transformations, this should be the last one. 33 | 34 | Returns: 35 | A :class:`optax.GradientTransformation` object. 36 | 37 | .. warning:: 38 | The transformation expects input params to be non-negative. 39 | When params is negative the transformed update will move them to 0. 40 | """ 41 | 42 | def init_fn(params): 43 | del params 44 | return NonNegativeParamsState() 45 | 46 | def update_fn(updates, state, params): 47 | if params is None: 48 | raise ValueError(base.NO_PARAMS_MSG) 49 | 50 | updates = jax.tree.map( 51 | lambda p, u: None if p is None else jnp.where((p + u) < 0.0, -p, u), 52 | params, 53 | updates, 54 | is_leaf=lambda x: x is None, 55 | ) 56 | return updates, state 57 | 58 | return base.GradientTransformation(init_fn, update_fn) 59 | 60 | 61 | class ZeroNansState(NamedTuple): 62 | """Contains a tree. 63 | 64 | The entry `found_nan` has the same tree structure as that of the parameters. 65 | Each leaf is a single boolean which contains True iff a NaN was detected in 66 | the corresponding parameter array at the last call to `update`. 67 | """ 68 | 69 | found_nan: Any 70 | 71 | 72 | def zero_nans() -> base.GradientTransformation: 73 | """A transformation which replaces NaNs with 0. 74 | 75 | The state of the transformation has the same tree structure as that of the 76 | parameters. Each leaf is a single boolean which contains True iff a NaN was 77 | detected in the corresponding parameter array at the last call to ``update``. 78 | This state is not used by the transformation internally, but lets users be 79 | aware when NaNs have been zeroed out. 80 | 81 | Returns: 82 | A :class:`optax.GradientTransformation`. 83 | """ 84 | 85 | def init_fn(params): 86 | return ZeroNansState( 87 | found_nan=jax.tree.map( 88 | lambda p: jnp.array(False, dtype=jnp.bool_), params 89 | ) 90 | ) 91 | 92 | def update_fn(updates, opt_state, params=None): 93 | del params, opt_state 94 | opt_state = ZeroNansState( 95 | found_nan=jax.tree.map(lambda p: jnp.any(jnp.isnan(p)), updates) 96 | ) 97 | updates = jax.tree.map( 98 | lambda p: jnp.where(jnp.isnan(p), jnp.zeros_like(p), p), updates 99 | ) 100 | return updates, opt_state 101 | 102 | return base.GradientTransformation(init=init_fn, update=update_fn) 103 | -------------------------------------------------------------------------------- /docs/api/losses.rst: -------------------------------------------------------------------------------- 1 | Losses 2 | ====== 3 | 4 | .. currentmodule:: optax.losses 5 | 6 | .. autosummary:: 7 | binary_dice_loss 8 | convex_kl_divergence 9 | cosine_distance 10 | cosine_similarity 11 | ctc_loss 12 | ctc_loss_with_forward_probs 13 | dice_loss 14 | generalized_kl_divergence 15 | hinge_loss 16 | huber_loss 17 | kl_divergence 18 | kl_divergence_with_log_targets 19 | l2_loss 20 | log_cosh 21 | make_fenchel_young_loss 22 | multiclass_generalized_dice_loss 23 | multiclass_hinge_loss 24 | multiclass_perceptron_loss 25 | multiclass_sparsemax_loss 26 | ntxent 27 | perceptron_loss 28 | poly_loss_cross_entropy 29 | ranking_softmax_loss 30 | safe_softmax_cross_entropy 31 | sigmoid_binary_cross_entropy 32 | sigmoid_focal_loss 33 | smooth_labels 34 | softmax_cross_entropy 35 | softmax_cross_entropy_with_integer_labels 36 | sparsemax_loss 37 | squared_error 38 | triplet_margin_loss 39 | 40 | 41 | Convex Kullback Leibler divergence 42 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 43 | .. autofunction:: convex_kl_divergence 44 | 45 | Cosine distance 46 | ~~~~~~~~~~~~~~~ 47 | .. autofunction:: cosine_distance 48 | 49 | Cosine similarity 50 | ~~~~~~~~~~~~~~~~~ 51 | .. autofunction:: cosine_similarity 52 | 53 | Connectionist temporal classification loss 54 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 55 | .. autofunction:: ctc_loss 56 | .. autofunction:: ctc_loss_with_forward_probs 57 | 58 | Dice loss 59 | ~~~~~~~~~ 60 | .. autofunction:: dice_loss 61 | .. autofunction:: multiclass_generalized_dice_loss 62 | .. autofunction:: binary_dice_loss 63 | 64 | Fenchel Young loss 65 | ~~~~~~~~~~~~~~~~~~ 66 | .. autofunction:: make_fenchel_young_loss 67 | 68 | Generalized Kullback-Leibler divergence 69 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 70 | .. autofunction:: generalized_kl_divergence 71 | 72 | Hinge loss 73 | ~~~~~~~~~~ 74 | .. autofunction:: hinge_loss 75 | .. autofunction:: multiclass_hinge_loss 76 | 77 | Huber loss 78 | ~~~~~~~~~~ 79 | .. autofunction:: huber_loss 80 | 81 | Kullback-Leibler divergence 82 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 83 | .. autofunction:: kl_divergence 84 | .. autofunction:: kl_divergence_with_log_targets 85 | 86 | L2 Squared loss 87 | ~~~~~~~~~~~~~~~ 88 | .. autofunction:: squared_error 89 | .. autofunction:: l2_loss 90 | 91 | Log hyperbolic cosine loss 92 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 93 | .. autofunction:: log_cosh 94 | 95 | Normalized temperature scaled cross-entropy (NT-Xent) loss 96 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 97 | .. autofunction:: ntxent 98 | 99 | Poly loss cross-entropy 100 | ~~~~~~~~~~~~~~~~~~~~~~~ 101 | .. autofunction:: poly_loss_cross_entropy 102 | 103 | Perceptron 104 | ~~~~~~~~~~~ 105 | .. autofunction:: perceptron_loss 106 | .. autofunction:: multiclass_perceptron_loss 107 | 108 | Ranking softmax loss 109 | ~~~~~~~~~~~~~~~~~~~~ 110 | .. autofunction:: ranking_softmax_loss 111 | 112 | Sigmoid binary cross-entropy 113 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 114 | .. autofunction:: sigmoid_binary_cross_entropy 115 | 116 | Sigmoid focal loss 117 | ~~~~~~~~~~~~~~~~~~ 118 | .. autofunction:: sigmoid_focal_loss 119 | 120 | Smoothing labels 121 | ~~~~~~~~~~~~~~~~ 122 | .. autofunction:: smooth_labels 123 | 124 | Soft-max cross-entropy 125 | ~~~~~~~~~~~~~~~~~~~~~~ 126 | .. autofunction:: safe_softmax_cross_entropy 127 | .. autofunction:: softmax_cross_entropy 128 | .. autofunction:: softmax_cross_entropy_with_integer_labels 129 | 130 | Sparsemax 131 | ~~~~~~~~~ 132 | .. autofunction:: sparsemax_loss 133 | .. autofunction:: multiclass_sparsemax_loss 134 | 135 | Triplet margin loss 136 | ~~~~~~~~~~~~~~~~~~~ 137 | .. autofunction:: triplet_margin_loss 138 | -------------------------------------------------------------------------------- /optax/contrib/_sam_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for the SAM optimizer in `sam.py`.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | import jax.numpy as jnp 21 | from optax._src import alias 22 | from optax._src import combine 23 | from optax._src import numerics 24 | from optax._src import test_utils 25 | from optax._src import update 26 | from optax.contrib import _sam 27 | import optax.tree 28 | 29 | _BASE_OPTIMIZERS_UNDER_TEST = [ 30 | {'base_opt_name': 'sgd', 'base_opt_kwargs': {'learning_rate': 1e-3}}, 31 | ] 32 | _ADVERSARIAL_OPTIMIZERS_UNDER_TEST = [ 33 | {'adv_opt_name': 'sgd', 'adv_opt_kwargs': {'learning_rate': 1e-5}}, 34 | {'adv_opt_name': 'adam', 'adv_opt_kwargs': {'learning_rate': 1e-4}}, 35 | ] 36 | 37 | 38 | def _setup_parabola(dtype): 39 | """Quadratic function as an optimization target.""" 40 | initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype) 41 | final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype) 42 | 43 | @jax.grad 44 | def get_updates(params): 45 | return jnp.sum(numerics.abs_sq(params - final_params)) 46 | 47 | return initial_params, final_params, get_updates 48 | 49 | 50 | class SAMTest(parameterized.TestCase): 51 | 52 | @parameterized.product( 53 | _BASE_OPTIMIZERS_UNDER_TEST, 54 | _ADVERSARIAL_OPTIMIZERS_UNDER_TEST, 55 | sync_period=(2,), 56 | target=(_setup_parabola,), 57 | dtype=('float32',), 58 | opaque_mode=(False, True), 59 | ) 60 | def test_optimization( 61 | self, 62 | base_opt_name, 63 | base_opt_kwargs, 64 | adv_opt_name, 65 | adv_opt_kwargs, 66 | sync_period, 67 | target, 68 | dtype, 69 | opaque_mode, 70 | ): 71 | dtype = jnp.dtype(dtype) 72 | base_opt = getattr(alias, base_opt_name)(**base_opt_kwargs) 73 | adv_opt = combine.chain( 74 | _sam.normalize(), getattr(alias, adv_opt_name)(**adv_opt_kwargs) 75 | ) 76 | opt = _sam.sam( 77 | base_opt, adv_opt, sync_period=sync_period, opaque_mode=opaque_mode 78 | ) 79 | initial_params, final_params, get_updates = target(dtype) 80 | 81 | if opaque_mode: 82 | update_kwargs = {'grad_fn': lambda p, _: get_updates(p)} 83 | else: 84 | update_kwargs = {} 85 | 86 | @jax.jit 87 | def step(params, state): 88 | updates = get_updates(params) 89 | updates, state = opt.update(updates, state, params, **update_kwargs) 90 | params = update.apply_updates(params, updates) 91 | return params, state 92 | 93 | params = initial_params 94 | state = opt.init(params) 95 | # A no-op change, to verify that tree map works. 96 | state = optax.tree.map_params(opt, lambda v: v, state) 97 | 98 | for _ in range(25000 * sync_period): 99 | params, state = step(params, state) 100 | 101 | test_utils.assert_trees_all_close( 102 | params, final_params, rtol=3e-2, atol=3e-2) 103 | 104 | 105 | if __name__ == '__main__': 106 | absltest.main() 107 | -------------------------------------------------------------------------------- /optax/_src/factorized_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for methods in `factorized.py`.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | import jax.numpy as jnp 21 | from optax._src import factorized 22 | from optax._src import test_utils 23 | from optax.transforms import _accumulation 24 | 25 | 26 | class FactorizedTest(parameterized.TestCase): 27 | 28 | def setUp(self): 29 | super().setUp() 30 | self.init_params = (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])) 31 | self.per_step_updates = (jnp.array([500.0, 5.0]), jnp.array([300.0, 3.0])) 32 | 33 | def test_scale_by_factored_rms(self): 34 | params = self.init_params 35 | 36 | scaler = factorized.scale_by_factored_rms() 37 | init_fn = jax.jit(scaler.init) 38 | transform_fn = jax.jit(scaler.update) 39 | 40 | state = init_fn(params) 41 | test_utils.assert_tree_all_finite(state) 42 | 43 | updates, state = transform_fn(self.per_step_updates, state, params) 44 | test_utils.assert_tree_all_finite((params, updates, state)) 45 | test_utils.assert_trees_all_equal_shapes(params, updates) 46 | 47 | @parameterized.product( 48 | factorized_dims=(True, False), dtype=('bfloat16', 'float32') 49 | ) 50 | def test_preserve_dtype(self, factorized_dims: bool, dtype: str): 51 | """Test that the optimizer returns updates of same dtype as params.""" 52 | dtype = jnp.dtype(dtype) 53 | opt = factorized.scale_by_factored_rms() 54 | fun = lambda x: jnp.sum(x**2) 55 | 56 | if factorized_dims: 57 | # The updates are factored only for large enough parameters 58 | # default min_dim_size_to_factor is 128 so we use 129 here. 59 | params = jnp.ones((129, 129), dtype=dtype) 60 | else: 61 | params = jnp.array([1.0, 2.0], dtype=dtype) 62 | grads = jax.grad(fun)(params) 63 | state = jax.jit(opt.init)(params) 64 | updates, _ = jax.jit(opt.update)(grads, state, params) 65 | self.assertEqual(updates.dtype, params.dtype) 66 | 67 | @parameterized.product( 68 | factorized_dims=(True, False), dtype=('bfloat16', 'float32') 69 | ) 70 | def test_gradient_accumulation(self, factorized_dims, dtype): 71 | """Test that the optimizers can safely be used with optax.MultiSteps.""" 72 | # Checks if https://github.com/google-deepmind/optax/issues/377 is fixed. 73 | dtype = jnp.dtype(dtype) 74 | base_opt = factorized.scale_by_factored_rms() 75 | opt = _accumulation.MultiSteps(base_opt, every_k_schedule=4) 76 | 77 | fun = lambda x: jnp.sum(x**2) 78 | 79 | if factorized_dims: 80 | # The updates are factored only for large enough parameters 81 | # default min_dim_size_to_factor is 128 so we use 129 here. 82 | params = jnp.ones((129, 129), dtype=dtype) 83 | else: 84 | params = jnp.array([1.0, 2.0], dtype=dtype) 85 | grads = jax.grad(fun)(params) 86 | state = jax.jit(opt.init)(params) 87 | updates, _ = jax.jit(opt.update)(grads, state, params) 88 | test_utils.assert_trees_all_equal(updates, jnp.zeros_like(grads)) 89 | 90 | 91 | if __name__ == '__main__': 92 | absltest.main() 93 | -------------------------------------------------------------------------------- /.github/check_license_headers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Check all Python files in the optax directory for license header.""" 16 | 17 | from concurrent import futures 18 | import logging 19 | import pathlib 20 | import pprint 21 | import re 22 | import sys 23 | 24 | logger = logging.getLogger(pathlib.Path(__file__).name) 25 | logging.basicConfig(format=logging.BASIC_FORMAT) 26 | logger.setLevel("DEBUG") 27 | 28 | # pylint: disable=line-too-long 29 | LICENSE_PATTERN = ( 30 | "(# (pylint|coding).*\n)*" 31 | "# Copyright 20[0-9][0-9] DeepMind Technologies Limited. All Rights Reserved.\n" # noqa: E501 32 | "#\n" 33 | "# Licensed under the Apache License, Version 2.0 \\(the \"License\"\\);\n" 34 | "# you may not use this file except in compliance with the License.\n" 35 | "# You may obtain a copy of the License at\n" 36 | "#\n" 37 | "# http://www.apache.org/licenses/LICENSE-2.0\n" 38 | "#\n" 39 | "# Unless required by applicable law or agreed to in writing, software\n" 40 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n" 41 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" # noqa: E501 42 | "# See the License for the specific language governing permissions and\n" 43 | "# limitations under the License.\n" 44 | "# ==============================================================================\n" # noqa: E501 45 | ".*" 46 | ) 47 | # pylint: enable=line-too-long 48 | 49 | LICENSE_TEMPLATE = """ 50 | # Copyright 20XX DeepMind Technologies Limited. All Rights Reserved. 51 | # 52 | # Licensed under the Apache License, Version 2.0 (the "License"); 53 | # you may not use this file except in compliance with the License. 54 | # You may obtain a copy of the License at 55 | # 56 | # http://www.apache.org/licenses/LICENSE-2.0 57 | # 58 | # Unless required by applicable law or agreed to in writing, software 59 | # distributed under the License is distributed on an "AS IS" BASIS, 60 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 61 | # See the License for the specific language governing permissions and 62 | # limitations under the License. 63 | # ============================================================================== 64 | """ 65 | 66 | EXCLUDE_LIST = [] 67 | 68 | 69 | def _check_license_header(fname): 70 | if fname in EXCLUDE_LIST: 71 | return True 72 | try: 73 | source = pathlib.Path(fname).read_text() 74 | return re.match(LICENSE_PATTERN, source) is not None 75 | except UnicodeDecodeError: 76 | return True 77 | 78 | if __name__ == "__main__": 79 | # check all Python files in the optax directory for license header 80 | source_files = list(pathlib.Path("./optax").glob("**/*.py")) 81 | with futures.ThreadPoolExecutor(max_workers=32) as executor: 82 | results = dict(zip(source_files, 83 | executor.map(_check_license_header, source_files))) 84 | failed_files = [str(fname) for fname, status in results.items() if not status] 85 | if failed_files: 86 | logger.error( 87 | "Files:\n%s\ndon't have the proper license. Please include this license" 88 | " template at the top of your file:\n%s", pprint.pformat(failed_files), 89 | LICENSE_TEMPLATE) 90 | sys.exit(1) # non-success return 91 | -------------------------------------------------------------------------------- /optax/transforms/_monitoring_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for monitoring and debugging in optax.""" 17 | 18 | from absl.testing import absltest 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | import numpy as np 23 | from optax import tree 24 | from optax._src import alias 25 | from optax._src import update 26 | from optax.transforms import _clipping 27 | from optax.transforms import _combining 28 | from optax.transforms import _monitoring 29 | 30 | 31 | class MonitoringTest(absltest.TestCase): 32 | 33 | def test_snapshot(self): 34 | """Tests that snapshot stores the correct values.""" 35 | 36 | def f(x): 37 | return jnp.sum(x**2) 38 | 39 | opt_before_clip = _combining.chain( 40 | alias.sgd(learning_rate=0.1, momentum=0.9), 41 | _monitoring.snapshot('norm_before_clip', tree.norm), 42 | ) 43 | opt = _combining.chain(opt_before_clip, _clipping.clip_by_global_norm(0.05)) 44 | 45 | params = jnp.array([1.0, 2.0, 3.0]) 46 | state_aux = opt_before_clip.init(params) 47 | state = opt.init(params) 48 | 49 | # Testing for two steps to observe behavior not only after initialization 50 | # but also after the first update. 51 | for step in range(1, 3): 52 | grads = jax.grad(f)(params) 53 | updates_before_clip, state_aux = opt_before_clip.update(grads, state_aux) 54 | updates, state = opt.update(grads, state) 55 | params = update.apply_updates(params, updates) 56 | with self.subTest(f'norms equal at {step=}'): 57 | got = tree.get(state, 'norm_before_clip') 58 | expected = tree.norm(updates_before_clip) 59 | np.testing.assert_allclose(got, expected) 60 | 61 | def test_monitor(self): 62 | """Tests that monitor stores the correct values.""" 63 | 64 | def f(x): 65 | return jnp.sum(x**2) 66 | 67 | ema_decay = 0.9 68 | debias = True 69 | opt_before_clip = _combining.chain( 70 | alias.sgd(learning_rate=0.1, momentum=0.9), 71 | _monitoring.monitor({ 72 | 'norm_before_clip': tree.norm, 73 | 'norm_before_clip_ema': _monitoring.measure_with_ema( 74 | tree.norm, ema_decay, debias 75 | ), 76 | }), 77 | ) 78 | opt = _combining.chain(opt_before_clip, _clipping.clip_by_global_norm(0.05)) 79 | 80 | params = jnp.array([1.0, 2.0, 3.0]) 81 | state_aux = opt_before_clip.init(params) 82 | state = opt.init(params) 83 | 84 | ema_norm = 0.0 85 | for step in range(1, 4): 86 | grads = jax.grad(f)(params) 87 | updates_before_clip, state_aux = opt_before_clip.update(grads, state_aux) 88 | updates, state = opt.update(grads, state) 89 | params = update.apply_updates(params, updates) 90 | norm_before_clip = tree.norm(updates_before_clip) 91 | with self.subTest(f'norms equal at {step=}'): 92 | np.testing.assert_allclose( 93 | tree.get(state, 'norm_before_clip'), norm_before_clip 94 | ) 95 | 96 | ema_norm = ema_decay * ema_norm + (1 - ema_decay) * norm_before_clip 97 | ema_norm_debiased = ema_norm / (1 - ema_decay**step) 98 | with self.subTest(f'ema norms equal at {step=}'): 99 | np.testing.assert_allclose( 100 | tree.get(state, 'norm_before_clip_ema'), 101 | ema_norm_debiased, 102 | rtol=1e-5, 103 | ) 104 | 105 | 106 | if __name__ == '__main__': 107 | absltest.main() 108 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["flit_core >=3.2,<4"] 3 | build-backend = "flit_core.buildapi" 4 | 5 | [project] 6 | name = "optax" 7 | dynamic = ["version"] 8 | description = "A gradient processing and optimization library in JAX." 9 | readme = "README.md" 10 | license = { file = "LICENSE" } 11 | requires-python = ">=3.10" 12 | authors = [ 13 | {name = "Google DeepMind", email = "optax-dev@google.com"}, 14 | ] 15 | keywords = [ 16 | "python", 17 | "machine learning", 18 | "reinforcement-learning" 19 | ] 20 | classifiers = [ 21 | "Environment :: Console", 22 | "Programming Language :: Python", 23 | "Intended Audience :: Developers", 24 | "Operating System :: OS Independent", 25 | "Programming Language :: Python :: 3", 26 | "Intended Audience :: Science/Research", 27 | "Development Status :: 4 - Beta", 28 | "License :: OSI Approved :: Apache Software License", 29 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 30 | "Topic :: Software Development :: Libraries :: Python Modules", 31 | ] 32 | dependencies = [ 33 | "absl-py>=0.7.1", 34 | "chex>=0.1.87", 35 | # Keep jax, jaxlib versions in sync with .github/workflows/tests.yml 36 | "jax>=0.5.3", 37 | "jaxlib>=0.5.3", 38 | "numpy>=1.18.0", 39 | ] 40 | 41 | [project.urls] 42 | homepage = "https://github.com/google-deepmind/optax" 43 | repository = "https://github.com/google-deepmind/optax" 44 | documentation = "https://optax.readthedocs.io/" 45 | 46 | [project.optional-dependencies] 47 | test = [ 48 | "flax>=0.5.3", 49 | "scipy>=1.7.1", 50 | "scikit-learn" 51 | ] 52 | 53 | docs = [ 54 | "sphinx>=6.0.0", 55 | "sphinx-book-theme>=1.0.1", # Older versions fail to pin pydata-sphinx-theme 56 | "sphinxcontrib-katex", 57 | "sphinx-autodoc-typehints", 58 | "ipython>=8.8.0", # 8.7.0 has ipython3 lexer error 59 | "myst-nb>=1.0.0", 60 | "matplotlib>=3.5.0", 61 | "sphinx-gallery>=0.14.0", 62 | "sphinx-collections>=0.0.1", 63 | "flax", 64 | "sphinx_contributors", 65 | "setuptools", 66 | ] 67 | 68 | [tool.setuptools.packages.find] 69 | include = ["README.md", "LICENSE"] 70 | exclude = ["*_test.py"] 71 | 72 | [tool.ruff] 73 | line-length = 80 74 | 75 | [tool.ruff.lint] 76 | select = [ 77 | "F", 78 | "E", 79 | "W291", # whitespace at the end of the line 80 | "B023", # pylint's cell-var-over-loop, closures capturing variables in loop 81 | ] 82 | ignore = [ 83 | "E731", # lambdas are allowed 84 | "F401", # allow unused imports 85 | "E402", # allow modules not at top of file 86 | "E741", # allow "l" as a variable name 87 | "E703", # allow semicolons (for jupyter notebooks) 88 | ] 89 | 90 | [tool.pylint.messages_control] 91 | disable = [ 92 | "bad-indentation", 93 | "unknown-option-value", 94 | "invalid-name", 95 | "missing-function-docstring", 96 | "missing-class-docstring", 97 | "missing-module-docstring", 98 | "no-member", 99 | "too-many-locals", 100 | "too-many-positional-arguments", 101 | "no-else-return", 102 | "line-too-long", 103 | "too-many-arguments", 104 | "no-value-for-parameter", 105 | "duplicate-code", 106 | "unused-argument", 107 | "too-few-public-methods", 108 | "wrong-import-order", 109 | "unused-import", 110 | "wrong-import-position", 111 | "unnecessary-lambda-assignment", 112 | "too-many-lines", 113 | "too-many-statements", 114 | "deprecated-class", 115 | "redefined-builtin", 116 | "used-before-assignment", 117 | "undefined-variable", 118 | "protected-access", 119 | "not-callable", 120 | "redefined-outer-name", 121 | "too-many-instance-attributes", 122 | "missing-final-newline", 123 | "too-many-public-methods", 124 | "import-error", 125 | ] 126 | 127 | # We include pyink to allow external contributors to optionally, but easily, 128 | # pass google internal formatting checks. 129 | [tool.pyink] 130 | pyink = true # false would mean black is used directly without pyink features 131 | pyink-use-majority-quotes = true 132 | pyink-indentation = 2 133 | line-length = 80 134 | include = '\.pyi?$' 135 | -------------------------------------------------------------------------------- /optax/tree_utils/_random.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utilities to generate random pytrees.""" 16 | 17 | from collections.abc import Callable 18 | import inspect 19 | from typing import Optional, Union 20 | 21 | import chex 22 | import jax 23 | from optax._src import base 24 | 25 | 26 | def tree_split_key_like( 27 | rng_key: base.PRNGKey, target_tree: chex.ArrayTree 28 | ) -> chex.ArrayTree: 29 | """Split keys to match structure of target tree. 30 | 31 | Args: 32 | rng_key: the key to split. 33 | target_tree: the tree whose structure to match. 34 | 35 | Returns: 36 | a tree of rng keys. 37 | """ 38 | tree_def = jax.tree.structure(target_tree) 39 | keys = jax.random.split(rng_key, tree_def.num_leaves) 40 | return jax.tree.unflatten(tree_def, keys) 41 | 42 | 43 | def tree_random_like( 44 | rng_key: base.PRNGKey, 45 | target_tree: chex.ArrayTree, 46 | sampler: Union[ 47 | Callable[[base.PRNGKey, base.Shape, jax.typing.DTypeLike], 48 | jax.typing.ArrayLike], 49 | Callable[[base.PRNGKey, base.Shape, jax.typing.DTypeLike, 50 | jax.sharding.Sharding], 51 | jax.typing.ArrayLike]] = jax.random.normal, 52 | dtype: Optional[chex.ArrayDType] = None, 53 | ) -> chex.ArrayTree: 54 | """Create tree with random entries of the same shape as target tree. 55 | 56 | Args: 57 | rng_key: the key for the random number generator. 58 | target_tree: the tree whose structure to match. Leaves must be arrays. 59 | sampler: the noise sampling function, by default ``jax.random.normal``. 60 | dtype: the desired dtype for the random numbers, passed to ``sampler``. If 61 | None, the dtype of the target tree is used if possible. 62 | 63 | Returns: 64 | a random tree with the same structure as ``target_tree``, whose leaves have 65 | distribution ``sampler``. 66 | 67 | .. warning:: 68 | The possible dtypes may be limited by the sampler, for example 69 | ``jax.random.rademacher`` only supports integer dtypes and will raise an 70 | error if the dtype of the target tree is not an integer or if the dtype 71 | is not of integer type. 72 | 73 | .. versionadded:: 0.2.1 74 | """ 75 | keys_tree = tree_split_key_like(rng_key, target_tree) 76 | sampler_ = sampler 77 | if "out_sharding" not in inspect.signature(sampler).parameters: 78 | sampler_ = lambda key, shape, dtype, *, out_sharding: sampler( # pylint: disable=unnecessary-lambda 79 | key, shape, dtype) # pytype: disable=wrong-arg-count 80 | return jax.tree.map( 81 | # pytype: disable=wrong-keyword-args 82 | lambda leaf, key: sampler_(key, leaf.shape, dtype or leaf.dtype, 83 | out_sharding=jax.typeof(leaf).sharding), 84 | # pytype: enable=wrong-keyword-args 85 | target_tree, 86 | keys_tree, 87 | ) 88 | 89 | 90 | def tree_unwrap_random_key_data(input_tree: chex.ArrayTree) -> chex.ArrayTree: 91 | """Unwrap random.key objects in a tree for numerical comparison. 92 | 93 | Args: 94 | input_tree: a tree of arrays and random.key objects. 95 | 96 | Returns: 97 | a tree of arrays and random.key_data objects. 98 | """ 99 | def _unwrap_random_key_data(x): 100 | if (isinstance(x, jax.Array) 101 | and jax.dtypes.issubdtype(x.dtype, jax.dtypes.prng_key)): 102 | return jax.random.key_data(x) 103 | return x 104 | 105 | return jax.tree.map(_unwrap_random_key_data, input_tree) 106 | -------------------------------------------------------------------------------- /optax/contrib/_mechanic_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Specific tests for `mechanic.py`, see `common_test.py` for usual tests.""" 16 | 17 | from typing import NamedTuple 18 | 19 | from absl.testing import absltest 20 | import jax 21 | import jax.numpy as jnp 22 | import numpy as np 23 | from optax._src import base 24 | from optax._src import test_utils 25 | from optax._src import update 26 | from optax.contrib import _mechanic 27 | import optax.tree 28 | 29 | 30 | class OptimizerTestState(NamedTuple): 31 | """Inner optimizer state for the Mechanic tests.""" 32 | 33 | aggregate_grads: base.Params 34 | 35 | 36 | def _test_optimizer(step_size: float) -> base.GradientTransformation: 37 | """Inner optimizer for the Mechanic tests.""" 38 | 39 | # Use SGD for simplicity but add non-trivial optimizer state so that the 40 | # resetting behavior of lookahead can be tested. 41 | def init_fn(params): 42 | aggregate_grads = jax.tree.map(jnp.zeros_like, params) 43 | return OptimizerTestState(aggregate_grads) 44 | 45 | def update_fn(updates, state, params): 46 | # The test optimizer does not use the parameters, but we check that they 47 | # have been passed correctly. 48 | test_utils.assert_trees_all_equal_shapes(updates, params) 49 | aggregate_grads = update.apply_updates(state.aggregate_grads, updates) 50 | updates = jax.tree.map(lambda u: step_size * u, updates) 51 | return updates, OptimizerTestState(aggregate_grads) 52 | 53 | return base.GradientTransformation(init_fn, update_fn) 54 | 55 | 56 | class MechanicTest(absltest.TestCase): 57 | 58 | def setUp(self): 59 | super().setUp() 60 | self.grads = {'x': np.array(2.0), 'y': np.array(-2.0)} 61 | self.initial_params = {'x': np.array(3.0), 'y': np.array(-3.0)} 62 | 63 | def loop(self, optimizer, num_steps, params 64 | ) -> tuple[base.Params, _mechanic.MechanicState]: 65 | """Performs a given number of optimizer steps.""" 66 | init_fn, update_fn = optimizer 67 | step = jax.jit(update_fn) 68 | opt_state = jax.jit(init_fn)(params) 69 | 70 | # A no-op change, to verify that tree map works. 71 | opt_state = optax.tree.map_params(init_fn, lambda v: v, opt_state) 72 | 73 | for _ in range(num_steps): 74 | updates, opt_state = step(self.grads, opt_state, params) 75 | print(updates) 76 | params = update.apply_updates(params, updates) 77 | 78 | return params, opt_state 79 | 80 | def test_mechanized(self): 81 | params = self.initial_params 82 | num_betas = 6 83 | 84 | inner_optimizer = _test_optimizer(-0.1) 85 | optimizer = _mechanic.mechanize( 86 | inner_optimizer, 87 | weight_decay=1e-2, 88 | eps=1e-10, 89 | s_init=1e-8, 90 | num_betas=num_betas, 91 | ) 92 | 93 | final_params, final_state = self.loop( 94 | optimizer=optimizer, num_steps=1, params=params 95 | ) 96 | expected_m = np.array([1.0e-10] * num_betas) 97 | expected_v = np.array([0.0] * num_betas) 98 | expected_s = np.array([1.6666667e-09] * num_betas) 99 | 100 | test_utils.assert_trees_all_close(expected_m, final_state.m) 101 | test_utils.assert_trees_all_close(expected_v, final_state.v) 102 | test_utils.assert_trees_all_close(expected_s, final_state.s) 103 | test_utils.assert_trees_all_close(final_params, params) 104 | test_utils.assert_tree_all_finite((final_params, final_state)) 105 | 106 | 107 | if __name__ == '__main__': 108 | absltest.main() 109 | -------------------------------------------------------------------------------- /optax/_src/float64_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests that types are preserved by the `update` calls when jax_enbable_x64.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | import jax.numpy as jnp 21 | 22 | from optax import transforms 23 | 24 | from optax._src import alias 25 | from optax._src import base 26 | from optax._src import transform 27 | from optax._src import update 28 | 29 | 30 | ALL_MODULES = [ 31 | ('identity', base.identity, {}), 32 | ('clip', transforms.clip, {'max_delta': 1.0}), 33 | ('clip_by_global_norm', transforms.clip_by_global_norm, {'max_norm': 1.0}), 34 | ('trace', transforms.trace, {'decay': 0.5, 'nesterov': False}), 35 | ('trace_with_nesterov', transforms.trace, {'decay': 0.5, 'nesterov': True}), 36 | ('scale_by_rss', transform.scale_by_rss, {}), 37 | ('scale_by_rms', transform.scale_by_rms, {}), 38 | ('scale_by_stddev', transform.scale_by_stddev, {}), 39 | ('scale_by_adam', transform.scale_by_adam, {}), 40 | ('scale', transform.scale, {'step_size': 3.0}), 41 | ( 42 | 'add_decayed_weights', 43 | transforms.add_decayed_weights, 44 | {'weight_decay': 0.1}, 45 | ), 46 | ( 47 | 'scale_by_schedule', 48 | transform.scale_by_schedule, 49 | {'step_size_fn': lambda x: x * 0.1}, 50 | ), 51 | ('scale_by_trust_ratio', transform.scale_by_trust_ratio, {}), 52 | ('add_noise', transforms.add_noise, {'eta': 1.0, 'gamma': 0.1, 'key': 0}), 53 | ('apply_every_k', transform.apply_every, {}), 54 | ('adagrad', alias.adagrad, {'learning_rate': 0.1}), 55 | ('adam', alias.adam, {'learning_rate': 0.1}), 56 | ('adamw', alias.adamw, {'learning_rate': 0.1}), 57 | ('fromage', alias.fromage, {'learning_rate': 0.1}), 58 | ('lamb', alias.lamb, {'learning_rate': 0.1}), 59 | ('noisy_sgd', alias.noisy_sgd, {'learning_rate': 0.1, 'key': 0}), 60 | ('rmsprop', alias.rmsprop, {'learning_rate': 0.1}), 61 | ('sgd', alias.sgd, {'learning_rate': 0.1}), 62 | ('sign_sgd', alias.sgd, {'learning_rate': 0.1}), 63 | ] 64 | 65 | 66 | class Float64Test(parameterized.TestCase): 67 | 68 | def _assert_dtype_equals(self, tree1, tree2): 69 | tree1_types = jax.tree.map(lambda t: t.dtype, tree1) 70 | tree2_types = jax.tree.map(lambda t: t.dtype, tree2) 71 | self.assertEqual(tree1_types, tree2_types) 72 | 73 | @parameterized.named_parameters(ALL_MODULES) 74 | def test_mixed_dtype_input_outputs(self, transform_constr, transform_kwargs): 75 | jax.config.update('jax_enable_x64', True) 76 | initial_params = ( 77 | jnp.array([1.0, 2.0], dtype=jnp.float32), 78 | jnp.array([3.0, 4.0], dtype=jnp.float64), 79 | ) 80 | updates = ( 81 | jnp.array([10.0, 21.0], dtype=jnp.float32), 82 | jnp.array([33.0, 42.0], dtype=jnp.float64), 83 | ) 84 | scaler = transform_constr(**transform_kwargs) 85 | init_fn = jax.jit(scaler.init) 86 | update_fn = jax.jit(scaler.update) 87 | 88 | initial_state = init_fn(initial_params) 89 | updates, new_state = update_fn( 90 | updates, initial_state, params=initial_params 91 | ) 92 | new_params = update.apply_updates(initial_params, updates) 93 | 94 | self._assert_dtype_equals(initial_state, new_state) 95 | self._assert_dtype_equals(initial_params, new_params) 96 | jax.config.update('jax_enable_x64', False) 97 | 98 | 99 | if __name__ == '__main__': 100 | absltest.main() 101 | -------------------------------------------------------------------------------- /optax/transforms/_adding_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for methods in `optax.transforms._adding.py`.""" 16 | 17 | from absl.testing import absltest 18 | import jax 19 | import jax.numpy as jnp 20 | from optax._src import test_utils 21 | from optax.transforms import _adding 22 | 23 | STEPS = 50 24 | 25 | 26 | class AddingTest(absltest.TestCase): 27 | 28 | def setUp(self): 29 | super().setUp() 30 | self.init_params = (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])) 31 | self.per_step_updates = (jnp.array([500.0, 5.0]), jnp.array([300.0, 3.0])) 32 | 33 | def test_add_decayed_weights(self): 34 | # Define a transform that add decayed weights. 35 | # We can define a mask either as a pytree, or as a function that 36 | # returns the pytree. Below we define the pytree directly. 37 | mask = (True, {"a": True, "b": False}) 38 | tx = _adding.add_decayed_weights(0.1, mask=mask) 39 | # Define input updates and weights. 40 | updates = ( 41 | jnp.zeros((2,), dtype=jnp.float32), 42 | { 43 | "a": jnp.zeros((2,), dtype=jnp.float32), 44 | "b": jnp.zeros((2,), dtype=jnp.float32), 45 | }, 46 | ) 47 | weights = ( 48 | jnp.ones((2,), dtype=jnp.float32), 49 | { 50 | "a": jnp.ones((2,), dtype=jnp.float32), 51 | "b": jnp.ones((2,), dtype=jnp.float32), 52 | }, 53 | ) 54 | # This mask means that we will add decayed weights to the first two 55 | # terms in the input updates, but not to the last element. 56 | expected_tx_updates = ( 57 | 0.1*jnp.ones((2,), dtype=jnp.float32), 58 | { 59 | "a": 0.1*jnp.ones((2,), dtype=jnp.float32), 60 | "b": jnp.zeros((2,), dtype=jnp.float32), 61 | }, 62 | ) 63 | # Apply transform 64 | state = tx.init(weights) 65 | transform_fn = jax.jit(tx.update) 66 | new_updates, _ = transform_fn(updates, state, weights) 67 | # Assert output as expected. 68 | test_utils.assert_trees_all_close(new_updates, expected_tx_updates) 69 | 70 | def test_add_noise_has_correct_variance_scaling(self): 71 | # Prepare to compare noise with a rescaled unit-variance substitute. 72 | eta = 0.3 73 | gamma = 0.55 74 | key = 314 75 | noise = _adding.add_noise(eta, gamma, key) 76 | noise_unit = _adding.add_noise(1.0, 0.0, key) 77 | 78 | params = self.init_params 79 | state = noise.init(params) 80 | state_unit = noise_unit.init(params) 81 | 82 | # Check the noise itself by adding it to zeros. 83 | updates = jax.tree.map(jnp.zeros_like, params) 84 | 85 | for i in range(1, STEPS + 1): 86 | updates_i, state = jax.jit(noise.update)(updates, state) 87 | updates_i_unit, state_unit = noise_unit.update(updates, state_unit) 88 | 89 | scale = jnp.sqrt(eta / i**gamma) 90 | 91 | updates_i_rescaled = jax.tree.map( 92 | lambda g, s=scale: g * s, updates_i_unit 93 | ) 94 | 95 | test_utils.assert_trees_all_close( 96 | updates_i, updates_i_rescaled, rtol=1e-4) 97 | 98 | def test_none_argument(self): 99 | weights = ( 100 | jnp.ones((2,), dtype=jnp.float32), 101 | { 102 | "a": jnp.ones((2,), dtype=jnp.float32), 103 | "b": jnp.ones((2,), dtype=jnp.float32), 104 | }, 105 | ) 106 | tf = _adding.add_decayed_weights(0.1, mask=None) 107 | tf.update(None, 0, weights) 108 | 109 | 110 | if __name__ == "__main__": 111 | absltest.main() 112 | -------------------------------------------------------------------------------- /optax/transforms/_freezing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utilites for freezing parameters.""" 17 | 18 | from typing import Union 19 | 20 | import chex 21 | import jax 22 | 23 | from optax._src import base 24 | # pylint: disable=g-importing-member 25 | from optax.transforms._combining import partition 26 | from optax.transforms._masking import masked 27 | # pylint: enable=g-importing-member 28 | 29 | 30 | def freeze(mask: Union[bool, chex.ArrayTree]) -> base.GradientTransformation: 31 | """Create a transformation that zeros out gradient updates for `mask=True`. 32 | 33 | This essentially freezes (i.e. holding constant) masked parameters. 34 | 35 | The mask must be static (i.e., not dependent on runtime values or updated 36 | during training) and can be: 37 | 38 | - a single boolean (or 0-d JAX bool array), causing every parameter to be 39 | either all-frozen (True) or all-trainable (False), or 40 | - a PyTree of booleans matching the structure of the parameters, where 41 | each leaf indicates whether that specific parameter leaf should be 42 | frozen (True) or left unchanged (False). 43 | 44 | Args: 45 | mask: A boolean prefix tree mask indicating which parameters to freeze. 46 | 47 | Example: 48 | >>> import jax.numpy as jnp 49 | >>> from optax import freeze 50 | >>> params = {'a': jnp.zeros(1), 'b': jnp.zeros(2)} 51 | >>> mask = {'a': True, 'b': False} # Freeze 'a', train 'b' 52 | >>> freezer = freeze(mask) 53 | 54 | Returns: 55 | An Optax `GradientTransformation` which applies `set_to_zero()` wherever 56 | `mask==True`, and leaves other gradients intact. 57 | 58 | .. seealso:: 59 | :func:`optax.selective_transform` : For partitioning updates 60 | so only un-frozen parameters are optimized. 61 | """ 62 | return masked(base.set_to_zero(), mask) 63 | 64 | 65 | def selective_transform( 66 | optimizer: base.GradientTransformation, 67 | *, # force kw-only arguments to show this is a freeze and not allow mask 68 | freeze_mask: Union[bool, chex.ArrayTree], 69 | ) -> base.GradientTransformation: 70 | """Partition updates so that only un-frozen parameters are optimized. 71 | 72 | Example: 73 | >>> import jax.numpy as jnp 74 | >>> from optax import selective_transform 75 | >>> params = {'a': jnp.zeros(1), 'b': jnp.zeros(2)} 76 | >>> mask = {'a': True, 'b': False} # Freeze 'a', train 'b' 77 | >>> selective_opt = selective_transform(optax.adam(1e-3), freeze_mask=mask) 78 | 79 | Args: 80 | optimizer: The inner Optax optimizer to apply to unfrozen leaves. 81 | freeze_mask: A *static* mask (i.e., not dependent on runtime values or 82 | updated during training). It can be either: 83 | 84 | - a scalar bool (or 0-d JAX bool array) to freeze everything (True) or 85 | nothing (False) 86 | - a PyTree of booleans mirroring the parameter tree, marking each leaf 87 | to freeze (True) or train (False). 88 | 89 | Returns: 90 | A `GradientTransformation` that routes each parameter leaf through: 91 | 92 | - the given `optimizer` if its mask is False (“train”), 93 | - `set_to_zero()` if its mask is True (“freeze”). 94 | 95 | .. seealso:: 96 | :func:`optax.freeze` : For simply zeroing out gradients 97 | according to a mask. 98 | """ 99 | 100 | def label_fn(params: base.PyTree): 101 | del params 102 | return jax.tree.map(lambda m: "freeze" if m else "train", freeze_mask) 103 | 104 | return partition( 105 | {"train": optimizer, "freeze": base.set_to_zero()}, 106 | param_labels=label_fn, 107 | ) 108 | -------------------------------------------------------------------------------- /optax/second_order/_deprecated.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Deprecated second order utilities kept for backward compatibility. 16 | """ 17 | 18 | import abc 19 | import functools 20 | from typing import Any, Protocol 21 | 22 | import jax 23 | from jax import flatten_util 24 | import jax.numpy as jnp 25 | from optax._src.deprecations import warn_deprecated_function # pylint: disable=g-importing-member 26 | 27 | 28 | def _ravel(p: Any) -> jax.Array: 29 | return flatten_util.ravel_pytree(p)[0] 30 | 31 | 32 | class LossFn(Protocol): 33 | """A loss function to be optimized.""" 34 | 35 | @abc.abstractmethod 36 | def __call__( 37 | self, params: Any, inputs: jax.Array, targets: jax.Array 38 | ) -> jax.Array: 39 | ... 40 | 41 | 42 | @functools.partial(warn_deprecated_function, version_removed='0.2.9') 43 | def hvp( 44 | loss: LossFn, 45 | v: jax.Array, 46 | params: Any, 47 | inputs: jax.Array, 48 | targets: jax.Array, 49 | ) -> jax.Array: 50 | """Performs an efficient vector-Hessian (of `loss`) product. 51 | 52 | .. deprecated: 0.2.7. This function will be removed in 0.2.9 53 | 54 | Args: 55 | loss: the loss function. 56 | v: a vector of size `ravel(params)`. 57 | params: model parameters. 58 | inputs: inputs at which `loss` is evaluated. 59 | targets: targets at which `loss` is evaluated. 60 | 61 | Returns: 62 | An Array corresponding to the product of `v` and the Hessian of `loss` 63 | evaluated at `(params, inputs, targets)`. 64 | """ 65 | _, unravel_fn = flatten_util.ravel_pytree(params) 66 | loss_fn = lambda p: loss(p, inputs, targets) 67 | return jax.jvp(jax.grad(loss_fn), [params], [unravel_fn(v)])[1] 68 | 69 | 70 | @functools.partial(warn_deprecated_function, version_removed='0.2.9') 71 | def hessian_diag( 72 | loss: LossFn, 73 | params: Any, 74 | inputs: jax.Array, 75 | targets: jax.Array, 76 | ) -> jax.Array: 77 | """Computes the diagonal hessian of `loss` at (`inputs`, `targets`). 78 | 79 | .. deprecated: 0.2.7. This function will be removed in 0.2.9 80 | 81 | Args: 82 | loss: the loss function. 83 | params: model parameters. 84 | inputs: inputs at which `loss` is evaluated. 85 | targets: targets at which `loss` is evaluated. 86 | 87 | Returns: 88 | A DeviceArray corresponding to the product to the Hessian of `loss` 89 | evaluated at `(params, inputs, targets)`. 90 | """ 91 | vs = jnp.eye(_ravel(params).size) 92 | comp = lambda v: jnp.vdot(v, _ravel(hvp(loss, v, params, inputs, targets))) 93 | return jax.vmap(comp)(vs) 94 | 95 | 96 | @functools.partial(warn_deprecated_function, version_removed='0.2.9') 97 | def fisher_diag( 98 | negative_log_likelihood: LossFn, 99 | params: Any, 100 | inputs: jax.Array, 101 | targets: jax.Array, 102 | ) -> jax.Array: 103 | """Computes the diagonal of the (observed) Fisher information matrix. 104 | 105 | .. deprecated: 0.2.7. This function will be removed in 0.2.9 106 | 107 | Args: 108 | negative_log_likelihood: the negative log likelihood function with expected 109 | signature `loss = fn(params, inputs, targets)`. 110 | params: model parameters. 111 | inputs: inputs at which `negative_log_likelihood` is evaluated. 112 | targets: targets at which `negative_log_likelihood` is evaluated. 113 | 114 | Returns: 115 | An Array corresponding to the product to the Hessian of 116 | `negative_log_likelihood` evaluated at `(params, inputs, targets)`. 117 | """ 118 | return jnp.square( 119 | _ravel(jax.grad(negative_log_likelihood)(params, inputs, targets)) 120 | ) 121 | -------------------------------------------------------------------------------- /optax/_src/update.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Apply transformed gradient updates to parameters.""" 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | from optax._src import base 20 | 21 | 22 | def apply_updates(params: base.Params, updates: base.Updates) -> base.Params: 23 | """Applies an update to the corresponding parameters. 24 | 25 | This is a utility functions that applies an update to a set of parameters, and 26 | then returns the updated parameters to the caller. As an example, the update 27 | may be a gradient transformed by a sequence of`GradientTransformations`. This 28 | function is exposed for convenience, but it just adds updates and parameters; 29 | you may also apply updates to parameters manually, using `jax.tree.map` 30 | (e.g. if you want to manipulate updates in custom ways before applying them). 31 | 32 | Args: 33 | params: a tree of parameters. 34 | updates: a tree of updates, the tree structure and the shape of the leaf 35 | nodes must match that of `params`. 36 | 37 | Returns: 38 | Updated parameters, with same structure, shape and type as `params`. 39 | """ 40 | return jax.tree.map( 41 | lambda p, u: ( 42 | None if p is None else jnp.asarray(p + u).astype(jnp.asarray(p).dtype) 43 | ), 44 | params, 45 | updates, 46 | is_leaf=lambda x: x is None, 47 | ) 48 | 49 | 50 | def incremental_update( 51 | new_tensors: base.Params, old_tensors: base.Params, 52 | step_size: jax.typing.ArrayLike) -> base.Params: 53 | """Incrementally update parameters via polyak averaging. 54 | 55 | Polyak averaging tracks an (exponential moving) average of the past 56 | parameters of a model, for use at test/evaluation time. 57 | 58 | Args: 59 | new_tensors: the latest value of the tensors. 60 | old_tensors: a moving average of the values of the tensors. 61 | step_size: the step_size used to update the polyak average on each step. 62 | 63 | Returns: 64 | an updated moving average `step_size*new+(1-step_size)*old` of the params. 65 | 66 | References: 67 | [Polyak et al, 1991](https://epubs.siam.org/doi/10.1137/0330046) 68 | """ 69 | return jax.tree.map( 70 | lambda new, old: ( 71 | None if new is None else step_size * new + (1.0 - step_size) * old 72 | ), 73 | new_tensors, 74 | old_tensors, 75 | is_leaf=lambda x: x is None, 76 | ) 77 | 78 | 79 | def periodic_update( 80 | new_tensors: base.Params, 81 | old_tensors: base.Params, 82 | steps: jax.typing.ArrayLike, # int 83 | update_period: jax.typing.ArrayLike, # int 84 | ) -> base.Params: 85 | """Periodically update all parameters with new values. 86 | 87 | A slow copy of a model's parameters, updated every K actual updates, can be 88 | used to implement forms of self-supervision (in supervised learning), or to 89 | stabilize temporal difference learning updates (in reinforcement learning). 90 | 91 | Args: 92 | new_tensors: the latest value of the tensors. 93 | old_tensors: a slow copy of the model's parameters. 94 | steps: number of update steps on the "online" network. 95 | update_period: every how many steps to update the "target" network. 96 | 97 | Returns: 98 | a slow copy of the model's parameters, updated every `update_period` steps. 99 | 100 | References: 101 | [Grill et al., 2020](https://arxiv.org/abs/2006.07733) 102 | [Mnih et al., 2015](https://arxiv.org/abs/1312.5602) 103 | """ 104 | return jax.lax.cond( 105 | jnp.mod(steps, update_period) == 0, 106 | lambda _: new_tensors, 107 | lambda _: old_tensors, 108 | None, 109 | ) 110 | -------------------------------------------------------------------------------- /optax/assignment/_hungarian_algorithm_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for the Hungarian algorithm.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | import jax.numpy as jnp 21 | import jax.random as jrd 22 | from optax import assignment 23 | import scipy 24 | 25 | 26 | class HungarianAlgorithmTest(parameterized.TestCase): 27 | 28 | @parameterized.product( 29 | fn=[assignment.hungarian_algorithm, assignment.base_hungarian_algorithm], 30 | n=[0, 1, 2, 4, 8, 16], 31 | m=[0, 1, 2, 4, 8, 16], 32 | ) 33 | def test_hungarian_algorithm(self, fn, n, m): 34 | key = jrd.key(0) 35 | costs = jrd.normal(key, (n, m)) 36 | 37 | i, j = fn(costs) 38 | 39 | r = min(costs.shape) 40 | 41 | with self.subTest('i has correct shape'): 42 | assert i.shape == (r,) 43 | 44 | with self.subTest('j has correct shape'): 45 | assert j.shape == (r,) 46 | 47 | with self.subTest('i has correct dtype'): 48 | assert jnp.issubdtype(i.dtype, jnp.integer) 49 | 50 | with self.subTest('j has correct dtype'): 51 | assert jnp.issubdtype(j.dtype, jnp.integer) 52 | 53 | with self.subTest('each element of i is non-negative'): 54 | assert jnp.all(0 <= i) 55 | 56 | with self.subTest('each element of j is non-negative'): 57 | assert jnp.all(0 <= j) 58 | 59 | with self.subTest('each element of i is less than the number of rows'): 60 | assert (i < costs.shape[0]).all() 61 | 62 | with self.subTest('each element of j is less than the number of columns'): 63 | assert (j < costs.shape[1]).all() 64 | 65 | x = jnp.zeros(costs.shape[0], int).at[i].add(1) 66 | 67 | with self.subTest('all elements of i lie in the valid range'): 68 | assert x.sum() == r 69 | 70 | with self.subTest('no two elements of i are equal'): 71 | assert (x <= 1).all() 72 | 73 | y = jnp.zeros(costs.shape[1], int).at[j].add(1) 74 | 75 | with self.subTest('all elements of j lie in the valid range'): 76 | assert y.sum() == r 77 | 78 | with self.subTest('no two elements of j are equal'): 79 | assert (y <= 1).all() 80 | 81 | cost_optax = costs[i, j].sum() 82 | 83 | i_scipy, j_scipy = scipy.optimize.linear_sum_assignment(costs) 84 | cost_scipy = costs[i_scipy, j_scipy].sum() 85 | 86 | with self.subTest('cost matches that obtained by scipy'): 87 | assert jnp.isclose(cost_optax, cost_scipy) 88 | 89 | @parameterized.product( 90 | fn=[assignment.hungarian_algorithm, assignment.base_hungarian_algorithm], 91 | k=[0, 1, 2, 4], 92 | n=[0, 1, 2, 4], 93 | m=[0, 1, 2, 4], 94 | ) 95 | def test_hungarian_algorithm_vmap(self, fn, k, n, m): 96 | key = jrd.key(0) 97 | costs = jrd.normal(key, (k, n, m)) 98 | 99 | with self.subTest('works under vmap'): 100 | i, j = jax.vmap(fn)(costs) 101 | 102 | r = min(costs.shape[1:]) 103 | 104 | with self.subTest('batch i has correct shape'): 105 | assert i.shape == (k, r) 106 | 107 | with self.subTest('batch j has correct shape'): 108 | assert j.shape == (k, r) 109 | 110 | @parameterized.product( 111 | fn=[assignment.hungarian_algorithm, assignment.base_hungarian_algorithm], 112 | ) 113 | def test_hungarian_algorithm_jit(self, fn): 114 | key = jrd.key(0) 115 | costs = jrd.normal(key, (20, 30)) 116 | 117 | with self.subTest('works under jit'): 118 | i, j = jax.jit(fn)(costs) 119 | 120 | r = min(costs.shape) 121 | 122 | with self.subTest('i has correct shape'): 123 | assert i.shape == (r,) 124 | 125 | with self.subTest('j has correct shape'): 126 | assert j.shape == (r,) 127 | 128 | 129 | if __name__ == '__main__': 130 | absltest.main() 131 | -------------------------------------------------------------------------------- /optax/tree_utils/_random_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for methods defined in optax.tree_utils._random.""" 16 | 17 | from collections.abc import Callable 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | import chex 22 | import jax 23 | import jax.numpy as jnp 24 | import jax.random as jrd 25 | import numpy as np 26 | from optax import tree_utils as otu 27 | from optax._src import base 28 | from optax._src import test_utils 29 | 30 | # We consider samplers with varying input dtypes, we do not test all possible 31 | # samplers from `jax.random`. 32 | _SAMPLER_DTYPES = ( 33 | {'sampler': jrd.normal, 'dtype': None}, 34 | {'sampler': jrd.normal, 'dtype': 'bfloat16'}, 35 | {'sampler': jrd.normal, 'dtype': 'float32'}, 36 | {'sampler': jrd.rademacher, 'dtype': 'int32'}, 37 | {'sampler': jrd.bits, 'dtype': 'uint32'}, 38 | ) 39 | 40 | 41 | def get_variable(type_var: str): 42 | """Get a variable of various shape.""" 43 | if type_var == 'real_array': 44 | return jnp.asarray([1.0, 2.0]) 45 | if type_var == 'complex_array': 46 | return jnp.asarray([1.0 + 1j * 2.0, 3.0 + 4j * 5.0]) 47 | if type_var == 'pytree': 48 | pytree = {'k1': 1.0, 'k2': (2.0, 3.0), 'k3': jnp.asarray([4.0, 5.0])} 49 | return jax.tree.map(jnp.asarray, pytree) 50 | raise ValueError(f'Invalid type_var {type_var}') 51 | 52 | 53 | class RandomTest(parameterized.TestCase): 54 | 55 | def test_tree_split_key_like(self): 56 | rng_key = jrd.key(0) 57 | tree = {'a': jnp.zeros(2), 'b': {'c': [jnp.ones(3), jnp.zeros([4, 5])]}} 58 | keys_tree = otu.tree_split_key_like(rng_key, tree) 59 | 60 | with self.subTest('Test structure matches'): 61 | self.assertEqual(jax.tree.structure(tree), jax.tree.structure(keys_tree)) 62 | 63 | with self.subTest('Test random key split'): 64 | fst = jnp.stack(jax.tree.flatten(keys_tree)[0]) 65 | snd = jrd.split(rng_key, jax.tree.structure(tree).num_leaves) 66 | np.testing.assert_array_equal(otu.tree_unwrap_random_key_data(fst), 67 | otu.tree_unwrap_random_key_data(snd)) 68 | 69 | @parameterized.product( 70 | _SAMPLER_DTYPES, 71 | type_var=['real_array', 'complex_array', 'pytree'], 72 | ) 73 | def test_tree_random_like( 74 | self, 75 | sampler: Callable[ 76 | [base.PRNGKey, base.Shape, jax.typing.DTypeLike], jax.Array 77 | ], 78 | dtype: str, 79 | type_var: str, 80 | ): 81 | """Test that tree_random_like matches its flat counterpart.""" 82 | if dtype is not None: 83 | dtype = jnp.dtype(dtype) 84 | rng_key = jrd.key(0) 85 | target_tree = get_variable(type_var) 86 | 87 | rand_tree = otu.tree_random_like( 88 | rng_key, target_tree, sampler=sampler, dtype=dtype 89 | ) 90 | 91 | flat_tree, tree_def = jax.tree.flatten(target_tree) 92 | 93 | with self.subTest('Test structure matches'): 94 | self.assertEqual(tree_def, jax.tree.structure(rand_tree)) 95 | 96 | with self.subTest('Test tree_random_like matches flat random like'): 97 | flat_rand_tree, _ = jax.tree.flatten(rand_tree) 98 | keys = jrd.split(rng_key, tree_def.num_leaves) 99 | expected_flat_rand_tree = [ 100 | sampler(key, x.shape, dtype or x.dtype) 101 | for key, x in zip(keys, flat_tree) 102 | ] 103 | test_utils.assert_trees_all_close(flat_rand_tree, expected_flat_rand_tree) 104 | 105 | with self.subTest('Test dtype are as expected'): 106 | if dtype is not None: 107 | for x in jax.tree.leaves(rand_tree): 108 | self.assertEqual(x.dtype, dtype) 109 | else: 110 | chex.assert_trees_all_equal_dtypes(rand_tree, target_tree) 111 | 112 | 113 | if __name__ == '__main__': 114 | absltest.main() 115 | -------------------------------------------------------------------------------- /optax/_src/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utilities for optax tests.""" 16 | 17 | import contextlib 18 | import logging 19 | import re 20 | import threading 21 | 22 | import jax 23 | import numpy as np 24 | 25 | _LOG_LIST = [] 26 | 27 | 28 | class _LogsToListHandler(logging.Handler): 29 | """A handler for capturing logs programmatically without printing them.""" 30 | 31 | def emit(self, record): 32 | _LOG_LIST.append(record) 33 | 34 | 35 | logger = logging.getLogger("jax") 36 | logger.addHandler(_LogsToListHandler()) 37 | 38 | # We need a lock to be able to run this context manager from multiple threads. 39 | _compilation_log_lock = threading.Lock() 40 | 41 | 42 | @contextlib.contextmanager 43 | def log_compilations(): 44 | """A utility for programmatically capturing JAX compilation logs.""" 45 | with _compilation_log_lock, jax.log_compiles(): 46 | _LOG_LIST.clear() 47 | compilation_logs = [] 48 | yield compilation_logs # these will contain the compilation logs 49 | compilation_logs.extend([ 50 | log for log in _LOG_LIST 51 | if re.search(r"Finished .* compilation", log.getMessage()) 52 | ]) 53 | 54 | 55 | def assert_trees_all_close(actual, desired, rtol=1e-6, atol=0.0, err_msg=None): 56 | """Asserts that two pytrees of arrays are close within a tolerance.""" 57 | flat_a, tree_def_a = jax.tree_util.tree_flatten(actual) 58 | flat_d, tree_def_d = jax.tree_util.tree_flatten(desired) 59 | if tree_def_a != tree_def_d: 60 | raise AssertionError( 61 | f"Trees have different structures:\n{tree_def_a}\n{tree_def_d}" 62 | ) 63 | for x, y in zip(flat_a, flat_d): 64 | np.testing.assert_allclose(x, y, rtol=rtol, atol=atol, err_msg=err_msg) 65 | 66 | 67 | def assert_trees_all_equal(actual, desired, err_msg=None): 68 | """Asserts that two pytrees of arrays are equal.""" 69 | flat_a, tree_def_a = jax.tree_util.tree_flatten(actual) 70 | flat_d, tree_def_d = jax.tree_util.tree_flatten(desired) 71 | if tree_def_a != tree_def_d: 72 | raise AssertionError( 73 | f"Trees have different structures:\n{tree_def_a}\n{tree_def_d}" 74 | ) 75 | for x, y in zip(flat_a, flat_d): 76 | np.testing.assert_array_equal(x, y, err_msg=err_msg) 77 | 78 | 79 | def assert_trees_all_equal_structs(actual, desired): 80 | """Asserts that two pytrees have the same structure.""" 81 | if (jax.tree_util.tree_structure(actual) != 82 | jax.tree_util.tree_structure(desired)): 83 | raise AssertionError( 84 | f"Trees have different structures:\n{actual}\n{desired}" 85 | ) 86 | 87 | 88 | def assert_tree_all_finite(actual, err_msg=None): 89 | """Asserts that all arrays in a pytree are finite.""" 90 | for x in jax.tree_util.tree_leaves(actual): 91 | if not np.all(np.isfinite(x)): 92 | raise AssertionError(f"Array {x} is not finite. {err_msg}") 93 | 94 | 95 | def assert_trees_all_equal_shapes(actual, desired, err_msg=None): 96 | """Asserts that two pytrees of arrays have the same shapes.""" 97 | assert_trees_all_equal_structs(actual, desired) 98 | for x, y in zip(jax.tree_util.tree_leaves(actual), 99 | jax.tree_util.tree_leaves(desired)): 100 | if x.shape != y.shape: 101 | raise AssertionError( 102 | f"Shapes are not equal: {x.shape} != {y.shape}. {err_msg}" 103 | ) 104 | 105 | 106 | def assert_trees_all_equal_dtypes(actual, desired, err_msg=None): 107 | """Asserts that two pytrees of arrays have the same dtypes.""" 108 | assert_trees_all_equal_structs(actual, desired) 109 | for x, y in zip(jax.tree_util.tree_leaves(actual), 110 | jax.tree_util.tree_leaves(desired)): 111 | if x.dtype != y.dtype: 112 | raise AssertionError( 113 | f"Dtypes are not equal: {x.dtype} != {y.dtype}. {err_msg}" 114 | ) 115 | -------------------------------------------------------------------------------- /optax/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The transforms sub-package.""" 16 | 17 | # pylint: disable=g-importing-member 18 | 19 | from optax.transforms._accumulation import ema 20 | from optax.transforms._accumulation import EmaState 21 | from optax.transforms._accumulation import MultiSteps 22 | from optax.transforms._accumulation import MultiStepsState 23 | from optax.transforms._accumulation import ShouldSkipUpdateFunction 24 | from optax.transforms._accumulation import skip_large_updates 25 | from optax.transforms._accumulation import skip_not_finite 26 | from optax.transforms._accumulation import trace 27 | from optax.transforms._accumulation import TraceState 28 | from optax.transforms._adding import add_decayed_weights 29 | from optax.transforms._adding import add_noise 30 | from optax.transforms._adding import AddNoiseState 31 | from optax.transforms._clipping import adaptive_grad_clip 32 | from optax.transforms._clipping import clip 33 | from optax.transforms._clipping import clip_by_block_rms 34 | from optax.transforms._clipping import clip_by_global_norm 35 | from optax.transforms._clipping import per_example_global_norm_clip 36 | from optax.transforms._clipping import per_example_layer_norm_clip 37 | from optax.transforms._clipping import unitwise_clip 38 | from optax.transforms._clipping import unitwise_norm 39 | from optax.transforms._combining import chain 40 | from optax.transforms._combining import named_chain 41 | from optax.transforms._combining import partition 42 | from optax.transforms._combining import PartitionState 43 | from optax.transforms._conditionality import apply_if_finite 44 | from optax.transforms._conditionality import ApplyIfFiniteState 45 | from optax.transforms._conditionality import conditionally_mask 46 | from optax.transforms._conditionality import conditionally_transform 47 | from optax.transforms._conditionality import ConditionallyMaskState 48 | from optax.transforms._conditionality import ConditionallyTransformState 49 | from optax.transforms._conditionality import ConditionFn 50 | from optax.transforms._constraining import keep_params_nonnegative 51 | from optax.transforms._constraining import NonNegativeParamsState 52 | from optax.transforms._constraining import zero_nans 53 | from optax.transforms._constraining import ZeroNansState 54 | from optax.transforms._freezing import freeze 55 | from optax.transforms._freezing import selective_transform 56 | from optax.transforms._layouts import flatten 57 | from optax.transforms._masking import masked 58 | from optax.transforms._masking import MaskedNode 59 | from optax.transforms._masking import MaskedState 60 | from optax.transforms._monitoring import measure_with_ema 61 | from optax.transforms._monitoring import monitor 62 | from optax.transforms._monitoring import MonitorState 63 | from optax.transforms._monitoring import snapshot 64 | from optax.transforms._monitoring import SnapshotState 65 | 66 | __all__ = ( 67 | "adaptive_grad_clip", 68 | "add_decayed_weights", 69 | "add_noise", 70 | "AddNoiseState", 71 | "apply_if_finite", 72 | "ApplyIfFiniteState", 73 | "chain", 74 | "clip_by_block_rms", 75 | "clip_by_global_norm", 76 | "clip", 77 | "conditionally_mask", 78 | "ConditionallyMaskState", 79 | "conditionally_transform", 80 | "ConditionallyTransformState", 81 | "ema", 82 | "EmaState", 83 | "flatten", 84 | "freeze", 85 | "keep_params_nonnegative", 86 | "masked", 87 | "MaskedState", 88 | "measure_with_ema", 89 | "monitor", 90 | "MonitorState", 91 | "MultiSteps", 92 | "MultiStepsState", 93 | "named_chain", 94 | "NonNegativeParamsState", 95 | "partition", 96 | "PartitionState", 97 | "selective_transform", 98 | "ShouldSkipUpdateFunction", 99 | "skip_large_updates", 100 | "skip_not_finite", 101 | "snapshot", 102 | "SnapshotState", 103 | "trace", 104 | "TraceState", 105 | "zero_nans", 106 | "ZeroNansState", 107 | ) 108 | -------------------------------------------------------------------------------- /optax/losses/_self_supervised_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for self-supervised losses in `optax.losses._self_supervised.py`.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | import jax.numpy as jnp 21 | import numpy as np 22 | 23 | from optax.losses import _self_supervised 24 | 25 | 26 | class NtxentTest(absltest.TestCase): 27 | 28 | def setUp(self): 29 | super().setUp() 30 | self.ys = jnp.array([ 31 | [-1.9540, 1.0780], 32 | [0.2380, -0.5703], 33 | [1.8745, -0.0195], 34 | [-0.6719, -1.9210], 35 | ]) 36 | self.ys_2 = jnp.array([ 37 | [0.0, 0.0], 38 | [0.2380, -0.5703], 39 | [1.8745, -0.0195], 40 | [-0.6719, -1.9210], 41 | ]) 42 | self.ts_1 = jnp.array([0, 0, 1, 1]) 43 | self.ts_2 = jnp.array([0, 0, 0, 1]) 44 | # Calculated expected output 45 | self.exp_1 = jnp.array(14.01032) 46 | self.exp_2 = jnp.array(8.968544) 47 | self.exp_3 = jnp.array(9.2889) 48 | 49 | def test_batched(self): 50 | np.testing.assert_allclose( 51 | jax.jit(_self_supervised.ntxent)(self.ys, self.ts_1), 52 | self.exp_1, 53 | atol=1e-4, 54 | ) 55 | 56 | np.testing.assert_allclose( 57 | jax.jit(_self_supervised.ntxent)(self.ys, self.ts_2), 58 | self.exp_2, 59 | atol=1e-4, 60 | ) 61 | 62 | np.testing.assert_allclose( 63 | jax.jit(_self_supervised.ntxent)(self.ys_2, self.ts_1), 64 | self.exp_3, 65 | atol=1e-4, 66 | ) 67 | 68 | 69 | class TripletMarginLossTest(parameterized.TestCase): 70 | 71 | def setUp(self): 72 | super().setUp() 73 | self.a1 = jnp.ones((2, 2)) 74 | self.p1 = jnp.zeros((2, 2)) 75 | self.n1 = jnp.ones((2, 2)) * 2 76 | self.a2 = jnp.zeros((2, 2)) 77 | self.p2 = jnp.ones((2, 2)) 78 | self.n2 = jnp.ones((2, 2)) * 2 79 | 80 | @parameterized.parameters([ 81 | { 82 | 'anchor': np.ones((2, 2)), 83 | 'positive': np.zeros((2, 2)), 84 | 'negative': np.ones((2, 2)) * 2, 85 | 'margin': 1.0, 86 | }, 87 | { 88 | 'anchor': np.zeros((2, 2)), 89 | 'positive': np.ones((2, 2)), 90 | 'negative': np.ones((2, 2)) * 2, 91 | 'margin': 1.0, 92 | } 93 | ]) 94 | def test_batched(self, anchor, positive, negative, margin): 95 | def testing_triplet_margin_loss(a, p, n, margin=1.0, p_norm=2, eps=1e-6): 96 | ap_distance = jnp.sqrt(jnp.sum(jnp.power(a - p, p_norm)) + eps) 97 | an_distance = jnp.sqrt(jnp.sum(jnp.power(a - n, p_norm)) + eps) 98 | return jnp.maximum(ap_distance - an_distance + margin, 0) 99 | 100 | handmade_result = testing_triplet_margin_loss( 101 | a=anchor, p=positive, n=negative, margin=margin 102 | ) 103 | result = jax.jit(_self_supervised.triplet_margin_loss)( 104 | anchor, positive, negative 105 | ) 106 | np.testing.assert_allclose(result, handmade_result, atol=1e-4) 107 | 108 | @parameterized.parameters([ 109 | { 110 | 'anchor': np.ones((2, 2)), 111 | 'positive': np.zeros((2, 2)), 112 | 'negative': np.ones((2, 2)) * 2, 113 | }, 114 | ]) 115 | def test_vmap(self, anchor, positive, negative): 116 | original_loss = _self_supervised.triplet_margin_loss(anchor, positive, 117 | negative) 118 | anchor_batched = anchor.reshape(1, *anchor.shape) 119 | positive_batched = positive.reshape(1, *positive.shape) 120 | negative_batched = negative.reshape(1, *negative.shape) 121 | vmap_loss = jax.jit( 122 | jax.vmap(_self_supervised.triplet_margin_loss, in_axes=(0, 0, 0)))( 123 | anchor_batched, positive_batched, negative_batched) 124 | np.testing.assert_allclose(vmap_loss.flatten(), original_loss.flatten() 125 | , atol=1e-4) 126 | 127 | 128 | if __name__ == '__main__': 129 | absltest.main() 130 | -------------------------------------------------------------------------------- /optax/contrib/_privacy_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for methods in `privacy.py`.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | import jax.numpy as jnp 21 | from optax._src import test_utils 22 | from optax.contrib import _privacy 23 | 24 | 25 | class DifferentiallyPrivateAggregateTest(parameterized.TestCase): 26 | 27 | def setUp(self): 28 | super().setUp() 29 | self.batch_size = 8 30 | self.params = { 31 | 'key_a': (jnp.zeros((2, 3, 4)), jnp.zeros([])), 32 | 'key_b': jnp.zeros((6, 7)), 33 | } 34 | # Example `i`'s grads are full of `i`s. Important to include 0 to ensure 35 | # there are no divisions by 0 (e.g. in norm clipping) 36 | a = jnp.arange(self.batch_size) 37 | self.per_eg_grads = jax.tree.map( 38 | lambda p: jnp.moveaxis( 39 | a * jnp.ones(p.shape + (self.batch_size,)), -1, 0 40 | ), 41 | self.params, 42 | ) 43 | 44 | def test_no_privacy(self): 45 | """l2_norm_clip=MAX_FLOAT32 and noise_multiplier=0 should recover SGD.""" 46 | dp_agg = _privacy.differentially_private_aggregate( 47 | l2_norm_clip=jnp.finfo(jnp.float32).max, noise_multiplier=0.0, key=0 48 | ) 49 | state = dp_agg.init(self.params) 50 | update_fn = jax.jit(dp_agg.update) 51 | mean_grads = jax.tree.map(lambda g: g.mean(0), self.per_eg_grads) 52 | 53 | for _ in range(3): 54 | updates, state = update_fn(self.per_eg_grads, state) 55 | test_utils.assert_trees_all_close(updates, mean_grads) 56 | 57 | @parameterized.parameters(0.5, 10.0, 20.0, 40.0, 80.0) 58 | def test_clipping_norm(self, l2_norm_clip): 59 | dp_agg = _privacy.differentially_private_aggregate( 60 | l2_norm_clip=l2_norm_clip, noise_multiplier=0.0, key=42 61 | ) 62 | state = dp_agg.init(self.params) 63 | update_fn = jax.jit(dp_agg.update) 64 | 65 | # Shape of the three arrays below is (self.batch_size, ) 66 | norms = [ 67 | jnp.linalg.norm(g.reshape(self.batch_size, -1), axis=1) 68 | for g in jax.tree.leaves(self.per_eg_grads) 69 | ] 70 | global_norms = jnp.linalg.norm(jnp.stack(norms), axis=0) 71 | divisors = jnp.maximum(global_norms / l2_norm_clip, 1.0) 72 | # Since the values of all the parameters are the same within each example, 73 | # we can easily compute what the values should be: 74 | expected_val = jnp.mean(jnp.arange(self.batch_size) / divisors) 75 | expected_tree = jax.tree.map( 76 | lambda p: jnp.broadcast_to(expected_val, p.shape), self.params 77 | ) 78 | 79 | for _ in range(3): 80 | updates, state = update_fn(self.per_eg_grads, state, self.params) 81 | test_utils.assert_trees_all_close(updates, expected_tree, rtol=2e-7) 82 | 83 | @parameterized.parameters((3.0, 2.0), (1.0, 5.0), (100.0, 4.0), (1.0, 90.0)) 84 | def test_noise_multiplier(self, l2_norm_clip, noise_multiplier): 85 | """Standard dev. of noise should be l2_norm_clip * noise_multiplier.""" 86 | dp_agg = _privacy.differentially_private_aggregate( 87 | l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, key=1337 88 | ) 89 | state = dp_agg.init(self.params) 90 | update_fn = jax.jit(dp_agg.update) 91 | expected_std = l2_norm_clip * noise_multiplier 92 | 93 | grads = [jnp.ones((1, 100, 100))] # batch size 1 94 | for _ in range(3): 95 | updates, state = update_fn(grads, state) 96 | test_utils.assert_trees_all_close( 97 | expected_std, jnp.std(updates[0]), atol=0.1 * expected_std 98 | ) 99 | 100 | def test_aggregated_updates_as_input_fails(self): 101 | """Expect per-example gradients as input to this transform.""" 102 | dp_agg = _privacy.differentially_private_aggregate( 103 | l2_norm_clip=0.1, noise_multiplier=1.1, key=2021 104 | ) 105 | state = dp_agg.init(self.params) 106 | mean_grads = jax.tree.map(lambda g: g.mean(0), self.per_eg_grads) 107 | with self.assertRaises(ValueError): 108 | dp_agg.update(mean_grads, state, self.params) 109 | 110 | 111 | if __name__ == '__main__': 112 | absltest.main() 113 | -------------------------------------------------------------------------------- /optax/contrib/_complex_valued.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Complex-valued optimization. 16 | 17 | When using `split_real_and_imaginary` to wrap an optimizer, we split the complex 18 | parameters and updates into pairs of real ones before sending them to the 19 | `update` of the wrapped optimizer, and merge the pairs of transformed real 20 | updates into complex ones afterward. In this way, optimizers on complex 21 | parameters behave the same way as if they were running on two real parameters. 22 | 23 | Note that the convention of conjugate for complex gradients in JAX is different 24 | from that in PyTorch and other frameworks, and we need to manually conjugate the 25 | gradients between `jax.grad` and `optimizer.update`. 26 | 27 | See details at https://github.com/deepmind/optax/issues/196 28 | """ 29 | 30 | from typing import NamedTuple, Union 31 | 32 | import jax 33 | import jax.numpy as jnp 34 | from optax._src import base 35 | 36 | 37 | # NOTE(dsuo): Opt out of using the new `jax.pmap` implementation. There is 38 | # a C++ failure in windowing that needs to be resolved. 39 | jax.config.update('jax_pmap_shmap_merge', False) 40 | 41 | 42 | class SplitRealAndImaginaryArrays(NamedTuple): 43 | """A pair of real arrays split from a complex array.""" 44 | 45 | real: jax.typing.ArrayLike 46 | imaginary: jax.typing.ArrayLike 47 | 48 | 49 | def _complex_to_real_pair( 50 | x: jax.typing.ArrayLike, 51 | ) -> Union[jax.typing.ArrayLike, SplitRealAndImaginaryArrays]: 52 | """Splits a complex array into a `SplitRealAndImaginaryArrays`. 53 | 54 | Args: 55 | x: The input array, can be complex or real. 56 | 57 | Returns: 58 | `SplitRealAndImaginaryArrays` if the input is a complex array. If the 59 | input is a real array, it is passed through unmodified. 60 | """ 61 | if jnp.iscomplexobj(x): 62 | return SplitRealAndImaginaryArrays(x.real, x.imag) 63 | else: 64 | return x 65 | 66 | 67 | def _real_pair_to_complex( 68 | x: Union[jax.typing.ArrayLike, SplitRealAndImaginaryArrays], 69 | ) -> jax.typing.ArrayLike: 70 | """Merges a `SplitRealAndImaginaryArrays` into a complex array. 71 | 72 | Args: 73 | x: The input `SplitRealAndImaginaryArrays` or array. 74 | 75 | Returns: 76 | A complex array obtained from the real and imaginary parts of the 77 | `SplitRealAndImaginaryArrays`. If the input is not a 78 | `SplitRealAndImaginaryArrays`, it is passed through unmodified. 79 | """ 80 | if isinstance(x, SplitRealAndImaginaryArrays): 81 | return x.real + x.imaginary * 1j 82 | else: 83 | return x 84 | 85 | 86 | class SplitRealAndImaginaryState(NamedTuple): 87 | """Maintains the inner transformation state for `split_real_and_imaginary`.""" 88 | 89 | inner_state: base.OptState 90 | 91 | 92 | def split_real_and_imaginary( 93 | inner: base.GradientTransformation, 94 | ) -> base.GradientTransformation: 95 | """Splits the real and imaginary components of complex updates into two. 96 | 97 | The inner transformation processes real parameters and updates, and the 98 | pairs of transformed real updates are merged into complex updates. 99 | 100 | Parameters and updates that are real before splitting are passed through 101 | unmodified. 102 | 103 | Args: 104 | inner: The inner transformation. 105 | 106 | Returns: 107 | An `optax.GradientTransformation`. 108 | """ 109 | 110 | def init_fn(params): 111 | params = jax.tree.map(_complex_to_real_pair, params) 112 | inner_state = inner.init(params) 113 | return SplitRealAndImaginaryState(inner_state) 114 | 115 | def update_fn(updates, state, params=None): 116 | inner_state = state.inner_state 117 | updates = jax.tree.map(_complex_to_real_pair, updates) 118 | params = jax.tree.map(_complex_to_real_pair, params) 119 | updates, inner_state = inner.update(updates, inner_state, params) 120 | updates = jax.tree.map( 121 | _real_pair_to_complex, 122 | updates, 123 | is_leaf=lambda x: isinstance(x, SplitRealAndImaginaryArrays), 124 | ) 125 | return updates, SplitRealAndImaginaryState(inner_state) 126 | 127 | return base.GradientTransformation(init_fn, update_fn) 128 | -------------------------------------------------------------------------------- /optax/tree_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The tree_utils sub-package.""" 16 | 17 | import functools 18 | import typing 19 | 20 | # pylint: disable=g-importing-member 21 | 22 | from optax.tree_utils._casting import tree_cast 23 | from optax.tree_utils._casting import tree_cast_like 24 | from optax.tree_utils._casting import tree_dtype 25 | from optax.tree_utils._random import tree_random_like 26 | from optax.tree_utils._random import tree_split_key_like 27 | from optax.tree_utils._random import tree_unwrap_random_key_data 28 | from optax.tree_utils._state_utils import NamedTupleKey 29 | from optax.tree_utils._state_utils import tree_get 30 | from optax.tree_utils._state_utils import tree_get_all_with_path 31 | from optax.tree_utils._state_utils import tree_map_params 32 | from optax.tree_utils._state_utils import tree_set 33 | from optax.tree_utils._tree_math import tree_add 34 | from optax.tree_utils._tree_math import tree_add_scale 35 | from optax.tree_utils._tree_math import tree_allclose 36 | from optax.tree_utils._tree_math import tree_batch_shape 37 | from optax.tree_utils._tree_math import tree_bias_correction 38 | from optax.tree_utils._tree_math import tree_clip 39 | from optax.tree_utils._tree_math import tree_conj 40 | from optax.tree_utils._tree_math import tree_div 41 | from optax.tree_utils._tree_math import tree_full_like 42 | from optax.tree_utils._tree_math import tree_max 43 | from optax.tree_utils._tree_math import tree_min 44 | from optax.tree_utils._tree_math import tree_mul 45 | from optax.tree_utils._tree_math import tree_norm 46 | from optax.tree_utils._tree_math import tree_ones_like 47 | from optax.tree_utils._tree_math import tree_real 48 | from optax.tree_utils._tree_math import tree_scale 49 | from optax.tree_utils._tree_math import tree_size 50 | from optax.tree_utils._tree_math import tree_sub 51 | from optax.tree_utils._tree_math import tree_sum 52 | from optax.tree_utils._tree_math import tree_update_infinity_moment 53 | from optax.tree_utils._tree_math import tree_update_moment 54 | from optax.tree_utils._tree_math import tree_update_moment_per_elem_norm 55 | from optax.tree_utils._tree_math import tree_vdot 56 | from optax.tree_utils._tree_math import tree_where 57 | from optax.tree_utils._tree_math import tree_zeros_like 58 | 59 | _deprecations = { 60 | # Added Mar 2025 61 | 'tree_scalar_mul': ( 62 | ('optax.tree_utils.tree_scalar_mul is deprecated: use' 63 | ' optax.tree_utils.tree_scale (optax v0.2.5 or newer).'), 64 | tree_scale, 65 | ), 66 | 'tree_add_scalar_mul': ( 67 | ('optax.tree_utils.tree_scalar_mul is deprecated: use' 68 | ' optax.tree_utils.tree_scale (optax v0.2.5 or newer).'), 69 | tree_add_scale, 70 | ), 71 | # Added May 2025 72 | 'tree_l1_norm': ( 73 | ('optax.tree_utils.tree_l1_norm is deprecated: use' 74 | ' optax.tree_utils.tree_norm(..., ord=1) (optax v0.2.5 or newer).'), 75 | functools.partial(tree_norm, ord=1), 76 | ), 77 | 'tree_l2_norm': ( 78 | ('optax.tree_utils.tree_l2_norm is deprecated: use' 79 | ' optax.tree_utils.tree_norm (optax v0.2.5 or newer).'), 80 | functools.partial(tree_norm, ord=2), 81 | ), 82 | 'tree_linf_norm': ( 83 | ('optax.tree_utils.tree_linf_norm is deprecated: use' 84 | ' optax.tree_utils.tree_norm(..., ord=jnp.inf)' 85 | ' (optax v0.2.5 or newer).'), 86 | functools.partial(tree_norm, ord='inf'), 87 | ), 88 | } 89 | 90 | # pylint: disable=g-import-not-at-top 91 | # pylint: disable=g-bad-import-order 92 | if typing.TYPE_CHECKING: 93 | tree_scalar_mul = tree_scale 94 | tree_add_scalar_mul = tree_add_scale 95 | tree_l1_norm = functools.partial(tree_norm, ord=1) 96 | tree_l2_norm = tree_norm 97 | tree_linf_norm = functools.partial(tree_norm, ord='inf') 98 | 99 | else: 100 | # pylint: disable=line-too-long 101 | from optax._src.deprecations import deprecation_getattr as _deprecation_getattr # noqa: E501 102 | # pylint: enable=line-too-long 103 | 104 | __getattr__ = _deprecation_getattr(__name__, _deprecations) 105 | del _deprecation_getattr 106 | # pylint: enable=g-bad-import-order 107 | # pylint: enable=g-import-not-at-top 108 | -------------------------------------------------------------------------------- /optax/tree_utils/_casting_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for tree utilities on data types.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | import jax.numpy as jnp 21 | import numpy as np 22 | from optax import tree_utils as otu 23 | 24 | 25 | class CastingTest(parameterized.TestCase): 26 | 27 | @parameterized.parameters([ 28 | (jnp.float32, [1.3, 2.001, 3.6], [-3.3], [1.3, 2.001, 3.6], [-3.3]), 29 | (jnp.float32, [1.3, 2.001, 3.6], [-3], [1.3, 2.001, 3.6], [-3.0]), 30 | (jnp.int32, [1.3, 2.001, 3.6], [-3.3], [1, 2, 3], [-3]), 31 | (jnp.int32, [1.3, 2.001, 3.6], [-3], [1, 2, 3], [-3]), 32 | (None, [1.123, 2.33], [0.0], [1.123, 2.33], [0.0]), 33 | (None, [1, 2, 3], [0.0], [1, 2, 3], [0.0]), 34 | ]) 35 | def test_tree_cast(self, dtype, b, c, new_b, new_c): 36 | def _build_tree(val1, val2): 37 | dict_tree = {'a': {'b': jnp.array(val1)}, 'c': jnp.array(val2)} 38 | return jax.tree.map(lambda x: x, dict_tree) 39 | 40 | tree = _build_tree(b, c) 41 | tree = otu.tree_cast(tree, dtype=dtype) 42 | jax.tree.map(np.testing.assert_array_equal, tree, _build_tree(new_b, new_c)) 43 | 44 | @parameterized.parameters([ 45 | (jnp.float16, [1.3, 2.001, 3.6], [-3]), 46 | (jnp.bfloat16, [1.3, 2.001, 3.6], [-3.3]), 47 | (jnp.float32, [1.3, 2.001, 3.6], [-3.3]), 48 | (jnp.int32, [1.3, 2.001, 3.6], [-3]), 49 | ]) 50 | def test_tree_cast_like(self, dtype, b, c): 51 | def _build_tree(val1, val2): 52 | dict_tree = {'a': {'b': jnp.array(val1)}, 'c': jnp.array(val2)} 53 | return jax.tree.map(lambda x: x, dict_tree) 54 | 55 | tree = _build_tree(b, c) 56 | target_tree = _build_tree(b, c) 57 | target_tree = jax.tree.map(lambda x: x.astype(dtype), target_tree) 58 | tree = otu.tree_cast_like(tree, target_tree) 59 | jax.tree.map(np.testing.assert_array_equal, tree, target_tree) 60 | 61 | def test_tree_dtype(self): 62 | """Test fecthing data type of a tree.""" 63 | 64 | with self.subTest('Check that it returns the right dtype'): 65 | tree = { 66 | 'a': {'b': jnp.array(1.0, dtype=jnp.float32)}, 67 | 'c': jnp.array(2.0, dtype=jnp.float32), 68 | } 69 | dtype = otu.tree_dtype(tree) 70 | self.assertEqual(dtype, jnp.float32) 71 | 72 | with self.subTest('Check that it raises an error if dtypes differ'): 73 | tree = { 74 | 'a': {'b': jnp.array(1.0, dtype=jnp.bfloat16)}, 75 | 'c': jnp.array(2.0, dtype=jnp.float32), 76 | } 77 | self.assertRaises(ValueError, otu.tree_dtype, tree) 78 | 79 | tree = { 80 | 'a': {'b': jnp.array(1.0, dtype=jnp.bfloat16)}, 81 | 'c': jnp.array(2.0, dtype=jnp.float32), 82 | } 83 | 84 | with self.subTest('Check that it works with lowest common dtype'): 85 | dtype = otu.tree_dtype(tree, 'lowest') 86 | self.assertEqual(dtype, jnp.bfloat16) 87 | 88 | with self.subTest('Check that it works with highest common dtype'): 89 | dtype = otu.tree_dtype(tree, 'highest') 90 | self.assertEqual(dtype, jnp.float32) 91 | 92 | tree = { 93 | 'a': {'b': jnp.array(1.0, dtype=jnp.bfloat16)}, 94 | 'c': jnp.array(2.0, dtype=jnp.float16), 95 | } 96 | 97 | with self.subTest('Check that it works when promoting mixed dtype'): 98 | dtype = otu.tree_dtype(tree, 'promote') 99 | self.assertEqual(dtype, jnp.float32) 100 | 101 | with self.subTest( 102 | 'Check that it raises an error if no dtypes cannot be promoted to one' 103 | ' another' 104 | ): 105 | self.assertRaises(ValueError, otu.tree_dtype, tree, 'lowest') 106 | self.assertRaises(ValueError, otu.tree_dtype, tree, 'highest') 107 | 108 | @parameterized.named_parameters( 109 | {'testcase_name': 'empty_dict', 'tree': {}}, 110 | {'testcase_name': 'empty_list', 'tree': []}, 111 | {'testcase_name': 'empty_tuple', 'tree': ()}, 112 | {'testcase_name': 'empty_none', 'tree': None}, 113 | ) 114 | def test_tree_dtype_utilities_with_empty_trees(self, tree): 115 | """Test tree data type utilities on empty trees.""" 116 | default_dtype = jnp.asarray(1.0).dtype 117 | 118 | with self.subTest('Check tree_dtype works with empty trees.'): 119 | dtype = otu.tree_dtype(tree) 120 | self.assertEqual(dtype, default_dtype) 121 | 122 | 123 | if __name__ == '__main__': 124 | absltest.main() 125 | -------------------------------------------------------------------------------- /optax/contrib/_cocob.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Backpropagating variant of the COntinuous COin Betting stochastic algorithm. 16 | 17 | COCOB is a contributed optimizer implemented from Algorithm 2 of "Training Deep 18 | Networks without Learning Rates Through Coin Betting" by Francesco Orabona and 19 | Tatiana Tommasi. 20 | """ 21 | from collections.abc import Callable 22 | from typing import Any, NamedTuple, Optional, Union 23 | 24 | import jax 25 | import jax.numpy as jnp 26 | from optax._src import base 27 | from optax._src import combine 28 | from optax._src import transform 29 | import optax.tree 30 | 31 | 32 | class COCOBState(NamedTuple): 33 | """State for COntinuous COin Betting.""" 34 | 35 | init_particles: base.Updates 36 | cumulative_gradients: base.Updates 37 | scale: base.Updates 38 | subgradients: base.Updates 39 | reward: base.Updates 40 | 41 | 42 | def scale_by_cocob( 43 | alpha: jax.typing.ArrayLike = 100.0, eps: jax.typing.ArrayLike = 1e-8 44 | ) -> base.GradientTransformation: 45 | """Rescale updates according to the COntinuous COin Betting algorithm. 46 | 47 | See :func:`optax.contrib.cocob` for more details. 48 | 49 | Args: 50 | alpha: fraction to bet parameter of the COCOB optimizer 51 | eps: jitter term to avoid dividing by 0 52 | 53 | Returns: 54 | A `GradientTransformation` object. 55 | """ 56 | 57 | def init_fn(params): 58 | init_adapt = optax.tree.zeros_like(params) 59 | init_scale = optax.tree.ones_like(params) 60 | init_scale = optax.tree.scale(eps, init_scale) 61 | return COCOBState( 62 | init_particles=params, 63 | cumulative_gradients=init_adapt, 64 | scale=init_scale, 65 | subgradients=init_adapt, 66 | reward=init_adapt, 67 | ) 68 | 69 | def update_fn(updates, state, params): 70 | init_particles, cumulative_grads, scale, subgradients, reward = state 71 | 72 | scale = jax.tree.map( 73 | lambda L, c: jnp.maximum(L, jnp.abs(c)), scale, updates 74 | ) 75 | subgradients = jax.tree.map( 76 | lambda G, c: G + jnp.abs(c), subgradients, updates 77 | ) 78 | reward = jax.tree.map( 79 | lambda R, c, p, p0: jnp.maximum(R - c * (p - p0), 0), 80 | reward, 81 | updates, 82 | params, 83 | init_particles, 84 | ) 85 | cumulative_grads = jax.tree.map( 86 | lambda C, c: C - c, cumulative_grads, updates 87 | ) 88 | 89 | new_updates = jax.tree.map( 90 | lambda p, p0, C, L, G, R: ( 91 | -p + (p0 + C / (L * jnp.maximum(G + L, alpha * L)) * (L + R)) 92 | ), 93 | params, 94 | init_particles, 95 | cumulative_grads, 96 | scale, 97 | subgradients, 98 | reward, 99 | ) 100 | 101 | new_state = COCOBState( 102 | init_particles=init_particles, 103 | cumulative_gradients=cumulative_grads, 104 | scale=scale, 105 | subgradients=subgradients, 106 | reward=reward, 107 | ) 108 | return new_updates, new_state 109 | 110 | return base.GradientTransformation(init_fn, update_fn) 111 | 112 | 113 | def cocob( 114 | learning_rate: base.ScalarOrSchedule = 1.0, 115 | alpha: jax.typing.ArrayLike = 100.0, 116 | eps: jax.typing.ArrayLike = 1e-8, 117 | weight_decay: jax.typing.ArrayLike = 0.0, 118 | mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, 119 | ) -> base.GradientTransformation: 120 | """Rescale updates according to the COntinuous COin Betting algorithm. 121 | 122 | Algorithm for stochastic subgradient descent. Uses a gambling algorithm to 123 | find the minimizer of a non-smooth objective function by accessing its 124 | subgradients. All we need is a good gambling strategy. See Algorithm 2 of: 125 | 126 | Args: 127 | learning_rate: optional learning rate to e.g. inject some scheduler 128 | alpha: fraction to bet parameter of the COCOB optimizer 129 | eps: jitter term to avoid dividing by 0 130 | weight_decay: L2 penalty 131 | mask: mask for weight decay 132 | 133 | Returns: 134 | A `GradientTransformation` object. 135 | 136 | References: 137 | Orabana et al, `Training Deep Networks without Learning Rates Through Coin 138 | Betting `_, 2017 139 | """ 140 | return combine.chain( 141 | transform.add_decayed_weights(weight_decay, mask), 142 | transform.scale_by_learning_rate(learning_rate, flip_sign=False), 143 | scale_by_cocob(alpha, eps), 144 | ) 145 | -------------------------------------------------------------------------------- /docs/api/utilities.rst: -------------------------------------------------------------------------------- 1 | Utilities 2 | ========= 3 | 4 | General 5 | ------- 6 | 7 | .. currentmodule:: optax 8 | 9 | .. autosummary:: 10 | scale_gradient 11 | value_and_grad_from_state 12 | 13 | Scale gradient 14 | ~~~~~~~~~~~~~~ 15 | .. autofunction:: scale_gradient 16 | 17 | Value and grad from state 18 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 19 | .. autofunction:: value_and_grad_from_state 20 | 21 | 22 | Numerical Stability 23 | ------------------- 24 | 25 | .. currentmodule:: optax 26 | 27 | .. autosummary:: 28 | safe_increment 29 | safe_norm 30 | safe_root_mean_squares 31 | 32 | Safe increment 33 | ~~~~~~~~~~~~~~ 34 | .. autofunction:: safe_increment 35 | 36 | Safe norm 37 | ~~~~~~~~~ 38 | .. autofunction:: safe_norm 39 | 40 | Safe root mean squares 41 | ~~~~~~~~~~~~~~~~~~~~~~ 42 | .. autofunction:: safe_root_mean_squares 43 | 44 | 45 | Linear Algebra Operators 46 | ------------------------ 47 | 48 | .. currentmodule:: optax 49 | 50 | .. autosummary:: 51 | matrix_inverse_pth_root 52 | power_iteration 53 | nnls 54 | 55 | Matrix inverse pth root 56 | ~~~~~~~~~~~~~~~~~~~~~~~~ 57 | .. autofunction:: matrix_inverse_pth_root 58 | 59 | Power iteration 60 | ~~~~~~~~~~~~~~~ 61 | .. autofunction:: power_iteration 62 | 63 | Non-negative least squares 64 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 65 | .. autofunction:: nnls 66 | 67 | 68 | Second Order Optimization 69 | ------------------------- 70 | 71 | .. currentmodule:: optax.second_order 72 | 73 | .. autosummary:: 74 | fisher_diag 75 | hessian_diag 76 | hvp 77 | 78 | Fisher diagonal 79 | ~~~~~~~~~~~~~~~ 80 | .. autofunction:: fisher_diag 81 | 82 | Hessian diagonal 83 | ~~~~~~~~~~~~~~~~ 84 | .. autofunction:: hessian_diag 85 | 86 | Hessian vector product 87 | ~~~~~~~~~~~~~~~~~~~~~~ 88 | .. autofunction:: hvp 89 | 90 | 91 | Tree 92 | ---- 93 | 94 | .. currentmodule:: optax.tree_utils 95 | 96 | .. autosummary:: 97 | NamedTupleKey 98 | tree_add 99 | tree_add_scale 100 | tree_allclose 101 | tree_batch_shape 102 | tree_cast 103 | tree_cast_like 104 | tree_clip 105 | tree_conj 106 | tree_div 107 | tree_dtype 108 | tree_full_like 109 | tree_get 110 | tree_get_all_with_path 111 | tree_norm 112 | tree_map_params 113 | tree_max 114 | tree_min 115 | tree_mul 116 | tree_ones_like 117 | tree_random_like 118 | tree_real 119 | tree_split_key_like 120 | tree_scale 121 | tree_set 122 | tree_size 123 | tree_sub 124 | tree_sum 125 | tree_vdot 126 | tree_where 127 | tree_zeros_like 128 | 129 | NamedTupleKey 130 | ~~~~~~~~~~~~~ 131 | .. autoclass:: NamedTupleKey 132 | 133 | Tree add 134 | ~~~~~~~~ 135 | .. autofunction:: tree_add 136 | 137 | Tree add and scalar multiply 138 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 139 | .. autofunction:: tree_add_scale 140 | 141 | Tree all close 142 | ~~~~~~~~~~~~~~ 143 | .. autofunction:: tree_allclose 144 | 145 | Tree batch reshaping 146 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 147 | .. autofunction:: tree_batch_shape 148 | 149 | Tree cast 150 | ~~~~~~~~~ 151 | .. autofunction:: tree_cast 152 | 153 | Tree cast like 154 | ~~~~~~~~~~~~~~ 155 | .. autofunction:: tree_cast_like 156 | 157 | Tree clip 158 | ~~~~~~~~~ 159 | .. autofunction:: tree_clip 160 | 161 | Tree conjugate 162 | ~~~~~~~~~~~~~~ 163 | .. autofunction:: tree_conj 164 | 165 | Tree data type 166 | ~~~~~~~~~~~~~~ 167 | .. autofunction:: tree_dtype 168 | 169 | Tree full like 170 | ~~~~~~~~~~~~~~ 171 | .. autofunction:: tree_full_like 172 | 173 | Tree divide 174 | ~~~~~~~~~~~ 175 | .. autofunction:: tree_div 176 | 177 | Fetch single value that match a given key 178 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 179 | .. autofunction:: tree_get 180 | 181 | Fetch all values that match a given key 182 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 183 | .. autofunction:: tree_get_all_with_path 184 | 185 | Tree norm 186 | ~~~~~~~~~ 187 | .. autofunction:: tree_norm 188 | 189 | Tree map parameters 190 | ~~~~~~~~~~~~~~~~~~~ 191 | .. autofunction:: tree_map_params 192 | 193 | Tree max 194 | ~~~~~~~~ 195 | .. autofunction:: tree_max 196 | 197 | Tree min 198 | ~~~~~~~~ 199 | .. autofunction:: tree_min 200 | 201 | Tree multiply 202 | ~~~~~~~~~~~~~ 203 | .. autofunction:: tree_mul 204 | 205 | Tree ones like 206 | ~~~~~~~~~~~~~~ 207 | .. autofunction:: tree_ones_like 208 | 209 | Split key according to structure of a tree 210 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 211 | .. autofunction:: tree_split_key_like 212 | 213 | Tree with random values 214 | ~~~~~~~~~~~~~~~~~~~~~~~ 215 | .. autofunction:: tree_random_like 216 | 217 | Tree real part 218 | ~~~~~~~~~~~~~~ 219 | .. autofunction:: tree_real 220 | 221 | Tree scalar multiply 222 | ~~~~~~~~~~~~~~~~~~~~ 223 | .. autofunction:: tree_scale 224 | 225 | Set values in a tree 226 | ~~~~~~~~~~~~~~~~~~~~ 227 | .. autofunction:: tree_set 228 | 229 | Tree size 230 | ~~~~~~~~~ 231 | .. autofunction:: tree_size 232 | 233 | Tree subtract 234 | ~~~~~~~~~~~~~ 235 | .. autofunction:: tree_sub 236 | 237 | Tree sum 238 | ~~~~~~~~ 239 | .. autofunction:: tree_sum 240 | 241 | Tree inner product 242 | ~~~~~~~~~~~~~~~~~~ 243 | .. autofunction:: tree_vdot 244 | 245 | Tree where 246 | ~~~~~~~~~~ 247 | .. autofunction:: tree_where 248 | 249 | Tree zeros like 250 | ~~~~~~~~~~~~~~~ 251 | .. autofunction:: tree_zeros_like 252 | -------------------------------------------------------------------------------- /optax/transforms/_constraining_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for methods in `optax.transforms._constraining.py`.""" 16 | 17 | from absl.testing import absltest 18 | import jax 19 | import jax.numpy as jnp 20 | from optax._src import combine 21 | from optax._src import test_utils 22 | from optax._src import transform 23 | from optax._src import update 24 | from optax.transforms import _accumulation 25 | from optax.transforms import _constraining 26 | 27 | 28 | STEPS = 50 29 | LR = 1e-2 30 | 31 | 32 | class ConstraintsTest(absltest.TestCase): 33 | 34 | def test_keep_params_nonnegative(self): 35 | grads = ( 36 | jnp.array([500.0, -500.0, 0.0]), 37 | jnp.array([500.0, -500.0, 0.0]), 38 | jnp.array([500.0, -500.0, 0.0]), 39 | ) 40 | 41 | params = ( 42 | jnp.array([-1.0, -1.0, -1.0]), 43 | jnp.array([1.0, 1.0, 1.0]), 44 | jnp.array([0.0, 0.0, 0.0]), 45 | ) 46 | 47 | # vanilla sgd 48 | opt = combine.chain( 49 | _accumulation.trace(decay=0, nesterov=False), transform.scale(-LR) 50 | ) 51 | opt_state = opt.init(params) 52 | 53 | updates, _ = opt.update(grads, opt_state, params) 54 | new_params = update.apply_updates(params, updates) 55 | 56 | test_utils.assert_trees_all_close( 57 | new_params, 58 | ( 59 | jnp.array([-6.0, 4.0, -1.0]), 60 | jnp.array([-4.0, 6.0, 1.0]), 61 | jnp.array([-5.0, 5.0, 0.0]), 62 | ), 63 | ) 64 | 65 | # sgd with keeping parameters non-negative 66 | opt = combine.chain( 67 | _accumulation.trace(decay=0, nesterov=False), 68 | transform.scale(-LR), 69 | _constraining.keep_params_nonnegative(), 70 | ) 71 | opt_state = opt.init(params) 72 | 73 | updates, _ = opt.update(grads, opt_state, params) 74 | new_params = update.apply_updates(params, updates) 75 | 76 | test_utils.assert_trees_all_close( 77 | new_params, 78 | ( 79 | jnp.array([0.0, 4.0, 0.0]), 80 | jnp.array([0.0, 6.0, 1.0]), 81 | jnp.array([0.0, 5.0, 0.0]), 82 | ), 83 | ) 84 | 85 | def test_zero_nans(self): 86 | params = (jnp.zeros([3]), jnp.zeros([3]), jnp.zeros([3])) 87 | 88 | opt = _constraining.zero_nans() 89 | opt_state = jax.jit(opt.init)(params) 90 | update_fn = jax.jit(opt.update) 91 | 92 | test_utils.assert_trees_all_close( 93 | opt_state, _constraining.ZeroNansState((jnp.array(False),) * 3) 94 | ) 95 | 96 | # Check an upate with nans 97 | grads_with_nans = ( 98 | jnp.ones([3]), 99 | jnp.array([1.0, float('nan'), float('nan')]), 100 | jnp.array([float('nan'), 1.0, 1.0]), 101 | ) 102 | updates, opt_state = update_fn(grads_with_nans, opt_state) 103 | test_utils.assert_trees_all_close( 104 | opt_state, 105 | _constraining.ZeroNansState( 106 | (jnp.array(False), jnp.array(True), jnp.array(True)) 107 | ), 108 | ) 109 | test_utils.assert_trees_all_close( 110 | updates, 111 | (jnp.ones([3]), jnp.array([1.0, 0.0, 0.0]), jnp.array([0.0, 1.0, 1.0])), 112 | ) 113 | 114 | # Check an upate with nans and infs 115 | grads_with_nans_infs = ( 116 | jnp.ones([3]), 117 | jnp.array([1.0, float('nan'), float('nan')]), 118 | jnp.array([float('inf'), 1.0, 1.0]), 119 | ) 120 | updates, opt_state = update_fn(grads_with_nans_infs, opt_state) 121 | test_utils.assert_trees_all_close( 122 | opt_state, 123 | _constraining.ZeroNansState( 124 | (jnp.array(False), jnp.array(True), jnp.array(False)) 125 | ), 126 | ) 127 | test_utils.assert_trees_all_close( 128 | updates, 129 | ( 130 | jnp.ones([3]), 131 | jnp.array([1.0, 0.0, 0.0]), 132 | jnp.array([float('inf'), 1.0, 1.0]), 133 | ), 134 | ) 135 | 136 | # Check an update with only good values 137 | grads = (jnp.ones([3]), jnp.ones([3]), jnp.ones([3])) 138 | updates, opt_state = update_fn(grads, opt_state) 139 | test_utils.assert_trees_all_close( 140 | opt_state, 141 | _constraining.ZeroNansState( 142 | (jnp.array(False), jnp.array(False), jnp.array(False)) 143 | ), 144 | ) 145 | test_utils.assert_trees_all_close(updates, grads) 146 | 147 | def test_none_arguments(self): 148 | tf = _constraining.keep_params_nonnegative() 149 | state = tf.init(jnp.array([1.0, 2.0, 3.0])) 150 | with self.assertRaises(ValueError): 151 | tf.update(jnp.array([1.0, 2.0, 3.0]), state, None) 152 | 153 | 154 | if __name__ == '__main__': 155 | absltest.main() 156 | -------------------------------------------------------------------------------- /optax/_src/sharding_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Module for testing sharding and related behavior of the optax public API.""" 17 | 18 | import os 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | import jax 23 | import jax.numpy as jnp 24 | import optax 25 | from optax._src import test_utils 26 | from optax.experimental import microbatching 27 | 28 | 29 | os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8' 30 | 31 | OPTIMIZERS = { 32 | 'adam': optax.adam(1.0), 33 | 'sgd': optax.sgd(1.0), 34 | 'adabelief': optax.adabelief(1.0), 35 | 'adamax': optax.adamax(1.0), 36 | 'adagrad': optax.adagrad(1.0), 37 | 'adamw': optax.adamw(1.0), 38 | 'rmsprop': optax.rmsprop(1.0), 39 | # TODO(mckennar): try to incorporate linesearch into the test. 40 | 'lbfgs': optax.lbfgs(1.0, linesearch=None), 41 | 'adadelta': optax.adadelta(1.0), 42 | 'adafactor': optax.adafactor(), 43 | 'adafactor2': optax.adafactor(min_dim_size_to_factor=1), 44 | 'adan': optax.adan(1.0), 45 | 'adamaxw': optax.adamaxw(1.0), 46 | 'amsgrad': optax.amsgrad(1.0), 47 | 'fromage': optax.fromage(1.0), 48 | 'lamb': optax.lamb(1.0), 49 | 'lars': optax.lars(1.0), 50 | 'lion': optax.lion(1.0), 51 | 'nadam': optax.nadam(1.0), 52 | 'nadamw': optax.nadamw(1.0), 53 | 'noisy_sgd': optax.noisy_sgd(1.0), 54 | 'novograd': optax.novograd(1.0), 55 | 'optimistic_gradient_descent': optax.optimistic_gradient_descent(1.0), 56 | 'radam': optax.radam(1.0), 57 | 'sm3': optax.sm3(1.0), 58 | 'yogi': optax.yogi(1.0), 59 | } 60 | 61 | 62 | class ShardingTest(parameterized.TestCase): 63 | 64 | @parameterized.named_parameters(OPTIMIZERS.items()) 65 | def test_init_with_abstract_input(self, optimizer): 66 | params = jax.ShapeDtypeStruct(shape=(2, 4, 8), dtype=jnp.float32) 67 | state = optimizer.init(params) 68 | self.assertIsNotNone(state) 69 | 70 | @parameterized.named_parameters(OPTIMIZERS.items()) 71 | def test_state_sharding_type_init_match_update(self, optimizer): 72 | if jax.__version__ < '0.7.2': 73 | self.skipTest('Skipping sharding-in-types test') 74 | mesh = jax.make_mesh( 75 | (8,), ('x',), axis_types=(jax.sharding.AxisType.Explicit,) 76 | ) 77 | sharding = jax.sharding.NamedSharding(mesh, jax.P(None, 'x')) 78 | 79 | with jax.set_mesh(mesh): 80 | params = jnp.zeros((2, 8, 4), dtype=jnp.float16, out_sharding=sharding) 81 | 82 | state0 = optimizer.init(params) 83 | _, state1 = optimizer.update(params, state0, params) 84 | 85 | type0 = jax.tree.map(jax.typeof, state0) 86 | type1 = jax.tree.map(jax.typeof, state1) 87 | test_utils.assert_trees_all_equal(type0, type1) 88 | 89 | @parameterized.named_parameters(OPTIMIZERS.items()) 90 | def test_state_sharding_type_preserved_with_jit(self, optimizer): 91 | if jax.__version__ < '0.7.2': 92 | self.skipTest('Skipping sharding-in-types test') 93 | mesh = jax.make_mesh( 94 | (8,), ('x',), axis_types=(jax.sharding.AxisType.Explicit,) 95 | ) 96 | sharding = jax.sharding.NamedSharding(mesh, jax.P(None, 'x')) 97 | 98 | with jax.set_mesh(mesh): 99 | params = jnp.zeros((2, 8, 4), dtype=jnp.float16, out_sharding=sharding) 100 | 101 | state0 = optimizer.init(params) 102 | state1 = jax.jit(optimizer.init)(params) 103 | type0 = jax.tree.map(jax.typeof, state0) 104 | type1 = jax.tree.map(jax.typeof, state1) 105 | test_utils.assert_trees_all_equal(type0, type1) 106 | 107 | _, state2 = optimizer.update(params, state0, params) 108 | _, state3 = jax.jit(optimizer.update)(params, state0, params) 109 | type2 = jax.tree.map(jax.typeof, state2) 110 | type3 = jax.tree.map(jax.typeof, state3) 111 | test_utils.assert_trees_all_equal(type2, type3) 112 | 113 | @parameterized.named_parameters( 114 | ('replicated', jax.sharding.PartitionSpec()), 115 | ('sharded', jax.sharding.PartitionSpec('x')) 116 | ) 117 | def test_microbatch_with_explicit_sharding(self, spec): 118 | if jax.__version__ < '0.8.1': 119 | self.skipTest('Skipping sharding-in-types test.') 120 | mesh = jax.make_mesh( 121 | (8,), ('x',), axis_types=(jax.sharding.AxisType.Explicit,) 122 | ) 123 | with jax.set_mesh(mesh): 124 | sharding = jax.sharding.NamedSharding(mesh, spec) 125 | 126 | fun = lambda x: (x + 1, jnp.sum(3 * x)) 127 | data = jnp.arange(16, out_sharding=sharding) 128 | 129 | strategy = ( 130 | microbatching.AccumulationType.CONCAT, 131 | microbatching.AccumulationType.SUM, 132 | ) 133 | microbatched_fun = microbatching.microbatch( 134 | fun, argnums=0, microbatch_size=8, accumulator=strategy 135 | ) 136 | 137 | actual = microbatched_fun(data) 138 | expected = fun(data) 139 | 140 | test_utils.assert_trees_all_equal( 141 | jax.tree.map(jax.typeof, actual), jax.tree.map(jax.typeof, expected) 142 | ) 143 | test_utils.assert_trees_all_equal(actual, expected) 144 | 145 | 146 | if __name__ == '__main__': 147 | absltest.main() 148 | -------------------------------------------------------------------------------- /optax/_src/numerics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utilities to ensure the implementation is safe wrt numerical issues. 16 | 17 | Note that complex numbers are also supported, see 18 | https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 19 | """ 20 | 21 | from typing import Optional, Union 22 | 23 | import jax 24 | import jax.numpy as jnp 25 | import numpy as np 26 | 27 | 28 | # TODO(jscholz) Promote these functions to jax core lib? 29 | 30 | 31 | def abs_sq(x: jax.typing.ArrayLike) -> jax.Array: 32 | """Returns the squared absolute value of a (maybe complex) array. 33 | 34 | For real `x`, JAX generates the same HLO from this, `jnp.square(x)`, `x * x`, 35 | or `x**2`. 36 | 37 | Args: 38 | x: a (maybe complex) array. 39 | 40 | Returns: 41 | The squared absolute value of `x`. 42 | """ 43 | if not isinstance(x, (np.ndarray, jnp.ndarray)): 44 | raise ValueError(f'`abs_sq` accepts only NDarrays, got: {x}.') 45 | return (x.conj() * x).real 46 | 47 | 48 | def safe_norm( 49 | x: jax.typing.ArrayLike, 50 | min_norm: jax.typing.ArrayLike, 51 | ord: Optional[Union[int, float, str]] = None, # pylint: disable=redefined-builtin 52 | axis: Union[None, tuple[int, ...], int] = None, 53 | keepdims: bool = False, 54 | ) -> jax.Array: 55 | """Returns jnp.maximum(jnp.linalg.norm(x), min_norm) with correct gradients. 56 | 57 | The gradients of `jnp.maximum(jnp.linalg.norm(x), min_norm)` at 0.0 is `NaN`, 58 | because jax will evaluate both branches of the `jnp.maximum`. This function 59 | will instead return the correct gradient of 0.0 also in such setting. 60 | 61 | Args: 62 | x: jax array. 63 | min_norm: lower bound for the returned norm. 64 | ord: {non-zero int, inf, -inf, 'fro', 'nuc'}, optional. Order of the norm. 65 | inf means numpy's inf object. The default is None. 66 | axis: {None, int, 2-tuple of ints}, optional. If axis is an integer, it 67 | specifies the axis of x along which to compute the vector norms. If axis 68 | is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix 69 | norms of these matrices are computed. If axis is None then either a vector 70 | norm (when x is 1-D) or a matrix norm (when x is 2-D) is returned. The 71 | default is None. 72 | keepdims: bool, optional. If this is set to True, the axes which are normed 73 | over are left in the result as dimensions with size one. With this option 74 | the result will broadcast correctly against the original x. 75 | 76 | Returns: 77 | The safe norm of the input vector, accounting for correct gradient. 78 | """ 79 | norm = jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=True) 80 | x = jnp.where(norm <= min_norm, jnp.ones_like(x), x) 81 | norm = jnp.squeeze(norm, axis=axis) if not keepdims else norm 82 | masked_norm = jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) 83 | return jnp.where(norm <= min_norm, min_norm, masked_norm) 84 | 85 | 86 | def safe_root_mean_squares( 87 | x: jax.typing.ArrayLike, min_rms: jax.typing.ArrayLike) -> jax.Array: 88 | """Returns `maximum(sqrt(mean(abs_sq(x))), min_norm)` with correct grads. 89 | 90 | The gradients of `maximum(sqrt(mean(abs_sq(x))), min_norm)` at 0.0 91 | is `NaN`, because jax will evaluate both branches of the `jnp.maximum`. This 92 | function will instead return the correct gradient of 0.0 also in such setting. 93 | 94 | Args: 95 | x: jax array. 96 | min_rms: lower bound for the returned norm. 97 | 98 | Returns: 99 | The safe RMS of the input vector, accounting for correct gradient. 100 | """ 101 | rms = jnp.sqrt(jnp.mean(abs_sq(x))) 102 | x = jnp.where(rms <= min_rms, jnp.ones_like(x), x) 103 | return jnp.where(rms <= min_rms, min_rms, jnp.sqrt(jnp.mean(abs_sq(x)))) 104 | 105 | 106 | def safe_increment(count: jax.typing.ArrayLike) -> jax.typing.ArrayLike: 107 | """Increments counter by one while avoiding overflow. 108 | 109 | Denote ``max_val``, ``min_val`` as the maximum, minimum, possible values for 110 | the ``dtype`` of ``count``. Normally ``max_val + 1`` would overflow to 111 | ``min_val``. This functions ensures that when ``max_val`` is reached the 112 | counter stays at ``max_val``. 113 | 114 | Args: 115 | count: a counter to be incremented. 116 | 117 | Returns: 118 | A counter incremented by 1, or ``max_val`` if the maximum value is 119 | reached. 120 | 121 | Examples: 122 | >>> import jax.numpy as jnp 123 | >>> import optax 124 | >>> optax.safe_increment(jnp.asarray(1, dtype=jnp.int32)) 125 | Array(2, dtype=int32) 126 | >>> optax.safe_increment(jnp.asarray(2147483647, dtype=jnp.int32)) 127 | Array(2147483647, dtype=int32) 128 | 129 | .. versionadded:: 0.2.4 130 | """ 131 | count_dtype = jnp.asarray(count).dtype 132 | if jnp.issubdtype(count_dtype, jnp.integer): 133 | max_value = jnp.iinfo(count_dtype).max 134 | elif jnp.issubdtype(count_dtype, jnp.floating): 135 | max_value = jnp.finfo(count_dtype).max 136 | else: 137 | raise ValueError( 138 | f'Cannot safely increment count with dtype {count_dtype},' 139 | ' valid dtypes are subdtypes of "jnp.integer" or "jnp.floating".' 140 | ) 141 | max_value = jnp.array(max_value, count_dtype) 142 | one = jnp.array(1, count_dtype) 143 | return jnp.where(count < max_value, count + one, max_value) 144 | 145 | 146 | safe_int32_increment = safe_increment 147 | --------------------------------------------------------------------------------