├── readme_pics ├── jax_logo.png ├── going_deeper.jpg ├── tpu_jit_speed.PNG └── api_onion_structure.jpg ├── .gitignore ├── LICENCE ├── README.md ├── Tutorial_3_JAX_Neural_Network_from_Scratch_Colab.ipynb └── Tutorial_4_Flax_Zero2Hero_Colab.ipynb /readme_pics/jax_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/get-started-with-JAX/HEAD/readme_pics/jax_logo.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # PyCharm IDE 2 | .idea 3 | __pycache__ 4 | 5 | # Jupyter notebook checkpoints 6 | .ipynb_checkpoints 7 | 8 | -------------------------------------------------------------------------------- /readme_pics/going_deeper.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/get-started-with-JAX/HEAD/readme_pics/going_deeper.jpg -------------------------------------------------------------------------------- /readme_pics/tpu_jit_speed.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/get-started-with-JAX/HEAD/readme_pics/tpu_jit_speed.PNG -------------------------------------------------------------------------------- /readme_pics/api_onion_structure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/get-started-with-JAX/HEAD/readme_pics/api_onion_structure.jpg -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Aleksa Gordić 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Get started with JAX! :computer: :zap: 2 | 3 | The goal of this repo is to make it easier to get started with [JAX](https://github.com/google/jax), [Flax](https://github.com/google/flax), and [Haiku](https://github.com/deepmind/dm-haiku)! 4 | 5 | `JAX` ecosystem is becoming an increasingly popular alternative to `PyTorch` and `TensorFlow`. :sunglasses: 6 | 7 |
8 |
9 | 10 |

11 | 12 |

13 | 14 |
15 |
16 | 17 | *Note: I'm only going to recommend content that I've personally analyzed and found useful here. 18 | If you want a comprehensive list check out the [awesome-jax repo](https://github.com/n2cholas/awesome-jax).* 19 | 20 | ## Table of Contents 21 | * [Machine Learning with JAX](#my-machine-learning-with-jax-tutorials) 22 | + [Tutorial #1: From Zero to Hero](#tutorial-1-from-zero-to-hero) 23 | + [Tutorial #2: From Hero to Hero Pro+](#tutorial-2-from-hero-to-heropro) 24 | + [Tutorial #3: Coding a Neural Network from Scratch in Pure JAX](#tutorial-3-building-a-neural-network-from-scratch) 25 | + [Tutorial #4: Flax From Zero to Hero](#tutorial-4-machine-learning-with-flax---from-zero-to-hero) 26 | + [Tutorial #5: Haiku From Zero to Hero (coming soon)](#tutorial-5-coming-up-machine-learning-with-haiku---from-zero-to-hero) 27 | * [Other useful JAX resources](#other-useful-content) 28 | 29 | ## My Machine Learning with JAX Tutorials 30 | 31 | *Tip on how to use notebooks: just open the notebook directly in Google Colab 32 | (you'll see a button on top of the Jupyter file which will direct you to Colab). 33 | This way you can avoid having to setup the Python env! (This was especially convenient for me since I'm on Windows which is still not supported)* 34 | 35 | ### Tutorial #1: From Zero to Hero 36 | 37 | In this video, we start from the basics and then gradually dig into the nitty-gritty details 38 | of `jit`, `grad`, `vmap`, and various other idiosyncrasies of JAX. 39 | 40 | [YouTube Video (Tutorial #1)](https://youtu.be/SstuvS-tVc0)
41 | [Accompanying Jupyter Notebook](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_1_JAX_Zero2Hero_Colab.ipynb)
42 | 43 |

44 | JAX from zero to hero! 46 |

47 | 48 | ### Tutorial #2: From Hero to HeroPro+ 49 | 50 | In this video, we learn all additional components needed to train ML models (such as NNs) on multiple machines! 51 | We'll train a simple MLP model and we'll even train an ML model on 8 TPU cores! 52 | 53 | [YouTube Video (Tutorial #2)](https://www.youtube.com/watch?v=CQQaifxuFcs)
54 | [Accompanying Jupyter Notebook](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_2_JAX_HeroPro%2B_Colab.ipynb)
55 | 56 |

57 | JAX from Hero to HeroPro+! 59 |

60 | 61 | ### Tutorial #3: Building a Neural Network from Scratch 62 | 63 | Watch me code a Neural Network from scratch! :partying_face: In this 3rd video of the JAX tutorials series. 64 | 65 | In this video, I build an [MLP](https://en.wikipedia.org/wiki/Multilayer_perceptron) and train it as a classifier on MNIST 66 | using PyTorch's data loader (although it's trivial to use a more complex dataset) - all this in "pure" JAX (no Flax/Haiku/Optax). 67 | 68 | I then do an additional analysis: 69 | * Visualize MLP's learned weights 70 | * Visualize embeddings of a batch of images using t-SNE 71 | * Finally, I analyze whether we have too many dead ReLU neurons in our network 72 | 73 | [YouTube Video (Tutorial #3)](https://www.youtube.com/watch?v=6_PqUPxRmjY)
74 | [Accompanying Jupyter Notebook](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_3_JAX_Neural_Network_from_Scratch_Colab.ipynb) (Note: I'll soon refactor it but I'll link the original)
75 | 76 |

77 | Building a Neural Network from Scratch in pure JAX! 79 |

80 | 81 | --- 82 | 83 | ### Tutorial #4: Machine Learning with Flax - From Zero to Hero 84 | 85 | In this video, I cover everything you need to know to get started with [Flax](https://github.com/google/flax)! 86 | 87 | We cover `init`, `apply`, `TrainState`, etc. and other idiosyncrasies like the usage of `mutable` and `rngs` keywords. 88 | 89 | [YouTube Video (Tutorial #4)](https://www.youtube.com/watch?v=5eUSmJvK8WA)
90 | [Accompanying Jupyter Notebook](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_4_Flax_Zero2Hero_Colab.ipynb)
91 | 92 |

93 | Flax from Zero to Hero! 95 |

96 | 97 | --- 98 | 99 | ### Tutorial #5 (coming up): Machine Learning with Haiku - From Zero to Hero 100 | 101 | todo 102 | 103 | ## Other useful content 104 | 105 | Aside from the [official docs](https://jax.readthedocs.io/) here are some resources that helped me. 106 | 107 | ### Videos 108 | 109 | * [Introduction to JAX](https://www.youtube.com/watch?v=0mVmRHMaOJ4&ab_channel=GoogleCloudTech) (gives a very high-level overview) 110 | * [JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas](https://www.youtube.com/watch?v=z-WSrQDXkuM&ab_channel=Enthought) (many more details) 111 | * [NeurIPS 2020: JAX Ecosystem Meetup](https://www.youtube.com/watch?v=iDxJxIyzSiM&t=1s&ab_channel=DeepMind) (DeepMind team about the ecosystem of libs around JAX) 112 | * [Introduction to JAX for Machine Learning and More](https://www.youtube.com/watch?v=QkmKfzxbCLQ&ab_channel=UWaterlooDataScience) (nice, hands-on workshop) 113 | * [Day 1 Talks: JAX, Flax & Transformers | HuggingFace](https://www.youtube.com/watch?v=fuAyUQcVzTY&ab_channel=HuggingFace) (all 4 talks are good) 114 | * [Day 2 Talks: JAX, Flax & Transformers | HuggingFace](https://www.youtube.com/watch?v=__eG63ZP_5g&ab_channel=HuggingFace) (only the first 2 talks are relevant) 115 | 116 | ### Blogs 117 | 118 | * [Using JAX to accelerate our research | DeepMind](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) (similar info as the NeuroIPS 2020 video) 119 | * [You don't know JAX | Colin Raffel](https://colinraffel.com/blog/you-don-t-know-jax.html) 120 | 121 | ## Acknowledgements 122 | 123 | * The notebooks were heavily inspired by the official [JAX](https://jax.readthedocs.io/), [Flax](https://flax.readthedocs.io/en/latest/), and [Haiku](https://dm-haiku.readthedocs.io/en/latest/) docs. 124 | 125 | ## Citation 126 | 127 | If you find this content useful, please cite the following: 128 | 129 | ``` 130 | @misc{Gordic2021GetStartedWithJAX, 131 | author = {Gordić, Aleksa}, 132 | title = {Get started with JAX}, 133 | year = {2021}, 134 | publisher = {GitHub}, 135 | journal = {GitHub repository}, 136 | howpublished = {\url{https://github.com/gordicaleksa/get-started-with-JAX}}, 137 | } 138 | ``` 139 | 140 | ## Connect With Me 141 | 142 | If you'd love to have some more AI-related content in your life :nerd_face:, consider: 143 | * Subscribing to my YouTube channel [The AI Epiphany](https://www.youtube.com/c/TheAiEpiphany) :bell: 144 | * Follow me on [LinkedIn](https://www.linkedin.com/in/aleksagordic/) and [Twitter](https://twitter.com/gordic_aleksa) :bulb: 145 | * Follow me on [Medium](https://gordicaleksa.medium.com/) :books: :heart: 146 | * Join the [Discord](https://discord.gg/peBrCpheKE) community! :family: 147 | 148 | ## Licence 149 | 150 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/gordicaleksa/get-started-with-JAX/blob/master/LICENCE) -------------------------------------------------------------------------------- /Tutorial_3_JAX_Neural_Network_from_Scratch_Colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Tutorial 3: JAX - Building a Neural Network from Scratch.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyMpTL6XC+tcxSqZ2FePUhlZ", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | } 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "XZuyP-M3KPUR" 35 | }, 36 | "source": [ 37 | "# MLP training on MNIST" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "id": "8-SzJ0NTKRP1" 44 | }, 45 | "source": [ 46 | "import numpy as np\n", 47 | "import jax.numpy as jnp\n", 48 | "from jax.scipy.special import logsumexp\n", 49 | "import jax\n", 50 | "from jax import jit, vmap, pmap, grad, value_and_grad\n", 51 | "\n", 52 | "from torchvision.datasets import MNIST\n", 53 | "from torch.utils.data import DataLoader" 54 | ], 55 | "execution_count": 1, 56 | "outputs": [] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "metadata": { 61 | "colab": { 62 | "base_uri": "https://localhost:8080/" 63 | }, 64 | "id": "G4NrxSVjKt8f", 65 | "outputId": "6bb8bef6-3098-4fd5-8ffe-62f4b0b1aa79" 66 | }, 67 | "source": [ 68 | "seed = 0\n", 69 | "mnist_img_size = (28, 28)\n", 70 | "\n", 71 | "def init_MLP(layer_widths, parent_key, scale=0.01):\n", 72 | "\n", 73 | " params = []\n", 74 | " keys = jax.random.split(parent_key, num=len(layer_widths)-1)\n", 75 | "\n", 76 | " for in_width, out_width, key in zip(layer_widths[:-1], layer_widths[1:], keys):\n", 77 | " weight_key, bias_key = jax.random.split(key)\n", 78 | " params.append([\n", 79 | " scale*jax.random.normal(weight_key, shape=(out_width, in_width)),\n", 80 | " scale*jax.random.normal(bias_key, shape=(out_width,))\n", 81 | " ]\n", 82 | " )\n", 83 | "\n", 84 | " return params\n", 85 | "\n", 86 | "# test\n", 87 | "key = jax.random.PRNGKey(seed)\n", 88 | "MLP_params = init_MLP([784, 512, 256, 10], key)\n", 89 | "print(jax.tree_map(lambda x: x.shape, MLP_params))" 90 | ], 91 | "execution_count": 4, 92 | "outputs": [ 93 | { 94 | "output_type": "stream", 95 | "name": "stdout", 96 | "text": [ 97 | "[[(512, 784), (512,)], [(256, 512), (256,)], [(10, 256), (10,)]]\n" 98 | ] 99 | } 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "metadata": { 105 | "colab": { 106 | "base_uri": "https://localhost:8080/" 107 | }, 108 | "id": "U_z7eLxINv9x", 109 | "outputId": "e9909f9f-6778-4977-91f1-f5b14dd9ecd4" 110 | }, 111 | "source": [ 112 | "def MLP_predict(params, x):\n", 113 | " hidden_layers = params[:-1]\n", 114 | "\n", 115 | " activation = x\n", 116 | " for w, b in hidden_layers:\n", 117 | " activation = jax.nn.relu(jnp.dot(w, activation) + b)\n", 118 | "\n", 119 | " w_last, b_last = params[-1]\n", 120 | " logits = jnp.dot(w_last, activation) + b_last\n", 121 | "\n", 122 | " # log(exp(o1)) - log(sum(exp(o1), exp(o2), ..., exp(o10)))\n", 123 | " # log( exp(o1) / sum(...) )\n", 124 | " return logits - logsumexp(logits)\n", 125 | "\n", 126 | "# tests\n", 127 | "\n", 128 | "# test single example\n", 129 | "\n", 130 | "dummy_img_flat = np.random.randn(np.prod(mnist_img_size))\n", 131 | "print(dummy_img_flat.shape)\n", 132 | "\n", 133 | "prediction = MLP_predict(MLP_params, dummy_img_flat)\n", 134 | "print(prediction.shape)\n", 135 | "\n", 136 | "# test batched function\n", 137 | "batched_MLP_predict = vmap(MLP_predict, in_axes=(None, 0))\n", 138 | "\n", 139 | "dummy_imgs_flat = np.random.randn(16, np.prod(mnist_img_size))\n", 140 | "print(dummy_imgs_flat.shape)\n", 141 | "predictions = batched_MLP_predict(MLP_params, dummy_imgs_flat)\n", 142 | "print(predictions.shape)" 143 | ], 144 | "execution_count": 5, 145 | "outputs": [ 146 | { 147 | "output_type": "stream", 148 | "name": "stdout", 149 | "text": [ 150 | "(784,)\n", 151 | "(10,)\n", 152 | "(16, 784)\n", 153 | "(16, 10)\n" 154 | ] 155 | } 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "metadata": { 161 | "colab": { 162 | "base_uri": "https://localhost:8080/" 163 | }, 164 | "id": "5pPM1dZ4QyYe", 165 | "outputId": "3317666b-e167-46b7-8cf4-b8592adc065a" 166 | }, 167 | "source": [ 168 | "def custom_transform(x):\n", 169 | " return np.ravel(np.array(x, dtype=np.float32))\n", 170 | "\n", 171 | "def custom_collate_fn(batch):\n", 172 | " transposed_data = list(zip(*batch))\n", 173 | "\n", 174 | " labels = np.array(transposed_data[1])\n", 175 | " imgs = np.stack(transposed_data[0])\n", 176 | "\n", 177 | " return imgs, labels\n", 178 | "\n", 179 | "batch_size = 128\n", 180 | "train_dataset = MNIST(root='train_mnist', train=True, download=True, transform=custom_transform)\n", 181 | "test_dataset = MNIST(root='test_mnist', train=False, download=True, transform=custom_transform)\n", 182 | "\n", 183 | "train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=custom_collate_fn, drop_last=True)\n", 184 | "test_loader = DataLoader(test_dataset, batch_size, shuffle=False, collate_fn=custom_collate_fn, drop_last=True)\n", 185 | "\n", 186 | "# test\n", 187 | "batch_data = next(iter(train_loader))\n", 188 | "imgs = batch_data[0]\n", 189 | "lbls = batch_data[1]\n", 190 | "print(imgs.shape, imgs[0].dtype, lbls.shape, lbls[0].dtype)\n", 191 | "\n", 192 | "# optimization - loading the whole dataset into memory\n", 193 | "train_images = jnp.array(train_dataset.data).reshape(len(train_dataset), -1)\n", 194 | "train_lbls = jnp.array(train_dataset.targets)\n", 195 | "\n", 196 | "test_images = jnp.array(test_dataset.data).reshape(len(test_dataset), -1)\n", 197 | "test_lbls = jnp.array(test_dataset.targets)" 198 | ], 199 | "execution_count": null, 200 | "outputs": [ 201 | { 202 | "output_type": "stream", 203 | "name": "stdout", 204 | "text": [ 205 | "(128, 784) float32 (128,) int64\n" 206 | ] 207 | } 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "metadata": { 213 | "id": "YQEYcSNzVeim" 214 | }, 215 | "source": [ 216 | "num_epochs = 5\n", 217 | "\n", 218 | "def loss_fn(params, imgs, gt_lbls):\n", 219 | " predictions = batched_MLP_predict(params, imgs)\n", 220 | "\n", 221 | " return -jnp.mean(predictions * gt_lbls)\n", 222 | "\n", 223 | "def accuracy(params, dataset_imgs, dataset_lbls):\n", 224 | " pred_classes = jnp.argmax(batched_MLP_predict(params, dataset_imgs), axis=1)\n", 225 | " return jnp.mean(dataset_lbls == pred_classes)\n", 226 | "\n", 227 | "@jit\n", 228 | "def update(params, imgs, gt_lbls, lr=0.01):\n", 229 | " loss, grads = value_and_grad(loss_fn)(params, imgs, gt_lbls)\n", 230 | "\n", 231 | " return loss, jax.tree_multimap(lambda p, g: p - lr*g, params, grads)\n", 232 | "\n", 233 | "# Create a MLP\n", 234 | "MLP_params = init_MLP([np.prod(mnist_img_size), 512, 256, len(MNIST.classes)], key)\n", 235 | "\n", 236 | "for epoch in range(num_epochs):\n", 237 | "\n", 238 | " for cnt, (imgs, lbls) in enumerate(train_loader):\n", 239 | "\n", 240 | " gt_labels = jax.nn.one_hot(lbls, len(MNIST.classes))\n", 241 | " \n", 242 | " loss, MLP_params = update(MLP_params, imgs, gt_labels)\n", 243 | " \n", 244 | " if cnt % 50 == 0:\n", 245 | " print(loss)\n", 246 | "\n", 247 | " print(f'Epoch {epoch}, train acc = {accuracy(MLP_params, train_images, train_lbls)} test acc = {accuracy(MLP_params, test_images, test_lbls)}')\n" 248 | ], 249 | "execution_count": null, 250 | "outputs": [] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "metadata": { 255 | "colab": { 256 | "base_uri": "https://localhost:8080/", 257 | "height": 316 258 | }, 259 | "id": "YmdBRBvU1wuA", 260 | "outputId": "efcfa75e-d0bb-4f16-9fb2-e85e82a53bcf" 261 | }, 262 | "source": [ 263 | "imgs, lbls = next(iter(test_loader))\n", 264 | "img = imgs[0].reshape(mnist_img_size)\n", 265 | "gt_lbl = lbls[0]\n", 266 | "print(img.shape)\n", 267 | "\n", 268 | "import matplotlib.pyplot as plt\n", 269 | "\n", 270 | "pred = jnp.argmax(MLP_predict(MLP_params, np.ravel(img)))\n", 271 | "print('pred', pred)\n", 272 | "print('gt', gt_lbl)\n", 273 | "\n", 274 | "plt.imshow(img); plt.show()" 275 | ], 276 | "execution_count": null, 277 | "outputs": [ 278 | { 279 | "output_type": "stream", 280 | "name": "stdout", 281 | "text": [ 282 | "(28, 28)\n", 283 | "pred 7\n", 284 | "gt 7\n" 285 | ] 286 | }, 287 | { 288 | "output_type": "display_data", 289 | "data": { 290 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAANiklEQVR4nO3df4wc9XnH8c8n/kV8QGtDcF3j4ISQqE4aSHWBRNDKESUFImSiJBRLtVyJ5lALElRRW0QVBalVSlEIok0aySluHESgaQBhJTSNa6W1UKljg4yxgdaEmsau8QFOaxPAP/DTP24cHXD7vWNndmft5/2SVrs7z87Oo/F9PLMzO/t1RAjA8e9tbTcAoD8IO5AEYQeSIOxAEoQdSGJ6Pxc207PiBA31c5FAKq/qZzoYBzxRrVbYbV8s6XZJ0yT9bUTcXHr9CRrSeb6wziIBFGyIdR1rXe/G254m6auSLpG0WNIy24u7fT8AvVXnM/u5kp6OiGci4qCkeyQtbaYtAE2rE/YFkn4y7vnOatrr2B6xvcn2pkM6UGNxAOro+dH4iFgZEcMRMTxDs3q9OAAd1An7LkkLxz0/vZoGYADVCftGSWfZfpftmZKulLSmmbYANK3rU28Rcdj2tZL+SWOn3lZFxLbGOgPQqFrn2SPiQUkPNtQLgB7i67JAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJGoN2Wx7h6T9kl6TdDgihptoCkDzaoW98rGIeKGB9wHQQ+zGA0nUDXtI+oHtR2yPTPQC2yO2N9nedEgHai4OQLfq7sZfEBG7bJ8maa3tpyJi/fgXRMRKSSsl6WTPjZrLA9ClWlv2iNhV3Y9Kul/SuU00BaB5XYfd9pDtk44+lvRxSVubagxAs+rsxs+TdL/to+/zrYj4fiNdAWhc12GPiGcknd1gLwB6iFNvQBKEHUiCsANJEHYgCcIOJNHEhTApvPjZj3asvXP508V5nxqdV6wfPDCjWF9wd7k+e+dLHWtHNj9RnBd5sGUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQ4zz5Ff/xH3+pY+9TQT8szn1lz4UvK5R2HX+5Yu/35j9Vc+LHrR6NndKwN3foLxXmnr3uk6XZax5YdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JwRP8GaTnZc+M8X9i35TXpZ58+r2PthQ+W/8+c82R5Hf/0V1ysz/zg/xbrt3zgvo61i97+SnHe7718YrH+idmdr5Wv65U4WKxvODBUrC854VDXy37P964u1t87srHr927ThlinfbF3wj8otuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATXs0/R0Hc2FGr13vvkerPrr39pScfan5+/qLzsfy3/5v0tS97TRUdTM/2VI8X60Jbdxfop6+8t1n91Zuff25+9o/xb/MejSbfstlfZHrW9ddy0ubbX2t5e3c/pbZsA6prKbvw3JF38hmk3SFoXEWdJWlc9BzDAJg17RKyXtPcNk5dKWl09Xi3p8ob7AtCwbj+zz4uIox+onpPUcTAz2yOSRiTpBM3ucnEA6qp9ND7GrqTpeKVHRKyMiOGIGJ6hWXUXB6BL3YZ9j+35klTdjzbXEoBe6DbsayStqB6vkPRAM+0A6JVJP7Pbvltjv1x+qu2dkr4g6WZJ37Z9laRnJV3RyyZRdvi5PR1rQ/d2rknSa5O899B3Xuyio2bs+b2PFuvvn1n+8/3S3vd1rC36u2eK8x4uVo9Nk4Y9IpZ1KB2bv0IBJMXXZYEkCDuQBGEHkiDsQBKEHUiCS1zRmulnLCzWv3LjV4r1GZ5WrP/D7b/ZsXbK7oeL8x6P2LIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKcZ0drnvrDBcX6h2eVh7LedrA8HPXcJ15+yz0dz9iyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASnGdHTx34xIc71h799G2TzF0eQej3r7uuWH/7v/1okvfPhS07kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBeXb01H9f0nl7cqLL59GX/ddFxfrs7z9WrEexms+kW3bbq2yP2t46btpNtnfZ3lzdLu1tmwDqmspu/DckXTzB9Nsi4pzq9mCzbQFo2qRhj4j1kvb2oRcAPVTnAN21trdUu/lzOr3I9ojtTbY3HdKBGosDUEe3Yf+apDMlnSNpt6RbO70wIlZGxHBEDM+Y5MIGAL3TVdgjYk9EvBYRRyR9XdK5zbYFoGldhd32/HFPPylpa6fXAhgMk55nt323pCWSTrW9U9IXJC2xfY7GTmXukHR1D3vEAHvbSScV68t//aGOtX1HXi3OO/rFdxfrsw5sLNbxepOGPSKWTTD5jh70AqCH+LoskARhB5Ig7EAShB1IgrADSXCJK2rZftP7i/Xvnvo3HWtLt3+qOO+sBzm11iS27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOfZUfR/v/ORYn3Lb/9Vsf7jw4c61l76y9OL887S7mIdbw1bdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgvPsyU1f8MvF+vWf//tifZbLf0JXPra8Y+0d/8j16v3Elh1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkuA8+3HO08v/xGd/d2ex/pkTXyzW79p/WrE+7/OdtydHinOiaZNu2W0vtP1D20/Y3mb7umr6XNtrbW+v7uf0vl0A3ZrKbvxhSZ+LiMWSPiLpGtuLJd0gaV1EnCVpXfUcwICaNOwRsTsiHq0e75f0pKQFkpZKWl29bLWky3vVJID63tJndtuLJH1I0gZJ8yLi6I+EPSdpXod5RiSNSNIJmt1tnwBqmvLReNsnSrpX0vURsW98LSJCUkw0X0SsjIjhiBieoVm1mgXQvSmF3fYMjQX9roi4r5q8x/b8qj5f0mhvWgTQhEl3421b0h2SnoyIL48rrZG0QtLN1f0DPekQ9Zz9vmL5z067s9bbf/WLnynWf/Gxh2u9P5ozlc/s50taLulx25uraTdqLOTftn2VpGclXdGbFgE0YdKwR8RDktyhfGGz7QDoFb4uCyRB2IEkCDuQBGEHkiDsQBJc4nocmLb4vR1rI/fU+/rD4lXXFOuL7vz3Wu+P/mHLDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcJ79OPDUH3T+Yd/LZu/rWJuK0//lYPkFMeEPFGEAsWUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQ4z34MePWyc4v1dZfdWqgy5BbGsGUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSSmMj77QknflDRPUkhaGRG3275J0mclPV+99MaIeLBXjWb2P+dPK9bfOb37c+l37T+tWJ+xr3w9O1ezHzum8qWaw5I+FxGP2j5J0iO211a12yLiS71rD0BTpjI++25Ju6vH+20/KWlBrxsD0Ky39Jnd9iJJH5K0oZp0re0ttlfZnvC3kWyP2N5ke9MhHajVLIDuTTnstk+UdK+k6yNin6SvSTpT0jka2/JP+AXtiFgZEcMRMTxDsxpoGUA3phR22zM0FvS7IuI+SYqIPRHxWkQckfR1SeWrNQC0atKw27akOyQ9GRFfHjd9/riXfVLS1ubbA9CUqRyNP1/SckmP295cTbtR0jLb52js7MsOSVf3pEPU8hcvLi7WH/6tRcV67H68wW7QpqkcjX9IkicocU4dOIbwDTogCcIOJEHYgSQIO5AEYQeSIOxAEo4+Drl7sufGeb6wb8sDstkQ67Qv9k50qpwtO5AFYQeSIOxAEoQdSIKwA0kQdiAJwg4k0dfz7Lafl/TsuEmnSnqhbw28NYPa26D2JdFbt5rs7YyIeMdEhb6G/U0LtzdFxHBrDRQMam+D2pdEb93qV2/sxgNJEHYgibbDvrLl5ZcMam+D2pdEb93qS2+tfmYH0D9tb9kB9AlhB5JoJey2L7b9H7aftn1DGz10YnuH7cdtb7a9qeVeVtketb113LS5ttfa3l7dTzjGXku93WR7V7XuNtu+tKXeFtr+oe0nbG+zfV01vdV1V+irL+ut75/ZbU+T9J+SLpK0U9JGScsi4om+NtKB7R2ShiOi9S9g2P4NSS9J+mZEfKCadoukvRFxc/Uf5ZyI+JMB6e0mSS+1PYx3NVrR/PHDjEu6XNLvqsV1V+jrCvVhvbWxZT9X0tMR8UxEHJR0j6SlLfQx8CJivaS9b5i8VNLq6vFqjf2x9F2H3gZCROyOiEerx/slHR1mvNV1V+irL9oI+wJJPxn3fKcGa7z3kPQD24/YHmm7mQnMi4jd1ePnJM1rs5kJTDqMdz+9YZjxgVl33Qx/XhcH6N7sgoj4NUmXSLqm2l0dSDH2GWyQzp1OaRjvfplgmPGfa3PddTv8eV1thH2XpIXjnp9eTRsIEbGruh+VdL8GbyjqPUdH0K3uR1vu5+cGaRjviYYZ1wCsuzaHP28j7BslnWX7XbZnSrpS0poW+ngT20PVgRPZHpL0cQ3eUNRrJK2oHq+Q9ECLvbzOoAzj3WmYcbW87lof/jwi+n6TdKnGjsj/WNKfttFDh77eLemx6rat7d4k3a2x3bpDGju2cZWkUyStk7Rd0j9LmjtAvd0p6XFJWzQWrPkt9XaBxnbRt0jaXN0ubXvdFfrqy3rj67JAEhygA5Ig7EAShB1IgrADSRB2IAnCDiRB2IEk/h9BCfQTVPflJQAAAABJRU5ErkJggg==\n", 291 | "text/plain": [ 292 | "
" 293 | ] 294 | }, 295 | "metadata": { 296 | "needs_background": "light" 297 | } 298 | } 299 | ] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "metadata": { 304 | "id": "TwgI3fZbKRqM" 305 | }, 306 | "source": [ 307 | "# Visualizations" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "metadata": { 313 | "colab": { 314 | "base_uri": "https://localhost:8080/", 315 | "height": 299 316 | }, 317 | "id": "jddJj8zo4D1e", 318 | "outputId": "fb157d1c-4fbe-45a5-c84d-6abe38355a5e" 319 | }, 320 | "source": [ 321 | "w = MLP_params[0][0]\n", 322 | "print(w.shape)\n", 323 | "\n", 324 | "w_single = w[500, :].reshape(mnist_img_size)\n", 325 | "print(w_single.shape)\n", 326 | "plt.imshow(w_single); plt.show()" 327 | ], 328 | "execution_count": null, 329 | "outputs": [ 330 | { 331 | "output_type": "stream", 332 | "name": "stdout", 333 | "text": [ 334 | "(512, 784)\n", 335 | "(28, 28)\n" 336 | ] 337 | }, 338 | { 339 | "output_type": "display_data", 340 | "data": { 341 | "image/png": "\n", 342 | "text/plain": [ 343 | "
" 344 | ] 345 | }, 346 | "metadata": { 347 | "needs_background": "light" 348 | } 349 | } 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "metadata": { 355 | "colab": { 356 | "base_uri": "https://localhost:8080/", 357 | "height": 484 358 | }, 359 | "id": "AZxm7G3j4iOS", 360 | "outputId": "521c3ad2-147d-4076-eea0-6537f32dafa0" 361 | }, 362 | "source": [ 363 | "# todo: visualize embeddings using t-SNE\n", 364 | "\n", 365 | "from sklearn.manifold import TSNE\n", 366 | "\n", 367 | "def fetch_activations(params, x):\n", 368 | " hidden_layers = params[:-1]\n", 369 | "\n", 370 | " activation = x\n", 371 | " for w, b in hidden_layers:\n", 372 | " activation = jax.nn.relu(jnp.dot(w, activation) + b)\n", 373 | "\n", 374 | " return activation\n", 375 | "\n", 376 | "batched_fetch_activations = vmap(fetch_activations, in_axes=(None, 0))\n", 377 | "imgs, lbls = next(iter(test_loader))\n", 378 | "\n", 379 | "batch_activations = batched_fetch_activations(MLP_params, imgs)\n", 380 | "print(batch_activations.shape) # (128, 2)\n", 381 | "\n", 382 | "t_sne_embeddings = TSNE(n_components=2, perplexity=30,).fit_transform(batch_activations)\n", 383 | "cora_label_to_color_map = {0: \"red\", 1: \"blue\", 2: \"green\", 3: \"orange\", 4: \"yellow\", 5: \"pink\", 6: \"gray\"}\n", 384 | "\n", 385 | "for class_id in range(10):\n", 386 | " plt.scatter(t_sne_embeddings[lbls == class_id, 0], t_sne_embeddings[lbls == class_id, 1], s=20, color=cora_label_to_color_map[class_id])\n", 387 | "plt.show()" 388 | ], 389 | "execution_count": null, 390 | "outputs": [ 391 | { 392 | "output_type": "stream", 393 | "name": "stdout", 394 | "text": [ 395 | "(128, 256)\n" 396 | ] 397 | }, 398 | { 399 | "output_type": "error", 400 | "ename": "KeyError", 401 | "evalue": "ignored", 402 | "traceback": [ 403 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 404 | "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", 405 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mclass_id\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscatter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt_sne_embeddings\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlbls\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mclass_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt_sne_embeddings\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlbls\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mclass_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcolor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcora_label_to_color_map\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mclass_id\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 25\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 406 | "\u001b[0;31mKeyError\u001b[0m: 7" 407 | ] 408 | }, 409 | { 410 | "output_type": "display_data", 411 | "data": { 412 | "image/png": "\n", 413 | "text/plain": [ 414 | "
" 415 | ] 416 | }, 417 | "metadata": { 418 | "needs_background": "light" 419 | } 420 | } 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "metadata": { 426 | "colab": { 427 | "base_uri": "https://localhost:8080/" 428 | }, 429 | "id": "MHL27HumNgwf", 430 | "outputId": "d44b1e9c-33d6-4dd3-cf05-b3a1e19f0194" 431 | }, 432 | "source": [ 433 | "# todo: dead neurons\n", 434 | "\n", 435 | "def fetch_activations2(params, x):\n", 436 | " hidden_layers = params[:-1]\n", 437 | " collector = []\n", 438 | "\n", 439 | " activation = x\n", 440 | " for w, b in hidden_layers:\n", 441 | " activation = jax.nn.relu(jnp.dot(w, activation) + b)\n", 442 | " collector.append(activation)\n", 443 | "\n", 444 | " return collector\n", 445 | "\n", 446 | "batched_fetch_activations2 = vmap(fetch_activations2, in_axes=(None, 0))\n", 447 | "\n", 448 | "imgs, lbls = next(iter(test_loader))\n", 449 | "\n", 450 | "MLP_params2 = init_MLP([np.prod(mnist_img_size), 512, 256, len(MNIST.classes)], key)\n", 451 | "\n", 452 | "batch_activations = batched_fetch_activations2(MLP_params2, imgs)\n", 453 | "print(batch_activations[1].shape) # (128, 512/256)\n", 454 | "\n", 455 | "dead_neurons = [np.ones(act.shape[1:]) for act in batch_activations]\n", 456 | "\n", 457 | "for layer_id, activations in enumerate(batch_activations):\n", 458 | " dead_neurons[layer_id] = np.logical_and(dead_neurons[layer_id], (activations == 0).all(axis=0))\n", 459 | "\n", 460 | "for layers in dead_neurons:\n", 461 | " print(np.sum(layers))" 462 | ], 463 | "execution_count": null, 464 | "outputs": [ 465 | { 466 | "output_type": "stream", 467 | "name": "stdout", 468 | "text": [ 469 | "(128, 256)\n", 470 | "0\n", 471 | "7\n" 472 | ] 473 | } 474 | ] 475 | }, 476 | { 477 | "cell_type": "markdown", 478 | "metadata": { 479 | "id": "jMmOX-VSKTjQ" 480 | }, 481 | "source": [ 482 | "# Parallelization" 483 | ] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "metadata": { 488 | "id": "1aCkdHuhKUqV" 489 | }, 490 | "source": [ 491 | "" 492 | ], 493 | "execution_count": null, 494 | "outputs": [] 495 | } 496 | ] 497 | } -------------------------------------------------------------------------------- /Tutorial_4_Flax_Zero2Hero_Colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "[](https://deepnote.com/launch?url=https%3A%2F%2Fgithub.com%2Fgordicaleksa%2Fget-started-with-JAX%2Fblob%2Fmain%2FTutorial_4_Flax_Zero2Hero_Colab.ipynb)" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "\"Open" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "TbMr3-5oun69" 22 | }, 23 | "source": [ 24 | "# Flax: From Zero to Hero!\n", 25 | "\n", 26 | "This notebook heavily relies on the [official Flax docs](https://flax.readthedocs.io/en/latest/) and [examples](https://github.com/google/flax/blob/main/examples/) + some additional code/modifications, comments/notes, etc." 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "C1qve53yeof5" 33 | }, 34 | "source": [ 35 | "### Enter Flax - the basics ❤️\n", 36 | "\n", 37 | "Before you jump into the Flax world I strongly recommend you check out my JAX tutorials, as I won't be covering the details of JAX here.\n", 38 | "\n", 39 | "* (Tutorial 1) ML with JAX: From Zero to Hero ([video](https://www.youtube.com/watch?v=SstuvS-tVc0), [notebook](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_1_JAX_Zero2Hero_Colab.ipynb))\n", 40 | "* (Tutorial 2) ML with JAX: from Hero to Hero Pro+ ([video](https://www.youtube.com/watch?v=CQQaifxuFcs), [notebook](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_2_JAX_HeroPro%2B_Colab.ipynb))\n", 41 | "* (Tutorial 3) ML with JAX: Coding a Neural Network from Scratch in Pure JAX ([video](https://www.youtube.com/watch?v=6_PqUPxRmjY), [notebook](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_3_JAX_Neural_Network_from_Scratch_Colab.ipynb))\n", 42 | "\n", 43 | "That out of the way - let's start with the basics!" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 1, 49 | "metadata": { 50 | "id": "GHcasJkggdZN" 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "# Install Flax and JAX\n", 55 | "!pip install --upgrade -q \"jax[cuda11_cudnn805]\" -f https://storage.googleapis.com/jax-releases/jax_releases.html\n", 56 | "!pip install --upgrade -q git+https://github.com/google/flax.git\n", 57 | "!pip install --upgrade -q git+https://github.com/deepmind/dm-haiku # Haiku is here just for comparison purposes" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 2, 63 | "metadata": { 64 | "id": "HmVx7EjigrEZ" 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "import jax\n", 69 | "from jax import lax, random, numpy as jnp\n", 70 | "\n", 71 | "# NN lib built on top of JAX developed by Google Research (Brain team)\n", 72 | "# Flax was \"designed for flexibility\" hence the name (Flexibility + JAX -> Flax)\n", 73 | "import flax\n", 74 | "from flax.core import freeze, unfreeze\n", 75 | "from flax import linen as nn # nn notation also used in PyTorch and in Flax's older API\n", 76 | "from flax.training import train_state # a useful dataclass to keep train state\n", 77 | "\n", 78 | "# DeepMind's NN JAX lib - just for comparison purposes, we're not learning Haiku here\n", 79 | "import haiku as hk \n", 80 | "\n", 81 | "# JAX optimizers - a separate lib developed by DeepMind\n", 82 | "import optax\n", 83 | "\n", 84 | "# Flax doesn't have its own data loading functions - we'll be using PyTorch dataloaders\n", 85 | "from torchvision.datasets import MNIST\n", 86 | "from torch.utils.data import DataLoader\n", 87 | "\n", 88 | "# Python libs\n", 89 | "import functools # useful utilities for functional programs\n", 90 | "from typing import Any, Callable, Sequence, Optional\n", 91 | "\n", 92 | "# Other important 3rd party libs\n", 93 | "import numpy as np\n", 94 | "import matplotlib.pyplot as plt" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": { 100 | "id": "aSDyQLgOesZp" 101 | }, 102 | "source": [ 103 | "The goal of this notebook is to get you started with Flax!\n", 104 | "\n", 105 | "I'll only cover the most essential parts of Flax (and Optax) - just as much as needed to get you started with training NNs!" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": { 112 | "id": "y1kdq0P_g7LU" 113 | }, 114 | "outputs": [], 115 | "source": [ 116 | "# Let's start with the simplest model possible: a single feed-forward layer (linear regression)\n", 117 | "model = nn.Dense(features=5)\n", 118 | "\n", 119 | "# All of the Flax NN layers inherit from the Module class (similarly to PyTorch)\n", 120 | "print(nn.Dense.__bases__)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": { 126 | "id": "ux9Okie5PWpw" 127 | }, 128 | "source": [ 129 | "So how can we do inference with this simple model? 2 steps: init and apply!" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 4, 135 | "metadata": { 136 | "id": "QViTvJhFite2" 137 | }, 138 | "outputs": [], 139 | "source": [ 140 | "# Step 1: init\n", 141 | "seed = 23\n", 142 | "key1, key2 = random.split(random.PRNGKey(seed))\n", 143 | "x = random.normal(key1, (10,)) # create a dummy input, a 10-dimensional random vector\n", 144 | "\n", 145 | "# Initialization call - this gives us the actual model weights \n", 146 | "# (remember JAX handles state externally!)\n", 147 | "y, params = model.init_with_output(key2, x) \n", 148 | "print(y)\n", 149 | "print(jax.tree_map(lambda x: x.shape, params))\n", 150 | "\n", 151 | "# Note1: automatic shape inference\n", 152 | "# Note2: immutable structure (hence FrozenDict)\n", 153 | "# Note3: init_with_output if you care, for whatever reason, about the output here" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": { 160 | "id": "b3yFAqeTjdLj" 161 | }, 162 | "outputs": [], 163 | "source": [ 164 | "# Step 2: apply\n", 165 | "y = model.apply(params, x) # this is how you run prediction in Flax, state is external!\n", 166 | "print(y)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": { 173 | "id": "31O_mx-Smalq" 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "try:\n", 178 | " y = model(x) # this doesn't work anymore (bye bye PyTorch syntax)\n", 179 | "except Exception as e:\n", 180 | " print(e)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": { 187 | "id": "fQYyv76sCJ25" 188 | }, 189 | "outputs": [], 190 | "source": [ 191 | "# todo: a small coding exercise - let's contrast Flax with Haiku" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": { 198 | "cellView": "form", 199 | "id": "UWr3hdpmFBng" 200 | }, 201 | "outputs": [], 202 | "source": [ 203 | "#@title Haiku vs Flax solution\n", 204 | "model = hk.transform(lambda x: hk.Linear(output_size=5)(x))\n", 205 | "\n", 206 | "seed = 23\n", 207 | "key1, key2 = random.split(random.PRNGKey(seed))\n", 208 | "x = random.normal(key1, (10,)) # create a dummy input, a 10-dimensional random vector\n", 209 | "\n", 210 | "params = model.init(key2, x)\n", 211 | "out = model.apply(params, None, x)\n", 212 | "print(out)\n", 213 | "\n", 214 | "print(hk.Linear.__bases__)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "metadata": { 220 | "id": "wWBxTShUiLzW" 221 | }, 222 | "source": [ 223 | "All of this might (initially!) be overwhelming if you're used to stateful, object-oriented paradigm.\n", 224 | "\n", 225 | "What Flax offers is high performance and flexibility (similarly to JAX).\n", 226 | "\n", 227 | "Here are some [benchmark numbers](https://github.com/huggingface/transformers/tree/master/examples/flax/text-classification) from the HuggingFace team.\n", 228 | "\n", 229 | "![image.png]()" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": { 235 | "id": "eUBYtd40krx1" 236 | }, 237 | "source": [ 238 | "Now that we have a an answer to \"why should I learn Flax?\" - let's start our descent into Flaxlandia!\n", 239 | "\n", 240 | "### A toy example 🚚 - training a linear regression model\n", 241 | "\n", 242 | "We'll first implement a pure-JAX appoach and then we'll do it the Flax-way." 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": { 249 | "id": "53-TXcbYkt9D" 250 | }, 251 | "outputs": [], 252 | "source": [ 253 | "# Defining a toy dataset\n", 254 | "n_samples = 150\n", 255 | "x_dim = 2 # putting small numbers here so that we can visualize the data easily\n", 256 | "y_dim = 1\n", 257 | "noise_amplitude = 0.1\n", 258 | "\n", 259 | "# Generate (random) ground truth W and b\n", 260 | "# Note: we could get W, b from a randomely initialized nn.Dense here, being closer to JAX for now \n", 261 | "key, w_key, b_key = random.split(random.PRNGKey(seed), num=3)\n", 262 | "W = random.normal(w_key, (x_dim, y_dim)) # weight\n", 263 | "b = random.normal(b_key, (y_dim,)) # bias\n", 264 | "\n", 265 | "# This is the structure that Flax expects (recall from the previous section!)\n", 266 | "true_params = freeze({'params': {'bias': b, 'kernel': W}})\n", 267 | "\n", 268 | "# Generate samples with additional noise\n", 269 | "key, x_key, noise_key = random.split(key, num=3)\n", 270 | "xs = random.normal(x_key, (n_samples, x_dim))\n", 271 | "ys = jnp.dot(xs, W) + b\n", 272 | "ys += noise_amplitude * random.normal(noise_key, (n_samples, y_dim))\n", 273 | "print(f'xs shape = {xs.shape} ; ys shape = {ys.shape}')" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 9, 279 | "metadata": { 280 | "colab": { 281 | "base_uri": "https://localhost:8080/", 282 | "height": 266 283 | }, 284 | "id": "lc4-xoIapKCs", 285 | "outputId": "52656571-0aa5-4c6f-f522-83b0158c1b97" 286 | }, 287 | "outputs": [ 288 | { 289 | "data": { 290 | "text/plain": [ 291 | "" 292 | ] 293 | }, 294 | "execution_count": 9, 295 | "metadata": {}, 296 | "output_type": "execute_result" 297 | }, 298 | { 299 | "data": { 300 | "image/png": "\n", 301 | "text/plain": [ 302 | "
" 303 | ] 304 | }, 305 | "metadata": { 306 | "needs_background": "light" 307 | }, 308 | "output_type": "display_data" 309 | } 310 | ], 311 | "source": [ 312 | "# Let's visualize our data (becoming one with the data paradigm <3)\n", 313 | "fig = plt.figure()\n", 314 | "ax = fig.add_subplot(111, projection='3d')\n", 315 | "assert xs.shape[-1] == 2 and ys.shape[-1] == 1 # low dimensional data so that we can plot it\n", 316 | "ax.scatter(xs[:, 0], xs[:, 1], zs=ys)\n", 317 | "\n", 318 | "# todo: exercise - let's show that our data lies on the 2D plane embedded in 3D\n", 319 | "# option 1: analytic approach\n", 320 | "# option 2: data-driven approach" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 22, 326 | "metadata": { 327 | "id": "mKiCOyoikxcM" 328 | }, 329 | "outputs": [], 330 | "source": [ 331 | "def make_mse_loss(xs, ys):\n", 332 | " \n", 333 | " def mse_loss(params):\n", 334 | " \"\"\"Gives the value of the loss on the (xs, ys) dataset for the given model (params).\"\"\"\n", 335 | " \n", 336 | " # Define the squared loss for a single pair (x,y)\n", 337 | " def squared_error(x, y):\n", 338 | " pred = model.apply(params, x)\n", 339 | " # Inner because 'y' could have in general more than 1 dims\n", 340 | " return jnp.inner(y-pred, y-pred) / 2.0\n", 341 | "\n", 342 | " # Batched version via vmap\n", 343 | " return jnp.mean(jax.vmap(squared_error)(xs, ys), axis=0)\n", 344 | "\n", 345 | " return jax.jit(mse_loss) # and finally we jit the result (mse_loss is a pure function)\n", 346 | "\n", 347 | "mse_loss = make_mse_loss(xs, ys)\n", 348 | "value_and_grad_fn = jax.value_and_grad(mse_loss)" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": null, 354 | "metadata": { 355 | "id": "phLYjH5ZkzLn" 356 | }, 357 | "outputs": [], 358 | "source": [ 359 | "# Let's reuse the simple feed-forward layer since it trivially implements linear regression\n", 360 | "model = nn.Dense(features=y_dim)\n", 361 | "params = model.init(key, xs)\n", 362 | "print(f'Initial params = {params}')\n", 363 | "\n", 364 | "# Let's set some reasonable hyperparams\n", 365 | "lr = 0.3\n", 366 | "epochs = 20\n", 367 | "log_period_epoch = 5\n", 368 | "\n", 369 | "print('-' * 50)\n", 370 | "for epoch in range(epochs):\n", 371 | " loss, grads = value_and_grad_fn(params)\n", 372 | " # SGD (closer to JAX again, but we'll progressively go towards how stuff is done in Flax)\n", 373 | " params = jax.tree_multimap(lambda p, g: p - lr * g, params, grads)\n", 374 | "\n", 375 | " if epoch % log_period_epoch == 0:\n", 376 | " print(f'epoch {epoch}, loss = {loss}')\n", 377 | "\n", 378 | "print('-' * 50)\n", 379 | "print(f'Learned params = {params}')\n", 380 | "print(f'Gt params = {true_params}')" 381 | ] 382 | }, 383 | { 384 | "cell_type": "markdown", 385 | "metadata": { 386 | "id": "rvy6Oow2lLHu" 387 | }, 388 | "source": [ 389 | "Now let's do the same thing but this time with dedicated optimizers!\n", 390 | "\n", 391 | "Enter DeepMind's optax! ❤️🔥" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": null, 397 | "metadata": { 398 | "id": "5hhcFZ7UlCov" 399 | }, 400 | "outputs": [], 401 | "source": [ 402 | "opt_sgd = optax.sgd(learning_rate=lr)\n", 403 | "opt_state = opt_sgd.init(params) # always the same pattern - handling state externally\n", 404 | "print(opt_state)\n", 405 | "# todo: exercise - compare Adam's and SGD's states" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": null, 411 | "metadata": { 412 | "id": "t_EHHjy_lFGN" 413 | }, 414 | "outputs": [], 415 | "source": [ 416 | "params = model.init(key, xs) # let's start with fresh params again\n", 417 | "\n", 418 | "for epoch in range(epochs):\n", 419 | " loss, grads = value_and_grad_fn(params)\n", 420 | " updates, opt_state = opt_sgd.update(grads, opt_state) # arbitrary optim logic!\n", 421 | " params = optax.apply_updates(params, updates)\n", 422 | "\n", 423 | " if epoch % log_period_epoch == 0:\n", 424 | " print(f'epoch {epoch}, loss = {loss}')\n", 425 | "\n", 426 | "# Note 1: as expected we get the same loss values\n", 427 | "# Note 2: we'll later see more concise ways to handle all of these state components (hint: TrainState)" 428 | ] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "metadata": { 433 | "id": "QF1gAYSzxQ1R" 434 | }, 435 | "source": [ 436 | "In this toy SGD example Optax may not seem that useful but it's very powerful.\n", 437 | "\n", 438 | "You can build arbitrary optimizers with arbitrary hyperparam schedules, chaining, param freezing, etc. You can check the [official docs here](https://optax.readthedocs.io/en/latest/)." 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": 8, 444 | "metadata": { 445 | "cellView": "form", 446 | "id": "rKbis5O0KQYH" 447 | }, 448 | "outputs": [], 449 | "source": [ 450 | "#@title Optax Advanced Examples\n", 451 | "# This cell won't \"compile\" (no ml_collections package) and serves just as an example\n", 452 | "\n", 453 | "# Example from Flax (ImageNet example)\n", 454 | "# https://github.com/google/flax/blob/main/examples/imagenet/train.py#L88\n", 455 | "def create_learning_rate_fn(\n", 456 | " config: ml_collections.ConfigDict,\n", 457 | " base_learning_rate: float,\n", 458 | " steps_per_epoch: int):\n", 459 | " \"\"\"Create learning rate schedule.\"\"\"\n", 460 | " warmup_fn = optax.linear_schedule(\n", 461 | " init_value=0., end_value=base_learning_rate,\n", 462 | " transition_steps=config.warmup_epochs * steps_per_epoch)\n", 463 | " cosine_epochs = max(config.num_epochs - config.warmup_epochs, 1)\n", 464 | " cosine_fn = optax.cosine_decay_schedule(\n", 465 | " init_value=base_learning_rate,\n", 466 | " decay_steps=cosine_epochs * steps_per_epoch)\n", 467 | " schedule_fn = optax.join_schedules(\n", 468 | " schedules=[warmup_fn, cosine_fn],\n", 469 | " boundaries=[config.warmup_epochs * steps_per_epoch])\n", 470 | " return schedule_fn\n", 471 | "\n", 472 | "tx = optax.sgd(\n", 473 | " learning_rate=learning_rate_fn,\n", 474 | " momentum=config.momentum,\n", 475 | " nesterov=True,\n", 476 | ")\n", 477 | "\n", 478 | "# Example from Haiku (ImageNet example)\n", 479 | "# https://github.com/deepmind/dm-haiku/blob/main/examples/imagenet/train.py#L116\n", 480 | "def make_optimizer() -> optax.GradientTransformation:\n", 481 | " \"\"\"SGD with nesterov momentum and a custom lr schedule.\"\"\"\n", 482 | " return optax.chain(\n", 483 | " optax.trace(\n", 484 | " decay=FLAGS.optimizer_momentum,\n", 485 | " nesterov=FLAGS.optimizer_use_nesterov),\n", 486 | " optax.scale_by_schedule(lr_schedule), optax.scale(-1))" 487 | ] 488 | }, 489 | { 490 | "cell_type": "markdown", 491 | "metadata": { 492 | "id": "WFAeHIEwL0ZH" 493 | }, 494 | "source": [ 495 | "Now let's go beyond these extremely simple models!" 496 | ] 497 | }, 498 | { 499 | "cell_type": "markdown", 500 | "metadata": { 501 | "id": "7_33y-bTl6bd" 502 | }, 503 | "source": [ 504 | "### Creating custom models ⭐" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": null, 510 | "metadata": { 511 | "id": "JOrJHqTSl75M" 512 | }, 513 | "outputs": [], 514 | "source": [ 515 | "class MLP(nn.Module):\n", 516 | " num_neurons_per_layer: Sequence[int] # data field (nn.Module is Python's dataclass)\n", 517 | "\n", 518 | " def setup(self): # because dataclass is implicitly using the __init__ function... :')\n", 519 | " self.layers = [nn.Dense(n) for n in self.num_neurons_per_layer]\n", 520 | "\n", 521 | " def __call__(self, x):\n", 522 | " activation = x\n", 523 | " for i, layer in enumerate(self.layers):\n", 524 | " activation = layer(activation)\n", 525 | " if i != len(self.layers) - 1:\n", 526 | " activation = nn.relu(activation)\n", 527 | " return activation\n", 528 | "\n", 529 | "x_key, init_key = random.split(random.PRNGKey(seed))\n", 530 | "\n", 531 | "model = MLP(num_neurons_per_layer=[16, 8, 1]) # define an MLP model\n", 532 | "x = random.uniform(x_key, (4,4)) # dummy input\n", 533 | "params = model.init(init_key, x) # initialize via init\n", 534 | "y = model.apply(params, x) # do a forward pass via apply\n", 535 | "\n", 536 | "print(jax.tree_map(jnp.shape, params))\n", 537 | "print(f'Output: {y}')\n", 538 | "\n", 539 | "# todo: exercise - use @nn.compact pattern instead\n", 540 | "# todo: check out https://realpython.com/python-data-classes/" 541 | ] 542 | }, 543 | { 544 | "cell_type": "markdown", 545 | "metadata": { 546 | "id": "TEhC-WdPnAYp" 547 | }, 548 | "source": [ 549 | "Great! \n", 550 | "\n", 551 | "Now that we know how to build more complex models let's dive deeper and understand how the 'nn.Dense' module is designed itself.\n", 552 | "\n", 553 | "#### Introducing \"param\"" 554 | ] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "execution_count": null, 559 | "metadata": { 560 | "id": "Z9YhSgxjnBQg" 561 | }, 562 | "outputs": [], 563 | "source": [ 564 | "class MyDenseImp(nn.Module):\n", 565 | " num_neurons: int\n", 566 | " weight_init: Callable = nn.initializers.lecun_normal()\n", 567 | " bias_init: Callable = nn.initializers.zeros\n", 568 | "\n", 569 | " @nn.compact\n", 570 | " def __call__(self, x):\n", 571 | " weight = self.param('weight', # parametar name (as it will appear in the FrozenDict)\n", 572 | " self.weight_init, # initialization function, RNG passed implicitly through init fn\n", 573 | " (x.shape[-1], self.num_neurons)) # shape info\n", 574 | " bias = self.param('bias', self.bias_init, (self.num_neurons,))\n", 575 | "\n", 576 | " return jnp.dot(x, weight) + bias\n", 577 | "\n", 578 | "x_key, init_key = random.split(random.PRNGKey(seed))\n", 579 | "\n", 580 | "model = MyDenseImp(num_neurons=3) # initialize the model\n", 581 | "x = random.uniform(x_key, (4,4)) # dummy input\n", 582 | "params = model.init(init_key, x) # initialize via init\n", 583 | "y = model.apply(params, x) # do a forward pass via apply\n", 584 | "\n", 585 | "print(jax.tree_map(jnp.shape, params))\n", 586 | "print(f'Output: {y}')\n", 587 | "\n", 588 | "# todo: exercise - check out the source code:\n", 589 | "# https://github.com/google/flax/blob/main/flax/linen/linear.py\n", 590 | "# https://github.com/google/jax/blob/main/jax/_src/nn/initializers.py#L150 <- to see why lecun_normal() vs zeros (no brackets)" 591 | ] 592 | }, 593 | { 594 | "cell_type": "code", 595 | "execution_count": null, 596 | "metadata": { 597 | "id": "AqCPhl9fBI_Z" 598 | }, 599 | "outputs": [], 600 | "source": [ 601 | "from inspect import signature\n", 602 | "\n", 603 | "# You can see it expects a PRNG key and it is passed implicitly through the init fn (same for zeros)\n", 604 | "print(signature(nn.initializers.lecun_normal()))" 605 | ] 606 | }, 607 | { 608 | "cell_type": "markdown", 609 | "metadata": { 610 | "id": "MWB8HvLHn6g0" 611 | }, 612 | "source": [ 613 | "So far we've only seen **trainable** params. \n", 614 | "\n", 615 | "ML models often times have variables which are part of the state but are not optimized via gradient descent.\n", 616 | "\n", 617 | "Let's see how we can handle them using a simple (and contrived) example!\n", 618 | "\n", 619 | "#### Introducing \"variable\"\n", 620 | "\n", 621 | "*Note on terminology: variable is a broader term and it includes both params (trainable variables) as well as non-trainable vars.*" 622 | ] 623 | }, 624 | { 625 | "cell_type": "code", 626 | "execution_count": null, 627 | "metadata": { 628 | "id": "oGE6qTHHngYh" 629 | }, 630 | "outputs": [], 631 | "source": [ 632 | "class BiasAdderWithRunningMean(nn.Module):\n", 633 | " decay: float = 0.99\n", 634 | "\n", 635 | " @nn.compact\n", 636 | " def __call__(self, x):\n", 637 | " is_initialized = self.has_variable('batch_stats', 'ema')\n", 638 | "\n", 639 | " # 'batch_stats' is not an arbitrary name!\n", 640 | " # Flax uses that name in its implementation of BatchNorm (hard-coded, probably not the best of designs?)\n", 641 | " ema = self.variable('batch_stats', 'ema', lambda shape: jnp.zeros(shape), x.shape[1:])\n", 642 | "\n", 643 | " # self.param will by default add this variable to 'params' collection (vs 'batch_stats' above)\n", 644 | " # Again some idiosyncrasies here we need to pass a key even though we don't actually use it...\n", 645 | " bias = self.param('bias', lambda key, shape: jnp.zeros(shape), x.shape[1:])\n", 646 | "\n", 647 | " if is_initialized:\n", 648 | " # self.variable returns a reference hence .value\n", 649 | " ema.value = self.decay * ema.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)\n", 650 | "\n", 651 | " return x - ema.value + bias\n", 652 | "\n", 653 | "x_key, init_key = random.split(random.PRNGKey(seed))\n", 654 | "\n", 655 | "model = BiasAdderWithRunningMean()\n", 656 | "x = random.uniform(x_key, (10,4)) # dummy input\n", 657 | "variables = model.init(init_key, x)\n", 658 | "print(f'Multiple collections = {variables}') # we can now see a new collection 'batch_stats'\n", 659 | "\n", 660 | "# We have to use mutable since regular params are not modified during the forward\n", 661 | "# pass, but these variables are. We can't keep state internally (because JAX) so we have to return it.\n", 662 | "y, updated_non_trainable_params = model.apply(variables, x, mutable=['batch_stats'])\n", 663 | "print(updated_non_trainable_params)" 664 | ] 665 | }, 666 | { 667 | "cell_type": "code", 668 | "execution_count": null, 669 | "metadata": { 670 | "id": "PuzwVt8RoHvY" 671 | }, 672 | "outputs": [], 673 | "source": [ 674 | "# Let's see how we could train such model!\n", 675 | "def update_step(opt, apply_fn, x, opt_state, params, non_trainable_params):\n", 676 | "\n", 677 | " def loss_fn(params):\n", 678 | " y, updated_non_trainable_params = apply_fn(\n", 679 | " {'params': params, **non_trainable_params}, \n", 680 | " x, mutable=list(non_trainable_params.keys()))\n", 681 | " \n", 682 | " loss = ((x - y) ** 2).sum() # not doing anything really, just for the demo purpose\n", 683 | "\n", 684 | " return loss, updated_non_trainable_params\n", 685 | "\n", 686 | " (loss, non_trainable_params), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)\n", 687 | " updates, opt_state = opt.update(grads, opt_state)\n", 688 | " params = optax.apply_updates(params, updates)\n", 689 | " \n", 690 | " return opt_state, params, non_trainable_params # all of these represent the state - ugly, for now\n", 691 | "\n", 692 | "model = BiasAdderWithRunningMean()\n", 693 | "x = jnp.ones((10,4)) # dummy input, using ones because it's easier to see what's going on\n", 694 | "\n", 695 | "variables = model.init(random.PRNGKey(seed), x)\n", 696 | "non_trainable_params, params = variables.pop('params')\n", 697 | "del variables # delete variables to avoid wasting resources (this pattern is used in the official code)\n", 698 | "\n", 699 | "sgd_opt = optax.sgd(learning_rate=0.1) # originally you'll see them use the 'tx' naming (from opTaX)\n", 700 | "opt_state = sgd_opt.init(params)\n", 701 | "\n", 702 | "for _ in range(3):\n", 703 | " # We'll later see how TrainState abstraction will make this step much more elegant!\n", 704 | " opt_state, params, non_trainable_params = update_step(sgd_opt, model.apply, x, opt_state, params, non_trainable_params)\n", 705 | " print(non_trainable_params)" 706 | ] 707 | }, 708 | { 709 | "cell_type": "markdown", 710 | "metadata": { 711 | "id": "gzWUq5vBrWMe" 712 | }, 713 | "source": [ 714 | "Let's go a level up in abstraction again now that we understand params and variables!\n", 715 | "\n", 716 | "Certain layers like BatchNorm will use variables in the background.\n", 717 | "\n", 718 | "Let's see a last example that is conceptually as complicated as it gets when it comes to Flax's idiosyncrasies, and high-level at the same time." 719 | ] 720 | }, 721 | { 722 | "cell_type": "code", 723 | "execution_count": null, 724 | "metadata": { 725 | "id": "rDw2986orY0a" 726 | }, 727 | "outputs": [], 728 | "source": [ 729 | "class DDNBlock(nn.Module):\n", 730 | " \"\"\"Dense, dropout + batchnorm combo.\n", 731 | "\n", 732 | " Contains trainable variables (params), non-trainable variables (batch stats),\n", 733 | " and stochasticity in the forward pass (because of dropout).\n", 734 | " \"\"\"\n", 735 | " num_neurons: int\n", 736 | " training: bool\n", 737 | "\n", 738 | " @nn.compact\n", 739 | " def __call__(self, x):\n", 740 | " x = nn.Dense(self.num_neurons)(x)\n", 741 | " x = nn.Dropout(rate=0.5, deterministic=not self.training)(x)\n", 742 | " x = nn.BatchNorm(use_running_average=not self.training)(x)\n", 743 | " return x\n", 744 | "\n", 745 | "key1, key2, key3, key4 = random.split(random.PRNGKey(seed), 4)\n", 746 | "\n", 747 | "model = DDNBlock(num_neurons=3, training=True)\n", 748 | "x = random.uniform(key1, (3,4,4))\n", 749 | "\n", 750 | "# New: because of Dropout we now have to include its unique key - kinda weird, but you get used to it\n", 751 | "variables = model.init({'params': key2, 'dropout': key3}, x)\n", 752 | "print(variables)\n", 753 | "\n", 754 | "# And same here, everything else remains the same as the previous example\n", 755 | "y, non_trainable_params = model.apply(variables, x, rngs={'dropout': key4}, mutable=['batch_stats'])\n", 756 | "\n", 757 | "# Let's run these model variables during \"evaluation\":\n", 758 | "eval_model = DDNBlock(num_neurons=3, training=False)\n", 759 | "# Because training=False we don't have stochasticity in the forward pass neither do we update the stats\n", 760 | "y = eval_model.apply(variables, x)" 761 | ] 762 | }, 763 | { 764 | "cell_type": "markdown", 765 | "metadata": { 766 | "id": "Ys1y-yM8vzT8" 767 | }, 768 | "source": [ 769 | "### A fully-fledged CNN on MNIST example in Flax! 💥\n", 770 | "\n", 771 | "Modified the official MNIST example here: https://github.com/google/flax/tree/main/examples/mnist\n", 772 | "\n", 773 | "We'll be using PyTorch dataloading instead of TFDS.\n", 774 | "\n", 775 | "Let's start by defining a model:" 776 | ] 777 | }, 778 | { 779 | "cell_type": "code", 780 | "execution_count": 3, 781 | "metadata": { 782 | "id": "MD8t9K2Nv0yC" 783 | }, 784 | "outputs": [], 785 | "source": [ 786 | "class CNN(nn.Module): # lots of hardcoding, but it serves a purpose for a simple demo\n", 787 | " @nn.compact\n", 788 | " def __call__(self, x):\n", 789 | " x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n", 790 | " x = nn.relu(x)\n", 791 | " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", 792 | " x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n", 793 | " x = nn.relu(x)\n", 794 | " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", 795 | " x = x.reshape((x.shape[0], -1)) # flatten\n", 796 | " x = nn.Dense(features=256)(x)\n", 797 | " x = nn.relu(x)\n", 798 | " x = nn.Dense(features=10)(x)\n", 799 | " x = nn.log_softmax(x)\n", 800 | " return x" 801 | ] 802 | }, 803 | { 804 | "cell_type": "markdown", 805 | "metadata": { 806 | "id": "rVgWLMhiSAYv" 807 | }, 808 | "source": [ 809 | "Let's add the data loading support in PyTorch!\n", 810 | "\n", 811 | "I'll be reusing code from [tutorial #3](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_3_JAX_Neural_Network_from_Scratch_Colab.ipynb):" 812 | ] 813 | }, 814 | { 815 | "cell_type": "code", 816 | "execution_count": 4, 817 | "metadata": { 818 | "id": "UZ-og2UOUUWD" 819 | }, 820 | "outputs": [], 821 | "source": [ 822 | "def custom_transform(x):\n", 823 | " # A couple of modifications here compared to tutorial #3 since we're using a CNN\n", 824 | " # Input: (28, 28) uint8 [0, 255] torch.Tensor, Output: (28, 28, 1) float32 [0, 1] np array\n", 825 | " return np.expand_dims(np.array(x, dtype=np.float32), axis=2) / 255.\n", 826 | "\n", 827 | "def custom_collate_fn(batch):\n", 828 | " \"\"\"Provides us with batches of numpy arrays and not PyTorch's tensors.\"\"\"\n", 829 | " transposed_data = list(zip(*batch))\n", 830 | "\n", 831 | " labels = np.array(transposed_data[1])\n", 832 | " imgs = np.stack(transposed_data[0])\n", 833 | "\n", 834 | " return imgs, labels\n", 835 | "\n", 836 | "mnist_img_size = (28, 28, 1)\n", 837 | "batch_size = 128\n", 838 | "\n", 839 | "train_dataset = MNIST(root='train_mnist', train=True, download=True, transform=custom_transform)\n", 840 | "test_dataset = MNIST(root='test_mnist', train=False, download=True, transform=custom_transform)\n", 841 | "\n", 842 | "train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=custom_collate_fn, drop_last=True)\n", 843 | "test_loader = DataLoader(test_dataset, batch_size, shuffle=False, collate_fn=custom_collate_fn, drop_last=True)\n", 844 | "\n", 845 | "# optimization - loading the whole dataset into memory\n", 846 | "train_images = jnp.array(train_dataset.data)\n", 847 | "train_lbls = jnp.array(train_dataset.targets)\n", 848 | "\n", 849 | "# np.expand_dims is to convert shape from (10000, 28, 28) -> (10000, 28, 28, 1)\n", 850 | "# We don't have to do this for training images because custom_transform does it for us.\n", 851 | "test_images = np.expand_dims(jnp.array(test_dataset.data), axis=3)\n", 852 | "test_lbls = jnp.array(test_dataset.targets)" 853 | ] 854 | }, 855 | { 856 | "cell_type": "code", 857 | "execution_count": 5, 858 | "metadata": { 859 | "colab": { 860 | "base_uri": "https://localhost:8080/", 861 | "height": 282 862 | }, 863 | "id": "2HeXX51NU0k6", 864 | "outputId": "43dad5bf-20c2-4c5a-9705-12b2e422f915" 865 | }, 866 | "outputs": [ 867 | { 868 | "name": "stdout", 869 | "output_type": "stream", 870 | "text": [ 871 | "7\n" 872 | ] 873 | }, 874 | { 875 | "data": { 876 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAANiklEQVR4nO3df4wc9XnH8c8n/kV8QGtDcF3j4ISQqE4aSHWBRNDKESUFImSiJBRLtVyJ5lALElRRW0QVBalVSlEIok0aySluHESgaQBhJTSNa6W1UKljg4yxgdaEmsau8QFOaxPAP/DTP24cHXD7vWNndmft5/2SVrs7z87Oo/F9PLMzO/t1RAjA8e9tbTcAoD8IO5AEYQeSIOxAEoQdSGJ6Pxc207PiBA31c5FAKq/qZzoYBzxRrVbYbV8s6XZJ0yT9bUTcXHr9CRrSeb6wziIBFGyIdR1rXe/G254m6auSLpG0WNIy24u7fT8AvVXnM/u5kp6OiGci4qCkeyQtbaYtAE2rE/YFkn4y7vnOatrr2B6xvcn2pkM6UGNxAOro+dH4iFgZEcMRMTxDs3q9OAAd1An7LkkLxz0/vZoGYADVCftGSWfZfpftmZKulLSmmbYANK3rU28Rcdj2tZL+SWOn3lZFxLbGOgPQqFrn2SPiQUkPNtQLgB7i67JAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJGoN2Wx7h6T9kl6TdDgihptoCkDzaoW98rGIeKGB9wHQQ+zGA0nUDXtI+oHtR2yPTPQC2yO2N9nedEgHai4OQLfq7sZfEBG7bJ8maa3tpyJi/fgXRMRKSSsl6WTPjZrLA9ClWlv2iNhV3Y9Kul/SuU00BaB5XYfd9pDtk44+lvRxSVubagxAs+rsxs+TdL/to+/zrYj4fiNdAWhc12GPiGcknd1gLwB6iFNvQBKEHUiCsANJEHYgCcIOJNHEhTApvPjZj3asvXP508V5nxqdV6wfPDCjWF9wd7k+e+dLHWtHNj9RnBd5sGUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQ4zz5Ff/xH3+pY+9TQT8szn1lz4UvK5R2HX+5Yu/35j9Vc+LHrR6NndKwN3foLxXmnr3uk6XZax5YdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JwRP8GaTnZc+M8X9i35TXpZ58+r2PthQ+W/8+c82R5Hf/0V1ysz/zg/xbrt3zgvo61i97+SnHe7718YrH+idmdr5Wv65U4WKxvODBUrC854VDXy37P964u1t87srHr927ThlinfbF3wj8otuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATXs0/R0Hc2FGr13vvkerPrr39pScfan5+/qLzsfy3/5v0tS97TRUdTM/2VI8X60Jbdxfop6+8t1n91Zuff25+9o/xb/MejSbfstlfZHrW9ddy0ubbX2t5e3c/pbZsA6prKbvw3JF38hmk3SFoXEWdJWlc9BzDAJg17RKyXtPcNk5dKWl09Xi3p8ob7AtCwbj+zz4uIox+onpPUcTAz2yOSRiTpBM3ucnEA6qp9ND7GrqTpeKVHRKyMiOGIGJ6hWXUXB6BL3YZ9j+35klTdjzbXEoBe6DbsayStqB6vkPRAM+0A6JVJP7Pbvltjv1x+qu2dkr4g6WZJ37Z9laRnJV3RyyZRdvi5PR1rQ/d2rknSa5O899B3Xuyio2bs+b2PFuvvn1n+8/3S3vd1rC36u2eK8x4uVo9Nk4Y9IpZ1KB2bv0IBJMXXZYEkCDuQBGEHkiDsQBKEHUiCS1zRmulnLCzWv3LjV4r1GZ5WrP/D7b/ZsXbK7oeL8x6P2LIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKcZ0drnvrDBcX6h2eVh7LedrA8HPXcJ15+yz0dz9iyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASnGdHTx34xIc71h799G2TzF0eQej3r7uuWH/7v/1okvfPhS07kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBeXb01H9f0nl7cqLL59GX/ddFxfrs7z9WrEexms+kW3bbq2yP2t46btpNtnfZ3lzdLu1tmwDqmspu/DckXTzB9Nsi4pzq9mCzbQFo2qRhj4j1kvb2oRcAPVTnAN21trdUu/lzOr3I9ojtTbY3HdKBGosDUEe3Yf+apDMlnSNpt6RbO70wIlZGxHBEDM+Y5MIGAL3TVdgjYk9EvBYRRyR9XdK5zbYFoGldhd32/HFPPylpa6fXAhgMk55nt323pCWSTrW9U9IXJC2xfY7GTmXukHR1D3vEAHvbSScV68t//aGOtX1HXi3OO/rFdxfrsw5sLNbxepOGPSKWTTD5jh70AqCH+LoskARhB5Ig7EAShB1IgrADSXCJK2rZftP7i/Xvnvo3HWtLt3+qOO+sBzm11iS27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOfZUfR/v/ORYn3Lb/9Vsf7jw4c61l76y9OL887S7mIdbw1bdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgvPsyU1f8MvF+vWf//tifZbLf0JXPra8Y+0d/8j16v3Elh1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkuA8+3HO08v/xGd/d2ex/pkTXyzW79p/WrE+7/OdtydHinOiaZNu2W0vtP1D20/Y3mb7umr6XNtrbW+v7uf0vl0A3ZrKbvxhSZ+LiMWSPiLpGtuLJd0gaV1EnCVpXfUcwICaNOwRsTsiHq0e75f0pKQFkpZKWl29bLWky3vVJID63tJndtuLJH1I0gZJ8yLi6I+EPSdpXod5RiSNSNIJmt1tnwBqmvLReNsnSrpX0vURsW98LSJCUkw0X0SsjIjhiBieoVm1mgXQvSmF3fYMjQX9roi4r5q8x/b8qj5f0mhvWgTQhEl3421b0h2SnoyIL48rrZG0QtLN1f0DPekQ9Zz9vmL5z067s9bbf/WLnynWf/Gxh2u9P5ozlc/s50taLulx25uraTdqLOTftn2VpGclXdGbFgE0YdKwR8RDktyhfGGz7QDoFb4uCyRB2IEkCDuQBGEHkiDsQBJc4nocmLb4vR1rI/fU+/rD4lXXFOuL7vz3Wu+P/mHLDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcJ79OPDUH3T+Yd/LZu/rWJuK0//lYPkFMeEPFGEAsWUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQ4z34MePWyc4v1dZfdWqgy5BbGsGUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSSmMj77QknflDRPUkhaGRG3275J0mclPV+99MaIeLBXjWb2P+dPK9bfOb37c+l37T+tWJ+xr3w9O1ezHzum8qWaw5I+FxGP2j5J0iO211a12yLiS71rD0BTpjI++25Ju6vH+20/KWlBrxsD0Ky39Jnd9iJJH5K0oZp0re0ttlfZnvC3kWyP2N5ke9MhHajVLIDuTTnstk+UdK+k6yNin6SvSTpT0jka2/JP+AXtiFgZEcMRMTxDsxpoGUA3phR22zM0FvS7IuI+SYqIPRHxWkQckfR1SeWrNQC0atKw27akOyQ9GRFfHjd9/riXfVLS1ubbA9CUqRyNP1/SckmP295cTbtR0jLb52js7MsOSVf3pEPU8hcvLi7WH/6tRcV67H68wW7QpqkcjX9IkicocU4dOIbwDTogCcIOJEHYgSQIO5AEYQeSIOxAEo4+Drl7sufGeb6wb8sDstkQ67Qv9k50qpwtO5AFYQeSIOxAEoQdSIKwA0kQdiAJwg4k0dfz7Lafl/TsuEmnSnqhbw28NYPa26D2JdFbt5rs7YyIeMdEhb6G/U0LtzdFxHBrDRQMam+D2pdEb93qV2/sxgNJEHYgibbDvrLl5ZcMam+D2pdEb93qS2+tfmYH0D9tb9kB9AlhB5JoJey2L7b9H7aftn1DGz10YnuH7cdtb7a9qeVeVtketb113LS5ttfa3l7dTzjGXku93WR7V7XuNtu+tKXeFtr+oe0nbG+zfV01vdV1V+irL+ut75/ZbU+T9J+SLpK0U9JGScsi4om+NtKB7R2ShiOi9S9g2P4NSS9J+mZEfKCadoukvRFxc/Uf5ZyI+JMB6e0mSS+1PYx3NVrR/PHDjEu6XNLvqsV1V+jrCvVhvbWxZT9X0tMR8UxEHJR0j6SlLfQx8CJivaS9b5i8VNLq6vFqjf2x9F2H3gZCROyOiEerx/slHR1mvNV1V+irL9oI+wJJPxn3fKcGa7z3kPQD24/YHmm7mQnMi4jd1ePnJM1rs5kJTDqMdz+9YZjxgVl33Qx/XhcH6N7sgoj4NUmXSLqm2l0dSDH2GWyQzp1OaRjvfplgmPGfa3PddTv8eV1thH2XpIXjnp9eTRsIEbGruh+VdL8GbyjqPUdH0K3uR1vu5+cGaRjviYYZ1wCsuzaHP28j7BslnWX7XbZnSrpS0poW+ngT20PVgRPZHpL0cQ3eUNRrJK2oHq+Q9ECLvbzOoAzj3WmYcbW87lof/jwi+n6TdKnGjsj/WNKfttFDh77eLemx6rat7d4k3a2x3bpDGju2cZWkUyStk7Rd0j9LmjtAvd0p6XFJWzQWrPkt9XaBxnbRt0jaXN0ubXvdFfrqy3rj67JAEhygA5Ig7EAShB1IgrADSRB2IAnCDiRB2IEk/h9BCfQTVPflJQAAAABJRU5ErkJggg==\n", 877 | "text/plain": [ 878 | "
" 879 | ] 880 | }, 881 | "metadata": { 882 | "needs_background": "light" 883 | }, 884 | "output_type": "display_data" 885 | } 886 | ], 887 | "source": [ 888 | "# Visualize a single image\n", 889 | "imgs, lbls = next(iter(test_loader))\n", 890 | "img = imgs[0].reshape(mnist_img_size)[:, :, 0]\n", 891 | "gt_lbl = lbls[0]\n", 892 | "\n", 893 | "print(gt_lbl)\n", 894 | "plt.imshow(img); plt.show()" 895 | ] 896 | }, 897 | { 898 | "cell_type": "markdown", 899 | "metadata": { 900 | "id": "TsGPQKx0SPL-" 901 | }, 902 | "source": [ 903 | "Great - we have our data pipeline ready and the model architecture defined.\n", 904 | "\n", 905 | "Now let's define core training functions:" 906 | ] 907 | }, 908 | { 909 | "cell_type": "code", 910 | "execution_count": 6, 911 | "metadata": { 912 | "id": "qD8ztbEsVM43" 913 | }, 914 | "outputs": [], 915 | "source": [ 916 | "@jax.jit\n", 917 | "def train_step(state, imgs, gt_labels):\n", 918 | " def loss_fn(params):\n", 919 | " logits = CNN().apply({'params': params}, imgs)\n", 920 | " one_hot_gt_labels = jax.nn.one_hot(gt_labels, num_classes=10)\n", 921 | " loss = -jnp.mean(jnp.sum(one_hot_gt_labels * logits, axis=-1))\n", 922 | " return loss, logits\n", 923 | " \n", 924 | " (_, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)\n", 925 | " state = state.apply_gradients(grads=grads) # this is the whole update now! concise!\n", 926 | " metrics = compute_metrics(logits=logits, gt_labels=gt_labels) # duplicating loss calculation but it's a bit cleaner\n", 927 | " return state, metrics\n", 928 | "\n", 929 | "@jax.jit\n", 930 | "def eval_step(state, imgs, gt_labels):\n", 931 | " logits = CNN().apply({'params': state.params}, imgs)\n", 932 | " return compute_metrics(logits=logits, gt_labels=gt_labels)" 933 | ] 934 | }, 935 | { 936 | "cell_type": "code", 937 | "execution_count": 7, 938 | "metadata": { 939 | "id": "v5VblVs2VWxo" 940 | }, 941 | "outputs": [], 942 | "source": [ 943 | "def train_one_epoch(state, dataloader, epoch):\n", 944 | " \"\"\"Train for 1 epoch on the training set.\"\"\"\n", 945 | " batch_metrics = []\n", 946 | " for cnt, (imgs, labels) in enumerate(dataloader):\n", 947 | " state, metrics = train_step(state, imgs, labels)\n", 948 | " batch_metrics.append(metrics)\n", 949 | "\n", 950 | " # Aggregate the metrics\n", 951 | " batch_metrics_np = jax.device_get(batch_metrics) # pull from the accelerator onto host (CPU)\n", 952 | " epoch_metrics_np = {\n", 953 | " k: np.mean([metrics[k] for metrics in batch_metrics_np])\n", 954 | " for k in batch_metrics_np[0]\n", 955 | " }\n", 956 | "\n", 957 | " return state, epoch_metrics_np\n", 958 | "\n", 959 | "def evaluate_model(state, test_imgs, test_lbls):\n", 960 | " \"\"\"Evaluate on the validation set.\"\"\"\n", 961 | " metrics = eval_step(state, test_imgs, test_lbls)\n", 962 | " metrics = jax.device_get(metrics) # pull from the accelerator onto host (CPU)\n", 963 | " metrics = jax.tree_map(lambda x: x.item(), metrics) # np.ndarray -> scalar\n", 964 | " return metrics" 965 | ] 966 | }, 967 | { 968 | "cell_type": "code", 969 | "execution_count": 8, 970 | "metadata": { 971 | "id": "xiV5yiA4BKEk" 972 | }, 973 | "outputs": [], 974 | "source": [ 975 | "# This one will keep things nice and tidy compared to our previous examples\n", 976 | "def create_train_state(key, learning_rate, momentum):\n", 977 | " cnn = CNN()\n", 978 | " params = cnn.init(key, jnp.ones([1, *mnist_img_size]))['params']\n", 979 | " sgd_opt = optax.sgd(learning_rate, momentum)\n", 980 | " # TrainState is a simple built-in wrapper class that makes things a bit cleaner\n", 981 | " return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=sgd_opt)\n", 982 | "\n", 983 | "def compute_metrics(*, logits, gt_labels):\n", 984 | " one_hot_gt_labels = jax.nn.one_hot(gt_labels, num_classes=10)\n", 985 | "\n", 986 | " loss = -jnp.mean(jnp.sum(one_hot_gt_labels * logits, axis=-1))\n", 987 | " accuracy = jnp.mean(jnp.argmax(logits, -1) == gt_labels)\n", 988 | "\n", 989 | " metrics = {\n", 990 | " 'loss': loss,\n", 991 | " 'accuracy': accuracy,\n", 992 | " }\n", 993 | " return metrics" 994 | ] 995 | }, 996 | { 997 | "cell_type": "code", 998 | "execution_count": 9, 999 | "metadata": { 1000 | "colab": { 1001 | "base_uri": "https://localhost:8080/" 1002 | }, 1003 | "id": "s8EFriHnVcJO", 1004 | "outputId": "cb40714f-6150-44d6-e1e0-290b72a23eda" 1005 | }, 1006 | "outputs": [ 1007 | { 1008 | "name": "stdout", 1009 | "output_type": "stream", 1010 | "text": [ 1011 | "Train epoch: 1, loss: 0.2903152406215668, accuracy: 91.86198115348816\n", 1012 | "Test epoch: 1, loss: 44.35035705566406, accuracy: 94.77999806404114\n", 1013 | "Train epoch: 2, loss: 0.058339256793260574, accuracy: 98.23551177978516\n", 1014 | "Test epoch: 2, loss: 17.13631820678711, accuracy: 97.33999967575073\n" 1015 | ] 1016 | } 1017 | ], 1018 | "source": [ 1019 | "# Finally let's define the high-level training/val loops\n", 1020 | "seed = 0 # needless to say these should be in a config or defined like flags\n", 1021 | "learning_rate = 0.1\n", 1022 | "momentum = 0.9\n", 1023 | "num_epochs = 2\n", 1024 | "batch_size = 32\n", 1025 | "\n", 1026 | "train_state = create_train_state(jax.random.PRNGKey(seed), learning_rate, momentum)\n", 1027 | "\n", 1028 | "for epoch in range(1, num_epochs + 1):\n", 1029 | " train_state, train_metrics = train_one_epoch(train_state, train_loader, epoch)\n", 1030 | " print(f\"Train epoch: {epoch}, loss: {train_metrics['loss']}, accuracy: {train_metrics['accuracy'] * 100}\")\n", 1031 | "\n", 1032 | " test_metrics = evaluate_model(train_state, test_images, test_lbls)\n", 1033 | " print(f\"Test epoch: {epoch}, loss: {test_metrics['loss']}, accuracy: {test_metrics['accuracy'] * 100}\")\n", 1034 | "\n", 1035 | "# todo: exercise - how would we go about adding dropout? What about BatchNorm? What would have to change?" 1036 | ] 1037 | }, 1038 | { 1039 | "cell_type": "markdown", 1040 | "metadata": { 1041 | "id": "6U-BIjQ1v4ff" 1042 | }, 1043 | "source": [ 1044 | "Bonus point: a walk-through the \"non-toy\", distributed ImageNet CNN training example.\n", 1045 | "\n", 1046 | "Head over to https://github.com/google/flax/tree/main/examples/imagenet\n", 1047 | "\n", 1048 | "You'll keep seeing the same pattern/structure in all official Flax examples." 1049 | ] 1050 | }, 1051 | { 1052 | "cell_type": "markdown", 1053 | "metadata": { 1054 | "id": "6Q4C2M2tv_0J" 1055 | }, 1056 | "source": [ 1057 | "### Further learning resources 📚\n", 1058 | "\n", 1059 | "Aside from the [official docs](https://flax.readthedocs.io/en/latest/) and [examples](https://github.com/google/flax/tree/main/examples) I found [HuggingFace's Flax examples](https://github.com/huggingface/transformers/tree/master/examples/flax) and the resources from their [\"community week\"](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects) useful as well.\n", 1060 | "\n", 1061 | "Finally, [source code](https://github.com/google/flax) is also your friend, as the library is still evolving." 1062 | ] 1063 | }, 1064 | { 1065 | "cell_type": "markdown", 1066 | "metadata": { 1067 | "id": "T5DqxlZ-SD3e" 1068 | }, 1069 | "source": [ 1070 | "### Connect with me ❤️\n", 1071 | "\n", 1072 | "Last but not least I regularly post AI-related stuff (paper summaries, AI news, etc.) on my Twitter/LinkedIn. We also have an ever increasing Discord community (1600+ members at the time of writing this). If you care about any of these I encourage you to connect! \n", 1073 | "\n", 1074 | "Social:
\n", 1075 | "💼 LinkedIn - https://www.linkedin.com/in/aleksagordic/
\n", 1076 | "🐦 Twitter - https://twitter.com/gordic_aleksa
\n", 1077 | "👨‍👩‍👧‍👦 Discord - https://discord.gg/peBrCpheKE
\n", 1078 | "🙏 Patreon - https://www.patreon.com/theaiepiphany
\n", 1079 | "\n", 1080 | "Content:
\n", 1081 | "📺 YouTube - https://www.youtube.com/c/TheAIEpiphany/
\n", 1082 | "📚 Medium - https://gordicaleksa.medium.com/
\n", 1083 | "💻 GitHub - https://github.com/gordicaleksa
\n", 1084 | "📢 AI Newsletter - https://aiepiphany.substack.com/
" 1085 | ] 1086 | } 1087 | ], 1088 | "metadata": { 1089 | "accelerator": "GPU", 1090 | "colab": { 1091 | "collapsed_sections": [], 1092 | "name": "Tutorial 4: Flax Zero2Hero.ipynb", 1093 | "provenance": [] 1094 | }, 1095 | "kernelspec": { 1096 | "display_name": "Python 3 (ipykernel)", 1097 | "language": "python", 1098 | "name": "python3" 1099 | }, 1100 | "language_info": { 1101 | "codemirror_mode": { 1102 | "name": "ipython", 1103 | "version": 3 1104 | }, 1105 | "file_extension": ".py", 1106 | "mimetype": "text/x-python", 1107 | "name": "python", 1108 | "nbconvert_exporter": "python", 1109 | "pygments_lexer": "ipython3", 1110 | "version": "3.9.0" 1111 | } 1112 | }, 1113 | "nbformat": 4, 1114 | "nbformat_minor": 1 1115 | } 1116 | --------------------------------------------------------------------------------