├── .github └── workflows │ ├── docs.yml │ ├── publish_to_pypi.yml │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── benchmarks ├── lbfgs_benchmark.py └── proximal_gradient_benchmark.py ├── docs ├── Makefile ├── README.md ├── _static │ └── css │ │ └── custom.css ├── api.rst ├── basics.rst ├── changelog.rst ├── conf.py ├── constrained.rst ├── developer.rst ├── fixed_point.rst ├── implicit_diff.rst ├── index.rst ├── line_search.rst ├── linear_system_solvers.rst ├── non_smooth.rst ├── nonlinear_least_squares.rst ├── notebooks │ ├── deep_learning │ │ ├── adversarial_training.ipynb │ │ ├── adversarial_training.md │ │ ├── resnet_flax.ipynb │ │ ├── resnet_flax.md │ │ ├── resnet_haiku.ipynb │ │ ├── resnet_haiku.md │ │ └── thumbnails │ │ │ ├── adversarial_training.png │ │ │ ├── resnet_flax.png │ │ │ └── resnet_haiku.png │ ├── distributed │ │ ├── custom_loop_pjit_example.ipynb │ │ ├── custom_loop_pjit_example.md │ │ ├── custom_loop_pmap_example.ipynb │ │ ├── custom_loop_pmap_example.md │ │ └── thumbnails │ │ │ ├── plot_custom_loop_pjit_example.png │ │ │ └── plot_custom_loop_pmap_example.png │ ├── implicit_diff │ │ ├── dataset_distillation.ipynb │ │ ├── dataset_distillation.md │ │ ├── maml.ipynb │ │ ├── maml.md │ │ └── thumbnails │ │ │ ├── maml.png │ │ │ └── plot_dataset_distillation.png │ ├── index.rst │ └── perturbed_optimizers │ │ ├── perturbed_optimizers.ipynb │ │ ├── perturbed_optimizers.md │ │ └── thumbnails │ │ └── perturbations.png ├── objective_and_loss.rst ├── perturbations.rst ├── quadratic_programming.rst ├── requirements.txt ├── root_finding.rst ├── stochastic.rst └── unconstrained.rst ├── examples ├── README.rst ├── constrained │ ├── README.rst │ ├── binary_kernel_svm_with_intercept.py │ ├── multiclass_linear_svm.py │ └── nmf.py ├── deep_learning │ ├── README.rst │ ├── distributed_flax_imagenet.py │ ├── haiku_vae.py │ └── plot_sgd_solvers.py ├── fixed_point │ ├── README.rst │ ├── deep_equilibrium_model.py │ ├── plot_anderson_accelerate_gd.py │ ├── plot_anderson_wrapper_cd.py │ └── plot_picard_ode.py ├── implicit_diff │ ├── README.rst │ ├── lasso_implicit_diff.py │ ├── ridge_reg_implicit_diff.py │ └── sparse_coding.py └── requirements.txt ├── jaxopt ├── __init__.py ├── _src │ ├── __init__.py │ ├── anderson.py │ ├── anderson_wrapper.py │ ├── armijo_sgd.py │ ├── backtracking_linesearch.py │ ├── base.py │ ├── bfgs.py │ ├── bisection.py │ ├── block_cd.py │ ├── broyden.py │ ├── cd_qp.py │ ├── cond.py │ ├── cvxpy_wrapper.py │ ├── eq_qp.py │ ├── fixed_point_iteration.py │ ├── gauss_newton.py │ ├── gradient_descent.py │ ├── hager_zhang_linesearch.py │ ├── implicit_diff.py │ ├── isotonic.py │ ├── iterative_refinement.py │ ├── lbfgs.py │ ├── lbfgsb.py │ ├── levenberg_marquardt.py │ ├── linear_operator.py │ ├── linear_solve.py │ ├── linesearch_util.py │ ├── loop.py │ ├── loss.py │ ├── mirror_descent.py │ ├── nonlinear_cg.py │ ├── objective.py │ ├── optax_wrapper.py │ ├── osqp.py │ ├── perturbations.py │ ├── polyak_sgd.py │ ├── projected_gradient.py │ ├── projection.py │ ├── prox.py │ ├── proximal_gradient.py │ ├── scipy_wrappers.py │ ├── test_util.py │ ├── tree_util.py │ └── zoom_linesearch.py ├── base.py ├── cond.py ├── implicit_diff.py ├── isotonic.py ├── linear_solve.py ├── loop.py ├── loss.py ├── objective.py ├── perturbations.py ├── projection.py ├── prox.py ├── tree_util.py └── version.py ├── pylintrc ├── requirements.txt ├── requirements_test.txt ├── setup.py └── tests ├── anderson_test.py ├── anderson_wrapper_test.py ├── armijo_sgd_test.py ├── backtracking_linesearch_test.py ├── base_test.py ├── bfgs_test.py ├── bisection_test.py ├── block_cd_test.py ├── broyden_test.py ├── cd_qp_test.py ├── common_test.py ├── cond_test.py ├── cvxpy_wrapper_test.py ├── eq_qp_test.py ├── fixed_point_iteration_test.py ├── gauss_newton_test.py ├── gradient_descent_test.py ├── hager_zhang_linesearch_test.py ├── implicit_diff_test.py ├── import_test.py ├── isotonic_test.py ├── iterative_refinement_test.py ├── lbfgs_test.py ├── lbfgsb_test.py ├── levenberg_marquardt_test.py ├── linear_operator_test.py ├── linear_solve_test.py ├── linesearch_common_test.py ├── loop_test.py ├── loss_test.py ├── mirror_descent_test.py ├── nonlinear_cg_test.py ├── optax_wrapper_test.py ├── osqp_test.py ├── perturbations_test.py ├── polyak_sgd_test.py ├── projected_gradient_test.py ├── projection_test.py ├── prox_test.py ├── proximal_gradient_test.py ├── scipy_wrappers_test.py ├── tree_util_test.py └── zoom_linesearch_test.py /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | 7 | jobs: 8 | build-and-deploy: 9 | name: "Build and deploy documentation" 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: "actions/checkout@v3" 13 | - name: Set up Python 3.11 14 | uses: "actions/setup-python@v4" 15 | with: 16 | python-version: 3.11 17 | cache: 'pip' 18 | - run: pip install -r docs/requirements.txt 19 | - name: Build documentation 20 | run: cd docs && make html 21 | - uses: cpina/github-action-push-to-another-repository@main 22 | env: 23 | SSH_DEPLOY_KEY: ${{ secrets.SSH_DEPLOY_KEY }} 24 | with: 25 | source-directory: 'docs/_build/html' 26 | destination-github-username: 'jaxopt' 27 | destination-repository-name: 'jaxopt.github.io' 28 | user-email: jaxopt@google.com 29 | target-branch: main 30 | target-directory: dev 31 | -------------------------------------------------------------------------------- /.github/workflows/publish_to_pypi.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v1 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: __token__ 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | 9 | jobs: 10 | build-and-test: 11 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" 12 | runs-on: "${{ matrix.os }}" 13 | 14 | strategy: 15 | matrix: 16 | python-version: ["3.10", "3.11", "3.12", "3.13"] 17 | os: [ubuntu-latest] 18 | 19 | steps: 20 | - uses: "actions/checkout@v3" 21 | - uses: "actions/setup-python@v4" 22 | with: 23 | python-version: "${{ matrix.python-version }}" 24 | cache: 'pip' 25 | - name: Install dependencies 26 | run: | 27 | set -xe 28 | pip install --upgrade pip setuptools wheel 29 | pip install -r requirements.txt 30 | pip install -r requirements_test.txt 31 | shell: bash 32 | - name: Build 33 | run: | 34 | set -xe 35 | python -VV 36 | python setup.py install 37 | shell: bash 38 | - name: Run tests 39 | timeout-minutes: 60 40 | run: | 41 | set -xe 42 | python -VV 43 | python -c "import jax; print('jax', jax.__version__)" 44 | python -c "import jaxlib; print('jaxlib', jaxlib.__version__)" 45 | pytest tests 46 | shell: bash 47 | 48 | 49 | build-and-test-docs: 50 | name: "Build documentation" 51 | runs-on: ubuntu-latest 52 | steps: 53 | - name: Cancel previous 54 | uses: styfle/cancel-workflow-action@0.11.0 55 | with: 56 | access_token: ${{ github.token }} 57 | if: ${{github.ref != 'refs/heads/main'}} 58 | - uses: "actions/checkout@v3" 59 | - name: Set up Python 3.11 60 | uses: "actions/setup-python@v4" 61 | with: 62 | python-version: 3.11 63 | cache: 'pip' 64 | - name: Install dependencies 65 | run: | 66 | set -xe 67 | pip install --upgrade pip setuptools wheel 68 | pip install -r docs/requirements.txt 69 | - name: Build documentation 70 | run: | 71 | set -xe 72 | python -VV 73 | cd docs && make clean && make html 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.so 3 | *.egg-info 4 | *.whl 5 | build/ 6 | dist/ 7 | .ipynb_checkpoints 8 | .DS_Store 9 | .mypy_cache/ 10 | .pytype/ 11 | docs/_build 12 | docs/notebooks/.ipynb_checkpoints/ 13 | docs/_autosummary 14 | docs/modules/generated 15 | docs/auto_examples 16 | .idea 17 | .vscode 18 | venv/ 19 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Install the pre-commit hooks below with 2 | # 'pre-commit install' 3 | 4 | # Auto-update the version of the hooks with 5 | # 'pre-commit autoupdate' 6 | 7 | # Run the hooks on all files with 8 | # 'pre-commit run --all' 9 | 10 | 11 | repos: 12 | - repo: https://github.com/mwouts/jupytext 13 | rev: v1.15.1 14 | hooks: 15 | - id: jupytext 16 | args: [--sync] 17 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JAXopt 2 | 3 | [**Status**](#status) 4 | | [**Installation**](#installation) 5 | | [**Documentation**](https://jaxopt.github.io) 6 | | [**Examples**](https://github.com/google/jaxopt/tree/main/examples) 7 | | [**Cite us**](#citeus) 8 | 9 | Hardware accelerated, batchable and differentiable optimizers in 10 | [JAX](https://github.com/google/jax). 11 | 12 | - **Hardware accelerated:** our implementations run on GPU and TPU, in addition 13 | to CPU. 14 | - **Batchable:** multiple instances of the same optimization problem can be 15 | automatically vectorized using JAX's vmap. 16 | - **Differentiable:** optimization problem solutions can be differentiated with 17 | respect to their inputs either implicitly or via autodiff of unrolled 18 | algorithm iterations. 19 | 20 | ## Status 21 | 22 | JAXopt is no longer maintained nor developed. Alternatives may be found on the 23 | JAX [website](https://docs.jax.dev/en/latest/). Some of its features (like 24 | losses, projections, lbfgs optimizer) have been ported into 25 | [optax](https://github.com/google-deepmind/optax). We are sincerely grateful for 26 | all the community contributions the project has garnered over the years. 27 | 28 | ## Installation 29 | 30 | To install the latest release of JAXopt, use the following command: 31 | 32 | ```bash 33 | $ pip install jaxopt 34 | ``` 35 | 36 | To install the **development** version, use the following command instead: 37 | 38 | ```bash 39 | $ pip install git+https://github.com/google/jaxopt 40 | ``` 41 | 42 | Alternatively, it can be installed from sources with the following command: 43 | 44 | ```bash 45 | $ python setup.py install 46 | ``` 47 | 48 | ## Cite us 49 | 50 | Our implicit differentiation framework is described in this 51 | [paper](https://arxiv.org/abs/2105.15183). To cite it: 52 | 53 | ``` 54 | @article{jaxopt_implicit_diff, 55 | title={Efficient and Modular Implicit Differentiation}, 56 | author={Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy 57 | and Hoyer, Stephan and Llinares-L{\'o}pez, Felipe and Pedregosa, Fabian 58 | and Vert, Jean-Philippe}, 59 | journal={arXiv preprint arXiv:2105.15183}, 60 | year={2021} 61 | } 62 | ``` 63 | 64 | ## Disclaimer 65 | 66 | JAXopt was an open source project maintained by a dedicated team in Google 67 | Research. It is not an official Google product. 68 | 69 | -------------------------------------------------------------------------------- /benchmarks/lbfgs_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Benchmark LBFGS implementation.""" 16 | 17 | 18 | from absl import app 19 | from absl import flags 20 | 21 | from sklearn import datasets 22 | 23 | import jax.numpy as jnp 24 | import jaxopt 25 | 26 | import numpy as onp 27 | 28 | import matplotlib.pyplot as plt 29 | 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | flags.DEFINE_integer("maxiter", default=30, help="Max # of iterations.") 34 | flags.DEFINE_integer("n_samples", default=10000, help="Number of samples.") 35 | flags.DEFINE_integer("n_features", default=200, help="Number of features.") 36 | flags.DEFINE_string("task", "binary_logreg", "Task to benchmark.") 37 | 38 | 39 | def binary_logreg(linesearch): 40 | X, y = datasets.make_classification(n_samples=FLAGS.n_samples, 41 | n_features=FLAGS.n_features, 42 | n_classes=2, 43 | n_informative=3, 44 | random_state=0) 45 | data = (X, y) 46 | fun = jaxopt.objective.binary_logreg 47 | init = jnp.zeros(X.shape[1]) 48 | lbfgs = jaxopt.LBFGS(fun=fun, linesearch=linesearch) 49 | state = lbfgs.init_state(init, data=data) 50 | errors = onp.zeros(FLAGS.maxiter) 51 | params = init 52 | 53 | for it in range(FLAGS.maxiter): 54 | params, state = lbfgs.update(params, state, data=data) 55 | errors[it] = state.error 56 | 57 | return errors 58 | 59 | 60 | def multiclass_logreg(linesearch): 61 | X, y = datasets.make_classification(n_samples=FLAGS.n_samples, 62 | n_features=FLAGS.n_features, 63 | n_classes=5, 64 | n_informative=5, 65 | random_state=0) 66 | data = (X, y) 67 | fun = jaxopt.objective.multiclass_logreg 68 | init = jnp.zeros((X.shape[1], 5)) 69 | lbfgs = jaxopt.LBFGS(fun=fun, linesearch=linesearch) 70 | state = lbfgs.init_state(init, data=data) 71 | errors = onp.zeros(FLAGS.maxiter) 72 | params = init 73 | 74 | for it in range(FLAGS.maxiter): 75 | params, state = lbfgs.update(params, state, data=data) 76 | errors[it] = state.error 77 | 78 | return errors 79 | 80 | 81 | def run_binary_logreg(): 82 | errors_backtracking = binary_logreg("backtracking") 83 | errors_zoom = binary_logreg("zoom") 84 | 85 | plt.figure() 86 | plt.plot(jnp.arange(FLAGS.maxiter), errors_backtracking, label="backtracking") 87 | plt.plot(jnp.arange(FLAGS.maxiter), errors_zoom, label="zoom") 88 | plt.xlabel("Iterations") 89 | plt.ylabel("Gradient error") 90 | plt.yscale("log") 91 | plt.legend(loc="best") 92 | plt.show() 93 | 94 | 95 | def run_multiclass_logreg(): 96 | errors_backtracking = multiclass_logreg("backtracking") 97 | errors_zoom = multiclass_logreg("zoom") 98 | 99 | plt.figure() 100 | plt.plot(jnp.arange(FLAGS.maxiter), errors_backtracking, label="backtracking") 101 | plt.plot(jnp.arange(FLAGS.maxiter), errors_zoom, label="zoom") 102 | plt.xlabel("Iterations") 103 | plt.ylabel("Gradient error") 104 | plt.yscale("log") 105 | plt.legend(loc="best") 106 | plt.show() 107 | 108 | 109 | def main(argv): 110 | if len(argv) > 1: 111 | raise app.UsageError("Too many command-line arguments.") 112 | 113 | print("n_samples:", FLAGS.n_samples) 114 | print("n_features:", FLAGS.n_features) 115 | print("maxiter:", FLAGS.maxiter) 116 | print("task:", FLAGS.task) 117 | print() 118 | 119 | if FLAGS.task == "binary_logreg": 120 | run_binary_logreg() 121 | elif FLAGS.task == "multiclass_logreg": 122 | run_multiclass_logreg() 123 | else: 124 | raise ValueError("Invalid task name.") 125 | 126 | 127 | if __name__ == '__main__': 128 | app.run(main) 129 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | clean: 23 | rm -rf $(BUILDDIR)/* 24 | rm -rf auto_examples/ 25 | rm -rf _autosummary/ 26 | 27 | html-noplot: 28 | $(SPHINXBUILD) -D plot_gallery=0 -D jupyter_execute_notebooks=off -b html $(ALLSPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html 29 | @echo 30 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 31 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Where to find the docs 2 | 3 | The JAXopt documentation can be found here: 4 | https://jaxopt.github.io 5 | 6 | # How to build the docs 7 | 8 | 1. Install the requirements using `pip install -r docs/requirements.txt` 9 | 2. Make sure `pandoc` is installed 10 | 3. Run the make script `make html` or `make html-noplot` for building without running the examples. 11 | -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | 2 | div.topic { 3 | padding: 0.5rem; 4 | background-color: #eee; 5 | margin-bottom: 1rem; 6 | border-radius: 0.25rem; 7 | border: 1px solid #CCC; 8 | } 9 | 10 | div.topic p { 11 | margin-bottom: 0.25rem; 12 | } 13 | 14 | div.topic dd { 15 | margin-bottom: 0.25rem; 16 | } 17 | 18 | p.topic-title { 19 | font-weight: bold; 20 | margin-bottom: 0.5rem; 21 | } 22 | 23 | div.topic > ul.simple { 24 | margin-bottom: 0.25rem; 25 | } 26 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | API at a glance 2 | =============== 3 | 4 | Optimization 5 | ------------ 6 | 7 | Unconstrained 8 | ~~~~~~~~~~~~~ 9 | 10 | .. autosummary:: 11 | :toctree: _autosummary 12 | 13 | jaxopt.BFGS 14 | jaxopt.GradientDescent 15 | jaxopt.LBFGS 16 | jaxopt.ScipyMinimize 17 | jaxopt.NonlinearCG 18 | 19 | Constrained 20 | ~~~~~~~~~~~ 21 | 22 | .. autosummary:: 23 | :toctree: _autosummary 24 | 25 | jaxopt.LBFGSB 26 | jaxopt.MirrorDescent 27 | jaxopt.ProjectedGradient 28 | jaxopt.ScipyBoundedMinimize 29 | 30 | Quadratic programming 31 | ~~~~~~~~~~~~~~~~~~~~~ 32 | 33 | .. autosummary:: 34 | :toctree: _autosummary 35 | 36 | jaxopt.BoxCDQP 37 | jaxopt.BoxOSQP 38 | jaxopt.CvxpyQP 39 | jaxopt.EqualityConstrainedQP 40 | jaxopt.OSQP 41 | 42 | Non-smooth 43 | ~~~~~~~~~~ 44 | 45 | .. autosummary:: 46 | :toctree: _autosummary 47 | 48 | jaxopt.ProximalGradient 49 | jaxopt.BlockCoordinateDescent 50 | 51 | Stochastic 52 | ~~~~~~~~~~ 53 | 54 | .. autosummary:: 55 | :toctree: _autosummary 56 | 57 | jaxopt.ArmijoSGD 58 | jaxopt.OptaxSolver 59 | jaxopt.PolyakSGD 60 | 61 | Loss functions 62 | ~~~~~~~~~~~~~~ 63 | 64 | .. autosummary:: 65 | :toctree: _autosummary 66 | 67 | jaxopt.loss.binary_logistic_loss 68 | jaxopt.loss.binary_sparsemax_loss 69 | jaxopt.loss.binary_hinge_loss 70 | jaxopt.loss.binary_perceptron_loss 71 | jaxopt.loss.sparse_plus 72 | jaxopt.loss.sparse_sigmoid 73 | jaxopt.loss.huber_loss 74 | jaxopt.loss.multiclass_logistic_loss 75 | jaxopt.loss.multiclass_sparsemax_loss 76 | jaxopt.loss.multiclass_hinge_loss 77 | jaxopt.loss.multiclass_perceptron_loss 78 | 79 | Linear system solving 80 | --------------------- 81 | 82 | .. autosummary:: 83 | :toctree: _autosummary 84 | 85 | jaxopt.linear_solve.solve_lu 86 | jaxopt.linear_solve.solve_cholesky 87 | jaxopt.linear_solve.solve_cg 88 | jaxopt.linear_solve.solve_normal_cg 89 | jaxopt.linear_solve.solve_gmres 90 | jaxopt.linear_solve.solve_bicgstab 91 | jaxopt.IterativeRefinement 92 | 93 | Nonlinear least squares 94 | ----------------------- 95 | 96 | .. autosummary:: 97 | :toctree: _autosummary 98 | 99 | jaxopt.GaussNewton 100 | jaxopt.LevenbergMarquardt 101 | 102 | Root finding 103 | ------------ 104 | 105 | .. autosummary:: 106 | :toctree: _autosummary 107 | 108 | jaxopt.Bisection 109 | jaxopt.Broyden 110 | jaxopt.ScipyRootFinding 111 | 112 | Fixed point resolution 113 | ---------------------- 114 | 115 | .. autosummary:: 116 | :toctree: _autosummary 117 | 118 | jaxopt.FixedPointIteration 119 | jaxopt.AndersonAcceleration 120 | jaxopt.AndersonWrapper 121 | 122 | Implicit differentiation 123 | ------------------------ 124 | 125 | .. autosummary:: 126 | :toctree: _autosummary 127 | 128 | jaxopt.implicit_diff.custom_root 129 | jaxopt.implicit_diff.custom_fixed_point 130 | jaxopt.implicit_diff.root_jvp 131 | jaxopt.implicit_diff.root_vjp 132 | 133 | Line search 134 | ----------- 135 | 136 | .. autosummary:: 137 | :toctree: _autosummary 138 | 139 | jaxopt.BacktrackingLineSearch 140 | jaxopt.HagerZhangLineSearch 141 | 142 | 143 | Perturbed optimizers 144 | -------------------- 145 | 146 | .. autosummary:: 147 | :toctree: _autosummary 148 | 149 | jaxopt.perturbations.make_perturbed_argmax 150 | jaxopt.perturbations.make_perturbed_max 151 | jaxopt.perturbations.make_perturbed_fun 152 | jaxopt.perturbations.Gumbel 153 | jaxopt.perturbations.Normal 154 | 155 | 156 | 157 | Isotonic regression 158 | ------------------- 159 | 160 | .. autosummary:: 161 | :toctree: _autosummary 162 | 163 | 164 | jaxopt.isotonic.isotonic_l2_pav 165 | 166 | 167 | Tree utilities 168 | -------------- 169 | 170 | .. autosummary:: 171 | :toctree: _autosummary 172 | 173 | jaxopt.tree_util.tree_add 174 | jaxopt.tree_util.tree_sub 175 | jaxopt.tree_util.tree_mul 176 | jaxopt.tree_util.tree_div 177 | jaxopt.tree_util.tree_scalar_mul 178 | jaxopt.tree_util.tree_add_scalar_mul 179 | jaxopt.tree_util.tree_vdot 180 | jaxopt.tree_util.tree_sum 181 | jaxopt.tree_util.tree_l2_norm 182 | jaxopt.tree_util.tree_zeros_like 183 | 184 | -------------------------------------------------------------------------------- /docs/basics.rst: -------------------------------------------------------------------------------- 1 | Basics 2 | ====== 3 | 4 | This section describes useful concepts across all JAXopt. 5 | 6 | Pytrees 7 | ------- 8 | 9 | `Pytrees `_ are an essential 10 | concept in JAX and JAXopt. They can be thought as a generalization of vectors. 11 | They are a way to structure parameters or weights using tuples and 12 | dictionaries. Many solvers in JAXopt have native support for pytrees. 13 | 14 | Double precision 15 | ---------------- 16 | 17 | JAX uses single (32-bit) floating precision by default. However, for some 18 | algorithms, this may not be enough. Double (64-bit) floating precision can be 19 | enabled by adding the following at the beginning of the file:: 20 | 21 | jax.config.update("jax_enable_x64", True) 22 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Configuration file for the Sphinx documentation builder. 16 | # 17 | # This file only contains a selection of the most common options. For a full 18 | # list see the documentation: 19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 20 | 21 | # -- Path setup -------------------------------------------------------------- 22 | 23 | # If extensions (or modules to document with autodoc) are in another directory, 24 | # add these directories to sys.path here. If the directory is relative to the 25 | # documentation root, use os.path.abspath to make it absolute, like shown here. 26 | # 27 | import os 28 | import sys 29 | sys.path.insert(0, os.path.abspath('..')) 30 | 31 | from jaxopt.version import __version__ 32 | 33 | 34 | # -- Project information ----------------------------------------------------- 35 | 36 | project = 'JAXopt' 37 | copyright = '2021-2022, the JAXopt authors' 38 | author = 'JAXopt authors' 39 | 40 | # The full version, including alpha/beta/rc tags 41 | release = __version__ 42 | 43 | 44 | # -- General configuration --------------------------------------------------- 45 | 46 | # Add any Sphinx extension module names here, as strings. They can be 47 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 48 | # ones. 49 | extensions = [ 50 | 'sphinx.ext.napoleon', # napoleon on top of autodoc: https://stackoverflow.com/a/66930447 might correct some warnings 51 | 'sphinx.ext.autodoc', 52 | 'sphinx.ext.autosummary', 53 | 'sphinx.ext.intersphinx', 54 | 'sphinx.ext.mathjax', 55 | 'sphinx.ext.viewcode', 56 | 'matplotlib.sphinxext.plot_directive', 57 | 'sphinx_autodoc_typehints', 58 | 'myst_nb', 59 | "sphinx_remove_toctrees", 60 | 'sphinx_rtd_theme', 61 | 'sphinx_gallery.gen_gallery', 62 | 'sphinx_copybutton', 63 | ] 64 | 65 | sphinx_gallery_conf = { 66 | 'examples_dirs': '../examples', # path to your example scripts 67 | 'gallery_dirs': 'auto_examples', # path to where to save gallery generated output 68 | 'ignore_pattern': r'_test\.py', # no gallery for test of examples 69 | "doc_module": "jaxopt", 70 | "backreferences_dir": os.path.join("modules", "generated"), 71 | } 72 | 73 | 74 | source_suffix = ['.rst', '.ipynb', '.md'] 75 | 76 | autosummary_generate = True 77 | autodoc_default_options = {"members": True, "inherited-members": True} 78 | 79 | master_doc = 'index' 80 | 81 | autodoc_typehints = 'description' 82 | 83 | # Add any paths that contain templates here, relative to this directory. 84 | templates_path = ['_templates'] 85 | 86 | # List of patterns, relative to source directory, that match files and 87 | # directories to ignore when looking for source files. 88 | # This pattern also affects html_static_path and html_extra_path. 89 | exclude_patterns = [ 90 | 'build/html', 91 | 'build/jupyter_execute', 92 | 'README.md', 93 | '_build', 94 | '**.ipynb_checkpoints', 95 | # Ignore markdown source for notebooks; myst-nb builds from the ipynb 96 | 'notebooks/deep_learning/*.md', 97 | 'notebooks/distributed/*.md', 98 | 'notebooks/implicit_diff/*.md'] 99 | 100 | 101 | # -- Options for HTML output ------------------------------------------------- 102 | 103 | # The theme to use for HTML and HTML Help pages. See the documentation for 104 | # a list of builtin themes. 105 | # 106 | 107 | html_theme = 'sphinx_rtd_theme' 108 | html_logo = '' 109 | html_favicon = '' 110 | 111 | # Add any paths that contain custom static files (such as style sheets) here, 112 | # relative to this directory. They are copied after the builtin static files, 113 | # so a file named "default.css" will overwrite the builtin "default.css". 114 | html_static_path = ['_static'] 115 | # These paths are either relative to html_static_path 116 | # or fully qualified paths (eg. https://...) 117 | html_css_files = [ 118 | 'css/custom.css', 119 | ] 120 | html_context = { 121 | "display_github": True, # Integrate GitHub 122 | "github_user": "google", # Username 123 | "github_repo": "jaxopt", # Repo name 124 | "github_version": "main", # Version 125 | "conf_py_path": "/docs/", # Path in the checkout to the docs root 126 | } 127 | 128 | 129 | # -- Options for myst ---------------------------------------------- 130 | nb_execution_mode = "force" 131 | nb_execution_allow_errors = False 132 | nb_execution_fail_on_error = True # Requires https://github.com/executablebooks/MyST-NB/pull/296 133 | myst_enable_extensions = ['dollarmath'] # To display maths in notebook 134 | 135 | # Notebook cell execution timeout; defaults to 30. 136 | nb_execution_timeout = 100 137 | 138 | # List of patterns, relative to source directory, that match notebook 139 | # files that will not be executed. 140 | nb_execution_excludepatterns = [ 141 | # Slow notebook 142 | 'notebooks/deep_learning/*.*', 143 | 'notebooks/distributed/*.*', 144 | 'notebooks/implicit_diff/dataset_distillation.*', 145 | 'notebooks/implicit_diff/maml.*', 146 | ] 147 | -------------------------------------------------------------------------------- /docs/constrained.rst: -------------------------------------------------------------------------------- 1 | .. _constrained_optim: 2 | 3 | Constrained optimization 4 | ======================== 5 | 6 | This section is concerned with problems of the form 7 | 8 | .. math:: 9 | 10 | \min_{x} f(x, \theta) \textrm{ subject to } x \in \mathcal{C}(\upsilon), 11 | 12 | where :math:`f(x, \theta)` is differentiable (almost everywhere), :math:`x` are 13 | the parameters with respect to which the function is minimized, :math:`\theta` 14 | are optional additional arguments, :math:`\mathcal{C}(\upsilon)` is a convex 15 | set and :math:`\upsilon` are parameter the convex set may depend on. 16 | 17 | Projected gradient 18 | ------------------ 19 | 20 | .. autosummary:: 21 | :toctree: _autosummary 22 | 23 | jaxopt.ProjectedGradient 24 | 25 | Instantiating and running the solver 26 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 27 | 28 | To solve constrained optimization problems, we can use projected gradient 29 | descent, which is gradient descent with an additional projection onto the 30 | constraint set. Constraints are specified by setting the ``projection`` 31 | argument. For instance, non-negativity constraints can be specified using 32 | :func:`projection_non_negative `:: 33 | 34 | from jaxopt import ProjectedGradient 35 | from jaxopt.projection import projection_non_negative 36 | 37 | pg = ProjectedGradient(fun=fun, projection=projection_non_negative) 38 | pg_sol = pg.run(w_init, data=(X, y)).params 39 | 40 | Numerous projections are available, see below. 41 | 42 | Specifying projection parameters 43 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 44 | 45 | Some projections have a hyperparameter that can be specified. For 46 | instance, the hyperparameter of :func:`projection_l2_ball 47 | ` is the radius of the :math:`L_2` ball. 48 | This can be passed using the ``hyperparams_proj`` argument of ``run``:: 49 | 50 | from jaxopt.projection import projection_l2_ball 51 | 52 | radius = 1.0 53 | pg = ProjectedGradient(fun=fun, projection=projection_l2_ball) 54 | pg_sol = pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params 55 | 56 | .. topic:: Examples 57 | 58 | * :ref:`sphx_glr_auto_examples_constrained_binary_kernel_svm_with_intercept.py` 59 | 60 | Differentiation 61 | ~~~~~~~~~~~~~~~ 62 | 63 | In some applications, it is useful to differentiate the solution of the solver 64 | with respect to some hyperparameters. Continuing the previous example, we can 65 | now differentiate the solution w.r.t. ``radius``:: 66 | 67 | def solution(radius): 68 | pg = ProjectedGradient(fun=fun, projection=projection_l2_ball, implicit_diff=True) 69 | return pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params 70 | 71 | print(jax.jacobian(solution)(radius)) 72 | 73 | Under the hood, we use the implicit function theorem if ``implicit_diff=True`` 74 | and autodiff of unrolled iterations if ``implicit_diff=False``. See the 75 | :ref:`implicit differentiation ` section for more details. 76 | 77 | Projections 78 | ~~~~~~~~~~~ 79 | 80 | The Euclidean projection onto :math:`\mathcal{C}(\upsilon)` is: 81 | 82 | .. math:: 83 | 84 | \text{proj}_{\mathcal{C}}(x', \upsilon) := 85 | \underset{x}{\text{argmin}} ~ ||x' - x||^2 \textrm{ subject to } x \in \mathcal{C}(\upsilon). 86 | 87 | The following operators are available. 88 | 89 | .. autosummary:: 90 | :toctree: _autosummary 91 | 92 | jaxopt.projection.projection_non_negative 93 | jaxopt.projection.projection_box 94 | jaxopt.projection.projection_simplex 95 | jaxopt.projection.projection_sparse_simplex 96 | jaxopt.projection.projection_l1_sphere 97 | jaxopt.projection.projection_l1_ball 98 | jaxopt.projection.projection_l2_sphere 99 | jaxopt.projection.projection_l2_ball 100 | jaxopt.projection.projection_linf_ball 101 | jaxopt.projection.projection_hyperplane 102 | jaxopt.projection.projection_halfspace 103 | jaxopt.projection.projection_affine_set 104 | jaxopt.projection.projection_polyhedron 105 | jaxopt.projection.projection_box_section 106 | jaxopt.projection.projection_transport 107 | jaxopt.projection.projection_birkhoff 108 | 109 | Projections always have two arguments: the input to be projected and the 110 | parameters of the convex set. 111 | 112 | Mirror descent 113 | -------------- 114 | 115 | .. autosummary:: 116 | :toctree: _autosummary 117 | 118 | jaxopt.MirrorDescent 119 | 120 | Kullback-Leibler projections 121 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 122 | 123 | The Kullback-Leibler projection onto :math:`\mathcal{C}(\upsilon)` is: 124 | 125 | .. math:: 126 | 127 | \text{proj}_{\mathcal{C}}(x', \upsilon) := 128 | \underset{x}{\text{argmin}} ~ \text{KL}(x, \exp(x')) \textrm{ subject to } x \in \mathcal{C}(\upsilon). 129 | 130 | The following operators are available. 131 | 132 | .. autosummary:: 133 | :toctree: _autosummary 134 | 135 | jaxopt.projection.kl_projection_transport 136 | jaxopt.projection.kl_projection_birkhoff 137 | 138 | Box constraints 139 | --------------- 140 | 141 | For optimization with box constraints, in addition to projected gradient 142 | descent, we can use our SciPy wrapper. 143 | 144 | 145 | .. autosummary:: 146 | :toctree: _autosummary 147 | 148 | jaxopt.ScipyBoundedMinimize 149 | jaxopt.LBFGSB 150 | 151 | This example shows how to apply non-negativity constraints, which can 152 | be achieved by setting box constraints :math:`[0, \infty)`:: 153 | 154 | from jaxopt import ScipyBoundedMinimize 155 | 156 | w_init = jnp.zeros(n_features) 157 | lbfgsb = ScipyBoundedMinimize(fun=fun, method="l-bfgs-b") 158 | lower_bounds = jnp.zeros_like(w_init) 159 | upper_bounds = jnp.ones_like(w_init) * jnp.inf 160 | bounds = (lower_bounds, upper_bounds) 161 | lbfgsb_sol = lbfgsb.run(w_init, bounds=bounds, data=(X, y)).params 162 | -------------------------------------------------------------------------------- /docs/developer.rst: -------------------------------------------------------------------------------- 1 | 2 | Development 3 | =========== 4 | 5 | Documentation 6 | ------------- 7 | 8 | To rebuild the documentation, install several packages:: 9 | 10 | pip install -r docs/requirements.txt 11 | 12 | And then run from the ``docs`` directory:: 13 | 14 | make html 15 | 16 | This can take a long time because it executes many of the examples; 17 | if you'd prefer to build the docs without executing the notebooks, you can run:: 18 | 19 | make html-noplot 20 | 21 | You can then see the generated documentation in ``docs/_build/html/index.html``. 22 | 23 | 24 | 25 | Update notebooks 26 | ++++++++++++++++ 27 | 28 | We use `jupytext `_ to maintain two synced copies of the notebooks 29 | in ``docs/notebooks``: one in ``ipynb`` format, and one in ``md`` format. The advantage of the former 30 | is that it can be opened and executed directly in Colab; the advantage of the latter is that 31 | it makes it much easier to track diffs within version control. 32 | 33 | Editing ipynb 34 | +++++++++++++ 35 | 36 | For making large changes that substantially modify code and outputs, it is easiest to 37 | edit the notebooks in Jupyter or in Colab. To edit notebooks in the Colab interface, 38 | open http://colab.research.google.com and ``Upload`` from your local repo. 39 | Update it as needed, ``Run all cells`` then ``Download ipynb``. 40 | You may want to test that it executes properly, using ``sphinx-build`` as explained above. 41 | 42 | Editing md 43 | ++++++++++ 44 | 45 | For making smaller changes to the text content of the notebooks, it is easiest to edit the 46 | ``.md`` versions using a text editor. 47 | 48 | Syncing notebooks 49 | +++++++++++++++++ 50 | 51 | After editing either the ipynb or md versions of the notebooks, you can sync the two versions 52 | using `jupytext `_. For example, to sync the files inside the ``docs/notebooks/deep_learning/``, run the command:: 53 | 54 | jupytext --sync docs/notebooks/deep_learning/*.* 55 | 56 | 57 | Be sure to use the version of jupytext specified in 58 | `.pre-commit-config.yaml `_. 59 | 60 | Alternatively, you can use the `pre-commit `_ framework to run this 61 | on all staged files in your git repository, automatically using the correct jupytext version:: 62 | 63 | pre-commit run jupytext 64 | 65 | See the pre-commit framework documentation for information on how to set your local git 66 | environment to execute this automatically. 67 | 68 | Creating new notebooks 69 | ++++++++++++++++++++++ 70 | 71 | If you are adding a new notebook to the documentation and would like to use the ``jupytext --sync`` 72 | command discussed here, you can set up your notebook for jupytext by using the following command:: 73 | 74 | jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb 75 | 76 | 77 | This works by adding a ``"jupytext"`` metadata field to the notebook file which specifies the 78 | desired formats, and which the ``jupytext --sync`` command recognizes when invoked. 79 | 80 | Notebooks within the sphinx build 81 | +++++++++++++++++++++++++++++++++ 82 | 83 | We exclude some notebooks from the build, e.g., because they contain long computations. 84 | See ``exclude_patterns`` in `conf.py `_. 85 | -------------------------------------------------------------------------------- /docs/implicit_diff.rst: -------------------------------------------------------------------------------- 1 | .. _implicit_diff: 2 | 3 | Implicit differentiation 4 | ======================== 5 | 6 | Argmin differentiation 7 | ---------------------- 8 | 9 | Argmin differentiation is the task of differentiating a minimization problem's 10 | solution with respect to its inputs. Namely, given 11 | 12 | .. math:: 13 | 14 | x^\star(\theta) := \underset{x}{\text{argmin}} f(x, \theta), 15 | 16 | we would like to compute the Jacobian :math:`\partial x^\star(\theta)`. This 17 | is usually done either by implicit differentiation or by autodiff through an 18 | algorithm's unrolled iterates. 19 | 20 | 21 | JAXopt solvers 22 | -------------- 23 | 24 | All solvers in JAXopt support implicit differentiation **out-of-the-box**. 25 | Most solvers have an ``implicit_diff=True|False`` option. When set to ``False``, 26 | autodiff of unrolled iterates is used instead of implicit differentiation. 27 | 28 | Using the ridge regression example from the :ref:`unconstrained optimization 29 | ` section, we can write:: 30 | 31 | def ridge_reg_objective(params, l2reg, X, y): 32 | residuals = jnp.dot(X, params) - y 33 | return jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.dot(params ** 2) 34 | 35 | def ridge_reg_solution(l2reg, X, y): 36 | gd = jaxopt.GradientDescent(fun=ridge_reg_objective, maxiter=500, implicit_diff=True) 37 | return gd.run(init_params, l2reg=l2reg, X=X, y=y).params 38 | 39 | Now, ``ridge_reg_solution`` is differentiable just like any other JAX function. 40 | Since ``ridge_reg_solution`` outputs a vector, we can compute its Jacobian:: 41 | 42 | print(jax.jacobian(ridge_reg_solution, argnums=0)(l2reg, X, y) 43 | 44 | where ``argnums=0`` specifies that we want to differentiate with respect to ``l2reg``. 45 | 46 | We can also compose ``ridge_reg_solution`` with other functions:: 47 | 48 | def validation_loss(l2reg): 49 | sol = ridge_reg_solution(l2reg, X_train, y_train) 50 | residuals = jnp.dot(X_val, params) - y_val 51 | return jnp.mean(residuals ** 2) 52 | 53 | print(jax.grad(validation_loss)(l2reg)) 54 | 55 | .. topic:: Examples 56 | 57 | * :doc:`/notebooks/implicit_diff/dataset_distillation` 58 | * :doc:`/notebooks/implicit_diff/maml` 59 | * :ref:`sphx_glr_auto_examples_implicit_diff_lasso_implicit_diff.py` 60 | * :ref:`sphx_glr_auto_examples_implicit_diff_sparse_coding.py` 61 | 62 | Custom solvers 63 | -------------- 64 | 65 | .. autosummary:: 66 | :toctree: _autosummary 67 | 68 | jaxopt.implicit_diff.custom_root 69 | jaxopt.implicit_diff.custom_fixed_point 70 | 71 | JAXopt also provides the ``custom_root`` and ``custom_fixed_point`` decorators, 72 | for easily adding implicit differentiation on top of any existing solver. 73 | 74 | .. topic:: Examples 75 | 76 | * :ref:`sphx_glr_auto_examples_implicit_diff_ridge_reg_implicit_diff.py` 77 | 78 | JVPs and VJPs 79 | ------------- 80 | 81 | Finally, we also provide lower-level routines for computing the JVPs and VJPs 82 | of roots of functions. 83 | 84 | .. autosummary:: 85 | :toctree: _autosummary 86 | 87 | jaxopt.implicit_diff.root_jvp 88 | jaxopt.implicit_diff.root_vjp 89 | 90 | .. topic:: References: 91 | 92 | * `Efficient and Modular Implicit Differentiation 93 | `_, 94 | Mathieu Blondel, Quentin Berthet, Marco Cuturi, Roy Frostig, Stephan Hoyer, Felipe Llinares-López, Fabian Pedregosa, Jean-Philippe Vert. 95 | ArXiv preprint. 96 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/google/jaxopt/tree/master/docs 2 | 3 | JAXopt 4 | ====== 5 | 6 | Hardware accelerated, batchable and differentiable optimizers in 7 | `JAX `_. 8 | 9 | - **Hardware accelerated:** our implementations run on GPU and TPU, in addition 10 | to CPU. 11 | - **Batchable:** multiple instances of the same optimization problem can be 12 | automatically vectorized using JAX's vmap. 13 | - **Differentiable:** optimization problem solutions can be differentiated with 14 | respect to their inputs either implicitly or via autodiff of unrolled 15 | algorithm iterations. 16 | 17 | 18 | Status 19 | ------ 20 | 21 | JAXopt is no longer maintained nor developed. Alternatives may be found on the 22 | JAX `website `_. Some of its features (like 23 | losses, projections, lbfgs optimizer) have been ported into `optax 24 | `_. We are sincerely grateful for all 25 | the community contributions the project has garnered over the years. 26 | 27 | 28 | Installation 29 | ------------ 30 | 31 | To install the latest release of JAXopt, use the following command:: 32 | 33 | pip install jaxopt 34 | 35 | To install the **development** version, use the following command instead:: 36 | 37 | pip install git+https://github.com/google/jaxopt 38 | 39 | Alternatively, it can be be installed from sources with the following command:: 40 | 41 | python setup.py install 42 | 43 | .. toctree:: 44 | :maxdepth: 1 45 | :caption: Documentation 46 | 47 | basics 48 | unconstrained 49 | constrained 50 | quadratic_programming 51 | non_smooth 52 | stochastic 53 | root_finding 54 | fixed_point 55 | nonlinear_least_squares 56 | linear_system_solvers 57 | implicit_diff 58 | objective_and_loss 59 | line_search 60 | perturbations 61 | 62 | .. toctree:: 63 | :maxdepth: 1 64 | :caption: API 65 | 66 | api 67 | 68 | .. toctree:: 69 | :maxdepth: 2 70 | :caption: Examples 71 | 72 | notebooks/index 73 | auto_examples/index 74 | 75 | .. toctree:: 76 | :maxdepth: 1 77 | :caption: About 78 | 79 | Authors 80 | changelog 81 | Source code 82 | Issue tracker 83 | developer 84 | 85 | Support 86 | ------- 87 | 88 | If you are having issues, please let us know by filing an issue on our 89 | `issue tracker `_. 90 | 91 | License 92 | ------- 93 | 94 | JAXopt is licensed under the Apache 2.0 License. 95 | 96 | 97 | Citing 98 | ------ 99 | 100 | If this software is useful for you, please consider citing 101 | `the paper `_ that describes 102 | its implicit differentiation framework: 103 | 104 | .. code-block:: bibtex 105 | 106 | @article{jaxopt_implicit_diff, 107 | title={Efficient and Modular Implicit Differentiation}, 108 | author={Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy 109 | and Hoyer, Stephan and Llinares-L{\'o}pez, Felipe and Pedregosa, Fabian 110 | and Vert, Jean-Philippe}, 111 | journal={arXiv preprint arXiv:2105.15183}, 112 | year={2021} 113 | } 114 | 115 | 116 | Indices and tables 117 | ================== 118 | 119 | * :ref:`genindex` 120 | -------------------------------------------------------------------------------- /docs/line_search.rst: -------------------------------------------------------------------------------- 1 | Line search 2 | =========== 3 | 4 | Given current parameters :math:`x_k` and a descent direction :math:`p_k`, 5 | the goal of a line search method is to find a step size :math:`\alpha_k` 6 | such that the one-dimensional function 7 | 8 | .. math:: 9 | 10 | \varphi(\alpha_k) \triangleq f(x_k + \alpha_k p_k) 11 | 12 | is minimized or at least a sufficient decrease is guaranteed. 13 | 14 | Sufficient decrease and curvature conditions 15 | -------------------------------------------- 16 | 17 | Exactly minimizing :math:`\varphi` is often computationally costly. 18 | Instead, it is often preferred to search for :math:`\alpha_k` satisfying certain conditions. 19 | One example of these conditions are the **Wolfe conditions** 20 | 21 | .. math:: 22 | 23 | \begin{aligned} 24 | f(x_k + \alpha_k p_k) &\le f(x_k) + c_1 \alpha_k \nabla f(x_k)^\top p_k \\ 25 | \nabla f(x_k + \alpha_k p_k)^\top p_k &\ge c_2 \nabla f(x_k)^\top p_k 26 | \end{aligned} 27 | 28 | where :math:`0 < c_1 < c_2 < 1`. These conditions are explained in greater detail in 29 | Nocedal and Wright, see equations (3.6a) and (3.6b) there. 30 | 31 | A step size may satisfy the Wolfe conditions without being particularly close 32 | to a minimizer of :math:`\varphi` (Nocedal and Wright, Figure 3.5). The 33 | curvature condition in the second equation can be modified to force the step 34 | size to lie in at least a broad neighborhood of a stationary point of 35 | :math:`\varphi`. Combined with the sufficient decrease condition in the first 36 | equation, these are known as the **strong Wolfe conditions** 37 | 38 | .. math:: 39 | 40 | \begin{aligned} 41 | f(x_k + \alpha_k p_k) &\le f(x_k) + c_1 \alpha_k \nabla f(x_k)^\top p_k \\ 42 | |\nabla f(x_k + \alpha_k p_k)^\top p_k| &\le c_2 |\nabla f(x_k)^\top p_k| 43 | \end{aligned} 44 | 45 | where again :math:`0 < c_1 < c_2 < 1`. See Nocedal and Wright, equations (3.7a) and (3.7b). 46 | 47 | Algorithms 48 | ---------- 49 | 50 | .. autosummary:: 51 | :toctree: _autosummary 52 | 53 | jaxopt.BacktrackingLineSearch 54 | jaxopt.HagerZhangLineSearch 55 | 56 | The :class:`BacktrackingLineSearch ` algorithm 57 | iteratively reduces the step size by some decrease factor until the conditions 58 | above are satisfied. Example:: 59 | 60 | ls = BacktrackingLineSearch(fun=fun, maxiter=20, condition="strong-wolfe", 61 | decrease_factor=0.8) 62 | stepsize, state = ls.run(init_stepsize=1.0, params=params, 63 | descent_direction=descent_direction, 64 | value=value, grad=grad) 65 | 66 | where 67 | 68 | * ``init_stepsize`` is the first step size value to try, 69 | * ``params`` are the current parameters :math:`x_k`, 70 | * ``descent_direction`` is the provided descent direction :math:`p_k` (optional, defaults to :math:`-\nabla f(x_k)`), 71 | * ``value`` is the current value :math:`f(x_k)` (optional, recomputed if not provided), 72 | * ``grad`` is the current gradient :math:`\nabla f(x_k)` (optional, recomputed if not provided), 73 | 74 | The returned ``state`` contains useful information such as ``state.params``, 75 | which contains :math:`x_k + \alpha_k p_k` and ``state.grad``, which contains 76 | :math:`\nabla f(x_k + \alpha_k p_k)`. 77 | 78 | .. topic:: References: 79 | 80 | * Numerical Optimization, Jorge Nocedal and Stephen Wright, Second edition. 81 | -------------------------------------------------------------------------------- /docs/linear_system_solvers.rst: -------------------------------------------------------------------------------- 1 | Linear system solving 2 | ===================== 3 | 4 | This section is concerned with solving problems of the form 5 | 6 | .. math:: 7 | 8 | Ax = b 9 | 10 | with unknown :math:`x` for a linear operator :math:`A` and vector :math:`b`. 11 | 12 | Indirect solvers 13 | ---------------- 14 | 15 | .. autosummary:: 16 | :toctree: _autosummary 17 | 18 | jaxopt.linear_solve.solve_cg 19 | jaxopt.linear_solve.solve_normal_cg 20 | jaxopt.linear_solve.solve_gmres 21 | jaxopt.linear_solve.solve_bicgstab 22 | 23 | 24 | Indirect solvers iteratively solve the linear system up to some precision. 25 | Example:: 26 | 27 | from jaxopt import linear_solve 28 | import numpy as onp 29 | 30 | onp.random.seed(42) 31 | A = onp.random.randn(3, 3) 32 | b = onp.random.randn(3) 33 | 34 | def matvec_A(x): 35 | return jnp.dot(A, x) 36 | 37 | sol = linear_solve.solve_normal_cg(matvec_A, b, tol=1e-5) 38 | print(sol) 39 | 40 | sol = linear_solve.solve_gmres(matvec_A, b, tol=1e-5) 41 | print(sol) 42 | 43 | sol = linear_solve.solve_bicgstab(matvec_A, b, tol=1e-5) 44 | print(sol) 45 | 46 | The above solvers support ridge regularization with the ``ridge`` option. 47 | They can be *warm-started* using the ``init`` option. 48 | Other options, such as ``tol`` or ``maxiter``, are also supported. 49 | 50 | Direct solvers 51 | -------------- 52 | 53 | .. autosummary:: 54 | :toctree: _autosummary 55 | 56 | jaxopt.linear_solve.solve_lu 57 | jaxopt.linear_solve.solve_cholesky 58 | 59 | 60 | Direct solvers are based on matrix decompositions. 61 | They need to materialize the matrix in memory. 62 | 63 | Example:: 64 | 65 | from jaxopt import linear_solve 66 | import numpy as onp 67 | 68 | onp.random.seed(42) 69 | A = onp.random.randn(3, 3) 70 | b = onp.random.randn(3) 71 | 72 | def matvec_A(x): 73 | return jnp.dot(A, x) 74 | 75 | sol = linear_solve.solve_lu(matvec_A, b) 76 | print(sol) 77 | 78 | 79 | Iterative refinement 80 | -------------------- 81 | 82 | .. autosummary:: 83 | :toctree: _autosummary 84 | 85 | jaxopt.IterativeRefinement 86 | 87 | `Iterative refinement `_ 88 | is a meta-algorithm for solving the linear system ``Ax = b`` based on 89 | a provided linear system solver. Our implementation is a slight generalization 90 | of the standard algorithm. It starts with :math:`(r_0, x_0) = (b, 0)` and 91 | iterates 92 | 93 | .. math:: 94 | 95 | \begin{aligned} 96 | x &= \text{solution of } \bar{A} x = r_{t-1}\\ 97 | x_t &= x_{t-1} + x\\ 98 | r_t &= b - A x_t 99 | \end{aligned} 100 | 101 | where :math:`\bar{A}` is some approximation of A, with preferably 102 | better preconditonning than A. By default, we use 103 | :math:`\bar{A} = A`, which is the standard iterative refinement algorithm. 104 | This method has the advantage of converging even if the solve step is 105 | inaccurate. This is particularly useful for ill-posed problems. 106 | Example:: 107 | 108 | from functools import partial 109 | import jax.numpy as jnp 110 | import numpy as onp 111 | from jaxopt import IterativeRefinement 112 | from jaxopt.linear_solve import solve_gmres 113 | 114 | # ill-conditioned linear system 115 | A = jnp.array([[3.9, 1.65], [6.845, 2.9]]) 116 | b = jnp.array([5.5, 9.7]) 117 | print(f"Condition number: {onp.linalg.cond(A):.0f}") 118 | # Condition number: 4647 119 | 120 | ridge = 1e-2 121 | tol = 1e-7 122 | 123 | x = solve_gmres(lambda x: jnp.dot(A, x), b, tol=tol) 124 | print(f"GMRES only error: {jnp.linalg.norm(A @ x - b):.7f}") 125 | # GMRES only error: nan 126 | 127 | solve_gmres_ridge = partial(solve_gmres, ridge=ridge) 128 | 129 | x_ridge = solve_gmres_ridge(lambda x: jnp.dot(A, x), b, tol=tol, ridge=ridge) 130 | print(f"GMRES+ridge error: {jnp.linalg.norm(A @ x_ridge - b):.7f}") 131 | # GMRES+ridge error: 0.0333328 132 | 133 | solver = IterativeRefinement(solve=solve_gmres_ridge, 134 | tol=tol, maxiter=100) 135 | x_refined, state = solver.run(init_params=None, params_A=A, b=b) 136 | print(f"Iterativement Refinement error: {jnp.linalg.norm(A @ x_refined - b):.7f}") 137 | # Iterativement Refinement error: 0.0000000 138 | -------------------------------------------------------------------------------- /docs/non_smooth.rst: -------------------------------------------------------------------------------- 1 | Non-smooth optimization 2 | ======================= 3 | 4 | This section is concerned with problems of the form 5 | 6 | .. math:: 7 | 8 | \min_{x} f(x, \theta) + g(x, \lambda) 9 | 10 | where :math:`f(x, \theta)` is differentiable (almost everywhere), 11 | :math:`x` are the parameters with respect to which the function is minimized, 12 | :math:`\theta` are optional extra arguments, 13 | :math:`g(x, \lambda)` is possibly non-smooth, 14 | and :math:`\lambda` are extra parameters :math:`g` may depend on. 15 | 16 | 17 | Proximal gradient 18 | ----------------- 19 | 20 | .. autosummary:: 21 | :toctree: _autosummary 22 | 23 | jaxopt.ProximalGradient 24 | 25 | Instantiating and running the solver 26 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 27 | 28 | Proximal gradient is a generalization of :ref:`projected gradient descent 29 | `. The non-smooth term :math:`g` above is specified by 30 | setting the corresponding proximal operator, which is achieved using the 31 | ``prox`` attribute of :class:`ProximalGradient `. 32 | 33 | For instance, suppose we want to solve the following optimization problem 34 | 35 | .. math:: 36 | 37 | \min_{w} \frac{1}{2n} ||Xw - y||^2 + \text{l1reg} \cdot ||w||_1 38 | 39 | which corresponds to the choice :math:`g(w, \text{l1reg}) = \text{l1reg} \cdot ||w||_1`. The 40 | corresponding ``prox`` operator is :func:`prox_lasso `. 41 | We can therefore write:: 42 | 43 | from jaxopt import ProximalGradient 44 | from jaxopt.prox import prox_lasso 45 | 46 | def least_squares(w, data): 47 | X, y = data 48 | residuals = jnp.dot(X, w) - y 49 | return jnp.mean(residuals ** 2) 50 | 51 | l1reg = 1.0 52 | pg = ProximalGradient(fun=least_squares, prox=prox_lasso) 53 | pg_sol = pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params 54 | 55 | Note that :func:`prox_lasso ` has a hyperparameter 56 | ``l1reg``, which controls the :math:`L_1` regularization strength. As shown 57 | above, we can specify it in the ``run`` method using the ``hyperparams_prox`` 58 | argument The remaining arguments are passed to the objective function, here 59 | ``least_squares``. 60 | 61 | Numerous proximal operators are available, see below. 62 | 63 | Differentiation 64 | ~~~~~~~~~~~~~~~ 65 | 66 | In some applications, it is useful to differentiate the solution of the solver 67 | with respect to some hyperparameters. Continuing the previous example, we can 68 | now differentiate the solution w.r.t. ``l1reg``:: 69 | 70 | def solution(l1reg): 71 | pg = ProximalGradient(fun=least_squares, prox=prox_lasso, implicit_diff=True) 72 | return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params 73 | 74 | print(jax.jacobian(solution)(l1reg)) 75 | 76 | Under the hood, we use the implicit function theorem if ``implicit_diff=True`` 77 | and autodiff of unrolled iterations if ``implicit_diff=False``. See the 78 | :ref:`implicit differentiation ` section for more details. 79 | 80 | .. topic:: Examples 81 | 82 | * :ref:`sphx_glr_auto_examples_implicit_diff_lasso_implicit_diff.py` 83 | * :ref:`sphx_glr_auto_examples_implicit_diff_sparse_coding.py` 84 | 85 | .. _block_coordinate_descent: 86 | 87 | Block coordinate descent 88 | ------------------------ 89 | 90 | .. autosummary:: 91 | :toctree: _autosummary 92 | 93 | jaxopt.BlockCoordinateDescent 94 | 95 | Contrary to other solvers, :class:`jaxopt.BlockCoordinateDescent` only works with 96 | :ref:`composite linear objective functions `. 97 | 98 | Example:: 99 | 100 | from jaxopt import objective 101 | from jaxopt import prox 102 | 103 | l1reg = 1.0 104 | w_init = jnp.zeros(n_features) 105 | bcd = BlockCoordinateDescent(fun=objective.least_squares, block_prox=prox.prox_lasso) 106 | lasso_sol = bcd.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params 107 | 108 | .. topic:: Examples 109 | 110 | * :ref:`sphx_glr_auto_examples_constrained_multiclass_linear_svm.py` 111 | * :ref:`sphx_glr_auto_examples_constrained_nmf.py` 112 | 113 | Proximal operators 114 | ------------------ 115 | 116 | Proximal gradient and block coordinate descent do not access :math:`g(x, \lambda)` 117 | directly but instead require its associated proximal operator. It is defined as: 118 | 119 | .. math:: 120 | 121 | \text{prox}_{g}(x', \lambda, \eta) := 122 | \underset{x}{\text{argmin}} ~ \frac{1}{2} ||x' - x||^2 + \eta g(x, \lambda). 123 | 124 | The following operators are available. 125 | 126 | .. autosummary:: 127 | :toctree: _autosummary 128 | 129 | jaxopt.prox.make_prox_from_projection 130 | jaxopt.prox.prox_none 131 | jaxopt.prox.prox_lasso 132 | jaxopt.prox.prox_non_negative_lasso 133 | jaxopt.prox.prox_elastic_net 134 | jaxopt.prox.prox_group_lasso 135 | jaxopt.prox.prox_ridge 136 | jaxopt.prox.prox_non_negative_ridge 137 | -------------------------------------------------------------------------------- /docs/nonlinear_least_squares.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _nonlinear_least_squares: 3 | 4 | Nonlinear least squares 5 | ======================= 6 | 7 | This section is concerned with problems of the form 8 | 9 | .. math:: 10 | 11 | \min_{x} \frac{1}{2} ||\textbf{r}(x, \theta)||^2, 12 | 13 | where :math:`\textbf{r}` is is a residual function, :math:`x` are the 14 | parameters with respect to which the function is minimized, and :math:`\theta` 15 | are optional additional arguments. 16 | 17 | Gauss-Newton 18 | ------------ 19 | 20 | .. autosummary:: 21 | :toctree: _autosummary 22 | 23 | jaxopt.GaussNewton 24 | 25 | We can use the Gauss-Newton method, which is the standard approach for nonlinear least squares problems. 26 | 27 | Update equation 28 | ~~~~~~~~~~~~~~~ 29 | 30 | The following equation is solved for every iteration to find the update to the 31 | parameters: 32 | 33 | .. math:: 34 | \mathbf{J} \mathbf{J^T} h_{gn} = - \mathbf{J^T} \mathbf{r} 35 | 36 | where :math:`\mathbf{J}` is the Jacobian of the residual function w.r.t. 37 | parameters. 38 | 39 | Instantiating and running the solver 40 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 41 | 42 | As an example, let us see how to minimize the Rosenbrock residual function:: 43 | 44 | from jaxopt import GaussNewton 45 | 46 | def rosenbrock(x): 47 | return np.array([10 * (x[1] - x[0]**2), (1 - x[0])]) 48 | 49 | gn = GaussNewton(residual_fun=rosenbrock) 50 | gn_sol = gn.run(x_init).params 51 | 52 | 53 | The residual function may take additional arguments, for example for fitting a double exponential:: 54 | 55 | def double_exponential(x, x_data, y_data): 56 | return y_data - (x[0] * jnp.exp(-x[2] * x_data) + x[1] * jnp.exp( 57 | -x[3] * x_data)). 58 | 59 | gn = GaussNewton(residual_fun=double_exponential) 60 | gn_sol = gn.run(x_init, x_data=x_data, y_data=y_data).params 61 | 62 | Differentiation 63 | ~~~~~~~~~~~~~~~ 64 | 65 | In some applications, it is useful to differentiate the solution of the solver 66 | with respect to some hyperparameters. Continuing the previous example, we can 67 | now differentiate the solution w.r.t. ``y``:: 68 | 69 | def solution(y): 70 | gn = GaussNewton(residual_fun=double_exponential) 71 | lm_sol = lm.run(x_init, x_data, y).params 72 | 73 | print(jax.jacobian(solution)(y_data)) 74 | 75 | Under the hood, we use the implicit function theorem if ``implicit_diff=True`` 76 | and autodiff of unrolled iterations if ``implicit_diff=False``. See the 77 | :ref:`implicit differentiation ` section for more details. 78 | 79 | Levenberg Marquardt 80 | ------------------- 81 | 82 | .. autosummary:: 83 | :toctree: _autosummary 84 | 85 | jaxopt.LevenbergMarquardt 86 | 87 | We can also use the Levenberg-Marquardt method, which is a more advanced method compared to Gauss-Newton, in 88 | that it regularizes the update equation. It helps for cases where Gauss-Newton method fails to converge. 89 | 90 | Update equation 91 | ~~~~~~~~~~~~~~~ 92 | 93 | The following equation is solved for every iteration to find the update to the 94 | parameters: 95 | 96 | .. math:: 97 | (\mathbf{J} \mathbf{J^T} + \mu\mathbf{I}) h_{lm} = - \mathbf{J^T} \mathbf{r} 98 | 99 | where :math:`\mathbf{J}` is the Jacobian of the residual function w.r.t. 100 | parameters and :math:`\mu` is the damping parameter. 101 | -------------------------------------------------------------------------------- /docs/notebooks/deep_learning/thumbnails/adversarial_training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/jaxopt/b12c2f3ddaa3ef71fac42d7e419adccb9dcc49a1/docs/notebooks/deep_learning/thumbnails/adversarial_training.png -------------------------------------------------------------------------------- /docs/notebooks/deep_learning/thumbnails/resnet_flax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/jaxopt/b12c2f3ddaa3ef71fac42d7e419adccb9dcc49a1/docs/notebooks/deep_learning/thumbnails/resnet_flax.png -------------------------------------------------------------------------------- /docs/notebooks/deep_learning/thumbnails/resnet_haiku.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/jaxopt/b12c2f3ddaa3ef71fac42d7e419adccb9dcc49a1/docs/notebooks/deep_learning/thumbnails/resnet_haiku.png -------------------------------------------------------------------------------- /docs/notebooks/distributed/thumbnails/plot_custom_loop_pjit_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/jaxopt/b12c2f3ddaa3ef71fac42d7e419adccb9dcc49a1/docs/notebooks/distributed/thumbnails/plot_custom_loop_pjit_example.png -------------------------------------------------------------------------------- /docs/notebooks/distributed/thumbnails/plot_custom_loop_pmap_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/jaxopt/b12c2f3ddaa3ef71fac42d7e419adccb9dcc49a1/docs/notebooks/distributed/thumbnails/plot_custom_loop_pmap_example.png -------------------------------------------------------------------------------- /docs/notebooks/implicit_diff/thumbnails/maml.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/jaxopt/b12c2f3ddaa3ef71fac42d7e419adccb9dcc49a1/docs/notebooks/implicit_diff/thumbnails/maml.png -------------------------------------------------------------------------------- /docs/notebooks/implicit_diff/thumbnails/plot_dataset_distillation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/jaxopt/b12c2f3ddaa3ef71fac42d7e419adccb9dcc49a1/docs/notebooks/implicit_diff/thumbnails/plot_dataset_distillation.png -------------------------------------------------------------------------------- /docs/notebooks/index.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _notebook_gallery: 3 | 4 | Notebook gallery 5 | ================ 6 | 7 | 8 | Deep learning 9 | ------------- 10 | 11 | 12 | .. raw:: html 13 | 14 |
15 | 16 |
17 | 18 | .. only:: html 19 | 20 | .. figure:: /notebooks/deep_learning/thumbnails/resnet_flax.png 21 | :alt: Resnet example with Flax and JAXopt. 22 | 23 | :doc:`/notebooks/deep_learning/resnet_flax` 24 | 25 | .. raw:: html 26 | 27 |
28 | 29 |
30 | 31 | .. only:: html 32 | 33 | .. figure:: /notebooks/deep_learning/thumbnails/resnet_haiku.png 34 | :alt: Resnet example with Haiku and JAXopt. 35 | 36 | :doc:`/notebooks/deep_learning/resnet_haiku` 37 | 38 | .. raw:: html 39 | 40 |
41 | 42 |
43 | 44 | .. only:: html 45 | 46 | .. figure:: /notebooks/deep_learning/thumbnails/adversarial_training.png 47 | :alt: Adversarial Training. 48 | 49 | :doc:`/notebooks/deep_learning/adversarial_training` 50 | 51 | .. raw:: html 52 | 53 |
54 | 55 |
56 | 57 | 58 | 59 | 60 | Implicit Differentiation 61 | ------------------------ 62 | 63 | 64 | .. raw:: html 65 | 66 |
67 | 68 |
69 | 70 | .. only:: html 71 | 72 | .. figure:: /notebooks/implicit_diff/thumbnails/plot_dataset_distillation.png 73 | :alt: Dataset distillation example with JAXopt. 74 | 75 | :doc:`/notebooks/implicit_diff/dataset_distillation` 76 | 77 | .. raw:: html 78 | 79 |
80 | 81 | 82 | .. raw:: html 83 | 84 |
85 | 86 | .. only:: html 87 | 88 | .. figure:: /notebooks/implicit_diff/thumbnails/maml.png 89 | :alt: Few-shot Adaptation with Model Agnostic Meta-Learning (MAML) 90 | 91 | :doc:`/notebooks/implicit_diff/maml` 92 | 93 | .. raw:: html 94 | 95 |
96 | 97 |
98 | 99 | 100 | 101 | Distributed Optimization 102 | ------------------------ 103 | 104 | 105 | .. raw:: html 106 | 107 |
108 | 109 |
110 | 111 | .. only:: html 112 | 113 | .. figure:: /notebooks/distributed/thumbnails/plot_custom_loop_pjit_example.png 114 | :alt: `jax.experimental.pjit` example using JAXopt. 115 | 116 | :doc:`/notebooks/distributed/custom_loop_pjit_example` 117 | 118 | .. raw:: html 119 | 120 |
121 | 122 | 123 | .. raw:: html 124 | 125 |
126 | 127 | .. only:: html 128 | 129 | .. figure:: /notebooks/distributed/thumbnails/plot_custom_loop_pmap_example.png 130 | :alt: `jax.pmap` example using JAXopt. 131 | 132 | :doc:`/notebooks/distributed/custom_loop_pmap_example` 133 | 134 | 135 | .. raw:: html 136 | 137 |
138 | 139 |
140 | 141 | 142 | Perturbed optimizers 143 | -------------------- 144 | 145 | 146 | .. raw:: html 147 | 148 |
149 | 150 |
151 | 152 | .. only:: html 153 | 154 | .. figure:: /notebooks/perturbed_optimizers/thumbnails/perturbations.png 155 | :alt: Perturbed optimizers with JAXopt. 156 | 157 | :doc:`/notebooks/perturbed_optimizers/perturbed_optimizers` 158 | 159 | .. raw:: html 160 | 161 |
162 | 163 |
164 | 165 | -------------------------------------------------------------------------------- /docs/notebooks/perturbed_optimizers/thumbnails/perturbations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/jaxopt/b12c2f3ddaa3ef71fac42d7e419adccb9dcc49a1/docs/notebooks/perturbed_optimizers/thumbnails/perturbations.png -------------------------------------------------------------------------------- /docs/objective_and_loss.rst: -------------------------------------------------------------------------------- 1 | Loss and objective functions 2 | ============================ 3 | 4 | Loss functions 5 | -------------- 6 | 7 | Regression 8 | ~~~~~~~~~~ 9 | 10 | .. autosummary:: 11 | :toctree: _autosummary 12 | 13 | jaxopt.loss.huber_loss 14 | 15 | Regression losses are of the form ``loss(float: target, float: pred) -> float``, 16 | where ``target`` is the ground-truth and ``pred`` is the model's output. 17 | 18 | Binary classification 19 | ~~~~~~~~~~~~~~~~~~~~~ 20 | 21 | .. autosummary:: 22 | :toctree: _autosummary 23 | 24 | jaxopt.loss.binary_logistic_loss 25 | jaxopt.loss.binary_sparsemax_loss 26 | jaxopt.loss.binary_hinge_loss 27 | jaxopt.loss.binary_perceptron_loss 28 | 29 | Binary classification losses are of the form ``loss(int: label, float: score) -> float``, 30 | where ``label`` is the ground-truth (``0`` or ``1``) and ``score`` is the model's output. 31 | 32 | The following utility functions are useful for the binary sparsemax loss. 33 | 34 | .. autosummary:: 35 | :toctree: _autosummary 36 | 37 | jaxopt.loss.sparse_plus 38 | jaxopt.loss.sparse_sigmoid 39 | 40 | Multiclass classification 41 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 42 | 43 | .. autosummary:: 44 | :toctree: _autosummary 45 | 46 | jaxopt.loss.multiclass_logistic_loss 47 | jaxopt.loss.multiclass_sparsemax_loss 48 | jaxopt.loss.multiclass_hinge_loss 49 | jaxopt.loss.multiclass_perceptron_loss 50 | 51 | Multiclass classification losses are of the form ``loss(int: label, jnp.ndarray: scores) -> float``, 52 | where ``label`` is the ground-truth (between ``0`` and ``n_classes - 1``) and 53 | ``scores`` is an array of size ``n_classes``. 54 | 55 | Applying loss functions on a batch 56 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 57 | 58 | All loss functions above are pointwise, meaning that they operate on a single sample. Use ``jax.vmap(loss)`` 59 | followed by a reduction such as ``jnp.mean`` or ``jnp.sum`` to use on a batch. 60 | 61 | Objective functions 62 | ------------------- 63 | 64 | .. _composite_linear_functions: 65 | 66 | Composite linear functions 67 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 68 | 69 | .. autosummary:: 70 | :toctree: _autosummary 71 | 72 | jaxopt.objective.least_squares 73 | jaxopt.objective.binary_logreg 74 | jaxopt.objective.multiclass_logreg 75 | jaxopt.objective.multiclass_linear_svm_dual 76 | 77 | Composite linear objective functions can be used with 78 | :ref:`block coordinate descent `. 79 | 80 | Other functions 81 | ~~~~~~~~~~~~~~~ 82 | 83 | .. autosummary:: 84 | :toctree: _autosummary 85 | 86 | jaxopt.objective.ridge_regression 87 | jaxopt.objective.multiclass_logreg_with_intercept 88 | jaxopt.objective.l2_multiclass_logreg 89 | jaxopt.objective.l2_multiclass_logreg_with_intercept 90 | -------------------------------------------------------------------------------- /docs/perturbations.rst: -------------------------------------------------------------------------------- 1 | Perturbed optimization 2 | ====================== 3 | 4 | The perturbed optimization module allows to transform a non-smooth function such as a max or arg-max into a differentiable function using random perturbations. This is useful for optimization algorithms that require differentiability, such as gradient descent (e.g. see :doc:`Notebook ` on perturbed optimizers). 5 | 6 | 7 | Max perturbations 8 | ----------------- 9 | 10 | Consider a maximum function of the form: 11 | 12 | .. math:: 13 | 14 | F(\theta) = \max_{y \in \mathcal{C}} \langle y, \theta\rangle\,, 15 | 16 | where :math:`\mathcal{C}` is a convex set. 17 | 18 | 19 | 20 | .. autosummary:: 21 | :toctree: _autosummary 22 | 23 | jaxopt.perturbations.make_perturbed_max 24 | 25 | 26 | 27 | 28 | The function :meth:`jaxopt.perturbations.make_perturbed_max` transforms the function :math:`F` into a the following differentiable function using random perturbations: 29 | 30 | 31 | .. math:: 32 | 33 | F_{\varepsilon}(\theta) = \mathbb{E}\left[ F(\theta + \varepsilon Z) \right]\,, 34 | 35 | where :math:`Z` is a random variable. The distribution of this random variable can be specified through the keyword argument ``noise``. The default is a Gumbel distribution, which is a good choice for discrete variables. For continuous variables, a normal distribution is more appropriate. 36 | 37 | 38 | Argmax perturbations 39 | -------------------- 40 | 41 | Consider an arg-max function of the form: 42 | 43 | .. math:: 44 | 45 | y^*(\theta) = \mathop{\mathrm{arg\,max}}_{y \in \mathcal{C}} \langle y, \theta\rangle\,, 46 | 47 | where :math:`\mathcal{C}` is a convex set. 48 | 49 | 50 | The function :meth:`jaxopt.perturbations.make_perturbed_argmax` transforms the function :math:`y^\star` into a the following differentiable function using random perturbations: 51 | 52 | 53 | .. math:: 54 | 55 | y_{\varepsilon}^*(\theta) = \mathbb{E}\left[ \mathop{\mathrm{arg\,max}}_{y \in \mathcal{C}} \langle y, \theta + \varepsilon Z \rangle \right]\,, 56 | 57 | where :math:`Z` is a random variable. The distribution of this random variable can be specified through the keyword argument ``noise``. The default is a Gumbel distribution, which is a good choice for discrete variables. For continuous variables, a normal distribution is more appropriate. 58 | 59 | 60 | .. autosummary:: 61 | :toctree: _autosummary 62 | 63 | jaxopt.perturbations.make_perturbed_argmax 64 | 65 | 66 | Scalar perturbations 67 | -------------------- 68 | 69 | Consider any function, :math:`f` that is not necessarily differentiable, e.g. piecewise-constant of the form: 70 | 71 | .. math:: 72 | 73 | f(\theta) = g(y^*(\theta))\,, 74 | 75 | where :math:`\mathop{\mathrm{arg\,max}}_{y \in \mathcal{C}} \langle y, \theta\rangle`` and :math:`\mathcal{C}` is a convex set. 76 | 77 | 78 | The function :meth:`jaxopt.perturbations.make_perturbed_fun` transforms the function :math:`f` into a the following differentiable function using random perturbations: 79 | 80 | .. math:: 81 | 82 | f_{\varepsilon}(\theta) = \mathbb{E}\left[ f(\theta + \varepsilon Z) \right]\,, 83 | 84 | where :math:`Z` is a random variable. The distribution of this random variable can be specified through the keyword argument ``noise``. The default is a Gumbel distribution, which is a good choice for discrete variables. For continuous variables, a normal distribution is more appropriate. This can be particulary useful in the example given above, when :math:`f` is only defined on the discrete set, not its convex hull, i.e. 85 | 86 | .. math:: 87 | 88 | f_{\varepsilon}(\theta) = \mathbb{E}\left[ g(\mathop{\mathrm{arg\,max}}_{y \in \mathcal{C}} \langle y, \theta + \varepsilon Z \rangle) \right]\,, 89 | 90 | 91 | .. autosummary:: 92 | :toctree: _autosummary 93 | 94 | jaxopt.perturbations.make_perturbed_fun 95 | 96 | 97 | Noise distributions 98 | ------------------- 99 | 100 | The functions :meth:`jaxopt.perturbations.make_perturbed_max`, :meth:`jaxopt.perturbations.make_perturbed_argmax` and :meth:`jaxopt.perturbations.make_perturbed_fun` take a keyword argument ``noise`` that specifies the distribution of random perturbations. Pre-defined distributions for this argument are the following: 101 | 102 | .. autosummary:: 103 | :toctree: _autosummary 104 | 105 | jaxopt.perturbations.Normal 106 | jaxopt.perturbations.Gumbel 107 | 108 | 109 | 110 | 111 | .. topic:: References 112 | 113 | Berthet, Q., Blondel, M., Teboul, O., Cuturi, M., Vert, J. P., & Bach, F. (2020). `Learning with differentiable pertubed optimizers `_. Advances in neural information processing systems, 33. 114 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.3.4 2 | pandoc>=1.0.2 3 | sphinx>=3.5.1 4 | sphinx_rtd_theme>=0.5.1 5 | sphinx_autodoc_typehints>=1.11.1 6 | ipython>=7.20.0 7 | ipykernel>=5.5.0 8 | sphinx-gallery>=0.9.0 9 | sphinx_copybutton>=0.4.0 10 | sphinx-remove-toctrees>=0.0.3 11 | jupyter-sphinx>=0.3.2 12 | myst-nb 13 | tensorflow-datasets 14 | tensorflow 15 | dm-haiku 16 | flax 17 | jupytext 18 | scikit-learn -------------------------------------------------------------------------------- /docs/root_finding.rst: -------------------------------------------------------------------------------- 1 | .. _root_finding: 2 | 3 | Root finding 4 | ============ 5 | 6 | This section is concerned with root finding, that is finding :math:`x` such 7 | that :math:`F(x, \theta) = 0`. 8 | 9 | Bisection 10 | --------- 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | 15 | jaxopt.Bisection 16 | 17 | Bisection is a suitable algorithm when :math:`F(x, \theta)` is one-dimensional 18 | in :math:`x`. 19 | 20 | Instantiating and running the solver 21 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 22 | 23 | First, let us consider the case :math:`F(x)`, i.e., without extra argument 24 | :math:`\theta`. The ``Bisection`` class requires a bracketing interval 25 | :math:`[\text{lower}, \text{upper}]`` such that :math:`F(\text{lower})` and 26 | :math:`F(\text{upper})` have opposite signs, meaning that a root is contained 27 | in this interval as long as :math:`F` is continuous. For instance, suppose 28 | that we want to find the root of :math:`F(x) = x^3 - x - 2`. We have 29 | :math:`F(1) = -2` and :math:`F(2) = 4`. Since the function is continuous, there 30 | must be a :math:`x` between 1 and 2 such that :math:`F(x) = 0`:: 31 | 32 | from jaxopt import Bisection 33 | 34 | def F(x): 35 | return x ** 3 - x - 2 36 | 37 | bisec = Bisection(optimality_fun=F, lower=1, upper=2) 38 | print(bisec.run().params) 39 | 40 | ``Bisection`` successfully finds the root ``x = 1.521``. 41 | Notice that ``Bisection`` does not require an initialization, 42 | since the bracketing interval is sufficient. 43 | 44 | Differentiation 45 | ~~~~~~~~~~~~~~~ 46 | 47 | Now, let us consider the case :math:`F(x, \theta)`. For instance, suppose that 48 | ``F`` takes an additional argument ``factor``. We can easily differentiate 49 | with respect to ``factor``:: 50 | 51 | def F(x, factor): 52 | return factor * x ** 3 - x - 2 53 | 54 | def root(factor): 55 | bisec = Bisection(optimality_fun=F, lower=1, upper=2) 56 | return bisec.run(factor=factor).params 57 | 58 | # Derivative of root with respect to factor at 2.0. 59 | print(jax.grad(root)(2.0)) 60 | 61 | Under the hood, we use the implicit function theorem in order to differentiate the root. 62 | See the :ref:`implicit differentiation ` section for more details. 63 | 64 | Scipy wrapper 65 | ------------- 66 | 67 | .. autosummary:: 68 | :toctree: _autosummary 69 | 70 | jaxopt.ScipyRootFinding 71 | 72 | 73 | Broyden's method 74 | ---------------- 75 | 76 | .. autosummary:: 77 | :toctree: _autosummary 78 | 79 | jaxopt.Broyden 80 | 81 | Broyden's method is an iterative algorithm suitable for nonlinear root equations in any dimension. 82 | It is a quasi-Newton method (like L-BFGS), meaning that it uses an approximation of the Jacobian matrix 83 | at each iteration. 84 | The approximation is updated at each iteration with a rank-one update. 85 | This makes the approximation easy to invert using the Sherman-Morrison formula, provided that it does not use too many 86 | updates. 87 | One can control the number of updates with the ``history_size`` argument. 88 | Furthermore, Broyden's method uses a line search to ensure the rank-one updates are stable. 89 | 90 | Example:: 91 | 92 | import jax.numpy as jnp 93 | from jaxopt import Broyden 94 | 95 | def F(x): 96 | return x ** 3 - x - 2 97 | 98 | broyden = Broyden(fun=F) 99 | print(broyden.run(jnp.array(1.0)).params) 100 | 101 | 102 | For implicit differentiation:: 103 | 104 | import jax 105 | import jax.numpy as jnp 106 | from jaxopt import Broyden 107 | 108 | def F(x, factor): 109 | return factor * x ** 3 - x - 2 110 | 111 | def root(factor): 112 | broyden = Broyden(fun=F) 113 | return broyden.run(jnp.array(1.0), factor=factor).params 114 | 115 | # Derivative of root with respect to factor at 2.0. 116 | print(jax.grad(root)(2.0)) 117 | -------------------------------------------------------------------------------- /docs/stochastic.rst: -------------------------------------------------------------------------------- 1 | Stochastic optimization 2 | ======================= 3 | 4 | This section is concerned with problems of the form 5 | 6 | .. math:: 7 | 8 | \min_{x} \mathbb{E}_{D}[f(x, \theta, D)], 9 | 10 | where :math:`f(x, \theta, D)` is differentiable (almost everywhere), :math:`x` 11 | are the parameters with respect to which the function is minimized, 12 | :math:`\theta` are optional fixed extra arguments and :math:`D` is a random 13 | variable (typically a mini-batch). 14 | 15 | 16 | .. topic:: Examples 17 | 18 | * :doc:`/notebooks/deep_learning/resnet_haiku` 19 | * :doc:`/notebooks/deep_learning/resnet_flax` 20 | * :ref:`sphx_glr_auto_examples_deep_learning_haiku_vae.py` 21 | * :ref:`sphx_glr_auto_examples_deep_learning_plot_sgd_solvers.py` 22 | 23 | 24 | Defining an objective function 25 | ------------------------------ 26 | 27 | Objective functions must contain a ``data`` argument corresponding to :math:`D` above. 28 | 29 | Example:: 30 | 31 | def ridge_reg_objective(params, l2reg, data): 32 | X, y = data 33 | residuals = jnp.dot(X, params) - y 34 | return jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.dot(w ** 2) 35 | 36 | Data iterator 37 | ------------- 38 | 39 | Sampling realizations of the random variable :math:`D` can be done using an iterator. 40 | 41 | Example:: 42 | 43 | def data_iterator(): 44 | for _ in range(n_iter): 45 | perm = rng.permutation(n_samples)[:batch_size] 46 | yield (X[perm], y[perm]) 47 | 48 | Solvers 49 | ------- 50 | 51 | .. autosummary:: 52 | :toctree: _autosummary 53 | 54 | jaxopt.ArmijoSGD 55 | jaxopt.OptaxSolver 56 | jaxopt.PolyakSGD 57 | 58 | Optax solvers 59 | ~~~~~~~~~~~~~ 60 | 61 | `Optax `_ solvers can be used in JAXopt using 62 | :class:`OptaxSolver `. Here's an example with Adam:: 63 | 64 | from jaxopt import OptaxSolver 65 | 66 | opt = optax.adam(learning_rate) 67 | solver = OptaxSolver(opt=opt, fun=ridge_reg_objective, maxiter=1000) 68 | 69 | See `common optimizers 70 | `_ in the 71 | optax documentation for a list of available stochastic solvers. 72 | 73 | Adaptive solvers 74 | ~~~~~~~~~~~~~~~~ 75 | 76 | Adaptive solvers update the step size at each iteration dynamically. 77 | An example is :class:`PolyakSGD `, a solver 78 | which computes step sizes adaptively using function values. 79 | 80 | Another example is :class:`ArmijoSGD `, a solver 81 | that uses an Armijo line search. 82 | 83 | For convergence guarantees to hold, these two algorithms 84 | require the interpolation hypothesis to hold: 85 | the global optimum over :math:`D` must also be a global optimum 86 | for any finite sample of :math:`D`. 87 | This is typically achieved by overparametrized models (e.g neural networks) 88 | in classification tasks with separable classes, or on regression tasks without noise. 89 | 90 | Run iterator vs. manual loop 91 | ---------------------------- 92 | 93 | The following:: 94 | 95 | iterator = data_iterator() 96 | solver.run_iterator(init_params, iterator, l2reg=l2reg) 97 | 98 | is equivalent to:: 99 | 100 | iterator = data_iterator() 101 | state = solver.init_state(init_params, l2reg=l2reg) 102 | params = init_params 103 | for _ in range(maxiter): 104 | data = next(iterator) 105 | params, state = solver.update(params, state, l2reg=l2reg, data=data) 106 | -------------------------------------------------------------------------------- /docs/unconstrained.rst: -------------------------------------------------------------------------------- 1 | .. _unconstrained_optim: 2 | 3 | Unconstrained optimization 4 | ========================== 5 | 6 | This section is concerned with problems of the form 7 | 8 | .. math:: 9 | 10 | \min_{x} f(x, \theta) 11 | 12 | where :math:`f(x, \theta)` is a differentiable (almost everywhere), :math:`x` 13 | are the parameters with respect to which the function is minimized and 14 | :math:`\theta` are optional extra arguments. 15 | 16 | Defining an objective function 17 | ------------------------------ 18 | 19 | Objective functions must always include as first argument the variables with 20 | respect to which the function is minimized. The function can also contain extra 21 | arguments. 22 | 23 | The following illustrates how to express the ridge regression objective:: 24 | 25 | def ridge_reg_objective(params, l2reg, X, y): 26 | residuals = jnp.dot(X, params) - y 27 | return jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.sum(params ** 2) 28 | 29 | The model parameters ``params`` correspond to :math:`x` while ``l2reg``, ``X`` 30 | and ``y`` correspond to the extra arguments :math:`\theta` in the mathematical 31 | notation above. 32 | 33 | Solvers 34 | ------- 35 | 36 | .. autosummary:: 37 | :toctree: _autosummary 38 | 39 | jaxopt.BFGS 40 | jaxopt.GradientDescent 41 | jaxopt.LBFGS 42 | jaxopt.ScipyMinimize 43 | jaxopt.NonlinearCG 44 | 45 | Instantiating and running the solver 46 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 47 | 48 | Continuing the ridge regression example above, gradient descent can be 49 | instantiated and run as follows:: 50 | 51 | solver = jaxopt.LBFGS(fun=ridge_reg_objective, maxiter=maxiter) 52 | res = solver.run(init_params, l2reg=l2reg, X=X, y=y) 53 | 54 | # Alternatively, we could have used one of these solvers as well: 55 | # solver = jaxopt.GradientDescent(fun=ridge_reg_objective, maxiter=500) 56 | # solver = jaxopt.ScipyMinimize(fun=ridge_reg_objective, method="L-BFGS-B", maxiter=500) 57 | # solver = jaxopt.NonlinearCG(fun=ridge_reg_objective, method="polak-ribiere", maxiter=500) 58 | 59 | Unpacking results 60 | ~~~~~~~~~~~~~~~~~ 61 | 62 | Note that ``res`` has the form ``NamedTuple(params, state)``, where ``params`` 63 | are the approximate solution found by the solver and ``state`` contains 64 | solver-specific information about convergence. 65 | 66 | Because ``res`` is a ``NamedTuple``, we can unpack it as:: 67 | 68 | params, state = res 69 | print(params, state) 70 | 71 | Alternatively, we can also access attributes directly:: 72 | 73 | print(res.params, res.state) 74 | -------------------------------------------------------------------------------- /examples/README.rst: -------------------------------------------------------------------------------- 1 | .. _general_examples: 2 | 3 | Example gallery 4 | =============== 5 | 6 | To clone the repository and the examples, please run:: 7 | 8 | $ git clone https://github.com/google/jaxopt.git 9 | 10 | or download this `zip file `_. 11 | 12 | To install the libraries that the examples depend on, please run:: 13 | 14 | $ pip install -r examples/requirements.txt 15 | -------------------------------------------------------------------------------- /examples/constrained/README.rst: -------------------------------------------------------------------------------- 1 | .. _constrained_examples: 2 | 3 | Constrained optimization 4 | ------------------------ 5 | 6 | -------------------------------------------------------------------------------- /examples/constrained/multiclass_linear_svm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Multiclass linear SVM (without intercept). 17 | ========================================== 18 | 19 | This quadratic program can be solved either with OSQP or with block coordinate descent. 20 | 21 | Reference: 22 | 23 | Crammer, K. and Singer, Y., 2001. On the algorithmic implementation of multiclass kernel-based vector machines. 24 | Journal of machine learning research, 2(Dec), pp.265-292. 25 | """ 26 | 27 | from absl import app 28 | from absl import flags 29 | 30 | import jax 31 | import jax.numpy as jnp 32 | 33 | from jaxopt import BlockCoordinateDescent 34 | from jaxopt import OSQP 35 | from jaxopt import objective 36 | from jaxopt import projection 37 | from jaxopt import prox 38 | 39 | from sklearn import datasets 40 | from sklearn import preprocessing 41 | from sklearn import svm 42 | 43 | 44 | flags.DEFINE_float("tol", 1e-5, "Tolerance of solvers.") 45 | flags.DEFINE_float("l2reg", 1000., "Regularization parameter. Must be positive.") 46 | flags.DEFINE_integer("num_samples", 20, "Size of train set.") 47 | flags.DEFINE_integer("num_features", 5, "Features dimension.") 48 | flags.DEFINE_integer("num_classes", 3, "Number of classes.") 49 | flags.DEFINE_bool("verbose", False, "Verbosity.") 50 | FLAGS = flags.FLAGS 51 | 52 | 53 | def multiclass_linear_svm_skl(X, y, l2reg): 54 | print("Solve multiclass SVM with sklearn.svm.LinearSVC:") 55 | svc = svm.LinearSVC(loss="hinge", dual=True, multi_class="crammer_singer", 56 | C=1.0 / l2reg, fit_intercept=False, 57 | tol=FLAGS.tol, max_iter=100*1000).fit(X, y) 58 | return svc.coef_.T 59 | 60 | 61 | def multiclass_linear_svm_bcd(X, Y, l2reg): 62 | print("Block coordinate descent solution:") 63 | 64 | # Set up parameters. 65 | block_prox = prox.make_prox_from_projection(projection.projection_simplex) 66 | fun = objective.multiclass_linear_svm_dual 67 | data = (X, Y) 68 | beta_init = jnp.ones((X.shape[0], Y.shape[-1])) / Y.shape[-1] 69 | 70 | # Run solver. 71 | bcd = BlockCoordinateDescent(fun=fun, block_prox=block_prox, 72 | maxiter=10*1000, tol=FLAGS.tol) 73 | sol = bcd.run(beta_init, hyperparams_prox=None, l2reg=FLAGS.l2reg, data=data) 74 | return sol.params 75 | 76 | 77 | def multiclass_linear_svm_osqp(X, Y, l2reg): 78 | # We solve the problem 79 | # 80 | # minimize 0.5/l2reg beta X X.T beta - (1. - Y)^T beta - 1./l2reg (Y^T X) X^T beta 81 | # under beta >= 0 82 | # sum_i beta_i = 1 83 | # 84 | print("OSQP solution solution:") 85 | 86 | def matvec_Q(X, beta): 87 | return 1./l2reg * jnp.dot(X, jnp.dot(X.T, beta)) 88 | 89 | linear_part = - (1. - Y) - 1./l2reg * jnp.dot(X, jnp.dot(X.T, Y)) 90 | 91 | def matvec_A(_, beta): 92 | return jnp.sum(beta, axis=-1) 93 | 94 | def matvec_G(_, beta): 95 | return -beta 96 | 97 | b = jnp.ones(X.shape[0]) 98 | h = jnp.zeros_like(Y) 99 | 100 | osqp = OSQP(matvec_Q=matvec_Q, matvec_A=matvec_A, matvec_G=matvec_G, tol=FLAGS.tol, maxiter=10*1000) 101 | hyper_params = dict(params_obj=(X, linear_part), 102 | params_eq=(None, b), 103 | params_ineq=(None, h)) 104 | 105 | sol, _ = osqp.run(init_params=None, **hyper_params) 106 | return sol.primal 107 | 108 | 109 | def main(argv): 110 | del argv 111 | 112 | # Generate data. 113 | num_samples = FLAGS.num_samples 114 | num_features = FLAGS.num_features 115 | num_classes = FLAGS.num_classes 116 | 117 | X, y = datasets.make_classification(n_samples=num_samples, n_features=num_features, 118 | n_informative=3, n_classes=num_classes, random_state=0) 119 | X = preprocessing.Normalizer().fit_transform(X) 120 | Y = preprocessing.LabelBinarizer().fit_transform(y) 121 | Y = jnp.array(Y) 122 | 123 | l2reg = FLAGS.l2reg 124 | 125 | # Compare against sklearn. 126 | W_osqp = multiclass_linear_svm_osqp(X, Y, l2reg) 127 | W_fit_osqp = jnp.dot(X.T, (Y - W_osqp)) / l2reg 128 | print(W_fit_osqp) 129 | print() 130 | 131 | W_bcd = multiclass_linear_svm_bcd(X, Y, l2reg) 132 | W_fit_bcd = jnp.dot(X.T, (Y - W_bcd)) / l2reg 133 | print(W_fit_bcd) 134 | print() 135 | 136 | W_skl = multiclass_linear_svm_skl(X, y, l2reg) 137 | print(W_skl) 138 | print() 139 | 140 | 141 | if __name__ == "__main__": 142 | jax.config.update("jax_platform_name", "cpu") 143 | app.run(main) 144 | -------------------------------------------------------------------------------- /examples/constrained/nmf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Non-negative matrix factorizaton (NMF) using alternating minimization. 17 | ====================================================================== 18 | """ 19 | 20 | from absl import app 21 | from absl import flags 22 | 23 | import jax.numpy as jnp 24 | 25 | from jaxopt import BlockCoordinateDescent 26 | from jaxopt import objective 27 | from jaxopt import prox 28 | 29 | import numpy as onp 30 | 31 | from sklearn import datasets 32 | 33 | 34 | flags.DEFINE_string("penalty", "l2", "Regularization type.") 35 | flags.DEFINE_float("gamma", 1.0, "Regularization strength.") 36 | FLAGS = flags.FLAGS 37 | 38 | 39 | def nnreg(U, V_init, X, maxiter=150): 40 | """Regularized non-negative regression. 41 | 42 | We solve:: 43 | 44 | min_{V >= 0} mean((U V^T - X) ** 2) + 0.5 * gamma * ||V||^2_2 45 | 46 | or 47 | 48 | min_{V >= 0} mean((U V^T - X) ** 2) + gamma * ||V||_1 49 | """ 50 | if FLAGS.penalty == "l2": 51 | block_prox = prox.prox_non_negative_ridge 52 | elif FLAGS.penalty == "l1": 53 | block_prox = prox.prox_non_negative_lasso 54 | else: 55 | raise ValueError("Invalid penalty.") 56 | 57 | bcd = BlockCoordinateDescent(fun=objective.least_squares, 58 | block_prox=block_prox, 59 | maxiter=maxiter) 60 | sol = bcd.run(init_params=V_init.T, hyperparams_prox=FLAGS.gamma, data=(U, X)) 61 | return sol.params.T # approximate solution V 62 | 63 | 64 | def reconstruction_error(U, V, X): 65 | """Computes (unregularized) reconstruction error.""" 66 | UV = jnp.dot(U, V.T) 67 | return 0.5 * jnp.mean((UV - X) ** 2) 68 | 69 | 70 | def nmf(U_init, V_init, X, maxiter=10): 71 | """NMF by alternating minimization. 72 | 73 | We solve 74 | 75 | min_{U >= 0, V>= 0} ||U V^T - X||^2 + 0.5 * gamma * (||U||^2_2 + ||V||^2_2) 76 | 77 | or 78 | 79 | min_{U >= 0, V>= 0} ||U V^T - X||^2 + gamma * (||U||_1 + ||V||_1) 80 | """ 81 | U, V = U_init, V_init 82 | 83 | error = reconstruction_error(U, V, X) 84 | print(f"STEP: 0; Error: {error:.3f}") 85 | print() 86 | 87 | for step in range(1, maxiter + 1): 88 | print(f"STEP: {step}") 89 | 90 | V = nnreg(U, V, X, maxiter=150) 91 | error = reconstruction_error(U, V, X) 92 | print(f"Error: {error:.3f} (V update)") 93 | 94 | U = nnreg(V, U, X.T, maxiter=150) 95 | error = reconstruction_error(U, V, X) 96 | print(f"Error: {error:.3f} (U update)") 97 | print() 98 | 99 | 100 | def main(argv): 101 | del argv 102 | 103 | # Prepare data. 104 | X, _ = datasets.load_diabetes(return_X_y=True) 105 | X = jnp.sqrt(X ** 2) 106 | 107 | n_samples = X.shape[0] 108 | n_features = X.shape[1] 109 | n_components = 10 110 | 111 | rng = onp.random.RandomState(0) 112 | U = jnp.array(rng.rand(n_samples, n_components)) 113 | V = jnp.array(rng.rand(n_features, n_components)) 114 | 115 | # Run the algorithm. 116 | print("penalty:", FLAGS.penalty) 117 | print("gamma", FLAGS.gamma) 118 | print() 119 | 120 | nmf(U, V, X, maxiter=30) 121 | 122 | if __name__ == "__main__": 123 | app.run(main) 124 | -------------------------------------------------------------------------------- /examples/deep_learning/README.rst: -------------------------------------------------------------------------------- 1 | .. _deep_learning_examples: 2 | 3 | Deep learning 4 | ------------- 5 | 6 | -------------------------------------------------------------------------------- /examples/fixed_point/README.rst: -------------------------------------------------------------------------------- 1 | .. _fixed_point_examples: 2 | 3 | Fixed point resolution 4 | ---------------------- 5 | 6 | -------------------------------------------------------------------------------- /examples/fixed_point/plot_anderson_accelerate_gd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r""" 16 | Anderson acceleration of gradient descent. 17 | ========================================== 18 | 19 | For a strictly convex function f, :math:`\nabla f(x)=0` implies that :math:`x` 20 | is the global optimum :math:`f`. 21 | 22 | Consequently the fixed point of :math:`T(x)=x-\eta\nabla f(x)` is the optimum of 23 | :math:`f`. 24 | 25 | Note that repeated application of the operator :math:`T` coincides exactlty with 26 | gradient descent with constant step size :math:`\eta`. 27 | 28 | Hence, as any other fixed point iteration, gradient descent can benefit from 29 | Anderson acceleration. Here, we choose :math:`f` as the objective function 30 | of ridge regression on some dummy dataset. Anderson acceleration reaches the 31 | optimal parameters within few iterations, whereas gradient descent is slower. 32 | 33 | Here `m` denotes the history size, and `K` the frequency of Anderson updates. 34 | """ 35 | 36 | import jax 37 | import jax.numpy as jnp 38 | 39 | import matplotlib.pyplot as plt 40 | from sklearn import datasets 41 | 42 | from jaxopt import AndersonAcceleration 43 | from jaxopt import FixedPointIteration 44 | 45 | from jaxopt import objective 46 | from jaxopt.tree_util import tree_scalar_mul, tree_sub 47 | 48 | jax.config.update("jax_platform_name", "cpu") 49 | 50 | 51 | # retrieve intermediate iterates. 52 | def run_all(solver, w_init, *args, **kwargs): 53 | state = solver.init_state(w_init, *args, **kwargs) 54 | sol = w_init 55 | sols, errors = [], [] 56 | 57 | for _ in range(solver.maxiter): 58 | sol, state = solver.update(sol, state, *args, **kwargs) 59 | sols.append(sol) 60 | errors.append(state.error) 61 | 62 | return jnp.stack(sols, axis=0), errors 63 | 64 | 65 | # dummy dataset 66 | X, y = datasets.make_regression(n_samples=100, n_features=10, random_state=0) 67 | ridge_regression_grad = jax.grad(objective.ridge_regression) 68 | 69 | # gradient step: x - grad_x f(x) with f the cost of learning task 70 | # the fixed point of this mapping verifies grad_x f(x) = 0 71 | # i.e the fixed point is an optimum 72 | def T(params, eta, l2reg, data): 73 | g = ridge_regression_grad(params, l2reg, data) 74 | step = tree_scalar_mul(eta, g) 75 | return tree_sub(params, step) 76 | 77 | w_init = jnp.zeros(X.shape[1]) # null vector 78 | eta = 1e-1 # small step size 79 | l2reg = 0. # no regularization 80 | tol = 1e-5 81 | maxiter = 80 82 | aa = AndersonAcceleration(T, history_size=5, mixing_frequency=1, maxiter=maxiter, ridge=5e-5, tol=tol) 83 | aam = AndersonAcceleration(T, history_size=5, mixing_frequency=5, maxiter=maxiter, ridge=5e-5, tol=tol) 84 | fpi = FixedPointIteration(T, maxiter=maxiter, tol=tol) 85 | 86 | aa_sols, aa_errors = run_all(aa, w_init, eta, l2reg, (X, y)) 87 | aam_sols, aam_errors = run_all(aam, w_init, eta, l2reg, (X, y)) 88 | fp_sols, fp_errors = run_all(fpi, w_init, eta, l2reg, (X, y)) 89 | 90 | sol = aa_sols[-1] 91 | print(f'Error={aa_errors[-1]:.6f} at parameters {sol}') 92 | print(f'At this point the gradient {ridge_regression_grad(sol, l2reg, (X,y))} is close to zero vector so we found the minimum.') 93 | 94 | fig = plt.figure(figsize=(10, 12)) 95 | fig.suptitle('Trajectory in parameter space') 96 | spec = fig.add_gridspec(ncols=2, nrows=3, hspace=0.3) 97 | 98 | # Plot trajectory in parameter space (8 dimensions) 99 | for i in range(4): 100 | ax = fig.add_subplot(spec[i//2, i%2]) 101 | ax.plot(fp_sols[:,i], fp_sols[:,2*i+1], '-', linewidth=4., label="Gradient Descent") 102 | ax.plot(aa_sols[:,i], aa_sols[:,2*i+1], 'v', markersize=12, label="Anderson Accelerated GD (m=5, K=1)") 103 | ax.plot(aam_sols[:,i], aam_sols[:,2*i+1], '*', markersize=8, label="Anderson Accelerated GD (m=5, K=5)") 104 | ax.set_xlabel(f'$x_{{{2*i+1}}}$') 105 | ax.set_ylabel(f'$x_{{{2*i+2}}}$') 106 | if i == 0: 107 | ax.legend(loc='upper left', bbox_to_anchor=(0.75, 1.38), 108 | ncol=1, fancybox=True, shadow=True) 109 | ax.axis('equal') 110 | 111 | # Plot error as function of iteration num 112 | ax = fig.add_subplot(spec[2, :]) 113 | iters = jnp.arange(len(aa_errors)) 114 | ax.plot(iters, fp_errors, linewidth=4., label='Gradient Descent Error') 115 | ax.plot(iters, aa_errors, linewidth=4., label='Anderson Accelerated GD Error (m=5, K=1)') 116 | ax.plot(iters, aam_errors, linewidth=4., label='Anderson Accelerated GD Error (m=5, K=5)') 117 | ax.set_xlabel('Iteration num') 118 | ax.set_ylabel('Error') 119 | ax.set_yscale('log') 120 | ax.legend() 121 | plt.show() 122 | 123 | -------------------------------------------------------------------------------- /examples/fixed_point/plot_anderson_wrapper_cd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r""" 16 | Anderson acceleration of block coordinate descent. 17 | ================================================== 18 | 19 | Block coordinate descent converges to a fixed point. It can therefore be 20 | accelerated with Anderson acceleration. 21 | 22 | Here `m` denotes the history size, and `K` the frequency of Anderson updates. 23 | 24 | Bertrand, Q. and Massias, M. 25 | Anderson acceleration of coordinate descent. 26 | AISTATS, 2021. 27 | """ 28 | 29 | import jax 30 | import jax.numpy as jnp 31 | 32 | from jaxopt import AndersonWrapper 33 | from jaxopt import BlockCoordinateDescent 34 | 35 | from jaxopt import objective 36 | from jaxopt import prox 37 | 38 | import matplotlib.pyplot as plt 39 | from sklearn import datasets 40 | 41 | jax.config.update("jax_platform_name", "cpu") 42 | jax.config.update("jax_enable_x64", True) 43 | 44 | 45 | # retrieve intermediate iterates. 46 | def run_all(solver, w_init, *args, **kwargs): 47 | state = solver.init_state(w_init, *args, **kwargs) 48 | sol = w_init 49 | sols, errors = [sol], [state.error] 50 | for _ in range(solver.maxiter): 51 | sol, state = solver.update(sol, state, *args, **kwargs) 52 | sols.append(sol) 53 | errors.append(state.error) 54 | return jnp.stack(sols, axis=0), errors 55 | 56 | 57 | X, y = datasets.make_regression(n_samples=10, n_features=8, random_state=1) 58 | fun = objective.least_squares # fun(params, data) 59 | l1reg = 10.0 60 | data = (X, y) 61 | 62 | w_init = jnp.zeros(X.shape[1]) 63 | maxiter = 80 64 | 65 | bcd = BlockCoordinateDescent(fun, block_prox=prox.prox_lasso, maxiter=maxiter, tol=1e-6) 66 | history_size = 5 67 | aa = AndersonWrapper(bcd, history_size=history_size, mixing_frequency=1, ridge=1e-4) 68 | aam = AndersonWrapper(bcd, history_size=history_size, mixing_frequency=history_size, ridge=1e-4) 69 | 70 | aa_sols, aa_errors = run_all(aa, w_init, hyperparams_prox=l1reg, data=data) 71 | aam_sols, aam_errors = run_all(aam, w_init, hyperparams_prox=l1reg, data=data) 72 | bcd_sols, bcd_errors = run_all(bcd, w_init, hyperparams_prox=l1reg, data=data) 73 | 74 | print(f'Error={aa_errors[-1]:.6f} at parameters {aa_sols[-1]} for Anderson (m=5, K=1)') 75 | print(f'Error={aam_errors[-1]:.6f} at parameters {aam_sols[-1]} for Anderson (m=5, K=5)') 76 | print(f'Error={bcd_errors[-1]:.6f} at parameters {bcd_sols[-1]} for Block CD') 77 | 78 | fig = plt.figure(figsize=(10, 12)) 79 | fig.suptitle('Least Square linear regression with Lasso penalty') 80 | spec = fig.add_gridspec(ncols=2, nrows=3, hspace=0.3) 81 | 82 | # Plot trajectory in parameter space (8 dimensions) 83 | for i in range(4): 84 | ax = fig.add_subplot(spec[i//2, i%2]) 85 | ax.plot(bcd_sols[:,i], bcd_sols[:,2*i+1], '--', label="Coordinate Descent") 86 | ax.plot(aa_sols[:,i], aa_sols[:,2*i+1], '--', label="Anderson Accelerated CD (m=5, K=1)") 87 | ax.plot(aam_sols[:,i], aam_sols[:,2*i+1], '--', label="Anderson Accelerated CD (m=5, K=5)") 88 | ax.set_xlabel(f'$x_{{{2*i+1}}}$') 89 | ax.set_ylabel(f'$x_{{{2*i+2}}}$') 90 | if i == 0: 91 | ax.legend(loc='upper left', bbox_to_anchor=(0.75, 1.38), 92 | ncol=1, fancybox=True, shadow=True) 93 | ax.axis('equal') 94 | 95 | # Plot error as function of iteration num 96 | ax = fig.add_subplot(spec[2, :]) 97 | iters = jnp.arange(len(aa_errors)) 98 | ax.plot(iters, bcd_errors, '-o', label='Coordinate Descent Error') 99 | ax.plot(iters, aa_errors, '-o', label='Anderson Accelerated CD Error (m=5, K=1)') 100 | ax.plot(iters, aam_errors, '-o', label='Anderson Accelerated CD Error (m=5, K=5)') 101 | ax.set_xlabel('Iteration num') 102 | ax.set_ylabel('Error') 103 | ax.set_yscale('log') 104 | ax.legend() 105 | plt.show() 106 | 107 | -------------------------------------------------------------------------------- /examples/fixed_point/plot_picard_ode.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r""" 16 | Anderson acceleration in application to Picard–Lindelöf theorem. 17 | ================================================================ 18 | 19 | Thanks to the `Picard–Lindelöf theorem, 20 | `_ we can 21 | reduce differential equation solving to fixed point computations and simple 22 | integration. More precisely consider the ODE: 23 | 24 | .. math:: 25 | 26 | y'(t)=f(t,y(t)) 27 | 28 | of some time-dependant dynamic 29 | :math:`f:\mathbb{R}\times\mathbb{R}^d\rightarrow\mathbb{R}^d` and initial 30 | conditions :math:`y(0)=y_0`. Then :math:`y` is the fixed point of the following 31 | map: 32 | 33 | .. math:: 34 | 35 | y(t)=T(y)(t)\mathrel{\mathop:}=y_0+\int_0^t f(s,y(s))\mathrm{d}s 36 | 37 | Then we can define the sequence of functions :math:`(\phi_k)` with 38 | :math:`\phi_0=0` recursively as follows: 39 | 40 | .. math:: 41 | 42 | \phi_{k+1}(t)=T(\phi_k)(t)\mathrel{\mathop:} = 43 | y_0+\int_0^t f(s,\phi_k(s))\mathrm{d}s 44 | 45 | Such sequence converges to the solution of the ODE, i.e., 46 | :math:`\lim_{k\rightarrow\infty}\phi_k=y`. 47 | 48 | In this example we choose :math:`f(t,y(t))=1+y(t)^2`. We know that the 49 | analytical solution is :math:`y(t)=\tan{t}` , which we use as a ground truth to 50 | evaluate our numerical scheme. 51 | We used ``scipy.integrate.cumulative_trapezoid`` to perform 52 | integration, but any other integration method can be used. 53 | """ 54 | 55 | 56 | import jax 57 | import jax.numpy as jnp 58 | 59 | from jaxopt import AndersonAcceleration 60 | 61 | 62 | import numpy as np 63 | import matplotlib.pyplot as plt 64 | from matplotlib.pyplot import cm 65 | import scipy.integrate 66 | 67 | jax.config.update("jax_platform_name", "cpu") 68 | 69 | 70 | # Solve the differential equation y'(t)=1+t^2, with solution y(t) = tan(t) 71 | def f(ti, phi): 72 | return 1 + phi ** 2 73 | 74 | def T(phi_cur, ti, y0, dx): 75 | """Fixed point iteration in the Picard method. 76 | See: https://en.wikipedia.org/wiki/Picard%E2%80%93Lindel%C3%B6f_theorem""" 77 | f_phi = f(ti, phi_cur) 78 | phi_next = scipy.integrate.cumulative_trapezoid(f_phi, initial=y0, dx=dx) 79 | return phi_next 80 | 81 | y0 = 0 82 | num_interpolating_points = 100 83 | t0 = jnp.array(0.) 84 | tmax = 0.9 * (jnp.pi / 2) # stop before pi/2 to ensure convergence 85 | dx = (tmax - t0) / (num_interpolating_points-1) 86 | phi0 = jnp.zeros(num_interpolating_points) 87 | ti = np.linspace(t0, tmax, num_interpolating_points) 88 | 89 | sols = [phi0] 90 | aa = AndersonAcceleration(T, history_size=5, maxiter=50, ridge=1e-5, jit=False) 91 | state = aa.init_state(phi0, ti, y0, dx) 92 | sol = phi0 93 | sols.append(sol) 94 | for k in range(aa.maxiter): 95 | sol, state = aa.update(phi0, state, ti, y0, dx) 96 | sols.append(sol) 97 | res = sols[-1] - np.tan(ti) 98 | print(f'Error of {jnp.linalg.norm(res)} with ground truth tan(t)') 99 | 100 | 101 | # vizualize the first 8 iterates to make the figure easier to read 102 | sols = sols[4:12] 103 | fig = plt.figure(figsize=(8,4)) 104 | ax = fig.add_subplot(1, 1, 1) 105 | 106 | colors = cm.plasma(np.linspace(0, 1, len(sols))) 107 | for k, (sol, c) in enumerate(zip(sols, colors)): 108 | desc = rf'$\phi_{k}$' if k > 0 else rf'$\phi_0=0$' 109 | ax.plot(ti, sol, '+', c=c, label=desc) 110 | ax.plot(ti, np.tan(ti), '-', c='green', label=r'$y(t)=\tan{(t)}$ (ground truth)') 111 | 112 | ax.legend() 113 | props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) 114 | formula = rf'$\phi_{{k+1}}(t)=\phi_0+\int_0^{{{tmax/2:.2f}\pi}} f(t,\phi_{{k}}(t))\mathrm{{d}}t$' 115 | ax.text(0.42, 0.85, formula, transform=ax.transAxes, fontsize=14, verticalalignment='top', bbox=props) 116 | fig.suptitle('Anderson acceleration for ODE solving') 117 | plt.show() 118 | -------------------------------------------------------------------------------- /examples/implicit_diff/README.rst: -------------------------------------------------------------------------------- 1 | .. _implicit_diff_examples: 2 | 3 | Implicit differentiation 4 | ------------------------ 5 | 6 | -------------------------------------------------------------------------------- /examples/implicit_diff/lasso_implicit_diff.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Implicit differentiation of lasso. 17 | ================================== 18 | """ 19 | 20 | from absl import app 21 | from absl import flags 22 | 23 | import jax.numpy as jnp 24 | 25 | from jaxopt import BlockCoordinateDescent 26 | from jaxopt import objective 27 | from jaxopt import OptaxSolver 28 | from jaxopt import prox 29 | from jaxopt import ProximalGradient 30 | import optax 31 | 32 | from sklearn import datasets 33 | from sklearn import model_selection 34 | from sklearn import preprocessing 35 | 36 | flags.DEFINE_bool("unrolling", False, "Whether to use unrolling.") 37 | flags.DEFINE_string("solver", "bcd", "Solver to use (bcd or pg).") 38 | FLAGS = flags.FLAGS 39 | 40 | 41 | def outer_objective(theta, init_inner, data): 42 | """Validation loss.""" 43 | X_tr, X_val, y_tr, y_val = data 44 | # We use the bijective mapping lam = jnp.exp(theta) to ensure positivity. 45 | lam = jnp.exp(theta) 46 | 47 | if FLAGS.solver == "pg": 48 | solver = ProximalGradient( 49 | fun=objective.least_squares, 50 | prox=prox.prox_lasso, 51 | implicit_diff=not FLAGS.unrolling, 52 | maxiter=500) 53 | elif FLAGS.solver == "bcd": 54 | solver = BlockCoordinateDescent( 55 | fun=objective.least_squares, 56 | block_prox=prox.prox_lasso, 57 | implicit_diff=not FLAGS.unrolling, 58 | maxiter=500) 59 | else: 60 | raise ValueError("Unknown solver.") 61 | 62 | # The format is run(init_params, hyperparams_prox, *args, **kwargs) 63 | # where *args and **kwargs are passed to `fun`. 64 | w_fit = solver.run(init_inner, lam, (X_tr, y_tr)).params 65 | 66 | y_pred = jnp.dot(X_val, w_fit) 67 | loss_value = jnp.mean((y_pred - y_val) ** 2) 68 | 69 | # We return w_fit as auxiliary data. 70 | # Auxiliary data is stored in the optimizer state (see below). 71 | return loss_value, w_fit 72 | 73 | 74 | def main(argv): 75 | del argv 76 | 77 | print("Solver:", FLAGS.solver) 78 | print("Unrolling:", FLAGS.unrolling) 79 | 80 | # Prepare data. 81 | X, y = datasets.load_diabetes(return_X_y=True) 82 | X = preprocessing.normalize(X) 83 | # data = (X_tr, X_val, y_tr, y_val) 84 | data = model_selection.train_test_split(X, y, test_size=0.33, random_state=0) 85 | 86 | # Initialize solver. 87 | solver = OptaxSolver(opt=optax.adam(1e-2), fun=outer_objective, has_aux=True) 88 | theta = 1.0 89 | init_w = jnp.zeros(X.shape[1]) 90 | state = solver.init_state(theta, init_inner=init_w, data=data) 91 | 92 | # Run outer loop. 93 | for _ in range(10): 94 | theta, state = solver.update(params=theta, state=state, init_inner=init_w, 95 | data=data) 96 | # The auxiliary data returned by the outer loss is stored in the state. 97 | init_w = state.aux 98 | print(f"[Step {state.iter_num}] Validation loss: {state.value:.3f}.") 99 | 100 | if __name__ == "__main__": 101 | app.run(main) 102 | -------------------------------------------------------------------------------- /examples/implicit_diff/ridge_reg_implicit_diff.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Implicit differentiation of ridge regression. 17 | ============================================= 18 | """ 19 | 20 | from absl import app 21 | import jax 22 | import jax.numpy as jnp 23 | from jaxopt import implicit_diff 24 | from jaxopt import linear_solve 25 | from jaxopt import OptaxSolver 26 | import optax 27 | from sklearn import datasets 28 | from sklearn import model_selection 29 | from sklearn import preprocessing 30 | 31 | 32 | def ridge_objective(params, l2reg, data): 33 | """Ridge objective function.""" 34 | X_tr, y_tr = data 35 | residuals = jnp.dot(X_tr, params) - y_tr 36 | return 0.5 * jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.sum(params ** 2) 37 | 38 | 39 | @implicit_diff.custom_root(jax.grad(ridge_objective)) 40 | def ridge_solver(init_params, l2reg, data): 41 | """Solve ridge regression by conjugate gradient.""" 42 | X_tr, y_tr = data 43 | 44 | def matvec(u): 45 | return jnp.dot(X_tr.T, jnp.dot(X_tr, u)) 46 | 47 | return linear_solve.solve_cg(matvec=matvec, 48 | b=jnp.dot(X_tr.T, y_tr), 49 | ridge=len(y_tr) * l2reg, 50 | init=init_params, 51 | maxiter=20) 52 | 53 | 54 | # Perhaps confusingly, theta is a parameter of the outer objective, 55 | # but l2reg = jnp.exp(theta) is an hyper-parameter of the inner objective. 56 | def outer_objective(theta, init_inner, data): 57 | """Validation loss.""" 58 | X_tr, X_val, y_tr, y_val = data 59 | # We use the bijective mapping l2reg = jnp.exp(theta) 60 | # both to optimize in log-space and to ensure positivity. 61 | l2reg = jnp.exp(theta) 62 | w_fit = ridge_solver(init_inner, l2reg, (X_tr, y_tr)) 63 | y_pred = jnp.dot(X_val, w_fit) 64 | loss_value = jnp.mean((y_pred - y_val) ** 2) 65 | # We return w_fit as auxiliary data. 66 | # Auxiliary data is stored in the optimizer state (see below). 67 | return loss_value, w_fit 68 | 69 | 70 | def main(argv): 71 | del argv 72 | 73 | # Prepare data. 74 | X, y = datasets.load_diabetes(return_X_y=True) 75 | X = preprocessing.normalize(X) 76 | # data = (X_tr, X_val, y_tr, y_val) 77 | data = model_selection.train_test_split(X, y, test_size=0.33, random_state=0) 78 | 79 | # Initialize solver. 80 | solver = OptaxSolver(opt=optax.adam(1e-2), fun=outer_objective, has_aux=True) 81 | theta = 1.0 82 | init_w = jnp.zeros(X.shape[1]) 83 | state = solver.init_state(theta, init_inner=init_w, data=data) 84 | 85 | # Run outer loop. 86 | for _ in range(50): 87 | theta, state = solver.update(params=theta, state=state, init_inner=init_w, 88 | data=data) 89 | # The auxiliary data returned by the outer loss is stored in the state. 90 | init_w = state.aux 91 | print(f"[Step {state.iter_num}] Validation loss: {state.value:.3f}.") 92 | 93 | if __name__ == "__main__": 94 | app.run(main) 95 | -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | dm-haiku>=0.0.4 2 | flax>=0.3.4 3 | optax>=0.0.9 4 | scikit-learn>=0.24.1 5 | tensorflow-datasets>=4.4.0 6 | tqdm>=4.62 7 | -------------------------------------------------------------------------------- /jaxopt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | 17 | from jaxopt import implicit_diff 18 | from jaxopt import isotonic 19 | from jaxopt import loss 20 | from jaxopt import objective 21 | from jaxopt import projection 22 | from jaxopt import prox 23 | 24 | from jaxopt._src.anderson import AndersonAcceleration 25 | from jaxopt._src.anderson_wrapper import AndersonWrapper 26 | from jaxopt._src.armijo_sgd import ArmijoSGD 27 | from jaxopt._src.base import OptStep 28 | from jaxopt._src.backtracking_linesearch import BacktrackingLineSearch 29 | from jaxopt._src.bfgs import BFGS 30 | from jaxopt._src.bisection import Bisection 31 | from jaxopt._src.block_cd import BlockCoordinateDescent 32 | from jaxopt._src.broyden import Broyden 33 | from jaxopt._src.cd_qp import BoxCDQP 34 | from jaxopt._src.cvxpy_wrapper import CvxpyQP 35 | from jaxopt._src.eq_qp import EqualityConstrainedQP 36 | from jaxopt._src.fixed_point_iteration import FixedPointIteration 37 | from jaxopt._src.gauss_newton import GaussNewton 38 | from jaxopt._src.gradient_descent import GradientDescent 39 | from jaxopt._src.hager_zhang_linesearch import HagerZhangLineSearch 40 | from jaxopt._src.iterative_refinement import IterativeRefinement 41 | from jaxopt._src.lbfgs import LBFGS 42 | from jaxopt._src.lbfgsb import LBFGSB 43 | from jaxopt._src.levenberg_marquardt import LevenbergMarquardt 44 | from jaxopt._src.mirror_descent import MirrorDescent 45 | from jaxopt._src.nonlinear_cg import NonlinearCG 46 | from jaxopt._src.optax_wrapper import OptaxSolver 47 | from jaxopt._src.osqp import BoxOSQP 48 | from jaxopt._src.osqp import OSQP 49 | from jaxopt._src.polyak_sgd import PolyakSGD 50 | from jaxopt._src.projected_gradient import ProjectedGradient 51 | from jaxopt._src.proximal_gradient import ProximalGradient 52 | from jaxopt._src.scipy_wrappers import ScipyBoundedLeastSquares 53 | from jaxopt._src.scipy_wrappers import ScipyBoundedMinimize 54 | from jaxopt._src.scipy_wrappers import ScipyLeastSquares 55 | from jaxopt._src.scipy_wrappers import ScipyMinimize 56 | from jaxopt._src.scipy_wrappers import ScipyRootFinding 57 | from jaxopt._src.zoom_linesearch import ZoomLineSearch 58 | 59 | warnings.warn( 60 | "JAXopt is no longer maintained. See https://docs.jax.dev/en/latest/ for" 61 | " alternatives.", 62 | DeprecationWarning, 63 | ) 64 | -------------------------------------------------------------------------------- /jaxopt/_src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jaxopt/_src/cd_qp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementation of coordinate descent for box-constrained QPs.""" 16 | 17 | from typing import Callable 18 | from typing import NamedTuple 19 | from typing import Optional 20 | from typing import Union 21 | 22 | from dataclasses import dataclass 23 | 24 | import jax 25 | import jax.numpy as jnp 26 | 27 | from jaxopt._src import base 28 | from jaxopt._src import projection 29 | 30 | 31 | class BoxCDQPState(NamedTuple): 32 | """Named tuple containing state information.""" 33 | iter_num: int 34 | error: float 35 | 36 | 37 | def fori_loop_body_fun(i, tup): 38 | x, Q, c, l, u, error = tup 39 | # i-th element of the gradient 40 | g_i = jnp.dot(Q[i], x) + c[i] 41 | # i-th diagonal element of the Hessian 42 | h_i = Q[i, i] 43 | # Newton-update and avoid division by zero 44 | update = jnp.where(h_i == 0, 0, g_i / h_i) 45 | # Newton-update + clipping to satisfy the box constraint 46 | x_i_new = jnp.clip(x[i] - update, l[i], u[i]) 47 | delta_i = x_i_new - x[i] 48 | # Cumulated error 49 | error += jnp.abs(delta_i) 50 | x = x.at[i].set(x_i_new) 51 | return x, Q, c, l, u, error 52 | 53 | 54 | @dataclass(eq=False) 55 | class BoxCDQP(base.IterativeSolver): 56 | """Coordinate descent solver for box-constrained QPs. 57 | 58 | This solver minimizes:: 59 | 60 | 0.5 + subject to l <= x <= u 61 | 62 | Attributes: 63 | maxiter: maximum number of coordinate descent iterations. 64 | tol: tolerance to use. 65 | verbose: whether to print information on every iteration or not. 66 | 67 | implicit_diff: whether to enable implicit diff or autodiff of unrolled 68 | iterations. 69 | implicit_diff_solve: the linear system solver to use. 70 | 71 | jit: whether to JIT-compile the optimization loop (default: True). 72 | unroll: whether to unroll the optimization loop (default: "auto"). 73 | """ 74 | maxiter: int = 500 75 | tol: float = 1e-4 76 | verbose: Union[bool, int] = False 77 | implicit_diff: bool = True 78 | implicit_diff_solve: Optional[Callable] = None 79 | jit: bool = True 80 | unroll: base.AutoOrBoolean = "auto" 81 | 82 | def init_state(self, 83 | init_params: jnp.ndarray, 84 | params_obj: Optional[base.ArrayPair] = None, 85 | params_ineq: Optional[base.ArrayPair] = None) -> BoxCDQPState: 86 | """Initialize the solver state. 87 | 88 | Args: 89 | init_params: array containing the initial parameters. 90 | params_obj: Tuple of arrays ``(Q, c)``. 91 | params_ineq: Tuple of arrays ``(l, u)``. 92 | Returns: 93 | state 94 | """ 95 | del params_obj, params_ineq # Not used. 96 | return BoxCDQPState(iter_num=jnp.asarray(0), 97 | error=jnp.asarray(jnp.inf)) 98 | 99 | def update(self, 100 | params: jnp.ndarray, 101 | state: NamedTuple, 102 | params_obj: base.ArrayPair, 103 | params_ineq: base.ArrayPair) -> base.OptStep: 104 | """Performs one epoch of coordinate descent. 105 | 106 | Args: 107 | params: array containing the parameters. 108 | state: named tuple containing the solver state. 109 | params_obj: Tuple of arrays ``(Q, c)``. 110 | params_ineq: Tuple of arrays ``(l, u)``. 111 | Returns: 112 | (params, state) 113 | """ 114 | Q, c = params_obj 115 | l, u = params_ineq 116 | 117 | init = (params, Q, c, l, u, 0) 118 | 119 | # todo: ability to permute coordinate order. 120 | params, _, _, _, _, error = jax.lax.fori_loop(lower=0, 121 | upper=params.shape[0], 122 | body_fun=fori_loop_body_fun, 123 | init_val=init) 124 | 125 | state = BoxCDQPState(iter_num=state.iter_num + 1, error=error) 126 | 127 | if self.verbose: 128 | self.log_info(state) 129 | return base.OptStep(params=params, state=state) 130 | 131 | def _fixed_point_fun(self, 132 | sol: jnp.ndarray, 133 | params_obj: base.ArrayPair, 134 | params_ineq: base.ArrayPair) -> jnp.ndarray: 135 | Q, c = params_obj 136 | l, u = params_ineq 137 | grad = jnp.dot(Q, sol) + c 138 | return projection.projection_box(sol - grad, (l, u)) 139 | 140 | def optimality_fun(self, 141 | sol: jnp.ndarray, 142 | params_obj: base.ArrayPair, 143 | params_ineq: base.ArrayPair) -> jnp.ndarray: 144 | return self._fixed_point_fun(sol, params_obj, params_ineq) - sol 145 | 146 | def __post_init__(self): 147 | super().__post_init__() 148 | -------------------------------------------------------------------------------- /jaxopt/_src/cond.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Branching utilities.""" 16 | 17 | import jax 18 | 19 | def cond(cond, if_fun, else_fun, *operands, jit=True): 20 | """Wrapper to avoid having the condition to be compiled if not wanted.""" 21 | if not jit: 22 | with jax.disable_jit(): 23 | return jax.lax.cond(cond, if_fun, else_fun, *operands) 24 | return jax.lax.cond(cond, if_fun, else_fun, *operands) -------------------------------------------------------------------------------- /jaxopt/_src/fixed_point_iteration.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementation of the fixed point iteration method in JAX.""" 16 | 17 | from typing import Any 18 | from typing import Callable 19 | from typing import NamedTuple 20 | from typing import Optional 21 | from typing import Union 22 | 23 | from dataclasses import dataclass 24 | 25 | import jax.numpy as jnp 26 | 27 | from jaxopt._src import base 28 | from jaxopt._src.tree_util import tree_l2_norm, tree_sub 29 | 30 | 31 | class FixedPointState(NamedTuple): 32 | """Named tuple containing state information. 33 | Attributes: 34 | iter_num: iteration number 35 | error: residuals of current estimate 36 | aux: auxiliary output of fixed_point_fun when has_aux=True 37 | """ 38 | iter_num: int 39 | error: float 40 | aux: Optional[Any] = None 41 | num_fun_eval: int = 0 42 | 43 | 44 | @dataclass(eq=False) 45 | class FixedPointIteration(base.IterativeSolver): 46 | """Fixed point iteration method. 47 | Attributes: 48 | fixed_point_fun: a function ``fixed_point_fun(x, *args, **kwargs)`` 49 | returning a pytree with the same structure and type as x 50 | The function should fulfill the Banach fixed-point theorem's assumptions. 51 | Otherwise convergence is not guaranteed. 52 | maxiter: maximum number of iterations. 53 | tol: tolerance (stopping criterion) 54 | has_aux: wether fixed_point_fun returns additional data. (default: False) 55 | if True, the fixed is computed only with respect to first element of the 56 | sequence returned. Other elements are carried during computation. 57 | verbose: whether to print information on every iteration or not. 58 | 59 | implicit_diff: whether to enable implicit diff or autodiff of unrolled 60 | iterations. 61 | implicit_diff_solve: the linear system solver to use. 62 | 63 | jit: whether to JIT-compile the optimization loop (default: True). 64 | unroll: whether to unroll the optimization loop (default: "auto") 65 | References: 66 | https://en.wikipedia.org/wiki/Fixed-point_iteration 67 | """ 68 | fixed_point_fun: Callable 69 | maxiter: int = 100 70 | tol: float = 1e-5 71 | has_aux: bool = False 72 | verbose: Union[bool, int] = False 73 | implicit_diff: bool = True 74 | implicit_diff_solve: Optional[Callable] = None 75 | jit: bool = True 76 | unroll: base.AutoOrBoolean = "auto" 77 | 78 | def init_state(self, 79 | init_params, 80 | *args, 81 | **kwargs) -> FixedPointState: 82 | """Initialize the solver state. 83 | 84 | Args: 85 | init_params: initial guess of the fixed point, pytree 86 | *args: additional positional arguments to be passed to ``optimality_fun``. 87 | **kwargs: additional keyword arguments to be passed to ``optimality_fun``. 88 | Returns: 89 | state 90 | """ 91 | return FixedPointState(iter_num=jnp.asarray(0), 92 | error=jnp.asarray(jnp.inf), 93 | aux=None, 94 | num_fun_eval=jnp.asarray(0, base.NUM_EVAL_DTYPE) 95 | ) 96 | 97 | def update(self, 98 | params: Any, 99 | state: NamedTuple, 100 | *args, 101 | **kwargs) -> base.OptStep: 102 | """Performs one iteration of the fixed point iteration method. 103 | Args: 104 | params: pytree containing the parameters. 105 | state: named tuple containing the solver state. 106 | *args: additional positional arguments to be passed to 107 | ``fixed_point_fun``. 108 | **kwargs: additional keyword arguments to be passed to 109 | ``fixed_point_fun``. 110 | Returns: 111 | (params, state) 112 | """ 113 | next_params, aux = self._fun(params, *args, **kwargs) 114 | error = tree_l2_norm(tree_sub(next_params, params)) 115 | next_state = FixedPointState(iter_num=state.iter_num + 1, 116 | error=error, 117 | aux=aux, 118 | num_fun_eval=state.num_fun_eval + 1) 119 | 120 | if self.verbose: 121 | self.log_info( 122 | next_state, 123 | error_name="Distance btw Iterates" 124 | ) 125 | return base.OptStep(params=next_params, state=next_state) 126 | 127 | def optimality_fun(self, params, *args, **kwargs): 128 | """Optimality function mapping compatible with ``@custom_root``.""" 129 | new_params, _ = self._fun(params, *args, **kwargs) 130 | return tree_sub(new_params, params) 131 | 132 | def __post_init__(self): 133 | super().__post_init__() 134 | 135 | if self.has_aux: 136 | self._fun = self.fixed_point_fun 137 | else: 138 | self._fun = lambda *a, **kw: (self.fixed_point_fun(*a, **kw), None) 139 | 140 | self.reference_signature = self.fixed_point_fun 141 | -------------------------------------------------------------------------------- /jaxopt/_src/gradient_descent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementation of gradient descent in JAX.""" 16 | 17 | from typing import Any 18 | from typing import NamedTuple 19 | 20 | from dataclasses import dataclass 21 | 22 | from jaxopt._src import base 23 | from jaxopt._src.proximal_gradient import ProximalGradient, ProxGradState 24 | 25 | 26 | @dataclass(eq=False) 27 | class GradientDescent(ProximalGradient): 28 | """Gradient Descent solver. 29 | 30 | Attributes: 31 | fun: a smooth function of the form ``fun(parameters, *args, **kwargs)``, 32 | where ``parameters`` are the model parameters w.r.t. which we minimize 33 | the function and the rest are fixed auxiliary parameters. 34 | value_and_grad: whether ``fun`` just returns the value (False) or both 35 | the value and gradient (True). 36 | has_aux: whether ``fun`` outputs auxiliary data or not. 37 | If ``has_aux`` is False, ``fun`` is expected to be 38 | scalar-valued. 39 | If ``has_aux`` is True, then we have one of the following 40 | two cases. 41 | If ``value_and_grad`` is False, the output should be 42 | ``value, aux = fun(...)``. 43 | If ``value_and_grad == True``, the output should be 44 | ``(value, aux), grad = fun(...)``. 45 | At each iteration of the algorithm, the auxiliary outputs are stored 46 | in ``state.aux``. 47 | 48 | stepsize: a stepsize to use (if <= 0, use backtracking line search), or a 49 | callable specifying the **positive** stepsize to use at each iteration. 50 | maxiter: maximum number of proximal gradient descent iterations. 51 | maxls: maximum number of iterations to use in the line search. 52 | tol: tolerance to use. 53 | 54 | acceleration: whether to use acceleration (also known as FISTA) or not. 55 | verbose: whether to print information on every iteration or not. 56 | 57 | implicit_diff: whether to enable implicit diff or autodiff of unrolled 58 | iterations. 59 | implicit_diff_solve: the linear system solver to use. 60 | 61 | jit: whether to JIT-compile the optimization loop (default: True). 62 | unroll: whether to unroll the optimization loop (default: "auto"). 63 | """ 64 | 65 | def init_state(self, 66 | init_params: Any, 67 | *args, 68 | **kwargs) -> ProxGradState: 69 | """Initialize the solver state. 70 | 71 | Args: 72 | init_params: pytree containing the initial parameters. 73 | *args: additional positional arguments to be passed to ``fun``. 74 | **kwargs: additional keyword arguments to be passed to ``fun``. 75 | Returns: 76 | state 77 | """ 78 | return super().init_state(init_params, None, *args, **kwargs) 79 | 80 | def update( 81 | self, params: Any, state: ProxGradState, *args, **kwargs 82 | ) -> base.OptStep: 83 | """Performs one iteration of gradient descent. 84 | 85 | Args: 86 | params: pytree containing the parameters. 87 | state: named tuple containing the solver state. 88 | *args: additional positional arguments to be passed to ``fun``. 89 | **kwargs: additional keyword arguments to be passed to ``fun``. 90 | Returns: 91 | (params, state) 92 | """ 93 | return super().update(params, state, None, *args, **kwargs) 94 | 95 | def optimality_fun(self, params, *args, **kwargs): 96 | """Optimality function mapping compatible with ``@custom_root``.""" 97 | return self._grad_fun(params, *args, **kwargs) 98 | 99 | def __post_init__(self): 100 | super().__post_init__() 101 | self.reference_signature = self.fun 102 | -------------------------------------------------------------------------------- /jaxopt/_src/isotonic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Isotonic Regression.""" 16 | 17 | import warnings 18 | import numpy as onp 19 | import jax 20 | import jax.numpy as jnp 21 | 22 | 23 | # pylint: disable=g-import-not-at-top 24 | try: 25 | from numba import njit 26 | 27 | NUMBA_AVAILABLE = True 28 | except ImportError: 29 | NUMBA_AVAILABLE = False 30 | # If Numba is not available, we define a dummy 'njit' function. 31 | 32 | def njit(func): 33 | return func 34 | 35 | 36 | @njit 37 | def _isotonic_l2_pav_numba(y): 38 | n = y.shape[0] 39 | target = onp.arange(n) 40 | c = onp.ones(n) 41 | sums = onp.zeros(n) 42 | sol = onp.zeros(n) 43 | 44 | # target describes a list of blocks. At any time, if [i..j] (inclusive) is 45 | # an active block, then target[i] := j and target[j] := i. 46 | 47 | for i in range(n): 48 | sol[i] = y[i] 49 | sums[i] = y[i] 50 | 51 | i = 0 52 | while i < n: 53 | k = target[i] + 1 54 | if k == n: 55 | break 56 | if sol[i] > sol[k]: 57 | i = k 58 | continue 59 | sum_y = sums[i] 60 | sum_c = c[i] 61 | while True: 62 | # We are within an increasing subsequence. 63 | prev_y = sol[k] 64 | sum_y += sums[k] 65 | sum_c += c[k] 66 | k = target[k] + 1 67 | if k == n or prev_y > sol[k]: 68 | # Non-singleton increasing subsequence is finished, 69 | # update first entry. 70 | sol[i] = sum_y / sum_c 71 | sums[i] = sum_y 72 | c[i] = sum_c 73 | target[i] = k - 1 74 | target[k - 1] = i 75 | if i > 0: 76 | # Backtrack if we can. This makes the algorithm 77 | # single-pass and ensures O(n) complexity. 78 | i = target[i - 1] 79 | # Otherwise, restart from the same point. 80 | break 81 | 82 | # Reconstruct the solution. 83 | i = 0 84 | while i < n: 85 | k = target[i] + 1 86 | sol[i + 1 : k] = sol[i] 87 | i = k 88 | return sol.astype(y.dtype) 89 | 90 | 91 | @jax.custom_jvp 92 | def _isotonic_l2_pav(y): 93 | if not NUMBA_AVAILABLE: 94 | warnings.warn( 95 | "Numba could not be imported. Code will run much more slowly." 96 | " To install, run 'pip install numba'." 97 | ) 98 | # Define the expected shape & dtype of output. 99 | shape_dtype = jax.ShapeDtypeStruct(shape=y.shape, dtype=y.dtype) 100 | sol = jax.pure_callback( 101 | _isotonic_l2_pav_numba, shape_dtype, y, vmap_method="sequential" 102 | ) 103 | return sol 104 | 105 | 106 | def isotonic_l2_pav(y, y_min=-jnp.inf, y_max=jnp.inf, increasing=True): 107 | r"""Solves an isotonic regression problem using PAV. 108 | 109 | Args: 110 | y: input to isotonic regression, a 1d-array. 111 | 112 | y_min : Lower bound on the lowest predicted value. 113 | y_max : Upper bound on the highest predicted value 114 | 115 | increasing : Order of the constraints: 116 | If True, it solves :math:`\mathop{\mathrm{arg\,min}}_{v_1 \leq ... \leq v_n} \|v - y\|^2`. 117 | If False, it solves :math:`\mathop{\mathrm{arg\,min}}_{v_1 \geq ... \geq v_n} \|v - y\|^2`. 118 | 119 | Returns: 120 | The solution, an array of the same size as y. 121 | """ 122 | sign = -1 if increasing else 1 123 | sol = _isotonic_l2_pav(y * sign) * sign 124 | sol = jnp.clip(sol, y_min, y_max) 125 | return sol 126 | 127 | 128 | def _jvp_isotonic_l2_jax_pav(solution, vector, eps=1e-8): 129 | x = solution 130 | mask = jnp.pad(jnp.absolute(jnp.diff(x)) <= eps, (1, 0)) 131 | ar = jnp.arange(x.size) 132 | inds_start = jnp.where(mask == 0, ar, +jnp.inf).sort() 133 | one_hot_start = jax.nn.one_hot(inds_start, len(vector)) 134 | A = jnp.cumsum(one_hot_start, axis=-1) 135 | A = jnp.append(jnp.diff(A[::-1], axis=0)[::-1], A[-1].reshape(1, -1), axis=0) 136 | B = A.copy() 137 | return (((B.T * (B @ vector)).T) / (A.sum(1, keepdims=True) + 1e-8)).sum(0) 138 | 139 | 140 | @_isotonic_l2_pav.defjvp 141 | def _isotonic_l2_pav_jvp(primals, tangents): 142 | """Jacobian-vector product of isotonic_l2_pav. 143 | 144 | See Section 5 of 145 | Fast Differentiable Sorting and Ranking 146 | Mathieu Blondel, Olivier Teboul, Quentin Berthet, Josip Djolonga 147 | ICML 2020 arXiv:2002.08871 148 | """ 149 | (y, ) = primals 150 | (vector, ) = tangents 151 | primal_out = _isotonic_l2_pav(y) 152 | tangent_out = _jvp_isotonic_l2_jax_pav(primal_out, vector) 153 | return primal_out, tangent_out 154 | -------------------------------------------------------------------------------- /jaxopt/_src/linear_operator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Interface for linear operators.""" 15 | 16 | import functools 17 | import jax 18 | import jax.numpy as jnp 19 | 20 | from jaxopt.tree_util import tree_map 21 | 22 | 23 | class DenseLinearOperator: 24 | """General operator for dense matrices. 25 | 26 | Attributes: 27 | pytree: pytree of dense matrices. 28 | 29 | Each leaf of ``pytree`` must be a 2D matrix. 30 | """ 31 | 32 | def __init__(self, pytree): 33 | self.pytree = pytree 34 | 35 | def __call__(self, x): 36 | return self.matvec(x) 37 | 38 | def matvec(self, x): 39 | return tree_map(jnp.dot, self.pytree, x) 40 | 41 | def rmatvec(self, _, y): 42 | return tree_map(lambda w,yi: jnp.dot(w.T, yi), self.pytree, y) 43 | 44 | def matvec_and_rmatvec(self, x, y): 45 | return self.matvec(x), self.rmatvec(x, y) 46 | 47 | def normal_matvec(self, x): 48 | """Computes A^T A x.""" 49 | return self.rmatvec(x, self.matvec(x)) 50 | 51 | def diag(self): 52 | diags_only = tree_map(jnp.diag, self.pytree) 53 | return diags_only 54 | 55 | def columns_l2_norms(self, squared=False): 56 | def col_norm(w): 57 | col_norms = jnp.sum(jnp.square(w), axis=0) 58 | if not squared: 59 | col_norms = jnp.sqrt(col_norms) 60 | return col_norms 61 | return tree_map(col_norm, self.pytree) 62 | 63 | 64 | class FunctionalLinearOperator: 65 | 66 | def __init__(self, fun, params): 67 | self.fun = functools.partial(fun, params) 68 | 69 | def __call__(self, x): 70 | return self.matvec(x) 71 | 72 | def matvec(self, x): 73 | return self.fun(x) 74 | 75 | def rmatvec(self, x, y): 76 | return self.matvec_and_rmatvec(x, y)[1] 77 | 78 | def matvec_and_rmatvec(self, x, y): 79 | matvec_x, vjp = jax.vjp(self.matvec, x) 80 | rmatvec_y, = vjp(y) 81 | return matvec_x, rmatvec_y 82 | 83 | def normal_matvec(self, x): 84 | """Computes A^T A x from matvec(x) = A x.""" 85 | matvec_x, vjp = jax.vjp(self.matvec, x) 86 | return vjp(matvec_x)[0] 87 | 88 | 89 | def _make_linear_operator(matvec): 90 | if matvec is None: 91 | return DenseLinearOperator 92 | else: 93 | return functools.partial(FunctionalLinearOperator, matvec) 94 | -------------------------------------------------------------------------------- /jaxopt/_src/linesearch_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Line searches utilities.""" 16 | 17 | from jax import numpy as jnp 18 | from jaxopt._src import base 19 | from jaxopt._src.backtracking_linesearch import BacktrackingLineSearch 20 | from jaxopt._src.hager_zhang_linesearch import HagerZhangLineSearch 21 | from jaxopt._src.zoom_linesearch import ZoomLineSearch 22 | 23 | 24 | def _setup_linesearch( 25 | linesearch, 26 | fun, 27 | value_and_grad, 28 | has_aux, 29 | maxlsiter, 30 | max_stepsize, 31 | jit, 32 | unroll, 33 | verbose, 34 | ): 35 | """Instantiate linesearch.""" 36 | 37 | available_linesearches = ["backtracking", "zoom", "hager-zhang"] 38 | if linesearch == "backtracking": 39 | linesearch_solver = BacktrackingLineSearch( 40 | fun=fun, 41 | value_and_grad=value_and_grad, 42 | has_aux=has_aux, 43 | maxiter=maxlsiter, 44 | max_stepsize=max_stepsize, 45 | jit=jit, 46 | unroll=unroll, 47 | verbose=verbose, 48 | ) 49 | elif linesearch == "zoom": 50 | linesearch_solver = ZoomLineSearch( 51 | fun=fun, 52 | value_and_grad=value_and_grad, 53 | has_aux=has_aux, 54 | maxiter=maxlsiter, 55 | max_stepsize=max_stepsize, 56 | jit=jit, 57 | unroll=unroll, 58 | verbose=verbose, 59 | ) 60 | elif linesearch == "hager-zhang": 61 | # NOTE(vroulet): max_stepsize has no effect in HZ 62 | linesearch_solver = HagerZhangLineSearch( 63 | fun=fun, 64 | value_and_grad=value_and_grad, 65 | has_aux=has_aux, 66 | maxiter=maxlsiter, 67 | jit=jit, 68 | unroll=unroll, 69 | verbose=verbose, 70 | ) 71 | elif isinstance(linesearch, base.IterativeLineSearch): 72 | linesearch_solver = linesearch 73 | else: 74 | raise ValueError( 75 | f"Linesearch {linesearch} not available/tested. " 76 | f"Available linesearches: {available_linesearches}" 77 | ) 78 | return linesearch_solver 79 | 80 | 81 | def _init_stepsize( 82 | strategy, max_stepsize, min_stepsize, increase_factor, stepsize 83 | ): 84 | """Set stepsize at the start of the linesearch from previous guess.""" 85 | available_strategies = ["max", "current", "increase"] 86 | if strategy == "max": 87 | init_stepsize = max_stepsize 88 | elif strategy == "current": 89 | init_stepsize = stepsize 90 | elif strategy == "increase": 91 | init_stepsize = jnp.where( 92 | stepsize <= min_stepsize, 93 | # If stepsize became too small, we restart it. 94 | max_stepsize, 95 | # Else, we increase a bit the previous one. 96 | stepsize * increase_factor, 97 | ) 98 | else: 99 | raise ValueError( 100 | f"Strategy {strategy} not available/tested. " 101 | f"Available linesearches: {available_strategies}" 102 | ) 103 | return init_stepsize 104 | -------------------------------------------------------------------------------- /jaxopt/_src/loop.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Loop utilities.""" 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | 20 | 21 | def _while_loop_scan(cond_fun, body_fun, init_val, max_iter): 22 | """Scan-based implementation (jit ok, reverse-mode autodiff ok).""" 23 | def _iter(val): 24 | next_val = body_fun(val) 25 | next_cond = cond_fun(next_val) 26 | return next_val, next_cond 27 | 28 | def _fun(tup, it): 29 | val, cond = tup 30 | # When cond is met, we start doing no-ops. 31 | return jax.lax.cond(cond, _iter, lambda x: (x, False), val), it 32 | 33 | init = (init_val, cond_fun(init_val)) 34 | return jax.lax.scan(_fun, init, None, length=max_iter)[0][0] 35 | 36 | 37 | def _while_loop_python(cond_fun, body_fun, init_val, maxiter): 38 | """Python based implementation (no jit, reverse-mode autodiff ok).""" 39 | val = init_val 40 | for _ in range(maxiter): 41 | cond = cond_fun(val) 42 | if not cond: 43 | # When condition is met, break (not jittable). 44 | break 45 | val = body_fun(val) 46 | return val 47 | 48 | 49 | def _while_loop_lax(cond_fun, body_fun, init_val, maxiter): 50 | """lax.while_loop based implementation (jit by default, no reverse-mode).""" 51 | def _cond_fun(_val): 52 | it, val = _val 53 | return jnp.logical_and(cond_fun(val), it <= maxiter - 1) 54 | 55 | def _body_fun(_val): 56 | it, val = _val 57 | val = body_fun(val) 58 | return it+1, val 59 | 60 | return jax.lax.while_loop(_cond_fun, _body_fun, (0, init_val))[1] 61 | 62 | 63 | def while_loop(cond_fun, body_fun, init_val, maxiter, unroll=False, jit=False): 64 | """A while loop with a bounded number of iterations.""" 65 | 66 | if unroll: 67 | if jit: 68 | fun = _while_loop_scan 69 | else: 70 | fun = _while_loop_python 71 | else: 72 | if jit: 73 | fun = _while_loop_lax 74 | else: 75 | raise ValueError("unroll=False and jit=False cannot be used together") 76 | 77 | if jit and fun is not _while_loop_lax: 78 | # jit of a lax while_loop is redundant, and this jit would only 79 | # constrain maxiter to be static where it is not required. 80 | fun = jax.jit(fun, static_argnums=(0, 1, 3)) 81 | 82 | return fun(cond_fun, body_fun, init_val, maxiter) 83 | -------------------------------------------------------------------------------- /jaxopt/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jaxopt._src.base import AutoOrBoolean 16 | from jaxopt._src.base import IterativeSolver 17 | from jaxopt._src.base import LinearOperator 18 | from jaxopt._src.base import OptStep 19 | from jaxopt._src.base import StochasticSolver 20 | from jaxopt._src.base import KKTSolution 21 | -------------------------------------------------------------------------------- /jaxopt/cond.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jaxopt._src.cond import cond 16 | -------------------------------------------------------------------------------- /jaxopt/implicit_diff.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jaxopt._src.implicit_diff import custom_root 16 | from jaxopt._src.implicit_diff import custom_fixed_point 17 | from jaxopt._src.implicit_diff import root_jvp 18 | from jaxopt._src.implicit_diff import root_vjp 19 | -------------------------------------------------------------------------------- /jaxopt/isotonic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jaxopt._src.isotonic import isotonic_l2_pav -------------------------------------------------------------------------------- /jaxopt/linear_solve.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jaxopt._src.linear_solve import solve_lu 16 | from jaxopt._src.linear_solve import solve_cholesky 17 | from jaxopt._src.linear_solve import solve_qr 18 | from jaxopt._src.linear_solve import solve_inv 19 | from jaxopt._src.linear_solve import solve_cg 20 | from jaxopt._src.linear_solve import solve_normal_cg 21 | from jaxopt._src.linear_solve import solve_gmres 22 | from jaxopt._src.linear_solve import solve_bicgstab 23 | from jaxopt._src.iterative_refinement import solve_iterative_refinement 24 | -------------------------------------------------------------------------------- /jaxopt/loop.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jaxopt._src.loop import while_loop 16 | -------------------------------------------------------------------------------- /jaxopt/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jaxopt._src.loss import binary_logistic_loss 16 | from jaxopt._src.loss import binary_sparsemax_loss, sparse_plus, sparse_sigmoid 17 | from jaxopt._src.loss import huber_loss 18 | from jaxopt._src.loss import make_fenchel_young_loss 19 | from jaxopt._src.loss import multiclass_logistic_loss 20 | from jaxopt._src.loss import multiclass_sparsemax_loss 21 | from jaxopt._src.loss import binary_hinge_loss 22 | from jaxopt._src.loss import binary_perceptron_loss 23 | from jaxopt._src.loss import multiclass_hinge_loss 24 | from jaxopt._src.loss import multiclass_perceptron_loss 25 | -------------------------------------------------------------------------------- /jaxopt/objective.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jaxopt._src.objective import CompositeLinearFunction 16 | 17 | from jaxopt._src.objective import least_squares 18 | from jaxopt._src.objective import ridge_regression 19 | 20 | from jaxopt._src.objective import binary_logreg 21 | 22 | from jaxopt._src.objective import multiclass_logreg 23 | from jaxopt._src.objective import multiclass_logreg_with_intercept 24 | from jaxopt._src.objective import l2_multiclass_logreg 25 | from jaxopt._src.objective import l2_multiclass_logreg_with_intercept 26 | 27 | from jaxopt._src.objective import multiclass_linear_svm_dual 28 | -------------------------------------------------------------------------------- /jaxopt/perturbations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jaxopt._src.perturbations import Gumbel 16 | from jaxopt._src.perturbations import Normal 17 | from jaxopt._src.perturbations import make_perturbed_argmax 18 | from jaxopt._src.perturbations import make_perturbed_max 19 | from jaxopt._src.perturbations import make_perturbed_fun 20 | -------------------------------------------------------------------------------- /jaxopt/projection.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jaxopt._src.projection import projection_non_negative 16 | from jaxopt._src.projection import projection_box 17 | from jaxopt._src.projection import projection_hypercube 18 | from jaxopt._src.projection import projection_simplex 19 | from jaxopt._src.projection import projection_sparse_simplex 20 | from jaxopt._src.projection import projection_l1_sphere 21 | from jaxopt._src.projection import projection_l1_ball 22 | from jaxopt._src.projection import projection_l2_sphere 23 | from jaxopt._src.projection import projection_l2_ball 24 | from jaxopt._src.projection import projection_linf_ball 25 | from jaxopt._src.projection import projection_hyperplane 26 | from jaxopt._src.projection import projection_halfspace 27 | from jaxopt._src.projection import projection_affine_set 28 | from jaxopt._src.projection import projection_polyhedron 29 | from jaxopt._src.projection import projection_box_section 30 | from jaxopt._src.projection import projection_transport 31 | from jaxopt._src.projection import projection_birkhoff 32 | from jaxopt._src.projection import kl_projection_transport 33 | from jaxopt._src.projection import kl_projection_birkhoff 34 | -------------------------------------------------------------------------------- /jaxopt/prox.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jaxopt._src.prox import make_prox_from_projection 16 | from jaxopt._src.prox import prox_none 17 | from jaxopt._src.prox import prox_lasso 18 | from jaxopt._src.prox import prox_non_negative_lasso 19 | from jaxopt._src.prox import prox_elastic_net 20 | from jaxopt._src.prox import prox_group_lasso 21 | from jaxopt._src.prox import prox_ridge 22 | from jaxopt._src.prox import prox_non_negative_ridge 23 | -------------------------------------------------------------------------------- /jaxopt/tree_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jaxopt._src.tree_util import broadcast_pytrees 16 | from jaxopt._src.tree_util import tree_map 17 | from jaxopt._src.tree_util import tree_reduce 18 | from jaxopt._src.tree_util import tree_add 19 | from jaxopt._src.tree_util import tree_sub 20 | from jaxopt._src.tree_util import tree_mul 21 | from jaxopt._src.tree_util import tree_scalar_mul 22 | from jaxopt._src.tree_util import tree_add_scalar_mul 23 | from jaxopt._src.tree_util import tree_dot 24 | from jaxopt._src.tree_util import tree_vdot 25 | from jaxopt._src.tree_util import tree_vdot_real 26 | from jaxopt._src.tree_util import tree_div 27 | from jaxopt._src.tree_util import tree_sum 28 | from jaxopt._src.tree_util import tree_l2_norm 29 | from jaxopt._src.tree_util import tree_where 30 | from jaxopt._src.tree_util import tree_zeros_like 31 | from jaxopt._src.tree_util import tree_ones_like 32 | from jaxopt._src.tree_util import tree_negative 33 | from jaxopt._src.tree_util import tree_inf_norm 34 | from jaxopt._src.tree_util import tree_conj 35 | from jaxopt._src.tree_util import tree_real 36 | from jaxopt._src.tree_util import tree_imag -------------------------------------------------------------------------------- /jaxopt/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """JAXopt version.""" 16 | 17 | __version__ = "0.8.5" 18 | -------------------------------------------------------------------------------- /pylintrc: -------------------------------------------------------------------------------- 1 | # This pylintrc file is taken from JAX 2 | # https://github.com/google/jax/blob/main/pylintrc 3 | [MASTER] 4 | 5 | # A comma-separated list of package or module names from where C extensions may 6 | # be loaded. Extensions are loading into the active Python interpreter and may 7 | # run arbitrary code 8 | extension-pkg-whitelist=numpy 9 | 10 | 11 | [MESSAGES CONTROL] 12 | 13 | # Disable the message, report, category or checker with the given id(s). You 14 | # can either give multiple identifiers separated by comma (,) or put this 15 | # option multiple times (only on the command line, not in the configuration 16 | # file where it should appear only once).You can also use "--disable=all" to 17 | # disable everything first and then reenable specific checks. For example, if 18 | # you want to run only the similarities checker, you can use "--disable=all 19 | # --enable=similarities". If you want to run only the classes checker, but have 20 | # no Warning level messages displayed, use"--disable=all --enable=classes 21 | # --disable=W" 22 | disable=missing-docstring, 23 | too-many-locals, 24 | invalid-name, 25 | redefined-outer-name, 26 | redefined-builtin, 27 | protected-name, 28 | no-else-return, 29 | fixme, 30 | protected-access, 31 | too-many-arguments, 32 | blacklisted-name, 33 | too-few-public-methods, 34 | unnecessary-lambda, 35 | 36 | 37 | # Enable the message, report, category or checker with the given id(s). You can 38 | # either give multiple identifier separated by comma (,) or put this option 39 | # multiple time (only on the command line, not in the configuration file where 40 | # it should appear only once). See also the "--disable" option for examples. 41 | enable=c-extension-no-member 42 | 43 | 44 | [FORMAT] 45 | 46 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 47 | # tab). 48 | indent-string=" " 49 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax>=0.2.18 2 | jaxlib>=0.1.69 3 | numpy>=1.18.4 4 | scipy>=1.0.0 5 | -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | absl-py>=0.7.0 2 | cvxpy>=1.1.11 3 | optax>=0.0.9 4 | pytest-xdist 5 | scikit-learn>=0.24.1 6 | cvxopt 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Setup script for JAXopt.""" 16 | 17 | import os 18 | from setuptools import find_packages 19 | from setuptools import setup 20 | 21 | 22 | folder = os.path.dirname(__file__) 23 | version_path = os.path.join(folder, "jaxopt", "version.py") 24 | 25 | __version__ = None 26 | with open(version_path) as f: 27 | exec(f.read(), globals()) 28 | 29 | req_path = os.path.join(folder, "requirements.txt") 30 | install_requires = [] 31 | if os.path.exists(req_path): 32 | with open(req_path) as fp: 33 | install_requires = [line.strip() for line in fp] 34 | 35 | readme_path = os.path.join(folder, "README.md") 36 | readme_contents = "" 37 | if os.path.exists(readme_path): 38 | with open(readme_path) as fp: 39 | readme_contents = fp.read().strip() 40 | 41 | setup( 42 | name="jaxopt", 43 | version=__version__, 44 | description="Hardware accelerated, batchable and differentiable optimizers in JAX.", 45 | author="Google LLC", 46 | author_email="no-reply@google.com", 47 | url="https://github.com/google/jaxopt", 48 | long_description=readme_contents, 49 | long_description_content_type="text/markdown", 50 | license="Apache 2.0", 51 | packages=find_packages(), 52 | package_data={}, 53 | install_requires=install_requires, 54 | classifiers=[ 55 | "Intended Audience :: Science/Research", 56 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 57 | "License :: OSI Approved :: Apache Software License", 58 | "Programming Language :: Python :: 3", 59 | "Programming Language :: Python :: 3.10", 60 | "Programming Language :: Python :: 3.11", 61 | "Programming Language :: Python :: 3.12", 62 | "Programming Language :: Python :: 3.13", 63 | ], 64 | keywords="optimization, root finding, implicit differentiation, jax", 65 | requires_python=">=3.10", 66 | ) 67 | -------------------------------------------------------------------------------- /tests/anderson_wrapper_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from absl.testing import absltest 17 | 18 | import jax.numpy as jnp 19 | from jax.test_util import check_grads 20 | import optax 21 | 22 | from jaxopt import objective 23 | 24 | from jaxopt import prox 25 | from jaxopt._src import test_util 26 | 27 | from jaxopt import AndersonWrapper 28 | from jaxopt import BlockCoordinateDescent 29 | from jaxopt import OptaxSolver 30 | from jaxopt import PolyakSGD 31 | from jaxopt import ProximalGradient 32 | 33 | from sklearn import datasets 34 | 35 | 36 | class AndersonWrapperTest(test_util.JaxoptTestCase): 37 | 38 | def test_proximal_gradient_wrapper(self): 39 | """Baseline test on simple optimizer.""" 40 | X, y = datasets.make_regression(n_samples=100, n_features=20, random_state=0) 41 | fun = objective.least_squares 42 | lam = 10.0 43 | data = (X, y) 44 | w_init = jnp.zeros(X.shape[1]) 45 | tol = 1e-3 46 | maxiter = 1000 47 | pg = ProximalGradient(fun=fun, prox=prox.prox_lasso, maxiter=maxiter, tol=tol, 48 | acceleration=False) 49 | aw = AndersonWrapper(pg, history_size=15) 50 | aw_params, awpg_info = aw.run(w_init, hyperparams_prox=lam, data=data) 51 | self.assertLess(awpg_info.error, tol) 52 | 53 | def test_mixing_frequency_polyak(self): 54 | """Test mixing_frequency by accelerating PolyakSGD.""" 55 | X, y = datasets.make_classification(n_samples=10, n_features=5, n_classes=3, 56 | n_informative=3, random_state=0) 57 | data = (X, y) 58 | l2reg = 100.0 59 | # fun(params, data) 60 | fun = objective.l2_multiclass_logreg_with_intercept 61 | n_classes = len(jnp.unique(y)) 62 | 63 | W_init = jnp.zeros((X.shape[1], n_classes)) 64 | b_init = jnp.zeros(n_classes) 65 | pytree_init = (W_init, b_init) 66 | 67 | opt = PolyakSGD(fun=fun, max_stepsize=0.01, tol=0.05, momentum=False) 68 | history_size = 5 69 | aw = AndersonWrapper(opt, history_size=history_size, mixing_frequency=1) 70 | aw_params, aw_state = aw.run(pytree_init, l2reg=l2reg, data=data) 71 | self.assertLess(aw_state.error, 0.05) 72 | 73 | def test_optax_restart(self): 74 | """Test Optax optimizer.""" 75 | X, y = datasets.make_classification(n_samples=100, n_features=20, n_classes=3, 76 | n_informative=3, random_state=0) 77 | data = (X, y) 78 | l2reg = 100.0 79 | # fun(params, data) 80 | fun = objective.l2_multiclass_logreg_with_intercept 81 | n_classes = len(jnp.unique(y)) 82 | 83 | W_init = jnp.zeros((X.shape[1], n_classes)) 84 | b_init = jnp.zeros(n_classes) 85 | pytree_init = (W_init, b_init) 86 | 87 | tol = 1e-2 88 | opt = OptaxSolver(opt=optax.sgd(1e-2, momentum=0.8), fun=fun, maxiter=1000, tol=0) 89 | aw = AndersonWrapper(opt, history_size=3, ridge=1e-3) 90 | params, infos = aw.run(pytree_init, l2reg=l2reg, data=data) 91 | 92 | # Check optimality conditions. 93 | error = opt.l2_optimality_error(params, l2reg=l2reg, data=data) 94 | self.assertLessEqual(error, tol) 95 | 96 | def test_block_cd_restart(self): 97 | """Accelerate Block CD.""" 98 | X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) 99 | 100 | # Setup parameters. 101 | fun = objective.least_squares # fun(params, data) 102 | l1reg = 10.0 103 | data = (X, y) 104 | 105 | # Initialize. 106 | w_init = jnp.zeros(X.shape[1]) 107 | tol = 5e-4 108 | maxiter = 100 109 | bcd = BlockCoordinateDescent(fun=fun, block_prox=prox.prox_lasso, tol=tol, maxiter=maxiter) 110 | aw = AndersonWrapper(bcd, history_size=3, ridge=1e-5) 111 | params, state = aw.run(init_params=w_init, hyperparams_prox=l1reg, data=data) 112 | 113 | # Check optimality conditions. 114 | self.assertLess(state.error, tol) 115 | 116 | def test_wrapper_grad(self): 117 | """Test gradient of wrapper.""" 118 | data_train = datasets.make_regression(n_samples=100, n_features=3, random_state=0) 119 | fun = objective.least_squares 120 | lam = 10.0 121 | w_init = jnp.zeros(data_train[0].shape[1]) 122 | tol = 1e-5 123 | maxiter = 1000 # large number of updates 124 | pg = ProximalGradient(fun=fun, prox=prox.prox_lasso, maxiter=maxiter, tol=tol, 125 | acceleration=False) 126 | aw = AndersonWrapper(pg, history_size=5) 127 | data_val = datasets.make_regression(n_samples=100, n_features=3, random_state=0) 128 | 129 | def solve_run(lam): 130 | aw_params = aw.run(w_init, lam, data_train).params 131 | loss = fun(aw_params, data=data_val) 132 | return loss 133 | 134 | check_grads(solve_run, args=(lam,), order=1, modes=['rev'], eps=2e-2) 135 | 136 | def solve_run(lam): 137 | aw_params = aw.run(w_init, hyperparams_prox=lam, data=data_train).params 138 | loss = fun(aw_params, data=data_val) 139 | return loss 140 | 141 | check_grads(solve_run, args=(lam,), order=1, modes=['rev'], eps=2e-2) 142 | 143 | if __name__ == '__main__': 144 | # Uncomment the line below in order to run in float64. 145 | # config.update("jax_enable_x64", True) 146 | absltest.main() 147 | -------------------------------------------------------------------------------- /tests/bisection_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | 20 | from jaxopt import projection 21 | from jaxopt import Bisection 22 | from jaxopt._src import test_util 23 | 24 | import numpy as onp 25 | 26 | 27 | # optimality_fun(params, hyperparams, data) 28 | def _optimality_fun_proj_simplex(tau, x, s): 29 | # optimality_fun(tau, x, s) is a decreasing function of tau on 30 | # [lower, upper] since the derivative w.r.t. tau is negative. 31 | return jnp.sum(jnp.maximum(x - tau, 0)) - s 32 | 33 | 34 | def _threshold_proj_simplex(bisect, x, s=1.0): 35 | return bisect.run(None, x, s).params 36 | 37 | 38 | def _projection_simplex_bisect(bisect, x, s=1.0): 39 | return jnp.maximum(x - _threshold_proj_simplex(bisect, x, s), 0) 40 | 41 | 42 | def _projection_simplex_bisect_setup(x, s=1.0): 43 | # tau = max(x) => tau >= x_i for all i 44 | # => x_i - tau <= 0 for all i 45 | # => maximum(x_i - tau, 0) = 0 for all i 46 | # => optimality_fun(tau, x, s) = -s <= 0 47 | upper = jax.lax.stop_gradient(jnp.max(x)) 48 | 49 | # tau' = min(x) => tau' <= x_i for all i 50 | # => 0 <= x_i - tau' for all_i 51 | # => maximum(x_i - tau', 0) >= 0 52 | # => optimality_fun(tau, x, s) >= 0 53 | # where tau = tau' - s / len(x) 54 | lower = jax.lax.stop_gradient(jnp.min(x)) - s / len(x) 55 | 56 | return Bisection(optimality_fun=_optimality_fun_proj_simplex, 57 | lower=lower, upper=upper, check_bracket=False) 58 | 59 | 60 | class BisectionTest(test_util.JaxoptTestCase): 61 | 62 | def test_bisect(self): 63 | rng = onp.random.RandomState(0) 64 | 65 | _projection_simplex_bisect_jitted = jax.jit( 66 | _projection_simplex_bisect, static_argnums=0) 67 | 68 | for _ in range(10): 69 | x = jnp.array(rng.randn(50).astype(onp.float32)) 70 | bisect = _projection_simplex_bisect_setup(x) 71 | p = projection.projection_simplex(x) 72 | p2 = _projection_simplex_bisect(bisect, x) 73 | p3 = _projection_simplex_bisect_jitted(bisect, x) 74 | self.assertArraysAllClose(p, p2, atol=1e-4) 75 | self.assertArraysAllClose(p, p3, atol=1e-4) 76 | 77 | J = jax.jacrev(projection.projection_simplex)(x) 78 | J2 = jax.jacrev(_projection_simplex_bisect, argnums=1)(bisect, x) 79 | J3 = jax.jacrev(_projection_simplex_bisect_jitted, argnums=1)(bisect, x) 80 | self.assertArraysAllClose(J, J2, atol=1e-5) 81 | self.assertArraysAllClose(J, J3, atol=1e-5) 82 | 83 | def test_bisect_wrong_lower_bracket(self): 84 | rng = onp.random.RandomState(0) 85 | x = jnp.array(rng.randn(5).astype(onp.float32)) 86 | s = 1.0 87 | upper = jnp.max(x) 88 | bisect = Bisection(optimality_fun=_optimality_fun_proj_simplex, 89 | lower=upper, upper=upper) 90 | self.assertRaises(ValueError, bisect.run, None, x, s) 91 | 92 | def test_bisect_wrong_upper_bracket(self): 93 | rng = onp.random.RandomState(0) 94 | x = jnp.array(rng.randn(5).astype(onp.float32)) 95 | s = 1.0 96 | lower = jnp.min(x) - s / len(x) 97 | bisect = Bisection(optimality_fun=_optimality_fun_proj_simplex, 98 | lower=lower, upper=lower) 99 | self.assertRaises(ValueError, bisect.run, None, x, s) 100 | 101 | def test_grad_of_value_and_grad(self): 102 | # See https://github.com/google/jaxopt/issues/141 103 | 104 | def bisect(x): 105 | b = _projection_simplex_bisect_setup(x) 106 | return _projection_simplex_bisect(b, x)[0] 107 | 108 | def bisect_val(x): 109 | val, _ = jax.value_and_grad(bisect)(x) 110 | return val 111 | 112 | rng = onp.random.RandomState(0) 113 | x = jnp.array(rng.randn(5).astype(onp.float32)) 114 | g1 = jax.grad(bisect)(x) 115 | g2 = jax.grad(bisect_val)(x) 116 | self.assertArraysAllClose(g1, g2) 117 | 118 | def test_edge(self): 119 | def F(x): 120 | return x - 0.5 121 | 122 | bisec = Bisection(optimality_fun=F, lower=0.0, upper=1.0) 123 | # The solution is found on the first iteration. 124 | self.assertEqual(bisec.run().params, 0.5) 125 | 126 | 127 | if __name__ == '__main__': 128 | absltest.main() 129 | -------------------------------------------------------------------------------- /tests/cd_qp_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | 20 | from jaxopt import BoxCDQP 21 | from jaxopt._src import test_util 22 | 23 | import numpy as onp 24 | 25 | 26 | def _cd_qp(Q, c, l, u, tol, maxiter, verbose=0): 27 | """Pure NumPy implementation for test purposes.""" 28 | x = onp.zeros(Q.shape[0]) 29 | 30 | for it in range(maxiter): 31 | error = 0 32 | 33 | for i in range(len(x)): 34 | g_i = onp.dot(Q[i], x) + c[i] 35 | h_i = Q[i, i] 36 | 37 | if h_i == 0: 38 | continue 39 | 40 | x_i_new = onp.clip(x[i] - g_i / h_i, l[i], u[i]) 41 | delta_i = x_i_new - x[i] 42 | error += onp.abs(delta_i) 43 | x[i] = x_i_new 44 | 45 | if verbose: 46 | print(it + 1, error) 47 | 48 | if error <= tol: 49 | break 50 | 51 | return x 52 | 53 | 54 | class CD_QP_Test(test_util.JaxoptTestCase): 55 | 56 | def setUp(self): 57 | rng = onp.random.RandomState(0) 58 | num_dim = 5 59 | M = rng.randn(num_dim, num_dim) 60 | self.Q = onp.dot(M, M.T) 61 | self.c = rng.randn(num_dim) 62 | self.l = rng.randn(num_dim) 63 | self.u = self.l + 5 * rng.rand(num_dim) 64 | self.params_obj = (self.Q, self.c) 65 | self.params_ineq = (self.l, self.u) 66 | 67 | def test_forward(self): 68 | sol_numpy = _cd_qp(self.Q, self.c, self.l, self.u, 69 | tol=1e-3, maxiter=100, verbose=0) 70 | 71 | # Manual loop 72 | params = jnp.zeros_like(sol_numpy) 73 | 74 | cdqp = BoxCDQP() 75 | state = cdqp.init_state(params) 76 | 77 | for _ in range(5): 78 | params, state = cdqp.update(params, state, params_obj=self.params_obj, 79 | params_ineq=self.params_ineq) 80 | 81 | self.assertAllClose(state.error, 0.0) 82 | self.assertAllClose(params, sol_numpy) 83 | 84 | # Run call. 85 | params = jnp.zeros_like(sol_numpy) 86 | params, state = cdqp.run(params, params_obj=self.params_obj, 87 | params_ineq=self.params_ineq) 88 | self.assertAllClose(state.error, 0.0) 89 | self.assertAllClose(params, sol_numpy) 90 | 91 | def test_backward(self): 92 | cdqp = BoxCDQP(implicit_diff=True) 93 | cdqp2 = BoxCDQP(implicit_diff=False) 94 | init_params = jnp.zeros(self.Q.shape[0]) 95 | 96 | def wrapper(c): 97 | params_obj = (self.Q, c) 98 | return cdqp.run(init_params, params_obj=params_obj, 99 | params_ineq=self.params_ineq).params 100 | 101 | def wrapper2(c): 102 | params_obj = (self.Q, c) 103 | return cdqp2.run(init_params, params_obj=params_obj, 104 | params_ineq=self.params_ineq).params 105 | 106 | J = jax.jacobian(wrapper)(self.c) 107 | J2 = jax.jacobian(wrapper2)(self.c) 108 | self.assertAllClose(J, J2) 109 | 110 | 111 | if __name__ == '__main__': 112 | # Uncomment the line below in order to run in float64. 113 | #jax.config.update("jax_enable_x64", True) 114 | absltest.main() 115 | -------------------------------------------------------------------------------- /tests/cond_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from absl.testing import parameterized 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | import numpy as onp 21 | 22 | from jaxopt._src.cond import cond 23 | from jaxopt._src import test_util 24 | 25 | 26 | class CondTest(test_util.JaxoptTestCase): 27 | 28 | @parameterized.product(jit=[False, True]) 29 | def test_cond(self, jit): 30 | def true_fun(x): 31 | return x 32 | def false_fun(x): 33 | return jnp.zeros_like(x) 34 | 35 | def my_relu(x): 36 | return cond(jnp.sum(x)>0, true_fun, false_fun, x, jit=jit) 37 | 38 | if jit: 39 | x = onp.array([1.]) 40 | else: 41 | x = jnp.array([1.]) 42 | self.assertEqual(jax.nn.relu(x), my_relu(x)) 43 | 44 | if __name__ == '__main__': 45 | absltest.main() 46 | -------------------------------------------------------------------------------- /tests/cvxpy_wrapper_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """CVXPY tests.""" 16 | 17 | from absl.testing import absltest 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | import numpy as onp 22 | 23 | from jaxopt import projection 24 | from jaxopt import CvxpyQP 25 | from jaxopt._src import test_util 26 | 27 | 28 | class CvxpyQPTest(test_util.JaxoptTestCase): 29 | 30 | def _check_derivative_Q_c_A_b(self, solver, params, Q, c, A, b): 31 | def fun(Q, c, A, b): 32 | try: 33 | params_ineq = params["params_ineq"] 34 | except KeyError: 35 | params_ineq = None 36 | 37 | Q = 0.5 * (Q + Q.T) 38 | 39 | hyperparams = dict(params_obj=(Q, c), 40 | params_eq=(A, b), 41 | params_ineq=params_ineq) 42 | 43 | # reduce the primal variables to a scalar value for test purpose. 44 | return jnp.sum(solver.run(None, **hyperparams).params[0]) 45 | 46 | # Derivative w.r.t. A. 47 | rng = onp.random.RandomState(0) 48 | V = rng.rand(*A.shape) 49 | V /= onp.sqrt(onp.sum(V ** 2)) 50 | eps = 1e-4 51 | deriv_jax = jnp.vdot(V, jax.grad(fun, argnums=2)(Q, c, A, b)) 52 | deriv_num = (fun(Q, c, A + eps * V, b) - fun(Q, c, A - eps * V, b)) / (2 * eps) 53 | self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) 54 | 55 | # Derivative w.r.t. b. 56 | v = rng.rand(*b.shape) 57 | v /= onp.sqrt(onp.sum(v ** 2)) 58 | eps = 1e-4 59 | deriv_jax = jnp.vdot(v, jax.grad(fun, argnums=3)(Q, c, A, b)) 60 | deriv_num = (fun(Q, c, A, b + eps * v) - fun(Q, c, A, b - eps * v)) / (2 * eps) 61 | self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) 62 | 63 | # Derivative w.r.t. Q 64 | W = rng.rand(*Q.shape) 65 | W /= onp.sqrt(onp.sum(W ** 2)) 66 | eps = 1e-4 67 | deriv_jax = jnp.vdot(W, jax.grad(fun, argnums=0)(Q, c, A, b)) 68 | deriv_num = (fun(Q + eps * W, c, A, b) - fun(Q - eps * W, c, A, b)) / (2 * eps) 69 | self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) 70 | 71 | # Derivative w.r.t. c 72 | w = rng.rand(*c.shape) 73 | w /= onp.sqrt(onp.sum(w ** 2)) 74 | eps = 1e-4 75 | deriv_jax = jnp.vdot(w, jax.grad(fun, argnums=1)(Q, c, A, b)) 76 | deriv_num = (fun(Q, c + eps * w, A, b) - fun(Q, c - eps * w, A, b)) / (2 * eps) 77 | self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) 78 | 79 | def test_qp_eq_and_ineq(self): 80 | Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) 81 | c = jnp.array([1.0, 1.0]) 82 | A = jnp.array([[1.0, 1.0]]) 83 | b = jnp.array([1.0]) 84 | G = jnp.array([[-1.0, 0.0], [0.0, -1.0]]) 85 | h = jnp.array([0.0, 0.0]) 86 | qp = CvxpyQP() 87 | hyperparams = dict(params_obj=(Q, c), params_eq=(A, b), params_ineq=(G, h)) 88 | sol = qp.run(None, **hyperparams).params 89 | self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0, atol=1e-4) 90 | self._check_derivative_Q_c_A_b(qp, hyperparams, Q, c, A, b) 91 | 92 | def test_projection_simplex(self): 93 | def _projection_simplex_qp(x, s=1.0): 94 | Q = jnp.eye(len(x)) 95 | A = jnp.array([jnp.ones_like(x)]) 96 | b = jnp.array([s]) 97 | G = -jnp.eye(len(x)) 98 | h = jnp.zeros_like(x) 99 | hyperparams = dict(params_obj=(Q, -x), params_eq=(A, b), 100 | params_ineq=(G, h)) 101 | 102 | qp = CvxpyQP() 103 | # Returns the primal solution only. 104 | return qp.run(None, **hyperparams).params[0] 105 | 106 | rng = onp.random.RandomState(0) 107 | x = jnp.array(rng.randn(10).astype(onp.float32)) 108 | p = projection.projection_simplex(x) 109 | p2 = _projection_simplex_qp(x) 110 | self.assertArraysAllClose(p, p2) 111 | J = jax.jacrev(projection.projection_simplex)(x) 112 | J2 = jax.jacrev(_projection_simplex_qp)(x) 113 | self.assertArraysAllClose(J, J2, atol=1e-5) 114 | 115 | 116 | if __name__ == '__main__': 117 | absltest.main() 118 | -------------------------------------------------------------------------------- /tests/gauss_newton_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from absl.testing import parameterized 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | 21 | from jaxopt import GaussNewton 22 | from jaxopt._src import test_util 23 | 24 | import numpy as onp 25 | 26 | 27 | def _enzyme_reaction_residual_model(coeffs, x, y): 28 | return y - coeffs[0] * x / (coeffs[1] + x) 29 | 30 | 31 | def _enzyme_reaction_residual_model_jac(coeffs, x, y, eps=1e-5): 32 | """Return the numerical Jacobian.""" 33 | gn = GaussNewton( 34 | residual_fun=_enzyme_reaction_residual_model, 35 | maxiter=100, 36 | tol=1.0e-6) 37 | 38 | # Sets eps only at idx, the rest is zero 39 | eps_at = lambda idx: onp.array([int(i == idx)*eps for i in range(len(x))]) 40 | 41 | res1 = jnp.zeros((len(coeffs), len(x))) 42 | res2 = jnp.zeros((len(coeffs), len(x))) 43 | for i in range(len(x)): 44 | res1 = res1.at[:,i].set(gn.run(coeffs, x + eps_at(i), y).params) 45 | res2 = res2.at[:,i].set(gn.run(coeffs, x - eps_at(i), y).params) 46 | 47 | twoeps = 2 * eps 48 | return (res1 - res2) / twoeps 49 | 50 | 51 | def _city_temperature_residual_model(coeffs, x, y): 52 | return y - (coeffs[0] * jnp.sin(x * coeffs[1] + coeffs[2]) + coeffs[3]) 53 | 54 | 55 | class GaussNewtonTest(test_util.JaxoptTestCase): 56 | 57 | def setUp(self): 58 | super().setUp() 59 | 60 | self.substrate_conc = onp.array( 61 | [0.038, 0.194, .425, .626, 1.253, 2.500, 3.740]) 62 | self.rate_data = onp.array( 63 | [0.050, 0.127, 0.094, 0.2122, 0.2729, 0.2665, 0.3317]) 64 | self.init_enzyme_reaction_coeffs = onp.array([0.1, 0.1]) 65 | 66 | self.months = onp.arange(1, 13) 67 | self.temperature_record = onp.array([ 68 | 61.0, 65.0, 72.0, 78.0, 85.0, 90.0, 92.0, 92.0, 88.0, 81.0, 72.0, 63.0 69 | ]) 70 | self.init_temperature_record_coeffs = onp.array([10, 0.5, 10.5, 50]) 71 | 72 | def test_aux_true(self): 73 | gn = GaussNewton(lambda x: (x**2, True), has_aux=True, maxiter=2) 74 | x_init = jnp.arange(2.) 75 | _, state = gn.run(x_init) 76 | self.assertEqual(state.aux, True) 77 | 78 | # Example taken from "Probability, Statistics and Estimation" by Mathieu ROUAUD. 79 | # The algorithm is detailed and applied to the biology experiment discussed in 80 | # page 84 with the uncertainties on the estimated values. 81 | def test_enzyme_reaction_parameter_fit(self): 82 | gn = GaussNewton( 83 | residual_fun=_enzyme_reaction_residual_model, 84 | maxiter=100, 85 | tol=1.0e-6) 86 | optimize_info = gn.run( 87 | self.init_enzyme_reaction_coeffs, 88 | self.substrate_conc, 89 | self.rate_data) 90 | 91 | self.assertArraysAllClose(optimize_info.params, 92 | onp.array([0.36183689, 0.55626653]), 93 | rtol=1e-7, atol=1e-7) 94 | 95 | @parameterized.product(implicit_diff=[True, False]) 96 | def test_enzyme_reaction_implicit_diff(self, implicit_diff): 97 | jac_num = _enzyme_reaction_residual_model_jac( 98 | self.init_enzyme_reaction_coeffs, self.substrate_conc, self.rate_data) 99 | 100 | gn = GaussNewton( 101 | residual_fun=_enzyme_reaction_residual_model, 102 | tol=1.0e-6, 103 | maxiter=10, 104 | implicit_diff=implicit_diff) 105 | 106 | def wrapper(substrate_conc): 107 | return gn.run( 108 | self.init_enzyme_reaction_coeffs, 109 | substrate_conc, 110 | self.rate_data).params 111 | jac_custom = jax.jacrev(wrapper)(self.substrate_conc) 112 | 113 | self.assertArraysAllClose(jac_num, jac_custom, atol=1e-2) 114 | 115 | # Example 7 from "SOLVING NONLINEAR LEAST-SQUARES PROBLEMS WITH THE 116 | # GAUSS-NEWTON AND LEVENBERG-MARQUARDT METHODS" by ALFONSO CROEZE et al. 117 | def test_temperature_record_four_parameter_fit(self): 118 | gn = GaussNewton( 119 | residual_fun=_city_temperature_residual_model, 120 | tol=1.0e-6) 121 | optimize_info = gn.run( 122 | self.init_temperature_record_coeffs, 123 | self.months, 124 | self.temperature_record) 125 | 126 | # Checking against the expected values 127 | self.assertArraysAllClose( 128 | optimize_info.params, 129 | onp.array([16.63994555, 0.46327812, 10.85228919, 76.19086103]), 130 | rtol=1e-6, atol=1e-5) 131 | 132 | def test_scalar_output_fun(self): 133 | gn = GaussNewton( 134 | residual_fun=lambda x: x @ x, 135 | tol=1e-1,) 136 | x_init = jnp.ones((2,)) 137 | x_opt, _ = gn.run(x_init) 138 | 139 | self.assertAllClose(x_opt, jnp.zeros((2,)), atol=1e0) 140 | 141 | 142 | if __name__ == '__main__': 143 | absltest.main() 144 | -------------------------------------------------------------------------------- /tests/hager_zhang_linesearch_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from absl.testing import parameterized 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | 21 | from jaxopt import HagerZhangLineSearch 22 | from jaxopt import objective 23 | from jaxopt._src import test_util 24 | from jaxopt.tree_util import tree_scalar_mul 25 | from jaxopt.tree_util import tree_vdot 26 | 27 | import numpy as onp 28 | 29 | from sklearn import datasets 30 | 31 | 32 | class HagerZhangLinesearchTest(test_util.JaxoptTestCase): 33 | 34 | def _check_conditions_satisfied( 35 | self, 36 | c1, 37 | c2, 38 | stepsize, 39 | initial_value, 40 | initial_grad, 41 | final_state): 42 | self.assertTrue(jnp.all(final_state.done)) 43 | self.assertFalse(jnp.any(final_state.failed)) 44 | 45 | descent_direction = tree_scalar_mul(-1, initial_grad) 46 | sufficient_decrease = jnp.all( 47 | final_state.value <= initial_value + 48 | c1 * stepsize * tree_vdot(final_state.grad, descent_direction)) 49 | self.assertTrue(sufficient_decrease) 50 | 51 | new_gd_vdot = tree_vdot(final_state.grad, descent_direction) 52 | gd_vdot = tree_vdot(initial_grad, descent_direction) 53 | curvature = jnp.all(new_gd_vdot >= c2 * gd_vdot) 54 | self.assertTrue(curvature) 55 | 56 | def test_hager_zhang_linesearch(self): 57 | x, y = datasets.make_classification( 58 | n_samples=10, n_features=5, n_classes=2, 59 | n_informative=3, random_state=0) 60 | data = (x, y) 61 | fun = objective.binary_logreg 62 | 63 | rng = onp.random.RandomState(0) 64 | w_init = rng.randn(x.shape[1]) 65 | initial_grad = jax.grad(fun)(w_init, data=data) 66 | initial_value = fun(w_init, data=data) 67 | 68 | # Manual loop. 69 | ls = HagerZhangLineSearch(fun=fun) 70 | stepsize = 1.0 71 | state = ls.init_state( 72 | init_stepsize=1.0, params=w_init, fun_kwargs={"data": data} 73 | ) 74 | stepsize, state = ls.update(stepsize=stepsize, state=state, params=w_init, 75 | fun_kwargs={"data": data}) 76 | 77 | # Call to run. 78 | ls = HagerZhangLineSearch(fun=fun, maxiter=20) 79 | stepsize, state = ls.run( 80 | init_stepsize=1.0, params=w_init, fun_kwargs={"data": data} 81 | ) 82 | self._check_conditions_satisfied( 83 | ls.c1, ls.c2, stepsize, initial_value, initial_grad, state) 84 | 85 | # Call to run with value_and_grad=True. 86 | ls = HagerZhangLineSearch(fun=jax.value_and_grad(fun), 87 | maxiter=20, 88 | value_and_grad=True) 89 | stepsize, state = ls.run( 90 | init_stepsize=1.0, params=w_init, fun_kwargs={"data": data} 91 | ) 92 | self._check_conditions_satisfied( 93 | ls.c1, ls.c2, stepsize, initial_value, initial_grad, state) 94 | 95 | # Failed linesearch (high c1 ensures convergence condition is not met). 96 | ls = HagerZhangLineSearch(fun=fun, maxiter=20, c1=2.) 97 | _, state = ls.run( 98 | init_stepsize=1.0, params=w_init, fun_kwargs={"data": data} 99 | ) 100 | self.assertTrue(jnp.all(state.failed)) 101 | self.assertFalse(jnp.any(state.done)) 102 | 103 | @parameterized.product(val=[onp.inf, onp.nan]) 104 | def test_hager_zhang_linesearch_non_finite(self, val): 105 | 106 | def fun(x): 107 | result = jnp.where(x > 4., val, (x - 2)**2) 108 | grad = jnp.where(x > 4., onp.nan, 2 * (x - 2.)) 109 | return result, grad 110 | x_init = -0.001 111 | 112 | ls = HagerZhangLineSearch(fun=fun, value_and_grad=True, jit=False) 113 | stepsize = 1.25 114 | state = ls.init_state(init_stepsize=1.25, params=x_init) 115 | 116 | stepsize, state = ls.update(stepsize=stepsize, state=state, params=x_init) 117 | # Should work around the Nan/Inf regions and provide a reasonable step size. 118 | self.assertTrue(state.done) 119 | 120 | 121 | if __name__ == '__main__': 122 | # Uncomment the line below in order to run in float64. 123 | # jax.config.update("jax_enable_x64", True) 124 | absltest.main() 125 | -------------------------------------------------------------------------------- /tests/import_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | 17 | import jaxopt 18 | from jaxopt._src import test_util 19 | 20 | 21 | class ImportTest(test_util.JaxoptTestCase): 22 | 23 | def test_implicit_diff(self): 24 | jaxopt.implicit_diff.root_vjp 25 | from jaxopt.implicit_diff import root_vjp 26 | 27 | def test_isotonic(self): 28 | jaxopt.isotonic.isotonic_l2_pav 29 | from jaxopt.isotonic import isotonic_l2_pav 30 | 31 | def test_prox(self): 32 | jaxopt.prox.prox_none 33 | from jaxopt.prox import prox_none 34 | 35 | def test_projection(self): 36 | jaxopt.projection.projection_simplex 37 | from jaxopt.projection import projection_simplex 38 | 39 | def test_tree_util(self): 40 | from jaxopt.tree_util import tree_vdot 41 | 42 | def test_linear_solve(self): 43 | from jaxopt.linear_solve import solve_lu 44 | 45 | def test_base(self): 46 | from jaxopt.base import LinearOperator 47 | 48 | def test_perturbations(self): 49 | from jaxopt.perturbations import make_perturbed_argmax 50 | 51 | def test_loss(self): 52 | jaxopt.loss.binary_logistic_loss 53 | from jaxopt.loss import binary_logistic_loss 54 | 55 | def test_objective(self): 56 | jaxopt.objective.least_squares 57 | from jaxopt.objective import least_squares 58 | 59 | def test_loop(self): 60 | from jaxopt.loop import while_loop 61 | 62 | 63 | if __name__ == '__main__': 64 | absltest.main() 65 | -------------------------------------------------------------------------------- /tests/isotonic_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for Isotonic Regression.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | from jax.test_util import check_grads 24 | from jaxopt.isotonic import isotonic_l2_pav 25 | from jaxopt._src import test_util 26 | from sklearn import isotonic 27 | 28 | 29 | class IsotonicPavTest(test_util.JaxoptTestCase): 30 | """Tests for PAV in JAX.""" 31 | 32 | def test_output_shape_and_dtype(self, n=10): 33 | """Verifies the shapes and dtypes of output.""" 34 | y = jax.random.normal(jax.random.PRNGKey(0), (n,)) 35 | output = isotonic_l2_pav(y) 36 | self.assertEqual(output.shape, y.shape) 37 | self.assertEqual(output.dtype, y.dtype) 38 | 39 | @parameterized.product(increasing=[True, False]) 40 | def test_compare_with_sklearn(self, increasing, n=10): 41 | """Compares the output with the one of sklearn.""" 42 | y = jax.random.normal(jax.random.PRNGKey(0), (n,)) 43 | output = isotonic_l2_pav(y, increasing=increasing) 44 | output_sklearn = jnp.array(isotonic.isotonic_regression(y, increasing=increasing)) 45 | self.assertArraysAllClose(output, output_sklearn) 46 | y_sort = y.sort() 47 | y_min = y_sort[2] 48 | y_max = y_sort[n-5] 49 | output = isotonic_l2_pav(y, y_min=y_min, y_max=y_max, increasing=increasing) 50 | output_sklearn = jnp.array(isotonic.isotonic_regression(y, y_min=y_min.item(), 51 | y_max=y_max.item(), increasing=increasing)) 52 | self.assertArraysAllClose(output, output_sklearn) 53 | 54 | @parameterized.product(increasing=[True, False]) 55 | def test_gradient(self, increasing, n=10): 56 | """Checks the gradient with finite differences.""" 57 | # Absolute error of test fails for large values of y. 58 | y = 0.1*jax.random.normal(jax.random.PRNGKey(0), (n,)) 59 | 60 | def loss(y): 61 | return (isotonic_l2_pav(y**3, increasing=increasing) 62 | + isotonic_l2_pav(y, increasing=increasing) ** 2).mean() 63 | 64 | check_grads(loss, (y,), order=2) 65 | 66 | def test_gradient_min_max(self, n=10): 67 | """Checks the gradient with finite differences.""" 68 | y = jax.random.normal(jax.random.PRNGKey(0), (n,)) 69 | y_sort = y.sort() 70 | y_min = y_sort[2] 71 | y_max = y_sort[n-5] 72 | def loss(y): 73 | return (isotonic_l2_pav(y**3, y_min=y_min, y_max=y_max) 74 | + isotonic_l2_pav(y, y_min=y_min, y_max=y_max) ** 2).mean() 75 | 76 | check_grads(loss, (y,), order=2) 77 | 78 | def test_vmap(self, n_features=10, n_batches=16): 79 | """Verifies vmap.""" 80 | y = jax.random.normal(jax.random.PRNGKey(0), (n_batches, n_features)) 81 | isotonic_l2_pav_vmap = jax.vmap(isotonic_l2_pav) 82 | for i in range(n_batches): 83 | self.assertArraysAllClose(isotonic_l2_pav_vmap(y)[i], isotonic_l2_pav(y[i])) 84 | 85 | if __name__ == '__main__': 86 | absltest.main() 87 | -------------------------------------------------------------------------------- /tests/iterative_refinement_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from functools import partial 16 | 17 | from absl.testing import absltest 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | from jax.test_util import check_grads 22 | 23 | from jaxopt import linear_solve 24 | from jaxopt import IterativeRefinement 25 | from jaxopt._src import test_util 26 | 27 | import numpy as onp 28 | 29 | 30 | class IterativeRefinementTest(test_util.JaxoptTestCase): 31 | 32 | def test_simple_system(self): 33 | onp.random.seed(0) 34 | n = 20 35 | A = onp.random.rand(n, n) 36 | b = onp.random.randn(n) 37 | 38 | low_acc = 1e-1 39 | high_acc = 1e-5 40 | 41 | # Heavily regularized low acuracy solver. 42 | inner_solver = partial(linear_solve.solve_gmres, tol=low_acc, ridge=1e-3) 43 | 44 | solver = IterativeRefinement(solve=inner_solver, tol=high_acc, maxiter=10) 45 | x, state = solver.run(None, A, b) 46 | self.assertLess(state.error, high_acc) 47 | 48 | x_approx = inner_solver(lambda x: jnp.dot(A, x), b) 49 | error_inner_solver = solver.l2_optimality_error(x_approx, A, b) 50 | # High accuracy solution obtained from low accuracy solver. 51 | self.assertLess(state.error, error_inner_solver) 52 | 53 | def test_ill_posed_problem(self): 54 | onp.random.seed(0) 55 | n = 10 56 | e = 5 57 | 58 | # duplicated rows. 59 | A = onp.random.rand(e, n) 60 | A = jnp.concatenate([A, A], axis=0) 61 | b = onp.random.randn(e) 62 | b = jnp.concatenate([b, b], axis=0) 63 | 64 | low_acc = 1e-1 65 | high_acc = 1e-3 66 | 67 | # Heavily regularized low acuracy solver. 68 | inner_solver = partial(linear_solve.solve_gmres, tol=low_acc, ridge=5e-2) 69 | 70 | solver = IterativeRefinement(solve=inner_solver, tol=high_acc, maxiter=30) 71 | x, state = solver.run(init_params=None, A=A, b=b) 72 | self.assertLess(state.error, high_acc) 73 | 74 | x_approx = inner_solver(lambda x: jnp.dot(A, x), b) 75 | error_inner_solver = solver.l2_optimality_error(x_approx, A, b) 76 | # High accuracy solution obtained from low accuracy solver. 77 | self.assertLess(state.error, error_inner_solver) 78 | 79 | def test_perturbed_system(self): 80 | onp.random.seed(0) 81 | n = 20 82 | 83 | A = onp.random.rand(n, n) # invertible matrix (with high probability). 84 | 85 | noise = onp.random.randn(n, n) 86 | sigma = 0.05 87 | A_bar = A + sigma * noise # perturbed system. 88 | 89 | expected = onp.random.randn(n) 90 | b = A @ expected # unperturbed target. 91 | 92 | high_acc = 1e-3 93 | solver = IterativeRefinement(matvec_A=None, matvec_A_bar=jnp.dot, 94 | tol=high_acc, maxiter=100) 95 | x, state = solver.run(init_params=None, A=A, b=b, A_bar=A_bar) 96 | self.assertLess(state.error, high_acc) 97 | self.assertArraysAllClose(x, expected, rtol=5e-2) 98 | 99 | def test_implicit_diff(self): 100 | onp.random.seed(17) 101 | n = 20 102 | A = onp.random.rand(n, n) 103 | b = onp.random.randn(n) 104 | 105 | low_acc = 1e-1 106 | high_acc = 1e-5 107 | 108 | # Heavily regularized low acuracy solver. 109 | inner_solver = partial(linear_solve.solve_gmres, tol=low_acc, ridge=1e-3) 110 | solver = IterativeRefinement(solve=inner_solver, tol=high_acc, maxiter=10) 111 | 112 | def solve_run(A, b): 113 | x, state = solver.run(init_params=None, A=A, b=b) 114 | return x 115 | 116 | check_grads(solve_run, args=(A, b), order=1, modes=['rev'], eps=1e-3) 117 | 118 | def test_warm_start(self): 119 | onp.random.seed(0) 120 | n = 20 121 | A = onp.random.rand(n, n) 122 | b = onp.random.randn(n) 123 | 124 | init_x = onp.random.randn(n) 125 | 126 | high_acc = 1e-5 127 | 128 | solver = IterativeRefinement(tol=high_acc, maxiter=10) 129 | x, state = solver.run(init_x, A, b) 130 | self.assertLess(state.error, high_acc) 131 | 132 | 133 | if __name__ == "__main__": 134 | jax.config.update("jax_enable_x64", False) # low precision environment. 135 | absltest.main() 136 | -------------------------------------------------------------------------------- /tests/linear_operator_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Linear Operator tests.""" 15 | 16 | from absl.testing import absltest 17 | 18 | import jax.numpy as jnp 19 | import numpy as onp 20 | 21 | from jaxopt._src.linear_operator import FunctionalLinearOperator 22 | from jaxopt._src import test_util 23 | 24 | 25 | class LinearOperatorTest(test_util.JaxoptTestCase): 26 | 27 | def test_matvec_and_rmatvec(self): 28 | rng = onp.random.RandomState(0) 29 | A = rng.randn(5, 4) 30 | matvec = lambda A,x: jnp.dot(A, x) 31 | x = rng.randn(4) 32 | y = rng.randn(5) 33 | linop_A = FunctionalLinearOperator(matvec, A) 34 | mv_A, rmv_A = linop_A.matvec_and_rmatvec(x, y) 35 | self.assertArraysAllClose(mv_A, jnp.dot(A, x)) 36 | self.assertArraysAllClose(rmv_A, jnp.dot(A.T, y)) 37 | 38 | 39 | if __name__ == '__main__': 40 | absltest.main() 41 | -------------------------------------------------------------------------------- /tests/linesearch_common_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from absl.testing import parameterized 17 | import jax 18 | import jax.numpy as jnp 19 | from jaxopt._src import test_util 20 | from jaxopt._src.linesearch_util import _init_stepsize 21 | from jaxopt._src.linesearch_util import _setup_linesearch 22 | 23 | 24 | class LinesearchTest(test_util.JaxoptTestCase): 25 | @parameterized.product( 26 | linesearch=["zoom", "backtracking", "hager-zhang"], 27 | use_gradient=[False, True], 28 | ) 29 | def test_linesearch_complex_variables(self, linesearch, use_gradient): 30 | """Test that optimization over complex variable z = x + jy matches equivalent real case""" 31 | 32 | W = jnp.array([[1, -2], [3, 4], [-4 + 2j, 5 - 3j], [-2 - 2j, 6]]) 33 | 34 | def C2R(z): 35 | return jnp.stack((z.real, z.imag)) if z is not None else None 36 | 37 | def R2C(x): 38 | return x[..., 0, :] + 1j * x[..., 1, :] 39 | 40 | def f(z): 41 | return W @ z 42 | 43 | def loss_complex(z): 44 | return jnp.sum(jnp.abs(f(z)) ** 1.5) 45 | 46 | def loss_real(zR): 47 | return loss_complex(R2C(zR)) 48 | 49 | z0 = jnp.array([1 - 1j, 0 + 1j]) 50 | 51 | common_args = dict( 52 | value_and_grad=False, 53 | has_aux=False, 54 | maxlsiter=3, 55 | max_stepsize=1, 56 | jit=True, 57 | unroll=False, 58 | verbose=False, 59 | ) 60 | 61 | ls_R = _setup_linesearch( 62 | linesearch=linesearch, 63 | fun=loss_real, 64 | **common_args, 65 | ) 66 | 67 | ls_C = _setup_linesearch( 68 | linesearch=linesearch, 69 | fun=loss_complex, 70 | **common_args, 71 | ) 72 | 73 | ls_state = _init_stepsize( 74 | strategy="increase", 75 | max_stepsize=1e-1, 76 | min_stepsize=1e-3, 77 | increase_factor=2.0, 78 | stepsize=1e-2, 79 | ) 80 | 81 | descent_direction = ( 82 | -jnp.conj(jax.grad(loss_complex)(z0)) if use_gradient else None 83 | ) 84 | 85 | stepsize_R, _ = ls_R.run( 86 | ls_state, params=C2R(z0), descent_direction=C2R(descent_direction) 87 | ) 88 | stepsize_C, _ = ls_C.run( 89 | ls_state, params=z0, descent_direction=descent_direction 90 | ) 91 | 92 | self.assertArraysAllClose(stepsize_R, stepsize_C) 93 | 94 | 95 | if __name__ == "__main__": 96 | absltest.main() 97 | -------------------------------------------------------------------------------- /tests/loop_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from absl.testing import parameterized 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | 21 | from jaxopt import loop 22 | from jaxopt._src import test_util 23 | 24 | 25 | class LoopTest(test_util.JaxoptTestCase): 26 | 27 | @parameterized.product(unroll=[True, False], jit=[True, False]) 28 | def test_while_loop(self, unroll, jit): 29 | def my_pow(x, y): 30 | def body_fun(val): 31 | return val * x 32 | def cond_fun(val): 33 | return True 34 | return loop.while_loop(cond_fun=cond_fun, body_fun=body_fun, init_val=1.0, 35 | maxiter=y, unroll=unroll, jit=jit) 36 | 37 | if not unroll and not jit: 38 | self.assertRaises(ValueError, my_pow, 3, 4) 39 | return 40 | 41 | self.assertEqual(my_pow(3, 4), pow(3, 4)) 42 | 43 | if unroll: 44 | # unroll=False uses lax.while_loop, whichs is not differentiable. 45 | self.assertEqual(jax.grad(my_pow)(3.0, 4), 46 | jax.grad(jnp.power)(3.0, 4)) 47 | 48 | @parameterized.product(unroll=[True, False], jit=[True, False]) 49 | def test_while_loop_stopped(self, unroll, jit): 50 | def my_pow(x, y, max_val): 51 | def body_fun(val): 52 | return val * x 53 | def cond_fun(val): 54 | return val < max_val 55 | return loop.while_loop(cond_fun=cond_fun, body_fun=body_fun, init_val=1.0, 56 | maxiter=y, unroll=unroll, jit=jit) 57 | 58 | if not unroll and not jit: 59 | self.assertRaises(ValueError, my_pow, 3, 4, max_val=81) 60 | return 61 | 62 | # We asked for pow(3, 6) but due to max_val, we get pow(3, 4). 63 | self.assertEqual(my_pow(3, 6, max_val=81), pow(3, 4)) 64 | 65 | if unroll: 66 | self.assertEqual(jax.grad(my_pow)(3.0, 6, max_val=81), 67 | jax.grad(jnp.power)(3.0, 4)) 68 | 69 | 70 | if __name__ == '__main__': 71 | absltest.main() 72 | -------------------------------------------------------------------------------- /tests/nonlinear_cg_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from absl.testing import parameterized 17 | 18 | import jax.random 19 | import jax.numpy as jnp 20 | 21 | import numpy as onp 22 | 23 | import jaxopt 24 | from jaxopt import NonlinearCG 25 | from jaxopt import objective 26 | from jaxopt._src import test_util 27 | from sklearn import datasets 28 | 29 | # Uncomment this line to test in x64 30 | # jax.config.update('jax_enable_x64', True) 31 | 32 | def get_random_pytree(): 33 | key = jax.random.PRNGKey(1213) 34 | 35 | def rn(key, l=3): 36 | return 0.05 * jnp.array(onp.random.normal(size=(10,))) 37 | 38 | def _get_random_pytree(curr_depth=0, max_depth=3): 39 | r = onp.random.uniform() 40 | if curr_depth == max_depth or r <= 0.2: # leaf 41 | return rn(key) 42 | elif curr_depth <= 1 or r <= 0.7: # list 43 | return [ 44 | _get_random_pytree(curr_depth=curr_depth + 45 | 1, max_depth=max_depth) 46 | for _ in range(2) 47 | ] 48 | else: # dict 49 | return { 50 | str(_): _get_random_pytree( 51 | curr_depth=curr_depth + 1, max_depth=max_depth 52 | ) 53 | for _ in range(2) 54 | } 55 | return [rn(key), {'a': rn(key), 'b': rn(key)}, _get_random_pytree()] 56 | 57 | 58 | class NonlinearCGTest(test_util.JaxoptTestCase): 59 | 60 | def test_arbitrary_pytree(self): 61 | def loss(w, data): 62 | X, y = data 63 | _w = jnp.concatenate(jax.tree_util.tree_leaves(w)) 64 | return ((jnp.dot(X, _w) - y) ** 2).mean() 65 | 66 | w = get_random_pytree() 67 | f_w = jnp.concatenate(jax.tree_util.tree_leaves(w)) 68 | X, y = datasets.make_classification(n_samples=15, n_features=f_w.shape[-1], 69 | n_classes=2, n_informative=3, 70 | random_state=0) 71 | data = (X, y) 72 | cg_model = NonlinearCG(fun=loss, tol=1e-2, maxiter=300, 73 | method="polak-ribiere") 74 | w_fit, info = cg_model.run(w, data=data) 75 | self.assertLessEqual(info.error, 5e-2) 76 | 77 | @parameterized.product( 78 | method=["fletcher-reeves", "polak-ribiere", "hestenes-stiefel"], 79 | linesearch=[ 80 | "backtracking", 81 | "zoom", 82 | jaxopt.BacktrackingLineSearch( 83 | objective.binary_logreg, decrease_factor=0.5 84 | ), 85 | ], 86 | linesearch_init=["max", "current", "increase"], 87 | ) 88 | def test_binary_logreg(self, method, linesearch, linesearch_init): 89 | X, y = datasets.make_classification( 90 | n_samples=10, n_features=5, n_classes=2, n_informative=3, random_state=0 91 | ) 92 | data = (X, y) 93 | fun = objective.binary_logreg 94 | 95 | w_init = jnp.zeros(X.shape[1]) 96 | cg_model = NonlinearCG( 97 | fun=fun, 98 | tol=1e-3, 99 | maxiter=100, 100 | method=method, 101 | linesearch=linesearch, 102 | linesearch_init=linesearch_init, 103 | ) 104 | 105 | # Test with positional argument. 106 | w_fit, info = cg_model.run(w_init, data) 107 | 108 | # Check optimality conditions. 109 | self.assertLessEqual(info.error, 5e-2) 110 | 111 | # Compare against sklearn. 112 | w_skl = test_util.logreg_skl(X, y, 1e-6, fit_intercept=False, 113 | multiclass=False) 114 | self.assertArraysAllClose(w_fit, w_skl, atol=5e-2) 115 | 116 | @parameterized.product( 117 | linesearch=['zoom', 'backtracking', 'hager-zhang'], 118 | method=['hestenes-stiefel', 'polak-ribiere', 'fletcher-reeves'] 119 | ) 120 | def test_complex(self, method, linesearch): 121 | """Test that optimization over complex variable z = x + jy matches equivalent real case""" 122 | 123 | W = jnp.array( 124 | [[1, - 2], 125 | [3, 4], 126 | [-4 + 2j, 5 - 3j], 127 | [-2 - 2j, 6]] 128 | ) 129 | 130 | def C2R(z): 131 | return jnp.stack((z.real, z.imag)) 132 | 133 | def R2C(x): 134 | return x[..., 0, :] + 1j * x[..., 1, :] 135 | 136 | def f(z): 137 | return W @ z 138 | 139 | def loss_complex(z): 140 | return jnp.sum(jnp.abs(f(z)) ** 1.5) 141 | 142 | def loss_real(zR): 143 | return loss_complex(R2C(zR)) 144 | 145 | z0 = jnp.array([1 - 1j, 0 + 1j]) 146 | xy0 = jnp.stack((z0.real, z0.imag)) 147 | 148 | solver_C = NonlinearCG(fun=loss_complex, maxiter=5, 149 | maxls=3, method=method, linesearch=linesearch) 150 | solver_R = NonlinearCG(fun=loss_real, maxiter=5, 151 | maxls=3, method=method, linesearch=linesearch) 152 | sol_C, _ = solver_C.run(z0) 153 | sol_R, _ = solver_R.run(C2R(z0)) 154 | # NOTE(vroulet): there is a slight loss of precision between real 155 | # and complex cases (observable for any linesearch with jax.enable_x64 156 | tol = 5*1e-15 if jax.config.jax_enable_x64 else 5*1e-6 157 | self.assertArraysAllClose(sol_C, R2C(sol_R), atol=tol, rtol=tol) 158 | 159 | 160 | if __name__ == '__main__': 161 | absltest.main() 162 | -------------------------------------------------------------------------------- /tests/projected_gradient_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from absl.testing import absltest 17 | from absl.testing import parameterized 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | 22 | from jaxopt import objective 23 | from jaxopt import projection 24 | from jaxopt import ProjectedGradient 25 | from jaxopt import ScipyBoundedMinimize 26 | from jaxopt._src import test_util 27 | 28 | import numpy as onp 29 | 30 | 31 | N_CALLS = 0 32 | 33 | class ProjectedGradientTest(test_util.JaxoptTestCase): 34 | 35 | def test_non_negative_least_squares(self): 36 | rng = onp.random.RandomState(0) 37 | X = rng.randn(10, 5) 38 | w = rng.rand(5) 39 | y = jnp.dot(X, w) 40 | fun = objective.least_squares 41 | w_init = jnp.zeros_like(w) 42 | 43 | pg = ProjectedGradient(fun=fun, 44 | projection=projection.projection_non_negative) 45 | pg_sol = pg.run(w_init, data=(X, y)).params 46 | 47 | lbfgsb = ScipyBoundedMinimize(fun=fun, method="l-bfgs-b") 48 | lower_bounds = jnp.zeros_like(w_init) 49 | upper_bounds = jnp.ones_like(w_init) * jnp.inf 50 | bounds = (lower_bounds, upper_bounds) 51 | lbfgsb_sol = lbfgsb.run(w_init, bounds=bounds, data=(X, y)).params 52 | 53 | self.assertArraysAllClose(pg_sol, lbfgsb_sol, atol=1e-2) 54 | 55 | def test_projected_gradient_l2_ball(self): 56 | rng = onp.random.RandomState(0) 57 | X = rng.randn(10, 5) 58 | w = rng.rand(5) 59 | y = jnp.dot(X, w) 60 | fun = objective.least_squares 61 | w_init = jnp.zeros_like(w) 62 | 63 | pg = ProjectedGradient(fun=fun, 64 | projection=projection.projection_l2_ball) 65 | pg_sol = pg.run(w_init, hyperparams_proj=1.0, data=(X, y)).params 66 | self.assertLess(jnp.sqrt(jnp.sum(pg_sol ** 2)), 1.0) 67 | 68 | def test_projected_gradient_l2_ball_manual_loop(self): 69 | rng = onp.random.RandomState(0) 70 | X = rng.randn(10, 5) 71 | w = rng.rand(5) 72 | y = jnp.dot(X, w) 73 | fun = objective.least_squares 74 | params = jnp.zeros_like(w) 75 | 76 | pg = ProjectedGradient(fun=fun, 77 | projection=projection.projection_l2_ball) 78 | 79 | state = pg.init_state(params) 80 | 81 | for _ in range(10): 82 | params, state = pg.update(params, state, hyperparams_proj=1.0, data=(X, y)) 83 | 84 | self.assertLess(jnp.sqrt(jnp.sum(params ** 2)), 1.0) 85 | 86 | def test_projected_gradient_implicit_diff(self): 87 | rng = onp.random.RandomState(0) 88 | X = rng.randn(10, 5) 89 | w = rng.rand(5) 90 | y = jnp.dot(X, w) 91 | fun = objective.least_squares 92 | w_init = jnp.zeros_like(w) 93 | 94 | def solution(radius): 95 | pg = ProjectedGradient(fun=fun, 96 | projection=projection.projection_l2_ball) 97 | return pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params 98 | 99 | eps = 1e-4 100 | J = jax.jacobian(solution)(0.1) 101 | J2 = (solution(0.1 + eps) - solution(0.1 - eps)) / (2 * eps) 102 | self.assertArraysAllClose(J, J2, atol=1e-2) 103 | 104 | def test_polyhedron_projection(self): 105 | def f(x): 106 | return x[0]**2-x[1]**2 107 | 108 | A = jnp.array([[0, 0]]) 109 | b = jnp.array([0]) 110 | G = jnp.array([[-1, -1], [0, 1], [1, -1], [-1, 0], [0, -1]]) 111 | h = jnp.array([-1, 1, 1, 0, 0]) 112 | hyperparams = (A, b, G, h) 113 | 114 | proj = projection.projection_polyhedron 115 | pg = ProjectedGradient(fun=f, projection=proj, jit=False) 116 | sol, state = pg.run(init_params=jnp.array([0.,1.]), hyperparams_proj=hyperparams) 117 | self.assertLess(state.error, pg.tol) 118 | 119 | @parameterized.product(n_iter=[10]) 120 | def test_n_calls(self, n_iter): 121 | """Test whether the number of function calls 122 | is equal to the number of iterations + 1 in the 123 | no linesearch case, where the complexity is linear.""" 124 | def fun(x): 125 | global N_CALLS 126 | N_CALLS += 1 127 | return x[0]**2-x[1]**2 128 | 129 | A = jnp.array([[0, 0]]) 130 | b = jnp.array([0]) 131 | G = jnp.array([[-1, -1], [0, 1], [1, -1], [-1, 0], [0, -1]]) 132 | h = jnp.array([-1, 1, 1, 0, 0]) 133 | hyperparams = (A, b, G, h) 134 | 135 | proj = projection.projection_polyhedron 136 | pg = ProjectedGradient(fun=fun, projection=proj, jit=False, maxiter=n_iter, tol=1e-10, stepsize=1.0) 137 | sol, state = pg.run(init_params=jnp.array([0.,1.]), hyperparams_proj=hyperparams) 138 | self.assertEqual(N_CALLS, n_iter) 139 | 140 | 141 | if __name__ == '__main__': 142 | # Uncomment the line below in order to run in float64. 143 | # jax.config.update("jax_enable_x64", True) 144 | absltest.main() 145 | --------------------------------------------------------------------------------