├── JAX_Tutorial.ipynb └── README.md /JAX_Tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "JAX Tutorial.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "accelerator": "GPU" 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "code", 19 | "metadata": { 20 | "id": "12TM5qL4i3RE", 21 | "colab_type": "code", 22 | "outputId": "6ba219a8-f1ca-44e9-ec8b-a09c18d02d7b", 23 | "colab": { 24 | "base_uri": "https://localhost:8080/", 25 | "height": 34 26 | } 27 | }, 28 | "source": [ 29 | "%tensorflow_version 2.x" 30 | ], 31 | "execution_count": 0, 32 | "outputs": [ 33 | { 34 | "output_type": "stream", 35 | "text": [ 36 | "TensorFlow 2.x selected.\n" 37 | ], 38 | "name": "stdout" 39 | } 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "metadata": { 45 | "id": "lRmDy8DXi71S", 46 | "colab_type": "code", 47 | "outputId": "676bd8b4-79d8-40e0-82be-263c3b7668d1", 48 | "colab": { 49 | "base_uri": "https://localhost:8080/", 50 | "height": 326 51 | } 52 | }, 53 | "source": [ 54 | "!pip install --upgrade jax" 55 | ], 56 | "execution_count": 0, 57 | "outputs": [ 58 | { 59 | "output_type": "stream", 60 | "text": [ 61 | "Collecting jax\n", 62 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/50/f4/d90107c22334c267ccb64e0ea8039018a4740b5dfad1576dd868aac45254/jax-0.1.59.tar.gz (270kB)\n", 63 | "\r\u001b[K |█▏ | 10kB 26.8MB/s eta 0:00:01\r\u001b[K |██▍ | 20kB 6.3MB/s eta 0:00:01\r\u001b[K |███▋ | 30kB 8.8MB/s eta 0:00:01\r\u001b[K |████▉ | 40kB 5.8MB/s eta 0:00:01\r\u001b[K |██████ | 51kB 4.8MB/s eta 0:00:01\r\u001b[K |███████▎ | 61kB 5.7MB/s eta 0:00:01\r\u001b[K |████████▌ | 71kB 6.2MB/s eta 0:00:01\r\u001b[K |█████████▊ | 81kB 7.0MB/s eta 0:00:01\r\u001b[K |███████████ | 92kB 7.8MB/s eta 0:00:01\r\u001b[K |████████████▏ | 102kB 7.7MB/s eta 0:00:01\r\u001b[K |█████████████▎ | 112kB 7.7MB/s eta 0:00:01\r\u001b[K |██████████████▌ | 122kB 7.7MB/s eta 0:00:01\r\u001b[K |███████████████▊ | 133kB 7.7MB/s eta 0:00:01\r\u001b[K |█████████████████ | 143kB 7.7MB/s eta 0:00:01\r\u001b[K |██████████████████▏ | 153kB 7.7MB/s eta 0:00:01\r\u001b[K |███████████████████▍ | 163kB 7.7MB/s eta 0:00:01\r\u001b[K |████████████████████▋ | 174kB 7.7MB/s eta 0:00:01\r\u001b[K |█████████████████████▉ | 184kB 7.7MB/s eta 0:00:01\r\u001b[K |███████████████████████ | 194kB 7.7MB/s eta 0:00:01\r\u001b[K |████████████████████████▎ | 204kB 7.7MB/s eta 0:00:01\r\u001b[K |█████████████████████████▍ | 215kB 7.7MB/s eta 0:00:01\r\u001b[K |██████████████████████████▋ | 225kB 7.7MB/s eta 0:00:01\r\u001b[K |███████████████████████████▉ | 235kB 7.7MB/s eta 0:00:01\r\u001b[K |█████████████████████████████ | 245kB 7.7MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▎ | 256kB 7.7MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▌| 266kB 7.7MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 276kB 7.7MB/s \n", 64 | "\u001b[?25hRequirement already satisfied, skipping upgrade: numpy>=1.12 in /tensorflow-2.1.0/python3.6 (from jax) (1.18.1)\n", 65 | "Requirement already satisfied, skipping upgrade: absl-py in /tensorflow-2.1.0/python3.6 (from jax) (0.9.0)\n", 66 | "Requirement already satisfied, skipping upgrade: opt_einsum in /tensorflow-2.1.0/python3.6 (from jax) (3.1.0)\n", 67 | "Requirement already satisfied, skipping upgrade: six in /tensorflow-2.1.0/python3.6 (from absl-py->jax) (1.14.0)\n", 68 | "Building wheels for collected packages: jax\n", 69 | " Building wheel for jax (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 70 | " Created wheel for jax: filename=jax-0.1.59-cp36-none-any.whl size=314120 sha256=01cc42fb312dfc7360d576e78b1ee72e2349653ea42e201d4b59ddefdc786256\n", 71 | " Stored in directory: /root/.cache/pip/wheels/d5/08/51/4cf5b10be26e86c533c2b577a93f7ec8b317bf02a7bb010b8a\n", 72 | "Successfully built jax\n", 73 | "Installing collected packages: jax\n", 74 | " Found existing installation: jax 0.1.58\n", 75 | " Uninstalling jax-0.1.58:\n", 76 | " Successfully uninstalled jax-0.1.58\n", 77 | "Successfully installed jax-0.1.59\n" 78 | ], 79 | "name": "stdout" 80 | } 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": { 86 | "id": "OpRORJ6tkelE", 87 | "colab_type": "text" 88 | }, 89 | "source": [ 90 | "# JAX 1. Numpy Wrapper" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "metadata": { 96 | "id": "6ewm5mgFi80g", 97 | "colab_type": "code", 98 | "outputId": "84d8b936-a546-4189-c8dd-1b787ab26529", 99 | "colab": { 100 | "base_uri": "https://localhost:8080/", 101 | "height": 34 102 | } 103 | }, 104 | "source": [ 105 | "import numpy as np\n", 106 | "\n", 107 | "x = np.ones((5000, 5000))\n", 108 | "y = np.arange(5000)\n", 109 | "\n", 110 | "%timeit z = np.sin(x) + np.cos(y)" 111 | ], 112 | "execution_count": 0, 113 | "outputs": [ 114 | { 115 | "output_type": "stream", 116 | "text": [ 117 | "1 loop, best of 3: 401 ms per loop\n" 118 | ], 119 | "name": "stdout" 120 | } 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "metadata": { 126 | "id": "lsjQwMS5jL9K", 127 | "colab_type": "code", 128 | "outputId": "2de97392-96a6-4f71-bc80-e5e3906bf6a2", 129 | "colab": { 130 | "base_uri": "https://localhost:8080/", 131 | "height": 34 132 | } 133 | }, 134 | "source": [ 135 | "import jax.numpy as jnp\n", 136 | "x = jnp.ones((5000, 5000))\n", 137 | "y = jnp.arange(5000)\n", 138 | "\n", 139 | "%timeit z = jnp.sin(x) + jnp.cos(y)" 140 | ], 141 | "execution_count": 0, 142 | "outputs": [ 143 | { 144 | "output_type": "stream", 145 | "text": [ 146 | "100 loops, best of 3: 2.15 ms per loop\n" 147 | ], 148 | "name": "stdout" 149 | } 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": { 155 | "id": "2s74k_3ekx5r", 156 | "colab_type": "text" 157 | }, 158 | "source": [ 159 | "# JAX 2. JIT Compiler" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "metadata": { 165 | "id": "VhD5QzxNjTQo", 166 | "colab_type": "code", 167 | "colab": {} 168 | }, 169 | "source": [ 170 | "from jax import jit\n", 171 | "import tensorflow as tf\n", 172 | "\n", 173 | "def fn(x, y):\n", 174 | " z = np.sin(x)\n", 175 | " w = np.cos(y)\n", 176 | " return z + w\n", 177 | "\n", 178 | "@jit\n", 179 | "def fn_jit(x, y):\n", 180 | " z = jnp.sin(x)\n", 181 | " w = jnp.cos(y)\n", 182 | " return z + w\n", 183 | "\n", 184 | "@tf.function\n", 185 | "def fn_tf2(x, y):\n", 186 | " z = tf.sin(x)\n", 187 | " w = tf.cos(y)\n", 188 | " return z + w" 189 | ], 190 | "execution_count": 0, 191 | "outputs": [] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "metadata": { 196 | "id": "cl1PIhYsnV2q", 197 | "colab_type": "code", 198 | "outputId": "4e9f46f1-ebed-4e29-abd9-19df4a0cfde6", 199 | "colab": { 200 | "base_uri": "https://localhost:8080/", 201 | "height": 34 202 | } 203 | }, 204 | "source": [ 205 | "x = np.ones((5000, 5000))\n", 206 | "y = np.ones((5000, 5000))\n", 207 | "%timeit fn(x, y)" 208 | ], 209 | "execution_count": 0, 210 | "outputs": [ 211 | { 212 | "output_type": "stream", 213 | "text": [ 214 | "1 loop, best of 3: 780 ms per loop\n" 215 | ], 216 | "name": "stdout" 217 | } 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "metadata": { 223 | "id": "_ogfbuO_nTaY", 224 | "colab_type": "code", 225 | "outputId": "423d77a4-76a1-4ff2-e674-8855fe21b346", 226 | "colab": { 227 | "base_uri": "https://localhost:8080/", 228 | "height": 34 229 | } 230 | }, 231 | "source": [ 232 | "jx = jnp.ones((5000, 5000))\n", 233 | "jy = jnp.ones((5000, 5000))\n", 234 | "%timeit fn_jit(jx, jy)" 235 | ], 236 | "execution_count": 0, 237 | "outputs": [ 238 | { 239 | "output_type": "stream", 240 | "text": [ 241 | "100 loops, best of 3: 2.12 ms per loop\n" 242 | ], 243 | "name": "stdout" 244 | } 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "metadata": { 250 | "id": "Avtiy3VPncSS", 251 | "colab_type": "code", 252 | "outputId": "6aaf00f1-6d5d-46f8-95d6-1aeba9ff5ace", 253 | "colab": { 254 | "base_uri": "https://localhost:8080/", 255 | "height": 71 256 | } 257 | }, 258 | "source": [ 259 | "tx = tf.ones((5000, 5000))\n", 260 | "ty = tf.ones((5000, 5000))\n", 261 | "%timeit fn_tf2(tx, ty)" 262 | ], 263 | "execution_count": 0, 264 | "outputs": [ 265 | { 266 | "output_type": "stream", 267 | "text": [ 268 | "The slowest run took 4.55 times longer than the fastest. This could mean that an intermediate result is being cached.\n", 269 | "1000 loops, best of 3: 3.36 ms per loop\n" 270 | ], 271 | "name": "stdout" 272 | } 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": { 278 | "id": "PzodSnzgs1Eu", 279 | "colab_type": "text" 280 | }, 281 | "source": [ 282 | "# JAX 3. grad" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "metadata": { 288 | "id": "QLyN4sF-ookp", 289 | "colab_type": "code", 290 | "colab": {} 291 | }, 292 | "source": [ 293 | "from jax import grad\n", 294 | "\n", 295 | "@jit\n", 296 | "def simple_fun(x):\n", 297 | " return jnp.sin(x) / x" 298 | ], 299 | "execution_count": 0, 300 | "outputs": [] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "metadata": { 305 | "id": "YPQ_q3J_qo6e", 306 | "colab_type": "code", 307 | "colab": {} 308 | }, 309 | "source": [ 310 | "grad_simple_fun = grad(simple_fun)" 311 | ], 312 | "execution_count": 0, 313 | "outputs": [] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "metadata": { 318 | "id": "Hmx864Wrqqkn", 319 | "colab_type": "code", 320 | "outputId": "0099c3f6-5242-4825-f67c-3d3108e812dc", 321 | "colab": { 322 | "base_uri": "https://localhost:8080/", 323 | "height": 34 324 | } 325 | }, 326 | "source": [ 327 | "%timeit grad_simple_fun(1.0)" 328 | ], 329 | "execution_count": 0, 330 | "outputs": [ 331 | { 332 | "output_type": "stream", 333 | "text": [ 334 | "1000 loops, best of 3: 1.22 ms per loop\n" 335 | ], 336 | "name": "stdout" 337 | } 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "metadata": { 343 | "id": "3NN2Rqr0r-zE", 344 | "colab_type": "code", 345 | "outputId": "0143afee-155e-46e9-c11c-48e1117cb8e3", 346 | "colab": { 347 | "base_uri": "https://localhost:8080/", 348 | "height": 187 349 | } 350 | }, 351 | "source": [ 352 | "x_range = jnp.arange(10, dtype=jnp.float32)\n", 353 | "[grad_simple_fun(xi) for xi in x_range]" 354 | ], 355 | "execution_count": 0, 356 | "outputs": [ 357 | { 358 | "output_type": "execute_result", 359 | "data": { 360 | "text/plain": [ 361 | "[DeviceArray(nan, dtype=float32),\n", 362 | " DeviceArray(-0.30116874, dtype=float32),\n", 363 | " DeviceArray(-0.43539774, dtype=float32),\n", 364 | " DeviceArray(-0.3456775, dtype=float32),\n", 365 | " DeviceArray(-0.11611074, dtype=float32),\n", 366 | " DeviceArray(0.09508941, dtype=float32),\n", 367 | " DeviceArray(0.16778992, dtype=float32),\n", 368 | " DeviceArray(0.09429243, dtype=float32),\n", 369 | " DeviceArray(-0.03364623, dtype=float32),\n", 370 | " DeviceArray(-0.10632458, dtype=float32)]" 371 | ] 372 | }, 373 | "metadata": { 374 | "tags": [] 375 | }, 376 | "execution_count": 77 377 | } 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "metadata": { 383 | "id": "PcZCALcmqrlk", 384 | "colab_type": "code", 385 | "colab": {} 386 | }, 387 | "source": [ 388 | "grad_grad_simple_fun = grad(grad(simple_fun))" 389 | ], 390 | "execution_count": 0, 391 | "outputs": [] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "metadata": { 396 | "id": "g8kZ-K-Erkfl", 397 | "colab_type": "code", 398 | "outputId": "56669dd6-44dd-48ff-dd2d-71d61342cb30", 399 | "colab": { 400 | "base_uri": "https://localhost:8080/", 401 | "height": 71 402 | } 403 | }, 404 | "source": [ 405 | "%timeit grad_grad_simple_fun(1.0)" 406 | ], 407 | "execution_count": 0, 408 | "outputs": [ 409 | { 410 | "output_type": "stream", 411 | "text": [ 412 | "The slowest run took 93.35 times longer than the fastest. This could mean that an intermediate result is being cached.\n", 413 | "1 loop, best of 3: 3.19 ms per loop\n" 414 | ], 415 | "name": "stdout" 416 | } 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "metadata": { 422 | "id": "lNaREYVErlkH", 423 | "colab_type": "code", 424 | "outputId": "a19b3fd4-f856-49c1-ef7c-dbe7d8ec2307", 425 | "colab": { 426 | "base_uri": "https://localhost:8080/", 427 | "height": 34 428 | } 429 | }, 430 | "source": [ 431 | "grad_grad_simple_fun(1.0)" 432 | ], 433 | "execution_count": 0, 434 | "outputs": [ 435 | { 436 | "output_type": "execute_result", 437 | "data": { 438 | "text/plain": [ 439 | "DeviceArray(-0.23913354, dtype=float32)" 440 | ] 441 | }, 442 | "metadata": { 443 | "tags": [] 444 | }, 445 | "execution_count": 73 446 | } 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "metadata": { 452 | "id": "TmyVUJrAsGBZ", 453 | "colab_type": "code", 454 | "outputId": "dd01de27-f132-4921-d390-585c75668445", 455 | "colab": { 456 | "base_uri": "https://localhost:8080/", 457 | "height": 187 458 | } 459 | }, 460 | "source": [ 461 | "x_range = jnp.arange(10, dtype=jnp.float32)\n", 462 | "[grad_grad_simple_fun(xi) for xi in x_range]" 463 | ], 464 | "execution_count": 0, 465 | "outputs": [ 466 | { 467 | "output_type": "execute_result", 468 | "data": { 469 | "text/plain": [ 470 | "[DeviceArray(nan, dtype=float32),\n", 471 | " DeviceArray(-0.23913354, dtype=float32),\n", 472 | " DeviceArray(-0.01925094, dtype=float32),\n", 473 | " DeviceArray(0.18341166, dtype=float32),\n", 474 | " DeviceArray(0.247256, dtype=float32),\n", 475 | " DeviceArray(0.1537491, dtype=float32),\n", 476 | " DeviceArray(-0.00936072, dtype=float32),\n", 477 | " DeviceArray(-0.12079593, dtype=float32),\n", 478 | " DeviceArray(-0.11525822, dtype=float32),\n", 479 | " DeviceArray(-0.02216326, dtype=float32)]" 480 | ] 481 | }, 482 | "metadata": { 483 | "tags": [] 484 | }, 485 | "execution_count": 78 486 | } 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "metadata": { 492 | "id": "21PcHaMFrmnN", 493 | "colab_type": "code", 494 | "colab": {} 495 | }, 496 | "source": [ 497 | "" 498 | ], 499 | "execution_count": 0, 500 | "outputs": [] 501 | } 502 | ] 503 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # jax-tutorial 2 | 3 | ### google JAX Tutorial 4 | --------------------------------------------------------------------------------