` section for more details.
78 |
79 | Levenberg Marquardt
80 | -------------------
81 |
82 | .. autosummary::
83 | :toctree: _autosummary
84 |
85 | jaxopt.LevenbergMarquardt
86 |
87 | We can also use the Levenberg-Marquardt method, which is a more advanced method compared to Gauss-Newton, in
88 | that it regularizes the update equation. It helps for cases where Gauss-Newton method fails to converge.
89 |
90 | Update equation
91 | ~~~~~~~~~~~~~~~
92 |
93 | The following equation is solved for every iteration to find the update to the
94 | parameters:
95 |
96 | .. math::
97 | (\mathbf{J} \mathbf{J^T} + \mu\mathbf{I}) h_{lm} = - \mathbf{J^T} \mathbf{r}
98 |
99 | where :math:`\mathbf{J}` is the Jacobian of the residual function w.r.t.
100 | parameters and :math:`\mu` is the damping parameter.
101 |
--------------------------------------------------------------------------------
/docs/notebooks/deep_learning/thumbnails/adversarial_training.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/jaxopt/b12c2f3ddaa3ef71fac42d7e419adccb9dcc49a1/docs/notebooks/deep_learning/thumbnails/adversarial_training.png
--------------------------------------------------------------------------------
/docs/notebooks/deep_learning/thumbnails/resnet_flax.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/jaxopt/b12c2f3ddaa3ef71fac42d7e419adccb9dcc49a1/docs/notebooks/deep_learning/thumbnails/resnet_flax.png
--------------------------------------------------------------------------------
/docs/notebooks/deep_learning/thumbnails/resnet_haiku.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/jaxopt/b12c2f3ddaa3ef71fac42d7e419adccb9dcc49a1/docs/notebooks/deep_learning/thumbnails/resnet_haiku.png
--------------------------------------------------------------------------------
/docs/notebooks/distributed/thumbnails/plot_custom_loop_pjit_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/jaxopt/b12c2f3ddaa3ef71fac42d7e419adccb9dcc49a1/docs/notebooks/distributed/thumbnails/plot_custom_loop_pjit_example.png
--------------------------------------------------------------------------------
/docs/notebooks/distributed/thumbnails/plot_custom_loop_pmap_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/jaxopt/b12c2f3ddaa3ef71fac42d7e419adccb9dcc49a1/docs/notebooks/distributed/thumbnails/plot_custom_loop_pmap_example.png
--------------------------------------------------------------------------------
/docs/notebooks/implicit_diff/thumbnails/maml.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/jaxopt/b12c2f3ddaa3ef71fac42d7e419adccb9dcc49a1/docs/notebooks/implicit_diff/thumbnails/maml.png
--------------------------------------------------------------------------------
/docs/notebooks/implicit_diff/thumbnails/plot_dataset_distillation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/jaxopt/b12c2f3ddaa3ef71fac42d7e419adccb9dcc49a1/docs/notebooks/implicit_diff/thumbnails/plot_dataset_distillation.png
--------------------------------------------------------------------------------
/docs/notebooks/index.rst:
--------------------------------------------------------------------------------
1 |
2 | .. _notebook_gallery:
3 |
4 | Notebook gallery
5 | ================
6 |
7 |
8 | Deep learning
9 | -------------
10 |
11 |
12 | .. raw:: html
13 |
14 |
15 |
16 |
17 |
18 | .. only:: html
19 |
20 | .. figure:: /notebooks/deep_learning/thumbnails/resnet_flax.png
21 | :alt: Resnet example with Flax and JAXopt.
22 |
23 | :doc:`/notebooks/deep_learning/resnet_flax`
24 |
25 | .. raw:: html
26 |
27 |
28 |
29 |
30 |
31 | .. only:: html
32 |
33 | .. figure:: /notebooks/deep_learning/thumbnails/resnet_haiku.png
34 | :alt: Resnet example with Haiku and JAXopt.
35 |
36 | :doc:`/notebooks/deep_learning/resnet_haiku`
37 |
38 | .. raw:: html
39 |
40 |
41 |
42 |
43 |
44 | .. only:: html
45 |
46 | .. figure:: /notebooks/deep_learning/thumbnails/adversarial_training.png
47 | :alt: Adversarial Training.
48 |
49 | :doc:`/notebooks/deep_learning/adversarial_training`
50 |
51 | .. raw:: html
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 | Implicit Differentiation
61 | ------------------------
62 |
63 |
64 | .. raw:: html
65 |
66 |
67 |
68 |
69 |
70 | .. only:: html
71 |
72 | .. figure:: /notebooks/implicit_diff/thumbnails/plot_dataset_distillation.png
73 | :alt: Dataset distillation example with JAXopt.
74 |
75 | :doc:`/notebooks/implicit_diff/dataset_distillation`
76 |
77 | .. raw:: html
78 |
79 |
80 |
81 |
82 | .. raw:: html
83 |
84 |
85 |
86 | .. only:: html
87 |
88 | .. figure:: /notebooks/implicit_diff/thumbnails/maml.png
89 | :alt: Few-shot Adaptation with Model Agnostic Meta-Learning (MAML)
90 |
91 | :doc:`/notebooks/implicit_diff/maml`
92 |
93 | .. raw:: html
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 | Distributed Optimization
102 | ------------------------
103 |
104 |
105 | .. raw:: html
106 |
107 |
108 |
109 |
110 |
111 | .. only:: html
112 |
113 | .. figure:: /notebooks/distributed/thumbnails/plot_custom_loop_pjit_example.png
114 | :alt: `jax.experimental.pjit` example using JAXopt.
115 |
116 | :doc:`/notebooks/distributed/custom_loop_pjit_example`
117 |
118 | .. raw:: html
119 |
120 |
121 |
122 |
123 | .. raw:: html
124 |
125 |
126 |
127 | .. only:: html
128 |
129 | .. figure:: /notebooks/distributed/thumbnails/plot_custom_loop_pmap_example.png
130 | :alt: `jax.pmap` example using JAXopt.
131 |
132 | :doc:`/notebooks/distributed/custom_loop_pmap_example`
133 |
134 |
135 | .. raw:: html
136 |
137 |
138 |
139 |
140 |
141 |
142 | Perturbed optimizers
143 | --------------------
144 |
145 |
146 | .. raw:: html
147 |
148 |
149 |
150 |
151 |
152 | .. only:: html
153 |
154 | .. figure:: /notebooks/perturbed_optimizers/thumbnails/perturbations.png
155 | :alt: Perturbed optimizers with JAXopt.
156 |
157 | :doc:`/notebooks/perturbed_optimizers/perturbed_optimizers`
158 |
159 | .. raw:: html
160 |
161 |
162 |
163 |
164 |
165 |
--------------------------------------------------------------------------------
/docs/notebooks/perturbed_optimizers/thumbnails/perturbations.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/jaxopt/b12c2f3ddaa3ef71fac42d7e419adccb9dcc49a1/docs/notebooks/perturbed_optimizers/thumbnails/perturbations.png
--------------------------------------------------------------------------------
/docs/objective_and_loss.rst:
--------------------------------------------------------------------------------
1 | Loss and objective functions
2 | ============================
3 |
4 | Loss functions
5 | --------------
6 |
7 | Regression
8 | ~~~~~~~~~~
9 |
10 | .. autosummary::
11 | :toctree: _autosummary
12 |
13 | jaxopt.loss.huber_loss
14 |
15 | Regression losses are of the form ``loss(float: target, float: pred) -> float``,
16 | where ``target`` is the ground-truth and ``pred`` is the model's output.
17 |
18 | Binary classification
19 | ~~~~~~~~~~~~~~~~~~~~~
20 |
21 | .. autosummary::
22 | :toctree: _autosummary
23 |
24 | jaxopt.loss.binary_logistic_loss
25 | jaxopt.loss.binary_sparsemax_loss
26 | jaxopt.loss.binary_hinge_loss
27 | jaxopt.loss.binary_perceptron_loss
28 |
29 | Binary classification losses are of the form ``loss(int: label, float: score) -> float``,
30 | where ``label`` is the ground-truth (``0`` or ``1``) and ``score`` is the model's output.
31 |
32 | The following utility functions are useful for the binary sparsemax loss.
33 |
34 | .. autosummary::
35 | :toctree: _autosummary
36 |
37 | jaxopt.loss.sparse_plus
38 | jaxopt.loss.sparse_sigmoid
39 |
40 | Multiclass classification
41 | ~~~~~~~~~~~~~~~~~~~~~~~~~
42 |
43 | .. autosummary::
44 | :toctree: _autosummary
45 |
46 | jaxopt.loss.multiclass_logistic_loss
47 | jaxopt.loss.multiclass_sparsemax_loss
48 | jaxopt.loss.multiclass_hinge_loss
49 | jaxopt.loss.multiclass_perceptron_loss
50 |
51 | Multiclass classification losses are of the form ``loss(int: label, jnp.ndarray: scores) -> float``,
52 | where ``label`` is the ground-truth (between ``0`` and ``n_classes - 1``) and
53 | ``scores`` is an array of size ``n_classes``.
54 |
55 | Applying loss functions on a batch
56 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
57 |
58 | All loss functions above are pointwise, meaning that they operate on a single sample. Use ``jax.vmap(loss)``
59 | followed by a reduction such as ``jnp.mean`` or ``jnp.sum`` to use on a batch.
60 |
61 | Objective functions
62 | -------------------
63 |
64 | .. _composite_linear_functions:
65 |
66 | Composite linear functions
67 | ~~~~~~~~~~~~~~~~~~~~~~~~~~
68 |
69 | .. autosummary::
70 | :toctree: _autosummary
71 |
72 | jaxopt.objective.least_squares
73 | jaxopt.objective.binary_logreg
74 | jaxopt.objective.multiclass_logreg
75 | jaxopt.objective.multiclass_linear_svm_dual
76 |
77 | Composite linear objective functions can be used with
78 | :ref:`block coordinate descent `.
79 |
80 | Other functions
81 | ~~~~~~~~~~~~~~~
82 |
83 | .. autosummary::
84 | :toctree: _autosummary
85 |
86 | jaxopt.objective.ridge_regression
87 | jaxopt.objective.multiclass_logreg_with_intercept
88 | jaxopt.objective.l2_multiclass_logreg
89 | jaxopt.objective.l2_multiclass_logreg_with_intercept
90 |
--------------------------------------------------------------------------------
/docs/perturbations.rst:
--------------------------------------------------------------------------------
1 | Perturbed optimization
2 | ======================
3 |
4 | The perturbed optimization module allows to transform a non-smooth function such as a max or arg-max into a differentiable function using random perturbations. This is useful for optimization algorithms that require differentiability, such as gradient descent (e.g. see :doc:`Notebook ` on perturbed optimizers).
5 |
6 |
7 | Max perturbations
8 | -----------------
9 |
10 | Consider a maximum function of the form:
11 |
12 | .. math::
13 |
14 | F(\theta) = \max_{y \in \mathcal{C}} \langle y, \theta\rangle\,,
15 |
16 | where :math:`\mathcal{C}` is a convex set.
17 |
18 |
19 |
20 | .. autosummary::
21 | :toctree: _autosummary
22 |
23 | jaxopt.perturbations.make_perturbed_max
24 |
25 |
26 |
27 |
28 | The function :meth:`jaxopt.perturbations.make_perturbed_max` transforms the function :math:`F` into a the following differentiable function using random perturbations:
29 |
30 |
31 | .. math::
32 |
33 | F_{\varepsilon}(\theta) = \mathbb{E}\left[ F(\theta + \varepsilon Z) \right]\,,
34 |
35 | where :math:`Z` is a random variable. The distribution of this random variable can be specified through the keyword argument ``noise``. The default is a Gumbel distribution, which is a good choice for discrete variables. For continuous variables, a normal distribution is more appropriate.
36 |
37 |
38 | Argmax perturbations
39 | --------------------
40 |
41 | Consider an arg-max function of the form:
42 |
43 | .. math::
44 |
45 | y^*(\theta) = \mathop{\mathrm{arg\,max}}_{y \in \mathcal{C}} \langle y, \theta\rangle\,,
46 |
47 | where :math:`\mathcal{C}` is a convex set.
48 |
49 |
50 | The function :meth:`jaxopt.perturbations.make_perturbed_argmax` transforms the function :math:`y^\star` into a the following differentiable function using random perturbations:
51 |
52 |
53 | .. math::
54 |
55 | y_{\varepsilon}^*(\theta) = \mathbb{E}\left[ \mathop{\mathrm{arg\,max}}_{y \in \mathcal{C}} \langle y, \theta + \varepsilon Z \rangle \right]\,,
56 |
57 | where :math:`Z` is a random variable. The distribution of this random variable can be specified through the keyword argument ``noise``. The default is a Gumbel distribution, which is a good choice for discrete variables. For continuous variables, a normal distribution is more appropriate.
58 |
59 |
60 | .. autosummary::
61 | :toctree: _autosummary
62 |
63 | jaxopt.perturbations.make_perturbed_argmax
64 |
65 |
66 | Scalar perturbations
67 | --------------------
68 |
69 | Consider any function, :math:`f` that is not necessarily differentiable, e.g. piecewise-constant of the form:
70 |
71 | .. math::
72 |
73 | f(\theta) = g(y^*(\theta))\,,
74 |
75 | where :math:`\mathop{\mathrm{arg\,max}}_{y \in \mathcal{C}} \langle y, \theta\rangle`` and :math:`\mathcal{C}` is a convex set.
76 |
77 |
78 | The function :meth:`jaxopt.perturbations.make_perturbed_fun` transforms the function :math:`f` into a the following differentiable function using random perturbations:
79 |
80 | .. math::
81 |
82 | f_{\varepsilon}(\theta) = \mathbb{E}\left[ f(\theta + \varepsilon Z) \right]\,,
83 |
84 | where :math:`Z` is a random variable. The distribution of this random variable can be specified through the keyword argument ``noise``. The default is a Gumbel distribution, which is a good choice for discrete variables. For continuous variables, a normal distribution is more appropriate. This can be particulary useful in the example given above, when :math:`f` is only defined on the discrete set, not its convex hull, i.e.
85 |
86 | .. math::
87 |
88 | f_{\varepsilon}(\theta) = \mathbb{E}\left[ g(\mathop{\mathrm{arg\,max}}_{y \in \mathcal{C}} \langle y, \theta + \varepsilon Z \rangle) \right]\,,
89 |
90 |
91 | .. autosummary::
92 | :toctree: _autosummary
93 |
94 | jaxopt.perturbations.make_perturbed_fun
95 |
96 |
97 | Noise distributions
98 | -------------------
99 |
100 | The functions :meth:`jaxopt.perturbations.make_perturbed_max`, :meth:`jaxopt.perturbations.make_perturbed_argmax` and :meth:`jaxopt.perturbations.make_perturbed_fun` take a keyword argument ``noise`` that specifies the distribution of random perturbations. Pre-defined distributions for this argument are the following:
101 |
102 | .. autosummary::
103 | :toctree: _autosummary
104 |
105 | jaxopt.perturbations.Normal
106 | jaxopt.perturbations.Gumbel
107 |
108 |
109 |
110 |
111 | .. topic:: References
112 |
113 | Berthet, Q., Blondel, M., Teboul, O., Cuturi, M., Vert, J. P., & Bach, F. (2020). `Learning with differentiable pertubed optimizers `_. Advances in neural information processing systems, 33.
114 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib>=3.3.4
2 | pandoc>=1.0.2
3 | sphinx>=3.5.1
4 | sphinx_rtd_theme>=0.5.1
5 | sphinx_autodoc_typehints>=1.11.1
6 | ipython>=7.20.0
7 | ipykernel>=5.5.0
8 | sphinx-gallery>=0.9.0
9 | sphinx_copybutton>=0.4.0
10 | sphinx-remove-toctrees>=0.0.3
11 | jupyter-sphinx>=0.3.2
12 | myst-nb
13 | tensorflow-datasets
14 | tensorflow
15 | dm-haiku
16 | flax
17 | jupytext
18 | scikit-learn
--------------------------------------------------------------------------------
/docs/root_finding.rst:
--------------------------------------------------------------------------------
1 | .. _root_finding:
2 |
3 | Root finding
4 | ============
5 |
6 | This section is concerned with root finding, that is finding :math:`x` such
7 | that :math:`F(x, \theta) = 0`.
8 |
9 | Bisection
10 | ---------
11 |
12 | .. autosummary::
13 | :toctree: _autosummary
14 |
15 | jaxopt.Bisection
16 |
17 | Bisection is a suitable algorithm when :math:`F(x, \theta)` is one-dimensional
18 | in :math:`x`.
19 |
20 | Instantiating and running the solver
21 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
22 |
23 | First, let us consider the case :math:`F(x)`, i.e., without extra argument
24 | :math:`\theta`. The ``Bisection`` class requires a bracketing interval
25 | :math:`[\text{lower}, \text{upper}]`` such that :math:`F(\text{lower})` and
26 | :math:`F(\text{upper})` have opposite signs, meaning that a root is contained
27 | in this interval as long as :math:`F` is continuous. For instance, suppose
28 | that we want to find the root of :math:`F(x) = x^3 - x - 2`. We have
29 | :math:`F(1) = -2` and :math:`F(2) = 4`. Since the function is continuous, there
30 | must be a :math:`x` between 1 and 2 such that :math:`F(x) = 0`::
31 |
32 | from jaxopt import Bisection
33 |
34 | def F(x):
35 | return x ** 3 - x - 2
36 |
37 | bisec = Bisection(optimality_fun=F, lower=1, upper=2)
38 | print(bisec.run().params)
39 |
40 | ``Bisection`` successfully finds the root ``x = 1.521``.
41 | Notice that ``Bisection`` does not require an initialization,
42 | since the bracketing interval is sufficient.
43 |
44 | Differentiation
45 | ~~~~~~~~~~~~~~~
46 |
47 | Now, let us consider the case :math:`F(x, \theta)`. For instance, suppose that
48 | ``F`` takes an additional argument ``factor``. We can easily differentiate
49 | with respect to ``factor``::
50 |
51 | def F(x, factor):
52 | return factor * x ** 3 - x - 2
53 |
54 | def root(factor):
55 | bisec = Bisection(optimality_fun=F, lower=1, upper=2)
56 | return bisec.run(factor=factor).params
57 |
58 | # Derivative of root with respect to factor at 2.0.
59 | print(jax.grad(root)(2.0))
60 |
61 | Under the hood, we use the implicit function theorem in order to differentiate the root.
62 | See the :ref:`implicit differentiation ` section for more details.
63 |
64 | Scipy wrapper
65 | -------------
66 |
67 | .. autosummary::
68 | :toctree: _autosummary
69 |
70 | jaxopt.ScipyRootFinding
71 |
72 |
73 | Broyden's method
74 | ----------------
75 |
76 | .. autosummary::
77 | :toctree: _autosummary
78 |
79 | jaxopt.Broyden
80 |
81 | Broyden's method is an iterative algorithm suitable for nonlinear root equations in any dimension.
82 | It is a quasi-Newton method (like L-BFGS), meaning that it uses an approximation of the Jacobian matrix
83 | at each iteration.
84 | The approximation is updated at each iteration with a rank-one update.
85 | This makes the approximation easy to invert using the Sherman-Morrison formula, provided that it does not use too many
86 | updates.
87 | One can control the number of updates with the ``history_size`` argument.
88 | Furthermore, Broyden's method uses a line search to ensure the rank-one updates are stable.
89 |
90 | Example::
91 |
92 | import jax.numpy as jnp
93 | from jaxopt import Broyden
94 |
95 | def F(x):
96 | return x ** 3 - x - 2
97 |
98 | broyden = Broyden(fun=F)
99 | print(broyden.run(jnp.array(1.0)).params)
100 |
101 |
102 | For implicit differentiation::
103 |
104 | import jax
105 | import jax.numpy as jnp
106 | from jaxopt import Broyden
107 |
108 | def F(x, factor):
109 | return factor * x ** 3 - x - 2
110 |
111 | def root(factor):
112 | broyden = Broyden(fun=F)
113 | return broyden.run(jnp.array(1.0), factor=factor).params
114 |
115 | # Derivative of root with respect to factor at 2.0.
116 | print(jax.grad(root)(2.0))
117 |
--------------------------------------------------------------------------------
/docs/stochastic.rst:
--------------------------------------------------------------------------------
1 | Stochastic optimization
2 | =======================
3 |
4 | This section is concerned with problems of the form
5 |
6 | .. math::
7 |
8 | \min_{x} \mathbb{E}_{D}[f(x, \theta, D)],
9 |
10 | where :math:`f(x, \theta, D)` is differentiable (almost everywhere), :math:`x`
11 | are the parameters with respect to which the function is minimized,
12 | :math:`\theta` are optional fixed extra arguments and :math:`D` is a random
13 | variable (typically a mini-batch).
14 |
15 |
16 | .. topic:: Examples
17 |
18 | * :doc:`/notebooks/deep_learning/resnet_haiku`
19 | * :doc:`/notebooks/deep_learning/resnet_flax`
20 | * :ref:`sphx_glr_auto_examples_deep_learning_haiku_vae.py`
21 | * :ref:`sphx_glr_auto_examples_deep_learning_plot_sgd_solvers.py`
22 |
23 |
24 | Defining an objective function
25 | ------------------------------
26 |
27 | Objective functions must contain a ``data`` argument corresponding to :math:`D` above.
28 |
29 | Example::
30 |
31 | def ridge_reg_objective(params, l2reg, data):
32 | X, y = data
33 | residuals = jnp.dot(X, params) - y
34 | return jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.dot(w ** 2)
35 |
36 | Data iterator
37 | -------------
38 |
39 | Sampling realizations of the random variable :math:`D` can be done using an iterator.
40 |
41 | Example::
42 |
43 | def data_iterator():
44 | for _ in range(n_iter):
45 | perm = rng.permutation(n_samples)[:batch_size]
46 | yield (X[perm], y[perm])
47 |
48 | Solvers
49 | -------
50 |
51 | .. autosummary::
52 | :toctree: _autosummary
53 |
54 | jaxopt.ArmijoSGD
55 | jaxopt.OptaxSolver
56 | jaxopt.PolyakSGD
57 |
58 | Optax solvers
59 | ~~~~~~~~~~~~~
60 |
61 | `Optax `_ solvers can be used in JAXopt using
62 | :class:`OptaxSolver `. Here's an example with Adam::
63 |
64 | from jaxopt import OptaxSolver
65 |
66 | opt = optax.adam(learning_rate)
67 | solver = OptaxSolver(opt=opt, fun=ridge_reg_objective, maxiter=1000)
68 |
69 | See `common optimizers
70 | `_ in the
71 | optax documentation for a list of available stochastic solvers.
72 |
73 | Adaptive solvers
74 | ~~~~~~~~~~~~~~~~
75 |
76 | Adaptive solvers update the step size at each iteration dynamically.
77 | An example is :class:`PolyakSGD `, a solver
78 | which computes step sizes adaptively using function values.
79 |
80 | Another example is :class:`ArmijoSGD `, a solver
81 | that uses an Armijo line search.
82 |
83 | For convergence guarantees to hold, these two algorithms
84 | require the interpolation hypothesis to hold:
85 | the global optimum over :math:`D` must also be a global optimum
86 | for any finite sample of :math:`D`.
87 | This is typically achieved by overparametrized models (e.g neural networks)
88 | in classification tasks with separable classes, or on regression tasks without noise.
89 |
90 | Run iterator vs. manual loop
91 | ----------------------------
92 |
93 | The following::
94 |
95 | iterator = data_iterator()
96 | solver.run_iterator(init_params, iterator, l2reg=l2reg)
97 |
98 | is equivalent to::
99 |
100 | iterator = data_iterator()
101 | state = solver.init_state(init_params, l2reg=l2reg)
102 | params = init_params
103 | for _ in range(maxiter):
104 | data = next(iterator)
105 | params, state = solver.update(params, state, l2reg=l2reg, data=data)
106 |
--------------------------------------------------------------------------------
/docs/unconstrained.rst:
--------------------------------------------------------------------------------
1 | .. _unconstrained_optim:
2 |
3 | Unconstrained optimization
4 | ==========================
5 |
6 | This section is concerned with problems of the form
7 |
8 | .. math::
9 |
10 | \min_{x} f(x, \theta)
11 |
12 | where :math:`f(x, \theta)` is a differentiable (almost everywhere), :math:`x`
13 | are the parameters with respect to which the function is minimized and
14 | :math:`\theta` are optional extra arguments.
15 |
16 | Defining an objective function
17 | ------------------------------
18 |
19 | Objective functions must always include as first argument the variables with
20 | respect to which the function is minimized. The function can also contain extra
21 | arguments.
22 |
23 | The following illustrates how to express the ridge regression objective::
24 |
25 | def ridge_reg_objective(params, l2reg, X, y):
26 | residuals = jnp.dot(X, params) - y
27 | return jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.sum(params ** 2)
28 |
29 | The model parameters ``params`` correspond to :math:`x` while ``l2reg``, ``X``
30 | and ``y`` correspond to the extra arguments :math:`\theta` in the mathematical
31 | notation above.
32 |
33 | Solvers
34 | -------
35 |
36 | .. autosummary::
37 | :toctree: _autosummary
38 |
39 | jaxopt.BFGS
40 | jaxopt.GradientDescent
41 | jaxopt.LBFGS
42 | jaxopt.ScipyMinimize
43 | jaxopt.NonlinearCG
44 |
45 | Instantiating and running the solver
46 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
47 |
48 | Continuing the ridge regression example above, gradient descent can be
49 | instantiated and run as follows::
50 |
51 | solver = jaxopt.LBFGS(fun=ridge_reg_objective, maxiter=maxiter)
52 | res = solver.run(init_params, l2reg=l2reg, X=X, y=y)
53 |
54 | # Alternatively, we could have used one of these solvers as well:
55 | # solver = jaxopt.GradientDescent(fun=ridge_reg_objective, maxiter=500)
56 | # solver = jaxopt.ScipyMinimize(fun=ridge_reg_objective, method="L-BFGS-B", maxiter=500)
57 | # solver = jaxopt.NonlinearCG(fun=ridge_reg_objective, method="polak-ribiere", maxiter=500)
58 |
59 | Unpacking results
60 | ~~~~~~~~~~~~~~~~~
61 |
62 | Note that ``res`` has the form ``NamedTuple(params, state)``, where ``params``
63 | are the approximate solution found by the solver and ``state`` contains
64 | solver-specific information about convergence.
65 |
66 | Because ``res`` is a ``NamedTuple``, we can unpack it as::
67 |
68 | params, state = res
69 | print(params, state)
70 |
71 | Alternatively, we can also access attributes directly::
72 |
73 | print(res.params, res.state)
74 |
--------------------------------------------------------------------------------
/examples/README.rst:
--------------------------------------------------------------------------------
1 | .. _general_examples:
2 |
3 | Example gallery
4 | ===============
5 |
6 | To clone the repository and the examples, please run::
7 |
8 | $ git clone https://github.com/google/jaxopt.git
9 |
10 | or download this `zip file `_.
11 |
12 | To install the libraries that the examples depend on, please run::
13 |
14 | $ pip install -r examples/requirements.txt
15 |
--------------------------------------------------------------------------------
/examples/constrained/README.rst:
--------------------------------------------------------------------------------
1 | .. _constrained_examples:
2 |
3 | Constrained optimization
4 | ------------------------
5 |
6 |
--------------------------------------------------------------------------------
/examples/constrained/multiclass_linear_svm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Multiclass linear SVM (without intercept).
17 | ==========================================
18 |
19 | This quadratic program can be solved either with OSQP or with block coordinate descent.
20 |
21 | Reference:
22 |
23 | Crammer, K. and Singer, Y., 2001. On the algorithmic implementation of multiclass kernel-based vector machines.
24 | Journal of machine learning research, 2(Dec), pp.265-292.
25 | """
26 |
27 | from absl import app
28 | from absl import flags
29 |
30 | import jax
31 | import jax.numpy as jnp
32 |
33 | from jaxopt import BlockCoordinateDescent
34 | from jaxopt import OSQP
35 | from jaxopt import objective
36 | from jaxopt import projection
37 | from jaxopt import prox
38 |
39 | from sklearn import datasets
40 | from sklearn import preprocessing
41 | from sklearn import svm
42 |
43 |
44 | flags.DEFINE_float("tol", 1e-5, "Tolerance of solvers.")
45 | flags.DEFINE_float("l2reg", 1000., "Regularization parameter. Must be positive.")
46 | flags.DEFINE_integer("num_samples", 20, "Size of train set.")
47 | flags.DEFINE_integer("num_features", 5, "Features dimension.")
48 | flags.DEFINE_integer("num_classes", 3, "Number of classes.")
49 | flags.DEFINE_bool("verbose", False, "Verbosity.")
50 | FLAGS = flags.FLAGS
51 |
52 |
53 | def multiclass_linear_svm_skl(X, y, l2reg):
54 | print("Solve multiclass SVM with sklearn.svm.LinearSVC:")
55 | svc = svm.LinearSVC(loss="hinge", dual=True, multi_class="crammer_singer",
56 | C=1.0 / l2reg, fit_intercept=False,
57 | tol=FLAGS.tol, max_iter=100*1000).fit(X, y)
58 | return svc.coef_.T
59 |
60 |
61 | def multiclass_linear_svm_bcd(X, Y, l2reg):
62 | print("Block coordinate descent solution:")
63 |
64 | # Set up parameters.
65 | block_prox = prox.make_prox_from_projection(projection.projection_simplex)
66 | fun = objective.multiclass_linear_svm_dual
67 | data = (X, Y)
68 | beta_init = jnp.ones((X.shape[0], Y.shape[-1])) / Y.shape[-1]
69 |
70 | # Run solver.
71 | bcd = BlockCoordinateDescent(fun=fun, block_prox=block_prox,
72 | maxiter=10*1000, tol=FLAGS.tol)
73 | sol = bcd.run(beta_init, hyperparams_prox=None, l2reg=FLAGS.l2reg, data=data)
74 | return sol.params
75 |
76 |
77 | def multiclass_linear_svm_osqp(X, Y, l2reg):
78 | # We solve the problem
79 | #
80 | # minimize 0.5/l2reg beta X X.T beta - (1. - Y)^T beta - 1./l2reg (Y^T X) X^T beta
81 | # under beta >= 0
82 | # sum_i beta_i = 1
83 | #
84 | print("OSQP solution solution:")
85 |
86 | def matvec_Q(X, beta):
87 | return 1./l2reg * jnp.dot(X, jnp.dot(X.T, beta))
88 |
89 | linear_part = - (1. - Y) - 1./l2reg * jnp.dot(X, jnp.dot(X.T, Y))
90 |
91 | def matvec_A(_, beta):
92 | return jnp.sum(beta, axis=-1)
93 |
94 | def matvec_G(_, beta):
95 | return -beta
96 |
97 | b = jnp.ones(X.shape[0])
98 | h = jnp.zeros_like(Y)
99 |
100 | osqp = OSQP(matvec_Q=matvec_Q, matvec_A=matvec_A, matvec_G=matvec_G, tol=FLAGS.tol, maxiter=10*1000)
101 | hyper_params = dict(params_obj=(X, linear_part),
102 | params_eq=(None, b),
103 | params_ineq=(None, h))
104 |
105 | sol, _ = osqp.run(init_params=None, **hyper_params)
106 | return sol.primal
107 |
108 |
109 | def main(argv):
110 | del argv
111 |
112 | # Generate data.
113 | num_samples = FLAGS.num_samples
114 | num_features = FLAGS.num_features
115 | num_classes = FLAGS.num_classes
116 |
117 | X, y = datasets.make_classification(n_samples=num_samples, n_features=num_features,
118 | n_informative=3, n_classes=num_classes, random_state=0)
119 | X = preprocessing.Normalizer().fit_transform(X)
120 | Y = preprocessing.LabelBinarizer().fit_transform(y)
121 | Y = jnp.array(Y)
122 |
123 | l2reg = FLAGS.l2reg
124 |
125 | # Compare against sklearn.
126 | W_osqp = multiclass_linear_svm_osqp(X, Y, l2reg)
127 | W_fit_osqp = jnp.dot(X.T, (Y - W_osqp)) / l2reg
128 | print(W_fit_osqp)
129 | print()
130 |
131 | W_bcd = multiclass_linear_svm_bcd(X, Y, l2reg)
132 | W_fit_bcd = jnp.dot(X.T, (Y - W_bcd)) / l2reg
133 | print(W_fit_bcd)
134 | print()
135 |
136 | W_skl = multiclass_linear_svm_skl(X, y, l2reg)
137 | print(W_skl)
138 | print()
139 |
140 |
141 | if __name__ == "__main__":
142 | jax.config.update("jax_platform_name", "cpu")
143 | app.run(main)
144 |
--------------------------------------------------------------------------------
/examples/constrained/nmf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Non-negative matrix factorizaton (NMF) using alternating minimization.
17 | ======================================================================
18 | """
19 |
20 | from absl import app
21 | from absl import flags
22 |
23 | import jax.numpy as jnp
24 |
25 | from jaxopt import BlockCoordinateDescent
26 | from jaxopt import objective
27 | from jaxopt import prox
28 |
29 | import numpy as onp
30 |
31 | from sklearn import datasets
32 |
33 |
34 | flags.DEFINE_string("penalty", "l2", "Regularization type.")
35 | flags.DEFINE_float("gamma", 1.0, "Regularization strength.")
36 | FLAGS = flags.FLAGS
37 |
38 |
39 | def nnreg(U, V_init, X, maxiter=150):
40 | """Regularized non-negative regression.
41 |
42 | We solve::
43 |
44 | min_{V >= 0} mean((U V^T - X) ** 2) + 0.5 * gamma * ||V||^2_2
45 |
46 | or
47 |
48 | min_{V >= 0} mean((U V^T - X) ** 2) + gamma * ||V||_1
49 | """
50 | if FLAGS.penalty == "l2":
51 | block_prox = prox.prox_non_negative_ridge
52 | elif FLAGS.penalty == "l1":
53 | block_prox = prox.prox_non_negative_lasso
54 | else:
55 | raise ValueError("Invalid penalty.")
56 |
57 | bcd = BlockCoordinateDescent(fun=objective.least_squares,
58 | block_prox=block_prox,
59 | maxiter=maxiter)
60 | sol = bcd.run(init_params=V_init.T, hyperparams_prox=FLAGS.gamma, data=(U, X))
61 | return sol.params.T # approximate solution V
62 |
63 |
64 | def reconstruction_error(U, V, X):
65 | """Computes (unregularized) reconstruction error."""
66 | UV = jnp.dot(U, V.T)
67 | return 0.5 * jnp.mean((UV - X) ** 2)
68 |
69 |
70 | def nmf(U_init, V_init, X, maxiter=10):
71 | """NMF by alternating minimization.
72 |
73 | We solve
74 |
75 | min_{U >= 0, V>= 0} ||U V^T - X||^2 + 0.5 * gamma * (||U||^2_2 + ||V||^2_2)
76 |
77 | or
78 |
79 | min_{U >= 0, V>= 0} ||U V^T - X||^2 + gamma * (||U||_1 + ||V||_1)
80 | """
81 | U, V = U_init, V_init
82 |
83 | error = reconstruction_error(U, V, X)
84 | print(f"STEP: 0; Error: {error:.3f}")
85 | print()
86 |
87 | for step in range(1, maxiter + 1):
88 | print(f"STEP: {step}")
89 |
90 | V = nnreg(U, V, X, maxiter=150)
91 | error = reconstruction_error(U, V, X)
92 | print(f"Error: {error:.3f} (V update)")
93 |
94 | U = nnreg(V, U, X.T, maxiter=150)
95 | error = reconstruction_error(U, V, X)
96 | print(f"Error: {error:.3f} (U update)")
97 | print()
98 |
99 |
100 | def main(argv):
101 | del argv
102 |
103 | # Prepare data.
104 | X, _ = datasets.load_diabetes(return_X_y=True)
105 | X = jnp.sqrt(X ** 2)
106 |
107 | n_samples = X.shape[0]
108 | n_features = X.shape[1]
109 | n_components = 10
110 |
111 | rng = onp.random.RandomState(0)
112 | U = jnp.array(rng.rand(n_samples, n_components))
113 | V = jnp.array(rng.rand(n_features, n_components))
114 |
115 | # Run the algorithm.
116 | print("penalty:", FLAGS.penalty)
117 | print("gamma", FLAGS.gamma)
118 | print()
119 |
120 | nmf(U, V, X, maxiter=30)
121 |
122 | if __name__ == "__main__":
123 | app.run(main)
124 |
--------------------------------------------------------------------------------
/examples/deep_learning/README.rst:
--------------------------------------------------------------------------------
1 | .. _deep_learning_examples:
2 |
3 | Deep learning
4 | -------------
5 |
6 |
--------------------------------------------------------------------------------
/examples/fixed_point/README.rst:
--------------------------------------------------------------------------------
1 | .. _fixed_point_examples:
2 |
3 | Fixed point resolution
4 | ----------------------
5 |
6 |
--------------------------------------------------------------------------------
/examples/fixed_point/plot_anderson_accelerate_gd.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""
16 | Anderson acceleration of gradient descent.
17 | ==========================================
18 |
19 | For a strictly convex function f, :math:`\nabla f(x)=0` implies that :math:`x`
20 | is the global optimum :math:`f`.
21 |
22 | Consequently the fixed point of :math:`T(x)=x-\eta\nabla f(x)` is the optimum of
23 | :math:`f`.
24 |
25 | Note that repeated application of the operator :math:`T` coincides exactlty with
26 | gradient descent with constant step size :math:`\eta`.
27 |
28 | Hence, as any other fixed point iteration, gradient descent can benefit from
29 | Anderson acceleration. Here, we choose :math:`f` as the objective function
30 | of ridge regression on some dummy dataset. Anderson acceleration reaches the
31 | optimal parameters within few iterations, whereas gradient descent is slower.
32 |
33 | Here `m` denotes the history size, and `K` the frequency of Anderson updates.
34 | """
35 |
36 | import jax
37 | import jax.numpy as jnp
38 |
39 | import matplotlib.pyplot as plt
40 | from sklearn import datasets
41 |
42 | from jaxopt import AndersonAcceleration
43 | from jaxopt import FixedPointIteration
44 |
45 | from jaxopt import objective
46 | from jaxopt.tree_util import tree_scalar_mul, tree_sub
47 |
48 | jax.config.update("jax_platform_name", "cpu")
49 |
50 |
51 | # retrieve intermediate iterates.
52 | def run_all(solver, w_init, *args, **kwargs):
53 | state = solver.init_state(w_init, *args, **kwargs)
54 | sol = w_init
55 | sols, errors = [], []
56 |
57 | for _ in range(solver.maxiter):
58 | sol, state = solver.update(sol, state, *args, **kwargs)
59 | sols.append(sol)
60 | errors.append(state.error)
61 |
62 | return jnp.stack(sols, axis=0), errors
63 |
64 |
65 | # dummy dataset
66 | X, y = datasets.make_regression(n_samples=100, n_features=10, random_state=0)
67 | ridge_regression_grad = jax.grad(objective.ridge_regression)
68 |
69 | # gradient step: x - grad_x f(x) with f the cost of learning task
70 | # the fixed point of this mapping verifies grad_x f(x) = 0
71 | # i.e the fixed point is an optimum
72 | def T(params, eta, l2reg, data):
73 | g = ridge_regression_grad(params, l2reg, data)
74 | step = tree_scalar_mul(eta, g)
75 | return tree_sub(params, step)
76 |
77 | w_init = jnp.zeros(X.shape[1]) # null vector
78 | eta = 1e-1 # small step size
79 | l2reg = 0. # no regularization
80 | tol = 1e-5
81 | maxiter = 80
82 | aa = AndersonAcceleration(T, history_size=5, mixing_frequency=1, maxiter=maxiter, ridge=5e-5, tol=tol)
83 | aam = AndersonAcceleration(T, history_size=5, mixing_frequency=5, maxiter=maxiter, ridge=5e-5, tol=tol)
84 | fpi = FixedPointIteration(T, maxiter=maxiter, tol=tol)
85 |
86 | aa_sols, aa_errors = run_all(aa, w_init, eta, l2reg, (X, y))
87 | aam_sols, aam_errors = run_all(aam, w_init, eta, l2reg, (X, y))
88 | fp_sols, fp_errors = run_all(fpi, w_init, eta, l2reg, (X, y))
89 |
90 | sol = aa_sols[-1]
91 | print(f'Error={aa_errors[-1]:.6f} at parameters {sol}')
92 | print(f'At this point the gradient {ridge_regression_grad(sol, l2reg, (X,y))} is close to zero vector so we found the minimum.')
93 |
94 | fig = plt.figure(figsize=(10, 12))
95 | fig.suptitle('Trajectory in parameter space')
96 | spec = fig.add_gridspec(ncols=2, nrows=3, hspace=0.3)
97 |
98 | # Plot trajectory in parameter space (8 dimensions)
99 | for i in range(4):
100 | ax = fig.add_subplot(spec[i//2, i%2])
101 | ax.plot(fp_sols[:,i], fp_sols[:,2*i+1], '-', linewidth=4., label="Gradient Descent")
102 | ax.plot(aa_sols[:,i], aa_sols[:,2*i+1], 'v', markersize=12, label="Anderson Accelerated GD (m=5, K=1)")
103 | ax.plot(aam_sols[:,i], aam_sols[:,2*i+1], '*', markersize=8, label="Anderson Accelerated GD (m=5, K=5)")
104 | ax.set_xlabel(f'$x_{{{2*i+1}}}$')
105 | ax.set_ylabel(f'$x_{{{2*i+2}}}$')
106 | if i == 0:
107 | ax.legend(loc='upper left', bbox_to_anchor=(0.75, 1.38),
108 | ncol=1, fancybox=True, shadow=True)
109 | ax.axis('equal')
110 |
111 | # Plot error as function of iteration num
112 | ax = fig.add_subplot(spec[2, :])
113 | iters = jnp.arange(len(aa_errors))
114 | ax.plot(iters, fp_errors, linewidth=4., label='Gradient Descent Error')
115 | ax.plot(iters, aa_errors, linewidth=4., label='Anderson Accelerated GD Error (m=5, K=1)')
116 | ax.plot(iters, aam_errors, linewidth=4., label='Anderson Accelerated GD Error (m=5, K=5)')
117 | ax.set_xlabel('Iteration num')
118 | ax.set_ylabel('Error')
119 | ax.set_yscale('log')
120 | ax.legend()
121 | plt.show()
122 |
123 |
--------------------------------------------------------------------------------
/examples/fixed_point/plot_anderson_wrapper_cd.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""
16 | Anderson acceleration of block coordinate descent.
17 | ==================================================
18 |
19 | Block coordinate descent converges to a fixed point. It can therefore be
20 | accelerated with Anderson acceleration.
21 |
22 | Here `m` denotes the history size, and `K` the frequency of Anderson updates.
23 |
24 | Bertrand, Q. and Massias, M.
25 | Anderson acceleration of coordinate descent.
26 | AISTATS, 2021.
27 | """
28 |
29 | import jax
30 | import jax.numpy as jnp
31 |
32 | from jaxopt import AndersonWrapper
33 | from jaxopt import BlockCoordinateDescent
34 |
35 | from jaxopt import objective
36 | from jaxopt import prox
37 |
38 | import matplotlib.pyplot as plt
39 | from sklearn import datasets
40 |
41 | jax.config.update("jax_platform_name", "cpu")
42 | jax.config.update("jax_enable_x64", True)
43 |
44 |
45 | # retrieve intermediate iterates.
46 | def run_all(solver, w_init, *args, **kwargs):
47 | state = solver.init_state(w_init, *args, **kwargs)
48 | sol = w_init
49 | sols, errors = [sol], [state.error]
50 | for _ in range(solver.maxiter):
51 | sol, state = solver.update(sol, state, *args, **kwargs)
52 | sols.append(sol)
53 | errors.append(state.error)
54 | return jnp.stack(sols, axis=0), errors
55 |
56 |
57 | X, y = datasets.make_regression(n_samples=10, n_features=8, random_state=1)
58 | fun = objective.least_squares # fun(params, data)
59 | l1reg = 10.0
60 | data = (X, y)
61 |
62 | w_init = jnp.zeros(X.shape[1])
63 | maxiter = 80
64 |
65 | bcd = BlockCoordinateDescent(fun, block_prox=prox.prox_lasso, maxiter=maxiter, tol=1e-6)
66 | history_size = 5
67 | aa = AndersonWrapper(bcd, history_size=history_size, mixing_frequency=1, ridge=1e-4)
68 | aam = AndersonWrapper(bcd, history_size=history_size, mixing_frequency=history_size, ridge=1e-4)
69 |
70 | aa_sols, aa_errors = run_all(aa, w_init, hyperparams_prox=l1reg, data=data)
71 | aam_sols, aam_errors = run_all(aam, w_init, hyperparams_prox=l1reg, data=data)
72 | bcd_sols, bcd_errors = run_all(bcd, w_init, hyperparams_prox=l1reg, data=data)
73 |
74 | print(f'Error={aa_errors[-1]:.6f} at parameters {aa_sols[-1]} for Anderson (m=5, K=1)')
75 | print(f'Error={aam_errors[-1]:.6f} at parameters {aam_sols[-1]} for Anderson (m=5, K=5)')
76 | print(f'Error={bcd_errors[-1]:.6f} at parameters {bcd_sols[-1]} for Block CD')
77 |
78 | fig = plt.figure(figsize=(10, 12))
79 | fig.suptitle('Least Square linear regression with Lasso penalty')
80 | spec = fig.add_gridspec(ncols=2, nrows=3, hspace=0.3)
81 |
82 | # Plot trajectory in parameter space (8 dimensions)
83 | for i in range(4):
84 | ax = fig.add_subplot(spec[i//2, i%2])
85 | ax.plot(bcd_sols[:,i], bcd_sols[:,2*i+1], '--', label="Coordinate Descent")
86 | ax.plot(aa_sols[:,i], aa_sols[:,2*i+1], '--', label="Anderson Accelerated CD (m=5, K=1)")
87 | ax.plot(aam_sols[:,i], aam_sols[:,2*i+1], '--', label="Anderson Accelerated CD (m=5, K=5)")
88 | ax.set_xlabel(f'$x_{{{2*i+1}}}$')
89 | ax.set_ylabel(f'$x_{{{2*i+2}}}$')
90 | if i == 0:
91 | ax.legend(loc='upper left', bbox_to_anchor=(0.75, 1.38),
92 | ncol=1, fancybox=True, shadow=True)
93 | ax.axis('equal')
94 |
95 | # Plot error as function of iteration num
96 | ax = fig.add_subplot(spec[2, :])
97 | iters = jnp.arange(len(aa_errors))
98 | ax.plot(iters, bcd_errors, '-o', label='Coordinate Descent Error')
99 | ax.plot(iters, aa_errors, '-o', label='Anderson Accelerated CD Error (m=5, K=1)')
100 | ax.plot(iters, aam_errors, '-o', label='Anderson Accelerated CD Error (m=5, K=5)')
101 | ax.set_xlabel('Iteration num')
102 | ax.set_ylabel('Error')
103 | ax.set_yscale('log')
104 | ax.legend()
105 | plt.show()
106 |
107 |
--------------------------------------------------------------------------------
/examples/fixed_point/plot_picard_ode.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""
16 | Anderson acceleration in application to Picard–Lindelöf theorem.
17 | ================================================================
18 |
19 | Thanks to the `Picard–Lindelöf theorem,
20 | `_ we can
21 | reduce differential equation solving to fixed point computations and simple
22 | integration. More precisely consider the ODE:
23 |
24 | .. math::
25 |
26 | y'(t)=f(t,y(t))
27 |
28 | of some time-dependant dynamic
29 | :math:`f:\mathbb{R}\times\mathbb{R}^d\rightarrow\mathbb{R}^d` and initial
30 | conditions :math:`y(0)=y_0`. Then :math:`y` is the fixed point of the following
31 | map:
32 |
33 | .. math::
34 |
35 | y(t)=T(y)(t)\mathrel{\mathop:}=y_0+\int_0^t f(s,y(s))\mathrm{d}s
36 |
37 | Then we can define the sequence of functions :math:`(\phi_k)` with
38 | :math:`\phi_0=0` recursively as follows:
39 |
40 | .. math::
41 |
42 | \phi_{k+1}(t)=T(\phi_k)(t)\mathrel{\mathop:} =
43 | y_0+\int_0^t f(s,\phi_k(s))\mathrm{d}s
44 |
45 | Such sequence converges to the solution of the ODE, i.e.,
46 | :math:`\lim_{k\rightarrow\infty}\phi_k=y`.
47 |
48 | In this example we choose :math:`f(t,y(t))=1+y(t)^2`. We know that the
49 | analytical solution is :math:`y(t)=\tan{t}` , which we use as a ground truth to
50 | evaluate our numerical scheme.
51 | We used ``scipy.integrate.cumulative_trapezoid`` to perform
52 | integration, but any other integration method can be used.
53 | """
54 |
55 |
56 | import jax
57 | import jax.numpy as jnp
58 |
59 | from jaxopt import AndersonAcceleration
60 |
61 |
62 | import numpy as np
63 | import matplotlib.pyplot as plt
64 | from matplotlib.pyplot import cm
65 | import scipy.integrate
66 |
67 | jax.config.update("jax_platform_name", "cpu")
68 |
69 |
70 | # Solve the differential equation y'(t)=1+t^2, with solution y(t) = tan(t)
71 | def f(ti, phi):
72 | return 1 + phi ** 2
73 |
74 | def T(phi_cur, ti, y0, dx):
75 | """Fixed point iteration in the Picard method.
76 | See: https://en.wikipedia.org/wiki/Picard%E2%80%93Lindel%C3%B6f_theorem"""
77 | f_phi = f(ti, phi_cur)
78 | phi_next = scipy.integrate.cumulative_trapezoid(f_phi, initial=y0, dx=dx)
79 | return phi_next
80 |
81 | y0 = 0
82 | num_interpolating_points = 100
83 | t0 = jnp.array(0.)
84 | tmax = 0.9 * (jnp.pi / 2) # stop before pi/2 to ensure convergence
85 | dx = (tmax - t0) / (num_interpolating_points-1)
86 | phi0 = jnp.zeros(num_interpolating_points)
87 | ti = np.linspace(t0, tmax, num_interpolating_points)
88 |
89 | sols = [phi0]
90 | aa = AndersonAcceleration(T, history_size=5, maxiter=50, ridge=1e-5, jit=False)
91 | state = aa.init_state(phi0, ti, y0, dx)
92 | sol = phi0
93 | sols.append(sol)
94 | for k in range(aa.maxiter):
95 | sol, state = aa.update(phi0, state, ti, y0, dx)
96 | sols.append(sol)
97 | res = sols[-1] - np.tan(ti)
98 | print(f'Error of {jnp.linalg.norm(res)} with ground truth tan(t)')
99 |
100 |
101 | # vizualize the first 8 iterates to make the figure easier to read
102 | sols = sols[4:12]
103 | fig = plt.figure(figsize=(8,4))
104 | ax = fig.add_subplot(1, 1, 1)
105 |
106 | colors = cm.plasma(np.linspace(0, 1, len(sols)))
107 | for k, (sol, c) in enumerate(zip(sols, colors)):
108 | desc = rf'$\phi_{k}$' if k > 0 else rf'$\phi_0=0$'
109 | ax.plot(ti, sol, '+', c=c, label=desc)
110 | ax.plot(ti, np.tan(ti), '-', c='green', label=r'$y(t)=\tan{(t)}$ (ground truth)')
111 |
112 | ax.legend()
113 | props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
114 | formula = rf'$\phi_{{k+1}}(t)=\phi_0+\int_0^{{{tmax/2:.2f}\pi}} f(t,\phi_{{k}}(t))\mathrm{{d}}t$'
115 | ax.text(0.42, 0.85, formula, transform=ax.transAxes, fontsize=14, verticalalignment='top', bbox=props)
116 | fig.suptitle('Anderson acceleration for ODE solving')
117 | plt.show()
118 |
--------------------------------------------------------------------------------
/examples/implicit_diff/README.rst:
--------------------------------------------------------------------------------
1 | .. _implicit_diff_examples:
2 |
3 | Implicit differentiation
4 | ------------------------
5 |
6 |
--------------------------------------------------------------------------------
/examples/implicit_diff/lasso_implicit_diff.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Implicit differentiation of lasso.
17 | ==================================
18 | """
19 |
20 | from absl import app
21 | from absl import flags
22 |
23 | import jax.numpy as jnp
24 |
25 | from jaxopt import BlockCoordinateDescent
26 | from jaxopt import objective
27 | from jaxopt import OptaxSolver
28 | from jaxopt import prox
29 | from jaxopt import ProximalGradient
30 | import optax
31 |
32 | from sklearn import datasets
33 | from sklearn import model_selection
34 | from sklearn import preprocessing
35 |
36 | flags.DEFINE_bool("unrolling", False, "Whether to use unrolling.")
37 | flags.DEFINE_string("solver", "bcd", "Solver to use (bcd or pg).")
38 | FLAGS = flags.FLAGS
39 |
40 |
41 | def outer_objective(theta, init_inner, data):
42 | """Validation loss."""
43 | X_tr, X_val, y_tr, y_val = data
44 | # We use the bijective mapping lam = jnp.exp(theta) to ensure positivity.
45 | lam = jnp.exp(theta)
46 |
47 | if FLAGS.solver == "pg":
48 | solver = ProximalGradient(
49 | fun=objective.least_squares,
50 | prox=prox.prox_lasso,
51 | implicit_diff=not FLAGS.unrolling,
52 | maxiter=500)
53 | elif FLAGS.solver == "bcd":
54 | solver = BlockCoordinateDescent(
55 | fun=objective.least_squares,
56 | block_prox=prox.prox_lasso,
57 | implicit_diff=not FLAGS.unrolling,
58 | maxiter=500)
59 | else:
60 | raise ValueError("Unknown solver.")
61 |
62 | # The format is run(init_params, hyperparams_prox, *args, **kwargs)
63 | # where *args and **kwargs are passed to `fun`.
64 | w_fit = solver.run(init_inner, lam, (X_tr, y_tr)).params
65 |
66 | y_pred = jnp.dot(X_val, w_fit)
67 | loss_value = jnp.mean((y_pred - y_val) ** 2)
68 |
69 | # We return w_fit as auxiliary data.
70 | # Auxiliary data is stored in the optimizer state (see below).
71 | return loss_value, w_fit
72 |
73 |
74 | def main(argv):
75 | del argv
76 |
77 | print("Solver:", FLAGS.solver)
78 | print("Unrolling:", FLAGS.unrolling)
79 |
80 | # Prepare data.
81 | X, y = datasets.load_diabetes(return_X_y=True)
82 | X = preprocessing.normalize(X)
83 | # data = (X_tr, X_val, y_tr, y_val)
84 | data = model_selection.train_test_split(X, y, test_size=0.33, random_state=0)
85 |
86 | # Initialize solver.
87 | solver = OptaxSolver(opt=optax.adam(1e-2), fun=outer_objective, has_aux=True)
88 | theta = 1.0
89 | init_w = jnp.zeros(X.shape[1])
90 | state = solver.init_state(theta, init_inner=init_w, data=data)
91 |
92 | # Run outer loop.
93 | for _ in range(10):
94 | theta, state = solver.update(params=theta, state=state, init_inner=init_w,
95 | data=data)
96 | # The auxiliary data returned by the outer loss is stored in the state.
97 | init_w = state.aux
98 | print(f"[Step {state.iter_num}] Validation loss: {state.value:.3f}.")
99 |
100 | if __name__ == "__main__":
101 | app.run(main)
102 |
--------------------------------------------------------------------------------
/examples/implicit_diff/ridge_reg_implicit_diff.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Implicit differentiation of ridge regression.
17 | =============================================
18 | """
19 |
20 | from absl import app
21 | import jax
22 | import jax.numpy as jnp
23 | from jaxopt import implicit_diff
24 | from jaxopt import linear_solve
25 | from jaxopt import OptaxSolver
26 | import optax
27 | from sklearn import datasets
28 | from sklearn import model_selection
29 | from sklearn import preprocessing
30 |
31 |
32 | def ridge_objective(params, l2reg, data):
33 | """Ridge objective function."""
34 | X_tr, y_tr = data
35 | residuals = jnp.dot(X_tr, params) - y_tr
36 | return 0.5 * jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.sum(params ** 2)
37 |
38 |
39 | @implicit_diff.custom_root(jax.grad(ridge_objective))
40 | def ridge_solver(init_params, l2reg, data):
41 | """Solve ridge regression by conjugate gradient."""
42 | X_tr, y_tr = data
43 |
44 | def matvec(u):
45 | return jnp.dot(X_tr.T, jnp.dot(X_tr, u))
46 |
47 | return linear_solve.solve_cg(matvec=matvec,
48 | b=jnp.dot(X_tr.T, y_tr),
49 | ridge=len(y_tr) * l2reg,
50 | init=init_params,
51 | maxiter=20)
52 |
53 |
54 | # Perhaps confusingly, theta is a parameter of the outer objective,
55 | # but l2reg = jnp.exp(theta) is an hyper-parameter of the inner objective.
56 | def outer_objective(theta, init_inner, data):
57 | """Validation loss."""
58 | X_tr, X_val, y_tr, y_val = data
59 | # We use the bijective mapping l2reg = jnp.exp(theta)
60 | # both to optimize in log-space and to ensure positivity.
61 | l2reg = jnp.exp(theta)
62 | w_fit = ridge_solver(init_inner, l2reg, (X_tr, y_tr))
63 | y_pred = jnp.dot(X_val, w_fit)
64 | loss_value = jnp.mean((y_pred - y_val) ** 2)
65 | # We return w_fit as auxiliary data.
66 | # Auxiliary data is stored in the optimizer state (see below).
67 | return loss_value, w_fit
68 |
69 |
70 | def main(argv):
71 | del argv
72 |
73 | # Prepare data.
74 | X, y = datasets.load_diabetes(return_X_y=True)
75 | X = preprocessing.normalize(X)
76 | # data = (X_tr, X_val, y_tr, y_val)
77 | data = model_selection.train_test_split(X, y, test_size=0.33, random_state=0)
78 |
79 | # Initialize solver.
80 | solver = OptaxSolver(opt=optax.adam(1e-2), fun=outer_objective, has_aux=True)
81 | theta = 1.0
82 | init_w = jnp.zeros(X.shape[1])
83 | state = solver.init_state(theta, init_inner=init_w, data=data)
84 |
85 | # Run outer loop.
86 | for _ in range(50):
87 | theta, state = solver.update(params=theta, state=state, init_inner=init_w,
88 | data=data)
89 | # The auxiliary data returned by the outer loss is stored in the state.
90 | init_w = state.aux
91 | print(f"[Step {state.iter_num}] Validation loss: {state.value:.3f}.")
92 |
93 | if __name__ == "__main__":
94 | app.run(main)
95 |
--------------------------------------------------------------------------------
/examples/requirements.txt:
--------------------------------------------------------------------------------
1 | dm-haiku>=0.0.4
2 | flax>=0.3.4
3 | optax>=0.0.9
4 | scikit-learn>=0.24.1
5 | tensorflow-datasets>=4.4.0
6 | tqdm>=4.62
7 |
--------------------------------------------------------------------------------
/jaxopt/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 |
17 | from jaxopt import implicit_diff
18 | from jaxopt import isotonic
19 | from jaxopt import loss
20 | from jaxopt import objective
21 | from jaxopt import projection
22 | from jaxopt import prox
23 |
24 | from jaxopt._src.anderson import AndersonAcceleration
25 | from jaxopt._src.anderson_wrapper import AndersonWrapper
26 | from jaxopt._src.armijo_sgd import ArmijoSGD
27 | from jaxopt._src.base import OptStep
28 | from jaxopt._src.backtracking_linesearch import BacktrackingLineSearch
29 | from jaxopt._src.bfgs import BFGS
30 | from jaxopt._src.bisection import Bisection
31 | from jaxopt._src.block_cd import BlockCoordinateDescent
32 | from jaxopt._src.broyden import Broyden
33 | from jaxopt._src.cd_qp import BoxCDQP
34 | from jaxopt._src.cvxpy_wrapper import CvxpyQP
35 | from jaxopt._src.eq_qp import EqualityConstrainedQP
36 | from jaxopt._src.fixed_point_iteration import FixedPointIteration
37 | from jaxopt._src.gauss_newton import GaussNewton
38 | from jaxopt._src.gradient_descent import GradientDescent
39 | from jaxopt._src.hager_zhang_linesearch import HagerZhangLineSearch
40 | from jaxopt._src.iterative_refinement import IterativeRefinement
41 | from jaxopt._src.lbfgs import LBFGS
42 | from jaxopt._src.lbfgsb import LBFGSB
43 | from jaxopt._src.levenberg_marquardt import LevenbergMarquardt
44 | from jaxopt._src.mirror_descent import MirrorDescent
45 | from jaxopt._src.nonlinear_cg import NonlinearCG
46 | from jaxopt._src.optax_wrapper import OptaxSolver
47 | from jaxopt._src.osqp import BoxOSQP
48 | from jaxopt._src.osqp import OSQP
49 | from jaxopt._src.polyak_sgd import PolyakSGD
50 | from jaxopt._src.projected_gradient import ProjectedGradient
51 | from jaxopt._src.proximal_gradient import ProximalGradient
52 | from jaxopt._src.scipy_wrappers import ScipyBoundedLeastSquares
53 | from jaxopt._src.scipy_wrappers import ScipyBoundedMinimize
54 | from jaxopt._src.scipy_wrappers import ScipyLeastSquares
55 | from jaxopt._src.scipy_wrappers import ScipyMinimize
56 | from jaxopt._src.scipy_wrappers import ScipyRootFinding
57 | from jaxopt._src.zoom_linesearch import ZoomLineSearch
58 |
59 | warnings.warn(
60 | "JAXopt is no longer maintained. See https://docs.jax.dev/en/latest/ for"
61 | " alternatives.",
62 | DeprecationWarning,
63 | )
64 |
--------------------------------------------------------------------------------
/jaxopt/_src/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/jaxopt/_src/cd_qp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Implementation of coordinate descent for box-constrained QPs."""
16 |
17 | from typing import Callable
18 | from typing import NamedTuple
19 | from typing import Optional
20 | from typing import Union
21 |
22 | from dataclasses import dataclass
23 |
24 | import jax
25 | import jax.numpy as jnp
26 |
27 | from jaxopt._src import base
28 | from jaxopt._src import projection
29 |
30 |
31 | class BoxCDQPState(NamedTuple):
32 | """Named tuple containing state information."""
33 | iter_num: int
34 | error: float
35 |
36 |
37 | def fori_loop_body_fun(i, tup):
38 | x, Q, c, l, u, error = tup
39 | # i-th element of the gradient
40 | g_i = jnp.dot(Q[i], x) + c[i]
41 | # i-th diagonal element of the Hessian
42 | h_i = Q[i, i]
43 | # Newton-update and avoid division by zero
44 | update = jnp.where(h_i == 0, 0, g_i / h_i)
45 | # Newton-update + clipping to satisfy the box constraint
46 | x_i_new = jnp.clip(x[i] - update, l[i], u[i])
47 | delta_i = x_i_new - x[i]
48 | # Cumulated error
49 | error += jnp.abs(delta_i)
50 | x = x.at[i].set(x_i_new)
51 | return x, Q, c, l, u, error
52 |
53 |
54 | @dataclass(eq=False)
55 | class BoxCDQP(base.IterativeSolver):
56 | """Coordinate descent solver for box-constrained QPs.
57 |
58 | This solver minimizes::
59 |
60 | 0.5 + subject to l <= x <= u
61 |
62 | Attributes:
63 | maxiter: maximum number of coordinate descent iterations.
64 | tol: tolerance to use.
65 | verbose: whether to print information on every iteration or not.
66 |
67 | implicit_diff: whether to enable implicit diff or autodiff of unrolled
68 | iterations.
69 | implicit_diff_solve: the linear system solver to use.
70 |
71 | jit: whether to JIT-compile the optimization loop (default: True).
72 | unroll: whether to unroll the optimization loop (default: "auto").
73 | """
74 | maxiter: int = 500
75 | tol: float = 1e-4
76 | verbose: Union[bool, int] = False
77 | implicit_diff: bool = True
78 | implicit_diff_solve: Optional[Callable] = None
79 | jit: bool = True
80 | unroll: base.AutoOrBoolean = "auto"
81 |
82 | def init_state(self,
83 | init_params: jnp.ndarray,
84 | params_obj: Optional[base.ArrayPair] = None,
85 | params_ineq: Optional[base.ArrayPair] = None) -> BoxCDQPState:
86 | """Initialize the solver state.
87 |
88 | Args:
89 | init_params: array containing the initial parameters.
90 | params_obj: Tuple of arrays ``(Q, c)``.
91 | params_ineq: Tuple of arrays ``(l, u)``.
92 | Returns:
93 | state
94 | """
95 | del params_obj, params_ineq # Not used.
96 | return BoxCDQPState(iter_num=jnp.asarray(0),
97 | error=jnp.asarray(jnp.inf))
98 |
99 | def update(self,
100 | params: jnp.ndarray,
101 | state: NamedTuple,
102 | params_obj: base.ArrayPair,
103 | params_ineq: base.ArrayPair) -> base.OptStep:
104 | """Performs one epoch of coordinate descent.
105 |
106 | Args:
107 | params: array containing the parameters.
108 | state: named tuple containing the solver state.
109 | params_obj: Tuple of arrays ``(Q, c)``.
110 | params_ineq: Tuple of arrays ``(l, u)``.
111 | Returns:
112 | (params, state)
113 | """
114 | Q, c = params_obj
115 | l, u = params_ineq
116 |
117 | init = (params, Q, c, l, u, 0)
118 |
119 | # todo: ability to permute coordinate order.
120 | params, _, _, _, _, error = jax.lax.fori_loop(lower=0,
121 | upper=params.shape[0],
122 | body_fun=fori_loop_body_fun,
123 | init_val=init)
124 |
125 | state = BoxCDQPState(iter_num=state.iter_num + 1, error=error)
126 |
127 | if self.verbose:
128 | self.log_info(state)
129 | return base.OptStep(params=params, state=state)
130 |
131 | def _fixed_point_fun(self,
132 | sol: jnp.ndarray,
133 | params_obj: base.ArrayPair,
134 | params_ineq: base.ArrayPair) -> jnp.ndarray:
135 | Q, c = params_obj
136 | l, u = params_ineq
137 | grad = jnp.dot(Q, sol) + c
138 | return projection.projection_box(sol - grad, (l, u))
139 |
140 | def optimality_fun(self,
141 | sol: jnp.ndarray,
142 | params_obj: base.ArrayPair,
143 | params_ineq: base.ArrayPair) -> jnp.ndarray:
144 | return self._fixed_point_fun(sol, params_obj, params_ineq) - sol
145 |
146 | def __post_init__(self):
147 | super().__post_init__()
148 |
--------------------------------------------------------------------------------
/jaxopt/_src/cond.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Branching utilities."""
16 |
17 | import jax
18 |
19 | def cond(cond, if_fun, else_fun, *operands, jit=True):
20 | """Wrapper to avoid having the condition to be compiled if not wanted."""
21 | if not jit:
22 | with jax.disable_jit():
23 | return jax.lax.cond(cond, if_fun, else_fun, *operands)
24 | return jax.lax.cond(cond, if_fun, else_fun, *operands)
--------------------------------------------------------------------------------
/jaxopt/_src/fixed_point_iteration.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Implementation of the fixed point iteration method in JAX."""
16 |
17 | from typing import Any
18 | from typing import Callable
19 | from typing import NamedTuple
20 | from typing import Optional
21 | from typing import Union
22 |
23 | from dataclasses import dataclass
24 |
25 | import jax.numpy as jnp
26 |
27 | from jaxopt._src import base
28 | from jaxopt._src.tree_util import tree_l2_norm, tree_sub
29 |
30 |
31 | class FixedPointState(NamedTuple):
32 | """Named tuple containing state information.
33 | Attributes:
34 | iter_num: iteration number
35 | error: residuals of current estimate
36 | aux: auxiliary output of fixed_point_fun when has_aux=True
37 | """
38 | iter_num: int
39 | error: float
40 | aux: Optional[Any] = None
41 | num_fun_eval: int = 0
42 |
43 |
44 | @dataclass(eq=False)
45 | class FixedPointIteration(base.IterativeSolver):
46 | """Fixed point iteration method.
47 | Attributes:
48 | fixed_point_fun: a function ``fixed_point_fun(x, *args, **kwargs)``
49 | returning a pytree with the same structure and type as x
50 | The function should fulfill the Banach fixed-point theorem's assumptions.
51 | Otherwise convergence is not guaranteed.
52 | maxiter: maximum number of iterations.
53 | tol: tolerance (stopping criterion)
54 | has_aux: wether fixed_point_fun returns additional data. (default: False)
55 | if True, the fixed is computed only with respect to first element of the
56 | sequence returned. Other elements are carried during computation.
57 | verbose: whether to print information on every iteration or not.
58 |
59 | implicit_diff: whether to enable implicit diff or autodiff of unrolled
60 | iterations.
61 | implicit_diff_solve: the linear system solver to use.
62 |
63 | jit: whether to JIT-compile the optimization loop (default: True).
64 | unroll: whether to unroll the optimization loop (default: "auto")
65 | References:
66 | https://en.wikipedia.org/wiki/Fixed-point_iteration
67 | """
68 | fixed_point_fun: Callable
69 | maxiter: int = 100
70 | tol: float = 1e-5
71 | has_aux: bool = False
72 | verbose: Union[bool, int] = False
73 | implicit_diff: bool = True
74 | implicit_diff_solve: Optional[Callable] = None
75 | jit: bool = True
76 | unroll: base.AutoOrBoolean = "auto"
77 |
78 | def init_state(self,
79 | init_params,
80 | *args,
81 | **kwargs) -> FixedPointState:
82 | """Initialize the solver state.
83 |
84 | Args:
85 | init_params: initial guess of the fixed point, pytree
86 | *args: additional positional arguments to be passed to ``optimality_fun``.
87 | **kwargs: additional keyword arguments to be passed to ``optimality_fun``.
88 | Returns:
89 | state
90 | """
91 | return FixedPointState(iter_num=jnp.asarray(0),
92 | error=jnp.asarray(jnp.inf),
93 | aux=None,
94 | num_fun_eval=jnp.asarray(0, base.NUM_EVAL_DTYPE)
95 | )
96 |
97 | def update(self,
98 | params: Any,
99 | state: NamedTuple,
100 | *args,
101 | **kwargs) -> base.OptStep:
102 | """Performs one iteration of the fixed point iteration method.
103 | Args:
104 | params: pytree containing the parameters.
105 | state: named tuple containing the solver state.
106 | *args: additional positional arguments to be passed to
107 | ``fixed_point_fun``.
108 | **kwargs: additional keyword arguments to be passed to
109 | ``fixed_point_fun``.
110 | Returns:
111 | (params, state)
112 | """
113 | next_params, aux = self._fun(params, *args, **kwargs)
114 | error = tree_l2_norm(tree_sub(next_params, params))
115 | next_state = FixedPointState(iter_num=state.iter_num + 1,
116 | error=error,
117 | aux=aux,
118 | num_fun_eval=state.num_fun_eval + 1)
119 |
120 | if self.verbose:
121 | self.log_info(
122 | next_state,
123 | error_name="Distance btw Iterates"
124 | )
125 | return base.OptStep(params=next_params, state=next_state)
126 |
127 | def optimality_fun(self, params, *args, **kwargs):
128 | """Optimality function mapping compatible with ``@custom_root``."""
129 | new_params, _ = self._fun(params, *args, **kwargs)
130 | return tree_sub(new_params, params)
131 |
132 | def __post_init__(self):
133 | super().__post_init__()
134 |
135 | if self.has_aux:
136 | self._fun = self.fixed_point_fun
137 | else:
138 | self._fun = lambda *a, **kw: (self.fixed_point_fun(*a, **kw), None)
139 |
140 | self.reference_signature = self.fixed_point_fun
141 |
--------------------------------------------------------------------------------
/jaxopt/_src/gradient_descent.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Implementation of gradient descent in JAX."""
16 |
17 | from typing import Any
18 | from typing import NamedTuple
19 |
20 | from dataclasses import dataclass
21 |
22 | from jaxopt._src import base
23 | from jaxopt._src.proximal_gradient import ProximalGradient, ProxGradState
24 |
25 |
26 | @dataclass(eq=False)
27 | class GradientDescent(ProximalGradient):
28 | """Gradient Descent solver.
29 |
30 | Attributes:
31 | fun: a smooth function of the form ``fun(parameters, *args, **kwargs)``,
32 | where ``parameters`` are the model parameters w.r.t. which we minimize
33 | the function and the rest are fixed auxiliary parameters.
34 | value_and_grad: whether ``fun`` just returns the value (False) or both
35 | the value and gradient (True).
36 | has_aux: whether ``fun`` outputs auxiliary data or not.
37 | If ``has_aux`` is False, ``fun`` is expected to be
38 | scalar-valued.
39 | If ``has_aux`` is True, then we have one of the following
40 | two cases.
41 | If ``value_and_grad`` is False, the output should be
42 | ``value, aux = fun(...)``.
43 | If ``value_and_grad == True``, the output should be
44 | ``(value, aux), grad = fun(...)``.
45 | At each iteration of the algorithm, the auxiliary outputs are stored
46 | in ``state.aux``.
47 |
48 | stepsize: a stepsize to use (if <= 0, use backtracking line search), or a
49 | callable specifying the **positive** stepsize to use at each iteration.
50 | maxiter: maximum number of proximal gradient descent iterations.
51 | maxls: maximum number of iterations to use in the line search.
52 | tol: tolerance to use.
53 |
54 | acceleration: whether to use acceleration (also known as FISTA) or not.
55 | verbose: whether to print information on every iteration or not.
56 |
57 | implicit_diff: whether to enable implicit diff or autodiff of unrolled
58 | iterations.
59 | implicit_diff_solve: the linear system solver to use.
60 |
61 | jit: whether to JIT-compile the optimization loop (default: True).
62 | unroll: whether to unroll the optimization loop (default: "auto").
63 | """
64 |
65 | def init_state(self,
66 | init_params: Any,
67 | *args,
68 | **kwargs) -> ProxGradState:
69 | """Initialize the solver state.
70 |
71 | Args:
72 | init_params: pytree containing the initial parameters.
73 | *args: additional positional arguments to be passed to ``fun``.
74 | **kwargs: additional keyword arguments to be passed to ``fun``.
75 | Returns:
76 | state
77 | """
78 | return super().init_state(init_params, None, *args, **kwargs)
79 |
80 | def update(
81 | self, params: Any, state: ProxGradState, *args, **kwargs
82 | ) -> base.OptStep:
83 | """Performs one iteration of gradient descent.
84 |
85 | Args:
86 | params: pytree containing the parameters.
87 | state: named tuple containing the solver state.
88 | *args: additional positional arguments to be passed to ``fun``.
89 | **kwargs: additional keyword arguments to be passed to ``fun``.
90 | Returns:
91 | (params, state)
92 | """
93 | return super().update(params, state, None, *args, **kwargs)
94 |
95 | def optimality_fun(self, params, *args, **kwargs):
96 | """Optimality function mapping compatible with ``@custom_root``."""
97 | return self._grad_fun(params, *args, **kwargs)
98 |
99 | def __post_init__(self):
100 | super().__post_init__()
101 | self.reference_signature = self.fun
102 |
--------------------------------------------------------------------------------
/jaxopt/_src/isotonic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Isotonic Regression."""
16 |
17 | import warnings
18 | import numpy as onp
19 | import jax
20 | import jax.numpy as jnp
21 |
22 |
23 | # pylint: disable=g-import-not-at-top
24 | try:
25 | from numba import njit
26 |
27 | NUMBA_AVAILABLE = True
28 | except ImportError:
29 | NUMBA_AVAILABLE = False
30 | # If Numba is not available, we define a dummy 'njit' function.
31 |
32 | def njit(func):
33 | return func
34 |
35 |
36 | @njit
37 | def _isotonic_l2_pav_numba(y):
38 | n = y.shape[0]
39 | target = onp.arange(n)
40 | c = onp.ones(n)
41 | sums = onp.zeros(n)
42 | sol = onp.zeros(n)
43 |
44 | # target describes a list of blocks. At any time, if [i..j] (inclusive) is
45 | # an active block, then target[i] := j and target[j] := i.
46 |
47 | for i in range(n):
48 | sol[i] = y[i]
49 | sums[i] = y[i]
50 |
51 | i = 0
52 | while i < n:
53 | k = target[i] + 1
54 | if k == n:
55 | break
56 | if sol[i] > sol[k]:
57 | i = k
58 | continue
59 | sum_y = sums[i]
60 | sum_c = c[i]
61 | while True:
62 | # We are within an increasing subsequence.
63 | prev_y = sol[k]
64 | sum_y += sums[k]
65 | sum_c += c[k]
66 | k = target[k] + 1
67 | if k == n or prev_y > sol[k]:
68 | # Non-singleton increasing subsequence is finished,
69 | # update first entry.
70 | sol[i] = sum_y / sum_c
71 | sums[i] = sum_y
72 | c[i] = sum_c
73 | target[i] = k - 1
74 | target[k - 1] = i
75 | if i > 0:
76 | # Backtrack if we can. This makes the algorithm
77 | # single-pass and ensures O(n) complexity.
78 | i = target[i - 1]
79 | # Otherwise, restart from the same point.
80 | break
81 |
82 | # Reconstruct the solution.
83 | i = 0
84 | while i < n:
85 | k = target[i] + 1
86 | sol[i + 1 : k] = sol[i]
87 | i = k
88 | return sol.astype(y.dtype)
89 |
90 |
91 | @jax.custom_jvp
92 | def _isotonic_l2_pav(y):
93 | if not NUMBA_AVAILABLE:
94 | warnings.warn(
95 | "Numba could not be imported. Code will run much more slowly."
96 | " To install, run 'pip install numba'."
97 | )
98 | # Define the expected shape & dtype of output.
99 | shape_dtype = jax.ShapeDtypeStruct(shape=y.shape, dtype=y.dtype)
100 | sol = jax.pure_callback(
101 | _isotonic_l2_pav_numba, shape_dtype, y, vmap_method="sequential"
102 | )
103 | return sol
104 |
105 |
106 | def isotonic_l2_pav(y, y_min=-jnp.inf, y_max=jnp.inf, increasing=True):
107 | r"""Solves an isotonic regression problem using PAV.
108 |
109 | Args:
110 | y: input to isotonic regression, a 1d-array.
111 |
112 | y_min : Lower bound on the lowest predicted value.
113 | y_max : Upper bound on the highest predicted value
114 |
115 | increasing : Order of the constraints:
116 | If True, it solves :math:`\mathop{\mathrm{arg\,min}}_{v_1 \leq ... \leq v_n} \|v - y\|^2`.
117 | If False, it solves :math:`\mathop{\mathrm{arg\,min}}_{v_1 \geq ... \geq v_n} \|v - y\|^2`.
118 |
119 | Returns:
120 | The solution, an array of the same size as y.
121 | """
122 | sign = -1 if increasing else 1
123 | sol = _isotonic_l2_pav(y * sign) * sign
124 | sol = jnp.clip(sol, y_min, y_max)
125 | return sol
126 |
127 |
128 | def _jvp_isotonic_l2_jax_pav(solution, vector, eps=1e-8):
129 | x = solution
130 | mask = jnp.pad(jnp.absolute(jnp.diff(x)) <= eps, (1, 0))
131 | ar = jnp.arange(x.size)
132 | inds_start = jnp.where(mask == 0, ar, +jnp.inf).sort()
133 | one_hot_start = jax.nn.one_hot(inds_start, len(vector))
134 | A = jnp.cumsum(one_hot_start, axis=-1)
135 | A = jnp.append(jnp.diff(A[::-1], axis=0)[::-1], A[-1].reshape(1, -1), axis=0)
136 | B = A.copy()
137 | return (((B.T * (B @ vector)).T) / (A.sum(1, keepdims=True) + 1e-8)).sum(0)
138 |
139 |
140 | @_isotonic_l2_pav.defjvp
141 | def _isotonic_l2_pav_jvp(primals, tangents):
142 | """Jacobian-vector product of isotonic_l2_pav.
143 |
144 | See Section 5 of
145 | Fast Differentiable Sorting and Ranking
146 | Mathieu Blondel, Olivier Teboul, Quentin Berthet, Josip Djolonga
147 | ICML 2020 arXiv:2002.08871
148 | """
149 | (y, ) = primals
150 | (vector, ) = tangents
151 | primal_out = _isotonic_l2_pav(y)
152 | tangent_out = _jvp_isotonic_l2_jax_pav(primal_out, vector)
153 | return primal_out, tangent_out
154 |
--------------------------------------------------------------------------------
/jaxopt/_src/linear_operator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Interface for linear operators."""
15 |
16 | import functools
17 | import jax
18 | import jax.numpy as jnp
19 |
20 | from jaxopt.tree_util import tree_map
21 |
22 |
23 | class DenseLinearOperator:
24 | """General operator for dense matrices.
25 |
26 | Attributes:
27 | pytree: pytree of dense matrices.
28 |
29 | Each leaf of ``pytree`` must be a 2D matrix.
30 | """
31 |
32 | def __init__(self, pytree):
33 | self.pytree = pytree
34 |
35 | def __call__(self, x):
36 | return self.matvec(x)
37 |
38 | def matvec(self, x):
39 | return tree_map(jnp.dot, self.pytree, x)
40 |
41 | def rmatvec(self, _, y):
42 | return tree_map(lambda w,yi: jnp.dot(w.T, yi), self.pytree, y)
43 |
44 | def matvec_and_rmatvec(self, x, y):
45 | return self.matvec(x), self.rmatvec(x, y)
46 |
47 | def normal_matvec(self, x):
48 | """Computes A^T A x."""
49 | return self.rmatvec(x, self.matvec(x))
50 |
51 | def diag(self):
52 | diags_only = tree_map(jnp.diag, self.pytree)
53 | return diags_only
54 |
55 | def columns_l2_norms(self, squared=False):
56 | def col_norm(w):
57 | col_norms = jnp.sum(jnp.square(w), axis=0)
58 | if not squared:
59 | col_norms = jnp.sqrt(col_norms)
60 | return col_norms
61 | return tree_map(col_norm, self.pytree)
62 |
63 |
64 | class FunctionalLinearOperator:
65 |
66 | def __init__(self, fun, params):
67 | self.fun = functools.partial(fun, params)
68 |
69 | def __call__(self, x):
70 | return self.matvec(x)
71 |
72 | def matvec(self, x):
73 | return self.fun(x)
74 |
75 | def rmatvec(self, x, y):
76 | return self.matvec_and_rmatvec(x, y)[1]
77 |
78 | def matvec_and_rmatvec(self, x, y):
79 | matvec_x, vjp = jax.vjp(self.matvec, x)
80 | rmatvec_y, = vjp(y)
81 | return matvec_x, rmatvec_y
82 |
83 | def normal_matvec(self, x):
84 | """Computes A^T A x from matvec(x) = A x."""
85 | matvec_x, vjp = jax.vjp(self.matvec, x)
86 | return vjp(matvec_x)[0]
87 |
88 |
89 | def _make_linear_operator(matvec):
90 | if matvec is None:
91 | return DenseLinearOperator
92 | else:
93 | return functools.partial(FunctionalLinearOperator, matvec)
94 |
--------------------------------------------------------------------------------
/jaxopt/_src/linesearch_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Line searches utilities."""
16 |
17 | from jax import numpy as jnp
18 | from jaxopt._src import base
19 | from jaxopt._src.backtracking_linesearch import BacktrackingLineSearch
20 | from jaxopt._src.hager_zhang_linesearch import HagerZhangLineSearch
21 | from jaxopt._src.zoom_linesearch import ZoomLineSearch
22 |
23 |
24 | def _setup_linesearch(
25 | linesearch,
26 | fun,
27 | value_and_grad,
28 | has_aux,
29 | maxlsiter,
30 | max_stepsize,
31 | jit,
32 | unroll,
33 | verbose,
34 | ):
35 | """Instantiate linesearch."""
36 |
37 | available_linesearches = ["backtracking", "zoom", "hager-zhang"]
38 | if linesearch == "backtracking":
39 | linesearch_solver = BacktrackingLineSearch(
40 | fun=fun,
41 | value_and_grad=value_and_grad,
42 | has_aux=has_aux,
43 | maxiter=maxlsiter,
44 | max_stepsize=max_stepsize,
45 | jit=jit,
46 | unroll=unroll,
47 | verbose=verbose,
48 | )
49 | elif linesearch == "zoom":
50 | linesearch_solver = ZoomLineSearch(
51 | fun=fun,
52 | value_and_grad=value_and_grad,
53 | has_aux=has_aux,
54 | maxiter=maxlsiter,
55 | max_stepsize=max_stepsize,
56 | jit=jit,
57 | unroll=unroll,
58 | verbose=verbose,
59 | )
60 | elif linesearch == "hager-zhang":
61 | # NOTE(vroulet): max_stepsize has no effect in HZ
62 | linesearch_solver = HagerZhangLineSearch(
63 | fun=fun,
64 | value_and_grad=value_and_grad,
65 | has_aux=has_aux,
66 | maxiter=maxlsiter,
67 | jit=jit,
68 | unroll=unroll,
69 | verbose=verbose,
70 | )
71 | elif isinstance(linesearch, base.IterativeLineSearch):
72 | linesearch_solver = linesearch
73 | else:
74 | raise ValueError(
75 | f"Linesearch {linesearch} not available/tested. "
76 | f"Available linesearches: {available_linesearches}"
77 | )
78 | return linesearch_solver
79 |
80 |
81 | def _init_stepsize(
82 | strategy, max_stepsize, min_stepsize, increase_factor, stepsize
83 | ):
84 | """Set stepsize at the start of the linesearch from previous guess."""
85 | available_strategies = ["max", "current", "increase"]
86 | if strategy == "max":
87 | init_stepsize = max_stepsize
88 | elif strategy == "current":
89 | init_stepsize = stepsize
90 | elif strategy == "increase":
91 | init_stepsize = jnp.where(
92 | stepsize <= min_stepsize,
93 | # If stepsize became too small, we restart it.
94 | max_stepsize,
95 | # Else, we increase a bit the previous one.
96 | stepsize * increase_factor,
97 | )
98 | else:
99 | raise ValueError(
100 | f"Strategy {strategy} not available/tested. "
101 | f"Available linesearches: {available_strategies}"
102 | )
103 | return init_stepsize
104 |
--------------------------------------------------------------------------------
/jaxopt/_src/loop.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Loop utilities."""
16 |
17 | import jax
18 | import jax.numpy as jnp
19 |
20 |
21 | def _while_loop_scan(cond_fun, body_fun, init_val, max_iter):
22 | """Scan-based implementation (jit ok, reverse-mode autodiff ok)."""
23 | def _iter(val):
24 | next_val = body_fun(val)
25 | next_cond = cond_fun(next_val)
26 | return next_val, next_cond
27 |
28 | def _fun(tup, it):
29 | val, cond = tup
30 | # When cond is met, we start doing no-ops.
31 | return jax.lax.cond(cond, _iter, lambda x: (x, False), val), it
32 |
33 | init = (init_val, cond_fun(init_val))
34 | return jax.lax.scan(_fun, init, None, length=max_iter)[0][0]
35 |
36 |
37 | def _while_loop_python(cond_fun, body_fun, init_val, maxiter):
38 | """Python based implementation (no jit, reverse-mode autodiff ok)."""
39 | val = init_val
40 | for _ in range(maxiter):
41 | cond = cond_fun(val)
42 | if not cond:
43 | # When condition is met, break (not jittable).
44 | break
45 | val = body_fun(val)
46 | return val
47 |
48 |
49 | def _while_loop_lax(cond_fun, body_fun, init_val, maxiter):
50 | """lax.while_loop based implementation (jit by default, no reverse-mode)."""
51 | def _cond_fun(_val):
52 | it, val = _val
53 | return jnp.logical_and(cond_fun(val), it <= maxiter - 1)
54 |
55 | def _body_fun(_val):
56 | it, val = _val
57 | val = body_fun(val)
58 | return it+1, val
59 |
60 | return jax.lax.while_loop(_cond_fun, _body_fun, (0, init_val))[1]
61 |
62 |
63 | def while_loop(cond_fun, body_fun, init_val, maxiter, unroll=False, jit=False):
64 | """A while loop with a bounded number of iterations."""
65 |
66 | if unroll:
67 | if jit:
68 | fun = _while_loop_scan
69 | else:
70 | fun = _while_loop_python
71 | else:
72 | if jit:
73 | fun = _while_loop_lax
74 | else:
75 | raise ValueError("unroll=False and jit=False cannot be used together")
76 |
77 | if jit and fun is not _while_loop_lax:
78 | # jit of a lax while_loop is redundant, and this jit would only
79 | # constrain maxiter to be static where it is not required.
80 | fun = jax.jit(fun, static_argnums=(0, 1, 3))
81 |
82 | return fun(cond_fun, body_fun, init_val, maxiter)
83 |
--------------------------------------------------------------------------------
/jaxopt/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from jaxopt._src.base import AutoOrBoolean
16 | from jaxopt._src.base import IterativeSolver
17 | from jaxopt._src.base import LinearOperator
18 | from jaxopt._src.base import OptStep
19 | from jaxopt._src.base import StochasticSolver
20 | from jaxopt._src.base import KKTSolution
21 |
--------------------------------------------------------------------------------
/jaxopt/cond.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from jaxopt._src.cond import cond
16 |
--------------------------------------------------------------------------------
/jaxopt/implicit_diff.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from jaxopt._src.implicit_diff import custom_root
16 | from jaxopt._src.implicit_diff import custom_fixed_point
17 | from jaxopt._src.implicit_diff import root_jvp
18 | from jaxopt._src.implicit_diff import root_vjp
19 |
--------------------------------------------------------------------------------
/jaxopt/isotonic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from jaxopt._src.isotonic import isotonic_l2_pav
--------------------------------------------------------------------------------
/jaxopt/linear_solve.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from jaxopt._src.linear_solve import solve_lu
16 | from jaxopt._src.linear_solve import solve_cholesky
17 | from jaxopt._src.linear_solve import solve_qr
18 | from jaxopt._src.linear_solve import solve_inv
19 | from jaxopt._src.linear_solve import solve_cg
20 | from jaxopt._src.linear_solve import solve_normal_cg
21 | from jaxopt._src.linear_solve import solve_gmres
22 | from jaxopt._src.linear_solve import solve_bicgstab
23 | from jaxopt._src.iterative_refinement import solve_iterative_refinement
24 |
--------------------------------------------------------------------------------
/jaxopt/loop.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from jaxopt._src.loop import while_loop
16 |
--------------------------------------------------------------------------------
/jaxopt/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from jaxopt._src.loss import binary_logistic_loss
16 | from jaxopt._src.loss import binary_sparsemax_loss, sparse_plus, sparse_sigmoid
17 | from jaxopt._src.loss import huber_loss
18 | from jaxopt._src.loss import make_fenchel_young_loss
19 | from jaxopt._src.loss import multiclass_logistic_loss
20 | from jaxopt._src.loss import multiclass_sparsemax_loss
21 | from jaxopt._src.loss import binary_hinge_loss
22 | from jaxopt._src.loss import binary_perceptron_loss
23 | from jaxopt._src.loss import multiclass_hinge_loss
24 | from jaxopt._src.loss import multiclass_perceptron_loss
25 |
--------------------------------------------------------------------------------
/jaxopt/objective.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from jaxopt._src.objective import CompositeLinearFunction
16 |
17 | from jaxopt._src.objective import least_squares
18 | from jaxopt._src.objective import ridge_regression
19 |
20 | from jaxopt._src.objective import binary_logreg
21 |
22 | from jaxopt._src.objective import multiclass_logreg
23 | from jaxopt._src.objective import multiclass_logreg_with_intercept
24 | from jaxopt._src.objective import l2_multiclass_logreg
25 | from jaxopt._src.objective import l2_multiclass_logreg_with_intercept
26 |
27 | from jaxopt._src.objective import multiclass_linear_svm_dual
28 |
--------------------------------------------------------------------------------
/jaxopt/perturbations.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from jaxopt._src.perturbations import Gumbel
16 | from jaxopt._src.perturbations import Normal
17 | from jaxopt._src.perturbations import make_perturbed_argmax
18 | from jaxopt._src.perturbations import make_perturbed_max
19 | from jaxopt._src.perturbations import make_perturbed_fun
20 |
--------------------------------------------------------------------------------
/jaxopt/projection.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from jaxopt._src.projection import projection_non_negative
16 | from jaxopt._src.projection import projection_box
17 | from jaxopt._src.projection import projection_hypercube
18 | from jaxopt._src.projection import projection_simplex
19 | from jaxopt._src.projection import projection_sparse_simplex
20 | from jaxopt._src.projection import projection_l1_sphere
21 | from jaxopt._src.projection import projection_l1_ball
22 | from jaxopt._src.projection import projection_l2_sphere
23 | from jaxopt._src.projection import projection_l2_ball
24 | from jaxopt._src.projection import projection_linf_ball
25 | from jaxopt._src.projection import projection_hyperplane
26 | from jaxopt._src.projection import projection_halfspace
27 | from jaxopt._src.projection import projection_affine_set
28 | from jaxopt._src.projection import projection_polyhedron
29 | from jaxopt._src.projection import projection_box_section
30 | from jaxopt._src.projection import projection_transport
31 | from jaxopt._src.projection import projection_birkhoff
32 | from jaxopt._src.projection import kl_projection_transport
33 | from jaxopt._src.projection import kl_projection_birkhoff
34 |
--------------------------------------------------------------------------------
/jaxopt/prox.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from jaxopt._src.prox import make_prox_from_projection
16 | from jaxopt._src.prox import prox_none
17 | from jaxopt._src.prox import prox_lasso
18 | from jaxopt._src.prox import prox_non_negative_lasso
19 | from jaxopt._src.prox import prox_elastic_net
20 | from jaxopt._src.prox import prox_group_lasso
21 | from jaxopt._src.prox import prox_ridge
22 | from jaxopt._src.prox import prox_non_negative_ridge
23 |
--------------------------------------------------------------------------------
/jaxopt/tree_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from jaxopt._src.tree_util import broadcast_pytrees
16 | from jaxopt._src.tree_util import tree_map
17 | from jaxopt._src.tree_util import tree_reduce
18 | from jaxopt._src.tree_util import tree_add
19 | from jaxopt._src.tree_util import tree_sub
20 | from jaxopt._src.tree_util import tree_mul
21 | from jaxopt._src.tree_util import tree_scalar_mul
22 | from jaxopt._src.tree_util import tree_add_scalar_mul
23 | from jaxopt._src.tree_util import tree_dot
24 | from jaxopt._src.tree_util import tree_vdot
25 | from jaxopt._src.tree_util import tree_vdot_real
26 | from jaxopt._src.tree_util import tree_div
27 | from jaxopt._src.tree_util import tree_sum
28 | from jaxopt._src.tree_util import tree_l2_norm
29 | from jaxopt._src.tree_util import tree_where
30 | from jaxopt._src.tree_util import tree_zeros_like
31 | from jaxopt._src.tree_util import tree_ones_like
32 | from jaxopt._src.tree_util import tree_negative
33 | from jaxopt._src.tree_util import tree_inf_norm
34 | from jaxopt._src.tree_util import tree_conj
35 | from jaxopt._src.tree_util import tree_real
36 | from jaxopt._src.tree_util import tree_imag
--------------------------------------------------------------------------------
/jaxopt/version.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """JAXopt version."""
16 |
17 | __version__ = "0.8.5"
18 |
--------------------------------------------------------------------------------
/pylintrc:
--------------------------------------------------------------------------------
1 | # This pylintrc file is taken from JAX
2 | # https://github.com/google/jax/blob/main/pylintrc
3 | [MASTER]
4 |
5 | # A comma-separated list of package or module names from where C extensions may
6 | # be loaded. Extensions are loading into the active Python interpreter and may
7 | # run arbitrary code
8 | extension-pkg-whitelist=numpy
9 |
10 |
11 | [MESSAGES CONTROL]
12 |
13 | # Disable the message, report, category or checker with the given id(s). You
14 | # can either give multiple identifiers separated by comma (,) or put this
15 | # option multiple times (only on the command line, not in the configuration
16 | # file where it should appear only once).You can also use "--disable=all" to
17 | # disable everything first and then reenable specific checks. For example, if
18 | # you want to run only the similarities checker, you can use "--disable=all
19 | # --enable=similarities". If you want to run only the classes checker, but have
20 | # no Warning level messages displayed, use"--disable=all --enable=classes
21 | # --disable=W"
22 | disable=missing-docstring,
23 | too-many-locals,
24 | invalid-name,
25 | redefined-outer-name,
26 | redefined-builtin,
27 | protected-name,
28 | no-else-return,
29 | fixme,
30 | protected-access,
31 | too-many-arguments,
32 | blacklisted-name,
33 | too-few-public-methods,
34 | unnecessary-lambda,
35 |
36 |
37 | # Enable the message, report, category or checker with the given id(s). You can
38 | # either give multiple identifier separated by comma (,) or put this option
39 | # multiple time (only on the command line, not in the configuration file where
40 | # it should appear only once). See also the "--disable" option for examples.
41 | enable=c-extension-no-member
42 |
43 |
44 | [FORMAT]
45 |
46 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
47 | # tab).
48 | indent-string=" "
49 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | jax>=0.2.18
2 | jaxlib>=0.1.69
3 | numpy>=1.18.4
4 | scipy>=1.0.0
5 |
--------------------------------------------------------------------------------
/requirements_test.txt:
--------------------------------------------------------------------------------
1 | absl-py>=0.7.0
2 | cvxpy>=1.1.11
3 | optax>=0.0.9
4 | pytest-xdist
5 | scikit-learn>=0.24.1
6 | cvxopt
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Setup script for JAXopt."""
16 |
17 | import os
18 | from setuptools import find_packages
19 | from setuptools import setup
20 |
21 |
22 | folder = os.path.dirname(__file__)
23 | version_path = os.path.join(folder, "jaxopt", "version.py")
24 |
25 | __version__ = None
26 | with open(version_path) as f:
27 | exec(f.read(), globals())
28 |
29 | req_path = os.path.join(folder, "requirements.txt")
30 | install_requires = []
31 | if os.path.exists(req_path):
32 | with open(req_path) as fp:
33 | install_requires = [line.strip() for line in fp]
34 |
35 | readme_path = os.path.join(folder, "README.md")
36 | readme_contents = ""
37 | if os.path.exists(readme_path):
38 | with open(readme_path) as fp:
39 | readme_contents = fp.read().strip()
40 |
41 | setup(
42 | name="jaxopt",
43 | version=__version__,
44 | description="Hardware accelerated, batchable and differentiable optimizers in JAX.",
45 | author="Google LLC",
46 | author_email="no-reply@google.com",
47 | url="https://github.com/google/jaxopt",
48 | long_description=readme_contents,
49 | long_description_content_type="text/markdown",
50 | license="Apache 2.0",
51 | packages=find_packages(),
52 | package_data={},
53 | install_requires=install_requires,
54 | classifiers=[
55 | "Intended Audience :: Science/Research",
56 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
57 | "License :: OSI Approved :: Apache Software License",
58 | "Programming Language :: Python :: 3",
59 | "Programming Language :: Python :: 3.10",
60 | "Programming Language :: Python :: 3.11",
61 | "Programming Language :: Python :: 3.12",
62 | "Programming Language :: Python :: 3.13",
63 | ],
64 | keywords="optimization, root finding, implicit differentiation, jax",
65 | requires_python=">=3.10",
66 | )
67 |
--------------------------------------------------------------------------------
/tests/anderson_wrapper_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from absl.testing import absltest
17 |
18 | import jax.numpy as jnp
19 | from jax.test_util import check_grads
20 | import optax
21 |
22 | from jaxopt import objective
23 |
24 | from jaxopt import prox
25 | from jaxopt._src import test_util
26 |
27 | from jaxopt import AndersonWrapper
28 | from jaxopt import BlockCoordinateDescent
29 | from jaxopt import OptaxSolver
30 | from jaxopt import PolyakSGD
31 | from jaxopt import ProximalGradient
32 |
33 | from sklearn import datasets
34 |
35 |
36 | class AndersonWrapperTest(test_util.JaxoptTestCase):
37 |
38 | def test_proximal_gradient_wrapper(self):
39 | """Baseline test on simple optimizer."""
40 | X, y = datasets.make_regression(n_samples=100, n_features=20, random_state=0)
41 | fun = objective.least_squares
42 | lam = 10.0
43 | data = (X, y)
44 | w_init = jnp.zeros(X.shape[1])
45 | tol = 1e-3
46 | maxiter = 1000
47 | pg = ProximalGradient(fun=fun, prox=prox.prox_lasso, maxiter=maxiter, tol=tol,
48 | acceleration=False)
49 | aw = AndersonWrapper(pg, history_size=15)
50 | aw_params, awpg_info = aw.run(w_init, hyperparams_prox=lam, data=data)
51 | self.assertLess(awpg_info.error, tol)
52 |
53 | def test_mixing_frequency_polyak(self):
54 | """Test mixing_frequency by accelerating PolyakSGD."""
55 | X, y = datasets.make_classification(n_samples=10, n_features=5, n_classes=3,
56 | n_informative=3, random_state=0)
57 | data = (X, y)
58 | l2reg = 100.0
59 | # fun(params, data)
60 | fun = objective.l2_multiclass_logreg_with_intercept
61 | n_classes = len(jnp.unique(y))
62 |
63 | W_init = jnp.zeros((X.shape[1], n_classes))
64 | b_init = jnp.zeros(n_classes)
65 | pytree_init = (W_init, b_init)
66 |
67 | opt = PolyakSGD(fun=fun, max_stepsize=0.01, tol=0.05, momentum=False)
68 | history_size = 5
69 | aw = AndersonWrapper(opt, history_size=history_size, mixing_frequency=1)
70 | aw_params, aw_state = aw.run(pytree_init, l2reg=l2reg, data=data)
71 | self.assertLess(aw_state.error, 0.05)
72 |
73 | def test_optax_restart(self):
74 | """Test Optax optimizer."""
75 | X, y = datasets.make_classification(n_samples=100, n_features=20, n_classes=3,
76 | n_informative=3, random_state=0)
77 | data = (X, y)
78 | l2reg = 100.0
79 | # fun(params, data)
80 | fun = objective.l2_multiclass_logreg_with_intercept
81 | n_classes = len(jnp.unique(y))
82 |
83 | W_init = jnp.zeros((X.shape[1], n_classes))
84 | b_init = jnp.zeros(n_classes)
85 | pytree_init = (W_init, b_init)
86 |
87 | tol = 1e-2
88 | opt = OptaxSolver(opt=optax.sgd(1e-2, momentum=0.8), fun=fun, maxiter=1000, tol=0)
89 | aw = AndersonWrapper(opt, history_size=3, ridge=1e-3)
90 | params, infos = aw.run(pytree_init, l2reg=l2reg, data=data)
91 |
92 | # Check optimality conditions.
93 | error = opt.l2_optimality_error(params, l2reg=l2reg, data=data)
94 | self.assertLessEqual(error, tol)
95 |
96 | def test_block_cd_restart(self):
97 | """Accelerate Block CD."""
98 | X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0)
99 |
100 | # Setup parameters.
101 | fun = objective.least_squares # fun(params, data)
102 | l1reg = 10.0
103 | data = (X, y)
104 |
105 | # Initialize.
106 | w_init = jnp.zeros(X.shape[1])
107 | tol = 5e-4
108 | maxiter = 100
109 | bcd = BlockCoordinateDescent(fun=fun, block_prox=prox.prox_lasso, tol=tol, maxiter=maxiter)
110 | aw = AndersonWrapper(bcd, history_size=3, ridge=1e-5)
111 | params, state = aw.run(init_params=w_init, hyperparams_prox=l1reg, data=data)
112 |
113 | # Check optimality conditions.
114 | self.assertLess(state.error, tol)
115 |
116 | def test_wrapper_grad(self):
117 | """Test gradient of wrapper."""
118 | data_train = datasets.make_regression(n_samples=100, n_features=3, random_state=0)
119 | fun = objective.least_squares
120 | lam = 10.0
121 | w_init = jnp.zeros(data_train[0].shape[1])
122 | tol = 1e-5
123 | maxiter = 1000 # large number of updates
124 | pg = ProximalGradient(fun=fun, prox=prox.prox_lasso, maxiter=maxiter, tol=tol,
125 | acceleration=False)
126 | aw = AndersonWrapper(pg, history_size=5)
127 | data_val = datasets.make_regression(n_samples=100, n_features=3, random_state=0)
128 |
129 | def solve_run(lam):
130 | aw_params = aw.run(w_init, lam, data_train).params
131 | loss = fun(aw_params, data=data_val)
132 | return loss
133 |
134 | check_grads(solve_run, args=(lam,), order=1, modes=['rev'], eps=2e-2)
135 |
136 | def solve_run(lam):
137 | aw_params = aw.run(w_init, hyperparams_prox=lam, data=data_train).params
138 | loss = fun(aw_params, data=data_val)
139 | return loss
140 |
141 | check_grads(solve_run, args=(lam,), order=1, modes=['rev'], eps=2e-2)
142 |
143 | if __name__ == '__main__':
144 | # Uncomment the line below in order to run in float64.
145 | # config.update("jax_enable_x64", True)
146 | absltest.main()
147 |
--------------------------------------------------------------------------------
/tests/bisection_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from absl.testing import absltest
16 |
17 | import jax
18 | import jax.numpy as jnp
19 |
20 | from jaxopt import projection
21 | from jaxopt import Bisection
22 | from jaxopt._src import test_util
23 |
24 | import numpy as onp
25 |
26 |
27 | # optimality_fun(params, hyperparams, data)
28 | def _optimality_fun_proj_simplex(tau, x, s):
29 | # optimality_fun(tau, x, s) is a decreasing function of tau on
30 | # [lower, upper] since the derivative w.r.t. tau is negative.
31 | return jnp.sum(jnp.maximum(x - tau, 0)) - s
32 |
33 |
34 | def _threshold_proj_simplex(bisect, x, s=1.0):
35 | return bisect.run(None, x, s).params
36 |
37 |
38 | def _projection_simplex_bisect(bisect, x, s=1.0):
39 | return jnp.maximum(x - _threshold_proj_simplex(bisect, x, s), 0)
40 |
41 |
42 | def _projection_simplex_bisect_setup(x, s=1.0):
43 | # tau = max(x) => tau >= x_i for all i
44 | # => x_i - tau <= 0 for all i
45 | # => maximum(x_i - tau, 0) = 0 for all i
46 | # => optimality_fun(tau, x, s) = -s <= 0
47 | upper = jax.lax.stop_gradient(jnp.max(x))
48 |
49 | # tau' = min(x) => tau' <= x_i for all i
50 | # => 0 <= x_i - tau' for all_i
51 | # => maximum(x_i - tau', 0) >= 0
52 | # => optimality_fun(tau, x, s) >= 0
53 | # where tau = tau' - s / len(x)
54 | lower = jax.lax.stop_gradient(jnp.min(x)) - s / len(x)
55 |
56 | return Bisection(optimality_fun=_optimality_fun_proj_simplex,
57 | lower=lower, upper=upper, check_bracket=False)
58 |
59 |
60 | class BisectionTest(test_util.JaxoptTestCase):
61 |
62 | def test_bisect(self):
63 | rng = onp.random.RandomState(0)
64 |
65 | _projection_simplex_bisect_jitted = jax.jit(
66 | _projection_simplex_bisect, static_argnums=0)
67 |
68 | for _ in range(10):
69 | x = jnp.array(rng.randn(50).astype(onp.float32))
70 | bisect = _projection_simplex_bisect_setup(x)
71 | p = projection.projection_simplex(x)
72 | p2 = _projection_simplex_bisect(bisect, x)
73 | p3 = _projection_simplex_bisect_jitted(bisect, x)
74 | self.assertArraysAllClose(p, p2, atol=1e-4)
75 | self.assertArraysAllClose(p, p3, atol=1e-4)
76 |
77 | J = jax.jacrev(projection.projection_simplex)(x)
78 | J2 = jax.jacrev(_projection_simplex_bisect, argnums=1)(bisect, x)
79 | J3 = jax.jacrev(_projection_simplex_bisect_jitted, argnums=1)(bisect, x)
80 | self.assertArraysAllClose(J, J2, atol=1e-5)
81 | self.assertArraysAllClose(J, J3, atol=1e-5)
82 |
83 | def test_bisect_wrong_lower_bracket(self):
84 | rng = onp.random.RandomState(0)
85 | x = jnp.array(rng.randn(5).astype(onp.float32))
86 | s = 1.0
87 | upper = jnp.max(x)
88 | bisect = Bisection(optimality_fun=_optimality_fun_proj_simplex,
89 | lower=upper, upper=upper)
90 | self.assertRaises(ValueError, bisect.run, None, x, s)
91 |
92 | def test_bisect_wrong_upper_bracket(self):
93 | rng = onp.random.RandomState(0)
94 | x = jnp.array(rng.randn(5).astype(onp.float32))
95 | s = 1.0
96 | lower = jnp.min(x) - s / len(x)
97 | bisect = Bisection(optimality_fun=_optimality_fun_proj_simplex,
98 | lower=lower, upper=lower)
99 | self.assertRaises(ValueError, bisect.run, None, x, s)
100 |
101 | def test_grad_of_value_and_grad(self):
102 | # See https://github.com/google/jaxopt/issues/141
103 |
104 | def bisect(x):
105 | b = _projection_simplex_bisect_setup(x)
106 | return _projection_simplex_bisect(b, x)[0]
107 |
108 | def bisect_val(x):
109 | val, _ = jax.value_and_grad(bisect)(x)
110 | return val
111 |
112 | rng = onp.random.RandomState(0)
113 | x = jnp.array(rng.randn(5).astype(onp.float32))
114 | g1 = jax.grad(bisect)(x)
115 | g2 = jax.grad(bisect_val)(x)
116 | self.assertArraysAllClose(g1, g2)
117 |
118 | def test_edge(self):
119 | def F(x):
120 | return x - 0.5
121 |
122 | bisec = Bisection(optimality_fun=F, lower=0.0, upper=1.0)
123 | # The solution is found on the first iteration.
124 | self.assertEqual(bisec.run().params, 0.5)
125 |
126 |
127 | if __name__ == '__main__':
128 | absltest.main()
129 |
--------------------------------------------------------------------------------
/tests/cd_qp_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from absl.testing import absltest
16 |
17 | import jax
18 | import jax.numpy as jnp
19 |
20 | from jaxopt import BoxCDQP
21 | from jaxopt._src import test_util
22 |
23 | import numpy as onp
24 |
25 |
26 | def _cd_qp(Q, c, l, u, tol, maxiter, verbose=0):
27 | """Pure NumPy implementation for test purposes."""
28 | x = onp.zeros(Q.shape[0])
29 |
30 | for it in range(maxiter):
31 | error = 0
32 |
33 | for i in range(len(x)):
34 | g_i = onp.dot(Q[i], x) + c[i]
35 | h_i = Q[i, i]
36 |
37 | if h_i == 0:
38 | continue
39 |
40 | x_i_new = onp.clip(x[i] - g_i / h_i, l[i], u[i])
41 | delta_i = x_i_new - x[i]
42 | error += onp.abs(delta_i)
43 | x[i] = x_i_new
44 |
45 | if verbose:
46 | print(it + 1, error)
47 |
48 | if error <= tol:
49 | break
50 |
51 | return x
52 |
53 |
54 | class CD_QP_Test(test_util.JaxoptTestCase):
55 |
56 | def setUp(self):
57 | rng = onp.random.RandomState(0)
58 | num_dim = 5
59 | M = rng.randn(num_dim, num_dim)
60 | self.Q = onp.dot(M, M.T)
61 | self.c = rng.randn(num_dim)
62 | self.l = rng.randn(num_dim)
63 | self.u = self.l + 5 * rng.rand(num_dim)
64 | self.params_obj = (self.Q, self.c)
65 | self.params_ineq = (self.l, self.u)
66 |
67 | def test_forward(self):
68 | sol_numpy = _cd_qp(self.Q, self.c, self.l, self.u,
69 | tol=1e-3, maxiter=100, verbose=0)
70 |
71 | # Manual loop
72 | params = jnp.zeros_like(sol_numpy)
73 |
74 | cdqp = BoxCDQP()
75 | state = cdqp.init_state(params)
76 |
77 | for _ in range(5):
78 | params, state = cdqp.update(params, state, params_obj=self.params_obj,
79 | params_ineq=self.params_ineq)
80 |
81 | self.assertAllClose(state.error, 0.0)
82 | self.assertAllClose(params, sol_numpy)
83 |
84 | # Run call.
85 | params = jnp.zeros_like(sol_numpy)
86 | params, state = cdqp.run(params, params_obj=self.params_obj,
87 | params_ineq=self.params_ineq)
88 | self.assertAllClose(state.error, 0.0)
89 | self.assertAllClose(params, sol_numpy)
90 |
91 | def test_backward(self):
92 | cdqp = BoxCDQP(implicit_diff=True)
93 | cdqp2 = BoxCDQP(implicit_diff=False)
94 | init_params = jnp.zeros(self.Q.shape[0])
95 |
96 | def wrapper(c):
97 | params_obj = (self.Q, c)
98 | return cdqp.run(init_params, params_obj=params_obj,
99 | params_ineq=self.params_ineq).params
100 |
101 | def wrapper2(c):
102 | params_obj = (self.Q, c)
103 | return cdqp2.run(init_params, params_obj=params_obj,
104 | params_ineq=self.params_ineq).params
105 |
106 | J = jax.jacobian(wrapper)(self.c)
107 | J2 = jax.jacobian(wrapper2)(self.c)
108 | self.assertAllClose(J, J2)
109 |
110 |
111 | if __name__ == '__main__':
112 | # Uncomment the line below in order to run in float64.
113 | #jax.config.update("jax_enable_x64", True)
114 | absltest.main()
115 |
--------------------------------------------------------------------------------
/tests/cond_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from absl.testing import absltest
16 | from absl.testing import parameterized
17 |
18 | import jax
19 | import jax.numpy as jnp
20 | import numpy as onp
21 |
22 | from jaxopt._src.cond import cond
23 | from jaxopt._src import test_util
24 |
25 |
26 | class CondTest(test_util.JaxoptTestCase):
27 |
28 | @parameterized.product(jit=[False, True])
29 | def test_cond(self, jit):
30 | def true_fun(x):
31 | return x
32 | def false_fun(x):
33 | return jnp.zeros_like(x)
34 |
35 | def my_relu(x):
36 | return cond(jnp.sum(x)>0, true_fun, false_fun, x, jit=jit)
37 |
38 | if jit:
39 | x = onp.array([1.])
40 | else:
41 | x = jnp.array([1.])
42 | self.assertEqual(jax.nn.relu(x), my_relu(x))
43 |
44 | if __name__ == '__main__':
45 | absltest.main()
46 |
--------------------------------------------------------------------------------
/tests/cvxpy_wrapper_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """CVXPY tests."""
16 |
17 | from absl.testing import absltest
18 |
19 | import jax
20 | import jax.numpy as jnp
21 | import numpy as onp
22 |
23 | from jaxopt import projection
24 | from jaxopt import CvxpyQP
25 | from jaxopt._src import test_util
26 |
27 |
28 | class CvxpyQPTest(test_util.JaxoptTestCase):
29 |
30 | def _check_derivative_Q_c_A_b(self, solver, params, Q, c, A, b):
31 | def fun(Q, c, A, b):
32 | try:
33 | params_ineq = params["params_ineq"]
34 | except KeyError:
35 | params_ineq = None
36 |
37 | Q = 0.5 * (Q + Q.T)
38 |
39 | hyperparams = dict(params_obj=(Q, c),
40 | params_eq=(A, b),
41 | params_ineq=params_ineq)
42 |
43 | # reduce the primal variables to a scalar value for test purpose.
44 | return jnp.sum(solver.run(None, **hyperparams).params[0])
45 |
46 | # Derivative w.r.t. A.
47 | rng = onp.random.RandomState(0)
48 | V = rng.rand(*A.shape)
49 | V /= onp.sqrt(onp.sum(V ** 2))
50 | eps = 1e-4
51 | deriv_jax = jnp.vdot(V, jax.grad(fun, argnums=2)(Q, c, A, b))
52 | deriv_num = (fun(Q, c, A + eps * V, b) - fun(Q, c, A - eps * V, b)) / (2 * eps)
53 | self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)
54 |
55 | # Derivative w.r.t. b.
56 | v = rng.rand(*b.shape)
57 | v /= onp.sqrt(onp.sum(v ** 2))
58 | eps = 1e-4
59 | deriv_jax = jnp.vdot(v, jax.grad(fun, argnums=3)(Q, c, A, b))
60 | deriv_num = (fun(Q, c, A, b + eps * v) - fun(Q, c, A, b - eps * v)) / (2 * eps)
61 | self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)
62 |
63 | # Derivative w.r.t. Q
64 | W = rng.rand(*Q.shape)
65 | W /= onp.sqrt(onp.sum(W ** 2))
66 | eps = 1e-4
67 | deriv_jax = jnp.vdot(W, jax.grad(fun, argnums=0)(Q, c, A, b))
68 | deriv_num = (fun(Q + eps * W, c, A, b) - fun(Q - eps * W, c, A, b)) / (2 * eps)
69 | self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)
70 |
71 | # Derivative w.r.t. c
72 | w = rng.rand(*c.shape)
73 | w /= onp.sqrt(onp.sum(w ** 2))
74 | eps = 1e-4
75 | deriv_jax = jnp.vdot(w, jax.grad(fun, argnums=1)(Q, c, A, b))
76 | deriv_num = (fun(Q, c + eps * w, A, b) - fun(Q, c - eps * w, A, b)) / (2 * eps)
77 | self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)
78 |
79 | def test_qp_eq_and_ineq(self):
80 | Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
81 | c = jnp.array([1.0, 1.0])
82 | A = jnp.array([[1.0, 1.0]])
83 | b = jnp.array([1.0])
84 | G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
85 | h = jnp.array([0.0, 0.0])
86 | qp = CvxpyQP()
87 | hyperparams = dict(params_obj=(Q, c), params_eq=(A, b), params_ineq=(G, h))
88 | sol = qp.run(None, **hyperparams).params
89 | self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0, atol=1e-4)
90 | self._check_derivative_Q_c_A_b(qp, hyperparams, Q, c, A, b)
91 |
92 | def test_projection_simplex(self):
93 | def _projection_simplex_qp(x, s=1.0):
94 | Q = jnp.eye(len(x))
95 | A = jnp.array([jnp.ones_like(x)])
96 | b = jnp.array([s])
97 | G = -jnp.eye(len(x))
98 | h = jnp.zeros_like(x)
99 | hyperparams = dict(params_obj=(Q, -x), params_eq=(A, b),
100 | params_ineq=(G, h))
101 |
102 | qp = CvxpyQP()
103 | # Returns the primal solution only.
104 | return qp.run(None, **hyperparams).params[0]
105 |
106 | rng = onp.random.RandomState(0)
107 | x = jnp.array(rng.randn(10).astype(onp.float32))
108 | p = projection.projection_simplex(x)
109 | p2 = _projection_simplex_qp(x)
110 | self.assertArraysAllClose(p, p2)
111 | J = jax.jacrev(projection.projection_simplex)(x)
112 | J2 = jax.jacrev(_projection_simplex_qp)(x)
113 | self.assertArraysAllClose(J, J2, atol=1e-5)
114 |
115 |
116 | if __name__ == '__main__':
117 | absltest.main()
118 |
--------------------------------------------------------------------------------
/tests/gauss_newton_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from absl.testing import absltest
16 | from absl.testing import parameterized
17 |
18 | import jax
19 | import jax.numpy as jnp
20 |
21 | from jaxopt import GaussNewton
22 | from jaxopt._src import test_util
23 |
24 | import numpy as onp
25 |
26 |
27 | def _enzyme_reaction_residual_model(coeffs, x, y):
28 | return y - coeffs[0] * x / (coeffs[1] + x)
29 |
30 |
31 | def _enzyme_reaction_residual_model_jac(coeffs, x, y, eps=1e-5):
32 | """Return the numerical Jacobian."""
33 | gn = GaussNewton(
34 | residual_fun=_enzyme_reaction_residual_model,
35 | maxiter=100,
36 | tol=1.0e-6)
37 |
38 | # Sets eps only at idx, the rest is zero
39 | eps_at = lambda idx: onp.array([int(i == idx)*eps for i in range(len(x))])
40 |
41 | res1 = jnp.zeros((len(coeffs), len(x)))
42 | res2 = jnp.zeros((len(coeffs), len(x)))
43 | for i in range(len(x)):
44 | res1 = res1.at[:,i].set(gn.run(coeffs, x + eps_at(i), y).params)
45 | res2 = res2.at[:,i].set(gn.run(coeffs, x - eps_at(i), y).params)
46 |
47 | twoeps = 2 * eps
48 | return (res1 - res2) / twoeps
49 |
50 |
51 | def _city_temperature_residual_model(coeffs, x, y):
52 | return y - (coeffs[0] * jnp.sin(x * coeffs[1] + coeffs[2]) + coeffs[3])
53 |
54 |
55 | class GaussNewtonTest(test_util.JaxoptTestCase):
56 |
57 | def setUp(self):
58 | super().setUp()
59 |
60 | self.substrate_conc = onp.array(
61 | [0.038, 0.194, .425, .626, 1.253, 2.500, 3.740])
62 | self.rate_data = onp.array(
63 | [0.050, 0.127, 0.094, 0.2122, 0.2729, 0.2665, 0.3317])
64 | self.init_enzyme_reaction_coeffs = onp.array([0.1, 0.1])
65 |
66 | self.months = onp.arange(1, 13)
67 | self.temperature_record = onp.array([
68 | 61.0, 65.0, 72.0, 78.0, 85.0, 90.0, 92.0, 92.0, 88.0, 81.0, 72.0, 63.0
69 | ])
70 | self.init_temperature_record_coeffs = onp.array([10, 0.5, 10.5, 50])
71 |
72 | def test_aux_true(self):
73 | gn = GaussNewton(lambda x: (x**2, True), has_aux=True, maxiter=2)
74 | x_init = jnp.arange(2.)
75 | _, state = gn.run(x_init)
76 | self.assertEqual(state.aux, True)
77 |
78 | # Example taken from "Probability, Statistics and Estimation" by Mathieu ROUAUD.
79 | # The algorithm is detailed and applied to the biology experiment discussed in
80 | # page 84 with the uncertainties on the estimated values.
81 | def test_enzyme_reaction_parameter_fit(self):
82 | gn = GaussNewton(
83 | residual_fun=_enzyme_reaction_residual_model,
84 | maxiter=100,
85 | tol=1.0e-6)
86 | optimize_info = gn.run(
87 | self.init_enzyme_reaction_coeffs,
88 | self.substrate_conc,
89 | self.rate_data)
90 |
91 | self.assertArraysAllClose(optimize_info.params,
92 | onp.array([0.36183689, 0.55626653]),
93 | rtol=1e-7, atol=1e-7)
94 |
95 | @parameterized.product(implicit_diff=[True, False])
96 | def test_enzyme_reaction_implicit_diff(self, implicit_diff):
97 | jac_num = _enzyme_reaction_residual_model_jac(
98 | self.init_enzyme_reaction_coeffs, self.substrate_conc, self.rate_data)
99 |
100 | gn = GaussNewton(
101 | residual_fun=_enzyme_reaction_residual_model,
102 | tol=1.0e-6,
103 | maxiter=10,
104 | implicit_diff=implicit_diff)
105 |
106 | def wrapper(substrate_conc):
107 | return gn.run(
108 | self.init_enzyme_reaction_coeffs,
109 | substrate_conc,
110 | self.rate_data).params
111 | jac_custom = jax.jacrev(wrapper)(self.substrate_conc)
112 |
113 | self.assertArraysAllClose(jac_num, jac_custom, atol=1e-2)
114 |
115 | # Example 7 from "SOLVING NONLINEAR LEAST-SQUARES PROBLEMS WITH THE
116 | # GAUSS-NEWTON AND LEVENBERG-MARQUARDT METHODS" by ALFONSO CROEZE et al.
117 | def test_temperature_record_four_parameter_fit(self):
118 | gn = GaussNewton(
119 | residual_fun=_city_temperature_residual_model,
120 | tol=1.0e-6)
121 | optimize_info = gn.run(
122 | self.init_temperature_record_coeffs,
123 | self.months,
124 | self.temperature_record)
125 |
126 | # Checking against the expected values
127 | self.assertArraysAllClose(
128 | optimize_info.params,
129 | onp.array([16.63994555, 0.46327812, 10.85228919, 76.19086103]),
130 | rtol=1e-6, atol=1e-5)
131 |
132 | def test_scalar_output_fun(self):
133 | gn = GaussNewton(
134 | residual_fun=lambda x: x @ x,
135 | tol=1e-1,)
136 | x_init = jnp.ones((2,))
137 | x_opt, _ = gn.run(x_init)
138 |
139 | self.assertAllClose(x_opt, jnp.zeros((2,)), atol=1e0)
140 |
141 |
142 | if __name__ == '__main__':
143 | absltest.main()
144 |
--------------------------------------------------------------------------------
/tests/hager_zhang_linesearch_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from absl.testing import absltest
16 | from absl.testing import parameterized
17 |
18 | import jax
19 | import jax.numpy as jnp
20 |
21 | from jaxopt import HagerZhangLineSearch
22 | from jaxopt import objective
23 | from jaxopt._src import test_util
24 | from jaxopt.tree_util import tree_scalar_mul
25 | from jaxopt.tree_util import tree_vdot
26 |
27 | import numpy as onp
28 |
29 | from sklearn import datasets
30 |
31 |
32 | class HagerZhangLinesearchTest(test_util.JaxoptTestCase):
33 |
34 | def _check_conditions_satisfied(
35 | self,
36 | c1,
37 | c2,
38 | stepsize,
39 | initial_value,
40 | initial_grad,
41 | final_state):
42 | self.assertTrue(jnp.all(final_state.done))
43 | self.assertFalse(jnp.any(final_state.failed))
44 |
45 | descent_direction = tree_scalar_mul(-1, initial_grad)
46 | sufficient_decrease = jnp.all(
47 | final_state.value <= initial_value +
48 | c1 * stepsize * tree_vdot(final_state.grad, descent_direction))
49 | self.assertTrue(sufficient_decrease)
50 |
51 | new_gd_vdot = tree_vdot(final_state.grad, descent_direction)
52 | gd_vdot = tree_vdot(initial_grad, descent_direction)
53 | curvature = jnp.all(new_gd_vdot >= c2 * gd_vdot)
54 | self.assertTrue(curvature)
55 |
56 | def test_hager_zhang_linesearch(self):
57 | x, y = datasets.make_classification(
58 | n_samples=10, n_features=5, n_classes=2,
59 | n_informative=3, random_state=0)
60 | data = (x, y)
61 | fun = objective.binary_logreg
62 |
63 | rng = onp.random.RandomState(0)
64 | w_init = rng.randn(x.shape[1])
65 | initial_grad = jax.grad(fun)(w_init, data=data)
66 | initial_value = fun(w_init, data=data)
67 |
68 | # Manual loop.
69 | ls = HagerZhangLineSearch(fun=fun)
70 | stepsize = 1.0
71 | state = ls.init_state(
72 | init_stepsize=1.0, params=w_init, fun_kwargs={"data": data}
73 | )
74 | stepsize, state = ls.update(stepsize=stepsize, state=state, params=w_init,
75 | fun_kwargs={"data": data})
76 |
77 | # Call to run.
78 | ls = HagerZhangLineSearch(fun=fun, maxiter=20)
79 | stepsize, state = ls.run(
80 | init_stepsize=1.0, params=w_init, fun_kwargs={"data": data}
81 | )
82 | self._check_conditions_satisfied(
83 | ls.c1, ls.c2, stepsize, initial_value, initial_grad, state)
84 |
85 | # Call to run with value_and_grad=True.
86 | ls = HagerZhangLineSearch(fun=jax.value_and_grad(fun),
87 | maxiter=20,
88 | value_and_grad=True)
89 | stepsize, state = ls.run(
90 | init_stepsize=1.0, params=w_init, fun_kwargs={"data": data}
91 | )
92 | self._check_conditions_satisfied(
93 | ls.c1, ls.c2, stepsize, initial_value, initial_grad, state)
94 |
95 | # Failed linesearch (high c1 ensures convergence condition is not met).
96 | ls = HagerZhangLineSearch(fun=fun, maxiter=20, c1=2.)
97 | _, state = ls.run(
98 | init_stepsize=1.0, params=w_init, fun_kwargs={"data": data}
99 | )
100 | self.assertTrue(jnp.all(state.failed))
101 | self.assertFalse(jnp.any(state.done))
102 |
103 | @parameterized.product(val=[onp.inf, onp.nan])
104 | def test_hager_zhang_linesearch_non_finite(self, val):
105 |
106 | def fun(x):
107 | result = jnp.where(x > 4., val, (x - 2)**2)
108 | grad = jnp.where(x > 4., onp.nan, 2 * (x - 2.))
109 | return result, grad
110 | x_init = -0.001
111 |
112 | ls = HagerZhangLineSearch(fun=fun, value_and_grad=True, jit=False)
113 | stepsize = 1.25
114 | state = ls.init_state(init_stepsize=1.25, params=x_init)
115 |
116 | stepsize, state = ls.update(stepsize=stepsize, state=state, params=x_init)
117 | # Should work around the Nan/Inf regions and provide a reasonable step size.
118 | self.assertTrue(state.done)
119 |
120 |
121 | if __name__ == '__main__':
122 | # Uncomment the line below in order to run in float64.
123 | # jax.config.update("jax_enable_x64", True)
124 | absltest.main()
125 |
--------------------------------------------------------------------------------
/tests/import_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from absl.testing import absltest
16 |
17 | import jaxopt
18 | from jaxopt._src import test_util
19 |
20 |
21 | class ImportTest(test_util.JaxoptTestCase):
22 |
23 | def test_implicit_diff(self):
24 | jaxopt.implicit_diff.root_vjp
25 | from jaxopt.implicit_diff import root_vjp
26 |
27 | def test_isotonic(self):
28 | jaxopt.isotonic.isotonic_l2_pav
29 | from jaxopt.isotonic import isotonic_l2_pav
30 |
31 | def test_prox(self):
32 | jaxopt.prox.prox_none
33 | from jaxopt.prox import prox_none
34 |
35 | def test_projection(self):
36 | jaxopt.projection.projection_simplex
37 | from jaxopt.projection import projection_simplex
38 |
39 | def test_tree_util(self):
40 | from jaxopt.tree_util import tree_vdot
41 |
42 | def test_linear_solve(self):
43 | from jaxopt.linear_solve import solve_lu
44 |
45 | def test_base(self):
46 | from jaxopt.base import LinearOperator
47 |
48 | def test_perturbations(self):
49 | from jaxopt.perturbations import make_perturbed_argmax
50 |
51 | def test_loss(self):
52 | jaxopt.loss.binary_logistic_loss
53 | from jaxopt.loss import binary_logistic_loss
54 |
55 | def test_objective(self):
56 | jaxopt.objective.least_squares
57 | from jaxopt.objective import least_squares
58 |
59 | def test_loop(self):
60 | from jaxopt.loop import while_loop
61 |
62 |
63 | if __name__ == '__main__':
64 | absltest.main()
65 |
--------------------------------------------------------------------------------
/tests/isotonic_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for Isotonic Regression."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 |
20 | import jax
21 | import jax.numpy as jnp
22 |
23 | from jax.test_util import check_grads
24 | from jaxopt.isotonic import isotonic_l2_pav
25 | from jaxopt._src import test_util
26 | from sklearn import isotonic
27 |
28 |
29 | class IsotonicPavTest(test_util.JaxoptTestCase):
30 | """Tests for PAV in JAX."""
31 |
32 | def test_output_shape_and_dtype(self, n=10):
33 | """Verifies the shapes and dtypes of output."""
34 | y = jax.random.normal(jax.random.PRNGKey(0), (n,))
35 | output = isotonic_l2_pav(y)
36 | self.assertEqual(output.shape, y.shape)
37 | self.assertEqual(output.dtype, y.dtype)
38 |
39 | @parameterized.product(increasing=[True, False])
40 | def test_compare_with_sklearn(self, increasing, n=10):
41 | """Compares the output with the one of sklearn."""
42 | y = jax.random.normal(jax.random.PRNGKey(0), (n,))
43 | output = isotonic_l2_pav(y, increasing=increasing)
44 | output_sklearn = jnp.array(isotonic.isotonic_regression(y, increasing=increasing))
45 | self.assertArraysAllClose(output, output_sklearn)
46 | y_sort = y.sort()
47 | y_min = y_sort[2]
48 | y_max = y_sort[n-5]
49 | output = isotonic_l2_pav(y, y_min=y_min, y_max=y_max, increasing=increasing)
50 | output_sklearn = jnp.array(isotonic.isotonic_regression(y, y_min=y_min.item(),
51 | y_max=y_max.item(), increasing=increasing))
52 | self.assertArraysAllClose(output, output_sklearn)
53 |
54 | @parameterized.product(increasing=[True, False])
55 | def test_gradient(self, increasing, n=10):
56 | """Checks the gradient with finite differences."""
57 | # Absolute error of test fails for large values of y.
58 | y = 0.1*jax.random.normal(jax.random.PRNGKey(0), (n,))
59 |
60 | def loss(y):
61 | return (isotonic_l2_pav(y**3, increasing=increasing)
62 | + isotonic_l2_pav(y, increasing=increasing) ** 2).mean()
63 |
64 | check_grads(loss, (y,), order=2)
65 |
66 | def test_gradient_min_max(self, n=10):
67 | """Checks the gradient with finite differences."""
68 | y = jax.random.normal(jax.random.PRNGKey(0), (n,))
69 | y_sort = y.sort()
70 | y_min = y_sort[2]
71 | y_max = y_sort[n-5]
72 | def loss(y):
73 | return (isotonic_l2_pav(y**3, y_min=y_min, y_max=y_max)
74 | + isotonic_l2_pav(y, y_min=y_min, y_max=y_max) ** 2).mean()
75 |
76 | check_grads(loss, (y,), order=2)
77 |
78 | def test_vmap(self, n_features=10, n_batches=16):
79 | """Verifies vmap."""
80 | y = jax.random.normal(jax.random.PRNGKey(0), (n_batches, n_features))
81 | isotonic_l2_pav_vmap = jax.vmap(isotonic_l2_pav)
82 | for i in range(n_batches):
83 | self.assertArraysAllClose(isotonic_l2_pav_vmap(y)[i], isotonic_l2_pav(y[i]))
84 |
85 | if __name__ == '__main__':
86 | absltest.main()
87 |
--------------------------------------------------------------------------------
/tests/iterative_refinement_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from functools import partial
16 |
17 | from absl.testing import absltest
18 |
19 | import jax
20 | import jax.numpy as jnp
21 | from jax.test_util import check_grads
22 |
23 | from jaxopt import linear_solve
24 | from jaxopt import IterativeRefinement
25 | from jaxopt._src import test_util
26 |
27 | import numpy as onp
28 |
29 |
30 | class IterativeRefinementTest(test_util.JaxoptTestCase):
31 |
32 | def test_simple_system(self):
33 | onp.random.seed(0)
34 | n = 20
35 | A = onp.random.rand(n, n)
36 | b = onp.random.randn(n)
37 |
38 | low_acc = 1e-1
39 | high_acc = 1e-5
40 |
41 | # Heavily regularized low acuracy solver.
42 | inner_solver = partial(linear_solve.solve_gmres, tol=low_acc, ridge=1e-3)
43 |
44 | solver = IterativeRefinement(solve=inner_solver, tol=high_acc, maxiter=10)
45 | x, state = solver.run(None, A, b)
46 | self.assertLess(state.error, high_acc)
47 |
48 | x_approx = inner_solver(lambda x: jnp.dot(A, x), b)
49 | error_inner_solver = solver.l2_optimality_error(x_approx, A, b)
50 | # High accuracy solution obtained from low accuracy solver.
51 | self.assertLess(state.error, error_inner_solver)
52 |
53 | def test_ill_posed_problem(self):
54 | onp.random.seed(0)
55 | n = 10
56 | e = 5
57 |
58 | # duplicated rows.
59 | A = onp.random.rand(e, n)
60 | A = jnp.concatenate([A, A], axis=0)
61 | b = onp.random.randn(e)
62 | b = jnp.concatenate([b, b], axis=0)
63 |
64 | low_acc = 1e-1
65 | high_acc = 1e-3
66 |
67 | # Heavily regularized low acuracy solver.
68 | inner_solver = partial(linear_solve.solve_gmres, tol=low_acc, ridge=5e-2)
69 |
70 | solver = IterativeRefinement(solve=inner_solver, tol=high_acc, maxiter=30)
71 | x, state = solver.run(init_params=None, A=A, b=b)
72 | self.assertLess(state.error, high_acc)
73 |
74 | x_approx = inner_solver(lambda x: jnp.dot(A, x), b)
75 | error_inner_solver = solver.l2_optimality_error(x_approx, A, b)
76 | # High accuracy solution obtained from low accuracy solver.
77 | self.assertLess(state.error, error_inner_solver)
78 |
79 | def test_perturbed_system(self):
80 | onp.random.seed(0)
81 | n = 20
82 |
83 | A = onp.random.rand(n, n) # invertible matrix (with high probability).
84 |
85 | noise = onp.random.randn(n, n)
86 | sigma = 0.05
87 | A_bar = A + sigma * noise # perturbed system.
88 |
89 | expected = onp.random.randn(n)
90 | b = A @ expected # unperturbed target.
91 |
92 | high_acc = 1e-3
93 | solver = IterativeRefinement(matvec_A=None, matvec_A_bar=jnp.dot,
94 | tol=high_acc, maxiter=100)
95 | x, state = solver.run(init_params=None, A=A, b=b, A_bar=A_bar)
96 | self.assertLess(state.error, high_acc)
97 | self.assertArraysAllClose(x, expected, rtol=5e-2)
98 |
99 | def test_implicit_diff(self):
100 | onp.random.seed(17)
101 | n = 20
102 | A = onp.random.rand(n, n)
103 | b = onp.random.randn(n)
104 |
105 | low_acc = 1e-1
106 | high_acc = 1e-5
107 |
108 | # Heavily regularized low acuracy solver.
109 | inner_solver = partial(linear_solve.solve_gmres, tol=low_acc, ridge=1e-3)
110 | solver = IterativeRefinement(solve=inner_solver, tol=high_acc, maxiter=10)
111 |
112 | def solve_run(A, b):
113 | x, state = solver.run(init_params=None, A=A, b=b)
114 | return x
115 |
116 | check_grads(solve_run, args=(A, b), order=1, modes=['rev'], eps=1e-3)
117 |
118 | def test_warm_start(self):
119 | onp.random.seed(0)
120 | n = 20
121 | A = onp.random.rand(n, n)
122 | b = onp.random.randn(n)
123 |
124 | init_x = onp.random.randn(n)
125 |
126 | high_acc = 1e-5
127 |
128 | solver = IterativeRefinement(tol=high_acc, maxiter=10)
129 | x, state = solver.run(init_x, A, b)
130 | self.assertLess(state.error, high_acc)
131 |
132 |
133 | if __name__ == "__main__":
134 | jax.config.update("jax_enable_x64", False) # low precision environment.
135 | absltest.main()
136 |
--------------------------------------------------------------------------------
/tests/linear_operator_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Linear Operator tests."""
15 |
16 | from absl.testing import absltest
17 |
18 | import jax.numpy as jnp
19 | import numpy as onp
20 |
21 | from jaxopt._src.linear_operator import FunctionalLinearOperator
22 | from jaxopt._src import test_util
23 |
24 |
25 | class LinearOperatorTest(test_util.JaxoptTestCase):
26 |
27 | def test_matvec_and_rmatvec(self):
28 | rng = onp.random.RandomState(0)
29 | A = rng.randn(5, 4)
30 | matvec = lambda A,x: jnp.dot(A, x)
31 | x = rng.randn(4)
32 | y = rng.randn(5)
33 | linop_A = FunctionalLinearOperator(matvec, A)
34 | mv_A, rmv_A = linop_A.matvec_and_rmatvec(x, y)
35 | self.assertArraysAllClose(mv_A, jnp.dot(A, x))
36 | self.assertArraysAllClose(rmv_A, jnp.dot(A.T, y))
37 |
38 |
39 | if __name__ == '__main__':
40 | absltest.main()
41 |
--------------------------------------------------------------------------------
/tests/linesearch_common_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from absl.testing import absltest
16 | from absl.testing import parameterized
17 | import jax
18 | import jax.numpy as jnp
19 | from jaxopt._src import test_util
20 | from jaxopt._src.linesearch_util import _init_stepsize
21 | from jaxopt._src.linesearch_util import _setup_linesearch
22 |
23 |
24 | class LinesearchTest(test_util.JaxoptTestCase):
25 | @parameterized.product(
26 | linesearch=["zoom", "backtracking", "hager-zhang"],
27 | use_gradient=[False, True],
28 | )
29 | def test_linesearch_complex_variables(self, linesearch, use_gradient):
30 | """Test that optimization over complex variable z = x + jy matches equivalent real case"""
31 |
32 | W = jnp.array([[1, -2], [3, 4], [-4 + 2j, 5 - 3j], [-2 - 2j, 6]])
33 |
34 | def C2R(z):
35 | return jnp.stack((z.real, z.imag)) if z is not None else None
36 |
37 | def R2C(x):
38 | return x[..., 0, :] + 1j * x[..., 1, :]
39 |
40 | def f(z):
41 | return W @ z
42 |
43 | def loss_complex(z):
44 | return jnp.sum(jnp.abs(f(z)) ** 1.5)
45 |
46 | def loss_real(zR):
47 | return loss_complex(R2C(zR))
48 |
49 | z0 = jnp.array([1 - 1j, 0 + 1j])
50 |
51 | common_args = dict(
52 | value_and_grad=False,
53 | has_aux=False,
54 | maxlsiter=3,
55 | max_stepsize=1,
56 | jit=True,
57 | unroll=False,
58 | verbose=False,
59 | )
60 |
61 | ls_R = _setup_linesearch(
62 | linesearch=linesearch,
63 | fun=loss_real,
64 | **common_args,
65 | )
66 |
67 | ls_C = _setup_linesearch(
68 | linesearch=linesearch,
69 | fun=loss_complex,
70 | **common_args,
71 | )
72 |
73 | ls_state = _init_stepsize(
74 | strategy="increase",
75 | max_stepsize=1e-1,
76 | min_stepsize=1e-3,
77 | increase_factor=2.0,
78 | stepsize=1e-2,
79 | )
80 |
81 | descent_direction = (
82 | -jnp.conj(jax.grad(loss_complex)(z0)) if use_gradient else None
83 | )
84 |
85 | stepsize_R, _ = ls_R.run(
86 | ls_state, params=C2R(z0), descent_direction=C2R(descent_direction)
87 | )
88 | stepsize_C, _ = ls_C.run(
89 | ls_state, params=z0, descent_direction=descent_direction
90 | )
91 |
92 | self.assertArraysAllClose(stepsize_R, stepsize_C)
93 |
94 |
95 | if __name__ == "__main__":
96 | absltest.main()
97 |
--------------------------------------------------------------------------------
/tests/loop_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from absl.testing import absltest
16 | from absl.testing import parameterized
17 |
18 | import jax
19 | import jax.numpy as jnp
20 |
21 | from jaxopt import loop
22 | from jaxopt._src import test_util
23 |
24 |
25 | class LoopTest(test_util.JaxoptTestCase):
26 |
27 | @parameterized.product(unroll=[True, False], jit=[True, False])
28 | def test_while_loop(self, unroll, jit):
29 | def my_pow(x, y):
30 | def body_fun(val):
31 | return val * x
32 | def cond_fun(val):
33 | return True
34 | return loop.while_loop(cond_fun=cond_fun, body_fun=body_fun, init_val=1.0,
35 | maxiter=y, unroll=unroll, jit=jit)
36 |
37 | if not unroll and not jit:
38 | self.assertRaises(ValueError, my_pow, 3, 4)
39 | return
40 |
41 | self.assertEqual(my_pow(3, 4), pow(3, 4))
42 |
43 | if unroll:
44 | # unroll=False uses lax.while_loop, whichs is not differentiable.
45 | self.assertEqual(jax.grad(my_pow)(3.0, 4),
46 | jax.grad(jnp.power)(3.0, 4))
47 |
48 | @parameterized.product(unroll=[True, False], jit=[True, False])
49 | def test_while_loop_stopped(self, unroll, jit):
50 | def my_pow(x, y, max_val):
51 | def body_fun(val):
52 | return val * x
53 | def cond_fun(val):
54 | return val < max_val
55 | return loop.while_loop(cond_fun=cond_fun, body_fun=body_fun, init_val=1.0,
56 | maxiter=y, unroll=unroll, jit=jit)
57 |
58 | if not unroll and not jit:
59 | self.assertRaises(ValueError, my_pow, 3, 4, max_val=81)
60 | return
61 |
62 | # We asked for pow(3, 6) but due to max_val, we get pow(3, 4).
63 | self.assertEqual(my_pow(3, 6, max_val=81), pow(3, 4))
64 |
65 | if unroll:
66 | self.assertEqual(jax.grad(my_pow)(3.0, 6, max_val=81),
67 | jax.grad(jnp.power)(3.0, 4))
68 |
69 |
70 | if __name__ == '__main__':
71 | absltest.main()
72 |
--------------------------------------------------------------------------------
/tests/nonlinear_cg_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from absl.testing import absltest
16 | from absl.testing import parameterized
17 |
18 | import jax.random
19 | import jax.numpy as jnp
20 |
21 | import numpy as onp
22 |
23 | import jaxopt
24 | from jaxopt import NonlinearCG
25 | from jaxopt import objective
26 | from jaxopt._src import test_util
27 | from sklearn import datasets
28 |
29 | # Uncomment this line to test in x64
30 | # jax.config.update('jax_enable_x64', True)
31 |
32 | def get_random_pytree():
33 | key = jax.random.PRNGKey(1213)
34 |
35 | def rn(key, l=3):
36 | return 0.05 * jnp.array(onp.random.normal(size=(10,)))
37 |
38 | def _get_random_pytree(curr_depth=0, max_depth=3):
39 | r = onp.random.uniform()
40 | if curr_depth == max_depth or r <= 0.2: # leaf
41 | return rn(key)
42 | elif curr_depth <= 1 or r <= 0.7: # list
43 | return [
44 | _get_random_pytree(curr_depth=curr_depth +
45 | 1, max_depth=max_depth)
46 | for _ in range(2)
47 | ]
48 | else: # dict
49 | return {
50 | str(_): _get_random_pytree(
51 | curr_depth=curr_depth + 1, max_depth=max_depth
52 | )
53 | for _ in range(2)
54 | }
55 | return [rn(key), {'a': rn(key), 'b': rn(key)}, _get_random_pytree()]
56 |
57 |
58 | class NonlinearCGTest(test_util.JaxoptTestCase):
59 |
60 | def test_arbitrary_pytree(self):
61 | def loss(w, data):
62 | X, y = data
63 | _w = jnp.concatenate(jax.tree_util.tree_leaves(w))
64 | return ((jnp.dot(X, _w) - y) ** 2).mean()
65 |
66 | w = get_random_pytree()
67 | f_w = jnp.concatenate(jax.tree_util.tree_leaves(w))
68 | X, y = datasets.make_classification(n_samples=15, n_features=f_w.shape[-1],
69 | n_classes=2, n_informative=3,
70 | random_state=0)
71 | data = (X, y)
72 | cg_model = NonlinearCG(fun=loss, tol=1e-2, maxiter=300,
73 | method="polak-ribiere")
74 | w_fit, info = cg_model.run(w, data=data)
75 | self.assertLessEqual(info.error, 5e-2)
76 |
77 | @parameterized.product(
78 | method=["fletcher-reeves", "polak-ribiere", "hestenes-stiefel"],
79 | linesearch=[
80 | "backtracking",
81 | "zoom",
82 | jaxopt.BacktrackingLineSearch(
83 | objective.binary_logreg, decrease_factor=0.5
84 | ),
85 | ],
86 | linesearch_init=["max", "current", "increase"],
87 | )
88 | def test_binary_logreg(self, method, linesearch, linesearch_init):
89 | X, y = datasets.make_classification(
90 | n_samples=10, n_features=5, n_classes=2, n_informative=3, random_state=0
91 | )
92 | data = (X, y)
93 | fun = objective.binary_logreg
94 |
95 | w_init = jnp.zeros(X.shape[1])
96 | cg_model = NonlinearCG(
97 | fun=fun,
98 | tol=1e-3,
99 | maxiter=100,
100 | method=method,
101 | linesearch=linesearch,
102 | linesearch_init=linesearch_init,
103 | )
104 |
105 | # Test with positional argument.
106 | w_fit, info = cg_model.run(w_init, data)
107 |
108 | # Check optimality conditions.
109 | self.assertLessEqual(info.error, 5e-2)
110 |
111 | # Compare against sklearn.
112 | w_skl = test_util.logreg_skl(X, y, 1e-6, fit_intercept=False,
113 | multiclass=False)
114 | self.assertArraysAllClose(w_fit, w_skl, atol=5e-2)
115 |
116 | @parameterized.product(
117 | linesearch=['zoom', 'backtracking', 'hager-zhang'],
118 | method=['hestenes-stiefel', 'polak-ribiere', 'fletcher-reeves']
119 | )
120 | def test_complex(self, method, linesearch):
121 | """Test that optimization over complex variable z = x + jy matches equivalent real case"""
122 |
123 | W = jnp.array(
124 | [[1, - 2],
125 | [3, 4],
126 | [-4 + 2j, 5 - 3j],
127 | [-2 - 2j, 6]]
128 | )
129 |
130 | def C2R(z):
131 | return jnp.stack((z.real, z.imag))
132 |
133 | def R2C(x):
134 | return x[..., 0, :] + 1j * x[..., 1, :]
135 |
136 | def f(z):
137 | return W @ z
138 |
139 | def loss_complex(z):
140 | return jnp.sum(jnp.abs(f(z)) ** 1.5)
141 |
142 | def loss_real(zR):
143 | return loss_complex(R2C(zR))
144 |
145 | z0 = jnp.array([1 - 1j, 0 + 1j])
146 | xy0 = jnp.stack((z0.real, z0.imag))
147 |
148 | solver_C = NonlinearCG(fun=loss_complex, maxiter=5,
149 | maxls=3, method=method, linesearch=linesearch)
150 | solver_R = NonlinearCG(fun=loss_real, maxiter=5,
151 | maxls=3, method=method, linesearch=linesearch)
152 | sol_C, _ = solver_C.run(z0)
153 | sol_R, _ = solver_R.run(C2R(z0))
154 | # NOTE(vroulet): there is a slight loss of precision between real
155 | # and complex cases (observable for any linesearch with jax.enable_x64
156 | tol = 5*1e-15 if jax.config.jax_enable_x64 else 5*1e-6
157 | self.assertArraysAllClose(sol_C, R2C(sol_R), atol=tol, rtol=tol)
158 |
159 |
160 | if __name__ == '__main__':
161 | absltest.main()
162 |
--------------------------------------------------------------------------------
/tests/projected_gradient_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from absl.testing import absltest
17 | from absl.testing import parameterized
18 |
19 | import jax
20 | import jax.numpy as jnp
21 |
22 | from jaxopt import objective
23 | from jaxopt import projection
24 | from jaxopt import ProjectedGradient
25 | from jaxopt import ScipyBoundedMinimize
26 | from jaxopt._src import test_util
27 |
28 | import numpy as onp
29 |
30 |
31 | N_CALLS = 0
32 |
33 | class ProjectedGradientTest(test_util.JaxoptTestCase):
34 |
35 | def test_non_negative_least_squares(self):
36 | rng = onp.random.RandomState(0)
37 | X = rng.randn(10, 5)
38 | w = rng.rand(5)
39 | y = jnp.dot(X, w)
40 | fun = objective.least_squares
41 | w_init = jnp.zeros_like(w)
42 |
43 | pg = ProjectedGradient(fun=fun,
44 | projection=projection.projection_non_negative)
45 | pg_sol = pg.run(w_init, data=(X, y)).params
46 |
47 | lbfgsb = ScipyBoundedMinimize(fun=fun, method="l-bfgs-b")
48 | lower_bounds = jnp.zeros_like(w_init)
49 | upper_bounds = jnp.ones_like(w_init) * jnp.inf
50 | bounds = (lower_bounds, upper_bounds)
51 | lbfgsb_sol = lbfgsb.run(w_init, bounds=bounds, data=(X, y)).params
52 |
53 | self.assertArraysAllClose(pg_sol, lbfgsb_sol, atol=1e-2)
54 |
55 | def test_projected_gradient_l2_ball(self):
56 | rng = onp.random.RandomState(0)
57 | X = rng.randn(10, 5)
58 | w = rng.rand(5)
59 | y = jnp.dot(X, w)
60 | fun = objective.least_squares
61 | w_init = jnp.zeros_like(w)
62 |
63 | pg = ProjectedGradient(fun=fun,
64 | projection=projection.projection_l2_ball)
65 | pg_sol = pg.run(w_init, hyperparams_proj=1.0, data=(X, y)).params
66 | self.assertLess(jnp.sqrt(jnp.sum(pg_sol ** 2)), 1.0)
67 |
68 | def test_projected_gradient_l2_ball_manual_loop(self):
69 | rng = onp.random.RandomState(0)
70 | X = rng.randn(10, 5)
71 | w = rng.rand(5)
72 | y = jnp.dot(X, w)
73 | fun = objective.least_squares
74 | params = jnp.zeros_like(w)
75 |
76 | pg = ProjectedGradient(fun=fun,
77 | projection=projection.projection_l2_ball)
78 |
79 | state = pg.init_state(params)
80 |
81 | for _ in range(10):
82 | params, state = pg.update(params, state, hyperparams_proj=1.0, data=(X, y))
83 |
84 | self.assertLess(jnp.sqrt(jnp.sum(params ** 2)), 1.0)
85 |
86 | def test_projected_gradient_implicit_diff(self):
87 | rng = onp.random.RandomState(0)
88 | X = rng.randn(10, 5)
89 | w = rng.rand(5)
90 | y = jnp.dot(X, w)
91 | fun = objective.least_squares
92 | w_init = jnp.zeros_like(w)
93 |
94 | def solution(radius):
95 | pg = ProjectedGradient(fun=fun,
96 | projection=projection.projection_l2_ball)
97 | return pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params
98 |
99 | eps = 1e-4
100 | J = jax.jacobian(solution)(0.1)
101 | J2 = (solution(0.1 + eps) - solution(0.1 - eps)) / (2 * eps)
102 | self.assertArraysAllClose(J, J2, atol=1e-2)
103 |
104 | def test_polyhedron_projection(self):
105 | def f(x):
106 | return x[0]**2-x[1]**2
107 |
108 | A = jnp.array([[0, 0]])
109 | b = jnp.array([0])
110 | G = jnp.array([[-1, -1], [0, 1], [1, -1], [-1, 0], [0, -1]])
111 | h = jnp.array([-1, 1, 1, 0, 0])
112 | hyperparams = (A, b, G, h)
113 |
114 | proj = projection.projection_polyhedron
115 | pg = ProjectedGradient(fun=f, projection=proj, jit=False)
116 | sol, state = pg.run(init_params=jnp.array([0.,1.]), hyperparams_proj=hyperparams)
117 | self.assertLess(state.error, pg.tol)
118 |
119 | @parameterized.product(n_iter=[10])
120 | def test_n_calls(self, n_iter):
121 | """Test whether the number of function calls
122 | is equal to the number of iterations + 1 in the
123 | no linesearch case, where the complexity is linear."""
124 | def fun(x):
125 | global N_CALLS
126 | N_CALLS += 1
127 | return x[0]**2-x[1]**2
128 |
129 | A = jnp.array([[0, 0]])
130 | b = jnp.array([0])
131 | G = jnp.array([[-1, -1], [0, 1], [1, -1], [-1, 0], [0, -1]])
132 | h = jnp.array([-1, 1, 1, 0, 0])
133 | hyperparams = (A, b, G, h)
134 |
135 | proj = projection.projection_polyhedron
136 | pg = ProjectedGradient(fun=fun, projection=proj, jit=False, maxiter=n_iter, tol=1e-10, stepsize=1.0)
137 | sol, state = pg.run(init_params=jnp.array([0.,1.]), hyperparams_proj=hyperparams)
138 | self.assertEqual(N_CALLS, n_iter)
139 |
140 |
141 | if __name__ == '__main__':
142 | # Uncomment the line below in order to run in float64.
143 | # jax.config.update("jax_enable_x64", True)
144 | absltest.main()
145 |
--------------------------------------------------------------------------------