├── .readthedocs.yaml ├── LICENSE ├── README.md ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── _static │ └── custom.css │ ├── api │ ├── index.rst │ ├── minimize-bfgs.rst │ ├── minimize-cg.rst │ ├── minimize-dogleg.rst │ ├── minimize-lbfgs.rst │ ├── minimize-newton-cg.rst │ ├── minimize-newton-exact.rst │ ├── minimize-trust-exact.rst │ ├── minimize-trust-krylov.rst │ └── minimize-trust-ncg.rst │ ├── conf.py │ ├── examples │ └── index.rst │ ├── index.rst │ ├── install.rst │ └── user_guide │ └── index.rst ├── examples ├── constrained_optimization_adversarial_examples.ipynb ├── rosen_minimize.ipynb ├── scipy_benchmark.py └── train_mnist_Minimizer.py ├── pyproject.toml ├── requirements.txt ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── test_imports.py └── torchmin │ ├── __init__.py │ ├── test_leastsquares.py │ └── test_minimize_constr.py └── torchmin ├── __init__.py ├── benchmarks.py ├── bfgs.py ├── cg.py ├── function.py ├── line_search.py ├── lstsq ├── __init__.py ├── cg.py ├── common.py ├── least_squares.py ├── linear_operator.py ├── lsmr.py └── trf.py ├── minimize.py ├── minimize_constr.py ├── newton.py ├── optim ├── __init__.py ├── minimizer.py └── scipy_minimizer.py └── trustregion ├── __init__.py ├── base.py ├── dogleg.py ├── exact.py ├── krylov.py └── ncg.py /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | # Tell RTD which build image to use and which Python to install 4 | build: 5 | os: ubuntu-22.04 6 | tools: 7 | python: "3.8" 8 | 9 | # Build from the docs/ directory with Sphinx 10 | sphinx: 11 | configuration: docs/source/conf.py 12 | 13 | # Explicitly set the version of Python and its requirements 14 | python: 15 | install: 16 | - requirements: docs/requirements.txt 17 | - requirements: requirements.txt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Reuben Feinman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Minimize 2 | 3 | For the most up-to-date information on pytorch-minimize, see the docs site: [pytorch-minimize.readthedocs.io](https://pytorch-minimize.readthedocs.io/) 4 | 5 | Pytorch-minimize represents a collection of utilities for minimizing multivariate functions in PyTorch. 6 | It is inspired heavily by SciPy's `optimize` module and MATLAB's [Optimization Toolbox](https://www.mathworks.com/products/optimization.html). 7 | Unlike SciPy and MATLAB, which use numerical approximations of function derivatives, pytorch-minimize uses _real_ first- and second-order derivatives, computed seamlessly behind the scenes with autograd. 8 | Both CPU and CUDA are supported. 9 | 10 | __Author__: Reuben Feinman 11 | 12 | __At a glance:__ 13 | 14 | ```python 15 | import torch 16 | from torchmin import minimize 17 | 18 | def rosen(x): 19 | return torch.sum(100*(x[..., 1:] - x[..., :-1]**2)**2 20 | + (1 - x[..., :-1])**2) 21 | 22 | # initial point 23 | x0 = torch.tensor([1., 8.]) 24 | 25 | # Select from the following methods: 26 | # ['bfgs', 'l-bfgs', 'cg', 'newton-cg', 'newton-exact', 27 | # 'trust-ncg', 'trust-krylov', 'trust-exact', 'dogleg'] 28 | 29 | # BFGS 30 | result = minimize(rosen, x0, method='bfgs') 31 | 32 | # Newton Conjugate Gradient 33 | result = minimize(rosen, x0, method='newton-cg') 34 | 35 | # Newton Exact 36 | result = minimize(rosen, x0, method='newton-exact') 37 | ``` 38 | 39 | __Solvers:__ BFGS, L-BFGS, Conjugate Gradient (CG), Newton Conjugate Gradient (NCG), Newton Exact, Dogleg, Trust-Region Exact, Trust-Region NCG, Trust-Region GLTR (Krylov) 40 | 41 | __Examples:__ See the [Rosenbrock minimization notebook](https://github.com/rfeinman/pytorch-minimize/blob/master/examples/rosen_minimize.ipynb) for a demonstration of function minimization with a handful of different algorithms. 42 | 43 | __Install with pip:__ 44 | 45 | pip install pytorch-minimize 46 | 47 | __Install from source:__ 48 | 49 | git clone https://github.com/rfeinman/pytorch-minimize.git 50 | cd pytorch-minimize 51 | pip install -e . 52 | 53 | ## Motivation 54 | Although PyTorch offers many routines for stochastic optimization, utilities for deterministic optimization are scarce; only L-BFGS is included in the `optim` package, and it's modified for mini-batch training. 55 | 56 | MATLAB and SciPy are industry standards for deterministic optimization. 57 | These libraries have a comprehensive set of routines; however, automatic differentiation is not supported.* 58 | Therefore, the user must provide explicit 1st- and 2nd-order gradients (if they are known) or use finite-difference approximations. 59 | 60 | The motivation for pytorch-minimize is to offer a set of tools for deterministic optimization with automatic gradients and GPU acceleration. 61 | 62 | __ 63 | 64 | *MATLAB offers minimal autograd support via the Deep Learning Toolbox, but the integration is not seamless: data must be converted to "dlarray" structures, and only a [subset of functions](https://www.mathworks.com/help/deeplearning/ug/list-of-functions-with-dlarray-support.html) are supported. 65 | Furthermore, derivatives must still be constructed and provided as function handles. 66 | Pytorch-minimize uses autograd to compute derivatives behind the scenes, so all you provide is an objective function. 67 | 68 | ## Library 69 | 70 | The pytorch-minimize library includes solvers for general-purpose function minimization (unconstrained & constrained), as well as for nonlinear least squares problems. 71 | 72 | ### 1. Unconstrained Minimizers 73 | 74 | The following solvers are available for _unconstrained_ minimization: 75 | 76 | - __BFGS/L-BFGS.__ BFGS is a cannonical quasi-Newton method for unconstrained optimization. I've implemented both the standard BFGS and the "limited memory" L-BFGS. For smaller scale problems where memory is not a concern, BFGS should be significantly faster than L-BFGS (especially on CUDA) since it avoids Python for loops and instead uses pure torch. 77 | 78 | - __Conjugate Gradient (CG).__ The conjugate gradient algorithm is a generalization of linear conjugate gradient to nonlinear optimization problems. Pytorch-minimize includes an implementation of the Polak-Ribiére CG algorithm described in Nocedal & Wright (2006) chapter 5.2. 79 | 80 | - __Newton Conjugate Gradient (NCG).__ The Newton-Raphson method is a staple of unconstrained optimization. Although computing full Hessian matrices with PyTorch's reverse-mode automatic differentiation can be costly, computing Hessian-vector products is cheap, and it also saves a lot of memory. The Conjugate Gradient (CG) variant of Newton's method is an effective solution for unconstrained minimization with Hessian-vector products. I've implemented a lightweight NewtonCG minimizer that uses HVP for the linear inverse sub-problems. 81 | 82 | - __Newton Exact.__ In some cases, we may prefer a more precise variant of the Newton-Raphson method at the cost of additional complexity. I've also implemented an "exact" variant of Newton's method that computes the full Hessian matrix and uses Cholesky factorization for linear inverse sub-problems. When Cholesky fails--i.e. the Hessian is not positive definite--the solver resorts to one of two options as specified by the user: 1) steepest descent direction (default), or 2) solve the inverse hessian with LU factorization. 83 | 84 | - __Trust-Region Newton Conjugate Gradient.__ Description coming soon. 85 | 86 | - __Trust-Region Newton Generalized Lanczos (Krylov).__ Description coming soon. 87 | 88 | - __Trust-Region Exact.__ Description coming soon. 89 | 90 | - __Dogleg.__ Description coming soon. 91 | 92 | To access the unconstrained minimizer interface, use the following import statement: 93 | 94 | from torchmin import minimize 95 | 96 | Use the argument `method` to specify which of the afformentioned solvers should be applied. 97 | 98 | ### 2. Constrained Minimizers 99 | 100 | The following solvers are available for _constrained_ minimization: 101 | 102 | - __Trust-Region Constrained Algorithm.__ Pytorch-minimize includes a single constrained minimization routine based on SciPy's 'trust-constr' method. The algorithm accepts generalized nonlinear constraints and variable boundries via the "constr" and "bounds" arguments. For equality constrained problems, it is an implementation of the Byrd-Omojokun Trust-Region SQP method. When inequality constraints are imposed, the trust-region interior point method is used. NOTE: The current trust-region constrained minimizer is not a custom implementation, but rather a wrapper for SciPy's `optimize.minimize` routine. It uses autograd behind the scenes to build jacobian & hessian callables before invoking scipy. Inputs and objectivs should use torch tensors like other pytorch-minimize routines. CUDA is supported but not recommended; data will be moved back-and-forth between GPU/CPU. 103 | 104 | To access the constrained minimizer interface, use the following import statement: 105 | 106 | from torchmin import minimize_constr 107 | 108 | ### 3. Nonlinear Least Squares 109 | 110 | The library also includes specialized solvers for nonlinear least squares problems. 111 | These solvers revolve around the Gauss-Newton method, a modification of Newton's method tailored to the lstsq setting. 112 | The least squares interface can be imported as follows: 113 | 114 | from torchmin import least_squares 115 | 116 | The least_squares function is heavily motivated by scipy's `optimize.least_squares`. 117 | Much of the scipy code was borrowed directly (all rights reserved) and ported from numpy to torch. 118 | Rather than have the user provide a jacobian function, in the new interface, jacobian-vector products are computed behind the scenes with autograd. 119 | At the moment, only the Trust Region Reflective ("trf") method is implemented, and bounds are not yet supported. 120 | 121 | ## Examples 122 | 123 | The [Rosenbrock minimization tutorial](https://github.com/rfeinman/pytorch-minimize/blob/master/examples/rosen_minimize.ipynb) demonstrates how to use pytorch-minimize to find the minimum of a scalar-valued function of multiple variables using various optimization strategies. 124 | 125 | In addition, the [SciPy benchmark](https://github.com/rfeinman/pytorch-minimize/blob/master/examples/scipy_benchmark.py) provides a comparison of pytorch-minimize solvers to their analogous solvers from the `scipy.optimize` library. 126 | For those transitioning from scipy, this script will help get a feel for the design of the current library. 127 | Unlike scipy, jacobian and hessian functions need not be provided to pytorch-minimize solvers, and numerical approximations are never used. 128 | 129 | For constrained optimization, the [adversarial examples tutorial](https://github.com/rfeinman/pytorch-minimize/blob/master/examples/constrained_optimization_adversarial_examples.ipynb) demonstrates how to use the trust-region constrained routine to generate an optimal adversarial perturbation given a constraint on the perturbation norm. 130 | 131 | ## Optimizer API 132 | 133 | As an alternative to the functional API, pytorch-minimize also includes an "optimizer" API based on the `torch.optim.Optimizer` class. 134 | To access the optimizer class, import as follows: 135 | 136 | from torchmin import Minimizer 137 | 138 | ## Citing this work 139 | 140 | If you use pytorch-minimize for academic research, you may cite the library as follows: 141 | 142 | ``` 143 | @misc{Feinman2021, 144 | author = {Feinman, Reuben}, 145 | title = {Pytorch-minimize: a library for numerical optimization with autograd}, 146 | publisher = {GitHub}, 147 | year = {2021}, 148 | url = {https://github.com/rfeinman/pytorch-minimize}, 149 | } 150 | ``` 151 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==3.5.3 2 | jinja2<3.1 3 | sphinx_rtd_theme==0.5.2 4 | readthedocs-sphinx-search==0.3.2 5 | -------------------------------------------------------------------------------- /docs/source/_static/custom.css: -------------------------------------------------------------------------------- 1 | .wy-table-responsive table td { 2 | white-space: normal; 3 | } -------------------------------------------------------------------------------- /docs/source/api/index.rst: -------------------------------------------------------------------------------- 1 | ================= 2 | API Documentation 3 | ================= 4 | 5 | .. currentmodule:: torchmin 6 | 7 | 8 | Functional API 9 | ============== 10 | 11 | The functional API provides an interface similar to those of SciPy's :mod:`optimize` module and MATLAB's ``fminunc``/``fmincon`` routines. Parameters are provided as a single torch Tensor, and an :class:`OptimizeResult` instance is returned that includes the optimized parameter value as well as other useful information (e.g. final function value, parameter gradient, etc.). 12 | 13 | There are 3 core utilities in the functional API, designed for 3 unique 14 | numerical optimization problems. 15 | 16 | 17 | **Unconstrained minimization** 18 | 19 | .. autosummary:: 20 | :toctree: generated 21 | 22 | minimize 23 | 24 | The :func:`minimize` function is a general utility for *unconstrained* minimization. It implements a number of different routines based on Newton and Quasi-Newton methods for numerical optimization. The following methods are supported, accessed via the `method` argument: 25 | 26 | .. toctree:: 27 | 28 | minimize-bfgs 29 | minimize-lbfgs 30 | minimize-cg 31 | minimize-newton-cg 32 | minimize-newton-exact 33 | minimize-dogleg 34 | minimize-trust-ncg 35 | minimize-trust-exact 36 | minimize-trust-krylov 37 | 38 | 39 | **Constrained minimization** 40 | 41 | .. autosummary:: 42 | :toctree: generated 43 | 44 | minimize_constr 45 | 46 | The :func:`minimize_constr` function is a general utility for *constrained* minimization. Algorithms for constrained minimization use Newton and Quasi-Newton methods on the KKT conditions of the constrained optimization problem. 47 | 48 | .. note:: 49 | The :func:`minimize_constr` function is currently in early beta. Unlike :func:`minimize`--which uses custom, pure PyTorch backend--the constrained solver is a wrapper for SciPy's 'trust-constr' minimization method. CUDA tensors are supported, but CUDA will only be used for function and gradient evaluation, with the remaining solver computations performed on CPU (with numpy arrays). 50 | 51 | 52 | **Nonlinear least-squares** 53 | 54 | .. autosummary:: 55 | :toctree: generated 56 | 57 | least_squares 58 | 59 | The :func:`least_squares` function is a specialized utility for nonlinear least-squares minimization problems. Algorithms for least-squares revolve around the Gauss-Newton method, a modification of Newton's method tailored to residual sum-of-squares (RSS) optimization. The following methods are currently supported: 60 | 61 | - Trust-region reflective 62 | - Dogleg - COMING SOON 63 | - Gauss-Newton line search - COMING SOON 64 | 65 | 66 | 67 | Optimizer API 68 | ============== 69 | 70 | The optimizer API provides an alternative interface based on PyTorch's :mod:`optim` module. This interface follows the schematic of PyTorch optimizers and will be familiar to those migrating from torch. 71 | 72 | .. autosummary:: 73 | :toctree: generated 74 | 75 | Minimizer 76 | Minimizer.step 77 | 78 | The :class:`Minimizer` class inherits from :class:`torch.optim.Optimizer` and constructs an object that holds the state of the provided variables. Unlike the functional API, which expects parameters to be a single Tensor, parameters can be passed to :class:`Minimizer` as iterables of Tensors. The class serves as a wrapper for :func:`torchmin.minimize()` and can use any of its methods (selected via the `method` argument) to perform unconstrained minimization. 79 | 80 | .. autosummary:: 81 | :toctree: generated 82 | 83 | ScipyMinimizer 84 | ScipyMinimizer.step 85 | 86 | Although the :class:`Minimizer` class will be sufficient for most problems where torch optimizers would be used, it does not support constraints. Another optimizer is provided, :class:`ScipyMinimizer`, which supports parameter bounds and linear/nonlinear constraint functions. This optimizer is a wrapper for :func:`scipy.optimize.minimize`. When using bound constraints, `bounds` are passed as iterables with same length as `params`, i.e. one bound specification per parameter Tensor. 87 | -------------------------------------------------------------------------------- /docs/source/api/minimize-bfgs.rst: -------------------------------------------------------------------------------- 1 | minimize(method='bfgs') 2 | ---------------------------------------- 3 | 4 | .. autofunction:: torchmin.bfgs._minimize_bfgs -------------------------------------------------------------------------------- /docs/source/api/minimize-cg.rst: -------------------------------------------------------------------------------- 1 | minimize(method='cg') 2 | ---------------------------------------- 3 | 4 | .. autofunction:: torchmin.cg._minimize_cg -------------------------------------------------------------------------------- /docs/source/api/minimize-dogleg.rst: -------------------------------------------------------------------------------- 1 | minimize(method='dogleg') 2 | ---------------------------------------- 3 | 4 | .. autofunction:: torchmin.trustregion._minimize_dogleg -------------------------------------------------------------------------------- /docs/source/api/minimize-lbfgs.rst: -------------------------------------------------------------------------------- 1 | minimize(method='l-bfgs') 2 | ---------------------------------------- 3 | 4 | .. autofunction:: torchmin.bfgs._minimize_lbfgs -------------------------------------------------------------------------------- /docs/source/api/minimize-newton-cg.rst: -------------------------------------------------------------------------------- 1 | minimize(method='newton-cg') 2 | ---------------------------------------- 3 | 4 | .. autofunction:: torchmin.newton._minimize_newton_cg -------------------------------------------------------------------------------- /docs/source/api/minimize-newton-exact.rst: -------------------------------------------------------------------------------- 1 | minimize(method='newton-exact') 2 | ---------------------------------------- 3 | 4 | .. autofunction:: torchmin.newton._minimize_newton_exact -------------------------------------------------------------------------------- /docs/source/api/minimize-trust-exact.rst: -------------------------------------------------------------------------------- 1 | minimize(method='trust-exact') 2 | ---------------------------------------- 3 | 4 | .. autofunction:: torchmin.trustregion._minimize_trust_exact -------------------------------------------------------------------------------- /docs/source/api/minimize-trust-krylov.rst: -------------------------------------------------------------------------------- 1 | minimize(method='trust-krylov') 2 | ---------------------------------------- 3 | 4 | .. autofunction:: torchmin.trustregion._minimize_trust_krylov -------------------------------------------------------------------------------- /docs/source/api/minimize-trust-ncg.rst: -------------------------------------------------------------------------------- 1 | minimize(method='trust-ncg') 2 | ---------------------------------------- 3 | 4 | .. autofunction:: torchmin.trustregion._minimize_trust_ncg -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../../')) 16 | 17 | import torchmin 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'pytorch-minimize' 23 | copyright = '2021, Reuben Feinman' 24 | author = 'Reuben Feinman' 25 | 26 | # The full version, including alpha/beta/rc tags 27 | release = '0.1.0-beta' 28 | 29 | 30 | # -- General configuration --------------------------------------------------- 31 | 32 | # Add any Sphinx extension module names here, as strings. They can be 33 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 34 | # ones. 35 | import sphinx_rtd_theme 36 | 37 | extensions = [ 38 | 'sphinx.ext.autodoc', 39 | 'sphinx.ext.autosummary', 40 | 'sphinx.ext.doctest', 41 | 'sphinx.ext.intersphinx', 42 | 'sphinx.ext.todo', 43 | 'sphinx.ext.coverage', 44 | 'sphinx.ext.mathjax', 45 | 'sphinx.ext.napoleon', 46 | 'sphinx.ext.viewcode', 47 | 'sphinx.ext.autosectionlabel', 48 | 'sphinx_rtd_theme' 49 | ] 50 | 51 | # autosectionlabel throws warnings if section names are duplicated. 52 | # The following tells autosectionlabel to not throw a warning for 53 | # duplicated section names that are in different documents. 54 | autosectionlabel_prefix_document = True 55 | 56 | 57 | # Add any paths that contain templates here, relative to this directory. 58 | templates_path = ['_templates'] 59 | 60 | # List of patterns, relative to source directory, that match files and 61 | # directories to ignore when looking for source files. 62 | # This pattern also affects html_static_path and html_extra_path. 63 | exclude_patterns = [] 64 | 65 | # ==== Customizations ==== 66 | 67 | # Disable displaying type annotations, these can be very verbose 68 | autodoc_typehints = 'none' 69 | 70 | # build the templated autosummary files 71 | autosummary_generate = True 72 | #numpydoc_show_class_members = False 73 | 74 | # Enable overriding of function signatures in the first line of the docstring. 75 | #autodoc_docstring_signature = True 76 | 77 | 78 | 79 | # -- Options for HTML output ------------------------------------------------- 80 | 81 | # The theme to use for HTML and HTML Help pages. See the documentation for 82 | # a list of builtin themes. 83 | # 84 | #html_theme = 'alabaster' 85 | html_theme = 'sphinx_rtd_theme' # addition 86 | 87 | # Add any paths that contain custom static files (such as style sheets) here, 88 | # relative to this directory. They are copied after the builtin static files, 89 | # so a file named "default.css" will overwrite the builtin "default.css". 90 | html_static_path = ['_static'] 91 | 92 | 93 | # ==== Customizations ==== 94 | 95 | # Called automatically by Sphinx, making this `conf.py` an "extension". 96 | def setup(app): 97 | # at the moment, we use custom.css to specify a maximum with for tables, 98 | # such as those generated by autosummary. 99 | app.add_css_file('custom.css') 100 | -------------------------------------------------------------------------------- /docs/source/examples/index.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ========= 3 | 4 | The examples site is in active development. Check back soon for more complete examples of how to use pytorch-minimize. 5 | 6 | Unconstrained minimization 7 | --------------------------- 8 | 9 | .. code-block:: python 10 | 11 | from torchmin import minimize 12 | from torchmin.benchmarks import rosen 13 | 14 | # initial point 15 | x0 = torch.randn(100, device='cpu') 16 | 17 | # BFGS 18 | result = minimize(rosen, x0, method='bfgs') 19 | 20 | # Newton Conjugate Gradient 21 | result = minimize(rosen, x0, method='newton-cg') 22 | 23 | Constrained minimization 24 | --------------------------- 25 | 26 | For constrained optimization, the `adversarial examples tutorial `_ demonstrates how to use trust-region constrained optimization to generate an optimal adversarial perturbation given a constraint on the perturbation norm. 27 | 28 | Nonlinear least-squares 29 | --------------------------- 30 | 31 | Coming soon. 32 | 33 | 34 | Scipy benchmark 35 | --------------------------- 36 | 37 | The `SciPy benchmark `_ provides a comparison of pytorch-minimize solvers to their analogous solvers from the :mod:`scipy.optimize` module. 38 | For those transitioning from scipy, this script will help get a feel for the design of the current library. 39 | Unlike scipy, jacobian and hessian functions need not be provided to pytorch-minimize solvers, and numerical approximations are never used. 40 | 41 | 42 | Minimizer (optimizer API) 43 | --------------------------- 44 | 45 | Another way to use the optimization tools from pytorch-minimize is via :class:`torchmin.Minimizer`, a pytorch Optimizer class. For a demo on how to use the Minimizer class, see the `MNIST classifier `_ tutorial. -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Pytorch-minimize 2 | ================ 3 | 4 | Pytorch-minimize is a library for numerical optimization with automatic differentiation and GPU acceleration. It implements a number of canonical techniques for deterministic (or "full-batch") optimization not offered in the :mod:`torch.optim` module. The library is inspired heavily by SciPy's :mod:`optimize` module and MATLAB's `Optimization Toolbox `_. Unlike SciPy and MATLAB, which use numerical approximations of derivatives that are slow and often inaccurate, pytorch-minimize uses *real* first- and second-order derivatives, computed seamlessly behind the scenes with autograd. Both CPU and CUDA are supported. 5 | 6 | :Author: Reuben Feinman 7 | :Version: 0.0.1 8 | 9 | Pytorch-minimize is currently in Beta; expect the API to change before a first official release. Some of the source code was taken directly from SciPy and ported to PyTorch. As such, here is their copyright notice: 10 | 11 | Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. All rights reserved. 12 | 13 | 14 | Table of Contents 15 | ================= 16 | 17 | .. toctree:: 18 | :maxdepth: 2 19 | 20 | install 21 | 22 | .. toctree:: 23 | :maxdepth: 2 24 | 25 | user_guide/index 26 | 27 | .. toctree:: 28 | :maxdepth: 2 29 | 30 | api/index 31 | 32 | .. toctree:: 33 | :maxdepth: 2 34 | 35 | examples/index 36 | -------------------------------------------------------------------------------- /docs/source/install.rst: -------------------------------------------------------------------------------- 1 | Install 2 | =========== 3 | 4 | To install pytorch-minimize, users may either 1) install the official PyPI release via pip, or 2) install a *bleeding edge* distribution from source. 5 | 6 | 7 | **Install via pip (official PyPI release)**:: 8 | 9 | pip install pytorch-minimize 10 | 11 | **Install from source (bleeding edge)**:: 12 | 13 | # clone the latest master to any location 14 | git clone https://github.com/rfeinman/pytorch-minimize.git 15 | 16 | # cd to the root directory and install the package with pip 17 | cd pytorch-minimize 18 | pip install -e . 19 | 20 | 21 | **PyTorch requirement** 22 | 23 | This library uses latest features from the actively-developed :mod:`torch.linalg` module. For maximum performance, users should install pytorch>=1.9, as it includes some new items not available in prior releases (e.g. `torch.linalg.cholesky_ex `_). Pytorch-minimize will automatically use these features when available. 24 | -------------------------------------------------------------------------------- /docs/source/user_guide/index.rst: -------------------------------------------------------------------------------- 1 | =========== 2 | User Guide 3 | =========== 4 | 5 | .. currentmodule:: torchmin 6 | 7 | Using the :func:`minimize` function 8 | ------------------------------------ 9 | 10 | Coming soon. -------------------------------------------------------------------------------- /examples/constrained_optimization_adversarial_examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "dried-niagara", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%matplotlib inline\n", 11 | "import matplotlib.pylab as plt\n", 12 | "import torch\n", 13 | "import torch.nn as nn\n", 14 | "import torch.nn.functional as F\n", 15 | "import torch.optim as optim\n", 16 | "from torch.utils.data import DataLoader\n", 17 | "from torchvision import transforms, datasets\n", 18 | "\n", 19 | "from torchmin import minimize_constr" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "id": "whole-fifty", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "device = torch.device('cuda:0')\n", 30 | "\n", 31 | "root = '/path/to/data' # fill in torchvision dataset path\n", 32 | "train_data = datasets.MNIST(root, train=True, transform=transforms.ToTensor())\n", 33 | "train_loader = DataLoader(train_data, batch_size=128, shuffle=True)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "closed-interview", 39 | "metadata": {}, 40 | "source": [ 41 | "# Train CNN classifier" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "id": "following-knowing", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "def CNN():\n", 52 | " return nn.Sequential(\n", 53 | " nn.Conv2d(1, 10, kernel_size=5),\n", 54 | " nn.SiLU(),\n", 55 | " nn.AvgPool2d(2),\n", 56 | " nn.Conv2d(10, 20, kernel_size=5),\n", 57 | " nn.SiLU(),\n", 58 | " nn.AvgPool2d(2),\n", 59 | " nn.Dropout(0.2),\n", 60 | " nn.Flatten(1),\n", 61 | " nn.Linear(320, 50),\n", 62 | " nn.Dropout(0.2),\n", 63 | " nn.Linear(50, 10),\n", 64 | " nn.LogSoftmax(1)\n", 65 | " )" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 4, 71 | "id": "accessory-killer", 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | "epoch 1 - loss: 0.4923\n", 79 | "epoch 2 - loss: 0.1428\n", 80 | "epoch 3 - loss: 0.1048\n", 81 | "epoch 4 - loss: 0.0883\n", 82 | "epoch 5 - loss: 0.0754\n", 83 | "epoch 6 - loss: 0.0672\n", 84 | "epoch 7 - loss: 0.0626\n", 85 | "epoch 8 - loss: 0.0578\n", 86 | "epoch 9 - loss: 0.0524\n", 87 | "epoch 10 - loss: 0.0509\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "torch.manual_seed(382)\n", 93 | "net = CNN().to(device)\n", 94 | "optimizer = optim.Adam(net.parameters())\n", 95 | "for epoch in range(10):\n", 96 | " epoch_loss = 0\n", 97 | " for (x, y) in train_loader:\n", 98 | " x = x.to(device, non_blocking=True)\n", 99 | " y = y.to(device, non_blocking=True)\n", 100 | " logits = net(x)\n", 101 | " loss = F.nll_loss(logits, y)\n", 102 | " optimizer.zero_grad(set_to_none=True)\n", 103 | " loss.backward()\n", 104 | " optimizer.step()\n", 105 | " epoch_loss += loss.item() * x.size(0)\n", 106 | " print('epoch %2d - loss: %0.4f' % (epoch+1, epoch_loss / len(train_data)))" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "id": "therapeutic-elimination", 112 | "metadata": {}, 113 | "source": [ 114 | "# set up adversarial example environment" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 6, 120 | "id": "developing-afghanistan", 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "# evaluation mode settings\n", 125 | "net = net.requires_grad_(False).eval()\n", 126 | "\n", 127 | "# move net to CPU\n", 128 | "# Note: using CUDA-based inputs and objectives is allowed\n", 129 | "# but inefficient with trust-constr, as the data will be\n", 130 | "# moved back-and-forth from CPU\n", 131 | "net = net.cpu()" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 7, 137 | "id": "mighty-realtor", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "def nll_objective(x, y):\n", 142 | " assert x.numel() == 28**2\n", 143 | " assert y.numel() == 1\n", 144 | " x = x.view(1, 1, 28, 28)\n", 145 | " return F.nll_loss(net(x), y.view(1))" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 8, 151 | "id": "better-nerve", 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "# select a random image from the dataset\n", 156 | "torch.manual_seed(338)\n", 157 | "x, y = next(iter(train_loader))\n", 158 | "img = x[0]\n", 159 | "label = y[0]" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 9, 165 | "id": "presidential-astrology", 166 | "metadata": {}, 167 | "outputs": [ 168 | { 169 | "data": { 170 | "text/plain": [ 171 | "tensor(1.4663e-05)" 172 | ] 173 | }, 174 | "execution_count": 9, 175 | "metadata": {}, 176 | "output_type": "execute_result" 177 | } 178 | ], 179 | "source": [ 180 | "nll_objective(img, label)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 10, 186 | "id": "independent-slovenia", 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "# minimization objective for adversarial examples\n", 191 | "# goal is to maximize NLL of perturbed image (image + perturbation)\n", 192 | "fn = lambda eps: - nll_objective(img + eps, label)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 18, 198 | "id": "bacterial-champagne", 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "# plotting utility\n", 203 | "\n", 204 | "def plot_distortion(img, eps, y):\n", 205 | " assert img.numel() == 28**2\n", 206 | " assert eps.numel() == 28**2\n", 207 | " img = img.view(28, 28)\n", 208 | " img_ = img + eps.view(28, 28)\n", 209 | " fig, axes = plt.subplots(1,2,figsize=(4,2))\n", 210 | " for i, x in enumerate((img, img_)):\n", 211 | " axes[i].imshow(x.cpu(), cmap=plt.cm.binary)\n", 212 | " axes[i].set_xticks([])\n", 213 | " axes[i].set_yticks([])\n", 214 | " axes[i].set_title('nll: %0.4f' % nll_objective(x, y))\n", 215 | " plt.show()" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "id": "ambient-thread", 221 | "metadata": {}, 222 | "source": [ 223 | "# craft adversarial example\n", 224 | "\n", 225 | "We will use our constrained optimizer to find the optimal unit-norm purturbation $\\epsilon$ \n", 226 | "\n", 227 | "\\begin{equation}\n", 228 | "\\max_{\\epsilon} NLL(x + \\epsilon) \\quad \\text{s.t.} \\quad ||\\epsilon|| = 1\n", 229 | "\\end{equation}" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 13, 235 | "id": "surprised-symposium", 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [ 239 | "torch.manual_seed(227)\n", 240 | "eps0 = torch.randn_like(img)\n", 241 | "eps0 /= eps0.norm()" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 14, 247 | "id": "missing-bargain", 248 | "metadata": {}, 249 | "outputs": [ 250 | { 251 | "data": { 252 | "text/plain": [ 253 | "-2.2291887944447808e-05" 254 | ] 255 | }, 256 | "execution_count": 14, 257 | "metadata": {}, 258 | "output_type": "execute_result" 259 | } 260 | ], 261 | "source": [ 262 | "fn(eps0).item()" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 15, 268 | "id": "miniature-fight", 269 | "metadata": { 270 | "scrolled": true 271 | }, 272 | "outputs": [ 273 | { 274 | "name": "stdout", 275 | "output_type": "stream", 276 | "text": [ 277 | "`xtol` termination condition is satisfied.\n", 278 | "Number of iterations: 32, function evaluations: 50, CG iterations: 52, optimality: 1.02e-04, constraint violation: 0.00e+00, execution time: 0.57 s.\n" 279 | ] 280 | } 281 | ], 282 | "source": [ 283 | "res = minimize_constr(\n", 284 | " fn, eps0, \n", 285 | " max_iter=100,\n", 286 | " constr=dict(\n", 287 | " fun=lambda x: x.square().sum(), \n", 288 | " lb=1, ub=1\n", 289 | " ),\n", 290 | " disp=1\n", 291 | ")" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 16, 297 | "id": "wanted-journal", 298 | "metadata": {}, 299 | "outputs": [ 300 | { 301 | "name": "stdout", 302 | "output_type": "stream", 303 | "text": [ 304 | "tensor(1.)\n" 305 | ] 306 | } 307 | ], 308 | "source": [ 309 | "eps = res.x\n", 310 | "print(eps.norm())" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 19, 316 | "id": "spanish-wright", 317 | "metadata": { 318 | "scrolled": true 319 | }, 320 | "outputs": [ 321 | { 322 | "data": { 323 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPEAAACHCAYAAADHsL/VAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOOUlEQVR4nO2de4xV1RXGvyXKGwVkBDsRBiFSA5QmQJsoKg8fRSEytdSmQCltItgYYgAjoaFgbXnUxCb9g2hMFCehtYrQClWIMRBBQyoJCExiQQggwQoDhfASVHb/uHdO916ds889973nfr9kkrVY++yz7xzWnPXdvc8+YowBISRcrqn0AAghhcEkJiRwmMSEBA6TmJDAYRITEjhMYkICpyqTWETGisgxyz8sIvdWckyk+PA6F4eqTOJCkAwrReRU9ucPIiKe9hNE5BMRuSgiW0RkQK59iUhD9piL2T7uVX3/VESOiMgFEfmbiPQuzaeuPdJcZxHpKCJrs38kjIiMVfGeIvKqiJzI/iy1Yv1F5Lz6MSIyPxtfpGKXROSqiPQp4cd3aHdJDOAxAFMAjADwHQCTAMxuq2H2F70OwGIAvQHsBPDXFH39BcAuADcC+DWAtSJSl+17KIAXAcwA0BfARQCrCv94JEvO1znLdgDTAfy7jdgfAXQF0ADgewBmiMgsADDGHDXGdG/9ATAcwFUAb2bjy1R8JYCtxpiWwj9ijhhjKvID4DCABQD2ADiLTPJ0zsbGAjim2t6bY78fAnjM8n8JYEdM28cAfGj53QBcAvDtpL4A3AbgMoAeVnwbgDlZexmAP1uxQQCu2O1r4acarrM67hiAserfWgCMtvxFALbFHL8EwJaYmAA4CGBmOX/Hlb4T/xjADwAMROav6c+TDhCRMSJyxtNkKICPLf/j7L8ltjXGXEDmIgxtK676GgrgkDHmnCdu930QmSS+zTP29kqlr3MuiLKHxbT7GYBXY2J3IVN1vVnAOFJT6ST+kzHmuDHmNIANAL6bdIAxZrsxpqenSXdk/uK3chZA9xi9pNu2tu+RQ19pj9XxWqLS1zmJTQAWikgPERkM4BfIlNcOItKapGtj+pkJYK0x5nweY8ibSiexrU8uInNhCuU8gOst/3oA50223klo29r+XEzc7ivtsTpeS1T6OicxFxkZdQDA35H5ruNYG+1mAnizrSQVkS4ApiL+Ll0yKp3EpaAZmS87WhmR/bfEtiLSDRnt2txWXPXVDOBWEenhidt93wqgE4D9KT4LiSfNdfZijDltjJlmjOlnjBmKTF78026TQ5L+EMBpAFvzGUMhtMckbgIwT0TqReRbAOYDWB3Tdj2AYSLyiIh0BvAbAHuMMZ8k9WWM2Q9gN4AlItJZRBqR0XutemgNgMkiclf2j8NvAaxTGprkT5rrDBHplL3GANAxe80kGxskIjeKSAcRmYjMF56/U100AjgDYEvMKWYCaMqzEiiI4JI4mxQ+zfEiMrprL4B9AP6R/bfW45tFZBoAGGNOAngEwO8B/AfA9wH8JNe+sm1HZY9dAeBH2T5hjGkGMAeZZD6BjBb+VV4fugYp5nXO8i9kSuZ6AJuzduuagJHZfs4BWA5gWvb62cQmqYjUAxiPzB+WsiMV+MNBCCkiwd2JCSEuTGJCAodJTEjgMIkJCRwmMSGBc22axn369DENDQ0lGgrJh8OHD6OlpSWfpYZt0qtXL1NfX1+s7oIlv9WbfgqZCWpubm4xxtS1FUuVxA0NDdi5c2feAyHFZ9SoUUXtr76+HuvWrcupbTmmJ0uRTJU6byG/ryFDhhyJi7GcJiRwUt2JSW2T5k6Spq3vrqf7KdWd2ddvuaqBfO/UvBMTEjhMYkICh+U0KQm6BLVLxTTlaZq2hZT7+jy2X6ySvlRfBPJOTEjgMIkJCRwmMSGBQ01MysI11xTnfuHTlTqWpq1PE2vy1fRJmjhfrc07MSGBwyQmJHCYxIQEDjUxyRnf3K+O+zRmkvaz+03Sub6233zzTU7naMvPdbwdOnSIjQHpvgvgsktCahQmMSGBw3Ka5I2vZNZlpl1WJpXTV69ejWxdEieV9HHn1P2mmY7S+EroUi0p9cE7MSGBwyQmJHCYxIQEDjUx8ZLmEUJbg2rdeO21//uvlqRrv/7668i2dWwShehRn6+1te3rWLl2InHGU/IzEEJKCpOYkMBp1+X0jh07IvvTTz91YnbJBgCzZs0q+vl1nzNmzHD8cePGFf2c5USXknYJbZfPOqbR18KeVtJTTLq83rNnT2QfO3bMiX311VeOv2TJkshOmlLySQObKVOmOP6DDz7o+L4thYtVevNOTEjgMIkJCRwmMSGBUxWa+PLly5H99NNPOzGtZdOwd+/eyP7ss8+8bUsxFbB69WrHf/vttx3/tddei+yxY8cW/fzFJs20jObSpUuR/fzzzzuxgwcPOr6tg5O0q33syZMnnZg+1tbpSVNXtg72aeKNGzc6/vvvv+/4y5Yti+zRo0d7z5kvvBMTEjhMYkICh0lMSOBURBNv2LDB8desWRPZr7/+ermHUzZOnDjh+J9//nmFRpI7ue7WAbga9L333nNi9jV/5513nJhvBw59Dt9yTh3zPW6YtONGro9Oam2tdbl9zZN0eJrHNZ3jcm5JCKlKmMSEBE7ZyunNmzdH9syZM53YmTNninKOAQMGOL6vZOrYsaPjv/LKK7Ft9dTQc889F9n29FgSd955p+Pfd999OR9bjehyddu2bZG9aNEiJ3b27NnI1mWlr2TWJfLNN98c21b3q4995plnYs+5fft2x29qaops37JQHRs+fLjj29NKpXrCiXdiQgKHSUxI4DCJCQmcsmli+6v3QjTwnDlzIruurs6JPfXUU47fvXv3vM9jo7XMyy+/HNnHjx/PuR/9XUCfPn0KG1iJiNvNI0nLnjp1KrJtDazRjynq7y6mTp0a2b169XJi06dPd/zOnTvHjk9fN9+OHJq33norsltaWpyYbzfOhx56yPFvuOGGyC5k2aoP3okJCRwmMSGBwyQmJHAqsuxSa6IRI0ZEtta5+rG1gQMHRnanTp1KMDrgwIEDjv/GG284vk8H2zrcfgwNKM0WQKUmzYu6bfQc7ZAhQyK7d+/eTmz+/PmO369fv9h+9Lys/YhjmmWN+tFUWwMD7pLY6667zol17do1sh9//HEn1tjYGHvOYmlgDe/EhAQOk5iQwClbOT1y5MjI1iXyE088Ua5h5MTs2bMdf+vWrTkfu3LlysjWpVZ75/bbb4/sBQsWOLFHH300spPeI2yXzHpZq25r95U0xWSzePFix7d3zQSALl26RLaejpo3b15k690u07wjWcP3ExNSozCJCQkcJjEhgVM2TWzrJduuFuxpo0KWhU6cOLEIowkDPUUyaNCgyB48eLATs/Wq1q5aN9q+fouDnmKy40kvDj99+nRkX7hwwYnZGhgAunXrFtl6isl+c4eeAtOfxY4nLfWkJiakRmESExI4TGJCAqcq3gBRDdg7MO7evbtyA2kn+Lbc0drQt/wwSRPb88haa2st+9FHH0W2XnbZs2dPx7e3b9JjsHVvmqWeSZrY7iuNPuadmJDAYRITEjg1W07rqYBdu3bl1c+wYcMc337CpdbwvZjbfnJNP8XmK5HTLNHU5ao9TQS4U0x9+/Z1Yrr0tkvoW265xYnZ4/ftHgK4vwctG9IsE/XBOzEhgcMkJiRwmMSEBE7NamK97G7VqlU5H2trJL3rx0033VTYwKqAvHddtPSgfsOGvaxR78iiteGXX34ZOxbt2+exd74E/l+Xv/vuu5GtNbDW5T169IjspUuXOjHf9x6l2r3DB+/EhAQOk5iQwGESExI4NauJC9l5cvLkyZFt7+LY3knSe/b8qda9tl7V2lVj96PnXbXWtueN9Tn1Fjy2DrZ3yQSAc+fOOb79Bkt7900g/g0ZbY037jggeclmrvBOTEjgMIkJCZyaKaf15uAffPBBzsfaOzUCwPLly4syptDxTZ/oUtGewtFPBWl8O0TqqSG79NYvCrefWgLckvnixYtO7I477nB8e8dT38vBk57Iso/Nd1llErwTExI4TGJCAodJTEjg1Iwm1rpWv1nARr+cfNq0aY5vL8mrJZKmSHy7bNgx/Siir9+kNyjYx+o3iehppCtXrkS2XpL5wAMPOL59jfV47WkurYmTdty04QvVCCEAmMSEBE+7Kqf379/v+C+88EJk26VUEnpz+0mTJhU2sMCIKwGTSkXb9/2+fdMw2tdtjx496vhr166NbL3pv29Ter0jy4QJExzftxOJXUInTRvZ4+cUEyGkTZjEhAQOk5iQwGlXmviee+5x/C+++CKvflasWFGM4bQLfC/x9r0Yzadzk/rxLWucPn2647e0tMSOXetp+ymnhQsXOjG9M6aPUmnbfPUz78SEBA6TmJDAYRITEjjBaeJ9+/ZFtr3DBgCcOnUq537sOWQAuP/++yNb7+RAMiQtE8x1GWHS43uHDh2K7CeffNKJ+V4Ar/t99tlnHf/uu++ObH2NfS8LT9Lw+VIsbc07MSGBwyQmJHCCK6dfeumlyD5y5Eje/dTV1Tl+Q0ND3n21Z3wlsl6OaJeZhZScGzdujGwtkfQ5fcsjdcncv3//yNalrN483vdZ8i2DS7WRPO/EhAQOk5iQwGESExI4Va+J169f7/hNTU159fPwww87fq09XpiGuOV/SY8Q2nE9ZRPXDgC2bNni+Js2bYpsrXN9fY0bN86JjRkzxvFt3Zs0bVSsXSrTLKXM+0V2eR1FCKkamMSEBA6TmJDAqXpNXCzmzp3r+PpNAqRtfJrOp+HSxLR+tpdPptlNcurUqbHnBJLfPOE7TyngbpeEEABMYkKCp+rL6cbGRq9PykeaKaakY23Gjx/v9X3nzDWWllIskeSyS0JImzCJCQkcJjEhgSNpdISInASQ//N/pBQMMMbUJTfLDV7jqiX2OqdKYkJI9cFympDAYRITEjhMYkICh0lMSOAwiQkJHCYxIYHDJCYkcJjEhAQOk5iQwPkvxBBGQYvNKc4AAAAASUVORK5CYII=\n", 324 | "text/plain": [ 325 | "
" 326 | ] 327 | }, 328 | "metadata": {}, 329 | "output_type": "display_data" 330 | } 331 | ], 332 | "source": [ 333 | "plot_distortion(img.detach(), eps, label)" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "id": "varying-commission", 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [] 343 | } 344 | ], 345 | "metadata": { 346 | "kernelspec": { 347 | "display_name": "Python 3", 348 | "language": "python", 349 | "name": "python3" 350 | }, 351 | "language_info": { 352 | "codemirror_mode": { 353 | "name": "ipython", 354 | "version": 3 355 | }, 356 | "file_extension": ".py", 357 | "mimetype": "text/x-python", 358 | "name": "python", 359 | "nbconvert_exporter": "python", 360 | "pygments_lexer": "ipython3", 361 | "version": "3.8.8" 362 | } 363 | }, 364 | "nbformat": 4, 365 | "nbformat_minor": 5 366 | } -------------------------------------------------------------------------------- /examples/scipy_benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | A comparison of pytorch-minimize solvers to the analogous solvers from 3 | scipy.optimize. 4 | 5 | Pytorch-minimize uses autograd to compute 1st- and 2nd-order derivatives 6 | implicitly, therefore derivative functions need not be provided or known. 7 | In contrast, scipy.optimize requires that they be provided, or else it will 8 | use imprecise numerical approximations. For fair comparison I am providing 9 | derivative functions to scipy.optimize in this script. In general, however, 10 | we will not have access to these functions, so applications of scipy.optimize 11 | are far more limited. 12 | 13 | """ 14 | import torch 15 | from torchmin import minimize 16 | from torchmin.benchmarks import rosen 17 | from scipy import optimize 18 | 19 | # Many scipy optimizers convert the data to double-precision, so 20 | # we will use double precision in torch for a fair comparison 21 | torch.set_default_dtype(torch.float64) 22 | 23 | 24 | def print_header(title, num_breaks=1): 25 | print('\n'*num_breaks + '='*50) 26 | print(' '*20 + title) 27 | print('='*50 + '\n') 28 | 29 | 30 | def main(): 31 | torch.manual_seed(991) 32 | x0 = torch.randn(100) 33 | x0_np = x0.numpy() 34 | 35 | print('\ninitial loss: %0.4f\n' % rosen(x0)) 36 | 37 | 38 | # ---- BFGS ---- 39 | print_header('BFGS') 40 | 41 | print('-'*19 + ' pytorch ' + '-'*19) 42 | res = minimize(rosen, x0, method='bfgs', tol=1e-5, disp=True) 43 | 44 | print('\n' + '-'*20 + ' scipy ' + '-'*20) 45 | res = optimize.minimize( 46 | optimize.rosen, x0_np, 47 | method='bfgs', 48 | jac=optimize.rosen_der, 49 | tol=1e-5, 50 | options=dict(disp=True) 51 | ) 52 | 53 | 54 | # ---- Newton CG ---- 55 | print_header('Newton-CG') 56 | 57 | print('-'*19 + ' pytorch ' + '-'*19) 58 | res = minimize(rosen, x0, method='newton-cg', tol=1e-5, disp=True) 59 | 60 | print('\n' + '-'*20 + ' scipy ' + '-'*20) 61 | res = optimize.minimize( 62 | optimize.rosen, x0_np, 63 | method='newton-cg', 64 | jac=optimize.rosen_der, 65 | hessp=optimize.rosen_hess_prod, 66 | tol=1e-5, 67 | options=dict(disp=True) 68 | ) 69 | 70 | 71 | # ---- Newton Exact ---- 72 | # NOTE: Scipy does not have a precise analogue to "newton-exact," but they 73 | # have something very close called "trust-exact." Like newton-exact, 74 | # trust-exact also uses Cholesky factorization of the explicit Hessian 75 | # matrix. However, whereas newton-exact first computes the newton direction 76 | # and then uses line search to determine a step size, trust-exact first 77 | # specifies a step size boundary and then solves for the optimal newton 78 | # step within this boundary (a constrained optimization problem). 79 | 80 | print_header('Newton-Exact') 81 | 82 | print('-'*19 + ' pytorch ' + '-'*19) 83 | res = minimize(rosen, x0, method='newton-exact', tol=1e-5, disp=True) 84 | 85 | print('\n' + '-'*20 + ' scipy ' + '-'*20) 86 | res = optimize.minimize( 87 | optimize.rosen, x0_np, 88 | method='trust-exact', 89 | jac=optimize.rosen_der, 90 | hess=optimize.rosen_hess, 91 | options=dict(gtol=1e-5, disp=True) 92 | ) 93 | 94 | print() 95 | 96 | 97 | if __name__ == '__main__': 98 | main() -------------------------------------------------------------------------------- /examples/train_mnist_Minimizer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchvision import datasets 7 | 8 | from torchmin import Minimizer 9 | 10 | 11 | def MLPClassifier(input_size, hidden_sizes, num_classes): 12 | layers = [] 13 | for i, hidden_size in enumerate(hidden_sizes): 14 | layers.append(nn.Linear(input_size, hidden_size)) 15 | layers.append(nn.ReLU()) 16 | input_size = hidden_size 17 | layers.append(nn.Linear(input_size, num_classes)) 18 | layers.append(nn.LogSoftmax(-1)) 19 | 20 | return nn.Sequential(*layers) 21 | 22 | 23 | @torch.no_grad() 24 | def evaluate(model): 25 | train_output = model(X_train) 26 | test_output = model(X_test) 27 | train_loss = F.nll_loss(train_output, y_train) 28 | test_loss = F.nll_loss(test_output, y_test) 29 | print('Loss (cross-entropy):\n train: {:.4f} - test: {:.4f}'.format(train_loss, test_loss)) 30 | train_accuracy = (train_output.argmax(-1) == y_train).float().mean() 31 | test_accuracy = (test_output.argmax(-1) == y_test).float().mean() 32 | print('Accuracy:\n train: {:.4f} - test: {:.4f}'.format(train_accuracy, test_accuracy)) 33 | 34 | 35 | if __name__ == '__main__': 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--mnist_root', type=str, required=True, 38 | help='root path for the MNIST dataset') 39 | parser.add_argument('--method', type=str, default='newton-cg', 40 | help='optimization method to use') 41 | parser.add_argument('--device', type=str, default='cpu', 42 | help='device to use for training') 43 | parser.add_argument('--quiet', action='store_true', 44 | help='whether to train in quiet mode (no loss printing)') 45 | parser.add_argument('--plot_weight', action='store_true', 46 | help='whether to plot the learned weights') 47 | args = parser.parse_args() 48 | 49 | device = torch.device(args.device) 50 | 51 | 52 | # -------------------------------------------- 53 | # Load MNIST dataset 54 | # -------------------------------------------- 55 | 56 | train_data = datasets.MNIST(args.mnist_root, train=True) 57 | X_train = (train_data.data.float().view(-1, 784) / 255.).to(device) 58 | y_train = train_data.targets.to(device) 59 | 60 | test_data = datasets.MNIST(args.mnist_root, train=False) 61 | X_test = (test_data.data.float().view(-1, 784) / 255.).to(device) 62 | y_test = test_data.targets.to(device) 63 | 64 | 65 | # -------------------------------------------- 66 | # Initialize model 67 | # -------------------------------------------- 68 | mlp = MLPClassifier(784, hidden_sizes=[50], num_classes=10) 69 | mlp = mlp.to(device) 70 | 71 | print('-------- Initial evaluation ---------') 72 | evaluate(mlp) 73 | 74 | 75 | # -------------------------------------------- 76 | # Fit model with Minimizer 77 | # -------------------------------------------- 78 | optimizer = Minimizer(mlp.parameters(), 79 | method=args.method, 80 | tol=1e-6, 81 | max_iter=200, 82 | disp=0 if args.quiet else 2) 83 | 84 | def closure(): 85 | optimizer.zero_grad() 86 | output = mlp(X_train) 87 | loss = F.nll_loss(output, y_train) 88 | # loss.backward() <-- do not call backward! 89 | return loss 90 | 91 | loss = optimizer.step(closure) 92 | 93 | # -------------------------------------------- 94 | # Evaluate fitted model 95 | # -------------------------------------------- 96 | print('-------- Final evaluation ---------') 97 | evaluate(mlp) 98 | 99 | if args.plot_weight: 100 | weight = mlp[0].weight.data.cpu().view(-1, 28, 28) 101 | vmin, vmax = weight.min(), weight.max() 102 | fig, axes = plt.subplots(4, 4, figsize=(6, 6)) 103 | axes = axes.ravel() 104 | for i in range(len(axes)): 105 | axes[i].matshow(weight[i], cmap='gray', vmin=0.5 * vmin, vmax=0.5 * vmax) 106 | axes[i].set_xticks(()) 107 | axes[i].set_yticks(()) 108 | plt.show() -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=46.4.0", "wheel"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.18.0 2 | scipy>=1.6 3 | torch>=1.9.0 -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | version = attr: torchmin.__version__ 3 | long_description = file: README.md 4 | long_description_content_type = text/markdown -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | # find packages 4 | packages = find_packages(exclude=("tests", "tests.*")) 5 | 6 | # setup 7 | setup( 8 | name='pytorch-minimize', 9 | description='Newton and Quasi-Newton optimization with PyTorch', 10 | url='https://pytorch-minimize.readthedocs.io', 11 | author='Reuben Feinman', 12 | author_email='reuben.feinman@nyu.edu', 13 | license='MIT Licence', 14 | packages=packages, 15 | zip_safe=False, 16 | install_requires=[ 17 | 'numpy>=1.18.0', 18 | 'scipy>=1.6', 19 | 'torch>=1.9.0' 20 | ], 21 | python_requires=">=3.7", 22 | classifiers=[ 23 | "Programming Language :: Python :: 3", 24 | "License :: OSI Approved :: MIT License", 25 | "Operating System :: OS Independent", 26 | ] 27 | ) -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rfeinman/pytorch-minimize/6b85e5be8ad1e7689591c02bb73dc66515520045/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_imports.py: -------------------------------------------------------------------------------- 1 | def test_import_packages(): 2 | """Test that importing works.""" 3 | import torchmin 4 | from torchmin import minimize 5 | from torchmin import Minimizer 6 | -------------------------------------------------------------------------------- /tests/torchmin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rfeinman/pytorch-minimize/6b85e5be8ad1e7689591c02bb73dc66515520045/tests/torchmin/__init__.py -------------------------------------------------------------------------------- /tests/torchmin/test_leastsquares.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torchmin import minimize 5 | 6 | torch.manual_seed(42) 7 | N = 100 8 | D = 7 9 | M = 5 10 | X = torch.randn(N, D) 11 | Y = torch.randn(N, M) 12 | trueB = torch.linalg.inv(X.T @ X) @ X.T @ Y 13 | all_methods = [ 14 | 'bfgs', 'l-bfgs', 'cg', 'newton-cg', 'newton-exact', 15 | 'trust-ncg', 'trust-krylov', 'trust-exact', 'dogleg'] 16 | 17 | 18 | @pytest.mark.parametrize('method', all_methods) 19 | def test_minimize(method): 20 | """Test least-squares problem on unconstrained minimizers.""" 21 | B0 = torch.zeros(D, M) 22 | 23 | def leastsquares_obj(B): 24 | return torch.sum((Y - X @ B) ** 2) 25 | 26 | result = minimize(leastsquares_obj, B0, method=method) 27 | torch.testing.assert_close(trueB, result.x, rtol=1e-4, atol=1e-4) 28 | -------------------------------------------------------------------------------- /tests/torchmin/test_minimize_constr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torchmin import minimize, minimize_constr 5 | from torchmin.benchmarks import rosen 6 | 7 | 8 | def test_rosen(): 9 | """Test Rosenbrock problem with constraints.""" 10 | 11 | x0 = torch.tensor([1., 8.]) 12 | res = minimize( 13 | rosen, x0, 14 | method='l-bfgs', 15 | options=dict(line_search='strong-wolfe'), 16 | max_iter=50, 17 | disp=0 18 | ) 19 | 20 | 21 | # Test inactive constraints 22 | 23 | res_constrained_sum = minimize_constr( 24 | rosen, x0, 25 | constr=dict(fun=lambda x: x.sum(), ub=10.), 26 | max_iter=50, 27 | disp=0 28 | ) 29 | torch.testing.assert_close( 30 | res.x, res_constrained_sum.x, rtol=1e-2, atol=1e-2) 31 | 32 | res_constrained_norm = minimize_constr( 33 | rosen, x0, 34 | constr=dict(fun=lambda x: x.square().sum(), ub=10.), 35 | max_iter=50, 36 | disp=0 37 | ) 38 | torch.testing.assert_close( 39 | res.x, res_constrained_norm.x, rtol=1e-2, atol=1e-2) 40 | 41 | 42 | # Test active constraints 43 | 44 | res_constrained_sum = minimize_constr( 45 | rosen, x0, 46 | constr=dict(fun=lambda x: x.sum(), ub=1.), 47 | max_iter=50, 48 | disp=0 49 | ) 50 | assert res_constrained_sum.x.sum() <= 1. 51 | res_constrained_norm = minimize_constr( 52 | rosen, x0, 53 | constr=dict(fun=lambda x: x.square().sum(), ub=1.), 54 | max_iter=50, 55 | disp=0 56 | ) 57 | assert res_constrained_norm.x.square().sum() <= 1. 58 | -------------------------------------------------------------------------------- /torchmin/__init__.py: -------------------------------------------------------------------------------- 1 | from .minimize import minimize 2 | from .minimize_constr import minimize_constr 3 | from .lstsq import least_squares 4 | from .optim import Minimizer, ScipyMinimizer 5 | 6 | __all__ = ['minimize', 'minimize_constr', 'least_squares', 7 | 'Minimizer', 'ScipyMinimizer'] 8 | 9 | __version__ = "0.0.2" -------------------------------------------------------------------------------- /torchmin/benchmarks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | __all__ = ['rosen', 'rosen_der', 'rosen_hess', 'rosen_hess_prod'] 4 | 5 | 6 | # ============================= 7 | # Rosenbrock function 8 | # ============================= 9 | 10 | 11 | def rosen(x, reduce=True): 12 | val = 100. * (x[...,1:] - x[...,:-1]**2)**2 + (1 - x[...,:-1])**2 13 | if reduce: 14 | return val.sum() 15 | else: 16 | # don't reduce batch dimensions 17 | return val.sum(-1) 18 | 19 | 20 | def rosen_der(x): 21 | xm = x[..., 1:-1] 22 | xm_m1 = x[..., :-2] 23 | xm_p1 = x[..., 2:] 24 | der = torch.zeros_like(x) 25 | der[..., 1:-1] = (200 * (xm - xm_m1**2) - 26 | 400 * (xm_p1 - xm**2) * xm - 2 * (1 - xm)) 27 | der[..., 0] = -400 * x[..., 0] * (x[..., 1] - x[..., 0]**2) - 2 * (1 - x[..., 0]) 28 | der[..., -1] = 200 * (x[..., -1] - x[..., -2]**2) 29 | return der 30 | 31 | 32 | def rosen_hess(x): 33 | H = torch.diag_embed(-400*x[..., :-1], 1) - \ 34 | torch.diag_embed(400*x[..., :-1], -1) 35 | diagonal = torch.zeros_like(x) 36 | diagonal[..., 0] = 1200*x[..., 0].square() - 400*x[..., 1] + 2 37 | diagonal[..., -1] = 200 38 | diagonal[..., 1:-1] = 202 + 1200*x[..., 1:-1].square() - 400*x[..., 2:] 39 | H.diagonal(dim1=-2, dim2=-1).add_(diagonal) 40 | return H 41 | 42 | 43 | def rosen_hess_prod(x, p): 44 | Hp = torch.zeros_like(x) 45 | Hp[..., 0] = (1200 * x[..., 0]**2 - 400 * x[..., 1] + 2) * p[..., 0] - \ 46 | 400 * x[..., 0] * p[..., 1] 47 | Hp[..., 1:-1] = (-400 * x[..., :-2] * p[..., :-2] + 48 | (202 + 1200 * x[..., 1:-1]**2 - 400 * x[..., 2:]) * p[..., 1:-1] - 49 | 400 * x[..., 1:-1] * p[..., 2:]) 50 | Hp[..., -1] = -400 * x[..., -2] * p[..., -2] + 200*p[..., -1] 51 | return Hp -------------------------------------------------------------------------------- /torchmin/bfgs.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | from torch import Tensor 4 | from scipy.optimize import OptimizeResult 5 | 6 | from .function import ScalarFunction 7 | from .line_search import strong_wolfe 8 | 9 | try: 10 | from scipy.optimize.optimize import _status_message 11 | except ImportError: 12 | from scipy.optimize._optimize import _status_message 13 | 14 | class HessianUpdateStrategy(ABC): 15 | def __init__(self): 16 | self.n_updates = 0 17 | 18 | @abstractmethod 19 | def solve(self, grad): 20 | pass 21 | 22 | @abstractmethod 23 | def _update(self, s, y, rho_inv): 24 | pass 25 | 26 | def update(self, s, y): 27 | rho_inv = y.dot(s) 28 | if rho_inv <= 1e-10: 29 | # curvature is negative; do not update 30 | return 31 | self._update(s, y, rho_inv) 32 | self.n_updates += 1 33 | 34 | 35 | class L_BFGS(HessianUpdateStrategy): 36 | def __init__(self, x, history_size=100): 37 | super().__init__() 38 | self.y = [] 39 | self.s = [] 40 | self.rho = [] 41 | self.H_diag = 1. 42 | self.alpha = x.new_empty(history_size) 43 | self.history_size = history_size 44 | 45 | def solve(self, grad): 46 | mem_size = len(self.y) 47 | d = grad.neg() 48 | for i in reversed(range(mem_size)): 49 | self.alpha[i] = self.s[i].dot(d) * self.rho[i] 50 | d.add_(self.y[i], alpha=-self.alpha[i]) 51 | d.mul_(self.H_diag) 52 | for i in range(mem_size): 53 | beta_i = self.y[i].dot(d) * self.rho[i] 54 | d.add_(self.s[i], alpha=self.alpha[i] - beta_i) 55 | 56 | return d 57 | 58 | def _update(self, s, y, rho_inv): 59 | if len(self.y) == self.history_size: 60 | self.y.pop(0) 61 | self.s.pop(0) 62 | self.rho.pop(0) 63 | self.y.append(y) 64 | self.s.append(s) 65 | self.rho.append(rho_inv.reciprocal()) 66 | self.H_diag = rho_inv / y.dot(y) 67 | 68 | 69 | class BFGS(HessianUpdateStrategy): 70 | def __init__(self, x, inverse=True): 71 | super().__init__() 72 | self.inverse = inverse 73 | if inverse: 74 | self.I = torch.eye(x.numel(), device=x.device, dtype=x.dtype) 75 | self.H = self.I.clone() 76 | else: 77 | self.B = torch.eye(x.numel(), device=x.device, dtype=x.dtype) 78 | 79 | def solve(self, grad): 80 | if self.inverse: 81 | return torch.matmul(self.H, grad.neg()) 82 | else: 83 | return torch.cholesky_solve(grad.neg().unsqueeze(1), 84 | torch.linalg.cholesky(self.B)).squeeze(1) 85 | 86 | def _update(self, s, y, rho_inv): 87 | rho = rho_inv.reciprocal() 88 | if self.inverse: 89 | if self.n_updates == 0: 90 | self.H.mul_(rho_inv / y.dot(y)) 91 | R = torch.addr(self.I, s, y, alpha=-rho) 92 | torch.addr( 93 | torch.linalg.multi_dot((R, self.H, R.t())), 94 | s, s, alpha=rho, out=self.H) 95 | else: 96 | if self.n_updates == 0: 97 | self.B.mul_(rho * y.dot(y)) 98 | Bs = torch.mv(self.B, s) 99 | self.B.addr_(y, y, alpha=rho) 100 | self.B.addr_(Bs, Bs, alpha=-1./s.dot(Bs)) 101 | 102 | 103 | @torch.no_grad() 104 | def _minimize_bfgs_core( 105 | fun, x0, lr=1., low_mem=False, history_size=100, inv_hess=True, 106 | max_iter=None, line_search='strong-wolfe', gtol=1e-5, xtol=1e-9, 107 | normp=float('inf'), callback=None, disp=0, return_all=False): 108 | """Minimize a multivariate function with BFGS or L-BFGS. 109 | 110 | We choose from BFGS/L-BFGS with the `low_mem` argument. 111 | 112 | Parameters 113 | ---------- 114 | fun : callable 115 | Scalar objective function to minimize 116 | x0 : Tensor 117 | Initialization point 118 | lr : float 119 | Step size for parameter updates. If using line search, this will be 120 | used as the initial step size for the search. 121 | low_mem : bool 122 | Whether to use L-BFGS, the "low memory" variant of the BFGS algorithm. 123 | history_size : int 124 | History size for L-BFGS hessian estimates. Ignored if `low_mem=False`. 125 | inv_hess : bool 126 | Whether to parameterize the inverse hessian vs. the hessian with BFGS. 127 | Ignored if `low_mem=True` (L-BFGS always parameterizes the inverse). 128 | max_iter : int, optional 129 | Maximum number of iterations to perform. Defaults to 200 * x0.numel() 130 | line_search : str 131 | Line search specifier. Currently the available options are 132 | {'none', 'strong-wolfe'}. 133 | gtol : float 134 | Termination tolerance on 1st-order optimality (gradient norm). 135 | xtol : float 136 | Termination tolerance on function/parameter changes. 137 | normp : Number or str 138 | The norm type to use for termination conditions. Can be any value 139 | supported by `torch.norm` p argument. 140 | callback : callable, optional 141 | Function to call after each iteration with the current parameter 142 | state, e.g. ``callback(x)``. 143 | disp : int or bool 144 | Display (verbosity) level. Set to >0 to print status messages. 145 | return_all : bool, optional 146 | Set to True to return a list of the best solution at each of the 147 | iterations. 148 | 149 | Returns 150 | ------- 151 | result : OptimizeResult 152 | Result of the optimization routine. 153 | """ 154 | lr = float(lr) 155 | disp = int(disp) 156 | if max_iter is None: 157 | max_iter = x0.numel() * 200 158 | if low_mem and not inv_hess: 159 | raise ValueError('inv_hess=False is not available for L-BFGS.') 160 | 161 | # construct scalar objective function 162 | sf = ScalarFunction(fun, x0.shape) 163 | closure = sf.closure 164 | if line_search == 'strong-wolfe': 165 | dir_evaluate = sf.dir_evaluate 166 | 167 | # compute initial f(x) and f'(x) 168 | x = x0.detach().view(-1).clone(memory_format=torch.contiguous_format) 169 | f, g, _, _ = closure(x) 170 | if disp > 1: 171 | print('initial fval: %0.4f' % f) 172 | if return_all: 173 | allvecs = [x] 174 | 175 | # initial settings 176 | if low_mem: 177 | hess = L_BFGS(x, history_size) 178 | else: 179 | hess = BFGS(x, inv_hess) 180 | d = g.neg() 181 | t = min(1., g.norm(p=1).reciprocal()) * lr 182 | n_iter = 0 183 | 184 | # BFGS iterations 185 | for n_iter in range(1, max_iter+1): 186 | 187 | # ================================== 188 | # compute Quasi-Newton direction 189 | # ================================== 190 | 191 | if n_iter > 1: 192 | d = hess.solve(g) 193 | 194 | # directional derivative 195 | gtd = g.dot(d) 196 | 197 | # check if directional derivative is below tolerance 198 | if gtd > -xtol: 199 | warnflag = 4 200 | msg = 'A non-descent direction was encountered.' 201 | break 202 | 203 | # ====================== 204 | # update parameter 205 | # ====================== 206 | 207 | if line_search == 'none': 208 | # no line search, move with fixed-step 209 | x_new = x + d.mul(t) 210 | f_new, g_new, _, _ = closure(x_new) 211 | elif line_search == 'strong-wolfe': 212 | # Determine step size via strong-wolfe line search 213 | f_new, g_new, t, ls_evals = \ 214 | strong_wolfe(dir_evaluate, x, t, d, f, g, gtd) 215 | x_new = x + d.mul(t) 216 | else: 217 | raise ValueError('invalid line_search option {}.'.format(line_search)) 218 | 219 | if disp > 1: 220 | print('iter %3d - fval: %0.4f' % (n_iter, f_new)) 221 | if return_all: 222 | allvecs.append(x_new) 223 | if callback is not None: 224 | callback(x_new) 225 | 226 | # ================================ 227 | # update hessian approximation 228 | # ================================ 229 | 230 | s = x_new.sub(x) 231 | y = g_new.sub(g) 232 | 233 | hess.update(s, y) 234 | 235 | # ========================================= 236 | # check conditions and update buffers 237 | # ========================================= 238 | 239 | # convergence by insufficient progress 240 | if (s.norm(p=normp) <= xtol) | ((f_new - f).abs() <= xtol): 241 | warnflag = 0 242 | msg = _status_message['success'] 243 | break 244 | 245 | # update state 246 | f[...] = f_new 247 | x.copy_(x_new) 248 | g.copy_(g_new) 249 | t = lr 250 | 251 | # convergence by 1st-order optimality 252 | if g.norm(p=normp) <= gtol: 253 | warnflag = 0 254 | msg = _status_message['success'] 255 | break 256 | 257 | # precision loss; exit 258 | if ~f.isfinite(): 259 | warnflag = 2 260 | msg = _status_message['pr_loss'] 261 | break 262 | 263 | else: 264 | # if we get to the end, the maximum num. iterations was reached 265 | warnflag = 1 266 | msg = _status_message['maxiter'] 267 | 268 | if disp: 269 | print(msg) 270 | print(" Current function value: %f" % f) 271 | print(" Iterations: %d" % n_iter) 272 | print(" Function evaluations: %d" % sf.nfev) 273 | result = OptimizeResult(fun=f, x=x.view_as(x0), grad=g.view_as(x0), 274 | status=warnflag, success=(warnflag==0), 275 | message=msg, nit=n_iter, nfev=sf.nfev) 276 | if not low_mem: 277 | if inv_hess: 278 | result['hess_inv'] = hess.H.view(2 * x0.shape) 279 | else: 280 | result['hess'] = hess.B.view(2 * x0.shape) 281 | if return_all: 282 | result['allvecs'] = allvecs 283 | 284 | return result 285 | 286 | 287 | def _minimize_bfgs( 288 | fun, x0, lr=1., inv_hess=True, max_iter=None, 289 | line_search='strong-wolfe', gtol=1e-5, xtol=1e-9, 290 | normp=float('inf'), callback=None, disp=0, return_all=False): 291 | """Minimize a multivariate function with BFGS 292 | 293 | Parameters 294 | ---------- 295 | fun : callable 296 | Scalar objective function to minimize. 297 | x0 : Tensor 298 | Initialization point. 299 | lr : float 300 | Step size for parameter updates. If using line search, this will be 301 | used as the initial step size for the search. 302 | inv_hess : bool 303 | Whether to parameterize the inverse hessian vs. the hessian with BFGS. 304 | max_iter : int, optional 305 | Maximum number of iterations to perform. Defaults to 306 | ``200 * x0.numel()``. 307 | line_search : str 308 | Line search specifier. Currently the available options are 309 | {'none', 'strong-wolfe'}. 310 | gtol : float 311 | Termination tolerance on 1st-order optimality (gradient norm). 312 | xtol : float 313 | Termination tolerance on function/parameter changes. 314 | normp : Number or str 315 | The norm type to use for termination conditions. Can be any value 316 | supported by :func:`torch.norm`. 317 | callback : callable, optional 318 | Function to call after each iteration with the current parameter 319 | state, e.g. ``callback(x)``. 320 | disp : int or bool 321 | Display (verbosity) level. Set to >0 to print status messages. 322 | return_all : bool, optional 323 | Set to True to return a list of the best solution at each of the 324 | iterations. 325 | 326 | Returns 327 | ------- 328 | result : OptimizeResult 329 | Result of the optimization routine. 330 | """ 331 | return _minimize_bfgs_core( 332 | fun, x0, lr, low_mem=False, inv_hess=inv_hess, max_iter=max_iter, 333 | line_search=line_search, gtol=gtol, xtol=xtol, 334 | normp=normp, callback=callback, disp=disp, return_all=return_all) 335 | 336 | 337 | def _minimize_lbfgs( 338 | fun, x0, lr=1., history_size=100, max_iter=None, 339 | line_search='strong-wolfe', gtol=1e-5, xtol=1e-9, 340 | normp=float('inf'), callback=None, disp=0, return_all=False): 341 | """Minimize a multivariate function with L-BFGS 342 | 343 | Parameters 344 | ---------- 345 | fun : callable 346 | Scalar objective function to minimize. 347 | x0 : Tensor 348 | Initialization point. 349 | lr : float 350 | Step size for parameter updates. If using line search, this will be 351 | used as the initial step size for the search. 352 | history_size : int 353 | History size for L-BFGS hessian estimates. 354 | max_iter : int, optional 355 | Maximum number of iterations to perform. Defaults to 356 | ``200 * x0.numel()``. 357 | line_search : str 358 | Line search specifier. Currently the available options are 359 | {'none', 'strong-wolfe'}. 360 | gtol : float 361 | Termination tolerance on 1st-order optimality (gradient norm). 362 | xtol : float 363 | Termination tolerance on function/parameter changes. 364 | normp : Number or str 365 | The norm type to use for termination conditions. Can be any value 366 | supported by :func:`torch.norm`. 367 | callback : callable, optional 368 | Function to call after each iteration with the current parameter 369 | state, e.g. ``callback(x)``. 370 | disp : int or bool 371 | Display (verbosity) level. Set to >0 to print status messages. 372 | return_all : bool, optional 373 | Set to True to return a list of the best solution at each of the 374 | iterations. 375 | 376 | Returns 377 | ------- 378 | result : OptimizeResult 379 | Result of the optimization routine. 380 | """ 381 | return _minimize_bfgs_core( 382 | fun, x0, lr, low_mem=True, history_size=history_size, 383 | max_iter=max_iter, line_search=line_search, gtol=gtol, xtol=xtol, 384 | normp=normp, callback=callback, disp=disp, return_all=return_all) -------------------------------------------------------------------------------- /torchmin/cg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.optimize import OptimizeResult 3 | 4 | from .function import ScalarFunction 5 | from .line_search import strong_wolfe 6 | 7 | try: 8 | from scipy.optimize.optimize import _status_message 9 | except ImportError: 10 | from scipy.optimize._optimize import _status_message 11 | 12 | dot = lambda u,v: torch.dot(u.view(-1), v.view(-1)) 13 | 14 | 15 | @torch.no_grad() 16 | def _minimize_cg(fun, x0, max_iter=None, gtol=1e-5, normp=float('inf'), 17 | callback=None, disp=0, return_all=False): 18 | """Minimize a scalar function of one or more variables using 19 | nonlinear conjugate gradient. 20 | 21 | The algorithm is described in Nocedal & Wright (2006) chapter 5.2. 22 | 23 | Parameters 24 | ---------- 25 | fun : callable 26 | Scalar objective function to minimize. 27 | x0 : Tensor 28 | Initialization point. 29 | max_iter : int 30 | Maximum number of iterations to perform. Defaults to 31 | ``200 * x0.numel()``. 32 | gtol : float 33 | Termination tolerance on 1st-order optimality (gradient norm). 34 | normp : float 35 | The norm type to use for termination conditions. Can be any value 36 | supported by :func:`torch.norm`. 37 | callback : callable, optional 38 | Function to call after each iteration with the current parameter 39 | state, e.g. ``callback(x)`` 40 | disp : int or bool 41 | Display (verbosity) level. Set to >0 to print status messages. 42 | return_all : bool, optional 43 | Set to True to return a list of the best solution at each of the 44 | iterations. 45 | 46 | """ 47 | disp = int(disp) 48 | if max_iter is None: 49 | max_iter = x0.numel() * 200 50 | 51 | # Construct scalar objective function 52 | sf = ScalarFunction(fun, x_shape=x0.shape) 53 | closure = sf.closure 54 | dir_evaluate = sf.dir_evaluate 55 | 56 | # initialize 57 | x = x0.detach().flatten() 58 | f, g, _, _ = closure(x) 59 | if disp > 1: 60 | print('initial fval: %0.4f' % f) 61 | if return_all: 62 | allvecs = [x] 63 | d = g.neg() 64 | grad_norm = g.norm(p=normp) 65 | old_f = f + g.norm() / 2 # Sets the initial step guess to dx ~ 1 66 | 67 | for niter in range(1, max_iter + 1): 68 | # delta/gtd 69 | delta = dot(g, g) 70 | gtd = dot(g, d) 71 | 72 | # compute initial step guess based on (f - old_f) / gtd 73 | t0 = torch.clamp(2.02 * (f - old_f) / gtd, max=1.0) 74 | if t0 <= 0: 75 | warnflag = 4 76 | msg = 'Initial step guess is negative.' 77 | break 78 | old_f = f 79 | 80 | # buffer to store next direction vector 81 | cached_step = [None] 82 | 83 | def polak_ribiere_powell_step(t, g_next): 84 | y = g_next - g 85 | beta = torch.clamp(dot(y, g_next) / delta, min=0) 86 | d_next = -g_next + d.mul(beta) 87 | torch.norm(g_next, p=normp, out=grad_norm) 88 | return t, d_next 89 | 90 | def descent_condition(t, f_next, g_next): 91 | # Polak-Ribiere+ needs an explicit check of a sufficient 92 | # descent condition, which is not guaranteed by strong Wolfe. 93 | cached_step[:] = polak_ribiere_powell_step(t, g_next) 94 | t, d_next = cached_step 95 | 96 | # Accept step if it leads to convergence. 97 | cond1 = grad_norm <= gtol 98 | 99 | # Accept step if sufficient descent condition applies. 100 | cond2 = dot(d_next, g_next) <= -0.01 * dot(g_next, g_next) 101 | 102 | return cond1 | cond2 103 | 104 | # Perform CG step 105 | f, g, t, ls_evals = \ 106 | strong_wolfe(dir_evaluate, x, t0, d, f, g, gtd, 107 | c2=0.4, extra_condition=descent_condition) 108 | 109 | # Update x and then update d (in that order) 110 | x = x + d.mul(t) 111 | if t == cached_step[0]: 112 | # Reuse already computed results if possible 113 | d = cached_step[1] 114 | else: 115 | d = polak_ribiere_powell_step(t, g)[1] 116 | 117 | if disp > 1: 118 | print('iter %3d - fval: %0.4f' % (niter, f)) 119 | if return_all: 120 | allvecs.append(x) 121 | if callback is not None: 122 | callback(x) 123 | 124 | # check optimality 125 | if grad_norm <= gtol: 126 | warnflag = 0 127 | msg = _status_message['success'] 128 | break 129 | 130 | else: 131 | # if we get to the end, the maximum iterations was reached 132 | warnflag = 1 133 | msg = _status_message['maxiter'] 134 | 135 | if disp: 136 | print("%s%s" % ("Warning: " if warnflag != 0 else "", msg)) 137 | print(" Current function value: %f" % f) 138 | print(" Iterations: %d" % niter) 139 | print(" Function evaluations: %d" % sf.nfev) 140 | 141 | result = OptimizeResult(fun=f, x=x.view_as(x0), grad=g.view_as(x0), 142 | status=warnflag, success=(warnflag == 0), 143 | message=msg, nit=niter, nfev=sf.nfev) 144 | if return_all: 145 | result['allvecs'] = allvecs 146 | return result -------------------------------------------------------------------------------- /torchmin/function.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from torch import Tensor 3 | from collections import namedtuple 4 | import torch 5 | import torch.autograd as autograd 6 | from torch._vmap_internals import _vmap 7 | 8 | from .optim.minimizer import Minimizer 9 | 10 | __all__ = ['ScalarFunction', 'VectorFunction'] 11 | 12 | 13 | 14 | # scalar function result (value) 15 | sf_value = namedtuple('sf_value', ['f', 'grad', 'hessp', 'hess']) 16 | 17 | # directional evaluate result 18 | de_value = namedtuple('de_value', ['f', 'grad']) 19 | 20 | # vector function result (value) 21 | vf_value = namedtuple('vf_value', ['f', 'jacp', 'jac']) 22 | 23 | 24 | @torch.jit.script 25 | class JacobianLinearOperator(object): 26 | def __init__(self, 27 | x: Tensor, 28 | f: Tensor, 29 | gf: Optional[Tensor] = None, 30 | gx: Optional[Tensor] = None, 31 | symmetric: bool = False) -> None: 32 | self.x = x 33 | self.f = f 34 | self.gf = gf 35 | self.gx = gx 36 | self.symmetric = symmetric 37 | # tensor-like properties 38 | self.shape = (f.numel(), x.numel()) 39 | self.dtype = x.dtype 40 | self.device = x.device 41 | 42 | def mv(self, v: Tensor) -> Tensor: 43 | if self.symmetric: 44 | return self.rmv(v) 45 | assert v.shape == self.x.shape 46 | gx, gf = self.gx, self.gf 47 | assert (gx is not None) and (gf is not None) 48 | outputs: List[Tensor] = [gx] 49 | inputs: List[Tensor] = [gf] 50 | grad_outputs: List[Optional[Tensor]] = [v] 51 | jvp = autograd.grad(outputs, inputs, grad_outputs, retain_graph=True)[0] 52 | if jvp is None: 53 | raise Exception 54 | return jvp 55 | 56 | def rmv(self, v: Tensor) -> Tensor: 57 | assert v.shape == self.f.shape 58 | outputs: List[Tensor] = [self.f] 59 | inputs: List[Tensor] = [self.x] 60 | grad_outputs: List[Optional[Tensor]] = [v] 61 | vjp = autograd.grad(outputs, inputs, grad_outputs, retain_graph=True)[0] 62 | if vjp is None: 63 | raise Exception 64 | return vjp 65 | 66 | 67 | def jacobian_linear_operator(x, f, symmetric=False): 68 | if symmetric: 69 | # Use vector-jacobian product (more efficient) 70 | gf = gx = None 71 | else: 72 | # Apply the "double backwards" trick to get true 73 | # jacobian-vector product 74 | with torch.enable_grad(): 75 | gf = torch.zeros_like(f, requires_grad=True) 76 | gx = autograd.grad(f, x, gf, create_graph=True)[0] 77 | return JacobianLinearOperator(x, f, gf, gx, symmetric) 78 | 79 | 80 | 81 | class ScalarFunction(object): 82 | """Scalar-valued objective function with autograd backend. 83 | 84 | This class provides a general-purpose objective wrapper which will 85 | compute first- and second-order derivatives via autograd as specified 86 | by the parameters of __init__. 87 | """ 88 | def __new__(cls, fun, x_shape, hessp=False, hess=False, twice_diffable=True): 89 | if isinstance(fun, Minimizer): 90 | assert fun._hessp == hessp 91 | assert fun._hess == hess 92 | return fun 93 | return super(ScalarFunction, cls).__new__(cls) 94 | 95 | def __init__(self, fun, x_shape, hessp=False, hess=False, twice_diffable=True): 96 | self._fun = fun 97 | self._x_shape = x_shape 98 | self._hessp = hessp 99 | self._hess = hess 100 | self._I = None 101 | self._twice_diffable = twice_diffable 102 | self.nfev = 0 103 | 104 | def fun(self, x): 105 | if x.shape != self._x_shape: 106 | x = x.view(self._x_shape) 107 | f = self._fun(x) 108 | if f.numel() != 1: 109 | raise RuntimeError('ScalarFunction was supplied a function ' 110 | 'that does not return scalar outputs.') 111 | self.nfev += 1 112 | 113 | return f 114 | 115 | def closure(self, x): 116 | """Evaluate the function, gradient, and hessian/hessian-product 117 | 118 | This method represents the core function call. It is used for 119 | computing newton/quasi newton directions, etc. 120 | """ 121 | x = x.detach().requires_grad_(True) 122 | with torch.enable_grad(): 123 | f = self.fun(x) 124 | grad = autograd.grad(f, x, create_graph=self._hessp or self._hess)[0] 125 | if (self._hessp or self._hess) and grad.grad_fn is None: 126 | raise RuntimeError('A 2nd-order derivative was requested but ' 127 | 'the objective is not twice-differentiable.') 128 | hessp = None 129 | hess = None 130 | if self._hessp: 131 | hessp = jacobian_linear_operator(x, grad, symmetric=self._twice_diffable) 132 | if self._hess: 133 | if self._I is None: 134 | self._I = torch.eye(x.numel(), dtype=x.dtype, device=x.device) 135 | hvp = lambda v: autograd.grad(grad, x, v, retain_graph=True)[0] 136 | hess = _vmap(hvp)(self._I) 137 | 138 | return sf_value(f=f.detach(), grad=grad.detach(), hessp=hessp, hess=hess) 139 | 140 | def dir_evaluate(self, x, t, d): 141 | """Evaluate a direction and step size. 142 | 143 | We define a separate "directional evaluate" function to be used 144 | for strong-wolfe line search. Only the function value and gradient 145 | are needed for this use case, so we avoid computational overhead. 146 | """ 147 | x = x + d.mul(t) 148 | x = x.detach().requires_grad_(True) 149 | with torch.enable_grad(): 150 | f = self.fun(x) 151 | grad = autograd.grad(f, x)[0] 152 | 153 | return de_value(f=float(f), grad=grad) 154 | 155 | 156 | class VectorFunction(object): 157 | """Vector-valued objective function with autograd backend.""" 158 | def __init__(self, fun, x_shape, jacp=False, jac=False): 159 | self._fun = fun 160 | self._x_shape = x_shape 161 | self._jacp = jacp 162 | self._jac = jac 163 | self._I = None 164 | self.nfev = 0 165 | 166 | def fun(self, x): 167 | if x.shape != self._x_shape: 168 | x = x.view(self._x_shape) 169 | f = self._fun(x) 170 | if f.dim() == 0: 171 | raise RuntimeError('VectorFunction expected vector outputs but ' 172 | 'received a scalar.') 173 | elif f.dim() > 1: 174 | f = f.view(-1) 175 | self.nfev += 1 176 | 177 | return f 178 | 179 | def closure(self, x): 180 | x = x.detach().requires_grad_(True) 181 | with torch.enable_grad(): 182 | f = self.fun(x) 183 | jacp = None 184 | jac = None 185 | if self._jacp: 186 | jacp = jacobian_linear_operator(x, f) 187 | if self._jac: 188 | if self._I is None: 189 | self._I = torch.eye(f.numel(), dtype=x.dtype, device=x.device) 190 | vjp = lambda v: autograd.grad(f, x, v, retain_graph=True)[0] 191 | jac = _vmap(vjp)(self._I) 192 | 193 | return vf_value(f=f.detach(), jacp=jacp, jac=jac) -------------------------------------------------------------------------------- /torchmin/line_search.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | from torch.optim.lbfgs import _strong_wolfe, _cubic_interpolate 4 | from scipy.optimize import minimize_scalar 5 | 6 | __all__ = ['strong_wolfe', 'brent', 'backtracking'] 7 | 8 | 9 | def _strong_wolfe_extra( 10 | obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, 11 | tolerance_change=1e-9, max_ls=25, extra_condition=None): 12 | """A modified variant of pytorch's strong-wolfe line search that supports 13 | an "extra_condition" argument (callable). 14 | 15 | This is required for methods such as Conjugate Gradient (polak-ribiere) 16 | where the strong wolfe conditions do not guarantee that we have a 17 | descent direction. 18 | 19 | Code borrowed from pytorch:: 20 | Copyright (c) 2016 Facebook, Inc. 21 | All rights reserved. 22 | """ 23 | # ported from https://github.com/torch/optim/blob/master/lswolfe.lua 24 | if extra_condition is None: 25 | extra_condition = lambda *args: True 26 | d_norm = d.abs().max() 27 | g = g.clone(memory_format=torch.contiguous_format) 28 | # evaluate objective and gradient using initial step 29 | f_new, g_new = obj_func(x, t, d) 30 | ls_func_evals = 1 31 | gtd_new = g_new.dot(d) 32 | 33 | # bracket an interval containing a point satisfying the Wolfe criteria 34 | t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd 35 | done = False 36 | ls_iter = 0 37 | while ls_iter < max_ls: 38 | # check conditions 39 | if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev): 40 | bracket = [t_prev, t] 41 | bracket_f = [f_prev, f_new] 42 | bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] 43 | bracket_gtd = [gtd_prev, gtd_new] 44 | break 45 | 46 | if abs(gtd_new) <= -c2 * gtd and extra_condition(t, f_new, g_new): 47 | bracket = [t] 48 | bracket_f = [f_new] 49 | bracket_g = [g_new] 50 | done = True 51 | break 52 | 53 | if gtd_new >= 0: 54 | bracket = [t_prev, t] 55 | bracket_f = [f_prev, f_new] 56 | bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] 57 | bracket_gtd = [gtd_prev, gtd_new] 58 | break 59 | 60 | # interpolate 61 | min_step = t + 0.01 * (t - t_prev) 62 | max_step = t * 10 63 | tmp = t 64 | t = _cubic_interpolate( 65 | t_prev, 66 | f_prev, 67 | gtd_prev, 68 | t, 69 | f_new, 70 | gtd_new, 71 | bounds=(min_step, max_step)) 72 | 73 | # next step 74 | t_prev = tmp 75 | f_prev = f_new 76 | g_prev = g_new.clone(memory_format=torch.contiguous_format) 77 | gtd_prev = gtd_new 78 | f_new, g_new = obj_func(x, t, d) 79 | ls_func_evals += 1 80 | gtd_new = g_new.dot(d) 81 | ls_iter += 1 82 | 83 | # reached max number of iterations? 84 | if ls_iter == max_ls: 85 | bracket = [0, t] 86 | bracket_f = [f, f_new] 87 | bracket_g = [g, g_new] 88 | 89 | # zoom phase: we now have a point satisfying the criteria, or 90 | # a bracket around it. We refine the bracket until we find the 91 | # exact point satisfying the criteria 92 | insuf_progress = False 93 | # find high and low points in bracket 94 | low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) 95 | while not done and ls_iter < max_ls: 96 | # line-search bracket is so small 97 | if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: 98 | break 99 | 100 | # compute new trial value 101 | t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0], 102 | bracket[1], bracket_f[1], bracket_gtd[1]) 103 | 104 | # test that we are making sufficient progress: 105 | # in case `t` is so close to boundary, we mark that we are making 106 | # insufficient progress, and if 107 | # + we have made insufficient progress in the last step, or 108 | # + `t` is at one of the boundary, 109 | # we will move `t` to a position which is `0.1 * len(bracket)` 110 | # away from the nearest boundary point. 111 | eps = 0.1 * (max(bracket) - min(bracket)) 112 | if min(max(bracket) - t, t - min(bracket)) < eps: 113 | # interpolation close to boundary 114 | if insuf_progress or t >= max(bracket) or t <= min(bracket): 115 | # evaluate at 0.1 away from boundary 116 | if abs(t - max(bracket)) < abs(t - min(bracket)): 117 | t = max(bracket) - eps 118 | else: 119 | t = min(bracket) + eps 120 | insuf_progress = False 121 | else: 122 | insuf_progress = True 123 | else: 124 | insuf_progress = False 125 | 126 | # Evaluate new point 127 | f_new, g_new = obj_func(x, t, d) 128 | ls_func_evals += 1 129 | gtd_new = g_new.dot(d) 130 | ls_iter += 1 131 | 132 | if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]: 133 | # Armijo condition not satisfied or not lower than lowest point 134 | bracket[high_pos] = t 135 | bracket_f[high_pos] = f_new 136 | bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) 137 | bracket_gtd[high_pos] = gtd_new 138 | low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0) 139 | else: 140 | if abs(gtd_new) <= -c2 * gtd and extra_condition(t, f_new, g_new): 141 | # Wolfe conditions satisfied 142 | done = True 143 | elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0: 144 | # old high becomes new low 145 | bracket[high_pos] = bracket[low_pos] 146 | bracket_f[high_pos] = bracket_f[low_pos] 147 | bracket_g[high_pos] = bracket_g[low_pos] 148 | bracket_gtd[high_pos] = bracket_gtd[low_pos] 149 | 150 | # new point becomes new low 151 | bracket[low_pos] = t 152 | bracket_f[low_pos] = f_new 153 | bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) 154 | bracket_gtd[low_pos] = gtd_new 155 | 156 | # return stuff 157 | t = bracket[low_pos] 158 | f_new = bracket_f[low_pos] 159 | g_new = bracket_g[low_pos] 160 | return f_new, g_new, t, ls_func_evals 161 | 162 | 163 | def strong_wolfe(fun, x, t, d, f, g, gtd=None, **kwargs): 164 | """ 165 | Expects `fun` to take arguments {x, t, d} and return {f(x1), f'(x1)}, 166 | where x1 is the new location after taking a step from x in direction d 167 | with step size t. 168 | """ 169 | if gtd is None: 170 | gtd = g.mul(d).sum() 171 | 172 | # use python floats for scalars as per torch.optim.lbfgs 173 | f, t = float(f), float(t) 174 | 175 | if 'extra_condition' in kwargs: 176 | f, g, t, ls_nevals = _strong_wolfe_extra( 177 | fun, x.view(-1), t, d.view(-1), f, g.view(-1), gtd, **kwargs) 178 | else: 179 | # in theory we shouldn't need to use pytorch's native _strong_wolfe 180 | # since the custom implementation above is equivalent with 181 | # extra_codition=None. But we will keep this in case they make any 182 | # changes. 183 | f, g, t, ls_nevals = _strong_wolfe( 184 | fun, x.view(-1), t, d.view(-1), f, g.view(-1), gtd, **kwargs) 185 | 186 | # convert back to torch scalar 187 | f = torch.as_tensor(f, dtype=x.dtype, device=x.device) 188 | 189 | return f, g.view_as(x), t, ls_nevals 190 | 191 | 192 | def brent(fun, x, d, bounds=(0,10)): 193 | """ 194 | Expects `fun` to take arguments {x} and return {f(x)} 195 | """ 196 | def line_obj(t): 197 | return float(fun(x + t * d)) 198 | res = minimize_scalar(line_obj, bounds=bounds, method='bounded') 199 | return res.x 200 | 201 | 202 | def backtracking(fun, x, t, d, f, g, mu=0.1, decay=0.98, max_ls=500, tmin=1e-5): 203 | """ 204 | Expects `fun` to take arguments {x, t, d} and return {f(x1), x1}, 205 | where x1 is the new location after taking a step from x in direction d 206 | with step size t. 207 | 208 | We use a generalized variant of the armijo condition that supports 209 | arbitrary step functions x' = step(x,t,d). When step(x,t,d) = x + t * d 210 | it is equivalent to the standard condition. 211 | """ 212 | x_new = x 213 | f_new = f 214 | success = False 215 | for i in range(max_ls): 216 | f_new, x_new = fun(x, t, d) 217 | if f_new <= f + mu * g.mul(x_new-x).sum(): 218 | success = True 219 | break 220 | if t <= tmin: 221 | warnings.warn('step size has reached the minimum threshold.') 222 | break 223 | t = t.mul(decay) 224 | else: 225 | warnings.warn('backtracking did not converge.') 226 | 227 | return x_new, f_new, t, success -------------------------------------------------------------------------------- /torchmin/lstsq/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module represents a pytorch re-implementation of scipy's 3 | `scipy.optimize._lsq` module. Some of the code is borrowed directly 4 | from the scipy library (all rights reserved). 5 | """ 6 | from .least_squares import least_squares -------------------------------------------------------------------------------- /torchmin/lstsq/cg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .linear_operator import aslinearoperator, TorchLinearOperator 4 | 5 | 6 | def cg(A, b, x0=None, max_iter=None, tol=1e-5): 7 | if max_iter is None: 8 | max_iter = 20 * b.numel() 9 | if x0 is None: 10 | x = torch.zeros_like(b) 11 | r = b.clone() 12 | else: 13 | x = x0.clone() 14 | r = b - A.mv(x) 15 | p = r.clone() 16 | rs = r.dot(r) 17 | rs_new = b.new_tensor(0.) 18 | alpha = b.new_tensor(0.) 19 | for n_iter in range(1, max_iter+1): 20 | Ap = A.mv(p) 21 | torch.div(rs, p.dot(Ap), out=alpha) 22 | x.add_(p, alpha=alpha) 23 | r.sub_(Ap, alpha=alpha) 24 | torch.dot(r, r, out=rs_new) 25 | p.mul_(rs_new / rs).add_(r) 26 | if n_iter % 10 == 0: 27 | r_norm = rs.sqrt() 28 | if r_norm < tol: 29 | break 30 | rs.copy_(rs_new, non_blocking=True) 31 | 32 | return x 33 | 34 | 35 | def cgls(A, b, alpha=0., **kwargs): 36 | A = aslinearoperator(A) 37 | m, n = A.shape 38 | Atb = A.rmv(b) 39 | AtA = TorchLinearOperator(shape=(n,n), 40 | matvec=lambda x: A.rmv(A.mv(x)) + alpha * x, 41 | rmatvec=None) 42 | return cg(AtA, Atb, **kwargs) -------------------------------------------------------------------------------- /torchmin/lstsq/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy.sparse.linalg import LinearOperator 4 | 5 | from .linear_operator import TorchLinearOperator 6 | 7 | EPS = torch.finfo(float).eps 8 | 9 | 10 | def in_bounds(x, lb, ub): 11 | """Check if a point lies within bounds.""" 12 | return torch.all((x >= lb) & (x <= ub)) 13 | 14 | 15 | def find_active_constraints(x, lb, ub, rtol=1e-10): 16 | """Determine which constraints are active in a given point. 17 | The threshold is computed using `rtol` and the absolute value of the 18 | closest bound. 19 | Returns 20 | ------- 21 | active : ndarray of int with shape of x 22 | Each component shows whether the corresponding constraint is active: 23 | * 0 - a constraint is not active. 24 | * -1 - a lower bound is active. 25 | * 1 - a upper bound is active. 26 | """ 27 | active = torch.zeros_like(x, dtype=torch.long) 28 | 29 | if rtol == 0: 30 | active[x <= lb] = -1 31 | active[x >= ub] = 1 32 | return active 33 | 34 | lower_dist = x - lb 35 | upper_dist = ub - x 36 | lower_threshold = rtol * lb.abs().clamp(1, None) 37 | upper_threshold = rtol * ub.abs().clamp(1, None) 38 | 39 | lower_active = (lb.isfinite() & 40 | (lower_dist <= torch.minimum(upper_dist, lower_threshold))) 41 | active[lower_active] = -1 42 | 43 | upper_active = (ub.isfinite() & 44 | (upper_dist <= torch.minimum(lower_dist, upper_threshold))) 45 | active[upper_active] = 1 46 | 47 | return active 48 | 49 | 50 | def make_strictly_feasible(x, lb, ub, rstep=1e-10): 51 | """Shift a point to the interior of a feasible region. 52 | Each element of the returned vector is at least at a relative distance 53 | `rstep` from the closest bound. If ``rstep=0`` then `np.nextafter` is used. 54 | """ 55 | x_new = x.clone() 56 | 57 | active = find_active_constraints(x, lb, ub, rstep) 58 | lower_mask = torch.eq(active, -1) 59 | upper_mask = torch.eq(active, 1) 60 | 61 | if rstep == 0: 62 | torch.nextafter(lb[lower_mask], ub[lower_mask], out=x_new[lower_mask]) 63 | torch.nextafter(ub[upper_mask], lb[upper_mask], out=x_new[upper_mask]) 64 | else: 65 | x_new[lower_mask] = lb[lower_mask].add(lb[lower_mask].abs().clamp(1,None), alpha=rstep) 66 | x_new[upper_mask] = ub[upper_mask].sub(ub[upper_mask].abs().clamp(1,None), alpha=rstep) 67 | 68 | tight_bounds = (x_new < lb) | (x_new > ub) 69 | x_new[tight_bounds] = 0.5 * (lb[tight_bounds] + ub[tight_bounds]) 70 | 71 | return x_new 72 | 73 | 74 | def solve_lsq_trust_region(n, m, uf, s, V, Delta, initial_alpha=None, 75 | rtol=0.01, max_iter=10): 76 | """Solve a trust-region problem arising in least-squares minimization. 77 | This function implements a method described by J. J. More [1]_ and used 78 | in MINPACK, but it relies on a single SVD of Jacobian instead of series 79 | of Cholesky decompositions. Before running this function, compute: 80 | ``U, s, VT = svd(J, full_matrices=False)``. 81 | """ 82 | def phi_and_derivative(alpha, suf, s, Delta): 83 | """Function of which to find zero. 84 | It is defined as "norm of regularized (by alpha) least-squares 85 | solution minus `Delta`". Refer to [1]_. 86 | """ 87 | denom = s.pow(2) + alpha 88 | p_norm = (suf / denom).norm() 89 | phi = p_norm - Delta 90 | phi_prime = -(suf.pow(2) / denom.pow(3)).sum() / p_norm 91 | return phi, phi_prime 92 | 93 | def set_alpha(alpha_lower, alpha_upper): 94 | new_alpha = (alpha_lower * alpha_upper).sqrt() 95 | return new_alpha.clamp_(0.001 * alpha_upper, None) 96 | 97 | suf = s * uf 98 | 99 | # Check if J has full rank and try Gauss-Newton step. 100 | eps = torch.finfo(s.dtype).eps 101 | full_rank = m >= n and s[-1] > eps * m * s[0] 102 | 103 | if full_rank: 104 | p = -V.mv(uf / s) 105 | if p.norm() <= Delta: 106 | return p, 0.0, 0 107 | phi, phi_prime = phi_and_derivative(0., suf, s, Delta) 108 | alpha_lower = -phi / phi_prime 109 | else: 110 | alpha_lower = s.new_tensor(0.) 111 | 112 | alpha_upper = suf.norm() / Delta 113 | 114 | if initial_alpha is None or not full_rank and initial_alpha == 0: 115 | alpha = set_alpha(alpha_lower, alpha_upper) 116 | else: 117 | alpha = initial_alpha.clone() 118 | 119 | for it in range(max_iter): 120 | # if alpha is outside of bounds, set new value (5.5)(a) 121 | alpha = torch.where((alpha < alpha_lower) | (alpha > alpha_upper), 122 | set_alpha(alpha_lower, alpha_upper), 123 | alpha) 124 | 125 | # compute new phi and phi' (5.5)(b) 126 | phi, phi_prime = phi_and_derivative(alpha, suf, s, Delta) 127 | 128 | # if phi is negative, update our upper bound (5.5)(b) 129 | alpha_upper = torch.where(phi < 0, alpha, alpha_upper) 130 | 131 | # update lower bound (5.5)(b) 132 | ratio = phi / phi_prime 133 | alpha_lower.clamp_(alpha-ratio, None) 134 | 135 | # compute new alpha (5.5)(c) 136 | alpha.addcdiv_((phi + Delta) * ratio, Delta, value=-1) 137 | 138 | if phi.abs() < rtol * Delta: 139 | break 140 | 141 | p = -V.mv(suf / (s.pow(2) + alpha)) 142 | 143 | # Make the norm of p equal to Delta, p is changed only slightly during 144 | # this. It is done to prevent p lie outside the trust region (which can 145 | # cause problems later). 146 | p.mul_(Delta / p.norm()) 147 | 148 | return p, alpha, it + 1 149 | 150 | 151 | def right_multiplied_operator(J, d): 152 | """Return J diag(d) as LinearOperator.""" 153 | if isinstance(J, LinearOperator): 154 | if torch.is_tensor(d): 155 | d = d.data.cpu().numpy() 156 | return LinearOperator(J.shape, 157 | matvec=lambda x: J.matvec(np.ravel(x) * d), 158 | matmat=lambda X: J.matmat(X * d[:, np.newaxis]), 159 | rmatvec=lambda x: d * J.rmatvec(x)) 160 | elif isinstance(J, TorchLinearOperator): 161 | return TorchLinearOperator(J.shape, 162 | matvec=lambda x: J.matvec(x.view(-1) * d), 163 | rmatvec=lambda x: d * J.rmatvec(x)) 164 | else: 165 | raise ValueError('Expected J to be a LinearOperator or ' 166 | 'TorchLinearOperator but found {}'.format(type(J))) 167 | 168 | 169 | def build_quadratic_1d(J, g, s, diag=None, s0=None): 170 | """Parameterize a multivariate quadratic function along a line. 171 | 172 | The resulting univariate quadratic function is given as follows: 173 | :: 174 | f(t) = 0.5 * (s0 + s*t).T * (J.T*J + diag) * (s0 + s*t) + 175 | g.T * (s0 + s*t) 176 | """ 177 | v = J.mv(s) 178 | a = v.dot(v) 179 | if diag is not None: 180 | a += s.dot(s * diag) 181 | a *= 0.5 182 | 183 | b = g.dot(s) 184 | 185 | if s0 is not None: 186 | u = J.mv(s0) 187 | b += u.dot(v) 188 | c = 0.5 * u.dot(u) + g.dot(s0) 189 | if diag is not None: 190 | b += s.dot(s0 * diag) 191 | c += 0.5 * s0.dot(s0 * diag) 192 | return a, b, c 193 | else: 194 | return a, b 195 | 196 | 197 | def minimize_quadratic_1d(a, b, lb, ub, c=0): 198 | """Minimize a 1-D quadratic function subject to bounds. 199 | 200 | The free term `c` is 0 by default. Bounds must be finite. 201 | """ 202 | t = [lb, ub] 203 | if a != 0: 204 | extremum = -0.5 * b / a 205 | if lb < extremum < ub: 206 | t.append(extremum) 207 | t = a.new_tensor(t) 208 | y = t * (a * t + b) + c 209 | min_index = torch.argmin(y) 210 | return t[min_index], y[min_index] 211 | 212 | 213 | def evaluate_quadratic(J, g, s, diag=None): 214 | """Compute values of a quadratic function arising in least squares. 215 | The function is 0.5 * s.T * (J.T * J + diag) * s + g.T * s. 216 | """ 217 | if s.dim() == 1: 218 | Js = J.mv(s) 219 | q = Js.dot(Js) 220 | if diag is not None: 221 | q += s.dot(s * diag) 222 | else: 223 | Js = J.matmul(s.T) 224 | q = Js.square().sum(0) 225 | if diag is not None: 226 | q += (diag * s.square()).sum(1) 227 | 228 | l = s.matmul(g) 229 | 230 | return 0.5 * q + l 231 | 232 | def solve_trust_region_2d(B, g, Delta): 233 | """Solve a general trust-region problem in 2 dimensions. 234 | The problem is reformulated as a 4th order algebraic equation, 235 | the solution of which is found by numpy.roots. 236 | """ 237 | try: 238 | L = torch.linalg.cholesky(B) 239 | p = - torch.cholesky_solve(g.unsqueeze(1), L).squeeze(1) 240 | if p.dot(p) <= Delta**2: 241 | return p, True 242 | except RuntimeError as exc: 243 | if not 'cholesky' in exc.args[0]: 244 | raise 245 | 246 | # move things to numpy 247 | device = B.device 248 | dtype = B.dtype 249 | B = B.data.cpu().numpy() 250 | g = g.data.cpu().numpy() 251 | Delta = float(Delta) 252 | 253 | a = B[0, 0] * Delta**2 254 | b = B[0, 1] * Delta**2 255 | c = B[1, 1] * Delta**2 256 | d = g[0] * Delta 257 | f = g[1] * Delta 258 | 259 | coeffs = np.array([-b + d, 2 * (a - c + f), 6 * b, 2 * (-a + c + f), -b - d]) 260 | t = np.roots(coeffs) # Can handle leading zeros. 261 | t = np.real(t[np.isreal(t)]) 262 | 263 | p = Delta * np.vstack((2 * t / (1 + t**2), (1 - t**2) / (1 + t**2))) 264 | value = 0.5 * np.sum(p * B.dot(p), axis=0) + np.dot(g, p) 265 | p = p[:, np.argmin(value)] 266 | 267 | # convert back to torch 268 | p = torch.tensor(p, device=device, dtype=dtype) 269 | 270 | return p, False 271 | 272 | 273 | def update_tr_radius(Delta, actual_reduction, predicted_reduction, 274 | step_norm, bound_hit): 275 | """Update the radius of a trust region based on the cost reduction. 276 | """ 277 | if predicted_reduction > 0: 278 | ratio = actual_reduction / predicted_reduction 279 | elif predicted_reduction == actual_reduction == 0: 280 | ratio = 1 281 | else: 282 | ratio = 0 283 | 284 | if ratio < 0.25: 285 | Delta = 0.25 * step_norm 286 | elif ratio > 0.75 and bound_hit: 287 | Delta *= 2.0 288 | 289 | return Delta, ratio 290 | 291 | 292 | def check_termination(dF, F, dx_norm, x_norm, ratio, ftol, xtol): 293 | """Check termination condition for nonlinear least squares.""" 294 | ftol_satisfied = dF < ftol * F and ratio > 0.25 295 | xtol_satisfied = dx_norm < xtol * (xtol + x_norm) 296 | 297 | if ftol_satisfied and xtol_satisfied: 298 | return 4 299 | elif ftol_satisfied: 300 | return 2 301 | elif xtol_satisfied: 302 | return 3 303 | else: 304 | return None -------------------------------------------------------------------------------- /torchmin/lstsq/least_squares.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generic interface for nonlinear least-squares minimization. 3 | """ 4 | from warnings import warn 5 | import numbers 6 | import torch 7 | 8 | from .trf import trf 9 | from .common import EPS, in_bounds, make_strictly_feasible 10 | 11 | __all__ = ['least_squares'] 12 | 13 | 14 | TERMINATION_MESSAGES = { 15 | -1: "Improper input parameters status returned from `leastsq`", 16 | 0: "The maximum number of function evaluations is exceeded.", 17 | 1: "`gtol` termination condition is satisfied.", 18 | 2: "`ftol` termination condition is satisfied.", 19 | 3: "`xtol` termination condition is satisfied.", 20 | 4: "Both `ftol` and `xtol` termination conditions are satisfied." 21 | } 22 | 23 | 24 | def prepare_bounds(bounds, x0): 25 | n = x0.shape[0] 26 | def process(b): 27 | if isinstance(b, numbers.Number): 28 | return x0.new_full((n,), b) 29 | elif isinstance(b, torch.Tensor): 30 | if b.dim() == 0: 31 | return x0.new_full((n,), b) 32 | return b 33 | else: 34 | raise ValueError 35 | 36 | lb, ub = [process(b) for b in bounds] 37 | 38 | return lb, ub 39 | 40 | 41 | def check_tolerance(ftol, xtol, gtol, method): 42 | def check(tol, name): 43 | if tol is None: 44 | tol = 0 45 | elif tol < EPS: 46 | warn("Setting `{}` below the machine epsilon ({:.2e}) effectively " 47 | "disables the corresponding termination condition." 48 | .format(name, EPS)) 49 | return tol 50 | 51 | ftol = check(ftol, "ftol") 52 | xtol = check(xtol, "xtol") 53 | gtol = check(gtol, "gtol") 54 | 55 | if method == "lm" and (ftol < EPS or xtol < EPS or gtol < EPS): 56 | raise ValueError("All tolerances must be higher than machine epsilon " 57 | "({:.2e}) for method 'lm'.".format(EPS)) 58 | elif ftol < EPS and xtol < EPS and gtol < EPS: 59 | raise ValueError("At least one of the tolerances must be higher than " 60 | "machine epsilon ({:.2e}).".format(EPS)) 61 | 62 | return ftol, xtol, gtol 63 | 64 | 65 | def check_x_scale(x_scale, x0): 66 | if isinstance(x_scale, str) and x_scale == 'jac': 67 | return x_scale 68 | try: 69 | x_scale = torch.as_tensor(x_scale) 70 | valid = x_scale.isfinite().all() and x_scale.gt(0).all() 71 | except (ValueError, TypeError): 72 | valid = False 73 | 74 | if not valid: 75 | raise ValueError("`x_scale` must be 'jac' or array_like with " 76 | "positive numbers.") 77 | 78 | if x_scale.dim() == 0: 79 | x_scale = x0.new_full(x0.shape, x_scale) 80 | 81 | if x_scale.shape != x0.shape: 82 | raise ValueError("Inconsistent shapes between `x_scale` and `x0`.") 83 | 84 | return x_scale 85 | 86 | 87 | def least_squares( 88 | fun, x0, bounds=None, method='trf', ftol=1e-8, xtol=1e-8, 89 | gtol=1e-8, x_scale=1.0, tr_solver='lsmr', tr_options=None, 90 | max_nfev=None, verbose=0): 91 | r"""Solve a nonlinear least-squares problem with bounds on the variables. 92 | 93 | Given the residual function 94 | :math:`f: \mathcal{R}^n \rightarrow \mathcal{R}^m`, `least_squares` 95 | finds a local minimum of the residual sum-of-squares (RSS) objective: 96 | 97 | .. math:: 98 | x^* = \underset{x}{\operatorname{arg\,min\,}} 99 | \frac{1}{2} ||f(x)||_2^2 \quad \text{subject to} \quad lb \leq x \leq ub 100 | 101 | The solution is found using variants of the Gauss-Newton method, a 102 | modification of Newton's method tailored to RSS problems. 103 | 104 | Parameters 105 | ---------- 106 | fun : callable 107 | Function which computes the vector of residuals, with the signature 108 | ``fun(x)``. The argument ``x`` passed to this 109 | function is a Tensor of shape (n,) (never a scalar, even for n=1). 110 | It must allocate and return a 1-D Tensor of shape (m,) or a scalar. 111 | x0 : Tensor or float 112 | Initial guess on independent variables, with shape (n,). If 113 | float, it will be treated as a 1-D Tensor with one element. 114 | bounds : 2-tuple of Tensor, optional 115 | Lower and upper bounds on independent variables. Defaults to no bounds. 116 | Each Tensor must match the size of `x0` or be a scalar, in the latter 117 | case a bound will be the same for all variables. Use ``inf`` with 118 | an appropriate sign to disable bounds on all or some variables. 119 | method : str, optional 120 | Algorithm to perform minimization. Default is 'trf'. 121 | 122 | * 'trf' : Trust Region Reflective algorithm, particularly suitable 123 | for large sparse problems with bounds. Generally robust method. 124 | * 'dogbox' : COMING SOON. dogleg algorithm with rectangular trust regions, 125 | typical use case is small problems with bounds. Not recommended 126 | for problems with rank-deficient Jacobian. 127 | ftol : float or None, optional 128 | Tolerance for termination by the change of the cost function. The 129 | optimization process is stopped when ``dF < ftol * F``, 130 | and there was an adequate agreement between a local quadratic model and 131 | the true model in the last step. If None, the termination by this 132 | condition is disabled. Default is 1e-8. 133 | xtol : float or None, optional 134 | Tolerance for termination by the change of the independent variables. 135 | Termination occurs when ``norm(dx) < xtol * (xtol + norm(x))``. 136 | If None, the termination by this condition is disabled. Default is 1e-8. 137 | gtol : float or None, optional 138 | Tolerance for termination by the norm of the gradient. Default is 1e-8. 139 | The exact condition depends on `method` used: 140 | 141 | * For 'trf' : ``norm(g_scaled, ord=inf) < gtol``, where 142 | ``g_scaled`` is the value of the gradient scaled to account for 143 | the presence of the bounds [STIR]_. 144 | * For 'dogbox' : ``norm(g_free, ord=inf) < gtol``, where 145 | ``g_free`` is the gradient with respect to the variables which 146 | are not in the optimal state on the boundary. 147 | x_scale : Tensor or 'jac', optional 148 | Characteristic scale of each variable. Setting `x_scale` is equivalent 149 | to reformulating the problem in scaled variables ``xs = x / x_scale``. 150 | An alternative view is that the size of a trust region along jth 151 | dimension is proportional to ``x_scale[j]``. Improved convergence may 152 | be achieved by setting `x_scale` such that a step of a given size 153 | along any of the scaled variables has a similar effect on the cost 154 | function. If set to 'jac', the scale is iteratively updated using the 155 | inverse norms of the columns of the Jacobian matrix (as described in 156 | [JJMore]_). 157 | max_nfev : None or int, optional 158 | Maximum number of function evaluations before the termination. 159 | Defaults to 100 * n. 160 | tr_solver : str, optional 161 | Method for solving trust-region subproblems. 162 | 163 | * 'exact' is suitable for not very large problems with dense 164 | Jacobian matrices. The computational complexity per iteration is 165 | comparable to a singular value decomposition of the Jacobian 166 | matrix. 167 | * 'lsmr' is suitable for problems with sparse and large Jacobian 168 | matrices. It uses an iterative procedure for finding a solution 169 | of a linear least-squares problem and only requires matrix-vector 170 | product evaluations. 171 | tr_options : dict, optional 172 | Keyword options passed to trust-region solver. 173 | 174 | * ``tr_solver='exact'``: `tr_options` are ignored. 175 | * ``tr_solver='lsmr'``: options for `scipy.sparse.linalg.lsmr`. 176 | Additionally, ``method='trf'`` supports 'regularize' option 177 | (bool, default is True), which adds a regularization term to the 178 | normal equation, which improves convergence if the Jacobian is 179 | rank-deficient [Byrd]_ (eq. 3.4). 180 | verbose : int, optional 181 | Level of algorithm's verbosity. 182 | 183 | * 0 : work silently (default). 184 | * 1 : display a termination report. 185 | * 2 : display progress during iterations. 186 | 187 | Returns 188 | ------- 189 | result : OptimizeResult 190 | Result of the optimization routine. 191 | 192 | References 193 | ---------- 194 | .. [STIR] M. A. Branch, T. F. Coleman, and Y. Li, "A Subspace, Interior, 195 | and Conjugate Gradient Method for Large-Scale Bound-Constrained 196 | Minimization Problems," SIAM Journal on Scientific Computing, 197 | Vol. 21, Number 1, pp 1-23, 1999. 198 | .. [Byrd] R. H. Byrd, R. B. Schnabel and G. A. Shultz, "Approximate 199 | solution of the trust region problem by minimization over 200 | two-dimensional subspaces", Math. Programming, 40, pp. 247-263, 201 | 1988. 202 | .. [JJMore] J. J. More, "The Levenberg-Marquardt Algorithm: Implementation 203 | and Theory," Numerical Analysis, ed. G. A. Watson, Lecture 204 | Notes in Mathematics 630, Springer Verlag, pp. 105-116, 1977. 205 | 206 | """ 207 | if tr_options is None: 208 | tr_options = {} 209 | 210 | if method not in ['trf', 'dogbox']: 211 | raise ValueError("`method` must be 'trf' or 'dogbox'.") 212 | 213 | if tr_solver not in ['exact', 'lsmr', 'cgls']: 214 | raise ValueError("`tr_solver` must be one of {'exact', 'lsmr', 'cgls'}.") 215 | 216 | if verbose not in [0, 1, 2]: 217 | raise ValueError("`verbose` must be in [0, 1, 2].") 218 | 219 | if bounds is None: 220 | bounds = (-float('inf'), float('inf')) 221 | elif not (isinstance(bounds, (tuple, list)) and len(bounds) == 2): 222 | raise ValueError("`bounds` must be a tuple/list with 2 elements.") 223 | 224 | if max_nfev is not None and max_nfev <= 0: 225 | raise ValueError("`max_nfev` must be None or positive integer.") 226 | 227 | # initial point 228 | x0 = torch.atleast_1d(x0) 229 | if torch.is_complex(x0): 230 | raise ValueError("`x0` must be real.") 231 | elif x0.dim() > 1: 232 | raise ValueError("`x0` must have at most 1 dimension.") 233 | 234 | # bounds 235 | lb, ub = prepare_bounds(bounds, x0) 236 | if lb.shape != x0.shape or ub.shape != x0.shape: 237 | raise ValueError("Inconsistent shapes between bounds and `x0`.") 238 | elif torch.any(lb >= ub): 239 | raise ValueError("Each lower bound must be strictly less than each " 240 | "upper bound.") 241 | elif not in_bounds(x0, lb, ub): 242 | raise ValueError("`x0` is infeasible.") 243 | 244 | # x_scale 245 | x_scale = check_x_scale(x_scale, x0) 246 | 247 | # tolerance 248 | ftol, xtol, gtol = check_tolerance(ftol, xtol, gtol, method) 249 | 250 | if method == 'trf': 251 | x0 = make_strictly_feasible(x0, lb, ub) 252 | 253 | def fun_wrapped(x): 254 | return torch.atleast_1d(fun(x)) 255 | 256 | # check function 257 | f0 = fun_wrapped(x0) 258 | if f0.dim() != 1: 259 | raise ValueError("`fun` must return at most 1-d array_like. " 260 | "f0.shape: {0}".format(f0.shape)) 261 | elif not f0.isfinite().all(): 262 | raise ValueError("Residuals are not finite in the initial point.") 263 | 264 | initial_cost = 0.5 * f0.dot(f0) 265 | 266 | if isinstance(x_scale, str) and x_scale == 'jac': 267 | raise ValueError("x_scale='jac' can't be used when `jac` " 268 | "returns LinearOperator.") 269 | 270 | if method == 'trf': 271 | result = trf(fun_wrapped, x0, f0, lb, ub, ftol, xtol, gtol, 272 | max_nfev, x_scale, tr_solver, tr_options.copy(), verbose) 273 | elif method == 'dogbox': 274 | raise NotImplementedError("'dogbox' method not yet implemented") 275 | # if tr_solver == 'lsmr' and 'regularize' in tr_options: 276 | # warn("The keyword 'regularize' in `tr_options` is not relevant " 277 | # "for 'dogbox' method.") 278 | # tr_options = tr_options.copy() 279 | # del tr_options['regularize'] 280 | # result = dogbox(fun_wrapped, x0, f0, lb, ub, ftol, xtol, gtol, 281 | # max_nfev, x_scale, tr_solver, tr_options, verbose) 282 | else: 283 | raise ValueError("`method` must be 'trf' or 'dogbox'.") 284 | 285 | result.message = TERMINATION_MESSAGES[result.status] 286 | result.success = result.status > 0 287 | 288 | if verbose >= 1: 289 | print(result.message) 290 | print("Function evaluations {0}, initial cost {1:.4e}, final cost " 291 | "{2:.4e}, first-order optimality {3:.2e}." 292 | .format(result.nfev, initial_cost, result.cost, 293 | result.optimality)) 294 | 295 | return result -------------------------------------------------------------------------------- /torchmin/lstsq/linear_operator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | from torch._vmap_internals import _vmap 4 | 5 | 6 | def jacobian_dense(fun, x, vectorize=True): 7 | x = x.detach().requires_grad_(True) 8 | return autograd.functional.jacobian(fun, x, vectorize=vectorize) 9 | 10 | 11 | def jacobian_linop(fun, x, return_f=False): 12 | x = x.detach().requires_grad_(True) 13 | with torch.enable_grad(): 14 | f = fun(x) 15 | 16 | # vector-jacobian product 17 | def vjp(v): 18 | v = v.view_as(f) 19 | vjp, = autograd.grad(f, x, v, retain_graph=True) 20 | return vjp.view(-1) 21 | 22 | # jacobian-vector product 23 | gf = torch.zeros_like(f, requires_grad=True) 24 | with torch.enable_grad(): 25 | gx, = autograd.grad(f, x, gf, create_graph=True) 26 | def jvp(v): 27 | v = v.view_as(x) 28 | jvp, = autograd.grad(gx, gf, v, retain_graph=True) 29 | return jvp.view(-1) 30 | 31 | jac = TorchLinearOperator((f.numel(), x.numel()), matvec=jvp, rmatvec=vjp) 32 | 33 | if return_f: 34 | return jac, f.detach() 35 | return jac 36 | 37 | 38 | class TorchLinearOperator(object): 39 | """Linear operator defined in terms of user-specified operations.""" 40 | def __init__(self, shape, matvec, rmatvec): 41 | self.shape = shape 42 | self._matvec = matvec 43 | self._rmatvec = rmatvec 44 | 45 | def matvec(self, x): 46 | return self._matvec(x) 47 | 48 | def rmatvec(self, x): 49 | return self._rmatvec(x) 50 | 51 | def matmat(self, X): 52 | try: 53 | return _vmap(self.matvec)(X.T).T 54 | except: 55 | return torch.hstack([self.matvec(col).view(-1,1) for col in X.T]) 56 | 57 | def transpose(self): 58 | new_shape = (self.shape[1], self.shape[0]) 59 | return type(self)(new_shape, self._rmatvec, self._matvec) 60 | 61 | mv = matvec 62 | rmv = rmatvec 63 | matmul = matmat 64 | t = transpose 65 | T = property(transpose) 66 | 67 | 68 | def aslinearoperator(A): 69 | if isinstance(A, TorchLinearOperator): 70 | return A 71 | elif isinstance(A, torch.Tensor): 72 | assert A.dim() == 2 73 | return TorchLinearOperator(A.shape, matvec=A.mv, rmatvec=A.T.mv) 74 | else: 75 | raise ValueError('Input must be either a Tensor or TorchLinearOperator') -------------------------------------------------------------------------------- /torchmin/lstsq/lsmr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code modified from scipy.sparse.linalg.lsmr 3 | 4 | Copyright (C) 2010 David Fong and Michael Saunders 5 | """ 6 | import torch 7 | 8 | from .linear_operator import aslinearoperator 9 | 10 | 11 | def _sym_ortho(a, b, out): 12 | torch.hypot(a, b, out=out[2]) 13 | torch.div(a, out[2], out=out[0]) 14 | torch.div(b, out[2], out=out[1]) 15 | return out 16 | 17 | 18 | @torch.no_grad() 19 | def lsmr(A, b, damp=0., atol=1e-6, btol=1e-6, conlim=1e8, maxiter=None, 20 | x0=None, check_nonzero=True): 21 | """Iterative solver for least-squares problems. 22 | 23 | lsmr solves the system of linear equations ``Ax = b``. If the system 24 | is inconsistent, it solves the least-squares problem ``min ||b - Ax||_2``. 25 | ``A`` is a rectangular matrix of dimension m-by-n, where all cases are 26 | allowed: m = n, m > n, or m < n. ``b`` is a vector of length m. 27 | The matrix A may be dense or sparse (usually sparse). 28 | 29 | Parameters 30 | ---------- 31 | A : {matrix, sparse matrix, ndarray, LinearOperator} 32 | Matrix A in the linear system. 33 | Alternatively, ``A`` can be a linear operator which can 34 | produce ``Ax`` and ``A^H x`` using, e.g., 35 | ``scipy.sparse.linalg.LinearOperator``. 36 | b : array_like, shape (m,) 37 | Vector ``b`` in the linear system. 38 | damp : float 39 | Damping factor for regularized least-squares. `lsmr` solves 40 | the regularized least-squares problem:: 41 | min ||(b) - ( A )x|| 42 | ||(0) (damp*I) ||_2 43 | where damp is a scalar. If damp is None or 0, the system 44 | is solved without regularization. 45 | atol, btol : float, optional 46 | Stopping tolerances. `lsmr` continues iterations until a 47 | certain backward error estimate is smaller than some quantity 48 | depending on atol and btol. Let ``r = b - Ax`` be the 49 | residual vector for the current approximate solution ``x``. 50 | If ``Ax = b`` seems to be consistent, ``lsmr`` terminates 51 | when ``norm(r) <= atol * norm(A) * norm(x) + btol * norm(b)``. 52 | Otherwise, lsmr terminates when ``norm(A^H r) <= 53 | atol * norm(A) * norm(r)``. If both tolerances are 1.0e-6 (say), 54 | the final ``norm(r)`` should be accurate to about 6 55 | digits. (The final ``x`` will usually have fewer correct digits, 56 | depending on ``cond(A)`` and the size of LAMBDA.) If `atol` 57 | or `btol` is None, a default value of 1.0e-6 will be used. 58 | Ideally, they should be estimates of the relative error in the 59 | entries of ``A`` and ``b`` respectively. For example, if the entries 60 | of ``A`` have 7 correct digits, set ``atol = 1e-7``. This prevents 61 | the algorithm from doing unnecessary work beyond the 62 | uncertainty of the input data. 63 | conlim : float, optional 64 | `lsmr` terminates if an estimate of ``cond(A)`` exceeds 65 | `conlim`. For compatible systems ``Ax = b``, conlim could be 66 | as large as 1.0e+12 (say). For least-squares problems, 67 | `conlim` should be less than 1.0e+8. If `conlim` is None, the 68 | default value is 1e+8. Maximum precision can be obtained by 69 | setting ``atol = btol = conlim = 0``, but the number of 70 | iterations may then be excessive. 71 | maxiter : int, optional 72 | `lsmr` terminates if the number of iterations reaches 73 | `maxiter`. The default is ``maxiter = min(m, n)``. For 74 | ill-conditioned systems, a larger value of `maxiter` may be 75 | needed. 76 | x0 : array_like, shape (n,), optional 77 | Initial guess of ``x``, if None zeros are used. 78 | 79 | Returns 80 | ------- 81 | x : ndarray of float 82 | Least-square solution returned. 83 | itn : int 84 | Number of iterations used. 85 | 86 | """ 87 | A = aslinearoperator(A) 88 | b = torch.atleast_1d(b) 89 | if b.dim() > 1: 90 | b = b.squeeze() 91 | eps = torch.finfo(b.dtype).eps 92 | damp = torch.as_tensor(damp, dtype=b.dtype, device=b.device) 93 | ctol = 1 / conlim if conlim > 0 else 0. 94 | m, n = A.shape 95 | if maxiter is None: 96 | maxiter = min(m, n) 97 | 98 | u = b.clone() 99 | normb = b.norm() 100 | if x0 is None: 101 | x = b.new_zeros(n) 102 | beta = normb.clone() 103 | else: 104 | x = torch.atleast_1d(x0).clone() 105 | u.sub_(A.matvec(x)) 106 | beta = u.norm() 107 | 108 | if beta > 0: 109 | u.div_(beta) 110 | v = A.rmatvec(u) 111 | alpha = v.norm() 112 | else: 113 | v = b.new_zeros(n) 114 | alpha = b.new_tensor(0) 115 | 116 | v = torch.where(alpha > 0, v / alpha, v) 117 | 118 | # Initialize variables for 1st iteration. 119 | 120 | zetabar = alpha * beta 121 | alphabar = alpha.clone() 122 | rho = b.new_tensor(1) 123 | rhobar = b.new_tensor(1) 124 | cbar = b.new_tensor(1) 125 | sbar = b.new_tensor(0) 126 | 127 | h = v.clone() 128 | hbar = b.new_zeros(n) 129 | 130 | # Initialize variables for estimation of ||r||. 131 | 132 | betadd = beta.clone() 133 | betad = b.new_tensor(0) 134 | rhodold = b.new_tensor(1) 135 | tautildeold = b.new_tensor(0) 136 | thetatilde = b.new_tensor(0) 137 | zeta = b.new_tensor(0) 138 | d = b.new_tensor(0) 139 | 140 | # Initialize variables for estimation of ||A|| and cond(A) 141 | 142 | normA2 = alpha.square() 143 | maxrbar = b.new_tensor(0) 144 | minrbar = b.new_tensor(0.99 * torch.finfo(b.dtype).max) 145 | normA = normA2.sqrt() 146 | condA = b.new_tensor(1) 147 | normx = b.new_tensor(0) 148 | normar = b.new_tensor(0) 149 | normr = b.new_tensor(0) 150 | 151 | # extra buffers (added by Reuben) 152 | c = b.new_tensor(0) 153 | s = b.new_tensor(0) 154 | chat = b.new_tensor(0) 155 | shat = b.new_tensor(0) 156 | alphahat = b.new_tensor(0) 157 | ctildeold = b.new_tensor(0) 158 | stildeold = b.new_tensor(0) 159 | rhotildeold = b.new_tensor(0) 160 | rhoold = b.new_tensor(0) 161 | rhobarold = b.new_tensor(0) 162 | zetaold = b.new_tensor(0) 163 | thetatildeold = b.new_tensor(0) 164 | betaacute = b.new_tensor(0) 165 | betahat = b.new_tensor(0) 166 | betacheck = b.new_tensor(0) 167 | taud = b.new_tensor(0) 168 | 169 | 170 | # Main iteration loop. 171 | for itn in range(1, maxiter+1): 172 | 173 | # Perform the next step of the bidiagonalization to obtain the 174 | # next beta, u, alpha, v. These satisfy the relations 175 | # beta*u = a*v - alpha*u, 176 | # alpha*v = A'*u - beta*v. 177 | 178 | u.mul_(-alpha).add_(A.matvec(v)) 179 | torch.norm(u, out=beta) 180 | 181 | if (not check_nonzero) or beta > 0: 182 | # check_nonzero option provides a means to avoid the GPU-CPU 183 | # synchronization of a `beta > 0` check. For most cases 184 | # beta == 0 is unlikely, but use this option with caution. 185 | u.div_(beta) 186 | v.mul_(-beta).add_(A.rmatvec(u)) 187 | torch.norm(v, out=alpha) 188 | v = torch.where(alpha > 0, v / alpha, v) 189 | 190 | # At this point, beta = beta_{k+1}, alpha = alpha_{k+1}. 191 | 192 | _sym_ortho(alphabar, damp, out=(chat, shat, alphahat)) 193 | 194 | # Use a plane rotation (Q_i) to turn B_i to R_i 195 | 196 | rhoold.copy_(rho, non_blocking=True) 197 | _sym_ortho(alphahat, beta, out=(c, s, rho)) 198 | thetanew = torch.mul(s, alpha) 199 | torch.mul(c, alpha, out=alphabar) 200 | 201 | # Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar 202 | 203 | rhobarold.copy_(rhobar, non_blocking=True) 204 | zetaold.copy_(zeta, non_blocking=True) 205 | thetabar = sbar * rho 206 | rhotemp = cbar * rho 207 | _sym_ortho(cbar * rho, thetanew, out=(cbar, sbar, rhobar)) 208 | torch.mul(cbar, zetabar, out=zeta) 209 | zetabar.mul_(-sbar) 210 | 211 | # Update h, h_hat, x. 212 | 213 | hbar.mul_(-thetabar * rho).div_(rhoold * rhobarold) 214 | hbar.add_(h) 215 | x.addcdiv_(zeta * hbar, rho * rhobar) 216 | h.mul_(-thetanew).div_(rho) 217 | h.add_(v) 218 | 219 | # Estimate of ||r||. 220 | 221 | # Apply rotation Qhat_{k,2k+1}. 222 | torch.mul(chat, betadd, out=betaacute) 223 | torch.mul(-shat, betadd, out=betacheck) 224 | 225 | # Apply rotation Q_{k,k+1}. 226 | torch.mul(c, betaacute, out=betahat) 227 | torch.mul(-s, betaacute, out=betadd) 228 | 229 | # Apply rotation Qtilde_{k-1}. 230 | # betad = betad_{k-1} here. 231 | 232 | thetatildeold.copy_(thetatilde, non_blocking=True) 233 | _sym_ortho(rhodold, thetabar, out=(ctildeold, stildeold, rhotildeold)) 234 | torch.mul(stildeold, rhobar, out=thetatilde) 235 | torch.mul(ctildeold, rhobar, out=rhodold) 236 | betad.mul_(-stildeold).addcmul_(ctildeold, betahat) 237 | 238 | # betad = betad_k here. 239 | # rhodold = rhod_k here. 240 | 241 | tautildeold.mul_(-thetatildeold).add_(zetaold).div_(rhotildeold) 242 | torch.div(zeta - thetatilde * tautildeold, rhodold, out=taud) 243 | d.addcmul_(betacheck, betacheck) 244 | torch.sqrt(d + (betad - taud).square() + betadd.square(), out=normr) 245 | 246 | # Estimate ||A||. 247 | normA2.addcmul_(beta, beta) 248 | torch.sqrt(normA2, out=normA) 249 | normA2.addcmul_(alpha, alpha) 250 | 251 | # Estimate cond(A). 252 | torch.max(maxrbar, rhobarold, out=maxrbar) 253 | if itn > 1: 254 | torch.min(minrbar, rhobarold, out=minrbar) 255 | 256 | 257 | # ------- Test for convergence -------- 258 | 259 | if itn % 10 == 0: 260 | 261 | # Compute norms for convergence testing. 262 | torch.abs(zetabar, out=normar) 263 | torch.norm(x, out=normx) 264 | torch.div(torch.max(maxrbar, rhotemp), torch.min(minrbar, rhotemp), 265 | out=condA) 266 | 267 | # Now use these norms to estimate certain other quantities, 268 | # some of which will be small near a solution. 269 | test1 = normr / normb 270 | test2 = normar / (normA * normr + eps) 271 | test3 = 1 / (condA + eps) 272 | t1 = test1 / (1 + normA * normx / normb) 273 | rtol = btol + atol * normA * normx / normb 274 | 275 | # The first 3 tests guard against extremely small values of 276 | # atol, btol or ctol. (The user may have set any or all of 277 | # the parameters atol, btol, conlim to 0.) 278 | # The effect is equivalent to the normAl tests using 279 | # atol = eps, btol = eps, conlim = 1/eps. 280 | 281 | # The second 3 tests allow for tolerances set by the user. 282 | 283 | stop = ((1 + test3 <= 1) | (1 + test2 <= 1) | (1 + t1 <= 1) 284 | | (test3 <= ctol) | (test2 <= atol) | (test1 <= rtol)) 285 | 286 | if stop: 287 | break 288 | 289 | return x, itn -------------------------------------------------------------------------------- /torchmin/lstsq/trf.py: -------------------------------------------------------------------------------- 1 | """Trust Region Reflective algorithm for least-squares optimization. 2 | """ 3 | import torch 4 | import numpy as np 5 | from scipy.optimize import OptimizeResult 6 | from scipy.optimize._lsq.common import (print_header_nonlinear, 7 | print_iteration_nonlinear) 8 | 9 | from .cg import cgls 10 | from .lsmr import lsmr 11 | from .linear_operator import jacobian_linop, jacobian_dense 12 | from .common import (right_multiplied_operator, build_quadratic_1d, 13 | minimize_quadratic_1d, evaluate_quadratic, 14 | solve_trust_region_2d, check_termination, 15 | update_tr_radius, solve_lsq_trust_region) 16 | 17 | 18 | def trf(fun, x0, f0, lb, ub, ftol, xtol, gtol, max_nfev, x_scale, 19 | tr_solver, tr_options, verbose): 20 | # For efficiency, it makes sense to run the simplified version of the 21 | # algorithm when no bounds are imposed. We decided to write the two 22 | # separate functions. It violates the DRY principle, but the individual 23 | # functions are kept the most readable. 24 | if lb.isneginf().all() and ub.isposinf().all(): 25 | return trf_no_bounds( 26 | fun, x0, f0, ftol, xtol, gtol, max_nfev, x_scale, 27 | tr_solver, tr_options, verbose) 28 | else: 29 | raise NotImplementedError('trf with bounds not currently supported.') 30 | 31 | 32 | def trf_no_bounds(fun, x0, f0=None, ftol=1e-8, xtol=1e-8, gtol=1e-8, 33 | max_nfev=None, x_scale=1.0, tr_solver='lsmr', 34 | tr_options=None, verbose=0): 35 | if max_nfev is None: 36 | max_nfev = x0.numel() * 100 37 | if tr_options is None: 38 | tr_options = {} 39 | assert tr_solver in ['exact', 'lsmr', 'cgls'] 40 | if tr_solver == 'exact': 41 | jacobian = jacobian_dense 42 | else: 43 | jacobian = jacobian_linop 44 | 45 | x = x0.clone() 46 | if f0 is None: 47 | f = fun(x) 48 | else: 49 | f = f0 50 | f_true = f.clone() 51 | J = jacobian(fun, x) 52 | nfev = njev = 1 53 | m, n = J.shape 54 | 55 | cost = 0.5 * f.dot(f) 56 | g = J.T.mv(f) 57 | 58 | scale = x_scale 59 | Delta = (x0 / scale).norm() 60 | if Delta == 0: 61 | Delta.fill_(1.) 62 | 63 | if tr_solver != 'exact': 64 | damp = tr_options.pop('damp', 1e-4) 65 | regularize = tr_options.pop('regularize', False) 66 | reg_term = 0. 67 | 68 | alpha = x0.new_tensor(0.) # "Levenberg-Marquardt" parameter 69 | termination_status = None 70 | iteration = 0 71 | step_norm = None 72 | actual_reduction = None 73 | 74 | if verbose == 2: 75 | print_header_nonlinear() 76 | 77 | while True: 78 | g_norm = g.norm(np.inf) 79 | if g_norm < gtol: 80 | termination_status = 1 81 | 82 | if verbose == 2: 83 | print_iteration_nonlinear(iteration, nfev, cost, actual_reduction, 84 | step_norm, g_norm) 85 | 86 | if termination_status is not None or nfev == max_nfev: 87 | break 88 | 89 | d = scale 90 | g_h = d * g 91 | 92 | if tr_solver == 'exact': 93 | J_h = J * d 94 | U, s, V = torch.linalg.svd(J_h, full_matrices=False) 95 | V = V.T 96 | uf = U.T.mv(f) 97 | else: 98 | J_h = right_multiplied_operator(J, d) 99 | 100 | if regularize: 101 | a, b = build_quadratic_1d(J_h, g_h, -g_h) 102 | to_tr = Delta / g_h.norm() 103 | ag_value = minimize_quadratic_1d(a, b, 0, to_tr)[1] 104 | reg_term = -ag_value / Delta**2 105 | 106 | damp_full = (damp**2 + reg_term)**0.5 107 | if tr_solver == 'lsmr': 108 | gn_h = lsmr(J_h, f, damp=damp_full, **tr_options)[0] 109 | elif tr_solver == 'cgls': 110 | gn_h = cgls(J_h, f, alpha=damp_full, max_iter=min(m,n), **tr_options) 111 | else: 112 | raise RuntimeError 113 | S = torch.vstack((g_h, gn_h)).T # [n,2] 114 | # Dispatch qr to CPU so long as pytorch/pytorch#22573 is not fixed 115 | S = torch.linalg.qr(S.cpu(), mode='reduced')[0].to(S.device) # [n,2] 116 | JS = J_h.matmul(S) # [m,2] 117 | B_S = JS.T.matmul(JS) # [2,2] 118 | g_S = S.T.mv(g_h) # [2] 119 | 120 | actual_reduction = -1 121 | while actual_reduction <= 0 and nfev < max_nfev: 122 | if tr_solver == 'exact': 123 | step_h, alpha, n_iter = solve_lsq_trust_region( 124 | n, m, uf, s, V, Delta, initial_alpha=alpha) 125 | else: 126 | p_S, _ = solve_trust_region_2d(B_S, g_S, Delta) 127 | step_h = S.matmul(p_S) 128 | 129 | predicted_reduction = -evaluate_quadratic(J_h, g_h, step_h) 130 | step = d * step_h 131 | x_new = x + step 132 | f_new = fun(x_new) 133 | nfev += 1 134 | 135 | step_h_norm = step_h.norm() 136 | 137 | if not f_new.isfinite().all(): 138 | Delta = 0.25 * step_h_norm 139 | continue 140 | 141 | # Usual trust-region step quality estimation. 142 | cost_new = 0.5 * f_new.dot(f_new) 143 | actual_reduction = cost - cost_new 144 | 145 | Delta_new, ratio = update_tr_radius( 146 | Delta, actual_reduction, predicted_reduction, 147 | step_h_norm, step_h_norm > 0.95 * Delta) 148 | 149 | step_norm = step.norm() 150 | termination_status = check_termination( 151 | actual_reduction, cost, step_norm, x.norm(), ratio, ftol, xtol) 152 | if termination_status is not None: 153 | break 154 | 155 | alpha *= Delta / Delta_new 156 | Delta = Delta_new 157 | 158 | if actual_reduction > 0: 159 | x, f, cost = x_new, f_new, cost_new 160 | f_true.copy_(f) 161 | J = jacobian(fun, x) 162 | g = J.T.mv(f) 163 | njev += 1 164 | else: 165 | step_norm = 0 166 | actual_reduction = 0 167 | 168 | iteration += 1 169 | 170 | if termination_status is None: 171 | termination_status = 0 172 | 173 | active_mask = torch.zeros_like(x) 174 | return OptimizeResult( 175 | x=x, cost=cost, fun=f_true, jac=J, grad=g, optimality=g_norm, 176 | active_mask=active_mask, nfev=nfev, njev=njev, 177 | status=termination_status) 178 | -------------------------------------------------------------------------------- /torchmin/minimize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .bfgs import _minimize_bfgs, _minimize_lbfgs 4 | from .cg import _minimize_cg 5 | from .newton import _minimize_newton_cg, _minimize_newton_exact 6 | from .trustregion import (_minimize_trust_exact, _minimize_dogleg, 7 | _minimize_trust_ncg, _minimize_trust_krylov) 8 | 9 | _tolerance_keys = { 10 | 'l-bfgs': 'gtol', 11 | 'bfgs': 'gtol', 12 | 'cg': 'gtol', 13 | 'newton-cg': 'xtol', 14 | 'newton-exact': 'xtol', 15 | 'dogleg': 'gtol', 16 | 'trust-ncg': 'gtol', 17 | 'trust-exact': 'gtol', 18 | 'trust-krylov': 'gtol' 19 | } 20 | 21 | 22 | def minimize( 23 | fun, x0, method, max_iter=None, tol=None, options=None, callback=None, 24 | disp=0, return_all=False): 25 | """Minimize a scalar function of one or more variables. 26 | 27 | .. note:: 28 | This is a general-purpose minimizer that calls one of the available 29 | routines based on a supplied `method` argument. 30 | 31 | Parameters 32 | ---------- 33 | fun : callable 34 | Scalar objective function to minimize. 35 | x0 : Tensor 36 | Initialization point. 37 | method : str 38 | The minimization routine to use. Should be one of 39 | 40 | - 'bfgs' 41 | - 'l-bfgs' 42 | - 'cg' 43 | - 'newton-cg' 44 | - 'newton-exact' 45 | - 'dogleg' 46 | - 'trust-ncg' 47 | - 'trust-exact' 48 | - 'trust-krylov' 49 | 50 | At the moment, method must be specified; there is no default. 51 | max_iter : int, optional 52 | Maximum number of iterations to perform. If unspecified, this will 53 | be set to the default of the selected method. 54 | tol : float 55 | Tolerance for termination. For detailed control, use solver-specific 56 | options. 57 | options : dict, optional 58 | A dictionary of keyword arguments to pass to the selected minimization 59 | routine. 60 | callback : callable, optional 61 | Function to call after each iteration with the current parameter 62 | state, e.g. ``callback(x)``. 63 | disp : int or bool 64 | Display (verbosity) level. Set to >0 to print status messages. 65 | return_all : bool, optional 66 | Set to True to return a list of the best solution at each of the 67 | iterations. 68 | 69 | Returns 70 | ------- 71 | result : OptimizeResult 72 | Result of the optimization routine. 73 | 74 | """ 75 | x0 = torch.as_tensor(x0) 76 | method = method.lower() 77 | assert method in ['bfgs', 'l-bfgs', 'cg', 'newton-cg', 'newton-exact', 78 | 'dogleg', 'trust-ncg', 'trust-exact', 'trust-krylov'] 79 | if options is None: 80 | options = {} 81 | if tol is not None: 82 | options.setdefault(_tolerance_keys[method], tol) 83 | options.setdefault('max_iter', max_iter) 84 | options.setdefault('callback', callback) 85 | options.setdefault('disp', disp) 86 | options.setdefault('return_all', return_all) 87 | 88 | if method == 'bfgs': 89 | return _minimize_bfgs(fun, x0, **options) 90 | elif method == 'l-bfgs': 91 | return _minimize_lbfgs(fun, x0, **options) 92 | elif method == 'cg': 93 | return _minimize_cg(fun, x0, **options) 94 | elif method == 'newton-cg': 95 | return _minimize_newton_cg(fun, x0, **options) 96 | elif method == 'newton-exact': 97 | return _minimize_newton_exact(fun, x0, **options) 98 | elif method == 'dogleg': 99 | return _minimize_dogleg(fun, x0, **options) 100 | elif method == 'trust-ncg': 101 | return _minimize_trust_ncg(fun, x0, **options) 102 | elif method == 'trust-exact': 103 | return _minimize_trust_exact(fun, x0, **options) 104 | elif method == 'trust-krylov': 105 | return _minimize_trust_krylov(fun, x0, **options) 106 | else: 107 | raise RuntimeError('invalid method "{}" encountered.'.format(method)) -------------------------------------------------------------------------------- /torchmin/minimize_constr.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numbers 3 | import torch 4 | import numpy as np 5 | from scipy.optimize import minimize, Bounds, NonlinearConstraint 6 | from scipy.sparse.linalg import LinearOperator 7 | 8 | _constr_keys = {'fun', 'lb', 'ub', 'jac', 'hess', 'hessp', 'keep_feasible'} 9 | _bounds_keys = {'lb', 'ub', 'keep_feasible'} 10 | 11 | 12 | def _build_obj(f, x0): 13 | numel = x0.numel() 14 | 15 | def to_tensor(x): 16 | return torch.tensor(x, dtype=x0.dtype, device=x0.device).view_as(x0) 17 | 18 | def f_with_jac(x): 19 | x = to_tensor(x).requires_grad_(True) 20 | with torch.enable_grad(): 21 | fval = f(x) 22 | grad, = torch.autograd.grad(fval, x) 23 | return fval.detach().cpu().numpy(), grad.view(-1).cpu().numpy() 24 | 25 | def f_hess(x): 26 | x = to_tensor(x).requires_grad_(True) 27 | with torch.enable_grad(): 28 | fval = f(x) 29 | grad, = torch.autograd.grad(fval, x, create_graph=True) 30 | def matvec(p): 31 | p = to_tensor(p) 32 | hvp, = torch.autograd.grad(grad, x, p, retain_graph=True) 33 | return hvp.view(-1).cpu().numpy() 34 | return LinearOperator((numel, numel), matvec=matvec) 35 | 36 | return f_with_jac, f_hess 37 | 38 | 39 | def _build_constr(constr, x0): 40 | assert isinstance(constr, dict) 41 | assert set(constr.keys()).issubset(_constr_keys) 42 | assert 'fun' in constr 43 | assert 'lb' in constr or 'ub' in constr 44 | if 'lb' not in constr: 45 | constr['lb'] = -np.inf 46 | if 'ub' not in constr: 47 | constr['ub'] = np.inf 48 | f_ = constr['fun'] 49 | numel = x0.numel() 50 | 51 | def to_tensor(x): 52 | return torch.tensor(x, dtype=x0.dtype, device=x0.device).view_as(x0) 53 | 54 | def f(x): 55 | x = to_tensor(x) 56 | return f_(x).cpu().numpy() 57 | 58 | def f_jac(x): 59 | x = to_tensor(x) 60 | if 'jac' in constr: 61 | grad = constr['jac'](x) 62 | else: 63 | x.requires_grad_(True) 64 | with torch.enable_grad(): 65 | grad, = torch.autograd.grad(f_(x), x) 66 | return grad.view(-1).cpu().numpy() 67 | 68 | def f_hess(x, v): 69 | x = to_tensor(x) 70 | if 'hess' in constr: 71 | hess = constr['hess'](x) 72 | return v[0] * hess.view(numel, numel).cpu().numpy() 73 | elif 'hessp' in constr: 74 | def matvec(p): 75 | p = to_tensor(p) 76 | hvp = constr['hessp'](x, p) 77 | return v[0] * hvp.view(-1).cpu().numpy() 78 | return LinearOperator((numel, numel), matvec=matvec) 79 | else: 80 | x.requires_grad_(True) 81 | with torch.enable_grad(): 82 | if 'jac' in constr: 83 | grad = constr['jac'](x) 84 | else: 85 | grad, = torch.autograd.grad(f_(x), x, create_graph=True) 86 | def matvec(p): 87 | p = to_tensor(p) 88 | if grad.grad_fn is None: 89 | # If grad_fn is None, then grad is constant wrt x, and hess is 0. 90 | hvp = torch.zeros_like(grad) 91 | else: 92 | hvp, = torch.autograd.grad(grad, x, p, retain_graph=True) 93 | return v[0] * hvp.view(-1).cpu().numpy() 94 | return LinearOperator((numel, numel), matvec=matvec) 95 | 96 | return NonlinearConstraint( 97 | fun=f, lb=constr['lb'], ub=constr['ub'], 98 | jac=f_jac, hess=f_hess, 99 | keep_feasible=constr.get('keep_feasible', False)) 100 | 101 | 102 | def _check_bound(val, x0): 103 | if isinstance(val, numbers.Number): 104 | return np.full(x0.numel(), val) 105 | elif isinstance(val, torch.Tensor): 106 | assert val.numel() == x0.numel() 107 | return val.detach().cpu().numpy().flatten() 108 | elif isinstance(val, np.ndarray): 109 | assert val.size == x0.numel() 110 | return val.flatten() 111 | else: 112 | raise ValueError('Bound value has unrecognized format.') 113 | 114 | 115 | def _build_bounds(bounds, x0): 116 | assert isinstance(bounds, dict) 117 | assert set(bounds.keys()).issubset(_bounds_keys) 118 | assert 'lb' in bounds or 'ub' in bounds 119 | lb = _check_bound(bounds.get('lb', -np.inf), x0) 120 | ub = _check_bound(bounds.get('ub', np.inf), x0) 121 | keep_feasible = bounds.get('keep_feasible', False) 122 | 123 | return Bounds(lb, ub, keep_feasible) 124 | 125 | 126 | @torch.no_grad() 127 | def minimize_constr( 128 | f, x0, constr=None, bounds=None, max_iter=None, tol=None, callback=None, 129 | disp=0, **kwargs): 130 | """Minimize a scalar function of one or more variables subject to 131 | bounds and/or constraints. 132 | 133 | .. note:: 134 | This is a wrapper for SciPy's 135 | `'trust-constr' `_ 136 | method. It uses autograd behind the scenes to build jacobian & hessian 137 | callables before invoking scipy. Inputs and objectivs should use 138 | PyTorch tensors like other routines. CUDA is supported; however, 139 | data will be transferred back-and-forth between GPU/CPU. 140 | 141 | Parameters 142 | ---------- 143 | f : callable 144 | Scalar objective function to minimize. 145 | x0 : Tensor 146 | Initialization point. 147 | constr : dict, optional 148 | Constraint specifications. Should be a dictionary with the 149 | following fields: 150 | 151 | * fun (callable) - Constraint function 152 | * lb (Tensor or float, optional) - Constraint lower bounds 153 | * ub : (Tensor or float, optional) - Constraint upper bounds 154 | 155 | One of either `lb` or `ub` must be provided. When `lb` == `ub` it is 156 | interpreted as an equality constraint. 157 | bounds : dict, optional 158 | Bounds on variables. Should a dictionary with at least one 159 | of the following fields: 160 | 161 | * lb (Tensor or float) - Lower bounds 162 | * ub (Tensor or float) - Upper bounds 163 | 164 | Bounds of `-inf`/`inf` are interpreted as no bound. When `lb` == `ub` 165 | it is interpreted as an equality constraint. 166 | max_iter : int, optional 167 | Maximum number of iterations to perform. If unspecified, this will 168 | be set to the default of the selected method. 169 | tol : float, optional 170 | Tolerance for termination. For detailed control, use solver-specific 171 | options. 172 | callback : callable, optional 173 | Function to call after each iteration with the current parameter 174 | state, e.g. ``callback(x)``. 175 | disp : int 176 | Level of algorithm's verbosity: 177 | 178 | * 0 : work silently (default). 179 | * 1 : display a termination report. 180 | * 2 : display progress during iterations. 181 | * 3 : display progress during iterations (more complete report). 182 | **kwargs 183 | Additional keyword arguments passed to SciPy's trust-constr solver. 184 | See options `here `_. 185 | 186 | Returns 187 | ------- 188 | result : OptimizeResult 189 | Result of the optimization routine. 190 | 191 | """ 192 | if max_iter is None: 193 | max_iter = 1000 194 | x0 = x0.detach() 195 | if x0.is_cuda: 196 | warnings.warn('GPU is not recommended for trust-constr. ' 197 | 'Data will be moved back-and-forth from CPU.') 198 | 199 | # handle callbacks 200 | if callback is not None: 201 | callback_ = callback 202 | callback = lambda x, state: callback_( 203 | torch.tensor(x, dtype=x0.dtype, device=x0.device).view_as(x0), state) 204 | 205 | # handle bounds 206 | if bounds is not None: 207 | bounds = _build_bounds(bounds, x0) 208 | 209 | # build objective function (and hessian) 210 | f_with_jac, f_hess = _build_obj(f, x0) 211 | 212 | # build constraints 213 | if constr is not None: 214 | constraints = [_build_constr(constr, x0)] 215 | else: 216 | constraints = [] 217 | 218 | # optimize 219 | x0_np = x0.cpu().numpy().flatten().copy() 220 | result = minimize( 221 | f_with_jac, x0_np, method='trust-constr', jac=True, 222 | hess=f_hess, callback=callback, tol=tol, 223 | bounds=bounds, 224 | constraints=constraints, 225 | options=dict(verbose=int(disp), maxiter=max_iter, **kwargs) 226 | ) 227 | 228 | # convert the important things to torch tensors 229 | for key in ['fun', 'grad', 'x']: 230 | result[key] = torch.tensor(result[key], dtype=x0.dtype, device=x0.device) 231 | result['x'] = result['x'].view_as(x0) 232 | 233 | return result 234 | 235 | -------------------------------------------------------------------------------- /torchmin/newton.py: -------------------------------------------------------------------------------- 1 | from scipy.optimize import OptimizeResult 2 | from scipy.sparse.linalg import eigsh 3 | from torch import Tensor 4 | import torch 5 | 6 | from .function import ScalarFunction 7 | from .line_search import strong_wolfe 8 | 9 | try: 10 | from scipy.optimize.optimize import _status_message 11 | except ImportError: 12 | from scipy.optimize._optimize import _status_message 13 | 14 | _status_message['cg_warn'] = "Warning: CG iterations didn't converge. The " \ 15 | "Hessian is not positive definite." 16 | 17 | 18 | def _cg_iters(grad, hess, max_iter, normp=1): 19 | """A CG solver specialized for the NewtonCG sub-problem. 20 | 21 | Derived from Algorithm 7.1 of "Numerical Optimization (2nd Ed.)" 22 | (Nocedal & Wright, 2006; pp. 169) 23 | """ 24 | # Get the most efficient dot product method for this problem 25 | if grad.dim() == 1: 26 | # standard dot product 27 | dot = torch.dot 28 | elif grad.dim() == 2: 29 | # batched dot product 30 | dot = lambda u,v: torch.bmm(u.unsqueeze(1), v.unsqueeze(2)).view(-1,1) 31 | else: 32 | # generalized dot product that supports batch inputs 33 | dot = lambda u,v: u.mul(v).sum(-1, keepdim=True) 34 | 35 | g_norm = grad.norm(p=normp) 36 | tol = g_norm * g_norm.sqrt().clamp(0, 0.5) 37 | eps = torch.finfo(grad.dtype).eps 38 | n_iter = 0 # TODO: remove? 39 | maxiter_reached = False 40 | 41 | # initialize state and iterate 42 | x = torch.zeros_like(grad) 43 | r = grad.clone() 44 | p = grad.neg() 45 | rs = dot(r, r) 46 | for n_iter in range(max_iter): 47 | if r.norm(p=normp) < tol: 48 | break 49 | Bp = hess.mv(p) 50 | curv = dot(p, Bp) 51 | curv_sum = curv.sum() 52 | if curv_sum < 0: 53 | # hessian is not positive-definite 54 | if n_iter == 0: 55 | # if first step, fall back to steepest descent direction 56 | # (scaled by Rayleigh quotient) 57 | x = grad.mul(rs / curv) 58 | #x = grad.neg() 59 | break 60 | elif curv_sum <= 3 * eps: 61 | break 62 | alpha = rs / curv 63 | x.addcmul_(alpha, p) 64 | r.addcmul_(alpha, Bp) 65 | rs_new = dot(r, r) 66 | p.mul_(rs_new / rs).sub_(r) 67 | rs = rs_new 68 | else: 69 | # curvature keeps increasing; bail 70 | maxiter_reached = True 71 | 72 | return x, n_iter, maxiter_reached 73 | 74 | 75 | @torch.no_grad() 76 | def _minimize_newton_cg( 77 | fun, x0, lr=1., max_iter=None, cg_max_iter=None, 78 | twice_diffable=True, line_search='strong-wolfe', xtol=1e-5, 79 | normp=1, callback=None, disp=0, return_all=False): 80 | """Minimize a scalar function of one or more variables using the 81 | Newton-Raphson method, with Conjugate Gradient for the linear inverse 82 | sub-problem. 83 | 84 | Parameters 85 | ---------- 86 | fun : callable 87 | Scalar objective function to minimize. 88 | x0 : Tensor 89 | Initialization point. 90 | lr : float 91 | Step size for parameter updates. If using line search, this will be 92 | used as the initial step size for the search. 93 | max_iter : int, optional 94 | Maximum number of iterations to perform. Defaults to 95 | ``200 * x0.numel()``. 96 | cg_max_iter : int, optional 97 | Maximum number of iterations for CG subproblem. Recommended to 98 | leave this at the default of ``20 * x0.numel()``. 99 | twice_diffable : bool 100 | Whether to assume the function is twice continuously differentiable. 101 | If True, hessian-vector products will be much faster. 102 | line_search : str 103 | Line search specifier. Currently the available options are 104 | {'none', 'strong_wolfe'}. 105 | xtol : float 106 | Average relative error in solution `xopt` acceptable for 107 | convergence. 108 | normp : Number or str 109 | The norm type to use for termination conditions. Can be any value 110 | supported by :func:`torch.norm`. 111 | callback : callable, optional 112 | Function to call after each iteration with the current parameter 113 | state, e.g. ``callback(x)``. 114 | disp : int or bool 115 | Display (verbosity) level. Set to >0 to print status messages. 116 | return_all : bool 117 | Set to True to return a list of the best solution at each of the 118 | iterations. 119 | 120 | Returns 121 | ------- 122 | result : OptimizeResult 123 | Result of the optimization routine. 124 | """ 125 | lr = float(lr) 126 | disp = int(disp) 127 | xtol = x0.numel() * xtol 128 | if max_iter is None: 129 | max_iter = x0.numel() * 200 130 | if cg_max_iter is None: 131 | cg_max_iter = x0.numel() * 20 132 | 133 | # construct scalar objective function 134 | sf = ScalarFunction(fun, x0.shape, hessp=True, twice_diffable=twice_diffable) 135 | closure = sf.closure 136 | if line_search == 'strong-wolfe': 137 | dir_evaluate = sf.dir_evaluate 138 | 139 | # initial settings 140 | x = x0.detach().clone(memory_format=torch.contiguous_format) 141 | f, g, hessp, _ = closure(x) 142 | if disp > 1: 143 | print('initial fval: %0.4f' % f) 144 | if return_all: 145 | allvecs = [x] 146 | ncg = 0 # number of cg iterations 147 | n_iter = 0 148 | 149 | # begin optimization loop 150 | for n_iter in range(1, max_iter + 1): 151 | 152 | # ============================================================ 153 | # Compute a search direction pk by applying the CG method to 154 | # H_f(xk) p = - J_f(xk) starting from 0. 155 | # ============================================================ 156 | 157 | # Compute search direction with conjugate gradient (GG) 158 | d, cg_iters, cg_fail = _cg_iters(g, hessp, cg_max_iter, normp) 159 | ncg += cg_iters 160 | 161 | if cg_fail: 162 | warnflag = 3 163 | msg = _status_message['cg_warn'] 164 | break 165 | 166 | # ===================================================== 167 | # Perform variable update (with optional line search) 168 | # ===================================================== 169 | 170 | if line_search == 'none': 171 | update = d.mul(lr) 172 | x = x + update 173 | elif line_search == 'strong-wolfe': 174 | # strong-wolfe line search 175 | _, _, t, ls_nevals = strong_wolfe(dir_evaluate, x, lr, d, f, g) 176 | update = d.mul(t) 177 | x = x + update 178 | else: 179 | raise ValueError('invalid line_search option {}.'.format(line_search)) 180 | 181 | # re-evaluate function 182 | f, g, hessp, _ = closure(x) 183 | 184 | if disp > 1: 185 | print('iter %3d - fval: %0.4f' % (n_iter, f)) 186 | if callback is not None: 187 | callback(x) 188 | if return_all: 189 | allvecs.append(x) 190 | 191 | # ========================== 192 | # check for convergence 193 | # ========================== 194 | 195 | if update.norm(p=normp) <= xtol: 196 | warnflag = 0 197 | msg = _status_message['success'] 198 | break 199 | 200 | if not f.isfinite(): 201 | warnflag = 3 202 | msg = _status_message['nan'] 203 | break 204 | 205 | else: 206 | # if we get to the end, the maximum num. iterations was reached 207 | warnflag = 1 208 | msg = _status_message['maxiter'] 209 | 210 | if disp: 211 | print(msg) 212 | print(" Current function value: %f" % f) 213 | print(" Iterations: %d" % n_iter) 214 | print(" Function evaluations: %d" % sf.nfev) 215 | print(" CG iterations: %d" % ncg) 216 | result = OptimizeResult(fun=f, x=x.view_as(x0), grad=g.view_as(x0), 217 | status=warnflag, success=(warnflag==0), 218 | message=msg, nit=n_iter, nfev=sf.nfev, ncg=ncg) 219 | if return_all: 220 | result['allvecs'] = allvecs 221 | return result 222 | 223 | 224 | 225 | @torch.no_grad() 226 | def _minimize_newton_exact( 227 | fun, x0, lr=1., max_iter=None, line_search='strong-wolfe', xtol=1e-5, 228 | normp=1, tikhonov=0., handle_npd='grad', callback=None, disp=0, 229 | return_all=False): 230 | """Minimize a scalar function of one or more variables using the 231 | Newton-Raphson method. 232 | 233 | This variant uses an "exact" Newton routine based on Cholesky factorization 234 | of the explicit Hessian matrix. 235 | 236 | Parameters 237 | ---------- 238 | fun : callable 239 | Scalar objective function to minimize. 240 | x0 : Tensor 241 | Initialization point. 242 | lr : float 243 | Step size for parameter updates. If using line search, this will be 244 | used as the initial step size for the search. 245 | max_iter : int, optional 246 | Maximum number of iterations to perform. Defaults to 247 | ``200 * x0.numel()``. 248 | line_search : str 249 | Line search specifier. Currently the available options are 250 | {'none', 'strong_wolfe'}. 251 | xtol : float 252 | Average relative error in solution `xopt` acceptable for 253 | convergence. 254 | normp : Number or str 255 | The norm type to use for termination conditions. Can be any value 256 | supported by :func:`torch.norm`. 257 | tikhonov : float 258 | Optional diagonal regularization (Tikhonov) parameter for the Hessian. 259 | handle_npd : str 260 | Mode for handling non-positive definite hessian matrices. Can be one 261 | of the following: 262 | 263 | * 'grad' : use steepest descent direction (gradient) 264 | * 'lu' : solve the inverse hessian with LU factorization 265 | * 'eig' : use symmetric eigendecomposition to determine a 266 | diagonal regularization parameter 267 | callback : callable, optional 268 | Function to call after each iteration with the current parameter 269 | state, e.g. ``callback(x)``. 270 | disp : int or bool 271 | Display (verbosity) level. Set to >0 to print status messages. 272 | return_all : bool 273 | Set to True to return a list of the best solution at each of the 274 | iterations. 275 | 276 | Returns 277 | ------- 278 | result : OptimizeResult 279 | Result of the optimization routine. 280 | """ 281 | lr = float(lr) 282 | disp = int(disp) 283 | xtol = x0.numel() * xtol 284 | if max_iter is None: 285 | max_iter = x0.numel() * 200 286 | 287 | # Construct scalar objective function 288 | sf = ScalarFunction(fun, x0.shape, hess=True) 289 | closure = sf.closure 290 | if line_search == 'strong-wolfe': 291 | dir_evaluate = sf.dir_evaluate 292 | 293 | # initial settings 294 | x = x0.detach().view(-1).clone(memory_format=torch.contiguous_format) 295 | f, g, _, hess = closure(x) 296 | if tikhonov > 0: 297 | hess.diagonal().add_(tikhonov) 298 | if disp > 1: 299 | print('initial fval: %0.4f' % f) 300 | if return_all: 301 | allvecs = [x] 302 | nfail = 0 303 | n_iter = 0 304 | 305 | # begin optimization loop 306 | for n_iter in range(1, max_iter + 1): 307 | 308 | # ================================================== 309 | # Compute a search direction d by solving 310 | # H_f(x) d = - J_f(x) 311 | # with the true Hessian and Cholesky factorization 312 | # =================================================== 313 | 314 | # Compute search direction with Cholesky solve 315 | L, info = torch.linalg.cholesky_ex(hess) 316 | 317 | if info == 0: 318 | d = torch.cholesky_solve(g.neg().unsqueeze(1), L).squeeze(1) 319 | else: 320 | nfail += 1 321 | if handle_npd == 'lu': 322 | d = torch.linalg.solve(hess, g.neg()) 323 | elif handle_npd in ['grad', 'cauchy']: 324 | d = g.neg() 325 | if handle_npd == 'cauchy': 326 | # cauchy point for a trust radius of delta=1. 327 | # equivalent to 'grad' with a scaled lr 328 | gnorm = g.norm(p=2) 329 | scale = 1 / gnorm 330 | gHg = g.dot(hess.mv(g)) 331 | if gHg > 0: 332 | scale *= torch.clamp_(gnorm.pow(3) / gHg, max=1) 333 | d *= scale 334 | elif handle_npd == 'eig': 335 | # this setting is experimental! use with caution 336 | # TODO: why use the factor 1.5 here? Seems to work best 337 | eig0 = eigsh(hess.cpu().numpy(), k=1, which="SA", tol=1e-4)[0].item() 338 | tau = max(1e-3 - 1.5 * eig0, 0) 339 | hess.diagonal().add_(tau) 340 | L = torch.linalg.cholesky(hess) 341 | d = torch.cholesky_solve(g.neg().unsqueeze(1), L).squeeze(1) 342 | else: 343 | raise RuntimeError('invalid handle_npd encountered.') 344 | 345 | 346 | # ===================================================== 347 | # Perform variable update (with optional line search) 348 | # ===================================================== 349 | 350 | if line_search == 'none': 351 | update = d.mul(lr) 352 | x = x + update 353 | elif line_search == 'strong-wolfe': 354 | # strong-wolfe line search 355 | _, _, t, ls_nevals = strong_wolfe(dir_evaluate, x, lr, d, f, g) 356 | update = d.mul(t) 357 | x = x + update 358 | else: 359 | raise ValueError('invalid line_search option {}.'.format(line_search)) 360 | 361 | # =================================== 362 | # Re-evaluate func/Jacobian/Hessian 363 | # =================================== 364 | 365 | f, g, _, hess = closure(x) 366 | if tikhonov > 0: 367 | hess.diagonal().add_(tikhonov) 368 | 369 | if disp > 1: 370 | print('iter %3d - fval: %0.4f - info: %d' % (n_iter, f, info)) 371 | if callback is not None: 372 | callback(x) 373 | if return_all: 374 | allvecs.append(x) 375 | 376 | # ========================== 377 | # check for convergence 378 | # ========================== 379 | 380 | if update.norm(p=normp) <= xtol: 381 | warnflag = 0 382 | msg = _status_message['success'] 383 | break 384 | 385 | if not f.isfinite(): 386 | warnflag = 3 387 | msg = _status_message['nan'] 388 | break 389 | 390 | else: 391 | # if we get to the end, the maximum num. iterations was reached 392 | warnflag = 1 393 | msg = _status_message['maxiter'] 394 | 395 | if disp: 396 | print(msg) 397 | print(" Current function value: %f" % f) 398 | print(" Iterations: %d" % n_iter) 399 | print(" Function evaluations: %d" % sf.nfev) 400 | result = OptimizeResult(fun=f, x=x.view_as(x0), grad=g.view_as(x0), 401 | hess=hess.view(2 * x0.shape), 402 | status=warnflag, success=(warnflag==0), 403 | message=msg, nit=n_iter, nfev=sf.nfev, nfail=nfail) 404 | if return_all: 405 | result['allvecs'] = allvecs 406 | return result 407 | -------------------------------------------------------------------------------- /torchmin/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .minimizer import Minimizer 2 | from .scipy_minimizer import ScipyMinimizer -------------------------------------------------------------------------------- /torchmin/optim/minimizer.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import torch 3 | from torch.optim import Optimizer 4 | 5 | 6 | class LinearOperator: 7 | """A generic linear operator to use with Minimizer""" 8 | def __init__(self, matvec, shape, dtype=torch.float, device=None): 9 | self.rmv = matvec 10 | self.mv = matvec 11 | self.shape = shape 12 | self.dtype = dtype 13 | self.device = device 14 | 15 | 16 | class Minimizer(Optimizer): 17 | """A general-purpose PyTorch optimizer for unconstrained function 18 | minimization. 19 | 20 | .. warning:: 21 | This optimizer doesn't support per-parameter options and parameter 22 | groups (there can be only one). 23 | 24 | .. warning:: 25 | Right now all parameters have to be on a single device. This will be 26 | improved in the future. 27 | 28 | Parameters 29 | ---------- 30 | params : iterable 31 | An iterable of :class:`torch.Tensor` s. Specifies what Tensors 32 | should be optimized. 33 | method : str 34 | Minimization method (algorithm) to use. Must be one of the methods 35 | offered in :func:`torchmin.minimize()`. Defaults to 'bfgs'. 36 | **minimize_kwargs : dict 37 | Additional keyword arguments that will be passed to 38 | :func:`torchmin.minimize()`. 39 | 40 | """ 41 | def __init__(self, 42 | params, 43 | method='bfgs', 44 | **minimize_kwargs): 45 | assert isinstance(method, str) 46 | method_ = method.lower() 47 | 48 | self._hessp = self._hess = False 49 | if method_ in ['bfgs', 'l-bfgs', 'cg']: 50 | pass 51 | elif method_ in ['newton-cg', 'trust-ncg', 'trust-krylov']: 52 | self._hessp = True 53 | elif method_ in ['newton-exact', 'dogleg', 'trust-exact']: 54 | self._hess = True 55 | else: 56 | raise ValueError('Unknown method {}'.format(method)) 57 | 58 | defaults = dict(method=method_, **minimize_kwargs) 59 | super().__init__(params, defaults) 60 | 61 | if len(self.param_groups) != 1: 62 | raise ValueError("Minimizer doesn't support per-parameter options") 63 | 64 | self._nfev = [0] 65 | self._params = self.param_groups[0]['params'] 66 | self._numel_cache = None 67 | self._closure = None 68 | self._result = None 69 | 70 | @property 71 | def nfev(self): 72 | return self._nfev[0] 73 | 74 | def _numel(self): 75 | if self._numel_cache is None: 76 | self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0) 77 | return self._numel_cache 78 | 79 | def _gather_flat_param(self): 80 | params = [] 81 | for p in self._params: 82 | if p.data.is_sparse: 83 | p = p.data.to_dense().view(-1) 84 | else: 85 | p = p.data.view(-1) 86 | params.append(p) 87 | return torch.cat(params) 88 | 89 | def _gather_flat_grad(self): 90 | grads = [] 91 | for p in self._params: 92 | if p.grad is None: 93 | g = p.new_zeros(p.numel()) 94 | elif p.grad.is_sparse: 95 | g = p.grad.to_dense().view(-1) 96 | else: 97 | g = p.grad.view(-1) 98 | grads.append(g) 99 | return torch.cat(grads) 100 | 101 | def _set_flat_param(self, value): 102 | offset = 0 103 | for p in self._params: 104 | numel = p.numel() 105 | p.copy_(value[offset:offset+numel].view_as(p)) 106 | offset += numel 107 | assert offset == self._numel() 108 | 109 | def closure(self, x): 110 | from torchmin.function import sf_value 111 | 112 | assert self._closure is not None 113 | self._set_flat_param(x) 114 | with torch.enable_grad(): 115 | f = self._closure() 116 | f.backward(create_graph=self._hessp or self._hess) 117 | grad = self._gather_flat_grad() 118 | 119 | grad_out = grad.detach().clone() 120 | hessp = None 121 | hess = None 122 | if self._hessp or self._hess: 123 | grad_accum = grad.detach().clone() 124 | def hvp(v): 125 | assert v.shape == grad.shape 126 | grad.backward(gradient=v, retain_graph=True) 127 | output = self._gather_flat_grad().detach() - grad_accum 128 | grad_accum.add_(output) 129 | return output 130 | 131 | numel = self._numel() 132 | if self._hessp: 133 | hessp = LinearOperator(hvp, shape=(numel, numel), 134 | dtype=grad.dtype, device=grad.device) 135 | if self._hess: 136 | eye = torch.eye(numel, dtype=grad.dtype, device=grad.device) 137 | hess = torch.zeros(numel, numel, dtype=grad.dtype, device=grad.device) 138 | for i in range(numel): 139 | hess[i] = hvp(eye[i]) 140 | 141 | return sf_value(f=f.detach(), grad=grad_out.detach(), hessp=hessp, hess=hess) 142 | 143 | def dir_evaluate(self, x, t, d): 144 | from torchmin.function import de_value 145 | 146 | self._set_flat_param(x + d.mul(t)) 147 | with torch.enable_grad(): 148 | f = self._closure() 149 | f.backward() 150 | grad = self._gather_flat_grad() 151 | self._set_flat_param(x) 152 | 153 | return de_value(f=float(f), grad=grad) 154 | 155 | @torch.no_grad() 156 | def step(self, closure): 157 | """Perform an optimization step. 158 | 159 | The function "closure" should have a slightly different 160 | form vs. the PyTorch standard: namely, it should not include any 161 | `backward()` calls. Backward steps will be performed internally 162 | by the optimizer. 163 | 164 | >>> def closure(): 165 | >>> optimizer.zero_grad() 166 | >>> output = model(input) 167 | >>> loss = loss_fn(output, target) 168 | >>> # loss.backward() <-- skip this step! 169 | >>> return loss 170 | 171 | Parameters 172 | ---------- 173 | closure : callable 174 | A function that re-evaluates the model and returns the loss. 175 | 176 | """ 177 | from torchmin.minimize import minimize 178 | 179 | # sanity check 180 | assert len(self.param_groups) == 1 181 | 182 | # overwrite closure 183 | closure_ = closure 184 | def closure(): 185 | self._nfev[0] += 1 186 | return closure_() 187 | self._closure = closure 188 | 189 | # get initial value 190 | x0 = self._gather_flat_param() 191 | 192 | # perform parameter update 193 | kwargs = {k:v for k,v in self.param_groups[0].items() if k != 'params'} 194 | self._result = minimize(self, x0, **kwargs) 195 | 196 | # set final value 197 | self._set_flat_param(self._result.x) 198 | 199 | return self._result.fun -------------------------------------------------------------------------------- /torchmin/optim/scipy_minimizer.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import numpy as np 3 | import torch 4 | from functools import reduce 5 | from torch.optim import Optimizer 6 | from scipy import optimize 7 | from torch._vmap_internals import _vmap 8 | from torch.autograd.functional import (_construct_standard_basis_for, 9 | _grad_postprocess, _tuple_postprocess, 10 | _as_tuple) 11 | 12 | 13 | def _build_bounds(bounds, params, numel_total): 14 | if len(bounds) != len(params): 15 | raise ValueError('bounds must be an iterable with same length as params') 16 | 17 | lb = np.full(numel_total, -np.inf) 18 | ub = np.full(numel_total, np.inf) 19 | keep_feasible = np.zeros(numel_total, dtype=np.bool) 20 | 21 | def process_bound(x, numel): 22 | if isinstance(x, torch.Tensor): 23 | assert x.numel() == numel 24 | return x.view(-1).detach().cpu().numpy() 25 | elif isinstance(x, np.ndarray): 26 | assert x.size == numel 27 | return x.flatten() 28 | elif isinstance(x, (bool, numbers.Number)): 29 | return x 30 | else: 31 | raise ValueError('invalid bound value.') 32 | 33 | offset = 0 34 | for bound, p in zip(bounds, params): 35 | numel = p.numel() 36 | if bound is None: 37 | offset += numel 38 | continue 39 | if not isinstance(bound, (list, tuple)) and len(bound) in [2,3]: 40 | raise ValueError('elements of "bounds" must each be a ' 41 | 'list/tuple of length 2 or 3') 42 | if bound[0] is None and bound[1] is None: 43 | raise ValueError('either lower or upper bound must be defined.') 44 | if bound[0] is not None: 45 | lb[offset:offset + numel] = process_bound(bound[0], numel) 46 | if bound[1] is not None: 47 | ub[offset:offset + numel] = process_bound(bound[1], numel) 48 | if len(bound) == 3: 49 | keep_feasible[offset:offset + numel] = process_bound(bound[2], numel) 50 | offset += numel 51 | 52 | return optimize.Bounds(lb, ub, keep_feasible) 53 | 54 | 55 | def _jacobian(inputs, outputs): 56 | """A modified variant of torch.autograd.functional.jacobian for 57 | pre-computed outputs 58 | 59 | This is only used for nonlinear parameter constraints (if provided) 60 | """ 61 | is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian") 62 | is_outputs_tuple, outputs = _as_tuple(outputs, "outputs", "jacobian") 63 | 64 | output_numels = tuple(output.numel() for output in outputs) 65 | grad_outputs = _construct_standard_basis_for(outputs, output_numels) 66 | with torch.enable_grad(): 67 | flat_outputs = tuple(output.reshape(-1) for output in outputs) 68 | 69 | def vjp(grad_output): 70 | vj = list(torch.autograd.grad(flat_outputs, inputs, grad_output, allow_unused=True)) 71 | for el_idx, vj_el in enumerate(vj): 72 | if vj_el is not None: 73 | continue 74 | vj[el_idx] = torch.zeros_like(inputs[el_idx]) 75 | return tuple(vj) 76 | 77 | jacobians_of_flat_output = _vmap(vjp)(grad_outputs) 78 | 79 | jacobian_input_output = [] 80 | for jac, input_i in zip(jacobians_of_flat_output, inputs): 81 | jacobian_input_i_output = [] 82 | for jac, output_j in zip(jac.split(output_numels, dim=0), outputs): 83 | jacobian_input_i_output_j = jac.view(output_j.shape + input_i.shape) 84 | jacobian_input_i_output.append(jacobian_input_i_output_j) 85 | jacobian_input_output.append(jacobian_input_i_output) 86 | 87 | jacobian_output_input = tuple(zip(*jacobian_input_output)) 88 | 89 | jacobian_output_input = _grad_postprocess(jacobian_output_input, create_graph=False) 90 | return _tuple_postprocess(jacobian_output_input, (is_outputs_tuple, is_inputs_tuple)) 91 | 92 | 93 | class ScipyMinimizer(Optimizer): 94 | """A PyTorch optimizer for constrained & unconstrained function 95 | minimization. 96 | 97 | .. note:: 98 | This optimizer is a wrapper for :func:`scipy.optimize.minimize`. 99 | It uses autograd behind the scenes to build jacobian & hessian 100 | callables before invoking scipy. Inputs and objectivs should use 101 | PyTorch tensors like other routines. CUDA is supported; however, 102 | data will be transferred back-and-forth between GPU/CPU. 103 | 104 | .. warning:: 105 | This optimizer doesn't support per-parameter options and parameter 106 | groups (there can be only one). 107 | 108 | .. warning:: 109 | Right now all parameters have to be on a single device. This will be 110 | improved in the future. 111 | 112 | Parameters 113 | ---------- 114 | params : iterable 115 | An iterable of :class:`torch.Tensor` s. Specifies what Tensors 116 | should be optimized. 117 | method : str 118 | One of the various optimization methods offered in scipy minimize. 119 | Defaults to 'bfgs'. 120 | bounds : iterable, optional 121 | An iterable of :class:`torch.Tensor` s or :class:`float` s with same 122 | length as `params`. Specifies boundaries for each parameter. 123 | constraints : dict, optional 124 | TODO 125 | tol : float, optional 126 | TODO 127 | options : dict, optional 128 | TODO 129 | 130 | """ 131 | def __init__(self, 132 | params, 133 | method='bfgs', 134 | bounds=None, 135 | constraints=(), # experimental feature! use with caution 136 | tol=None, 137 | options=None): 138 | assert isinstance(method, str) 139 | method = method.lower() 140 | defaults = dict( 141 | method=method, 142 | bounds=bounds, 143 | constraints=constraints, 144 | tol=tol, 145 | options=options) 146 | super().__init__(params, defaults) 147 | 148 | if len(self.param_groups) != 1: 149 | raise ValueError("Minimize doesn't support per-parameter options " 150 | "(parameter groups)") 151 | if constraints != () and method != 'trust-constr': 152 | raise NotImplementedError("Constraints only currently supported for " 153 | "method='trust-constr'.") 154 | 155 | self._params = self.param_groups[0]['params'] 156 | self._param_bounds = self.param_groups[0]['bounds'] 157 | self._numel_cache = None 158 | self._bounds_cache = None 159 | self._result = None 160 | 161 | def _numel(self): 162 | if self._numel_cache is None: 163 | self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0) 164 | return self._numel_cache 165 | 166 | def _bounds(self): 167 | if self._param_bounds is None: 168 | return None 169 | if self._bounds_cache is None: 170 | self._bounds_cache = _build_bounds(self._param_bounds, self._params, 171 | self._numel()) 172 | return self._bounds_cache 173 | 174 | def _gather_flat_param(self): 175 | views = [] 176 | for p in self._params: 177 | if p.data.is_sparse: 178 | view = p.data.to_dense().view(-1) 179 | else: 180 | view = p.data.view(-1) 181 | views.append(view) 182 | return torch.cat(views, 0) 183 | 184 | def _gather_flat_grad(self): 185 | views = [] 186 | for p in self._params: 187 | if p.grad is None: 188 | view = p.new_zeros(p.numel()) 189 | elif p.grad.is_sparse: 190 | view = p.grad.to_dense().view(-1) 191 | else: 192 | view = p.grad.view(-1) 193 | views.append(view) 194 | return torch.cat(views, 0) 195 | 196 | def _set_flat_param(self, value): 197 | offset = 0 198 | for p in self._params: 199 | numel = p.numel() 200 | # view as to avoid deprecated pointwise semantics 201 | p.copy_(value[offset:offset + numel].view_as(p)) 202 | offset += numel 203 | assert offset == self._numel() 204 | 205 | def _build_constraints(self, constraints): 206 | assert isinstance(constraints, dict) 207 | assert 'fun' in constraints 208 | assert 'lb' in constraints or 'ub' in constraints 209 | 210 | to_tensor = lambda x: self._params[0].new_tensor(x) 211 | to_array = lambda x: x.cpu().numpy() 212 | fun_ = constraints['fun'] 213 | lb = constraints.get('lb', -np.inf) 214 | ub = constraints.get('ub', np.inf) 215 | strict = constraints.get('keep_feasible', False) 216 | lb = to_array(lb) if torch.is_tensor(lb) else lb 217 | ub = to_array(ub) if torch.is_tensor(ub) else ub 218 | strict = to_array(strict) if torch.is_tensor(strict) else strict 219 | 220 | def fun(x): 221 | self._set_flat_param(to_tensor(x)) 222 | return to_array(fun_()) 223 | 224 | def jac(x): 225 | self._set_flat_param(to_tensor(x)) 226 | with torch.enable_grad(): 227 | output = fun_() 228 | 229 | # this is now a tuple of tensors, one per parameter, each with 230 | # shape (num_outputs, *param_shape). 231 | J_seq = _jacobian(inputs=tuple(self._params), outputs=output) 232 | 233 | # flatten and stack the tensors along dim 1 to get our full matrix 234 | J = torch.cat([elt.view(output.numel(), -1) for elt in J_seq], 1) 235 | 236 | return to_array(J) 237 | 238 | return optimize.NonlinearConstraint(fun, lb, ub, jac=jac, keep_feasible=strict) 239 | 240 | @torch.no_grad() 241 | def step(self, closure): 242 | """Perform an optimization step. 243 | 244 | Parameters 245 | ---------- 246 | closure : callable 247 | A function that re-evaluates the model and returns the loss. 248 | See the `closure instructions 249 | `_ 250 | from PyTorch Optimizer docs for areference on how to construct 251 | this callable. 252 | """ 253 | # sanity check 254 | assert len(self.param_groups) == 1 255 | 256 | # functions to convert numpy -> torch and torch -> numpy 257 | to_tensor = lambda x: self._params[0].new_tensor(x) 258 | to_array = lambda x: x.cpu().numpy() 259 | 260 | # optimizer settings 261 | group = self.param_groups[0] 262 | method = group['method'] 263 | bounds = self._bounds() 264 | constraints = group['constraints'] 265 | tol = group['tol'] 266 | options = group['options'] 267 | 268 | # build constraints (if provided) 269 | if constraints != (): 270 | constraints = self._build_constraints(constraints) 271 | 272 | # build objective 273 | def fun(x): 274 | x = to_tensor(x) 275 | self._set_flat_param(x) 276 | with torch.enable_grad(): 277 | loss = closure() 278 | grad = self._gather_flat_grad() 279 | return float(loss), to_array(grad) 280 | 281 | # initial value (numpy array) 282 | x0 = to_array(self._gather_flat_param()) 283 | 284 | # optimize 285 | self._result = optimize.minimize( 286 | fun, x0, method=method, jac=True, bounds=bounds, 287 | constraints=constraints, tol=tol, options=options 288 | ) 289 | 290 | # set final param 291 | self._set_flat_param(to_tensor(self._result.x)) 292 | 293 | return to_tensor(self._result.fun) 294 | -------------------------------------------------------------------------------- /torchmin/trustregion/__init__.py: -------------------------------------------------------------------------------- 1 | from .ncg import _minimize_trust_ncg 2 | from .exact import _minimize_trust_exact 3 | from .dogleg import _minimize_dogleg 4 | from .krylov import _minimize_trust_krylov -------------------------------------------------------------------------------- /torchmin/trustregion/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trust-region optimization. 3 | 4 | Code ported from SciPy to PyTorch 5 | 6 | Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. 7 | All rights reserved. 8 | """ 9 | from abc import ABC, abstractmethod 10 | import torch 11 | from torch.linalg import norm 12 | from scipy.optimize import OptimizeResult 13 | 14 | from ..function import ScalarFunction 15 | from ..optim.minimizer import Minimizer 16 | 17 | try: 18 | from scipy.optimize.optimize import _status_message 19 | except ImportError: 20 | from scipy.optimize._optimize import _status_message 21 | 22 | status_messages = ( 23 | _status_message['success'], 24 | _status_message['maxiter'], 25 | 'A bad approximation caused failure to predict improvement.', 26 | 'A linalg error occurred, such as a non-psd Hessian.', 27 | ) 28 | 29 | 30 | class BaseQuadraticSubproblem(ABC): 31 | """ 32 | Base/abstract class defining the quadratic model for trust-region 33 | minimization. Child classes must implement the ``solve`` method and 34 | ``hess_prod`` property. 35 | """ 36 | def __init__(self, x, closure): 37 | # evaluate closure 38 | f, g, hessp, hess = closure(x) 39 | 40 | self._x = x 41 | self._f = f 42 | self._g = g 43 | self._h = hessp if self.hess_prod else hess 44 | self._g_mag = None 45 | self._cauchy_point = None 46 | self._newton_point = None 47 | 48 | # buffer for boundaries computation 49 | self._tab = x.new_empty(2) 50 | 51 | def __call__(self, p): 52 | return self.fun + self.jac.dot(p) + 0.5 * p.dot(self.hessp(p)) 53 | 54 | @property 55 | def fun(self): 56 | """Value of objective function at current iteration.""" 57 | return self._f 58 | 59 | @property 60 | def jac(self): 61 | """Value of Jacobian of objective function at current iteration.""" 62 | return self._g 63 | 64 | @property 65 | def hess(self): 66 | """Value of Hessian of objective function at current iteration.""" 67 | if self.hess_prod: 68 | raise Exception('class {} does not have ' 69 | 'method `hess`'.format(type(self))) 70 | return self._h 71 | 72 | def hessp(self, p): 73 | """Value of Hessian-vector product at current iteration for a 74 | particular vector ``p``. 75 | 76 | Note: ``self._h`` is either a Tensor or a LinearOperator. In either 77 | case, it has a method ``mv()``. 78 | """ 79 | return self._h.mv(p) 80 | 81 | @property 82 | def jac_mag(self): 83 | """Magnitude of jacobian of objective function at current iteration.""" 84 | if self._g_mag is None: 85 | self._g_mag = norm(self.jac) 86 | return self._g_mag 87 | 88 | def get_boundaries_intersections(self, z, d, trust_radius): 89 | """ 90 | Solve the scalar quadratic equation ||z + t d|| == trust_radius. 91 | This is like a line-sphere intersection. 92 | Return the two values of t, sorted from low to high. 93 | """ 94 | a = d.dot(d) 95 | b = 2 * z.dot(d) 96 | c = z.dot(z) - trust_radius**2 97 | sqrt_discriminant = torch.sqrt(b*b - 4*a*c) 98 | 99 | # The following calculation is mathematically equivalent to: 100 | # ta = (-b - sqrt_discriminant) / (2*a) 101 | # tb = (-b + sqrt_discriminant) / (2*a) 102 | # but produces smaller round off errors. 103 | aux = b + torch.copysign(sqrt_discriminant, b) 104 | self._tab[0] = -aux / (2*a) 105 | self._tab[1] = -2*c / aux 106 | return self._tab.sort()[0] 107 | 108 | @abstractmethod 109 | def solve(self, trust_radius): 110 | pass 111 | 112 | @property 113 | @abstractmethod 114 | def hess_prod(self): 115 | """A property that must be set by every sub-class indicating whether 116 | to use full hessian matrix or hessian-vector products.""" 117 | pass 118 | 119 | 120 | def _minimize_trust_region(fun, x0, subproblem=None, initial_trust_radius=1., 121 | max_trust_radius=1000., eta=0.15, gtol=1e-4, 122 | max_iter=None, disp=False, return_all=False, 123 | callback=None): 124 | """ 125 | Minimization of scalar function of one or more variables using a 126 | trust-region algorithm. 127 | 128 | Options for the trust-region algorithm are: 129 | initial_trust_radius : float 130 | Initial trust radius. 131 | max_trust_radius : float 132 | Never propose steps that are longer than this value. 133 | eta : float 134 | Trust region related acceptance stringency for proposed steps. 135 | gtol : float 136 | Gradient norm must be less than `gtol` 137 | before successful termination. 138 | max_iter : int 139 | Maximum number of iterations to perform. 140 | disp : bool 141 | If True, print convergence message. 142 | 143 | This function is called by :func:`torchmin.minimize`. 144 | It is not supposed to be called directly. 145 | """ 146 | if subproblem is None: 147 | raise ValueError('A subproblem solving strategy is required for ' 148 | 'trust-region methods') 149 | if not (0 <= eta < 0.25): 150 | raise Exception('invalid acceptance stringency') 151 | if max_trust_radius <= 0: 152 | raise Exception('the max trust radius must be positive') 153 | if initial_trust_radius <= 0: 154 | raise ValueError('the initial trust radius must be positive') 155 | if initial_trust_radius >= max_trust_radius: 156 | raise ValueError('the initial trust radius must be less than the ' 157 | 'max trust radius') 158 | 159 | # Input check/pre-process 160 | disp = int(disp) 161 | if max_iter is None: 162 | max_iter = x0.numel() * 200 163 | 164 | # Construct scalar objective function 165 | hessp = subproblem.hess_prod 166 | sf = ScalarFunction(fun, x0.shape, hessp=hessp, hess=not hessp) 167 | closure = sf.closure 168 | 169 | # init the search status 170 | warnflag = 1 # maximum iterations flag 171 | k = 0 172 | 173 | # initialize the search 174 | trust_radius = torch.as_tensor(initial_trust_radius, 175 | dtype=x0.dtype, device=x0.device) 176 | x = x0.detach().flatten() 177 | if return_all: 178 | allvecs = [x] 179 | 180 | # initial subproblem 181 | m = subproblem(x, closure) 182 | 183 | # search for the function min 184 | # do not even start if the gradient is small enough 185 | while k < max_iter: 186 | 187 | # Solve the sub-problem. 188 | # This gives us the proposed step relative to the current position 189 | # and it tells us whether the proposed step 190 | # has reached the trust region boundary or not. 191 | try: 192 | p, hits_boundary = m.solve(trust_radius) 193 | except RuntimeError as exc: 194 | # TODO: catch general linalg error like np.linalg.linalg.LinAlgError 195 | if 'singular' in exc.args[0]: 196 | warnflag = 3 197 | break 198 | else: 199 | raise 200 | 201 | # calculate the predicted value at the proposed point 202 | predicted_value = m(p) 203 | 204 | # define the local approximation at the proposed point 205 | x_proposed = x + p 206 | m_proposed = subproblem(x_proposed, closure) 207 | 208 | # evaluate the ratio defined in equation (4.4) 209 | actual_reduction = m.fun - m_proposed.fun 210 | predicted_reduction = m.fun - predicted_value 211 | if predicted_reduction <= 0: 212 | warnflag = 2 213 | break 214 | rho = actual_reduction / predicted_reduction 215 | 216 | # update the trust radius according to the actual/predicted ratio 217 | if rho < 0.25: 218 | trust_radius = trust_radius.mul(0.25) 219 | elif rho > 0.75 and hits_boundary: 220 | trust_radius = torch.clamp(2*trust_radius, max=max_trust_radius) 221 | 222 | # if the ratio is high enough then accept the proposed step 223 | if rho > eta: 224 | x = x_proposed 225 | m = m_proposed 226 | elif isinstance(sf, Minimizer): 227 | # if we are using a Minimizer as our ScalarFunction then we 228 | # need to re-compute the previous state because it was 229 | # overwritten during the call `subproblem(x_proposed, closure)` 230 | m = subproblem(x, closure) 231 | 232 | # append the best guess, call back, increment the iteration count 233 | if return_all: 234 | allvecs.append(x.clone()) 235 | if callback is not None: 236 | callback(x.clone()) 237 | k += 1 238 | 239 | # verbosity check 240 | if disp > 1: 241 | print('iter %d - fval: %0.4f' % (k, m.fun)) 242 | 243 | # check if the gradient is small enough to stop 244 | if m.jac_mag < gtol: 245 | warnflag = 0 246 | break 247 | 248 | # print some stuff if requested 249 | if disp: 250 | msg = status_messages[warnflag] 251 | if warnflag != 0: 252 | msg = 'Warning: ' + msg 253 | print(msg) 254 | print(" Current function value: %f" % m.fun) 255 | print(" Iterations: %d" % k) 256 | print(" Function evaluations: %d" % sf.nfev) 257 | # print(" Gradient evaluations: %d" % sf.ngev) 258 | # print(" Hessian evaluations: %d" % (sf.nhev + nhessp[0])) 259 | 260 | result = OptimizeResult(x=x.view_as(x0), fun=m.fun, grad=m.jac.view_as(x0), 261 | success=(warnflag == 0), status=warnflag, 262 | nfev=sf.nfev, nit=k, message=status_messages[warnflag]) 263 | 264 | if not subproblem.hess_prod: 265 | result['hess'] = m.hess.view(2 * x0.shape) 266 | 267 | if return_all: 268 | result['allvecs'] = allvecs 269 | 270 | return result -------------------------------------------------------------------------------- /torchmin/trustregion/dogleg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dog-leg trust-region optimization. 3 | 4 | Code ported from SciPy to PyTorch 5 | 6 | Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. 7 | All rights reserved. 8 | """ 9 | import torch 10 | from torch.linalg import norm 11 | 12 | from .base import _minimize_trust_region, BaseQuadraticSubproblem 13 | 14 | 15 | def _minimize_dogleg( 16 | fun, x0, **trust_region_options): 17 | """Minimization of scalar function of one or more variables using 18 | the dog-leg trust-region algorithm. 19 | 20 | .. warning:: 21 | The Hessian is required to be positive definite at all times; 22 | otherwise this algorithm will fail. 23 | 24 | Parameters 25 | ---------- 26 | fun : callable 27 | Scalar objective function to minimize 28 | x0 : Tensor 29 | Initialization point 30 | initial_trust_radius : float 31 | Initial trust-region radius. 32 | max_trust_radius : float 33 | Maximum value of the trust-region radius. No steps that are longer 34 | than this value will be proposed. 35 | eta : float 36 | Trust region related acceptance stringency for proposed steps. 37 | gtol : float 38 | Gradient norm must be less than `gtol` before successful 39 | termination. 40 | 41 | Returns 42 | ------- 43 | result : OptimizeResult 44 | Result of the optimization routine. 45 | 46 | References 47 | ---------- 48 | .. [1] Jorge Nocedal and Stephen Wright, 49 | Numerical Optimization, second edition, 50 | Springer-Verlag, 2006, page 73. 51 | 52 | """ 53 | return _minimize_trust_region(fun, x0, 54 | subproblem=DoglegSubproblem, 55 | **trust_region_options) 56 | 57 | 58 | class DoglegSubproblem(BaseQuadraticSubproblem): 59 | """Quadratic subproblem solved by the dogleg method""" 60 | hess_prod = False 61 | 62 | def cauchy_point(self): 63 | """ 64 | The Cauchy point is minimal along the direction of steepest descent. 65 | """ 66 | if self._cauchy_point is None: 67 | g = self.jac 68 | Bg = self.hessp(g) 69 | self._cauchy_point = -(g.dot(g) / g.dot(Bg)) * g 70 | return self._cauchy_point 71 | 72 | def newton_point(self): 73 | """ 74 | The Newton point is a global minimum of the approximate function. 75 | """ 76 | if self._newton_point is None: 77 | p = -torch.cholesky_solve(self.jac.view(-1,1), 78 | torch.linalg.cholesky(self.hess)) 79 | self._newton_point = p.view(-1) 80 | return self._newton_point 81 | 82 | def solve(self, trust_radius): 83 | """Solve quadratic subproblem""" 84 | 85 | # Compute the Newton point. 86 | # This is the optimum for the quadratic model function. 87 | # If it is inside the trust radius then return this point. 88 | p_best = self.newton_point() 89 | if norm(p_best) < trust_radius: 90 | hits_boundary = False 91 | return p_best, hits_boundary 92 | 93 | # Compute the Cauchy point. 94 | # This is the predicted optimum along the direction of steepest descent. 95 | p_u = self.cauchy_point() 96 | 97 | # If the Cauchy point is outside the trust region, 98 | # then return the point where the path intersects the boundary. 99 | p_u_norm = norm(p_u) 100 | if p_u_norm >= trust_radius: 101 | p_boundary = p_u * (trust_radius / p_u_norm) 102 | hits_boundary = True 103 | return p_boundary, hits_boundary 104 | 105 | # Compute the intersection of the trust region boundary 106 | # and the line segment connecting the Cauchy and Newton points. 107 | # This requires solving a quadratic equation. 108 | # ||p_u + t*(p_best - p_u)||**2 == trust_radius**2 109 | # Solve this for positive time t using the quadratic formula. 110 | _, tb = self.get_boundaries_intersections(p_u, p_best - p_u, 111 | trust_radius) 112 | p_boundary = p_u + tb * (p_best - p_u) 113 | hits_boundary = True 114 | return p_boundary, hits_boundary -------------------------------------------------------------------------------- /torchmin/trustregion/krylov.py: -------------------------------------------------------------------------------- 1 | """ 2 | TODO: this module is not yet complete. It is not ready for use. 3 | """ 4 | import numpy as np 5 | from scipy.linalg import eigh_tridiagonal, get_lapack_funcs 6 | import torch 7 | 8 | from .base import _minimize_trust_region, BaseQuadraticSubproblem 9 | 10 | 11 | def _minimize_trust_krylov(fun, x0, **trust_region_options): 12 | """Minimization of scalar function of one or more variables using 13 | the GLTR Krylov subspace trust-region algorithm. 14 | 15 | .. warning:: 16 | This minimizer is in early stages and has not been rigorously 17 | tested. It may change in the near future. 18 | 19 | Parameters 20 | ---------- 21 | fun : callable 22 | Scalar objective function to minimize. 23 | x0 : Tensor 24 | Initialization point. 25 | initial_tr_radius : float 26 | Initial trust-region radius. 27 | max_tr_radius : float 28 | Maximum value of the trust-region radius. No steps that are longer 29 | than this value will be proposed. 30 | eta : float 31 | Trust region related acceptance stringency for proposed steps. 32 | gtol : float 33 | Gradient norm must be less than ``gtol`` before successful 34 | termination. 35 | 36 | Returns 37 | ------- 38 | result : OptimizeResult 39 | Result of the optimization routine. 40 | 41 | Notes 42 | ----- 43 | This trust-region solver is based on the GLTR algorithm as 44 | described in [1]_ and [2]_. 45 | 46 | References 47 | ---------- 48 | .. [1] F. Lenders, C. Kirches, and A. Potschka, "trlib: A vector-free 49 | implementation of the GLTR method for...", 50 | arXiv:1611.04718. 51 | .. [2] N. Gould, S. Lucidi, M. Roma, P. Toint: “Solving the Trust-Region 52 | Subproblem using the Lanczos Method”, 53 | SIAM J. Optim., 9(2), 504–525, 1999. 54 | .. [3] J. Nocedal and S. Wright, "Numerical optimization", 55 | Springer Science & Business Media. pp. 83-91, 2006. 56 | 57 | """ 58 | return _minimize_trust_region(fun, x0, 59 | subproblem=KrylovSubproblem, 60 | **trust_region_options) 61 | 62 | 63 | class KrylovSubproblem(BaseQuadraticSubproblem): 64 | """The GLTR trust region sub-problem defined on an expanding 65 | Krylov subspace. 66 | 67 | Based on the implementation of GLTR described in [1]_. 68 | 69 | References 70 | ---------- 71 | .. [1] F. Lenders, C. Kirches, and A. Potschka, "trlib: A vector-free 72 | implementation of the GLTR method for...", 73 | arXiv:1611.04718. 74 | .. [2] N. Gould, S. Lucidi, M. Roma, P. Toint: “Solving the Trust-Region 75 | Subproblem using the Lanczos Method”, 76 | SIAM J. Optim., 9(2), 504–525, 1999. 77 | .. [3] J. Nocedal and S. Wright, "Numerical optimization", 78 | Springer Science & Business Media. pp. 83-91, 2006. 79 | """ 80 | hess_prod = True 81 | max_lanczos = None 82 | max_ms_iters = 500 # max iterations of the Moré-Sorensen loop 83 | 84 | def __init__(self, x, fun, k_easy=0.1, k_hard=0.2, tol=1e-5, ortho=True, 85 | debug=False): 86 | super().__init__(x, fun) 87 | self.eps = torch.finfo(x.dtype).eps 88 | self.k_easy = k_easy 89 | self.k_hard = k_hard 90 | self.tol = tol 91 | self.ortho = ortho 92 | self._debug = debug 93 | 94 | def tridiag_subproblem(self, Ta, Tb, tr_radius): 95 | """Solve the GLTR tridiagonal subproblem. 96 | 97 | Based on Algorithm 5.2 of [2]_. We factorize as follows: 98 | 99 | .. math:: 100 | T + lambd * I = LDL^T 101 | 102 | Where `D` is diagonal and `L` unit (lower) bi-diagonal. 103 | """ 104 | device, dtype = Ta.device, Ta.dtype 105 | 106 | # convert to numpy 107 | Ta = Ta.cpu().numpy() 108 | Tb = Tb.cpu().numpy() 109 | tr_radius = float(tr_radius) 110 | 111 | # right hand side 112 | rhs = np.zeros_like(Ta) 113 | rhs[0] = - float(self.jac_mag) 114 | 115 | # get LAPACK routines for factorizing and solving sym-PD tridiagonal 116 | ptsv, pttrs = get_lapack_funcs(('ptsv', 'pttrs'), (Ta, Tb, rhs)) 117 | 118 | eig0 = None 119 | lambd_lb = 0. 120 | lambd = 0. 121 | for _ in range(self.max_ms_iters): 122 | lambd = max(lambd, lambd_lb) 123 | 124 | # factor T + lambd * I = LDL^T and solve LDL^T p = rhs 125 | d, e, p, info = ptsv(Ta + lambd, Tb, rhs) 126 | assert info >= 0 # sanity check 127 | if info > 0: 128 | assert eig0 is None # sanity check; should only happen once 129 | # estimate smallest eigenvalue and continue 130 | eig0 = eigh_tridiagonal( 131 | Ta, Tb, eigvals_only=True, select='i', 132 | select_range=(0,0), lapack_driver='stebz').item() 133 | lambd_lb = max(1e-3 - eig0, 0) 134 | continue 135 | 136 | p_norm = np.linalg.norm(p) 137 | if p_norm < tr_radius: 138 | # TODO: add extra checks 139 | status = 0 140 | break 141 | elif abs(p_norm - tr_radius) / tr_radius <= self.k_easy: 142 | status = 1 143 | break 144 | 145 | # solve LDL^T q = p and compute 146 | v, info = pttrs(d, e, p) 147 | q_norm2 = v.dot(p) 148 | 149 | # update lambd 150 | lambd += (p_norm**2 / q_norm2) * (p_norm - tr_radius) / tr_radius 151 | else: 152 | status = -1 153 | 154 | p = torch.tensor(p, device=device, dtype=dtype) 155 | 156 | return p, status, lambd 157 | 158 | def solve(self, tr_radius): 159 | g = self.jac 160 | gamma_0 = self.jac_mag 161 | n, = g.shape 162 | m = n if self.max_lanczos is None else min(n, self.max_lanczos) 163 | dtype = g.dtype 164 | device = g.device 165 | h_best = None 166 | error_best = float('inf') 167 | 168 | # Lanczos Q matrix buffer 169 | Q = torch.zeros(m, n, dtype=dtype, device=device) 170 | Q[0] = g / gamma_0 171 | 172 | # Lanczos T matrix buffers 173 | # a and b are the diagonal and off-diagonal entries of T, respectively 174 | a = torch.zeros(m, dtype=dtype, device=device) 175 | b = torch.zeros(m, dtype=dtype, device=device) 176 | 177 | # first lanczos iteration 178 | r = self.hessp(Q[0]) 179 | torch.dot(Q[0], r, out=a[0]) 180 | r.sub_(Q[0] * a[0]) 181 | torch.linalg.norm(r, out=b[0]) 182 | if b[0] < self.eps: 183 | raise RuntimeError('initial beta is zero.') 184 | 185 | # remaining iterations 186 | for i in range(1, m): 187 | torch.div(r, b[i-1], out=Q[i]) 188 | r = self.hessp(Q[i]) 189 | r.sub_(Q[i-1] * b[i-1]) 190 | torch.dot(Q[i], r, out=a[i]) 191 | r.sub_(Q[i] * a[i]) 192 | if self.ortho: 193 | # Re-orthogonalize with Gram-Schmidt 194 | r.addmv_(Q[:i+1].T, Q[:i+1].mv(r), alpha=-1) 195 | torch.linalg.norm(r, out=b[i]) 196 | if b[i] < self.eps: 197 | # This should never occur when self.ortho=True 198 | raise RuntimeError('reducible T matrix encountered.') 199 | 200 | # GLTR sub-problem 201 | h, status, lambd = self.tridiag_subproblem(a[:i+1], b[:i], tr_radius) 202 | 203 | if status >= 0: 204 | # convergence check; see Algorithm 1 of [1]_ and 205 | # Algorithm 5.1 of [2]_. Equivalent to the following: 206 | # p = Q[:i+1].T.mv(h) 207 | # error = torch.linalg.norm(self.hessp(p) + lambd * p + g) 208 | error = b[i] * h[-1].abs() 209 | if self._debug: 210 | print('iter %3d - status: %d - lambd: %0.4e - error: %0.4e' 211 | % (i+1, status, lambd, error)) 212 | if error < error_best: 213 | # we've found a new best 214 | hits_boundary = status != 0 215 | h_best = h 216 | error_best = error 217 | if error_best <= self.tol: 218 | break 219 | 220 | elif self._debug: 221 | print('iter %3d - status: %d - lambd: %0.4e' % 222 | (i+1, status, lambd)) 223 | 224 | if h_best is None: 225 | # TODO: what should we do here? 226 | raise RuntimeError('gltr solution not found') 227 | 228 | # project h back to R^n 229 | p_best = Q[:i+1].T.mv(h_best) 230 | 231 | return p_best, hits_boundary 232 | -------------------------------------------------------------------------------- /torchmin/trustregion/ncg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Newton-CG trust-region optimization. 3 | 4 | Code ported from SciPy to PyTorch 5 | 6 | Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. 7 | All rights reserved. 8 | """ 9 | import torch 10 | from torch.linalg import norm 11 | 12 | from .base import _minimize_trust_region, BaseQuadraticSubproblem 13 | 14 | 15 | def _minimize_trust_ncg( 16 | fun, x0, **trust_region_options): 17 | """Minimization of scalar function of one or more variables using 18 | the Newton conjugate gradient trust-region algorithm. 19 | 20 | Parameters 21 | ---------- 22 | fun : callable 23 | Scalar objective function to minimize. 24 | x0 : Tensor 25 | Initialization point. 26 | initial_trust_radius : float 27 | Initial trust-region radius. 28 | max_trust_radius : float 29 | Maximum value of the trust-region radius. No steps that are longer 30 | than this value will be proposed. 31 | eta : float 32 | Trust region related acceptance stringency for proposed steps. 33 | gtol : float 34 | Gradient norm must be less than ``gtol`` before successful 35 | termination. 36 | 37 | Returns 38 | ------- 39 | result : OptimizeResult 40 | Result of the optimization routine. 41 | 42 | Notes 43 | ----- 44 | This is algorithm (7.2) of Nocedal and Wright 2nd edition. 45 | Only the function that computes the Hessian-vector product is required. 46 | The Hessian itself is not required, and the Hessian does 47 | not need to be positive semidefinite. 48 | 49 | """ 50 | return _minimize_trust_region(fun, x0, 51 | subproblem=CGSteihaugSubproblem, 52 | **trust_region_options) 53 | 54 | 55 | class CGSteihaugSubproblem(BaseQuadraticSubproblem): 56 | """Quadratic subproblem solved by a conjugate gradient method""" 57 | hess_prod = True 58 | 59 | def solve(self, trust_radius): 60 | """Solve the subproblem using a conjugate gradient method. 61 | 62 | Parameters 63 | ---------- 64 | trust_radius : float 65 | We are allowed to wander only this far away from the origin. 66 | 67 | Returns 68 | ------- 69 | p : Tensor 70 | The proposed step. 71 | hits_boundary : bool 72 | True if the proposed step is on the boundary of the trust region. 73 | 74 | """ 75 | 76 | # get the norm of jacobian and define the origin 77 | p_origin = torch.zeros_like(self.jac) 78 | 79 | # define a default tolerance 80 | tolerance = self.jac_mag * self.jac_mag.sqrt().clamp(max=0.5) 81 | 82 | # Stop the method if the search direction 83 | # is a direction of nonpositive curvature. 84 | if self.jac_mag < tolerance: 85 | hits_boundary = False 86 | return p_origin, hits_boundary 87 | 88 | # init the state for the first iteration 89 | z = p_origin 90 | r = self.jac 91 | d = -r 92 | 93 | # Search for the min of the approximation of the objective function. 94 | while True: 95 | 96 | # do an iteration 97 | Bd = self.hessp(d) 98 | dBd = d.dot(Bd) 99 | if dBd <= 0: 100 | # Look at the two boundary points. 101 | # Find both values of t to get the boundary points such that 102 | # ||z + t d|| == trust_radius 103 | # and then choose the one with the predicted min value. 104 | ta, tb = self.get_boundaries_intersections(z, d, trust_radius) 105 | pa = z + ta * d 106 | pb = z + tb * d 107 | p_boundary = torch.where(self(pa).lt(self(pb)), pa, pb) 108 | hits_boundary = True 109 | return p_boundary, hits_boundary 110 | r_squared = r.dot(r) 111 | alpha = r_squared / dBd 112 | z_next = z + alpha * d 113 | if norm(z_next) >= trust_radius: 114 | # Find t >= 0 to get the boundary point such that 115 | # ||z + t d|| == trust_radius 116 | ta, tb = self.get_boundaries_intersections(z, d, trust_radius) 117 | p_boundary = z + tb * d 118 | hits_boundary = True 119 | return p_boundary, hits_boundary 120 | r_next = r + alpha * Bd 121 | r_next_squared = r_next.dot(r_next) 122 | if r_next_squared.sqrt() < tolerance: 123 | hits_boundary = False 124 | return z_next, hits_boundary 125 | beta_next = r_next_squared / r_squared 126 | d_next = -r_next + beta_next * d 127 | 128 | # update the state for the next iteration 129 | z = z_next 130 | r = r_next 131 | d = d_next --------------------------------------------------------------------------------