├── paper ├── fig │ ├── fp.pdf │ ├── fp.png │ ├── opt.pdf │ ├── opt.png │ ├── ctrl.pdf │ ├── ctrl.png │ ├── imaml.pdf │ ├── imaml.png │ ├── gaussian.pdf │ ├── maxent.gif │ ├── maxent.pdf │ ├── maxent.png │ ├── overview.pdf │ ├── overview.png │ ├── sphere │ │ ├── 0.png │ │ ├── 1.png │ │ ├── 2.png │ │ ├── 3.png │ │ ├── 4.png │ │ ├── 5.png │ │ ├── 6.png │ │ └── 7.png │ ├── vae-iter.pdf │ ├── vae-time.pdf │ ├── loss-comp.pdf │ ├── loss-comp.png │ ├── learning-obj.pdf │ ├── learning-obj.png │ ├── learning-reg.pdf │ ├── learning-reg.png │ ├── learning-rl.pdf │ ├── learning-rl.png │ ├── smoothed-loss.pdf │ ├── smoothed-loss.png │ ├── vae-samples.png │ ├── control-model-free-iter.pdf │ ├── control-model-free-time.pdf │ ├── dcem │ │ ├── cem-vis-full-space.pdf │ │ └── cem-vis-latent-space.pdf │ ├── control-model-based-iter.pdf │ └── control-model-based-time.pdf ├── chapters │ ├── 3-main-table.tex │ ├── 1-intro.tex │ ├── 5-discussion.tex │ └── 4-implementation.tex ├── math_commands.tex └── amor.tex ├── .gitignore ├── .gitmodules ├── CONTRIBUTING.md ├── code ├── figures │ ├── imaml.py │ ├── fixed-point.py │ ├── smoothed-loss.py │ ├── loss-comp.py │ ├── maxent.py │ ├── ctrl.py │ ├── maxent-animation.py │ └── main-example.py ├── evaluate_amortization_speed_vae.py ├── evaluate_amortization_speed_function.py ├── train-sphere.py └── evaluate_amortization_speed_control.py ├── README.md ├── CODE_OF_CONDUCT.md └── LICENSE /paper/fig/fp.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/fp.pdf -------------------------------------------------------------------------------- /paper/fig/fp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/fp.png -------------------------------------------------------------------------------- /paper/fig/opt.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/opt.pdf -------------------------------------------------------------------------------- /paper/fig/opt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/opt.png -------------------------------------------------------------------------------- /paper/fig/ctrl.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/ctrl.pdf -------------------------------------------------------------------------------- /paper/fig/ctrl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/ctrl.png -------------------------------------------------------------------------------- /paper/fig/imaml.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/imaml.pdf -------------------------------------------------------------------------------- /paper/fig/imaml.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/imaml.png -------------------------------------------------------------------------------- /paper/fig/gaussian.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/gaussian.pdf -------------------------------------------------------------------------------- /paper/fig/maxent.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/maxent.gif -------------------------------------------------------------------------------- /paper/fig/maxent.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/maxent.pdf -------------------------------------------------------------------------------- /paper/fig/maxent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/maxent.png -------------------------------------------------------------------------------- /paper/fig/overview.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/overview.pdf -------------------------------------------------------------------------------- /paper/fig/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/overview.png -------------------------------------------------------------------------------- /paper/fig/sphere/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/sphere/0.png -------------------------------------------------------------------------------- /paper/fig/sphere/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/sphere/1.png -------------------------------------------------------------------------------- /paper/fig/sphere/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/sphere/2.png -------------------------------------------------------------------------------- /paper/fig/sphere/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/sphere/3.png -------------------------------------------------------------------------------- /paper/fig/sphere/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/sphere/4.png -------------------------------------------------------------------------------- /paper/fig/sphere/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/sphere/5.png -------------------------------------------------------------------------------- /paper/fig/sphere/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/sphere/6.png -------------------------------------------------------------------------------- /paper/fig/sphere/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/sphere/7.png -------------------------------------------------------------------------------- /paper/fig/vae-iter.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/vae-iter.pdf -------------------------------------------------------------------------------- /paper/fig/vae-time.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/vae-time.pdf -------------------------------------------------------------------------------- /paper/fig/loss-comp.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/loss-comp.pdf -------------------------------------------------------------------------------- /paper/fig/loss-comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/loss-comp.png -------------------------------------------------------------------------------- /paper/fig/learning-obj.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/learning-obj.pdf -------------------------------------------------------------------------------- /paper/fig/learning-obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/learning-obj.png -------------------------------------------------------------------------------- /paper/fig/learning-reg.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/learning-reg.pdf -------------------------------------------------------------------------------- /paper/fig/learning-reg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/learning-reg.png -------------------------------------------------------------------------------- /paper/fig/learning-rl.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/learning-rl.pdf -------------------------------------------------------------------------------- /paper/fig/learning-rl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/learning-rl.png -------------------------------------------------------------------------------- /paper/fig/smoothed-loss.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/smoothed-loss.pdf -------------------------------------------------------------------------------- /paper/fig/smoothed-loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/smoothed-loss.png -------------------------------------------------------------------------------- /paper/fig/vae-samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/vae-samples.png -------------------------------------------------------------------------------- /paper/fig/control-model-free-iter.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/control-model-free-iter.pdf -------------------------------------------------------------------------------- /paper/fig/control-model-free-time.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/control-model-free-time.pdf -------------------------------------------------------------------------------- /paper/fig/dcem/cem-vis-full-space.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/dcem/cem-vis-full-space.pdf -------------------------------------------------------------------------------- /paper/fig/control-model-based-iter.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/control-model-based-iter.pdf -------------------------------------------------------------------------------- /paper/fig/control-model-based-time.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/control-model-based-time.pdf -------------------------------------------------------------------------------- /paper/fig/dcem/cem-vis-latent-space.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/amortized-optimization-tutorial/HEAD/paper/fig/dcem/cem-vis-latent-space.pdf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pdf 2 | *.gif 3 | *.png 4 | !**/fig/**/* 5 | *.out 6 | *.toc 7 | *.log 8 | *.aux 9 | *.bbl 10 | *.blg 11 | *.brf 12 | *.fdb_latexmk 13 | *.fls 14 | *.xml 15 | *.bcf 16 | *.gz 17 | .deps 18 | .deps-a 19 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "code/vae_submodule"] 2 | path = code/vae_submodule 3 | url = git@github.com:YannDubs/disentangling-vae.git 4 | [submodule "code/svg_submodule"] 5 | path = code/svg_submodule 6 | url = git@github.com:facebookresearch/svg.git 7 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to svg 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /code/figures/imaml.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from matplotlib import cm 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | 11 | import os 12 | plt.rcParams.update({ 13 | "text.usetex": True, 14 | "font.family": "serif", 15 | "font.sans-serif": ["Computer Modern Roman"]}) 16 | plt.style.use('bmh') 17 | 18 | fig, ax = plt.subplots(figsize=(2,1.3), dpi=200) 19 | 20 | N = 1000 21 | y = np.linspace(-2.0, 2.0, N) 22 | z = -y**3 - 10.*y 23 | ax.plot(y, z, color='k') 24 | 25 | I = N // 5 26 | y0, z0 = y[I], z[I] 27 | ax.scatter(y0, z0, color='#5499FF', lw=1, s=50, zorder=10, marker='.') 28 | ax.text(y0, z0-3, r'$$\hat y^0_\theta$$', color='#5499FF', 29 | ha='right', va='top') 30 | 31 | lams = np.linspace(0., 12., 15) 32 | for lam in lams: 33 | z_ = z + (lam/2)*(y-y0)**2 34 | ax.plot(y, z_, color='k', alpha=0.2) 35 | 36 | # ax.set_title('$$f(y) + {\lambda\over 2}||y-\hat y_0||_2^2$$', size=10) 37 | 38 | # ax.set_xlabel('$$y$$') 39 | # ax.xaxis.set_label_coords(.5, 0.01) 40 | 41 | fig.tight_layout() 42 | ax.set_xticks([]) 43 | ax.set_yticks([]) 44 | ax.grid(False) 45 | ax.spines['top'].set_visible(False) 46 | ax.spines['right'].set_visible(False) 47 | ax.spines['bottom'].set_visible(False) 48 | ax.spines['left'].set_visible(False) 49 | fname = 'imaml.pdf' 50 | plt.savefig(fname, transparent=True) 51 | os.system(f'pdfcrop {fname} {fname}') 52 | -------------------------------------------------------------------------------- /code/figures/fixed-point.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from matplotlib import cm 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | 11 | import os 12 | plt.rcParams.update({ 13 | "text.usetex": True, 14 | "font.family": "serif", 15 | "font.sans-serif": ["Computer Modern Roman"]}) 16 | plt.style.use('bmh') 17 | 18 | N = 1000 19 | x = np.linspace(-5., 5.0, N) 20 | 21 | fig, ax = plt.subplots(figsize=(2,1.3), dpi=200) 22 | 23 | y = x 24 | ax.plot(x, y, color='k', linestyle='--', alpha=.5) 25 | 26 | y = -2.*np.sin(x)+0.9*x*(1+0.1*np.cos(x))**2 27 | ax.plot(x, y, color='k') 28 | 29 | fp = max(x[np.abs(y-x) <= 5e-3]) # Numerically find the fixed-point :) 30 | ax.scatter([0], [0], color='#AA0000', lw=1, s=70, zorder=10, marker='*') 31 | ax.scatter([fp], [fp], color='#AA0000', lw=1, s=70, zorder=10, marker='*') 32 | ax.scatter([-fp], [-fp], color='#AA0000', lw=1, s=70, zorder=10, marker='*') 33 | 34 | # ax.set_ylabel('$$g(y)$$', rotation=0, labelpad=0) 35 | # ax.yaxis.set_label_coords(-.07, .44) 36 | # ax.set_xlabel('$$y$$') 37 | # ax.xaxis.set_label_coords(.5, 0.01) 38 | 39 | fig.tight_layout() 40 | ax.set_xticks([]) 41 | ax.set_yticks([]) 42 | ax.grid(False) 43 | ax.spines['top'].set_visible(False) 44 | ax.spines['right'].set_visible(False) 45 | ax.spines['bottom'].set_visible(False) 46 | ax.spines['left'].set_visible(False) 47 | fname = 'fp.pdf' 48 | plt.savefig(fname, transparent=True) 49 | os.system(f'pdfcrop {fname} {fname}') 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tutorial on Amortized Optimization 2 | This repository contains the source code for the paper 3 | [Tutorial on amortized optimization for 4 | learning to optimize over continuous domains](https://arxiv.org/abs/2202.00665) 5 | by 6 | [Brandon Amos](http://bamos.github.io). 7 | The main LaTeX source is in [paper](./paper) 8 | and the source code examples are in [code](./code). 9 | The code that generates the following plots is also 10 | in [code/figures](./code/figures): 11 | 12 | ## [main-example.py](./code/figures/main-example.py) 13 | ![](./paper/fig/opt.png?raw=true) 14 | 15 | ![](./paper/fig/learning-obj.png?raw=true) 16 | ![](./paper/fig/learning-reg.png?raw=true) 17 | 18 | ![](./paper/fig/learning-rl.png?raw=true) 19 | 20 | ## [maxent-animation.py](./code/figures/maxent-animation.py) 21 | ![](./paper/fig/maxent.gif?raw=true) 22 | 23 | ## [maxent.py](./code/figures/maxent.py) 24 | ![](./paper/fig/maxent.png?raw=true) 25 | 26 | ## [ctrl.py](./code/figures/ctrl.py) 27 | ![](./paper/fig/ctrl.png?raw=true) 28 | 29 | ## [imaml.py](./code/figures/imaml.py) 30 | ![](./paper/fig/imaml.png?raw=true) 31 | 32 | ## [fixed-point.py](./code/figures/fixed-point.py) 33 | ![](./paper/fig/fp.png?raw=true) 34 | 35 | ## [loss-comp.py](./code/figures/loss-comp.py) 36 | ![](./paper/fig/loss-comp.png?raw=true) 37 | 38 | ## [smoothed-loss.py](./code/figures/smoothed-loss.py) 39 | ![](./paper/fig/smoothed-loss.png?raw=true) 40 | 41 | # Licensing 42 | The source code for this tutorial, plots, and 43 | sphere experiment is licensed under the 44 | [CC BY-NC 4.0 License](https://creativecommons.org/licenses/by-nc/4.0/). 45 | -------------------------------------------------------------------------------- /code/figures/smoothed-loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from matplotlib import cm 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | 11 | import os 12 | plt.rcParams.update({ 13 | "text.usetex": True, 14 | "font.family": "serif", 15 | "font.sans-serif": ["Computer Modern Roman"]}) 16 | plt.style.use('bmh') 17 | 18 | fig, ax = plt.subplots(figsize=(2.5,1.5), dpi=200) 19 | 20 | def f(x): 21 | return np.cos(x) + 0.2*np.abs(x-np.pi/2) 22 | 23 | N = 100 24 | x = np.linspace(-4.*np.pi, 2*np.pi, N) 25 | y = f(x) 26 | ax.plot(x, y, color='k') 27 | 28 | sigmas = [1., 1.5, 2.5] 29 | for sigma in sigmas: 30 | ys = [] 31 | # Inefficiently doing this... 32 | for xi in x: 33 | eps = sigma*np.random.randn(50000) 34 | yi = np.mean(f(xi+eps)) 35 | ys.append(yi) 36 | ax.plot(x, ys, alpha=1., lw=2) 37 | 38 | # ax.set_xlabel(r'$$\theta$$') 39 | # ax.xaxis.set_label_coords(.5, 0.01) 40 | # ax.set_ylabel(r'$${\mathcal L}(\hat y_\theta)$$', rotation=0, labelpad=0) 41 | # ax.yaxis.set_label_coords(-.07, .44) 42 | # ax.set_ylabel('$$y$$', rotation=0, labelpad=0) 43 | # ax.xaxis.set_label_coords(.5, 0.01) 44 | 45 | fig.tight_layout() 46 | ax.set_xticks([]) 47 | ax.set_yticks([]) 48 | ax.grid(False) 49 | ax.spines['top'].set_visible(False) 50 | ax.spines['right'].set_visible(False) 51 | ax.spines['bottom'].set_visible(False) 52 | ax.spines['left'].set_visible(False) 53 | fname = 'smoothed-loss.pdf' 54 | plt.savefig(fname, transparent=True) 55 | os.system(f'pdfcrop {fname} {fname}') 56 | -------------------------------------------------------------------------------- /code/figures/loss-comp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from matplotlib import cm 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | 11 | import os 12 | plt.rcParams.update({ 13 | "text.usetex": True, 14 | "font.family": "serif", 15 | "font.sans-serif": ["Computer Modern Roman"]}) 16 | plt.style.use('bmh') 17 | 18 | fig, ax = plt.subplots(figsize=(1.5,1.5), dpi=200) 19 | 20 | N = 1000 21 | x = np.linspace(-5.0, 5.0, N) 22 | y = np.linspace(-5.0, 5.0, N) 23 | X, Y = np.meshgrid(x, y) 24 | a,b = 0., 10. 25 | Z = X**2 + Y**2 + 1.4*X*Y 26 | Z = 1./(1.+np.exp(-Z/10.)) 27 | 28 | fig, ax = plt.subplots(figsize=(2,1.7), dpi=200) 29 | CS = ax.contourf(X, Y, Z, cmap='Purples', alpha=0.8) 30 | 31 | Z = X**2 + Y**2 32 | CS = ax.contour(X, Y, Z, colors='k', alpha=.7, linewidths=1, levels=5) 33 | 34 | ax.scatter([0], [0], color='#AA0000', lw=1, s=50, zorder=10, marker='*') 35 | 36 | ax.set_ylabel('$$y_1$$', rotation=0, labelpad=0) 37 | ax.yaxis.set_label_coords(-.07, .44) 38 | ax.set_xlabel('$$y_0$$') 39 | ax.xaxis.set_label_coords(.5, 0.01) 40 | 41 | ax.text(0., 1., r'$$f(y; x)$$', color='#491386', 42 | bbox=dict(facecolor='white', pad=0, alpha=0.9, edgecolor='none'), 43 | transform=ax.transAxes, ha='left', va='top') 44 | 45 | ax.text(.3, .3, '$$y^\star(x)$$', color='#AA0000', 46 | ha='left', va='bottom') 47 | 48 | fig.tight_layout() 49 | ax.set_xticks([]) 50 | ax.set_yticks([]) 51 | ax.grid(False) 52 | ax.spines['top'].set_visible(False) 53 | ax.spines['right'].set_visible(False) 54 | ax.spines['bottom'].set_visible(False) 55 | ax.spines['left'].set_visible(False) 56 | fname = 'loss-comp.pdf' 57 | plt.savefig(fname, transparent=True) 58 | os.system(f'pdfcrop {fname} {fname}') 59 | -------------------------------------------------------------------------------- /paper/chapters/3-main-table.tex: -------------------------------------------------------------------------------- 1 | % Copyright (c) Meta Platforms, Inc. and affiliates. 2 | % \hspace*{-8mm} 3 | \resizebox{\textwidth}{!}{ 4 | \begin{tabular}{ccccccc} 5 | \S & Application & Objective $f$ & Domain $\gY$ & Context Space $\gX$ & Amortization model $\hat y_\theta$ & Loss $\gL$ \\ \toprule 6 | \ref{sec:apps:avi} & VAE & $-\ELBO$ & variational posterior & data & full & $\gL_{\rm obj}$ \\ 7 | & SAVAE/IVAE & | & | & | & semi & | \\ 8 | \midrule 9 | \ref{sec:apps:lista} & PSD & reconstruction & sparse code & data & full & $\gL_{\rm reg}$ \\ 10 | & LISTA & | & | & | & semi & | \\ 11 | \midrule 12 | \ref{sec:apps:meta} & HyperNets & task loss & model parameters & tasks & full & $\gL_{\rm obj}$ \\ 13 | & LM & | & | & | & semi & $\gL^{\rm RL}_{\rm obj}$ \\ 14 | & MAML & | & | & | & | & $\gL_{\rm obj}$ \\ 15 | & Neural Potts & pseudo-likelihood & | & protein sequences & full & $\gL_{\rm obj}$ \\ 16 | \midrule 17 | \ref{sec:apps:convex} & NeuralFP & FP residual & FP iterates & FP contexts & semi & $\gL_{\rm obj}^\Sigma$ \\ 18 | & HyperAA & | & | & | & | & $\gL_{\rm reg}^\Sigma$ \\ 19 | & NeuralSCS & CP residual & CP iterates & CP parameters & | & $\gL_{\rm obj}^\Sigma$ \\ 20 | & HyperDEQ & DEQ residual & DEQ iterates & DEQ parameters & | & $\gL_{\rm reg}^\Sigma$ \\ 21 | & NeuralNMF & NMF residual & factorizations & input matrices & | & $\gL_{\rm obj}^\Sigma$ \\ 22 | & RLQP & $R_{\rm RLQP}$ & QP iterates & QP parameters & | & $\gL^{\rm RL}_{\rm obj}$ \\ 23 | \midrule 24 | \ref{sec:apps:ot} & Meta OT & dual OT cost & optimal couplings & input measures & full & $\gL_{\rm obj}$ \\ 25 | & CondOT & dual OT cost & optimal couplings & contextual information & | & $\gL_{\rm obj}$ \\ 26 | & AmorConj & $c$-transform obj & ${\rm supp}(\alpha)$ & ${\rm supp}(\beta)$ & | & $\gL_{\rm obj}$ \\ 27 | & $\gA$-SW & max-sliced dist & slices $\Theta$ & mini-batches & | & $\gL_{\rm obj}$ \\ 28 | \midrule 29 | \ref{sec:apps:ctrl} & BC/IL & $-Q$-value & controls & state space & full & $\gL_{\rm reg}$ \\ 30 | & (D)DPG/TD3 & | & | & | & | & $\gL_{\rm obj}$ \\ 31 | & PILCO & | & | & | & | & $\gL_{\rm obj}$ \\ 32 | & POPLIN & | & | & | & full or semi & $\gL_{\rm reg}$ \\ 33 | & DCEM & | & | & | & semi & $\gL_{\rm reg}$ \\ 34 | & IAPO & | & | & | & | & $\gL_{\rm obj}$ \\ 35 | & SVG & $\D_\gQ$ or $-\gE_Q$ & control dists & | & full & $\gL_{\rm obj}$ \\ 36 | & SAC & | & | & | & | & $\gL_{\rm obj}$ \\ 37 | & GPS & | & | & | & | & $\gL_{\rm KL}$ \\ 38 | \bottomrule 39 | \end{tabular}} 40 | 41 | %%% Local Variables: 42 | %%% coding: utf-8 43 | %%% mode: latex 44 | %%% TeX-master: "../amor-nowplain.tex" 45 | %%% End: 46 | -------------------------------------------------------------------------------- /code/figures/maxent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from matplotlib import cm 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | 11 | import os 12 | plt.rcParams.update({ 13 | "text.usetex": True, 14 | "font.family": "serif", 15 | "font.sans-serif": ["Computer Modern Roman"]}) 16 | plt.style.use('bmh') 17 | 18 | phi = jnp.array([0., .7, 4.]) # Parameters to learn 19 | 20 | @jax.jit 21 | def compute_dist(x, phi): 22 | # Compute values at the discretized points in the domain 23 | v = jnp.exp(-0.5*(x-phi[0])**2 + phi[1]*jnp.sin(x*phi[2])) 24 | dx = x[1:]-x[:-1] 25 | y = v/sum(v[1:]*dx) # Normalize to be a proper distribution. 26 | flow_x = flow(x, y) # Constrain the mean and variance. 27 | J_flow = jnp.diag(jax.jacfwd(flow)(x, y)) 28 | flow_y = y / J_flow 29 | return flow_x, flow_y 30 | 31 | @jax.jit 32 | def mean(x, y): 33 | dx = x[1:]-x[:-1] 34 | x = x[1:] 35 | y = y[1:] 36 | return sum(x*y*dx) 37 | 38 | @jax.jit 39 | def std(x, y): 40 | mu = mean(x,y) 41 | dx = x[1:]-x[:-1] 42 | x = x[1:] 43 | y = y[1:] 44 | return jnp.sqrt(sum(((x-mu)**2)*y*dx)) 45 | 46 | @jax.jit 47 | def entr(x, y): 48 | dx = x[1:]-x[:-1] 49 | y = y[1:] 50 | return -sum(y*jnp.log(y+1e-8)*dx) 51 | 52 | @jax.jit 53 | def flow(x, y): 54 | # Normalize the domain so that the distribution has 55 | # zero mean and identity variance. 56 | return (x - mean(x,y)) / std(x, y) 57 | 58 | @jax.jit 59 | def loss(x, phi): 60 | x, y = compute_dist(x, phi) 61 | return -entr(x, y) 62 | 63 | dloss_dphi = jax.jit(jax.grad(loss, argnums=1)) 64 | 65 | fig, ax = plt.subplots(figsize=(2,1.3), dpi=200) 66 | 67 | N = 1000 # Number of discretization points in the domain 68 | 69 | # The domain of the unprojected distribution 70 | x_unproj = jnp.linspace(-5.0, 5.0, N) 71 | 72 | # Plot the initialization 73 | x, y = compute_dist(x_unproj, phi) 74 | ax.plot(x, y, color='k', alpha=0.5) 75 | print(f'entr={entr(x,y):.2f} (mean={mean(x, y):.2f} std={std(x,y):.2f})') 76 | 77 | for t in range(20): 78 | # Take a gradient step with respect to the 79 | # parameters of the distribution 80 | phi -= dloss_dphi(x_unproj, phi) 81 | x, y = compute_dist(x_unproj, phi) 82 | ax.plot(x, y, color='k', alpha=0.2) 83 | print(f'entr={entr(x,y):.2f} (mean={mean(x, y):.2f} std={std(x,y):.2f})') 84 | 85 | fig.tight_layout() 86 | ax.set_xticks([]) 87 | ax.set_yticks([]) 88 | ax.grid(False) 89 | ax.spines['top'].set_visible(False) 90 | ax.spines['right'].set_visible(False) 91 | ax.spines['bottom'].set_visible(False) 92 | ax.spines['left'].set_visible(False) 93 | fname = 'maxent.pdf' 94 | plt.savefig(fname, transparent=True) 95 | os.system(f'pdfcrop {fname} {fname}') 96 | -------------------------------------------------------------------------------- /paper/math_commands.tex: -------------------------------------------------------------------------------- 1 | % Copyright (c) Meta Platforms, Inc. and affiliates. 2 | %%%%% NEW MATH DEFINITIONS %%%%% 3 | 4 | \usepackage{amsmath,amsfonts,bm,mathtools,amssymb} 5 | 6 | \def\ceil#1{\lceil #1 \rceil} 7 | \def\floor#1{\lfloor #1 \rfloor} 8 | \def\1{\bm{1}} 9 | 10 | \def\gA{{\mathcal{A}}} 11 | \def\gB{{\mathcal{B}}} 12 | \def\gC{{\mathcal{C}}} 13 | \def\gD{{\mathcal{D}}} 14 | \def\gE{{\mathcal{E}}} 15 | \def\gF{{\mathcal{F}}} 16 | \def\gG{{\mathcal{G}}} 17 | \def\gH{{\mathcal{H}}} 18 | \def\gI{{\mathcal{I}}} 19 | \def\gJ{{\mathcal{J}}} 20 | \def\gK{{\mathcal{K}}} 21 | \def\gL{{\mathcal{L}}} 22 | \def\gM{{\mathcal{M}}} 23 | \def\gN{{\mathcal{N}}} 24 | \def\gO{{\mathcal{O}}} 25 | \def\gP{{\mathcal{P}}} 26 | \def\gQ{{\mathcal{Q}}} 27 | \def\gR{{\mathcal{R}}} 28 | \def\gS{{\mathcal{S}}} 29 | \def\gT{{\mathcal{T}}} 30 | \def\gU{{\mathcal{U}}} 31 | \def\gV{{\mathcal{V}}} 32 | \def\gW{{\mathcal{W}}} 33 | \def\gX{{\mathcal{X}}} 34 | \def\gY{{\mathcal{Y}}} 35 | \def\gZ{{\mathcal{Z}}} 36 | 37 | \def\sA{{\mathbb{A}}} 38 | \def\sB{{\mathbb{B}}} 39 | \def\sC{{\mathbb{C}}} 40 | \def\sD{{\mathbb{D}}} 41 | % Don't use a set called E, because this would be the same as our symbol 42 | % for expectation. 43 | \def\sF{{\mathbb{F}}} 44 | \def\sG{{\mathbb{G}}} 45 | \def\sH{{\mathbb{H}}} 46 | \def\sI{{\mathbb{I}}} 47 | \def\sJ{{\mathbb{J}}} 48 | \def\sK{{\mathbb{K}}} 49 | \def\sL{{\mathbb{L}}} 50 | \def\sM{{\mathbb{M}}} 51 | \def\sN{{\mathbb{N}}} 52 | \def\sO{{\mathbb{O}}} 53 | \def\sP{{\mathbb{P}}} 54 | \def\sQ{{\mathbb{Q}}} 55 | \def\sR{{\mathbb{R}}} 56 | \def\sS{{\mathbb{S}}} 57 | \def\sT{{\mathbb{T}}} 58 | \def\sU{{\mathbb{U}}} 59 | \def\sV{{\mathbb{V}}} 60 | \def\sW{{\mathbb{W}}} 61 | \def\sX{{\mathbb{X}}} 62 | \def\sY{{\mathbb{Y}}} 63 | \def\sZ{{\mathbb{Z}}} 64 | 65 | 66 | % \newcommand{\E}{\mathbb{E}} 67 | \DeclareMathOperator*{\E}{\mathbb{E}} 68 | \DeclareMathOperator{\HH}{\mathbb{H}} 69 | \DeclareMathOperator{\Var}{\rm{Var}} 70 | \newcommand{\Ls}{\mathcal{L}} 71 | \newcommand{\R}{\mathbb{R}} 72 | \newcommand{\D}{\mathrm{D}} 73 | 74 | \DeclareMathOperator*{\argmax}{arg\,max} 75 | \DeclareMathOperator*{\argmin}{arg\,min} 76 | \DeclareMathOperator*{\arginf}{arg\,inf} 77 | \DeclareMathOperator*{\argsup}{arg\,sup} 78 | \DeclareMathOperator*{\minimize}{minimize} 79 | \DeclareMathOperator*{\maximize}{maximize} 80 | \DeclareMathOperator*{\subjectto}{subject\;to} 81 | \DeclareMathOperator*{\st}{s.t.} 82 | 83 | \DeclarePairedDelimiterX{\infdivx}[2]{(}{)}{% 84 | #1\;\delimsize|\delimsize|\;#2% 85 | } 86 | \newcommand{\kl}[2]{\ensuremath{{\rm D}_{\rm KL}\infdivx{#1}{#2}}\xspace} 87 | \newcommand{\dist}[2]{\ensuremath{{\rm D}\infdivx{#1}{#2}}\xspace} 88 | 89 | \DeclareMathOperator{\sign}{sign} 90 | \DeclareMathOperator{\Tr}{Tr} 91 | \DeclareMathOperator{\ELBO}{ELBO} 92 | \let\ab\allowbreak 93 | 94 | \newcommand{\ftrans}{\ensuremath{f^{\mathrm{trans}}}} 95 | \newcommand{\fodec}{\ensuremath{f^{\mathrm{odec}}}} 96 | \newcommand{\frew}{\ensuremath{f^{\mathrm{rew}}}} 97 | \newcommand{\fdec}{\ensuremath{f^{\mathrm{dec}}}} 98 | 99 | 100 | \newcommand{\xinit}{{x_{\rm init}}} 101 | \newcommand{\uinit}{{u_{\rm init}}} 102 | \newcommand{\piold}{{\pi_{\theta_{\rm old}}}} 103 | 104 | \newcommand{\defeq}{\vcentcolon=} 105 | \newcommand{\eqdef}{=\vcentcolon} 106 | 107 | \newcommand{\gradupdate}{\ensuremath{\mathrm{grad\_update}}} 108 | 109 | \newcommand{\stopgrad}[2]{ \underset{\tiny\mathbf{stop}(#2)}{\left\llbracket{}#1\right\rrbracket{}} } 110 | -------------------------------------------------------------------------------- /code/figures/ctrl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from matplotlib import cm 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | 11 | import os 12 | plt.rcParams.update({ 13 | "text.usetex": True, 14 | "font.family": "serif", 15 | "font.sans-serif": ["Computer Modern Roman"]}) 16 | plt.style.use('bmh') 17 | 18 | # Initial problem with x^\star 19 | N = 1000 20 | x = np.linspace(-3.8, 5.0, N) 21 | # y = -(x**2)*np.sin(x) 22 | # y = (np.cos(x)**2) #- np.abs(x) 23 | y = -x**2 + 7*np.sin(x+np.pi) 24 | y = y/8. 25 | 26 | nrow, ncol = 1, 2 27 | fig, axs = plt.subplots(nrow, ncol, figsize=(ncol*2.5,nrow*1.1), dpi=200) 28 | 29 | ax = axs[0] 30 | ax.axhline(0, color='k') 31 | ax.plot(x, y-y.min(), color='k') 32 | ustar = x[y.argmax()] 33 | ax.axvline(ustar, color='#AA0000') 34 | ax.text(0.13, 0.09, '$$\pi^\star(x)$$', color='#AA0000', 35 | transform=ax.transAxes, ha='left', va='bottom') 36 | 37 | ax.axvline(ustar+2, color='#5499FF') 38 | ax.text(0.54, 0.09, r'$$\pi_\theta(x)$$', color='#5499FF', 39 | transform=ax.transAxes, ha='left', va='bottom') 40 | 41 | ax.arrow(x=ustar+2, y=0.5, dx=-0.5, dy=0., 42 | width=0.1, color='#5499FF', zorder=10) 43 | 44 | ax.text(0.7, 0.44, '$$Q(x, u)$$', 45 | transform=ax.transAxes, ha='left', va='bottom') 46 | 47 | ax.set_xlabel('$$u$$') 48 | ax.xaxis.set_label_coords(.5, 0.01) 49 | ax.set_title('Deterministic Policy', fontsize=12, pad=5) 50 | # ax.set_ylabel('$$Q(x, u)$$', rotation=0, labelpad=0) 51 | # ax.yaxis.set_label_coords(-.1, .44) 52 | 53 | 54 | ax = axs[1] 55 | y = np.exp(y) 56 | y -= y.min() 57 | ax.plot(x, y, color='k') #, zorder=10) 58 | ax.set_xlabel('$$u$$') 59 | ax.xaxis.set_label_coords(.5, 0.01) 60 | 61 | mu, sigma = ustar, 0.8 62 | ystar = np.exp(-.5*((x-mu)/sigma)**2) #/ (sigma*np.sqrt(2.*np.pi)) 63 | ystar = ystar * y.sum() / ystar.sum() 64 | ax.plot(x, ystar, color='#AA0000') 65 | 66 | mu, sigma = ustar+2, 1.5 67 | yhat = np.exp(-.5*((x-mu)/sigma)**2) #/ (sigma*np.sqrt(2.*np.pi)) 68 | yhat = yhat * y.sum() / yhat.sum() 69 | ax.plot(x, yhat, color='#5499FF') 70 | 71 | # I = [250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 800] 72 | I = [250, 300, 350, 400, 650, 700, 750, 800] 73 | for i in I: 74 | ax.arrow(x=x[i], y=yhat[i], dx=-0.5, dy=0., 75 | width=0.05, 76 | color='#5499FF', 77 | zorder=10) 78 | 79 | ax.text(0.37, 0.74, '$$\pi^\star(x)$$', color='#AA0000', 80 | transform=ax.transAxes, ha='left', va='bottom') 81 | ax.text(0.6, 0.45, r'$$\pi_\theta(x)$$', color='#5499FF', 82 | transform=ax.transAxes, ha='left', va='bottom') 83 | ax.text(0., 0.43, '$$\mathcal{Q}(x, u)$$', 84 | transform=ax.transAxes, ha='left', va='bottom') 85 | ax.axhline(0., color='k',zorder=-1) 86 | 87 | # ax.set_ylabel('$${\mathcal{Q}}(x, u)$$', rotation=0, labelpad=0) 88 | # ax.yaxis.set_label_coords(-.1, .44) 89 | 90 | fig.tight_layout() 91 | for ax in axs: 92 | ax.set_xticks([]) 93 | ax.set_yticks([]) 94 | ax.grid(False) 95 | ax.spines['top'].set_visible(False) 96 | ax.spines['right'].set_visible(False) 97 | ax.spines['bottom'].set_visible(False) 98 | ax.spines['left'].set_visible(False) 99 | 100 | ax.set_title('Stochastic Policy', fontsize=12, pad=5) 101 | 102 | fname = 'ctrl.pdf' 103 | plt.savefig(fname, transparent=True) 104 | os.system(f'pdfcrop {fname} {fname}') 105 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /code/evaluate_amortization_speed_vae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from torchvision.utils import make_grid, save_image 7 | import numpy as np 8 | 9 | import argparse 10 | import os 11 | import sys 12 | 13 | sys.path.append('vae_submodule') 14 | from utils.helpers import FormatterNoDuplicate, check_bounds, set_seed 15 | from utils.visualize import Visualizer 16 | from utils.viz_helpers import get_samples 17 | from disvae.utils.modelIO import load_model, load_metadata 18 | from disvae.models.losses import get_loss_f 19 | 20 | from evaluate_amortization_speed_function import evaluate_amortization_speed 21 | 22 | import sys 23 | from IPython.core import ultratb 24 | sys.excepthook = ultratb.FormattedTB( 25 | mode='Plain', color_scheme='Neutral', call_pdb=1) 26 | 27 | 28 | def sample_gaussian(mean, logvar): 29 | std = torch.exp(0.5 * logvar) 30 | eps = torch.randn_like(std) 31 | return mean + std * eps 32 | 33 | def unflatten_latent(z_flat): 34 | n = z_flat.shape[-1] 35 | return z_flat[...,:n//2], z_flat[...,n//2:] 36 | 37 | 38 | def estimate_elbo(x, z_flat, decoder): 39 | latent_dist = unflatten_latent(z_flat) 40 | 41 | latent_sample = sample_gaussian(*latent_dist) 42 | latent_sample = latent_sample 43 | recon_batch = decoder(latent_sample) 44 | batch_size = x.shape[0] 45 | log_likelihood = -F.binary_cross_entropy(recon_batch, x, reduce=False).sum(dim=[1,2,3]) 46 | 47 | mean, logvar = latent_dist 48 | latent_kl = 0.5 * (-1 - logvar + mean.pow(2) + logvar.exp()) 49 | kl_to_prior = latent_kl.sum(dim=[-1]) 50 | 51 | assert log_likelihood.shape == kl_to_prior.shape 52 | loss = log_likelihood - kl_to_prior 53 | return loss 54 | 55 | 56 | def main(): 57 | model_dir = 'vae_submodule/results/VAE_mnist' 58 | meta_data = load_metadata(model_dir) 59 | model = load_model(model_dir).cuda() 60 | model.eval() # don't sample from latent: use mean 61 | dataset = meta_data['dataset'] 62 | loss_f = get_loss_f('VAE', 63 | n_data=len(dataset), 64 | device='cuda', 65 | rec_dist='bernoulli', 66 | reg_anneal=0) 67 | 68 | batch_size = 1024 69 | num_save = 15 70 | data_samples = get_samples(dataset, batch_size, idcs=[25518, 13361, 22622]).cuda() 71 | 72 | def amortization_model(data_samples): 73 | latent_dist = model.encoder(data_samples) 74 | latent_dist_flat = torch.cat(latent_dist, dim=-1) 75 | return latent_dist_flat 76 | 77 | def amortization_objective(latent_dist_flat, data_samples): 78 | elbo = estimate_elbo(data_samples, latent_dist_flat, model.decoder) 79 | return elbo 80 | 81 | iterate_history, predicted_samples = evaluate_amortization_speed( 82 | amortization_model=amortization_model, 83 | amortization_objective=amortization_objective, 84 | contexts=data_samples, 85 | tag='vae', 86 | fig_ylabel='ELBO', 87 | adam_lr=5e-3, 88 | num_iterations=2000, 89 | maximize=True, 90 | save_iterates=[0, 250, 500, 1000, 2000], 91 | num_save=num_save, 92 | ) 93 | 94 | iterate_history.append((-1, predicted_samples[:num_save])) 95 | 96 | reconstructions = [] 97 | for i, latent_dist_flat in iterate_history: 98 | latent_dist = unflatten_latent(latent_dist_flat) 99 | latent_mean = latent_dist[0] 100 | reconstructions.append(1.-model.decoder(latent_mean)) 101 | 102 | reconstructions.append(1.-data_samples[:num_save]) 103 | reconstructions = torch.cat(reconstructions, dim=0) 104 | reconstructions = F.interpolate(reconstructions, 105 | recompute_scale_factor=True, scale_factor=1.5, mode='bilinear') 106 | 107 | fname = f'vae-samples.png' 108 | save_image(reconstructions, fname, nrow=num_save) 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /code/evaluate_amortization_speed_function.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import torch 4 | import numpy as np 5 | 6 | import matplotlib.pyplot as plt 7 | plt.style.use('bmh') 8 | params = { 9 | "text.usetex" : True, 10 | "font.family" : "serif", 11 | "font.serif" : ["Computer Modern Serif"] 12 | } 13 | plt.rcParams.update(params) 14 | 15 | import os 16 | import time 17 | 18 | def evaluate_amortization_speed( 19 | amortization_model, 20 | amortization_objective, 21 | contexts, 22 | tag, 23 | fig_ylabel, 24 | adam_lr=5e-3, 25 | num_iterations=2000, 26 | maximize=False, 27 | iter_history_callback=None, 28 | save_iterates=[], 29 | num_save=8, 30 | ): 31 | times = [] 32 | n_trials = 10 33 | for i in range(n_trials+1): 34 | start_time = time.time() 35 | predicted_solutions = amortization_model(contexts) 36 | if i > 0: 37 | times.append(time.time()-start_time) 38 | 39 | amortized_objectives = amortization_objective( 40 | predicted_solutions, contexts 41 | ).cpu().detach() 42 | print(f'solution size: {predicted_solutions.shape[1]}') 43 | print('--- amortization model') 44 | print(f'average objective value: {amortized_objectives.mean():.2f}') 45 | print(f'average runtime: {np.mean(times)*1000:.2f}ms') 46 | 47 | iterates = torch.nn.Parameter(torch.zeros_like(predicted_solutions)) 48 | 49 | opt = torch.optim.Adam([iterates], lr=adam_lr) 50 | 51 | objective_history = [] 52 | times = [] 53 | iterations = [] 54 | iterate_history = [] 55 | 56 | start_time = time.time() 57 | 58 | for i in range(num_iterations+1): 59 | objectives = amortization_objective(iterates, contexts) 60 | mean_objective = objectives.mean() 61 | if maximize: 62 | mean_objective *= -1. 63 | opt.zero_grad() 64 | mean_objective.backward() 65 | opt.step() 66 | 67 | if i % 50 == 0: 68 | iterations.append(i) 69 | times.append(time.time()-start_time) 70 | objective_history.append((objectives.mean().item(), objectives.std().item())) 71 | print(i, objectives.mean().item()) 72 | 73 | if i in save_iterates: 74 | iterate_history.append((i, iterates[:num_save].detach().clone())) 75 | 76 | 77 | times = np.array(times) 78 | 79 | figsize = (4,2) 80 | fig, ax = plt.subplots(figsize=figsize, dpi=200) 81 | objective_means, objective_stds = map(np.array, zip(*objective_history)) 82 | 83 | l, = ax.plot(iterations, objective_means) 84 | ax.axhline(amortized_objectives.mean().cpu().detach(), color='k', linestyle='--') 85 | ax.axhspan(amortized_objectives.mean()-amortized_objectives.std(), 86 | amortized_objectives.mean()+amortized_objectives.std(), color='k', alpha=0.15) 87 | ax.fill_between( 88 | iterations, objective_means-objective_stds, objective_means+objective_stds, 89 | color=l.get_color(), alpha=0.5) 90 | ax.set_xlabel('Adam Iterations') 91 | ax.set_ylabel(fig_ylabel) 92 | ax.set_xlim(0, max(iterations)) 93 | # ax.set_ylim(0, 1000) 94 | fig.tight_layout() 95 | fname = f'{tag}-iter.pdf' 96 | print(f'saving to {fname}') 97 | fig.savefig(fname, transparent=True) 98 | os.system(f'pdfcrop {fname} {fname}') 99 | 100 | fig, ax = plt.subplots(figsize=figsize, dpi=200) 101 | ax.axhline(amortized_objectives.mean(), color='k', linestyle='--') 102 | ax.axhspan(amortized_objectives.mean()-amortized_objectives.std(), 103 | amortized_objectives.mean()+amortized_objectives.std(), color='k', alpha=0.15) 104 | l, = ax.plot(times, objective_means) 105 | ax.fill_between( 106 | times, objective_means-objective_stds, objective_means+objective_stds, 107 | color=l.get_color(), alpha=0.5) 108 | ax.set_xlim(0, max(times)) 109 | # ax.set_ylim(0, 1000) 110 | ax.set_xlabel('Runtime (seconds)') 111 | ax.set_ylabel(fig_ylabel) 112 | fig.tight_layout() 113 | 114 | fname = f'{tag}-time.pdf' 115 | print(f'saving to {fname}') 116 | fig.savefig(fname, transparent=True) 117 | os.system(f'pdfcrop {fname} {fname}') 118 | 119 | return iterate_history, predicted_solutions 120 | 121 | -------------------------------------------------------------------------------- /code/figures/maxent-animation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from matplotlib import cm 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | 11 | import shutil 12 | import os 13 | 14 | plt.rcParams.update({ 15 | "text.usetex": True, 16 | "font.family": "serif", 17 | "font.sans-serif": ["Computer Modern Roman"], 18 | "text.latex.preamble": r"\usepackage{amsfonts}", 19 | }) 20 | 21 | 22 | # We will define a 1D density parameterized by \phi to maximize over 23 | # with gradient steps, and implement this by discretizing over the domain. 24 | phi = jnp.array([0., 1.5, .7, 6.]) 25 | 26 | @jax.jit 27 | def compute_dist(x, phi): 28 | # Compute values of at the discretized points in the domain. 29 | v = jnp.exp(-0.5*((x-phi[0])/phi[1])**2 + phi[2]*jnp.sin(x*phi[3])) 30 | dx = x[1:]-x[:-1] 31 | y = v/sum(v[1:]*dx) # Normalize to be a proper distribution. 32 | flow_x = flow(x, y) # Constrain the mean and variance. 33 | 34 | # Compute the new probabilities. 35 | J_flow = jnp.diag(jax.jacfwd(flow)(x, y)) 36 | flow_y = y / J_flow 37 | return flow_x, flow_y 38 | 39 | @jax.jit 40 | def mean(x, y): 41 | dx = x[1:]-x[:-1] 42 | x = x[1:] 43 | y = y[1:] 44 | return sum(x*y*dx) 45 | 46 | @jax.jit 47 | def std(x, y): 48 | mu = mean(x,y) 49 | dx = x[1:]-x[:-1] 50 | x = x[1:] 51 | y = y[1:] 52 | return jnp.sqrt(sum(((x-mu)**2)*y*dx)) 53 | 54 | @jax.jit 55 | def entr(x, y): 56 | dx = x[1:]-x[:-1] 57 | y = y[1:] 58 | return -sum(y*jnp.log(y+1e-8)*dx) 59 | 60 | @jax.jit 61 | def flow(x, y): 62 | # Normalize the domain so that the distribution has 63 | # zero mean and identity variance. 64 | return (x - mean(x,y)) / std(x, y) 65 | 66 | @jax.jit 67 | def loss(x, phi): 68 | x, y = compute_dist(x, phi) 69 | return -entr(x, y) 70 | 71 | 72 | # Prepare the output directory 73 | d = 'maxent-animation' 74 | if os.path.exists(d): 75 | shutil.rmtree(d) 76 | os.makedirs(d) 77 | 78 | 79 | def plot(t): 80 | nrow, ncol = 1, 2 81 | fig, axs = plt.subplots(nrow, ncol, figsize=(ncol*3, nrow*2), dpi=200, 82 | gridspec_kw={'wspace': .3, 'hspace': 0}) 83 | ax = axs[0] 84 | ax.plot(entrs, color='k') 85 | ax.set_xlabel('Updates', fontsize=10) 86 | ax.set_title(r'Entropy ($\mathbb{H}_p[X]$)', fontsize=10) 87 | ax.set_xlim(0, n_step) 88 | ax.set_ylim(1.3, 1.45) 89 | 90 | ax = axs[1] 91 | ax.plot(x, y, color='k') 92 | ax.set_ylim(0, 0.7) 93 | ax.set_xlim(-3, 3) 94 | ax.set_xlabel('$x$', fontsize=10) 95 | ax.set_title('$p(x)$', fontsize=10) 96 | 97 | for ax in axs: 98 | ax.grid(False) 99 | ax.spines['top'].set_visible(False) 100 | ax.spines['right'].set_visible(False) 101 | ax.spines['bottom'].set_visible(False) 102 | ax.spines['left'].set_visible(False) 103 | 104 | fig.suptitle( 105 | r'$\max_{p} \mathbb{H}_p[X]\; \rm{subject\; to}\; \mathbb{E}_p[X] = \mu\;\rm{and}\;\rm{Var}_p[X]=\Sigma$') 106 | fig.subplots_adjust(top=0.7) 107 | 108 | fname = f'{d}/{t:04d}.png' 109 | plt.savefig(fname, bbox_inches='tight') 110 | plt.close(fig) 111 | # os.system(f'pdfcrop {fname} {fname}') 112 | os.system(f'convert -trim {fname} {fname}') 113 | 114 | 115 | # jitted derivative of the loss with respect to phi 116 | dloss_dphi = jax.jit(jax.grad(loss, argnums=1)) 117 | 118 | # Number of discretization points in the domain 119 | # Decrease this to run faster 120 | N = 1000 121 | 122 | # The domain of the unprojected distribution 123 | x_unproj = jnp.linspace(-5.0, 5.0, N) 124 | 125 | entrs = [] 126 | x, y = compute_dist(x_unproj, phi) 127 | entrs.append(entr(x,y)) 128 | print(f'entr={entr(x,y):.2f} (mean={mean(x, y):.2f} std={std(x,y):.2f})') 129 | 130 | # The step size can be much larger but it's set to this for the animation. 131 | n_step = 100 132 | step_size = 0.13 133 | for t in range(n_step): 134 | # Take a gradient step with respect to the 135 | # parameters of the distribution 136 | phi -= step_size*dloss_dphi(x_unproj, phi) 137 | x, y = compute_dist(x_unproj, phi) 138 | entrs.append(entr(x,y)) 139 | print(f'entr={entr(x,y):.2f} (mean={mean(x, y):.2f} std={std(x,y):.2f})') 140 | 141 | plot(t) 142 | 143 | # By the end, we see that the entropy is the true maximal entropy 144 | # of the Gaussian of (1/2)log(2\pi)+(1/2) \approx 1.42. 145 | 146 | os.system(f'convert -delay 10 -loop 0 {d}/*.png {d}/maxent.gif') 147 | -------------------------------------------------------------------------------- /code/train-sphere.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | import torch 5 | from torch import nn 6 | import numpy as np 7 | import os 8 | 9 | import matplotlib.pyplot as plt 10 | plt.style.use('bmh') 11 | 12 | import sys 13 | from IPython.core import ultratb 14 | sys.excepthook = ultratb.FormattedTB( 15 | mode='Plain', color_scheme='Neutral', call_pdb=1) 16 | 17 | def celestial_to_euclidean(ra, dec): 18 | x = np.cos(dec)*np.cos(ra) 19 | y = np.cos(dec)*np.sin(ra) 20 | z = np.sin(dec) 21 | return x, y, z 22 | 23 | def euclidean_to_celestial(x, y, z): 24 | sindec = z 25 | cosdec = np.sqrt(x*x + y*y) 26 | sinra = y / cosdec 27 | cosra = x / cosdec 28 | ra = np.arctan2(sinra, cosra) 29 | dec = np.arctan2(sindec, cosdec) 30 | return ra, dec 31 | 32 | def euclidean_to_celestial_th(x, y, z): 33 | sindec = z 34 | cosdec = (x*x + y*y).sqrt() 35 | sinra = y / cosdec 36 | cosra = x / cosdec 37 | ra = torch.atan2(sinra, cosra) 38 | dec = torch.atan2(sindec, cosdec) 39 | return ra, dec 40 | 41 | 42 | def sphere_dist_th(x,y): 43 | if x.ndim == 1: 44 | x = x.unsqueeze(0) 45 | if y.ndim == 1: 46 | y = y.unsqueeze(0) 47 | assert x.ndim == y.ndim == 2 48 | inner = (x*y).sum(-1) 49 | return torch.arccos(inner) 50 | 51 | class c_convex(nn.Module): 52 | def __init__(self, n_components=4, gamma=0.5, seed=None): 53 | super().__init__() 54 | self.n_components = n_components 55 | self.gamma = gamma 56 | 57 | # Sample a random c-convex function 58 | if seed is not None: 59 | torch.manual_seed(seed) 60 | self.ys = torch.randn(n_components, 3) 61 | self.ys = self.ys / torch.norm(self.ys, 2, dim=-1, keepdim=True) 62 | self.alphas = .7*torch.rand(self.n_components) 63 | self.params = torch.cat((self.ys.view(-1), self.alphas.view(-1))) 64 | 65 | def forward(self, xyz): 66 | # TODO: Could be optimized 67 | cs = [] 68 | for y, alpha in zip(self.ys, self.alphas): 69 | ci = 0.5*sphere_dist_th(y, xyz)**2 + alpha 70 | cs.append(ci) 71 | cs = torch.stack(cs) 72 | if self.gamma == None or self.gamma == 0.: 73 | z = cs.min(dim=0).values 74 | else: 75 | z = -self.gamma*(-cs/self.gamma).logsumexp(dim=0) 76 | return z 77 | 78 | 79 | seeds = [8,9,2,31,4,20,16,7] 80 | fs = [c_convex(seed=i) for i in seeds] 81 | n_params = len(fs[0].params) 82 | 83 | class AmortizedModel(nn.Module): 84 | def __init__(self, n_params): 85 | super().__init__() 86 | self.base = nn.Sequential( 87 | nn.Linear(n_params, n_hidden), 88 | nn.ReLU(inplace=True), 89 | nn.Linear(n_hidden, n_hidden), 90 | nn.ReLU(inplace=True), 91 | nn.Linear(n_hidden, 3) 92 | ) 93 | 94 | def forward(self, p): 95 | squeeze = p.ndim == 1 96 | if squeeze: 97 | p = p.unsqueeze(0) 98 | assert p.ndim == 2 99 | z = self.base(p) 100 | z = z / z.norm(dim=-1, keepdim=True) 101 | if squeeze: 102 | z = z.squeeze(0) 103 | return z 104 | 105 | n_hidden = 128 106 | torch.manual_seed(0) 107 | model = AmortizedModel(n_params=n_params) 108 | opt = torch.optim.Adam(model.parameters(), lr=5e-4) 109 | 110 | xs = [] 111 | for i in range(100): 112 | losses = [] 113 | xis = [] 114 | for f in fs: 115 | pred_opt = model(f.params) 116 | xis.append(pred_opt) 117 | losses.append(f(pred_opt)) 118 | with torch.no_grad(): 119 | xis = torch.stack(xis) 120 | xs.append(xis) 121 | loss = sum(losses) 122 | 123 | opt.zero_grad() 124 | loss.backward() 125 | opt.step() 126 | 127 | xs = torch.stack(xs, dim=1) 128 | 129 | pad = .1 130 | n_sample = 100 131 | ra = np.linspace(-np.pi+pad, np.pi-pad, n_sample) 132 | dec= np.linspace(-np.pi/2+pad, np.pi/2-pad, n_sample) 133 | ra_grid, dec_grid = np.meshgrid(ra,dec) 134 | ra_grid_flat = ra_grid.ravel() 135 | dec_grid_flat = dec_grid.ravel() 136 | x_grid, y_grid, z_grid = celestial_to_euclidean(ra_grid_flat, dec_grid_flat) 137 | 138 | p_grid = np.stack((x_grid, y_grid, z_grid), axis=-1) 139 | p_grid_th = torch.from_numpy(p_grid).float() 140 | 141 | 142 | for i, (f, xs_i) in enumerate(zip(fs, xs)): 143 | nrow, ncol = 1, 1 144 | fig, ax = plt.subplots( 145 | nrow, ncol, figsize=(3*ncol, 2*nrow), 146 | subplot_kw={'projection': 'mollweide'}, 147 | gridspec_kw = {'wspace':0, 'hspace':0} 148 | ) 149 | 150 | with torch.no_grad(): 151 | f_grid = f(p_grid_th).numpy() 152 | best_i = f_grid.argmin() 153 | ra_opt, dec_opt= ra_grid_flat[best_i], dec_grid_flat[best_i] 154 | 155 | f_grid = f_grid.reshape(ra_grid.shape) 156 | n_levels = 10 157 | ax.contourf(ra_grid, dec_grid, f_grid, n_levels, cmap='Purples') 158 | 159 | x,y,z = xs_i.split(1,dim=-1) 160 | ra, dec = euclidean_to_celestial_th(x,y,z) 161 | ax.plot(ra, dec, color='#5499FF', lw=3, ls=':') 162 | 163 | ax.scatter(ra_opt, dec_opt, marker='*', color='#AA0000', 164 | s=100, zorder=10) 165 | 166 | for s in ax.spines.values(): 167 | s.set_visible(False) 168 | 169 | ax.set_xticks([]) 170 | ax.set_yticks([]) 171 | ax.grid(False) 172 | 173 | fname = f'paper/fig/sphere/{i}.png' 174 | plt.savefig(fname, transparent=True) 175 | os.system(f'convert -trim {fname} {fname}') 176 | -------------------------------------------------------------------------------- /paper/amor.tex: -------------------------------------------------------------------------------- 1 | \documentclass[oneside,11pt]{book} 2 | \pagestyle{plain} 3 | \usepackage[Lenny]{fncychap} 4 | \ChNameVar{\Large} 5 | 6 | \newcommand{\BlackBox}{\rule{1.5ex}{1.5ex}} % end of proof 7 | \newenvironment{proof}{\par\noindent{\bf Proof\ }}{\hfill\BlackBox\\[2mm]} 8 | \newtheorem{theorem}{Theorem} 9 | \newtheorem{proposition}{Proposition} 10 | \newtheorem{remark}{Remark} 11 | \newtheorem{definition}{Definition} 12 | 13 | \usepackage[margin=1.5in]{geometry} 14 | 15 | \usepackage{natbib} 16 | \bibliographystyle{plainnat} 17 | 18 | \usepackage{xspace} 19 | 20 | \newcommand{\ie}{\emph{i.e.}\xspace} 21 | \newcommand{\eg}{\emph{e.g.}\xspace} 22 | \newcommand{\etc}{\emph{etc}\xspace} 23 | \newcommand{\now}{\textsc{now}} 24 | 25 | \input{math_commands.tex} 26 | 27 | \usepackage[utf8]{inputenc} % allow utf-8 input 28 | \usepackage[T1]{fontenc} % use 8-bit T1 fonts 29 | \DeclareUnicodeCharacter{1EF3}{\`y} 30 | \DeclareUnicodeCharacter{0301}{\'{e}} 31 | 32 | \usepackage{url} % simple URL typesetting 33 | \usepackage{booktabs} % professional-quality tables 34 | \usepackage{amsfonts} % blackboard math symbols 35 | \usepackage{nicefrac} % compact symbols for 1/2, etc. 36 | \usepackage{microtype} % microtypography 37 | \usepackage{enumitem} 38 | \usepackage{slantsc} 39 | \usepackage{caption} 40 | 41 | \renewcommand\eminnershape{\itshape\scshape} 42 | 43 | \usepackage{xcolor} 44 | \definecolor{linkcolor}{RGB}{74, 102, 146} 45 | \definecolor{lightpurple}{RGB}{168, 141, 201} 46 | 47 | \usepackage[ 48 | colorlinks=true,allcolors=linkcolor,pageanchor=true, 49 | plainpages=false,pdfpagelabels,bookmarks,bookmarksnumbered, 50 | backref=page, 51 | pdfauthor={Brandon Amos}, 52 | pdftitle={Tutorial on amortized optimization}, 53 | ]{hyperref} 54 | 55 | \renewcommand*{\backref}[1]{} 56 | \renewcommand*{\backrefalt}[4]{% 57 | \ifcase #1 \or (Cited on page~#2.) 58 | \else (Cited on pages~#2.) 59 | \fi% 60 | } 61 | 62 | 63 | \usepackage[nameinlink]{cleveref} 64 | \Crefname{equation}{Eq.}{Eqs.} 65 | 66 | \newcommand\pro{\item[$+$]} 67 | \newcommand\con{\item[$-$]} 68 | 69 | \usepackage{caption} 70 | \usepackage{subcaption} 71 | \usepackage{wrapfig} 72 | \usepackage{tikz} 73 | \usetikzlibrary{arrows.meta, arrows, backgrounds, bayesnet, calc, matrix} 74 | 75 | \usepackage{listings} 76 | \definecolor{code_green}{rgb}{0,0.6,0} 77 | \definecolor{code_gray}{rgb}{0.5,0.5,0.5} 78 | \definecolor{code_purple}{rgb}{.5, .21, .68} 79 | 80 | \lstset{ % 81 | backgroundcolor=\color{white}, 82 | basicstyle=\footnotesize\ttfamily, 83 | breaklines=true, 84 | captionpos=b, 85 | commentstyle=\color{code_green}, 86 | escapeinside={\%*}{*)}, 87 | keywordstyle=\color{blue}, 88 | stringstyle=\color{code_purple}, 89 | language=Python, 90 | columns=fullflexible, 91 | numbers=left, 92 | xleftmargin=10mm, 93 | } 94 | 95 | % \lstset{basicstyle=\footnotesize\ttfamily, columns=fullflexible, language=Python, 96 | % numbers=left} 97 | 98 | \setlength\tabcolsep{4 pt} 99 | 100 | \newcommand{\cblock}[3]{ 101 | \hspace{-1.5mm} 102 | \begin{tikzpicture} 103 | [ 104 | node/.style={square, minimum size=10mm, thick, line width=0pt}, 105 | ] 106 | \node[fill={rgb,255:red,#1;green,#2;blue,#3}] () [] {}; 107 | \end{tikzpicture}% 108 | } 109 | 110 | \begin{document} 111 | 112 | \begin{titlepage} 113 | \thispagestyle{empty} 114 | \begin{center} 115 | \textbf{\Large Tutorial on amortized optimization} \\ 116 | {\large Learning to optimize over continuous spaces} \\~\\ 117 | Brandon Amos, \emph{Meta AI} 118 | \end{center} 119 | 120 | \vspace{0.9cm} 121 | \noindent\textbf{Abstract.} \\ 122 | Optimization is a ubiquitous modeling tool and is often 123 | deployed in settings which repeatedly solve similar 124 | instances of the same problem. 125 | Amortized optimization methods use learning to predict the solutions to 126 | problems in these settings, exploiting the shared structure 127 | between similar problem instances. 128 | These methods have been crucial in variational inference 129 | and reinforcement learning and are capable of solving 130 | optimization problems many orders of magnitude faster 131 | than traditional optimization methods that do not use amortization. 132 | This tutorial presents an introduction to the amortized optimization 133 | foundations behind these advancements and overviews 134 | their applications in variational inference, sparse coding, 135 | gradient-based meta-learning, control, reinforcement learning, 136 | convex optimization, optimal transport, and deep equilibrium networks. 137 | The source code for this tutorial is available at 138 | {\footnotesize\url{https://github.com/facebookresearch/amortized-optimization-tutorial}}. 139 | \end{titlepage} 140 | 141 | \setcounter{tocdepth}{1} 142 | \tableofcontents 143 | 144 | \include{chapters/1-intro} 145 | \include{chapters/2-foundations} 146 | \include{chapters/3-applications} 147 | \include{chapters/4-implementation} 148 | \include{chapters/5-discussion} 149 | 150 | \newpage 151 | \section*{Acknowledgments} 152 | I would like to thank 153 | Nil-Jana Akpinar, 154 | Alfredo Canziani, 155 | Samuel Cohen, 156 | Georgina Hall, 157 | Misha Khodak, 158 | Boris Knyazev, 159 | Hane Lee, 160 | Joe Marino, 161 | Maximilian Nickel, 162 | Paavo Parmas, 163 | Rajiv Sambharya, 164 | Jens Sj\"olund, 165 | Bartolomeo Stellato, 166 | Alex Terenin, 167 | Eugene Vinitsky, 168 | Atlas Wang, 169 | and 170 | Arman Zharmagambetov 171 | for insightful discussions 172 | and feedback on this tutorial. 173 | I am also grateful to the anonymous FnT reviewers 174 | who gave a significant amount of helpful and 175 | detailed feedback. 176 | 177 | {\footnotesize\bibliography{amor}} 178 | 179 | \end{document} 180 | 181 | 182 | %%% Local Variables: 183 | %%% coding: utf-8 184 | %%% mode: latex 185 | %%% TeX-master: "amor.tex" 186 | %%% LaTeX-biblatex-use-Biber: True 187 | %%% End: 188 | -------------------------------------------------------------------------------- /code/figures/main-example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from matplotlib import cm 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | 11 | import os 12 | plt.rcParams.update({ 13 | "text.usetex": True, 14 | "font.family": "serif", 15 | "font.sans-serif": ["Computer Modern Roman"]}) 16 | plt.style.use('bmh') 17 | 18 | # Initial problem with x^\star 19 | N = 1000 20 | x = np.linspace(-5.0, 5.0, N) 21 | y = np.linspace(-10.0, 10.0, N) 22 | X, Y = np.meshgrid(x, y) 23 | Z = (Y-np.sin(X)-X*(1+0.1*np.cos(X)))**2 24 | Z = 1./(1.+np.exp(-Z/80.)) 25 | 26 | fig, ax = plt.subplots(figsize=(2,1.7), dpi=200) 27 | CS = ax.contourf(X, Y, Z, cmap='Purples') 28 | 29 | ax.text(0., 1., r'$$f(y; x)$$', color='#491386', 30 | bbox=dict(facecolor='white', pad=0, alpha=0.9, edgecolor='none'), 31 | transform=ax.transAxes, ha='left', va='top') 32 | 33 | 34 | I = np.argmin(Z, axis=0) 35 | xstar, ystar = x, y[I] 36 | ax.plot(xstar, ystar, color='#AA0000', lw=3) 37 | ax.text(.92, .8, '$$y^\star(x)$$', color='#AA0000', 38 | transform=ax.transAxes, ha='right', va='top') 39 | 40 | ax.set_ylabel('$$y$$', rotation=0, labelpad=0) 41 | ax.yaxis.set_label_coords(-.07, .44) 42 | ax.set_xlabel('$$x$$') 43 | ax.xaxis.set_label_coords(.5, 0.01) 44 | 45 | fig.tight_layout() 46 | ax.set_xticks([]) 47 | ax.set_yticks([]) 48 | ax.grid(False) 49 | ax.spines['top'].set_visible(False) 50 | ax.spines['right'].set_visible(False) 51 | ax.spines['bottom'].set_visible(False) 52 | ax.spines['left'].set_visible(False) 53 | fname = 'opt.pdf' 54 | plt.savefig(fname, transparent=True) 55 | os.system(f'pdfcrop {fname} {fname}') 56 | 57 | # Regression loss 58 | xhat, yhat= xstar.copy(), ystar.copy() 59 | yhat = -0.5*yhat + 0.0*xhat*np.maximum(xhat, 0.) - \ 60 | 0.23*xhat*np.minimum(xhat, 0.) 61 | 62 | fig, ax = plt.subplots(figsize=(2,1.7), dpi=200) 63 | CS = ax.contourf(X, Y, Z, cmap='Purples') 64 | 65 | ax.text(0., 1., r'$$f(y; x)$$', color='#491386', 66 | bbox=dict(facecolor='white', pad=0, alpha=0.9, edgecolor='none'), 67 | transform=ax.transAxes, ha='left', va='top') 68 | 69 | I = np.argmin(Z, axis=0) 70 | xstar, ystar = x, y[I] 71 | ax.plot(xstar, ystar, color='#AA0000', lw=3) 72 | ax.text(.92, .8, '$$y^\star(x)$$', color='#AA0000', 73 | transform=ax.transAxes, ha='right', va='top') 74 | 75 | ax.plot(xhat, yhat, color='#5499FF', lw=3) 76 | ax.text(0.3, .57, r'$$\hat y_\theta(x)$$', color='#5499FF', 77 | bbox=dict(facecolor='white', pad=0, alpha=0.6, edgecolor='none'), 78 | transform=ax.transAxes, ha='left', va='bottom') 79 | 80 | n_reg = 15 81 | pad = 35 82 | I = np.round(np.linspace(pad, len(y) - 1 - pad, n_reg)).astype(int) 83 | for idx in I: 84 | ax.plot( 85 | (xstar[idx], xhat[idx]), (yhat[idx], ystar[idx]), 86 | color='k', lw=1, solid_capstyle='round') 87 | 88 | ax.set_ylabel('$$y$$', rotation=0, labelpad=0) 89 | ax.yaxis.set_label_coords(-.07, .44) 90 | ax.set_xlabel('$$x$$') 91 | ax.xaxis.set_label_coords(.5, 0.01) 92 | ax.set_title('Regression-Based', fontsize=12, pad=0) 93 | 94 | fig.tight_layout() 95 | ax.set_xticks([]) 96 | ax.set_yticks([]) 97 | ax.grid(False) 98 | ax.spines['top'].set_visible(False) 99 | ax.spines['right'].set_visible(False) 100 | ax.spines['bottom'].set_visible(False) 101 | ax.spines['left'].set_visible(False) 102 | fname = 'learning-reg.pdf' 103 | plt.savefig(fname, transparent=True) 104 | os.system(f'pdfcrop {fname} {fname}') 105 | 106 | # Objective loss 107 | fig, ax = plt.subplots(figsize=(2,1.7), dpi=200) 108 | CS = ax.contourf(X, Y, Z, cmap='Purples') 109 | 110 | ax.plot(xstar, ystar, color='#AA0000', lw=3, ls='--') 111 | ax.plot(xhat, yhat, color='#5499FF', lw=3) 112 | 113 | I = np.round(np.linspace(pad, len(y) - 1 - pad, n_reg)).astype(int) 114 | 115 | def f(x,y): 116 | z = y-jnp.sin(x)-x*1.+0.1*jnp.cos(x) 117 | z = z**2 118 | z = 1./(1.+jnp.exp(-z/80.)) 119 | return z 120 | 121 | df = jax.grad(f, argnums=1) 122 | 123 | for idx in I: 124 | x,y = jnp.array(xhat[idx]), jnp.array(yhat[idx]) 125 | z = f(x,y) 126 | dz = df(x,y) 127 | ax.quiver( 128 | xhat[idx], yhat[idx], 0., -dz, 129 | color='k', lw=1, scale=.2, zorder=10) #, solid_capstyle='round') 130 | 131 | ax.set_ylabel('$$y$$', rotation=0, labelpad=0) 132 | ax.yaxis.set_label_coords(-.07, .44) 133 | ax.set_xlabel('$$x$$') 134 | ax.xaxis.set_label_coords(.5, 0.01) 135 | ax.set_title('Objective-Based', fontsize=12, pad=0) 136 | 137 | fig.tight_layout() 138 | ax.set_xticks([]) 139 | ax.set_yticks([]) 140 | ax.grid(False) 141 | ax.spines['top'].set_visible(False) 142 | ax.spines['right'].set_visible(False) 143 | ax.spines['bottom'].set_visible(False) 144 | ax.spines['left'].set_visible(False) 145 | fname = 'learning-obj.pdf' 146 | plt.savefig(fname, transparent=True) 147 | os.system(f'pdfcrop {fname} {fname}') 148 | 149 | 150 | # RL loss 151 | fig, ax = plt.subplots(figsize=(2,1.5), dpi=200) 152 | CS = ax.contourf(X, Y, Z, cmap='Purples') 153 | 154 | ax.plot(xstar, ystar, color='#AA0000', lw=3, ls='--') 155 | ax.plot(xhat, yhat, color='#5499FF', lw=3) 156 | 157 | np.random.seed(2) 158 | for _ in range(20): 159 | p = np.linspace(0, 3., len(xhat)) 160 | p = p*np.flip(p) 161 | q = 0.04*np.random.randn(len(xhat)) 162 | q = np.cumsum(q, axis=-1) 163 | q = q*np.flip(q) 164 | pert = 0.3*(p+q)*np.random.randn() 165 | ax.plot(xhat, yhat+pert, color='#5499FF', lw=1, alpha=0.3) 166 | 167 | # ax.set_xlabel('$$x$$') 168 | ax.xaxis.set_label_coords(.5, 0.01) 169 | # ax.set_title('RL-Based', fontsize=12, pad=0) 170 | 171 | fig.tight_layout() 172 | ax.set_xticks([]) 173 | ax.set_yticks([]) 174 | ax.grid(False) 175 | ax.spines['top'].set_visible(False) 176 | ax.spines['right'].set_visible(False) 177 | ax.spines['bottom'].set_visible(False) 178 | ax.spines['left'].set_visible(False) 179 | fname = 'learning-rl.pdf' 180 | plt.savefig(fname, transparent=True) 181 | os.system(f'pdfcrop {fname} {fname}') 182 | -------------------------------------------------------------------------------- /code/evaluate_amortization_speed_control.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import argparse 8 | import os 9 | import sys 10 | import pickle as pkl 11 | import shutil 12 | from omegaconf import OmegaConf 13 | from collections import namedtuple 14 | import dmc2gym 15 | 16 | import matplotlib.pyplot as plt 17 | plt.style.use('bmh') 18 | from matplotlib import cm 19 | 20 | from multiprocessing import Process 21 | 22 | from svg.video import VideoRecorder 23 | from svg import utils, dx 24 | 25 | from evaluate_amortization_speed_function import evaluate_amortization_speed 26 | 27 | def main(): 28 | import sys 29 | from IPython.core import ultratb 30 | sys.excepthook = ultratb.FormattedTB(mode='Verbose', 31 | color_scheme='Linux', 32 | call_pdb=1) 33 | 34 | exp = torch.load('svg_submodule/trained-humanoid/latest.pt') 35 | 36 | # Clean up logging after resuming the experiment code 37 | del exp.logger 38 | os.remove('eval.csv') 39 | os.remove('train.csv') 40 | 41 | observations = collect_eval_episode(exp) 42 | 43 | # First try to predict the maximum value action 44 | 45 | def amortization_model(observations): 46 | actions, _, _ = exp.agent.actor(observations, compute_pi=False, compute_log_pi=False) 47 | return actions 48 | 49 | def amortization_objective(actions, observations, normalize=True): 50 | q1, q2 = exp.agent.critic(observations, actions) 51 | values = torch.min(q1, q2).squeeze() 52 | if normalize: 53 | values = normalize_values(values) 54 | 55 | return values 56 | 57 | with torch.no_grad(): 58 | expert_actions = amortization_model(observations) 59 | zero_actions = torch.zeros_like(expert_actions) 60 | expert_values = amortization_objective(expert_actions, observations, normalize=False) 61 | zero_values = amortization_objective(zero_actions, observations, normalize=False) 62 | 63 | def normalize_values(values): 64 | """normalize so that the expert value is 0 and the zero action is -1.""" 65 | norm_values = (values - expert_values) / (expert_values - zero_values) 66 | 67 | # assume we can't do better than the expert. 68 | # otherwise the optimization overfits to the inaccurate model 69 | # and value approximation. 70 | norm_values[norm_values > 0.] = 0. 71 | return norm_values 72 | 73 | 74 | evaluate_amortization_speed( 75 | amortization_model=amortization_model, 76 | amortization_objective=amortization_objective, 77 | contexts=observations, 78 | tag='control-model-free', 79 | fig_ylabel='Value', 80 | adam_lr=5e-3, 81 | num_iterations=500, 82 | maximize=True, 83 | ) 84 | 85 | # Next try to predict the solution to the short-horizon model-based 86 | # control problem. 87 | def amortization_model(observations): 88 | num_batch = observations.shape[0] 89 | action_seq, _, _ = exp.agent.dx.unroll_policy( 90 | observations, exp.agent.actor, sample=False, last_u=True) 91 | action_seq_flat = action_seq.transpose(0,1).reshape(num_batch, -1) 92 | return action_seq_flat 93 | 94 | def amortization_objective(action_seq_flat, observations, normalize=True): 95 | num_batch = action_seq_flat.shape[0] 96 | action_seq = action_seq_flat.reshape(num_batch, -1, exp.agent.action_dim).transpose(0, 1) 97 | predicted_states = exp.agent.dx.unroll(observations, action_seq[:-1]) 98 | 99 | all_obs = torch.cat((observations.unsqueeze(0), predicted_states), dim=0) 100 | xu = torch.cat((all_obs, action_seq), dim=2) 101 | dones = exp.agent.done(xu).sigmoid().squeeze(dim=2) 102 | not_dones = 1. - dones 103 | not_dones = utils.accum_prod(not_dones) 104 | last_not_dones = not_dones[-1] 105 | 106 | rewards = not_dones * exp.agent.rew(xu).squeeze(2) 107 | q1, q2 = exp.agent.critic(all_obs[-1], action_seq[-1]) 108 | q = torch.min(q1, q2).reshape(num_batch) 109 | rewards[-1] = last_not_dones * q 110 | 111 | rewards *= exp.agent.discount_horizon.unsqueeze(1) 112 | 113 | values = rewards.sum(dim=0) 114 | if normalize: 115 | values = normalize_values(values) 116 | return values 117 | 118 | with torch.no_grad(): 119 | # used in the normalization 120 | expert_action_seq = amortization_model(observations) 121 | zero_action_seq = torch.zeros_like(expert_action_seq) 122 | expert_values = amortization_objective(expert_action_seq, observations, normalize=False) 123 | zero_values = amortization_objective(zero_action_seq, observations, normalize=False) 124 | 125 | evaluate_amortization_speed( 126 | amortization_model=amortization_model, 127 | amortization_objective=amortization_objective, 128 | contexts=observations, 129 | tag='control-model-based', 130 | fig_ylabel='Value', 131 | adam_lr=5e-3, 132 | num_iterations=500, 133 | maximize=True, 134 | ) 135 | 136 | 137 | 138 | def collect_eval_episode(exp): 139 | device = 'cuda' 140 | exp.env.set_seed(0) 141 | obs = exp.env.reset() 142 | done = False 143 | total_reward = 0. 144 | step = 0 145 | observations = [] 146 | while not done: 147 | if exp.cfg.normalize_obs: 148 | mu, sigma = exp.replay_buffer.get_obs_stats() 149 | obs = (obs - mu) / sigma 150 | obs = torch.FloatTensor(obs).to(device) 151 | observations.append(obs) 152 | action, _, _ = exp.agent.actor(obs, compute_pi=False, compute_log_pi=False) 153 | action = action.clamp(min=exp.env.action_space.low.min(), 154 | max=exp.env.action_space.high.max()) 155 | 156 | obs, reward, done, _ = exp.env.step(utils.to_np(action.squeeze(0))) 157 | total_reward += reward 158 | step += 1 159 | print(f'+ eval episode reward: {total_reward}') 160 | observations = torch.stack(observations, dim=0) 161 | return observations 162 | 163 | 164 | if __name__ == '__main__': 165 | main() 166 | -------------------------------------------------------------------------------- /paper/chapters/1-intro.tex: -------------------------------------------------------------------------------- 1 | \chapter{Introduction} 2 | \begin{figure}[t] 3 | \centering 4 | \includegraphics[width=2in]{fig/opt.pdf} 5 | \caption{Illustration of the parametric optimization problem 6 | in \cref{eq:opt}. 7 | Each context $x$ parameterizes an 8 | optimization problem that the objective $f(y; x)$ depends on. 9 | The contours show the values of the objectives where 10 | darker colors indicate higher values. 11 | The objective is then minimized over $y$ and the resulting 12 | solution $y^\star(x)$ is shown in red. 13 | In other words, each vertical slice is an optimization problem 14 | and this visualization shows a continuum of optimization problems. 15 | } 16 | \label{fig:opt} 17 | \end{figure} 18 | This tutorial studies the use of machine learning 19 | to improve repeated solves of parametric optimization 20 | problems of the form 21 | \begin{equation} 22 | y^\star(x) \in \argmin_y f(y; x), 23 | \label{eq:opt} 24 | \end{equation} 25 | where the \emph{non-convex} objective 26 | $f: \gY\times \gX\rightarrow \R$ 27 | takes a \emph{context} or \emph{parameterization} 28 | $x\in\gX$ which can be continuous or discrete, 29 | and the \emph{continuous, unconstrained domain} of 30 | the problem is $y\in\gY=\R^n$. 31 | \Cref{eq:opt} implicitly defines a \emph{solution} 32 | $y^\star(x)\in\gY$. 33 | In most of the applications considered later in 34 | \cref{sec:apps}, $y^\star(x)$ is unique and smooth, 35 | \ie, the solution continuously changes in a 36 | connected way as the context changes, as illustrated 37 | in \cref{fig:opt}. 38 | 39 | Parametric optimization problems such as \cref{eq:opt} 40 | have been studied for decades 41 | \citep{bank1982non,fiacco1990sensitivity,shapiro2003sensitivity,klatte2006nonsmooth,bonnans2013perturbation,still2018lectures,fiacco2020mathematical} 42 | with a focus on sensitivity analysis. 43 | The general formulation in \cref{eq:opt} captures many 44 | tasks arising in physics, engineering, mathematics, control, 45 | inverse modeling, and machine learning. 46 | For example, when controlling a continuous robotic system, 47 | $\gX$ is the space of \emph{observations} or \emph{states}, 48 | \eg, angular positions and velocities describing 49 | the configuration of the system, 50 | the domain $\gY\defeq \gU$ is the \emph{control space}, 51 | \eg, torques to apply to each actuated joint, 52 | and $f(u; x)\defeq -Q(u, x)$ is the \emph{control cost} 53 | or the negated \emph{Q-value} of the state-action tuple $(x,u)$, 54 | \eg, to reach a goal location or to maximize the velocity. 55 | For every encountered state $x$, the system is controlled 56 | by solving an optimization problem in the form of \cref{eq:opt}. 57 | While $\gY=\R^n$ is over a deterministic real-valued space 58 | in \cref{eq:opt}, the formulation can also capture 59 | stochastic optimization problems as discussed in 60 | \cref{sec:extensions:sto}. For example, 61 | \Cref{sec:apps:avi} optimizes over the (real-valued) 62 | parameters of a variational distribution and 63 | \cref{sec:apps:ctrl} optimizes over the (real-valued) 64 | parameters of a stochastic policy for control and 65 | reinforcement learning. 66 | 67 | \begin{figure}[t] 68 | \centering 69 | \includegraphics[width=\textwidth]{fig/overview.pdf} 70 | \caption{An amortized optimization method learns 71 | a model $\hat y_\theta$ to predict the minimum 72 | of an \emph{objective} $f(y;x)$ to a parameterized 73 | optimization problem, as in \cref{eq:opt}, 74 | which depends on a \emph{context} $x$. 75 | For example, in control, 76 | the context space $\gX$ is the state space of the system, 77 | \eg angular positions and velocities describing 78 | the configuration of the system, 79 | the domain $\gY\defeq\gU$ is the control space, 80 | \eg torques to apply to each actuated joint, 81 | the cost (or negated value) of a state-action 82 | pair is $f(u; x)\defeq -Q(x,u)$, and the state distribution is $p(x)$. 83 | For an encountered state $x$, 84 | many reinforcement learning policies $\pi_\theta(x)\defeq\hat y_\theta(x)$ 85 | amortize the solution to the underlying control problem 86 | with true solution $y^\star(x)$. 87 | This humanoid policy was obtained with the model-based 88 | stochastic value gradient in \citet{amos2021model}. 89 | } 90 | \label{fig:overview} 91 | \end{figure} 92 | 93 | 94 | Optimization problems such as \cref{eq:opt} quickly become a 95 | computational bottleneck in systems they are a part of. 96 | These problems often do not have a closed-form 97 | analytic solution and are instead solved with 98 | approximate numerical methods which iteratively 99 | search for the solution. 100 | This computational problem has led to many specialized 101 | solvers that leverage domain-specific insights to 102 | deliver fast solves. 103 | Specialized algorithms are 104 | especially prevalent in convex optimization methods for 105 | linear programming, quadratic programming, cone programming, 106 | and control and use theoretical insights of the problem 107 | structure to bring empirical gains of computational 108 | improvements and improved convergence 109 | \citep{boyd2004convex,nocedal2006numerical,bertsekas2015convex,bubeck2015convex,nesterov2018lectures}. 110 | 111 | Mostly separate from optimization research and algorithmic advancements, 112 | the machine learning community has focused on developing 113 | generic function approximation methods for estimating non-trivial 114 | high-dimensional mappings from data 115 | \citep{murphy2012machine,goodfellow2016deep,deisenroth2020mathematics}. 116 | While machine learning models are often used to reconstruct mappings 117 | from data, \eg for supervised classification or regression where 118 | the targets are given by human annotations. 119 | Many computational advancements on the software and hardware 120 | have been developed in recent years to make the prediction time fast: 121 | the forward pass of a neural network generating a prediction 122 | can execute in milliseconds on a graphics processing unit. 123 | 124 | \textbf{Overview.} 125 | This tutorial studies the use of machine learning models to 126 | rapidly predict the solutions to the optimization problem in 127 | \cref{eq:opt}, which is referred to as 128 | \emph{amortized optimization} or \emph{learning to optimize}. 129 | Amortized optimization methods are capable of significantly 130 | improving the computational time of 131 | classical algorithms \emph{on a focused subset of problems}. 132 | This is because the model is able to learn about the 133 | solution mapping from $x$ to $y^\star(x)$ that classical 134 | optimization methods usually do not assume access to. 135 | My goal in writing this is to explore a unified perspective 136 | of modeling approaches of amortized optimization in 137 | \cref{sec:foundations} to help draw connections 138 | between the applications in \cref{sec:apps}, 139 | \eg between amortized variational inference, meta-learning, 140 | and policy learning for control and reinforcement learning, 141 | sparse coding, convex optimization, optimal transport, 142 | and deep equilibrium networks. 143 | These topics have historically been studied in isolation 144 | without connections between their amortization components. 145 | \Cref{sec:implementation} presents a computational tour 146 | through source code for variational inference, policy learning, 147 | and a spherical optimization problem and 148 | \cref{sec:discussion} concludes with a discussion of 149 | challenges, limitations, open problems, and related work. 150 | 151 | \textbf{How much does amortization help?} 152 | Amortized optimization has been revolutionary to many fields, 153 | especially including variational inference and reinforcement 154 | learning. 155 | \Cref{fig:vae-performance} shows that the amortization component 156 | of a variational autoencoder trained on MNIST is \textbf{25000} 157 | times faster (0.4ms vs.~8 seconds!) than solving a batch of 158 | 1024 optimization problems from scratch to obtain a 159 | solution of the same quality. 160 | These optimization problems are solved in every training iteration 161 | and can become a significant bottleneck if they are 162 | inefficiently solved. 163 | If the model is being trained for millions of iterations, 164 | then the difference between solving the optimization problem 165 | in 0.4ms vs.~8 seconds makes the difference between the 166 | entire training process finishing in a few hours or a month. 167 | 168 | \textbf{A historic note: amortization in control and statistical inference.} 169 | Amortized optimization has arisen in many fields as a result 170 | to practical optimization problems being non-convex and not 171 | having easily computed, or closed-form solutions. 172 | Continuous control problems with linear dynamics and quadratic 173 | cost are convex and often easily solved with the linear 174 | quadratic regulator (LQR) and many non-convex extensions and 175 | iterative applications of LQR have been successful over 176 | the decades, but becomes increasingly infeasible on 177 | non-trivial systems and in reinforcement learning settings 178 | where the policy often needs to be rapidly executed. 179 | For this reason, the reinforcement learning community almost 180 | exclusively amortizes control optimization problems with 181 | a learned policy \citep{sutton2018reinforcement}. 182 | Related to this throughline in control and reinforcement learning, 183 | many statistical optimization problems have closed 184 | form solutions for known distributions such as Gaussians. 185 | For example, the original Kalman filter is defined with Gaussians 186 | and the updates take a closed form. The extended Kalman filter 187 | generalizes the distributions to non-Gaussians, but the updates 188 | are in general no longer available analytically and need to be 189 | computationally estimated. 190 | \citet{marino2018general} shows how amortization helps improve 191 | this computationally challenging step. 192 | Both of these control and statistical settings start with a 193 | simple setting with analytic solutions to optimization problems, 194 | generalize to more challenging optimization problems 195 | that need to be computationally estimated, and then 196 | add back some computational tractability with amortized optimization. 197 | 198 | 199 | %%% Local Variables: 200 | %%% coding: utf-8 201 | %%% mode: latex 202 | %%% TeX-master: "../amor-nowplain.tex" 203 | %%% LaTeX-biblatex-use-Biber: True 204 | %%% End: -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | 409 | -------------------------------------------------------------------------------- /paper/chapters/5-discussion.tex: -------------------------------------------------------------------------------- 1 | \chapter{Discussion} 2 | \label{sec:discussion} 3 | Many of the specialized methods discuss tradeoffs and 4 | limitations within the context of their application, 5 | and more generally papers such as 6 | \citet{chen2021learning,metz2021gradients} 7 | provide even deeper probes into general paradigms for 8 | learning to optimize. 9 | This section emphasizes a few additional discussion points 10 | around amortized optimization. 11 | 12 | \section{Surpassing the convergence rates of 13 | classical methods} 14 | \label{sec:convergence} 15 | Theoretical and empirical optimization research often focuses 16 | on discovering algorithms with theoretically strong convergence 17 | rates in general or worst-case scenarios. 18 | Many of the algorithms with the best convergence 19 | rates are used as the state-of-the-art algorithms in practice, 20 | such as momentum and acceleration methods. 21 | Amortized optimization methods can surpass the results 22 | provided by classical optimization methods because they 23 | are capable of tuning the initialization and updates 24 | to the best-case scenario within the distribution of 25 | contexts the amortization model is trained on. 26 | For example, the fully amortized models for amortized variational 27 | inference and model-free actor-critic methods for RL 28 | presented in \cref{sec:impl:eval} solve 29 | the optimization problems \emph{in constant time} with just 30 | a single prediction of the solution from the context without 31 | even looking at the objective! 32 | Further theoretical characterizations of this are provided 33 | in \citet{khodak2022learning} and related literature on 34 | algorithms with predictions. 35 | 36 | \section{Generalization and convergence guarantees} 37 | \label{sec:generalization} 38 | Despite having powerful successes of amortized optimization in 39 | some settings, the field struggles to bring strong success 40 | in other domains. 41 | Despite having the capacity of surpassing the convergence rates of 42 | other algorithms, oftentimes in practice amortized optimization 43 | methods can deeply struggle to generalize and converge to 44 | reasonable solutions. 45 | In some deployments this inaccuracy may be acceptable if there 46 | is a quick way of checking the quality of the amortized model, 47 | \eg the residuals for fixed-point and convex problems. 48 | If that is the case, then poorly-solved instances can be flagged 49 | and re-solved with a standard solver for the problem that 50 | may incur more computational time for that instance. 51 | \citet{sambharya2022l2ws} presents generalization bounds for learned 52 | warm-starts based on Rademacher complexity, 53 | and \citet{sambharya2023l2ws,sucker2024generalization} investigate PAC-Bayes generalization bounds. 54 | \citet{banert2021accelerated,premont2022simple} add provable 55 | convergence guarantees to semi-amortized models by guarding 56 | the update and ensuring the learned optimizer does not 57 | does not deviate too much from a known convergent algorithm. 58 | A practical takeaway is that some models are more likely 59 | to result in convergent and stable semi-amortized models 60 | than others. 61 | For example, the semi-amortized model 62 | parameterized with gradient descent (which 63 | has some mild converge guarantees) in \citet{finn2017model} 64 | is often more stable than the semi-amortized model parameterized 65 | by a sequential model (without many convergence guarantees) 66 | in \citep{ravi2016optimization}. 67 | Other modeling and architecture tricks such as layer 68 | normalization \citep{ba2016layernorm} help improve the 69 | stability of amortized optimization models. 70 | Additionally, \citet{fahy2024greedya} investigates learning preconditioners 71 | and prove that their parameterization of the preconditioning space 72 | always results in a convergent optimizer. 73 | 74 | \section{Measuring performance} 75 | Quantifying the performance of amortization models can be even more 76 | challenging than the choice between using a regression- or 77 | objective-based loss and is often tied to 78 | problem-specific metrics that are important. 79 | For example, even if a method is able to attain low objective values 80 | in a few iterations, the computation may take \emph{longer} than 81 | a specialized algorithm or another amortization model that can reach 82 | the same level of accuracy, thus not making it useful for the 83 | original goal of speeding up solves to \cref{eq:opt}. 84 | 85 | \section{Successes and limitations of amortized optimization} 86 | While amortized optimization has standout applications in 87 | variational inference, reinforcement learning, and meta-learning, 88 | it struggles to bring value in other settings. 89 | Often, learning the amortized model is computationally more 90 | expensive than solving the original optimization problems and 91 | brings instabilities into a higher-level learning or optimization 92 | process deployed on top of potentially inaccurate solutions 93 | from the amortized model. 94 | This section summarizes principles behind successful applications 95 | of amortized optimization and characterize limitations that 96 | may arise. 97 | 98 | \subsection*{Characteristics of successful applications} 99 | \begin{itemize} 100 | \item \textbf{Objective $f(y; x)$ is smooth over the domain $\gY$ 101 | and has unique solutions $y^\star$.} 102 | With objective-based learning, non-convex objectives with 103 | few poor local optima are ideal. 104 | This behavior can be encouraged with smoothing as is 105 | often done for meta-learning and policy learning (\cref{sec:smooth}). 106 | \item \textbf{A higher-level process should tolerate sub-optimal 107 | solutions given by $\hat y$ in the beginning of training.} 108 | In variational encoders, the suboptimal bound on the likelihood 109 | is still acceptable to optimize the density model's parameters over 110 | in \cref{eq:vae-full}. 111 | And in reinforcement learning policies, 112 | a suboptimal solution to the maximum value problem is still 113 | acceptable to deploy on the system in early phases of training, 114 | and may even be desirable for the exploration induced by randomly 115 | initialized policies. 116 | \item \textbf{The context distribution $p(x)$ is not too big and 117 | well-scoped and deployed on a specialized class of sub-problems.} 118 | For example, instead of trying to amortize the solution to 119 | \emph{every} possible $\ELBO$ maximization, VAEs 120 | amortize the problem only over the dataset the density 121 | model is being trained on. 122 | And in reinforcement learning, the policy $\pi_\theta$ doesn't 123 | try to amortize the solution to \emph{every} possible control 124 | problem, but instead focuses only on amortizing the solutions 125 | to the control problems on the replay buffer of the 126 | specific MDP. 127 | \item \textbf{In semi-amortized models, parameterizing the initialization 128 | and specialized components for the updates.} 129 | While semi-amortized models are a thriving research topic, 130 | the most successful applications of them: 131 | \begin{enumerate} 132 | \item \textbf{Parameterize and learn the initial iterate.} 133 | MAML \citep{finn2017model} \emph{only} parameterizes the initial 134 | iterate and follows it with gradient descent steps. 135 | \citet{bai2022neural} parameterizes 136 | the initial iterate and follows it with accelerated 137 | fixed-point iterations. 138 | \item \textbf{Parameterize and learn specialized components of the 139 | updates.} In sparse coding, LISTA \citep{gregor2010learning} 140 | only parameterized $\{F,G,\beta\}$ instead of the 141 | entire update rule. 142 | \citet{bai2022neural} only parameterizes $\alpha,\beta$ 143 | after the initial iterate, and 144 | RLQP \citep{ichnowski2021accelerating} only parameterizing $\rho$. 145 | \end{enumerate} 146 | While using a pure sequence model to update a sequence of 147 | iterations is possible and theoretically satisfying as it 148 | gives the model the power to arbitrarily update the sequence 149 | of iterates, in practice this can be unstable and severely 150 | overfit to the training instances. 151 | \citet{metz2021gradients} observes, for example, that semi-amortized 152 | recurrent sequence models induce chaotic behaviors 153 | and exploding gradients. 154 | \end{itemize} 155 | 156 | \subsection*{Limitations and failures} 157 | \begin{itemize} 158 | \item \textbf{Amortized optimization does \emph{not} magically solve otherwise 159 | intractable optimization problems!} 160 | At least not without significant insights. 161 | In most successful settings, the original optimization problem can be 162 | (semi-)tractably solved for a context $x$ with classical methods, 163 | such as using standard black-box variational inference 164 | or model-predictive control methods. 165 | Intractabilities indeed start arising when repeatedly solving 166 | the optimization problem, even if a single one can be reasonably solved, 167 | and amortization often thrive in these settings to rapidly solve 168 | problems with similar structure. 169 | \item \textbf{The combination of $p(x)$ and $y^\star(x)$ are too hard 170 | for a model to learn.} This could come from $p(x)$ being too 171 | large, \eg contexts of every optimization problem in the universe, 172 | or the solution $y^\star(x)$ not being smooth or predictable. 173 | $y^\star(x)$ may also not be unique, but this is perhaps easier 174 | to handle if the loss is carefully set up, \eg objective-based 175 | losses handle this more nicely. 176 | \item \textbf{The domain requires accurate solutions.} 177 | Even though metrics that measure the solution quality of $\hat y$ 178 | can be defined on top of \cref{eq:opt}, amortized methods 179 | typically cannot rival the accuracy of standard algorithms 180 | used to solve the optimization problems. 181 | In these settings, amortized optimization still has the 182 | potential at uncovering new foundations and algorithms 183 | for solving problems, but is non-trivial to 184 | successfully demonstrate. 185 | From an amortization perspective, one difficulty of safety-critical 186 | model-free reinforcement learning comes from needing to 187 | ensure the amortized policy properly optimizes a 188 | value estimate that (hopefully) encodes safety-critical 189 | properties of the state-action space. 190 | \end{itemize} 191 | 192 | \section{Some open problems and under-explored directions} 193 | In most domains, introducing or significantly improving amortized 194 | optimization is extremely valuable and will likely be well-received. 195 | Beyond this, there are many under-explored directions and 196 | combinations of ideas covered in this tutorial that can 197 | be shared between the existing fields using amortized optimization, 198 | for example: 199 | 200 | \begin{enumerate} 201 | \item \textbf{Overcoming local minima with objective-based losses 202 | and connections to stochastic policies.} 203 | \Cref{sec:smooth} covered the objective smoothing by 204 | \citet{metz2019understanding,merchant2021learn2hop} 205 | to overcome suboptimal local minima in the objective. 206 | These have striking similarities to stochastic policies 207 | in reinforcement learning that also overcome local 208 | minima, \eg in \cref{eq:Q-opt-sto-exp}. 209 | The stochastic policies, such as in \citet{haarnoja2018soft}, 210 | have the desirable property of starting with a high variance 211 | and then focusing in on a low-variance solution with a 212 | penalty constraining the entropy to a fixed value. 213 | A similar method is employed in GECO \citep{rezende2018taming} 214 | that adjusts a Lagrange multiplier in the ELBO objective 215 | to achieve a target conditional log-likelihood. 216 | These tricks seem useful to generalize and apply to 217 | other amortization settings to overcome poor minima. 218 | \item \textbf{Widespread and usable amortized convex solvers.} 219 | When using off-the-shelf optimization packages such as 220 | \citet{diamond2016cvxpy,o2016conic,stellato2018osqp}, 221 | users are likely solving many similar problem instances 222 | that amortization can help improve. 223 | \citet{venkataraman2021neural,ichnowski2021accelerating} 224 | are active research directions that study adding 225 | amortization to these solvers, but they do not scale 226 | to the general online setting that also doesn't 227 | add too much learning overhead for the user. 228 | \item \textbf{Improving the wall-clock training time 229 | of implicit models and differentiable optimization.} 230 | Optimization problems and fixed-point problems 231 | are being integrated into machine learning models, 232 | such as with differentiable optimization 233 | \citep{domke2012generic,gould2016differentiating,amos2017optnet,amos2019differentiable,agrawal2019differentiable,lee2019meta} 234 | and deep equilibrium models 235 | \citep{bai2019deep,bai2020multiscale}. 236 | In these settings, the data distribution the model 237 | is being trained on naturally induces a distribution over 238 | contexts that seem amenable to amortization. 239 | \citet{venkataraman2021neural,bai2022neural} 240 | explore amortization in these settings, but often do not 241 | improve the wall-clock time it takes to train these models 242 | from scratch. 243 | \item \textbf{Understanding the amortization gap.} 244 | \citet{cremer2018inference} study the \emph{amortization gap} 245 | in amortized variational inference, which measures how well the 246 | amortization model approximates the true solution. 247 | This crucial concept should be analyzed in most amortized 248 | optimization settings to understand the accuracy of 249 | the amortization model. 250 | \item \textbf{Implicit differentiation and shrinkage.} 251 | \citet{chen2019modular,rajeswaran2019meta} show that penalizing 252 | the amortization objective can significantly improve the 253 | computational and memory requirements to train a semi-amortized 254 | model for meta-learning. Many of the ideas in these settings 255 | can be applied in other amortization settings, 256 | as also observed by \citet{huszar2019imaml}. 257 | \item \textbf{Distribution shift of $p(x)$ and out-of-distribution generalization.} 258 | This tutorial has assumed that $p(x)$ is fixed and remains 259 | the same through the entire training process. 260 | However, in some settings $p(x)$ may shift over time, which 261 | could come from 1) the data generating process naturally 262 | changing, or 2) a \emph{higher-level} learning process 263 | also influencing $p(x)$. 264 | Furthermore, after training on some context distribution $p(x)$, 265 | a deploy model is likely not going to be evaluated on the 266 | same distribution and should ideally be resilient 267 | to out-of-distribution samples. 268 | The out-of-distribution performance can often be measured 269 | and quantified and reported alongside the model. 270 | Even if the amortization model fails at optimizing \cref{eq:opt}, 271 | it's detectable because the optimality conditions of 272 | \cref{eq:opt} or other solution quality metrics can be checked. 273 | If the solution quality isn't high enough, then a slower 274 | optimizer could potentially be used as a fallback. 275 | \item \textbf{Amortized and semi-amortized control and reinforcement learning.} 276 | Applications of semi-amortization in control and reinforcement learning 277 | covered in \cref{sec:apps:ctrl} are budding and 278 | learning sample-efficient optimal controllers is 279 | an active research area, especially in model-based settings 280 | where the dynamics model is known or approximated. 281 | \citet{amos2019dcem} shows how amortization can learn latent 282 | control spaces that are aware of the structure of the 283 | solutions to control problems. 284 | \citet{marino2020iterative} study semi-amortized methods 285 | based on gradient descent and show that they better-amortize 286 | the solutions than the standard fully-amortized models. 287 | \end{enumerate} 288 | 289 | \section{Related work} 290 | \subsection{Other tutorials, reviews, and discussions 291 | on amortized optimization} 292 | My goal in writing this tutorial was to provide a perspective 293 | of existing amortized optimization methods for learning 294 | to optimize with a categorization of the 295 | modeling (fully-amortized and semi-amortized) 296 | and learning (gradient-based, objective-based, or RL-based) 297 | aspects that I have found useful and have not seen 298 | emphasized as much in the literature. 299 | The other tutorials and reviews on 300 | amortized optimization, learning to optimize, and 301 | meta-learning over continuous domains 302 | that I am aware of are excellent resources: 303 | 304 | \begin{itemize} 305 | \item \citet{chen2021learning} captures many other emerging areas 306 | of learning to optimize and discuss many other modeling paradigms 307 | and optimization methods for learning to optimize, such as 308 | plug-and-play methods \citep{venkatakrishnan2013plug,meinhardt2017learning,rick2017one,zhang2017learning}. 309 | They emphasize the key aspects and questions to tackle as a community, 310 | including model capacity, trainability, generalization, and 311 | interpretability. 312 | They propose \emph{Open-L2O} as a new benchmark for 313 | learning to optimize and review many other applications, 314 | including sparse and low-rank regression, graphical models, 315 | differential equations, quadratic optimization, inverse problems, 316 | constrained optimization, image restoration and reconstruction, 317 | medical and biological imaging, wireless communications, 318 | seismic imaging. 319 | \item \citet{shu2017amortized} is a blog post that discusses 320 | fully-amortized models with gradient-based learning 321 | and includes applications in variational inference, 322 | meta-learning, image style transfer, 323 | and survival-based classification. 324 | \item \citet{weng2018metalearning} is a blog post 325 | with an introduction and review of meta-learning methods. 326 | After defining the problem setup, the review discusses 327 | metric-based, model-based, and optimization-based approaches, 328 | and discusses approximations to the second-order derivatives 329 | that come up with MAML. 330 | \item \citet{hospedales2020meta} is a review focused on meta-learning, 331 | where they categorize meta-learning components into a 332 | meta-representation, meta-optimizer, and meta-objective. 333 | The most relevant connections to amortization here are that 334 | the meta-representation can instantiate an 335 | amortized optimization problem that is solved with the 336 | meta-optimizer. 337 | \item \citet{kim2020deep} is a dissertation on deep 338 | latent variable models for natural language 339 | and contextualizes and studies the use of amortization and 340 | semi-amortization in this setting. 341 | \item \citet{marino2021learned} is a dissertation on learned 342 | feedback and feedforward information for perception and control 343 | and contextualizes and studies the use of amortization and 344 | semi-amortization in these settings. 345 | \item \citet{monga2021algorithm} is a review on 346 | algorithm unrolling that starts with the unrolling 347 | in LISTA \citep{gregor2010learning} for amortized 348 | sparse coding, and then connects to other methods 349 | of unrolling specialized algorithms. 350 | While some unrolling methods have applications in 351 | semi-amortized models, this review also considers 352 | applications and use-cases beyond just 353 | amortized optimization. 354 | \item \citet{banert2020data} consider theoretical foundations 355 | for data-driven nonsmooth optimization and show applications 356 | in deblurring and solving inverse problems for 357 | computed tomography. 358 | \item \citet{liu2022teaching} study fully-amortized 359 | models based on deep sets \citep{zaheer2017deep} 360 | and set transformers \citep{lee2019set}. 361 | They consider regression- and objective-based losses 362 | for regression, PCA, core-set creation, and 363 | supply management for cyber-physical systems. 364 | \item \citet{vanhentenryck2025optimizationlearning} presents an overview 365 | of learned optimization methods arising in power systems, 366 | for real-time risk assessment and security-constrained optimal power flow. 367 | \end{itemize} 368 | 369 | \subsection{Amortized optimization over discrete domains} 370 | A significant generalization of \cref{eq:opt} is to optimization 371 | problems that have discrete domains, 372 | which includes combinatorial optimization 373 | and mixed discrete-continuous optimization. 374 | I have chosen to not include these works in this tutorial 375 | as many methods for discrete optimization are significantly 376 | different from the methods considered here, as learning with 377 | derivative information often becomes impossible. 378 | Key works in discrete and combinatorial spaces include 379 | \citet{khalil2016learning,dai2017learning,jeong2019learning,bertsimas2019online,shao2021learning,bertsimas2021voice,cappart2021combinatorial} 380 | and the surveys 381 | \citep{lodi2017learning,bengio2021machine,kotary2021end} 382 | capture a much broader view of this space. 383 | \citet{banerjee2015efficiently} consider repeated ILP solves 384 | and show applications in aircraft carrier deck scheduling and vehicle routing. 385 | For architecture search, \citet{luo2018neural} learn a continuous 386 | latent space behind the discrete architecture space. 387 | Many reinforcement learning and control methods over discrete 388 | spaces can also be seen as amortizing or semi-amortizing the 389 | discrete control problems, for example: 390 | \citet{cauligi2020learning,cauligi2021coco} use regression-based 391 | amortization to learn mixed-integer control policies. 392 | \citet{fickinger2021scalable} fine-tune the policy 393 | optimizer for every encountered state. 394 | \citet{tennenholtz2019natural,chandak2019learning,van2020q} 395 | learn latent action spaces for high-dimensional 396 | discrete action spaces with shared structure. 397 | 398 | \subsection{Learning-augmented and amortized algorithms beyond optimization} 399 | While many algorithms can be interpreted as solving an 400 | optimization problems or fixed-point computations and 401 | can therefore be improved with amortized optimization, 402 | it is also fruitful to use learning to improve 403 | algorithms that have nothing to do with optimization. 404 | Some key starting references in this space include 405 | data-driven algorithm design \citep{balcan2020data}, 406 | algorithms with predictions 407 | \citep{dinitz2021faster,sakaue2022discrete,chen2022faster,khodak2022learning}, 408 | learning to prune \citep{alabi2019learning}, 409 | learning solutions to differential equations 410 | \citep{li2020fourier,poli2020hypersolvers,karniadakis2021physics,kovachki2021universal,chen2021solving,blechschmidt2021three,marwah2021parametric,berto2021neural} 411 | learning simulators for physics \citep{grzeszczuk1998neuroanimator,ladicky2015data,he2019learning,sanchez2020learning,wiewel2019latent,usman2021machine,vinuesa2021potential}, 412 | and learning for symbolic math 413 | \citep{lample2019deep,charton2021linear,charton2021deep,drori2021neural,dascoli2022deep} 414 | \citet{salimans2022progressive} progressively amortizes a 415 | sampling process for diffusion models. 416 | \citet{schwarzschild2021can} learn recurrent neural networks 417 | to solve algorithmic problems for prefix sum, mazes, and chess. 418 | 419 | \subsection{Continuation and homotopy methods} 420 | Amortized optimization settings share a similar motivation to 421 | continuation and homotopy methods that have been studied for 422 | over four decades 423 | \citep{richter1983continuation,watson1989modern,allgower2012numerical}. 424 | These methods usually set the context space to be the 425 | interval $\gX=[0,1]$ and simultaneously solve (without learning) 426 | problems along this line. 427 | This similarity indicates that problem classes typically 428 | studied by continuation and homotopy methods could also benefit 429 | from the shared amortization models here. 430 | 431 | %%% Local Variables: 432 | %%% coding: utf-8 433 | %%% mode: latex 434 | %%% TeX-master: "../amor-nowplain.tex" 435 | %%% LaTeX-biblatex-use-Biber: True 436 | %%% End: -------------------------------------------------------------------------------- /paper/chapters/4-implementation.tex: -------------------------------------------------------------------------------- 1 | \chapter{Implementation and software examples} 2 | \label{sec:implementation} 3 | 4 | \begin{table}[t] 5 | \centering 6 | \caption{Dimensions for the settings considered in this section} 7 | \label{tab:sizes} 8 | \resizebox{\textwidth}{!}{ 9 | \begin{tabular}{lll}\toprule 10 | Setting & Context dimension $|\gX|$ & Solution dimension $|\gY|$ \\ \midrule 11 | VAE on MNIST (\ref{sec:impl:vaes}) & $784$ {\color{gray} (=$28\cdot 28$, MNIST digits)} & $20$ {\color{gray}(parameterizing a 10D Gaussian)} \\ 12 | Model-free control (\ref{sec:impl:model-free}) & $45$ {\color{gray}(humanoid states)} & $17$ {\color{gray} (action dimension)} \\ 13 | Model-based control (\ref{sec:impl:model-based}) & $45$ {\color{gray}(humanoid states)} & $51$ {\color{gray} (=$17\cdot 3$, short action sequence)} \\ 14 | Sphere (\ref{sec:impl:sphere}) & $16$ {\color{gray}($c$-convex function parameterizations)} & $3$ {\color{gray} (sphere)} \\ \bottomrule 15 | \end{tabular}} 16 | \end{table} 17 | 18 | Turning now to the implementation details, this section 19 | looks at how to develop and analyze amortization software. 20 | The standard and easiest approach in most settings is to use 21 | automatic differentiation software such as 22 | \citet{maclaurin2015autograd,al2016theano,abadi2016tensorflow,bezanson2017julia,agrawal2019tensorflow,paszke2019pytorch,bradbury2020jax} 23 | to parameterize and learn the amortization model. 24 | There are many open source implementations and re-implementations 25 | of the methods in \cref{sec:apps} that provide a concrete 26 | starting point to start building on them. 27 | This section looks closer at three specific implementations: 28 | \cref{sec:impl:eval} evaluates the amortization components 29 | behind existing implementations of variational autoencoders 30 | \cref{sec:apps:avi} and control \cref{sec:apps:ctrl} and 31 | \cref{sec:impl:sphere} implements and trains an amortization model 32 | to optimize functions defined on a sphere. 33 | \Cref{tab:sizes} summarizes the concrete dimensions of the amortization 34 | problems considered here and \cref{sec:impl:software} 35 | concludes with other useful software references. 36 | The source code behind this section is available at 37 | \url{https://github.com/facebookresearch/amortized-optimization-tutorial}. 38 | 39 | \section{Amortization in the wild: a deeper look} 40 | \label{sec:impl:eval} 41 | 42 | This section shows code examples of how existing implementations 43 | using amortized optimization define and optimize their models 44 | for variational autoencoders (\cref{sec:impl:vaes}) 45 | and control and policy learning (\cref{sec:impl:model-free,sec:impl:model-based}). 46 | The amortization component in these systems is often a part 47 | of a larger system to achieve a larger task: 48 | VAEs also reconstruct the source data 49 | after amortizing the ELBO computation in \cref{eq:vae-full} 50 | and policy learning methods also estimate the 51 | $Q$-value function in \cref{sec:Q-learning}. 52 | This section scopes to the amortization components to show 53 | how they are implemented. 54 | I have also added evaluation code to the pre-trained amortization models 55 | from existing repositories and show that the amortized approximation 56 | often obtains a solution up to \textbf{25000 times} faster than 57 | solving the optimization problems from scratch on an 58 | NVIDIA Quadro GP100 GPU. 59 | 60 | \subsection{The variational autoencoder (VAE)} 61 | \label{sec:impl:vaes} 62 | 63 | This section looks at the code behind 64 | standard VAE \citep{kingma2013auto} that follows the 65 | amortized optimization setup described in \cref{sec:apps:vae}. 66 | While there are many implementations for training and 67 | reproducing a VAE, this section will use the implementation at 68 | \url{https://github.com/YannDubs/disentangling-vae}, 69 | which builds on the code behind 70 | \citet{dupont2018learning} at 71 | \url{https://github.com/Schlumberger/joint-vae}. 72 | While the repository is focused on disentangled representations 73 | and extensions of the original VAE formulation, this 74 | section only highlights the parts corresponding to the 75 | original VAE formulation. 76 | The code uses standard PyTorch in a minimal way that 77 | allow us to easily look at the amortization components. 78 | 79 | \textbf{Training the VAE.} 80 | \Cref{lst:vae} paraphrases the relevant snippets of code to 81 | implement the main amortization problem in \cref{eq:vae-amor} 82 | for image data where the likelihood is given by a Bernoulli. 83 | \Cref{lst:vae.encoder} defines an encoder $\hat \lambda_\theta$, 84 | to predicts a solution to the ELBO implemented in \cref{lst:vae.elbo}, 85 | which is optimized in a loop over the training data (images) 86 | in \cref{lst:vae.loop}. 87 | The \path{README} in the repository contains instructions 88 | for running the training from scratch. 89 | The repository contains the binary of a model trained 90 | on the MNIST dataset \citep{lecun1998mnist}, which 91 | the next portion evaluates. 92 | 93 | \begin{figure}[t] 94 | \centering 95 | \vspace{-4mm} 96 | \includegraphics[width=0.49\textwidth]{fig/vae-iter.pdf} 97 | \hfill 98 | \includegraphics[width=0.49\textwidth]{fig/vae-time.pdf} \\ 99 | \cblock{0}{0}{0} Amortized encoder $\hat\lambda_\theta(x)$ --- runtime: 0.4ms 100 | \caption{ 101 | Runtime comparison between Adam and an amortized encoder $\hat\lambda_\theta$ 102 | to solve \cref{eq:elbo-opt} for a VAE on MNIST. 103 | This uses a batch of 1024 samples and was 104 | run on an unloaded NVIDIA Quadro GP100 GPU. 105 | The values are normalized so that $\lambda(x)=0$ takes a value of -1 and 106 | the optimal $\lambda^\star$ takes a value of 0. 107 | The amortized policy is approximately 108 | \textbf{25000} times faster than solving the 109 | problem from scratch. 110 | } 111 | \label{fig:vae-performance} 112 | \end{figure} 113 | 114 | \begin{figure} 115 | \centering 116 | \begin{subfigure}[b]{\textwidth} 117 | \begin{lstlisting} 118 | class Encoder(nn.Module): # From disvae.models.encoders 119 | def forward(self, x): # x is the amortization context: the original data 120 | mu_logvar = self.convnet(x) 121 | mu, logvar = mu_logvar.view(-1, self.latent_dim, 2).unbind(-1) # Split 122 | return (mu, logvar) # = latent_dist or \lambda 123 | \end{lstlisting} 124 | \caption{Forward definition for the encoder $\hat \lambda_\theta(x)$. 125 | \path{self.convnet} uses the architecture 126 | from \citet{burgess2018understanding}.} 127 | \label{lst:vae.encoder} 128 | \end{subfigure} 129 | \begin{subfigure}[b]{\textwidth} 130 | \begin{lstlisting} 131 | # From disvae.models.losses.BetaHLoss with a Bernoulli likelihood 132 | def estimate_elbo(data, latent_dist): 133 | mean, logvar = latent_dist 134 | 135 | reconstructed_batch = sample_and_decode(latent_dist) 136 | log_likelihood = -F.binary_cross_entropy( 137 | reconstructed_batch, x, reduce=False).sum(dim=[1,2,3]) 138 | 139 | # Closed-form distance to the prior 140 | latent_kl = 0.5 * (-1 - logvar + mean.pow(2) + logvar.exp()) 141 | kl_to_prior = latent_kl.sum(dim=[-1]) 142 | 143 | loss = log_likelihood - kl_to_prior 144 | return loss.mean() 145 | \end{lstlisting} 146 | \caption{Definition of the $\ELBO$ in \cref{eq:elbo}} 147 | \label{lst:vae.elbo} 148 | \end{subfigure} 149 | 150 | \begin{subfigure}[b]{\textwidth} 151 | \begin{lstlisting} 152 | model = Encoder() 153 | for batch in iter(data_loader): 154 | latent_dist = model(batch) 155 | loss = -estimate_elbo(batch, latent_dist) 156 | self.optimizer.zero_grad() 157 | loss.backward() 158 | self.optimizer.step() 159 | \end{lstlisting} 160 | \caption{Main VAE training loop for the encoder} 161 | \label{lst:vae.loop} 162 | \end{subfigure} 163 | 164 | \caption{Paraphrased PyTorch code examples of the key 165 | amortization components of a VAE from 166 | \url{https://github.com/YannDubs/disentangling-vae}.} 167 | \label{lst:vae} 168 | \end{figure} 169 | 170 | \textbf{Evaluating the VAE.} 171 | This section looks at how well the amortized encoder 172 | $\hat\lambda$ approximates the optimal 173 | encoder $\lambda^\star$ given by explicitly solving 174 | \cref{eq:elbo-opt}, which is 175 | referred to the \emph{amortization gap} \citep{cremer2018inference}. 176 | \cref{eq:elbo-opt} can be solved with a gradient-based optimizer 177 | such as SGD or Adam \citep{kingma2014adam}. 178 | \Cref{lst:evaluation} shows the key parts of the PyTorch code 179 | for making this comparison, which can be run on the pre-trained 180 | MNIST VAE with 181 | \path{code/evaluate_amortization_speed_function_vae.py}. 182 | 183 | \Cref{fig:vae-performance} shows that the amortized prediction from 184 | the VAE's encoder predicts the solution to the ELBO \textbf{25000} 185 | times faster (!) than running 2k iterations of Adam on 186 | a batch of 1024 samples. 187 | This is significant as \emph{every training iteration} of 188 | the VAE requires solving \cref{eq:elbo-opt}, and a large model 189 | may need millions of training iterations to converge. 190 | Amortizing the solution makes the difference between the training 191 | code running in a few hours instead of a few months if 192 | the problem was solved from scratch to the same level of optimality. 193 | Knowing only the ELBO values is not sufficient to gauge 194 | the quality of approximate variational distributions. 195 | To help understand the quality of the approximate solutions, 196 | \cref{fig:vae-samples} plots out the decoded samples 197 | alongside the original data. 198 | 199 | \begin{figure}[t] 200 | \centering 201 | \resizebox{\textwidth}{!}{ 202 | \hspace*{-4mm} 203 | \begin{tikzpicture} 204 | \node[align=left,anchor=north west] {\includegraphics[width=430pt]{fig/vae-samples.png}}; 205 | \node[align=right,anchor=east] at (2mm,-6mm) (0) {0}; 206 | \node[align=right,below=10.3mm of 0.east,anchor=east] (250) {${\rm Adam}_{250}$}; 207 | \node[align=right,below=10.3mm of 250.east,anchor=east] (500) {${\rm Adam}_{500}$}; 208 | \node[align=right,below=10.3mm of 500.east,anchor=east] (1000) {${\rm Adam}_{1000}$}; 209 | \node[align=right,below=10.3mm of 1000.east,anchor=east] (2000) {${\rm Adam}_{2000}$}; 210 | \node[align=right,below=10.3mm of 2000.east,anchor=east] (amor) {$\hat\lambda_\theta(x)$}; 211 | \node[align=right,below=9.5mm of amor.east,anchor=east] (data) {Data}; 212 | \end{tikzpicture}} 213 | \caption{Decoded reconstructions of the variational distribution optimizing for the ELBO. 214 | ${\rm Adam}_n$ corresponds to the distribution from running Adam for $n$ 215 | iterations, $\hat \lambda_\theta$ is the amortized approximation, 216 | and the ground-truth data, \ie the context, is shown in the bottom row. 217 | } 218 | 219 | \label{fig:vae-samples} 220 | \end{figure} 221 | 222 | \begin{figure}[t] 223 | \begin{lstlisting} 224 | # amortization_model: maps contexts to a solution 225 | # amortization_objective: maps an iterate and contexts to the objective 226 | 227 | adam_lr, num_iterations = ... 228 | contexts = sample_contexts() 229 | 230 | # Predict the solutions with the amortization model 231 | predicted_solutions = amortization_model(contexts) 232 | amortized_objectives = amortization_objective( 233 | predicted_solutions, contexts 234 | ) 235 | 236 | # Use Adam (or another torch optimizer) to solve for the solutions 237 | iterates = torch.nn.Parameter(torch.zeros_like(predicted_solutions)) 238 | opt = torch.optim.Adam([iterates], lr=adam_lr) 239 | 240 | for i in range(num_iterations): 241 | objectives = amortization_objective(iterates, contexts) 242 | opt.zero_grad() 243 | objective.backward() 244 | opt.step() 245 | \end{lstlisting} 246 | \caption{Evaluation code for comparing the amortized prediction $\hat y$ 247 | to the true solution $y^\star$ solving \cref{eq:opt} with a 248 | gradient-based optimizer. 249 | The full instrumented version of this code is available in 250 | the repository associated with this tutorial at 251 | \texttt{\detokenize{code/evaluate_amortization_speed_function.py}}. 252 | } 253 | \label{lst:evaluation} 254 | \end{figure} 255 | 256 | \subsection{Control with a model-free value estimate} 257 | \label{sec:impl:model-free} 258 | 259 | \begin{figure}[t] 260 | \centering 261 | \includegraphics[width=0.49\textwidth]{fig/control-model-free-iter.pdf} 262 | \hfill 263 | \includegraphics[width=0.49\textwidth]{fig/control-model-free-time.pdf} \\[5mm] 264 | \cblock{0}{0}{0} Policy $\pi(x)$ --- runtime: 0.65ms 265 | \caption{ 266 | Runtime comparison between Adam and a learned policy $\pi_\theta$ 267 | to solve \cref{eq:Q-opt} on the humanoid MDP. 268 | This was evaluated as a batch on an expert trajectory with 269 | 1000 states and was run on an unloaded NVIDIA Quadro GP100 GPU. 270 | The values are normalized so that $\pi(x)=0$ takes a value of -1 and 271 | the optimal $\pi^\star$ takes a value of 0. 272 | The amortized policy is approximately 273 | 1000 times faster than solving the 274 | problem from scratch. 275 | } 276 | \label{fig:model-free-performance} 277 | \end{figure} 278 | 279 | 280 | This section dives into the training and evaluation code for learning a 281 | deterministic model-free policy $\pi_\theta: \gX\rightarrow\gY$ to amortize a model-free 282 | value estimate $Q$ for controlling the humanoid MDP from 283 | \citet{brockman2016openai} 284 | visualized in \cref{fig:overview}. 285 | This MDP has $|\gX|=45$ states (angular positions and velocities 286 | describing the state of the system) 287 | and $|\gY|=17$ actions (torques to apply to the joints). 288 | A model-free policy $\pi$ maps the state to the optimal actions 289 | that maximize the value on the system. 290 | Given a known action-conditional value estimate $Q(x,u)$, 291 | the optimal policy $\pi^\star$ solves the optimization problem in 292 | \cref{eq:Q-opt} that the learned policy $\pi$ tries to match, 293 | \eg using policy gradient in \cref{eq:dpg-loss}. 294 | 295 | The codebase behind \citet{amos2021model} at 296 | \url{https://github.com/facebookresearch/svg} 297 | contains trained model-free policy and value estimates 298 | on the humanoid in addition to model-based components 299 | the next section will use. 300 | The full training code there involves parameterizing 301 | a stochastic policy and estimating many additional 302 | model-based components, but the basic training loop 303 | for amortizing a deterministic policy from the solution 304 | there can be distilled into a form similar to 305 | \cref{lst:vae}. 306 | 307 | This section mostly focuses on evaluating the performance 308 | of the trained model-free policy in comparison 309 | to maximizing the model-free value estimate 310 | \cref{eq:Q-opt} from scratch for every state encountered. 311 | An exhaustive evaluation of a solver for \cref{eq:Q-opt} 312 | would need to ensure that the solution is not overly 313 | adapted to a bad part of the $Q$ estimate space --- 314 | because $Q$ is also a neural network susceptible to 315 | adversarial examples, it is very likely that directly 316 | optimizing \cref{eq:Q-opt} may result in a deceptively 317 | good policy when looking at the $Q$ estimate that does not 318 | work well on the real system. 319 | For simplicity, this section ignores these 320 | issues and normalizes the values to $[-1,0]$ where 321 | $-1$ will correspond to the value from taking a zero 322 | action and $0$ will correspond to the value 323 | from taking the expert's action. 324 | (This is valid in this example because the zero 325 | action and expert action never coincide.) 326 | 327 | \Cref{fig:model-free-performance} shows that the 328 | amortized policy is approximately 1000 times faster 329 | than solving the problem from scratch. 330 | The $Q$ values presented there are normalized and clamped 331 | so that the expert policy has a value of zero and 332 | the zero action has a value of -1. 333 | This example can be run with 334 | \path{code/evaluate_amortization_speed_function_control.py}, 335 | which shares the evaluation code also used for the VAE 336 | in \cref{lst:evaluation}. 337 | 338 | \subsection{Control with a model-based value estimate} 339 | \label{sec:impl:model-based} 340 | 341 | Extending the results of \cref{sec:impl:model-free}, 342 | this section compares the trained humanoid policy 343 | from \url{https://github.com/facebookresearch/svg} 344 | to solving a short-horizon ($H=3$) model-based control 345 | optimization problem defined in \cref{eq:mpc}. 346 | The optimal action sequence solving \cref{eq:mpc} 347 | is $u^\star_{1:H}$ can be approximated by interleaving 348 | a model-free policy $\pi_\theta$ with the dynamics $f$. 349 | While standard model predictive control method are 350 | often ideal for solving for $u^\star_{1:H}$ from scratch, 351 | using Adam as a gradient-based shooting method is 352 | a reasonable baseline in this short-horizon setting. 353 | 354 | \Cref{fig:model-based-performance} shows that the 355 | amortized policy is approximately 700 times faster 356 | than solving the problem from scratch. 357 | This model-based setting has the same issues with 358 | the approximation errors in the models and the 359 | model-based value estimate is again 360 | normalized and clamped so that the expert 361 | policy has a value of zero and the zero action has a value of -1. 362 | The source code behind this example is also available in 363 | \path{code/evaluate_amortization_speed_function_control.py}. 364 | 365 | \begin{figure}[t] 366 | \centering 367 | \includegraphics[width=0.49\textwidth]{fig/control-model-based-iter.pdf} 368 | \hfill 369 | \includegraphics[width=0.49\textwidth]{fig/control-model-based-time.pdf} \\[5mm] 370 | \cblock{0}{0}{0} Policy $\pi(x)$ --- runtime: 5.8ms 371 | \caption{Runtime comparison between Adam and a learned policy $\pi_\theta$ 372 | to solve a short-horizon ($H=3$) model-based control 373 | problem (\cref{eq:mpc}) on the humanoid MDP. 374 | This was evaluated as a batch on an expert trajectory with 375 | 1000 states and was run on an unloaded NVIDIA Quadro GP100 GPU. 376 | The amortized policy is approximately 377 | 700 times faster than solving the 378 | problem from scratch. 379 | } 380 | \label{fig:model-based-performance} 381 | \end{figure} 382 | 383 | \section{Training an amortization model on a sphere} 384 | \label{sec:impl:sphere} 385 | 386 | This section contains a new demonstration that applies 387 | the insights from amortized optimization to learn to solve 388 | optimization problems over spheres of the form 389 | \begin{equation} 390 | y^\star(x) \in \argmin_{y\in\gS^2} f(y; x), 391 | \label{eq:sphere-opt-con} 392 | \end{equation} 393 | where $\gS^2$ is the surface of the \emph{unit 2-sphere} 394 | embedded in $\R^3$ as $\gS^2\defeq \{y\in\R^3 \mid \|y\|_2=1\}$ 395 | and $x$ is some parameterization of the function 396 | $f: \gS^2\times\gX\rightarrow \R$. 397 | \Cref{eq:sphere-opt-con} is relevant to physical and 398 | geographical settings seeking the extreme values of a 399 | function defined on the Earth or other spaces that can 400 | be approximated with a sphere. 401 | The full source code behind this experiment is available 402 | in \path{code/train-sphere.py}. 403 | 404 | \textbf{Amortization objective.} 405 | \Cref{eq:sphere-opt-con} first needs to be transformed 406 | from a constrained optimization problem into an unconstrained 407 | one of the form \cref{eq:opt}. 408 | In this setting, one way of doing this 409 | is by using a projection: 410 | \begin{equation} 411 | y^\star(x) \in \argmin_{y\in\R^3} f(\pi_{\gS^2}(y); x), 412 | \label{eq:sphere-opt-proj} 413 | \end{equation} 414 | where $\pi_{\gS^2}: \R^3\rightarrow \gS^2$ is the 415 | Euclidean projection onto $\gS^2$, \ie, 416 | \begin{equation} 417 | \begin{aligned} 418 | \pi_{\gS^2}(x)\defeq& \argmin_{y\in\gS^2} \|y-x\|_2 \\ 419 | =& \;x/\|x\|_2. 420 | \end{aligned} 421 | \label{eq:pi} 422 | \end{equation} 423 | 424 | \textbf{$c$-convex functions on the sphere.} 425 | A synthetic class of optimization problems defined 426 | on the sphere using the $c$-convex functions from 427 | \citet{cohen2021riemannian} can be instantiated with: 428 | \begin{equation} 429 | f(y; x) = {\textstyle \min_{\gamma}} \left\{\frac{1}{2} d(x,z_i)+\alpha_i\right\}_{i=1}^m 430 | \label{eq:rcpm} 431 | \end{equation} 432 | where $m$ components define the context 433 | $x=\{z_i\} \cup \{\alpha_i\}$ 434 | with $z_i\in\gS^2$ and $\alpha_i\in\R$, 435 | $d(x,y)\defeq \arccos(x^\top y)$ is the 436 | Riemannian distance on the sphere in the 437 | ambient Euclidean space, and 438 | $\min_\gamma(a_1,\ldots,a_m)\defeq -\gamma\log\sum_{i=1}^m\exp(-a_i/\gamma)$ 439 | is a soft minimization operator 440 | as proposed in \citet{cuturi2017soft}. 441 | The context distribution $p(x)$ is sampled 442 | with $z_i\sim \gU(\gS^2)$, \ie uniformly from the sphere, 443 | and $\alpha_i\sim\gN(0,\beta)$ 444 | with variance $\beta\in\R_+$. 445 | 446 | \textbf{Amortization model.} 447 | The model $\hat y_\theta: \gX\rightarrow\R$ 448 | is a fully-connected MLP. 449 | The predictions to \cref{eq:sphere-opt-con} 450 | on the sphere can again be obtained by composing 451 | the output with the projection 452 | $\pi_{\gS^2}\circ \hat y_\theta$. 453 | 454 | \textbf{Optimizing the gradient-based loss.} 455 | Finally, it is reasonable to optimize the 456 | gradient-based loss $\gL_{\rm obj}$ because 457 | the objective and model are tractable and 458 | easily differentiable. 459 | \Cref{fig:sphere} shows the model's predictions 460 | starting with the untrained model and finishing 461 | with the trained model, showing that this setup 462 | indeed enables us to predict the solutions to 463 | \cref{eq:sphere-opt-con} with a single neural network 464 | $\hat y_\theta(x)$ trained with the gradient-based loss. 465 | 466 | \textbf{Summary.} 467 | $\gA_{\rm sphere}\defeq (f\circ \pi_{\gS^2}, \R^3, \gX, p(x), \hat y_\theta, \gL_{\rm obj})$ 468 | 469 | \begin{figure} 470 | \includegraphics[width=0.25\textwidth]{fig/sphere/0.png} 471 | \hspace{-2.7mm} 472 | \includegraphics[width=0.25\textwidth]{fig/sphere/1.png} 473 | \hspace{-2.7mm} 474 | \includegraphics[width=0.25\textwidth]{fig/sphere/2.png} 475 | \hspace{-2.7mm} 476 | \includegraphics[width=0.25\textwidth]{fig/sphere/3.png} \\[-.8mm] 477 | \includegraphics[width=0.25\textwidth]{fig/sphere/4.png} 478 | \hspace{-2.7mm} 479 | \includegraphics[width=0.25\textwidth]{fig/sphere/5.png} 480 | \hspace{-2.7mm} 481 | \includegraphics[width=0.25\textwidth]{fig/sphere/6.png} 482 | \hspace{-2.7mm} 483 | \includegraphics[width=0.25\textwidth]{fig/sphere/7.png} \\[-6mm] 484 | \begin{center} 485 | \cblock{73}{19}{134} $f(y; x)$ contours \; 486 | \cblock{170}{0}{0} Optimal $y^\star(x)$ \; 487 | \cblock{84}{153}{255} Predictions $\hat y_\theta(x)$ throughout training 488 | \end{center} 489 | \vspace{-3mm} 490 | \caption{Visualization of the predictions of an amortized 491 | optimization model predicting the solutions 492 | to optimization problems on the sphere.} 493 | \label{fig:sphere} 494 | \end{figure} 495 | 496 | \section{Other useful software packages} 497 | \label{sec:impl:software} 498 | 499 | Implementing semi-amortized models are usually more challenging 500 | than fully-amortized models. Learning an optimization-based 501 | model that internally solves an optimization problem is 502 | not as widespread as learning a feedforward neural network. 503 | While most autodiff packages provide standalone features to implement 504 | unrolled gradient-based optimization, the following specialized 505 | packages provide crucial features that further enable the 506 | exploration of semi-amortized models: 507 | \begin{itemize} 508 | \item \href{https://github.com/cvxgrp/cvxpylayers}{cvxpylayers} 509 | \citep{agrawal2019differentiable} 510 | allows an optimization problem to be expressed in the 511 | high-level language \verb!CVXPY! \citep{diamond2016cvxpy} 512 | and exported to PyTorch, JAX, and TensorFlow 513 | as a differentiable optimization layers. 514 | \item \href{https://github.com/google/jaxopt}{jaxopt} 515 | \citep{blondel2021efficient} 516 | is a differentiable optimization library for JAX 517 | and implements many optimization settings and fixed-point 518 | computations along with their implicit derivatives. 519 | \item \href{https://github.com/facebookresearch/higher}{higher} 520 | \citep{grefenstette2019generalized} 521 | is a PyTorch library that adds differentiable higher-order 522 | optimization support with 523 | 1) monkey-patched functional \verb!torch.nn! modules, 524 | and 2) differentiable versions of \verb!torch.optim! 525 | optimizers such as Adam and SGD. 526 | This enables arbitrary torch modules and optimizers 527 | to be unrolled and used as a semi-amortized model. 528 | \item \href{https://github.com/metaopt/TorchOpt}{TorchOpt} 529 | provides a functional and differentiable optimizer in PyTorch 530 | and has higher performance than \verb!higher! in some cases. 531 | \item \href{https://github.com/pytorch/functorch}{functorch} 532 | \citep{functorch2021} is a PyTorch library providing 533 | composable function transforms for batching and 534 | derivative operations, and for creating functional 535 | versions of PyTorch modules that can be used in 536 | optimization algorithms. 537 | All of these operations may arise in the implementation 538 | of an amortized optimization method and can become computational 539 | bottlenecks if not efficiently implemented. 540 | \item \href{https://github.com/jump-dev/DiffOpt.jl}{DiffOpt.jl} 541 | provides differentiable optimization in Julia's JuMP 542 | \citep{DunningHuchetteLubin2017}. 543 | \item \href{https://github.com/tristandeleu/pytorch-meta}{Torchmeta} 544 | \citep{deleu2019torchmeta} and 545 | \href{http://learn2learn.net}{learn2learn} 546 | \citep{arnold2020learn2learn} 547 | are PyTorch libraries and collection of meta-learning 548 | algorithms that also focus on making data-loading 549 | and task definitions easy. 550 | \item \href{https://github.com/prolearner/hypertorch}{hypertorch} 551 | \citep{grazzi2020iteration} 552 | is a PyTorch package for computing hypergradients with a 553 | large focus on providing computationally efficient approximations 554 | to them. 555 | \end{itemize} 556 | 557 | %%% Local Variables: 558 | %%% coding: utf-8 559 | %%% mode: latex 560 | %%% TeX-master: "../amor.tex" 561 | %%% LaTeX-biblatex-use-Biber: True 562 | %%% End: --------------------------------------------------------------------------------