├── README.md ├── fun_to_jax_ir.png ├── jax-utils ├── README.md ├── jax_utils │ ├── set_conda_env_vars_for_jax_gpu.py │ └── test_jax_installation.py └── setup.py ├── tutorial.ipynb └── tutorial_with_solutions.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # JAX tutorial 2 | 3 | 4 | Open in Colab 5 | 6 | 7 | The tutorial notebook can be opened and run interactively on [Google Colab](https://colab.research.google.com/) from the badge above. The corresponding notebook with solutions to the exercises can also be run on Google Colab from [this link](https://colab.research.google.com/github/pierreglaser/jax-tutorial/blob/main/tutorial_with_solutions.ipynb). Alternatively follow the instructions below to set up a local Python environment to run the notebook from. 8 | 9 | ## Installation Instructions 10 | 11 | 12 | ### Requirements: 13 | 14 | - A UNIX-Compliant distribution 15 | - A `conda`-based package manager 16 | - (Optional) for GPU support: CUDA driver libraries `>= 11.6`. 17 | 18 | ### Jax Installation (CPU) 19 | 20 | To use a CPU-only powered jax, create a `conda` virtual environment containing `python` and `jax`: 21 | ```bash 22 | conda create -n jax-tutorial python=3.9 && conda activate jax-tutorial 23 | conda install -c conda-forge numpy scipy jax flax numpyro 24 | ``` 25 | 26 | ### Jax Installation (GPU) 27 | 28 | In all cases, you will need to install a GPU-able version of jax. 29 | 30 | ```bash 31 | # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer. 32 | # Note: wheels only available on linux. 33 | pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 34 | ``` 35 | 36 | A fully-functionning version of jax (i.e which includes working working (sparse) linear algebra and deep network primitives) on GPU requires `cudatoolkit` libraries, `cudnn`, as well as `nvcc` (a CUDA compiler). 37 | In most cases, these libraries should already be present in your system. Alas for research staff working on compute clusters with only user privileges, they often reside in a non-standard locations. 38 | 39 | #### If CUDA-related utilities are available in standard locations 40 | You should be all set. Congrats for living such a luxurious life. 41 | 42 | #### If using properly configured modulefiles (case of the Sainsbury Wellcome Center Compute Cluster). 43 | Some compute environments (like the SWC compute cluster) use modulefiles to integrate specific libraries and executables with your current shell session, removing the need for environment variables plumbing when the said libraries/executables are present in non-standard locations. 44 | 45 | If you're a SWC staff researcher working on the SWC compute cluster, you can load the cuda/11.6 modulefile by executing: 46 | 47 | ```bash 48 | module load cuda/11.6 49 | ``` 50 | 51 | and *voila*. 52 | 53 | 54 | #### If CUDA-related utilities are available in a non-standard locations 55 | If none of the two cases above apply, for instance in the case of user (conda) installed CUDA-libraries, or incomplete module files, you will need to point to `jax` yourself the place where such libraries can be found. 56 | To do so, locate the root directory containing the cuda utilities, say, `/path/to/cuda`, and run: 57 | 58 | ```bash 59 | export XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda/dir; 60 | export LD_LIBRARY_PATH=/path/to/cuda/dir/lib64; # YMMV: might be lib and not lib64 61 | ``` 62 | 63 | 64 | ### Testing your installation 65 | 66 | To test that your jax environment is properly setup, a convenience script is provided as part of this tutorial. From the root directory of this repository run: 67 | ```bash 68 | python -m pip install ./jax-utils 69 | # if on CPU: 70 | python -m jax_utils.test_jax_installation 71 | # if on GPU: 72 | python -m jax_utils.test_jax_installation --gpu 73 | ``` 74 | 75 | This script will test a subset of jax features relying on different libraries and will loudly error out if some piece of software is missing. 76 | 77 | 78 | ### Installing jupyter-related utilities 79 | 80 | To execute jupyter notebooks that will use the previously setup `jax-tutorial` environment as the execution environment, either install `jupyterlab` directly in this environment: 81 | 82 | ```bash 83 | conda install jupyterlab 84 | ``` 85 | 86 | or install `ipykernel` and register your kernel to your external jupyterlab installation: 87 | 88 | ```bash 89 | conda install ipykernel 90 | python -m ipykernel install --prefix=path/to/miniforge/installation/envs/ --name="jax-tutorial"; 91 | conda deactivate && conda activate 92 | ``` 93 | 94 | If you're using a GPU-powered jax, jupyterlab, and you're feeling fancy, install the jupyterlab extension `jupyterlab_nvdashboard`, which will dynamically display 95 | valuable metrics such as GPU memory usage or GPU volatle utilisation: 96 | 97 | ```bash 98 | pip install jupyterlab_nvdashboard 99 | ``` 100 | 101 | At this point, you should bee all set. To execute the notebooks `tutorial.ipynb`, simply make sure you are in the root directory of this tutuorial's repository, and execute: 102 | 103 | ```bash 104 | jupyter lab 105 | ``` -------------------------------------------------------------------------------- /fun_to_jax_ir.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pierreglaser/jax-tutorial/071c923c082e8dbd7aea54b220052283dc3dbe9d/fun_to_jax_ir.png -------------------------------------------------------------------------------- /jax-utils/README.md: -------------------------------------------------------------------------------- 1 | ## A set of `jax`-related utilities. 2 | 3 | 4 | Right now, this repository only contains a small function that sets up the right 5 | environment variables to get `jax` running against a `cuda` toolkit installed using 6 | `conda` 7 | -------------------------------------------------------------------------------- /jax-utils/jax_utils/set_conda_env_vars_for_jax_gpu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from pathlib import Path 4 | 5 | 6 | def set_environment_for_jax(): 7 | try: 8 | __import__("jax") 9 | except ModuleNotFoundError: 10 | return 11 | 12 | if os.environ.get("CUDA_VISIBLE_DEVICES") is None: 13 | return 14 | 15 | # XXX: this script has been written with condainstallation in mind -- 16 | # I should maybe check whether this python belongs to a conda env. 17 | 18 | conda_env_bin_dir = Path(sys.executable).parent 19 | conda_env_dir = conda_env_bin_dir.parent 20 | conda_env_lib_dir = conda_env_dir / "lib" 21 | 22 | # TODO(piereglaser): expose cuda toolkit to jax 23 | os.environ["PATH"] = f"{os.environ['PATH']}:{conda_env_bin_dir}" 24 | 25 | if "XLA_FLAGS" not in os.environ: 26 | print("setting XLA_FLAGS") 27 | os.environ["XLA_FLAGS"] = f"--xla_gpu_cuda_data_dir={conda_env_dir}" 28 | 29 | if "LD_LIBRARY_PATH" not in os.environ: 30 | print("setting LD_LIBRARY_PATH") 31 | os.environ["LD_LIBRARY_PATH"] = f"{conda_env_lib_dir}" 32 | 33 | # Don't prealloacte 90% of GPU memory as it recently led to memory leaks 34 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 35 | # os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".5" 36 | os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" 37 | -------------------------------------------------------------------------------- /jax-utils/jax_utils/test_jax_installation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import argparse 4 | from typing import Literal 5 | 6 | 7 | def show_environment_variables(): 8 | # Make sure jax knows where to look for cuda runtime libraries 9 | print(f"XLA_FLAGS={os.environ.get('XLA_FLAGS')}") 10 | print(f"LD_LIBRARY_PATH={os.environ.get('LD_LIBRARY_PATH')}") 11 | print(f"PATH={os.environ.get('PATH')}") 12 | print(f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}") 13 | print(f"XLA_PYTHON_CLIENT_PREALLOCATE={os.environ.get('XLA_PYTHON_CLIENT_PREALLOCATE')}") 14 | print(f"XLA_PYTHON_CLIENT_ALLOCATOR={os.environ.get('XLA_PYTHON_CLIENT_ALLOCATOR')}") 15 | 16 | 17 | def _test_jax_installation_unsafe(device: Literal['cpu', 'gpu']): 18 | try: 19 | import jax 20 | import jax.numpy as jnp 21 | from jax import random 22 | except ModuleNotFoundError: 23 | raise ValueError("jax is not installed") 24 | 25 | assert device in ['cpu', 'gpu'], f"device must be 'cpu' or 'gpu', got {device}" 26 | 27 | if device == 'gpu': 28 | from jaxlib.xla_extension import GpuDevice 29 | device_cls = GpuDevice 30 | else: 31 | from jaxlib.xla_extension import Device 32 | device_cls = Device 33 | 34 | if device == "gpu": 35 | # Check access to cuda compiler 36 | print("test access to a cuda compiler...", end="") 37 | try: 38 | subprocess.check_output(["which", "ptxas"]) 39 | # os.system("ptxas --version") 40 | except subprocess.CalledProcessError as e: 41 | raise ValueError("No cuda compiler found in $PATH") from e 42 | print(" OK.") 43 | 44 | # tell which device jax uses 45 | 46 | print(f"checking if jax can detect a {device} device...", end="") 47 | assert any(isinstance(d, device_cls) for d in jax.local_devices()) 48 | print(" OK.") 49 | 50 | # create a simple jax array 51 | print(f"testing array creation on {device}...", end="") 52 | key = random.PRNGKey(0) 53 | x = random.normal(key, (10,)) 54 | print(" OK.") 55 | 56 | # Use specialized cuda lib such as linear algebra solvers 57 | print( 58 | "testing use of specialized cuda libraries such as linear algebra solvers...", 59 | end="", 60 | ) 61 | A = jnp.array([[0, 1], [1, 1], [1, 1], [2, 1]]) 62 | _, _ = jnp.linalg.qr(A) 63 | 64 | A = jnp.eye(10) 65 | _, _ = jnp.linalg.eigh(A) 66 | 67 | print(" OK.") 68 | 69 | # Use cudnn primitives such as convolutions 70 | # (cudnn has to be installed separately) 71 | print("testing use of cudnn primitives...", end="") 72 | key = random.PRNGKey(0) 73 | x = jnp.linspace(0, 10, 500) 74 | y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,)) 75 | 76 | window = jnp.ones(10) / 10 77 | _ = jnp.convolve(y, window, mode="same") 78 | print(" OK.") 79 | 80 | print("Test done, everything seems well installed.") 81 | 82 | 83 | def test_jax_installation(device, verbose=False): 84 | try: 85 | _test_jax_installation_unsafe(device=device) 86 | except Exception as e: 87 | print('\n') 88 | print('\n') 89 | print('##################################################################') 90 | print('# #') 91 | print('# #') 92 | print('# ERROR WHILE TESTING JAX INSTALLATION #') 93 | print('# #') 94 | print('# #') 95 | print('##################################################################') 96 | print("An error occured during the test.") 97 | if not verbose: 98 | print("Run python -m jax_utils.test_jax_installation --verbose to get more information.") 99 | else: 100 | print("Here are the environment variables:") 101 | show_environment_variables() 102 | print("Here is the error message:") 103 | print(e) 104 | 105 | if __name__ == "__main__": 106 | 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument('--verbose', action='store_true') 109 | parser.add_argument('--gpu', action='store_true') 110 | args = parser.parse_args() 111 | test_jax_installation(device="gpu" if args.gpu else "cpu", verbose=args.verbose) 112 | -------------------------------------------------------------------------------- /jax-utils/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | try: 4 | from setuptools import setup 5 | except ImportError: 6 | from distutils.core import setup 7 | 8 | description = ("Some sanity-checking-scripts to ensure that jax is properly installed") 9 | 10 | dist = setup( 11 | name="jax-utils", 12 | version="0.0.0dev0", 13 | description=description, 14 | author="Pierre Glaser", 15 | author_email="pierreglaser@msn.com", 16 | license="BSD 3-Clause License", 17 | packages=["jax_utils"], 18 | install_requires=["jax"], 19 | classifiers=[ 20 | "Development Status :: 4 - Beta", 21 | "License :: OSI Approved :: BSD License", 22 | "Operating System :: POSIX", 23 | "Programming Language :: Python :: 3.8", 24 | "Programming Language :: Python :: Implementation :: CPython", 25 | ], 26 | python_requires=">=3.8", 27 | ) 28 | -------------------------------------------------------------------------------- /tutorial_with_solutions.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "79b3ba15-2375-4efd-97de-1c5fa6e6d2d4", 6 | "metadata": {}, 7 | "source": [ 8 | "# Tutorial: JAX\n", 9 | "\n", 10 | "The cell below will install the additional required dependencies using `pip` if running the notebook on [Google Colab](https://colab.research.google.com/) - it should do nothing if running locally." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "cc17c4af-1dbb-421a-b9c8-353c835c8698", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "try:\n", 21 | " import google.colab\n", 22 | " # Running on Google Colab therefore install additional dependencies\n", 23 | " !pip install flax numpyro\n", 24 | "except ImportError:\n", 25 | " # Assume dependencies installed if not running on Colab\n", 26 | " pass" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "id": "a61a4ff0-537e-42cc-bdfd-2a76ca8de292", 32 | "metadata": { 33 | "slideshow": { 34 | "slide_type": "slide" 35 | }, 36 | "tags": [] 37 | }, 38 | "source": [ 39 | "## JAX: Intro\n", 40 | "\n", 41 | "At a high level, jax is an extensible system for composable function transformations based on trace-specializing functional python code. These transformations include:\n", 42 | "\n", 43 | "- Just In Time Compilation\n", 44 | "- Automatic Differentiation\n", 45 | "- Automatic Vectorization\n", 46 | "- Single Program Multiple Device (SMPD) transformations.\n", 47 | "- And more...\n", 48 | "\n" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "id": "89bfe947-f766-4771-90cf-711f3d086e04", 54 | "metadata": { 55 | "slideshow": { 56 | "slide_type": "subslide" 57 | }, 58 | "tags": [] 59 | }, 60 | "source": [ 61 | "The functions on which jax can operate must:\n", 62 | "\n", 63 | "- take a (collections of) tensor-like inputs and return (collections of) tensor-like outputs, which are instances of the `DeviceArray` class (similar to np.ndarry instances)\n", 64 | "- manipulate these tensors using only a set of (closed) primitives exposed in the `jax` libraries, for the most part in the `jax.numpy` and `jax.scipy` modules, and which often have a direct equivalent in `numpy` or `scipy`\n", 65 | "- be functionally pure: running the same function with the same inputs should yield the same outputs. Functional programming should feel natural when the code logic is the direct translation of mathemtatical operations, a case which is very frequent in Machine Learning.\n", 66 | "\n", 67 | "\n", 68 | "We will start with a few very examples, and expand on that." 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "id": "4a294589-a6ed-40f5-a7a4-242e237d7dc6", 74 | "metadata": { 75 | "slideshow": { 76 | "slide_type": "slide" 77 | }, 78 | "tags": [] 79 | }, 80 | "source": [ 81 | "## Starting simple... The gaussian kernel" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 1, 87 | "id": "ae21ec89-b9a1-455d-953c-72532653acf5", 88 | "metadata": { 89 | "slideshow": { 90 | "slide_type": "skip" 91 | }, 92 | "tags": [] 93 | }, 94 | "outputs": [], 95 | "source": [ 96 | "import numpy as np\n", 97 | "\n", 98 | "import jax\n", 99 | "import jax.numpy as jnp\n", 100 | "from jax import jit, grad, vmap" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "id": "664690cf-3eea-4772-ac4e-ee45260b3f21", 106 | "metadata": { 107 | "slideshow": { 108 | "slide_type": "subslide" 109 | }, 110 | "tags": [] 111 | }, 112 | "source": [ 113 | "The `gaussian_kernel` function is a jax implementation of the famous gaussian kernel:\n", 114 | "$$ k(x, y) = e^{-\\frac{\\|x - y\\|^2}{\\sigma^2}}$$\n", 115 | "\n", 116 | "A function often used when implementing kernel-based methods, such as [Kernel Regression](https://en.wikipedia.org/wiki/Kernel_regression)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 2, 122 | "id": "8eb17e5b-ad06-4919-b10b-290f6d743463", 123 | "metadata": { 124 | "tags": [] 125 | }, 126 | "outputs": [], 127 | "source": [ 128 | "def gaussian_kernel(x, y):\n", 129 | " sigma = 2\n", 130 | " z = x - y\n", 131 | " return jnp.exp(-jnp.sum(jnp.square(z)) / sigma ** 2)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "id": "9b1ea2ff-218d-4b81-ae5e-cc0c30785073", 137 | "metadata": {}, 138 | "source": [] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 3, 143 | "id": "60ef0206-4aef-4b64-b1a0-a168213bcaf9", 144 | "metadata": { 145 | "slideshow": { 146 | "slide_type": "fragment" 147 | }, 148 | "tags": [] 149 | }, 150 | "outputs": [ 151 | { 152 | "data": { 153 | "text/plain": [ 154 | "DeviceArray(0.60653067, dtype=float32)" 155 | ] 156 | }, 157 | "execution_count": 3, 158 | "metadata": {}, 159 | "output_type": "execute_result" 160 | } 161 | ], 162 | "source": [ 163 | "x_input = jnp.ones((2,)) # [0., 0.]\n", 164 | "y_input = jnp.zeros((2,)) # [1., 1.]\n", 165 | "retval = gaussian_kernel(x_input, y_input)\n", 166 | "retval" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "id": "922cde14-db48-4379-8731-82c92fd08a5e", 172 | "metadata": { 173 | "slideshow": { 174 | "slide_type": "subslide" 175 | }, 176 | "tags": [] 177 | }, 178 | "source": [ 179 | "The gaussian kernel relies on the following jax primitives: `jnp.sum`, `jnp.square`, and `jnp.exp`, which are the jax-analogues of `np.sum`, `np.square` and `np.exp`. Almost all `numpy` primitives are present in `jax.numpy`, including traditional linear algebra operations like `jnp.dot` or `jnp.matmul`.\n", 180 | "\n", 181 | "This function satisfies all the constraints enumerated in the introductory paragraph of this section: it takes as input 2 jax arrays, return a scalar (0-th order jax array), and is pure (same input->same output). We can thus apply jax transformations on it!" 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "id": "55861338-3dd4-40ef-9586-388d927184ca", 187 | "metadata": { 188 | "slideshow": { 189 | "slide_type": "slide" 190 | }, 191 | "tags": [] 192 | }, 193 | "source": [ 194 | "### Jax Transformations" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "id": "ade6e9ca-f328-4430-ba90-5b9bfd483b68", 200 | "metadata": { 201 | "slideshow": { 202 | "slide_type": "subslide" 203 | }, 204 | "tags": [] 205 | }, 206 | "source": [ 207 | "### JIT compilation" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 4, 213 | "id": "6d7ac895-e6b3-4bff-b5e9-8306b3470447", 214 | "metadata": { 215 | "tags": [] 216 | }, 217 | "outputs": [], 218 | "source": [ 219 | "# just-in-time compilation\n", 220 | "jitted_gaussian_kernel = jit(gaussian_kernel)\n", 221 | "assert gaussian_kernel(x_input, y_input) == gaussian_kernel(x_input, y_input)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "id": "eb30fc49-3860-4c99-a1a7-3b375b9595a7", 227 | "metadata": { 228 | "tags": [] 229 | }, 230 | "source": [ 231 | "The JIT-compiled version of `gaussian_kernel`, `jitted_gaussian_kernel` executes the same end-to-end mathematical operations, makes additional software and harware optimizations to speed up computations." 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 5, 237 | "id": "97aa17d7-d822-470a-a01a-c2edffefdf0f", 238 | "metadata": {}, 239 | "outputs": [ 240 | { 241 | "name": "stdout", 242 | "output_type": "stream", 243 | "text": [ 244 | "14.5 µs ± 195 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" 245 | ] 246 | } 247 | ], 248 | "source": [ 249 | "%timeit gaussian_kernel(x_input, y_input).block_until_ready()" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 6, 255 | "id": "c2ffc3ea-b12f-417d-89b5-2d8cea507000", 256 | "metadata": { 257 | "tags": [] 258 | }, 259 | "outputs": [ 260 | { 261 | "name": "stdout", 262 | "output_type": "stream", 263 | "text": [ 264 | "2.32 µs ± 63.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" 265 | ] 266 | } 267 | ], 268 | "source": [ 269 | "%timeit jitted_gaussian_kernel(x_input, y_input).block_until_ready()" 270 | ] 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "id": "75d37c83-691e-453e-8f25-eef4b9c6ec36", 275 | "metadata": {}, 276 | "source": [ 277 | "The speedup here is minimal, due to the fact that the function `gaussian_kernel` is quite simple: there is not much optimizations to perform.\n", 278 | "We will see other use cases where the speedup yieled by jit-compilation of such functions can become significant." 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "id": "1c509077-fe69-4863-889f-34bcf897226f", 284 | "metadata": { 285 | "slideshow": { 286 | "slide_type": "subslide" 287 | }, 288 | "tags": [] 289 | }, 290 | "source": [ 291 | "### Automatic differentiation\n", 292 | "\n", 293 | "Jax can compute derivatives of functions with respect to one or several arguments:" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 7, 299 | "id": "e3573043-703f-4c33-885a-07b17ac1a39a", 300 | "metadata": {}, 301 | "outputs": [ 302 | { 303 | "data": { 304 | "text/plain": [ 305 | "DeviceArray([-0.30326533, -0.30326533], dtype=float32)" 306 | ] 307 | }, 308 | "execution_count": 7, 309 | "metadata": {}, 310 | "output_type": "execute_result" 311 | } 312 | ], 313 | "source": [ 314 | "# automatic differentiation\n", 315 | "d_dx_gaussian_kernel=grad(gaussian_kernel, argnums=0) # compute the partial derivative of gaussian_kernel w.r.t x\n", 316 | "d_dx_gaussian_kernel(x_input, y_input) # returns a 2-d vector" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "id": "b72448b9-d6ae-410f-b4e5-90b7f3274951", 322 | "metadata": { 323 | "slideshow": { 324 | "slide_type": "subslide" 325 | }, 326 | "tags": [] 327 | }, 328 | "source": [ 329 | "### Automatic vectorization\n", 330 | "\n", 331 | "Last but not least, jax can also \"automatically vectorize\" a function. The vectorized function:\n", 332 | "- takes as input a \"batch\" input tensors, stacked in a new dimension\n", 333 | "- applies the original function to all input tensors in the batch" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 8, 339 | "id": "b230cf65-3a6f-41a5-8fde-09a93e141706", 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [ 343 | "vmapped_gaussian_kernel = vmap(gaussian_kernel, in_axes=(0, None))" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 9, 349 | "id": "9c8e506e-3891-471f-868b-27019f70d364", 350 | "metadata": {}, 351 | "outputs": [ 352 | { 353 | "data": { 354 | "text/plain": [ 355 | "DeviceArray([1. , 0.60653067], dtype=float32)" 356 | ] 357 | }, 358 | "execution_count": 9, 359 | "metadata": {}, 360 | "output_type": "execute_result" 361 | } 362 | ], 363 | "source": [ 364 | "batch_of_x_inputs = jnp.stack((jnp.zeros((2,)), jnp.ones((2,))))\n", 365 | "vmapped_gaussian_kernel(batch_of_x_inputs, y_input)" 366 | ] 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "id": "1159d42e-59f8-4492-a344-62ff98833536", 371 | "metadata": { 372 | "slideshow": { 373 | "slide_type": "subslide" 374 | }, 375 | "tags": [] 376 | }, 377 | "source": [ 378 | "### Arbitrary transformation composition\n", 379 | "\n", 380 | "Importantly, all these transformations can be composed together (!):" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 10, 386 | "id": "231a5524-346c-45c8-af0b-ca82db740d39", 387 | "metadata": {}, 388 | "outputs": [ 389 | { 390 | "data": { 391 | "text/plain": [ 392 | "DeviceArray([[-0. , -0. ],\n", 393 | " [-0.30326533, -0.30326533]], dtype=float32)" 394 | ] 395 | }, 396 | "execution_count": 10, 397 | "metadata": {}, 398 | "output_type": "execute_result" 399 | } 400 | ], 401 | "source": [ 402 | "jit(vmap(grad(gaussian_kernel), in_axes=(0, None)))(batch_of_x_inputs, y_input)" 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "id": "b1c86aad-47cb-4f49-a442-6a57b8dc05ac", 408 | "metadata": { 409 | "slideshow": { 410 | "slide_type": "subslide" 411 | }, 412 | "tags": [] 413 | }, 414 | "source": [ 415 | "With these basic building blocks in mind, we can start implementing interestiting mathematical and statistical objects. As you will see, using the function transformations provided by \n", 416 | "jax will often benefit you in one or multiple of the following ways:\n", 417 | "\n", 418 | "- reduce code complexity, increase code readability\n", 419 | "- increase code efficiency" 420 | ] 421 | }, 422 | { 423 | "cell_type": "markdown", 424 | "id": "41fb74c3-e527-43e0-b044-f9dcd81bf322", 425 | "metadata": {}, 426 | "source": [ 427 | "### Exercise 1: Computing the (Gaussian Kernel) gram matrix between two datasets" 428 | ] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "id": "3acd4131-e729-47ed-8519-97a88604fe28", 433 | "metadata": { 434 | "slideshow": { 435 | "slide_type": "subslide" 436 | }, 437 | "tags": [] 438 | }, 439 | "source": [ 440 | "As a first nontrivial composition exercise, try to implement a function that computes the **gram matrix** between two datasets $X = \\{x_i\\}_{i=1}^{n}$ and $X = \\{y_i\\}_{i=1}^{m}$, for a given kernel $k$. This gram matrix is given by:\n", 441 | " \n", 442 | " \n", 443 | "$$\n", 444 | "M = \\begin{vmatrix}\n", 445 | "k(x_1, y_1) & \\dots & k(x_1, y_m)\\\\\n", 446 | "\\vdots & \\ddots & \\vdots \\\\\n", 447 | "k(x_n, y_1) & \\dots & k(x_n, y_m)\\\\\n", 448 | "\\end{vmatrix}\n", 449 | "$$\n" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": 11, 455 | "id": "e53ccace-db4f-415b-ba88-c3a5d0336bd2", 456 | "metadata": { 457 | "slideshow": { 458 | "slide_type": "subslide" 459 | }, 460 | "tags": [] 461 | }, 462 | "outputs": [], 463 | "source": [ 464 | "import numpy as np\n", 465 | "random_state = np.random.RandomState(42)\n", 466 | "X = random_state.randn(100, 2)\n", 467 | "Y = random_state.randn(200, 2)" 468 | ] 469 | }, 470 | { 471 | "cell_type": "markdown", 472 | "id": "7294df9d-f253-442c-be38-260dfc291871", 473 | "metadata": { 474 | "slideshow": { 475 | "slide_type": "subslide" 476 | }, 477 | "tags": [] 478 | }, 479 | "source": [ 480 | "#### numpy-style implementations" 481 | ] 482 | }, 483 | { 484 | "cell_type": "markdown", 485 | "id": "fb8ecf36-55a8-4fe0-8e71-7e15ebb2de4b", 486 | "metadata": {}, 487 | "source": [ 488 | "A naive implementation would consists in two nested for-loops that iterates over X and Y to compute each element of the matrix:" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": 12, 494 | "id": "ee4d029d-c36b-477c-9cbc-3092e41bfbf7", 495 | "metadata": {}, 496 | "outputs": [], 497 | "source": [ 498 | "def compute_gram_matrix_naively(X, Y):\n", 499 | " gram_matrix = np.empty((X.shape[0], Y.shape[0]))\n", 500 | " for i in range(X.shape[0]):\n", 501 | " for j in range(Y.shape[0]):\n", 502 | " kij = gaussian_kernel(X[i], Y[j]) # note the interoperability between jax and numpy arrays within **untransformed** jax funtions\n", 503 | " gram_matrix[i, j] = kij\n", 504 | " return gram_matrix" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": 13, 510 | "id": "aa26bae6-6835-4e55-8291-7ca3319f20b8", 511 | "metadata": {}, 512 | "outputs": [], 513 | "source": [ 514 | "M_naive = compute_gram_matrix_naively(X, Y)" 515 | ] 516 | }, 517 | { 518 | "cell_type": "markdown", 519 | "id": "09b880d6-021a-465d-9b8b-63d9be6620d1", 520 | "metadata": { 521 | "slideshow": { 522 | "slide_type": "subslide" 523 | }, 524 | "tags": [] 525 | }, 526 | "source": [ 527 | "A more efficient implementation consists in leveraging numpy's broadcasting abilities to compute the entries of the gram matrix in a vectorized manner.\n", 528 | "Note that this requires plumbing axis values in reduction operations happening during the kernel computation:" 529 | ] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": 14, 534 | "id": "437676b6-b054-427c-ba86-6ba0a0babea4", 535 | "metadata": {}, 536 | "outputs": [], 537 | "source": [ 538 | "def gaussian_kernel_explicit_reduction_axes(x, y):\n", 539 | " sigma = 2\n", 540 | " z = x - y\n", 541 | " # we add a axis=-1 to prevent numpy from summing over all axes when giving gaussian_kernel a stack of tensors.\n", 542 | " # return jnp.exp(-jnp.sum(jnp.square(z)) / sigma ** 2)\n", 543 | " return jnp.exp(-jnp.sum(jnp.square(z), axis=-1) / sigma ** 2)" 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": 15, 549 | "id": "bf3a5a6c-f1c4-4ca0-bf00-f96da0f71368", 550 | "metadata": {}, 551 | "outputs": [], 552 | "source": [ 553 | "def compute_gram_matrix_bcast_semantics(X, Y):\n", 554 | " return gaussian_kernel_explicit_reduction_axes(X[:, None, :], Y[None, :, :])" 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "execution_count": 16, 560 | "id": "a752c3d7-3f9a-4521-9214-b5aab5fec8c6", 561 | "metadata": {}, 562 | "outputs": [], 563 | "source": [ 564 | "M_bcast = compute_gram_matrix_bcast_semantics(X, Y)\n", 565 | "assert np.allclose(M_naive, M_bcast)" 566 | ] 567 | }, 568 | { 569 | "cell_type": "markdown", 570 | "id": "bb590f0f-1855-4e88-9deb-d29bafd77c56", 571 | "metadata": { 572 | "slideshow": { 573 | "slide_type": "subslide" 574 | }, 575 | "tags": [] 576 | }, 577 | "source": [ 578 | "Benchmarking these two functions highlights the dramatically higher efficiency of the latter method:" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": 17, 584 | "id": "d89c1d1f-3acb-4b26-a24b-3097a4e58192", 585 | "metadata": {}, 586 | "outputs": [ 587 | { 588 | "name": "stdout", 589 | "output_type": "stream", 590 | "text": [ 591 | "292 ms ± 4.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" 592 | ] 593 | } 594 | ], 595 | "source": [ 596 | "%timeit M_naive = compute_gram_matrix_naively(X, Y)" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": 18, 602 | "id": "5a2475e2-c16b-4da4-885b-e33618a72b81", 603 | "metadata": {}, 604 | "outputs": [ 605 | { 606 | "name": "stdout", 607 | "output_type": "stream", 608 | "text": [ 609 | "216 µs ± 3.17 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" 610 | ] 611 | } 612 | ], 613 | "source": [ 614 | "%timeit M_bcast = compute_gram_matrix_bcast_semantics(X, Y)" 615 | ] 616 | }, 617 | { 618 | "cell_type": "markdown", 619 | "id": "cd9cb735-f547-462b-8f82-692e6685a2d2", 620 | "metadata": { 621 | "slideshow": { 622 | "slide_type": "subslide" 623 | }, 624 | "tags": [] 625 | }, 626 | "source": [ 627 | "However, the latter method required:\n", 628 | "\n", 629 | "- rewriting `gaussian_kernel` to account for the case when it is given batched inputs. What if `gaussian_kernel` is replaced by a much more complex function? What is it is replaced by a third-party function that you are not familiar with?\n", 630 | "- relying on broadcasting semantics in `compute_gram_matrix_bcast_semantics`: althought in that case, the use of such semantics was quite simple, its use in complex code base can quickly become error-prone." 631 | ] 632 | }, 633 | { 634 | "cell_type": "markdown", 635 | "id": "cd07bd73-cc67-472c-a68c-47df3f789915", 636 | "metadata": { 637 | "slideshow": { 638 | "slide_type": "subslide" 639 | }, 640 | "tags": [] 641 | }, 642 | "source": [ 643 | "#### Jax-implementation" 644 | ] 645 | }, 646 | { 647 | "cell_type": "code", 648 | "execution_count": 19, 649 | "id": "f700875e-b672-4c39-b67b-e355b01d2864", 650 | "metadata": {}, 651 | "outputs": [], 652 | "source": [ 653 | "# Exercise: combine of vmap and gaussian_kernel within compute_gram_matrix_using_vmap to compute the gram matrix between X and Y\n", 654 | "def compute_gram_matrix_using_vmap(X, Y):\n", 655 | " # raise NotImplemented\n", 656 | " return vmap(vmap(gaussian_kernel, in_axes=(None, 0)), in_axes=(0, None))(X, Y)" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": 20, 662 | "id": "a143ddf1-8a5e-4ea0-a519-2cfc777f2587", 663 | "metadata": {}, 664 | "outputs": [], 665 | "source": [ 666 | "M_vmap = compute_gram_matrix_using_vmap(X, Y)\n", 667 | "assert np.allclose(M_vmap, M_bcast)" 668 | ] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "execution_count": 21, 673 | "id": "5acaaa09-a7a6-454e-b9e8-0f0912d8c213", 674 | "metadata": {}, 675 | "outputs": [], 676 | "source": [ 677 | "jitted_gaussian_kernel = jit(gaussian_kernel)\n", 678 | "_ = jitted_gaussian_kernel(jnp.ones((2,)), jnp.zeros((2,))).block_until_ready()" 679 | ] 680 | }, 681 | { 682 | "cell_type": "markdown", 683 | "id": "7a056ab3-6fca-4c36-8c3b-d39d3973d393", 684 | "metadata": {}, 685 | "source": [ 686 | "### Exercise 2: A jax-implementation of the Kernelized Stein Discrepancy.\n", 687 | "\n", 688 | "\n", 689 | "Further composition of `vmap`, `jit` and `grad` happens when computing the infamous `Stein` kernel, which is used to evaluate whether samples $\\{x_i\\}_{i=1}^{n}$ are distributed according to a test density $q(x) = p(x) / Z$ known up to a normalizing constant Z. Given some given kernel $k$, Writing $s(x) = \\nabla \\log p(x)$, the stein kernel writes:\n", 690 | "\n", 691 | "$$\n", 692 | "k_{\\textrm{stein}}(x, y) = s(x)^\\top s(y)k(x, y) + s(x)^\\top \\nabla_yk(x, y) + \\nabla_x k(x, y)^\\top s(y) + \\textrm{div}_x(∇_yk(x, y))\n", 693 | "$$\n", 694 | "\n", 695 | "which can be used to compute a measure of discrepancy between the samples and the density $p$, called the Kernelized Stein Discrepancy (or KSD):\n", 696 | "\n", 697 | "$$\n", 698 | "\\text{KSD}(p, x) = \\sqrt{ \\frac{1}{N^2} \\sum_{i=1}^n \\sum_{j=1}^n k_{\\textrm{stein}}(x_i, y_j)}\n", 699 | "$$" 700 | ] 701 | }, 702 | { 703 | "cell_type": "markdown", 704 | "id": "3516c328-d001-495b-b758-f9372697db83", 705 | "metadata": {}, 706 | "source": [ 707 | "As an exercise, compute KSD between $p$ and $X$ and using the gaussian kernel as the base kernel, and $p$ being standard normal distributio\n", 708 | "We will place ourselves under the null hypothesis, meaning that $X$ will be sampled from $p$." 709 | ] 710 | }, 711 | { 712 | "cell_type": "code", 713 | "execution_count": 22, 714 | "id": "1ad2457f-0879-4640-91c2-380aa6028c28", 715 | "metadata": {}, 716 | "outputs": [], 717 | "source": [ 718 | "def p(x):\n", 719 | " return jnp.exp(-0.5 * jnp.sum(jnp.square(x))) # unnormalized standard normal density\n", 720 | "\n", 721 | "s = ... # compute the score function" 722 | ] 723 | }, 724 | { 725 | "cell_type": "code", 726 | "execution_count": 23, 727 | "id": "72df2649-5067-482c-a116-00001b1cb4f7", 728 | "metadata": {}, 729 | "outputs": [], 730 | "source": [ 731 | "# Solution\n", 732 | "s = grad(lambda x: jnp.log(p(x)))" 733 | ] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "execution_count": 24, 738 | "id": "21ffa12f-8032-45c1-98ec-87cc329cf19b", 739 | "metadata": {}, 740 | "outputs": [], 741 | "source": [ 742 | "# You will need the jacobian function transformation of jax, which computes the jacobian of a vector-valued function\n", 743 | "from jax import jacobian" 744 | ] 745 | }, 746 | { 747 | "cell_type": "code", 748 | "execution_count": 25, 749 | "id": "0760266f-9f58-4f21-8299-eb71f19ec634", 750 | "metadata": {}, 751 | "outputs": [], 752 | "source": [ 753 | "def stein_kernel(x, y):\n", 754 | " k = gaussian_kernel\n", 755 | " term_1 = jnp.dot(s(x), s(y)) * k(x, y)\n", 756 | " term_2 = jnp.dot(s(x), grad(k, argnums=1)(x, y))\n", 757 | " term_3 = jnp.dot(grad(k, argnums=0)(x, y), s(y))\n", 758 | " \n", 759 | " # Solution\n", 760 | " term_4 = jnp.trace(jacobian(grad(k, argnums=1), argnums=0)(x, y))\n", 761 | " \n", 762 | " return term_1 + term_2 + term_3 + term_4" 763 | ] 764 | }, 765 | { 766 | "cell_type": "code", 767 | "execution_count": 26, 768 | "id": "c29f9790-7b8b-4c6e-a490-8d3ecfecd25b", 769 | "metadata": {}, 770 | "outputs": [], 771 | "source": [ 772 | "def compute_gram_matrix_stein_kernel(X):\n", 773 | " return vmap(vmap(stein_kernel, in_axes=(None, 0)), in_axes=(0, None))(X, X)" 774 | ] 775 | }, 776 | { 777 | "cell_type": "code", 778 | "execution_count": 27, 779 | "id": "0e3b506e-754c-4e7b-98d4-3202e1d46567", 780 | "metadata": {}, 781 | "outputs": [], 782 | "source": [ 783 | "def ksd(X):\n", 784 | " N = X.shape[0]\n", 785 | " return jnp.sqrt(jnp.sum(compute_gram_matrix_stein_kernel(X)) / (N **2))" 786 | ] 787 | }, 788 | { 789 | "cell_type": "code", 790 | "execution_count": 28, 791 | "id": "b732e0ce-12fd-47c3-bba1-6ef2616a3d21", 792 | "metadata": {}, 793 | "outputs": [], 794 | "source": [ 795 | "from jax import random\n", 796 | "key = random.PRNGKey(0)\n", 797 | "X = random.normal(key, (5000,2))" 798 | ] 799 | }, 800 | { 801 | "cell_type": "markdown", 802 | "id": "ce3fd341-7705-4570-9ca0-2f27607630ae", 803 | "metadata": {}, 804 | "source": [ 805 | "Under the null hypothesis KSD(X, p) should tend to 0 at a $1/\\sqrt{N}$ rate. Given that X is actually sampled from a unit gaussian, we should recover this regime if we implemented\n", 806 | "the KSD properly. The following cells will plot the relationship between $1/\\sqrt{N}$ and $\\text{KSD}(X, p)$ for different sample size $N$ for X." 807 | ] 808 | }, 809 | { 810 | "cell_type": "code", 811 | "execution_count": 29, 812 | "id": "73272e70-12d4-46f4-a460-6d4dab6fe0e4", 813 | "metadata": {}, 814 | "outputs": [], 815 | "source": [ 816 | "sample_sizes = (50, 100, 200, 300, 500, 1000, 2000, 5000)\n", 817 | "random_state = np.random.RandomState(40)\n", 818 | "\n", 819 | "ksd_vals = []\n", 820 | "for N in sample_sizes:\n", 821 | " ksd_vals_this_iter = []\n", 822 | " for rs in range(5):\n", 823 | " X = random_state.randn(N, 2)\n", 824 | " ksd_vals_this_iter.append(ksd(jnp.array(X)))\n", 825 | " ksd_vals.append(jnp.mean(jnp.array(ksd_vals_this_iter)))" 826 | ] 827 | }, 828 | { 829 | "cell_type": "code", 830 | "execution_count": 30, 831 | "id": "6b46607b-7aae-4690-a619-76215260b7a6", 832 | "metadata": {}, 833 | "outputs": [ 834 | { 835 | "data": { 836 | "text/plain": [ 837 | "" 838 | ] 839 | }, 840 | "execution_count": 30, 841 | "metadata": {}, 842 | "output_type": "execute_result" 843 | }, 844 | { 845 | "data": { 846 | "image/png": "\n", 847 | "text/plain": [ 848 | "
" 849 | ] 850 | }, 851 | "metadata": {}, 852 | "output_type": "display_data" 853 | } 854 | ], 855 | "source": [ 856 | "import matplotlib.pyplot as plt\n", 857 | "plt.plot(1 / jnp.sqrt(jnp.array(sample_sizes)), ksd_vals, label=\"KSD\")\n", 858 | "\n", 859 | "coef = np.polyfit(1 / jnp.sqrt(jnp.array(sample_sizes)),ksd_vals,1)\n", 860 | "poly1d_fn = np.poly1d(coef) \n", 861 | "# poly1d_fn is now a function which takes in x and returns an estimate for y\n", 862 | "\n", 863 | "plt.plot(1 / jnp.sqrt(jnp.array(sample_sizes)), poly1d_fn(jnp.array(1 / jnp.sqrt(jnp.array(sample_sizes)))), '--k', label=\"linear fit\")\n", 864 | "plt.xlabel(\"1/√N\")\n", 865 | "plt.ylabel(\"KSD(X, p)\")\n", 866 | "plt.legend()" 867 | ] 868 | }, 869 | { 870 | "cell_type": "markdown", 871 | "id": "1e11843b-d02a-4a33-b822-c0d7369ee2bb", 872 | "metadata": {}, 873 | "source": [ 874 | "## Randomness in JAX" 875 | ] 876 | }, 877 | { 878 | "cell_type": "markdown", 879 | "id": "65142534-1c18-4e40-af17-d611499a1da3", 880 | "metadata": {}, 881 | "source": [ 882 | "### Description " 883 | ] 884 | }, 885 | { 886 | "cell_type": "markdown", 887 | "id": "9a9a720f-6f12-48ff-aeeb-c3c2c9a9ccbb", 888 | "metadata": {}, 889 | "source": [ 890 | "- So far, we have conveniently evaded the notion of randomness in Jax by delegating draws from probability distributions to numpy.\n", 891 | "- Random variables can be drawn from jax. However, the handling of randomness is an example of the functionalization constraint that jax imposes.\n", 892 | "- his constrain prevents numpy-like random variables drawing, since drawing a random variable in numpy incurs a side effect by modifying numpy's random state:" 893 | ] 894 | }, 895 | { 896 | "cell_type": "code", 897 | "execution_count": 31, 898 | "id": "93f314aa-5ec6-4f67-9e70-e1215e7ff5d6", 899 | "metadata": {}, 900 | "outputs": [], 901 | "source": [ 902 | "def draw_one_random_variable():\n", 903 | " return np.random.randn()" 904 | ] 905 | }, 906 | { 907 | "cell_type": "code", 908 | "execution_count": 32, 909 | "id": "36bdb6a9-e92c-474d-b846-8f97d462d1fa", 910 | "metadata": {}, 911 | "outputs": [ 912 | { 913 | "name": "stdout", 914 | "output_type": "stream", 915 | "text": [ 916 | "-0.07537296806232933 1.6193415953720864\n" 917 | ] 918 | } 919 | ], 920 | "source": [ 921 | "print(draw_one_random_variable(), draw_one_random_variable()) # two different outputs for the same input (Φ): " 922 | ] 923 | }, 924 | { 925 | "cell_type": "markdown", 926 | "id": "8c2781b0-d006-4f43-b51f-682286dd4dac", 927 | "metadata": {}, 928 | "source": [ 929 | "`RandomState` Based randomness generation is also forbidden in functional programming frameworks, since it modifies its input:" 930 | ] 931 | }, 932 | { 933 | "cell_type": "code", 934 | "execution_count": 33, 935 | "id": "90a08a0c-b2c7-44f5-a9a3-70e4dbe72bbe", 936 | "metadata": {}, 937 | "outputs": [], 938 | "source": [ 939 | "def draw_one_random_variable_RandomState(rs):\n", 940 | " return rs.randn()" 941 | ] 942 | }, 943 | { 944 | "cell_type": "code", 945 | "execution_count": 34, 946 | "id": "a1655007-e9ce-40ed-a57e-a49c2de6c981", 947 | "metadata": {}, 948 | "outputs": [ 949 | { 950 | "data": { 951 | "text/plain": [ 952 | "False" 953 | ] 954 | }, 955 | "execution_count": 34, 956 | "metadata": {}, 957 | "output_type": "execute_result" 958 | } 959 | ], 960 | "source": [ 961 | "rs = np.random.RandomState(42)\n", 962 | "orig_state = rs.get_state()\n", 963 | "\n", 964 | "draw_one_random_variable_RandomState(rs)\n", 965 | "np.allclose(rs.get_state()[1], orig_state[1])" 966 | ] 967 | }, 968 | { 969 | "cell_type": "markdown", 970 | "id": "bcb152e1-01a9-4689-9bcd-64eef2d61a6c", 971 | "metadata": {}, 972 | "source": [ 973 | "Instead, `jax` generates random variables in a functionally pure manner by:\n", 974 | "- requiring a one-off `PRNGKey`, a `RandomState` equivalent characterizing some pseudorandomness state.\n", 975 | "- not alterating this key when generating a random variable\n", 976 | "\n" 977 | ] 978 | }, 979 | { 980 | "cell_type": "markdown", 981 | "id": "c7033363-a038-487e-bf3d-b9775c409a17", 982 | "metadata": {}, 983 | "source": [ 984 | "- The consequence of this pure randomness generation is that new keys must be repeteadly generated to \"create\" new randomness.\n", 985 | "- Generating new keys from old ones is known in the computational pseudorandomness literrature as splitting, for which a utility function is provided in jax." 986 | ] 987 | }, 988 | { 989 | "cell_type": "code", 990 | "execution_count": 35, 991 | "id": "09d05b35-7018-4530-9b44-e66f31e08f50", 992 | "metadata": {}, 993 | "outputs": [], 994 | "source": [ 995 | "# key \n", 996 | "# | \n", 997 | "# | random.split(key)\n", 998 | "# | \n", 999 | "# ---------------------------\n", 1000 | "# | |\n", 1001 | "# | |\n", 1002 | "# | |\n", 1003 | "# subkey1 subkey2\n", 1004 | "# | \n", 1005 | "# random.split(subkey1) | | random.split(subkey2)\n", 1006 | "# | | \n", 1007 | "# -------------- --------------\n", 1008 | "# | | | |\n", 1009 | "# | | | |\n", 1010 | "# | | | |\n", 1011 | "# subkey1.1 subkey1.2 subkey2.1 subkey2.2\n", 1012 | "#" 1013 | ] 1014 | }, 1015 | { 1016 | "cell_type": "code", 1017 | "execution_count": 36, 1018 | "id": "bb7037e4-253e-47e2-b794-d4b6cb7b96ab", 1019 | "metadata": {}, 1020 | "outputs": [], 1021 | "source": [ 1022 | "from jax import random\n", 1023 | "key = random.PRNGKey(42)\n", 1024 | "\n", 1025 | "key, subkey = random.split(key)\n", 1026 | "# provide a one-shot \"RandomState\"-like subkey to generate a random variable.\n", 1027 | "x1 = random.normal(subkey)\n", 1028 | "\n", 1029 | "# From now on, subkey cannot be used since it was used to generate randomness in random.normal\n", 1030 | "x1_same = random.normal(subkey)\n", 1031 | "assert x1_same == x1\n", 1032 | "\n", 1033 | "# You're left with an \"untouched\" new key that can be split to generate new randomness\n", 1034 | "key, subkey = random.split(subkey)\n", 1035 | "\n", 1036 | "x2 = random.normal(subkey)\n", 1037 | "assert x2 != x1" 1038 | ] 1039 | }, 1040 | { 1041 | "cell_type": "markdown", 1042 | "id": "42a1af8c-eec9-4e90-8a78-37651b0b6600", 1043 | "metadata": {}, 1044 | "source": [ 1045 | "### Exercise 3: re-plot the KSD values obtained from your previous exercise, but by generating random jax arrays natively this time!\n" 1046 | ] 1047 | }, 1048 | { 1049 | "cell_type": "code", 1050 | "execution_count": 37, 1051 | "id": "61430432-a145-4e2f-a6ab-301c13f4ddf5", 1052 | "metadata": {}, 1053 | "outputs": [], 1054 | "source": [ 1055 | "from jax import random\n", 1056 | "\n", 1057 | "sample_sizes = (50, 100, 200, 300, 500, 1000, 2000, 5000)\n", 1058 | "\n", 1059 | "key = random.PRNGKey(0)\n", 1060 | "\n", 1061 | "ksd_vals = []\n", 1062 | "for N in sample_sizes:\n", 1063 | " ksd_vals_this_iter = []\n", 1064 | " for rs in range(5):\n", 1065 | " key, subkey = random.split(key)\n", 1066 | " X = jax.random.normal(key, (N, 2))\n", 1067 | " ksd_vals_this_iter.append(jit(ksd)(jnp.array(X)))\n", 1068 | " ksd_vals.append(jnp.mean(jnp.array(ksd_vals_this_iter)))" 1069 | ] 1070 | }, 1071 | { 1072 | "cell_type": "code", 1073 | "execution_count": 38, 1074 | "id": "94d5ec3d-aab8-446b-a63b-58e21972779b", 1075 | "metadata": {}, 1076 | "outputs": [ 1077 | { 1078 | "data": { 1079 | "text/plain": [ 1080 | "" 1081 | ] 1082 | }, 1083 | "execution_count": 38, 1084 | "metadata": {}, 1085 | "output_type": "execute_result" 1086 | }, 1087 | { 1088 | "data": { 1089 | "image/png": "\n", 1090 | "text/plain": [ 1091 | "
" 1092 | ] 1093 | }, 1094 | "metadata": {}, 1095 | "output_type": "display_data" 1096 | } 1097 | ], 1098 | "source": [ 1099 | "import matplotlib.pyplot as plt\n", 1100 | "plt.plot(1 / jnp.sqrt(jnp.array(sample_sizes)), ksd_vals, label=\"KSD\")\n", 1101 | "\n", 1102 | "coef = np.polyfit(1 / jnp.sqrt(jnp.array(sample_sizes)),ksd_vals,1)\n", 1103 | "poly1d_fn = np.poly1d(coef) \n", 1104 | "# poly1d_fn is now a function which takes in x and returns an estimate for y\n", 1105 | "\n", 1106 | "plt.plot(1 / jnp.sqrt(jnp.array(sample_sizes)), poly1d_fn(jnp.array(1 / jnp.sqrt(jnp.array(sample_sizes)))), '--k', label=\"linear fit\")\n", 1107 | "plt.xlabel(\"1/√N\")\n", 1108 | "plt.ylabel(\"KSD(X, p)\")\n", 1109 | "plt.legend()" 1110 | ] 1111 | }, 1112 | { 1113 | "cell_type": "markdown", 1114 | "id": "8aafd0bb-3c90-423a-b969-4ace884b7acf", 1115 | "metadata": {}, 1116 | "source": [ 1117 | "## Control-Flows in JAX\n", 1118 | "\n", 1119 | "Understanding control flows in `jax` requires some additional digging of jax internals:" 1120 | ] 1121 | }, 1122 | { 1123 | "cell_type": "markdown", 1124 | "id": "4534d45e-92f9-472d-a93f-8790dd79c717", 1125 | "metadata": {}, 1126 | "source": [ 1127 | "### A primer on jax internals: static shapes, functional purity" 1128 | ] 1129 | }, 1130 | { 1131 | "cell_type": "markdown", 1132 | "id": "ad38c9fe-1f83-41de-be22-945cddfc1dda", 1133 | "metadata": {}, 1134 | "source": [ 1135 | "\n", 1136 | "- Jax works by transforming a functionally pure function into a directed acyclic computation graph\n", 1137 | "- The interpreted nature of python makes the creation of this graph a challenge: such computation graphs are typically created ahead of time, during the compilation process when working with compiled languages.\n", 1138 | "- In intrepreted languages, the structure of the computational graph will be \"discovered\" at runtime, when executing the said program." 1139 | ] 1140 | }, 1141 | { 1142 | "cell_type": "markdown", 1143 | "id": "3551379f-4939-4459-b35c-cec5e31c201f", 1144 | "metadata": {}, 1145 | "source": [ 1146 | "- To build this graph, jax relies on the amazing polymorphism capabilities of Python by substituting input to jax functions with special jax constructs called \"tracers\".\n", 1147 | "- These tracers progressively record the set of instructions applied on them, gathering the necessary informations to create a computational graph, called a \"Jax Intermediate Representation\", or JAX IR.\n", 1148 | "- This graph can then be traversed to perform computation graph transformations by defining appropritate transformation translation rules, leading to `vmap`, `jit`, `grad` etc.\n", 1149 | "\n", 1150 | "\n", 1151 | "Jax IRs can be described using a 1st order ANF called a `jaxpr`. Example is given below (credits to the Jax Team for the slides)" 1152 | ] 1153 | }, 1154 | { 1155 | "cell_type": "markdown", 1156 | "id": "e0700bcf-6928-4298-b329-a501c130b159", 1157 | "metadata": {}, 1158 | "source": [ 1159 | "![title](./fun_to_jax_ir.png)" 1160 | ] 1161 | }, 1162 | { 1163 | "cell_type": "markdown", 1164 | "id": "4f26f401-37e9-4c90-8ddb-4016b07d7507", 1165 | "metadata": {}, 1166 | "source": [ 1167 | "### Constructing jax-friendly control flows " 1168 | ] 1169 | }, 1170 | { 1171 | "cell_type": "markdown", 1172 | "id": "abe45ea2-fc0a-4171-97d9-f4a371fcab71", 1173 | "metadata": {}, 1174 | "source": [ 1175 | "Additionally, as of now, `jaxpr`s are specialized to specific input and output shapes: thus, these shapes are required to be static, and cannot change depending on the value of the input. Thus, the following function:" 1176 | ] 1177 | }, 1178 | { 1179 | "cell_type": "code", 1180 | "execution_count": 39, 1181 | "id": "03e7e1b7-f98b-4a58-a283-b8c6a2b67b32", 1182 | "metadata": {}, 1183 | "outputs": [], 1184 | "source": [ 1185 | "def f(x):\n", 1186 | " if x > 0:\n", 1187 | " return jnp.ones((2,))\n", 1188 | " else:\n", 1189 | " return jnp.ones((3,))" 1190 | ] 1191 | }, 1192 | { 1193 | "cell_type": "markdown", 1194 | "id": "649b55df-9635-413c-836e-ebb561ee0067", 1195 | "metadata": {}, 1196 | "source": [ 1197 | "will not yield a valid `jaxpr`, as the shape of the output depends on dynamic values (namely the values on the input).\n", 1198 | "Note however that because shapes are treated as static this:" 1199 | ] 1200 | }, 1201 | { 1202 | "cell_type": "code", 1203 | "execution_count": 40, 1204 | "id": "e39493c8-5931-4508-9718-50afaf57d5f3", 1205 | "metadata": {}, 1206 | "outputs": [], 1207 | "source": [ 1208 | "def f(x):\n", 1209 | " if x.shape[0] == 2:\n", 1210 | " return jnp.ones((2,))\n", 1211 | " elif x.shape[0] == 3:\n", 1212 | " return jnp.ones((3,))\n", 1213 | " else:\n", 1214 | " raise ValueError" 1215 | ] 1216 | }, 1217 | { 1218 | "cell_type": "markdown", 1219 | "id": "4ad67505-d14e-46c4-8ad6-5124d3a4e28b", 1220 | "metadata": {}, 1221 | "source": [ 1222 | "will yield a valid `jaxpr`, and can thus be subject to transformations.\n" 1223 | ] 1224 | }, 1225 | { 1226 | "cell_type": "markdown", 1227 | "id": "9488908a-ba58-49bf-a932-16fc49c6c7d7", 1228 | "metadata": {}, 1229 | "source": [ 1230 | "#### Jax Conditionals\n", 1231 | "\n", 1232 | "Note however, that even a modified version of the first function that would return outputs with identical shapes on both branches:" 1233 | ] 1234 | }, 1235 | { 1236 | "cell_type": "code", 1237 | "execution_count": 41, 1238 | "id": "a259c020-0222-499b-9e15-7dbe8c4ffc73", 1239 | "metadata": {}, 1240 | "outputs": [], 1241 | "source": [ 1242 | "def f(x):\n", 1243 | " if x > 0:\n", 1244 | " return jnp.ones((2,))\n", 1245 | " else:\n", 1246 | " return jnp.zeros((2,))" 1247 | ] 1248 | }, 1249 | { 1250 | "cell_type": "markdown", 1251 | "id": "4973897c-0388-49e9-8ea0-bb89683dad5d", 1252 | "metadata": {}, 1253 | "source": [ 1254 | "is fundamentally incompatible with a tracing-based mechanism: within a single function call, a tracer will visit only one branch,\n", 1255 | "leading to an incomplete construction of the funciton's computational graph.\n", 1256 | "\n", 1257 | "To solve this problem, jax exposes special control-flow primitives to be used in-lieu of Python control flows. Here is a jax-compatible rewrite of `f`,\n", 1258 | "using `jax.lax.cond`, the native jax if-statement primitive:" 1259 | ] 1260 | }, 1261 | { 1262 | "cell_type": "code", 1263 | "execution_count": 42, 1264 | "id": "0678e724-7cc2-4bdd-ac0e-5b230b0b7d15", 1265 | "metadata": {}, 1266 | "outputs": [], 1267 | "source": [ 1268 | "import jax.lax\n", 1269 | "\n", 1270 | "def f(x):\n", 1271 | " return jax.lax.cond(x > 0, lambda: jnp.ones((2,)), lambda: jnp.zeros((2,)))" 1272 | ] 1273 | }, 1274 | { 1275 | "cell_type": "code", 1276 | "execution_count": 43, 1277 | "id": "23fde93a-338d-45c6-930b-106b4834c171", 1278 | "metadata": {}, 1279 | "outputs": [ 1280 | { 1281 | "name": "stdout", 1282 | "output_type": "stream", 1283 | "text": [ 1284 | "f(-1.)=DeviceArray([0., 0.], dtype=float32)\n", 1285 | "f(2.)=DeviceArray([1., 1.], dtype=float32)\n" 1286 | ] 1287 | } 1288 | ], 1289 | "source": [ 1290 | "print(f\"{f(-1.)=}\")\n", 1291 | "print(f\"{f(2.)=}\")" 1292 | ] 1293 | }, 1294 | { 1295 | "cell_type": "markdown", 1296 | "id": "db9f0aed-2a14-4da5-9b1e-3fd5d4a76bc3", 1297 | "metadata": {}, 1298 | "source": [ 1299 | "#### Jax for-loops" 1300 | ] 1301 | }, 1302 | { 1303 | "cell_type": "markdown", 1304 | "id": "8aada819-eabe-4c3d-8f26-ef12fcc6048e", 1305 | "metadata": {}, 1306 | "source": [ 1307 | "for `for-loop` primitives, something different is at stake. Indeed, consider the following, (jax-valid) function" 1308 | ] 1309 | }, 1310 | { 1311 | "cell_type": "code", 1312 | "execution_count": 44, 1313 | "id": "8ca4099f-802d-4830-8e2a-90e72f5adf52", 1314 | "metadata": {}, 1315 | "outputs": [], 1316 | "source": [ 1317 | "def f(x):\n", 1318 | " for i in range(100):\n", 1319 | " x = x+1\n", 1320 | " return x" 1321 | ] 1322 | }, 1323 | { 1324 | "cell_type": "markdown", 1325 | "id": "dd13ca3b-7119-4268-b8c5-727f90815b19", 1326 | "metadata": {}, 1327 | "source": [ 1328 | "Unlike for the if-statement case, a jax-tracer will completely visit `f`'s computational graph. However, the tracer will not be aware of the loop-structure:\n", 1329 | "\n", 1330 | "- Looking at the resulting computational graph, the loop will appear \"unrolled\", e.g. each iteration yielding it's independent sequence of operations.\n", 1331 | "- Statically unrolling python for-loops can become costly when the number of iterations increases, increasing the compilation time." 1332 | ] 1333 | }, 1334 | { 1335 | "cell_type": "markdown", 1336 | "id": "cdd8878b-8bc5-4411-9391-954ccf784a74", 1337 | "metadata": {}, 1338 | "source": [ 1339 | "- As you may have figured it, jax exposes a for-loop special primitive, that lowers the entire loop to a single HLO For node.\n", 1340 | "- Using `jax.lax.cond` the user tells much more about the structure of the for loop than using a dynamic Pythonic for loop.\n", 1341 | "- In particular, all outputs of `body_fun` must keep the same shape across iterations" 1342 | ] 1343 | }, 1344 | { 1345 | "cell_type": "code", 1346 | "execution_count": 45, 1347 | "id": "72a2b700-4cdc-428b-91b1-ac3bdb4063fb", 1348 | "metadata": {}, 1349 | "outputs": [], 1350 | "source": [ 1351 | "import jax.lax\n", 1352 | "\n", 1353 | "def f_fori_loop(x):\n", 1354 | " def body_fun(i, x):\n", 1355 | " return x+1\n", 1356 | " return jax.lax.fori_loop(0, 100, body_fun, x)" 1357 | ] 1358 | }, 1359 | { 1360 | "cell_type": "code", 1361 | "execution_count": 46, 1362 | "id": "984c76be-96c0-42c7-838f-4e0478f59c93", 1363 | "metadata": {}, 1364 | "outputs": [], 1365 | "source": [ 1366 | "assert f(1) == f_fori_loop(1)" 1367 | ] 1368 | }, 1369 | { 1370 | "cell_type": "markdown", 1371 | "id": "16e2ce1a-3534-4f12-9db4-7a4991f9a7d8", 1372 | "metadata": {}, 1373 | "source": [ 1374 | "Note that this `fori_loop` does not capture intermediates iterates, but only the final iterate. Instead, use `jax.lax.scan` to capture all intermediate iterates." 1375 | ] 1376 | }, 1377 | { 1378 | "cell_type": "markdown", 1379 | "id": "cda11e2e-3f8c-4de4-a25d-a5c4d76a0a2f", 1380 | "metadata": {}, 1381 | "source": [ 1382 | "### Exercise 4: implementing a Monte-Carlo Markov Chain Algorithm\n", 1383 | "\n", 1384 | "\n", 1385 | "- A Monte-Carlo Markov Chain algorithm is an algorithm that can be used to approximately sample from an unnormalized probabilty distribution p.\n", 1386 | "- MCMC algorithm srepeateadly generates iterates $x_i$ by drawing them from a **Markov Kernel** $k(x_{i-1}, \\cdot)$, a distributution parametrized by $x_{i-1}$ and which admits $p$ as an invariant distribution:\n", 1387 | "\n", 1388 | "$$\n", 1389 | "\\int k(x, y) p(x) = p(y)\n", 1390 | "$$\n", 1391 | "\n", 1392 | "A famous family of kernels are Metropolis Hasting kernels, which, given some input $x$ (typically, $x_{i-1}$ during a loop), are parametrized by a proposal distribution $q(\\cdot | x)$ and are described by the following probabilistic program:\n", 1393 | "\n", 1394 | "\n", 1395 | "**How to sample from $k(\\cdot, x)$**\n", 1396 | "\n", 1397 | "- generate $x' \\sim q(\\cdot|x)$\n", 1398 | "- $\\alpha = \\frac{p(x')q(x | x')}{p(x)q(x'|x)}$\n", 1399 | "- with prob. $\\min(\\alpha, 1)$, return x'. Else, return x." 1400 | ] 1401 | }, 1402 | { 1403 | "cell_type": "markdown", 1404 | "id": "6f6f72cc-b615-43d2-88f9-9ae146c0283a", 1405 | "metadata": {}, 1406 | "source": [ 1407 | "The goal of this exercise is to write an efficient implementation of a MCMC algorithm with for some target density $p(x)$ given below, random walk MH kernel, which is a MH kernel characetrized by the proposal distribution $q(x, y) = \\mathcal N(x-y, \\sigma^2 I_2)$. \n", 1408 | "\n", 1409 | "Importantly, this proposal is symmetric $q(x|x')=q(x'|x)$. Thus, the computation of $\\alpha$ reduces to $p(x')/p(x)$.\n", 1410 | " \n", 1411 | "\n", 1412 | "To complete this algorithm you will have to:\n", 1413 | "- implement the random walk kernel\n", 1414 | "- create the for-loop representing the MCMC algorith, using a jax.lax.scan primitive" 1415 | ] 1416 | }, 1417 | { 1418 | "cell_type": "code", 1419 | "execution_count": 47, 1420 | "id": "f5bf9063-affe-4c2c-8e80-09949d87e2d7", 1421 | "metadata": {}, 1422 | "outputs": [], 1423 | "source": [ 1424 | "def p(x):\n", 1425 | " \n", 1426 | " # the unnormalized density of interest\n", 1427 | " sigma = 1.\n", 1428 | " return jnp.exp(-0.5/sigma**2 * jnp.sum(jnp.square(x - jnp.ones((2,)))))" 1429 | ] 1430 | }, 1431 | { 1432 | "cell_type": "code", 1433 | "execution_count": 48, 1434 | "id": "6d878b9d-068e-46d9-a74b-cbff7c7c19df", 1435 | "metadata": {}, 1436 | "outputs": [], 1437 | "source": [ 1438 | "# the proposal distribution q(\\cdot | x)\n", 1439 | "def q(x, y):\n", 1440 | " return jnp.exp(-0.5/sigma**2 * jnp.sum(jnp.square(x - y)))\n", 1441 | "\n", 1442 | "def random_walk_kernel(x, key, sigma=1):\n", 1443 | " # 1. generate x' from q(x, \\cdot). Hint: use a shifted and scaled sample from random.normal.\n", 1444 | " # 2. compute \\alpha\n", 1445 | " # 3. return x or x'.\n", 1446 | " raise NotImplemented" 1447 | ] 1448 | }, 1449 | { 1450 | "cell_type": "markdown", 1451 | "id": "f75e6c74-b459-46a2-9310-3378b3151c40", 1452 | "metadata": {}, 1453 | "source": [ 1454 | " Careful! there should not be native if statements! try running a jitted version of `random_walk_kernel`\n", 1455 | " \n" 1456 | ] 1457 | }, 1458 | { 1459 | "cell_type": "code", 1460 | "execution_count": 49, 1461 | "id": "f51cba4e-f50b-4853-ae61-b6ac4c6e2904", 1462 | "metadata": {}, 1463 | "outputs": [ 1464 | { 1465 | "ename": "TypeError", 1466 | "evalue": "exceptions must derive from BaseException", 1467 | "output_type": "error", 1468 | "traceback": [ 1469 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 1470 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 1471 | "Input \u001b[0;32mIn [49]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m x0 \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mzeros((\u001b[38;5;241m2\u001b[39m,))\n\u001b[1;32m 2\u001b[0m key \u001b[38;5;241m=\u001b[39m random\u001b[38;5;241m.\u001b[39mPRNGKey(\u001b[38;5;241m0\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m x1 \u001b[38;5;241m=\u001b[39m \u001b[43mrandom_walk_kernel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m x1_jitted \u001b[38;5;241m=\u001b[39m jit(random_walk_kernel)(x0, key)\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m jnp\u001b[38;5;241m.\u001b[39mallclose(x1, x1_jitted)\n", 1472 | "Input \u001b[0;32mIn [48]\u001b[0m, in \u001b[0;36mrandom_walk_kernel\u001b[0;34m(x, key, sigma)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrandom_walk_kernel\u001b[39m(x, key, sigma\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m):\n\u001b[1;32m 6\u001b[0m \u001b[38;5;66;03m# 1. generate x' from q(x, \\cdot). Hint: use a shifted and scaled sample from random.normal.\u001b[39;00m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# 2. compute \\alpha\u001b[39;00m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;66;03m# 3. return x or x'.\u001b[39;00m\n\u001b[0;32m----> 9\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mNotImplemented\u001b[39m\n", 1473 | "\u001b[0;31mTypeError\u001b[0m: exceptions must derive from BaseException" 1474 | ] 1475 | } 1476 | ], 1477 | "source": [ 1478 | "x0 = jnp.zeros((2,))\n", 1479 | "key = random.PRNGKey(0)\n", 1480 | "\n", 1481 | "x1 = random_walk_kernel(x0, key)\n", 1482 | "x1_jitted = jit(random_walk_kernel)(x0, key)\n", 1483 | "\n", 1484 | "assert jnp.allclose(x1, x1_jitted)" 1485 | ] 1486 | }, 1487 | { 1488 | "cell_type": "markdown", 1489 | "id": "43495098-33bd-4b44-bd97-a21042aafb64", 1490 | "metadata": {}, 1491 | "source": [ 1492 | "For the next step, implement a full MCMC algorithm by iteratively sampling from this kernel.\n", 1493 | "Below is a for-loop equivalent:" 1494 | ] 1495 | }, 1496 | { 1497 | "cell_type": "code", 1498 | "execution_count": null, 1499 | "id": "945d3e58-266e-4f9d-acfd-bd2b89c27071", 1500 | "metadata": {}, 1501 | "outputs": [], 1502 | "source": [ 1503 | "def rw_mcmc(num_steps, x0, key):\n", 1504 | " x = x0\n", 1505 | " iterates = [x]\n", 1506 | " for i in range(steps):\n", 1507 | " key, subkey = random.split(key)\n", 1508 | " x = random_walk_kernel(x, subkey)\n", 1509 | " iterates.append(x)\n", 1510 | " return iterates" 1511 | ] 1512 | }, 1513 | { 1514 | "cell_type": "markdown", 1515 | "id": "009f8a91-7032-412f-a75d-b38ab7a684ce", 1516 | "metadata": {}, 1517 | "source": [ 1518 | "Fill the following code:" 1519 | ] 1520 | }, 1521 | { 1522 | "cell_type": "code", 1523 | "execution_count": null, 1524 | "id": "87aeeb01-a985-4e26-9e72-09bed56ccbf2", 1525 | "metadata": {}, 1526 | "outputs": [], 1527 | "source": [ 1528 | "def rw_mcmc_jax_control_flow(num_steps, x0, key):\n", 1529 | " x = x0\n", 1530 | " iterates = [x]\n", 1531 | " \n", 1532 | " def step_fn(carry, input_):\n", 1533 | " output = ... # do a MH kernel step\n", 1534 | " return output, output\n", 1535 | " \n", 1536 | " # call jax.lax.scan using step_fn as its step function.\n", 1537 | " _, iterates = jax.lax.scan(...)\n", 1538 | " return iterates" 1539 | ] 1540 | }, 1541 | { 1542 | "cell_type": "markdown", 1543 | "id": "3f963ff5-81b2-47d1-9144-90492a3d5560", 1544 | "metadata": {}, 1545 | "source": [ 1546 | "You will notice how much faster the version of `rw_mcmc` using jax.lax control flow is as opposed to its Python counterpart!" 1547 | ] 1548 | }, 1549 | { 1550 | "cell_type": "markdown", 1551 | "id": "0d96037f-e74f-402c-84d1-465a485315ac", 1552 | "metadata": {}, 1553 | "source": [ 1554 | "## Extending the flexibility of jax functions with PyTrees\n", 1555 | "\n", 1556 | "\n", 1557 | "### A Primer on PyTrees \n", 1558 | "\n", 1559 | "The aforementioned fact that valid `jax` functions could only take `jax` arrays as inputs and return `jax` arrays is a slight reduction. Actually, jax functions can take as inputs arbitrary nested collections JAX arrays. These nested collections are called PyTrees.\n", 1560 | "\n", 1561 | "- Collections that defines valid PyTrees by defaults are tuples, namedtuple, lists, dictionaries. PyTrees are characetrized by the structure or the nested collections (where the cardinality of each collection is a static attribute of the PyTree structure, i.e cannot change across instances of this pytree structure). Such collections have a tree structure (hence the name), and the leaves are the array.\n", 1562 | "- The structure of the PyTree is called a PyTreeDef" 1563 | ] 1564 | }, 1565 | { 1566 | "cell_type": "code", 1567 | "execution_count": null, 1568 | "id": "38782b7d-6956-4905-8f24-014f73681128", 1569 | "metadata": {}, 1570 | "outputs": [], 1571 | "source": [ 1572 | "# Example\n", 1573 | "some_pytree = {'a': jnp.ones((2,)), 'b': jnp.ones((2,))}" 1574 | ] 1575 | }, 1576 | { 1577 | "cell_type": "markdown", 1578 | "id": "6becb5a7-3abc-4dc0-9335-ec7aa2b72f0c", 1579 | "metadata": {}, 1580 | "source": [ 1581 | "The principles applying to jax arrays mostly generalize to PyTrees. For instances, shapes must be known statically. This function for example is invalid jax code" 1582 | ] 1583 | }, 1584 | { 1585 | "cell_type": "code", 1586 | "execution_count": null, 1587 | "id": "c8ca3b35-d4cd-4918-9df4-76d2b93b60bd", 1588 | "metadata": {}, 1589 | "outputs": [], 1590 | "source": [ 1591 | "def f():\n", 1592 | " return jax.lax.cond(True, lambda: [jnp.ones((2,)), jnp.ones((3,))], lambda: [jnp.ones((3,)), jnp.ones((3,))])" 1593 | ] 1594 | }, 1595 | { 1596 | "cell_type": "markdown", 1597 | "id": "0d64ce41-986e-4e50-adb8-57ca62423d59", 1598 | "metadata": {}, 1599 | "source": [ 1600 | "Pytrees can be differentiated of vmapped against:" 1601 | ] 1602 | }, 1603 | { 1604 | "cell_type": "code", 1605 | "execution_count": null, 1606 | "id": "23648821-a677-4c8c-8e48-f37c1da6f75e", 1607 | "metadata": {}, 1608 | "outputs": [], 1609 | "source": [ 1610 | "def f(tree):\n", 1611 | " return 2 * tree['a']['c'] * tree['b']" 1612 | ] 1613 | }, 1614 | { 1615 | "cell_type": "code", 1616 | "execution_count": null, 1617 | "id": "bccf93c7-9fc3-4b4b-a449-d12461f60e04", 1618 | "metadata": {}, 1619 | "outputs": [], 1620 | "source": [ 1621 | "grad_f = grad(f)\n", 1622 | "grad_f({'a': {'c': 10.}, 'b': -2.})" 1623 | ] 1624 | }, 1625 | { 1626 | "cell_type": "code", 1627 | "execution_count": null, 1628 | "id": "bcc247d1-6d55-4769-b83d-153fa642c36f", 1629 | "metadata": {}, 1630 | "outputs": [], 1631 | "source": [ 1632 | "vmapped_f = vmap(f)\n", 1633 | "vmapped_f({'a': {'c': 10. * jnp.ones((10,))}, 'b': -2. * jnp.arange(10)})" 1634 | ] 1635 | }, 1636 | { 1637 | "cell_type": "markdown", 1638 | "id": "da0e5217-c07d-4dc0-8f4d-b7d494b9680e", 1639 | "metadata": {}, 1640 | "source": [ 1641 | "All such operations can be done only on part of the PyTree structure of each argument:" 1642 | ] 1643 | }, 1644 | { 1645 | "cell_type": "code", 1646 | "execution_count": null, 1647 | "id": "0d9446ba-ba53-4832-b7a8-165221182cac", 1648 | "metadata": {}, 1649 | "outputs": [], 1650 | "source": [ 1651 | "vmapped_f_partial = vmap(f, in_axes=({'a': {'c': 0}, 'b': None}, )) # a bit error-prone: this is currently a jax usability weak spot.\n", 1652 | "vmapped_f_partial({'a': {'c': 10. * jnp.ones((10,))}, 'b': -2.})" 1653 | ] 1654 | }, 1655 | { 1656 | "cell_type": "markdown", 1657 | "id": "927a6a98-60d7-4ca1-b786-991a66f2b18e", 1658 | "metadata": {}, 1659 | "source": [ 1660 | " pytrees are useful to represent the internal state (e.g. the weights, the batch norm statistics, etc) of a neural network, as often, neural network have a nested/hierechical/modular structure." 1661 | ] 1662 | }, 1663 | { 1664 | "cell_type": "markdown", 1665 | "id": "3aab100e-83a9-4fc3-ac47-6b7126ed7ce0", 1666 | "metadata": {}, 1667 | "source": [ 1668 | "### Convenients PyTrees\n", 1669 | "\n", 1670 | "Two important remarks on PyTree are in order:\n", 1671 | "\n", 1672 | "- NamedTuple are PyTree. Using namedtuples often prevent extenive kwargs plumbing in functions, by gathering all \"configuration\" arguments (as opposed to actual variables) to a function within a `Config` NamedTuple Umbrella (note that all attributes of the named tuple must be tracable!)." 1673 | ] 1674 | }, 1675 | { 1676 | "cell_type": "code", 1677 | "execution_count": null, 1678 | "id": "8036103b-2eb7-4c82-835b-287849917fde", 1679 | "metadata": {}, 1680 | "outputs": [], 1681 | "source": [ 1682 | "from typing import NamedTuple\n", 1683 | "class MyAlgorithmConfig(NamedTuple):\n", 1684 | " learning_rate: int = 0.01\n", 1685 | " regularization_val: int = 0.001\n", 1686 | " # careful! if the training loop was implememted as a jax.lax.scan loop,\n", 1687 | " # num_iter will have to be static, as it impacts the output shape of the scan loop\n", 1688 | " # num_iter = 100\n", 1689 | " ..." 1690 | ] 1691 | }, 1692 | { 1693 | "cell_type": "code", 1694 | "execution_count": null, 1695 | "id": "110ec796-b937-4f85-908a-888ae62e0ae8", 1696 | "metadata": {}, 1697 | "outputs": [], 1698 | "source": [ 1699 | "def train_my_network(X, y, my_algorithm_config: MyAlgorithmConfig,\n", 1700 | " # num_iter = 100\n", 1701 | " ):\n", 1702 | " ...\n", 1703 | " \n", 1704 | "# instead of\n", 1705 | "def train_my_network(X, y, learning_rate, regularization_val, \n", 1706 | " # num_iter = 100\n", 1707 | " ):\n", 1708 | " ..." 1709 | ] 1710 | }, 1711 | { 1712 | "cell_type": "markdown", 1713 | "id": "a1d21c30-1722-47e8-afa9-b703799e0d1f", 1714 | "metadata": {}, 1715 | "source": [ 1716 | "- Additional User-Defined Collection Types can be registered as PyTreeNodes, using the jax.tree_util.register_pytree_node_class\n", 1717 | "- However, `flax`, a jax-powered neural network library, exposes a `struct.PyTreeNode` helper allowing the extremly useful operation:\n", 1718 | " - automatically register dataclass-like classes...\n", 1719 | " - ...while manually specifying which arguments are static and which are traceable! Thanks to this, `num_iter` in the previous example can be safely ignored:\n", 1720 | " " 1721 | ] 1722 | }, 1723 | { 1724 | "cell_type": "code", 1725 | "execution_count": null, 1726 | "id": "fda057f5-afaf-4a2b-b737-cc3b9f8095c2", 1727 | "metadata": {}, 1728 | "outputs": [], 1729 | "source": [ 1730 | "from flax import struct\n", 1731 | "\n", 1732 | "class MyAlgorithmConfig(struct.PyTreeNode):\n", 1733 | " learning_rate: int = 0.01\n", 1734 | " regularization_val: int = 0.001\n", 1735 | " # num_iter can be marked as static!\n", 1736 | " num_iter: int = struct.field(pytree_node=False, default=100)" 1737 | ] 1738 | }, 1739 | { 1740 | "cell_type": "markdown", 1741 | "id": "6d550e43-baa4-4e97-9eb5-e17c3c943ef3", 1742 | "metadata": {}, 1743 | "source": [ 1744 | "## A final End-to-end example. Mixture of Gaussians:" 1745 | ] 1746 | }, 1747 | { 1748 | "cell_type": "markdown", 1749 | "id": "8d682d61-4542-427b-b6e3-aad6e186e17b", 1750 | "metadata": {}, 1751 | "source": [ 1752 | "### What This Example shows" 1753 | ] 1754 | }, 1755 | { 1756 | "cell_type": "markdown", 1757 | "id": "3bcf1ebc-122d-4266-bb07-08e1b783abab", 1758 | "metadata": {}, 1759 | "source": [ 1760 | "This final example is an end-to-end macro-example showcasing the capabilities of jax on a end-to-end algorithm with a few subtlelties: a mixture of gaussian algorithms. In particular, this examples demostrates:\n", 1761 | "\n", 1762 | "- the functional nature of jax through the use of functional slice setting semantics (x = x.at[i].set(y))\n", 1763 | "- the use of user-defined PyTrees, both as input and output algorithm\n", 1764 | "- top-level use of jax.vmap to seamlessly vectorize the algorithm over random initializations" 1765 | ] 1766 | }, 1767 | { 1768 | "cell_type": "markdown", 1769 | "id": "524f1124-831e-46d0-ab3f-3d1a617c1fb7", 1770 | "metadata": {}, 1771 | "source": [ 1772 | "### Show Me the Code" 1773 | ] 1774 | }, 1775 | { 1776 | "cell_type": "code", 1777 | "execution_count": null, 1778 | "id": "4c45c09b-fbfe-4956-9e06-71ef7ddd9d93", 1779 | "metadata": {}, 1780 | "outputs": [], 1781 | "source": [ 1782 | "from typing import NamedTuple\n", 1783 | "\n", 1784 | "import jax\n", 1785 | "import jax.numpy as jnp\n", 1786 | "from flax import struct\n", 1787 | "from jax import random\n", 1788 | "from jax._src.api import vmap\n", 1789 | "from jax.nn import log_softmax, logsumexp\n", 1790 | "from jax.tree_util import tree_map\n", 1791 | "from jax.flatten_util import ravel_pytree # type: ignore\n", 1792 | "from numpyro import distributions as np_distributions\n", 1793 | "import numpy as np\n", 1794 | "\n", 1795 | "from typing import Any\n", 1796 | "Array, Numeric, PRNGKeyArray = Any, Any, Any" 1797 | ] 1798 | }, 1799 | { 1800 | "cell_type": "code", 1801 | "execution_count": null, 1802 | "id": "12fdf9e5-cb5a-44d6-a55c-b55b077be951", 1803 | "metadata": {}, 1804 | "outputs": [], 1805 | "source": [ 1806 | "class MOGDistribution(np_distributions.Distribution):\n", 1807 | " def __init__(self, cluster_means: Array, cluster_covs: Array, cluster_props: Array):\n", 1808 | " assert len(cluster_means.shape) == 2\n", 1809 | " self._num_clusters, self._num_dims = cluster_means.shape\n", 1810 | " self.cluster_means = cluster_means\n", 1811 | "\n", 1812 | " assert len(cluster_props.shape) == 1\n", 1813 | " self.cluster_props = cluster_props\n", 1814 | "\n", 1815 | " self.cluster_covs = cluster_covs\n", 1816 | "\n", 1817 | " self._dists = np_distributions.MultivariateNormal(\n", 1818 | " cluster_means, covariance_matrix=cluster_covs\n", 1819 | " )\n", 1820 | "\n", 1821 | " super(MOGDistribution, self).__init__(\n", 1822 | " batch_shape=(), event_shape=(self._num_dims,)\n", 1823 | " )\n", 1824 | "\n", 1825 | " def log_prob(self, x: Array):\n", 1826 | " assert len(x.shape) == 1\n", 1827 | " return logsumexp(self._dists.log_prob(x), b=self.cluster_props)\n", 1828 | "\n", 1829 | " def _sample_from_cluster_idx(self, key, idx):\n", 1830 | " return tree_map(lambda d: d[idx], self._dists).sample(key)\n", 1831 | "\n", 1832 | " def sample(self, key: PRNGKeyArray, sample_shape: tuple = ()) -> Array:\n", 1833 | " if sample_shape == tuple():\n", 1834 | " sample_shape = (1,)\n", 1835 | "\n", 1836 | " key, key_latent = random.split(key)\n", 1837 | "\n", 1838 | " mn = np_distributions.Categorical(probs=self.cluster_props)\n", 1839 | " idxs = mn.sample(key_latent, sample_shape=sample_shape)\n", 1840 | " keys_observed = random.split(key, num=sample_shape[0])\n", 1841 | " return vmap(self._sample_from_cluster_idx, in_axes=(0, 0))( # type: ignore\n", 1842 | " keys_observed, idxs\n", 1843 | " )" 1844 | ] 1845 | }, 1846 | { 1847 | "cell_type": "code", 1848 | "execution_count": null, 1849 | "id": "ed3896a3-f636-40b5-96ce-6c9d1a771af2", 1850 | "metadata": {}, 1851 | "outputs": [], 1852 | "source": [ 1853 | "class MOGTrainingConfig(struct.PyTreeNode):\n", 1854 | " num_clusters: int = struct.field(pytree_node=False)\n", 1855 | " max_iter: int = struct.field(pytree_node=False)\n", 1856 | " num_inits: int = struct.field(pytree_node=False)\n", 1857 | " min_std: float = struct.field(pytree_node=True, default=0.01)\n", 1858 | " max_train_samples: int = struct.field(pytree_node=False, default=1000)\n", 1859 | " cov_reg_param: int = struct.field(pytree_node=False, default=1e-6)" 1860 | ] 1861 | }, 1862 | { 1863 | "cell_type": "code", 1864 | "execution_count": null, 1865 | "id": "ec1d8153-fd45-4cfa-a9f0-2f8633c7e257", 1866 | "metadata": {}, 1867 | "outputs": [], 1868 | "source": [ 1869 | "def _kmeans_plus_plus_init(data: Array, num_clusters: int, key: PRNGKeyArray) -> Array:\n", 1870 | " num_points, num_dim = data.shape\n", 1871 | " clusters = jnp.empty((num_clusters, num_dim))\n", 1872 | "\n", 1873 | " init_cluster_data_idx = random.choice(key, a=num_points)\n", 1874 | " init_cluster_center = data[init_cluster_data_idx, :]\n", 1875 | "\n", 1876 | " this_cluster_center = init_cluster_center\n", 1877 | " clusters = clusters.at[0, :].set(this_cluster_center)\n", 1878 | "\n", 1879 | " all_sq_dists = jnp.inf * jnp.ones((num_points, num_clusters))\n", 1880 | "\n", 1881 | " for i in range(num_clusters - 1):\n", 1882 | " sq_dists = jnp.sum(jnp.square(data - this_cluster_center), axis=1)\n", 1883 | "\n", 1884 | " all_sq_dists = all_sq_dists.at[:, i].set(sq_dists)\n", 1885 | " min_sq_dists = jnp.min(all_sq_dists, axis=1)\n", 1886 | "\n", 1887 | " key, subkey = random.split(key)\n", 1888 | " next_cluster_idx = random.categorical(\n", 1889 | " subkey, logits=jnp.log(min_sq_dists + 1e-15)\n", 1890 | " )\n", 1891 | " next_cluster_center = data[next_cluster_idx]\n", 1892 | "\n", 1893 | " this_cluster_center = next_cluster_center\n", 1894 | " clusters = clusters.at[i + 1, :].set(this_cluster_center)\n", 1895 | "\n", 1896 | " return clusters" 1897 | ] 1898 | }, 1899 | { 1900 | "cell_type": "code", 1901 | "execution_count": null, 1902 | "id": "9f7c926c-9a55-4361-8fde-30a342a82055", 1903 | "metadata": {}, 1904 | "outputs": [], 1905 | "source": [ 1906 | "class MOGResult(NamedTuple):\n", 1907 | " # don't return a MOGDistribution because numpyro Distributions objects are not\n", 1908 | " # vmap-able\n", 1909 | " min_std: float\n", 1910 | " cluster_init: Array\n", 1911 | " cluster_means: Array\n", 1912 | " cluster_covs: Array\n", 1913 | " cluster_props: Array\n", 1914 | " log_probs: Array\n", 1915 | " final_log_prob: Numeric\n", 1916 | " converged: bool\n", 1917 | " num_iter_convergence: Numeric\n", 1918 | "\n", 1919 | " def to_dist(self) -> MOGDistribution:\n", 1920 | " return MOGDistribution(\n", 1921 | " self.cluster_means, self.cluster_covs, self.cluster_props\n", 1922 | " )" 1923 | ] 1924 | }, 1925 | { 1926 | "cell_type": "code", 1927 | "execution_count": null, 1928 | "id": "3d41eeda-49b6-4dec-8bb4-ab01d787dc88", 1929 | "metadata": {}, 1930 | "outputs": [], 1931 | "source": [ 1932 | "def _fit_one_mog(\n", 1933 | " data: Array, num_clusters: int, min_std: float, max_iter: int,\n", 1934 | " max_train_samples: int, cov_reg_param: float, key: PRNGKeyArray\n", 1935 | ") -> MOGResult:\n", 1936 | " num_points, num_dims = data.shape\n", 1937 | "\n", 1938 | " assert len(data.shape) == 2\n", 1939 | " # TODO: kmeans++?\n", 1940 | "\n", 1941 | " if data.shape[0] > max_train_samples:\n", 1942 | " key, subkey = random.split(key)\n", 1943 | " data = random.permutation(subkey, data, axis=0)\n", 1944 | " data = data[:max_train_samples]\n", 1945 | "\n", 1946 | " # init_cluster_data_idx = random.choice(key, a=num_points, shape=(num_clusters,))\n", 1947 | "\n", 1948 | " # cluster_means = data[init_cluster_data_idx]\n", 1949 | " key, key_init = random.split(key)\n", 1950 | " init_cluster_means = _kmeans_plus_plus_init(data, num_clusters, key_init)\n", 1951 | " init_cluster_covs = jnp.stack(\n", 1952 | " [jnp.eye(num_dims) for _ in range(num_clusters)], axis=0\n", 1953 | " )\n", 1954 | "\n", 1955 | " log_cluster_props = -np.log(num_clusters) * jnp.ones((num_clusters,))\n", 1956 | "\n", 1957 | " log_prob = prev_log_prob = -jnp.inf\n", 1958 | "\n", 1959 | " iter_no = 0\n", 1960 | " assert max_iter > 0\n", 1961 | "\n", 1962 | " converged = False\n", 1963 | " num_iter_convergence = 0\n", 1964 | "\n", 1965 | " log_probs = jnp.empty((max_iter,))\n", 1966 | "\n", 1967 | " cluster_means = init_cluster_means\n", 1968 | " cluster_covs = init_cluster_covs\n", 1969 | "\n", 1970 | "\n", 1971 | " for iter_no in range(max_iter):\n", 1972 | " dists = np_distributions.MultivariateNormal(\n", 1973 | " cluster_means,\n", 1974 | " covariance_matrix=cluster_covs,\n", 1975 | " )\n", 1976 | " log_joint = dists.log_prob(data[:, None, :]) + log_cluster_props[None, :]\n", 1977 | "\n", 1978 | " log_prob = logsumexp(log_joint, axis=1).mean()\n", 1979 | " log_probs = log_probs.at[iter_no].set(log_prob)\n", 1980 | "\n", 1981 | " # assert log_prob - prev_log_prob > -1e-6\n", 1982 | " converged = jnp.abs(log_prob - prev_log_prob) < 1e-4\n", 1983 | " num_iter_convergence += 1 - converged\n", 1984 | "\n", 1985 | " # E-step: compute posterior p(z=k|x) (num_points, num_clusters)\n", 1986 | " log_resps = log_softmax(log_joint, axis=1)\n", 1987 | "\n", 1988 | " # M-step\n", 1989 | " # data: (num_points, dim)\n", 1990 | " # normalized_data_weights: (num_points,num_clusters)\n", 1991 | " normalized_data_weights = jax.nn.softmax(log_resps, axis=0)\n", 1992 | " cluster_means = jnp.sum(data[:, None, :] * normalized_data_weights[:, :, None], axis=0)\n", 1993 | "\n", 1994 | " log_cluster_props = jax.nn.log_softmax(jax.nn.logsumexp(log_resps, axis=0))\n", 1995 | "\n", 1996 | " def _compute_cov(mean, weights):\n", 1997 | " return (\n", 1998 | " (data - mean).T @ jnp.diag(weights) @ (data - mean) + cov_reg_param * jnp.eye(num_dims)\n", 1999 | " )\n", 2000 | "\n", 2001 | " cluster_covs = vmap(_compute_cov, in_axes=(0, 1))(cluster_means, normalized_data_weights) # type: ignore\n", 2002 | "\n", 2003 | " prev_log_prob = log_prob\n", 2004 | "\n", 2005 | " def smooth_cov(cov_mat):\n", 2006 | " from jax.numpy.linalg import eigh\n", 2007 | " eigvals, eigvecs = eigh(cov_mat)\n", 2008 | " return eigvecs @ jnp.diag(jnp.clip(jnp.real(eigvals), a_min=min_std ** 2)) @ eigvecs.T\n", 2009 | "\n", 2010 | " cluster_covs = vmap(smooth_cov)(cluster_covs)\n", 2011 | " return MOGResult(\n", 2012 | " min_std,\n", 2013 | " init_cluster_means,\n", 2014 | " cluster_means,\n", 2015 | " cluster_covs,\n", 2016 | " jax.nn.softmax(log_cluster_props),\n", 2017 | " log_probs,\n", 2018 | " log_prob,\n", 2019 | " converged,\n", 2020 | " num_iter_convergence,\n", 2021 | " )" 2022 | ] 2023 | }, 2024 | { 2025 | "cell_type": "code", 2026 | "execution_count": null, 2027 | "id": "2a4cb854-b195-4aa4-b3f5-6beddf8bde9e", 2028 | "metadata": {}, 2029 | "outputs": [], 2030 | "source": [ 2031 | "def fit_mog(data: Array, config: MOGTrainingConfig, key: PRNGKeyArray) -> MOGResult:\n", 2032 | " keys = random.split(key, num=config.num_inits)\n", 2033 | " vmapped_fit = vmap(_fit_one_mog, in_axes=(None, None, None, None, None, None, 0)) # type: ignore\n", 2034 | " rets = vmapped_fit(data, config.num_clusters, config.min_std, config.max_iter,\n", 2035 | " config.max_train_samples, config.cov_reg_param, keys)\n", 2036 | "\n", 2037 | " print(rets.final_log_prob)\n", 2038 | " best_fit_idx = jnp.argmax(rets.final_log_prob)\n", 2039 | " return tree_map(lambda l: l[best_fit_idx], rets)\n" 2040 | ] 2041 | }, 2042 | { 2043 | "cell_type": "code", 2044 | "execution_count": null, 2045 | "id": "bdd8aac9-31fe-48fd-bb8e-77437b83ef14", 2046 | "metadata": {}, 2047 | "outputs": [], 2048 | "source": [ 2049 | "data = jnp.concatenate([\n", 2050 | " np_distributions.Normal(0, 0.25).sample(random.PRNGKey(0), (1000, 2)),\n", 2051 | " np_distributions.Normal(1, 0.25).sample(random.PRNGKey(1), (1000, 2)),\n", 2052 | " np_distributions.Normal(2, 0.25).sample(random.PRNGKey(2), (1000, 2))\n", 2053 | " ], axis=0)\n", 2054 | "\n", 2055 | "# fix the algorithm by increasing the number of clusters\n", 2056 | "ret = fit_mog(data, MOGTrainingConfig(2, 100, 3, 0.01), random.PRNGKey(1))" 2057 | ] 2058 | }, 2059 | { 2060 | "cell_type": "code", 2061 | "execution_count": null, 2062 | "id": "6abf29a7-4fe2-49f8-9bc2-8c3877d013d9", 2063 | "metadata": {}, 2064 | "outputs": [], 2065 | "source": [ 2066 | "simulated_data = ret.to_dist().sample(random.PRNGKey(3), sample_shape=(len(data),))" 2067 | ] 2068 | }, 2069 | { 2070 | "cell_type": "code", 2071 | "execution_count": null, 2072 | "id": "537f7b68-85bf-4659-9a90-db1c000859ef", 2073 | "metadata": {}, 2074 | "outputs": [], 2075 | "source": [ 2076 | "import matplotlib.pyplot as plt\n", 2077 | "f, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 5))\n", 2078 | "ax1.scatter(data[:, 0], data[:, 1], s=3)\n", 2079 | "ax1.set_title('training data')\n", 2080 | "\n", 2081 | "ax2.scatter(simulated_data[:, 0], simulated_data[:, 1], s=3)\n", 2082 | "ax2.set_title('MoG-simulated data')" 2083 | ] 2084 | }, 2085 | { 2086 | "cell_type": "markdown", 2087 | "id": "f46bfced-631c-49a9-8c02-49ed03660f24", 2088 | "metadata": {}, 2089 | "source": [ 2090 | "#### Jax and Functional purity\n", 2091 | "\n", 2092 | "As mentioned above, operates only on functionally pure programs (no side effect).\n", 2093 | "This \"no side effect\" policy, applied to the nested structure of computer programs, imply that variables cannot be modified.\n", 2094 | "\n", 2095 | "- Imposing Such constraints constituted a natural first step for building a framework as complex as jax, being ensuring a Smaller Surface Area for Compilation.\n", 2096 | "- Functionally pure programs imply very handy property called \"referential transparency\". A refenrentially transparent program can be extensively optimized by compilers.\n", 2097 | "\n", 2098 | "The functionaly purity requirement imposed by jax imposes python-program rethinking, especially at the control-flow levels. In addition to yield static shapes, branching and iterations need to be \"purified\"" 2099 | ] 2100 | } 2101 | ], 2102 | "metadata": { 2103 | "kernelspec": { 2104 | "display_name": "Python 3 (ipykernel)", 2105 | "language": "python", 2106 | "name": "python3" 2107 | }, 2108 | "language_info": { 2109 | "codemirror_mode": { 2110 | "name": "ipython", 2111 | "version": 3 2112 | }, 2113 | "file_extension": ".py", 2114 | "mimetype": "text/x-python", 2115 | "name": "python", 2116 | "nbconvert_exporter": "python", 2117 | "pygments_lexer": "ipython3", 2118 | "version": "3.9.13" 2119 | } 2120 | }, 2121 | "nbformat": 4, 2122 | "nbformat_minor": 5 2123 | } 2124 | --------------------------------------------------------------------------------