├── docs
├── .gitignore
├── _static
│ └── css
│ │ └── custom.css
├── images
│ └── examples
│ │ ├── lbfgs.png
│ │ ├── mnist.png
│ │ ├── ogda.png
│ │ ├── flax_optax.png
│ │ ├── lookahead.png
│ │ ├── adversarial.png
│ │ ├── contrib
│ │ ├── sam.png
│ │ ├── reduce_on_plateau.png
│ │ └── ademamix_rosenbrock.png
│ │ ├── perturbations.png
│ │ ├── cifar10_resnet.png
│ │ ├── tiny_shakespeare.png
│ │ └── linear_assignment_problem.png
├── api
│ ├── assignment.rst
│ ├── combining_optimizers.rst
│ ├── perturbations.rst
│ ├── apply_updates.rst
│ ├── experimental.rst
│ ├── optimizer_wrappers.rst
│ ├── stochastic_gradient_estimators.rst
│ ├── optimizer_schedules.rst
│ ├── optimizers.rst
│ ├── projections.rst
│ ├── contrib.rst
│ ├── losses.rst
│ └── utilities.rst
├── pyproject.toml
├── Makefile
├── index.rst
└── ext
│ └── coverage_check.py
├── .github
├── workflows
│ ├── mlc_config.json
│ └── pypi-publish.yml
└── check_license_headers.py
├── examples
├── README.md
├── contrib
│ └── README.md
└── pyproject.toml
├── .gitignore
├── readthedocs.yml
├── optax
├── experimental
│ └── __init__.py
├── schedules
│ ├── inject.py
│ ├── _join_test.py
│ ├── _join.py
│ └── __init__.py
├── assignment
│ ├── __init__.py
│ └── _hungarian_algorithm_test.py
├── perturbations
│ └── __init__.py
├── second_order
│ ├── __init__.py
│ └── _deprecated.py
├── _src
│ ├── combine.py
│ ├── constrain.py
│ ├── wrappers.py
│ ├── schedule.py
│ ├── deprecations.py
│ ├── update_test.py
│ ├── factorized_test.py
│ ├── float64_test.py
│ ├── update.py
│ ├── test_utils.py
│ ├── sharding_test.py
│ └── numerics.py
├── optax_test.py
├── CHANGELOG.md
├── projections
│ └── __init__.py
├── losses
│ ├── _smoothing.py
│ ├── _fenchel_young.py
│ ├── _smoothing_test.py
│ ├── __init__.py
│ ├── _fenchel_young_test.py
│ └── _self_supervised_test.py
├── transforms
│ ├── _layouts_test.py
│ ├── _layouts.py
│ ├── _constraining.py
│ ├── _monitoring_test.py
│ ├── _adding_test.py
│ ├── _freezing.py
│ ├── __init__.py
│ └── _constraining_test.py
├── tree
│ └── __init__.py
├── contrib
│ ├── _complex_valued_test.py
│ ├── __init__.py
│ ├── _sam_test.py
│ ├── _mechanic_test.py
│ ├── _privacy_test.py
│ ├── _complex_valued.py
│ └── _cocob.py
└── tree_utils
│ ├── _random.py
│ ├── _random_test.py
│ ├── __init__.py
│ └── _casting_test.py
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── test.sh
└── pyproject.toml
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | _build
2 |
--------------------------------------------------------------------------------
/docs/_static/css/custom.css:
--------------------------------------------------------------------------------
1 | div.math {
2 | flex-direction: row;
3 | }
4 |
--------------------------------------------------------------------------------
/docs/images/examples/lbfgs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/lbfgs.png
--------------------------------------------------------------------------------
/docs/images/examples/mnist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/mnist.png
--------------------------------------------------------------------------------
/docs/images/examples/ogda.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/ogda.png
--------------------------------------------------------------------------------
/docs/images/examples/flax_optax.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/flax_optax.png
--------------------------------------------------------------------------------
/docs/images/examples/lookahead.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/lookahead.png
--------------------------------------------------------------------------------
/docs/images/examples/adversarial.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/adversarial.png
--------------------------------------------------------------------------------
/docs/images/examples/contrib/sam.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/contrib/sam.png
--------------------------------------------------------------------------------
/docs/images/examples/perturbations.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/perturbations.png
--------------------------------------------------------------------------------
/docs/images/examples/cifar10_resnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/cifar10_resnet.png
--------------------------------------------------------------------------------
/docs/images/examples/tiny_shakespeare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/tiny_shakespeare.png
--------------------------------------------------------------------------------
/.github/workflows/mlc_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "timeout": "20s",
3 | "retryCount": 5,
4 | "aliveStatusCodes": [
5 | 429
6 | ]
7 | }
8 |
--------------------------------------------------------------------------------
/docs/images/examples/contrib/reduce_on_plateau.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/contrib/reduce_on_plateau.png
--------------------------------------------------------------------------------
/docs/images/examples/linear_assignment_problem.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/linear_assignment_problem.png
--------------------------------------------------------------------------------
/docs/images/examples/contrib/ademamix_rosenbrock.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/optax/HEAD/docs/images/examples/contrib/ademamix_rosenbrock.png
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Examples
2 |
3 | This directory contains examples using the optax library.
4 |
5 | ```{toctree}
6 | :glob:
7 | :maxdepth: 1
8 |
9 | *
10 | ```
11 |
--------------------------------------------------------------------------------
/examples/contrib/README.md:
--------------------------------------------------------------------------------
1 | # Contrib Examples
2 |
3 | Examples that make use of the `optax.contrib` module.
4 |
5 | ```{toctree}
6 | :glob:
7 | :maxdepth: 1
8 |
9 | *
10 | ```
11 |
--------------------------------------------------------------------------------
/docs/api/assignment.rst:
--------------------------------------------------------------------------------
1 | Assignment problem
2 | ==================
3 |
4 | .. currentmodule:: optax.assignment
5 |
6 | .. autosummary::
7 | hungarian_algorithm
8 |
9 |
10 | Hungarian algorithm
11 | ~~~~~~~~~~~~~~~~~~~
12 | .. autofunction:: hungarian_algorithm
13 |
--------------------------------------------------------------------------------
/docs/api/combining_optimizers.rst:
--------------------------------------------------------------------------------
1 | Combining Optimizers
2 | =====================
3 |
4 | .. currentmodule:: optax
5 |
6 | .. autosummary::
7 | chain
8 | named_chain
9 | partition
10 |
11 | Chain
12 | ~~~~~
13 | .. autofunction:: chain
14 | .. autofunction:: named_chain
15 |
16 | Partition
17 | ~~~~~~~~~
18 | .. autofunction:: partition
19 | .. autoclass:: PartitionState
20 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Building and releasing library:
2 | *.egg-info
3 | *.pyc
4 | *.so
5 | build/
6 | dist/
7 | venv/
8 | _testing/
9 |
10 | # Building the documentation
11 | docs/_autosummary
12 | docs/_collections
13 | docs/modules/generated
14 | docs/sg_execution_times.rst
15 |
16 | # Mac OS
17 | .DS_Store
18 |
19 | # Python tools
20 | .mypy_cache/
21 | .pytype/
22 | .ipynb_checkpoints
23 |
24 | # Editors
25 | .idea
26 | .vscode
27 |
28 |
--------------------------------------------------------------------------------
/docs/api/perturbations.rst:
--------------------------------------------------------------------------------
1 | Perturbations
2 | =============
3 |
4 | .. currentmodule:: optax.perturbations
5 |
6 | .. autosummary::
7 | make_perturbed_fun
8 | Gumbel
9 | Normal
10 |
11 |
12 | Gumbel noise
13 | ~~~~~~~~~~~~
14 | .. autoclass:: Gumbel
15 |
16 | Make perturbed function
17 | ~~~~~~~~~~~~~~~~~~~~~~~
18 | .. autofunction:: make_perturbed_fun
19 |
20 | Normal noise
21 | ~~~~~~~~~~~~
22 | .. autoclass:: Normal
23 |
24 |
--------------------------------------------------------------------------------
/docs/api/apply_updates.rst:
--------------------------------------------------------------------------------
1 | Apply Updates
2 | =============
3 |
4 | .. currentmodule:: optax
5 |
6 | .. autosummary::
7 | apply_updates
8 | incremental_update
9 | periodic_update
10 |
11 | Apply updates
12 | ~~~~~~~~~~~~~~~~~
13 | .. autofunction:: apply_updates
14 |
15 | Incremental update
16 | ~~~~~~~~~~~~~~~~~~
17 | .. autofunction:: incremental_update
18 |
19 | Periodic update
20 | ~~~~~~~~~~~~~~~
21 | .. autofunction:: periodic_update
22 |
--------------------------------------------------------------------------------
/readthedocs.yml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 |
4 | version: 2
5 |
6 | build:
7 | os: ubuntu-22.04
8 | tools:
9 | python: "3.11"
10 |
11 | sphinx:
12 | builder: html
13 | configuration: docs/conf.py
14 | fail_on_warning: false
15 |
16 | python:
17 | install:
18 | # Equivalent to 'pip install .'
19 | - method: pip
20 | path: .
21 | # Equivalent to 'pip install .[docs]'
22 | - method: pip
23 | path: .
24 | extra_requirements:
25 | - docs
26 |
--------------------------------------------------------------------------------
/docs/api/experimental.rst:
--------------------------------------------------------------------------------
1 | 🧪 Experimental
2 | ===============
3 |
4 | Experimental features subject to changes before being graduated into `optax`.
5 |
6 | .. currentmodule:: optax.experimental
7 |
8 | .. autosummary::
9 | microbatching.microbatch
10 | microbatching.micro_vmap
11 | microbatching.micro_grad
12 | microbatching.AccumulationType
13 | microbatching.Accumulator
14 |
15 | .. currentmodule:: optax.experimental.microbatching
16 |
17 | Microbatching
18 | ~~~~~~~~~~~~~
19 | .. autofunction:: microbatch
20 | .. autofunction:: micro_vmap
21 | .. autofunction:: micro_grad
22 | .. autofunction:: reshape_batch_axis
23 | .. autoclass:: AccumulationType
24 | :members:
25 | .. autoclass:: Accumulator
26 |
--------------------------------------------------------------------------------
/docs/pyproject.toml:
--------------------------------------------------------------------------------
1 | # The pyproject.toml is here to override ruff linter checks in this directory
2 | # changes from base pyproject.toml:
3 | # - allow E501 line too long
4 |
5 | [tool.ruff]
6 | line-length = 80
7 |
8 | [tool.ruff.lint]
9 | select = [
10 | "F",
11 | "E",
12 | "W291", # whitespace at the end of the line
13 | "B023", # pylint's cell-var-over-loop, closures capturing variables in loop
14 | ]
15 | ignore = [
16 | "E731", # lambdas are allowed
17 | "F401", # allow unused imports
18 | "E501", # allow line-too-long in notebooks
19 | "E402", # allow modules not at top of file
20 | "E741", # allow "l" as a variable name
21 | "E703", # allow semicolons (for jupyter notebooks)
22 | ]
23 |
--------------------------------------------------------------------------------
/examples/pyproject.toml:
--------------------------------------------------------------------------------
1 | # The pyproject.toml is here to override ruff linter checks in this directory
2 | # changes from base pyproject.toml:
3 | # - allow E501 line too long
4 |
5 | [tool.ruff]
6 | line-length = 80
7 |
8 | [tool.ruff.lint]
9 | select = [
10 | "F",
11 | "E",
12 | "W291", # whitespace at the end of the line
13 | "B023", # pylint's cell-var-over-loop, closures capturing variables in loop
14 | ]
15 | ignore = [
16 | "E731", # lambdas are allowed
17 | "F401", # allow unused imports
18 | "E501", # allow line-too-long in notebooks
19 | "E402", # allow modules not at top of file
20 | "E741", # allow "l" as a variable name
21 | "E703", # allow semicolons (for jupyter notebooks)
22 | ]
23 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SOURCEDIR = .
8 | BUILDDIR = _build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
20 |
21 | clean:
22 | rm -rf $(BUILDDIR)/* auto_examples _collections modules
23 |
24 | html-noplot:
25 | $(SPHINXBUILD) -D plot_gallery=0 -D nb_execution_mode=off -b html $(ALLSPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html
26 | @echo
27 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
28 |
--------------------------------------------------------------------------------
/optax/experimental/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Experimental optax modules."""
17 |
18 | from . import _aggregating as aggregating
19 | from .microbatching import microbatch
20 |
21 |
22 | __all__ = [
23 | 'aggregating',
24 | 'microbatch',
25 | ]
26 |
--------------------------------------------------------------------------------
/optax/schedules/inject.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Import stub."""
16 |
17 | # TODO(mtthss): delete this file asap.
18 |
19 | from optax.schedules import _inject
20 |
21 | InjectHyperparamsState = _inject.InjectHyperparamsState
22 | inject_hyperparams = _inject.inject_hyperparams
23 |
--------------------------------------------------------------------------------
/optax/assignment/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """The assignment sub-package."""
16 |
17 | # pylint:disable=g-importing-member
18 |
19 | from optax.assignment._hungarian_algorithm import base_hungarian_algorithm
20 | from optax.assignment._hungarian_algorithm import hungarian_algorithm
21 |
--------------------------------------------------------------------------------
/optax/perturbations/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """The perturbations sub-package."""
16 |
17 | # pylint: disable=g-importing-member
18 |
19 | from optax.perturbations._make_pert import Gumbel
20 | from optax.perturbations._make_pert import make_perturbed_fun
21 | from optax.perturbations._make_pert import Normal
22 |
--------------------------------------------------------------------------------
/optax/second_order/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """The second order optimization sub-package."""
16 |
17 | # pylint: disable=g-importing-member
18 |
19 | from optax.second_order._deprecated import fisher_diag
20 | from optax.second_order._deprecated import hessian_diag
21 | from optax.second_order._deprecated import hvp
22 |
--------------------------------------------------------------------------------
/optax/_src/combine.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Flexibly compose gradient transformations."""
16 |
17 | from optax.transforms import _combining
18 |
19 | chain = _combining.chain
20 | named_chain = _combining.named_chain
21 | partition = _combining.partition
22 | multi_transform = _combining.partition # for backwards compatibility
23 | MultiTransformState = _combining.PartitionState
24 |
--------------------------------------------------------------------------------
/docs/api/optimizer_wrappers.rst:
--------------------------------------------------------------------------------
1 | Optimizer Wrappers
2 | ====================
3 |
4 | .. currentmodule:: optax
5 |
6 | .. autosummary::
7 | apply_if_finite
8 | ApplyIfFiniteState
9 | flatten
10 | lookahead
11 | LookaheadParams
12 | LookaheadState
13 | masked
14 | MaskedState
15 | MultiSteps
16 | MultiStepsState
17 | ShouldSkipUpdateFunction
18 | skip_large_updates
19 | skip_not_finite
20 |
21 |
22 | Apply if finite
23 | ~~~~~~~~~~~~~~~
24 | .. autofunction:: apply_if_finite
25 | .. autoclass:: ApplyIfFiniteState
26 |
27 | Flatten
28 | ~~~~~~~~
29 | .. autofunction:: flatten
30 |
31 | Lookahead
32 | ~~~~~~~~~
33 | .. autofunction:: lookahead
34 | .. autoclass:: LookaheadParams
35 | .. autoclass:: LookaheadState
36 |
37 | Masked update
38 | ~~~~~~~~~~~~~
39 | .. autofunction:: masked
40 | .. autoclass:: MaskedState
41 |
42 | Multi-step update
43 | ~~~~~~~~~~~~~~~~~
44 | .. autoclass:: MultiSteps
45 | .. autoclass:: MultiStepsState
46 | .. autoclass:: ShouldSkipUpdateFunction
47 | .. autofunction:: skip_large_updates
48 | .. autofunction:: skip_not_finite
49 |
--------------------------------------------------------------------------------
/optax/_src/constrain.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Gradient transformations used to enforce specific constraints."""
16 |
17 | from optax.transforms import _constraining
18 |
19 | keep_params_nonnegative = _constraining.keep_params_nonnegative
20 | NonNegativeParamsState = _constraining.NonNegativeParamsState
21 | zero_nans = _constraining.zero_nans
22 | ZeroNansState = _constraining.ZeroNansState
23 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # Install the pre-commit hooks below with
2 | # 'pre-commit install'
3 |
4 | # Auto-update the version of the hooks with
5 | # 'pre-commit autoupdate'
6 |
7 | # Run the hooks on all files with
8 | # 'pre-commit run --all'
9 |
10 | repos:
11 | - repo: https://github.com/pre-commit/pre-commit-hooks
12 | rev: v5.0.0
13 | hooks:
14 | - id: check-ast
15 | - id: check-merge-conflict
16 | - id: check-toml
17 | - id: check-yaml
18 | - id: end-of-file-fixer
19 | files: \.(py|md)$
20 | - id: trailing-whitespace
21 | files: \.(py|md|ipynb)$
22 | - id: debug-statements
23 | files: \.py$ # only include python files
24 | - id: mixed-line-ending # remove CR
25 | args: [--fix=lf]
26 |
27 | - repo: local
28 | hooks:
29 | - id: check-license
30 | name: Check license headers
31 | entry: python ./.github/check_license_headers.py
32 | language: python
33 | pass_filenames: false # only run this check once globally per repo
34 | always_run: true
35 |
36 | - repo: https://github.com/astral-sh/ruff-pre-commit
37 | rev: v0.9.10
38 | hooks:
39 | - id: ruff
40 |
--------------------------------------------------------------------------------
/.github/workflows/pypi-publish.yml:
--------------------------------------------------------------------------------
1 | name: pypi
2 |
3 | on:
4 | release:
5 | types: [created]
6 |
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | permissions:
11 | id-token: write
12 | steps:
13 | - uses: actions/checkout@v4
14 | - name: Set up Python
15 | uses: actions/setup-python@v4
16 | with:
17 | python-version: '3.x'
18 | - name: Install dependencies
19 | run: |
20 | python -m pip install --upgrade pip
21 | pip install setuptools wheel twine build
22 | - name: Check consistency between the package version and release tag
23 | run: |
24 | pip install .
25 | RELEASE_VER=${GITHUB_REF#refs/*/}
26 | PACKAGE_VER="v`python -c 'import optax; print(optax.__version__)'`"
27 | if [ $RELEASE_VER != $PACKAGE_VER ]
28 | then
29 | echo "package ver. ($PACKAGE_VER) != release ver. ($RELEASE_VER)"; exit 1
30 | fi
31 | - name: Build
32 | run: |
33 | python -m build
34 | - name: Publish package distributions to PyPI
35 | uses: pypa/gh-action-pypi-publish@release/v1
36 |
--------------------------------------------------------------------------------
/docs/api/stochastic_gradient_estimators.rst:
--------------------------------------------------------------------------------
1 | Stochastic Gradient Estimators and Control Variates
2 | ===================================================
3 |
4 | .. warning::
5 | This module has been deprecated and will be removed in optax 0.2.7
6 |
7 | .. currentmodule:: optax.monte_carlo
8 |
9 | .. autosummary::
10 | control_delta_method
11 | control_variates_jacobians
12 | measure_valued_jacobians
13 | moving_avg_baseline
14 | pathwise_jacobians
15 | score_function_jacobians
16 |
17 |
18 | Control delta method
19 | ~~~~~~~~~~~~~~~~~~~~
20 | .. autofunction:: control_delta_method
21 |
22 | Control variates Jacobians
23 | ~~~~~~~~~~~~~~~~~~~~~~~~~~
24 | .. autofunction:: control_variates_jacobians
25 |
26 | Moving average baseline
27 | ~~~~~~~~~~~~~~~~~~~~~~~
28 | .. autofunction:: moving_avg_baseline
29 |
30 | Measure valued Jacobians
31 | ~~~~~~~~~~~~~~~~~~~~~~~~
32 | .. autofunction:: measure_valued_jacobians
33 |
34 | Pathwise Jacobians
35 | ~~~~~~~~~~~~~~~~~~
36 | .. autofunction:: pathwise_jacobians
37 |
38 | Score function Jacobians
39 | ~~~~~~~~~~~~~~~~~~~~~~~~
40 | .. autofunction:: score_function_jacobians
41 |
--------------------------------------------------------------------------------
/optax/optax_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for importing optax."""
16 |
17 | from absl.testing import absltest
18 | import optax
19 | from optax import transforms
20 |
21 |
22 | class OptaxTest(absltest.TestCase):
23 | """Test optax can be imported correctly."""
24 |
25 | def test_import(self):
26 | self.assertTrue(hasattr(optax, 'GradientTransformation'))
27 | self.assertTrue(hasattr(transforms, 'partition'))
28 |
29 |
30 | if __name__ == '__main__':
31 | absltest.main()
32 |
--------------------------------------------------------------------------------
/optax/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | All notable changes to this project will be documented in this file.
4 |
5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7 |
8 | ## [Unreleased]
9 |
10 | ### Added
11 |
12 | - CHANGELOG.md file added to track notable changes.
13 |
14 | ### Changed
15 |
16 | - Test classes now inherit from `absl.TestCase` or `parameterized.TestCase`
17 | instead of `chex.TestCase` as part of our effort to remove the `chex`
18 | dependency. This means that Chex test variants (with/without `jit`, with/without
19 | `device_put`, with `pmap`) are no longer tested. We decided it was sufficient to
20 | use `jit` throughout the tests. There is already test coverage on both CPU and
21 | accelerators, and `pmap` is deprecated.
22 | - Classification losses (`poly_loss_cross_entropy`,
23 | `ctc_loss_with_forward_probs`, `ctc_loss`, `sigmoid_focal_loss`) and regression
24 | losses (`huber_loss`, `cosine_similarity`, `cosine_distance`) no longer support
25 | positional args for hyperparameter-like inputs.
26 |
27 | ### Removed
28 |
29 | - Stochastic gradient estimators à la Reinforce with control variates methods.
30 | See monte_carlo folder in optax 0.1.8 if you are interested.
31 | - Removed optax._src.transform.cast_tree and optax._src.utils.cast_tree. Use
32 | optax.tree.cast from now on.
33 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Testing
26 |
27 | Please make sure that your PR passes all tests by running `bash test.sh` on your
28 | local machine. Also, you can run only tests that are affected by your code
29 | changes, but you will need to select them manually.
30 |
31 | ## Community Guidelines
32 |
33 | This project follows [Google's Open Source Community
34 | Guidelines](https://opensource.google.com/conduct/).
35 |
--------------------------------------------------------------------------------
/optax/projections/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """The projections sub-package."""
17 |
18 | # pylint: disable=g-importing-member
19 |
20 | from optax.projections._projections import projection_box
21 | from optax.projections._projections import projection_halfspace
22 | from optax.projections._projections import projection_hypercube
23 | from optax.projections._projections import projection_hyperplane
24 | from optax.projections._projections import projection_l1_ball
25 | from optax.projections._projections import projection_l1_sphere
26 | from optax.projections._projections import projection_l2_ball
27 | from optax.projections._projections import projection_l2_sphere
28 | from optax.projections._projections import projection_linf_ball
29 | from optax.projections._projections import projection_non_negative
30 | from optax.projections._projections import projection_simplex
31 | from optax.projections._projections import projection_vector
32 |
--------------------------------------------------------------------------------
/optax/losses/_smoothing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Smoothing functions."""
16 |
17 | import jax
18 | import jax.numpy as jnp
19 | from optax._src import utils
20 |
21 |
22 | def smooth_labels(
23 | labels: jax.typing.ArrayLike,
24 | alpha: jax.typing.ArrayLike,
25 | ) -> jax.Array:
26 | """Apply label smoothing.
27 |
28 | Label smoothing is often used in combination with a cross-entropy loss.
29 | Smoothed labels favor small logit gaps, and it has been shown that this can
30 | provide better model calibration by preventing overconfident predictions.
31 |
32 | Args:
33 | labels: One hot labels to be smoothed.
34 | alpha: The smoothing factor.
35 |
36 | Returns:
37 | a smoothed version of the one hot input labels.
38 |
39 | References:
40 | Muller et al, `When does label smoothing help?
41 | `_, 2019
42 | """
43 | utils.check_subdtype(labels, jnp.floating)
44 | num_categories = labels.shape[-1]
45 | return (1.0 - alpha) * labels + alpha / num_categories
46 |
--------------------------------------------------------------------------------
/optax/_src/wrappers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Transformation wrappers."""
16 |
17 | from optax.transforms import _accumulation
18 | from optax.transforms import _conditionality
19 | from optax.transforms import _layouts
20 | from optax.transforms import _masking
21 |
22 |
23 | apply_if_finite = _conditionality.apply_if_finite
24 | ApplyIfFiniteState = _conditionality.ApplyIfFiniteState
25 | ConditionFn = _conditionality.ConditionFn
26 | conditionally_mask = _conditionality.conditionally_mask
27 | conditionally_transform = _conditionality.conditionally_transform
28 | ConditionallyMaskState = _conditionality.ConditionallyMaskState
29 | ConditionallyTransformState = _conditionality.ConditionallyTransformState
30 | flatten = _layouts.flatten
31 | masked = _masking.masked
32 | MaskedNode = _masking.MaskedNode
33 | MaskedState = _masking.MaskedState
34 | MultiSteps = _accumulation.MultiSteps
35 | MultiStepsState = _accumulation.MultiStepsState
36 | ShouldSkipUpdateFunction = _accumulation.ShouldSkipUpdateFunction
37 | skip_not_finite = _accumulation.skip_not_finite
38 | skip_large_updates = _accumulation.skip_large_updates
39 |
--------------------------------------------------------------------------------
/optax/schedules/_join_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for methods in `join.py`."""
16 |
17 | from absl.testing import absltest
18 | import numpy as np
19 | from optax.schedules import _join
20 | from optax.schedules import _schedule
21 |
22 |
23 | class JoinTest(absltest.TestCase):
24 |
25 | def test_join_schedules(self):
26 | my_schedule = _join.join_schedules(
27 | schedules=[
28 | _schedule.constant_schedule(1.0),
29 | _schedule.constant_schedule(2.0),
30 | _schedule.constant_schedule(1.0),
31 | ],
32 | boundaries=[3, 6],
33 | )
34 | np.testing.assert_allclose(1.0, my_schedule(0), atol=0.0)
35 | np.testing.assert_allclose(1.0, my_schedule(1), atol=0.0)
36 | np.testing.assert_allclose(1.0, my_schedule(2), atol=0.0)
37 | np.testing.assert_allclose(2.0, my_schedule(3), atol=0.0)
38 | np.testing.assert_allclose(2.0, my_schedule(4), atol=0.0)
39 | np.testing.assert_allclose(2.0, my_schedule(5), atol=0.0)
40 | np.testing.assert_allclose(1.0, my_schedule(6), atol=0.0)
41 |
42 |
43 | if __name__ == "__main__":
44 | absltest.main()
45 |
--------------------------------------------------------------------------------
/optax/schedules/_join.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Utilities to join schedules."""
16 |
17 | from typing import Sequence
18 |
19 | import jax
20 | import jax.numpy as jnp
21 | from optax._src import base
22 |
23 |
24 | def join_schedules(
25 | schedules: Sequence[base.Schedule], boundaries: Sequence[int]
26 | ) -> base.Schedule:
27 | """Sequentially apply multiple schedules.
28 |
29 | Args:
30 | schedules: A list of callables (expected to be optax schedules). Each
31 | schedule will receive a step count indicating the number of steps since
32 | the previous boundary transition.
33 | boundaries: A list of integers (of length one less than schedules) that
34 | indicate when to transition between schedules.
35 |
36 | Returns:
37 | schedule: A function that maps step counts to values.
38 | """
39 |
40 | def schedule(step: jax.typing.ArrayLike) -> jax.typing.ArrayLike:
41 | output = schedules[0](step)
42 | for boundary, schedule in zip(boundaries, schedules[1:]):
43 | output = jnp.where(step < boundary, output, schedule(step - boundary))
44 | return output
45 |
46 | return schedule
47 |
--------------------------------------------------------------------------------
/docs/api/optimizer_schedules.rst:
--------------------------------------------------------------------------------
1 | Optimizer Schedules
2 | =====================
3 |
4 | .. currentmodule:: optax.schedules
5 |
6 | .. autosummary::
7 | constant_schedule
8 | cosine_decay_schedule
9 | cosine_onecycle_schedule
10 | exponential_decay
11 | join_schedules
12 | linear_onecycle_schedule
13 | linear_schedule
14 | piecewise_constant_schedule
15 | piecewise_interpolate_schedule
16 | polynomial_schedule
17 | sgdr_schedule
18 | warmup_constant_schedule
19 | warmup_cosine_decay_schedule
20 | warmup_exponential_decay_schedule
21 | Schedule
22 | InjectHyperparamsState
23 | inject_hyperparams
24 |
25 |
26 | .. autoclass:: Schedule
27 |
28 | Constant schedule
29 | ~~~~~~~~~~~~~~~~~
30 | .. autofunction:: constant_schedule
31 | .. autofunction:: warmup_constant_schedule
32 |
33 | Cosine decay schedule
34 | ~~~~~~~~~~~~~~~~~~~~~
35 | .. autofunction:: cosine_decay_schedule
36 | .. autofunction:: cosine_onecycle_schedule
37 | .. autofunction:: warmup_cosine_decay_schedule
38 |
39 | Exponential decay schedule
40 | ~~~~~~~~~~~~~~~~~~~~~~~~~~
41 | .. autofunction:: exponential_decay
42 | .. autofunction:: warmup_exponential_decay_schedule
43 |
44 | Join schedules
45 | ~~~~~~~~~~~~~~
46 | .. autofunction:: join_schedules
47 |
48 | Inject hyperparameters
49 | ~~~~~~~~~~~~~~~~~~~~~~
50 | .. autofunction:: inject_hyperparams
51 | .. autoclass:: InjectHyperparamsState
52 |
53 | Linear schedules
54 | ~~~~~~~~~~~~~~~~
55 | .. autofunction:: linear_onecycle_schedule
56 | .. autofunction:: linear_schedule
57 |
58 | Piecewise schedules
59 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
60 | .. autofunction:: piecewise_constant_schedule
61 | .. autofunction:: piecewise_interpolate_schedule
62 |
63 | Polynomial schedules
64 | ~~~~~~~~~~~~~~~~~~~~
65 | .. autofunction:: polynomial_schedule
66 |
67 | Reduce on plateau
68 | ~~~~~~~~~~~~~~~~~
69 | .. autofunction:: optax.contrib.reduce_on_plateau
70 |
71 | Warm restarts
72 | ~~~~~~~~~~~~~
73 | .. autofunction:: sgdr_schedule
74 |
--------------------------------------------------------------------------------
/optax/_src/schedule.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Hyper-parameters Schedules.
16 |
17 | Schedules may be used to anneal the value of a hyper-parameter over time; for
18 | instance, they may be used to anneal the learning rate used to update an agent's
19 | parameters or the exploration factor used to select actions.
20 | """
21 |
22 | from optax import schedules
23 |
24 |
25 | # TODO(mtthss): remove schedules alises from flat namespaces after user updates.
26 | constant_schedule = schedules.constant_schedule
27 | cosine_decay_schedule = schedules.cosine_decay_schedule
28 | cosine_onecycle_schedule = schedules.cosine_onecycle_schedule
29 | exponential_decay = schedules.exponential_decay
30 | inject_hyperparams = schedules.inject_hyperparams
31 | InjectHyperparamsState = schedules.InjectHyperparamsState
32 | join_schedules = schedules.join_schedules
33 | linear_onecycle_schedule = schedules.linear_onecycle_schedule
34 | linear_schedule = schedules.linear_schedule
35 | piecewise_constant_schedule = schedules.piecewise_constant_schedule
36 | piecewise_interpolate_schedule = schedules.piecewise_interpolate_schedule
37 | polynomial_schedule = schedules.polynomial_schedule
38 | sgdr_schedule = schedules.sgdr_schedule
39 | warmup_constant_schedule = schedules.warmup_constant_schedule
40 | warmup_cosine_decay_schedule = schedules.warmup_cosine_decay_schedule
41 | warmup_exponential_decay_schedule = schedules.warmup_exponential_decay_schedule
42 |
--------------------------------------------------------------------------------
/optax/schedules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Utilities for creating schedules."""
16 |
17 | # pylint: disable=g-importing-member
18 |
19 | from optax._src.base import Schedule
20 | from optax._src.base import StatefulSchedule
21 | from optax.schedules._inject import inject_hyperparams
22 | from optax.schedules._inject import inject_stateful_hyperparams
23 | from optax.schedules._inject import InjectHyperparamsState
24 | from optax.schedules._inject import InjectStatefulHyperparamsState
25 | from optax.schedules._inject import WrappedSchedule
26 | from optax.schedules._join import join_schedules
27 | from optax.schedules._schedule import constant_schedule
28 | from optax.schedules._schedule import cosine_decay_schedule
29 | from optax.schedules._schedule import cosine_onecycle_schedule
30 | from optax.schedules._schedule import exponential_decay
31 | from optax.schedules._schedule import linear_onecycle_schedule
32 | from optax.schedules._schedule import linear_schedule
33 | from optax.schedules._schedule import piecewise_constant_schedule
34 | from optax.schedules._schedule import piecewise_interpolate_schedule
35 | from optax.schedules._schedule import polynomial_schedule
36 | from optax.schedules._schedule import sgdr_schedule
37 | from optax.schedules._schedule import warmup_constant_schedule
38 | from optax.schedules._schedule import warmup_cosine_decay_schedule
39 | from optax.schedules._schedule import warmup_exponential_decay_schedule
40 |
--------------------------------------------------------------------------------
/optax/transforms/_layouts_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for methods in `optax.transforms._layouts.py`."""
16 |
17 | from absl.testing import absltest
18 | import jax.numpy as jnp
19 | from optax._src import alias
20 | from optax._src import test_utils
21 | from optax._src import update
22 | from optax.transforms import _layouts
23 |
24 |
25 | class LayoutsTest(absltest.TestCase):
26 |
27 | def test_flatten(self):
28 | def init_params():
29 | return (jnp.array(2.0), jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0]))
30 |
31 | per_step_updates = (
32 | jnp.array(1.0),
33 | jnp.array([500.0, 5.0]),
34 | jnp.array([300.0, 3.0]),
35 | )
36 |
37 | # First calculate new params without flattening
38 | optax_sgd_params = init_params()
39 | sgd = alias.sgd(1e-2, 0.0)
40 | state_sgd = sgd.init(optax_sgd_params)
41 | updates_sgd, _ = sgd.update(per_step_updates, state_sgd)
42 | sgd_params_no_flatten = update.apply_updates(optax_sgd_params, updates_sgd)
43 |
44 | # And now calculate new params with flattening
45 | optax_sgd_params = init_params()
46 | sgd = _layouts.flatten(sgd)
47 |
48 | state_sgd = sgd.init(optax_sgd_params)
49 | updates_sgd, _ = sgd.update(per_step_updates, state_sgd)
50 | sgd_params_flatten = update.apply_updates(optax_sgd_params, updates_sgd)
51 |
52 | # Test that both give the same result
53 | test_utils.assert_trees_all_close(
54 | sgd_params_no_flatten, sgd_params_flatten, atol=1e-7, rtol=1e-7
55 | )
56 |
57 |
58 | if __name__ == "__main__":
59 | absltest.main()
60 |
--------------------------------------------------------------------------------
/optax/losses/_fenchel_young.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Fenchel-Young losses."""
16 |
17 | from typing import Any, Protocol
18 |
19 | import jax
20 | import jax.numpy as jnp
21 |
22 |
23 | class MaxFun(Protocol):
24 |
25 | def __call__(self, scores, *args, **kwargs: Any) -> jax.typing.ArrayLike:
26 | ...
27 |
28 |
29 | def make_fenchel_young_loss(max_fun: MaxFun):
30 | """Creates a Fenchel-Young loss from a max function.
31 |
32 | Args:
33 | max_fun: the max function on which the Fenchel-Young loss is built.
34 |
35 | Returns:
36 | A Fenchel-Young loss function with the same signature.
37 |
38 | Examples:
39 | Given a max function, e.g., the log sum exp, you can construct a
40 | Fenchel-Young loss easily as follows:
41 |
42 | >>> from jax.scipy.special import logsumexp
43 | >>> fy_loss = optax.losses.make_fenchel_young_loss(max_fun=logsumexp)
44 |
45 | Reference:
46 | Blondel et al. `Learning with Fenchel-Young Losses
47 | `_, 2020
48 |
49 | .. warning::
50 | The resulting loss accepts an arbitrary number of leading dimensions
51 | with the fy_loss operating over the last dimension. The jaxopt version of
52 | this function would instead flatten any vector in a single big 1D vector.
53 | """
54 |
55 | vdot_last_dim = jnp.vectorize(jnp.vdot, signature="(n),(n)->()")
56 | max_fun_last_dim = jnp.vectorize(max_fun, signature="(n)->()")
57 |
58 | def fenchel_young_loss(scores, targets, *args, **kwargs):
59 | max_value = max_fun_last_dim(scores, *args, **kwargs)
60 | return max_value - vdot_last_dim(targets, scores)
61 |
62 | return fenchel_young_loss
63 |
--------------------------------------------------------------------------------
/optax/tree/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """The tree_utils sub-package."""
16 |
17 | # pylint: disable=g-importing-member
18 |
19 | from optax.tree_utils import _casting
20 | from optax.tree_utils import _random
21 | from optax.tree_utils import _state_utils
22 | from optax.tree_utils import _tree_math
23 |
24 | cast = _casting.tree_cast
25 | cast_like = _casting.tree_cast_like
26 | dtype = _casting.tree_dtype
27 | random_like = _random.tree_random_like
28 | split_key_like = _random.tree_split_key_like
29 | unwrap_random_key_data = _random.tree_unwrap_random_key_data
30 | get = _state_utils.tree_get
31 | get_all_with_path = _state_utils.tree_get_all_with_path
32 | map_params = _state_utils.tree_map_params
33 | set = _state_utils.tree_set # pylint: disable=redefined-builtin
34 | add = _tree_math.tree_add
35 | add_scale = _tree_math.tree_add_scale
36 | allclose = _tree_math.tree_allclose
37 | batch_shape = _tree_math.tree_batch_shape
38 | bias_correction = _tree_math.tree_bias_correction
39 | clip = _tree_math.tree_clip
40 | conj = _tree_math.tree_conj
41 | div = _tree_math.tree_div
42 | full_like = _tree_math.tree_full_like
43 | max = _tree_math.tree_max # pylint: disable=redefined-builtin
44 | min = _tree_math.tree_min # pylint: disable=redefined-builtin
45 | mul = _tree_math.tree_mul
46 | norm = _tree_math.tree_norm
47 | ones_like = _tree_math.tree_ones_like
48 | real = _tree_math.tree_real
49 | scale = _tree_math.tree_scale
50 | size = _tree_math.tree_size
51 | sub = _tree_math.tree_sub
52 | sum = _tree_math.tree_sum # pylint: disable=redefined-builtin
53 | update_infinity_moment = _tree_math.tree_update_infinity_moment
54 | update_moment = _tree_math.tree_update_moment
55 | update_moment_per_elem_norm = _tree_math.tree_update_moment_per_elem_norm
56 | vdot = _tree_math.tree_vdot
57 | where = _tree_math.tree_where
58 | zeros_like = _tree_math.tree_zeros_like
59 |
--------------------------------------------------------------------------------
/docs/api/optimizers.rst:
--------------------------------------------------------------------------------
1 | Optimizers
2 | ==========
3 |
4 | .. currentmodule:: optax
5 |
6 | .. autosummary::
7 | adabelief
8 | adadelta
9 | adan
10 | adafactor
11 | adagrad
12 | adam
13 | adamw
14 | adamax
15 | adamaxw
16 | amsgrad
17 | fromage
18 | lamb
19 | lars
20 | lbfgs
21 | lion
22 | nadam
23 | nadamw
24 | noisy_sgd
25 | novograd
26 | optimistic_gradient_descent
27 | optimistic_adam_v2
28 | polyak_sgd
29 | radam
30 | rmsprop
31 | sgd
32 | sign_sgd
33 | signum
34 | sm3
35 | yogi
36 |
37 |
38 | AdaBelief
39 | ~~~~~~~~~
40 | .. autofunction:: adabelief
41 |
42 | AdaDelta
43 | ~~~~~~~~~
44 | .. autofunction:: adadelta
45 |
46 | Adan
47 | ~~~~
48 | .. autofunction:: adan
49 |
50 | AdaGrad
51 | ~~~~~~~
52 | .. autofunction:: adagrad
53 |
54 | AdaFactor
55 | ~~~~~~~~~
56 | .. autofunction:: adafactor
57 |
58 | Adam
59 | ~~~~
60 | .. autofunction:: adam
61 |
62 | Adamax
63 | ~~~~~~
64 | .. autofunction:: adamax
65 |
66 | AdamaxW
67 | ~~~~~~~
68 | .. autofunction:: adamaxw
69 |
70 | AdamW
71 | ~~~~~
72 | .. autofunction:: adamw
73 |
74 | AMSGrad
75 | ~~~~~~~
76 | .. autofunction:: amsgrad
77 |
78 | Fromage
79 | ~~~~~~~
80 | .. autofunction:: fromage
81 |
82 | Lamb
83 | ~~~~
84 | .. autofunction:: lamb
85 |
86 | Lars
87 | ~~~~
88 | .. autofunction:: lars
89 |
90 | LBFGS
91 | ~~~~~
92 | .. autofunction:: lbfgs
93 |
94 | Lion
95 | ~~~~
96 | .. autofunction:: lion
97 |
98 | Nadam
99 | ~~~~~
100 | .. autofunction:: nadam
101 |
102 | NadamW
103 | ~~~~~~
104 | .. autofunction:: nadamw
105 |
106 | Noisy SGD
107 | ~~~~~~~~~
108 | .. autofunction:: noisy_sgd
109 |
110 | Novograd
111 | ~~~~~~~~
112 | .. autofunction:: novograd
113 |
114 | Optimistic GD
115 | ~~~~~~~~~~~~~
116 | .. autofunction:: optimistic_gradient_descent
117 |
118 | Optimistic Adam
119 | ~~~~~~~~~~~~~~~
120 | .. autofunction:: optimistic_adam_v2
121 |
122 | Polyak step-size SGD
123 | ~~~~~~~~~~~~~~~~~~~~
124 | .. autofunction:: polyak_sgd
125 |
126 | RAdam
127 | ~~~~~
128 | .. autofunction:: radam
129 |
130 | RMSProp
131 | ~~~~~~~
132 | .. autofunction:: rmsprop
133 |
134 | RProp
135 | ~~~~~
136 | .. autofunction:: rprop
137 |
138 | SGD
139 | ~~~
140 | .. autofunction:: sgd
141 |
142 | SignSGD
143 | ~~~~~~~
144 | .. autofunction:: sign_sgd
145 |
146 | Signum
147 | ~~~~~~
148 | .. autofunction:: signum
149 |
150 | SM3
151 | ~~~
152 | .. autofunction:: sm3
153 |
154 | Yogi
155 | ~~~~
156 | .. autofunction:: yogi
157 |
--------------------------------------------------------------------------------
/optax/losses/_smoothing_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for smoothing functions in `optax.losses._smoothing.py`."""
16 |
17 | from absl.testing import absltest
18 | import jax
19 | import jax.numpy as jnp
20 | import numpy as np
21 | from optax.losses import _smoothing
22 |
23 |
24 | class SmoothLabelsTest(absltest.TestCase):
25 |
26 | def setUp(self):
27 | super().setUp()
28 | self.ts = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0]], dtype=np.float32)
29 | # compute expected outputs in numpy.
30 | self.exp_alpha_zero = self.ts
31 | self.exp_alpha_zero_point_one = 0.9 * self.ts + 0.1 / self.ts.shape[-1]
32 | self.exp_alpha_one = jnp.ones_like(self.ts) / self.ts.shape[-1]
33 |
34 | def test_scalar(self):
35 | """Tests for a full batch."""
36 | np.testing.assert_allclose(
37 | jax.jit(_smoothing.smooth_labels)(self.ts[0], 0.0),
38 | self.exp_alpha_zero[0],
39 | atol=1e-4,
40 | )
41 | np.testing.assert_allclose(
42 | jax.jit(_smoothing.smooth_labels)(self.ts[0], 0.1),
43 | self.exp_alpha_zero_point_one[0],
44 | atol=1e-4,
45 | )
46 | np.testing.assert_allclose(
47 | jax.jit(_smoothing.smooth_labels)(self.ts[0], 1.0),
48 | self.exp_alpha_one[0],
49 | atol=1e-4,
50 | )
51 |
52 | def test_batched(self):
53 | """Tests for a full batch."""
54 | np.testing.assert_allclose(
55 | jax.jit(_smoothing.smooth_labels)(self.ts, 0.0),
56 | self.exp_alpha_zero,
57 | atol=1e-4,
58 | )
59 | np.testing.assert_allclose(
60 | jax.jit(_smoothing.smooth_labels)(self.ts, 0.1),
61 | self.exp_alpha_zero_point_one,
62 | atol=1e-4,
63 | )
64 | np.testing.assert_allclose(
65 | jax.jit(_smoothing.smooth_labels)(self.ts, 1.0),
66 | self.exp_alpha_one,
67 | atol=1e-4,
68 | )
69 |
70 |
71 | if __name__ == '__main__':
72 | absltest.main()
73 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/google-deepmind/optax/tree/main/docs
2 |
3 | =====
4 | Optax
5 | =====
6 |
7 | Optax is a gradient processing and optimization library for JAX. It is designed
8 | to facilitate research by providing building blocks that can be recombined in
9 | custom ways in order to optimize parametric models such as, but not limited to,
10 | deep neural networks.
11 |
12 | Our goals are to
13 |
14 | * Provide readable, well-tested, efficient implementations of core components,
15 | * Improve researcher productivity by making it possible to combine low level
16 | ingredients into custom optimizer (or other gradient processing components).
17 | * Accelerate adoption of new ideas by making it easy for anyone to contribute.
18 |
19 | We favor focusing on small composable building blocks that can be effectively
20 | combined into custom solutions. Others may build upon these basic components
21 | more complicated abstractions. Whenever reasonable, implementations prioritize
22 | readability and structuring code to match standard equations, over code reuse.
23 |
24 | Installation
25 | ------------
26 |
27 | The latest release of Optax can be installed from
28 | `PyPI `_ using::
29 |
30 | pip install optax
31 |
32 | You may also install directly from GitHub, using the following command. This
33 | can be used to obtain the most recent version of Optax::
34 |
35 | pip install git+git://github.com/google-deepmind/optax.git
36 |
37 | Note that Optax is built on top of JAX.
38 | See `here `_
39 | for instructions on installing JAX.
40 |
41 |
42 | .. toctree::
43 | :hidden:
44 |
45 | getting_started
46 |
47 | gallery
48 |
49 | development
50 |
51 |
52 | .. toctree::
53 | :hidden:
54 | :caption: 📖 Reference
55 | :maxdepth: 2
56 |
57 | api/assignment
58 | api/optimizers
59 | api/transformations
60 | api/combining_optimizers
61 | api/optimizer_wrappers
62 | api/optimizer_schedules
63 | api/apply_updates
64 | api/perturbations
65 | api/projections
66 | api/losses
67 | api/stochastic_gradient_estimators
68 | api/utilities
69 | api/contrib
70 | api/experimental
71 |
72 |
73 | Support
74 | -------
75 |
76 | If you encounter issues with this software, please let us know by filing an issue on our `issue tracker `_. We are also happy to receive bug fixes and other contributions. For more information of how to contribute, please see the :doc:`development guide `.
77 |
78 |
79 | License
80 | -------
81 |
82 | Optax is licensed under the `Apache 2.0 License `_.
83 |
84 |
85 | Indices and Tables
86 | ==================
87 |
88 | * :ref:`genindex`
89 |
--------------------------------------------------------------------------------
/optax/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """The losses sub-package."""
16 |
17 | # pylint:disable=g-importing-member
18 |
19 | from optax.losses._classification import convex_kl_divergence
20 | from optax.losses._classification import ctc_loss
21 | from optax.losses._classification import ctc_loss_with_forward_probs
22 | from optax.losses._classification import generalized_kl_divergence
23 | from optax.losses._classification import hinge_loss
24 | from optax.losses._classification import kl_divergence
25 | from optax.losses._classification import kl_divergence_with_log_targets
26 | from optax.losses._classification import multiclass_hinge_loss
27 | from optax.losses._classification import multiclass_perceptron_loss
28 | from optax.losses._classification import multiclass_sparsemax_loss
29 | from optax.losses._classification import perceptron_loss
30 | from optax.losses._classification import poly_loss_cross_entropy
31 | from optax.losses._classification import safe_softmax_cross_entropy
32 | from optax.losses._classification import sigmoid_binary_cross_entropy
33 | from optax.losses._classification import sigmoid_focal_loss
34 | from optax.losses._classification import softmax_cross_entropy
35 | # pylint: disable=line-too-long
36 | from optax.losses._classification import softmax_cross_entropy_with_integer_labels # noqa: E501
37 | # pylint: enable=line-too-long
38 | from optax.losses._classification import sparsemax_loss
39 | from optax.losses._fenchel_young import make_fenchel_young_loss
40 | from optax.losses._ranking import ranking_softmax_loss
41 | from optax.losses._regression import cosine_distance
42 | from optax.losses._regression import cosine_similarity
43 | from optax.losses._regression import huber_loss
44 | from optax.losses._regression import l2_loss
45 | from optax.losses._regression import log_cosh
46 | from optax.losses._regression import squared_error
47 | from optax.losses._segmentation import binary_dice_loss
48 | from optax.losses._segmentation import dice_loss
49 | from optax.losses._segmentation import multiclass_generalized_dice_loss
50 | from optax.losses._self_supervised import ntxent
51 | from optax.losses._self_supervised import triplet_margin_loss
52 | from optax.losses._smoothing import smooth_labels
53 |
--------------------------------------------------------------------------------
/optax/transforms/_layouts.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Wrappers changing the layouts of the tensors that transforms operate on."""
16 |
17 | import jax
18 | import jax.numpy as jnp
19 | import numpy as np
20 | from optax._src import base
21 |
22 |
23 | def flatten(
24 | inner: base.GradientTransformation,
25 | ) -> base.GradientTransformationExtraArgs:
26 | """Flattens parameters and gradients for init and update of inner transform.
27 |
28 | This can reduce the overhead of performing many calculations on lots of small
29 | variables, at the cost of slightly increased memory usage.
30 |
31 | Args:
32 | inner: Inner transformation to flatten inputs for.
33 |
34 | Returns:
35 | New :class:`optax.GradientTransformationExtraArgs`
36 | """
37 |
38 | inner = base.with_extra_args_support(inner)
39 |
40 | def _flatten(params):
41 | """Flattens and concatenates all tensors in params to a single vector."""
42 | params, _ = jax.tree.flatten(params)
43 | return jnp.concatenate([jnp.reshape(param, [-1]) for param in params])
44 |
45 | def _unflatten(updates, flat):
46 | """Extracts tensors from flat, using the structure and shapes of params."""
47 | updates_flat, treedef = jax.tree.flatten(updates)
48 | offsets = []
49 | for update in updates_flat:
50 | size = np.size(update)
51 | if offsets:
52 | offsets.append(size + offsets[-1])
53 | else:
54 | offsets.append(size)
55 | del offsets[-1]
56 | flat_split = jnp.split(flat, offsets)
57 | reshaped = [
58 | jnp.reshape(flat_update, update.shape)
59 | for flat_update, update in zip(flat_split, updates_flat)
60 | ]
61 | return jax.tree.unflatten(treedef, reshaped)
62 |
63 | def init_fn(params):
64 | flat = _flatten(params)
65 | return inner.init(flat)
66 |
67 | def update_fn(updates, state, params=None, **extra_args):
68 | if params is not None:
69 | params = _flatten(params)
70 | updates_flat, state = inner.update(
71 | _flatten(updates), state, params, **extra_args
72 | )
73 | updates = _unflatten(updates, updates_flat)
74 | return updates, state
75 |
76 | return base.GradientTransformationExtraArgs(init_fn, update_fn)
77 |
--------------------------------------------------------------------------------
/optax/losses/_fenchel_young_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for fenchel young loss in `_fenchel_young.py`."""
16 |
17 | from absl.testing import absltest
18 | import jax
19 | import jax.numpy as jnp
20 | from jax.scipy.special import logsumexp # pylint: disable=g-importing-member
21 |
22 | from optax._src import test_utils
23 | from optax.losses import _classification
24 | from optax.losses import _fenchel_young
25 |
26 |
27 | def one_hot_argmax(inputs: jnp.ndarray) -> jnp.ndarray:
28 | """An argmax one-hot function for arbitrary shapes."""
29 | inputs_flat = jnp.reshape(inputs, (-1))
30 | flat_one_hot = jax.nn.one_hot(jnp.argmax(inputs_flat), inputs_flat.shape[0])
31 | return jnp.reshape(flat_one_hot, inputs.shape)
32 |
33 |
34 | class FenchelYoungTest(absltest.TestCase):
35 |
36 | def test_fenchel_young_reg(self):
37 | # Checks the behavior of the Fenchel-Young loss.
38 | fy_loss = jax.jit(_fenchel_young.make_fenchel_young_loss(logsumexp))
39 | rng = jax.random.key(0)
40 | rngs = jax.random.split(rng, 2)
41 | theta_true = jax.random.uniform(rngs[0], (8, 5))
42 | y_true = jax.vmap(jax.nn.softmax)(theta_true)
43 | theta_random = jax.random.uniform(rngs[1], (8, 5))
44 | y_random = jax.vmap(jax.nn.softmax)(theta_random)
45 | grad_random = jax.vmap(jax.grad(fy_loss))(theta_random, y_true)
46 | # Checks that the gradient of the loss takes the correct form.
47 | test_utils.assert_trees_all_close(grad_random, y_random - y_true, rtol=1e-4)
48 | y_one_hot = jax.vmap(one_hot_argmax)(theta_true)
49 | int_one_hot = jnp.where(y_one_hot == 1.)[1]
50 | loss_one_hot = jax.vmap(fy_loss)(theta_random, y_one_hot)
51 | log_loss = jax.vmap(
52 | _classification.softmax_cross_entropy_with_integer_labels)(
53 | theta_random, int_one_hot)
54 | # Checks that the FY loss associated to logsumexp is correct.
55 | test_utils.assert_trees_all_close(loss_one_hot, log_loss, rtol=1e-4)
56 | # Checks that vmapping or not is equivalent.
57 | loss_one_hot_no_vmap = fy_loss(theta_random, y_one_hot)
58 | test_utils.assert_trees_all_close(
59 | loss_one_hot, loss_one_hot_no_vmap, rtol=1e-4)
60 |
61 |
62 | if __name__ == "__main__":
63 | absltest.main()
64 |
--------------------------------------------------------------------------------
/docs/api/projections.rst:
--------------------------------------------------------------------------------
1 | Projections
2 | ===========
3 |
4 | .. currentmodule:: optax.projections
5 |
6 | Projections can be used to perform constrained optimization.
7 | The Euclidean projection onto a set :math:`\mathcal{C}` is:
8 |
9 | .. math::
10 |
11 | \text{proj}_{\mathcal{C}}(u) :=
12 | \underset{v}{\text{argmin}} ~ \|u - v\|^2_2 \textrm{ subject to } v \in \mathcal{C}.
13 |
14 | For instance, here is an example how we can project parameters to the non-negative orthant::
15 |
16 | >>> import optax
17 | >>> import jax
18 | >>> import jax.numpy as jnp
19 | >>> num_weights = 2
20 | >>> xs = jnp.array([[-1.8, 2.2], [-2.0, 1.2]])
21 | >>> ys = jnp.array([0.5, 0.8])
22 | >>> optimizer = optax.adam(learning_rate=1e-3)
23 | >>> params = {'w': jnp.zeros(num_weights)}
24 | >>> opt_state = optimizer.init(params)
25 | >>> loss = lambda params, x, y: jnp.mean((params['w'].dot(x) - y) ** 2)
26 | >>> grads = jax.grad(loss)(params, xs, ys)
27 | >>> updates, opt_state = optimizer.update(grads, opt_state)
28 | >>> params = optax.apply_updates(params, updates)
29 | >>> params = optax.projections.projection_non_negative(params)
30 |
31 | Available projections
32 | ~~~~~~~~~~~~~~~~~~~~~
33 | .. autosummary::
34 | projection_box
35 | projection_hypercube
36 | projection_l1_ball
37 | projection_l1_sphere
38 | projection_l2_ball
39 | projection_l2_sphere
40 | projection_linf_ball
41 | projection_non_negative
42 | projection_simplex
43 | projection_vector
44 | projection_hyperplane
45 | projection_halfspace
46 |
47 | Projection onto a box
48 | ~~~~~~~~~~~~~~~~~~~~~
49 | .. autofunction:: projection_box
50 |
51 | Projection onto a hypercube
52 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~
53 | .. autofunction:: projection_hypercube
54 |
55 | Projection onto the L1 ball
56 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~
57 | .. autofunction:: projection_l1_ball
58 |
59 | Projection onto the L1 sphere
60 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
61 | .. autofunction:: projection_l1_sphere
62 |
63 | Projection onto the L2 ball
64 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~
65 | .. autofunction:: projection_l2_ball
66 |
67 | Projection onto the L2 sphere
68 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
69 | .. autofunction:: projection_l2_sphere
70 |
71 | Projection onto the L-infinity ball
72 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
73 | .. autofunction:: projection_linf_ball
74 |
75 | Projection onto the non-negative orthant
76 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
77 | .. autofunction:: projection_non_negative
78 |
79 | Projection onto a simplex
80 | ~~~~~~~~~~~~~~~~~~~~~~~~~
81 | .. autofunction:: projection_simplex
82 |
83 | Projection onto a vector
84 | ~~~~~~~~~~~~~~~~~~~~~~~~
85 | .. autofunction:: projection_vector
86 |
87 | Projection onto a hyperplane
88 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
89 | .. autofunction:: projection_hyperplane
90 |
91 | Projection onto a halfspace
92 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~
93 | .. autofunction:: projection_halfspace
94 |
--------------------------------------------------------------------------------
/optax/contrib/_complex_valued_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for methods in `complex_valued.py`."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import jax
20 | import jax.numpy as jnp
21 | import numpy as np
22 | from optax._src import transform
23 | from optax._src import update
24 | from optax.contrib import _complex_valued
25 |
26 |
27 | def _loss_fun_complex_to_real(z):
28 | return (z.conj() * z).real.sum()
29 |
30 |
31 | def _loss_fun_real_to_real(params):
32 | x, y = params
33 | return _loss_fun_complex_to_real(x + y * 1j)
34 |
35 |
36 | class ComplexValuedTest(parameterized.TestCase):
37 |
38 | @parameterized.named_parameters([
39 | ('adam', transform.scale_by_adam),
40 | ('param_block_norm', transform.scale_by_param_block_norm),
41 | ])
42 | def test_split_real_and_imaginary(self, scaler_constr):
43 |
44 | def do_update(loss_fun, optimizer, params, opt_state):
45 | loss, grads = jax.value_and_grad(loss_fun)(params)
46 | # Complex gradients need to be conjugated before being added to parameters
47 | grads = jax.tree.map(lambda x: x.conj(), grads)
48 | updates, opt_state = jax.jit(optimizer.update)(
49 | grads, opt_state, params
50 | )
51 | params = update.apply_updates(params, updates)
52 | return loss, grads, params, opt_state
53 |
54 | x = jnp.array([[0.1, 0.2, 0.3], [-0.1, -0.2, -0.3]])
55 | y = jnp.array([[0.5, -0.5, 0], [0.1, 0.3, -0.2]])
56 | z = x + y * 1j
57 |
58 | optimizer = scaler_constr()
59 | optimizer_complex = _complex_valued.split_real_and_imaginary(optimizer)
60 | opt_state = jax.jit(optimizer.init)((x, y))
61 | opt_state_complex = jax.jit(optimizer_complex.init)(z)
62 |
63 | # Check that the loss, the gradients, and the parameters are the same for
64 | # real-to-real and complex-to-real loss functions in each step
65 | for _ in range(3):
66 | loss, (gx, gy), (x, y), opt_state = do_update(
67 | _loss_fun_real_to_real, optimizer, (x, y), opt_state
68 | )
69 | loss_complex, gz, z, opt_state_complex = do_update(
70 | _loss_fun_complex_to_real, optimizer_complex, z, opt_state_complex
71 | )
72 | rtol = 1e-6
73 | np.testing.assert_allclose(loss, loss_complex, rtol=rtol)
74 | np.testing.assert_allclose(gx + gy * 1j, gz, rtol=rtol)
75 | np.testing.assert_allclose(x + y * 1j, z, rtol=rtol)
76 |
77 |
78 | if __name__ == '__main__':
79 | absltest.main()
80 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | function cleanup {
17 | deactivate
18 | rm -r "${TEMP_DIR}"
19 | }
20 | trap cleanup EXIT
21 |
22 | REPO_DIR=$(pwd)
23 | TEMP_DIR=$(mktemp --directory)
24 |
25 | set -o errexit
26 | set -o nounset
27 | set -o pipefail
28 |
29 | # Install deps in a virtual env.
30 | python3 -m venv "${TEMP_DIR}/test_venv"
31 | source "${TEMP_DIR}/test_venv/bin/activate"
32 |
33 | # Run the linter first to check lint errors quickly
34 | python3 -m pip install --quiet --upgrade pip uv
35 | python3 -m uv pip install --quiet pre-commit
36 | pre-commit run -a
37 |
38 | # Install dependencies.
39 | python3 -m uv pip install --quiet --upgrade pip setuptools wheel
40 | python3 -m uv pip install --quiet --upgrade flake8 pytest-xdist pylint pylint-exit
41 | python3 -m uv pip install --quiet --editable ".[test]"
42 |
43 | # Install the requested JAX version
44 | if [ -z "${JAX_VERSION-}" ]; then
45 | : # use version installed in requirements above
46 | elif [ "$JAX_VERSION" = "newest" ]; then
47 | python3 -m uv pip install --quiet --upgrade jax jaxlib
48 | elif [ "$JAX_VERSION" = "nightly" ]; then
49 | python3 -m uv pip install --quiet --upgrade --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
50 | else
51 | python3 -m uv pip install --quiet "jax==${JAX_VERSION}" "jaxlib==${JAX_VERSION}"
52 | fi
53 |
54 | # Ensure optax was not installed by one of the dependencies above,
55 | # since if it is, the tests below will be run against that version instead of
56 | # the branch build.
57 | python3 -m uv pip uninstall optax
58 |
59 | # Lint with flake8.
60 | python3 -m flake8 --select=E9,F63,F7,F82,E225,E251 --show-source --statistics
61 |
62 | # Lint with pylint.
63 | pylint optax
64 |
65 | # Build the package.
66 | python3 -m uv pip install --quiet build
67 | python3 -m build
68 | python3 -m pip wheel --no-deps dist/optax-*.tar.gz --wheel-dir "${TEMP_DIR}"
69 | python3 -m pip install --quiet "${TEMP_DIR}/optax-"*.whl
70 |
71 | # Check types with pytype.
72 | python3 -m pip install --quiet pytype
73 | pytype "optax" -j auto --keep-going --disable import-error
74 |
75 | # Run tests using pytest.
76 | # Change directory to avoid importing the package from repo root.
77 | cd "${TEMP_DIR}"
78 | python3 -m pytest --numprocesses auto --pyargs optax
79 | #python3 -m pytest --numprocesses 8 --pyargs optax
80 | cd "${REPO_DIR}"
81 |
82 | # Build Sphinx docs.
83 | python3 -m uv pip install --quiet --editable ".[docs]"
84 | cd docs
85 | make html
86 | make doctest # run doctests
87 | cd ..
88 |
89 | echo "All tests passed. Congrats!"
90 |
--------------------------------------------------------------------------------
/docs/ext/coverage_check.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Asserts all public symbols are covered in the docs."""
16 |
17 | from collections.abc import Mapping
18 | import inspect
19 | import types
20 | from typing import Any, Sequence, Tuple
21 |
22 | import optax
23 | from sphinx import application
24 | from sphinx import builders
25 | from sphinx import errors
26 |
27 |
28 | def find_internal_python_modules(
29 | root_module: types.ModuleType,
30 | ) -> Sequence[Tuple[str, types.ModuleType]]:
31 | """Returns `(name, module)` for all Optax submodules under `root_module`."""
32 | modules = set([(root_module.__name__, root_module)])
33 | visited = set()
34 | to_visit = [root_module]
35 |
36 | while to_visit:
37 | mod = to_visit.pop()
38 | visited.add(mod)
39 |
40 | for name in dir(mod):
41 | obj = getattr(mod, name)
42 | if inspect.ismodule(obj) and obj not in visited:
43 | if obj.__name__.startswith("optax"):
44 | if "_src" not in obj.__name__:
45 | to_visit.append(obj)
46 | modules.add((obj.__name__, obj))
47 |
48 | return sorted(modules)
49 |
50 |
51 | def optax_public_symbols():
52 | """Collect all optax public symbols."""
53 | names = set()
54 | for module_name, module in find_internal_python_modules(optax):
55 | for name in module.__all__:
56 | names.add(module_name + "." + name)
57 | return names
58 |
59 |
60 | class OptaxCoverageCheck(builders.Builder):
61 | """Builder that checks all public symbols are included."""
62 |
63 | name = "coverage_check"
64 |
65 | def get_outdated_docs(self) -> str:
66 | return "coverage_check"
67 |
68 | def write(self, *ignored: Any) -> None: # pylint: disable=overridden-final-method
69 | pass
70 |
71 | def finish(self) -> None:
72 | documented_objects = frozenset(self.env.domaindata["py"]["objects"]) # pytype: disable=attribute-error
73 | undocumented_objects = set(optax_public_symbols()) - documented_objects
74 | if undocumented_objects:
75 | undocumented_objects = tuple(sorted(undocumented_objects))
76 | raise errors.SphinxError(
77 | "All public symbols must be included in our documentation, did you "
78 | "forget to add an entry to `api.rst`?\n"
79 | f"Undocumented symbols: {undocumented_objects}")
80 |
81 | def get_target_uri(self, docname, typ=None):
82 | raise NotImplementedError
83 |
84 | def write_doc(self, docname, doctree):
85 | raise NotImplementedError
86 |
87 |
88 | def setup(app: application.Sphinx) -> Mapping[str, Any]:
89 | app.add_builder(OptaxCoverageCheck)
90 | return {"version": optax.__version__, "parallel_read_safe": True}
91 |
--------------------------------------------------------------------------------
/optax/_src/deprecations.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Deprecation helpers."""
16 |
17 | import functools
18 | from typing import Any, Callable, Optional
19 | import warnings
20 |
21 |
22 | # Module __getattr__ factory that warns if deprecated names are used.
23 | #
24 | # Example usage:
25 | # from optax.contrib.dpsgd import dpsgd as _deprecated_dpsgd
26 | #
27 | # _deprecations = {
28 | # # Added Apr 2024:
29 | # "dpsgd": (
30 | # "optax.dpsgd is deprecated. Use optax.contrib.dpsgd instead.",
31 | # _deprecated_dpsgd,
32 | # ),
33 | # }
34 | #
35 | # from optax._src.deprecations import deprecation_getattr as _deprecation_getattr # pylint: disable=line-too-long # noqa: E501
36 | # __getattr__ = _deprecation_getattr(__name__, _deprecations)
37 | # del _deprecation_getattr
38 |
39 |
40 | # Note that type checkers such as Pytype will not know about the deprecated
41 | # names. If it is desirable that a deprecated name is known to the type checker,
42 | # add:
43 | # import typing
44 | # if typing.TYPE_CHECKING:
45 | # from optax.contrib import dpsgd
46 | # del typing
47 |
48 |
49 | def deprecation_getattr(module, deprecations):
50 | def _getattr(name):
51 | if name in deprecations:
52 | message, fn = deprecations[name]
53 | if fn is None: # Is the deprecation accelerated?
54 | raise AttributeError(message)
55 | warnings.warn(message, DeprecationWarning, stacklevel=2)
56 | return fn
57 | raise AttributeError(f'module {module!r} has no attribute {name!r}')
58 |
59 | return _getattr
60 |
61 |
62 | def warn_deprecated_function(
63 | fun: Callable[..., Any],
64 | replacement: Optional[str] = None,
65 | version_removed: Optional[str] = None,
66 | ) -> Callable[..., Any]:
67 | """A decorator to mark a function definition as deprecated.
68 |
69 | Args:
70 | fun: the deprecated function.
71 | replacement: name of the function to be used instead.
72 | version_removed: version of optax in which the function was/will be removed.
73 |
74 | Returns:
75 | The wrapped function.
76 |
77 | Example usage:
78 | >>> @functools.partial(warn_deprecated_function, replacement='g')
79 | ... def f(a, b):
80 | ... return a + b
81 | """
82 | if hasattr(fun, '__name__'):
83 | warning_message = f'The function {fun.__name__} is deprecated.'
84 | else:
85 | warning_message = 'The function is deprecated.'
86 | if replacement:
87 | warning_message += f' Please use {replacement} instead.'
88 | if version_removed:
89 | warning_message += (
90 | f' This function will be/was removed in optax {version_removed}.'
91 | )
92 |
93 | @functools.wraps(fun)
94 | def new_fun(*args, **kwargs):
95 | warnings.warn(warning_message, category=DeprecationWarning, stacklevel=2)
96 | return fun(*args, **kwargs)
97 |
98 | return new_fun
99 |
--------------------------------------------------------------------------------
/docs/api/contrib.rst:
--------------------------------------------------------------------------------
1 | 🔧 Contrib
2 | ===============
3 |
4 | Algorithms or wrappers that don't meet (yet) the :ref:`inclusion_criteria` or
5 | are not supported by the main library.
6 |
7 | .. currentmodule:: optax.contrib
8 |
9 | .. autosummary::
10 | acprop
11 | ademamix
12 | adopt
13 | simplified_ademamix
14 | cocob
15 | COCOBState
16 | dadapt_adamw
17 | DAdaptAdamWState
18 | differentially_private_aggregate
19 | DifferentiallyPrivateAggregateState
20 | dog
21 | DoGState
22 | dowg
23 | DoWGState
24 | dpsgd
25 | mechanize
26 | MechanicState
27 | momo
28 | MomoState
29 | momo_adam
30 | MomoAdamState
31 | muon
32 | MuonState
33 | prodigy
34 | ProdigyState
35 | sam
36 | SAMState
37 | schedule_free
38 | schedule_free_adamw
39 | schedule_free_eval_params
40 | schedule_free_sgd
41 | ScheduleFreeState
42 | sophia
43 | SophiaState
44 | split_real_and_imaginary
45 | SplitRealAndImaginaryState
46 |
47 | AdEMAMix
48 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~
49 | .. autofunction:: ademamix
50 | .. autofunction:: scale_by_ademamix
51 | .. autoclass:: ScaleByAdemamixState
52 |
53 | Simplified AdEMAMix
54 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~
55 | .. autofunction:: simplified_ademamix
56 | .. autofunction:: scale_by_simplified_ademamix
57 | .. autoclass:: ScaleBySimplifiedAdEMAMixState
58 |
59 | ADOPT
60 | ~~~~~
61 | .. autofunction:: adopt
62 | .. autofunction:: scale_by_adopt
63 |
64 | Asynchronous-centering-Prop
65 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~
66 | .. autofunction:: acprop
67 | .. autofunction:: scale_by_acprop
68 |
69 | Complex-valued Optimization
70 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~
71 | .. autofunction:: split_real_and_imaginary
72 | .. autoclass:: SplitRealAndImaginaryState
73 |
74 | Continuous coin betting
75 | ~~~~~~~~~~~~~~~~~~~~~~~
76 | .. autofunction:: cocob
77 | .. autoclass:: COCOBState
78 |
79 | D-adaptation
80 | ~~~~~~~~~~~~
81 | .. autofunction:: dadapt_adamw
82 | .. autoclass:: DAdaptAdamWState
83 |
84 | Differentially Private Aggregate
85 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
86 | .. autofunction:: differentially_private_aggregate
87 | .. autoclass:: DifferentiallyPrivateAggregateState
88 | .. autofunction:: dpsgd
89 |
90 | Distance over Gradients
91 | ~~~~~~~~~~~~~~~~~~~~~~~
92 | .. autofunction:: dog
93 | .. autoclass:: DoGState
94 | .. autofunction:: dowg
95 | .. autoclass:: DoWGState
96 |
97 | Mechanize
98 | ~~~~~~~~~
99 | .. autofunction:: mechanize
100 | .. autoclass:: MechanicState
101 |
102 | Momo
103 | ~~~~
104 | .. autofunction:: momo
105 | .. autoclass:: MomoState
106 | .. autofunction:: momo_adam
107 | .. autoclass:: MomoAdamState
108 |
109 | Muon
110 | ~~~~
111 | .. autofunction:: muon
112 | .. autofunction:: scale_by_muon
113 | .. autoclass:: MuonState
114 |
115 | Prodigy
116 | ~~~~~~~
117 | .. autofunction:: prodigy
118 | .. autoclass:: ProdigyState
119 |
120 | Schedule-Free
121 | ~~~~~~~~~~~~~
122 | .. autofunction:: schedule_free
123 | .. autofunction:: schedule_free_adamw
124 | .. autofunction:: schedule_free_eval_params
125 | .. autofunction:: schedule_free_sgd
126 | .. autoclass:: ScheduleFreeState
127 |
128 | Sharpness aware minimization
129 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
130 | .. autofunction:: sam
131 | .. autoclass:: SAMState
132 |
133 | Sophia
134 | ~~~~~~
135 | .. autofunction:: hutchinson_estimator_diag_hessian
136 | .. autoclass:: HutchinsonState
137 | .. autofunction:: sophia
138 | .. autoclass:: SophiaState
139 |
--------------------------------------------------------------------------------
/optax/_src/update_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for methods in `update.py`."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import jax
20 | import jax.numpy as jnp
21 | from optax._src import test_utils
22 | from optax._src import update
23 |
24 |
25 | class UpdateTest(parameterized.TestCase):
26 |
27 | def test_apply_updates(self):
28 | params = ({'a': jnp.ones((3, 2))}, jnp.ones((1,)))
29 | grads = jax.tree.map(lambda t: 2 * t, params)
30 | exp_params = jax.tree.map(lambda t: 3 * t, params)
31 | new_params = jax.jit(update.apply_updates)(params, grads)
32 |
33 | test_utils.assert_trees_all_close(
34 | exp_params, new_params, atol=1e-10, rtol=1e-5)
35 |
36 | def test_apply_updates_mixed_precision(self):
37 | params = (
38 | {'a': jnp.ones((3, 2), dtype=jnp.bfloat16)},
39 | jnp.ones((1,), dtype=jnp.bfloat16),
40 | )
41 | grads = jax.tree.map(lambda t: (2 * t).astype(jnp.float32), params)
42 | new_params = jax.jit(update.apply_updates)(params, grads)
43 |
44 | for leaf in jax.tree.leaves(new_params):
45 | assert leaf.dtype == jnp.bfloat16
46 |
47 | def test_incremental_update(self):
48 | params_1 = ({'a': jnp.ones((3, 2))}, jnp.ones((1,)))
49 | params_2 = jax.tree.map(lambda t: 2 * t, params_1)
50 | exp_params = jax.tree.map(lambda t: 1.5 * t, params_1)
51 | new_params = jax.jit(update.incremental_update)(
52 | params_2, params_1, 0.5
53 | )
54 |
55 | test_utils.assert_trees_all_close(
56 | exp_params, new_params, atol=1e-10, rtol=1e-5)
57 |
58 | def test_periodic_update(self):
59 | params_1 = ({'a': jnp.ones((3, 2))}, jnp.ones((1,)))
60 | params_2 = jax.tree.map(lambda t: 2 * t, params_1)
61 |
62 | update_period = 5
63 | update_fn = jax.jit(update.periodic_update)
64 |
65 | for j in range(3):
66 | for i in range(1, update_period):
67 | new_params = update_fn(
68 | params_2, params_1, j * update_period + i, update_period
69 | )
70 | test_utils.assert_trees_all_close(
71 | params_1, new_params, atol=1e-10, rtol=1e-5)
72 |
73 | new_params = update_fn(
74 | params_2, params_1, (j + 1) * update_period, update_period
75 | )
76 | test_utils.assert_trees_all_close(
77 | params_2, new_params, atol=1e-10, rtol=1e-5)
78 |
79 | @parameterized.named_parameters(
80 | {'testcase_name': 'apply_updates', 'operation': update.apply_updates},
81 | {
82 | 'testcase_name': 'incremental_update',
83 | 'operation': lambda x, y: update.incremental_update(x, y, 1),
84 | },
85 | )
86 | def test_none_argument(self, operation):
87 | x = jnp.array([1.0, 2.0, 3.0])
88 | operation(None, x)
89 |
90 |
91 | if __name__ == '__main__':
92 | absltest.main()
93 |
--------------------------------------------------------------------------------
/optax/contrib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contributed optimizers in Optax."""
16 |
17 | # pylint: disable=g-importing-member
18 |
19 | from optax.contrib._acprop import acprop
20 | from optax.contrib._acprop import scale_by_acprop
21 | from optax.contrib._ademamix import ademamix
22 | from optax.contrib._ademamix import scale_by_ademamix
23 | from optax.contrib._ademamix import scale_by_simplified_ademamix
24 | from optax.contrib._ademamix import ScaleByAdemamixState
25 | from optax.contrib._ademamix import ScaleBySimplifiedAdEMAMixState
26 | from optax.contrib._ademamix import simplified_ademamix
27 | from optax.contrib._adopt import adopt
28 | from optax.contrib._adopt import scale_by_adopt
29 | from optax.contrib._cocob import cocob
30 | from optax.contrib._cocob import COCOBState
31 | from optax.contrib._cocob import scale_by_cocob
32 | from optax.contrib._complex_valued import split_real_and_imaginary
33 | from optax.contrib._complex_valued import SplitRealAndImaginaryState
34 | from optax.contrib._dadapt_adamw import dadapt_adamw
35 | from optax.contrib._dadapt_adamw import DAdaptAdamWState
36 | from optax.contrib._dog import dog
37 | from optax.contrib._dog import DoGState
38 | from optax.contrib._dog import dowg
39 | from optax.contrib._dog import DoWGState
40 | from optax.contrib._mechanic import MechanicState
41 | from optax.contrib._mechanic import mechanize
42 | from optax.contrib._momo import momo
43 | from optax.contrib._momo import momo_adam
44 | from optax.contrib._momo import MomoAdamState
45 | from optax.contrib._momo import MomoState
46 | from optax.contrib._muon import muon
47 | from optax.contrib._muon import MuonDimensionNumbers
48 | from optax.contrib._muon import MuonState
49 | from optax.contrib._muon import scale_by_muon
50 | from optax.contrib._privacy import differentially_private_aggregate
51 | from optax.contrib._privacy import DifferentiallyPrivateAggregateState
52 | from optax.contrib._privacy import dpsgd
53 | from optax.contrib._prodigy import prodigy
54 | from optax.contrib._prodigy import ProdigyState
55 | from optax.contrib._reduce_on_plateau import reduce_on_plateau
56 | from optax.contrib._reduce_on_plateau import ReduceLROnPlateauState
57 | from optax.contrib._sam import normalize
58 | from optax.contrib._sam import NormalizeState
59 | from optax.contrib._sam import sam
60 | from optax.contrib._sam import SAMState
61 | from optax.contrib._schedule_free import schedule_free
62 | from optax.contrib._schedule_free import schedule_free_adamw
63 | from optax.contrib._schedule_free import schedule_free_eval_params
64 | from optax.contrib._schedule_free import schedule_free_sgd
65 | from optax.contrib._schedule_free import ScheduleFreeState
66 | from optax.contrib._sophia import hutchinson_estimator_diag_hessian
67 | from optax.contrib._sophia import HutchinsonState
68 | from optax.contrib._sophia import sophia
69 | from optax.contrib._sophia import SophiaState
70 |
--------------------------------------------------------------------------------
/optax/transforms/_constraining.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Gradient transformations used to enforce specific constraints."""
16 |
17 | from typing import Any, NamedTuple
18 |
19 | import jax
20 | import jax.numpy as jnp
21 | from optax._src import base
22 |
23 |
24 | NonNegativeParamsState = base.EmptyState
25 |
26 |
27 | def keep_params_nonnegative() -> base.GradientTransformation:
28 | """Modifies the updates to keep parameters non-negative, i.e. >= 0.
29 |
30 | This transformation ensures that parameters after the update will be
31 | larger than or equal to zero.
32 | In a chain of transformations, this should be the last one.
33 |
34 | Returns:
35 | A :class:`optax.GradientTransformation` object.
36 |
37 | .. warning::
38 | The transformation expects input params to be non-negative.
39 | When params is negative the transformed update will move them to 0.
40 | """
41 |
42 | def init_fn(params):
43 | del params
44 | return NonNegativeParamsState()
45 |
46 | def update_fn(updates, state, params):
47 | if params is None:
48 | raise ValueError(base.NO_PARAMS_MSG)
49 |
50 | updates = jax.tree.map(
51 | lambda p, u: None if p is None else jnp.where((p + u) < 0.0, -p, u),
52 | params,
53 | updates,
54 | is_leaf=lambda x: x is None,
55 | )
56 | return updates, state
57 |
58 | return base.GradientTransformation(init_fn, update_fn)
59 |
60 |
61 | class ZeroNansState(NamedTuple):
62 | """Contains a tree.
63 |
64 | The entry `found_nan` has the same tree structure as that of the parameters.
65 | Each leaf is a single boolean which contains True iff a NaN was detected in
66 | the corresponding parameter array at the last call to `update`.
67 | """
68 |
69 | found_nan: Any
70 |
71 |
72 | def zero_nans() -> base.GradientTransformation:
73 | """A transformation which replaces NaNs with 0.
74 |
75 | The state of the transformation has the same tree structure as that of the
76 | parameters. Each leaf is a single boolean which contains True iff a NaN was
77 | detected in the corresponding parameter array at the last call to ``update``.
78 | This state is not used by the transformation internally, but lets users be
79 | aware when NaNs have been zeroed out.
80 |
81 | Returns:
82 | A :class:`optax.GradientTransformation`.
83 | """
84 |
85 | def init_fn(params):
86 | return ZeroNansState(
87 | found_nan=jax.tree.map(
88 | lambda p: jnp.array(False, dtype=jnp.bool_), params
89 | )
90 | )
91 |
92 | def update_fn(updates, opt_state, params=None):
93 | del params, opt_state
94 | opt_state = ZeroNansState(
95 | found_nan=jax.tree.map(lambda p: jnp.any(jnp.isnan(p)), updates)
96 | )
97 | updates = jax.tree.map(
98 | lambda p: jnp.where(jnp.isnan(p), jnp.zeros_like(p), p), updates
99 | )
100 | return updates, opt_state
101 |
102 | return base.GradientTransformation(init=init_fn, update=update_fn)
103 |
--------------------------------------------------------------------------------
/docs/api/losses.rst:
--------------------------------------------------------------------------------
1 | Losses
2 | ======
3 |
4 | .. currentmodule:: optax.losses
5 |
6 | .. autosummary::
7 | binary_dice_loss
8 | convex_kl_divergence
9 | cosine_distance
10 | cosine_similarity
11 | ctc_loss
12 | ctc_loss_with_forward_probs
13 | dice_loss
14 | generalized_kl_divergence
15 | hinge_loss
16 | huber_loss
17 | kl_divergence
18 | kl_divergence_with_log_targets
19 | l2_loss
20 | log_cosh
21 | make_fenchel_young_loss
22 | multiclass_generalized_dice_loss
23 | multiclass_hinge_loss
24 | multiclass_perceptron_loss
25 | multiclass_sparsemax_loss
26 | ntxent
27 | perceptron_loss
28 | poly_loss_cross_entropy
29 | ranking_softmax_loss
30 | safe_softmax_cross_entropy
31 | sigmoid_binary_cross_entropy
32 | sigmoid_focal_loss
33 | smooth_labels
34 | softmax_cross_entropy
35 | softmax_cross_entropy_with_integer_labels
36 | sparsemax_loss
37 | squared_error
38 | triplet_margin_loss
39 |
40 |
41 | Convex Kullback Leibler divergence
42 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
43 | .. autofunction:: convex_kl_divergence
44 |
45 | Cosine distance
46 | ~~~~~~~~~~~~~~~
47 | .. autofunction:: cosine_distance
48 |
49 | Cosine similarity
50 | ~~~~~~~~~~~~~~~~~
51 | .. autofunction:: cosine_similarity
52 |
53 | Connectionist temporal classification loss
54 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
55 | .. autofunction:: ctc_loss
56 | .. autofunction:: ctc_loss_with_forward_probs
57 |
58 | Dice loss
59 | ~~~~~~~~~
60 | .. autofunction:: dice_loss
61 | .. autofunction:: multiclass_generalized_dice_loss
62 | .. autofunction:: binary_dice_loss
63 |
64 | Fenchel Young loss
65 | ~~~~~~~~~~~~~~~~~~
66 | .. autofunction:: make_fenchel_young_loss
67 |
68 | Generalized Kullback-Leibler divergence
69 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
70 | .. autofunction:: generalized_kl_divergence
71 |
72 | Hinge loss
73 | ~~~~~~~~~~
74 | .. autofunction:: hinge_loss
75 | .. autofunction:: multiclass_hinge_loss
76 |
77 | Huber loss
78 | ~~~~~~~~~~
79 | .. autofunction:: huber_loss
80 |
81 | Kullback-Leibler divergence
82 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~
83 | .. autofunction:: kl_divergence
84 | .. autofunction:: kl_divergence_with_log_targets
85 |
86 | L2 Squared loss
87 | ~~~~~~~~~~~~~~~
88 | .. autofunction:: squared_error
89 | .. autofunction:: l2_loss
90 |
91 | Log hyperbolic cosine loss
92 | ~~~~~~~~~~~~~~~~~~~~~~~~~~
93 | .. autofunction:: log_cosh
94 |
95 | Normalized temperature scaled cross-entropy (NT-Xent) loss
96 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
97 | .. autofunction:: ntxent
98 |
99 | Poly loss cross-entropy
100 | ~~~~~~~~~~~~~~~~~~~~~~~
101 | .. autofunction:: poly_loss_cross_entropy
102 |
103 | Perceptron
104 | ~~~~~~~~~~~
105 | .. autofunction:: perceptron_loss
106 | .. autofunction:: multiclass_perceptron_loss
107 |
108 | Ranking softmax loss
109 | ~~~~~~~~~~~~~~~~~~~~
110 | .. autofunction:: ranking_softmax_loss
111 |
112 | Sigmoid binary cross-entropy
113 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
114 | .. autofunction:: sigmoid_binary_cross_entropy
115 |
116 | Sigmoid focal loss
117 | ~~~~~~~~~~~~~~~~~~
118 | .. autofunction:: sigmoid_focal_loss
119 |
120 | Smoothing labels
121 | ~~~~~~~~~~~~~~~~
122 | .. autofunction:: smooth_labels
123 |
124 | Soft-max cross-entropy
125 | ~~~~~~~~~~~~~~~~~~~~~~
126 | .. autofunction:: safe_softmax_cross_entropy
127 | .. autofunction:: softmax_cross_entropy
128 | .. autofunction:: softmax_cross_entropy_with_integer_labels
129 |
130 | Sparsemax
131 | ~~~~~~~~~
132 | .. autofunction:: sparsemax_loss
133 | .. autofunction:: multiclass_sparsemax_loss
134 |
135 | Triplet margin loss
136 | ~~~~~~~~~~~~~~~~~~~
137 | .. autofunction:: triplet_margin_loss
138 |
--------------------------------------------------------------------------------
/optax/contrib/_sam_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for the SAM optimizer in `sam.py`."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import jax
20 | import jax.numpy as jnp
21 | from optax._src import alias
22 | from optax._src import combine
23 | from optax._src import numerics
24 | from optax._src import test_utils
25 | from optax._src import update
26 | from optax.contrib import _sam
27 | import optax.tree
28 |
29 | _BASE_OPTIMIZERS_UNDER_TEST = [
30 | {'base_opt_name': 'sgd', 'base_opt_kwargs': {'learning_rate': 1e-3}},
31 | ]
32 | _ADVERSARIAL_OPTIMIZERS_UNDER_TEST = [
33 | {'adv_opt_name': 'sgd', 'adv_opt_kwargs': {'learning_rate': 1e-5}},
34 | {'adv_opt_name': 'adam', 'adv_opt_kwargs': {'learning_rate': 1e-4}},
35 | ]
36 |
37 |
38 | def _setup_parabola(dtype):
39 | """Quadratic function as an optimization target."""
40 | initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype)
41 | final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype)
42 |
43 | @jax.grad
44 | def get_updates(params):
45 | return jnp.sum(numerics.abs_sq(params - final_params))
46 |
47 | return initial_params, final_params, get_updates
48 |
49 |
50 | class SAMTest(parameterized.TestCase):
51 |
52 | @parameterized.product(
53 | _BASE_OPTIMIZERS_UNDER_TEST,
54 | _ADVERSARIAL_OPTIMIZERS_UNDER_TEST,
55 | sync_period=(2,),
56 | target=(_setup_parabola,),
57 | dtype=('float32',),
58 | opaque_mode=(False, True),
59 | )
60 | def test_optimization(
61 | self,
62 | base_opt_name,
63 | base_opt_kwargs,
64 | adv_opt_name,
65 | adv_opt_kwargs,
66 | sync_period,
67 | target,
68 | dtype,
69 | opaque_mode,
70 | ):
71 | dtype = jnp.dtype(dtype)
72 | base_opt = getattr(alias, base_opt_name)(**base_opt_kwargs)
73 | adv_opt = combine.chain(
74 | _sam.normalize(), getattr(alias, adv_opt_name)(**adv_opt_kwargs)
75 | )
76 | opt = _sam.sam(
77 | base_opt, adv_opt, sync_period=sync_period, opaque_mode=opaque_mode
78 | )
79 | initial_params, final_params, get_updates = target(dtype)
80 |
81 | if opaque_mode:
82 | update_kwargs = {'grad_fn': lambda p, _: get_updates(p)}
83 | else:
84 | update_kwargs = {}
85 |
86 | @jax.jit
87 | def step(params, state):
88 | updates = get_updates(params)
89 | updates, state = opt.update(updates, state, params, **update_kwargs)
90 | params = update.apply_updates(params, updates)
91 | return params, state
92 |
93 | params = initial_params
94 | state = opt.init(params)
95 | # A no-op change, to verify that tree map works.
96 | state = optax.tree.map_params(opt, lambda v: v, state)
97 |
98 | for _ in range(25000 * sync_period):
99 | params, state = step(params, state)
100 |
101 | test_utils.assert_trees_all_close(
102 | params, final_params, rtol=3e-2, atol=3e-2)
103 |
104 |
105 | if __name__ == '__main__':
106 | absltest.main()
107 |
--------------------------------------------------------------------------------
/optax/_src/factorized_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for methods in `factorized.py`."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import jax
20 | import jax.numpy as jnp
21 | from optax._src import factorized
22 | from optax._src import test_utils
23 | from optax.transforms import _accumulation
24 |
25 |
26 | class FactorizedTest(parameterized.TestCase):
27 |
28 | def setUp(self):
29 | super().setUp()
30 | self.init_params = (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0]))
31 | self.per_step_updates = (jnp.array([500.0, 5.0]), jnp.array([300.0, 3.0]))
32 |
33 | def test_scale_by_factored_rms(self):
34 | params = self.init_params
35 |
36 | scaler = factorized.scale_by_factored_rms()
37 | init_fn = jax.jit(scaler.init)
38 | transform_fn = jax.jit(scaler.update)
39 |
40 | state = init_fn(params)
41 | test_utils.assert_tree_all_finite(state)
42 |
43 | updates, state = transform_fn(self.per_step_updates, state, params)
44 | test_utils.assert_tree_all_finite((params, updates, state))
45 | test_utils.assert_trees_all_equal_shapes(params, updates)
46 |
47 | @parameterized.product(
48 | factorized_dims=(True, False), dtype=('bfloat16', 'float32')
49 | )
50 | def test_preserve_dtype(self, factorized_dims: bool, dtype: str):
51 | """Test that the optimizer returns updates of same dtype as params."""
52 | dtype = jnp.dtype(dtype)
53 | opt = factorized.scale_by_factored_rms()
54 | fun = lambda x: jnp.sum(x**2)
55 |
56 | if factorized_dims:
57 | # The updates are factored only for large enough parameters
58 | # default min_dim_size_to_factor is 128 so we use 129 here.
59 | params = jnp.ones((129, 129), dtype=dtype)
60 | else:
61 | params = jnp.array([1.0, 2.0], dtype=dtype)
62 | grads = jax.grad(fun)(params)
63 | state = jax.jit(opt.init)(params)
64 | updates, _ = jax.jit(opt.update)(grads, state, params)
65 | self.assertEqual(updates.dtype, params.dtype)
66 |
67 | @parameterized.product(
68 | factorized_dims=(True, False), dtype=('bfloat16', 'float32')
69 | )
70 | def test_gradient_accumulation(self, factorized_dims, dtype):
71 | """Test that the optimizers can safely be used with optax.MultiSteps."""
72 | # Checks if https://github.com/google-deepmind/optax/issues/377 is fixed.
73 | dtype = jnp.dtype(dtype)
74 | base_opt = factorized.scale_by_factored_rms()
75 | opt = _accumulation.MultiSteps(base_opt, every_k_schedule=4)
76 |
77 | fun = lambda x: jnp.sum(x**2)
78 |
79 | if factorized_dims:
80 | # The updates are factored only for large enough parameters
81 | # default min_dim_size_to_factor is 128 so we use 129 here.
82 | params = jnp.ones((129, 129), dtype=dtype)
83 | else:
84 | params = jnp.array([1.0, 2.0], dtype=dtype)
85 | grads = jax.grad(fun)(params)
86 | state = jax.jit(opt.init)(params)
87 | updates, _ = jax.jit(opt.update)(grads, state, params)
88 | test_utils.assert_trees_all_equal(updates, jnp.zeros_like(grads))
89 |
90 |
91 | if __name__ == '__main__':
92 | absltest.main()
93 |
--------------------------------------------------------------------------------
/.github/check_license_headers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Check all Python files in the optax directory for license header."""
16 |
17 | from concurrent import futures
18 | import logging
19 | import pathlib
20 | import pprint
21 | import re
22 | import sys
23 |
24 | logger = logging.getLogger(pathlib.Path(__file__).name)
25 | logging.basicConfig(format=logging.BASIC_FORMAT)
26 | logger.setLevel("DEBUG")
27 |
28 | # pylint: disable=line-too-long
29 | LICENSE_PATTERN = (
30 | "(# (pylint|coding).*\n)*"
31 | "# Copyright 20[0-9][0-9] DeepMind Technologies Limited. All Rights Reserved.\n" # noqa: E501
32 | "#\n"
33 | "# Licensed under the Apache License, Version 2.0 \\(the \"License\"\\);\n"
34 | "# you may not use this file except in compliance with the License.\n"
35 | "# You may obtain a copy of the License at\n"
36 | "#\n"
37 | "# http://www.apache.org/licenses/LICENSE-2.0\n"
38 | "#\n"
39 | "# Unless required by applicable law or agreed to in writing, software\n"
40 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n"
41 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" # noqa: E501
42 | "# See the License for the specific language governing permissions and\n"
43 | "# limitations under the License.\n"
44 | "# ==============================================================================\n" # noqa: E501
45 | ".*"
46 | )
47 | # pylint: enable=line-too-long
48 |
49 | LICENSE_TEMPLATE = """
50 | # Copyright 20XX DeepMind Technologies Limited. All Rights Reserved.
51 | #
52 | # Licensed under the Apache License, Version 2.0 (the "License");
53 | # you may not use this file except in compliance with the License.
54 | # You may obtain a copy of the License at
55 | #
56 | # http://www.apache.org/licenses/LICENSE-2.0
57 | #
58 | # Unless required by applicable law or agreed to in writing, software
59 | # distributed under the License is distributed on an "AS IS" BASIS,
60 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
61 | # See the License for the specific language governing permissions and
62 | # limitations under the License.
63 | # ==============================================================================
64 | """
65 |
66 | EXCLUDE_LIST = []
67 |
68 |
69 | def _check_license_header(fname):
70 | if fname in EXCLUDE_LIST:
71 | return True
72 | try:
73 | source = pathlib.Path(fname).read_text()
74 | return re.match(LICENSE_PATTERN, source) is not None
75 | except UnicodeDecodeError:
76 | return True
77 |
78 | if __name__ == "__main__":
79 | # check all Python files in the optax directory for license header
80 | source_files = list(pathlib.Path("./optax").glob("**/*.py"))
81 | with futures.ThreadPoolExecutor(max_workers=32) as executor:
82 | results = dict(zip(source_files,
83 | executor.map(_check_license_header, source_files)))
84 | failed_files = [str(fname) for fname, status in results.items() if not status]
85 | if failed_files:
86 | logger.error(
87 | "Files:\n%s\ndon't have the proper license. Please include this license"
88 | " template at the top of your file:\n%s", pprint.pformat(failed_files),
89 | LICENSE_TEMPLATE)
90 | sys.exit(1) # non-success return
91 |
--------------------------------------------------------------------------------
/optax/transforms/_monitoring_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for monitoring and debugging in optax."""
17 |
18 | from absl.testing import absltest
19 |
20 | import jax
21 | import jax.numpy as jnp
22 | import numpy as np
23 | from optax import tree
24 | from optax._src import alias
25 | from optax._src import update
26 | from optax.transforms import _clipping
27 | from optax.transforms import _combining
28 | from optax.transforms import _monitoring
29 |
30 |
31 | class MonitoringTest(absltest.TestCase):
32 |
33 | def test_snapshot(self):
34 | """Tests that snapshot stores the correct values."""
35 |
36 | def f(x):
37 | return jnp.sum(x**2)
38 |
39 | opt_before_clip = _combining.chain(
40 | alias.sgd(learning_rate=0.1, momentum=0.9),
41 | _monitoring.snapshot('norm_before_clip', tree.norm),
42 | )
43 | opt = _combining.chain(opt_before_clip, _clipping.clip_by_global_norm(0.05))
44 |
45 | params = jnp.array([1.0, 2.0, 3.0])
46 | state_aux = opt_before_clip.init(params)
47 | state = opt.init(params)
48 |
49 | # Testing for two steps to observe behavior not only after initialization
50 | # but also after the first update.
51 | for step in range(1, 3):
52 | grads = jax.grad(f)(params)
53 | updates_before_clip, state_aux = opt_before_clip.update(grads, state_aux)
54 | updates, state = opt.update(grads, state)
55 | params = update.apply_updates(params, updates)
56 | with self.subTest(f'norms equal at {step=}'):
57 | got = tree.get(state, 'norm_before_clip')
58 | expected = tree.norm(updates_before_clip)
59 | np.testing.assert_allclose(got, expected)
60 |
61 | def test_monitor(self):
62 | """Tests that monitor stores the correct values."""
63 |
64 | def f(x):
65 | return jnp.sum(x**2)
66 |
67 | ema_decay = 0.9
68 | debias = True
69 | opt_before_clip = _combining.chain(
70 | alias.sgd(learning_rate=0.1, momentum=0.9),
71 | _monitoring.monitor({
72 | 'norm_before_clip': tree.norm,
73 | 'norm_before_clip_ema': _monitoring.measure_with_ema(
74 | tree.norm, ema_decay, debias
75 | ),
76 | }),
77 | )
78 | opt = _combining.chain(opt_before_clip, _clipping.clip_by_global_norm(0.05))
79 |
80 | params = jnp.array([1.0, 2.0, 3.0])
81 | state_aux = opt_before_clip.init(params)
82 | state = opt.init(params)
83 |
84 | ema_norm = 0.0
85 | for step in range(1, 4):
86 | grads = jax.grad(f)(params)
87 | updates_before_clip, state_aux = opt_before_clip.update(grads, state_aux)
88 | updates, state = opt.update(grads, state)
89 | params = update.apply_updates(params, updates)
90 | norm_before_clip = tree.norm(updates_before_clip)
91 | with self.subTest(f'norms equal at {step=}'):
92 | np.testing.assert_allclose(
93 | tree.get(state, 'norm_before_clip'), norm_before_clip
94 | )
95 |
96 | ema_norm = ema_decay * ema_norm + (1 - ema_decay) * norm_before_clip
97 | ema_norm_debiased = ema_norm / (1 - ema_decay**step)
98 | with self.subTest(f'ema norms equal at {step=}'):
99 | np.testing.assert_allclose(
100 | tree.get(state, 'norm_before_clip_ema'),
101 | ema_norm_debiased,
102 | rtol=1e-5,
103 | )
104 |
105 |
106 | if __name__ == '__main__':
107 | absltest.main()
108 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["flit_core >=3.2,<4"]
3 | build-backend = "flit_core.buildapi"
4 |
5 | [project]
6 | name = "optax"
7 | dynamic = ["version"]
8 | description = "A gradient processing and optimization library in JAX."
9 | readme = "README.md"
10 | license = { file = "LICENSE" }
11 | requires-python = ">=3.10"
12 | authors = [
13 | {name = "Google DeepMind", email = "optax-dev@google.com"},
14 | ]
15 | keywords = [
16 | "python",
17 | "machine learning",
18 | "reinforcement-learning"
19 | ]
20 | classifiers = [
21 | "Environment :: Console",
22 | "Programming Language :: Python",
23 | "Intended Audience :: Developers",
24 | "Operating System :: OS Independent",
25 | "Programming Language :: Python :: 3",
26 | "Intended Audience :: Science/Research",
27 | "Development Status :: 4 - Beta",
28 | "License :: OSI Approved :: Apache Software License",
29 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
30 | "Topic :: Software Development :: Libraries :: Python Modules",
31 | ]
32 | dependencies = [
33 | "absl-py>=0.7.1",
34 | "chex>=0.1.87",
35 | # Keep jax, jaxlib versions in sync with .github/workflows/tests.yml
36 | "jax>=0.5.3",
37 | "jaxlib>=0.5.3",
38 | "numpy>=1.18.0",
39 | ]
40 |
41 | [project.urls]
42 | homepage = "https://github.com/google-deepmind/optax"
43 | repository = "https://github.com/google-deepmind/optax"
44 | documentation = "https://optax.readthedocs.io/"
45 |
46 | [project.optional-dependencies]
47 | test = [
48 | "flax>=0.5.3",
49 | "scipy>=1.7.1",
50 | "scikit-learn"
51 | ]
52 |
53 | docs = [
54 | "sphinx>=6.0.0",
55 | "sphinx-book-theme>=1.0.1", # Older versions fail to pin pydata-sphinx-theme
56 | "sphinxcontrib-katex",
57 | "sphinx-autodoc-typehints",
58 | "ipython>=8.8.0", # 8.7.0 has ipython3 lexer error
59 | "myst-nb>=1.0.0",
60 | "matplotlib>=3.5.0",
61 | "sphinx-gallery>=0.14.0",
62 | "sphinx-collections>=0.0.1",
63 | "flax",
64 | "sphinx_contributors",
65 | "setuptools",
66 | ]
67 |
68 | [tool.setuptools.packages.find]
69 | include = ["README.md", "LICENSE"]
70 | exclude = ["*_test.py"]
71 |
72 | [tool.ruff]
73 | line-length = 80
74 |
75 | [tool.ruff.lint]
76 | select = [
77 | "F",
78 | "E",
79 | "W291", # whitespace at the end of the line
80 | "B023", # pylint's cell-var-over-loop, closures capturing variables in loop
81 | ]
82 | ignore = [
83 | "E731", # lambdas are allowed
84 | "F401", # allow unused imports
85 | "E402", # allow modules not at top of file
86 | "E741", # allow "l" as a variable name
87 | "E703", # allow semicolons (for jupyter notebooks)
88 | ]
89 |
90 | [tool.pylint.messages_control]
91 | disable = [
92 | "bad-indentation",
93 | "unknown-option-value",
94 | "invalid-name",
95 | "missing-function-docstring",
96 | "missing-class-docstring",
97 | "missing-module-docstring",
98 | "no-member",
99 | "too-many-locals",
100 | "too-many-positional-arguments",
101 | "no-else-return",
102 | "line-too-long",
103 | "too-many-arguments",
104 | "no-value-for-parameter",
105 | "duplicate-code",
106 | "unused-argument",
107 | "too-few-public-methods",
108 | "wrong-import-order",
109 | "unused-import",
110 | "wrong-import-position",
111 | "unnecessary-lambda-assignment",
112 | "too-many-lines",
113 | "too-many-statements",
114 | "deprecated-class",
115 | "redefined-builtin",
116 | "used-before-assignment",
117 | "undefined-variable",
118 | "protected-access",
119 | "not-callable",
120 | "redefined-outer-name",
121 | "too-many-instance-attributes",
122 | "missing-final-newline",
123 | "too-many-public-methods",
124 | "import-error",
125 | ]
126 |
127 | # We include pyink to allow external contributors to optionally, but easily,
128 | # pass google internal formatting checks.
129 | [tool.pyink]
130 | pyink = true # false would mean black is used directly without pyink features
131 | pyink-use-majority-quotes = true
132 | pyink-indentation = 2
133 | line-length = 80
134 | include = '\.pyi?$'
135 |
--------------------------------------------------------------------------------
/optax/tree_utils/_random.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Utilities to generate random pytrees."""
16 |
17 | from collections.abc import Callable
18 | import inspect
19 | from typing import Optional, Union
20 |
21 | import chex
22 | import jax
23 | from optax._src import base
24 |
25 |
26 | def tree_split_key_like(
27 | rng_key: base.PRNGKey, target_tree: chex.ArrayTree
28 | ) -> chex.ArrayTree:
29 | """Split keys to match structure of target tree.
30 |
31 | Args:
32 | rng_key: the key to split.
33 | target_tree: the tree whose structure to match.
34 |
35 | Returns:
36 | a tree of rng keys.
37 | """
38 | tree_def = jax.tree.structure(target_tree)
39 | keys = jax.random.split(rng_key, tree_def.num_leaves)
40 | return jax.tree.unflatten(tree_def, keys)
41 |
42 |
43 | def tree_random_like(
44 | rng_key: base.PRNGKey,
45 | target_tree: chex.ArrayTree,
46 | sampler: Union[
47 | Callable[[base.PRNGKey, base.Shape, jax.typing.DTypeLike],
48 | jax.typing.ArrayLike],
49 | Callable[[base.PRNGKey, base.Shape, jax.typing.DTypeLike,
50 | jax.sharding.Sharding],
51 | jax.typing.ArrayLike]] = jax.random.normal,
52 | dtype: Optional[chex.ArrayDType] = None,
53 | ) -> chex.ArrayTree:
54 | """Create tree with random entries of the same shape as target tree.
55 |
56 | Args:
57 | rng_key: the key for the random number generator.
58 | target_tree: the tree whose structure to match. Leaves must be arrays.
59 | sampler: the noise sampling function, by default ``jax.random.normal``.
60 | dtype: the desired dtype for the random numbers, passed to ``sampler``. If
61 | None, the dtype of the target tree is used if possible.
62 |
63 | Returns:
64 | a random tree with the same structure as ``target_tree``, whose leaves have
65 | distribution ``sampler``.
66 |
67 | .. warning::
68 | The possible dtypes may be limited by the sampler, for example
69 | ``jax.random.rademacher`` only supports integer dtypes and will raise an
70 | error if the dtype of the target tree is not an integer or if the dtype
71 | is not of integer type.
72 |
73 | .. versionadded:: 0.2.1
74 | """
75 | keys_tree = tree_split_key_like(rng_key, target_tree)
76 | sampler_ = sampler
77 | if "out_sharding" not in inspect.signature(sampler).parameters:
78 | sampler_ = lambda key, shape, dtype, *, out_sharding: sampler( # pylint: disable=unnecessary-lambda
79 | key, shape, dtype) # pytype: disable=wrong-arg-count
80 | return jax.tree.map(
81 | # pytype: disable=wrong-keyword-args
82 | lambda leaf, key: sampler_(key, leaf.shape, dtype or leaf.dtype,
83 | out_sharding=jax.typeof(leaf).sharding),
84 | # pytype: enable=wrong-keyword-args
85 | target_tree,
86 | keys_tree,
87 | )
88 |
89 |
90 | def tree_unwrap_random_key_data(input_tree: chex.ArrayTree) -> chex.ArrayTree:
91 | """Unwrap random.key objects in a tree for numerical comparison.
92 |
93 | Args:
94 | input_tree: a tree of arrays and random.key objects.
95 |
96 | Returns:
97 | a tree of arrays and random.key_data objects.
98 | """
99 | def _unwrap_random_key_data(x):
100 | if (isinstance(x, jax.Array)
101 | and jax.dtypes.issubdtype(x.dtype, jax.dtypes.prng_key)):
102 | return jax.random.key_data(x)
103 | return x
104 |
105 | return jax.tree.map(_unwrap_random_key_data, input_tree)
106 |
--------------------------------------------------------------------------------
/optax/contrib/_mechanic_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Specific tests for `mechanic.py`, see `common_test.py` for usual tests."""
16 |
17 | from typing import NamedTuple
18 |
19 | from absl.testing import absltest
20 | import jax
21 | import jax.numpy as jnp
22 | import numpy as np
23 | from optax._src import base
24 | from optax._src import test_utils
25 | from optax._src import update
26 | from optax.contrib import _mechanic
27 | import optax.tree
28 |
29 |
30 | class OptimizerTestState(NamedTuple):
31 | """Inner optimizer state for the Mechanic tests."""
32 |
33 | aggregate_grads: base.Params
34 |
35 |
36 | def _test_optimizer(step_size: float) -> base.GradientTransformation:
37 | """Inner optimizer for the Mechanic tests."""
38 |
39 | # Use SGD for simplicity but add non-trivial optimizer state so that the
40 | # resetting behavior of lookahead can be tested.
41 | def init_fn(params):
42 | aggregate_grads = jax.tree.map(jnp.zeros_like, params)
43 | return OptimizerTestState(aggregate_grads)
44 |
45 | def update_fn(updates, state, params):
46 | # The test optimizer does not use the parameters, but we check that they
47 | # have been passed correctly.
48 | test_utils.assert_trees_all_equal_shapes(updates, params)
49 | aggregate_grads = update.apply_updates(state.aggregate_grads, updates)
50 | updates = jax.tree.map(lambda u: step_size * u, updates)
51 | return updates, OptimizerTestState(aggregate_grads)
52 |
53 | return base.GradientTransformation(init_fn, update_fn)
54 |
55 |
56 | class MechanicTest(absltest.TestCase):
57 |
58 | def setUp(self):
59 | super().setUp()
60 | self.grads = {'x': np.array(2.0), 'y': np.array(-2.0)}
61 | self.initial_params = {'x': np.array(3.0), 'y': np.array(-3.0)}
62 |
63 | def loop(self, optimizer, num_steps, params
64 | ) -> tuple[base.Params, _mechanic.MechanicState]:
65 | """Performs a given number of optimizer steps."""
66 | init_fn, update_fn = optimizer
67 | step = jax.jit(update_fn)
68 | opt_state = jax.jit(init_fn)(params)
69 |
70 | # A no-op change, to verify that tree map works.
71 | opt_state = optax.tree.map_params(init_fn, lambda v: v, opt_state)
72 |
73 | for _ in range(num_steps):
74 | updates, opt_state = step(self.grads, opt_state, params)
75 | print(updates)
76 | params = update.apply_updates(params, updates)
77 |
78 | return params, opt_state
79 |
80 | def test_mechanized(self):
81 | params = self.initial_params
82 | num_betas = 6
83 |
84 | inner_optimizer = _test_optimizer(-0.1)
85 | optimizer = _mechanic.mechanize(
86 | inner_optimizer,
87 | weight_decay=1e-2,
88 | eps=1e-10,
89 | s_init=1e-8,
90 | num_betas=num_betas,
91 | )
92 |
93 | final_params, final_state = self.loop(
94 | optimizer=optimizer, num_steps=1, params=params
95 | )
96 | expected_m = np.array([1.0e-10] * num_betas)
97 | expected_v = np.array([0.0] * num_betas)
98 | expected_s = np.array([1.6666667e-09] * num_betas)
99 |
100 | test_utils.assert_trees_all_close(expected_m, final_state.m)
101 | test_utils.assert_trees_all_close(expected_v, final_state.v)
102 | test_utils.assert_trees_all_close(expected_s, final_state.s)
103 | test_utils.assert_trees_all_close(final_params, params)
104 | test_utils.assert_tree_all_finite((final_params, final_state))
105 |
106 |
107 | if __name__ == '__main__':
108 | absltest.main()
109 |
--------------------------------------------------------------------------------
/optax/_src/float64_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests that types are preserved by the `update` calls when jax_enbable_x64."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import jax
20 | import jax.numpy as jnp
21 |
22 | from optax import transforms
23 |
24 | from optax._src import alias
25 | from optax._src import base
26 | from optax._src import transform
27 | from optax._src import update
28 |
29 |
30 | ALL_MODULES = [
31 | ('identity', base.identity, {}),
32 | ('clip', transforms.clip, {'max_delta': 1.0}),
33 | ('clip_by_global_norm', transforms.clip_by_global_norm, {'max_norm': 1.0}),
34 | ('trace', transforms.trace, {'decay': 0.5, 'nesterov': False}),
35 | ('trace_with_nesterov', transforms.trace, {'decay': 0.5, 'nesterov': True}),
36 | ('scale_by_rss', transform.scale_by_rss, {}),
37 | ('scale_by_rms', transform.scale_by_rms, {}),
38 | ('scale_by_stddev', transform.scale_by_stddev, {}),
39 | ('scale_by_adam', transform.scale_by_adam, {}),
40 | ('scale', transform.scale, {'step_size': 3.0}),
41 | (
42 | 'add_decayed_weights',
43 | transforms.add_decayed_weights,
44 | {'weight_decay': 0.1},
45 | ),
46 | (
47 | 'scale_by_schedule',
48 | transform.scale_by_schedule,
49 | {'step_size_fn': lambda x: x * 0.1},
50 | ),
51 | ('scale_by_trust_ratio', transform.scale_by_trust_ratio, {}),
52 | ('add_noise', transforms.add_noise, {'eta': 1.0, 'gamma': 0.1, 'key': 0}),
53 | ('apply_every_k', transform.apply_every, {}),
54 | ('adagrad', alias.adagrad, {'learning_rate': 0.1}),
55 | ('adam', alias.adam, {'learning_rate': 0.1}),
56 | ('adamw', alias.adamw, {'learning_rate': 0.1}),
57 | ('fromage', alias.fromage, {'learning_rate': 0.1}),
58 | ('lamb', alias.lamb, {'learning_rate': 0.1}),
59 | ('noisy_sgd', alias.noisy_sgd, {'learning_rate': 0.1, 'key': 0}),
60 | ('rmsprop', alias.rmsprop, {'learning_rate': 0.1}),
61 | ('sgd', alias.sgd, {'learning_rate': 0.1}),
62 | ('sign_sgd', alias.sgd, {'learning_rate': 0.1}),
63 | ]
64 |
65 |
66 | class Float64Test(parameterized.TestCase):
67 |
68 | def _assert_dtype_equals(self, tree1, tree2):
69 | tree1_types = jax.tree.map(lambda t: t.dtype, tree1)
70 | tree2_types = jax.tree.map(lambda t: t.dtype, tree2)
71 | self.assertEqual(tree1_types, tree2_types)
72 |
73 | @parameterized.named_parameters(ALL_MODULES)
74 | def test_mixed_dtype_input_outputs(self, transform_constr, transform_kwargs):
75 | jax.config.update('jax_enable_x64', True)
76 | initial_params = (
77 | jnp.array([1.0, 2.0], dtype=jnp.float32),
78 | jnp.array([3.0, 4.0], dtype=jnp.float64),
79 | )
80 | updates = (
81 | jnp.array([10.0, 21.0], dtype=jnp.float32),
82 | jnp.array([33.0, 42.0], dtype=jnp.float64),
83 | )
84 | scaler = transform_constr(**transform_kwargs)
85 | init_fn = jax.jit(scaler.init)
86 | update_fn = jax.jit(scaler.update)
87 |
88 | initial_state = init_fn(initial_params)
89 | updates, new_state = update_fn(
90 | updates, initial_state, params=initial_params
91 | )
92 | new_params = update.apply_updates(initial_params, updates)
93 |
94 | self._assert_dtype_equals(initial_state, new_state)
95 | self._assert_dtype_equals(initial_params, new_params)
96 | jax.config.update('jax_enable_x64', False)
97 |
98 |
99 | if __name__ == '__main__':
100 | absltest.main()
101 |
--------------------------------------------------------------------------------
/optax/transforms/_adding_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for methods in `optax.transforms._adding.py`."""
16 |
17 | from absl.testing import absltest
18 | import jax
19 | import jax.numpy as jnp
20 | from optax._src import test_utils
21 | from optax.transforms import _adding
22 |
23 | STEPS = 50
24 |
25 |
26 | class AddingTest(absltest.TestCase):
27 |
28 | def setUp(self):
29 | super().setUp()
30 | self.init_params = (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0]))
31 | self.per_step_updates = (jnp.array([500.0, 5.0]), jnp.array([300.0, 3.0]))
32 |
33 | def test_add_decayed_weights(self):
34 | # Define a transform that add decayed weights.
35 | # We can define a mask either as a pytree, or as a function that
36 | # returns the pytree. Below we define the pytree directly.
37 | mask = (True, {"a": True, "b": False})
38 | tx = _adding.add_decayed_weights(0.1, mask=mask)
39 | # Define input updates and weights.
40 | updates = (
41 | jnp.zeros((2,), dtype=jnp.float32),
42 | {
43 | "a": jnp.zeros((2,), dtype=jnp.float32),
44 | "b": jnp.zeros((2,), dtype=jnp.float32),
45 | },
46 | )
47 | weights = (
48 | jnp.ones((2,), dtype=jnp.float32),
49 | {
50 | "a": jnp.ones((2,), dtype=jnp.float32),
51 | "b": jnp.ones((2,), dtype=jnp.float32),
52 | },
53 | )
54 | # This mask means that we will add decayed weights to the first two
55 | # terms in the input updates, but not to the last element.
56 | expected_tx_updates = (
57 | 0.1*jnp.ones((2,), dtype=jnp.float32),
58 | {
59 | "a": 0.1*jnp.ones((2,), dtype=jnp.float32),
60 | "b": jnp.zeros((2,), dtype=jnp.float32),
61 | },
62 | )
63 | # Apply transform
64 | state = tx.init(weights)
65 | transform_fn = jax.jit(tx.update)
66 | new_updates, _ = transform_fn(updates, state, weights)
67 | # Assert output as expected.
68 | test_utils.assert_trees_all_close(new_updates, expected_tx_updates)
69 |
70 | def test_add_noise_has_correct_variance_scaling(self):
71 | # Prepare to compare noise with a rescaled unit-variance substitute.
72 | eta = 0.3
73 | gamma = 0.55
74 | key = 314
75 | noise = _adding.add_noise(eta, gamma, key)
76 | noise_unit = _adding.add_noise(1.0, 0.0, key)
77 |
78 | params = self.init_params
79 | state = noise.init(params)
80 | state_unit = noise_unit.init(params)
81 |
82 | # Check the noise itself by adding it to zeros.
83 | updates = jax.tree.map(jnp.zeros_like, params)
84 |
85 | for i in range(1, STEPS + 1):
86 | updates_i, state = jax.jit(noise.update)(updates, state)
87 | updates_i_unit, state_unit = noise_unit.update(updates, state_unit)
88 |
89 | scale = jnp.sqrt(eta / i**gamma)
90 |
91 | updates_i_rescaled = jax.tree.map(
92 | lambda g, s=scale: g * s, updates_i_unit
93 | )
94 |
95 | test_utils.assert_trees_all_close(
96 | updates_i, updates_i_rescaled, rtol=1e-4)
97 |
98 | def test_none_argument(self):
99 | weights = (
100 | jnp.ones((2,), dtype=jnp.float32),
101 | {
102 | "a": jnp.ones((2,), dtype=jnp.float32),
103 | "b": jnp.ones((2,), dtype=jnp.float32),
104 | },
105 | )
106 | tf = _adding.add_decayed_weights(0.1, mask=None)
107 | tf.update(None, 0, weights)
108 |
109 |
110 | if __name__ == "__main__":
111 | absltest.main()
112 |
--------------------------------------------------------------------------------
/optax/transforms/_freezing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utilites for freezing parameters."""
17 |
18 | from typing import Union
19 |
20 | import chex
21 | import jax
22 |
23 | from optax._src import base
24 | # pylint: disable=g-importing-member
25 | from optax.transforms._combining import partition
26 | from optax.transforms._masking import masked
27 | # pylint: enable=g-importing-member
28 |
29 |
30 | def freeze(mask: Union[bool, chex.ArrayTree]) -> base.GradientTransformation:
31 | """Create a transformation that zeros out gradient updates for `mask=True`.
32 |
33 | This essentially freezes (i.e. holding constant) masked parameters.
34 |
35 | The mask must be static (i.e., not dependent on runtime values or updated
36 | during training) and can be:
37 |
38 | - a single boolean (or 0-d JAX bool array), causing every parameter to be
39 | either all-frozen (True) or all-trainable (False), or
40 | - a PyTree of booleans matching the structure of the parameters, where
41 | each leaf indicates whether that specific parameter leaf should be
42 | frozen (True) or left unchanged (False).
43 |
44 | Args:
45 | mask: A boolean prefix tree mask indicating which parameters to freeze.
46 |
47 | Example:
48 | >>> import jax.numpy as jnp
49 | >>> from optax import freeze
50 | >>> params = {'a': jnp.zeros(1), 'b': jnp.zeros(2)}
51 | >>> mask = {'a': True, 'b': False} # Freeze 'a', train 'b'
52 | >>> freezer = freeze(mask)
53 |
54 | Returns:
55 | An Optax `GradientTransformation` which applies `set_to_zero()` wherever
56 | `mask==True`, and leaves other gradients intact.
57 |
58 | .. seealso::
59 | :func:`optax.selective_transform` : For partitioning updates
60 | so only un-frozen parameters are optimized.
61 | """
62 | return masked(base.set_to_zero(), mask)
63 |
64 |
65 | def selective_transform(
66 | optimizer: base.GradientTransformation,
67 | *, # force kw-only arguments to show this is a freeze and not allow mask
68 | freeze_mask: Union[bool, chex.ArrayTree],
69 | ) -> base.GradientTransformation:
70 | """Partition updates so that only un-frozen parameters are optimized.
71 |
72 | Example:
73 | >>> import jax.numpy as jnp
74 | >>> from optax import selective_transform
75 | >>> params = {'a': jnp.zeros(1), 'b': jnp.zeros(2)}
76 | >>> mask = {'a': True, 'b': False} # Freeze 'a', train 'b'
77 | >>> selective_opt = selective_transform(optax.adam(1e-3), freeze_mask=mask)
78 |
79 | Args:
80 | optimizer: The inner Optax optimizer to apply to unfrozen leaves.
81 | freeze_mask: A *static* mask (i.e., not dependent on runtime values or
82 | updated during training). It can be either:
83 |
84 | - a scalar bool (or 0-d JAX bool array) to freeze everything (True) or
85 | nothing (False)
86 | - a PyTree of booleans mirroring the parameter tree, marking each leaf
87 | to freeze (True) or train (False).
88 |
89 | Returns:
90 | A `GradientTransformation` that routes each parameter leaf through:
91 |
92 | - the given `optimizer` if its mask is False (“train”),
93 | - `set_to_zero()` if its mask is True (“freeze”).
94 |
95 | .. seealso::
96 | :func:`optax.freeze` : For simply zeroing out gradients
97 | according to a mask.
98 | """
99 |
100 | def label_fn(params: base.PyTree):
101 | del params
102 | return jax.tree.map(lambda m: "freeze" if m else "train", freeze_mask)
103 |
104 | return partition(
105 | {"train": optimizer, "freeze": base.set_to_zero()},
106 | param_labels=label_fn,
107 | )
108 |
--------------------------------------------------------------------------------
/optax/second_order/_deprecated.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Deprecated second order utilities kept for backward compatibility.
16 | """
17 |
18 | import abc
19 | import functools
20 | from typing import Any, Protocol
21 |
22 | import jax
23 | from jax import flatten_util
24 | import jax.numpy as jnp
25 | from optax._src.deprecations import warn_deprecated_function # pylint: disable=g-importing-member
26 |
27 |
28 | def _ravel(p: Any) -> jax.Array:
29 | return flatten_util.ravel_pytree(p)[0]
30 |
31 |
32 | class LossFn(Protocol):
33 | """A loss function to be optimized."""
34 |
35 | @abc.abstractmethod
36 | def __call__(
37 | self, params: Any, inputs: jax.Array, targets: jax.Array
38 | ) -> jax.Array:
39 | ...
40 |
41 |
42 | @functools.partial(warn_deprecated_function, version_removed='0.2.9')
43 | def hvp(
44 | loss: LossFn,
45 | v: jax.Array,
46 | params: Any,
47 | inputs: jax.Array,
48 | targets: jax.Array,
49 | ) -> jax.Array:
50 | """Performs an efficient vector-Hessian (of `loss`) product.
51 |
52 | .. deprecated: 0.2.7. This function will be removed in 0.2.9
53 |
54 | Args:
55 | loss: the loss function.
56 | v: a vector of size `ravel(params)`.
57 | params: model parameters.
58 | inputs: inputs at which `loss` is evaluated.
59 | targets: targets at which `loss` is evaluated.
60 |
61 | Returns:
62 | An Array corresponding to the product of `v` and the Hessian of `loss`
63 | evaluated at `(params, inputs, targets)`.
64 | """
65 | _, unravel_fn = flatten_util.ravel_pytree(params)
66 | loss_fn = lambda p: loss(p, inputs, targets)
67 | return jax.jvp(jax.grad(loss_fn), [params], [unravel_fn(v)])[1]
68 |
69 |
70 | @functools.partial(warn_deprecated_function, version_removed='0.2.9')
71 | def hessian_diag(
72 | loss: LossFn,
73 | params: Any,
74 | inputs: jax.Array,
75 | targets: jax.Array,
76 | ) -> jax.Array:
77 | """Computes the diagonal hessian of `loss` at (`inputs`, `targets`).
78 |
79 | .. deprecated: 0.2.7. This function will be removed in 0.2.9
80 |
81 | Args:
82 | loss: the loss function.
83 | params: model parameters.
84 | inputs: inputs at which `loss` is evaluated.
85 | targets: targets at which `loss` is evaluated.
86 |
87 | Returns:
88 | A DeviceArray corresponding to the product to the Hessian of `loss`
89 | evaluated at `(params, inputs, targets)`.
90 | """
91 | vs = jnp.eye(_ravel(params).size)
92 | comp = lambda v: jnp.vdot(v, _ravel(hvp(loss, v, params, inputs, targets)))
93 | return jax.vmap(comp)(vs)
94 |
95 |
96 | @functools.partial(warn_deprecated_function, version_removed='0.2.9')
97 | def fisher_diag(
98 | negative_log_likelihood: LossFn,
99 | params: Any,
100 | inputs: jax.Array,
101 | targets: jax.Array,
102 | ) -> jax.Array:
103 | """Computes the diagonal of the (observed) Fisher information matrix.
104 |
105 | .. deprecated: 0.2.7. This function will be removed in 0.2.9
106 |
107 | Args:
108 | negative_log_likelihood: the negative log likelihood function with expected
109 | signature `loss = fn(params, inputs, targets)`.
110 | params: model parameters.
111 | inputs: inputs at which `negative_log_likelihood` is evaluated.
112 | targets: targets at which `negative_log_likelihood` is evaluated.
113 |
114 | Returns:
115 | An Array corresponding to the product to the Hessian of
116 | `negative_log_likelihood` evaluated at `(params, inputs, targets)`.
117 | """
118 | return jnp.square(
119 | _ravel(jax.grad(negative_log_likelihood)(params, inputs, targets))
120 | )
121 |
--------------------------------------------------------------------------------
/optax/_src/update.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Apply transformed gradient updates to parameters."""
16 |
17 | import jax
18 | import jax.numpy as jnp
19 | from optax._src import base
20 |
21 |
22 | def apply_updates(params: base.Params, updates: base.Updates) -> base.Params:
23 | """Applies an update to the corresponding parameters.
24 |
25 | This is a utility functions that applies an update to a set of parameters, and
26 | then returns the updated parameters to the caller. As an example, the update
27 | may be a gradient transformed by a sequence of`GradientTransformations`. This
28 | function is exposed for convenience, but it just adds updates and parameters;
29 | you may also apply updates to parameters manually, using `jax.tree.map`
30 | (e.g. if you want to manipulate updates in custom ways before applying them).
31 |
32 | Args:
33 | params: a tree of parameters.
34 | updates: a tree of updates, the tree structure and the shape of the leaf
35 | nodes must match that of `params`.
36 |
37 | Returns:
38 | Updated parameters, with same structure, shape and type as `params`.
39 | """
40 | return jax.tree.map(
41 | lambda p, u: (
42 | None if p is None else jnp.asarray(p + u).astype(jnp.asarray(p).dtype)
43 | ),
44 | params,
45 | updates,
46 | is_leaf=lambda x: x is None,
47 | )
48 |
49 |
50 | def incremental_update(
51 | new_tensors: base.Params, old_tensors: base.Params,
52 | step_size: jax.typing.ArrayLike) -> base.Params:
53 | """Incrementally update parameters via polyak averaging.
54 |
55 | Polyak averaging tracks an (exponential moving) average of the past
56 | parameters of a model, for use at test/evaluation time.
57 |
58 | Args:
59 | new_tensors: the latest value of the tensors.
60 | old_tensors: a moving average of the values of the tensors.
61 | step_size: the step_size used to update the polyak average on each step.
62 |
63 | Returns:
64 | an updated moving average `step_size*new+(1-step_size)*old` of the params.
65 |
66 | References:
67 | [Polyak et al, 1991](https://epubs.siam.org/doi/10.1137/0330046)
68 | """
69 | return jax.tree.map(
70 | lambda new, old: (
71 | None if new is None else step_size * new + (1.0 - step_size) * old
72 | ),
73 | new_tensors,
74 | old_tensors,
75 | is_leaf=lambda x: x is None,
76 | )
77 |
78 |
79 | def periodic_update(
80 | new_tensors: base.Params,
81 | old_tensors: base.Params,
82 | steps: jax.typing.ArrayLike, # int
83 | update_period: jax.typing.ArrayLike, # int
84 | ) -> base.Params:
85 | """Periodically update all parameters with new values.
86 |
87 | A slow copy of a model's parameters, updated every K actual updates, can be
88 | used to implement forms of self-supervision (in supervised learning), or to
89 | stabilize temporal difference learning updates (in reinforcement learning).
90 |
91 | Args:
92 | new_tensors: the latest value of the tensors.
93 | old_tensors: a slow copy of the model's parameters.
94 | steps: number of update steps on the "online" network.
95 | update_period: every how many steps to update the "target" network.
96 |
97 | Returns:
98 | a slow copy of the model's parameters, updated every `update_period` steps.
99 |
100 | References:
101 | [Grill et al., 2020](https://arxiv.org/abs/2006.07733)
102 | [Mnih et al., 2015](https://arxiv.org/abs/1312.5602)
103 | """
104 | return jax.lax.cond(
105 | jnp.mod(steps, update_period) == 0,
106 | lambda _: new_tensors,
107 | lambda _: old_tensors,
108 | None,
109 | )
110 |
--------------------------------------------------------------------------------
/optax/assignment/_hungarian_algorithm_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for the Hungarian algorithm."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import jax
20 | import jax.numpy as jnp
21 | import jax.random as jrd
22 | from optax import assignment
23 | import scipy
24 |
25 |
26 | class HungarianAlgorithmTest(parameterized.TestCase):
27 |
28 | @parameterized.product(
29 | fn=[assignment.hungarian_algorithm, assignment.base_hungarian_algorithm],
30 | n=[0, 1, 2, 4, 8, 16],
31 | m=[0, 1, 2, 4, 8, 16],
32 | )
33 | def test_hungarian_algorithm(self, fn, n, m):
34 | key = jrd.key(0)
35 | costs = jrd.normal(key, (n, m))
36 |
37 | i, j = fn(costs)
38 |
39 | r = min(costs.shape)
40 |
41 | with self.subTest('i has correct shape'):
42 | assert i.shape == (r,)
43 |
44 | with self.subTest('j has correct shape'):
45 | assert j.shape == (r,)
46 |
47 | with self.subTest('i has correct dtype'):
48 | assert jnp.issubdtype(i.dtype, jnp.integer)
49 |
50 | with self.subTest('j has correct dtype'):
51 | assert jnp.issubdtype(j.dtype, jnp.integer)
52 |
53 | with self.subTest('each element of i is non-negative'):
54 | assert jnp.all(0 <= i)
55 |
56 | with self.subTest('each element of j is non-negative'):
57 | assert jnp.all(0 <= j)
58 |
59 | with self.subTest('each element of i is less than the number of rows'):
60 | assert (i < costs.shape[0]).all()
61 |
62 | with self.subTest('each element of j is less than the number of columns'):
63 | assert (j < costs.shape[1]).all()
64 |
65 | x = jnp.zeros(costs.shape[0], int).at[i].add(1)
66 |
67 | with self.subTest('all elements of i lie in the valid range'):
68 | assert x.sum() == r
69 |
70 | with self.subTest('no two elements of i are equal'):
71 | assert (x <= 1).all()
72 |
73 | y = jnp.zeros(costs.shape[1], int).at[j].add(1)
74 |
75 | with self.subTest('all elements of j lie in the valid range'):
76 | assert y.sum() == r
77 |
78 | with self.subTest('no two elements of j are equal'):
79 | assert (y <= 1).all()
80 |
81 | cost_optax = costs[i, j].sum()
82 |
83 | i_scipy, j_scipy = scipy.optimize.linear_sum_assignment(costs)
84 | cost_scipy = costs[i_scipy, j_scipy].sum()
85 |
86 | with self.subTest('cost matches that obtained by scipy'):
87 | assert jnp.isclose(cost_optax, cost_scipy)
88 |
89 | @parameterized.product(
90 | fn=[assignment.hungarian_algorithm, assignment.base_hungarian_algorithm],
91 | k=[0, 1, 2, 4],
92 | n=[0, 1, 2, 4],
93 | m=[0, 1, 2, 4],
94 | )
95 | def test_hungarian_algorithm_vmap(self, fn, k, n, m):
96 | key = jrd.key(0)
97 | costs = jrd.normal(key, (k, n, m))
98 |
99 | with self.subTest('works under vmap'):
100 | i, j = jax.vmap(fn)(costs)
101 |
102 | r = min(costs.shape[1:])
103 |
104 | with self.subTest('batch i has correct shape'):
105 | assert i.shape == (k, r)
106 |
107 | with self.subTest('batch j has correct shape'):
108 | assert j.shape == (k, r)
109 |
110 | @parameterized.product(
111 | fn=[assignment.hungarian_algorithm, assignment.base_hungarian_algorithm],
112 | )
113 | def test_hungarian_algorithm_jit(self, fn):
114 | key = jrd.key(0)
115 | costs = jrd.normal(key, (20, 30))
116 |
117 | with self.subTest('works under jit'):
118 | i, j = jax.jit(fn)(costs)
119 |
120 | r = min(costs.shape)
121 |
122 | with self.subTest('i has correct shape'):
123 | assert i.shape == (r,)
124 |
125 | with self.subTest('j has correct shape'):
126 | assert j.shape == (r,)
127 |
128 |
129 | if __name__ == '__main__':
130 | absltest.main()
131 |
--------------------------------------------------------------------------------
/optax/tree_utils/_random_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for methods defined in optax.tree_utils._random."""
16 |
17 | from collections.abc import Callable
18 |
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 | import chex
22 | import jax
23 | import jax.numpy as jnp
24 | import jax.random as jrd
25 | import numpy as np
26 | from optax import tree_utils as otu
27 | from optax._src import base
28 | from optax._src import test_utils
29 |
30 | # We consider samplers with varying input dtypes, we do not test all possible
31 | # samplers from `jax.random`.
32 | _SAMPLER_DTYPES = (
33 | {'sampler': jrd.normal, 'dtype': None},
34 | {'sampler': jrd.normal, 'dtype': 'bfloat16'},
35 | {'sampler': jrd.normal, 'dtype': 'float32'},
36 | {'sampler': jrd.rademacher, 'dtype': 'int32'},
37 | {'sampler': jrd.bits, 'dtype': 'uint32'},
38 | )
39 |
40 |
41 | def get_variable(type_var: str):
42 | """Get a variable of various shape."""
43 | if type_var == 'real_array':
44 | return jnp.asarray([1.0, 2.0])
45 | if type_var == 'complex_array':
46 | return jnp.asarray([1.0 + 1j * 2.0, 3.0 + 4j * 5.0])
47 | if type_var == 'pytree':
48 | pytree = {'k1': 1.0, 'k2': (2.0, 3.0), 'k3': jnp.asarray([4.0, 5.0])}
49 | return jax.tree.map(jnp.asarray, pytree)
50 | raise ValueError(f'Invalid type_var {type_var}')
51 |
52 |
53 | class RandomTest(parameterized.TestCase):
54 |
55 | def test_tree_split_key_like(self):
56 | rng_key = jrd.key(0)
57 | tree = {'a': jnp.zeros(2), 'b': {'c': [jnp.ones(3), jnp.zeros([4, 5])]}}
58 | keys_tree = otu.tree_split_key_like(rng_key, tree)
59 |
60 | with self.subTest('Test structure matches'):
61 | self.assertEqual(jax.tree.structure(tree), jax.tree.structure(keys_tree))
62 |
63 | with self.subTest('Test random key split'):
64 | fst = jnp.stack(jax.tree.flatten(keys_tree)[0])
65 | snd = jrd.split(rng_key, jax.tree.structure(tree).num_leaves)
66 | np.testing.assert_array_equal(otu.tree_unwrap_random_key_data(fst),
67 | otu.tree_unwrap_random_key_data(snd))
68 |
69 | @parameterized.product(
70 | _SAMPLER_DTYPES,
71 | type_var=['real_array', 'complex_array', 'pytree'],
72 | )
73 | def test_tree_random_like(
74 | self,
75 | sampler: Callable[
76 | [base.PRNGKey, base.Shape, jax.typing.DTypeLike], jax.Array
77 | ],
78 | dtype: str,
79 | type_var: str,
80 | ):
81 | """Test that tree_random_like matches its flat counterpart."""
82 | if dtype is not None:
83 | dtype = jnp.dtype(dtype)
84 | rng_key = jrd.key(0)
85 | target_tree = get_variable(type_var)
86 |
87 | rand_tree = otu.tree_random_like(
88 | rng_key, target_tree, sampler=sampler, dtype=dtype
89 | )
90 |
91 | flat_tree, tree_def = jax.tree.flatten(target_tree)
92 |
93 | with self.subTest('Test structure matches'):
94 | self.assertEqual(tree_def, jax.tree.structure(rand_tree))
95 |
96 | with self.subTest('Test tree_random_like matches flat random like'):
97 | flat_rand_tree, _ = jax.tree.flatten(rand_tree)
98 | keys = jrd.split(rng_key, tree_def.num_leaves)
99 | expected_flat_rand_tree = [
100 | sampler(key, x.shape, dtype or x.dtype)
101 | for key, x in zip(keys, flat_tree)
102 | ]
103 | test_utils.assert_trees_all_close(flat_rand_tree, expected_flat_rand_tree)
104 |
105 | with self.subTest('Test dtype are as expected'):
106 | if dtype is not None:
107 | for x in jax.tree.leaves(rand_tree):
108 | self.assertEqual(x.dtype, dtype)
109 | else:
110 | chex.assert_trees_all_equal_dtypes(rand_tree, target_tree)
111 |
112 |
113 | if __name__ == '__main__':
114 | absltest.main()
115 |
--------------------------------------------------------------------------------
/optax/_src/test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Utilities for optax tests."""
16 |
17 | import contextlib
18 | import logging
19 | import re
20 | import threading
21 |
22 | import jax
23 | import numpy as np
24 |
25 | _LOG_LIST = []
26 |
27 |
28 | class _LogsToListHandler(logging.Handler):
29 | """A handler for capturing logs programmatically without printing them."""
30 |
31 | def emit(self, record):
32 | _LOG_LIST.append(record)
33 |
34 |
35 | logger = logging.getLogger("jax")
36 | logger.addHandler(_LogsToListHandler())
37 |
38 | # We need a lock to be able to run this context manager from multiple threads.
39 | _compilation_log_lock = threading.Lock()
40 |
41 |
42 | @contextlib.contextmanager
43 | def log_compilations():
44 | """A utility for programmatically capturing JAX compilation logs."""
45 | with _compilation_log_lock, jax.log_compiles():
46 | _LOG_LIST.clear()
47 | compilation_logs = []
48 | yield compilation_logs # these will contain the compilation logs
49 | compilation_logs.extend([
50 | log for log in _LOG_LIST
51 | if re.search(r"Finished .* compilation", log.getMessage())
52 | ])
53 |
54 |
55 | def assert_trees_all_close(actual, desired, rtol=1e-6, atol=0.0, err_msg=None):
56 | """Asserts that two pytrees of arrays are close within a tolerance."""
57 | flat_a, tree_def_a = jax.tree_util.tree_flatten(actual)
58 | flat_d, tree_def_d = jax.tree_util.tree_flatten(desired)
59 | if tree_def_a != tree_def_d:
60 | raise AssertionError(
61 | f"Trees have different structures:\n{tree_def_a}\n{tree_def_d}"
62 | )
63 | for x, y in zip(flat_a, flat_d):
64 | np.testing.assert_allclose(x, y, rtol=rtol, atol=atol, err_msg=err_msg)
65 |
66 |
67 | def assert_trees_all_equal(actual, desired, err_msg=None):
68 | """Asserts that two pytrees of arrays are equal."""
69 | flat_a, tree_def_a = jax.tree_util.tree_flatten(actual)
70 | flat_d, tree_def_d = jax.tree_util.tree_flatten(desired)
71 | if tree_def_a != tree_def_d:
72 | raise AssertionError(
73 | f"Trees have different structures:\n{tree_def_a}\n{tree_def_d}"
74 | )
75 | for x, y in zip(flat_a, flat_d):
76 | np.testing.assert_array_equal(x, y, err_msg=err_msg)
77 |
78 |
79 | def assert_trees_all_equal_structs(actual, desired):
80 | """Asserts that two pytrees have the same structure."""
81 | if (jax.tree_util.tree_structure(actual) !=
82 | jax.tree_util.tree_structure(desired)):
83 | raise AssertionError(
84 | f"Trees have different structures:\n{actual}\n{desired}"
85 | )
86 |
87 |
88 | def assert_tree_all_finite(actual, err_msg=None):
89 | """Asserts that all arrays in a pytree are finite."""
90 | for x in jax.tree_util.tree_leaves(actual):
91 | if not np.all(np.isfinite(x)):
92 | raise AssertionError(f"Array {x} is not finite. {err_msg}")
93 |
94 |
95 | def assert_trees_all_equal_shapes(actual, desired, err_msg=None):
96 | """Asserts that two pytrees of arrays have the same shapes."""
97 | assert_trees_all_equal_structs(actual, desired)
98 | for x, y in zip(jax.tree_util.tree_leaves(actual),
99 | jax.tree_util.tree_leaves(desired)):
100 | if x.shape != y.shape:
101 | raise AssertionError(
102 | f"Shapes are not equal: {x.shape} != {y.shape}. {err_msg}"
103 | )
104 |
105 |
106 | def assert_trees_all_equal_dtypes(actual, desired, err_msg=None):
107 | """Asserts that two pytrees of arrays have the same dtypes."""
108 | assert_trees_all_equal_structs(actual, desired)
109 | for x, y in zip(jax.tree_util.tree_leaves(actual),
110 | jax.tree_util.tree_leaves(desired)):
111 | if x.dtype != y.dtype:
112 | raise AssertionError(
113 | f"Dtypes are not equal: {x.dtype} != {y.dtype}. {err_msg}"
114 | )
115 |
--------------------------------------------------------------------------------
/optax/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """The transforms sub-package."""
16 |
17 | # pylint: disable=g-importing-member
18 |
19 | from optax.transforms._accumulation import ema
20 | from optax.transforms._accumulation import EmaState
21 | from optax.transforms._accumulation import MultiSteps
22 | from optax.transforms._accumulation import MultiStepsState
23 | from optax.transforms._accumulation import ShouldSkipUpdateFunction
24 | from optax.transforms._accumulation import skip_large_updates
25 | from optax.transforms._accumulation import skip_not_finite
26 | from optax.transforms._accumulation import trace
27 | from optax.transforms._accumulation import TraceState
28 | from optax.transforms._adding import add_decayed_weights
29 | from optax.transforms._adding import add_noise
30 | from optax.transforms._adding import AddNoiseState
31 | from optax.transforms._clipping import adaptive_grad_clip
32 | from optax.transforms._clipping import clip
33 | from optax.transforms._clipping import clip_by_block_rms
34 | from optax.transforms._clipping import clip_by_global_norm
35 | from optax.transforms._clipping import per_example_global_norm_clip
36 | from optax.transforms._clipping import per_example_layer_norm_clip
37 | from optax.transforms._clipping import unitwise_clip
38 | from optax.transforms._clipping import unitwise_norm
39 | from optax.transforms._combining import chain
40 | from optax.transforms._combining import named_chain
41 | from optax.transforms._combining import partition
42 | from optax.transforms._combining import PartitionState
43 | from optax.transforms._conditionality import apply_if_finite
44 | from optax.transforms._conditionality import ApplyIfFiniteState
45 | from optax.transforms._conditionality import conditionally_mask
46 | from optax.transforms._conditionality import conditionally_transform
47 | from optax.transforms._conditionality import ConditionallyMaskState
48 | from optax.transforms._conditionality import ConditionallyTransformState
49 | from optax.transforms._conditionality import ConditionFn
50 | from optax.transforms._constraining import keep_params_nonnegative
51 | from optax.transforms._constraining import NonNegativeParamsState
52 | from optax.transforms._constraining import zero_nans
53 | from optax.transforms._constraining import ZeroNansState
54 | from optax.transforms._freezing import freeze
55 | from optax.transforms._freezing import selective_transform
56 | from optax.transforms._layouts import flatten
57 | from optax.transforms._masking import masked
58 | from optax.transforms._masking import MaskedNode
59 | from optax.transforms._masking import MaskedState
60 | from optax.transforms._monitoring import measure_with_ema
61 | from optax.transforms._monitoring import monitor
62 | from optax.transforms._monitoring import MonitorState
63 | from optax.transforms._monitoring import snapshot
64 | from optax.transforms._monitoring import SnapshotState
65 |
66 | __all__ = (
67 | "adaptive_grad_clip",
68 | "add_decayed_weights",
69 | "add_noise",
70 | "AddNoiseState",
71 | "apply_if_finite",
72 | "ApplyIfFiniteState",
73 | "chain",
74 | "clip_by_block_rms",
75 | "clip_by_global_norm",
76 | "clip",
77 | "conditionally_mask",
78 | "ConditionallyMaskState",
79 | "conditionally_transform",
80 | "ConditionallyTransformState",
81 | "ema",
82 | "EmaState",
83 | "flatten",
84 | "freeze",
85 | "keep_params_nonnegative",
86 | "masked",
87 | "MaskedState",
88 | "measure_with_ema",
89 | "monitor",
90 | "MonitorState",
91 | "MultiSteps",
92 | "MultiStepsState",
93 | "named_chain",
94 | "NonNegativeParamsState",
95 | "partition",
96 | "PartitionState",
97 | "selective_transform",
98 | "ShouldSkipUpdateFunction",
99 | "skip_large_updates",
100 | "skip_not_finite",
101 | "snapshot",
102 | "SnapshotState",
103 | "trace",
104 | "TraceState",
105 | "zero_nans",
106 | "ZeroNansState",
107 | )
108 |
--------------------------------------------------------------------------------
/optax/losses/_self_supervised_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for self-supervised losses in `optax.losses._self_supervised.py`."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import jax
20 | import jax.numpy as jnp
21 | import numpy as np
22 |
23 | from optax.losses import _self_supervised
24 |
25 |
26 | class NtxentTest(absltest.TestCase):
27 |
28 | def setUp(self):
29 | super().setUp()
30 | self.ys = jnp.array([
31 | [-1.9540, 1.0780],
32 | [0.2380, -0.5703],
33 | [1.8745, -0.0195],
34 | [-0.6719, -1.9210],
35 | ])
36 | self.ys_2 = jnp.array([
37 | [0.0, 0.0],
38 | [0.2380, -0.5703],
39 | [1.8745, -0.0195],
40 | [-0.6719, -1.9210],
41 | ])
42 | self.ts_1 = jnp.array([0, 0, 1, 1])
43 | self.ts_2 = jnp.array([0, 0, 0, 1])
44 | # Calculated expected output
45 | self.exp_1 = jnp.array(14.01032)
46 | self.exp_2 = jnp.array(8.968544)
47 | self.exp_3 = jnp.array(9.2889)
48 |
49 | def test_batched(self):
50 | np.testing.assert_allclose(
51 | jax.jit(_self_supervised.ntxent)(self.ys, self.ts_1),
52 | self.exp_1,
53 | atol=1e-4,
54 | )
55 |
56 | np.testing.assert_allclose(
57 | jax.jit(_self_supervised.ntxent)(self.ys, self.ts_2),
58 | self.exp_2,
59 | atol=1e-4,
60 | )
61 |
62 | np.testing.assert_allclose(
63 | jax.jit(_self_supervised.ntxent)(self.ys_2, self.ts_1),
64 | self.exp_3,
65 | atol=1e-4,
66 | )
67 |
68 |
69 | class TripletMarginLossTest(parameterized.TestCase):
70 |
71 | def setUp(self):
72 | super().setUp()
73 | self.a1 = jnp.ones((2, 2))
74 | self.p1 = jnp.zeros((2, 2))
75 | self.n1 = jnp.ones((2, 2)) * 2
76 | self.a2 = jnp.zeros((2, 2))
77 | self.p2 = jnp.ones((2, 2))
78 | self.n2 = jnp.ones((2, 2)) * 2
79 |
80 | @parameterized.parameters([
81 | {
82 | 'anchor': np.ones((2, 2)),
83 | 'positive': np.zeros((2, 2)),
84 | 'negative': np.ones((2, 2)) * 2,
85 | 'margin': 1.0,
86 | },
87 | {
88 | 'anchor': np.zeros((2, 2)),
89 | 'positive': np.ones((2, 2)),
90 | 'negative': np.ones((2, 2)) * 2,
91 | 'margin': 1.0,
92 | }
93 | ])
94 | def test_batched(self, anchor, positive, negative, margin):
95 | def testing_triplet_margin_loss(a, p, n, margin=1.0, p_norm=2, eps=1e-6):
96 | ap_distance = jnp.sqrt(jnp.sum(jnp.power(a - p, p_norm)) + eps)
97 | an_distance = jnp.sqrt(jnp.sum(jnp.power(a - n, p_norm)) + eps)
98 | return jnp.maximum(ap_distance - an_distance + margin, 0)
99 |
100 | handmade_result = testing_triplet_margin_loss(
101 | a=anchor, p=positive, n=negative, margin=margin
102 | )
103 | result = jax.jit(_self_supervised.triplet_margin_loss)(
104 | anchor, positive, negative
105 | )
106 | np.testing.assert_allclose(result, handmade_result, atol=1e-4)
107 |
108 | @parameterized.parameters([
109 | {
110 | 'anchor': np.ones((2, 2)),
111 | 'positive': np.zeros((2, 2)),
112 | 'negative': np.ones((2, 2)) * 2,
113 | },
114 | ])
115 | def test_vmap(self, anchor, positive, negative):
116 | original_loss = _self_supervised.triplet_margin_loss(anchor, positive,
117 | negative)
118 | anchor_batched = anchor.reshape(1, *anchor.shape)
119 | positive_batched = positive.reshape(1, *positive.shape)
120 | negative_batched = negative.reshape(1, *negative.shape)
121 | vmap_loss = jax.jit(
122 | jax.vmap(_self_supervised.triplet_margin_loss, in_axes=(0, 0, 0)))(
123 | anchor_batched, positive_batched, negative_batched)
124 | np.testing.assert_allclose(vmap_loss.flatten(), original_loss.flatten()
125 | , atol=1e-4)
126 |
127 |
128 | if __name__ == '__main__':
129 | absltest.main()
130 |
--------------------------------------------------------------------------------
/optax/contrib/_privacy_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for methods in `privacy.py`."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import jax
20 | import jax.numpy as jnp
21 | from optax._src import test_utils
22 | from optax.contrib import _privacy
23 |
24 |
25 | class DifferentiallyPrivateAggregateTest(parameterized.TestCase):
26 |
27 | def setUp(self):
28 | super().setUp()
29 | self.batch_size = 8
30 | self.params = {
31 | 'key_a': (jnp.zeros((2, 3, 4)), jnp.zeros([])),
32 | 'key_b': jnp.zeros((6, 7)),
33 | }
34 | # Example `i`'s grads are full of `i`s. Important to include 0 to ensure
35 | # there are no divisions by 0 (e.g. in norm clipping)
36 | a = jnp.arange(self.batch_size)
37 | self.per_eg_grads = jax.tree.map(
38 | lambda p: jnp.moveaxis(
39 | a * jnp.ones(p.shape + (self.batch_size,)), -1, 0
40 | ),
41 | self.params,
42 | )
43 |
44 | def test_no_privacy(self):
45 | """l2_norm_clip=MAX_FLOAT32 and noise_multiplier=0 should recover SGD."""
46 | dp_agg = _privacy.differentially_private_aggregate(
47 | l2_norm_clip=jnp.finfo(jnp.float32).max, noise_multiplier=0.0, key=0
48 | )
49 | state = dp_agg.init(self.params)
50 | update_fn = jax.jit(dp_agg.update)
51 | mean_grads = jax.tree.map(lambda g: g.mean(0), self.per_eg_grads)
52 |
53 | for _ in range(3):
54 | updates, state = update_fn(self.per_eg_grads, state)
55 | test_utils.assert_trees_all_close(updates, mean_grads)
56 |
57 | @parameterized.parameters(0.5, 10.0, 20.0, 40.0, 80.0)
58 | def test_clipping_norm(self, l2_norm_clip):
59 | dp_agg = _privacy.differentially_private_aggregate(
60 | l2_norm_clip=l2_norm_clip, noise_multiplier=0.0, key=42
61 | )
62 | state = dp_agg.init(self.params)
63 | update_fn = jax.jit(dp_agg.update)
64 |
65 | # Shape of the three arrays below is (self.batch_size, )
66 | norms = [
67 | jnp.linalg.norm(g.reshape(self.batch_size, -1), axis=1)
68 | for g in jax.tree.leaves(self.per_eg_grads)
69 | ]
70 | global_norms = jnp.linalg.norm(jnp.stack(norms), axis=0)
71 | divisors = jnp.maximum(global_norms / l2_norm_clip, 1.0)
72 | # Since the values of all the parameters are the same within each example,
73 | # we can easily compute what the values should be:
74 | expected_val = jnp.mean(jnp.arange(self.batch_size) / divisors)
75 | expected_tree = jax.tree.map(
76 | lambda p: jnp.broadcast_to(expected_val, p.shape), self.params
77 | )
78 |
79 | for _ in range(3):
80 | updates, state = update_fn(self.per_eg_grads, state, self.params)
81 | test_utils.assert_trees_all_close(updates, expected_tree, rtol=2e-7)
82 |
83 | @parameterized.parameters((3.0, 2.0), (1.0, 5.0), (100.0, 4.0), (1.0, 90.0))
84 | def test_noise_multiplier(self, l2_norm_clip, noise_multiplier):
85 | """Standard dev. of noise should be l2_norm_clip * noise_multiplier."""
86 | dp_agg = _privacy.differentially_private_aggregate(
87 | l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, key=1337
88 | )
89 | state = dp_agg.init(self.params)
90 | update_fn = jax.jit(dp_agg.update)
91 | expected_std = l2_norm_clip * noise_multiplier
92 |
93 | grads = [jnp.ones((1, 100, 100))] # batch size 1
94 | for _ in range(3):
95 | updates, state = update_fn(grads, state)
96 | test_utils.assert_trees_all_close(
97 | expected_std, jnp.std(updates[0]), atol=0.1 * expected_std
98 | )
99 |
100 | def test_aggregated_updates_as_input_fails(self):
101 | """Expect per-example gradients as input to this transform."""
102 | dp_agg = _privacy.differentially_private_aggregate(
103 | l2_norm_clip=0.1, noise_multiplier=1.1, key=2021
104 | )
105 | state = dp_agg.init(self.params)
106 | mean_grads = jax.tree.map(lambda g: g.mean(0), self.per_eg_grads)
107 | with self.assertRaises(ValueError):
108 | dp_agg.update(mean_grads, state, self.params)
109 |
110 |
111 | if __name__ == '__main__':
112 | absltest.main()
113 |
--------------------------------------------------------------------------------
/optax/contrib/_complex_valued.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Complex-valued optimization.
16 |
17 | When using `split_real_and_imaginary` to wrap an optimizer, we split the complex
18 | parameters and updates into pairs of real ones before sending them to the
19 | `update` of the wrapped optimizer, and merge the pairs of transformed real
20 | updates into complex ones afterward. In this way, optimizers on complex
21 | parameters behave the same way as if they were running on two real parameters.
22 |
23 | Note that the convention of conjugate for complex gradients in JAX is different
24 | from that in PyTorch and other frameworks, and we need to manually conjugate the
25 | gradients between `jax.grad` and `optimizer.update`.
26 |
27 | See details at https://github.com/deepmind/optax/issues/196
28 | """
29 |
30 | from typing import NamedTuple, Union
31 |
32 | import jax
33 | import jax.numpy as jnp
34 | from optax._src import base
35 |
36 |
37 | # NOTE(dsuo): Opt out of using the new `jax.pmap` implementation. There is
38 | # a C++ failure in windowing that needs to be resolved.
39 | jax.config.update('jax_pmap_shmap_merge', False)
40 |
41 |
42 | class SplitRealAndImaginaryArrays(NamedTuple):
43 | """A pair of real arrays split from a complex array."""
44 |
45 | real: jax.typing.ArrayLike
46 | imaginary: jax.typing.ArrayLike
47 |
48 |
49 | def _complex_to_real_pair(
50 | x: jax.typing.ArrayLike,
51 | ) -> Union[jax.typing.ArrayLike, SplitRealAndImaginaryArrays]:
52 | """Splits a complex array into a `SplitRealAndImaginaryArrays`.
53 |
54 | Args:
55 | x: The input array, can be complex or real.
56 |
57 | Returns:
58 | `SplitRealAndImaginaryArrays` if the input is a complex array. If the
59 | input is a real array, it is passed through unmodified.
60 | """
61 | if jnp.iscomplexobj(x):
62 | return SplitRealAndImaginaryArrays(x.real, x.imag)
63 | else:
64 | return x
65 |
66 |
67 | def _real_pair_to_complex(
68 | x: Union[jax.typing.ArrayLike, SplitRealAndImaginaryArrays],
69 | ) -> jax.typing.ArrayLike:
70 | """Merges a `SplitRealAndImaginaryArrays` into a complex array.
71 |
72 | Args:
73 | x: The input `SplitRealAndImaginaryArrays` or array.
74 |
75 | Returns:
76 | A complex array obtained from the real and imaginary parts of the
77 | `SplitRealAndImaginaryArrays`. If the input is not a
78 | `SplitRealAndImaginaryArrays`, it is passed through unmodified.
79 | """
80 | if isinstance(x, SplitRealAndImaginaryArrays):
81 | return x.real + x.imaginary * 1j
82 | else:
83 | return x
84 |
85 |
86 | class SplitRealAndImaginaryState(NamedTuple):
87 | """Maintains the inner transformation state for `split_real_and_imaginary`."""
88 |
89 | inner_state: base.OptState
90 |
91 |
92 | def split_real_and_imaginary(
93 | inner: base.GradientTransformation,
94 | ) -> base.GradientTransformation:
95 | """Splits the real and imaginary components of complex updates into two.
96 |
97 | The inner transformation processes real parameters and updates, and the
98 | pairs of transformed real updates are merged into complex updates.
99 |
100 | Parameters and updates that are real before splitting are passed through
101 | unmodified.
102 |
103 | Args:
104 | inner: The inner transformation.
105 |
106 | Returns:
107 | An `optax.GradientTransformation`.
108 | """
109 |
110 | def init_fn(params):
111 | params = jax.tree.map(_complex_to_real_pair, params)
112 | inner_state = inner.init(params)
113 | return SplitRealAndImaginaryState(inner_state)
114 |
115 | def update_fn(updates, state, params=None):
116 | inner_state = state.inner_state
117 | updates = jax.tree.map(_complex_to_real_pair, updates)
118 | params = jax.tree.map(_complex_to_real_pair, params)
119 | updates, inner_state = inner.update(updates, inner_state, params)
120 | updates = jax.tree.map(
121 | _real_pair_to_complex,
122 | updates,
123 | is_leaf=lambda x: isinstance(x, SplitRealAndImaginaryArrays),
124 | )
125 | return updates, SplitRealAndImaginaryState(inner_state)
126 |
127 | return base.GradientTransformation(init_fn, update_fn)
128 |
--------------------------------------------------------------------------------
/optax/tree_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """The tree_utils sub-package."""
16 |
17 | import functools
18 | import typing
19 |
20 | # pylint: disable=g-importing-member
21 |
22 | from optax.tree_utils._casting import tree_cast
23 | from optax.tree_utils._casting import tree_cast_like
24 | from optax.tree_utils._casting import tree_dtype
25 | from optax.tree_utils._random import tree_random_like
26 | from optax.tree_utils._random import tree_split_key_like
27 | from optax.tree_utils._random import tree_unwrap_random_key_data
28 | from optax.tree_utils._state_utils import NamedTupleKey
29 | from optax.tree_utils._state_utils import tree_get
30 | from optax.tree_utils._state_utils import tree_get_all_with_path
31 | from optax.tree_utils._state_utils import tree_map_params
32 | from optax.tree_utils._state_utils import tree_set
33 | from optax.tree_utils._tree_math import tree_add
34 | from optax.tree_utils._tree_math import tree_add_scale
35 | from optax.tree_utils._tree_math import tree_allclose
36 | from optax.tree_utils._tree_math import tree_batch_shape
37 | from optax.tree_utils._tree_math import tree_bias_correction
38 | from optax.tree_utils._tree_math import tree_clip
39 | from optax.tree_utils._tree_math import tree_conj
40 | from optax.tree_utils._tree_math import tree_div
41 | from optax.tree_utils._tree_math import tree_full_like
42 | from optax.tree_utils._tree_math import tree_max
43 | from optax.tree_utils._tree_math import tree_min
44 | from optax.tree_utils._tree_math import tree_mul
45 | from optax.tree_utils._tree_math import tree_norm
46 | from optax.tree_utils._tree_math import tree_ones_like
47 | from optax.tree_utils._tree_math import tree_real
48 | from optax.tree_utils._tree_math import tree_scale
49 | from optax.tree_utils._tree_math import tree_size
50 | from optax.tree_utils._tree_math import tree_sub
51 | from optax.tree_utils._tree_math import tree_sum
52 | from optax.tree_utils._tree_math import tree_update_infinity_moment
53 | from optax.tree_utils._tree_math import tree_update_moment
54 | from optax.tree_utils._tree_math import tree_update_moment_per_elem_norm
55 | from optax.tree_utils._tree_math import tree_vdot
56 | from optax.tree_utils._tree_math import tree_where
57 | from optax.tree_utils._tree_math import tree_zeros_like
58 |
59 | _deprecations = {
60 | # Added Mar 2025
61 | 'tree_scalar_mul': (
62 | ('optax.tree_utils.tree_scalar_mul is deprecated: use'
63 | ' optax.tree_utils.tree_scale (optax v0.2.5 or newer).'),
64 | tree_scale,
65 | ),
66 | 'tree_add_scalar_mul': (
67 | ('optax.tree_utils.tree_scalar_mul is deprecated: use'
68 | ' optax.tree_utils.tree_scale (optax v0.2.5 or newer).'),
69 | tree_add_scale,
70 | ),
71 | # Added May 2025
72 | 'tree_l1_norm': (
73 | ('optax.tree_utils.tree_l1_norm is deprecated: use'
74 | ' optax.tree_utils.tree_norm(..., ord=1) (optax v0.2.5 or newer).'),
75 | functools.partial(tree_norm, ord=1),
76 | ),
77 | 'tree_l2_norm': (
78 | ('optax.tree_utils.tree_l2_norm is deprecated: use'
79 | ' optax.tree_utils.tree_norm (optax v0.2.5 or newer).'),
80 | functools.partial(tree_norm, ord=2),
81 | ),
82 | 'tree_linf_norm': (
83 | ('optax.tree_utils.tree_linf_norm is deprecated: use'
84 | ' optax.tree_utils.tree_norm(..., ord=jnp.inf)'
85 | ' (optax v0.2.5 or newer).'),
86 | functools.partial(tree_norm, ord='inf'),
87 | ),
88 | }
89 |
90 | # pylint: disable=g-import-not-at-top
91 | # pylint: disable=g-bad-import-order
92 | if typing.TYPE_CHECKING:
93 | tree_scalar_mul = tree_scale
94 | tree_add_scalar_mul = tree_add_scale
95 | tree_l1_norm = functools.partial(tree_norm, ord=1)
96 | tree_l2_norm = tree_norm
97 | tree_linf_norm = functools.partial(tree_norm, ord='inf')
98 |
99 | else:
100 | # pylint: disable=line-too-long
101 | from optax._src.deprecations import deprecation_getattr as _deprecation_getattr # noqa: E501
102 | # pylint: enable=line-too-long
103 |
104 | __getattr__ = _deprecation_getattr(__name__, _deprecations)
105 | del _deprecation_getattr
106 | # pylint: enable=g-bad-import-order
107 | # pylint: enable=g-import-not-at-top
108 |
--------------------------------------------------------------------------------
/optax/tree_utils/_casting_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for tree utilities on data types."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import jax
20 | import jax.numpy as jnp
21 | import numpy as np
22 | from optax import tree_utils as otu
23 |
24 |
25 | class CastingTest(parameterized.TestCase):
26 |
27 | @parameterized.parameters([
28 | (jnp.float32, [1.3, 2.001, 3.6], [-3.3], [1.3, 2.001, 3.6], [-3.3]),
29 | (jnp.float32, [1.3, 2.001, 3.6], [-3], [1.3, 2.001, 3.6], [-3.0]),
30 | (jnp.int32, [1.3, 2.001, 3.6], [-3.3], [1, 2, 3], [-3]),
31 | (jnp.int32, [1.3, 2.001, 3.6], [-3], [1, 2, 3], [-3]),
32 | (None, [1.123, 2.33], [0.0], [1.123, 2.33], [0.0]),
33 | (None, [1, 2, 3], [0.0], [1, 2, 3], [0.0]),
34 | ])
35 | def test_tree_cast(self, dtype, b, c, new_b, new_c):
36 | def _build_tree(val1, val2):
37 | dict_tree = {'a': {'b': jnp.array(val1)}, 'c': jnp.array(val2)}
38 | return jax.tree.map(lambda x: x, dict_tree)
39 |
40 | tree = _build_tree(b, c)
41 | tree = otu.tree_cast(tree, dtype=dtype)
42 | jax.tree.map(np.testing.assert_array_equal, tree, _build_tree(new_b, new_c))
43 |
44 | @parameterized.parameters([
45 | (jnp.float16, [1.3, 2.001, 3.6], [-3]),
46 | (jnp.bfloat16, [1.3, 2.001, 3.6], [-3.3]),
47 | (jnp.float32, [1.3, 2.001, 3.6], [-3.3]),
48 | (jnp.int32, [1.3, 2.001, 3.6], [-3]),
49 | ])
50 | def test_tree_cast_like(self, dtype, b, c):
51 | def _build_tree(val1, val2):
52 | dict_tree = {'a': {'b': jnp.array(val1)}, 'c': jnp.array(val2)}
53 | return jax.tree.map(lambda x: x, dict_tree)
54 |
55 | tree = _build_tree(b, c)
56 | target_tree = _build_tree(b, c)
57 | target_tree = jax.tree.map(lambda x: x.astype(dtype), target_tree)
58 | tree = otu.tree_cast_like(tree, target_tree)
59 | jax.tree.map(np.testing.assert_array_equal, tree, target_tree)
60 |
61 | def test_tree_dtype(self):
62 | """Test fecthing data type of a tree."""
63 |
64 | with self.subTest('Check that it returns the right dtype'):
65 | tree = {
66 | 'a': {'b': jnp.array(1.0, dtype=jnp.float32)},
67 | 'c': jnp.array(2.0, dtype=jnp.float32),
68 | }
69 | dtype = otu.tree_dtype(tree)
70 | self.assertEqual(dtype, jnp.float32)
71 |
72 | with self.subTest('Check that it raises an error if dtypes differ'):
73 | tree = {
74 | 'a': {'b': jnp.array(1.0, dtype=jnp.bfloat16)},
75 | 'c': jnp.array(2.0, dtype=jnp.float32),
76 | }
77 | self.assertRaises(ValueError, otu.tree_dtype, tree)
78 |
79 | tree = {
80 | 'a': {'b': jnp.array(1.0, dtype=jnp.bfloat16)},
81 | 'c': jnp.array(2.0, dtype=jnp.float32),
82 | }
83 |
84 | with self.subTest('Check that it works with lowest common dtype'):
85 | dtype = otu.tree_dtype(tree, 'lowest')
86 | self.assertEqual(dtype, jnp.bfloat16)
87 |
88 | with self.subTest('Check that it works with highest common dtype'):
89 | dtype = otu.tree_dtype(tree, 'highest')
90 | self.assertEqual(dtype, jnp.float32)
91 |
92 | tree = {
93 | 'a': {'b': jnp.array(1.0, dtype=jnp.bfloat16)},
94 | 'c': jnp.array(2.0, dtype=jnp.float16),
95 | }
96 |
97 | with self.subTest('Check that it works when promoting mixed dtype'):
98 | dtype = otu.tree_dtype(tree, 'promote')
99 | self.assertEqual(dtype, jnp.float32)
100 |
101 | with self.subTest(
102 | 'Check that it raises an error if no dtypes cannot be promoted to one'
103 | ' another'
104 | ):
105 | self.assertRaises(ValueError, otu.tree_dtype, tree, 'lowest')
106 | self.assertRaises(ValueError, otu.tree_dtype, tree, 'highest')
107 |
108 | @parameterized.named_parameters(
109 | {'testcase_name': 'empty_dict', 'tree': {}},
110 | {'testcase_name': 'empty_list', 'tree': []},
111 | {'testcase_name': 'empty_tuple', 'tree': ()},
112 | {'testcase_name': 'empty_none', 'tree': None},
113 | )
114 | def test_tree_dtype_utilities_with_empty_trees(self, tree):
115 | """Test tree data type utilities on empty trees."""
116 | default_dtype = jnp.asarray(1.0).dtype
117 |
118 | with self.subTest('Check tree_dtype works with empty trees.'):
119 | dtype = otu.tree_dtype(tree)
120 | self.assertEqual(dtype, default_dtype)
121 |
122 |
123 | if __name__ == '__main__':
124 | absltest.main()
125 |
--------------------------------------------------------------------------------
/optax/contrib/_cocob.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Backpropagating variant of the COntinuous COin Betting stochastic algorithm.
16 |
17 | COCOB is a contributed optimizer implemented from Algorithm 2 of "Training Deep
18 | Networks without Learning Rates Through Coin Betting" by Francesco Orabona and
19 | Tatiana Tommasi.
20 | """
21 | from collections.abc import Callable
22 | from typing import Any, NamedTuple, Optional, Union
23 |
24 | import jax
25 | import jax.numpy as jnp
26 | from optax._src import base
27 | from optax._src import combine
28 | from optax._src import transform
29 | import optax.tree
30 |
31 |
32 | class COCOBState(NamedTuple):
33 | """State for COntinuous COin Betting."""
34 |
35 | init_particles: base.Updates
36 | cumulative_gradients: base.Updates
37 | scale: base.Updates
38 | subgradients: base.Updates
39 | reward: base.Updates
40 |
41 |
42 | def scale_by_cocob(
43 | alpha: jax.typing.ArrayLike = 100.0, eps: jax.typing.ArrayLike = 1e-8
44 | ) -> base.GradientTransformation:
45 | """Rescale updates according to the COntinuous COin Betting algorithm.
46 |
47 | See :func:`optax.contrib.cocob` for more details.
48 |
49 | Args:
50 | alpha: fraction to bet parameter of the COCOB optimizer
51 | eps: jitter term to avoid dividing by 0
52 |
53 | Returns:
54 | A `GradientTransformation` object.
55 | """
56 |
57 | def init_fn(params):
58 | init_adapt = optax.tree.zeros_like(params)
59 | init_scale = optax.tree.ones_like(params)
60 | init_scale = optax.tree.scale(eps, init_scale)
61 | return COCOBState(
62 | init_particles=params,
63 | cumulative_gradients=init_adapt,
64 | scale=init_scale,
65 | subgradients=init_adapt,
66 | reward=init_adapt,
67 | )
68 |
69 | def update_fn(updates, state, params):
70 | init_particles, cumulative_grads, scale, subgradients, reward = state
71 |
72 | scale = jax.tree.map(
73 | lambda L, c: jnp.maximum(L, jnp.abs(c)), scale, updates
74 | )
75 | subgradients = jax.tree.map(
76 | lambda G, c: G + jnp.abs(c), subgradients, updates
77 | )
78 | reward = jax.tree.map(
79 | lambda R, c, p, p0: jnp.maximum(R - c * (p - p0), 0),
80 | reward,
81 | updates,
82 | params,
83 | init_particles,
84 | )
85 | cumulative_grads = jax.tree.map(
86 | lambda C, c: C - c, cumulative_grads, updates
87 | )
88 |
89 | new_updates = jax.tree.map(
90 | lambda p, p0, C, L, G, R: (
91 | -p + (p0 + C / (L * jnp.maximum(G + L, alpha * L)) * (L + R))
92 | ),
93 | params,
94 | init_particles,
95 | cumulative_grads,
96 | scale,
97 | subgradients,
98 | reward,
99 | )
100 |
101 | new_state = COCOBState(
102 | init_particles=init_particles,
103 | cumulative_gradients=cumulative_grads,
104 | scale=scale,
105 | subgradients=subgradients,
106 | reward=reward,
107 | )
108 | return new_updates, new_state
109 |
110 | return base.GradientTransformation(init_fn, update_fn)
111 |
112 |
113 | def cocob(
114 | learning_rate: base.ScalarOrSchedule = 1.0,
115 | alpha: jax.typing.ArrayLike = 100.0,
116 | eps: jax.typing.ArrayLike = 1e-8,
117 | weight_decay: jax.typing.ArrayLike = 0.0,
118 | mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
119 | ) -> base.GradientTransformation:
120 | """Rescale updates according to the COntinuous COin Betting algorithm.
121 |
122 | Algorithm for stochastic subgradient descent. Uses a gambling algorithm to
123 | find the minimizer of a non-smooth objective function by accessing its
124 | subgradients. All we need is a good gambling strategy. See Algorithm 2 of:
125 |
126 | Args:
127 | learning_rate: optional learning rate to e.g. inject some scheduler
128 | alpha: fraction to bet parameter of the COCOB optimizer
129 | eps: jitter term to avoid dividing by 0
130 | weight_decay: L2 penalty
131 | mask: mask for weight decay
132 |
133 | Returns:
134 | A `GradientTransformation` object.
135 |
136 | References:
137 | Orabana et al, `Training Deep Networks without Learning Rates Through Coin
138 | Betting `_, 2017
139 | """
140 | return combine.chain(
141 | transform.add_decayed_weights(weight_decay, mask),
142 | transform.scale_by_learning_rate(learning_rate, flip_sign=False),
143 | scale_by_cocob(alpha, eps),
144 | )
145 |
--------------------------------------------------------------------------------
/docs/api/utilities.rst:
--------------------------------------------------------------------------------
1 | Utilities
2 | =========
3 |
4 | General
5 | -------
6 |
7 | .. currentmodule:: optax
8 |
9 | .. autosummary::
10 | scale_gradient
11 | value_and_grad_from_state
12 |
13 | Scale gradient
14 | ~~~~~~~~~~~~~~
15 | .. autofunction:: scale_gradient
16 |
17 | Value and grad from state
18 | ~~~~~~~~~~~~~~~~~~~~~~~~~
19 | .. autofunction:: value_and_grad_from_state
20 |
21 |
22 | Numerical Stability
23 | -------------------
24 |
25 | .. currentmodule:: optax
26 |
27 | .. autosummary::
28 | safe_increment
29 | safe_norm
30 | safe_root_mean_squares
31 |
32 | Safe increment
33 | ~~~~~~~~~~~~~~
34 | .. autofunction:: safe_increment
35 |
36 | Safe norm
37 | ~~~~~~~~~
38 | .. autofunction:: safe_norm
39 |
40 | Safe root mean squares
41 | ~~~~~~~~~~~~~~~~~~~~~~
42 | .. autofunction:: safe_root_mean_squares
43 |
44 |
45 | Linear Algebra Operators
46 | ------------------------
47 |
48 | .. currentmodule:: optax
49 |
50 | .. autosummary::
51 | matrix_inverse_pth_root
52 | power_iteration
53 | nnls
54 |
55 | Matrix inverse pth root
56 | ~~~~~~~~~~~~~~~~~~~~~~~~
57 | .. autofunction:: matrix_inverse_pth_root
58 |
59 | Power iteration
60 | ~~~~~~~~~~~~~~~
61 | .. autofunction:: power_iteration
62 |
63 | Non-negative least squares
64 | ~~~~~~~~~~~~~~~~~~~~~~~~~~
65 | .. autofunction:: nnls
66 |
67 |
68 | Second Order Optimization
69 | -------------------------
70 |
71 | .. currentmodule:: optax.second_order
72 |
73 | .. autosummary::
74 | fisher_diag
75 | hessian_diag
76 | hvp
77 |
78 | Fisher diagonal
79 | ~~~~~~~~~~~~~~~
80 | .. autofunction:: fisher_diag
81 |
82 | Hessian diagonal
83 | ~~~~~~~~~~~~~~~~
84 | .. autofunction:: hessian_diag
85 |
86 | Hessian vector product
87 | ~~~~~~~~~~~~~~~~~~~~~~
88 | .. autofunction:: hvp
89 |
90 |
91 | Tree
92 | ----
93 |
94 | .. currentmodule:: optax.tree_utils
95 |
96 | .. autosummary::
97 | NamedTupleKey
98 | tree_add
99 | tree_add_scale
100 | tree_allclose
101 | tree_batch_shape
102 | tree_cast
103 | tree_cast_like
104 | tree_clip
105 | tree_conj
106 | tree_div
107 | tree_dtype
108 | tree_full_like
109 | tree_get
110 | tree_get_all_with_path
111 | tree_norm
112 | tree_map_params
113 | tree_max
114 | tree_min
115 | tree_mul
116 | tree_ones_like
117 | tree_random_like
118 | tree_real
119 | tree_split_key_like
120 | tree_scale
121 | tree_set
122 | tree_size
123 | tree_sub
124 | tree_sum
125 | tree_vdot
126 | tree_where
127 | tree_zeros_like
128 |
129 | NamedTupleKey
130 | ~~~~~~~~~~~~~
131 | .. autoclass:: NamedTupleKey
132 |
133 | Tree add
134 | ~~~~~~~~
135 | .. autofunction:: tree_add
136 |
137 | Tree add and scalar multiply
138 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
139 | .. autofunction:: tree_add_scale
140 |
141 | Tree all close
142 | ~~~~~~~~~~~~~~
143 | .. autofunction:: tree_allclose
144 |
145 | Tree batch reshaping
146 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
147 | .. autofunction:: tree_batch_shape
148 |
149 | Tree cast
150 | ~~~~~~~~~
151 | .. autofunction:: tree_cast
152 |
153 | Tree cast like
154 | ~~~~~~~~~~~~~~
155 | .. autofunction:: tree_cast_like
156 |
157 | Tree clip
158 | ~~~~~~~~~
159 | .. autofunction:: tree_clip
160 |
161 | Tree conjugate
162 | ~~~~~~~~~~~~~~
163 | .. autofunction:: tree_conj
164 |
165 | Tree data type
166 | ~~~~~~~~~~~~~~
167 | .. autofunction:: tree_dtype
168 |
169 | Tree full like
170 | ~~~~~~~~~~~~~~
171 | .. autofunction:: tree_full_like
172 |
173 | Tree divide
174 | ~~~~~~~~~~~
175 | .. autofunction:: tree_div
176 |
177 | Fetch single value that match a given key
178 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
179 | .. autofunction:: tree_get
180 |
181 | Fetch all values that match a given key
182 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
183 | .. autofunction:: tree_get_all_with_path
184 |
185 | Tree norm
186 | ~~~~~~~~~
187 | .. autofunction:: tree_norm
188 |
189 | Tree map parameters
190 | ~~~~~~~~~~~~~~~~~~~
191 | .. autofunction:: tree_map_params
192 |
193 | Tree max
194 | ~~~~~~~~
195 | .. autofunction:: tree_max
196 |
197 | Tree min
198 | ~~~~~~~~
199 | .. autofunction:: tree_min
200 |
201 | Tree multiply
202 | ~~~~~~~~~~~~~
203 | .. autofunction:: tree_mul
204 |
205 | Tree ones like
206 | ~~~~~~~~~~~~~~
207 | .. autofunction:: tree_ones_like
208 |
209 | Split key according to structure of a tree
210 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
211 | .. autofunction:: tree_split_key_like
212 |
213 | Tree with random values
214 | ~~~~~~~~~~~~~~~~~~~~~~~
215 | .. autofunction:: tree_random_like
216 |
217 | Tree real part
218 | ~~~~~~~~~~~~~~
219 | .. autofunction:: tree_real
220 |
221 | Tree scalar multiply
222 | ~~~~~~~~~~~~~~~~~~~~
223 | .. autofunction:: tree_scale
224 |
225 | Set values in a tree
226 | ~~~~~~~~~~~~~~~~~~~~
227 | .. autofunction:: tree_set
228 |
229 | Tree size
230 | ~~~~~~~~~
231 | .. autofunction:: tree_size
232 |
233 | Tree subtract
234 | ~~~~~~~~~~~~~
235 | .. autofunction:: tree_sub
236 |
237 | Tree sum
238 | ~~~~~~~~
239 | .. autofunction:: tree_sum
240 |
241 | Tree inner product
242 | ~~~~~~~~~~~~~~~~~~
243 | .. autofunction:: tree_vdot
244 |
245 | Tree where
246 | ~~~~~~~~~~
247 | .. autofunction:: tree_where
248 |
249 | Tree zeros like
250 | ~~~~~~~~~~~~~~~
251 | .. autofunction:: tree_zeros_like
252 |
--------------------------------------------------------------------------------
/optax/transforms/_constraining_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for methods in `optax.transforms._constraining.py`."""
16 |
17 | from absl.testing import absltest
18 | import jax
19 | import jax.numpy as jnp
20 | from optax._src import combine
21 | from optax._src import test_utils
22 | from optax._src import transform
23 | from optax._src import update
24 | from optax.transforms import _accumulation
25 | from optax.transforms import _constraining
26 |
27 |
28 | STEPS = 50
29 | LR = 1e-2
30 |
31 |
32 | class ConstraintsTest(absltest.TestCase):
33 |
34 | def test_keep_params_nonnegative(self):
35 | grads = (
36 | jnp.array([500.0, -500.0, 0.0]),
37 | jnp.array([500.0, -500.0, 0.0]),
38 | jnp.array([500.0, -500.0, 0.0]),
39 | )
40 |
41 | params = (
42 | jnp.array([-1.0, -1.0, -1.0]),
43 | jnp.array([1.0, 1.0, 1.0]),
44 | jnp.array([0.0, 0.0, 0.0]),
45 | )
46 |
47 | # vanilla sgd
48 | opt = combine.chain(
49 | _accumulation.trace(decay=0, nesterov=False), transform.scale(-LR)
50 | )
51 | opt_state = opt.init(params)
52 |
53 | updates, _ = opt.update(grads, opt_state, params)
54 | new_params = update.apply_updates(params, updates)
55 |
56 | test_utils.assert_trees_all_close(
57 | new_params,
58 | (
59 | jnp.array([-6.0, 4.0, -1.0]),
60 | jnp.array([-4.0, 6.0, 1.0]),
61 | jnp.array([-5.0, 5.0, 0.0]),
62 | ),
63 | )
64 |
65 | # sgd with keeping parameters non-negative
66 | opt = combine.chain(
67 | _accumulation.trace(decay=0, nesterov=False),
68 | transform.scale(-LR),
69 | _constraining.keep_params_nonnegative(),
70 | )
71 | opt_state = opt.init(params)
72 |
73 | updates, _ = opt.update(grads, opt_state, params)
74 | new_params = update.apply_updates(params, updates)
75 |
76 | test_utils.assert_trees_all_close(
77 | new_params,
78 | (
79 | jnp.array([0.0, 4.0, 0.0]),
80 | jnp.array([0.0, 6.0, 1.0]),
81 | jnp.array([0.0, 5.0, 0.0]),
82 | ),
83 | )
84 |
85 | def test_zero_nans(self):
86 | params = (jnp.zeros([3]), jnp.zeros([3]), jnp.zeros([3]))
87 |
88 | opt = _constraining.zero_nans()
89 | opt_state = jax.jit(opt.init)(params)
90 | update_fn = jax.jit(opt.update)
91 |
92 | test_utils.assert_trees_all_close(
93 | opt_state, _constraining.ZeroNansState((jnp.array(False),) * 3)
94 | )
95 |
96 | # Check an upate with nans
97 | grads_with_nans = (
98 | jnp.ones([3]),
99 | jnp.array([1.0, float('nan'), float('nan')]),
100 | jnp.array([float('nan'), 1.0, 1.0]),
101 | )
102 | updates, opt_state = update_fn(grads_with_nans, opt_state)
103 | test_utils.assert_trees_all_close(
104 | opt_state,
105 | _constraining.ZeroNansState(
106 | (jnp.array(False), jnp.array(True), jnp.array(True))
107 | ),
108 | )
109 | test_utils.assert_trees_all_close(
110 | updates,
111 | (jnp.ones([3]), jnp.array([1.0, 0.0, 0.0]), jnp.array([0.0, 1.0, 1.0])),
112 | )
113 |
114 | # Check an upate with nans and infs
115 | grads_with_nans_infs = (
116 | jnp.ones([3]),
117 | jnp.array([1.0, float('nan'), float('nan')]),
118 | jnp.array([float('inf'), 1.0, 1.0]),
119 | )
120 | updates, opt_state = update_fn(grads_with_nans_infs, opt_state)
121 | test_utils.assert_trees_all_close(
122 | opt_state,
123 | _constraining.ZeroNansState(
124 | (jnp.array(False), jnp.array(True), jnp.array(False))
125 | ),
126 | )
127 | test_utils.assert_trees_all_close(
128 | updates,
129 | (
130 | jnp.ones([3]),
131 | jnp.array([1.0, 0.0, 0.0]),
132 | jnp.array([float('inf'), 1.0, 1.0]),
133 | ),
134 | )
135 |
136 | # Check an update with only good values
137 | grads = (jnp.ones([3]), jnp.ones([3]), jnp.ones([3]))
138 | updates, opt_state = update_fn(grads, opt_state)
139 | test_utils.assert_trees_all_close(
140 | opt_state,
141 | _constraining.ZeroNansState(
142 | (jnp.array(False), jnp.array(False), jnp.array(False))
143 | ),
144 | )
145 | test_utils.assert_trees_all_close(updates, grads)
146 |
147 | def test_none_arguments(self):
148 | tf = _constraining.keep_params_nonnegative()
149 | state = tf.init(jnp.array([1.0, 2.0, 3.0]))
150 | with self.assertRaises(ValueError):
151 | tf.update(jnp.array([1.0, 2.0, 3.0]), state, None)
152 |
153 |
154 | if __name__ == '__main__':
155 | absltest.main()
156 |
--------------------------------------------------------------------------------
/optax/_src/sharding_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Module for testing sharding and related behavior of the optax public API."""
17 |
18 | import os
19 |
20 | from absl.testing import absltest
21 | from absl.testing import parameterized
22 | import jax
23 | import jax.numpy as jnp
24 | import optax
25 | from optax._src import test_utils
26 | from optax.experimental import microbatching
27 |
28 |
29 | os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
30 |
31 | OPTIMIZERS = {
32 | 'adam': optax.adam(1.0),
33 | 'sgd': optax.sgd(1.0),
34 | 'adabelief': optax.adabelief(1.0),
35 | 'adamax': optax.adamax(1.0),
36 | 'adagrad': optax.adagrad(1.0),
37 | 'adamw': optax.adamw(1.0),
38 | 'rmsprop': optax.rmsprop(1.0),
39 | # TODO(mckennar): try to incorporate linesearch into the test.
40 | 'lbfgs': optax.lbfgs(1.0, linesearch=None),
41 | 'adadelta': optax.adadelta(1.0),
42 | 'adafactor': optax.adafactor(),
43 | 'adafactor2': optax.adafactor(min_dim_size_to_factor=1),
44 | 'adan': optax.adan(1.0),
45 | 'adamaxw': optax.adamaxw(1.0),
46 | 'amsgrad': optax.amsgrad(1.0),
47 | 'fromage': optax.fromage(1.0),
48 | 'lamb': optax.lamb(1.0),
49 | 'lars': optax.lars(1.0),
50 | 'lion': optax.lion(1.0),
51 | 'nadam': optax.nadam(1.0),
52 | 'nadamw': optax.nadamw(1.0),
53 | 'noisy_sgd': optax.noisy_sgd(1.0),
54 | 'novograd': optax.novograd(1.0),
55 | 'optimistic_gradient_descent': optax.optimistic_gradient_descent(1.0),
56 | 'radam': optax.radam(1.0),
57 | 'sm3': optax.sm3(1.0),
58 | 'yogi': optax.yogi(1.0),
59 | }
60 |
61 |
62 | class ShardingTest(parameterized.TestCase):
63 |
64 | @parameterized.named_parameters(OPTIMIZERS.items())
65 | def test_init_with_abstract_input(self, optimizer):
66 | params = jax.ShapeDtypeStruct(shape=(2, 4, 8), dtype=jnp.float32)
67 | state = optimizer.init(params)
68 | self.assertIsNotNone(state)
69 |
70 | @parameterized.named_parameters(OPTIMIZERS.items())
71 | def test_state_sharding_type_init_match_update(self, optimizer):
72 | if jax.__version__ < '0.7.2':
73 | self.skipTest('Skipping sharding-in-types test')
74 | mesh = jax.make_mesh(
75 | (8,), ('x',), axis_types=(jax.sharding.AxisType.Explicit,)
76 | )
77 | sharding = jax.sharding.NamedSharding(mesh, jax.P(None, 'x'))
78 |
79 | with jax.set_mesh(mesh):
80 | params = jnp.zeros((2, 8, 4), dtype=jnp.float16, out_sharding=sharding)
81 |
82 | state0 = optimizer.init(params)
83 | _, state1 = optimizer.update(params, state0, params)
84 |
85 | type0 = jax.tree.map(jax.typeof, state0)
86 | type1 = jax.tree.map(jax.typeof, state1)
87 | test_utils.assert_trees_all_equal(type0, type1)
88 |
89 | @parameterized.named_parameters(OPTIMIZERS.items())
90 | def test_state_sharding_type_preserved_with_jit(self, optimizer):
91 | if jax.__version__ < '0.7.2':
92 | self.skipTest('Skipping sharding-in-types test')
93 | mesh = jax.make_mesh(
94 | (8,), ('x',), axis_types=(jax.sharding.AxisType.Explicit,)
95 | )
96 | sharding = jax.sharding.NamedSharding(mesh, jax.P(None, 'x'))
97 |
98 | with jax.set_mesh(mesh):
99 | params = jnp.zeros((2, 8, 4), dtype=jnp.float16, out_sharding=sharding)
100 |
101 | state0 = optimizer.init(params)
102 | state1 = jax.jit(optimizer.init)(params)
103 | type0 = jax.tree.map(jax.typeof, state0)
104 | type1 = jax.tree.map(jax.typeof, state1)
105 | test_utils.assert_trees_all_equal(type0, type1)
106 |
107 | _, state2 = optimizer.update(params, state0, params)
108 | _, state3 = jax.jit(optimizer.update)(params, state0, params)
109 | type2 = jax.tree.map(jax.typeof, state2)
110 | type3 = jax.tree.map(jax.typeof, state3)
111 | test_utils.assert_trees_all_equal(type2, type3)
112 |
113 | @parameterized.named_parameters(
114 | ('replicated', jax.sharding.PartitionSpec()),
115 | ('sharded', jax.sharding.PartitionSpec('x'))
116 | )
117 | def test_microbatch_with_explicit_sharding(self, spec):
118 | if jax.__version__ < '0.8.1':
119 | self.skipTest('Skipping sharding-in-types test.')
120 | mesh = jax.make_mesh(
121 | (8,), ('x',), axis_types=(jax.sharding.AxisType.Explicit,)
122 | )
123 | with jax.set_mesh(mesh):
124 | sharding = jax.sharding.NamedSharding(mesh, spec)
125 |
126 | fun = lambda x: (x + 1, jnp.sum(3 * x))
127 | data = jnp.arange(16, out_sharding=sharding)
128 |
129 | strategy = (
130 | microbatching.AccumulationType.CONCAT,
131 | microbatching.AccumulationType.SUM,
132 | )
133 | microbatched_fun = microbatching.microbatch(
134 | fun, argnums=0, microbatch_size=8, accumulator=strategy
135 | )
136 |
137 | actual = microbatched_fun(data)
138 | expected = fun(data)
139 |
140 | test_utils.assert_trees_all_equal(
141 | jax.tree.map(jax.typeof, actual), jax.tree.map(jax.typeof, expected)
142 | )
143 | test_utils.assert_trees_all_equal(actual, expected)
144 |
145 |
146 | if __name__ == '__main__':
147 | absltest.main()
148 |
--------------------------------------------------------------------------------
/optax/_src/numerics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Utilities to ensure the implementation is safe wrt numerical issues.
16 |
17 | Note that complex numbers are also supported, see
18 | https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29
19 | """
20 |
21 | from typing import Optional, Union
22 |
23 | import jax
24 | import jax.numpy as jnp
25 | import numpy as np
26 |
27 |
28 | # TODO(jscholz) Promote these functions to jax core lib?
29 |
30 |
31 | def abs_sq(x: jax.typing.ArrayLike) -> jax.Array:
32 | """Returns the squared absolute value of a (maybe complex) array.
33 |
34 | For real `x`, JAX generates the same HLO from this, `jnp.square(x)`, `x * x`,
35 | or `x**2`.
36 |
37 | Args:
38 | x: a (maybe complex) array.
39 |
40 | Returns:
41 | The squared absolute value of `x`.
42 | """
43 | if not isinstance(x, (np.ndarray, jnp.ndarray)):
44 | raise ValueError(f'`abs_sq` accepts only NDarrays, got: {x}.')
45 | return (x.conj() * x).real
46 |
47 |
48 | def safe_norm(
49 | x: jax.typing.ArrayLike,
50 | min_norm: jax.typing.ArrayLike,
51 | ord: Optional[Union[int, float, str]] = None, # pylint: disable=redefined-builtin
52 | axis: Union[None, tuple[int, ...], int] = None,
53 | keepdims: bool = False,
54 | ) -> jax.Array:
55 | """Returns jnp.maximum(jnp.linalg.norm(x), min_norm) with correct gradients.
56 |
57 | The gradients of `jnp.maximum(jnp.linalg.norm(x), min_norm)` at 0.0 is `NaN`,
58 | because jax will evaluate both branches of the `jnp.maximum`. This function
59 | will instead return the correct gradient of 0.0 also in such setting.
60 |
61 | Args:
62 | x: jax array.
63 | min_norm: lower bound for the returned norm.
64 | ord: {non-zero int, inf, -inf, 'fro', 'nuc'}, optional. Order of the norm.
65 | inf means numpy's inf object. The default is None.
66 | axis: {None, int, 2-tuple of ints}, optional. If axis is an integer, it
67 | specifies the axis of x along which to compute the vector norms. If axis
68 | is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix
69 | norms of these matrices are computed. If axis is None then either a vector
70 | norm (when x is 1-D) or a matrix norm (when x is 2-D) is returned. The
71 | default is None.
72 | keepdims: bool, optional. If this is set to True, the axes which are normed
73 | over are left in the result as dimensions with size one. With this option
74 | the result will broadcast correctly against the original x.
75 |
76 | Returns:
77 | The safe norm of the input vector, accounting for correct gradient.
78 | """
79 | norm = jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=True)
80 | x = jnp.where(norm <= min_norm, jnp.ones_like(x), x)
81 | norm = jnp.squeeze(norm, axis=axis) if not keepdims else norm
82 | masked_norm = jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
83 | return jnp.where(norm <= min_norm, min_norm, masked_norm)
84 |
85 |
86 | def safe_root_mean_squares(
87 | x: jax.typing.ArrayLike, min_rms: jax.typing.ArrayLike) -> jax.Array:
88 | """Returns `maximum(sqrt(mean(abs_sq(x))), min_norm)` with correct grads.
89 |
90 | The gradients of `maximum(sqrt(mean(abs_sq(x))), min_norm)` at 0.0
91 | is `NaN`, because jax will evaluate both branches of the `jnp.maximum`. This
92 | function will instead return the correct gradient of 0.0 also in such setting.
93 |
94 | Args:
95 | x: jax array.
96 | min_rms: lower bound for the returned norm.
97 |
98 | Returns:
99 | The safe RMS of the input vector, accounting for correct gradient.
100 | """
101 | rms = jnp.sqrt(jnp.mean(abs_sq(x)))
102 | x = jnp.where(rms <= min_rms, jnp.ones_like(x), x)
103 | return jnp.where(rms <= min_rms, min_rms, jnp.sqrt(jnp.mean(abs_sq(x))))
104 |
105 |
106 | def safe_increment(count: jax.typing.ArrayLike) -> jax.typing.ArrayLike:
107 | """Increments counter by one while avoiding overflow.
108 |
109 | Denote ``max_val``, ``min_val`` as the maximum, minimum, possible values for
110 | the ``dtype`` of ``count``. Normally ``max_val + 1`` would overflow to
111 | ``min_val``. This functions ensures that when ``max_val`` is reached the
112 | counter stays at ``max_val``.
113 |
114 | Args:
115 | count: a counter to be incremented.
116 |
117 | Returns:
118 | A counter incremented by 1, or ``max_val`` if the maximum value is
119 | reached.
120 |
121 | Examples:
122 | >>> import jax.numpy as jnp
123 | >>> import optax
124 | >>> optax.safe_increment(jnp.asarray(1, dtype=jnp.int32))
125 | Array(2, dtype=int32)
126 | >>> optax.safe_increment(jnp.asarray(2147483647, dtype=jnp.int32))
127 | Array(2147483647, dtype=int32)
128 |
129 | .. versionadded:: 0.2.4
130 | """
131 | count_dtype = jnp.asarray(count).dtype
132 | if jnp.issubdtype(count_dtype, jnp.integer):
133 | max_value = jnp.iinfo(count_dtype).max
134 | elif jnp.issubdtype(count_dtype, jnp.floating):
135 | max_value = jnp.finfo(count_dtype).max
136 | else:
137 | raise ValueError(
138 | f'Cannot safely increment count with dtype {count_dtype},'
139 | ' valid dtypes are subdtypes of "jnp.integer" or "jnp.floating".'
140 | )
141 | max_value = jnp.array(max_value, count_dtype)
142 | one = jnp.array(1, count_dtype)
143 | return jnp.where(count < max_value, count + one, max_value)
144 |
145 |
146 | safe_int32_increment = safe_increment
147 |
--------------------------------------------------------------------------------