├── README.md └── official-tutorials ├── .ipynb_checkpoints ├── 1.1-快速开始-checkpoint.ipynb ├── 1.2-以JAX的方式思考-checkpoint.ipynb ├── 1.3-JAX锋芒毕露-checkpoint.ipynb ├── 1.4.5-JAX中的伪随机数-checkpoint.ipynb └── 2.3-异步调度-checkpoint.ipynb ├── GettingStarted ├── .ipynb_checkpoints │ └── 1.3-JAX锋芒毕露-checkpoint.ipynb ├── 1.1-快速开始.ipynb ├── 1.2-以JAX的方式思考.ipynb ├── 1.3-JAX锋芒毕露.ipynb └── Tutorial:Jax101 │ ├── .ipynb_checkpoints │ ├── 1.4.1-加速版Numpy——JAX -checkpoint.ipynb │ ├── 1.4.2-JAX的即时编译-checkpoint.ipynb │ ├── 1.4.3-JAX的自动向量化-checkpoint.ipynb │ └── 1.4.4-JAX中的高级自动微分-checkpoint.ipynb │ ├── 1.4.1-加速版Numpy——JAX .ipynb │ ├── 1.4.2-JAX的即时编译.ipynb │ ├── 1.4.3-JAX的自动向量化.ipynb │ ├── 1.4.4-JAX中的高级自动微分.ipynb │ └── 1.4.5-JAX中的伪随机数.ipynb └── ReferenceDocumentation └── 2.3-异步调度.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # JAX 中文教程 2 | 3 | 原英文教程地址: [ReadTheDocs](https://jax.readthedocs.io/en/latest) 4 | 5 | 施工中... 6 | Work in Progress... 7 | 8 | ## 目录 9 | 10 | * 启蒙 11 | * [快速开始JAX](https://github.com/rasin-tsukuba/JAX_chinese_tutorial/blob/main/official-tutorials/GettingStarted/1.1-%E5%BF%AB%E9%80%9F%E5%BC%80%E5%A7%8B.ipynb) 12 | * [以JAX的方式思考](https://github.com/rasin-tsukuba/JAX_chinese_tutorial/blob/main/official-tutorials/GettingStarted/1.2-%E4%BB%A5JAX%E7%9A%84%E6%96%B9%E5%BC%8F%E6%80%9D%E8%80%83.ipynb) 13 | * [JAX锋芒毕露](https://github.com/rasin-tsukuba/JAX_chinese_tutorial/blob/main/official-tutorials/GettingStarted/1.3-JAX%E9%94%8B%E8%8A%92%E6%AF%95%E9%9C%B2.ipynb) 14 | * 教程:JAX 101 15 | * [加速版Numpy——JAX](https://github.com/rasin-tsukuba/JAX_chinese_tutorial/blob/main/official-tutorials/GettingStarted/Tutorial:Jax101/1.4.1-%E5%8A%A0%E9%80%9F%E7%89%88Numpy%E2%80%94%E2%80%94JAX%20.ipynb) 16 | * [JAX的即时编译](https://github.com/rasin-tsukuba/JAX_tutorial_Chinese_version/blob/main/official-tutorials/GettingStarted/Tutorial:Jax101/1.4.2-JAX%E7%9A%84%E5%8D%B3%E6%97%B6%E7%BC%96%E8%AF%91.ipynb) 17 | * [JAX中的自动向量化](https://github.com/rasin-tsukuba/JAX_tutorial_Chinese_version/blob/main/official-tutorials/GettingStarted/Tutorial:Jax101/1.4.3-JAX%E7%9A%84%E8%87%AA%E5%8A%A8%E5%90%91%E9%87%8F%E5%8C%96.ipynb) 18 | * [JAX中的高级自动微分]() 19 | * [JAX中的伪随机数](https://github.com/rasin-tsukuba/JAX_chinese_tutorial/blob/main/official-tutorials/GettingStarted/Tutorial:Jax101/1.4.5-JAX%E4%B8%AD%E7%9A%84%E4%BC%AA%E9%9A%8F%E6%9C%BA%E6%95%B0.ipynb) 20 | * 使用 pytrees 21 | * 样例:ML模型参数 22 | * 自定义 pytree节点 23 | * 常见 pytree 陷阱和模式 24 | * JAX中的并行评估 25 | * 有状态的计算 26 | * 参考文档 27 | * JAX常见问题 (FAQ) 28 | * 变换 29 | * [异步调度](https://github.com/rasin-tsukuba/JAX_chinese_tutorial/blob/main/official-tutorials/ReferenceDocumentation/2.3-%E5%BC%82%E6%AD%A5%E8%B0%83%E5%BA%A6.ipynb) 30 | * 理解Jaxprs 31 | * JAX中的卷积 32 | * Pytrees 33 | * 类型提升语义 34 | * JAX错误 35 | * JAX术语表 36 | * 变更记录 37 | * 高级JAX教程 38 | * `Autodiff`指导手册 39 | * 自动批处理日志密度示例 40 | * 使用Tensorflow数据加载来训练简单的神经网络 41 | * JAX可转换的Python函数自定义派生规则 42 | * JAX原语如何工作 43 | * 用JAX编写自定义Jaxpr解释器 44 | * 使用PyTorch数据加载来训练简单的神经网络 45 | * Python中的XLA 46 | * 带有JAX的MAML教程 47 | * 通过估计JAX中数据分布梯度的生成建模 -------------------------------------------------------------------------------- /official-tutorials/.ipynb_checkpoints/1.1-快速开始-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "loved-situation", 6 | "metadata": {}, 7 | "source": [ 8 | "# JAX快速开始\n", 9 | "\n", 10 | "`JAX` 是CPU, GPU和TPU上的Numpy实现,具有出色的自动求导功能,可用于高性能机器学习" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "id": "checked-auditor", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import jax.numpy as jnp\n", 21 | "from jax import grad, jit, vmap\n", 22 | "from jax import random" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 3, 28 | "id": "sporting-basics", 29 | "metadata": {}, 30 | "outputs": [ 31 | { 32 | "name": "stdout", 33 | "output_type": "stream", 34 | "text": [ 35 | "[-0.3721109 0.26423115 -0.18252768 -0.7368197 -0.44030377 -0.1521442\n", 36 | " -0.67135346 -0.5908641 0.73168886 0.5673026 ]\n" 37 | ] 38 | } 39 | ], 40 | "source": [ 41 | "key = random.PRNGKey(0)\n", 42 | "x = random.normal(key, (10, ))\n", 43 | "print(x)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 8, 49 | "id": "effective-buffalo", 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "9.62 ms ± 194 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "size = 3000\n", 62 | "x = random.normal(key, (size, size), dtype=jnp.float32)\n", 63 | "%timeit jnp.dot(x, x.T).block_until_ready()" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 9, 69 | "id": "confidential-broadcast", 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "name": "stdout", 74 | "output_type": "stream", 75 | "text": [ 76 | "27.5 ms ± 51.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" 77 | ] 78 | } 79 | ], 80 | "source": [ 81 | "import numpy as np\n", 82 | "x = np.random.normal(size=(size, size)).astype(np.float32)\n", 83 | "%timeit jnp.dot(x, x.T).block_until_ready()" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 10, 89 | "id": "hungry-friendly", 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "9.75 ms ± 281 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "from jax import device_put\n", 102 | "\n", 103 | "x = np.random.normal(size=(size, size)).astype(np.float32)\n", 104 | "x = device_put(x)\n", 105 | "%timeit jnp.dot(x, x.T).block_until_ready()" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 11, 111 | "id": "composed-indication", 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "name": "stdout", 116 | "output_type": "stream", 117 | "text": [ 118 | "56.4 ms ± 544 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" 119 | ] 120 | } 121 | ], 122 | "source": [ 123 | "x = np.random.normal(size=(size, size)).astype(np.float32)\n", 124 | "%timeit np.dot(x, x.T)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "id": "several-cardiff", 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [] 134 | } 135 | ], 136 | "metadata": { 137 | "kernelspec": { 138 | "display_name": "Python 3", 139 | "language": "python", 140 | "name": "python3" 141 | }, 142 | "language_info": { 143 | "codemirror_mode": { 144 | "name": "ipython", 145 | "version": 3 146 | }, 147 | "file_extension": ".py", 148 | "mimetype": "text/x-python", 149 | "name": "python", 150 | "nbconvert_exporter": "python", 151 | "pygments_lexer": "ipython3", 152 | "version": "3.9.2" 153 | } 154 | }, 155 | "nbformat": 4, 156 | "nbformat_minor": 5 157 | } 158 | -------------------------------------------------------------------------------- /official-tutorials/.ipynb_checkpoints/1.2-以JAX的方式思考-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 5 6 | } 7 | -------------------------------------------------------------------------------- /official-tutorials/.ipynb_checkpoints/1.3-JAX锋芒毕露-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 5 6 | } 7 | -------------------------------------------------------------------------------- /official-tutorials/.ipynb_checkpoints/1.4.5-JAX中的伪随机数-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 5 6 | } 7 | -------------------------------------------------------------------------------- /official-tutorials/.ipynb_checkpoints/2.3-异步调度-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 5 6 | } 7 | -------------------------------------------------------------------------------- /official-tutorials/GettingStarted/.ipynb_checkpoints/1.3-JAX锋芒毕露-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "iraqi-rough", 6 | "metadata": {}, 7 | "source": [ 8 | "# 🔪JAX锋芒毕露🔪\n", 9 | "\n", 10 | "> 作者:levskaya@ mattjj@\n", 11 | "> \n", 12 | "> 在意大利的乡间漫步时,人们会毫不犹豫地告诉您JAX具有:“una anima di pura programmazione funzionale(纯函数式编程的灵魂)”\n", 13 | "\n", 14 | "JAX是一种用于表达和转换组合的数值程序。JAX还能够便于用于CPU或加速器(GPU或TPU)。JAX对于许多数值和科学编程都非常有用,但前提是它们是在一下描述的某些约束条件下编写而成的。\n", 15 | "\n" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "id": "brief-ensemble", 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "import numpy as np\n", 26 | "from jax import grad, jit\n", 27 | "from jax import lax\n", 28 | "from jax import random\n", 29 | "import jax \n", 30 | "import jax.numpy as jnp\n", 31 | "import matplotlib as mpl\n", 32 | "from matplotlib import pyplot as plt\n", 33 | "from matplotlib import rcParams\n", 34 | "\n", 35 | "rcParams['image.interpolation'] = 'nearest'\n", 36 | "rcParams['image.cmap'] = 'viridis'\n", 37 | "rcParams['axes.grid'] = False" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "intended-commander", 43 | "metadata": {}, 44 | "source": [ 45 | "# 🔪纯函数\n", 46 | "\n", 47 | "JAX变换和编译仅适用于功能纯净的Python函数:所有输入数据均通过函数参数传递,所有结果均通过函数结果输出。如果使用相同的输入调用纯函数,则总会返回相同的结果。\n", 48 | "\n", 49 | "这是一些功能上并非纯函数的示例,对于这些函数,JAX的行为不同于Python解释器。注意,JAX系统不能保证这些行为。使用JAX的正确方法是使用功能纯粹的纯Python函数。" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 2, 55 | "id": "collective-cabin", 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "Executing function\n", 63 | "First call: 4.0\n", 64 | "Second call: 5.0\n", 65 | "Executing function\n", 66 | "Third call, different type: [5.]\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "def impure_print_side_effect(x):\n", 72 | " print(\"Executing function\") # 这就是一种副作用\n", 73 | " return x\n", 74 | "\n", 75 | "# The side-effects appear during the first run \n", 76 | "# 在第一次执行时就出现副作用\n", 77 | "print(\"First call: \", jit(impure_print_side_effect)(4.))\n", 78 | "\n", 79 | "# Subsequent runs with parameters of same type and shape may not show the side-effect\n", 80 | "# This is because JAX now invokes a cached compilation of the function\n", 81 | "# 使用相同类型和形状的参数进行的后续运行可能不会显示副作用\n", 82 | "# 这是因为JAX现在调用了该函数的缓存编译\n", 83 | "print(\"Second call: \", jit(impure_print_side_effect)(5.))\n", 84 | "\n", 85 | "# JAX re-runs the Python function when the type or shape of the argument changes\n", 86 | "# 当参数的类型或者形状更改时,JAX重新运行Python函数\n", 87 | "print(\"Third call, different type: \", jit(impure_print_side_effect)(jnp.array([5.])))\n" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 3, 93 | "id": "funky-tennis", 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "First call: 4.0\n", 101 | "Second call: 5.0\n", 102 | "Third call, different type: [14.]\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "g = 0.\n", 108 | "def impure_uses_globals(x):\n", 109 | " return x + g\n", 110 | "\n", 111 | "# JAX captures the value of the global during the first run\n", 112 | "# JAX在第一次执行时捕捉到全局变量\n", 113 | "print (\"First call: \", jit(impure_uses_globals)(4.))\n", 114 | "\n", 115 | "g = 10. # 更新全局变量\n", 116 | "\n", 117 | "# Subsequent runs may silently use the cached value of the globals\n", 118 | "# 以下的结果将会默认使用缓存的全局变量\n", 119 | "print (\"Second call: \", jit(impure_uses_globals)(5.))\n", 120 | "\n", 121 | "# JAX re-runs the Python function when the type or shape of the argument changes\n", 122 | "# This will end up reading the latest value of the global\n", 123 | "# 参数的类型或者形状更改时,JAX重新运行Python函数\n", 124 | "# 这样将会重新读取最新的全局变量值\n", 125 | "print (\"Third call, different type: \", jit(impure_uses_globals)(jnp.array([4.])))" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 4, 131 | "id": "pleasant-athletics", 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "First call: 4.0\n", 139 | "Saved global: Tracedwith\n" 140 | ] 141 | } 142 | ], 143 | "source": [ 144 | "g = 0.\n", 145 | "def impure_saves_global(x):\n", 146 | " global g\n", 147 | " g = x\n", 148 | " return x\n", 149 | "\n", 150 | "# JAX runs once the transformed function with special Traced values for arguments\n", 151 | "# JAX运行带有参数的特殊跟踪值转换后的函数\n", 152 | "print (\"First call: \", jit(impure_saves_global)(4.))\n", 153 | "print (\"Saved global: \", g) # 保存的全局变量带有内部的JAX值" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "id": "engaging-eating", 159 | "metadata": {}, 160 | "source": [ 161 | "即时Python函数实际上在内部使用有状态的对象,只要它不读取或写入外部状态,它在功能上也可以是纯函数:" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 5, 167 | "id": "stopped-album", 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "name": "stdout", 172 | "output_type": "stream", 173 | "text": [ 174 | "50.0\n" 175 | ] 176 | } 177 | ], 178 | "source": [ 179 | "def pure_uses_internal_state(x):\n", 180 | " state = dict(even=0, odd=0)\n", 181 | " for i in range(10):\n", 182 | " state['even' if i % 2 == 0 else 'odd'] += x\n", 183 | " return state['even'] + state['odd']\n", 184 | "\n", 185 | "print(jit(pure_uses_internal_state)(5.))" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "id": "placed-holiday", 191 | "metadata": {}, 192 | "source": [ 193 | "不建议在需要 `jit` 的任何JAX函数或任何控制流原语中使用迭代器。因为迭代器是一个Python对象,它引入状态以检测下一个元素。因此,它与JAX功能编程模型不兼容。下面的代码中有一些将迭代器与JAX一起使用的错误示例。它们中的大多数返回错误,但有些会产生令人意外的结果:" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 6, 199 | "id": "attempted-complexity", 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "name": "stdout", 204 | "output_type": "stream", 205 | "text": [ 206 | "45\n", 207 | "0\n" 208 | ] 209 | }, 210 | { 211 | "ename": "TypeError", 212 | "evalue": "Value with type is not a valid JAX type", 213 | "output_type": "error", 214 | "traceback": [ 215 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 216 | "\u001b[0;31mFilteredStackTrace\u001b[0m Traceback (most recent call last)", 217 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# 报错\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0mlax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcond\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0miter_operand\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 218 | "\u001b[0;31mFilteredStackTrace\u001b[0m: TypeError: Value with type is not a valid JAX type\n\nThe stack trace above excludes JAX-internal frames.\nThe following is the original exception that occurred, unmodified.\n\n--------------------", 219 | "\nThe above exception was the direct cause of the following exception:\n", 220 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 221 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0miter_operand\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# 报错\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0mlax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcond\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0miter_operand\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 222 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mreraise_with_filtered_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 140\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_under_reraiser\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 223 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/lax/control_flow.py\u001b[0m in \u001b[0;36mcond\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 712\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_cond_with_per_branch_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mba\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 713\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 714\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_cond\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 715\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 716\u001b[0m def _cond_with_per_branch_args(pred,\n", 224 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/lax/control_flow.py\u001b[0m in \u001b[0;36m_cond\u001b[0;34m(pred, true_fun, false_fun, operand)\u001b[0m\n\u001b[1;32m 682\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 683\u001b[0m \u001b[0mops\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mops_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moperand\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 684\u001b[0;31m \u001b[0mops_avals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_abstractify\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mops\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 685\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 686\u001b[0m jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(\n", 225 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/util.py\u001b[0m in \u001b[0;36msafe_map\u001b[0;34m(f, *args)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32min\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'length mismatch: {}'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0munzip2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxys\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 226 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/lax/control_flow.py\u001b[0m in \u001b[0;36m_abstractify\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_abstractify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mraise_to_shaped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_aval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 111\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_typecheck_param\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg_required\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 227 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36mget_aval\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 919\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 920\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 921\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mconcrete_aval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 922\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 923\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 228 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36mconcrete_aval\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 911\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'__jax_array__'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 912\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mconcrete_aval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__jax_array__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 913\u001b[0;31m raise TypeError(f\"Value {repr(x)} with type {type(x)} is not a valid JAX \"\n\u001b[0m\u001b[1;32m 914\u001b[0m \"type\")\n\u001b[1;32m 915\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 229 | "\u001b[0;31mTypeError\u001b[0m: Value with type is not a valid JAX type" 230 | ] 231 | } 232 | ], 233 | "source": [ 234 | "import jax.numpy as jnp\n", 235 | "import jax.lax as lax\n", 236 | "from jax import make_jaxpr\n", 237 | "\n", 238 | "# lax.fori_loop\n", 239 | "array = jnp.arange(10)\n", 240 | "print(lax.fori_loop(0, 10, lambda i, x: x+array[i], 0)) # 预期结果是45\n", 241 | "iterator = iter(range(10))\n", 242 | "print(lax.fori_loop(0, 10, lambda i, x:x+next(iterator), 0)) #意外结果是0\n", 243 | "\n", 244 | "#lax.scan\n", 245 | "def func11(arr, extra):\n", 246 | " ones = jnp.ones(arr.shape)\n", 247 | " def body(carry, aelems):\n", 248 | " ae1, ae2 = aelems\n", 249 | " return (carry + ae1 * ae2 + extra, carry)\n", 250 | " return lax.scan(body, 0., (arr, ones))\n", 251 | "make_jaxpr(func11)(jnp.arange(16), 5.)\n", 252 | "\n", 253 | "#lax.cond\n", 254 | "array_operand = jnp.array([0.])\n", 255 | "lax.cond(True, lambda x: x+1, lambda x:x-1, array_operand)\n", 256 | "iter_operand = iter(range(10))\n", 257 | "# 报错\n", 258 | "lax.cond(True, lambda x: next(x)+1, lambda x:next(x)-1, iter_operand)" 259 | ] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "id": "moving-brisbane", 264 | "metadata": {}, 265 | "source": [ 266 | "## 🔪 就地更新\n", 267 | "\n", 268 | "在Numpy中我们习惯这么做:" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 7, 274 | "id": "unavailable-country", 275 | "metadata": {}, 276 | "outputs": [ 277 | { 278 | "name": "stdout", 279 | "output_type": "stream", 280 | "text": [ 281 | "original array:\n", 282 | "[[0. 0. 0.]\n", 283 | " [0. 0. 0.]\n", 284 | " [0. 0. 0.]]\n", 285 | "updated array:\n", 286 | "[[0. 0. 0.]\n", 287 | " [1. 1. 1.]\n", 288 | " [0. 0. 0.]]\n" 289 | ] 290 | } 291 | ], 292 | "source": [ 293 | "numpy_array = np.zeros((3, 3), dtype=np.float32)\n", 294 | "print(\"original array:\")\n", 295 | "print(numpy_array)\n", 296 | "\n", 297 | "# In place, mutating update\n", 298 | "# 直接可变更新\n", 299 | "numpy_array[1, :] = 1.0\n", 300 | "print(\"updated array:\")\n", 301 | "print(numpy_array)" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "id": "weird-comfort", 307 | "metadata": {}, 308 | "source": [ 309 | "但是,如果尝试就地更新JAX数组,则会收到错误消息! (☉_☉)" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 8, 315 | "id": "suited-trainer", 316 | "metadata": {}, 317 | "outputs": [ 318 | { 319 | "name": "stdout", 320 | "output_type": "stream", 321 | "text": [ 322 | "Exception '' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?\n" 323 | ] 324 | } 325 | ], 326 | "source": [ 327 | "jax_array = jnp.zeros((3,3), dtype=jnp.float32)\n", 328 | "\n", 329 | "# In place update of JAX's array will yield an error!\n", 330 | "# 就地更新JAX数组则会导致错误\n", 331 | "try:\n", 332 | " jax_array[1, :] = 1.0\n", 333 | "except Exception as e:\n", 334 | " print(\"Exception {}\".format(e))" 335 | ] 336 | }, 337 | { 338 | "cell_type": "markdown", 339 | "id": "nutritional-testing", 340 | "metadata": {}, 341 | "source": [ 342 | "### 原因是啥?!\n", 343 | "\n", 344 | "允许就地改变变量使得程序分析和转换非常困难。JAX需要数值程序的纯函数表达式。\n", 345 | "\n", 346 | "JAX提供了几个函数式的更新函数:[index_update](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_update.html#jax.ops.index_update), [index_add](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_add.html#jax.ops.index_add), [index_min](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_max.html#jax.ops.index_min), [index_max](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_max.html#jax.ops.index_max), 以及 [index](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index.html#jax.ops.index)辅助函数。\n", 347 | "\n", 348 | "️⚠️ 在`jit`代码的 `lax.while_loop`或 `lax.fori_loop`中,切片的大小不能做为参数值的函数,而只能是参数形状的函数——切片开始索引没有这种限制。有关此限制的更多信息,请参见下面的控制流部分。" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": 9, 354 | "id": "matched-morrison", 355 | "metadata": {}, 356 | "outputs": [], 357 | "source": [ 358 | "from jax.ops import index, index_add, index_update" 359 | ] 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "id": "growing-closure", 364 | "metadata": {}, 365 | "source": [ 366 | "### index_update\n", 367 | "\n", 368 | "如果 **index_update**的**input values** 没有被重用,`jit`编译的代码将会就地执行这些操作:" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": 11, 374 | "id": "selective-brake", 375 | "metadata": {}, 376 | "outputs": [ 377 | { 378 | "name": "stdout", 379 | "output_type": "stream", 380 | "text": [ 381 | "original array:\n", 382 | "[[0. 0. 0.]\n", 383 | " [0. 0. 0.]\n", 384 | " [0. 0. 0.]]\n", 385 | "old array unchanged:\n", 386 | "[[0. 0. 0.]\n", 387 | " [0. 0. 0.]\n", 388 | " [0. 0. 0.]]\n", 389 | "new array:\n", 390 | "[[0. 0. 0.]\n", 391 | " [1. 1. 1.]\n", 392 | " [0. 0. 0.]]\n" 393 | ] 394 | } 395 | ], 396 | "source": [ 397 | "jax_array = jnp.zeros((3, 3))\n", 398 | "print(\"original array:\")\n", 399 | "print(jax_array)\n", 400 | "\n", 401 | "new_jax_array = index_update(jax_array, index[1, :], 1.)\n", 402 | "\n", 403 | "print(\"old array unchanged:\")\n", 404 | "print(jax_array)\n", 405 | "\n", 406 | "print(\"new array:\")\n", 407 | "print(new_jax_array)" 408 | ] 409 | }, 410 | { 411 | "cell_type": "markdown", 412 | "id": "electoral-myanmar", 413 | "metadata": {}, 414 | "source": [ 415 | "### index_add\n", 416 | "\n", 417 | "如果 **index_add**的**input values** 没有被重用,`jit`编译的代码将会就地执行这些操作:\n" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": 12, 423 | "id": "moving-institute", 424 | "metadata": {}, 425 | "outputs": [ 426 | { 427 | "name": "stdout", 428 | "output_type": "stream", 429 | "text": [ 430 | "original array:\n", 431 | "[[1. 1. 1. 1. 1. 1.]\n", 432 | " [1. 1. 1. 1. 1. 1.]\n", 433 | " [1. 1. 1. 1. 1. 1.]\n", 434 | " [1. 1. 1. 1. 1. 1.]\n", 435 | " [1. 1. 1. 1. 1. 1.]]\n", 436 | "new array post-addition:\n", 437 | "[[1. 1. 1. 8. 8. 8.]\n", 438 | " [1. 1. 1. 1. 1. 1.]\n", 439 | " [1. 1. 1. 8. 8. 8.]\n", 440 | " [1. 1. 1. 1. 1. 1.]\n", 441 | " [1. 1. 1. 8. 8. 8.]]\n" 442 | ] 443 | } 444 | ], 445 | "source": [ 446 | "print(\"original array:\")\n", 447 | "jax_array = jnp.ones((5, 6))\n", 448 | "print(jax_array)\n", 449 | "\n", 450 | "new_jax_array = index_add(jax_array, index[::2, 3:], 7)\n", 451 | "print(\"new array post-addition:\")\n", 452 | "print(new_jax_array)" 453 | ] 454 | }, 455 | { 456 | "cell_type": "markdown", 457 | "id": "desirable-motivation", 458 | "metadata": {}, 459 | "source": [ 460 | "## 🔪 越界索引\n", 461 | "\n", 462 | "在Numpy中,当您在数组的边界之外索引时往往会报错,如下所示:\n" 463 | ] 464 | }, 465 | { 466 | "cell_type": "code", 467 | "execution_count": 13, 468 | "id": "fiscal-grant", 469 | "metadata": {}, 470 | "outputs": [ 471 | { 472 | "name": "stdout", 473 | "output_type": "stream", 474 | "text": [ 475 | "Exception index 11 is out of bounds for axis 0 with size 10\n" 476 | ] 477 | } 478 | ], 479 | "source": [ 480 | "try:\n", 481 | " np.arange(10)[11]\n", 482 | "except Exception as e:\n", 483 | " print(\"Exception {}\".format(e))\n" 484 | ] 485 | }, 486 | { 487 | "cell_type": "markdown", 488 | "id": "identical-central", 489 | "metadata": {}, 490 | "source": [ 491 | "然而,在加速设备上引发错误更加困难。因此,JAX不会报错,而是将索引限制在数组的边界上,也就是返回数组的最后一个值:" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": 14, 497 | "id": "periodic-catalog", 498 | "metadata": {}, 499 | "outputs": [ 500 | { 501 | "data": { 502 | "text/plain": [ 503 | "DeviceArray(9, dtype=int32)" 504 | ] 505 | }, 506 | "execution_count": 14, 507 | "metadata": {}, 508 | "output_type": "execute_result" 509 | } 510 | ], 511 | "source": [ 512 | "jnp.arange(10)[11]\n" 513 | ] 514 | }, 515 | { 516 | "cell_type": "markdown", 517 | "id": "initial-auditor", 518 | "metadata": {}, 519 | "source": [ 520 | "注意,由于这种行为,`jnp.nanargmin` 和 `jnp.nanargmax` 对于由NaN组成的切片返回-1,Numpy则会报错。\n", 521 | "\n", 522 | "## 🔪 随机数\n", 523 | "\n", 524 | "> 如果所有由于 `rand()` 不好而导致结果令人怀疑的科学论文都从图书馆的书架上消失了,那么每个书架上的空隙会有拳头那么大。\n", 525 | ">\n", 526 | "> —— 数字食谱\n", 527 | "\n", 528 | "### RNGs(随机数生成器)和状态\n", 529 | "\n", 530 | "你已经习惯了使用Numpy或其他库中具有状态的为随机数生成器(PRNGs),这些函数有助于隐藏很多细节,给您直接准备好伪随机源:\n", 531 | "\n" 532 | ] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "execution_count": 15, 537 | "id": "hourly-fighter", 538 | "metadata": {}, 539 | "outputs": [ 540 | { 541 | "name": "stdout", 542 | "output_type": "stream", 543 | "text": [ 544 | "0.801690546086135\n", 545 | "0.17685059614480447\n", 546 | "0.09584388259855248\n" 547 | ] 548 | } 549 | ], 550 | "source": [ 551 | "print(np.random.random())\n", 552 | "print(np.random.random())\n", 553 | "print(np.random.random())" 554 | ] 555 | }, 556 | { 557 | "cell_type": "markdown", 558 | "id": "stuffed-meaning", 559 | "metadata": {}, 560 | "source": [ 561 | "在后端,Numpy使用Mersenne Twister PRNG为期" 562 | ] 563 | } 564 | ], 565 | "metadata": { 566 | "kernelspec": { 567 | "display_name": "Python 3", 568 | "language": "python", 569 | "name": "python3" 570 | }, 571 | "language_info": { 572 | "codemirror_mode": { 573 | "name": "ipython", 574 | "version": 3 575 | }, 576 | "file_extension": ".py", 577 | "mimetype": "text/x-python", 578 | "name": "python", 579 | "nbconvert_exporter": "python", 580 | "pygments_lexer": "ipython3", 581 | "version": "3.9.2" 582 | } 583 | }, 584 | "nbformat": 4, 585 | "nbformat_minor": 5 586 | } 587 | -------------------------------------------------------------------------------- /official-tutorials/GettingStarted/1.1-快速开始.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "capital-transmission", 6 | "metadata": {}, 7 | "source": [ 8 | "# JAX快速开始\n", 9 | "\n", 10 | "`JAX` 是CPU, GPU和TPU上的Numpy实现,具有出色的自动求导功能,可用于高性能机器学习研究。\n", 11 | "\n", 12 | "使用最新的 [Autograd](https://github.com/hips/autograd) 库,`JAX`可以自动求导原生的Python和Numpy代码。在Python的循环、条件、递归和闭包中也能够轻松使用,甚至可以求微分的微分的微分。它也支持反向模式和正向模式微分,并且以任意的顺序组合。\n", 13 | "\n", 14 | "`JAX`使用 [XLA](https://www.tensorflow.org/xla) 来在GPU或TPU加速器上编译和运行代码。默认情况下编译在后台进行,并且库调用会即时编译(JIT)和执行。`JAX`甚至可以仅用一条函数API来让你将自己写的Python函数即时编译成XLA优化核。您可以任意组合编译和自动微分,无需离开Python即可变大复杂的算法并且获得最佳的性能。\n", 15 | "\n", 16 | "首先我们先导入常用的JAX库:" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "id": "checked-auditor", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import jax.numpy as jnp\n", 27 | "from jax import grad, jit, vmap\n", 28 | "from jax import random" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "id": "standing-payday", 34 | "metadata": {}, 35 | "source": [ 36 | "## 矩阵乘法\n", 37 | "\n", 38 | "我们用以下的示例来生成随机数据。与Numpy的一个较大的不同点是,JAX和它生成随机数的方式不同。详细内容参考[JAX锋芒毕露:随机数](https://render.githubusercontent.com/view/ipynb?color_mode=light&commit=6ac9c12ef0d554cbb52e5117d4a87ce431069d39&enc_url=68747470733a2f2f7261772e67697468756275736572636f6e74656e742e636f6d2f726173696e2d7473756b7562612f4a41585f6368696e6573655f7475746f7269616c2f366163396331326566306435353463626235326535313137643461383763653433313036396433392f636f64652f312e332d4a41582545392539342538422545382538412539322545362541462539352545392539432542322e6970796e62&nwo=rasin-tsukuba%2FJAX_chinese_tutorial&path=code%2F1.3-JAX%E9%94%8B%E8%8A%92%E6%AF%95%E9%9C%B2.ipynb&repository_id=349726397&repository_type=Repository#%F0%9F%94%AA-%E9%9A%8F%E6%9C%BA%E6%95%B0)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "id": "sporting-basics", 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "name": "stdout", 49 | "output_type": "stream", 50 | "text": [ 51 | "[-0.3721109 0.26423115 -0.18252768 -0.7368197 -0.44030377 -0.1521442\n", 52 | " -0.67135346 -0.5908641 0.73168886 0.5673026 ]\n" 53 | ] 54 | } 55 | ], 56 | "source": [ 57 | "key = random.PRNGKey(0)\n", 58 | "x = random.normal(key, (10, ))\n", 59 | "print(x)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "returning-shade", 65 | "metadata": {}, 66 | "source": [ 67 | "让我们现在开始,给两个大矩阵做乘法:" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "id": "effective-buffalo", 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "10.1 ms ± 243 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "size = 3000\n", 86 | "x = random.normal(key, (size, size), dtype=jnp.float32)\n", 87 | "%timeit jnp.dot(x, x.T).block_until_ready()" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "id": "checked-exclusion", 93 | "metadata": {}, 94 | "source": [ 95 | "我们添加了 `block_until_ready`,因为默认情况下JAX采用异步执行(详见 异步调度)。\n", 96 | "\n", 97 | "JAX的NumPy函数也可以用于普通的NumPy数组。" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 4, 103 | "id": "confidential-broadcast", 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "name": "stdout", 108 | "output_type": "stream", 109 | "text": [ 110 | "28.4 ms ± 733 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" 111 | ] 112 | } 113 | ], 114 | "source": [ 115 | "import numpy as np\n", 116 | "x = np.random.normal(size=(size, size)).astype(np.float32)\n", 117 | "%timeit jnp.dot(x, x.T).block_until_ready()" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "id": "selective-africa", 123 | "metadata": {}, 124 | "source": [ 125 | "这样做会比较慢,因为每次都必须将数据传送到GPU。您可以使用 `device_put()`确保 `NDArray`由设备内存支持。" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 5, 131 | "id": "hungry-friendly", 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "9.6 ms ± 139 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" 139 | ] 140 | } 141 | ], 142 | "source": [ 143 | "from jax import device_put\n", 144 | "\n", 145 | "x = np.random.normal(size=(size, size)).astype(np.float32)\n", 146 | "x = device_put(x)\n", 147 | "%timeit jnp.dot(x, x.T).block_until_ready()" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "id": "stone-landscape", 153 | "metadata": {}, 154 | "source": [ 155 | "`device_put()` 的输出仍然类似 `NDArray`,但仅需在打印、绘图、保存和分支等需要它们的值的时候才将值复制回CPU。 `device_put()`的行为等效于函数 `jit(lambda x: x)`,但速度更快。\n", 156 | "\n", 157 | "如果你有GPU或TPU,这些调用都会在加速设备上云子那个,并且有比CPU更快的速度。" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 6, 163 | "id": "composed-indication", 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "name": "stdout", 168 | "output_type": "stream", 169 | "text": [ 170 | "65.7 ms ± 10.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" 171 | ] 172 | } 173 | ], 174 | "source": [ 175 | "x = np.random.normal(size=(size, size)).astype(np.float32)\n", 176 | "%timeit np.dot(x, x.T)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "id": "compressed-comparative", 182 | "metadata": {}, 183 | "source": [ 184 | "JAX不仅仅是一个由GPU支持的NumPy。他还带有一些程序转换,这些转换在编写数值代码的时候很有用。目前主要有三个:\n", 185 | "\n", 186 | "* `jit()`,用于加速你的代码\n", 187 | "* `grad()`,用于微分\n", 188 | "* `vmap()`,用于自动向量化或批处理\n", 189 | "\n", 190 | "接下来我们一一介绍,我们还将以有趣的方式来编写这些内容。\n", 191 | "\n", 192 | "## 使用 `jit()`来加速你的代码\n", 193 | "\n", 194 | "JAX可以在GPU(或CPU,如果您没有GPU的话,TPU支持即将到来!)上透明地运行。但是,在上面的实例中,JAX一次只将一个内核分配给GPU操作。如果有一系列操作,则可以使用 `@jit` 装饰器使用XLA一起编译多个操作。让我们尝试一下:" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 7, 200 | "id": "printable-throat", 201 | "metadata": {}, 202 | "outputs": [ 203 | { 204 | "name": "stdout", 205 | "output_type": "stream", 206 | "text": [ 207 | "794 µs ± 18 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" 208 | ] 209 | } 210 | ], 211 | "source": [ 212 | "def selu(x, alpha=1.67, lmbda=1.05):\n", 213 | " return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)\n", 214 | "\n", 215 | "x = random.normal(key, (1000000, ))\n", 216 | "%timeit selu(x).block_until_ready()" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "id": "personal-minutes", 222 | "metadata": {}, 223 | "source": [ 224 | "我们可以用 `@jit`加快速度,它将在首次调用 `selu` 的时候进行jit编译,然后将其缓存。" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 8, 230 | "id": "accepted-independence", 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "name": "stdout", 235 | "output_type": "stream", 236 | "text": [ 237 | "51.6 µs ± 904 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" 238 | ] 239 | } 240 | ], 241 | "source": [ 242 | "selu_jit = jit(selu)\n", 243 | "%timeit selu_jit(x).block_until_ready()" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "id": "compound-niger", 249 | "metadata": {}, 250 | "source": [ 251 | "## 使用`grad()`来微分\n", 252 | "\n", 253 | "除了计算数值函数以外,我们还希望对其进行变换。一种变换是[自动微分](https://en.wikipedia.org/wiki/Automatic_differentiation)。在JAX中,就像 [Autograd](https://github.com/HIPS/autograd),你可以使用`grad()`函数计算微分。" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 9, 259 | "id": "dominican-maryland", 260 | "metadata": {}, 261 | "outputs": [ 262 | { 263 | "name": "stdout", 264 | "output_type": "stream", 265 | "text": [ 266 | "[0.25 0.19661197 0.10499357]\n" 267 | ] 268 | } 269 | ], 270 | "source": [ 271 | "def sum_logistic(x):\n", 272 | " return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))\n", 273 | "\n", 274 | "x_small = jnp.arange(3.)\n", 275 | "derivative_fn = grad(sum_logistic)\n", 276 | "print(derivative_fn(x_small))" 277 | ] 278 | }, 279 | { 280 | "cell_type": "markdown", 281 | "id": "lightweight-underground", 282 | "metadata": {}, 283 | "source": [ 284 | "让我们使用差分来验证我们的计算结果是正确的:" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 10, 290 | "id": "little-reverse", 291 | "metadata": {}, 292 | "outputs": [ 293 | { 294 | "name": "stdout", 295 | "output_type": "stream", 296 | "text": [ 297 | "[0.24998187 0.1965761 0.10502338]\n" 298 | ] 299 | } 300 | ], 301 | "source": [ 302 | "def first_finite_differences(f, x):\n", 303 | " eps = 1e-3\n", 304 | " return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps) for v in jnp.eye(len(x))])\n", 305 | "\n", 306 | "print(first_finite_differences(sum_logistic, x_small))" 307 | ] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "id": "incident-spyware", 312 | "metadata": {}, 313 | "source": [ 314 | "求微分就像调用 `grad()`一样容易。`grad()`和`jit()`可以任意组合。在上面的示例中,我们将`sum_logistic`设置为即时编译,然后取其微分。我们也可以更进一步:" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 11, 320 | "id": "genuine-arrow", 321 | "metadata": {}, 322 | "outputs": [ 323 | { 324 | "name": "stdout", 325 | "output_type": "stream", 326 | "text": [ 327 | "-0.035325605\n" 328 | ] 329 | } 330 | ], 331 | "source": [ 332 | "print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "id": "seeing-roman", 338 | "metadata": {}, 339 | "source": [ 340 | "对于更高级的自动微分,可以调用`jax.vjp()`用于反向模式的向量-雅克比积(vector-Jacobian products)和`jax.jvp()`用于前向模式的雅克比-向量积(Jacobian-vector products)。这两者也可以任意组合彼此,也可以与其他的JAX转换互相组合。这里提供了一种他们构成有效计算完整的Hessian矩阵函数的方法:" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": 12, 346 | "id": "christian-tucson", 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "from jax import jacfwd, jacrev\n", 351 | "\n", 352 | "def hessian(fun):\n", 353 | " return jit(jacfwd(jacrev(fun)))" 354 | ] 355 | }, 356 | { 357 | "cell_type": "markdown", 358 | "id": "administrative-posting", 359 | "metadata": {}, 360 | "source": [ 361 | "## 使用`vmap()`来自动向量化\n", 362 | "\n", 363 | "JAX在其API中还有另一种可能有用的转换:`vmap()`,向量化映射。它具有沿数组轴映射函数的类似语义,但不是将循环保留在外部,而是将循环推入函数的原始操作中以提高性能。当与`jit()`组合时,它的速度可以与手动添加批梯度一样快。这里我们将使用一个简单的示例,并使用 `vmap()`将矩阵-向量乘积提升为矩阵-矩阵乘积。尽管在这种特定情况下很容易手动完成此操作,但是这种技术可以用于更加复杂的功能。" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": 13, 369 | "id": "instant-recovery", 370 | "metadata": {}, 371 | "outputs": [], 372 | "source": [ 373 | "mat = random.normal(key, (150, 100))\n", 374 | "batched_x = random.normal(key, (10, 100))\n", 375 | "\n", 376 | "def apply_matrix(v):\n", 377 | " return jnp.dot(mat, v)" 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "id": "cardiovascular-height", 383 | "metadata": {}, 384 | "source": [ 385 | "给定诸如`apply_matrix()`之类的函数,我们可以在Python中循环执行批处理维度,但是这么做的性能很差。" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": 14, 391 | "id": "advanced-gothic", 392 | "metadata": {}, 393 | "outputs": [ 394 | { 395 | "name": "stdout", 396 | "output_type": "stream", 397 | "text": [ 398 | "Naively batched\n", 399 | "2.99 ms ± 62.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" 400 | ] 401 | } 402 | ], 403 | "source": [ 404 | "def naively_batched_apply_matrix(v_batched):\n", 405 | " return jnp.stack([apply_matrix(v) for v in v_batched])\n", 406 | "\n", 407 | "print('Naively batched')\n", 408 | "%timeit naively_batched_apply_matrix(batched_x).block_until_ready()" 409 | ] 410 | }, 411 | { 412 | "cell_type": "markdown", 413 | "id": "junior-kenya", 414 | "metadata": {}, 415 | "source": [ 416 | "我们知道如何手动批处理该操作。在这种情况下,`jnp.dot()`能够透明地处理额外的批处理维度。" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 15, 422 | "id": "dutch-movement", 423 | "metadata": {}, 424 | "outputs": [ 425 | { 426 | "name": "stdout", 427 | "output_type": "stream", 428 | "text": [ 429 | "Manually batched\n", 430 | "34.3 µs ± 2.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" 431 | ] 432 | } 433 | ], 434 | "source": [ 435 | "@jit\n", 436 | "def batched_apply_matrix(v_batched):\n", 437 | " return jnp.dot(v_batched, mat.T)\n", 438 | "\n", 439 | "print('Manually batched')\n", 440 | "%timeit batched_apply_matrix(batched_x).block_until_ready()" 441 | ] 442 | }, 443 | { 444 | "cell_type": "markdown", 445 | "id": "interior-worker", 446 | "metadata": {}, 447 | "source": [ 448 | "然而,假设没有批处理支持,我们的函数可能更加复杂。我们可以使用`vmap()`自动添加批处理支持:" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": 16, 454 | "id": "meaningful-bowling", 455 | "metadata": {}, 456 | "outputs": [ 457 | { 458 | "name": "stdout", 459 | "output_type": "stream", 460 | "text": [ 461 | "Auto-vecctorized with vmap\n", 462 | "36.8 µs ± 6.56 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" 463 | ] 464 | } 465 | ], 466 | "source": [ 467 | "@jit\n", 468 | "def vmap_batched_apply_matrix(v_batched):\n", 469 | " return vmap(apply_matrix)(v_batched)\n", 470 | "\n", 471 | "print('Auto-vecctorized with vmap')\n", 472 | "%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()" 473 | ] 474 | }, 475 | { 476 | "cell_type": "markdown", 477 | "id": "obvious-auction", 478 | "metadata": {}, 479 | "source": [ 480 | "当然,`vmap()`可以和`jit()`,`grad()`和任何其他JAX变换任意组合。这只是JAX的冰山一角,我们很兴奋您将使用JAX!" 481 | ] 482 | } 483 | ], 484 | "metadata": { 485 | "kernelspec": { 486 | "display_name": "Python 3", 487 | "language": "python", 488 | "name": "python3" 489 | }, 490 | "language_info": { 491 | "codemirror_mode": { 492 | "name": "ipython", 493 | "version": 3 494 | }, 495 | "file_extension": ".py", 496 | "mimetype": "text/x-python", 497 | "name": "python", 498 | "nbconvert_exporter": "python", 499 | "pygments_lexer": "ipython3", 500 | "version": "3.9.2" 501 | } 502 | }, 503 | "nbformat": 4, 504 | "nbformat_minor": 5 505 | } 506 | -------------------------------------------------------------------------------- /official-tutorials/GettingStarted/Tutorial:Jax101/.ipynb_checkpoints/1.4.1-加速版Numpy——JAX -checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 5 6 | } 7 | -------------------------------------------------------------------------------- /official-tutorials/GettingStarted/Tutorial:Jax101/.ipynb_checkpoints/1.4.2-JAX的即时编译-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "technical-processing", 6 | "metadata": {}, 7 | "source": [ 8 | "# JAX的即时编译\n", 9 | "\n", 10 | "> 作者:Rosalia Schneider & Vladimir Mikulik\n", 11 | "\n", 12 | "在本节中,我们将进一步讨论JAX的工作原理,以及如何使其具有高性能。我们将讨论`jax.jit()`变换,该变换将执行JAX Python函数的即时编译(JIT),以便可以在XLA中有效地执行该转换。\n", 13 | "\n", 14 | "## 如何使用JAX变换\n", 15 | "\n", 16 | "在上一节中,我们讨论了JAX允许我们变换Python函数。这是通过首先将Python函数转换为一种简单的中间语言jaxpr来完成的。之后,转换将在jaxpr形式上进行。\n", 17 | "\n", 18 | "我们可以用 `jax.make_jaxpr` 来显示函数的jaxpr形式:" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "id": "corrected-onion", 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stdout", 29 | "output_type": "stream", 30 | "text": [ 31 | "{ lambda ; a.\n", 32 | " let b = log a\n", 33 | " c = log 2.0\n", 34 | " d = div b c\n", 35 | " in (d,) }\n" 36 | ] 37 | } 38 | ], 39 | "source": [ 40 | "import jax\n", 41 | "import jax.numpy as jnp\n", 42 | "\n", 43 | "global_list = []\n", 44 | "\n", 45 | "def log2(x):\n", 46 | " global_list.append(x)\n", 47 | " ln_x = jnp.log(x)\n", 48 | " ln_2 = jnp.log(2.0)\n", 49 | " \n", 50 | " return ln_x / ln_2\n", 51 | "\n", 52 | "print(jax.make_jaxpr(log2)(3.0))" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "southeast-spray", 58 | "metadata": {}, 59 | "source": [ 60 | "在教程中的[理解Jaxprs]()部分提供了有关上述输出含义的更多信息。\n", 61 | "\n", 62 | "请注意,很重要的一点是jaxpr无法捕获该函数的副作用:其中没有与`global_list.append(x)`的内容。这是一个特性,并不是一个漏洞:JAX旨在理解无副作用的代码。如果您不太熟悉纯函数和副作用这两个术语,请参见[JAX锋芒毕露:🔪纯函数](https://render.githubusercontent.com/view/ipynb?color_mode=light&commit=fe4a5f85bf7936468ed39f20cced5b25a1612efb&enc_url=68747470733a2f2f7261772e67697468756275736572636f6e74656e742e636f6d2f726173696e2d7473756b7562612f4a41585f6368696e6573655f7475746f7269616c2f666534613566383562663739333634363865643339663230636365643562323561313631326566622f6f6666696369616c2d7475746f7269616c732f47657474696e67537461727465642f312e332d4a41582545392539342538422545382538412539322545362541462539352545392539432542322e6970796e62&nwo=rasin-tsukuba%2FJAX_chinese_tutorial&path=official-tutorials%2FGettingStarted%2F1.3-JAX%E9%94%8B%E8%8A%92%E6%AF%95%E9%9C%B2.ipynb&repository_id=349726397&repository_type=Repository#%F0%9F%94%AA%E7%BA%AF%E5%87%BD%E6%95%B0)。\n", 63 | "\n", 64 | "当然,非纯函数仍然可以编写甚至运行,但是一旦转换为jaxpr,JAX就无法保证其行为。但根据经验,您可以期望(但不应该依赖)JAX转换函数的副作用只运行一次(在第一次调用时)之后再也不会运行。这是因为JAX使用称为“跟踪”的过程生成jaxpr的方式。\n", 65 | "\n", 66 | "跟踪时,JAX用跟踪器对象包装每个参数。然后,这些跟踪器记录函数调用期间对他们执行的所有JAX操作(发生在Python代码之中)。之后,JAX使用跟踪记录来重构整个函数。该重建的输出是jaxpr。由于跟踪其没有记录Python的副作用,因此它们不会出现在jaxpr中。但是,副作用仍会在跟踪期间发生。\n", 67 | "\n", 68 | "注意:Python的 `print()` 不是纯函数:文本输出是该函数的副作用。因此,任何 `print()`调用都只会在跟踪过程中发生,而不会出现在jaxpr中:" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 2, 74 | "id": "geological-nicaragua", 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "name": "stdout", 79 | "output_type": "stream", 80 | "text": [ 81 | "printed x: Tracedwith\n", 82 | "{ lambda ; a.\n", 83 | " let b = log a\n", 84 | " c = log 2.0\n", 85 | " d = convert_element_type[ new_dtype=float32\n", 86 | " weak_type=False ] b\n", 87 | " e = div d c\n", 88 | " in (e,) }\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "def log2_with_print(x):\n", 94 | " print(\"printed x: \", x)\n", 95 | " ln_x = jnp.log(x)\n", 96 | " ln_2 = jnp.log(2)\n", 97 | " return ln_x / ln_2\n", 98 | "\n", 99 | "print(jax.make_jaxpr(log2_with_print)(3.))" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "id": "worse-london", 105 | "metadata": {}, 106 | "source": [ 107 | "看到打印的`x`成为一个 `Traced` 对象了吗?这就是JAX的内部运行机制。\n", 108 | "\n", 109 | "Python代码至少运行一次的事实严格上来说是实现细节,因此不应该对其有依赖。但是,理解它很有用,因为您可以调试以打印出计算的中间值时使用它。\n", 110 | "\n", 111 | "关键要理解的是,jaxpr会捕获对给定参数执行的功能。例如,如果我们有条件,那么jaxpr将只知道我们采取的分支:" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 3, 117 | "id": "parallel-physiology", 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "name": "stdout", 122 | "output_type": "stream", 123 | "text": [ 124 | "{ lambda ; a.\n", 125 | " let \n", 126 | " in (a,) }\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "def log2_if_rank_2(x):\n", 132 | " if x.ndim == 2:\n", 133 | " ln_x = jnp.log(x)\n", 134 | " ln_2 = jnp.log(2)\n", 135 | " return ln_x / ln_2\n", 136 | " else:\n", 137 | " return x\n", 138 | " \n", 139 | "print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1., 2., 3.])))" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "id": "indirect-witness", 145 | "metadata": {}, 146 | "source": [ 147 | "## 使用JIT编译函数\n", 148 | "\n", 149 | "如前所述,JAX使操作可以使用相同的的代码在CPU、GPU和TPU上执行。让我们来看一个计算比例指数线性单位(SELU)的示例,这是深度学习中常用的一种运算:" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 5, 155 | "id": "swiss-basis", 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "866 µs ± 20.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" 163 | ] 164 | } 165 | ], 166 | "source": [ 167 | "import jax\n", 168 | "import jax.numpy as jnp\n", 169 | "\n", 170 | "def selu(x, alpha=1.67, lambda_=1.05):\n", 171 | " return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)\n", 172 | "\n", 173 | "x = jnp.arange(1000000)\n", 174 | "%timeit selu(x).block_until_ready()" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "id": "human-neighbor", 180 | "metadata": {}, 181 | "source": [ 182 | "以上代码一次性向加速器发送了一个操作。这限制了XLA编译器优化功能的能力。\n", 183 | "\n", 184 | "自然,我们想要做的事给XLA编译器尽可能多的代码,以便它可以完全优化它。为此,JAX提供了 `jax.jit`转换,它将即时编译JAX兼容的函数。下面的示例显示了如何使用JIT来加快此函数:" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 6, 190 | "id": "written-spectrum", 191 | "metadata": {}, 192 | "outputs": [ 193 | { 194 | "name": "stdout", 195 | "output_type": "stream", 196 | "text": [ 197 | "52.3 µs ± 473 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" 198 | ] 199 | } 200 | ], 201 | "source": [ 202 | "selu_jit = jax.jit(selu)\n", 203 | "\n", 204 | "# warm up\n", 205 | "#预热\n", 206 | "selu_jit(x).block_until_ready()\n", 207 | "\n", 208 | "%timeit selu_jit(x).block_until_ready()" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "id": "textile-editor", 214 | "metadata": {}, 215 | "source": [ 216 | "以下是刚才发生事情的详细解释:\n", 217 | "\n", 218 | "1. 我们将`selu_jit`定义为`selu`的编译版本\n", 219 | "2. 我们在`x`上运行一次 `selu_jit`。这就是JAX进行跟踪的地方——毕竟他需要一些输入才能包装在跟踪器中。然后,使用XLA将jaxpr编译为针对您的GPU或TPU优化的非常有效的代码。现在,对`selu_jit`的后续调用将使用改代码,从而完全跳过我们以前的Python实现。\n", 220 | "\n" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "id": "incomplete-exchange", 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [] 230 | } 231 | ], 232 | "metadata": { 233 | "kernelspec": { 234 | "display_name": "Python 3", 235 | "language": "python", 236 | "name": "python3" 237 | }, 238 | "language_info": { 239 | "codemirror_mode": { 240 | "name": "ipython", 241 | "version": 3 242 | }, 243 | "file_extension": ".py", 244 | "mimetype": "text/x-python", 245 | "name": "python", 246 | "nbconvert_exporter": "python", 247 | "pygments_lexer": "ipython3", 248 | "version": "3.9.2" 249 | } 250 | }, 251 | "nbformat": 4, 252 | "nbformat_minor": 5 253 | } 254 | -------------------------------------------------------------------------------- /official-tutorials/GettingStarted/Tutorial:Jax101/.ipynb_checkpoints/1.4.3-JAX的自动向量化-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 5 6 | } 7 | -------------------------------------------------------------------------------- /official-tutorials/GettingStarted/Tutorial:Jax101/.ipynb_checkpoints/1.4.4-JAX中的高级自动微分-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 5 6 | } 7 | -------------------------------------------------------------------------------- /official-tutorials/GettingStarted/Tutorial:Jax101/1.4.1-加速版Numpy——JAX .ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "textile-finger", 6 | "metadata": {}, 7 | "source": [ 8 | "# 加速版Numpy——JAX \n", 9 | "\n", 10 | "> 作者:Rosalia Schneider & Vladimir Mikulik\n", 11 | "\n", 12 | "在第一节我们会学习最基础的JAX知识。\n", 13 | "\n", 14 | "## JAX numpy入门\n", 15 | "\n", 16 | "基本上,JAX是一个可用于转换使用类似NumPy的API编写的数组操作程序库。在这一系列指南中,我们将准确地阐释其含义。您可以将JAX视为可在加速设备上运行的可微分NumPy。\n", 17 | "\n", 18 | "下面的代码展示了如何导入JAX和创建向量:" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "id": "studied-dispatch", 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stdout", 29 | "output_type": "stream", 30 | "text": [ 31 | "[0 1 2 3 4 5 6 7 8 9]\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "import jax\n", 37 | "import jax.numpy as jnp\n", 38 | "\n", 39 | "x = jnp.arange(10)\n", 40 | "print(x)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "id": "thousand-amendment", 46 | "metadata": {}, 47 | "source": [ 48 | "目前一些都和NumPy一模一样。JAX的一大美丽在于,您无需学习新的API。如果将`np`替换为 `jnp`,则许多常见的NumPy程序在JAX中运行的结果也一样。但是,在本届末尾我们将会谈到一些重要的区别。\n", 49 | "\n", 50 | "如果检查 `x` 的类型,您会注意到第一个区别:它是一个 `DeviceArray`类型的变量,是JAX表示数组的方式。" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "id": "floral-gnome", 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "data": { 61 | "text/plain": [ 62 | "DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)" 63 | ] 64 | }, 65 | "execution_count": 2, 66 | "metadata": {}, 67 | "output_type": "execute_result" 68 | } 69 | ], 70 | "source": [ 71 | "x" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "id": "surrounded-computer", 77 | "metadata": {}, 78 | "source": [ 79 | "JAX一个有用的功能室相同的代码可以在不同的后端(CPU、GPU和TPU)上运行。\n", 80 | "\n", 81 | "现在我们执行一个点乘操作,以演示可以在不更改代码的情况下在不同设备中完成惩罚。我们用`%timeit`来检验其性能。\n", 82 | "\n", 83 | "一个技术细节:当调用JAX函数时,相应的操作将分派给加速器,以便在可能的情况下进异步计算。因此,返回的数组不必在函数后立即被“填充”。如果不立即要求结果,那么该计算就不会阻止Python执行。因此,除非我们执行 `block_until_ready`,我们将只对调度进行计时,而不是对实时计算计时。详情请参阅[异步调度](https://github.com/rasin-tsukuba/JAX_chinese_tutorial/blob/main/official-tutorials/ReferenceDocumentation/2.3-%E5%BC%82%E6%AD%A5%E8%B0%83%E5%BA%A6.ipynb)。" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 3, 89 | "id": "entire-aberdeen", 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "218 µs ± 1.97 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "long_vector = jnp.arange(int(1e7))\n", 102 | "\n", 103 | "%timeit jnp.dot(long_vector, long_vector).block_until_ready()" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "id": "addressed-preparation", 109 | "metadata": {}, 110 | "source": [ 111 | "小提示:尝试两次运行以上代码,一次不用加速器,一次使用GPU运行。(在Colab中,单击Runtime->Change Runtime Type,并选择GPU)。查看它在GPU上运行的速度。\n", 112 | "\n", 113 | "## JAX的第一个变换:`grad`\n", 114 | "\n", 115 | "JAX的一个基本功能是允许您对函数进行变换。 `jax.grad`是最常用的变换之一,它采用Python编写的数值函数,并返回一个新的Python函数,该函数可计算原始函数的梯度。\n", 116 | "\n", 117 | "我们首先定义一个函数,该函数需要一个数组并返回平方和。" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 4, 123 | "id": "laughing-senator", 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "def sum_of_squares(x):\n", 128 | " return jnp.sum(x ** 2)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "id": "outside-graham", 134 | "metadata": {}, 135 | "source": [ 136 | "将 `jax.grad`应用于 `sum_of_square`将返回一个不同的函数,即 `sum_of_squares`对于其第一个参数`x`的梯度。您可以在数组上使用该函数已返回语数组中每个元素有关的倒数。" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 5, 142 | "id": "selected-discipline", 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "name": "stdout", 147 | "output_type": "stream", 148 | "text": [ 149 | "30.0\n", 150 | "[2. 4. 6. 8.]\n" 151 | ] 152 | } 153 | ], 154 | "source": [ 155 | "sum_of_squares_dx = jax.grad(sum_of_squares)\n", 156 | "x = jnp.asarray([1., 2., 3., 4.])\n", 157 | "print(sum_of_squares(x))\n", 158 | "print(sum_of_squares_dx(x))" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "id": "cubic-maldives", 164 | "metadata": {}, 165 | "source": [ 166 | "您可以将 `jax.grad`类比为向量微积分中的 $\\triangledown $操作符。给定一个函数 $f(x)$, $\\triangledown f$代表计算函数$f$梯度的函数,也就是:\n", 167 | "\n", 168 | "$$\n", 169 | "(\\triangledown f)(x)_i = \\frac{\\partial f}{\\partial x_i}(x)\n", 170 | "$$\n", 171 | "\n", 172 | "同理,既然`jax.grad(f)`是计算梯度的函数,那么 `jax.grad(f)(x)`就是函数`f`对于`x`的导数(类似 $\\triangledown$,`jax.grad`只能在输出为标量的函数中,否则将会报错)。\n", 173 | "\n", 174 | "这使得JAX API与Tensorflow和PyTorch等其他autodiff库完全不同(例如通过调用`loss.backward()`),在其中我们使用损失张量本身来计算梯度。JAX API直接与函数配合使用,与基础数学保持更紧密的练习。一旦习惯了这种处理方式,您就会感觉很自然:代码中的损失函数实际上是参数和数据的函数,并且您会发现它的梯度就像在数学中一样。\n", 175 | "\n", 176 | "这种工作方式使控制事件变得简单,例如要对哪些变量求导。默认情况下,`jax.grad`将找到相对于第一个参数的梯度。在下面的示例中,`sum_squared_error_dx`的结果是`sum_squared_error`相对于`x`的梯度:" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 6, 182 | "id": "unknown-pride", 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "name": "stdout", 187 | "output_type": "stream", 188 | "text": [ 189 | "[-0.20000005 -0.19999981 -0.19999981 -0.19999981]\n" 190 | ] 191 | } 192 | ], 193 | "source": [ 194 | "def sum_squared_error(x, y):\n", 195 | " return jnp.sum((x-y) ** 2)\n", 196 | "\n", 197 | "sum_squared_error_dx = jax.grad(sum_squared_error)\n", 198 | "\n", 199 | "y = jnp.asarray([1.1, 2.1, 3.1, 4.1])\n", 200 | "print(sum_squared_error_dx(x, y))" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "id": "breeding-catalyst", 206 | "metadata": {}, 207 | "source": [ 208 | "要找到相对于一个或多个不同参数的梯度,可以设置`argnums`:" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 7, 214 | "id": "fluid-lafayette", 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "data": { 219 | "text/plain": [ 220 | "(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n", 221 | " DeviceArray([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))" 222 | ] 223 | }, 224 | "execution_count": 7, 225 | "metadata": {}, 226 | "output_type": "execute_result" 227 | } 228 | ], 229 | "source": [ 230 | "jax.grad(sum_squared_error, argnums=(0, 1))(x, y)" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "id": "existing-panic", 236 | "metadata": {}, 237 | "source": [ 238 | "这是否意味着在机器学习中,我们需要编写带有巨大参数列表的函数,并未每个模型参数组数组提供一个参数?不必。JAX配备了用于将数组捆绑在一起的机制,称为`pytree`的数据结构。有关更多信息,请参阅[后续指南]()。因此,最常见使用`jax.grad`的方式如下:" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "id": "pharmaceutical-provincial", 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [ 248 | "def loss_fn(params, data):\n", 249 | " ...\n", 250 | " \n", 251 | "grads = jax.grad(loss_fn)(params, data_batch)" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "id": "fatal-centre", 257 | "metadata": {}, 258 | "source": [ 259 | "其中,`params`是例如数组的嵌套字典,返回的`grads`是具有相同结构的数组的另一个嵌套字典。\n", 260 | "\n", 261 | "## 值和梯度\n", 262 | "\n", 263 | "通常,您需要同时得到函数的值和梯度。例如,如果您想记录训练损失。JAX具有方便的“姊妹转换”功能,可以有效地完成此任务:" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 9, 269 | "id": "limiting-integral", 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "data": { 274 | "text/plain": [ 275 | "(DeviceArray(0.03999995, dtype=float32),\n", 276 | " DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32))" 277 | ] 278 | }, 279 | "execution_count": 9, 280 | "metadata": {}, 281 | "output_type": "execute_result" 282 | } 283 | ], 284 | "source": [ 285 | "jax.value_and_grad(sum_squared_error)(x, y)" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "id": "canadian-threshold", 291 | "metadata": {}, 292 | "source": [ 293 | "由此返回了一个`(value, grad)`元组。准确的说,对于任何`f`:" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "id": "cathedral-projector", 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "jax.value_and_grad(f)(*xs) == (f(*xs), jax.grad(f)(*xs)) " 304 | ] 305 | }, 306 | { 307 | "cell_type": "markdown", 308 | "id": "incorporated-muscle", 309 | "metadata": {}, 310 | "source": [ 311 | "## 辅助数据\n", 312 | "\n", 313 | "除了要记录该值外, 我们通常还希望报告在计算损失函数时获得的一些中间结果。但是,如果我们同时尝试使用常规的 `jax.grad`,这样做则会遇到麻烦:\n", 314 | "\n" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 10, 320 | "id": "running-forward", 321 | "metadata": {}, 322 | "outputs": [ 323 | { 324 | "ename": "TypeError", 325 | "evalue": "Gradient only defined for scalar-output functions. Output was (DeviceArray(0.03999995, dtype=float32), DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32)).", 326 | "output_type": "error", 327 | "traceback": [ 328 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 329 | "\u001b[0;31mFilteredStackTrace\u001b[0m Traceback (most recent call last)", 330 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msqared_error_with_aux\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 331 | "\u001b[0;31mFilteredStackTrace\u001b[0m: TypeError: Gradient only defined for scalar-output functions. Output was (DeviceArray(0.03999995, dtype=float32), DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32)).\n\nThe stack trace above excludes JAX-internal frames.\nThe following is the original exception that occurred, unmodified.\n\n--------------------", 332 | "\nThe above exception was the direct cause of the following exception:\n", 333 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 334 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/api.py\u001b[0m in \u001b[0;36m_check_scalar\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 840\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 841\u001b[0;31m \u001b[0maval\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_aval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 842\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 335 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36mget_aval\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 920\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 921\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mconcrete_aval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 922\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 336 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36mconcrete_aval\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 912\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mconcrete_aval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__jax_array__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 913\u001b[0;31m raise TypeError(f\"Value {repr(x)} with type {type(x)} is not a valid JAX \"\n\u001b[0m\u001b[1;32m 914\u001b[0m \"type\")\n", 337 | "\u001b[0;31mTypeError\u001b[0m: Value (DeviceArray(0.03999995, dtype=float32), DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32)) with type is not a valid JAX type", 338 | "\nThe above exception was the direct cause of the following exception:\n", 339 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 340 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msum_squared_error\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msqared_error_with_aux\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 341 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mreraise_with_filtered_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 140\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_under_reraiser\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 342 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/api.py\u001b[0m in \u001b[0;36mgrad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 758\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mapi_boundary\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 759\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgrad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 760\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue_and_grad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 761\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 762\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 343 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mreraise_with_filtered_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 140\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_under_reraiser\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 344 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/api.py\u001b[0m in \u001b[0;36mvalue_and_grad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 824\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 825\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvjp_py\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mdyn_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 826\u001b[0;31m \u001b[0m_check_scalar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 827\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdtypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresult_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 828\u001b[0m \u001b[0mtree_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_check_output_dtype_grad\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mholomorphic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 345 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/api.py\u001b[0m in \u001b[0;36m_check_scalar\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 841\u001b[0m \u001b[0maval\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_aval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 842\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 843\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"was {x}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 844\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 845\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mShapedArray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 346 | "\u001b[0;31mTypeError\u001b[0m: Gradient only defined for scalar-output functions. Output was (DeviceArray(0.03999995, dtype=float32), DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32))." 347 | ] 348 | } 349 | ], 350 | "source": [ 351 | "def sqared_error_with_aux(x, y):\n", 352 | " return sum_squared_error(x, y), x-y\n", 353 | "\n", 354 | "jax.grad(sqared_error_with_aux)(x, y)" 355 | ] 356 | }, 357 | { 358 | "cell_type": "markdown", 359 | "id": "artistic-neutral", 360 | "metadata": {}, 361 | "source": [ 362 | "这是因为 `jax.grad` 仅在标量函数上定义,并且我们的新函数返回一个元组。但我们也需要返回一个元组以返回中间结果。所以我们需要 `has_aux`出现:" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 11, 368 | "id": "latest-species", 369 | "metadata": {}, 370 | "outputs": [ 371 | { 372 | "data": { 373 | "text/plain": [ 374 | "(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n", 375 | " DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32))" 376 | ] 377 | }, 378 | "execution_count": 11, 379 | "metadata": {}, 380 | "output_type": "execute_result" 381 | } 382 | ], 383 | "source": [ 384 | "jax.grad(sqared_error_with_aux, has_aux=True)(x, y)" 385 | ] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "id": "academic-bride", 390 | "metadata": {}, 391 | "source": [ 392 | "`has_aux`表示该函数返回一对 `(out, aux)`。它使`jax.grad`忽略`aux`,将其传递给用户,同时区分函数,就好像只返回了`out`一样。\n", 393 | "\n", 394 | "## 和NumPy的不同\n", 395 | "\n", 396 | "`jax.numpy` API与NumPy API密切保持同步。但是,还是有一些重要的区别。我们将会在后面的指南中介其中的更多内容,现在我们先提出一点。\n", 397 | "\n", 398 | "最重要的区别(在某种意义上是所有其他区别的根源)是,JAX被设计为具有函数性,就像函数式编程中一样。其背后的原因是,JAX支持的程序转换在功能样式程序中更加可行。\n", 399 | "\n", 400 | "函数式编程(FP)的介绍不在本指南的范围之内。如果您已经熟悉FP,那么在学习JAX时,您的FP直觉会有所帮助。如果没有,也请放心!使用JAX是,使用grok进行函数式编程的重要功能非常简单:不要编写有副作用的代码。\n", 401 | "\n", 402 | "副作用是其功能未在输出中出现的任何效果。一个示例是就地修改数组:" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": 12, 408 | "id": "matched-heart", 409 | "metadata": {}, 410 | "outputs": [ 411 | { 412 | "data": { 413 | "text/plain": [ 414 | "array([123, 2, 3])" 415 | ] 416 | }, 417 | "execution_count": 12, 418 | "metadata": {}, 419 | "output_type": "execute_result" 420 | } 421 | ], 422 | "source": [ 423 | "import numpy as np\n", 424 | "\n", 425 | "x = np.array([1, 2, 3])\n", 426 | "\n", 427 | "def in_place_modify(x):\n", 428 | " x[0] = 123\n", 429 | " return None\n", 430 | "\n", 431 | "in_place_modify(x)\n", 432 | "x" 433 | ] 434 | }, 435 | { 436 | "cell_type": "markdown", 437 | "id": "loose-shopping", 438 | "metadata": {}, 439 | "source": [ 440 | "副作用函数修改了其参数,但返回了一个完全不相关的值。这是一个副作用。\n", 441 | "\n", 442 | "这份代码在NumPy中可以运行,但是JAX数组不允许对其进行就地修改:" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": 14, 448 | "id": "pointed-alberta", 449 | "metadata": {}, 450 | "outputs": [ 451 | { 452 | "ename": "TypeError", 453 | "evalue": "'' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?", 454 | "output_type": "error", 455 | "traceback": [ 456 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 457 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 458 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0min_place_modify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# 将其输入映射到 `jnp.ndarray`时报错\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 459 | "\u001b[0;32m\u001b[0m in \u001b[0;36min_place_modify\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0min_place_modify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m123\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 460 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py\u001b[0m in \u001b[0;36m_unimplemented_setitem\u001b[0;34m(self, i, x)\u001b[0m\n\u001b[1;32m 5264\u001b[0m \u001b[0;34m\"immutable; perhaps you want jax.ops.index_update or \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5265\u001b[0m \"jax.ops.index_add instead?\")\n\u001b[0;32m-> 5266\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5267\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5268\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_operator_round\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumber\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mndigits\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 461 | "\u001b[0;31mTypeError\u001b[0m: '' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?" 462 | ] 463 | } 464 | ], 465 | "source": [ 466 | "in_place_modify(jnp.array(x)) # 将其输入映射到 `jnp.ndarray`时报错\n" 467 | ] 468 | }, 469 | { 470 | "cell_type": "markdown", 471 | "id": "alike-eight", 472 | "metadata": {}, 473 | "source": [ 474 | "该错误指出了JAX通过`jax.ops.index_*ops`进行相同操作的无副作用的方法。它们类似于按索引进行就地修改,但实际上是创建一个新数组并进行相应的修改:" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": 15, 480 | "id": "israeli-monitor", 481 | "metadata": {}, 482 | "outputs": [ 483 | { 484 | "data": { 485 | "text/plain": [ 486 | "DeviceArray([123, 2, 3], dtype=int32)" 487 | ] 488 | }, 489 | "execution_count": 15, 490 | "metadata": {}, 491 | "output_type": "execute_result" 492 | } 493 | ], 494 | "source": [ 495 | "def jax_in_place_modify(x):\n", 496 | " return jax.ops.index_update(x, 0, 123)\n", 497 | "\n", 498 | "y = jnp.array([1, 2, 3])\n", 499 | "jax_in_place_modify(y)" 500 | ] 501 | }, 502 | { 503 | "cell_type": "markdown", 504 | "id": "convertible-neighbor", 505 | "metadata": {}, 506 | "source": [ 507 | "注意旧数组没有丝毫变化,所以这样就没有副作用:" 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": 17, 513 | "id": "automated-summary", 514 | "metadata": {}, 515 | "outputs": [ 516 | { 517 | "data": { 518 | "text/plain": [ 519 | "DeviceArray([1, 2, 3], dtype=int32)" 520 | ] 521 | }, 522 | "execution_count": 17, 523 | "metadata": {}, 524 | "output_type": "execute_result" 525 | } 526 | ], 527 | "source": [ 528 | "y" 529 | ] 530 | }, 531 | { 532 | "cell_type": "markdown", 533 | "id": "rational-creature", 534 | "metadata": {}, 535 | "source": [ 536 | "无副作用的代码有时候称为纯函数式的,或者直接称为纯函数。\n", 537 | "\n", 538 | "纯函数版本的效率不低吗?严格来说,确实会低,我们还创建了一个新数组。但是,正如我们将在下一届中说明的那样,JAX计算通常在使用另一个变换`jax.jit`运行之前进行编译。如果我们在使用`jax.ops.index_update()`对其进行就地修改后不适用旧数组,则编译器可以识别出它实际上可以编译为就地修改,从而最终获得较搞笑的代码。\n", 539 | "\n", 540 | "当然,可以混合使用副作用较大的Python代码和纯函数式的JAX代码,我们将在后面详细介绍。随着对JAX的熟悉,您将学习如何以及何时可以使用它。根据经验,任何打算由JAX转换的函数都应该避免副作用,而JAX原语本身将尽力帮助您做到这一点。\n", 541 | "\n", 542 | "我们将解释其他与JAX特性相关的地方。甚至有一个部分完全专注于适应处理状态的函数式编程:[第七部分:状态问题]。然而,如果您不太耐烦,那么可以从JAX文档中查看 [JAX锋芒毕露](https://github.com/rasin-tsukuba/JAX_chinese_tutorial/blob/main/official-tutorials/GettingStarted/1.3-JAX%E9%94%8B%E8%8A%92%E6%AF%95%E9%9C%B2.ipynb)。\n", 543 | "\n", 544 | "## 您的第一个JAX训练循环\n", 545 | "\n", 546 | "虽然关于JAX我们还有很多要学习的知识,但是您已经足够了解如何使用JAX来构建简单的训练循环。\n", 547 | "\n", 548 | "为简单起见,我们将从线性回归开始。我们的数据由 $y=w_{true}x + b_{true}+\\epsilon$采样得到:" 549 | ] 550 | }, 551 | { 552 | "cell_type": "code", 553 | "execution_count": 18, 554 | "id": "outer-graphic", 555 | "metadata": {}, 556 | "outputs": [ 557 | { 558 | "data": { 559 | "text/plain": [ 560 | "" 561 | ] 562 | }, 563 | "execution_count": 18, 564 | "metadata": {}, 565 | "output_type": "execute_result" 566 | }, 567 | { 568 | "data": { 569 | "image/png": "\n", 570 | "text/plain": [ 571 | "
" 572 | ] 573 | }, 574 | "metadata": { 575 | "needs_background": "light" 576 | }, 577 | "output_type": "display_data" 578 | } 579 | ], 580 | "source": [ 581 | "import numpy as np\n", 582 | "import matplotlib.pyplot as plt\n", 583 | "\n", 584 | "xs = np.random.normal(size=(100,))\n", 585 | "noise = np.random.normal(scale=0.1, size=(100,))\n", 586 | "ys = xs * 3 - 1 + noise\n", 587 | "\n", 588 | "plt.scatter(xs, ys)" 589 | ] 590 | }, 591 | { 592 | "cell_type": "markdown", 593 | "id": "qualified-calcium", 594 | "metadata": {}, 595 | "source": [ 596 | "因此,我们的模型为:$\\hat{y}(x;\\theta) = wx +b$。\n", 597 | "\n", 598 | "我们将会用到单个数组,`theta=[w, b]`来装填两个参数:" 599 | ] 600 | }, 601 | { 602 | "cell_type": "code", 603 | "execution_count": 19, 604 | "id": "gentle-tunnel", 605 | "metadata": {}, 606 | "outputs": [], 607 | "source": [ 608 | "def model(theta, x):\n", 609 | " \"\"\"\n", 610 | " Computes wx + b on a batch of input x.\n", 611 | " 根据批量输入x来计算 wx+b\n", 612 | " \"\"\"\n", 613 | " \n", 614 | " w, b = theta\n", 615 | " return w * x + b\n", 616 | " " 617 | ] 618 | }, 619 | { 620 | "cell_type": "markdown", 621 | "id": "monetary-sampling", 622 | "metadata": {}, 623 | "source": [ 624 | "那么损失函数为 $J(x, y; \\theta) = (\\hat{y} - y)^2$。" 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": 20, 630 | "id": "straight-spokesman", 631 | "metadata": {}, 632 | "outputs": [], 633 | "source": [ 634 | "def loss_fn(theta, x, y):\n", 635 | " prediction = model(theta, x)\n", 636 | " return jnp.mean((prediction - y) ** 2)" 637 | ] 638 | }, 639 | { 640 | "cell_type": "markdown", 641 | "id": "located-relief", 642 | "metadata": {}, 643 | "source": [ 644 | "我们将如何优化损失函数呢?使用梯度下降。在每个更新步骤中,我们将找到相对于参数的损失梯度,并且在最陡的下降方向上走上一小步$\\theta_{new} = \\theta-0.1(\\triangledown_{\\theta} J)(x, y; \\theta)$:" 645 | ] 646 | }, 647 | { 648 | "cell_type": "code", 649 | "execution_count": 21, 650 | "id": "moral-buying", 651 | "metadata": {}, 652 | "outputs": [], 653 | "source": [ 654 | "def update(theta, x, y, lr=0.1):\n", 655 | " return theta - lr * jax.grad(loss_fn)(theta, x, y)" 656 | ] 657 | }, 658 | { 659 | "cell_type": "markdown", 660 | "id": "running-straight", 661 | "metadata": {}, 662 | "source": [ 663 | "在JAX中,通常会定义一个`update()`函数,该函数将在每个步骤中调用,将当前参数作为输入并返回新参数。这是JAX函数式的自然结果,在[状态的问题]一节中将会有更详细的解释。\n", 664 | "\n", 665 | "之后,我们可以对该函数进行整体的JIT编译,以实现最高效率。下一节将会更加确切地解释`jax.jit`的工作原理,但如果您愿意,您可以尝试在`update()`定义之前添加 `@jax.jit`,并查看下面的训练循环如何更快地运行:" 666 | ] 667 | }, 668 | { 669 | "cell_type": "code", 670 | "execution_count": 22, 671 | "id": "written-devon", 672 | "metadata": {}, 673 | "outputs": [ 674 | { 675 | "name": "stdout", 676 | "output_type": "stream", 677 | "text": [ 678 | "w: 2.98, b: -1.00\n" 679 | ] 680 | }, 681 | { 682 | "data": { 683 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAcDElEQVR4nO3deZSU9Z3v8fe3m0ILXJooLrQgZnQ0ooLa4j6JuOPudcGTzYkzRK9mdK5jBq4bGjXOtBo90US5msxJgsZxxBZFbWFiYkICI9Boi4pBRKQwURQQpcVevvePri6rq6u6q7qeqqeWz+scDvUs9POtaD78/D2/xdwdERGpfDVhFyAiIsWhwBcRqRIKfBGRKqHAFxGpEgp8EZEqMSTsAvqz8847+9ixY8MuQ0SkbCxZsmS9u49Md62kA3/s2LEsXrw47DJERMqGmb2T6Zq6dEREqoQCX0SkSijwRUSqhAJfRKRKKPBFRKqEAl9EpEqU9LBMEZFq0tQSo7F5Bes2tjGqLso1J+/L2QfXB/bzFfgiIiWgqSXG9NmttLV3AhDb2Mb02a0AgYW+unRERErA7c++kQj7Hm3tnTQ2rwjsGQp8EZGQ3fTUcv7y8Wdpr63b2BbYc9SlIyISkrfXf8pxd/y233tG1UUDe55a+CIiRebuXD5raa+wv/3cA4lGanvdF43Ucs3J+wb2XLXwRUSK6NXYJk7/8R8Sx3ddMJ5zD9kDgG0jtRqlIyJS7rq6nAse+BOL39kAwE7Dh7Jg2iS2TWrVn31wfaABn0qBLyJSYAtWrufrDy5KHP/84sM4br9dil5H3oFvZvsCjyad+jJwg7vfnXTP14Angbfjp2a7+835PltEpJS1d3Zx3B2/Ze2G7pE2++++A0997xhqayyUevIOfHdfAUwAMLNaIAY8kebW37v76fk+T0SkHDzT+h7/e9bSxPHjlx3FoXuOCLGi4Lt0jgfecveMO66IiFSyLZ93MP6m52nvdAAm7bcLD327AbNwWvXJgg78KcAjGa4daWYvA+uAf3H35eluMrOpwFSAMWPGBFyeiEjh/HLhO1zf9GrieN4//x377Lp9iBX1Zu4ezA8yG0p3mI9z97+mXNsB6HL3T8xsMnCPu+8z0M9saGhw7WkrIqVuw6efc/AP5iWOpxw2mtv/10Gh1GJmS9y9Id21IFv4pwJLU8MewN0/Tvr8jJn9xMx2dvf1AT5fRKTo7pn/Z340/83E8YJpk6gPcHZskIIM/IvI0J1jZrsBf3V3N7OJdM/w/TDAZ4uIFNW6jW0cdftvEsffm7Q3V58U3KzYQggk8M1sOHAi8N2kc5cCuPv9wHnAZWbWAbQBUzyoviQRkSK7rqmVXy1ckzheev2JfGn40BAryk4gge/unwI7pZy7P+nzvcC9QTxLRCQsK9/fzAl3vZg4vunMcXz7qLHhFZQjzbQVERmAu/OPv1jC/Ne7X1HWGLTOOJnh25RXhJZXtSIiRdayZgPn/OSPieMfX3QwZ4wfFWJFg6fAFxFJo7PLOfu+BbTGNgEwasdt+e01xzF0SPmuKq/AFxFJ8bs3P+DbP/ufxPEvL5nIsfuMDLGiYCjwRUTitnZ0cuy/vcD7m7cCMGF0HbMvO4qakBY7C5oCX0QEeHJZjCt/veyL48uPZvzoutDqKQQFvohUtU+2dnDAjc2J41PG7cZPv3FISSx2FjQFvohUrZ8veJubnnotcfzfV3+Vvxm5XYgVFZYCX0SqzoefbOXQW+Ynjr915J7cfNYBIVZUHAp8Eak4TS2xjJuBNza/wX0vvJW4d+H049ltx23DKrWoFPgiUlGaWmJMn91KW3snALGNbUyf3cpHn37OzU9/0X1z9Yl/y/eOH3CV9oqiwBeRitLYvCIR9j3a2jt7hf2yG06kbljpL3YWNAW+iFSUdRvbMl679ZwD+PrhexaxmtJSvnOERUTSGJVh85Hdd9y2qsMeFPgiUmGO/8oufc5FI7X86yn7hVBNaVGXjohUhPbOLva59tk+5+tTRulUMwW+iJS9B3+/ilvmvp44/vfzDuKChtEhVlSaFPgiUrY+3drBuKRlEQBW3Ta5YhY7C5oCX0TK0i1Pv8aDf3g7cfzziw/juP369t/LFxT4IlLSUmfNXvrVL3P9k8sT14fW1vDmraeGWGH5UOCLSMlKN2s2OeybLj+aCRW2hHEhBRb4ZrYa2Ax0Ah3u3pBy3YB7gMnAFuBid18a1PNFpPKkmzULMKTGWHnb5BAqKm9Bt/CPc/f1Ga6dCuwT/3U48NP47yIiaWWaNdvZ5UWupDIUc+LVWcAvvNtCoM7Mdi/i80WkjMx95T0yxXqm2bTSvyBb+A48b2YOPODuM1Ou1wPvJh2vjZ97L/kmM5sKTAUYM2ZMgOWJSLkYO21uxmvRSC3XnLxvEaupHEG28I9x90Po7rq53Mz+bjA/xN1nunuDuzeMHFn+u8SLSHaaWmIcNKO5V9hvt80Q7r5wAvV1UYzuWbM/PPdAzZodpMBa+O4ei//+vpk9AUwEXky6JQYkT33bI35ORKrc40vWcvVjL/c6t+2QGm45+wDOPrheAR+QQFr4ZjbczLbv+QycBLyactsc4FvW7Qhgk7u/h4hUta81vtAn7AE+6+iisXlFCBVVrqBa+LsCT8R3eR8CPOzuz5nZpQDufj/wDN1DMlfSPSzz7wN6toiUoc2ftXPgjOf7vae/te0ld4EEvruvAsanOX9/0mcHLg/ieSJS3lJfyg6pMTrSDLXUaJxgaT18ESmadz/a0ifs37ptMnecP55opLbXeY3GCZ6WVhCRokgN+vMO3YM7zu/uGOh5KZu8Zo7WsA+eAl9ECqJn0bNYmn741bef1uecRuMUngJfRAKXuuhZj7MmjOKeKQeHVJWoD19EAnfDk6+mXfRs8eoNIVQjPdTCF5FA9bcsgoZZhkuBLyKBuOLhpTz9Sv9zKTXMMlzq0hGRvLg7Y6fN7RX2/zRpbw2zLEFq4YvIoE28dT7vb97a61zPCJwvj9xOwyxLjAJfRHK2taOTfa97rte5BdMmUZ/UZaNhlqVHgS8iOUn3UjbduHopPQp8EcnKB5u3ctit83ude+3mkxk2VDFSLvRPSkQSembHrtvYxo7RCGawcUt7n60Ga2uMt7SJeNlR4IsI0Hd27Ma29rT3rbptMjU1VszSJCAKfJEq19+aN6lGDIso7MuYAl+kimVa8yaTDVvSt/qlPGjilUgVa2xekXXYS/lT4ItUsVzXtqmLRgpUiRSDAl+kiuWytk2kxphx5rgCViOFpj58kSrT1BLjpqeWZ+yPj0Zq+eG5BwLagarSKPBFqkhTS4x/eWwZHV3pr9enBLsCvrLkHfhmNhr4BbAr4MBMd78n5Z6vAU8Cb8dPzXb3m/N9toj0L3ki1ai6aL9DL+vroiyYNqmI1UmxBdHC7wCudvelZrY9sMTM5rn7ayn3/d7dTw/geSLSj+Rx9QaJWbIDjbPX5iSVL+/Ad/f3gPfinzeb2etAPZAa+CJSYKnj6lOXROiPNiepfIGO0jGzscDBwKI0l480s5fN7Fkzy/iq38ymmtliM1v8wQcfBFmeSMUb7Lj6SI1pc5IqEFjgm9l2wOPAVe7+ccrlpcCe7j4e+DHQlOnnuPtMd29w94aRI0cGVZ5IRWtqiXH07b/JenmEEcO+GE9fF43QeP54vaCtAoGM0jGzCN1hP8vdZ6deT/4LwN2fMbOfmNnO7r4+iOeLVLNcl0e48YxxCvcqlXcL38wMeAh43d3vynDPbvH7MLOJ8ed+mO+zRQRuemp51mFfF40o7KtYEC38o4FvAq1mtix+7v8CYwDc/X7gPOAyM+sA2oAp7p7L+yQRSaOpJZb1gmaRWs2UrXZBjNL5A9Dveqnufi9wb77PEpHernp0WcZrZtDTrBoxLKKuHNFMW5FylG5f2VQ/umCCAl560eJpImUmm7BXX72koxa+SInKZlmEuy+c0GeETjRSq756SUuBL1JCMm03mHr8zyf8LVeesE/iWKtaSjYU+CIl4rqmVmYtXDPgcgh3X9i7b/7sg+sV8JIV9eGLlICmllhWYQ8wfXYrTS2xgtcklUeBL1ICGptXZL3QWVt7J43NKwpaj1QmdemIhCT5pWyusxC1lLEMhgJfJATZ9tdnoqWMZTAU+CJFkmkETq6ikVotZSyDosAXKYKmlhjXPPYy7V25t+kjtcbwoUPY1NauYZeSFwW+SIFd19TKrxauyereWjPuvGC8xtVLQSjwRQool7AH6HTXuHopGA3LFCmQppZYTmEPUK+XsVJACnyRAujZhSoXehkrhaYuHZEApC509unWjpw2E9d69VIMCnyRQcp2obNstNxwUlBliWSkwBcZhFw3Du+P+u2lWNSHLzIIjc0rBhX2kZreu4Gq316KSYEvMgiDXcum8fzx1NdFMbpb9j8890D120vRqEtHJEvJL2ZrzOj03GbN1tdFNcZeQhVIC9/MTjGzFWa20sympbm+jZk9Gr++yMzGBvFckWLpWRohFl/ZMtewV9eNlIK8W/hmVgvcB5wIrAVeMrM57v5a0m2XABvcfW8zmwL8G3Bhvs8WKbTBLni2zZAaopFarX8jJSWILp2JwEp3XwVgZr8GzgKSA/8sYEb8838B95qZuefYTBIpomxH4owYFmHDlnYA6qIRZpyp8fRSmoII/Hrg3aTjtcDhme5x9w4z2wTsBKxP/WFmNhWYCjBmzJgAyhMZnGxH4mgMvZSLkntp6+4zgZkADQ0N+i8AKapcd6EaMSxS8JpEghJE4MeA0UnHe8TPpbtnrZkNAXYEPgzg2SJ5u66plUcWvZvzi1iAG88YV4CKRAojiFE6LwH7mNleZjYUmALMSblnDvDt+OfzgN+o/15KQc/yxbmGvQHfOGKM+uqlrOTdwo/3yV8BNAO1wM/cfbmZ3Qwsdvc5wEPAL81sJfAR3X8piITukUXvDnxTXG187H29Rt1ImQqkD9/dnwGeSTl3Q9Lnz4Dzg3iWSJCyadnX10VZMG1SEaoRKSwtrSBVq6kl9VVTX5owJZWk5EbpiBTDtx5axIt/7jMquA+tdSOVRIEvVWfstLlZ3dez9o1IpVDgS9XINuhBXTlSmRT4UhVyCXuNwpFKpcCXipC6p2xPYKcL+tW3n5Z2nZxopFZ99lLRrJTnPzU0NPjixYvDLkNK3HVNrcxauKbXUgjRSG3adXBW335a4nOmvyREypmZLXH3hnTX1MKXstbUEusT9kCfsE8O+h7ajESqjcbhS1lrbF4x4CJn6cJepBop8KVsNbXE+t2YpL4uqrAXSaIuHSkruexApWGVIr0p8KVsZLsDVQ/1z4v0psCXktfUEuOmp5YnthHMRn1dtIAViZQnBb6UtKaWGNf818u0d2Y/fFizZEXS00tbKWmNzStyCvsa04JnIpmohS8lJ9d9ZXtEao3G88Yr7EUyUOBLSUk3azYbWv9GZGAKfCkJ1zW1MmvRGnJd6UPr34hkT4EvoevZSDxbBjhq1YvkSoEvoWhqiTFjznI2tmU/1BK6w/5tzZ4VGRQFvhRdU0uMax57mfau3FdqHaXx9SKDllfgm1kjcAbwOfAW8PfuvjHNfauBzUAn0JFp6U6pDtNnvzKosNf4epH85DsOfx5wgLsfBLwJTO/n3uPcfYLCvno1tcT4yvXP0tbeldX9kRoYMSyC0d1fr5ezIvnJq4Xv7s8nHS4EzsuvHKlUX6yDk13Yf+OIMdxy9oEFrkqkugTZh/8d4NEM1xx43swceMDdZ2b6IWY2FZgKMGbMmADLk2JLnkCFkfWQS4W9SGEMGPhmNh/YLc2la939yfg91wIdwKwMP+YYd4+Z2S7APDN7w91fTHdj/C+DmdC9xWEW30FKRL+LnGXxT9KAryvsRQpmwMB39xP6u25mFwOnA8d7hg1y3T0W//19M3sCmAikDXwpT4NZ5CyZWvUihZfXS1szOwX4PnCmu2/JcM9wM9u+5zNwEvBqPs+V0pPrImc9akxhL1Is+fbh3wtsQ3c3DcBCd7/UzEYBD7r7ZGBX4In49SHAw+7+XJ7PlRKzLosdqFLVRSMsu/GkAlQjIunkO0pn7wzn1wGT459XAePzeY6UvlF10ay2HewRjdQy48xxBaxIRFJppq3kLHVZhBHDIjntRjViWIQbzxinMfUiRabAl5ykWxahv7CvAXYcFmHjlnZGabEzkVAp8CUnjc0r+l0W4e4LJyTG3ivgRUqLAl9y0l8/vQFnH1yvgBcpUQp8SSt5lmxPSz3DNIuEumGRIlUnIoOhwJc+vlj3phPobtVf9eiyAf9crrtViUhxKfClT2t+y+cdibDPxaYcNzMRkeJS4Fe5dK35wdLmJCKlLd/18KXMNTavyLo1XxeNEI3Upr2mzUlESp9a+FUu2yURkmfGNjavILaxjVozOt21mbhImVDgV7lslkRIDXQFu0h5UpdOlRuoG6a+LsqCaZMU8iIVQC38KnbCXb9j5fufZLyufnmRyqIWfpUaO21un7C/+8IJ1NdFtWm4SIVSC7/KjJ02t8+51beflvisgBepXGrhV5HUsN9+2yG9wl5EKpta+BUqefZsuhUPFPQi1UeBX2Gua2pl1sI1aUMe4JJj9uL60/cvak0iUhoU+BXkuqZWfrVwTcbr9XVRhb1IFVMffgV5ZNG7/V4fzEbjIlI5FPgVpHOA9Ym1uJlIdcsr8M1shpnFzGxZ/NfkDPedYmYrzGylmU3L55nS1x9Xrk873DKZJlGJSBB9+D9y9zsyXTSzWuA+4ERgLfCSmc1x99cCeHbV6hmFk81yxsOH1nLrOZpEJVLtivHSdiKw0t1XAZjZr4GzAAX+IDW1xLjmsZf7bCZ+1wXjWbpmA48sepdOd2rNuOjw0dxy9oEhVSoipSSIwL/CzL4FLAaudvcNKdfrgeS3iWuBwzP9MDObCkwFGDNmTADllb/UHakytervfP5NFkybpIAXkbQGDHwzmw/slubStcBPgR8AHv/9TuA7+RTk7jOBmQANDQ1VuUtqcsDvGI3w6ecdtHd2/0/RXxdOPrtViUjlGzDw3f2EbH6Qmf0/4Ok0l2LA6KTjPeLnJI3ULQc35rBPbK1ZocoSkQqQ7yid3ZMOzwFeTXPbS8A+ZraXmQ0FpgBz8nluJctly8FUAw3LFJHqlm8f/r+b2QS6u3RWA98FMLNRwIPuPtndO8zsCqAZqAV+5u7L83xuxcpnclS9xtmLSD/yCnx3/2aG8+uAyUnHzwDP5POsSpP6IrZnC8FsthyM1BgYiX590Dh7ERmY1tIJQWo/fWxjG9Nnt9LV5QOGfa0ZjeePB0j7F4aISCYK/BCk66dva+/k/zz2cr9/Lhqp7bULlQJeRHKhwA/BQP30v//+cYz+0rCM3T4iIoOhwA9B3bAIG7akH26Zut2gAl5EgqLVMousqSXGxgxh/40jNLNYRApHLfwiu+rRZRmvvfDGB8UrRESqjgK/CJpaYtwy9zXWf/J5v/dpgxIRKSQFfgGkroWT7fII2qBERApJgR+A5ICvGxbhk886EksXZxv2kVrTxCkRKSgFfp5SJ1FlGn3TnxHDItx4xjiNyBGRglLg52mwi53V10VZMG1SASoSEUlPwzLzNJgXrVr3RkTCoMDPUzYvWiM1xohhEYzuln3y8ggiIsWiLp08uDvRobUD3td4/ngFvIiEToE/SAtXfciUmQsTxzVAV5r76uuiCnsRKQkK/Bx1dHZx0o9eZNX6TwHYe5fteO7KY3n6lfd6jdYB9dWLSGlR4Oegeflf+O4vlySO//O7RzJxry8BXyxVrNUtRaRUKfCz8Fl7J4f8YB5bPu9uvR+990786pLDsZRNw7W6pYiUMgX+AB59aQ3/+nhr4vjZK4/lK7vvEGJFIiKDo8DPYNOWdsbf/Hzi+NxD6rnrggnhFSQikicFfhr3vbCSxuYVieOeHahERMpZXoFvZo8CPcNQ6oCN7j4hzX2rgc1AJ9Dh7g35PLdQ/vrxZxx+238nji/96t8w7dT9QqxIRCQ4eQW+u1/Y89nM7gQ29XP7ce6+Pp/nFdKMOcv5jz+uThy/dO0JjNx+m/AKEhEJWCBdOtY9XOUCoOxWA1v1wSdMuvN3iePrTvsK/3Dsl0OsSESkMILqwz8W+Ku7/znDdQeeNzMHHnD3mZl+kJlNBaYCjBlTuD1e3Z3LH17KM61/SZxrnXES228bKdgzRUTCNGDgm9l8YLc0l6519yfjny8CHunnxxzj7jEz2wWYZ2ZvuPuL6W6M/2UwE6ChocEHqm8wWtdu4ox7/5A4vuuC8Zx7yB6FeJSISMkYMPDd/YT+rpvZEOBc4NB+fkYs/vv7ZvYEMBFIG/iF1NXlnP/An1jyzgYAdho+lD9On8Q2QwZeAE1EpNwF0aVzAvCGu69Nd9HMhgM17r45/vkk4OYAnpuTBSvX8/UHFyWOf37xYRy33y7FLkNEJDRBBP4UUrpzzGwU8KC7TwZ2BZ6IL0MwBHjY3Z8L4LlZae/s4muNvyUW36hk/9134KnvHUNtjQ3wJ0VEKkvege/uF6c5tw6YHP+8Chif73Oylbyh+IhhQ/loy+eJa49fdhSH7jmiWKWIiJSUipppm7qheE/Y77/7Dsz9p2P6LHYmIlJNKmqLw0wbim9qa1fYi0jVq6jAz7Sh+GA2GhcRqTQVFfiZNhTPZqNxEZFKV1GBf83J+xKN9B5Tr20GRUS6VdRLW20zKCKSWUUFPmibQRGRTCqqS0dERDJT4IuIVAkFvohIlVDgi4hUCQW+iEiVMPeC7DESCDP7AHgn7DpS7AyU7N68edJ3K0/6buWpUN9tT3cfme5CSQd+KTKzxe7eEHYdhaDvVp703cpTGN9NXToiIlVCgS8iUiUU+LmbGXYBBaTvVp703cpT0b+b+vBFRKqEWvgiIlVCgS8iUiUU+INgZo1m9oaZvWJmT5hZXdg1BcXMzjez5WbWZWYVMRzOzE4xsxVmttLMpoVdT1DM7Gdm9r6ZvRp2LUEzs9Fm9oKZvRb/9/HKsGsKiplta2b/Y2Yvx7/bTcV6tgJ/cOYBB7j7QcCbwPSQ6wnSq8C5wIthFxIEM6sF7gNOBfYHLjKz/cOtKjD/AZwSdhEF0gFc7e77A0cAl1fQP7etwCR3Hw9MAE4xsyOK8WAF/iC4+/Pu3hE/XAjsEWY9QXL31919Rdh1BGgisNLdV7n758CvgbNCrikQ7v4i8FHYdRSCu7/n7kvjnzcDrwMVsdGFd/skfhiJ/yrK6BkFfv6+AzwbdhGSUT3wbtLxWiokOKqFmY0FDgYWhVxKYMys1syWAe8D89y9KN+t4na8CoqZzQd2S3PpWnd/Mn7PtXT/p+esYtaWr2y+m0gpMLPtgMeBq9z947DrCYq7dwIT4u//njCzA9y94O9iFPgZuPsJ/V03s4uB04HjvcwmMwz03SpMDBiddLxH/JyUODOL0B32s9x9dtj1FIK7bzSzF+h+F1PwwFeXziCY2SnA94Ez3X1L2PVIv14C9jGzvcxsKDAFmBNyTTIAMzPgIeB1d78r7HqCZGYje0b2mVkUOBF4oxjPVuAPzr3A9sA8M1tmZveHXVBQzOwcM1sLHAnMNbPmsGvKR/zl+hVAM90v/v7T3ZeHW1UwzOwR4E/Avma21swuCbumAB0NfBOYFP//2DIzmxx2UQHZHXjBzF6hu0Eyz92fLsaDtbSCiEiVUAtfRKRKKPBFRKqEAl9EpEoo8EVEqoQCX0SkSijwRUSqhAJfRKRK/H/OJmaVFOq8CgAAAABJRU5ErkJggg==\n", 684 | "text/plain": [ 685 | "
" 686 | ] 687 | }, 688 | "metadata": { 689 | "needs_background": "light" 690 | }, 691 | "output_type": "display_data" 692 | } 693 | ], 694 | "source": [ 695 | "theta = jnp.array([1., 1.])\n", 696 | "\n", 697 | "for _ in range(1000):\n", 698 | " theta = update(theta, xs, ys)\n", 699 | " \n", 700 | "plt.scatter(xs, ys)\n", 701 | "plt.plot(xs, model(theta, xs))\n", 702 | "\n", 703 | "w, b = theta\n", 704 | "print(f\"w: {w:<.2f}, b: {b:<.2f}\")" 705 | ] 706 | }, 707 | { 708 | "cell_type": "markdown", 709 | "id": "accomplished-gather", 710 | "metadata": {}, 711 | "source": [ 712 | "正如您阅读时看到这样,这份代码基本就是将在JAX中实现所有训练循环的一个基础。该示例与实际训练循环之间的主要区别在于我们的模型较为简单:这里我们只用单个数组来容纳所有参数。我们将在后面的 [pytree指南]()中介绍如何管理更多参数。您可以直接跳过,在那儿了解如何在JAX中手动定义和训练简单的MLP。" 713 | ] 714 | } 715 | ], 716 | "metadata": { 717 | "kernelspec": { 718 | "display_name": "Python 3", 719 | "language": "python", 720 | "name": "python3" 721 | }, 722 | "language_info": { 723 | "codemirror_mode": { 724 | "name": "ipython", 725 | "version": 3 726 | }, 727 | "file_extension": ".py", 728 | "mimetype": "text/x-python", 729 | "name": "python", 730 | "nbconvert_exporter": "python", 731 | "pygments_lexer": "ipython3", 732 | "version": "3.9.2" 733 | } 734 | }, 735 | "nbformat": 4, 736 | "nbformat_minor": 5 737 | } 738 | -------------------------------------------------------------------------------- /official-tutorials/GettingStarted/Tutorial:Jax101/1.4.2-JAX的即时编译.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "technical-processing", 6 | "metadata": {}, 7 | "source": [ 8 | "# JAX的即时编译\n", 9 | "\n", 10 | "> 作者:Rosalia Schneider & Vladimir Mikulik\n", 11 | "\n", 12 | "在本节中,我们将进一步讨论JAX的工作原理,以及如何使其具有高性能。我们将讨论`jax.jit()`变换,该变换将执行JAX Python函数的即时编译(JIT),以便可以在XLA中有效地执行该转换。\n", 13 | "\n", 14 | "## 如何使用JAX变换\n", 15 | "\n", 16 | "在上一节中,我们讨论了JAX允许我们变换Python函数。这是通过首先将Python函数转换为一种简单的中间语言jaxpr来完成的。之后,转换将在jaxpr形式上进行。\n", 17 | "\n", 18 | "我们可以用 `jax.make_jaxpr` 来显示函数的jaxpr形式:" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "id": "corrected-onion", 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stdout", 29 | "output_type": "stream", 30 | "text": [ 31 | "{ lambda ; a.\n", 32 | " let b = log a\n", 33 | " c = log 2.0\n", 34 | " d = div b c\n", 35 | " in (d,) }\n" 36 | ] 37 | } 38 | ], 39 | "source": [ 40 | "import jax\n", 41 | "import jax.numpy as jnp\n", 42 | "\n", 43 | "global_list = []\n", 44 | "\n", 45 | "def log2(x):\n", 46 | " global_list.append(x)\n", 47 | " ln_x = jnp.log(x)\n", 48 | " ln_2 = jnp.log(2.0)\n", 49 | " \n", 50 | " return ln_x / ln_2\n", 51 | "\n", 52 | "print(jax.make_jaxpr(log2)(3.0))" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "southeast-spray", 58 | "metadata": {}, 59 | "source": [ 60 | "在教程中的[理解Jaxprs]()部分提供了有关上述输出含义的更多信息。\n", 61 | "\n", 62 | "请注意,很重要的一点是jaxpr无法捕获该函数的副作用:其中没有与`global_list.append(x)`的内容。这是一个特性,并不是一个漏洞:JAX旨在理解无副作用的代码。如果您不太熟悉纯函数和副作用这两个术语,请参见[JAX锋芒毕露:🔪纯函数](https://render.githubusercontent.com/view/ipynb?color_mode=light&commit=fe4a5f85bf7936468ed39f20cced5b25a1612efb&enc_url=68747470733a2f2f7261772e67697468756275736572636f6e74656e742e636f6d2f726173696e2d7473756b7562612f4a41585f6368696e6573655f7475746f7269616c2f666534613566383562663739333634363865643339663230636365643562323561313631326566622f6f6666696369616c2d7475746f7269616c732f47657474696e67537461727465642f312e332d4a41582545392539342538422545382538412539322545362541462539352545392539432542322e6970796e62&nwo=rasin-tsukuba%2FJAX_chinese_tutorial&path=official-tutorials%2FGettingStarted%2F1.3-JAX%E9%94%8B%E8%8A%92%E6%AF%95%E9%9C%B2.ipynb&repository_id=349726397&repository_type=Repository#%F0%9F%94%AA%E7%BA%AF%E5%87%BD%E6%95%B0)。\n", 63 | "\n", 64 | "当然,非纯函数仍然可以编写甚至运行,但是一旦转换为jaxpr,JAX就无法保证其行为。但根据经验,您可以期望(但不应该依赖)JAX转换函数的副作用只运行一次(在第一次调用时)之后再也不会运行。这是因为JAX使用称为“跟踪”的过程生成jaxpr的方式。\n", 65 | "\n", 66 | "跟踪时,JAX用跟踪器对象包装每个参数。然后,这些跟踪器记录函数调用期间对他们执行的所有JAX操作(发生在Python代码之中)。之后,JAX使用跟踪记录来重构整个函数。该重建的输出是jaxpr。由于跟踪其没有记录Python的副作用,因此它们不会出现在jaxpr中。但是,副作用仍会在跟踪期间发生。\n", 67 | "\n", 68 | "注意:Python的 `print()` 不是纯函数:文本输出是该函数的副作用。因此,任何 `print()`调用都只会在跟踪过程中发生,而不会出现在jaxpr中:" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 2, 74 | "id": "geological-nicaragua", 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "name": "stdout", 79 | "output_type": "stream", 80 | "text": [ 81 | "printed x: Tracedwith\n", 82 | "{ lambda ; a.\n", 83 | " let b = log a\n", 84 | " c = log 2.0\n", 85 | " d = convert_element_type[ new_dtype=float32\n", 86 | " weak_type=False ] b\n", 87 | " e = div d c\n", 88 | " in (e,) }\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "def log2_with_print(x):\n", 94 | " print(\"printed x: \", x)\n", 95 | " ln_x = jnp.log(x)\n", 96 | " ln_2 = jnp.log(2)\n", 97 | " return ln_x / ln_2\n", 98 | "\n", 99 | "print(jax.make_jaxpr(log2_with_print)(3.))" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "id": "worse-london", 105 | "metadata": {}, 106 | "source": [ 107 | "看到打印的`x`成为一个 `Traced` 对象了吗?这就是JAX的内部运行机制。\n", 108 | "\n", 109 | "Python代码至少运行一次的事实严格上来说是实现细节,因此不应该对其有依赖。但是,理解它很有用,因为您可以调试以打印出计算的中间值时使用它。\n", 110 | "\n", 111 | "关键要理解的是,jaxpr会捕获对给定参数执行的功能。例如,如果我们有条件,那么jaxpr将只知道我们采取的分支:" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 3, 117 | "id": "parallel-physiology", 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "name": "stdout", 122 | "output_type": "stream", 123 | "text": [ 124 | "{ lambda ; a.\n", 125 | " let \n", 126 | " in (a,) }\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "def log2_if_rank_2(x):\n", 132 | " if x.ndim == 2:\n", 133 | " ln_x = jnp.log(x)\n", 134 | " ln_2 = jnp.log(2)\n", 135 | " return ln_x / ln_2\n", 136 | " else:\n", 137 | " return x\n", 138 | " \n", 139 | "print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1., 2., 3.])))" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "id": "indirect-witness", 145 | "metadata": {}, 146 | "source": [ 147 | "## 使用JIT编译函数\n", 148 | "\n", 149 | "如前所述,JAX使操作可以使用相同的的代码在CPU、GPU和TPU上执行。让我们来看一个计算比例指数线性单位(SELU)的示例,这是深度学习中常用的一种运算:" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 4, 155 | "id": "swiss-basis", 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "852 µs ± 26 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" 163 | ] 164 | } 165 | ], 166 | "source": [ 167 | "import jax\n", 168 | "import jax.numpy as jnp\n", 169 | "\n", 170 | "def selu(x, alpha=1.67, lambda_=1.05):\n", 171 | " return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)\n", 172 | "\n", 173 | "x = jnp.arange(1000000)\n", 174 | "%timeit selu(x).block_until_ready()" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "id": "human-neighbor", 180 | "metadata": {}, 181 | "source": [ 182 | "以上代码一次性向加速器发送了一个操作。这限制了XLA编译器优化功能的能力。\n", 183 | "\n", 184 | "自然,我们想要做的事给XLA编译器尽可能多的代码,以便它可以完全优化它。为此,JAX提供了 `jax.jit`转换,它将即时编译JAX兼容的函数。下面的示例显示了如何使用JIT来加快此函数:" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 5, 190 | "id": "written-spectrum", 191 | "metadata": {}, 192 | "outputs": [ 193 | { 194 | "name": "stdout", 195 | "output_type": "stream", 196 | "text": [ 197 | "55.5 µs ± 3.48 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" 198 | ] 199 | } 200 | ], 201 | "source": [ 202 | "selu_jit = jax.jit(selu)\n", 203 | "\n", 204 | "# warm up\n", 205 | "#预热\n", 206 | "selu_jit(x).block_until_ready()\n", 207 | "\n", 208 | "%timeit selu_jit(x).block_until_ready()" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "id": "textile-editor", 214 | "metadata": {}, 215 | "source": [ 216 | "以下是刚才发生事情的详细解释:\n", 217 | "\n", 218 | "1. 我们将`selu_jit`定义为`selu`的编译版本\n", 219 | "2. 我们在`x`上运行一次 `selu_jit`。这就是JAX进行跟踪的地方——毕竟他需要一些输入才能包装在跟踪器中。然后,使用XLA将jaxpr编译为针对您的GPU或TPU优化的非常有效的代码。现在,对`selu_jit`的后续调用将使用改代码,从而完全跳过我们以前的Python实现。\n", 220 | "\n", 221 | "(如果我们不单独包括预热调用,一切都会照常进行。它仍然会很快,因为我们在基准测试中运行了许多循环,但可能并不会公平比较。)\n", 222 | "\n", 223 | "我们队编译版本的执行速度进行计时。(注意,由于JAX的异步执行模型,因此必须使用`block_until_ready()`)\n", 224 | "\n", 225 | "## 为什么不都用上JIT?\n", 226 | "\n", 227 | "看完以上的示例后,您可能想知道我们是否应该简单粗暴的将`jax.jit`应用于每个函数。要了解为什么不这么做,以及什么时候应该或不应该应用`jit`,首先让我们查看一下JIT无法正常工作的情况:" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 6, 233 | "id": "incomplete-exchange", 234 | "metadata": {}, 235 | "outputs": [ 236 | { 237 | "ename": "ConcretizationTypeError", 238 | "evalue": "Abstract tracer value encountered where concrete value is expected: Tracedwith\nThe problem arose with the `bool` function. \nWhile tracing the function f at :4, this concrete value was not available in Python because it depends on the value of the arguments to f at :4 at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).\n (https://jax.readthedocs.io/en/latest/errors.html#jax._src.errors.ConcretizationTypeError)", 239 | "output_type": "error", 240 | "traceback": [ 241 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 242 | "\u001b[0;31mFilteredStackTrace\u001b[0m Traceback (most recent call last)", 243 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mf_jit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mf_jit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# 应该会报错\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 244 | "\u001b[0;32m\u001b[0m in \u001b[0;36mf\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 245 | "\u001b[0;31mFilteredStackTrace\u001b[0m: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Tracedwith\nThe problem arose with the `bool` function. \nWhile tracing the function f at :4, this concrete value was not available in Python because it depends on the value of the arguments to f at :4 at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).\n (https://jax.readthedocs.io/en/latest/errors.html#jax._src.errors.ConcretizationTypeError)\n\nThe stack trace above excludes JAX-internal frames.\nThe following is the original exception that occurred, unmodified.\n\n--------------------", 246 | "\nThe above exception was the direct cause of the following exception:\n", 247 | "\u001b[0;31mConcretizationTypeError\u001b[0m Traceback (most recent call last)", 248 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mf_jit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mf_jit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# 应该会报错\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 249 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mreraise_with_filtered_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 140\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_under_reraiser\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 250 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/api.py\u001b[0m in \u001b[0;36mf_jitted\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 414\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcpp_jitted_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 415\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 416\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcpp_jitted_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 417\u001b[0m \u001b[0mf_jitted\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cpp_jitted_f\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcpp_jitted_f\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 418\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 251 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/api.py\u001b[0m in \u001b[0;36mcache_miss\u001b[0;34m(_, *args, **kwargs)\u001b[0m\n\u001b[1;32m 295\u001b[0m \u001b[0m_check_arg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 296\u001b[0m \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 297\u001b[0;31m out_flat = xla.xla_call(\n\u001b[0m\u001b[1;32m 298\u001b[0m \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 299\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs_flat\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 252 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36mbind\u001b[0;34m(self, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1392\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1393\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1394\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcall_bind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1395\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1396\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 253 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36mcall_bind\u001b[0;34m(primitive, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1383\u001b[0m \u001b[0mtracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1384\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mmaybe_new_sublevel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1385\u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1386\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_lower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapply_todos\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv_trace_todo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1387\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 254 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess\u001b[0;34m(self, trace, fun, tracers, params)\u001b[0m\n\u001b[1;32m 1395\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1396\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1397\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1398\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1399\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpost_process\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 255 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess_call\u001b[0;34m(self, primitive, f, tracers, params)\u001b[0m\n\u001b[1;32m 623\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 624\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 625\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimpl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 626\u001b[0m \u001b[0mprocess_map\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 256 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_call_impl\u001b[0;34m(fun, device, backend, name, donated_invars, *args)\u001b[0m\n\u001b[1;32m 584\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 585\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_xla_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlu\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mWrappedFun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdonated_invars\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 586\u001b[0;31m compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,\n\u001b[0m\u001b[1;32m 587\u001b[0m *unsafe_map(arg_spec, args))\n\u001b[1;32m 588\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 257 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mmemoized_fun\u001b[0;34m(fun, *args)\u001b[0m\n\u001b[1;32m 258\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpopulate_stores\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 259\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 260\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 261\u001b[0m \u001b[0mcache\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 258 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_callable\u001b[0;34m(fun, device, backend, name, donated_invars, *arg_specs)\u001b[0m\n\u001b[1;32m 660\u001b[0m \u001b[0mabstract_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg_devices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0munzip2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg_specs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 661\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0momnistaging_enabled\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 662\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_to_jaxpr_final\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mabstract_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 663\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTracer\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mc\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 664\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUnexpectedTracerError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Encountered an unexpected tracer.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 259 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr_final\u001b[0;34m(fun, in_avals)\u001b[0m\n\u001b[1;32m 1218\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msource_info\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun_sourceinfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1219\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjaxpr_stack\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1220\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_to_subjaxpr_dynamic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1221\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1222\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 260 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_subjaxpr_dynamic\u001b[0;34m(fun, main, in_avals)\u001b[0m\n\u001b[1;32m 1198\u001b[0m \u001b[0mtrace\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDynamicJaxprTrace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcur_sublevel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1199\u001b[0m \u001b[0min_tracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_arg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1200\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0min_tracers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1201\u001b[0m \u001b[0mout_tracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1202\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mframe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_jaxpr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0min_tracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tracers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 261 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 167\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;31m# Some transformations yield from inside context managers, so we have to\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 262 | "\u001b[0;32m\u001b[0m in \u001b[0;36mf\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 263 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36m__bool__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 552\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__getitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_getitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 553\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__nonzero__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nonzero\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 554\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0m__bool__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_bool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 555\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__int__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_int\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 556\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__long__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_long\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 264 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36merror\u001b[0;34m(self, arg)\u001b[0m\n\u001b[1;32m 949\u001b[0m f\"or `jnp.array(x, {fun.__name__})` instead.\")\n\u001b[1;32m 950\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 951\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mConcretizationTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfname_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 952\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 953\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 265 | "\u001b[0;31mConcretizationTypeError\u001b[0m: Abstract tracer value encountered where concrete value is expected: Tracedwith\nThe problem arose with the `bool` function. \nWhile tracing the function f at :4, this concrete value was not available in Python because it depends on the value of the arguments to f at :4 at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).\n (https://jax.readthedocs.io/en/latest/errors.html#jax._src.errors.ConcretizationTypeError)" 266 | ] 267 | } 268 | ], 269 | "source": [ 270 | "# Condition on value of x\n", 271 | "# 以x的值为条件时\n", 272 | "\n", 273 | "def f(x):\n", 274 | " if x> 0:\n", 275 | " return x\n", 276 | " else:\n", 277 | " return 2 * x\n", 278 | " \n", 279 | "f_jit = jax.jit(f)\n", 280 | "f_jit(10) # 应该会报错" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 7, 286 | "id": "polish-dress", 287 | "metadata": {}, 288 | "outputs": [ 289 | { 290 | "ename": "ConcretizationTypeError", 291 | "evalue": "Abstract tracer value encountered where concrete value is expected: Tracedwith\nThe problem arose with the `bool` function. \nWhile tracing the function g at :4, this concrete value was not available in Python because it depends on the value of the arguments to g at :4 at flattened positions [1], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).\n (https://jax.readthedocs.io/en/latest/errors.html#jax._src.errors.ConcretizationTypeError)", 292 | "output_type": "error", 293 | "traceback": [ 294 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 295 | "\u001b[0;31mFilteredStackTrace\u001b[0m Traceback (most recent call last)", 296 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mg_jit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mg_jit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m20\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# 应该会报错\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 297 | "\u001b[0;32m\u001b[0m in \u001b[0;36mg\u001b[0;34m(x, n)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0;32mwhile\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 298 | "\u001b[0;31mFilteredStackTrace\u001b[0m: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Tracedwith\nThe problem arose with the `bool` function. \nWhile tracing the function g at :4, this concrete value was not available in Python because it depends on the value of the arguments to g at :4 at flattened positions [1], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).\n (https://jax.readthedocs.io/en/latest/errors.html#jax._src.errors.ConcretizationTypeError)\n\nThe stack trace above excludes JAX-internal frames.\nThe following is the original exception that occurred, unmodified.\n\n--------------------", 299 | "\nThe above exception was the direct cause of the following exception:\n", 300 | "\u001b[0;31mConcretizationTypeError\u001b[0m Traceback (most recent call last)", 301 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mg_jit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mg_jit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m20\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# 应该会报错\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 302 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mreraise_with_filtered_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 140\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_under_reraiser\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 303 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/api.py\u001b[0m in \u001b[0;36mf_jitted\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 414\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcpp_jitted_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 415\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 416\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcpp_jitted_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 417\u001b[0m \u001b[0mf_jitted\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cpp_jitted_f\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcpp_jitted_f\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 418\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 304 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/api.py\u001b[0m in \u001b[0;36mcache_miss\u001b[0;34m(_, *args, **kwargs)\u001b[0m\n\u001b[1;32m 295\u001b[0m \u001b[0m_check_arg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 296\u001b[0m \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 297\u001b[0;31m out_flat = xla.xla_call(\n\u001b[0m\u001b[1;32m 298\u001b[0m \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 299\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs_flat\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 305 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36mbind\u001b[0;34m(self, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1392\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1393\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1394\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcall_bind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1395\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1396\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 306 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36mcall_bind\u001b[0;34m(primitive, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1383\u001b[0m \u001b[0mtracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1384\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mmaybe_new_sublevel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1385\u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1386\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_lower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapply_todos\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv_trace_todo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1387\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 307 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess\u001b[0;34m(self, trace, fun, tracers, params)\u001b[0m\n\u001b[1;32m 1395\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1396\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1397\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1398\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1399\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpost_process\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 308 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess_call\u001b[0;34m(self, primitive, f, tracers, params)\u001b[0m\n\u001b[1;32m 623\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 624\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 625\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimpl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 626\u001b[0m \u001b[0mprocess_map\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 309 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_call_impl\u001b[0;34m(fun, device, backend, name, donated_invars, *args)\u001b[0m\n\u001b[1;32m 584\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 585\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_xla_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlu\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mWrappedFun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdonated_invars\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 586\u001b[0;31m compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,\n\u001b[0m\u001b[1;32m 587\u001b[0m *unsafe_map(arg_spec, args))\n\u001b[1;32m 588\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 310 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mmemoized_fun\u001b[0;34m(fun, *args)\u001b[0m\n\u001b[1;32m 258\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpopulate_stores\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 259\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 260\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 261\u001b[0m \u001b[0mcache\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 311 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_callable\u001b[0;34m(fun, device, backend, name, donated_invars, *arg_specs)\u001b[0m\n\u001b[1;32m 660\u001b[0m \u001b[0mabstract_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg_devices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0munzip2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg_specs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 661\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0momnistaging_enabled\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 662\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_to_jaxpr_final\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mabstract_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 663\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTracer\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mc\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 664\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUnexpectedTracerError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Encountered an unexpected tracer.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 312 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr_final\u001b[0;34m(fun, in_avals)\u001b[0m\n\u001b[1;32m 1218\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msource_info\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun_sourceinfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1219\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjaxpr_stack\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1220\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_to_subjaxpr_dynamic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1221\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1222\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 313 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_subjaxpr_dynamic\u001b[0;34m(fun, main, in_avals)\u001b[0m\n\u001b[1;32m 1198\u001b[0m \u001b[0mtrace\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDynamicJaxprTrace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcur_sublevel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1199\u001b[0m \u001b[0min_tracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_arg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1200\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0min_tracers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1201\u001b[0m \u001b[0mout_tracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1202\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mframe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_jaxpr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0min_tracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tracers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 314 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 167\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;31m# Some transformations yield from inside context managers, so we have to\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 315 | "\u001b[0;32m\u001b[0m in \u001b[0;36mg\u001b[0;34m(x, n)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0;32mwhile\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 316 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36m__bool__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 552\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__getitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_getitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 553\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__nonzero__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nonzero\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 554\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0m__bool__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_bool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 555\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__int__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_int\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 556\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__long__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_long\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 317 | "\u001b[0;32m~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36merror\u001b[0;34m(self, arg)\u001b[0m\n\u001b[1;32m 949\u001b[0m f\"or `jnp.array(x, {fun.__name__})` instead.\")\n\u001b[1;32m 950\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 951\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mConcretizationTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfname_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 952\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 953\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 318 | "\u001b[0;31mConcretizationTypeError\u001b[0m: Abstract tracer value encountered where concrete value is expected: Tracedwith\nThe problem arose with the `bool` function. \nWhile tracing the function g at :4, this concrete value was not available in Python because it depends on the value of the arguments to g at :4 at flattened positions [1], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).\n (https://jax.readthedocs.io/en/latest/errors.html#jax._src.errors.ConcretizationTypeError)" 319 | ] 320 | } 321 | ], 322 | "source": [ 323 | "# while loop conditioned on x and n\n", 324 | "# while循环中以x和n值为条件\n", 325 | "\n", 326 | "def g(x, n):\n", 327 | " i = 0\n", 328 | " while i < n:\n", 329 | " i += 1\n", 330 | " return x + i\n", 331 | "\n", 332 | "g_jit = jax.jit(g)\n", 333 | "g_jit(10, 20) # 应该会报错" 334 | ] 335 | }, 336 | { 337 | "cell_type": "markdown", 338 | "id": "stretch-fundamental", 339 | "metadata": {}, 340 | "source": [ 341 | "问题在于,我们试图将即时编译函数的输入值作为条件。我们无法执行此操作的原因与上述事实有关,jaxpr取决于用于跟踪它的实际值。\n", 342 | "\n", 343 | "有关在跟踪中使用的值的信息越具体,我们将越可以使用标准的Python控制流来表达自己。但是,过于具体意味着我们无法将相同的跟踪函数用于其他值。JAX通过针对不同母的在不同的抽象级别进行跟踪来解决此问题。\n", 344 | "\n", 345 | "对于`jax.jit`,默认几位别`ShapedArray`——也就是说,每个跟踪器都有具体的形状(虽然允许我们对其进行调整),但没有具体的值。这使得编译后的函数可以再所有可能具有相同形状的输入上工作,这是机器学习中的标准用例。但是,由于追踪器没有具体的值,因此如果我们尝试给一个跟踪器限定条件,则会得到上面的错误。\n", 346 | "\n", 347 | "在`jax.grad`中,约束则更加宽松,您可以做更多的尝试。但是,如果要组合多个转换,则必须满足最严格的转换约束。因此,如果您在使用`jit(grad(f))`是,则`f`不能是一个限制条件。有关Python控制流和JAX交互之间的更多详细信息,请参见[JAX锋芒毕露:🔪 控制流](https://render.githubusercontent.com/view/ipynb?color_mode=light&commit=fe4a5f85bf7936468ed39f20cced5b25a1612efb&enc_url=68747470733a2f2f7261772e67697468756275736572636f6e74656e742e636f6d2f726173696e2d7473756b7562612f4a41585f6368696e6573655f7475746f7269616c2f666534613566383562663739333634363865643339663230636365643562323561313631326566622f6f6666696369616c2d7475746f7269616c732f47657474696e67537461727465642f312e332d4a41582545392539342538422545382538412539322545362541462539352545392539432542322e6970796e62&nwo=rasin-tsukuba%2FJAX_chinese_tutorial&path=official-tutorials%2FGettingStarted%2F1.3-JAX%E9%94%8B%E8%8A%92%E6%AF%95%E9%9C%B2.ipynb&repository_id=349726397&repository_type=Repository#%F0%9F%94%AA-%E6%8E%A7%E5%88%B6%E6%B5%81)。\n", 348 | "\n", 349 | "解决问题的一种方法是重写代码,以免出现条件限制。另一个方法是使用特殊的控制流运算符,例如`jax.lax.cond`。但是,有时候也是不太可能的。在这种情况下,您可以考虑仅添加函数的一部分。例如,如果函数中计算量最大的部分位于循环内部,则我们可以仅对该内部部分进行JIT(请确保下一节有关缓存的内容,避免陷入混乱):" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 8, 355 | "id": "modern-blood", 356 | "metadata": {}, 357 | "outputs": [ 358 | { 359 | "data": { 360 | "text/plain": [ 361 | "DeviceArray(30, dtype=int32)" 362 | ] 363 | }, 364 | "execution_count": 8, 365 | "metadata": {}, 366 | "output_type": "execute_result" 367 | } 368 | ], 369 | "source": [ 370 | "# while loop conditioned on x and n with a jitted body\n", 371 | "# 即时编译的while循环中以x和n为条件\n", 372 | "@jax.jit\n", 373 | "def loop_body(prev_i):\n", 374 | " return prev_i + 1\n", 375 | "\n", 376 | "def g_inner_jitted(x, n):\n", 377 | " i = 0\n", 378 | " while i < n:\n", 379 | " i = loop_body(i)\n", 380 | " return x + i\n", 381 | "\n", 382 | "g_inner_jitted(10, 20)" 383 | ] 384 | }, 385 | { 386 | "cell_type": "markdown", 387 | "id": "stunning-switch", 388 | "metadata": {}, 389 | "source": [ 390 | "如果我们确实要对一个对输入值有条件的函数进行JIT,我们可以通过制定`static_argnums`高速度JAX来帮助自己针对特定输入时使用不太抽象的跟踪器。这样做的代价是生成的jaxpr灵活性较差,因此JAX将不得不为指定输入的每个新值重新编译该函数。晋档保证该函数获得有限的不同值时,这才是一个好策略。" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": 9, 396 | "id": "spiritual-bennett", 397 | "metadata": {}, 398 | "outputs": [ 399 | { 400 | "name": "stdout", 401 | "output_type": "stream", 402 | "text": [ 403 | "10\n" 404 | ] 405 | } 406 | ], 407 | "source": [ 408 | "f_jit_correct = jax.jit(f, static_argnums=0)\n", 409 | "print(f_jit_correct(10))" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 10, 415 | "id": "equipped-complexity", 416 | "metadata": {}, 417 | "outputs": [ 418 | { 419 | "name": "stdout", 420 | "output_type": "stream", 421 | "text": [ 422 | "30\n" 423 | ] 424 | } 425 | ], 426 | "source": [ 427 | "g_jit_correct = jax.jit(g, static_argnums=1)\n", 428 | "print(g_jit_correct(10, 20))" 429 | ] 430 | }, 431 | { 432 | "cell_type": "markdown", 433 | "id": "empirical-clear", 434 | "metadata": {}, 435 | "source": [ 436 | "## 使用JIT的时机\n", 437 | "\n", 438 | "在许多以上示例都是不值得即时编译的:" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": 11, 444 | "id": "impossible-manor", 445 | "metadata": {}, 446 | "outputs": [ 447 | { 448 | "name": "stdout", 449 | "output_type": "stream", 450 | "text": [ 451 | "g jitted:\n", 452 | "51.8 µs ± 1.68 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", 453 | "g:\n", 454 | "591 ns ± 4.29 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" 455 | ] 456 | } 457 | ], 458 | "source": [ 459 | "print(\"g jitted:\")\n", 460 | "%timeit g_jit_correct(10, 20).block_until_ready()\n", 461 | "\n", 462 | "print(\"g:\")\n", 463 | "%timeit g(10, 20)" 464 | ] 465 | }, 466 | { 467 | "cell_type": "markdown", 468 | "id": "white-worst", 469 | "metadata": {}, 470 | "source": [ 471 | "这是因为`jax.jit`本身引入了一些开销。因此,荣昌只有在编译函数很复杂,并且要多次调用的情况下才用于节省时间。幸运的是,机器学习中经常有这种情况,我们倾向于编译一个大型、复杂的模型,然后运行数百万次迭代。\n", 472 | "\n", 473 | "通常,您将最大的计算模块即时编译;理想的情况下,将整个`update`函数即时编译。这都为便以其提供了最大的优化自由度。\n", 474 | "\n", 475 | "## 缓存\n", 476 | "\n", 477 | "了解`jax.jit`的缓存行为很重要。\n", 478 | "\n", 479 | "假设我们定义了`f=jax.jit(g)`。当我们第一次调用`f`时,它将被编译并缓存生成的XLA代码。`f`的后续调用将重用缓存的代码。`jax.jit`就是通过这种方式来弥补编译的前期成本。\n", 480 | "\n", 481 | "如果我指定`static_argnums`,则缓存的代码将仅用于标记为`static`的相同参数值。如果其中任何一个发生更改,则会重新编译。如果有很多值,那么您的程序可能要花更多的时间在编译操作上,而不是执行。\n", 482 | "\n", 483 | "避免在循环内调用`jax.jit`。这样做在每次调用时创建一个新的`f`,该`f`将在每次调用时进行编译,而不是重复使用相同的缓存函数:" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": 12, 489 | "id": "meaningful-gibson", 490 | "metadata": {}, 491 | "outputs": [ 492 | { 493 | "name": "stdout", 494 | "output_type": "stream", 495 | "text": [ 496 | "jit called outside the loop:\n", 497 | "5.13 ms ± 123 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", 498 | "jit called inside the loop:\n", 499 | "7.46 ms ± 156 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" 500 | ] 501 | } 502 | ], 503 | "source": [ 504 | "def unjitted_loop_body(prev_i):\n", 505 | " return prev_i + 1\n", 506 | "\n", 507 | "def g_inner_jitted_poorly(x, n):\n", 508 | " i = 0\n", 509 | " while i < n:\n", 510 | " # Don't do this\n", 511 | " # 别这么做\n", 512 | " i = jax.jit(unjitted_loop_body)(i)\n", 513 | " return x + i\n", 514 | "\n", 515 | "print(\"jit called outside the loop:\")\n", 516 | "%timeit g_inner_jitted(10, 20).block_until_ready()\n", 517 | "\n", 518 | "print(\"jit called inside the loop:\")\n", 519 | "%timeit g_inner_jitted_poorly(10, 20).block_until_ready() " 520 | ] 521 | } 522 | ], 523 | "metadata": { 524 | "kernelspec": { 525 | "display_name": "Python 3", 526 | "language": "python", 527 | "name": "python3" 528 | }, 529 | "language_info": { 530 | "codemirror_mode": { 531 | "name": "ipython", 532 | "version": 3 533 | }, 534 | "file_extension": ".py", 535 | "mimetype": "text/x-python", 536 | "name": "python", 537 | "nbconvert_exporter": "python", 538 | "pygments_lexer": "ipython3", 539 | "version": "3.9.2" 540 | } 541 | }, 542 | "nbformat": 4, 543 | "nbformat_minor": 5 544 | } 545 | -------------------------------------------------------------------------------- /official-tutorials/GettingStarted/Tutorial:Jax101/1.4.3-JAX的自动向量化.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "extra-membrane", 6 | "metadata": {}, 7 | "source": [ 8 | "# JAX的自动向量化\n", 9 | "\n", 10 | "> 作者:Matteo Hessel\n", 11 | "\n", 12 | "在上一节我们讨论通过 `jax.jit`函数进行的JIT编译。本届讨论了JAX中的另一种转换,即通过`jax.vmap`进行矢量化。\n", 13 | "\n", 14 | "## 手动向量化\n", 15 | "\n", 16 | "让我们来看下面的简单代码,该代码计算两个一维向量的卷积:" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "id": "fewer-carry", 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "data": { 27 | "text/plain": [ 28 | "DeviceArray([11., 20., 29.], dtype=float32)" 29 | ] 30 | }, 31 | "execution_count": 1, 32 | "metadata": {}, 33 | "output_type": "execute_result" 34 | } 35 | ], 36 | "source": [ 37 | "import jax\n", 38 | "import jax.numpy as jnp\n", 39 | "\n", 40 | "x = jnp.arange(5)\n", 41 | "w = jnp.array([2., 3., 4.])\n", 42 | "\n", 43 | "def convolve(x, w):\n", 44 | " output = []\n", 45 | " for i in range(1, len(x)-1):\n", 46 | " output.append(jnp.dot(x[i-1:i+2], w))\n", 47 | " return jnp.array(output)\n", 48 | "\n", 49 | "convolve(x, w)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "id": "elder-brush", 55 | "metadata": {}, 56 | "source": [ 57 | "假设我们想要将此函数应用于一批权重`w`和一批向量`x`上:" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 2, 63 | "id": "liberal-destruction", 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "xs = jnp.stack([x, x])\n", 68 | "ws = jnp.stack([w, w])" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "id": "invisible-discovery", 74 | "metadata": {}, 75 | "source": [ 76 | "最简单的想法就是在Python中循环遍历该批处理:" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 3, 82 | "id": "marked-pierre", 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "data": { 87 | "text/plain": [ 88 | "DeviceArray([[11., 20., 29.],\n", 89 | " [11., 20., 29.]], dtype=float32)" 90 | ] 91 | }, 92 | "execution_count": 3, 93 | "metadata": {}, 94 | "output_type": "execute_result" 95 | } 96 | ], 97 | "source": [ 98 | "def manually_batched_convolve(xs, ws):\n", 99 | " output = []\n", 100 | " for i in range(xs.shape[0]):\n", 101 | " output.append(convolve(xs[i], ws[i]))\n", 102 | " return jnp.stack(output)\n", 103 | "\n", 104 | "manually_batched_convolve(xs, ws)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "id": "occupied-atlas", 110 | "metadata": {}, 111 | "source": [ 112 | "这虽然会产生正确的结果,但是效率不高。\n", 113 | "\n", 114 | "为了有效地批处理计算,通常必须手动重写函数以确保它以向量的形式完成。这并不是特别难实现,但确实设计更改函数如何处理索引,维度和输入的其他部分。\n", 115 | "\n", 116 | "例如,我们可以手动重写`convolve()`来支持跨批处理维度的向量化计算,如下所示:" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 4, 122 | "id": "graphic-polymer", 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "data": { 127 | "text/plain": [ 128 | "DeviceArray([[11., 20., 29.],\n", 129 | " [11., 20., 29.]], dtype=float32)" 130 | ] 131 | }, 132 | "execution_count": 4, 133 | "metadata": {}, 134 | "output_type": "execute_result" 135 | } 136 | ], 137 | "source": [ 138 | "def manually_vectorized_convolve(xs, ws):\n", 139 | " output = []\n", 140 | " for i in range(1, xs.shape[-1]-1):\n", 141 | " output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))\n", 142 | " return jnp.stack(output, axis=1)\n", 143 | "\n", 144 | "manually_vectorized_convolve(xs, ws)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "id": "elect-damages", 150 | "metadata": {}, 151 | "source": [ 152 | "这种重新实现是混乱且容易出错的。 幸运的是,JAX提供了另一种方法。\n", 153 | "\n", 154 | "## 自动向量化\n", 155 | "\n", 156 | "在JAX中,`jax.vmap`转换旨在自动生成函数的向量化实现:" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 5, 162 | "id": "major-animation", 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "data": { 167 | "text/plain": [ 168 | "DeviceArray([[11., 20., 29.],\n", 169 | " [11., 20., 29.]], dtype=float32)" 170 | ] 171 | }, 172 | "execution_count": 5, 173 | "metadata": {}, 174 | "output_type": "execute_result" 175 | } 176 | ], 177 | "source": [ 178 | "auto_batch_convolve = jax.vmap(convolve)\n", 179 | "auto_batch_convolve(xs, ws)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "id": "studied-afternoon", 185 | "metadata": {}, 186 | "source": [ 187 | "它通过类似于`jax.jit`的函数追踪功能,并在每个输入的开头自动添加批处理来实现此目的。\n", 188 | "\n", 189 | "如果批次维度不是第一个,则可以使用`in_axes`和`out_axes`参数指定批次维度在输入和输出中的位置。如果批处理轴对于所有输入和输出(或列表)相同,则这些值可以是整数。" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 6, 195 | "id": "empty-metro", 196 | "metadata": {}, 197 | "outputs": [ 198 | { 199 | "data": { 200 | "text/plain": [ 201 | "DeviceArray([[11., 11.],\n", 202 | " [20., 20.],\n", 203 | " [29., 29.]], dtype=float32)" 204 | ] 205 | }, 206 | "execution_count": 6, 207 | "metadata": {}, 208 | "output_type": "execute_result" 209 | } 210 | ], 211 | "source": [ 212 | "auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)\n", 213 | "\n", 214 | "xst = jnp.transpose(xs)\n", 215 | "wst = jnp.transpose(ws)\n", 216 | "\n", 217 | "auto_batch_convolve_v2(xst, wst)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "id": "beautiful-february", 223 | "metadata": {}, 224 | "source": [ 225 | "`jax.vmap`还支持仅对其中一个参数进行批处理的情况:例如,如果您想将一组权重`w`与一组向量`x`进行卷积,则在这种情况下,可以将`in_axes`参数设置为`None`:" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 7, 231 | "id": "precious-devil", 232 | "metadata": {}, 233 | "outputs": [ 234 | { 235 | "data": { 236 | "text/plain": [ 237 | "DeviceArray([[11., 20., 29.],\n", 238 | " [11., 20., 29.]], dtype=float32)" 239 | ] 240 | }, 241 | "execution_count": 7, 242 | "metadata": {}, 243 | "output_type": "execute_result" 244 | } 245 | ], 246 | "source": [ 247 | "batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])\n", 248 | "\n", 249 | "batch_convolve_v3(xs, w)" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "id": "boring-calgary", 255 | "metadata": {}, 256 | "source": [ 257 | "## 组合变换\n", 258 | "\n", 259 | "和所有JAX变换一样,`jax.jit`和`jax.vmap`被设计成可组合的,也就是说您可以使用`jit`包装被向量化的函数,或使用`vmap`包装一个被即时编译的函数,一切都将照常工作:" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 8, 265 | "id": "flexible-anger", 266 | "metadata": {}, 267 | "outputs": [ 268 | { 269 | "data": { 270 | "text/plain": [ 271 | "DeviceArray([[11., 20., 29.],\n", 272 | " [11., 20., 29.]], dtype=float32)" 273 | ] 274 | }, 275 | "execution_count": 8, 276 | "metadata": {}, 277 | "output_type": "execute_result" 278 | } 279 | ], 280 | "source": [ 281 | "jitted_batch_convolve = jax.jit(auto_batch_convolve)\n", 282 | "\n", 283 | "jitted_batch_convolve(xs, ws)" 284 | ] 285 | } 286 | ], 287 | "metadata": { 288 | "kernelspec": { 289 | "display_name": "Python 3", 290 | "language": "python", 291 | "name": "python3" 292 | }, 293 | "language_info": { 294 | "codemirror_mode": { 295 | "name": "ipython", 296 | "version": 3 297 | }, 298 | "file_extension": ".py", 299 | "mimetype": "text/x-python", 300 | "name": "python", 301 | "nbconvert_exporter": "python", 302 | "pygments_lexer": "ipython3", 303 | "version": "3.9.2" 304 | } 305 | }, 306 | "nbformat": 4, 307 | "nbformat_minor": 5 308 | } 309 | -------------------------------------------------------------------------------- /official-tutorials/GettingStarted/Tutorial:Jax101/1.4.4-JAX中的高级自动微分.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "available-province", 6 | "metadata": {}, 7 | "source": [ 8 | "# JAX中的高级自动微分\n", 9 | "\n", 10 | "> 作者: Vlatimir Mikulik & Matteo Hessel\n", 11 | "\n", 12 | "计算梯度是现代机器学习方法的关键部分。本节涵盖了与现代机器学习相关的自动微分中的一些高级主题。\n", 13 | "\n", 14 | "尽管了解自动微分的工作原理对于大多数情况下使用JAX也不是至关重要,但我们鼓励读者观看这个[视频](https://www.bilibili.com/video/BV1YX4y1G7V3/)来获得更深入的认识。\n", 15 | "\n", 16 | "[`Autodiff`指导手册]()是对JAX后端如何实现这些功能更高级更详细的解释。在JAX中进行大多数操作不需要了解这一点。但是,某些功能(例如[自定义微分]())依赖于对此的理解,因此如果您需要使用这些解释,则值得一看。\n", 17 | "\n", 18 | "## 导入" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "id": "sitting-assist", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import jax\n", 29 | "import jax.numpy as jnp" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "id": "departmental-framing", 35 | "metadata": {}, 36 | "source": [ 37 | "## 高阶求导\n", 38 | "\n", 39 | "JAX的`autodiff`使得计算高阶导数变得容易,因为计算导数的函数本身是可微的。因此,高阶导数与叠加变换一样容易。\n", 40 | "\n", 41 | "我们在单变量情况下说明这一点:\n", 42 | "\n", 43 | "函数 $f(x)=x^3 + 2x^2 - 3x + 1$ 可以被计算为:" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "id": "international-denmark", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "f = lambda x: x ** 3 + 2 * x ** 2 - 3 * x + 1\n", 54 | "dfdx = jax.grad(f)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "id": "small-times", 60 | "metadata": {}, 61 | "source": [ 62 | "对于$f$的高阶求导为:\n", 63 | "\n", 64 | "$$\n", 65 | "f'(x)=3x^2+4x-3\\\\\n", 66 | "f''(x)=6x+4\\\\\n", 67 | "f'''(x)=6\\\\\n", 68 | "f^{iv}=0\\\\\n", 69 | "$$\n", 70 | "\n", 71 | "在JAX中计算任何导数都像链接`grad`函数一样容易:" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 3, 77 | "id": "fancy-belize", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "d2fdx = jax.grad(dfdx)\n", 82 | "d3fdx = jax.grad(d2fdx)\n", 83 | "d4fdx = jax.grad(d3fdx)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "id": "median-birthday", 89 | "metadata": {}, 90 | "source": [ 91 | "当`x=1`时,我们将会得到:\n", 92 | "\n", 93 | "$$\n", 94 | "f'(1)=4\\\\\n", 95 | "f''(1)=10\\\\\n", 96 | "f'''(1)=6\\\\\n", 97 | "f^{iv}(1)=0\\\\\n", 98 | "$$\n", 99 | "\n", 100 | "使用JAX:" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 4, 106 | "id": "reliable-perspective", 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "name": "stdout", 111 | "output_type": "stream", 112 | "text": [ 113 | "4.0\n", 114 | "10.0\n", 115 | "6.0\n", 116 | "0.0\n" 117 | ] 118 | } 119 | ], 120 | "source": [ 121 | "print(dfdx(1.))\n", 122 | "print(d2fdx(1.))\n", 123 | "print(d3fdx(1.))\n", 124 | "print(d4fdx(1.))" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "id": "organized-observation", 130 | "metadata": {}, 131 | "source": [ 132 | "在多变量情况下,高阶导数更为复杂。 函数的二阶导数由其[Hessian矩阵](https://baike.baidu.com/item/黑塞矩阵/2248782)表示,根据:\n", 133 | "\n", 134 | "$$\n", 135 | "(\\mathbb{H}f)_{i,j}=\\frac{\\partial^2 f}{\\partial_i \\partial_j}\n", 136 | "$$\n", 137 | "\n", 138 | "多变量实值函数的Hessian,$f:\\mathbb{R}^n \\rightarrow \\mathbb{R}$,可以用其梯度的雅克比矩阵来识别。JAX提供了两种变换来计算雅克比矩阵,即`jax.jacfwd`和 `jax.jacrev`,分别对应于正向和反向模式的`autodiff`。虽然他们给出的答案是相同的,但是在不同情况下,某种方法的效率会更高。请参阅上面的[视频链接](https://www.bilibili.com/video/BV1YX4y1G7V3/)来获得更多解释。" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 5, 144 | "id": "internal-geology", 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "def hessian(f):\n", 149 | " return jax.jacfwd(jax.grad(f))" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "id": "significant-omaha", 155 | "metadata": {}, 156 | "source": [ 157 | "我们再次检查点乘是否正确: $f:x \\rightarrow x^Tx$\n", 158 | "\n", 159 | "如果 $i=j$ , $\\frac{\\partial^2 f}{\\partial_i \\partial_j}(x)=2$ 。否则, $\\frac{\\partial^2 f}{\\partial_i \\partial_j}(x)=0$ 。" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 6, 165 | "id": "beginning-keeping", 166 | "metadata": {}, 167 | "outputs": [ 168 | { 169 | "data": { 170 | "text/plain": [ 171 | "DeviceArray([[2., 0., 0.],\n", 172 | " [0., 2., 0.],\n", 173 | " [0., 0., 2.]], dtype=float32)" 174 | ] 175 | }, 176 | "execution_count": 6, 177 | "metadata": {}, 178 | "output_type": "execute_result" 179 | } 180 | ], 181 | "source": [ 182 | "def f(x):\n", 183 | " return jnp.dot(x, x)\n", 184 | "\n", 185 | "hessian(f)(jnp.array([1., 2., 3.]))" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "id": "tribal-shift", 191 | "metadata": {}, 192 | "source": [ 193 | "但是,通常我们对完整的Hessian本身不感兴趣,因此计算效率可能非常低。 [Autodiff指导手册]()解释了一些技巧,例如Hessian-vector乘法,可以在不具体化整个矩阵的情况下使用它。\n", 194 | "\n", 195 | "如果您打算在JAX中使用高阶导数,我们强烈建议您阅读[Autodiff指导手册]()。\n", 196 | "\n", 197 | "## 高阶优化\n", 198 | "\n", 199 | "一些元学习技术,例如与模型无关的元学习(MAML),需要通过梯度更新来区分。 在其他框架中,这可能非常麻烦,但是在JAX中,它要容易得多:" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 7, 205 | "id": "civic-patrick", 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "def meta_loss_fn(params, data):\n", 210 | " \"\"\"\n", 211 | " Computes the loss after one step of SGD.\n", 212 | " 在SGD一步后计算损失\n", 213 | " \"\"\"\n", 214 | " grads = jax.grad(loss_fn)(params, data)\n", 215 | " return loss_fn(params - lr * grads, data)\n", 216 | "\n", 217 | "#meta_grads = jax.grad(meta_loss_fn)(params, data)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "id": "modern-means", 223 | "metadata": {}, 224 | "source": [ 225 | "## 梯度停止\n", 226 | "\n", 227 | "`Autodiff`使函数能够自动计算相对于输入的梯度。但有时我们可能需要一些其他控制,我们可能希望避免通过计算图的某些子集向后传播梯度。\n", 228 | "\n", 229 | "例如`TD(0)`(时差)强化学习更新。这用于从环境交互的经验中学习估计环境中状态的值。假设由状态 $s_{t-1}$ 中的值来估计 $v_\\theta(s_{t-1})$ 由线性函数设定参数。" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 8, 235 | "id": "frequent-breast", 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [ 239 | "# Value function and initial parameters\n", 240 | "# 值函数和参数初始化\n", 241 | "value_fn = lambda theta, state: jnp.dot(theta, state)\n", 242 | "theta = jnp.array([0.1, -0.1, 0.])" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "id": "fossil-russia", 248 | "metadata": {}, 249 | "source": [ 250 | "考虑从状态 $s_{t-1}$ 到 状态 $s_t$,在此过程中我们观察到奖励 $r_t$" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 9, 256 | "id": "proprietary-omega", 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "s_tm1 = jnp.array([1., 2., -1.])\n", 261 | "r_t = jnp.array(1.)\n", 262 | "s_t = jnp.array([2., 1., 0.])" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "id": "iraqi-indiana", 268 | "metadata": {}, 269 | "source": [ 270 | "网络参数的TD(0)更新为: $$\\triangle \\theta = (r_t + v_\\theta (s_t) - v_\\theta (s_{t-1})) \\triangledown v_\\theta (s_{t-1})$$\n", 271 | "\n", 272 | "此更新不是任何损失函数的梯度。但是它可以写成伪损失函数的梯度:\n", 273 | "\n", 274 | "$$\n", 275 | "L(\\theta) = [r_t + v_\\theta (s_t) - v_\\theta(s_{t-1})]^2\n", 276 | "$$\n", 277 | "\n", 278 | "如果忽略目标 $r_t + v_\\theta(s_t)$对于参数 $\\theta$的依赖性。\n", 279 | "\n", 280 | "我们如何在JAX中实现呢?如果我们简单地写出伪损失函数,我们将得到:" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 10, 286 | "id": "classified-procedure", 287 | "metadata": {}, 288 | "outputs": [ 289 | { 290 | "data": { 291 | "text/plain": [ 292 | "DeviceArray([ 2.4, -2.4, 2.4], dtype=float32)" 293 | ] 294 | }, 295 | "execution_count": 10, 296 | "metadata": {}, 297 | "output_type": "execute_result" 298 | } 299 | ], 300 | "source": [ 301 | "def td_loss(theta, s_tm1, r_t, s_t):\n", 302 | " v_tm1 = value_fn(theta, s_tm1)\n", 303 | " target = r_t + value_fn(theta, s_t)\n", 304 | " return (target - v_tm1) ** 2\n", 305 | "\n", 306 | "td_update = jax.grad(td_loss)\n", 307 | "delta_theta = td_update(theta, s_tm1, r_t, s_t)\n", 308 | "\n", 309 | "delta_theta" 310 | ] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "id": "ordinary-celebrity", 315 | "metadata": {}, 316 | "source": [ 317 | "但`td_update`不会计算TD(0)更新,因为梯度计算将包括目标对 $\\theta$的依赖性。\n", 318 | "我们可以使用 `jax.lax.stop_gradient` 来强制忽略目标对 $\\theta$的依赖:" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": 11, 324 | "id": "latest-biotechnology", 325 | "metadata": {}, 326 | "outputs": [ 327 | { 328 | "data": { 329 | "text/plain": [ 330 | "DeviceArray([-2.4, -4.8, 2.4], dtype=float32)" 331 | ] 332 | }, 333 | "execution_count": 11, 334 | "metadata": {}, 335 | "output_type": "execute_result" 336 | } 337 | ], 338 | "source": [ 339 | "def td_loss(theta, s_tm1, r_t, s_t):\n", 340 | " v_tm1 = value_fn(theta, s_tm1)\n", 341 | " target = r_t + value_fn(theta, s_t)\n", 342 | " return (jax.lax.stop_gradient(target) - v_tm1) ** 2\n", 343 | "\n", 344 | "td_update = jax.grad(td_loss)\n", 345 | "delta_theta = td_update(theta, s_tm1, r_t, s_t)\n", 346 | "\n", 347 | "delta_theta" 348 | ] 349 | }, 350 | { 351 | "cell_type": "markdown", 352 | "id": "corporate-fleece", 353 | "metadata": {}, 354 | "source": [ 355 | "这会将`target`视为不依赖于参数θ,并计算对参数的正确更新。\n", 356 | "\n", 357 | "`jax.lax.stop_gradient`在其他设置中也可能有用,例如,如果您希望某些损失的梯度仅影响神经网络参数的一个子集(因为其他参数是使用不同的损失函数训练)。\n", 358 | "\n", 359 | "## 使用`stop_gradient`的直通估算器\n", 360 | "\n", 361 | "直通估算器是一种定义不可微分函数梯度的技巧。给定不可微分函数 $f: \\mathbb{R}_n \\rightarrow \\mathbb{R}_n$ 作为我们希望找到其梯度的更大函数的一部分,我们简单地在反向传播过程中假定 `f`是恒等函数。这里可以使用 `jax.lax.stop_gradient` 清楚地实现:" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 13, 367 | "id": "closing-tomorrow", 368 | "metadata": {}, 369 | "outputs": [ 370 | { 371 | "name": "stdout", 372 | "output_type": "stream", 373 | "text": [ 374 | "f(x): 3.0\n", 375 | "straight_through_f(x): 3.0\n", 376 | "grad(f)(x): 0.0\n", 377 | "grad(straight_through_f)(x): 1.0\n" 378 | ] 379 | } 380 | ], 381 | "source": [ 382 | "def f(x):\n", 383 | " return jnp.round(x) # 不可微分\n", 384 | "\n", 385 | "def straight_through_f(x):\n", 386 | " return x + jax.lax.stop_gradient(f(x) - x)\n", 387 | "\n", 388 | "print(\"f(x): \", f(3.2))\n", 389 | "print(\"straight_through_f(x): \", straight_through_f(3.2))\n", 390 | "\n", 391 | "print(\"grad(f)(x): \", jax.grad(f)(3.2))\n", 392 | "print(\"grad(straight_through_f)(x): \", jax.grad(straight_through_f)(3.2))\n" 393 | ] 394 | }, 395 | { 396 | "cell_type": "markdown", 397 | "id": "sapphire-regression", 398 | "metadata": {}, 399 | "source": [ 400 | "## 逐样本梯度\n", 401 | "\n", 402 | "尽管大多数ML系统从一批数据中计算梯度和更新,但出于计算效率和/或方差减少的原因,有时有时必须访问批中每个特定样本相关的梯度/更新。\n", 403 | "\n", 404 | "例如,需要根据梯度大小对数据进行优先级排序,或者对每个样本进行裁剪/归一化。\n", 405 | "\n", 406 | "在许多框架(PyTorch,TF,Theano)中,计算每个示例的梯度通常并不容易,因为该库直接在批处理中累积梯度。简单的解决方法通常效率很低(例如,每个样本计算一个单独的损失,然后汇总所得的梯度)。\n", 407 | "\n", 408 | "在JAX中,我们可以定义代码,以一种简单而有效的方式计算每个样本的梯度。只需将 `jax`, `vmap` 和 `grad` 结合在一起:" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 15, 414 | "id": "dying-mouth", 415 | "metadata": {}, 416 | "outputs": [ 417 | { 418 | "data": { 419 | "text/plain": [ 420 | "DeviceArray([[-2.4, -4.8, 2.4],\n", 421 | " [-2.4, -4.8, 2.4]], dtype=float32)" 422 | ] 423 | }, 424 | "execution_count": 15, 425 | "metadata": {}, 426 | "output_type": "execute_result" 427 | } 428 | ], 429 | "source": [ 430 | "perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))\n", 431 | "\n", 432 | "# 测试:\n", 433 | "batched_s_tm1 = jnp.stack([s_tm1, s_tm1])\n", 434 | "batched_r_t = jnp.stack([r_t, r_t])\n", 435 | "batched_s_t = jnp.stack([s_t, s_t])\n", 436 | "\n", 437 | "perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)" 438 | ] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "id": "modern-fiber", 443 | "metadata": {}, 444 | "source": [ 445 | "让我们把这些转换一个一个过一遍。\n", 446 | "\n", 447 | "首先,我们将 `jax.grad`应用于`td_loss`以获得一个计算损失地图的函数,也就是单个(非批处理的)输入上的参数:" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 17, 453 | "id": "therapeutic-sharing", 454 | "metadata": {}, 455 | "outputs": [ 456 | { 457 | "data": { 458 | "text/plain": [ 459 | "DeviceArray([-2.4, -4.8, 2.4], dtype=float32)" 460 | ] 461 | }, 462 | "execution_count": 17, 463 | "metadata": {}, 464 | "output_type": "execute_result" 465 | } 466 | ], 467 | "source": [ 468 | "dtdloss_dtheta = jax.grad(td_loss)\n", 469 | "dtdloss_dtheta(theta, s_tm1, r_t, s_t)" 470 | ] 471 | }, 472 | { 473 | "cell_type": "markdown", 474 | "id": "behind-rachel", 475 | "metadata": {}, 476 | "source": [ 477 | "该函数计算上面的一行数组。\n", 478 | "\n", 479 | "然后我们使用`jax.vmap`对这个函数进行向量化处理。这会将批次维度添加到所有输入和输出。现在,给定一批输入,我们将产生一批输出——批中的每个输出都对英语输入批中相应成员的梯度。" 480 | ] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "execution_count": 18, 485 | "id": "immune-certificate", 486 | "metadata": {}, 487 | "outputs": [ 488 | { 489 | "data": { 490 | "text/plain": [ 491 | "DeviceArray([[-2.4, -4.8, 2.4],\n", 492 | " [-2.4, -4.8, 2.4]], dtype=float32)" 493 | ] 494 | }, 495 | "execution_count": 18, 496 | "metadata": {}, 497 | "output_type": "execute_result" 498 | } 499 | ], 500 | "source": [ 501 | "almost_perex_grads = jax.vmap(dtdloss_dtheta)\n", 502 | "\n", 503 | "batched_theta = jnp.stack([theta, theta])\n", 504 | "almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)" 505 | ] 506 | }, 507 | { 508 | "cell_type": "markdown", 509 | "id": "ready-patch", 510 | "metadata": {}, 511 | "source": [ 512 | "这不是我们最想要的,因为我们必须手动想该函数提供一批 `theta`,而实际上我们只想用一个`theta`。我们通过在`jax.vmap`中" 513 | ] 514 | } 515 | ], 516 | "metadata": { 517 | "kernelspec": { 518 | "display_name": "Python 3", 519 | "language": "python", 520 | "name": "python3" 521 | }, 522 | "language_info": { 523 | "codemirror_mode": { 524 | "name": "ipython", 525 | "version": 3 526 | }, 527 | "file_extension": ".py", 528 | "mimetype": "text/x-python", 529 | "name": "python", 530 | "nbconvert_exporter": "python", 531 | "pygments_lexer": "ipython3", 532 | "version": "3.9.2" 533 | } 534 | }, 535 | "nbformat": 4, 536 | "nbformat_minor": 5 537 | } 538 | -------------------------------------------------------------------------------- /official-tutorials/GettingStarted/Tutorial:Jax101/1.4.5-JAX中的伪随机数.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "informal-establishment", 6 | "metadata": {}, 7 | "source": [ 8 | "# JAX中的伪随机数\n", 9 | "\n", 10 | "*作者: Matteo Hessel & Rosalia Schneider*\n", 11 | "\n", 12 | "这一章中我们主要关注伪随机数生成(PRNG: pseudo random number generation);换句话说,就是通过算法来生成数字序列的过程,这个序列的性质近似于从恰当分布中采样的随机数字序列的性质。\n", 13 | "\n", 14 | "PRNG生成的序列并不是真实的随机,因为它们实际上由初始值决定,也就是我们通常提到的种子 `seed`,随机采样的每个步骤都是某个状态 `state` 的确定性函数,这个状态会从一个样本转移到下一个样本。\n", 15 | "\n", 16 | "在任何机器学习和科学计算框架中,伪随机数生成都是很重要的一个组成部分。一般来说,JAX尽可能与NumPy兼容,但伪随机数生成是一个显著的例外。\n", 17 | "\n", 18 | "为了更好理解随机数生成中JAX和NumPy的区别,在这一章中两种方法都会有所涉及。\n", 19 | "\n", 20 | "## NumPy中的随机数\n", 21 | "\n", 22 | "在NumPy中,伪随机中生成由 `numpy.random` 模块原生支持。伪随机数生成基于全局状态 `state`,它可以由 `random.seed(SEED)` 来指定初始状态。" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "id": "photographic-mobility", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "import numpy as np\n", 33 | "\n", 34 | "np.random.seed(0)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "id": "adapted-pavilion", 40 | "metadata": {}, 41 | "source": [ 42 | "我们可以使用一下命令来检查状态的内容:" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 2, 48 | "id": "tropical-doctor", 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,\n", 56 | " 2481403966, 4042607538, 337614300, 3232553940, 1018809052,\n", 57 | " 3202401494, 1775180719, 3192392114, 594215549, 184016991,\n", 58 | " 829906058, 610491522, 3879932251, 3139825610, 297902587,\n", 59 | " 4075895579, 2943625357, 3530655617, 1423771745, 2135928312,\n", 60 | " 2891506774, 1066338622, 135451537, 933040465, 2759011858,\n", 61 | " 2273819758, 3545703099, 2516396728, 127 ...\n" 62 | ] 63 | } 64 | ], 65 | "source": [ 66 | "def print_truncated_random_state():\n", 67 | " \"\"\"\n", 68 | " To avoid spamming the outputs, print only part of the state.\n", 69 | " 为了避免输出过多冗杂输出,只打印部分状态\n", 70 | " \"\"\"\n", 71 | " \n", 72 | " full_random_state = np.random.get_state()\n", 73 | " print(str(full_random_state)[:460], '...')\n", 74 | " \n", 75 | "print_truncated_random_state()" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "id": "hydraulic-fluid", 81 | "metadata": {}, 82 | "source": [ 83 | "每次调用一个随机函数,`state`就更新一次:" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 3, 89 | "id": "explicit-schema", 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,\n", 97 | " 2481403966, 4042607538, 337614300, 3232553940, 1018809052,\n", 98 | " 3202401494, 1775180719, 3192392114, 594215549, 184016991,\n", 99 | " 829906058, 610491522, 3879932251, 3139825610, 297902587,\n", 100 | " 4075895579, 2943625357, 3530655617, 1423771745, 2135928312,\n", 101 | " 2891506774, 1066338622, 135451537, 933040465, 2759011858,\n", 102 | " 2273819758, 3545703099, 2516396728, 127 ...\n", 103 | "('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,\n", 104 | " 3904844661, 676747479, 2085143622, 1056793272, 3812477442,\n", 105 | " 2168787041, 275552121, 2696932952, 3432054210, 1657102335,\n", 106 | " 3518946594, 962584079, 1051271004, 3806145045, 1414436097,\n", 107 | " 2032348584, 1661738718, 1116708477, 2562755208, 3176189976,\n", 108 | " 696824676, 2399811678, 3992505346, 569184356, 2626558620,\n", 109 | " 136797809, 4273176064, 296167901, 343 ...\n" 110 | ] 111 | } 112 | ], 113 | "source": [ 114 | "np.random.seed(0)\n", 115 | "print_truncated_random_state()\n", 116 | "_ = np.random.uniform()\n", 117 | "print_truncated_random_state()" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "id": "usual-costs", 123 | "metadata": {}, 124 | "source": [ 125 | "NumPy 允许您在单个函数调用中对单个数字或整个数字向量进行采样。例如,您可以通过执行以下操作从均匀分布中采样一个向量中的三个标量:" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 4, 131 | "id": "ambient-technique", 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "[0.5488135 0.71518937 0.60276338]\n" 139 | ] 140 | } 141 | ], 142 | "source": [ 143 | "np.random.seed(0)\n", 144 | "print(np.random.uniform(size=3))" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "id": "distributed-terrace", 150 | "metadata": {}, 151 | "source": [ 152 | "NumPy提供了一个 *顺序等效保证*,意思是单独对N个数字进行连续采样和对N个数字的向量采样将会产生相同的伪随机序列:" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 5, 158 | "id": "analyzed-broadcasting", 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "name": "stdout", 163 | "output_type": "stream", 164 | "text": [ 165 | "individually: [0.5488135 0.71518937 0.60276338]\n", 166 | "all at once: [0.5488135 0.71518937 0.60276338]\n" 167 | ] 168 | } 169 | ], 170 | "source": [ 171 | "np.random.seed(0)\n", 172 | "print(\"individually:\", np.stack([np.random.uniform() for _ in range(3)]))\n", 173 | "\n", 174 | "np.random.seed(0)\n", 175 | "print(\"all at once:\", np.random.uniform(size=3))" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "id": "legislative-italic", 181 | "metadata": {}, 182 | "source": [ 183 | "## JAX中的随机数\n", 184 | "\n", 185 | "JAX中的随机数生成和NumPy中的有很显著的不同。原因是NumPy中的PRNG设计很难同时保证JAX中的许多理想化属性,特别是代码应该具有:\n", 186 | "\n", 187 | "1. 可复现化,\n", 188 | "2. 可并行化,\n", 189 | "3. 可向量化。\n", 190 | "\n", 191 | "我们将在下面讨论原因。首先,我们先专注于基于全局状态的PRNG设计的含义。我们来看以下代码:" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 6, 197 | "id": "reliable-worse", 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "name": "stdout", 202 | "output_type": "stream", 203 | "text": [ 204 | "1.9791922366721637\n" 205 | ] 206 | } 207 | ], 208 | "source": [ 209 | "import numpy as np\n", 210 | "\n", 211 | "np.random.seed(0)\n", 212 | "\n", 213 | "def bar():\n", 214 | " return np.random.uniform()\n", 215 | "\n", 216 | "def baz():\n", 217 | " return np.random.uniform()\n", 218 | "\n", 219 | "def foo():\n", 220 | " return bar() + 2 * baz()\n", 221 | "\n", 222 | "print(foo())" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "id": "verified-salem", 228 | "metadata": {}, 229 | "source": [ 230 | "函数 `foo` 将从一个均匀分布中的采样到的两个标量求和。如果像原生的Python那样,假设 `bar()`和 `baz()` 两个函数的执行顺序固定,则此时代码的输出只能满足第一个要求。这在NumPy中貌似不是个大问题,由于它已经被Python强制执行,但在JAX就行不通了。\n", 231 | "\n", 232 | "要使以上代码在JAX中可重现,则将需要强制指定特性的执行顺序。这样的话将会违反第二个要求,因为在JAX中应该可以并行化 `bar` 和 `baz`,因为这两个函数并不相互依赖。所以为了避免这个问题,JAX不采用全局状态。随机函数显式采用一个状态,我们取而代之称之为 `key`。\n", 233 | "\n" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 7, 239 | "id": "provincial-baking", 240 | "metadata": {}, 241 | "outputs": [ 242 | { 243 | "name": "stdout", 244 | "output_type": "stream", 245 | "text": [ 246 | "[ 0 42]\n" 247 | ] 248 | } 249 | ], 250 | "source": [ 251 | "from jax import random\n", 252 | "\n", 253 | "key = random.PRNGKey(42)\n", 254 | "print(key)" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "id": "collect-discovery", 260 | "metadata": {}, 261 | "source": [ 262 | "`key`只是一个 `(2,)` 形状的数组。\n", 263 | "\n", 264 | "本质上,`Random key`只是 `Random seed` 的一个别称。但是,与其在NumPy中设置一次一样,在JAX中对随机函数的任何调用都需要指定一个 `key`。随机函数使用 `key`,但不对齐进行修改。向随机函数提供相同的 `key` 都将始终生成相同的样本:" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 8, 270 | "id": "relevant-convenience", 271 | "metadata": {}, 272 | "outputs": [ 273 | { 274 | "name": "stdout", 275 | "output_type": "stream", 276 | "text": [ 277 | "-0.18471177\n", 278 | "-0.18471177\n" 279 | ] 280 | } 281 | ], 282 | "source": [ 283 | "print(random.normal(key))\n", 284 | "print(random.normal(key))" 285 | ] 286 | }, 287 | { 288 | "cell_type": "markdown", 289 | "id": "prescribed-gauge", 290 | "metadata": {}, 291 | "source": [ 292 | "**注意**: 通常我们不希望将相同的 `key` 提供给不同的随机函数,这样可能会导致相关联的结果。\n", 293 | "\n", 294 | "**经验法则:永远不要重用 `key`** (除非您想要相同的输出)。\n", 295 | "\n", 296 | "为了生成不同且独立的样本,每当您要调用随机函数时,都必须*自己手动* `split()` key:" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 9, 302 | "id": "upset-generation", 303 | "metadata": {}, 304 | "outputs": [ 305 | { 306 | "name": "stdout", 307 | "output_type": "stream", 308 | "text": [ 309 | "old key [ 0 42]\n", 310 | " \\---SPLIT--> new key [2465931498 3679230171]\n", 311 | " \\--> new subkey [255383827 267815257] --> normal 1.3694694\n" 312 | ] 313 | } 314 | ], 315 | "source": [ 316 | "print(\"old key\", key)\n", 317 | "new_key, subkey = random.split(key)\n", 318 | "# The old key is discarded -- we must never use it again.\n", 319 | "# 旧key被丢弃——我们不能重复使用它\n", 320 | "del key \n", 321 | "normal_sample = random.normal(subkey)\n", 322 | "print(r\" \\---SPLIT--> new key \", new_key)\n", 323 | "print(r\" \\--> new subkey \", subkey, \"--> normal\", normal_sample)\n", 324 | "# The subkey is also discarded after use.\n", 325 | "# subkey在使用后依然被丢弃\n", 326 | "del subkey\n", 327 | "\n", 328 | "# Note: you don't actually need to `del` keys -- that's just for emphasis.\n", 329 | "# Not reusing the same values is enough.\n", 330 | "# 注意:一般来说不需要真的 `del` key,这里只是为了强调\n", 331 | "# 只要不使用相同值就行\n", 332 | "\n", 333 | "# If we wanted to do this again, we would use new_key as the key.\n", 334 | "# 如果我们想要再来一次,就直接使用 `new_key`\n", 335 | "key = new_key " 336 | ] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "id": "personalized-collective", 341 | "metadata": {}, 342 | "source": [ 343 | "`split()` 是一个确定性函数,他可以将一个key转换为几个独立的key(在伪随机的意义上)。我们将输出之一保留为 `new_key`,可以安全地使用唯一的额外key (称为 `subkey`)作为随机函数的输入,然后将其永久丢弃。\n", 344 | "\n", 345 | "如果要从同一个正态分布中取得另一个样本,则可以再次拆分 `key`,以此类推。关键的是,您永远不会两次使用相同的 `PRNGKey`。由于 `split()` 将key作为其参数,因此在拆后时必须丢弃旧key。\n", 346 | "\n", 347 | "我们将 `split(key)` 中输出的哪一部分称为 `key` 或者 `subkey` 并不重要,他们都是状态相同的伪随机数。我们使用 `key/subkey`的约定原因是为了跟踪他们在未来被使用的方式。 `subkey` 注定由随机函数直接消耗掉,而 `key` 可以保留在以后产生更多随机性。\n", 348 | "\n", 349 | "通常,以上的例子可以更简洁地写成:\n" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 10, 355 | "id": "excellent-producer", 356 | "metadata": {}, 357 | "outputs": [], 358 | "source": [ 359 | "key, subkey = random.split(key)" 360 | ] 361 | }, 362 | { 363 | "cell_type": "markdown", 364 | "id": "exempt-congo", 365 | "metadata": {}, 366 | "source": [ 367 | "这样将会自动消耗旧key。值得一提的是,不管你需要多少key,`split()` 函数都可以产生:" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 11, 373 | "id": "ecological-occupation", 374 | "metadata": {}, 375 | "outputs": [], 376 | "source": [ 377 | "key, *forty_two_subkeys = random.split(key, num=43)" 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "id": "minimal-teach", 383 | "metadata": {}, 384 | "source": [ 385 | "NumPy和JAX的随机模块之间的另一个区别与上述的 *顺序等效保证* 有关。与NumPy相同,JAX的随机模块也允许对数字向量进行采样。但是,JAX不提供*顺序等效保证*,因为这样做会干扰SIMD硬件上的向量化。\n", 386 | "\n", 387 | "在下面的示例中,使用三个subkey从正太分布中采样三个值的结果和使用单个key指定 `shape=(3,)`的结果不同:" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 12, 393 | "id": "pleasant-dealer", 394 | "metadata": {}, 395 | "outputs": [ 396 | { 397 | "name": "stdout", 398 | "output_type": "stream", 399 | "text": [ 400 | "individually: [-0.04838832 0.10796154 -1.2226542 ]\n", 401 | "all at once: [ 0.18693547 -1.2806505 -1.5593132 ]\n" 402 | ] 403 | } 404 | ], 405 | "source": [ 406 | "key = random.PRNGKey(42)\n", 407 | "subkeys = random.split(key, 3)\n", 408 | "sequence = np.stack([random.normal(subkey) for subkey in subkeys])\n", 409 | "print(\"individually:\", sequence)\n", 410 | "\n", 411 | "key = random.PRNGKey(42)\n", 412 | "print(\"all at once:\", random.normal(key, shape=(3,)))" 413 | ] 414 | }, 415 | { 416 | "cell_type": "markdown", 417 | "id": "pursuant-night", 418 | "metadata": {}, 419 | "source": [ 420 | "注意,这里与上面我们的建议相反,在第二个示例中我们直接使用 `key` 作为`random.normal()` 的输入。这是因为我们不会在其他任何地方重复使用它,因此我们不会违反一次性使用原则。" 421 | ] 422 | } 423 | ], 424 | "metadata": { 425 | "kernelspec": { 426 | "display_name": "Python 3", 427 | "language": "python", 428 | "name": "python3" 429 | }, 430 | "language_info": { 431 | "codemirror_mode": { 432 | "name": "ipython", 433 | "version": 3 434 | }, 435 | "file_extension": ".py", 436 | "mimetype": "text/x-python", 437 | "name": "python", 438 | "nbconvert_exporter": "python", 439 | "pygments_lexer": "ipython3", 440 | "version": "3.9.2" 441 | } 442 | }, 443 | "nbformat": 4, 444 | "nbformat_minor": 5 445 | } 446 | -------------------------------------------------------------------------------- /official-tutorials/ReferenceDocumentation/2.3-异步调度.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "streaming-drama", 6 | "metadata": {}, 7 | "source": [ 8 | "# 异步调度\n", 9 | "\n", 10 | "JAX使用异步调度来隐藏Python开销。考虑以下代码:" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "supreme-vitamin", 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "data": { 21 | "text/plain": [ 22 | "DeviceArray([[258.01974, 249.6486 , 257.13367, ..., 236.67946, 250.68948,\n", 23 | " 241.36853],\n", 24 | " [265.65985, 256.28912, 262.1825 , ..., 242.03188, 256.1676 ,\n", 25 | " 252.44131],\n", 26 | " [262.38904, 255.72743, 261.2306 , ..., 240.8356 , 255.41084,\n", 27 | " 249.62466],\n", 28 | " ...,\n", 29 | " [259.15814, 253.09195, 257.72174, ..., 242.23877, 250.72672,\n", 30 | " 247.16637],\n", 31 | " [271.2267 , 261.91208, 265.33398, ..., 248.26645, 262.0539 ,\n", 32 | " 261.33704],\n", 33 | " [257.16138, 254.75424, 259.083 , ..., 241.5985 , 248.626 ,\n", 34 | " 243.22357]], dtype=float32)" 35 | ] 36 | }, 37 | "execution_count": 1, 38 | "metadata": {}, 39 | "output_type": "execute_result" 40 | } 41 | ], 42 | "source": [ 43 | "import numpy as np\n", 44 | "import jax.numpy as jnp\n", 45 | "from jax import random\n", 46 | "\n", 47 | "x = random.uniform(random.PRNGKey(0), (1000, 1000))\n", 48 | "jnp.dot(x, x) + 3" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "id": "supreme-still", 54 | "metadata": {}, 55 | "source": [ 56 | "类似 `jnp.dot(x, x)` 的操作被执行的时候,JAX在将控制权返回给Python程序之前不会等待该操作完成。JAX返回一个 `DeviceArray` 值,这是一个未来值,也就是未来将在加速设备上生成的值,不一定立即可用。我们可以检查 `DeviceArray` 的形状或者类型,而不必等待计算完毕,甚至可以像例子中的加法操作,将其传递给另一个JAX计算。仅当我们实际要求从主机检查数组值的时候(例如通过打印或将其转换为普通的 `numpy.ndarray`),JAX才会强制Python代码等待计算完成。\n", 57 | "\n", 58 | "异步调度非常有用,因为它允许Python代码在加速设备之前运行,从而使Python代码可以脱离关键路径。如果Python代码使设备上的工作入队列的速度快于其执行速度,并且前提是Python代码实际上不需要检查主机上的计算输出,那么Python程序就可以使任意数量的工作入队列,并避免了加速设备等待。\n", 59 | "\n", 60 | "异步调度对微基准测试产生了令人惊讶的后果。" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 2, 66 | "id": "verified-terrorism", 67 | "metadata": { 68 | "scrolled": true 69 | }, 70 | "outputs": [ 71 | { 72 | "name": "stdout", 73 | "output_type": "stream", 74 | "text": [ 75 | "CPU times: user 195 µs, sys: 231 µs, total: 426 µs\n", 76 | "Wall time: 259 µs\n" 77 | ] 78 | }, 79 | { 80 | "data": { 81 | "text/plain": [ 82 | "DeviceArray([[255.01974, 246.6486 , 254.13365, ..., 233.67946, 247.68948,\n", 83 | " 238.36853],\n", 84 | " [262.65985, 253.28912, 259.1825 , ..., 239.03188, 253.1676 ,\n", 85 | " 249.44131],\n", 86 | " [259.38904, 252.72743, 258.2306 , ..., 237.8356 , 252.41084,\n", 87 | " 246.62466],\n", 88 | " ...,\n", 89 | " [256.15814, 250.09195, 254.72174, ..., 239.23877, 247.72672,\n", 90 | " 244.16637],\n", 91 | " [268.2267 , 258.91208, 262.33398, ..., 245.26645, 259.0539 ,\n", 92 | " 258.33704],\n", 93 | " [254.16138, 251.75424, 256.083 , ..., 238.5985 , 245.626 ,\n", 94 | " 240.22357]], dtype=float32)" 95 | ] 96 | }, 97 | "execution_count": 2, 98 | "metadata": {}, 99 | "output_type": "execute_result" 100 | } 101 | ], 102 | "source": [ 103 | "%time jnp.dot(x, x)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "id": "surgical-clothing", 109 | "metadata": {}, 110 | "source": [ 111 | "在CPU上进行`1000 x 1000`矩阵乘法仅需275微秒!但是事实证明,异步调度会误导我们,我们不是在对矩阵乘法的执行时间计时,而是对调度的工作计时。为了衡量操作的真实成本,我们必须读取主机上的值(例如将其转换为普通的NumPy数组),或者对 `DeviceArray` 值使用 `block_until_ready()` 方法来等待计算完成。" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 3, 117 | "id": "metallic-running", 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "name": "stdout", 122 | "output_type": "stream", 123 | "text": [ 124 | "CPU times: user 2.63 ms, sys: 0 ns, total: 2.63 ms\n", 125 | "Wall time: 2.37 ms\n" 126 | ] 127 | }, 128 | { 129 | "data": { 130 | "text/plain": [ 131 | "array([[255.01974, 246.6486 , 254.13365, ..., 233.67946, 247.68948,\n", 132 | " 238.36853],\n", 133 | " [262.65985, 253.28912, 259.1825 , ..., 239.03188, 253.1676 ,\n", 134 | " 249.44131],\n", 135 | " [259.38904, 252.72743, 258.2306 , ..., 237.8356 , 252.41084,\n", 136 | " 246.62466],\n", 137 | " ...,\n", 138 | " [256.15814, 250.09195, 254.72174, ..., 239.23877, 247.72672,\n", 139 | " 244.16637],\n", 140 | " [268.2267 , 258.91208, 262.33398, ..., 245.26645, 259.0539 ,\n", 141 | " 258.33704],\n", 142 | " [254.16138, 251.75424, 256.083 , ..., 238.5985 , 245.626 ,\n", 143 | " 240.22357]], dtype=float32)" 144 | ] 145 | }, 146 | "execution_count": 3, 147 | "metadata": {}, 148 | "output_type": "execute_result" 149 | } 150 | ], 151 | "source": [ 152 | "%time np.asarray(jnp.dot(x, x))" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 4, 158 | "id": "prerequisite-primary", 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "name": "stdout", 163 | "output_type": "stream", 164 | "text": [ 165 | "CPU times: user 1.14 ms, sys: 0 ns, total: 1.14 ms\n", 166 | "Wall time: 824 µs\n" 167 | ] 168 | }, 169 | { 170 | "data": { 171 | "text/plain": [ 172 | "DeviceArray([[255.01974, 246.6486 , 254.13365, ..., 233.67946, 247.68948,\n", 173 | " 238.36853],\n", 174 | " [262.65985, 253.28912, 259.1825 , ..., 239.03188, 253.1676 ,\n", 175 | " 249.44131],\n", 176 | " [259.38904, 252.72743, 258.2306 , ..., 237.8356 , 252.41084,\n", 177 | " 246.62466],\n", 178 | " ...,\n", 179 | " [256.15814, 250.09195, 254.72174, ..., 239.23877, 247.72672,\n", 180 | " 244.16637],\n", 181 | " [268.2267 , 258.91208, 262.33398, ..., 245.26645, 259.0539 ,\n", 182 | " 258.33704],\n", 183 | " [254.16138, 251.75424, 256.083 , ..., 238.5985 , 245.626 ,\n", 184 | " 240.22357]], dtype=float32)" 185 | ] 186 | }, 187 | "execution_count": 4, 188 | "metadata": {}, 189 | "output_type": "execute_result" 190 | } 191 | ], 192 | "source": [ 193 | "%time jnp.dot(x, x).block_until_ready() " 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "id": "steady-broadcasting", 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [] 203 | } 204 | ], 205 | "metadata": { 206 | "kernelspec": { 207 | "display_name": "Python 3", 208 | "language": "python", 209 | "name": "python3" 210 | }, 211 | "language_info": { 212 | "codemirror_mode": { 213 | "name": "ipython", 214 | "version": 3 215 | }, 216 | "file_extension": ".py", 217 | "mimetype": "text/x-python", 218 | "name": "python", 219 | "nbconvert_exporter": "python", 220 | "pygments_lexer": "ipython3", 221 | "version": "3.9.2" 222 | } 223 | }, 224 | "nbformat": 4, 225 | "nbformat_minor": 5 226 | } 227 | --------------------------------------------------------------------------------