├── assets ├── mcmc_1d.png ├── mcmc_2d.png ├── mcmc_2dv2.png ├── mcmc_2d_contour.png ├── variational_2d.png └── nn_regression_mcmc.png ├── jax_bayes ├── variational │ ├── __init__.py │ ├── families.py │ └── variational_family.py ├── __init__.py ├── mcmc │ ├── __init__.py │ ├── utils.py │ ├── sampler.py │ └── sampler_fns.py └── utils.py ├── requirements.txt ├── examples ├── examples_requirements.txt ├── README.md ├── deep │ ├── nn_regression │ │ ├── mlp_regression.py │ │ ├── mlp_regression_mcmc.py │ │ └── mlp_regression_var.py │ ├── mnist │ │ └── mnist_mcmc.py │ ├── cifar10 │ │ ├── cifar10.ipynb │ │ └── cifar10_mcmc.ipynb │ └── nmt │ │ └── attention_nmt.ipynb └── shallow │ ├── mcmc_1d.py │ ├── variational_2d.py │ └── mcmc_2d.py ├── LICENSE ├── .gitignore ├── setup.py └── README.md /assets/mcmc_1d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jamesvuc/jax-bayes/HEAD/assets/mcmc_1d.png -------------------------------------------------------------------------------- /assets/mcmc_2d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jamesvuc/jax-bayes/HEAD/assets/mcmc_2d.png -------------------------------------------------------------------------------- /assets/mcmc_2dv2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jamesvuc/jax-bayes/HEAD/assets/mcmc_2dv2.png -------------------------------------------------------------------------------- /assets/mcmc_2d_contour.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jamesvuc/jax-bayes/HEAD/assets/mcmc_2d_contour.png -------------------------------------------------------------------------------- /assets/variational_2d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jamesvuc/jax-bayes/HEAD/assets/variational_2d.png -------------------------------------------------------------------------------- /assets/nn_regression_mcmc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jamesvuc/jax-bayes/HEAD/assets/nn_regression_mcmc.png -------------------------------------------------------------------------------- /jax_bayes/variational/__init__.py: -------------------------------------------------------------------------------- 1 | from .families import diagonal_mvn_fns 2 | 3 | from .variational_family import variational_family -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py>=0.9.0 2 | numpy>=1.18.0 3 | opt-einsum>=3.3.0 4 | protobuf>=3.12.4 5 | scipy>=1.5.2 6 | six>=1.15.0 7 | tqdm>=4.48.2 8 | -------------------------------------------------------------------------------- /examples/examples_requirements.txt: -------------------------------------------------------------------------------- 1 | cycler==0.10.0 2 | dm-haiku==0.0.2 3 | kiwisolver==1.2.0 4 | matplotlib==3.3.0 5 | pandas==1.1.0 6 | Pillow==7.2.0 7 | protobuf==3.12.4 8 | scipy==1.5.2 9 | seaborn==0.10.1 10 | six==1.15.0 11 | tqdm==4.48.2 12 | -------------------------------------------------------------------------------- /jax_bayes/__init__.py: -------------------------------------------------------------------------------- 1 | """ jax-bayes is a bayesian inference library for JAX """ 2 | 3 | from jax_bayes import mcmc 4 | from jax_bayes import variational 5 | from jax_bayes import utils 6 | 7 | __version__ = "0.1.1" 8 | 9 | __all__ = ( 10 | "mcmc", 11 | "variational", 12 | "utils" 13 | ) -------------------------------------------------------------------------------- /jax_bayes/mcmc/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler_fns import langevin_fns 2 | from .sampler_fns import mala_fns 3 | from .sampler_fns import rk_langevin_fns 4 | from .sampler_fns import hmc_fns 5 | from .sampler_fns import rms_langevin_fns 6 | from .sampler_fns import rms_mala_fns 7 | from .sampler_fns import rwmh_fns 8 | 9 | from .utils import blackbox_mcmc, init_distributions 10 | 11 | from .sampler import sampler, SamplerState, SamplerKeys -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 James Vuckovic 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /jax_bayes/utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | def certainty_acc(pp, targets, cert_threshold=0.5): 5 | """ Calculates the accuracy-at-certainty from the predictive probabilites pp 6 | on the targets. 7 | 8 | Args: 9 | pp: (batch_size, n_classes) array of probabilities 10 | targets: (batch_size, n_calsses) array of label class indices 11 | cert_threhsold: (float) minimum probability for making a prediction 12 | 13 | Returns: 14 | accuracy at certainty, indicies of those prediction instances for which 15 | the model is certain. 16 | """ 17 | preds = jnp.argmax(pp, axis=1) 18 | pred_probs = jnp.max(pp, axis=1) 19 | 20 | certain_idxs = pred_probs >= cert_threshold 21 | acc_at_certainty = jnp.mean(targets[certain_idxs] == preds[certain_idxs]) 22 | 23 | return acc_at_certainty, certain_idxs 24 | 25 | @jax.jit 26 | @jax.vmap 27 | def entropy(p): 28 | """ computes discrete Shannon entropy. 29 | p: (n_classes,) array of probabilities corresponding to each class 30 | """ 31 | p += 1e-12 #tolerance to avoid nans while ensuring 0log(0) = 0 32 | return - jnp.sum(p * jnp.log(p)) 33 | 34 | def confidence_bands(y, sample_axis=-1): 35 | """ Computes confidence bands for samples y. 36 | 37 | Args: 38 | y: array of samples. 39 | 40 | Returns: 41 | mean and standard deviation along the axes not equal to sample_axis. 42 | """ 43 | m = jnp.mean(y, axis=sample_axis) 44 | s = jnp.std(y, axis=sample_axis) 45 | return m - s, m + s 46 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Examples 3 | 4 | We provide 2 levels of examples: 5 | 1. `examples/shallow`: These demonstrate the different samplers on 1- and 2- dimenisonal probability distributions (i.e. not neural networks). 6 | 7 | 2. `examples/deep`: These showcase how jax-bayes can be used for deep Bayesian ML. The goal is to allow one to compare different inference techniques apply to some standard problems in ML: 8 | 1. neural network regression 9 | 2. MNIST 10 | 3. CIFAR10 11 | 4. Neural Machine Translation 12 | 13 | Some of these are nontrivial to implement with current Bayesian methods. 14 | 15 | *current status*: 16 | 17 | example | optimization | MCMC | VI 18 | :--:|:--:|:--:|:--: 19 | nn regression | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: 20 | MNIST | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: 21 | CIFAR10 | :heavy_check_mark: | :heavy_check_mark:(broken) 22 | NMT | :heavy_check_mark: 23 | 24 | 25 | 26 | ## Visualizations 27 | Here are visualizations from some of the examples: 28 | 29 | ### `examples/shallow/mcmc_1d.py` 30 | ![](https://github.com/jamesvuc/jax-bayes/blob/master/assets/mcmc_1d.png "1d MCMC") 31 | 32 | ### `examples/shallow/mcmc_2d.py` 33 | 34 | ![](https://github.com/jamesvuc/jax-bayes/blob/master/assets/mcmc_2d.png "2d MCMC") 35 | 36 | ### `examples/shallow/variational_2d.py` 37 | ![](https://github.com/jamesvuc/jax-bayes/blob/master/assets/variational_2d.png "2d variational") 38 | 39 | ### `examples/deep/nn_regression/mlp_regression_mcmc.py` 40 | ![](https://github.com/jamesvuc/jax-bayes/blob/master/assets/nn_regression_mcmc.png "2d variational") -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | coverage-report/ 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | venv_*/ 94 | 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | .DS_Store 110 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) James Vuckovic. All rights reserved. 2 | # Licensed under the MIT license. 3 | """Setup for pip package. 4 | 5 | Adapted from https://github.com/deepmind/dm-haiku/blob/master/setup.py 6 | """ 7 | 8 | from setuptools import find_namespace_packages 9 | from setuptools import setup 10 | 11 | 12 | def _get_version(): 13 | with open('jax_bayes/__init__.py') as fp: 14 | for line in fp: 15 | if line.startswith('__version__'): 16 | g = {} 17 | exec(line, g) # pylint: disable=exec-used 18 | return g['__version__'] 19 | raise ValueError('`__version__` not defined in `jax_bayes/__init__.py`') 20 | 21 | 22 | def _parse_requirements(requirements_txt_path): 23 | with open(requirements_txt_path) as fp: 24 | return fp.read().splitlines() 25 | 26 | 27 | _VERSION = _get_version() 28 | 29 | EXTRA_PACKAGES = { 30 | 'jax': ['jax>=0.1.74'], 31 | 'jaxlib': ['jaxlib>=0.1.51'], 32 | } 33 | 34 | setup( 35 | name='jax-bayes', 36 | version=_VERSION, 37 | url='https://github.com/jamesvuc/jax-bayes', 38 | license='MIT', 39 | author='James Vuckovic', 40 | description='jax-bayes is a bayesian inference library for JAX.', 41 | long_description=open('README.md').read(), 42 | long_description_content_type='text/markdown', 43 | author_email='james@jamesvuckovic.com', 44 | # Contained modules and scripts. 45 | packages=find_namespace_packages(exclude=['*_test.py']), 46 | install_requires=_parse_requirements('requirements.txt'), 47 | extras_require=EXTRA_PACKAGES, 48 | # tests_require=_parse_requirements('requirements-test.txt'), 49 | requires_python='>=3.6', 50 | include_package_data=True, 51 | zip_safe=False, 52 | # PyPI package information. 53 | classifiers=[ 54 | 'Intended Audience :: Developers', 55 | 'Intended Audience :: Education', 56 | 'Intended Audience :: Science/Research', 57 | 'License :: OSI Approved :: Apache Software License', 58 | 'Programming Language :: Python :: 3', 59 | 'Programming Language :: Python :: 3.6', 60 | 'Programming Language :: Python :: 3.7', 61 | 'Topic :: Scientific/Engineering :: Mathematics', 62 | 'Topic :: Software Development :: Libraries :: Python Modules', 63 | 'Topic :: Software Development :: Libraries', 64 | ], 65 | ) -------------------------------------------------------------------------------- /jax_bayes/mcmc/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from tqdm import tqdm 3 | 4 | import jax 5 | from jax import grad, value_and_grad 6 | 7 | 8 | centered_uniform = \ 9 | lambda *args, **kwargs: jax.random.uniform(*args, **kwargs) - 0.5 10 | init_distributions = dict(normal=jax.random.normal, 11 | uniform=centered_uniform) 12 | 13 | #redefine elementwise_grad operation 14 | elementwise_grad = lambda f: jax.vmap(grad(f)) 15 | elementwise_value_and_grad = lambda f: jax.vmap(value_and_grad(f)) 16 | 17 | def blackbox_mcmc( 18 | logprob, 19 | x0, 20 | sampler_fn, 21 | num_iters=1000, 22 | proposal_iters=1, 23 | seed=None, 24 | recompute_grad=False, 25 | use_jit=True, 26 | **sampler_args 27 | ): 28 | """ A single-function black-box sampler abstracting the various pieces 29 | of the functional sampler methodologies. 30 | 31 | Args: 32 | logprob: a callable logbrob(x) that returns the unnormalized 33 | log probability of x. 34 | x0: array of initial sample(s) 35 | sampler_fn: a sampler_fn using the @sampler decorator (se sampler.py) 36 | num_iters: number of iterations 37 | proposal_iters: number of times to compute proposal w/ new gradients 38 | seed: seed for the keys 39 | recompute_grad: boolean for whether to recompute the gradients (use if 40 | proposal_iters > 0) 41 | use_jit: boolean for jitting the update step (for debugging). 42 | 43 | Returns: 44 | approximate samples according to logprob using the sampler_fn. 45 | """ 46 | 47 | g = elementwise_value_and_grad(logprob) 48 | 49 | seed = int(time.time() * 1000) if seed is None else seed 50 | init_key = jax.random.PRNGKey(seed) 51 | 52 | sampler = sampler_fn(init_key, **sampler_args) 53 | sampler_state, sampler_keys = sampler.init(x0) 54 | 55 | def _step(i, state, keys): 56 | x = sampler.get_params(state) 57 | fx, dx = g(x) 58 | 59 | prop_state, keys = sampler.propose(i, dx, state, keys) 60 | x_prop = sampler.get_params(prop_state) 61 | if recompute_grad: 62 | fx_prop, dx_prop = g(x_prop) 63 | for _ in range(max(proposal_iters-1, 0)): 64 | prop_state, keys = sampler.propose(i, dx_prop, prop_state, keys) 65 | x_prop = sampler.get_params(prop_state) 66 | fx_prop, dx_prop = g(x_prop) 67 | if proposal_iters > 1: 68 | prop_state, keys = sampler.propose(i, dx_prop, prop_state, keys, is_final=True) 69 | x_prop = sampler.get_params(prop_state) 70 | fx_prop, dx_prop = g(x_prop) 71 | else: 72 | fx_prop, dx_prop = fx, dx 73 | 74 | accept_idxs, keys = sampler.accept( 75 | i, fx, fx_prop, dx, state, dx_prop, prop_state, keys 76 | ) 77 | 78 | state, keys = sampler.update( 79 | i, accept_idxs, dx, state, dx, prop_state, keys 80 | ) 81 | return state, keys 82 | 83 | if use_jit: 84 | _step = jax.jit(_step) 85 | 86 | for i in tqdm(range(num_iters)): 87 | # if callback: callback(x, i, dx) 88 | sampler_state, sampler_keys = _step(i, sampler_state, sampler_keys) 89 | 90 | return sampler.get_params(sampler_state) 91 | -------------------------------------------------------------------------------- /examples/deep/nn_regression/mlp_regression.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | np.random.seed(0) 3 | 4 | import haiku as hk 5 | 6 | import jax.numpy as jnp 7 | from jax.experimental import optimizers 8 | import jax 9 | 10 | from tqdm import tqdm, trange 11 | from matplotlib import pyplot as plt 12 | 13 | def build_dataset(): 14 | n_train, n_test, d = 200, 100, 1 15 | xlims = [-1.0, 5.0] 16 | x_train = np.random.rand(n_train, d) * (xlims[1] - xlims[0]) + xlims[0] 17 | x_test = np.random.rand(n_test, d) * (xlims[1] - xlims[0]) + xlims[0] 18 | 19 | target_func = lambda t: (np.log(t + 100.0) * np.sin(1.0 * np.pi*t)) + 0.1 * t 20 | 21 | y_train = target_func(x_train) 22 | y_test = target_func(x_test) 23 | 24 | y_train += np.random.randn(*x_train.shape) * (1.0 * (x_train + 2.0)**0.5) 25 | 26 | return (x_train, y_train), (x_test, y_test) 27 | 28 | 29 | def net_fn(x): 30 | 31 | mlp = hk.Sequential([ 32 | hk.Linear(128, w_init=hk.initializers.RandomNormal(stddev=5.0), 33 | b_init=hk.initializers.RandomNormal(stddev=5.0)), 34 | jnp.tanh, 35 | hk.Linear(1, w_init=hk.initializers.RandomNormal(stddev=5.0), 36 | b_init=hk.initializers.RandomNormal(stddev=5.0)) 37 | ]) 38 | 39 | return mlp(x) 40 | 41 | def main(): 42 | # ======= Setup ======= 43 | xy_train, xy_test = build_dataset() 44 | (x_train, y_train), (x_test, y_test) = xy_train, xy_test 45 | 46 | lr = 1e-3 47 | reg = 0.0 48 | lik_var = 0.5 49 | 50 | net = hk.transform(net_fn) 51 | opt_init, opt_update, opt_get_params = optimizers.sgd(lr) 52 | 53 | def logprob(params, xy): 54 | """ log posterior logP(params | xy), assuming 55 | P(params) ~ N(0,eta) 56 | P(y|x, params) ~ N(f(x;params), lik_var) 57 | """ 58 | x, y = xy 59 | preds = net.apply(params, None, x) 60 | log_prior = - reg * sum(jnp.sum(jnp.square(p)) 61 | for p in jax.tree_leaves(params)) 62 | log_lik = - jnp.mean(jnp.square(preds - y)) / lik_var 63 | return log_lik + log_prior 64 | 65 | #minimize the - logprob to find MAP 66 | loss = lambda params, xy: - logprob(params, xy) 67 | 68 | @jax.jit 69 | def train_step(i, opt_state, batch): 70 | params = opt_get_params(opt_state) 71 | dx = jax.grad(loss)(params, batch) 72 | opt_state = opt_update(i, dx, opt_state) 73 | return opt_state 74 | 75 | # ======= Training ====== 76 | 77 | # initialization 78 | params = net.init(jax.random.PRNGKey(42), x_train) 79 | opt_state = opt_init(params) 80 | 81 | #do the optimization 82 | for step in trange(2000): 83 | if step % 200 == 0: 84 | params = opt_get_params(opt_state) 85 | train_loss = loss(params, xy_train) 86 | test_acc = loss(params, xy_test) 87 | print(f"step = {step}" 88 | f" | train loss = {train_loss:.3f}" 89 | f" | test loss = {test_acc:.3f}") 90 | 91 | opt_state = train_step(step, opt_state, xy_train) 92 | 93 | params = opt_get_params(opt_state) 94 | 95 | # ========= Plotting ========= 96 | plot_inputs = np.linspace(-1, 10, num=600).reshape(-1,1) 97 | outputs = net.apply(params, None, plot_inputs) 98 | 99 | f, ax = plt.subplots(1) 100 | 101 | ax.plot(x_train.ravel(), y_train.ravel(), 'bx', color='green') 102 | ax.plot(x_test.ravel(), y_test.ravel(), 'bx', color='red') 103 | ax.plot(plot_inputs, outputs, alpha=1) 104 | 105 | plt.show() 106 | 107 | if __name__ == '__main__': 108 | main() -------------------------------------------------------------------------------- /jax_bayes/variational/families.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | import numpy as np 7 | 8 | from .variational_family import variational_family 9 | 10 | def elbo_reparam(logprob, samples, var_approx, var_params): 11 | return jnp.mean(logprob(samples) - var_approx(samples, var_params)) 12 | 13 | def gaussian_elbo_reparam(logprob, samples, var_params): 14 | return jnp.mean(logprob(samples)) + diag_mvn_entropy(var_params) 15 | 16 | def elbo_noscore(logprob, samples, var_approx, var_params): 17 | var_params = jax.lax.stop_gradient(var_params) 18 | return - jnp.mean(logprob(samples) - var_approx(samples, var_params)) 19 | 20 | def diag_mvn_logpdf(x, mean, diag_cov): 21 | """ Returns the log_pdf of x under a MVN with diagonal covariance without 22 | storing the full covariance for O(N) storage instead of O(N^2). 23 | """ 24 | n = mean.shape[-1] 25 | y = x - mean 26 | tmp = jnp.einsum('...i,i->...i', y, 1./diag_cov) 27 | return (-1/2 * jnp.einsum('...i,...i->...', y, tmp) 28 | - n/2 * np.log(2*np.pi) - jnp.log(diag_cov).sum()/2.) 29 | 30 | def diag_mvn_entropy(logcov): 31 | d = logcov.shape[0] 32 | return 0.5 * d * (1.0 + np.log(2*np.pi)) + jnp.sum(logcov) 33 | 34 | @variational_family 35 | def diagonal_mvn_fns(base_key, mean_stddev=1.0, init_sigma=1.0, eps=1e-6): 36 | """ Constructs functions for a VariationalFamily object using a 37 | diagonal multivariate normal (equivalent to mean-field) variational family. 38 | 39 | Args: 40 | base_key: jax.random.PRNGKey used to seed the randomness for the algorithm 41 | mean_stddev: standard deviation of mean parameter initialization 42 | init_sigma: starting standard deviation of the mean field parameters 43 | eps: tolerance for the log-covariance 44 | 45 | Returns: 46 | init, sample, evaluate, get_samples, next_key, entropy, base_key 47 | functions for the VariationalFamily (to be tree-ified in the decorator). 48 | Entropy can be a dummy function, the others are needed for VI. 49 | """ 50 | def init(x0, key): 51 | next_key, key = jax.random.split(key) 52 | mean = jax.random.normal(key, x0.shape) * mean_stddev 53 | logcov = jnp.zeros_like(x0) + math.log(init_sigma) 54 | 55 | return (mean, logcov), next_key 56 | 57 | #TODO: Remove arg 'i' 58 | def sample(i, num_samples, key, params): 59 | """ sample from q( |params) """ 60 | key, next_key = jax.random.split(key) 61 | mean, logcov = params 62 | 63 | shape = (num_samples,) + mean.shape 64 | Z = jax.random.normal(key, shape) 65 | 66 | return Z * jnp.exp(logcov) + mean, next_key 67 | 68 | def evaluate(inputs, params): 69 | """ evaluate logq( |params) """ 70 | mean, logcov = params 71 | mean, logcov = mean.reshape(-1), logcov.reshape(-1) 72 | inputs = inputs.reshape(inputs.shape[0], -1) 73 | 74 | cov = jnp.exp(logcov) + eps 75 | return diag_mvn_logpdf(inputs, mean, cov) 76 | 77 | def get_samples(samples): 78 | return samples 79 | 80 | def next_key(key): 81 | _, new_key = jax.random.split(key) 82 | return new_key 83 | 84 | def entropy(params): 85 | _, logcov = params 86 | return diag_mvn_entropy(logcov) 87 | 88 | return init, sample, evaluate, get_samples, next_key, entropy, base_key 89 | -------------------------------------------------------------------------------- /examples/shallow/mcmc_1d.py: -------------------------------------------------------------------------------- 1 | #mcmc_1d.py 2 | import jax.numpy as jnp 3 | import jax.scipy.stats.norm as norm 4 | import jax 5 | 6 | from copy import copy, deepcopy 7 | import itertools, math 8 | import time 9 | from matplotlib import pyplot as plt 10 | import seaborn as sns 11 | 12 | from jax_bayes.mcmc import ( 13 | langevin_fns, 14 | mala_fns, 15 | rk_langevin_fns, 16 | hmc_fns, 17 | rms_langevin_fns, 18 | rwmh_fns 19 | ) 20 | from jax_bayes.mcmc import blackbox_mcmc as bb_mcmc 21 | 22 | @jax.jit 23 | def bimodal_logprob(z): 24 | return jnp.log(jnp.sin(z)**2) + jnp.log(jnp.sin(2*z)**2) + norm.logpdf(z) 25 | 26 | def main(): 27 | 28 | #====== Setup ======= 29 | 30 | n_iters, n_samples = 1000, 1000 31 | seed = 0 32 | init_vals = jnp.array(0.0) 33 | 34 | allsamps = [] 35 | logprob = bimodal_logprob 36 | 37 | #====== Tests ======= 38 | 39 | t = time.time() 40 | print('running 1d tests ...') 41 | samps = bb_mcmc( 42 | logprob, init_vals, langevin_fns, num_iters=n_iters, 43 | seed=seed, num_samples=n_samples, step_size=1e-3, 44 | init_dist='normal', init_stddev=1.0 45 | ) 46 | print('done langevin in', time.time()-t,'\n') 47 | allsamps.append(samps) 48 | 49 | t = time.time() 50 | samps = bb_mcmc( 51 | logprob, init_vals, mala_fns, num_iters=n_iters, 52 | seed=seed, num_samples=n_samples, step_size=1e-3, 53 | init_dist='normal', init_stddev=1.0, recompute_grad=True 54 | ) 55 | print('done MALA in', time.time()-t,'\n') 56 | allsamps.append(samps) 57 | 58 | t = time.time() 59 | samps = bb_mcmc( 60 | logprob, init_vals, rk_langevin_fns, num_iters=n_iters, 61 | seed=seed, num_samples=n_samples, step_size=1e-3, 62 | init_dist='normal', init_stddev=1.0, recompute_grad=True 63 | ) 64 | print('done langevin_RK in', time.time()-t,'\n') 65 | allsamps.append(samps) 66 | 67 | 68 | t = time.time() 69 | samps = bb_mcmc( 70 | logprob, init_vals, hmc_fns, num_iters=n_iters, 71 | proposal_iters= 5, seed=seed, num_samples=n_samples, step_size=1e-2, 72 | init_dist='normal', init_stddev=1.0, recompute_grad=True 73 | ) 74 | print('done HMC in', time.time()-t,'\n') 75 | allsamps.append(samps) 76 | 77 | t = time.time() 78 | samps = bb_mcmc( 79 | logprob, init_vals, rms_langevin_fns, num_iters=n_iters, 80 | seed = seed, num_samples = n_samples, step_size=5e-3, 81 | init_dist='normal', init_stddev=1.0, beta=0.99 82 | ) 83 | print('done rms in', time.time()-t,'\n') 84 | allsamps.append(samps) 85 | 86 | t = time.time() 87 | samps = bb_mcmc( 88 | logprob, init_vals, rwmh_fns, num_iters=n_iters, 89 | seed=seed, num_samples=n_samples, step_size=0.05, 90 | init_dist='normal', init_stddev=1.0, recompute_grad=True, 91 | ) 92 | print('done rwmh in', time.time()-t,'\n') 93 | allsamps.append(samps) 94 | 95 | #====== Plotting ======= 96 | 97 | lims = [-5,5] 98 | names = [ 99 | 'langevin', 100 | 'MALA', 101 | 'langevin_RK', 102 | 'HMC', 103 | 'RMS langevin', 104 | 'RWMH' 105 | ] 106 | cols = 2 107 | rows = math.ceil(len(names) / cols) 108 | idxs = itertools.product(range(rows), range(cols)) 109 | f, axes = plt.subplots(rows, cols, sharex=True, figsize=(12, 8)) 110 | for i, (name, samps, (r,c)) in enumerate(zip(names, allsamps, idxs)): 111 | print(samps.shape) 112 | sns.distplot(samps, bins=500, kde=False, ax=axes[r,c]) 113 | axb = axes[r,c].twinx() 114 | axb.scatter(samps, jnp.ones(len(samps)), alpha=0.1, marker='x', color='red') 115 | 116 | zs = jnp.linspace(*lims, num=250) 117 | axc = axes[r,c].twinx() 118 | axc.plot(zs, jnp.exp(bimodal_logprob(zs)), color='orange') 119 | 120 | axes[r,c].set_xlim(*lims) 121 | title = name 122 | axes[r,c].set_title(title) 123 | 124 | axes[r,c].set_yticks([]) 125 | axb.set_yticks([]) 126 | axc.set_yticks([]) 127 | 128 | plt.show() 129 | 130 | 131 | if __name__ == '__main__': 132 | main() -------------------------------------------------------------------------------- /examples/shallow/variational_2d.py: -------------------------------------------------------------------------------- 1 | """ Example for Black Box Variational Inference (BBVI) 2 | 3 | Adapted from https://github.com/HIPS/autograd/blob/master/examples/black_box_svi.py 4 | """ 5 | 6 | from matplotlib import pyplot as plt 7 | from tqdm import trange 8 | 9 | import jax 10 | import numpy as onp 11 | import jax.numpy as jnp 12 | from jax.experimental import optimizers 13 | 14 | import jax.scipy.stats.norm as norm 15 | import jax.scipy.stats.multivariate_normal as mvn 16 | 17 | from jax_bayes.variational import diagonal_mvn_fns, elbo_reparam 18 | 19 | @jax.jit 20 | @jax.vmap 21 | def logprob(z): 22 | x, y = z[0], z[1] 23 | y_density = norm.logpdf(y, 0, 1.35) 24 | x_density = norm.logpdf(x, 0, jnp.exp(y)) 25 | 26 | return x_density + y_density 27 | 28 | def plot_isocontours(ax, func, xlimits=[-2, 2], ylimits=[-4, 2], numticks=101): 29 | x = jnp.linspace(*xlimits, num=numticks) 30 | y = jnp.linspace(*ylimits, num=numticks) 31 | X, Y = jnp.meshgrid(x, y) 32 | zs = func(jnp.concatenate([jnp.atleast_2d(X.ravel()), 33 | jnp.atleast_2d(Y.ravel())]).T) 34 | Z = zs.reshape(X.shape) 35 | Z = jax.lax.stop_gradient(Z) 36 | ax.contour(X, Y, Z) 37 | ax.set_yticks([]) 38 | ax.set_xticks([]) 39 | 40 | def main(): 41 | key = jax.random.PRNGKey(1) 42 | vf = diagonal_mvn_fns(key, mean_stddev=0.1) 43 | 44 | x0 = jnp.zeros(2) 45 | var_params, var_keys = vf.init(x0) 46 | 47 | lr = 1e-3 48 | opt_init, opt_update, opt_get_params = optimizers.adam(lr) 49 | opt_state = opt_init(var_params) 50 | 51 | f, axes = plt.subplots(2, figsize=(7, 7)) 52 | f.subplots_adjust(bottom=0.05, top=0.95, hspace=0.3) 53 | pts, hist = [], [] 54 | def callback(i, fx, params, *args): 55 | pts.append(i) 56 | hist.append(-fx) 57 | 58 | ax = axes[0] 59 | ax.cla() 60 | ax.plot(pts, hist) 61 | 62 | ax = axes[1] 63 | ax.cla() 64 | ax.set_title(f"i={i}") 65 | plot_isocontours(ax, lambda z:jnp.exp(logprob(z))) 66 | 67 | mean, logcov = params 68 | cov = cov=jnp.diag(jnp.exp(logcov)) + 1e-6 69 | logq = lambda z:jnp.exp(mvn.logpdf(z, mean, cov)) 70 | logq = jax.vmap(logq) 71 | plot_isocontours(ax,logq) 72 | 73 | plt.pause(1.0/30) 74 | plt.draw() 75 | 76 | num_samples = 1000 77 | 78 | def elbo(p, keys): 79 | samples_state, _ = vf.sample(0, num_samples, keys, p) 80 | samples = vf.get_samples(samples_state) 81 | 82 | #this is using the default reparameterization trick 83 | # return jnp.mean(logprob(samples) - vf.evaluate(samples_state, p)) 84 | 85 | #this uses the fact that the vf (a diagonal MVN) has a closed-form entropy 86 | return jnp.mean(logprob(samples)) + vf.entropy(p) 87 | 88 | params = opt_get_params(opt_state) 89 | print(f'elbo before:{elbo(params, var_keys):.5f}') 90 | 91 | @jax.jit 92 | def bbvi_step(i, opt_state, var_keys): 93 | params = opt_get_params(opt_state) 94 | var_keys = vf.next_key(var_keys) #generate one key to use now 95 | next_keys = vf.next_key(var_keys) #generate one to return 96 | 97 | obj = lambda p: - elbo(p, var_keys) 98 | 99 | loss, dlambda = jax.value_and_grad(obj)(params) 100 | opt_state = opt_update(i, dlambda, opt_state) 101 | 102 | return opt_state, next_keys, loss 103 | 104 | #do the optimization loop 105 | for i in trange(5000): 106 | opt_state, var_keys, loss = bbvi_step(i, opt_state, var_keys) 107 | if i % 10 == 0: 108 | callback(i, loss, vf.get_params(opt_get_params(opt_state))) 109 | 110 | params = opt_get_params(opt_state) 111 | print(f'elbo after {elbo(params, var_keys):.5f}') 112 | plt.show() 113 | 114 | if __name__ == '__main__': 115 | main() 116 | -------------------------------------------------------------------------------- /examples/shallow/mcmc_2d.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | 3 | import jax 4 | import jax.numpy as np 5 | import jax.scipy.stats.multivariate_normal as mvn 6 | 7 | import itertools, math 8 | import time 9 | from matplotlib import pyplot as plt 10 | import seaborn as sns 11 | 12 | from jax_bayes.mcmc import ( 13 | langevin_fns, 14 | mala_fns, 15 | rk_langevin_fns, 16 | hmc_fns, 17 | rms_langevin_fns, 18 | rwmh_fns 19 | ) 20 | from jax_bayes.mcmc import blackbox_mcmc as bb_mcmc 21 | 22 | def make_logprob(): 23 | 24 | mus = np.array([[0.0, 0.0], 25 | [2.5, 4.0], 26 | [4.0, 0.0]]) 27 | 28 | sigmas = np.array([ [[1.0, 0.0], 29 | [0.0, 2.0]], 30 | 31 | [[2.0, -1.0], 32 | [-1.0, 1.0]], 33 | 34 | [[1.0, 0.1], 35 | [0.1, 2.0]] ]) 36 | 37 | @jax.jit 38 | def _logprob(z): 39 | return np.log(mvn.pdf(z, mean=mus[0], cov=sigmas[0]) + \ 40 | mvn.pdf(z, mean=mus[1], cov=sigmas[1]) + \ 41 | mvn.pdf(z, mean=mus[2], cov=sigmas[2])) 42 | 43 | return _logprob 44 | 45 | 46 | def main(): 47 | #====== Setup ======= 48 | n_iters, n_samples, d = 2000, 2000, 2 49 | key = jax.random.PRNGKey(1) 50 | init_vals = np.array([2.0, 2.0]) 51 | 52 | 53 | logprob = make_logprob() 54 | allsamps = [] 55 | 56 | #====== Tests ======= 57 | 58 | t = time.time() 59 | print('running 2d tests ...') 60 | samps = bb_mcmc( 61 | logprob, init_vals, langevin_fns, num_iters=n_iters, 62 | num_samples=n_samples, seed=0, step_size=0.05, 63 | init_dist='uniform', init_stddev=5.0) 64 | print('done langevin in', time.time()-t,'\n') 65 | allsamps.append(samps) 66 | 67 | t = time.time() 68 | samps = bb_mcmc( 69 | logprob, init_vals, mala_fns, num_iters=n_iters, 70 | num_samples=n_samples, seed=0, step_size=0.05, 71 | init_dist='uniform', init_stddev=5.0, recompute_grad=True) 72 | print('done MALA in', time.time()-t,'\n') 73 | allsamps.append(samps) 74 | 75 | 76 | t = time.time() 77 | samps = bb_mcmc( 78 | logprob, init_vals, rk_langevin_fns, num_iters=n_iters, 79 | num_samples=n_samples, seed=0, step_size=0.05, 80 | init_dist='uniform', init_stddev=5.0, recompute_grad=True) 81 | print('done langevin_RK in', time.time()-t,'\n') 82 | allsamps.append(samps) 83 | 84 | t = time.time() 85 | samps = bb_mcmc( 86 | logprob, init_vals, hmc_fns, num_iters=n_iters//5, 87 | proposal_iters = 5, num_samples=n_samples, seed=0, step_size=0.05, 88 | init_dist='uniform', init_stddev=5.0, recompute_grad=True) 89 | print('done HMC in', time.time()-t,'\n') 90 | allsamps.append(samps) 91 | 92 | t = time.time() 93 | samps = bb_mcmc(logprob, init_vals, rms_langevin_fns, num_iters=n_iters, 94 | num_samples=n_samples, seed=0, step_size=1e-3, #1e-3 95 | beta=0.99, eps=1e-5, 96 | init_dist='uniform', init_stddev=5.0) 97 | print('done rms_langevin in', time.time()-t,'\n') 98 | allsamps.append(samps) 99 | 100 | 101 | t = time.time() 102 | samps = bb_mcmc(logprob, init_vals, rwmh_fns, num_iters=n_iters, 103 | num_samples=n_samples, seed=0, step_size=0.05, 104 | init_dist='uniform', init_stddev=5.0, recompute_grad=True) 105 | print('done RW MH in' , time.time()-t,'\n') 106 | allsamps.append(samps) 107 | 108 | 109 | #====== Plotting ======= 110 | init_vals = jax.random.uniform(key, (n_samples,d)) * 5.0 111 | 112 | pts = onp.linspace(-7, 7, 1000) 113 | X, Y = onp.meshgrid(pts, pts) 114 | pos = onp.empty(X.shape + (2,)) 115 | pos[:, :, 0] = X 116 | pos[:, :, 1] = Y 117 | Z = onp.exp(jax.vmap(logprob)(pos)) 118 | 119 | """ for the solo contour plot 120 | f, ax = plt.subplots() 121 | ax.contour(X, Y, Z,) 122 | ax.set_xlim(-2, 6) 123 | ax.set_ylim(-5, 7) 124 | plt.show() 125 | """ 126 | 127 | names = ['langevin', 'MALA', 'langevin_RK', 'HMC', 'RMS_langevin', 'RWMH'] 128 | 129 | cols = 2 130 | rows = math.ceil(len(names) / cols) 131 | idxs = itertools.product(range(rows), range(cols)) 132 | f, axes = plt.subplots(rows, cols, sharex=True, figsize=(12, 9)) 133 | 134 | for i, (name, samps, (r,c)) in enumerate(zip(names, allsamps, idxs)): 135 | print(name) 136 | row = i // 3 137 | col = i % 3 138 | ax = axes[r, c] 139 | 140 | # ax.contour(X, Y, Z, alpha=0.5, cmap='Oranges') 141 | # ax.hist2d(samps[:,0], samps[:,1], alpha=0.5, bins=25) 142 | sns.kdeplot(samps[:,0], samps[:,1], shade=True, shade_lowest=True, 143 | cmap='Blues', ax=ax) 144 | ax.set_title(name) 145 | ax.set_xlim(-2, 6) 146 | ax.set_ylim(-5, 7) 147 | ax.set_xticks([]) 148 | ax.set_yticks([]) 149 | 150 | plt.show() 151 | 152 | if __name__ == '__main__': 153 | main() 154 | 155 | 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /examples/deep/nn_regression/mlp_regression_mcmc.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | np.random.seed(0) 5 | 6 | import haiku as hk 7 | 8 | import jax.numpy as jnp 9 | from jax.experimental import optimizers 10 | import jax 11 | 12 | from tqdm import tqdm, trange 13 | from matplotlib import pyplot as plt 14 | 15 | from jax_bayes.utils import confidence_bands 16 | from jax_bayes.mcmc import ( 17 | # langevin_fns, 18 | mala_fns, 19 | # hmc_fns, 20 | ) 21 | 22 | #could use any of the samplers modulo hyperparameters 23 | # sampler_fns = hmc_fns 24 | # sampler_fns = langevin_fns 25 | sampler_fns = mala_fns 26 | 27 | def build_dataset(): 28 | n_train, n_test, d = 200, 100, 1 29 | xlims = [-1.0, 5.0] 30 | x_train = np.random.rand(n_train, d) * (xlims[1] - xlims[0]) + xlims[0] 31 | x_test = np.random.rand(n_test, d) * (xlims[1] - xlims[0]) + xlims[0] 32 | 33 | target_func = lambda t: (np.log(t + 100.0) * np.sin(1.0 * np.pi*t)) + 0.1 * t 34 | 35 | y_train = target_func(x_train) 36 | y_test = target_func(x_test) 37 | 38 | y_train += np.random.randn(*x_train.shape) * (1.0 * (x_train + 2.0)**0.5) 39 | 40 | return (x_train, y_train), (x_test, y_test) 41 | 42 | def net_fn(x): 43 | 44 | mlp = hk.Sequential([ 45 | hk.Linear(128, w_init=hk.initializers.Constant(0), 46 | b_init=hk.initializers.Constant(0)), 47 | jnp.tanh, 48 | hk.Linear(1, w_init=hk.initializers.Constant(0), 49 | b_init=hk.initializers.Constant(0)) 50 | ]) 51 | 52 | return mlp(x) 53 | 54 | def main(): 55 | # ======= Setup ======= 56 | xy_train, xy_test = build_dataset() 57 | (x_train, y_train), (x_test, y_test) = xy_train, xy_test 58 | 59 | # lr = 1e-3 60 | # reg = 0.1 61 | # lik_var = 0.5 62 | 63 | # lr = 1e-1 64 | lr = 1e-4 65 | reg = 0.1 66 | lik_var = 0.5 67 | 68 | net = hk.transform(net_fn) 69 | key = jax.random.PRNGKey(0) 70 | 71 | sampler_init, sampler_propose, sampler_accept, sampler_update, sampler_get_params = \ 72 | sampler_fns(key, num_samples=10, step_size=lr, init_stddev=5.0) 73 | 74 | 75 | def logprob(params, xy): 76 | """ log posterior, assuming 77 | P(params) ~ N(0,eta) 78 | P(y|x, params) ~ N(f(x;params), lik_var) 79 | """ 80 | x, y = xy 81 | 82 | preds = net.apply(params, None, x) 83 | log_prior = - reg * sum(jnp.sum(jnp.square(p)) 84 | for p in jax.tree_leaves(params)) 85 | log_lik = - jnp.mean(jnp.square(preds - y)) / lik_var 86 | return log_lik + log_prior 87 | 88 | @jax.jit 89 | def sampler_step(i, state, keys, batch): 90 | # print(state) 91 | # input() 92 | params = sampler_get_params(state) 93 | logp = lambda params:logprob(params, batch) 94 | fx, dx = jax.vmap(jax.value_and_grad(logp))(params) 95 | 96 | fx_prop, dx_prop = fx, dx 97 | # fx_prop, prop_state, dx_prop, new_keys = fx, state, dx, keys 98 | prop_state, keys = sampler_propose(i, dx, state, keys) 99 | 100 | # for RK-langevin and MALA --- recompute gradients 101 | prop_params = sampler_get_params(prop_state) 102 | fx_prop, dx_prop = jax.vmap(jax.value_and_grad(logp))(prop_params) 103 | 104 | # for HMC 105 | # prop_state, dx_prop, keys = state, dx, keys 106 | # for j in range(5): #5 iterations of the leapfrog integrator 107 | # prop_state, keys = \ 108 | # sampler_propose(i, dx_prop, prop_state, keys) 109 | 110 | # prop_params = sampler_get_params(prop_state) 111 | # fx_prop, dx_prop = jax.vmap(jax.value_and_grad(logp))(prop_params) 112 | 113 | accept_idxs, keys = sampler_accept( 114 | i, fx, fx_prop, dx, state, dx_prop, prop_state, keys 115 | ) 116 | state, keys = sampler_update( 117 | i, accept_idxs, dx, state, dx_prop, prop_state, keys 118 | ) 119 | 120 | 121 | return state, keys 122 | 123 | # ======= Sampling ====== 124 | 125 | # initialization 126 | params = net.init(jax.random.PRNGKey(42), x_train) 127 | sampler_state, sampler_keys = sampler_init(params) 128 | 129 | #do the sampling 130 | for step in trange(5000): 131 | # if step % 250 == 0: 132 | if False: 133 | sampler_params = sampler_get_params(sampler_state) 134 | logp = lambda params:logprob(params, xy_train) 135 | train_logp = jnp.mean(jax.vmap(logp)(sampler_params)) 136 | logp = lambda params:logprob(params, xy_test ) 137 | test_logp = jnp.mean(jax.vmap(logp)(sampler_params)) 138 | print(f"step = {step}" 139 | f" | train logp = {train_logp:.3f}" 140 | f" | test logp = {test_logp:.3f}") 141 | 142 | sampler_state, sampler_keys = \ 143 | sampler_step(step, sampler_state, sampler_keys, xy_train) 144 | 145 | 146 | sampler_params = sampler_get_params(sampler_state) 147 | 148 | # ========= Plotting ======== 149 | plot_inputs = np.linspace(-1, 10, num=600).reshape(-1,1) 150 | outputs = jax.vmap(net.apply, in_axes=(0, None, None))(sampler_params, None, plot_inputs) 151 | 152 | lower, upper = confidence_bands(outputs.squeeze(-1).T) 153 | 154 | f, ax = plt.subplots(1) 155 | 156 | ax.plot(x_train.ravel(), y_train.ravel(), 'x', color='green') 157 | ax.plot(x_test.ravel(), y_test.ravel(), 'x', color='red') 158 | for i in range(outputs.shape[0]): 159 | ax.plot(plot_inputs, outputs[i], alpha=0.25) 160 | ax.plot(plot_inputs, np.mean(outputs[:, :, 0].T, axis=1), color='black', 161 | linewidth=1.0) 162 | ax.fill_between(plot_inputs.squeeze(-1), lower, upper, alpha=0.75) 163 | 164 | ax.set_ylim(-10, 15) 165 | ax.set_xticks([]) 166 | ax.set_yticks([]) 167 | 168 | plt.show() 169 | 170 | if __name__ == '__main__': 171 | main() -------------------------------------------------------------------------------- /jax_bayes/variational/variational_family.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import functools 3 | 4 | import jax 5 | from jax._src.util import partial, safe_zip, safe_map, unzip2 6 | from jax._src.tree_util import tree_flatten, tree_unflatten, register_pytree_node 7 | 8 | from ..mcmc.sampler import SamplerKeys, SamplerState 9 | 10 | map = safe_map 11 | zip = safe_zip 12 | 13 | VariationalParams = namedtuple( 14 | "VariationalParams", 15 | ["packed_state", "tree_def", "subtree_defs"] 16 | ) 17 | register_pytree_node( 18 | VariationalParams, 19 | lambda xs: ((xs.packed_state,), (xs.tree_def, xs.subtree_defs)), 20 | lambda data, xs: VariationalParams(xs[0], data[0], data[1]) 21 | ) 22 | 23 | class VariationalFamily: 24 | init = None 25 | sample = None 26 | evaluate = None 27 | get_samples = None 28 | get_params = None 29 | get_params = None 30 | next_key = None 31 | entropy = None 32 | 33 | def variational_family(var_maker): 34 | """Decorator to make an optimizer defined for arrays generalize to containers. 35 | 36 | With this decorator, you can write variational_family functions that 37 | each operate only on single arrays, and convert them to corresponding 38 | functions that operate on pytrees of parameters. See the optimizers defined in 39 | optimizers.py for examples. 40 | 41 | Note: The variational families produced by this function are limited to modelling 42 | (at most) block-diagonal dependence with one block per leaf node in the pytree. 43 | This is used when we sum the log-probabilities (e.g. in tree_evaluate). 44 | 45 | Args: 46 | var_maker: a function that returns an ``(init, sample, evaluate, get_samples, 47 | next_key, entropy, init_key)`` tuple of functions that might only work 48 | with ndarrays. 49 | 50 | Returns: 51 | A ``VariationalFamily object`` that collects tree-ified versions of the 52 | above functions that work on arbitrary pytrees. 53 | 54 | The VariationalParams pytree type used by the returned functions is isomorphic 55 | to ``ParameterPytree (OptStatePytree ndarray)``, but may store the state 56 | instead as e.g. a partially-flattened data structure for performance. 57 | """ 58 | 59 | @functools.wraps(var_maker) 60 | def tree_var_maker(*args, **kwargs): 61 | init, sample, evaluate, get_samples, next_key, entropy, init_key = \ 62 | var_maker(*args, **kwargs) 63 | 64 | @functools.wraps(init) 65 | def tree_init(x0_tree): 66 | x0_flat, tree = tree_flatten(x0_tree) 67 | initial_keys = jax.random.split(init_key, len(x0_flat)) 68 | initial_params, initial_keys = unzip2(init(x0, k) for x0, k in \ 69 | zip(x0_flat, initial_keys)) 70 | params_flat, subtrees = unzip2(map(tree_flatten, initial_params)) 71 | return VariationalParams(params_flat, tree, subtrees), SamplerKeys(initial_keys) 72 | 73 | @functools.wraps(sample) 74 | def tree_sample(i, num_samples, tree_keys, var_params): 75 | params_flat, tree, subtrees = var_params 76 | params = map(tree_unflatten, subtrees, params_flat) 77 | keys, keys_meta = tree_flatten(tree_keys) 78 | samples, new_keys = unzip2(map(partial(sample, i, num_samples), keys, params)) 79 | samples_flat, subtrees2 = unzip2(map(tree_flatten, samples)) 80 | return SamplerState(samples_flat, tree, subtrees2), SamplerKeys(new_keys) 81 | 82 | @functools.wraps(evaluate) 83 | def tree_evaluate(inputs, var_params): 84 | """ this assumes each factor is independent (i.e. block-diagonal) """ 85 | params_flat, tree, subtrees = var_params 86 | params = map(tree_unflatten, subtrees, params_flat) 87 | 88 | #inputs is a also a pytree with the same structure as var_params 89 | inputs_flat, tree2, subtrees2 = inputs 90 | inputs = map(tree_unflatten, subtrees2, inputs_flat) 91 | 92 | if tree2 != tree: 93 | msg = ("evaluate update function was passed a inputs tree that did " 94 | "not match the parameter tree structure with which it was " 95 | "initialized: parameter tree {} and inputs tree {}.") 96 | raise TypeError(msg.format(tree, tree2)) 97 | 98 | logprob = sum(evaluate(x, p) for x, p in zip(inputs, params)) 99 | return logprob 100 | 101 | @functools.wraps(get_samples) 102 | def tree_get_samples(var_state): 103 | states_flat, tree, subtrees = var_state 104 | states = map(tree_unflatten, subtrees, states_flat) 105 | samples = map(get_samples, states) 106 | return tree_unflatten(tree, samples) 107 | 108 | tree_get_params = tree_get_samples 109 | 110 | @functools.wraps(next_key) 111 | def tree_next_key(tree_keys): 112 | keys, keys_meta = tree_flatten(tree_keys) 113 | new_keys = [next_key(key) for key in keys] 114 | return SamplerKeys(new_keys) 115 | 116 | @functools.wraps(entropy) 117 | def tree_entropy(var_params): 118 | params_flat, tree, subtrees = var_params 119 | params = map(tree_unflatten, subtrees, params_flat) 120 | 121 | ent = sum(entropy(p) for p in params) 122 | return ent 123 | 124 | var_family = VariationalFamily() 125 | var_family.init = tree_init 126 | var_family.sample = tree_sample 127 | var_family.evaluate = tree_evaluate 128 | var_family.get_samples = tree_get_samples 129 | var_family.get_params = tree_get_params 130 | var_family.next_key = tree_next_key 131 | var_family.entropy = tree_entropy 132 | 133 | return var_family 134 | 135 | return tree_var_maker 136 | -------------------------------------------------------------------------------- /examples/deep/nn_regression/mlp_regression_var.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | np.random.seed(0) 3 | 4 | import haiku as hk 5 | 6 | import jax.numpy as jnp 7 | from jax.experimental import optimizers 8 | import jax 9 | 10 | from tqdm import tqdm, trange 11 | from matplotlib import pyplot as plt 12 | 13 | from jax_bayes.variational import diagonal_mvn_fns 14 | from jax_bayes.utils import confidence_bands 15 | 16 | def build_dataset(): 17 | n_train, n_test, d = 100, 100, 1 18 | xlims = [-1.0, 5.0] 19 | x_train = np.random.rand(n_train, d) * (xlims[1] - xlims[0]) + xlims[0] 20 | x_test = np.random.rand(n_test, d) * (xlims[1] - xlims[0]) + xlims[0] 21 | 22 | target_func = lambda t: (np.log(t + 100.0) * np.sin(1.0 * np.pi*t)) + 0.1 * t 23 | 24 | y_train = target_func(x_train) 25 | y_test = target_func(x_test) 26 | 27 | y_train += np.random.randn(*x_train.shape) * (1.0 * (x_train + 2.0)**0.5) 28 | 29 | return (x_train, y_train), (x_test, y_test) 30 | 31 | def net_fn(x): 32 | sig = 4.0 33 | rbf = lambda x: jnp.exp(-x**2)#deep basis function model 34 | activation = rbf 35 | # activation = jnp.tanh 36 | mlp = hk.Sequential([ 37 | hk.Linear(128, w_init=hk.initializers.Constant(0), 38 | b_init=hk.initializers.Constant(0)), 39 | activation, 40 | hk.Linear(1, w_init=hk.initializers.Constant(0), 41 | b_init=hk.initializers.Constant(0)) 42 | ]) 43 | return mlp(x) 44 | 45 | def main(): 46 | # ======= Setup ======= 47 | xy_train, xy_test = build_dataset() 48 | (x_train, y_train), (x_test, y_test) = xy_train, xy_test 49 | 50 | lr = 5e-2 51 | reg = 0.1 #this is it 52 | lik_var = 0.1 53 | 54 | net = hk.transform(net_fn) 55 | params = net.init(jax.random.PRNGKey(42), x_train) 56 | 57 | seed = 0 58 | key = jax.random.PRNGKey(seed) 59 | vf = diagonal_mvn_fns(key, init_sigma = 0.1) 60 | var_params, var_keys = vf.init(params) 61 | 62 | opt_init, opt_update, opt_get_params = optimizers.adam(lr) 63 | opt_state = opt_init(var_params) 64 | 65 | @jax.jit 66 | def logprob(params, xy): 67 | """ log posterior, assuming 68 | P(params) ~ N(0,eta) 69 | P(y|x, params) ~ N(f(x;params), lik_var) 70 | """ 71 | x, y = xy 72 | preds = net.apply(params, None, x) 73 | log_prior = - reg * sum(jnp.sum(jnp.square(p)) 74 | for p in jax.tree_leaves(params)) 75 | log_lik = - jnp.mean(jnp.square(preds - y)) / lik_var 76 | return log_lik + log_prior 77 | 78 | num_samples = 50 79 | 80 | @jax.jit 81 | def bbvi_step(i, opt_state, var_keys, batch): 82 | var_params = opt_get_params(opt_state) 83 | logp = lambda p: logprob(p, batch) 84 | logp = jax.vmap(logp) 85 | 86 | var_keys = vf.next_key(var_keys) #generate one key to use now 87 | next_keys = vf.next_key(var_keys) #generate one to return 88 | 89 | def elbo(p, keys): 90 | samples_state, _ = vf.sample(0, num_samples, keys, p) 91 | samples = vf.get_samples(samples_state) 92 | 93 | # 'stick the landing' ELBO estimator see https://arxiv.org/pdf/1703.09194.pdf 94 | return jnp.mean(logp(samples) - 95 | vf.evaluate(samples_state, jax.lax.stop_gradient(p))) 96 | 97 | obj = lambda p: - elbo(p, var_keys) 98 | 99 | loss, dlambda = jax.value_and_grad(obj)(var_params) 100 | opt_state = opt_update(i, dlambda, opt_state) 101 | return opt_state, loss, next_keys 102 | 103 | # ======== Optimization ========= 104 | 105 | hist = [] 106 | for step in trange(2000): 107 | if step % 250 == 0: 108 | var_params = opt_get_params(opt_state) 109 | samples_state, _ = vf.sample(0, num_samples, var_keys, var_params) 110 | param_samples = vf.get_samples(samples_state) 111 | 112 | logp = lambda params:logprob(params, xy_train) 113 | train_logp = jnp.mean(jax.vmap(logp)(param_samples)) 114 | 115 | _elbo = jnp.mean(jax.vmap(logp)(param_samples) - 116 | vf.evaluate(samples_state, var_params)) 117 | 118 | logp = lambda params:logprob(params, xy_test) 119 | test_logp = jnp.mean(jax.vmap(logp)(param_samples)) 120 | print(f"step = {step}" 121 | f" | train logp = {train_logp:.3f}" 122 | f" | test logp = {test_logp:.3f}" 123 | f" | train elbo = {_elbo:.3f}") 124 | 125 | opt_state, loss, var_keys = bbvi_step(step, opt_state, var_keys, xy_train) 126 | hist.append(loss) 127 | 128 | # generate the final samples 129 | var_params = opt_get_params(opt_state) 130 | samples_state, _ = vf.sample(0, 10, var_keys, var_params) 131 | param_samples = vf.get_samples(samples_state) 132 | 133 | logp = lambda params:logprob(params, xy_train) 134 | final_logp = jnp.mean(jax.vmap(logp)(param_samples)) 135 | print(f'final logp = {final_logp:.3f}') 136 | 137 | # =========== Plotting =========== 138 | plot_inputs = np.linspace(-1, 10, num=600).reshape(-1,1) 139 | outputs = jax.vmap(lambda params: net.apply(params, None, plot_inputs))(param_samples) 140 | 141 | lower, upper = confidence_bands(outputs.squeeze(-1).T) 142 | 143 | f, axes = plt.subplots(2) 144 | 145 | ax = axes[0] 146 | ax.plot(hist) 147 | 148 | ax = axes[1] 149 | ax.plot(x_train.ravel(), y_train.ravel(), 'bx', color='green') 150 | ax.plot(x_test.ravel(), y_test.ravel(), 'bx', color='red') 151 | for i in range(outputs.shape[0]): 152 | ax.plot(plot_inputs, outputs[i], alpha=0.25) 153 | ax.plot(plot_inputs, np.mean(outputs[:, :, 0].T, axis=1), color='black', linewidth=1.0) 154 | ax.fill_between(plot_inputs.squeeze(-1), lower, upper, alpha=0.75) 155 | 156 | plt.show() 157 | 158 | if __name__ == '__main__': 159 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # jax-bayes 2 | High-dimensional Bayesian inference with Python and Jax. 3 | 4 | ## Overview 5 | jax-bayes is designed to accelerate research in high-dimensional Bayesian inference, specifically for deep neural networks. It is built on [Jax](https://github.com/google/jax). 6 | 7 | ***NOTE: the `jax_bayes.mcmc` api was updated on 02/05/2022 to version 0.1.0 and is not backwards compatible with the previous version 0.0.1. The changes are minor, and they fix a significant bug. See [this PR](https://github.com/jamesvuc/jax-bayes/pull/1) for more details.*** 8 | 9 | jax-bayes supports two different methods for sampling from high-dimensional distributions: 10 | - **Markov Chain Monte Carlo** (MCMC) which iterates a Markov chain which has an invariant distribution (approximately) equal to the target distribution 11 | - **Variational Inference** (VI): which finds the closest (in some sense) distribution in a parameterized family of distributions to the target distribution. 12 | 13 | jax-bayes allows you to **"bring your own JAX-based network to the Bayesian ML party"** by providing samplers that operate on arbitrary data structures of JAX arrays and JAX transformations. You can also define your own sampler in terms of JAX arrays and lift them to general-purpose samplers (using the same approach as in [``jax.experimental.optimizers``](https://jax.readthedocs.io/en/latest/_modules/jax/experimental/optimizers.html)) 14 | 15 | ### Quickstart 16 | You can easily modify this [Haiku quickstart example](https://github.com/deepmind/dm-haiku#quickstart) to support bayesian inference: 17 | ```python 18 | # ---- From the Haiku Quickstart ---- 19 | import jax.numpy as jnp 20 | import haiku as hk 21 | 22 | def softmax_cross_entropy(logits, labels): 23 | one_hot = jax.nn.one_hot(labels, logits.shape[-1]) 24 | return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1) 25 | 26 | def logprob_fn(batch): 27 | 28 | mlp = hk.Sequential([ 29 | hk.Linear(300), jax.nn.relu, 30 | hk.Linear(100), jax.nn.relu, 31 | hk.Linear(10), 32 | ]) 33 | logits = mlp(batch['images']) 34 | return - jnp.mean(softmax_cross_entropy(logits, batch['labels'])) 35 | 36 | logprob = hk.transform(logprob_fn) 37 | 38 | # ---- With jax-bayes ---- 39 | 40 | #instantiate the sampler 41 | key = jax.random.PRNGKey(0) 42 | from jax_bayes.mcmc import langevin_fns 43 | init, propose, accept, update, get_params = langevin_fns(key, lr=1e-3) 44 | 45 | #define the mcmc step 46 | @jax.jit 47 | def mcmc_step(state, keys, batch): 48 | params = get_params(state) 49 | batch_logprob = lambda p: logprob.apply(p, None, batch) 50 | 51 | #use vmap + grad to compute per-sample gradients 52 | g = jax.vmap(jax.grad(batch_logprob))(params) 53 | 54 | #omiting some unused arguments for this example 55 | propose_state, new_keys = propose(g, state, keys, ...) 56 | accept_idxs, new_keyes = accept(g, state, ..., prop_state, ...) # not ncessary for langevin algorithm 57 | next_state, new_keys = update(accept_idxs, state, propose_state, new_keys, ...) 58 | 59 | return next_state, new_keys 60 | 61 | #initialize the sampler state 62 | params = logprob.init(jax.random.PRNGKey(1), next(dataset)) 63 | sampler_state, keys = init(params) 64 | 65 | #run the mcmc algorithm 66 | for i in range(1000): 67 | sampler_state, keys = mcmc_step(sampler_state, keys, next(dataset)) 68 | 69 | # extract your samples 70 | sampled_params = get_params(sampler_state) 71 | ``` 72 | 73 | ### Logits != Uncertainty 74 | Sometimes we want our neural networks to say "I don't know" (think self-driving cars, machine translation, etc) but, as illustrated in [this paper](http://proceedings.mlr.press/v48/gal16.pdf) or [`examples/deep/mnist`](https://github.com/jamesvuc/jax-bayes/tree/master/examples/deep/mnist), the logits of a neural network should not serve a substitute for uncertainty. This library allows you to model *weight uncertainty* about the data by sampling from the posterior rather than optimizing it. You can also take advantge of occam's razor and other benefits of Bayesian statistics. 75 | 76 | ## Installation 77 | jax-bayes requires jax>=0.1.74 and jaxlib>=0.1.15 as separate dependencies, since jaxlib needs to be [installed](https://github.com/google/jax#pip-installation) for the accelerator (CPU / GPU / TPU). 78 | 79 | Assuming you have jax + jaxlib installed, install via pip: 80 | ``` 81 | pip install git+https://github.com/jamesvuc/jax-bayes 82 | ``` 83 | 84 | ## Package Description 85 | - ``jax_bayes.mcmc`` contains the MCMC functionality. It provides: 86 | - ``jax_bayes.mcmc.sampler`` which is the decorator that "tree-ifies" a sampler's methods. A sampler is defined as a callable returning a tuple of functions 87 | ```python 88 | def sampler(*args, **kwargs): 89 | ... 90 | return init, log_proposal, propose, update, get_params 91 | ``` 92 | where the returned functions have specific signatures. 93 | - A bunch of samplers: 94 | - ``jax_bayes.mcmc.langevin_fns`` (Unadjusted Langevin Algorithm) 95 | - ``jax_bayes.mcmc.mala_fns`` (Metropolis Adjusted Langevin Algorithm) 96 | - ``jax_bayes.mcmc.rk_langevin_fns`` (stochastic Runge Kutta solver for the continuous-time Langevin dyanmics) 97 | - ``jax_bayes.mcmc.hmc_fns`` (Hamitonian Monte Carlo algorithm) 98 | - ``jax_bayes.mcmc.rms_langevin_fns`` (preconditioned Langevin algorithm using the smoothed root-mean-square estimate of the gradient as the preconditionner matrix (like RMSProp)) 99 | - ``jax_bayes.mcmc.rwmh_fns`` implements (Random Walk Metropolis Hastings Algorithm.) 100 | - ``jax_bayes.mcmc.bb_mcmc`` wraps a given sampler into a "black-box" function suitable for sampling from simple densities (e.g. without sampling batches). 101 | - ``jax_bayes.variational`` contains the variational inference functionality. It provides: 102 | - ``jax_bayes.variational.variational_family`` which is a decorator that tree-ifies the variational family's methods. A variational family is defined as a callable returning a tuple of functions 103 | ```python 104 | def variational_family(*args, **kwargs): 105 | ... 106 | return init, sample, evaluate, get_samples, next_key, entropy 107 | ``` 108 | where the returned functions have specific signatures. The returned object is not, however, a tree-ified collection of functions but a class that contains these functions 109 | - ``jax_bayes.variational.diag_mvn_fns`` (diagonal multivariate gaussian family) 110 | 111 | ## Examples 112 | We have provided some diverse examples, some of which are under active development --- see ``examples/`` for more details. At a high level, we provide: 113 | 1. Shallow examples for sampling from regular probability distributions using MCMC and VI. 114 | 2. Deep examples for doing deep Bayesian ML (mostly with Colab) 115 | 1. Neural Network Regession 116 | 2. MNIST with 300-100-10 MLP 117 | 3. CIFAR10 with a CNN 118 | 4. Attention-based RNN Neural Machine Translation 119 | 120 | *Note: If you are familiar with ML and are looking to learn how to use JAX, these examples include regular ML versions that are relatively self-contained* 121 | 122 | mcmc | nn regression 123 | :-------------------------:|:-------------------------: 124 | ![](https://github.com/jamesvuc/jax-bayes/blob/master/assets/mcmc_2d.png "2d MCMC") | ![](https://github.com/jamesvuc/jax-bayes/blob/master/assets/nn_regression_mcmc.png "NN regression") 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /examples/deep/mnist/mnist_mcmc.py: -------------------------------------------------------------------------------- 1 | import haiku as hk 2 | 3 | import jax.numpy as jnp 4 | from jax.experimental import optimizers 5 | import jax 6 | 7 | import jax_bayes 8 | from jax_bayes.mcmc import mala_fns 9 | 10 | import sys, os, math, time 11 | import numpy as np 12 | from tqdm import trange 13 | 14 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 15 | import tensorflow_datasets as tfds 16 | 17 | from matplotlib import pyplot as plt 18 | 19 | #load the data and create the model 20 | def load_dataset(split, is_training, batch_size): 21 | ds = tfds.load('mnist:3.*.*', split=split).cache().repeat() 22 | if is_training: 23 | ds = ds.shuffle(10 * batch_size, seed=0) 24 | ds = ds.batch(batch_size) 25 | return iter(tfds.as_numpy(ds)) 26 | 27 | def net_fn(batch): 28 | """ Standard LeNet-300-100 MLP """ 29 | x = batch["image"].astype(jnp.float32) / 255. 30 | 31 | # we initialize the model with zeros since we're going to construct intiial 32 | # samples for the weights with additive Gaussian noise 33 | sig = 0.0 34 | mlp = hk.Sequential([ 35 | hk.Flatten(), 36 | hk.Linear(300, w_init=hk.initializers.RandomNormal(stddev=sig), 37 | b_init=hk.initializers.RandomNormal(stddev=sig)), 38 | jax.nn.relu, 39 | hk.Linear(100, w_init=hk.initializers.RandomNormal(stddev=sig), 40 | b_init=hk.initializers.RandomNormal(stddev=sig)), 41 | jax.nn.relu, 42 | hk.Linear(10, w_init=hk.initializers.RandomNormal(stddev=sig), 43 | b_init=hk.initializers.RandomNormal(stddev=sig)) 44 | ]) 45 | 46 | return mlp(x) 47 | 48 | 49 | def main(): 50 | #hyperparameters 51 | lr = 5e-3 52 | # lr = 1e-3 53 | reg = 1e-4 54 | num_samples = 100 # number of samples to approximate the posterior 55 | init_stddev = 5.0 # initial distribution for the samples will be N(0, 0.1) 56 | train_batch_size = 1_000 57 | eval_batch_size = 10_000 58 | 59 | #instantiate the model --- same as regular case 60 | net = hk.transform(net_fn) 61 | 62 | #build the sampler instead of optimizer 63 | sampler_fns = mala_fns 64 | seed = 0 65 | key = jax.random.PRNGKey(seed) 66 | sampler_init, sampler_propose, sampler_accept, sampler_update, sampler_get_params = \ 67 | sampler_fns(key, num_samples=num_samples, step_size=lr, init_stddev=init_stddev, 68 | noise_scale=0.1) 69 | 70 | 71 | # loss is the same as the regular case! This is because in regular ML, we're minimizing 72 | # the negative log-posterior logP(params | data) = logP(data | params) + logP(params) + constant 73 | # i.e. finding the MAP estimate. 74 | def loss(params, batch): 75 | logits = net.apply(params, jax.random.PRNGKey(0), batch) 76 | labels = jax.nn.one_hot(batch['label'], 10) 77 | l2_loss = 0.5 * sum(jnp.sum(jnp.square(p)) 78 | for p in jax.tree_leaves(params)) 79 | softmax_crossent = - jnp.mean(labels * jax.nn.log_softmax(logits)) 80 | 81 | return softmax_crossent + reg * l2_loss 82 | 83 | #the log-probability is the negative of the loss 84 | logprob = lambda p,b : - loss(p, b) 85 | 86 | @jax.jit 87 | def accuracy(params, batch): 88 | #auto-vectorize over the samples of params! only in JAX... 89 | pred_fn = jax.vmap(net.apply, in_axes=(0, None, None)) 90 | 91 | # this is crucial --- integrate (i.e. average) out the parameters 92 | all_logits = pred_fn(params, None, batch) 93 | probs = jnp.mean(jax.nn.softmax(all_logits, axis=-1), axis=0) 94 | 95 | return jnp.mean(jnp.argmax(probs, axis=-1) == batch['label']) 96 | 97 | #build the mcmc step. This is like the opimization step, but for sampling 98 | @jax.jit 99 | def mcmc_step(i, state, keys, batch): 100 | #extract parameters 101 | params = sampler_get_params(state) 102 | # rvals = sampler_get_params(state, idx=1) 103 | 104 | #form a partial eval of logprob on the data 105 | logp = lambda p: logprob(p, batch) #can make this 1-line? 106 | 107 | # evaluate *per-sample* gradients 108 | fx, dx = jax.vmap(jax.value_and_grad(logp))(params) 109 | 110 | # generat proposal states for the Markov chains 111 | prop_state, new_keys = sampler_propose(i, dx, state, keys) 112 | 113 | #we don't need to re-compute gradients for the accept stage 114 | prop_params = sampler_get_params(prop_state) 115 | fx_prop, dx_prop = jax.vmap(jax.value_and_grad(logp))(prop_params) 116 | 117 | # generate the acceptance indices from the Metropolis-Hastings 118 | # accept-reject step 119 | accept_idxs, keys = sampler_accept( 120 | i, fx, fx_prop, dx, state, dx_prop, prop_state, keys 121 | ) 122 | 123 | # update the sampler state based on the acceptance acceptance indices 124 | state, keys = sampler_update( 125 | i, accept_idxs, dx, state, dx_prop, prop_state, keys 126 | ) 127 | 128 | return fx, state, new_keys 129 | 130 | # load the data into memory and create batch iterators 131 | train_batches = load_dataset("train", is_training=True, batch_size=train_batch_size) 132 | val_batches = load_dataset("train", is_training=False, batch_size=eval_batch_size) 133 | test_batches = load_dataset("test", is_training=False, batch_size=eval_batch_size) 134 | 135 | 136 | #get a single sample of the params using the normal hk.init(...) 137 | params = net.init(jax.random.PRNGKey(42), next(train_batches)) 138 | 139 | # get a SamplerState object with `num_samples` params along dimension 0 140 | # generated by adding Gaussian noise (see sampler_fns(..., init_dist='normal')) 141 | sampler_state, sampler_keys = sampler_init(params) 142 | 143 | # iterate the the Markov chain 144 | for step in trange(2_501): 145 | train_logprobs, sampler_state, sampler_keys = \ 146 | mcmc_step(step, sampler_state, sampler_keys, next(train_batches)) 147 | 148 | if step % 500 == 0: 149 | params = sampler_get_params(sampler_state) 150 | val_acc = accuracy(params, next(val_batches)) 151 | test_acc = accuracy(params, next(test_batches)) 152 | print(f"step = {step}" 153 | f" | val acc = {val_acc:.3f}" 154 | f" | test acc = {test_acc:.3f}") 155 | 156 | def posterior_predictive(params, batch): 157 | pred_fn = lambda p:net.apply(p, None, batch) 158 | pred_fn = jax.vmap(pred_fn) 159 | 160 | logit_samples = pred_fn(params) # n_samples x batch_size x n_classes 161 | pred_samples = jnp.argmax(logit_samples, axis=-1) #n_samples x batch_size 162 | 163 | n_classes = logit_samples.shape[-1] 164 | batch_size = logit_samples.shape[1] 165 | probs = np.zeros((batch_size, n_classes)) 166 | for c in range(n_classes): 167 | idxs = pred_samples == c 168 | probs[:,c] = idxs.sum(axis=0) 169 | 170 | return probs / probs.sum(axis=1, keepdims=True) 171 | 172 | 173 | def do_analysis(): 174 | test_data = next(test_batches) 175 | pred_fn = jax.vmap(net.apply, in_axes=(0, None, None)) 176 | 177 | all_test_logits = pred_fn(params, None, test_data) 178 | probs = jnp.mean(jax.nn.softmax(all_test_logits, axis=-1), axis=0) 179 | correct_preds_mask = jnp.argmax(probs, axis=-1) == test_data['label'] 180 | 181 | # pp = posterior_predictive(params, test_data) 182 | pp = probs 183 | entropies = jax_bayes.utils.entropy(pp) 184 | 185 | correct_ent = entropies[correct_preds_mask] 186 | incorrect_ent = entropies[~correct_preds_mask] 187 | 188 | mean_correct_ent = jnp.mean(correct_ent) 189 | mean_incorrect_ent = jnp.mean(incorrect_ent) 190 | 191 | plt.hist(correct_ent, alpha=0.3, label='correct', density=True) 192 | plt.hist(incorrect_ent, alpha=0.3, label='incorrect', density=True) 193 | plt.axvline(x=mean_correct_ent, color='blue', label='mean correct') 194 | plt.axvline(x=mean_incorrect_ent, color='orange', label='mean incorrect') 195 | plt.legend() 196 | plt.xlabel("entropy") 197 | plt.ylabel("histogram density") 198 | plt.title("posterior predictive entropy of correct vs incorrect predictions") 199 | plt.show() 200 | 201 | do_analysis() 202 | plt.show() 203 | 204 | 205 | if __name__ == '__main__': 206 | main() -------------------------------------------------------------------------------- /jax_bayes/mcmc/sampler.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import functools 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | from jax._src.util import partial, safe_zip, safe_map, unzip2 7 | from jax._src.tree_util import ( 8 | tree_flatten, 9 | tree_unflatten, 10 | register_pytree_node, 11 | ) 12 | 13 | from jax.tree_util import tree_map, tree_leaves 14 | 15 | 16 | map = safe_map 17 | zip = safe_zip 18 | 19 | def sum_tree_leaves(tree): 20 | return sum(leaf for leaf in tree_leaves(tree)) 21 | 22 | 23 | # from jax.experimental.optimizers.py: 24 | # The implementation here basically works by flattening pytrees. There are two 25 | # levels of pytrees to think about: the pytree of params, which we can think of 26 | # as defining an "outer pytree", and a pytree produced by applying init_fun to 27 | # each leaf of the params pytree, which we can think of as the "inner pytrees". 28 | # Since pytrees can be flattened, that structure is isomorphic to a list of 29 | # lists (with no further nesting). 30 | 31 | SamplerState = namedtuple("SamplerState", 32 | ["packed_state", "tree_def", "subtree_defs"]) 33 | state_flatten_fn = lambda xs: ((xs.packed_state,), (xs.tree_def, xs.subtree_defs)) 34 | state_unflatten_fn = lambda data, xs: SamplerState(xs[0], data[0], data[1]) 35 | register_pytree_node(SamplerState, state_flatten_fn, state_unflatten_fn) 36 | 37 | SamplerKeys = namedtuple("SamplerKeys", ["keys"]) 38 | key_flatten_fn = lambda xs: ((xs.keys,), (1,)) #make the (1,) an empty tuple 39 | key_unflatten_fn = lambda data, xs: SamplerKeys(xs[0]) 40 | register_pytree_node(SamplerKeys, key_flatten_fn, key_unflatten_fn) 41 | 42 | SamplerFns = namedtuple("SamplerFns", 43 | ['init', 'propose', 'accept', 'update', 'get_params'] 44 | ) 45 | 46 | def check_equal_pytrees(tree1, tree2, prefix=""): 47 | if tree1 != tree2: #compares tree defs of the two trees 48 | msg = (prefix + 49 | "Passed a gradient tree that did " 50 | "not match the parameter tree structure with which it was " 51 | "initialized: parameter tree {} and grad tree {}.") 52 | raise TypeError(msg.format(tree2, tree1)) 53 | 54 | 55 | def sampler(sampler_builder): 56 | """Decorator to make an sampler defined for arrays generalize to containers. 57 | 58 | With this decorator, you can write init, propose, update, and get_params functions that 59 | each operate only on single arrays, and convert (or tree-ify) them to corresponding 60 | functions that operate on pytrees of parameters. See the samplers defined in 61 | sampler_fns.py for examples. 62 | 63 | Args: 64 | sampler_builder: a function that returns an ``(init, propose, update, get_params)`` 65 | triple of functions that might only work with ndarrays, as per 66 | 67 | Returns: 68 | An ``(init, propose, update, get_params)`` triple of functions that work on 69 | arbitrary pytrees. 70 | """ 71 | @functools.wraps(sampler_builder) 72 | def tree_sampler_builder(*args, **kwargs): 73 | init, propose, log_proposal, update, get_params, init_key = sampler_builder(*args, **kwargs) 74 | 75 | @functools.wraps(init) 76 | def tree_init(x0_tree): 77 | x0_flat, tree = tree_flatten(x0_tree) #tree is the treedef 78 | initial_keys = jax.random.split(init_key, len(x0_flat)) 79 | initial_states, initial_keys = unzip2(init(x0, k) for x0, k in \ 80 | zip(x0_flat, initial_keys)) 81 | states_flat, subtrees = unzip2(map(tree_flatten, initial_states)) 82 | return SamplerState(states_flat, tree, subtrees), SamplerKeys(initial_keys) 83 | 84 | @functools.wraps(propose) 85 | # def tree_propose(i, grad_tree, samp_state, samp_keys): 86 | def tree_propose(i, grad_tree, samp_state, samp_keys, **kwargs): 87 | states_flat, tree, subtrees = samp_state 88 | keys = samp_keys 89 | grad_flat, tree2 = tree_flatten(grad_tree) 90 | if tree2 != tree: #compares tree defs of the two trees 91 | msg = ("sampler propose function was passed a gradient tree that did " 92 | "not match the parameter tree structure with which it was " 93 | "initialized: parameter tree {} and grad tree {}.") 94 | raise TypeError(msg.format(tree, tree2)) 95 | states = map(tree_unflatten, subtrees, states_flat) 96 | keys, keys_meta = tree_flatten(keys) 97 | 98 | _propose = lambda g, x, k: propose(i, g, x, k, **kwargs) 99 | new_states, new_keys = unzip2(map(partial(propose, i), grad_flat, states, keys)) 100 | new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states)) 101 | for subtree, subtree2 in zip(subtrees, subtrees2): 102 | if subtree2 != subtree: 103 | msg = ("sampler propose function produced an output structure that " 104 | "did not match its input structure: input {} and output {}.") 105 | raise TypeError(msg.format(subtree, subtree2)) 106 | 107 | return SamplerState(new_states_flat, tree, subtrees), SamplerKeys(new_keys) 108 | 109 | def tree_accept( 110 | i, 111 | logp, 112 | logp_prop, 113 | grad_tree, 114 | state, 115 | prop_grad_tree, 116 | prop_state, 117 | samp_keys 118 | ): 119 | 120 | states_flat, tree, subtrees = state 121 | grad_flat, tree2 = tree_flatten(grad_tree) 122 | 123 | prop_states_flat, prop_tree, prop_subtrees = prop_state 124 | prop_grad_flat, prop_tree2 = tree_flatten(prop_grad_tree) 125 | 126 | # compute logQ(xprop|x) 127 | logq_prop = sum( 128 | map(partial(log_proposal, i), 129 | grad_flat, states_flat, prop_grad_flat, prop_states_flat 130 | ) 131 | ) 132 | # compute #logQ(x|xprop) 133 | logq_x = sum( 134 | map(partial(log_proposal, i), 135 | prop_grad_flat, prop_states_flat, grad_flat, states_flat 136 | ) 137 | ) 138 | 139 | # compute log_alpha = log(P(xprop)/P(x) * Q(x|xprop)/Q(xprop|x)) 140 | log_alpha = logp_prop + logq_x - logp - logq_prop 141 | 142 | # split a single key from the tree keys for the global acceptance step 143 | samp_keys_flat, samp_keys_meta = tree_flatten(samp_keys) 144 | global_key, next_key = jax.random.split(samp_keys_flat[0]) 145 | samp_keys_flat[0] = next_key 146 | 147 | # perform global acceptance step (sample accept/reject for each tree) 148 | # not each leaf of each tree 149 | U = jax.random.uniform(global_key, (logp.shape[0],)) 150 | accept_idxs = jnp.log(U) < log_alpha 151 | 152 | return accept_idxs, SamplerKeys(samp_keys_flat) 153 | 154 | @functools.wraps(update) 155 | def tree_update( 156 | i, 157 | accept_idxs, 158 | grad_tree, 159 | samp_state, 160 | prop_grad_tree, 161 | prop_samp_state, 162 | samp_keys 163 | ): 164 | """ 165 | logp_x and logp_xprop are (N,) arrays (i.e. arrays of scalars) 166 | of log probs to be passed to every call of update. 167 | """ 168 | keys = samp_keys 169 | 170 | states_flat, tree, subtrees = samp_state 171 | grad_flat, tree2 = tree_flatten(grad_tree) 172 | # compare state grad tree and state tree 173 | check_equal_pytrees(tree, tree2, prefix='state tree vs grad tree: ') 174 | 175 | prop_states_flat, prop_tree, prop_subtrees = prop_samp_state 176 | prop_grad_flat, prop_tree2 = tree_flatten(prop_grad_tree) 177 | # compare proposal tree and state tree 178 | check_equal_pytrees(prop_tree, tree, prefix='prop state tree vs grad tree: ') 179 | 180 | # compare proposal grad tree and state grad tree 181 | check_equal_pytrees(prop_tree2, prop_tree, prefix='prop state tree vs prop grad tree: ') 182 | 183 | states = map(tree_unflatten, subtrees, states_flat) 184 | prop_states = map(tree_unflatten, prop_subtrees, prop_states_flat) 185 | keys, keys_meta = tree_flatten(keys) 186 | new_states, new_keys= unzip2( 187 | map(partial(update, i, accept_idxs), 188 | grad_flat, states, prop_grad_flat, prop_states, keys) 189 | ) 190 | new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states)) 191 | for subtree, subtree2 in zip(subtrees, subtrees2): 192 | if subtree2 != subtree: 193 | msg = ("sampler update function produced an output structure that " 194 | "did not match its input structure: input {} and output {}.") 195 | raise TypeError(msg.format(subtree, subtree2)) 196 | 197 | return SamplerState(new_states_flat, tree, subtrees), SamplerKeys(new_keys) 198 | 199 | @functools.wraps(get_params) 200 | def tree_get_params(samp_state, **kwargs): 201 | states_flat, tree, subtrees = samp_state 202 | states = map(tree_unflatten, subtrees, states_flat) 203 | params = (get_params(s, **kwargs) for s in states) 204 | return tree_unflatten(tree, params) 205 | 206 | return SamplerFns(tree_init, tree_propose, tree_accept, tree_update, tree_get_params) 207 | return tree_sampler_builder -------------------------------------------------------------------------------- /examples/deep/cifar10/cifar10.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "cifar10.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "accelerator": "GPU" 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "CLATnUvpdftH", 21 | "colab_type": "text" 22 | }, 23 | "source": [ 24 | "# jax-bayes CIFAR10 Example --- Traditional ML Approach\n", 25 | "\n", 26 | "## Set up the environment" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "metadata": { 32 | "id": "UfMSaNHlceB7", 33 | "colab_type": "code", 34 | "colab": { 35 | "base_uri": "https://localhost:8080/", 36 | "height": 1000 37 | }, 38 | "outputId": "d04f814f-8140-4a5c-f67a-d3035258bc14" 39 | }, 40 | "source": [ 41 | "#see https://github.com/google/jax#pip-installation\n", 42 | "!pip install --upgrade https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.51-cp36-none-manylinux2010_x86_64.whl\n", 43 | "!pip install --upgrade jax\n", 44 | "!pip install git+https://github.com/deepmind/dm-haiku\n", 45 | "!pip install git+https://github.com/jamesvuc/jax-bayes" 46 | ], 47 | "execution_count": 2, 48 | "outputs": [ 49 | { 50 | "output_type": "stream", 51 | "text": [ 52 | "Collecting jaxlib==0.1.51\n", 53 | "\u001b[?25l Downloading https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.51-cp36-none-manylinux2010_x86_64.whl (71.5MB)\n", 54 | "\u001b[K |████████████████████████████████| 71.5MB 42kB/s \n", 55 | "\u001b[?25hRequirement already satisfied, skipping upgrade: scipy in /usr/local/lib/python3.6/dist-packages (from jaxlib==0.1.51) (1.4.1)\n", 56 | "Requirement already satisfied, skipping upgrade: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jaxlib==0.1.51) (1.18.5)\n", 57 | "Requirement already satisfied, skipping upgrade: absl-py in /usr/local/lib/python3.6/dist-packages (from jaxlib==0.1.51) (0.9.0)\n", 58 | "Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jaxlib==0.1.51) (1.15.0)\n", 59 | "Installing collected packages: jaxlib\n", 60 | " Found existing installation: jaxlib 0.1.52\n", 61 | " Uninstalling jaxlib-0.1.52:\n", 62 | " Successfully uninstalled jaxlib-0.1.52\n", 63 | "Successfully installed jaxlib-0.1.51\n", 64 | "Requirement already up-to-date: jax in /usr/local/lib/python3.6/dist-packages (0.1.75)\n", 65 | "Requirement already satisfied, skipping upgrade: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jax) (1.18.5)\n", 66 | "Requirement already satisfied, skipping upgrade: absl-py in /usr/local/lib/python3.6/dist-packages (from jax) (0.9.0)\n", 67 | "Requirement already satisfied, skipping upgrade: opt-einsum in /usr/local/lib/python3.6/dist-packages (from jax) (3.3.0)\n", 68 | "Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jax) (1.15.0)\n", 69 | "Collecting git+https://github.com/deepmind/dm-haiku\n", 70 | " Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-qx61eemy\n", 71 | " Running command git clone -q https://github.com/deepmind/dm-haiku /tmp/pip-req-build-qx61eemy\n", 72 | "Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from dm-haiku==0.0.2) (0.9.0)\n", 73 | "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.6/dist-packages (from dm-haiku==0.0.2) (1.18.5)\n", 74 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from absl-py>=0.7.1->dm-haiku==0.0.2) (1.15.0)\n", 75 | "Building wheels for collected packages: dm-haiku\n", 76 | " Building wheel for dm-haiku (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 77 | " Created wheel for dm-haiku: filename=dm_haiku-0.0.2-cp36-none-any.whl size=289739 sha256=0ea4611f09ee7534f77a37f5f875814f9437bb2aa72d43f19d3b69d4892aabfb\n", 78 | " Stored in directory: /tmp/pip-ephem-wheel-cache-gsov__2x/wheels/97/0f/e9/17f34e377f8d4060fa88a7e82bee5d8afbf7972384768a5499\n", 79 | "Successfully built dm-haiku\n", 80 | "Installing collected packages: dm-haiku\n", 81 | "Successfully installed dm-haiku-0.0.2\n", 82 | "Collecting git+https://github.com/jamesvuc/jax-bayes\n", 83 | " Cloning https://github.com/jamesvuc/jax-bayes to /tmp/pip-req-build-tbzmaa7c\n", 84 | " Running command git clone -q https://github.com/jamesvuc/jax-bayes /tmp/pip-req-build-tbzmaa7c\n", 85 | "Requirement already satisfied: absl-py>=0.9.0 in /usr/local/lib/python3.6/dist-packages (from jax-bayes==0.0.1) (0.9.0)\n", 86 | "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.6/dist-packages (from jax-bayes==0.0.1) (1.18.5)\n", 87 | "Requirement already satisfied: opt-einsum>=3.3.0 in /usr/local/lib/python3.6/dist-packages (from jax-bayes==0.0.1) (3.3.0)\n", 88 | "Requirement already satisfied: protobuf>=3.12.4 in /usr/local/lib/python3.6/dist-packages (from jax-bayes==0.0.1) (3.12.4)\n", 89 | "Collecting scipy>=1.5.2\n", 90 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/2b/a8/f4c66eb529bb252d50e83dbf2909c6502e2f857550f22571ed8556f62d95/scipy-1.5.2-cp36-cp36m-manylinux1_x86_64.whl (25.9MB)\n", 91 | "\u001b[K |████████████████████████████████| 25.9MB 117kB/s \n", 92 | "\u001b[?25hRequirement already satisfied: six>=1.15.0 in /usr/local/lib/python3.6/dist-packages (from jax-bayes==0.0.1) (1.15.0)\n", 93 | "Collecting tqdm>=4.48.2\n", 94 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/28/7e/281edb5bc3274dfb894d90f4dbacfceaca381c2435ec6187a2c6f329aed7/tqdm-4.48.2-py2.py3-none-any.whl (68kB)\n", 95 | "\u001b[K |████████████████████████████████| 71kB 8.3MB/s \n", 96 | "\u001b[?25hRequirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.12.4->jax-bayes==0.0.1) (49.2.0)\n", 97 | "Building wheels for collected packages: jax-bayes\n", 98 | " Building wheel for jax-bayes (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 99 | " Created wheel for jax-bayes: filename=jax_bayes-0.0.1-cp36-none-any.whl size=1009734 sha256=ce9211265ff46056ed79baedb149a7f3d5420fb2b4234ecbcfc5d73695119b9f\n", 100 | " Stored in directory: /tmp/pip-ephem-wheel-cache-38e51wxr/wheels/31/65/d6/bcf4b5e84c6f6f176e73850145875e806569759c23081b4446\n", 101 | "Successfully built jax-bayes\n", 102 | "\u001b[31mERROR: tensorflow 2.3.0 has requirement scipy==1.4.1, but you'll have scipy 1.5.2 which is incompatible.\u001b[0m\n", 103 | "\u001b[31mERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.\u001b[0m\n", 104 | "Installing collected packages: scipy, tqdm, jax-bayes\n", 105 | " Found existing installation: scipy 1.4.1\n", 106 | " Uninstalling scipy-1.4.1:\n", 107 | " Successfully uninstalled scipy-1.4.1\n", 108 | " Found existing installation: tqdm 4.41.1\n", 109 | " Uninstalling tqdm-4.41.1:\n", 110 | " Successfully uninstalled tqdm-4.41.1\n", 111 | "Successfully installed jax-bayes-0.0.1 scipy-1.5.2 tqdm-4.48.2\n" 112 | ], 113 | "name": "stdout" 114 | } 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "metadata": { 120 | "id": "l70DSI0ajQJq", 121 | "colab_type": "code", 122 | "colab": {} 123 | }, 124 | "source": [ 125 | "import haiku as hk\n", 126 | "\n", 127 | "import jax.numpy as jnp\n", 128 | "from jax.experimental import optimizers\n", 129 | "import jax\n", 130 | "\n", 131 | "import sys, os, math, time\n", 132 | "import numpy as np\n", 133 | "\n", 134 | "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' \n", 135 | "import tensorflow_datasets as tfds" 136 | ], 137 | "execution_count": 3, 138 | "outputs": [] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": { 143 | "id": "B686kNdCzFEP", 144 | "colab_type": "text" 145 | }, 146 | "source": [ 147 | "## Build the dataset loader and CNN" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "metadata": { 153 | "id": "AzgaUa2owIqg", 154 | "colab_type": "code", 155 | "colab": {} 156 | }, 157 | "source": [ 158 | "def load_dataset(split, is_training, batch_size, repeat=True, seed=0):\n", 159 | " if repeat:\n", 160 | " ds = tfds.load('cifar10', split=split).cache().repeat()\n", 161 | " else:\n", 162 | " ds = tfds.load('cifar10', split=split).cache()\n", 163 | " if is_training:\n", 164 | " ds = ds.shuffle(10 * batch_size, seed=seed)\n", 165 | " ds = ds.batch(batch_size)\n", 166 | " return tfds.as_numpy(ds)\n", 167 | "\n", 168 | "# build a 32-32-64-32 CNN with max-pooling \n", 169 | "# followed by a 128-10-n_classes MLP\n", 170 | "class Net(hk.Module):\n", 171 | " def __init__(self, dropout=0.1, n_classes=10):\n", 172 | " super(Net, self).__init__()\n", 173 | " self.conv_stage = hk.Sequential([\n", 174 | " #block 1\n", 175 | " hk.Conv2D(32, kernel_shape=3, stride=1, padding='SAME'), \n", 176 | " jax.nn.relu, \n", 177 | " hk.MaxPool(window_shape=(1,2,2,1), strides=(1,1,1,1), padding='VALID'),\n", 178 | " # block 2\n", 179 | " hk.Conv2D(32, kernel_shape=3, stride=1, padding='SAME'), \n", 180 | " jax.nn.relu, \n", 181 | " hk.MaxPool(window_shape=(1,2,2,1), strides=(1,1,1,1), padding='VALID'),\n", 182 | " # block 3\n", 183 | " hk.Conv2D(64, kernel_shape=3, stride=1, padding='SAME'), \n", 184 | " jax.nn.relu, \n", 185 | " hk.MaxPool(window_shape=(1,2,2,1), strides=(1,1,1,1), padding='VALID'),\n", 186 | " # block 4\n", 187 | " hk.Conv2D(32, kernel_shape=3, stride=1, padding='SAME')\n", 188 | " ])\n", 189 | "\n", 190 | " self.mlp_stage = hk.Sequential([\n", 191 | " hk.Flatten(),\n", 192 | " hk.Linear(128), \n", 193 | " jax.nn.relu, \n", 194 | " hk.Linear(n_classes)\n", 195 | " ])\n", 196 | "\n", 197 | " self.p_dropout = dropout\n", 198 | "\n", 199 | " def __call__(self, x, use_dropout=True):\n", 200 | " x = self.conv_stage(x)\n", 201 | " \n", 202 | " dropout_rate = self.p_dropout if use_dropout else 0.0\n", 203 | " x = hk.dropout(hk.next_rng_key(), dropout_rate, x)\n", 204 | "\n", 205 | " return self.mlp_stage(x)\n", 206 | "\n", 207 | "# standard normalization constants\n", 208 | "mean_norm = jnp.array([[0.4914, 0.4822, 0.4465]])\n", 209 | "std_norm = jnp.array([[0.247, 0.243, 0.261]])\n", 210 | "\n", 211 | "#define the net-function \n", 212 | "def net_fn(batch, use_dropout):\n", 213 | " net = Net(dropout=0.0)\n", 214 | " x = batch['image']/255.0\n", 215 | " x = (x - mean_norm) / std_norm\n", 216 | " return net(x, use_dropout)" 217 | ], 218 | "execution_count": 4, 219 | "outputs": [] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "metadata": { 224 | "id": "eo-Gypdbo0wY", 225 | "colab_type": "code", 226 | "colab": {} 227 | }, 228 | "source": [ 229 | "# hyperparameters\n", 230 | "lr = 1e-3\n", 231 | "reg = 1e-4\n", 232 | "\n", 233 | "# instantiate the network\n", 234 | "net = hk.transform(net_fn)\n", 235 | "\n", 236 | "# build the optimizer\n", 237 | "opt_init, opt_update, opt_get_params = optimizers.rmsprop(lr)\n", 238 | "\n", 239 | "# standard L2-regularized crossentropy loss function\n", 240 | "def loss(params, rng, batch):\n", 241 | " logits = net.apply(params, rng, batch, use_dropout=True)\n", 242 | " labels = jax.nn.one_hot(batch['label'], 10)\n", 243 | "\n", 244 | " l2_loss = 0.5 * sum(jnp.sum(jnp.square(p)) \n", 245 | " for p in jax.tree_leaves(params))\n", 246 | " softmax_crossent = - jnp.mean(labels * jax.nn.log_softmax(logits))\n", 247 | "\n", 248 | " return softmax_crossent + reg * l2_loss\n", 249 | "\n", 250 | "@jax.jit\n", 251 | "def accuracy(params, batch):\n", 252 | " preds = net.apply(params, jax.random.PRNGKey(101), batch, use_dropout=False)\n", 253 | " return jnp.mean(jnp.argmax(preds, axis=-1) == batch['label'])\n", 254 | "\n", 255 | "@jax.jit\n", 256 | "def train_step(i, opt_state, rng, batch):\n", 257 | "\tparams = opt_get_params(opt_state)\n", 258 | "\tfx, dx = jax.value_and_grad(loss)(params, rng, batch)\n", 259 | "\topt_state = opt_update(i, dx, opt_state)\n", 260 | "\treturn fx, opt_state" 261 | ], 262 | "execution_count": 5, 263 | "outputs": [] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": { 268 | "id": "rfv4Mldkdt40", 269 | "colab_type": "text" 270 | }, 271 | "source": [ 272 | "## Load the Initialization, Val and Test Batches & Do the Optimization" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "metadata": { 278 | "id": "NYeEembgpgSk", 279 | "colab_type": "code", 280 | "colab": {} 281 | }, 282 | "source": [ 283 | "init_batches = load_dataset(\"train\", is_training=True, batch_size=256)\n", 284 | "val_batches = load_dataset(\"train\", is_training=False, batch_size=1_000)\n", 285 | "test_batches = load_dataset(\"test\", is_training=False, batch_size=1_000)" 286 | ], 287 | "execution_count": null, 288 | "outputs": [] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "metadata": { 293 | "id": "-MDzIu4uxmeD", 294 | "colab_type": "code", 295 | "colab": { 296 | "base_uri": "https://localhost:8080/", 297 | "height": 397 298 | }, 299 | "outputId": "240f3443-5529-4e12-9ad8-7af870738e3d" 300 | }, 301 | "source": [ 302 | "%%time\n", 303 | "\n", 304 | "# intialize the paramaeters\n", 305 | "params = net.init(jax.random.PRNGKey(42), next(init_batches), use_dropout=True)\n", 306 | "opt_state = opt_init(params)\n", 307 | "\n", 308 | "# initialize a key for the dropout\n", 309 | "rng = jax.random.PRNGKey(2)\n", 310 | "\n", 311 | "for epoch in range(100):\n", 312 | "\t #generate a shuffled epoch of training data\n", 313 | " train_batches = load_dataset(\"train\", is_training=True,\n", 314 | " batch_size=256, repeat=False, seed=epoch)\n", 315 | " \n", 316 | " for batch in train_batches:\n", 317 | " # run an optimization step\n", 318 | " train_loss, opt_state = train_step(epoch, opt_state, rng, batch)\n", 319 | " \n", 320 | " # make more rng for the dropout\n", 321 | " rng, _ = jax.random.split(rng)\n", 322 | "\t\n", 323 | " if epoch % 5 == 0:\n", 324 | " params = opt_get_params(opt_state)\n", 325 | " val_acc = accuracy(params, next(val_batches))\n", 326 | " test_acc = accuracy(params, next(test_batches))\n", 327 | " print(f\"epoch = {epoch}\"\n", 328 | " f\" | train loss = {train_loss:.4f}\"\n", 329 | " f\" | val acc = {val_acc:.3f}\"\n", 330 | " f\" | test acc = {test_acc:.3f}\")" 331 | ], 332 | "execution_count": 7, 333 | "outputs": [ 334 | { 335 | "output_type": "stream", 336 | "text": [ 337 | "epoch = 0 | train loss = 0.1405 | val acc = 0.489 | test acc = 0.515\n", 338 | "epoch = 5 | train loss = 0.0659 | val acc = 0.788 | test acc = 0.688\n", 339 | "epoch = 10 | train loss = 0.0596 | val acc = 0.818 | test acc = 0.669\n", 340 | "epoch = 15 | train loss = 0.0554 | val acc = 0.896 | test acc = 0.702\n", 341 | "epoch = 20 | train loss = 0.0598 | val acc = 0.880 | test acc = 0.646\n", 342 | "epoch = 25 | train loss = 0.0547 | val acc = 0.939 | test acc = 0.709\n", 343 | "epoch = 30 | train loss = 0.0504 | val acc = 0.966 | test acc = 0.714\n", 344 | "epoch = 35 | train loss = 0.0502 | val acc = 0.953 | test acc = 0.705\n", 345 | "epoch = 40 | train loss = 0.0637 | val acc = 0.954 | test acc = 0.723\n", 346 | "epoch = 45 | train loss = 0.0494 | val acc = 0.957 | test acc = 0.718\n", 347 | "epoch = 50 | train loss = 0.0472 | val acc = 0.952 | test acc = 0.731\n", 348 | "epoch = 55 | train loss = 0.0458 | val acc = 0.972 | test acc = 0.717\n", 349 | "epoch = 60 | train loss = 0.0503 | val acc = 0.952 | test acc = 0.730\n", 350 | "epoch = 65 | train loss = 0.0490 | val acc = 0.962 | test acc = 0.705\n", 351 | "epoch = 70 | train loss = 0.0554 | val acc = 0.959 | test acc = 0.695\n", 352 | "epoch = 75 | train loss = 0.0488 | val acc = 0.973 | test acc = 0.716\n", 353 | "epoch = 80 | train loss = 0.0479 | val acc = 0.976 | test acc = 0.726\n", 354 | "epoch = 85 | train loss = 0.0499 | val acc = 0.963 | test acc = 0.728\n", 355 | "epoch = 90 | train loss = 0.0565 | val acc = 0.947 | test acc = 0.722\n", 356 | "epoch = 95 | train loss = 0.0491 | val acc = 0.963 | test acc = 0.725\n", 357 | "CPU times: user 30min 24s, sys: 11min 53s, total: 42min 17s\n", 358 | "Wall time: 23min 15s\n" 359 | ], 360 | "name": "stdout" 361 | } 362 | ] 363 | } 364 | ] 365 | } -------------------------------------------------------------------------------- /jax_bayes/mcmc/sampler_fns.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jax.example_libraries.optimizers import make_schedule 6 | 7 | from .sampler import sampler 8 | from .utils import init_distributions 9 | 10 | centered_uniform = \ 11 | lambda *args, **kwargs: jax.random.uniform(*args, **kwargs) - 0.5 12 | init_distributions = dict(normal=jax.random.normal, 13 | uniform=centered_uniform) 14 | 15 | def match_dims(src, target, start_dim=1): 16 | """ 17 | returns an array with the data from 'src' and same number of dims 18 | as 'target' by padding with empty dimensions starting at 'start_dim' 19 | """ 20 | new_dims = tuple(range(start_dim, len(target.shape))) 21 | return jnp.expand_dims(src, new_dims) 22 | 23 | @sampler 24 | def langevin_fns( 25 | base_key, 26 | num_samples=10, 27 | step_size=1e-3, 28 | noise_scale=1.0, 29 | init_stddev=0.1, 30 | init_dist='normal' 31 | ): 32 | """Constructs sampler functions for the Unadjusted Langevin Algorithm. 33 | See e.g. https://arxiv.org/pdf/1605.01559.pdf. 34 | 35 | Args: 36 | base_key: jax.random.PRNGKey 37 | num_samples: number of samples to initialize; either > 0 or -1. 38 | If num_samples == -1, assumes that the initial samples are 39 | already constructed. 40 | step_size: float or callable w/ signature step_size(i) 41 | init_stdev: nonnegative float standard deviation for initialization 42 | of initial samples (ignored if num_samples == -1) 43 | init_dist: str in ['normal', 'centered_uniform'] to sample perturbations 44 | for the initial distribution 45 | 46 | Returns: 47 | sampler function tuple (init, propose, update, get_params) 48 | """ 49 | 50 | step_size = make_schedule(step_size) 51 | noise_scale = make_schedule(noise_scale) 52 | 53 | if isinstance(init_dist, str): 54 | init_dist = init_distributions[init_dist] 55 | 56 | def log_proposal(*args): 57 | return jnp.zeros(num_samples) 58 | 59 | def init(x0, key): 60 | init_key, next_key = jax.random.split(key) 61 | if num_samples == -1: 62 | return x0, next_key 63 | x = init_dist(init_key, (num_samples,) + x0.shape) 64 | return x0 + init_stddev * x, next_key 65 | 66 | def propose(i, g, x, key, **kwargs): 67 | key, next_key = jax.random.split(key) 68 | Z = jax.random.normal(key, x.shape) 69 | return x + step_size(i) * g + \ 70 | jnp.sqrt(2 * step_size(i) * noise_scale(i)) * Z, next_key 71 | 72 | def update(i, accept_idxs, g, x, gprop, xprop, key): 73 | key, next_key = jax.random.split(key) 74 | return xprop, next_key 75 | 76 | def get_params(x): 77 | return x 78 | 79 | return init, propose, log_proposal, update, get_params, base_key 80 | 81 | @sampler 82 | def mala_fns( 83 | base_key, 84 | num_samples=10, 85 | step_size=1e-3, 86 | init_stddev=0.1, 87 | noise_scale=1.0, 88 | init_dist='normal' 89 | ): 90 | """Constructs sampler functions for the Metropolis Adjusted Langevin Algorithm. 91 | See e.g. http://probability.ca/jeff/ftpdir/lang.pdf 92 | 93 | Args: 94 | base_key: jax.random.PRNGKey 95 | num_samples: number of samples to initialize; either > 0 or -1. 96 | If num_samples == -1, assumes that the initial samples are 97 | already constructed. 98 | step_size: float or callable w/ signature step_size(i) 99 | init_stdev: nonnegative float standard deviation for initialization 100 | of initial samples (ignored if num_samples == -1) 101 | init_dist: str in ['normal', 'centered_uniform'] to sample perturbations 102 | for the initial distribution 103 | 104 | Returns: 105 | sampler function tuple (init, propose, update, get_params) 106 | """ 107 | step_size = make_schedule(step_size) 108 | noise_scale = make_schedule(noise_scale) 109 | 110 | if isinstance(init_dist, str): 111 | init_dist = init_distributions[init_dist] 112 | 113 | def log_proposal(i, g, x, gprop, xprop): #grads come first 114 | #computes log q(xprop|x) 115 | x, = x 116 | xprop, = xprop 117 | return - 0.5 * jnp.sum(jnp.square((xprop - x - step_size(i) * g)) \ 118 | / 2 * step_size(i) * noise_scale(i)**2) 119 | log_proposal = jax.vmap(log_proposal, in_axes=(None, 0, 0, 0, 0)) 120 | 121 | def init(x0, key): 122 | init_key, next_key = jax.random.split(key) 123 | if num_samples == -1: 124 | return x0, next_key 125 | x = init_dist(init_key, (num_samples,) + x0.shape) 126 | return x0 + init_stddev * x, next_key 127 | 128 | def propose(i, g, x, key, **kwargs): 129 | key, next_key = jax.random.split(key) 130 | Z = jax.random.normal(key, x.shape) 131 | return x + step_size(i) * g + jnp.sqrt(2 * step_size(i)) * noise_scale(i) * Z, next_key 132 | 133 | def update(i, accept_idxs, g, x, gprop, xprop, key): 134 | """ if the state had additional data, you would need to accept them too""" 135 | accept_idxs = match_dims(accept_idxs, x) 136 | mask = accept_idxs.astype(jnp.float32) 137 | 138 | xnext = x * (1.0 - mask) + xprop * mask 139 | return xnext, key 140 | 141 | def get_params(x): 142 | return x 143 | 144 | return init, propose, log_proposal, update, get_params, base_key 145 | 146 | @sampler 147 | def rk_langevin_fns( 148 | base_key, 149 | num_samples=10, 150 | step_size=1e-3, 151 | init_stddev=0.1, 152 | init_dist='normal' 153 | ): 154 | """Constructs sampler functions for a Stochastic Runge Kutta integrator of the 155 | continuous-time Langevin dynamics. 156 | 157 | See e.g. https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_method_(SDE). 158 | 159 | One step of the integration is computed as 0.5*(K1 + K2) where K1, K2 160 | are the two 'knots' of the integrator. K1 is computed in propose(...) and 161 | K2 is computed in update(...) since we need to re-compute gradients. 162 | 163 | Args: 164 | base_key: jax.random.PRNGKey 165 | num_samples: number of samples to initialize; either > 0 or -1. 166 | If num_samples == -1, assumes that the initial samples are 167 | already constructed. 168 | step_size: float or callable w/ signature step_size(i) 169 | init_stdev: nonnegative float standard deviation for initialization 170 | of initial samples (ignored if num_samples == -1) 171 | init_dist: str in ['normal', 'centered_uniform'] to sample perturbations 172 | for the initial distribution 173 | 174 | Returns: 175 | sampler function tuple (init, propose, update, get_params) 176 | """ 177 | step_size = make_schedule(step_size) 178 | if isinstance(init_dist, str): 179 | init_dist = init_distributions[init_dist] 180 | 181 | def log_proposal(*args): 182 | return jnp.zeros(num_samples) 183 | 184 | def init(x0, key): 185 | init_key, next_key = jax.random.split(key) 186 | if num_samples == -1: 187 | return x0, next_key 188 | x = init_dist(init_key, (num_samples,) + x0.shape) 189 | return x0 + init_stddev * x, next_key 190 | 191 | def propose(i, g, x, key, **kwargs): 192 | h = step_size(i) 193 | root_h = math.sqrt(h) 194 | 195 | w_key, s_key, next_key = jax.random.split(key, 3) 196 | W = jax.random.normal(w_key, x.shape) * root_h 197 | S = jax.random.bernoulli(s_key, 0.5, (x.shape[0],)) * 2 - 1 198 | S = match_dims(S, x) 199 | 200 | K1 = x + h * g + (W - root_h * S) * math.sqrt(2) 201 | return K1, key 202 | 203 | def update(i, accept_idxs, g, x, gprop, xprop, key): 204 | h = step_size(i) 205 | root_h = math.sqrt(h) 206 | 207 | w_key, s_key, next_key = jax.random.split(key, 3) 208 | W = jax.random.normal(w_key, x.shape) * root_h 209 | S = jax.random.bernoulli(s_key, 0.5, (x.shape[0],)) * 2 - 1 210 | S = match_dims(S, x) 211 | 212 | K2 = h * gprop + (W + root_h * S) * math.sqrt(2) 213 | return 0.5 * x + 0.5 * (xprop + K2), next_key 214 | 215 | def get_params(x): 216 | return x 217 | 218 | return init, propose, log_proposal, update, get_params, base_key 219 | 220 | @sampler 221 | def hmc_fns( 222 | base_key, 223 | num_samples=10, 224 | step_size=1e-3, 225 | noise_scale=1.0, 226 | init_stddev=0.1, 227 | init_dist='normal' 228 | ): 229 | """Constructs sampler functions for the Hamiltonia Monte Carlo algorithm. 230 | See e.g. http://probability.ca/jeff/ftpdir/lang.pdf 231 | 232 | Args: 233 | base_key: jax.random.PRNGKey 234 | num_samples: number of samples to initialize; either > 0 or -1. 235 | If num_samples == -1, assumes that the initial samples are 236 | already constructed. 237 | step_size: float or callable w/ signature step_size(i) 238 | init_stdev: nonnegative float standard deviation for initialization 239 | of initial samples (ignored if num_samples == -1) 240 | init_dist: str in ['normal', 'centered_uniform'] to sample perturbations 241 | for the initial distribution 242 | 243 | Returns: 244 | sampler function tuple (init, propose, update, get_params) 245 | """ 246 | step_size = make_schedule(step_size) 247 | noise_scale = make_schedule(noise_scale) 248 | if isinstance(init_dist, str): 249 | init_dist = init_distributions[init_dist] 250 | 251 | def dot_product(x, y): 252 | return jnp.sum(x * y) 253 | 254 | def log_proposal(i, g, x, gprop, xprop): #grads come first 255 | #computes log q(xprop|x) 256 | xprop, rprop = xprop 257 | return 0.5 * noise_scale(i) * dot_product(rprop, rprop) 258 | log_proposal = jax.vmap(log_proposal, in_axes=(None, 0, 0, 0, 0)) 259 | 260 | def init(x0, key): 261 | init_key, r_key, next_key = jax.random.split(key, 3) 262 | if num_samples == -1: 263 | r = jax.random.normal(r_key, x0.shape) 264 | return (x0, r*noise_scale(0)), next_key 265 | x = init_dist(init_key, (num_samples,) + x0.shape) 266 | r = jax.random.normal(r_key, (num_samples,) + x0.shape) 267 | return (x0 + init_stddev * x, r * noise_scale(0)), next_key 268 | 269 | def propose(i, g, x, key, is_final=False): 270 | """ 271 | iterate this several times for multistep leapfrog integrator. 272 | is_final is used for the final update of leapfrog integrator, 273 | which only modifies rprop and not xprop. 274 | """ 275 | next_key = key 276 | x, r = x 277 | rprop = r + 0.5 * step_size(i) * g 278 | if is_final: 279 | xprop = x 280 | else: 281 | xprop = x + step_size(i) * rprop 282 | 283 | return (xprop, rprop), next_key 284 | 285 | def update(i, accept_idxs, g, x, gprop, xprop, key): 286 | u_key, r_key, next_key = jax.random.split(key, 3) 287 | 288 | x, r = x 289 | xprop, rprop = xprop 290 | 291 | accept_idxs = match_dims(accept_idxs, x) 292 | mask = accept_idxs.astype(jnp.float32) 293 | xnext = x * (1.0 - mask) + xprop * mask 294 | 295 | #this is for the first step of the leapfrog integrator 296 | rnext = jax.random.normal(r_key, x.shape) * noise_scale(i) 297 | return (xnext, rnext), next_key 298 | 299 | def get_params(x): 300 | return x[0] 301 | 302 | return init, propose, log_proposal, update, get_params, base_key 303 | 304 | @sampler 305 | def rms_langevin_fns( 306 | base_key, 307 | num_samples=10, 308 | step_size=1e-3, 309 | noise_scale=1.0, 310 | beta=0.9, 311 | eps=1e-9, 312 | init_stddev=0.1, 313 | init_dist='normal' 314 | ): 315 | """Constructs sampler functions for the RMS-preconditioned Langevin algorithm. 316 | See e.g. https://arxiv.org/pdf/1512.07666.pdf 317 | 318 | Args: 319 | base_key: jax.random.PRNGKey 320 | num_samples: number of samples to initialize; either > 0 or -1. 321 | If num_samples == -1, assumes that the initial samples are 322 | already constructed. 323 | step_size: float or callable w/ signature step_size(i) 324 | init_stdev: nonnegative float standard deviation for initialization 325 | of initial samples (ignored if num_samples == -1) 326 | init_dist: str in ['normal', 'centered_uniform'] to sample perturbations 327 | for the initial distribution 328 | 329 | Returns: 330 | sampler function tuple (init, propose, update, get_params) 331 | """ 332 | step_size = make_schedule(step_size) 333 | noise_scale = make_schedule(noise_scale) 334 | 335 | if isinstance(init_dist, str): 336 | init_dist = init_distributions[init_dist] 337 | 338 | def log_proposal(*args): #grads come first 339 | return jnp.zeros(num_samples) 340 | 341 | def init(x0, key): 342 | init_key, next_key = jax.random.split(key) 343 | if num_samples == -1: 344 | r = jnp.zeros_like(x0) 345 | return (x0, r), next_key 346 | x = init_dist(init_key, (num_samples,) + x0.shape) 347 | r = jnp.zeros_like(x) 348 | return (x0 + init_stddev * x, r), next_key 349 | 350 | def propose(i, g, x, key, **kwargs): 351 | key, next_key = jax.random.split(key) 352 | x, r = x 353 | Z = jax.random.normal(key, x.shape) 354 | 355 | r = beta * r + (1. - beta) * jnp.square(g) 356 | 357 | ss = step_size(i) / (jnp.sqrt(r) + eps) 358 | xprop = x + ss * g + jnp.sqrt(2 * ss) * noise_scale(i) * Z 359 | 360 | return (xprop, r), next_key 361 | 362 | def update(i, accept_idxs, g, x, gprop, xprop, key): 363 | key, next_key = jax.random.split(key) 364 | return xprop, next_key 365 | 366 | def get_params(x): 367 | return x[0] 368 | 369 | return init, propose, log_proposal, update, get_params, base_key 370 | 371 | @sampler 372 | def rms_mala_fns( 373 | base_key, 374 | num_samples=10, 375 | step_size=1e-3, 376 | noise_scale=1.0, 377 | beta=0.9, 378 | eps=1e-9, 379 | init_stddev=0.1, 380 | init_dist='normal' 381 | ): 382 | """Constructs sampler functions for the Metropolis Adjusted Langevin Algorithm. 383 | See e.g. http://probability.ca/jeff/ftpdir/lang.pdf 384 | 385 | Args: 386 | base_key: jax.random.PRNGKey 387 | num_samples: number of samples to initialize; either > 0 or -1. 388 | If num_samples == -1, assumes that the initial samples are 389 | already constructed. 390 | step_size: float or callable w/ signature step_size(i) 391 | init_stdev: nonnegative float standard deviation for initialization 392 | of initial samples (ignored if num_samples == -1) 393 | init_dist: str in ['normal', 'centered_uniform'] to sample perturbations 394 | for the initial distribution 395 | 396 | Returns: 397 | sampler function tuple (init, propose, update, get_params) 398 | """ 399 | step_size = make_schedule(step_size) 400 | noise_scale = make_schedule(noise_scale) 401 | 402 | if isinstance(init_dist, str): 403 | init_dist = init_distributions[init_dist] 404 | 405 | def log_proposal(i, g, x, gprop, xprop): #grads come first 406 | #computes log q(xprop|x) 407 | x,r = x 408 | xprop,rprop = xprop 409 | ss = step_size(i) / (jnp.sqrt(r) + eps) 410 | return - 0.5 * jnp.sum(jnp.square(xprop - x - ss * g) \ 411 | / 2 * ss * noise_scale(i)**2) 412 | log_proposal = jax.vmap(log_proposal, in_axes=(None, 0, 0, 0, 0)) 413 | 414 | def init(x0, key): 415 | init_key, next_key = jax.random.split(key) 416 | if num_samples == -1: 417 | r = jnp.zeros_like(x0) 418 | return (x0, r), next_key 419 | x = init_dist(init_key, (num_samples,) + x0.shape) 420 | r = jnp.zeros_like(x) 421 | return (x0 + init_stddev * x, r), next_key 422 | 423 | def propose(i, g, x, key, **kwargs): 424 | key, next_key = jax.random.split(key) 425 | x, r = x 426 | Z = jax.random.normal(key, x.shape) 427 | 428 | r = beta * r + (1. - beta) * jnp.square(g) 429 | 430 | ss = step_size(i) / (jnp.sqrt(r) + eps) 431 | xprop = x + ss * g + jnp.sqrt(2 * ss) * noise_scale(i) * Z 432 | 433 | return (xprop, r), next_key 434 | 435 | def update(i, accept_idxs, g, x, gprop, xprop, key): 436 | """ if the state had additional data, you would need to accept them too""" 437 | x, r = x 438 | xprop, rprop = xprop 439 | 440 | accept_idxs = match_dims(accept_idxs, x) 441 | mask = accept_idxs.astype(jnp.float32) 442 | 443 | xnext = x * (1.0 - mask) + xprop * mask 444 | rnext = r * (1.0 - mask) + rprop * mask 445 | return (xnext, rnext), key 446 | 447 | def get_params(x): 448 | return x[0] 449 | 450 | return init, propose, log_proposal, update, get_params, base_key 451 | 452 | @sampler 453 | def rwmh_fns( 454 | base_key, 455 | num_samples=10, 456 | step_size=1e-3, 457 | init_stddev=0.1, 458 | init_dist='normal' 459 | ): 460 | """Constructs sampler functions for Random Walk Metropolis Hastings. 461 | See e.g. https://arxiv.org/pdf/1504.01896.pdf 462 | 463 | Args: 464 | base_key: jax.random.PRNGKey 465 | num_samples: number of samples to initialize; either > 0 or -1. 466 | If num_samples == -1, assumes that the initial samples are 467 | already constructed. 468 | step_size: float or callable w/ signature step_size(i) 469 | init_stdev: nonnegative float standard deviation for initialization 470 | of initial samples (ignored if num_samples == -1) 471 | init_dist: str in ['normal', 'centered_uniform'] to sample perturbations 472 | for the initial distribution 473 | 474 | Returns: 475 | sampler function tuple (init, propose, update, get_params) 476 | """ 477 | step_size = make_schedule(step_size) 478 | if isinstance(init_dist, str): 479 | init_dist = init_distributions[init_dist] 480 | 481 | def log_proposal(*args): 482 | # in this case, the proposal is symmetric, so this is correct 483 | return jnp.zeros(num_samples) 484 | 485 | def init(x0, key): 486 | init_key, next_key = jax.random.split(key) 487 | if num_samples == -1: 488 | return x0, next_key 489 | x = init_dist(init_key, (num_samples,) + x0.shape) 490 | return x0 + init_stddev * x, next_key 491 | 492 | def propose(i, g, x, key, **kwargs): 493 | key, next_key = jax.random.split(key) 494 | Z = jax.random.normal(key, x.shape) 495 | return x + step_size(i) * Z, next_key 496 | 497 | def update(i, accept_idxs, g, x, gprop, xprop, key): 498 | key, next_key = jax.random.split(key) 499 | 500 | accept_idxs = match_dims(accept_idxs, x) 501 | mask = accept_idxs.astype(jnp.float32) 502 | xnext = x * (1.0 - mask) + xprop * mask 503 | return xnext, next_key 504 | 505 | def get_params(x): 506 | return x 507 | 508 | return init, propose, log_proposal, update, get_params, base_key -------------------------------------------------------------------------------- /examples/deep/cifar10/cifar10_mcmc.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "cifar10_mcmc.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "accelerator": "GPU" 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "FuajMtOoy3xC", 21 | "colab_type": "text" 22 | }, 23 | "source": [ 24 | "# jax-bayes CIFAR10 Example --- Bayesian MCMC Approach\n", 25 | "\n", 26 | "## Set Up the Environment" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "metadata": { 32 | "id": "UfMSaNHlceB7", 33 | "colab_type": "code", 34 | "colab": { 35 | "base_uri": "https://localhost:8080/", 36 | "height": 795 37 | }, 38 | "outputId": "ad97595e-f01d-490d-c9e1-51061a76cdf5" 39 | }, 40 | "source": [ 41 | "#see https://github.com/google/jax#pip-installation\n", 42 | "!pip install --upgrade https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.51-cp36-none-manylinux2010_x86_64.whl\n", 43 | "!pip install --upgrade jax\n", 44 | "!pip install git+https://github.com/deepmind/dm-haiku\n", 45 | "!pip install git+https://github.com/jamesvuc/jax-bayes" 46 | ], 47 | "execution_count": 3, 48 | "outputs": [ 49 | { 50 | "output_type": "stream", 51 | "text": [ 52 | "Collecting jaxlib==0.1.51\n", 53 | " Using cached https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.51-cp36-none-manylinux2010_x86_64.whl\n", 54 | "Requirement already satisfied, skipping upgrade: scipy in /usr/local/lib/python3.6/dist-packages (from jaxlib==0.1.51) (1.5.2)\n", 55 | "Requirement already satisfied, skipping upgrade: absl-py in /usr/local/lib/python3.6/dist-packages (from jaxlib==0.1.51) (0.9.0)\n", 56 | "Requirement already satisfied, skipping upgrade: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jaxlib==0.1.51) (1.18.5)\n", 57 | "Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jaxlib==0.1.51) (1.15.0)\n", 58 | "Installing collected packages: jaxlib\n", 59 | " Found existing installation: jaxlib 0.1.51\n", 60 | " Uninstalling jaxlib-0.1.51:\n", 61 | " Successfully uninstalled jaxlib-0.1.51\n", 62 | "Successfully installed jaxlib-0.1.51\n", 63 | "Requirement already up-to-date: jax in /usr/local/lib/python3.6/dist-packages (0.1.75)\n", 64 | "Requirement already satisfied, skipping upgrade: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jax) (1.18.5)\n", 65 | "Requirement already satisfied, skipping upgrade: absl-py in /usr/local/lib/python3.6/dist-packages (from jax) (0.9.0)\n", 66 | "Requirement already satisfied, skipping upgrade: opt-einsum in /usr/local/lib/python3.6/dist-packages (from jax) (3.3.0)\n", 67 | "Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jax) (1.15.0)\n", 68 | "Collecting git+https://github.com/deepmind/dm-haiku\n", 69 | " Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-yf2rs_hb\n", 70 | " Running command git clone -q https://github.com/deepmind/dm-haiku /tmp/pip-req-build-yf2rs_hb\n", 71 | "Requirement already satisfied (use --upgrade to upgrade): dm-haiku==0.0.2 from git+https://github.com/deepmind/dm-haiku in /usr/local/lib/python3.6/dist-packages\n", 72 | "Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from dm-haiku==0.0.2) (0.9.0)\n", 73 | "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.6/dist-packages (from dm-haiku==0.0.2) (1.18.5)\n", 74 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from absl-py>=0.7.1->dm-haiku==0.0.2) (1.15.0)\n", 75 | "Building wheels for collected packages: dm-haiku\n", 76 | " Building wheel for dm-haiku (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 77 | " Created wheel for dm-haiku: filename=dm_haiku-0.0.2-cp36-none-any.whl size=289739 sha256=3b8458b694f0318292ff7f1ef1f8a08f8166e3141772ec2afc50f2464a55d1b0\n", 78 | " Stored in directory: /tmp/pip-ephem-wheel-cache-nu9w8nn8/wheels/97/0f/e9/17f34e377f8d4060fa88a7e82bee5d8afbf7972384768a5499\n", 79 | "Successfully built dm-haiku\n", 80 | "Collecting git+https://github.com/jamesvuc/jax-bayes\n", 81 | " Cloning https://github.com/jamesvuc/jax-bayes to /tmp/pip-req-build-2qkv8e8a\n", 82 | " Running command git clone -q https://github.com/jamesvuc/jax-bayes /tmp/pip-req-build-2qkv8e8a\n", 83 | "Requirement already satisfied (use --upgrade to upgrade): jax-bayes==0.0.1 from git+https://github.com/jamesvuc/jax-bayes in /usr/local/lib/python3.6/dist-packages\n", 84 | "Requirement already satisfied: absl-py>=0.9.0 in /usr/local/lib/python3.6/dist-packages (from jax-bayes==0.0.1) (0.9.0)\n", 85 | "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.6/dist-packages (from jax-bayes==0.0.1) (1.18.5)\n", 86 | "Requirement already satisfied: opt-einsum>=3.3.0 in /usr/local/lib/python3.6/dist-packages (from jax-bayes==0.0.1) (3.3.0)\n", 87 | "Requirement already satisfied: protobuf>=3.12.4 in /usr/local/lib/python3.6/dist-packages (from jax-bayes==0.0.1) (3.12.4)\n", 88 | "Requirement already satisfied: scipy>=1.5.2 in /usr/local/lib/python3.6/dist-packages (from jax-bayes==0.0.1) (1.5.2)\n", 89 | "Requirement already satisfied: six>=1.15.0 in /usr/local/lib/python3.6/dist-packages (from jax-bayes==0.0.1) (1.15.0)\n", 90 | "Requirement already satisfied: tqdm>=4.48.2 in /usr/local/lib/python3.6/dist-packages (from jax-bayes==0.0.1) (4.48.2)\n", 91 | "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.12.4->jax-bayes==0.0.1) (49.2.0)\n", 92 | "Building wheels for collected packages: jax-bayes\n", 93 | " Building wheel for jax-bayes (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 94 | " Created wheel for jax-bayes: filename=jax_bayes-0.0.1-cp36-none-any.whl size=1009734 sha256=eddebc139a0a210a3d1cbe62944d13d03eb84b8574fa2ade56d50e429f9cb824\n", 95 | " Stored in directory: /tmp/pip-ephem-wheel-cache-5kvj05ns/wheels/31/65/d6/bcf4b5e84c6f6f176e73850145875e806569759c23081b4446\n", 96 | "Successfully built jax-bayes\n" 97 | ], 98 | "name": "stdout" 99 | } 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "metadata": { 105 | "id": "l70DSI0ajQJq", 106 | "colab_type": "code", 107 | "colab": {} 108 | }, 109 | "source": [ 110 | "import haiku as hk\n", 111 | "\n", 112 | "import jax.numpy as jnp\n", 113 | "from jax.experimental import optimizers\n", 114 | "import jax\n", 115 | "\n", 116 | "import jax_bayes\n", 117 | "\n", 118 | "import sys, os, math, time\n", 119 | "import numpy as np\n", 120 | "\n", 121 | "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' \n", 122 | "import tensorflow_datasets as tfds" 123 | ], 124 | "execution_count": 4, 125 | "outputs": [] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": { 130 | "id": "B686kNdCzFEP", 131 | "colab_type": "text" 132 | }, 133 | "source": [ 134 | "## Build the dataset loader and CNN" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "metadata": { 140 | "id": "AzgaUa2owIqg", 141 | "colab_type": "code", 142 | "colab": {} 143 | }, 144 | "source": [ 145 | "def load_dataset(split, is_training, batch_size, repeat=True, seed=0):\n", 146 | " if repeat:\n", 147 | " ds = tfds.load('cifar10', split=split).cache().repeat()\n", 148 | " else:\n", 149 | " ds = tfds.load('cifar10', split=split).cache()\n", 150 | " if is_training:\n", 151 | " ds = ds.shuffle(10 * batch_size, seed=seed)\n", 152 | " ds = ds.batch(batch_size)\n", 153 | " return tfds.as_numpy(ds)\n", 154 | "\n", 155 | "# build a 32-32-64-32 CNN with max-pooling \n", 156 | "# followed by a 128-10-n_classes MLP\n", 157 | "class Net(hk.Module):\n", 158 | " def __init__(self, dropout=0.1, n_classes=10):\n", 159 | " super(Net, self).__init__()\n", 160 | " self.conv_stage = hk.Sequential([\n", 161 | " #block 1\n", 162 | " hk.Conv2D(32, kernel_shape=3, stride=1, padding='SAME'), \n", 163 | " jax.nn.relu, \n", 164 | " hk.MaxPool(window_shape=(1,2,2,1), strides=(1,1,1,1), padding='VALID'),\n", 165 | " # block 2\n", 166 | " hk.Conv2D(32, kernel_shape=3, stride=1, padding='SAME'), \n", 167 | " jax.nn.relu, \n", 168 | " hk.MaxPool(window_shape=(1,2,2,1), strides=(1,1,1,1), padding='VALID'),\n", 169 | " # block 3\n", 170 | " hk.Conv2D(64, kernel_shape=3, stride=1, padding='SAME'), \n", 171 | " jax.nn.relu, \n", 172 | " hk.MaxPool(window_shape=(1,2,2,1), strides=(1,1,1,1), padding='VALID'),\n", 173 | " # block 4\n", 174 | " hk.Conv2D(32, kernel_shape=3, stride=1, padding='SAME')\n", 175 | " ])\n", 176 | "\n", 177 | " self.mlp_stage = hk.Sequential([\n", 178 | " hk.Flatten(),\n", 179 | " hk.Linear(128), \n", 180 | " jax.nn.relu, \n", 181 | " hk.Linear(n_classes)\n", 182 | " ])\n", 183 | "\n", 184 | " self.p_dropout = dropout\n", 185 | "\n", 186 | " def __call__(self, x, use_dropout=True):\n", 187 | " x = self.conv_stage(x)\n", 188 | " \n", 189 | " dropout_rate = self.p_dropout if use_dropout else 0.0\n", 190 | " x = hk.dropout(hk.next_rng_key(), dropout_rate, x)\n", 191 | "\n", 192 | " return self.mlp_stage(x)\n", 193 | "\n", 194 | "# standard normalization constants\n", 195 | "mean_norm = jnp.array([[0.4914, 0.4822, 0.4465]])\n", 196 | "std_norm = jnp.array([[0.247, 0.243, 0.261]])\n", 197 | "\n", 198 | "#define the net-function \n", 199 | "def net_fn(batch, use_dropout):\n", 200 | " net = Net(dropout=0.0)\n", 201 | " x = batch['image']/255.0\n", 202 | " x = (x - mean_norm) / std_norm\n", 203 | " return net(x, use_dropout)" 204 | ], 205 | "execution_count": 5, 206 | "outputs": [] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "metadata": { 211 | "id": "0jAvwWBY6D9Q", 212 | "colab_type": "text" 213 | }, 214 | "source": [ 215 | "## Build the Loss, Metrics, and MCMC step" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "metadata": { 221 | "id": "NZva_QuKwO_2", 222 | "colab_type": "code", 223 | "colab": {} 224 | }, 225 | "source": [ 226 | "# hyperparameters\n", 227 | "# lr = 1e-2\n", 228 | "lr_initial = 1e-2\n", 229 | "lr_final = 1e-3\n", 230 | "decay_start = 100\n", 231 | "decay_steps = 100\n", 232 | "decay_schedule = jax.experimental.optimizers.polynomial_decay(lr_initial, decay_steps, lr_final, power=1.0)\n", 233 | "lr = lambda t: jax.lax.cond(t < decay_start,\n", 234 | " lambda s: lr_initial,\n", 235 | " lambda s: decay_schedule(s - decay_start),\n", 236 | " t)\n", 237 | "\n", 238 | "\n", 239 | "reg = 1e-4\n", 240 | "num_samples = 5\n", 241 | "#for this example, we're going to use the jax initializers to sample the initial \n", 242 | "# distribution, so we will use init_stddev = 0.0\n", 243 | "init_stddev = 0.0 \n", 244 | "\n", 245 | "# instantiate the network\n", 246 | "net = hk.transform(net_fn)\n", 247 | "\n", 248 | "# build the sampler\n", 249 | "key = jax.random.PRNGKey(0)\n", 250 | "sampler_init, sampler_propose, sampler_update, sampler_get_params = \\\n", 251 | " jax_bayes.mcmc.rms_langevin_fns(key, num_samples=-1, step_size=lr, \n", 252 | " init_stddev=init_stddev)\n", 253 | "\n", 254 | "# standard regularized crossentropy loss function, which is the \n", 255 | "# negative unnormalized log-posterior \n", 256 | "def loss(params, rng, batch):\n", 257 | " logits = net.apply(params, rng, batch, use_dropout=True)\n", 258 | " labels = jax.nn.one_hot(batch['label'], 10)\n", 259 | "\n", 260 | " l2_loss = 0.5 * sum(jnp.sum(jnp.square(p)) \n", 261 | " for p in jax.tree_leaves(params))\n", 262 | " softmax_crossent = - jnp.mean(labels * jax.nn.log_softmax(logits))\n", 263 | "\n", 264 | " return softmax_crossent + reg * l2_loss\n", 265 | "\n", 266 | "logprob = lambda p,k,b : - loss(p, k, b)\n", 267 | "\n", 268 | "@jax.jit\n", 269 | "def accuracy(params, batch):\n", 270 | " pred_fn = lambda p:net.apply(p, jax.random.PRNGKey(101), batch, use_dropout=False)\n", 271 | " pred_fn = jax.vmap(pred_fn)\n", 272 | " preds = jnp.mean(pred_fn(params), axis=0)\n", 273 | " return jnp.mean(jnp.argmax(preds, axis=-1) == batch['label'])\n", 274 | "\n", 275 | "# the data loss will help us monitor the Markov chain's progress without worrying\n", 276 | "# about the effects of regularization.\n", 277 | "def data_loss(params, batch):\n", 278 | " logits = net.apply(params, jax.random.PRNGKey(0), batch, use_dropout=False)\n", 279 | " labels = jax.nn.one_hot(batch['label'], 10)\n", 280 | " softmax_crossent = - jnp.mean(labels * jax.nn.log_softmax(logits))\n", 281 | " return softmax_crossent\n", 282 | "data_loss = jax.vmap(data_loss, in_axes=(0, None))\n", 283 | "\n", 284 | "@jax.jit\n", 285 | "def mcmc_step(i, sampler_state, sampler_keys, rng, batch):\n", 286 | " params = sampler_get_params(sampler_state)\n", 287 | " logp = lambda p,k: logprob(p, k, batch)\n", 288 | " fx, dx = jax.vmap(jax.value_and_grad(logp))(params, rng)\n", 289 | "\n", 290 | " sampler_prop_state, new_keys = sampler_propose(i, dx, sampler_state, \n", 291 | " sampler_keys)\n", 292 | "\n", 293 | " fx_prop, dx_prop = fx, dx\n", 294 | "\n", 295 | " sampler_state, new_keys = sampler_update(i, fx, fx_prop, \n", 296 | " dx, sampler_state, \n", 297 | " dx_prop, sampler_prop_state, \n", 298 | " new_keys)\n", 299 | " \n", 300 | " return jnp.mean(fx), sampler_state, new_keys" 301 | ], 302 | "execution_count": 13, 303 | "outputs": [] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "metadata": { 308 | "id": "GDH0DnMy7Zc_", 309 | "colab_type": "text" 310 | }, 311 | "source": [ 312 | "## Load Batch iterators & Do the MCMC inference" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "metadata": { 318 | "id": "xTAPjsUi8lJF", 319 | "colab_type": "code", 320 | "colab": {} 321 | }, 322 | "source": [ 323 | "init_batches = load_dataset(\"train\", is_training=True, batch_size=512)\n", 324 | "val_batches = load_dataset(\"train\", is_training=False, batch_size=2_000)\n", 325 | "test_batches = load_dataset(\"test\", is_training=False, batch_size=2_000)" 326 | ], 327 | "execution_count": 10, 328 | "outputs": [] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "metadata": { 333 | "id": "BJPiQfaRGWAc", 334 | "colab_type": "code", 335 | "colab": { 336 | "base_uri": "https://localhost:8080/", 337 | "height": 881 338 | }, 339 | "outputId": "0d7df7bc-9807-4928-ea80-c707897b92eb" 340 | }, 341 | "source": [ 342 | "#Use the vmap-over-keys trick to sample a highly anisotropic initial distribution\n", 343 | "init_batch = next(init_batches)\n", 344 | "keys = jax.random.split(jax.random.PRNGKey(1), num_samples)\n", 345 | "init_param_samples = jax.vmap(lambda k:net.init(k, init_batch, use_dropout=True))(keys)\n", 346 | "sampler_state, sampler_keys = sampler_init(init_param_samples)\n", 347 | "\n", 348 | "# generate RNGs for the dropout\n", 349 | "rngs = jax.random.split(jax.random.PRNGKey(2), num_samples)\n", 350 | "\n", 351 | "for epoch in range(250):\n", 352 | " #generate a shuffled epoch of training data\n", 353 | " train_batches = load_dataset(\"train\", is_training=True,\n", 354 | " batch_size=128, repeat=False, seed=epoch)\n", 355 | " \n", 356 | " start = time.time()\n", 357 | " for batch in train_batches:\n", 358 | " # run an MCMC step\n", 359 | " train_logprob, sampler_state, sampler_keys = \\\n", 360 | " mcmc_step(epoch, sampler_state, sampler_keys, rngs, batch)\n", 361 | " \n", 362 | " # make more rngs for the dropout\n", 363 | " rngs = jax.random.split(rngs[0], num_samples)\n", 364 | " epoch_time = time.time() - start\n", 365 | "\n", 366 | " if epoch % 5 == 0:\n", 367 | " # compute val and test accuracy, and the sampler-average data loss\n", 368 | " params = sampler_get_params(sampler_state)\n", 369 | " val_acc = accuracy(params, next(val_batches))\n", 370 | " test_acc = accuracy(params, next(test_batches))\n", 371 | " _data_loss = jnp.mean(data_loss(params, next(val_batches)))\n", 372 | " print(f\"epoch = {epoch}\"\n", 373 | " f\" | time per epoch {epoch_time:.3f}\"\n", 374 | " f\" | data loss = {_data_loss:.3e}\"\n", 375 | " f\" | val acc = {val_acc:.3f}\"\n", 376 | " f\" | test acc = {test_acc:.3f}\")" 377 | ], 378 | "execution_count": 14, 379 | "outputs": [ 380 | { 381 | "output_type": "stream", 382 | "text": [ 383 | "epoch = 0 | time per epoch 43.130 | data loss = 1.126e+17 | val acc = 0.189 | test acc = 0.193\n", 384 | "epoch = 5 | time per epoch 35.306 | data loss = 1.354e+15 | val acc = 0.333 | test acc = 0.346\n", 385 | "epoch = 10 | time per epoch 35.097 | data loss = 3.986e+14 | val acc = 0.351 | test acc = 0.350\n", 386 | "epoch = 15 | time per epoch 34.981 | data loss = 1.929e+14 | val acc = 0.388 | test acc = 0.368\n", 387 | "epoch = 20 | time per epoch 34.972 | data loss = 1.142e+14 | val acc = 0.392 | test acc = 0.399\n", 388 | "epoch = 25 | time per epoch 34.980 | data loss = 7.258e+13 | val acc = 0.412 | test acc = 0.413\n", 389 | "epoch = 30 | time per epoch 34.928 | data loss = 5.112e+13 | val acc = 0.438 | test acc = 0.399\n", 390 | "epoch = 35 | time per epoch 34.901 | data loss = 3.738e+13 | val acc = 0.442 | test acc = 0.416\n", 391 | "epoch = 40 | time per epoch 34.915 | data loss = 3.082e+13 | val acc = 0.455 | test acc = 0.404\n", 392 | "epoch = 45 | time per epoch 34.886 | data loss = 2.601e+13 | val acc = 0.458 | test acc = 0.432\n", 393 | "epoch = 50 | time per epoch 34.899 | data loss = 1.859e+13 | val acc = 0.483 | test acc = 0.446\n", 394 | "epoch = 55 | time per epoch 34.876 | data loss = 1.997e+13 | val acc = 0.508 | test acc = 0.429\n", 395 | "epoch = 60 | time per epoch 34.876 | data loss = 1.398e+13 | val acc = 0.479 | test acc = 0.422\n", 396 | "epoch = 65 | time per epoch 34.840 | data loss = 1.359e+13 | val acc = 0.470 | test acc = 0.416\n", 397 | "epoch = 70 | time per epoch 34.826 | data loss = 1.226e+13 | val acc = 0.465 | test acc = 0.417\n", 398 | "epoch = 75 | time per epoch 34.958 | data loss = 9.665e+12 | val acc = 0.503 | test acc = 0.447\n", 399 | "epoch = 80 | time per epoch 34.830 | data loss = 1.243e+13 | val acc = 0.495 | test acc = 0.441\n", 400 | "epoch = 85 | time per epoch 34.898 | data loss = 6.606e+12 | val acc = 0.483 | test acc = 0.426\n", 401 | "epoch = 90 | time per epoch 34.837 | data loss = 6.185e+12 | val acc = 0.475 | test acc = 0.400\n", 402 | "epoch = 95 | time per epoch 34.866 | data loss = 4.987e+12 | val acc = 0.524 | test acc = 0.454\n", 403 | "epoch = 100 | time per epoch 34.846 | data loss = 6.212e+12 | val acc = 0.478 | test acc = 0.428\n", 404 | "epoch = 105 | time per epoch 34.804 | data loss = 4.210e+12 | val acc = 0.507 | test acc = 0.445\n", 405 | "epoch = 110 | time per epoch 34.832 | data loss = 3.411e+12 | val acc = 0.501 | test acc = 0.437\n", 406 | "epoch = 115 | time per epoch 34.879 | data loss = 4.890e+12 | val acc = 0.446 | test acc = 0.396\n", 407 | "epoch = 120 | time per epoch 34.803 | data loss = 3.994e+12 | val acc = 0.522 | test acc = 0.451\n", 408 | "epoch = 125 | time per epoch 34.815 | data loss = 4.236e+12 | val acc = 0.532 | test acc = 0.472\n", 409 | "epoch = 130 | time per epoch 34.806 | data loss = 3.272e+12 | val acc = 0.514 | test acc = 0.445\n", 410 | "epoch = 135 | time per epoch 34.820 | data loss = 4.463e+12 | val acc = 0.534 | test acc = 0.444\n", 411 | "epoch = 140 | time per epoch 34.740 | data loss = 2.526e+12 | val acc = 0.548 | test acc = 0.456\n", 412 | "epoch = 145 | time per epoch 34.780 | data loss = 1.986e+12 | val acc = 0.532 | test acc = 0.456\n", 413 | "epoch = 150 | time per epoch 34.808 | data loss = 2.967e+12 | val acc = 0.568 | test acc = 0.473\n", 414 | "epoch = 155 | time per epoch 34.785 | data loss = 2.285e+12 | val acc = 0.550 | test acc = 0.432\n", 415 | "epoch = 160 | time per epoch 34.759 | data loss = 3.253e+12 | val acc = 0.538 | test acc = 0.439\n", 416 | "epoch = 165 | time per epoch 34.752 | data loss = 1.835e+12 | val acc = 0.573 | test acc = 0.462\n", 417 | "epoch = 170 | time per epoch 34.803 | data loss = 2.220e+12 | val acc = 0.574 | test acc = 0.465\n", 418 | "epoch = 175 | time per epoch 34.823 | data loss = 1.506e+12 | val acc = 0.576 | test acc = 0.476\n", 419 | "epoch = 180 | time per epoch 34.855 | data loss = 2.284e+12 | val acc = 0.599 | test acc = 0.448\n", 420 | "epoch = 185 | time per epoch 34.823 | data loss = 2.918e+12 | val acc = 0.573 | test acc = 0.464\n", 421 | "epoch = 190 | time per epoch 34.790 | data loss = 2.949e+12 | val acc = 0.551 | test acc = 0.454\n", 422 | "epoch = 195 | time per epoch 34.758 | data loss = 2.196e+12 | val acc = 0.575 | test acc = 0.457\n", 423 | "epoch = 200 | time per epoch 34.800 | data loss = 2.651e+12 | val acc = 0.576 | test acc = 0.469\n", 424 | "epoch = 205 | time per epoch 34.862 | data loss = 3.374e+12 | val acc = 0.562 | test acc = 0.443\n", 425 | "epoch = 210 | time per epoch 34.877 | data loss = 3.457e+12 | val acc = 0.586 | test acc = 0.459\n", 426 | "epoch = 215 | time per epoch 34.841 | data loss = 2.016e+12 | val acc = 0.603 | test acc = 0.463\n", 427 | "epoch = 220 | time per epoch 34.778 | data loss = 1.775e+12 | val acc = 0.589 | test acc = 0.456\n", 428 | "epoch = 225 | time per epoch 34.809 | data loss = 1.632e+12 | val acc = 0.607 | test acc = 0.477\n", 429 | "epoch = 230 | time per epoch 34.948 | data loss = 1.829e+12 | val acc = 0.575 | test acc = 0.448\n", 430 | "epoch = 235 | time per epoch 34.810 | data loss = 1.301e+12 | val acc = 0.571 | test acc = 0.457\n", 431 | "epoch = 240 | time per epoch 34.725 | data loss = 1.250e+12 | val acc = 0.586 | test acc = 0.448\n", 432 | "epoch = 245 | time per epoch 34.781 | data loss = 1.900e+12 | val acc = 0.592 | test acc = 0.451\n" 433 | ], 434 | "name": "stdout" 435 | } 436 | ] 437 | }, 438 | { 439 | "cell_type": "markdown", 440 | "metadata": { 441 | "id": "DVmiAoypRMOP", 442 | "colab_type": "text" 443 | }, 444 | "source": [ 445 | "**Note**: This example highlights how Bayesian ML and regular ML are very different. \n", 446 | "\n", 447 | "- We know a lot less about efficient inference than we do optimization.\n", 448 | "- Accuracy of around 45% (vs 70% for the optimization approach) is only a bit worse than current SoTA algorithms for this architecture (see e.g. [This paper](https://arxiv.org/pdf/1709.01180.pdf)). More hyperparameter tuning could probably close this gap.\n", 449 | "- In fact many MCMC papers do not evaluate on CIFAR10 (preferring to use MNIST, where we can easily achieve >96%)\n", 450 | "- There are several factors that contribute to MCMC's increased difficulty:\n", 451 | " - stochastic gradients\n", 452 | " - dependence on hyperparameters\n", 453 | " - regularization techniques\n", 454 | " - probabilistic algorithms are generally more subtle" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "metadata": { 460 | "id": "83Sp_OeSLBf_", 461 | "colab_type": "code", 462 | "colab": { 463 | "base_uri": "https://localhost:8080/", 464 | "height": 35 465 | }, 466 | "outputId": "e4c03c7a-fa70-4fec-c631-d8f8d6672aec" 467 | }, 468 | "source": [ 469 | "def posterior_predictive(params, batch):\n", 470 | " \"\"\"computes the posterior_predictive P(class = c | inputs, params) using a histogram\n", 471 | " \"\"\"\n", 472 | " pred_fn = lambda p:net.apply(p, jax.random.PRNGKey(0), batch, use_dropout=False) \n", 473 | " pred_fn = jax.vmap(pred_fn)\n", 474 | "\n", 475 | " logit_samples = pred_fn(params) # n_samples x batch_size x n_classes\n", 476 | " pred_samples = jnp.argmax(logit_samples, axis=-1) #n_samples x batch_size\n", 477 | "\n", 478 | " n_classes = logit_samples.shape[-1]\n", 479 | " batch_size = logit_samples.shape[1]\n", 480 | " probs = np.zeros((batch_size, n_classes))\n", 481 | " for c in range(n_classes):\n", 482 | " idxs = pred_samples == c\n", 483 | " probs[:,c] = idxs.sum(axis=0)\n", 484 | "\n", 485 | " return probs / probs.sum(axis=1, keepdims=True)\n", 486 | "\n", 487 | "params = sampler_get_params(sampler_state)\n", 488 | "print('Final predictive entropy', jnp.mean(jax_bayes.utils.entropy(posterior_predictive(params, next(test_batches)))))" 489 | ], 490 | "execution_count": 20, 491 | "outputs": [ 492 | { 493 | "output_type": "stream", 494 | "text": [ 495 | "Final predictive entropy 1.3844115\n" 496 | ], 497 | "name": "stdout" 498 | } 499 | ] 500 | } 501 | ] 502 | } -------------------------------------------------------------------------------- /examples/deep/nmt/attention_nmt.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "attention_nmt.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "accelerator": "GPU" 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "BxYOQbY85DHu", 21 | "colab_type": "text" 22 | }, 23 | "source": [ 24 | "# NMT Example --- Traditional ML Approach\n", 25 | "\n", 26 | "Adapted from https://www.tensorflow.org/tutorials/text/nmt_with_attention\n", 27 | "\n", 28 | "## Set Up Environment" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "metadata": { 34 | "id": "WyAXZf8VLu5T", 35 | "colab_type": "code", 36 | "colab": { 37 | "base_uri": "https://localhost:8080/", 38 | "height": 328 39 | }, 40 | "outputId": "97938d44-5621-4ff9-8de8-587d8576c97f" 41 | }, 42 | "source": [ 43 | "!pip install --upgrade https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.51-cp36-none-manylinux2010_x86_64.whl\n", 44 | "!pip install --upgrade jax\n", 45 | "!pip install git+https://github.com/deepmind/dm-haiku\n", 46 | "\n", 47 | "import tensorflow as tf\n", 48 | "\n", 49 | "import matplotlib.pyplot as plt\n", 50 | "import matplotlib.ticker as ticker\n", 51 | "from sklearn.model_selection import train_test_split\n", 52 | "\n", 53 | "import jax\n", 54 | "import jax.numpy as jnp\n", 55 | "from jax.experimental import optimizers\n", 56 | "\n", 57 | "import haiku as hk\n", 58 | "\n", 59 | "import unicodedata\n", 60 | "import re\n", 61 | "import numpy as np\n", 62 | "import os\n", 63 | "import io\n", 64 | "import time" 65 | ], 66 | "execution_count": 2, 67 | "outputs": [ 68 | { 69 | "output_type": "stream", 70 | "text": [ 71 | "Requirement already up-to-date: jax in /usr/local/lib/python3.6/dist-packages (0.1.75)\n", 72 | "Requirement already satisfied, skipping upgrade: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jax) (1.18.5)\n", 73 | "Requirement already satisfied, skipping upgrade: absl-py in /usr/local/lib/python3.6/dist-packages (from jax) (0.9.0)\n", 74 | "Requirement already satisfied, skipping upgrade: opt-einsum in /usr/local/lib/python3.6/dist-packages (from jax) (3.3.0)\n", 75 | "Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jax) (1.15.0)\n", 76 | "Collecting git+https://github.com/deepmind/dm-haiku\n", 77 | " Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-n8f8jxjj\n", 78 | " Running command git clone -q https://github.com/deepmind/dm-haiku /tmp/pip-req-build-n8f8jxjj\n", 79 | "Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from dm-haiku==0.0.2) (0.9.0)\n", 80 | "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.6/dist-packages (from dm-haiku==0.0.2) (1.18.5)\n", 81 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from absl-py>=0.7.1->dm-haiku==0.0.2) (1.15.0)\n", 82 | "Building wheels for collected packages: dm-haiku\n", 83 | " Building wheel for dm-haiku (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 84 | " Created wheel for dm-haiku: filename=dm_haiku-0.0.2-cp36-none-any.whl size=289739 sha256=31a1f3bf7c0bc62f063c1630283257dac52b38679e60d2ef754b5cf2192cf32c\n", 85 | " Stored in directory: /tmp/pip-ephem-wheel-cache-0la00c1v/wheels/97/0f/e9/17f34e377f8d4060fa88a7e82bee5d8afbf7972384768a5499\n", 86 | "Successfully built dm-haiku\n", 87 | "Installing collected packages: dm-haiku\n", 88 | "Successfully installed dm-haiku-0.0.2\n" 89 | ], 90 | "name": "stdout" 91 | } 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": { 97 | "id": "kgl1qQ_CNMVV", 98 | "colab_type": "text" 99 | }, 100 | "source": [ 101 | "## Dataset Processing & NLP-specific Functions" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "metadata": { 107 | "id": "EQ2Q5nr0MCt2", 108 | "colab_type": "code", 109 | "colab": { 110 | "base_uri": "https://localhost:8080/", 111 | "height": 52 112 | }, 113 | "outputId": "0f187b18-b5a9-4dad-f7b9-7c3e997aea86" 114 | }, 115 | "source": [ 116 | "# Download the file\n", 117 | "path_to_zip = tf.keras.utils.get_file(\n", 118 | " 'spa-eng.zip', origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip',\n", 119 | " extract=True)\n", 120 | "\n", 121 | "path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\"" 122 | ], 123 | "execution_count": 3, 124 | "outputs": [ 125 | { 126 | "output_type": "stream", 127 | "text": [ 128 | "Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip\n", 129 | "2646016/2638744 [==============================] - 0s 0us/step\n" 130 | ], 131 | "name": "stdout" 132 | } 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "metadata": { 138 | "id": "u6u6RLNsMfCP", 139 | "colab_type": "code", 140 | "colab": {} 141 | }, 142 | "source": [ 143 | "# ========= DATA PROCESSING =============\n", 144 | "# Converts the unicode file to ascii\n", 145 | "def unicode_to_ascii(s):\n", 146 | " return ''.join(c for c in unicodedata.normalize('NFD', s)\n", 147 | " if unicodedata.category(c) != 'Mn')\n", 148 | "\n", 149 | "\n", 150 | "def preprocess_sentence(w):\n", 151 | " w = unicode_to_ascii(w.lower().strip())\n", 152 | "\n", 153 | " # creating a space between a word and the punctuation following it\n", 154 | " # eg: \"he is a boy.\" => \"he is a boy .\"\n", 155 | " # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n", 156 | " w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n", 157 | " w = re.sub(r'[\" \"]+', \" \", w)\n", 158 | "\n", 159 | " # replacing everything with space except (a-z, A-Z, \".\", \"?\", \"!\", \",\")\n", 160 | " w = re.sub(r\"[^a-zA-Z?.!,¿]+\", \" \", w)\n", 161 | "\n", 162 | " w = w.strip()\n", 163 | "\n", 164 | " # adding a start and an end token to the sentence\n", 165 | " # so that the model know when to start and stop predicting.\n", 166 | " w = ' ' + w + ' '\n", 167 | " return w\n", 168 | "\n", 169 | "# 1. Remove the accents\n", 170 | "# 2. Clean the sentences\n", 171 | "# 3. Return word pairs in the format: [ENGLISH, SPANISH]\n", 172 | "def create_dataset(path, num_examples):\n", 173 | " lines = io.open(path, encoding='UTF-8').read().strip().split('\\n')\n", 174 | "\n", 175 | " word_pairs = [[preprocess_sentence(w) for w in l.split('\\t')] for l in lines[:num_examples]]\n", 176 | "\n", 177 | " return zip(*word_pairs)\n", 178 | "\n", 179 | "def tokenize(lang):\n", 180 | " lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(\n", 181 | " filters='')\n", 182 | " lang_tokenizer.fit_on_texts(lang)\n", 183 | "\n", 184 | " tensor = lang_tokenizer.texts_to_sequences(lang)\n", 185 | "\n", 186 | " tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,\n", 187 | " padding='post')\n", 188 | "\n", 189 | " return tensor, lang_tokenizer\n", 190 | "\n", 191 | "def load_dataset(path, num_examples=None):\n", 192 | " # creating cleaned input, output pairs\n", 193 | " targ_lang, inp_lang = create_dataset(path, num_examples)\n", 194 | "\n", 195 | " input_tensor, inp_lang_tokenizer = tokenize(inp_lang)\n", 196 | " target_tensor, targ_lang_tokenizer = tokenize(targ_lang)\n", 197 | "\n", 198 | " return input_tensor, target_tensor, inp_lang_tokenizer, targ_lang_tokenizer" 199 | ], 200 | "execution_count": 4, 201 | "outputs": [] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "metadata": { 206 | "id": "ru2m3ZMbM_W6", 207 | "colab_type": "code", 208 | "colab": { 209 | "base_uri": "https://localhost:8080/", 210 | "height": 35 211 | }, 212 | "outputId": "4abc7f1f-3ce2-43da-ab63-a8a26ad5a2ca" 213 | }, 214 | "source": [ 215 | "# Try experimenting with the size of that dataset\n", 216 | "num_examples = 30000\n", 217 | "# num_examples = -1\n", 218 | "input_tensor, target_tensor, inp_lang, targ_lang = load_dataset(path_to_file, num_examples)\n", 219 | "\n", 220 | "# Calculate max_length of the target tensors\n", 221 | "max_length_targ, max_length_inp = target_tensor.shape[1], input_tensor.shape[1]\n", 222 | "\n", 223 | "# Creating training and validation sets using an 80-20 split\n", 224 | "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n", 225 | "\n", 226 | "# Show length\n", 227 | "print(len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val))" 228 | ], 229 | "execution_count": 52, 230 | "outputs": [ 231 | { 232 | "output_type": "stream", 233 | "text": [ 234 | "24000 24000 6000 6000\n" 235 | ], 236 | "name": "stdout" 237 | } 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "metadata": { 243 | "id": "BOrNnUfHNiX7", 244 | "colab_type": "code", 245 | "colab": { 246 | "base_uri": "https://localhost:8080/", 247 | "height": 35 248 | }, 249 | "outputId": "4de4fa17-6657-4dee-d41d-a97ec97f413e" 250 | }, 251 | "source": [ 252 | "#make the dataset\n", 253 | "BUFFER_SIZE = len(input_tensor_train)\n", 254 | "BATCH_SIZE = 64\n", 255 | "steps_per_epoch = len(input_tensor_train)//BATCH_SIZE\n", 256 | "embedding_dim = 256\n", 257 | "units = 1024\n", 258 | "vocab_inp_size = len(inp_lang.word_index)+1\n", 259 | "vocab_tar_size = len(targ_lang.word_index)+1\n", 260 | "\n", 261 | "dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n", 262 | "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)\n", 263 | "\n", 264 | "import tensorflow_datasets as tfds\n", 265 | "dataset = tfds.as_numpy(dataset)\n", 266 | "\n", 267 | "example_input_batch, example_target_batch = next(iter(dataset))\n", 268 | "example_input_batch.shape, example_target_batch.shape" 269 | ], 270 | "execution_count": 53, 271 | "outputs": [ 272 | { 273 | "output_type": "execute_result", 274 | "data": { 275 | "text/plain": [ 276 | "((64, 16), (64, 11))" 277 | ] 278 | }, 279 | "metadata": { 280 | "tags": [] 281 | }, 282 | "execution_count": 53 283 | } 284 | ] 285 | }, 286 | { 287 | "cell_type": "markdown", 288 | "metadata": { 289 | "id": "2-i6hvE5NwBN", 290 | "colab_type": "text" 291 | }, 292 | "source": [ 293 | "## Define the Encoder-Decoder Model\n", 294 | "\n", 295 | "This is a standard encoder-decoder architecture with attentional decoding. See the paper https://arxiv.org/pdf/1409.0473.pdf for details. The attention mechanism allows the model to selectively *attend* to the encoded inputs, allowing the model to focus on the most important inputs in the source language for each prediction in the target language.\n", 296 | "\n", 297 | "We use a GRU-based recurrent model with scaled dot-product attention (which is different from the paper above). " 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "metadata": { 303 | "id": "9p8N9KJUN80w", 304 | "colab_type": "code", 305 | "colab": {} 306 | }, 307 | "source": [ 308 | "class Encoder(hk.Module):\n", 309 | " def __init__(self, vocab_size, d_model):\n", 310 | " super(Encoder, self).__init__()\n", 311 | " #is it better to keep the embedding outside?\n", 312 | " self.embedding = hk.Embed(vocab_size=vocab_size, embed_dim=d_model)\n", 313 | " self.gru = hk.GRU(hidden_size=d_model, \n", 314 | " w_i_init=hk.initializers.VarianceScaling(1.0, \"fan_avg\", \"uniform\"),\n", 315 | " w_h_init=hk.initializers.VarianceScaling(1.0, \"fan_avg\", \"uniform\"),\n", 316 | " b_init=hk.initializers.Constant(0.0))\n", 317 | " \n", 318 | " def initial_state(self, batch_size):\n", 319 | " return self.gru.initial_state(batch_size)\n", 320 | " \n", 321 | " def __call__(self, tokens, init_state):\n", 322 | " inputs = self.embedding(tokens)\n", 323 | " return hk.dynamic_unroll(self.gru, inputs, init_state)\n", 324 | "\n", 325 | "class ScaledDotAttention(hk.Module):\n", 326 | " \"\"\" Implements single-headed scaled dot-product attention \"\"\" \n", 327 | " def __init__(self, d_model):\n", 328 | " super(ScaledDotAttention, self).__init__()\n", 329 | " self.W_Q = hk.Linear(d_model, w_init = hk.initializers.VarianceScaling(1.0, \"fan_avg\", \"uniform\"),\n", 330 | " b_init = hk.initializers.Constant(0.0))\n", 331 | " self.W_K = hk.Linear(d_model, w_init = hk.initializers.VarianceScaling(1.0, \"fan_avg\", \"uniform\"),\n", 332 | " b_init = hk.initializers.Constant(0.0))\n", 333 | " self.W_V = hk.Linear(d_model, w_init = hk.initializers.VarianceScaling(1.0, \"fan_avg\", \"uniform\"),\n", 334 | " b_init = hk.initializers.Constant(0.0))\n", 335 | " self.d_model = d_model\n", 336 | " self.root_d_model = np.sqrt(self.d_model)\n", 337 | " \n", 338 | " def __call__(self, Q, K, V):\n", 339 | " #apply linear projections to the Queries, Keys, and Values\n", 340 | " Q = self.W_Q(Q)\n", 341 | " K = self.W_K(K)\n", 342 | " V = self.W_V(V)\n", 343 | "\n", 344 | " #batch-dimension last...this is weird\n", 345 | " scores = jnp.einsum('...bd,tbd->...tb', Q, K)/self.root_d_model\n", 346 | "\n", 347 | " #normalize the scores\n", 348 | " probs = jax.nn.softmax(scores, axis=-2)\n", 349 | " \n", 350 | " #average the values w.r.t. the probs\n", 351 | " return jnp.einsum('...tb,tbd->...bd', probs, V)\n", 352 | " \n", 353 | "class BhadanauAttention(hk.Module):\n", 354 | " def __init__(self, d_model):\n", 355 | " super(BhadanauAttention, self).__init__()\n", 356 | " self.W_Q = hk.Linear(d_model, w_init = hk.initializers.VarianceScaling(1.0, \"fan_avg\", \"uniform\"),\n", 357 | " b_init = hk.initializers.Constant(0.0))\n", 358 | " self.W_K = hk.Linear(d_model, w_init = hk.initializers.VarianceScaling(1.0, \"fan_avg\", \"uniform\"),\n", 359 | " b_init = hk.initializers.Constant(0.0))\n", 360 | " self.W_score = hk.Linear(1, w_init = hk.initializers.VarianceScaling(1.0, \"fan_avg\", \"uniform\"),\n", 361 | " b_init = hk.initializers.Constant(0.0))\n", 362 | "\n", 363 | " def __call__(self, Q, K, V):\n", 364 | " Q = jnp.expand_dims(Q, 0)\n", 365 | " #project the inputs\n", 366 | " Q = self.W_Q(Q)\n", 367 | " K = self.W_K(K)\n", 368 | "\n", 369 | " # compute the scores using the Bhadanau attention mechanism\n", 370 | " scores = self.W_score(jnp.tanh(Q + K))\n", 371 | "\n", 372 | " # normalize the scores into probs\n", 373 | " probs = jax.nn.softmax(scores, axis=0) #0 is time axis\n", 374 | "\n", 375 | " # average the values w.r.t. the probs\n", 376 | " return jnp.einsum('tbd,tbd->bd', probs, V)\n", 377 | "\n", 378 | "class Decoder(hk.Module):\n", 379 | " def __init__(self, attn, vocab_size, d_model):\n", 380 | " super(Decoder, self).__init__()\n", 381 | " self.embedding = hk.Embed(vocab_size=vocab_size, embed_dim=d_model)\n", 382 | " self.attn = attn\n", 383 | " self.gru = hk.GRU(hidden_size=d_model, \n", 384 | " w_i_init=hk.initializers.VarianceScaling(1.0, \"fan_avg\", \"uniform\"),\n", 385 | " w_h_init=hk.initializers.VarianceScaling(1.0, \"fan_avg\", \"uniform\"),\n", 386 | " b_init=hk.initializers.Constant(0.0))\n", 387 | " \n", 388 | " self.proj = hk.Linear(vocab_size)\n", 389 | " \n", 390 | " def initial_state(self, batch_size):\n", 391 | " return self.gru.initial_state(batch_size)\n", 392 | "\n", 393 | " def __call__(self, tokens, enc_outputs, hidden_state):\n", 394 | " \"\"\" do attention with queries = hidden state, keys = enc_outputs, \n", 395 | " values = enc_outputs to select the most 'relevant' encoded outputs \n", 396 | " to the hidden state.\"\"\"\n", 397 | " \n", 398 | " # hidden_state = np.expand_dims(hidden_state, 0)\n", 399 | " ctx_vector = self.attn(hidden_state, enc_outputs, enc_outputs)\n", 400 | "\n", 401 | " # embed the tokens with the target embedding\n", 402 | " inputs = self.embedding(tokens)\n", 403 | "\n", 404 | " # concat the ctx_vector to the embeddings\n", 405 | " inputs = jnp.concatenate([ctx_vector, inputs], axis=-1)\n", 406 | "\n", 407 | " #apply the decoder to the context + inputs\n", 408 | " outputs, hidden_state = self.gru(inputs, hidden_state)\n", 409 | "\n", 410 | " # project outputs into logit space and return (logits, hidden_state)\n", 411 | " return self.proj(outputs), hidden_state\n", 412 | " " 413 | ], 414 | "execution_count": 47, 415 | "outputs": [] 416 | }, 417 | { 418 | "cell_type": "markdown", 419 | "metadata": { 420 | "id": "ZjNxAYsGOyHH", 421 | "colab_type": "text" 422 | }, 423 | "source": [ 424 | "## Define the Encoder and Decoder 'Forward' functions\n", 425 | "\n", 426 | "We define these separately since we need to run the encoder once and the decoder multiple times for autoregressive decoding." 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "metadata": { 432 | "id": "9OVEJ2RDQ3Rw", 433 | "colab_type": "code", 434 | "colab": {} 435 | }, 436 | "source": [ 437 | "def encoder_fn(input_seqs):\n", 438 | " \"\"\" assumes input_seqs is time-first\n", 439 | " args:\n", 440 | " input_seqs: an input sequence of tokens\n", 441 | " \n", 442 | " returns:\n", 443 | " a tuple of arrays for the the encoded outputs and the final hidden state of the encoder \n", 444 | " \"\"\"\n", 445 | " \n", 446 | " encoder = Encoder(vocab_size=vocab_inp_size, d_model = embedding_dim)\n", 447 | " batch_size = input_seqs.shape[1]\n", 448 | "\n", 449 | " #initialize the hidden state\n", 450 | " enc_initial_state = encoder.initial_state(batch_size)\n", 451 | " \n", 452 | " #apply the encoder to the full sequence using hk.dynamic_unroll(...)\n", 453 | " enc_outputs, enc_hidden = encoder(input_seqs, enc_initial_state)\n", 454 | "\n", 455 | " return enc_outputs, enc_hidden\n", 456 | "\n", 457 | "def decoder_fn(dec_inputs, hidden_state, enc_outputs):\n", 458 | " \"\"\" assumes dec_inputs are time-first \"\"\"\n", 459 | " attn = ScaledDotAttention(d_model = embedding_dim)\n", 460 | " # attn = BhadanauAttention(d_model = embedding_dim) # uncomment for Bhadanau attention\n", 461 | " \n", 462 | " decoder = Decoder(attn, vocab_size = vocab_tar_size, d_model = embedding_dim)\n", 463 | "\n", 464 | " # apply the decoder to a single input (i.e. not unrolled) since we need \n", 465 | " # to autoregressively generate the translation.\n", 466 | " outputs, hidden_state = decoder(dec_inputs, enc_outputs, hidden_state)\n", 467 | "\n", 468 | " return outputs, hidden_state\n", 469 | "\n", 470 | "def init_params(key, batch):\n", 471 | " test_inputs, test_targets = batch\n", 472 | "\n", 473 | " #transpose inputs to be time-first\n", 474 | " test_inputs = test_inputs.transpose(1,0)\n", 475 | " test_targets = test_targets.transpose(1,0)\n", 476 | "\n", 477 | " encoder = hk.transform(encoder_fn, apply_rng = True)\n", 478 | " enc_params = encoder.init(jax.random.PRNGKey(42), test_inputs)\n", 479 | " enc_outputs, enc_hiddens = encoder.apply(enc_params, jax.random.PRNGKey(0), test_inputs)\n", 480 | "\n", 481 | " decoder = hk.transform(decoder_fn, apply_rng = True)\n", 482 | " dec_params = decoder.init(jax.random.PRNGKey(42), test_targets[0], \n", 483 | " enc_hiddens, enc_outputs)\n", 484 | "\n", 485 | " return enc_params, dec_params" 486 | ], 487 | "execution_count": 48, 488 | "outputs": [] 489 | }, 490 | { 491 | "cell_type": "markdown", 492 | "metadata": { 493 | "id": "EFL1BomTPLMX", 494 | "colab_type": "text" 495 | }, 496 | "source": [ 497 | "## Define the Loss Function and Train Step" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "metadata": { 503 | "id": "u-x8FbdUTHk0", 504 | "colab_type": "code", 505 | "colab": {} 506 | }, 507 | "source": [ 508 | "lr = 1e-3\n", 509 | "opt_init, opt_update, opt_get_params = optimizers.adam(lr)\n", 510 | "\n", 511 | "def masked_crossent(logits, targets):\n", 512 | " one_hot_labels = jax.nn.one_hot(targets, vocab_tar_size)\n", 513 | "\n", 514 | " # since we have padded the batch with 0s, to make them uniform length, \n", 515 | " # we need to mask out the padding tokens, which are index-zero tokens\n", 516 | " mask = jnp.expand_dims(targets > 0,1)\n", 517 | "\n", 518 | " #do masked mean, ensuring that length-zero batches don't give nan.\n", 519 | " denom = jnp.max(jnp.array([jnp.sum(mask), 1]))\n", 520 | " crossent = - jnp.sum(one_hot_labels * jax.nn.log_softmax(logits) * mask) / denom\n", 521 | "\n", 522 | " return crossent\n", 523 | "\n", 524 | "def loss(params, batch):\n", 525 | " enc_params, dec_params = params\n", 526 | "\n", 527 | " input_batch, target_batch = batch\n", 528 | "\n", 529 | " #transpose batch to be time-first\n", 530 | " input_batch = jnp.transpose(input_batch, (1,0))\n", 531 | " target_batch = jnp.transpose(target_batch, (1,0))\n", 532 | "\n", 533 | " #encode the batch once\n", 534 | " enc_outputs, enc_hidden = encoder.apply(enc_params, jax.random.PRNGKey(0), input_batch)\n", 535 | "\n", 536 | " #initalize the decoder's hidden state to be the encoder's hidden state\n", 537 | " dec_hidden = enc_hidden\n", 538 | "\n", 539 | " #start predicting with the token\n", 540 | " dec_input = jnp.array([targ_lang.word_index['']] * BATCH_SIZE)\n", 541 | "\n", 542 | " t_max = target_batch.shape[0]\n", 543 | " loss = 0.0\n", 544 | " for t in range(1, t_max):\n", 545 | " # iterate through the targets\n", 546 | " targets = target_batch[t]\n", 547 | "\n", 548 | " # compute logits over target vocabulary for the current word (targets)\n", 549 | " logits, dec_hidden = decoder.apply(dec_params, jax.random.PRNGKey(0), \n", 550 | " dec_input, dec_hidden, enc_outputs)\n", 551 | "\n", 552 | " # accumulate the loss\n", 553 | " loss += masked_crossent(logits, targets)\n", 554 | " \n", 555 | " # use teacher forcing by providing the ground-truth input to the model at each timestep\n", 556 | " dec_input = targets\n", 557 | "\n", 558 | " return loss / t_max\n", 559 | "\n", 560 | "\n", 561 | "@jax.jit\n", 562 | "def train_step(i, opt_state, batch):\n", 563 | " params = opt_get_params(opt_state)\n", 564 | " # batch_loss_fn = lambda p: loss(p, batch)\n", 565 | " # fx, dx = jax.value_and_grad(batch_loss_fn)(params)\n", 566 | " fx, dx = jax.value_and_grad(loss)(params, batch)\n", 567 | " opt_state = opt_update(i, dx, opt_state)\n", 568 | " return fx, opt_state\n", 569 | "\n", 570 | "\n", 571 | "def eval_step(params, sentence, max_len=32):\n", 572 | " \"\"\" decodes a single input sentence, provided as a string \"\"\"\n", 573 | " enc_params, dec_params = params\n", 574 | "\n", 575 | " # tokenize input string\n", 576 | " sentence = preprocess_sentence(sentence)\n", 577 | " inputs = [inp_lang.word_index[token] for token in sentence.split(' ')]\n", 578 | " inputs = np.expand_dims(jnp.array(inputs), 1)\n", 579 | "\n", 580 | " # encode the inputs\n", 581 | " enc_outputs, enc_hidden = encoder.apply(enc_params, jax.random.PRNGKey(0), inputs)\n", 582 | "\n", 583 | " # initialize the decoder's hidden state with the encoder's hidden state\n", 584 | " dec_hidden = enc_hidden\n", 585 | "\n", 586 | " #start predicting with the token\n", 587 | " dec_input = jnp.array([targ_lang.word_index['']] * 1)\n", 588 | "\n", 589 | " result = []\n", 590 | " for t in range(1, max_len):\n", 591 | " # compute the logits for the current token\n", 592 | " logits, dec_hidden = decoder.apply(dec_params, jax.random.PRNGKey(0), \n", 593 | " dec_input, dec_hidden, enc_outputs)\n", 594 | "\n", 595 | " # greedy-decode the prediction\n", 596 | " pred_idx = int(jnp.argmax(logits))\n", 597 | " result.append(targ_lang.index_word[pred_idx])\n", 598 | "\n", 599 | " #if the decoder says 'stop', return\n", 600 | " if targ_lang.index_word[pred_idx] == '':\n", 601 | " break\n", 602 | " \n", 603 | " #otherwise, the prediction becomes the input (for autogregressive decoding)\n", 604 | " dec_input = jnp.array([pred_idx])\n", 605 | " \n", 606 | " return \" \".join(result) + '.'" 607 | ], 608 | "execution_count": 49, 609 | "outputs": [] 610 | }, 611 | { 612 | "cell_type": "markdown", 613 | "metadata": { 614 | "id": "C7daxYUNPRqg", 615 | "colab_type": "text" 616 | }, 617 | "source": [ 618 | "## Do the training" 619 | ] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "metadata": { 624 | "id": "ylxox6CEjOZS", 625 | "colab_type": "code", 626 | "colab": { 627 | "base_uri": "https://localhost:8080/", 628 | "height": 190 629 | }, 630 | "outputId": "cf41a742-5940-4977-c5aa-0b9426212424" 631 | }, 632 | "source": [ 633 | "init_key = jax.random.PRNGKey(0)\n", 634 | "params = init_params(init_key, next(dataset))\n", 635 | "opt_state = opt_init(params)\n", 636 | "\n", 637 | "train_dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train))\\\n", 638 | " .shuffle(BUFFER_SIZE)\n", 639 | "\n", 640 | "for epoch in range(10):\n", 641 | " epoch_loss = 0.0\n", 642 | " dataset_iter = tfds.as_numpy(train_dataset.batch(BATCH_SIZE, drop_remainder=True))\n", 643 | " \n", 644 | " start = time.time()\n", 645 | " for b, batch in enumerate(dataset_iter):\n", 646 | " train_loss, opt_state = train_step(b, opt_state, batch)\n", 647 | " epoch_loss += train_loss\n", 648 | " \n", 649 | " print(f\"epoch = {epoch}\",\n", 650 | " f\" | train loss = {epoch_loss / (b + 1):.5f}\",\n", 651 | " f\" | time per epoch = {time.time() - start:.2f}s\")" 652 | ], 653 | "execution_count": 54, 654 | "outputs": [ 655 | { 656 | "output_type": "stream", 657 | "text": [ 658 | "epoch = 0 | train loss = 1.84770 | time per epoch = 13.29s\n", 659 | "epoch = 1 | train loss = 1.17798 | time per epoch = 5.35s\n", 660 | "epoch = 2 | train loss = 0.92113 | time per epoch = 5.38s\n", 661 | "epoch = 3 | train loss = 0.74348 | time per epoch = 5.38s\n", 662 | "epoch = 4 | train loss = 0.60680 | time per epoch = 5.41s\n", 663 | "epoch = 5 | train loss = 0.50736 | time per epoch = 5.46s\n", 664 | "epoch = 6 | train loss = 0.42633 | time per epoch = 5.46s\n", 665 | "epoch = 7 | train loss = 0.34549 | time per epoch = 5.42s\n", 666 | "epoch = 8 | train loss = 0.28708 | time per epoch = 5.41s\n", 667 | "epoch = 9 | train loss = 0.24023 | time per epoch = 5.40s\n" 668 | ], 669 | "name": "stdout" 670 | } 671 | ] 672 | }, 673 | { 674 | "cell_type": "markdown", 675 | "metadata": { 676 | "id": "dX4YhirSPTwy", 677 | "colab_type": "text" 678 | }, 679 | "source": [ 680 | "## Evaluate on some sample sentences\n", 681 | "\n", 682 | "Note: this is a simple model trained on a subset of the data. The translations are not perfect (below are some reasonable outputs)." 683 | ] 684 | }, 685 | { 686 | "cell_type": "code", 687 | "metadata": { 688 | "id": "pWxX3fT7nZz-", 689 | "colab_type": "code", 690 | "colab": { 691 | "base_uri": "https://localhost:8080/", 692 | "height": 69 693 | }, 694 | "outputId": "2c9f312a-fa34-4703-e46e-f01689ad0498" 695 | }, 696 | "source": [ 697 | "params = opt_get_params(opt_state)\n", 698 | "print(eval_step(params, u'hace mucho calor aqui.'))\n", 699 | "print(eval_step(params, u'hola!'))\n", 700 | "print(eval_step(params, u'¿cómo estás?'))" 701 | ], 702 | "execution_count": 63, 703 | "outputs": [ 704 | { 705 | "output_type": "stream", 706 | "text": [ 707 | "it s hot here . .\n", 708 | "hello ! .\n", 709 | "how are you ? .\n" 710 | ], 711 | "name": "stdout" 712 | } 713 | ] 714 | } 715 | ] 716 | } --------------------------------------------------------------------------------