├── LICENSE ├── .gitignore ├── README.md ├── haiku_basics.ipynb ├── flax_mnist.ipynb ├── flax_basics.ipynb ├── jax_basic.ipynb └── basic_transformer.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Deterministic Algorithms Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Jax-Journey 2 | 3 | ## A first Introduction 4 | 5 | **Blog** : [Getting Started with Jax](https://roberttlange.github.io/posts/2020/03/blog-post-10/) by Robert Tjarko Lange 6 | 7 | **Notebook** : [jax_basic.ipynb](https://github.com/deterministic-algorithms-lab/Jax-Journey/blob/main/jax_basic.ipynb) 8 | 9 | ## Functional Programmning Intro 10 | 11 | **Blog** : [Functional JS 2](https://medium.com/dailyjs/functional-js-2-functions-duh-70bf22f87bb8) , [Functional JS 3](https://medium.com/dailyjs/functional-js-3-state-89d8cc9ebc9e) by Krzysztof Czernek 12 | 13 | ## How it works 14 | 15 | **A nice comment on Jax** : https://github.com/google/jax/issues/196#issuecomment-451671635 16 | 17 | **Blog** : [From PyTorch to Jax : Towards NN frameworks that purify stateful code](https://sjmielke.com/jax-purify.htm) by Sabrina J. Mielke 18 | 19 | **Notebook** : [making_functions_pure.ipynb](https://github.com/deterministic-algorithms-lab/Jax-Journey/blob/main/making_functions_pure.ipynb) 20 | 21 | ## Flax Intro 22 | 23 | **Blog** : [Documentation Tutorial](https://flax.readthedocs.io/en/latest/notebooks/flax_basics.html) 24 | 25 | **Notebook** : [Basics-MLP](https://github.com/deterministic-algorithms-lab/Jax-Journey/blob/main/flax_basics.ipynb) , [Basics-MNIST](https://github.com/deterministic-algorithms-lab/Jax-Journey/blob/main/flax_mnist.ipynb) 26 | 27 | ### Note : Now is the best time to read [Jax-The Sharp Bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) 28 | 29 | ## Haiku Intro 30 | 31 | **Notebook** : [haiku basics with lstm](https://github.com/deterministic-algorithms-lab/Jax-Journey/blob/main/haiku_basics.ipynb) 32 | 33 | **Blog** : [From PyTorch to Jax : Towards NN frameworks that purify stateful code](https://sjmielke.com/jax-purify.htm) by Sabrina J. Mielke 34 | 35 | ## Basic Transformer Using Haiku 36 | 37 | **Notebook** : [basic_transformer.ipynb](https://github.com/deterministic-algorithms-lab/Jax-Journey/blob/main/basic_transformer.ipynb) 38 | 39 | **Blog** : [FineTuning Transformers with Jax+Haiku](https://www.pragmatic.ml/finetuning-transformers-with-jax-and-haiku/) by Madison May 40 | 41 | ## Jax - A deeper Dive 42 | 43 | The following tutorials form a nice starting point, for a deeper dive into Jax mechanics : 44 | 45 | 1. [On Automatic Differentiation in Jax](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) 46 | 47 | 2. [How Jax Represents your functions internally : Jaxprs](https://jax.readthedocs.io/en/latest/jaxpr.html#understanding-jaxprs) 48 | 49 | 3. [How Jax Primitives Work](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html) 50 | -------------------------------------------------------------------------------- /haiku_basics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "haiku-basics.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyOsfFY4x7BGS34MXZh7RqbD", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | } 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "view-in-github", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\"Open" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "metadata": { 31 | "id": "uOCnw0vnehKT" 32 | }, 33 | "source": [ 34 | "!pip install git+https://github.com/deepmind/dm-haiku" 35 | ], 36 | "execution_count": null, 37 | "outputs": [] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "metadata": { 42 | "id": "7vZsBj6BekmL" 43 | }, 44 | "source": [ 45 | "import haiku as hk\n", 46 | "import jax.numpy as jnp\n", 47 | "import jax" 48 | ], 49 | "execution_count": null, 50 | "outputs": [] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": { 55 | "id": "OlOX_u66b9y1" 56 | }, 57 | "source": [ 58 | "# LSTM" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": { 64 | "id": "7S0xBoBtjXoX" 65 | }, 66 | "source": [ 67 | "**Specialities :**\n", 68 | "\n", 69 | "* ```name``` argument in ```__init__``` . Module must call ```super().__init__()``` with its ```name``` .\n", 70 | "* ```__call__``` can take any arguments, return any. \n", 71 | "* Only single function(```__call__```) for both ```init_fn()``` and ```apply_fn()```. \n" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "metadata": { 77 | "id": "Qs9VPPhiS9_V" 78 | }, 79 | "source": [ 80 | "class HaikuLSTMCell(hk.Module):\n", 81 | " def __init__(self, in_dim, out_dim, name=None):\n", 82 | " super().__init__(name=name or \"lstmcell\")\n", 83 | " self.in_dim = in_dim\n", 84 | " self.out_dim = out_dim\n", 85 | " \n", 86 | " def __call__(self, inputs, h, c):\n", 87 | " weights_ih = hk.get_parameter(\"weight_ih\", \n", 88 | " (4*self.out_dim, self.in_dim),\n", 89 | " init = hk.initializers.UniformScaling())\n", 90 | " weights_hh = hk.get_parameter(\"weights_hh\",\n", 91 | " (4*self.out_dim, self.out_dim),\n", 92 | " init=hk.initializers.UniformScaling())\n", 93 | " bias = hk.get_parameter(\"bias\",\n", 94 | " (4*self.out_dim,),\n", 95 | " init = hk.initializers.Constant(0.0))\n", 96 | " \n", 97 | " ifgo = weights_ih @ inputs + weights_hh @ h + bias\n", 98 | " i, f, g, o = jnp.split(ifgo, indices_or_sections=4, axis=-1)\n", 99 | " \n", 100 | " i = jax.nn.sigmoid(i)\n", 101 | " f = jax.nn.sigmoid(f)\n", 102 | " g = jnp.tanh(g)\n", 103 | " o = jax.nn.sigmoid(o)\n", 104 | "\n", 105 | " new_c = f*c + i * g\n", 106 | " new_h = o*jnp.tanh(new_c)\n", 107 | " \n", 108 | " return (new_h, new_c)" 109 | ], 110 | "execution_count": null, 111 | "outputs": [] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": { 116 | "id": "U8rKZnrAm4IB" 117 | }, 118 | "source": [ 119 | "The following code is even more like PyTorch :\n", 120 | "* You can define all submodules and parameters inside ```__init__()```. All things that happen inside the function that is sent into ```hk.transform()``` will be well traced, and are valid. " 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "metadata": { 126 | "id": "K1zykpR3efQa" 127 | }, 128 | "source": [ 129 | "class HaikuLSTMLM(hk.Module):\n", 130 | " def __init__(self, vocab_size, dim, name=None):\n", 131 | " super().__init__(name=name or \"lstmlm\")\n", 132 | " _c0 = hk.get_parameter(name=\"c_0\",\n", 133 | " shape = (dim,),\n", 134 | " \n", 135 | " init = hk.initializers.TruncatedNormal(stddev=0.1))\n", 136 | " self.hc_0 = (jnp.tanh(_c0), _c0)\n", 137 | " self.embeddings = hk.Embed(vocab_size, dim)\n", 138 | " self.cell = HaikuLSTMCell(dim, dim)\n", 139 | " \n", 140 | " def forward(self, seq, hc):\n", 141 | " loss = 0\n", 142 | " for idx in seq:\n", 143 | " loss -= jax.nn.log_softmax(self.embeddings.embeddings@hc[0])[idx]\n", 144 | " hc = self.cell(self.embeddings(idx), *hc)\n", 145 | " return loss, hc" 146 | ], 147 | "execution_count": null, 148 | "outputs": [] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": { 153 | "id": "OkHZw77241zx" 154 | }, 155 | "source": [ 156 | "* It doesn't matter to ```hk.transform()``` where the submodules are defined, as long as they are defined within the function that is being transformed so that they can be purified. So both the above and below definition are valid and equivalent. \n", 157 | "\n", 158 | "* The second way allows you to make model sizes dependent on inputs received in ```forward()```\n", 159 | "\n", 160 | "* The ```forward()``` function need not be named as it is , and can have any other name. Some poeple use ```__call__```, instead. We are able to use syntax like in line 16 below, (```hk.Embed(.. , ..)(idx)```) because the processing done ```hk.Embed``` is defined in its ```__call__```, rather than forward. Had it been defined in ```forward()```, we'd have to call ```hk.Embed(.. , ..).forward(idx)``` instead. " 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "metadata": { 166 | "id": "RIFcJvXD5hNN" 167 | }, 168 | "source": [ 169 | "class HaikuLSTMLM(hk.Module):\n", 170 | " def __init__(self, vocab_size, dim, name=None):\n", 171 | " super().__init__(name=name or \"lstmlm\")\n", 172 | " _c0 = hk.get_parameter(name=\"c_0\",\n", 173 | " shape = (dim,),\n", 174 | " init = hk.initializers.TruncatedNormal(stddev=0.1))\n", 175 | " self.hc_0 = (jnp.tanh(_c0), _c0)\n", 176 | " self.vocab_size=vocab_size\n", 177 | " self.dim = dim\n", 178 | " self.cell = HaikuLSTMCell(dim, dim)\n", 179 | " \n", 180 | " def forward(self, seq, hc):\n", 181 | " loss = 0\n", 182 | " for idx in seq:\n", 183 | " loss -= jax.nn.log_softmax(hk.Embed(self.vocab_size, self.dim).embeddings@hc[0])[idx]\n", 184 | " hc = self.cell(hk.Embed(self.vocab_size, self.dim)(idx), *hc)\n", 185 | " return loss, hc" 186 | ], 187 | "execution_count": null, 188 | "outputs": [] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "metadata": { 193 | "id": "e2KjbdiWqV_7" 194 | }, 195 | "source": [ 196 | "def impure_forward_fn(vocab_size, dim, seq, hc=None):\n", 197 | " lm = HaikuLSTMLM(vocab_size, dim)\n", 198 | " return lm.forward(seq, hc if hc else lm.hc_0)" 199 | ], 200 | "execution_count": null, 201 | "outputs": [] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "metadata": { 206 | "id": "4A48Q-7UrUiw" 207 | }, 208 | "source": [ 209 | "init_fn, nojit_pure_forward_fn = hk.transform(impure_forward_fn)\n", 210 | "pure_forward_fn = jax.jit(nojit_pure_forward_fn)" 211 | ], 212 | "execution_count": null, 213 | "outputs": [] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "metadata": { 218 | "id": "Z-TPkGpxvdxA" 219 | }, 220 | "source": [ 221 | "* ```init_fn()``` takes in two types of arguments. First is the random key and second are the inputs to be sent to the function that was transformed. It returns the nested params.\n", 222 | "\n", 223 | "* ```nojit_pure_forward_function``` takes in three types of arguments. First is the ```params``` returned by ```init_fn()``` and second is the ```rng``` key and third are the arguments to the function that was transformed. Same ```rng``` key will give same result on same inputs. It returns the same things that are returned by ```impure_forward_fn()``` . " 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "metadata": { 229 | "id": "qOt9BFBju9J-" 230 | }, 231 | "source": [ 232 | "rng = jax.random.PRNGKey(0)\n", 233 | "params = init_fn(rng, vocab_size = 20, dim = 10, seq=jnp.array([0]))" 234 | ], 235 | "execution_count": null, 236 | "outputs": [] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "metadata": { 241 | "id": "MHVDpHkyw4-g" 242 | }, 243 | "source": [ 244 | "print(params)" 245 | ], 246 | "execution_count": null, 247 | "outputs": [] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "metadata": { 252 | "id": "Ku-lknH6xFhu" 253 | }, 254 | "source": [ 255 | "loss, hc = nojit_pure_forward_fn(params, rng, vocab_size = 20, dim=10, seq=jnp.array([0]))" 256 | ], 257 | "execution_count": null, 258 | "outputs": [] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "metadata": { 263 | "colab": { 264 | "base_uri": "https://localhost:8080/" 265 | }, 266 | "id": "IyQly5IryrIe", 267 | "outputId": "e46cc3c0-9183-47c9-a4e1-8b36bc40b9b7" 268 | }, 269 | "source": [ 270 | "print(loss, hc)" 271 | ], 272 | "execution_count": null, 273 | "outputs": [ 274 | { 275 | "output_type": "stream", 276 | "text": [ 277 | "2.9562287 (DeviceArray([ 0.19030678, -0.04981524, -0.1435111 , 0.14797553,\n", 278 | " 0.01645921, -0.01669403, 0.11530687, -0.10629394,\n", 279 | " -0.02137115, 0.07460269], dtype=float32), DeviceArray([ 0.37595972, -0.08241095, -0.2591579 , 0.3729893 ,\n", 280 | " 0.0248227 , -0.03331303, 0.19235653, -0.24751279,\n", 281 | " -0.04453837, 0.15290585], dtype=float32))\n" 282 | ], 283 | "name": "stdout" 284 | } 285 | ] 286 | } 287 | ] 288 | } -------------------------------------------------------------------------------- /flax_mnist.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "flax_mnist.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyMAMrSyi2SMsFF6tOZIv9Kf", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | } 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "view-in-github", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\"Open" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "metadata": { 31 | "colab": { 32 | "base_uri": "https://localhost:8080/" 33 | }, 34 | "id": "612x-KQNev6L", 35 | "outputId": "f45e417c-ed3c-4af1-82be-232e6af172c1" 36 | }, 37 | "source": [ 38 | "# Install ml-collections & latest Flax version from Github.\n", 39 | "!pip install -q ml-collections git+https://github.com/google/flax" 40 | ], 41 | "execution_count": null, 42 | "outputs": [ 43 | { 44 | "output_type": "stream", 45 | "text": [ 46 | "\u001b[?25l\r\u001b[K |███▊ | 10kB 14.9MB/s eta 0:00:01\r\u001b[K |███████▍ | 20kB 19.8MB/s eta 0:00:01\r\u001b[K |███████████ | 30kB 22.6MB/s eta 0:00:01\r\u001b[K |██████████████▉ | 40kB 16.8MB/s eta 0:00:01\r\u001b[K |██████████████████▌ | 51kB 10.7MB/s eta 0:00:01\r\u001b[K |██████████████████████▏ | 61kB 12.0MB/s eta 0:00:01\r\u001b[K |█████████████████████████▉ | 71kB 10.9MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▋ | 81kB 11.9MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 92kB 6.3MB/s \n", 47 | "\u001b[?25h Building wheel for flax (setup.py) ... \u001b[?25l\u001b[?25hdone\n" 48 | ], 49 | "name": "stdout" 50 | } 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": { 56 | "id": "DaPx_QXfe_q4" 57 | }, 58 | "source": [ 59 | "ML Collections is a library of collections(like normal python ```collections``` module) specialised for ML. The repo can be viewed [here](https://github.com/google/ml_collections)." 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "metadata": { 65 | "id": "96r859yCeb5l" 66 | }, 67 | "source": [ 68 | "import ml_collections\n", 69 | "\n", 70 | "def get_config():\n", 71 | " config = ml_collections.ConfigDict()\n", 72 | " \n", 73 | " config.learning_rate = 0.1\n", 74 | " config.momentum = 0.9\n", 75 | " config.batch_size = 128\n", 76 | " config.num_epochs = 10\n", 77 | "\n", 78 | " return config" 79 | ], 80 | "execution_count": null, 81 | "outputs": [] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": { 86 | "id": "4RbQ2kMRnP5e" 87 | }, 88 | "source": [ 89 | "# Imports" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "metadata": { 95 | "id": "EaDe0Bimff15" 96 | }, 97 | "source": [ 98 | "from absl import logging\n", 99 | "import flax\n", 100 | "import jax.numpy as jnp\n", 101 | "from matplotlib import pyplot as plt\n", 102 | "import numpy as np\n", 103 | "import tensorflow_datasets as tfds\n", 104 | "\n", 105 | "logging.set_verbosity(logging.INFO)" 106 | ], 107 | "execution_count": null, 108 | "outputs": [] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "metadata": { 113 | "id": "0pdBQBGhfy-p" 114 | }, 115 | "source": [ 116 | "# Helper functions for images.\n", 117 | "\n", 118 | "def show_img(img, ax=None, title=None):\n", 119 | " \"\"\"Shows a single image.\"\"\"\n", 120 | " if ax is None:\n", 121 | " ax = plt.gca()\n", 122 | " ax.imshow(img[..., 0], cmap='gray')\n", 123 | " ax.set_xticks([])\n", 124 | " ax.set_yticks([])\n", 125 | " if title:\n", 126 | " ax.set_title(title)\n", 127 | "\n", 128 | "def show_img_grid(imgs, titles):\n", 129 | " \"\"\"Shows a grid of images.\"\"\"\n", 130 | " n = int(np.ceil(len(imgs)**.5))\n", 131 | " _, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))\n", 132 | " for i, (img, title) in enumerate(zip(imgs, titles)):\n", 133 | " show_img(img, axs[i // n][i % n], title)" 134 | ], 135 | "execution_count": null, 136 | "outputs": [] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "metadata": { 141 | "id": "5wo9zShEf6ts" 142 | }, 143 | "source": [ 144 | "# Local imports from current directory will auto reload.\n", 145 | "# Any changes you make to local files will appear automatically.\n", 146 | "%load_ext autoreload\n", 147 | "%autoreload 2" 148 | ], 149 | "execution_count": null, 150 | "outputs": [] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "metadata": { 155 | "id": "3iwWpea0gGdZ" 156 | }, 157 | "source": [ 158 | "config = get_config()" 159 | ], 160 | "execution_count": null, 161 | "outputs": [] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": { 166 | "id": "eaed4ugFiLNx" 167 | }, 168 | "source": [ 169 | "* ```tfds.as_numpy()``` takes in a dataset to a python generator, that generates numpy matrices here. \n", 170 | "\n", 171 | "* ```tf.DatasetBuilder.as_dataset()``` builds an input pipeline(taking care of all batch size, device etc.) using ```tf.data.Dataset```(s). The ```tf.data.Dataset```(s) correspond to the ```nn.Dataset``` of PyTorch." 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "metadata": { 177 | "id": "nDNet9fwgOEW" 178 | }, 179 | "source": [ 180 | "def get_datasets():\n", 181 | " ds_builder = tfds.builder('mnist')\n", 182 | " ds_builder.download_and_prepare()\n", 183 | " train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))\n", 184 | " test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))\n", 185 | " #print(test_ds) #Each dataset has data in different format, so do check.Here the structure is a dict {'image':np array of all images, 'label': all labels}\n", 186 | " #print(train_ds['image'][0]) #Prints first image. The values are 0 to 255..\n", 187 | " train_ds['image'] = jnp.float32(train_ds['image']) /255 \n", 188 | " test_ds['image'] = jnp.float32(test_ds['image']) / 255\n", 189 | " return train_ds, test_ds" 190 | ], 191 | "execution_count": null, 192 | "outputs": [] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "metadata": { 197 | "id": "ONPZKDz0m5aW" 198 | }, 199 | "source": [ 200 | "train_ds, test_ds = get_datasets()" 201 | ], 202 | "execution_count": null, 203 | "outputs": [] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "metadata": { 208 | "id": "QgJcjvfqm8y4" 209 | }, 210 | "source": [ 211 | "show_img_grid(\n", 212 | " [train_ds['image'][idx] for idx in range(25)],\n", 213 | " [f'label={train_ds[\"label\"][idx]}' for idx in range(25)],\n", 214 | ")" 215 | ], 216 | "execution_count": null, 217 | "outputs": [] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "metadata": { 222 | "id": "cMKtqEuAov-s" 223 | }, 224 | "source": [ 225 | "# Model" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "metadata": { 231 | "id": "auCejTm_ql1C" 232 | }, 233 | "source": [ 234 | "from flax import linen as nn\n", 235 | "from flax import optim\n", 236 | "from flax.metrics import tensorboard\n", 237 | "import numpy as onp\n", 238 | "from jax import random\n", 239 | "import jax" 240 | ], 241 | "execution_count": null, 242 | "outputs": [] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "metadata": { 247 | "id": "wIqCGxm5oz4d" 248 | }, 249 | "source": [ 250 | "class CNN(nn.Module):\n", 251 | " @nn.compact\n", 252 | " def __call__(self, x):\n", 253 | " x = nn.Conv(features=32, kernel_size=(3,3))(x)\n", 254 | " x = nn.relu(x)\n", 255 | " x = nn.avg_pool(x, window_shape=(2,2), strides=(2,2))\n", 256 | " x = nn.Conv(features=64, kernel_size=(3,3))(x)\n", 257 | " x = nn.relu(x)\n", 258 | " x = nn.avg_pool(x, window_shape=(2,2), strides=(2,3))\n", 259 | " x = x.reshape((x.shape[0], -1))\n", 260 | " x = nn.Dense(features=256)(x)\n", 261 | " x = nn.relu(x)\n", 262 | " x = nn.Dense(features=10)(x)\n", 263 | " x = nn.log_softmax(x)\n", 264 | " return x" 265 | ], 266 | "execution_count": null, 267 | "outputs": [] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "metadata": { 272 | "id": "17NZsG4pqzPg" 273 | }, 274 | "source": [ 275 | "key = random.PRNGKey(0)\n", 276 | "key1, key2 = random.split(key)\n", 277 | "x = random.normal(key1, (1, 28, 28, 1))\n", 278 | "\n", 279 | "model = CNN()\n", 280 | "params = model.init(key2, x)\n", 281 | "print(params) #To check dictionary structure.. whether variables are there, etc." 282 | ], 283 | "execution_count": null, 284 | "outputs": [] 285 | }, 286 | { 287 | "cell_type": "markdown", 288 | "metadata": { 289 | "id": "XxlG23AgtxnW" 290 | }, 291 | "source": [ 292 | "# Optimizer" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "metadata": { 298 | "id": "sk6kUQs4tmyo" 299 | }, 300 | "source": [ 301 | "optimizer_def = optim.Momentum(learning_rate=config.learning_rate, \n", 302 | " beta=config.momentum)\n", 303 | "optimizer = optimizer_def.create(params)" 304 | ], 305 | "execution_count": null, 306 | "outputs": [] 307 | }, 308 | { 309 | "cell_type": "markdown", 310 | "metadata": { 311 | "id": "WQFafZjjnp1s" 312 | }, 313 | "source": [ 314 | "# Training" 315 | ] 316 | }, 317 | { 318 | "cell_type": "markdown", 319 | "metadata": { 320 | "id": "kpSDfgJevg1S" 321 | }, 322 | "source": [ 323 | "##Loss Funtion" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "metadata": { 329 | "id": "j6tYKrhVwRf2" 330 | }, 331 | "source": [ 332 | "def cross_entropy_loss(labels,logits):\n", 333 | " return -jnp.mean(jnp.sum(labels*logits, axis=-1))" 334 | ], 335 | "execution_count": null, 336 | "outputs": [] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "metadata": { 341 | "id": "1Z1ECDuzwvMY" 342 | }, 343 | "source": [ 344 | "max_classes=10\n", 345 | "def onehot(label):\n", 346 | " x = (label[...,None]==jnp.arange(0,max_classes)[None])\n", 347 | " return x.astype(jnp.float32)" 348 | ], 349 | "execution_count": null, 350 | "outputs": [] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "metadata": { 355 | "id": "_IcIm-CmvgC7" 356 | }, 357 | "source": [ 358 | "def loss_fn(params, batch): #Can input any number of arguments.\n", 359 | " logits = CNN().apply(params, batch['image']) #We are not constrained to use the same model as before.\n", 360 | " loss = cross_entropy_loss(onehot(batch['label']), logits)\n", 361 | " return loss, logits #Can output at most two values" 362 | ], 363 | "execution_count": null, 364 | "outputs": [] 365 | }, 366 | { 367 | "cell_type": "markdown", 368 | "metadata": { 369 | "id": "ZxSZHTpk858v" 370 | }, 371 | "source": [ 372 | "## Metric Calculation" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "metadata": { 378 | "id": "SAe5DQt99CaY" 379 | }, 380 | "source": [ 381 | "def compute_metric(logits, labels):\n", 382 | " loss = cross_entropy_loss(logits, onehot(labels))\n", 383 | " accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)\n", 384 | " metrics = {\n", 385 | " 'loss' : loss,\n", 386 | " 'accuracy' : accuracy,\n", 387 | " }\n", 388 | " return metrics" 389 | ], 390 | "execution_count": null, 391 | "outputs": [] 392 | }, 393 | { 394 | "cell_type": "markdown", 395 | "metadata": { 396 | "id": "i8HgXZ9UvSZK" 397 | }, 398 | "source": [ 399 | "## Single Step Training" 400 | ] 401 | }, 402 | { 403 | "cell_type": "markdown", 404 | "metadata": { 405 | "id": "xM-h7BX56yI7" 406 | }, 407 | "source": [ 408 | "The ```has_aux=True``` below is necessary to indicate that the ```loss_fn``` returns two values, first of which is output of mathematical operation and second is auxillary data. The inability to print any the abstractions adopted by ```JAX``` are very nicely explained [here]( https://github.com/google/jax/issues/196#issuecomment-451671635 ) ." 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "metadata": { 414 | "id": "L7s5bBYbvRhb" 415 | }, 416 | "source": [ 417 | "@jax.jit\n", 418 | "def train_step(optimizer, batch):\n", 419 | " grad_n_val_fn = jax.value_and_grad(loss_fn, has_aux=True) #By default, gradients will be calculated w.r.t the first argument of loss_fn only. \n", 420 | " (loss, logits), grad = grad_n_val_fn(optimizer.target, batch)\n", 421 | " optimizer = optimizer.apply_gradient(grad)\n", 422 | " \n", 423 | " #print(loss) #Not able to get value of loss directly. \n", 424 | " #Can't print values inside jit compiled functions and others nested,inside it, yet.\n", 425 | " return optimizer, compute_metric(logits, batch['label'])" 426 | ], 427 | "execution_count": null, 428 | "outputs": [] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "metadata": { 433 | "id": "Xe-MMzbIz5Nc" 434 | }, 435 | "source": [ 436 | "## Epoch Training" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "metadata": { 442 | "id": "ML4mqrmi0skB" 443 | }, 444 | "source": [ 445 | "### Setting up data loading" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "metadata": { 451 | "id": "G-_Yht9Qz4yQ" 452 | }, 453 | "source": [ 454 | "train_ds_size = len(train_ds['image'])\n", 455 | "steps_per_epoch = train_ds_size//config.batch_size\n", 456 | "\n", 457 | "perms = random.permutation(key, len(train_ds['image']))\n", 458 | "perms = perms[:steps_per_epoch*config.batch_size]\n", 459 | "perms = perms.reshape((steps_per_epoch, config.batch_size))" 460 | ], 461 | "execution_count": null, 462 | "outputs": [] 463 | }, 464 | { 465 | "cell_type": "markdown", 466 | "metadata": { 467 | "id": "na-jvDS50xmW" 468 | }, 469 | "source": [ 470 | "### Training loop" 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "metadata": { 476 | "colab": { 477 | "base_uri": "https://localhost:8080/" 478 | }, 479 | "id": "7YYm33df0w6N", 480 | "outputId": "72888316-0d2d-431a-96c1-3200127c18ec" 481 | }, 482 | "source": [ 483 | "metrics = []\n", 484 | "for perm in perms:\n", 485 | " batch = {k: v[perm] for k,v in train_ds.items()} #batch is a dictionary/pytree here\n", 486 | " optimizer, metric = train_step(optimizer, batch)\n", 487 | " metrics.append(metric)\n", 488 | "\n", 489 | "metrics = jax.device_get(metrics) #Get metrics from device into CPU as numpy arrays\n", 490 | "mean_metrics = {k : onp.mean([metric[k] for metric in metrics]) #Averaging metrics of all batches, while\n", 491 | " for k in metrics[0]} #Looping over all types of metrics\n", 492 | "print(mean_metrics) #Can print outside any jit-ted functions" 493 | ], 494 | "execution_count": null, 495 | "outputs": [ 496 | { 497 | "output_type": "stream", 498 | "text": [ 499 | "{'accuracy': 0.9871461, 'loss': 0.04197854}\n" 500 | ], 501 | "name": "stdout" 502 | } 503 | ] 504 | } 505 | ] 506 | } -------------------------------------------------------------------------------- /flax_basics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "flax_basics.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyO4WBy4mQKMS6frPDJVLuQQ", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | } 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "view-in-github", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\"Open" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "metadata": { 31 | "id": "twYeooj3pMsh", 32 | "colab": { 33 | "base_uri": "https://localhost:8080/" 34 | }, 35 | "outputId": "0c97bb2b-5ed1-492c-905f-bef3a7c143a5" 36 | }, 37 | "source": [ 38 | "!pip install --upgrade -q pip jax jaxlib\n", 39 | "# Install Flax at head:\n", 40 | "!pip install --upgrade -q git+https://github.com/google/flax.git" 41 | ], 42 | "execution_count": 1, 43 | "outputs": [ 44 | { 45 | "output_type": "stream", 46 | "text": [ 47 | "\u001b[K |████████████████████████████████| 1.5MB 5.5MB/s \n", 48 | "\u001b[K |████████████████████████████████| 522kB 36.3MB/s \n", 49 | "\u001b[?25h Building wheel for jax (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 50 | " Building wheel for flax (setup.py) ... \u001b[?25l\u001b[?25hdone\n" 51 | ], 52 | "name": "stdout" 53 | } 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "metadata": { 59 | "id": "oi5V7X-SofKn" 60 | }, 61 | "source": [ 62 | "import jax\n", 63 | "from typing import Any, Callable, Sequence, Optional\n", 64 | "from jax import lax, random, numpy as jnp\n", 65 | "import flax\n", 66 | "from flax.core import freeze, unfreeze\n", 67 | "from flax import linen as nn\n", 68 | "\n", 69 | "from jax.config import config\n", 70 | "config.enable_omnistaging() # Linen requires enabling omnistaging" 71 | ], 72 | "execution_count": 2, 73 | "outputs": [] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": { 78 | "id": "LVWgTDepoi_9" 79 | }, 80 | "source": [ 81 | "* Class attributes are attributes of class specified outside any function. \n", 82 | "* They are same for all instances of the class.\n", 83 | "* In below syntax, ```features``` is not a class attribute. In the ```__init__()``` of parent class, it will be initialized. It is different for different objects, and must be provided during creation of object.\n", 84 | "\n" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "metadata": { 90 | "id": "C-ewBn6KpArI" 91 | }, 92 | "source": [ 93 | "class ExplicitMLP(nn.Module):\n", 94 | " features: Sequence[int]\n", 95 | "\n", 96 | " def setup(self):\n", 97 | " '''\n", 98 | " This function is called automatically after __postinit__() function. \n", 99 | " Here we can register submodules, variables, parameters you will need in your model.\n", 100 | " '''\n", 101 | " self.layers = [nn.Dense(feat) for feat in self.features]\n", 102 | " \n", 103 | " def __call__(self, inputs):\n", 104 | " '''\n", 105 | " Is called whenever inputs are sent in the model.apply()\n", 106 | " It doesn't matter whether inputs contain params or not. Don't think about it.\n", 107 | " This function just need specifies the flow.\n", 108 | " '''\n", 109 | " x = inputs\n", 110 | " for i, lyr in enumerate(self.layers):\n", 111 | " x = lyr(x)\n", 112 | " if i!=len(self.layers)-1:\n", 113 | " x = nn.relu(x)\n", 114 | " return x" 115 | ], 116 | "execution_count": 5, 117 | "outputs": [] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": { 122 | "id": "bT9o9scrzDAI" 123 | }, 124 | "source": [ 125 | "* In the above class, ```model.layers``` won't be accessible from outside the class. It seems like these layers come into existence only when ```model.apply()``` is called.\n", 126 | "\n", 127 | "* Below is an example of a neat trick done by flax. If you would like to modify/define the initialisation procedure for a module, at first sight it looks like you will have to pass in and maintain what method to use outside of class(like with ```params```). But, what flax does is that it recognizes that the initialisation method is basically just a combination of function and a random key, so, it will allow you to store and maintain the function part inside the class! (You can do so for functions, but not for shared state.) And this function will take the random key+ shapes etc. as its input and produce deterministic output based on that, which will be used to provide the initial parameters. " 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "metadata": { 133 | "id": "Hf5t9bejp32H" 134 | }, 135 | "source": [ 136 | "key = random.PRNGKey(0)\n", 137 | "key1, key2 = random.split(key, 2)\n", 138 | "x = random.normal(key1, (4,4)) #First dimension will automatically be interpretted as batch-dimension. No need to use vmap.\n", 139 | "\n", 140 | "model = ExplicitMLP(features=[3,4,5])\n", 141 | "params = model.init(key2, x) #Would go on init-ing all the internal layers too." 142 | ], 143 | "execution_count": null, 144 | "outputs": [] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": { 149 | "id": "4jZXovvJ6a-W" 150 | }, 151 | "source": [ 152 | "The ```model.apply()``` below, would have to call each of its sub-layer as specified in ```__call__``` function above. Before calling each of it's sub layers, it sets that specific layer's params properly and would also set various flags that would make sure that you can only use ```__call__``` from inside ```model.apply()``` or ```model.init()```." 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "metadata": { 158 | "colab": { 159 | "base_uri": "https://localhost:8080/" 160 | }, 161 | "id": "lTEtLzQ4p7Lm", 162 | "outputId": "0d4ae6dc-38fa-43e1-d2df-de3981eff66d" 163 | }, 164 | "source": [ 165 | "y = model.apply(params, x) #Can't do y = model((params,x))\n", 166 | "\n", 167 | "print('initialized parameter shapes:\\n', jax.tree_map(jnp.shape, unfreeze(params)))\n", 168 | "print('output shape:\\n', y.shape)" 169 | ], 170 | "execution_count": null, 171 | "outputs": [ 172 | { 173 | "output_type": "stream", 174 | "text": [ 175 | "initialized parameter shapes:\n", 176 | " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", 177 | "output shape:\n", 178 | " (4, 5)\n" 179 | ], 180 | "name": "stdout" 181 | } 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "metadata": { 187 | "id": "NwVwTH5q61CD" 188 | }, 189 | "source": [ 190 | "Below is another easier method for specifying the flow of steps in the model. We define as well as use the layers directly, specifying only what to pass to it. " 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "metadata": { 196 | "id": "nnro68LUyNdD" 197 | }, 198 | "source": [ 199 | "class SimpleMLP(nn.Module):\n", 200 | " features: Sequence[int]\n", 201 | "\n", 202 | " @nn.compact\n", 203 | " def __call__(self, inputs):\n", 204 | " x = inputs\n", 205 | " for i, feat in enumerate(self.features):\n", 206 | " x = nn.Dense(feat, name=f'layers_{i}')(x) #No need to do init/apply etc. as we are in @nn.compact\n", 207 | " if i!=len(self.features)-1:\n", 208 | " x=nn.relu(x)\n", 209 | " return x " 210 | ], 211 | "execution_count": null, 212 | "outputs": [] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "metadata": { 217 | "colab": { 218 | "base_uri": "https://localhost:8080/" 219 | }, 220 | "id": "RCZOpPAW1c4M", 221 | "outputId": "dd1bf173-df06-4e3a-f1e1-27d442b862c0" 222 | }, 223 | "source": [ 224 | "key = random.PRNGKey(0)\n", 225 | "key1, key2 = random.split(key, 2)\n", 226 | "x = random.uniform(key1, (4,4))\n", 227 | "\n", 228 | "model = SimpleMLP([4, 3, 5])\n", 229 | "params = model.init(key2,x)\n", 230 | "y = model.apply(params, x)\n", 231 | "\n", 232 | "print('initialised parameter shapes:\\n', jax.tree_map(jnp.shape, unfreeze(params)))\n", 233 | "print('output shape:\\n', y.shape)" 234 | ], 235 | "execution_count": null, 236 | "outputs": [ 237 | { 238 | "output_type": "stream", 239 | "text": [ 240 | "initialised parameter shapes:\n", 241 | " {'params': {'layers_0': {'bias': (4,), 'kernel': (4, 4)}, 'layers_1': {'bias': (3,), 'kernel': (4, 3)}, 'layers_2': {'bias': (5,), 'kernel': (3, 5)}}}\n", 242 | "output shape:\n", 243 | " (4, 5)\n" 244 | ], 245 | "name": "stdout" 246 | } 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": { 252 | "id": "N_n3I4ie-EEk" 253 | }, 254 | "source": [ 255 | "Compact notation for defining computation models from scratch, using mathematical operations(only) alongside defining any parameters that the model has. The ```self.param()``` behave differently based on whether ```__call__``` has been called by ```init()``` or ```apply()```." 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "metadata": { 261 | "id": "bM-yvsTp3kLO" 262 | }, 263 | "source": [ 264 | "class SimpleDense(nn.Module):\n", 265 | " features: int\n", 266 | " kernel_init: Callable = nn.initializers.lecun_normal()\n", 267 | " bias_init: Callable = nn.initializers.zeros\n", 268 | "\n", 269 | " @nn.compact\n", 270 | " def __call__(self, inputs):\n", 271 | " kernel = self.param('kernel',\n", 272 | " self.kernel_init,\n", 273 | " (inputs.shape[-1], self.features))\n", 274 | " y = jnp.dot(inputs, kernel)\n", 275 | " bias = self.param('bias',\n", 276 | " self.bias_init,\n", 277 | " (self.features, ))\n", 278 | " y = y+bias\n", 279 | " return y " 280 | ], 281 | "execution_count": null, 282 | "outputs": [] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "metadata": { 287 | "colab": { 288 | "base_uri": "https://localhost:8080/" 289 | }, 290 | "id": "cB8JzfOk8uqz", 291 | "outputId": "70c70680-034a-4705-b890-c29456bdbe3a" 292 | }, 293 | "source": [ 294 | "key = random.PRNGKey(0)\n", 295 | "key1, key2 = random.split(key, 2)\n", 296 | "x = random.uniform(key1, (4,4))\n", 297 | "\n", 298 | "model = SimpleDense(features=3)\n", 299 | "params = model.init(key2, x)\n", 300 | "y = model.apply(params, x)\n", 301 | "\n", 302 | "print('initialised parameter shapes:\\n', jax.tree_map(jnp.shape, unfreeze(params)))\n", 303 | "print('output shape:\\n', y.shape)" 304 | ], 305 | "execution_count": null, 306 | "outputs": [ 307 | { 308 | "output_type": "stream", 309 | "text": [ 310 | "initialised parameter shapes:\n", 311 | " {'params': {'bias': (3,), 'kernel': (4, 3)}}\n", 312 | "output shape:\n", 313 | " (4, 3)\n" 314 | ], 315 | "name": "stdout" 316 | } 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "metadata": { 322 | "id": "tzhNWbVQ_A2e" 323 | }, 324 | "source": [ 325 | "If the above model is implemented using ```setup()``` way, it won't be able to fill in the blank below as no input is available in ```setup()``` function." 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "metadata": { 331 | "id": "sSqCyMcC9Nff" 332 | }, 333 | "source": [ 334 | "class SimpleDense(nn.Module):\n", 335 | " features: int\n", 336 | " kernel_init: Callable = nn.initializers.lecun_normal()\n", 337 | " bias_init: Callable = nn.initializers.zeros\n", 338 | "\n", 339 | " def setup(self):\n", 340 | " self.kernel = self.param('kernel',\n", 341 | " self.kernel_init,\n", 342 | " (___________, self.features))\n", 343 | " bias = self.param('bias',\n", 344 | " self.bias_init,\n", 345 | " (self.features, ))\n", 346 | " @nn.compact\n", 347 | " def __call__(self, inputs):\n", 348 | " y = jnp.dot(inputs, self.kernel)+self.bias\n", 349 | " return y " 350 | ], 351 | "execution_count": null, 352 | "outputs": [] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": { 357 | "id": "5humUngRBYN1" 358 | }, 359 | "source": [ 360 | "* Following code shows how to define variables for a model, apart from its parameters. \n", 361 | "* The variables, like parameters, are stored in a tree. \n", 362 | "* And like parameters, are handled outside the class.\n", 363 | "* To define a variable, specify the entire path from root to the final variable. Here we have specified ```('batch_stats', 'mean')```.\n", 364 | "* Due to ```@nn.compact()``` the variables and parameters are only initalised and defined once. but all the operations specified are performed every time ```model.apply()``` is called." 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "metadata": { 370 | "id": "6cRejY3wBT-l" 371 | }, 372 | "source": [ 373 | "class BiasAdderWithRunningMean(nn.Module):\n", 374 | " decay: float = 0.99\n", 375 | "\n", 376 | " @nn.compact\n", 377 | " def __call__(self, x):\n", 378 | " is_initialized = self.has_variable('batch_stats', 'mean')\n", 379 | " ra_mean = self.variable('batch_stats', 'mean', #variable entire path name\n", 380 | " lambda s: jnp.zeros(s), #initialization function\n", 381 | " x.shape[1:]) #input to initialization function\n", 382 | " mean = ra_mean.value\n", 383 | " bias = self.param('bias', \n", 384 | " lambda rng, shape : jnp.zeros(shape), #Since it's a parameter, its lambda function must take rng and shape both. \n", 385 | " x.shape[1:])\n", 386 | " \n", 387 | " if is_initialized:\n", 388 | " ra_mean.value = self.decay * ra_mean.value\\\n", 389 | " + (1.0-self.decay)*jnp.mean(x, axis=0, keepdims=True)\n", 390 | "\n", 391 | " return x - ra_mean.value + bias " 392 | ], 393 | "execution_count": null, 394 | "outputs": [] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "metadata": { 399 | "id": "3lXEgMugKHDF" 400 | }, 401 | "source": [ 402 | "* The ```model.apply()``` call has been modified below. You must specify the mutable parameters of the model, and receive them in the output. \n", 403 | "\n", 404 | "* The variable ```y``` still contains, the value returned by the ```__call__``` function defined above.\n", 405 | "\n", 406 | "* ```model.init()``` returns all the initialized parameters, i.e., variables and params, both. All those are sent into the ```apply()``` call. (And hence they don't need to be initialised again in ```__call__```. )\n", 407 | "\n", 408 | "* Although the ```model.apply()``` returns updated variables, but still ```params_n_variables``` has the same old variables. As variables need to be handled outside the class too; so the variables in the ```params_n_variables``` need to be updated here too. " 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "metadata": { 414 | "id": "un7GMwwbIIok" 415 | }, 416 | "source": [ 417 | "key = random.PRNGKey(0)\n", 418 | "key1, key2 = random.split(key, 2)\n", 419 | "x = random.uniform(key1, (5,))\n", 420 | "\n", 421 | "model = BiasAdderWithRunningMean(decay=0.99)\n", 422 | "params_n_variables = model.init(key2, x)\n", 423 | "print(params_n_variables)\n", 424 | "\n", 425 | "for i in range(10):\n", 426 | " x = random.normal(key2+i, (5,))\n", 427 | " \n", 428 | " y, updated_variables = model.apply(params_n_variables, x, mutable=['batch_stats'])\n", 429 | "\n", 430 | " old_variables, params = params_n_variables.pop('params') #remaining tree is first output and popped part(params) is the second\n", 431 | " params_n_variables = freeze({'params':params, **updated_variables}) #New tree being made from the available components\n", 432 | " \n", 433 | " print(updated_state)\n", 434 | "\n", 435 | "print('initialised parameter shapes:\\n', jax.tree_map(jnp.shape, unfreeze(params)))\n", 436 | "print('output shape:\\n', y.shape)" 437 | ], 438 | "execution_count": null, 439 | "outputs": [] 440 | }, 441 | { 442 | "cell_type": "markdown", 443 | "metadata": { 444 | "id": "w-bhdiKjUagA" 445 | }, 446 | "source": [ 447 | "#Optimizers in flax" 448 | ] 449 | }, 450 | { 451 | "cell_type": "markdown", 452 | "metadata": { 453 | "id": "MBK-P4S5VyzV" 454 | }, 455 | "source": [ 456 | "The parameters of the model are stored in the optimizer and are available in ```optimizer.target``` ." 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "metadata": { 462 | "id": "A-ivwsUbJTcY" 463 | }, 464 | "source": [ 465 | "from flax import optim\n", 466 | "optimizer_def = optim.GradientDescent(learning_rate=0.01)\n", 467 | "optimizer = optimizer_def.create(params) #These params are stored within the class of optimizer and need not be handled outside.\n", 468 | "loss_grad_fn = jax.value_and_grad(loss) " 469 | ], 470 | "execution_count": null, 471 | "outputs": [] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "metadata": { 476 | "id": "quwD4QjRU79N" 477 | }, 478 | "source": [ 479 | "for i in range(101):\n", 480 | " loss_val, grad = loss_grad_fn(optimizer.target)\n", 481 | " optimizer = optimizer.apply_gradient(grad)" 482 | ], 483 | "execution_count": null, 484 | "outputs": [] 485 | } 486 | ] 487 | } -------------------------------------------------------------------------------- /jax_basic.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "jax-basic.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyPehf0/MS0woXFWJe5XDjUT", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | } 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "view-in-github", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\"Open" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "id": "oXjPlcOX2G3I" 32 | }, 33 | "source": [ 34 | "A notebook for this [blog](https://roberttlange.github.io/posts/2020/03/blog-post-10/) with additional notes. Implements MLP and CNN in ```JAX```. It is suggested to read that blog side-by-side. " 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "metadata": { 40 | "id": "G5Q-brUn68dX" 41 | }, 42 | "source": [ 43 | "%matplotlib inline\n", 44 | "%config InlineBackend.figure_format = 'retina'\n", 45 | "\n", 46 | "import numpy as onp\n", 47 | "import jax.numpy as np\n", 48 | "from jax import grad, jit, vmap, value_and_grad\n", 49 | "from jax import random\n", 50 | "\n", 51 | "# Generate key which is used to generate random numbers\n", 52 | "key = random.PRNGKey(1) #A key is always an nd-array of size (2,) " 53 | ], 54 | "execution_count": null, 55 | "outputs": [] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "metadata": { 60 | "id": "U1I15wsT7C_H" 61 | }, 62 | "source": [ 63 | "# Generate a random matrix\n", 64 | "x = random.uniform(key, (1000, 1000))\n", 65 | "# Compare running times of 3 different matrix multiplications\n", 66 | "%time y = onp.dot(x, x)\n", 67 | "%time y = np.dot(x, x); print(y)\n", 68 | "%time y = np.dot(x, x).block_until_ready()" 69 | ], 70 | "execution_count": null, 71 | "outputs": [] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": { 76 | "id": "-SbZO63W7_st" 77 | }, 78 | "source": [ 79 | "The above is due to [Asyncronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)." 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "metadata": { 85 | "id": "i5YUeW4V7QPR" 86 | }, 87 | "source": [ 88 | "def ReLU(x):\n", 89 | " \"\"\" Rectified Linear Unit (ReLU) activation function \"\"\"\n", 90 | " return np.maximum(0, x)\n", 91 | "\n", 92 | "jit_ReLU = jit(ReLU) " 93 | ], 94 | "execution_count": null, 95 | "outputs": [] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": { 100 | "id": "uLDbzLKu9HZQ" 101 | }, 102 | "source": [ 103 | "JIT a simple python function using numpy to make it faster. Normally, each operation has its own kernel which are dispatched to GPU, one by one. If we have a sequence of operations, we can use the ```@jit decorator / jit()``` to compile multiple operations together using XLA." 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "metadata": { 109 | "colab": { 110 | "base_uri": "https://localhost:8080/" 111 | }, 112 | "id": "dm0Oo83u8TdW", 113 | "outputId": "4025d3d7-6a90-4a54-f951-ff7926162976" 114 | }, 115 | "source": [ 116 | "%time out = ReLU(x).block_until_ready()\n", 117 | "\n", 118 | "# Call jitted version to compile for evaluation time!\n", 119 | "%time jit_ReLU(x).block_until_ready() #First time call will cause compilation, and may take longer.\n", 120 | "%time out = jit_ReLU(x).block_until_ready()" 121 | ], 122 | "execution_count": null, 123 | "outputs": [ 124 | { 125 | "output_type": "stream", 126 | "text": [ 127 | "CPU times: user 60.6 ms, sys: 0 ns, total: 60.6 ms\n", 128 | "Wall time: 61.1 ms\n", 129 | "CPU times: user 26 ms, sys: 841 µs, total: 26.8 ms\n", 130 | "Wall time: 25.8 ms\n", 131 | "CPU times: user 1.21 ms, sys: 0 ns, total: 1.21 ms\n", 132 | "Wall time: 696 µs\n" 133 | ], 134 | "name": "stdout" 135 | } 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": { 141 | "id": "pIMKKfKd98KT" 142 | }, 143 | "source": [ 144 | "The ```grad()``` function takes as input a function ```f``` and returns the function ``` f' ``` . This ```f'``` should be ```jit()```-ted again." 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "metadata": { 150 | "colab": { 151 | "base_uri": "https://localhost:8080/" 152 | }, 153 | "id": "JpKDUNP38hGD", 154 | "outputId": "f51a44c0-51a2-4e21-bbc6-27328acec43b" 155 | }, 156 | "source": [ 157 | "def FiniteDiffGrad(x):\n", 158 | " \"\"\" Compute the finite difference derivative approx for the ReLU\"\"\"\n", 159 | " return np.array((ReLU(x + 1e-3) - ReLU(x - 1e-3)) / (2 * 1e-3))\n", 160 | "\n", 161 | "# Compare the Jax gradient with a finite difference approximation\n", 162 | "print(\"Jax Grad: \", jit(grad(jit(ReLU)))(2.))\n", 163 | "print(\"FD Gradient:\", FiniteDiffGrad(2.))" 164 | ], 165 | "execution_count": null, 166 | "outputs": [ 167 | { 168 | "output_type": "stream", 169 | "text": [ 170 | "Jax Grad: 1.0\n", 171 | "FD Gradient: 0.99998707\n" 172 | ], 173 | "name": "stdout" 174 | } 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": { 180 | "id": "VlG3QboG-dRO" 181 | }, 182 | "source": [ 183 | "**vmap** - makes batching as easy as never before. While in PyTorch one always has to be careful over which dimension you want to perform computations, vmap lets you simply write your computations for a single sample case and afterwards wrap it to make it batch compatible. " 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "metadata": { 189 | "id": "bfS5MpeR96Jj" 190 | }, 191 | "source": [ 192 | "batch_dim = 32\n", 193 | "feature_dim = 100\n", 194 | "hidden_dim = 512\n", 195 | "\n", 196 | "# Generate a batch of vectors to process\n", 197 | "X = random.normal(key, (batch_dim, feature_dim))\n", 198 | "\n", 199 | "# Generate Gaussian weights and biases\n", 200 | "params = [random.normal(key, (hidden_dim, feature_dim)),\n", 201 | " random.normal(key, (hidden_dim, ))]\n", 202 | "\n", 203 | "def relu_layer(params, x):\n", 204 | " \"\"\" Simple ReLu layer for single sample \"\"\"\n", 205 | " return ReLU(np.dot(params[0], x) + params[1])\n", 206 | "\n", 207 | "def batch_version_relu_layer(params, x):\n", 208 | " \"\"\" Error prone batch version \"\"\"\n", 209 | " return ReLU(np.dot(X, params[0].T) + params[1])\n", 210 | "\n", 211 | "def vmap_relu_layer(params, x):\n", 212 | " \"\"\" vmap version of the ReLU layer \"\"\"\n", 213 | " return jit(vmap(relu_layer, in_axes=(None, 0), out_axes=0))\n", 214 | "\n", 215 | "out = np.stack([relu_layer(params, X[i, :]) for i in range(X.shape[0])])\n", 216 | "out = batch_version_relu_layer(params, X)\n", 217 | "out = vmap_relu_layer(params, X)" 218 | ], 219 | "execution_count": null, 220 | "outputs": [] 221 | }, 222 | { 223 | "cell_type": "markdown", 224 | "metadata": { 225 | "id": "bBgq7oJIKGC5" 226 | }, 227 | "source": [ 228 | "```vmap``` wraps the ```relu_layer``` function and takes as an input the axis over which to batch the inputs. In our case the first input to ```relu_layer``` are the parameters which are the same for the entire batch [```(None)```]. The second input is the feature vector, ```x```. We have stacked the vectors into a matrix such that our input has dimensions ```(batch_dim, feature_dim)```. We therefore need to provide ```vmap``` with batch dimension ```(0)``` in order to properly parallelize the computations. ```out_axes``` then specifies how to stack the individual samples' outputs. In order to keep things consistent, we choose the first dimension to remain the batch dimension." 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": { 234 | "id": "ZAlWJJKWMI8q" 235 | }, 236 | "source": [ 237 | "## MLP" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "metadata": { 243 | "id": "rTUvezvFKnOl" 244 | }, 245 | "source": [ 246 | "from jax.scipy.special import logsumexp\n", 247 | "from jax.experimental import optimizers\n", 248 | "\n", 249 | "import torch\n", 250 | "from torchvision import datasets, transforms\n", 251 | "\n", 252 | "import time" 253 | ], 254 | "execution_count": null, 255 | "outputs": [] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "metadata": { 260 | "id": "zJrb8CtwMTIg" 261 | }, 262 | "source": [ 263 | "batch_size = 100\n", 264 | "\n", 265 | "train_loader = torch.utils.data.DataLoader(\n", 266 | " datasets.MNIST('../data', train=True, download=True,\n", 267 | " transform=transforms.Compose([\n", 268 | " transforms.ToTensor(),\n", 269 | " transforms.Normalize((0.1307,), (0.3081,))\n", 270 | " ])),\n", 271 | " batch_size=batch_size, shuffle=True)\n", 272 | "\n", 273 | "test_loader = torch.utils.data.DataLoader(\n", 274 | " datasets.MNIST('../data', train=False, transform=transforms.Compose([\n", 275 | " transforms.ToTensor(),\n", 276 | " transforms.Normalize((0.1307,), (0.3081,))\n", 277 | " ])),\n", 278 | " batch_size=batch_size, shuffle=True)" 279 | ], 280 | "execution_count": null, 281 | "outputs": [] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "metadata": { 286 | "colab": { 287 | "base_uri": "https://localhost:8080/" 288 | }, 289 | "id": "xZeZBPooOZLH", 290 | "outputId": "5ac0521b-5c56-4c62-d155-07198de28c3d" 291 | }, 292 | "source": [ 293 | "print(key)\n", 294 | "split = random.split(key, 5) #Can be split into any number of parts. New keys, along new axis\n", 295 | "print(split) \n", 296 | "print(random.split(split[0])) #Can only split \"keys\", i.e. , nd-array of size (2,)" 297 | ], 298 | "execution_count": null, 299 | "outputs": [ 300 | { 301 | "output_type": "stream", 302 | "text": [ 303 | "[0 1]\n", 304 | "[[3243370355 1344208528]\n", 305 | " [ 532076793 2354449600]\n", 306 | " [1813813011 1313272271]\n", 307 | " [3522235465 4107438537]\n", 308 | " [1531693580 2391939978]]\n", 309 | "[[1467608531 2825924092]\n", 310 | " [ 757006082 1868645737]]\n" 311 | ], 312 | "name": "stdout" 313 | } 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": { 319 | "id": "cqSnsYLyR72N" 320 | }, 321 | "source": [ 322 | "Since ```JAX``` offers only a functional programming interface, we can't write classes corresponding to modules, in ```JAX``` . We must write a function for initialization, and forward pass instead. " 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "metadata": { 328 | "id": "Qy3Pytv4NW5g" 329 | }, 330 | "source": [ 331 | "def initialize_mlp(sizes, key):\n", 332 | " \"\"\" Initialize the weights of all layers of a linear layer network \"\"\"\n", 333 | "\n", 334 | " keys = random.split(key, len(sizes))\n", 335 | " \n", 336 | " # Initialize a single layer with Gaussian weights - helper function\n", 337 | " def initialize_layer(m, n, key, scale=1e-2):\n", 338 | " w_key, b_key = random.split(key)\n", 339 | " return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))\n", 340 | " \n", 341 | " return [initialize_layer(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]\n", 342 | "\n", 343 | "\n", 344 | "layer_sizes = [784, 512, 512, 10]\n", 345 | "\n", 346 | "# Return a list of tuples of layer weights\n", 347 | "params = initialize_mlp(layer_sizes, key)\n" 348 | ], 349 | "execution_count": null, 350 | "outputs": [] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": { 355 | "id": "MtVhCfKlSnLQ" 356 | }, 357 | "source": [ 358 | "The forward passs functions should take as input all the parameters(```params```) of the model, and the input(```in_array```) to it. Usually, we make a dictionary of all the parameters, so that the function can access them easily." 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "metadata": { 364 | "id": "yxHYcJ8fQSoJ" 365 | }, 366 | "source": [ 367 | "def forward_pass(params, in_array):\n", 368 | " \"\"\" \n", 369 | " Compute the forward pass for each example individually.\n", 370 | " Inputs : params: List of tuples. Tuples must be as required by relu_layer.\n", 371 | " in_array: Input array as needed by relu_layer.\n", 372 | " \"\"\"\n", 373 | " activations = in_array\n", 374 | "\n", 375 | " # Loop over the ReLU hidden layers\n", 376 | " for w, b in params[:-1]:\n", 377 | " activations = relu_layer([w, b], activations)\n", 378 | "\n", 379 | " # Perform final trafo to logits\n", 380 | " final_w, final_b = params[-1]\n", 381 | " logits = np.dot(final_w, activations) + final_b #Feel free to use any jit-numpy operations in your functions, anywhere.\n", 382 | "\n", 383 | " return logits - logsumexp(logits) #Just simple softmax, it is. \n", 384 | "\n", 385 | "# Make a batched version of the `predict` function\n", 386 | "batch_forward = vmap(forward_pass, in_axes=(None, 0), out_axes=0)\n" 387 | ], 388 | "execution_count": null, 389 | "outputs": [] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "metadata": { 394 | "id": "ZoCF9Cu2R5Fy" 395 | }, 396 | "source": [ 397 | "def one_hot(x, k, dtype=np.float32):\n", 398 | " \"\"\"Create a one-hot encoding of x of size k \"\"\"\n", 399 | " return np.array(x[:, None] == np.arange(k), dtype)\n", 400 | "\n", 401 | "def loss(params, in_arrays, targets):\n", 402 | " \"\"\" \n", 403 | " Compute the multi-class cross-entropy loss.\n", 404 | "\n", 405 | " Inputs : params: list of model parameters as accepted by forward_pass\n", 406 | " in_arrays: input_array as accepted by forward_pass\n", 407 | " targets: jit-numpy array containing one hot targets\n", 408 | " \"\"\"\n", 409 | " preds = batch_forward(params, in_arrays)\n", 410 | " return -np.sum(preds * targets) #Cross Entropy Loss. Divide by 784 to average.\n", 411 | "\n", 412 | "def accuracy(params, data_loader):\n", 413 | " \"\"\" Compute the accuracy for a provided dataloader \"\"\"\n", 414 | " acc_total = 0\n", 415 | " num_classes = 10\n", 416 | "\n", 417 | " for batch_idx, (data, target) in enumerate(data_loader):\n", 418 | " images = np.array(data).reshape(data.size(0), 28*28) #Need to make PyTorch tensors, into jit-numpy arrays\n", 419 | " targets = one_hot(np.array(target), num_classes)\n", 420 | "\n", 421 | " target_class = np.argmax(targets, axis=1)\n", 422 | " predicted_class = np.argmax(batch_forward(params, images), axis=1)\n", 423 | " acc_total += np.sum(predicted_class == target_class)\n", 424 | " return acc_total/len(data_loader.dataset)" 425 | ], 426 | "execution_count": null, 427 | "outputs": [] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "metadata": { 432 | "colab": { 433 | "base_uri": "https://localhost:8080/" 434 | }, 435 | "id": "OzuAhdz6UQlN", 436 | "outputId": "9ec82b1d-0077-4501-f0ba-603ca56ce437" 437 | }, 438 | "source": [ 439 | "x = np.arange(3)\n", 440 | "print(x.shape)\n", 441 | "print(x[None, :].shape)\n", 442 | "print(x[:,None].shape)\n", 443 | "print(x+x[None,:])\n", 444 | "print(x[None,:]+x[:,None])\n", 445 | "print(x+x[:,None])" 446 | ], 447 | "execution_count": null, 448 | "outputs": [ 449 | { 450 | "output_type": "stream", 451 | "text": [ 452 | "(3,)\n", 453 | "(1, 3)\n", 454 | "(3, 1)\n", 455 | "[[0 2 4]]\n", 456 | "[[0 1 2]\n", 457 | " [1 2 3]\n", 458 | " [2 3 4]]\n", 459 | "[[0 1 2]\n", 460 | " [1 2 3]\n", 461 | " [2 3 4]]\n" 462 | ], 463 | "name": "stdout" 464 | } 465 | ] 466 | }, 467 | { 468 | "cell_type": "markdown", 469 | "metadata": { 470 | "id": "dzzVCvqpcWqw" 471 | }, 472 | "source": [ 473 | "```value_and_grad(fn)``` returns a function that takes same arguments(```x```) as ```fn``` and returns both the return value(```fn(x)```) of ```fn``` and its gradient(```fn'(x)```), as a tuple. \n", 474 | "\n", 475 | "The optimizer below stores its data(parameters and hyperparameters) in ```opt_state``` and its functionality is defined in ```opt_update()```, ```opt_init()``` and ```get_params()``` . Notice how there is no class. It would be better to put all 4 things in a dicionary, hence." 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "metadata": { 481 | "id": "39Pivo8eUk9X" 482 | }, 483 | "source": [ 484 | "@jit\n", 485 | "def update(params, x, y, opt_state):\n", 486 | " \"\"\" \n", 487 | " Compute the gradient for a batch and update the parameters\n", 488 | "\n", 489 | " Inputs : params: list of model parameters as accepted by loss function (in turn by forward_pass)\n", 490 | " x: input as accepted by loss_function(in turn by forward_pass)\n", 491 | " y: jit-numpy array containing one hot targets(as required by loss function)\n", 492 | " opt_state: as required by opt_update\n", 493 | " Returns : \n", 494 | " updated parameters, current optimizer state, computed value\n", 495 | " \"\"\"\n", 496 | " value, grads = value_and_grad(loss)(params, x, y)\n", 497 | " opt_state = opt_update(0, grads, opt_state) #opt_update is a function, not a variable, hence is available in this scope, although not defined here.\n", 498 | " return get_params(opt_state), opt_state, value #The first argument to the opt_update function is the optimizer step number.\n", 499 | "\n", 500 | "# Defining an optimizer in Jax\n", 501 | "step_size = 1e-3\n", 502 | "opt_init, opt_update, get_params = optimizers.adam(step_size)\n", 503 | "opt_state = opt_init(params) #All the updatable parameters. First opt_state needs to be obtained this way, always.\n", 504 | "\n", 505 | "num_epochs = 10\n", 506 | "num_classes = 10" 507 | ], 508 | "execution_count": null, 509 | "outputs": [] 510 | }, 511 | { 512 | "cell_type": "markdown", 513 | "metadata": { 514 | "id": "tnZioQ7tlGzx" 515 | }, 516 | "source": [ 517 | "Notice how in all the above code, each function tries to make sure that its input fits well with the functions that it is calling. And this leads to a hierarchical structure, in stark comparison to the step-wise structure of PyTorch code. " 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "metadata": { 523 | "colab": { 524 | "base_uri": "https://localhost:8080/" 525 | }, 526 | "id": "nY3bya5GeQBF", 527 | "outputId": "0e85d070-9a6b-4784-c3e4-8fe676b4197e" 528 | }, 529 | "source": [ 530 | "def run_mnist_training_loop(num_epochs, opt_state, net_type=\"MLP\"):\n", 531 | " \"\"\" Implements a learning loop over epochs. \"\"\"\n", 532 | "\n", 533 | " # Initialize placeholder for logging\n", 534 | " log_acc_train, log_acc_test, train_loss = [], [], []\n", 535 | "\n", 536 | " # Get the initial set of parameters\n", 537 | " params = get_params(opt_state) #Assumes all parameters are updatable. Otherwise send as argument in this function. \n", 538 | "\n", 539 | " # Get initial accuracy after random init\n", 540 | " train_acc = accuracy(params, train_loader)\n", 541 | " test_acc = accuracy(params, test_loader)\n", 542 | " log_acc_train.append(train_acc)\n", 543 | " log_acc_test.append(test_acc)\n", 544 | "\n", 545 | " # Loop over the training epochs\n", 546 | " for epoch in range(num_epochs):\n", 547 | " start_time = time.time()\n", 548 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 549 | " if net_type == \"MLP\":\n", 550 | " # Flatten the image into 784-sized vectors for the MLP\n", 551 | " x = np.array(data).reshape(data.size(0), 28*28)\n", 552 | " elif net_type == \"CNN\":\n", 553 | " # No flattening of the input required for the CNN\n", 554 | " x = np.array(data)\n", 555 | " y = one_hot(np.array(target), num_classes)\n", 556 | " params, opt_state, loss = update(params, x, y, opt_state)\n", 557 | " train_loss.append(loss)\n", 558 | "\n", 559 | " epoch_time = time.time() - start_time\n", 560 | " train_acc = accuracy(params, train_loader)\n", 561 | " test_acc = accuracy(params, test_loader)\n", 562 | " log_acc_train.append(train_acc)\n", 563 | " log_acc_test.append(test_acc)\n", 564 | " print(\"Epoch {} | T: {:0.2f} | Train A: {:0.3f} | Test A: {:0.3f}\".format(epoch+1, epoch_time,\n", 565 | " train_acc, test_acc))\n", 566 | "\n", 567 | " return train_loss, log_acc_train, log_acc_test\n", 568 | "\n", 569 | "\n", 570 | "train_loss, train_log, test_log = run_mnist_training_loop(num_epochs,\n", 571 | " opt_state,\n", 572 | " net_type=\"MLP\")\n" 573 | ], 574 | "execution_count": null, 575 | "outputs": [ 576 | { 577 | "output_type": "stream", 578 | "text": [ 579 | "Epoch 1 | T: 16.56 | Train A: 0.973 | Test A: 0.968\n", 580 | "Epoch 2 | T: 15.61 | Train A: 0.984 | Test A: 0.974\n", 581 | "Epoch 3 | T: 15.55 | Train A: 0.990 | Test A: 0.979\n", 582 | "Epoch 4 | T: 15.63 | Train A: 0.993 | Test A: 0.981\n", 583 | "Epoch 5 | T: 15.41 | Train A: 0.992 | Test A: 0.978\n", 584 | "Epoch 6 | T: 15.13 | Train A: 0.997 | Test A: 0.982\n", 585 | "Epoch 7 | T: 15.27 | Train A: 0.996 | Test A: 0.980\n", 586 | "Epoch 8 | T: 15.93 | Train A: 0.996 | Test A: 0.980\n", 587 | "Epoch 9 | T: 15.80 | Train A: 0.995 | Test A: 0.981\n", 588 | "Epoch 10 | T: 15.36 | Train A: 0.997 | Test A: 0.982\n" 589 | ], 590 | "name": "stdout" 591 | } 592 | ] 593 | }, 594 | { 595 | "cell_type": "markdown", 596 | "metadata": { 597 | "id": "wYmvwp9RloXc" 598 | }, 599 | "source": [ 600 | "# CNN" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "metadata": { 606 | "id": "S8h_tPpakkU7" 607 | }, 608 | "source": [ 609 | "from jax.experimental import stax\n", 610 | "from jax.experimental.stax import (BatchNorm, Conv, Dense, Flatten,\n", 611 | " Relu, LogSoftmax)" 612 | ], 613 | "execution_count": null, 614 | "outputs": [] 615 | }, 616 | { 617 | "cell_type": "markdown", 618 | "metadata": { 619 | "id": "QOMdDw6hnlsy" 620 | }, 621 | "source": [ 622 | "The ```init_fun()``` below takes the ```key``` and the shape of input as its arguments. It returns the output shape and the randomly assigned parameters. \n", 623 | "\n", 624 | "The ```conv_net()``` function takes ```params``` and input of the shape specified in second argument of ```init_fun()``` and returns the result of the convolution operations specified in ```stax.serial()```. Note that if it is a function that returns ```f(x)``` , you can quickly make another one to get ```f'(x)``` ." 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "metadata": { 630 | "id": "STD7OZT3l1JT" 631 | }, 632 | "source": [ 633 | "init_fun, conv_net = stax.serial(Conv(32, (5, 5), (2, 2), padding=\"SAME\"), #First argument is number of out channels, second is filter shape, third stride.\n", 634 | " BatchNorm(), Relu,\n", 635 | " Conv(32, (5, 5), (2, 2), padding=\"SAME\"),\n", 636 | " BatchNorm(), Relu,\n", 637 | " Conv(10, (3, 3), (2, 2), padding=\"SAME\"),\n", 638 | " BatchNorm(), Relu,\n", 639 | " Conv(10, (3, 3), (2, 2), padding=\"SAME\"), Relu,\n", 640 | " Flatten,\n", 641 | " Dense(num_classes), #Only final size needs to be specified !! \n", 642 | " LogSoftmax)\n", 643 | "\n", 644 | "output_shape, params = init_fun(key, (batch_size, 1, 28, 28))" 645 | ], 646 | "execution_count": null, 647 | "outputs": [] 648 | }, 649 | { 650 | "cell_type": "markdown", 651 | "metadata": { 652 | "id": "U_zUMJ4YqFO5" 653 | }, 654 | "source": [ 655 | "Various types of initializations can also be specified for each layer. See [here](https://jax.readthedocs.io/en/latest/_modules/jax/experimental/stax.html#serial) for default initializations of each layer. " 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "metadata": { 661 | "id": "7kH0B_kXpJ2o" 662 | }, 663 | "source": [ 664 | "def accuracy(params, data_loader):\n", 665 | " \"\"\" Compute the accuracy for the CNN case (no flattening of input)\"\"\"\n", 666 | " acc_total = 0\n", 667 | " for batch_idx, (data, target) in enumerate(data_loader):\n", 668 | " images = np.array(data)\n", 669 | " targets = one_hot(np.array(target), num_classes)\n", 670 | "\n", 671 | " target_class = np.argmax(targets, axis=1)\n", 672 | " predicted_class = np.argmax(conv_net(params, images), axis=1)\n", 673 | " acc_total += np.sum(predicted_class == target_class)\n", 674 | " return acc_total/len(data_loader.dataset)\n", 675 | "\n", 676 | "def loss(params, images, targets):\n", 677 | " preds = conv_net(params, images)\n", 678 | " return -np.sum(preds * targets)" 679 | ], 680 | "execution_count": null, 681 | "outputs": [] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "metadata": { 686 | "id": "DSD2aU40w2Ut" 687 | }, 688 | "source": [ 689 | "step_size = 1e-3\n", 690 | "opt_init, opt_update, get_params = optimizers.adam(step_size)\n", 691 | "opt_state = opt_init(params)\n", 692 | "num_epochs = 10\n", 693 | "\n", 694 | "train_loss, train_log, test_log = run_mnist_training_loop(num_epochs,\n", 695 | " opt_state,\n", 696 | " net_type=\"CNN\")" 697 | ], 698 | "execution_count": null, 699 | "outputs": [] 700 | } 701 | ] 702 | } -------------------------------------------------------------------------------- /basic_transformer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "basic_transformer.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [ 9 | "jAFyvoEV1WRB", 10 | "Ga6G28Qe1j0v", 11 | "RseynpJDHYtj", 12 | "J8vTrTw6c_OR" 13 | ], 14 | "authorship_tag": "ABX9TyP5VPOjDaBymJwBxJP5Gz8R", 15 | "include_colab_link": true 16 | }, 17 | "kernelspec": { 18 | "name": "python3", 19 | "display_name": "Python 3" 20 | }, 21 | "accelerator": "GPU" 22 | }, 23 | "cells": [ 24 | { 25 | "cell_type": "markdown", 26 | "metadata": { 27 | "id": "view-in-github", 28 | "colab_type": "text" 29 | }, 30 | "source": [ 31 | "\"Open" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "metadata": { 37 | "id": "wMwOcI23zjM-" 38 | }, 39 | "source": [ 40 | "!pip install git+https://github.com/deepmind/dm-haiku\n", 41 | "!pip install transformers\n", 42 | "!pip install git+git://github.com/deepmind/optax.git" 43 | ], 44 | "execution_count": null, 45 | "outputs": [] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "metadata": { 50 | "id": "cincO7Umzsu1" 51 | }, 52 | "source": [ 53 | "import haiku as hk\n", 54 | "import jax.numpy as jnp\n", 55 | "import jax\n", 56 | "\n", 57 | "from jax import jit\n", 58 | "from jax.random import PRNGKey\n", 59 | "import numpy as np" 60 | ], 61 | "execution_count": null, 62 | "outputs": [] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": { 67 | "id": "2LD9FiTbzyjU" 68 | }, 69 | "source": [ 70 | "#Transformers-Classification Using pre-trained weights from RoBERTa" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": { 76 | "id": "FU2MeYIv0BHT" 77 | }, 78 | "source": [ 79 | "## Embedding Layers" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "metadata": { 85 | "id": "I7w2UjDAzxcq" 86 | }, 87 | "source": [ 88 | "from transformers import RobertaModel\n", 89 | "\n", 90 | "class Embedding(hk.Module):\n", 91 | " def __init__(self, config):\n", 92 | " super().__init__()\n", 93 | " self.config = config\n", 94 | "\n", 95 | " def __call__(self, token_ids, training=False):\n", 96 | " \"\"\"\n", 97 | " token_ids: ints of shape (batch, n_seq)\n", 98 | " \"\"\"\n", 99 | " word_embeddings = self.config['pretrained']['embeddings/word_embeddings']\n", 100 | " \n", 101 | " flat_token_ids = jnp.reshape(token_ids, [-1])\n", 102 | " \n", 103 | " flat_token_embeddings = hk.Embed(vocab_size=word_embeddings.shape[0],\n", 104 | " embed_dim=word_embeddings.shape[1],\n", 105 | " w_init=hk.initializers.Constant(word_embeddings))(flat_token_ids)\n", 106 | "\n", 107 | " token_embeddings = jnp.reshape(flat_token_embeddings, [token_ids.shape[0], -1, word_embeddings.shape[1]])\n", 108 | " \n", 109 | " embeddings = token_embeddings + PositionEmbeddings(self.config)()\n", 110 | "\n", 111 | " embeddings = hk.LayerNorm(axis=-1,\n", 112 | " create_scale=True,\n", 113 | " create_offset=True,\n", 114 | " scale_init=hk.initializers.Constant(self.config['pretrained']['embeddings/LayerNorm/gamma']),\n", 115 | " offset_init=hk.initializers.Constant(self.config['pretrained']['embeddings/LayerNorm/beta']))(embeddings)\n", 116 | " if training:\n", 117 | " embeddings = hk.dropout(hk.next_rng_key(),\n", 118 | " rate=self.config['embed_dropout_rate'],\n", 119 | " x=embeddings)\n", 120 | " \n", 121 | " return embeddings" 122 | ], 123 | "execution_count": null, 124 | "outputs": [] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "metadata": { 129 | "id": "kAoqTp6Uz7ta" 130 | }, 131 | "source": [ 132 | "class PositionEmbeddings(hk.Module):\n", 133 | " \"\"\"\n", 134 | " A position embedding of size [max_seq_leq, word_embedding_dim]\n", 135 | " \"\"\"\n", 136 | " def __init__(self, config):\n", 137 | " super().__init__()\n", 138 | " self.config = config\n", 139 | " self.offset = 2\n", 140 | "\n", 141 | " def __call__(self):\n", 142 | " pretrained_position_embeddings = self.config['pretrained']['embeddings/position_embeddings']\n", 143 | "\n", 144 | " position_weights = hk.get_parameter(\"position_embeddings\",\n", 145 | " pretrained_position_embeddings.shape,\n", 146 | " init=hk.initializers.Constant(pretrained_position_embeddings))\n", 147 | " \n", 148 | " start = self.offset\n", 149 | " end = self.offset+self.config['max_length']\n", 150 | " \n", 151 | " return position_weights[start:end]" 152 | ], 153 | "execution_count": null, 154 | "outputs": [] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": { 159 | "id": "tZgxynPM0a5G" 160 | }, 161 | "source": [ 162 | "## Tokenizer and Utilities for Downloading and Extracting pre-trained weights" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "metadata": { 168 | "id": "Iu1HMppK0Fzh" 169 | }, 170 | "source": [ 171 | "from io import BytesIO\n", 172 | "from functools import lru_cache\n", 173 | "\n", 174 | "import joblib\n", 175 | "import requests\n", 176 | "\n", 177 | "from transformers import RobertaModel, RobertaTokenizer\n", 178 | "\n", 179 | "huggingface_roberta = RobertaModel.from_pretrained('roberta-base', output_hidden_states=True)\n", 180 | "\n", 181 | "huggingface_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')\n" 182 | ], 183 | "execution_count": null, 184 | "outputs": [] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "metadata": { 189 | "id": "kFw7QoSj0IRe" 190 | }, 191 | "source": [ 192 | "def postprocess_key(key):\n", 193 | " key = key.replace('model/featurizer/bert/', '')\n", 194 | " key = key.replace(':0', '')\n", 195 | " key = key.replace('self/', '')\n", 196 | " return key" 197 | ], 198 | "execution_count": null, 199 | "outputs": [] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "metadata": { 204 | "id": "eEr4QWMj0L6O" 205 | }, 206 | "source": [ 207 | "@lru_cache()\n", 208 | "def get_pretrained_weights():\n", 209 | " # We'll use the weight dictionary from the Roberta encoder at \n", 210 | " # https://github.com/IndicoDataSolutions/finetune\n", 211 | " remote_url = \"https://bendropbox.s3.amazonaws.com/roberta/roberta-model-sm-v2.jl\"\n", 212 | " weights = joblib.load(BytesIO(requests.get(remote_url).content))\n", 213 | "\n", 214 | " weights = {\n", 215 | " postprocess_key(key): value\n", 216 | " for key, value in weights.items()\n", 217 | " }\n", 218 | "\n", 219 | " input_embeddings = huggingface_roberta.get_input_embeddings()\n", 220 | " weights['embeddings/word_embeddings'] = input_embeddings.weight.detach().numpy()\n", 221 | "\n", 222 | " return weights\n" 223 | ], 224 | "execution_count": null, 225 | "outputs": [] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "metadata": { 230 | "id": "IsfSZH7C0YY6" 231 | }, 232 | "source": [ 233 | "class Scope(object):\n", 234 | " \"\"\"\n", 235 | " A tiny utility to help make looking up into our dictionary cleaner.\n", 236 | " There's no haiku magic here.\n", 237 | " \"\"\"\n", 238 | " def __init__(self, weights, prefix):\n", 239 | " self.weights = weights\n", 240 | " self.prefix = prefix\n", 241 | "\n", 242 | " def __getitem__(self, key):\n", 243 | " return self.weights[self.prefix + key]" 244 | ], 245 | "execution_count": null, 246 | "outputs": [] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "metadata": { 251 | "id": "BMhZBNCE0nH8" 252 | }, 253 | "source": [ 254 | "##Running the Embedding Layers" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "metadata": { 260 | "id": "G5DPXgub0jIu" 261 | }, 262 | "source": [ 263 | "sample_text = \"This was a flower of evil.\"\n", 264 | "\n", 265 | "\n", 266 | "config = {'pretrained' : get_pretrained_weights(),\n", 267 | " 'max_length' : 512,\n", 268 | " 'embed_dropout_rate' : 0.1\n", 269 | " }\n", 270 | "\n", 271 | "encoded = huggingface_tokenizer.batch_encode_plus([sample_text, sample_text],\n", 272 | " padding='max_length',\n", 273 | " max_length=config['max_length'])\n", 274 | "\n", 275 | "sample_tokens = encoded['input_ids']" 276 | ], 277 | "execution_count": null, 278 | "outputs": [] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "metadata": { 283 | "id": "xdEjN5KY0leO" 284 | }, 285 | "source": [ 286 | "\n", 287 | "\n", 288 | "def embed_fn(tokens, training=False) :\n", 289 | " embedding = Embedding(config)(tokens)\n", 290 | " return embedding\n", 291 | "\n", 292 | "rng = PRNGKey(42)\n", 293 | "embed = hk.transform(embed_fn, apply_rng=True)\n", 294 | "sample_tokens = np.asarray(sample_tokens)\n", 295 | "params = embed.init(rng, sample_tokens, training=False)\n", 296 | "embedded_tokens = jit(embed.apply)(params, rng, sample_tokens, training=False)" 297 | ], 298 | "execution_count": null, 299 | "outputs": [] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "metadata": { 304 | "id": "jNRxih2y0vI2" 305 | }, 306 | "source": [ 307 | "## Transformer Block" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "metadata": { 313 | "id": "Tz6FnSL_00Md" 314 | }, 315 | "source": [ 316 | "class TransformerBlock(hk.Module):\n", 317 | "\n", 318 | " def __init__(self, config, layer_num):\n", 319 | " super().__init__()\n", 320 | " self.config = config\n", 321 | " self.n = layer_num\n", 322 | "\n", 323 | " def __call__(self, x, mask, training = False):\n", 324 | "\n", 325 | " scope = Scope(\n", 326 | " self.config['pretrained'], f'encoder/layer_{self.n}/'\n", 327 | " )\n", 328 | "\n", 329 | " attention_output = MultiHeadAttention(self.config,\n", 330 | " self.n)(x, mask, training=training)\n", 331 | " \n", 332 | " residual = attention_output+x\n", 333 | "\n", 334 | " attention_output = hk.LayerNorm(axis=-1,\n", 335 | " create_scale=True,\n", 336 | " create_offset=True,\n", 337 | " scale_init=hk.initializers.Constant(scope['attention/output/LayerNorm/gamma']),\n", 338 | " offset_init=hk.initializers.Constant(scope['attention/output/LayerNorm/beta']),)(residual)\n", 339 | "\n", 340 | " mlp_output = TransformerMLP(self.config, self.n)(attention_output, training=training)\n", 341 | "\n", 342 | " output_residual = mlp_output+attention_output\n", 343 | "\n", 344 | " layer_output = hk.LayerNorm(axis=-1,\n", 345 | " create_scale=True,\n", 346 | " create_offset=True,\n", 347 | " scale_init=hk.initializers.Constant(scope['output/LayerNorm/gamma']),\n", 348 | " offset_init=hk.initializers.Constant(scope['output/LayerNorm/beta']))(output_residual)\n", 349 | " \n", 350 | " return layer_output" 351 | ], 352 | "execution_count": null, 353 | "outputs": [] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "metadata": { 358 | "id": "bD18sx8U01d8" 359 | }, 360 | "source": [ 361 | "## Multi-Head Attention" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "metadata": { 367 | "id": "cYPItvV700_9" 368 | }, 369 | "source": [ 370 | "class MultiHeadAttention(hk.Module):\n", 371 | " def __init__(self, config, layer_num):\n", 372 | " super().__init__()\n", 373 | " self.config = config\n", 374 | " self.n = layer_num\n", 375 | "\n", 376 | " def _split_into_heads(self, x):\n", 377 | " return jnp.reshape(x, [x.shape[0], x.shape[1], self.config['n_heads'], x.shape[2]//self.config['n_heads']])\n", 378 | "\n", 379 | " def __call__(self, x, mask, training=False):\n", 380 | " \n", 381 | " scope = Scope(self.config['pretrained'], f'encoder/layer_{self.n}/attention/')\n", 382 | "\n", 383 | " queries = hk.Linear(output_size=self.config['hidden_size'],\n", 384 | " w_init=hk.initializers.Constant(scope['query/kernel']),\n", 385 | " b_init=hk.initializers.Constant(scope['query/bias']))(x)\n", 386 | " \n", 387 | " keys = hk.Linear(output_size=self.config['hidden_size'],\n", 388 | " w_init=hk.initializers.Constant(scope['key/kernel']),\n", 389 | " b_init=hk.initializers.Constant(scope['key/bias']))(x)\n", 390 | " \n", 391 | " values = hk.Linear(output_size=self.config['hidden_size'],\n", 392 | " w_init=hk.initializers.Constant(scope['value/kernel']),\n", 393 | " b_init=hk.initializers.Constant(scope['value/bias']))(x)\n", 394 | " \n", 395 | " queries = self._split_into_heads(queries)\n", 396 | " keys = self._split_into_heads(keys)\n", 397 | " values = self._split_into_heads(values)\n", 398 | "\n", 399 | " attention_logits = jnp.einsum('bsnh,btnh->bnst', queries, keys)\n", 400 | " attention_logits /= np.sqrt(queries.shape[-1])\n", 401 | "\n", 402 | " attention_logits += jnp.reshape(mask*-2**32, [mask.shape[0],1,1,mask.shape[1]])\n", 403 | " attention_weights = jax.nn.softmax(attention_logits, axis=-1)\n", 404 | " per_head_attention_output = jnp.einsum('btnh,bnst->bsnh', values, attention_weights)\n", 405 | " attention_output = jnp.reshape(per_head_attention_output, [per_head_attention_output.shape[0], per_head_attention_output.shape[1], -1])\n", 406 | "\n", 407 | " attention_output = hk.Linear(output_size=self.config['hidden_size'],\n", 408 | " w_init=hk.initializers.Constant(scope['output/dense/kernel']),\n", 409 | " b_init=hk.initializers.Constant(scope['output/dense/bias']))(attention_output)\n", 410 | " \n", 411 | " if training:\n", 412 | " attention_output = hk.dropout(rng=hk.next_rng_key(),\n", 413 | " rate=self.config['attention_drop_rate'],\n", 414 | " x=attention_output)\n", 415 | " \n", 416 | " return attention_output" 417 | ], 418 | "execution_count": null, 419 | "outputs": [] 420 | }, 421 | { 422 | "cell_type": "markdown", 423 | "metadata": { 424 | "id": "xH1rORmz08oU" 425 | }, 426 | "source": [ 427 | "## Transformer MLP" 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "metadata": { 433 | "id": "ehX4oUHY1ALU" 434 | }, 435 | "source": [ 436 | "def gelu(x):\n", 437 | " return x*0.5*(1.0+jax.scipy.special.erf(x / jnp.sqrt(2.0)))\n", 438 | "\n", 439 | "class TransformerMLP(hk.Module):\n", 440 | "\n", 441 | " def __init__(self, config, layer_num):\n", 442 | " super().__init__()\n", 443 | " self.config = config\n", 444 | " self.n = layer_num\n", 445 | "\n", 446 | " def __call__(self, x, training=False):\n", 447 | "\n", 448 | " scope = Scope(self.config['pretrained'], f'encoder/layer_{self.n}/')\n", 449 | "\n", 450 | " intermediate_output = hk.Linear(output_size=self.config['intermediate_size'],\n", 451 | " w_init=hk.initializers.Constant(scope['intermediate/dense/kernel']),\n", 452 | " b_init=hk.initializers.Constant(scope['intermediate/dense/bias']))(x)\n", 453 | "\n", 454 | " intermediate_output = gelu(intermediate_output)\n", 455 | "\n", 456 | " output = hk.Linear(output_size=self.config['hidden_size'],\n", 457 | " w_init=hk.initializers.Constant(scope['output/dense/kernel']),\n", 458 | " b_init=hk.initializers.Constant(scope['output/dense/bias']))(intermediate_output)\n", 459 | " \n", 460 | " if training:\n", 461 | " output = hk.dropout(rng=hk.next_rng_key(),\n", 462 | " rate=self.config['fully_connected_drop_rate'],\n", 463 | " x=output)\n", 464 | " \n", 465 | " return output" 466 | ], 467 | "execution_count": null, 468 | "outputs": [] 469 | }, 470 | { 471 | "cell_type": "markdown", 472 | "metadata": { 473 | "id": "xblfTQfw1Jh8" 474 | }, 475 | "source": [ 476 | "## Confg and Getting Features from the model" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "metadata": { 482 | "id": "QFkccsZI1DpG" 483 | }, 484 | "source": [ 485 | "class RobertaFeaturizer(hk.Module):\n", 486 | " def __init__(self, config):\n", 487 | " super().__init__()\n", 488 | " self.config = config\n", 489 | "\n", 490 | " def __call__(self, token_ids, training=False):\n", 491 | " x = Embedding(self.config)(token_ids, training=training)\n", 492 | " mask = (token_ids==self.config['mask_id']).astype(jnp.float32)\n", 493 | " for layer_num in range(self.config['n_layers']):\n", 494 | " x = TransformerBlock(config, layer_num=layer_num)(x,mask,training)\n", 495 | " return x" 496 | ], 497 | "execution_count": null, 498 | "outputs": [] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "metadata": { 503 | "colab": { 504 | "base_uri": "https://localhost:8080/" 505 | }, 506 | "id": "6CYk3wgQ1Ihs", 507 | "outputId": "99ab8e37-64b2-4f47-f640-a0a8d5f0ba0e" 508 | }, 509 | "source": [ 510 | "config = {\n", 511 | " 'pretrained' : config['pretrained'], \n", 512 | " 'max_length' : config['max_length'], \n", 513 | " 'embed_dropout_rate' : 0.1,\n", 514 | " 'fully_connected_drop_rate' : 0.1,\n", 515 | " 'attention_drop_rate' : 0.1,\n", 516 | " 'hidden_size' : 768,\n", 517 | " 'intermediate_size' : 3072,\n", 518 | " 'n_heads' : 12,\n", 519 | " 'n_layers' : 12,\n", 520 | " 'mask_id' : 1,\n", 521 | " 'weight_stddev' : 0.02,\n", 522 | " \n", 523 | " 'n_classes' : 2,\n", 524 | " 'classifier_drop_rate' : 0.1,\n", 525 | " 'learning_rate' : 1e-5,\n", 526 | " 'max_grad_norm' : 1.0,\n", 527 | " 'l2' : 0.1,\n", 528 | " 'n_epochs' : 5,\n", 529 | " 'batch_size' : 4\n", 530 | " }\n", 531 | "\n", 532 | "def featurizer_fn(tokens, training=False):\n", 533 | " contextual_embeddings = RobertaFeaturizer(config)(tokens, training=training)\n", 534 | " return contextual_embeddings\n", 535 | "\n", 536 | "rng = PRNGKey(42)\n", 537 | "roberta = hk.transform(featurizer_fn)\n", 538 | "sample_tokens = np.asarray(sample_tokens)\n", 539 | "params = roberta.init(rng, sample_tokens, training=False)\n", 540 | "contextual_embeddings = jit(roberta.apply)(params, rng, sample_tokens)\n", 541 | "print(contextual_embeddings.shape)" 542 | ], 543 | "execution_count": null, 544 | "outputs": [ 545 | { 546 | "output_type": "stream", 547 | "text": [ 548 | "(2, 512, 768)\n" 549 | ], 550 | "name": "stdout" 551 | } 552 | ] 553 | }, 554 | { 555 | "cell_type": "markdown", 556 | "metadata": { 557 | "id": "WtJaXolI-IL2" 558 | }, 559 | "source": [ 560 | "## Getting Data" 561 | ] 562 | }, 563 | { 564 | "cell_type": "code", 565 | "metadata": { 566 | "id": "1OQUUQ9B-MEu" 567 | }, 568 | "source": [ 569 | "import tensorflow_datasets as tfds\n", 570 | "\n", 571 | "def load_dataset(split, training, batch_size, n_epochs=1, n_examples=None):\n", 572 | " ds = tfds.load(\"imdb_reviews\", \n", 573 | " split=f\"{split}[:{n_examples}]\").cache().repeat(n_epochs)\n", 574 | " \n", 575 | " if training:\n", 576 | " ds = ds.shuffle(10*batch_size, seed=0)\n", 577 | " \n", 578 | " ds = ds.batch(batch_size)\n", 579 | "\n", 580 | " return tfds.as_numpy(ds)" 581 | ], 582 | "execution_count": null, 583 | "outputs": [] 584 | }, 585 | { 586 | "cell_type": "code", 587 | "metadata": { 588 | "id": "Vb7W-Jf-_KfI" 589 | }, 590 | "source": [ 591 | "n_examples = 25000\n", 592 | "train = load_dataset('train', training=True, batch_size=4, n_epochs=config['n_epochs'],n_examples=n_examples)" 593 | ], 594 | "execution_count": null, 595 | "outputs": [] 596 | }, 597 | { 598 | "cell_type": "code", 599 | "metadata": { 600 | "id": "Lrd1yBJwAsEt" 601 | }, 602 | "source": [ 603 | "def encode_batch(batch_text):\n", 604 | " batch_text = [\n", 605 | " text[:512].decode('utf-8') if isinstance(text, bytes) else text[:512]\n", 606 | " for text in batch_text\n", 607 | " ]\n", 608 | " \n", 609 | " token_ids = huggingface_tokenizer.batch_encode_plus(batch_text,\n", 610 | " padding='max_length',\n", 611 | " max_length=config['max_length'],\n", 612 | " )['input_ids']\n", 613 | " \n", 614 | " return np.asarray(token_ids)" 615 | ], 616 | "execution_count": null, 617 | "outputs": [] 618 | }, 619 | { 620 | "cell_type": "markdown", 621 | "metadata": { 622 | "id": "jAFyvoEV1WRB" 623 | }, 624 | "source": [ 625 | "## The classifier" 626 | ] 627 | }, 628 | { 629 | "cell_type": "code", 630 | "metadata": { 631 | "id": "ce1B55SL1R2z" 632 | }, 633 | "source": [ 634 | "class RobertaClassifier(hk.Module):\n", 635 | " def __init__(self, config):\n", 636 | " super().__init__()\n", 637 | " self.config = config\n", 638 | "\n", 639 | " def __call__(self, token_ids, training=False):\n", 640 | " sequence_features = RobertaFeaturizer(self.config)(token_ids=token_ids, training=training)\n", 641 | "\n", 642 | " clf_state = sequence_features[:,0,:]\n", 643 | "\n", 644 | " if training:\n", 645 | " clf_state = hk.dropout(rng=hk.next_rng_key(),\n", 646 | " rate=self.config['classifier_drop_rate'],\n", 647 | " x=clf_state)\n", 648 | " \n", 649 | " clf_logits = hk.Linear(output_size=self.config['n_classes'],\n", 650 | " w_init=hk.initializers.TruncatedNormal(self.config['weight_stddev']))(clf_state)\n", 651 | "\n", 652 | " return clf_logits" 653 | ], 654 | "execution_count": null, 655 | "outputs": [] 656 | }, 657 | { 658 | "cell_type": "markdown", 659 | "metadata": { 660 | "id": "Ga6G28Qe1j0v" 661 | }, 662 | "source": [ 663 | "## Running the Classifier" 664 | ] 665 | }, 666 | { 667 | "cell_type": "code", 668 | "metadata": { 669 | "id": "K7bZRSRH1o_R" 670 | }, 671 | "source": [ 672 | "def roberta_classification_fn(batch_token_ids, training):\n", 673 | " logits = RobertaClassifier(config)(batch_token_ids, training=training)\n", 674 | " return logits\n", 675 | "\n", 676 | "rng = jax.random.PRNGKey(42)\n", 677 | "roberta_classifier = hk.transform(roberta_classification_fn) \n", 678 | "\n", 679 | "params = roberta_classifier.init(rng, \n", 680 | " batch_token_ids=encode_batch(['sample sentence', 'Another one!']),\n", 681 | " training=True)\n" 682 | ], 683 | "execution_count": null, 684 | "outputs": [] 685 | }, 686 | { 687 | "cell_type": "markdown", 688 | "metadata": { 689 | "id": "0isUhSMSD-ZR" 690 | }, 691 | "source": [ 692 | "```roberta_classifier.init()``` and ```roberta_classifier.apply()``` are pure functions now. So, can be composed to gether and used with other functions. " 693 | ] 694 | }, 695 | { 696 | "cell_type": "code", 697 | "metadata": { 698 | "id": "lcbdVKceDlyl" 699 | }, 700 | "source": [ 701 | "def loss(params, rng, batch_token_ids, batch_labels):\n", 702 | " logits = roberta_classifier.apply(params, rng, batch_token_ids, training=True)\n", 703 | " labels = hk.one_hot(batch_labels, config['n_classes'])\n", 704 | " softmax_xent = -jnp.sum(labels*jax.nn.log_softmax(logits))\n", 705 | " softmax_xent /= labels.shape[0]\n", 706 | " return softmax_xent\n", 707 | "\n", 708 | "@jax.jit\n", 709 | "def accuracy(params, rng, batch_token_ids, batch_labels):\n", 710 | " logits = roberta_classifier.apply(params, rng, batch_token_ids, training=False)\n", 711 | " return jnp.mean(jnp.argmax(logits, axis=-1)==batch_labels)\n", 712 | "\n", 713 | "@jax.jit\n", 714 | "def update(params, rng, opt_state, batch_token_ids, batch_labels):\n", 715 | " batch_loss, grad = jax.value_and_grad(loss)(params, rng, batch_token_ids, batch_labels)\n", 716 | " updates, opt_state = opt.update(grad, opt_state)\n", 717 | " new_params = optax.apply_updates(params, updates)\n", 718 | " return new_params, opt_state, batch_loss" 719 | ], 720 | "execution_count": null, 721 | "outputs": [] 722 | }, 723 | { 724 | "cell_type": "markdown", 725 | "metadata": { 726 | "id": "RseynpJDHYtj" 727 | }, 728 | "source": [ 729 | "## Defining Learning rate scheduler and Optimizer" 730 | ] 731 | }, 732 | { 733 | "cell_type": "code", 734 | "metadata": { 735 | "id": "xjHY4tKiHX8A" 736 | }, 737 | "source": [ 738 | "import optax" 739 | ], 740 | "execution_count": null, 741 | "outputs": [] 742 | }, 743 | { 744 | "cell_type": "markdown", 745 | "metadata": { 746 | "id": "OJJ1-bzCYYru" 747 | }, 748 | "source": [ 749 | "The below way of defining a functionality allows you to tie together namespaces with functions.(Or \"wrap\" a function in a namespace consisting of variables defined in the outer function).\n", 750 | "\n", 751 | "Here, ```warmup_percentage``` and ```total_steps``` act as if they were variables in a class with a function ```lr_schedule()```. The ```lr_schedule()``` function can access them, freely. " 752 | ] 753 | }, 754 | { 755 | "cell_type": "code", 756 | "metadata": { 757 | "id": "lvGKeVarIE2K" 758 | }, 759 | "source": [ 760 | "def make_lr_schedule(warmup_percentage, total_steps):\n", 761 | " \n", 762 | " def lr_schedule(step):\n", 763 | " percent_complete = step/total_steps\n", 764 | " \n", 765 | " #0 or 1 based on whether we are before peak\n", 766 | " before_peak = jax.lax.convert_element_type((percent_complete<=warmup_percentage),\n", 767 | " np.float32)\n", 768 | " #Factor for scaling learning rate\n", 769 | " scale = ( before_peak*(percent_complete/warmup_percentage)\n", 770 | " + (1-before_peak) ) * (1-percent_complete)\n", 771 | " \n", 772 | " return scale\n", 773 | " \n", 774 | " return lr_schedule" 775 | ], 776 | "execution_count": null, 777 | "outputs": [] 778 | }, 779 | { 780 | "cell_type": "code", 781 | "metadata": { 782 | "id": "HI5QJ_lLXeg2" 783 | }, 784 | "source": [ 785 | "total_steps = config['n_epochs']*(n_examples//config['batch_size'])\n", 786 | "\n", 787 | "lr_schedule = make_lr_schedule(warmup_percentage=0.1, total_steps=total_steps)" 788 | ], 789 | "execution_count": null, 790 | "outputs": [] 791 | }, 792 | { 793 | "cell_type": "code", 794 | "metadata": { 795 | "id": "MclJF-XRX1Mi" 796 | }, 797 | "source": [ 798 | "opt = optax.chain(\n", 799 | " optax.clip_by_global_norm(config['max_grad_norm']),\n", 800 | " optax.adam(learning_rate=config['learning_rate']),\n", 801 | " optax.scale_by_schedule(lr_schedule),\n", 802 | ")\n", 803 | "opt_state = opt.init(params)" 804 | ], 805 | "execution_count": null, 806 | "outputs": [] 807 | }, 808 | { 809 | "cell_type": "markdown", 810 | "metadata": { 811 | "id": "J8vTrTw6c_OR" 812 | }, 813 | "source": [ 814 | "## Utility for Measuring Performance" 815 | ] 816 | }, 817 | { 818 | "cell_type": "code", 819 | "metadata": { 820 | "id": "VZ-UokInc-jh" 821 | }, 822 | "source": [ 823 | "def measure_current_performance(params, n_examples=None, splits=('train', 'test')):\n", 824 | " if 'train' in splits:\n", 825 | " train_eval = load_dataset('train', training=False, batch_size=25, n_examples=n_examples)\n", 826 | "\n", 827 | " train_accuracy = np.mean([accuracy(params, rng, \n", 828 | " encode_batch(train_eval_batch['text']),\n", 829 | " train_eval_batch['label'])\n", 830 | " for train_eval_batch in train_eval])\n", 831 | " \n", 832 | " print(f\"\\t Train validation acc: {train_accuracy:.3f}\")\n", 833 | "\n", 834 | " if 'test' in splits:\n", 835 | " test_eval = load_dataset('test', training=False, batch_size=25, n_examples=n_examples)\n", 836 | "\n", 837 | " test_accuracy = np.mean([accuracy(params, rng, \n", 838 | " encode_batch(test_eval_batch['text']),\n", 839 | " test_eval_batch['label'])\n", 840 | " for test_eval_batch in test_eval])\n", 841 | " \n", 842 | " print(f\"\\t Test validation accuracy: {test_accuracy:.3f}\")" 843 | ], 844 | "execution_count": null, 845 | "outputs": [] 846 | }, 847 | { 848 | "cell_type": "markdown", 849 | "metadata": { 850 | "id": "mM_37Nq_fArJ" 851 | }, 852 | "source": [ 853 | "## Training Loop" 854 | ] 855 | }, 856 | { 857 | "cell_type": "markdown", 858 | "metadata": { 859 | "id": "ygZLLAAJnmPN" 860 | }, 861 | "source": [ 862 | "###For running on a different dataset : " 863 | ] 864 | }, 865 | { 866 | "cell_type": "markdown", 867 | "metadata": { 868 | "id": "u3HPLdlnrZTB" 869 | }, 870 | "source": [ 871 | "**In the cell below :**\n", 872 | "\n", 873 | "* Change Line 1 to enumerate any data set returning batches of actual text(can have emojis too), with their integer labels. For example, ```train_batch['text']``` can be a list(or any other iterable) ```['My name is Jeevesh.', 'I live at your house.']``` with ```train_batch['labels']``` as another list ```[1,2]```.\n", 874 | "\n", 875 | "* Change ```n_classes``` in config.\n", 876 | "\n", 877 | "* Change tokenizer/provide vocabulary to add new tokens for additional languages, using ```huggingface_tokenizer.add_tokens()``` .\n", 878 | "\n", 879 | "* Rest remains same." 880 | ] 881 | }, 882 | { 883 | "cell_type": "code", 884 | "metadata": { 885 | "id": "7XJXY72DfCd6" 886 | }, 887 | "source": [ 888 | "for step, train_batch in enumerate(train):\n", 889 | " \n", 890 | " if step%100==0:\n", 891 | " print(f'[Step {step}]')\n", 892 | " if step%1000==0 and step!=0:\n", 893 | " measure_current_performance(params, n_examples=100)\n", 894 | " print(\"Here\")\n", 895 | " batch_token_ids = encode_batch(train_batch['text'])\n", 896 | " batch_labels = train_batch['label']\n", 897 | " params, opt_state, batch_loss = update(params, rng, opt_state, batch_token_ids, batch_labels)" 898 | ], 899 | "execution_count": null, 900 | "outputs": [] 901 | } 902 | ] 903 | } --------------------------------------------------------------------------------